import numpy as np
from .find_neighbors import find_neighbors
from .find_direct_neighbors import find_direct_neighbors

def find_weights(img: np.ndarray, direct_neighbors: list, channel: int, i: int, j: int, N: int, M: int) -> list:
    """
    Find the weights of the neighbors of a pixel in the image.
    
    Args:
        img (np.ndarray): The image to process.
        direct_neighbors (list): The list of direct neighbors of the pixel.
        channel (int): The index of the channel to process.
        i (int): The row index of the pixel.
        j (int): The column index of the pixel.
        N (int): Height of the image.
        M (int): Width of the image.
        
    Returns:
        list: The list of weights of the neighbors of the pixel.
    """
    
    [Dx, Dy, Dxx, Dyy] = direct_neighbors
    E = []
    c = 1
    
    for k in [-1, 0, 1]:
        for l in [-1, 0, 1]:
            n = find_neighbors(img, channel, i + k, j + l, N, M)
            dd = find_direct_neighbors(n)

            sqrt_arguments = {
                1: 1 + Dyy * 2 + dd[3] * 2,
                3: 1 + Dxx * 2 + dd[2] * 2,
                2: 1 + Dy * 2 + dd[1] * 2,
                4: 1 + Dx * 2 + dd[0] * 2
            }

            value = sqrt_arguments.get(c, 1)
            if value < 0:
                E.append(0)
            else:
                E.append(1 / np.sqrt(value))

            c += 1
            
    return E