"""
DDFAPD - Menon (2007) Bayer CFA Demosaicing
===========================================

*Bayer* CFA (Colour Filter Array) DDFAPD - *Menon (2007)* demosaicing.

References
----------
-   :cite:`Menon2007c` : Menon, D., Andriani, S., & Calvagno, G. (2007).
    Demosaicing With Directional Filtering and a posteriori Decision. IEEE
    Transactions on Image Processing, 16(1), 132-141.
    doi:10.1109/TIP.2006.884928
"""

import numpy as np
from colour.hints import ArrayLike, Literal, NDArrayFloat
from colour.utilities import as_float_array, ones, tsplit, tstack
from scipy.ndimage.filters import convolve, convolve1d
from src.forward_model import CFA

def tensor_mask_to_RGB_mask(mask: ArrayLike, pixelPattern: str = "RGB"):
        # We extract image chanels from mask
    for i, letter in enumerate(pixelPattern):
        if letter == "R":
            R_m = mask[:, :, i]
        elif letter == "G":
            G_m = mask[:, :, i]
        elif letter == "B":
            B_m = mask[:, :, i]

    return R_m, G_m, B_m

def _cnv_h(x: ArrayLike, y: ArrayLike) -> NDArrayFloat:
    """Perform horizontal convolution."""
    # we go through the rows because axis = -1
    return convolve1d(x, y, mode="mirror")


def _cnv_v(x: ArrayLike, y: ArrayLike) -> NDArrayFloat:
    """Perform vertical convolution."""

    return convolve1d(x, y, mode="mirror", axis=0)


def demosaicing_CFA_Bayer_Menon2007(
    rawImage: ArrayLike,
    mask: ArrayLike,
    pixelPattern: str = "RGB",
    refining_step: bool = True,
):
    """
    Return the demosaiced *RGB* colourspace array from given *Bayer* CFA using
    DDFAPD - *Menon (2007)* demosaicing algorithm.

    Parameters
    ----------
    CFA
        *Bayer* CFA.
    pattern
        Arrangement of the colour filters on the pixel array.
    refining_step
        Perform refining step.

    Returns
    -------
    :class:`numpy.ndarray`
        *RGB* colourspace array.

    Notes
    -----
    -   The definition output is not clipped in range [0, 1] : this allows for
        direct HDRI image generation on *Bayer* CFA data and post
        demosaicing of the high dynamic range data as showcased in this
        `Jupyter Notebook <https://github.com/colour-science/colour-hdri/\
blob/develop/colour_hdri/examples/\
examples_merge_from_raw_files_with_post_demosaicing.ipynb>`__.

    References
    ----------
    :cite:`Menon2007c`
    """

    # We extract image chanels from mask
    R_m, G_m, B_m = tensor_mask_to_RGB_mask(mask, pixelPattern)

    # We extract known pixel intensities: when we have a zero in the mask, we have an unknown pixel intensity for the color
    R = rawImage * R_m
    G = rawImage * G_m
    B = rawImage * B_m

    # We define the horizontal and vertical filters
    h_0 = as_float_array([0.0, 0.5, 0.0, 0.5, 0.0])
    h_1 = as_float_array([-0.25, 0.0, 0.5, 0.0, -0.25])

    # Green components interpolation along both horizontal and veritcal directions:
    # For each unkown green pixel, we compute the gradient along both horizontal and vertical directions
    G_H = np.where(G_m == 0, _cnv_h(rawImage, h_0) + _cnv_h(rawImage, h_1), G)
    G_V = np.where(G_m == 0, _cnv_v(rawImage, h_0) + _cnv_v(rawImage, h_1), G)

    # We calculate the chrominance differences along both horizontal and vertical directions
    # For each unknown red and blue pixel, we compute the difference between the pixel intensity and the horizontal green component
    C_H = np.where(R_m == 1, R - G_H, 0)
    C_H = np.where(B_m == 1, B - G_H, C_H)

    # Sale method with vertical green component
    C_V = np.where(R_m == 1, R - G_V, 0)
    C_V = np.where(B_m == 1, B - G_V, C_V)

    # We compute the directional gradients along both horizontal and vertical directions
    # First we pad our arrayes with zeros to avoid boundary effects. Acxtually, we pad with the last value of the array
    # We add two columns to the right of the horizontal array and two rows at the bottom of the  vertical array, with the reflect mode.
    # Then we remove the first two columns of the horizontal array and the first two rows of the vertical array.
    paded_D_H = np.pad(C_H, ((0, 0), (0, 2)), mode="reflect")[:, 2:]
    paded_D_V = np.pad(C_V, ((0, 2), (0, 0)), mode="reflect")[2:, :]

    # We compute the difference between the original array and the padded array.
    # With the paded array, we have a difference between each pixel and the right neigborhood. We do not have issue with boundaries.
    # It gives a measure of pixel intensity variation along the horizontal and vertical directions.
    D_H = np.abs(C_H - paded_D_H)
    D_V = np.abs(C_V - paded_D_V)

    del h_0, h_1, C_V, C_H, paded_D_V, paded_D_H

    # We define a sufficiently large neighborhood with a size of (5, 5).
    k = as_float_array(
        [
            [0.0, 0.0, 1.0, 0.0, 1.0],
            [0.0, 0.0, 0.0, 1.0, 0.0],
            [0.0, 0.0, 3.0, 0.0, 3.0],
            [0.0, 0.0, 0.0, 1.0, 0.0],
            [0.0, 0.0, 1.0, 0.0, 1.0],
        ]
    )

    # We convolve the difference component with the neighborhood. This method is used to highlight directional variations in the image, in two direction.
    d_H = convolve(D_H, k, mode="constant")
    d_V = convolve(D_V, np.transpose(k), mode="constant")

    del D_H, D_V

    # We estimate the green channel with our classifier
    mask = d_V >= d_H
    G = np.where(mask, G_H, G_V)
    # We estimate the mask which represents the best directional reconstruction
    M = np.where(mask, 1, 0)

    del d_H, d_V, G_H, G_V

    ## The, we estimate the red and blue channels

    # We arrays with ones at the line where there is at least one red (blue) pixel in the red (blue) mask
    R_r = np.transpose(np.any(R_m == 1, axis=1)[None]) * ones(R.shape)
    B_r = np.transpose(np.any(B_m == 1, axis=1)[None]) * ones(B.shape)

    # We define a new filter
    k_b = as_float_array([0.5, 0, 0.5])

    # We fill R array with the condition: if we are in a line where there is at least one red pixel in the red mask and we are on a green pixel in the green mask, we apply the filter horizontaly to the red channel.
    # If not it means we are on a red pixel (only two possiblity) in the red mask, so we do not apply the filter because we know the red pixel
    R = np.where(
        np.logical_and(G_m == 1, R_r == 1),
        G + _cnv_h(R, k_b) - _cnv_h(G, k_b),
        R,
    )

    # Same but we test only the line where there is at least one blue pixel in the blue mask.
    # When the condition is true, we apply the filter vertically because this time red pixel are aline vertically.
    R = np.where(
        np.logical_and(G_m == 1, B_r == 1) == 1,
        G + _cnv_v(R, k_b) - _cnv_v(G, k_b),
        R,
    )

    # It is the same logic for the blue image
    B = np.where(
        np.logical_and(G_m == 1, B_r == 1),
        G + _cnv_h(B, k_b) - _cnv_h(G, k_b),
        B,
    )

    B = np.where(
        np.logical_and(G_m == 1, R_r == 1) == 1,
        G + _cnv_v(B, k_b) - _cnv_v(G, k_b),
        B,
    )

    # To finish R image we need to interpolate blue pixel. We use M to know wich direction is the best and then we interpolate the blue pixel with the filter.
    R_b = np.where(
            M == 1,
            B + _cnv_h(R, k_b) - _cnv_h(B, k_b),
            B + _cnv_v(R, k_b) - _cnv_v(B, k_b),
        )

    # Then we put the condition: if we are on a line where there is at least one blue pixel and we are on a blue pixel we take the previous interpolated value.
    # If not we know the red pixel value and we keep it.
    R = np.where(
        np.logical_and(B_r == 1, B_m == 1),
        R_b,
        R,
    )

    # Same idea for the blue image.
    B = np.where(
        np.logical_and(R_r == 1, R_m == 1),
        np.where(
            M == 1,
            R + _cnv_h(B, k_b) - _cnv_h(R, k_b),
            R + _cnv_v(B, k_b) - _cnv_v(R, k_b),
        ),
        B,
    )

    # We stack the channels in the last dimension to get the final image
    RGB = tstack([R, G, B])

    del R, G, B, k_b, R_r, B_r

    # We optionally perform the refining step
    if refining_step:
        RGB = refining_step_Menon2007(RGB, tstack([R_m, G_m, B_m]), M)

    del M, R_m, G_m, B_m

    return RGB


def refining_step_Menon2007(
    RGB: ArrayLike, RGB_m: ArrayLike, M: ArrayLike
) -> NDArrayFloat:
    """
    Perform the refining step on given *RGB* colourspace array.

    Parameters
    ----------
    RGB
        *RGB* colourspace array.
    RGB_m
        *Bayer* CFA red, green and blue masks.
    M
        Estimation for the best directional reconstruction.

    Returns
    -------
    :class:`numpy.ndarray`
        Refined *RGB* colourspace array.
    """
    # Unpacking the RGB and RGB_m arrays.
    R, G, B = tsplit(RGB)
    R_m, G_m, B_m = tsplit(RGB_m)
    M = as_float_array(M)

    del RGB, RGB_m

    # Updating of the green component.
    R_G = R - G
    B_G = B - G

    # Definition of the low-pass filter.
    FIR = ones(3) / 3

    # When we are on a blue pixel, we convolve the pixel with the filter in function of the best direction.
    B_G_m = np.where(
        B_m == 1,
        np.where(M == 1, _cnv_h(B_G, FIR), _cnv_v(B_G, FIR)),
        0,
    )

    # Same for the red pixel.
    R_G_m = np.where(
        R_m == 1,
        np.where(M == 1, _cnv_h(R_G, FIR), _cnv_v(R_G, FIR)),
        0,
    )

    del B_G, R_G

    # We update the green component for known red and blue pixels with the difference between the red or blue pixel intensity and the filtered pixel intensity.
    G = np.where(R_m == 1, R - R_G_m, G)
    G = np.where(B_m == 1, B - B_G_m, G)

    # Updating of the red and blue components in the green locations.

    # R_r is an array with ones at the line where there is at least one red pixel in the red mask.
    R_r = np.transpose(np.any(R_m == 1, axis=1)[None]) * ones(R.shape)
    # R_c is an array with ones at the column where there is at least one red pixel in the red mask.
    R_c = np.any(R_m == 1, axis=0)[None] * ones(R.shape)
    # B_r is an array with ones at the line where there is at least one blue pixel in the blue mask.
    B_r = np.transpose(np.any(B_m == 1, axis=1)[None]) * ones(B.shape)
    # B_c is an array with ones at the column where there is at least one blue pixel in the blue mask.
    B_c = np.any(B_m == 1, axis=0)[None] * ones(B.shape)

    R_G = R - G
    B_G = B - G

    k_b = as_float_array([0.5, 0.0, 0.5])

    R_G_m = np.where(
        np.logical_and(G_m == 1, B_r == 1),
        _cnv_v(R_G, k_b),
        R_G_m,
    )
    R = np.where(np.logical_and(G_m == 1, B_r == 1), G + R_G_m, R)
    R_G_m = np.where(
        np.logical_and(G_m == 1, B_c == 1),
        _cnv_h(R_G, k_b),
        R_G_m,
    )
    R = np.where(np.logical_and(G_m == 1, B_c == 1), G + R_G_m, R)

    del B_r, R_G_m, B_c, R_G

    B_G_m = np.where(
        np.logical_and(G_m == 1, R_r == 1),
        _cnv_v(B_G, k_b),
        B_G_m,
    )
    B = np.where(np.logical_and(G_m == 1, R_r == 1), G + B_G_m, B)
    B_G_m = np.where(
        np.logical_and(G_m == 1, R_c == 1),
        _cnv_h(B_G, k_b),
        B_G_m,
    )
    B = np.where(np.logical_and(G_m == 1, R_c == 1), G + B_G_m, B)

    del B_G_m, R_r, R_c, G_m, B_G

    # Updating of the red (blue) component in the blue (red) locations.
    R_B = R - B
    R_B_m = np.where(
        B_m == 1,
        np.where(M == 1, _cnv_h(R_B, FIR), _cnv_v(R_B, FIR)),
        0,
    )
    R = np.where(B_m == 1, B + R_B_m, R)

    R_B_m = np.where(
        R_m == 1,
        np.where(M == 1, _cnv_h(R_B, FIR), _cnv_v(R_B, FIR)),
        0,
    )
    B = np.where(R_m == 1, R - R_B_m, B)

    del R_B, R_B_m, R_m

    return tstack([R, G, B])