From 4011935590b5a4196cc0df27a470af4037c651ec Mon Sep 17 00:00:00 2001
From: Achille Mbogol Touye <achille.mbogol-touye@univ-grenoble-alpes.fr>
Date: Tue, 7 Nov 2023 01:51:08 +0100
Subject: [PATCH] Replace 01-DNN-Wine-Regression-lightning.ipynb

---
 .../01-DNN-Wine-Regression-lightning.ipynb    | 21 ++++++++++++++-----
 1 file changed, 16 insertions(+), 5 deletions(-)

diff --git a/Wine.Lightning/01-DNN-Wine-Regression-lightning.ipynb b/Wine.Lightning/01-DNN-Wine-Regression-lightning.ipynb
index d27ead2..ce0b260 100644
--- a/Wine.Lightning/01-DNN-Wine-Regression-lightning.ipynb
+++ b/Wine.Lightning/01-DNN-Wine-Regression-lightning.ipynb
@@ -73,9 +73,11 @@
     "import torch.nn.functional as F\n",
     "import torchvision.transforms as T\n",
     "\n",
-    "from IPython.display import Markdown\n",
+    "\n",
     "from importlib import reload\n",
+    "from IPython.display import Markdown\n",
     "from torch.utils.data import Dataset, DataLoader, random_split\n",
+    "from modules.progressbar import CustomTrainProgressBar\n",
     "from modules.data_load import WineQualityDataset, Normalize, ToTensor\n",
     "from lightning.pytorch.loggers.tensorboard import TensorBoardLogger\n",
     "from torchmetrics.functional.regression import mean_absolute_error, mean_squared_error\n",
@@ -287,8 +289,9 @@
     "    def forward(self, x):                                              # forward pass\n",
     "        x = self.model(x)\n",
     "        return x        \n",
-    "        \n",
-    "     # optimizer\n",
+    "\n",
+    "   \n",
+    "    # optimizer\n",
     "    def configure_optimizers(self):                              \n",
     "        optimizer = torch.optim.RMSprop(self.parameters(),lr=1e-4)\n",
     "        return optimizer \n",
@@ -431,7 +434,8 @@
     "trainer = pl.Trainer(accelerator='auto',\n",
     "                     max_epochs=100,\n",
     "                     logger=logger,\n",
-    "                     callbacks=[savemodel_callback])\n",
+    "                     num_sanity_val_steps=0,\n",
+    "                     callbacks=[savemodel_callback,CustomTrainProgressBar()])\n",
     "\n",
     "trainer.fit(model=reg, train_dataloaders=train_loader, val_dataloaders=test_loader)"
    ]
@@ -474,7 +478,7 @@
    "source": [
     "# launch Tensorboard \n",
     "%reload_ext tensorboard\n",
-    "%tensorboard --logdir=Wine_logs/reg_logs/"
+    "%tensorboard --logdir=Wine_logs/reg_logs/ --bind_all"
    ]
   },
   {
@@ -589,6 +593,13 @@
     "---\n",
     "<img width=\"80px\" src=\"../fidle/img/logo-paysage.svg\"></img>"
    ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
   }
  ],
  "metadata": {
-- 
GitLab