import os
import re
import base64
import io
import threading
import shutil
import time
import sys
import random
import tempfile
from typing import Optional, List, Dict, Set, Any, Tuple
from contextlib import asynccontextmanager
from collections import Counter

from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from PIL import Image
import numpy as np
import onnxruntime as ort

# White-letter solver
import cv2
import pytesseract

# ONNX animal-icon puzzle (lazy-loaded)
_animals_session: Optional[ort.InferenceSession] = None
_animals_img_size: int = 256          # read from model input shape at load time
_animals_class_names: List[str] = []  # populated from the registry folder name


# ======================================================
# RESOURCE PATH
# ======================================================
def resource_path(relative_path: str) -> str:
    """Get absolute path to resource, works for dev and PyInstaller."""
    if hasattr(sys, "_MEIPASS"):
        return os.path.join(sys._MEIPASS, relative_path)
    return os.path.join(os.path.dirname(os.path.abspath(__file__)), relative_path)


# ======================================================
# TESSERACT CONFIG
# ======================================================
# Main white-letter solver does NOT need Tesseract in the normal path.
# Tesseract is only used as a fallback.
TESSERACT_CMD = os.getenv("TESSERACT_CMD", "").strip()

if TESSERACT_CMD and os.path.exists(TESSERACT_CMD):
    pytesseract.pytesseract.tesseract_cmd = TESSERACT_CMD
else:
    WINDOWS_TESSERACT_EXE = r"C:\Program Files\Tesseract-OCR\tesseract.exe"
    if os.path.exists(WINDOWS_TESSERACT_EXE):
        pytesseract.pytesseract.tesseract_cmd = WINDOWS_TESSERACT_EXE
        print(f"[tesseract] Using Windows path: {WINDOWS_TESSERACT_EXE}")
    else:
        print("[tesseract] Using system PATH. If fallback OCR fails, install Tesseract or set TESSERACT_CMD.")


# ======================================================
# CONFIG
# ======================================================
# External model folders.
# Keep these beside the EXE / app.py:
#   ./model_hot
#   ./model_cold
HOT_DIR = "./model_hot"
COLD_DIR = "./model_cold"

EVICTION_INTERVAL_SEC = 300
MIN_HOT_MODELS = 4

TARGET_CLASS_INDEX = 0


# ======================================================
# NORMALIZATION
# ======================================================
def normalize_label(s: str) -> str:
    if not s:
        return ""
    s = s.replace("_", " ").replace("-", " ").lower()
    s = re.sub(r"\s+", " ", s).strip()
    return s


def is_unpaired_letter_question(question: str) -> bool:
    q = normalize_label(question)

    checks = [
        ["unpaired", "letter"],
        ["unmatched", "letter"],
        ["no", "matching", "letter"],
        ["matching", "letter", "tile"],
        ["no", "pair", "letter"],
        ["different", "letter"],          # "select the tile that shows a different letter"
        ["tile", "shows", "different"],   # alternate phrasing
    ]

    return any(all(word in q for word in words) for words in checks)


def is_checklist_letter_question(question: str) -> bool:
    """Returns True for 'Select all letters from the checklist based on the count'."""
    q = normalize_label(question)
    checks = [
        ["select", "letter", "checklist"],
        ["letters", "checklist", "count"],
        ["select", "checklist", "count"],
        ["checklist", "based", "count"],
        ["letters", "from", "checklist"],
    ]
    return any(all(word in q for word in words) for words in checks)


# ======================================================
# HOMOGLYPH STRIPPING (Cyrillic/Greek lookalikes → Latin)
# ======================================================
_HOMOGLYPH_MAP = str.maketrans({
    '\u0410': 'A', '\u0430': 'a', '\u0415': 'E', '\u0435': 'e',
    '\u041E': 'O', '\u043E': 'o', '\u0421': 'C', '\u0441': 'c',
    '\u0420': 'P', '\u0440': 'p', '\u0425': 'X', '\u0445': 'x',
    '\u0422': 'T', '\u0412': 'B', '\u041A': 'K', '\u043A': 'k',
    '\u041C': 'M', '\u043C': 'm', '\u0423': 'Y', '\u0443': 'y',
    '\u0391': 'A', '\u03B1': 'a', '\u0395': 'E', '\u03B5': 'e',
    '\u039F': 'O', '\u03BF': 'o', '\u0399': 'I', '\u03B9': 'i',
    '\u039A': 'K', '\u03BA': 'k', '\u039C': 'M', '\u03BC': 'm',
    '\u039D': 'N', '\u03BD': 'n', '\u03A1': 'P', '\u03C1': 'p',
    '\u03A4': 'T', '\u03C4': 't', '\u03A5': 'Y', '\u03C5': 'y',
})


def is_animal_icons_question(question: str) -> bool:
    q = normalize_label(question.translate(_HOMOGLYPH_MAP))
    checks = [
        ["animal", "icons", "counts"],
        ["animal", "icons", "according"],
        ["select", "animal", "icons"],
    ]
    return any(all(word in q for word in words) for words in checks)
def decode_b64_to_pil(image_b64: str) -> Image.Image:
    if image_b64.startswith("data:image"):
        image_b64 = image_b64.split(",", 1)[-1]
    raw = base64.b64decode(image_b64)
    return Image.open(io.BytesIO(raw)).convert("RGB")


def preprocess(pil_img: Image.Image, img_size: int = 256) -> np.ndarray:
    pil_img = pil_img.resize((img_size, img_size))
    arr = np.array(pil_img, dtype=np.float32) / 255.0
    arr = arr.transpose(2, 0, 1)
    return np.expand_dims(arr, axis=0)


# ======================================================
# ONNX MODEL REGISTRY
# ======================================================
_model_cache: Dict[str, ort.InferenceSession] = {}
_registry: Dict[str, str] = {}
_cold_index: Dict[str, str] = {}
_model_last_used: Dict[str, float] = {}
_hot_model_files: Set[str] = set()

_lock = threading.Lock()


# ======================================================
# ONNX SESSION LOADER
# ======================================================
def load_session(model_path: str) -> ort.InferenceSession:
    with _lock:
        if model_path in _model_cache:
            return _model_cache[model_path]

    sess = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])

    with _lock:
        _model_cache[model_path] = sess

    return sess


def detect(
    sess: ort.InferenceSession,
    tensor: np.ndarray,
    conf_threshold: float,
) -> bool:
    input_name = sess.get_inputs()[0].name
    raw = sess.run(None, {input_name: tensor})[0]

    preds = raw[0]
    num_classes = preds.shape[0] - 4

    if TARGET_CLASS_INDEX >= num_classes:
        return False

    scores = preds[4 + TARGET_CLASS_INDEX, :]
    return bool(np.any(scores >= conf_threshold))


# ======================================================
# MODEL SCANNING
# ======================================================
def _find_onnx_in_folder(folder: str) -> Optional[str]:
    if not os.path.isdir(folder):
        return None

    for f in os.listdir(folder):
        if f.lower().endswith(".onnx"):
            return os.path.join(folder, f)

    return None


def scan_hot_models():
    os.makedirs(HOT_DIR, exist_ok=True)

    entries = sorted(os.listdir(HOT_DIR))  # deterministic order

    for entry in entries:
        folder_path = os.path.join(HOT_DIR, entry)

        if not os.path.isdir(folder_path):
            continue

        onnx_path = _find_onnx_in_folder(folder_path)

        if not onnx_path:
            continue

        try:
            with _lock:
                current_hot = len(_hot_model_files)

            if current_hot >= MIN_HOT_MODELS:
                # Already at capacity — move this folder straight to cold
                norm = normalize_label(entry)
                cold_folder = os.path.join(COLD_DIR, entry)
                try:
                    shutil.move(folder_path, cold_folder)
                    cold_onnx = os.path.join(cold_folder, os.path.basename(onnx_path))
                    with _lock:
                        _cold_index[norm] = cold_onnx
                    print(f"[hot→cold] '{entry}' moved to cold at startup (hot full)")
                except Exception as e:
                    print(f"[hot→cold] Failed to move '{entry}': {e}")
                continue

            load_session(onnx_path)
            norm = normalize_label(entry)

            with _lock:
                _registry[norm] = onnx_path
                _hot_model_files.add(onnx_path)
                _model_last_used[onnx_path] = time.time()

            print(f"[hot] Loaded: {entry}")

        except Exception as e:
            print(f"[hot] Error loading {folder_path}: {e}")


def scan_cold_models():
    os.makedirs(COLD_DIR, exist_ok=True)

    count = 0

    for entry in os.listdir(COLD_DIR):
        folder_path = os.path.join(COLD_DIR, entry)

        if not os.path.isdir(folder_path):
            continue

        onnx_path = _find_onnx_in_folder(folder_path)

        if not onnx_path:
            continue

        norm = normalize_label(entry)

        with _lock:
            if norm not in _cold_index:
                _cold_index[norm] = onnx_path
                count += 1

    print(f"[cold] Indexed {count} model folder(s) in {COLD_DIR}")


# ======================================================
# PROMOTE MODEL: cold -> hot
# ======================================================
def promote_from_cold(norm_question: str) -> Optional[str]:
    with _lock:
        cold_onnx = _cold_index.get(norm_question)

    if cold_onnx is None:
        return None

    cold_folder = os.path.dirname(cold_onnx)
    folder_name = os.path.basename(cold_folder)
    hot_folder = os.path.join(HOT_DIR, folder_name)

    if not os.path.isdir(cold_folder):
        with _lock:
            _cold_index.pop(norm_question, None)
        return None

    try:
        shutil.move(cold_folder, hot_folder)
    except Exception as e:
        print(f"[promote] Failed to move '{folder_name}' cold -> hot: {e}")
        return None

    print(f"[promote] '{folder_name}' cold -> hot")

    hot_onnx = _find_onnx_in_folder(hot_folder)

    if not hot_onnx:
        return None

    try:
        load_session(hot_onnx)
    except Exception as e:
        print(f"[promote] Failed to load '{hot_onnx}': {e}")
        return None

    with _lock:
        _registry[norm_question] = hot_onnx
        _hot_model_files.add(hot_onnx)
        _model_last_used[hot_onnx] = time.time()
        _cold_index.pop(norm_question, None)

    return hot_onnx


# ======================================================
# EVICT MODEL: hot -> cold
# ======================================================
def evict_unused_models():
    now = time.time()

    with _lock:
        hot_count = len(_hot_model_files)

        if hot_count <= MIN_HOT_MODELS:
            return

        candidates = sorted(
            list(_hot_model_files),
            key=lambda p: _model_last_used.get(p, 0),
        )

    evictable = hot_count - MIN_HOT_MODELS
    evicted = 0

    for hot_onnx in candidates:
        if evicted >= evictable:
            break

        with _lock:
            last = _model_last_used.get(hot_onnx, 0)

        if now - last < EVICTION_INTERVAL_SEC:
            continue

        hot_folder = os.path.dirname(hot_onnx)
        folder_name = os.path.basename(hot_folder)
        cold_folder = os.path.join(COLD_DIR, folder_name)

        with _lock:
            keys = [q for q, p in _registry.items() if p == hot_onnx]

            for q in keys:
                del _registry[q]
                cold_onnx = os.path.join(cold_folder, os.path.basename(hot_onnx))
                _cold_index[q] = cold_onnx

            _hot_model_files.discard(hot_onnx)
            _model_last_used.pop(hot_onnx, None)
            _model_cache.pop(hot_onnx, None)

        try:
            shutil.move(hot_folder, cold_folder)
            print(f"[evict] '{folder_name}' hot -> cold")
            evicted += 1
        except Exception as e:
            print(f"[evict] Failed to move '{folder_name}': {e}")


# ======================================================
# BACKGROUND EVICTOR THREAD
# ======================================================
_evictor_stop = threading.Event()


def _evictor_loop():
    while not _evictor_stop.is_set():
        _evictor_stop.wait(EVICTION_INTERVAL_SEC)

        if _evictor_stop.is_set():
            break

        try:
            evict_unused_models()
        except Exception as e:
            print(f"[evictor] error: {e}")


# ======================================================
# TILE-FIRST UNPAIRED LETTER SOLVER
# ======================================================
def _order_points(pts: np.ndarray) -> np.ndarray:
    rect = np.zeros((4, 2), dtype="float32")

    s = pts.sum(axis=1)
    rect[0] = pts[np.argmin(s)]      # top-left
    rect[2] = pts[np.argmax(s)]      # bottom-right

    diff = np.diff(pts, axis=1)
    rect[1] = pts[np.argmin(diff)]   # top-right
    rect[3] = pts[np.argmax(diff)]   # bottom-left

    return rect


def _warp_tile(img_bgr: np.ndarray, box: np.ndarray) -> Optional[np.ndarray]:
    rect = _order_points(box.astype("float32"))
    tl, tr, br, bl = rect

    w1 = np.linalg.norm(br - bl)
    w2 = np.linalg.norm(tr - tl)
    h1 = np.linalg.norm(tr - br)
    h2 = np.linalg.norm(tl - bl)

    max_w = int(max(w1, w2))
    max_h = int(max(h1, h2))

    if max_w < 15 or max_h < 15:
        return None

    dst = np.array(
        [
            [0, 0],
            [max_w - 1, 0],
            [max_w - 1, max_h - 1],
            [0, max_h - 1],
        ],
        dtype="float32",
    )

    M = cv2.getPerspectiveTransform(rect, dst)
    return cv2.warpPerspective(img_bgr, M, (max_w, max_h))


def _detect_tile_boxes(pil_img: Image.Image) -> List[Dict[str, Any]]:
    """
    Detects the tile squares.

    This is much more stable than detecting white letters directly because
    waterfall/background white lines are ignored. It supports both:
    - old yellow tiles
    - new blue/cyan tiles

    Works at any image scale by upscaling small images before processing.
    """
    img_rgb = np.array(pil_img.convert("RGB"))
    img_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)

    ih, iw = img_bgr.shape[:2]

    # Upscale tiny images so tile edges are detectable.
    # Tiles need to be at least ~50px wide after any downscaling.
    TARGET_MIN_DIM = 300
    scale_factor = 1.0
    if min(ih, iw) < TARGET_MIN_DIM:
        scale_factor = TARGET_MIN_DIM / min(ih, iw)
        new_w = int(iw * scale_factor)
        new_h = int(ih * scale_factor)
        img_bgr = cv2.resize(img_bgr, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
        ih, iw = img_bgr.shape[:2]

    gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
    img_area = ih * iw

    # Minimum tile side: ~6% of the shorter image dimension.
    min_tile_side = max(15, int(min(ih, iw) * 0.06))
    # Dedup radius: ~8% of shorter dimension.
    dedup_radius = max(20, int(min(ih, iw) * 0.08))

    # Edge rectangle detection works well for the tile border.
    blur = cv2.GaussianBlur(gray, (5, 5), 0)
    edges = cv2.Canny(blur, 50, 150)
    edges = cv2.dilate(edges, np.ones((3, 3), np.uint8), iterations=1)

    contours, _ = cv2.findContours(edges, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)

    candidates: List[Dict[str, Any]] = []

    for cnt in contours:
        area = cv2.contourArea(cnt)

        if area < img_area * 0.005:
            continue

        if area > img_area * 0.10:
            continue

        rect = cv2.minAreaRect(cnt)
        (cx, cy), (rw, rh), angle = rect

        if rw < min_tile_side or rh < min_tile_side:
            continue

        ratio = max(rw, rh) / max(1.0, min(rw, rh))
        rect_area = rw * rh
        rectangularity = area / rect_area if rect_area else 0

        # Tiles are square-ish.
        if ratio > 1.55:
            continue

        # Reject weak/random contours.
        if rectangularity < 0.50:
            continue

        box = cv2.boxPoints(rect)
        crop = _warp_tile(img_bgr, box)

        if crop is None:
            continue

        # Check crop has colored tile-like pixels.
        hsv_crop = cv2.cvtColor(crop, cv2.COLOR_BGR2HSV)
        h = hsv_crop[:, :, 0]
        s = hsv_crop[:, :, 1]
        v = hsv_crop[:, :, 2]

        tile_bool = (
            (s > 35)
            & (v > 45)
            & (
                ((h >= 70) & (h <= 125))   # blue/cyan tiles
                | ((h >= 8) & (h <= 55))   # yellow/orange/brown tiles
                | ((h >= 0) & (h <= 8))    # deep brown/red tiles
            )
        )

        tile_ratio = float(np.count_nonzero(tile_bool) / tile_bool.size)

        if tile_ratio < 0.22:
            continue

        score = float(area * rectangularity * tile_ratio / ratio)

        # Store coordinates scaled back to original image space.
        candidates.append(
            {
                "cx": float(cx / scale_factor),
                "cy": float(cy / scale_factor),
                "area": float(area),
                "ratio": float(ratio),
                "rectangularity": float(rectangularity),
                "tile_ratio": float(tile_ratio),
                "score": score,
                "crop": crop,
                "box": box,
            }
        )

    candidates.sort(key=lambda c: c["score"], reverse=True)

    # Deduplicate repeated contours around same tile.
    filtered: List[Dict[str, Any]] = []

    for c in candidates:
        duplicate = False

        for f in filtered:
            dist = ((c["cx"] - f["cx"]) ** 2 + (c["cy"] - f["cy"]) ** 2) ** 0.5

            if dist < dedup_radius:
                duplicate = True
                break

        if not duplicate:
            filtered.append(c)

    # Return all detected tiles (no hardcoded cap — puzzle may have 3, 4, 5, or more).
    return filtered


def _letter_mask_from_tile_crop(tile_bgr: np.ndarray) -> Optional[np.ndarray]:
    """
    Extracts only the white letter from one tile crop and normalizes it to 80x80.

    Since the crop is already tile-only, the waterfall/background is gone.
    """
    h, w = tile_bgr.shape[:2]

    if h < 20 or w < 20:
        return None

    # Remove tile border. Letter is inside.
    px = int(w * 0.16)
    py = int(h * 0.16)

    inner = tile_bgr[py:h - py, px:w - px]

    if inner.size == 0:
        inner = tile_bgr

    hsv = cv2.cvtColor(inner, cv2.COLOR_BGR2HSV)

    # White letter pixels.
    mask = cv2.inRange(
        hsv,
        np.array([0, 0, 150], dtype=np.uint8),
        np.array([179, 135, 255], dtype=np.uint8),
    )

    mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, np.ones((2, 2), np.uint8), iterations=1)
    mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, np.ones((2, 2), np.uint8), iterations=1)

    contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    if not contours:
        return None

    # Letter is the largest white component inside the tile.
    contour = max(contours, key=cv2.contourArea)

    # Minimum area scales with crop size — a letter should cover ~3% of the inner area.
    inner_area = inner.shape[0] * inner.shape[1]
    min_letter_area = max(15, int(inner_area * 0.03))

    if cv2.contourArea(contour) < min_letter_area:
        return None

    x, y, bw, bh = cv2.boundingRect(contour)

    pad = 4
    x1 = max(0, x - pad)
    y1 = max(0, y - pad)
    x2 = min(mask.shape[1], x + bw + pad)
    y2 = min(mask.shape[0], y + bh + pad)

    roi = mask[y1:y2, x1:x2]

    pts = cv2.findNonZero(roi)

    if pts is None:
        return None

    bx, by, bw2, bh2 = cv2.boundingRect(pts)
    roi = roi[by:by + bh2, bx:bx + bw2]

    rh, rw = roi.shape[:2]

    if rw <= 0 or rh <= 0:
        return None

    size = 80

    scale = min((size - 10) / rw, (size - 10) / rh)
    nw = max(1, int(rw * scale))
    nh = max(1, int(rh * scale))

    resized = cv2.resize(roi, (nw, nh), interpolation=cv2.INTER_NEAREST)

    canvas = np.zeros((size, size), dtype=np.uint8)
    xo = (size - nw) // 2
    yo = (size - nh) // 2
    canvas[yo:yo + nh, xo:xo + nw] = resized

    return canvas


def _mask_distance(a: np.ndarray, b: np.ndarray) -> float:
    """
    Rotation-tolerant IoU distance.
    0.0 = identical, bigger = different.
    """
    best = 999.0

    for angle in [-15, -10, -6, -3, 0, 3, 6, 10, 15]:
        if angle:
            M = cv2.getRotationMatrix2D((40, 40), angle, 1.0)
            bb = cv2.warpAffine(
                b,
                M,
                (80, 80),
                flags=cv2.INTER_NEAREST,
                borderValue=0,
            )
        else:
            bb = b

        aa = cv2.dilate(a, np.ones((3, 3), np.uint8), iterations=1) > 0
        bb = cv2.dilate(bb, np.ones((3, 3), np.uint8), iterations=1) > 0

        inter = np.logical_and(aa, bb).sum()
        union = np.logical_or(aa, bb).sum()

        if union <= 0:
            continue

        d = 1.0 - (inter / union)

        if d < best:
            best = float(d)

    return best


def _pick_unpaired_tile_by_shape(tiles: List[Dict[str, Any]], masks: List[np.ndarray]) -> Optional[Dict[str, Any]]:
    """
    Finds the ONE tile whose letter shape is unlike all others.
    Works for N=3, 4, or 5 tiles.

    For each candidate single_idx we remove it and score the remainder:
      - Even remainder (N=5 → 4 left): optimal pair-matching score.
      - Odd remainder  (N=4 → 3 left, N=3 → 2 left): best single-pair
        distance among the remainder (the closest matching pair).

    The candidate whose removal gives the LOWEST remainder score wins.
    A low score means the remaining tiles match each other well.
    """
    n = len(tiles)
    if n < 3 or len(masks) != n:
        return None

    # Build full pairwise distance matrix.
    dist = [[0.0] * n for _ in range(n)]
    for i in range(n):
        for j in range(i + 1, n):
            d = _mask_distance(masks[i], masks[j])
            dist[i][j] = d
            dist[j][i] = d

    print(f"[tile-unpaired] dist={[[round(x, 3) for x in row] for row in dist]}")

    def score_remainder(indices: List[int]) -> Tuple[float, List[Tuple[int, int]]]:
        """
        Return (score, pairs) for a group of indices.
        Lower score = these tiles match each other better.
        """
        m = len(indices)
        if m == 0:
            return 0.0, []

        if m == 2:
            # Only one possible pair.
            a, b = indices
            return dist[a][b], [(a, b)]

        if m % 2 == 0:
            # Even: optimal full pairing via recursion (fine for m<=6).
            first = indices[0]
            rest = indices[1:]
            best_sc = 999.0
            best_pairs: List[Tuple[int, int]] = []
            for k, partner in enumerate(rest):
                remaining = [rest[j] for j in range(len(rest)) if j != k]
                sub_sc, sub_pairs = score_remainder(remaining)
                total = dist[first][partner] + sub_sc
                if total < best_sc:
                    best_sc = total
                    best_pairs = [(first, partner)] + sub_pairs
            return best_sc, best_pairs

        else:
            # Odd: find the best same-letter pair among all pairs in this group.
            # The odd tile out within the remainder is the one that doesn't pair.
            # We want the minimum distance pair — that represents the matching letters.
            best_sc = 999.0
            best_pairs: List[Tuple[int, int]] = []
            for a_idx in range(m):
                for b_idx in range(a_idx + 1, m):
                    a, b = indices[a_idx], indices[b_idx]
                    d = dist[a][b]
                    if d < best_sc:
                        best_sc = d
                        best_pairs = [(a, b)]
            return best_sc, best_pairs

    # Threshold: if the best pair in the remainder are too dissimilar,
    # the shape match is unreliable → fall through to OCR.
    best_plan = None
    best_score = 999.0

    for single_idx in range(n):
        rest = [i for i in range(n) if i != single_idx]

        pairs_in_rest = [(rest[a], rest[b]) for a in range(len(rest)) for b in range(a + 1, len(rest))]
        avg_dist = sum(dist[a][b] for a, b in pairs_in_rest) / len(pairs_in_rest)

        if avg_dist < best_score:
            best_score = avg_dist
            best_plan = {
                "single": single_idx,
                "score": avg_dist,
            }

    if not best_plan:
        return None

    # Find runner-up score to check confidence gap.
    scores = []
    for single_idx in range(n):
        rest = [i for i in range(n) if i != single_idx]
        pairs_in_rest = [(rest[a], rest[b]) for a in range(len(rest)) for b in range(a + 1, len(rest))]
        avg_dist = sum(dist[a][b] for a, b in pairs_in_rest) / len(pairs_in_rest)
        scores.append(avg_dist)
    scores_sorted = sorted(scores)
    winner_score = scores_sorted[0]
    runner_up_score = scores_sorted[1] if len(scores_sorted) > 1 else 999.0
    gap = runner_up_score - winner_score

    print(f"[tile-unpaired] best_plan={best_plan} gap={gap:.3f}")

    # Shape result is only trusted when:
    # 1. The winning remainder avg distance is low (tiles match well).
    # 2. The gap to runner-up is clear (unambiguous winner).
    # Thresholds are generous because brown tiles with complex letters
    # (M, T, W, etc.) have naturally high inter-letter shape distances.
    PAIR_THRESHOLD = 0.55
    MIN_GAP = 0.05

    if best_plan["score"] > PAIR_THRESHOLD:
        print(f"[tile-unpaired] shape rejected: avg remainder dist {best_plan['score']:.3f} > {PAIR_THRESHOLD}")
        return None

    if gap < MIN_GAP:
        print(f"[tile-unpaired] shape rejected: gap {gap:.3f} < {MIN_GAP} (ambiguous) → OCR")
        return None

    t = tiles[best_plan["single"]]

    return {
        "x": int(round(t["cx"])),
        "y": int(round(t["cy"])),
        "letter": "?",
        "conf": round(100.0 - best_score * 100.0, 2),
        "method": "tile_shape",
    }


def _clean_ocr_text(text: str) -> str:
    text = (text or "").upper().strip()
    return re.sub(r"[^A-Z]", "", text)


def _make_tile_letter_strip(masks: List[np.ndarray], tiles: List[Dict[str, Any]]):
    """
    Fallback OCR strip. Works for any number of tiles >= 3.
    Only used if shape matching fails.
    """
    n = len(masks)
    if n < 3 or len(tiles) != n:
        return None, []

    order = sorted(range(n), key=lambda i: tiles[i]["cx"])
    ordered_tiles = [tiles[i] for i in order]
    ordered_masks = [masks[i] for i in order]

    crops = []

    for m in ordered_masks:
        # Tesseract likes black text on white background.
        crop = cv2.bitwise_not(m)
        crop = cv2.resize(crop, (80, 80), interpolation=cv2.INTER_CUBIC)
        crops.append(crop)

    spacer = np.full((80, 34), 255, dtype=np.uint8)
    left_pad = np.full((80, 25), 255, dtype=np.uint8)
    right_pad = np.full((80, 25), 255, dtype=np.uint8)

    parts = [left_pad]

    for i, crop in enumerate(crops):
        parts.append(crop)
        if i != len(crops) - 1:
            parts.append(spacer)

    parts.append(right_pad)

    return np.hstack(parts), ordered_tiles


def _pick_unpaired_by_ocr_fallback(tiles: List[Dict[str, Any]], masks: List[np.ndarray]) -> Optional[Dict[str, Any]]:
    n = len(tiles)
    strip, ordered_tiles = _make_tile_letter_strip(masks, tiles)

    if strip is None or len(ordered_tiles) != n:
        return None

    config = "--psm 7 --oem 3 -c tessedit_char_whitelist=ABCDEFGHIJKLMNOPQRSTUVWXYZ"

    try:
        text = pytesseract.image_to_string(strip, config=config)
    except Exception as e:
        print(f"[tile-ocr] tesseract error: {e}")
        return None

    letters = _clean_ocr_text(text)
    print(f"[tile-ocr] strip_letters={letters!r}")

    if len(letters) != n:
        return None

    detections = []
    for ch, t in zip(letters, ordered_tiles):
        detections.append(
            {
                "letter": ch,
                "x": int(round(t["cx"])),
                "y": int(round(t["cy"])),
                "conf": float(t.get("score", 0)),
            }
        )

    counts = Counter(d["letter"] for d in detections)
    singles = [letter for letter, count in counts.items() if count == 1]

    print(f"[tile-ocr] detections={detections} counts={dict(counts)}")

    # Best case: exactly one letter appears only once.
    if len(singles) == 1:
        unpaired = singles[0]
        best = next((d for d in detections if d["letter"] == unpaired), None)
        if best:
            best["method"] = "tile_ocr"
            return best

    # Fallback: OCR misread one letter (e.g. T→F or T→E), giving 2 apparent singles.
    # Use shape distance to pick the most isolated tile — the true odd-one-out
    # will be maximally different from all others in shape space.
    if len(singles) >= 2 and masks is not None and len(masks) == n:
        # Build ordered mask list matching ordered_tiles
        tile_to_mask: Dict[int, np.ndarray] = {}
        for orig_idx, t in enumerate(tiles):
            tile_to_mask[id(t)] = masks[orig_idx]

        ordered_masks_local = []
        for ot in ordered_tiles:
            matched = next((masks[i] for i, t in enumerate(tiles) if t is ot), None)
            ordered_masks_local.append(matched)

        if all(m is not None for m in ordered_masks_local):
            # Compute avg shape distance of each tile from all others.
            isolations = []
            for i in range(n):
                dists = []
                for j in range(n):
                    if i == j:
                        continue
                    d = _mask_distance(ordered_masks_local[i], ordered_masks_local[j])
                    dists.append(d)
                isolations.append(sum(dists) / len(dists) if dists else 0.0)

            # Among singles, pick the one with highest isolation.
            single_detections = [d for d in detections if d["letter"] in singles]
            best = max(
                single_detections,
                key=lambda d: isolations[detections.index(d)],
            )
            print(f"[tile-ocr] 2-singles tiebreak by isolation: picked letter={best['letter']} isolations={[round(x,3) for x in isolations]}")
            best["method"] = "tile_ocr_isolation"
            return best

    # N=3 special case: 2 same + 1 unique (counts like {A:2, B:1}).
    # The unique letter (count==1) is the answer regardless of singles logic.
    if n == 3:
        minority = min(counts, key=counts.get)
        if counts[minority] == 1:
            best = next((d for d in detections if d["letter"] == minority), None)
            if best:
                best["method"] = "tile_ocr_n3"
                return best

    return None


def _attempt_solve(pil_img: Image.Image, scale: float = 1.0) -> Optional[Dict[str, Any]]:
    """
    Single solve attempt on a PIL image.
    Returns raw best dict (with cx/cy in ORIGINAL image space) or None.
    scale is the upscale factor applied — used to convert coords back.
    """
    MIN_TILES = 3

    tiles = _detect_tile_boxes(pil_img)

    print(
        f"[tile-unpaired] tiles={len(tiles)} scale={scale:.1f} centers="
        f"{[(int(round(t['cx'])), int(round(t['cy']))) for t in tiles]}"
    )

    if len(tiles) < MIN_TILES:
        print(f"[tile-unpaired] only {len(tiles)} tile(s) detected (need >= {MIN_TILES})")
        return None

    masks: List[np.ndarray] = []
    valid_tiles: List[Dict[str, Any]] = []

    for t in tiles:
        m = _letter_mask_from_tile_crop(t["crop"])
        if m is None:
            continue
        valid_tiles.append(t)
        masks.append(m)

    print(f"[tile-unpaired] valid_tile_letters={len(valid_tiles)}")

    if len(valid_tiles) < MIN_TILES:
        print(f"[tile-unpaired] only {len(valid_tiles)} letter(s) extracted (need >= {MIN_TILES})")
        return None

    best = _pick_unpaired_tile_by_shape(valid_tiles, masks)

    if not best:
        best = _pick_unpaired_by_ocr_fallback(valid_tiles, masks)

    if best and scale != 1.0:
        # Convert coordinates back to original image space.
        best["x"] = int(round(best["x"] / scale))
        best["y"] = int(round(best["y"] / scale))

    return best


def solve_unpaired_letter(image_b64: str) -> Optional[Dict[str, Any]]:
    """
    Tile-first solver. Works for any number of tiles >= 3.

    Attempt 1: original image.
    Attempt 2: 2x upscale  (catches small/low-res images).
    Attempt 3: 3x upscale  (last resort for very tiny images).

    Upscaling is cheap (~1ms) and lets letter extraction succeed on images
    where crops are too small for the contour-area threshold.
    Coordinates are always returned in original image space.
    """
    start_time = time.time()

    pil_img = decode_b64_to_pil(image_b64)
    orig_w, orig_h = pil_img.size

    best = None

    for scale in [1.0, 2.0, 3.0]:
        if scale == 1.0:
            attempt_img = pil_img
        else:
            new_w = int(orig_w * scale)
            new_h = int(orig_h * scale)
            attempt_img = pil_img.resize((new_w, new_h), Image.BILINEAR)
            print(f"[tile-unpaired] retrying at {scale:.0f}x upscale ({new_w}x{new_h})")

        best = _attempt_solve(attempt_img, scale=scale)

        if best:
            break

    print(f"[tile-unpaired] result={best} time={round(time.time() - start_time, 3)}s")

    if not best:
        return None

    return {
        "x": int(best["x"]),
        "y": int(best["y"]),
        "letter": best.get("letter", "?"),
        "conf": round(float(best.get("conf", 0)), 2),
        "method": best.get("method", ""),
    }


# ======================================================
# CHECKLIST LETTER SOLVER
# ======================================================

def _ocr_single_tile_letter(tile_bgr: np.ndarray) -> str:
    """
    OCR a single tile crop and return the uppercase letter (or '').
    Tries multiple strategies so panel tiles (checklist side) also work.
    """
    config_single = "--psm 10 --oem 3 -c tessedit_char_whitelist=ABCDEFGHIJKLMNOPQRSTUVWXYZ"
    config_word   = "--psm 8  --oem 3 -c tessedit_char_whitelist=ABCDEFGHIJKLMNOPQRSTUVWXYZ"

    # Strategy 1: white-letter mask (works for colored grid tiles)
    mask = _letter_mask_from_tile_crop(tile_bgr)
    if mask is not None:
        inv = cv2.bitwise_not(mask)
        inv = cv2.resize(inv, (80, 80), interpolation=cv2.INTER_CUBIC)
        pad = np.full((100, 100), 255, dtype=np.uint8)
        pad[10:90, 10:90] = inv
        for cfg in [config_single, config_word]:
            try:
                text = pytesseract.image_to_string(pad, config=cfg)
                letter = _clean_ocr_text(text)[:1]
                if letter:
                    return letter
            except Exception:
                pass

    # Strategy 2: direct grayscale OCR on the raw crop
    h, w = tile_bgr.shape[:2]
    target_size = 96
    sc = max(1.0, target_size / max(h, w, 1))
    resized = cv2.resize(tile_bgr, (max(1, int(w * sc)), max(1, int(h * sc))), interpolation=cv2.INTER_CUBIC)
    gray = cv2.cvtColor(resized, cv2.COLOR_BGR2GRAY)

    def _t1(g): return cv2.threshold(g, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)[1]
    def _t2(g): return cv2.adaptiveThreshold(g, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, 11, 2)
    def _t3(g): return cv2.threshold(g, 127, 255, cv2.THRESH_BINARY)[1]
    def _t4(g): return cv2.threshold(g, 127, 255, cv2.THRESH_BINARY_INV)[1]

    for thresh_fn in [_t1, _t2, _t3, _t4]:
        try:
            bw = thresh_fn(gray)
            pad2 = np.full((bw.shape[0] + 20, bw.shape[1] + 20), 255, dtype=np.uint8)
            pad2[10:10 + bw.shape[0], 10:10 + bw.shape[1]] = bw
            for cfg in [config_single, config_word]:
                text = pytesseract.image_to_string(pad2, config=cfg)
                letter = _clean_ocr_text(text)[:1]
                if letter:
                    return letter
        except Exception:
            pass

    return ""


def _extract_checklist_letters(pil_img: Image.Image) -> List[str]:
    """
    Extracts the target letters from the checklist panel on the right side of the image.

    The panel shows items like:  [3x]  [T]   or   [1x] [C]  [1x] [P]  [1x] [N]
    We read the letter tile(s) from the right panel.

    Strategy:
    1. Crop the right ~25% of the image (the checklist panel).
    2. Detect tile boxes within that crop.
    3. OCR each tile letter.
    4. Also read the count badge next to each tile using Tesseract on the badge region.
    5. Return a flat list of target letters (repeated by count).
    """
    img_rgb = np.array(pil_img.convert("RGB"))
    img_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
    ih, iw = img_bgr.shape[:2]

    # The checklist panel is typically the rightmost ~28% of the image.
    panel_x_start = int(iw * 0.72)
    panel_crop_bgr = img_bgr[:, panel_x_start:, :]
    panel_pil = Image.fromarray(cv2.cvtColor(panel_crop_bgr, cv2.COLOR_BGR2RGB))

    ph, pw = panel_crop_bgr.shape[:2]

    # Detect tile boxes inside the panel.
    tiles = _detect_tile_boxes(panel_pil)

    print(f"[checklist] panel tiles detected: {len(tiles)} in crop x>{panel_x_start}")

    if not tiles:
        return []

    # For each tile, OCR the letter.
    # Also try to find the count badge (oval above/beside each tile).
    target_letters: List[str] = []

    # Sort tiles top-to-bottom (they're stacked vertically in the panel).
    tiles_sorted = sorted(tiles, key=lambda t: t["cy"])

    gray_panel = cv2.cvtColor(panel_crop_bgr, cv2.COLOR_BGR2GRAY)

    for tile in tiles_sorted:
        letter = _ocr_single_tile_letter(tile["crop"])
        if not letter:
            print(f"[checklist] tile at ({tile['cx']:.0f},{tile['cy']:.0f}) — OCR failed")
            continue

        # Try to find count badge.  Badges are typically above the tile in the panel.
        count = _read_count_badge(gray_panel, tile, pw, ph)
        print(f"[checklist] letter={letter!r} count={count}")

        for _ in range(count):
            target_letters.append(letter)

    return target_letters


def _read_count_badge(gray: np.ndarray, tile: Dict[str, Any], pw: int, ph: int) -> int:
    """
    Reads the numeric count from the oval badge near the tile in the panel.
    Searches above AND to the left of the tile center.
    Returns 1 if not found (caller handles the "click all" fallback).
    """
    cx, cy = int(tile["cx"]), int(tile["cy"])

    # Badge can be above the tile or to its left; search a generous region.
    badge_y1 = max(0, cy - 90)
    badge_y2 = min(ph, cy + 10)
    badge_x1 = 0
    badge_x2 = pw

    if badge_y2 <= badge_y1:
        return 1

    badge_roi = gray[badge_y1:badge_y2, badge_x1:badge_x2]

    # Upscale small ROIs so Tesseract has enough pixels.
    min_dim = min(badge_roi.shape[0], badge_roi.shape[1])
    if min_dim < 20:
        sc = max(1, int(40 / max(min_dim, 1)))
        badge_roi = cv2.resize(badge_roi, None, fx=sc, fy=sc, interpolation=cv2.INTER_LINEAR)

    config = "--psm 11 --oem 3 -c tessedit_char_whitelist=0123456789xX"
    pattern = re.compile(r"(\d+)\s*[xX]")

    # Try multiple thresholding approaches.
    for inv in [False, True]:
        try:
            bw = badge_roi if not inv else cv2.bitwise_not(badge_roi)
            _, thresh = cv2.threshold(bw, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
            text = pytesseract.image_to_string(thresh, config=config).strip().lower()
            m = pattern.search(text)
            if m:
                val = int(m.group(1))
                if 1 <= val <= 9:
                    print(f"[checklist] badge text={text!r} count={val}")
                    return val
        except Exception:
            pass

    return 1


def _ocr_all_grid_letters(pil_img: Image.Image) -> List[Dict[str, Any]]:
    """
    OCR every tile in the left 4x4 grid and return list of
    {'letter': str, 'x': int, 'y': int} in original image coordinates.
    """
    img_rgb = np.array(pil_img.convert("RGB"))
    img_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
    ih, iw = img_bgr.shape[:2]

    # The grid is on the left ~72% of the image.
    grid_x_end = int(iw * 0.75)
    grid_crop_bgr = img_bgr[:, :grid_x_end, :]
    grid_pil = Image.fromarray(cv2.cvtColor(grid_crop_bgr, cv2.COLOR_BGR2RGB))

    tiles = _detect_tile_boxes(grid_pil)

    print(f"[checklist] grid tiles detected: {len(tiles)}")

    results: List[Dict[str, Any]] = []

    for tile in tiles:
        letter = _ocr_single_tile_letter(tile["crop"])
        if not letter:
            continue
        results.append({
            "letter": letter,
            "x": int(round(tile["cx"])),
            "y": int(round(tile["cy"])),
        })

    return results


def solve_checklist_letters(image_b64: str) -> List[Dict[str, Any]]:
    """
    Full solver for 'Select all letters from the checklist based on the count'.

    1. Extract target letters (with counts) from the right checklist panel.
    2. OCR all tiles in the left grid.
    3. For each target letter, find matching grid tiles and collect their coords.
    4. Return list of {x, y} dicts — one per required click.
    """
    start = time.time()
    pil_img = decode_b64_to_pil(image_b64)
    orig_w, orig_h = pil_img.size

    # Try at native resolution first, then 2x if needed.
    target_letters: List[str] = []
    grid_letters: List[Dict[str, Any]] = []

    for scale in [1.0, 2.0]:
        if scale == 1.0:
            attempt = pil_img
        else:
            attempt = pil_img.resize((int(orig_w * scale), int(orig_h * scale)), Image.BILINEAR)

        target_letters = _extract_checklist_letters(attempt)
        grid_letters = _ocr_all_grid_letters(attempt)

        print(f"[checklist] scale={scale} targets={target_letters} grid_letters={[g['letter'] for g in grid_letters]}")

        if target_letters and grid_letters:
            break

    if not target_letters:
        print("[checklist] no target letters found in checklist panel")
        return []

    if not grid_letters:
        print("[checklist] no grid letters detected")
        return []

    # Scale coordinates back to original image space.
    inv_scale = orig_w / (pil_img.size[0])  # always 1 since we only resize attempt

    # For each target letter (expanded by count), find matching tiles in the grid.
    # We consume matches so the same tile isn't clicked twice.
    # target_letters is already a flat list repeated by count, e.g. ["V","V","V"] for 3x V.
    # However if the badge OCR failed and returned count=1 but the puzzle wants ALL
    # occurrences of that letter, we fall back to clicking every matching tile.
    badge_ocr_uncertain = len(set(target_letters)) == len(target_letters)  # all unique → likely count=1 defaults

    clicks: List[Dict[str, Any]] = []
    used_indices: Set[int] = set()

    def _coord(g):
        if scale == 2.0:
            return int(round(g["x"] / 2.0)), int(round(g["y"] / 2.0))
        return g["x"], g["y"]

    if badge_ocr_uncertain:
        # Badge counts may be unreliable — click ALL tiles matching each target letter.
        for target in set(target_letters):
            for idx, g in enumerate(grid_letters):
                if idx in used_indices:
                    continue
                if g["letter"] == target:
                    x, y = _coord(g)
                    clicks.append({"x": x, "y": y, "letter": target})
                    used_indices.add(idx)
    else:
        # Badge counts are reliable — click exactly count tiles per letter.
        for target in target_letters:
            for idx, g in enumerate(grid_letters):
                if idx in used_indices:
                    continue
                if g["letter"] == target:
                    x, y = _coord(g)
                    clicks.append({"x": x, "y": y, "letter": target})
                    used_indices.add(idx)
                    break

    print(f"[checklist] clicks={clicks} time={round(time.time()-start,3)}s")
    return clicks


# ======================================================
# ANIMAL ICONS PUZZLE SOLVER
# "Select ALL animal icons according to the counts shown"
# ======================================================

def _load_animals_model():
    """
    Lazy-load the animals ONNX model (same pattern as all other models).
    - Uses _find_onnx_in_folder so any .onnx filename works.
    - Reads the required input size directly from the model (no hardcoding).
    - Class names come from the registry (folder name = question label),
      exactly the same way every other model resolves its class.
    Returns (session, img_size) or (None, 256).
    """
    global _animals_session, _animals_img_size

    if _animals_session is not None:
        return _animals_session, _animals_img_size

    for base in [HOT_DIR, COLD_DIR]:
        folder = os.path.join(base, "animals_puzzle")
        onnx_path = _find_onnx_in_folder(folder)
        if not onnx_path:
            continue
        try:
            sess = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])
            # Read the spatial input size the model actually expects.
            inp = sess.get_inputs()[0]
            shape = inp.shape  # e.g. [1, 3, 256, 256] or [1, 3, 640, 640]
            img_size = int(shape[2]) if len(shape) >= 4 and isinstance(shape[2], int) and shape[2] > 0 else 256
            _animals_session = sess
            _animals_img_size = img_size
            print(f"[animals] Loaded ONNX model from {onnx_path}, input size={img_size}")
            return _animals_session, _animals_img_size
        except Exception as e:
            print(f"[animals] Failed to load {onnx_path}: {e}")

    print("[animals] .onnx not found in model_hot/animals_puzzle or model_cold/animals_puzzle")
    return None, 256


def _nms_detections(dets: List[Dict[str, Any]], iou_threshold: float = 0.45) -> List[Dict[str, Any]]:
    """
    Per-class Non-Maximum Suppression to remove duplicate overlapping boxes.
    dets must have keys: cls, conf, cx, cy, w, h (all in original pixel space).
    Returns the deduplicated list sorted by confidence descending.
    """
    if not dets:
        return []

    # Group by class, apply NMS within each class.
    from collections import defaultdict
    by_cls: Dict[str, List[Dict[str, Any]]] = defaultdict(list)
    for d in dets:
        by_cls[d["cls"]].append(d)

    kept = []
    for cls_dets in by_cls.values():
        cls_dets.sort(key=lambda d: d["conf"], reverse=True)
        while cls_dets:
            best = cls_dets.pop(0)
            kept.append(best)
            bx1 = best["cx"] - best["w"] / 2
            by1 = best["cy"] - best["h"] / 2
            bx2 = best["cx"] + best["w"] / 2
            by2 = best["cy"] + best["h"] / 2
            b_area = best["w"] * best["h"]

            survivors = []
            for d in cls_dets:
                dx1 = d["cx"] - d["w"] / 2
                dy1 = d["cy"] - d["h"] / 2
                dx2 = d["cx"] + d["w"] / 2
                dy2 = d["cy"] + d["h"] / 2

                ix1 = max(bx1, dx1)
                iy1 = max(by1, dy1)
                ix2 = min(bx2, dx2)
                iy2 = min(by2, dy2)
                inter = max(0.0, ix2 - ix1) * max(0.0, iy2 - iy1)
                union = b_area + d["w"] * d["h"] - inter
                iou = inter / union if union > 0 else 0.0
                if iou < iou_threshold:
                    survivors.append(d)
            cls_dets = survivors

    return kept


def _animals_run_onnx(
    sess: ort.InferenceSession,
    img_bgr: np.ndarray,
    img_size: int,
    conf_threshold: float = 0.25,
    iou_threshold: float = 0.45,
) -> List[Dict[str, Any]]:
    """
    Run ONNX inference and return ONE detection per object (after NMS).
    Handles YOLOv8-style output: [1, num_classes+4, num_anchors].
    Class labels are integer strings.
    """
    ih, iw = img_bgr.shape[:2]

    # Preprocess: resize to model's expected size → RGB float32 → NCHW
    resized = cv2.resize(img_bgr, (img_size, img_size), interpolation=cv2.INTER_LINEAR)
    rgb = cv2.cvtColor(resized, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
    tensor = np.expand_dims(rgb.transpose(2, 0, 1), axis=0)  # (1, 3, H, W)

    input_name = sess.get_inputs()[0].name
    raw = sess.run(None, {input_name: tensor})[0]  # (1, num_cls+4, anchors)

    preds = raw[0].T  # (anchors, num_cls+4)
    num_classes = preds.shape[1] - 4
    if num_classes <= 0:
        return []

    dets = []
    for row in preds:
        scores = row[4:]
        best_cls = int(np.argmax(scores))
        best_conf = float(scores[best_cls])
        if best_conf < conf_threshold:
            continue
        # cx, cy, w, h are in img_size space — scale to original pixels
        cx = float(row[0]) / img_size * iw
        cy = float(row[1]) / img_size * ih
        w  = float(row[2]) / img_size * iw
        h  = float(row[3]) / img_size * ih
        dets.append({"cls": str(best_cls), "conf": best_conf, "cx": cx, "cy": cy, "w": w, "h": h})

    return _nms_detections(dets, iou_threshold=iou_threshold)


def solve_animal_icons(image_b64: str) -> List[Dict[str, Any]]:
    """
    Solver for the animal-icon grid puzzle (ONNX version).

    1. Run the ONNX model on the full board image.
    2. Split detections into LEFT (reference panel) vs RIGHT (grid) by x.
    3. Collect animal classes present in the reference panel.
    4. Click ALL grid cells whose class matches any reference class.
    Returns list of {"x": int, "y": int} dicts.
    """
    sess, img_size = _load_animals_model()
    if sess is None:
        return []

    pil_img = decode_b64_to_pil(image_b64)
    img_bgr = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
    ih, iw = img_bgr.shape[:2]

    dets = _animals_run_onnx(sess, img_bgr, img_size, conf_threshold=0.25)

    if not dets:
        print("[animals] no detections")
        return []

    # Find content width (ignore black padding on the right)
    gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
    bright_cols = np.where(gray.max(axis=0) > 15)[0]
    content_right = int(bright_cols[-1]) if len(bright_cols) else iw
    panel_x_max = content_right * 0.33  # reference panel ≈ left 33% of content

    refs = [d for d in dets if d["cx"] < panel_x_max]
    grid = [d for d in dets if d["cx"] >= panel_x_max]

    ref_classes = set(d["cls"] for d in refs)
    if not ref_classes:
        print("[animals] no reference icons detected in left panel")
        return []

    clicks = []
    for d in grid:
        if d["cls"] in ref_classes:
            clicks.append({"x": int(round(d["cx"])), "y": int(round(d["cy"]))})

    print(f"[animals] img_size={img_size} ref_classes={ref_classes} grid_dets={len(grid)} clicks={len(clicks)}")
    for c in clicks:
        print(f"  -> click ({c['x']}, {c['y']})")

    return clicks


# ======================================================
# OPPOSITE ROTATION SOLVER  (barber-pole / canvasVideo)
# ======================================================

def _decode_webm_frames(video_b64_list: List[str]):
    """
    Accepts a list of base64 webm strings (payload["video"]).
    Concatenates, writes to a temp file, reads with OpenCV.
    Returns (frames, content_box) where:
      frames      – list of BGR numpy arrays cropped to content area
      content_box – (cx, cy, cw, ch) offset of content inside the full frame
    """
    raw = b"".join(base64.b64decode(chunk) for chunk in video_b64_list)

    with tempfile.NamedTemporaryFile(suffix=".webm", delete=False) as f:
        f.write(raw)
        tmp_path = f.name

    try:
        cap = cv2.VideoCapture(tmp_path)
        if not cap.isOpened():
            return [], (0, 0, 0, 0)

        total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        fw    = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        fh    = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

        frames_raw: List[np.ndarray] = []
        for _ in range(total):
            ret, frame = cap.read()
            if ret:
                frames_raw.append(frame)
        cap.release()
    finally:
        try:
            os.unlink(tmp_path)
        except OSError:
            pass

    if not frames_raw:
        return [], (0, 0, fw, fh)

    # Auto-detect content bounding box to remove black letterbox bars
    sample = frames_raw[len(frames_raw) // 2]
    gray   = cv2.cvtColor(sample, cv2.COLOR_BGR2GRAY)
    _, thresh = cv2.threshold(gray, 15, 255, cv2.THRESH_BINARY)
    contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    if contours:
        bx, by, bw, bh = cv2.boundingRect(max(contours, key=cv2.contourArea))
    else:
        bx, by, bw, bh = 0, 0, fw, fh

    frames = [f[by:by+bh, bx:bx+bw] for f in frames_raw]
    return frames, (bx, by, bw, bh)


# Flow is computed at this scale factor for speed (upscaled back afterwards)
_FLOW_SCALE = 0.5

def _build_full_flow(frames: List[np.ndarray]) -> tuple:
    """
    Accumulate optical flow over sampled frame pairs for speed.
    - Downscales frames to 50% before computing flow (4x faster).
    - Samples every Nth pair so total pairs ≤ 60 regardless of frame count.
    - Returns (vy_mean, mag_mean, sign_consistency) at ORIGINAL resolution.
    """
    if len(frames) < 2:
        h, w = frames[0].shape[:2]
        z = np.zeros((h, w), np.float32)
        return z, z, np.ones((h, w), np.float32)

    h, w   = frames[0].shape[:2]
    sh, sw = int(h * _FLOW_SCALE), int(w * _FLOW_SCALE)

    vy_sum  = np.zeros((sh, sw), np.float32)
    mag_sum = np.zeros((sh, sw), np.float32)
    pos_cnt = np.zeros((sh, sw), np.float32)
    neg_cnt = np.zeros((sh, sw), np.float32)
    n = 0

    skip  = min(5, len(frames) // 8)
    total = len(frames) - 1 - skip
    # Sample at most 60 pairs evenly spaced — enough for stable averages
    step  = max(1, total // 60)

    for i in range(skip, len(frames) - 1, step):
        f1s = cv2.resize(frames[i],     (sw, sh), interpolation=cv2.INTER_LINEAR)
        f2s = cv2.resize(frames[i + 1], (sw, sh), interpolation=cv2.INTER_LINEAR)
        g1  = cv2.cvtColor(f1s, cv2.COLOR_BGR2GRAY)
        g2  = cv2.cvtColor(f2s, cv2.COLOR_BGR2GRAY)
        # Faster Farneback params: fewer pyramid levels, smaller window
        fl  = cv2.calcOpticalFlowFarneback(g1, g2, None, 0.5, 2, 12, 2, 5, 1.1, 0)
        vx_f = fl[..., 0]
        vy_f = fl[..., 1]
        m_f  = np.sqrt(vx_f**2 + vy_f**2)
        vy_sum  += vy_f
        mag_sum += m_f
        pos_cnt += (vy_f >  0.1).astype(np.float32)
        neg_cnt += (vy_f < -0.1).astype(np.float32)
        n += 1

    if n == 0:
        z = np.zeros((h, w), np.float32)
        return z, z, np.ones((h, w), np.float32)

    vy_mean  = vy_sum  / n
    mag_mean = mag_sum / n
    sign_con = np.abs(pos_cnt - neg_cnt) / n

    # Upscale back to original resolution
    vy_mean  = cv2.resize(vy_mean,  (w, h), interpolation=cv2.INTER_LINEAR)
    mag_mean = cv2.resize(mag_mean, (w, h), interpolation=cv2.INTER_LINEAR)
    sign_con = cv2.resize(sign_con, (w, h), interpolation=cv2.INTER_LINEAR)

    return vy_mean, mag_mean, sign_con


def _score_roi(vy: np.ndarray, mag: np.ndarray,
               sign_con: np.ndarray,
               x1: int, y1: int, x2: int, y2: int) -> float:
    """
    Compute a robust signed flow score for a bounding box.
    Uses magnitude-weighted mean vy, restricted to pixels with:
      - mag > 20th percentile of the ROI  (foreground only)
      - sign_consistency > 0.5            (consistent direction)
    """
    roi_v = vy[y1:y2, x1:x2]
    roi_m = mag[y1:y2, x1:x2]
    roi_s = sign_con[y1:y2, x1:x2]

    if roi_v.size == 0:
        return 0.0

    # Foreground mask: strong motion AND consistent sign
    fg_thresh = float(np.percentile(roi_m, 20)) if roi_m.size > 0 else 0.0
    fg = (roi_m >= fg_thresh) & (roi_s >= 0.5)

    if fg.sum() < 10:
        # fallback: just use all pixels weighted by magnitude
        w_sum = float(roi_m.sum())
        return float((roi_v * roi_m).sum() / w_sum) if w_sum > 1e-6 else 0.0

    w_sum = float(roi_m[fg].sum())
    return float((roi_v[fg] * roi_m[fg]).sum() / w_sum) if w_sum > 1e-6 else 0.0


def _find_objects_by_contour(vy: np.ndarray, mag: np.ndarray,
                              fh: int, fw: int,
                              sign_con: Optional[np.ndarray] = None) -> List[Dict[str, Any]]:
    """
    Detect individual rotating objects using watershed on the mag map.
    Uses sign-consistent foreground scoring for robust flow estimation.
    """
    if sign_con is None:
        sign_con = np.ones_like(mag)

    nonzero = mag[mag > 0.02]
    if len(nonzero) < 100:
        return []

    # ── 1. Motion mask ────────────────────────────────────────────────────────
    thresh_val = float(np.percentile(nonzero, 30))
    binary = (mag >= thresh_val).astype(np.uint8)

    # ── 2. Distance transform ─────────────────────────────────────────────────
    dist = cv2.distanceTransform(binary, cv2.DIST_L2, 5)
    cv2.normalize(dist, dist, 0, 1.0, cv2.NORM_MINMAX)

    # ── 3. Local maxima — adaptive separation based on estimated object size ──
    # Objects are roughly 8-12% of frame width apart
    sep = max(30, int(fw * 0.08))
    kernel_lm = np.ones((sep, sep), np.uint8)
    dilated   = cv2.dilate(dist, kernel_lm)
    # Stricter height threshold to avoid noise peaks
    local_max = ((dist >= dilated - 1e-6) & (dist > 0.35)).astype(np.uint8)

    num_peaks, peak_labels = cv2.connectedComponents(local_max)
    print(f"[rotation] peaks={num_peaks - 1}  sep={sep}px")

    objects = []

    if num_peaks > 1:
        # ── 4. Watershed ──────────────────────────────────────────────────────
        markers = peak_labels.astype(np.int32)
        mag_u8  = (mag * 255 / (mag.max() + 1e-6)).astype(np.uint8)
        mag_bgr = cv2.cvtColor(mag_u8, cv2.COLOR_GRAY2BGR)
        cv2.watershed(mag_bgr, markers)

        min_a = fh * fw * 0.004
        max_a = fh * fw * 0.25

        for label in range(1, num_peaks):
            mask = ((markers == label) & (binary == 1)).astype(np.uint8)
            px   = np.where(mask)
            if len(px[0]) < 50:
                continue

            y1_, y2_ = int(px[0].min()), int(px[0].max())
            x1_, x2_ = int(px[1].min()), int(px[1].max())

            pad = 8
            x1_ = max(0, x1_-pad); y1_ = max(0, y1_-pad)
            x2_ = min(fw, x2_+pad); y2_ = min(fh, y2_+pad)

            area = (y2_-y1_) * (x2_-x1_)
            if not (min_a < area < max_a):
                continue

            flow = _score_roi(vy, mag, sign_con, x1_, y1_, x2_, y2_)
            objects.append({
                "id": len(objects)+1,
                "x1": x1_, "y1": y1_, "x2": x2_, "y2": y2_,
                "cx": (x1_+x2_)//2, "cy": (y1_+y2_)//2,
                "flow": flow, "area": area,
            })

    # ── 5. Fallback ───────────────────────────────────────────────────────────
    if not objects:
        k = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7))
        b2 = cv2.morphologyEx(binary, cv2.MORPH_OPEN, k)
        contours, _ = cv2.findContours(b2, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        min_a = fh * fw * 0.004
        max_a = fh * fw * 0.25
        contours = sorted([c for c in contours if min_a < cv2.contourArea(c) < max_a],
                          key=cv2.contourArea, reverse=True)[:12]
        for i, c in enumerate(contours):
            bx, by, bw, bh = cv2.boundingRect(c)
            flow = _score_roi(vy, mag, sign_con, bx, by, bx+bw, by+bh)
            objects.append({
                "id": i+1,
                "x1": bx, "y1": by, "x2": bx+bw, "y2": by+bh,
                "cx": bx+bw//2, "cy": by+bh//2,
                "flow": flow, "area": int(cv2.contourArea(c)),
            })

    # ── 6. IoU deduplication ──────────────────────────────────────────────────
    def iou(a, b):
        ix1 = max(a["x1"],b["x1"]); iy1 = max(a["y1"],b["y1"])
        ix2 = min(a["x2"],b["x2"]); iy2 = min(a["y2"],b["y2"])
        inter = max(0, ix2-ix1) * max(0, iy2-iy1)
        ua = (a["x2"]-a["x1"]) * (a["y2"]-a["y1"])
        ub = (b["x2"]-b["x1"]) * (b["y2"]-b["y1"])
        return inter / (ua + ub - inter + 1e-6)

    objects = sorted(objects, key=lambda o: abs(o["flow"]), reverse=True)
    keep = []; used = [False]*len(objects)
    for i, o in enumerate(objects):
        if used[i]: continue
        keep.append(o)
        for j in range(i+1, len(objects)):
            if not used[j] and iou(o, objects[j]) > 0.35:
                used[j] = True

    for i, o in enumerate(keep):
        o["id"] = i+1

    return keep


def _pick_odd_object(objects: List[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
    """
    Pick the one object rotating opposite to the majority.

    Uses a tiered approach:
    1. If there is a clear sign-opposite minority (≤ 1 object opposite),
       return the one with the strongest opposite flow.
    2. If mixed (could be 2 vs 6 etc.), pick the minority by count.
    3. If all same direction, use z-score outlier detection.
    """
    if not objects:
        return None

    flows    = np.array([o["flow"] for o in objects], dtype=np.float32)
    max_abs  = float(np.max(np.abs(flows)))
    if max_abs < 0.05:
        return objects[0]  # nothing meaningful

    # Noise floor: 8% of strongest
    noise   = max(0.05, max_abs * 0.08)
    real    = [o for o in objects if abs(o["flow"]) >= noise]
    if not real:
        real = objects

    real_flows = np.array([o["flow"] for o in real])
    n_cw   = int(np.sum(real_flows > 0))
    n_ccw  = int(np.sum(real_flows < 0))
    total  = len(real)

    print(f"[rotation] pick: n_cw={n_cw} n_ccw={n_ccw} noise={noise:.3f}")

    if n_cw > 0 and n_ccw > 0:
        # Both directions present — strict minority rule
        if n_ccw < n_cw:
            minority = [o for o in real if o["flow"] < 0]
            # Among minority, pick strongest (most negative)
            return min(minority, key=lambda o: o["flow"])
        elif n_cw < n_ccw:
            minority = [o for o in real if o["flow"] > 0]
            return max(minority, key=lambda o: o["flow"])
        else:
            # Equal — pick extremes
            best_cw  = max(real, key=lambda o: o["flow"])
            best_ccw = min(real, key=lambda o: o["flow"])
            # Pick the one furthest from 0
            if abs(best_ccw["flow"]) >= abs(best_cw["flow"]):
                return best_ccw
            return best_cw

    # All same direction — z-score outlier
    mean_f = float(np.mean(real_flows))
    std_f  = float(np.std(real_flows))
    if std_f < 1e-6:
        return real[0]

    z_scores = np.abs((real_flows - mean_f) / std_f)
    best_idx = int(np.argmax(z_scores))
    print(f"[rotation] all-same-dir outlier z={z_scores[best_idx]:.2f} id={real[best_idx]['id']}")
    return real[best_idx]

# kept for backwards compat
def _detect_barber_pole_objects(frames: List[np.ndarray]) -> List[Dict[str, Any]]:
    if not frames:
        return []
    vy, mag = _build_full_flow(frames)
    fh, fw  = frames[0].shape[:2]
    return _find_objects_by_contour(vy, mag, fh, fw)


def _measure_stripe_flow(frames: List[np.ndarray], obj: Dict[str, Any], start: int = 3) -> float:
    r = _build_full_flow(frames[start:])
    vy = r[0]; mag = r[1]; sign_con = r[2] if len(r) == 3 else np.ones_like(vy)
    x1, y1, x2, y2 = obj["x1"], obj["y1"], obj["x2"], obj["y2"]
    return _score_roi(vy, mag, sign_con, x1, y1, x2, y2)


def solve_opposite_rotation(video_b64_list: List[str], screenshot_b64: Optional[str] = None) -> Optional[Dict[str, Any]]:
    """
    Solver for 'Find the object rotating in the opposite direction'.
    """
    frames, content_box = _decode_webm_frames(video_b64_list)

    if not frames:
        print("[rotation] ERROR: no frames decoded")
        return None

    cx_off, cy_off = content_box[0], content_box[1]
    fh, fw = frames[0].shape[:2]
    print(f"[rotation] frames={len(frames)} content_box={content_box} frame_size={fw}x{fh}")

    result_flow = _build_full_flow(frames)
    if len(result_flow) == 3:
        vy, mag, sign_con = result_flow
    else:
        vy, mag = result_flow; sign_con = np.ones_like(mag)
    objects = _find_objects_by_contour(vy, mag, fh, fw, sign_con)
    print(f"[rotation] objects detected: {len(objects)}")

    if not objects:
        print("[rotation] ERROR: no objects found")
        return None

    for obj in objects:
        obj["direction"] = "clockwise" if obj["flow"] > 0 else "counter-clockwise"
        print(f"[rotation] obj={obj['id']:>2}  flow={obj['flow']:+.4f}  "
              f"bbox=({obj['x1']},{obj['y1']})-({obj['x2']},{obj['y2']})  "
              f"-> {obj['direction']}")

    odd = _pick_odd_object(objects)
    if not odd:
        print("[rotation] ERROR: could not pick odd object")
        return None

    orig_x = odd["cx"] + cx_off
    orig_y = odd["cy"] + cy_off

    print(f"[rotation] ODD: id={odd['id']} dir={odd['direction']} "
          f"flow={odd['flow']:+.4f}  "
          f"cropped_center=({odd['cx']},{odd['cy']})  "
          f"original_click=({orig_x},{orig_y})")

    # Scale to screenshot dimensions
    if screenshot_b64:
        try:
            ss_img = decode_b64_to_pil(screenshot_b64)
            ss_w, ss_h = ss_img.size
            full_w = content_box[2] if content_box[2] > 0 else fw
            full_h = cy_off + fh
            scale_x = ss_w / full_w
            scale_y = ss_h / full_h
            orig_x = int(round(orig_x * scale_x))
            orig_y = int(round(orig_y * scale_y))
            print(f"[rotation] screenshot={ss_w}x{ss_h} full={full_w}x{full_h} "
                  f"scale=({scale_x:.3f},{scale_y:.3f}) scaled_click=({orig_x},{orig_y})")
        except Exception as _se:
            print(f"[rotation] coord scaling failed: {_se}")

    return {
        "x": orig_x,
        "y": orig_y,
        "direction": odd["direction"],
        "flow": round(odd["flow"], 4),
        "object_id": odd["id"],
        "all_objects": [
            {"id": o["id"], "x": o["cx"]+cx_off, "y": o["cy"]+cy_off,
             "flow": round(o["flow"], 4), "direction": o["direction"]}
            for o in objects
        ],
    }


def is_opposite_rotation_question(question: str) -> bool:
    """Returns True when the question asks to find the object rotating in the opposite direction."""
    q = normalize_label(question)
    checks = [
        ["rotating", "opposite"],
        ["rotation", "opposite"],
        ["opposite", "direction"],
        ["rotate", "opposite"],
        ["spins", "opposite"],
        ["spinning", "opposite"],
    ]
    return any(all(word in q for word in words) for words in checks)


# ======================================================
# DIFFERENT-ANIMAL CARD SOLVER  (canvasVideo)
# "Select the card with a different animal"
# ======================================================

def is_different_animal_question(question: str) -> bool:
    """
    True for 'Select the card with a different animal' style prompts.
    The cards start face-down and reveal their animal over the video,
    so this is handled with the canvasVideo flow (like rotation).
    """
    q = normalize_label(question.translate(_HOMOGLYPH_MAP))
    checks = [
        ["card", "different", "animal"],
        ["different", "animal"],
        ["animal", "different"],
        ["odd", "animal"],
        ["different", "creature"],
    ]
    return any(all(word in q for word in words) for words in checks)


# Known card-tile colours (OpenCV HSV: H 0-179, S/V 0-255). A given puzzle
# uses ONE of these; the solver auto-detects which colour is on the board so
# the other puzzle's background (e.g. green aurora vs blue ice) can't be
# mistaken for tiles.
#   teal: H≈80,  S≈129, V≈176  (green/teal boards)
#   blue: H≈106, S≈165, V≈255  (bright-blue boards) — needs high S & V so the
#         icy bluish background (low S/V) is excluded.
_TILE_COLORS: Dict[str, Tuple[Tuple[int, int, int], Tuple[int, int, int]]] = {
    "teal": ((65, 60, 90),   (98, 255, 255)),
    "blue": ((96, 120, 185), (122, 255, 255)),
}

# Default range used by helpers when none is supplied (teal, for back-compat).
_DEFAULT_TILE_RANGE = _TILE_COLORS["teal"]


def _tile_mask(bgr: np.ndarray, rng=None) -> np.ndarray:
    """Binary mask of the card-tile colour for the given HSV range."""
    if rng is None:
        rng = _DEFAULT_TILE_RANGE
    hsv = cv2.cvtColor(bgr, cv2.COLOR_BGR2HSV)
    return cv2.inRange(hsv, rng[0], rng[1])


def _detect_card_cells(frames: List[np.ndarray], rng=None, max_probe: int = 12) -> List[Tuple[int, int, int, int]]:
    """
    Locate the fixed card tiles (the coloured rounded squares).

    The tiles never move, so we accumulate tile-coloured square candidates
    across a handful of frames and cluster their centres. Returns a list of
    stable (x, y, w, h) boxes ordered row-by-row, left-to-right.
    """
    if not frames:
        return []

    fh, fw = frames[0].shape[:2]
    min_area = 0.02 * fw * fh
    probe = frames[:: max(1, len(frames) // max_probe)] if len(frames) > max_probe else frames

    raw: List[Tuple[int, int, int, int]] = []
    for f in probe:
        mask = _tile_mask(f, rng)
        mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, np.ones((7, 7), np.uint8))
        cnts, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        for c in cnts:
            x, y, w, h = cv2.boundingRect(c)
            area = w * h
            if area < min_area:
                continue
            # Tiles are SOLID filled squares; wispy background (e.g. green
            # aurora that shares the teal hue) has a low fill ratio → reject.
            extent = cv2.contourArea(c) / float(area) if area else 0
            if extent < 0.80:
                continue
            ar = w / float(h) if h else 0
            if 0.6 < ar < 1.6:                      # single square tile
                raw.append((x, y, w, h))
            elif ar >= 1.6:                          # several tiles merged in a row
                k = int(round(ar))
                sw = w // k
                if sw * h >= min_area:
                    for i in range(k):
                        raw.append((x + i * sw, y, sw, h))
            elif ar <= 0.625:                        # several tiles merged in a column
                k = int(round(1.0 / ar))
                sh = h // k
                if w * sh >= min_area:
                    for i in range(k):
                        raw.append((x, y + i * sh, w, sh))

    if not raw:
        return []

    # Cluster candidate boxes by centre proximity (tiles are static).
    thresh = 0.12 * fw
    clusters: List[Dict[str, Any]] = []
    for (x, y, w, h) in raw:
        cx, cy = x + w / 2, y + h / 2
        placed = False
        for cl in clusters:
            if abs(cx - cl["cx"]) <= thresh and abs(cy - cl["cy"]) <= thresh:
                cl["boxes"].append((x, y, w, h))
                cl["cx"] = np.mean([b[0] + b[2] / 2 for b in cl["boxes"]])
                cl["cy"] = np.mean([b[1] + b[3] / 2 for b in cl["boxes"]])
                placed = True
                break
        if not placed:
            clusters.append({"cx": cx, "cy": cy, "boxes": [(x, y, w, h)]})

    cells: List[Tuple[int, int, int, int]] = []
    for cl in clusters:
        if len(cl["boxes"]) < max(1, len(probe) // 3):   # weak/spurious cluster
            continue
        bx = int(np.median([b[0] for b in cl["boxes"]]))
        by = int(np.median([b[1] for b in cl["boxes"]]))
        bw = int(np.median([b[2] for b in cl["boxes"]]))
        bh = int(np.median([b[3] for b in cl["boxes"]]))
        cells.append((bx, by, bw, bh))

    # Order row-by-row (top→bottom), then left→right.
    row_h = 0.2 * fh
    cells.sort(key=lambda b: (round((b[1] + b[3] / 2) / row_h), b[0]))
    return cells


def _detect_tile_color(frames: List[np.ndarray]):
    """
    Auto-detect which known tile colour the board uses by trying each and
    keeping whichever yields the most stable square cells.
    Returns (color_name, rng, cells).
    """
    best = (None, None, [])
    for name, rng in _TILE_COLORS.items():
        cells = _detect_card_cells(frames, rng)
        print(f"[diff-animal] tile-colour probe '{name}': {len(cells)} cell(s)")
        if len(cells) > len(best[2]):
            best = (name, rng, cells)
    return best


def _animal_descriptor(crop: np.ndarray, rng=None):
    """
    Build a comparison descriptor for the animal shown on a tile.

    Robustly isolates the animal as the LARGEST connected non-tile blob (so a
    few stray border / compression-noise pixels can't blow the crop up to the
    whole tile), then returns:
      - a 64x64 contrast-normalised grayscale (for structural NCC)
      - a hue/sat colour histogram over the animal pixels
    Same animal sprite -> near-identical descriptors; different animal -> far.
    Returns None if too little animal is visible.
    """
    if crop is None or crop.size == 0:
        return None

    h, w = crop.shape[:2]
    # Trim a thin border: tile rounded corners + edge compression noise.
    b = max(2, int(0.05 * min(h, w)))
    inner = crop[b:h - b, b:w - b]
    if inner.size == 0:
        return None

    tile = _tile_mask(inner, rng)
    animal = (tile == 0).astype(np.uint8)
    # Drop tiny specks, then keep the single largest blob = the animal.
    animal = cv2.morphologyEx(animal, cv2.MORPH_OPEN, np.ones((3, 3), np.uint8))
    n, labels, stats, _ = cv2.connectedComponentsWithStats(animal, connectivity=8)
    if n <= 1:
        return None
    areas = stats[1:, cv2.CC_STAT_AREA]
    big = 1 + int(np.argmax(areas))
    if stats[big, cv2.CC_STAT_AREA] < 60:
        return None

    x = int(stats[big, cv2.CC_STAT_LEFT])
    y = int(stats[big, cv2.CC_STAT_TOP])
    bw = int(stats[big, cv2.CC_STAT_WIDTH])
    bh = int(stats[big, cv2.CC_STAT_HEIGHT])
    tight = inner[y:y + bh, x:x + bw]
    comp_mask = (labels[y:y + bh, x:x + bw] == big).astype(np.uint8) * 255
    if tight.size == 0:
        return None

    g = cv2.cvtColor(tight, cv2.COLOR_BGR2GRAY)
    g = cv2.resize(g, (64, 64), interpolation=cv2.INTER_AREA).astype(np.float32)
    g = (g - g.mean()) / (g.std() + 1e-6)

    hsv = cv2.cvtColor(tight, cv2.COLOR_BGR2HSV)
    hist = cv2.calcHist([hsv], [0, 1], comp_mask, [18, 16], [0, 180, 0, 256])
    cv2.normalize(hist, hist)
    return g, hist


def _animal_distance(a, b) -> float:
    """0 = identical animal, ~1 = clearly different animal."""
    g1, h1 = a
    g2, h2 = b
    ncc = float((g1 * g2).mean())                                  # structure
    hc = float(cv2.compareHist(h1, h2, cv2.HISTCMP_CORREL))        # colour
    sim = 0.5 * ncc + 0.5 * hc
    return 1.0 - sim


def _tile_fullness(crop: np.ndarray, rng=None):
    """
    Measure how 'fully face-on' a tile is, returning (presence, w_frac, h_frac).

    - presence : fraction of the cell that is NOT the tile colour.
    - w_frac/h_frac : how much of the cell width/height the tile colour spans.

    Why this matters: during the flip animation the tile rotates and becomes
    horizontally squished, so the photo BACKGROUND shows on the sides inside
    the cell box. That background is not the tile colour and would falsely
    inflate 'presence'. A fully face-on tile (back OR fully-revealed front) has
    the tile colour spanning the whole cell -> w_frac and h_frac ≈ 1.0. A
    mid-flip tile is squished -> small w_frac. So we only trust frames with
    w_frac/h_frac high.
    """
    if crop.size == 0:
        return 0.0, 0.0, 0.0
    tile = _tile_mask(crop, rng)
    presence = float((tile == 0).mean())
    ch, cw = crop.shape[:2]
    ys, xs = np.where(tile > 0)
    if len(xs) < 10:
        return presence, 0.0, 0.0
    w_frac = (xs.max() - xs.min() + 1) / float(cw)
    h_frac = (ys.max() - ys.min() + 1) / float(ch)
    return presence, w_frac, h_frac


def _select_tile_representative(crops: List[np.ndarray], rng=None):
    """
    Given several fully-revealed crops of the SAME tile, return one robust
    representative descriptor: the medoid — the crop most similar to all the
    others. This averages out any single noisy frame.
    Returns (descriptor, index) or (None, -1).
    """
    descs = []
    for c in crops:
        d = _animal_descriptor(c, rng)
        if d is not None:
            descs.append(d)
    if not descs:
        return None, -1
    if len(descs) == 1:
        return descs[0], 0

    n = len(descs)
    totals = [0.0] * n
    for i in range(n):
        for j in range(i + 1, n):
            d = _animal_distance(descs[i], descs[j])
            totals[i] += d
            totals[j] += d
    best = int(np.argmin(totals))
    return descs[best], best


def _pick_odd_tile(reps: List[Dict[str, Any]]):
    """
    Decide which revealed tile holds the DIFFERENT animal.

    Strategy: group tiles whose animals look alike (distance < SAME_THRESH).
    The largest group is the common animal; the tile outside it is the odd
    one. Falls back to 'most-dissimilar-overall' when there is no clear group.
    Returns (odd_index, distance_matrix).
    """
    n = len(reps)
    D = [[0.0] * n for _ in range(n)]
    for i in range(n):
        for j in range(i + 1, n):
            d = _animal_distance(reps[i]["desc"], reps[j]["desc"])
            D[i][j] = D[j][i] = d

    if n == 2:
        # Two tiles, both 'different' from each other — undecidable which is odd.
        return None, D

    SAME_THRESH = 0.40  # same-animal pairs sit ~0.05; different ~0.8

    parent = list(range(n))

    def find(a):
        while parent[a] != a:
            parent[a] = parent[parent[a]]
            a = parent[a]
        return a

    def union(a, b):
        parent[find(a)] = find(b)

    for i in range(n):
        for j in range(i + 1, n):
            if D[i][j] < SAME_THRESH:
                union(i, j)

    groups: Dict[int, List[int]] = {}
    for i in range(n):
        groups.setdefault(find(i), []).append(i)
    ordered = sorted(groups.values(), key=len, reverse=True)

    # Clear majority + smaller odd group → answer is the smallest/outlier group.
    if len(ordered) >= 2 and len(ordered[0]) > len(ordered[-1]):
        majority = set(ordered[0])
        outliers = [i for i in range(n) if i not in majority]
        if len(outliers) == 1:
            return outliers[0], D
        # Several outliers: the one farthest from the majority animal.
        outliers.sort(key=lambda i: np.mean([D[i][m] for m in majority]), reverse=True)
        return outliers[0], D

    # No clean grouping: pick the tile most dissimilar from everything else.
    totals = [sum(D[i]) for i in range(n)]
    return int(np.argmax(totals)), D


def solve_different_animal(
    video_b64_list: List[str],
    screenshot_b64: Optional[str] = None,
    frame_step: int = 3,
) -> Optional[Dict[str, Any]]:
    """
    Solver for 'Select the card with a different animal' — model-free, fast.

    Speed design:
      - Only up to 60 frames are scanned (evenly sampled from the full video).
      - Cell detection uses 6 probe frames.
      - Upscaling is applied only to the tiny per-tile crops, NOT to all frames
        — so the cost is N_cells × TOP_K small resizes, not 60 full-frame resizes.
      - Only TOP_K=3 best crops per tile are used for the medoid.
    """
    frames, content_box = _decode_webm_frames(video_b64_list)
    if not frames:
        print("[diff-animal] ERROR: no frames decoded")
        return None

    cx_off, cy_off = content_box[0], content_box[1]
    fh0, fw0 = frames[0].shape[:2]          # ORIGINAL cropped-frame size

    # Cap the scan at 60 frames, evenly spread across the video.
    MAX_SCAN = 60
    if len(frames) > MAX_SCAN:
        idx = np.linspace(0, len(frames) - 1, MAX_SCAN).astype(int)
        scan = [frames[i] for i in idx]
    else:
        scan = frames

    print(f"[diff-animal] frames={len(frames)} scanned={len(scan)} "
          f"content_box={content_box} orig_size={fw0}x{fh0}")

    # Detect tile color and cell positions on the raw (un-upscaled) frames.
    color_name, rng, cells = _detect_tile_color(scan)
    if len(cells) < 2:
        print(f"[diff-animal] ERROR: detected only {len(cells)} card cell(s)")
        return None
    print(f"[diff-animal] tile colour='{color_name}' card cells detected: {len(cells)}")

    # Upscale factor: applied to individual tile crops, not full frames.
    # Target ~200 px per tile side for a clean descriptor (was 1000 px full frame).
    TARGET_CROP = 200
    MIN_PRESENCE = 0.08
    MAX_PRESENCE = 0.65
    FULL_FRAC    = 0.90
    TOP_K        = 3       # 3 peak-presence crops → fast medoid, still robust

    revealed: List[Dict[str, Any]] = []
    for (x, y, w, h) in cells:
        cxc, cyc = x + w / 2.0, y + h / 2.0
        candidates: List[Tuple[float, np.ndarray]] = []

        for f in scan:
            crop = f[y:y + h, x:x + w]
            pres, wf, hf = _tile_fullness(crop, rng)
            if wf >= FULL_FRAC and hf >= FULL_FRAC and MIN_PRESENCE <= pres <= MAX_PRESENCE:
                candidates.append((pres, crop.copy()))

        if not candidates:
            print(f"  cell center=({cxc:.0f},{cyc:.0f}) full_frames=0 -> skipped")
            continue

        # Take the TOP_K most-revealed crops and upscale them (cheap: tiny crops).
        candidates.sort(key=lambda c: c[0], reverse=True)
        top_crops = []
        cup = max(1.0, TARGET_CROP / max(w, h))   # per-crop upscale factor
        for _, crop in candidates[:TOP_K]:
            if cup > 1.01:
                crop = cv2.resize(crop,
                                  (int(round(crop.shape[1] * cup)),
                                   int(round(crop.shape[0] * cup))),
                                  interpolation=cv2.INTER_CUBIC)
            top_crops.append(crop)

        desc, _ = _select_tile_representative(top_crops, rng)
        if desc is None:
            print(f"  cell center=({cxc:.0f},{cyc:.0f}) full_frames={len(candidates)} "
                  f"-> descriptor failed, skipped")
            continue

        revealed.append({"cx": cxc, "cy": cyc, "frames": len(candidates), "desc": desc})
        print(f"  cell center=({cxc:.0f},{cyc:.0f}) full_frames={len(candidates)} "
              f"-> animal captured (crop_up={cup:.1f}x)")

    if len(revealed) < 2:
        print(f"[diff-animal] ERROR: only {len(revealed)} clean animal(s) — cannot compare")
        return None

    odd_idx, D = _pick_odd_tile(revealed)

    for i in range(len(revealed)):
        row = "  ".join(f"{D[i][j]:.3f}" for j in range(len(revealed)))
        print(f"    dist[{i}] = {row}")

    if odd_idx is None:
        print("[diff-animal] ERROR: ambiguous (only two distinct animals) — cannot pick odd one")
        return None

    odd = revealed[odd_idx]
    avg_dist = sum(D[odd_idx]) / max(1, len(revealed) - 1)
    print(f"[diff-animal] ODD = cell{odd_idx} center=({odd['cx']:.0f},{odd['cy']:.0f}) "
          f"avg-dist={avg_dist:.3f} frames={odd['frames']}")

    # Cell centers are in the raw-frame coordinate space (no full-frame upscale).
    orig_x = odd["cx"] + cx_off
    orig_y = odd["cy"] + cy_off

    # Scale to screenshot coordinates.
    if screenshot_b64:
        try:
            ss_img = decode_b64_to_pil(screenshot_b64)
            ss_w, ss_h = ss_img.size
            full_w = content_box[2] if content_box[2] > 0 else fw0
            full_h = cy_off + fh0
            scale_x = ss_w / full_w
            scale_y = ss_h / full_h
            orig_x = orig_x * scale_x
            orig_y = orig_y * scale_y
            print(f"[diff-animal] screenshot={ss_w}x{ss_h} full={full_w}x{full_h} "
                  f"scale=({scale_x:.3f},{scale_y:.3f}) scaled=({orig_x:.0f},{orig_y:.0f})")
        except Exception as _se:
            print(f"[diff-animal] coord scaling failed: {_se}")

    return {
        "x": int(round(orig_x)),
        "y": int(round(orig_y)),
        "cells": [
            {"x": int(round(r["cx"] + cx_off)), "y": int(round(r["cy"] + cy_off)),
             "frames": r["frames"]}
            for r in revealed
        ],
    }


# ======================================================
# LIFESPAN
# ======================================================
@asynccontextmanager
async def lifespan(app: FastAPI):
    scan_hot_models()
    scan_cold_models()

    evictor = threading.Thread(target=_evictor_loop, daemon=True)
    evictor.start()

    print(
        f"[startup] hot={len(_hot_model_files)} models, "
        f"cold={len(_cold_index)} indexed questions"
    )

    yield

    _evictor_stop.set()


# ======================================================
# FASTAPI
# ======================================================
app = FastAPI(title="YOLO ONNX + White Letter Solver", version="11.0.0", lifespan=lifespan)


# ======================================================
# SCHEMAS
# ======================================================
class PredictRequest(BaseModel):
    class_name: str
    image_base64: str
    conf_threshold: Optional[float] = 0.20
    iou_threshold: Optional[float] = 0.70
    img_size: Optional[int] = 256


class ExtensionCheckRequest(BaseModel):
    model_config = {"extra": "allow"}   # absorb any unknown fields → no 422

    question: Optional[str] = "Find the object rotating in the opposite direction"
    questionType: Optional[str] = "objectClassify"
    # queries can be omitted OR sent as JSON null → both normalised to []
    queries: Optional[List[str]] = None
    # video-puzzle fields (canvasVideo)
    canvasVideo: Optional[bool] = False
    video: Optional[List[str]] = None
    format: Optional[str] = None
    type: Optional[str] = None
    screenshot: Optional[bool] = None
    websiteURL: Optional[str] = None
    websiteKEY: Optional[str] = None
    examples: Optional[List[Any]] = None
    choices: Optional[List[Any]] = None

    @property
    def queries_list(self) -> List[str]:
        """Always returns a list, never None - use this instead of .queries directly."""
        return self.queries if self.queries is not None else []


class OuterCheckRequest(BaseModel):
    """Outer envelope sent by the captcha client: { apiKey, source, task: {...} }"""
    model_config = {"extra": "allow"}
    apiKey: Optional[str] = None
    source: Optional[str] = None
    version: Optional[str] = None
    appID: Optional[Any] = None
    task: Optional[Dict[str, Any]] = None

    def to_inner(self) -> "ExtensionCheckRequest":
        if self.task:
            return ExtensionCheckRequest(**self.task)
        data = {k: v for k, v in (self.model_extra or {}).items()}
        return ExtensionCheckRequest(**data)


# ======================================================
# RESPONSE BUILDER
# ======================================================
def build_extension_response(answers, qtype):
    return {
        "code": 200,
        "msg": "",
        "answers": answers,
        "meta": {
            "pass_report": True,
            "fail_report": True,
            "data": "",
        },
        "questionType": qtype or "objectClassify",
    }


def build_click_response(result: Optional[Dict[str, Any]]):
    """
    Same nested click structure as your adjusted objectClick response.
    """
    if result is None:
        return {
            "answers": None,
            "code": 200,
            "meta": {
                "data": "",
                "fail_report": True,
                "pass_report": True,
            },
            "msg": "",
            "questionType": "objectClick",
        }

    return {
        "answers": [
            [
                {
                    "x": int(result["x"]),
                    "y": int(result["y"]),
                }
            ]
        ],
        "code": 200,
        "meta": {
            "data": "",
            "fail_report": True,
            "pass_report": True,
        },
        "msg": "",
        "questionType": "objectClick",
    }


# ======================================================
# /predict
# ======================================================
@app.post("/predict")
async def predict(req: PredictRequest):
    norm = normalize_label(req.class_name)

    onnx_path = _registry.get(norm)

    if not onnx_path:
        onnx_path = promote_from_cold(norm)

    if not onnx_path:
        return {"found": False}

    sess = load_session(onnx_path)

    with _lock:
        _model_last_used[onnx_path] = time.time()

    pil = decode_b64_to_pil(req.image_base64)
    tensor = preprocess(pil, req.img_size)

    found = detect(sess, tensor, req.conf_threshold)
    return {"found": found}


# ======================================================
# /check
# ======================================================
@app.post("/check")
async def check(outer: OuterCheckRequest):
    req = outer.to_inner()
    if not req.queries_list and not req.video:
        raise HTTPException(400, "queries or video required")

    norm_q = normalize_label(req.question)

    print(
        f"[check] question={req.question!r} "
        f"norm={norm_q!r} "
        f"qtype={req.questionType!r} "
        f"queries={len(req.queries_list)} "
        f"canvasVideo={req.canvasVideo} "
        f"video_chunks={len(req.video) if req.video else 0}"
    )

    # --------------------------------------------------
    # "Find the object rotating in the opposite direction"
    # --------------------------------------------------
    if is_opposite_rotation_question(req.question):
        has_video = bool(req.video)

        # No video yet → ask the client to send one
        if not has_video:
            print("[check] rotation question but no video → requesting canvasVideo")
            return {
                "code": 400,
                "questionVariant": "canvasvideo",
                "questionType": "objectClick",
                "canvasParams": {
                    "video": True,
                    "duration": 6,
                    "fps": 60,
                    "frames": None,
                    "framecount": 360,
                },
            }

        # Has video → decode and solve
        print("[check] rotation question with video → solving")
        try:
            result = solve_opposite_rotation(req.video, req.queries_list[0] if req.queries_list else None)
        except Exception as e:
            print(f"[check] rotation solver error: {e}")
            result = None
        return build_click_response(result)

    # --------------------------------------------------
    # "Select the card with a different animal"  (canvasVideo)
    # Cards reveal their animal over time → request a video, then sample
    # every 3rd frame, classify each card's animal, click the odd one out.
    # --------------------------------------------------
    if is_different_animal_question(req.question):
        has_video = bool(req.video)

        # No video yet → ask the client to record one.
        if not has_video:
            print("[check] different-animal question but no video → requesting canvasVideo")
            return {
                "code": 400,
                "questionVariant": "canvasvideo",
                "questionType": "objectClick",
                "canvasParams": {
                    "video": True,
                    "duration": 6,
                    "fps": 60,
                    "frames": None,
                    "framecount": 360,
                },
            }

        # Has video → decode, sample every 3rd frame, and solve.
        print("[check] different-animal question with video → solving")
        try:
            result = solve_different_animal(
                req.video,
                req.queries_list[0] if req.queries_list else None,
                frame_step=3,
            )
        except Exception as e:
            print(f"[check] different-animal solver error: {e}")
            result = None
        return build_click_response(result)

    # --------------------------------------------------
    # Special case: "Click the card that does not match"
    # --------------------------------------------------
    if norm_q == "click the card that does not match" and req.questionType == "objectDrag":
        if not req.queries_list:
            return build_click_response(None)
        try:
            pil = decode_b64_to_pil(req.queries_list[0])
            img_w, img_h = pil.size

            base_x = img_w * 0.3458
            base_y = img_h * 0.7738

            first_card_x = int(base_x + random.randint(-3, 3))
            first_card_y = int(base_y + random.randint(-5, 5))

            print(f"[card-match] img={img_w}x{img_h} click=({first_card_x},{first_card_y})")

        except Exception as e:
            first_card_x = random.randint(179, 185)
            first_card_y = random.randint(210, 220)
            print(f"[card-match] fallback click=({first_card_x},{first_card_y}) err={e}")

        return {
            "answers": [
                [
                    {
                        "x": first_card_x,
                        "y": first_card_y,
                    }
                ]
            ],
            "code": 200,
            "meta": {
                "data": "",
                "fail_report": True,
                "pass_report": True,
            },
            "msg": "",
            "questionType": "objectClick",
        }

    # --------------------------------------------------
    # Special white-letter unpaired puzzle solver
    # --------------------------------------------------
    if is_unpaired_letter_question(req.question):
        if not req.queries_list:
            return build_click_response(None)
        print("[check] using white-letter unpaired solver")
        result = solve_unpaired_letter(req.queries_list[0])
        return build_click_response(result)

    # --------------------------------------------------
    # Checklist letter solver
    # "Select all letters from the checklist based on the count"
    # --------------------------------------------------
    if is_checklist_letter_question(req.question):
        if not req.queries_list:
            return build_click_response(None)
        print("[check] using checklist letter solver")
        clicks = solve_checklist_letters(req.queries_list[0])
        if not clicks:
            return build_click_response(None)
        # Return all click coordinates — one per required letter match.
        return {
            "answers": [[{"x": c["x"], "y": c["y"]} for c in clicks]],
            "code": 200,
            "meta": {
                "data": "",
                "fail_report": True,
                "pass_report": True,
            },
            "msg": "",
            "questionType": "objectClick",
        }

    # --------------------------------------------------
    # Animal icons puzzle solver
    # "Select ALL animal icons according to the counts shown"
    # --------------------------------------------------
    if is_animal_icons_question(req.question):
        if not req.queries_list:
            return build_click_response(None)
        print("[check] using animal icons solver")
        clicks = solve_animal_icons(req.queries_list[0])
        if not clicks:
            return build_click_response(None)
        return {
            "answers": [[{"x": c["x"], "y": c["y"]} for c in clicks]],
            "code": 200,
            "meta": {
                "data": "",
                "fail_report": True,
                "pass_report": True,
            },
            "msg": "",
            "questionType": "objectClick",
        }

    # --------------------------------------------------
    # Default existing YOLO / ONNX logic
    # --------------------------------------------------
    if not req.queries_list:
        raise HTTPException(400, "queries empty")

    answers: List[bool] = []

    for tile_b64 in req.queries_list:
        try:
            result = await predict(
                PredictRequest(
                    class_name=req.question,
                    image_base64=tile_b64,
                )
            )
            answers.append(bool(result.get("found", False)))

        except Exception as e:
            print(f"[check] prediction error: {e}")
            answers.append(False)

    if answers and not any(answers):
        return []

    return build_extension_response(answers, req.questionType)


# ======================================================
# DEBUG ENDPOINTS
# ======================================================
@app.post("/debug/unpaired-letter")
async def debug_unpaired_letter(req: ExtensionCheckRequest):
    if not req.queries_list:
        raise HTTPException(400, "queries empty")

    pil_img = decode_b64_to_pil(req.queries_list[0])

    tiles = _detect_tile_boxes(pil_img)

    masks = []
    valid_tiles = []

    for t in tiles:
        m = _letter_mask_from_tile_crop(t["crop"])

        if m is not None:
            valid_tiles.append(t)
            masks.append(m)

    shape_result = None
    ocr_result = None

    if len(valid_tiles) >= 3 and len(masks) == len(valid_tiles):
        shape_result = _pick_unpaired_tile_by_shape(valid_tiles, masks)
        ocr_result = _pick_unpaired_by_ocr_fallback(valid_tiles, masks)

    result = solve_unpaired_letter(req.queries_list[0])

    return {
        "tiles_detected": len(tiles),
        "tile_centers": [
            {
                "x": int(round(t["cx"])),
                "y": int(round(t["cy"])),
                "score": round(float(t.get("score", 0)), 2),
                "tile_ratio": round(float(t.get("tile_ratio", 0)), 3),
                "ratio": round(float(t.get("ratio", 0)), 3),
                "rectangularity": round(float(t.get("rectangularity", 0)), 3),
            }
            for t in tiles
        ],
        "valid_tile_letters": len(valid_tiles),
        "shape_result": shape_result,
        "ocr_result": ocr_result,
        "result": result,
    }


@app.get("/health")
async def health():
    try:
        version = pytesseract.get_tesseract_version()
        tesseract_ok = True
        tesseract_version = str(version)
    except Exception as e:
        tesseract_ok = False
        tesseract_version = str(e)

    return {
        "ok": True,
        "hot_models": len(_hot_model_files),
        "cold_models": len(_cold_index),
        "tesseract_ok": tesseract_ok,
        "tesseract": tesseract_version,
        "tesseract_cmd": getattr(pytesseract.pytesseract, "tesseract_cmd", "tesseract"),
    }


@app.get("/models")
async def list_models():
    with _lock:
        hot = {q: os.path.basename(os.path.dirname(p)) for q, p in _registry.items()}
        cold = {q: os.path.basename(os.path.dirname(p)) for q, p in _cold_index.items()}

        return {
            "hot_count": len(_hot_model_files),
            "hot": hot,
            "cold_count": len(_cold_index),
            "cold": cold,
        }


# ======================================================
# RUN:
# uvicorn app:app --host 0.0.0.0 --port 8000
# ======================================================
