Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Florent Chatelain
two way kernel matrix puncturing
Commits
dfa652b7
Commit
dfa652b7
authored
Feb 09, 2021
by
Florent Chatelain
Browse files
up figures
parent
7d5a70b7
Changes
2
Hide whitespace changes
Inline
Side-by-side
figures.ipynb
View file @
dfa652b7
This source diff could not be displayed because it is too large. You can
view the blob
instead.
punctutils.py
View file @
dfa652b7
...
...
@@ -11,16 +11,15 @@ import scipy.linalg as lin
import
scipy.stats
as
stats
import
scipy.sparse.linalg
import
scipy.special
import
sys
import
pandas
as
pd
import
seaborn
as
sns
from
numpy.random
import
default_rng
from
tensorflow.keras.datasets
import
fashion_mnist
from
numpy.random
import
default_rng
rng
=
default_rng
(
0
)
plt
.
rcParams
.
update
({
"font.size"
:
12
})
def
puncture_X
(
X
,
eps
):
"""Puncture the data matrix X with a selection rate eps"""
...
...
@@ -157,7 +156,7 @@ def gen_synth_X(p, n, mus, cs):
# Repmat the mean vectors for each of the n samples and stack them in P
# using the proportion cs for each class
n0
=
int
(
0.4
*
n
)
n0
=
int
(
cs
[
0
]
*
n
)
J
=
np
.
zeros
((
n
,
2
))
J
[:
n0
,
0
]
=
1
J
[
n0
:,
1
]
=
1
...
...
@@ -189,24 +188,44 @@ def puncture_eigs(X, eB, eS, b=1, sparsity=0):
lambdas
=
np
.
linalg
.
eigvalsh
(
K
.
todense
())
U
=
None
return
lambdas
,
U
def
puncture_eigs
(
X
,
eB
,
eS
,
b
=
1
,
sparsity
=
0
):
"""Make the simulation to get the spectral data for a two-way punctured
kernel matrix from a data matrix X.
def
gen_MNIST_vectors
(
n0
=
1024
):
"""Load, reshape and preprocess and the 'trouser' and 'pullover'
data images
"""
n1
=
n0
(
X
,
y
),
_
=
fashion_mnist
.
load_data
()
selected_labels
=
[
1
,
2
]
# trouser and pullover
X0
=
X
[
y
==
selected_labels
[
0
]].
reshape
(
-
1
,
28
**
2
)
X1
=
X
[
y
==
selected_labels
[
1
]].
reshape
(
-
1
,
28
**
2
)
X_S
=
puncture_X
(
X
,
eS
)
B
,
_
,
_
=
mask_B
(
X
.
shape
[
1
],
eB
,
is_diag
=
b
)
K
=
puncture_K_vanilla
(
X_S
,
B
)
mean0
=
X0
.
mean
()
mean1
=
X1
.
mean
()
if
sparsity
:
lambdas
,
U
=
sp
.
sparse
.
linalg
.
eigsh
(
K
,
k
=
sparsity
,
which
=
"LA"
,
tol
=
0
,
return_eigenvectors
=
True
)
else
:
lambdas
=
np
.
linalg
.
eigvalsh
(
K
.
todense
())
U
=
None
return
lambdas
,
U
data_full
=
np
.
concatenate
((
X0
,
X1
),
axis
=
0
).
astype
(
float
)
# (n,p) convention
data_full
=
X
.
reshape
(
-
1
,
28
**
2
).
astype
(
float
)
mean_full
=
data_full
.
mean
(
axis
=
0
)
sd_full_iso
=
data_full
.
ravel
().
std
(
axis
=
0
,
ddof
=
1
)
ind0
=
rng
.
choice
(
X0
.
shape
[
0
],
replace
=
False
,
size
=
n0
)
X0
=
X0
[
ind0
,
:]
ind1
=
rng
.
choice
(
X1
.
shape
[
0
],
replace
=
False
,
size
=
n1
)
X1
=
X1
[
ind1
,
:]
data
=
np
.
concatenate
((
X0
,
X1
),
axis
=
0
).
astype
(
float
)
# (n,p) convention
n
,
p
=
data
.
shape
c
=
p
/
n
X
=
(
data
-
data
.
mean
(
axis
=
0
))
/
sd_full_iso
X
=
X
.
T
# now this is (p,n) convention
# mean0 / sd_full_iso, mean1 / sd_full_iso
ell
=
(((
mean0
-
mean1
)
/
sd_full_iso
)
**
2
)
*
p
return
X
,
n0
,
n1
,
ell
...
...
@@ -279,8 +298,192 @@ def disp_eigs(ax, eigvals, u, eB, eS, c0, n0, ells, b=1, vM=None):
ax
[
1
].
plot
(
u
,
"k"
)
ax
[
1
].
plot
(
u_gt
,
"r"
)
def
disp_eigs_full
(
axes
,
eigvals
,
u
,
c0
,
ell
,
lmax
=
15
,
b
=
1
):
"""Plot in the first Fig. the truncated sample eigvals distributions of the
non-punctured kernel, and in the second Fig the dominanr sample and population
eigenvector
"""
isolated_eig
=
spike
(
1
,
1
,
c0
,
ell
,
b
=
1
)[
0
]
spike_eig
=
[
isolated_eig
,
eigvals
[
-
1
]]
in_xmax
=
np
.
max
(
spike_eig
)
*
2
in_xmin
=
np
.
min
(
spike_eig
)
*
0.5
# sns.lineplot(x=[in_xmin, in_xmax], y=[0, 0], lw=2, color="k", ax=axins)
axes
[
0
].
axes
.
plot
([
in_xmin
,
in_xmax
],
[
0
,
0
],
"-k"
)
axes
[
0
].
axes
.
plot
(
isolated_eig
,
0
,
"og"
,
fillstyle
=
"none"
,
markersize
=
10
,
label
=
r
"limiting spike"
,
)
axes
[
0
].
axes
.
plot
(
eigvals
[
-
1
],
0
,
"ob"
,
markersize
=
8
,
label
=
r
"Largest eigval"
)
eigvals_t
=
eigvals
[
eigvals
<
lmax
]
axes
[
0
].
axes
.
hist
(
eigvals_t
,
bins
=
np
.
linspace
(
0
,
lmax
,
50
),
density
=
True
)
axr
,
density
=
lsd
(
1
,
1
,
c0
,
b
=
1
,
xmin
=
0
,
xmax
=
lmax
,
nsamples
=
400
)
axes
[
0
].
axes
.
plot
(
axr
,
density
,
lw
=
2
,
color
=
"r"
,
label
=
"limiting density"
)
axes
[
0
].
axes
.
set_ylabel
(
""
)
axes
[
0
].
axes
.
set_xlim
([
0
,
lmax
])
axes
[
0
].
axes
.
set_ylim
([
0
,
0.1
])
#
# data = {'eig': eigvals[-1] , 'y': 0}
# data = pd.DataFrame([data])
# sns.scatterplot(data=data, x='eig', y='y', marker="o",, ax=axins)
u_gt
=
np
.
ones
(
len
(
u
))
n0
=
len
(
u
)
//
2
u_gt
[
n0
:]
=
-
1
u_gt
=
u_gt
/
np
.
sqrt
(
len
(
u_gt
))
axes
[
1
].
axes
.
plot
(
u
)
sns
.
lineplot
(
x
=
np
.
arange
(
len
(
u_gt
)),
y
=
u_gt
,
lw
=
2
,
color
=
"r"
,
ax
=
axes
[
1
])
axes
[
1
].
axes
.
set_ylim
([
-
0.05
,
0.05
])
def
get_perf_clustering
(
n0
=
5000
,
nbMC
=
10
):
"""Compute the missclassification rates for two-way punctured
spectral clustering as a function of the puncturing rates
"""
s_eps_ref
=
0.1
b_eps_ref
=
0.1
index_ref
=
s_eps_ref
**
2
*
b_eps_ref
# Get the two classes GAN vectors
X
,
n0
,
n1
,
ell
=
gen_MNIST_vectors
(
n0
)
(
p
,
n
)
=
X
.
shape
c
=
p
/
n
# Set the list of epsilon_B puncturing rates
b_eps_vals
=
np
.
array
(
[
5e-4
,
0.001
,
0.002
,
0.003
,
0.004
,
0.005
,
0.01
,
0.012
,
0.015
,
0.02
,
0.03
,
0.05
,
0.1
,
0.5
,
1
,
]
)
neps
=
len
(
b_eps_vals
)
# Arrays for empirical clustering perf
fixed_s_eps_perf
=
np
.
zeros
((
neps
))
equiv_s_eps_perf
=
np
.
zeros
((
neps
))
# Set the epsilon_S puncturing rates in the equi-performance regime
perf_ref
=
np
.
sqrt
(
s_eps_ref
**
2
*
b_eps_ref
/
c
)
s_eps_equi_vals
=
perf_ref
/
np
.
sqrt
(
b_eps_vals
/
c
)
# Ground truth vector with class0 and class1 coded as +1 and -1 resp.
u_gt
=
np
.
ones
((
n
,
1
))
u_gt
[
n0
:,
0
]
=
-
1
for
iMC
in
range
(
nbMC
):
# progression bar
sys
.
stdout
.
write
(
"
\r
"
)
sys
.
stdout
.
write
(
"[progression {}/{}]"
.
format
(
iMC
+
1
,
nbMC
))
sys
.
stdout
.
flush
()
# Fixed epsilon_S data
X_S_fix
=
puncture_X
(
X
,
s_eps_ref
)
# fixed eps_S regime
# 'Brute force' implementation: use dense vectors to get numpy multithreading and optimization
X_S_fix
=
X_S_fix
.
todense
()
K_full_fix
=
X_S_fix
.
T
@
X_S_fix
for
i
,
b_eps
in
enumerate
(
b_eps_vals
):
B
,
_
,
_
=
mask_B
(
n
,
b_eps
)
# Fixed S perf
K_fix
=
sp
.
sparse
.
csr_matrix
(
B
.
multiply
(
K_full_fix
))
/
p
K_fix
=
K_fix
+
sp
.
sparse
.
tril
(
K_fix
,
k
=-
1
).
T
_
,
u
=
sp
.
sparse
.
linalg
.
eigsh
(
K_fix
,
k
=
1
,
which
=
"LA"
,
tol
=
0
,
return_eigenvectors
=
True
)
m0
=
np
.
mean
(
u
[:
n0
])
m1
=
np
.
mean
(
u
[
n0
:])
u
=
u
*
np
.
sign
(
m0
-
m1
)
fixed_s_eps_perf
[
i
]
=
np
.
mean
(
u
*
u_gt
<
0
)
# Equiperf eps_S regime
s_eps_eq
=
s_eps_equi_vals
[
i
]
X_S_eq
=
puncture_X
(
X
,
s_eps_eq
)
# equi-performance regime
K_eq
=
puncture_K_vanilla
(
X_S_eq
,
B
)
_
,
u
=
sp
.
sparse
.
linalg
.
eigsh
(
K_eq
,
k
=
1
,
which
=
"LA"
,
tol
=
0
,
return_eigenvectors
=
True
)
m0
=
np
.
mean
(
u
[:
n0
])
m1
=
np
.
mean
(
u
[
n0
:])
u
=
u
*
np
.
sign
(
m0
-
m1
)
equiv_s_eps_perf
[
i
]
=
np
.
mean
(
u
*
u_gt
<
0
)
d
=
{
"Beps"
:
b_eps_vals
,
"Seps_equi"
:
s_eps_equi_vals
,
"perf_equi"
:
equiv_s_eps_perf
,
"perf_fixed"
:
fixed_s_eps_perf
,
"nbMC"
:
nbMC
,
"n0"
:
n0
,
"n1"
:
n1
,
}
if
iMC
>
0
:
df_tmp
=
pd
.
DataFrame
(
data
=
d
)
df
=
df
.
append
(
df_tmp
,
ignore_index
=
True
)
else
:
df
=
pd
.
DataFrame
(
data
=
d
)
return
df
,
n0
def
plot_perf_clustering
(
df
,
n0
,
ax
):
"""Plot the missclassification rates for two-way punctured
spectral clustering as a function of the puncturing rates
"""
s_eps_ref
=
0.1
b_eps_ref
=
0.1
index_ref
=
s_eps_ref
**
2
*
b_eps_ref
sns
.
lineplot
(
data
=
df
,
x
=
"Beps"
,
y
=
"perf_equi"
,
err_style
=
"band"
,
ci
=
"sd"
,
label
=
r
"$\epsilon_S^2 \epsilon_B=${:.3f}"
.
format
(
index_ref
),
ax
=
ax
,
)
sns
.
lineplot
(
data
=
df
,
x
=
"Beps"
,
y
=
"perf_fixed"
,
err_style
=
"band"
,
ci
=
"sd"
,
label
=
r
"$\epsilon_S=${:.3f}"
.
format
(
s_eps_ref
),
ax
=
ax
,
)
ax
.
axes
.
legend
()
ax
.
axes
.
set_ylabel
(
""
)
ax
.
axes
.
set_ylim
([
0
,
0.4
])
ax
.
axes
.
set_xlim
([
0
,
0.1
])
ax
.
axes
.
grid
(
"on"
)
# Define below the useful mathematical functions introduced in the paper
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment