Source code for openhgnn.trainerflow.GATNE_trainer

import torch as th
from tqdm import tqdm
from . import BaseFlow, register_flow
from ..models import build_model
from ..models.GATNE import NSLoss
import torch
from tqdm.auto import tqdm
from numpy import random
import dgl
from ..sampler.GATNE_sampler import NeighborSampler, generate_pairs


[docs] @register_flow("GATNE_trainer") class GATNE(BaseFlow): def __init__(self, args): super(GATNE, self).__init__(args) self.model = build_model(self.model).build_model_from_args(self.args, self.hg).to(self.device) self.train_pairs = None self.train_dataloader = None self.nsloss = None self.neighbor_sampler = None self.orig_val_hg = self.task.val_hg self.orig_test_hg = self.task.test_hg self.preprocess() def preprocess(self): assert len(self.hg.ntypes) == 1 bidirected_hg = dgl.to_bidirected(dgl.to_simple(self.hg.to('cpu'))) all_walks = [] for etype in self.hg.etypes: nodes = torch.unique(bidirected_hg.edges(etype=etype)[0]).repeat(self.args.rw_walks) traces, types = dgl.sampling.random_walk( bidirected_hg, nodes, metapath=[etype] * (self.args.rw_length - 1) ) all_walks.append(traces) self.train_pairs = generate_pairs(all_walks, self.args.window_size, self.args.num_workers) self.neighbor_sampler = NeighborSampler(bidirected_hg, [self.args.neighbor_samples]) self.train_dataloader = torch.utils.data.DataLoader( self.train_pairs, batch_size=self.args.batch_size, collate_fn=self.neighbor_sampler.sample, shuffle=True, num_workers=self.args.num_workers, pin_memory=True, ) self.nsloss = NSLoss(self.hg.num_nodes(), self.args.neg_size, self.args.dim).to(self.device) self.optimizer = torch.optim.Adam( [{"params": self.model.parameters()}, {"params": self.nsloss.parameters()}], lr=self.args.learning_rate ) return def train(self): best_score = 0 patience = 0 for self.epoch in range(self.args.max_epoch): self._full_train_step() cur_score = self._full_test_step() if cur_score > best_score: best_score = cur_score patience = 0 else: patience += 1 if patience > self.args.patience: self.logger.train_info(f'Early Stop!\tEpoch:{self.epoch:03d}.') break def _full_train_step(self): self.model.train() random.shuffle(self.train_pairs) data_iter = tqdm( self.train_dataloader, desc="epoch %d" % self.epoch, total=(len(self.train_pairs) + (self.args.batch_size - 1)) // self.args.batch_size, ) avg_loss = 0.0 for i, (block, head_invmap, tails, block_types) in enumerate(data_iter): self.optimizer.zero_grad() # embs: [batch_size, edge_type_count, embedding_size] block_types = block_types.to(self.device) embs = self.model(block[0].to(self.device))[head_invmap] embs = embs.gather( 1, block_types.view(-1, 1, 1).expand(embs.shape[0], 1, embs.shape[2]) )[:, 0] loss = self.nsloss( block[0].dstdata[dgl.NID][head_invmap].to(self.device), embs, tails.to(self.device), ) loss.backward() self.optimizer.step() avg_loss += loss.item() post_fix = { "epoch": self.epoch, "iter": i, "avg_loss": avg_loss / (i + 1), "loss": loss.item(), } data_iter.set_postfix(post_fix) def _full_test_step(self): self.model.eval() # {'1': {}, '2': {}} final_model = dict( zip(self.hg.etypes, [th.empty(self.hg.num_nodes(), self.args.dim) for _ in range(len(self.hg.etypes))])) for i in tqdm(range(self.hg.num_nodes()), desc='Evaluating...'): train_inputs = ( torch.tensor([i for _ in range(len(self.hg.etypes))]) .unsqueeze(1) .to(self.device) ) # [i, i] train_types = ( torch.tensor(list(range(len(self.hg.etypes)))).unsqueeze(1).to(self.device) ) # [0, 1] pairs = torch.cat( (train_inputs, train_inputs, train_types), dim=1 ) # (2, 3) ( train_blocks, train_invmap, fake_tails, train_types, ) = self.neighbor_sampler.sample(pairs) node_emb = self.model(train_blocks[0].to(self.device))[train_invmap] node_emb = node_emb.gather( 1, train_types.to(self.device) .view(-1, 1, 1) .expand(node_emb.shape[0], 1, node_emb.shape[2]), )[:, 0] for j in range(len(self.hg.etypes)): final_model[self.hg.etypes[j]][i] = node_emb[j].detach() metric = {} score = [] for etype in self.hg.etypes: self.task.val_hg = dgl.edge_type_subgraph(self.orig_val_hg, [etype]) self.task.test_hg = dgl.edge_type_subgraph(self.orig_test_hg, [etype]) for split in ['test', 'valid']: n_embedding = {self.hg.ntypes[0]: final_model[etype].to(self.device)} res = self.task.evaluate(n_embedding=n_embedding, mode=split) metric[split] = res if split == 'valid': score.append(res.get('roc_auc')) self.logger.train_info(etype + self.logger.metric2str(metric)) avg_score = sum(score) / len(score) return avg_score