解析python实现Lasso回归

最初在进行Lasso回归时,一般会通过sklearn库进行实现。但是,了解其内部的Python实现对于掌握Lasso回归建模和算法的原理和特性非常有帮助。下面给出了一个Python实现的Lasso回归建模过程。

步骤一:加载数据

import numpy as np

def load_data():
    # 加载数据集
    data = np.loadtxt("data.txt", delimiter=",")

    # 将数据拆分为特征和标签
    X = data[:, :-1] # 特征
    y = data[:, -1] # 标签
    return X, y

步骤二:搭建模型

class LassoRegressor:
    def __init__(self, alpha=1, l1_ratio=0.5, max_iter=1000, tol=1e-3):
        self.alpha = alpha # l1正则化系数
        self.l1_ratio = l1_ratio # L1/L2正则化比率
        self.max_iter = max_iter # 最大迭代次数
        self.tol = tol # 损失函数变化量的阈值

    def _soft_threshold(self, x, lambda_):
        if x > 0 and lambda_ < abs(x):
            return x - lambda_
        elif x < 0 and lambda_ < abs(x):
            return x + lambda_
        else:
            return 0

    def fit(self, X, y):
        self.n_samples, self.n_features = X.shape

        # 初始化theta
        self.theta = np.zeros(self.n_features)

        # 初始化损失函数二范数
        self.cost_his = []

        for i in range(self.max_iter):
            # 计算预测结果并计算残差
            y_pred = X.dot(self.theta)
            residuals = y - y_pred

            # 更新theta
            for j in range(self.n_features):
                X_j = X[:, j]
                soft_threshold = self._soft_threshold(X_j.T.dot(residuals), self.alpha*self.l1_ratio)
                self.theta[j] = soft_threshold / (1 + self.alpha*(1-self.l1_ratio))

            # 计算损失函数
            cost = np.sum((y_pred - y)**2) + self.alpha*np.sum(np.abs(self.theta))
            self.cost_his.append(cost)

            # 判断损失函数是否收敛
            if len(self.cost_his) > 1:
                if abs(self.cost_his[-1] - self.cost_his[-2]) < self.tol:
                    break

        return self

    def predict(self, X):
        return X.dot(self.theta)

代码中,我们定义了一个 LassoRegressor 类,包含了以下几个方法:
- __init__(self, alpha, l1_ratio, max_iter, tol): 初始化方法,包括正则化系数、L1/L2比例、最大迭代次数和损失函数变化门限。
- _soft_threshold(self, x, lambda_): 辅助计算函数,主要用于计算软阈值。
- fit(self, X, y): 训练方法,拟合数据集,通过更新参数theta来减少损失函数。
- predict(self, X): 预测方法,通过theta和特征集X来预测标签。

步骤三:测试模型

通过下面这段代码来测试上面的模型:

from sklearn.datasets import make_regression
from sklearn.linear_model import Lasso
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
import matplotlib.pyplot as plt

# 生成样本数据
X, y = make_regression(n_samples=100, n_features=10, random_state=0, noise=0.5)

# 划分数据集
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)

# 建立sklearn的Lasso模型并拟合
lasso_sklearn = Lasso(alpha=1.0)
lasso_sklearn.fit(X_train, y_train)
y_pred_sklearn = lasso_sklearn.predict(X_test)

# 建立自己的Lasso模型并拟合
lasso_self = LassoRegressor(alpha=1.0)
lasso_self.fit(X_train, y_train)
y_pred_self = lasso_self.predict(X_test)

# 对比两者效果
print("sklearn mean squared error: ", mean_squared_error(y_test, y_pred_sklearn))
print("self mean squared error: ", mean_squared_error(y_test, y_pred_self))

# 绘制损失函数曲线
plt.plot(range(len(lasso_self.cost_his)), lasso_self.cost_his)
plt.title("Lasso Regression Cost History")
plt.xlabel("Number of Iterations")
plt.ylabel("Cost")
plt.show()

上述代码中,我们首先生成了一个包含100个样本数据和10个特征的人工数据集,然后将其分为训练集和测试集。接着,我们评估了最终的预测结果,并且输出了其均方误差。最后,我们绘制损失函数的历史曲线以及与sklearn库结果进行对比,以验证我们上述模型的准确性和可信度。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:解析python实现Lasso回归 - Python技术站

(0)
上一篇 2023年5月18日
下一篇 2023年5月18日

相关文章

  • Effective HPA:预测未来的弹性伸缩产品

    作者 胡启明,腾讯云专家工程师,专注 Kubernetes、降本增效等云原生领域,Crane 核心开发工程师,现负责成本优化开源项目 Crane 开源治理和弹性能力落地工作。 余宇飞,腾讯云专家工程师,专注云原生可观测性、成本优化等领域,Crane 核心开发者,现负责 Crane 资源预测、推荐落地、运营平台建设等相关工作。 田奇,腾讯高级工程师,专注分布式…

    2023年4月9日
    00
  • python分析inkscape路径数据方案简单介绍

    Python分析Inkscape路径数据方案简单介绍 什么是Inkscape路径数据? 在Inkscape中,路径是由节点和线段组成的,其中节点用于确定路径方向和曲率,线段用于连接节点并绘制路径。路径数据是描述路径的元素、属性和值的集合。 路径数据通常使用SVG(Scalable Vector Graphics)语法进行描述,SVG是一种可缩放的矢量图形语言…

    云计算 2023年5月18日
    00
  • Python模块域名dnspython解析

    Python模块dnspython是一个轻量级的DNS解析库,旨在提供DNS解析和相关工具的Python编程接口,常用于网络编程、域名解析和DNS服务器查询等应用场合。但是,对于初学者来说,可能会感到有些困难。下面我们将详细讲解“Python模块域名dnspython解析”的完整攻略。 安装dnspython模块 首先需要安装dnspython模块,在命令行…

    云计算 2023年5月18日
    00
  • 云计算OpenStack环境搭建(4)

    准备工作:   准备3台机器,确保yum源是可用的,分别为控制节点(192.168.11.3)、计算节点(192.168.11.4)和存储节点(192.168.11.5) 控制节点:OpenStack日常的管理服务都运行的节点(OpenStack packages、mariadb、rabbitmq、memcached、keystone、glance) 计算节…

    云计算 2023年4月11日
    00
  • 战火硝烟中的云计算 (云计算今生来世2)

    Google, 微软,亚马逊和IBM是几个云计算领域里的主要玩家。从出身来看,Google 是广告商,IBM和微软是软件公司而Amazon 是电子商务公司,除了IBM和微软,似乎其他公司本来与软件服务相去甚远。然而随着云计算的普及,领域之间的界限将日益模糊,各个厂商将进入其他领域进行竞争,谁能吸引更多的客户到自己的平台上,谁就能在竞争中立于不败之地。 Goo…

    云计算 2023年4月10日
    00
  • 云原生周刊:Kubernetes 1.27 服务器端字段校验和 OpenAPI V3 进阶至 GA

    开源项目推荐 KubeView KubeView 是一个 Kubernetes 集群可视化工具和可视化资源管理器。它允许用户在集群内部运行命令,并查看集群内部的资源使用情况、容器运行状态、网络流量等。KubeView 支持多种数据源,可以读取 Prometheus、Grafana、Kubernetes 管理等工具的数据,将集群内部的数据可视化。 kube-s…

    云计算 2023年5月8日
    00
  • 基于云边协同架构的五大应用场景革新

    从概念到场景落地,边缘云加速革新,颠覆体验,拟造丰沛生态。 边缘云的概念自明确以来已有四个多年头。 什么是边缘云? 边缘云,即把公共云的能力放在离数据发生端和消费端最近的地方,提升数据的处理效率,承载更多场景,同时降低数据的搬运成本。 在边缘云的演进过程中,阿里云提炼出边缘云技术发展的三大价值驱动力,通过云边协同的方式,推动企业数字化发展,为用户带去更多的可…

    云计算 2023年4月13日
    00
  • 在Go中使用JSON(附demo)

    让我来为您讲解如何在Go中使用JSON。 前置知识 在了解如何在Go中使用JSON之前,我们需要先了解一些前置知识: JSON简介 JSON(JavaScript Object Notation)是一种轻量级的数据交换格式,易于人阅读和编写,同时也易于机器解析和生成。在Web应用程序中,JSON通常用于从服务器传输数据到客户端。 在JSON中,数据以键值对的…

    云计算 2023年5月17日
    00
合作推广
合作推广
分享本页
返回顶部