import numpy as np
import torch as th
import torch.nn as nn
import dgl.function as fn
import torch.nn.functional as F
from . import BaseModel, register_model
import torch.nn.functional as F
[文档]@register_model('KGCN')
class KGCN(BaseModel):
r"""
This module KGCN was introduced in `KGCN <https://dl.acm.org/doi/10.1145/3308558.3313417>`__.
It included two parts:
Aggregate the entity representation and its neighborhood representation into the entity's embedding.
The message function is defined as follow:
:math:`\mathrm{v}_{\mathcal{N}(v)}^{u}=\sum_{e \in \mathcal{N}(v)} \tilde{\pi}_{r_{v, e}}^{u} \mathrm{e}`
where :math:`\mathrm{e}` is the representation of entity,
:math:`\tilde{\pi}_{r_{v, e}}^{u}` is the scalar weight on the edge from entity to entity,
the result :math:`\mathrm{v}_{\mathcal{N}(v)}^{u}` saves message which is passed from neighbor nodes
There are three types of aggregators.
Sum aggregator takes the summation of two representation vectors,
Concat aggregator concatenates the two representation vectors and
Neighbor aggregator directly takes the neighborhood representation of entity as the output representation
:math:`a g g_{s u m}=\sigma\left(\mathbf{W} \cdot\left(\mathrm{v}+\mathrm{v}_{\mathcal{S}(v)}^{u}\right)+\mathbf{b}\right)`
:math:`agg $_{\text {concat }}=\sigma\left(\mathbf{W} \cdot \text{concat}\left(\mathrm{v}, \mathrm{v}_{\mathcal{S}(v)}^{u}\right)+\mathbf{b}\right)$`
:math:`\text { agg }_{\text {neighbor }}=\sigma\left(\mathrm{W} \cdot \mathrm{v}_{\mathcal{S}(v)}^{u}+\mathrm{b}\right)`
In the above equations, :math:`\sigma` is the nonlinear function and
:math:`\mathrm{W}` and :math:`\mathrm{b}` are transformation weight and bias.
the representation of an item is bound up with its neighbors by aggregation
Obtain scores using final entity representation and user representation
The final entity representation is denoted as :math:`\mathrm{v}^{u}`,
:math:`\mathrm{v}^{u}` do dot product with user representation :math:`\mathrm{u}`
can obtain the probability. The math formula for the above function is:
:math:`$\hat{y}_{u v}=f\left(\mathbf{u}, \mathrm{v}^{u}\right)$`
Parameters
----------
g : DGLGraph
A knowledge Graph preserves relationships between entities
args : Config
Model's config
"""
@classmethod
def build_model_from_args(cls, args, g):
return cls(g, args)
def __init__(self, g, args):
super(KGCN, self).__init__()
self.g = g
self.args = args
self.in_dim = args.in_dim
self.out_dim = args.out_dim
self.entity_emb_matrix = nn.Parameter(th.FloatTensor(self.g.num_nodes(), self.in_dim))
self.relation_emb_matrix = nn.Parameter(th.FloatTensor(args.n_relation, self.in_dim))
self.user_emb_matrix = nn.Parameter(th.FloatTensor(args.n_user, self.in_dim))
self.Aggregate = KGCN_Aggregate(args)
self.reset_parameters()
def reset_parameters(self):
nn.init.uniform_(self.entity_emb_matrix, -1, 1)
nn.init.uniform_(self.relation_emb_matrix, -1, 1)
nn.init.uniform_(self.user_emb_matrix, -1, 1)
def get_score(self):
r"""
Obtain scores using final entity representation and user representation
Returns
-------
"""
self.user_embeddings = self.user_emb_matrix[np.array(self.userList)]
self.scores = th.sum(self.user_embeddings * self.item_embeddings, dim=1)
self.scores_normalized = th.sigmoid(self.scores)
def get_embeddings(self):
return self.user_emb_matrix, self.entity_emb_matrix, self.relation_emb_matrix
def forward(self, blocks, inputdata):
r"""
Predict the probability between user and entity
Parameters
----------
blocks : list
Blocks saves the information of neighbor nodes in each layer
inputdata : numpy.ndarray
Inputdata contains the relationship between the user and the entity
Returns
-------
labels : torch.Tensor
the label between users and entities
scores : torch.Tensor
Probability of users clicking on entitys
"""
self.data = inputdata
self.blocks = blocks
self.user_indices = self.data[:,0]
self.itemlist = self.data[:,1]
self.labels = self.data[:,2]
self.item_embeddings, self.userList,self.labelList = self.Aggregate(blocks, inputdata)
self.get_score()
self.labels = th.tensor(self.labelList).to(self.args.device)
return self.labels, self.scores
class KGCN_Aggregate(nn.Module):
def __init__(self, args):
super(KGCN_Aggregate, self).__init__()
self.args = args
self.in_dim = args.in_dim
self.out_dim = args.out_dim
if self.args.aggregate == 'CONCAT':
self.agg = nn.Linear(self.in_dim*2, self.out_dim)
else:
self.agg = nn.Linear(self.in_dim, self.out_dim)
def aggregate(self):
self.sub_g.update_all(fn.u_mul_e('embedding', 'weight', 'm'),fn.sum('m', 'ft'))
self.userList = []
self.labelList = []
embeddingList = []
for i in range(len(self.data)):
weightIndex = np.where(self.itemlist==int(self.sub_g.dstdata['_ID'][i]))
if self.args.aggregate == 'SUM':
embeddingList.append(self.sub_g.dstdata['embedding'][i] + self.sub_g.dstdata['ft'][i][weightIndex])
elif self.args.aggregate == 'CONCAT':
embeddingList.append(th.cat([self.sub_g.dstdata['embedding'][i], self.sub_g.dstdata['ft'][i][weightIndex].squeeze(0)],dim=-1))
elif self.args.aggregate == 'NEIGHBOR':
embeddingList.append(self.sub_g.dstdata['embedding'][i])
self.userList.append(int(self.user_indices[weightIndex]))
self.labelList.append(int(self.labels[weightIndex]))
self.sub_g.dstdata['embedding'] = th.stack(embeddingList).squeeze(1)
output = F.dropout(self.sub_g.dstdata['embedding'],p=0)
if self.layer+1 == len(self.blocks):
self.item_embeddings = th.tanh(self.agg(output))
else:
self.item_embeddings = th.relu(self.agg(output))
def forward(self,blocks,inputdata):
r"""
Aggregate the entity representation and its neighborhood representation
Parameters
----------
blocks : list
Blocks saves the information of neighbor nodes in each layer
inputdata : numpy.ndarray
Inputdata contains the relationship between the user and the entity
Returns
-------
item_embeddings : torch.Tensor
items' embeddings after aggregated
userList : list
Users corresponding to items
labelList : list
Labels corresponding to items
"""
self.data = inputdata
self.blocks = blocks
self.user_indices = self.data[:,0]
self.itemlist = self.data[:,1]
self.labels = self.data[:,2]
for self.layer in range(len(blocks)):
self.sub_g = blocks[self.layer]
self.aggregate()
return self.item_embeddings, self.userList, self.labelList