import torch as th
from tqdm import tqdm
from . import BaseFlow, register_flow
from ..models import build_model
from ..models.GATNE import NSLoss
import torch
from tqdm.auto import tqdm
from numpy import random
import dgl
from ..sampler.GATNE_sampler import NeighborSampler, generate_pairs
[docs]
@register_flow("GATNE_trainer")
class GATNE(BaseFlow):
def __init__(self, args):
super(GATNE, self).__init__(args)
self.model = build_model(self.model).build_model_from_args(self.args, self.hg).to(self.device)
self.train_pairs = None
self.train_dataloader = None
self.nsloss = None
self.neighbor_sampler = None
self.orig_val_hg = self.task.val_hg
self.orig_test_hg = self.task.test_hg
self.preprocess()
def preprocess(self):
assert len(self.hg.ntypes) == 1
bidirected_hg = dgl.to_bidirected(dgl.to_simple(self.hg.to('cpu')))
all_walks = []
for etype in self.hg.etypes:
nodes = torch.unique(bidirected_hg.edges(etype=etype)[0]).repeat(self.args.rw_walks)
traces, types = dgl.sampling.random_walk(
bidirected_hg, nodes, metapath=[etype] * (self.args.rw_length - 1)
)
all_walks.append(traces)
self.train_pairs = generate_pairs(all_walks, self.args.window_size, self.args.num_workers)
self.neighbor_sampler = NeighborSampler(bidirected_hg, [self.args.neighbor_samples])
self.train_dataloader = torch.utils.data.DataLoader(
self.train_pairs,
batch_size=self.args.batch_size,
collate_fn=self.neighbor_sampler.sample,
shuffle=True,
num_workers=self.args.num_workers,
pin_memory=True,
)
self.nsloss = NSLoss(self.hg.num_nodes(), self.args.neg_size, self.args.dim).to(self.device)
self.optimizer = torch.optim.Adam(
[{"params": self.model.parameters()}, {"params": self.nsloss.parameters()}], lr=self.args.learning_rate
)
return
def train(self):
best_score = 0
patience = 0
for self.epoch in range(self.args.max_epoch):
self._full_train_step()
cur_score = self._full_test_step()
if cur_score > best_score:
best_score = cur_score
patience = 0
else:
patience += 1
if patience > self.args.patience:
self.logger.train_info(f'Early Stop!\tEpoch:{self.epoch:03d}.')
break
def _full_train_step(self):
self.model.train()
random.shuffle(self.train_pairs)
data_iter = tqdm(
self.train_dataloader,
desc="epoch %d" % self.epoch,
total=(len(self.train_pairs) + (self.args.batch_size - 1)) // self.args.batch_size,
)
avg_loss = 0.0
for i, (block, head_invmap, tails, block_types) in enumerate(data_iter):
self.optimizer.zero_grad()
# embs: [batch_size, edge_type_count, embedding_size]
block_types = block_types.to(self.device)
embs = self.model(block[0].to(self.device))[head_invmap]
embs = embs.gather(
1, block_types.view(-1, 1, 1).expand(embs.shape[0], 1, embs.shape[2])
)[:, 0]
loss = self.nsloss(
block[0].dstdata[dgl.NID][head_invmap].to(self.device),
embs,
tails.to(self.device),
)
loss.backward()
self.optimizer.step()
avg_loss += loss.item()
post_fix = {
"epoch": self.epoch,
"iter": i,
"avg_loss": avg_loss / (i + 1),
"loss": loss.item(),
}
data_iter.set_postfix(post_fix)
def _full_test_step(self):
self.model.eval()
# {'1': {}, '2': {}}
final_model = dict(
zip(self.hg.etypes, [th.empty(self.hg.num_nodes(), self.args.dim) for _ in range(len(self.hg.etypes))]))
for i in tqdm(range(self.hg.num_nodes()), desc='Evaluating...'):
train_inputs = (
torch.tensor([i for _ in range(len(self.hg.etypes))])
.unsqueeze(1)
.to(self.device)
) # [i, i]
train_types = (
torch.tensor(list(range(len(self.hg.etypes)))).unsqueeze(1).to(self.device)
) # [0, 1]
pairs = torch.cat(
(train_inputs, train_inputs, train_types), dim=1
) # (2, 3)
(
train_blocks,
train_invmap,
fake_tails,
train_types,
) = self.neighbor_sampler.sample(pairs)
node_emb = self.model(train_blocks[0].to(self.device))[train_invmap]
node_emb = node_emb.gather(
1,
train_types.to(self.device)
.view(-1, 1, 1)
.expand(node_emb.shape[0], 1, node_emb.shape[2]),
)[:, 0]
for j in range(len(self.hg.etypes)):
final_model[self.hg.etypes[j]][i] = node_emb[j].detach()
metric = {}
score = []
for etype in self.hg.etypes:
self.task.val_hg = dgl.edge_type_subgraph(self.orig_val_hg, [etype])
self.task.test_hg = dgl.edge_type_subgraph(self.orig_test_hg, [etype])
for split in ['test', 'valid']:
n_embedding = {self.hg.ntypes[0]: final_model[etype].to(self.device)}
res = self.task.evaluate(n_embedding=n_embedding, mode=split)
metric[split] = res
if split == 'valid':
score.append(res.get('roc_auc'))
self.logger.train_info(etype + self.logger.metric2str(metric))
avg_score = sum(score) / len(score)
return avg_score