import dgl
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn.pytorch import GraphConv, EdgeWeightNorm
from ..utils import transform_relation_graph_list
from . import BaseModel, register_model
[docs]
@register_model('GTN')
class GTN(BaseModel):
r"""
GTN from paper `Graph Transformer Networks <https://arxiv.org/abs/1911.06455>`__
in NeurIPS_2019. You can also see the extension paper `Graph Transformer
Networks: Learning Meta-path Graphs to Improve GNNs <https://arxiv.org/abs/2106.06218.pdf>`__.
`Code from author <https://github.com/seongjunyun/Graph_Transformer_Networks>`__.
Given a heterogeneous graph :math:`G` and its edge relation type set :math:`\mathcal{R}`.Then we extract
the single relation adjacency matrix list. In that, we can generate combination adjacency matrix by conv
the single relation adjacency matrix list. We can generate :math:'l-length' meta-path adjacency matrix
by multiplying combination adjacency matrix. Then we can generate node representation using a GCN layer.
Parameters
----------
num_edge_type : int
Number of relations.
num_channels : int
Number of conv channels.
in_dim : int
The dimension of input feature.
hidden_dim : int
The dimension of hidden layer.
num_class : int
Number of classification type.
num_layers : int
Length of hybrid metapath.
category : string
Type of predicted nodes.
norm : bool
If True, the adjacency matrix will be normalized.
identity : bool
If True, the identity matrix will be added to relation matrix set.
"""
@classmethod
def build_model_from_args(cls, args, hg):
if args.identity:
num_edge_type = len(hg.canonical_etypes) + 1
else:
num_edge_type = len(hg.canonical_etypes)
# add self-loop edge
return cls(num_edge_type=num_edge_type, num_channels=args.num_channels,
in_dim=args.hidden_dim, hidden_dim=args.hidden_dim, num_class=args.out_dim,
num_layers=args.num_layers, category=args.category, norm=args.norm_emd_flag, identity=args.identity)
def __init__(self, num_edge_type, num_channels, in_dim, hidden_dim, num_class, num_layers, category, norm,
identity):
super(GTN, self).__init__()
self.num_edge_type = num_edge_type
self.num_channels = num_channels
self.in_dim = in_dim
self.hidden_dim = hidden_dim
self.num_class = num_class
self.num_layers = num_layers
self.is_norm = norm
self.category = category
self.identity = identity
layers = []
for i in range(num_layers):
if i == 0:
layers.append(GTLayer(num_edge_type, num_channels, first=True))
else:
layers.append(GTLayer(num_edge_type, num_channels, first=False))
self.layers = nn.ModuleList(layers)
self.gcn = GraphConv(in_feats=self.in_dim, out_feats=hidden_dim, norm='none', activation=F.relu)
self.norm = EdgeWeightNorm(norm='right')
self.linear1 = nn.Linear(self.hidden_dim * self.num_channels, self.hidden_dim)
self.linear2 = nn.Linear(self.hidden_dim, self.num_class)
self.category_idx = None
self.A = None
self.h = None
def normalization(self, H):
norm_H = []
for i in range(self.num_channels):
g = H[i]
g = dgl.remove_self_loop(g)
g.edata['w_sum'] = self.norm(g, g.edata['w_sum'])
norm_H.append(g)
return norm_H
def forward(self, hg, h):
with hg.local_scope():
hg.ndata['h'] = h
# * =============== Extract edges in original graph ================
if self.category_idx is None:
self.A, h, self.category_idx = transform_relation_graph_list(hg, category=self.category,
identity=self.identity)
else:
g = dgl.to_homogeneous(hg, ndata='h')
h = g.ndata['h']
# X_ = self.gcn(g, self.h)
A = self.A
# * =============== Get new graph structure ================
for i in range(self.num_layers):
if i == 0:
H, W = self.layers[i](A)
else:
H, W = self.layers[i](A, H)
if self.is_norm == True:
H = self.normalization(H)
# Ws.append(W)
# * =============== GCN Encoder ================
for i in range(self.num_channels):
g = dgl.remove_self_loop(H[i])
edge_weight = g.edata['w_sum']
g = dgl.add_self_loop(g)
edge_weight = th.cat((edge_weight, th.full((g.number_of_nodes(),), 1, device=g.device)))
edge_weight = self.norm(g, edge_weight)
if i == 0:
X_ = self.gcn(g, h, edge_weight=edge_weight)
else:
X_ = th.cat((X_, self.gcn(g, h, edge_weight=edge_weight)), dim=1)
X_ = self.linear1(X_)
X_ = F.relu(X_)
y = self.linear2(X_)
return {self.category: y[self.category_idx]}
class GTLayer(nn.Module):
r"""
CTLayer multiply each combination adjacency matrix :math:`l` times to a :math:`l-length`
meta-paths adjacency matrix.
The method to generate :math:`l-length` meta-path adjacency matrix can be described as:
.. math::
A_{(l)}=\Pi_{i=1}^{l} A_{i}
where :math:`A_{i}` is the combination adjacency matrix generated by GT conv.
Parameters
----------
in_channels: int
The input dimension of GTConv which is numerically equal to the number of relations.
out_channels: int
The input dimension of GTConv which is numerically equal to the number of channel in GTN.
first: bool
If true, the first combination adjacency matrix multiply the combination adjacency matrix.
"""
def __init__(self, in_channels, out_channels, first=True):
super(GTLayer, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.first = first
if self.first:
self.conv1 = GTConv(in_channels, out_channels)
self.conv2 = GTConv(in_channels, out_channels)
else:
self.conv1 = GTConv(in_channels, out_channels)
def forward(self, A, H_=None):
if self.first:
result_A = self.conv1(A)
result_B = self.conv2(A)
W = [(F.softmax(self.conv1.weight, dim=1)).detach(), (F.softmax(self.conv2.weight, dim=1)).detach()]
else:
result_A = H_
result_B = self.conv1(A)
W = [(F.softmax(self.conv1.weight, dim=1)).detach()]
H = []
for i in range(len(result_A)):
g = dgl.adj_product_graph(result_A[i], result_B[i], 'w_sum')
H.append(g)
return H, W
class GTConv(nn.Module):
r"""
We conv each sub adjacency matrix :math:`A_{R_{i}}` to a combination adjacency matrix :math:`A_{1}`:
.. math::
A_{1} = conv\left(A ; W_{c}\right)=\sum_{R_{i} \in R} w_{R_{i}} A_{R_{i}}
where :math:`R_i \subseteq \mathcal{R}` and :math:`W_{c}` is the weight of each relation matrix
"""
def __init__(self, in_channels, out_channels, softmax_flag=True):
super(GTConv, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.weight = nn.Parameter(th.Tensor(out_channels, in_channels))
self.softmax_flag = softmax_flag
self.reset_parameters()
def reset_parameters(self):
nn.init.normal_(self.weight, std=0.01)
def forward(self, A):
if self.softmax_flag:
Filter = F.softmax(self.weight, dim=1)
else:
Filter = self.weight
num_channels = Filter.shape[0]
results = []
for i in range(num_channels):
for j, g in enumerate(A):
A[j].edata['w_sum'] = g.edata['w'] * Filter[i][j]
sum_g = dgl.adj_sum_graph(A, 'w_sum')
results.append(sum_g)
return results