添加链接
link管理
链接快照平台
  • 输入网页链接,自动生成快照
  • 标签化管理网页链接

Pytorch目前成为学术界最流行的DL框架,没有之一。很大程度上,简洁直观地操作有关。模型的保存和加载,于pytorch而言,也是很简单的。本文做了一个比较实验,方便大家理解。

首先,要清楚几个函数: torch.save torch.load state_dict() load_state_dict()
先举最简单的例子:

import torch
model = torch.load('my_model.pth')
torch.save(model, 'new_model.pth')

上面的代码非常直观,一载一存。但是有一个问题,这样保存的pth文件直接包含了整个模型的结构。当你需要灵活加载模型参数时,比如只加载部分参数,那么这种情况保存的pth文件读取进来还得额外解析出“参数文件”

如果想更灵活对待咱们训练好的模型参数,咱们可以使用下面这个方法。pytorch把所有的模型参数用一个内部定义的dict进行保存,自称为“state_dict。这个所谓的state_dict就是不带模型结构的模型参数了~
咱们的加载和保存就发生了一点微妙的变化:

import torch
model = MyModel() # init your model class, build the graph shape
state_dict = torch.load('model_state_dict.pth')
model.load_state_dict(state_dict)
torch.save(model.state_dict(), 'model_state_dict1.pth')

比较上面两段代码,咱们可以有一下结论:

  1. pth文件既可能保存了模型的图结构,也有可能没保存;
  2. 加载没保存图结构的pth时,需要先初始化模型结构,即把架子搭好;
  3. 在保存模型的时候,如果不想保存图结构,可以单独保存model.state_dict()

到这里还不清楚的可以留言提问~
参考链接:https://pytorch.org/tutorials/beginner/basics/saveloadrun_tutorial.html

脚本如下:

import torch
import torchvision.models as models
model = models.vgg16(pretrained=True)
torch.save(model.state_dict(), 'only_weights.pth')
model_state_dict = torch.load('only_weights.pth')
model1 = models.vgg16() # describe the graph shape
model1.load_state_dict(model_state_dict)
model1.eval()
torch.save(model1, 'whole_model.pth')
model2 = torch.load('whole_model.pth')
model2.eval()
# model3 = torch.load('only_weights.pth')
# model3.eval()    # Error

model3切换到eval()模式就会报错,原因是model3只包含weights而缺乏图结构~

Pytorch目前成为学术界最流行的DL框架,没有之一。很大程度上,简洁直观地操作有关。模型的保存和加载,于pytorch而言,也是很简单的。本文做了一个比较实验,方便大家理解。首先,要清楚几个函数:torch.save,torch.load,state_dict(),load_state_dict()。先举最简单的例子:import torchmodel = torch.load('my_model.pth')torch.save(model, 'new_model.pth')上面的代码非 k.replace('module.',''):v for k,v in  torch.load(config.model_path, map_location=config.device).items() model = self.model.to(config.device) * config.device 指定使用哪块 the_model = TheModelClass(*args, **kwargs) the_model.load_state_dict(torch.load(PATH)) 使用这种方法,我们需要自己导入模...
希望将训练好的模型加载到新的网络上。如上面题目所描述的,PyTorch加载之前保存模型参数的时候,遇到了问题。 Unexpected key(s) in state_dict: "module.features. ...".,Expected ".features....". 直接原因是key值名字不对应。 表明了加载过程中,期望获得的key值为feature...,而不是...
1.作用:用来加载torch.save()保存模型文件。 torch.load()先在CPU上加载,不会依赖于保存模型的设备。如果加载失败,可能是因为没有包含某些设备,比如你在gpu上训练保存模型,而在cpu上加载,可能会报错,此时,需要使用map_location来将存储动态重新映射到可选设备上,比如map_location=torch.device('cpu'),意思是映射到cpu上,在cpu上加载模型,无论你这个模型从哪里训练保存的。 一句话:map_location适用于修改模型能在gpu.
更详细的可以参考这里https://pytorch.org/docs/stable/notes/serialization.html#preserve-storage-sharing torch.save()并torch.load()让您轻松保存加载张量:最简单的就是 t = torch.tensor([1., 2.]) torch.save(t, 'tensor.pth') torch.load('tensor.pth') 按照惯例,Py
PyTorch保存加载模型非常简单。你可以使用torch.save()函数将模型保存为.pth或.pkl文件,使用torch.load()函数从文件中加载模型。 以下是一个基本的示例,展示如何保存加载一个PyTorch模型保存模型: import torch model = YourModel() torch.save(model.state_dict(), 'model.pth') 加载模型: import torch model = YourModel() model.load_state_dict(torch.load('model.pth')) model.eval() 需要注意的是,加载模型时应该确保与保存时使用相同的代码版本、PyTorch版本和硬件设备,并且需要调用model.eval()以确保在推理过程中正确设置一些模型参数。