{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Main notebook for experimenting with the algorithms\n",
    "\n",
    "Here is a simple notebook to test your code. You can modify it as you please.\n",
    "\n",
    "Remember to restart the jupyter kernel each time you modify a file."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%capture\n",
    "\n",
    "%pip install -r requirements.txt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from src.utils import load_image, save_image, psnr, ssim\n",
    "from src.forward_model import CFA\n",
    "from src.methods.RHOUCH_Oussama.reconstruct import run_reconstruction"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load the input image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "image_path = 'images/img_4.png'\n",
    "\n",
    "img = load_image(image_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Shows some information on the image\n",
    "plt.imshow(img)\n",
    "plt.show()\n",
    "print(f'Shape of the image: {img.shape}.')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Definition of the forward model\n",
    "\n",
    "To setup the forward operator we just need to instanciate the `CFA` class. This class needs two arguments: `cfa_name` being the kind of pattern (bayer or kodak), and `input_shape` the shape of the inputs of the operator.\n",
    "\n",
    "This operation is linear and can be represented by a matrix $A$ but no direct access to this matrix is given (one can create it if needed). However the method `direct` allows to perform $A$'s operation. Likewise the method `adjoint` will perform the operation of $A^T$.\n",
    "\n",
    "For example let $X \\in \\mathbb R^{M \\times N \\times 3}$ the input RGB image in natural shape. Then we got $x \\in \\mathbb R^{3MN}$ (vectorized version of $X$) and $A \\in \\mathbb R^{MN \\times 3MN}$, leading to:\n",
    "\n",
    "\\begin{equation*}\n",
    "    y = Ax \\in \\mathbb R^{MN} \\quad \\text{and} \\quad z = A^Ty  \\in \\mathbb R^{3MN}\n",
    "\\end{equation*}\n",
    "\n",
    "However thanks to `direct` and `adjoint` there is no need to work with vectorized images, except if it is interesting to create the matrix $A$ explicitly."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cfa_name = 'bayer' # bayer or quad_bayer\n",
    "op = CFA(cfa_name, img.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Shows the mask\n",
    "plt.imshow(op.mask[:10, :10])\n",
    "plt.show()\n",
    "print(f'Shape of the mask: {op.mask.shape}.')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Applies the mask to the image\n",
    "y = op.direct(img)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Applies the adjoint operation to y\n",
    "z = op.adjoint(y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(2, 3, figsize=(15, 10))\n",
    "axs[0, 0].imshow(img)\n",
    "axs[0, 0].set_title('Input image')\n",
    "axs[0, 1].imshow(y, cmap='gray')\n",
    "axs[0, 1].set_title('Output image')\n",
    "axs[0, 2].imshow(z)\n",
    "axs[0, 2].set_title('Adjoint image')\n",
    "axs[1, 0].imshow(img[800:864, 450:514])\n",
    "axs[1, 0].set_title('Zoomed input image')\n",
    "axs[1, 1].imshow(y[800:864, 450:514], cmap='gray')\n",
    "axs[1, 1].set_title('Zoomed output image')\n",
    "axs[1, 2].imshow(z[800:864, 450:514])\n",
    "axs[1, 2].set_title('Zoomed adjoint image')\n",
    "plt.show()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Run the reconstruction\n",
    "\n",
    "Here the goal is to reconstruct the image `img` using only `y` and `op` (using `img` is forbidden).\n",
    "\n",
    "To run the reconstruction we simply call the function `run_reconstruction`. This function takes in argument the image to reconstruct and the kind of CFA used (bayer or kodak). All the parameters related to the reconstruction itself must be written inside `run_reconstruction`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "res = run_reconstruction(y, cfa_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Prints some information on the reconstruction\n",
    "print(f'Size of the reconstruction: {res.shape}.')"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Quantitative and qualitative results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(1, 2, figsize=(15, 15))\n",
    "axs[0].imshow(img)\n",
    "axs[0].set_title('Original image')\n",
    "axs[1].imshow(res)\n",
    "axs[1].set_title('Reconstructed image')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Computes some metrics\n",
    "print(f'PSNR: {psnr(img, res):.2f}')\n",
    "print(f'SSIM: {ssim(img, res):.4f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "reconstructed_path = 'output/reconstructed_image.png'\n",
    "\n",
    "save_image(reconstructed_path, res)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.0"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}