Source code for openhgnn.tasks.link_prediction

import dgl
import torch as th
import torch.nn.functional as F
from dgl.dataloading.negative_sampler import Uniform
from . import BaseTask, register_task
from ..dataset import build_dataset
from ..utils import Evaluator


[docs]@register_task("link_prediction") class LinkPrediction(BaseTask): r""" Link prediction tasks. Attributes ----------- dataset : NodeClassificationDataset Task-related dataset evaluator : Evaluator offer evaluation metric Methods --------- get_graph : return a graph get_loss_fn : return a loss function """ def __init__(self, args): super(LinkPrediction, self).__init__() self.name_dataset = args.dataset self.logger = args.logger self.dataset = build_dataset(args.dataset, 'link_prediction', logger=self.logger) # self.evaluator = Evaluator() self.train_hg, self.val_hg, self.test_hg, self.neg_val_graph, self.neg_test_graph = self.dataset.get_split() self.pred_hg = getattr(self.dataset, 'pred_graph', None) if self.val_hg is None and self.test_hg is None: pass else: self.val_hg = self.val_hg.to(args.device) self.test_hg = self.test_hg.to(args.device) self.evaluator = Evaluator(args.seed) if not hasattr(args, 'score_fn'): self.ScorePredictor = HeteroDistMultPredictor() args.score_fn = 'distmult' elif args.score_fn == 'dot-product': self.ScorePredictor = HeteroDotProductPredictor() elif args.score_fn == 'distmult': self.ScorePredictor = HeteroDistMultPredictor() # deprecated, new score predictor of these score_fn are in their model # elif args.score_fn in ['transe', 'transh', 'transr', 'transd'] : # self.ScorePredictor = HeteroTransXPredictor(args.dis_norm) self.negative_sampler = Uniform(1) self.evaluation_metric = getattr(args, 'evaluation_metric', 'roc_auc') # default evaluation_metric is roc_auc if args.dataset in ['wn18', 'FB15k', 'FB15k-237']: self.evaluation_metric = 'mrr' self.filtered = args.filtered if hasattr(args, "valid_percent"): self.dataset.modify_size(args.valid_percent, 'valid') if hasattr(args, "test_percent"): self.dataset.modify_size(args.test_percent, 'test') args.logger.info('[Init Task] The task: link prediction, the dataset: {}, the evaluation metric is {}, ' 'the score function: {} '.format(self.name_dataset, self.evaluation_metric, args.score_fn)) def get_out_ntype(self): ntype = [] for l in self.dataset.target_link: ntype.append(l[0]) ntype.append(l[2]) return set(ntype) def get_graph(self): return self.dataset.g def get_loss_fn(self): return F.binary_cross_entropy_with_logits def get_evaluator(self, name): if name == 'acc': return self.evaluator.author_link_prediction elif name == 'mrr': return self.evaluator.mrr_ elif name == 'academic_lp': return self.evaluator.author_link_prediction elif name == 'roc_auc': return self.evaluator.cal_roc_auc
[docs] def evaluate(self, n_embedding, r_embedding=None, mode='test'): r""" Parameters ---------- n_embedding: th.Tensor the embedding of nodes r_embedding: th.Tensor the embedding of relation types mode: str the evaluation mode, train/valid/test Returns ------- """ if self.evaluation_metric == 'acc': acc = self.evaluator.author_link_prediction return dict(Accuracy=acc) elif self.evaluation_metric == 'mrr': mrr_matrix = self.evaluator.mrr_(n_embedding, r_embedding, self.dataset.train_triplets, self.dataset.valid_triplets, self.dataset.test_triplets, score_predictor=self.ScorePredictor, hits=[1, 3, 10], filtered=getattr(self, 'filtered', 'filtered'), eval_mode=mode) return mrr_matrix elif self.evaluation_metric == 'roc_auc': if mode == 'test': eval_hg = self.test_hg neg_hg = self.neg_val_graph elif mode == 'valid': eval_hg = self.val_hg neg_hg = self.neg_val_graph else: raise ValueError('Mode error, supported test and valid.') if neg_hg is None: neg_hg = self.construct_negative_graph(eval_hg) p_score = th.sigmoid(self.ScorePredictor(eval_hg, n_embedding, r_embedding)) n_score = th.sigmoid(self.ScorePredictor(neg_hg, n_embedding, r_embedding)) p_label = th.ones(len(p_score), device=p_score.device) n_label = th.zeros(len(n_score), device=p_score.device) roc_auc = self.evaluator.cal_roc_auc(th.cat((p_label, n_label)).cpu(), th.cat((p_score, n_score)).cpu()) loss = F.binary_cross_entropy_with_logits(th.cat((p_score, n_score)), th.cat((p_label, n_label))) return dict(roc_auc=roc_auc, loss=loss) else: return self.evaluator.link_prediction
def predict(self, n_embedding, r_embedding, **kwargs): score = th.sigmoid(self.ScorePredictor(self.pred_hg, n_embedding, r_embedding)) indices = self.pred_hg.edges() return indices, score def tranX_predict(self): pred_triples_T = self.dataset.pred_triples.T score = th.sigmoid(self.ScorePredictor(pred_triples_T[0], pred_triples_T[1], pred_triples_T[2])) indices = self.pred_hg.edges() return indices, score def downstream_evaluate(self, logits, evaluation_metric): if evaluation_metric == 'academic_lp': auc, macro_f1, micro_f1 = self.evaluator.author_link_prediction(logits, self.dataset.train_batch, self.dataset.test_batch) return dict(AUC=auc, Macro_f1=macro_f1, Mirco_f1=micro_f1) def get_batch(self): return self.dataset.train_batch, self.dataset.test_batch def get_train(self): return self.train_hg def get_labels(self): return self.dataset.get_labels() def dict2emd(self, r_embedding): r_emd = [] for i in range(self.dataset.num_rels): r_emd.append(r_embedding[str(i)]) return th.stack(r_emd).squeeze() def construct_negative_graph(self, hg): e_dict = { etype: hg.edges(etype=etype, form='eid') for etype in hg.canonical_etypes} neg_srcdst = self.negative_sampler(hg, e_dict) neg_pair_graph = dgl.heterograph(neg_srcdst, {ntype: hg.number_of_nodes(ntype) for ntype in hg.ntypes}) return neg_pair_graph
class HeteroDotProductPredictor(th.nn.Module): """ References: `documentation of dgl <https://docs.dgl.ai/guide/training-link.html#heterogeneous-graphs>_` """ def forward(self, edge_subgraph, x, *args, **kwargs): """ Parameters ---------- edge_subgraph: dgl.Heterograph the prediction graph only contains the edges of the target link x: dict[str: th.Tensor] the embedding dict. The key only contains the nodes involving with the target link. Returns ------- score: th.Tensor the prediction of the edges in edge_subgraph """ with edge_subgraph.local_scope(): for ntype in edge_subgraph.ntypes: edge_subgraph.nodes[ntype].data['x'] = x[ntype] for etype in edge_subgraph.canonical_etypes: edge_subgraph.apply_edges( dgl.function.u_dot_v('x', 'x', 'score'), etype=etype) score = edge_subgraph.edata['score'] if isinstance(score, dict): result = [] for _, value in score.items(): result.append(value) score = th.cat(result) return score.squeeze() class HeteroDistMultPredictor(th.nn.Module): def forward(self, edge_subgraph, x, r_embedding, *args, **kwargs): """ DistMult factorization (Yang et al. 2014) as the scoring function, which is known to perform well on standard link prediction benchmarks when used on its own. In DistMult, every relation r is associated with a diagonal matrix :math:`R_{r} \in \mathbb{R}^{d \times d}` and a triple (s, r, o) is scored as .. math:: f(s, r, o)=e_{s}^{T} R_{r} e_{o} Parameters ---------- edge_subgraph: dgl.Heterograph the prediction graph only contains the edges of the target link x: dict[str: th.Tensor] the node embedding dict. The key only contains the nodes involving with the target link. r_embedding: th.Tensor the all relation types embedding Returns ------- score: th.Tensor the prediction of the edges in edge_subgraph """ with edge_subgraph.local_scope(): for ntype in edge_subgraph.ntypes: edge_subgraph.nodes[ntype].data['x'] = x[ntype] for etype in edge_subgraph.canonical_etypes: e = r_embedding[etype[1]] n = edge_subgraph.num_edges(etype) if 1 == len(edge_subgraph.canonical_etypes): edge_subgraph.edata['e'] = e.expand(n, -1) else: edge_subgraph.edata['e'] = {etype: e.expand(n, -1)} edge_subgraph.apply_edges( dgl.function.u_mul_e('x', 'e', 's'), etype=etype) edge_subgraph.apply_edges( dgl.function.e_mul_v('s', 'x', 'score'), etype=etype) score = edge_subgraph.edata['score'] if isinstance(score, dict): result = [] for _, value in score.items(): result.append(th.sum(value, dim=1)) score = th.cat(result) else: score = th.sum(score, dim=1) return score # class HeteroTransXPredictor(th.nn.Module): # def __init__(self, dis_norm): # super(HeteroTransXPredictor, self).__init__() # self.dis_norm = dis_norm # def forward(self, h, r, t): # h = F.normalize(h, 2, -1) # r = F.normalize(r, 2, -1) # t = F.normalize(t, 2, -1) # dist = th.norm(h+r-t, self.dis_norm, dim=-1) # return dist