"""The main file for the reconstruction.
This file should NOT be modified except the body of the 'run_reconstruction' function.
Students can call their functions (declared in others files of src/methods/your_name).
"""


import numpy as np

from src.forward_model import CFA
from src.methods.Chardon_tom.utils import *
import pywt

#!!!!!!!! It is normal that the reconstructions lasts several minutes (3min on my computer)

def run_reconstruction(y: np.ndarray, cfa: str) -> np.ndarray:
    """Performs demosaicking on y.

    Args:
        y (np.ndarray): Mosaicked image to be reconstructed.
        cfa (str): Name of the CFA. Can be bayer or quad_bayer.

    Returns:
        np.ndarray: Demosaicked image.
    """

    # Define constants and operators
    cfa_name = 'bayer' # bayer or quad_bayer
    input_shape = (y.shape[0], y.shape[1], 3)
    op = CFA(cfa_name, input_shape)

    res = op.adjoint(y)

    N,M = input_shape[0], input_shape[1]

    

    #interpolating green channel

    for i in range (N):
        for j in range (M):
            if res[i,j,1] ==0:

                neighbors = get_neighbors(res,1,i,j,N,M)
                weights = get_weights(res,i,j,1,N,M)
                res[i,j,1] = interpolate_green(weights, neighbors)



    #first intepolation of red channel

    for i in range (1,N,2):
        for j in range (0,M,2):
            neighbors = get_neighbors(res,0,i,j,N,M)
            neighbors_G = get_neighbors(res,1,i,j,N,M)
            weights = get_weights(res,i,j,0,N,M) 
            res[i,j,0] = interpolate_red_blue(weights,neighbors, neighbors_G)

    # second interpolation of red channel

    for i in range (N):
        for j in range (M):
            if res[i,j,0] ==0:
                neighbors = get_neighbors(res,0,i,j,N,M)
                weights = get_weights(res,i,j,0,N,M)
                res[i,j,0] = interpolate_green(weights, neighbors)


    #first interpolation of blue channel

    for i in range (0,N,2):
        for j in range (1,M,2):
            neighbors = get_neighbors(res,2,i,j,N,M)
            neighbors_G = get_neighbors(res,1,i,j,N,M)
            weights = get_weights(res,i,j,2,N,M) 
            res[i,j,2] = interpolate_red_blue(weights, neighbors, neighbors_G)

    #second interpolation of blue channel

    for i in range (N):
        for j in range (M):
            if res[i,j,2] ==0:
                neighbors = get_neighbors(res,2,i,j,N,M)
                weights = get_weights(res,i,j,2,N,M)
                res[i,j,2] = interpolate_green(weights,neighbors)



    # k=0
    # while k<2 :
    #     for i in range(input_shape[0]):
    #         for j in range(input_shape[1]):
    #             res[i][j][1] = correction_green(res,i,j,N,M) 
    #     for i in range(input_shape[0]):
    #         for j in range(input_shape[1]):
    #             res[i][j][0] = correction_red(res,i,j,N,M) 
    #     for i in range(input_shape[0]):
    #         for j in range(input_shape[1]):
    #             res[i][j][2] = correction_blue(res,i,j,N,M) 
    #     k+=1

    res[res>1] = 1
    res[res<0] = 0


    return res