PyTorch搭建多项式回归模型(三)

当建立了数据的特征和目标集,就可以开始训练多项式回归模型了。在此教程中,我们将搭建一个多项式回归模型,根据公式f(x)=ax^3+bx^2+cx+d进行拟合。

数据预处理

import torch
import numpy as np

# 设置随机种子,保证结果可复现
torch.manual_seed(2021)

# 创建训练数据和测试数据
x_train = np.array([-3, -2, -1, 0, 1, 2, 3], dtype=np.float32)
y_train = np.array([3, 2, 1, 0, 1, 2, 3], dtype=np.float32)
x_train = torch.from_numpy(x_train.reshape(-1, 1))
y_train = torch.from_numpy(y_train.reshape(-1, 1))

我们使用numpy创建了特征集x_train和目标集y_train,并将它们转换成了PyTorch张量。需要将特征集进行reshape操作,使其变成n行1列的张量,以符合PyTorch的输入格式。

定义模型

import torch.nn as nn

class PolynomialRegression(nn.Module):

    def __init__(self, input_dim, output_dim):
        super(PolynomialRegression, self).__init__()
        self.poly = nn.Sequential(
            nn.Linear(input_dim, 3),
            nn.ReLU(),
            nn.Linear(3, 4),
            nn.ReLU(),
            nn.Linear(4, output_dim)
        )

    def forward(self, x):
        out = self.poly(x)
        return out

我们使用PyTorch的nn.Module类创建了PolynomialRegression类,并重写了其构造函数和前向传播函数。模型的每一层都通过nn.Linear()函数创建,前三层分别有3个神经元、4个神经元和1个神经元。其中中间的ReLU激活函数是为了增加非线性效应。

训练模型

# 模型训练
model = PolynomialRegression(input_dim=1, output_dim=1)
criterion = nn.MSELoss()  # 损失函数
optimizer = torch.optim.SGD(model.parameters(), lr=0.03)  # 优化函数

num_epochs = 2000
for epoch in range(num_epochs):
    # 前向传播
    outputs = model(x_train)
    loss = criterion(outputs, y_train)

    # 反向传播和优化
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if (epoch + 1) % 100 == 0:
        print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, loss.item()))

在这里,我们设置了2000轮的训练,并以均方误差作为损失函数。通过反向传播和优化函数进行优化,即可训练模型。输出结果为每100次训练后的损失函数值,用于观察模型的训练情况。

模型预测

# 预测模型结果
model.eval()
predict = model(x_train)
predict = predict.data.numpy()

# 数据可视化
import matplotlib.pyplot as plt

plt.plot(x_train.numpy(), y_train.numpy(), 'ro', label='Original Data')
plt.plot(x_train.numpy(), predict, label='Fitting Line')
plt.legend()
plt.show()

最后,我们使用模型对训练数据进行预测,并将结果可视化,以便更好地理解模型的拟合效果。

示例说明1:将数据进行归一化处理

我们可以在数据预处理步骤中加入数据归一化的操作以提升模型效果。在下列代码中,将原始数据除以5来缩小数据范围,并在模型预测步骤中进行数据反归一化操作。

# 特征集归一化处理
x_train = (x_train - x_train.mean()) / x_train.std()
# 目标集归一化处理
y_train = (y_train - y_train.mean()) / y_train.std()

# ...(省略定义模型,训练模型部分代码)

# 预测模型结果
model.eval()
predict = model(x_train)
predict = predict * y_train.std() + y_train.mean()
predict = predict.data.numpy()

# 数据可视化
import matplotlib.pyplot as plt

plt.plot(x_train.numpy(), y_train.numpy(), 'ro', label='Original Data')
plt.plot(x_train.numpy(), predict, label='Fitting Line')
plt.legend()
plt.show()

示例说明2:将多项式回归模型更改为二次函数拟合

我们可以将多项式回归中拟合的公式f(x)=ax^3+bx^2+cx+d变更为f(x)=ax^2+bx+c,从而实现二次函数拟合,示例如下:

class QuadraticRegression(nn.Module):

    def __init__(self, input_dim, output_dim):
        super(QuadraticRegression, self).__init__()
        self.poly = nn.Sequential(
            nn.Linear(input_dim, 2),
            nn.ReLU(),
            nn.Linear(2, output_dim)
        )

    def forward(self, x):
        out = self.poly(x)
        return out

# ...(省略数据预处理,训练模型、预测模型、数据可视化等步骤)

# 创建二次函数拟合模型
model = QuadraticRegression(input_dim=1, output_dim=1)

通过这样的修改,我们可以更灵活地拟合不同形状的曲线。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:PyTorch搭建多项式回归模型(三) - Python技术站

(0)
上一篇 2023年5月25日
下一篇 2023年5月25日

相关文章

  • Vue项目History模式404问题解决方法

    下面是“Vue项目History模式404问题解决方法”的完整攻略: 问题背景 在Vue项目中,我们可以选择使用History模式路由,以去除URL中的#符号。但是,在使用History模式路由时,如果浏览器直接访问某个路由或者刷新当前页面,就会出现404错误。 问题原因 在使用History模式路由时,当用户在浏览器中输入某个路由地址,或者在浏览器中刷新页…

    人工智能概览 2023年5月25日
    00
  • keepalived+nginx高可用实现方法示例

    Keepalived + Nginx 高可用实现方法 在高可用架构中,Keepalived和Nginx是两个非常常用的组件。Keepalived是一个基于VRRP协议实现高可用的工具,用于将一组服务器(通常是两个或多个)作为一个虚拟的负载均衡器来使用。而Nginx则是一款高性能的Web服务器软件,可以将多个Web服务器上的流量通过反向代理的方式分发到不同的W…

    人工智能概览 2023年5月25日
    00
  • 对pytorch中不定长序列补齐的操作

    下面是对PyTorch中不定长序列补齐的操作的完整攻略。 1. 序列补齐的操作 在处理序列数据时,由于序列长度不一,常常需要对长度不足的序列进行补齐操作。补齐操作指的是将长度小于预定长度的序列,通过在序列中添加一些特殊字符(比如PAD)或者重复序列元素等方式,将其长度补齐至预定长度。补齐操作可以使得序列数据可以被组成batch,在训练神经网络时方便使用。 P…

    人工智能概论 2023年5月25日
    00
  • Qt实现文本编辑器(二)

    下面我会详细讲解“Qt实现文本编辑器(二)”的完整攻略。该攻略主要分为以下几个部分: 设置界面 定义窗口类 定义文本编辑器类 定义菜单栏、工具栏 实现快捷键功能 实现查找、替换功能 实现撤销、重做功能 实现文件操作功能 其中,步骤二、三、八为主要内容。下面我会对这几个部分逐一进行讲解。 1. 设置界面 在工具->Qt Design页面中,设置文本编辑器…

    人工智能概览 2023年5月25日
    00
  • freebsd6.2 nginx+php+mysql+zend系统优化防止ddos攻击

    针对 “freebsd6.2 nginx+php+mysql+zend系统优化防止ddos攻击”的完整攻略,我将会详细讲解该过程,并给出两个示例说明。 一、系统优化 1.升级操作系统和软件包: FreeBSD 6.2 已经过时,其内核版本较老,安全性和性能都不如现在的操作系统。所以,我们需要将操作系统更新到较新的版本,并且要保持更新操作系统和软件包,以便获得…

    人工智能概览 2023年5月25日
    00
  • centos7如何设置密码规则?centos7设置密码规则的方法

    下面是详细讲解“centos7如何设置密码规则?centos7设置密码规则的方法”的完整攻略。 设置密码规则 CentOS 7使用强密码来保护用户的帐户。在CentOS 7中,通过修改PAM(Pluggable Authentication Modules,可插入身份验证模块)配置文件,可以设置密码规则来确保用户密码的强度。下面是设置密码规则的步骤: 步骤1…

    人工智能概览 2023年5月25日
    00
  • 详解django.contirb.auth-认证

    关于Django认证模块django.contrib.auth的详细讲解,可以分为以下几个部分进行阐述: 1. 概述 Django中的认证模块django.contrib.auth提供了一系列的身份验证和授权功能,它通常用于管理用户和组,以及用户认证、注册、登录和注销等过程。其中,认证API提供了基于用户名和密码、E-mail和密码、OAuth等多种认证方式…

    人工智能概览 2023年5月25日
    00
  • Python中文分词库jieba,pkusegwg性能准确度比较

    Python中文分词库jieba,pkuseg比较 在Python中,中文分词一直是一个非常重要的任务。而jieba和pkuseg是两个比较常用的中文分词工具。在本文中,我们将对这两个工具进行比较,包括性能、准确度等因素。 jieba 首先介绍的是jieba,它是一个中文分词工具包,功能强大,使用方便,因此被广泛使用。这是非常成熟的一个工具,经过多年的开发和…

    人工智能概论 2023年5月25日
    00
合作推广
合作推广
分享本页
返回顶部