Source code for openhgnn.models.GATNE

import dgl
from . import register_model, BaseModel
import torch.nn as nn
import torch
import torch.nn.functional as F
from torch.nn import Parameter
import math
import dgl.function as fn


[docs] @register_model('GATNE-T') class GATNE(BaseModel): @classmethod def build_model_from_args(cls, args, hg): return cls(hg.num_nodes(), args.dim, args.edge_dim, hg.etypes, len(hg.etypes), args.att_dim) def __init__( self, num_nodes, embedding_size, embedding_u_size, edge_types, edge_type_count, att_dim, ): super(GATNE, self).__init__() self.num_nodes = num_nodes self.embedding_size = embedding_size self.embedding_u_size = embedding_u_size self.edge_types = edge_types self.edge_type_count = edge_type_count self.att_dim = att_dim self.node_embeddings = Parameter(torch.FloatTensor(num_nodes, embedding_size)) self.node_type_embeddings = Parameter( torch.FloatTensor(num_nodes, edge_type_count, embedding_u_size) ) self.trans_weights = Parameter( torch.FloatTensor(edge_type_count, embedding_u_size, embedding_size) ) self.trans_weights_s1 = Parameter( torch.FloatTensor(edge_type_count, embedding_u_size, att_dim) ) self.trans_weights_s2 = Parameter(torch.FloatTensor(edge_type_count, att_dim, 1)) self.reset_parameters() def reset_parameters(self): self.node_embeddings.data.uniform_(-1.0, 1.0) self.node_type_embeddings.data.uniform_(-1.0, 1.0) self.trans_weights.data.normal_(std=1.0 / math.sqrt(self.embedding_size)) self.trans_weights_s1.data.normal_(std=1.0 / math.sqrt(self.embedding_size)) self.trans_weights_s2.data.normal_(std=1.0 / math.sqrt(self.embedding_size)) # embs: [batch_size, embedding_size] def forward(self, block): input_nodes = block.srcdata[dgl.NID] output_nodes = block.dstdata[dgl.NID] batch_size = block.number_of_dst_nodes() node_embed = self.node_embeddings node_type_embed = [] with block.local_scope(): for i in range(self.edge_type_count): edge_type = self.edge_types[i] block.srcdata[edge_type] = self.node_type_embeddings[input_nodes, i] block.dstdata[edge_type] = self.node_type_embeddings[output_nodes, i] block.update_all( fn.copy_u(edge_type, "m"), fn.sum("m", edge_type), etype=edge_type ) node_type_embed.append(block.dstdata[edge_type]) node_type_embed = torch.stack(node_type_embed, 1) tmp_node_type_embed = node_type_embed.unsqueeze(2).view( -1, 1, self.embedding_u_size ) trans_w = ( self.trans_weights.unsqueeze(0) .repeat(batch_size, 1, 1, 1) .view(-1, self.embedding_u_size, self.embedding_size) ) trans_w_s1 = ( self.trans_weights_s1.unsqueeze(0) .repeat(batch_size, 1, 1, 1) .view(-1, self.embedding_u_size, self.att_dim) ) trans_w_s2 = ( self.trans_weights_s2.unsqueeze(0) .repeat(batch_size, 1, 1, 1) .view(-1, self.att_dim, 1) ) attention = ( F.softmax( torch.matmul( torch.tanh(torch.matmul(tmp_node_type_embed, trans_w_s1)), trans_w_s2, ) .squeeze(2) .view(-1, self.edge_type_count), dim=1, ) .unsqueeze(1) .repeat(1, self.edge_type_count, 1) ) node_type_embed = torch.matmul(attention, node_type_embed).view( -1, 1, self.embedding_u_size ) node_embed = node_embed[output_nodes].unsqueeze(1).repeat( 1, self.edge_type_count, 1 ) + torch.matmul(node_type_embed, trans_w).view( -1, self.edge_type_count, self.embedding_size ) last_node_embed = F.normalize(node_embed, dim=2) return last_node_embed # [batch_size, edge_type_count, embedding_size]
class NSLoss(nn.Module): def __init__(self, num_nodes, num_sampled, embedding_size): super(NSLoss, self).__init__() self.num_nodes = num_nodes self.num_sampled = num_sampled self.embedding_size = embedding_size self.weights = Parameter(torch.FloatTensor(num_nodes, embedding_size)) # [ (log(i+2) - log(i+1)) / log(num_nodes + 1)] self.sample_weights = F.normalize( torch.Tensor( [ (math.log(k + 2) - math.log(k + 1)) / math.log(num_nodes + 1) for k in range(num_nodes) ] ), dim=0, ) self.reset_parameters() def reset_parameters(self): self.weights.data.normal_(std=1.0 / math.sqrt(self.embedding_size)) def forward(self, input, embs, label): n = input.shape[0] log_target = torch.log( torch.sigmoid(torch.sum(torch.mul(embs, self.weights[label]), 1)) ) negs = torch.multinomial( self.sample_weights, self.num_sampled * n, replacement=True ).view(n, self.num_sampled) noise = torch.neg(self.weights[negs]) sum_log_sampled = torch.sum( torch.log(torch.sigmoid(torch.bmm(noise, embs.unsqueeze(2)))), 1 ).squeeze() loss = log_target + sum_log_sampled return -loss.sum() / n