openhgnn.trainerflow.HeGAN_trainer 源代码

import argparse
import copy
import dgl
import numpy as np
import torch
from tqdm import tqdm
import torch.nn.functional as F
from openhgnn.models import build_model
from . import BaseFlow, register_flow
from ..tasks import build_task
import random
import copy

class GraphSampler:
    r"""
    First load graph data to self.hg_dict, then interate.
    """
    def __init__(self, hg, k):
        self.k = k
        self.ets = hg.canonical_etypes
        self.nt_et = {}
        for et in hg.canonical_etypes:
            if et[0] not in self.nt_et:
                self.nt_et[et[0]] = [et]
            else:
                self.nt_et[et[0]].append(et)

        self.hg_dict = {key: {} for key in hg.ntypes}
        for nt in hg.ntypes:
            for nid in range(hg.num_nodes(nt)):
                if nid not in self.hg_dict[nt]:
                    self.hg_dict[nt][nid] = {}
                for et in self.nt_et[nt]:
                    self.hg_dict[nt][nid][et] = hg.successors(nid, et)

    def sample_graph_for_dis(self):
        r"""
        sample three graphs from original graph.

        Note
        ------------
        pos_hg:
            Sampled graph from true graph distribution, that is from the original graph with real node and real relation.
        neg_hg1:
            Sampled graph with true nodes pair but wrong realtion.
        neg_hg2:
            Sampled graph with true scr nodes and realtion but wrong node embedding.
            Embedding are generated by Generator, so we can use `pos_hg` as adjacency matrix.
        """
        pos_dict = {}
        neg_dict1 = {}

        for nt in self.hg_dict.keys():
            for src in self.hg_dict[nt].keys():
                for i in range(self.k):
                    et = random.choice(self.nt_et[nt])
                    dst = random.choice(self.hg_dict[nt][src][et])
                    if et not in pos_dict:
                        pos_dict[et] = ([src], [dst])
                    else:
                        pos_dict[et][0].append(src)
                        pos_dict[et][1].append(dst)

                    wrong_et = random.choice(self.ets)
                    while wrong_et == et:
                        wrong_et = random.choice(self.ets)
                    wrong_et = (et[0], wrong_et[1], et[2])

                    if wrong_et not in neg_dict1:
                        neg_dict1[wrong_et] = ([src], [dst])
                    else:
                        neg_dict1[wrong_et][0].append(src)
                        neg_dict1[wrong_et][1].append(dst)

        pos_hg = dgl.heterograph(pos_dict, {nt: len(self.hg_dict[nt].keys()) for nt in self.hg_dict.keys()})
        neg_hg1 = dgl.heterograph(neg_dict1, {nt: len(self.hg_dict[nt].keys()) for nt in self.hg_dict.keys()})
        neg_hg2 = dgl.heterograph(pos_dict, {nt: len(self.hg_dict[nt].keys()) for nt in self.hg_dict.keys()})

        return pos_hg, neg_hg1, neg_hg2

    def sample_graph_for_gen(self):
        d = {}
        for nt in self.hg_dict.keys():
            for src in self.hg_dict[nt].keys():
                for i in range(self.k):
                    et = self.nt_et[nt][random.randint(0, len(self.nt_et[nt]) - 1)]
                    dst = self.hg_dict[nt][src][et][random.randint(0, len(self.hg_dict[nt][src][et]) - 1)]
                    if et not in d:
                        d[et] = ([src], [dst])
                    else:
                        d[et][0].append(src)
                        d[et][1].append(dst)

        return dgl.heterograph(d, {nt: len(self.hg_dict[nt].keys()) for nt in self.hg_dict.keys()})

[文档]@register_flow('HeGAN_trainer') class HeGANTrainer(BaseFlow): """Node classification flows. Supported Model: HeGAN Supported Dataset:yelp The task is to classify the nodes of HIN(Heterogeneous Information Network). Note: If the output dim is not equal the number of classes, a MLP will follow the gnn model. """ def __init__(self, args): super().__init__(args) self.num_classes = self.task.dataset.num_classes self.category = self.task.dataset.category self.hg = self.task.get_graph() self.model = build_model(self.model).build_model_from_args(self.args, self.hg) self.model = self.model.to(self.device) self.label_smooth = args.label_smooth self.evaluator = self.task.evaluator.classification self.evaluate_interval = 1 self.loss_fn = torch.nn.BCEWithLogitsLoss(reduction='sum') self.optim_dis = torch.optim.Adam(self.model.discriminator.parameters(), lr=args.lr_dis, weight_decay=args.wd_dis) self.optim_gen = torch.optim.Adam(self.model.generator.parameters(), lr=args.lr_gen, weight_decay=args.wd_gen) self.train_idx, self.val_idx, self.test_idx = self.task.get_split() self.labels = self.task.get_labels().to(self.device) self.sampler = GraphSampler(self.hg, self.args.n_sample) def train(self): epoch_iter = tqdm(range(self.args.max_epoch)) for epoch in epoch_iter: if self.args.mini_batch_flag: dis_loss, gen_loss = self._mini_train_step() else: dis_loss, gen_loss = self._full_train_step() dis_score, gen_score = self._test_step() print(epoch) print("discriminator:\n\tloss:{:.4f}\n\tmicro_f1: {:.4f},\n\tmacro_f1: {:.4f}".format(dis_loss, dis_score[0], dis_score[1])) print("generator:\n\tloss:{:.4f}\n\tmicro_f1: {:.4f},\n\tmacro_f1: {:.4f}".format(gen_loss, gen_score[0], gen_score[1])) def _mini_train_step(self): dis_loss, gen_loss = None, None return dis_loss, gen_loss def _full_train_step(self): r""" Note ---- pos_loss: positive graph loss. neg_loss1: negative graph loss with wrong realtions. neg_loss2: negativa graph loss with wrong nodes embedding. """ self.model.train() gen_loss = None dis_loss = None # discriminator step for _ in range(self.args.epoch_dis): pos_hg, pos_hg1, pos_hg2 = self.sampler.sample_graph_for_dis() pos_hg = pos_hg.to(self.device) pos_hg1 = pos_hg1.to(self.device) pos_hg2 = pos_hg2.to(self.device) noise_emb = { et: torch.tensor(np.random.normal(0.0, self.args.sigma, (pos_hg2.num_edges(et), self.args.emb_size)).astype('float32')).to(self.device) for et in pos_hg2.canonical_etypes } self.model.generator.assign_node_data(pos_hg2, None) self.model.generator.assign_edge_data(pos_hg2, None) generate_neighbor_emb = self.model.generator.generate_neighbor_emb(pos_hg2, noise_emb) pos_score, neg_score1, neg_score2 = self.model.discriminator(pos_hg, pos_hg1, pos_hg2, generate_neighbor_emb) pos_loss = -torch.mean(F.logsigmoid(pos_score)) neg_loss1 = -torch.mean(F.logsigmoid(1-neg_score1 + 1e-5)) neg_loss2 = -torch.mean(F.logsigmoid(1-neg_score2 + 1e-5)) dis_loss = pos_loss + neg_loss2 + neg_loss1 self.optim_dis.zero_grad() dis_loss.backward() self.optim_dis.step() # generator step dis_node_emb, dis_relation_matrix = self.model.discriminator.get_parameters() for _ in range(self.args.epoch_gen): gen_hg = self.sampler.sample_graph_for_gen() noise_emb = { et: torch.tensor(np.random.normal(0.0, self.args.sigma, (gen_hg.num_edges(et), self.args.emb_size)).astype('float32')).to(self.device) for et in gen_hg.canonical_etypes } gen_hg = gen_hg.to(self.device) score = self.model.generator(gen_hg, dis_node_emb, dis_relation_matrix, noise_emb) gen_loss = -torch.mean(F.logsigmoid(score))*(1-self.label_smooth)+\ -torch.mean(F.logsigmoid(1-score + 1e-5))*self.label_smooth self.optim_gen.zero_grad() gen_loss.backward() self.optim_gen.step() return dis_loss.item(), gen_loss.item() def _test_step(self, split=None, logits=None): self.model.eval() self.model.generator.eval() self.model.discriminator.eval() with torch.no_grad(): dis_emb = self.model.discriminator.nodes_embedding[self.category] gen_emb = self.model.generator.nodes_embedding[self.category] dis_metric = self.evaluator(dis_emb.cpu(), self.labels.cpu()) gen_metric = self.evaluator(gen_emb.cpu(), self.labels.cpu()) return dis_metric, gen_metric