import numpy as np



def nonlocal_int(R_0, G_0, B_0, beta, h, k, rho, N, cfa_mask) :
    """
    Nonlocal interpolation algorithm
    Inputs : 
        - R_0, G_0, B_0 : The initial reconstructed channels
        - beta : the parameter of intercorrelation between the channels
        - h : the filtering parameter
        - k : the half-size of the search window
        - rho : the size of the 
        - N : the number of iterations
    Outputs :
        - R, G, B : the reconstructed channels
    """

    height, width = R_0.shape
    u_0 = np.stack((R_0, G_0, B_0), axis=2)
    d = np.ones((height, width, height, width)) * np.inf
    w = np.zeros((height, width, N)) 
    ksi = np.zeros((height, width,2))
    # Compute weight distribution on initialization u0
    R = np.zeros_like(R_0)
    G = np.zeros_like(G_0)
    B = np.zeros_like(B_0)

    index_image = np.zeros((height, width,N, 2))


    for i in range(height) :
        for j in range(width) :
            # p = (i,j)
            for h in range(max(0, i-k), min(height, i+k+1)) :
                for w in range(max(0, j-k), min(width, j+k+1)) :
                    # q = m,n
                    d[i,j,h,w] = np.sum((u_0[i-rho:i+rho+1, j-rho:j+rho+1] - u_0[h-rho:h+rho+1, w-rho:w+rho+1])**2)

            d[i,j,i,j] = np.inf
            d[i,j,i,j] = np.min(d[i,j])

            sorted_image = np.sort(d[i,j], axis=None)
            reconstructed_image = np.zeros((height, width))
            
            for t in range(height*width) :
                reconstructed_image[t//width, t%width] = sorted_image[t]
            

    

            for n in range(N) :
                index_image[i,j,n] = np.where(d[i,j] == sorted_image[n])
                w[i,j,n] = np.exp(-d[i,j,0,n]/(h**2))

            ksi[i,j] = np.sum(w[i,j])

            for n in range(N) :
                w[i,j,n] = w[i,j,n]/ksi[i,j]
    
    # Enhancement of green channel
                
    for i in range(height) :
        for j in range(width) :
            if not cfa_mask[i,j,1] == 1 : # Not in green CFA
                for n in range(N) :
                    green_qn = G_0[int(index_image[i,j,n,0]), int(index_image[i,j,n,1])] # G0(qn)
                    red_qn = R_0[int(index_image[i,j,n,0]), int(index_image[i,j,n,1])]
                    blue_qn = B_0[int(index_image[i,j,n,0]), int(index_image[i,j,n,1])]
                    G[i,j] += w[i,j,n] * (green_qn - beta * (red_qn + blue_qn)) + beta * (R_0[i,j] + B_0[i,j])
    
    # Enhancement of red and blue channels
                    
    for i in range(height) :
        for j in range(width) :
            if not cfa_mask[i,j,0] == 1 : # Not in red CFA
                for n in range(N) :
                    red_qn = R_0[int(index_image[i,j,n,0]), int(index_image[i,j,n,1])] # R0(qn)
                    G_qn = G_0[int(index_image[i,j,n,0]), int(index_image[i,j,n,1])]
                    R[i,j] += w[i,j,n] * (red_qn - beta * G_qn) + beta * G_0[i,j]
            if not cfa_mask[i,j,2] == 1 : # Not in blue CFA
                for n in range(N) :
                    blue_qn = B_0[int(index_image[i,j,n,0]), int(index_image[i,j,n,1])] # B0(qn)
                    G_qn = G_0[int(index_image[i,j,n,0]), int(index_image[i,j,n,1])]
                    B[i,j] += w[i,j,n] * (blue_qn - beta * G_qn) + beta * G_0[i,j]

    return R, G, B