Whenever i tried to save the entire model the warning shows, it saves without any errors.
checkpoint = {'model': ResNet9(), 'state_dict': model.state_dict()}
torch.save(checkpoint, 'model.pth')
But when i try loading it , it gives me this error.
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<ipython-input-1-29ed9644f643> in <module>
10 return model
---> 12 model = load_checkpoint('face.pth')
<ipython-input-1-29ed9644f643> in load_checkpoint(filepath)
1 import torch
2 def load_checkpoint(filepath):
----> 3 checkpoint = torch.load(filepath)
4 model = checkpoint['model']
5 model.load_state_dict(checkpoint['state_dict'])
~/anaconda3/lib/python3.7/site-packages/torch/serialization.py in load(f, map_location, pickle_module, **pickle_load_args)
578 return torch.jit.load(f)
579 return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
--> 580 return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
~/anaconda3/lib/python3.7/site-packages/torch/serialization.py in _legacy_load(f, map_location, pickle_module, **pickle_load_args)
758 unpickler = pickle_module.Unpickler(f, **pickle_load_args)
759 unpickler.persistent_load = persistent_load
--> 760 result = unpickler.load()
762 deserialized_storage_keys = pickle_module.load(f, **pickle_load_args)
AttributeError: Can't get attribute 'ResNet9' on <module '__main__'>
my torch version is 1.5.1 and im using colab
This is a common disadvantage of storing the model directly, as you would need to restore the exact same file structure.
Based on the error message it seems that the definition of ResNet9
is missing.
The recommended way it to only store the state_dict
and recreate the model directly using:
model = ResNet9() # you would need the definition as an import or in the current script
model.load_state_dict(checkpoint['state_dict'])
So i have to create the model architecture from scratch?.
Is there any other way to use a model architecture in a another python file just like
we import models from torchvision.
You would have to load the model definition. It can be imported from another file or should be defined in the current script.
If you use the recommended approach, you are flexible where the model definition is stored (in another file, from a lib, in the current script), while loading the model directly would force you to replicate the exact file structure, which was used when you stored the model.
What do you mean by ‘should be defined in the current script’.
Im sorry , i’m new to all this , alot of what you’ve just said is confusing.
If im not asking too much could you explain it with an example.
So if i want to do training, data preprocessing in one file and using the trained model
for prediction in another file , my best bet would be to import the model.
Thank you so much for helping @ptrblck