"""
GRU downstream evaluation for the Temporal Separability Benchmark.

Edit the CONFIG section at the top, then run:
    python run_downstream.py
"""

import sys
import numpy as np
import torch
import torch.nn as nn
from pathlib import Path
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence
from torch.utils.data import DataLoader, Dataset
from sklearn.metrics import roc_auc_score, average_precision_score, f1_score
from sklearn.preprocessing import label_binarize

sys.path.insert(0, str(Path(__file__).parent.parent / "probe"))
from binning import first_last, compute_edges_uniform, apply_edges

# ── CONFIG ────────────────────────────────────────────────────────────────────
EMBEDDINGS_DIR = Path(__file__).parent.parent / "embeddings"
MODEL_NAME     = "retfound_dinov2"   # folder name under embeddings/
STRATEGY       = "first_last"        # "first_last" | "uniform"
N_BINS         = 2               # ignored for first_last

HIDDEN_SIZE = 256
N_LAYERS    = 2
DROPOUT     = 0.3
EPOCHS      = 60
BATCH_SIZE  = 32
LR          = 1e-3
PATIENCE    = 10
N_CLASSES   = 5

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ─────────────────────────────────────────────────────────────────────────────


class TemporalGRU(nn.Module):
    def __init__(self, input_size: int):
        super().__init__()
        self.gru = nn.GRU(
            input_size, HIDDEN_SIZE, N_LAYERS,
            batch_first=True,
            dropout=DROPOUT if N_LAYERS > 1 else 0.0,
        )
        self.head = nn.Linear(HIDDEN_SIZE, N_CLASSES)

    def forward(self, x, lengths):
        packed = pack_padded_sequence(x, lengths.cpu(), batch_first=True, enforce_sorted=False)
        _, h = self.gru(packed)      # h: (n_layers, batch, hidden)
        return self.head(h[-1])      # last layer hidden state → logits


def load_split(split: str):
    return torch.load(EMBEDDINGS_DIR / MODEL_NAME / f"{split}.pt", weights_only=False)


def prepare_sequences(data, edges):
    sequences, labels = [], []
    for exam in data:
        features = exam["features"]
        elapsed  = exam["elapsed_seconds"]
        valid    = ~elapsed.isnan()
        features, elapsed = features[valid], elapsed[valid]
        if len(features) == 0:
            continue
        if STRATEGY == "first_last":
            vecs, _ = first_last(features, elapsed)
        else:
            vecs, _ = apply_edges(features, elapsed, edges)
        if len(vecs) == 0:
            continue
        sequences.append(vecs.float())
        labels.append(exam["label"])
    return sequences, torch.tensor(labels, dtype=torch.long)


def collate_batch(batch):
    sequences, labels = zip(*batch)
    lengths = torch.tensor([s.shape[0] for s in sequences], dtype=torch.long)
    return pad_sequence(sequences, batch_first=True), lengths, torch.stack(labels).long()


def make_loader(sequences, labels, shuffle=False):
    class _DS(Dataset):
        def __getitem__(self, i): return sequences[i], labels[i]
        def __len__(self): return len(sequences)
    return DataLoader(_DS(), BATCH_SIZE, shuffle=shuffle, collate_fn=collate_batch)


def get_probs(model, loader):
    model.eval()
    probs = []
    with torch.inference_mode():
        for X, lengths, _ in loader:
            logits = model(X.to(DEVICE), lengths.to(DEVICE))
            probs.append(torch.softmax(logits, dim=1).cpu().numpy())
    return np.vstack(probs)


def compute_metrics(y_true, probs):
    y_bin = label_binarize(y_true, classes=list(range(N_CLASSES)))
    preds = probs.argmax(axis=1)
    aucs, aps = [], []
    for c in range(N_CLASSES):
        if y_bin[:, c].sum() == 0:
            continue
        aucs.append(roc_auc_score(y_bin[:, c], probs[:, c]))
        aps.append(average_precision_score(y_bin[:, c], probs[:, c]))
    f1 = float(f1_score(y_true, preds, average="macro", zero_division=0))
    return float(np.mean(aucs) if aucs else 0), float(np.mean(aps) if aps else 0), f1


def main():
    print(f"Model: {MODEL_NAME}  Strategy: {STRATEGY}  Bins: {N_BINS}  Device: {DEVICE}\n")

    train_data = load_split("train")
    val_data   = load_split("val")
    test_data  = load_split("test")

    edges = None
    if STRATEGY != "first_last":
        all_elapsed = np.concatenate([
            exam["elapsed_seconds"][~exam["elapsed_seconds"].isnan()].numpy()
            for exam in train_data
        ])
        edges = compute_edges_uniform(all_elapsed, N_BINS)

    X_train, y_train = prepare_sequences(train_data, edges)
    X_val,   y_val   = prepare_sequences(val_data,   edges)
    X_test,  y_test  = prepare_sequences(test_data,  edges)

    train_loader = make_loader(X_train, y_train, shuffle=True)
    val_loader   = make_loader(X_val,   y_val)
    test_loader  = make_loader(X_test,  y_test)

    model     = TemporalGRU(input_size=X_train[0].shape[1]).to(DEVICE)
    optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
    counts    = torch.bincount(y_train, minlength=N_CLASSES).float()
    weights   = (1.0 / (counts + 1e-6)).to(DEVICE)
    criterion = nn.CrossEntropyLoss(weight=weights / weights.sum() * N_CLASSES)

    best_auc, best_state, no_improve = -1.0, None, 0

    for epoch in range(EPOCHS):
        model.train()
        for X_b, lengths, y_b in train_loader:
            X_b, lengths, y_b = X_b.to(DEVICE), lengths.to(DEVICE), y_b.to(DEVICE)
            optimizer.zero_grad()
            criterion(model(X_b, lengths), y_b).backward()
            optimizer.step()

        val_probs     = get_probs(model, val_loader)
        val_auc, _, _ = compute_metrics(y_val.numpy(), val_probs)
        val_auc       = float(np.nan_to_num(val_auc))

        if val_auc > best_auc:
            best_auc   = val_auc
            best_state = {k: v.clone() for k, v in model.state_dict().items()}
            no_improve = 0
        else:
            no_improve += 1
            if no_improve >= PATIENCE:
                print(f"  Early stop at epoch {epoch}")
                break

        if (epoch + 1) % 10 == 0:
            print(f"  Epoch {epoch + 1:3d}  val AUC-ROC={val_auc:.4f}")

    model.load_state_dict(best_state)
    test_probs      = get_probs(model, test_loader)
    auc, ap, f1     = compute_metrics(y_test.numpy(), test_probs)

    print(f"\n── Test results ─────────────────────────────────────────────────────")
    print(f"  AUC-ROC : {auc:.4f}")
    print(f"  AP      : {ap:.4f}")
    print(f"  F1      : {f1:.4f}")
    print(f"  n_test  : {len(y_test)}")


if __name__ == "__main__":
    main()
