torch-geometric
Compare original and translation side by side
🇺🇸
Original
English🇨🇳
Translation
ChinesePyTorch Geometric (PyG)
PyTorch Geometric (PyG)
PyG is the standard library for Graph Neural Networks built on PyTorch. It provides data structures for graphs, 60+ GNN layer implementations, scalable mini-batch training, and support for heterogeneous graphs.
Install: (or ; requires PyTorch). Optional: , , , for accelerated ops.
uv add torch_geometricuv pip install torch_geometricpyg-libtorch-scattertorch-sparsetorch-clusterPyG是基于PyTorch构建的图神经网络标准库。它提供图数据结构、60余种GNN层实现、可扩展的小批量训练,并支持异构图。
安装:(或;需依赖PyTorch)。可选安装:、、、以加速运算。
uv add torch_geometricuv pip install torch_geometricpyg-libtorch-scattertorch-sparsetorch-clusterCore Concepts
核心概念
Graph Data: Data
and HeteroData
DataHeteroData图数据:Data
与HeteroData
DataHeteroDataA graph lives in a object. The key attributes:
Datapython
from torch_geometric.data import Data
data = Data(
x=node_features, # [num_nodes, num_node_features]
edge_index=edge_index, # [2, num_edges] — COO format, dtype=torch.long
edge_attr=edge_features, # [num_edges, num_edge_features]
y=labels, # node-level [num_nodes, *] or graph-level [1, *]
pos=positions, # [num_nodes, num_dimensions] (for point clouds/spatial)
)edge_index[2, num_edges]edge_index[0]edge_index[1].contiguous()python
undefined图数据存储在对象中。核心属性如下:
Datapython
from torch_geometric.data import Data
data = Data(
x=node_features, # [节点数量, 节点特征维度]
edge_index=edge_index, # [2, 边数量] — COO格式,数据类型为torch.long
edge_attr=edge_features, # [边数量, 边特征维度]
y=labels, # 节点级标签 [节点数量, *] 或图级标签 [1, *]
pos=positions, # [节点数量, 空间维度](用于点云/空间数据)
)edge_index[2, 边数量]edge_index[0]edge_index[1].contiguous()python
undefinedIf edges are [[src1, dst1], [src2, dst2], ...] — transpose first:
如果边的格式为[[src1, dst1], [src2, dst2], ...] — 先转置:
edge_index = edge_pairs.t().contiguous()
For undirected graphs, include both directions: edge (0,1) needs both `[0,1]` and `[1,0]` in edge_index.
For heterogeneous graphs, use `HeteroData` — see the Heterogeneous Graphs section below.edge_index = edge_pairs.t().contiguous()
对于无向图,需要包含双向边:边(0,1)需要同时在edge_index中添加`[0,1]`和`[1,0]`。
对于异构图,请使用`HeteroData`——详见下文的异构图章节。Datasets
数据集
PyG bundles many standard datasets that auto-download and preprocess:
python
from torch_geometric.datasets import Planetoid, TUDatasetPyG内置了许多标准数据集,可自动下载并预处理:
python
from torch_geometric.datasets import Planetoid, TUDatasetSingle-graph node classification (Cora, Citeseer, Pubmed)
单图节点分类数据集(Cora、Citeseer、Pubmed)
dataset = Planetoid(root='./data', name='Cora')
data = dataset[0] # single graph with train/val/test masks
dataset = Planetoid(root='./data', name='Cora')
data = dataset[0] # 包含训练/验证/测试掩码的单张图
Multi-graph classification (ENZYMES, MUTAG, IMDB-BINARY, etc.)
多图分类数据集(ENZYMES、MUTAG、IMDB-BINARY等)
dataset = TUDataset(root='./data', name='ENZYMES')
dataset = TUDataset(root='./data', name='ENZYMES')
dataset[0], dataset[1], ... are individual graphs
dataset[0], dataset[1], ... 为单个图数据
Common datasets by task:
- **Node classification**: Planetoid (Cora/Citeseer/Pubmed), OGB (ogbn-arxiv, ogbn-products, ogbn-mag)
- **Graph classification**: TUDataset (MUTAG, ENZYMES, PROTEINS, IMDB-BINARY), OGB (ogbg-molhiv)
- **Link prediction**: OGB (ogbl-collab, ogbl-citation2)
- **Molecular**: QM7, QM9, MoleculeNet
- **Point cloud/mesh**: ShapeNet, ModelNet10/40, FAUST
按任务分类的常见数据集:
- **节点分类**:Planetoid(Cora/Citeseer/Pubmed)、OGB(ogbn-arxiv、ogbn-products、ogbn-mag)
- **图分类**:TUDataset(MUTAG、ENZYMES、PROTEINS、IMDB-BINARY)、OGB(ogbg-molhiv)
- **链接预测**:OGB(ogbl-collab、ogbl-citation2)
- **分子数据**:QM7、QM9、MoleculeNet
- **点云/网格数据**:ShapeNet、ModelNet10/40、FAUSTTransforms
变换操作
Transforms preprocess or augment graph data, analogous to torchvision transforms:
python
import torch_geometric.transforms as T变换操作用于预处理或增强图数据,类似于torchvision的变换:
python
import torch_geometric.transforms as TCommon transforms
常见变换
T.NormalizeFeatures() # Row-normalize node features to sum to 1
T.ToUndirected() # Add reverse edges to make graph undirected
T.AddSelfLoops() # Add self-loop edges
T.KNNGraph(k=6) # Build k-NN graph from point cloud positions
T.RandomJitter(0.01) # Random noise augmentation on positions
T.Compose([...]) # Chain multiple transforms
T.NormalizeFeatures() # 对节点特征进行行归一化,使其和为1
T.ToUndirected() # 添加反向边,将图转换为无向图
T.AddSelfLoops() # 添加自环边
T.KNNGraph(k=6) # 根据点云位置构建k近邻图
T.RandomJitter(0.01) # 为位置添加随机噪声增强
T.Compose([...]) # 组合多个变换
Apply as pre_transform (once, saved to disk) or transform (every access)
可设置为pre_transform(仅执行一次,保存到磁盘)或transform(每次访问时执行)
dataset = ShapeNet(root='./data', pre_transform=T.KNNGraph(k=6),
transform=T.RandomJitter(0.01))
undefineddataset = ShapeNet(root='./data', pre_transform=T.KNNGraph(k=6),
transform=T.RandomJitter(0.01))
undefinedBuilding GNN Models
构建GNN模型
Quick Start: Using Built-in Layers
快速入门:使用内置层
The fastest way to build a GNN — stack conv layers from :
torch_geometric.nnpython
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
class GCN(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
super().__init__()
self.conv1 = GCNConv(in_channels, hidden_channels)
self.conv2 = GCNConv(hidden_channels, out_channels)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index).relu()
x = F.dropout(x, p=0.5, training=self.training)
x = self.conv2(x, edge_index)
return xImportant: PyG conv layers do NOT include activation functions — apply them yourself after each layer. This is by design for flexibility.
构建GNN的最快方式——堆叠中的卷积层:
torch_geometric.nnpython
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
class GCN(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
super().__init__()
self.conv1 = GCNConv(in_channels, hidden_channels)
self.conv2 = GCNConv(hidden_channels, out_channels)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index).relu()
x = F.dropout(x, p=0.5, training=self.training)
x = self.conv2(x, edge_index)
return x重要提示:PyG的卷积层不包含激活函数——需手动在每层后添加。这是为了灵活性而设计的。
Choosing a Conv Layer
选择卷积层
Pick based on your task and graph structure:
| Layer | Best for | Key idea |
|---|---|---|
| Homogeneous, semi-supervised node classification | Spectral-inspired, degree-normalized aggregation |
| When neighbor importance varies | Attention-weighted messages |
| Large graphs, inductive settings | Sampling-friendly, learnable aggregation |
| Graph classification, maximizing expressiveness | As powerful as WL test |
| Rich edge features, complex interactions | Multi-head attention with edge features |
| Point clouds, dynamic graphs | MLP on edge features (x_i, x_j - x_i) |
| Heterogeneous with many relation types | Relation-specific weight matrices |
| Heterogeneous graphs | Type-specific attention |
All conv layers accept at minimum. Many also accept for edge features.
(x, edge_index)edge_attr根据任务和图结构选择合适的卷积层:
| 层 | 适用场景 | 核心思想 |
|---|---|---|
| 同构图、半监督节点分类 | 基于谱方法,度归一化聚合 |
| 邻居重要性存在差异的场景 | 注意力加权消息传递 |
| 大图、归纳式场景 | 支持采样,可学习聚合方式 |
| 图分类、最大化表达能力 | 与WL测试一样强大 |
| 边特征丰富、交互复杂的场景 | 结合边特征的多头注意力 |
| 点云、动态图 | 基于边特征(x_i, x_j - x_i)的MLP |
| 包含多种关系类型的异构图 | 关系特定的权重矩阵 |
| 异构图 | 类型特定的注意力机制 |
所有卷积层至少接受作为输入。许多层还支持传入以处理边特征。
(x, edge_index)edge_attrLazy Initialization
延迟初始化
Use for input channels to let PyG infer dimensions automatically — especially useful for heterogeneous models:
-1python
conv = SAGEConv((-1, -1), 64) # Input dims inferred on first forward pass输入通道使用,让PyG自动推断维度——在异构模型中尤其有用:
-1python
conv = SAGEConv((-1, -1), 64) # 首次前向传播时自动推断输入维度Initialize lazy modules:
初始化延迟模块:
with torch.no_grad():
out = model(data.x, data.edge_index)
undefinedwith torch.no_grad():
out = model(data.x, data.edge_index)
undefinedHigh-Level Model APIs
高级模型API
For common architectures, PyG provides ready-made model classes:
python
from torch_geometric.nn import GraphSAGE, GCN, GAT, GIN
model = GraphSAGE(
in_channels=dataset.num_features,
hidden_channels=64,
out_channels=dataset.num_classes,
num_layers=2,
)对于常见架构,PyG提供了现成的模型类:
python
from torch_geometric.nn import GraphSAGE, GCN, GAT, GIN
model = GraphSAGE(
in_channels=dataset.num_features,
hidden_channels=64,
out_channels=dataset.num_classes,
num_layers=2,
)Custom Layers via MessagePassing
通过MessagePassing实现自定义层
To implement a novel GNN layer, subclass . The framework is:
MessagePassing- orchestrates the message passing
propagate() - defines what info flows along each edge (the phi function)
message() - combines messages at each node (sum/mean/max)
aggregate() - transforms the aggregated result (the gamma function)
update()
python
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
class MyConv(MessagePassing):
def __init__(self, in_channels, out_channels):
super().__init__(aggr='add') # "add", "mean", or "max"
self.lin = torch.nn.Linear(in_channels, out_channels)
def forward(self, x, edge_index):
# Pre-processing before message passing
x = self.lin(x)
# Start message passing
return self.propagate(edge_index, x=x)
def message(self, x_j):
# x_j: features of source nodes for each edge [num_edges, features]
# The _j suffix auto-indexes source nodes, _i indexes target nodes
return x_jThe / convention: any tensor passed to can be auto-indexed by appending (target/central node) or (source/neighbor node) in the signature. So if you pass to propagate, you can access and in message().
_i_jpropagate()_i_jmessage()x=...x_ix_jRead for the full GCN and EdgeConv implementation examples.
references/message_passing.md要实现新型GNN层,需继承。框架如下:
MessagePassing- 协调消息传递流程
propagate() - 定义沿每条边传递的信息(phi函数)
message() - 汇总每个节点收到的消息(求和/均值/最大值)
aggregate() - 转换汇总后的结果(gamma函数)
update()
python
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
class MyConv(MessagePassing):
def __init__(self, in_channels, out_channels):
super().__init__(aggr='add') # "add"、"mean"或"max"
self.lin = torch.nn.Linear(in_channels, out_channels)
def forward(self, x, edge_index):
# 消息传递前的预处理
x = self.lin(x)
# 启动消息传递
return self.propagate(edge_index, x=x)
def message(self, x_j):
# x_j: 每条边的源节点特征 [边数量, 特征维度]
# 后缀_j自动索引源节点,_i索引目标节点
return x_j_i_jpropagate()message()_i_jx=...x_ix_j查看获取完整的GCN和EdgeConv实现示例。
references/message_passing.mdTask-Specific Patterns
任务特定模式
Node Classification
节点分类
python
undefinedpython
undefinedFull-batch training on a single graph (e.g., Cora)
在单张图上进行全批量训练(如Cora)
model.train()
for epoch in range(200):
optimizer.zero_grad()
out = model(data.x, data.edge_index)
loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
model.train()
for epoch in range(200):
optimizer.zero_grad()
out = model(data.x, data.edge_index)
loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
Evaluation
评估
model.eval()
pred = model(data.x, data.edge_index).argmax(dim=1)
acc = (pred[data.test_mask] == data.y[data.test_mask]).float().mean()
undefinedmodel.eval()
pred = model(data.x, data.edge_index).argmax(dim=1)
acc = (pred[data.test_mask] == data.y[data.test_mask]).float().mean()
undefinedGraph Classification
图分类
Multiple graphs — use for mini-batching and global pooling to get graph-level representations:
DataLoaderpython
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv, global_mean_pool
loader = DataLoader(dataset, batch_size=32, shuffle=True)
class GraphClassifier(torch.nn.Module):
def __init__(self, in_ch, hidden_ch, out_ch):
super().__init__()
self.conv1 = GCNConv(in_ch, hidden_ch)
self.conv2 = GCNConv(hidden_ch, hidden_ch)
self.lin = torch.nn.Linear(hidden_ch, out_ch)
def forward(self, x, edge_index, batch):
x = self.conv1(x, edge_index).relu()
x = self.conv2(x, edge_index).relu()
x = global_mean_pool(x, batch) # [num_graphs_in_batch, hidden_ch]
return self.lin(x)处理多张图时——使用进行小批量处理,并通过全局池化得到图级表示:
DataLoaderpython
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv, global_mean_pool
loader = DataLoader(dataset, batch_size=32, shuffle=True)
class GraphClassifier(torch.nn.Module):
def __init__(self, in_ch, hidden_ch, out_ch):
super().__init__()
self.conv1 = GCNConv(in_ch, hidden_ch)
self.conv2 = GCNConv(hidden_ch, hidden_ch)
self.lin = torch.nn.Linear(hidden_ch, out_ch)
def forward(self, x, edge_index, batch):
x = self.conv1(x, edge_index).relu()
x = self.conv2(x, edge_index).relu()
x = global_mean_pool(x, batch) # [批量中的图数量, 隐藏层维度]
return self.lin(x)Training loop
训练循环
for data in loader:
out = model(data.x, data.edge_index, data.batch)
loss = F.cross_entropy(out, data.y)
PyG's `DataLoader` batches multiple graphs by creating block-diagonal adjacency matrices. The `batch` tensor maps each node to its graph index. Pooling ops (`global_mean_pool`, `global_max_pool`, `global_add_pool`) use this to aggregate per-graph.for data in loader:
out = model(data.x, data.edge_index, data.batch)
loss = F.cross_entropy(out, data.y)
PyG的`DataLoader`通过创建块对角邻接矩阵来批量处理多张图。`batch`张量将每个节点映射到其所属的图索引。池化操作(`global_mean_pool`、`global_max_pool`、`global_add_pool`)利用该张量进行图级聚合。Link Prediction
链接预测
Split edges into train/val/test, use negative sampling:
python
from torch_geometric.transforms import RandomLinkSplit
transform = RandomLinkSplit(
num_val=0.1,
num_test=0.1,
is_undirected=True,
add_negative_train_samples=False,
)
train_data, val_data, test_data = transform(data)将边划分为训练/验证/测试集,使用负采样:
python
from torch_geometric.transforms import RandomLinkSplit
transform = RandomLinkSplit(
num_val=0.1,
num_test=0.1,
is_undirected=True,
add_negative_train_samples=False,
)
train_data, val_data, test_data = transform(data)Encode nodes, then score edges
编码节点,然后对边进行评分
z = model.encode(train_data.x, train_data.edge_index)
z = model.encode(train_data.x, train_data.edge_index)
Positive edges
正边评分
pos_score = (z[train_data.edge_label_index[0]] * z[train_data.edge_label_index[1]]).sum(dim=1)
Read `references/link_prediction.md` for the complete link prediction guide: GAE/VGAE autoencoders, full training loops, LinkNeighborLoader for large graphs, heterogeneous link prediction, and evaluation metrics.pos_score = (z[train_data.edge_label_index[0]] * z[train_data.edge_label_index[1]]).sum(dim=1)
查看`references/link_prediction.md`获取完整的链接预测指南:GAE/VGAE自动编码器、完整训练循环、适用于大图的LinkNeighborLoader、异构链接预测以及评估指标。Scaling to Large Graphs
扩展到大图
For graphs that don't fit in GPU memory, use neighbor sampling via :
NeighborLoaderpython
from torch_geometric.loader import NeighborLoader
train_loader = NeighborLoader(
data,
num_neighbors=[15, 10], # Sample 15 neighbors in hop 1, 10 in hop 2
batch_size=128, # Number of seed nodes per batch
input_nodes=data.train_mask, # Which nodes to sample from
shuffle=True,
)
for batch in train_loader:
batch = batch.to(device)
out = model(batch.x, batch.edge_index)
# Only use first batch_size nodes for loss (these are the seed nodes)
loss = F.cross_entropy(out[:batch.batch_size], batch.y[:batch.batch_size])Key points about NeighborLoader:
- list length should match GNN depth (number of message passing layers)
num_neighbors - Seed nodes are always the first nodes in the output
batch.batch_size - maps relabeled indices back to original node IDs
batch.n_id - Works for both and
DataHeteroData - For link prediction, use instead
LinkNeighborLoader - Sampling more than 2-3 hops is generally infeasible (exponential blowup)
Other scalability options: (ClusterGCN), , . For multi-GPU training, DDP, PyTorch Lightning integration, and support, read .
ClusterLoaderGraphSAINTSamplerShaDowKHopSamplertorch.compilereferences/scaling.md对于无法放入GPU内存的图,使用进行邻居采样:
NeighborLoaderpython
from torch_geometric.loader import NeighborLoader
train_loader = NeighborLoader(
data,
num_neighbors=[15, 10], # 第1跳采样15个邻居,第2跳采样10个邻居
batch_size=128, # 每个批量的种子节点数量
input_nodes=data.train_mask, # 采样的节点来源
shuffle=True,
)
for batch in train_loader:
batch = batch.to(device)
out = model(batch.x, batch.edge_index)
# 仅使用前batch.batch_size个节点计算损失(这些是种子节点)
loss = F.cross_entropy(out[:batch.batch_size], batch.y[:batch.batch_size])NeighborLoader核心要点:
- 列表长度应与GNN的深度(消息传递层数)匹配
num_neighbors - 种子节点始终是输出中的前个节点
batch.batch_size - 将重标记的索引映射回原始节点ID
batch.n_id - 适用于和
DataHeteroData - 链接预测请使用
LinkNeighborLoader - 采样超过2-3跳通常不可行(指数级增长)
其他扩展方案:(ClusterGCN)、、。关于多GPU训练、DDP、PyTorch Lightning集成和支持,请查看。
ClusterLoaderGraphSAINTSamplerShaDowKHopSamplertorch.compilereferences/scaling.mdHeterogeneous Graphs
异构图
For graphs with multiple node and edge types (social networks, knowledge graphs, recommendation):
python
from torch_geometric.data import HeteroData
data = HeteroData()适用于包含多种节点和边类型的图(社交网络、知识图谱、推荐系统):
python
from torch_geometric.data import HeteroData
data = HeteroData()Node features — indexed by node type string
节点特征——按节点类型字符串索引
data['user'].x = torch.randn(1000, 64)
data['movie'].x = torch.randn(500, 128)
data['user'].x = torch.randn(1000, 64)
data['movie'].x = torch.randn(500, 128)
Edge indices — indexed by (src_type, edge_type, dst_type) triplet
边索引——按(源类型, 边类型, 目标类型)三元组索引
data['user', 'rates', 'movie'].edge_index = torch.randint(0, 500, (2, 3000))
data['user', 'follows', 'user'].edge_index = torch.randint(0, 1000, (2, 5000))
data['user', 'rates', 'movie'].edge_index = torch.randint(0, 500, (2, 3000))
data['user', 'follows', 'user'].edge_index = torch.randint(0, 1000, (2, 5000))
Access convenience dicts
便捷访问字典
data.x_dict # {'user': tensor, 'movie': tensor}
data.edge_index_dict # {('user','rates','movie'): tensor, ...}
data.metadata() # ([node_types], [edge_types])
undefineddata.x_dict # {'user': 张量, 'movie': 张量}
data.edge_index_dict # {('user','rates','movie'): 张量, ...}
data.metadata() # ([节点类型列表], [边类型列表])
undefinedThree ways to build heterogeneous GNNs
构建异构GNN的三种方式
1. Auto-convert with — write a homogeneous model, convert automatically:
to_hetero()python
from torch_geometric.nn import SAGEConv, to_hetero
class GNN(torch.nn.Module):
def __init__(self, hidden_channels, out_channels):
super().__init__()
self.conv1 = SAGEConv((-1, -1), hidden_channels)
self.conv2 = SAGEConv((-1, -1), out_channels)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index).relu()
x = self.conv2(x, edge_index)
return x
model = GNN(64, dataset.num_classes)
model = to_hetero(model, data.metadata(), aggr='sum')1. 使用自动转换——编写同构模型,自动转换为异构模型:
to_hetero()python
from torch_geometric.nn import SAGEConv, to_hetero
class GNN(torch.nn.Module):
def __init__(self, hidden_channels, out_channels):
super().__init__()
self.conv1 = SAGEConv((-1, -1), hidden_channels)
self.conv2 = SAGEConv((-1, -1), out_channels)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index).relu()
x = self.conv2(x, edge_index)
return x
model = GNN(64, dataset.num_classes)
model = to_hetero(model, data.metadata(), aggr='sum')Now accepts dicts:
现在接受字典作为输入:
out = model(data.x_dict, data.edge_index_dict)
Use `(-1, -1)` for bipartite input channels (source, target may differ). Lazy init handles the rest.
**2. `HeteroConv` wrapper** — different conv per edge type:
```python
from torch_geometric.nn import HeteroConv, GCNConv, SAGEConv, GATConv
conv = HeteroConv({
('paper', 'cites', 'paper'): GCNConv(-1, 64),
('author', 'writes', 'paper'): SAGEConv((-1, -1), 64),
('paper', 'rev_writes', 'author'): GATConv((-1, -1), 64, add_self_loops=False),
}, aggr='sum')3. Native heterogeneous operators like :
HGTConvpython
from torch_geometric.nn import HGTConv
conv = HGTConv(hidden_channels, hidden_channels, data.metadata(), num_heads=4)Important for heterogeneous graphs:
- Use to add reverse edge types for bidirectional message flow
T.ToUndirected() - Disable in bipartite conv layers (different source/dest types) — use skip connections instead:
add_self_loopsconv(x, edge_index) + lin(x) - For NeighborLoader on HeteroData, specify as
input_nodestuple('node_type', mask) - can be a dict keyed by edge type for fine-grained control
num_neighbors
Read for complete examples including training loops and NeighborLoader usage with heterogeneous graphs.
references/heterogeneous.mdout = model(data.x_dict, data.edge_index_dict)
对于 bipartite 输入通道(源和目标维度可能不同),使用`(-1, -1)`。延迟初始化会处理其余部分。
**2. `HeteroConv`包装器**——为每种边类型使用不同的卷积层:
```python
from torch_geometric.nn import HeteroConv, GCNConv, SAGEConv, GATConv
conv = HeteroConv({
('paper', 'cites', 'paper'): GCNConv(-1, 64),
('author', 'writes', 'paper'): SAGEConv((-1, -1), 64),
('paper', 'rev_writes', 'author'): GATConv((-1, -1), 64, add_self_loops=False),
}, aggr='sum')3. 原生异构算子如:
HGTConvpython
from torch_geometric.nn import HGTConv
conv = HGTConv(hidden_channels, hidden_channels, data.metadata(), num_heads=4)异构图重要提示:
- 使用添加反向边类型,实现双向消息传递
T.ToUndirected() - 在 bipartite 卷积层中禁用(源和目标节点类型不同)——改用跳跃连接:
add_self_loopsconv(x, edge_index) + lin(x) - 在HeteroData上使用NeighborLoader时,需将指定为
input_nodes元组('节点类型', 掩码) - 可以是按边类型键控的字典,实现细粒度控制
num_neighbors
查看获取完整示例,包括训练循环和异构图的NeighborLoader用法。
references/heterogeneous.mdCustom Datasets
自定义数据集
For loading your own data into PyG:
- Quick (no class needed): Create objects directly and pass a list to
DataDataLoader - Reusable (fits in RAM): Subclass — override
InMemoryDataset,raw_file_names,processed_file_names,download()process() - Large (disk-backed): Subclass — also override
Datasetandlen()get() - From CSV: Load node/edge tables with pandas, build mappings to consecutive indices, assemble into or
DataHeteroData - From NetworkX: converts a NetworkX graph directly
from_networkx(G) - From scipy sparse: extracts edge_index
from_scipy_sparse_matrix(adj)
Read for complete examples with all patterns, CSV loading with encoders, and the MovieLens walkthrough.
references/custom_datasets.md将自有数据加载到PyG中:
- 快速方式(无需类):直接创建对象,将列表传入
DataDataLoader - 可复用方式(可放入内存):继承——重写
InMemoryDataset、raw_file_names、processed_file_names、download()process() - 大数据方式(磁盘存储):继承——还需重写
Dataset和len()get() - 从CSV加载:使用pandas加载节点/边表,构建连续索引映射,组装成或
DataHeteroData - 从NetworkX加载:直接将NetworkX图转换为PyG格式
from_networkx(G) - 从scipy稀疏矩阵加载:提取edge_index
from_scipy_sparse_matrix(adj)
查看获取所有模式的完整示例、带编码器的CSV加载以及MovieLens教程。
references/custom_datasets.mdExplainability
可解释性
PyG provides for interpreting GNN predictions:
torch_geometric.explainpython
from torch_geometric.explain import Explainer, GNNExplainer
explainer = Explainer(
model=model,
algorithm=GNNExplainer(epochs=200),
explanation_type='model',
node_mask_type='attributes',
edge_mask_type='object',
model_config=dict(
mode='multiclass_classification',
task_level='node',
return_type='log_probs',
),
)
explanation = explainer(data.x, data.edge_index, index=10)
explanation.visualize_graph() # Important subgraph
explanation.visualize_feature_importance(top_k=10) # Feature importanceAvailable algorithms: (optimization-based), (parametric, trained), (gradient-based via Captum), (attention weights). Works for both homogeneous and heterogeneous graphs.
GNNExplainerPGExplainerCaptumExplainerAttentionExplainerRead for all algorithms, heterogeneous explanations, evaluation metrics, and PGExplainer training.
references/explainability.mdPyG提供用于解释GNN预测:
torch_geometric.explainpython
from torch_geometric.explain import Explainer, GNNExplainer
explainer = Explainer(
model=model,
algorithm=GNNExplainer(epochs=200),
explanation_type='model',
node_mask_type='attributes',
edge_mask_type='object',
model_config=dict(
mode='multiclass_classification',
task_level='node',
return_type='log_probs',
),
)
explanation = explainer(data.x, data.edge_index, index=10)
explanation.visualize_graph() # 重要子图可视化
explanation.visualize_feature_importance(top_k=10) # 特征重要性可视化可用算法:(基于优化)、(参数化,需训练)、(基于梯度,依赖Captum)、(基于注意力权重)。适用于同构图和异构图。
GNNExplainerPGExplainerCaptumExplainerAttentionExplainer查看获取所有算法、异构解释、评估指标以及PGExplainer训练方法。
references/explainability.mdCommon Pitfalls
常见陷阱
- edge_index shape: Must be , not
[2, num_edges]. Transpose if needed.[num_edges, 2] - Forgetting activations: Conv layers don't include ReLU/etc — add them manually.
- Self-loops in hetero bipartite: Don't use when source and dest node types differ. Use skip connections instead.
add_self_loops=True - NeighborLoader slicing: Only the first nodes are your seed nodes. Slice predictions and labels accordingly.
batch.batch_size - Undirected graphs: If your graph is undirected, include edges in both directions in , or use
edge_index.T.ToUndirected() - Lazy init: Models with input channels need one forward pass with
-1before training to initialize parameters.torch.no_grad() - Global pooling for graph tasks: Use (not manual reshape) to aggregate node features to graph-level.
global_mean_pool(x, batch) - num_neighbors alignment: Keep equal to the number of GNN layers. More hops than layers wastes compute; fewer means wasted model capacity.
len(num_neighbors)
- edge_index形状:必须是,而非
[2, 边数量]。必要时进行转置。[边数量, 2] - 忘记添加激活函数:卷积层不包含ReLU等激活函数——需手动添加。
- 异构 bipartite 图中的自环:源和目标节点类型不同时,不要设置。改用跳跃连接。
add_self_loops=True - NeighborLoader切片:仅前个节点是种子节点。需相应地切片预测结果和标签。
batch.batch_size - 无向图处理:如果是无向图,需在中包含双向边,或使用
edge_index。T.ToUndirected() - 延迟初始化:输入通道为的模型,在训练前需执行一次带
-1的前向传播以初始化参数。torch.no_grad() - 图任务的全局池化:使用(而非手动重塑)将节点特征聚合为图级特征。
global_mean_pool(x, batch) - num_neighbors对齐:保持与GNN层数相等。跳数多于层数会浪费计算资源;跳数少于层数则会浪费模型容量。
len(num_neighbors)