Source code for openhgnn.models.general_HGNN

import dgl
from ..layers import SkipConnection
from openhgnn.models import BaseModel, register_model
from ..models.HeteroMLP import HGNNPostMP, HGNNPreMP


stage_dict = {
    'stack': SkipConnection.HGNNStackStage,
    'skipsum': SkipConnection.HGNNSkipStage,
    'skipconcat': SkipConnection.HGNNSkipStage,
}


def HG_transformation(hg, metapaths_dict):
    graph_data = {}
    for key, mp in metapaths_dict.items():
        mp_g = dgl.metapath_reachable_graph(hg, mp)
        n_edge = mp_g.canonical_etypes[0]
        graph_data[(n_edge[0], key, n_edge[2])] = mp_g.edges()
    return dgl.heterograph(graph_data)


[docs]@register_model('general_HGNN') class general_HGNN(BaseModel): """ General heterogeneous GNN model """
[docs] @classmethod def build_model_from_args(cls, args, hg): out_node_type = args.out_node_type # args.subgraph_extraction = 'metapath' if args.subgraph_extraction == 'relation': new_hg = hg print('relation extraction!') elif args.subgraph_extraction == 'metapath': if hasattr(args, 'meta_paths_dict'): new_hg = HG_transformation(hg, args.meta_paths_dict) print('metapath extraction!') else: raise ValueError('No meta-path is specified!') elif args.subgraph_extraction == 'mixed': relation_dict = args.meta_paths_dict for etype in hg.canonical_etypes: relation_dict[etype[1]] = [etype] new_hg = HG_transformation(hg, relation_dict) print('mixed extraction!') pass else: raise ValueError('subgraph_extraction only supports relation, metapath and mixed') return cls(args, new_hg, out_node_type)
def __init__(self, args, hg, out_node_type, **kwargs): """ """ super(general_HGNN, self).__init__() self.hg = hg self.out_node_type = out_node_type # the first linear is operated in outside of model (in trainerflow) if args.layers_pre_mp - 1 > 0: self.pre_mp = HGNNPreMP(args, self.hg.ntypes, args.layers_pre_mp, args.hidden_dim, args.hidden_dim) if args.layers_gnn > 0: HGNNStage = stage_dict[args.stage_type] self.hgnn = HGNNStage(gnn_type=args.gnn_type, rel_names=self.hg.etypes, stage_type=args.stage_type, dim_in=args.hidden_dim, dim_out=args.hidden_dim, num_layers=args.layers_gnn, skip_every=1, dropout=args.dropout, act=args.activation, has_bn=args.has_bn, has_l2norm=args.has_l2norm, num_heads=args.num_heads, macro_func=args.macro_func) gnn_out_dim = self.hgnn.dim_out self.post_mp = HGNNPostMP(args, self.out_node_type, args.layers_post_mp, gnn_out_dim, args.out_dim)
[docs] def forward(self, hg, h_dict): with hg.local_scope(): hg = self.hg h_dict = {key: value for key, value in h_dict.items() if key in hg.ntypes} if hasattr(self, 'pre_mp'): h_dict = self.pre_mp(h_dict) if hasattr(self, 'hgnn'): h_dict = self.hgnn(hg, h_dict) if hasattr(self, 'post_mp'): out_h = {} for key, value in h_dict.items(): if key in self.out_node_type: out_h[key] = value out_h = self.post_mp(out_h) return out_h