import numpy as np
import pywt

def get_neighbors (img,channel,i,j,N,M):

    P1 = img[(i-1)%N,(j-1)%M,channel]
    P2 = img[(i-1)%N,j%M,channel]
    P3 = img[(i-1)%N,(j+1)%M,channel]
    P4 = img[i%N,(j-1)%M,channel]
    P5 = img[i%N,j%M,channel]
    P6 = img[i%N,(j+1)%M,channel]
    P7 = img[(i+1)%N,(j-1)%M,channel]
    P8 = img[(i+1)%N,j%M,channel]
    P9 = img[(i+1)%N,(j+1)%M,channel]

    return np.array([P1,P2,P3,P4,P5,P6,P7,P8,P9])


def get_derivatives(neighbors):

    [P1, P2, P3, P4, P5, P6, P7, P8, P9] = neighbors

    D_x = (P4 - P6)/2
    D_y = (P2 - P8)/2
    D_xd = (P3 - P7)/(2*np.sqrt(2))
    D_yd = (P1 - P9)/(2*np.sqrt(2))

    return ([D_x, D_y, D_xd, D_yd])


def get_weights(mosaic_image, i, j, channel, N, M):

    derivatives_neigbors = []
    for l in range(-1, 2):
        for L in range(-1, 2):
            derivatives_neigbors.append(get_derivatives(
                get_neighbors(mosaic_image, channel, i+l, j+L, N, M)))

    [Dx, Dy, Dxd, Dyd] = derivatives_neigbors[4]
    E1 = 1/np.sqrt(1 + Dyd**2 + derivatives_neigbors[0][3]**2)
    E2 = 1/np.sqrt(1 + Dy**2 + derivatives_neigbors[1][1]**2)
    E3 = 1/np.sqrt(1 + Dxd**2 + derivatives_neigbors[2][2]**2)
    E4 = 1/np.sqrt(1 + Dx**2 + derivatives_neigbors[3][0]**2)
    E6 = 1/np.sqrt(1 + Dxd**2 + derivatives_neigbors[5][2]**2)
    E7 = 1/np.sqrt(1 + Dy**2 + derivatives_neigbors[6][1]**2)
    E8 = 1/np.sqrt(1 + Dyd**2 + derivatives_neigbors[7][3]**2)
    E9 = 1/np.sqrt(1 + Dx**2 + derivatives_neigbors[8][0]**2)
    E = [E1, E2, E3, E4, E6, E7, E8, E9]

    return E    


def interpolate_green(weights, neighbors):

    [E1, E2, E3, E4, E6, E7, E8, E9] = weights
    [P1, P2, P3, P4, P5, P6, P7, P8, P9] = neighbors

    I5 = (E2*P2 + E4*P4 + E6*P6 + E8*P8)/(E2 + E4 + E6 + E8)

    return (I5)


def interpolate_red_blue(weights, neighbors, green_neighbors):

    [E1, E2, E3, E4, E6, E7, E8, E9] = weights
    [P1, P2, P3, P4, P5, P6, P7, P8, P9] = neighbors
    [G1, G2, G3, G4, G5, G6, G7, G8, G9] = green_neighbors

    I5 = G5*(E1*P1/G1 + E3*P3/G3 + E7*P7/G7 + E9*P9/G9)/(E1 + E3 + E7 + E9)

    return (I5)


def correction_green(res,i,j,N,M):


    [G1,G2,G3,G4,G5,G6,G7,G8,G9] = get_neighbors(res,1,i,j,N,M)
    [R1,R2,R3,R4,R5,R6,R7,R8,R9] = get_neighbors(res,0,i,j,N,M)
    [B1,B2,B3,B4,B5,B6,B7,B8,B9] = get_neighbors(res,2,i,j,N,M)
    [E1,E2,E3,E4,E6,E7,E8,E9] = get_weights(res,i,j,1,N,M)

    Gb5 = R5*((E2*G2)/B2 + (E4*G4)/B4 + (E6*G6)/B6 + (E8*G8)/B8)/(E2 + E4 + E6 + E8)
    Gr5 = B5*((E2*G2)/R2 + (E4*G4)/R4 + (E6*G6)/R6 + (E8*G8)/R8)/(E2 + E4 + E6 + E8)

    G5 = (Gb5 + Gr5)/2

    return G5

def correction_red(res,i,j,N,M) :

    [G1,G2,G3,G4,G5,G6,G7,G8,G9] = get_neighbors(res,1,i,j,N,M)
    [R1,R2,R3,R4,R5,R6,R7,R8,R9] = get_neighbors(res,0,i,j,N,M)
    [E1,E2,E3,E4,E6,E7,E8,E9] = get_weights(res,i,j,0,N,M)

    R5 = G5*((E1*R1)/G1 + (E2*R2)/G2 + (E3*R3)/G3 + (E4*R4)/G4 + (E6*R6)/G6 + (E7*R7)/G7 + (E8*R8)/G8 + (E9*R9)/G9)/(E1 + E2 + E3 + E4 + E6 + E7 + E8 + E9)

    return R5

def correction_blue(res,i,j,N,M) :

    [G1,G2,G3,G4,G5,G6,G7,G8,G9] = get_neighbors(res,1,i,j,N,M)
    [B1,B2,B3,B4,B5,B6,B7,B8,B9] = get_neighbors(res,2,i,j,N,M)
    [E1,E2,E3,E4,E6,E7,E8,E9] = get_weights(res,i,j,2,N,M)

    B5 = G5*((E1*B1)/G1 + (E2*B2)/G2 + (E3*B3)/G3 + (E4*B4)/G4 + (E6*B6)/G6 + (E7*B7)/G7 + (E8*B8)/G8 + (E9*B9)/G9)/(E1 + E2 + E3 + E4 + E6 + E7 + E8 + E9)

    return B5