"""The main file for the reconstruction.
This file should NOT be modified except the body of the 'run_reconstruction' function.
Students can call their functions (declared in others files of src/methods/your_name).
"""

from src.forward_model import CFA
import numpy as np
from scipy.signal import convolve2d
import cv2 

def is_green(z,i, j):

    return z[i, j, 1] != 0

def hamilton_adams_interpolation(y, op, z):
    height, width = y.shape
    green_channel = np.copy(z[:, :, 1])  

    for i in range(1, height-1):
        for j in range(1, width-1):
            if not is_green(z,i, j) :
                delta_H = abs(z[i, j-1, 1] - z[i, j+1, 1]) + abs(z[i, j-1, 0] - z[i, j+1, 0] + z[i, j-1, 2] - z[i, j+1, 2]) / 2
                # print(f"delta_H : {delta_H}")
                delta_V = abs(z[i-1, j, 1] - z[i+1, j, 1]) + abs(z[i-1, j, 0] - z[i+1, j, 0] + z[i-1, j, 2] - z[i+1, j, 2]) / 2

                if delta_H > delta_V:
                    green_channel[i, j] = (z[i-1, j, 1] + z[i+1, j, 1]) / 2 + (z[i, j-1, 0] - z[i, j+1, 0] + z[i, j-1, 2] - z[i, j+1, 2]) / 4
                elif delta_H < delta_V:
                    green_channel[i, j] = (z[i, j-1, 1] + z[i, j+1, 1]) / 2 + (z[i-1, j, 0] - z[i+1, j, 0] + z[i-1, j, 2] - z[i+1, j, 2]) / 4
                else:
                    green_channel[i, j] = (z[i-1, j, 1] + z[i+1, j, 1] + z[i, j-1, 1] + z[i, j+1, 1]) / 4 + \
                                            (z[i, j-1, 0] - z[i, j+1, 0] + z[i, j-1, 2] - z[i, j+1, 2] + \
                                            z[i-1, j, 0] - z[i+1, j, 0] + z[i-1, j, 2] - z[i+1, j, 2]) / 8

    return green_channel

def interpolate_channel_difference(mosaicked_channel, green_channel_interpolated):
    ker_bayer_red_blue = np.array([[1, 2, 1], [2, 4, 2], [1, 2, 1]]) / 4

    print(mosaicked_channel.shape, green_channel_interpolated.shape)
    difference = mosaicked_channel - green_channel_interpolated
    difference_interpolated = convolve2d(difference, np.ones((3, 3)) / 9, mode='same', boundary='wrap')
    channel_interpolated = green_channel_interpolated + difference_interpolated

    channel_interpolated = convolve2d(channel_interpolated, ker_bayer_red_blue, mode='same')

    return channel_interpolated

def Constant_difference_based_interpolation_reconstruction(op, y, z):
    if op.cfa == 'bayer':
        print("bayer")

        red_channel = z[:, :, 0]
        green_channel = z[:, :, 1]
        blue_channel = z[:, :, 2]

        green_channel_reconstruct = hamilton_adams_interpolation(y, op, z)
            
        red_channel_interpolated = interpolate_channel_difference(red_channel, green_channel_reconstruct)
        blue_channel_interpolated = interpolate_channel_difference(blue_channel, green_channel_reconstruct)

        reconstructed_image = np.stack((red_channel_interpolated, green_channel_reconstruct, blue_channel_interpolated), axis=-1)

        return reconstructed_image
    
    elif op.cfa == "quad_bayer":
        print(f"quad_bayer")
        new_z = cv2.resize(z, (z.shape[1] // 2, z.shape[0] // 2), interpolation=cv2.INTER_AREA)
        new_y=np.sum(new_z, axis=2)
        op.mask = op.mask[::2, ::2]
        green_channel_reconstruct_new = hamilton_adams_interpolation(new_y, op, new_z)
        red_channel_new = new_z[:, :, 0]
        blue_channel_new = new_z[:, :, 2]
        red_channel_interpolated_new = interpolate_channel_difference(red_channel_new, green_channel_reconstruct_new)
        blue_channel_interpolated_new = interpolate_channel_difference(blue_channel_new, green_channel_reconstruct_new)
        reconstructed_image_new = np.stack((red_channel_interpolated_new, green_channel_reconstruct_new, blue_channel_interpolated_new), axis=-1)
        reconstructed_image_upsampled = cv2.resize(reconstructed_image_new, (z.shape[1], z.shape[0]), interpolation=cv2.INTER_LINEAR)
        
        return reconstructed_image_upsampled

    else :
        raise ValueError("CFA pattern not recognized")


def run_reconstruction(y: np.ndarray, cfa: str) -> np.ndarray:
    """Performs demosaicking on y.

    Args:
        y (np.ndarray): Mosaicked image to be reconstructed.
        cfa (str): Name of the CFA. Can be bayer or quad_bayer.

    Returns:
        np.ndarray: Demosaicked image.
    """
    input_shape = (y.shape[0], y.shape[1], 3)
    op = CFA(cfa, input_shape)
    z = op.adjoint(y)
    reconstructed_image = Constant_difference_based_interpolation_reconstruction(op, y, z)
    
    return reconstructed_image


####
####
####

####      ####                ####        #############
####      ######              ####      ##################
####      ########            ####      ####################
####      ##########          ####      ####        ########
####      ############        ####      ####            ####
####      ####  ########      ####      ####            ####
####      ####    ########    ####      ####            ####
####      ####      ########  ####      ####            ####
####      ####  ##    ######  ####      ####          ######
####      ####  ####      ##  ####      ####    ############
####      ####  ######        ####      ####    ##########
####      ####  ##########    ####      ####    ########
####      ####      ########  ####      ####
####      ####        ############      ####
####      ####          ##########      ####
####      ####            ########      ####
####      ####              ######      ####

# 2023
# Authors: Mauro Dalla Mura and Matthieu Muller