Source code for openhgnn.trainerflow.mg2vec_trainer

import os

import torch
from torch.utils.data import DataLoader
import torch.optim as optim
from ..models import build_model
from . import BaseFlow, register_flow
from ..sampler import mg2vec_sampler
import numpy as np


[docs] @register_flow('mg2vec_trainer') class Mg2vecTrainer(BaseFlow): def __init__(self, args): super(Mg2vecTrainer, self).__init__(args) self.mg2vec_sampler = None self.dataloader = None self.model = None self.embeddings_file_path = os.path.join(self.args.output_dir, self.args.dataset + '_mg2vec_embeddings.npy') self.embeddings_file_path2 = os.path.join(self.args.output_dir, self.args.dataset + '_mg2vec_embeddings.txt') self.load_trained_embeddings = False def preprocess(self): input_file = "./openhgnn/dataset/{}/meta.txt".format(self.args.dataset) block_size = self.args.batch_size * 100000 self.mg2vec_sampler = mg2vec_sampler.Mg2vecSampler(input_file, block_size, self.args.alpha) self.dataloader = DataLoader(self.mg2vec_sampler, batch_size=self.args.batch_size, shuffle=True, num_workers=self.args.num_workers, ) self.args.node_num = self.mg2vec_sampler.data.node_count self.args.mg_num = self.mg2vec_sampler.data.mg_count self.args.unigram = self.mg2vec_sampler.data.unigram self.model = build_model(self.model_name).build_model_from_args(self.args, self.hg).to(self.device) def train(self): emb = self.load_embeddings() emb_dict = dict() for nId, node in self.mg2vec_sampler.data.node_reverse_dict.items(): emb_dict[int(node)] = emb[nId] # todo: only supports edge classification now metric = { 'test': self.task.downstream_evaluate(logits=self.get_edge_embed(emb=emb_dict), evaluation_metric='acc_f1')} self.logger.train_info(self.logger.metric2str(metric)) # metric = {'test': self.task.evaluate(emb_dict)} # self.logger.train_info(self.logger.metric2str(metric)) def load_embeddings(self): if not self.load_trained_embeddings or not os.path.exists(self.embeddings_file_path): self.train_embeddings() emb = np.load(self.embeddings_file_path) return emb def train_embeddings(self): self.preprocess() epoch_index = 1 optimizer = optim.Adam(list(self.model.parameters()), lr=self.args.lr) average_loss = 0.0 step = 0 print("train start") while True: for i, sampled_batch in enumerate(self.dataloader): if len(sampled_batch) > 0: train_a = sampled_batch[0].to(self.device) train_b = sampled_batch[1].to(self.device) train_label = sampled_batch[2].to(self.device) train_freq = sampled_batch[3].reshape(-1, 1).to(self.device) train_weight = sampled_batch[4].reshape(-1, 1).to(self.device) optimizer.zero_grad() loss = self.model.forward(train_a, train_b, train_label, train_freq, train_weight, self.device) loss.backward() optimizer.step() average_loss += loss.item() step += 1 if step > 0 and step % 10000 == 0: average_loss /= 10000 print('Average loss at step ', step, ': ', average_loss) average_loss = 0.0 if self.mg2vec_sampler.data.epoch_end: print("epoch %d end" % epoch_index) epoch_index += 1 self.mg2vec_sampler.data.epoch_end = False if epoch_index > self.args.max_epoch: break self.mg2vec_sampler.data.read_block() print("total step: ", step) self.model.save_embedding_np(self.embeddings_file_path) self.model.save_embedding(self.mg2vec_sampler.data.node_reverse_dict, self.embeddings_file_path2) def get_edge_embed(self, emb): edge_embed = [] g = self.hg u, v = g.edges() core1_dict = g.nodes['core1'].data['id2node'].cpu() core2_dict = g.nodes['core2'].data['id2node'].cpu() for i in range(len(u)): edge_embed.append(np.hstack([emb[int(core1_dict[u[i]])], emb[int(core2_dict[v[i]])]])) x = np.array(edge_embed) return x