import numpy as np

def find_Knearest_neighbors(z, chan, i, j, N, M):
    """Finds a pixel's neighbors on a channel"""
    return np.array([z[(i+di)%N, (j+dj)%M, chan] for di in range(-1, 2) for dj in range(-1, 2)])

def calculate_directional_gradients(neighbors):
    """Gives the directional derivative"""
    P1, P2, P3, P4, P5, P6, P7, P8, P9 = neighbors
    Dx, Dy = (P4 - P6)/2, (P2 - P8)/2
    Dxd, Dyd = (P3 - P7)/(2*np.sqrt(2)), (P1 - P9)/(2*np.sqrt(2))
    return [Dx, Dy, Dxd, Dyd]

def calculate_adaptive_weights(z, neigh, dir_deriv,chan,i,j,N,M):

    [Dx,Dy,Dxd,Dyd] = dir_deriv
    [P1,P2,P3,P4,P5,P6,P7,P8,P9] = neigh
    E = []
    c = 1
    for k in range (-1,2):
        for k in range (-1,2):

            n = find_Knearest_neighbors(z,chan,i+k,j+k,N,M)
            dd = calculate_directional_gradients(n)
            if c == 1 or c == 9:
                E.append(1/np.sqrt(1 + Dyd**2 + dd[3]**2))
            elif c == 2 or c == 8:
                E.append(1/np.sqrt(1 + Dy**2 + dd[1]**2))
            elif c == 3 or c == 7:
                E.append(1/np.sqrt(1 + Dxd**2 + dd[2]**2))
            elif c == 4 or c == 6:
                E.append(1/np.sqrt(1 + Dx**2 + dd[0]**2))
            c += 1
    return E       

def interpolate_pixel(neigh,weights):
    """This function performs interpolation for a single pixel by calculating a weighted average of its neighboring pixels"""
    [P1,P2,P3,P4,P5,P6,P7,P8,P9] = neigh
    [E1,E2,E3,E4,E6,E7,E8,E9] = weights
    num5 = E2*P2 + E4*P4 + E6*P6 + E8*P8
    den5 = E2 + E4 + E6 + E8
    I5 = num5/den5
    return I5

def interpolate_RedBlue(neighbors, neighbors_G, weights):
    """This function specifically interpolates a pixel in the red or blue channels"""
    [P1,P2,P3,P4,P5,P6,P7,P8,P9] = neighbors
    [G1,G2,G3,G4,G5,G6,G7,G8,G9] = neighbors_G
    [E1,E2,E3,E4,E6,E7,E8,E9] = weights
    num5 = ((E1*P1)/G1) + ((E3*P3)/G3) + ((E7*P7)/G7) + ((E9*P9)/G9)
    den5 = E1 + E3 + E7 + E9
    I5 = G5 * num5/den5

    return I5