"""A file containing some useful functions for the project.
This file should NOT be modified.
"""


from skimage.metrics import peak_signal_noise_ratio, structural_similarity
from skimage.io import imread, imsave
import numpy as np

from src.checks import check_data_range, check_rgb, check_shape, check_path, check_png


def normalise_image(img: np.ndarray) -> np.ndarray:
    """Normalise the values of img in the interval [0, 1].

    Args:
        img (np.ndarray): Image to normalise.

    Returns:
        np.ndarray: Normalised image.
    """
    return (img - np.min(img)) / np.ptp(img)


def load_image(file_path: str) -> np.ndarray:
    """Loads the image located in file_path.

    Args:
        file_path (str): Path of the file containing the image. Must end by '.png'.

    Returns:
        np.ndarray: The loaded image.
    """
    check_path(file_path)
    check_png(file_path)

    return normalise_image(imread(file_path))


def save_image(file_path: str, img: np.ndarray) -> None:
    """Saves the image located in file_path.

    Args:
        file_path (str): Path of the file in which the image will be saved. Must end by '.png'.
        img (np.ndarray): Image to save.
    """
    check_path(file_path.split('/')[-2])
    check_png(file_path)

    imsave(file_path, (img * 255).astype(np.uint8))


def psnr(img1: np.ndarray, img2: np.ndarray) -> float:
    """Computes the PSNR between img1 and img2 after some sanity checks.
    img1 and img2 must:
        - have the same shape;
        - be in range [0, 1].

    Args:
        img1 (np.ndarray): First image.
        img2 (np.ndarray): Second image.

    Returns:
        float: PSNR between img1 and img2.
    """
    check_shape(img1, img2)
    check_data_range(img1)

    return peak_signal_noise_ratio(img1, img2, data_range=1)


def ssim(img1: np.ndarray, img2: np.ndarray) -> float:
    """Computes the SSIM between img1 and img2 after some sanity checks.
    img1 and img2 must:
        - have the same shape;
        - be in range [0, 1];
        - be 3 dimensional array with 3 channels.

    Args:
        img1 (np.ndarray): First image.
        img2 (np.ndarray): Second image.

    Returns:
        float: SSIM between img1 and img2.
    """
    check_shape(img1, img2)
    check_data_range(img1)
    check_rgb(img1)

    return structural_similarity(img1, img2, data_range=1, channel_axis=2)


####
####
####

####      ####                ####        #############
####      ######              ####      ##################
####      ########            ####      ####################
####      ##########          ####      ####        ########
####      ############        ####      ####            ####
####      ####  ########      ####      ####            ####
####      ####    ########    ####      ####            ####
####      ####      ########  ####      ####            ####
####      ####  ##    ######  ####      ####          ######
####      ####  ####      ##  ####      ####    ############
####      ####  ######        ####      ####    ##########
####      ####  ##########    ####      ####    ########
####      ####      ########  ####      ####
####      ####        ############      ####
####      ####          ##########      ####
####      ####            ########      ####
####      ####              ######      ####

# 2023
# Authors: Mauro Dalla Mura and Matthieu Muller