diff --git a/MNIST.Lightning/modules/progressbar.py b/MNIST.Lightning/modules/progressbar.py index 075e13b93f5f1632b8c1892ece8819a4900d12a4..fb25a079b603ff54dcbcff4a252c0fda9f288416 100644 --- a/MNIST.Lightning/modules/progressbar.py +++ b/MNIST.Lightning/modules/progressbar.py @@ -17,21 +17,36 @@ from lightning.pytorch.callbacks import TQDMProgressBar class CustomTrainProgressBar(TQDMProgressBar): def __init__(self): super().__init__() - self._val_progress_bar = _tqdm() + self._val_progress_bar = _tqdm() + self._predict_progress_bar = _tqdm() + def init_predict_tqdm(self): + bar=super().init_test_tqdm() + bar.set_description("Predicting") + return bar + def init_train_tqdm(self): bar=super().init_train_tqdm() bar.set_description("Training") - return bar + return bar @property def val_progress_bar(self): if self._val_progress_bar is None: raise ValueError("The `_val_progress_bar` reference has not been set yet.") return self._val_progress_bar + + @property + def predict_progress_bar(self) -> _tqdm: + if self._predict_progress_bar is None: + raise TypeError(f"The `{self.__class__.__name__}._predict_progress_bar` reference has not been set yet.") + return self._predict_progress_bar def on_validation_start(self, trainer, pl_module): # Désactivez l'affichage de la barre de progression de validation self.val_progress_bar.disable = True - \ No newline at end of file + + def on_predict_start(self, trainer, pl_module): + # Désactivez l'affichage de la barre de progression de validation + self.predict_progress_bar.disable = True \ No newline at end of file