torch-geometric

Compare original and translation side by side

🇺🇸

Original

English
🇨🇳

Translation

Chinese

PyTorch Geometric (PyG)

PyTorch Geometric (PyG)

PyG is the standard library for Graph Neural Networks built on PyTorch. It provides data structures for graphs, 60+ GNN layer implementations, scalable mini-batch training, and support for heterogeneous graphs.
Install:
uv add torch_geometric
(or
uv pip install torch_geometric
; requires PyTorch). Optional:
pyg-lib
,
torch-scatter
,
torch-sparse
,
torch-cluster
for accelerated ops.
PyG是基于PyTorch构建的图神经网络标准库。它提供图数据结构、60余种GNN层实现、可扩展的小批量训练,并支持异构图。
安装:
uv add torch_geometric
(或
uv pip install torch_geometric
;需依赖PyTorch)。可选安装:
pyg-lib
torch-scatter
torch-sparse
torch-cluster
以加速运算。

Core Concepts

核心概念

Graph Data:
Data
and
HeteroData

图数据:
Data
HeteroData

A graph lives in a
Data
object. The key attributes:
python
from torch_geometric.data import Data

data = Data(
    x=node_features,          # [num_nodes, num_node_features]
    edge_index=edge_index,     # [2, num_edges] — COO format, dtype=torch.long
    edge_attr=edge_features,   # [num_edges, num_edge_features]
    y=labels,                  # node-level [num_nodes, *] or graph-level [1, *]
    pos=positions,             # [num_nodes, num_dimensions] (for point clouds/spatial)
)
edge_index
format is critical
: it's a
[2, num_edges]
tensor where
edge_index[0]
= source nodes,
edge_index[1]
= target nodes. It is NOT a list of tuples. If you have edge pairs as rows, transpose and call
.contiguous()
:
python
undefined
图数据存储在
Data
对象中。核心属性如下:
python
from torch_geometric.data import Data

data = Data(
    x=node_features,          # [节点数量, 节点特征维度]
    edge_index=edge_index,     # [2, 边数量] — COO格式,数据类型为torch.long
    edge_attr=edge_features,   # [边数量, 边特征维度]
    y=labels,                  # 节点级标签 [节点数量, *] 或图级标签 [1, *]
    pos=positions,             # [节点数量, 空间维度](用于点云/空间数据)
)
edge_index
格式至关重要
:它是一个形状为
[2, 边数量]
的张量,其中
edge_index[0]
表示源节点,
edge_index[1]
表示目标节点。它不是元组列表。如果你的边对是按行存储的,需要先转置并调用
.contiguous()
python
undefined

If edges are [[src1, dst1], [src2, dst2], ...] — transpose first:

如果边的格式为[[src1, dst1], [src2, dst2], ...] — 先转置:

edge_index = edge_pairs.t().contiguous()

For undirected graphs, include both directions: edge (0,1) needs both `[0,1]` and `[1,0]` in edge_index.

For heterogeneous graphs, use `HeteroData` — see the Heterogeneous Graphs section below.
edge_index = edge_pairs.t().contiguous()

对于无向图,需要包含双向边:边(0,1)需要同时在edge_index中添加`[0,1]`和`[1,0]`。

对于异构图,请使用`HeteroData`——详见下文的异构图章节。

Datasets

数据集

PyG bundles many standard datasets that auto-download and preprocess:
python
from torch_geometric.datasets import Planetoid, TUDataset
PyG内置了许多标准数据集,可自动下载并预处理:
python
from torch_geometric.datasets import Planetoid, TUDataset

Single-graph node classification (Cora, Citeseer, Pubmed)

单图节点分类数据集(Cora、Citeseer、Pubmed)

dataset = Planetoid(root='./data', name='Cora') data = dataset[0] # single graph with train/val/test masks
dataset = Planetoid(root='./data', name='Cora') data = dataset[0] # 包含训练/验证/测试掩码的单张图

Multi-graph classification (ENZYMES, MUTAG, IMDB-BINARY, etc.)

多图分类数据集(ENZYMES、MUTAG、IMDB-BINARY等)

dataset = TUDataset(root='./data', name='ENZYMES')
dataset = TUDataset(root='./data', name='ENZYMES')

dataset[0], dataset[1], ... are individual graphs

dataset[0], dataset[1], ... 为单个图数据


Common datasets by task:
- **Node classification**: Planetoid (Cora/Citeseer/Pubmed), OGB (ogbn-arxiv, ogbn-products, ogbn-mag)
- **Graph classification**: TUDataset (MUTAG, ENZYMES, PROTEINS, IMDB-BINARY), OGB (ogbg-molhiv)
- **Link prediction**: OGB (ogbl-collab, ogbl-citation2)
- **Molecular**: QM7, QM9, MoleculeNet
- **Point cloud/mesh**: ShapeNet, ModelNet10/40, FAUST

按任务分类的常见数据集:
- **节点分类**:Planetoid(Cora/Citeseer/Pubmed)、OGB(ogbn-arxiv、ogbn-products、ogbn-mag)
- **图分类**:TUDataset(MUTAG、ENZYMES、PROTEINS、IMDB-BINARY)、OGB(ogbg-molhiv)
- **链接预测**:OGB(ogbl-collab、ogbl-citation2)
- **分子数据**:QM7、QM9、MoleculeNet
- **点云/网格数据**:ShapeNet、ModelNet10/40、FAUST

Transforms

变换操作

Transforms preprocess or augment graph data, analogous to torchvision transforms:
python
import torch_geometric.transforms as T
变换操作用于预处理或增强图数据,类似于torchvision的变换:
python
import torch_geometric.transforms as T

Common transforms

常见变换

T.NormalizeFeatures() # Row-normalize node features to sum to 1 T.ToUndirected() # Add reverse edges to make graph undirected T.AddSelfLoops() # Add self-loop edges T.KNNGraph(k=6) # Build k-NN graph from point cloud positions T.RandomJitter(0.01) # Random noise augmentation on positions T.Compose([...]) # Chain multiple transforms
T.NormalizeFeatures() # 对节点特征进行行归一化,使其和为1 T.ToUndirected() # 添加反向边,将图转换为无向图 T.AddSelfLoops() # 添加自环边 T.KNNGraph(k=6) # 根据点云位置构建k近邻图 T.RandomJitter(0.01) # 为位置添加随机噪声增强 T.Compose([...]) # 组合多个变换

Apply as pre_transform (once, saved to disk) or transform (every access)

可设置为pre_transform(仅执行一次,保存到磁盘)或transform(每次访问时执行)

dataset = ShapeNet(root='./data', pre_transform=T.KNNGraph(k=6), transform=T.RandomJitter(0.01))
undefined
dataset = ShapeNet(root='./data', pre_transform=T.KNNGraph(k=6), transform=T.RandomJitter(0.01))
undefined

Building GNN Models

构建GNN模型

Quick Start: Using Built-in Layers

快速入门:使用内置层

The fastest way to build a GNN — stack conv layers from
torch_geometric.nn
:
python
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv2(x, edge_index)
        return x
Important: PyG conv layers do NOT include activation functions — apply them yourself after each layer. This is by design for flexibility.
构建GNN的最快方式——堆叠
torch_geometric.nn
中的卷积层:
python
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv2(x, edge_index)
        return x
重要提示:PyG的卷积层不包含激活函数——需手动在每层后添加。这是为了灵活性而设计的。

Choosing a Conv Layer

选择卷积层

Pick based on your task and graph structure:
LayerBest forKey idea
GCNConv
Homogeneous, semi-supervised node classificationSpectral-inspired, degree-normalized aggregation
GATConv
/
GATv2Conv
When neighbor importance variesAttention-weighted messages
SAGEConv
Large graphs, inductive settingsSampling-friendly, learnable aggregation
GINConv
Graph classification, maximizing expressivenessAs powerful as WL test
TransformerConv
Rich edge features, complex interactionsMulti-head attention with edge features
EdgeConv
Point clouds, dynamic graphsMLP on edge features (x_i, x_j - x_i)
RGCNConv
Heterogeneous with many relation typesRelation-specific weight matrices
HGTConv
Heterogeneous graphsType-specific attention
All conv layers accept
(x, edge_index)
at minimum. Many also accept
edge_attr
for edge features.
根据任务和图结构选择合适的卷积层:
适用场景核心思想
GCNConv
同构图、半监督节点分类基于谱方法,度归一化聚合
GATConv
/
GATv2Conv
邻居重要性存在差异的场景注意力加权消息传递
SAGEConv
大图、归纳式场景支持采样,可学习聚合方式
GINConv
图分类、最大化表达能力与WL测试一样强大
TransformerConv
边特征丰富、交互复杂的场景结合边特征的多头注意力
EdgeConv
点云、动态图基于边特征(x_i, x_j - x_i)的MLP
RGCNConv
包含多种关系类型的异构图关系特定的权重矩阵
HGTConv
异构图类型特定的注意力机制
所有卷积层至少接受
(x, edge_index)
作为输入。许多层还支持传入
edge_attr
以处理边特征。

Lazy Initialization

延迟初始化

Use
-1
for input channels to let PyG infer dimensions automatically — especially useful for heterogeneous models:
python
conv = SAGEConv((-1, -1), 64)  # Input dims inferred on first forward pass
输入通道使用
-1
,让PyG自动推断维度——在异构模型中尤其有用:
python
conv = SAGEConv((-1, -1), 64)  # 首次前向传播时自动推断输入维度

Initialize lazy modules:

初始化延迟模块:

with torch.no_grad(): out = model(data.x, data.edge_index)
undefined
with torch.no_grad(): out = model(data.x, data.edge_index)
undefined

High-Level Model APIs

高级模型API

For common architectures, PyG provides ready-made model classes:
python
from torch_geometric.nn import GraphSAGE, GCN, GAT, GIN

model = GraphSAGE(
    in_channels=dataset.num_features,
    hidden_channels=64,
    out_channels=dataset.num_classes,
    num_layers=2,
)
对于常见架构,PyG提供了现成的模型类:
python
from torch_geometric.nn import GraphSAGE, GCN, GAT, GIN

model = GraphSAGE(
    in_channels=dataset.num_features,
    hidden_channels=64,
    out_channels=dataset.num_classes,
    num_layers=2,
)

Custom Layers via MessagePassing

通过MessagePassing实现自定义层

To implement a novel GNN layer, subclass
MessagePassing
. The framework is:
  1. propagate()
    orchestrates the message passing
  2. message()
    defines what info flows along each edge (the phi function)
  3. aggregate()
    combines messages at each node (sum/mean/max)
  4. update()
    transforms the aggregated result (the gamma function)
python
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree

class MyConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super().__init__(aggr='add')  # "add", "mean", or "max"
        self.lin = torch.nn.Linear(in_channels, out_channels)

    def forward(self, x, edge_index):
        # Pre-processing before message passing
        x = self.lin(x)
        # Start message passing
        return self.propagate(edge_index, x=x)

    def message(self, x_j):
        # x_j: features of source nodes for each edge [num_edges, features]
        # The _j suffix auto-indexes source nodes, _i indexes target nodes
        return x_j
The
_i
/
_j
convention
: any tensor passed to
propagate()
can be auto-indexed by appending
_i
(target/central node) or
_j
(source/neighbor node) in the
message()
signature. So if you pass
x=...
to propagate, you can access
x_i
and
x_j
in message().
Read
references/message_passing.md
for the full GCN and EdgeConv implementation examples.
要实现新型GNN层,需继承
MessagePassing
。框架如下:
  1. propagate()
    协调消息传递流程
  2. message()
    定义沿每条边传递的信息(phi函数)
  3. aggregate()
    汇总每个节点收到的消息(求和/均值/最大值)
  4. update()
    转换汇总后的结果(gamma函数)
python
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree

class MyConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super().__init__(aggr='add')  # "add"、"mean"或"max"
        self.lin = torch.nn.Linear(in_channels, out_channels)

    def forward(self, x, edge_index):
        # 消息传递前的预处理
        x = self.lin(x)
        # 启动消息传递
        return self.propagate(edge_index, x=x)

    def message(self, x_j):
        # x_j: 每条边的源节点特征 [边数量, 特征维度]
        # 后缀_j自动索引源节点,_i索引目标节点
        return x_j
_i
/
_j
约定
:传递给
propagate()
的任何张量,都可以在
message()
的签名中通过添加
_i
(目标/中心节点)或
_j
(源/邻居节点)来自动索引。例如,若向propagate传入
x=...
,则可在message()中访问
x_i
x_j
查看
references/message_passing.md
获取完整的GCN和EdgeConv实现示例。

Task-Specific Patterns

任务特定模式

Node Classification

节点分类

python
undefined
python
undefined

Full-batch training on a single graph (e.g., Cora)

在单张图上进行全批量训练(如Cora)

model.train() for epoch in range(200): optimizer.zero_grad() out = model(data.x, data.edge_index) loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask]) loss.backward() optimizer.step()
model.train() for epoch in range(200): optimizer.zero_grad() out = model(data.x, data.edge_index) loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask]) loss.backward() optimizer.step()

Evaluation

评估

model.eval() pred = model(data.x, data.edge_index).argmax(dim=1) acc = (pred[data.test_mask] == data.y[data.test_mask]).float().mean()
undefined
model.eval() pred = model(data.x, data.edge_index).argmax(dim=1) acc = (pred[data.test_mask] == data.y[data.test_mask]).float().mean()
undefined

Graph Classification

图分类

Multiple graphs — use
DataLoader
for mini-batching and global pooling to get graph-level representations:
python
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv, global_mean_pool

loader = DataLoader(dataset, batch_size=32, shuffle=True)

class GraphClassifier(torch.nn.Module):
    def __init__(self, in_ch, hidden_ch, out_ch):
        super().__init__()
        self.conv1 = GCNConv(in_ch, hidden_ch)
        self.conv2 = GCNConv(hidden_ch, hidden_ch)
        self.lin = torch.nn.Linear(hidden_ch, out_ch)

    def forward(self, x, edge_index, batch):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index).relu()
        x = global_mean_pool(x, batch)  # [num_graphs_in_batch, hidden_ch]
        return self.lin(x)
处理多张图时——使用
DataLoader
进行小批量处理,并通过全局池化得到图级表示:
python
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv, global_mean_pool

loader = DataLoader(dataset, batch_size=32, shuffle=True)

class GraphClassifier(torch.nn.Module):
    def __init__(self, in_ch, hidden_ch, out_ch):
        super().__init__()
        self.conv1 = GCNConv(in_ch, hidden_ch)
        self.conv2 = GCNConv(hidden_ch, hidden_ch)
        self.lin = torch.nn.Linear(hidden_ch, out_ch)

    def forward(self, x, edge_index, batch):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index).relu()
        x = global_mean_pool(x, batch)  # [批量中的图数量, 隐藏层维度]
        return self.lin(x)

Training loop

训练循环

for data in loader: out = model(data.x, data.edge_index, data.batch) loss = F.cross_entropy(out, data.y)

PyG's `DataLoader` batches multiple graphs by creating block-diagonal adjacency matrices. The `batch` tensor maps each node to its graph index. Pooling ops (`global_mean_pool`, `global_max_pool`, `global_add_pool`) use this to aggregate per-graph.
for data in loader: out = model(data.x, data.edge_index, data.batch) loss = F.cross_entropy(out, data.y)

PyG的`DataLoader`通过创建块对角邻接矩阵来批量处理多张图。`batch`张量将每个节点映射到其所属的图索引。池化操作(`global_mean_pool`、`global_max_pool`、`global_add_pool`)利用该张量进行图级聚合。

Link Prediction

链接预测

Split edges into train/val/test, use negative sampling:
python
from torch_geometric.transforms import RandomLinkSplit

transform = RandomLinkSplit(
    num_val=0.1,
    num_test=0.1,
    is_undirected=True,
    add_negative_train_samples=False,
)
train_data, val_data, test_data = transform(data)
将边划分为训练/验证/测试集,使用负采样:
python
from torch_geometric.transforms import RandomLinkSplit

transform = RandomLinkSplit(
    num_val=0.1,
    num_test=0.1,
    is_undirected=True,
    add_negative_train_samples=False,
)
train_data, val_data, test_data = transform(data)

Encode nodes, then score edges

编码节点,然后对边进行评分

z = model.encode(train_data.x, train_data.edge_index)
z = model.encode(train_data.x, train_data.edge_index)

Positive edges

正边评分

pos_score = (z[train_data.edge_label_index[0]] * z[train_data.edge_label_index[1]]).sum(dim=1)

Read `references/link_prediction.md` for the complete link prediction guide: GAE/VGAE autoencoders, full training loops, LinkNeighborLoader for large graphs, heterogeneous link prediction, and evaluation metrics.
pos_score = (z[train_data.edge_label_index[0]] * z[train_data.edge_label_index[1]]).sum(dim=1)

查看`references/link_prediction.md`获取完整的链接预测指南:GAE/VGAE自动编码器、完整训练循环、适用于大图的LinkNeighborLoader、异构链接预测以及评估指标。

Scaling to Large Graphs

扩展到大图

For graphs that don't fit in GPU memory, use neighbor sampling via
NeighborLoader
:
python
from torch_geometric.loader import NeighborLoader

train_loader = NeighborLoader(
    data,
    num_neighbors=[15, 10],     # Sample 15 neighbors in hop 1, 10 in hop 2
    batch_size=128,              # Number of seed nodes per batch
    input_nodes=data.train_mask, # Which nodes to sample from
    shuffle=True,
)

for batch in train_loader:
    batch = batch.to(device)
    out = model(batch.x, batch.edge_index)
    # Only use first batch_size nodes for loss (these are the seed nodes)
    loss = F.cross_entropy(out[:batch.batch_size], batch.y[:batch.batch_size])
Key points about NeighborLoader:
  • num_neighbors
    list length should match GNN depth (number of message passing layers)
  • Seed nodes are always the first
    batch.batch_size
    nodes in the output
  • batch.n_id
    maps relabeled indices back to original node IDs
  • Works for both
    Data
    and
    HeteroData
  • For link prediction, use
    LinkNeighborLoader
    instead
  • Sampling more than 2-3 hops is generally infeasible (exponential blowup)
Other scalability options:
ClusterLoader
(ClusterGCN),
GraphSAINTSampler
,
ShaDowKHopSampler
. For multi-GPU training, DDP, PyTorch Lightning integration, and
torch.compile
support, read
references/scaling.md
.
对于无法放入GPU内存的图,使用
NeighborLoader
进行邻居采样:
python
from torch_geometric.loader import NeighborLoader

train_loader = NeighborLoader(
    data,
    num_neighbors=[15, 10],     # 第1跳采样15个邻居,第2跳采样10个邻居
    batch_size=128,              # 每个批量的种子节点数量
    input_nodes=data.train_mask, # 采样的节点来源
    shuffle=True,
)

for batch in train_loader:
    batch = batch.to(device)
    out = model(batch.x, batch.edge_index)
    # 仅使用前batch.batch_size个节点计算损失(这些是种子节点)
    loss = F.cross_entropy(out[:batch.batch_size], batch.y[:batch.batch_size])
NeighborLoader核心要点
  • num_neighbors
    列表长度应与GNN的深度(消息传递层数)匹配
  • 种子节点始终是输出中的前
    batch.batch_size
    个节点
  • batch.n_id
    将重标记的索引映射回原始节点ID
  • 适用于
    Data
    HeteroData
  • 链接预测请使用
    LinkNeighborLoader
  • 采样超过2-3跳通常不可行(指数级增长)
其他扩展方案:
ClusterLoader
(ClusterGCN)、
GraphSAINTSampler
ShaDowKHopSampler
。关于多GPU训练、DDP、PyTorch Lightning集成和
torch.compile
支持,请查看
references/scaling.md

Heterogeneous Graphs

异构图

For graphs with multiple node and edge types (social networks, knowledge graphs, recommendation):
python
from torch_geometric.data import HeteroData

data = HeteroData()
适用于包含多种节点和边类型的图(社交网络、知识图谱、推荐系统):
python
from torch_geometric.data import HeteroData

data = HeteroData()

Node features — indexed by node type string

节点特征——按节点类型字符串索引

data['user'].x = torch.randn(1000, 64) data['movie'].x = torch.randn(500, 128)
data['user'].x = torch.randn(1000, 64) data['movie'].x = torch.randn(500, 128)

Edge indices — indexed by (src_type, edge_type, dst_type) triplet

边索引——按(源类型, 边类型, 目标类型)三元组索引

data['user', 'rates', 'movie'].edge_index = torch.randint(0, 500, (2, 3000)) data['user', 'follows', 'user'].edge_index = torch.randint(0, 1000, (2, 5000))
data['user', 'rates', 'movie'].edge_index = torch.randint(0, 500, (2, 3000)) data['user', 'follows', 'user'].edge_index = torch.randint(0, 1000, (2, 5000))

Access convenience dicts

便捷访问字典

data.x_dict # {'user': tensor, 'movie': tensor} data.edge_index_dict # {('user','rates','movie'): tensor, ...} data.metadata() # ([node_types], [edge_types])
undefined
data.x_dict # {'user': 张量, 'movie': 张量} data.edge_index_dict # {('user','rates','movie'): 张量, ...} data.metadata() # ([节点类型列表], [边类型列表])
undefined

Three ways to build heterogeneous GNNs

构建异构GNN的三种方式

1. Auto-convert with
to_hetero()
— write a homogeneous model, convert automatically:
python
from torch_geometric.nn import SAGEConv, to_hetero

class GNN(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = SAGEConv((-1, -1), hidden_channels)
        self.conv2 = SAGEConv((-1, -1), out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index)
        return x

model = GNN(64, dataset.num_classes)
model = to_hetero(model, data.metadata(), aggr='sum')
1. 使用
to_hetero()
自动转换
——编写同构模型,自动转换为异构模型:
python
from torch_geometric.nn import SAGEConv, to_hetero

class GNN(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = SAGEConv((-1, -1), hidden_channels)
        self.conv2 = SAGEConv((-1, -1), out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index)
        return x

model = GNN(64, dataset.num_classes)
model = to_hetero(model, data.metadata(), aggr='sum')

Now accepts dicts:

现在接受字典作为输入:

out = model(data.x_dict, data.edge_index_dict)

Use `(-1, -1)` for bipartite input channels (source, target may differ). Lazy init handles the rest.

**2. `HeteroConv` wrapper** — different conv per edge type:

```python
from torch_geometric.nn import HeteroConv, GCNConv, SAGEConv, GATConv

conv = HeteroConv({
    ('paper', 'cites', 'paper'): GCNConv(-1, 64),
    ('author', 'writes', 'paper'): SAGEConv((-1, -1), 64),
    ('paper', 'rev_writes', 'author'): GATConv((-1, -1), 64, add_self_loops=False),
}, aggr='sum')
3. Native heterogeneous operators like
HGTConv
:
python
from torch_geometric.nn import HGTConv
conv = HGTConv(hidden_channels, hidden_channels, data.metadata(), num_heads=4)
Important for heterogeneous graphs:
  • Use
    T.ToUndirected()
    to add reverse edge types for bidirectional message flow
  • Disable
    add_self_loops
    in bipartite conv layers (different source/dest types) — use skip connections instead:
    conv(x, edge_index) + lin(x)
  • For NeighborLoader on HeteroData, specify
    input_nodes
    as
    ('node_type', mask)
    tuple
  • num_neighbors
    can be a dict keyed by edge type for fine-grained control
Read
references/heterogeneous.md
for complete examples including training loops and NeighborLoader usage with heterogeneous graphs.
out = model(data.x_dict, data.edge_index_dict)

对于 bipartite 输入通道(源和目标维度可能不同),使用`(-1, -1)`。延迟初始化会处理其余部分。

**2. `HeteroConv`包装器**——为每种边类型使用不同的卷积层:

```python
from torch_geometric.nn import HeteroConv, GCNConv, SAGEConv, GATConv

conv = HeteroConv({
    ('paper', 'cites', 'paper'): GCNConv(-1, 64),
    ('author', 'writes', 'paper'): SAGEConv((-1, -1), 64),
    ('paper', 'rev_writes', 'author'): GATConv((-1, -1), 64, add_self_loops=False),
}, aggr='sum')
3. 原生异构算子
HGTConv
python
from torch_geometric.nn import HGTConv
conv = HGTConv(hidden_channels, hidden_channels, data.metadata(), num_heads=4)
异构图重要提示
  • 使用
    T.ToUndirected()
    添加反向边类型,实现双向消息传递
  • 在 bipartite 卷积层中禁用
    add_self_loops
    (源和目标节点类型不同)——改用跳跃连接:
    conv(x, edge_index) + lin(x)
  • 在HeteroData上使用NeighborLoader时,需将
    input_nodes
    指定为
    ('节点类型', 掩码)
    元组
  • num_neighbors
    可以是按边类型键控的字典,实现细粒度控制
查看
references/heterogeneous.md
获取完整示例,包括训练循环和异构图的NeighborLoader用法。

Custom Datasets

自定义数据集

For loading your own data into PyG:
  • Quick (no class needed): Create
    Data
    objects directly and pass a list to
    DataLoader
  • Reusable (fits in RAM): Subclass
    InMemoryDataset
    — override
    raw_file_names
    ,
    processed_file_names
    ,
    download()
    ,
    process()
  • Large (disk-backed): Subclass
    Dataset
    — also override
    len()
    and
    get()
  • From CSV: Load node/edge tables with pandas, build mappings to consecutive indices, assemble into
    Data
    or
    HeteroData
  • From NetworkX:
    from_networkx(G)
    converts a NetworkX graph directly
  • From scipy sparse:
    from_scipy_sparse_matrix(adj)
    extracts edge_index
Read
references/custom_datasets.md
for complete examples with all patterns, CSV loading with encoders, and the MovieLens walkthrough.
将自有数据加载到PyG中:
  • 快速方式(无需类):直接创建
    Data
    对象,将列表传入
    DataLoader
  • 可复用方式(可放入内存):继承
    InMemoryDataset
    ——重写
    raw_file_names
    processed_file_names
    download()
    process()
  • 大数据方式(磁盘存储):继承
    Dataset
    ——还需重写
    len()
    get()
  • 从CSV加载:使用pandas加载节点/边表,构建连续索引映射,组装成
    Data
    HeteroData
  • 从NetworkX加载
    from_networkx(G)
    直接将NetworkX图转换为PyG格式
  • 从scipy稀疏矩阵加载
    from_scipy_sparse_matrix(adj)
    提取edge_index
查看
references/custom_datasets.md
获取所有模式的完整示例、带编码器的CSV加载以及MovieLens教程。

Explainability

可解释性

PyG provides
torch_geometric.explain
for interpreting GNN predictions:
python
from torch_geometric.explain import Explainer, GNNExplainer

explainer = Explainer(
    model=model,
    algorithm=GNNExplainer(epochs=200),
    explanation_type='model',
    node_mask_type='attributes',
    edge_mask_type='object',
    model_config=dict(
        mode='multiclass_classification',
        task_level='node',
        return_type='log_probs',
    ),
)

explanation = explainer(data.x, data.edge_index, index=10)
explanation.visualize_graph()           # Important subgraph
explanation.visualize_feature_importance(top_k=10)  # Feature importance
Available algorithms:
GNNExplainer
(optimization-based),
PGExplainer
(parametric, trained),
CaptumExplainer
(gradient-based via Captum),
AttentionExplainer
(attention weights). Works for both homogeneous and heterogeneous graphs.
Read
references/explainability.md
for all algorithms, heterogeneous explanations, evaluation metrics, and PGExplainer training.
PyG提供
torch_geometric.explain
用于解释GNN预测:
python
from torch_geometric.explain import Explainer, GNNExplainer

explainer = Explainer(
    model=model,
    algorithm=GNNExplainer(epochs=200),
    explanation_type='model',
    node_mask_type='attributes',
    edge_mask_type='object',
    model_config=dict(
        mode='multiclass_classification',
        task_level='node',
        return_type='log_probs',
    ),
)

explanation = explainer(data.x, data.edge_index, index=10)
explanation.visualize_graph()           # 重要子图可视化
explanation.visualize_feature_importance(top_k=10)  # 特征重要性可视化
可用算法:
GNNExplainer
(基于优化)、
PGExplainer
(参数化,需训练)、
CaptumExplainer
(基于梯度,依赖Captum)、
AttentionExplainer
(基于注意力权重)。适用于同构图和异构图。
查看
references/explainability.md
获取所有算法、异构解释、评估指标以及PGExplainer训练方法。

Common Pitfalls

常见陷阱

  1. edge_index shape: Must be
    [2, num_edges]
    , not
    [num_edges, 2]
    . Transpose if needed.
  2. Forgetting activations: Conv layers don't include ReLU/etc — add them manually.
  3. Self-loops in hetero bipartite: Don't use
    add_self_loops=True
    when source and dest node types differ. Use skip connections instead.
  4. NeighborLoader slicing: Only the first
    batch.batch_size
    nodes are your seed nodes. Slice predictions and labels accordingly.
  5. Undirected graphs: If your graph is undirected, include edges in both directions in
    edge_index
    , or use
    T.ToUndirected()
    .
  6. Lazy init: Models with
    -1
    input channels need one forward pass with
    torch.no_grad()
    before training to initialize parameters.
  7. Global pooling for graph tasks: Use
    global_mean_pool(x, batch)
    (not manual reshape) to aggregate node features to graph-level.
  8. num_neighbors alignment: Keep
    len(num_neighbors)
    equal to the number of GNN layers. More hops than layers wastes compute; fewer means wasted model capacity.
  1. edge_index形状:必须是
    [2, 边数量]
    ,而非
    [边数量, 2]
    。必要时进行转置。
  2. 忘记添加激活函数:卷积层不包含ReLU等激活函数——需手动添加。
  3. 异构 bipartite 图中的自环:源和目标节点类型不同时,不要设置
    add_self_loops=True
    。改用跳跃连接。
  4. NeighborLoader切片:仅前
    batch.batch_size
    个节点是种子节点。需相应地切片预测结果和标签。
  5. 无向图处理:如果是无向图,需在
    edge_index
    中包含双向边,或使用
    T.ToUndirected()
  6. 延迟初始化:输入通道为
    -1
    的模型,在训练前需执行一次带
    torch.no_grad()
    的前向传播以初始化参数。
  7. 图任务的全局池化:使用
    global_mean_pool(x, batch)
    (而非手动重塑)将节点特征聚合为图级特征。
  8. num_neighbors对齐:保持
    len(num_neighbors)
    与GNN层数相等。跳数多于层数会浪费计算资源;跳数少于层数则会浪费模型容量。