Source code for geoprior.scripts.update_ablation_record_posthoc

# SPDX-License-Identifier: Apache-2.0
# GeoPrior-v3 - https://github.com/earthai-tech/geoprior-v3
# Copyright (c) 2026-present
# Author: LKouadio <https://lkouadio.com>

"""
Update ablation record  using post-hoc calibrated metrics.

What this does
--------------
For each JSONL row (one run), this script:

- finds the run folder via the timestamped
  ``geoprior_eval_phys_<TS>_interpretable.json`` file
- replaces top-level metrics with post-hoc calibrated values:
  ``mae``, ``mse``, ``rmse``, ``r2``, ``coverage80``,
  ``sharpness80``, and per-horizon metrics
- writes units and detailed blocks into ``record["metrics"]``
- stores the old pre-patch values under
  ``record["legacy"]["pre_posthoc_patch"]``

Typical layout
--------------
.. code-block:: text

   results_root/
     run_folder_A/
       geoprior_eval_phys_<TS>_interpretable.json
       <city>_GeoPriorSubsNet_eval_diagnostics_TestSet_H3_calibrated.json
       ablations_records/ablation_record.jsonl
     run_folder_B/
       ...

The ``ablations_records/`` folder lives under each run folder.
"""

from __future__ import annotations

import argparse
import json
import math
from pathlib import Path
from typing import Any

_INTERP_PREFIX = "geoprior_eval_phys_"
_INTERP_SUFFIX = "_interpretable.json"


def _read_json(path: Path) -> dict[str, Any]:
    return json.loads(path.read_text(encoding="utf-8"))


def _safe_float(x: Any) -> float | None:
    if x is None:
        return None
    try:
        return float(x)
    except Exception:
        return None


def _pick_first(*vals: Any) -> Any:
    for v in vals:
        if v is not None:
            return v
    return None


def _build_ts_index(runs_root: Path) -> dict[str, Path]:
    """
    Map timestamp -> run_dir by scanning interpretable phys JSONs.
    """
    idx: dict[str, Path] = {}
    pat = f"{_INTERP_PREFIX}*{_INTERP_SUFFIX}"
    for p in runs_root.rglob(pat):
        name = p.name
        if not (
            name.startswith(_INTERP_PREFIX)
            and name.endswith(_INTERP_SUFFIX)
        ):
            continue
        ts = name[len(_INTERP_PREFIX) : -len(_INTERP_SUFFIX)]
        idx[ts] = p.parent
    return idx


def _find_diag_calibrated(
    run_dir: Path,
    city: str | None,
) -> Path | None:
    """
    Prefer <city>_*eval_diagnostics*_calibrated.json.
    Fallback to any *eval_diagnostics*_calibrated.json.
    """
    if city:
        cands = list(
            run_dir.rglob(
                f"{city}_*eval_diagnostics*_calibrated.json"
            )
        )
        if cands:
            return cands[0]

    cands = list(
        run_dir.rglob("*eval_diagnostics*_calibrated.json")
    )
    if cands:
        return cands[0]
    return None


def _hkeys_to_Hn(d: dict[str, Any]) -> dict[str, Any]:
    """
    Convert {"1": v, "2": v} -> {"H1": v, "H2": v}.
    Leaves keys that already look like "H1" unchanged.
    """
    out: dict[str, Any] = {}
    for k, v in (d or {}).items():
        ks = str(k)
        if ks.startswith("H"):
            out[ks] = v
        else:
            out[f"H{ks}"] = v
    return out


def _update_one_record(
    rec: dict[str, Any],
    run_dir: Path,
) -> tuple[dict[str, Any], dict[str, Any]]:
    ts = str(rec.get("timestamp") or "")
    city = rec.get("city")

    interp_path = (
        run_dir / f"{_INTERP_PREFIX}{ts}{_INTERP_SUFFIX}"
    )
    if not interp_path.exists():
        return rec, {
            "ok": False,
            "reason": "missing_interpretable_json",
            "expected": str(interp_path),
        }

    interp = _read_json(interp_path)

    diag_cal_path = _find_diag_calibrated(run_dir, city)
    diag_cal = (
        _read_json(diag_cal_path) if diag_cal_path else {}
    )
    overall = (diag_cal or {}).get("__overall__", {})

    # --- Preserve old (wrong) headline metrics ---
    legacy: dict[str, Any] = {}
    for k in [
        "r2",
        "mse",
        "mae",
        "rmse",
        "coverage80",
        "sharpness80",
    ]:
        if k in rec:
            legacy[k] = rec.get(k)
    rec.setdefault("legacy", {})
    rec["legacy"]["pre_posthoc_patch"] = legacy

    # --- Units (interpretable JSON is authoritative) ---
    units = interp.get("units", {}) or {}
    rec["units"] = units

    # --- Post-hoc point metrics (prefer diagnostics overall, then interp) ---
    p_point = interp.get("point_metrics", {}) or {}
    r2 = _pick_first(
        overall.get("overall_r2"), p_point.get("r2")
    )
    mse = _pick_first(
        overall.get("overall_mse"), p_point.get("mse")
    )
    mae = _pick_first(
        overall.get("overall_mae"), p_point.get("mae")
    )
    rmse = _pick_first(
        overall.get("overall_rmse"), p_point.get("rmse")
    )
    if rmse is None and mse is not None:
        try:
            rmse = math.sqrt(float(mse))
        except Exception:
            rmse = None

    # --- Post-hoc calibrated interval metrics ---
    ival = interp.get("interval_calibration", {}) or {}
    cov_cal = _pick_first(
        overall.get("coverage80"),
        ival.get("coverage80_calibrated_phys"),
        ival.get("coverage80_calibrated"),
    )
    shp_cal = _pick_first(
        overall.get("sharpness80"),
        ival.get("sharpness80_calibrated_phys"),
        ival.get("sharpness80_calibrated"),
    )

    # --- Uncalibrated (for transparency/debug) ---
    cov_uncal = _pick_first(
        ival.get("coverage80_uncalibrated_phys"),
        ival.get("coverage80_uncalibrated"),
    )
    shp_uncal = _pick_first(
        ival.get("sharpness80_uncalibrated_phys"),
        ival.get("sharpness80_uncalibrated"),
    )

    # --- Per-horizon (prefer diagnostics overall, then interp) ---
    ph_mae = overall.get("per_horizon_mae") or {}
    ph_r2 = overall.get("per_horizon_r2") or {}
    if ph_mae:
        rec["per_horizon_mae"] = _hkeys_to_Hn(ph_mae)
    else:
        rec["per_horizon_mae"] = (
            interp.get("per_horizon", {}) or {}
        ).get("mae")

    if ph_r2:
        rec["per_horizon_r2"] = _hkeys_to_Hn(ph_r2)
    else:
        rec["per_horizon_r2"] = (
            interp.get("per_horizon", {}) or {}
        ).get("r2")

    # --- Physics diagnostics (interpretable JSON) ---
    phys = interp.get("physics_diagnostics", {}) or {}
    for k in ["epsilon_prior", "epsilon_cons", "epsilon_gw"]:
        if k in phys:
            rec[k] = phys.get(k)

    # --- Write updated headline metrics (paper-consistent) ---
    rec["r2"] = _safe_float(r2)
    rec["mse"] = _safe_float(mse)
    rec["mae"] = _safe_float(mae)
    rec["rmse"] = _safe_float(rmse)
    rec["coverage80"] = _safe_float(cov_cal)
    rec["sharpness80"] = _safe_float(shp_cal)

    # --- Attach full blocks for traceability ---
    rec["metrics"] = {
        "posthoc": {
            "point_metrics": {
                "mae": _safe_float(mae),
                "mse": _safe_float(mse),
                "rmse": _safe_float(rmse),
                "r2": _safe_float(r2),
            },
            "interval": {
                "target": ival.get("target"),
                "coverage80_calibrated": _safe_float(cov_cal),
                "sharpness80_calibrated": _safe_float(
                    shp_cal
                ),
                "coverage80_uncalibrated": _safe_float(
                    cov_uncal
                ),
                "sharpness80_uncalibrated": _safe_float(
                    shp_uncal
                ),
                "factors_per_horizon": ival.get(
                    "factors_per_horizon"
                ),
            },
            "per_horizon": interp.get("per_horizon", {}),
        },
        "raw": {
            "metrics_evaluate": interp.get(
                "metrics_evaluate", {}
            ),
        },
        "sources": {
            "interpretable": str(interp_path),
            "eval_diag_cal": (
                str(diag_cal_path) if diag_cal_path else None
            ),
        },
    }

    return rec, {
        "ok": True,
        "ts": ts,
        "run_dir": str(run_dir),
        "unit": units.get("subs_metrics_unit"),
    }


[docs] def main() -> int: ap = argparse.ArgumentParser() ap.add_argument( "--runs-root", type=Path, required=True, help="Folder containing run folders (scan recursive).", ) ap.add_argument( "--ablation", type=Path, required=True, help="Path to ablation_record.jsonl", ) ap.add_argument( "--backup", action="store_true", help="Write a .bak copy before overwriting.", ) ap.add_argument( "--dry-run", action="store_true", help="Do not write; just report.", ) args = ap.parse_args() runs_root = args.runs_root.expanduser().resolve() ab_path = args.ablation.expanduser().resolve() if not ab_path.exists(): raise SystemExit(f"Missing: {ab_path}") idx = _build_ts_index(runs_root) if not idx: raise SystemExit( "No geoprior_eval_phys_*_interpretable.json found." ) lines = ab_path.read_text(encoding="utf-8").splitlines() recs: list[dict[str, Any]] = [] infos: list[dict[str, Any]] = [] for ln in lines: if not ln.strip(): continue rec = json.loads(ln) ts = str(rec.get("timestamp") or "") run_dir = idx.get(ts) if run_dir is None: infos.append( { "ok": False, "ts": ts, "reason": "ts_not_found", } ) recs.append(rec) continue rec_u, info = _update_one_record(rec, run_dir) recs.append(rec_u) infos.append(info) ok_n = sum(1 for i in infos if i.get("ok")) bad_n = len(infos) - ok_n print(f"Updated: {ok_n} | Skipped: {bad_n}") for i in infos: if not i.get("ok"): print(f" - {i}") if args.dry_run: return 0 if args.backup: bak = ab_path.with_suffix(ab_path.suffix + ".bak") bak.write_text( "\n".join(lines) + "\n", encoding="utf-8" ) print(f"Backup: {bak}") out_lines = [ json.dumps(r, ensure_ascii=False) for r in recs ] ab_path.write_text( "\n".join(out_lines) + "\n", encoding="utf-8" ) print(f"Wrote: {ab_path}") return 0
if __name__ == "__main__": raise SystemExit(main())