openhgnn.models.HeCo 源代码

import torch
import torch.nn as nn
import torch.nn.functional as F

import dgl
from dgl.data.utils import load_graphs
from dgl.sampling import sample_neighbors
from dgl.nn.pytorch import GATConv, GraphConv
from openhgnn.models import BaseModel, register_model
from ..utils.utils import extract_metapaths


def init_drop(dropout):
    if dropout > 0:
        return nn.Dropout(dropout)
    else:
        return lambda x: x


[文档]@register_model('HeCo') class HeCo(BaseModel): r""" **Title:** Self-supervised Heterogeneous Graph Neural Network with Co-contrastive Learning **Authors:** Xiao Wang, Nian Liu, Hui Han, Chuan Shi HeCo was introduced in `[paper] <http://shichuan.org/doc/112.pdf>`_ and parameters are defined as follows: Parameters ---------- meta_paths : dict Extract metapaths from graph network_schema : dict Directed edges from other types to target type category : string The category of the nodes to be classificated hidden_size : int Hidden units size feat_drop : float Dropout rate for projected feature attn_drop : float Dropout rate for attentions used in two view guided encoders sample_rate : dict The nuber of neighbors of each type sampled for network schema view tau : float Temperature parameter used for contrastive loss lam : float Balance parameter for two contrastive losses """ @classmethod def build_model_from_args(cls, args, hg): if args.meta_paths_dict is None: meta_paths_dict = extract_metapaths(args.category, hg.canonical_etypes) else: meta_paths_dict = args.meta_paths_dict schema = [] for etype in hg.canonical_etypes: if etype[2] == args.category: schema.append(etype) return cls(meta_paths_dict=meta_paths_dict, network_schema=schema, category=args.category, hidden_size=args.hidden_dim, feat_drop=args.feat_drop, attn_drop=args.attn_drop, sample_rate=args.sample_rate, tau=args.tau, lam=args.lam) def __init__(self, meta_paths_dict, network_schema, category, hidden_size, feat_drop, attn_drop , sample_rate, tau, lam): super(HeCo, self).__init__() self.category = category # target node type self.feat_drop = init_drop(feat_drop) self.attn_drop = attn_drop self.mp = Mp_encoder(meta_paths_dict, hidden_size, self.attn_drop) self.sc = Sc_encoder(network_schema, hidden_size, self.attn_drop, sample_rate, self.category) self.contrast = Contrast(hidden_size, tau, lam) def forward(self, g, h_dict, pos): r""" This is the forward part of model HeCo. Parameters ---------- g : DGLGraph A DGLGraph h_dict: dict Projected features after linear projection pos: matrix A matrix to indicate the postives for each node Returns ------- loss : float The optimize objective Note ----------- Pos matrix is pre-defined by users. The relative tool is given in original code. """ new_h = {} for key, value in h_dict.items(): new_h[key] = F.elu(self.feat_drop(value)) z_mp = self.mp(g, new_h[self.category]) z_sc = self.sc(g, new_h) loss = self.contrast(z_mp, z_sc, pos) return loss def get_embeds(self, g, h_dict): r""" This is to get final embeddings of target nodes """ z_mp = F.elu(h_dict[self.category]) z_mp = self.mp(g, z_mp) return z_mp.detach()
class SelfAttention(nn.Module): def __init__(self, hidden_dim, attn_drop, txt): r""" This part is used to calculate type-level attention and semantic-level attention, and utilize them to generate :math:`z^{sc}` and :math:`z^{mp}`. .. math:: w_{n}&=\frac{1}{|V|}\sum\limits_{i\in V} \textbf{a}^\top \cdot \tanh\left(\textbf{W}h_i^{n}+\textbf{b}\right) \\ \beta_{n}&=\frac{\exp\left(w_{n}\right)}{\sum_{i=1}^M\exp\left(w_{i}\right)} \\ z &= \sum_{n=1}^M \beta_{n}\cdot h^{n} Parameters ---------- txt : str A str to identify view, MP or SC Returns ------- z : matrix The fused embedding matrix """ super(SelfAttention, self).__init__() self.fc = nn.Linear(hidden_dim, hidden_dim, bias=True) nn.init.xavier_normal_(self.fc.weight, gain=1.414) self.tanh = nn.Tanh() self.att = nn.Parameter(torch.empty(size=(1, hidden_dim)), requires_grad=True) nn.init.xavier_normal_(self.att.data, gain=1.414) self.softmax = nn.Softmax(dim=0) self.attn_drop = init_drop(attn_drop) self.txt = txt def forward(self, embeds): beta = [] attn_curr = self.attn_drop(self.att) for embed in embeds: sp = self.tanh(self.fc(embed)).mean(dim=0) beta.append(attn_curr.matmul(sp.t())) beta = torch.cat(beta, dim=-1).view(-1) beta = self.softmax(beta) print(self.txt, beta.data.cpu().numpy()) # semantic attention z = 0 for i in range(len(embeds)): z += embeds[i] * beta[i] return z class Mp_encoder(nn.Module): def __init__(self, meta_paths_dict, hidden_size, attn_drop): r""" This part is to encode meta-path view. Returns ------- z_mp : matrix The embedding matrix under meta-path view. """ super(Mp_encoder, self).__init__() # One GCN layer for each meta path based adjacency matrix self.act = nn.PReLU() self.gcn_layers = nn.ModuleDict() for mp in meta_paths_dict: one_layer = GraphConv(hidden_size, hidden_size, activation=self.act, allow_zero_in_degree=True) one_layer.reset_parameters() self.gcn_layers[mp] = one_layer self.meta_paths_dict = meta_paths_dict self._cached_graph = None self._cached_coalesced_graph = {} self.semantic_attention = SelfAttention(hidden_size, attn_drop, "mp") def forward(self, g, h): semantic_embeddings = [] if self._cached_graph is None or self._cached_graph is not g: self._cached_graph = g self._cached_coalesced_graph.clear() for mp, meta_path in self.meta_paths_dict.items(): self._cached_coalesced_graph[mp] = dgl.metapath_reachable_graph( g, meta_path) for mp, meta_path in self.meta_paths_dict.items(): new_g = self._cached_coalesced_graph[mp] one = self.gcn_layers[mp](new_g, h) semantic_embeddings.append(one) # node level attention z_mp = self.semantic_attention(semantic_embeddings) return z_mp class Sc_encoder(nn.Module): def __init__(self, network_schema, hidden_size, attn_drop, sample_rate, category): r""" This part is to encode network schema view. Returns ------- z_mp : matrix The embedding matrix under network schema view. Note ----------- There is a different sampling strategy between original code and this code. In original code, the authors implement sampling without replacement if the number of neighbors exceeds a threshold, and with replacement if not. In this version, we simply use the API dgl.sampling.sample_neighbors to implement this operation, and set replacement as True. """ super(Sc_encoder, self).__init__() self.gat_layers = nn.ModuleList() for i in range(len(network_schema)): one_layer = GATConv((hidden_size, hidden_size), hidden_size, num_heads=1, attn_drop=attn_drop, allow_zero_in_degree=True) one_layer.reset_parameters() self.gat_layers.append(one_layer) self.network_schema = list(tuple(ns) for ns in network_schema) self._cached_graph = None self._cached_coalesced_graph = {} self.inter = SelfAttention(hidden_size, attn_drop, "sc") self.sample_rate = sample_rate self.category = category def forward(self, g, h): intra_embeddings = [] for i, network_schema in enumerate(self.network_schema): src_type = network_schema[0] one_graph = g[network_schema] cate_num = torch.arange(0, g.num_nodes(self.category)).to(g.device) sub_graph = sample_neighbors(one_graph, {self.category: cate_num}, {network_schema[1]: self.sample_rate[src_type]}, replace=True) one = self.gat_layers[i](sub_graph, (h[src_type], h[self.category])) one = one.squeeze(1) intra_embeddings.append(one) z_sc = self.inter(intra_embeddings) return z_sc class Contrast(nn.Module): def __init__(self, hidden_dim, tau, lam): r""" This part is used to calculate the contrastive loss. Returns ------- contra_loss : float The calculated loss """ super(Contrast, self).__init__() self.proj = nn.Sequential( nn.Linear(hidden_dim, hidden_dim), nn.ELU(), nn.Linear(hidden_dim, hidden_dim) ) self.tau = tau self.lam = lam for model in self.proj: if isinstance(model, nn.Linear): nn.init.xavier_normal_(model.weight, gain=1.414) def sim(self, z1, z2): r""" This part is used to calculate the cosine similarity of each pair of nodes from different views. """ z1_norm = torch.norm(z1, dim=-1, keepdim=True) z2_norm = torch.norm(z2, dim=-1, keepdim=True) dot_numerator = torch.mm(z1, z2.t()) dot_denominator = torch.mm(z1_norm, z2_norm.t()) sim_matrix = torch.exp(dot_numerator / dot_denominator / self.tau) return sim_matrix def forward(self, z_mp, z_sc, pos): r""" This is the forward part of contrast part. We firstly project the embeddings under two views into the space where contrastive loss is calculated. Then, we calculate the contrastive loss with projected embeddings in a cross-view way. .. math:: \mathcal{L}_i^{sc}=-\log\frac{\sum_{j\in\mathbb{P}_i}exp\left(sim\left(z_i^{sc}\_proj,z_j^{mp}\_proj\right)/\tau\right)}{\sum_{k\in\{\mathbb{P}_i\bigcup\mathbb{N}_i\}}exp\left(sim\left(z_i^{sc}\_proj,z_k^{mp}\_proj\right)/\tau\right)} where we show the contrastive loss :math:`\mathcal{L}_i^{sc}` under network schema view, and :math:`\mathbb{P}_i` and :math:`\mathbb{N}_i` are positives and negatives for node :math:`i`. In a similar way, we can get the contrastive loss :math:`\mathcal{L}_i^{mp}` under meta-path view. Finally, we utilize combination parameter :math:`\lambda` to add this two losses. Note ----------- In implementation, each row of 'matrix_mp2sc' means the similarity with exponential between one node in meta-path view and all nodes in network schema view. Then, we conduct normalization for this row, and pick the results where the pair of nodes are positives. Finally, we sum these results for each row, and give a log to get the final loss. """ z_proj_mp = self.proj(z_mp) z_proj_sc = self.proj(z_sc) matrix_mp2sc = self.sim(z_proj_mp, z_proj_sc) matrix_sc2mp = matrix_mp2sc.t() matrix_mp2sc = matrix_mp2sc / (torch.sum(matrix_mp2sc, dim=1).view(-1, 1) + 1e-8) lori_mp = -torch.log(matrix_mp2sc.mul(pos.to_dense()).sum(dim=-1)).mean() matrix_sc2mp = matrix_sc2mp / (torch.sum(matrix_sc2mp, dim=1).view(-1, 1) + 1e-8) lori_sc = -torch.log(matrix_sc2mp.mul(pos.to_dense()).sum(dim=-1)).mean() contra_loss = self.lam * lori_mp + (1 - self.lam) * lori_sc return contra_loss '''logreg''' class LogReg(nn.Module): r""" Parameters ---------- ft_in : int Size of hid_units nb_class : int The number of category's types """ def __init__(self, ft_in, nb_classes): super(LogReg, self).__init__() self.fc = nn.Linear(ft_in, nb_classes) for m in self.modules(): self.weights_init(m) def weights_init(self, m): if isinstance(m, nn.Linear): torch.nn.init.xavier_uniform_(m.weight.data) if m.bias is not None: m.bias.data.fill_(0.0) def forward(self, seq): ret = self.fc(seq) return ret