Newer
Older
# ------------------------------------------------------------------
# _____ _ _ _
# | ___(_) __| | | ___
# | |_ | |/ _` | |/ _ \
# | _| | | (_| | | __/
# |_| |_|\__,_|_|\___|
# ------------------------------------------------------------------
# Formation Introduction au Deep Learning (FIDLE)
# CNRS/SARI/DEVLOG 2023
# ------------------------------------------------------------------
# 2.0 version by Achille Mbogol Touye (EFELIA-MIAI/SIMAP¨), sep 2023
from tqdm import tqdm as _tqdm
from lightning.pytorch.callbacks import TQDMProgressBar
# Créez un callback de barre de progression pour afficher les métriques d'entraînement
class CustomTrainProgressBar(TQDMProgressBar):
def __init__(self):
super().__init__()
self._val_progress_bar = _tqdm()
self._predict_progress_bar = _tqdm()
def init_predict_tqdm(self):
bar=super().init_test_tqdm()
bar.set_description("Predicting")
return bar
def init_train_tqdm(self):
bar=super().init_train_tqdm()
bar.set_description("Training")
@property
def val_progress_bar(self):
if self._val_progress_bar is None:
raise ValueError("The `_val_progress_bar` reference has not been set yet.")
return self._val_progress_bar
@property
def predict_progress_bar(self) -> _tqdm:
if self._predict_progress_bar is None:
raise TypeError(f"The `{self.__class__.__name__}._predict_progress_bar` reference has not been set yet.")
return self._predict_progress_bar
def on_validation_start(self, trainer, pl_module):
# Désactivez l'affichage de la barre de progression de validation
self.val_progress_bar.disable = True
def on_predict_start(self, trainer, pl_module):
# Désactivez l'affichage de la barre de progression de validation
self.predict_progress_bar.disable = True