Vous avez reçu un message "Your GitLab account has been locked ..." ? Pas d'inquiétude : lisez cet article https://docs.gricad-pages.univ-grenoble-alpes.fr/help/unlock/

Commit f7d2b493 authored by Florent Chatelain's avatar Florent Chatelain
Browse files

up codes

parent 8414211d
This diff is collapsed.
......@@ -23,7 +23,7 @@ rng = default_rng(0)
def puncture_X(X, eps):
"""Puncture the data matrix X with a selection rate eps"""
(p, n) = X.shape
cmpl = False
if eps > 0.5:
......@@ -67,8 +67,8 @@ def puncture_X(X, eps):
def ind2sub4low(IND):
"""Subscripts from linear index for lower triangular matrix (only
elements below diagonal). This determines the equivalent subscript
values corresponding to a given single index into a 2D lower
elements below diagonal). This determines the equivalent subscript
values corresponding to a given single index into a 2D lower
triangular matrix, excluded all elements over the diagonal.
"""
......@@ -79,9 +79,9 @@ def ind2sub4low(IND):
def mask_B(n, eps, is_diag=1):
"""Create the sparsity mask with density eps for the kernel matrix.
The diagonal entries are omited when is_diag is zero
The diagonal entries are omited when is_diag is zero
"""
cmpl = False
if eps > 0.5:
cmpl = True
......@@ -126,7 +126,7 @@ def mask_B(n, eps, is_diag=1):
def puncture_K_vanilla(X_S, B):
"""Compute the punctured kernel matrix with spasity mask B"""
p = X_S.shape[0]
if sp.sparse.issparse(X_S):
X_S = X_S.todense()
......@@ -138,7 +138,7 @@ def puncture_K_vanilla(X_S, B):
def gen_synth_mus(p, n, cov_mu):
"""Draw the two correlated mean vectors for a two-classes model"""
c0 = p / n # dimension/sample size ratio
# set the corr coeff between the mean vectors
rho = cov_mu[0, 1] / np.sqrt(cov_mu[0, 0] * cov_mu[1, 1])
......@@ -153,7 +153,7 @@ def gen_synth_mus(p, n, cov_mu):
def gen_synth_X(p, n, mus, cs):
"""Draw the noisy data matrix X and get the population spike matrices"""
# Repmat the mean vectors for each of the n samples and stack them in P
# using the proportion cs for each class
n0 = int(cs[0] * n)
......@@ -172,10 +172,10 @@ def gen_synth_X(p, n, mus, cs):
def puncture_eigs(X, eB, eS, b=1, sparsity=0):
"""Make the simulation to get the spectral data for a two-way punctured
"""Make the simulation to get the spectral data for a two-way punctured
kernel matrix from a data matrix X.
"""
X_S = puncture_X(X, eS)
B, _, _ = mask_B(X.shape[1], eB, is_diag=b)
K = puncture_K_vanilla(X_S, B)
......@@ -188,15 +188,15 @@ def puncture_eigs(X, eB, eS, b=1, sparsity=0):
lambdas = np.linalg.eigvalsh(K.todense())
U = None
return lambdas, U
def gen_MNIST_vectors(n0=1024):
"""Load, reshape and preprocess and the 'trouser' and 'pullover'
data images
"""
n1 = n0
(X, y), _ = fashion_mnist.load_data()
selected_labels = [1, 2] # trouser and pullover
......@@ -236,7 +236,7 @@ def disp_eigs(ax, eigvals, u, eB, eS, c0, n0, ells, b=1, vM=None):
"""
n = len(eigvals)
# disp eigvals hist
sns.histplot(
eigvals.flatten(), # color="blue", cbar_kws={'edgecolor': 'darkblue'},
......@@ -298,10 +298,10 @@ def disp_eigs(ax, eigvals, u, eB, eS, c0, n0, ells, b=1, vM=None):
ax[1].plot(u, "k")
ax[1].plot(u_gt, "r")
def disp_eigs_full(axes, eigvals, u, c0, ell, lmax= 15, b=1):
"""Plot in the first Fig. the truncated sample eigvals distributions of the
"""Plot in the first Fig. the truncated sample eigvals distributions of the
non-punctured kernel, and in the second Fig the dominanr sample and population
eigenvector
"""
......@@ -322,18 +322,19 @@ def disp_eigs_full(axes, eigvals, u, c0, ell, lmax= 15, b=1):
label=r"limiting spike",
)
axes[0].axes.plot(eigvals[-1], 0, "ob", markersize=8, label=r"Largest eigval")
eigvals_t = eigvals[eigvals < lmax]
axes[0].axes.hist(eigvals_t, bins=np.linspace(0, lmax, 50), density=True)
axr, density = lsd(1, 1, c0, b=1, xmin=0, xmax=lmax, nsamples=400)
axes[0].axes.plot(axr, density, lw=2, color="r", label="limiting density")
axes[0].axes.set_ylabel("")
axes[0].axes.set_xlim([0, lmax])
axes[0].axes.set_ylim([0, 0.1])
print("Full case: largest eigval {:.2f} and limiting spike {:.2f}".format(
eigvals[-1], isolated_eig))
#
# data = {'eig': eigvals[-1] , 'y': 0}
# data = pd.DataFrame([data])
# sns.scatterplot(data=data, x='eig', y='y', marker="o",, ax=axins)
......@@ -350,14 +351,14 @@ def disp_eigs_full(axes, eigvals, u, c0, ell, lmax= 15, b=1):
def get_perf_clustering(n0=5000, nbMC=10):
"""Compute the missclassification rates for two-way punctured
"""Compute the missclassification rates for two-way punctured
spectral clustering as a function of the puncturing rates
"""
s_eps_ref = 0.1
b_eps_ref = 0.1
index_ref = s_eps_ref ** 2 * b_eps_ref
# Get the two classes GAN vectors
X, n0, n1, ell = gen_MNIST_vectors(n0)
(p, n) = X.shape
......@@ -451,14 +452,14 @@ def get_perf_clustering(n0=5000, nbMC=10):
def plot_perf_clustering(df, n0, ax):
"""Plot the missclassification rates for two-way punctured
"""Plot the missclassification rates for two-way punctured
spectral clustering as a function of the puncturing rates
"""
s_eps_ref = 0.1
b_eps_ref = 0.1
index_ref = s_eps_ref ** 2 * b_eps_ref
sns.lineplot(
data=df,
x="Beps",
......@@ -482,9 +483,9 @@ def plot_perf_clustering(df, n0, ax):
ax.axes.set_ylim([0, 0.4])
ax.axes.set_xlim([0, 0.1])
ax.axes.grid("on")
# Define below the useful mathematical functions introduced in the paper
def qfunc(t):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment