From 717b7be895e412c1018ea2190daf0a99f45bd6de Mon Sep 17 00:00:00 2001 From: Jean-Luc Parouty <Jean-Luc.Parouty@simap.grenoble-inp.fr> Date: Fri, 24 Feb 2023 20:47:28 +0100 Subject: [PATCH] Add GAN example with PyTorch Lightning --- DCGAN-PyTorch/01-DCGAN-PL.ipynb | 344 +++++++++++++++++++ DCGAN-PyTorch/modules/Discriminators.py | 41 +++ DCGAN-PyTorch/modules/GAN.py | 179 ++++++++++ DCGAN-PyTorch/modules/Generators.py | 53 +++ DCGAN-PyTorch/modules/QuickDrawDataModule.py | 71 ++++ DCGAN-PyTorch/modules/SmartProgressBar.py | 70 ++++ 6 files changed, 758 insertions(+) create mode 100644 DCGAN-PyTorch/01-DCGAN-PL.ipynb create mode 100644 DCGAN-PyTorch/modules/Discriminators.py create mode 100644 DCGAN-PyTorch/modules/GAN.py create mode 100644 DCGAN-PyTorch/modules/Generators.py create mode 100644 DCGAN-PyTorch/modules/QuickDrawDataModule.py create mode 100644 DCGAN-PyTorch/modules/SmartProgressBar.py diff --git a/DCGAN-PyTorch/01-DCGAN-PL.ipynb b/DCGAN-PyTorch/01-DCGAN-PL.ipynb new file mode 100644 index 0000000..fdb048a --- /dev/null +++ b/DCGAN-PyTorch/01-DCGAN-PL.ipynb @@ -0,0 +1,344 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# GAN using PyTorch Lightning \n", + "\n", + "See : \n", + "- https://pytorch-lightning.readthedocs.io/en/stable/notebooks/lightning_examples/basic-gan.html\n", + "- https://www.assemblyai.com/blog/pytorch-lightning-for-dummies/\n", + "\n", + "\n", + "Note : Need \n", + "```pip install ipywidgets lightning tqdm```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 1 - Init and parameters\n", + "#### Python init" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import os\n", + "import sys\n", + "\n", + "import numpy as np\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "import torchvision\n", + "import torchvision.transforms as transforms\n", + "from lightning import LightningDataModule, LightningModule, Trainer\n", + "from lightning.pytorch.callbacks.progress.tqdm_progress import TQDMProgressBar\n", + "from lightning.pytorch.callbacks.progress.base import ProgressBarBase\n", + "from lightning.pytorch.callbacks import ModelCheckpoint\n", + "from lightning.pytorch.loggers.tensorboard import TensorBoardLogger\n", + "\n", + "from tqdm import tqdm\n", + "from torch.utils.data import DataLoader\n", + "\n", + "import fidle\n", + "\n", + "from modules.SmartProgressBar import SmartProgressBar\n", + "from modules.QuickDrawDataModule import QuickDrawDataModule\n", + "\n", + "from modules.GAN import GAN\n", + "from modules.Generators import *\n", + "from modules.Discriminators import *\n", + "\n", + "# Init Fidle environment\n", + "run_id, run_dir, datasets_dir = fidle.init('SHEEP3')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Few parameters" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "latent_dim = 128\n", + " \n", + "generator_class = 'Generator_1'\n", + "discriminator_class = 'Discriminator_1' \n", + " \n", + "scale = .05\n", + "epochs = 15\n", + "batch_size = 32\n", + "num_img = 36\n", + "fit_verbosity = 2\n", + " \n", + "dataset_file = datasets_dir+'/QuickDraw/origine/sheep.npy' \n", + "data_shape = (28,28,1)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 2 - Get some nice data" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Get a Nice DataModule\n", + "Our DataModule is defined in [./modules/QuickDrawDataModule.py](./modules/QuickDrawDataModule.py) \n", + "This is a [LightningDataModule](https://pytorch-lightning.readthedocs.io/en/stable/data/datamodule.html)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "dm = QuickDrawDataModule(dataset_file, scale, batch_size, num_workers=8)\n", + "dm.setup()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Have a look" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "dl = dm.train_dataloader()\n", + "batch_data = next(iter(dl))\n", + "\n", + "fidle.scrawler.images( batch_data.reshape(-1,28,28), indices=range(batch_size), columns=12, x_size=1, y_size=1, \n", + " y_padding=0,spines_alpha=0, save_as='01-Sheeps')" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 3 - Get a nice GAN model\n", + "\n", + "Our Generators are defined in [./modules/Generators.py](./modules/Generators.py) \n", + "Our Discriminators are defined in [./modules/Discriminators.py](./modules/Discriminators.py) \n", + "\n", + "\n", + "Our GAN is defined in [./modules/GAN.py](./modules/GAN.py) " + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Basic test - Just to be sure it (could) works... ;-)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "print('\\nInstantiation :\\n')\n", + "generator = Generator_1(latent_dim=latent_dim, data_shape=data_shape)\n", + "discriminator = Discriminator_1(latent_dim=latent_dim, data_shape=data_shape)\n", + "\n", + "print('\\nFew tests :\\n')\n", + "z = torch.randn(batch_size, latent_dim)\n", + "print('z size : ',z.size())\n", + "\n", + "fake_img = generator.forward(z)\n", + "print('fake_img : ', fake_img.size())\n", + "\n", + "p = discriminator.forward(fake_img)\n", + "print('pred fake : ', p.size())\n", + "\n", + "print('batch_data : ',batch_data.size())\n", + "\n", + "p = discriminator.forward(batch_data)\n", + "print('pred real : ', p.size())\n", + "\n", + "nimg = fake_img.detach().numpy()\n", + "fidle.scrawler.images( nimg.reshape(-1,28,28), indices=range(batch_size), columns=12, x_size=1, y_size=1, \n", + " y_padding=0,spines_alpha=0, save_as='01-Sheeps')" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### GAN model\n", + "To simplify our code, the GAN class is defined separately in the module [./modules/GAN.py](./modules/GAN.py) \n", + "Passing the classe names for generator/discriminator by parameter allows to stay modular and to use the PL checkpoints." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "gan = GAN( data_shape = data_shape, \n", + " batch_size = batch_size, \n", + " latent_dim = latent_dim, \n", + " generator_class = generator_class, \n", + " discriminator_class = discriminator_class)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 5 - Train it !\n", + "#### Instantiate Callbacks, Logger & co.\n", + "More about :\n", + "- [Checkpoints](https://pytorch-lightning.readthedocs.io/en/stable/common/checkpointing_basic.html)\n", + "- [modelCheckpoint](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.callbacks.ModelCheckpoint.html#pytorch_lightning.callbacks.ModelCheckpoint)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "\n", + "# ---- for tensorboard logs\n", + "#\n", + "logger = TensorBoardLogger( save_dir = f'{run_dir}',\n", + " name = 'tb_logs' )\n", + "\n", + "# ---- To save checkpoints\n", + "#\n", + "callback_checkpoints = ModelCheckpoint( dirpath = f'{run_dir}/models', \n", + " filename = 'bestModel', \n", + " save_top_k = 1, \n", + " save_last = True,\n", + " every_n_epochs = 1, \n", + " monitor = \"g_loss\")\n", + "\n", + "# ---- To have a nive progress bar\n", + "#\n", + "callback_progressBar = SmartProgressBar(verbosity=2) # Usable evertywhere\n", + "# progress_bar = TQDMProgressBar(refresh_rate=1) # Usable in real jupyter lab (bug in vscode)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Train it" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "\n", + "trainer = Trainer(\n", + " accelerator = \"auto\",\n", + "# devices = 1 if torch.cuda.is_available() else None, # limiting got iPython runs\n", + " max_epochs = epochs,\n", + " callbacks = [callback_progressBar, callback_checkpoints],\n", + " log_every_n_steps = batch_size,\n", + " logger = logger\n", + ")\n", + "\n", + "trainer.fit(gan, dm)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 6 - Reload a checkpoint" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# gan = GAN.load_from_checkpoint('./run/SHEEP3/lightning_logs/version_3/checkpoints/epoch=4-step=1980.ckpt')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.2" + }, + "vscode": { + "interpreter": { + "hash": "b3929042cc22c1274d74e3e946c52b845b57cb6d84f2d591ffe0519b38e4896d" + } + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/DCGAN-PyTorch/modules/Discriminators.py b/DCGAN-PyTorch/modules/Discriminators.py new file mode 100644 index 0000000..5c6334b --- /dev/null +++ b/DCGAN-PyTorch/modules/Discriminators.py @@ -0,0 +1,41 @@ +# ------------------------------------------------------------------ +# _____ _ _ _ +# | ___(_) __| | | ___ +# | |_ | |/ _` | |/ _ \ +# | _| | | (_| | | __/ +# |_| |_|\__,_|_|\___| GAN / Generators +# ------------------------------------------------------------------ +# Formation Introduction au Deep Learning (FIDLE) +# CNRS/SARI/DEVLOG MIAI/EFELIA 2023 - https://fidle.cnrs.fr +# ------------------------------------------------------------------ +# by JL Parouty (feb 2023) - PyTorch Lightning example + +import numpy as np +import torch.nn as nn + +class Discriminator_1(nn.Module): + + def __init__(self, latent_dim=None, data_shape=None): + + super().__init__() + self.img_shape = data_shape + print('init discriminator : ',data_shape,' to sigmoid') + + self.model = nn.Sequential( + + nn.Flatten(), + nn.Linear(int(np.prod(data_shape)), 512), + nn.ReLU(), + + nn.Linear(512, 256), + nn.ReLU(), + + nn.Linear(256, 1), + nn.Sigmoid(), + ) + + def forward(self, img): + # img_flat = img.view(img.size(0), -1) + validity = self.model(img) + + return validity \ No newline at end of file diff --git a/DCGAN-PyTorch/modules/GAN.py b/DCGAN-PyTorch/modules/GAN.py new file mode 100644 index 0000000..8a1f3ad --- /dev/null +++ b/DCGAN-PyTorch/modules/GAN.py @@ -0,0 +1,179 @@ + +# ------------------------------------------------------------------ +# _____ _ _ _ +# | ___(_) __| | | ___ +# | |_ | |/ _` | |/ _ \ +# | _| | | (_| | | __/ +# |_| |_|\__,_|_|\___| GAN / GAN LigthningModule +# ------------------------------------------------------------------ +# Formation Introduction au Deep Learning (FIDLE) +# CNRS/SARI/DEVLOG MIAI/EFELIA 2023 - https://fidle.cnrs.fr +# ------------------------------------------------------------------ +# by JL Parouty (feb 2023) - PyTorch Lightning example + + +import sys +import numpy as np +import torch +import torch.nn.functional as F +import torchvision +from lightning import LightningModule + + +class GAN(LightningModule): + + # ------------------------------------------------------------------------- + # Init + # ------------------------------------------------------------------------- + # + def __init__( + self, + data_shape = (None,None,None), + latent_dim = None, + lr = 0.0002, + b1 = 0.5, + b2 = 0.999, + batch_size = 64, + generator_class = None, + discriminator_class = None, + **kwargs, + ): + super().__init__() + + print('\n---- GAN initialization --------------------------------------------') + + # ---- Hyperparameters + # + # Enable Lightning to store all the provided arguments under the self.hparams attribute. + # These hyperparameters will also be stored within the model checkpoint. + # + self.save_hyperparameters() + + print('Hyperarameters are :') + for name,value in self.hparams.items(): + print(f'{name:24s} : {value}') + + # ---- Generator/Discriminator instantiation + # + # self.generator = Generator(latent_dim=self.hparams.latent_dim, img_shape=data_shape) + # self.discriminator = Discriminator(img_shape=data_shape) + + print('Submodels :') + module=sys.modules['__main__'] + class_g = getattr(module, generator_class) + class_d = getattr(module, discriminator_class) + self.generator = class_g( latent_dim=latent_dim, data_shape=data_shape) + self.discriminator = class_d( latent_dim=latent_dim, data_shape=data_shape) + + # ---- Validation and example data + # + self.validation_z = torch.randn(8, self.hparams.latent_dim) + self.example_input_array = torch.zeros(2, self.hparams.latent_dim) + + + def forward(self, z): + return self.generator(z) + + + def adversarial_loss(self, y_hat, y): + return F.binary_cross_entropy(y_hat, y) + + + def training_step(self, batch, batch_idx, optimizer_idx): + imgs = batch + batch_size = batch.size(0) + + # ---- Get some latent space vectors + # We use type_as() to make sure we initialize z on the right device (GPU/CPU). + # + z = torch.randn(batch_size, self.hparams.latent_dim) + z = z.type_as(imgs) + + # ---- Train generator + # Generator use optimizer #0 + # We try to generate false images that could mislead the discriminator + # + if optimizer_idx == 0: + + # Generate fake images + self.fake_imgs = self.generator.forward(z) + + # Assemble labels that say all images are real, yes it's a lie ;-) + # put on GPU because we created this tensor inside training_loop + misleading_labels = torch.ones(batch_size, 1) + misleading_labels = misleading_labels.type_as(imgs) + + # Adversarial loss is binary cross-entropy + g_loss = self.adversarial_loss(self.discriminator.forward(self.fake_imgs), misleading_labels) + self.log("g_loss", g_loss, prog_bar=True) + return g_loss + + # ---- Train discriminator + # Discriminator use optimizer #1 + # We try to make the difference between fake images and real ones + # + if optimizer_idx == 1: + + # These images are reals + real_labels = torch.ones(batch_size, 1) + real_labels = real_labels.type_as(imgs) + pred_labels = self.discriminator.forward(imgs) + + real_loss = self.adversarial_loss(pred_labels, real_labels) + + # These images are fake + fake_imgs = self.generator.forward(z) + fake_labels = torch.zeros(batch_size, 1) + fake_labels = fake_labels.type_as(imgs) + + fake_loss = self.adversarial_loss(self.discriminator(fake_imgs.detach()), fake_labels) + + # Discriminator loss is the average + d_loss = (real_loss + fake_loss) / 2 + self.log("d_loss", d_loss, prog_bar=True) + return d_loss + + + def configure_optimizers(self): + + lr = self.hparams.lr + b1 = self.hparams.b1 + b2 = self.hparams.b2 + + # With a GAN, we need 2 separate optimizer. + # opt_g to optimize the generator #0 + # opt_d to optimize the discriminator #1 + opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2)) + opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2)) + return [opt_g, opt_d], [] + + + def training_epoch_end(self, outputs): + + # Get our validation latent vectors as z + # z = self.validation_z.type_as(self.generator.model[0].weight) + + # ---- Log Graph + # + if(self.current_epoch==1): + sampleImg=torch.rand((1,28,28,1)) + self.logger.experiment.add_graph(self.discriminator,sampleImg) + + # ---- Log d_loss/epoch + # + g_loss, d_loss = 0,0 + for metrics in outputs: + g_loss+=float( metrics[0]['loss'] ) + d_loss+=float( metrics[1]['loss'] ) + g_loss, d_loss = g_loss/len(outputs), d_loss/len(outputs) + self.logger.experiment.add_scalar("g_loss/epochs",g_loss, self.current_epoch) + self.logger.experiment.add_scalar("d_loss/epochs",d_loss, self.current_epoch) + + # ---- Log some of these images + # + z = torch.randn(self.hparams.batch_size, self.hparams.latent_dim) + z = z.type_as(self.generator.model[0].weight) + sample_imgs = self.generator(z) + sample_imgs = sample_imgs.permute(0, 3, 1, 2) # from NHWC to NCHW + grid = torchvision.utils.make_grid(tensor=sample_imgs, nrow=12, ) + self.logger.experiment.add_image(f"Generated images", grid,self.current_epoch) diff --git a/DCGAN-PyTorch/modules/Generators.py b/DCGAN-PyTorch/modules/Generators.py new file mode 100644 index 0000000..0c49adb --- /dev/null +++ b/DCGAN-PyTorch/modules/Generators.py @@ -0,0 +1,53 @@ + +# ------------------------------------------------------------------ +# _____ _ _ _ +# | ___(_) __| | | ___ +# | |_ | |/ _` | |/ _ \ +# | _| | | (_| | | __/ +# |_| |_|\__,_|_|\___| GAN / Generators +# ------------------------------------------------------------------ +# Formation Introduction au Deep Learning (FIDLE) +# CNRS/SARI/DEVLOG MIAI/EFELIA 2023 - https://fidle.cnrs.fr +# ------------------------------------------------------------------ +# by JL Parouty (feb 2023) - PyTorch Lightning example + + +import numpy as np +import torch.nn as nn + + +class Generator_1(nn.Module): + + def __init__(self, latent_dim=None, data_shape=None): + super().__init__() + self.latent_dim = latent_dim + self.img_shape = data_shape + print('init generator : ',latent_dim,' to ',data_shape) + + self.model = nn.Sequential( + + nn.Linear(latent_dim, 128), + nn.ReLU(), + + nn.Linear(128,256), + nn.BatchNorm1d(256, 0.8), + nn.ReLU(), + + nn.Linear(256, 512), + nn.BatchNorm1d(512, 0.8), + nn.ReLU(), + + nn.Linear(512, 1024), + nn.BatchNorm1d(1024, 0.8), + nn.ReLU(), + + nn.Linear(1024, int(np.prod(data_shape))), + nn.Sigmoid() + + ) + + + def forward(self, z): + img = self.model(z) + img = img.view(img.size(0), *self.img_shape) + return img \ No newline at end of file diff --git a/DCGAN-PyTorch/modules/QuickDrawDataModule.py b/DCGAN-PyTorch/modules/QuickDrawDataModule.py new file mode 100644 index 0000000..34a4ecf --- /dev/null +++ b/DCGAN-PyTorch/modules/QuickDrawDataModule.py @@ -0,0 +1,71 @@ + +# ------------------------------------------------------------------ +# _____ _ _ _ +# | ___(_) __| | | ___ +# | |_ | |/ _` | |/ _ \ +# | _| | | (_| | | __/ +# |_| |_|\__,_|_|\___| GAN / QuickDrawDataModule +# ------------------------------------------------------------------ +# Formation Introduction au Deep Learning (FIDLE) +# CNRS/SARI/DEVLOG MIAI/EFELIA 2023 - https://fidle.cnrs.fr +# ------------------------------------------------------------------ +# by JL Parouty (feb 2023) - PyTorch Lightning example + + +import numpy as np +import torch +from lightning import LightningDataModule +from torch.utils.data import DataLoader + + +class QuickDrawDataModule(LightningDataModule): + + + def __init__( self, dataset_file='./sheep.npy', scale=1., batch_size=64, num_workers=4 ): + + super().__init__() + + print('\n---- QuickDrawDataModule initialization ----------------------------') + print(f'with : scale={scale} batch size={batch_size}') + + self.scale = scale + self.dataset_file = dataset_file + self.batch_size = batch_size + self.num_workers = num_workers + + self.dims = (28, 28, 1) + self.num_classes = 10 + + + + def prepare_data(self): + pass + + + def setup(self, stage=None): + print('\nDataModule Setup :') + # Load dataset + # Called at the beginning of each stage (train,val,test) + # Here, whatever the stage value, we'll have only one set. + data = np.load(self.dataset_file) + print('Original dataset shape : ',data.shape) + + # Rescale + n=int(self.scale*len(data)) + data = data[:n] + print('Rescaled dataset shape : ',data.shape) + + # Normalize, reshape and shuffle + data = data/255 + data = data.reshape(-1,28,28,1) + data = torch.from_numpy(data).float() + print('Final dataset shape : ',data.shape) + + print('Dataset loaded and ready.') + self.data_train = data + + + def train_dataloader(self): + # Note : Numpy ndarray is Dataset compliant + # Have map-style interface. See https://pytorch.org/docs/stable/data.html + return DataLoader( self.data_train, batch_size=self.batch_size, num_workers=self.num_workers ) \ No newline at end of file diff --git a/DCGAN-PyTorch/modules/SmartProgressBar.py b/DCGAN-PyTorch/modules/SmartProgressBar.py new file mode 100644 index 0000000..3ebe192 --- /dev/null +++ b/DCGAN-PyTorch/modules/SmartProgressBar.py @@ -0,0 +1,70 @@ + +# ------------------------------------------------------------------ +# _____ _ _ _ +# | ___(_) __| | | ___ +# | |_ | |/ _` | |/ _ \ +# | _| | | (_| | | __/ +# |_| |_|\__,_|_|\___| GAN / SmartProgressBar +# ------------------------------------------------------------------ +# Formation Introduction au Deep Learning (FIDLE) +# CNRS/SARI/DEVLOG MIAI/EFELIA 2023 - https://fidle.cnrs.fr +# ------------------------------------------------------------------ +# by JL Parouty (feb 2023) - PyTorch Lightning example + +from lightning.pytorch.callbacks.progress.base import ProgressBarBase +from tqdm import tqdm +import sys + +class SmartProgressBar(ProgressBarBase): + + def __init__(self, verbosity=2): + super().__init__() + self.verbosity = verbosity + + def disable(self): + self.enable = False + + + def setup(self, trainer, pl_module, stage): + super().setup(trainer, pl_module, stage) + self.stage = stage + + + def on_train_epoch_start(self, trainer, pl_module): + super().on_train_epoch_start(trainer, pl_module) + if not self.enable : return + + if self.verbosity==2: + self.progress=tqdm( total=trainer.num_training_batches, + desc=f'{self.stage} {trainer.current_epoch+1}/{trainer.max_epochs}', + ncols=100, ascii= " >", + bar_format='{l_bar}{bar}| [{elapsed}] {postfix}') + + + + def on_train_epoch_end(self, trainer, pl_module): + super().on_train_epoch_end(trainer, pl_module) + + if not self.enable : return + + if self.verbosity==2: + self.progress.close() + + if self.verbosity==1: + print(f'Train {trainer.current_epoch+1}/{trainer.max_epochs} Done.') + + + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx) + + if not self.enable : return + + if self.verbosity==2: + metrics = {} + for name,value in trainer.logged_metrics.items(): + metrics[name]=f'{float( trainer.logged_metrics[name] ):3.3f}' + self.progress.set_postfix(metrics) + self.progress.update(1) + + +progress_bar = SmartProgressBar(verbosity=2) -- GitLab