import argparse
import copy
import dgl
import numpy as np
import torch
from tqdm import tqdm
import torch.nn.functional as F
from openhgnn.models import build_model
from . import BaseFlow, register_flow
from ..tasks import build_task
import random
import copy
class GraphSampler:
r"""
First load graph data to self.hg_dict, then interate.
"""
def __init__(self, hg, k):
self.k = k
self.ets = hg.canonical_etypes
self.nt_et = {}
for et in hg.canonical_etypes:
if et[0] not in self.nt_et:
self.nt_et[et[0]] = [et]
else:
self.nt_et[et[0]].append(et)
self.hg_dict = {key: {} for key in hg.ntypes}
for nt in hg.ntypes:
for nid in range(hg.num_nodes(nt)):
if nid not in self.hg_dict[nt]:
self.hg_dict[nt][nid] = {}
for et in self.nt_et[nt]:
self.hg_dict[nt][nid][et] = hg.successors(nid, et)
def sample_graph_for_dis(self):
r"""
sample three graphs from original graph.
Note
------------
pos_hg:
Sampled graph from true graph distribution, that is from the original graph with real node and real relation.
neg_hg1:
Sampled graph with true nodes pair but wrong realtion.
neg_hg2:
Sampled graph with true scr nodes and realtion but wrong node embedding.
Embedding are generated by Generator, so we can use `pos_hg` as adjacency matrix.
"""
pos_dict = {}
neg_dict1 = {}
for nt in self.hg_dict.keys():
for src in self.hg_dict[nt].keys():
for i in range(self.k):
et = random.choice(self.nt_et[nt])
dst = random.choice(self.hg_dict[nt][src][et])
if et not in pos_dict:
pos_dict[et] = ([src], [dst])
else:
pos_dict[et][0].append(src)
pos_dict[et][1].append(dst)
wrong_et = random.choice(self.ets)
while wrong_et == et:
wrong_et = random.choice(self.ets)
wrong_et = (et[0], wrong_et[1], et[2])
if wrong_et not in neg_dict1:
neg_dict1[wrong_et] = ([src], [dst])
else:
neg_dict1[wrong_et][0].append(src)
neg_dict1[wrong_et][1].append(dst)
pos_hg = dgl.heterograph(pos_dict, {nt: len(self.hg_dict[nt].keys()) for nt in self.hg_dict.keys()})
neg_hg1 = dgl.heterograph(neg_dict1, {nt: len(self.hg_dict[nt].keys()) for nt in self.hg_dict.keys()})
neg_hg2 = dgl.heterograph(pos_dict, {nt: len(self.hg_dict[nt].keys()) for nt in self.hg_dict.keys()})
return pos_hg, neg_hg1, neg_hg2
def sample_graph_for_gen(self):
d = {}
for nt in self.hg_dict.keys():
for src in self.hg_dict[nt].keys():
for i in range(self.k):
et = self.nt_et[nt][random.randint(0, len(self.nt_et[nt]) - 1)]
dst = self.hg_dict[nt][src][et][random.randint(0, len(self.hg_dict[nt][src][et]) - 1)]
if et not in d:
d[et] = ([src], [dst])
else:
d[et][0].append(src)
d[et][1].append(dst)
return dgl.heterograph(d, {nt: len(self.hg_dict[nt].keys()) for nt in self.hg_dict.keys()})
[文档]@register_flow('HeGAN_trainer')
class HeGANTrainer(BaseFlow):
"""Node classification flows.
Supported Model: HeGAN
Supported Dataset:yelp
The task is to classify the nodes of HIN(Heterogeneous Information Network).
Note: If the output dim is not equal the number of classes, a MLP will follow the gnn model.
"""
def __init__(self, args):
super().__init__(args)
self.num_classes = self.task.dataset.num_classes
self.category = self.task.dataset.category
self.hg = self.task.get_graph()
self.model = build_model(self.model).build_model_from_args(self.args, self.hg)
self.model = self.model.to(self.device)
self.label_smooth = args.label_smooth
self.evaluator = self.task.evaluator.classification
self.evaluate_interval = 1
self.loss_fn = torch.nn.BCEWithLogitsLoss(reduction='sum')
self.optim_dis = torch.optim.Adam(self.model.discriminator.parameters(), lr=args.lr_dis, weight_decay=args.wd_dis)
self.optim_gen = torch.optim.Adam(self.model.generator.parameters(), lr=args.lr_gen, weight_decay=args.wd_gen)
self.train_idx, self.val_idx, self.test_idx = self.task.get_split()
self.labels = self.task.get_labels().to(self.device)
self.sampler = GraphSampler(self.hg, self.args.n_sample)
def train(self):
epoch_iter = tqdm(range(self.args.max_epoch))
for epoch in epoch_iter:
if self.args.mini_batch_flag:
dis_loss, gen_loss = self._mini_train_step()
else:
dis_loss, gen_loss = self._full_train_step()
dis_score, gen_score = self._test_step()
print(epoch)
print("discriminator:\n\tloss:{:.4f}\n\tmicro_f1: {:.4f},\n\tmacro_f1: {:.4f}".format(dis_loss, dis_score[0], dis_score[1]))
print("generator:\n\tloss:{:.4f}\n\tmicro_f1: {:.4f},\n\tmacro_f1: {:.4f}".format(gen_loss, gen_score[0], gen_score[1]))
def _mini_train_step(self):
dis_loss, gen_loss = None, None
return dis_loss, gen_loss
def _full_train_step(self):
r"""
Note
----
pos_loss:
positive graph loss.
neg_loss1:
negative graph loss with wrong realtions.
neg_loss2:
negativa graph loss with wrong nodes embedding.
"""
self.model.train()
gen_loss = None
dis_loss = None
# discriminator step
for _ in range(self.args.epoch_dis):
pos_hg, pos_hg1, pos_hg2 = self.sampler.sample_graph_for_dis()
pos_hg = pos_hg.to(self.device)
pos_hg1 = pos_hg1.to(self.device)
pos_hg2 = pos_hg2.to(self.device)
noise_emb = {
et: torch.tensor(np.random.normal(0.0, self.args.sigma, (pos_hg2.num_edges(et), self.args.emb_size)).astype('float32')).to(self.device)
for et in pos_hg2.canonical_etypes
}
self.model.generator.assign_node_data(pos_hg2, None)
self.model.generator.assign_edge_data(pos_hg2, None)
generate_neighbor_emb = self.model.generator.generate_neighbor_emb(pos_hg2, noise_emb)
pos_score, neg_score1, neg_score2 = self.model.discriminator(pos_hg, pos_hg1, pos_hg2, generate_neighbor_emb)
pos_loss = -torch.mean(F.logsigmoid(pos_score))
neg_loss1 = -torch.mean(F.logsigmoid(1-neg_score1 + 1e-5))
neg_loss2 = -torch.mean(F.logsigmoid(1-neg_score2 + 1e-5))
dis_loss = pos_loss + neg_loss2 + neg_loss1
self.optim_dis.zero_grad()
dis_loss.backward()
self.optim_dis.step()
# generator step
dis_node_emb, dis_relation_matrix = self.model.discriminator.get_parameters()
for _ in range(self.args.epoch_gen):
gen_hg = self.sampler.sample_graph_for_gen()
noise_emb = {
et: torch.tensor(np.random.normal(0.0, self.args.sigma, (gen_hg.num_edges(et), self.args.emb_size)).astype('float32')).to(self.device)
for et in gen_hg.canonical_etypes
}
gen_hg = gen_hg.to(self.device)
score = self.model.generator(gen_hg, dis_node_emb, dis_relation_matrix, noise_emb)
gen_loss = -torch.mean(F.logsigmoid(score))*(1-self.label_smooth)+\
-torch.mean(F.logsigmoid(1-score + 1e-5))*self.label_smooth
self.optim_gen.zero_grad()
gen_loss.backward()
self.optim_gen.step()
return dis_loss.item(), gen_loss.item()
def _test_step(self, split=None, logits=None):
self.model.eval()
self.model.generator.eval()
self.model.discriminator.eval()
with torch.no_grad():
dis_emb = self.model.discriminator.nodes_embedding[self.category]
gen_emb = self.model.generator.nodes_embedding[self.category]
dis_metric = self.evaluator(dis_emb.cpu(), self.labels.cpu())
gen_metric = self.evaluator(gen_emb.cpu(), self.labels.cpu())
return dis_metric, gen_metric