torch-geometric

Compare original and translation side by side

🇺🇸

Original

English
🇨🇳

Translation

Chinese

PyTorch Geometric (PyG)

PyTorch Geometric (PyG)

Overview

概述

PyTorch Geometric is a library built on PyTorch for developing and training Graph Neural Networks (GNNs). Apply this skill for deep learning on graphs and irregular structures, including mini-batch processing, multi-GPU training, and geometric deep learning applications.
PyTorch Geometric (PyG) 是基于PyTorch构建的用于开发和训练图神经网络(GNNs)的库。将此技能应用于图和不规则结构的深度学习,包括小批量处理、多GPU训练和几何深度学习应用。

When to Use This Skill

何时使用此技能

This skill should be used when working with:
  • Graph-based machine learning: Node classification, graph classification, link prediction
  • Molecular property prediction: Drug discovery, chemical property prediction
  • Social network analysis: Community detection, influence prediction
  • Citation networks: Paper classification, recommendation systems
  • 3D geometric data: Point clouds, meshes, molecular structures
  • Heterogeneous graphs: Multi-type nodes and edges (e.g., knowledge graphs)
  • Large-scale graph learning: Neighbor sampling, distributed training
在处理以下场景时应使用此技能:
  • 基于图的机器学习:节点分类、图分类、链接预测
  • 分子属性预测:药物发现、化学属性预测
  • 社交网络分析:社区检测、影响力预测
  • 引文网络:论文分类、推荐系统
  • 3D几何数据:点云、网格、分子结构
  • 异构图:多类型节点和边(例如知识图谱)
  • 大规模图学习:邻居采样、分布式训练

Quick Start

快速开始

Installation

安装

bash
uv pip install torch_geometric
For additional dependencies (sparse operations, clustering):
bash
uv pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-${TORCH}+${CUDA}.html
bash
uv pip install torch_geometric
如需额外依赖(稀疏操作、聚类):
bash
uv pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-${TORCH}+${CUDA}.html

Basic Graph Creation

基础图创建

python
import torch
from torch_geometric.data import Data
python
import torch
from torch_geometric.data import Data

Create a simple graph with 3 nodes

创建一个包含3个节点的简单图

edge_index = torch.tensor([[0, 1, 1, 2], # source nodes [1, 0, 2, 1]], dtype=torch.long) # target nodes x = torch.tensor([[-1], [0], [1]], dtype=torch.float) # node features
data = Data(x=x, edge_index=edge_index) print(f"Nodes: {data.num_nodes}, Edges: {data.num_edges}")
undefined
edge_index = torch.tensor([[0, 1, 1, 2], # 源节点 [1, 0, 2, 1]], dtype=torch.long) # 目标节点 x = torch.tensor([[-1], [0], [1]], dtype=torch.float) # 节点特征
data = Data(x=x, edge_index=edge_index) print(f"Nodes: {data.num_nodes}, Edges: {data.num_edges}")
undefined

Loading a Benchmark Dataset

加载基准数据集

python
from torch_geometric.datasets import Planetoid
python
from torch_geometric.datasets import Planetoid

Load Cora citation network

加载Cora引文网络

dataset = Planetoid(root='/tmp/Cora', name='Cora') data = dataset[0] # Get the first (and only) graph
print(f"Dataset: {dataset}") print(f"Nodes: {data.num_nodes}, Edges: {data.num_edges}") print(f"Features: {data.num_node_features}, Classes: {dataset.num_classes}")
undefined
dataset = Planetoid(root='/tmp/Cora', name='Cora') data = dataset[0] # 获取第一个(也是唯一一个)图
print(f"Dataset: {dataset}") print(f"Nodes: {data.num_nodes}, Edges: {data.num_edges}") print(f"Features: {data.num_node_features}, Classes: {dataset.num_classes}")
undefined

Core Concepts

核心概念

Data Structure

数据结构

PyG represents graphs using the
torch_geometric.data.Data
class with these key attributes:
  • data.x
    : Node feature matrix
    [num_nodes, num_node_features]
  • data.edge_index
    : Graph connectivity in COO format
    [2, num_edges]
  • data.edge_attr
    : Edge feature matrix
    [num_edges, num_edge_features]
    (optional)
  • data.y
    : Target labels for nodes or graphs
  • data.pos
    : Node spatial positions
    [num_nodes, num_dimensions]
    (optional)
  • Custom attributes: Can add any attribute (e.g.,
    data.train_mask
    ,
    data.batch
    )
Important: These attributes are not mandatory—extend Data objects with custom attributes as needed.
PyG 使用
torch_geometric.data.Data
类表示图,包含以下关键属性:
  • data.x
    :节点特征矩阵
    [num_nodes, num_node_features]
  • data.edge_index
    :COO格式的图连通性
    [2, num_edges]
  • data.edge_attr
    :边特征矩阵
    [num_edges, num_edge_features]
    (可选)
  • data.y
    :节点或图的目标标签
  • data.pos
    :节点空间位置
    [num_nodes, num_dimensions]
    (可选)
  • 自定义属性:可添加任意属性(例如
    data.train_mask
    ,
    data.batch
重要提示:这些属性并非强制要求——可根据需要扩展Data对象的自定义属性。

Edge Index Format

边索引格式

Edges are stored in COO (coordinate) format as a
[2, num_edges]
tensor:
  • First row: source node indices
  • Second row: target node indices
python
undefined
边以COO(坐标)格式存储为
[2, num_edges]
张量:
  • 第一行:源节点索引
  • 第二行:目标节点索引
python
undefined

Edge list: (0→1), (1→0), (1→2), (2→1)

边列表: (0→1), (1→0), (1→2), (2→1)

edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long)
undefined
edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long)
undefined

Mini-Batch Processing

小批量处理

PyG handles batching by creating block-diagonal adjacency matrices, concatenating multiple graphs into one large disconnected graph:
  • Adjacency matrices are stacked diagonally
  • Node features are concatenated along the node dimension
  • A
    batch
    vector maps each node to its source graph
  • No padding needed—computationally efficient
python
from torch_geometric.loader import DataLoader

loader = DataLoader(dataset, batch_size=32, shuffle=True)
for batch in loader:
    print(f"Batch size: {batch.num_graphs}")
    print(f"Total nodes: {batch.num_nodes}")
    # batch.batch maps nodes to graphs
PyG 通过创建块对角邻接矩阵来处理批处理,将多个图连接成一个大型不连通图:
  • 邻接矩阵沿对角线堆叠
  • 节点特征沿节点维度拼接
  • batch
    向量将每个节点映射到其源图
  • 无需填充——计算效率高
python
from torch_geometric.loader import DataLoader

loader = DataLoader(dataset, batch_size=32, shuffle=True)
for batch in loader:
    print(f"Batch size: {batch.num_graphs}")
    print(f"Total nodes: {batch.num_nodes}")
    # batch.batch 将节点映射到对应图

Building Graph Neural Networks

构建图神经网络

Message Passing Paradigm

消息传递范式

GNNs in PyG follow a neighborhood aggregation scheme:
  1. Transform node features
  2. Propagate messages along edges
  3. Aggregate messages from neighbors
  4. Update node representations
PyG 中的GNN遵循邻域聚合方案:
  1. 转换节点特征
  2. 沿边传播消息
  3. 聚合来自邻居的消息
  4. 更新节点表示

Using Pre-Built Layers

使用预构建层

PyG provides 40+ convolutional layers. Common ones include:
GCNConv (Graph Convolutional Network):
python
from torch_geometric.nn import GCNConv
import torch.nn.functional as F

class GCN(torch.nn.Module):
    def __init__(self, num_features, num_classes):
        super().__init__()
        self.conv1 = GCNConv(num_features, 16)
        self.conv2 = GCNConv(16, num_classes)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)
GATConv (Graph Attention Network):
python
from torch_geometric.nn import GATConv

class GAT(torch.nn.Module):
    def __init__(self, num_features, num_classes):
        super().__init__()
        self.conv1 = GATConv(num_features, 8, heads=8, dropout=0.6)
        self.conv2 = GATConv(8 * 8, num_classes, heads=1, concat=False, dropout=0.6)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = F.dropout(x, p=0.6, training=self.training)
        x = F.elu(self.conv1(x, edge_index))
        x = F.dropout(x, p=0.6, training=self.training)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)
GraphSAGE:
python
from torch_geometric.nn import SAGEConv

class GraphSAGE(torch.nn.Module):
    def __init__(self, num_features, num_classes):
        super().__init__()
        self.conv1 = SAGEConv(num_features, 64)
        self.conv2 = SAGEConv(64, num_classes)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)
PyG 提供40余种卷积层。常见的包括:
GCNConv(图卷积网络):
python
from torch_geometric.nn import GCNConv
import torch.nn.functional as F

class GCN(torch.nn.Module):
    def __init__(self, num_features, num_classes):
        super().__init__()
        self.conv1 = GCNConv(num_features, 16)
        self.conv2 = GCNConv(16, num_classes)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)
GATConv(图注意力网络):
python
from torch_geometric.nn import GATConv

class GAT(torch.nn.Module):
    def __init__(self, num_features, num_classes):
        super().__init__()
        self.conv1 = GATConv(num_features, 8, heads=8, dropout=0.6)
        self.conv2 = GATConv(8 * 8, num_classes, heads=1, concat=False, dropout=0.6)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = F.dropout(x, p=0.6, training=self.training)
        x = F.elu(self.conv1(x, edge_index))
        x = F.dropout(x, p=0.6, training=self.training)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)
GraphSAGE:
python
from torch_geometric.nn import SAGEConv

class GraphSAGE(torch.nn.Module):
    def __init__(self, num_features, num_classes):
        super().__init__()
        self.conv1 = SAGEConv(num_features, 64)
        self.conv2 = SAGEConv(64, num_classes)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)

Custom Message Passing Layers

自定义消息传递层

For custom layers, inherit from
MessagePassing
:
python
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree

class CustomConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super().__init__(aggr='add')  # "add", "mean", or "max"
        self.lin = torch.nn.Linear(in_channels, out_channels)

    def forward(self, x, edge_index):
        # Add self-loops to adjacency matrix
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        # Transform node features
        x = self.lin(x)

        # Compute normalization
        row, col = edge_index
        deg = degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        # Propagate messages
        return self.propagate(edge_index, x=x, norm=norm)

    def message(self, x_j, norm):
        # x_j: features of source nodes
        return norm.view(-1, 1) * x_j
Key methods:
  • forward()
    : Main entry point
  • message()
    : Constructs messages from source to target nodes
  • aggregate()
    : Aggregates messages (usually don't override—set
    aggr
    parameter)
  • update()
    : Updates node embeddings after aggregation
Variable naming convention: Appending
_i
or
_j
to tensor names automatically maps them to target or source nodes.
如需自定义层,继承自
MessagePassing
python
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree

class CustomConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super().__init__(aggr='add')  # "add", "mean", 或 "max"
        self.lin = torch.nn.Linear(in_channels, out_channels)

    def forward(self, x, edge_index):
        # 向邻接矩阵添加自环
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        # 转换节点特征
        x = self.lin(x)

        # 计算归一化系数
        row, col = edge_index
        deg = degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        # 传播消息
        return self.propagate(edge_index, x=x, norm=norm)

    def message(self, x_j, norm):
        # x_j: 源节点的特征
        return norm.view(-1, 1) * x_j
关键方法:
  • forward()
    : 主入口点
  • message()
    : 构造从源节点到目标节点的消息
  • aggregate()
    : 聚合消息(通常无需重写——设置
    aggr
    参数即可)
  • update()
    : 聚合后更新节点嵌入
变量命名约定:在张量名称后添加
_i
_j
会自动将其映射到目标节点或源节点。

Working with Datasets

处理数据集

Loading Built-in Datasets

加载内置数据集

PyG provides extensive benchmark datasets:
python
undefined
PyG 提供丰富的基准数据集:
python
undefined

Citation networks (node classification)

引文网络(节点分类)

from torch_geometric.datasets import Planetoid dataset = Planetoid(root='/tmp/Cora', name='Cora') # or 'CiteSeer', 'PubMed'
from torch_geometric.datasets import Planetoid dataset = Planetoid(root='/tmp/Cora', name='Cora') # 或 'CiteSeer', 'PubMed'

Graph classification

图分类

from torch_geometric.datasets import TUDataset dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES')
from torch_geometric.datasets import TUDataset dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES')

Molecular datasets

分子数据集

from torch_geometric.datasets import QM9 dataset = QM9(root='/tmp/QM9')
from torch_geometric.datasets import QM9 dataset = QM9(root='/tmp/QM9')

Large-scale datasets

大规模数据集

from torch_geometric.datasets import Reddit dataset = Reddit(root='/tmp/Reddit')

Check `references/datasets_reference.md` for a comprehensive list.
from torch_geometric.datasets import Reddit dataset = Reddit(root='/tmp/Reddit')

查看`references/datasets_reference.md`获取完整列表。

Creating Custom Datasets

创建自定义数据集

For datasets that fit in memory, inherit from
InMemoryDataset
:
python
from torch_geometric.data import InMemoryDataset, Data
import torch

class MyOwnDataset(InMemoryDataset):
    def __init__(self, root, transform=None, pre_transform=None):
        super().__init__(root, transform, pre_transform)
        self.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        return ['my_data.csv']  # Files needed in raw_dir

    @property
    def processed_file_names(self):
        return ['data.pt']  # Files in processed_dir

    def download(self):
        # Download raw data to self.raw_dir
        pass

    def process(self):
        # Read data, create Data objects
        data_list = []

        # Example: Create a simple graph
        edge_index = torch.tensor([[0, 1], [1, 0]], dtype=torch.long)
        x = torch.randn(2, 16)
        y = torch.tensor([0], dtype=torch.long)

        data = Data(x=x, edge_index=edge_index, y=y)
        data_list.append(data)

        # Apply pre_filter and pre_transform
        if self.pre_filter is not None:
            data_list = [d for d in data_list if self.pre_filter(d)]

        if self.pre_transform is not None:
            data_list = [self.pre_transform(d) for d in data_list]

        # Save processed data
        self.save(data_list, self.processed_paths[0])
For large datasets that don't fit in memory, inherit from
Dataset
and implement
len()
and
get(idx)
.
对于可放入内存的数据集,继承自
InMemoryDataset
python
from torch_geometric.data import InMemoryDataset, Data
import torch

class MyOwnDataset(InMemoryDataset):
    def __init__(self, root, transform=None, pre_transform=None):
        super().__init__(root, transform, pre_transform)
        self.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        return ['my_data.csv']  # raw_dir中所需的文件

    @property
    def processed_file_names(self):
        return ['data.pt']  # processed_dir中的文件

    def download(self):
        # 将原始数据下载到self.raw_dir
        pass

    def process(self):
        # 读取数据,创建Data对象
        data_list = []

        # 示例:创建一个简单图
        edge_index = torch.tensor([[0, 1], [1, 0]], dtype=torch.long)
        x = torch.randn(2, 16)
        y = torch.tensor([0], dtype=torch.long)

        data = Data(x=x, edge_index=edge_index, y=y)
        data_list.append(data)

        # 应用pre_filter和pre_transform
        if self.pre_filter is not None:
            data_list = [d for d in data_list if self.pre_filter(d)]

        if self.pre_transform is not None:
            data_list = [self.pre_transform(d) for d in data_list]

        # 保存处理后的数据
        self.save(data_list, self.processed_paths[0])
对于无法放入内存的大规模数据集,继承自
Dataset
并实现
len()
get(idx)

Loading Graphs from CSV

从CSV加载图

python
import pandas as pd
import torch
from torch_geometric.data import HeteroData
python
import pandas as pd
import torch
from torch_geometric.data import HeteroData

Load nodes

加载节点

nodes_df = pd.read_csv('nodes.csv') x = torch.tensor(nodes_df[['feat1', 'feat2']].values, dtype=torch.float)
nodes_df = pd.read_csv('nodes.csv') x = torch.tensor(nodes_df[['feat1', 'feat2']].values, dtype=torch.float)

Load edges

加载边

edges_df = pd.read_csv('edges.csv') edge_index = torch.tensor([edges_df['source'].values, edges_df['target'].values], dtype=torch.long)
data = Data(x=x, edge_index=edge_index)
undefined
edges_df = pd.read_csv('edges.csv') edge_index = torch.tensor([edges_df['source'].values, edges_df['target'].values], dtype=torch.long)
data = Data(x=x, edge_index=edge_index)
undefined

Training Workflows

训练工作流

Node Classification (Single Graph)

节点分类(单图)

python
import torch
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
python
import torch
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid

Load dataset

加载数据集

dataset = Planetoid(root='/tmp/Cora', name='Cora') data = dataset[0]
dataset = Planetoid(root='/tmp/Cora', name='Cora') data = dataset[0]

Create model

创建模型

model = GCN(dataset.num_features, dataset.num_classes) optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
model = GCN(dataset.num_features, dataset.num_classes) optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

Training

训练

model.train() for epoch in range(200): optimizer.zero_grad() out = model(data) loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask]) loss.backward() optimizer.step()
if epoch % 10 == 0:
    print(f'Epoch {epoch}, Loss: {loss.item():.4f}')
model.train() for epoch in range(200): optimizer.zero_grad() out = model(data) loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask]) loss.backward() optimizer.step()
if epoch % 10 == 0:
    print(f'Epoch {epoch}, Loss: {loss.item():.4f}')

Evaluation

评估

model.eval() pred = model(data).argmax(dim=1) correct = (pred[data.test_mask] == data.y[data.test_mask]).sum() acc = int(correct) / int(data.test_mask.sum()) print(f'Test Accuracy: {acc:.4f}')
undefined
model.eval() pred = model(data).argmax(dim=1) correct = (pred[data.test_mask] == data.y[data.test_mask]).sum() acc = int(correct) / int(data.test_mask.sum()) print(f'Test Accuracy: {acc:.4f}')
undefined

Graph Classification (Multiple Graphs)

图分类(多图)

python
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn import global_mean_pool

class GraphClassifier(torch.nn.Module):
    def __init__(self, num_features, num_classes):
        super().__init__()
        self.conv1 = GCNConv(num_features, 64)
        self.conv2 = GCNConv(64, 64)
        self.lin = torch.nn.Linear(64, num_classes)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch

        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)

        # Global pooling (aggregate node features to graph-level)
        x = global_mean_pool(x, batch)

        x = self.lin(x)
        return F.log_softmax(x, dim=1)
python
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn import global_mean_pool

class GraphClassifier(torch.nn.Module):
    def __init__(self, num_features, num_classes):
        super().__init__()
        self.conv1 = GCNConv(num_features, 64)
        self.conv2 = GCNConv(64, 64)
        self.lin = torch.nn.Linear(64, num_classes)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch

        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)

        # 全局池化(将节点特征聚合为图级特征)
        x = global_mean_pool(x, batch)

        x = self.lin(x)
        return F.log_softmax(x, dim=1)

Load dataset

加载数据集

dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES') loader = DataLoader(dataset, batch_size=32, shuffle=True)
model = GraphClassifier(dataset.num_features, dataset.num_classes) optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES') loader = DataLoader(dataset, batch_size=32, shuffle=True)
model = GraphClassifier(dataset.num_features, dataset.num_classes) optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

Training

训练

model.train() for epoch in range(100): total_loss = 0 for batch in loader: optimizer.zero_grad() out = model(batch) loss = F.nll_loss(out, batch.y) loss.backward() optimizer.step() total_loss += loss.item()
if epoch % 10 == 0:
    print(f'Epoch {epoch}, Loss: {total_loss / len(loader):.4f}')
undefined
model.train() for epoch in range(100): total_loss = 0 for batch in loader: optimizer.zero_grad() out = model(batch) loss = F.nll_loss(out, batch.y) loss.backward() optimizer.step() total_loss += loss.item()
if epoch % 10 == 0:
    print(f'Epoch {epoch}, Loss: {total_loss / len(loader):.4f}')
undefined

Large-Scale Graphs with Neighbor Sampling

基于邻居采样的大规模图训练

For large graphs, use
NeighborLoader
to sample subgraphs:
python
from torch_geometric.loader import NeighborLoader
对于大型图,使用
NeighborLoader
采样子图:
python
from torch_geometric.loader import NeighborLoader

Create a neighbor sampler

创建邻居采样器

train_loader = NeighborLoader( data, num_neighbors=[25, 10], # Sample 25 neighbors for 1st hop, 10 for 2nd hop batch_size=128, input_nodes=data.train_mask, )
train_loader = NeighborLoader( data, num_neighbors=[25, 10], # 第一跳采样25个邻居,第二跳采样10个 batch_size=128, input_nodes=data.train_mask, )

Training

训练

model.train() for batch in train_loader: optimizer.zero_grad() out = model(batch) # Only compute loss on seed nodes (first batch_size nodes) loss = F.nll_loss(out[:batch.batch_size], batch.y[:batch.batch_size]) loss.backward() optimizer.step()

**Important**:
- Output subgraphs are directed
- Node indices are relabeled (0 to batch.num_nodes - 1)
- Only use seed node predictions for loss computation
- Sampling beyond 2-3 hops is generally not feasible
model.train() for batch in train_loader: optimizer.zero_grad() out = model(batch) # 仅对种子节点(前batch_size个节点)计算损失 loss = F.nll_loss(out[:batch.batch_size], batch.y[:batch.batch_size]) loss.backward() optimizer.step()

**重要提示**:
- 输出子图是有向的
- 节点索引会被重新标记(从0到batch.num_nodes - 1)
- 仅使用种子节点的预测结果计算损失
- 通常采样超过2-3跳不可行

Advanced Features

高级功能

Heterogeneous Graphs

异构图

For graphs with multiple node and edge types, use
HeteroData
:
python
from torch_geometric.data import HeteroData

data = HeteroData()
对于包含多种节点和边类型的图,使用
HeteroData
python
from torch_geometric.data import HeteroData

data = HeteroData()

Add node features for different types

为不同类型添加节点特征

data['paper'].x = torch.randn(100, 128) # 100 papers with 128 features data['author'].x = torch.randn(200, 64) # 200 authors with 64 features
data['paper'].x = torch.randn(100, 128) # 100篇论文,每篇128维特征 data['author'].x = torch.randn(200, 64) # 200位作者,每位64维特征

Add edges for different types (source_type, edge_type, target_type)

为不同类型添加边(源类型, 边类型, 目标类型)

data['author', 'writes', 'paper'].edge_index = torch.randint(0, 200, (2, 500)) data['paper', 'cites', 'paper'].edge_index = torch.randint(0, 100, (2, 300))
print(data)

Convert homogeneous models to heterogeneous:

```python
from torch_geometric.nn import to_hetero
data['author', 'writes', 'paper'].edge_index = torch.randint(0, 200, (2, 500)) data['paper', 'cites', 'paper'].edge_index = torch.randint(0, 100, (2, 300))
print(data)

将同构模型转换为异构模型:

```python
from torch_geometric.nn import to_hetero

Define homogeneous model

定义同构模型

model = GNN(...)
model = GNN(...)

Convert to heterogeneous

转换为异构模型

model = to_hetero(model, data.metadata(), aggr='sum')
model = to_hetero(model, data.metadata(), aggr='sum')

Use as normal

正常使用

out = model(data.x_dict, data.edge_index_dict)

Or use `HeteroConv` for custom edge-type-specific operations:

```python
from torch_geometric.nn import HeteroConv, GCNConv, SAGEConv

class HeteroGNN(torch.nn.Module):
    def __init__(self, metadata):
        super().__init__()
        self.conv1 = HeteroConv({
            ('paper', 'cites', 'paper'): GCNConv(-1, 64),
            ('author', 'writes', 'paper'): SAGEConv((-1, -1), 64),
        }, aggr='sum')

        self.conv2 = HeteroConv({
            ('paper', 'cites', 'paper'): GCNConv(64, 32),
            ('author', 'writes', 'paper'): SAGEConv((64, 64), 32),
        }, aggr='sum')

    def forward(self, x_dict, edge_index_dict):
        x_dict = self.conv1(x_dict, edge_index_dict)
        x_dict = {key: F.relu(x) for key, x in x_dict.items()}
        x_dict = self.conv2(x_dict, edge_index_dict)
        return x_dict
out = model(data.x_dict, data.edge_index_dict)

或使用`HeteroConv`实现自定义边类型专属操作:

```python
from torch_geometric.nn import HeteroConv, GCNConv, SAGEConv

class HeteroGNN(torch.nn.Module):
    def __init__(self, metadata):
        super().__init__()
        self.conv1 = HeteroConv({
            ('paper', 'cites', 'paper'): GCNConv(-1, 64),
            ('author', 'writes', 'paper'): SAGEConv((-1, -1), 64),
        }, aggr='sum')

        self.conv2 = HeteroConv({
            ('paper', 'cites', 'paper'): GCNConv(64, 32),
            ('author', 'writes', 'paper'): SAGEConv((64, 64), 32),
        }, aggr='sum')

    def forward(self, x_dict, edge_index_dict):
        x_dict = self.conv1(x_dict, edge_index_dict)
        x_dict = {key: F.relu(x) for key, x in x_dict.items()}
        x_dict = self.conv2(x_dict, edge_index_dict)
        return x_dict

Transforms

变换

Apply transforms to modify graph structure or features:
python
from torch_geometric.transforms import NormalizeFeatures, AddSelfLoops, Compose
应用变换以修改图结构或特征:
python
from torch_geometric.transforms import NormalizeFeatures, AddSelfLoops, Compose

Single transform

单个变换

transform = NormalizeFeatures() dataset = Planetoid(root='/tmp/Cora', name='Cora', transform=transform)
transform = NormalizeFeatures() dataset = Planetoid(root='/tmp/Cora', name='Cora', transform=transform)

Compose multiple transforms

组合多个变换

transform = Compose([ AddSelfLoops(), NormalizeFeatures(), ]) dataset = Planetoid(root='/tmp/Cora', name='Cora', transform=transform)

Common transforms:
- **Structure**: `ToUndirected`, `AddSelfLoops`, `RemoveSelfLoops`, `KNNGraph`, `RadiusGraph`
- **Features**: `NormalizeFeatures`, `NormalizeScale`, `Center`
- **Sampling**: `RandomNodeSplit`, `RandomLinkSplit`
- **Positional Encoding**: `AddLaplacianEigenvectorPE`, `AddRandomWalkPE`

See `references/transforms_reference.md` for the full list.
transform = Compose([ AddSelfLoops(), NormalizeFeatures(), ]) dataset = Planetoid(root='/tmp/Cora', name='Cora', transform=transform)

常见变换:
- **结构类**:`ToUndirected`, `AddSelfLoops`, `RemoveSelfLoops`, `KNNGraph`, `RadiusGraph`
- **特征类**:`NormalizeFeatures`, `NormalizeScale`, `Center`
- **采样类**:`RandomNodeSplit`, `RandomLinkSplit`
- **位置编码类**:`AddLaplacianEigenvectorPE`, `AddRandomWalkPE`

完整列表请查看`references/transforms_reference.md`。

Model Explainability

模型可解释性

PyG provides explainability tools to understand model predictions:
python
from torch_geometric.explain import Explainer, GNNExplainer
PyG 提供可解释性工具以理解模型预测:
python
from torch_geometric.explain import Explainer, GNNExplainer

Create explainer

创建解释器

explainer = Explainer( model=model, algorithm=GNNExplainer(epochs=200), explanation_type='model', # or 'phenomenon' node_mask_type='attributes', edge_mask_type='object', model_config=dict( mode='multiclass_classification', task_level='node', return_type='log_probs', ), )
explainer = Explainer( model=model, algorithm=GNNExplainer(epochs=200), explanation_type='model', # 或 'phenomenon' node_mask_type='attributes', edge_mask_type='object', model_config=dict( mode='multiclass_classification', task_level='node', return_type='log_probs', ), )

Generate explanation for a specific node

为特定节点生成解释

node_idx = 10 explanation = explainer(data.x, data.edge_index, index=node_idx)
node_idx = 10 explanation = explainer(data.x, data.edge_index, index=node_idx)

Visualize

可视化

print(f'Node {node_idx} explanation:') print(f'Important edges: {explanation.edge_mask.topk(5).indices}') print(f'Important features: {explanation.node_mask[node_idx].topk(5).indices}')
undefined
print(f'Node {node_idx} explanation:') print(f'Important edges: {explanation.edge_mask.topk(5).indices}') print(f'Important features: {explanation.node_mask[node_idx].topk(5).indices}')
undefined

Pooling Operations

池化操作

For hierarchical graph representations:
python
from torch_geometric.nn import TopKPooling, global_mean_pool

class HierarchicalGNN(torch.nn.Module):
    def __init__(self, num_features, num_classes):
        super().__init__()
        self.conv1 = GCNConv(num_features, 64)
        self.pool1 = TopKPooling(64, ratio=0.8)
        self.conv2 = GCNConv(64, 64)
        self.pool2 = TopKPooling(64, ratio=0.8)
        self.lin = torch.nn.Linear(64, num_classes)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch

        x = F.relu(self.conv1(x, edge_index))
        x, edge_index, _, batch, _, _ = self.pool1(x, edge_index, None, batch)

        x = F.relu(self.conv2(x, edge_index))
        x, edge_index, _, batch, _, _ = self.pool2(x, edge_index, None, batch)

        x = global_mean_pool(x, batch)
        x = self.lin(x)
        return F.log_softmax(x, dim=1)
用于分层图表示:
python
from torch_geometric.nn import TopKPooling, global_mean_pool

class HierarchicalGNN(torch.nn.Module):
    def __init__(self, num_features, num_classes):
        super().__init__()
        self.conv1 = GCNConv(num_features, 64)
        self.pool1 = TopKPooling(64, ratio=0.8)
        self.conv2 = GCNConv(64, 64)
        self.pool2 = TopKPooling(64, ratio=0.8)
        self.lin = torch.nn.Linear(64, num_classes)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch

        x = F.relu(self.conv1(x, edge_index))
        x, edge_index, _, batch, _, _ = self.pool1(x, edge_index, None, batch)

        x = F.relu(self.conv2(x, edge_index))
        x, edge_index, _, batch, _, _ = self.pool2(x, edge_index, None, batch)

        x = global_mean_pool(x, batch)
        x = self.lin(x)
        return F.log_softmax(x, dim=1)

Common Patterns and Best Practices

常见模式与最佳实践

Check Graph Properties

检查图属性

python
undefined
python
undefined

Undirected check

检查是否为无向图

from torch_geometric.utils import is_undirected print(f"Is undirected: {is_undirected(data.edge_index)}")
from torch_geometric.utils import is_undirected print(f"Is undirected: {is_undirected(data.edge_index)}")

Connected components

连通分量

from torch_geometric.utils import connected_components print(f"Connected components: {connected_components(data.edge_index)}")
from torch_geometric.utils import connected_components print(f"Connected components: {connected_components(data.edge_index)}")

Contains self-loops

是否包含自环

from torch_geometric.utils import contains_self_loops print(f"Has self-loops: {contains_self_loops(data.edge_index)}")
undefined
from torch_geometric.utils import contains_self_loops print(f"Has self-loops: {contains_self_loops(data.edge_index)}")
undefined

GPU Training

GPU训练

python
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
data = data.to(device)
python
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
data = data.to(device)

For DataLoader

对于DataLoader

for batch in loader: batch = batch.to(device) # Train...
undefined
for batch in loader: batch = batch.to(device) # 训练...
undefined

Save and Load Models

保存与加载模型

python
undefined
python
undefined

Save

保存

torch.save(model.state_dict(), 'model.pth')
torch.save(model.state_dict(), 'model.pth')

Load

加载

model = GCN(num_features, num_classes) model.load_state_dict(torch.load('model.pth')) model.eval()
undefined
model = GCN(num_features, num_classes) model.load_state_dict(torch.load('model.pth')) model.eval()
undefined

Layer Capabilities

层能力

When choosing layers, consider these capabilities:
  • SparseTensor: Supports efficient sparse matrix operations
  • edge_weight: Handles one-dimensional edge weights
  • edge_attr: Processes multi-dimensional edge features
  • Bipartite: Works with bipartite graphs (different source/target dimensions)
  • Lazy: Enables initialization without specifying input dimensions
See the GNN cheatsheet at
references/layer_capabilities.md
.
选择层时,考虑以下能力:
  • SparseTensor: 支持高效稀疏矩阵操作
  • edge_weight: 处理一维边权重
  • edge_attr: 处理多维边特征
  • Bipartite: 适用于二分图(源/目标维度不同)
  • Lazy: 无需指定输入维度即可初始化
详情请查看
references/layer_capabilities.md
中的GNN速查表。

Resources

资源

Bundled References

内置参考文档

This skill includes detailed reference documentation:
  • references/layers_reference.md
    : Complete listing of all 40+ GNN layers with descriptions and capabilities
  • references/datasets_reference.md
    : Comprehensive dataset catalog organized by category
  • references/transforms_reference.md
    : All available transforms and their use cases
  • references/api_patterns.md
    : Common API patterns and coding examples
此技能包含详细的参考文档:
  • references/layers_reference.md
    : 所有40余种GNN层的完整列表,包含描述和能力说明
  • references/datasets_reference.md
    : 按类别组织的全面数据集目录
  • references/transforms_reference.md
    : 所有可用变换及其使用场景
  • references/api_patterns.md
    : 常见API模式和编码示例

Scripts

脚本

Utility scripts are provided in
scripts/
:
  • scripts/visualize_graph.py
    : Visualize graph structure using networkx and matplotlib
  • scripts/create_gnn_template.py
    : Generate boilerplate code for common GNN architectures
  • scripts/benchmark_model.py
    : Benchmark model performance on standard datasets
Execute scripts directly or read them for implementation patterns.
scripts/
目录下提供实用脚本:
  • scripts/visualize_graph.py
    : 使用networkx和matplotlib可视化图结构
  • scripts/create_gnn_template.py
    : 生成常见GNN架构的样板代码
  • scripts/benchmark_model.py
    : 在标准数据集上基准测试模型性能
可直接执行脚本或阅读脚本以了解实现模式。

Official Resources

官方资源