Skip to content
Snippets Groups Projects
forward_model.py 3.21 KiB
"""A file containing the forward operator.
This file should NOT be modified.
"""


import numpy as np

from src.checks import check_cfa, check_rgb


class CFA():
    def __init__(self, cfa: str, input_shape: tuple) -> None:
        """Constructor of the forward operator's class.

        Args:
            cfa (str): Name of the pattern. Either bayer or quad_bayer.
            input_shape (tuple): Shape of the input images of the operator.
        """
        check_cfa(cfa)
    
        self.cfa = cfa
        self.input_shape = input_shape
        self.output_shape = input_shape[:-1]

        if self.cfa == 'bayer':
            self.mask = get_bayer_mask(input_shape)

        elif self.cfa == 'quad_bayer':
            self.mask = get_quad_bayer_mask(input_shape)


    def direct(self, x: np.ndarray) -> np.ndarray:
        """Applies the CFA operation to the image x.

        Args:
            x (np.ndarray): Input image.

        Returns:
            np.ndarray: Output image.
        """
        check_rgb(x)
    
        return np.sum(x * self.mask, axis=2)


    def adjoint(self, y: np.ndarray) -> np.ndarray:
        """Applies the adjoint of the CFA operation.

        Args:
            y (np.ndarray): Input image.

        Returns:
            np.ndarray: Output image.
        """
        return self.mask * y[..., np.newaxis]


def get_bayer_mask(input_shape: tuple) -> np.ndarray:
    """Return the mask of the Bayer CFA.

    Args:
        input_shape (tuple): Shape of the mask.

    Returns:
        np.ndarray: Mask.
    """
    res = np.kron(np.ones((input_shape[0], input_shape[1], 1)), [0, 1, 0])

    res[::2, 1::2] = [1, 0, 0]
    res[1::2, ::2] = [0, 0, 1]

    return res


def get_quad_bayer_mask(input_shape: tuple) -> np.ndarray:
    """Return the mask of the quad_bayer CFA.

    Args:
        input_shape (tuple): Shape of the mask.

    Returns:
        np.ndarray: Mask.
    """
    res = np.kron(np.ones((input_shape[0], input_shape[1], 1)), [0, 1, 0])

    res[::4, 2::4] = [1, 0, 0]
    res[::4, 3::4] = [1, 0, 0]
    res[1::4, 2::4] = [1, 0, 0]
    res[1::4, 3::4] = [1, 0, 0]

    res[2::4, ::4] = [0, 0, 1]
    res[2::4, 1::4] = [0, 0, 1]
    res[3::4, ::4] = [0, 0, 1]
    res[3::4, 1::4] = [0, 0, 1]

    return res


####
####
####

####      ####                ####        #############
####      ######              ####      ##################
####      ########            ####      ####################
####      ##########          ####      ####        ########
####      ############        ####      ####            ####
####      ####  ########      ####      ####            ####
####      ####    ########    ####      ####            ####
####      ####      ########  ####      ####            ####
####      ####  ##    ######  ####      ####          ######
####      ####  ####      ##  ####      ####    ############
####      ####  ######        ####      ####    ##########
####      ####  ##########    ####      ####    ########
####      ####      ########  ####      ####
####      ####        ############      ####
####      ####          ##########      ####
####      ####            ########      ####
####      ####              ######      ####

# 2023
# Authors: Mauro Dalla Mura and Matthieu Muller