openhgnn.dataset.EdgeClassificationDataset 源代码

import torch as th
from . import BaseDataset, register_dataset
from . import Mg2vecDataSet


[文档]@register_dataset('edge_classification') class EdgeClassificationDataset(BaseDataset): r""" The class *EdgeClassificationDataset* is a base class for datasets which can be used in task *edge classification*. So its subclass should contain attributes such as graph, category, num_classes and so on. Besides, it should implement the functions *get_labels()* and *get_split()*. Attributes ------------- g : dgl.DGLHeteroGraph The heterogeneous graph. category : str The category(or target) node type need to be predict. In general, we predict only one node type. num_classes : int The target node will be classified into num_classes categories. has_feature : bool Whether the dataset has feature. Default ``False``. multi_label : bool Whether the node has multi label. Default ``False``. For now, only HGBn-IMDB has multi-label. """ def __init__(self, *args, **kwargs): super(EdgeClassificationDataset, self).__init__(*args, **kwargs) self.g = None self.category = None self.num_classes = None self.has_feature = False self.multi_label = False def get_labels(self): r""" The subclass of dataset should overwrite the function. We can get labels of target nodes through it. Notes ------ In general, the labels are th.LongTensor. But for multi-label dataset, they should be th.FloatTensor. Or it will raise RuntimeError: Expected object of scalar type Long but got scalar type Float for argument #2 target' in call to _thnn_nll_loss_forward return ------- labels : torch.Tensor """ if 'labels' in self.g.edges[self.category].data: labels = self.g.edges[self.category].data.pop('labels').long() elif 'label' in self.g.edges[self.category].data: labels = self.g.edges[self.category].data.pop('label').long() else: raise ValueError('Labels of nodes are not in the hg.edges[category].data.') labels = labels.float() if self.multi_label else labels return labels def get_split(self, validation=True): r""" Parameters ---------- validation : bool Whether to split dataset. Default ``True``. If it is False, val_idx will be same with train_idx. We can get idx of train, validation and test through it. return ------- train_idx, val_idx, test_idx : torch.Tensor, torch.Tensor, torch.Tensor """ if 'train_mask' not in self.g.edges[self.category].data: self.logger.dataset_info("The dataset has no train mask. " "So split the category nodes randomly. And the ratio of train/test is 8:2.") num_nodes = self.g.number_of_nodes(self.category) n_test = int(num_nodes * 0.2) n_train = num_nodes - n_test train, test = th.utils.data.random_split(range(num_nodes), [n_train, n_test]) train_idx = th.tensor(train.indices) test_idx = th.tensor(test.indices) if validation: self.logger.dataset_info("Split train into train/valid with the ratio of 8:2 ") random_int = th.randperm(len(train_idx)) valid_idx = train_idx[random_int[:len(train_idx) // 5]] train_idx = train_idx[random_int[len(train_idx) // 5:]] else: self.logger.dataset_info("Set valid set with train set.") valid_idx = train_idx train_idx = train_idx else: train_mask = self.g.edges[self.category].data.pop('train_mask') test_mask = self.g.edges[self.category].data.pop('test_mask') train_idx = th.nonzero(train_mask, as_tuple=False).squeeze() test_idx = th.nonzero(test_mask, as_tuple=False).squeeze() if validation: if 'val_mask' in self.g.edges[self.category].data: val_mask = self.g.edges[self.category].data.pop('val_mask') valid_idx = th.nonzero(val_mask, as_tuple=False).squeeze() elif 'valid_mask' in self.g.edges[self.category].data: val_mask = self.g.edges[self.category].data.pop('valid_mask').squeeze() valid_idx = th.nonzero(val_mask, as_tuple=False).squeeze() else: self.logger.dataset_info("Split train into train/valid with the ratio of 8:2 ") random_int = th.randperm(len(train_idx)) valid_idx = train_idx[random_int[:len(train_idx) // 5]] train_idx = train_idx[random_int[len(train_idx) // 5:]] else: self.logger.dataset_info("Set valid set with train set.") valid_idx = train_idx train_idx = train_idx self.train_idx = train_idx self.valid_idx = valid_idx self.test_idx = test_idx return self.train_idx, self.valid_idx, self.test_idx
@register_dataset('hin_edge_classification') class HIN_EdgeClassification(EdgeClassificationDataset): r""" The HIN dataset are all used in different papers. So we preprocess them and store them as form of dgl.DGLHeteroGraph. The dataset name combined with paper name through 4(for). Dataset Name : dblp4Mg2vec/ ... """ def __init__(self, dataset_name, *args, **kwargs): super(HIN_EdgeClassification, self).__init__(*args, **kwargs) self.g, self.category, self.num_classes = self.load_HIN(dataset_name) def load_HIN(self, name_dataset): if name_dataset == 'dblp4Mg2vec_4': # which is used in MG2VEC with size=4 dataset = Mg2vecDataSet(name='dblp4Mg2vec_4', raw_dir='') g = dataset[0].long() category = 'relation' num_classes = 3 return g, category, num_classes if name_dataset == 'dblp4Mg2vec_5': # which is used in MG2VEC with size=5 dataset = Mg2vecDataSet(name='dblp4Mg2vec_5', raw_dir='') g = dataset[0].long() category = 'relation' num_classes = 3 return g, category, num_classes