Source code for openhgnn.trainerflow.MeiRec_trainer

from collections import OrderedDict

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from tqdm import tqdm
from ..models import build_model
from . import BaseFlow, register_flow
from ..tasks import build_task


[docs] @register_flow("MeiREC_trainer") class MeiRECTrainer(BaseFlow): def __init__(self, args): super().__init__(args) self.args = args self.model_name = args.model self.device = args.device self.task = build_task(args) self.model = build_model(self.model_name).build_model_from_args( self.args).model print("build_model_finish") self.model = self.model.to(self.device) self.loss_fn = nn.BCELoss(reduction='mean') # self.evaluator = self.task.evaluate self.optimizer = optim.Adam(self.model.parameters(), lr=args.lr) self.scheduler = optim.lr_scheduler.ExponentialLR(self.optimizer, 0.99) self.patience = args.patience self.max_epoch = args.max_epoch def train(self): train_loader = self.task.train_loader test_loader = self.task.test_loader for epoch in range(1, self.max_epoch + 1): pbar = tqdm(train_loader) if epoch % 3 == 0: for i, data in enumerate(pbar): # train step, data: batch * features if i != len(train_loader) - 1: output, train_loss, train_auc, train_acc = self._train_step(data) # record bar pbar.set_description(f"epoch[{epoch}/{self.max_epoch}],\ data[{i}/{len(train_loader)}]") pbar.set_postfix( OrderedDict({ "train_loss": train_loss.item(), "train_auc": train_auc, "train_acc": train_acc })) else: pr = [] la = [] _, train_loss, train_auc, train_acc = self._train_step(data) for j, data_test in enumerate(test_loader): output, test_loss, test_auc, test_acc = self.eval_model(data_test) pr.append(output) la.append(data_test["labels"]) prediction_test = torch.cat(pr, 0).cuda() labels_test = torch.cat(la, 0).cuda() loss_cross = self.loss_fn(prediction_test, labels_test.float()) loss_total = loss_cross + self.model.regular_loss # auc and acc value predicts_cpu = prediction_test.cpu().detach().numpy() labels_cpu = labels_test.cpu().detach().numpy() correct_num = sum((predicts_cpu > 0.5) == labels_cpu) acc = correct_num / len(predicts_cpu) # auc = roc_auc_score(labels_cpu, predicts_cpu) auc = self.auc(labels_cpu, predicts_cpu) pbar.set_description(f"epoch[{epoch}/{self.max_epoch}],\ data[{i}/{len(train_loader)}]") pbar.set_postfix( OrderedDict({ "train_loss": train_loss.item(), "train_auc": train_auc, "train_acc": train_acc, "test_loss": loss_total.item(), "test_auc": auc, "test_acc": acc })) else: for i, data in enumerate(pbar): # train step, data: batch * features output, train_loss, train_auc, train_acc = self._train_step(data) # record bar pbar.set_description(f"epoch[{epoch}/{self.max_epoch}],\ data[{i}/{len(train_loader)}]") pbar.set_postfix( OrderedDict({ "train_loss": train_loss.item(), "train_auc": train_auc, "train_acc": train_acc })) self.scheduler.step() # if epoch % 10 == 0: # self.save_model(str(epoch)) # self.save_model("latest.pth.tar") def _train_step(self, x): self.model.train() # Sets the module in training mode output, loss, auc, acc = self.forward(x) self.optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=5.0) self.optimizer.step() return output, loss, auc, acc def eval_model(self, x): self.model.eval() # Sets the module in evaluation mode. with torch.no_grad(): output, loss, auc, acc = self.forward(x) return output, loss, auc, acc def _confusion_matrix_at_thresholds(self, labels, preds, thresholds): values = {} tp = [] fp = [] tn = [] fn = [] for thr in thresholds: tp.append(sum((preds > thr) & (labels == 1))) fp.append(sum((preds > thr) & (labels == 0))) tn.append(sum((preds < thr) & (labels == 0))) fn.append(sum((preds < thr) & (labels == 1))) values['tp'] = torch.from_numpy(np.array(tp)) values['fp'] = torch.from_numpy(np.array(fp)) values['tn'] = torch.from_numpy(np.array(tn)) values['fn'] = torch.from_numpy(np.array(fn)) return values def auc(self, labels, preds, num_thresholds=200): # tensorflow1.0_auc kepsilon = 1e-7 # to account for floating point imprecisions thresholds = [(i + 1) * 1.0 / (num_thresholds - 1) for i in range(num_thresholds - 2)] thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon] values = self._confusion_matrix_at_thresholds( labels, preds, thresholds) # Add epsilons to avoid dividing by 0. epsilon = 1.0e-6 def compute_auc(tp, fn, tn, fp): rec = torch.div(tp + epsilon, tp + fn + epsilon) fp_rate = torch.div(fp, fp + tn + epsilon) x = fp_rate y = rec return torch.sum( torch.multiply(x[:num_thresholds - 1] - x[1:], (y[:num_thresholds - 1] + y[1:]) / 2.)) # sum up the areas of all the trapeziums auc_value = compute_auc( values['tp'], values['fn'], values['tn'], values['fp']) return auc_value.item() def forward(self, x): inputs = x['data'].T labels = x['labels'] inputs = inputs.cuda() labels = labels.cuda() predicts = self.model(inputs) # loss loss_cross = self.loss_fn(predicts, labels.float()) loss_total = loss_cross + self.model.regular_loss # auc and acc value predicts_cpu = predicts.cpu().detach().numpy() labels_cpu = labels.cpu().detach().numpy() correct_num = sum((predicts_cpu > 0.5) == labels_cpu) acc = correct_num / len(predicts_cpu) auc = self.auc(labels_cpu, predicts_cpu) return predicts, loss_total.cpu(), auc, acc