diff --git a/DCGAN-PyTorch/01-DCGAN-PL.ipynb b/DCGAN-PyTorch/01-DCGAN-PL.ipynb index ea028f9bcf76c2b665f5492bb51e1d7ff2e5f120..5e6a5f4dee89ca1fd0fcbd72141c01bbc33c4e7d 100644 --- a/DCGAN-PyTorch/01-DCGAN-PL.ipynb +++ b/DCGAN-PyTorch/01-DCGAN-PL.ipynb @@ -82,8 +82,8 @@ "generator_class = 'Generator_2'\n", "discriminator_class = 'Discriminator_1' \n", " \n", - "scale = .1\n", - "epochs = 10\n", + "scale = .01\n", + "epochs = 5\n", "batch_size = 32\n", "num_img = 36\n", "fit_verbosity = 2\n", @@ -211,7 +211,8 @@ " batch_size = batch_size, \n", " latent_dim = latent_dim, \n", " generator_class = generator_class, \n", - " discriminator_class = discriminator_class)" + " discriminator_class = discriminator_class,\n", + " lr=0.0001)" ] }, { @@ -295,9 +296,37 @@ "metadata": {}, "outputs": [], "source": [ - "# gan = GAN.load_from_checkpoint('./run/SHEEP3/lightning_logs/version_3/checkpoints/epoch=4-step=1980.ckpt')" + "gan = GAN.load_from_checkpoint('./run/SHEEP3/models/last.ckpt')" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "nb_images = 32\n", + "\n", + "# z = np.random.normal(size=(nb_images,latent_dim))\n", + "\n", + "z = torch.randn(nb_images, latent_dim)\n", + "print('z size : ',z.size())\n", + "\n", + "fake_img = gan.generator.forward(z)\n", + "print('fake_img : ', fake_img.size())\n", + "\n", + "nimg = fake_img.detach().numpy()\n", + "fidle.scrawler.images( nimg.reshape(-1,28,28), indices=range(nb_images), columns=12, x_size=1, y_size=1, \n", + " y_padding=0,spines_alpha=0, save_as='01-Sheeps')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, { "cell_type": "code", "execution_count": null, diff --git a/DCGAN-PyTorch/modules/GAN.py b/DCGAN-PyTorch/modules/GAN.py index a97065c36b1aaa0f37e769474f5147220f35ae01..43917470559c9b672a4af5c0b93591cc307d4312 100644 --- a/DCGAN-PyTorch/modules/GAN.py +++ b/DCGAN-PyTorch/modules/GAN.py @@ -116,6 +116,8 @@ class GAN(LightningModule): # These images are reals real_labels = torch.ones(batch_size, 1) + # Add random noise to the labels + # real_labels += 0.05 * torch.rand(batch_size,1) real_labels = real_labels.type_as(imgs) pred_labels = self.discriminator.forward(imgs) @@ -124,6 +126,8 @@ class GAN(LightningModule): # These images are fake fake_imgs = self.generator.forward(z) fake_labels = torch.zeros(batch_size, 1) + # Add random noise to the labels + # fake_labels += 0.05 * torch.rand(batch_size,1) fake_labels = fake_labels.type_as(imgs) fake_loss = self.adversarial_loss(self.discriminator(fake_imgs.detach()), fake_labels) @@ -143,8 +147,10 @@ class GAN(LightningModule): # With a GAN, we need 2 separate optimizer. # opt_g to optimize the generator #0 # opt_d to optimize the discriminator #1 - opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2)) - opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2)) + # opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2)) + # opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2),) + opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr) + opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr) return [opt_g, opt_d], [] diff --git a/DCGAN-PyTorch/modules/Generators.py b/DCGAN-PyTorch/modules/Generators.py index cb187826db83719500650a5042cea4cf170a9dad..755b592df2b347f80c48452290afb724d7a3b4e4 100644 --- a/DCGAN-PyTorch/modules/Generators.py +++ b/DCGAN-PyTorch/modules/Generators.py @@ -67,12 +67,16 @@ class Generator_2(nn.Module): nn.Linear(latent_dim, 7*7*64), nn.Unflatten(1, (64,7,7)), - nn.UpsamplingBilinear2d( scale_factor=2 ), + nn.UpsamplingNearest2d( scale_factor=2 ), + # nn.UpsamplingBilinear2d( scale_factor=2 ), nn.Conv2d( 64,128, (3,3), stride=(1,1), padding=(1,1) ), + nn.BatchNorm2d(128), nn.ReLU(), - nn.UpsamplingBilinear2d( scale_factor=2 ), + nn.UpsamplingNearest2d( scale_factor=2 ), + # nn.UpsamplingBilinear2d( scale_factor=2 ), nn.Conv2d( 128,256, (3,3), stride=(1,1), padding=(1,1)), + nn.BatchNorm2d(256), nn.ReLU(), nn.Conv2d( 256,1, (5,5), stride=(1,1), padding=(2,2)), @@ -82,6 +86,8 @@ class Generator_2(nn.Module): def forward(self, z): img = self.model(z) - img = img.view(img.size(0), *self.img_shape) + img = img.view(img.size(0), *self.img_shape) # batch_size x 1 x W x H => batch_size x W x H x 1 return img + +