Pytorch+PyG实现GIN过程示例详解

下面是关于“Pytorch+PyG实现GIN过程示例详解”的完整攻略。

GIN简介

GIN(Graph Isomorphism Network)是一种基于图同构的神经网络模型,它可以对任意形状的图进行分类、回归和聚类等任务。GIN模型的核心思想是将每个节点的特征向量与其邻居节点的特征向量进行聚合,然后将聚合后的特征向量作为节点的新特征向量。GIN模型可以通过堆叠多个GIN层来构建深度神经网络。

Pytorch+PyG实现GIN过程示例

示例1:使用PyG实现GIN模型

我们将使用PyG(PyTorch Geometric)库来实现GIN模型。PyG是一个基于PyTorch的几何深度学习库,它提供了一组用于处理图形数据的工具和模型。我们将使用PyG中的torch_geometric.nn模块来实现GIN模型。下面是一个示例:

import torch
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree

# 定义GIN层
class GINConv(MessagePassing):
    def __init__(self, emb_dim):
        super(GINConv, self).__init__(aggr='add')
        self.mlp = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2 * emb_dim),
                                        torch.nn.BatchNorm1d(2 * emb_dim),
                                        torch.nn.ReLU(),
                                        torch.nn.Linear(2 * emb_dim, emb_dim),
                                        torch.nn.BatchNorm1d(emb_dim),
                                        torch.nn.ReLU())

    def forward(self, x, edge_index):
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
        row, col = edge_index
        deg = degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        x = self.mlp((1 + norm) * x)
        return x

# 定义GIN模型
class GIN(torch.nn.Module):
    def __init__(self, emb_dim, num_classes):
        super(GIN, self).__init__()
        self.conv1 = GINConv(emb_dim)
        self.conv2 = GINConv(emb_dim)
        self.fc = torch.nn.Linear(emb_dim, num_classes)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = global_add_pool(x, edge_index)
        x = self.fc(x)
        return F.log_softmax(x, dim=1)

在这个示例中,我们首先定义了一个名为GINConv的GIN层。该层包含一个多层感知机(MLP)和一个图消息传递函数。在forward方法中,我们首先使用add_self_loops函数将自环添加到边缘索引中,然后使用degree函数计算每个节点的度数。接下来,我们计算每个节点的归一化因子,并使用它来聚合每个节点的邻居节点的特征向量。最后,我们使用MLP对聚合后的特征向量进行非线性变换,并返回新的特征向量。

然后,我们定义了一个名为GIN的GIN模型。该模型包含两个GIN层和一个全连接层。在forward方法中,我们首先使用第一个GIN层对节点特征向量进行聚合,并使用ReLU函数进行非线性变换。然后,我们使用第二个GIN层对特征向量进行聚合,并再次使用ReLU函数进行非线性变换。接下来,我们使用global_add_pool函数对所有节点的特征向量进行汇总,并使用全连接层将其映射到类别概率。最后,我们使用F.log_softmax函数将输出转换为概率。

示例2:使用PyG实现GIN模型进行节点分类

我们将使用Cora数据集来演示如何使用PyG实现GIN模型进行节点分类。Cora数据集是一个引文网络数据集,其中每个节点表示一篇论文,每个边表示两篇论文之间的引用关系。每个节点都有一个特征向量,表示论文的词袋表示。我们的目标是预测每篇论文的类别。下面是一个示例:

from torch_geometric.datasets import Planetoid
from torch_geometric.nn import global_add_pool
from torch_geometric.utils import train_test_split_edges

# 加载Cora数据集
dataset = Planetoid(root='data/Cora', name='Cora')
data = dataset[0]

# 划分训练集、验证集和测试集
data.train_mask = data.val_mask = data.test_mask = data.y = None
data = train_test_split_edges(data)
data.train_mask = data.val_mask = data.test_mask = None
data.y = data.y.squeeze()

# 定义模型和优化器
model = GIN(emb_dim=16, num_classes=dataset.num_classes)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

# 训练模型
model.train()
for epoch in range(200):
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)
    loss = F.nll_loss(out[data.train_pos_edge_index], data.train_y)
    loss.backward()
    optimizer.step()

    # 在验证集上评估模型
    model.eval()
    with torch.no_grad():
        out = model(data.x, data.edge_index)
        val_loss = F.nll_loss(out[data.val_pos_edge_index], data.val_y)
        val_acc = accuracy(out[data.val_pos_edge_index], data.val_y)

    print('Epoch: {:03d}, Loss: {:.4f}, Val Loss: {:.4f}, Val Acc: {:.4f}'.format(
        epoch, loss.item(), val_loss.item(), val_acc.item()))

# 在测试集上评估模型
model.eval()
with torch.no_grad():
    out = model(data.x, data.edge_index)
    test_loss = F.nll_loss(out[data.test_pos_edge_index], data.test_y)
    test_acc = accuracy(out[data.test_pos_edge_index], data.test_y)

print('Test Loss: {:.4f}, Test Acc: {:.4f}'.format(test_loss.item(), test_acc.item()))

在这个示例中,我们首先使用Planetoid类加载Cora数据集,并使用train_test_split_edges函数将数据集划分为训练集、验证集和测试集。然后,我们定义了一个名为model的GIN模型,并使用Adam优化器进行训练。在每个时期中,我们首先使用optimizer.zero_grad()方法除梯度,然后使用模型对训练数据进行预测,并使用F.nll_loss函数计算损失。接下来,我们使用模型对验证数据进行预测,并计算验证损失和准确率。最后,我们打印出每个时期的损失、验证损失和验证准确率。

在训练结束后,我们使用模型对测试数据进行预测,并计算测试损失和准确率。最后,我们打印出测试损失和测试准确率。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:Pytorch+PyG实现GIN过程示例详解 - Python技术站

(2)
上一篇 2023年5月15日
下一篇 2023年5月15日

相关文章

  • Pytorch中expand()的使用(扩展某个维度)

    PyTorch中expand()的使用(扩展某个维度) 在PyTorch中,expand()函数可以用来扩展张量的某个维度,从而实现张量的形状变换。expand()函数会自动复制张量的数据,以填充新的维度。下面是expand()函数的详细使用方法: torch.Tensor.expand(*sizes) -> Tensor 其中,*sizes是一个可变…

    PyTorch 2023年5月15日
    00
  • 使用自定义的Dataloader做数据增强、格式统一等操作/像使用pytorch一样进行训练。

    格式统一 https://detectron2.readthedocs.io/tutorials/data_loading.html 不使用train而是使用Model进行自定义训练 https://detectron2.readthedocs.io/tutorials/models.html 实现并写一个新的model层,注册到config以供使用 htt…

    PyTorch 2023年4月7日
    00
  • python pytorch numpy DNN 线性回归模型

    1、直接奉献代码,后期有入门更新,之前一直在学的是TensorFlow, import torch from torch.autograd import Variable import torch.nn.functional as F import matplotlib.pyplot as plt import numpy as np x_data = np…

    2023年4月8日
    00
  • 详解win10下pytorch-gpu安装以及CUDA详细安装过程

    在Windows 10下安装PyTorch GPU版本需要安装CUDA和cuDNN,本文将详细讲解如何安装PyTorch GPU版本以及CUDA和cuDNN,并提供两个示例说明。 1. 安装PyTorch GPU版本 在安装PyTorch GPU版本之前,需要先安装CUDA和cuDNN。安装完成后,可以通过以下步骤安装PyTorch GPU版本: 打开Ana…

    PyTorch 2023年5月15日
    00
  • Pytorch中RNN和LSTM的简单应用

    目录 使用RNN执行回归任务 使用LSTM执行分类任务 使用RNN执行回归任务 import torch from torch import nn import numpy as np import matplotlib.pyplot as plt # torch.manual_seed(1) # reproducible # Hyper Parameter…

    PyTorch 2023年4月8日
    00
  • ubuntu下用anaconda快速安装 pytorch

    1.  创建虚拟环境 1 conda create -n pytorch python=3.6 2. 激活虚拟环境 1 conda activate pytorch #这里 有用 source activate pytorch,因为我用的是conda激活的,这个看个人需求 3. 安装pytorch   打开pytorch官网https://pytorch.o…

    2023年4月8日
    00
  • 基于anaconda3的Pytorch环境搭建

    安装anaconda3,版本选择新的就行 打开anaconda prompt创建虚拟环境conda create -n pytorch_gpu python=3.9,pytorch_gpu是环境名称,可自行选取,python=3.9是选择的python版本,可自行选择,conda会自动下载选择的python版本 打开cmd按照下图输入查看显卡驱动版本 查看显…

    2023年4月8日
    00
  • Pytorch实现List Tensor转Tensor,reshape拼接等操作

    以下是PyTorch实现List Tensor转Tensor、reshape、拼接等操作的两个示例说明。 示例1:将List Tensor转换为Tensor 在这个示例中,我们将使用PyTorch将List Tensor转换为Tensor。 首先,我们需要准备数据。我们将使用以下代码来生成List Tensor: import torch x1 = torc…

    PyTorch 2023年5月15日
    00
合作推广
合作推广
分享本页
返回顶部