图文详解梯度下降算法的原理及Python实现

yizhihongxing

图文详解梯度下降算法的原理及Python实现

梯度下降算法是机器学习中最常用的优化算法之一,它的主要作用是通过迭代的方式,不断调整模型参数使得模型的损失函数最小化。本文将详细讲解梯度下降算法的原理及Python实现,以及两个示例说明。

梯度下降算法原理

梯度下降算法的基本思想是通过不断调整模型参数,使得模型的损失函数最小化。具体来说,算法的步骤如下:

  1. 随机初始化模型参数;
  2. 计算模型的损失函数;
  3. 计算损失函数对模型参数的梯度;
  4. 根据梯度调整模型参数;
  5. 重复步骤2-4,直到损失函数收敛或达到最大迭代数。

其中,步骤3是梯度下降算法的核心,它的目的是计算损失函数对模型参数的梯度,以根据梯度调整模型参数。具体来说,对于一个模型参数 $\theta_i$,它的梯度可以表示为:

$$\frac{\partial J(\theta)}{\partial \theta_i}$$

其中,$J(\theta)$ 表示模型的损失函数,$\theta$ 表示模型的参数向量。

在计算梯度时,我们可以使用链式法则将损失函数的梯度表示为各个参数的偏导数之积。具体来说,对于一个多元函数 $f(x_1, x_2, ..., x_n)$,它的偏导数可以表示为:

$$\frac{\partial f}{\partial x_i} = \frac{\partial f}{\partial x_{i+1}} \cdot \frac{\partial x_{i+1}}{\partial x_i}$$

通过不断使用链式法则,我们可以将损失函数的梯度表示为各个参数的偏导数之积,从而计算出模型参数的梯度。

梯度下降算法Python实现

在Python中,我们可以使用NumPy库实现梯度下降算法。下面是一个简单的示例代码,用于对一个线性回归模型进行训练。

import numpy as np

# 定义模型
def model(X, theta):
    return X.dot(theta)

# 定义损失函数
def cost_function(X, y, theta):
    m = len(y)
    J = np.sum((model(X, theta) - y) ** 2) / (2 * m)
    return J

# 定义梯度下降算法
def gradient_descent(X, y, theta, alpha,_iters):
    m = len(y)
    J_history = np.zeros(num_iters)

    for i in range(num_iters):
        theta = theta - alpha * (X.T.dot(model(X, theta) - y) / m)
        J_history[i] = cost_function(X, y, theta)

    return theta, J_history

# 加载数据
data = np.loadtxt('data.csv', delimiter=',')
X = data[:, :-1]
y = data[:, -1]

# 在X前面添加一1,以便计算截距
X = np.hstack((np.ones((len(y), 1)), X))

# 随机初始化模型参数
theta = np.random.randn(X.shape[1])

# 设置学习率和迭代次数
alpha = 0.01
num_iters = 1000

# 运行梯度下降算法
theta, J_history = gradient_descent(X, y, theta, alpha, num_iters)

# 输出模型参数和损失函数的历史记录
print('theta:', theta)
print('J_history:', J_history)

在这个示例中,我们首先定义了一个线性回归模型和损失函数。然后,我们使用NumPy库加载数据,并在数据前面添加一列1,以便计算截距。接下来,我们随机初始化模型参数,并设置学习率和迭代次数。最后,我们使用定义的梯度下降算法对模型进行训练,并输出模型参数和损失函数的历史记录。

示例1:线性回归

在这个示例中,我们将使用梯度下降算法对一个线性回归模型进行训练,以便预测房价。

import numpy as np
import matplotlib.pyplot as plt

# 加载数据
data = np.loadtxt('data.csv', delimiter=',')
X = data[:, :-1]
y = data[:, -1]

# 在X前面添加一列1,以便计算截距
X = np.hstack((np.ones((len(y), 1)), X))

# 随机初始化模型参数
theta = np.random.randn(X.shape[1])

# 设置学习率和迭代次数
alpha = 0.01
num_iters = 1000

# 运行梯度下降算法
m = len(y)
J_history = np.zeros(num_iters)

for i in range(num_iters):
    theta = theta - alpha * (X.T.dot(X.dot(theta) - y) / m)
    J_history[i] = np.sum((X.dot(theta) - y) ** 2) / (2 * m)

# 输出模型参数和损失函数的历史记录
print('theta:', theta)
print('J_history:', J_history)

# 绘制损失函数的历史记录
plt.plot(J_history)
plt.xlabel('Iterations')
plt.ylabel('Cost')
plt.show()

在这个示例中,我们首先使用NumPy库加载数据,并在数据前面添加一列1,以便计算截距。接下来,我们随机初始化模型参数,并设置学习率和迭代次数。然后,我们使用梯度下降算法对模型进行训练,并输出模型参数和损失函数的历史记录。最后,我们使用Matplotlib库绘制损失函数的历史记录。

示例2:逻辑回归

在这个示例中,我们将使用梯度下降算法对一个逻辑回归模型进行训练,以便预测肿瘤为恶性。

import numpy as np
import matplotlib.pyplot as plt

# 加载数据
data = np.loadtxt('data.csv', delimiter=',')
X = data[:, :-1]
y data[:, -1]

# 在X前面添加一列1,以便计算截距
X = np.hstack((np.ones((len(y), 1)), X))

# 随机初始化模型参数
theta = np.random.randn(X.shape[1])

# 设置学习率和迭代次数
alpha = 0.01
num_iters = 1000

# 定义sigmoid函数
def sigmoid(z):
    return 1 / (1 + np.exp(-z))

# 定义损失函数
def cost_function(X, y, theta):
    m = len(y)
    h = sigmoid(X.dot(theta))
    J = -np.sum(y * np.log(h) + (1 - y) * np.log(1 - h)) / m
    return J

# 定义梯度下降算法
def gradient_descent(X, y, theta, alpha, num_iters):
    m = len(y)
    J_history = np.zeros(num_iters)

    for i in range(num_iters):
        h = sigmoid(X.dot(theta))
        theta = theta - alpha * (X.T.dot(h - y) / m)
        J_history[i] = cost_function(X, y, theta)

    return theta, J_history

# 运行梯度下降算法
theta, J_history = gradient_descent(X, y, theta, alpha, num_iters)

# 输出模型参数和损失函数的历史记录
print('theta:', theta)
print('J_history:', J_history)

# 绘制损失函数的历史
plt.plot(J_history)
plt.xlabel('Iterations')
plt.ylabel('Cost')
plt.show()

在这个示例中,我们首先使用NumPy库加载数据,并在数据前面添加一列1,以便计算截距。接下来,我们随机初始化模型参数,并设置学习率和迭代数。然后,我们定义了sigmoid函数和损失函数,并使用梯度下降算法对模型进行训练。最后,我们输出模型参数损失函数的历史记录,并使用Matplotlib库绘制损失函数的历史记录。

总结

本文详细讲解了梯度下降算法的原理及Python实现,以及两个示例说明。梯度下降算是机器学习中最常用的优化算法之一,它的主要作用是通过迭代的方式,不断调整模型参数,使得模型的损失函数最小化。在实际应用中,我们可以根据具体的需求选择不同的损失函数和学习率,并结合其他优化算法进行综合处理。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:图文详解梯度下降算法的原理及Python实现 - Python技术站

(0)
上一篇 2023年5月14日
下一篇 2023年5月14日

相关文章

  • django框架模板语言使用方法详解

    Django框架模板语言使用方法详解 Django框架的模板语言(Template Language)是一种用于在HTML模板中嵌入动态内容的语言。本文将介绍Django模板语言的基本语法和常用标签,并提供两个示例。 模板语言的基本语法 Django模板语言使用双大括号({{}})来标识动态内容。在模板中,可以使用变量、标签和过滤器来生成动态内容。 以下是一…

    python 2023年5月15日
    00
  • python中readline判断文件读取结束的方法

    在Python中,我们可以使用readline()方法来一行一行地读取文件内容。但是,如何判断文件读取结束呢?我们可以通过以下几种方式来判断。 方法一:使用while循环和readline()方法 我们可以通过在while循环中使用readline()方法来读取文件内容,每次读取一行,当readline()返回的为空字符串时,表示已经到达文件的结尾,此时应该…

    python 2023年6月3日
    00
  • Python包管理工具pip用法详解

    Python包管理工具pip用法详解 什么是pip pip是Python语言的一个常用包管理工具,它可以用来安装、升级、卸载Python包。 安装pip 如果你使用的是Python 2.7.9及以上版本或Python 3.4及以上版本,pip已经默认安装了。如果没有安装pip,你可以通过以下命令安装: sudo apt install python-pip …

    python 2023年5月18日
    00
  • python集合能干吗

    Python集合是一种无序、不重复的数据类型,可以用于存储各种类型的值,例如数字、字符串和元组等。集合非常适合用于数据去重、判断成员关系、求交集和并集等场景。 数据去重 集合最常用的功能之一就是去重。我们可以将一组数据放到一个集合中,自动去除重复的元素。使用方法如下: # 创建一个列表,包含重复元素 nums = [1, 2, 3, 2, 4, 5, 1] …

    python 2023年5月13日
    00
  • python用win32gui遍历窗口并设置窗口位置的方法

    下面是详细讲解如何使用win32gui模块来遍历窗口并设置窗口位置的方法。 1. 安装Python和win32 在使用win32gui模块前,需要先安装Python和win32。Python可以从官方下载页面下载(https://www.python.org/downloads/),安装时记得选中“Add Python to PATH”选项。 安装Pytho…

    python 2023年6月13日
    00
  • Numpy对数组的操作:创建、变形(升降维等)、计算、取值、复制、分割、合并

    当然,我很乐意为您提供“Numpy对数组的操作”的完整攻略。以下是详细步骤和示例。 Numpy对数组的操作 Numpy是中用于科学计算的一个重要库,它提供高的数组操作和数学函数。在Numpy中,数组是一个重要的数据结构,因此对数组的操作也是非常重要。下我们将介绍Numpy对数组的操作,包括创建、变形(升降维等)、计算、取值、复制、分割、合等。 1 创建数组 …

    python 2023年5月13日
    00
  • Python urllib库的使用指南详解

    Python urllib库的使用指南详解 什么是Python urllib库? Python urllib库是Python标准库中用于和网站进行交互的工具包。它可以用于发送HTTP请求,从服务器获取响应,并对响应进行处理。Python urllib库包含4个模块:urllib.request、urllib.response、urllib.parse和url…

    python 2023年6月3日
    00
  • python线程中同步锁详解

    下面是关于”Python线程中同步锁详解”的完整攻略: 什么是同步锁? 同步锁是用于多线程编程的重要工具之一,它可以确保多个线程不会同时访问共享资源,从而避免数据竞争和死锁等问题的发生。 在Python中,我们可以使用threading模块提供的Lock, RLock和Semaphore等类来实现同步锁。 Lock类详解 Lock类的基本用法 Lock类是普…

    python 2023年5月19日
    00
合作推广
合作推广
分享本页
返回顶部