前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >Pytorch 拟合多项式的例子

Pytorch 拟合多项式的例子

作者头像
王云峰
发布2023-10-23 10:27:09
2590
发布2023-10-23 10:27:09
举报
文章被收录于专栏:Yunfeng's Simple Blog

1. 概述

Pytorch包含了Linear层,可以用来拟合y = w * x + b 形式的函数,其中wbias就是Linear层的weights和bias。这里写个拟合一次多项式的简单demo,作为一个小实验。

2. 拟合一次多项式

采用下面的代码,我们设计了一个包含一个线性层的网络,通过给它feed随机构造的数据(y = 1.233 * x + 0.988),结合梯度下降算法和MSE loss惩罚函数,让它学习数据的构造参数:

代码语言:javascript
复制
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.linear = nn.Linear(1, 1)

    def forward(self, x):
        x = self.linear(x)
        return x


def run():
    torch.manual_seed(1024)

    model = Model()
    model.train()
    optimizer = optim.SGD(model.parameters(), lr=1e-2)

    w = 1.233
    b = 0.988
	num_iteration = 5000
    for i in range(num_iteration):
        optimizer.zero_grad()
        x = torch.rand(1)
        y = w * x + b
        pred = model(x)

        loss = F.mse_loss(y, pred)
        loss.backward()
        optimizer.step()

    for name, param in model.named_parameters():
        print(f"{name}={param.data.numpy().squeeze():.3f}")


if __name__ == "__main__":
    run()

运行这个脚本的输出结果如下:

代码语言:javascript
复制
linear.weight=1.233
linear.bias=0.988

可以看到,经过5000次的迭代,网络能成功地学习到数据构造过程中的w和b参数, 这个小网络现在可以用来替代线性回归机器学习算法了!

如果迭代周期太小则可能收敛不到我们预设的参数,可以手动修改迭代次数num_iteration为2000查看结果。

3. 如果重复Linear层会发生什么?

如果我们把同一个linear层重复执行两次,会有什么结果呢?也就是网络定义修改如下:

代码语言:javascript
复制
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.linear = nn.Linear(1, 1)

    def forward(self, x):
        x = self.linear(x)
        x = self.linear(x)
        return x

这里调用了两次同一个linear层,因此相当于 y = w * ( w * x + b) + b,也就是一次forward更新两次参数,也可以理解成两个共享参数的线性层。 完整的示例代码如下:

代码语言:javascript
复制
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.linear = nn.Linear(1, 1)

    def forward(self, x):
        x = self.linear(x)
        x = self.linear(x)
        return x


def run():
    torch.manual_seed(1024)

    model = Model()
    model.train()
    optimizer = optim.SGD(model.parameters(), lr=1e-2)

    w = 1.233
    b = 0.988
    num_iteration = 5000
    for i in range(num_iteration):
        optimizer.zero_grad()
        x = torch.rand(1)
        y = w * x + b
        pred = model(x)

        loss = F.mse_loss(y, pred)
        loss.backward()
        optimizer.step()

    for name, param in model.named_parameters():
        print(f"{name}={param.data.numpy().squeeze():.3f}")


if __name__ == "__main__":
    run()

同样的,通过我们构造 y = 1.233 * x + 0.998的数据,带入 y = w * ( w * x + b) + b,可以得到一组解 w=1.110, b=0.468,这与我们网络运行得到的结果是一致的:

代码语言:javascript
复制
linear.weight=1.110
linear.bias=0.468

同时也有一个问题:为什么没得到w为负数的另一组解呢?这是因为我这里为了保证复现性,手动设置了随机数种子为1024,设置为别的值应该可以得到另一组参数,欢迎尝试。

本文参与?腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2022-06-26,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客?前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与?腾讯云自媒体同步曝光计划? ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 1. 概述
  • 2. 拟合一次多项式
  • 3. 如果重复Linear层会发生什么?
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档
http://www.vxiaotou.com