I have saved a model during training. However, I am facing issues when trying to load the model from a saved checkpoint. The model class name is
CSLRModel
. On the python code
CSLRModel.load_from_checkpoint(ckpt_path, **kwargs)
, I am getting the following error:
return CSLRModel.load_from_checkpoint(ckpt_path, encoder_seq, config, config.num_classes_gloss)
File "/data/envs/ohdev/lib/python3.8/site-packages/pytorch_lightning/core/saving.py", line 137, in load_from_checkpoint
return _load_from_checkpoint(
File "/data/envs/ohdev/lib/python3.8/site-packages/pytorch_lightning/core/saving.py", line 158, in _load_from_checkpoint
checkpoint = pl_load(checkpoint_path, map_location=map_location)
File "/data/envs/ohdev/lib/python3.8/site-packages/lightning_lite/utilities/cloud_io.py", line 48, in _load
return torch.load(f, map_location=map_location)
File "/data/envs/ohdev/lib/python3.8/site-packages/torch/serialization.py", line 789, in load
return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
File "/data/envs/ohdev/lib/python3.8/site-packages/torch/serialization.py", line 1131, in _load
result = unpickler.load()
File "/usr/lib/python3.8/pickle.py", line 1212, in load
dispatch[key[0]](self)
File "/usr/lib/python3.8/pickle.py", line 1253, in load_binpersid
self.append(self.persistent_load(pid))
File "/data/envs/ohdev/lib/python3.8/site-packages/torch/serialization.py", line 1101, in persistent_load
load_tensor(dtype, nbytes, key, _maybe_decode_ascii(location))
File "/data/envs/ohdev/lib/python3.8/site-packages/torch/serialization.py", line 1083, in load_tensor
wrap_storage=restore_location(storage, location),
File "/data/envs/ohdev/lib/python3.8/site-packages/torch/serialization.py", line 1058, in restore_location
result = map_location(storage, location)
File "/data/envs/ohdev/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/data/envs/ohdev/lib/python3.8/site-packages/torch/nn/modules/module.py", line 246, in _forward_unimplemented
raise NotImplementedError(f"Module [{type(self).__name__}] is missing the required \"forward\" function")
NotImplementedError: Module [ModuleList] is missing the required "forward" function
Not sure what I am doing wrong here. The CSLRModel
class does have forward
function, and I was able to use the model class to successfully train and save the checkpoints. It is only in loading from a checkpoint that I am facing an error.
class CSLRModel(pl.LightningModule):
def __init__(self, encoder_seq, config, num_classes_gloss):
super().__init__()
self.save_hyperparameters()
self.epoch_no = 0
## path hardcoded. Resolve
gloss_dict = np.load('/data/cslr_datasets/PHOENIX-2014/phoenix2014-release/phoenix-2014-multisigner/preprocess/gloss_dict.npy', allow_pickle=True).item()
self.decoder = Decode(gloss_dict, num_classes_gloss, 'beam')
self.config = config
self.encoder_seq = encoder_seq
self.num_encoders = len(encoder_seq)
self.num_classes_gloss = num_classes_gloss
self.classifiers = nn.ModuleDict() ## to use features of an encoder to classify
for enc in self.encoder_seq:
if enc.use_ctc:
self.classifiers[f'{enc.encoder_id}'] = nn.Linear(enc.out_size, self.num_classes_gloss)
self.externals = {}
self.initialize_losses(config.losses)
def forward(self, x, len_x, label, len_label, is_training=True):
for i, enc in enumerate(self.encoder_seq):
x, len_x, internal_losses = enc(x,len_x)
self.loss_value.update(internal_losses)
if enc.use_ctc:
logits = self.classifiers[f'{enc.encoder_id}'](x)
self.externals[f'encoder{i+1}.logits'] = logits
self.loss_value[f'{enc.encoder_id}.CTCLoss'] = self.loss_fn['CTC'](
logits.transpose(0,1).log_softmax(-1),
label.cpu().int(),
len_x.cpu().int(),
len_label.cpu().int()).mean()
return self.compute_external_losses(), logits, len_x
Pasting the __init__()
and forward()
functions of the model definition. Not pasting the rest of the definition since it would make the post long.
Note: I was able to load the model from checkpoint when I used self.save_hyperparameters()
in __init__()
. Is this required?
Hey, it makes a ton of sense now.
Here is how load_from_checkpoint works internally:
1.) We instantiate the class (CSLRModel
) with the necessary init arguments
2.) We load the state dict to the class instance
For 1.) we need to get the init arguments somewhere. There are 2 options to do this.
save_hyperparameters()
just serializes the init arguments so that you don’t need to do anything (at least for easily serializable stuff, different if you pass a nn.Module)
You pass them in as keyword arguments to load_from_checkpoint
Since you didn’t do any of these it makes sense that you couldn’t load the model. Hope that makes it a bit clearer!
Cheers,
Justus
Yeah, I had read about these two options in the docs, and I tried both. My query was as follows:
1] Approach 1: Use self.save_hyperparameters()
in model __init__()
and only pass ckpt_path
when loading during test time i.e., CSLRModel(ckpt_path)
. This one works for me, but I do not prefer it since the self.save_hyperparameters()
statement takes a little long to execute.
2] Approach 2: Do not use self.save_hyperparameters()
and load during test time as follows: CSLRModel(ckpt_path, encoder_seq, config, num_classes_gloss)
i.e., pass both saved checkpoint path and hyperparameters. However, when I use this approach, I get the above mentioned error.
NotImplementedError: Module [ModuleList] is missing the required "forward" function