Source code for opengsl.data.dataset.dataset

import torch
from opengsl.data.dataset.pyg_load import pyg_load_dataset
from opengsl.data.dataset.split import get_split, k_fold
from opengsl.module.functional import normalize
import numpy as np
from opengsl.data.preprocess.control_homophily import control_homophily
import pickle
import os
import urllib.request
from torch_geometric.utils import degree
import torch_geometric.transforms as T
# from ogb.nodeproppred import PygNodePropPredDataset


[docs]class Dataset: ''' # TODO update docstring Dataset Class. This class loads, preprocesses and splits various datasets. Parameters ---------- data : str The name of dataset. feat_norm : bool Whether to normalize the features. verbose : bool Whether to print statistics. n_splits : int Number of data splits. homophily_control : float The homophily ratio `control homophily` receives. If set to `None`, the original adj will be kept unchanged. path : str Path to save dataset files. ''' def __init__(self, data, feat_norm=False, verbose=True, n_splits=1, split='public', split_params=None, homophily_control=None, path='./data/', cv=None, **kwargs): self.name = data self.feat_norm = feat_norm self.verbose = verbose self.path = path self.device = torch.device('cuda') self.single_graph = True self.split_params = split_params self.n_splits = n_splits self.split = split assert self.split in ['public', 'random'] self.cv = cv self.total_splits = n_splits * cv if cv else n_splits self.prepare_data(data, feat_norm, verbose) if self.single_graph: self.split_data(split, n_splits, cv, split_params, verbose) else: self.split_graphs(split, n_splits, cv, split_params, verbose) if homophily_control: self.adj = control_homophily(self.adj, self.labels.cpu().numpy(), homophily_control)
[docs] def prepare_data(self, ds_name, feat_norm=False, verbose=True): ''' Function to Load various datasets. Homophilous datasets are loaded via pyg, while heterophilous datasets are loaded with `hetero_load`. The results are saved as `self.feats, self.adj, self.labels, self.train_masks, self.val_masks, self.test_masks`. Noth that `self.adj` is undirected and has no self loops. Parameters ---------- ds_name : str The name of dataset. feat_norm : bool Whether to normalize the features. verbose : bool Whether to print statistics. ''' if ds_name in ['cora', 'pubmed', 'citeseer', 'amazoncom', 'amazonpho', 'coauthorcs', 'coauthorph', 'blogcatalog', 'flickr', 'wikics', 'amazon-ratings', 'questions', 'chameleon-filtered', 'squirrel-filtered', 'minesweeper', 'roman-empire', 'wiki-cooc', 'tolokers', 'cora_full', 'cora_ml', 'citeseer_full', 'dblp', 'pubmed_full'] or 'csbm' in ds_name: self.data_raw = pyg_load_dataset(ds_name, path=self.path) self.g = self.data_raw[0] self.feats = self.g.x # unnormalized if ds_name == 'flickr': self.feats = self.feats.to_dense().float() self.n_nodes = self.feats.shape[0] self.dim_feats = self.feats.shape[1] self.labels = self.g.y self.adj = torch.sparse.FloatTensor(self.g.edge_index, torch.ones(self.g.edge_index.shape[1]), [self.n_nodes, self.n_nodes]) self.n_edges = self.g.num_edges/2 self.n_classes = self.data_raw.num_classes self.feats = self.feats.to(self.device) self.labels = self.labels.to(self.device) if 'csbm' in ds_name: self.labels = self.labels.float() self.adj = self.adj.to(self.device) # normalize features if feat_norm: self.feats = normalize(self.feats, style='row') # elif ds_name in ['ogbn-arxiv']: # self.data_raw = PygNodePropPredDataset(name='ogbn-arxiv', root='./data') # self.g = self.data_raw[0] # self.feats = self.g.x # unnormalized # self.n_nodes = self.feats.shape[0] # self.dim_feats = self.feats.shape[1] # self.labels = self.g.y # reverse_edge_index = torch.stack([self.g.edge_index[1], self.g.edge_index[0]]) # # self.adj = torch.sparse.FloatTensor(torch.cat([reverse_edge_index, self.g.edge_index], dim=1), torch.ones(self.g.edge_index.shape[1]*2), # [self.n_nodes, self.n_nodes]) # self.n_edges = self.g.num_edges # self.n_classes = self.data_raw.num_classes # # self.feats = self.feats.to(self.device) # self.labels = self.labels.to(self.device).view(-1) # self.adj = self.adj.to(self.device) else: # graph level self.single_graph = False self.data_raw = pyg_load_dataset(ds_name, path=self.path) if self.data_raw.data.x is None: # 如果没有节点特征,使用度数作为节点特征 # 所有图的最大度数 max_degree = 0 degs = [] for data in self.data_raw: degs += [degree(data.edge_index[0], dtype=torch.long)] max_degree = max(max_degree, degs[-1].max().item()) # 两种安排度数特征的方式 if max_degree < 1000: self.data_raw.transform = T.OneHotDegree(max_degree) else: deg = torch.cat(degs, dim=0).to(torch.float) mean, std = deg.mean().item(), deg.std().item() self.data_raw.transform = NormalizedDegree(mean, std) self.n_graphs = len(self.data_raw) self.n_classes = self.data_raw.num_classes self.dim_feats = self.data_raw[0].x.shape[1] if verbose: if self.single_graph: print("""----Data statistics------' #Nodes %d #Edges %d #Classes %d""" % (self.n_nodes, self.n_edges, self.n_classes)) else: print("""----Data statistics------' #Graphs %d #Classes %d""" % (self.n_graphs, self.n_classes)) self.num_targets = self.n_classes
[docs] def split_data(self, split, n_splits=1, cv=None, split_params=None, verbose=True): ''' Function to conduct data splitting for various datasets. Parameters ---------- n_splits : int Number of data splits. verbose : bool Whether to print statistics. ''' self.train_masks = [] self.val_masks = [] self.test_masks = [] if split == 'public': assert self.name in ['cora', 'citeseer', 'pubmed', 'blogcatalog', 'flickr', 'roman-empire', 'amazon-ratings', 'minesweeper', 'tolokers', 'questions', 'wikics'], 'This dataset has no public splits.' if self.name in ['cora', 'citeseer', 'pubmed']: for i in range(n_splits): self.train_masks.append(torch.nonzero(self.g.train_mask, as_tuple=False).squeeze().numpy()) self.val_masks.append(torch.nonzero(self.g.val_mask, as_tuple=False).squeeze().numpy()) self.test_masks.append(torch.nonzero(self.g.test_mask, as_tuple=False).squeeze().numpy()) elif self.name in ['blogcatalog', 'flickr']: def load_obj(file_name): with open(file_name, 'rb') as f: return pickle.load(f) def download(name): url = 'https://github.com/zhao-tong/GAug/raw/master/data/graphs/' try: print('Downloading', url + name) urllib.request.urlretrieve(url + name, os.path.join(self.path, self.name, name)) print('Done!') except: raise Exception( '''Download failed! Make sure you have stable Internet connection and enter the right name''') split_file = self.name + '_tvt_nids.pkl' if not os.path.exists(os.path.join(self.path, self.name, split_file)): download(split_file) train_indices, val_indices, test_indices = load_obj(os.path.join(self.path, self.name, split_file)) for i in range(n_splits): self.train_masks.append(train_indices) self.val_masks.append(val_indices) self.test_masks.append(test_indices) elif self.name in ['roman-empire', 'amazon-ratings', 'minesweeper', 'tolokers', 'questions', 'wikics']: assert n_splits < 10, 'n_splits > public splits' self.train_masks = [torch.nonzero(x, as_tuple=False).squeeze().numpy() for x in self.g.train_mask.T] self.val_masks = [torch.nonzero(x, as_tuple=False).squeeze().numpy() for x in self.g.val_mask.T] self.test_masks = [torch.nonzero(x, as_tuple=False).squeeze().numpy() for x in self.g.test_mask.T] elif split == 'random': if cv: for i in range(n_splits): np.random.seed(i) train_indices, val_indices, test_indices = k_fold(self.labels.cpu().numpy(), cv) self.train_masks.extend(train_indices) self.val_masks.extend(val_indices) self.test_masks.extend(test_indices) else: assert split_params is not None, 'you need to specify the split params' for i in range(n_splits): np.random.seed(i) train_indices, val_indices, test_indices = get_split(self.labels.cpu().numpy(), split_params) self.train_masks.append(train_indices) self.val_masks.append(val_indices) self.test_masks.append(test_indices) else: raise NotImplementedError if verbose: print("""----Split statistics of %d splits------' #Train samples %d #Val samples %d #Test samples %d""" % (self.total_splits, len(self.train_masks[0]), len(self.val_masks[0]), len(self.test_masks[0])))
[docs] def split_graphs(self, split, n_splits, cv, split_params, verbose=True): ''' Function to conduct data splitting for graph-level datasets. Parameters ---------- n_splits : int Number of data splits. verbose : bool Whether to print statistics. ''' self.train_masks = [] self.val_masks = [] self.test_masks = [] if split == 'public': raise NotImplementedError elif split == 'random': for seed in range(n_splits): np.random.seed(seed) if cv: for fold, (train_idx, test_idx, val_idx) in enumerate(zip(*k_fold(self.data_raw.y.cpu().numpy(), cv))): self.train_masks.append(train_idx) self.val_masks.append(val_idx) self.test_masks.append(test_idx) else: assert split_params is not None train_indices, val_indices, test_indices = get_split(self.data_raw.y.cpu().numpy(), split_params) self.train_masks.append(train_indices) self.val_masks.append(val_indices) self.test_masks.append(test_indices)
class NormalizedDegree(object): def __init__(self, mean, std): self.mean = mean self.std = std def __call__(self, data): deg = degree(data.edge_index[0], dtype=torch.float) deg = (deg - self.mean) / self.std data.x = deg.view(-1, 1) return data if __name__ == '__main__': dataset = Dataset('citeseer_full', split='random', n_splits=10, split_params={'train_examples_per_class':20, 'val_examples_per_class':30})