欢迎光临散文网 会员登陆 & 注册

Pytorch学习笔记11:Sinx正弦函数曲线拟合

2021-06-10 09:57 作者:车科技2020  | 我要投稿

#需要import的lib
import torch
import time
import platform
import cmath
import matplotlib.pyplot as plt
import numpy as np

#import CV2
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

#需要import的lib
#运行环境tesla k20/python 3.7/pytorch 1.20
print('——————————运行环境——————————')
print('Python Version:',platform.python_version())
print('Torch Version:',torch.__version__)
#print('OpenCV Version:',CV2.__version__)
print('CUDA GPU check:',torch.cuda.is_available())
if(torch.cuda.is_available()):
print('CUDA GPU num:', torch.cuda.device_count())
n=torch.cuda.device_count()
while n > 0:
  print('CUDA GPU name:', torch.cuda.get_device_name(n-1))
  print('CUDA GPU capability:', torch.cuda.get_device_capability(n-1))
  print('CUDA GPU properties:', torch.cuda.get_device_properties(n-1))
  n -= 1
print('CUDA GPU index:', torch.cuda.current_device())
print('——————————运行环境——————————')

time_start=time.time()

#device=torch.device('cuda:0')
#工程优化应用,GPU不一定更快
device=torch.device('cpu')

#对sinx进行采样
#3.14弧度=180度
data=torch.zeros(2,100)
data.requires_grad=False

pred=torch.zeros(2,100)
pred.requires_grad=False

graph=torch.zeros(2,100)
graph.requires_grad=False

wb=torch.ones(2,5)
wb.requires_grad=True
print(wb)

for i in range(0,100):#对sinx进行采样,加了正态分布的噪声
   data[0][i]=i*0.1
   mid=torch.from_numpy(np.random.randn(1)*0.01)#注意numpy转tensor
   data[1][i]=cmath.sin(i*0.1).real+mid
print(data)



pred[0]=data[0]
graph[0]=data[0]
print(pred)



def func1(wb,pred):
   y=(pred[0]*0.1)*wb[0][0]+wb[1][0]+(pred[0]*0.1)**2*wb[0][1]+wb[1][1]+(pred[0]*0.1)**3*wb[0][2]+wb[1][2]+(pred[0]*0.1)**4*wb[0][3]+wb[1][3]+(pred[0]*0.1)**5*wb[0][4]+wb[1][4] #5次方泰勒展开
   return y

loss_func=torch.nn.MSELoss()
optim=torch.optim.Adam([wb],lr=1e-3)

for step in range(3000000):#如果迭代过程太慢,可以减少这个数值做测试体验下,不过缩小后拟合效果不好
   loss=loss_func(func1(wb,pred),data[1])
   optim.zero_grad()
   loss.backward()
   optim.step()
   if step % 200 == 0:
       #print('step {}:x={},f(x)={}'.format(step, wb.tolist(), loss.item()))
       print('wb:',wb)
       print('loss:',loss.item())


graph[1]=func1(wb, graph)
a=graph[0].detach().numpy()#注意tensor转numpy
b=graph[1].detach().numpy()
plt.plot(a, b)
c=data[0].detach().numpy()
d=data[1].detach().numpy()
plt.plot(c, d)
plt.show()

time_end=time.time()
print('Totally cost',time_end-time_start,'s')

Pytorch学习笔记11:Sinx正弦函数曲线拟合的评论 (共 条)

分享到微博请遵守国家法律