diff --git a/DCGAN-PyTorch/01-DCGAN-PL.ipynb b/DCGAN-PyTorch/01-DCGAN-PL.ipynb index 5e6a5f4dee89ca1fd0fcbd72141c01bbc33c4e7d..d3687e449d70fa5ff984345f1da4f321817e8f06 100644 --- a/DCGAN-PyTorch/01-DCGAN-PL.ipynb +++ b/DCGAN-PyTorch/01-DCGAN-PL.ipynb @@ -80,7 +80,7 @@ "latent_dim = 128\n", " \n", "generator_class = 'Generator_2'\n", - "discriminator_class = 'Discriminator_1' \n", + "discriminator_class = 'Discriminator_2' \n", " \n", "scale = .01\n", "epochs = 5\n", @@ -170,7 +170,7 @@ "source": [ "print('\\nInstantiation :\\n')\n", "generator = Generator_2(latent_dim=latent_dim, data_shape=data_shape)\n", - "discriminator = Discriminator_1(latent_dim=latent_dim, data_shape=data_shape)\n", + "discriminator = Discriminator_2(latent_dim=latent_dim, data_shape=data_shape)\n", "\n", "print('\\nFew tests :\\n')\n", "z = torch.randn(batch_size, latent_dim)\n", @@ -296,7 +296,7 @@ "metadata": {}, "outputs": [], "source": [ - "gan = GAN.load_from_checkpoint('./run/SHEEP3/models/last.ckpt')" + "gan = GAN.load_from_checkpoint('./run/SHEEP3/models/last-v1.ckpt')" ] }, { diff --git a/DCGAN-PyTorch/modules/Discriminators.py b/DCGAN-PyTorch/modules/Discriminators.py index 5c6334b254b66ae677b5d57a46541922a33d9569..86c2f706115d86ef5ce835d4d6b61ab750107c8e 100644 --- a/DCGAN-PyTorch/modules/Discriminators.py +++ b/DCGAN-PyTorch/modules/Discriminators.py @@ -19,7 +19,7 @@ class Discriminator_1(nn.Module): super().__init__() self.img_shape = data_shape - print('init discriminator : ',data_shape,' to sigmoid') + print('init discriminator 1 : ',data_shape,' to sigmoid') self.model = nn.Sequential( @@ -38,4 +38,48 @@ class Discriminator_1(nn.Module): # img_flat = img.view(img.size(0), -1) validity = self.model(img) + return validity + + + + +class Discriminator_2(nn.Module): + + def __init__(self, latent_dim=None, data_shape=None): + + super().__init__() + self.img_shape = data_shape + print('init discriminator 2 : ',data_shape,' to sigmoid') + + self.model = nn.Sequential( + + nn.Conv2d(1, 32, kernel_size = 3, stride = 2, padding = 1), + nn.ReLU(), + nn.BatchNorm2d(32), + nn.Dropout2d(0.25), + + nn.Conv2d(32, 64, kernel_size = 3, stride = 1, padding = 1), + nn.ReLU(), + nn.BatchNorm2d(64), + nn.Dropout2d(0.25), + + nn.Conv2d(64, 128, kernel_size = 3, stride = 1, padding = 1), + nn.ReLU(), + nn.BatchNorm2d(128), + nn.Dropout2d(0.25), + + nn.Conv2d(128, 256, kernel_size = 3, stride = 2, padding = 1), + nn.ReLU(), + nn.BatchNorm2d(256), + nn.Dropout2d(0.25), + + nn.Flatten(), + nn.Linear(12544, 1), + nn.Sigmoid(), + ) + + def forward(self, img): + img_nchw = img.permute(0, 3, 1, 2) # from NHWC to NCHW + validity = self.model(img_nchw) + return validity \ No newline at end of file diff --git a/DCGAN-PyTorch/modules/Generators.py b/DCGAN-PyTorch/modules/Generators.py index 755b592df2b347f80c48452290afb724d7a3b4e4..70e523d747187daf5cd1acf38d9ea32cd2f44cad 100644 --- a/DCGAN-PyTorch/modules/Generators.py +++ b/DCGAN-PyTorch/modules/Generators.py @@ -70,14 +70,14 @@ class Generator_2(nn.Module): nn.UpsamplingNearest2d( scale_factor=2 ), # nn.UpsamplingBilinear2d( scale_factor=2 ), nn.Conv2d( 64,128, (3,3), stride=(1,1), padding=(1,1) ), - nn.BatchNorm2d(128), nn.ReLU(), + nn.BatchNorm2d(128), nn.UpsamplingNearest2d( scale_factor=2 ), # nn.UpsamplingBilinear2d( scale_factor=2 ), nn.Conv2d( 128,256, (3,3), stride=(1,1), padding=(1,1)), - nn.BatchNorm2d(256), nn.ReLU(), + nn.BatchNorm2d(256), nn.Conv2d( 256,1, (5,5), stride=(1,1), padding=(2,2)), nn.Sigmoid()