欢迎光临散文网 会员登陆 & 注册

一个简单的pytorch中线性回归模型的创建、训练和测试,每一步都有明确注释

2023-03-31 19:37 作者:我还在等你GDF  | 我要投稿

#调用torch库
import torch
#调用numpy库
import numpy
#可有可无,用来调用查看线型图的方法
import matplotlib.pyplot as plt
#定义一个训练集
t_data = torch.rand(100, 1)
print("t_data:", t_data)
#定义一个训练集的答案
p_data = t_data * 2
print("p_data:", p_data)

#定义一个线性回归模型
class LinearRegression(torch.nn.Module):
   #初始化线性回归模型,定义网络的结构
   def __init__(self):
       #初始化父类
       super(LinearRegression, self).__init__()
       #定义自己的神经元
       self.Linear = torch.nn.Linear(1, 1)
   #定义网络的计算:前向传递计算
   def forward(self, x):
       #将x(输入的数据,可以是训练集或者测试集)在自定义的全连接层中进行计算并返回计算得到的预测值
       pred = self.Linear(x)
       return pred
#定义模型
model = LinearRegression()
#定义损失函数
mse_loss = torch.nn.MSELoss(reduction='mean')
#定义优化器
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

#查看定义的模型中的参数:可以看出模型的初始参数是随机设置的,我们想要的参数是通过训练集来进行更新得到的
for name, parameters in model.named_parameters():
   print("name:", name)
   print("parameters:", parameters)

#通过训练集对模型进行训练
for epoch in range(1000):
   #通过训练集t_data得到训练结果pred
   pred = model(t_data)
   #通过pred和p_data计算损失
   loss = mse_loss(pred, p_data)
   #输出经过计算的损失值以方便查看损失是否降低来判断训练是否有效(为了不过于繁琐,只查看第0, 200, 400, 600, 800, 1000次的损失值)
   if (epoch + 1) % 200 == 0:
       print("Epoch:", epoch + 1)
       print("loss:", loss.item())
   #在反向计算梯度前,需要先进行优化器的梯度清零,否则梯度会在每次训练过程中累加
   optimizer.zero_grad()
   #进行梯度的反向计算
   loss.backward()
   #通过优化器更新模型中的参数
   optimizer.step()
   #输出经过调试后的参数以方便查看权重参数的更替(为了不过于繁琐,只查看第0, 200, 400, 600, 800, 1000次的权重参数)
   if (epoch + 1) % 200 == 0:
       for name,parameters in model.named_parameters():
           print("name:", name)
           print("parameters:", parameters, end='\n\n')

#通过自己输入一个测试集得到预测结果来判断模型训练的如何
t_test = torch.rand(100, 1)
p_test = model(t_test)
print("t_test:", t_test)
print("p_test:", p_test)
#可有可无 用来查看线型图,通过线型图可以发现,测试集和测试结果符合的线性关系是否为2来判断训练的效果
t_test = t_test.data.numpy()
p_test = p_test.data.numpy()
plt.scatter(t_test, p_test)
plt.show()


一个简单的pytorch中线性回归模型的创建、训练和测试,每一步都有明确注释的评论 (共 条)

分享到微博请遵守国家法律