Hi, there:
I’ve encountered this problem and got stucked for a while. I have a labeled image dataset in a considerable large scale and I chose to train a vgg16 on it just starting from
pytorch’s imagenet example
.
I firstly organize data into three splits, namely train, val, test; under each of them are bunches of subdirectory organized by class labels, like:
train
label0
label1
label0
label1
file0
file1
and the command:
export CUDA_VISIBLE_DEVICES=device_id
python3 main.py /path/to/my/dataset -a vgg16 -b 32 --lr 0.001
and the training seems to be fine — with nearly 90% of top-5 accuracy. The model file name is model_best.pth.tar
After that I would like to infer some images using my model, it fails with the follow error:
RuntimeError: Error(s) in loading state_dict for VGG:
Missing key(s) in state_dict: "features.0.weight", "features.0.bias", "features.2.weight", "features.2.bias", "features.5.weight", "features.5.bias", "features.7.weight", "features.7.bias", "features.10.weight", "features.10.bias", "features.12.weight", "features.12.bias", "features.14.weight", "features.14.bias", "features.17.weight", "features.17.bias", "features.19.weight", "features.19.bias", "features.21.weight", "features.21.bias", "features.24.weight", "features.24.bias", "features.26.weight", "features.26.bias", "features.28.weight", "features.28.bias".
Unexpected key(s) in state_dict: "features.module.0.weight", "features.module.0.bias", "features.module.2.weight", "features.module.2.bias", "features.module.5.weight", "features.module.5.bias", "features.module.7.weight", "features.module.7.bias", "features.module.10.weight", "features.module.10.bias", "features.module.12.weight", "features.module.12.bias", "features.module.14.weight", "features.module.14.bias", "features.module.17.weight", "features.module.17.bias", "features.module.19.weight", "features.module.19.bias", "features.module.21.weight", "features.module.21.bias", "features.module.24.weight", "features.module.24.bias", "features.module.26.weight", "features.module.26.bias", "features.module.28.weight", "features.module.28.bias".
Could anyone give me some advice?
Thanks in advance.
EDIT: the loading snippet:
import torch
from torchvision import models
model = models.__dict__[args.arch]() # arch is fed as 'vgg16'
model.cuda()
checkpoint = torch.load(model_file_name) # ie, model_best.pth.tar
model.load_state_dict(checkpoint['state_dict'])
main(args)
File "classifier.py", line 78, in main
_im, im = Classifier(args.model).infer(args.image)
File "classifier.py", line 48, in __init__
arch.load_state_dict(checkpoint['state_dict'])
File "/home/xxx/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 721, in load_state_dict
self.__class__.__name__, "\n\t".join(error_msgs)))
It seems you’ve used nn.DataParallel
to save the model.
You could wrap your current model again in nn.DataParallel
or just remove the .module
keys.
Here is a similar thread with some suggestions.
Can we load the state_dict of a previously trained NN for a new NN that has some changes in its structure. For example, I need the state_dict of previously trained NN partially, which i successfully did. Now, I add one more layer to my new NN and it shows error that the state_dict of NN which was loaded does not contain this output layer. How to proceed?
In my case it would not throw any errors anymore, but it wouldn’t correctly load the state_dict either.
In the end I just had to do this to just remove the “.module” part.
Can you explain more clearly how to wrap the current model in nn.Dataparallel? Like can you give an example?
I have the error Missing key(s) in state_dict:
and actually when I save the model, I just use torch.save().
I am new to pytorch, thanks so much!
You could just wrap the model in nn.DataParallel
and push it to the device:
model = Model(input_size, output_size)
model = nn.DataParallel(model)
model.to(device)
I would not recommend to save the model directly, but instead its state_dict
as explained here.
Also, after you’ve wrapped the model in nn.DataParallel
, the original model will be accessible via model.module
, so you might want to store the state_dict
via torch.save(model.module.state_dict(), 'file_name.pt')
.
I use torch.save(model.state_dict(), 'file_name.pt')
to save the model. If I don’t use nn.DataParallel
to save model, I also don’t need it when I load it, right?
If so, when I torch.save(model.state_dict(), 'mymodel.pt')
in one py file during training, I try to load it in a new file with model.load_state_dict(torch.load('mymodel.pt))
, it gives me the error
RuntimeError: Error(s) in loading state_dict for model: Missing key(s) in state_dict: "l1.W", "l1.b", "l2.W", "l2.b".
My model is a self-defined model with two self-constructed layers l1 and l2, each has parameter W and b.
Can you give me any answer to it? Thanks!
lelegu:
If I don’t use nn.DataParallel
to save model, I also don’t need it when I load it, right?
Yes, that is correct.
I also am facing the same issue. I don’t use nn.DataParallel.
I am using a slightly modified version of [this repo] in a Kaggle notebook https://github.com/aitorzip/PyTorch-CycleGAN.
Here’s how I save:
torch.save({
'epoch': epoch,
'model_state_dict': netG_A2B.state_dict(),
'optimizer_state_dict': optimizer_G.state_dict(),
'loss_histories': loss_histories,
}, f'netG_A2B_{epoch:03d}.pth')
Here’s how I load:
checkpoint = torch.load(weights_path)['model_state_dict']
self.model.load_state_dict(checkpoint)
And here’s a representative snippet of the (longer) error message:
Missing key(s) in state_dict: "1.weight", "1.bias", "4.weight", "4.bias", "7.weight", "7.bias".
Unexpected key(s) in state_dict: "model.1.weight", "model.1.bias", "model.4.weight", "model.4.bias", "model.7.weight", "model.7.bias".
I opted for this method to fix it:
checkpoint = torch.load(weights_path, map_location=self.device)['model_state_dict']
for key in list(checkpoint.keys()):
if 'model.' in key:
checkpoint[key.replace('model.', '')] = checkpoint[key]
del checkpoint[key]
self.model.load_state_dict(checkpoint)
You can replace module keys in state _dict as follows:-
pretrained_dict = {key.replace("module.", ""): value for key, value in pretrained_dict.items()}
model.load_state_dict(pretrained_dict)
Ideally, if you use DataParallel save the checkpoint file as follows for inference:-
torch.save(model.module.state_dict(), 'model_ckpt.pt')
.
This might also be useful for running inference using CPU, at a later time.
Ref: https://stackoverflow.com/a/61854807/3177661
I got this error for a different reason - in my case, it was because I was trying to load variables that were not yet defined in the model. In the saved model, they were defined in a non-init function,
class NN(nn.Module):
def __init__(self):
super().__init__()
def save(self, model_save_path):
os.makedirs(os.path.dirname(model_save_path), exist_ok=True)
save_dict = {}
save_dict.update({'xgboost_model':copy.deepcopy(self.model)})
save_dict.update({'model':copy.deepcopy(self.state_dict())})
torch.save(save_dict, model_save_path)
def fit(self, X_train, Y_train, X_val, Y_val, **kwargs):
train_data_mean = X_train.mean(axis=0)
train_data_std = X_train.std(axis=0)
self.register_buffer('train_data_mean', torch.tensor(train_data_mean.values))
self.register_buffer('train_data_std', torch.tensor(train_data_std.values))
self.model.fit(X_train, Y_train, eval_set=[(X_train, Y_train), (X_val, Y_val)], **kwargs)
def forward(self, inputs):
return inputs
The solution was initializing filler variables in the constructor, to be replaced by the loaded values
class NN(nn.Module):
def __init__(self):
super().__init__()
self.register_buffer('train_data_mean', torch.zeros(52))
self.register_buffer('train_data_std', torch.ones(52))
def save(self, model_save_path):
os.makedirs(os.path.dirname(model_save_path), exist_ok=True)
save_dict = {}
save_dict.update({'xgboost_model':copy.deepcopy(self.model)})
save_dict.update({'model':copy.deepcopy(self.state_dict())})
torch.save(save_dict, model_save_path)
def fit(self, X_train, Y_train, X_val, Y_val, **kwargs):
train_data_mean = X_train.mean(axis=0)
train_data_std = X_train.std(axis=0)
self.register_buffer('train_data_mean', torch.tensor(train_data_mean.values))
self.register_buffer('train_data_std', torch.tensor(train_data_std.values))
self.model.fit(X_train, Y_train, eval_set=[(X_train, Y_train), (X_val, Y_val)], **kwargs)
def forward(self, inputs):
return inputs