Python Pytorch学习之图像检索实践

yizhihongxing

Python PyTorch学习之图像检索实践

本文将介绍如何使用Python和PyTorch实现图像检索。我们将使用一个预训练的卷积神经网络模型来提取图像特征,并使用余弦相似度来计算图像之间的相似度。本文将分为以下几个部分:

  1. 数据集准备
  2. 模型准备
  3. 图像特征提取
  4. 图像检索
  5. 示例说明

数据集准备

我们将使用CIFAR-10数据集作为我们的图像数据集。CIFAR-10数据集包含10个类别的60000张32x32彩色图像。我们将使用其中的50000张图像作为训练集,10000张图像作为测试集。

我们可以使用torchvision库来下载和加载CIFAR-10数据集。以下是加载CIFAR-10数据集的示例代码:

import torch
import torchvision
import torchvision.transforms as transforms

# Define data transforms
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Load CIFAR-10 dataset
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

# Create data loaders
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False, num_workers=2)

在这个示例中,我们首先定义了一个数据变换transform,用于将图像转换为张量并进行归一化。然后,我们使用torchvision.datasets.CIFAR10函数加载CIFAR-10数据集,并使用数据变换transform对数据进行预处理。最后,我们使用torch.utils.data.DataLoader函数创建训练集和测试集的数据加载器。

模型准备

我们将使用一个预训练的卷积神经网络模型来提取图像特征。在本文中,我们将使用ResNet-18模型。ResNet-18是一个18层的深度卷积神经网络模型,它在ImageNet数据集上进行了预训练,并在图像分类任务上取得了很好的性能。

我们可以使用torchvision.models.resnet18函数来加载ResNet-18模型。以下是加载ResNet-18模型的示例代码:

import torch
import torchvision.models as models

# Load ResNet-18 model
model = models.resnet18(pretrained=True)

在这个示例中,我们使用torchvision.models.resnet18函数加载ResNet-18模型,并将其预训练的参数加载到模型中。

图像特征提取

我们将使用ResNet-18模型来提取图像特征。我们将使用模型的最后一个全局平均池化层来提取图像特征。全局平均池化层将图像特征图转换为一个向量,这个向量可以表示整个图像的特征。

以下是使用ResNet-18模型提取图像特征的示例代码:

import torch
import torchvision.models as models

# Load ResNet-18 model
model = models.resnet18(pretrained=True)

# Remove last layer
model = torch.nn.Sequential(*list(model.children())[:-1])

# Extract features
features = []
labels = []
for images, target in trainloader:
    output = model(images)
    features.append(output)
    labels.append(target)
features = torch.cat(features, dim=0)
labels = torch.cat(labels, dim=0)

在这个示例中,我们首先加载ResNet-18模型,并使用torch.nn.Sequential函数将模型的最后一层去掉。然后,我们使用模型提取训练集中所有图像的特征,并将这些特征存储在一个张量features中。我们还将训练集中所有图像的标签存储在一个张量labels中。

图像检索

我们将使用余弦相似度来计算图像之间的相似度。余弦相似度是一种常用的相似度度量方法,它可以衡量两个向量之间的夹角余弦值。余弦相似度的取值范围为[-1, 1],值越大表示两个向量越相似。

以下是使用余弦相似度计算图像之间相似度的示例代码:

import torch

# Compute cosine similarity
def cosine_similarity(x, y):
    return torch.dot(x, y) / (torch.norm(x) * torch.norm(y))

# Compute similarity matrix
similarity_matrix = torch.zeros(len(features), len(features))
for i in range(len(features)):
    for j in range(len(features)):
        similarity_matrix[i, j] = cosine_similarity(features[i], features[j])

# Retrieve similar images
query_index = 0
similar_indices = similarity_matrix[query_index].argsort(descending=True)[:10]

在这个示例中,我们首先定义了一个计算余弦相似度的函数cosine_similarity。然后,我们使用这个函数计算所有图像之间的相似度,并将相似度存储在一个张量similarity_matrix中。最后,我们选择一个查询图像,并使用similarity_matrix找到与查询图像最相似的10张图像的索引。

示例说明

以下是两个使用Python和PyTorch实现图像检索的示例说明:

示例1:使用CIFAR-10数据集实现图像检索

以下是一个使用CIFAR-10数据集实现图像检索的示例代码:

import torch
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models

# Define data transforms
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Load CIFAR-10 dataset
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

# Create data loaders
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False, num_workers=2)

# Load ResNet-18 model
model = models.resnet18(pretrained=True)
model = torch.nn.Sequential(*list(model.children())[:-1])

# Extract features
features = []
labels = []
for images, target in trainloader:
    output = model(images)
    features.append(output)
    labels.append(target)
features = torch.cat(features, dim=0)
labels = torch.cat(labels, dim=0)

# Compute cosine similarity
def cosine_similarity(x, y):
    return torch.dot(x, y) / (torch.norm(x) * torch.norm(y))

# Compute similarity matrix
similarity_matrix = torch.zeros(len(features), len(features))
for i in range(len(features)):
    for j in range(len(features)):
        similarity_matrix[i, j] = cosine_similarity(features[i], features[j])

# Retrieve similar images
query_index = 0
similar_indices = similarity_matrix[query_index].argsort(descending=True)[:10]
print(similar_indices)

在这个示例中,我们首先加载CIFAR-10数据集,并使用ResNet-18模型提取训练集中所有图像的特征。然后,我们使用余弦相似度计算所有图像之间的相似度,并找到与查询图像最相似的10张图像的索引。

示例2:使用自定义数据集实现图像检索

以下是一个使用自定义数据集实现图像检索的示例代码:

import torch
import torchvision.transforms as transforms
import torchvision.models as models
from PIL import Image

# Define data transforms
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Load custom dataset
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, root, transform=None):
        self.root = root
        self.transform = transform
        self.images = []
        for filename in os.listdir(root):
            if filename.endswith('.jpg'):
                self.images.append(os.path.join(root, filename))

    def __getitem__(self, index):
        image_path = self.images[index]
        image = Image.open(image_path).convert('RGB')
        if self.transform is not None:
            image = self.transform(image)
        return image

    def __len__(self):
        return len(self.images)

dataset = CustomDataset('./images', transform=transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=False, num_workers=2)

# Load ResNet-18 model
model = models.resnet18(pretrained=True)
model = torch.nn.Sequential(*list(model.children())[:-1])

# Extract features
features = []
for images in dataloader:
    output = model(images)
    features.append(output)
features = torch.cat(features, dim=0)

# Compute cosine similarity
def cosine_similarity(x, y):
    return torch.dot(x, y) / (torch.norm(x) * torch.norm(y))

# Compute similarity matrix
similarity_matrix = torch.zeros(len(features), len(features))
for i in range(len(features)):
    for j in range(len(features)):
        similarity_matrix[i, j] = cosine_similarity(features[i], features[j])

# Retrieve similar images
query_index = 0
similar_indices = similarity_matrix[query_index].argsort(descending=True)[:10]
print(similar_indices)

在这个示例中,我们首先定义了一个自定义数据集CustomDataset,并使用ResNet-18模型提取所有图像的特征。然后,我们使用余弦相似度计算所有图像之间的相似度,并找到与查询图像最相似的10张图像的索引。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:Python Pytorch学习之图像检索实践 - Python技术站

(0)
上一篇 2023年5月16日
下一篇 2023年5月16日

相关文章

  • pytorch transforms图像增强怎么实现

    这篇文章主要介绍“pytorch transforms图像增强怎么实现”的相关知识,小编通过实际案例向大家展示操作过程,操作方法简单快捷,实用性强,希望这篇“pytorch transforms图像增强怎么实现”文章能帮助大家解决问题。 一、前言 本文基于的是pytorch3.7.1 二、图像处理 深度学习是由数据驱动的,而数据的数量和分布对于模型的优劣具有…

    PyTorch 2023年4月7日
    00
  • pytorch 入门指南

    两类深度学习框架的优缺点 动态图(PyTorch) 计算图的进行与代码的运行时同时进行的。 静态图(Tensorflow <2.0) 自建命名体系 自建时序控制 难以介入 使用深度学习框架的优点 GPU 加速 (cuda) 自动求导 常用网络层的API PyTorch 的特点 支持 GPU 动态神经网络 Python 优先 命令式体验 轻松扩展 1.P…

    PyTorch 2023年4月8日
    00
  • PyTorch Geometric Temporal 介绍 —— 数据结构和RGCN的概念

    Introduction PyTorch Geometric Temporal is a temporal graph neural network extension library for PyTorch Geometric. PyTorch Geometric Temporal 是基于PyTorch Geometric的对时间序列图数据的扩展。 Dat…

    PyTorch 2023年4月8日
    00
  • pytorch: cudnn.benchmark=True

    import torch.backends.cudnn as cudnn cudnn.benchmark = True 设置这个 flag 可以让内置的 cuDNN 的 auto-tuner 自动寻找最适合当前配置的高效算法,来达到优化运行效率的问题。如果网络的输入数据维度或类型上变化不大,也就是每次训练的图像尺寸都是一样的时候,设置 torch.backe…

    PyTorch 2023年4月8日
    00
  • pytorch调用gpu

    第一步!指定gpu import osos.environ[“CUDA_VISIBLE_DEVICES”] = ‘0’ 第二步! 对于每一个要踹到gpu去的Tensor或者model x 使用x = x.cuda()就ok了 嘤嘤嘤

    PyTorch 2023年4月6日
    00
  • pytorch踩坑记

    因为我有数学物理背景,所以清楚卷积的原理。但是在看pytorch文档的时候感到非常头大,罗列的公式以及各种令人眩晕的下标让入门新手不知所云…最初我以为torch.nn.conv1d的参数in_channel/out_channel表示图像的通道数,经过运行错误提示之后,才知道[in_channel,kernel_size]构成了卷积核。  loss函数中…

    2023年4月6日
    00
  • Pytorch从一个输入目录中加载所有的PNG图像,并将它们存储在张量中

    1 import os 2 import imageio 3 from imageio import imread 4 import torch 5 6 # batch_size = 3 7 # batch = torch.zeros(batch_size, 3, 256, 256, dtype=torch.uint8) 8 # batch.shape #t…

    PyTorch 2023年4月7日
    00
  • 初识Pytorch使用transforms的代码

    初识Pytorch使用transforms的代码 在PyTorch中,transforms是一个常用的数据预处理工具。在使用transforms时,可以对数据进行各种预处理操作,例如裁剪、缩放、旋转、翻转等。本文将介绍如何使用transforms,并演示两个示例。 示例一:对图像进行随机裁剪和水平翻转 import torch import torchvis…

    PyTorch 2023年5月15日
    00
合作推广
合作推广
分享本页
返回顶部