Source code for openhgnn.trainerflow.hgt_trainer

import argparse
import copy
import dgl
import numpy as np
import torch
from tqdm import tqdm
import torch.nn.functional as F
from ..models import build_model
from ..sampler import HGTsampler
from . import BaseFlow, register_flow
from ..tasks import build_task
from ..utils import extract_embed, EarlyStopping


[docs] @register_flow("hgttrainer") class HGTTrainer(BaseFlow): """HGTtrainer flows. Supported Model: HGT Supported Dataset:ogbn-mag """ def __init__(self, args): super(HGTTrainer, self).__init__(args) self.args = args self.model_name = args.model self.device = args.device self.task = build_task(args) self.hg = self.task.get_graph().to(self.device) self.num_classes = self.task.dataset.num_classes # Build the model. If the output dim is not equal the number of classes, a MLP will follow the gnn model. if args.out_dim != self.num_classes: print('Modify the out_dim with num_classes') args.out_dim = self.num_classes self.model = build_model(self.model).build_model_from_args(self.args, self.hg) if not hasattr(args, 'out_dim') or args.out_dim == self.num_classes: pass else: #self.model = MLP_follow_model(self.model, args.out_dim, self.num_classes) pass self.model = self.model.to(self.device) self.evaluator = self.task.get_evaluator('acc') self.loss_fn = self.task.get_loss_fn() self.optimizer = ( torch.optim.AdamW(self.model.parameters(), lr=args.lr, weight_decay=args.weight_decay)) self.patience = args.patience self.max_epoch = args.max_epoch self.category = self.task.dataset.category self.train_idx, self.val_idx, self.test_idx = self.task.get_split() self.labels = self.task.get_labels().to(self.device) def preprocess(self): if self.args.mini_batch_flag: N_SAMPLE_NODES_PER_TYPE = 1280 # number of nodes to sample per node type per sampler step N_SAMPLE_STEPS = 6 # number of sampler steps # self.hg = HGT_preprocess4mag(self.hg, self.train_idx) sampler = HGTsampler(self.hg.to('cpu'), self.category, N_SAMPLE_NODES_PER_TYPE, N_SAMPLE_STEPS) self.dataloader = torch.utils.data.DataLoader( self.train_idx, batch_size=self.args.batch_size, collate_fn=sampler.sampler_subgraph, #num_workers=self.args.num_workers, shuffle=True, drop_last=False, ) # self.dataloader_it = iter(dataloader) # next(self.dataloader_it) return def train(self): self.preprocess() stopper = EarlyStopping(self.args.patience) epoch_iter = tqdm(range(self.max_epoch)) print(0.2) for epoch in epoch_iter: self.evaluate() # if self.args.mini_batch_flag: # train_loss = self._mini_train_step() # else: # train_loss = self._full_train_setp() # print(train_loss) # torch.save(self.model.state_dict(), './openhgnn/output/HGT/epoch' + str(epoch) + 'HGT.pt') #if (epoch + 1) % self.evaluate_interval == 0: #f1, losses = self._test_step() # train_f1 = f1["train"] # val_f1 = f1["val"] # val_loss = losses["val"] # epoch_iter.set_description( # f"Epoch: {epoch:03d}, Train_macro_f1: {train_f1[0]:.4f},Val_macro_f1: {val_f1[0]:.4f}, train_loss:{train_loss: .4f}" # ) # early_stop = stopper.step(val_loss, val_f1[0], self.model) # if early_stop: # print('Early Stop!\tEpoch:' + str(epoch)) # break print(f"Valid accurracy = {stopper.best_score: .4f}") self.model = stopper.best_model test_f1, _ = self._test_step(split="test") val_f1, _ = self._test_step(split="val") print(f"Test accuracy = {test_f1[0]:.4f}") return dict(Acc=test_f1, ValAcc=val_f1) def _full_train_setp(self): self.model.train() h = self.hg.ndata['h'] logits = self.model(self.hg, h)[self.category] loss = self.loss_fn(logits[self.train_idx], self.labels[self.train_idx]) self.optimizer.zero_grad() loss.backward() self.optimizer.step() return loss.item() def _mini_train_step(self,): self.model.train() loss_all = 0 #torch.save(self.model.state_dict(), './openhgnn/output/HGT/'+ str(0) + 'HGT.pt') for i, (sg, seed_nodes) in tqdm(enumerate(self.dataloader)): sg = sg.to(self.device) h = sg.ndata.pop('h') logits = self.model(sg, h)[self.category] labels = self.labels[sg.ndata[dgl.NID][self.category]].squeeze() loss = self.loss_fn(logits, labels) loss_all += loss.item() self.optimizer.zero_grad() loss.backward() self.optimizer.step() return loss_all def evaluate(self, split=None): ck_pt = torch.load('./openhgnn/output/HGT/epoch7HGT.pt') log = open('./openhgnn/output/HGT/train.log', 'w') self.model.load_state_dict(ck_pt) self.model.eval() with torch.no_grad(): sum_preds = torch.zeros(self.hg.num_nodes(self.category), self.num_classes).to(self.device) counts = torch.zeros(self.hg.num_nodes(self.category)).to(self.device) for sg, num_seed_nodes in tqdm(self.dataloader): sg = sg.to(self.device) h = sg.ndata.pop('h') logits = self.model(sg, h)[self.category] #labels = self.labels[sg.ndata[dgl.NID][self.category]].squeeze() nid = sg.ndata[dgl.NID][self.category] ones = torch.ones(nid.shape[0]).to(self.device) sum_preds.scatter_add_(0, nid[:, None].expand_as(logits), logits) counts.scatter_add_(0, nid, ones) avg_preds = sum_preds/ counts[:, None] avg_preds[torch.isnan(avg_preds)] = 0 final_preds = sum_preds.argmax(1) del sum_preds del counts if split == "train": mask = self.train_idx elif split == "val": mask = self.val_idx elif split == "test": mask = self.test_idx else: mask = None if mask is not None: loss = self.loss_fn(avg_preds[mask], self.labels[mask]) metric = self.evaluator(self.labels[mask].to('cpu'), final_preds[mask].to('cpu')) return metric, loss else: masks = {'train': self.train_idx, 'val': self.val_idx, 'test': self.test_idx} metrics = {key: self.evaluator(self.labels[mask].to('cpu'), final_preds[mask].to('cpu')) for key, mask in masks.items()} losses = {key: self.loss_fn(avg_preds[mask], self.labels[mask].squeeze()) for key, mask in masks.items()} print('Train:', metrics['train'], 'Validation:', metrics['val'], 'Test:', metrics['test']) print('Train:', metrics['train'], 'Validation:', metrics['val'], 'Test:', metrics['test'], file=log, flush=True) return metrics, losses def _test_step(self, split=None, logits=None): self.model.eval() with torch.no_grad(): h = self.hg.ndata['h'] logits = logits if logits else self.model(self.hg, h)[self.category] if split == "train": mask = self.train_idx elif split == "val": mask = self.val_idx elif split == "test": mask = self.test_idx else: mask = None if mask is not None: loss = self.loss_fn(logits[mask], self.labels[mask]).item() metric = self.evaluator(self.labels[mask].to('cpu'), logits[mask].argmax(dim=1).to('cpu')) return metric, loss else: masks = {'train': self.train_idx, 'val': self.val_idx, 'test': self.test_idx} metrics = {key: self.evaluator(self.labels[mask].to('cpu'), logits[mask].argmax(dim=1).to('cpu')) for key, mask in masks.items()} losses = {key: self.loss_fn(logits[mask], self.labels[mask]).item() for key, mask in masks.items()} return metrics, losses