import numpy as np

def bayer_blue_red(res, y, i, j) : 
    """
    Compute estimated blue/red pixel in red/blue bayer pixel

    Args : 
        res     : estimated image
        y       : image to reconstruct
        i, j    : indices
    Return : 
        value   : value of the estimated pixels
    """
    K2 = res[i-1,j-1,1] - y[i-1,j-1]
    K4 = res[i-1,j+1,1] - y[i-1,j+1]
    K10 = res[i+1,j-1,1] - y[i+1,j-1]
    K12 = res[i+1,j+1,1] - y[i+1,j+1]
    value = res[i,j,1]-1/4*(K2+K4+K10+K12)
    return value

def bayer_green_vert(res, y, i, j) : 
    """
    Compute estimated blue/red pixel in green bayer pixel in vertical direction

    Args : 
        res     : estimated image
        y       : image to reconstruct
        i, j    : indices
    Return : 
        value   : value of the estimated pixels
    """
    k1 = res[i-1,j,1] - y[i-1,j]
    k2 = res[i+1,j,1] - y[i+1,j]
    value = y[i,j] - 1/2*(k1+k2)
    return value

def bayer_green_hor(res, y, i, j): 
    """
    Compute estimated blue/red pixel in green bayer pixel in horizontal direction

    Args : 
        res     : estimated image
        y       : image to reconstruct
        i, j    : indices
    Return : 
        value   : value of the estimated pixels
    """
    k1 = res[i,j-1,1] - y[i,j-1]
    k2 = res[i,j+1,1] - y[i,j+1]
    value = y[i,j] - 1/2*(k1+k2)
    return value

def interpolate_green(res, y, z):
    """
    Directional interpolation of the green channel

    Args : 
        res     : estimated image
        y       : image to reconstruct
        z       : bayer pattern
    Return : 
        res     : reconstructed image
    """
    for i in range(2,y.shape[0]-1):
        for j in range(2,y.shape[1]-1):
        # Vertical and horizontal gradient
            if z[i,j,1]==0:
                d_h = np.abs(y[i,j-1]-y[i,j+1])
                d_v = np.abs(y[i-1,j]-y[i+1,j])
                if d_h > d_v:
                    green = (y[i-1,j]+y[i+1,j])/2
                elif d_v > d_h:
                    green = (y[i,j-1]+y[i,j+1])/2
                else :
                    green = (y[i,j-1]+y[i,j+1]+y[i-1,j]+y[i+1,j])/4
            else : 
                green = y[i,j]
            res[i,j,1] = green
    return res

def quad_to_bayer(y):
    """
    Convert Quad Bayer to Bayer

    Args : 
        res     : estimated image
        y       : image to reconstruct
        i, j    : indices
    Return : 
        value   : value of the estimated pixels
    """
    for i in range(1, y.shape[0], 4):
        save = np.copy(y[:,i])
        y[:,i] = y[:,i+1]
        y[:,i+1] = save
    for j in range(1, y.shape[0], 4):
        save = np.copy(y[j,:])
        y[j,:] = y[j+1,:]
        y[j+1,:] = save
    for i in range(1, y.shape[0], 4):
        for j in range(1, y.shape[0], 4):
            save = np.copy(y[i,j])
            y[i,j] = y[i+1,j+1]
            y[i+1,j+1] = save
    return y

def interpolation(y, op):
    """
    Reconstruct image

    Args : 
        y       : image to reconstruct
        op      : CFA operator
    Return : 
        np.ndarray: Demosaicked image.
    """
    if op.cfa == 'quad_bayer':
        y = quad_to_bayer(y)
        op.mask = quad_to_bayer(op.mask)

    z = op.adjoint(y)
    res = np.empty(op.input_shape)

    # Interpolation of green channel
    res = interpolate_green(res, y, z)
    # Interpolation of R and B channels using channel correlation
    for i in range(2,y.shape[0]-2):
        for j in range(2, y.shape[1]-2):
            # Bayer is Green
            if z[i,j,1] != 0 : 
                # Green is between 2 vertical bleu
                if z[i+1,j,0] == 0:
                    red = bayer_green_hor(res, y, i, j) # Compute Red channel
                    blue = bayer_green_vert(res, y, i, j) # Compute Blue channel 
                # Green is between 2 vertical red
                else:
                    blue = bayer_green_hor(res, y, i, j) # Compute Blue channel 
                    red = bayer_green_vert(res, y, i, j) # Compute Red channel
            # Bayer is red
            elif z[i,j,0] != 0 :
                red = y[i,j]    # Red channel
                blue = bayer_blue_red(res, y, i, j) # Blue channel
            # Bayer is bleue
            elif z[i,j,2] != 0 :
                blue = y[i,j] # Bleu channel
                red = bayer_blue_red(res, y, i, j) # Res channel

            res[i,j,0] = np.clip(red, 0, 255)
            res[i,j,2] = np.clip(blue,0,255)
    return res