Pytorch学习笔记11:Sinx正弦函数曲线拟合
#需要import的lib
import torch
import time
import platform
import cmath
import matplotlib.pyplot as plt
import numpy as np
#import CV2
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
#需要import的lib
#运行环境tesla k20/python 3.7/pytorch 1.20
print('——————————运行环境——————————')
print('Python Version:',platform.python_version())
print('Torch Version:',torch.__version__)
#print('OpenCV Version:',CV2.__version__)
print('CUDA GPU check:',torch.cuda.is_available())
if(torch.cuda.is_available()):
print('CUDA GPU num:', torch.cuda.device_count())
n=torch.cuda.device_count()
while n > 0:
print('CUDA GPU name:', torch.cuda.get_device_name(n-1))
print('CUDA GPU capability:', torch.cuda.get_device_capability(n-1))
print('CUDA GPU properties:', torch.cuda.get_device_properties(n-1))
n -= 1
print('CUDA GPU index:', torch.cuda.current_device())
print('——————————运行环境——————————')
time_start=time.time()
#device=torch.device('cuda:0')
#工程优化应用,GPU不一定更快
device=torch.device('cpu')
#对sinx进行采样
#3.14弧度=180度
data=torch.zeros(2,100)
data.requires_grad=False
pred=torch.zeros(2,100)
pred.requires_grad=False
graph=torch.zeros(2,100)
graph.requires_grad=False
wb=torch.ones(2,5)
wb.requires_grad=True
print(wb)
for i in range(0,100):#对sinx进行采样,加了正态分布的噪声
data[0][i]=i*0.1
mid=torch.from_numpy(np.random.randn(1)*0.01)#注意numpy转tensor
data[1][i]=cmath.sin(i*0.1).real+mid
print(data)
pred[0]=data[0]
graph[0]=data[0]
print(pred)
def func1(wb,pred):
y=(pred[0]*0.1)*wb[0][0]+wb[1][0]+(pred[0]*0.1)**2*wb[0][1]+wb[1][1]+(pred[0]*0.1)**3*wb[0][2]+wb[1][2]+(pred[0]*0.1)**4*wb[0][3]+wb[1][3]+(pred[0]*0.1)**5*wb[0][4]+wb[1][4] #5次方泰勒展开
return y
loss_func=torch.nn.MSELoss()
optim=torch.optim.Adam([wb],lr=1e-3)
for step in range(3000000):#如果迭代过程太慢,可以减少这个数值做测试体验下,不过缩小后拟合效果不好
loss=loss_func(func1(wb,pred),data[1])
optim.zero_grad()
loss.backward()
optim.step()
if step % 200 == 0:
#print('step {}:x={},f(x)={}'.format(step, wb.tolist(), loss.item()))
print('wb:',wb)
print('loss:',loss.item())
graph[1]=func1(wb, graph)
a=graph[0].detach().numpy()#注意tensor转numpy
b=graph[1].detach().numpy()
plt.plot(a, b)
c=data[0].detach().numpy()
d=data[1].detach().numpy()
plt.plot(c, d)
plt.show()
time_end=time.time()
print('Totally cost',time_end-time_start,'s')