PyTorch搭建多项式回归模型(三)

当建立了数据的特征和目标集,就可以开始训练多项式回归模型了。在此教程中,我们将搭建一个多项式回归模型,根据公式f(x)=ax^3+bx^2+cx+d进行拟合。

数据预处理

import torch
import numpy as np

# 设置随机种子,保证结果可复现
torch.manual_seed(2021)

# 创建训练数据和测试数据
x_train = np.array([-3, -2, -1, 0, 1, 2, 3], dtype=np.float32)
y_train = np.array([3, 2, 1, 0, 1, 2, 3], dtype=np.float32)
x_train = torch.from_numpy(x_train.reshape(-1, 1))
y_train = torch.from_numpy(y_train.reshape(-1, 1))

我们使用numpy创建了特征集x_train和目标集y_train,并将它们转换成了PyTorch张量。需要将特征集进行reshape操作,使其变成n行1列的张量,以符合PyTorch的输入格式。

定义模型

import torch.nn as nn

class PolynomialRegression(nn.Module):

    def __init__(self, input_dim, output_dim):
        super(PolynomialRegression, self).__init__()
        self.poly = nn.Sequential(
            nn.Linear(input_dim, 3),
            nn.ReLU(),
            nn.Linear(3, 4),
            nn.ReLU(),
            nn.Linear(4, output_dim)
        )

    def forward(self, x):
        out = self.poly(x)
        return out

我们使用PyTorch的nn.Module类创建了PolynomialRegression类,并重写了其构造函数和前向传播函数。模型的每一层都通过nn.Linear()函数创建,前三层分别有3个神经元、4个神经元和1个神经元。其中中间的ReLU激活函数是为了增加非线性效应。

训练模型

# 模型训练
model = PolynomialRegression(input_dim=1, output_dim=1)
criterion = nn.MSELoss()  # 损失函数
optimizer = torch.optim.SGD(model.parameters(), lr=0.03)  # 优化函数

num_epochs = 2000
for epoch in range(num_epochs):
    # 前向传播
    outputs = model(x_train)
    loss = criterion(outputs, y_train)

    # 反向传播和优化
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if (epoch + 1) % 100 == 0:
        print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, loss.item()))

在这里,我们设置了2000轮的训练,并以均方误差作为损失函数。通过反向传播和优化函数进行优化,即可训练模型。输出结果为每100次训练后的损失函数值,用于观察模型的训练情况。

模型预测

# 预测模型结果
model.eval()
predict = model(x_train)
predict = predict.data.numpy()

# 数据可视化
import matplotlib.pyplot as plt

plt.plot(x_train.numpy(), y_train.numpy(), 'ro', label='Original Data')
plt.plot(x_train.numpy(), predict, label='Fitting Line')
plt.legend()
plt.show()

最后,我们使用模型对训练数据进行预测,并将结果可视化,以便更好地理解模型的拟合效果。

示例说明1:将数据进行归一化处理

我们可以在数据预处理步骤中加入数据归一化的操作以提升模型效果。在下列代码中,将原始数据除以5来缩小数据范围,并在模型预测步骤中进行数据反归一化操作。

# 特征集归一化处理
x_train = (x_train - x_train.mean()) / x_train.std()
# 目标集归一化处理
y_train = (y_train - y_train.mean()) / y_train.std()

# ...(省略定义模型,训练模型部分代码)

# 预测模型结果
model.eval()
predict = model(x_train)
predict = predict * y_train.std() + y_train.mean()
predict = predict.data.numpy()

# 数据可视化
import matplotlib.pyplot as plt

plt.plot(x_train.numpy(), y_train.numpy(), 'ro', label='Original Data')
plt.plot(x_train.numpy(), predict, label='Fitting Line')
plt.legend()
plt.show()

示例说明2:将多项式回归模型更改为二次函数拟合

我们可以将多项式回归中拟合的公式f(x)=ax^3+bx^2+cx+d变更为f(x)=ax^2+bx+c,从而实现二次函数拟合,示例如下:

class QuadraticRegression(nn.Module):

    def __init__(self, input_dim, output_dim):
        super(QuadraticRegression, self).__init__()
        self.poly = nn.Sequential(
            nn.Linear(input_dim, 2),
            nn.ReLU(),
            nn.Linear(2, output_dim)
        )

    def forward(self, x):
        out = self.poly(x)
        return out

# ...(省略数据预处理,训练模型、预测模型、数据可视化等步骤)

# 创建二次函数拟合模型
model = QuadraticRegression(input_dim=1, output_dim=1)

通过这样的修改,我们可以更灵活地拟合不同形状的曲线。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:PyTorch搭建多项式回归模型(三) - Python技术站

(0)
上一篇 2023年5月25日
下一篇 2023年5月25日

相关文章

  • Unity实现动物识别的示例代码

    下面将为你详细讲解Unity实现动物识别的示例代码攻略。 概述 动物识别是近几年比较火热的技术之一,它主要是利用深度学习技术来识别动物的种类,以便在未来为动物保护和研究提供更多的数据支撑。而Unity是目前比较流行的游戏开发引擎之一,在其基础上可以比较方便地实现动物识别的功能。 环境要求 在实现动物识别的过程中,我们需要安装一些必备的软件和插件,包括: Un…

    人工智能概论 2023年5月25日
    00
  • acrobat pro dc怎么用?adobe acrobat pro dc 2017安装+使用教程

    Acrobat Pro DC是Adobe推出的一款PDF编辑及制作工具,本文将为大家提供一份完整的安装与使用攻略。 安装Acrobat Pro DC 下载Acrobat Pro DC安装程序,可以在Adobe官网或者第三方下载站点进行下载。 双击以启动安装程序。 程序会自动检测你的计算机是否能够承受运行Acrobat Pro DC所需的最低要求,并自动显示在…

    人工智能概览 2023年5月25日
    00
  • 常用的Spring Boot调用外部接口方式实现数据交互

    Spring Boot是一款十分流行的Java框架,使用Spring Boot开发应用程序常遇到的问题之一就是需要调用外部接口实现数据交互。本篇文章将详细讲解常用的Spring Boot调用外部接口方式实现数据交互的完整攻略,主要包括以下几点。 1. 实现数据交互的方式 在前期规划时,我们需要明确如何实现数据交互。通常有以下几种方式。 RestTemplat…

    人工智能概览 2023年5月25日
    00
  • SpringCloud Stream消息驱动实例详解

    SpringCloud Stream消息驱动实例详解 本文将详细介绍Spring Cloud Stream的使用方法,包括如何使用Spring Cloud Stream进行消息驱动、如何构建生产者和消费者,并给出了两个示例说明。 什么是Spring Cloud Stream? Spring Cloud Stream是用于构建消息驱动微服务的框架,提供了一种简…

    人工智能概览 2023年5月25日
    00
  • Python提取频域特征知识点浅析

    请允许我详细讲解 “Python提取频域特征知识点浅析” 的完整攻略。 一、前言 频域特征提取是信号处理中的一个重要步骤,它允许我们将一个时域信号转换到一个频域信号,这样我们就可以通过频率分析获得更多关于信号特征的信息。Python中有很多强大的工具用于频域分析。 二、Python中的频域分析工具 1. NumPy和SciPy NumPy和SciPy是Pyt…

    人工智能概览 2023年5月25日
    00
  • 使用k8tz解决pod内的时区问题(坑的解决)

    当我们在使用 Kubernetes 部署应用时,有时会遇到时区不正确的问题。pod 内部的时区不受主机时区的影响,因此需要在容器内设置正确的时区。本文将介绍如何使用 k8tz 解决这个问题。 准备工作 在开始使用 k8tz 前,需要先为集群中的所有节点安装 tzdata 包,以保证时区信息正确。可以通过以下命令安装: apt-get update &…

    人工智能概览 2023年5月25日
    00
  • node.js博客项目开发手记

    下面我将详细讲解“node.js博客项目开发手记”的完整攻略。该攻略包含项目开发的整个过程,具体步骤如下: 第一步:准备开发环境 首先需要确保本地安装了Node.js环境和npm包管理器,然后在命令行中输入以下命令来创建一个新的博客项目: mkdir my-blog cd my-blog npm init 接下来执行以下命令安装需要的模块: npm inst…

    人工智能概览 2023年5月25日
    00
  • 解决django xadmin主题不显示和只显示bootstrap2的问题

    下面是针对 Django xadmin 主题不显示和只显示 bootstrap2 的问题的完整攻略: 问题描述 在使用 Django xadmin 后台管理系统时,我们可能会遇到以下两个问题: xadmin 主题显示异常:前端页面没有样式,显示非常原始; xadmin 只显示 bootstrap2 样式:页面只显示 bootstrap2 的样式而不是应该的主…

    人工智能概览 2023年5月25日
    00
合作推广
合作推广
分享本页
返回顶部