openhgnn.models.DiffMG 源代码

import torch
import torch.nn as nn
import torch.nn.functional as F
from openhgnn.models import BaseModel, register_model
import numpy as np

[文档]@register_model('DiffMG') class DiffMG(BaseModel): @classmethod def build_model_from_args(cls, args,hg): args.search_model = search_model return cls def __init__(self, in_dims, n_hid, n_steps, dropout=None, attn_dim=64, use_norm=True, out_nl=True): super(DiffMG, self).__init__() self.n_hid = n_hid self.ws = nn.ModuleList() assert (isinstance(in_dims, list)) for i in range(len(in_dims)): self.ws.append(nn.Linear(in_dims[i], n_hid)) assert (isinstance(n_steps, list)) self.metas = nn.ModuleList() for i in range(len(n_steps)): self.metas.append(Cell(n_steps[i], n_hid, n_hid, use_norm=use_norm, use_nl=out_nl)) # * [Optional] Combine more than one meta graph? self.attn_fc1 = nn.Linear(n_hid, attn_dim) self.attn_fc2 = nn.Linear(attn_dim, 1) self.feats_drop = nn.Dropout(dropout) if dropout is not None else lambda x: x def forward(self, node_feats, node_types, adjs, idxes_seq, idxes_res, gpu): if gpu > -1: hid = torch.zeros((len(node_types), self.n_hid)).cuda() else: hid = torch.zeros((len(node_types), self.n_hid)) for i in range(len(node_feats)): hid[node_types == i] = self.ws[i](node_feats[i]) hid = self.feats_drop(hid) temps = []; attns = [] for i, meta in enumerate(self.metas): hidi = meta(hid, adjs, idxes_seq[i], idxes_res[i]) temps.append(hidi) attni = self.attn_fc2(torch.tanh(self.attn_fc1(temps[-1]))) attns.append(attni) hids = torch.stack(temps, dim=0).transpose(0, 1) attns = F.softmax(torch.cat(attns, dim=-1), dim=-1) out = (attns.unsqueeze(dim=-1) * hids).sum(dim=1) return out
class Op(nn.Module): def __init__(self): super(Op, self).__init__() # 定义图卷积的操作,其中adjs是邻接矩阵的列表,idx是当前需要使用的邻接矩阵的索引。 def forward(self, x, adjs, idx): return torch.spmm(adjs[idx], x) class Cell(nn.Module): def __init__(self, n_step, n_hid_prev, n_hid, use_norm=True, use_nl=True): super(Cell, self).__init__() self.affine = nn.Linear(n_hid_prev, n_hid) self.n_step = n_step self.norm = nn.LayerNorm(n_hid) if use_norm is True else lambda x: x self.use_nl = use_nl self.ops_seq = nn.ModuleList() self.ops_res = nn.ModuleList() for i in range(self.n_step): self.ops_seq.append(Op()) for i in range(1, self.n_step): for j in range(i): self.ops_res.append(Op()) def forward(self, x, adjs, idxes_seq, idxes_res): x = self.affine(x) states = [x] offset = 0 for i in range(self.n_step): seqi = self.ops_seq[i](states[i], adjs[:-1], idxes_seq[i]) # ! exclude zero Op resi = sum(self.ops_res[offset + j](h, adjs, idxes_res[offset + j]) for j, h in enumerate(states[:i])) offset += i states.append(seqi + resi) # assert(offset == len(self.ops_res)) output = self.norm(states[-1]) if self.use_nl: output = F.gelu(output) return output class Op1(nn.Module): ''' operation for one link in the DAG search space ''' def __init__(self): super(Op1, self).__init__() def forward(self, x, adjs, ws, idx): # assert(ws.size(0) == len(adjs)) return ws[idx] * torch.spmm(adjs[idx], x) class Cell1(nn.Module): ''' the DAG search space ''' def __init__(self, n_step, n_hid_prev, n_hid, cstr, use_norm=True, use_nl=True): super(Cell1, self).__init__() self.affine = nn.Linear(n_hid_prev, n_hid) self.n_step = n_step # * number of intermediate states (i.e., K) self.norm = nn.LayerNorm(n_hid, elementwise_affine=False) if use_norm is True else lambda x: x self.use_nl = use_nl assert (isinstance(cstr, list)) self.cstr = cstr # * type constraint self.ops_seq = nn.ModuleList() # * state (i - 1) -> state i, 1 <= i < K for i in range(1, self.n_step): self.ops_seq.append(Op1()) self.ops_res = nn.ModuleList() # * state j -> state i, 0 <= j < i - 1, 2 <= i < K for i in range(2, self.n_step): for j in range(i - 1): self.ops_res.append(Op1()) self.last_seq = Op1() # * state (K - 1) -> state K self.last_res = nn.ModuleList() # * state i -> state K, 0 <= i < K - 1 for i in range(self.n_step - 1): self.last_res.append(Op1()) def forward(self, x, adjs, ws_seq, idxes_seq, ws_res, idxes_res): # assert(isinstance(ws_seq, list)) # assert(len(ws_seq) == 2) x = self.affine(x) states = [x] offset = 0 for i in range(self.n_step - 1): seqi = self.ops_seq[i](states[i], adjs[:-1], ws_seq[0][i], idxes_seq[0][i]) # ! exclude zero Op resi = sum(self.ops_res[offset + j](h, adjs, ws_res[0][offset + j], idxes_res[0][offset + j]) for j, h in enumerate(states[:i])) offset += i states.append(seqi + resi) # assert(offset == len(self.ops_res)) adjs_cstr = [adjs[i] for i in self.cstr] out_seq = self.last_seq(states[-1], adjs_cstr, ws_seq[1], idxes_seq[1]) adjs_cstr.append(adjs[-1]) out_res = sum(self.last_res[i](h, adjs_cstr, ws_res[1][i], idxes_res[1][i]) for i, h in enumerate(states[:-1])) output = self.norm(out_seq + out_res) if self.use_nl: output = F.gelu(output) return output class search_model(nn.Module): def __init__(self, in_dims, n_hid, n_adjs, n_steps, cstr, attn_dim=64, use_norm=True, out_nl=True): super(search_model, self).__init__() self.cstr = cstr self.n_adjs = n_adjs self.n_hid = n_hid self.ws = nn.ModuleList() # * node type-specific transformation assert (isinstance(in_dims, list)) for i in range(len(in_dims)): self.ws.append(nn.Linear(in_dims[i], n_hid)) assert (isinstance(n_steps, list)) # * [optional] combine more than one meta graph? self.metas = nn.ModuleList() for i in range(len(n_steps)): self.metas.append(Cell1(n_steps[i], n_hid, n_hid, cstr, use_norm=use_norm, use_nl=out_nl)) self.as_seq = [] # * arch parameters for ops_seq self.as_last_seq = [] # * arch parameters for last_seq for i in range(len(n_steps)): if n_steps[i] > 1: ai = 1e-3 * torch.randn(n_steps[i] - 1, n_adjs - 1) # ! exclude zero Op ai = ai ai.requires_grad_(True) self.as_seq.append(ai) else: self.as_seq.append(None) ai_last = 1e-3 * torch.randn(len(cstr)) ai_last = ai_last ai_last.requires_grad_(True) self.as_last_seq.append(ai_last) ks = [sum(1 for i in range(2, n_steps[k]) for j in range(i - 1)) for k in range(len(n_steps))] self.as_res = [] # * arch parameters for ops_res self.as_last_res = [] # * arch parameters for last_res for i in range(len(n_steps)): if ks[i] > 0: ai = 1e-3 * torch.randn(ks[i], n_adjs) ai = ai ai.requires_grad_(True) self.as_res.append(ai) else: self.as_res.append(None) if n_steps[i] > 1: ai_last = 1e-3 * torch.randn(n_steps[i] - 1, len(cstr) + 1) ai_last = ai_last ai_last.requires_grad_(True) self.as_last_res.append(ai_last) else: self.as_last_res.append(None) assert (ks[0] + n_steps[0] + (0 if self.as_last_res[0] is None else self.as_last_res[0].size(0)) == ( 1 + n_steps[0]) * n_steps[0] // 2) # * [optional] combine more than one meta graph? self.attn_fc1 = nn.Linear(n_hid, attn_dim) self.attn_fc2 = nn.Linear(attn_dim, 1) def alphas(self): alphas = [] for each in self.as_seq: if each is not None: alphas.append(each) for each in self.as_last_seq: alphas.append(each) for each in self.as_res: if each is not None: alphas.append(each) for each in self.as_last_res: if each is not None: alphas.append(each) return alphas def sample(self, eps): ''' to sample one candidate edge type per link ''' idxes_seq = [] idxes_res = [] if np.random.uniform() < eps: for i in range(len(self.metas)): temp = [] temp.append(None if self.as_seq[i] is None else torch.randint(low=0, high=self.as_seq[i].size(-1), size=self.as_seq[i].size()[:-1])) temp.append(torch.randint(low=0, high=self.as_last_seq[i].size(-1), size=(1,))) idxes_seq.append(temp) for i in range(len(self.metas)): temp = [] temp.append(None if self.as_res[i] is None else torch.randint(low=0, high=self.as_res[i].size(-1), size=self.as_res[i].size()[:-1])) temp.append( None if self.as_last_res[i] is None else torch.randint(low=0, high=self.as_last_res[i].size(-1), size=self.as_last_res[i].size()[:-1])) idxes_res.append(temp) else: for i in range(len(self.metas)): temp = [] temp.append(None if self.as_seq[i] is None else torch.argmax(F.softmax(self.as_seq[i], dim=-1), dim=-1)) temp.append(torch.argmax(F.softmax(self.as_last_seq[i], dim=-1), dim=-1)) idxes_seq.append(temp) for i in range(len(self.metas)): temp = [] temp.append(None if self.as_res[i] is None else torch.argmax(F.softmax(self.as_res[i], dim=-1), dim=-1)) temp.append( None if self.as_last_res[i] is None else torch.argmax(F.softmax(self.as_last_res[i], dim=-1), dim=-1)) idxes_res.append(temp) return idxes_seq, idxes_res def forward(self, node_feats, node_types, adjs, idxes_seq, idxes_res): hid = torch.zeros((len(node_types), self.n_hid)) for i in range(len(node_feats)): hid[node_types == i] = self.ws[i](node_feats[i]) temps = []; attns = [] for i, meta in enumerate(self.metas): ws_seq = [] ws_seq.append(None if self.as_seq[i] is None else F.softmax(self.as_seq[i], dim=-1)) ws_seq.append(F.softmax(self.as_last_seq[i], dim=-1)) ws_res = [] ws_res.append(None if self.as_res[i] is None else F.softmax(self.as_res[i], dim=-1)) ws_res.append(None if self.as_last_res[i] is None else F.softmax(self.as_last_res[i], dim=-1)) hidi = meta(hid, adjs, ws_seq, idxes_seq[i], ws_res, idxes_res[i]) temps.append(hidi) attni = self.attn_fc2(torch.tanh(self.attn_fc1(temps[-1]))) attns.append(attni) hids = torch.stack(temps, dim=0).transpose(0, 1) attns = F.softmax(torch.cat(attns, dim=-1), dim=-1) out = (attns.unsqueeze(dim=-1) * hids).sum(dim=1) return out def parse(self): ''' to derive a meta graph indicated by arch parameters ''' idxes_seq, idxes_res = self.sample(0.) msg_seq = []; msg_res = [] for i in range(len(idxes_seq)): map_seq = [self.cstr[idxes_seq[i][1].item()]] msg_seq.append(map_seq if idxes_seq[i][0] is None else idxes_seq[i][0].tolist() + map_seq) assert (len(msg_seq[i]) == self.metas[i].n_step) temp_res = [] if idxes_res[i][1] is not None: for item in idxes_res[i][1].tolist(): if item < len(self.cstr): temp_res.append(self.cstr[item]) else: assert (item == len(self.cstr)) temp_res.append(self.n_adjs - 1) if idxes_res[i][0] is not None: temp_res = idxes_res[i][0].tolist() + temp_res assert (len(temp_res) == self.metas[i].n_step * (self.metas[i].n_step - 1) // 2) msg_res.append(temp_res) return msg_seq, msg_res