Source code for openhgnn.trainerflow.slice_trainer

import copy

from dgl._ffi.base import DGLError
from openhgnn import sampler

from torch.utils.data.dataloader import DataLoader
from torch.utils.data.sampler import BatchSampler
from torch.utils.data import TensorDataset
from openhgnn.trainerflow.base_flow import BaseFlow
import dgl
from networkx.algorithms.centrality.betweenness import edge_betweenness_centrality
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import torch.nn.functional as F
from tqdm import tqdm, trange
from sklearn.model_selection import train_test_split
import random
import os
import pickle
import json
import time
from typing import List
import shutil
import copy
import pandas as pd
import math
from sklearn.metrics import (
    accuracy_score,
    auc,
    f1_score,
    mean_squared_error,
    precision_recall_curve,
    precision_recall_fscore_support,
    roc_auc_score,
    roc_curve,
)

from . import BaseFlow, register_flow
from ..tasks import build_task
from openhgnn.models import build_model
from openhgnn.models.SLiCE import SLiCE, SLiCEFinetuneLayer
from ..utils import extract_embed, EarlyStopping

from ..sampler.SLiCE_sampler import SLiCESampler
[docs] @register_flow("slicetrainer") class SLiCETrainer(BaseFlow): def __init__(self,args): super(SLiCETrainer, self).__init__(args) self.out_dir=args.outdir self.pretrain_path=os.path.join(self.out_dir,'pretrain/') self.pretrain_save_path=os.path.join(self.pretrain_path,'best_pretrain_model.pt') self.finetune_path=os.path.join(self.out_dir,'finetune/') self.finetune_save_path=os.path.join(self.finetune_path,'best_finetune_model.pt') self.g=self.task.dataset.g self.g=dgl.to_homogeneous(self.g,edata=['train_mask','valid_mask','test_mask','label']) self.model=dict() self.model['pretrain']=SLiCE.build_model_from_args(self.args,self.g) self.model['finetune']=SLiCEFinetuneLayer.build_model_from_args(args) #loss function self.loss_fn=torch.nn.CrossEntropyLoss() #optimizer self.optimizer=dict() self.optimizer['pretrain']=optim.Adam(self.model['pretrain'].parameters(), lr=args.lr) self.optimizer['finetune']=optim.Adam(self.model['finetune'].parameters(), lr=args.ft_lr) self.patience=5 #for early stopping #number of epochs self.n_epochs=dict() self.n_epochs['pretrain']=args.n_epochs self.n_epochs['finetune']=args.ft_n_epochs #batch size self.batch_size=dict() self.batch_size['pretrain']=args.batch_size self.batch_size['finetune']=args.ft_batch_size self.labels = self.g.edata['label'] self.idx=dict() #pretrain self.node_subgraphs=dict() #finetune self.edges=dict() self.edges_label=dict() self.graphs=dict() self.best_epoch=dict() self.is_pretrained=False self.is_finetuned=False self.threshold=None def preprocess(self): if not os.path.exists(self.pretrain_path): os.makedirs(self.pretrain_path) if not os.path.exists(self.finetune_path): os.makedirs(self.finetune_path) for task in ['train','valid','test']: #make directories if not os.path.exists(os.path.join(self.finetune_path,task)): os.makedirs(os.path.join(self.finetune_path,task)) mask=self.g.edata[task+'_mask'] index = torch.nonzero(mask.squeeze()).squeeze() if task=='train': self.idx['train']=index elif task=='valid': self.idx['valid']=index else: self.idx['test']=index self.edges[task]=self.g.find_edges(index) #finally, g should be a graph containing just train_edges self.graphs[task]=dgl.edge_subgraph(self.g,index) #sample walks self.sampler=SLiCESampler(self.g,self.graphs['train'],num_walks_per_node=self.args.n_pred,beam_width=self.args.beam_width, max_num_edges=self.args.max_length,walk_type=self.args.walk_type, path_option=self.args.path_option,save_path=self.out_dir)#full graph sampler=self.sampler #get dataloader for pretrain and finetune evaluation on link prediction g=self.g #pretrain node_walk_path=self.pretrain_path+'node_walks.bin' if os.path.exists(node_walk_path): node_walks,_=dgl.load_graphs(node_walk_path) else: node_walks=sampler.get_node_subgraph(g.nodes()) dgl.save_graphs(node_walk_path,node_walks) random.shuffle(node_walks) total_len=len(node_walks) train_size=int(0.8*total_len) valid_size=int(0.1*total_len) self.node_subgraphs['train']=node_walks[:train_size] self.node_subgraphs['valid']=node_walks[train_size:train_size+valid_size] self.node_subgraphs['test']=node_walks[train_size+valid_size:] #finetune src,dst=g.find_edges(self.idx['train']) self.edges['train']=list(zip(src.tolist(),dst.tolist())) src,dst=g.find_edges(self.idx['valid']) self.edges['valid']=list(zip(src.tolist(),dst.tolist())) src,dst=g.find_edges(self.idx['test']) self.edges['test']=list(zip(src.tolist(),dst.tolist())) self.edges_label['train']=list() self.edges_label['valid']=[int(self.labels[x]) for x in self.idx['valid']] self.edges_label['test']=[int(self.labels[x]) for x in self.idx['test']] edges_label=self.edges_label #generate pretrain subgraph train_file=os.path.join(self.finetune_path,'train_edges.pickle') if os.path.exists(train_file): with open(train_file,'rb') as f: self.edges['train'],edges_label['train']=pickle.load(f) else: self.edges['train'],edges_label['train']=sampler.generate_false_edges2(self.edges['train'],train_file) self.edges['valid'],self.edges_label['valid']=sampler.shuffle_edge_label(self.edges['valid'],self.edges_label['valid']) self.edges['test'],self.edges_label['test']=sampler.shuffle_edge_label(self.edges['test'],self.edges_label['test']) #generate finetune subgraph for task in ['train','valid','test']: edges=self.edges[task] batch_size=self.batch_size['finetune'] n_batch=int(len(edges)/batch_size) total_len=len(edges) for batch in range(n_batch): i=batch*batch_size if i+batch_size<total_len: end=i+batch_size else: end=total_len batch_file=os.path.join(self.finetune_path,'{}/edge_subgraph_{}.bin'.format(task,batch)) #pair_file=self.finetune_path+'{}/pair_subgraph_{}.pickle'.format(task,batch) if not os.path.exists(batch_file): subgraph_list=self.sampler.get_edge_subgraph(self.edges[task][i:end]) dgl.save_graphs(batch_file,subgraph_list) def train(self): self.preprocess() self.pretrain() self.finetune() def pretrain(self): print("Start Pretraining...") stopper=EarlyStopping(self.patience) batch_size=self.batch_size['pretrain'] self.model['pretrain'].train() self.is_pretrained=True if os.path.exists(self.pretrain_save_path): pass for epoch in range(self.n_epochs['pretrain']): print("Epoch {}:".format(epoch)) i=0 total_len=len(self.node_subgraphs['train']) n_batch=math.ceil(total_len/batch_size) bar=tqdm(range(n_batch)) avg_loss=0 for batch in bar: i=batch*batch_size if i+batch_size<total_len: subgraph_list=self.node_subgraphs['train'][i:i+batch_size] else: subgraph_list=self.node_subgraphs['train'][i:] pred_data,true_data=self.model['pretrain'](subgraph_list) loss=self.loss_fn(pred_data.transpose(1,2),true_data) avg_loss+=float(loss) self.optimizer['pretrain'].zero_grad() loss.backward() self.optimizer['pretrain'].step() i+=batch_size bar.set_description("Batch {} Loss: {:.3f}".format(batch,loss)) #torch.save(self.model['pretrain'],self.pretrain_path+'model_'+str(ii)+'SLiCE.pt') avg_loss=avg_loss/n_batch print("AvgLoss: {:.3f}".format(avg_loss)) early_stop=stopper.loss_step(avg_loss,self.model['pretrain']) if early_stop: print('Early Stop!\tEpoch:' + str(epoch)) break self.best_epoch['pretrain']=epoch torch.save(self.model['pretrain'].state_dict(),self.pretrain_save_path) print("Evaluating for pretraining...") def finetune(self): if not os.path.exists(self.pretrain_save_path): print("Model not pretrained!") else: ck_pt=torch.load(self.pretrain_save_path) self.model['pretrain'].load_state_dict(ck_pt) self.model['pretrain'].eval() self.model['pretrain'].set_fine_tuning() self.model['finetune'].train() print("Start Finetuning...") stopper=EarlyStopping(self.patience) batch_size=self.batch_size['finetune'] for epoch in range(self.n_epochs['finetune']): batch=0 total_len=len(self.edges['train']) print("Eopch {}:".format(epoch)) n_batch=math.ceil(total_len/batch_size) bar=tqdm(range(n_batch)) avg_loss=0 for batch in bar: i=batch*batch_size if i+batch_size<total_len: end=i+batch_size else: end=total_len batch_file=os.path.join(self.finetune_path,'{}/edge_subgraph_{}.bin'.format('train',batch)) if os.path.exists(batch_file): subgraph_list,_=dgl.load_graphs(batch_file) else: subgraph_list=self.sampler.get_edge_subgraph(self.edges['train'][i:end]) dgl.save_graphs(batch_file,subgraph_list) self.model['pretrain'].set_fine_tuning() with torch.no_grad(): _,layer_output,_=self.model['pretrain'](subgraph_list) pred_scores,_,_=self.model['finetune'](layer_output) loss=F.binary_cross_entropy(pred_scores,torch.tensor(self.edges_label['train'][i:end],dtype=torch.float).reshape(-1,1)) bar.set_description('Batch {}: Loss:{:.3f}'.format(batch,loss)) avg_loss+=float(loss) self.optimizer['finetune'].zero_grad() loss.backward() self.optimizer['finetune'].step() torch.save(self.model['finetune'].state_dict(),self.finetune_path+'model_'+str(epoch)+'SLiCE.pt') avg_loss=avg_loss/n_batch print("AvgLoss: {:.3f}".format(avg_loss)) early_stop=stopper.loss_step(avg_loss,self.model['finetune']) if early_stop: print('Early Stop!\tEpoch:' + str(epoch)) break self.model['finetune']=stopper.best_model torch.save(self.model,self.finetune_save_path) #run validation to find the best epoch self.is_finetuned=True print("Evaluating for pretraining...") self._test_step() def _test_step(self): with torch.no_grad(): #validation and find best threshold pred_data={'train':[],'valid':[],'test':[]} true_data={'train':[],'valid':[],'test':[]} self.model['pretrain'].eval() for task in ['valid','test']: total_len=len(self.edges[task]) batch_size=self.batch_size['finetune'] n_batch=int(total_len/batch_size) for batch in range(n_batch): i=batch*batch_size if i+batch_size<total_len: end=i+batch_size else: end=total_len #get edge subgraphs for test batch_file=os.path.join(self.finetune_path,'{}/edge_subgraph_{}.bin'.format(task,batch)) if os.path.exists(batch_file): subgraph_list,_=dgl.load_graphs(batch_file) else: subgraph_list=self.sampler.get_edge_subgraph(self.edges[task][i:end]) dgl.save_graphs(batch_file,subgraph_list) #get score and label #output: 100*7*200 layer_output: 100*6*7*200 output,layer_output,_=self.model['pretrain'](subgraph_list) if not self.is_finetuned: source_embed = output[:, 0, :].unsqueeze(1) target_embed = output[:, 1, :].unsqueeze(1).transpose(1, 2) score = torch.bmm(source_embed, target_embed).squeeze(1)#embedding相乘得到相似度分数 score = torch.sigmoid(score).data.cpu().numpy().tolist() else: #score:[ft_batch_size,1] #src_embedding/dst_embedding:[ft_batch_size,1,embedding_dim] score,_,_=self.model['finetune'](layer_output) labels=self.edges_label[task][i:end] for ii, _ in enumerate(score): pred_data[task].append(float(score[ii][0])) true_data[task].append(labels[ii]) i+=batch_size #test and get result real_true_data=np.array(true_data['valid'],dtype=np.int) self.threshold=self.get_threshold(real_true_data,pred_data['valid'])[0] prediction_data=pred_data['test'] sorted_pred = prediction_data[:] sorted_pred.sort() # threshold = sorted_pred[-true_num] y_pred = np.zeros(len(prediction_data), dtype=np.int32) for i, _ in enumerate(prediction_data): if prediction_data[i] >= self.threshold: y_pred[i] = 1 y_true = np.array(true_data['test']) y_scores = np.array(prediction_data) ps, rs, _ = precision_recall_curve(y_true, y_scores) if self.is_finetuned: header="Finetuning" else: header="Pretraining" print(f"----------------------Testing for {header}()------------") print( f"y_true.shape: {y_true.shape}, y_scores.shape: {y_scores.shape}" f", y_pred.shape: {y_pred.shape}" ) try: roc_auc = roc_auc_score(y_true, y_scores) except ValueError: roc_auc = 'UNDEFINED' f1 = f1_score(y_true, y_pred) auc_value = auc(rs,ps) print( f"{header} : ROC-AUC: {roc_auc}," f" F1: {f1}, AUC: {auc_value}" ) def get_threshold(self, target, predicted): fpr, tpr, threshold = roc_curve(target, predicted,pos_label=1) i = np.arange(len(tpr),dtype=np.int) roc = pd.DataFrame( { "tf": pd.Series(tpr - (1 - fpr), index=i), "threshold": pd.Series(threshold, index=i), } ) print() roc_t = roc.loc[(roc.tf - 0).abs().argsort()[:1]] return list(roc_t["threshold"]) def loss_calculation(self, pos_score, neg_score): # an example hinge loss loss = [] for i in pos_score: loss.append(F.logsigmoid(pos_score[i])) loss.append(F.logsigmoid(-neg_score[i])) loss = torch.cat(loss) return -loss.mean() def ScorePredictor(self, edge_subgraphs, pairs, x): #x:[batch_size*num_nodes*embed_dim] score=[] labels=[] for ii,edge_subgraph in enumerate(edge_subgraphs): src_embed=x[ii,0,:] dst_embed=x[ii,1,:] score.append(torch.dot(src_embed,dst_embed)) src,dst,label=pairs[ii] labels.append(label) score=torch.sigmoid(torch.tensor(score)) res=F.binary_cross_entropy(score,torch.FloatTensor(labels)) return res def nid_to_id(self,subgraph,src): for ii,each in enumerate(subgraph.ndata[dgl.NID]): if each==src: return ii return -1