diff --git a/VAE/01-VAE with MNIST.ipynb b/VAE/01-VAE with MNIST.ipynb index 5c7450927e8956a11faf1a3355494fa52a20a01b..8eb9c45527c8cdc76b79e66fb3001635d231813b 100644 --- a/VAE/01-VAE with MNIST.ipynb +++ b/VAE/01-VAE with MNIST.ipynb @@ -22,9 +22,22 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "IDLE 2020 - Practical Work Module\n", + " Version : 0.2.5\n", + " Run time : Tuesday 4 February 2020, 00:10:15\n", + " Matplotlib style : ../fidle/talk.mplstyle\n", + " TensorFlow version : 2.0.0\n", + " Keras version : 2.2.4-tf\n" + ] + } + ], "source": [ "import numpy as np\n", "\n", @@ -57,9 +70,18 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(60000, 28, 28, 1)\n", + "(10000, 28, 28, 1)\n" + ] + } + ], "source": [ "(x_train, y_train), (x_test, y_test) = mnist.load_data()\n", "\n", @@ -80,14 +102,23 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model initialized.\n", + "Outputs will be in : ./run/001\n" + ] + } + ], "source": [ "# reload(modules.vae)\n", "# reload(modules.callbacks)\n", "\n", - "tag = '000'\n", + "tag = '001'\n", "\n", "input_shape = (28,28,1)\n", "z_dim = 2\n", @@ -122,7 +153,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -141,14 +172,14 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "batch_size = 100\n", "epochs = 200\n", - "image_periodicity = 1 # in epoch\n", - "chkpt_periodicity = 2 # in epoch\n", + "image_periodicity = 1 # for each epoch\n", + "chkpt_periodicity = 2 # for each epoch\n", "initial_epoch = 0\n", "dataset_size = 1" ] @@ -157,7 +188,166 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train on 60000 samples, validate on 10000 samples\n", + "Epoch 1/200\n", + " 100/60000 [..............................] - ETA: 23:40 - loss: 231.4378 - vae_r_loss: 231.4373 - vae_kl_loss: 5.3801e-04WARNING:tensorflow:Method (on_train_batch_end) is slow compared to the batch update (0.251492). Check your callbacks.\n", + "60000/60000 [==============================] - 6s 101us/sample - loss: 67.7431 - vae_r_loss: 65.0691 - vae_kl_loss: 2.6740 - val_loss: 55.6598 - val_vae_r_loss: 52.4039 - val_vae_kl_loss: 3.2560\n", + "Epoch 2/200\n", + "60000/60000 [==============================] - 3s 58us/sample - loss: 54.0334 - vae_r_loss: 50.4695 - vae_kl_loss: 3.5639 - val_loss: 52.9105 - val_vae_r_loss: 49.1433 - val_vae_kl_loss: 3.7672\n", + "Epoch 3/200\n", + "60000/60000 [==============================] - 3s 57us/sample - loss: 51.8937 - vae_r_loss: 47.9195 - vae_kl_loss: 3.9743 - val_loss: 51.1775 - val_vae_r_loss: 47.0874 - val_vae_kl_loss: 4.0901\n", + "Epoch 4/200\n", + "60000/60000 [==============================] - 4s 59us/sample - loss: 50.4622 - vae_r_loss: 46.1359 - vae_kl_loss: 4.3264 - val_loss: 49.8507 - val_vae_r_loss: 45.2015 - val_vae_kl_loss: 4.6492\n", + "Epoch 5/200\n", + "60000/60000 [==============================] - 3s 57us/sample - loss: 49.3577 - vae_r_loss: 44.8123 - vae_kl_loss: 4.5454 - val_loss: 48.9416 - val_vae_r_loss: 44.3832 - val_vae_kl_loss: 4.5584\n", + "Epoch 6/200\n", + "60000/60000 [==============================] - 3s 58us/sample - loss: 48.5603 - vae_r_loss: 43.8800 - vae_kl_loss: 4.6803 - val_loss: 48.1800 - val_vae_r_loss: 43.5046 - val_vae_kl_loss: 4.6754\n", + "Epoch 7/200\n", + "60000/60000 [==============================] - 3s 57us/sample - loss: 48.0286 - vae_r_loss: 43.2646 - vae_kl_loss: 4.7640 - val_loss: 47.9362 - val_vae_r_loss: 43.2833 - val_vae_kl_loss: 4.6529\n", + "Epoch 8/200\n", + "60000/60000 [==============================] - 3s 58us/sample - loss: 47.6163 - vae_r_loss: 42.7828 - vae_kl_loss: 4.8336 - val_loss: 47.6161 - val_vae_r_loss: 42.7176 - val_vae_kl_loss: 4.8985\n", + "Epoch 9/200\n", + "60000/60000 [==============================] - 3s 58us/sample - loss: 47.2654 - vae_r_loss: 42.3804 - vae_kl_loss: 4.8850 - val_loss: 47.1385 - val_vae_r_loss: 42.2280 - val_vae_kl_loss: 4.9105\n", + "Epoch 10/200\n", + "WARNING:tensorflow:Method (on_train_batch_end) is slow compared to the batch update (0.732872). Check your callbacks.\n", + " 100/60000 [..............................] - ETA: 7:23 - loss: 47.8688 - vae_r_loss: 43.0966 - vae_kl_loss: 4.7722WARNING:tensorflow:Method (on_train_batch_end) is slow compared to the batch update (0.366450). Check your callbacks.\n", + "60000/60000 [==============================] - 4s 70us/sample - loss: 46.9698 - vae_r_loss: 42.0353 - vae_kl_loss: 4.9345 - val_loss: 47.0246 - val_vae_r_loss: 42.1103 - val_vae_kl_loss: 4.9143\n", + "Epoch 11/200\n", + "60000/60000 [==============================] - 3s 57us/sample - loss: 46.7538 - vae_r_loss: 41.7733 - vae_kl_loss: 4.9805 - val_loss: 46.9033 - val_vae_r_loss: 41.9019 - val_vae_kl_loss: 5.0014\n", + "Epoch 12/200\n", + "60000/60000 [==============================] - 3s 58us/sample - loss: 46.4962 - vae_r_loss: 41.4867 - vae_kl_loss: 5.0095 - val_loss: 46.6990 - val_vae_r_loss: 41.8006 - val_vae_kl_loss: 4.8985\n", + "Epoch 13/200\n", + "60000/60000 [==============================] - 3s 57us/sample - loss: 46.3232 - vae_r_loss: 41.2603 - vae_kl_loss: 5.0629 - val_loss: 46.6737 - val_vae_r_loss: 41.4675 - val_vae_kl_loss: 5.2061\n", + "Epoch 14/200\n", + "60000/60000 [==============================] - 3s 58us/sample - loss: 46.1505 - vae_r_loss: 41.0678 - vae_kl_loss: 5.0828 - val_loss: 46.3871 - val_vae_r_loss: 41.4687 - val_vae_kl_loss: 4.9184\n", + "Epoch 15/200\n", + "60000/60000 [==============================] - 3s 57us/sample - loss: 45.9750 - vae_r_loss: 40.8533 - vae_kl_loss: 5.1217 - val_loss: 46.1730 - val_vae_r_loss: 41.0982 - val_vae_kl_loss: 5.0748\n", + "Epoch 16/200\n", + "60000/60000 [==============================] - 3s 57us/sample - loss: 45.8053 - vae_r_loss: 40.6467 - vae_kl_loss: 5.1586 - val_loss: 46.2439 - val_vae_r_loss: 41.1142 - val_vae_kl_loss: 5.1297\n", + "Epoch 17/200\n", + "60000/60000 [==============================] - 3s 57us/sample - loss: 45.6415 - vae_r_loss: 40.4657 - vae_kl_loss: 5.1758 - val_loss: 46.0754 - val_vae_r_loss: 41.0632 - val_vae_kl_loss: 5.0122\n", + "Epoch 18/200\n", + "60000/60000 [==============================] - 3s 58us/sample - loss: 45.5121 - vae_r_loss: 40.3147 - vae_kl_loss: 5.1974 - val_loss: 45.8663 - val_vae_r_loss: 40.5329 - val_vae_kl_loss: 5.3334\n", + "Epoch 19/200\n", + "60000/60000 [==============================] - 3s 56us/sample - loss: 45.3686 - vae_r_loss: 40.1475 - vae_kl_loss: 5.2211 - val_loss: 46.2054 - val_vae_r_loss: 41.1238 - val_vae_kl_loss: 5.0816\n", + "Epoch 20/200\n", + "60000/60000 [==============================] - 3s 58us/sample - loss: 45.2161 - vae_r_loss: 39.9703 - vae_kl_loss: 5.2458 - val_loss: 45.7448 - val_vae_r_loss: 40.6166 - val_vae_kl_loss: 5.1283\n", + "Epoch 21/200\n", + "60000/60000 [==============================] - 3s 56us/sample - loss: 45.1159 - vae_r_loss: 39.8419 - vae_kl_loss: 5.2740 - val_loss: 45.8612 - val_vae_r_loss: 40.8692 - val_vae_kl_loss: 4.9920\n", + "Epoch 22/200\n", + "60000/60000 [==============================] - 3s 57us/sample - loss: 44.9881 - vae_r_loss: 39.7023 - vae_kl_loss: 5.2857 - val_loss: 45.8085 - val_vae_r_loss: 40.2675 - val_vae_kl_loss: 5.5410\n", + "Epoch 23/200\n", + "60000/60000 [==============================] - 3s 57us/sample - loss: 44.8471 - vae_r_loss: 39.5384 - vae_kl_loss: 5.3087 - val_loss: 45.4330 - val_vae_r_loss: 40.0743 - val_vae_kl_loss: 5.3587\n", + "Epoch 24/200\n", + "60000/60000 [==============================] - 3s 58us/sample - loss: 44.7550 - vae_r_loss: 39.4362 - vae_kl_loss: 5.3188 - val_loss: 45.3320 - val_vae_r_loss: 39.9992 - val_vae_kl_loss: 5.3328\n", + "Epoch 25/200\n", + "60000/60000 [==============================] - 3s 57us/sample - loss: 44.6692 - vae_r_loss: 39.3461 - vae_kl_loss: 5.3232 - val_loss: 45.3552 - val_vae_r_loss: 40.0258 - val_vae_kl_loss: 5.3294\n", + "Epoch 26/200\n", + "60000/60000 [==============================] - 3s 58us/sample - loss: 44.5891 - vae_r_loss: 39.2333 - vae_kl_loss: 5.3558 - val_loss: 45.2681 - val_vae_r_loss: 39.9015 - val_vae_kl_loss: 5.3666\n", + "Epoch 27/200\n", + "60000/60000 [==============================] - 3s 56us/sample - loss: 44.5072 - vae_r_loss: 39.1374 - vae_kl_loss: 5.3698 - val_loss: 45.3209 - val_vae_r_loss: 39.9636 - val_vae_kl_loss: 5.3574\n", + "Epoch 28/200\n", + "60000/60000 [==============================] - 3s 58us/sample - loss: 44.4180 - vae_r_loss: 39.0149 - vae_kl_loss: 5.4031 - val_loss: 45.2435 - val_vae_r_loss: 39.7765 - val_vae_kl_loss: 5.4671\n", + "Epoch 29/200\n", + "60000/60000 [==============================] - 3s 57us/sample - loss: 44.3102 - vae_r_loss: 38.9046 - vae_kl_loss: 5.4057 - val_loss: 45.2258 - val_vae_r_loss: 39.8441 - val_vae_kl_loss: 5.3817\n", + "Epoch 30/200\n", + "60000/60000 [==============================] - 3s 58us/sample - loss: 44.2489 - vae_r_loss: 38.8299 - vae_kl_loss: 5.4190 - val_loss: 45.0044 - val_vae_r_loss: 39.6516 - val_vae_kl_loss: 5.3528\n", + "Epoch 31/200\n", + "60000/60000 [==============================] - 3s 57us/sample - loss: 44.1732 - vae_r_loss: 38.7482 - vae_kl_loss: 5.4249 - val_loss: 45.0000 - val_vae_r_loss: 39.5609 - val_vae_kl_loss: 5.4391\n", + "Epoch 32/200\n", + "60000/60000 [==============================] - 3s 58us/sample - loss: 44.0894 - vae_r_loss: 38.6580 - vae_kl_loss: 5.4314 - val_loss: 44.9769 - val_vae_r_loss: 39.5384 - val_vae_kl_loss: 5.4385\n", + "Epoch 33/200\n", + "60000/60000 [==============================] - 3s 57us/sample - loss: 44.0582 - vae_r_loss: 38.6092 - vae_kl_loss: 5.4490 - val_loss: 44.9346 - val_vae_r_loss: 39.3805 - val_vae_kl_loss: 5.5541\n", + "Epoch 34/200\n", + "60000/60000 [==============================] - 3s 57us/sample - loss: 43.9458 - vae_r_loss: 38.4818 - vae_kl_loss: 5.4640 - val_loss: 45.0624 - val_vae_r_loss: 39.5811 - val_vae_kl_loss: 5.4813\n", + "Epoch 35/200\n", + "60000/60000 [==============================] - 3s 56us/sample - loss: 43.8850 - vae_r_loss: 38.4031 - vae_kl_loss: 5.4819 - val_loss: 45.0285 - val_vae_r_loss: 39.5350 - val_vae_kl_loss: 5.4935\n", + "Epoch 36/200\n", + "60000/60000 [==============================] - 3s 58us/sample - loss: 43.8698 - vae_r_loss: 38.3779 - vae_kl_loss: 5.4918 - val_loss: 44.9170 - val_vae_r_loss: 39.5714 - val_vae_kl_loss: 5.3456\n", + "Epoch 37/200\n", + "60000/60000 [==============================] - 3s 57us/sample - loss: 43.7739 - vae_r_loss: 38.2723 - vae_kl_loss: 5.5016 - val_loss: 44.8441 - val_vae_r_loss: 39.3665 - val_vae_kl_loss: 5.4776\n", + "Epoch 38/200\n", + "60000/60000 [==============================] - 3s 57us/sample - loss: 43.7084 - vae_r_loss: 38.1933 - vae_kl_loss: 5.5151 - val_loss: 44.9233 - val_vae_r_loss: 39.5526 - val_vae_kl_loss: 5.3706\n", + "Epoch 39/200\n", + "60000/60000 [==============================] - 3s 57us/sample - loss: 43.6626 - vae_r_loss: 38.1320 - vae_kl_loss: 5.5306 - val_loss: 44.6793 - val_vae_r_loss: 39.2304 - val_vae_kl_loss: 5.4489\n", + "Epoch 40/200\n", + "60000/60000 [==============================] - 3s 58us/sample - loss: 43.5838 - vae_r_loss: 38.0592 - vae_kl_loss: 5.5246 - val_loss: 44.6130 - val_vae_r_loss: 39.0715 - val_vae_kl_loss: 5.5415\n", + "Epoch 41/200\n", + "60000/60000 [==============================] - 3s 56us/sample - loss: 43.5194 - vae_r_loss: 37.9840 - vae_kl_loss: 5.5354 - val_loss: 44.8512 - val_vae_r_loss: 39.6158 - val_vae_kl_loss: 5.2354\n", + "Epoch 42/200\n", + "60000/60000 [==============================] - 3s 57us/sample - loss: 43.5129 - vae_r_loss: 37.9786 - vae_kl_loss: 5.5343 - val_loss: 44.6991 - val_vae_r_loss: 39.2098 - val_vae_kl_loss: 5.4894\n", + "Epoch 43/200\n", + "60000/60000 [==============================] - 3s 56us/sample - loss: 43.4707 - vae_r_loss: 37.9237 - vae_kl_loss: 5.5470 - val_loss: 44.7121 - val_vae_r_loss: 39.2446 - val_vae_kl_loss: 5.4675\n", + "Epoch 44/200\n", + "60000/60000 [==============================] - 3s 57us/sample - loss: 43.3832 - vae_r_loss: 37.8227 - vae_kl_loss: 5.5604 - val_loss: 44.9172 - val_vae_r_loss: 39.3446 - val_vae_kl_loss: 5.5726\n", + "Epoch 45/200\n", + "60000/60000 [==============================] - 3s 57us/sample - loss: 43.3868 - vae_r_loss: 37.8075 - vae_kl_loss: 5.5793 - val_loss: 44.5718 - val_vae_r_loss: 39.0284 - val_vae_kl_loss: 5.5434\n", + "Epoch 46/200\n", + "60000/60000 [==============================] - 3s 57us/sample - loss: 43.2774 - vae_r_loss: 37.6953 - vae_kl_loss: 5.5821 - val_loss: 44.6954 - val_vae_r_loss: 39.1276 - val_vae_kl_loss: 5.5678\n", + "Epoch 47/200\n", + "60000/60000 [==============================] - 3s 56us/sample - loss: 43.2765 - vae_r_loss: 37.6813 - vae_kl_loss: 5.5952 - val_loss: 44.6153 - val_vae_r_loss: 38.9606 - val_vae_kl_loss: 5.6547\n", + "Epoch 48/200\n", + "60000/60000 [==============================] - 3s 58us/sample - loss: 43.2385 - vae_r_loss: 37.6431 - vae_kl_loss: 5.5954 - val_loss: 44.5508 - val_vae_r_loss: 39.0830 - val_vae_kl_loss: 5.4678\n", + "Epoch 49/200\n", + "60000/60000 [==============================] - 3s 56us/sample - loss: 43.1847 - vae_r_loss: 37.5822 - vae_kl_loss: 5.6025 - val_loss: 44.8277 - val_vae_r_loss: 39.1688 - val_vae_kl_loss: 5.6589\n", + "Epoch 50/200\n", + "60000/60000 [==============================] - 4s 58us/sample - loss: 43.1557 - vae_r_loss: 37.5533 - vae_kl_loss: 5.6024 - val_loss: 44.5082 - val_vae_r_loss: 38.9529 - val_vae_kl_loss: 5.5553\n", + "Epoch 51/200\n", + "60000/60000 [==============================] - 3s 56us/sample - loss: 43.0726 - vae_r_loss: 37.4533 - vae_kl_loss: 5.6193 - val_loss: 44.6332 - val_vae_r_loss: 38.9104 - val_vae_kl_loss: 5.7228\n", + "Epoch 52/200\n", + "60000/60000 [==============================] - 3s 57us/sample - loss: 43.1003 - vae_r_loss: 37.4708 - vae_kl_loss: 5.6295 - val_loss: 44.5279 - val_vae_r_loss: 39.0846 - val_vae_kl_loss: 5.4433\n", + "Epoch 53/200\n", + "60000/60000 [==============================] - 3s 56us/sample - loss: 43.0121 - vae_r_loss: 37.3923 - vae_kl_loss: 5.6198 - val_loss: 44.5675 - val_vae_r_loss: 38.9651 - val_vae_kl_loss: 5.6024\n", + "Epoch 54/200\n", + "60000/60000 [==============================] - 3s 57us/sample - loss: 42.9750 - vae_r_loss: 37.3273 - vae_kl_loss: 5.6477 - val_loss: 44.6084 - val_vae_r_loss: 39.0057 - val_vae_kl_loss: 5.6027\n", + "Epoch 55/200\n", + "60000/60000 [==============================] - 3s 57us/sample - loss: 42.9669 - vae_r_loss: 37.3124 - vae_kl_loss: 5.6545 - val_loss: 44.4369 - val_vae_r_loss: 38.7499 - val_vae_kl_loss: 5.6870\n", + "Epoch 56/200\n", + "60000/60000 [==============================] - 3s 57us/sample - loss: 42.9172 - vae_r_loss: 37.2666 - vae_kl_loss: 5.6506 - val_loss: 44.4817 - val_vae_r_loss: 38.8071 - val_vae_kl_loss: 5.6747\n", + "Epoch 57/200\n", + "60000/60000 [==============================] - 3s 56us/sample - loss: 42.8719 - vae_r_loss: 37.2088 - vae_kl_loss: 5.6630 - val_loss: 44.7545 - val_vae_r_loss: 39.1340 - val_vae_kl_loss: 5.6205\n", + "Epoch 58/200\n", + "60000/60000 [==============================] - 3s 57us/sample - loss: 42.8724 - vae_r_loss: 37.2070 - vae_kl_loss: 5.6654 - val_loss: 44.4428 - val_vae_r_loss: 38.8374 - val_vae_kl_loss: 5.6054\n", + "Epoch 59/200\n", + "60000/60000 [==============================] - 3s 57us/sample - loss: 42.8085 - vae_r_loss: 37.1356 - vae_kl_loss: 5.6729 - val_loss: 44.3657 - val_vae_r_loss: 38.8973 - val_vae_kl_loss: 5.4684\n", + "Epoch 60/200\n", + "60000/60000 [==============================] - 3s 57us/sample - loss: 42.7711 - vae_r_loss: 37.1025 - vae_kl_loss: 5.6687 - val_loss: 44.5526 - val_vae_r_loss: 38.7923 - val_vae_kl_loss: 5.7603\n", + "Epoch 61/200\n", + "60000/60000 [==============================] - 3s 56us/sample - loss: 42.7549 - vae_r_loss: 37.0712 - vae_kl_loss: 5.6837 - val_loss: 44.6274 - val_vae_r_loss: 39.1211 - val_vae_kl_loss: 5.5063\n", + "Epoch 62/200\n", + "60000/60000 [==============================] - 3s 57us/sample - loss: 42.7314 - vae_r_loss: 37.0368 - vae_kl_loss: 5.6946 - val_loss: 44.3828 - val_vae_r_loss: 38.8327 - val_vae_kl_loss: 5.5502\n", + "Epoch 63/200\n", + "60000/60000 [==============================] - 3s 56us/sample - loss: 42.6688 - vae_r_loss: 36.9835 - vae_kl_loss: 5.6853 - val_loss: 44.4869 - val_vae_r_loss: 38.8497 - val_vae_kl_loss: 5.6372\n", + "Epoch 64/200\n", + "60000/60000 [==============================] - 3s 57us/sample - loss: 42.6714 - vae_r_loss: 36.9633 - vae_kl_loss: 5.7080 - val_loss: 44.4562 - val_vae_r_loss: 38.7178 - val_vae_kl_loss: 5.7384\n", + "Epoch 65/200\n", + "60000/60000 [==============================] - 3s 56us/sample - loss: 42.6547 - vae_r_loss: 36.9360 - vae_kl_loss: 5.7187 - val_loss: 44.4947 - val_vae_r_loss: 38.8561 - val_vae_kl_loss: 5.6386\n", + "Epoch 66/200\n", + "60000/60000 [==============================] - 3s 57us/sample - loss: 42.5807 - vae_r_loss: 36.8625 - vae_kl_loss: 5.7182 - val_loss: 44.4270 - val_vae_r_loss: 38.7251 - val_vae_kl_loss: 5.7019\n", + "Epoch 67/200\n", + "60000/60000 [==============================] - 3s 56us/sample - loss: 42.5664 - vae_r_loss: 36.8466 - vae_kl_loss: 5.7197 - val_loss: 44.5878 - val_vae_r_loss: 38.8787 - val_vae_kl_loss: 5.7091\n", + "Epoch 68/200\n", + "60000/60000 [==============================] - 3s 57us/sample - loss: 42.5503 - vae_r_loss: 36.8269 - vae_kl_loss: 5.7235 - val_loss: 44.6236 - val_vae_r_loss: 38.8846 - val_vae_kl_loss: 5.7390\n", + "Epoch 69/200\n", + "60000/60000 [==============================] - 3s 56us/sample - loss: 42.5057 - vae_r_loss: 36.7706 - vae_kl_loss: 5.7352 - val_loss: 44.5720 - val_vae_r_loss: 38.9196 - val_vae_kl_loss: 5.6525\n", + "Epoch 70/200\n", + "60000/60000 [==============================] - 3s 57us/sample - loss: 42.4955 - vae_r_loss: 36.7553 - vae_kl_loss: 5.7402 - val_loss: 44.4059 - val_vae_r_loss: 38.8886 - val_vae_kl_loss: 5.5173\n", + "Epoch 71/200\n", + "60000/60000 [==============================] - 3s 56us/sample - loss: 42.4649 - vae_r_loss: 36.7251 - vae_kl_loss: 5.7398 - val_loss: 44.5864 - val_vae_r_loss: 38.8203 - val_vae_kl_loss: 5.7661\n", + "Epoch 72/200\n", + "60000/60000 [==============================] - 3s 58us/sample - loss: 42.4907 - vae_r_loss: 36.7440 - vae_kl_loss: 5.7467 - val_loss: 44.3493 - val_vae_r_loss: 38.6765 - val_vae_kl_loss: 5.6727\n", + "Epoch 73/200\n", + "60000/60000 [==============================] - 3s 56us/sample - loss: 42.4224 - vae_r_loss: 36.6558 - vae_kl_loss: 5.7666 - val_loss: 44.5477 - val_vae_r_loss: 38.7588 - val_vae_kl_loss: 5.7889\n", + "Epoch 74/200\n", + "43100/60000 [====================>.........] - ETA: 0s - loss: 42.3141 - vae_r_loss: 36.5576 - vae_kl_loss: 5.7565" + ] + } + ], "source": [ "vae.train(x_train,\n", " x_test,\n", diff --git a/VAE/modules/callbacks.py b/VAE/modules/callbacks.py index 63adb23aede91492630a7fd5e4b7dfde817b8304..e7d6e3f88a00b4fe12570663b9fbc6c5f3104465 100644 --- a/VAE/modules/callbacks.py +++ b/VAE/modules/callbacks.py @@ -1,4 +1,4 @@ -from tensorflow.keras.callbacks import Callback +from tensorflow.keras.callbacks import Callback, LearningRateScheduler import numpy as np import matplotlib.pyplot as plt import os @@ -33,3 +33,14 @@ class ImagesCallback(Callback): def on_epoch_begin(self, epoch, logs={}): self.epoch += 1 + +def step_decay_schedule(initial_lr, decay_factor=0.5, step_size=1): + ''' + Wrapper function to create a LearningRateScheduler with step decay schedule. + ''' + def schedule(epoch): + new_lr = initial_lr * (decay_factor ** np.floor(epoch/step_size)) + + return new_lr + + return LearningRateScheduler(schedule) \ No newline at end of file diff --git a/VAE/modules/vae.py b/VAE/modules/vae.py index f9efa3b28a5264e49631b3f5b29daef4a7e74df7..e98409b9e1fa27bb01ec3e0d24413bbe87588ca5 100644 --- a/VAE/modules/vae.py +++ b/VAE/modules/vae.py @@ -6,7 +6,7 @@ from tensorflow.keras import backend as K from tensorflow.keras.layers import Input, Conv2D, Flatten, Dense, Conv2DTranspose, Reshape, Lambda from tensorflow.keras.layers import Activation, BatchNormalization, LeakyReLU, Dropout from tensorflow.keras.models import Model -from tensorflow.keras.callbacks import ModelCheckpoint +from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard from tensorflow.keras.optimizers import Adam from tensorflow.keras.utils import plot_model @@ -161,18 +161,25 @@ class VariationalAutoencoder(): self.n_test = n_test self.batch_size = batch_size - # ---- Callbacks - images_callback = modules.callbacks.ImagesCallback(initial_epoch, image_periodicity, self) + # ---- Callback : Images + callbacks_images = modules.callbacks.ImagesCallback(initial_epoch, image_periodicity, self) -# lr_sched = step_decay_schedule(initial_lr=self.learning_rate, decay_factor=lr_decay, step_size=1) + # ---- Callback : Learning rate scheduler + lr_sched = modules.callbacks.step_decay_schedule(initial_lr=self.learning_rate, decay_factor=lr_decay, step_size=1) - filename1 = self.run_directory+"/models/model-{epoch:03d}-{loss:.2f}.h5" - checkpoint1 = ModelCheckpoint(filename1, save_freq=n_train*chkpt_periodicity ,verbose=0) + # ---- Callback : Checkpoint + filename = self.run_directory+"/models/model-{epoch:03d}-{loss:.2f}.h5" + callback_chkpts = ModelCheckpoint(filename, save_freq=n_train*chkpt_periodicity ,verbose=0) - filename2 = self.run_directory+"/models/best_model.h5" - checkpoint2 = ModelCheckpoint(filename2, save_best_only=True, mode='min',monitor='val_loss',verbose=0) + # ---- Callback : Best model + filename = self.run_directory+"/models/best_model.h5" + callback_bestmodel = ModelCheckpoint(filename, save_best_only=True, mode='min',monitor='val_loss',verbose=0) - callbacks_list = [checkpoint1, checkpoint2, images_callback] + # ---- Callback tensorboard + dirname = self.run_directory+"/logs" + callback_tensorboard = TensorBoard(log_dir=dirname, histogram_freq=1) + + callbacks_list = [callbacks_images, callback_chkpts, callback_bestmodel, callback_tensorboard, lr_sched] self.model.fit(x_train[:n_train], x_train[:n_train], batch_size = batch_size, @@ -189,3 +196,5 @@ class VariationalAutoencoder(): plot_model(self.model, to_file=f'{d}/model.png', show_shapes = True, show_layer_names = True, expand_nested=True) plot_model(self.encoder, to_file=f'{d}/encoder.png', show_shapes = True, show_layer_names = True) plot_model(self.decoder, to_file=f'{d}/decoder.png', show_shapes = True, show_layer_names = True) + +