Source code for openhgnn.trainerflow.base_flow

import os
import torch
from abc import ABC, abstractmethod

from ..tasks import build_task
from ..layers.HeteroLinear import HeteroFeature
from ..utils import get_nodes_dict


[docs] class BaseFlow(ABC): candidate_optimizer = { 'Adam': torch.optim.Adam, 'SGD': torch.optim.SGD, 'Adadelta': torch.optim.Adadelta } def __init__(self, args): """ Parameters ---------- args Attributes ------------- evaluate_interval: int the interval of evaluation in validation """ super(BaseFlow, self).__init__() self.evaluator = None self.evaluate_interval = getattr(args, 'evaluate_interval', 1) if hasattr(args, 'model_path'): self._checkpoint = args.model_path elif hasattr(args, '_checkpoint'): self._checkpoint = os.path.join(args._checkpoint, f"{args.model_name}_{args.dataset_name}.pt") else: if hasattr(args, 'load_from_pretrained'): self._checkpoint = os.path.join(args.output_dir, f"{args.model_name}_{args.dataset_name}_{args.task}.pt") else: self._checkpoint = None if not hasattr(args, 'HGB_results_path') and args.dataset_name[:3] == 'HGB': args.HGB_results_path = os.path.join(args.output_dir, "{}_{}_{}.txt".format(args.model_name, args.dataset_name[5:], args.seed)) # Distributed models will check this parameter during the training process to determine whether to use distributed. self.use_distributed = args.use_distributed # stage flags: whether to run the corresponding stages # todo: only take effects in node classification trainer flow # args.training_flag = getattr(args, 'training_flag', True) # args.validation_flag = getattr(args, 'validation_flag', True) args.test_flag = getattr(args, 'test_flag', True) args.prediction_flag = getattr(args, 'prediction_flag', False) args.use_uva = getattr(args, 'use_uva', False) self.args = args self.logger = self.args.logger self.model_name = args.model_name self.model = args.model self.device = args.device self.task = build_task(args) self.max_epoch = args.max_epoch self.optimizer = None if self.model_name in ["SIAN", "MeiREC", "ExpressGNN", "Ingram", "RedGNN","RedGNNT", "AdapropI", "AdapropT","RedGNNT", "Grail", "ComPILE","DisenKGAT"]: return if self.model_name == "Ingram": return if self.args.use_uva: self.hg = self.task.get_graph() else: self.hg = self.task.get_graph().to(self.device) self.args.meta_paths_dict = self.task.dataset.meta_paths_dict self.patience = args.patience self.loss_fn = self.task.get_loss_fn() def preprocess(self): r""" Every trainerflow should run the preprocess_feature if you want to get a feature preprocessing. The Parameters in input_feature will be added into optimizer and input_feature will be added into the model. Attributes ----------- input_feature : HeteroFeature It will return the processed feature if call it. """ if hasattr(self.args, 'activation'): if hasattr(self.args.activation, 'weight'): import torch.nn as nn act = nn.PReLU() else: act = self.args.activation else: act = None # useful type selection if hasattr(self.args, 'feat'): pass else: # Default 0, nothing to do. self.args.feat = 0 self.feature_preprocess(act) self.optimizer.add_param_group({'params': self.input_feature.parameters()}) # for early stop, load the model with input_feature module. self.model.add_module('input_feature', self.input_feature) self.load_from_pretrained() def feature_preprocess(self, act): """ Feat 0, 1 ,2 Node feature 1 node type & more than 1 node types no feature Returns ------- """ if self.hg.ndata.get('h', {}) == {} or self.args.feat == 2: if self.hg.ndata.get('h', {}) == {}: self.logger.feature_info('Assign embedding as features, because hg.ndata is empty.') else: self.logger.feature_info('feat2, drop features!') self.hg.ndata.pop('h') self.input_feature = HeteroFeature({}, get_nodes_dict(self.hg), self.args.hidden_dim, act=act).to(self.device) elif self.args.feat == 0: self.input_feature = self.init_feature(act) elif self.args.feat == 1: if self.args.task != 'node_classification': self.logger.feature_info('\'feat 1\' is only for node classification task, set feat 0!') self.input_feature = self.init_feature(act) else: h_dict = self.hg.ndata.pop('h') self.logger.feature_info('feat1, preserve target nodes!') self.input_feature = HeteroFeature({self.category: h_dict[self.category]}, get_nodes_dict(self.hg), self.args.hidden_dim, act=act).to(self.device) def init_feature(self, act): self.logger.feature_info("Feat is 0, nothing to do!") if isinstance(self.hg.ndata['h'], dict): # The heterogeneous contains more than one node type. input_feature = HeteroFeature(self.hg.ndata['h'], get_nodes_dict(self.hg), self.args.hidden_dim, act=act).to(self.device) elif isinstance(self.hg.ndata['h'], torch.Tensor): # The heterogeneous only contains one node type. input_feature = HeteroFeature({self.hg.ntypes[0]: self.hg.ndata['h']}, get_nodes_dict(self.hg), self.args.hidden_dim, act=act).to(self.device) return input_feature @abstractmethod def train(self): pass def _full_train_step(self): r""" Train with a full_batch graph """ raise NotImplementedError def _mini_train_step(self): r""" Train with a mini_batch seed nodes graph """ raise NotImplementedError def _full_test_step(self): r""" Test with a full_batch graph """ raise NotImplementedError def _mini_test_step(self): r""" Test with a mini_batch seed nodes graph """ raise NotImplementedError def load_from_pretrained(self): if hasattr(self.args, 'load_from_pretrained') and self.args.load_from_pretrained: try: ck_pt = torch.load(self._checkpoint) self.model.load_state_dict(ck_pt) self.logger.info('[Load Model] Load model from pretrained model:' + self._checkpoint) except FileNotFoundError: self.logger.info('[Load Model] Do not load the model from pretrained, ' '{} doesn\'t exists'.format(self._checkpoint)) # return self.model def save_checkpoint(self): if self._checkpoint and hasattr(self.model, "_parameters()"): torch.save(self.model.state_dict(), self._checkpoint)