# ------------------------------------------------------------------ # _____ _ _ _ # | ___(_) __| | | ___ # | |_ | |/ _` | |/ _ \ # | _| | | (_| | | __/ # |_| |_|\__,_|_|\___| GAN / QuickDrawDataModule # ------------------------------------------------------------------ # Formation Introduction au Deep Learning (FIDLE) # CNRS/MIAI - https://fidle.cnrs.fr # ------------------------------------------------------------------ # JL Parouty (Mars 2024) import numpy as np import torch from lightning import LightningDataModule from torch.utils.data import DataLoader class QuickDrawDataModule(LightningDataModule): def __init__( self, dataset_file='./sheep.npy', scale=1., batch_size=64, num_workers=4 ): super().__init__() print('\n---- QuickDrawDataModule initialization ----------------------------') print(f'with : scale={scale} batch size={batch_size}') self.scale = scale self.dataset_file = dataset_file self.batch_size = batch_size self.num_workers = num_workers self.dims = (28, 28, 1) self.num_classes = 10 def prepare_data(self): pass def setup(self, stage=None): print('\nDataModule Setup :') # Load dataset # Called at the beginning of each stage (train,val,test) # Here, whatever the stage value, we'll have only one set. data = np.load(self.dataset_file) print('Original dataset shape : ',data.shape) # Rescale n=int(self.scale*len(data)) data = data[:n] print('Rescaled dataset shape : ',data.shape) # Normalize, reshape and shuffle data = data/255 data = data.reshape(-1,28,28,1) data = torch.from_numpy(data).float() print('Final dataset shape : ',data.shape) print('Dataset loaded and ready.') self.data_train = data def train_dataloader(self): # Note : Numpy ndarray is Dataset compliant # Have map-style interface. See https://pytorch.org/docs/stable/data.html return DataLoader( self.data_train, batch_size=self.batch_size, num_workers=self.num_workers )