Source code for geoprior.scripts.plot_ablations_sensitivity

# SPDX-License-Identifier: Apache-2.0
# GeoPrior-v3 — https://github.com/earthai-tech/geoprior-v3
# Copyright (c) 2026-present
# Author: LKouadio <https://lkouadio.com>
#
r"""Plot extended ablations & sensitivities.

Updated v3.2 behavior:

- Prefer ablation_record.updated*.jsonl when present.
- Dedupe (timestamp, city): keep best record (updated wins).
- Normalize pde_mode to buckets: none vs both
  (off/none -> none, else -> both).
- Optional explicit ``--input`` (csv/json/jsonl) and ``--models``
  filtering.

Plot upgrades (v3.2+):

- Configurable colormap via ``--cmap``.
- Multiple heatmap metrics in a single figure via
  ``--heatmap-metrics``.
- Optional heatmap-only layout (no bar row) via ``--no-bars``.
- Optional ``--map-kind`` values: ``heatmap`` for discrete grids,
  ``smooth`` for interpolated grids, ``contour`` for ``contourf`` on
  a pivot grid, and ``tricontour`` for scattered cells.
- Consistent lambda grid alignment across cities and metrics via
  ``--align-grid``.
- Optional marker for a chosen ``(lambda_cons, lambda_prior)`` point.
- Optional Pareto trade-off view via ``--pareto`` with a scatter of
  MAE versus sharpness, an optional non-dominated front overlay, and
  an optional ``--pareto-density`` hexbin overlay.
"""

from __future__ import annotations

import argparse
import json
from collections.abc import Iterable
from pathlib import Path
from typing import Any

import matplotlib.pyplot as plt
import matplotlib.tri as mtri
import numpy as np
import pandas as pd

from . import config as cfg
from . import utils

_LOWER_IS_BETTER = {
    "mae",
    "rmse",
    "mse",
    "sharpness80",
    "epsilon_prior",
    "epsilon_cons",
}


# ---------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------
def _parse_args(
    argv: list[str] | None,
    *,
    prog: str | None = None,
) -> argparse.Namespace:
    p = argparse.ArgumentParser(
        prog=prog or "plot-ablations-sensitivity",
        description="Supplement S6: Ablations & sensitivities",
    )

    p.add_argument(
        "--root",
        type=str,
        default="results",
        help="Root to scan for ablation_record*.jsonl",
    )
    p.add_argument(
        "--input",
        type=str,
        default=None,
        action="append",
        help=(
            "Explicit ablation file/table (repeatable). "
            "Supports .jsonl/.json/.csv."
        ),
    )
    p.add_argument(
        "--out-dir",
        type=str,
        default=None,
        help="Optional output dir override.",
    )

    p.add_argument("--font", type=int, default=cfg.PAPER_FONT)
    p.add_argument("--dpi", type=int, default=cfg.PAPER_DPI)

    # --- plot controls ------------------------------------------------
    p.add_argument(
        "--cmap",
        type=str,
        default="viridis",
        help=(
            "Colormap name for maps (e.g. viridis, magma, cividis). "
            "Used by heatmap/smooth/contour/tricontour."
        ),
    )
    p.add_argument(
        "--map-kind",
        type=str,
        default="heatmap",
        choices=[
            "heatmap",
            "smooth",
            "contour",
            "tricontour",
        ],
        help=(
            "Map rendering style. 'heatmap' uses imshow on the lambda "
            "grid. 'tricontour' gives a density-like smooth surface and "
            "handles missing cells."
        ),
    )
    p.add_argument(
        "--levels",
        type=int,
        default=12,
        help="Number of contour levels for contour/tricontour.",
    )
    p.add_argument(
        "--contour-lines",
        type=str,
        default="false",
        help="Overlay contour lines on contour/tricontour (true/false).",
    )

    p.add_argument(
        "--bar-metric",
        type=str,
        default="mae",
        choices=[
            "r2",
            "mae",
            "rmse",
            "mse",
            "coverage80",
            "sharpness80",
            "epsilon_prior",
            "epsilon_cons",
        ],
        help="Metric for bar plots (used when bars are enabled).",
    )
    p.add_argument(
        "--heatmap-metric",
        type=str,
        default="sharpness80",
        choices=[
            "mae",
            "rmse",
            "mse",
            "r2",
            "coverage80",
            "sharpness80",
            "epsilon_prior",
            "epsilon_cons",
        ],
        help="Single metric for heatmaps (legacy).",
    )
    p.add_argument(
        "--heatmap-metrics",
        type=str,
        default="",
        help=(
            "Comma list of map metrics to plot in one figure. "
            "Example: mae,coverage80,sharpness80. "
            "If provided, overrides --heatmap-metric."
        ),
    )
    p.add_argument(
        "--no-bars",
        type=str,
        default="false",
        help=(
            "Disable the bar-row entirely (true/false). "
            "Recommended when using --heatmap-metrics."
        ),
    )

    p.add_argument(
        "--align-grid",
        type=str,
        default="true",
        help=(
            "Align lambda grids (ticks/ordering) across cities/metrics "
            "(true/false). When true, missing cells are left blank."
        ),
    )

    p.add_argument(
        "--models",
        type=str,
        default="",
        help="Comma list of model names to keep.",
    )

    p.add_argument(
        "--mute-values",
        type=str,
        default="false",
        help="Annotate values above bars (true/false).",
    )
    p.add_argument(
        "--no-colorbar",
        type=str,
        default="false",
        help="Hide colorbars (true/false).",
    )

    p.add_argument(
        "--mark-lambda-cons",
        type=float,
        default=None,
        help="Optional marker: lambda_cons value to mark on maps.",
    )
    p.add_argument(
        "--mark-lambda-prior",
        type=float,
        default=None,
        help="Optional marker: lambda_prior value to mark on maps.",
    )
    p.add_argument(
        "--no-best-marker",
        type=str,
        default="false",
        help="Disable automatic best-cell marker (true/false).",
    )

    p.add_argument(
        "--pareto",
        type=str,
        default="false",
        help=(
            "Plot Pareto scatter (trade-offs) instead of maps. "
            "Default: false."
        ),
    )
    p.add_argument(
        "--pareto-x",
        type=str,
        default="mae",
        choices=[
            "mae",
            "rmse",
            "mse",
            "r2",
            "coverage80",
            "sharpness80",
            "epsilon_prior",
            "epsilon_cons",
        ],
        help="X metric for Pareto scatter (default: mae).",
    )
    p.add_argument(
        "--pareto-y",
        type=str,
        default="sharpness80",
        choices=[
            "mae",
            "rmse",
            "mse",
            "r2",
            "coverage80",
            "sharpness80",
            "epsilon_prior",
            "epsilon_cons",
        ],
        help="Y metric for Pareto scatter (default: sharpness80).",
    )
    p.add_argument(
        "--pareto-color",
        type=str,
        default="coverage80",
        choices=[
            "coverage80",
            "sharpness80",
            "mae",
            "mse",
            "r2",
            "epsilon_prior",
            "epsilon_cons",
        ],
        help="Color metric for Pareto scatter (default: coverage80).",
    )
    p.add_argument(
        "--pareto-front",
        type=str,
        default="true",
        help="Overlay non-dominated front (true/false).",
    )
    p.add_argument(
        "--pareto-s",
        type=float,
        default=26.0,
        help="Marker size for Pareto scatter.",
    )
    p.add_argument(
        "--pareto-alpha",
        type=float,
        default=0.85,
        help="Marker alpha for Pareto scatter.",
    )

    p.add_argument(
        "--pareto-density",
        type=str,
        default="false",
        help=(
            "Add a density overlay (hexbin) behind the Pareto scatter "
            "(true/false). Useful when many runs overlap."
        ),
    )
    p.add_argument(
        "--pareto-density-gridsize",
        type=int,
        default=35,
        help="Hexbin gridsize for Pareto density overlay.",
    )
    p.add_argument(
        "--pareto-density-bins",
        type=str,
        default="log",
        choices=["log", "linear"],
        help="Hexbin bin scaling (log or linear).",
    )
    p.add_argument(
        "--pareto-density-cmap",
        type=str,
        default="Greys",
        help="Colormap for Pareto density overlay (hexbin).",
    )
    p.add_argument(
        "--pareto-density-alpha",
        type=float,
        default=0.25,
        help="Alpha for Pareto density overlay.",
    )
    p.add_argument(
        "--pareto-density-mincnt",
        type=int,
        default=1,
        help="Minimum count per hexagon to be shown (default: 1).",
    )

    p.add_argument(
        "--cities",
        type=str,
        default="",
        help="Comma list of cities (overrides city-a/b).",
    )
    p.add_argument(
        "--city-a",
        type=str,
        default="Nansha",
        help="First city name.",
    )
    p.add_argument(
        "--city-b",
        type=str,
        default="Zhongshan",
        help="Second city name.",
    )

    utils.add_plot_text_args(
        p,
        default_out="supp_fig_S6_ablations",
    )
    return p.parse_args(argv)


# ---------------------------------------------------------------------
# I/O
# ---------------------------------------------------------------------
def _read_jsonl(fp: Path) -> list[dict[str, Any]]:
    rows: list[dict[str, Any]] = []
    with fp.open("r", encoding="utf-8") as f:
        for ln in f:
            s = ln.strip()
            if not s:
                continue
            try:
                rec = json.loads(s)
            except Exception:
                continue
            if isinstance(rec, dict):
                rows.append(rec)
    return rows


def _load_one_input(path: Path) -> pd.DataFrame:
    suf = path.suffix.lower()

    if suf == ".csv":
        df = pd.read_csv(path)
        if "_src" not in df.columns:
            df["_src"] = str(path)
        return df

    if suf == ".jsonl":
        rows = _read_jsonl(path)
        df = pd.DataFrame(rows)
        if "_src" not in df.columns:
            df["_src"] = str(path)
        return df

    if suf == ".json":
        try:
            obj = json.loads(path.read_text(encoding="utf-8"))
        except Exception:
            obj = None

        if isinstance(obj, list):
            df = pd.DataFrame(obj)
            if "_src" not in df.columns:
                df["_src"] = str(path)
            return df

        if isinstance(obj, dict):
            df = pd.DataFrame([obj])
            if "_src" not in df.columns:
                df["_src"] = str(path)
            return df

    return pd.DataFrame([])


def _scan_records(root: Path) -> pd.DataFrame:
    files = utils.find_all(
        root,
        cfg.PATTERNS.get("ablation_record_jsonl", ()),
    )

    blocks: list[pd.DataFrame] = []
    for fp in files:
        df = _load_one_input(fp)
        if not df.empty:
            blocks.append(df)

    if not blocks:
        return pd.DataFrame([])
    return pd.concat(blocks, ignore_index=True)


def _load_records(args: argparse.Namespace) -> pd.DataFrame:
    if args.input:
        blocks: list[pd.DataFrame] = []
        for s in args.input:
            p = utils.as_path(s)
            if p.exists():
                df = _load_one_input(p)
                if not df.empty:
                    blocks.append(df)
        if not blocks:
            return pd.DataFrame([])
        return pd.concat(blocks, ignore_index=True)

    root = utils.as_path(args.root)
    return _scan_records(root)


# ---------------------------------------------------------------------
# Canonicalization / precedence
# ---------------------------------------------------------------------
def _canon_pde_mode(x: Any) -> str:
    s = str(x or "").strip().lower()
    if s in {"none", "off", "no", "0", "false"}:
        return "none"
    if s in {"both", "on", "1", "true"}:
        return "both"
    # treat any physics-on mode as "both" bucket
    if s in {"consolidation", "gw_flow"}:
        return "both"
    return s


def _canon_cols(df: pd.DataFrame) -> pd.DataFrame:
    if df.empty:
        return df

    aliases = {
        "timestamp": ("timestamp", "ts"),
        "city": ("city", "City"),
        "model": ("model", "Model"),
        "pde_mode": ("pde_mode", "pde"),
        "lambda_prior": ("lambda_prior", "lambda_p"),
        "lambda_cons": ("lambda_cons", "lambda_c"),
        "coverage80": ("coverage80", "coverage8"),
        "sharpness80": ("sharpness80", "sharpness"),
    }
    utils.ensure_columns(df, aliases=aliases)

    if "city" in df.columns:
        df["city"] = (
            df["city"].astype(str).map(utils.canonical_city)
        )

    if "pde_mode" in df.columns:
        df["pde_mode"] = df["pde_mode"].map(_canon_pde_mode)
    else:
        df["pde_mode"] = "both"

    # stable bucket used by S6 layout:
    #   - "none": physics off
    #   - "both": any physics on mode
    df["pde_bucket"] = np.where(
        df["pde_mode"].astype(str).eq("none"),
        "none",
        "both",
    )

    num_cols = [
        "lambda_prior",
        "lambda_cons",
        "r2",
        "mae",
        "rmse",
        "mse",
        "coverage80",
        "sharpness80",
        "epsilon_prior",
        "epsilon_cons",
    ]
    for c in num_cols:
        if c in df.columns:
            df[c] = pd.to_numeric(df[c], errors="coerce")

    return df


def _needs_si_to_mm(row: pd.Series) -> bool:
    u = row.get("units", None)
    if isinstance(u, dict):
        uu = str(u.get("subs_metrics_unit", "")).lower()
        if uu == "mm":
            return False
        if uu in {"m", "meter", "metre"}:
            return True

    # heuristic: SI-scale metrics tend to be < 1 in meters
    mae = row.get("mae", np.nan)
    mse = row.get("mse", np.nan)
    shp = row.get("sharpness80", np.nan)

    ok = (
        np.isfinite(mae)
        and np.isfinite(mse)
        and np.isfinite(shp)
        and 0.0 < float(mae) < 1.0
        and 0.0 < float(mse) < 1.0
        and 0.0 < float(shp) < 1.0
    )
    return bool(ok)


def _harmonize_units(df: pd.DataFrame) -> pd.DataFrame:
    if df.empty:
        return df

    if "mae" not in df.columns or "mse" not in df.columns:
        return df

    mask = df.apply(_needs_si_to_mm, axis=1)
    if not mask.any():
        return df

    for c in ["mae", "rmse", "sharpness80"]:
        if c in df.columns:
            df.loc[mask, c] = df.loc[mask, c] * 1000.0

    if "mse" in df.columns:
        df.loc[mask, "mse"] = df.loc[mask, "mse"] * 1e6

    return df


def _record_score(row: pd.Series) -> int:
    s = 0
    src = str(row.get("_src", "")).lower()
    if "updated" in src:
        s += 100

    u = row.get("units", None)
    if isinstance(u, dict):
        uu = str(u.get("subs_metrics_unit", "")).lower()
        if uu == "mm":
            s += 50

    if isinstance(row.get("legacy", None), dict):
        s += 10
    if isinstance(row.get("metrics", None), dict):
        s += 10

    return int(s)


def _dedupe_prefer_best(df: pd.DataFrame) -> pd.DataFrame:
    if df.empty:
        return df
    if "timestamp" not in df.columns:
        return df
    if "city" not in df.columns:
        df["city"] = ""

    df = df.copy()
    df["_score"] = df.apply(_record_score, axis=1)

    df = df.sort_values(
        by=["timestamp", "city", "_score"],
        ascending=[True, True, False],
    )
    df = df.drop_duplicates(
        subset=["timestamp", "city"],
        keep="first",
    ).copy()

    return df.drop(columns=["_score"], errors="ignore")


def _filter_models(
    df: pd.DataFrame, models: str
) -> pd.DataFrame:
    s = str(models or "").strip()
    if not s or "model" not in df.columns:
        return df

    keep = [m.strip() for m in s.split(",") if m.strip()]
    if not keep:
        return df

    return df.loc[df["model"].astype(str).isin(keep)].copy()


def _resolve_cities(args: argparse.Namespace) -> list[str]:
    raw = str(args.cities or "").strip()
    if raw:
        parts = [
            p.strip() for p in raw.split(",") if p.strip()
        ]
        return [utils.canonical_city(p) for p in parts]

    a = utils.canonical_city(args.city_a)
    b = utils.canonical_city(args.city_b)
    return [a, b]


def _resolve_map_metrics(
    args: argparse.Namespace,
) -> list[str]:
    raw = str(args.heatmap_metrics or "").strip()
    if not raw:
        return [str(args.heatmap_metric)]
    mets = [m.strip() for m in raw.split(",") if m.strip()]
    # unique, stable order
    out: list[str] = []
    for m in mets:
        if m not in out:
            out.append(m)
    return out or [str(args.heatmap_metric)]


# ---------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------
def _coverage_target(metric: str) -> float | None:
    s = str(metric).strip().lower()
    if s.startswith("coverage"):
        tail = s[len("coverage") :]
        if tail.isdigit():
            return int(tail) / 100.0
    return None


def _metric_label(name: str) -> str:
    k = str(name).strip()
    if k in cfg.PHYS_LABELS:
        return cfg.PHYS_LABELS[k]

    # Coverage is special: target-based, not monotone.
    tgt = _coverage_target(k)
    if tgt is not None:
        p = int(round(100.0 * tgt))
        return f"Coverage ({p}% PI; target {tgt:.2f})"

    meta = cfg.PLOT_METRIC_META.get(k, None)
    if isinstance(meta, dict):
        txt = str(
            meta.get("title") or meta.get("ylabel") or k
        )
        unit = str(meta.get("unit", "") or "")
        return txt.format(unit=unit)

    kl = k.lower()
    if kl == "r2":
        return r"$R^2$ (↑)"
    return k


def _best_ij(
    arr: np.ndarray, *, metric: str
) -> tuple[int, int] | None:
    try:
        a = np.asarray(arr, dtype=float)
        if not np.isfinite(a).any():
            return None

        kl = str(metric).strip().lower()
        tgt = _coverage_target(kl)

        if tgt is not None:
            obj = np.abs(a - tgt)
            obj[~np.isfinite(a)] = np.nan
            idx = int(np.nanargmin(obj))
        elif kl in _LOWER_IS_BETTER:
            idx = int(np.nanargmin(a))
        else:
            idx = int(np.nanargmax(a))

        i, j = np.unravel_index(idx, a.shape)
        return (int(i), int(j))
    except Exception:
        return None


def _metric_fmt(metric: str) -> str:
    meta = cfg.PLOT_METRIC_META.get(str(metric), None)
    if isinstance(meta, dict):
        return str(meta.get("fmt", "{:.3g}"))
    return "{:.3g}"


def _axes_cleanup(ax: plt.Axes) -> None:
    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)


def _bars_by_lambda(
    ax: plt.Axes,
    df: pd.DataFrame,
    *,
    city: str,
    metric: str,
    color: str,
    annotate: bool,
    show_legend: bool,
    show_labels: bool,
    show_ticks: bool,
    show_title: bool,
) -> None:
    sub = df.loc[df["city"].astype(str).eq(str(city))].copy()
    if sub.empty or metric not in sub.columns:
        ax.set_axis_off()
        return

    keep = sub.loc[
        sub["pde_bucket"].isin(["none", "both"])
    ].copy()
    if keep.empty:
        ax.set_axis_off()
        return

    grp = (
        keep.groupby(
            ["pde_bucket", "lambda_prior"], dropna=False
        )[metric]
        .mean()
        .reset_index()
    )
    if grp.empty:
        ax.set_axis_off()
        return

    xs = sorted(grp["lambda_prior"].dropna().unique())
    if not xs:
        ax.set_axis_off()
        return

    modes = ["none", "both"]
    base = np.arange(len(xs), dtype=float)
    width = 0.8 / 2.0
    offset0 = 0.5

    fmt = _metric_fmt(metric)

    for i, m in enumerate(modes):
        y = [
            grp.loc[
                (grp["pde_bucket"] == m)
                & (grp["lambda_prior"] == x),
                metric,
            ].mean()
            for x in xs
        ]

        xloc = base + (i - offset0) * width
        ax.bar(
            xloc,
            y,
            width=width,
            label=m,
            color=color,
            alpha=0.35 if m == "none" else 0.95,
            edgecolor="white",
        )

        if annotate:
            for xi, yi in zip(xloc, y, strict=False):
                if pd.notna(yi):
                    ax.text(
                        xi,
                        yi,
                        fmt.format(float(yi)),
                        ha="center",
                        va="bottom",
                        fontsize=8,
                    )

    ax.set_xticks(base)

    if show_ticks:
        ax.set_xticklabels([f"{x:.2g}" for x in xs])
    else:
        ax.set_xticklabels([])

    if show_labels:
        ax.set_xlabel(r"$\lambda_{\mathrm{prior}}$")
        ax.set_ylabel(_metric_label(metric))
    else:
        ax.set_xlabel("")
        ax.set_ylabel("")

    if show_title:
        ax.set_title(
            f"{city}{_metric_label(metric)}",
            loc="left",
            pad=6,
            fontweight="bold",
        )

    if show_legend:
        ax.legend(title="physics", frameon=False)

    _axes_cleanup(ax)


def _lambda_grid(
    df: pd.DataFrame,
    *,
    cities: Iterable[str],
) -> tuple[list[float], list[float]]:
    sub = df.loc[
        df["city"].astype(str).isin([str(c) for c in cities])
    ].copy()
    sub = sub.loc[sub["pde_bucket"].eq("both")].copy()
    sub = sub.dropna(subset=["lambda_cons", "lambda_prior"])
    lc = sorted(
        set(float(x) for x in sub["lambda_cons"].unique())
    )
    lp = sorted(
        set(float(x) for x in sub["lambda_prior"].unique())
    )
    return lc, lp


def _pivot_grid(
    sub: pd.DataFrame,
    *,
    metric: str,
    lc_grid: list[float] | None,
    lp_grid: list[float] | None,
) -> pd.DataFrame:
    piv = sub.pivot_table(
        index="lambda_cons",
        columns="lambda_prior",
        values=metric,
        aggfunc="mean",
    )
    if piv.empty:
        return piv
    piv = piv.sort_index().sort_index(axis=1)

    if lc_grid is not None:
        piv = piv.reindex(index=lc_grid)
    if lp_grid is not None:
        piv = piv.reindex(columns=lp_grid)

    return piv


def _maybe_mark_point(
    ax: plt.Axes,
    *,
    lc_ticks: list[float],
    lp_ticks: list[float],
    mark_lc: float | None,
    mark_lp: float | None,
) -> None:
    if mark_lc is None or mark_lp is None:
        return
    if (mark_lc not in lc_ticks) or (mark_lp not in lp_ticks):
        return
    i = lc_ticks.index(mark_lc)
    j = lp_ticks.index(mark_lp)
    ax.scatter(
        [j],
        [i],
        marker="s",
        s=70,
        facecolors="none",
        edgecolors="white",
        linewidths=1.8,
    )


def _plot_map_on_grid(
    ax: plt.Axes,
    piv: pd.DataFrame,
    *,
    metric: str,
    cmap: str,
    map_kind: str,
    levels: int,
    contour_lines: bool,
    vmin: float | None,
    vmax: float | None,
) -> Any:
    """
    Returns an artist usable for colorbar (Image/ContourSet).
    """
    data = np.asarray(piv.values, dtype=float)

    # Mask NaNs so they render as "bad" color in imshow.
    mdata = np.ma.masked_invalid(data)
    cm = plt.get_cmap(cmap).copy()
    cm.set_bad(color=(0.85, 0.85, 0.85, 1.0))

    lp = list(piv.columns.astype(float))
    lc = list(piv.index.astype(float))

    if map_kind in {"heatmap", "smooth"}:
        interp = (
            "nearest" if map_kind == "heatmap" else "bicubic"
        )
        im = ax.imshow(
            mdata,
            aspect="auto",
            cmap=cm,
            interpolation=interp,
            vmin=vmin,
            vmax=vmax,
        )
        return im

    # contour/ tricontour want real coordinate grids
    X, Y = np.meshgrid(lp, lc)

    if map_kind == "contour":
        cs = ax.contourf(
            X,
            Y,
            mdata,
            levels=levels,
            cmap=cm,
            vmin=vmin,
            vmax=vmax,
        )
        if contour_lines:
            ax.contour(
                X,
                Y,
                mdata,
                levels=levels,
                colors="k",
                linewidths=0.4,
                alpha=0.35,
            )
        return cs

    raise ValueError(f"Unsupported map_kind: {map_kind}")


def _plot_map_tricontour(
    ax: plt.Axes,
    sub: pd.DataFrame,
    *,
    metric: str,
    cmap: str,
    levels: int,
    contour_lines: bool,
    vmin: float | None,
    vmax: float | None,
) -> Any:
    """
    Tricontour view from scattered (lambda_prior, lambda_cons) points.
    Handles missing lambda grid cells gracefully.
    """
    s = sub.dropna(
        subset=["lambda_cons", "lambda_prior", metric]
    ).copy()
    if s.empty:
        return None

    x = s["lambda_prior"].astype(float).to_numpy()
    y = s["lambda_cons"].astype(float).to_numpy()
    z = s[metric].astype(float).to_numpy()

    tri = mtri.Triangulation(x, y)
    cm = plt.get_cmap(cmap).copy()

    cs = ax.tricontourf(
        tri,
        z,
        levels=levels,
        cmap=cm,
        vmin=vmin,
        vmax=vmax,
    )
    if contour_lines:
        ax.tricontour(
            tri,
            z,
            levels=levels,
            colors="k",
            linewidths=0.4,
            alpha=0.35,
        )
    return cs


def _heatmap_one(
    ax: plt.Axes,
    df: pd.DataFrame,
    *,
    city: str,
    metric: str,
    cmap: str,
    map_kind: str,
    levels: int,
    contour_lines: bool,
    lc_grid: list[float] | None,
    lp_grid: list[float] | None,
    vmin: float | None,
    vmax: float | None,
    show_labels: bool,
    show_ticks: bool,
    show_title: bool,
    mark_best: bool,
    mark_lc: float | None,
    mark_lp: float | None,
) -> Any | None:
    sub = df.loc[df["city"].astype(str).eq(str(city))].copy()
    sub = sub.loc[sub["pde_bucket"].eq("both")].copy()

    if sub.empty or metric not in sub.columns:
        ax.set_axis_off()
        return None

    sub = sub.dropna(subset=["lambda_cons", "lambda_prior"])
    if sub.empty:
        ax.set_axis_off()
        return None

    if map_kind == "tricontour":
        artist = _plot_map_tricontour(
            ax,
            sub,
            metric=metric,
            cmap=cmap,
            levels=levels,
            contour_lines=contour_lines,
            vmin=vmin,
            vmax=vmax,
        )
        if artist is None:
            ax.set_axis_off()
            return None

        # ticks: use aligned grids if available, else unique points
        lc_ticks = lc_grid or sorted(
            set(float(x) for x in sub["lambda_cons"].unique())
        )
        lp_ticks = lp_grid or sorted(
            set(
                float(x) for x in sub["lambda_prior"].unique()
            )
        )
        if show_ticks:
            ax.set_xticks(lp_ticks)
            ax.set_yticks(lc_ticks)
            ax.set_xticklabels([f"{x:.2g}" for x in lp_ticks])
            ax.set_yticklabels([f"{x:.2g}" for x in lc_ticks])
        else:
            ax.set_xticks([])
            ax.set_yticks([])

        if show_labels:
            ax.set_xlabel(r"$\lambda_{\mathrm{prior}}$")
            ax.set_ylabel(r"$\lambda_{\mathrm{cons}}$")
        else:
            ax.set_xlabel("")
            ax.set_ylabel("")

        if show_title:
            ax.set_title(
                f"{city}{_metric_label(metric)}",
                loc="left",
                pad=6,
                fontweight="bold",
            )

        # "best" marker: approximate by pivoting to the aligned grid
        if mark_best:
            piv = _pivot_grid(
                sub,
                metric=metric,
                lc_grid=lc_grid,
                lp_grid=lp_grid,
            )
            if not piv.empty:
                ij = _best_ij(piv.values, metric=metric)
                if ij is not None:
                    i, j = ij
                    ax.scatter(
                        [float(piv.columns[j])],
                        [float(piv.index[i])],
                        marker="o",
                        s=50,
                        facecolors="none",
                        edgecolors="white",
                        linewidths=1.5,
                    )

        if mark_lc is not None and mark_lp is not None:
            ax.scatter(
                [mark_lp],
                [mark_lc],
                marker="s",
                s=70,
                facecolors="none",
                edgecolors="white",
                linewidths=1.8,
            )

        _axes_cleanup(ax)
        return artist

    piv = _pivot_grid(
        sub,
        metric=metric,
        lc_grid=lc_grid,
        lp_grid=lp_grid,
    )
    if piv.empty:
        ax.set_axis_off()
        return None

    artist = _plot_map_on_grid(
        ax,
        piv,
        metric=metric,
        cmap=cmap,
        map_kind=map_kind,
        levels=levels,
        contour_lines=contour_lines,
        vmin=vmin,
        vmax=vmax,
    )

    if map_kind == "contour":
        xt = [float(c) for c in piv.columns]
        yt = [float(r) for r in piv.index]
        ax.set_xticks(xt)
        ax.set_yticks(yt)
        if show_ticks:
            ax.set_xticklabels([f"{c:.2g}" for c in xt])
            ax.set_yticklabels([f"{r:.2g}" for r in yt])
        else:
            ax.set_xticklabels([])
            ax.set_yticklabels([])
    else:
        ax.set_xticks(range(piv.shape[1]))
        ax.set_yticks(range(piv.shape[0]))
        if show_ticks:
            ax.set_xticklabels(
                [f"{c:.2g}" for c in piv.columns]
            )
            ax.set_yticklabels(
                [f"{r:.2g}" for r in piv.index]
            )
        else:
            ax.set_xticklabels([])
            ax.set_yticklabels([])

    if show_labels:
        ax.set_xlabel(r"$\lambda_{\mathrm{prior}}$")
        ax.set_ylabel(r"$\lambda_{\mathrm{cons}}$")
    else:
        ax.set_xlabel("")
        ax.set_ylabel("")

    if show_title:
        ax.set_title(
            f"{city}{_metric_label(metric)}",
            loc="left",
            pad=6,
            fontweight="bold",
        )

    if mark_best:
        ij = _best_ij(piv.values, metric=metric)
        if ij is not None:
            i, j = ij
            ax.scatter(
                [j],
                [i],
                marker="o",
                s=50,
                facecolors="none",
                edgecolors="white",
                linewidths=1.5,
            )

    if mark_lc is not None and mark_lp is not None:
        if map_kind == "contour":
            if (mark_lc in piv.index) and (
                mark_lp in piv.columns
            ):
                ax.scatter(
                    [float(mark_lp)],
                    [float(mark_lc)],
                    marker="s",
                    s=70,
                    facecolors="none",
                    edgecolors="white",
                    linewidths=1.8,
                )
        else:
            _maybe_mark_point(
                ax,
                lc_ticks=[float(x) for x in piv.index],
                lp_ticks=[float(x) for x in piv.columns],
                mark_lc=mark_lc,
                mark_lp=mark_lp,
            )

    _axes_cleanup(ax)
    return artist


def _resolve_out(*, out: str, out_dir: str | None) -> Path:
    if out_dir:
        base = Path(out_dir).expanduser()
        return (base / Path(out).expanduser()).resolve()
    return utils.resolve_fig_out(out)


def _compute_clim(
    df: pd.DataFrame,
    *,
    cities: list[str],
    metric: str,
) -> tuple[float | None, float | None]:
    """
    Compute a robust vmin/vmax for a metric across the selected cities
    (physics-on bucket only).
    """
    if metric not in df.columns:
        return None, None

    sub = df.loc[
        df["city"].astype(str).isin([str(c) for c in cities])
    ].copy()
    sub = sub.loc[sub["pde_bucket"].eq("both")].copy()
    v = pd.to_numeric(sub[metric], errors="coerce").dropna()
    if v.empty:
        return None, None

    # robust range (2-98 percentile) to avoid one-off extremes
    lo = float(np.nanpercentile(v.to_numpy(), 2))
    hi = float(np.nanpercentile(v.to_numpy(), 98))
    if not np.isfinite(lo) or not np.isfinite(hi) or lo == hi:
        return None, None
    return lo, hi


def _to_min_obj(metric: str, v: np.ndarray) -> np.ndarray:
    """
    Map a metric to a minimization objective:
    - if lower is better -> keep as is
    - if higher is better -> negate
    """
    m = str(metric).lower()
    if m in _LOWER_IS_BETTER:
        return v
    return -v


def _pareto_front_mask(
    x: np.ndarray, y: np.ndarray
) -> np.ndarray:
    """
    Non-dominated points for 2D minimization objectives.
    Returns a boolean mask of points on the Pareto front.
    """
    n = int(x.size)
    keep = np.ones(n, dtype=bool)
    for i in range(n):
        if not keep[i]:
            continue
        # dominated if exists j with xj<=xi and yj<=yi and one strict
        dom = (
            (x <= x[i])
            & (y <= y[i])
            & ((x < x[i]) | (y < y[i]))
        )
        if np.any(dom):
            keep[i] = False
    return keep


def _pareto_panel(
    ax: plt.Axes,
    df: pd.DataFrame,
    *,
    city: str,
    x_metric: str,
    y_metric: str,
    c_metric: str,
    cmap: str,
    c_norm: Any | None,
    s: float,
    alpha: float,
    show_labels: bool,
    show_title: bool,
    mark_lc: float | None,
    mark_lp: float | None,
    show_front: bool,
    density: bool,
    density_gridsize: int,
    density_bins: str,
    density_cmap: str,
    density_alpha: float,
    density_mincnt: int,
) -> Any | None:
    sub = df.loc[df["city"].astype(str).eq(str(city))].copy()
    sub = sub.loc[sub["pde_bucket"].eq("both")].copy()

    need = [
        x_metric,
        y_metric,
        c_metric,
        "lambda_cons",
        "lambda_prior",
    ]
    for c in need:
        if c not in sub.columns:
            ax.set_axis_off()
            return None

    sub = sub.dropna(subset=need).copy()
    if sub.empty:
        ax.set_axis_off()
        return None

    x = pd.to_numeric(
        sub[x_metric], errors="coerce"
    ).to_numpy()
    y = pd.to_numeric(
        sub[y_metric], errors="coerce"
    ).to_numpy()
    c = pd.to_numeric(
        sub[c_metric], errors="coerce"
    ).to_numpy()

    ok = np.isfinite(x) & np.isfinite(y) & np.isfinite(c)
    x, y, c = x[ok], y[ok], c[ok]
    sub = sub.loc[ok].copy()
    if x.size == 0:
        ax.set_axis_off()
        return None

    # Optional density overlay (hexbin) behind points.
    # NOTE: matplotlib's built-in `bins='log'` can become visually
    # invisible when most counts are 1 (log10(1)=0 and vmin=vmax=0).
    # We therefore apply a stable transform ourselves.
    if density and x.size >= 3:
        hb = ax.hexbin(
            x,
            y,
            gridsize=int(density_gridsize),
            mincnt=int(density_mincnt),
            bins=None,
            cmap=str(density_cmap),
            linewidths=0.0,
            alpha=float(density_alpha),
            zorder=0,
        )
        try:
            counts = np.asarray(hb.get_array(), dtype=float)
            if counts.size:
                if str(density_bins).lower() == "log":
                    vals = np.log10(counts + 1.0)
                    hb.set_array(vals)
                    vmax = float(np.nanmax(vals))
                    hb.set_clim(
                        0.0, vmax if vmax > 0 else 1.0
                    )
                else:
                    vmax = float(np.nanmax(counts))
                    hb.set_clim(
                        0.0, vmax if vmax > 0 else 1.0
                    )
        except Exception:
            pass

        # Rasterize to keep PDF size reasonable
        try:
            hb.set_rasterized(True)
        except Exception:
            pass

    sc = ax.scatter(
        x,
        y,
        c=c,
        cmap=cmap,
        norm=c_norm,
        s=float(s),
        alpha=float(alpha),
        edgecolors="white",
        linewidths=0.35,
        zorder=3,
    )

    # Optional Pareto front (computed in objective space)
    if show_front and x.size >= 3:
        xo = _to_min_obj(x_metric, x)
        yo = _to_min_obj(y_metric, y)
        mask = _pareto_front_mask(xo, yo)
        if np.any(mask):
            xf = x[mask]
            yf = y[mask]
            order = np.argsort(xf)
            ax.plot(
                xf[order],
                yf[order],
                linewidth=1.2,
                alpha=0.9,
                zorder=4,
            )

    # Mark chosen (lambda_cons, lambda_prior) point if present
    if mark_lc is not None and mark_lp is not None:
        sel = sub.loc[
            (
                sub["lambda_cons"].astype(float)
                == float(mark_lc)
            )
            & (
                sub["lambda_prior"].astype(float)
                == float(mark_lp)
            )
        ]
        if not sel.empty:
            xv = float(np.nanmean(sel[x_metric].to_numpy()))
            yv = float(np.nanmean(sel[y_metric].to_numpy()))
            ax.scatter(
                [xv],
                [yv],
                marker="s",
                s=90,
                facecolors="none",
                edgecolors="black",
                linewidths=1.6,
                zorder=6,
            )

    if show_labels:
        ax.set_xlabel(_metric_label(x_metric))
        ax.set_ylabel(_metric_label(y_metric))
    else:
        ax.set_xlabel("")
        ax.set_ylabel("")

    if show_title:
        ax.set_title(city, loc="left", fontweight="bold")

    _axes_cleanup(ax)
    return sc


# ---------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------
[docs] def plot_ablations_sensivity_main( argv: list[str] | None = None, *, prog: str | None = None, ) -> None: args = _parse_args(argv, prog=prog) utils.set_paper_style( fontsize=int(args.font), dpi=int(args.dpi) ) show_legend = utils.str_to_bool( args.show_legend, default=True ) show_labels = utils.str_to_bool( args.show_labels, default=True ) show_ticks = utils.str_to_bool( args.show_ticklabels, default=True ) show_title = utils.str_to_bool( args.show_title, default=True ) show_pan_t = utils.str_to_bool( args.show_panel_titles, default=True ) annotate = not utils.str_to_bool( args.mute_values, default=False ) show_cb = not utils.str_to_bool( args.no_colorbar, default=False ) no_bars = utils.str_to_bool(args.no_bars, default=False) align_grid = utils.str_to_bool( args.align_grid, default=True ) contour_lines = utils.str_to_bool( args.contour_lines, default=False ) mark_best = not utils.str_to_bool( args.no_best_marker, default=False ) df = _load_records(args) if df.empty: root = utils.as_path(args.root) raise SystemExit( "No ablation_record*.jsonl found under:\n" f" {root.resolve()}\n" "Run ablations with the logger enabled first." ) df = _canon_cols(df) df = _harmonize_units(df) df = _dedupe_prefer_best(df) df = _filter_models(df, args.models) cities = _resolve_cities(args) if not cities: raise SystemExit("No cities selected.") if len(cities) == 1: cities = [cities[0], cities[0]] city_a, city_b = cities[0], cities[1] out = _resolve_out(out=args.out, out_dir=args.out_dir) utils.ensure_dir(out.parent) tidy_csv = out.parent / "tableS6_ablations_used.csv" df.to_csv(tidy_csv, index=False) print(f"[OK] table -> {tidy_csv}") map_metrics = _resolve_map_metrics(args) # Auto-disable bars if multiple map metrics do_pareto = utils.str_to_bool(args.pareto, default=False) if do_pareto: x_metric = str(args.pareto_x) y_metric = str(args.pareto_y) c_metric = str(args.pareto_color) show_front = utils.str_to_bool( args.pareto_front, default=True ) density_on = utils.str_to_bool( getattr(args, "pareto_density", "false"), default=False, ) density_gridsize = int( getattr(args, "pareto_density_gridsize", 35) ) density_bins = str( getattr(args, "pareto_density_bins", "log") ) density_cmap = str( getattr(args, "pareto_density_cmap", "Greys") ) density_alpha = float( getattr(args, "pareto_density_alpha", 0.25) ) density_mincnt = int( getattr(args, "pareto_density_mincnt", 1) ) # Shared color scale across cities for comparability cmin, cmax = _compute_clim( df, cities=[city_a, city_b], metric=c_metric ) c_norm = None if cmin is not None and cmax is not None: c_norm = plt.Normalize(vmin=cmin, vmax=cmax) fig, axes = plt.subplots( nrows=1, ncols=2, figsize=(8.6, 3.8), dpi=int(args.dpi), ) ax1, ax2 = axes[0], axes[1] sc1 = _pareto_panel( ax1, df, city=city_a, x_metric=x_metric, y_metric=y_metric, c_metric=c_metric, cmap=str(args.cmap), c_norm=c_norm, s=float(args.pareto_s), alpha=float(args.pareto_alpha), show_labels=True, show_title=True, mark_lc=args.mark_lambda_cons, mark_lp=args.mark_lambda_prior, show_front=show_front, density=density_on, density_gridsize=density_gridsize, density_bins=density_bins, density_cmap=density_cmap, density_alpha=density_alpha, density_mincnt=density_mincnt, ) sc2 = _pareto_panel( ax2, df, city=city_b, x_metric=x_metric, y_metric=y_metric, c_metric=c_metric, cmap=str(args.cmap), c_norm=c_norm, s=float(args.pareto_s), alpha=float(args.pareto_alpha), show_labels=True, show_title=True, mark_lc=args.mark_lambda_cons, mark_lp=args.mark_lambda_prior, show_front=show_front, density=density_on, density_gridsize=density_gridsize, density_bins=density_bins, density_cmap=density_cmap, density_alpha=density_alpha, density_mincnt=density_mincnt, ) if show_cb: sc = sc1 if sc1 is not None else sc2 if sc is not None: cbar = fig.colorbar( sc, ax=[ax1, ax2], orientation="vertical", fraction=0.04, pad=0.04, ) cbar.set_label(_metric_label(c_metric)) if show_title: default = "Supplement S6 • Pareto trade-offs" ttl = utils.resolve_title( default=default, title=args.title ) fig.suptitle(ttl, fontsize=11, fontweight="bold") utils.save_figure(fig, out) return if len(map_metrics) > 1: no_bars = True # Build a shared lambda grid if requested lc_grid: list[float] | None = None lp_grid: list[float] | None = None if align_grid: lc_grid, lp_grid = _lambda_grid( df, cities=[city_a, city_b] ) # ----------------------------------------------------------------- # Layout # ----------------------------------------------------------------- if not no_bars and len(map_metrics) == 1: # Legacy 2x2: bars (row1) + maps (row2) fig = plt.figure(figsize=(8.6, 7.0)) gs = fig.add_gridspec( nrows=2, ncols=2, left=0.07, right=0.90 if (show_legend and show_cb) else 0.98, top=0.94, bottom=0.10, hspace=0.32, wspace=0.30, ) ax11 = fig.add_subplot(gs[0, 0]) _bars_by_lambda( ax11, df, city=city_a, metric=args.bar_metric, color=cfg.CITY_COLORS.get(city_a, "#1F78B4"), annotate=annotate, show_legend=show_legend, show_labels=show_labels, show_ticks=show_ticks, show_title=show_pan_t, ) ax12 = fig.add_subplot(gs[0, 1]) _bars_by_lambda( ax12, df, city=city_b, metric=args.bar_metric, color=cfg.CITY_COLORS.get(city_b, "#E31A1C"), annotate=annotate, show_legend=show_legend, show_labels=show_labels, show_ticks=show_ticks, show_title=show_pan_t, ) metric = map_metrics[0] vmin, vmax = _compute_clim( df, cities=[city_a, city_b], metric=metric ) ax21 = fig.add_subplot(gs[1, 0]) im_a = _heatmap_one( ax21, df, city=city_a, metric=metric, cmap=args.cmap, map_kind=args.map_kind, levels=int(args.levels), contour_lines=contour_lines, lc_grid=lc_grid, lp_grid=lp_grid, vmin=vmin, vmax=vmax, show_labels=show_labels, show_ticks=show_ticks, show_title=show_pan_t, mark_best=mark_best, mark_lc=args.mark_lambda_cons, mark_lp=args.mark_lambda_prior, ) ax22 = fig.add_subplot(gs[1, 1]) im_b = _heatmap_one( ax22, df, city=city_b, metric=metric, cmap=args.cmap, map_kind=args.map_kind, levels=int(args.levels), contour_lines=contour_lines, lc_grid=lc_grid, lp_grid=lp_grid, vmin=vmin, vmax=vmax, show_labels=show_labels, show_ticks=show_ticks, show_title=show_pan_t, mark_best=mark_best, mark_lc=args.mark_lambda_cons, mark_lp=args.mark_lambda_prior, ) if show_cb: cax = fig.add_axes([0.92, 0.12, 0.015, 0.30]) im_for = im_a if im_a is not None else im_b if im_for is not None: fig.colorbar( im_for, cax=cax, orientation="vertical", label=_metric_label(metric), ) else: cax.set_axis_off() else: # Heatmap-only multi-metric: 2 rows (cities) × M cols (metrics) M = max(1, len(map_metrics)) fig = plt.figure(figsize=(3.2 * M, 5.8)) # Allow room for per-column colorbars right = 0.98 if not show_cb else 0.92 gs = fig.add_gridspec( nrows=2, ncols=M, left=0.07, right=right, top=0.93, bottom=0.11, hspace=0.20, wspace=0.25, ) # Precompute per-metric color ranges and store artists for cbar artists: list[Any] = [] top_axes: list[plt.Axes] = [] bottom_axes: list[plt.Axes] = [] for j, metric in enumerate(map_metrics): vmin, vmax = _compute_clim( df, cities=[city_a, city_b], metric=metric, ) # City A (row 0) axA = fig.add_subplot(gs[0, j]) top_axes.append(axA) artA = _heatmap_one( axA, df, city=city_a, metric=metric, cmap=args.cmap, map_kind=args.map_kind, levels=int(args.levels), contour_lines=contour_lines, lc_grid=lc_grid, lp_grid=lp_grid, vmin=vmin, vmax=vmax, show_labels=False, show_ticks=show_ticks, show_title=False, mark_best=mark_best, mark_lc=args.mark_lambda_cons, mark_lp=args.mark_lambda_prior, ) # City B (row 1) axB = fig.add_subplot(gs[1, j]) bottom_axes.append(axB) artB = _heatmap_one( axB, df, city=city_b, metric=metric, cmap=args.cmap, map_kind=args.map_kind, levels=int(args.levels), contour_lines=contour_lines, lc_grid=lc_grid, lp_grid=lp_grid, vmin=vmin, vmax=vmax, show_labels=(show_labels and j == 0), show_ticks=show_ticks, show_title=False, mark_best=mark_best, mark_lc=args.mark_lambda_cons, mark_lp=args.mark_lambda_prior, ) # Column titles: metric names if show_pan_t: axA.set_title( _metric_label(metric), fontweight="bold", pad=6, ) # Row labels: cities on left-most column if j == 0 and show_pan_t: axA.text( -0.20, 0.5, city_a, transform=axA.transAxes, rotation=90, va="center", ha="center", fontweight="bold", ) axB.text( -0.20, 0.5, city_b, transform=axB.transAxes, rotation=90, va="center", ha="center", fontweight="bold", ) artists.append(artA if artA is not None else artB) # Colorbars: one per metric column (shared across the two cities) if show_cb: for j, metric in enumerate(map_metrics): art = artists[j] if j < len(artists) else None axA = ( top_axes[j] if j < len(top_axes) else None ) axB = ( bottom_axes[j] if j < len(bottom_axes) else None ) if art is None or axA is None or axB is None: continue cbar = fig.colorbar( art, ax=[axA, axB], orientation="horizontal", fraction=0.06, pad=0.10, ) cbar.set_label(_metric_label(metric)) # ----------------------------------------------------------------- # Title + save # ----------------------------------------------------------------- if show_title: default = "Supplement S6 • Extended ablations & sensitivities" ttl = utils.resolve_title( default=default, title=args.title ) fig.suptitle(ttl, fontsize=11, fontweight="bold") utils.save_figure(fig, out)
[docs] def main( argv: list[str] | None = None, *, prog: str | None = None, ) -> None: plot_ablations_sensivity_main(argv, prog=prog)
if __name__ == "__main__": main()