{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "<img width=\"800px\" src=\"../fidle/img/00-Fidle-header-01.svg\"></img>\n", "\n", "# <!-- TITLE --> [SHEEP2] - A WGAN-GP to Draw a Sheep\n", "<!-- DESC --> Episode 2 : Draw me a sheep, revisited with a WGAN-GP\n", "<!-- AUTHOR : Jean-Luc Parouty (CNRS/SIMaP) -->\n", "\n", "## Objectives :\n", " - Build and train a WGAN-GP model with the Quick Draw dataset\n", " - Understanding WGAN-GP\n", "\n", "The [Quick draw dataset](https://quickdraw.withgoogle.com/data) contains about 50.000.000 drawings, made by real people... \n", "We are using a subset of 117.555 of Sheep drawings \n", "To get the dataset : [https://github.com/googlecreativelab/quickdraw-dataset](https://github.com/googlecreativelab/quickdraw-dataset) \n", "Datasets in numpy bitmap file : [https://console.cloud.google.com/storage/quickdraw_dataset/full/numpy_bitmap](https://console.cloud.google.com/storage/quickdraw_dataset/full/numpy_bitmap) \n", "Sheep dataset : [https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/sheep.npy](https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/sheep.npy) (94.3 Mo)\n", "\n", "\n", "## What we're going to do :\n", "\n", " - Have a look to the dataset\n", " - Defining a GAN model\n", " - Build the model\n", " - Train it\n", " - Analyze the results\n", "\n", "## Acknowledgements :\n", "Thanks to **François Chollet** who is at the base of this example. \n", "See : [https://keras.io/examples/](https://keras.io/examples/)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Step 1 - Init python stuff" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt\n", "from skimage import io\n", "import sys\n", "\n", "import tensorflow as tf\n", "from tensorflow import keras\n", "from tensorflow.keras import layers\n", "from tensorflow.keras.callbacks import TensorBoard\n", "\n", "from modules.models import WGANGP\n", "from modules.callbacks import ImagesCallback\n", "\n", "import fidle\n", "\n", "# Init Fidle environment\n", "run_id, run_dir, datasets_dir = fidle.init('SHEEP2')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Step 2 - Parameters\n", "`scale` : With scale=1, we need 5-6 minutes on a GPU V100 ...and >2h on a CPU ! \n", "`latent_dim` : Latent space dimension, 128 for example ! \n", "`fit_verbosity` : verbosity during training : 0 = silent, 1 = progress bar, 2 = one line per epoch \n", "`num_img` : Number of images to visualize\n", "\n", "**Notes:**\n", "- The settings below (scale=0.01) allow the notebooks to run on a laptop, but not to get a minimal result! \n", "- For a decent result, you need something like: scale=1. \n", "- The convergence being much better, epochs can here remain at epochs=3 :-)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "latent_dim = 128\n", "\n", "scale = 0.01\n", "epochs = 3\n", "n_critic = 2\n", "batch_size = 64\n", "num_img = 12\n", "fit_verbosity = 1" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Override parameters (batch mode) - Just forget this cell" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "fidle.override('scale', 'latent_dim', 'epochs', 'batch_size', 'num_img', 'fit_verbosity')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Step 3 - Load dataset and have a look \n", "Load sheeps as numpy bitmaps..." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Load dataset\n", "x_data = np.load(datasets_dir+'/QuickDraw/origine/sheep.npy')\n", "print('Original dataset shape : ',x_data.shape)\n", "\n", "# Rescale\n", "n=int(scale*len(x_data))\n", "x_data = x_data[:n]\n", "print('Rescaled dataset shape : ',x_data.shape)\n", "\n", "# Normalize, reshape and shuffle\n", "x_data = x_data/255\n", "x_data = x_data.reshape(-1,28,28,1)\n", "np.random.shuffle(x_data)\n", "print('Final dataset shape : ',x_data.shape)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "...and have a look : \n", "Note : These sheep are sheep drawn ... by real humans!" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "fidle.scrawler.images( x_data.reshape(-1,28,28), indices=range(72), columns=12, x_size=1, y_size=1, \n", " y_padding=0,spines_alpha=0, save_as='01-Sheeps')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Step 4 - Create a discriminator" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "inputs = keras.Input(shape=(28, 28, 1))\n", "x = layers.Conv2D(64, kernel_size=4, strides=2, padding=\"same\")(inputs)\n", "x = layers.LeakyReLU(alpha=0.2)(x)\n", "x = layers.Conv2D(128, kernel_size=4, strides=2, padding=\"same\")(x)\n", "x = layers.LeakyReLU(alpha=0.2)(x)\n", "x = layers.Conv2D(128, kernel_size=4, strides=2, padding=\"same\")(x)\n", "x = layers.LeakyReLU(alpha=0.2)(x)\n", "x = layers.Flatten()(x)\n", "x = layers.Dropout(0.2)(x)\n", "c = layers.Dense(1)(x)\n", "\n", "discriminator = keras.Model(inputs, c, name=\"discriminator\")\n", "discriminator.summary()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Step 5 - Create a generator" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "inputs = keras.Input(shape=(latent_dim,))\n", "x = layers.Dense(7 * 7 * 64)(inputs)\n", "x = layers.Reshape((7, 7, 64))(x)\n", "x = layers.UpSampling2D()(x)\n", "x = layers.Conv2D(128, kernel_size=3, strides=1, padding='same', activation='relu')(x)\n", "x = layers.UpSampling2D()(x)\n", "x = layers.Conv2D(256, kernel_size=3, strides=1, padding='same', activation='relu')(x)\n", "outputs = layers.Conv2D(1, kernel_size=5, strides=1, padding=\"same\", activation='sigmoid')(x)\n", "\n", "generator = keras.Model(inputs, outputs, name=\"generator\")\n", "generator.summary()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Step 6 - Build, compile and train our DCGAN \n", "Duration : 5' on a V100, with : scale=0.5, epochs=10, n_critic=2\n", "First, clean saved images :" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!rm $run_dir/images/*.jpg >/dev/null 2>&1 " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Build our model :" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "gan = WGANGP(discriminator=discriminator, generator=generator, latent_dim=latent_dim, n_critic=n_critic)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "gan.compile(\n", "# discriminator_optimizer = keras.optimizers.Adam(learning_rate=0.0001),\n", "# generator_optimizer = keras.optimizers.Adam(learning_rate=0.0001)\n", " discriminator_optimizer = keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5, beta_2=0.9),\n", " generator_optimizer = keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5, beta_2=0.9)\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Add a callback to save images, train our DCGAN model and save it :" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "imagesCallback = ImagesCallback(num_img=num_img, latent_dim=latent_dim, run_dir=f'{run_dir}/images')\n", "\n", "history = gan.fit( x_data, \n", " epochs=epochs, \n", " batch_size=batch_size, \n", " callbacks=[imagesCallback], \n", " verbose=fit_verbosity )\n", "\n", "gan.save(f'{run_dir}/models/model.h5')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Step 7 - History" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "fidle.scrawler.history(history, plot={'loss':['d_loss','g_loss']}, save_as='01-history')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "images=[]\n", "for epoch in range(0,epochs,1):\n", " for i in range(num_img):\n", " filename = f'{run_dir}/images/image-{epoch:03d}-{i:02d}.jpg'\n", " image = io.imread(filename)\n", " images.append(image) \n", "\n", "fidle.scrawler.images(images, None, indices='all', columns=num_img, x_size=1,y_size=1, interpolation=None, y_padding=0, spines_alpha=0, save_as='04-learning')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Step 8 - Generation\n", "Reload our saved model :" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "gan.reload(f'{run_dir}/models/model.h5')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Generate somes images from latent space :" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "nb_images = 12*15\n", "\n", "z = np.random.normal(size=(nb_images,latent_dim))\n", "images = gan.predict(z, verbose=0)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Plot it :" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "fidle.scrawler.images(images, None, indices='all', columns=num_img, x_size=1,y_size=1, interpolation=None, y_padding=0, spines_alpha=0, save_as='04-learning')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "fidle.end()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "<img width=\"80px\" src=\"../fidle/img/00-Fidle-logo-01.svg\"></img>" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3.9.2 ('fidle-env')", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.2" }, "vscode": { "interpreter": { "hash": "b3929042cc22c1274d74e3e946c52b845b57cb6d84f2d591ffe0519b38e4896d" } } }, "nbformat": 4, "nbformat_minor": 4 }