Skip to content
Snippets Groups Projects
Commit 29c63240 authored by Achille Mbogol Touye's avatar Achille Mbogol Touye
Browse files

Update 02-CNN-MNIST_Lightning.ipynb

parent a2e717bd
No related branches found
No related tags found
1 merge request!10Update 02-CNN-MNIST_Lightning.ipynb
%% Cell type:markdown id:86fe2213-fb44-4bd4-a371-a541cba6a744 tags:
<img width="800px" src="../fidle/img/header.svg"></img>
## <!-- TITLE --> [MNIST2] - Simple classification with CNN using lightning
<!-- DESC --> An example of classification using a convolutional neural network for the famous MNIST dataset
<!-- AUTHOR : MBOGOL Touye Achille (AI/ML Engineer MIAI/SIMaP) -->
## Objectives :
- Recognizing handwritten numbers
- Understanding the principle of a classifier DNN network
- Implementation with pytorch lightning
The [MNIST dataset](http://yann.lecun.com/exdb/mnist/) (Modified National Institute of Standards and Technology) is a must for Deep Learning.
It consists of 60,000 small images of handwritten numbers for learning and 10,000 for testing.
## What we're going to do :
- Retrieve data
- Preparing the data
- Create a model
- Train the model
- Evaluate the result
%% Cell type:markdown id:7f16101a-6612-4e02-93e9-c45ce1ac911d tags:
## Step 1 - Init python stuff
%% Cell type:code id:743c77d3-0983-491c-90be-ef2219861a47 tags:
``` python
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import lightning.pytorch as pl
import torch.nn.functional as F
import torchvision.transforms as T
import sys,os
import multiprocessing
import matplotlib.pyplot as plt
from torchvision import datasets
from torchmetrics.functional import accuracy
from torch.utils.data import Dataset, DataLoader
from modules.progressbar import CustomTrainProgressBar
from lightning.pytorch.loggers import TensorBoardLogger
# Init Fidle environment
import fidle
run_id, run_dir, datasets_dir = fidle.init('MNIST_Ligthning')
```
%% Cell type:markdown id:df10dcda-aa63-476b-8665-9b1610fe51c6 tags:
## Step 2 Retrieve data
MNIST is one of the most famous historic dataset include in torchvision Datasets. `torchvision` provides many built-in datasets in the `torchvision.datasets`.
%% Cell type:code id:6668e50c-f0c6-43cf-b733-9ac29d6a3900 tags:
``` python
# Load data sets
train_dataset = datasets.MNIST(root="data", train=True, download=True, transform=None)
test_dataset= datasets.MNIST(root="data", train=False, download=False, transform=None)
```
%% Cell type:code id:a14d6fc2-b913-4eaa-9cde-5ca6785bfa12 tags:
``` python
# print info for train data
print(train_dataset)
print()
# print info for test data
print(test_dataset)
```
%% Cell type:code id:44a489f5-3e53-4a2b-8069-f265b2814dc0 tags:
``` python
# See the shape of train data and test data
print("x_train : ",train_dataset.data.shape)
print("y_train : ",train_dataset.targets.shape)
print()
print("x_test : ",test_dataset.data.shape)
print("y_test : ",test_dataset.targets.shape)
print()
# print number of targets and values targets
print("Number of Targets :",len(np.unique(train_dataset.targets)))
print("Targets Values :", np.unique(train_dataset.targets))
print()
print("Remark that we work with torch tensors and not numpy array, not tensorflow tensor")
print(" -> x_train.dtype = ",train_dataset.data.dtype)
print(" -> y_train.dtype = ",train_dataset.targets.dtype)
```
%% Cell type:markdown id:b418adb7-33ea-450c-9793-3cdce5d5fa8c tags:
## Step 3 - Preparing your data for training with DataLoaders
The Dataset retrieves our dataset’s features and labels one sample at a time. While training a model, we typically want to pass samples in `minibatches`, reshuffle the data at every epoch to reduce model overfitting, and use Python’s multiprocessing to speed up data retrieval. DataLoader is an iterable that abstracts this complexity for us in an easy API.
%% Cell type:code id:8af0bc4c-acb3-46d9-aae2-143b0327d970 tags:
``` python
# Before normalization:
x_train=train_dataset.data
print('Before normalization : Min={}, max={}'.format(x_train.min(),x_train.max()))
# After normalization:
## T.Compose creates a pipeline where the provided transformations are run in sequence
transforms = T.Compose(
[
# This transforms takes a np.array or a PIL image of integers
# in the range 0-255 and transforms it to a float tensor in the
# range 0.0 - 1.0
T.ToTensor(),
]
)
train_dataset = datasets.MNIST(root="data", train=True, download=True, transform=transforms)
test_dataset= datasets.MNIST(root="data", train=False, download=True, transform=transforms)
# print image and label After normalization.
# iter() followed by next() is used to get some images and label
image,label=next(iter(train_dataset))
print('After normalization : Min={}, max={}'.format(image.min(),image.max()))
```
%% Cell type:markdown id:35d50a57-8274-4660-8765-d0f2bf7214bd tags:
### Have a look
%% Cell type:code id:a172ebc5-8858-4f30-8e2c-1e9c123ae0ee tags:
``` python
x_train=train_dataset.data
y_train=train_dataset.targets
```
%% Cell type:code id:5a487760-b43a-4f7c-bfd8-1ce2c9652769 tags:
``` python
fidle.scrawler.images(x_train, y_train, [27], x_size=5,y_size=5, colorbar=True, save_as='01-one-digit')
fidle.scrawler.images(x_train, y_train, range(5,41), columns=12, save_as='02-many-digits')
```
%% Cell type:code id:ca0a63ae-e6d6-4940-b8ff-9b11cb2737bb tags:
``` python
# train bacth data
train_loader= DataLoader(
dataset=train_dataset,
shuffle=True,
batch_size=512,
num_workers=2
)
# test batch data
test_loader= DataLoader(
dataset=test_dataset,
shuffle=False,
batch_size=512,
num_workers=2
)
# print image and label After normalization and batch_size.
image, label=next(iter(train_loader))
print('Shape of first training data batch after use pytorch dataloader :\nbatch images = {} \nbatch labels = {}'.format(image.shape,label.shape))
```
%% Cell type:markdown id:51bf21ee-76ca-42fa-b67f-066dbd239a72 tags:
## Step 4 - Create Model
About informations about :
- [Optimizer](https://www.tensorflow.org/api_docs/python/tf/keras/optimizers)
- [Activation](https://www.tensorflow.org/api_docs/python/tf/keras/activations)
- [Loss](https://www.tensorflow.org/api_docs/python/tf/keras/losses)
- [Metrics](https://www.tensorflow.org/api_docs/python/tf/keras/metrics)
`Note :` PyTorch provides losses such as the cross-entropy loss (`nn.CrossEntropyLoss`) usually use for classification problem. we're using the softmax function to predict class probabilities. With a softmax output, you want to use cross-entropy as the loss. To actually calculate the loss, we need to pass in the raw output of our network into the loss, not the output of the softmax function. This raw output is usually called the *logits* or *scores*. because in pytorch the cross entropy contain softmax function already.
%% Cell type:code id:16701119-71eb-4f59-a50a-f153b07a74ae tags:
``` python
class MyNet(nn.Module):
def __init__(self,num_class=10):
super().__init__()
self.num_class=num_class
self.model = nn.Sequential(
# first convolution
nn.Conv2d(in_channels=1, out_channels=8, kernel_size=3, stride=1, padding=0),
nn.ReLU(),
nn.MaxPool2d((2,2)),
nn.Dropout2d(0.1), # Combat overfitting
# second convolution
nn.Conv2d(in_channels=8, out_channels=16, kernel_size=3, stride=1, padding=0),
nn.ReLU(),
nn.MaxPool2d((2,2)),
nn.Dropout2d(0.1), # Combat overfitting
nn.Flatten(), # convert feature map into feature vectors
# MLP network
nn.Linear(16*5*5,100),
nn.ReLU(),
nn.Dropout1d(0.1), # Combat overfitting
nn.Linear(100, num_class), # logits outpout
)
def forward(self, x):
x=self.model(x) # forward pass
return x
```
%% Cell type:code id:37abf99b-f8ec-4048-a65d-f173ee18b234 tags:
``` python
class LitModel(pl.LightningModule):
def __init__(self, MyNet):
super().__init__()
self.MyNet = MyNet
# forward pass
def forward(self, x):
return self.MyNet(x)
# optimizer
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return optimizer
def training_step(self, batch, batch_idx):
# defines the train loop
x, y = batch
# forward pass
y_hat = self.MyNet(x)
# loss function
loss= F.cross_entropy(y_hat, y)
# metrics accuracy
acc=accuracy(y_hat, y,task="multiclass",num_classes=10)
metrics = {"train_loss": loss, "train_acc": acc}
# logs metrics for each training_step
self.log_dict(metrics,
on_step=False ,
on_epoch=True,
prog_bar=True,
logger=True)
return loss
def validation_step(self, batch, batch_idx):
# defines the valid loop.
x, y = batch
# forward pass
y_hat= self.MyNet(x)
# loss function MSE
loss = F.cross_entropy(y_hat, y)
# metrics accuracy
acc=accuracy(y_hat, y, task="multiclass", num_classes=10)
metrics = {"test_loss": loss, "test_acc": acc}
# logs metrics for each validation_step
self.log_dict(metrics,
on_step = False,
on_epoch = True,
prog_bar = True,
logger = True
)
return metrics
def predict_step(self, batch, batch_idx):
# defnie the predict loop
x, y = batch
# forward pass
y_hat = self.MyNet(x)
return y_hat
```
%% Cell type:code id:489af62f-8f7c-4d1b-a6d0-5a0417e79869 tags:
``` python
# print summary model
model=LitModel(MyNet())
print(model)
```
%% Cell type:markdown id:fb32e85d-bd92-4ca5-a3dc-ddb5ed50ba6b tags:
## Step 5 - Train Model
%% Cell type:code id:756d5e19-6a10-42b8-8971-411389f7d19c tags:
``` python
# loggers data
logger = TensorBoardLogger(save_dir="MNIST2_logs",name="CNN_logs")
```
%% Cell type:code id:ce975c03-d05d-40c4-92ff-0cc90699c13e tags:
``` python
# train model
trainer = pl.Trainer(accelerator='auto',
max_epochs=16,
logger=logger,
num_sanity_val_steps=0,
callbacks=[CustomTrainProgressBar()]
)
trainer.fit(model=model, train_dataloaders=train_loader, val_dataloaders=test_loader)
```
%% Cell type:markdown id:a1191f05-4454-415c-a5ed-e63d9ae56651 tags:
## Step 6 - Evaluate
### 6.1 - Final loss and accuracy
Note : With a DNN, we had a precision of the order of : 97.7%
%% Cell type:code id:9f45316e-0d2d-4fc1-b9a8-5fb8aaf5586a tags:
``` python
# evaluate your model
score=trainer.validate(model=model,dataloaders=test_loader, verbose=False)
print('x_test / acc : {:5.4f}'.format(score[0]['test_acc']))
print('x_test / loss : {:5.4f}'.format(score[0]['test_loss']))
```
%% Cell type:code id:5cfe9bd6-654b-42e0-b430-5f3b816526b0 tags:
``` python
score=trainer.validate(model=model,dataloaders=test_loader, verbose=False)
```
%% Cell type:markdown id:e352e48d-b473-4162-a1aa-72d6d4f7aa38 tags:
## 6.2 - Plot history
%% Cell type:code id:5b1c6d11-b897-4e2b-8615-c207c8344d07 tags:
``` python
# launch Tensorboard
%reload_ext tensorboard
%tensorboard --logdir=MNIST2_logs/CNN_logs/
%tensorboard --logdir=MNIST2_logs/CNN_logs/ --bind_all
```
%% Cell type:markdown id:f00ded6b-a7db-4c5d-b1b2-72264db20bdb tags:
### 6.3 - Plot results
%% Cell type:code id:e387a70d-9c23-4d16-8ef7-879aec7791e2 tags:
``` python
# logits outpout by batch size
y_logits= trainer.predict(model=model,dataloaders=test_loader)
# Concat into single tensor
y_logits= torch.cat(y_logits)
# output probabilities values
y_pred_values=F.softmax(y_logits,dim=1)
# Returns the indices of the maximum output probabilities values
y_pred=torch.argmax(y_pred_values,dim=-1)
```
%% Cell type:code id:fb2b2eeb-fcd8-453c-93ef-59a960a8bbd5 tags:
``` python
x_test=test_dataset.data
y_test=test_dataset.targets
```
%% Cell type:code id:71187fa9-2ad3-4b23-94b9-1846045bd070 tags:
``` python
fidle.scrawler.images(x_test, y_test, range(0,200), columns=12, x_size=1, y_size=1, y_pred=y_pred, save_as='04-predictions')
```
%% Cell type:markdown id:2fc7b2b9-9115-4848-9aae-2798bf7aa79a tags:
### 6.4 - Plot some errors
%% Cell type:code id:e55f17c4-fce7-423a-9adf-f2511c534ef5 tags:
``` python
errors=[ i for i in range(len(x_test)) if y_pred[i]!=y_test[i] ]
errors=errors[:min(24,len(errors))]
fidle.scrawler.images(x_test, y_test, errors[:15], columns=6, x_size=2, y_size=2, y_pred=y_pred, save_as='05-some-errors')
```
%% Cell type:code id:fea1b396-70ca-4b00-851d-0538a4b347fb tags:
``` python
fidle.scrawler.confusion_matrix(y_test,y_pred,range(10),normalize=True, save_as='06-confusion-matrix')
```
%% Cell type:code id:e982c032-cce8-4c71-8cdc-2af4b31b2914 tags:
``` python
fidle.end()
```
%% Cell type:markdown id:233838c2-c97f-4489-8c79-9247d7b7456b tags:
<div class="todo">
A few things you can do for fun:
<ul>
<li>Changing the network architecture (layers, number of neurons, etc.)</li>
<li>Display a summary of the network</li>
<li>Retrieve and display the softmax output of the network, to evaluate its "doubts".</li>
</ul>
</div>
%% Cell type:markdown id:51b87aa0-d4e9-48bb-8205-4b583f4b0b61 tags:
---
<img width="80px" src="../fidle/img/logo-paysage.svg"></img>
%% Cell type:code id:83a8ceb7-9711-407b-9546-60ad7bd2b5ba tags:
``` python
```
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment