From 717b7be895e412c1018ea2190daf0a99f45bd6de Mon Sep 17 00:00:00 2001
From: Jean-Luc Parouty <Jean-Luc.Parouty@simap.grenoble-inp.fr>
Date: Fri, 24 Feb 2023 20:47:28 +0100
Subject: [PATCH] Add GAN example with PyTorch Lightning

---
 DCGAN-PyTorch/01-DCGAN-PL.ipynb              | 344 +++++++++++++++++++
 DCGAN-PyTorch/modules/Discriminators.py      |  41 +++
 DCGAN-PyTorch/modules/GAN.py                 | 179 ++++++++++
 DCGAN-PyTorch/modules/Generators.py          |  53 +++
 DCGAN-PyTorch/modules/QuickDrawDataModule.py |  71 ++++
 DCGAN-PyTorch/modules/SmartProgressBar.py    |  70 ++++
 6 files changed, 758 insertions(+)
 create mode 100644 DCGAN-PyTorch/01-DCGAN-PL.ipynb
 create mode 100644 DCGAN-PyTorch/modules/Discriminators.py
 create mode 100644 DCGAN-PyTorch/modules/GAN.py
 create mode 100644 DCGAN-PyTorch/modules/Generators.py
 create mode 100644 DCGAN-PyTorch/modules/QuickDrawDataModule.py
 create mode 100644 DCGAN-PyTorch/modules/SmartProgressBar.py

diff --git a/DCGAN-PyTorch/01-DCGAN-PL.ipynb b/DCGAN-PyTorch/01-DCGAN-PL.ipynb
new file mode 100644
index 0000000..fdb048a
--- /dev/null
+++ b/DCGAN-PyTorch/01-DCGAN-PL.ipynb
@@ -0,0 +1,344 @@
+{
+ "cells": [
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# GAN using PyTorch Lightning \n",
+    "\n",
+    "See : \n",
+    "- https://pytorch-lightning.readthedocs.io/en/stable/notebooks/lightning_examples/basic-gan.html\n",
+    "- https://www.assemblyai.com/blog/pytorch-lightning-for-dummies/\n",
+    "\n",
+    "\n",
+    "Note : Need \n",
+    "```pip install ipywidgets lightning tqdm```"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Step 1 - Init and parameters\n",
+    "#### Python init"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "tags": []
+   },
+   "outputs": [],
+   "source": [
+    "import os\n",
+    "import sys\n",
+    "\n",
+    "import numpy as np\n",
+    "import torch\n",
+    "import torch.nn as nn\n",
+    "import torch.nn.functional as F\n",
+    "import torchvision\n",
+    "import torchvision.transforms as transforms\n",
+    "from lightning import LightningDataModule, LightningModule, Trainer\n",
+    "from lightning.pytorch.callbacks.progress.tqdm_progress import TQDMProgressBar\n",
+    "from lightning.pytorch.callbacks.progress.base          import ProgressBarBase\n",
+    "from lightning.pytorch.callbacks                        import ModelCheckpoint\n",
+    "from lightning.pytorch.loggers.tensorboard              import TensorBoardLogger\n",
+    "\n",
+    "from tqdm import tqdm\n",
+    "from torch.utils.data import DataLoader\n",
+    "\n",
+    "import fidle\n",
+    "\n",
+    "from modules.SmartProgressBar    import SmartProgressBar\n",
+    "from modules.QuickDrawDataModule import QuickDrawDataModule\n",
+    "\n",
+    "from modules.GAN                 import GAN\n",
+    "from modules.Generators          import *\n",
+    "from modules.Discriminators      import *\n",
+    "\n",
+    "# Init Fidle environment\n",
+    "run_id, run_dir, datasets_dir = fidle.init('SHEEP3')"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "#### Few parameters"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "tags": []
+   },
+   "outputs": [],
+   "source": [
+    "latent_dim          = 128\n",
+    "    \n",
+    "generator_class     = 'Generator_1'\n",
+    "discriminator_class = 'Discriminator_1'    \n",
+    "    \n",
+    "scale               = .05\n",
+    "epochs              = 15\n",
+    "batch_size          = 32\n",
+    "num_img             = 36\n",
+    "fit_verbosity       = 2\n",
+    "    \n",
+    "dataset_file        = datasets_dir+'/QuickDraw/origine/sheep.npy' \n",
+    "data_shape          = (28,28,1)"
+   ]
+  },
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Step 2 - Get some nice data"
+   ]
+  },
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "#### Get a Nice DataModule\n",
+    "Our DataModule is defined in [./modules/QuickDrawDataModule.py](./modules/QuickDrawDataModule.py)   \n",
+    "This is a [LightningDataModule](https://pytorch-lightning.readthedocs.io/en/stable/data/datamodule.html)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "tags": []
+   },
+   "outputs": [],
+   "source": [
+    "dm = QuickDrawDataModule(dataset_file, scale, batch_size, num_workers=8)\n",
+    "dm.setup()"
+   ]
+  },
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "#### Have a look"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "dl         = dm.train_dataloader()\n",
+    "batch_data = next(iter(dl))\n",
+    "\n",
+    "fidle.scrawler.images( batch_data.reshape(-1,28,28), indices=range(batch_size), columns=12, x_size=1, y_size=1, \n",
+    "                       y_padding=0,spines_alpha=0, save_as='01-Sheeps')"
+   ]
+  },
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Step 3 - Get a nice GAN model\n",
+    "\n",
+    "Our Generators are defined in [./modules/Generators.py](./modules/Generators.py)  \n",
+    "Our Discriminators are defined in [./modules/Discriminators.py](./modules/Discriminators.py)  \n",
+    "\n",
+    "\n",
+    "Our GAN is defined in [./modules/GAN.py](./modules/GAN.py)  "
+   ]
+  },
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "#### Basic test - Just to be sure it (could) works... ;-)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "tags": []
+   },
+   "outputs": [],
+   "source": [
+    "print('\\nInstantiation :\\n')\n",
+    "generator     = Generator_1(latent_dim=latent_dim, data_shape=data_shape)\n",
+    "discriminator = Discriminator_1(latent_dim=latent_dim, data_shape=data_shape)\n",
+    "\n",
+    "print('\\nFew tests :\\n')\n",
+    "z = torch.randn(batch_size, latent_dim)\n",
+    "print('z size        : ',z.size())\n",
+    "\n",
+    "fake_img = generator.forward(z)\n",
+    "print('fake_img      : ', fake_img.size())\n",
+    "\n",
+    "p = discriminator.forward(fake_img)\n",
+    "print('pred fake     : ', p.size())\n",
+    "\n",
+    "print('batch_data    : ',batch_data.size())\n",
+    "\n",
+    "p = discriminator.forward(batch_data)\n",
+    "print('pred real     : ', p.size())\n",
+    "\n",
+    "nimg = fake_img.detach().numpy()\n",
+    "fidle.scrawler.images( nimg.reshape(-1,28,28), indices=range(batch_size), columns=12, x_size=1, y_size=1, \n",
+    "                       y_padding=0,spines_alpha=0, save_as='01-Sheeps')"
+   ]
+  },
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "#### GAN model\n",
+    "To simplify our code, the GAN class is defined separately in the module [./modules/GAN.py](./modules/GAN.py)  \n",
+    "Passing the classe names for generator/discriminator by parameter allows to stay modular and to use the PL checkpoints."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "gan = GAN( data_shape          = data_shape,  \n",
+    "           batch_size          = batch_size, \n",
+    "           latent_dim          = latent_dim, \n",
+    "           generator_class     = generator_class, \n",
+    "           discriminator_class = discriminator_class)"
+   ]
+  },
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Step 5 - Train it !\n",
+    "#### Instantiate Callbacks, Logger & co.\n",
+    "More about :\n",
+    "- [Checkpoints](https://pytorch-lightning.readthedocs.io/en/stable/common/checkpointing_basic.html)\n",
+    "- [modelCheckpoint](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.callbacks.ModelCheckpoint.html#pytorch_lightning.callbacks.ModelCheckpoint)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "tags": []
+   },
+   "outputs": [],
+   "source": [
+    "\n",
+    "# ---- for tensorboard logs\n",
+    "#\n",
+    "logger       = TensorBoardLogger(       save_dir       = f'{run_dir}',\n",
+    "                                        name           = 'tb_logs'  )\n",
+    "\n",
+    "# ---- To save checkpoints\n",
+    "#\n",
+    "callback_checkpoints = ModelCheckpoint( dirpath        = f'{run_dir}/models', \n",
+    "                                        filename       = 'bestModel', \n",
+    "                                        save_top_k     = 1, \n",
+    "                                        save_last      = True,\n",
+    "                                        every_n_epochs = 1, \n",
+    "                                        monitor        = \"g_loss\")\n",
+    "\n",
+    "# ---- To have a nive progress bar\n",
+    "#\n",
+    "callback_progressBar = SmartProgressBar(verbosity=2)          # Usable evertywhere\n",
+    "# progress_bar = TQDMProgressBar(refresh_rate=1)              # Usable in real jupyter lab (bug in vscode)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "#### Train it"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "tags": []
+   },
+   "outputs": [],
+   "source": [
+    "\n",
+    "trainer = Trainer(\n",
+    "    accelerator        = \"auto\",\n",
+    "#    devices            = 1 if torch.cuda.is_available() else None,  # limiting got iPython runs\n",
+    "    max_epochs         = epochs,\n",
+    "    callbacks          = [callback_progressBar, callback_checkpoints],\n",
+    "    log_every_n_steps  = batch_size,\n",
+    "    logger             = logger\n",
+    ")\n",
+    "\n",
+    "trainer.fit(gan, dm)"
+   ]
+  },
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Step 6 - Reload a checkpoint"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# gan = GAN.load_from_checkpoint('./run/SHEEP3/lightning_logs/version_3/checkpoints/epoch=4-step=1980.ckpt')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3 (ipykernel)",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.9.2"
+  },
+  "vscode": {
+   "interpreter": {
+    "hash": "b3929042cc22c1274d74e3e946c52b845b57cb6d84f2d591ffe0519b38e4896d"
+   }
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/DCGAN-PyTorch/modules/Discriminators.py b/DCGAN-PyTorch/modules/Discriminators.py
new file mode 100644
index 0000000..5c6334b
--- /dev/null
+++ b/DCGAN-PyTorch/modules/Discriminators.py
@@ -0,0 +1,41 @@
+# ------------------------------------------------------------------
+#     _____ _     _ _
+#    |  ___(_) __| | | ___
+#    | |_  | |/ _` | |/ _ \
+#    |  _| | | (_| | |  __/
+#    |_|   |_|\__,_|_|\___|                         GAN / Generators
+# ------------------------------------------------------------------
+# Formation Introduction au Deep Learning  (FIDLE)
+# CNRS/SARI/DEVLOG MIAI/EFELIA 2023 - https://fidle.cnrs.fr
+# ------------------------------------------------------------------
+# by JL Parouty (feb 2023) - PyTorch Lightning example
+
+import numpy as np
+import torch.nn as nn
+
+class Discriminator_1(nn.Module):
+
+    def __init__(self, latent_dim=None, data_shape=None):
+    
+        super().__init__()
+        self.img_shape = data_shape
+        print('init discriminator       : ',data_shape,' to sigmoid')
+
+        self.model = nn.Sequential(
+
+            nn.Flatten(),
+            nn.Linear(int(np.prod(data_shape)), 512),
+            nn.ReLU(),
+            
+            nn.Linear(512, 256),
+            nn.ReLU(),
+
+            nn.Linear(256, 1),
+            nn.Sigmoid(),
+        )
+
+    def forward(self, img):
+        # img_flat = img.view(img.size(0), -1)
+        validity = self.model(img)
+
+        return validity
\ No newline at end of file
diff --git a/DCGAN-PyTorch/modules/GAN.py b/DCGAN-PyTorch/modules/GAN.py
new file mode 100644
index 0000000..8a1f3ad
--- /dev/null
+++ b/DCGAN-PyTorch/modules/GAN.py
@@ -0,0 +1,179 @@
+
+# ------------------------------------------------------------------
+#     _____ _     _ _
+#    |  ___(_) __| | | ___
+#    | |_  | |/ _` | |/ _ \
+#    |  _| | | (_| | |  __/
+#    |_|   |_|\__,_|_|\___|                GAN / GAN LigthningModule
+# ------------------------------------------------------------------
+# Formation Introduction au Deep Learning  (FIDLE)
+# CNRS/SARI/DEVLOG MIAI/EFELIA 2023 - https://fidle.cnrs.fr
+# ------------------------------------------------------------------
+# by JL Parouty (feb 2023) - PyTorch Lightning example
+
+
+import sys
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torchvision
+from lightning import LightningModule
+
+
+class GAN(LightningModule):
+
+    # -------------------------------------------------------------------------
+    # Init
+    # -------------------------------------------------------------------------
+    #
+    def __init__(
+        self,
+        data_shape          = (None,None,None),
+        latent_dim          = None,
+        lr                  = 0.0002,
+        b1                  = 0.5,
+        b2                  = 0.999,
+        batch_size          = 64,
+        generator_class     = None,
+        discriminator_class = None,
+        **kwargs,
+    ):
+        super().__init__()
+
+        print('\n---- GAN initialization --------------------------------------------')
+
+        # ---- Hyperparameters
+        #
+        # Enable Lightning to store all the provided arguments under the self.hparams attribute.
+        # These hyperparameters will also be stored within the model checkpoint.
+        #
+        self.save_hyperparameters()
+
+        print('Hyperarameters are :')
+        for name,value in self.hparams.items():
+            print(f'{name:24s} : {value}')
+
+        # ---- Generator/Discriminator instantiation
+        #
+        # self.generator     = Generator(latent_dim=self.hparams.latent_dim, img_shape=data_shape)
+        # self.discriminator = Discriminator(img_shape=data_shape)
+
+        print('Submodels :')
+        module=sys.modules['__main__']
+        class_g = getattr(module, generator_class)
+        class_d = getattr(module, discriminator_class)
+        self.generator     = class_g( latent_dim=latent_dim, data_shape=data_shape)
+        self.discriminator = class_d( latent_dim=latent_dim, data_shape=data_shape)
+
+        # ---- Validation and example data
+        #
+        self.validation_z        = torch.randn(8, self.hparams.latent_dim)
+        self.example_input_array = torch.zeros(2, self.hparams.latent_dim)
+
+
+    def forward(self, z):
+        return self.generator(z)
+
+
+    def adversarial_loss(self, y_hat, y):
+        return F.binary_cross_entropy(y_hat, y)
+
+
+    def training_step(self, batch, batch_idx, optimizer_idx):
+        imgs       = batch
+        batch_size = batch.size(0)
+
+        # ---- Get some latent space vectors
+        #      We use type_as() to make sure we initialize z on the right device (GPU/CPU).
+        #
+        z = torch.randn(batch_size, self.hparams.latent_dim)
+        z = z.type_as(imgs)
+
+        # ---- Train generator
+        #      Generator use optimizer #0
+        #      We try to generate false images that could mislead the discriminator
+        #
+        if optimizer_idx == 0:
+
+            # Generate fake images
+            self.fake_imgs = self.generator.forward(z)
+
+            # Assemble labels that say all images are real, yes it's a lie ;-)
+            # put on GPU because we created this tensor inside training_loop
+            misleading_labels = torch.ones(batch_size, 1)
+            misleading_labels = misleading_labels.type_as(imgs)
+
+            # Adversarial loss is binary cross-entropy
+            g_loss = self.adversarial_loss(self.discriminator.forward(self.fake_imgs), misleading_labels)
+            self.log("g_loss", g_loss, prog_bar=True)
+            return g_loss
+
+        # ---- Train discriminator
+        #      Discriminator use optimizer #1
+        #      We try to make the difference between fake images and real ones 
+        #
+        if optimizer_idx == 1:
+            
+            # These images are reals
+            real_labels = torch.ones(batch_size, 1)
+            real_labels = real_labels.type_as(imgs)
+            pred_labels = self.discriminator.forward(imgs)
+
+            real_loss   = self.adversarial_loss(pred_labels, real_labels)
+
+            # These images are fake
+            fake_imgs   = self.generator.forward(z)
+            fake_labels = torch.zeros(batch_size, 1)
+            fake_labels = fake_labels.type_as(imgs)
+
+            fake_loss   = self.adversarial_loss(self.discriminator(fake_imgs.detach()), fake_labels)
+
+            # Discriminator loss is the average
+            d_loss = (real_loss + fake_loss) / 2
+            self.log("d_loss", d_loss, prog_bar=True)
+            return d_loss
+
+
+    def configure_optimizers(self):
+
+        lr = self.hparams.lr
+        b1 = self.hparams.b1
+        b2 = self.hparams.b2
+
+        # With a GAN, we need 2 separate optimizer.
+        # opt_g to optimize the generator      #0
+        # opt_d to optimize the discriminator  #1
+        opt_g = torch.optim.Adam(self.generator.parameters(),     lr=lr, betas=(b1, b2))
+        opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2))
+        return [opt_g, opt_d], []
+
+
+    def training_epoch_end(self, outputs):
+
+        # Get our validation latent vectors as z
+        # z = self.validation_z.type_as(self.generator.model[0].weight)
+
+        # ---- Log Graph
+        #
+        if(self.current_epoch==1):
+            sampleImg=torch.rand((1,28,28,1))
+            self.logger.experiment.add_graph(self.discriminator,sampleImg)
+
+        # ---- Log d_loss/epoch
+        #
+        g_loss, d_loss = 0,0
+        for metrics in outputs:
+            g_loss+=float( metrics[0]['loss'] )
+            d_loss+=float( metrics[1]['loss'] )
+        g_loss, d_loss = g_loss/len(outputs), d_loss/len(outputs)
+        self.logger.experiment.add_scalar("g_loss/epochs",g_loss, self.current_epoch)
+        self.logger.experiment.add_scalar("d_loss/epochs",d_loss, self.current_epoch)
+
+        # ---- Log some of these images
+        #
+        z = torch.randn(self.hparams.batch_size, self.hparams.latent_dim)
+        z = z.type_as(self.generator.model[0].weight)
+        sample_imgs = self.generator(z)
+        sample_imgs = sample_imgs.permute(0, 3, 1, 2) # from NHWC to NCHW
+        grid = torchvision.utils.make_grid(tensor=sample_imgs, nrow=12, )
+        self.logger.experiment.add_image(f"Generated images", grid,self.current_epoch)
diff --git a/DCGAN-PyTorch/modules/Generators.py b/DCGAN-PyTorch/modules/Generators.py
new file mode 100644
index 0000000..0c49adb
--- /dev/null
+++ b/DCGAN-PyTorch/modules/Generators.py
@@ -0,0 +1,53 @@
+
+# ------------------------------------------------------------------
+#     _____ _     _ _
+#    |  ___(_) __| | | ___
+#    | |_  | |/ _` | |/ _ \
+#    |  _| | | (_| | |  __/
+#    |_|   |_|\__,_|_|\___|                         GAN / Generators
+# ------------------------------------------------------------------
+# Formation Introduction au Deep Learning  (FIDLE)
+# CNRS/SARI/DEVLOG MIAI/EFELIA 2023 - https://fidle.cnrs.fr
+# ------------------------------------------------------------------
+# by JL Parouty (feb 2023) - PyTorch Lightning example
+
+
+import numpy as np
+import torch.nn as nn
+
+
+class Generator_1(nn.Module):
+
+    def __init__(self, latent_dim=None, data_shape=None):
+        super().__init__()
+        self.latent_dim = latent_dim
+        self.img_shape  = data_shape
+        print('init generator           : ',latent_dim,' to ',data_shape)
+
+        self.model = nn.Sequential(
+            
+            nn.Linear(latent_dim, 128),
+            nn.ReLU(),
+
+            nn.Linear(128,256),
+            nn.BatchNorm1d(256, 0.8),
+            nn.ReLU(),
+
+            nn.Linear(256, 512),
+            nn.BatchNorm1d(512, 0.8),
+            nn.ReLU(),
+
+            nn.Linear(512, 1024),
+            nn.BatchNorm1d(1024, 0.8),
+            nn.ReLU(),
+
+            nn.Linear(1024, int(np.prod(data_shape))),
+            nn.Sigmoid()
+
+        )
+
+
+    def forward(self, z):
+        img = self.model(z)
+        img = img.view(img.size(0), *self.img_shape)
+        return img
\ No newline at end of file
diff --git a/DCGAN-PyTorch/modules/QuickDrawDataModule.py b/DCGAN-PyTorch/modules/QuickDrawDataModule.py
new file mode 100644
index 0000000..34a4ecf
--- /dev/null
+++ b/DCGAN-PyTorch/modules/QuickDrawDataModule.py
@@ -0,0 +1,71 @@
+
+# ------------------------------------------------------------------
+#     _____ _     _ _
+#    |  ___(_) __| | | ___
+#    | |_  | |/ _` | |/ _ \
+#    |  _| | | (_| | |  __/
+#    |_|   |_|\__,_|_|\___|                GAN / QuickDrawDataModule
+# ------------------------------------------------------------------
+# Formation Introduction au Deep Learning  (FIDLE)
+# CNRS/SARI/DEVLOG MIAI/EFELIA 2023 - https://fidle.cnrs.fr
+# ------------------------------------------------------------------
+# by JL Parouty (feb 2023) - PyTorch Lightning example
+
+
+import numpy as np
+import torch
+from lightning import LightningDataModule
+from torch.utils.data import DataLoader
+
+
+class QuickDrawDataModule(LightningDataModule):
+
+
+    def __init__( self, dataset_file='./sheep.npy', scale=1., batch_size=64, num_workers=4 ):
+
+        super().__init__()
+
+        print('\n---- QuickDrawDataModule initialization ----------------------------')
+        print(f'with : scale={scale}  batch size={batch_size}')
+        
+        self.scale        = scale
+        self.dataset_file = dataset_file
+        self.batch_size   = batch_size
+        self.num_workers  = num_workers
+
+        self.dims         = (28, 28, 1)
+        self.num_classes  = 10
+
+
+
+    def prepare_data(self):
+        pass
+
+
+    def setup(self, stage=None):
+        print('\nDataModule Setup :')
+        # Load dataset
+        # Called at the beginning of each stage (train,val,test)
+        # Here, whatever the stage value, we'll have only one set.
+        data = np.load(self.dataset_file)
+        print('Original dataset shape : ',data.shape)
+
+        # Rescale
+        n=int(self.scale*len(data))
+        data = data[:n]
+        print('Rescaled dataset shape : ',data.shape)
+
+        # Normalize, reshape and shuffle
+        data = data/255
+        data = data.reshape(-1,28,28,1)
+        data = torch.from_numpy(data).float()
+        print('Final dataset shape    : ',data.shape)
+
+        print('Dataset loaded and ready.')
+        self.data_train = data
+
+
+    def train_dataloader(self):
+        # Note : Numpy ndarray is Dataset compliant
+        # Have map-style interface. See https://pytorch.org/docs/stable/data.html 
+        return DataLoader( self.data_train, batch_size=self.batch_size, num_workers=self.num_workers )
\ No newline at end of file
diff --git a/DCGAN-PyTorch/modules/SmartProgressBar.py b/DCGAN-PyTorch/modules/SmartProgressBar.py
new file mode 100644
index 0000000..3ebe192
--- /dev/null
+++ b/DCGAN-PyTorch/modules/SmartProgressBar.py
@@ -0,0 +1,70 @@
+
+# ------------------------------------------------------------------
+#     _____ _     _ _
+#    |  ___(_) __| | | ___
+#    | |_  | |/ _` | |/ _ \
+#    |  _| | | (_| | |  __/
+#    |_|   |_|\__,_|_|\___|                   GAN / SmartProgressBar
+# ------------------------------------------------------------------
+# Formation Introduction au Deep Learning  (FIDLE)
+# CNRS/SARI/DEVLOG MIAI/EFELIA 2023 - https://fidle.cnrs.fr
+# ------------------------------------------------------------------
+# by JL Parouty (feb 2023) - PyTorch Lightning example
+
+from lightning.pytorch.callbacks.progress.base import ProgressBarBase
+from tqdm import tqdm
+import sys
+
+class SmartProgressBar(ProgressBarBase):
+
+    def __init__(self, verbosity=2):
+        super().__init__()
+        self.verbosity = verbosity
+
+    def disable(self):
+        self.enable = False
+
+
+    def setup(self, trainer, pl_module, stage):
+        super().setup(trainer, pl_module, stage)
+        self.stage = stage
+
+
+    def on_train_epoch_start(self, trainer, pl_module):
+        super().on_train_epoch_start(trainer, pl_module)
+        if not self.enable : return
+
+        if self.verbosity==2:
+            self.progress=tqdm( total=trainer.num_training_batches,
+                                desc=f'{self.stage} {trainer.current_epoch+1}/{trainer.max_epochs}', 
+                                ncols=100, ascii= " >", 
+                                bar_format='{l_bar}{bar}| [{elapsed}] {postfix}')
+
+
+
+    def on_train_epoch_end(self, trainer, pl_module):
+        super().on_train_epoch_end(trainer, pl_module)
+
+        if not self.enable : return
+
+        if self.verbosity==2:
+            self.progress.close()
+
+        if self.verbosity==1:
+            print(f'Train {trainer.current_epoch+1}/{trainer.max_epochs} Done.')
+
+
+    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
+        super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)
+
+        if not self.enable : return
+        
+        if self.verbosity==2:
+            metrics = {}
+            for name,value in trainer.logged_metrics.items():
+                metrics[name]=f'{float( trainer.logged_metrics[name] ):3.3f}'
+            self.progress.set_postfix(metrics)
+            self.progress.update(1)
+
+
+progress_bar = SmartProgressBar(verbosity=2)
-- 
GitLab