import os
from dgl.data.utils import download, extract_archive
from dgl.data import DGLDataset
from dgl.data.utils import load_graphs
[docs]class AcademicDataset(DGLDataset):
_prefix = 'https://s3.cn-north-1.amazonaws.com.cn/dgl-data/'
_urls = {
'academic4HetGNN': 'dataset/academic4HetGNN.zip',
'acm4GTN': 'dataset/acm4GTN.zip',
'acm4NSHE': 'dataset/acm4NSHE.zip',
'acm4NARS': 'dataset/acm4NARS.zip',
'acm4HeCo': 'dataset/acm4HeCo.zip',
'imdb4MAGNN': 'dataset/imdb4MAGNN.zip',
'imdb4GTN': 'dataset/imdb4GTN.zip',
'DoubanMovie': 'dataset/DoubanMovie.zip',
'dblp4MAGNN': 'dataset/dblp4MAGNN.zip',
'yelp4HeGAN': 'dataset/yelp4HeGAN.zip',
'yelp4rec': 'dataset/yelp4rec.zip',
'HNE-PubMed': 'dataset/HNE-PubMed.zip',
'MTWM': 'dataset/MTWM.zip',
'amazon4SLICE': 'dataset/amazon4SLICE.zip'
}
def __init__(self, name, raw_dir=None, force_reload=False, verbose=True):
assert name in ['acm4GTN', 'acm4NSHE', 'academic4HetGNN', 'imdb4MAGNN', 'imdb4GTN', 'HNE-PubMed', 'MTWM',
'DoubanMovie', 'dblp4MAGNN', 'acm4NARS', 'acm4HeCo', 'yelp4rec', 'yelp4HeGAN', 'amazon4SLICE']
self.data_path = './openhgnn/' + self._urls[name]
self.g_path = './openhgnn/dataset/' + name + '/graph.bin'
raw_dir = './openhgnn/dataset'
url = self._prefix + self._urls[name]
super(AcademicDataset, self).__init__(name=name,
url=url,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose)
[docs] def download(self):
# download raw data to local disk
# path to store the file
if os.path.exists(self.data_path): # pragma: no cover
pass
else:
file_path = os.path.join(self.raw_dir)
# download file
download(self.url, path=file_path)
extract_archive(self.data_path, os.path.join(self.raw_dir, self.name))
[docs] def process(self):
# process raw data to graphs, labels, splitting masks
g, _ = load_graphs(self.g_path)
self._g = g[0]
def __getitem__(self, idx):
# get one example by index
assert idx == 0, "This dataset has only one graph"
return self._g
def __len__(self):
# number of data examples
return 1
[docs] def save(self):
# save processed data to directory `self.save_path`
pass
[docs] def load(self):
# load processed data from directory `self.save_path`
pass
[docs] def has_cache(self):
# check whether there are processed data in `self.save_path`
pass