openhgnn.layers.MetapathConv 源代码
"""Heterograph NN modules"""
import torch as th
import torch.nn as nn
[文档]class MetapathConv(nn.Module):
r"""
MetapathConv is an aggregation function based on meta-path, which is similar with `dgl.nn.pytorch.HeteroGraphConv`.
We could choose Attention/ APPNP or any GraphConvLayer to aggregate node features.
After that we will get embeddings based on different meta-paths and fusion them.
.. math::
\mathbf{Z}=\mathcal{F}(Z^{\Phi_1},Z^{\Phi_2},...,Z^{\Phi_p})=\mathcal{F}(f(H,\Phi_1),f(H,\Phi_2),...,f(H,\Phi_p))
where :math:`\mathcal{F}` denotes semantic fusion function, such as semantic-attention. :math:`\Phi_i` denotes meta-path and
:math:`f` denotes the aggregation function, such as GAT, APPNP.
Parameters
------------
meta_paths_dict : dict[str, list[tuple(meta-path)]]
contain multiple meta-paths.
mods : nn.ModuleDict
aggregation function
macro_func : callable aggregation func
A semantic aggregation way, e.g. 'mean', 'max', 'sum' or 'attention'
"""
def __init__(self, meta_paths_dict, mods, macro_func, **kargs):
super(MetapathConv, self).__init__()
# One GAT layer for each meta path based adjacency matrix
self.mods = mods
self.meta_paths_dict = meta_paths_dict
self.SemanticConv = macro_func
def forward(self, g_dict, h_dict):
r"""
Parameters
-----------
g_dict : dict[str: dgl.DGLGraph]
A dict of DGLGraph(full batch) or DGLBlock(mini batch) extracted by metapaths.
h_dict : dict[str: torch.Tensor]
The input features
Returns
--------
h : dict[str: torch.Tensor]
The output features dict
"""
outputs = {g.dsttypes[0]: [] for s, g in g_dict.items()}
for meta_path_name, meta_path in self.meta_paths_dict.items():
new_g = g_dict[meta_path_name]
# han minibatch
if h_dict.get(meta_path_name) is not None:
h = h_dict[meta_path_name][new_g.srctypes[0]]
# full batch
else:
h = h_dict[new_g.srctypes[0]]
outputs[new_g.dsttypes[0]].append(self.mods[meta_path_name](new_g, h).flatten(1))
# semantic_embeddings = th.stack(semantic_embeddings, dim=1) # (N, M, D * K)
# Aggregate the results for each destination node type
rsts = {}
for ntype, ntype_outputs in outputs.items():
if len(ntype_outputs) != 0:
rsts[ntype] = self.SemanticConv(ntype_outputs) # (N, D * K)
return rsts