5.8 模型文件的读写
前面我们讲了众多方法,来训练好一个模型,让模型能够收敛,同时又不出现过拟合和欠拟合问题。当模型训练好以后,就需要对模型参数进行保存,以便部署到不同的环境中去使用。通常,一个深度学习模型的训练需要消耗很长的时间,比如几天,如果训练过程中出现了问题,无论是软件问题还是硬件问题,又或者是外部因素比如突然断电,造成的损失都是非常大的。因此,我们不仅要在最后对模型进行保存,训练过程中也应该定时保存中间结果,以避免损失。这一节我们就来学习一下模型保存的方法。
5.8.1 张量的保存和加载
在深度学习中,模型的参数一般是张量形式的。对于单个的张量,pytorch为我们提供了方便直接的函数来进行读写。比如我们定义如下的一个张量a。
import torch
a = torch.rand(10)
a
tensor([0.1190, 0.7933, 0.9636, 0.5436, 0.2750, 0.3664, 0.4274, 0.9336, 0.1324,
0.8449])
可以简单的用一个save函数去存储这个张量a,这里需要我们给他起一个名字,我们就叫它tensor-a,把它放在model文件夹里。
torch.save(a, 'model/tensor-a')
这就完成了张量的写入,这时我们可以在当前路径下的model文件夹里看到tensor-a这个文件。读取同样简单,只需要用一个load函数就可以完成张量的加载,传入的参数是文件的路径。
torch.load('model/tensor-a')
tensor([0.1190, 0.7933, 0.9636, 0.5436, 0.2750, 0.3664, 0.4274, 0.9336, 0.1324,
0.8449])
如果我们要存储的不止一个张量,也没有问题,save和load函数同样支持保存张量列表。先把张量数据存储起来。
a = torch.rand(10)
b = torch.rand(10)
c = torch.rand(10)
torch.save([a,b,c], 'model/tensor-abc')
[a,b,c]
[tensor([0.0270, 0.8876, 0.4965, 0.5507, 0.9629, 0.7735, 0.9478, 0.7899, 0.7003,
0.5002]),
tensor([0.3628, 0.1818, 0.3137, 0.4671, 0.6445, 0.0022, 0.2800, 0.4637, 0.4888,
0.2336]),
tensor([0.8327, 0.3511, 0.2187, 0.6894, 0.9219, 0.7021, 0.1927, 0.0983, 0.6716,
0.1998])]
然后再把它读取出来。
torch.load('model/tensor-abc')
[tensor([0.0270, 0.8876, 0.4965, 0.5507, 0.9629, 0.7735, 0.9478, 0.7899, 0.7003,
0.5002]),
tensor([0.3628, 0.1818, 0.3137, 0.4671, 0.6445, 0.0022, 0.2800, 0.4637, 0.4888,
0.2336]),
tensor([0.8327, 0.3511, 0.2187, 0.6894, 0.9219, 0.7021, 0.1927, 0.0983, 0.6716,
0.1998])]
对于多个张量,pytorch同样支持以字典的形式来进行存储。比如我们建立一个字典tensor_dict,然后把它存起来。
a = torch.rand(10)
b = torch.rand(10)
c = torch.rand(10)
tensor_dict={'a':a, 'b':b, 'c':c}
torch.save(tensor_dict, 'model/tensor_dict')
tensor_dict
{'a': tensor([0.1925, 0.3094, 0.8293, 0.3449, 0.3672, 0.3616, 0.9751, 0.7442, 0.8948,
0.9062]),
'b': tensor([0.6409, 0.1292, 0.1913, 0.0356, 0.0109, 0.8862, 0.9702, 0.4830, 0.2453,
0.0902]),
'c': tensor([0.4258, 0.1488, 0.8010, 0.0061, 0.9639, 0.2933, 0.3556, 0.0569, 0.9560,
0.4338])}
然后是读取。
torch.load('model/tensor_dict')
{'a': tensor([0.1925, 0.3094, 0.8293, 0.3449, 0.3672, 0.3616, 0.9751, 0.7442, 0.8948,
0.9062]),
'b': tensor([0.6409, 0.1292, 0.1913, 0.0356, 0.0109, 0.8862, 0.9702, 0.4830, 0.2453,
0.0902]),
'c': tensor([0.4258, 0.1488, 0.8010, 0.0061, 0.9639, 0.2933, 0.3556, 0.0569, 0.9560,
0.4338])}
张量的读写非常的简单,接下来我们看看模型整体参数的读写。
5.8.2 模型参数的保存和加载
模型参数一般都是张量形式的,虽然单个张量的保存和加载非常简单,但整个模型中包含着大大小小的若干张量,单独保存这些张量会很困难。为了解决这个问题,pytorch贴心的为我们准备了内置函数来保存加载整个模型参数。我们以5.2节的多层感知机为例,来看一下如何保存。
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
# 定义 MLP 网络
class MLP(nn.Module):
def __init__(self, input_size, hidden_size, num_classes):
super(MLP, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_size, hidden_size)
self.fc3 = nn.Linear(hidden_size, num_classes)
def forward(self, x):
out = self.fc1(x)
out = self.relu(out)
out = self.fc2(out)
out = self.relu(out)
out = self.fc3(out)
return out
# 定义超参数
input_size = 28 * 28 # 输入大小
hidden_size = 512 # 隐藏层大小
num_classes = 10 # 输出大小(类别数)
然后我们实例化一个MLP网络,并随机生成一个输入X,并计算出模型的输出Y。
# 实例化 MLP 网络
model = MLP(input_size, hidden_size, num_classes)
X = torch.randn(size=(2, 28*28))
然后同样是调用save方法,我们把模型存储到model文件夹里,取名叫做mlp.params。
torch.save(model.state_dict(), 'model/mlp.params')
接下来,我们来读取保存好的模型参数,重新加载我们的模型。我们先把模型params参数读取出来,然后实例化一个模型,然后直接调用load_state_dict方法,传入模型参数params。
params = torch.load('model/mlp.params')
model_load = MLP(input_size, hidden_size, num_classes)
model_load.load_state_dict(params)
<All keys matched successfully>
此时两个模型model和model_load具有相同的参数,我们给他输入相同的X,看一下输出结果。
output1 = model(X)
output1
tensor([[ 0.0914, 0.0178, 0.0692, 0.1486, 0.1002, 0.0057, -0.1099, 0.1332,
0.0241, 0.1137],
[-0.0228, 0.0446, 0.1374, 0.2009, -0.0978, -0.0831, -0.0193, 0.1040,
0.1097, 0.1484]], grad_fn=<AddmmBackward0>)
output2 = model_load(X)
output2
tensor([[ 0.0914, 0.0178, 0.0692, 0.1486, 0.1002, 0.0057, -0.1099, 0.1332,
0.0241, 0.1137],
[-0.0228, 0.0446, 0.1374, 0.2009, -0.0978, -0.0831, -0.0193, 0.1040,
0.1097, 0.1484]], grad_fn=<AddmmBackward0>)
可以看到,输出的结果完全一致,说明我们将参数成功地读取并载入了模型中。
梗直哥提示:使用save保存的是模型参数而不是整个模型,因此在模型加载参数的时候,需要我们单独指定模型架构,并且要保证模型结构和保存的时候一致,否则可能会导致参数加载失败。如果你想了解更多内容,欢迎入群学习(加V: gengzhige99)
同步更新:Gitbub/公众号:梗直哥
