Commit 7d5a70b7 authored by Florent Chatelain's avatar Florent Chatelain
Browse files

up figure codes

parents 1b3e87fa f62b1235
This diff is collapsed.
# coding: utf-8
""" Two-way puncturing utilities to prepare and puncture synthetic or real-word data.
Also comprise functions to display spectral information, limiting population spikes, clustering performances, ... as the figures in the spaper.
"""
import numpy as np
import scipy as sp
import matplotlib.pyplot as plt
import scipy.linalg as lin
import scipy.stats as stats
import scipy.sparse.linalg
import scipy.special
import pandas as pd
import seaborn as sns
from numpy.random import default_rng
rng = default_rng(0)
plt.rcParams.update({"font.size": 12})
def puncture_X(X, eps):
"""Puncture the data matrix X with a selection rate eps"""
(p, n) = X.shape
cmpl = False
if eps > 0.5:
cmpl = True
eps = 1 - eps
# Compute the (asymptotic) equivalent selection rate when we sample with
# replacement.
if eps <= 0:
q = 0
else:
q = -np.log(1 - eps)
# Number of entries in the lower triangular matrix B
nedges = p * n
# Sample the entry indexes to retain
indreplace = rng.choice(nedges, replace=True, size=int(np.round(q * nedges)))
# Remove duplicates from the sampled indices
if eps > 0.01:
if not cmpl:
ind = -np.ones(nedges, dtype=int)
ind[indreplace] = indreplace
ind = ind[ind >= 0]
else:
ind = np.unique(indreplace)
if cmpl:
ind = np.arange(nedges)
ind[indreplace] = -1
ind = ind[ind >= 0]
# Create the sparsified data matrix
vals = X.ravel()[ind]
I, J = np.unravel_index(ind, X.shape)
X_S = sp.sparse.coo_matrix((vals, (I, J)), shape=X.shape)
# get the 'csc' representation for efficient column slicing
X_S = X_S.tocsc()
return X_S
def ind2sub4low(IND):
"""Subscripts from linear index for lower triangular matrix (only
elements below diagonal). This determines the equivalent subscript
values corresponding to a given single index into a 2D lower
triangular matrix, excluded all elements over the diagonal.
"""
I = np.round(np.floor(-0.5 + 0.5 * np.sqrt(1 + 8 * IND)) + 2).astype("int") - 1
J = np.round((I + 1) * (2 - I) / 2 + IND).astype("int") - 1
return I, J
def mask_B(n, eps, is_diag=1):
"""Create the sparsity mask with density eps for the kernel matrix.
The diagonal entries are omited when is_diag is zero
"""
cmpl = False
if eps > 0.5:
cmpl = True
eps = 1 - eps
# compute the (asymptotic) equivalent selection rate when we sample with replacement
if eps <= 0:
q = 0
else:
q = -np.log(1 - eps)
q = -np.log(1 - eps)
# number of entries in the lower triangular matrix B
nedges = (n * (n - 1)) // 2
# sample the entry indexes to retain
indreplace = rng.choice(nedges, replace=True, size=int(np.round(q * nedges)))
if cmpl:
ind = np.arange(nedges)
ind[indreplace] = -1
ind = ind[ind > 0]
else:
ind = indreplace
# no need to remove the duplicates (done by the sparse matrix factory)
# create the sparse selection matrix
data = np.ones(len(ind))
I, J = ind2sub4low(ind)
B = sp.sparse.coo_matrix((data, (I, J)), shape=(n, n))
# Symmetric matrix (no need to store the upper diag triangular array)
if is_diag != 0:
B += sp.sparse.eye(n)
B = B.tocsr()
# remove duplicate (possible du to csr conversion)
B.data[B.data > 1] = 1
indices = B.indices
indptr = B.indptr
return B, indices, indptr
def puncture_K_vanilla(X_S, B):
"""Compute the punctured kernel matrix with spasity mask B"""
p = X_S.shape[0]
if sp.sparse.issparse(X_S):
X_S = X_S.todense()
K = X_S.T @ X_S
K = sp.sparse.csr_matrix(B.multiply(K)) / p
K = K + sp.sparse.tril(K, k=-1).T
return K
def gen_synth_mus(p, n, cov_mu):
"""Draw the two correlated mean vectors for a two-classes model"""
c0 = p / n # dimension/sample size ratio
# set the corr coeff between the mean vectors
rho = cov_mu[0, 1] / np.sqrt(cov_mu[0, 0] * cov_mu[1, 1])
# draw the mean vectors mu_1,2
mu0 = rng.normal(size=(p,))
mu1 = np.sqrt((1 - rho ** 2)) * rng.normal(size=(p,)) + rho * mu0
mu0 = mu0 * np.sqrt(cov_mu[0, 0])
mu1 = mu1 * np.sqrt(cov_mu[1, 1])
mus = np.concatenate((mu0, mu1)).reshape(p, 2)
return mus
def gen_synth_X(p, n, mus, cs):
"""Draw the noisy data matrix X and get the population spike matrices"""
# Repmat the mean vectors for each of the n samples and stack them in P
# using the proportion cs for each class
n0 = int(0.4 * n)
J = np.zeros((n, 2))
J[:n0, 0] = 1
J[n0:, 1] = 1
P = mus @ J.T
M = np.diag(np.sqrt(cs)) @ (mus.T @ mus) @ np.diag(np.sqrt(cs))
# Population spikes
ells, vM = sp.linalg.eigh(M)
# Full data matrix
Z = rng.normal(size=(p, n))
X = Z + P
return X, ells, vM
def puncture_eigs(X, eB, eS, b=1, sparsity=0):
"""Make the simulation to get the spectral data for a two-way punctured
kernel matrix from a data matrix X.
"""
X_S = puncture_X(X, eS)
B, _, _ = mask_B(X.shape[1], eB, is_diag=b)
K = puncture_K_vanilla(X_S, B)
if sparsity:
lambdas, U = sp.sparse.linalg.eigsh(
K, k=sparsity, which="LA", tol=0, return_eigenvectors=True
)
else:
lambdas = np.linalg.eigvalsh(K.todense())
U = None
return lambdas, U
def puncture_eigs(X, eB, eS, b=1, sparsity=0):
"""Make the simulation to get the spectral data for a two-way punctured
kernel matrix from a data matrix X.
"""
X_S = puncture_X(X, eS)
B, _, _ = mask_B(X.shape[1], eB, is_diag=b)
K = puncture_K_vanilla(X_S, B)
if sparsity:
lambdas, U = sp.sparse.linalg.eigsh(
K, k=sparsity, which="LA", tol=0, return_eigenvectors=True
)
else:
lambdas = np.linalg.eigvalsh(K.todense())
U = None
return lambdas, U
def disp_eigs(ax, eigvals, u, eB, eS, c0, n0, ells, b=1, vM=None):
"""Plot in the first Fig. the sample eigvals distributions with the (limiting)
population spikes, and in the seconf Fig the principal sample and population
eigenvector
"""
n = len(eigvals)
# disp eigvals hist
sns.histplot(
eigvals.flatten(), # color="blue", cbar_kws={'edgecolor': 'darkblue'},
stat="density",
ax=ax[0],
)
# disp limiting empirical spectrum density
xmin = min(np.min(eigvals) * 0.8, np.min(eigvals) * 1.2) # accounting negative min
xmax = np.max(eigvals) * 1.2
xs, density = lsd(eB, eS, c0, b, xmin=xmin, xmax=xmax, nsamples=400)
ax[0].axes.plot(xs, density, "r", label="limiting density")
# disp spike eigvals
if np.isscalar(ells):
ells = np.array([ells])
nells = len(ells) # how many spikes to disp?
yoffset = np.min((2, ax[0].axes.get_ylim()[1])) * 0.004
isolated_eigs = np.zeros(ells.shape)
for i, ell in enumerate(ells):
isolated_eig = spike(eB, eS, c0, ell, b=b)[0]
lablim = ""
labsamp = ""
if i == 0:
lablim = r"limiting spikes"
labsamp = r"largest sample eigvals"
ax[0].axes.plot(
isolated_eig,
yoffset,
"og",
fillstyle="none",
markersize=10,
label=lablim,
)
ax[0].axes.plot(eigvals[-1 - i], yoffset, "ob", label=labsamp)
ax[0].axes.legend()
ax[0].axes.set_xlim([xmin, xmax])
ax[0].axes.set_ylabel("")
# disp principal spike eigenvector
if vM is not None:
zeta = spike(eB, eS, c0, np.max(ells), b)[1] # alignement index
else:
zeta = 1 / np.sqrt(2) # n0 == n1 case
vM = np.array([[0, 0], [-1, 1]])
u_gt = np.concatenate(
[
-vM[1, 0] * np.ones(n0) / np.sqrt(n0),
-vM[1, 1] * np.ones(n - n0) / np.sqrt(n - n0),
]
) * np.sqrt(
zeta
) # ground truth
ax[1].plot(u, "k")
ax[1].plot(u_gt, "r")
# Define below the useful mathematical functions introduced in the paper
def qfunc(t):
return 0.5 - 0.5 * scipy.special.erf(t / np.sqrt(2))
def coeffs_F(eB, eS, c0):
return [1, 2 / eS, 1 / eS ** 2 * (1 - c0 / eB), -2 * c0 / eS ** 3, -c0 / eS ** 4]
def F(eB, eS, c0, t):
return np.polyval(coeffs_F(eB, eS, c0), t)
def G(eB, eS, c0, b, t):
return (
eS * b
+ 1 / c0 * eB * eS * (1 + eS * t)
+ eS / (1 + eS * t)
+ eB / t / (1 + eS * t)
)
def spike(eB, eS, c0, ell, b=1):
Gamma = np.max(np.roots(coeffs_F(eB, eS, c0)))
rho = int(ell > Gamma) * G(eB, eS, c0, b, ell)
zeta = int(ell > Gamma) * (F(eB, eS, c0, ell) * eS ** 3 / ell / (1 + eS * ell) ** 3)
return rho, zeta
def phase_transition(c0, ell, res=1e-3):
eBs = np.arange(res, 1, res)
eSs = np.zeros(len(eBs))
for iB, eB in enumerate(eBs):
eSs[iB] = res
while np.max(np.roots(coeffs_F(eB, eSs[iB], c0))) > ell and eSs[iB] < 1:
eSs[iB] += res
return eBs, eSs
def lsd(eB, eS, c, b=1, xmin=-5, xmax=10, nsamples=1000):
axr = np.linspace(xmin, xmax, nsamples)
m_old = 0
iy = 1e-3
watchdog = 50000
density = np.zeros(axr.shape, dtype="float")
for i, x in enumerate(axr):
z = (x + iy * 1j) / eS
m = m_old
delta = 1
w = 0
while (delta > 1e-6) and w < watchdog:
m_bar = 1 / (b - z - eB / c * m + eB ** 3 * m ** 2 / (c * (c + eB * m)))
delta = np.abs(m - m_bar)
m = m_bar
w += 1
density[i] = (m_bar / eS).imag / np.pi
return axr, density
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment