openhgnn.models.base_model 源代码

from abc import ABCMeta
import torch.nn as nn


[文档]class BaseModel(nn.Module, metaclass=ABCMeta): @classmethod def build_model_from_args(cls, args, hg): r""" Build the model instance from args and hg. So every subclass inheriting it should override the method. """ raise NotImplementedError("Models must implement the build_model_from_args method") def __init__(self): super(BaseModel, self).__init__() def forward(self, *args): r""" The model plays a role of encoder. So the forward will encoder original features into new features. Parameters ----------- hg : dgl.DGlHeteroGraph the heterogeneous graph h_dict : dict[str, th.Tensor] the dict of heterogeneous feature Return ------- out_dic : dict[str, th.Tensor] A dict of encoded feature. In general, it should ouput all nodes embedding. It is allowed that just output the embedding of target nodes which are participated in loss calculation. """ raise NotImplementedError def extra_loss(self): r""" Some model want to use L2Norm which is not applied all parameters. Returns ------- th.Tensor """ raise NotImplementedError def h2dict(self, h, hdict): pre = 0 out_dict = {} for i, value in hdict.items(): out_dict[i] = h[pre:value.shape[0]+pre] pre += value.shape[0] return out_dict def get_emb(self): r""" Return the embedding of a model for further analysis. Returns ------- numpy.array """ raise NotImplementedError