diff --git a/MNIST.Lightning/02-CNN-MNIST_Lightning.ipynb b/MNIST.Lightning/02-CNN-MNIST_Lightning.ipynb index 1845b617177379449a7b9cc19a1a658c539ad14b..a9290b454ac244616d03562ac429ad7a681ecd55 100644 --- a/MNIST.Lightning/02-CNN-MNIST_Lightning.ipynb +++ b/MNIST.Lightning/02-CNN-MNIST_Lightning.ipynb @@ -59,10 +59,12 @@ "import multiprocessing\n", "import matplotlib.pyplot as plt\n", "\n", - "from lightning.pytorch.loggers import TensorBoardLogger\n", - "from torch.utils.data import Dataset, DataLoader\n", "from torchvision import datasets\n", "from torchmetrics.functional import accuracy\n", + "from torch.utils.data import Dataset, DataLoader\n", + "from modules.progressbar import CustomTrainProgressBar\n", + "from lightning.pytorch.loggers import TensorBoardLogger\n", + "\n", "\n", "# Init Fidle environment\n", "import fidle\n", @@ -167,10 +169,6 @@ " # range 0.0 - 1.0\n", " T.ToTensor(),\n", "\n", - " # This then renormalizes the tensor to be between -1.0 and 1.0,\n", - " # which is a better range for modern activation functions like\n", - " # Relu\n", - " T.Normalize((0.5), (0.5)),\n", " ]\n", ")\n", "\n", @@ -222,15 +220,12 @@ "metadata": {}, "outputs": [], "source": [ - "# get the number of CPUs in your system \n", - "n_workers = multiprocessing.cpu_count()\n", - "\n", "# train bacth data\n", "train_loader= DataLoader(\n", " dataset=train_dataset, \n", " shuffle=True, \n", " batch_size=512,\n", - " num_workers=n_workers \n", + " num_workers=2\n", ")\n", "\n", "# test batch data\n", @@ -238,7 +233,7 @@ " dataset=test_dataset, \n", " shuffle=False, \n", " batch_size=512,\n", - " num_workers=n_workers \n", + " num_workers=2 \n", ")\n", "\n", "# print image and label After normalization and batch_size.\n", @@ -279,20 +274,20 @@ " nn.Conv2d(in_channels=1, out_channels=8, kernel_size=3, stride=1, padding=0),\n", " nn.ReLU(),\n", " nn.MaxPool2d((2,2)), \n", - " nn.Dropout2d(0.2), # Combat overfitting\n", + " nn.Dropout2d(0.1), # Combat overfitting\n", " \n", " # second convolution \n", " nn.Conv2d(in_channels=8, out_channels=16, kernel_size=3, stride=1, padding=0),\n", " nn.ReLU(),\n", " nn.MaxPool2d((2,2)), \n", - " nn.Dropout2d(0.2), # Combat overfitting\n", + " nn.Dropout2d(0.1), # Combat overfitting\n", " \n", " nn.Flatten(), # convert feature map into feature vectors\n", " \n", " # MLP network \n", " nn.Linear(16*5*5,100),\n", " nn.ReLU(),\n", - " nn.Dropout1d(0.2), # Combat overfitting\n", + " nn.Dropout1d(0.1), # Combat overfitting\n", " \n", " nn.Linear(100, num_class), # logits outpout\n", " )\n", @@ -431,7 +426,12 @@ "outputs": [], "source": [ "# train model\n", - "trainer = pl.Trainer(accelerator='auto',max_epochs=16,logger=logger)\n", + "trainer = pl.Trainer(accelerator='auto',\n", + " max_epochs=16,\n", + " logger=logger,\n", + " num_sanity_val_steps=0,\n", + " callbacks=[CustomTrainProgressBar()]\n", + " )\n", "\n", "trainer.fit(model=model, train_dataloaders=train_loader, val_dataloaders=test_loader)" ] @@ -602,6 +602,14 @@ "---\n", "<img width=\"80px\" src=\"../fidle/img/logo-paysage.svg\"></img>" ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "83a8ceb7-9711-407b-9546-60ad7bd2b5ba", + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": {