import torch as th
import torch.nn as nn
import torch.nn.functional as F
import dgl.nn as dglnn
from . import BaseModel, register_model
[文档]@register_model('RGCN')
class RGCN(BaseModel):
"""
**Title:** `Modeling Relational Data with Graph Convolutional Networks <https://arxiv.org/abs/1703.06103>`_
**Authors:** Michael Schlichtkrull, Thomas N. Kipf, Peter Bloem, Rianne van den Berg, Ivan Titov, Max Welling
Parameters
----------
in_dim : int
Input feature size.
hidden_dim : int
Hidden dimension .
out_dim : int
Output feature size.
etypes : list[str]
Relation names.
num_bases : int, optional
Number of bases. If is none, use number of relations. Default: None.
num_hidden_layers: int
Number of RelGraphConvLayer
dropout : float, optional
Dropout rate. Default: 0.0
use_self_loop : bool, optional
True to include self loop message. Default: False
Attributes
-----------
RelGraphConvLayer: RelGraphConvLayer
"""
@classmethod
def build_model_from_args(cls, args, hg):
return cls(args.hidden_dim,
args.hidden_dim,
args.out_dim,
hg.etypes,
args.n_bases,
args.num_layers - 2,
dropout=args.dropout)
def __init__(self, in_dim,
hidden_dim,
out_dim,
etypes,
num_bases,
num_hidden_layers=1,
dropout=0,
use_self_loop=False):
super(RGCN, self).__init__()
self.in_dim = in_dim
self.h_dim = hidden_dim
self.out_dim = out_dim
self.rel_names = list(set(etypes))
self.rel_names.sort()
if num_bases < 0 or num_bases > len(self.rel_names):
self.num_bases = len(self.rel_names)
else:
self.num_bases = num_bases
self.num_hidden_layers = num_hidden_layers
self.dropout = dropout
self.use_self_loop = use_self_loop
self.layers = nn.ModuleList()
# input 2 hidden
self.layers.append(RelGraphConvLayer(
self.in_dim, self.h_dim, self.rel_names,
self.num_bases, activation=F.relu, self_loop=self.use_self_loop,
dropout=self.dropout, weight=True))
# hidden 2 hidden
for i in range(self.num_hidden_layers):
self.layers.append(RelGraphConvLayer(
self.h_dim, self.h_dim, self.rel_names,
self.num_bases, activation=F.relu, self_loop=self.use_self_loop,
dropout=self.dropout))
# hidden 2 output
self.layers.append(RelGraphConvLayer(
self.h_dim, self.out_dim, self.rel_names,
self.num_bases, activation=None,
self_loop=self.use_self_loop))
def forward(self, hg, h_dict):
r"""
Support full-batch and mini-batch training.
Parameters
----------
hg: dgl.HeteroGraph or dgl.blocks
Input graph
h_dict: dict[str, th.Tensor]
Input feature
Returns
-------
h: dict[str, th.Tensor]
output feature
"""
if hasattr(hg, 'ntypes'):
# full graph training,
for layer in self.layers:
h_dict = layer(hg, h_dict)
else:
# minibatch training, block
for layer, block in zip(self.layers, hg):
h_dict = layer(block, h_dict)
return h_dict
def l2_penalty(self):
loss = 0.0005 * th.norm(self.layers[0].weight, p=2, dim=1)
return loss
class RelGraphConvLayer(nn.Module):
r"""Relational graph convolution layer.
We use `HeteroGraphConv <https://docs.dgl.ai/api/python/nn.pytorch.html#heterographconv>`_ to implement the model.
Parameters
----------
in_feat : int
Input feature size.
out_feat : int
Output feature size.
rel_names : list[str]
Relation names.
num_bases : int, optional
Number of bases. If is none, use number of relations. Default: None.
weight : bool, optional
True if a linear layer is applied after message passing. Default: True
bias : bool, optional
True if bias is added. Default: True
activation : callable, optional
Activation function. Default: None
self_loop : bool, optional
True to include self loop message. Default: False
dropout : float, optional
Dropout rate. Default: 0.0
"""
def __init__(self,
in_feat,
out_feat,
rel_names,
num_bases,
*,
weight=True,
bias=True,
activation=None,
self_loop=False,
dropout=0.0):
super(RelGraphConvLayer, self).__init__()
self.in_feat = in_feat
self.out_feat = out_feat
self.rel_names = rel_names
self.num_bases = num_bases
self.bias = bias
self.activation = activation
self.self_loop = self_loop
self.batchnorm = False
self.conv = dglnn.HeteroGraphConv({
rel: dglnn.GraphConv(in_feat, out_feat, norm='right', weight=False, bias=False)
for rel in rel_names
})
self.use_weight = weight
self.use_basis = num_bases < len(self.rel_names) and weight
if self.use_weight:
if self.use_basis:
self.basis = dglnn.WeightBasis((in_feat, out_feat), num_bases, len(self.rel_names))
else:
self.weight = nn.Parameter(th.Tensor(len(self.rel_names), in_feat, out_feat))
nn.init.xavier_uniform_(self.weight, gain=nn.init.calculate_gain('relu'))
# bias
if bias:
self.h_bias = nn.Parameter(th.Tensor(out_feat))
nn.init.zeros_(self.h_bias)
# weight for self loop
if self.self_loop:
self.loop_weight = nn.Parameter(th.Tensor(in_feat, out_feat))
nn.init.xavier_uniform_(self.loop_weight,
gain=nn.init.calculate_gain('relu'))
# define batch norm layer
if self.batchnorm:
self.bn = nn.BatchNorm1d(out_feat)
self.dropout = nn.Dropout(dropout)
def forward(self, g, inputs):
"""Forward computation
Parameters
----------
g : DGLHeteroGraph
Input graph.
inputs : dict[str, torch.Tensor]
Node feature for each node type.
Returns
-------
dict[str, torch.Tensor]
New node features for each node type.
"""
g = g.local_var()
if self.use_weight:
weight = self.basis() if self.use_basis else self.weight
wdict = {self.rel_names[i]: {'weight': w.squeeze(0)}
for i, w in enumerate(th.split(weight, 1, dim=0))}
else:
wdict = {}
if g.is_block:
inputs_src = inputs
inputs_dst = {k: v[:g.number_of_dst_nodes(k)] for k, v in inputs.items()}
else:
inputs_src = inputs_dst = inputs
hs = self.conv(g, inputs_src, mod_kwargs=wdict)
def _apply(ntype, h):
if self.self_loop:
h = h + th.matmul(inputs_dst[ntype], self.loop_weight)
if self.bias:
h = h + self.h_bias
if self.activation:
h = self.activation(h)
if self.batchnorm:
h = self.bn(h)
return self.dropout(h)
return {ntype: _apply(ntype, h) for ntype, h in hs.items()}