Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
# ------------------------------------------------------------------
# _____ _ _ _
# | ___(_) __| | | ___
# | |_ | |/ _` | |/ _ \
# | _| | | (_| | | __/
# |_| |_|\__,_|_|\___| Tensorboard callback
# ------------------------------------------------------------------
# Formation Introduction au Deep Learning (FIDLE) - CNRS/MIAI/UGA
# ------------------------------------------------------------------
# JL Parouty 2023
#
# See : https://keras.io/api/callbacks/
# See : https://keras.io/guides/writing_your_own_callbacks/
# See : https://pytorch.org/docs/stable/tensorboard.html
import keras
from torch.utils.tensorboard import SummaryWriter
class TensorboardCallback(keras.callbacks.Callback):
def __init__(self, log_dir=None):
'''
Init callback
Args:
log_dir : log directory
'''
self.writer = SummaryWriter(log_dir=log_dir)
def on_epoch_end(self, epoch, logs=None):
'''
Record logs at epoch end
'''
# ---- Records all metrics (very simply)
#
# for k,v in logs.items():
# self.writer.add_scalar(k,v, epoch)
# ---- Records and group specific metrics
#
self.writer.add_scalars('Accuracy',
{'Train':logs['accuracy'],
'Validation':logs['val_accuracy']},
epoch )
self.writer.add_scalars('Loss',
{'Train':logs['loss'],
'Validation':logs['val_loss']},
epoch )