"""
Analyse the extracted elapsed-time CSV to understand the temporal
distribution across the dataset.

Usage:
    pixi run python exploration/analyse_timestamps.py

Reads : config/frame_elapsed_times_v2.csv
Saves : exploration/ts_*.svg   (individual plots, viewable in Cursor)
Prints: summary statistics to stdout
"""

import sys
from pathlib import Path

import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

# ── Config ────────────────────────────────────────────────────────────
CSV_PATH = Path("config/frame_elapsed_times_v2.csv")
OUT_DIR = Path("exploration")
OUT_DIR.mkdir(exist_ok=True)

CLIP_PERCENTILE = 99  # clip outliers above this percentile for viz

# ── Load ──────────────────────────────────────────────────────────────
if not CSV_PATH.exists():
    sys.exit(f"CSV not found: {CSV_PATH}  (is the SLURM job done?)")

df = pd.read_csv(CSV_PATH)
df["corrected_seconds"] = pd.to_numeric(df["corrected_seconds"], errors="coerce")
print(f"Loaded {len(df):,} rows  ({df['exam_folder'].nunique()} exams)\n")

valid = df["corrected_seconds"].dropna()

# ── Per-exam aggregates ───────────────────────────────────────────────
exam = df.groupby("exam_folder").agg(
    n_frames=("frame", "count"),
    start_sec=("corrected_seconds", "min"),
    end_sec=("corrected_seconds", "max"),
    median_sec=("corrected_seconds", "median"),
    pct_fa=("time_source", lambda s: (s == "fa").mean() * 100),
    pct_interp=("time_source", lambda s: (s == "interpolated").mean() * 100),
    pct_none=("time_source", lambda s: (s == "none").mean() * 100),
).reset_index()
exam["duration_sec"] = exam["end_sec"] - exam["start_sec"]

# ══════════════════════════════════════════════════════════════════════
#  OUTLIER ANALYSIS
# ══════════════════════════════════════════════════════════════════════
clip_val = valid.quantile(CLIP_PERCENTILE / 100)
outliers = valid[valid > clip_val]
n_outlier_frames = len(outliers)
outlier_exams = df.loc[df["corrected_seconds"] > clip_val, "exam_folder"].unique()

print("=" * 62)
print("  OUTLIER ANALYSIS")
print("=" * 62)
print(f"  Clip threshold ({CLIP_PERCENTILE}th pctl) : {clip_val:.1f}s  "
      f"({clip_val/60:.1f} min)")
print(f"  Frames above threshold    : {n_outlier_frames:,} "
      f"({n_outlier_frames/len(valid)*100:.2f}%)")
print(f"  Exams with outlier frames : {len(outlier_exams)}")
print()
print("  TOP 10 MOST EXTREME TIMESTAMPS:")
top = df.nlargest(10, "corrected_seconds")[
    ["exam_folder", "frame", "corrected_seconds", "corrected_time", "time_source"]
]
for _, row in top.iterrows():
    print(f"    {row['exam_folder']:<12} frame {int(row['frame']):>3}  "
          f"{row['corrected_seconds']:>8.1f}s  ({row['corrected_seconds']/60:.1f} min)  "
          f"src={row['time_source']}")
print()

# For visualization, clip to remove extreme outliers
valid_clipped = valid.clip(upper=clip_val)

# ══════════════════════════════════════════════════════════════════════
#  PLOTS — individual SVG files (viewable in Cursor)
# ══════════════════════════════════════════════════════════════════════
STYLE = {
    "figure.facecolor": "#fafafa",
    "axes.facecolor": "#ffffff",
    "axes.grid": True,
    "grid.alpha": 0.3,
    "font.size": 11,
}
plt.rcParams.update(STYLE)


def save(fig, name):
    p = OUT_DIR / f"ts_{name}.svg"
    fig.savefig(p, bbox_inches="tight")
    plt.close(fig)
    print(f"  Saved → {p}")


# ── 1. All timestamps (clipped) ──────────────────────────────────────
fig, ax = plt.subplots(figsize=(10, 5))
ax.hist(valid_clipped / 60, bins=80, color="#4c72b0", edgecolor="white", lw=0.4)
for edge in q_edges[1:-1]:
    ax.axvline(edge / 60, color="#e74c3c", ls="--", alpha=0.7, lw=1.2,
               label=f"{edge/60:.1f} min" if edge == q_edges[1] else None)
ax.set_xlabel("Elapsed time (min)")
ax.set_ylabel("Frame count")
ax.set_title(f"Distribution of all frame timestamps  (clipped at {CLIP_PERCENTILE}th pctl)")
ax.legend(title="Quintile edges", fontsize=9)
save(fig, "all_timestamps")

# ── 2. Per-exam start times ──────────────────────────────────────────
fig, ax = plt.subplots(figsize=(10, 5))
start_clip = exam["start_sec"].clip(upper=exam["start_sec"].quantile(0.99))
ax.hist(start_clip / 60, bins=60, color="#55a868", edgecolor="white", lw=0.4)
ax.axvline(exam["start_sec"].median() / 60, color="red", ls="-", lw=1.5,
           label=f"Median = {exam['start_sec'].median()/60:.1f} min")
ax.set_xlabel("Start time (min)")
ax.set_ylabel("Exam count")
ax.set_title("When does each exam begin?")
ax.legend()
save(fig, "start_times")

# ── 3. Per-exam end times ────────────────────────────────────────────
fig, ax = plt.subplots(figsize=(10, 5))
end_clip = exam["end_sec"].clip(upper=exam["end_sec"].quantile(0.99))
ax.hist(end_clip / 60, bins=60, color="#c44e52", edgecolor="white", lw=0.4)
ax.axvline(exam["end_sec"].median() / 60, color="navy", ls="-", lw=1.5,
           label=f"Median = {exam['end_sec'].median()/60:.1f} min")
ax.set_xlabel("End time (min)")
ax.set_ylabel("Exam count")
ax.set_title("When does each exam end?")
ax.legend()
save(fig, "end_times")

# ── 4. Exam durations ────────────────────────────────────────────────
fig, ax = plt.subplots(figsize=(10, 5))
dur_clip = exam["duration_sec"].clip(upper=exam["duration_sec"].quantile(0.99))
ax.hist(dur_clip / 60, bins=60, color="#8172b2", edgecolor="white", lw=0.4)
ax.axvline(exam["duration_sec"].median() / 60, color="red", ls="-", lw=1.5,
           label=f"Median = {exam['duration_sec'].median()/60:.1f} min")
ax.set_xlabel("Duration (min)")
ax.set_ylabel("Exam count")
ax.set_title("Total time span per examination")
ax.legend()
save(fig, "durations")

# ── 5. Frames per exam ───────────────────────────────────────────────
fig, ax = plt.subplots(figsize=(10, 5))
max_shown = int(exam["n_frames"].quantile(0.99)) + 2
clipped_frames = exam["n_frames"].clip(upper=max_shown)
ax.hist(clipped_frames, bins=range(0, max_shown + 2), color="#ccb974",
        edgecolor="white", lw=0.4)
ax.axvline(exam["n_frames"].median(), color="red", ls="-", lw=1.5,
           label=f"Median = {exam['n_frames'].median():.0f}")
ax.set_xlabel("Frames per exam")
ax.set_ylabel("Exam count")
ax.set_title(f"Frames per examination  (clipped display at {max_shown})")
ax.legend()
save(fig, "frames_per_exam")

# ── 6. Classical phase bar chart ─────────────────────────────────────
fig, ax = plt.subplots(figsize=(10, 5))
phase_counts = []
phase_labels = []
for name, lo, hi in phase_defs:
    cnt = ((valid >= lo) & (valid < hi)).sum()
    phase_counts.append(cnt)
    phase_labels.append(f"{name}\n({lo}–{hi}s)")
colors = ["#a1dab4", "#41b6c4", "#2c7fb8", "#253494", "#081d58"]
bars = ax.bar(phase_labels, phase_counts, color=colors, edgecolor="white")
for bar, cnt in zip(bars, phase_counts):
    pct = cnt / len(valid) * 100
    ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height(),
            f"{cnt:,}\n({pct:.1f}%)", ha="center", va="bottom", fontsize=9)
ax.set_ylabel("Frame count")
ax.set_title("Classical FA phases — frame distribution")
save(fig, "classical_phases")

# ── 7. Data-driven bins bar chart ────────────────────────────────────
fig, ax = plt.subplots(figsize=(10, 5))
bin_counts = []
bin_labels = []
for i in range(N_BINS):
    lo, hi = q_edges[i], q_edges[i + 1]
    mask = (valid_clipped >= lo) & (valid_clipped < hi) if i < N_BINS - 1 else (valid_clipped >= lo)
    bin_counts.append(mask.sum())
    bin_labels.append(f"Bin {i}\n{lo/60:.1f}–{hi/60:.1f} min")
colors_dd = ["#fef0d9", "#fdcc8a", "#fc8d59", "#e34a33", "#b30000"]
bars = ax.bar(bin_labels, bin_counts, color=colors_dd, edgecolor="#666", lw=0.5)
for bar, cnt in zip(bars, bin_counts):
    ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height(),
            f"{cnt:,}", ha="center", va="bottom", fontsize=10)
ax.set_ylabel("Frame count")
ax.set_title(f"Data-driven quintile bins (clipped at {CLIP_PERCENTILE}th pctl)")
save(fig, "data_driven_bins")

# ── 8. Per-exam timeline plot (20 random exams) ──────────────────────
fig, ax = plt.subplots(figsize=(12, 8))
sample = exam.sample(min(20, len(exam)), random_state=42).sort_values("start_sec")
for i, (_, row) in enumerate(sample.iterrows()):
    ex_df = df[df["exam_folder"] == row["exam_folder"]]
    times = ex_df["corrected_seconds"].dropna() / 60
    ax.scatter(times, [i] * len(times), s=8, alpha=0.7)
    ax.plot([times.min(), times.max()], [i, i], lw=0.8, alpha=0.4)
ax.set_yticks(range(len(sample)))
ax.set_yticklabels(sample["exam_folder"].values, fontsize=8)
ax.set_xlabel("Elapsed time (min)")
ax.set_title("Frame timestamps for 20 random exams")
ax.invert_yaxis()
save(fig, "exam_timelines")

print(f"\nAll plots saved to {OUT_DIR}/ts_*.svg")
print("Open any .svg file in Cursor to view it.")