import torch
from sklearn.metrics import f1_score
import torch.nn.functional as F
from tqdm import tqdm
import numpy as np
import torch.nn as nn
from openhgnn.models import build_model
from openhgnn.models.DMGI import LogReg
from openhgnn.trainerflow import register_flow, BaseFlow
from openhgnn.utils import EarlyStopping
[docs]
@register_flow("DMGI_trainer")
class DMGI_trainer(BaseFlow):
def __init__(self, args):
super(DMGI_trainer, self).__init__(args)
# get category
self.args.category = self.task.dataset.category
self.category = self.args.category
if hasattr(self.task.dataset, 'in_dim'):
self.args.in_dim = self.task.dataset.in_dim
else:
self.args.in_dim = self.hg.ndata['h'][self.category].shape[1]
# get category num_classes
self.num_classes = self.task.dataset.num_classes
self.args.num_classes = self.task.dataset.num_classes
self.model = build_model(self.model).build_model_from_args(self.args, self.hg)
self.model = self.model.to(self.device)
self.optimizer = self.candidate_optimizer[args.optimizer](self.model.parameters(),
lr=args.lr, weight_decay=args.weight_decay)
self.train_idx, self.val_idx, self.test_idx = self.task.get_split()
# get label
self.labels = self.task.get_labels().to(self.device)
# get category's numbers
self.num_nodes = self.hg.num_nodes(self.category)
self.isSemi = args.isSemi
# a coefficient to Calculate semi
self.sup_coef = args.sup_coef
def preprocess(self):
pass
def train(self):
stopper = EarlyStopping(self.patience)
model = self.model
epoch_iter = tqdm(range(self.max_epoch))
for epoch in epoch_iter:
'''use earlyStopping'''
loss = self._full_train_setp()
early_stop = stopper.loss_step(loss, model)
self.logger.train_info(f"Epoch: {epoch:03d}, Loss: {loss:.4f}")
if early_stop:
self.logger.train_info(f'Early Stop!\tEpoch:{epoch}')
break
# Evaluation
stopper.load_model(self.model)
model.eval()
self.evaluate(model.H.data.detach(),)
def _full_train_setp(self):
self.model.train()
self.optimizer.zero_grad()
lbl_1 = torch.ones(1, self.num_nodes)
lbl_2 = torch.zeros(1, self.num_nodes)
lbl = torch.cat((lbl_1, lbl_2), 1).to(self.args.device)
result = self.model(self.hg)
loss = self.calculate_J(result, lbl)
loss.backward()
self.optimizer.step()
loss = loss.cpu()
loss = loss.detach().numpy()
return loss
def _test_step(self, split=None, logits=None):
pass
def _mini_train_step(self, ):
pass
def loss_calculation(self, positive_graph, negative_graph, embedding):
pass
def calculate_J(self, result, lbl):
r"""
Two formulas to calculate the final objective :math:`\mathcal{J}`
If isSemi = Ture, introduce a semi-supervised module into our framework that predicts the labels of labeled nodes from
the consensus embedding Z. More precisely, we minimize the cross-entropy error over the labeled nodes:
.. math::
\begin{equation}
\mathcal{J}_{\text {semi }}=\sum_{r \in \mathcal{R}} \mathcal{L}^{(r)}+\alpha \ell_{\mathrm{cs}}+\beta\|\Theta\|+\gamma \ell_{\text {sup }}
\end{equation}
Where :math:`\gamma` is the coefficient of the semi-supervised module, the way to calculate :math:`\ell_{\text {sup }}` :
.. math::
\begin{equation}
\ell_{\text {sup }}=-\frac{1}{\left|\mathcal{Y}_{L}\right|} \sum_{l \in \mathcal{Y}_{L}} \sum_{i=1}^{c} Y_{l i} \ln \hat{Y}_{l i}
\end{equation}
If isSemi = False:
.. math::
\begin{equation}
\mathcal{J}=\sum_{r \in \mathcal{R}} \mathcal{L}^{(r)}+\alpha \ell_{\mathrm{cs}}+\beta\|\Theta\|^{2}
\end{equation}
Where :math:`\alpha` controls the importance of the consensus regularization,
:math:`mathcal{L}^{(r)}` is cross entropy.
"""
logits = result['logits']
xent = nn.CrossEntropyLoss()
b_xent = nn.BCEWithLogitsLoss()
xent_loss = None
for idx, logit in enumerate(logits):
logit = logit.unsqueeze(0)
if xent_loss is None:
xent_loss = b_xent(logit, lbl)
else:
xent_loss += b_xent(logit, lbl)
loss = xent_loss
reg_loss = result['reg_loss']
loss += self.args.reg_coef * reg_loss
if self.isSemi:
sup = result['semi']
semi_loss = xent(sup[self.train_idx], self.labels[self.train_idx])
loss += self.sup_coef * semi_loss
return loss
def evaluate(self, embeds):
hid_units = embeds.shape[2]
xent = F.cross_entropy
train_embs = embeds[0, self.train_idx]
val_embs = embeds[0, self.val_idx]
test_embs = embeds[0, self.test_idx]
train_lbls = self.labels[self.train_idx]
val_lbls = self.labels[self.val_idx]
test_lbls = self.labels[self.test_idx]
val_accs = [];test_accs = []
val_micro_f1s = [];test_micro_f1s = []
val_macro_f1s = [];test_macro_f1s = []
for _ in range(50):
log = LogReg(hid_units, self.num_classes)
opt = torch.optim.Adam(log.parameters(), lr=0.01, weight_decay=0.0)
log.to(self.device)
accs = []
micro_f1s = []
macro_f1s = []
macro_f1s_val = [] ##
for iter_ in range(50):
# train
log.train()
opt.zero_grad()
logits = log(train_embs)
loss = xent(logits, train_lbls)
loss.backward()
opt.step()
# val
logits = log(val_embs)
preds = torch.argmax(logits, dim=1)
val_acc = torch.sum(preds == val_lbls).float() / val_lbls.shape[0]
val_f1_macro = f1_score(val_lbls.cpu(), preds.cpu(), average='macro')
val_f1_micro = f1_score(val_lbls.cpu(), preds.cpu(), average='micro')
val_accs.append(val_acc.item())
val_macro_f1s.append(val_f1_macro)
val_micro_f1s.append(val_f1_micro)
# test
logits = log(test_embs)
preds = torch.argmax(logits, dim=1)
test_acc = torch.sum(preds == test_lbls).float() / test_lbls.shape[0]
test_f1_macro = f1_score(test_lbls.cpu(), preds.cpu(), average='macro')
test_f1_micro = f1_score(test_lbls.cpu(), preds.cpu(), average='micro')
test_accs.append(test_acc.item())
test_macro_f1s.append(test_f1_macro)
test_micro_f1s.append(test_f1_micro)
max_iter = val_accs.index(max(val_accs))
accs.append(test_accs[max_iter])
max_iter = val_macro_f1s.index(max(val_macro_f1s))
macro_f1s.append(test_macro_f1s[max_iter])
macro_f1s_val.append(val_macro_f1s[max_iter]) ###
max_iter = val_micro_f1s.index(max(val_micro_f1s))
micro_f1s.append(test_micro_f1s[max_iter])
self.logger.train_info("\t[Classification] Macro-F1: {:.4f} ({:.4f}) | Micro-F1: {:.4f} ({:.4f})".format(np.mean(macro_f1s),
np.std(macro_f1s),
np.mean(micro_f1s),
np.std(micro_f1s)))