import numpy as np from scipy.signal import correlate2d from src.forward_model import CFA def malvar_he_cutler(y: np.ndarray, op: CFA ) -> np.ndarray: """Performs demosaicing using the malvar-he-cutler algorithm Args: op (CFA): CFA operator. y (np.ndarray): Mosaicked image. Returns: np.ndarray: Demosaicked image. """ red_mask, green_mask, blue_mask = [op.mask[:, :, 0], op.mask[:, :, 1], op.mask[:, :, 2]] mosaicked_image = np.float32(y) demosaicked_image = np.empty(op.input_shape) if op.cfa == 'quad_bayer': filters = get_quad_bayer_filters() else: filters = get_default_filters() demosaicked_image = apply_demosaicking_filters( mosaicked_image,demosaicked_image, red_mask, green_mask, blue_mask, filters ) return demosaicked_image def get_quad_bayer_filters(): coefficient_scale = 0.03125 return { "G_at_R_and_B": np.array([ [0, 0, 0, 0, -1, -1, 0, 0, 0, 0], [0, 0, 0, 0, -1, -1, 0, 0, 0, 0], [0, 0, 0, 0, 2, 2, 0, 0, 0, 0], [0, 0, 0, 0, 2, 2, 0, 0, 0, 0], [-1, -1, 2, 2, 4, 4, 2, 2, -1, -1], [-1, -1, 2, 2, 4, 4, 2, 2, -1, -1], [0, 0, 0, 0, 2, 2, 0, 0, 0, 0], [0, 0, 0, 0, 2, 2, 0, 0, 0, 0], [0, 0, 0, 0, -1, -1, 0, 0, 0, 0], [0, 0, 0, 0, -1, -1, 0, 0, 0, 0] ]) * coefficient_scale, "R_at_GR_and_B_at_GB": np.array([ [0, 0, 0, 0, 0.5, 0.5, 0, 0, 0, 0], [0, 0, 0, 0, 0.5, 0.5, 0, 0, 0, 0], [0, 0, -1, -1, 0, 0, -1, -1, 0, 0], [0, 0, -1, -1, 0, 0, -1, -1, 0, 0], [-1, -1, 4, 4, 5, 5, 4, 4, -1, -1], [-1, -1, 4, 4, 5, 5, 4, 4, -1, -1], [0, 0, -1, -1, 0, 0, -1, -1, 0, 0], [0, 0, -1, -1, 0, 0, -1, -1, 0, 0], [0, 0, 0, 0, 0.5, 0.5, 0, 0, 0, 0], [0, 0, 0, 0, 0.5, 0.5, 0, 0, 0, 0] ]) * coefficient_scale, "R_at_GB_and_B_at_GR": np.array([ [0, 0, 0, 0, -1, -1, 0, 0, 0, 0], [0, 0, 0, 0, -1, -1, 0, 0, 0, 0], [0, 0, -1, -1, 4, 4, -1, -1, 0, 0], [0, 0, -1, -1, 4, 4, -1, -1, 0, 0], [0.5, 0.5, 0, 0, 5, 5, 0, 0, 0.5, 0.5], [0.5, 0.5, 0, 0, 5, 5, 0, 0, 0.5, 0.5], [0, 0, -1, -1, 4, 4, -1, -1, 0, 0], [0, 0, -1, -1, 4, 4, -1, -1, 0, 0], [0, 0, 0, 0, -1, -1, 0, 0, 0, 0], [0, 0, 0, 0, -1, -1, 0, 0, 0, 0] ]) * coefficient_scale, "R_at_B_and_B_at_R": np.array([ [0, 0, 0, 0, -1.5, -1.5, 0, 0, 0, 0], [0, 0, 0, 0, -1.5, -1.5, 0, 0, 0, 0], [0, 0, 2, 2, 0, 0, 2, 2, 0, 0], [0, 0, 2, 2, 0, 0, 2, 2, 0, 0], [-1.5, -1.5, 0, 0, 6, 6, 0, 0, -1.5, -1.5], [-1.5, -1.5, 0, 0, 6, 6, 0, 0, -1.5, -1.5], [0, 0, 2, 2, 0, 0, 2, 2, 0, 0], [0, 0, 2, 2, 0, 0, 2, 2, 0, 0], [0, 0, 0, 0, -1.5, -1.5, 0, 0, 0, 0], [0, 0, 0, 0, -1.5, -1.5, 0, 0, 0, 0] ]) * coefficient_scale, } def get_default_filters(): coefficient_scale = 0.125 return { "G_at_R_and_B": np.array([ [0, 0, -1, 0, 0], [0, 0, 2, 0, 0], [-1, 2, 4, 2, -1], [0, 0, 2, 0, 0], [0, 0, -1, 0, 0] ]) * coefficient_scale, "R_at_GR_and_B_at_GB": np.array([ [0, 0, 0.5, 0, 0], [0, -1, 0, -1, 0], [-1, 4, 5, 4, -1], [0, -1, 0, -1, 0], [0, 0, 0.5, 0, 0] ]) * coefficient_scale, "R_at_GB_and_B_at_GR": np.array([ [0, 0, -1, 0, 0], [0, -1, 4, -1, 0], [0.5, 0, 5, 0, 0.5], [0, -1, 4, -1, 0], [0, 0, -1, 0, 0] ]) * coefficient_scale, "R_at_B_and_B_at_R": np.array([ [0, 0, -1.5, 0, 0], [0, 2, 0, 2, 0], [-1.5, 0, 6, 0, -1.5], [0, 2, 0, 2, 0], [0, 0, -1.5, 0, 0] ]) * coefficient_scale, } def apply_demosaicking_filters(image, res, red_mask, green_mask, blue_mask, filters): red_channel = image * red_mask green_channel = image * green_mask blue_channel = image * blue_mask # Create the green channel after applying a filter green_channel = np.where( np.logical_or(red_mask == 1, blue_mask == 1), correlate2d(image, filters['G_at_R_and_B'], mode="same", boundary="symm"), green_channel ) # Define masks for extracting pixel values red_row_mask = np.any(red_mask == 1, axis=1)[:, np.newaxis].astype(np.float32) red_col_mask = np.any(red_mask == 1, axis=0)[np.newaxis].astype(np.float32) blue_row_mask = np.any(blue_mask == 1, axis=1)[:, np.newaxis].astype(np.float32) blue_col_mask = np.any(blue_mask == 1, axis=0)[np.newaxis].astype(np.float32) def update_channel(channel, row_mask, col_mask, filter_key): return np.where( np.logical_and(row_mask == 1, col_mask == 1), correlate2d(image, filters[filter_key], mode="same", boundary="symm"), channel ) # Update the red channel and blue channel red_channel = update_channel(red_channel, red_row_mask, blue_col_mask, 'R_at_GR_and_B_at_GB') red_channel = update_channel(red_channel, blue_row_mask, red_col_mask, 'R_at_GB_and_B_at_GR') blue_channel = update_channel(blue_channel, blue_row_mask, red_col_mask, 'R_at_GR_and_B_at_GB') blue_channel = update_channel(blue_channel, red_row_mask, blue_col_mask, 'R_at_GB_and_B_at_GR') # Update R channel and B channel again red_channel = update_channel(red_channel, blue_row_mask, blue_col_mask, 'R_at_B_and_B_at_R') blue_channel = update_channel(blue_channel, red_row_mask, red_col_mask, 'R_at_B_and_B_at_R') res[:, :, 0] = red_channel res[:, :, 1] = green_channel res[:, :, 2] = blue_channel return res