#!/usr/bin/env python

import numpy as np
import os
from PIL import Image
import matplotlib.pyplot as plt
from matplotlib.widgets import Button, Slider
import skimage as ski
import cv2
import scipy.ndimage as nd
from multiprocessing import Pool

# Variable Globale
CORES = os.cpu_count() - 10
# Dossiers d'entree et de sortie
INPUT_DIR = "/home/experiences/grades/saras/Documents/DATA_SWING/01_processed_data/tomo_sain61_si_dej_S00017_to_S00504_480x480_gpu_1modes_200DM500ML_recons_S/00_crop_data/"
OUTPUT_DIR = "/home/experiences/grades/saras/Documents/DATA_SWING/01_processed_data/tomo_sain61_si_dej_S00017_to_S00504_480x480_gpu_1modes_200DM500ML_recons_S/01_Watershed_Segmentation/"

# Output Dirs
OUTPUT_DIRS = {
    "borderlow": os.path.join(OUTPUT_DIR, "borderlow"),
    "imclean": os.path.join(OUTPUT_DIR, "imclean"),
    "lumiere": os.path.join(OUTPUT_DIR, "lumiere"),
    "borderfill": os.path.join(OUTPUT_DIR, "borderfill"),
    "collar_mask": os.path.join(OUTPUT_DIR, "collar_mask"),
}


def process_image(image_name: str) -> None:
    """
    Segmentation de la lumière tubulaire par watershed.

    Parameters:
    -----------
        image_name : str
            Nom de l'image à segmenter

    Returns:
    --------
        None

    """
    print(f"Traitement de l'image: {image_name}")
    # Chargement des images
    image_path = os.path.join(INPUT_DIR, image_name)
    image = Image.open(image_path)

    # Convert to array
    im_array = np.array(image)

    # Détection des contours avec le seuil de Yen
    border = ski.filters.threshold_yen(im_array)
    border = image > border

    # Remove small objects
    border = ski.morphology.remove_small_objects(border)

    # Fill holes
    borderfill = ski.morphology.area_closing(border, area_threshold=10000)

    # Appliquer une érosion isotropique avec un rayon de 10 pixels, réduisant ainsi la taille de la structure
    borderlow = ski.morphology.isotropic_erosion(borderfill, 15)

    # Suppression des pixels en dehors du masque érodé
    imclean = im_array.copy()
    imclean[~borderlow] = 0

    # Création d'une image vide (0 partout) de la même taille que imclean
    image_vide = np.zeros_like(imclean)

    # Watershed

    # 1. Préparation des marqueurs: classification des pixels
    seuil2 = 35000
    seuil3 = 0
    seuil4 = 30000

    # Création d'un masque binaire ou les pixels sont dans la plage [seuil3 - seuil4]
    test2 = (imclean > seuil3) & (imclean < seuil4)

    # Classificaton des pixels en 3 classes
    image_vide[test2] = 2
    image_vide[imclean > seuil2] = 1
    image_vide[~borderlow] = 0

    # Watershed
    elem = ski.morphology.disk(5)
    watershed_labels = ski.segmentation.watershed(
        ski.filters.sobel(
            ski.filters.gaussian(im_array, 1)
        ), image_vide
    )

    watershed_labels[watershed_labels == 1] = 0
    watershed_labels[watershed_labels == 2] = 1

    # Génération du mask de la lumière tubulaire
    lumiere = borderlow & watershed_labels

    # Sauvegarde des masques
    #Borderfill
    plt.imsave(os.path.join(OUTPUT_DIRS["borderfill"], image_name.replace(".tif", "_borderfill.tiff")), borderfill)
    
    # Borderlow
    plt.imsave(os.path.join(OUTPUT_DIRS["borderlow"], image_name.replace(".tif", "_borderlow.tiff")), borderlow)

    # imclean
    plt.imsave(os.path.join(OUTPUT_DIRS["imclean"], image_name.replace(".tif", "_imclean.tiff")), imclean)

    # lumiere_unique
    plt.imsave(os.path.join(OUTPUT_DIRS["lumiere"], image_name.replace(".tif", "_lumiere.tiff")), lumiere)
 
if __name__ == "__main__":
    # Créer le dossier de sortie s'il n'existe pas
    os.makedirs(OUTPUT_DIR, exist_ok=True)

    for dir_path in OUTPUT_DIRS.values():
        os.makedirs(dir_path, exist_ok=True)

    # Liste des fichiers dans le dossier
    images = sorted([img for img in os.listdir(INPUT_DIR) if img.endswith(".tif")])
    print("Nombre d'images détectées:", len(images))
    # print("Premières images triées", images[:5])

    with Pool(CORES) as p:
        p.map(process_image, images)

    print("Traitement terminé pour toutes les images.")
