diff --git a/DCGAN-PyTorch/01-DCGAN-PL.ipynb b/DCGAN-PyTorch/01-DCGAN-PL.ipynb index d3687e449d70fa5ff984345f1da4f321817e8f06..6647aea04e763dcba088694529ff31b14956fe8b 100644 --- a/DCGAN-PyTorch/01-DCGAN-PL.ipynb +++ b/DCGAN-PyTorch/01-DCGAN-PL.ipynb @@ -4,15 +4,30 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# GAN using PyTorch Lightning \n", + "<img width=\"800px\" src=\"../fidle/img/00-Fidle-header-01.svg\"></img>\n", "\n", - "See : \n", - "- https://pytorch-lightning.readthedocs.io/en/stable/notebooks/lightning_examples/basic-gan.html\n", - "- https://www.assemblyai.com/blog/pytorch-lightning-for-dummies/\n", + "# <!-- TITLE --> [SHEEP3] - A DCGAN to Draw a Sheep, with Pytorch Lightning\n", + "<!-- DESC --> Episode 1 : Draw me a sheep, revisited with a DCGAN, writing in Pytorch Lightning\n", + "<!-- AUTHOR : Jean-Luc Parouty (CNRS/SIMaP) -->\n", "\n", + "## Objectives :\n", + " - Build and train a DCGAN model with the Quick Draw dataset\n", + " - Understanding DCGAN\n", "\n", - "Note : Need \n", - "```pip install ipywidgets lightning tqdm```" + "The [Quick draw dataset](https://quickdraw.withgoogle.com/data) contains about 50.000.000 drawings, made by real people... \n", + "We are using a subset of 117.555 of Sheep drawings \n", + "To get the dataset : [https://github.com/googlecreativelab/quickdraw-dataset](https://github.com/googlecreativelab/quickdraw-dataset) \n", + "Datasets in numpy bitmap file : [https://console.cloud.google.com/storage/quickdraw_dataset/full/numpy_bitmap](https://console.cloud.google.com/storage/quickdraw_dataset/full/numpy_bitmap) \n", + "Sheep dataset : [https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/sheep.npy](https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/sheep.npy) (94.3 Mo)\n", + "\n", + "\n", + "## What we're going to do :\n", + "\n", + " - Have a look to the dataset\n", + " - Defining a GAN model\n", + " - Build the model\n", + " - Train it\n", + " - Have a look of the results" ] }, { @@ -33,6 +48,7 @@ "source": [ "import os\n", "import sys\n", + "import shutil\n", "\n", "import numpy as np\n", "import torch\n", @@ -82,7 +98,7 @@ "generator_class = 'Generator_2'\n", "discriminator_class = 'Discriminator_2' \n", " \n", - "scale = .01\n", + "scale = .005\n", "epochs = 5\n", "batch_size = 32\n", "num_img = 36\n", @@ -92,6 +108,25 @@ "data_shape = (28,28,1)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Cleaning" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# You can comment these lines to keep each run...\n", + "shutil.rmtree(f'{run_dir}/figs', ignore_errors=True)\n", + "shutil.rmtree(f'{run_dir}/models', ignore_errors=True)\n", + "shutil.rmtree(f'{run_dir}/tb_logs', ignore_errors=True)" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -168,9 +203,19 @@ }, "outputs": [], "source": [ + "# ---- A little piece of black magic to instantiate Generator and Discriminator from their class names\n", + "#\n", + "def get_classByName(class_name, **args):\n", + " module=sys.modules['__main__']\n", + " class_ = getattr(module, class_name)\n", + " instance_ = class_(**args)\n", + " return instance_\n", + "\n", + "# ----Get it, and play with them\n", + "#\n", "print('\\nInstantiation :\\n')\n", - "generator = Generator_2(latent_dim=latent_dim, data_shape=data_shape)\n", - "discriminator = Discriminator_2(latent_dim=latent_dim, data_shape=data_shape)\n", + "generator = get_classByName( generator_class, latent_dim=latent_dim, data_shape=data_shape)\n", + "discriminator = get_classByName( discriminator_class, latent_dim=latent_dim, data_shape=data_shape)\n", "\n", "print('\\nFew tests :\\n')\n", "z = torch.randn(batch_size, latent_dim)\n", @@ -273,7 +318,6 @@ "\n", "trainer = Trainer(\n", " accelerator = \"auto\",\n", - "# devices = 1 if torch.cuda.is_available() else None, # limiting got iPython runs\n", " max_epochs = epochs,\n", " callbacks = [callback_progressBar, callback_checkpoints],\n", " log_every_n_steps = batch_size,\n", @@ -284,10 +328,12 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ - "## Step 6 - Reload a checkpoint" + "## Step 6 - Reload our best model\n", + "Note : " ] }, { @@ -296,7 +342,7 @@ "metadata": {}, "outputs": [], "source": [ - "gan = GAN.load_from_checkpoint('./run/SHEEP3/models/last-v1.ckpt')" + "gan = GAN.load_from_checkpoint('./run/SHEEP3/models/bestModel.ckpt')" ] }, { @@ -307,8 +353,6 @@ "source": [ "nb_images = 32\n", "\n", - "# z = np.random.normal(size=(nb_images,latent_dim))\n", - "\n", "z = torch.randn(nb_images, latent_dim)\n", "print('z size : ',z.size())\n", "\n", @@ -325,14 +369,17 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": [] + "source": [ + "fidle.end()" + ] }, { - "cell_type": "code", - "execution_count": null, + "cell_type": "markdown", "metadata": {}, - "outputs": [], - "source": [] + "source": [ + "---\n", + "<img width=\"80px\" src=\"../fidle/img/00-Fidle-logo-01.svg\"></img>" + ] } ], "metadata": { diff --git a/DCGAN-PyTorch/modules/Discriminators.py b/DCGAN-PyTorch/modules/Discriminators.py index 86c2f706115d86ef5ce835d4d6b61ab750107c8e..2458a61dc7bf036f22227634001eafe702de3a9c 100644 --- a/DCGAN-PyTorch/modules/Discriminators.py +++ b/DCGAN-PyTorch/modules/Discriminators.py @@ -35,7 +35,6 @@ class Discriminator_1(nn.Module): ) def forward(self, img): - # img_flat = img.view(img.size(0), -1) validity = self.model(img) return validity @@ -79,7 +78,7 @@ class Discriminator_2(nn.Module): ) def forward(self, img): - img_nchw = img.permute(0, 3, 1, 2) # from NHWC to NCHW + img_nchw = img.permute(0, 3, 1, 2) # reformat from NHWC to NCHW validity = self.model(img_nchw) return validity \ No newline at end of file diff --git a/DCGAN-PyTorch/modules/Generators.py b/DCGAN-PyTorch/modules/Generators.py index 70e523d747187daf5cd1acf38d9ea32cd2f44cad..9b104d579469f51dfda08b1332c9b100b6fddaa4 100644 --- a/DCGAN-PyTorch/modules/Generators.py +++ b/DCGAN-PyTorch/modules/Generators.py @@ -67,14 +67,14 @@ class Generator_2(nn.Module): nn.Linear(latent_dim, 7*7*64), nn.Unflatten(1, (64,7,7)), - nn.UpsamplingNearest2d( scale_factor=2 ), - # nn.UpsamplingBilinear2d( scale_factor=2 ), + # nn.UpsamplingNearest2d( scale_factor=2 ), + nn.UpsamplingBilinear2d( scale_factor=2 ), nn.Conv2d( 64,128, (3,3), stride=(1,1), padding=(1,1) ), nn.ReLU(), nn.BatchNorm2d(128), - nn.UpsamplingNearest2d( scale_factor=2 ), - # nn.UpsamplingBilinear2d( scale_factor=2 ), + # nn.UpsamplingNearest2d( scale_factor=2 ), + nn.UpsamplingBilinear2d( scale_factor=2 ), nn.Conv2d( 128,256, (3,3), stride=(1,1), padding=(1,1)), nn.ReLU(), nn.BatchNorm2d(256), @@ -85,9 +85,10 @@ class Generator_2(nn.Module): ) def forward(self, z): - img = self.model(z) - img = img.view(img.size(0), *self.img_shape) # batch_size x 1 x W x H => batch_size x W x H x 1 - return img + img_nchw = self.model(z) + img_nhwc = img_nchw.permute(0, 2, 3, 1) # reformat from NCHW to NHWC + # img = img.view(img.size(0), *self.img_shape) # reformat from NCHW to NHWC + return img_nhwc