欢迎您访问365答案网,请分享给你的朋友!
生活常识 学习资料

Pytorch拟合心形函数

时间:2023-05-16
前言(情人节对象不够?AI来凑!

 这不到了情人节,还没有对象一起过节。就拉着AI来过。

最开始用10000(1万)轮训练,我差点以为AI差点跟我说“你想吃peach”

后来给她加个0,才听话。

废话不多说,看正文。

主要思路

多项式f(x)=13cosx-5cos2x-2cos3x-cos4x

不同于前一篇,这次来个参数方程,但我们只拟合y

输入参数为[cosx,cos2x,cos3x,cos4x]

需要拟合的参数为[13,-5,-2,-1]

所以不需要激活层,只要一个线性层

验证采用留一法

一共训练100000轮,只能说拟合得太慢

详细代码

#多项式f(x)=13cosx-5cos2x-2cos3x-cos4ximport torchimport numpy as npimport randomimport matplotlib.pyplot as pltt = torch.linspace(-15,15,1000)#生成-15到15的1000个数构成的等差数列x = 16*torch.sin(t)**3y = 13*torch.cos(t)-5*torch.cos(2*t)-2*torch.cos(3*t)-torch.cos(4*t)plt.scatter(x.data.numpy(),y.data.numpy())plt.show()

 

def y_features(t): #[cosx,cos2x,cos3x,cos4x] t = t.unsqueeze(1) return torch.cat([torch.cos(t * i) for i in range(1,5)],1)def x_features(t): t = t.unsqueeze(1) return 16*torch.sin(t)**3

t_weights = torch.Tensor([13,-5,-2,-1]).unsqueeze(1)def target(t): return t.mm(t_weights) #矩阵相乘

#随机生成训练数据def get_batch_data(batch_size): batch_x = torch.randn(batch_size) #print(batch_x) features_x = x_features(batch_x) features_y = y_features(batch_x) target_x = features_x target_y = target(features_y) return target_x,features_y,target_y

#建立模型class PolynomiaRegression(torch.nn.Module): def __init__(self): super(PolynomiaRegression,self).__init__() self.poly = torch.nn.Linear(4,1) def forward(self,t): return self.poly(t)

#开始训练import mathepochs = 100000batch_size = 32model =PolynomiaRegression()criterion = torch.nn.MSELoss()optimizer = torch.optim.SGD(model.parameters(),0.001)loss_value = np.infloss_holder = []step = 0for epoch in range(epochs): target_x,batch_x,batch_y = get_batch_data(batch_size) out = model(batch_x) loss = criterion(out,batch_y) optimizer.zero_grad() loss.backward() optimizer.step() if(loss

Epoch:[1/100000],loss:[65.298500]

Epoch:[1001/100000],loss:[2.491852] Epoch:[2001/100000],loss:[2.300167] Epoch:[3001/100000],loss:[0.070248] Epoch:[4001/100000],loss:[13.497649] Epoch:[5001/100000],loss:[0.566398] Epoch:[6001/100000],loss:[3.217979] Epoch:[7001/100000],loss:[0.192295] Epoch:[8001/100000],loss:[2.993185] Epoch:[9001/100000],loss:[0.840189] Epoch:[10001/100000],loss:[0.044125]

Epoch:[11001/100000],loss:[0.030104] Epoch:[12001/100000],loss:[0.057011] Epoch:[13001/100000],loss:[0.652313] Epoch:[14001/100000],loss:[0.055101] Epoch:[15001/100000],loss:[0.035360] Epoch:[16001/100000],loss:[0.404017] Epoch:[17001/100000],loss:[0.032413] Epoch:[18001/100000],loss:[0.026495] Epoch:[19001/100000],loss:[0.258574] Epoch:[20001/100000],loss:[0.013485]

Epoch:[21001/100000],loss:[0.030264] Epoch:[22001/100000],loss:[1.169118] Epoch:[23001/100000],loss:[1.099216] Epoch:[24001/100000],loss:[0.470649] Epoch:[25001/100000],loss:[0.115718] Epoch:[26001/100000],loss:[0.105106] Epoch:[27001/100000],loss:[0.365399] Epoch:[28001/100000],loss:[0.002508] Epoch:[29001/100000],loss:[0.008796] Epoch:[30001/100000],loss:[0.333936]

Epoch:[31001/100000],loss:[0.007649] Epoch:[32001/100000],loss:[0.012496] Epoch:[33001/100000],loss:[0.006446] Epoch:[34001/100000],loss:[0.060192] Epoch:[35001/100000],loss:[0.005069] Epoch:[36001/100000],loss:[0.003233] Epoch:[37001/100000],loss:[0.003273] Epoch:[38001/100000],loss:[0.001058] Epoch:[39001/100000],loss:[0.001612] Epoch:[40001/100000],loss:[0.037987]

Epoch:[41001/100000],loss:[0.081242] Epoch:[42001/100000],loss:[0.008152] Epoch:[43001/100000],loss:[0.002064] Epoch:[44001/100000],loss:[0.001179] Epoch:[45001/100000],loss:[0.001132] Epoch:[46001/100000],loss:[0.003256] Epoch:[47001/100000],loss:[0.001605] Epoch:[48001/100000],loss:[0.021894] Epoch:[49001/100000],loss:[0.017690] Epoch:[50001/100000],loss:[0.001353]

Epoch:[51001/100000],loss:[0.000576] Epoch:[52001/100000],loss:[0.000605] Epoch:[53001/100000],loss:[0.000519] Epoch:[54001/100000],loss:[0.002327] Epoch:[55001/100000],loss:[0.000258] Epoch:[56001/100000],loss:[0.000120] Epoch:[57001/100000],loss:[0.013316] Epoch:[58001/100000],loss:[0.000190] Epoch:[59001/100000],loss:[0.000207] Epoch:[60001/100000],loss:[0.000129]

Epoch:[61001/100000],loss:[0.000184] Epoch:[62001/100000],loss:[0.001675] Epoch:[63001/100000],loss:[0.000082] Epoch:[64001/100000],loss:[0.002057] Epoch:[65001/100000],loss:[0.002364] Epoch:[66001/100000],loss:[0.000843] Epoch:[67001/100000],loss:[0.000024] Epoch:[68001/100000],loss:[0.000076] Epoch:[69001/100000],loss:[0.000077] Epoch:[70001/100000],loss:[0.000027]

Epoch:[71001/100000],loss:[0.000503] Epoch:[72001/100000],loss:[0.001702] Epoch:[73001/100000],loss:[0.001267] Epoch:[74001/100000],loss:[0.000023] Epoch:[75001/100000],loss:[0.000828] Epoch:[76001/100000],loss:[0.000147] Epoch:[77001/100000],loss:[0.000082] Epoch:[78001/100000],loss:[0.000008] Epoch:[79001/100000],loss:[0.000137] Epoch:[80001/100000],loss:[0.000009]

Epoch:[81001/100000],loss:[0.000667] Epoch:[82001/100000],loss:[0.000015] Epoch:[83001/100000],loss:[0.000071] Epoch:[84001/100000],loss:[0.000247] Epoch:[85001/100000],loss:[0.000227] Epoch:[86001/100000],loss:[0.000258] Epoch:[87001/100000],loss:[0.000006] Epoch:[88001/100000],loss:[0.000004] Epoch:[89001/100000],loss:[0.000016] Epoch:[90001/100000],loss:[0.000054]

Epoch:[91001/100000],loss:[0.000018] Epoch:[92001/100000],loss:[0.000002] Epoch:[93001/100000],loss:[0.000001] Epoch:[94001/100000],loss:[0.000003] Epoch:[95001/100000],loss:[0.000048] Epoch:[96001/100000],loss:[0.000061] Epoch:[97001/100000],loss:[0.000076] Epoch:[98001/100000],loss:[0.000001] Epoch:[99001/100000],loss:[0.000015]

 结果

 从运行结果可以看到98100轮后(9810*10)基本稳定。但从测试结果看出100000轮都还没过拟合。。。看来还需要多几轮训练。

Copyright © 2016-2020 www.365daan.com All Rights Reserved. 365答案网 版权所有 备案号:

部分内容来自互联网,版权归原作者所有,如有冒犯请联系我们,我们将在三个工作时内妥善处理。