Source code for openhgnn.models.GTN_sparse

import dgl
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn.pytorch import GraphConv, EdgeWeightNorm
from ..utils import transform_relation_graph_list
from . import BaseModel, register_model


[docs] @register_model('GTN') class GTN(BaseModel): r""" GTN from paper `Graph Transformer Networks <https://arxiv.org/abs/1911.06455>`__ in NeurIPS_2019. You can also see the extension paper `Graph Transformer Networks: Learning Meta-path Graphs to Improve GNNs <https://arxiv.org/abs/2106.06218.pdf>`__. `Code from author <https://github.com/seongjunyun/Graph_Transformer_Networks>`__. Given a heterogeneous graph :math:`G` and its edge relation type set :math:`\mathcal{R}`.Then we extract the single relation adjacency matrix list. In that, we can generate combination adjacency matrix by conv the single relation adjacency matrix list. We can generate :math:'l-length' meta-path adjacency matrix by multiplying combination adjacency matrix. Then we can generate node representation using a GCN layer. Parameters ---------- num_edge_type : int Number of relations. num_channels : int Number of conv channels. in_dim : int The dimension of input feature. hidden_dim : int The dimension of hidden layer. num_class : int Number of classification type. num_layers : int Length of hybrid metapath. category : string Type of predicted nodes. norm : bool If True, the adjacency matrix will be normalized. identity : bool If True, the identity matrix will be added to relation matrix set. """ @classmethod def build_model_from_args(cls, args, hg): if args.identity: num_edge_type = len(hg.canonical_etypes) + 1 else: num_edge_type = len(hg.canonical_etypes) # add self-loop edge return cls(num_edge_type=num_edge_type, num_channels=args.num_channels, in_dim=args.hidden_dim, hidden_dim=args.hidden_dim, num_class=args.out_dim, num_layers=args.num_layers, category=args.category, norm=args.norm_emd_flag, identity=args.identity) def __init__(self, num_edge_type, num_channels, in_dim, hidden_dim, num_class, num_layers, category, norm, identity): super(GTN, self).__init__() self.num_edge_type = num_edge_type self.num_channels = num_channels self.in_dim = in_dim self.hidden_dim = hidden_dim self.num_class = num_class self.num_layers = num_layers self.is_norm = norm self.category = category self.identity = identity layers = [] for i in range(num_layers): if i == 0: layers.append(GTLayer(num_edge_type, num_channels, first=True)) else: layers.append(GTLayer(num_edge_type, num_channels, first=False)) self.layers = nn.ModuleList(layers) self.gcn = GraphConv(in_feats=self.in_dim, out_feats=hidden_dim, norm='none', activation=F.relu) self.norm = EdgeWeightNorm(norm='right') self.linear1 = nn.Linear(self.hidden_dim * self.num_channels, self.hidden_dim) self.linear2 = nn.Linear(self.hidden_dim, self.num_class) self.category_idx = None self.A = None self.h = None def normalization(self, H): norm_H = [] for i in range(self.num_channels): g = H[i] g = dgl.remove_self_loop(g) g.edata['w_sum'] = self.norm(g, g.edata['w_sum']) norm_H.append(g) return norm_H def forward(self, hg, h): with hg.local_scope(): hg.ndata['h'] = h # * =============== Extract edges in original graph ================ if self.category_idx is None: self.A, h, self.category_idx = transform_relation_graph_list(hg, category=self.category, identity=self.identity) else: g = dgl.to_homogeneous(hg, ndata='h') h = g.ndata['h'] # X_ = self.gcn(g, self.h) A = self.A # * =============== Get new graph structure ================ for i in range(self.num_layers): if i == 0: H, W = self.layers[i](A) else: H, W = self.layers[i](A, H) if self.is_norm == True: H = self.normalization(H) # Ws.append(W) # * =============== GCN Encoder ================ for i in range(self.num_channels): g = dgl.remove_self_loop(H[i]) edge_weight = g.edata['w_sum'] g = dgl.add_self_loop(g) edge_weight = th.cat((edge_weight, th.full((g.number_of_nodes(),), 1, device=g.device))) edge_weight = self.norm(g, edge_weight) if i == 0: X_ = self.gcn(g, h, edge_weight=edge_weight) else: X_ = th.cat((X_, self.gcn(g, h, edge_weight=edge_weight)), dim=1) X_ = self.linear1(X_) X_ = F.relu(X_) y = self.linear2(X_) return {self.category: y[self.category_idx]}
class GTLayer(nn.Module): r""" CTLayer multiply each combination adjacency matrix :math:`l` times to a :math:`l-length` meta-paths adjacency matrix. The method to generate :math:`l-length` meta-path adjacency matrix can be described as: .. math:: A_{(l)}=\Pi_{i=1}^{l} A_{i} where :math:`A_{i}` is the combination adjacency matrix generated by GT conv. Parameters ---------- in_channels: int The input dimension of GTConv which is numerically equal to the number of relations. out_channels: int The input dimension of GTConv which is numerically equal to the number of channel in GTN. first: bool If true, the first combination adjacency matrix multiply the combination adjacency matrix. """ def __init__(self, in_channels, out_channels, first=True): super(GTLayer, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.first = first if self.first: self.conv1 = GTConv(in_channels, out_channels) self.conv2 = GTConv(in_channels, out_channels) else: self.conv1 = GTConv(in_channels, out_channels) def forward(self, A, H_=None): if self.first: result_A = self.conv1(A) result_B = self.conv2(A) W = [(F.softmax(self.conv1.weight, dim=1)).detach(), (F.softmax(self.conv2.weight, dim=1)).detach()] else: result_A = H_ result_B = self.conv1(A) W = [(F.softmax(self.conv1.weight, dim=1)).detach()] H = [] for i in range(len(result_A)): g = dgl.adj_product_graph(result_A[i], result_B[i], 'w_sum') H.append(g) return H, W class GTConv(nn.Module): r""" We conv each sub adjacency matrix :math:`A_{R_{i}}` to a combination adjacency matrix :math:`A_{1}`: .. math:: A_{1} = conv\left(A ; W_{c}\right)=\sum_{R_{i} \in R} w_{R_{i}} A_{R_{i}} where :math:`R_i \subseteq \mathcal{R}` and :math:`W_{c}` is the weight of each relation matrix """ def __init__(self, in_channels, out_channels, softmax_flag=True): super(GTConv, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.weight = nn.Parameter(th.Tensor(out_channels, in_channels)) self.softmax_flag = softmax_flag self.reset_parameters() def reset_parameters(self): nn.init.normal_(self.weight, std=0.01) def forward(self, A): if self.softmax_flag: Filter = F.softmax(self.weight, dim=1) else: Filter = self.weight num_channels = Filter.shape[0] results = [] for i in range(num_channels): for j, g in enumerate(A): A[j].edata['w_sum'] = g.edata['w'] * Filter[i][j] sum_g = dgl.adj_sum_graph(A, 'w_sum') results.append(sum_g) return results