from scipy import ndimage
import numpy as np

############################################################

    
def color_pixel(i,j,cfa = "bayer"):
    if (cfa == "quad_bayer"):
        i = i//2
        j = j//2
    
    if ((i+j)%2==0):
        return 'green'
    else:
        if (i%2==0):
            return 'red'
        else:
            return 'blue'

def rmse_pixel(pixel_raw,pixel_extrapolate):
    return np.sqrt(np.mean((pixel_raw-pixel_extrapolate)**2))

######### Method extrapolation with edge detection #########
def compute_orientation_matrix(img_raw):
    vertical = ndimage.sobel(img_raw, 0)
    horizontal = ndimage.sobel(img_raw, 1)
    orientation_matrix = np.zeros(img_raw.shape)
    orientation_matrix[vertical < horizontal] = 1
    return orientation_matrix

## Green Channel ##
##Formulas for etrapolation of pixels:
def extrapolate_green_top(img_raw,i,j):
    return img_raw[i-1,j] + 3/4*(img_raw[i,j]-img_raw[i-2,j])-1/4*(img_raw[i-1,j]-img_raw[i-3,j])

def extrapolate_green_bottom(img_raw,i,j):
    return img_raw[i+1,j] + 3/4*(img_raw[i,j]-img_raw[i+2,j])-1/4*(img_raw[i+1,j]-img_raw[i+3,j])

def extrapolate_green_left(img_raw,i,j):
    return img_raw[i,j-1] + 3/4*(img_raw[i,j]-img_raw[i,j-2])-1/4*(img_raw[i,j-1]-img_raw[i,j-3])

def extrapolate_green_right(img_raw,i,j):
    return img_raw[i,j+1] + 3/4*(img_raw[i,j]-img_raw[i,j+2])-1/4*(img_raw[i,j+1]-img_raw[i,j+3])

## Extrapolation method:
def median_extrapolate_green_pixel(img_raw,i,j,orientations_to_drop):
    list_extrapolate_pixel = []

    if ("top" not in orientations_to_drop):
        list_extrapolate_pixel.append(extrapolate_green_top(img_raw,i,j))
    if ("bottom" not in orientations_to_drop):
        list_extrapolate_pixel.append(extrapolate_green_bottom(img_raw,i,j))  
    if("left" not in orientations_to_drop):
        list_extrapolate_pixel.append(extrapolate_green_left(img_raw,i,j))
    if("right" not in orientations_to_drop):
        list_extrapolate_pixel.append(extrapolate_green_right(img_raw,i,j))

    return np.median(list_extrapolate_pixel)

def extrapolate_green_pixel(img_raw,i,j,orientation):
    # First the borders:
    orientations_to_drop = []
    if (i<2):
        orientations_to_drop.append('top')
    if (i>img_raw.shape[0]-4):
        orientations_to_drop.append('bottom')
    if (j<2):
        orientations_to_drop.append('left')
    if (j>img_raw.shape[1]-4):
        orientations_to_drop.append('right')
    
    
    # Then the rest of the image:
    else:
        if (orientation == 1): # V < H so we gonna eliminate one horizontal pixel. 
            if ("right" not in orientations_to_drop and "left" not in orientations_to_drop):
                rmse_pixel_left = rmse_pixel(img_raw[i,j],extrapolate_green_left(img_raw,i,j))
                rmse_pixel_right = rmse_pixel(img_raw[i,j],extrapolate_green_right(img_raw,i,j))
                if (rmse_pixel_left > rmse_pixel_right):
                    orientations_to_drop.append('left')
                else:
                    orientations_to_drop.append('right')
    
        else: # V > H so we gonna eliminate one vertical pixel.
            if ("top" not in orientations_to_drop and "bottom" not in orientations_to_drop):
                rmse_pixel_top = rmse_pixel(img_raw[i,j],extrapolate_green_top(img_raw,i,j))
                rmse_pixel_bottom = rmse_pixel(img_raw[i,j],extrapolate_green_bottom(img_raw,i,j))
                if (rmse_pixel_top > rmse_pixel_bottom):
                    orientations_to_drop.append('top')
                else:
                    orientations_to_drop.append('bottom')
    return median_extrapolate_green_pixel(img_raw,i,j,orientations_to_drop)

def extrapolate_green(img_raw,extrapolate_img):
    orientation_matrix = compute_orientation_matrix(img_raw)
    for i in range(img_raw.shape[0]):
        for j in range(img_raw.shape[1]):
            if (color_pixel(i,j)!= "green"):
                extrapolate_img[i,j,1] = extrapolate_green_pixel(img_raw,i,j,orientation_matrix[i,j])
            else:
                extrapolate_img[i,j,1] = img_raw[i,j]
    return extrapolate_img

## Red and Blue Channels ##
def extrapolate_top(img_raw,img_extrapolate,i,j):
    return (img_raw[i-1,j] + img_raw[i,j]-img_extrapolate[i-1,j,1])
def extrapolate_left(img_raw,img_extrapolate,i,j):
    return (img_raw[i,j-1] + img_raw[i,j]-img_extrapolate[i,j-1,1])
def extrapolate_right(img_raw,img_extrapolate,i,j):
    return (img_raw[i,j+1] + img_raw[i,j]-img_extrapolate[i,j+1,1])
def extrapolate_bottom(img_raw,img_extrapolate,i,j):
    return (img_raw[i+1,j] + img_raw[i,j]-img_extrapolate[i+1,j,1])

def extrapolate_top_left(img_raw,img_extrapolate,i,j):
    return (img_raw[i-1,j-1] + img_extrapolate[i,j,1]-img_extrapolate[i-1,j-1,1])
def extrapolate_top_right(img_raw,img_extrapolate,i,j):
    return (img_raw[i-1,j+1] + img_extrapolate[i,j,1]-img_extrapolate[i-1,j+1,1])
def extrapolate_bottom_left(img_raw,img_extrapolate,i,j):
    return (img_raw[i+1,j-1] + img_extrapolate[i,j,1]-img_extrapolate[i+1,j-1,1])
def extrapolate_bottom_right(img_raw,img_extrapolate,i,j):
    return (img_raw[i+1,j+1] + img_extrapolate[i,j,1]-img_extrapolate[i+1,j+1,1])

def median_pixel(img_raw,img_extrapolate,i,j,orientations_to_drop):
    list_extrapolate = []
    if (color_pixel(i,j) != "green"):
        if("top_left" not in orientations_to_drop):
            list_extrapolate.append(extrapolate_top_left(img_raw,img_extrapolate,i,j))
        if("top_right" not in orientations_to_drop):
            list_extrapolate.append(extrapolate_top_right(img_raw,img_extrapolate,i,j))
        if("bottom_left" not in orientations_to_drop):
            list_extrapolate.append(extrapolate_bottom_left(img_raw,img_extrapolate,i,j))
        if("bottom_right" not in orientations_to_drop):
            list_extrapolate.append(extrapolate_bottom_right(img_raw,img_extrapolate,i,j))

    elif (color_pixel(i,j) == "green"):
        if("top" not in orientations_to_drop):
            list_extrapolate.append(extrapolate_top(img_raw,img_extrapolate,i,j))
        if("left" not in orientations_to_drop):
            list_extrapolate.append(extrapolate_left(img_raw,img_extrapolate,i,j))
        if("right" not in orientations_to_drop):
            list_extrapolate.append(extrapolate_right(img_raw,img_extrapolate,i,j))
        if("bottom" not in orientations_to_drop):
            list_extrapolate.append(extrapolate_bottom(img_raw,img_extrapolate,i,j))
    return np.median(list_extrapolate)

def extrapolate_pixel(img_raw,img_extrapolate,i,j,color):
    orientations_to_drop = []

    if (color_pixel(i,j)!='green'):
        if (i<1):
            orientations_to_drop.append("top_left")
            orientations_to_drop.append("top_right")
        if (i>img_raw.shape[0]-2):
            orientations_to_drop.append("bottom_left")
            orientations_to_drop.append("bottom_right")
        if (j<1):
            orientations_to_drop.append("top_left")
            orientations_to_drop.append("bottom_left")
        if (j>img_raw.shape[1]-2):
            orientations_to_drop.append("top_right")
            orientations_to_drop.append("bottom_right")
        
        
        if ("top_left" not in orientations_to_drop and "top_right" not in orientations_to_drop and "bottom_left" not in orientations_to_drop and "bottom_right" not in orientations_to_drop):
            rmse_top_left = rmse_pixel(img_raw[i,j],extrapolate_top_left(img_raw,img_extrapolate,i,j))
            rmse_top_right = rmse_pixel(img_raw[i,j],extrapolate_top_right(img_raw,img_extrapolate,i,j))
            rmse_bottom_left = rmse_pixel(img_raw[i,j],extrapolate_bottom_left(img_raw,img_extrapolate,i,j))
            rmse_bottom_right = rmse_pixel(img_raw[i,j],extrapolate_bottom_right(img_raw,img_extrapolate,i,j))
            if (rmse_bottom_left> rmse_bottom_right and rmse_bottom_left> rmse_top_left and rmse_bottom_left> rmse_top_right):
                orientations_to_drop.append("bottom_left")
            elif (rmse_bottom_right> rmse_bottom_left and rmse_bottom_right> rmse_top_left and rmse_bottom_right> rmse_top_right):
                orientations_to_drop.append("bottom_right")
            elif (rmse_top_left> rmse_bottom_left and rmse_top_left> rmse_bottom_right and rmse_top_left> rmse_top_right):
                orientations_to_drop.append("top_left")
            else:
                orientations_to_drop.append("top_right")
    elif(color_pixel(i,j)=="green"):
        if (i<1):
            orientations_to_drop.append("top")
        if (i>img_raw.shape[0]-2):
            orientations_to_drop.append("bottom")
        if (j<1):
            orientations_to_drop.append("left")
        if (j>img_raw.shape[1]-2):
            orientations_to_drop.append("right")
        
        if ((i%2!=0 and color == "red") or (i%2==0 and color == "blue")):
            if ("right" not in orientations_to_drop and "left" not in orientations_to_drop):
                rmse_pixel_left = rmse_pixel(img_raw[i,j],extrapolate_left(img_raw,img_extrapolate,i,j))
                rmse_pixel_right = rmse_pixel(img_raw[i,j],extrapolate_right(img_raw,img_extrapolate,i,j))
                if (rmse_pixel_left > rmse_pixel_right):
                    orientations_to_drop.append('left')
                else:
                    orientations_to_drop.append('right')
        else:
            if ("top" not in orientations_to_drop and "bottom" not in orientations_to_drop):
                rmse_pixel_top = rmse_pixel(img_raw[i,j],extrapolate_top(img_raw,img_extrapolate,i,j))
                rmse_pixel_bottom = rmse_pixel(img_raw[i,j],extrapolate_bottom(img_raw,img_extrapolate,i,j))
                if (rmse_pixel_top > rmse_pixel_bottom):
                    orientations_to_drop.append('top')
                else:
                    orientations_to_drop.append('bottom')
    return median_pixel(img_raw,img_extrapolate,i,j,orientations_to_drop)

def extrapolate_red(img_raw,img_extrapolate):
    for i in range(img_raw.shape[0]):
        for j in range(img_raw.shape[1]):
            if (color_pixel(i,j)!="red"):
                img_extrapolate[i,j,0] = extrapolate_pixel(img_raw,img_extrapolate,i,j,"red")
            else:
                img_extrapolate[i,j,0] = img_raw[i,j]

def extrapolate_blue(img_raw,img_extrapolate):
    for i in range(img_raw.shape[0]):
        for j in range(img_raw.shape[1]):
            if (color_pixel(i,j)!="blue"):
                img_extrapolate[i,j,2] = extrapolate_pixel(img_raw,img_extrapolate,i,j,"blue")
            else:
                img_extrapolate[i,j,2] = img_raw[i,j]

def extrapolate_img(img_cfa):
    extapolate_img = np.zeros(img_cfa.shape + (3,))
    extrapolate_green(img_cfa,extapolate_img)
    extrapolate_red(img_cfa,extapolate_img)
    extrapolate_blue(img_cfa,extapolate_img)
    return extapolate_img


#################################################
##QUAD BAYER
#################################################


## Green Channel ##

### Formulas extrapolation of pixels:
def extrapolate_green_top_quad(img_raw,i,j):
    extrapolate_quad = np.zeros((2,2))
    for m in range(2):
        for n in range(2):
            extrapolate_quad[m,n] = img_raw[i+m-1*2,j+n] + 3/4*(img_raw[i+m,j+n]-img_raw[i+m-2*2,j+n])-1/4*(img_raw[i+m-1*2,j+n]-img_raw[i+m-3*2,j+n])
    return extrapolate_quad

def extrapolate_green_bottom_quad(img_raw,i,j):
    extrapolate_quad = np.zeros((2,2))
    for m in range(2):
        for n in range(2):
            extrapolate_quad[m,n] = img_raw[i+m+1*2,j+n] + 3/4*(img_raw[i+m,j+n]-img_raw[i+m+2*2,j+n])-1/4*(img_raw[i+m+1*2,j+n]-img_raw[i+m+3*2,j+n])
    return extrapolate_quad

def extrapolate_green_left_quad(img_raw,i,j):
    extrapolate_quad = np.zeros((2,2))
    for m in range(2):
        for n in range(2):
            extrapolate_quad[m,n] = img_raw[i+m,j+n-1*2] + 3/4*(img_raw[i+m,j+n]-img_raw[i+m,j+n-2*2])-1/4*(img_raw[i+m,j+n-1*2]-img_raw[i+m,j+n-3*2])
    return extrapolate_quad

def extrapolate_green_right_quad(img_raw,i,j):
    extrapolate_quad = np.zeros((2,2))
    for m in range(2):
        for n in range(2):
            extrapolate_quad[m,n] = img_raw[i+m,j+n+1*2] + 3/4*(img_raw[i+m,j+n]-img_raw[i+m,j+n+2*2])-1/4*(img_raw[i+m,j+n+1*2]-img_raw[i+m,j+n+3*2])
    return extrapolate_quad

### Extrapolation method:
def median_extrapolate_green_pixel_quad(img_raw,i,j,orientations_to_drop):
    list_extrapolate_pixel = []

    if ("top" not in orientations_to_drop):
        list_extrapolate_pixel.append(extrapolate_green_top_quad(img_raw,i,j))
    if ("bottom" not in orientations_to_drop):
        list_extrapolate_pixel.append(extrapolate_green_bottom_quad(img_raw,i,j))  
    if("left" not in orientations_to_drop):
        list_extrapolate_pixel.append(extrapolate_green_left_quad(img_raw,i,j))
    if("right" not in orientations_to_drop):
        list_extrapolate_pixel.append(extrapolate_green_right_quad(img_raw,i,j))
    median_quad = np.zeros((2,2))
    for m in range(2):
        for n in range(2):
            median_quad[m,n] = np.median([list_extrapolate_pixel[k] for k in range(len(list_extrapolate_pixel))])
    return median_quad

def extrapolate_green_pixel_quad(img_raw,i,j,orientation):
    # First the borders:
    orientations_to_drop = []
    if (i<2):
        orientations_to_drop.append('top')
    if (i>img_raw.shape[0]-4*2):
        orientations_to_drop.append('bottom')
    if (j<2):
        orientations_to_drop.append('left')
    if (j>img_raw.shape[1]-4*2):
        orientations_to_drop.append('right')
    
    
    # Then the rest of the image:
    else:
        if (orientation >0.5): # V < H so we gonna eliminate one horizontal pixel. 
            if ("right" not in orientations_to_drop and "left" not in orientations_to_drop):
                rmse_pixel_left = rmse_pixel(img_raw[i:i+2,j:j+2],extrapolate_green_left_quad(img_raw,i,j))
                rmse_pixel_right = rmse_pixel(img_raw[i:i+2,j:j+2],extrapolate_green_right_quad(img_raw,i,j))
                if (np.sum(rmse_pixel_left) > np.sum(rmse_pixel_right)):
                    orientations_to_drop.append('left')
                else:
                    orientations_to_drop.append('right')
    
        else: # V > H so we gonna eliminate one vertical pixel.
            if ("top" not in orientations_to_drop and "bottom" not in orientations_to_drop):
                rmse_pixel_top = rmse_pixel(img_raw[i+2,j+2],extrapolate_green_top_quad(img_raw,i,j))
                rmse_pixel_bottom = rmse_pixel(img_raw[i+2,j+2],extrapolate_green_bottom_quad(img_raw,i,j))
                if (np.sum(rmse_pixel_top) > np.sum(rmse_pixel_bottom)):
                    orientations_to_drop.append('top')
                else:
                    orientations_to_drop.append('bottom')
    return median_extrapolate_green_pixel_quad(img_raw,i,j,orientations_to_drop)

def extrapolate_green_quad(img_raw,extrapolate_img):
    orientation_matrix = compute_orientation_matrix(img_raw)
    for i in range(0,img_raw.shape[0],2):
        for j in range(0,img_raw.shape[1],2):
            if (color_pixel(i,j,'quad_bayer')!= "green"):
                extrapolate_img[i:i+2,j:j+2,1] = extrapolate_green_pixel_quad(img_raw,i,j,(1/4) *np.sum(orientation_matrix[i:i+2,j:j+2]))
            else:
                extrapolate_img[i:i+2,j:j+2,1] = img_raw[i:i+2,j:j+2]
    return extrapolate_img

## Red and Blue Channels ##

def extrapolate_top_quad(img_raw,img_extrapolate,i,j):
    extrapolate_quad = np.zeros((2,2))
    for m in range(2):
        for n in range(2):
            extrapolate_quad[m,n] = (img_raw[i+m-1*2,j+n] + img_raw[i+m,j+n]-img_extrapolate[i+m-1*2,j+n,1])
    return extrapolate_quad
def extrapolate_left_quad(img_raw,img_extrapolate,i,j):
    extrapolate_quad = np.zeros((2,2))
    for m in range(2):
        for n in range(2):
            extrapolate_quad[m,n] = (img_raw[i+m,j+n-1*2] + img_raw[i+m,j+n]-img_extrapolate[i+m,j+n-1*2,1])
    return extrapolate_quad
def extrapolate_right_quad(img_raw,img_extrapolate,i,j):
    extrapolate_quad = np.zeros((2,2))
    for m in range(2):
        for n in range(2):
            extrapolate_quad[m,n] = (img_raw[i+m,j+n+1*2] + img_raw[i+m,j+n]-img_extrapolate[i+m,j+n+1*2,1])
    return extrapolate_quad
def extrapolate_bottom_quad(img_raw,img_extrapolate,i,j):
    extrapolate_quad = np.zeros((2,2))
    for m in range(2):
        for n in range(2):
            extrapolate_quad[m,n] = (img_raw[i+m+1*2,j+n] + img_raw[i+m,j+n]-img_extrapolate[i+m+1*2,j+n,1])
    return extrapolate_quad

def extrapolate_top_left_quad(img_raw,img_extrapolate,i,j):
    extrapolate_quad = np.zeros((2,2))
    for m in range(2):
        for n in range(2):
            extrapolate_quad[m,n] =(img_raw[i+m-1*2,j+n-1*2] + img_extrapolate[i+m,j+n,1]-img_extrapolate[i+m-1*2,j+n-1*2,1])
    return extrapolate_quad
def extrapolate_top_right_quad(img_raw,img_extrapolate,i,j):
    extrapolate_quad = np.zeros((2,2))
    for m in range(2):
        for n in range(2):
            extrapolate_quad[m,n] =(img_raw[i+m-1*2,j+n+1*2] + img_extrapolate[i+m,j+n,1]-img_extrapolate[i+m-1*2,j+n+1*2,1])
    return extrapolate_quad
def extrapolate_bottom_left_quad(img_raw,img_extrapolate,i,j):
    extrapolate_quad = np.zeros((2,2))
    for m in range(2):
        for n in range(2):
            extrapolate_quad[m,n] =(img_raw[i+m+1*2,j+n-1*2] + img_extrapolate[i+m,j+n,1]-img_extrapolate[i+m+1*2,j+n-1*2,1])
    return extrapolate_quad
def extrapolate_bottom_right_quad(img_raw,img_extrapolate,i,j):
    extrapolate_quad = np.zeros((2,2))
    for m in range(2):
        for n in range(2):
            extrapolate_quad[m,n] =(img_raw[i+m+1*2,j+n+1*2] + img_extrapolate[i+m,j+n,1]-img_extrapolate[i+m+1*2,j+n+1*2,1])
    return extrapolate_quad

def median_pixel_quad(img_raw,img_extrapolate,i,j,orientations_to_drop):
    list_extrapolate = []
    if (color_pixel(i,j,"quad_bayer") != "green"):
        if("top_left" not in orientations_to_drop):
            list_extrapolate.append(extrapolate_top_left_quad(img_raw,img_extrapolate,i,j))
        if("top_right" not in orientations_to_drop):
            list_extrapolate.append(extrapolate_top_right_quad(img_raw,img_extrapolate,i,j))
        if("bottom_left" not in orientations_to_drop):
            list_extrapolate.append(extrapolate_bottom_left_quad(img_raw,img_extrapolate,i,j))
        if("bottom_right" not in orientations_to_drop):
            list_extrapolate.append(extrapolate_bottom_right_quad(img_raw,img_extrapolate,i,j))

    elif (color_pixel(i,j,"quad_bayer") == "green"):
        if("top" not in orientations_to_drop):
            list_extrapolate.append(extrapolate_top_quad(img_raw,img_extrapolate,i,j))
        if("left" not in orientations_to_drop):
            list_extrapolate.append(extrapolate_left_quad(img_raw,img_extrapolate,i,j))
        if("right" not in orientations_to_drop):
            list_extrapolate.append(extrapolate_right_quad(img_raw,img_extrapolate,i,j))
        if("bottom" not in orientations_to_drop):
            list_extrapolate.append(extrapolate_bottom_quad(img_raw,img_extrapolate,i,j))
    
    median_quad = np.zeros((2,2))
    for m in range(2):
        for n in range(2):
            median_quad[m,n] = np.median([list_extrapolate[k][m,n] for k in range(len(list_extrapolate))])
    return median_quad

def extrapolate_pixel_quad(img_raw,img_extrapolate,i,j,color):
    orientations_to_drop = []

    if (color_pixel(i,j,"quad_bayer")!='green'):
        if (i<1):
            orientations_to_drop.append("top_left")
            orientations_to_drop.append("top_right")
        if (i>img_raw.shape[0]-2*2):
            orientations_to_drop.append("bottom_left")
            orientations_to_drop.append("bottom_right")
        if (j<1):
            orientations_to_drop.append("top_left")
            orientations_to_drop.append("bottom_left")
        if (j>img_raw.shape[1]-2*2):
            orientations_to_drop.append("top_right")
            orientations_to_drop.append("bottom_right")
        
        
        if ("top_left" not in orientations_to_drop and "top_right" not in orientations_to_drop and "bottom_left" not in orientations_to_drop and "bottom_right" not in orientations_to_drop):
            rmse_top_left = rmse_pixel(img_raw[i:i+2,j:j+2],extrapolate_top_left_quad(img_raw,img_extrapolate,i,j))
            rmse_top_right = rmse_pixel(img_raw[i:i+2,j:j+2],extrapolate_top_right_quad(img_raw,img_extrapolate,i,j))
            rmse_bottom_left = rmse_pixel(img_raw[i:i+2,j:j+2],extrapolate_bottom_left_quad(img_raw,img_extrapolate,i,j))
            rmse_bottom_right = rmse_pixel(img_raw[i:i+2,j:j+2],extrapolate_bottom_right_quad(img_raw,img_extrapolate,i,j))
            if (rmse_bottom_left> rmse_bottom_right and rmse_bottom_left> rmse_top_left and rmse_bottom_left> rmse_top_right):
                orientations_to_drop.append("bottom_left")
            elif (rmse_bottom_right> rmse_bottom_left and rmse_bottom_right> rmse_top_left and rmse_bottom_right> rmse_top_right):
                orientations_to_drop.append("bottom_right")
            elif (rmse_top_left> rmse_bottom_left and rmse_top_left> rmse_bottom_right and rmse_top_left> rmse_top_right):
                orientations_to_drop.append("top_left")
            else:
                orientations_to_drop.append("top_right")
    elif(color_pixel(i,j,"quad_bayer")=="green"):
        if (i<1):
            orientations_to_drop.append("top")
        if (i>img_raw.shape[0]-2*2):
            orientations_to_drop.append("bottom")
        if (j<1):
            orientations_to_drop.append("left")
        if (j>img_raw.shape[1]-2*2):
            orientations_to_drop.append("right")
        
        if (((i/2)%2!=0 and color == "red") or ((i/2)%2==0 and color == "blue")):
            if ("right" not in orientations_to_drop and "left" not in orientations_to_drop):
                rmse_pixel_left = rmse_pixel(img_raw[i:i+2,j:j+2],extrapolate_left_quad(img_raw,img_extrapolate,i,j))
                rmse_pixel_right = rmse_pixel(img_raw[i:i+2,j:j+2],extrapolate_right_quad(img_raw,img_extrapolate,i,j))
                if (rmse_pixel_left > rmse_pixel_right):
                    orientations_to_drop.append('left')
                else:
                    orientations_to_drop.append('right')
        else:
            if ("top" not in orientations_to_drop and "bottom" not in orientations_to_drop):
                rmse_pixel_top = rmse_pixel(img_raw[i:i+2,j:j+2],extrapolate_top_quad(img_raw,img_extrapolate,i,j))
                rmse_pixel_bottom = rmse_pixel(img_raw[i:i+2,j:j+2],extrapolate_bottom_quad(img_raw,img_extrapolate,i,j))
                if (rmse_pixel_top > rmse_pixel_bottom):
                    orientations_to_drop.append('top')
                else:
                    orientations_to_drop.append('bottom')
    return median_pixel_quad(img_raw,img_extrapolate,i,j,orientations_to_drop)

def extrapolate_red_quad(img_raw,img_extrapolate):
    for i in range(0,img_raw.shape[0],2):
        for j in range(0,img_raw.shape[1],2):
            if (color_pixel(i,j,"quad_bayer")!="red"):
                img_extrapolate[i:i+2,j:j+2,0] = extrapolate_pixel_quad(img_raw,img_extrapolate,i,j,"red")
            else:
                img_extrapolate[i:i+2,j:j+2,0] = img_raw[i:i+2,j:j+2]

def extrapolate_blue_quad(img_raw,img_extrapolate):
    for i in range(0,img_raw.shape[0],2):
        for j in range(0,img_raw.shape[1],2):
            if (color_pixel(i,j,"quad_bayer")!="blue"):
                img_extrapolate[i:i+2,j:j+2,2] = extrapolate_pixel_quad(img_raw,img_extrapolate,i,j,"blue")
            else:
                img_extrapolate[i:i+2,j:j+2,2] = img_raw[i:i+2,j:j+2]

def extrapolate_img_quad(img_cfa):
    extapolate_img = np.zeros(img_cfa.shape + (3,))
    extrapolate_green_quad(img_cfa,extapolate_img)
    extrapolate_red_quad(img_cfa,extapolate_img)
    extrapolate_blue_quad(img_cfa,extapolate_img)
    return extapolate_img

def extrapolate_cfa(img_cfa,cfa):
    if (cfa=="bayer"):
        return extrapolate_img(img_cfa)
    elif(cfa=="quad_bayer"):
        return extrapolate_img_quad(img_cfa)
    else:
        print("Error: cfa not recognized")
        return None