openhgnn.trainerflow.DHNE_trainer 源代码

from abc import abstractmethod
import dgl
import torch as th
from torch import nn
from tqdm import tqdm
import torch
import torch.nn.functional as F
from openhgnn.trainerflow import BaseFlow, register_flow
from openhgnn.models import build_model
from openhgnn.utils import extract_embed, EarlyStopping, get_nodes_dict, add_reverse_edges
from torch.utils.data import DataLoader
from openhgnn.tasks import build_task
import os
from sklearn import metrics
import numpy as np


[文档]@register_flow("DHNE_trainer") class DHNE_trainer(BaseFlow): """ DHNE trainer flow. Supported Model: DHNE Supported Datase: drug, GPS, MovieLens, wordnet The trainerflow supports hypergraph embedding task. For more details, please refer to the original paper: https://arxiv.org/abs/1711.10146 """ def __init__(self, args): super(DHNE_trainer, self).__init__(args) self.hg = None self.args = args self.logger = self.args.logger self.model_name = args.model_name self.model = args.model self.task = build_task(args) self.train_dataset = self.task.train_data self.val_dataset = self.task.val_data self.device = self.args.device self.args.nums_type = self.train_dataset.nums_type self.model = build_model(self.model).build_model_from_args( self.args).to(self.device) self.optimizer = torch.optim.RMSprop( self.model.parameters(), lr=args.lr) self.train_dataloader = DataLoader(self.train_dataset, batch_size=self.args.batch_size, shuffle=True, num_workers=4, collate_fn=self.collate_fn) self.val_dataloader = DataLoader(self.val_dataset, batch_size=self.args.batch_size, shuffle=True, num_workers=4, collate_fn=self.collate_fn) def collate_fn(self, batches): encode1, encode2, encode3 = [], [], [] labels = [] for encodes, label in batches: encode1.append(torch.from_numpy(encodes[0])) encode2.append(torch.from_numpy(encodes[1])) encode3.append(torch.from_numpy(encodes[2])) labels.append(torch.from_numpy(label)) return [torch.concat(encode1).type(torch.float32), torch.concat(encode2).type(torch.float32), torch.concat(encode3).type(torch.float32)], torch.concat(labels) def preprocess(self): pass def train(self): total_metric = {} max_auc = 0 max_auc_epoch = 0 max_auc_epoch_loss = 0 for epoch in tqdm(range(self.max_epoch)): loss = self._mini_train_step() if epoch % self.evaluate_interval == 0: val_metric = self._test_step() self.logger.train_info( f"Epoch: {epoch:03d}, train loss: {loss:.4f}. " + "encode_0:{encode_0},encode_1:{encode_1},encode_2:{encode_2},label:{label},auc_micro:{auc_micro}".format(**val_metric)) total_metric[epoch] = val_metric if val_metric['auc_micro'] > max_auc: max_auc = val_metric['auc_micro'] max_auc_epoch = epoch max_auc_epoch_loss = loss print("[Best Info] the best training results:") print(f"Epoch: {max_auc_epoch:03d}, train loss: {max_auc_epoch_loss:.4f}. " + "encode_0:{encode_0},encode_1:{encode_1},encode_2:{encode_2},label:{label},auc_micro:{auc_micro}".format( **total_metric[max_auc_epoch])) save_path = os.path.dirname(os.path.abspath('__file__'))+'/openhgnn/output/'+self.model_name th.save(self.model.state_dict(), save_path+'/DHNE.pth') def _full_train_setp(self): pass def _mini_train_step(self,): self.model.train() all_loss = 0 loader_tqdm = tqdm(self.train_dataloader, ncols=120) for blocks, target in loader_tqdm: blocks = [b.to(self.device) for b in blocks] target = target.to(self.device) logits = self.model(blocks) loss = self.loss_calculation(blocks+[target], logits) all_loss += loss.item() self.optimizer.zero_grad() loss.backward() self.optimizer.step() return all_loss def sparse_autoencoder_error(self, y_true, y_pred): return torch.mean(torch.square(torch.sign(y_true)*(y_true-y_pred)), axis=-1) def loss_calculation(self, targets, preds): loss = 0 for i in range(3): target = targets[i] pred = preds[i] loss += self.sparse_autoencoder_error(target, pred) label_target = targets[-1] label_target = label_target.reshape(-1, 1) label_pred = preds[-1] loss = loss.mean() loss += F.binary_cross_entropy_with_logits(label_pred, label_target) return loss def _test_step(self): self.model.eval() metric = {"encode_0": 0, "encode_1": 0, "encode_2": 0, "label": 0} with th.no_grad(): loader_tqdm = tqdm(self.val_dataloader, ncols=120) num = 0 y_true = [] y_score = [] for blocks, target in loader_tqdm: blocks = [b.to(self.device) for b in blocks] targets = target.to(self.device) logits = self.model(blocks) for i in range(3): pred = logits[i] target = blocks[i] metric["encode_%d" % i] += F.mse_loss(pred, target) label_logit = torch.sigmoid(logits[-1]) label_pred = (label_logit > 0.5).int() label_pred = label_pred.reshape(-1) metric["label"] += (label_pred == targets).float().mean() y_true.append(targets.cpu().numpy()) y_score.append(label_logit.cpu().numpy()) num += 1 metric = {k: v.cpu().numpy() / num for k, v in metric.items()} y_true = np.hstack(y_true) y_score = np.vstack(y_score).reshape(-1) auc_micro = metrics.roc_auc_score(y_true, y_score, average='micro') metric['auc_micro'] = auc_micro return metric def save_checkpoint(self): if self._checkpoint and hasattr(self.model, "_parameters()"): torch.save(self.model.state_dict(), self._checkpoint)