from abc import abstractmethod
import dgl
import torch as th
from torch import nn
from tqdm import tqdm
import torch
import torch.nn.functional as F
from openhgnn.trainerflow import BaseFlow, register_flow
from openhgnn.models import build_model
from openhgnn.utils import extract_embed, EarlyStopping, get_nodes_dict, add_reverse_edges
from torch.utils.data import DataLoader
from openhgnn.tasks import build_task
import os
from sklearn import metrics
import numpy as np
[文档]@register_flow("DHNE_trainer")
class DHNE_trainer(BaseFlow):
"""
DHNE trainer flow.
Supported Model: DHNE
Supported Datase: drug, GPS, MovieLens, wordnet
The trainerflow supports hypergraph embedding task.
For more details, please refer to the original paper: https://arxiv.org/abs/1711.10146
"""
def __init__(self, args):
super(DHNE_trainer, self).__init__(args)
self.hg = None
self.args = args
self.logger = self.args.logger
self.model_name = args.model_name
self.model = args.model
self.task = build_task(args)
self.train_dataset = self.task.train_data
self.val_dataset = self.task.val_data
self.device = self.args.device
self.args.nums_type = self.train_dataset.nums_type
self.model = build_model(self.model).build_model_from_args(
self.args).to(self.device)
self.optimizer = torch.optim.RMSprop(
self.model.parameters(), lr=args.lr)
self.train_dataloader = DataLoader(self.train_dataset,
batch_size=self.args.batch_size,
shuffle=True,
num_workers=4,
collate_fn=self.collate_fn)
self.val_dataloader = DataLoader(self.val_dataset,
batch_size=self.args.batch_size,
shuffle=True,
num_workers=4,
collate_fn=self.collate_fn)
def collate_fn(self, batches):
encode1, encode2, encode3 = [], [], []
labels = []
for encodes, label in batches:
encode1.append(torch.from_numpy(encodes[0]))
encode2.append(torch.from_numpy(encodes[1]))
encode3.append(torch.from_numpy(encodes[2]))
labels.append(torch.from_numpy(label))
return [torch.concat(encode1).type(torch.float32), torch.concat(encode2).type(torch.float32), torch.concat(encode3).type(torch.float32)], torch.concat(labels)
def preprocess(self):
pass
def train(self):
total_metric = {}
max_auc = 0
max_auc_epoch = 0
max_auc_epoch_loss = 0
for epoch in tqdm(range(self.max_epoch)):
loss = self._mini_train_step()
if epoch % self.evaluate_interval == 0:
val_metric = self._test_step()
self.logger.train_info(
f"Epoch: {epoch:03d}, train loss: {loss:.4f}. " +
"encode_0:{encode_0},encode_1:{encode_1},encode_2:{encode_2},label:{label},auc_micro:{auc_micro}".format(**val_metric))
total_metric[epoch] = val_metric
if val_metric['auc_micro'] > max_auc:
max_auc = val_metric['auc_micro']
max_auc_epoch = epoch
max_auc_epoch_loss = loss
print("[Best Info] the best training results:")
print(f"Epoch: {max_auc_epoch:03d}, train loss: {max_auc_epoch_loss:.4f}. " +
"encode_0:{encode_0},encode_1:{encode_1},encode_2:{encode_2},label:{label},auc_micro:{auc_micro}".format(
**total_metric[max_auc_epoch]))
save_path = os.path.dirname(os.path.abspath('__file__'))+'/openhgnn/output/'+self.model_name
th.save(self.model.state_dict(), save_path+'/DHNE.pth')
def _full_train_setp(self):
pass
def _mini_train_step(self,):
self.model.train()
all_loss = 0
loader_tqdm = tqdm(self.train_dataloader, ncols=120)
for blocks, target in loader_tqdm:
blocks = [b.to(self.device) for b in blocks]
target = target.to(self.device)
logits = self.model(blocks)
loss = self.loss_calculation(blocks+[target], logits)
all_loss += loss.item()
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
return all_loss
def sparse_autoencoder_error(self, y_true, y_pred):
return torch.mean(torch.square(torch.sign(y_true)*(y_true-y_pred)), axis=-1)
def loss_calculation(self, targets, preds):
loss = 0
for i in range(3):
target = targets[i]
pred = preds[i]
loss += self.sparse_autoencoder_error(target, pred)
label_target = targets[-1]
label_target = label_target.reshape(-1, 1)
label_pred = preds[-1]
loss = loss.mean()
loss += F.binary_cross_entropy_with_logits(label_pred, label_target)
return loss
def _test_step(self):
self.model.eval()
metric = {"encode_0": 0, "encode_1": 0, "encode_2": 0, "label": 0}
with th.no_grad():
loader_tqdm = tqdm(self.val_dataloader, ncols=120)
num = 0
y_true = []
y_score = []
for blocks, target in loader_tqdm:
blocks = [b.to(self.device) for b in blocks]
targets = target.to(self.device)
logits = self.model(blocks)
for i in range(3):
pred = logits[i]
target = blocks[i]
metric["encode_%d" % i] += F.mse_loss(pred, target)
label_logit = torch.sigmoid(logits[-1])
label_pred = (label_logit > 0.5).int()
label_pred = label_pred.reshape(-1)
metric["label"] += (label_pred == targets).float().mean()
y_true.append(targets.cpu().numpy())
y_score.append(label_logit.cpu().numpy())
num += 1
metric = {k: v.cpu().numpy() / num for k, v in metric.items()}
y_true = np.hstack(y_true)
y_score = np.vstack(y_score).reshape(-1)
auc_micro = metrics.roc_auc_score(y_true, y_score, average='micro')
metric['auc_micro'] = auc_micro
return metric
def save_checkpoint(self):
if self._checkpoint and hasattr(self.model, "_parameters()"):
torch.save(self.model.state_dict(), self._checkpoint)