From 95dd87dc277bb2b73d2872b3f354a23f7b41cfdc Mon Sep 17 00:00:00 2001 From: Jean-Luc <Jean-Luc.Parouty@simap.grenoble-inp.fr> Date: Sat, 4 Mar 2023 12:36:43 +0100 Subject: [PATCH] Update 01-DCGAN-PL.ipynb --- DCGAN-PyTorch/01-DCGAN-PL.ipynb | 34 ++++++++++++++++++--------------- 1 file changed, 19 insertions(+), 15 deletions(-) diff --git a/DCGAN-PyTorch/01-DCGAN-PL.ipynb b/DCGAN-PyTorch/01-DCGAN-PL.ipynb index 6647aea..294fa50 100644 --- a/DCGAN-PyTorch/01-DCGAN-PL.ipynb +++ b/DCGAN-PyTorch/01-DCGAN-PL.ipynb @@ -98,10 +98,13 @@ "generator_class = 'Generator_2'\n", "discriminator_class = 'Discriminator_2' \n", " \n", - "scale = .005\n", - "epochs = 5\n", + "scale = 1\n", + "epochs = 3\n", + "lr = 0.0001\n", + "b1 = 0.5\n", + "b2 = 0.999\n", "batch_size = 32\n", - "num_img = 36\n", + "num_img = 48\n", "fit_verbosity = 2\n", " \n", "dataset_file = datasets_dir+'/QuickDraw/origine/sheep.npy' \n", @@ -122,9 +125,9 @@ "outputs": [], "source": [ "# You can comment these lines to keep each run...\n", - "shutil.rmtree(f'{run_dir}/figs', ignore_errors=True)\n", - "shutil.rmtree(f'{run_dir}/models', ignore_errors=True)\n", - "shutil.rmtree(f'{run_dir}/tb_logs', ignore_errors=True)" + "# shutil.rmtree(f'{run_dir}/figs', ignore_errors=True)\n", + "# shutil.rmtree(f'{run_dir}/models', ignore_errors=True)\n", + "# shutil.rmtree(f'{run_dir}/tb_logs', ignore_errors=True)" ] }, { @@ -252,12 +255,14 @@ "metadata": {}, "outputs": [], "source": [ - "gan = GAN( data_shape = data_shape, \n", + "gan = GAN( data_shape = data_shape,\n", + " lr = lr,\n", + " b1 = b1,\n", + " b2 = b2,\n", " batch_size = batch_size, \n", " latent_dim = latent_dim, \n", " generator_class = generator_class, \n", - " discriminator_class = discriminator_class,\n", - " lr=0.0001)" + " discriminator_class = discriminator_class)" ] }, { @@ -328,7 +333,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -342,7 +346,7 @@ "metadata": {}, "outputs": [], "source": [ - "gan = GAN.load_from_checkpoint('./run/SHEEP3/models/bestModel.ckpt')" + "gan = GAN.load_from_checkpoint('./run/SHEEP3/models/bestModel-v1.ckpt')" ] }, { @@ -351,7 +355,7 @@ "metadata": {}, "outputs": [], "source": [ - "nb_images = 32\n", + "nb_images = 96\n", "\n", "z = torch.randn(nb_images, latent_dim)\n", "print('z size : ',z.size())\n", @@ -384,9 +388,9 @@ ], "metadata": { "kernelspec": { - "display_name": "fidle-env", + "display_name": "Fidle-Lightning", "language": "python", - "name": "python3" + "name": "fidle-lightning" }, "language_info": { "codemirror_mode": { @@ -398,7 +402,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.2" + "version": "3.10.8" }, "vscode": { "interpreter": { -- GitLab