Source code for openhgnn.models.RHGNN

import dgl
import torch as th
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from . import BaseModel, register_model
import tqdm
import torch

from dgl.ops import edge_softmax
import dgl.function as fn


[docs]@register_model('RHGNN') class RHGNN(BaseModel): r""" This is the main method of model RHGNN Parameters ---------- graph: dgl.DGLHeteroGraph a heterogeneous graph input_dim_dict: dict node input dimension dictionary hidden_dim: int node hidden dimension relation_input_dim: int relation input dimension relation_hidden_dim: int relation hidden dimension num_layers: int number of stacked layers n_heads: int number of attention heads dropout: float dropout rate negative_slope: float negative slope residual: boolean residual connections or not norm: boolean layer normalization or not """
[docs] @classmethod def build_model_from_args(cls, args, hg): input_dim_dict = {ntype: hg.nodes[ntype].data['h'].shape[1] for ntype in hg.ntypes} return cls(graph=hg, input_dim_dict=input_dim_dict, hidden_dim=args.hidden_dim, relation_input_dim=args.relation_hidden_units, relation_hidden_dim=args.relation_hidden_units, num_layers=args.num_layers,category=args.category, out_dim=args.out_dim )
def __init__(self, graph: dgl.DGLHeteroGraph, input_dim_dict, hidden_dim: int, relation_input_dim: int, relation_hidden_dim: int, num_layers: int, category, out_dim, n_heads: int = 4, dropout: float = 0.2, negative_slope: float = 0.2, residual: bool = True, norm: bool = True): super(RHGNN, self).__init__() self.category = category self.input_dim_dict = input_dim_dict self.num_layers = num_layers self.hidden_dim = hidden_dim self.relation_input_dim = relation_input_dim self.relation_hidden_dim = relation_input_dim self.n_heads = n_heads self.dropout = dropout self.negative_slope = negative_slope self.residual = residual self.out_dim = out_dim self.norm = norm print(graph.etypes) # relation embedding dictionary self.relation_embedding = nn.ParameterDict({ etype: nn.Parameter(torch.randn(relation_input_dim, 1)) for etype in graph.etypes }) # align the dimension of different types of nodes self.projection_layer = nn.ModuleDict({ ntype: nn.Linear(self.input_dim_dict[ntype], hidden_dim * n_heads) for ntype in input_dim_dict }) # each layer takes in the heterogeneous graph as input self.layers = nn.ModuleList() # for each relation_layer self.layers.append( R_HGNN_Layer(graph, hidden_dim * n_heads, hidden_dim, relation_input_dim, relation_hidden_dim, n_heads, dropout, negative_slope, residual, norm)) for _ in range(1, self.num_layers): self.layers.append(R_HGNN_Layer(graph, hidden_dim * n_heads, hidden_dim, relation_hidden_dim * n_heads, relation_hidden_dim, n_heads, dropout, negative_slope, residual, norm)) # transformation matrix for target node representation under each relation self.node_transformation_weight = nn.ParameterDict({ etype: nn.Parameter(torch.randn(n_heads, hidden_dim, hidden_dim)) for etype in graph.etypes }) # transformation matrix for relation representation self.relation_transformation_weight = nn.ParameterDict({ etype: nn.Parameter(torch.randn(n_heads, relation_hidden_dim, hidden_dim)) for etype in graph.etypes }) # different relations fusing module self.relation_fusing = RelationFusing(node_hidden_dim=hidden_dim, relation_hidden_dim=relation_hidden_dim, num_heads=n_heads, dropout=dropout, negative_slope=negative_slope) self.classifier = nn.Linear(self.hidden_dim * self.n_heads, self.out_dim) #### todo self.reset_parameters()
[docs] def reset_parameters(self): """Reinitialize learnable parameters.""" gain = nn.init.calculate_gain('relu') for etype in self.relation_embedding: nn.init.xavier_normal_(self.relation_embedding[etype], gain=gain) for ntype in self.projection_layer: nn.init.xavier_normal_(self.projection_layer[ntype].weight, gain=gain) for etype in self.node_transformation_weight: nn.init.xavier_normal_(self.node_transformation_weight[etype], gain=gain) for etype in self.relation_transformation_weight: nn.init.xavier_normal_(self.relation_transformation_weight[etype], gain=gain)
[docs] def forward(self, blocks: list, relation_target_node_features=None, relation_embedding: dict = None): r""" Parameters ---------- blocks: list list of sampled dgl.DGLHeteroGraph relation_target_node_features: dict target node features under each relation, e.g {(srctype, etype, dsttype): features} relation_embedding: dict embedding for each relation, e.g {etype: feature} or None """ relation_target_node_features = {} for stype, etype, dtype in blocks[0].canonical_etypes: relation_target_node_features[(stype, etype, dtype)] = blocks[0].srcnodes[dtype].data.get('h').to(torch.float32) # target relation feature projection for stype, reltype, dtype in relation_target_node_features: relation_target_node_features[(stype, reltype, dtype)] = self.projection_layer[dtype]( relation_target_node_features[(stype, reltype, dtype)]) # each relation is associated with a specific type, if no semantic information is given, # then the one-hot representation of each relation is assign with trainable hidden representation if relation_embedding is None: relation_embedding = {} for etype in self.relation_embedding: relation_embedding[etype] = self.relation_embedding[etype].flatten() # graph convolution for block, layer in zip(blocks, self.layers): relation_target_node_features, relation_embedding = layer(block, relation_target_node_features, relation_embedding) relation_fusion_embedding_dict = {} # relation_target_node_features -> {(srctype, etype, dsttype): target_node_features} for dsttype in set([dtype for _, _, dtype in relation_target_node_features]): relation_target_node_features_dict = {etype: relation_target_node_features[(stype, etype, dtype)] for stype, etype, dtype in relation_target_node_features} etypes = [etype for stype, etype, dtype in relation_target_node_features if dtype == dsttype] dst_node_features = [relation_target_node_features_dict[etype] for etype in etypes] dst_relation_embeddings = [relation_embedding[etype] for etype in etypes] dst_node_feature_transformation_weight = [self.node_transformation_weight[etype] for etype in etypes] dst_relation_embedding_transformation_weight = [self.relation_transformation_weight[etype] for etype in etypes] # Tensor, shape (heads_num * hidden_dim) dst_node_relation_fusion_feature = self.relation_fusing(dst_node_features, dst_relation_embeddings, dst_node_feature_transformation_weight, dst_relation_embedding_transformation_weight) relation_fusion_embedding_dict[dsttype] = dst_node_relation_fusion_feature # relation_fusion_embedding_dict, {ntype: tensor -> (nodes, n_heads * hidden_dim)} # relation_target_node_features, {(srctype, etype, dsttype): (dst_nodes, n_heads * hidden_dim)} classifier_result = self.classifier(relation_fusion_embedding_dict[self.category]) # return relation_fusion_embedding_dict, relation_target_node_features return {self.category: classifier_result}
[docs] def inference(self, graph: dgl.DGLHeteroGraph, relation_target_node_features: dict, relation_embedding: dict = None, device: str = 'cuda:0'): r""" mini-batch inference of final representation over all node types. Outer loop: Interate the layers, Inner loop: Interate the batches Parameters ---------- graph: dgl.DGLHeteroGraph The whole relational graphs relation_target_node_features: dict target node features under each relation, e.g {(srctype, etype, dsttype): features} relation_embedding: dict embedding for each relation, e.g {etype: feature} or None device: str device """ with torch.no_grad(): if relation_embedding is None: relation_embedding = {} for etype in self.relation_embedding: relation_embedding[etype] = self.relation_embedding[etype].flatten() # interate over each layer for index, layer in enumerate(self.layers): # Tensor, features of all relation embeddings of the target nodes, store on cpu y = { (stype, etype, dtype): torch.zeros(graph.number_of_nodes(dtype), self.hidden_dim * self.n_heads) for stype, etype, dtype in graph.canonical_etypes} # full sample for each type of nodes sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1) dataloader = dgl.dataloading.NodeDataLoader( graph, {ntype: torch.arange(graph.number_of_nodes(ntype)) for ntype in graph.ntypes}, sampler, batch_size=1280, shuffle=True, drop_last=False, num_workers=4) tqdm_dataloader = tqdm(dataloader, ncols=120) for batch, (input_nodes, output_nodes, blocks) in enumerate(tqdm_dataloader): block = blocks[0].to(device) # for relational graphs that only contain a single type of nodes, construct the input and output node dictionary if len(set(blocks[0].ntypes)) == 1: input_nodes = {blocks[0].ntypes[0]: input_nodes} output_nodes = {blocks[0].ntypes[0]: output_nodes} input_features = {(stype, etype, dtype): relation_target_node_features[(stype, etype, dtype)][ input_nodes[dtype]].to(device) for stype, etype, dtype in relation_target_node_features.keys()} input_relation_features = relation_embedding if index == 0: # target relation feature projection for the first layer in the full batch inference for stype, reltype, dtype in input_features: input_features[(stype, reltype, dtype)] = self.projection_layer[dtype]( input_features[(stype, reltype, dtype)]) h, input_relation_features = layer(block, input_features, input_relation_features) for stype, reltype, dtype in h.keys(): y[(stype, reltype, dtype)][output_nodes[dtype]] = h[(stype, reltype, dtype)].cpu() tqdm_dataloader.set_description(f'inference for the {batch}-th batch in model {index}-th layer') # update the features of all the nodes (after the graph convolution) in the whole graph relation_target_node_features = y # relation embedding is updated after each layer relation_embedding = input_relation_features for stype, etype, dtype in relation_target_node_features: relation_target_node_features[(stype, etype, dtype)] = relation_target_node_features[ (stype, etype, dtype)].to(device) relation_fusion_embedding_dict = {} # relation_target_node_features -> {(srctype, etype, dsttype): target_node_features} for dsttype in set([dtype for _, _, dtype in relation_target_node_features]): relation_target_node_features_dict = {etype: relation_target_node_features[(stype, etype, dtype)] for stype, etype, dtype in relation_target_node_features} etypes = [etype for stype, etype, dtype in relation_target_node_features if dtype == dsttype] dst_node_features = [relation_target_node_features_dict[etype] for etype in etypes] dst_relation_embeddings = [relation_embedding[etype] for etype in etypes] dst_node_feature_transformation_weight = [self.node_transformation_weight[etype] for etype in etypes] dst_relation_embedding_transformation_weight = [self.relation_transformation_weight[etype] for etype in etypes] # use mini-batch to avoid out of memory in inference relation_fusion_embedding = [] index = 0 batch_size = 2560 while index < dst_node_features[0].shape[0]: # Tensor, shape (heads_num * hidden_dim) relation_fusion_embedding.append(self.relation_fusing( [dst_node_feature[index: index + batch_size, :] for dst_node_feature in dst_node_features], dst_relation_embeddings, dst_node_feature_transformation_weight, dst_relation_embedding_transformation_weight)) index += batch_size relation_fusion_embedding_dict[dsttype] = torch.cat(relation_fusion_embedding, dim=0) # relation_fusion_embedding_dict, {ntype: tensor -> (nodes, n_heads * hidden_dim)} # relation_target_node_features, {ntype: tensor -> (num_relations, nodes, n_heads * hidden_dim)} return relation_fusion_embedding_dict, relation_target_node_features
# hetetoConv class HeteroGraphConv(nn.Module): r""" A generic module for computing convolution on heterogeneous graphs. The heterograph convolution applies sub-modules on their associating relation graphs, which reads the features from source nodes and writes the updated ones to destination nodes. If multiple relations have the same destination node types, their results are aggregated by the specified method. If the relation graph has no edge, the corresponding module will not be called. Parameters ---------- mods : dict[str, nn.Module] Modules associated with every edge types. """ def __init__(self, mods: dict): super(HeteroGraphConv, self).__init__() self.mods = nn.ModuleDict(mods) def forward(self, graph: dgl.DGLHeteroGraph, input_src: dict, input_dst: dict, relation_embedding: dict, node_transformation_weight: nn.ParameterDict, relation_transformation_weight: nn.ParameterDict): r""" call the forward function with each module. Parameters ---------- graph: DGLHeteroGraph The Heterogeneous Graph. input_src: dict[tuple, Tensor] Input source node features {relation_type: features, } input_dst: dict[tuple, Tensor] Input destination node features {relation_type: features, } relation_embedding: dict[etype, Tensor] Input relation features {etype: feature} node_transformation_weight: nn.ParameterDict weights {ntype, (inp_dim, hidden_dim)} relation_transformation_weight: nn.ParameterDict weights {etype, (n_heads, 2 * hidden_dim)} Returns ------- outputs: dict[tuple, Tensor] Output representations for every relation -> {(stype, etype, dtype): features}. """ # find reverse relation dict reverse_relation_dict = {} for srctype, reltype, dsttype in list(input_src.keys()): for stype, etype, dtype in input_src: if stype == dsttype and dtype == srctype and etype != reltype: reverse_relation_dict[reltype] = etype break # dictionary, {(srctype, etype, dsttype): representations} outputs = dict() for stype, etype, dtype in graph.canonical_etypes: rel_graph = graph[stype, etype, dtype] if rel_graph.number_of_edges() == 0: continue # for example, (author, writes, paper) relation, take author as src_nodes, take paper as dst_nodes dst_representation = self.mods[etype](rel_graph, (input_src[(dtype, reverse_relation_dict[etype], stype)], input_dst[(stype, etype, dtype)]), node_transformation_weight[dtype], node_transformation_weight[stype], relation_embedding[etype], relation_transformation_weight[etype]) # dst_representation (dst_nodes, hid_dim) outputs[(stype, etype, dtype)] = dst_representation return outputs # relation crossing class RelationCrossing(nn.Module): def __init__(self, in_feats: int, out_feats: int, num_heads: int, dropout: float = 0.0, negative_slope: float = 0.2): r""" Relation crossing layer Parameters ---------- in_feats : pair of ints input feature size out_feats : int output feature size num_heads : int number of heads in Multi-Head Attention dropout : float optional, dropout rate, defaults: 0.0 negative_slope : float optional, negative slope rate, defaults: 0.2 """ super(RelationCrossing, self).__init__() self._in_feats = in_feats self._out_feats = out_feats self._num_heads = num_heads self.dropout = nn.Dropout(dropout) self.leaky_relu = nn.LeakyReLU(negative_slope) def forward(self, dsttype_node_features: torch.Tensor, relations_crossing_attention_weight: nn.Parameter): r""" Parameters ---------- dsttype_node_features: a tensor of (dsttype_node_relations_num, num_dst_nodes, n_heads * hidden_dim) relations_crossing_attention_weight: Parameter the shape is (n_heads, hidden_dim) Returns: ---------- output_features: Tensor """ if len(dsttype_node_features) == 1: # (num_dst_nodes, n_heads * hidden_dim) dsttype_node_features = dsttype_node_features.squeeze(dim=0) else: # (dsttype_node_relations_num, num_dst_nodes, n_heads, hidden_dim) dsttype_node_features = dsttype_node_features.reshape(dsttype_node_features.shape[0], -1, self._num_heads, self._out_feats) # shape -> (dsttype_node_relations_num, dst_nodes_num, n_heads, 1), (dsttype_node_relations_num, dst_nodes_num, n_heads, hidden_dim) * (n_heads, hidden_dim) dsttype_node_relation_attention = (dsttype_node_features * relations_crossing_attention_weight).sum(dim=-1, keepdim=True) dsttype_node_relation_attention = F.softmax(self.leaky_relu(dsttype_node_relation_attention), dim=0) # shape -> (dst_nodes_num, n_heads, hidden_dim), (dsttype_node_relations_num, dst_nodes_num, n_heads, hidden_dim) * (dsttype_node_relations_num, dst_nodes_num, n_heads, 1) dsttype_node_features = (dsttype_node_features * dsttype_node_relation_attention).sum(dim=0) dsttype_node_features = self.dropout(dsttype_node_features) # shape -> (dst_nodes_num, n_heads * hidden_dim) dsttype_node_features = dsttype_node_features.reshape(-1, self._num_heads * self._out_feats) return dsttype_node_features # relation fusing class RelationFusing(nn.Module): def __init__(self, node_hidden_dim: int, relation_hidden_dim: int, num_heads: int, dropout: float = 0.0, negative_slope: float = 0.2): r""" Parameters ---------- node_hidden_dim: int node hidden feature size relation_hidden_dim: int relation hidden feature size num_heads: int number of heads in Multi-Head Attention dropout: float dropout rate, defaults: 0.0 negative_slope: float negative slope, defaults: 0.2 """ super(RelationFusing, self).__init__() self.node_hidden_dim = node_hidden_dim self.relation_hidden_dim = relation_hidden_dim self.num_heads = num_heads self.dropout = nn.Dropout(dropout) self.leaky_relu = nn.LeakyReLU(negative_slope) def forward(self, dst_node_features: list, dst_relation_embeddings: list, dst_node_feature_transformation_weight: list, dst_relation_embedding_transformation_weight: list): r""" Parameters ---------- dst_node_features: list e.g [each shape is (num_dst_nodes, n_heads * node_hidden_dim)] dst_relation_embeddings: list e.g [each shape is (n_heads * relation_hidden_dim)] dst_node_feature_transformation_weight: list e.g [each shape is (n_heads, node_hidden_dim, node_hidden_dim)] dst_relation_embedding_transformation_weight: list e.g [each shape is (n_heads, relation_hidden_dim, relation_hidden_dim)] Returns ---------- dst_node_relation_fusion_feature: Tensor the target node representation after relation-aware representations fusion """ if len(dst_node_features) == 1: # (num_dst_nodes, n_heads * hidden_dim) dst_node_relation_fusion_feature = dst_node_features[0] else: # (num_dst_relations, nodes, n_heads, node_hidden_dim) dst_node_features = torch.stack(dst_node_features, dim=0).reshape(len(dst_node_features), -1, self.num_heads, self.node_hidden_dim) # (num_dst_relations, n_heads, relation_hidden_dim) dst_relation_embeddings = torch.stack(dst_relation_embeddings, dim=0).reshape(len(dst_node_features), self.num_heads, self.relation_hidden_dim) # (num_dst_relations, n_heads, node_hidden_dim, node_hidden_dim) dst_node_feature_transformation_weight = torch.stack(dst_node_feature_transformation_weight, dim=0).reshape( len(dst_node_features), self.num_heads, self.node_hidden_dim, self.node_hidden_dim) # (num_dst_relations, n_heads, relation_hidden_dim, relation_hidden_dim) dst_relation_embedding_transformation_weight = torch.stack(dst_relation_embedding_transformation_weight, dim=0).reshape(len(dst_node_features), self.num_heads, self.relation_hidden_dim, self.node_hidden_dim) # shape (num_dst_relations, nodes, n_heads, hidden_dim) dst_node_features = torch.einsum('abcd,acde->abce', dst_node_features, dst_node_feature_transformation_weight) # shape (num_dst_relations, n_heads, hidden_dim) dst_relation_embeddings = torch.einsum('abc,abcd->abd', dst_relation_embeddings, dst_relation_embedding_transformation_weight) # shape (num_dst_relations, nodes, n_heads, 1) attention_scores = (dst_node_features * dst_relation_embeddings.unsqueeze(dim=1)).sum(dim=-1, keepdim=True) attention_scores = F.softmax(self.leaky_relu(attention_scores), dim=0) # (nodes, n_heads, hidden_dim) dst_node_relation_fusion_feature = (dst_node_features * attention_scores).sum(dim=0) dst_node_relation_fusion_feature = self.dropout(dst_node_relation_fusion_feature) # (nodes, n_heads * hidden_dim) dst_node_relation_fusion_feature = dst_node_relation_fusion_feature.reshape(-1, self.num_heads * self.node_hidden_dim) return dst_node_relation_fusion_feature # relationGraphConv class RelationGraphConv(nn.Module): def __init__(self, in_feats: tuple, out_feats: int, num_heads: int, dropout: float = 0.0, negative_slope: float = 0.2): r""" Relation graph convolution layer Parameters ---------- in_feats : pair of ints input feature size out_feats : int output feature size num_heads : int number of heads in Multi-Head Attention dropout : float optional, dropout rate, defaults: 0 negative_slope : float optional, negative slope rate, defaults: 0.2 """ super(RelationGraphConv, self).__init__() self._in_src_feats, self._in_dst_feats = in_feats[0], in_feats[1] self._out_feats = out_feats self._num_heads = num_heads self.dropout = nn.Dropout(dropout) self.leaky_relu = nn.LeakyReLU(negative_slope) self.relu = nn.ReLU() def forward(self, graph: dgl.DGLHeteroGraph, feat: tuple, dst_node_transformation_weight: nn.Parameter, src_node_transformation_weight: nn.Parameter, relation_embedding: torch.Tensor, relation_transformation_weight: nn.Parameter): r""" Parameters ---------- graph : specific relational DGLHeteroGraph feat : pair of torch.Tensor e.g The pair contains two tensors of shape (N_{in}, D_{in_{src}})` and (N_{out}, D_{in_{dst}}). dst_node_transformation_weight: e.g Parameter (input_dst_dim, n_heads * hidden_dim) src_node_transformation_weight: e.g Parameter (input_src_dim, n_heads * hidden_dim) relation_embedding: torch.Tensor e.g (relation_input_dim) relation_transformation_weight: e,g Parameter (relation_input_dim, n_heads * 2 * hidden_dim) Returns ------- dst_features: torch.Tensor shape (N, H, D_out)` where H is the number of heads, and D_out is size of output feature. """ graph = graph.local_var() # Tensor, (N_src, input_src_dim) feat_src = self.dropout(feat[0]) # Tensor, (N_dst, input_dst_dim) feat_dst = self.dropout(feat[1]) # Tensor, (N_src, n_heads, hidden_dim) -> (N_src, input_src_dim) * (input_src_dim, n_heads * hidden_dim) feat_src = torch.matmul(feat_src, src_node_transformation_weight).view(-1, self._num_heads, self._out_feats) # Tensor, (N_dst, n_heads, hidden_dim) -> (N_dst, input_dst_dim) * (input_dst_dim, n_heads * hidden_dim) feat_dst = torch.matmul(feat_dst, dst_node_transformation_weight).view(-1, self._num_heads, self._out_feats) # Tensor, (n_heads, 2 * hidden_dim) -> (1, input_dst_dim) * (input_dst_dim, n_heads * hidden_dim) relation_attention_weight = torch.matmul(relation_embedding.unsqueeze(dim=0), relation_transformation_weight).view(self._num_heads, 2 * self._out_feats) # first decompose the weight vector into [a_l || a_r], then # a^T [Wh_i || Wh_j] = a_l Wh_i + a_r Wh_j, This implementation is much efficient # Tensor, (N_dst, n_heads, 1), (N_dst, n_heads, hidden_dim) * (n_heads, hidden_dim) e_dst = (feat_dst * relation_attention_weight[:, :self._out_feats]).sum(dim=-1, keepdim=True) # Tensor, (N_src, n_heads, 1), (N_src, n_heads, hidden_dim) * (n_heads, hidden_dim) e_src = (feat_src * relation_attention_weight[:, self._out_feats:]).sum(dim=-1, keepdim=True) # (N_src, n_heads, hidden_dim), (N_src, n_heads, 1) graph.srcdata.update({'ft': feat_src, 'e_src': e_src}) # (N_dst, n_heads, 1) graph.dstdata.update({'e_dst': e_dst}) # compute edge attention, e_src and e_dst are a_src * Wh_src and a_dst * Wh_dst respectively. graph.apply_edges(fn.u_add_v('e_src', 'e_dst', 'e')) # shape (edges_num, heads, 1) e = self.leaky_relu(graph.edata.pop('e')) # compute softmax graph.edata['a'] = edge_softmax(graph, e) graph.update_all(fn.u_mul_e('ft', 'a', 'msg'), fn.sum('msg', 'feat')) # (N_dst, n_heads * hidden_dim), reshape (N_dst, n_heads, hidden_dim) dst_features = graph.dstdata.pop('feat').reshape(-1, self._num_heads * self._out_feats) dst_features = self.relu(dst_features) return dst_features class R_HGNN_Layer(nn.Module): def __init__(self, graph, input_dim: int, hidden_dim: int, relation_input_dim: int, relation_hidden_dim: int, n_heads: int = 8, dropout: float = 0.2, negative_slope: float = 0.2, residual: bool = True, norm: bool = False): """ Parameters ---------- graph: a heterogeneous graph input_dim: int node input dimension hidden_dim: int node hidden dimension relation_input_dim: int relation input dimension relation_hidden_dim: int relation hidden dimension n_heads: int number of attention heads dropout: float dropout rate negative_slope: float negative slope residual: boolean residual connections or not norm: boolean layer normalization or not """ super(R_HGNN_Layer, self).__init__() self.input_dim = input_dim self.hidden_dim = hidden_dim self.relation_input_dim = relation_input_dim self.relation_hidden_dim = relation_hidden_dim self.n_heads = n_heads self.dropout = dropout self.negative_slope = negative_slope self.residual = residual self.norm = norm # node transformation parameters of each type self.node_transformation_weight = nn.ParameterDict({ ntype: nn.Parameter(torch.randn(input_dim, n_heads * hidden_dim)) for ntype in graph.ntypes }) # relation transformation parameters of each type, used as attention queries self.relation_transformation_weight = nn.ParameterDict({ etype: nn.Parameter(torch.randn(relation_input_dim, n_heads * 2 * hidden_dim)) for etype in graph.etypes }) # relation propagation layer of each relation self.relation_propagation_layer = nn.ModuleDict({ etype: nn.Linear(relation_input_dim, n_heads * relation_hidden_dim) for etype in graph.etypes }) # hetero conv modules, each RelationGraphConv deals with a single type of relation self.hetero_conv = HeteroGraphConv({ etype: RelationGraphConv(in_feats=(input_dim, input_dim), out_feats=hidden_dim, num_heads=n_heads, dropout=dropout, negative_slope=negative_slope) for etype in graph.etypes }) if self.residual: # residual connection self.res_fc = nn.ModuleDict() self.residual_weight = nn.ParameterDict() for ntype in graph.ntypes: self.res_fc[ntype] = nn.Linear(input_dim, n_heads * hidden_dim) self.residual_weight[ntype] = nn.Parameter(torch.randn(1)) if self.norm: self.layer_norm = nn.ModuleDict({ntype: nn.LayerNorm(n_heads * hidden_dim) for ntype in graph.ntypes}) # relation type crossing attention trainable parameters self.relations_crossing_attention_weight = nn.ParameterDict({ etype: nn.Parameter(torch.randn(n_heads, hidden_dim)) for etype in graph.etypes }) # different relations crossing layer self.relations_crossing_layer = RelationCrossing(in_feats=n_heads * hidden_dim, out_feats=hidden_dim, num_heads=n_heads, dropout=dropout, negative_slope=negative_slope) self.reset_parameters() def reset_parameters(self): """Reinitialize learnable parameters.""" gain = nn.init.calculate_gain('relu') for weight in self.node_transformation_weight: nn.init.xavier_normal_(self.node_transformation_weight[weight], gain=gain) for weight in self.relation_transformation_weight: nn.init.xavier_normal_(self.relation_transformation_weight[weight], gain=gain) for etype in self.relation_propagation_layer: nn.init.xavier_normal_(self.relation_propagation_layer[etype].weight, gain=gain) if self.residual: for ntype in self.res_fc: nn.init.xavier_normal_(self.res_fc[ntype].weight, gain=gain) for weight in self.relations_crossing_attention_weight: nn.init.xavier_normal_(self.relations_crossing_attention_weight[weight], gain=gain) def forward(self, graph: dgl.DGLHeteroGraph, relation_target_node_features: dict, relation_embedding: dict): """ :param graph: dgl.DGLHeteroGraph :param relation_target_node_features: dict, {relation_type: target_node_features shape (N_nodes, input_dim)}, each value in relation_target_node_features represents the representation of target node features :param relation_embedding: embedding for each relation, dict, {etype: feature} :return: output_features: dict, {relation_type: target_node_features} """ # in each relation, target type of nodes has an embedding # dictionary of {(srctype, etypye, dsttype): target_node_features} input_src = relation_target_node_features if graph.is_block: input_dst = {} for srctype, etypye, dsttype in relation_target_node_features: input_dst[(srctype, etypye, dsttype)] = relation_target_node_features[(srctype, etypye, dsttype)][ :graph.number_of_dst_nodes(dsttype)] else: input_dst = relation_target_node_features # output_features, dict {(srctype, etypye, dsttype): target_node_features} output_features = self.hetero_conv(graph, input_src, input_dst, relation_embedding, self.node_transformation_weight, self.relation_transformation_weight) # residual connection for the target node if self.residual: for srctype, etype, dsttype in output_features: alpha = torch.sigmoid(self.residual_weight[dsttype]) output_features[(srctype, etype, dsttype)] = output_features[(srctype, etype, dsttype)] * alpha + \ self.res_fc[dsttype]( input_dst[(srctype, etype, dsttype)]) * (1 - alpha) output_features_dict = {} # different relations crossing layer for srctype, etype, dsttype in output_features: # (dsttype_node_relations_num, dst_nodes_num, n_heads * hidden_dim) dst_node_relations_features = torch.stack([output_features[(stype, reltype, dtype)] for stype, reltype, dtype in output_features if dtype == dsttype], dim=0) output_features_dict[(srctype, etype, dsttype)] = self.relations_crossing_layer(dst_node_relations_features, self.relations_crossing_attention_weight[etype]) # layer norm for the output if self.norm: for srctype, etype, dsttype in output_features_dict: output_features_dict[(srctype, etype, dsttype)] = self.layer_norm[dsttype](output_features_dict[(srctype, etype, dsttype)]) relation_embedding_dict = {} for etype in relation_embedding: relation_embedding_dict[etype] = self.relation_propagation_layer[etype](relation_embedding[etype]) # relation features after relation crossing layer, {(srctype, etype, dsttype): target_node_features} # relation embeddings after relation update, {etype: relation_embedding} return output_features_dict, relation_embedding_dict