Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
Fidle
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Deploy
Releases
Package Registry
Model registry
Operate
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Slim Karkar
Fidle
Commits
665b3434
Commit
665b3434
authored
5 years ago
by
Jean-Luc Parouty Jean-Luc.Parouty@simap.grenoble-inp.fr
Browse files
Options
Downloads
Patches
Plain Diff
Update VAE with leaning rate scheduler
Former-commit-id:
b5ef0e38
parent
9e4e9e88
No related branches found
Branches containing commit
No related tags found
No related merge requests found
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
VAE/01-VAE with MNIST.ipynb
+202
-12
202 additions, 12 deletions
VAE/01-VAE with MNIST.ipynb
VAE/modules/callbacks.py
+12
-1
12 additions, 1 deletion
VAE/modules/callbacks.py
VAE/modules/vae.py
+18
-9
18 additions, 9 deletions
VAE/modules/vae.py
with
232 additions
and
22 deletions
VAE/01-VAE with MNIST.ipynb
+
202
−
12
View file @
665b3434
...
@@ -22,9 +22,22 @@
...
@@ -22,9 +22,22 @@
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count":
null
,
"execution_count":
1
,
"metadata": {},
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"IDLE 2020 - Practical Work Module\n",
" Version : 0.2.5\n",
" Run time : Tuesday 4 February 2020, 00:10:15\n",
" Matplotlib style : ../fidle/talk.mplstyle\n",
" TensorFlow version : 2.0.0\n",
" Keras version : 2.2.4-tf\n"
]
}
],
"source": [
"source": [
"import numpy as np\n",
"import numpy as np\n",
"\n",
"\n",
...
@@ -57,9 +70,18 @@
...
@@ -57,9 +70,18 @@
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count":
null
,
"execution_count":
2
,
"metadata": {},
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(60000, 28, 28, 1)\n",
"(10000, 28, 28, 1)\n"
]
}
],
"source": [
"source": [
"(x_train, y_train), (x_test, y_test) = mnist.load_data()\n",
"(x_train, y_train), (x_test, y_test) = mnist.load_data()\n",
"\n",
"\n",
...
@@ -80,14 +102,23 @@
...
@@ -80,14 +102,23 @@
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count":
null
,
"execution_count":
3
,
"metadata": {},
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model initialized.\n",
"Outputs will be in : ./run/001\n"
]
}
],
"source": [
"source": [
"# reload(modules.vae)\n",
"# reload(modules.vae)\n",
"# reload(modules.callbacks)\n",
"# reload(modules.callbacks)\n",
"\n",
"\n",
"tag = '00
0
'\n",
"tag = '00
1
'\n",
"\n",
"\n",
"input_shape = (28,28,1)\n",
"input_shape = (28,28,1)\n",
"z_dim = 2\n",
"z_dim = 2\n",
...
@@ -122,7 +153,7 @@
...
@@ -122,7 +153,7 @@
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count":
null
,
"execution_count":
4
,
"metadata": {},
"metadata": {},
"outputs": [],
"outputs": [],
"source": [
"source": [
...
@@ -141,14 +172,14 @@
...
@@ -141,14 +172,14 @@
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count":
null
,
"execution_count":
5
,
"metadata": {},
"metadata": {},
"outputs": [],
"outputs": [],
"source": [
"source": [
"batch_size = 100\n",
"batch_size = 100\n",
"epochs = 200\n",
"epochs = 200\n",
"image_periodicity = 1 #
in
epoch\n",
"image_periodicity = 1 #
for each
epoch\n",
"chkpt_periodicity = 2
# in
epoch\n",
"chkpt_periodicity = 2
# for each
epoch\n",
"initial_epoch = 0\n",
"initial_epoch = 0\n",
"dataset_size = 1"
"dataset_size = 1"
]
]
...
@@ -157,7 +188,166 @@
...
@@ -157,7 +188,166 @@
"cell_type": "code",
"cell_type": "code",
"execution_count": null,
"execution_count": null,
"metadata": {},
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train on 60000 samples, validate on 10000 samples\n",
"Epoch 1/200\n",
" 100/60000 [..............................] - ETA: 23:40 - loss: 231.4378 - vae_r_loss: 231.4373 - vae_kl_loss: 5.3801e-04WARNING:tensorflow:Method (on_train_batch_end) is slow compared to the batch update (0.251492). Check your callbacks.\n",
"60000/60000 [==============================] - 6s 101us/sample - loss: 67.7431 - vae_r_loss: 65.0691 - vae_kl_loss: 2.6740 - val_loss: 55.6598 - val_vae_r_loss: 52.4039 - val_vae_kl_loss: 3.2560\n",
"Epoch 2/200\n",
"60000/60000 [==============================] - 3s 58us/sample - loss: 54.0334 - vae_r_loss: 50.4695 - vae_kl_loss: 3.5639 - val_loss: 52.9105 - val_vae_r_loss: 49.1433 - val_vae_kl_loss: 3.7672\n",
"Epoch 3/200\n",
"60000/60000 [==============================] - 3s 57us/sample - loss: 51.8937 - vae_r_loss: 47.9195 - vae_kl_loss: 3.9743 - val_loss: 51.1775 - val_vae_r_loss: 47.0874 - val_vae_kl_loss: 4.0901\n",
"Epoch 4/200\n",
"60000/60000 [==============================] - 4s 59us/sample - loss: 50.4622 - vae_r_loss: 46.1359 - vae_kl_loss: 4.3264 - val_loss: 49.8507 - val_vae_r_loss: 45.2015 - val_vae_kl_loss: 4.6492\n",
"Epoch 5/200\n",
"60000/60000 [==============================] - 3s 57us/sample - loss: 49.3577 - vae_r_loss: 44.8123 - vae_kl_loss: 4.5454 - val_loss: 48.9416 - val_vae_r_loss: 44.3832 - val_vae_kl_loss: 4.5584\n",
"Epoch 6/200\n",
"60000/60000 [==============================] - 3s 58us/sample - loss: 48.5603 - vae_r_loss: 43.8800 - vae_kl_loss: 4.6803 - val_loss: 48.1800 - val_vae_r_loss: 43.5046 - val_vae_kl_loss: 4.6754\n",
"Epoch 7/200\n",
"60000/60000 [==============================] - 3s 57us/sample - loss: 48.0286 - vae_r_loss: 43.2646 - vae_kl_loss: 4.7640 - val_loss: 47.9362 - val_vae_r_loss: 43.2833 - val_vae_kl_loss: 4.6529\n",
"Epoch 8/200\n",
"60000/60000 [==============================] - 3s 58us/sample - loss: 47.6163 - vae_r_loss: 42.7828 - vae_kl_loss: 4.8336 - val_loss: 47.6161 - val_vae_r_loss: 42.7176 - val_vae_kl_loss: 4.8985\n",
"Epoch 9/200\n",
"60000/60000 [==============================] - 3s 58us/sample - loss: 47.2654 - vae_r_loss: 42.3804 - vae_kl_loss: 4.8850 - val_loss: 47.1385 - val_vae_r_loss: 42.2280 - val_vae_kl_loss: 4.9105\n",
"Epoch 10/200\n",
"WARNING:tensorflow:Method (on_train_batch_end) is slow compared to the batch update (0.732872). Check your callbacks.\n",
" 100/60000 [..............................] - ETA: 7:23 - loss: 47.8688 - vae_r_loss: 43.0966 - vae_kl_loss: 4.7722WARNING:tensorflow:Method (on_train_batch_end) is slow compared to the batch update (0.366450). Check your callbacks.\n",
"60000/60000 [==============================] - 4s 70us/sample - loss: 46.9698 - vae_r_loss: 42.0353 - vae_kl_loss: 4.9345 - val_loss: 47.0246 - val_vae_r_loss: 42.1103 - val_vae_kl_loss: 4.9143\n",
"Epoch 11/200\n",
"60000/60000 [==============================] - 3s 57us/sample - loss: 46.7538 - vae_r_loss: 41.7733 - vae_kl_loss: 4.9805 - val_loss: 46.9033 - val_vae_r_loss: 41.9019 - val_vae_kl_loss: 5.0014\n",
"Epoch 12/200\n",
"60000/60000 [==============================] - 3s 58us/sample - loss: 46.4962 - vae_r_loss: 41.4867 - vae_kl_loss: 5.0095 - val_loss: 46.6990 - val_vae_r_loss: 41.8006 - val_vae_kl_loss: 4.8985\n",
"Epoch 13/200\n",
"60000/60000 [==============================] - 3s 57us/sample - loss: 46.3232 - vae_r_loss: 41.2603 - vae_kl_loss: 5.0629 - val_loss: 46.6737 - val_vae_r_loss: 41.4675 - val_vae_kl_loss: 5.2061\n",
"Epoch 14/200\n",
"60000/60000 [==============================] - 3s 58us/sample - loss: 46.1505 - vae_r_loss: 41.0678 - vae_kl_loss: 5.0828 - val_loss: 46.3871 - val_vae_r_loss: 41.4687 - val_vae_kl_loss: 4.9184\n",
"Epoch 15/200\n",
"60000/60000 [==============================] - 3s 57us/sample - loss: 45.9750 - vae_r_loss: 40.8533 - vae_kl_loss: 5.1217 - val_loss: 46.1730 - val_vae_r_loss: 41.0982 - val_vae_kl_loss: 5.0748\n",
"Epoch 16/200\n",
"60000/60000 [==============================] - 3s 57us/sample - loss: 45.8053 - vae_r_loss: 40.6467 - vae_kl_loss: 5.1586 - val_loss: 46.2439 - val_vae_r_loss: 41.1142 - val_vae_kl_loss: 5.1297\n",
"Epoch 17/200\n",
"60000/60000 [==============================] - 3s 57us/sample - loss: 45.6415 - vae_r_loss: 40.4657 - vae_kl_loss: 5.1758 - val_loss: 46.0754 - val_vae_r_loss: 41.0632 - val_vae_kl_loss: 5.0122\n",
"Epoch 18/200\n",
"60000/60000 [==============================] - 3s 58us/sample - loss: 45.5121 - vae_r_loss: 40.3147 - vae_kl_loss: 5.1974 - val_loss: 45.8663 - val_vae_r_loss: 40.5329 - val_vae_kl_loss: 5.3334\n",
"Epoch 19/200\n",
"60000/60000 [==============================] - 3s 56us/sample - loss: 45.3686 - vae_r_loss: 40.1475 - vae_kl_loss: 5.2211 - val_loss: 46.2054 - val_vae_r_loss: 41.1238 - val_vae_kl_loss: 5.0816\n",
"Epoch 20/200\n",
"60000/60000 [==============================] - 3s 58us/sample - loss: 45.2161 - vae_r_loss: 39.9703 - vae_kl_loss: 5.2458 - val_loss: 45.7448 - val_vae_r_loss: 40.6166 - val_vae_kl_loss: 5.1283\n",
"Epoch 21/200\n",
"60000/60000 [==============================] - 3s 56us/sample - loss: 45.1159 - vae_r_loss: 39.8419 - vae_kl_loss: 5.2740 - val_loss: 45.8612 - val_vae_r_loss: 40.8692 - val_vae_kl_loss: 4.9920\n",
"Epoch 22/200\n",
"60000/60000 [==============================] - 3s 57us/sample - loss: 44.9881 - vae_r_loss: 39.7023 - vae_kl_loss: 5.2857 - val_loss: 45.8085 - val_vae_r_loss: 40.2675 - val_vae_kl_loss: 5.5410\n",
"Epoch 23/200\n",
"60000/60000 [==============================] - 3s 57us/sample - loss: 44.8471 - vae_r_loss: 39.5384 - vae_kl_loss: 5.3087 - val_loss: 45.4330 - val_vae_r_loss: 40.0743 - val_vae_kl_loss: 5.3587\n",
"Epoch 24/200\n",
"60000/60000 [==============================] - 3s 58us/sample - loss: 44.7550 - vae_r_loss: 39.4362 - vae_kl_loss: 5.3188 - val_loss: 45.3320 - val_vae_r_loss: 39.9992 - val_vae_kl_loss: 5.3328\n",
"Epoch 25/200\n",
"60000/60000 [==============================] - 3s 57us/sample - loss: 44.6692 - vae_r_loss: 39.3461 - vae_kl_loss: 5.3232 - val_loss: 45.3552 - val_vae_r_loss: 40.0258 - val_vae_kl_loss: 5.3294\n",
"Epoch 26/200\n",
"60000/60000 [==============================] - 3s 58us/sample - loss: 44.5891 - vae_r_loss: 39.2333 - vae_kl_loss: 5.3558 - val_loss: 45.2681 - val_vae_r_loss: 39.9015 - val_vae_kl_loss: 5.3666\n",
"Epoch 27/200\n",
"60000/60000 [==============================] - 3s 56us/sample - loss: 44.5072 - vae_r_loss: 39.1374 - vae_kl_loss: 5.3698 - val_loss: 45.3209 - val_vae_r_loss: 39.9636 - val_vae_kl_loss: 5.3574\n",
"Epoch 28/200\n",
"60000/60000 [==============================] - 3s 58us/sample - loss: 44.4180 - vae_r_loss: 39.0149 - vae_kl_loss: 5.4031 - val_loss: 45.2435 - val_vae_r_loss: 39.7765 - val_vae_kl_loss: 5.4671\n",
"Epoch 29/200\n",
"60000/60000 [==============================] - 3s 57us/sample - loss: 44.3102 - vae_r_loss: 38.9046 - vae_kl_loss: 5.4057 - val_loss: 45.2258 - val_vae_r_loss: 39.8441 - val_vae_kl_loss: 5.3817\n",
"Epoch 30/200\n",
"60000/60000 [==============================] - 3s 58us/sample - loss: 44.2489 - vae_r_loss: 38.8299 - vae_kl_loss: 5.4190 - val_loss: 45.0044 - val_vae_r_loss: 39.6516 - val_vae_kl_loss: 5.3528\n",
"Epoch 31/200\n",
"60000/60000 [==============================] - 3s 57us/sample - loss: 44.1732 - vae_r_loss: 38.7482 - vae_kl_loss: 5.4249 - val_loss: 45.0000 - val_vae_r_loss: 39.5609 - val_vae_kl_loss: 5.4391\n",
"Epoch 32/200\n",
"60000/60000 [==============================] - 3s 58us/sample - loss: 44.0894 - vae_r_loss: 38.6580 - vae_kl_loss: 5.4314 - val_loss: 44.9769 - val_vae_r_loss: 39.5384 - val_vae_kl_loss: 5.4385\n",
"Epoch 33/200\n",
"60000/60000 [==============================] - 3s 57us/sample - loss: 44.0582 - vae_r_loss: 38.6092 - vae_kl_loss: 5.4490 - val_loss: 44.9346 - val_vae_r_loss: 39.3805 - val_vae_kl_loss: 5.5541\n",
"Epoch 34/200\n",
"60000/60000 [==============================] - 3s 57us/sample - loss: 43.9458 - vae_r_loss: 38.4818 - vae_kl_loss: 5.4640 - val_loss: 45.0624 - val_vae_r_loss: 39.5811 - val_vae_kl_loss: 5.4813\n",
"Epoch 35/200\n",
"60000/60000 [==============================] - 3s 56us/sample - loss: 43.8850 - vae_r_loss: 38.4031 - vae_kl_loss: 5.4819 - val_loss: 45.0285 - val_vae_r_loss: 39.5350 - val_vae_kl_loss: 5.4935\n",
"Epoch 36/200\n",
"60000/60000 [==============================] - 3s 58us/sample - loss: 43.8698 - vae_r_loss: 38.3779 - vae_kl_loss: 5.4918 - val_loss: 44.9170 - val_vae_r_loss: 39.5714 - val_vae_kl_loss: 5.3456\n",
"Epoch 37/200\n",
"60000/60000 [==============================] - 3s 57us/sample - loss: 43.7739 - vae_r_loss: 38.2723 - vae_kl_loss: 5.5016 - val_loss: 44.8441 - val_vae_r_loss: 39.3665 - val_vae_kl_loss: 5.4776\n",
"Epoch 38/200\n",
"60000/60000 [==============================] - 3s 57us/sample - loss: 43.7084 - vae_r_loss: 38.1933 - vae_kl_loss: 5.5151 - val_loss: 44.9233 - val_vae_r_loss: 39.5526 - val_vae_kl_loss: 5.3706\n",
"Epoch 39/200\n",
"60000/60000 [==============================] - 3s 57us/sample - loss: 43.6626 - vae_r_loss: 38.1320 - vae_kl_loss: 5.5306 - val_loss: 44.6793 - val_vae_r_loss: 39.2304 - val_vae_kl_loss: 5.4489\n",
"Epoch 40/200\n",
"60000/60000 [==============================] - 3s 58us/sample - loss: 43.5838 - vae_r_loss: 38.0592 - vae_kl_loss: 5.5246 - val_loss: 44.6130 - val_vae_r_loss: 39.0715 - val_vae_kl_loss: 5.5415\n",
"Epoch 41/200\n",
"60000/60000 [==============================] - 3s 56us/sample - loss: 43.5194 - vae_r_loss: 37.9840 - vae_kl_loss: 5.5354 - val_loss: 44.8512 - val_vae_r_loss: 39.6158 - val_vae_kl_loss: 5.2354\n",
"Epoch 42/200\n",
"60000/60000 [==============================] - 3s 57us/sample - loss: 43.5129 - vae_r_loss: 37.9786 - vae_kl_loss: 5.5343 - val_loss: 44.6991 - val_vae_r_loss: 39.2098 - val_vae_kl_loss: 5.4894\n",
"Epoch 43/200\n",
"60000/60000 [==============================] - 3s 56us/sample - loss: 43.4707 - vae_r_loss: 37.9237 - vae_kl_loss: 5.5470 - val_loss: 44.7121 - val_vae_r_loss: 39.2446 - val_vae_kl_loss: 5.4675\n",
"Epoch 44/200\n",
"60000/60000 [==============================] - 3s 57us/sample - loss: 43.3832 - vae_r_loss: 37.8227 - vae_kl_loss: 5.5604 - val_loss: 44.9172 - val_vae_r_loss: 39.3446 - val_vae_kl_loss: 5.5726\n",
"Epoch 45/200\n",
"60000/60000 [==============================] - 3s 57us/sample - loss: 43.3868 - vae_r_loss: 37.8075 - vae_kl_loss: 5.5793 - val_loss: 44.5718 - val_vae_r_loss: 39.0284 - val_vae_kl_loss: 5.5434\n",
"Epoch 46/200\n",
"60000/60000 [==============================] - 3s 57us/sample - loss: 43.2774 - vae_r_loss: 37.6953 - vae_kl_loss: 5.5821 - val_loss: 44.6954 - val_vae_r_loss: 39.1276 - val_vae_kl_loss: 5.5678\n",
"Epoch 47/200\n",
"60000/60000 [==============================] - 3s 56us/sample - loss: 43.2765 - vae_r_loss: 37.6813 - vae_kl_loss: 5.5952 - val_loss: 44.6153 - val_vae_r_loss: 38.9606 - val_vae_kl_loss: 5.6547\n",
"Epoch 48/200\n",
"60000/60000 [==============================] - 3s 58us/sample - loss: 43.2385 - vae_r_loss: 37.6431 - vae_kl_loss: 5.5954 - val_loss: 44.5508 - val_vae_r_loss: 39.0830 - val_vae_kl_loss: 5.4678\n",
"Epoch 49/200\n",
"60000/60000 [==============================] - 3s 56us/sample - loss: 43.1847 - vae_r_loss: 37.5822 - vae_kl_loss: 5.6025 - val_loss: 44.8277 - val_vae_r_loss: 39.1688 - val_vae_kl_loss: 5.6589\n",
"Epoch 50/200\n",
"60000/60000 [==============================] - 4s 58us/sample - loss: 43.1557 - vae_r_loss: 37.5533 - vae_kl_loss: 5.6024 - val_loss: 44.5082 - val_vae_r_loss: 38.9529 - val_vae_kl_loss: 5.5553\n",
"Epoch 51/200\n",
"60000/60000 [==============================] - 3s 56us/sample - loss: 43.0726 - vae_r_loss: 37.4533 - vae_kl_loss: 5.6193 - val_loss: 44.6332 - val_vae_r_loss: 38.9104 - val_vae_kl_loss: 5.7228\n",
"Epoch 52/200\n",
"60000/60000 [==============================] - 3s 57us/sample - loss: 43.1003 - vae_r_loss: 37.4708 - vae_kl_loss: 5.6295 - val_loss: 44.5279 - val_vae_r_loss: 39.0846 - val_vae_kl_loss: 5.4433\n",
"Epoch 53/200\n",
"60000/60000 [==============================] - 3s 56us/sample - loss: 43.0121 - vae_r_loss: 37.3923 - vae_kl_loss: 5.6198 - val_loss: 44.5675 - val_vae_r_loss: 38.9651 - val_vae_kl_loss: 5.6024\n",
"Epoch 54/200\n",
"60000/60000 [==============================] - 3s 57us/sample - loss: 42.9750 - vae_r_loss: 37.3273 - vae_kl_loss: 5.6477 - val_loss: 44.6084 - val_vae_r_loss: 39.0057 - val_vae_kl_loss: 5.6027\n",
"Epoch 55/200\n",
"60000/60000 [==============================] - 3s 57us/sample - loss: 42.9669 - vae_r_loss: 37.3124 - vae_kl_loss: 5.6545 - val_loss: 44.4369 - val_vae_r_loss: 38.7499 - val_vae_kl_loss: 5.6870\n",
"Epoch 56/200\n",
"60000/60000 [==============================] - 3s 57us/sample - loss: 42.9172 - vae_r_loss: 37.2666 - vae_kl_loss: 5.6506 - val_loss: 44.4817 - val_vae_r_loss: 38.8071 - val_vae_kl_loss: 5.6747\n",
"Epoch 57/200\n",
"60000/60000 [==============================] - 3s 56us/sample - loss: 42.8719 - vae_r_loss: 37.2088 - vae_kl_loss: 5.6630 - val_loss: 44.7545 - val_vae_r_loss: 39.1340 - val_vae_kl_loss: 5.6205\n",
"Epoch 58/200\n",
"60000/60000 [==============================] - 3s 57us/sample - loss: 42.8724 - vae_r_loss: 37.2070 - vae_kl_loss: 5.6654 - val_loss: 44.4428 - val_vae_r_loss: 38.8374 - val_vae_kl_loss: 5.6054\n",
"Epoch 59/200\n",
"60000/60000 [==============================] - 3s 57us/sample - loss: 42.8085 - vae_r_loss: 37.1356 - vae_kl_loss: 5.6729 - val_loss: 44.3657 - val_vae_r_loss: 38.8973 - val_vae_kl_loss: 5.4684\n",
"Epoch 60/200\n",
"60000/60000 [==============================] - 3s 57us/sample - loss: 42.7711 - vae_r_loss: 37.1025 - vae_kl_loss: 5.6687 - val_loss: 44.5526 - val_vae_r_loss: 38.7923 - val_vae_kl_loss: 5.7603\n",
"Epoch 61/200\n",
"60000/60000 [==============================] - 3s 56us/sample - loss: 42.7549 - vae_r_loss: 37.0712 - vae_kl_loss: 5.6837 - val_loss: 44.6274 - val_vae_r_loss: 39.1211 - val_vae_kl_loss: 5.5063\n",
"Epoch 62/200\n",
"60000/60000 [==============================] - 3s 57us/sample - loss: 42.7314 - vae_r_loss: 37.0368 - vae_kl_loss: 5.6946 - val_loss: 44.3828 - val_vae_r_loss: 38.8327 - val_vae_kl_loss: 5.5502\n",
"Epoch 63/200\n",
"60000/60000 [==============================] - 3s 56us/sample - loss: 42.6688 - vae_r_loss: 36.9835 - vae_kl_loss: 5.6853 - val_loss: 44.4869 - val_vae_r_loss: 38.8497 - val_vae_kl_loss: 5.6372\n",
"Epoch 64/200\n",
"60000/60000 [==============================] - 3s 57us/sample - loss: 42.6714 - vae_r_loss: 36.9633 - vae_kl_loss: 5.7080 - val_loss: 44.4562 - val_vae_r_loss: 38.7178 - val_vae_kl_loss: 5.7384\n",
"Epoch 65/200\n",
"60000/60000 [==============================] - 3s 56us/sample - loss: 42.6547 - vae_r_loss: 36.9360 - vae_kl_loss: 5.7187 - val_loss: 44.4947 - val_vae_r_loss: 38.8561 - val_vae_kl_loss: 5.6386\n",
"Epoch 66/200\n",
"60000/60000 [==============================] - 3s 57us/sample - loss: 42.5807 - vae_r_loss: 36.8625 - vae_kl_loss: 5.7182 - val_loss: 44.4270 - val_vae_r_loss: 38.7251 - val_vae_kl_loss: 5.7019\n",
"Epoch 67/200\n",
"60000/60000 [==============================] - 3s 56us/sample - loss: 42.5664 - vae_r_loss: 36.8466 - vae_kl_loss: 5.7197 - val_loss: 44.5878 - val_vae_r_loss: 38.8787 - val_vae_kl_loss: 5.7091\n",
"Epoch 68/200\n",
"60000/60000 [==============================] - 3s 57us/sample - loss: 42.5503 - vae_r_loss: 36.8269 - vae_kl_loss: 5.7235 - val_loss: 44.6236 - val_vae_r_loss: 38.8846 - val_vae_kl_loss: 5.7390\n",
"Epoch 69/200\n",
"60000/60000 [==============================] - 3s 56us/sample - loss: 42.5057 - vae_r_loss: 36.7706 - vae_kl_loss: 5.7352 - val_loss: 44.5720 - val_vae_r_loss: 38.9196 - val_vae_kl_loss: 5.6525\n",
"Epoch 70/200\n",
"60000/60000 [==============================] - 3s 57us/sample - loss: 42.4955 - vae_r_loss: 36.7553 - vae_kl_loss: 5.7402 - val_loss: 44.4059 - val_vae_r_loss: 38.8886 - val_vae_kl_loss: 5.5173\n",
"Epoch 71/200\n",
"60000/60000 [==============================] - 3s 56us/sample - loss: 42.4649 - vae_r_loss: 36.7251 - vae_kl_loss: 5.7398 - val_loss: 44.5864 - val_vae_r_loss: 38.8203 - val_vae_kl_loss: 5.7661\n",
"Epoch 72/200\n",
"60000/60000 [==============================] - 3s 58us/sample - loss: 42.4907 - vae_r_loss: 36.7440 - vae_kl_loss: 5.7467 - val_loss: 44.3493 - val_vae_r_loss: 38.6765 - val_vae_kl_loss: 5.6727\n",
"Epoch 73/200\n",
"60000/60000 [==============================] - 3s 56us/sample - loss: 42.4224 - vae_r_loss: 36.6558 - vae_kl_loss: 5.7666 - val_loss: 44.5477 - val_vae_r_loss: 38.7588 - val_vae_kl_loss: 5.7889\n",
"Epoch 74/200\n",
"43100/60000 [====================>.........] - ETA: 0s - loss: 42.3141 - vae_r_loss: 36.5576 - vae_kl_loss: 5.7565"
]
}
],
"source": [
"source": [
"vae.train(x_train,\n",
"vae.train(x_train,\n",
" x_test,\n",
" x_test,\n",
...
...
%% Cell type:markdown id: tags:
%% Cell type:markdown id: tags:
Variational AutoEncoder
Variational AutoEncoder
=======================
=======================
---
---
Formation Introduction au Deep Learning (FIDLE) - S. Arias, E. Maldonado, JL. Parouty - CNRS/SARI/DEVLOG - 2020
Formation Introduction au Deep Learning (FIDLE) - S. Arias, E. Maldonado, JL. Parouty - CNRS/SARI/DEVLOG - 2020
## Variational AutoEncoder (VAE), with MNIST Dataset
## Variational AutoEncoder (VAE), with MNIST Dataset
%% Cell type:markdown id: tags:
%% Cell type:markdown id: tags:
## Step 1 - Init python stuff
## Step 1 - Init python stuff
%% Cell type:code id: tags:
%% Cell type:code id: tags:
```
python
```
python
import
numpy
as
np
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow
as
tf
import
tensorflow.keras
as
keras
import
tensorflow.keras
as
keras
import
tensorflow.keras.datasets.mnist
as
mnist
import
tensorflow.keras.datasets.mnist
as
mnist
import
modules.vae
import
modules.vae
import
matplotlib.pyplot
as
plt
import
matplotlib.pyplot
as
plt
import
matplotlib
import
matplotlib
import
seaborn
as
sns
import
seaborn
as
sns
import
os
,
sys
,
h5py
,
json
import
os
,
sys
,
h5py
,
json
from
importlib
import
reload
from
importlib
import
reload
sys
.
path
.
append
(
'
..
'
)
sys
.
path
.
append
(
'
..
'
)
import
fidle.pwk
as
ooo
import
fidle.pwk
as
ooo
ooo
.
init
()
ooo
.
init
()
```
```
%% Output
IDLE 2020 - Practical Work Module
Version : 0.2.5
Run time : Tuesday 4 February 2020, 00:10:15
Matplotlib style : ../fidle/talk.mplstyle
TensorFlow version : 2.0.0
Keras version : 2.2.4-tf
%% Cell type:markdown id: tags:
%% Cell type:markdown id: tags:
## Step 2 - Get data
## Step 2 - Get data
%% Cell type:code id: tags:
%% Cell type:code id: tags:
```
python
```
python
(
x_train
,
y_train
),
(
x_test
,
y_test
)
=
mnist
.
load_data
()
(
x_train
,
y_train
),
(
x_test
,
y_test
)
=
mnist
.
load_data
()
x_train
=
x_train
.
astype
(
'
float32
'
)
/
255.
x_train
=
x_train
.
astype
(
'
float32
'
)
/
255.
x_train
=
np
.
expand_dims
(
x_train
,
axis
=
3
)
x_train
=
np
.
expand_dims
(
x_train
,
axis
=
3
)
x_test
=
x_test
.
astype
(
'
float32
'
)
/
255.
x_test
=
x_test
.
astype
(
'
float32
'
)
/
255.
x_test
=
np
.
expand_dims
(
x_test
,
axis
=
3
)
x_test
=
np
.
expand_dims
(
x_test
,
axis
=
3
)
print
(
x_train
.
shape
)
print
(
x_train
.
shape
)
print
(
x_test
.
shape
)
print
(
x_test
.
shape
)
```
```
%% Output
(60000, 28, 28, 1)
(10000, 28, 28, 1)
%% Cell type:markdown id: tags:
%% Cell type:markdown id: tags:
## Step 3 - Get VAE model
## Step 3 - Get VAE model
%% Cell type:code id: tags:
%% Cell type:code id: tags:
```
python
```
python
# reload(modules.vae)
# reload(modules.vae)
# reload(modules.callbacks)
# reload(modules.callbacks)
tag
=
'
00
0
'
tag
=
'
00
1
'
input_shape
=
(
28
,
28
,
1
)
input_shape
=
(
28
,
28
,
1
)
z_dim
=
2
z_dim
=
2
verbose
=
0
verbose
=
0
encoder
=
[
{
'
type
'
:
'
Conv2D
'
,
'
filters
'
:
32
,
'
kernel_size
'
:(
3
,
3
),
'
strides
'
:
1
,
'
padding
'
:
'
same
'
,
'
activation
'
:
'
relu
'
},
encoder
=
[
{
'
type
'
:
'
Conv2D
'
,
'
filters
'
:
32
,
'
kernel_size
'
:(
3
,
3
),
'
strides
'
:
1
,
'
padding
'
:
'
same
'
,
'
activation
'
:
'
relu
'
},
{
'
type
'
:
'
Conv2D
'
,
'
filters
'
:
64
,
'
kernel_size
'
:(
3
,
3
),
'
strides
'
:
2
,
'
padding
'
:
'
same
'
,
'
activation
'
:
'
relu
'
},
{
'
type
'
:
'
Conv2D
'
,
'
filters
'
:
64
,
'
kernel_size
'
:(
3
,
3
),
'
strides
'
:
2
,
'
padding
'
:
'
same
'
,
'
activation
'
:
'
relu
'
},
{
'
type
'
:
'
Conv2D
'
,
'
filters
'
:
64
,
'
kernel_size
'
:(
3
,
3
),
'
strides
'
:
2
,
'
padding
'
:
'
same
'
,
'
activation
'
:
'
relu
'
},
{
'
type
'
:
'
Conv2D
'
,
'
filters
'
:
64
,
'
kernel_size
'
:(
3
,
3
),
'
strides
'
:
2
,
'
padding
'
:
'
same
'
,
'
activation
'
:
'
relu
'
},
{
'
type
'
:
'
Conv2D
'
,
'
filters
'
:
64
,
'
kernel_size
'
:(
3
,
3
),
'
strides
'
:
1
,
'
padding
'
:
'
same
'
,
'
activation
'
:
'
relu
'
}
{
'
type
'
:
'
Conv2D
'
,
'
filters
'
:
64
,
'
kernel_size
'
:(
3
,
3
),
'
strides
'
:
1
,
'
padding
'
:
'
same
'
,
'
activation
'
:
'
relu
'
}
]
]
decoder
=
[
{
'
type
'
:
'
Conv2DT
'
,
'
filters
'
:
64
,
'
kernel_size
'
:(
3
,
3
),
'
strides
'
:
1
,
'
padding
'
:
'
same
'
,
'
activation
'
:
'
relu
'
},
decoder
=
[
{
'
type
'
:
'
Conv2DT
'
,
'
filters
'
:
64
,
'
kernel_size
'
:(
3
,
3
),
'
strides
'
:
1
,
'
padding
'
:
'
same
'
,
'
activation
'
:
'
relu
'
},
{
'
type
'
:
'
Conv2DT
'
,
'
filters
'
:
64
,
'
kernel_size
'
:(
3
,
3
),
'
strides
'
:
2
,
'
padding
'
:
'
same
'
,
'
activation
'
:
'
relu
'
},
{
'
type
'
:
'
Conv2DT
'
,
'
filters
'
:
64
,
'
kernel_size
'
:(
3
,
3
),
'
strides
'
:
2
,
'
padding
'
:
'
same
'
,
'
activation
'
:
'
relu
'
},
{
'
type
'
:
'
Conv2DT
'
,
'
filters
'
:
32
,
'
kernel_size
'
:(
3
,
3
),
'
strides
'
:
2
,
'
padding
'
:
'
same
'
,
'
activation
'
:
'
relu
'
},
{
'
type
'
:
'
Conv2DT
'
,
'
filters
'
:
32
,
'
kernel_size
'
:(
3
,
3
),
'
strides
'
:
2
,
'
padding
'
:
'
same
'
,
'
activation
'
:
'
relu
'
},
{
'
type
'
:
'
Conv2DT
'
,
'
filters
'
:
1
,
'
kernel_size
'
:(
3
,
3
),
'
strides
'
:
1
,
'
padding
'
:
'
same
'
,
'
activation
'
:
'
sigmoid
'
}
{
'
type
'
:
'
Conv2DT
'
,
'
filters
'
:
1
,
'
kernel_size
'
:(
3
,
3
),
'
strides
'
:
1
,
'
padding
'
:
'
same
'
,
'
activation
'
:
'
sigmoid
'
}
]
]
vae
=
modules
.
vae
.
VariationalAutoencoder
(
input_shape
=
input_shape
,
vae
=
modules
.
vae
.
VariationalAutoencoder
(
input_shape
=
input_shape
,
encoder_layers
=
encoder
,
encoder_layers
=
encoder
,
decoder_layers
=
decoder
,
decoder_layers
=
decoder
,
z_dim
=
z_dim
,
z_dim
=
z_dim
,
verbose
=
verbose
,
verbose
=
verbose
,
run_tag
=
tag
)
run_tag
=
tag
)
```
```
%% Output
Model initialized.
Outputs will be in : ./run/001
%% Cell type:markdown id: tags:
%% Cell type:markdown id: tags:
## Step 4 - Compile it
## Step 4 - Compile it
%% Cell type:code id: tags:
%% Cell type:code id: tags:
```
python
```
python
learning_rate
=
0.0005
learning_rate
=
0.0005
r_loss_factor
=
1000
r_loss_factor
=
1000
vae
.
compile
(
learning_rate
,
r_loss_factor
)
vae
.
compile
(
learning_rate
,
r_loss_factor
)
```
```
%% Cell type:markdown id: tags:
%% Cell type:markdown id: tags:
## Step 5 - Train
## Step 5 - Train
%% Cell type:code id: tags:
%% Cell type:code id: tags:
```
python
```
python
batch_size
=
100
batch_size
=
100
epochs
=
200
epochs
=
200
image_periodicity
=
1
#
in
epoch
image_periodicity
=
1
#
for each
epoch
chkpt_periodicity
=
2
# in
epoch
chkpt_periodicity
=
2
# for each
epoch
initial_epoch
=
0
initial_epoch
=
0
dataset_size
=
1
dataset_size
=
1
```
```
%% Cell type:code id: tags:
%% Cell type:code id: tags:
```
python
```
python
vae
.
train
(
x_train
,
vae
.
train
(
x_train
,
x_test
,
x_test
,
batch_size
=
batch_size
,
batch_size
=
batch_size
,
epochs
=
epochs
,
epochs
=
epochs
,
image_periodicity
=
image_periodicity
,
image_periodicity
=
image_periodicity
,
chkpt_periodicity
=
chkpt_periodicity
,
chkpt_periodicity
=
chkpt_periodicity
,
initial_epoch
=
initial_epoch
,
initial_epoch
=
initial_epoch
,
dataset_size
=
dataset_size
,
dataset_size
=
dataset_size
,
lr_decay
=
1
lr_decay
=
1
)
)
```
```
%% Output
Train on 60000 samples, validate on 10000 samples
Epoch 1/200
100/60000 [..............................] - ETA: 23:40 - loss: 231.4378 - vae_r_loss: 231.4373 - vae_kl_loss: 5.3801e-04WARNING:tensorflow:Method (on_train_batch_end) is slow compared to the batch update (0.251492). Check your callbacks.
60000/60000 [==============================] - 6s 101us/sample - loss: 67.7431 - vae_r_loss: 65.0691 - vae_kl_loss: 2.6740 - val_loss: 55.6598 - val_vae_r_loss: 52.4039 - val_vae_kl_loss: 3.2560
Epoch 2/200
60000/60000 [==============================] - 3s 58us/sample - loss: 54.0334 - vae_r_loss: 50.4695 - vae_kl_loss: 3.5639 - val_loss: 52.9105 - val_vae_r_loss: 49.1433 - val_vae_kl_loss: 3.7672
Epoch 3/200
60000/60000 [==============================] - 3s 57us/sample - loss: 51.8937 - vae_r_loss: 47.9195 - vae_kl_loss: 3.9743 - val_loss: 51.1775 - val_vae_r_loss: 47.0874 - val_vae_kl_loss: 4.0901
Epoch 4/200
60000/60000 [==============================] - 4s 59us/sample - loss: 50.4622 - vae_r_loss: 46.1359 - vae_kl_loss: 4.3264 - val_loss: 49.8507 - val_vae_r_loss: 45.2015 - val_vae_kl_loss: 4.6492
Epoch 5/200
60000/60000 [==============================] - 3s 57us/sample - loss: 49.3577 - vae_r_loss: 44.8123 - vae_kl_loss: 4.5454 - val_loss: 48.9416 - val_vae_r_loss: 44.3832 - val_vae_kl_loss: 4.5584
Epoch 6/200
60000/60000 [==============================] - 3s 58us/sample - loss: 48.5603 - vae_r_loss: 43.8800 - vae_kl_loss: 4.6803 - val_loss: 48.1800 - val_vae_r_loss: 43.5046 - val_vae_kl_loss: 4.6754
Epoch 7/200
60000/60000 [==============================] - 3s 57us/sample - loss: 48.0286 - vae_r_loss: 43.2646 - vae_kl_loss: 4.7640 - val_loss: 47.9362 - val_vae_r_loss: 43.2833 - val_vae_kl_loss: 4.6529
Epoch 8/200
60000/60000 [==============================] - 3s 58us/sample - loss: 47.6163 - vae_r_loss: 42.7828 - vae_kl_loss: 4.8336 - val_loss: 47.6161 - val_vae_r_loss: 42.7176 - val_vae_kl_loss: 4.8985
Epoch 9/200
60000/60000 [==============================] - 3s 58us/sample - loss: 47.2654 - vae_r_loss: 42.3804 - vae_kl_loss: 4.8850 - val_loss: 47.1385 - val_vae_r_loss: 42.2280 - val_vae_kl_loss: 4.9105
Epoch 10/200
WARNING:tensorflow:Method (on_train_batch_end) is slow compared to the batch update (0.732872). Check your callbacks.
100/60000 [..............................] - ETA: 7:23 - loss: 47.8688 - vae_r_loss: 43.0966 - vae_kl_loss: 4.7722WARNING:tensorflow:Method (on_train_batch_end) is slow compared to the batch update (0.366450). Check your callbacks.
60000/60000 [==============================] - 4s 70us/sample - loss: 46.9698 - vae_r_loss: 42.0353 - vae_kl_loss: 4.9345 - val_loss: 47.0246 - val_vae_r_loss: 42.1103 - val_vae_kl_loss: 4.9143
Epoch 11/200
60000/60000 [==============================] - 3s 57us/sample - loss: 46.7538 - vae_r_loss: 41.7733 - vae_kl_loss: 4.9805 - val_loss: 46.9033 - val_vae_r_loss: 41.9019 - val_vae_kl_loss: 5.0014
Epoch 12/200
60000/60000 [==============================] - 3s 58us/sample - loss: 46.4962 - vae_r_loss: 41.4867 - vae_kl_loss: 5.0095 - val_loss: 46.6990 - val_vae_r_loss: 41.8006 - val_vae_kl_loss: 4.8985
Epoch 13/200
60000/60000 [==============================] - 3s 57us/sample - loss: 46.3232 - vae_r_loss: 41.2603 - vae_kl_loss: 5.0629 - val_loss: 46.6737 - val_vae_r_loss: 41.4675 - val_vae_kl_loss: 5.2061
Epoch 14/200
60000/60000 [==============================] - 3s 58us/sample - loss: 46.1505 - vae_r_loss: 41.0678 - vae_kl_loss: 5.0828 - val_loss: 46.3871 - val_vae_r_loss: 41.4687 - val_vae_kl_loss: 4.9184
Epoch 15/200
60000/60000 [==============================] - 3s 57us/sample - loss: 45.9750 - vae_r_loss: 40.8533 - vae_kl_loss: 5.1217 - val_loss: 46.1730 - val_vae_r_loss: 41.0982 - val_vae_kl_loss: 5.0748
Epoch 16/200
60000/60000 [==============================] - 3s 57us/sample - loss: 45.8053 - vae_r_loss: 40.6467 - vae_kl_loss: 5.1586 - val_loss: 46.2439 - val_vae_r_loss: 41.1142 - val_vae_kl_loss: 5.1297
Epoch 17/200
60000/60000 [==============================] - 3s 57us/sample - loss: 45.6415 - vae_r_loss: 40.4657 - vae_kl_loss: 5.1758 - val_loss: 46.0754 - val_vae_r_loss: 41.0632 - val_vae_kl_loss: 5.0122
Epoch 18/200
60000/60000 [==============================] - 3s 58us/sample - loss: 45.5121 - vae_r_loss: 40.3147 - vae_kl_loss: 5.1974 - val_loss: 45.8663 - val_vae_r_loss: 40.5329 - val_vae_kl_loss: 5.3334
Epoch 19/200
60000/60000 [==============================] - 3s 56us/sample - loss: 45.3686 - vae_r_loss: 40.1475 - vae_kl_loss: 5.2211 - val_loss: 46.2054 - val_vae_r_loss: 41.1238 - val_vae_kl_loss: 5.0816
Epoch 20/200
60000/60000 [==============================] - 3s 58us/sample - loss: 45.2161 - vae_r_loss: 39.9703 - vae_kl_loss: 5.2458 - val_loss: 45.7448 - val_vae_r_loss: 40.6166 - val_vae_kl_loss: 5.1283
Epoch 21/200
60000/60000 [==============================] - 3s 56us/sample - loss: 45.1159 - vae_r_loss: 39.8419 - vae_kl_loss: 5.2740 - val_loss: 45.8612 - val_vae_r_loss: 40.8692 - val_vae_kl_loss: 4.9920
Epoch 22/200
60000/60000 [==============================] - 3s 57us/sample - loss: 44.9881 - vae_r_loss: 39.7023 - vae_kl_loss: 5.2857 - val_loss: 45.8085 - val_vae_r_loss: 40.2675 - val_vae_kl_loss: 5.5410
Epoch 23/200
60000/60000 [==============================] - 3s 57us/sample - loss: 44.8471 - vae_r_loss: 39.5384 - vae_kl_loss: 5.3087 - val_loss: 45.4330 - val_vae_r_loss: 40.0743 - val_vae_kl_loss: 5.3587
Epoch 24/200
60000/60000 [==============================] - 3s 58us/sample - loss: 44.7550 - vae_r_loss: 39.4362 - vae_kl_loss: 5.3188 - val_loss: 45.3320 - val_vae_r_loss: 39.9992 - val_vae_kl_loss: 5.3328
Epoch 25/200
60000/60000 [==============================] - 3s 57us/sample - loss: 44.6692 - vae_r_loss: 39.3461 - vae_kl_loss: 5.3232 - val_loss: 45.3552 - val_vae_r_loss: 40.0258 - val_vae_kl_loss: 5.3294
Epoch 26/200
60000/60000 [==============================] - 3s 58us/sample - loss: 44.5891 - vae_r_loss: 39.2333 - vae_kl_loss: 5.3558 - val_loss: 45.2681 - val_vae_r_loss: 39.9015 - val_vae_kl_loss: 5.3666
Epoch 27/200
60000/60000 [==============================] - 3s 56us/sample - loss: 44.5072 - vae_r_loss: 39.1374 - vae_kl_loss: 5.3698 - val_loss: 45.3209 - val_vae_r_loss: 39.9636 - val_vae_kl_loss: 5.3574
Epoch 28/200
60000/60000 [==============================] - 3s 58us/sample - loss: 44.4180 - vae_r_loss: 39.0149 - vae_kl_loss: 5.4031 - val_loss: 45.2435 - val_vae_r_loss: 39.7765 - val_vae_kl_loss: 5.4671
Epoch 29/200
60000/60000 [==============================] - 3s 57us/sample - loss: 44.3102 - vae_r_loss: 38.9046 - vae_kl_loss: 5.4057 - val_loss: 45.2258 - val_vae_r_loss: 39.8441 - val_vae_kl_loss: 5.3817
Epoch 30/200
60000/60000 [==============================] - 3s 58us/sample - loss: 44.2489 - vae_r_loss: 38.8299 - vae_kl_loss: 5.4190 - val_loss: 45.0044 - val_vae_r_loss: 39.6516 - val_vae_kl_loss: 5.3528
Epoch 31/200
60000/60000 [==============================] - 3s 57us/sample - loss: 44.1732 - vae_r_loss: 38.7482 - vae_kl_loss: 5.4249 - val_loss: 45.0000 - val_vae_r_loss: 39.5609 - val_vae_kl_loss: 5.4391
Epoch 32/200
60000/60000 [==============================] - 3s 58us/sample - loss: 44.0894 - vae_r_loss: 38.6580 - vae_kl_loss: 5.4314 - val_loss: 44.9769 - val_vae_r_loss: 39.5384 - val_vae_kl_loss: 5.4385
Epoch 33/200
60000/60000 [==============================] - 3s 57us/sample - loss: 44.0582 - vae_r_loss: 38.6092 - vae_kl_loss: 5.4490 - val_loss: 44.9346 - val_vae_r_loss: 39.3805 - val_vae_kl_loss: 5.5541
Epoch 34/200
60000/60000 [==============================] - 3s 57us/sample - loss: 43.9458 - vae_r_loss: 38.4818 - vae_kl_loss: 5.4640 - val_loss: 45.0624 - val_vae_r_loss: 39.5811 - val_vae_kl_loss: 5.4813
Epoch 35/200
60000/60000 [==============================] - 3s 56us/sample - loss: 43.8850 - vae_r_loss: 38.4031 - vae_kl_loss: 5.4819 - val_loss: 45.0285 - val_vae_r_loss: 39.5350 - val_vae_kl_loss: 5.4935
Epoch 36/200
60000/60000 [==============================] - 3s 58us/sample - loss: 43.8698 - vae_r_loss: 38.3779 - vae_kl_loss: 5.4918 - val_loss: 44.9170 - val_vae_r_loss: 39.5714 - val_vae_kl_loss: 5.3456
Epoch 37/200
60000/60000 [==============================] - 3s 57us/sample - loss: 43.7739 - vae_r_loss: 38.2723 - vae_kl_loss: 5.5016 - val_loss: 44.8441 - val_vae_r_loss: 39.3665 - val_vae_kl_loss: 5.4776
Epoch 38/200
60000/60000 [==============================] - 3s 57us/sample - loss: 43.7084 - vae_r_loss: 38.1933 - vae_kl_loss: 5.5151 - val_loss: 44.9233 - val_vae_r_loss: 39.5526 - val_vae_kl_loss: 5.3706
Epoch 39/200
60000/60000 [==============================] - 3s 57us/sample - loss: 43.6626 - vae_r_loss: 38.1320 - vae_kl_loss: 5.5306 - val_loss: 44.6793 - val_vae_r_loss: 39.2304 - val_vae_kl_loss: 5.4489
Epoch 40/200
60000/60000 [==============================] - 3s 58us/sample - loss: 43.5838 - vae_r_loss: 38.0592 - vae_kl_loss: 5.5246 - val_loss: 44.6130 - val_vae_r_loss: 39.0715 - val_vae_kl_loss: 5.5415
Epoch 41/200
60000/60000 [==============================] - 3s 56us/sample - loss: 43.5194 - vae_r_loss: 37.9840 - vae_kl_loss: 5.5354 - val_loss: 44.8512 - val_vae_r_loss: 39.6158 - val_vae_kl_loss: 5.2354
Epoch 42/200
60000/60000 [==============================] - 3s 57us/sample - loss: 43.5129 - vae_r_loss: 37.9786 - vae_kl_loss: 5.5343 - val_loss: 44.6991 - val_vae_r_loss: 39.2098 - val_vae_kl_loss: 5.4894
Epoch 43/200
60000/60000 [==============================] - 3s 56us/sample - loss: 43.4707 - vae_r_loss: 37.9237 - vae_kl_loss: 5.5470 - val_loss: 44.7121 - val_vae_r_loss: 39.2446 - val_vae_kl_loss: 5.4675
Epoch 44/200
60000/60000 [==============================] - 3s 57us/sample - loss: 43.3832 - vae_r_loss: 37.8227 - vae_kl_loss: 5.5604 - val_loss: 44.9172 - val_vae_r_loss: 39.3446 - val_vae_kl_loss: 5.5726
Epoch 45/200
60000/60000 [==============================] - 3s 57us/sample - loss: 43.3868 - vae_r_loss: 37.8075 - vae_kl_loss: 5.5793 - val_loss: 44.5718 - val_vae_r_loss: 39.0284 - val_vae_kl_loss: 5.5434
Epoch 46/200
60000/60000 [==============================] - 3s 57us/sample - loss: 43.2774 - vae_r_loss: 37.6953 - vae_kl_loss: 5.5821 - val_loss: 44.6954 - val_vae_r_loss: 39.1276 - val_vae_kl_loss: 5.5678
Epoch 47/200
60000/60000 [==============================] - 3s 56us/sample - loss: 43.2765 - vae_r_loss: 37.6813 - vae_kl_loss: 5.5952 - val_loss: 44.6153 - val_vae_r_loss: 38.9606 - val_vae_kl_loss: 5.6547
Epoch 48/200
60000/60000 [==============================] - 3s 58us/sample - loss: 43.2385 - vae_r_loss: 37.6431 - vae_kl_loss: 5.5954 - val_loss: 44.5508 - val_vae_r_loss: 39.0830 - val_vae_kl_loss: 5.4678
Epoch 49/200
60000/60000 [==============================] - 3s 56us/sample - loss: 43.1847 - vae_r_loss: 37.5822 - vae_kl_loss: 5.6025 - val_loss: 44.8277 - val_vae_r_loss: 39.1688 - val_vae_kl_loss: 5.6589
Epoch 50/200
60000/60000 [==============================] - 4s 58us/sample - loss: 43.1557 - vae_r_loss: 37.5533 - vae_kl_loss: 5.6024 - val_loss: 44.5082 - val_vae_r_loss: 38.9529 - val_vae_kl_loss: 5.5553
Epoch 51/200
60000/60000 [==============================] - 3s 56us/sample - loss: 43.0726 - vae_r_loss: 37.4533 - vae_kl_loss: 5.6193 - val_loss: 44.6332 - val_vae_r_loss: 38.9104 - val_vae_kl_loss: 5.7228
Epoch 52/200
60000/60000 [==============================] - 3s 57us/sample - loss: 43.1003 - vae_r_loss: 37.4708 - vae_kl_loss: 5.6295 - val_loss: 44.5279 - val_vae_r_loss: 39.0846 - val_vae_kl_loss: 5.4433
Epoch 53/200
60000/60000 [==============================] - 3s 56us/sample - loss: 43.0121 - vae_r_loss: 37.3923 - vae_kl_loss: 5.6198 - val_loss: 44.5675 - val_vae_r_loss: 38.9651 - val_vae_kl_loss: 5.6024
Epoch 54/200
60000/60000 [==============================] - 3s 57us/sample - loss: 42.9750 - vae_r_loss: 37.3273 - vae_kl_loss: 5.6477 - val_loss: 44.6084 - val_vae_r_loss: 39.0057 - val_vae_kl_loss: 5.6027
Epoch 55/200
60000/60000 [==============================] - 3s 57us/sample - loss: 42.9669 - vae_r_loss: 37.3124 - vae_kl_loss: 5.6545 - val_loss: 44.4369 - val_vae_r_loss: 38.7499 - val_vae_kl_loss: 5.6870
Epoch 56/200
60000/60000 [==============================] - 3s 57us/sample - loss: 42.9172 - vae_r_loss: 37.2666 - vae_kl_loss: 5.6506 - val_loss: 44.4817 - val_vae_r_loss: 38.8071 - val_vae_kl_loss: 5.6747
Epoch 57/200
60000/60000 [==============================] - 3s 56us/sample - loss: 42.8719 - vae_r_loss: 37.2088 - vae_kl_loss: 5.6630 - val_loss: 44.7545 - val_vae_r_loss: 39.1340 - val_vae_kl_loss: 5.6205
Epoch 58/200
60000/60000 [==============================] - 3s 57us/sample - loss: 42.8724 - vae_r_loss: 37.2070 - vae_kl_loss: 5.6654 - val_loss: 44.4428 - val_vae_r_loss: 38.8374 - val_vae_kl_loss: 5.6054
Epoch 59/200
60000/60000 [==============================] - 3s 57us/sample - loss: 42.8085 - vae_r_loss: 37.1356 - vae_kl_loss: 5.6729 - val_loss: 44.3657 - val_vae_r_loss: 38.8973 - val_vae_kl_loss: 5.4684
Epoch 60/200
60000/60000 [==============================] - 3s 57us/sample - loss: 42.7711 - vae_r_loss: 37.1025 - vae_kl_loss: 5.6687 - val_loss: 44.5526 - val_vae_r_loss: 38.7923 - val_vae_kl_loss: 5.7603
Epoch 61/200
60000/60000 [==============================] - 3s 56us/sample - loss: 42.7549 - vae_r_loss: 37.0712 - vae_kl_loss: 5.6837 - val_loss: 44.6274 - val_vae_r_loss: 39.1211 - val_vae_kl_loss: 5.5063
Epoch 62/200
60000/60000 [==============================] - 3s 57us/sample - loss: 42.7314 - vae_r_loss: 37.0368 - vae_kl_loss: 5.6946 - val_loss: 44.3828 - val_vae_r_loss: 38.8327 - val_vae_kl_loss: 5.5502
Epoch 63/200
60000/60000 [==============================] - 3s 56us/sample - loss: 42.6688 - vae_r_loss: 36.9835 - vae_kl_loss: 5.6853 - val_loss: 44.4869 - val_vae_r_loss: 38.8497 - val_vae_kl_loss: 5.6372
Epoch 64/200
60000/60000 [==============================] - 3s 57us/sample - loss: 42.6714 - vae_r_loss: 36.9633 - vae_kl_loss: 5.7080 - val_loss: 44.4562 - val_vae_r_loss: 38.7178 - val_vae_kl_loss: 5.7384
Epoch 65/200
60000/60000 [==============================] - 3s 56us/sample - loss: 42.6547 - vae_r_loss: 36.9360 - vae_kl_loss: 5.7187 - val_loss: 44.4947 - val_vae_r_loss: 38.8561 - val_vae_kl_loss: 5.6386
Epoch 66/200
60000/60000 [==============================] - 3s 57us/sample - loss: 42.5807 - vae_r_loss: 36.8625 - vae_kl_loss: 5.7182 - val_loss: 44.4270 - val_vae_r_loss: 38.7251 - val_vae_kl_loss: 5.7019
Epoch 67/200
60000/60000 [==============================] - 3s 56us/sample - loss: 42.5664 - vae_r_loss: 36.8466 - vae_kl_loss: 5.7197 - val_loss: 44.5878 - val_vae_r_loss: 38.8787 - val_vae_kl_loss: 5.7091
Epoch 68/200
60000/60000 [==============================] - 3s 57us/sample - loss: 42.5503 - vae_r_loss: 36.8269 - vae_kl_loss: 5.7235 - val_loss: 44.6236 - val_vae_r_loss: 38.8846 - val_vae_kl_loss: 5.7390
Epoch 69/200
60000/60000 [==============================] - 3s 56us/sample - loss: 42.5057 - vae_r_loss: 36.7706 - vae_kl_loss: 5.7352 - val_loss: 44.5720 - val_vae_r_loss: 38.9196 - val_vae_kl_loss: 5.6525
Epoch 70/200
60000/60000 [==============================] - 3s 57us/sample - loss: 42.4955 - vae_r_loss: 36.7553 - vae_kl_loss: 5.7402 - val_loss: 44.4059 - val_vae_r_loss: 38.8886 - val_vae_kl_loss: 5.5173
Epoch 71/200
60000/60000 [==============================] - 3s 56us/sample - loss: 42.4649 - vae_r_loss: 36.7251 - vae_kl_loss: 5.7398 - val_loss: 44.5864 - val_vae_r_loss: 38.8203 - val_vae_kl_loss: 5.7661
Epoch 72/200
60000/60000 [==============================] - 3s 58us/sample - loss: 42.4907 - vae_r_loss: 36.7440 - vae_kl_loss: 5.7467 - val_loss: 44.3493 - val_vae_r_loss: 38.6765 - val_vae_kl_loss: 5.6727
Epoch 73/200
60000/60000 [==============================] - 3s 56us/sample - loss: 42.4224 - vae_r_loss: 36.6558 - vae_kl_loss: 5.7666 - val_loss: 44.5477 - val_vae_r_loss: 38.7588 - val_vae_kl_loss: 5.7889
Epoch 74/200
43100/60000 [====================>.........] - ETA: 0s - loss: 42.3141 - vae_r_loss: 36.5576 - vae_kl_loss: 5.7565
%% Cell type:code id: tags:
%% Cell type:code id: tags:
```
python
```
python
``
`
``
`
%%
Cell
type
:
code
id
:
tags
:
%%
Cell
type
:
code
id
:
tags
:
```
python
```
python
```
```
...
...
This diff is collapsed.
Click to expand it.
VAE/modules/callbacks.py
+
12
−
1
View file @
665b3434
from
tensorflow.keras.callbacks
import
Callback
from
tensorflow.keras.callbacks
import
Callback
,
LearningRateScheduler
import
numpy
as
np
import
numpy
as
np
import
matplotlib.pyplot
as
plt
import
matplotlib.pyplot
as
plt
import
os
import
os
...
@@ -33,3 +33,14 @@ class ImagesCallback(Callback):
...
@@ -33,3 +33,14 @@ class ImagesCallback(Callback):
def
on_epoch_begin
(
self
,
epoch
,
logs
=
{}):
def
on_epoch_begin
(
self
,
epoch
,
logs
=
{}):
self
.
epoch
+=
1
self
.
epoch
+=
1
def
step_decay_schedule
(
initial_lr
,
decay_factor
=
0.5
,
step_size
=
1
):
'''
Wrapper function to create a LearningRateScheduler with step decay schedule.
'''
def
schedule
(
epoch
):
new_lr
=
initial_lr
*
(
decay_factor
**
np
.
floor
(
epoch
/
step_size
))
return
new_lr
return
LearningRateScheduler
(
schedule
)
\ No newline at end of file
This diff is collapsed.
Click to expand it.
VAE/modules/vae.py
+
18
−
9
View file @
665b3434
...
@@ -6,7 +6,7 @@ from tensorflow.keras import backend as K
...
@@ -6,7 +6,7 @@ from tensorflow.keras import backend as K
from
tensorflow.keras.layers
import
Input
,
Conv2D
,
Flatten
,
Dense
,
Conv2DTranspose
,
Reshape
,
Lambda
from
tensorflow.keras.layers
import
Input
,
Conv2D
,
Flatten
,
Dense
,
Conv2DTranspose
,
Reshape
,
Lambda
from
tensorflow.keras.layers
import
Activation
,
BatchNormalization
,
LeakyReLU
,
Dropout
from
tensorflow.keras.layers
import
Activation
,
BatchNormalization
,
LeakyReLU
,
Dropout
from
tensorflow.keras.models
import
Model
from
tensorflow.keras.models
import
Model
from
tensorflow.keras.callbacks
import
ModelCheckpoint
from
tensorflow.keras.callbacks
import
ModelCheckpoint
,
TensorBoard
from
tensorflow.keras.optimizers
import
Adam
from
tensorflow.keras.optimizers
import
Adam
from
tensorflow.keras.utils
import
plot_model
from
tensorflow.keras.utils
import
plot_model
...
@@ -161,18 +161,25 @@ class VariationalAutoencoder():
...
@@ -161,18 +161,25 @@ class VariationalAutoencoder():
self
.
n_test
=
n_test
self
.
n_test
=
n_test
self
.
batch_size
=
batch_size
self
.
batch_size
=
batch_size
# ---- Callbacks
# ---- Callback
: Image
s
images_
callback
=
modules
.
callbacks
.
ImagesCallback
(
initial_epoch
,
image_periodicity
,
self
)
callback
s_images
=
modules
.
callbacks
.
ImagesCallback
(
initial_epoch
,
image_periodicity
,
self
)
# lr_sched = step_decay_schedule(initial_lr=self.learning_rate, decay_factor=lr_decay, step_size=1)
# ---- Callback : Learning rate scheduler
lr_sched
=
modules
.
callbacks
.
step_decay_schedule
(
initial_lr
=
self
.
learning_rate
,
decay_factor
=
lr_decay
,
step_size
=
1
)
filename1
=
self
.
run_directory
+
"
/models/model-{epoch:03d}-{loss:.2f}.h5
"
# ---- Callback : Checkpoint
checkpoint1
=
ModelCheckpoint
(
filename1
,
save_freq
=
n_train
*
chkpt_periodicity
,
verbose
=
0
)
filename
=
self
.
run_directory
+
"
/models/model-{epoch:03d}-{loss:.2f}.h5
"
callback_chkpts
=
ModelCheckpoint
(
filename
,
save_freq
=
n_train
*
chkpt_periodicity
,
verbose
=
0
)
filename2
=
self
.
run_directory
+
"
/models/best_model.h5
"
# ---- Callback : Best model
checkpoint2
=
ModelCheckpoint
(
filename2
,
save_best_only
=
True
,
mode
=
'
min
'
,
monitor
=
'
val_loss
'
,
verbose
=
0
)
filename
=
self
.
run_directory
+
"
/models/best_model.h5
"
callback_bestmodel
=
ModelCheckpoint
(
filename
,
save_best_only
=
True
,
mode
=
'
min
'
,
monitor
=
'
val_loss
'
,
verbose
=
0
)
callbacks_list
=
[
checkpoint1
,
checkpoint2
,
images_callback
]
# ---- Callback tensorboard
dirname
=
self
.
run_directory
+
"
/logs
"
callback_tensorboard
=
TensorBoard
(
log_dir
=
dirname
,
histogram_freq
=
1
)
callbacks_list
=
[
callbacks_images
,
callback_chkpts
,
callback_bestmodel
,
callback_tensorboard
,
lr_sched
]
self
.
model
.
fit
(
x_train
[:
n_train
],
x_train
[:
n_train
],
self
.
model
.
fit
(
x_train
[:
n_train
],
x_train
[:
n_train
],
batch_size
=
batch_size
,
batch_size
=
batch_size
,
...
@@ -189,3 +196,5 @@ class VariationalAutoencoder():
...
@@ -189,3 +196,5 @@ class VariationalAutoencoder():
plot_model
(
self
.
model
,
to_file
=
f
'
{
d
}
/model.png
'
,
show_shapes
=
True
,
show_layer_names
=
True
,
expand_nested
=
True
)
plot_model
(
self
.
model
,
to_file
=
f
'
{
d
}
/model.png
'
,
show_shapes
=
True
,
show_layer_names
=
True
,
expand_nested
=
True
)
plot_model
(
self
.
encoder
,
to_file
=
f
'
{
d
}
/encoder.png
'
,
show_shapes
=
True
,
show_layer_names
=
True
)
plot_model
(
self
.
encoder
,
to_file
=
f
'
{
d
}
/encoder.png
'
,
show_shapes
=
True
,
show_layer_names
=
True
)
plot_model
(
self
.
decoder
,
to_file
=
f
'
{
d
}
/decoder.png
'
,
show_shapes
=
True
,
show_layer_names
=
True
)
plot_model
(
self
.
decoder
,
to_file
=
f
'
{
d
}
/decoder.png
'
,
show_shapes
=
True
,
show_layer_names
=
True
)
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment