Source code for openhgnn.dataset.academic_graph

import os
from dgl.data.utils import download, extract_archive
from dgl.data import DGLDataset
from dgl.data.utils import load_graphs


[docs]class AcademicDataset(DGLDataset): _prefix = 'https://s3.cn-north-1.amazonaws.com.cn/dgl-data/' _urls = { 'academic4HetGNN': 'dataset/academic4HetGNN.zip', 'acm4GTN': 'dataset/acm4GTN.zip', 'acm4NSHE': 'dataset/acm4NSHE.zip', 'acm4NARS': 'dataset/acm4NARS.zip', 'acm4HeCo': 'dataset/acm4HeCo.zip', 'imdb4MAGNN': 'dataset/imdb4MAGNN.zip', 'imdb4GTN': 'dataset/imdb4GTN.zip', 'DoubanMovie': 'dataset/DoubanMovie.zip', 'dblp4MAGNN': 'dataset/dblp4MAGNN.zip', 'yelp4HeGAN': 'dataset/yelp4HeGAN.zip', 'yelp4rec': 'dataset/yelp4rec.zip', 'HNE-PubMed': 'dataset/HNE-PubMed.zip', 'MTWM': 'dataset/MTWM.zip', 'amazon4SLICE': 'dataset/amazon4SLICE.zip' } def __init__(self, name, raw_dir=None, force_reload=False, verbose=True): assert name in ['acm4GTN', 'acm4NSHE', 'academic4HetGNN', 'imdb4MAGNN', 'imdb4GTN', 'HNE-PubMed', 'MTWM', 'DoubanMovie', 'dblp4MAGNN', 'acm4NARS', 'acm4HeCo', 'yelp4rec', 'yelp4HeGAN', 'amazon4SLICE'] self.data_path = './openhgnn/' + self._urls[name] self.g_path = './openhgnn/dataset/' + name + '/graph.bin' raw_dir = './openhgnn/dataset' url = self._prefix + self._urls[name] super(AcademicDataset, self).__init__(name=name, url=url, raw_dir=raw_dir, force_reload=force_reload, verbose=verbose)
[docs] def download(self): # download raw data to local disk # path to store the file if os.path.exists(self.data_path): # pragma: no cover pass else: file_path = os.path.join(self.raw_dir) # download file download(self.url, path=file_path) extract_archive(self.data_path, os.path.join(self.raw_dir, self.name))
[docs] def process(self): # process raw data to graphs, labels, splitting masks g, _ = load_graphs(self.g_path) self._g = g[0]
def __getitem__(self, idx): # get one example by index assert idx == 0, "This dataset has only one graph" return self._g def __len__(self): # number of data examples return 1
[docs] def save(self): # save processed data to directory `self.save_path` pass
[docs] def load(self): # load processed data from directory `self.save_path` pass
[docs] def has_cache(self): # check whether there are processed data in `self.save_path` pass