python机器学习理论与实战(六)支持向量机

Python机器学习理论与实战(六)支持向量机

简介

支持向量机(Support Vector Machine,简称 SVM)是一个强大的分类算法,其具有优秀的泛化能力。在本文中,我们将介绍 SVM 的原理、实现及应用。

SVM 原理

SVM 的核心思想是:找到一个可以将不同类别的数据分割开的最优超平面。其中“最优”的定义是:在所有能成功分割不同类别数据的超平面中,选择距离两类样本点最近的点到超平面的距离最大的超平面。

SVM 实现

SVM 可以用于线性可分和线性不可分的情况。使用不同的核函数可以将线性不可分的情况转化为线性可分,从而解决问题。

下面是使用 scikit-learn 实现 SVM 的基本流程:

from sklearn import svm
clf = svm.SVC()  # 初始化 SVM 模型
clf.fit(X_train, y_train)  # 训练模型
y_predict = clf.predict(X_test)  # 预测结果

其中 X_train 表示训练集的特征矩阵,y_train 表示训练集的标签,X_test 表示测试集的特征矩阵,y_predict 表示模型预测出的测试集标签。

下面是一个简单的例子,使用 SVM 对 iris 数据集进行分类:

from sklearn import svm, datasets
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

# 加载数据集
iris = datasets.load_iris()
X = iris.data
y = iris.target

# 划分数据集为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=0)

# 初始化 SVM 模型
clf = svm.SVC(kernel='linear', C=1)
clf.fit(X_train, y_train)

# 预测结果并计算准确率
y_predict = clf.predict(X_test)
accuracy = accuracy_score(y_test, y_predict)
print('准确率:%.2f%%' % (accuracy * 100))

SVM 应用

SVM 通常用于分类问题,例如文本分类、图像分类等。以下是两个示例:

例子1:使用 SVM 进行文本分类

下面是将 SVM 用于文本分类的示例:

import numpy as np
from sklearn import svm
from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

# 加载数据集
newsgroups_train = fetch_20newsgroups(subset='train')
categories = newsgroups_train.target_names

# 特征工程
vectorizer = TfidfVectorizer()
X = vectorizer.fit_transform(newsgroups_train.data)
y = newsgroups_train.target

# 划分数据集为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=0)

# 初始化 SVM 模型
clf = svm.SVC(kernel='linear', C=1)
clf.fit(X_train, y_train)

# 预测结果并计算准确率
y_predict = clf.predict(X_test)
accuracy = accuracy_score(y_test, y_predict)
print('准确率:%.2f%%' % (accuracy * 100))

例子2:使用 SVM 进行图像分类

下面是将 SVM 用于图像分类的示例:

import os
import cv2
import numpy as np
from sklearn import svm
from sklearn.metrics import accuracy_score

# 加载数据集
data_dir = './data'
classes = ['cats', 'dogs']

X, y = [], []
for class_id, class_name in enumerate(classes):
    for file_name in os.listdir(os.path.join(data_dir, class_name)):
        img_path = os.path.join(data_dir, class_name, file_name)
        img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
        img = cv2.resize(img, (32, 32))  # 将图像缩放至相同大小
        X.append(img.flatten())  # 将图像数据展平为一维数组
        y.append(class_id)

# 划分数据集为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=0)

# 初始化 SVM 模型
clf = svm.SVC(kernel='linear', C=1)
clf.fit(X_train, y_train)

# 预测结果并计算准确率
y_predict = clf.predict(X_test)
accuracy = accuracy_score(y_test, y_predict)
print('准确率:%.2f%%' % (accuracy * 100))

总结

这篇文章介绍了 SVM 的原理和实现,同时给出了文本分类和图像分类的两个示例。SVM 是一种非常强大的分类算法,具有很好的泛化能力,既可以用于线性可分问题,也可以用于线性不可分问题。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:python机器学习理论与实战(六)支持向量机 - Python技术站

(0)
上一篇 2023年5月23日
下一篇 2023年5月23日

相关文章

  • 详解Python PIL UnsharpMask()方法

    下面是Python PIL库中的UnsharpMask()方法的完整攻略,希望能对您有所帮助。 什么是UnsharpMask()方法? UnsharpMask()是Python PIL(Python Imaging Library)库中的一种图像增强方法,它通过图像锐化来提高图像的清晰度和对比度。UnsharpMask()方法根据输入的图像,生成一个锐化后的…

    python-answer 2023年3月25日
    00
  • 10个python3常用排序算法详细说明与实例(快速排序,冒泡排序,桶排序,基数排序,堆排序,希尔排序,归并排序,计数排序)

    10个Python3常用排序算法详细说明与实例 排序算法是计算机科学中的基本问题之一,它的目的是将一组数据按照一定的顺序排列。Python中提供了多种排序算法,本文将介绍10个常用的排序算法,并提供详细的说明和实例。 1. 快速排序 快速排序是一种基于分治思想的排序算法,它的时间复杂度为O(nlogn)。快速排序的基本思想是选择一个基准元素,将序列分为两个子…

    python 2023年5月14日
    00
  • Python之京东商品秒杀的实现示例

    下面我将详细讲解“Python之京东商品秒杀的实现示例”的完整攻略。 简介 该示例是基于Python语言实现京东商品秒杀的完整流程。通过抓取商品信息和抢购链接信息,利用网络请求模拟登录、加入购物车和提交订单等操作,实现京东商品秒杀的效果。其中,需要用到Python的相关库,如requests、selenium等。 实现步骤 1. 抓取商品信息和抢购链接信息 …

    python 2023年6月2日
    00
  • python入门之基础语法学习笔记

    以下是关于“Python入门之基础语法学习笔记”的完整攻略: 问题描述 Python 是一种高级编程语言,易于学习和使用。本将介绍 Python 的基础语法,包括变量、数据类型、运算符、条件句、循环语句等。 解决方法 1. 变量 在 Python 中,变量是用来存储数据的容器。可以使用赋值语句来创建变量。示例代码如下: x = 10 y = "He…

    python 2023年5月13日
    00
  • python使用mysql数据库示例代码

    下面是Python使用MySQL数据库的示例代码攻略,包含了数据库连接、数据查询和数据插入等操作。 连接MySQL数据库 在Python程序中连接MySQL数据库,需要先安装MySQL-Python模块。使用以下命令可以安装该模块: pip install mysql-connector-python 连接MySQL数据库的代码示例如下: import my…

    python 2023年6月1日
    00
  • Python实现视频裁剪的示例代码

    下面我就来为你详细讲解“Python实现视频裁剪的示例代码”的完整攻略。 简介 首先来了解一下Python实现视频裁剪需要用到的几个关键概念。 OpenCV库 OpenCV是一个基于BSD许可(开源)发行的跨平台计算机视觉库,可以运行在Linux、Windows和Mac OS操作系统上。它轻量级而且高效,因此非常适合于移动端应用的开发。此外,OpenCV也具…

    python 2023年6月3日
    00
  • 详解Python sys.argv使用方法

    详解Python sys.argv使用方法 什么是sys.argv? 在Python中,sys.argv是Python解释器提供的一个命令行参数列表。它包含了命令行参数的所有参数。sys.argv至少包含一项,即当前程序的名称,其余项是用户传递的参数。 如何使用sys.argv? 使用sys.argv需要先导入sys模块,通过sys.argv获取用户传递的参…

    python 2023年6月2日
    00
  • Python爬虫必备之Xpath简介及实例讲解

    Python爬虫必备之Xpath简介及实例讲解 什么是Xpath Xpath(XML Path Language)是一种在XML文档中定位元素的语言。它可以通过标签、属性等特征,准确定位到需要抽取数据的目标元素。在Python爬虫中,Xpath是一个非常重要的工具,可以帮助我们快速准确地抽取需要的数据。 Xpath的基本语法 Xpath的语法非常简单,以下是…

    python 2023年5月14日
    00
合作推广
合作推广
分享本页
返回顶部