Source code for openhgnn.trainerflow.recommendation

import dgl
import numpy as np
import torch as th
from tqdm import tqdm
import torch
from openhgnn.models import build_model
from . import BaseFlow, register_flow
from ..utils import EarlyStopping


[docs] @register_flow("recommendation") class Recommendation(BaseFlow): """Recommendation flows.""" def __init__(self, args=None): super(Recommendation, self).__init__(args) self.target_link = self.task.dataset.target_link self.args.out_node_type = self.task.dataset.out_ntypes self.args.out_dim = self.args.hidden_dim self.model = build_model(self.model).build_model_from_args(self.args, self.hg) self.model = self.model.to(self.device) self.reg_weight = 0.1 self.metric = ['recall', 'ndcg'] self.val_metric = 'recall' # self.topk_list = [5, 10, 20, 50, 100] self.topk = 20 #self.evaluator = self.task.get_evaluator(self.metric) self.optimizer = ( th.optim.Adam(self.model.parameters(), lr=args.lr, weight_decay=args.weight_decay) ) self.patience = args.patience self.max_epoch = args.max_epoch self.num_neg = self.task.dataset.num_neg self.user_name = self.task.dataset.user_name self.item_name = self.task.dataset.item_name self.num_user = self.hg.num_nodes(self.user_name) self.num_item = self.hg.num_nodes(self.user_name) self.train_eid_dict = { etype: self.hg.edges(etype=etype, form='eid') for etype in self.hg.canonical_etypes} def preprocess(self): self.train_hg, self.val_hg, self.test_hg = self.task.get_split() self.train_neg_hg = self.task.dataset.construct_negative_graph(self.train_hg) self.train_hg = self.train_hg.to(self.device) self.val_hg = self.val_hg.to(self.device) self.test_hg = self.test_hg.to(self.device) self.negative_graph = self.train_neg_hg.to(self.device) self.positive_graph = self.train_hg.edge_type_subgraph([self.target_link]) # generage complete user-item graph for evaluation # src, dst = th.arange(self.num_user), th.arange(self.num_item) # src = src.repeat_interleave(self.num_item) # dst = dst.repeat(self.num_user) # self.eval_graph = dgl.heterograph({('user', 'user-item', 'item'): (src, dst)}, {'user': self.num_user, 'item': self.num_item}).to(self.device) self.preprocess_feature() return def train(self): self.preprocess() epoch_iter = tqdm(range(self.max_epoch)) stopper = EarlyStopping(self.args.patience, self._checkpoint) for epoch in tqdm(range(self.max_epoch), ncols=80): loss = 0 if self.args.mini_batch_flag: loss = self._mini_train_step() else: loss = self._full_train_step() if epoch % self.evaluate_interval == 0: metric_dic = self._test_step(split='val') epoch_iter.set_description( f"Epoch: {epoch:03d}, Recall@K: {metric_dic['recall']:.4f}, NDCG@K: {metric_dic['ndcg']:.4f}, Loss:{loss:.4f}" ) early_stop = stopper.step_score(metric_dic[self.val_metric], self.model) if early_stop: print('Early Stop!\tEpoch:' + str(epoch)) break print(f"Valid {self.val_metric} = {stopper.best_score: .4f}") stopper.load_model(self.model) test_metric_dic = self._test_step(split="test") #val_metric_dic = self._test_step(split="val") print(f"Test Recall@K = {test_metric_dic['recall']: .4f}, NDCG@K = {test_metric_dic['ndcg']: .4f}") # result = dict(Test_metric=test_metric_dic, Val_metric=val_metric_dic) # with open(self.args.results_path, 'w') as f: # json.dump(result, f) # f.write('\n') # self.task.dataset.save_results(result, self.args.results_path) return test_metric_dic['recall'], test_metric_dic['ndcg'], epoch # return dict(Test_metric=test_metric_dic, Val_metric=val_metric_dic) def loss_calculation(self, positive_graph, negative_graph, embedding): p_score = self.ScorePredictor(positive_graph, embedding).repeat_interleave(self.num_neg) n_score = self.ScorePredictor(negative_graph, embedding) bpr_loss = -torch.log(torch.sigmoid(p_score - n_score)).mean() reg_loss = self.regularization_loss(embedding) return bpr_loss + self.reg_weight * reg_loss def ScorePredictor(self, edge_subgraph, x): with edge_subgraph.local_scope(): for ntype in [self.user_name, self.item_name]: edge_subgraph.nodes[ntype].data['x'] = x[ntype] edge_subgraph.apply_edges( dgl.function.u_dot_v('x', 'x', 'score'), etype=self.target_link) score = edge_subgraph.edges[self.target_link].data['score'] return score.squeeze() def regularization_loss(self, embedding): reg_loss = th.zeros(1, 1, device=self.device) for e in embedding.values(): reg_loss += th.mean(e.pow(2)) return reg_loss def _full_train_step(self): self.model.train() h_dict = self.input_feature() embedding = self.model(self.train_hg, h_dict) loss = self.loss_calculation(self.positive_graph, self.negative_graph, embedding) self.optimizer.zero_grad() loss.backward() self.optimizer.step() # print(loss.item()) return loss.item() def _test_step(self, split=None, logits=None): self.model.eval() if split == 'val': test_graph = self.val_hg elif split == 'test': test_graph = self.test_hg else: raise ValueError('split must be in [val, test]') with th.no_grad(): h_dict = self.input_feature() embedding = self.model(self.hg, h_dict) score_matrix = (embedding[self.user_name] @ embedding[self.item_name].T).detach().cpu().numpy() train_u, train_i = self.positive_graph.edges(etype=self.target_link)[0].cpu().numpy(), self.positive_graph.edges(etype=self.target_link)[1].cpu().numpy() score_matrix[train_u, train_i] = np.NINF ind = np.argpartition(score_matrix, -self.topk) # (num_users, num_items) ind = ind[:, -self.topk:] # (num_users, k), indicating non-ranked rec list arr_ind = score_matrix[np.arange(self.num_user)[:, None], ind] arr_ind_argsort = np.argsort(arr_ind)[np.arange(self.num_user), ::-1] pred_list = ind[np.arange(len(score_matrix))[:, None], arr_ind_argsort] # (num_uses, k) metric_dic = {} for m in self.metric: if m == 'recall': metric_k = recall_at_k(pred_list, test_graph, self.topk, self.user_name, self.target_link) elif m == 'ndcg': metric_k = ndcg_at_k(pred_list, test_graph, self.topk, self.user_name, self.target_link) else: raise NotImplementedError metric_dic[m] = metric_k return metric_dic
def recall_at_k(pred_list, test_graph, k, user_name, target_link): sum = 0.0 test_users = 0 for user in range(test_graph.num_nodes(user_name)): test_items_set = set(test_graph.successors(user, etype=target_link).cpu().numpy()) pred_items_set = set(pred_list[user][:k]) if len(test_items_set) != 0: sum += len(test_items_set & pred_items_set) / float(len(test_items_set)) test_users += 1 return sum / test_users def ndcg_at_k(pred_list, test_graph, k, user_name, target_link): ndcg = [] for user in range(test_graph.num_nodes(user_name)): test_items_set = set(test_graph.successors(user, etype=target_link).cpu().numpy()) pred_items_set = pred_list[user][:k] hit_list = [1 if i in pred_items_set else 0 for i in test_items_set] GT = len(test_items_set) if GT >= k: ideal_hit_list = [1] * k else: ideal_hit_list = [1] * GT + [0] * (k - GT) # idcg = compute_DCG(sorted(hit_list, reverse=True)) idcg = compute_DCG(ideal_hit_list) if idcg: ndcg.append(compute_DCG(hit_list) / idcg) return np.mean(ndcg) def compute_DCG(l): l = np.array(l) if l.size: return np.sum(np.subtract(np.power(2, l), 1) / np.log2(np.arange(2, l.size + 2))) else: return 0.0 # if __name__ == '__main__': # dataset_name = 'Yelp' # rec_dataset = TestRecData(dataset_name) # print(rec_dataset)