Pytorch+PyG实现GraphSAGE过程示例详解

GraphSAGE是一种用于节点嵌入的图神经网络模型,它可以学习节点的低维向量表示,以便于在图上进行各种任务,如节点分类、链接预测等。在本文中,我们将介绍如何使用PyTorch和PyG实现GraphSAGE模型,并提供两个示例说明。

示例1:使用GraphSAGE进行节点分类

在这个示例中,我们将使用GraphSAGE模型对Cora数据集中的节点进行分类。Cora数据集是一个引文网络数据集,其中每个节点代表一篇论文,每个边代表一篇论文引用另一篇论文。每个节点都有一个类别标签,表示它所属的研究领域。

步骤1:加载数据集

首先,我们需要加载Cora数据集。我们可以使用PyG中的Planetoid类来加载数据集。下面是一个示例:

from torch_geometric.datasets import Planetoid

# 加载Cora数据集
dataset = Planetoid(root='data', name='Cora')

在这个示例中,我们首先导入Planetoid类,然后使用该类加载Cora数据集。我们将数据集保存在data文件夹中,并将其命名为Cora

步骤2:定义GraphSAGE模型

接下来,我们需要定义GraphSAGE模型。我们可以使用PyG中的SAGEConv类来定义GraphSAGE层。下面是一个示例:

import torch.nn.functional as F
from torch_geometric.nn import SAGEConv

class GraphSAGE(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(GraphSAGE, self).__init__()
        self.conv1 = SAGEConv(in_channels, hidden_channels)
        self.conv2 = SAGEConv(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)

在这个示例中,我们首先导入SAGEConv类和F模块,然后定义了一个名为GraphSAGE的类,该类继承自torch.nn.Module类。在__init__方法中,我们定义了两个GraphSAGE层,每个层都包含一个SAGEConv层和一个ReLU激活函数。在forward方法中,我们首先对输入张量进行第一层GraphSAGE操作,然后进行dropout操作,最后进行第二层GraphSAGE操作,并使用F.log_softmax函数计算输出。

步骤3:训练模型

接下来,我们需要训练GraphSAGE模型。我们可以使用PyG中的train函数来训练模型。下面是一个示例:

import torch
from torch_geometric.data import DataLoader

# 定义超参数
lr = 0.01
epochs = 200
hidden_channels = 16

# 加载数据集
dataset = Planetoid(root='data', name='Cora')
data = dataset[0]
loader = DataLoader(dataset, batch_size=32, shuffle=True)

# 定义模型和优化器
model = GraphSAGE(dataset.num_features, hidden_channels, dataset.num_classes)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

# 训练模型
model.train()
for epoch in range(epochs):
    for batch in loader:
        optimizer.zero_grad()
        out = model(batch.x, batch.edge_index)
        loss = F.nll_loss(out, batch.y)
        loss.backward()
        optimizer.step()

    print(f'Epoch {epoch + 1}/{epochs}, Loss: {loss.item():.4f}')

在这个示例中,我们首先定义了超参数,包括学习率、训练轮数和隐藏层维度。然后,我们加载Cora数据集,并使用DataLoader类创建一个数据加载器。接下来,我们定义了GraphSAGE模型和Adam优化器。在训练过程中,我们使用model.train()语句将模型设置为训练模式,并使用optimizer.zero_grad()语句清除梯度。然后,我们计算模型输出和损失,并使用loss.backward()语句计算梯度。最后,我们使用optimizer.step()语句更新模型参数,并打印训练信息。

步骤4:测试模型

最后,我们需要测试GraphSAGE模型的性能。我们可以使用PyG中的test函数来测试模型。下面是一个示例:

model.eval()
correct = 0
for batch in loader:
    out = model(batch.x, batch.edge_index)
    pred = out.argmax(dim=1)
    correct += int((pred == batch.y).sum())

acc = correct / len(dataset)
print(f'Test Accuracy: {acc:.4f}')

在这个示例中,我们首先使用model.eval()语句将模型设置为评估模式。然后,我们使用一个循环遍历测试集中的所有批次,并计算模型输出和预测标签。最后,我们计算模型的准确率,并打印测试结果。

示例2:使用GraphSAGE进行链接预测

在这个示例中,我们将使用GraphSAGE模型对Cora数据集中的链接进行预测。具体来说,我们将使用GraphSAGE模型预测Cora数据集中每个节点的邻居节点。

步骤1:加载数据集

首先,我们需要加载Cora数据集。我们可以使用PyG中的Planetoid类来加载数据集。下面是一个示例:

from torch_geometric.datasets import Planetoid

# 加载Cora数据集
dataset = Planetoid(root='data', name='Cora')

在这个示例中,我们首先导入Planetoid类,然后使用该类加载Cora数据集。我们将数据集保存在data文件夹中,并将其命名为Cora

步骤2:定义GraphSAGE模型

接下来,我们需要定义GraphSAGE模型。我们可以使用PyG中的SAGEConv类来定义GraphSAGE层。下面是一个示例:

import torch.nn.functional as F
from torch_geometric.nn import SAGEConv

class GraphSAGE(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(GraphSAGE, self).__init__()
        self.conv1 = SAGEConv(in_channels, hidden_channels)
        self.conv2 = SAGEConv(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        return x

在这个示例中,我们首先导入SAGEConv类和F模块,然后定义了一个名为GraphSAGE的类,该类继承自torch.nn.Module类。在__init__方法中,我们定义了两个GraphSAGE层,每个层都包含一个SAGEConv层和一个ReLU激活函数。在forward方法中,我们首先对输入张量进行第一层GraphSAGE操作,然后进行dropout操作,最后进行第二层GraphSAGE操作,并返回输出。

步骤3:训练模型

接下来,我们需要训练GraphSAGE模型。我们可以使用PyG中的train函数来训练模型。下面是一个示例:

import torch
from torch_geometric.data import DataLoader
from torch_geometric.utils import train_test_split

# 定义超参数
lr = 0.01
epochs = 200
hidden_channels = 16

# 加载数据集
dataset = Planetoid(root='data', name='Cora')
data = dataset[0]
data.train_mask = data.val_mask = data.test_mask = data.y = None
data = train_test_split(data, val_ratio=0.05, test_ratio=0.1)
loader = DataLoader(data, batch_size=32, shuffle=True)

# 定义模型和优化器
model = GraphSAGE(dataset.num_features, hidden_channels, dataset.num_features)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

# 训练模型
model.train()
for epoch in range(epochs):
    for batch in loader:
        optimizer.zero_grad()
        out = model(batch.x, batch.edge_index)
        loss = F.mse_loss(out, batch.x[batch.edge_index[1]])
        loss.backward()
        optimizer.step()

    print(f'Epoch {epoch + 1}/{epochs}, Loss: {loss.item():.4f}')

在这个示例中,我们首先定义了超参数,包括学习率、训练轮数和隐藏层维度。然后,我们加载Cora数据集,并使用train_test_split函数将数据集分成训练集、验证集和测试集。接下来,我们使用DataLoader类创建一个数据加载器。然后,我们定义了GraphSAGE模型和Adam优化器。在训练过程中,我们使用model.train()语句将模型设置为训练模式,并使用optimizer.zero_grad()语句清除梯度。然后,我们计算模型输出和损失,并使用loss.backward()语句计算梯度。最后,我们使用optimizer.step()语句更新模型参数,并打印训练信息。

步骤4:测试模型

最后,我们需要测试GraphSAGE模型的性能。我们可以使用PyG中的test函数来测试模型。下面是一个示例:

model.eval()
z = model(data.x, data.edge_index)
z = z.detach().cpu().numpy()

from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score

train_idx = data.train_mask.nonzero(as_tuple=False).view(-1).tolist()
test_idx = data.test_mask.nonzero(as_tuple=False).view(-1).tolist()

clf = LogisticRegression(max_iter=10000)
clf.fit(z[train_idx], data.y[train_idx])
pred = clf.predict(z[test_idx])
acc = accuracy_score(data.y[test_idx], pred)

print(f'Test Accuracy: {acc:.4f}')

在这个示例中,我们首先使用model.eval()语句将模型设置为评估模式。然后,我们使用模型对所有节点进行嵌入,并将嵌入张量转换为NumPy数组。接下来,我们使用Logistic回归模型对嵌入进行训练,并使用测试集进行预测。最后,我们计算模型的准确率,并打印测试结果。

总之,使用PyTorch和PyG实现GraphSAGE模型非常简单。我们可以使用PyG中的SAGEConv类定义GraphSAGE层,使用PyTorch中的torch.nn.Module类定义模型,使用PyG中的train函数训练模型,使用PyG中的test函数测试模型。我们可以使用GraphSAGE模型进行节点分类、链接预测等各种任务。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:Pytorch+PyG实现GraphSAGE过程示例详解 - Python技术站

(0)
上一篇 2023年5月15日
下一篇 2023年5月15日

相关文章

  • 在Pytorch中使用样本权重(sample_weight)的正确方法

    在PyTorch中,使用样本权重(sample_weight)可以对不同样本的重要性进行加权,从而提高模型的性能。本文将详细介绍在PyTorch中使用样本权重的正确方法,并提供两个示例说明。 1. 使用torch.nn.CrossEntropyLoss实现样本权重 在PyTorch中,可以使用torch.nn.CrossEntropyLoss函数实现样本权重…

    PyTorch 2023年5月15日
    00
  • pytorch bug记录

    一 pytorch 使用tensorboard在使用tensorboard 展示PROJECTOR 的时候发现并没有显示。 writer.add_embedding(features, metadata=class_labels, label_img=images.unsqueeze(1)) 继而安装了 tensorboard 和 tensorboardx …

    PyTorch 2023年4月8日
    00
  • Pytorch框架详解之一

    Pytorch基础操作 numpy基础操作 定义数组(一维与多维) 寻找最大值 维度上升与维度下降 数组计算 矩阵reshape 矩阵维度转换 代码实现 import numpy as np a = np.array([1, 2, 3, 4, 5, 6]) # array数组 b = np.array([8, 7, 6, 5, 4, 3]) print(a.…

    2023年4月8日
    00
  • PyTorch 导数应用的使用教程

    PyTorch 导数应用的使用教程 PyTorch 是一个基于 Python 的科学计算库,它主要用于深度学习和神经网络。在 PyTorch 中,导数应用是非常重要的一个功能,它可以帮助我们计算函数的梯度,从而实现自动微分和反向传播。本文将详细讲解 PyTorch 导数应用的使用教程,并提供两个示例说明。 1. PyTorch 导数应用的基础知识 在 PyT…

    PyTorch 2023年5月16日
    00
  • pytorch(十九):MNIST打印准确率和损失

    一、例子            二、整体代码 import torch from torch.nn import functional as F import torch.nn as nn import torchvision from torchvision import datasets,transforms import torch.optim as …

    PyTorch 2023年4月7日
    00
  • pytorch 实现 AlexNet 网络模型训练自定义图片分类

    1、AlexNet网络模型,pytorch1.1.0 实现      注意:AlexNet,in_img_size >=64 输入图片矩阵的大小要大于等于64 # coding:utf-8 import torch.nn as nn import torch class alex_net(nn.Module): def __init__(self,in…

    PyTorch 2023年4月8日
    00
  • win10使用清华源快速安装pytorch-GPU版(推荐)

    Win10使用清华源快速安装PyTorch-GPU版(推荐) 在Win10上安装PyTorch-GPU版可以加速深度学习模型的训练。本文将介绍如何使用清华源快速安装PyTorch-GPU版,并提供两个示例。 安装Anaconda 首先,我们需要安装Anaconda,它是一个流行的Python发行版,包含了许多常用的Python库和工具。您可以从官方网站下载适…

    PyTorch 2023年5月16日
    00
  • pytorch中histc()函数与numpy中histogram()及histogram2d()函数

    引言   直方图是一种对数据分布的描述,在图像处理中,直方图概念非常重要,应用广泛,如图像对比度增强(直方图均衡化),图像信息量度量(信息熵),图像配准(利用两张图像的互信息度量相似度)等。 1、numpy中histogram()函数用于统计一个数据的分布 numpy.histogram(a, bins=10, range=None, normed=None…

    2023年4月8日
    00
合作推广
合作推广
分享本页
返回顶部