import os
import gc
import time
import uuid
import argparse
import datetime
import numpy as np
import torch
import torch.nn.functional as F
import os
import sys
import gc
import random
import dgl.function as fn
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import scipy.sparse as sp
from sklearn.metrics import f1_score
from tqdm import tqdm
from ogb.nodeproppred import DglNodePropPredDataset, Evaluator
from openhgnn.models import build_model
from . import BaseFlow, register_flow
from ..tasks import build_task
import functools
from contextlib import closing
import multiprocessing as mp
from multiprocessing import Pool
from tqdm import tqdm
from ..tasks import NodeClassification
[docs]
@register_flow("SeHGNN_trainer")
class SeHGNNtrainer(BaseFlow):
def __init__(self,args):
super(SeHGNNtrainer, self).__init__(args)
args.stages = [int(item.strip()) for item in args.stages.split(',')]
self.args = args
self.flow = NodeClassification(args)
def train(self):
args = self.args
if args.seed > 0:
self.set_random_seed(args.seed)
num_nodes = self.flow.dataset.SeHGNN_g.num_nodes("P")
n_classes = int(self.flow.labels.max()) + 1
evaluator = self.flow.get_evaluator("acc")
# =======
# rearange node idx (for feats & labels)
# =======
train_node_nums = len(self.flow.train_idx)
valid_node_nums = len(self.flow.val_idx)
test_node_nums = len(self.flow.test_idx)
trainval_point = train_node_nums
valtest_point = trainval_point + valid_node_nums
total_num_nodes = len(self.flow.train_idx) + len(self.flow.val_idx) + len(self.flow.test_idx)
init2sort = torch.cat([self.flow.train_idx, self.flow.val_idx, self.flow.test_idx])
sort2init = torch.argsort(init2sort)
assert torch.all(self.flow.labels[init2sort][sort2init] == self.flow.labels)
labels = self.flow.labels[init2sort]
# =======
# features propagate alongside the metapath
# =======
tgt_type = 'P'
max_hops = args.num_hops + 1
# compute k-hop feature
self.flow.dataset.SeHGNN_g = self.hg_propagate(self.flow.dataset.SeHGNN_g, tgt_type, args.num_hops, max_hops, echo=False)
feats = {}
keys = list(self.flow.dataset.SeHGNN_g.nodes[tgt_type].data.keys())
print(f'Involved feat keys {keys}')
for k in keys:
feats[k] = self.flow.dataset.SeHGNN_g.nodes[tgt_type].data.pop(k)
self.flow.dataset.SeHGNN_g = self.clear_hg(self.flow.dataset.SeHGNN_g, echo=False)
feats = {k: v[init2sort] for k, v in feats.items()}
gc.collect()
all_loader = torch.utils.data.DataLoader(
torch.arange(num_nodes), batch_size=args.batch_size, shuffle=False, drop_last=False)
checkpt_folder = f'./openhgnn/output/SeHGNN/{args.dataset}/'
if not os.path.exists(checkpt_folder):
os.makedirs(checkpt_folder)
if args.amp:
scalar = torch.cuda.amp.GradScaler()
else:
scalar = None
device = "cuda:{}".format(args.gpu) if not args.cpu else 'cpu'
labels_cuda = labels.long().to(device)
checkpt_file = checkpt_folder + uuid.uuid4().hex
print(checkpt_file)
for stage in range(args.start_stage, len(args.stages)):
epochs = args.stages[stage]
if len(args.reload):
pt_path = f'./openhgnn/output/SeHGNN/ogbn-mag/{args.reload}_{stage-1}.pt'
assert os.path.exists(pt_path)
print(f'Reload raw_preds from {pt_path}', flush=True)
raw_preds = torch.load(pt_path, map_location='cpu')
# =======
# Expand training set & train loader
# =======
if stage > 0:
preds = raw_preds.argmax(dim=-1)
predict_prob = raw_preds.softmax(dim=1)
train_acc = evaluator(preds[:trainval_point], labels[:trainval_point])
val_acc = evaluator(preds[trainval_point:valtest_point], labels[trainval_point:valtest_point])
test_acc = evaluator(preds[valtest_point:total_num_nodes], labels[valtest_point:total_num_nodes])
print(f'Stage {stage-1} history model:\n\t' \
+ f'Train acc {train_acc*100:.4f} Val acc {val_acc*100:.4f} Test acc {test_acc*100:.4f}')
confident_mask = predict_prob.max(1)[0] > args.threshold
val_enhance_offset = torch.where(confident_mask[trainval_point:valtest_point])[0]
test_enhance_offset = torch.where(confident_mask[valtest_point:total_num_nodes])[0]
val_enhance_nid = val_enhance_offset + trainval_point
test_enhance_nid = test_enhance_offset + valtest_point
enhance_nid = torch.cat((val_enhance_nid, test_enhance_nid))
print(f'Stage: {stage}, threshold {args.threshold}, confident nodes: {len(enhance_nid)} / {total_num_nodes - trainval_point}')
val_confident_level = (predict_prob[val_enhance_nid].argmax(1) == labels[val_enhance_nid]).sum() / len(val_enhance_nid)
print(f'\t\t val confident nodes: {len(val_enhance_nid)} / {valid_node_nums}, val confident level: {val_confident_level}')
test_confident_level = (predict_prob[test_enhance_nid].argmax(1) == labels[test_enhance_nid]).sum() / len(test_enhance_nid)
print(f'\t\ttest confident nodes: {len(test_enhance_nid)} / {test_node_nums}, test confident_level: {test_confident_level}')
del train_loader
train_batch_size = int(args.batch_size * len(self.flow.train_idx) / (len(enhance_nid) + len(self.flow.train_idx)))
train_loader = torch.utils.data.DataLoader(
torch.arange(train_node_nums), batch_size=train_batch_size, shuffle=True, drop_last=False)
enhance_batch_size = int(args.batch_size * len(enhance_nid) / (len(enhance_nid) + len(self.flow.train_idx)))
enhance_loader = torch.utils.data.DataLoader(
enhance_nid, batch_size=enhance_batch_size, shuffle=True, drop_last=False)
else:
train_loader = torch.utils.data.DataLoader(
torch.arange(train_node_nums), batch_size=args.batch_size, shuffle=True, drop_last=False)
# =======
# labels propagate alongside the metapath
# =======
label_feats = {}
if args.label_feats:
if stage > 0:
label_onehot = predict_prob[sort2init].clone()
else:
label_onehot = torch.zeros((num_nodes, n_classes))
label_onehot[self.flow.train_idx] = F.one_hot(self.flow.labels[self.flow.train_idx], n_classes).float()
self.flow.dataset.SeHGNN_g.nodes['P'].data['P'] = label_onehot
print(f'Current num label hops = {args.num_label_hops}')
max_hops = args.num_label_hops + 1
self.flow.dataset.SeHGNN_g = self.hg_propagate(self.flow.dataset.SeHGNN_g, tgt_type, args.num_label_hops, max_hops, echo=False)
keys = list(self.flow.dataset.SeHGNN_g.nodes[tgt_type].data.keys())
print(f'Involved label keys {keys}')
for k in keys:
if k == tgt_type: continue
label_feats[k] = self.flow.dataset.SeHGNN_g.nodes[tgt_type].data.pop(k)
self.flow.dataset.SeHGNN_g = self.clear_hg(self.flow.dataset.SeHGNN_g, echo=False)
for k in ['PPP', 'PAP', 'PFP', 'PPPP', 'PAPP', 'PPAP', 'PFPP', 'PPFP']:
if k in label_feats:
diag = torch.load(f'{args.dataset}_{k}_diag.pt')
label_feats[k] = label_feats[k] - diag.unsqueeze(-1) * label_onehot
assert torch.all(label_feats[k] > -1e-6)
print(k, torch.sum(label_feats[k] < 0), label_feats[k].min())
label_emb = (label_feats['PPP'] + label_feats['PAP'] + label_feats['PP'] + label_feats['PFP']) / 4
else:
label_emb = torch.zeros((num_nodes, n_classes))
label_feats = {k: v[init2sort] for k, v in label_feats.items()}
label_emb = label_emb[init2sort]
if stage == 0:
label_feats = {}
# =======
# Eval loader
# =======
if stage > 0:
del eval_loader
eval_loader = []
for batch_idx in range((num_nodes-trainval_point-1) // args.batch_size + 1):
batch_start = batch_idx * args.batch_size + trainval_point
batch_end = min(num_nodes, (batch_idx+1) * args.batch_size + trainval_point)
batch_feats = {k: v[batch_start:batch_end] for k,v in feats.items()}
batch_label_feats = {k: v[batch_start:batch_end] for k,v in label_feats.items()}
batch_labels_emb = label_emb[batch_start:batch_end]
eval_loader.append((batch_feats, batch_label_feats, batch_labels_emb))
data_size = {k: v.size(-1) for k, v in feats.items()}
# =======
# Construct network
# =======
args.data_size = data_size
args.nclass = n_classes
args.nfeat = args.embed_size
args.num_feats = len(feats)
args.num_label_feats = len(label_feats)
args.tgt_key = tgt_type
model = build_model(self.args.model).build_model_from_args(self.args).to(self.args.device)
if stage == args.start_stage:
print(model)
print("# Params:", self.get_n_params(model))
loss_fcn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr,
weight_decay=args.weight_decay)
best_epoch = 0
best_val_acc = 0
best_test_acc = 0
count = 0
for epoch in range(epochs):
gc.collect()
torch.cuda.empty_cache()
start = time.time()
if stage == 0:
loss, acc = self.run(model, train_loader, loss_fcn, optimizer, evaluator, device, feats, label_feats, labels_cuda, label_emb, scalar=scalar)
else:
loss, acc = self.train_multi_stage(model, train_loader, enhance_loader, loss_fcn, optimizer, evaluator, device, feats, label_feats, labels_cuda, label_emb, predict_prob, args.gama, scalar=scalar)
end = time.time()
log = "Epoch {}, Time(s): {:.4f}, estimated train loss {:.4f}, acc {:.4f}\n".format(epoch, end-start, loss, acc*100)
torch.cuda.empty_cache()
if epoch % args.eval_every == 0:
with torch.no_grad():
model.eval()
raw_preds = []
start = time.time()
for batch_feats, batch_label_feats, batch_labels_emb in eval_loader:
batch_feats = {k: v.to(device) for k,v in batch_feats.items()}
batch_label_feats = {k: v.to(device) for k,v in batch_label_feats.items()}
batch_labels_emb = batch_labels_emb.to(device)
fk = {'0': batch_feats, '1': batch_label_feats, '2': batch_labels_emb}
raw_preds.append(model(fk).cpu())
raw_preds = torch.cat(raw_preds, dim=0)
loss_val = loss_fcn(raw_preds[:valid_node_nums], labels[trainval_point:valtest_point]).item()
loss_test = loss_fcn(raw_preds[valid_node_nums:valid_node_nums+test_node_nums], labels[valtest_point:total_num_nodes]).item()
preds = raw_preds.argmax(dim=-1)
val_acc = evaluator(preds[:valid_node_nums], labels[trainval_point:valtest_point])
test_acc = evaluator(preds[valid_node_nums:valid_node_nums+test_node_nums], labels[valtest_point:total_num_nodes])
end = time.time()
log += f'Time: {end-start}, Val loss: {loss_val}, Test loss: {loss_test}\n'
log += 'Val acc: {:.4f}, Test acc: {:.4f}\n'.format(val_acc*100, test_acc*100)
if val_acc > best_val_acc:
best_epoch = epoch
best_val_acc = val_acc
best_test_acc = test_acc
torch.save(model.state_dict(), f'{checkpt_file}_{stage}.pkl')
count = 0
else:
count = count + args.eval_every
if count >= args.patience:
break
log += "Best Epoch {},Val {:.4f}, Test {:.4f}".format(best_epoch, best_val_acc*100, best_test_acc*100)
print(log, flush=True)
print("Best Epoch {}, Val {:.4f}, Test {:.4f}".format(best_epoch, best_val_acc*100, best_test_acc*100))
model.load_state_dict(torch.load(checkpt_file+f'_{stage}.pkl'))
raw_preds = self.gen_output_torch(model, feats, label_feats, label_emb, all_loader, device)
torch.save(raw_preds, checkpt_file+f'_{stage}.pt')
def set_random_seed(self, seed=0):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
def get_n_params(self, model):
pp = 0
for p in list(model.parameters()):
nn = 1
for s in list(p.size()):
nn = nn * s
pp += nn
return pp
def hg_propagate(self, new_g, tgt_type, num_hops, max_hops, echo=False):
for hop in range(1, max_hops):
for etype in new_g.etypes:
stype, _, dtype = new_g.to_canonical_etype(etype)
for k in list(new_g.nodes[stype].data.keys()):
if len(k) == hop:
current_dst_name = f'{dtype}{k}'
if (hop == num_hops and dtype != tgt_type) or (hop > num_hops):
continue
if echo: print(k, etype, current_dst_name)
new_g[etype].update_all(
fn.copy_u(k, 'm'),
fn.mean('m', current_dst_name), etype=etype)
# remove no-use items
for ntype in new_g.ntypes:
if ntype == tgt_type: continue
removes = []
for k in new_g.nodes[ntype].data.keys():
if len(k) <= hop:
removes.append(k)
for k in removes:
new_g.nodes[ntype].data.pop(k)
if echo and len(removes): print('remove', removes)
gc.collect()
if echo: print(f'-- hop={hop} ---')
for ntype in new_g.ntypes:
for k, v in new_g.nodes[ntype].data.items():
if echo: print(f'{ntype} {k} {v.shape}')
if echo: print(f'------\n')
return new_g
def clear_hg(self, new_g, echo=False):
if echo: print('Remove keys left after propagation')
for ntype in new_g.ntypes:
keys = list(new_g.nodes[ntype].data.keys())
if len(keys):
if echo: print(ntype, keys)
for k in keys:
new_g.nodes[ntype].data.pop(k)
return new_g
def run(self, model, train_loader, loss_fcn, optimizer, evaluator, device,
feats, label_feats, labels_cuda, label_emb, mask=None, scalar=None):
model.train()
total_loss = 0
iter_num = 0
y_true, y_pred = [], []
for batch in train_loader:
batch_feats = {k: x[batch].to(device) for k, x in feats.items()}
batch_labels_feats = {k: x[batch].to(device) for k, x in label_feats.items()}
# if mask is not None:
# batch_mask = {k: x[batch].to(device) for k, x in mask.items()}
# else:
# batch_mask = None
batch_label_emb = label_emb[batch].to(device)
batch_y = labels_cuda[batch]
optimizer.zero_grad()
if scalar is not None:
with torch.cuda.amp.autocast():
fk = {'0': batch_feats, '1' :batch_labels_feats, '2': batch_label_emb}
output_att = model(fk)
if isinstance(loss_fcn, nn.BCELoss):
output_att = torch.sigmoid(output_att)
loss_train = loss_fcn(output_att, batch_y)
scalar.scale(loss_train).backward()
scalar.step(optimizer)
scalar.update()
else:
fk = {'0': batch_feats, '1': batch_labels_feats,'2': batch_label_emb}
output_att = model(fk)
if isinstance(loss_fcn, nn.BCELoss):
output_att = torch.sigmoid(output_att)
L1 = loss_fcn(output_att, batch_y)
loss_train = L1
loss_train.backward()
optimizer.step()
y_true.append(batch_y.cpu().to(torch.long))
if isinstance(loss_fcn, nn.BCELoss):
y_pred.append((output_att.data.cpu() > 0).int())
else:
y_pred.append(output_att.argmax(dim=-1, keepdim=True).cpu())
total_loss += loss_train.item()
iter_num += 1
loss = total_loss / iter_num
acc = evaluator(torch.cat(y_true, dim=0), torch.cat(y_pred, dim=0))
return loss, acc
def train_multi_stage(self, model, train_loader, enhance_loader, loss_fcn, optimizer, evaluator, device,
feats, label_feats, labels, label_emb, predict_prob, gama, scalar=None):
model.train()
loss_fcn = nn.CrossEntropyLoss()
y_true, y_pred = [], []
total_loss = 0
loss_l1, loss_l2 = 0., 0.
iter_num = 0
for idx_1, idx_2 in zip(train_loader, enhance_loader):
idx = torch.cat((idx_1, idx_2), dim=0)
L1_ratio = len(idx_1) * 1.0 / (len(idx_1) + len(idx_2))
L2_ratio = len(idx_2) * 1.0 / (len(idx_1) + len(idx_2))
batch_feats = {k: x[idx].to(device) for k, x in feats.items()}
batch_labels_feats = {k: x[idx].to(device) for k, x in label_feats.items()}
batch_label_emb = label_emb[idx].to(device)
y = labels[idx_1].to(torch.long).to(device)
extra_weight, extra_y = predict_prob[idx_2].max(dim=1)
extra_weight = extra_weight.to(device)
extra_y = extra_y.to(device)
optimizer.zero_grad()
if scalar is not None:
with torch.cuda.amp.autocast():
fk = {'0': batch_feats, '1': batch_labels_feats, '2': batch_label_emb}
output_att = model(fk)
L1 = loss_fcn(output_att[:len(idx_1)], y)
L2 = F.cross_entropy(output_att[len(idx_1):], extra_y, reduction='none')
L2 = (L2 * extra_weight).sum() / len(idx_2)
loss_train = L1_ratio * L1 + gama * L2_ratio * L2
scalar.scale(loss_train).backward()
scalar.step(optimizer)
scalar.update()
else:
while True:
print("Yy")
fk = {'0': batch_feats, '1': label_emb[idx].to(device)}
output_att = model(fk)
L1 = loss_fcn(output_att[:len(idx_1)], y)
L2 = F.cross_entropy(output_att[len(idx_1):], extra_y, reduction='none')
L2 = (L2 * extra_weight).sum() / len(idx_2)
loss_train = L1_ratio * L1 + gama * L2_ratio * L2
loss_train.backward()
optimizer.step()
y_true.append(labels[idx_1].to(torch.long))
y_pred.append(output_att[:len(idx_1)].argmax(dim=-1, keepdim=True).cpu())
total_loss += loss_train.item()
loss_l1 += L1.item()
loss_l2 += L2.item()
iter_num += 1
print(loss_l1 / iter_num, loss_l2 / iter_num)
loss = total_loss / iter_num
approx_acc = evaluator(torch.cat(y_true, dim=0), torch.cat(y_pred, dim=0))
return loss, approx_acc
@torch.no_grad()
def gen_output_torch(self, model, feats, label_feats, label_emb, test_loader, device):
model.eval()
preds = []
for batch in tqdm(test_loader):
batch_feats = {k: x[batch].to(device) for k, x in feats.items()}
batch_labels_feats = {k: x[batch].to(device) for k, x in label_feats.items()}
batch_label_emb = label_emb[batch].to(device)
preds.append(model(batch_feats, batch_labels_feats, batch_label_emb).cpu())
preds = torch.cat(preds, dim=0)
return preds
def get_ogb_evaluator(self, dataset):
evaluator = Evaluator(name=dataset)
return lambda preds, labels: evaluator.eval({
"y_true": labels.view(-1, 1),
"y_pred": preds.view(-1, 1),
})["acc"]