Commit 7d5a70b7 by Florent Chatelain

### up figure codes

parents 1b3e87fa f62b1235
This diff is collapsed.
This diff is collapsed.
punctutils.py 0 → 100644
 # coding: utf-8 """ Two-way puncturing utilities to prepare and puncture synthetic or real-word data. Also comprise functions to display spectral information, limiting population spikes, clustering performances, ... as the figures in the spaper. """ import numpy as np import scipy as sp import matplotlib.pyplot as plt import scipy.linalg as lin import scipy.stats as stats import scipy.sparse.linalg import scipy.special import pandas as pd import seaborn as sns from numpy.random import default_rng rng = default_rng(0) plt.rcParams.update({"font.size": 12}) def puncture_X(X, eps): """Puncture the data matrix X with a selection rate eps""" (p, n) = X.shape cmpl = False if eps > 0.5: cmpl = True eps = 1 - eps # Compute the (asymptotic) equivalent selection rate when we sample with # replacement. if eps <= 0: q = 0 else: q = -np.log(1 - eps) # Number of entries in the lower triangular matrix B nedges = p * n # Sample the entry indexes to retain indreplace = rng.choice(nedges, replace=True, size=int(np.round(q * nedges))) # Remove duplicates from the sampled indices if eps > 0.01: if not cmpl: ind = -np.ones(nedges, dtype=int) ind[indreplace] = indreplace ind = ind[ind >= 0] else: ind = np.unique(indreplace) if cmpl: ind = np.arange(nedges) ind[indreplace] = -1 ind = ind[ind >= 0] # Create the sparsified data matrix vals = X.ravel()[ind] I, J = np.unravel_index(ind, X.shape) X_S = sp.sparse.coo_matrix((vals, (I, J)), shape=X.shape) # get the 'csc' representation for efficient column slicing X_S = X_S.tocsc() return X_S def ind2sub4low(IND): """Subscripts from linear index for lower triangular matrix (only elements below diagonal). This determines the equivalent subscript values corresponding to a given single index into a 2D lower triangular matrix, excluded all elements over the diagonal. """ I = np.round(np.floor(-0.5 + 0.5 * np.sqrt(1 + 8 * IND)) + 2).astype("int") - 1 J = np.round((I + 1) * (2 - I) / 2 + IND).astype("int") - 1 return I, J def mask_B(n, eps, is_diag=1): """Create the sparsity mask with density eps for the kernel matrix. The diagonal entries are omited when is_diag is zero """ cmpl = False if eps > 0.5: cmpl = True eps = 1 - eps # compute the (asymptotic) equivalent selection rate when we sample with replacement if eps <= 0: q = 0 else: q = -np.log(1 - eps) q = -np.log(1 - eps) # number of entries in the lower triangular matrix B nedges = (n * (n - 1)) // 2 # sample the entry indexes to retain indreplace = rng.choice(nedges, replace=True, size=int(np.round(q * nedges))) if cmpl: ind = np.arange(nedges) ind[indreplace] = -1 ind = ind[ind > 0] else: ind = indreplace # no need to remove the duplicates (done by the sparse matrix factory) # create the sparse selection matrix data = np.ones(len(ind)) I, J = ind2sub4low(ind) B = sp.sparse.coo_matrix((data, (I, J)), shape=(n, n)) # Symmetric matrix (no need to store the upper diag triangular array) if is_diag != 0: B += sp.sparse.eye(n) B = B.tocsr() # remove duplicate (possible du to csr conversion) B.data[B.data > 1] = 1 indices = B.indices indptr = B.indptr return B, indices, indptr def puncture_K_vanilla(X_S, B): """Compute the punctured kernel matrix with spasity mask B""" p = X_S.shape[0] if sp.sparse.issparse(X_S): X_S = X_S.todense() K = X_S.T @ X_S K = sp.sparse.csr_matrix(B.multiply(K)) / p K = K + sp.sparse.tril(K, k=-1).T return K def gen_synth_mus(p, n, cov_mu): """Draw the two correlated mean vectors for a two-classes model""" c0 = p / n # dimension/sample size ratio # set the corr coeff between the mean vectors rho = cov_mu[0, 1] / np.sqrt(cov_mu[0, 0] * cov_mu[1, 1]) # draw the mean vectors mu_1,2 mu0 = rng.normal(size=(p,)) mu1 = np.sqrt((1 - rho ** 2)) * rng.normal(size=(p,)) + rho * mu0 mu0 = mu0 * np.sqrt(cov_mu[0, 0]) mu1 = mu1 * np.sqrt(cov_mu[1, 1]) mus = np.concatenate((mu0, mu1)).reshape(p, 2) return mus def gen_synth_X(p, n, mus, cs): """Draw the noisy data matrix X and get the population spike matrices""" # Repmat the mean vectors for each of the n samples and stack them in P # using the proportion cs for each class n0 = int(0.4 * n) J = np.zeros((n, 2)) J[:n0, 0] = 1 J[n0:, 1] = 1 P = mus @ J.T M = np.diag(np.sqrt(cs)) @ (mus.T @ mus) @ np.diag(np.sqrt(cs)) # Population spikes ells, vM = sp.linalg.eigh(M) # Full data matrix Z = rng.normal(size=(p, n)) X = Z + P return X, ells, vM def puncture_eigs(X, eB, eS, b=1, sparsity=0): """Make the simulation to get the spectral data for a two-way punctured kernel matrix from a data matrix X. """ X_S = puncture_X(X, eS) B, _, _ = mask_B(X.shape[1], eB, is_diag=b) K = puncture_K_vanilla(X_S, B) if sparsity: lambdas, U = sp.sparse.linalg.eigsh( K, k=sparsity, which="LA", tol=0, return_eigenvectors=True ) else: lambdas = np.linalg.eigvalsh(K.todense()) U = None return lambdas, U def puncture_eigs(X, eB, eS, b=1, sparsity=0): """Make the simulation to get the spectral data for a two-way punctured kernel matrix from a data matrix X. """ X_S = puncture_X(X, eS) B, _, _ = mask_B(X.shape[1], eB, is_diag=b) K = puncture_K_vanilla(X_S, B) if sparsity: lambdas, U = sp.sparse.linalg.eigsh( K, k=sparsity, which="LA", tol=0, return_eigenvectors=True ) else: lambdas = np.linalg.eigvalsh(K.todense()) U = None return lambdas, U def disp_eigs(ax, eigvals, u, eB, eS, c0, n0, ells, b=1, vM=None): """Plot in the first Fig. the sample eigvals distributions with the (limiting) population spikes, and in the seconf Fig the principal sample and population eigenvector """ n = len(eigvals) # disp eigvals hist sns.histplot( eigvals.flatten(), # color="blue", cbar_kws={'edgecolor': 'darkblue'}, stat="density", ax=ax[0], ) # disp limiting empirical spectrum density xmin = min(np.min(eigvals) * 0.8, np.min(eigvals) * 1.2) # accounting negative min xmax = np.max(eigvals) * 1.2 xs, density = lsd(eB, eS, c0, b, xmin=xmin, xmax=xmax, nsamples=400) ax[0].axes.plot(xs, density, "r", label="limiting density") # disp spike eigvals if np.isscalar(ells): ells = np.array([ells]) nells = len(ells) # how many spikes to disp? yoffset = np.min((2, ax[0].axes.get_ylim()[1])) * 0.004 isolated_eigs = np.zeros(ells.shape) for i, ell in enumerate(ells): isolated_eig = spike(eB, eS, c0, ell, b=b)[0] lablim = "" labsamp = "" if i == 0: lablim = r"limiting spikes" labsamp = r"largest sample eigvals" ax[0].axes.plot( isolated_eig, yoffset, "og", fillstyle="none", markersize=10, label=lablim, ) ax[0].axes.plot(eigvals[-1 - i], yoffset, "ob", label=labsamp) ax[0].axes.legend() ax[0].axes.set_xlim([xmin, xmax]) ax[0].axes.set_ylabel("") # disp principal spike eigenvector if vM is not None: zeta = spike(eB, eS, c0, np.max(ells), b)[1] # alignement index else: zeta = 1 / np.sqrt(2) # n0 == n1 case vM = np.array([[0, 0], [-1, 1]]) u_gt = np.concatenate( [ -vM[1, 0] * np.ones(n0) / np.sqrt(n0), -vM[1, 1] * np.ones(n - n0) / np.sqrt(n - n0), ] ) * np.sqrt( zeta ) # ground truth ax[1].plot(u, "k") ax[1].plot(u_gt, "r") # Define below the useful mathematical functions introduced in the paper def qfunc(t): return 0.5 - 0.5 * scipy.special.erf(t / np.sqrt(2)) def coeffs_F(eB, eS, c0): return [1, 2 / eS, 1 / eS ** 2 * (1 - c0 / eB), -2 * c0 / eS ** 3, -c0 / eS ** 4] def F(eB, eS, c0, t): return np.polyval(coeffs_F(eB, eS, c0), t) def G(eB, eS, c0, b, t): return ( eS * b + 1 / c0 * eB * eS * (1 + eS * t) + eS / (1 + eS * t) + eB / t / (1 + eS * t) ) def spike(eB, eS, c0, ell, b=1): Gamma = np.max(np.roots(coeffs_F(eB, eS, c0))) rho = int(ell > Gamma) * G(eB, eS, c0, b, ell) zeta = int(ell > Gamma) * (F(eB, eS, c0, ell) * eS ** 3 / ell / (1 + eS * ell) ** 3) return rho, zeta def phase_transition(c0, ell, res=1e-3): eBs = np.arange(res, 1, res) eSs = np.zeros(len(eBs)) for iB, eB in enumerate(eBs): eSs[iB] = res while np.max(np.roots(coeffs_F(eB, eSs[iB], c0))) > ell and eSs[iB] < 1: eSs[iB] += res return eBs, eSs def lsd(eB, eS, c, b=1, xmin=-5, xmax=10, nsamples=1000): axr = np.linspace(xmin, xmax, nsamples) m_old = 0 iy = 1e-3 watchdog = 50000 density = np.zeros(axr.shape, dtype="float") for i, x in enumerate(axr): z = (x + iy * 1j) / eS m = m_old delta = 1 w = 0 while (delta > 1e-6) and w < watchdog: m_bar = 1 / (b - z - eB / c * m + eB ** 3 * m ** 2 / (c * (c + eB * m))) delta = np.abs(m - m_bar) m = m_bar w += 1 density[i] = (m_bar / eS).imag / np.pi return axr, density
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!