#!/usr/bin/env python

import numpy as np
import os
import matplotlib.pyplot as plt
import skimage as ski
from PIL import Image
from multiprocessing import Pool

# Variable Globale
CORES = os.cpu_count() - 10
# Dossiers d'entree et de sortie
INPUT_DIR = "/home/experiences/grades/saras/Documents/DATA_SWING/01_processed_data/tomo_sain61_si_dej3/02_crop_data/crop_data_2d/"
OUTPUT_DIR = "/home/experiences/grades/saras/Documents/DATA_SWING/01_processed_data/tomo_sain61_si_dej3/03_watershed_2d/"

# Output Dirs
OUTPUT_DIRS = {
    "borderlow": os.path.join(OUTPUT_DIR, "borderlow"),
    "imclean": os.path.join(OUTPUT_DIR, "imclean"),
    "lumiere": os.path.join(OUTPUT_DIR, "lumiere"),
    "borderfill": os.path.join(OUTPUT_DIR, "borderfill"),
}

def process_image(image_name: str) -> None:
    """
    Segmentation de la lumière tubulaire par watershed.

    Parameters:
    -----------
        image_name : str
            Nom de l'image à segmenter

    Returns:
    --------
        None

    """
    print(f"Traitement de l'image: {image_name}")
    # Chargement des images
    image_path = os.path.join(INPUT_DIR, image_name)
    image = Image.open(image_path)

    # Convert to array
    im_array = np.array(image)

    # Normalisation 8 bits
    img_norm = ((im_array - im_array.min()) / np.ptp(im_array)) * 255
    img_norm = img_norm.astype(np.uint8)

    # Détection des contours avec le seuil de Yen
    border = ski.filters.threshold_yen(img_norm)
    border = img_norm > border

    # Nettoyage et fermeture morphologique
    border = ski.morphology.remove_small_objects(border)
    borderfill = ski.morphology.area_closing(border, area_threshold=100000)

    # Érosion isotropique
    borderlow = ski.morphology.isotropic_erosion(borderfill, 10)

    # Suppression des pixels en dehors du masque érodé
    imclean = im_array.copy()
    imclean[~borderlow] = 0

    # Création d'une image vide (0 partout) de la même taille que imclean
    image_vide = np.zeros_like(imclean)

    # Watershed

    # 1. Préparation des marqueurs: classification des pixels
    seuil2 = 24000 
    seuil3 = 0
    seuil4 = 22000

    # Création d'un masque binaire ou les pixels sont dans la plage [seuil3 - seuil4]
    test2 = (imclean > seuil3) & (imclean < seuil4)

    # Classificaton des pixels en 3 classes
    image_vide[test2] = 2
    image_vide[imclean > seuil2] = 1
    image_vide[~borderlow] = 0


    # Watershed
    elem = ski.morphology.disk(5)
    watershed_labels = ski.segmentation.watershed(
        ski.filters.sobel(
            ski.filters.gaussian(im_array, 1)
        ), image_vide
    )

    watershed_labels[watershed_labels == 1] = 0
    watershed_labels[watershed_labels == 2] = 1

    # Génération du mask de la lumière tubulaire
    lumiere = borderlow & watershed_labels
    
    # Sauvegarde des masques
    # Borderfill
    tifffile.imwrite(
        os.path.join(OUTPUT_DIRS["borderfill"], image_name.replace(".tif", "_borderfill.tiff")),
        (borderfill.astype(np.uint8) * 255)
    )
    
    # Borderlow
    tifffile.imwrite(
        os.path.join(OUTPUT_DIRS["borderlow"], image_name.replace(".tif", "_borderlow.tiff")),
        (borderlow.astype(np.uint8) * 255)
    )
    
    # Imclean
    tifffile.imwrite(
        os.path.join(OUTPUT_DIRS["imclean"], image_name.replace(".tif", "_imclean.tiff")),
        (imclean.astype(np.uint8) * 255)
    )
    
    # Lumiere
    tifffile.imwrite(
        os.path.join(OUTPUT_DIRS["lumiere"], image_name.replace(".tif", "_lumiere.tiff")),
        (lumiere.astype(np.uint8) * 255)
    )


if __name__ == "__main__":
    # Créer le dossier de sortie s'il n'existe pas
    os.makedirs(OUTPUT_DIR, exist_ok=True)

    for dir_path in OUTPUT_DIRS.values():
        os.makedirs(dir_path, exist_ok=True)

    # Liste des fichiers dans le dossier
    images = sorted([img for img in os.listdir(INPUT_DIR) if img.endswith(".tif")])
    print("Nombre d'images détectées:", len(images))
    # print("Premières images triées", images[:5])

    with Pool(CORES) as p:
        p.map(process_image, images)

    print("Traitement terminé pour toutes les images.")



