import argparse
import copy
import dgl
import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm
import torch.nn.functional as F
from ..models import build_model
from ..models.HeCo import LogReg
from . import BaseFlow, register_flow
from ..tasks import build_task
from ..utils import extract_embed, EarlyStopping
from sklearn.metrics import f1_score, roc_auc_score
[docs]
@register_flow("HeCo_trainer")
class HeCoTrainer(BaseFlow):
def __init__(self, args):
super(HeCoTrainer, self).__init__(args)
self.args = args
self.model_name = args.model
self.device = args.device
self.task = build_task(args)
self.hg = self.task.get_graph().to(self.device)
self.num_classes = int(self.task.dataset.num_classes)
self.args.category = self.task.dataset.category
self.category = self.args.category
self.pos = self.task.dataset.pos.to(self.device)
self.model = build_model(self.model).build_model_from_args(self.args, self.hg)
print("build_model_finish")
self.model = self.model.to(self.device)
self.evaluator = self.task.get_evaluator('f1')
self.optimizer = (
torch.optim.Adam(self.model.parameters(), lr=args.lr, weight_decay=args.weight_decay))
self.patience = args.patience
self.max_epoch = args.max_epoch
self.train_idx, self.val_idx, self.test_idx = self.task.get_split()
self.labels = self.task.get_labels().to(self.device)
def preprocess(self):
super(HeCoTrainer, self).preprocess()
def train(self):
self.preprocess()
stopper = EarlyStopping(self.args.patience)
# epoch_iter = tqdm(range(self.max_epoch))
for epoch in range(self.max_epoch):
'''use earlyStopping'''
loss = self._full_train_step()
early_stop = stopper.loss_step(loss, self.model)
print((f"Epoch: {epoch:03d}, Loss: {loss:.4f}"))
if early_stop:
print('Early Stop!\tEpoch:' + str(epoch))
break
# Evaluation
model = stopper.load_model(self.model)
model.eval()
h_dict = self.model.input_feature()
embeds = model.get_embeds(self.hg, h_dict=h_dict)
self.evaluate(embeds,)
def _full_train_step(self):
self.model.train()
self.optimizer.zero_grad()
h_dict = self.model.input_feature()
loss = self.model(self.hg, h_dict, self.pos)
loss.backward()
self.optimizer.step()
loss = loss.cpu()
loss = loss.detach().numpy()
return loss
def evaluate(self, embeds):
hid_units = embeds.shape[1]
xent = nn.CrossEntropyLoss()
train_embs = embeds[self.train_idx]
val_embs = embeds[self.val_idx]
test_embs = embeds[self.test_idx]
train_lbls = self.labels[self.train_idx[:,0]].reshape(-1)
val_lbls = self.labels[self.val_idx[:,0]].reshape(-1)
test_lbls = self.labels[self.test_idx[:,0]].reshape(-1)
accs = []
micro_f1s = []
macro_f1s = []
macro_f1s_val = []
auc_score_list = []
for _ in range(50):
log = LogReg(hid_units, self.num_classes)
opt = torch.optim.Adam(log.parameters(), lr=self.args.eva_lr, weight_decay=self.args.eva_wd)
log.to(self.device)
val_accs = []
test_accs = []
val_micro_f1s = []
test_micro_f1s = []
val_macro_f1s = []
test_macro_f1s = []
logits_list = []
for iter_ in range(200):
# train
log.train()
opt.zero_grad()
logits = log(train_embs)
loss = xent(logits[:,0], train_lbls)
loss.backward()
opt.step()
# val
logits = log(val_embs)
preds = torch.argmax(logits[:,0], dim=1)
val_acc = torch.sum(preds == val_lbls).float() / val_lbls.shape[0]
val_f1_macro = f1_score(val_lbls.cpu(), preds.cpu(), average='macro')
val_f1_micro = f1_score(val_lbls.cpu(), preds.cpu(), average='micro')
val_accs.append(val_acc.item())
val_macro_f1s.append(val_f1_macro)
val_micro_f1s.append(val_f1_micro)
# test
logits = log(test_embs)
preds = torch.argmax(logits[:,0], dim=1)
test_acc = torch.sum(preds == test_lbls).float() / test_lbls.shape[0]
test_f1_macro = f1_score(test_lbls.cpu(), preds.cpu(), average='macro')
test_f1_micro = f1_score(test_lbls.cpu(), preds.cpu(), average='micro')
test_accs.append(test_acc.item())
test_macro_f1s.append(test_f1_macro)
test_micro_f1s.append(test_f1_micro)
logits_list.append(logits)
max_iter = val_accs.index(max(val_accs))
accs.append(test_accs[max_iter])
max_iter = val_macro_f1s.index(max(val_macro_f1s))
macro_f1s.append(test_macro_f1s[max_iter])
macro_f1s_val.append(val_macro_f1s[max_iter])
max_iter = val_micro_f1s.index(max(val_micro_f1s))
micro_f1s.append(test_micro_f1s[max_iter])
# auc
best_logits = logits_list[max_iter]
best_proba = F.softmax(best_logits[:,0], dim=1)
auc_score_list.append(roc_auc_score(y_true=test_lbls.detach().cpu().numpy(),
y_score=best_proba.detach().cpu().numpy(),
multi_class='ovr'
))
print("\t[Classification] Macro-F1_mean: {:.4f} var: {:.4f} Micro-F1_mean: {:.4f} var: {:.4f} auc {:.4f}"
.format(np.mean(macro_f1s),
np.std(macro_f1s),
np.mean(micro_f1s),
np.std(micro_f1s),
np.mean(auc_score_list),
np.std(auc_score_list)
)
)