From 4011935590b5a4196cc0df27a470af4037c651ec Mon Sep 17 00:00:00 2001 From: Achille Mbogol Touye <achille.mbogol-touye@univ-grenoble-alpes.fr> Date: Tue, 7 Nov 2023 01:51:08 +0100 Subject: [PATCH] Replace 01-DNN-Wine-Regression-lightning.ipynb --- .../01-DNN-Wine-Regression-lightning.ipynb | 21 ++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/Wine.Lightning/01-DNN-Wine-Regression-lightning.ipynb b/Wine.Lightning/01-DNN-Wine-Regression-lightning.ipynb index d27ead2..ce0b260 100644 --- a/Wine.Lightning/01-DNN-Wine-Regression-lightning.ipynb +++ b/Wine.Lightning/01-DNN-Wine-Regression-lightning.ipynb @@ -73,9 +73,11 @@ "import torch.nn.functional as F\n", "import torchvision.transforms as T\n", "\n", - "from IPython.display import Markdown\n", + "\n", "from importlib import reload\n", + "from IPython.display import Markdown\n", "from torch.utils.data import Dataset, DataLoader, random_split\n", + "from modules.progressbar import CustomTrainProgressBar\n", "from modules.data_load import WineQualityDataset, Normalize, ToTensor\n", "from lightning.pytorch.loggers.tensorboard import TensorBoardLogger\n", "from torchmetrics.functional.regression import mean_absolute_error, mean_squared_error\n", @@ -287,8 +289,9 @@ " def forward(self, x): # forward pass\n", " x = self.model(x)\n", " return x \n", - " \n", - " # optimizer\n", + "\n", + " \n", + " # optimizer\n", " def configure_optimizers(self): \n", " optimizer = torch.optim.RMSprop(self.parameters(),lr=1e-4)\n", " return optimizer \n", @@ -431,7 +434,8 @@ "trainer = pl.Trainer(accelerator='auto',\n", " max_epochs=100,\n", " logger=logger,\n", - " callbacks=[savemodel_callback])\n", + " num_sanity_val_steps=0,\n", + " callbacks=[savemodel_callback,CustomTrainProgressBar()])\n", "\n", "trainer.fit(model=reg, train_dataloaders=train_loader, val_dataloaders=test_loader)" ] @@ -474,7 +478,7 @@ "source": [ "# launch Tensorboard \n", "%reload_ext tensorboard\n", - "%tensorboard --logdir=Wine_logs/reg_logs/" + "%tensorboard --logdir=Wine_logs/reg_logs/ --bind_all" ] }, { @@ -589,6 +593,13 @@ "---\n", "<img width=\"80px\" src=\"../fidle/img/logo-paysage.svg\"></img>" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { -- GitLab