#!/usr/bin/env python

import numpy as np
import os
from PIL import Image
import matplotlib.pyplot as plt
from matplotlib.widgets import Button, Slider
import skimage as ski
import cv2
import scipy.ndimage as nd
from tifffile import imread
import tifffile




# Dossiers d'entree et de sortie
INPUT_FILE = "/home/experiences/grades/saras/Documents/DATA_SWING/01_processed_data/tomo_sain61_si_dej3/01_crop_data/data_3d/data_3d.tif"
OUTPUT_DIR = "/home/experiences/grades/saras/Documents/DATA_SWING/01_processed_data/tomo_sain61_si_dej3/01_3d_watershed/"
# Output files
OUTPUT_FILES = {
    "borderlow": os.path.join(OUTPUT_DIR, "3D_borderlow.tiff"),
    "seeds_3d": os.path.join(OUTPUT_DIR, "seeds_3d.tiff"),
    "lumiere": os.path.join(OUTPUT_DIR, "3D_lumiere.tiff"),
}

def process_image(image_path: str) -> None:
    """
    Segmentation de la lumière tubulaire par watershed 3D.

    Parameters:
    -----------
        image_name : str
            Nom de l'image à segmenter

    Returns:
    --------
        None

    """
    print(f"Traitement de l'image: {image_path}")
    # Chargement du stack en 3D
    stack_3d = imread(image_path)

    # Reshape
    stack_3d = np.transpose(stack_3d, (1,2,0))

    # Créer un array vide pour la boucle
    seeds_3d = np.zeros_like(stack_3d)
    borderlow_3d = np.zeros_like(stack_3d)

    # boucle en 3d
    for i in range(stack_3d.shape[2]):
        img = stack_3d[:,:,i]
        im_array = np.array(img)
        # Normalisation en 8 bits
        img_norm = ((im_array - im_array.min()) / np.ptp(im_array)) * 255
        img_norm = img_norm.astype(np.uint8)
        border = ski.filters.threshold_yen(img)
        border = img > border
        border = ski.morphology.remove_small_objects(border)
        borderfill = ski.morphology.area_closing(border, area_threshold=10000)
        borderlow = ski.morphology.isotropic_erosion(borderfill, 25)
        imclean = img.copy()
        imclean[~borderlow] = 0
    
        # image_vide
        image_vide = np.zeros_like(imclean)
    
        #Watershed
        #Seuil
        seuil_1 = 20000
        seuil_2 = 0
        seuil_3 = 18000
        # Masque binaire avec les pixels [seuil_2 _ seuil2)]
        masque = (imclean > seuil_2) & (imclean < seuil_3)
    
        # remplissage de l'array avec les seuils
        image_vide[masque] = 2
        image_vide[imclean > seuil_1] = 1
        image_vide[~borderlow] = 0
    
        seeds_3d[:,:,i] = image_vide
        borderlow_3d[:,:,i] = borderlow

    # Watershed
    elem = ski.morphology.disk(5)
    watershed_labels = ski.segmentation.watershed(ski.filters.sobel(ski.filters.gaussian(stack_3d,1)), seeds_3d)
    
    watershed_labels[watershed_labels == 1] = 0
    watershed_labels[watershed_labels == 2] = 1
    
    lumiere = borderlow_3d & watershed_labels

    # sauvegarde image_3d
    tifffile.imwrite(OUTPUT_FILES["seeds_3d"], seeds_3d.astype(np.uint8))
    tifffile.imwrite(OUTPUT_FILES["borderlow"], borderlow_3d.astype(np.uint8))
    tifffile.imwrite(OUTPUT_FILES["lumiere"], lumiere.astype(np.uint8))


if __name__ == "__main__":
    # Créer le dossier de sortie s'il n'existe pas
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    process_image(INPUT_FILE)

    print("Traitement terminé pour le stack 3d.")
