Skip to content
Snippets Groups Projects
malvar.py 5.82 KiB
Newer Older
Theresa El Murr's avatar
Theresa El Murr committed
import numpy as np
from scipy.signal import correlate2d
from src.forward_model import CFA

def malvar_he_cutler(y: np.ndarray, op: CFA ) -> np.ndarray:
    """Performs demosaicing using the malvar-he-cutler algorithm

    Args:
        op (CFA): CFA operator.
        y (np.ndarray): Mosaicked image.

    Returns:
        np.ndarray: Demosaicked image.
    """

    red_mask, green_mask, blue_mask =  [op.mask[:, :, 0], op.mask[:, :, 1], op.mask[:, :, 2]]
    mosaicked_image = np.float32(y)
    demosaicked_image = np.empty(op.input_shape)

    if op.cfa == 'quad_bayer':
        filters = get_quad_bayer_filters()
    else:
        filters = get_default_filters()

    demosaicked_image = apply_demosaicking_filters(
        mosaicked_image,demosaicked_image, red_mask, green_mask, blue_mask, filters
    )

    return demosaicked_image

def get_quad_bayer_filters():
    coefficient_scale = 0.03125
    return {
        "G_at_R_and_B": np.array([
            [0, 0, 0, 0, -1, -1, 0, 0, 0, 0],
            [0, 0, 0, 0, -1, -1, 0, 0, 0, 0],
            [0, 0, 0, 0, 2, 2, 0, 0, 0, 0],
            [0, 0, 0, 0, 2, 2, 0, 0, 0, 0],
            [-1, -1, 2, 2, 4, 4, 2, 2, -1, -1],
            [-1, -1, 2, 2, 4, 4, 2, 2, -1, -1],
            [0, 0, 0, 0, 2, 2, 0, 0, 0, 0],
            [0, 0, 0, 0, 2, 2, 0, 0, 0, 0],
            [0, 0, 0, 0, -1, -1, 0, 0, 0, 0],
            [0, 0, 0, 0, -1, -1, 0, 0, 0, 0]
        ]) * coefficient_scale,
        "R_at_GR_and_B_at_GB": np.array([
            [0, 0, 0, 0, 0.5, 0.5, 0, 0, 0, 0],
            [0, 0, 0, 0, 0.5, 0.5, 0, 0, 0, 0],
            [0, 0, -1, -1, 0, 0, -1, -1, 0, 0],
            [0, 0, -1, -1, 0, 0, -1, -1, 0, 0],
            [-1, -1, 4, 4, 5, 5, 4, 4, -1, -1],
            [-1, -1, 4, 4, 5, 5, 4, 4, -1, -1],
            [0, 0, -1, -1, 0, 0, -1, -1, 0, 0],
            [0, 0, -1, -1, 0, 0, -1, -1, 0, 0],
            [0, 0, 0, 0, 0.5, 0.5, 0, 0, 0, 0],
            [0, 0, 0, 0, 0.5, 0.5, 0, 0, 0, 0]
        ]) * coefficient_scale,
        "R_at_GB_and_B_at_GR": np.array([
            [0, 0, 0, 0, -1, -1, 0, 0, 0, 0],
            [0, 0, 0, 0, -1, -1, 0, 0, 0, 0],
            [0, 0, -1, -1, 4, 4, -1, -1, 0, 0],
            [0, 0, -1, -1, 4, 4, -1, -1, 0, 0],
            [0.5, 0.5, 0, 0, 5, 5, 0, 0, 0.5, 0.5],
            [0.5, 0.5, 0, 0, 5, 5, 0, 0, 0.5, 0.5],
            [0, 0, -1, -1, 4, 4, -1, -1, 0, 0],
            [0, 0, -1, -1, 4, 4, -1, -1, 0, 0],
            [0, 0, 0, 0, -1, -1, 0, 0, 0, 0],
            [0, 0, 0, 0, -1, -1, 0, 0, 0, 0]
        ]) * coefficient_scale,
        "R_at_B_and_B_at_R": np.array([
            [0, 0, 0, 0, -1.5, -1.5, 0, 0, 0, 0],
            [0, 0, 0, 0, -1.5, -1.5, 0, 0, 0, 0],
            [0, 0, 2, 2, 0, 0, 2, 2, 0, 0],
            [0, 0, 2, 2, 0, 0, 2, 2, 0, 0],
            [-1.5, -1.5, 0, 0, 6, 6, 0, 0, -1.5, -1.5],
            [-1.5, -1.5, 0, 0, 6, 6, 0, 0, -1.5, -1.5],
            [0, 0, 2, 2, 0, 0, 2, 2, 0, 0],
            [0, 0, 2, 2, 0, 0, 2, 2, 0, 0],
            [0, 0, 0, 0, -1.5, -1.5, 0, 0, 0, 0],
            [0, 0, 0, 0, -1.5, -1.5, 0, 0, 0, 0]
        ]) * coefficient_scale,
    }

def get_default_filters():
    coefficient_scale = 0.125
    return {
        "G_at_R_and_B": np.array([
            [0, 0, -1, 0, 0],
            [0, 0, 2, 0, 0],
            [-1, 2, 4, 2, -1],
            [0, 0, 2, 0, 0],
            [0, 0, -1, 0, 0]
        ]) * coefficient_scale,
        "R_at_GR_and_B_at_GB": np.array([
            [0, 0, 0.5, 0, 0],
            [0, -1, 0, -1, 0],
            [-1, 4, 5, 4, -1],
            [0, -1, 0, -1, 0],
            [0, 0, 0.5, 0, 0]
        ]) * coefficient_scale,
        "R_at_GB_and_B_at_GR": np.array([
            [0, 0, -1, 0, 0],
            [0, -1, 4, -1, 0],
            [0.5, 0, 5, 0, 0.5],
            [0, -1, 4, -1, 0],
            [0, 0, -1, 0, 0]
        ]) * coefficient_scale,
        "R_at_B_and_B_at_R": np.array([
            [0, 0, -1.5, 0, 0],
            [0, 2, 0, 2, 0],
            [-1.5, 0, 6, 0, -1.5],
            [0, 2, 0, 2, 0],
            [0, 0, -1.5, 0, 0]
        ]) * coefficient_scale,
    }

def apply_demosaicking_filters(image, res, red_mask, green_mask, blue_mask, filters):
    red_channel = image * red_mask
    green_channel = image * green_mask
    blue_channel = image * blue_mask

    # Create the green channel after applying a filter
    green_channel = np.where(
        np.logical_or(red_mask == 1, blue_mask == 1),
        correlate2d(image, filters['G_at_R_and_B'], mode="same", boundary="symm"),
        green_channel
    )


    # Define masks for extracting pixel values
    red_row_mask = np.any(red_mask == 1, axis=1)[:, np.newaxis].astype(np.float32)
    red_col_mask = np.any(red_mask == 1, axis=0)[np.newaxis].astype(np.float32)

    blue_row_mask = np.any(blue_mask == 1, axis=1)[:, np.newaxis].astype(np.float32)
    blue_col_mask = np.any(blue_mask == 1, axis=0)[np.newaxis].astype(np.float32)

    def update_channel(channel, row_mask, col_mask, filter_key):
        return np.where(
            np.logical_and(row_mask == 1, col_mask == 1),
            correlate2d(image, filters[filter_key], mode="same", boundary="symm"),
            channel
        )

# Update the red channel and blue channel
    red_channel = update_channel(red_channel, red_row_mask, blue_col_mask, 'R_at_GR_and_B_at_GB')
    red_channel = update_channel(red_channel, blue_row_mask, red_col_mask, 'R_at_GB_and_B_at_GR')

    blue_channel = update_channel(blue_channel, blue_row_mask, red_col_mask, 'R_at_GR_and_B_at_GB')
    blue_channel = update_channel(blue_channel, red_row_mask, blue_col_mask, 'R_at_GB_and_B_at_GR')

    # Update R channel and B channel again
    red_channel = update_channel(red_channel, blue_row_mask, blue_col_mask, 'R_at_B_and_B_at_R')
    blue_channel = update_channel(blue_channel, red_row_mask, red_col_mask, 'R_at_B_and_B_at_R')
    res[:, :, 0] = red_channel
    res[:, :, 1] = green_channel
    res[:, :, 2] = blue_channel
    return res