From fff5ebea281eeb3a0db81b6381cc54437c9018a8 Mon Sep 17 00:00:00 2001
From: Achille Mbogol Touye <achille.mbogol-touye@univ-grenoble-alpes.fr>
Date: Tue, 7 Nov 2023 10:12:54 +0100
Subject: [PATCH] Replace progressbar.py

---
 MNIST.Lightning/modules/progressbar.py | 21 ++++++++++++++++++---
 1 file changed, 18 insertions(+), 3 deletions(-)

diff --git a/MNIST.Lightning/modules/progressbar.py b/MNIST.Lightning/modules/progressbar.py
index 075e13b..fb25a07 100644
--- a/MNIST.Lightning/modules/progressbar.py
+++ b/MNIST.Lightning/modules/progressbar.py
@@ -17,21 +17,36 @@ from lightning.pytorch.callbacks import TQDMProgressBar
 class CustomTrainProgressBar(TQDMProgressBar):
     def __init__(self):
         super().__init__()
-        self._val_progress_bar = _tqdm()
+        self._val_progress_bar     = _tqdm()
+        self._predict_progress_bar = _tqdm()
         
+    def init_predict_tqdm(self):
+        bar=super().init_test_tqdm()
+        bar.set_description("Predicting")
+        return bar
+
     def init_train_tqdm(self):
         bar=super().init_train_tqdm()
         bar.set_description("Training")
-        return bar
+        return bar    
 
     @property
     def val_progress_bar(self):
         if self._val_progress_bar is None:
             raise ValueError("The `_val_progress_bar` reference has not been set yet.")
         return self._val_progress_bar
+
+    @property
+    def predict_progress_bar(self) -> _tqdm:
+        if self._predict_progress_bar is None:
+            raise TypeError(f"The `{self.__class__.__name__}._predict_progress_bar` reference has not been set yet.")
+        return self._predict_progress_bar    
     
 
     def on_validation_start(self, trainer, pl_module):
         # Désactivez l'affichage de la barre de progression de validation
         self.val_progress_bar.disable = True  
-        
\ No newline at end of file
+
+    def on_predict_start(self, trainer, pl_module):
+        # Désactivez l'affichage de la barre de progression de validation
+        self.predict_progress_bar.disable = True 
\ No newline at end of file
-- 
GitLab