Source code for openhgnn.dataset.LinkPredictionDataset

import os.path

import dgl
import math
import re
from copy import deepcopy
import numpy as np
import torch as th
import itertools
import random
from random import shuffle, choice
from collections import Counter
from os.path import join as joinpath
from os.path import isfile
from dgl.data.knowledge_graph import load_data
from . import BaseDataset, register_dataset
from . import AcademicDataset, HGBDataset, OHGBDataset, NBF_Dataset
from ..utils import add_reverse_edges
from collections import defaultdict
import os
from scipy.sparse import csr_matrix

__all__ = ['LinkPredictionDataset', 'HGB_LinkPrediction']


[docs] @register_dataset('link_prediction') class LinkPredictionDataset(BaseDataset): """ metric: Accuracy, multi-label f1 or multi-class f1. Default: `accuracy` """ def __init__(self, *args, **kwargs): super(LinkPredictionDataset, self).__init__(*args, **kwargs) self.target_link = None self.target_link_r = None def get_split(self, val_ratio=0.1, test_ratio=0.2): """ Get subgraphs for train, valid and test. Generally, the original will have train_mask and test_mask in edata, or we will split it automatically. If the original graph do not has the train_mask in edata, we default that there is no valid_mask and test_mask. So we will split the edges of the original graph into train/valid/test 0.7/0.1/0.2. The dataset has not validation_mask, so we split train edges randomly. Parameters ---------- val_ratio : int The ratio of validation. Default: 0.1 test_ratio : int The ratio of test. Default: 0.2 Returns ------- train_hg """ val_edge_dict = {} test_edge_dict = {} out_ntypes = [] train_graph = self.g for i, etype in enumerate(self.target_link): num_edges = self.g.num_edges(etype) if 'train_mask' not in self.g.edges[etype].data: """ split edges into train/valid/test. """ random_int = th.randperm(num_edges) val_index = random_int[:int(num_edges * val_ratio)] val_edge = self.g.find_edges(val_index, etype) test_index = random_int[int(num_edges * val_ratio):int(num_edges * (test_ratio + val_ratio))] test_edge = self.g.find_edges(test_index, etype) val_edge_dict[etype] = val_edge test_edge_dict[etype] = test_edge out_ntypes.append(etype[0]) out_ntypes.append(etype[2]) train_graph = dgl.remove_edges(train_graph, th.cat((val_index, test_index)), etype) # train_graph = dgl.remove_edges(train_graph, val_index, etype) if self.target_link_r is None: pass else: reverse_edge = self.target_link_r[i] train_graph = dgl.remove_edges(train_graph, th.arange(train_graph.num_edges(reverse_edge)), reverse_edge) edges = train_graph.edges(etype=etype) train_graph = dgl.add_edges(train_graph, edges[1], edges[0], etype=reverse_edge) else: if 'valid_mask' not in self.g.edges[etype].data: train_idx = self.g.edges[etype].data['train_mask'] random_int = th.randperm(int(train_idx.sum( ))) val_index = random_int[:int(train_idx.sum( ) * val_ratio)] val_edge = self.g.find_edges(val_index, etype) else: val_mask = self.g.edges[etype].data['valid_mask'].squeeze( ) val_index = th.nonzero(val_mask).squeeze( ) val_edge = self.g.find_edges(val_index, etype) test_mask = self.g.edges[etype].data['test_mask'].squeeze( ) test_index = th.nonzero(test_mask).squeeze( ) test_edge = self.g.find_edges(test_index, etype) val_edge_dict[etype] = val_edge test_edge_dict[etype] = test_edge out_ntypes.append(etype[0]) out_ntypes.append(etype[2]) # self.val_label = train_graph.edges[etype[1]].data['label'][val_index] self.test_label = train_graph.edges[etype[1]].data['label'][test_index] train_graph = dgl.remove_edges(train_graph, th.cat((val_index, test_index)), etype) # train_graph = dgl.remove_edges(train_graph, th.cat((val_index, test_index)), 'item-user') self.out_ntypes = set(out_ntypes) val_graph = dgl.heterograph(val_edge_dict, {ntype: self.g.number_of_nodes(ntype) for ntype in set(out_ntypes)}) test_graph = dgl.heterograph(test_edge_dict, {ntype: self.g.number_of_nodes(ntype) for ntype in set(out_ntypes)}) # todo: val/test negative graphs should be created before training rather than # create them dynamically in every evaluation. return train_graph, val_graph, test_graph, None, None
@register_dataset('demo_link_prediction') class Test_LinkPrediction(LinkPredictionDataset): def __init__(self, dataset_name): super(Test_LinkPrediction, self).__init__( ) self.g = self.load_HIN('./openhgnn/debug/data.bin') self.target_link = 'user-item' self.has_feature = False self.meta_paths_dict = None self.preprocess( ) # self.generate_negative() def preprocess(self): test_mask = self.g.edges[self.target_link].data['test_mask'] index = th.nonzero(test_mask).squeeze( ) self.test_edge = self.g.find_edges(index, self.target_link) self.pos_test_graph = dgl.heterograph({('user', 'user-item', 'item'): self.test_edge}, {ntype: self.g.number_of_nodes(ntype) for ntype in ['user', 'item']}) self.g.remove_edges(index, self.target_link) self.g.remove_edges(index, 'item-user') self.neg_test_graph, _ = dgl.load_graphs('./openhgnn/debug/neg.bin') self.neg_test_graph = self.neg_test_graph[0] return def generate_negative(self): k = 99 e = self.pos_test_graph.edges( ) neg_src = [] neg_dst = [] for i in range(self.pos_test_graph.number_of_edges( )): src = e[0][i] exp = self.pos_test_graph.successors(src) dst = th.randint(high=self.g.number_of_nodes('item'), size=(k,)) for d in range(len(dst)): while dst[d] in exp: dst[d] = th.randint(high=self.g.number_of_nodes('item'), size=(1,)) src = src.repeat_interleave(k) neg_src.append(src) neg_dst.append(dst) neg_edge = (th.cat(neg_src), th.cat(neg_dst)) neg_graph = dgl.heterograph({('user', 'user-item', 'item'): neg_edge}, {ntype: self.g.number_of_nodes(ntype) for ntype in ['user', 'item']}) dgl.save_graphs('./openhgnn/debug/neg.bin', neg_graph) @register_dataset('hin_link_prediction') class HIN_LinkPrediction(LinkPredictionDataset): def __init__(self, dataset_name, *args, **kwargs): super(HIN_LinkPrediction, self).__init__(*args, **kwargs) self.g = self.load_HIN(dataset_name) def load_link_pred(self, path): u_list = [] v_list = [] label_list = [] with open(path) as f: for i in f.readlines( ): u, v, label = i.strip( ).split(', ') u_list.append(int(u)) v_list.append(int(v)) label_list.append(int(label)) return u_list, v_list, label_list def load_HIN(self, dataset_name): self.dataset_name = dataset_name if dataset_name == 'academic4HetGNN': # which is used in HetGNN dataset = AcademicDataset(name='academic4HetGNN', raw_dir='') g = dataset[0].long( ) self.train_batch = self.load_link_pred('./openhgnn/dataset/' + dataset_name + '/a_a_list_train.txt') self.test_batch = self.load_link_pred('./openhgnn/dataset/' + dataset_name + '/a_a_list_test.txt') self.category = 'author' elif dataset_name == 'Book-Crossing': g, _ = dgl.load_graphs('./openhgnn/dataset/book_graph.bin') g = g[0] self.target_link = [('user', 'user-item', 'item')] self.node_type = ['user', 'item'] elif dataset_name == 'amazon4SLICE': dataset = AcademicDataset(name='amazon4SLICE', raw_dir='') g = dataset[0].long( ) # self.target_link = [('product', 'product-1-product', 'product'), # ('product', 'product-2-product', 'product')] self.target_link = [('product', 'product-1-product', 'product')] elif dataset_name == 'MTWM': dataset = AcademicDataset(name='MTWM', raw_dir='') g = dataset[0].long( ) g = add_reverse_edges(g) self.target_link = [('user', 'user-buy-spu', 'spu')] self.target_link_r = [('spu', 'user-buy-spu-rev', 'user')] self.meta_paths_dict = { 'UPU1': [('user', 'user-buy-poi', 'poi'), ('poi', 'user-buy-poi-rev', 'user')], 'UPU2': [('user', 'user-click-poi', 'poi'), ('poi', 'user-click-poi-rev', 'user')], 'USU': [('user', 'user-buy-spu', 'spu'), ('spu', 'user-buy-spu-rev', 'user')], 'UPSPU1': [('user', 'user-buy-poi', 'poi'), ('poi', 'poi-contain-spu', 'spu'), ('spu', 'poi-contain-spu-rev', 'poi'), ('poi', 'user-buy-poi-rev', 'user') ], 'UPSPU2': [ ('user', 'user-click-poi', 'poi'), ('poi', 'poi-contain-spu', 'spu'), ('spu', 'poi-contain-spu-rev', 'poi'), ('poi', 'user-click-poi-rev', 'user') ] } self.node_type = ['user', 'spu'] elif dataset_name == 'HGBl-ACM': dataset = HGBDataset(name='HGBn-ACM', raw_dir='') g = dataset[0].long( ) self.has_feature = True self.target_link = [('paper', 'paper-ref-paper', 'paper')] self.node_type = ['author', 'paper', 'subject', 'term'] self.target_link_r = [('paper', 'paper-cite-paper', 'paper')] self.meta_paths_dict = {'PAP': [('paper', 'paper-author', 'author'), ('author', 'author-paper', 'paper')], 'PSP': [('paper', 'paper-subject', 'subject'), ('subject', 'subject-paper', 'paper')], 'PcPAP': [('paper', 'paper-cite-paper', 'paper'), ('paper', 'paper-author', 'author'), ('author', 'author-paper', 'paper')], 'PcPSP': [('paper', 'paper-cite-paper', 'paper'), ('paper', 'paper-subject', 'subject'), ('subject', 'subject-paper', 'paper')], 'PrPAP': [('paper', 'paper-ref-paper', 'paper'), ('paper', 'paper-author', 'author'), ('author', 'author-paper', 'paper')], 'PrPSP': [('paper', 'paper-ref-paper', 'paper'), ('paper', 'paper-subject', 'subject'), ('subject', 'subject-paper', 'paper')] } elif dataset_name == 'HGBl-DBLP': dataset = HGBDataset(name='HGBn-DBLP', raw_dir='') g = dataset[0].long( ) self.has_feature = True self.target_link = [('author', 'author-paper', 'paper')] self.node_type = ['author', 'paper', 'venue', 'term'] self.target_link_r = [('paper', 'paper-author', 'author')] self.meta_paths_dict = {'APA': [('author', 'author-paper', 'paper'), ('paper', 'paper-author', 'author')], 'APTPA': [('author', 'author-paper', 'paper'), ('paper', 'paper-term', 'term'), ('term', 'term-paper', 'paper'), ('paper', 'paper-author', 'author')], 'APVPA': [('author', 'author-paper', 'paper'), ('paper', 'paper-venue', 'venue'), ('venue', 'venue-paper', 'paper'), ('paper', 'paper-author', 'author')], 'PAP': [('paper', 'paper-author', 'author'), ('author', 'author-paper', 'paper')], 'PTP': [('paper', 'paper-term', 'term'), ('term', 'term-paper', 'paper')], 'PVP': [('paper', 'paper-venue', 'venue'), ('venue', 'venue-paper', 'paper')], } elif dataset_name == 'HGBl-IMDB': dataset = HGBDataset(name='HGBn-IMDB', raw_dir='') g = dataset[0].long( ) self.has_feature = True # self.target_link = [('author', 'author-paper', 'paper')] # self.node_type = ['author', 'paper', 'subject', 'term'] # self.target_link_r = [('paper', 'paper-author', 'author')] self.target_link = [('actor', 'actor->movie', 'movie')] self.node_type = ['actor', 'director', 'keyword', 'movie'] self.target_link_r = [('movie', 'movie->actor', 'actor')] self.meta_paths_dict = { 'MAM': [('movie', 'movie->actor', 'actor'), ('actor', 'actor->movie', 'movie')], 'MDM': [('movie', 'movie->director', 'director'), ('director', 'director->movie', 'movie')], 'MKM': [('movie', 'movie->keyword', 'keyword'), ('keyword', 'keyword->movie', 'movie')], # 'DMD': [('director', 'director->movie', 'movie'), ('movie', 'movie->director', 'director')], # 'DMAMD': [('director', 'director->movie', 'movie'), ('movie', 'movie->actor', 'actor'), # ('actor', 'actor->movie', 'movie'), ('movie', 'movie->director', 'director')], 'AMA': [('actor', 'actor->movie', 'movie'), ('movie', 'movie->actor', 'actor')], 'AMDMA': [('actor', 'actor->movie', 'movie'), ('movie', 'movie->director', 'director'), ('director', 'director->movie', 'movie'), ('movie', 'movie->actor', 'actor')] } return g def get_split(self, val_ratio=0.1, test_ratio=0.2): if self.dataset_name == 'academic4HetGNN': return None, None, None, None, None else: return super(HIN_LinkPrediction, self).get_split(val_ratio, test_ratio) @register_dataset('HGBl_link_prediction') class HGB_LinkPrediction(LinkPredictionDataset): r""" The HGB dataset will be used in task *link prediction*. Dataset Name : HGBn-amazon/HGBn-LastFM/HGBn-PubMed So if you want to get more information, refer to `HGB datasets <https://github.com/THUDM/HGB>`_ Attributes ----------- has_feature : bool Whether the dataset has feature. Except HGBl-LastFM, others have features. target_link : list of tuple[canonical_etypes] The etypes of test link. HGBl-amazon has two etypes of test link. other has only one. """ def __init__(self, dataset_name, *args, **kwargs): super(HGB_LinkPrediction, self).__init__(*args, **kwargs) self.dataset_name = dataset_name self.target_link_r = None if dataset_name == 'HGBl-amazon': dataset = HGBDataset(name=dataset_name, raw_dir='') g = dataset[0].long( ) self.has_feature = False self.target_link = [('product', 'product-product-0', 'product'), ('product', 'product-product-1', 'product')] self.target_link_r = None self.link = [0, 1] self.node_type = ["product"] self.test_edge_type = {'product-product-0': 0, 'product-product-1': 1} self.meta_paths_dict = { 'P0P': [('product', 'product-product-0', 'product'), ('product', 'product-product-1', 'product')], 'P1P': [('product', 'product-product-1', 'product'), ('product', 'product-product-0', 'product')] } elif dataset_name == 'HGBl-LastFM': dataset = HGBDataset(name=dataset_name, raw_dir='') g = dataset[0].long( ) self.has_feature = False self.target_link = [('user', 'user-artist', 'artist')] self.node_type = ['user', 'artist', 'tag'] self.test_edge_type = {'user-artist': 0} g = add_reverse_edges(g) self.target_link_r = [('artist', 'user-artist-rev', 'user')] self.meta_paths_dict = {'UU': [('user', 'user-user', 'user')], 'UAU': [('user', 'user-artist', 'artist'), ('artist', 'user-artist-rev', 'user')], 'UATAU': [('user', 'user-artist', 'artist'), ('artist', 'artist-tag', 'tag'), ('tag', 'artist-tag-rev', 'artist'), ('artist', 'user-artist-rev', 'user')], 'AUA': [('artist', 'user-artist-rev', 'user'), ('user', 'user-artist', 'artist')], 'ATA': [('artist', 'artist-tag', 'tag'), ('tag', 'artist-tag-rev', 'artist')] } elif dataset_name == 'HGBl-PubMed': dataset = HGBDataset(name=dataset_name, raw_dir='') g = dataset[0].long( ) self.has_feature = True self.target_link = [('1', '1_to_1', '1')] self.node_type = ['0', '1', '2', '3'] self.test_edge_type = {'1_to_1': 2} g = add_reverse_edges(g) self.target_link_r = [('1', '1_to_1-rev', '1')] self.meta_paths_dict = {'101': [('1', '0_to_1-rev', '0'), ('0', '0_to_1', '1')], '111': [('1', '1_to_1', '1'), ('1', '1_to_1-rev', '1')], '121': [('1', '2_to_1-rev', '2'), ('2', '2_to_1', '1')], '131': [('1', '3_to_1-rev', '3'), ('3', '3_to_1', '1')] } self.g = g self.shift_dict = self.calculate_node_shift( ) def load_link_pred(self, path): return def calculate_node_shift(self): node_shift_dict = {} count = 0 for type in self.node_type: node_shift_dict[type] = count count += self.g.num_nodes(type) return node_shift_dict def get_split(self): r""" Get graphs for train, valid or test. The dataset has not validation_mask, so we split train edges randomly. """ val_edge_dict = {} test_edge_dict = {} out_ntypes = [] train_graph = self.g val_ratio = 0.1 for i, etype in enumerate(self.target_link): train_mask = self.g.edges[etype].data['train_mask'].squeeze( ) train_index = th.nonzero(train_mask).squeeze( ) random_int = th.randperm(len(train_index))[:int(len(train_index) * val_ratio)] val_index = train_index[random_int] val_edge = self.g.find_edges(val_index, etype) test_mask = self.g.edges[etype].data['test_mask'].squeeze( ) test_index = th.nonzero(test_mask).squeeze( ) test_edge = self.g.find_edges(test_index, etype) val_edge_dict[etype] = val_edge test_edge_dict[etype] = test_edge out_ntypes.append(etype[0]) out_ntypes.append(etype[2]) train_graph = dgl.remove_edges(train_graph, th.cat((val_index, test_index)), etype) if self.target_link_r is None: pass else: train_graph = dgl.remove_edges(train_graph, th.cat((val_index, test_index)), self.target_link_r[i]) self.out_ntypes = set(out_ntypes) val_graph = dgl.heterograph(val_edge_dict, {ntype: self.g.number_of_nodes(ntype) for ntype in set(out_ntypes)}) test_graph = dgl.heterograph(test_edge_dict, {ntype: self.g.number_of_nodes(ntype) for ntype in set(out_ntypes)}) return train_graph, val_graph, test_graph, None, None def save_results(self, hg, score, file_path): with hg.local_scope( ): src_list = [] dst_list = [] edge_type_list = [] for etype in hg.canonical_etypes: edges = hg.edges(etype=etype) src_id = edges[0] + self.shift_dict[etype[0]] dst_id = edges[1] + self.shift_dict[etype[2]] src_list.append(src_id) dst_list.append(dst_id) edge_type_list.append(th.full((src_id.shape[0],), self.test_edge_type[etype[1]])) src_list = th.cat(src_list) dst_list = th.cat(dst_list) edge_type_list = th.cat(edge_type_list) with open(file_path, "w") as f: for l, r, edge_type, c in zip(src_list, dst_list, edge_type_list, score): f.write(f"{l}\t{r}\t{edge_type}\t{round(float(c), 4)}\n") @register_dataset('ohgb_link_prediction') class OHGB_LinkPrediction(LinkPredictionDataset): def __init__(self, dataset_name, *args, **kwargs): super(OHGB_LinkPrediction, self).__init__(*args, **kwargs) self.dataset_name = dataset_name self.has_feature = True if dataset_name == 'ohgbl-MTWM': dataset = OHGBDataset(name=dataset_name, raw_dir='') g = dataset[0].long( ) self.target_link = [('user', 'user-buy-spu', 'spu')] self.target_link_r = [('spu', 'user-buy-spu-rev', 'user')] self.node_type = ['user', 'spu'] elif dataset_name == 'ohgbl-yelp1': dataset = OHGBDataset(name=dataset_name, raw_dir='') g = dataset[0].long( ) self.target_link = [('user', 'user-buy-business', 'business')] self.target_link_r = [('business', 'user-buy-business-rev', 'user')] elif dataset_name == 'ohgbl-yelp2': dataset = OHGBDataset(name=dataset_name, raw_dir='') g = dataset[0].long( ) self.target_link = [('business', 'described-with', 'phrase')] self.target_link_r = [('business', 'described-with-rev', 'phrase')] elif dataset_name == 'ohgbl-Freebase': dataset = OHGBDataset(name=dataset_name, raw_dir='') g = dataset[0].long( ) self.target_link = [('BOOK', 'BOOK-and-BOOK', 'BOOK')] self.target_link_r = [('BOOK', 'BOOK-and-BOOK-rev', 'BOOK')] self.g = g def build_graph_from_triplets(num_nodes, num_rels, triplets): """ Create a DGL graph. The graph is bidirectional because RGCN authors use reversed relations. This function also generates edge type and normalization factor (reciprocal of node incoming degree) """ g = dgl.graph(([], [])) g.add_nodes(num_nodes) src, rel, dst = triplets src, dst = np.concatenate((src, dst)), np.concatenate((dst, src)) rel = np.concatenate((rel, rel + num_rels)) edges = sorted(zip(dst, src, rel)) dst, src, rel = np.array(edges).transpose( ) g.add_edges(src, dst) norm = comp_deg_norm(g) print("# nodes: {}, # edges: {}".format(num_nodes, len(src))) return g, rel.astype('int64'), norm.astype('int64') def comp_deg_norm(g): g = g.local_var( ) in_deg = g.in_degrees(range(g.number_of_nodes( ))).float( ).numpy( ) norm = 1.0 / in_deg norm[np.isinf(norm)] = 0 return norm @register_dataset('kg_sub_link_prediction') class KG_RedDataset(LinkPredictionDataset): def __init__(self, dataset_name, *args, **kwargs): super(KG_RedDataset, self).__init__(*args, **kwargs) self.trans_dir = os.path.join('openhgnn/dataset/data', dataset_name) self.ind_dir = self.trans_dir + '_ind' folder = os.path.exists(self.trans_dir) if not folder: os.makedirs(self.trans_dir) url = "https://s3.cn-north-1.amazonaws.com.cn/dgl-data/dataset/openhgnn/fb237_v1.zip" response = requests.get(url) with zipfile.ZipFile(io.BytesIO(response.content)) as myzip: myzip.extractall(self.trans_dir) print("--- download data ---") else: print("--- There is data! ---") folder = os.path.exists(self.ind_dir) if not folder: os.makedirs(self.ind_dir) # 下载数据 url = "https://s3.cn-north-1.amazonaws.com.cn/dgl-data/dataset/openhgnn/fb237_v1_ind.zip" response = requests.get(url) with zipfile.ZipFile(io.BytesIO(response.content)) as myzip: myzip.extractall(self.ind_dir) print("--- download data ---") else: print("--- There is data! ---") with open(os.path.join(self.trans_dir, 'entities.txt')) as f: self.entity2id = dict() for line in f: entity, eid = line.strip().split() self.entity2id[entity] = int(eid) with open(os.path.join(self.trans_dir, 'relations.txt')) as f: self.relation2id = dict() id2relation = [] for line in f: relation, rid = line.strip().split() self.relation2id[relation] = int(rid) id2relation.append(relation) with open(os.path.join(self.ind_dir, 'entities.txt')) as f: self.entity2id_ind = dict() for line in f: entity, eid = line.strip().split() self.entity2id_ind[entity] = int(eid) for i in range(len(self.relation2id)): id2relation.append(id2relation[i] + '_inv') id2relation.append('idd') self.id2relation = id2relation self.n_ent = len(self.entity2id) self.n_rel = len(self.relation2id) self.n_ent_ind = len(self.entity2id_ind) self.tra_train = self.read_triples(self.trans_dir, 'train.txt') self.tra_valid = self.read_triples(self.trans_dir, 'valid.txt') self.tra_test = self.read_triples(self.trans_dir, 'test.txt') self.ind_train = self.read_triples(self.ind_dir, 'train.txt', 'inductive') self.ind_valid = self.read_triples(self.ind_dir, 'valid.txt', 'inductive') self.ind_test = self.read_triples(self.ind_dir, 'test.txt', 'inductive') self.val_filters = self.get_filter('valid') self.tst_filters = self.get_filter('test') for filt in self.val_filters: self.val_filters[filt] = list(self.val_filters[filt]) for filt in self.tst_filters: self.tst_filters[filt] = list(self.tst_filters[filt]) self.tra_KG, self.tra_sub = self.load_graph(self.tra_train) self.ind_KG, self.ind_sub = self.load_graph(self.ind_train, 'inductive') self.tra_train = np.array(self.tra_valid) self.tra_val_qry, self.tra_val_ans = self.load_query(self.tra_test) self.ind_val_qry, self.ind_val_ans = self.load_query(self.ind_valid) self.ind_tst_qry, self.ind_tst_ans = self.load_query(self.ind_test) self.valid_q, self.valid_a = self.tra_val_qry, self.tra_val_ans self.test_q, self.test_a = self.ind_val_qry + self.ind_tst_qry, self.ind_val_ans + self.ind_tst_ans self.n_train = len(self.tra_train) self.n_valid = len(self.valid_q) self.n_test = len(self.test_q) def read_triples(self, directory, filename, mode='transductive'): triples = [] with open(os.path.join(directory, filename)) as f: for line in f: h, r, t = line.strip().split() if mode == 'transductive': h, r, t = self.entity2id[h], self.relation2id[r], self.entity2id[t] else: h, r, t = self.entity2id_ind[h], self.relation2id[r], self.entity2id_ind[t] triples.append([h, r, t]) triples.append([t, r + self.n_rel, h]) return triples def load_graph(self, triples, mode='transductive'): n_ent = self.n_ent if mode == 'transductive' else self.n_ent_ind KG = np.array(triples) idd = np.concatenate([np.expand_dims(np.arange(n_ent), 1), 2 * self.n_rel * np.ones((n_ent, 1)), np.expand_dims(np.arange(n_ent), 1)], 1) KG = np.concatenate([KG, idd], 0) n_fact = KG.shape[0] M_sub = csr_matrix((np.ones((n_fact,)), (np.arange(n_fact), KG[:, 0])), shape=(n_fact, n_ent)) return KG, M_sub def load_query(self, triples): triples.sort(key=lambda x: (x[0], x[1])) trip_hr = defaultdict(lambda: list()) for trip in triples: h, r, t = trip trip_hr[(h, r)].append(t) queries = [] answers = [] for key in trip_hr: queries.append(key) answers.append(np.array(trip_hr[key])) return queries, answers def get_neighbors(self, nodes, mode='transductive'): # nodes: n_node x 2 with (batch_idx, node_idx) if mode == 'transductive': KG = self.tra_KG M_sub = self.tra_sub n_ent = self.n_ent else: KG = self.ind_KG M_sub = self.ind_sub n_ent = self.n_ent_ind node_1hot = csr_matrix((np.ones(len(nodes)), (nodes[:, 1], nodes[:, 0])), shape=(n_ent, nodes.shape[0])) edge_1hot = M_sub.dot(node_1hot) edges = np.nonzero(edge_1hot) sampled_edges = np.concatenate([np.expand_dims(edges[1], 1), KG[edges[0]]], axis=1) # (batch_idx, head, rela, tail) sampled_edges = th.LongTensor(sampled_edges) # index to nodes head_nodes, head_index = th.unique(sampled_edges[:, [0, 1]], dim=0, sorted=True, return_inverse=True) tail_nodes, tail_index = th.unique(sampled_edges[:, [0, 3]], dim=0, sorted=True, return_inverse=True) mask = sampled_edges[:, 2] == (self.n_rel * 2) _, old_idx = head_index[mask].sort() old_nodes_new_idx = tail_index[mask][old_idx] sampled_edges = th.cat([sampled_edges, head_index.unsqueeze(1), tail_index.unsqueeze(1)], 1) return tail_nodes, sampled_edges, old_nodes_new_idx def get_batch(self, batch_idx, steps=2, data='train'): if data == 'train': return self.tra_train[batch_idx] if data == 'valid': # print(self.) query, answer = np.array(self.valid_q), self.valid_a # np.array(self.valid_a) n_ent = self.n_ent if data == 'test': query, answer = np.array(self.test_q), self.test_a # np.array(self.test_a) n_ent = self.n_ent_ind subs = [] rels = [] objs = [] subs = query[batch_idx, 0] rels = query[batch_idx, 1] objs = np.zeros((len(batch_idx), n_ent)) for i in range(len(batch_idx)): objs[i][answer[batch_idx[i]]] = 1 return subs, rels, objs def shuffle_train(self, ): rand_idx = np.random.permutation(self.n_train) self.tra_train = self.tra_train[rand_idx] def get_filter(self, data='valid'): filters = defaultdict(lambda: set()) if data == 'valid': for triple in self.tra_train: h, r, t = triple filters[(h, r)].add(t) for triple in self.tra_valid: h, r, t = triple filters[(h, r)].add(t) for triple in self.tra_test: h, r, t = triple filters[(h, r)].add(t) else: for triple in self.ind_train: h, r, t = triple filters[(h, r)].add(t) for triple in self.ind_valid: h, r, t = triple filters[(h, r)].add(t) for triple in self.ind_test: h, r, t = triple filters[(h, r)].add(t) return filters @register_dataset('kg_subT_link_prediction') class KG_RedTDataset(LinkPredictionDataset): def __init__(self, dataset_name, *args, **kwargs): super(KG_RedTDataset, self).__init__(*args, **kwargs) self.task_dir = os.path.join('openhgnn/dataset/data', dataset_name) task_dir = self.task_dir folder = os.path.exists(self.task_dir) if not folder: os.makedirs(self.task_dir) url = "https://s3.cn-north-1.amazonaws.com.cn/dgl-data/dataset/openhgnn/family.zip" response = requests.get(url) with zipfile.ZipFile(io.BytesIO(response.content)) as myzip: myzip.extractall(self.task_dir) print("--- download data ---") else: print("--- There is data! ---") with open(os.path.join(task_dir, 'entities.txt')) as f: self.entity2id = dict() n_ent = 0 for line in f: entity = line.strip() self.entity2id[entity] = n_ent n_ent += 1 with open(os.path.join(task_dir, 'relations.txt')) as f: self.relation2id = dict() n_rel = 0 for line in f: relation = line.strip() self.relation2id[relation] = n_rel n_rel += 1 self.n_ent = n_ent self.n_rel = n_rel self.filters = defaultdict(lambda: set()) self.fact_triple = self.read_triples('facts.txt') self.train_triple = self.read_triples('train.txt') self.valid_triple = self.read_triples('valid.txt') self.test_triple = self.read_triples('test.txt') self.fact_data = self.double_triple(self.fact_triple) self.train_data = np.array(self.double_triple(self.train_triple)) self.valid_data = self.double_triple(self.valid_triple) self.test_data = self.double_triple(self.test_triple) self.load_graph(self.fact_data) self.load_test_graph(self.double_triple(self.fact_triple) + self.double_triple(self.train_triple)) self.valid_q, self.valid_a = self.load_query(self.valid_data) self.test_q, self.test_a = self.load_query(self.test_data) self.n_train = len(self.train_data) self.n_valid = len(self.valid_q) self.n_test = len(self.test_q) for filt in self.filters: self.filters[filt] = list(self.filters[filt]) print('n_train:', self.n_train, 'n_valid:', self.n_valid, 'n_test:', self.n_test) def read_triples(self, filename): triples = [] with open(os.path.join(self.task_dir, filename)) as f: for line in f: h, r, t = line.strip().split() h, r, t = self.entity2id[h], self.relation2id[r], self.entity2id[t] triples.append([h, r, t]) self.filters[(h, r)].add(t) self.filters[(t, r + self.n_rel)].add(h) return triples def double_triple(self, triples): new_triples = [] for triple in triples: h, r, t = triple new_triples.append([t, r + self.n_rel, h]) return triples + new_triples def load_graph(self, triples): idd = np.concatenate([np.expand_dims(np.arange(self.n_ent), 1), 2 * self.n_rel * np.ones((self.n_ent, 1)), np.expand_dims(np.arange(self.n_ent), 1)], 1) self.KG = np.concatenate([np.array(triples), idd], 0) self.n_fact = len(self.KG) self.M_sub = csr_matrix((np.ones((self.n_fact,)), (np.arange(self.n_fact), self.KG[:, 0])), shape=(self.n_fact, self.n_ent)) def load_test_graph(self, triples): idd = np.concatenate([np.expand_dims(np.arange(self.n_ent), 1), 2 * self.n_rel * np.ones((self.n_ent, 1)), np.expand_dims(np.arange(self.n_ent), 1)], 1) self.tKG = np.concatenate([np.array(triples), idd], 0) self.tn_fact = len(self.tKG) self.tM_sub = csr_matrix((np.ones((self.tn_fact,)), (np.arange(self.tn_fact), self.tKG[:, 0])), shape=(self.tn_fact, self.n_ent)) def load_query(self, triples): triples.sort(key=lambda x: (x[0], x[1])) trip_hr = defaultdict(lambda: list()) for trip in triples: h, r, t = trip trip_hr[(h, r)].append(t) queries = [] answers = [] for key in trip_hr: queries.append(key) answers.append(np.array(trip_hr[key])) return queries, answers def get_neighbors(self, nodes, mode='train'): if mode == 'train': KG = self.KG M_sub = self.M_sub else: KG = self.tKG M_sub = self.tM_sub # nodes: n_node x 2 with (batch_idx, node_idx) node_1hot = csr_matrix((np.ones(len(nodes)), (nodes[:, 1], nodes[:, 0])), shape=(self.n_ent, nodes.shape[0])) edge_1hot = M_sub.dot(node_1hot) edges = np.nonzero(edge_1hot) sampled_edges = np.concatenate([np.expand_dims(edges[1], 1), KG[edges[0]]], axis=1) # (batch_idx, head, rela, tail) sampled_edges = torch.LongTensor(sampled_edges).cuda() # index to nodes head_nodes, head_index = torch.unique(sampled_edges[:, [0, 1]], dim=0, sorted=True, return_inverse=True) tail_nodes, tail_index = torch.unique(sampled_edges[:, [0, 3]], dim=0, sorted=True, return_inverse=True) sampled_edges = torch.cat([sampled_edges, head_index.unsqueeze(1), tail_index.unsqueeze(1)], 1) mask = sampled_edges[:, 2] == (self.n_rel * 2) _, old_idx = head_index[mask].sort() old_nodes_new_idx = tail_index[mask][old_idx] return tail_nodes, sampled_edges, old_nodes_new_idx def get_batch(self, batch_idx, steps=2, data='train'): if data == 'train': return np.array(self.train_data)[batch_idx] if data == 'valid': query, answer = np.array(self.valid_q), self.valid_a if data == 'test': query, answer = np.array(self.test_q), self.test_a subs = [] rels = [] objs = [] subs = query[batch_idx, 0] rels = query[batch_idx, 1] objs = np.zeros((len(batch_idx), self.n_ent)) for i in range(len(batch_idx)): objs[i][answer[batch_idx[i]]] = 1 return subs, rels, objs def shuffle_train(self, ): fact_triple = np.array(self.fact_triple) train_triple = np.array(self.train_triple) all_triple = np.concatenate([fact_triple, train_triple], axis=0) n_all = len(all_triple) rand_idx = np.random.permutation(n_all) all_triple = all_triple[rand_idx] # increase the ratio of fact_data, e.g., 3/4->4/5, can increase the performance self.fact_data = self.double_triple(all_triple[:n_all * 3 // 4].tolist()) self.train_data = np.array(self.double_triple(all_triple[n_all * 3 // 4:].tolist())) self.n_train = len(self.train_data) self.load_graph(self.fact_data) @register_dataset('kg_link_prediction') class KG_LinkPrediction(LinkPredictionDataset): """ From `RGCN <https://arxiv.org/abs/1703.06103>`_, WN18 & FB15k face a data leakage. """ def __init__(self, dataset_name, *args, **kwargs): super(KG_LinkPrediction, self).__init__(*args, **kwargs) if dataset_name in ['wn18', 'FB15k', 'FB15k-237']: dataset = load_data(dataset_name) g = dataset[0] self.num_rels = dataset.num_rels self.num_nodes = dataset.num_nodes self.train_hg, self.train_triplets = self._build_hg(g, 'train') self.valid_hg, self.valid_triplets = self._build_hg(g, 'valid') self.test_hg, self.test_triplets = self._build_hg(g, 'test') self.g = self.train_hg self.category = '_N' self.target_link = self.test_hg.canonical_etypes def _build_hg(self, g, mode): sub_g = dgl.edge_subgraph(g, g.edata[mode + '_edge_mask'], relabel_nodes=False) src, dst = sub_g.edges( ) etype = sub_g.edata['etype'] edge_dict = {} for i in range(self.num_rels): mask = (etype == i) edge_name = ('_N', str(i), '_N') edge_dict[edge_name] = (src[mask], dst[mask]) hg = dgl.heterograph(edge_dict, {'_N': self.num_nodes}) return hg, th.stack((src, etype, dst)).T def modify_size(self, eval_percent, dataset_type): if dataset_type == 'valid': self.valid_triplets = th.tensor( random.sample(self.valid_triplets.tolist( ), math.ceil(self.valid_triplets.shape[0] * eval_percent))) elif dataset_type == 'test': self.test_triplets = th.tensor( random.sample(self.test_triplets.tolist( ), math.ceil(self.test_triplets.shape[0] * eval_percent))) def get_graph_directed_from_triples(self, triples, format='graph'): s = th.LongTensor(triples[:, 0]) r = th.LongTensor(triples[:, 1]) o = th.LongTensor(triples[:, 2]) if format == 'graph': edge_dict = {} for i in range(self.num_rels): mask = (r == i) edge_name = (self.category, str(i), self.category) edge_dict[edge_name] = (s[mask], o[mask]) return dgl.heterograph(edge_dict, {self.category: self.num_nodes}) def get_triples(self, g, mask_mode): ''' :param g: :param mask_mode: should be one of 'train_mask', 'val_mask', 'test_mask :return: ''' edges = g.edges( ) etype = g.edata['etype'] mask = g.edata.pop(mask_mode) return th.stack((edges[0][mask], etype[mask], edges[1][mask])) def get_all_triplets(self, dataset): train_data = th.LongTensor(dataset.train) valid_data = th.LongTensor(dataset.valid) test_data = th.LongTensor(dataset.test) return train_data, valid_data, test_data def get_split(self): return self.train_hg, self.valid_hg, self.test_hg, None, None def split_graph(self, g, mode='train'): """ Parameters ---------- g: DGLGraph a homogeneous graph fomat mode: str split the subgraph according to the mode Returns ------- hg: DGLHeterograph """ edges = g.edges( ) etype = g.edata['etype'] if mode == 'train': mask = g.edata['train_mask'] elif mode == 'valid': mask = g.edata['valid_edge_mask'] elif mode == 'test': mask = g.edata['test_edge_mask'] hg = self.build_graph((edges[0][mask], edges[1][mask]), etype[mask]) return hg def build_graph(self, edges, etype): edge_dict = {} for i in range(self.num_rels): mask = (etype == i) edge_name = (self.category, str(i), self.category) edge_dict[edge_name] = (edges[0][mask], edges[1][mask]) hg = dgl.heterograph(edge_dict, {self.category: self.num_nodes}) return hg def build_g(self, train): s = train[:, 0] r = train[:, 1] o = train[:, 2] edge_dict = {} for i in range(self.num_rels): mask = (r == i) edge_name = (self.category, str(i), self.category) edge_dict[edge_name] = (th.LongTensor(s[mask]), th.LongTensor(o[mask])) hg = dgl.heterograph(edge_dict, {self.category: self.num_nodes}) return hg import torch import struct import os import json import logging from scipy.sparse import csc_matrix from scipy.special import softmax from tqdm import tqdm import pickle import scipy.sparse as ssp import lmdb import requests import zipfile import io from torch.utils.data import Dataset import networkx as nx from ..utils.Grail_utils import * class SubGraphDataset(Dataset): def __init__(self, db_path, db_name_pos, db_name_neg, raw_data_paths, included_relations=None, add_traspose_rels=False, num_neg_samples_per_link=1, use_kge_embeddings=False, dataset='', kge_model='', file_name=''): self.main_env = lmdb.open(db_path, readonly= True, max_dbs=3, lock=False) self.db_pos = self.main_env.open_db(db_name_pos.encode()) self.db_neg = self.main_env.open_db(db_name_neg.encode()) self.node_features, self.kge_entity2id = get_kge_embeddings(dataset, kge_model) if use_kge_embeddings else (None, None) self.num_neg_samples_per_link = num_neg_samples_per_link self.file_name = file_name self.add_traspose_rels = add_traspose_rels ssp_graph, __, __, __, id2entity, id2relation = process_files(raw_data_paths, included_relations) self.num_rels = len(ssp_graph) # Add transpose matrices to handle both directions of relations. if add_traspose_rels: ssp_graph_t = [adj.T for adj in ssp_graph] ssp_graph += ssp_graph_t # the effective number of relations after adding symmetric adjacency matrices and/or self connections self.aug_num_rels = len(ssp_graph) self.graph = ssp_multigraph_to_dgl(ssp_graph) self.ssp_graph = ssp_graph self.id2entity = id2entity self.id2relation = id2relation self.max_n_label = np.array([0, 0]) with self.main_env.begin() as txn: #a = txn.get('max_n_label_sub'.encode()) #print(a) self.max_n_label[0] = int.from_bytes(txn.get('max_n_label_sub'.encode()), byteorder='little') self.max_n_label[1] = int.from_bytes(txn.get('max_n_label_obj'.encode()), byteorder='little') self.avg_subgraph_size = struct.unpack('f', txn.get('avg_subgraph_size'.encode())) self.min_subgraph_size = struct.unpack('f', txn.get('min_subgraph_size'.encode())) self.max_subgraph_size = struct.unpack('f', txn.get('max_subgraph_size'.encode())) self.std_subgraph_size = struct.unpack('f', txn.get('std_subgraph_size'.encode())) self.avg_enc_ratio = struct.unpack('f', txn.get('avg_enc_ratio'.encode())) self.min_enc_ratio = struct.unpack('f', txn.get('min_enc_ratio'.encode())) self.max_enc_ratio = struct.unpack('f', txn.get('max_enc_ratio'.encode())) self.std_enc_ratio = struct.unpack('f', txn.get('std_enc_ratio'.encode())) self.avg_num_pruned_nodes = struct.unpack('f', txn.get('avg_num_pruned_nodes'.encode())) self.min_num_pruned_nodes = struct.unpack('f', txn.get('min_num_pruned_nodes'.encode())) self.max_num_pruned_nodes = struct.unpack('f', txn.get('max_num_pruned_nodes'.encode())) self.std_num_pruned_nodes = struct.unpack('f', txn.get('std_num_pruned_nodes'.encode())) logging.info(f"Max distance from sub : {self.max_n_label[0]}, Max distance from obj : {self.max_n_label[1]}") # logging.info('=====================') # logging.info(f"Subgraph size stats: \n Avg size {self.avg_subgraph_size}, \n Min size {self.min_subgraph_size}, \n Max size {self.max_subgraph_size}, \n Std {self.std_subgraph_size}") # logging.info('=====================') # logging.info(f"Enclosed nodes ratio stats: \n Avg size {self.avg_enc_ratio}, \n Min size {self.min_enc_ratio}, \n Max size {self.max_enc_ratio}, \n Std {self.std_enc_ratio}") # logging.info('=====================') # logging.info(f"# of pruned nodes stats: \n Avg size {self.avg_num_pruned_nodes}, \n Min size {self.min_num_pruned_nodes}, \n Max size {self.max_num_pruned_nodes}, \n Std {self.std_num_pruned_nodes}") with self.main_env.begin(db=self.db_pos) as txn: self.num_graphs_pos = int.from_bytes(txn.get('num_graphs'.encode()), byteorder='little') with self.main_env.begin(db=self.db_neg) as txn: self.num_graphs_neg = int.from_bytes(txn.get('num_graphs'.encode()), byteorder='little') self.__getitem__(0) def __getitem__(self, index): with self.main_env.begin(db=self.db_pos) as txn: str_id = '{:08}'.format(index).encode('ascii') nodes_pos, r_label_pos, g_label_pos, n_labels_pos = deserialize(txn.get(str_id)).values() subgraph_pos = self._prepare_subgraphs(nodes_pos, r_label_pos, n_labels_pos) subgraphs_neg = [] r_labels_neg = [] g_labels_neg = [] with self.main_env.begin(db=self.db_neg) as txn: for i in range(self.num_neg_samples_per_link): str_id = '{:08}'.format(index + i * (self.num_graphs_pos)).encode('ascii') nodes_neg, r_label_neg, g_label_neg, n_labels_neg = deserialize(txn.get(str_id)).values() subgraphs_neg.append(self._prepare_subgraphs(nodes_neg, r_label_neg, n_labels_neg)) r_labels_neg.append(r_label_neg) g_labels_neg.append(g_label_neg) return subgraph_pos, g_label_pos, r_label_pos, subgraphs_neg, g_labels_neg, r_labels_neg def __len__(self): return self.num_graphs_pos def _prepare_subgraphs(self, nodes, r_label, n_labels): if not isinstance(self.graph, dgl.DGLGraph): subgraph = dgl.graph(self.graph.subgraph(nodes)) else: subgraph = self.graph.subgraph(nodes) #subgraph.edata['type'] = self.graph.edata['type'][self.graph.subgraph(nodes).parent_eid] subgraph.edata['type'] = self.graph.edata['type'][subgraph.edata[dgl.EID]] subgraph.edata['label'] = torch.tensor(r_label * np.ones(subgraph.edata['type'].shape), dtype=torch.long) #print("请输出: ") #print(subgraph) #edges_btw_roots = subgraph.edge_id(0, 1, return_array=True) #edges_btw_roots = subgraph.edge_ids(0, 1) edges_btw_roots = torch.tensor([]) try: edges_btw_roots = subgraph.edge_ids(torch.tensor([0]),torch.tensor([1])) # edges_btw_roots = np.array([edges_btw_roots]) except: #print("Error") edges_btw_roots = torch.tensor([]) edges_btw_roots = edges_btw_roots.numpy() rel_link = np.nonzero(subgraph.edata['type'][edges_btw_roots] == r_label) if rel_link.squeeze().nelement() == 0: subgraph = dgl.add_edges(subgraph, 0, 1) subgraph.edata['type'][-1] = torch.tensor(r_label).type(torch.LongTensor) subgraph.edata['label'][-1] = torch.tensor(r_label).type(torch.LongTensor) # map the id read by GraIL to the entity IDs as registered by the KGE embeddings kge_nodes = [self.kge_entity2id[self.id2entity[n]] for n in nodes] if self.kge_entity2id else None n_feats = self.node_features[kge_nodes] if self.node_features is not None else None subgraph = self._prepare_features_new(subgraph, n_labels, n_feats) return subgraph def _prepare_features(self, subgraph, n_labels, n_feats=None): # One hot encode the node label feature and concat to n_featsure n_nodes = subgraph.number_of_nodes() label_feats = np.zeros((n_nodes, self.max_n_label[0] + 1)) label_feats[np.arange(n_nodes), n_labels] = 1 label_feats[np.arange(n_nodes), self.max_n_label[0] + 1 + n_labels[:, 1]] = 1 n_feats = np.concatenate((label_feats, n_feats), axis=1) if n_feats else label_feats subgraph.ndata['feat'] = torch.FloatTensor(n_feats) self.n_feat_dim = n_feats.shape[1] # Find cleaner way to do this -- i.e. set the n_feat_dim return subgraph def _prepare_features_new(self, subgraph, n_labels, n_feats=None): # One hot encode the node label feature and concat to n_featsure n_nodes = subgraph.number_of_nodes() label_feats = np.zeros((n_nodes, self.max_n_label[0] + 1 + self.max_n_label[1] + 1)) label_feats[np.arange(n_nodes), n_labels[:, 0]] = 1 label_feats[np.arange(n_nodes), self.max_n_label[0] + 1 + n_labels[:, 1]] = 1 # label_feats = np.zeros((n_nodes, self.max_n_label[0] + 1 + self.max_n_label[1] + 1)) # label_feats[np.arange(n_nodes), 0] = 1 # label_feats[np.arange(n_nodes), self.max_n_label[0] + 1] = 1 n_feats = np.concatenate((label_feats, n_feats), axis=1) if n_feats is not None else label_feats subgraph.ndata['feat'] = torch.FloatTensor(n_feats) head_id = np.argwhere([label[0] == 0 and label[1] == 1 for label in n_labels]) tail_id = np.argwhere([label[0] == 1 and label[1] == 0 for label in n_labels]) n_ids = np.zeros(n_nodes) n_ids[head_id] = 1 # head n_ids[tail_id] = 2 # tail subgraph.ndata['id'] = torch.FloatTensor(n_ids) self.n_feat_dim = n_feats.shape[1] # Find cleaner way to do this -- i.e. set the n_feat_dim return subgraph @register_dataset('grail_link_prediction') class Grail_LinkPrediction(LinkPredictionDataset): def __init__(self, dataset_name, *args, **kwargs): super(Grail_LinkPrediction, self).__init__(*args, **kwargs) self.args = kwargs['args'] self.args.db_path = f'./openhgnn/dataset/data/{self.args.dataset}/subgraphs_en_{self.args.enclosing_sub_graph}_neg_{self.args.num_neg_samples_per_link}_hop_{self.args.hop}' self.args.train_file = "train" self.args.valid_file = "valid" self.args.file_paths = { 'train': './openhgnn/dataset/data/{}/{}.txt'.format(self.args.dataset, self.args.train_file), 'valid': './openhgnn/dataset/data/{}/{}.txt'.format(self.args.dataset, self.args.valid_file) } relation2id_path = f'./openhgnn/dataset/data/{self.args.dataset}/relation2id.json' self.data_folder = f'./openhgnn/dataset/data/{self.args.dataset}' if not os.path.exists(self.data_folder): os.makedirs(self.data_folder) # makedirs 创建文件时如果路径不存在会创建这个路径 url = f'https://github.com/kkteru/grail/blob/master/data/{self.args.dataset}' self.download_folder(url,self.data_folder) print("--- download data ---") else: print("--- There is data! ---") if not os.path.exists(self.data_folder+'_ind'): os.makedirs(self.data_folder+'_ind') # makedirs 创建文件时如果路径不存在会创建这个路径 url = f'https://github.com/kkteru/grail/blob/master/data/{self.args.dataset}_ind' self.download_folder(url,self.data_folder+'_ind') print("--- download data ---") else: print("--- There is data! ---") if not os.path.isdir(self.args.db_path): generate_subgraph_datasets(self.args, relation2id_path) with open(relation2id_path) as f: self.relation2id = json.load(f) self.train = SubGraphDataset(self.args.db_path, 'train_pos', 'train_neg', self.args.file_paths,add_traspose_rels=self.args.add_traspose_rels,num_neg_samples_per_link=self.args.num_neg_samples_per_link,use_kge_embeddings=self.args.use_kge_embeddings, dataset=self.args.dataset,kge_model=self.args.kge_model, file_name=self.args.train_file) self.valid = SubGraphDataset(self.args.db_path, 'valid_pos', 'valid_neg', self.args.file_paths, add_traspose_rels=self.args.add_traspose_rels, num_neg_samples_per_link=self.args.num_neg_samples_per_link, use_kge_embeddings=self.args.use_kge_embeddings, dataset=self.args.dataset, kge_model=self.args.kge_model, file_name= self.args.valid_file) def download_folder(self,url, save_path): response = requests.get(url) if response.status_code == 200: # 确保保存路径存在 os.makedirs(save_path, exist_ok=True) # 解析响应内容 content = response.content.decode('utf-8') lines = content.splitlines() for line in lines: # 提取文件名 file_name = line.split('/')[-1] # 构建文件的完整URL file_url = url + '/' + file_name # 构建文件的保存路径 file_save_path = os.path.join(save_path, file_name) # 下载文件 self.download_file(file_url, file_save_path) def download_file(self,url, save_path): response = requests.get(url) if response.status_code == 200: with open(save_path, 'wb') as file: file.write(response.content) class kg_sampler( ): def __init__(self, ): self.sampler = 'uniform' return def generate_sampled_graph_and_labels(self, triplets, sample_size, split_size, num_rels, adj_list, degrees, negative_rate, sampler="uniform"): """Get training graph and signals First perform edge neighborhood sampling on graph, then perform negative sampling to generate negative samples """ # perform edge neighbor sampling if self.sampler == "uniform": edges = sample_edge_uniform(adj_list, degrees, len(triplets), sample_size) elif self.sampler == "neighbor": edges = sample_edge_neighborhood(adj_list, degrees, len(triplets), sample_size) else: raise ValueError("Sampler type must be either 'uniform' or 'neighbor'.") # relabel nodes to have consecutive node ids edges = triplets[edges] src, rel, dst = edges.transpose( ) uniq_v, edges = np.unique((src, dst), return_inverse=True) src, dst = np.reshape(edges, (2, -1)) relabeled_edges = np.stack((src, rel, dst)).transpose( ) # negative sampling samples, labels = negative_sampling(relabeled_edges, len(uniq_v), negative_rate) # further split graph, only half of the edges will be used as graph # structure, while the rest half is used as unseen positive samples split_size = int(sample_size * split_size) graph_split_ids = np.random.choice(np.arange(sample_size), size=split_size, replace=False) src = src[graph_split_ids] dst = dst[graph_split_ids] rel = rel[graph_split_ids] # build DGL graph print("# sampled nodes: {}".format(len(uniq_v))) print("# sampled edges: {}".format(len(src) * 2)) g, rel, norm = build_graph_from_triplets(len(uniq_v), num_rels, (src, rel, dst)) return g, uniq_v, rel, norm, samples, labels def sample_edge_neighborhood(adj_list, degrees, n_triplets, sample_size): """Sample edges by neighborhool expansion. This guarantees that the sampled edges form a connected graph, which may help deeper GNNs that require information from more than one hop. """ edges = np.zeros((sample_size), dtype=np.int32) # initialize sample_counts = np.array([d for d in degrees]) picked = np.array([False for _ in range(n_triplets)]) seen = np.array([False for _ in degrees]) for i in range(0, sample_size): weights = sample_counts * seen if np.sum(weights) == 0: weights = np.ones_like(weights) weights[np.where(sample_counts == 0)] = 0 probabilities = (weights) / np.sum(weights) chosen_vertex = np.random.choice(np.arange(degrees.shape[0]), p=probabilities) chosen_adj_list = adj_list[chosen_vertex] seen[chosen_vertex] = True chosen_edge = np.random.choice(np.arange(chosen_adj_list.shape[0])) chosen_edge = chosen_adj_list[chosen_edge] edge_number = chosen_edge[0] while picked[edge_number]: chosen_edge = np.random.choice(np.arange(chosen_adj_list.shape[0])) chosen_edge = chosen_adj_list[chosen_edge] edge_number = chosen_edge[0] edges[i] = edge_number other_vertex = chosen_edge[1] picked[edge_number] = True sample_counts[chosen_vertex] -= 1 sample_counts[other_vertex] -= 1 seen[other_vertex] = True return edges def sample_edge_uniform(adj_list, degrees, n_triplets, sample_size): """Sample edges uniformly from all the edges.""" all_edges = np.arange(n_triplets) return np.random.choice(all_edges, sample_size, replace=False) # --- ExpressGNN --- # grounded rule stats code BAD = 0 # sample not valid FULL_OBSERVERED = 1 # sample valid, but rule contains only observed vars and does not have negation for all atoms GOOD = 2 # sample valid @register_dataset('express_gnn') class ExpressGNNDataset(BaseDataset): def __init__(self, dataset_name, *args, **kwargs): super( ).__init__(*args, **kwargs) self.args = kwargs['args'] self.PRED_DICT = {} self.dataset_name = dataset_name self.const_dict = ConstantDict() self.batchsize = self.args.batchsize self.shuffle_sampling = self.args.shuffle_sampling data_root = 'openhgnn' data_root = os.path.join(data_root, 'dataset') data_root = os.path.join(data_root, 'data') data_root = os.path.join(data_root, self.dataset_name) ext_rule_path = None folder = os.path.exists(data_root) print(data_root) print('folder', folder) if not folder: # 判断是否存在文件夹如果不存在则创建为文件夹 os.makedirs(data_root) # makedirs 创建文件时如果路径不存在会创建这个路径 # 下载数据 url = f"https://s3.cn-north-1.amazonaws.com.cn/dgl-data/dataset/openhgnn/{dataset_name}.zip" response = requests.get(url) with zipfile.ZipFile(io.BytesIO(response.content)) as myzip: myzip.extractall(data_root) print("--- download data ---") else: print("--- There is data! ---") # Decide the way dataset will be load, set 1 to load FBWN dataset load_method = 0 # print(dataset_name[0:13]) if dataset_name[0:13] == 'EXP_FB15k-237': load_method = 1 else: load_method = 0 guss_fb = 'EXP_FB15k' in data_root if guss_fb != (load_method == 1): print("WARNING: set load_method to 1 if you load Freebase dataset, otherwise 0") # FBWN dataset if load_method == 1: fact_path_ls = [joinpath(data_root, 'facts.txt'), joinpath(data_root, 'train.txt')] query_path = joinpath(data_root, 'test.txt') pred_path = joinpath(data_root, 'relations.txt') const_path = joinpath(data_root, 'entities.txt') valid_path = joinpath(data_root, 'valid.txt') rule_path = joinpath(data_root, 'cleaned_rules_weight_larger_than_0.9.txt') print(rule_path) print(os.getcwd()) # print(fact_path_ls + [query_path, pred_path, const_path, valid_path, rule_path]) # assert all(map(isfile, fact_path_ls + [query_path, pred_path, const_path, valid_path, rule_path])) # assuming only one type TYPE_SET.update(['type']) # add all const for line in iterline(const_path): self.const_dict.add_const('type', line) # add all pred for line in iterline(pred_path): self.PRED_DICT[line] = Predicate(line, ['type', 'type']) # add all facts fact_ls = [] for fact_path in fact_path_ls: for line in iterline(fact_path): parts = line.split('\t') assert len(parts) == 3, print(parts) e1, pred_name, e2 = parts assert self.const_dict.has_const('type', e1) and self.const_dict.has_const('type', e2) assert pred_name in self.PRED_DICT fact_ls.append(Fact(pred_name, [e1, e2], 1)) # add all validations valid_ls = [] for line in iterline(valid_path): parts = line.split('\t') assert len(parts) == 3, print(parts) e1, pred_name, e2 = parts assert self.const_dict.has_const('type', e1) and self.const_dict.has_const('type', e2) assert pred_name in self.PRED_DICT valid_ls.append(Fact(pred_name, [e1, e2], 1)) # add all queries query_ls = [] for line in iterline(query_path): parts = line.split('\t') assert len(parts) == 3, print(parts) e1, pred_name, e2 = parts assert self.const_dict.has_const('type', e1) and self.const_dict.has_const('type', e2) assert pred_name in self.PRED_DICT query_ls.append(Fact(pred_name, [e1, e2], 1)) # add all rules rule_ls = [] strip_items = lambda ls: list(map(lambda x: x.strip( ), ls)) first_atom_reg = re.compile(r'([\d.]+) (!?)([^(]+)\((.*)\)') atom_reg = re.compile(r'(!?)([^(]+)\((.*)\)') for line in iterline(rule_path): atom_str_ls = strip_items(line.split(' v ')) assert len(atom_str_ls) > 1, 'rule length must be greater than 1, but get %s' % line atom_ls = [] rule_weight = 0.0 for i, atom_str in enumerate(atom_str_ls): if i == 0: m = first_atom_reg.match(atom_str) assert m is not None, 'matching atom failed for %s' % atom_str rule_weight = float(m.group(1)) neg = m.group(2) == '!' pred_name = m.group(3).strip( ) var_name_ls = strip_items(m.group(4).split(',')) else: m = atom_reg.match(atom_str) assert m is not None, 'matching atom failed for %s' % atom_str neg = m.group(1) == '!' pred_name = m.group(2).strip( ) var_name_ls = strip_items(m.group(3).split(',')) atom = Atom(neg, pred_name, var_name_ls, self.PRED_DICT[pred_name].var_types) atom_ls.append(atom) rule = Formula(atom_ls, rule_weight) rule_ls.append(rule) else: if dataset_name == 'Cora' or dataset_name == 'kinship': data_root = joinpath(data_root, 'S' + str(self.args.load_s)) elif dataset_name == 'uw_cse': if self.args.load_s == 1: data_root = joinpath(data_root, 'ai') elif self.args.load_s == 2: data_root = joinpath(data_root, 'graphics') elif self.args.load_s == 3: data_root = joinpath(data_root, 'language') elif self.args.load_s == 4: data_root = joinpath(data_root, 'systems') elif self.args.load_s == 5: data_root = joinpath(data_root, 'theory') else: print('Warning: Invalid load_s') else: print('Warning: Invalid dataset for load_method = 0') rpath = joinpath(data_root, 'rules') if ext_rule_path is None else ext_rule_path fact_ls, rule_ls, query_ls = self.preprocess_kinship(joinpath(data_root, 'predicates'), joinpath(data_root, 'facts'), rpath, joinpath(data_root, 'queries')) valid_ls = [] self.const_sort_dict = dict( [(type_name, sorted(list(self.const_dict[type_name]))) for type_name in self.const_dict.constants.keys( )]) if load_method == 1: self.const2ind = dict([(const, i) for i, const in enumerate(self.const_sort_dict['type'])]) # linear in size of facts self.fact_dict = dict((pred_name, set( )) for pred_name in self.PRED_DICT) self.test_fact_dict = dict((pred_name, set( )) for pred_name in self.PRED_DICT) self.valid_dict = dict((pred_name, set( )) for pred_name in self.PRED_DICT) self.ht_dict = dict((pred_name, [dict( ), dict( )]) for pred_name in self.PRED_DICT) self.ht_dict_train = dict((pred_name, [dict( ), dict( )]) for pred_name in self.PRED_DICT) def add_ht(pn, c_ls, ht_dict): if load_method == 0: if c_ls[0] in ht_dict[pn][0]: ht_dict[pn][0][c_ls[0]].add(c_ls[0]) else: ht_dict[pn][0][c_ls[0]] = {c_ls[0]} elif load_method == 1: if c_ls[0] in ht_dict[pn][0]: ht_dict[pn][0][c_ls[0]].add(c_ls[1]) else: ht_dict[pn][0][c_ls[0]] = {c_ls[1]} if c_ls[1] in ht_dict[pn][1]: ht_dict[pn][1][c_ls[1]].add(c_ls[0]) else: ht_dict[pn][1][c_ls[1]] = {c_ls[0]} const_cnter = Counter() for fact in fact_ls: self.fact_dict[fact.pred_name].add((fact.val, tuple(fact.const_ls))) add_ht(fact.pred_name, fact.const_ls, self.ht_dict) add_ht(fact.pred_name, fact.const_ls, self.ht_dict_train) const_cnter.update(fact.const_ls) for fact in valid_ls: self.valid_dict[fact.pred_name].add((fact.val, tuple(fact.const_ls))) add_ht(fact.pred_name, fact.const_ls, self.ht_dict) # the sorted list version self.fact_dict_2 = dict((pred_name, sorted(list(self.fact_dict[pred_name]))) for pred_name in self.fact_dict.keys( )) self.valid_dict_2 = dict((pred_name, sorted(list(self.valid_dict[pred_name]))) for pred_name in self.valid_dict.keys( )) self.rule_ls = rule_ls # pred_atom-key dict self.atom_key_dict_ls = [] for rule in self.rule_ls: atom_key_dict = dict( ) for atom in rule.atom_ls: atom_dict = dict((var_name, dict( )) for var_name in atom.var_name_ls) for i, var_name in enumerate(atom.var_name_ls): if atom.pred_name not in self.fact_dict: continue for v in self.fact_dict[atom.pred_name]: if v[1][i] not in atom_dict[var_name]: atom_dict[var_name][v[1][i]] = [v] else: atom_dict[var_name][v[1][i]] += [v] # happens if predicate occurs more than once in one rule then we merge the set if atom.pred_name in atom_key_dict: for k, v in atom_dict.items( ): if k not in atom_key_dict[atom.pred_name]: atom_key_dict[atom.pred_name][k] = v else: atom_key_dict[atom.pred_name] = atom_dict self.atom_key_dict_ls.append(atom_key_dict) self.test_fact_ls = [] self.valid_fact_ls = [] for fact in query_ls: self.test_fact_ls.append((fact.val, fact.pred_name, tuple(fact.const_ls))) self.test_fact_dict[fact.pred_name].add((fact.val, tuple(fact.const_ls))) add_ht(fact.pred_name, fact.const_ls, self.ht_dict) for fact in valid_ls: self.valid_fact_ls.append((fact.val, fact.pred_name, tuple(fact.const_ls))) self.num_rules = len(rule_ls) self.rule_gens = None self.reset( ) def generate_gnd_pred(self, pred_name): """ return a list of all instantiations of a predicate function, this can be extremely large :param pred_name: string :return: """ assert pred_name in self.PRED_DICT pred = self.PRED_DICT[pred_name] subs = itertools.product(*[self.const_sort_dict[var_type] for var_type in pred.var_types]) return [(pred_name, sub) for sub in subs] def generate_gnd_rule(self, rule): subs = itertools.product(*[self.const_sort_dict[rule.rule_vars[k]] for k in rule.rule_vars.keys( )]) sub = next(subs, None) while sub is not None: latent_vars = [] latent_neg_mask = [] observed_neg_mask = [] for atom in rule.atom_ls: grounding = tuple(sub[rule.key2ind[var_name]] for var_name in atom.var_name_ls) pos_gnding, neg_gnding = (1, grounding), (0, grounding) if pos_gnding in self.fact_dict[atom.pred_name]: observed_neg_mask.append(0 if atom.neg else 1) elif neg_gnding in self.fact_dict[atom.pred_name]: observed_neg_mask.append(1 if atom.neg else 0) else: latent_vars.append((atom.pred_name, grounding)) latent_neg_mask.append(1 if atom.neg else 0) isfullneg = (sum(latent_neg_mask) == len(latent_neg_mask)) and \ (sum(observed_neg_mask) > 0) yield latent_vars, [latent_neg_mask, observed_neg_mask], isfullneg sub = next(subs, None) def get_batch(self, epoch_mode=False, filter_latent=True): """ return the ind-th batch of ground formula and latent variable indicators :return: Parameters ---------- filter_latent epoch_mode """ batch_neg_mask = [[] for _ in range(len(self.rule_ls))] batch_latent_var_inds = [[] for _ in range(len(self.rule_ls))] observed_rule_cnts = [0.0 for _ in range(len(self.rule_ls))] flat_latent_vars = dict( ) cnt = 0 inds = list(range(len(self.rule_ls))) while cnt < self.batchsize: if self.shuffle_sampling: shuffle(inds) hasdata = False for ind in inds: latent_vars, neg_mask, isfullneg = next(self.rule_gens[ind], (None, None, None)) if latent_vars is None: if epoch_mode: continue else: self.rule_gens[ind] = self.generate_gnd_rule(self.rule_ls[ind]) latent_vars, neg_mask, isfullneg = next(self.rule_gens[ind]) if epoch_mode: hasdata = True # if rule is fully latent if (len(neg_mask[1]) == 0) and filter_latent: continue # if rule fully observed if len(latent_vars) == 0: observed_rule_cnts[ind] += 0 if isfullneg else 1 cnt += 1 if cnt >= self.batchsize: break else: continue batch_neg_mask[ind].append(neg_mask) for latent_var in latent_vars: if latent_var not in flat_latent_vars: flat_latent_vars[latent_var] = len(flat_latent_vars) batch_latent_var_inds[ind].append([flat_latent_vars[e] for e in latent_vars]) cnt += 1 if cnt >= self.batchsize: break if epoch_mode and (hasdata is False): break flat_list = sorted([(k, v) for k, v in flat_latent_vars.items( )], key=lambda x: x[1]) flat_list = [e[0] for e in flat_list] return batch_neg_mask, flat_list, batch_latent_var_inds, observed_rule_cnts def _instantiate_pred(self, atom, atom_dict, sub, rule, observed_prob): key2ind = rule.key2ind rule_vars = rule.rule_vars # substitute with observed fact if np.random.rand( ) < observed_prob: fact_choice_set = None for var_name in atom.var_name_ls: const = sub[key2ind[var_name]] if const is None: choice_set = itertools.chain.from_iterable([v for k, v in atom_dict[var_name].items( )]) else: if const in atom_dict[var_name]: choice_set = atom_dict[var_name][const] else: choice_set = [] if fact_choice_set is None: fact_choice_set = set(choice_set) else: fact_choice_set = fact_choice_set.intersection(set(choice_set)) if len(fact_choice_set) == 0: break if len(fact_choice_set) == 0: for var_name in atom.var_name_ls: if sub[key2ind[var_name]] is None: sub[key2ind[var_name]] = choice(self.const_sort_dict[rule_vars[var_name]]) else: val, const_ls = choice(sorted(list(fact_choice_set))) for var_name, const in zip(atom.var_name_ls, const_ls): sub[key2ind[var_name]] = const # substitute with random facts else: for var_name in atom.var_name_ls: if sub[key2ind[var_name]] is None: sub[key2ind[var_name]] = choice(self.const_sort_dict[rule_vars[var_name]]) def _gen_mask(self, rule, sub, closed_world): latent_vars = [] observed_vars = [] latent_neg_mask = [] observed_neg_mask = [] for atom in rule.atom_ls: grounding = tuple(sub[rule.key2ind[var_name]] for var_name in atom.var_name_ls) pos_gnding, neg_gnding = (1, grounding), (0, grounding) if pos_gnding in self.fact_dict[atom.pred_name]: observed_vars.append((1, atom.pred_name)) observed_neg_mask.append(0 if atom.neg else 1) elif neg_gnding in self.fact_dict[atom.pred_name]: observed_vars.append((0, atom.pred_name)) observed_neg_mask.append(1 if atom.neg else 0) else: if closed_world and (len(self.test_fact_dict[atom.pred_name]) == 0): observed_vars.append((0, atom.pred_name)) observed_neg_mask.append(1 if atom.neg else 0) else: latent_vars.append((atom.pred_name, grounding)) latent_neg_mask.append(1 if atom.neg else 0) return latent_vars, observed_vars, latent_neg_mask, observed_neg_mask def _get_rule_stat(self, observed_vars, latent_vars, observed_neg_mask, filter_latent, filter_observed): is_full_latent = len(observed_vars) == 0 is_full_observed = len(latent_vars) == 0 if is_full_latent and filter_latent: return BAD if is_full_observed: if filter_observed: return BAD is_full_neg = sum(observed_neg_mask) == 0 if is_full_neg: return BAD else: return FULL_OBSERVERED # if observed var already yields 1 if sum(observed_neg_mask) > 0: return BAD return GOOD def _inst_var(self, sub, var2ind, var2type, at, ht_dict, gen_latent): if len(at.var_name_ls) != 2: raise KeyError must_latent = gen_latent if must_latent: tmp = [sub[var2ind[vn]] for vn in at.var_name_ls] for i, subi in enumerate(tmp): if subi is None: tmp[i] = random.choice(self.const_sort_dict[var2type[at.var_name_ls[i]]]) islatent = (tmp[0] not in ht_dict[0]) or (tmp[1] not in ht_dict[0][tmp[0]]) for i, vn in enumerate(at.var_name_ls): sub[var2ind[vn]] = tmp[i] return [self.const2ind[subi] for subi in tmp], islatent, islatent or at.neg vn0 = at.var_name_ls[0] sub0 = sub[var2ind[vn0]] vn1 = at.var_name_ls[1] sub1 = sub[var2ind[vn1]] if sub0 is None: if sub1 is None: if len(ht_dict[0]) > 0: sub0 = random.choice(tuple(ht_dict[0].keys( ))) sub1 = random.choice(tuple(ht_dict[0][sub0])) sub[var2ind[vn0]] = sub0 sub[var2ind[vn1]] = sub1 return [self.const2ind[sub0], self.const2ind[sub1]], False, at.neg else: if sub1 in ht_dict[1]: sub0 = random.choice(tuple(ht_dict[1][sub1])) sub[var2ind[vn0]] = sub0 return [self.const2ind[sub0], self.const2ind[sub1]], False, at.neg else: sub0 = random.choice(self.const_sort_dict[var2type[vn0]]) sub[var2ind[vn0]] = sub0 return [self.const2ind[sub0], self.const2ind[sub1]], True, True else: if sub1 is None: if sub0 in ht_dict[0]: sub1 = random.choice(tuple(ht_dict[0][sub0])) sub[var2ind[vn1]] = sub1 return [self.const2ind[sub0], self.const2ind[sub1]], False, at.neg else: sub1 = random.choice(self.const_sort_dict[var2type[vn1]]) sub[var2ind[vn1]] = sub1 return [self.const2ind[sub0], self.const2ind[sub1]], True, True else: islatent = (sub0 not in ht_dict[0]) or (sub1 not in ht_dict[0][sub0]) return [self.const2ind[sub0], self.const2ind[sub1]], islatent, islatent or at.neg def get_batch_fast(self, batchsize, observed_prob=0.9): prob_decay = 0.5 for rule in self.rule_ls: var2ind = rule.key2ind var2type = rule.rule_vars samples = [[atom.pred_name, []] for atom in rule.atom_ls] neg_mask = [[atom.pred_name, []] for atom in rule.atom_ls] latent_mask = [[atom.pred_name, []] for atom in rule.atom_ls] obs_var = [[atom.pred_name, []] for atom in rule.atom_ls] cnt = 0 while cnt <= batchsize: sub = [None] * len(rule.rule_vars) # substitutions sample_buff = [[] for _ in rule.atom_ls] neg_mask_buff = [[] for _ in rule.atom_ls] latent_mask_buff = [[] for _ in rule.atom_ls] atom_inds = list(range(len(rule.atom_ls))) shuffle(atom_inds) succ = True cur_threshold = observed_prob obs_list = [] for atom_ind in atom_inds: atom = rule.atom_ls[atom_ind] pred_ht_dict = self.ht_dict_train[atom.pred_name] gen_latent = np.random.rand( ) > cur_threshold c_ls, islatent, atom_succ = self._inst_var(sub, var2ind, var2type, atom, pred_ht_dict, gen_latent) if not islatent: obs_var[atom_ind][1].append(c_ls) cur_threshold *= prob_decay succ = succ and atom_succ obs_list.append(not islatent) if succ: sample_buff[atom_ind].append(c_ls) latent_mask_buff[atom_ind].append(1 if islatent else 0) neg_mask_buff[atom_ind].append(0 if atom.neg else 1) if succ and any(obs_list): for i in range(len(rule.atom_ls)): samples[i][1].extend(sample_buff[i]) latent_mask[i][1].extend(latent_mask_buff[i]) neg_mask[i][1].extend(neg_mask_buff[i]) cnt += 1 yield samples, neg_mask, latent_mask, obs_var def get_batch_by_q(self, batchsize, observed_prob=1.0, validation=False): samples_by_r = [[[atom.pred_name, []] for atom in rule.atom_ls] for rule in self.rule_ls] neg_mask_by_r = [[[atom.pred_name, []] for atom in rule.atom_ls] for rule in self.rule_ls] latent_mask_by_r = [[[atom.pred_name, []] for atom in rule.atom_ls] for rule in self.rule_ls] obs_var_by_r = [[[atom.pred_name, []] for atom in rule.atom_ls] for rule in self.rule_ls] neg_var_by_r = [[[atom.pred_name, []] for atom in rule.atom_ls] for rule in self.rule_ls] cnt = 0 num_ents = len(self.const2ind) ind2const = self.const_sort_dict['type'] def gen_fake(c1, c2, pn): for _ in range(10): c1_fake = random.randint(0, num_ents - 1) c2_fake = random.randint(0, num_ents - 1) if np.random.rand( ) > 0.5: if ind2const[c1_fake] not in self.ht_dict_train[pn][1][ind2const[c2]]: return c1_fake, c2 else: if ind2const[c2_fake] not in self.ht_dict_train[pn][0][ind2const[c1]]: return c1, c2_fake return None, None if validation: fact_ls = self.valid_fact_ls else: fact_ls = self.test_fact_ls for val, pred_name, consts in fact_ls: for rule_i, rule in enumerate(self.rule_ls): # find rule with pred_name as head if rule.atom_ls[-1].pred_name != pred_name: continue samples = samples_by_r[rule_i] neg_mask = neg_mask_by_r[rule_i] latent_mask = latent_mask_by_r[rule_i] obs_var = obs_var_by_r[rule_i] neg_var = neg_var_by_r[rule_i] var2ind = rule.key2ind var2type = rule.rule_vars sub = [None] * len(rule.rule_vars) # substitutions vn0, vn1 = rule.atom_ls[-1].var_name_ls sub[var2ind[vn0]] = consts[0] sub[var2ind[vn1]] = consts[1] sample_buff = [[] for _ in rule.atom_ls] neg_mask_buff = [[] for _ in rule.atom_ls] latent_mask_buff = [[] for _ in rule.atom_ls] atom_inds = list(range(len(rule.atom_ls) - 1)) shuffle(atom_inds) succ = True obs_list = [] for atom_ind in atom_inds: atom = rule.atom_ls[atom_ind] pred_ht_dict = self.ht_dict_train[atom.pred_name] gen_latent = np.random.rand( ) > observed_prob c_ls, islatent, atom_succ = self._inst_var(sub, var2ind, var2type, atom, pred_ht_dict, gen_latent) assert atom_succ if not islatent: obs_var[atom_ind][1].append(c_ls) c1, c2 = gen_fake(c_ls[0], c_ls[1], atom.pred_name) if c1 is not None: neg_var[atom_ind][1].append([c1, c2]) succ = succ and atom_succ obs_list.append(not islatent) sample_buff[atom_ind].append(c_ls) latent_mask_buff[atom_ind].append(1 if islatent else 0) neg_mask_buff[atom_ind].append(0 if atom.neg else 1) if succ and any(obs_list): for i in range(len(rule.atom_ls)): samples[i][1].extend(sample_buff[i]) latent_mask[i][1].extend(latent_mask_buff[i]) neg_mask[i][1].extend(neg_mask_buff[i]) samples[-1][1].append([self.const2ind[consts[0]], self.const2ind[consts[1]]]) latent_mask[-1][1].append(1) neg_mask[-1][1].append(1) cnt += 1 if cnt >= batchsize: yield samples_by_r, latent_mask_by_r, neg_mask_by_r, obs_var_by_r, neg_var_by_r samples_by_r = [[[atom.pred_name, []] for atom in rule.atom_ls] for rule in self.rule_ls] neg_mask_by_r = [[[atom.pred_name, []] for atom in rule.atom_ls] for rule in self.rule_ls] latent_mask_by_r = [[[atom.pred_name, []] for atom in rule.atom_ls] for rule in self.rule_ls] obs_var_by_r = [[[atom.pred_name, []] for atom in rule.atom_ls] for rule in self.rule_ls] neg_var_by_r = [[[atom.pred_name, []] for atom in rule.atom_ls] for rule in self.rule_ls] cnt = 0 yield samples_by_r, latent_mask_by_r, neg_mask_by_r, obs_var_by_r, neg_var_by_r def get_batch_by_q_v2(self, batchsize, observed_prob=1.0): samples_by_r = [[[atom.pred_name, []] for atom in rule.atom_ls] for rule in self.rule_ls] neg_mask_by_r = [[[atom.pred_name, []] for atom in rule.atom_ls] for rule in self.rule_ls] latent_mask_by_r = [[[atom.pred_name, []] for atom in rule.atom_ls] for rule in self.rule_ls] obs_var_by_r = [[[atom.pred_name, []] for atom in rule.atom_ls] for rule in self.rule_ls] neg_var_by_r = [[[atom.pred_name, []] for atom in rule.atom_ls] for rule in self.rule_ls] cnt = 0 num_ents = len(self.const2ind) ind2const = self.const_sort_dict['type'] def gen_fake(c1, c2, pn): for _ in range(10): c1_fake = random.randint(0, num_ents - 1) c2_fake = random.randint(0, num_ents - 1) if np.random.rand( ) > 0.5: if ind2const[c1_fake] not in self.ht_dict_train[pn][1][ind2const[c2]]: return c1_fake, c2 else: if ind2const[c2_fake] not in self.ht_dict_train[pn][0][ind2const[c1]]: return c1, c2_fake return None, None for val, pred_name, consts in self.test_fact_ls: for rule_i, rule in enumerate(self.rule_ls): # find rule with pred_name as head if rule.atom_ls[-1].pred_name != pred_name: continue samples = samples_by_r[rule_i] neg_mask = neg_mask_by_r[rule_i] latent_mask = latent_mask_by_r[rule_i] var2ind = rule.key2ind var2type = rule.rule_vars sub_ls = [[None for _ in range(len(rule.rule_vars))] for _ in range(2)] # substitutions vn0, vn1 = rule.atom_ls[-1].var_name_ls sub_ls[0][var2ind[vn0]] = consts[0] sub_ls[0][var2ind[vn1]] = consts[1] c1, c2 = gen_fake(self.const2ind[consts[0]], self.const2ind[consts[1]], pred_name) if c1 is not None: sub_ls[1][var2ind[vn0]] = ind2const[c1] sub_ls[1][var2ind[vn1]] = ind2const[c2] else: sub_ls.pop(1) pos_query_succ = False for sub_ind, sub in enumerate(sub_ls): sample_buff = [[] for _ in rule.atom_ls] neg_mask_buff = [[] for _ in rule.atom_ls] latent_mask_buff = [[] for _ in rule.atom_ls] atom_inds = list(range(len(rule.atom_ls) - 1)) shuffle(atom_inds) succ = True obs_list = [] for atom_ind in atom_inds: atom = rule.atom_ls[atom_ind] pred_ht_dict = self.ht_dict_train[atom.pred_name] gen_latent = np.random.rand( ) > observed_prob if sub_ind == 1: gen_latent = np.random.rand( ) > 0.5 c_ls, islatent, atom_succ = self._inst_var(sub, var2ind, var2type, atom, pred_ht_dict, gen_latent) assert atom_succ succ = succ and atom_succ obs_list.append(not islatent) sample_buff[atom_ind].append(c_ls) latent_mask_buff[atom_ind].append(1 if islatent else 0) neg_mask_buff[atom_ind].append(0 if atom.neg else 1) if succ: if any(obs_list) or ((sub_ind == 1) and pos_query_succ): for i in range(len(rule.atom_ls)): samples[i][1].extend(sample_buff[i]) latent_mask[i][1].extend(latent_mask_buff[i]) neg_mask[i][1].extend(neg_mask_buff[i]) if sub_ind == 0: samples[-1][1].append([self.const2ind[consts[0]], self.const2ind[consts[1]]]) latent_mask[-1][1].append(1) neg_mask[-1][1].append(1) pos_query_succ = True cnt += 1 else: samples[-1][1].append([c1, c2]) latent_mask[-1][1].append(0) # sample a negative fact at head neg_mask[-1][1].append(1) if cnt >= batchsize: yield samples_by_r, latent_mask_by_r, neg_mask_by_r, obs_var_by_r, neg_var_by_r samples_by_r = [[[atom.pred_name, []] for atom in rule.atom_ls] for rule in self.rule_ls] neg_mask_by_r = [[[atom.pred_name, []] for atom in rule.atom_ls] for rule in self.rule_ls] latent_mask_by_r = [[[atom.pred_name, []] for atom in rule.atom_ls] for rule in self.rule_ls] obs_var_by_r = [[[atom.pred_name, []] for atom in rule.atom_ls] for rule in self.rule_ls] neg_var_by_r = [[[atom.pred_name, []] for atom in rule.atom_ls] for rule in self.rule_ls] cnt = 0 yield samples_by_r, latent_mask_by_r, neg_mask_by_r, obs_var_by_r, neg_var_by_r def get_batch_rnd(self, observed_prob=0.7, filter_latent=True, closed_world=False, filter_observed=False): """ return a batch of gnd formulae by random sampling with controllable bias towards those containing observed variables. The overall sampling logic is that: 1) rnd sample a rule from rule_ls 2) shuffle the predicates contained in the rule 3) for each of these predicates, with (observed_prob) it will be instantiated as observed variable, and for (1-observed_prob) if will be simply uniformly instantiated. 3.1) if observed var, then sample from the knowledge base, which is self.fact_dict, if failed for any reason, go to 3.2) 3.2) if uniformly sample, then for each logic variable in the predicate, instantiate it with a uniform sample from the corresponding constant dict :param observed_prob: probability of instantiating a predicate as observed variable :param filter_latent: filter out ground formula containing only latent vars :param closed_world: if set True, reduce the sampling space of all predicates not in the test_dict to the set specified in fact_dict :param filter_observed: filter out ground formula containing only observed vars :return: """ batch_neg_mask = [[] for _ in range(len(self.rule_ls))] batch_latent_var_inds = [[] for _ in range(len(self.rule_ls))] batch_observed_vars = [[] for _ in range(len(self.rule_ls))] observed_rule_cnts = [0.0 for _ in range(len(self.rule_ls))] flat_latent_vars = dict( ) cnt = 0 inds = list(range(len(self.rule_ls))) while cnt < self.batchsize: # randomly sample a formula if self.shuffle_sampling: shuffle(inds) for ind in inds: rule = self.rule_ls[ind] atom_key_dict = self.atom_key_dict_ls[ind] sub = [None] * len(rule.rule_vars) # substitutions # randomly sample an atom from the formula atom_inds = list(range(len(rule.atom_ls))) shuffle(atom_inds) for atom_ind in atom_inds: atom = rule.atom_ls[atom_ind] atom_dict = atom_key_dict[atom.pred_name] # instantiate the predicate self._instantiate_pred(atom, atom_dict, sub, rule, observed_prob) # if variable substitution is complete already then exit if not (None in sub): break # generate latent and observed var labels and their negation masks latent_vars, observed_vars, \ latent_neg_mask, observed_neg_mask = self._gen_mask(rule, sub, closed_world) # check sampled ground rule status stat_code = self._get_rule_stat(observed_vars, latent_vars, observed_neg_mask, filter_latent, filter_observed) # is a valid sample with only observed vars and does not have negation on all of them if stat_code == FULL_OBSERVERED: observed_rule_cnts[ind] += 1 cnt += 1 # is a valid sample elif stat_code == GOOD: batch_neg_mask[ind].append([latent_neg_mask, observed_neg_mask]) for latent_var in latent_vars: if latent_var not in flat_latent_vars: flat_latent_vars[latent_var] = len(flat_latent_vars) batch_latent_var_inds[ind].append([flat_latent_vars[e] for e in latent_vars]) batch_observed_vars[ind].append(observed_vars) cnt += 1 # not a valid sample else: continue if cnt >= self.batchsize: break flat_list = sorted([(k, v) for k, v in flat_latent_vars.items( )], key=lambda x: x[1]) flat_list = [e[0] for e in flat_list] return batch_neg_mask, flat_list, batch_latent_var_inds, observed_rule_cnts, batch_observed_vars def reset(self): self.rule_gens = [self.generate_gnd_rule(rule) for rule in self.rule_ls] def get_stats(self): num_ents = sum([len(v) for k, v in self.const_sort_dict.items( )]) num_rels = len(self.PRED_DICT) num_facts = sum([len(v) for k, v in self.fact_dict.items( )]) num_queries = len(self.test_fact_ls) num_gnd_atom = 0 for pred_name, pred in self.PRED_DICT.items( ): cnt = 1 for var_type in pred.var_types: cnt *= len(self.const_sort_dict[var_type]) num_gnd_atom += cnt num_gnd_rule = 0 for rule in self.rule_ls: cnt = 1 for var_type in rule.rule_vars.values( ): cnt *= len(self.const_sort_dict[var_type]) num_gnd_rule += cnt return num_ents, num_rels, num_facts, num_queries, num_gnd_atom, num_gnd_rule def preprocess_kinship(self, ppath, fpath, rpath, qpath): """ :param ppath: predicate file path :param fpath: facts file path :param rpath: rule file path :param qpath: query file path :return: """ assert all(map(isfile, [ppath, fpath, rpath, qpath])) strip_items = lambda ls: list(map(lambda x: x.strip( ), ls)) pred_reg = re.compile(r'(.*)\((.*)\)') with open(ppath) as f: for line in f: # skip empty lines if line.strip( ) == '': continue m = pred_reg.match(line.strip( )) assert m is not None, 'matching predicate failed for %s' % line name, var_types = m.group(1), m.group(2) var_types = list(map(lambda x: x.strip( ), var_types.split(','))) self.PRED_DICT[name] = Predicate(name, var_types) TYPE_SET.update(var_types) fact_ls = [] fact_reg = re.compile(r'(!?)(.*)\((.*)\)') with open(fpath) as f: for line in f: # skip empty lines if line.strip( ) == '': continue m = fact_reg.match(line.strip( )) assert m is not None, 'matching fact failed for %s' % line val = 0 if m.group(1) == '!' else 1 name, consts = m.group(2), m.group(3) consts = strip_items(consts.split(',')) fact_ls.append(Fact(name, consts, val)) for var_type in self.PRED_DICT[name].var_types: self.const_dict.add_const(var_type, consts.pop(0)) rule_ls = [] first_atom_reg = re.compile(r'([\d.]+) (!?)([\w\d]+)\((.*)\)') atom_reg = re.compile(r'(!?)([\w\d]+)\((.*)\)') with open(rpath) as f: for line in f: # skip empty lines if line.strip( ) == '': continue atom_str_ls = strip_items(line.strip( ).split(' v ')) assert len(atom_str_ls) > 1, 'rule length must be greater than 1, but get %s' % line atom_ls = [] rule_weight = 0.0 for i, atom_str in enumerate(atom_str_ls): if i == 0: m = first_atom_reg.match(atom_str) assert m is not None, 'matching atom failed for %s' % atom_str rule_weight = float(m.group(1)) neg = m.group(2) == '!' pred_name = m.group(3).strip( ) var_name_ls = strip_items(m.group(4).split(',')) else: m = atom_reg.match(atom_str) assert m is not None, 'matching atom failed for %s' % atom_str neg = m.group(1) == '!' pred_name = m.group(2).strip( ) var_name_ls = strip_items(m.group(3).split(',')) atom = Atom(neg, pred_name, var_name_ls, self.PRED_DICT[pred_name].var_types) atom_ls.append(atom) rule = Formula(atom_ls, rule_weight) rule_ls.append(rule) query_ls = [] with open(qpath) as f: for line in f: # skip empty lines if line.strip( ) == '': continue m = fact_reg.match(line.strip( )) assert m is not None, 'matching fact failed for %s' % line val = 0 if m.group(1) == '!' else 1 name, consts = m.group(2), m.group(3) consts = strip_items(consts.split(',')) query_ls.append(Fact(name, consts, val)) for var_type in self.PRED_DICT[name].var_types: self.const_dict.add_const(var_type, consts.pop(0)) return fact_ls, rule_ls, query_ls TYPE_SET = set( ) def iterline(fpath): with open(fpath) as f: for line in f: line = line.strip( ) if line == '': continue yield line class ConstantDict: def __init__(self): self.constants = {} def add_const(self, const_type, const): """ :param const_type: string :param const: string """ # if const_type not in TYPE_DICT: # TYPE_DICT[const_type] = len(TYPE_DICT) if const_type in self.constants: self.constants[const_type].add(const) else: self.constants[const_type] = {const} def __getitem__(self, key): return self.constants[key] def has_const(self, key, const): if key in self.constants: return const in self[key] else: return False class Predicate: def __init__(self, name, var_types): """ :param name: string :param var_types: list of strings """ self.name = name self.var_types = var_types self.num_args = len(var_types) def __repr__(self): return '%s(%s)' % (self.name, ','.join(self.var_types)) class Fact: def __init__(self, pred_name, const_ls, val): self.pred_name = pred_name self.const_ls = deepcopy(const_ls) self.val = val def __repr__(self): return self.pred_name + '(%s)' % ','.join(self.const_ls) class Atom: def __init__(self, neg, pred_name, var_name_ls, var_type_ls): self.neg = neg self.pred_name = pred_name self.var_name_ls = var_name_ls self.var_type_ls = var_type_ls def __repr__(self): return ('!' if self.neg else '') + self.pred_name + '(%s)' % ','.join(self.var_name_ls) class Formula: """ only support clause form with disjunction, e.g. ! """ def __init__(self, atom_ls, weight): self.weight = weight self.atom_ls = atom_ls self.rule_vars = dict( ) for atom in self.atom_ls: self.rule_vars.update(zip(atom.var_name_ls, atom.var_type_ls)) self.key2ind = dict(zip(self.rule_vars.keys( ), range(len(self.rule_vars.keys( ))))) def evaluate(self): pass def __repr__(self): return ' v '.join(list(map(repr, self.atom_ls))) class ConstantDict: def __init__(self): self.constants = {} def add_const(self, const_type, const): """ :param const_type: string :param const: string """ # if const_type not in TYPE_DICT: # TYPE_DICT[const_type] = len(TYPE_DICT) if const_type in self.constants: self.constants[const_type].add(const) else: self.constants[const_type] = {const} def __getitem__(self, key): return self.constants[key] def has_const(self, key, const): if key in self.constants: return const in self[key] else: return False @register_dataset('NBF_link_prediction') class NBF_LinkPrediction(LinkPredictionDataset): r""" The NBF dataset will be used in task *link prediction*. """ def __init__(self, dataset_name ,*args, **kwargs): # dataset_name in ['NBF_WN18RR','NBF_FB15k-237'] self.dataset = NBF_Dataset(root='./openhgnn/dataset/', name=dataset_name[4:], version="v1") import os import requests import zipfile import io @register_dataset('DisenKGAT_link_prediction') class DisenKGAT_LinkPrediction(LinkPredictionDataset): def __init__(self, dataset ,*args, **kwargs): # dataset "DisenKGAT" self.logger = kwargs.get("Logger") self.args = kwargs.get("args") self.current_dir = os.path.dirname(os.path.abspath(__file__)) self.dataset_name = dataset self.raw_dir = os.path.join(self.current_dir, self.dataset_name ,"raw_dir" ) self.processed_dir = os.path.join(self.current_dir, self.dataset_name ,"processed_dir" ) if not os.path.exists(self.raw_dir): os.makedirs(self.raw_dir) self.download() else: print("raw_dir already exists") def download(self): url = "https://s3.cn-north-1.amazonaws.com.cn/dgl-data/dataset/openhgnn/{}.zip".format(self.dataset_name) response = requests.get(url) with zipfile.ZipFile(io.BytesIO(response.content)) as myzip: myzip.extractall(self.raw_dir) print("--- download finished---")