Source code for openhgnn.models.HDE

import torch
import torch.nn as nn
import torch.nn.functional as F
from . import BaseModel, register_model


class GNN(nn.Module):
    """
    Aggregate 2-hop neighbor.
    """
    def __init__(self, input_dim, output_dim, num_neighbor, use_bias=True):
        super(GNN, self).__init__()
        self.input_dim = int(input_dim)
        self.num_fea = int(input_dim)
        self.output_dim = int(output_dim)
        self.num_neighbor = num_neighbor
        self.use_bias = use_bias
        self.linear1 = nn.Linear(self.input_dim * 2, 64)
        self.linear2 = nn.Linear(64+self.num_fea, 64)
        self.linear3 = nn.Linear(64, self.output_dim)

    def forward(self, fea):
        node = fea[:, :self.num_fea]
        neigh1 = fea[:, self.num_fea:self.num_fea * (self.num_neighbor + 1)]
        neigh1 = torch.reshape(neigh1, [-1, self.num_neighbor, self.num_fea])

        neigh2 = fea[:, self.num_fea * (self.num_neighbor + 1):]
        neigh2 = torch.reshape(neigh2, [-1, self.num_neighbor, self.num_neighbor, self.num_fea])
        neigh2_agg = torch.mean(neigh2, dim=2)
        tmp = torch.cat([neigh1, neigh2_agg], dim=2)
        tmp = F.relu(self.linear1(tmp))
        emb = torch.cat([node, torch.mean(tmp, dim=1)], dim=1)
        emb = F.relu(self.linear2(emb))
        emb = F.relu(self.linear3(emb))

        return emb


[docs]@register_model('HDE') class HDE(BaseModel): def __init__(self, input_dim, output_dim, num_neighbor, use_bias=True): super(HDE, self).__init__() self.input_dim = int(input_dim) self.output_dim = int(output_dim) self.num_neighbor = num_neighbor self.use_bias = use_bias self.aggregator = GNN(input_dim=input_dim, output_dim=output_dim, num_neighbor=num_neighbor) self.linear1 = nn.Linear(2*self.output_dim, 32) self.linear2 = nn.Linear(32, 2)
[docs] def forward(self, fea_a, fea_b): emb_a = self.aggregator(fea_a) emb_b = self.aggregator(fea_b) emb = torch.cat([emb_a, emb_b], dim=1) emb = F.relu(self.linear1(emb)) output = self.linear2(emb) return output
[docs] @classmethod def build_model_from_args(cls, args, hg): return cls(input_dim=args.input_dim, output_dim=args.output_dim, num_neighbor=args.num_neighbor, use_bias=args.use_bias)