Skip to content
Snippets Groups Projects
mourasa_reconstruct.py 6.89 KiB
Newer Older
#Imports
import cv2
import numpy as np
from scipy.fft import fft2, ifft2
from src.forward_model import CFA
from scipy.signal import convolve2d

def criterion(img_true, img):
    """Function that calculate the NMSE between img_true and img"""
    return np.linalg.norm(np.int8(img_true) - np.int8(img))**2 / np.linalg.norm(np.int8(img_true))**2

def mosaic_bp(img):
    """Function that collect values and index of known pixels for the Bayer pattern"""
    H = img.shape[0]
    W = img.shape[1]
    img_r = np.zeros((H,W), dtype = np.uint8)
    idx_r = []
    img_g = np.zeros((H,W), dtype = np.uint8)
    idx_g = []
    img_b = np.zeros((H,W), dtype = np.uint8)
    idx_b = []
    img_tot = np.zeros((H,W,3), dtype = np.uint8)
    for i in range(H):
        for j in range(W):
            if ((i+j)%2 == 0) : 
                img_g[i,j] = img[i,j,1]
                idx_g.append(i*H + j)
            if (((i+j)%2 == 1) and (i%2 == 0)) :
                img_r[i,j] = img[i,j,0]
                idx_r.append(i*H + j)
            if (((i+j)%2 == 1) and (i%2 == 1)) :
                img_b[i,j] = img[i,j,2]
                idx_b.append(i*H + j)
    img_tot[:,:,0] = img_r
    img_tot[:,:,1] = img_g
    img_tot[:,:,2] = img_b
    return img_tot, img_r, idx_r, img_g, idx_g, img_b, idx_b

def mosaic_qbp(img):
    """Function that collect values and index of known pixels for the Quad Bayer pattern"""
    H = img.shape[0]
    W = img.shape[1]
    img_r = np.zeros((H,W), dtype = np.uint8)
    idx_r = []
    img_g = np.zeros((H,W), dtype = np.uint8)
    idx_g = []
    img_b = np.zeros((H,W), dtype = np.uint8)
    idx_b = []
    img_tot = np.zeros((H,W,3), dtype = np.uint8)
    for i in range(0,H,2):
        for j in range(0,W,2):
            if ((i+j)%4 == 0) : 
                img_g[i:i+2,j:j+2] = img[i:i+2,j:j+2,1]
                idx_g.append(i*H + j)
                idx_g.append(i*H + j+1)
                idx_g.append((i+1)*H + j)
                idx_g.append((i+1)*H + j+1)
            if (((i+j)%4 == 2) and (i%4 == 0)) :
                img_r[i:i+2,j:j+2] = img[i:i+2,j:j+2,0]
                idx_r.append(i*H + j)
                idx_r.append(i*H + j+1)
                idx_r.append((i+1)*H + j)
                idx_r.append((i+1)*H + j+1)
            if (((i+j)%4 == 2) and (i%4 == 2)) :
                img_b[i:i+2,j:j+2] = img[i:i+2,j:j+2,2]
                idx_b.append(i*H + j)
                idx_b.append(i*H + j+1)
                idx_b.append((i+1)*H + j)
                idx_b.append((i+1)*H + j+1)
    img_tot[:,:,0] = img_r
    img_tot[:,:,1] = img_g
    img_tot[:,:,2] = img_b
    return img_tot, img_r, idx_r, img_g, idx_g, img_b, idx_b

def proxop1(X_fft, gamma):
    """Function that calculate the proximal operator of l1-norm"""
    H = X_fft.shape[0]
    W = X_fft.shape[1]
    output_fft = np.zeros((H,W), dtype = complex)
    for i in range(H):
        for j in range(W):
            output_fft[i,j] = max(0, np.abs(X_fft[i,j])-gamma) * np.exp(1j*np.angle(X_fft[i,j]))
    return output_fft

def proxop2(X_fft, Y, idx):
    """Function that calculate the proximal operator of indicator function"""
    X = np.abs(ifft2(X_fft))
    H = X.shape[0]
    W = X.shape[1]
    output = np.zeros((H,W))
    for i in range(H):
        for j in range(W):
            if ((i*H + j) in idx) : output[i,j] = Y[i,j]
            else : output[i,j] = X[i,j]
    return fft2(output)
    
def DouglasRachford(X_fft, Y, idx, rho, gamma, NbIt, img_true):
    """Function that iterate the Douglas-Rachford Algorithm"""
    J = []
    J.append(criterion(img_true,np.abs(ifft2(X_fft))))
    for k in range(NbIt):
        X_fft_temp = proxop1(X_fft, gamma)
        X_fft = X_fft + 2*rho*(proxop2(2*X_fft_temp - X_fft, Y, idx) - X_fft_temp)
        J.append(criterion(img_true,np.abs(ifft2(X_fft))))
    return X_fft, J

def interpol_bp_rb(img):
    """Function that do the interpolation for R or B channel of a Bayer pattern image"""
    tool = np.array([[0.25,0.5,0.25],[0.5,1,0.5],[0.25,0.5,0.25]])
    output = convolve2d(img, tool, mode='same', boundary='wrap')
    return output

def interpol_qbp_rb(img):
    """Function that do the interpolation for R or B channel of a Quad Bayer pattern image"""
    tool = 0.25*np.array([[0.25,0.25,0.5,0.5,0.25,0.25],[0.25,0.25,0.5,0.5,0.25,0.25],[0.5,0.5,1,1,0.5,0.5],[0.5,0.5,1,1,0.5,0.5],[0.25,0.25,0.5,0.5,0.25,0.25],[0.25,0.25,0.5,0.5,0.25,0.25]])
    output = convolve2d(img, tool, mode='same', boundary='wrap')
    return output

def interpol_bp_g(img):
    """Function that do the interpolation for G channel of a Bayer pattern image"""
    tool = np.array([[0,0.25,0],[0.25,1,0.25],[0,0.25,0]])
    output = convolve2d(img, tool, mode='same', boundary='wrap')
    return output

def interpol_qbp_g(img):
    """Function that do the interpolation for G channel of a Quad Bayer pattern image"""
    tool = 0.25*np.array([[0,0,0.25,0.25,0,0],[0,0,0.25,0.25,0,0],[0.25,0.25,1,1,0.25,0.25],[0.25,0.25,1,1,0.25,0.25],[0,0,0.25,0.25,0,0],[0,0,0.25,0.25,0,0]])
    output = convolve2d(img, tool, mode='same', boundary='wrap')
    return output

def mourasa_reconstruction(op: CFA, y: np.ndarray):
    """Function that reconstruct the colour image"""
    #y = cv2.cvtColor(y,cv2.COLOR_BGR2RGB)
    img_rec = np.empty(op.input_shape)
    
    if op.cfa == 'bayer':
        img_tot, img_r, idx_r, img_g, idx_g, img_b, idx_b = mosaic_bp(y)
        #R Channel Reconstruction
        cp_r = interpol_bp_rb(img_r)
        DR_r_fft, J_r = DouglasRachford(fft2(cp_r), img_r, idx_r, 0.8, 15, 10, y[:,:,0])
        DR_r = ifft2(DR_r_fft)
        #G Channel reconstruction
        cp_g = interpol_bp_g(img_g)
        DR_g_fft, J_g = DouglasRachford(fft2(cp_g), img_g, idx_g, 0.8, 15, 10, y[:,:,1])
        DR_g = ifft2(DR_g_fft)
        #B Channel reconstruction
        cp_b = interpol_bp_rb(img_b)
        DR_b_fft, J_b = DouglasRachford(fft2(cp_b), img_b, idx_b, 0.8, 15, 10, y[:,:,2])
        DR_b = ifft2(DR_b_fft)
        #Channels
        img_rec = np.zeros(y.shape, dtype = np.uint8)
        img_rec[:,:,0] = DR_r
        img_rec[:,:,1] = DR_g
        img_rec[:,:,2] = DR_b
    
    elif op.cfa == 'quad_bayer':
        img_tot, img_r, idx_r, img_g, idx_g, img_b, idx_b = mosaic_qbp(y)
        #R Channel Reconstruction
        cp_r = interpol_qbp_rb(img_r)
        DR_r_fft, J_r = DouglasRachford(fft2(cp_r), img_r, idx_r, 0.8, 15, 10, y[:,:,0])
        DR_r = ifft2(DR_r_fft)
        #G Channel reconstruction
        cp_g = interpol_qbp_g(img_g)
        DR_g_fft, J_g = DouglasRachford(fft2(cp_g), img_g, idx_g, 0.8, 15, 10, y[:,:,1])
        DR_g = ifft2(DR_g_fft)
        #B Channel reconstruction
        cp_b = interpol_qbp_rb(img_b)
        DR_b_fft, J_b = DouglasRachford(fft2(cp_b), img_b, idx_b, 0.8, 15, 10, y[:,:,2])
        DR_b = ifft2(DR_b_fft)
        #Channels
        img_rec = np.zeros(y.shape, dtype = np.uint8)
        img_rec[:,:,0] = DR_r
        img_rec[:,:,1] = DR_g
        img_rec[:,:,2] = DR_b
    
    #img_rec = cv2.cvtColor(img_rec,cv2.COLOR_RGB2BGR)
    return img_rec