添加链接
link管理
链接快照平台
  • 输入网页链接,自动生成快照
  • 标签化管理网页链接

Whenever i tried to save the entire model the warning shows, it saves without any errors.

checkpoint = {'model': ResNet9(), 'state_dict': model.state_dict()}
torch.save(checkpoint, 'model.pth')

But when i try loading it , it gives me this error.

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-1-29ed9644f643> in <module>
     10     return model
---> 12 model = load_checkpoint('face.pth')
<ipython-input-1-29ed9644f643> in load_checkpoint(filepath)
      1 import torch
      2 def load_checkpoint(filepath):
----> 3     checkpoint = torch.load(filepath)
      4     model = checkpoint['model']
      5     model.load_state_dict(checkpoint['state_dict'])
~/anaconda3/lib/python3.7/site-packages/torch/serialization.py in load(f, map_location, pickle_module, **pickle_load_args)
    578                     return torch.jit.load(f)
    579                 return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
--> 580         return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
~/anaconda3/lib/python3.7/site-packages/torch/serialization.py in _legacy_load(f, map_location, pickle_module, **pickle_load_args)
    758     unpickler = pickle_module.Unpickler(f, **pickle_load_args)
    759     unpickler.persistent_load = persistent_load
--> 760     result = unpickler.load()
    762     deserialized_storage_keys = pickle_module.load(f, **pickle_load_args)
AttributeError: Can't get attribute 'ResNet9' on <module '__main__'>

my torch version is 1.5.1 and im using colab

This is a common disadvantage of storing the model directly, as you would need to restore the exact same file structure.
Based on the error message it seems that the definition of ResNet9 is missing.

The recommended way it to only store the state_dict and recreate the model directly using:

model = ResNet9() # you would need the definition as an import or in the current script
model.load_state_dict(checkpoint['state_dict'])
              

So i have to create the model architecture from scratch?.
Is there any other way to use a model architecture in a another python file just like
we import models from torchvision.

You would have to load the model definition. It can be imported from another file or should be defined in the current script.
If you use the recommended approach, you are flexible where the model definition is stored (in another file, from a lib, in the current script), while loading the model directly would force you to replicate the exact file structure, which was used when you stored the model.

What do you mean by ‘should be defined in the current script’.
Im sorry , i’m new to all this , alot of what you’ve just said is confusing.
If im not asking too much could you explain it with an example.

So if i want to do training, data preprocessing in one file and using the trained model
for prediction in another file , my best bet would be to import the model.

Thank you so much for helping @ptrblck