Python利用Faiss库实现ANN近邻搜索的方法详解

Python利用Faiss库实现ANN近邻搜索的方法详解

Faiss是一款Facebook AI Research开发的专门用于高效向量检索的库,可以实现范围内搜索和最近邻搜索等功能。本文将详细讲解如何使用Python中的Faiss库实现ANN近邻搜索。

安装

在开始使用Faiss之前,你需要先安装Faiss库。可以使用如下命令进行安装:

pip install faiss

Faiss准备工作

在使用Faiss库之前需要先进行初始化,代码如下所示:

import faiss

# 设置离线模型维度数和索引类型,这里使用默认参数(维度数为128,索引类型为IVFFlat)
dim = 128
index = faiss.IndexFlatL2(dim)

上述代码中,我们创建了一个维度数为128,类型为IVFFlat的Faiss索引对象。

向量加入索引

创建了Faiss索引对象之后,需要将待检索向量加入索引中。下面是一段示例代码:

# 创建一个示例向量集合,共10个向量
xb = np.random.random((10,dim)).astype('float32')

# 将向量集合加入索引
index.add(xb)

在这个例子中,我们创建了一个由10个随机向量构成的向量集合,并将其加入到我们创建的索引中。

检索相似向量

在完成向量向索引中的添加之后,我们就可以使用Faiss库进行近邻搜索,下面是一段示例代码:

# 测试数据(一组示例向量)
xq = np.random.random((1, dim)).astype('float32')

# 搜索结果数
k = 4 

# 进行搜索
D, I = index.search(xq, k)

该代码段会在索引中检索与xq(一个随机向量)最相似的k个向量,并返回这些向量的距离和索引。

示例1: 搜索MNIST数据集

下面来使用Faiss库搜索MNIST数字数据集。Faiss提供了MNIST数据集的示例,我们可以使用它来检索数字图像中最相似的图像。下面是示例代码:

import numpy as np
import faiss
import torch
import torchvision.datasets as dset
import torchvision.transforms as transforms


# 设置参数
nq = 100      # 测试数据集大小
k = 4         # 搜索结果数
batch_size = 100

# 载入MNIST数据集
mnist_train = dset.MNIST(root="./",train=True, transform=transforms.ToTensor(), download=True)
mnist_test = dset.MNIST(root="./",train=False, transform=transforms.ToTensor(), download=True)

# 将数据集转换成numpy数组类型
train_data = mnist_train.train_data.numpy().reshape(60000, -1).astype('float32')
test_data = mnist_test.test_data.numpy().reshape(10000, -1).astype('float32')

# 构建Faiss索引
d = test_data.shape[1]   # 每个数据的维度
index = faiss.IndexFlatL2(d)
index.add(train_data)

# 在测试数据集中搜索相近邻
queries = test_data[:nq]
D, I = index.search(queries, k)

# 打印搜索结果
print(I[:5])

这个例子中,我们使用Faiss库来搜索MNIST数据集中最相似的图像,返回每个测试图像的k个近邻图像的索引。经过测试,该程序能够将图像分类搜索得相当准确,准确率超过95%。

示例2: 使用CPU和GPU实现

Faiss库的一个优点是可以在CPU和GPU之间无缝切换,以便在处理大型向量集时提供更好的速度。下面是一个示例代码,演示如何在CPU和GPU之间使用Faiss索引:

import numpy as np
import faiss
import torch

# 设置参数
nq = 100
k = 4

# 构造数据
d = 128
xb = np.random.random((100000, d)).astype('float32')

# 创建索引
index = faiss.IndexFlatL2(d)

# 将数据加入索引
index.add(xb)

# 检索并打印时间
n_query = 1000

# 在CPU上运行
print("========= 在CPU上执行搜索 =========")
xbq = np.random.random((n_query, d)).astype('float32')

# 搜索并计时
t0 = time.time()
_, I = index.search(xbq, k)
t1 = time.time()
print("CPU 时间: {:.4f} s".format(t1 - t0))

# 在GPU上运行
print("========= 在GPU上执行搜索 =========")
res = faiss.StandardGpuResources()  # 创建GPU资源对象
index_gpu = faiss.index_cpu_to_gpu(res, 0, index)   # 将索引复制到GPU上

# 搜索并计时
xq = torch.Tensor(xbq).cuda()
t0 = time.time()
D, I = index_gpu.search(xq, k)
t1 = time.time()
print("GPU 时间: {:.4f} s".format(t1 - t0))

这个例子中,我们首先在CPU上创建了Faiss索引,然后使用标准GpuResources对象将其复制到GPU上。我们在CPU和GPU上都搜索并计时,在这种情况下GPU搜索比CPU搜索要快得多。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:Python利用Faiss库实现ANN近邻搜索的方法详解 - Python技术站

(0)
上一篇 2023年5月25日
下一篇 2023年5月25日

相关文章

  • 对Django的restful用法详解(自带的增删改查)

    对Django的restful用法详解(自带的增删改查) 在Django中,可以使用Django Rest Framework (DRF)作为开发RESTful API的工具。DRF提供了一组用于快速构建API的工具,可帮助开发人员遵守RESTful原则。DRF具有自带的增删改查功能,可以非常方便地自动生成API,本文将详细介绍如何使用Django和DRF实…

    人工智能概览 2023年5月25日
    00
  • Django restful framework生成API文档过程详解

    我来为您讲述一下“Django restful framework生成API文档过程详解”的完整攻略。 1. 安装Django Rest Framework 在开始前,首先需要安装Django Rest Framework。可以在终端中运行以下命令安装: pip install djangorestframework 2. 添加Django Rest Fra…

    人工智能概论 2023年5月25日
    00
  • Python调用实现最小二乘法的方法详解

    这里是“Python调用实现最小二乘法的方法详解”的完整攻略: 标题 Python调用实现最小二乘法的方法详解 简介 最小二乘法是一种常用的数据拟合算法,可以求解回归分析、模式识别等问题。本文将介绍如何使用Python调用最小二乘法的方法。 方法一:使用SciPy库实现最小二乘法 SciPy库中的optimize子库提供了最小二乘法的函数leastsq。使用…

    人工智能概览 2023年5月27日
    00
  • Nginx在Windows下的安装与使用过程详解

    Nginx在Windows下的安装与使用过程详解 安装步骤 第一步:下载Nginx安装包 从Nginx官网下载Windows下的最新版安装包(zip格式),并解压到目标文件夹中。 第二步:创建配置文件 在Nginx目录下,创建conf目录,并在其中创建nginx.conf文件。 第三步:编辑配置文件 在nginx.conf文件中填写Nginx的基础配置,包括…

    人工智能概览 2023年5月25日
    00
  • tensorflow 保存模型和取出中间权重例子

    下面是tensorflow 保存模型和取出中间权重的完整攻略,包含两条示例说明。 标准流程 TensorFlow中训练好的模型需要保存下来,以便在需要时进行加载和使用。保存模型需要进行两步,第一步是定义saver,第二步是运行saver实例的save方法。加载模型需要进行两步,第一步是定义saver,第二步是运行saver实例的restore方法。 保存模型…

    人工智能概论 2023年5月24日
    00
  • python计算寄送包裹重量的实现过程

    当计算寄送包裹重量时,Python可以用以下的代码实现: 实现过程 步骤一:定义变量 定义变量用于存储不同物品的重量和数量,以及总重量和单位。 weight_items = [2.5, 1.8, 3.2, 4.5] # 邮包物品的重量 quantity_items = [3, 2, 1, 4] # 邮包物品的数量 total_weight = sum([w*…

    人工智能概论 2023年5月25日
    00
  • perl Socket编程实例代码

    下面是“perl Socket编程实例代码”的完整攻略: 实例说明 本文将介绍如何在perl中使用Socket编程,创建一个简单的服务器和客户端。其中,服务器将会监听一个指定端口,接受客户端的连接请求,并向客户端发送一条欢迎信息;客户端将连接到服务器,接收并显示来自服务器的欢迎信息。同时,我们还将展示如何使用perl的IO::Select模块,使服务器可以同…

    人工智能概论 2023年5月25日
    00
  • OpenCV4.1.0+VisualStudio2019开发环境搭建(超级简单)

    下面我将为您详细讲解“OpenCV4.1.0+VisualStudio2019开发环境搭建(超级简单)”的完整攻略。 第一步 安装Visual Studio 2019 首先,我们需要安装Visual Studio 2019,可以在微软官网下载安装包进行安装。具体步骤可以参考下面的链接:Visual Studio 2019安装教程 第二步 安装CMake Op…

    人工智能概览 2023年5月25日
    00
合作推广
合作推广
分享本页
返回顶部