Python机器学习之逻辑回归

Python机器学习之逻辑回归

逻辑回归(Logistic Regression)是一种常用的分类算法,它可以用于二分类和多分类问题。在这篇文章中,我们将介绍如何使用Python实现逻辑回归算法,并详细讲解实现原理。

实现原理

逻辑回归是一种基于概率的分类算法,它的目标是根据输入特征预测样本属于哪个类别。逻辑回归的实现原理如下:

  1. 首先定义一个逻辑回归模型,包含权重向量和偏置项。
  2. 然后定义一个损失函数,用于衡量模型预测结果与真实结果之间的差距。
  3. 接着使用梯度下降算法来最小化损失函数,从而得到最优的权重向量和偏置项。
  4. 最后使用训练好的模型来预测新的样本类别。

Python实现

下面是一个使用Python实现逻辑回归算法的示例:

import numpy as np

class LogisticRegression:
    def __init__(self, lr=0.01, num_iter=100000, fit_intercept=True, verbose=False):
        self.lr = lr
        self.num_iter = num_iter
        self.fit_intercept = fit_intercept
        self.verbose = verbose

    def __add_intercept(self, X):
        intercept = np.ones((X.shape[0], 1))
        return np.concatenate((intercept, X), axis=1)

    def __sigmoid(self, z):
        return 1 / (1 + np.exp(-z))

    def __loss(self, h, y):
        return (-y * np.log(h) - (1 - y) * np.log(1 - h)).mean()

    def fit(self, X, y):
        if self.fit_intercept:
            X = self.__add_intercept(X)

        self.theta = np.zeros(X.shape[1])

        for i in range(self.num_iter):
            z = np.dot(X, self.theta)
            h = self.__sigmoid(z)
            gradient = np.dot(X.T, (h - y)) / y.size
            self.theta -= self.lr * gradient

            if self.verbose and i % 10000 == 0:
                z = np.dot(X, self.theta)
                h = self.__sigmoid(z)
                print(f'loss: {self.__loss(h, y)} \t')

    def predict_prob(self, X):
        if self.fit_intercept:
            X = self.__add_intercept(X)

        return self.__sigmoid(np.dot(X, self.theta))

    def predict(self, X, threshold=0.5):
        return self.predict_prob(X) >= threshold

在这个示例中,我们首先定义了一个名为LogisticRegression的类,用于实现逻辑回归算法。在LogisticRegression类中,我们首先定义了一个__add_intercept函数,用于添加截距项。然后定义了一个__sigmoid函数,用于计算sigmoid函数的值。接着定义了一个__loss函数,用于计算损失函数的值。最后定义了一个fit函数,用于训练模型;一个predict_prob函数,用于预测样本属于正类的概率;一个predict函数,用于预测样本类别。

示例1:使用逻辑回归进行二分类

在这个示例中,我们将使用逻辑回归进行二分类。我们首先生成一个二分类数据集,然后使用逻辑回归模型进行训练,并使用测试集评估模型的性能。

import numpy as np
import matplotlib.pyplot as plt

np.random.seed(0)
X = np.random.randn(200, 2)
y = np.array([0 if x[0] + x[1] 0 else 1 for x in X])

model = LogisticRegression(lr=0.1, num_iter=300000)
model.fit(X, y)

plt.scatter(X[:, 0], X[:, 1], c=y)
x1_min, x1_max = X[:, 0].min(), X[:, 0].max(),
x2_min, x2_max = X[:, 1].min(), X[:, 1].max(),
xx1, xx2 = np.meshgrid(np.linspace(x1_min, x1_max), np.linspace(x2_min, x2_max))
grid = np.c_[xx1.ravel(), xx2.ravel()]
probs = model.predict_prob(grid).reshape(xx1.shape)
plt.contour(xx1, xx2, probs, [0.5], linewidths=1, colors='red')
plt.show()

在这个示例中,我们首先使用numpy模块生成一个二分类数据集。然后使用逻辑回归模型进行训练,并使用测试集评估模型的性能。最后使用matplotlib模块绘制出决策边界。

示例2:使用逻辑回归进行多分类

在这个示例中,我们将使用逻辑回归进行多分类。我们首先加载一个手写数字数据集,然后使用逻辑回归模型进行训练,并使用测试集评估模型的性能。

from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

digits = load_digits()
X_train, X_test, y_train, y_test = train_test_split(digits.data, digits.target, test_size=0.3, random_state=0)

model = LogisticRegression(lr=0.1, num_iter=300000)
model.fit(X_train, y_train)

y_pred = model.predict(X_test)
accuracy = np.mean(y_pred == y_test)
print(f"Accuracy: {accuracy}")

fig, axes = plt.subplots(4, 4, figsize=(8, 8))
fig.subplots_adjust(hspace=0.3, wspace=0.3)

for i, ax in enumerate(axes.flat):
    ax.imshow(X_test[i].reshape(8, 8), cmap='binary', interpolation='nearest')
    ax.text(0.05, 0.05, str(y_pred[i]), transform=ax.transAxes, color='green' if y_pred[i] == y_test[i] else 'red')
    ax.set_xticks([])
    ax.set_yticks([])

plt.show()

在这个示例中,我们首先使用sklearn.datasets模块加载一个手写数字数据集。然后使用逻辑回归模型进行训练,并使用测试集评估模型的性能。最后使用matplotlib模块绘制出预测结果。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:Python机器学习之逻辑回归 - Python技术站

(0)
上一篇 2023年5月14日
下一篇 2023年5月14日

相关文章

  • 对Python3.x版本print函数左右对齐详解

    对Python3.x版本print函数左右对齐详解 在Python3.x版本中,print函数有多种对齐方式,可以对字符串进行左对齐、右对齐和居中对齐。下面逐一介绍这三种对齐方式以及如何使用它们。 左对齐 采用左对齐方式可以将字符串左对齐,并在字符串右侧填充空格来实现对齐。左对齐采用“<”进行标识。 string = ‘Python’ print(‘{…

    python 2023年6月5日
    00
  • Python 中的pygame安装与配置教程详解

    下面是关于“Python中的pygame安装与配置教程详解”的完整攻略。 1. 安装Python Python是一种编程语言,pygame是Python游戏开发库,因此我们需要先安装Python才能进行pygame的配置。可以从Python的官网下载对应版本进行安装,在安装过程中注意勾选“Add Python to PATH”选项。 2. 安装pygame …

    python 2023年5月14日
    00
  • Python在字典中查找元素的3种方式

    Python中常见的数据结构之一是字典。字典是由键和值组成的无序集合,其中的键是唯一的。我们需要在字典中查找元素时,可以使用以下三种方式。 使用in关键字查找 在Python中,可以使用in关键字来检查字典中是否存在某个键,如果存在则返回True,否则返回False。 # 示例1 user_dict = {‘name’: ‘Alice’, ‘age’: 20…

    python 2023年5月13日
    00
  • Windows上配置Emacs来开发Python及用Python扩展Emacs

    Windows上配置Emacs来开发Python及用Python扩展Emacs 在Windows上配置Emacs来开发Python需要进行以下步骤: 步骤1:安装Emacs 可以从官网下载最新版本的Emacs: https://www.gnu.org/software/emacs/download.html#windows 步骤2:安装Python 可以从P…

    python 2023年6月3日
    00
  • python每天定时运行某程序代码

    以下是实现Python定时运行程序代码的完整攻略: 1. 安装第三方模块 我们可以使用Python的第三方模块schedule来完成定时运行某程序代码的功能,需要先安装该模块。可以通过使用pip这个包管理器来完成安装,具体命令如下: pip install schedule 2. 导入模块 接下来,我们需要将schedule模块导入到Python源代码中,可…

    python 2023年5月19日
    00
  • Python网络编程 Python套接字编程

    Python网络编程 Python套接字编程攻略 1. 网络编程基础 网络编程基础涉及到的主要概念有IP地址、端口、协议、套接字等。 IP地址:Internet Protocol Address,即网络协议地址。它是用于标识互联网上设备的地址。IP地址分为IPv4和IPv6两种。 端口:通过IP地址,可以找到对应设备上的进程,而端口则是用于标识这些进程的,相…

    python 2023年5月19日
    00
  • PyQt5入门之基于QListWidget版本实现图片缩略图列表功能

    我会详细讲解“PyQt5入门之基于QListWidget版本实现图片缩略图列表功能”的完整攻略。 概述 QListWidget是Qt中的列表控件,它能够展示列表式的数据,并支持图标展示。本篇攻略将会介绍如何基于QListWidget实现图片缩略图列表功能。 实现步骤 导入必要的模块 from PyQt5.QtGui import QIcon, QPixmap…

    python 2023年5月19日
    00
  • 一文带你玩转MySQL获取时间和格式转换各类操作方法详解

    一文带你玩转MySQL获取时间和格式转换各类操作方法详解 获取当前日期/时间 获取当前日期 获取当前日期可以使用函数CURDATE(),该函数返回的是当前日期的字符串。下面是一个示例: SELECT CURDATE(); 输出如下所示: CURDATE() 2021-08-03 获取当前时间 获取当前时间可以使用函数CURTIME(),该函数返回的是当前时间…

    python 2023年6月2日
    00
合作推广
合作推广
分享本页
返回顶部