Skip to content
Snippets Groups Projects
progressbar.py 1.93 KiB
Newer Older
# ------------------------------------------------------------------
#     _____ _     _ _
#    |  ___(_) __| | | ___
#    | |_  | |/ _` | |/ _ \
#    |  _| | | (_| | |  __/
#    |_|   |_|\__,_|_|\___|
# ------------------------------------------------------------------
# Formation Introduction au Deep Learning  (FIDLE)
# CNRS/SARI/DEVLOG 2023 
# ------------------------------------------------------------------
# 2.0 version by Achille Mbogol Touye (EFELIA-MIAI/SIMAP¨), sep 2023

from tqdm import tqdm as _tqdm
from lightning.pytorch.callbacks import TQDMProgressBar

# Créez un callback de barre de progression pour afficher les métriques d'entraînement
class CustomTrainProgressBar(TQDMProgressBar):
    def __init__(self):
        super().__init__()
        self._val_progress_bar     = _tqdm()
        self._predict_progress_bar = _tqdm()
    def init_predict_tqdm(self):
        bar=super().init_test_tqdm()
        bar.set_description("Predicting")
        return bar

    def init_train_tqdm(self):
        bar=super().init_train_tqdm()
        bar.set_description("Training")
        return bar    

    @property
    def val_progress_bar(self):
        if self._val_progress_bar is None:
            raise ValueError("The `_val_progress_bar` reference has not been set yet.")
        return self._val_progress_bar

    @property
    def predict_progress_bar(self) -> _tqdm:
        if self._predict_progress_bar is None:
            raise TypeError(f"The `{self.__class__.__name__}._predict_progress_bar` reference has not been set yet.")
        return self._predict_progress_bar    
    

    def on_validation_start(self, trainer, pl_module):
        # Désactivez l'affichage de la barre de progression de validation
        self.val_progress_bar.disable = True  

    def on_predict_start(self, trainer, pl_module):
        # Désactivez l'affichage de la barre de progression de validation
        self.predict_progress_bar.disable = True