diff --git a/DCGAN-PyTorch/01-DCGAN-PL.ipynb b/DCGAN-PyTorch/01-DCGAN-PL.ipynb index dd821c77a52b4f6bf92471019c5bb9834d1a5d43..e4b57a42fcfdb636f4a77795e46bb3ba87825b70 100644 --- a/DCGAN-PyTorch/01-DCGAN-PL.ipynb +++ b/DCGAN-PyTorch/01-DCGAN-PL.ipynb @@ -96,9 +96,9 @@ "source": [ "latent_dim = 128\n", "\n", - "gan_class = 'GAN'\n", + "gan_class = 'WGANGP'\n", "generator_class = 'Generator_2'\n", - "discriminator_class = 'Discriminator_2' \n", + "discriminator_class = 'Discriminator_3' \n", " \n", "scale = 0.001\n", "epochs = 3\n", @@ -181,6 +181,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -190,7 +191,27 @@ "Our Discriminators are defined in [./modules/Discriminators.py](./modules/Discriminators.py) \n", "\n", "\n", - "Our GAN is defined in [./modules/GAN.py](./modules/GAN.py) " + "Our GAN is defined in [./modules/GAN.py](./modules/GAN.py) \n", + "\n", + "#### Class loader" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def get_class(class_name):\n", + " module=sys.modules['__main__']\n", + " class_ = getattr(module, class_name)\n", + " return class_\n", + " \n", + "def get_instance(class_name, **args):\n", + " module=sys.modules['__main__']\n", + " class_ = getattr(module, class_name)\n", + " instance_ = class_(**args)\n", + " return instance_" ] }, { @@ -219,8 +240,12 @@ "# ----Get it, and play with them\n", "#\n", "print('\\nInstantiation :\\n')\n", - "generator = get_classByName( generator_class, latent_dim=latent_dim, data_shape=data_shape)\n", - "discriminator = get_classByName( discriminator_class, latent_dim=latent_dim, data_shape=data_shape)\n", + "\n", + "Generator_ = get_class(generator_class)\n", + "Discriminator_ = get_class(discriminator_class)\n", + "\n", + "generator = Generator_( latent_dim=latent_dim, data_shape=data_shape)\n", + "discriminator = Discriminator_( latent_dim=latent_dim, data_shape=data_shape)\n", "\n", "print('\\nFew tests :\\n')\n", "z = torch.randn(batch_size, latent_dim)\n", @@ -276,14 +301,16 @@ "metadata": {}, "outputs": [], "source": [ - "gan = WGANGP( data_shape = data_shape,\n", - " lr = lr,\n", - " b1 = b1,\n", - " b2 = b2,\n", - " batch_size = batch_size, \n", - " latent_dim = latent_dim, \n", - " generator_class = generator_class, \n", - " discriminator_class = discriminator_class)" + "GAN_ = get_class(gan_class)\n", + "\n", + "gan = GAN_( data_shape = data_shape,\n", + " lr = lr,\n", + " b1 = b1,\n", + " b2 = b2,\n", + " batch_size = batch_size, \n", + " latent_dim = latent_dim, \n", + " generator_class = generator_class, \n", + " discriminator_class = discriminator_class)" ] }, { diff --git a/DCGAN-PyTorch/modules/Discriminators.py b/DCGAN-PyTorch/modules/Discriminators.py index 2458a61dc7bf036f22227634001eafe702de3a9c..bdbaa79c08332bfcdd6c6a6e8ad3a4cee62f02e9 100644 --- a/DCGAN-PyTorch/modules/Discriminators.py +++ b/DCGAN-PyTorch/modules/Discriminators.py @@ -14,6 +14,9 @@ import numpy as np import torch.nn as nn class Discriminator_1(nn.Module): + ''' + A basic DNN discriminator, usable with classic GAN + ''' def __init__(self, latent_dim=None, data_shape=None): @@ -43,6 +46,9 @@ class Discriminator_1(nn.Module): class Discriminator_2(nn.Module): + ''' + A more efficient discriminator,based on CNN, usable with classic GAN + ''' def __init__(self, latent_dim=None, data_shape=None): @@ -77,6 +83,52 @@ class Discriminator_2(nn.Module): nn.Sigmoid(), ) + def forward(self, img): + img_nchw = img.permute(0, 3, 1, 2) # reformat from NHWC to NCHW + validity = self.model(img_nchw) + + return validity + + + +class Discriminator_3(nn.Module): + ''' + A CNN discriminator, usable with a WGANGP. + This discriminator has no sigmoid and returns a critical and not a probability + ''' + + def __init__(self, latent_dim=None, data_shape=None): + + super().__init__() + self.img_shape = data_shape + print('init discriminator 2 : ',data_shape,' to sigmoid') + + self.model = nn.Sequential( + + nn.Conv2d(1, 32, kernel_size = 3, stride = 2, padding = 1), + nn.ReLU(), + nn.BatchNorm2d(32), + nn.Dropout2d(0.25), + + nn.Conv2d(32, 64, kernel_size = 3, stride = 1, padding = 1), + nn.ReLU(), + nn.BatchNorm2d(64), + nn.Dropout2d(0.25), + + nn.Conv2d(64, 128, kernel_size = 3, stride = 1, padding = 1), + nn.ReLU(), + nn.BatchNorm2d(128), + nn.Dropout2d(0.25), + + nn.Conv2d(128, 256, kernel_size = 3, stride = 2, padding = 1), + nn.ReLU(), + nn.BatchNorm2d(256), + nn.Dropout2d(0.25), + + nn.Flatten(), + nn.Linear(12544, 1), + ) + def forward(self, img): img_nchw = img.permute(0, 3, 1, 2) # reformat from NHWC to NCHW validity = self.model(img_nchw)