PyTorch梯度下降反向传播实例分析

PyTorch梯度下降反向传播实例分析

本文小编为大家详细介绍“PyTorch梯度下降反向传播实例分析”,内容详细,步骤清晰,细节处理妥当,希望这篇“PyTorch梯度下降反向传播实例分析”文章能帮助大家解决疑惑,下面跟着小编的思路慢慢深入,一起来学习新知识吧。

前言:

反向传播的目的是计算成本函数C对网络中任意w或b的偏导数。一旦我们有了这些偏导数,我们将通过一些常数 α的乘积和该数量相对于成本函数的偏导数来更新网络中的权重和偏差。这是流行的梯度下降算法。而偏导数给出了最大上升的方向。因此,关于反向传播算法,我们继续查看下文。

我们向相反的方向迈出了一小步——最大下降的方向,也就是将我们带到成本函数的局部最小值的方向

如题:

意思是利用这个二次模型来预测数据,减小损失函数(MSE)的值。

代码如下:

importtorchimportmatplotlib.pyplotaspltimportosos.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"#数据集x_data=[1.0,2.0,3.0]y_data=[2.0,4.0,6.0]#权重参数初始值均为1w=torch.tensor([1.0,1.0,1.0])w.requires_grad=True#需要计算梯度#前向传播defforward(x):returnw[0]*(x**2)+w[1]*x+w[2]#计算损失defloss(x,y):y_pred=forward(x)return(y_pred-y)**2#训练模块print('predict(beforetranining)',4,forward(4).item())epoch_list=[]w_list=[]loss_list=[]forepochinrange(1000):forx,yinzip(x_data,y_data):l=loss(x,y)l.backward()#后向传播print('\tgrad:',x,y,w.grad.data)w.data=w.data-0.01*w.grad.data#梯度下降w.grad.data.zero_()#梯度清零操作print('progress:',epoch,l.item())epoch_list.append(epoch)w_list.append(w.data)loss_list.append(l.item())print('predict(aftertranining)',4,forward(4).item())#绘图plt.plot(epoch_list,loss_list,'b')plt.xlabel('Epoch')plt.ylabel('Loss')plt.grid()plt.show()

结果如下:

predict(beforetranining)421.0grad:1.02.0tensor([2.,2.,2.])grad:2.04.0tensor([22.8800,11.4400,5.7200])grad:3.06.0tensor([77.0472,25.6824,8.5608])progress:018.321826934814453grad:1.02.0tensor([-1.1466,-1.1466,-1.1466])grad:2.04.0tensor([-15.5367,-7.7683,-3.8842])grad:3.06.0tensor([-30.4322,-10.1441,-3.3814])progress:12.858394145965576grad:1.02.0tensor([0.3451,0.3451,0.3451])grad:2.04.0tensor([2.4273,1.2137,0.6068])grad:3.06.0tensor([19.4499,6.4833,2.1611])progress:21.1675907373428345grad:1.02.0tensor([-0.3224,-0.3224,-0.3224])grad:2.04.0tensor([-5.8458,-2.9229,-1.4614])grad:3.06.0tensor([-3.8829,-1.2943,-0.4314])progress:30.04653334245085716grad:1.02.0tensor([0.0137,0.0137,0.0137])grad:2.04.0tensor([-1.9141,-0.9570,-0.4785])grad:3.06.0tensor([6.8557,2.2852,0.7617])progress:40.14506366848945618grad:1.02.0tensor([-0.1182,-0.1182,-0.1182])grad:2.04.0tensor([-3.6644,-1.8322,-0.9161])grad:3.06.0tensor([1.7455,0.5818,0.1939])progress:50.009403289295732975grad:1.02.0tensor([-0.0333,-0.0333,-0.0333])grad:2.04.0tensor([-2.7739,-1.3869,-0.6935])grad:3.06.0tensor([4.0140,1.3380,0.4460])progress:60.04972923547029495grad:1.02.0tensor([-0.0501,-0.0501,-0.0501])grad:2.04.0tensor([-3.1150,-1.5575,-0.7788])grad:3.06.0tensor([2.8534,0.9511,0.3170])progress:70.025129113346338272grad:1.02.0tensor([-0.0205,-0.0205,-0.0205])grad:2.04.0tensor([-2.8858,-1.4429,-0.7215])grad:3.06.0tensor([3.2924,1.0975,0.3658])progress:80.03345605731010437grad:1.02.0tensor([-0.0134,-0.0134,-0.0134])grad:2.04.0tensor([-2.9247,-1.4623,-0.7312])grad:3.06.0tensor([2.9909,0.9970,0.3323])progress:90.027609655633568764grad:1.02.0tensor([0.0033,0.0033,0.0033])grad:2.04.0tensor([-2.8414,-1.4207,-0.7103])grad:3.06.0tensor([3.0377,1.0126,0.3375])progress:100.02848036028444767grad:1.02.0tensor([0.0148,0.0148,0.0148])grad:2.04.0tensor([-2.8174,-1.4087,-0.7043])grad:3.06.0tensor([2.9260,0.9753,0.3251])progress:110.02642466314136982grad:1.02.0tensor([0.0280,0.0280,0.0280])grad:2.04.0tensor([-2.7682,-1.3841,-0.6920])grad:3.06.0tensor([2.8915,0.9638,0.3213])progress:120.025804826989769936grad:1.02.0tensor([0.0397,0.0397,0.0397])grad:2.04.0tensor([-2.7330,-1.3665,-0.6832])grad:3.06.0tensor([2.8243,0.9414,0.3138])progress:130.02462013065814972grad:1.02.0tensor([0.0514,0.0514,0.0514])grad:2.04.0tensor([-2.6934,-1.3467,-0.6734])grad:3.06.0tensor([2.7756,0.9252,0.3084])progress:140.023777369409799576grad:1.02.0tensor([0.0624,0.0624,0.0624])grad:2.04.0tensor([-2.6580,-1.3290,-0.6645])grad:3.06.0tensor([2.7213,0.9071,0.3024])progress:150.0228563379496336grad:1.02.0tensor([0.0731,0.0731,0.0731])grad:2.04.0tensor([-2.6227,-1.3113,-0.6557])grad:3.06.0tensor([2.6725,0.8908,0.2969])progress:160.022044027224183083grad:1.02.0tensor([0.0833,0.0833,0.0833])grad:2.04.0tensor([-2.5893,-1.2946,-0.6473])grad:3.06.0tensor([2.6240,0.8747,0.2916])progress:170.02125072106719017grad:1.02.0tensor([0.0931,0.0931,0.0931])grad:2.04.0tensor([-2.5568,-1.2784,-0.6392])grad:3.06.0tensor([2.5780,0.8593,0.2864])progress:180.020513182505965233grad:1.02.0tensor([0.1025,0.1025,0.1025])grad:2.04.0tensor([-2.5258,-1.2629,-0.6314])grad:3.06.0tensor([2.5335,0.8445,0.2815])progress:190.019810274243354797grad:1.02.0tensor([0.1116,0.1116,0.1116])grad:2.04.0tensor([-2.4958,-1.2479,-0.6239])grad:3.06.0tensor([2.4908,0.8303,0.2768])progress:200.019148115068674088grad:1.02.0tensor([0.1203,0.1203,0.1203])grad:2.04.0tensor([-2.4669,-1.2335,-0.6167])grad:3.06.0tensor([2.4496,0.8165,0.2722])progress:210.018520694226026535grad:1.02.0tensor([0.1286,0.1286,0.1286])grad:2.04.0tensor([-2.4392,-1.2196,-0.6098])grad:3.06.0tensor([2.4101,0.8034,0.2678])progress:220.017927465960383415grad:1.02.0tensor([0.1367,0.1367,0.1367])grad:2.04.0tensor([-2.4124,-1.2062,-0.6031])grad:3.06.0tensor([2.3720,0.7907,0.2636])progress:230.01736525259912014grad:1.02.0tensor([0.1444,0.1444,0.1444])grad:2.04.0tensor([-2.3867,-1.1933,-0.5967])grad:3.06.0tensor([2.3354,0.7785,0.2595])progress:240.016833148896694183grad:1.02.0tensor([0.1518,0.1518,0.1518])grad:2.04.0tensor([-2.3619,-1.1810,-0.5905])grad:3.06.0tensor([2.3001,0.7667,0.2556])progress:250.01632905937731266grad:1.02.0tensor([0.1589,0.1589,0.1589])grad:2.04.0tensor([-2.3380,-1.1690,-0.5845])grad:3.06.0tensor([2.2662,0.7554,0.2518])progress:260.01585075818002224grad:1.02.0tensor([0.1657,0.1657,0.1657])grad:2.04.0tensor([-2.3151,-1.1575,-0.5788])grad:3.06.0tensor([2.2336,0.7445,0.2482])progress:270.015397666022181511grad:1.02.0tensor([0.1723,0.1723,0.1723])grad:2.04.0tensor([-2.2929,-1.1465,-0.5732])grad:3.06.0tensor([2.2022,0.7341,0.2447])progress:280.014967591501772404grad:1.02.0tensor([0.1786,0.1786,0.1786])grad:2.04.0tensor([-2.2716,-1.1358,-0.5679])grad:3.06.0tensor([2.1719,0.7240,0.2413])progress:290.014559715054929256grad:1.02.0tensor([0.1846,0.1846,0.1846])grad:2.04.0tensor([-2.2511,-1.1255,-0.5628])grad:3.06.0tensor([2.1429,0.7143,0.2381])progress:300.014172340743243694grad:1.02.0tensor([0.1904,0.1904,0.1904])grad:2.04.0tensor([-2.2313,-1.1157,-0.5578])grad:3.06.0tensor([2.1149,0.7050,0.2350])progress:310.013804304413497448grad:1.02.0tensor([0.1960,0.1960,0.1960])grad:2.04.0tensor([-2.2123,-1.1061,-0.5531])grad:3.06.0tensor([2.0879,0.6960,0.2320])progress:320.013455045409500599grad:1.02.0tensor([0.2014,0.2014,0.2014])grad:2.04.0tensor([-2.1939,-1.0970,-0.5485])grad:3.06.0tensor([2.0620,0.6873,0.2291])progress:330.013122711330652237grad:1.02.0tensor([0.2065,0.2065,0.2065])grad:2.04.0tensor([-2.1763,-1.0881,-0.5441])grad:3.06.0tensor([2.0370,0.6790,0.2263])progress:340.01280694268643856grad:1.02.0tensor([0.2114,0.2114,0.2114])grad:2.04.0tensor([-2.1592,-1.0796,-0.5398])grad:3.06.0tensor([2.0130,0.6710,0.2237])progress:350.012506747618317604grad:1.02.0tensor([0.2162,0.2162,0.2162])grad:2.04.0tensor([-2.1428,-1.0714,-0.5357])grad:3.06.0tensor([1.9899,0.6633,0.2211])progress:360.012220758944749832grad:1.02.0tensor([0.2207,0.2207,0.2207])grad:2.04.0tensor([-2.1270,-1.0635,-0.5317])grad:3.06.0tensor([1.9676,0.6559,0.2186])progress:370.01194891706109047grad:1.02.0tensor([0.2251,0.2251,0.2251])grad:2.04.0tensor([-2.1118,-1.0559,-0.5279])grad:3.06.0tensor([1.9462,0.6487,0.2162])progress:380.011689926497638226grad:1.02.0tensor([0.2292,0.2292,0.2292])grad:2.04.0tensor([-2.0971,-1.0485,-0.5243])grad:3.06.0tensor([1.9255,0.6418,0.2139])progress:390.01144315768033266grad:1.02.0tensor([0.2333,0.2333,0.2333])grad:2.04.0tensor([-2.0829,-1.0414,-0.5207])grad:3.06.0tensor([1.9057,0.6352,0.2117])progress:400.011208509095013142grad:1.02.0tensor([0.2371,0.2371,0.2371])grad:2.04.0tensor([-2.0693,-1.0346,-0.5173])grad:3.06.0tensor([1.8865,0.6288,0.2096])progress:410.0109840864315629grad:1.02.0tensor([0.2408,0.2408,0.2408])grad:2.04.0tensor([-2.0561,-1.0280,-0.5140])grad:3.06.0tensor([1.8681,0.6227,0.2076])progress:420.010770938359200954grad:1.02.0tensor([0.2444,0.2444,0.2444])grad:2.04.0tensor([-2.0434,-1.0217,-0.5108])grad:3.06.0tensor([1.8503,0.6168,0.2056])progress:430.010566935874521732grad:1.02.0tensor([0.2478,0.2478,0.2478])grad:2.04.0tensor([-2.0312,-1.0156,-0.5078])grad:3.06.0tensor([1.8332,0.6111,0.2037])progress:440.010372749529778957grad:1.02.0tensor([0.2510,0.2510,0.2510])grad:2.04.0tensor([-2.0194,-1.0097,-0.5048])grad:3.06.0tensor([1.8168,0.6056,0.2019])progress:450.010187389329075813grad:1.02.0tensor([0.2542,0.2542,0.2542])grad:2.04.0tensor([-2.0080,-1.0040,-0.5020])grad:3.06.0tensor([1.8009,0.6003,0.2001])progress:460.010010283440351486grad:1.02.0tensor([0.2572,0.2572,0.2572])grad:2.04.0tensor([-1.9970,-0.9985,-0.4992])grad:3.06.0tensor([1.7856,0.5952,0.1984])progress:470.00984097272157669grad:1.02.0tensor([0.2600,0.2600,0.2600])grad:2.04.0tensor([-1.9864,-0.9932,-0.4966])grad:3.06.0tensor([1.7709,0.5903,0.1968])progress:480.009679674170911312grad:1.02.0tensor([0.2628,0.2628,0.2628])grad:2.04.0tensor([-1.9762,-0.9881,-0.4940])grad:3.06.0tensor([1.7568,0.5856,0.1952])progress:490.009525291621685028grad:1.02.0tensor([0.2655,0.2655,0.2655])grad:2.04.0tensor([-1.9663,-0.9832,-0.4916])grad:3.06.0tensor([1.7431,0.5810,0.1937])progress:500.00937769003212452grad:1.02.0tensor([0.2680,0.2680,0.2680])grad:2.04.0tensor([-1.9568,-0.9784,-0.4892])grad:3.06.0tensor([1.7299,0.5766,0.1922])progress:510.009236648678779602grad:1.02.0tensor([0.2704,0.2704,0.2704])grad:2.04.0tensor([-1.9476,-0.9738,-0.4869])grad:3.06.0tensor([1.7172,0.5724,0.1908])progress:520.00910158734768629grad:1.02.0tensor([0.2728,0.2728,0.2728])grad:2.04.0tensor([-1.9387,-0.9694,-0.4847])grad:3.06.0tensor([1.7050,0.5683,0.1894])progress:530.00897257961332798grad:1.02.0tensor([0.2750,0.2750,0.2750])grad:2.04.0tensor([-1.9301,-0.9651,-0.4825])grad:3.06.0tensor([1.6932,0.5644,0.1881])progress:540.008848887868225574grad:1.02.0tensor([0.2771,0.2771,0.2771])grad:2.04.0tensor([-1.9219,-0.9609,-0.4805])grad:3.06.0tensor([1.6819,0.5606,0.1869])progress:550.008730598725378513grad:1.02.0tensor([0.2792,0.2792,0.2792])grad:2.04.0tensor([-1.9139,-0.9569,-0.4785])grad:3.06.0tensor([1.6709,0.5570,0.1857])progress:560.00861735362559557grad:1.02.0tensor([0.2811,0.2811,0.2811])grad:2.04.0tensor([-1.9062,-0.9531,-0.4765])grad:3.06.0tensor([1.6604,0.5535,0.1845])progress:570.008508718572556973grad:1.02.0tensor([0.2830,0.2830,0.2830])grad:2.04.0tensor([-1.8987,-0.9493,-0.4747])grad:3.06.0tensor([1.6502,0.5501,0.1834])progress:580.008404706604778767grad:1.02.0tensor([0.2848,0.2848,0.2848])grad:2.04.0tensor([-1.8915,-0.9457,-0.4729])grad:3.06.0tensor([1.6404,0.5468,0.1823])progress:590.008305158466100693grad:1.02.0tensor([0.2865,0.2865,0.2865])grad:2.04.0tensor([-1.8845,-0.9423,-0.4711])grad:3.06.0tensor([1.6309,0.5436,0.1812])progress:600.00820931326597929grad:1.02.0tensor([0.2882,0.2882,0.2882])grad:2.04.0tensor([-1.8778,-0.9389,-0.4694])grad:3.06.0tensor([1.6218,0.5406,0.1802])progress:610.008117804303765297grad:1.02.0tensor([0.2898,0.2898,0.2898])grad:2.04.0tensor([-1.8713,-0.9356,-0.4678])grad:3.06.0tensor([1.6130,0.5377,0.1792])progress:620.008029798977077007grad:1.02.0tensor([0.2913,0.2913,0.2913])grad:2.04.0tensor([-1.8650,-0.9325,-0.4662])grad:3.06.0tensor([1.6045,0.5348,0.1783])progress:630.007945418357849121grad:1.02.0tensor([0.2927,0.2927,0.2927])grad:2.04.0tensor([-1.8589,-0.9294,-0.4647])grad:3.06.0tensor([1.5962,0.5321,0.1774])progress:640.007864190265536308grad:1.02.0tensor([0.2941,0.2941,0.2941])grad:2.04.0tensor([-1.8530,-0.9265,-0.4632])grad:3.06.0tensor([1.5884,0.5295,0.1765])progress:650.007786744274199009grad:1.02.0tensor([0.2954,0.2954,0.2954])grad:2.04.0tensor([-1.8473,-0.9236,-0.4618])grad:3.06.0tensor([1.5807,0.5269,0.1756])progress:660.007711691781878471grad:1.02.0tensor([0.2967,0.2967,0.2967])grad:2.04.0tensor([-1.8417,-0.9209,-0.4604])grad:3.06.0tensor([1.5733,0.5244,0.1748])progress:670.007640169933438301grad:1.02.0tensor([0.2979,0.2979,0.2979])grad:2.04.0tensor([-1.8364,-0.9182,-0.4591])grad:3.06.0tensor([1.5662,0.5221,0.1740])progress:680.007570972666144371grad:1.02.0tensor([0.2991,0.2991,0.2991])grad:2.04.0tensor([-1.8312,-0.9156,-0.4578])grad:3.06.0tensor([1.5593,0.5198,0.1733])progress:690.007504733745008707grad:1.02.0tensor([0.3002,0.3002,0.3002])grad:2.04.0tensor([-1.8262,-0.9131,-0.4566])grad:3.06.0tensor([1.5527,0.5176,0.1725])progress:700.007440924644470215grad:1.02.0tensor([0.3012,0.3012,0.3012])grad:2.04.0tensor([-1.8214,-0.9107,-0.4553])grad:3.06.0tensor([1.5463,0.5154,0.1718])progress:710.007379599846899509grad:1.02.0tensor([0.3022,0.3022,0.3022])grad:2.04.0tensor([-1.8167,-0.9083,-0.4542])grad:3.06.0tensor([1.5401,0.5134,0.1711])progress:720.007320486940443516grad:1.02.0tensor([0.3032,0.3032,0.3032])grad:2.04.0tensor([-1.8121,-0.9060,-0.4530])grad:3.06.0tensor([1.5341,0.5114,0.1705])progress:730.007263725157827139grad:1.02.0tensor([0.3041,0.3041,0.3041])grad:2.04.0tensor([-1.8077,-0.9038,-0.4519])grad:3.06.0tensor([1.5283,0.5094,0.1698])progress:740.007209045812487602grad:1.02.0tensor([0.3050,0.3050,0.3050])grad:2.04.0tensor([-1.8034,-0.9017,-0.4508])grad:3.06.0tensor([1.5227,0.5076,0.1692])progress:750.007156429346650839grad:1.02.0tensor([0.3058,0.3058,0.3058])grad:2.04.0tensor([-1.7992,-0.8996,-0.4498])grad:3.06.0tensor([1.5173,0.5058,0.1686])progress:760.007105532102286816grad:1.02.0tensor([0.3066,0.3066,0.3066])grad:2.04.0tensor([-1.7952,-0.8976,-0.4488])grad:3.06.0tensor([1.5121,0.5040,0.1680])progress:770.00705681974068284grad:1.02.0tensor([0.3073,0.3073,0.3073])grad:2.04.0tensor([-1.7913,-0.8956,-0.4478])grad:3.06.0tensor([1.5070,0.5023,0.1674])progress:780.007009552326053381grad:1.02.0tensor([0.3081,0.3081,0.3081])grad:2.04.0tensor([-1.7875,-0.8937,-0.4469])grad:3.06.0tensor([1.5021,0.5007,0.1669])progress:790.006964194122701883grad:1.02.0tensor([0.3087,0.3087,0.3087])grad:2.04.0tensor([-1.7838,-0.8919,-0.4459])grad:3.06.0tensor([1.4974,0.4991,0.1664])progress:800.006920332089066505grad:1.02.0tensor([0.3094,0.3094,0.3094])grad:2.04.0tensor([-1.7802,-0.8901,-0.4450])grad:3.06.0tensor([1.4928,0.4976,0.1659])progress:810.006878111511468887grad:1.02.0tensor([0.3100,0.3100,0.3100])grad:2.04.0tensor([-1.7767,-0.8883,-0.4442])grad:3.06.0tensor([1.4884,0.4961,0.1654])progress:820.006837360095232725grad:1.02.0tensor([0.3106,0.3106,0.3106])grad:2.04.0tensor([-1.7733,-0.8867,-0.4433])grad:3.06.0tensor([1.4841,0.4947,0.1649])progress:830.006797831039875746grad:1.02.0tensor([0.3111,0.3111,0.3111])grad:2.04.0tensor([-1.7700,-0.8850,-0.4425])grad:3.06.0tensor([1.4800,0.4933,0.1644])progress:840.006760062649846077grad:1.02.0tensor([0.3117,0.3117,0.3117])grad:2.04.0tensor([-1.7668,-0.8834,-0.4417])grad:3.06.0tensor([1.4759,0.4920,0.1640])progress:850.006723103579133749grad:1.02.0tensor([0.3122,0.3122,0.3122])grad:2.04.0tensor([-1.7637,-0.8818,-0.4409])grad:3.06.0tensor([1.4720,0.4907,0.1636])progress:860.00668772729113698grad:1.02.0tensor([0.3127,0.3127,0.3127])grad:2.04.0tensor([-1.7607,-0.8803,-0.4402])grad:3.06.0tensor([1.4682,0.4894,0.1631])progress:870.006653300020843744grad:1.02.0tensor([0.3131,0.3131,0.3131])grad:2.04.0tensor([-1.7577,-0.8789,-0.4394])grad:3.06.0tensor([1.4646,0.4882,0.1627])progress:880.0066203586757183075grad:1.02.0tensor([0.3135,0.3135,0.3135])grad:2.04.0tensor([-1.7548,-0.8774,-0.4387])grad:3.06.0tensor([1.4610,0.4870,0.1623])progress:890.0065881176851689816grad:1.02.0tensor([0.3139,0.3139,0.3139])grad:2.04.0tensor([-1.7520,-0.8760,-0.4380])grad:3.06.0tensor([1.4576,0.4859,0.1620])progress:900.0065572685562074184grad:1.02.0tensor([0.3143,0.3143,0.3143])grad:2.04.0tensor([-1.7493,-0.8747,-0.4373])grad:3.06.0tensor([1.4542,0.4847,0.1616])progress:910.0065271081402897835grad:1.02.0tensor([0.3147,0.3147,0.3147])grad:2.04.0tensor([-1.7466,-0.8733,-0.4367])grad:3.06.0tensor([1.4510,0.4837,0.1612])progress:920.00649801641702652grad:1.02.0tensor([0.3150,0.3150,0.3150])grad:2.04.0tensor([-1.7441,-0.8720,-0.4360])grad:3.06.0tensor([1.4478,0.4826,0.1609])progress:930.0064699104987084866grad:1.02.0tensor([0.3153,0.3153,0.3153])grad:2.04.0tensor([-1.7415,-0.8708,-0.4354])grad:3.06.0tensor([1.4448,0.4816,0.1605])progress:940.006442630663514137grad:1.02.0tensor([0.3156,0.3156,0.3156])grad:2.04.0tensor([-1.7391,-0.8695,-0.4348])grad:3.06.0tensor([1.4418,0.4806,0.1602])progress:950.006416172254830599grad:1.02.0tensor([0.3159,0.3159,0.3159])grad:2.04.0tensor([-1.7366,-0.8683,-0.4342])grad:3.06.0tensor([1.4389,0.4796,0.1599])progress:960.006390606984496117grad:1.02.0tensor([0.3161,0.3161,0.3161])grad:2.04.0tensor([-1.7343,-0.8671,-0.4336])grad:3.06.0tensor([1.4361,0.4787,0.1596])progress:970.0063657015562057495grad:1.02.0tensor([0.3164,0.3164,0.3164])grad:2.04.0tensor([-1.7320,-0.8660,-0.4330])grad:3.06.0tensor([1.4334,0.4778,0.1593])progress:980.0063416799530386925grad:1.02.0tensor([0.3166,0.3166,0.3166])grad:2.04.0tensor([-1.7297,-0.8649,-0.4324])grad:3.06.0tensor([1.4308,0.4769,0.1590])progress:990.00631808303296566predict(aftertranining)48.544171333312988

损失值随着迭代次数的增加呈递减趋势,如下图所示:

可以看出:x=4时的预测值约为8.5,与真实值8有所差距,可通过提高迭代次数或者调整学习率、初始参数等方法来减小差距。

读到这里,这篇“PyTorch梯度下降反向传播实例分析”文章已经介绍完毕,想要掌握这篇文章的知识点还需要大家自己动手实践使用过才能领会,如果想了解更多相关内容的文章,欢迎关注恰卡编程网行业资讯频道。

发布于 2022-03-09 22:49:43
收藏
分享
海报
0 条评论
24
上一篇:go单例怎么实现双重检测是否安全 下一篇:PHP文件中怎么配置超时时间
目录

    0 条评论

    本站已关闭游客评论,请登录或者注册后再评论吧~

    忘记密码?

    图形验证码