Skip to content
Snippets Groups Projects
reconstruct.py 3.02 KiB
Newer Older
Tom Chardon's avatar
Tom Chardon committed
"""The main file for the reconstruction.
This file should NOT be modified except the body of the 'run_reconstruction' function.
Students can call their functions (declared in others files of src/methods/your_name).
"""


import numpy as np

from src.forward_model import CFA
from src.methods.chardon_tom.utils import *
import pywt

#!!!!!!!! It is normal that the reconstructions lasts several minutes (3min on my computer)

def run_reconstruction(y: np.ndarray, cfa: str) -> np.ndarray:
    """Performs demosaicking on y.

    Args:
        y (np.ndarray): Mosaicked image to be reconstructed.
        cfa (str): Name of the CFA. Can be bayer or quad_bayer.

    Returns:
        np.ndarray: Demosaicked image.
    """

    # Define constants and operators
    cfa_name = 'bayer' # bayer or quad_bayer
    input_shape = (y.shape[0], y.shape[1], 3)
    op = CFA(cfa_name, input_shape)

    res = op.adjoint(y)

    N,M = input_shape[0], input_shape[1]

    

    #interpolating green channel

    for i in range (N):
        for j in range (M):
            if res[i,j,1] ==0:

                neighbors = get_neighbors(res,1,i,j,N,M)
                weights = get_weights(res,i,j,1,N,M)
                res[i,j,1] = interpolate_green(weights, neighbors)



    #first intepolation of red channel

    for i in range (1,N,2):
        for j in range (0,M,2):
            neighbors = get_neighbors(res,0,i,j,N,M)
            neighbors_G = get_neighbors(res,1,i,j,N,M)
            weights = get_weights(res,i,j,0,N,M) 
            res[i,j,0] = interpolate_red_blue(weights,neighbors, neighbors_G)

    # second interpolation of red channel

    for i in range (N):
        for j in range (M):
            if res[i,j,0] ==0:
                neighbors = get_neighbors(res,0,i,j,N,M)
                weights = get_weights(res,i,j,0,N,M)
                res[i,j,0] = interpolate_green(weights, neighbors)


    #first interpolation of blue channel

    for i in range (0,N,2):
        for j in range (1,M,2):
            neighbors = get_neighbors(res,2,i,j,N,M)
            neighbors_G = get_neighbors(res,1,i,j,N,M)
            weights = get_weights(res,i,j,2,N,M) 
            res[i,j,2] = interpolate_red_blue(weights, neighbors, neighbors_G)

    #second interpolation of blue channel

    for i in range (N):
        for j in range (M):
            if res[i,j,2] ==0:
                neighbors = get_neighbors(res,2,i,j,N,M)
                weights = get_weights(res,i,j,2,N,M)
                res[i,j,2] = interpolate_green(weights,neighbors)



    # k=0
    # while k<2 :
    #     for i in range(input_shape[0]):
    #         for j in range(input_shape[1]):
    #             res[i][j][1] = correction_green(res,i,j,N,M) 
    #     for i in range(input_shape[0]):
    #         for j in range(input_shape[1]):
    #             res[i][j][0] = correction_red(res,i,j,N,M) 
    #     for i in range(input_shape[0]):
    #         for j in range(input_shape[1]):
    #             res[i][j][2] = correction_blue(res,i,j,N,M) 
    #     k+=1

    res[res>1] = 1
    res[res<0] = 0


    return res