import torch
import torch.nn.functional as F
import numpy as np
import scipy.sparse as sp
import torch_geometric.nn
from opengsl.module.metric import Cosine
from opengsl.utils.utils import scipy_sparse_to_sparse_tensor, sparse_tensor_to_scipy_sparse
[docs]def normalize(mx, style='symmetric', add_loop=True, p=None):
'''
Normalize the feature matrix or adj matrix.
Parameters
----------
mx : torch.tensor
Feature matrix or adj matrix to normalize. Note that either sparse or dense form is supported.
style: str
If set as ``row``, `mx` will be row-wise normalized.
If set as ``symmetric``, `mx` will be normalized as in GCN.
If set as ``softmax``, `mx` will be normalized using softmax.
If set as ``row-norm``, `mx` will be normalized using `F.normalize` in pytorch.
add_loop : bool
Whether to add self loop.
p : float
The exponent value in the norm formulation. Onlu used when style is set as ``row-norm``.
Returns
-------
normalized_mx : torch.tensor
The normalized matrix.
'''
if style == 'row':
if mx.is_sparse:
return row_normalize_sp(mx)
else:
return row_nomalize(mx)
elif style == 'symmetric':
if mx.is_sparse:
return normalize_sp_tensor_tractable(mx, add_loop)
else:
return normalize_tensor(mx, add_loop)
elif style == 'softmax':
if mx.is_sparse:
return torch.sparse.softmax(mx, dim=-1)
else:
return F.softmax(mx, dim=-1)
elif style == 'row-norm':
assert p is not None
if mx.is_sparse:
# TODO
pass
else:
return F.normalize(mx, dim=-1, p=p)
else:
raise KeyError("The normalize style is not provided.")
def row_nomalize(mx):
"""Row-normalize sparse matrix.
"""
# device = mx.device
# mx = mx.cpu().numpy()
# r_sum = np.array(mx.sum(1))
# r_inv = np.power(r_sum, -1).flatten()
# r_inv[np.isinf(r_inv)] = 0.
# r_mat_inv = sp.diags(r_inv)
# mx = r_mat_inv.dot(mx)
# mx = torch.tensor(mx, device=device)
r_sum = mx.sum(1)
r_inv = r_sum.pow(-1).flatten()
r_inv[torch.isinf(r_inv)] = 0.
r_mat_inv = torch.diag(r_inv)
mx = r_mat_inv @ mx
return mx
def row_normalize_sp(mx):
adj = mx.coalesce()
inv_sqrt_degree = 1. / (torch.sparse.sum(mx, dim=1).values() + 1e-12)
D_value = inv_sqrt_degree[adj.indices()[0]]
new_values = adj.values() * D_value
return torch.sparse.FloatTensor(adj.indices(), new_values, adj.size())
def normalize_sp_tensor_tractable(adj, add_loop=True):
n = adj.shape[0]
device = adj.device
if add_loop:
adj = adj + torch.eye(n, device=device).to_sparse()
adj = adj.coalesce()
inv_sqrt_degree = 1. / (torch.sqrt(torch.sparse.sum(adj, dim=1).values()) + 1e-12)
D_value = inv_sqrt_degree[adj.indices()[0]] * inv_sqrt_degree[adj.indices()[1]]
new_values = adj.values() * D_value
return torch.sparse.FloatTensor(adj.indices(), new_values, adj.size())
def normalize_tensor(adj, add_loop=True):
device = adj.device
adj_loop = adj + torch.eye(adj.shape[0], device=device) if add_loop else adj
rowsum = adj_loop.sum(1)
r_inv = rowsum.pow(-1/2).flatten()
r_inv[torch.isinf(r_inv)] = 0.
r_mat_inv = torch.diag(r_inv)
A = r_mat_inv @ adj_loop
A = A @ r_mat_inv
return A
def normalize_sp_tensor(adj, add_loop=True):
device = adj.device
adj = sparse_tensor_to_scipy_sparse(adj)
adj = normalize_sp_matrix(adj, add_loop)
adj = scipy_sparse_to_sparse_tensor(adj, device=device)
return adj
def normalize_sp_matrix(adj, add_loop=True):
mx = adj + sp.eye(adj.shape[0]) if add_loop else adj
rowsum = np.array(mx.sum(1))
r_inv_sqrt = np.power(rowsum, -0.5).flatten()
r_inv_sqrt[np.isinf(r_inv_sqrt)] = 0.
r_mat_inv_sqrt = sp.diags(r_inv_sqrt)
new = mx.dot(r_mat_inv_sqrt).transpose().dot(r_mat_inv_sqrt)
return new
def symmetry(adj, i=2):
if adj.is_sparse:
n = adj.shape[0]
adj_t = torch.sparse.FloatTensor(adj.indices()[[1,0]], adj.values(), [n, n])
return (adj_t + adj).coalesce() / i
else:
return (adj.t() + adj) / i
def knn(adj, K, self_loop=True, set_value=None, sparse_out=False):
if adj.is_sparse:
# TODO
pass
else:
device = adj.device
values, indices = adj.topk(k=int(K), dim=-1)
assert torch.max(indices) < adj.shape[1]
if sparse_out:
n = adj.shape[0]
new_indices = torch.stack([torch.arange(n, device=device).view(-1, 1).expand(-1, int(K)).contiguous().flatten(),
indices.flatten()])
new_values = values.flatten()
return torch.sparse.FloatTensor(new_indices, new_values, [n, n]).coalesce()
else:
mask = torch.zeros(adj.shape, device=device)
mask[torch.arange(adj.shape[0], device=device).view(-1, 1), indices] = 1.
if not self_loop:
mask[torch.arange(adj.shape[0], device=device).view(-1, 1), torch.arange(adj.shape[0], device=device).view(-1, 1)] = 0
mask.requires_grad = False
new_adj = adj * mask
if set_value:
new_adj[new_adj.nonzero()[:, 0], new_adj.nonzero()[:, 1]] = set_value
return new_adj
def enn(adj, epsilon, set_value=None):
if adj.is_sparse:
n = adj.shape[0]
values = adj.values()
mask = values > epsilon
mask.requires_grad = False
new_values = values[mask]
if set_value:
new_values[:] = set_value
new_indices = adj.indices()[:,mask]
return torch.sparse.FloatTensor(new_indices, new_values, [n, n])
else:
mask = adj > epsilon
mask.requires_grad = False
new_adj = adj * mask
if set_value:
new_adj[mask] = set_value
return new_adj
def to_undirected(adj):
if adj.is_sparse:
device = adj.device
assert (adj.values() == 1).all()
n = adj.shape[0]
indices_t = adj.indices()[[1,0]]
new_indices = torch.cat([adj.indices(), indices_t], dim=1)
new_indices = torch.unique(new_indices, dim=1)
new_values = torch.ones(new_indices.shape[1], device=device)
new_adj = torch.sparse.FloatTensor(new_indices, new_values, [n, n])
return new_adj
else:
return adj + adj.T - adj * (adj <= adj.T) - adj.T * (adj > adj.T)
def knn_fast(X, k, b):
X = F.normalize(X, dim=1, p=2)
index = 0
values = torch.zeros(X.shape[0] * (k + 1)).cuda()
rows = torch.zeros(X.shape[0] * (k + 1)).cuda()
cols = torch.zeros(X.shape[0] * (k + 1)).cuda()
norm_row = torch.zeros(X.shape[0]).cuda()
norm_col = torch.zeros(X.shape[0]).cuda()
while index < X.shape[0]:
if (index + b) > (X.shape[0]):
end = X.shape[0]
else:
end = index + b
sub_tensor = X[index:index + b]
similarities = torch.mm(sub_tensor, X.t())
vals, inds = similarities.topk(k=k + 1, dim=-1)
values[index * (k + 1):(end) * (k + 1)] = vals.view(-1)
cols[index * (k + 1):(end) * (k + 1)] = inds.view(-1)
rows[index * (k + 1):(end) * (k + 1)] = torch.arange(index, end).view(-1, 1).repeat(1, k + 1).view(-1)
norm_row[index: end] = torch.sum(vals, dim=1)
norm_col.index_add_(-1, inds.view(-1), vals.view(-1))
index += b
norm = norm_row + norm_col
rows = rows.long()
cols = cols.long()
values *= (torch.pow(norm[rows], -0.5) * torch.pow(norm[cols], -0.5))
return rows, cols, values
def apply_non_linearity(adj, non_linearity, i):
if non_linearity == 'elu':
return F.elu(adj * i - i) + 1
elif non_linearity == 'relu':
return F.relu(adj)
elif non_linearity == 'none':
return adj
else:
raise KeyError('We dont support the non-linearity yet')
def removeselfloop(adj):
if adj.is_sparse:
# TODO
pass
else:
n = adj.shape[0]
mask = torch.eye(n, device=adj.device).bool()
return adj * ~mask
if __name__ == '__main__':
from torch_geometric import seed_everything
seed_everything(42)
# adj = torch.rand(5, 5).to_sparse()
# adj = torch.sparse.FloatTensor(torch.tensor([[0,0,1,1,2,2,3,3,4],[1,2,3,4,0,1,2,3,3]]), torch.tensor([1,1,1,1,1,1,1,1,1]), [5,5])
adj = torch.rand(3,3).to_sparse()
x = torch.rand(10,3)
print(adj)
print(enn(adj, 0.5))