Source code for geoprior.scripts.build_model_metrics

# SPDX-License-Identifier: Apache-2.0
# GeoPrior-v3 - https://github.com/earthai-tech/geoprior-v3
# Copyright (c) 2026-present
# Author: LKouadio <https://lkouadio.com>
#
"""Build a unified "model metrics" table from GeoPrior runs.

This scans a results root (or a single run folder) and collects
metrics from ablation records, preferring *updated* JSONL records.

Outputs (under scripts/out/ unless overridden)
----------------------------------------------
- ``<out>.csv``: wide table with one row per run
- ``<out>.json``: same content as the CSV, serialized as records
- ``<out>_long.csv``: optional long table with one row per horizon per run
- ``<out>_long.json``: optional JSON version of the long table

Data sources
------------
We scan for JSONL under
``<src>/**/ablation_records/ablation_record*.jsonl``.

Preference order (robust even if cfg differs)
---------------------------------------------
1. ``ablation_record.updated*.jsonl``
2. ``ablation_record*.jsonl``

Notes
-----
- Legacy (SI) ablation records are auto-converted to mm when
  values look like meters (heuristic).
- Per-horizon blocks are expanded to columns such as
  ``r2_H1``, ``r2_H2``, ``mae_H1``, and ``mae_H2``.
- Interval calibration fields are exported when present, including
  ``coverage80_cal_phys`` and ``sharpness80_cal_phys``.
"""

from __future__ import annotations

import argparse
import json
from pathlib import Path
from typing import Any

import numpy as np
import pandas as pd

from . import config as cfg
from . import utils

# ---------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------


def _parse_args(
    argv: list[str] | None,
    *,
    prog: str | None = None,
) -> argparse.Namespace:
    p = argparse.ArgumentParser(
        prog=prog or "build-model-metrics",
        description="Build a unified model metrics table.",
    )

    p.add_argument(
        "--src",
        "--results-root",
        "--run-dir",
        dest="src",
        type=str,
        default="results",
        help=(
            "Results root (scan recursive) OR a single run dir."
        ),
    )
    p.add_argument(
        "--out",
        type=str,
        default="model_metrics",
        help=(
            "Output stem/path (scripts/out/ if relative). "
            "Suffix is ignored."
        ),
    )
    p.add_argument(
        "--out-dir",
        type=str,
        default=None,
        help="Optional output dir override.",
    )

    utils.add_city_flags(p, default_both=False)
    p.add_argument(
        "--city",
        type=str,
        default="",
        help=(
            "Extra city filter (comma list). "
            "Example: --city Nansha,Zhongshan"
        ),
    )

    p.add_argument(
        "--models",
        type=str,
        default="",
        help=(
            "Comma list of model names to keep (exact match)."
        ),
    )

    p.add_argument(
        "--include-long",
        type=str,
        default="true",
        help="Write a per-horizon long table (true/false).",
    )

    p.add_argument(
        "--dedupe",
        type=str,
        default="true",
        help=(
            "Deduplicate by (timestamp, city, model) "
            "preferring updated (true/false)."
        ),
    )

    return p.parse_args(argv)


# ---------------------------------------------------------------------
# IO
# ---------------------------------------------------------------------


def _read_jsonl(fp: Path) -> list[dict[str, Any]]:
    rows: list[dict[str, Any]] = []
    with fp.open("r", encoding="utf-8") as f:
        for ln in f:
            s = ln.strip()
            if not s:
                continue
            try:
                rec = json.loads(s)
            except Exception:
                continue
            if isinstance(rec, dict):
                rows.append(rec)
    return rows


def _preferred_ablation_patterns() -> tuple[str, ...]:
    """Enforce updated-first precedence, even if cfg differs."""
    base = list(cfg.PATTERNS.get("ablation_record_jsonl", ()))
    prefer = [
        "ablation_records/ablation_record.updated*.jsonl",
        "ablation_records/ablation_record*.jsonl",
        "ablation_record.updated*.jsonl",
        "ablation_record*.jsonl",
    ]
    out: list[str] = []
    for pat in prefer + base:
        if pat and pat not in out:
            out.append(pat)
    return tuple(out)


def _scan_ablation_jsonl(src: Path) -> list[dict[str, Any]]:
    pats = _preferred_ablation_patterns()
    files = utils.find_all(src, pats, must_exist=False)
    rows: list[dict[str, Any]] = []

    for fp in files:
        try:
            recs = _read_jsonl(fp)
        except Exception:
            continue

        # infer run_dir from path:
        # <run_dir>/ablation_records/ablation_record*.jsonl
        run_dir = fp.parent.parent
        for r in recs:
            rr = dict(r)
            rr["_src"] = str(fp)
            rr.setdefault("_run_dir", str(run_dir))
            rows.append(rr)

    return rows


# ---------------------------------------------------------------------
# Canon / scaling
# ---------------------------------------------------------------------


def _hkeys_to_Hn(d: Any) -> dict[str, Any]:
    if not isinstance(d, dict):
        return {}
    out: dict[str, Any] = {}
    for k, v in d.items():
        ks = str(k)
        if ks.startswith("H"):
            out[ks] = v
        else:
            out[f"H{ks}"] = v
    return out


def _get_nested(d: Any, path: tuple[str, ...]) -> Any:
    cur: Any = d
    for k in path:
        if not isinstance(cur, dict):
            return None
        cur = cur.get(k, None)
    return cur


def _to_float(x: Any) -> float:
    try:
        v = float(x)
    except Exception:
        return float("nan")
    return v


def _needs_si_to_mm(rec: dict[str, Any]) -> bool:
    """Heuristic for legacy SI ablation records."""
    u = rec.get("units", None)
    if isinstance(u, dict):
        uu = str(u.get("subs_metrics_unit", "")).lower()
        if uu == "mm":
            return False
        if uu in {"m", "meter", "metre"}:
            return True

    mae = _to_float(rec.get("mae", np.nan))
    mse = _to_float(rec.get("mse", np.nan))
    shp = _to_float(rec.get("sharpness80", np.nan))

    ok = (
        np.isfinite(mae)
        and np.isfinite(mse)
        and np.isfinite(shp)
        and 0.0 < mae < 1.0
        and 0.0 < mse < 1.0
        and 0.0 < shp < 1.0
    )
    return bool(ok)


def _scale_legacy_to_mm(
    rec: dict[str, Any],
) -> dict[str, Any]:
    if not _needs_si_to_mm(rec):
        return rec

    out = dict(rec)

    for k in ["mae", "rmse", "sharpness80"]:
        v = out.get(k, None)
        if isinstance(v, int | float) and np.isfinite(v):
            out[k] = float(v) * 1000.0

    v = out.get("mse", None)
    if isinstance(v, int | float) and np.isfinite(v):
        out["mse"] = float(v) * 1e6

    u = dict(out.get("units") or {})
    u.setdefault("subs_metrics_unit", "mm")
    u.setdefault("subs_factor_si_to_real", 1000.0)
    out["units"] = u
    return out


def _record_score(rec: dict[str, Any]) -> int:
    s = 0
    src = str(rec.get("_src", "")).lower()
    if "updated" in src:
        s += 100
    u = rec.get("units", None)
    if isinstance(u, dict):
        uu = str(u.get("subs_metrics_unit", "")).lower()
        if uu == "mm":
            s += 50
    if isinstance(rec.get("metrics", None), dict):
        s += 10
    for k in ["rmse", "pss", "epsilon_prior", "epsilon_cons"]:
        v = _to_float(rec.get(k, np.nan))
        if np.isfinite(v):
            s += 1
    return int(s)


def _dedupe_records(
    rows: list[dict[str, Any]],
) -> list[dict[str, Any]]:
    if not rows:
        return rows

    key_map: dict[tuple[str, str, str], dict[str, Any]] = {}
    for r0 in rows:
        r = _scale_legacy_to_mm(r0)
        ts = str(r.get("timestamp") or "")
        city = utils.canonical_city(str(r.get("city") or ""))
        model = str(r.get("model") or "")
        key = (ts, city, model)

        best = key_map.get(key, None)
        if best is None:
            r["city"] = city
            key_map[key] = r
            continue

        if _record_score(r) > _record_score(best):
            r["city"] = city
            key_map[key] = r

    return list(key_map.values())


# ---------------------------------------------------------------------
# Flatten to table
# ---------------------------------------------------------------------


def _extract_interval_cols(
    rec: dict[str, Any],
) -> dict[str, Any]:
    """Pull calibrated/uncalibrated interval metrics when present."""
    out: dict[str, Any] = {}

    ival = _get_nested(
        rec,
        ("metrics", "posthoc", "interval_calibration"),
    )
    if not isinstance(ival, dict):
        ival = _get_nested(
            rec, ("metrics", "posthoc", "interval")
        )
    if not isinstance(ival, dict):
        ival = rec.get("interval_calibration", None)

    if not isinstance(ival, dict):
        return out

    def _pick(*ks: str) -> Any:
        for k in ks:
            if k in ival:
                return ival.get(k)
        return None

    out["interval_target"] = _pick("target")

    # scaled
    out["coverage80_uncal"] = _pick("coverage80_uncalibrated")
    out["coverage80_cal"] = _pick("coverage80_calibrated")
    out["sharpness80_uncal"] = _pick(
        "sharpness80_uncalibrated"
    )
    out["sharpness80_cal"] = _pick("sharpness80_calibrated")

    # physical
    out["coverage80_uncal_phys"] = _pick(
        "coverage80_uncalibrated_phys"
    )
    out["coverage80_cal_phys"] = _pick(
        "coverage80_calibrated_phys"
    )
    out["sharpness80_uncal_phys"] = _pick(
        "sharpness80_uncalibrated_phys"
    )
    out["sharpness80_cal_phys"] = _pick(
        "sharpness80_calibrated_phys"
    )

    fac = _pick("factors_per_horizon")
    if fac is not None:
        try:
            out["factors_per_horizon"] = json.dumps(fac)
        except Exception:
            out["factors_per_horizon"] = str(fac)

    stats = _pick("factors_per_horizon_from_cal_stats")
    if isinstance(stats, dict):
        eb = stats.get("eval_before", {}) or {}
        ea = stats.get("eval_after", {}) or {}
        phb = eb.get("per_horizon", {}) or {}
        pha = ea.get("per_horizon", {}) or {}

        for hk, hv in (phb or {}).items():
            if not isinstance(hv, dict):
                continue
            H = (
                f"H{hk}"
                if not str(hk).startswith("H")
                else str(hk)
            )
            out[f"coverage80_uncal_{H}"] = hv.get("coverage")
            out[f"sharpness80_uncal_{H}"] = hv.get(
                "sharpness"
            )

        for hk, hv in (pha or {}).items():
            if not isinstance(hv, dict):
                continue
            H = (
                f"H{hk}"
                if not str(hk).startswith("H")
                else str(hk)
            )
            out[f"coverage80_cal_{H}"] = hv.get("coverage")
            out[f"sharpness80_cal_{H}"] = hv.get("sharpness")

    return out


def _extract_per_horizon_cols(
    rec: dict[str, Any],
) -> dict[str, Any]:
    out: dict[str, Any] = {}

    pr2 = rec.get("per_horizon_r2", None)
    pmae = rec.get("per_horizon_mae", None)

    if not isinstance(pr2, dict):
        pr2 = _get_nested(rec, ("per_horizon", "r2"))
    if not isinstance(pmae, dict):
        pmae = _get_nested(rec, ("per_horizon", "mae"))

    pr2 = _hkeys_to_Hn(pr2)
    pmae = _hkeys_to_Hn(pmae)

    for hk, hv in pr2.items():
        out[f"r2_{hk}"] = hv
    for hk, hv in pmae.items():
        out[f"mae_{hk}"] = hv

    return out


def _record_to_row(rec0: dict[str, Any]) -> dict[str, Any]:
    rec = _scale_legacy_to_mm(rec0)
    out: dict[str, Any] = {}

    # id + config
    for k in [
        "timestamp",
        "city",
        "model",
        "pde_mode",
        "use_effective_h",
        "kappa_mode",
        "hd_factor",
        "lambda_cons",
        "lambda_gw",
        "lambda_prior",
        "lambda_smooth",
        "lambda_mv",
    ]:
        if k in rec:
            out[k] = rec.get(k)

    out["city"] = utils.canonical_city(
        str(out.get("city") or "")
    )

    # headline metrics
    for k in [
        "r2",
        "mae",
        "mse",
        "rmse",
        "pss",
        "coverage80",
        "sharpness80",
        "epsilon_prior",
        "epsilon_cons",
        "epsilon_gw",
        "epsilon_cons_raw",
        "epsilon_gw_raw",
    ]:
        if k in rec:
            out[k] = rec.get(k)

    # rmse fallback
    if ("rmse" not in out or out.get("rmse") is None) and (
        "mse" in out
    ):
        try:
            out["rmse"] = float(out["mse"]) ** 0.5
        except Exception:
            pass

    # units (if present)
    u = rec.get("units", None)
    if isinstance(u, dict):
        out["subs_metrics_unit"] = u.get("subs_metrics_unit")
        out["time_units"] = u.get("time_units")

    # provenance
    out["_src"] = rec.get("_src")
    out["_run_dir"] = rec.get("_run_dir")

    # expanded blocks
    out.update(_extract_interval_cols(rec))
    out.update(_extract_per_horizon_cols(rec))

    return out


def _rows_to_df(rows: list[dict[str, Any]]) -> pd.DataFrame:
    if not rows:
        return pd.DataFrame([])
    df = pd.DataFrame(rows)

    # numeric coercion for stable CSV export
    num_cols = [
        "hd_factor",
        "lambda_cons",
        "lambda_gw",
        "lambda_prior",
        "lambda_smooth",
        "lambda_mv",
        "r2",
        "mae",
        "mse",
        "rmse",
        "pss",
        "coverage80",
        "sharpness80",
        "epsilon_prior",
        "epsilon_cons",
        "epsilon_gw",
        "epsilon_cons_raw",
        "epsilon_gw_raw",
        "coverage80_uncal",
        "coverage80_cal",
        "sharpness80_uncal",
        "sharpness80_cal",
        "coverage80_uncal_phys",
        "coverage80_cal_phys",
        "sharpness80_uncal_phys",
        "sharpness80_cal_phys",
    ]
    for c in num_cols:
        if c in df.columns:
            df[c] = pd.to_numeric(df[c], errors="coerce")

    # expanded per-horizon cols
    for c in list(df.columns):
        if c.startswith(("r2_H", "mae_H")):
            df[c] = pd.to_numeric(df[c], errors="coerce")
        if c.startswith(("coverage80_", "sharpness80_")):
            df[c] = pd.to_numeric(df[c], errors="coerce")

    return df


def _to_long(df: pd.DataFrame) -> pd.DataFrame:
    if df.empty:
        return df

    base_cols = [
        "timestamp",
        "city",
        "model",
        "pde_mode",
        "lambda_cons",
        "lambda_gw",
        "lambda_prior",
        "lambda_smooth",
        "lambda_mv",
        "_src",
        "_run_dir",
    ]
    base_cols = [c for c in base_cols if c in df.columns]

    # discover horizons
    hs: list[str] = []
    for c in df.columns:
        if c.startswith("r2_H"):
            hs.append(c.replace("r2_", ""))
        if c.startswith("mae_H"):
            hs.append(c.replace("mae_", ""))
    hs = sorted(set(hs))

    if not hs:
        return pd.DataFrame([])

    rows: list[dict[str, Any]] = []
    for _, r in df.iterrows():
        base = {k: r.get(k) for k in base_cols}
        for H in hs:
            rec: dict[str, Any] = dict(base)
            rec["horizon"] = H
            rec["r2"] = r.get(f"r2_{H}", np.nan)
            rec["mae"] = r.get(f"mae_{H}", np.nan)

            rec["coverage80_uncal"] = r.get(
                f"coverage80_uncal_{H}", np.nan
            )
            rec["coverage80_cal"] = r.get(
                f"coverage80_cal_{H}", np.nan
            )
            rec["sharpness80_uncal"] = r.get(
                f"sharpness80_uncal_{H}", np.nan
            )
            rec["sharpness80_cal"] = r.get(
                f"sharpness80_cal_{H}", np.nan
            )
            rows.append(rec)

    out = pd.DataFrame(rows)
    for c in [
        "r2",
        "mae",
        "coverage80_uncal",
        "coverage80_cal",
    ]:
        if c in out.columns:
            out[c] = pd.to_numeric(out[c], errors="coerce")
    for c in ["sharpness80_uncal", "sharpness80_cal"]:
        if c in out.columns:
            out[c] = pd.to_numeric(out[c], errors="coerce")

    return out


# ---------------------------------------------------------------------
# Filters + output
# ---------------------------------------------------------------------


def _filter_rows(
    rows: list[dict[str, Any]],
    args: argparse.Namespace,
) -> list[dict[str, Any]]:
    if not rows:
        return rows

    keep_cities = utils.resolve_cities(args)

    raw = str(args.city or "").strip()
    if raw:
        for p in raw.split(","):
            p = p.strip()
            if p:
                keep_cities.append(utils.canonical_city(p))

    keep_cities = [c for c in keep_cities if c]
    keep_cities = sorted(set(keep_cities))

    keep_models: list[str] = []
    raw_m = str(args.models or "").strip()
    if raw_m:
        keep_models = [
            m.strip() for m in raw_m.split(",") if m.strip()
        ]

    out: list[dict[str, Any]] = []
    for r in rows:
        city = utils.canonical_city(str(r.get("city") or ""))
        model = str(r.get("model") or "")

        if keep_cities and city not in keep_cities:
            continue
        if keep_models and model not in keep_models:
            continue

        out.append(r)

    return out


def _resolve_out(*, out: str, out_dir: str | None) -> Path:
    """Like resolve_fig_out(), but targets scripts/out/ by default."""
    p = Path(out).expanduser().with_suffix("")
    if out_dir:
        base = Path(out_dir).expanduser()
        return (base / p.name).resolve()
    if not p.is_absolute():
        p = cfg.OUT_DIR / p
    return p


# ---------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------


[docs] def build_model_metrics_main( argv: list[str] | None = None, *, prog: str | None = None, ) -> None: args = _parse_args(argv, prog=prog) utils.ensure_script_dirs() src = utils.as_path(args.src) rows = _scan_ablation_jsonl(src) rows = _filter_rows(rows, args) if utils.str_to_bool(args.dedupe, default=True): rows = _dedupe_records(rows) else: rows = [_scale_legacy_to_mm(r) for r in rows] if not rows: raise SystemExit( f"No ablation records found under: {src}" ) flat = [_record_to_row(r) for r in rows] df = _rows_to_df(flat) out_stem = _resolve_out( out=args.out, out_dir=args.out_dir ) csv_p = out_stem.with_suffix(".csv") json_p = out_stem.with_suffix(".json") df.to_csv(csv_p, index=False) json_p.write_text( json.dumps( df.to_dict("records"), ensure_ascii=False, indent=2, ) + "\n", encoding="utf-8", ) print(f"[OK] wrote: {csv_p}") print(f"[OK] wrote: {json_p}") if utils.str_to_bool(args.include_long, default=True): df_long = _to_long(df) if not df_long.empty: csv_l = out_stem.with_name( out_stem.name + "_long" ) csv_l = csv_l.with_suffix(".csv") json_l = out_stem.with_name( out_stem.name + "_long" ) json_l = json_l.with_suffix(".json") df_long.to_csv(csv_l, index=False) json_l.write_text( json.dumps( df_long.to_dict("records"), ensure_ascii=False, indent=2, ) + "\n", encoding="utf-8", ) print(f"[OK] wrote: {csv_l}") print(f"[OK] wrote: {json_l}")
[docs] def main(argv: list[str] | None = None) -> None: build_model_metrics_main(argv)
if __name__ == "__main__": main()