添加链接
link管理
链接快照平台
  • 输入网页链接,自动生成快照
  • 标签化管理网页链接

I am trying to save the model in PyTorch by using the below code:

model=utils.get_model(self.model)
torch.save({#‘model_state_dict’: model,
#added new
‘model_state_dict’: model.state_dict(),
}, os.path.join(self.checkpoint, ‘model_{}.pth’.format(task_id)))

I am able to load the model successfully with no issues in my app. The model is been saved in to a pth file.

My second step is to take the saved model model.pth and load it via the code below into another application:

model.load_state_dict(torch.load("./checkpoint/model.pth"))

It is giving me the below error:

RuntimeError Traceback (most recent call last)
----> 1 model.load_state_dict(torch.load("/home/jovyan/.cache/torch/checkpoints/resnext50_32x4d-7cdf4587.pth"))
2 model = model.eval()

/srv/conda/envs/notebook/lib/python3.7/site-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
828 if len(error_msgs) > 0:
829 raise RuntimeError(‘Error(s) in loading state_dict for {}:\n\t{}’.format(
→ 830 self. class . name , “\n\t”.join(error_msgs)))
831 return _IncompatibleKeys(missing_keys, unexpected_keys)

RuntimeError: Error(s) in loading state_dict for Target:
Missing key(s) in state_dict: “conv_layer.0.weight”, “conv_layer.0.bias”, “conv_layer.1.weight”, “conv_layer.1.bias”, “conv_layer.4.weight”, “conv_layer.4.bias”, “conv_layer.5.weight”, “conv_layer.5.bias”, “conv_layer.7.weight”, “conv_layer.7.bias”, “conv_layer.8.weight”, “conv_layer.8.bias”, “conv_layer.11.weight”, “conv_layer.11.bias”, “conv_layer.12.weight”, “conv_layer.12.bias”, “conv_layer.14.weight”, “conv_layer.14.bias”, “conv_layer.15.weight”, “conv_layer.15.bias”, “conv_layer.18.weight”, “conv_layer.18.bias”, “conv_layer.19.weight”, “conv_layer.19.bias”, “conv_layer.21.weight”, “conv_layer.21.bias”, “conv_layer.22.weight”, “conv_layer.22.bias”, “conv_layer.24.weight”, “conv_layer.24.bias”

Therefore in my code I start to explore additional options to add the model_state too,
My question is, isn’t supposed once I use the below code save all my model including the model state?

model=utils.get_model(self.model)
torch.save({

Apparently not, that’s why I added the below to my code:

torch.save(model.state_dict(), os.path.join(self.checkpoint, ‘model_state_{}.pth’.format(task_id)))

However, I am getting this error :frowning:

in save_all_models
‘model_state_dict’: model.state_dict(),
AttributeError: ‘collections.OrderedDict’ object has no attribute ‘state_dict’

Thank you for your help in advance.

torch.save({
‘model_state_dict’: model.state_dict(),
}, os.path.join(self.checkpoint, ‘model_{}.pth’.format(task_id)))

If you load the checkpoint file and print its keys, you would see

checkpoint = torch.load('...<the path>...')
print(checkpoint.keys())
# output
>>> ['model_state_dict']

So the solution is:

  • directly saving the model and directly load
  • # save
    torch.save(model.state_dict(), '...<path>...')
    # load
    model.load_state_dict(torch.load('...<path>...'))
    
  • save the model states with a key and load it with THAT key
  • # save
    torch.save({
        'mymodel': model.state_dict()
        }, '...<path>...')
    # load
    model.load_state_dict(torch.load('...<path>...')['mymodel'])
    

    And for the error AttributeError: ‘collections.OrderedDict’ object has no attribute ‘state_dict’, just use type(model) to check the type of model.

    torch.save({‘model_state_dict’: model.state_dict(),
    }, os.path.join(self.checkpoint, ‘model_{}.pth’.format(task_id)))

    however, I am still getting the same error:

     torch.save({'model_state_dict': model.state_dict(),
    AttributeError: 'collections.OrderedDict' object has no attribute 'state_dict'
    

    I was reading the issues also discussed here: