import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patheffects as patheffects
from sqlalchemy import create_engine
import pandas as pd
from PIL import Image
import os
import argparse
import json
import re

# --- CONFIGURATION BASE DE DONNÉES ---
DB_CONFIG = "mysql+pymysql://uset_readonly:StfnPdts@uset.oma.be/uset_drawings"

# --- CONFIGURATION DES CHEMINS ---
# Source: ../data/uset/drawings/YYYY/MM/filename
SOURCE_IMAGES_ROOT = "../../drawings"
# Destination: ./drawings_overlays/YYYY/MM/GRID-usdXXXX.png
OUTPUT_LAYERS_ROOT = "../"

# --- CONFIGURATION VISUELLE ---
GROUP_FONT_SIZE = 60
GROUP_OFFSET = 20
GROUP_HALO_WIDTH = 5
GROUP_CIRCLE_LW = 3
GRID_LINE_WIDTH = 1.5

def generate_batch_layers(start_date, end_date, db_config, output_folder):
    engine = create_engine(db_config)
    
    query = f"SELECT DISTINCT DateTime, Filename FROM drawings WHERE DateTime BETWEEN '{start_date} 00:00:00' AND '{end_date} 23:59:59'"
    df_entries = pd.read_sql(query, engine)

    for _, entry in df_entries.iterrows():
        target_dt, filename = entry['DateTime'], str(entry['Filename'])
        
        match = re.search(r'usd\d+', filename)
        if not match: 
            continue
        base_name = match.group(0)

        dt = pd.to_datetime(target_dt)
        year_str = dt.strftime('%Y')
        month_str = dt.strftime('%m')

        # --- LOCALISATION SOURCE ---
        # On cherche l'image dans SOURCE_ROOT/YYYY/MM/filename
        local_path = os.path.join(SOURCE_IMAGES_ROOT, year_str, month_str, filename)
        
        if not os.path.exists(local_path):
            print(f"⚠️ Image source introuvable : {local_path}")
            continue

        # --- PRÉPARATION DESTINATION ---
        current_output_dir = os.path.join(output_folder, year_str, month_str)
        if not os.path.exists(current_output_dir):
            os.makedirs(current_output_dir)

        try:
            res_calib = pd.read_sql(f"SELECT CenterX, CenterY, Radius FROM calibrations WHERE DateTime = '{target_dt}'", engine)
            res_draw = pd.read_sql(f"SELECT AngleP, AngleB FROM drawings WHERE DateTime = '{target_dt}'", engine)
            groups = pd.read_sql(f"SELECT DigiSunNumber, RawArea_px, PosX, PosY FROM sGroups WHERE DateTime = '{target_dt}'", engine)

            if res_calib.empty or res_draw.empty:
                continue

            cx, cy, radius = res_calib.at[0, 'CenterX'], res_calib.at[0, 'CenterY'], res_calib.at[0, 'Radius']
            P_rad, B0_rad = np.radians(-res_draw.at[0, 'AngleP']), np.radians(res_draw.at[0, 'AngleB'])

            with Image.open(local_path) as img:
                img_w, img_h = img.size

            def setup_ax():
                fig, ax = plt.subplots(figsize=(img_w/100, img_h/100), dpi=100)
                ax.set_xlim(0, img_w)
                ax.set_ylim(img_h, 0)
                ax.axis('off')
                fig.patch.set_alpha(0)
                ax.patch.set_alpha(0)
                return fig, ax

            # --- GÉNÉRATION DE LA GRILLE ---
            fig_g, ax_g = setup_ax()
            mask = plt.Circle((cx, cy), radius, transform=ax_g.transData)
            for deg in range(-90, 100, 10):
                for is_meridian in [False, True]:
                    rng = np.radians(np.arange(-90, 91, 1))
                    phi = np.radians(deg) if not is_meridian else rng
                    lam = rng if not is_meridian else np.radians(deg)
                    gx = np.cos(phi) * np.sin(lam)
                    gy = np.sin(phi) * np.cos(B0_rad) - np.cos(phi) * np.sin(B0_rad) * np.cos(lam)
                    gxr, gyr = gx * np.cos(P_rad) - gy * np.sin(P_rad), gx * np.sin(P_rad) + gy * np.cos(P_rad)
                    
                    is_main = (not is_meridian and deg == 0) or (is_meridian and deg == 0)
                    color = '#0055ff' if is_main else '#00008b' 
                    
                    l, = ax_g.plot(cx + gxr*radius, cy - gyr*radius, color=color, lw=GRID_LINE_WIDTH, alpha=0.7)
                    l.set_clip_path(mask)
            
            fig_g.savefig(os.path.join(current_output_dir, f"GRID-{base_name}.png"), transparent=True, bbox_inches='tight', pad_inches=0)
            plt.close(fig_g)

            # --- GÉNÉRATION DES GROUPES ---
            fig_gr, ax_gr = setup_ax()
            for _, g in groups.iterrows():
                r_px = np.sqrt(g['RawArea_px'] / np.pi) * 2.5
                ax_gr.add_patch(plt.Circle((g['PosX'], g['PosY']), r_px, color='#cc0000', fill=False, lw=GROUP_CIRCLE_LW))
                t = ax_gr.text(
                    g['PosX'] + r_px + GROUP_OFFSET, 
                    g['PosY'], 
                    str(int(g['DigiSunNumber'])), 
                    color='black', 
                    fontsize=GROUP_FONT_SIZE, 
                    fontweight='bold', 
                    va='center'
                )
                t.set_path_effects([patheffects.withStroke(linewidth=GROUP_HALO_WIDTH, foreground='white')])
            
            fig_gr.savefig(os.path.join(current_output_dir, f"GROUPS-{base_name}.png"), transparent=True, bbox_inches='tight', pad_inches=0)
            plt.close(fig_gr)
            
            print(f"✅ Traitement réussi : {year_str}/{month_str}/{base_name}")

        except Exception as e:
            print(f"❌ Erreur lors du traitement de {base_name}: {e}")

def generate_index(output_folder):
    """ Scanne récursivement les dossiers de sortie pour data.js """
    if not os.path.exists(output_folder):
        return
    
    available = []
    for root, dirs, files in os.walk(output_folder):
        for f in files:
            if f.startswith('GRID-'):
                match = re.search(r'usd\d+', f)
                if match:
                    available.append(match.group(0))
    
    available = sorted(list(set(available)))
    with open('data.js', 'w') as f:
        f.write(f"const imageList = {json.dumps(available)};")
    print(f"📂 Index data.js mis à jour ({len(available)} fichiers).")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Générateur de calques USET")
    parser.add_argument("start", help="Date de début (YYYY-MM-DD)")
    parser.add_argument("end", help="Date de fin (YYYY-MM-DD)")
    
    args = parser.parse_args()
    
    generate_batch_layers(args.start, args.end, DB_CONFIG, OUTPUT_LAYERS_ROOT)
    generate_index(OUTPUT_LAYERS_ROOT)