Skip to content
Snippets Groups Projects
reconstruct.py 2.12 KiB
Newer Older
mounael15's avatar
mounael15 committed
import numpy as np
from src.forward_model import CFA
from src.methods.ELAMRANI_Mouna.functions import *


def run_reconstruction(y: np.ndarray, cfa: str) -> np.ndarray:

    cfa_name = 'bayer' # bayer or quad_bayer
    input_shape = (y.shape[0], y.shape[1], 3)
    op = CFA(cfa_name, input_shape)

    img_res = op.adjoint(y)
    N = img_res[:,:,0].shape[0]
    M = img_res[:,:,0].shape[1]

    def interpolate_channel(img_res, channel, first_pass, N, M):
        for i in range(N):
            for j in range(M):
                if first_pass and ((channel == 0 and i % 2 == 1 and j % 2 == 0) or
                               (channel == 2 and i % 2 == 0 and j % 2 == 1)):
                    neighbors = find_Knearest_neighbors(img_res, channel, i, j, N, M)
                    neighbors_G = find_Knearest_neighbors(img_res, 1, i, j, N, M)
                    dir_deriv = calculate_directional_gradients(neighbors_G)
                    weights = calculate_adaptive_weights(img_res, neighbors_G, dir_deriv, 1, i, j, N, M)
                    img_res[i, j, channel] = interpolate_RedBlue(neighbors, neighbors_G, weights)
                elif not first_pass and img_res[i, j, channel] == 0:
                    neighbors = find_Knearest_neighbors(img_res, channel, i, j, N, M)
                    dir_deriv = calculate_directional_gradients(neighbors)
                    weights = calculate_adaptive_weights(img_res, neighbors, dir_deriv, channel, i, j, N, M)
                    img_res[i, j, channel] = interpolate_pixel(neighbors, weights)
        return img_res

# Interpolation pour chaque canal
    img_res = interpolate_channel(img_res, 1, False, N, M)  # Interpolation du canal vert
    img_res = interpolate_channel(img_res, 0, True, N, M)   # Première interpolation du canal rouge
    img_res = interpolate_channel(img_res, 0, False, N, M)  # Seconde interpolation du canal rouge
    img_res = interpolate_channel(img_res, 2, True, N, M)   # Première interpolation du canal bleu
    img_res = interpolate_channel(img_res, 2, False, N, M)  # Seconde interpolation du canal bleu

    img_res[img_res > 1] = 1
    img_res[img_res < 0] = 0

    return img_res