Vous avez reçu un message "Your GitLab account has been locked ..." ? Pas d'inquiétude : lisez cet article https://docs.gricad-pages.univ-grenoble-alpes.fr/help/unlock/

Commit 6134ef6f authored by Florent Chatelain's avatar Florent Chatelain
Browse files

up nb

parent c00d2cfe
......@@ -11,8 +11,7 @@ The elements of proof of the main theorems in the core article are detailed in
### GAN Data
GAN data that we generate and use in the paper are stored as [Git LFS](https://github.com/git-lfs/git-lfs/wiki) objects. Once Git LFS is [installed](https://github.com/git-lfs/git-lfs/wiki/Installation), the corresponding `.npy` files can be downloaded locally in the cloned repo
with the command
[VGG](https://arxiv.org/pdf/1409.1556.pdf) features of randomly [BigGAN](https://arxiv.org/abs/1809.11096)-generated images that we use in the paper are stored as [Git LFS](https://github.com/git-lfs/git-lfs/wiki) objects which avoids storing them locally. However this is possible to store the corresponding `.npy` files locally in order to avoid downloading them each time they are processed in the codes. This can be done manually via the GitLab web interface. An alternative, once Git LFS is [installed](https://github.com/git-lfs/git-lfs/wiki/Installation), can be is to use the command
```bash
> git lfs pull
......
......@@ -227,7 +227,36 @@ def gen_MNIST_vectors(n0=1024):
ell = (((mean0 - mean1) / sd_full_iso) ** 2) * p
return X, n0, n1, ell
def gen_GAN_vectors(n0=5000):
"""Load, reshape and preprocess and the VGG feature vectors for the
'tabby' and 'collie' BigGAN images
"""
n1 = n0
from pathlib import Path
basename = Path('./data') # relative path to data folder in the gitlab repo
infile0 = 'GAN_collie_vgg19_large.npy'
X0 = np.load((basename / infile0).as_posix())
infile1 = 'GAN_tabby_vgg19_large.npy'
X1 = np.load((basename / infile1).as_posix())
data_full = np.concatenate((X0, X1), axis=0).astype(float) # (n,p) convention
sd_full = data_full.std(axis=0, ddof=1)
X0 = X0[:n0, :]
X1 = X1[:n1, :]
data = np.concatenate((X0, X1), axis=0).astype(float) # (n,p) convention
n, p = data.shape
c = p / n
X = (data - data.mean(axis=0)) / sd_full
X = X.T # now this is (p,n) convention
# Power of the spike
vecmu = X[:,:n0].mean(axis=1)
ell = sp.linalg.norm(vecmu)**2
return X, n0, n1, ell
def disp_eigs(ax, eigvals, u, eB, eS, c0, n0, ells, b=1, vM=None):
"""Plot in the first Fig. the sample eigvals distributions with the (limiting)
......@@ -350,40 +379,38 @@ def disp_eigs_full(axes, eigvals, u, c0, ell, lmax= 15, b=1):
def get_perf_clustering(n0=5000, nbMC=10):
def get_perf_clustering(n0=5000, nbMC=10, isGAN=False):
"""Compute the missclassification rates for two-way punctured
spectral clustering as a function of the puncturing rates
for 'MNIST' fashion (default choice) or 'GAN' datasets
"""
s_eps_ref = 0.1
b_eps_ref = 0.1
index_ref = s_eps_ref ** 2 * b_eps_ref
# Get the two classes GAN vectors
X, n0, n1, ell = gen_MNIST_vectors(n0)
# Get the two classes GAN, or MNIST, vectors
if (isGAN):
X, n0, n1, ell = gen_GAN_vectors(n0)
# Set the list of epsilon_B puncturing rates
b_eps_vals = np.array(
[
5e-4, 0.001, 0.005, 0.01, 0.012, 0.015, 0.02, 0.03, 0.05, 0.1,
0.5, 1,
]
)
else:
X, n0, n1, ell = gen_MNIST_vectors(n0)
# Set the list of epsilon_B puncturing rates
b_eps_vals = np.array(
[
5e-4, 0.001, 0.002, 0.003, 0.004, 0.005, 0.01, 0.012, 0.015,
0.02, 0.03, 0.05, 0.1, 0.5, 1,
]
)
(p, n) = X.shape
c = p / n
# Set the list of epsilon_B puncturing rates
b_eps_vals = np.array(
[
5e-4,
0.001,
0.002,
0.003,
0.004,
0.005,
0.01,
0.012,
0.015,
0.02,
0.03,
0.05,
0.1,
0.5,
1,
]
)
neps = len(b_eps_vals)
# Arrays for empirical clustering perf
......@@ -451,7 +478,7 @@ def get_perf_clustering(n0=5000, nbMC=10):
return df, n0
def plot_perf_clustering(df, n0, ax):
def plot_perf_clustering(df, ax, isGAN=False):
"""Plot the missclassification rates for two-way punctured
spectral clustering as a function of the puncturing rates
"""
......@@ -480,7 +507,10 @@ def plot_perf_clustering(df, n0, ax):
)
ax.axes.legend()
ax.axes.set_ylabel("")
ax.axes.set_ylim([0, 0.4])
if (isGAN):
ax.axes.set_ylim([0, 0.1])
else:
ax.axes.set_ylim([0, 0.4])
ax.axes.set_xlim([0, 0.1])
ax.axes.grid("on")
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment