添加链接
link管理
链接快照平台
  • 输入网页链接,自动生成快照
  • 标签化管理网页链接

4.5. 读取和存储

到目前为止,我们介绍了如何处理数据以及如何构建、训练和测试深度学习模型。然而在实际中,我们有时需要把训练好的模型部署到很多不同的设备。在这种情况下,我们可以把内存中训练好的模型参数存储在硬盘上供后续读取使用。

4.5.1. 读写 NDArray

我们可以直接使用 save 函数和 load 函数分别存储和读取 NDArray 。下面的例子创建了 NDArray 变量 x ,并将其存在文件名同为 x 的文件里。

In [1]: from mxnet import nd from mxnet.gluon import nn x = nd . ones ( 3 ) nd . save ( 'x' , x )

然后我们将数据从存储的文件读回内存。

In [2]: x2 = nd . load ( 'x' ) mydict = { 'x' : x , 'y' : y } nd . save ( 'mydict' , mydict ) mydict2 = nd . load ( 'mydict' ) mydict2

4.5.2. 读写Gluon模型的参数

NDArray 以外,我们还可以读写Gluon模型的参数。Gluon的 Block 类提供了 save_parameters 函数和 load_parameters 函数来读写模型参数。为了演示方便,我们先创建一个多层感知机,并将其初始化。回忆 “模型参数的延后初始化” 一节,由于延后初始化,我们需要先运行一次前向计算才能实际初始化模型参数。

In [5]: class MLP ( nn . Block ): def __init__ ( self , ** kwargs ): super ( MLP , self ) . __init__ ( ** kwargs ) self . hidden = nn . Dense ( 256 , activation = 'relu' ) self . output = nn . Dense ( 10 ) def forward ( self , x ): return self . output ( self . hidden ( x )) net = MLP () net . initialize () X = nd . random . uniform ( shape = ( 2 , 20 )) Y = net ( X )

下面把该模型的参数存成文件,文件名为mlp.params。

In [6]: filename = 'mlp.params' net . save_parameters ( filename )

接下来,我们再实例化一次定义好的多层感知机。与随机初始化模型参数不同,我们在这里直接读取保存在文件里的参数。

In [7]: net2 = MLP () net2 . load_parameters ( filename )

因为这两个实例都有同样的模型参数,那么对同一个输入 X 的计算结果将会是一样的。我们来验证一下。

In [8]: Y2 = net2 ( X ) Y2 == Y