Source code for openhgnn.models.HetGNN

import dgl
import torch as th
import torch.nn as nn
import dgl.function as fn
import torch.nn.functional as F
from . import BaseModel, register_model


[docs]@register_model('HetGNN') class HetGNN(BaseModel): r""" HetGNN[KDD2019]- `Heterogeneous Graph Neural Network <https://dl.acm.org/doi/abs/10.1145/3292500.3330961>`_ `Source Code Link <https://github.com/chuxuzhang/KDD2019_HetGNN>`_ The author of the paper only gives the academic dataset. Attributes ----------- Het_Aggrate : nn.Module Het_Aggregate """
[docs] @classmethod def build_model_from_args(cls, args, hg): return cls(hg, args)
def __init__(self, hg, args): super(HetGNN, self).__init__() self.Het_Aggregate = Het_Aggregate(hg.ntypes, args.dim) self.ntypes = hg.ntypes self.device = args.device self.loss_fn = HetGNN.compute_loss
[docs] def forward(self, hg, h=None): if h is None: h = self.extract_feature(hg, self.ntypes) x = self.Het_Aggregate(hg, h) return x
def evaluator(self): self.link_preddiction() self.node_classification() def get_embedding(self): input_features = self.model.extract_feature(self.hg, self.hg.ntypes) x = self.model(self.model.preprocess(self.hg, self.args).to(self.args.device), input_features) return x def link_preddiction(self): x = self.get_embedding() self.model.lp_evaluator(x[self.category].to('cpu').detach(), self.train_batch, self.test_batch) def node_classification(self): x = self.get_embedding() self.model.nc_evaluator(x[self.category].to('cpu').detach(), self.labels, self.train_idx, self.test_idx) @staticmethod def compute_loss(pos_score, neg_score): # an example hinge loss loss = [] for i in pos_score: loss.append(F.logsigmoid(pos_score[i])) loss.append(F.logsigmoid(-neg_score[i])) loss = th.cat(loss) return -loss.mean() @staticmethod def extract_feature(g, ntypes): input_features = {} for n in ntypes: ndata = g.srcnodes[n].data data = {} data['dw_embedding'] = ndata['dw_embedding'] data['abstract'] = ndata['abstract'] if n == 'paper': data['title'] = ndata['title'] data['venue'] = ndata['venue'] data['author'] = ndata['author'] data['reference'] = ndata['reference'] input_features[n] = data return input_features @staticmethod def pred(edge_subgraph, x): with edge_subgraph.local_scope(): edge_subgraph.ndata['x'] = x for etype in edge_subgraph.canonical_etypes: edge_subgraph.apply_edges( dgl.function.u_dot_v('x', 'x', 'score'), etype=etype) return edge_subgraph.edata['score']
class ScorePredictor(nn.Module): def forward(self, edge_subgraph, x): with edge_subgraph.local_scope(): edge_subgraph.ndata['x'] = x for etype in edge_subgraph.canonical_etypes: edge_subgraph.apply_edges( dgl.function.u_dot_v('x', 'x', 'score'), etype=etype) return edge_subgraph.edata['score'] class Het_Aggregate(nn.Module): r""" The whole model of HetGNN Attributes ----------- content_rnn : nn.Module het_content_encoder neigh_rnn : nn.Module aggregate_het_neigh atten_w : nn.ModuleDict[str, nn.Module] """ def __init__(self, ntypes, dim): super(Het_Aggregate, self).__init__() # ntypes means nodes type name self.ntypes =ntypes self.dim = dim self.content_rnn = het_content_encoder(dim) self.neigh_rnn = aggregate_het_neigh(ntypes, dim) self.atten_w = nn.ModuleDict({}) for n in self.ntypes: self.atten_w[n] = nn.Linear(in_features=dim * 2, out_features=1) self.softmax = nn.Softmax(dim=1) self.activation = nn.LeakyReLU() self.drop = nn.Dropout(p=0.5) self.bn = nn.BatchNorm1d(dim) self.embed_d = dim def forward(self, hg, h_dict): with hg.local_scope(): content_h = {} for ntype, h in h_dict.items(): content_h[ntype] = self.content_rnn(h) neigh_h = self.neigh_rnn(hg, content_h) # the content feature of the dst nodes dst_h = {k: v[:hg.number_of_dst_nodes(k)] for k, v in content_h.items()} out_h = {} for n in self.ntypes: d_h = dst_h[n] batch_size = d_h.shape[0] concat_h = [] concat_emd = [] for i in range(len(neigh_h[n])): concat_h.append(th.cat((d_h, neigh_h[n][i]), 1)) concat_emd.append(neigh_h[n][i]) concat_h.append(th.cat((d_h, d_h), 1)) concat_emd.append(d_h) concat_h = th.hstack(concat_h).view(batch_size * (len(self.ntypes) + 1), self.dim *2) atten_w = self.activation(self.atten_w[n](concat_h)).view(batch_size, len(self.ntypes) + 1) atten_w = self.softmax(atten_w).view(batch_size, 1, 4) # weighted combination concat_emd = th.hstack(concat_emd).view(batch_size, len(self.ntypes) + 1, self.dim) weight_agg_batch = th.bmm(atten_w, concat_emd).view(batch_size, self.dim) out_h[n] = weight_agg_batch return out_h class het_content_encoder(nn.Module): r""" The Encoding Heterogeneous Contents(C2) in the paper For a specific node type, encoder different content features with a LSTM. In paper, it is (b) NN-1: node heterogeneous contents encoder in figure 2. Parameters ------------ dim : int input dimension Attributes ------------ content_rnn : nn.Module nn.LSTM encode different content feature """ def __init__(self, dim): super(het_content_encoder, self).__init__() self.content_rnn = nn.LSTM(dim, int(dim / 2), 1, batch_first=True, bidirectional=True) self.content_rnn.flatten_parameters() self.dim = dim def forward(self, h_dict): r""" Parameters ---------- h_dict: dict[str, th.Tensor] key means different content feature Returns ------- content_h : th.tensor """ concate_embed = [] for _, h in h_dict.items(): concate_embed.append(h) concate_embed = th.cat(concate_embed, 1) concate_embed = concate_embed.view(concate_embed.shape[0], -1, self.dim) all_state, last_state = self.content_rnn(concate_embed) out_h = th.mean(all_state, 1).squeeze() return out_h class aggregate_het_neigh(nn.Module): r""" It is a Aggregating Heterogeneous Neighbors(C3) Same Type Neighbors Aggregation """ def __init__(self, ntypes, dim): super(aggregate_het_neigh, self).__init__() self.neigh_rnn = nn.ModuleDict({}) self.ntypes =ntypes for n in ntypes: self.neigh_rnn[n] = lstm_aggr(dim) def forward(self, hg, inputs): with hg.local_scope(): outputs = {} for i in self.ntypes: outputs[i] = [] if isinstance(inputs, tuple) or hg.is_block: if isinstance(inputs, tuple): src_inputs, dst_inputs = inputs else: src_inputs = inputs dst_inputs = {k: v[:hg.number_of_dst_nodes(k)] for k, v in inputs.items()} for stype, etype, dtype in hg.canonical_etypes: rel_graph = hg[stype, etype, dtype] if rel_graph.number_of_edges() == 0: continue if stype not in src_inputs or dtype not in dst_inputs: continue dstdata = self.neigh_rnn[stype]( rel_graph, (src_inputs[stype], dst_inputs[dtype])) outputs[dtype].append(dstdata) else: for stype, etype, dtype in hg.canonical_etypes: rel_graph = hg[stype, etype, dtype] if rel_graph.number_of_edges() == 0: continue if stype not in inputs: continue dstdata = self.neigh_rnn[stype]( rel_graph, inputs[stype]) outputs[dtype].append(dstdata) return outputs class lstm_aggr(nn.Module): r""" Aggregate the same neighbors with LSTM """ def __init__(self, dim): super(lstm_aggr, self).__init__() self.lstm = nn.LSTM(dim, int(dim / 2), 1, batch_first=True, bidirectional=True) self.lstm.flatten_parameters() def _lstm_reducer(self, nodes): m = nodes.mailbox['m'] # (B, L, D) batch_size = m.shape[0] all_state, last_state = self.lstm(m) return {'neigh': th.mean(all_state, 1)} def forward(self, g, inputs): with g.local_scope(): if isinstance(inputs, tuple) or g.is_block: if isinstance(inputs, tuple): src_inputs, dst_inputs = inputs else: src_inputs = inputs dst_inputs = {k: v[:g.number_of_dst_nodes(k)] for k, v in inputs.items()} g.srcdata['h'] = src_inputs g.update_all(fn.copy_u('h', 'm'), self._lstm_reducer) h_neigh = g.dstdata['neigh'] else: g.srcdata['h'] = inputs g.update_all(fn.copy_u('h', 'm'), self._lstm_reducer) h_neigh = g.dstdata['neigh'] return h_neigh # from openhgnn.models.micro_layer.LSTM_conv import LSTMConv # from openhgnn.models.HeteroGraphConv import HeteroGraphConv # class HetGNNConv(nn.Module): # def __init__(self, graph, ntypes, dim): # super(HetGNNConv, self).__init__() # # ntypes means nodes type name # self.ntypes =ntypes # self.dim = dim # # # hetero conv modules # self.micro_conv = HeteroGraphConv({ # etype: LSTMConv(dim=dim) # for srctype, etype, dsttype in graph.canonical_etypes # }) # # # different types aggregation module # self.macro_conv = AttConv(in_feats=hidden_dim * n_heads, out_feats=hidden_dim, # num_heads=n_heads, # dropout=dropout, negative_slope=0.2) # # self.atten_w = nn.ModuleDict({}) # for n in self.ntypes: # self.atten_w[n] = nn.Linear(in_features=dim * 2, out_features=1) # # self.softmax = nn.Softmax(dim=1) # self.activation = nn.LeakyReLU() # self.drop = nn.Dropout(p=0.5) # self.bn = nn.BatchNorm1d(dim) # self.embed_d = dim # # def forward(self, hg, h): # x = self.Het_Aggrate(hg, h) # return x