Source code for geoprior.scripts.update_ablation_records

# SPDX-License-Identifier: Apache-2.0
# GeoPrior-v3 - https://github.com/earthai-tech/geoprior-v3
# Copyright (c) 2026-present
# Author: LKouadio <https://lkouadio.com>

"""Update ablation_record  using post-hoc calibrated metrics.

Why this script
---------------
Older ablation JSONL rows may store point/interval metrics in
SI units (e.g., meters) or store *uncalibrated* interval metrics.

This script patches each JSONL row by reading the authoritative
run artifacts:

- ``geoprior_eval_phys_<TS>_interpretable.json`` (preferred)
- ``geoprior_eval_phys_<TS>.json`` (fallback)
- ``*eval_diagnostics*_calibrated.json`` (fallback)

Outputs
-------
For each input ablation JSONL:

- writes a sibling file ``ablation_record.updated.jsonl`` next
  to the original and does not overwrite by default
- optionally exports a combined table as CSV and JSON under
  ``scripts/out``
"""

from __future__ import annotations

import argparse
import json
import math
from dataclasses import dataclass
from pathlib import Path
from typing import Any

import pandas as pd

from . import config as cfg
from . import utils

_PHYS_PREFIX = "geoprior_eval_phys_"
_PHYS_EXT = ".json"


[docs] @dataclass(frozen=True) class PatchInfo: ok: bool reason: str = "" ts: str = "" run_dir: str = "" phys_json: str = "" eval_diag: str = ""
def _parse_args( argv: list[str] | None, *, prog: str | None = None ) -> argparse.Namespace: p = argparse.ArgumentParser( prog=prog or "update-ablation-records", description=( "Patch ablation_record.jsonl to use post-hoc " "calibrated metrics in interpretable units." ), ) g = p.add_mutually_exclusive_group(required=True) g.add_argument( "--results-root", type=str, default=None, help=( "Scan recursively for run folders + ablation " "records under this root." ), ) g.add_argument( "--run-dir", type=str, default=None, help=( "Update only the ablation record(s) found under " "this run directory." ), ) p.add_argument( "--ablation", type=str, default=None, help=( "Optional explicit path to a JSONL ablation record. " "If provided, only this file is processed." ), ) p.add_argument( "--include-updated-inputs", type=str, default="false", help=( "Also process ablation_record.updated*.jsonl as " "inputs (true/false)." ), ) p.add_argument( "--overwrite", action="store_true", help="Overwrite existing *.updated.jsonl outputs.", ) p.add_argument( "--backup", action="store_true", help="Write a .bak copy of the input JSONL.", ) p.add_argument( "--dry-run", action="store_true", help="Do not write; just report.", ) p.add_argument( "--no-table", action="store_true", help="Skip exporting the combined CSV/JSON table.", ) p.add_argument( "--table-stem", type=str, default="ablation_records_table.updated", help=( "Output stem for the combined table under scripts/out/. " "Writes <stem>.csv and <stem>.json" ), ) p.add_argument( "--verbose", type=str, default="true", help="Print per-file status (true/false).", ) return p.parse_args(argv) def _is_updated_name(p: Path) -> bool: return "updated" in p.name.lower() def _out_updated_path(src: Path) -> Path: if src.name == "ablation_record.jsonl": return src.with_name("ablation_record.updated.jsonl") stem = src.stem if stem.endswith(".updated"): return src return src.with_name(stem + ".updated" + src.suffix) def _read_jsonl(path: Path) -> list[dict[str, Any]]: rows: list[dict[str, Any]] = [] for ln in path.read_text(encoding="utf-8").splitlines(): s = ln.strip() if not s: continue try: obj = json.loads(s) except Exception: continue if isinstance(obj, dict): rows.append(obj) return rows def _write_jsonl( path: Path, rows: list[dict[str, Any]] ) -> None: out = [json.dumps(r, ensure_ascii=False) for r in rows] path.write_text("\n".join(out) + "\n", encoding="utf-8") def _safe_float(x: Any) -> float | None: if x is None: return None try: v = float(x) except Exception: return None if not math.isfinite(v): return None return v def _parse_ts_from_phys_json(name: str) -> str | None: if not ( name.startswith(_PHYS_PREFIX) and name.endswith(_PHYS_EXT) ): return None core = name[len(_PHYS_PREFIX) : -len(_PHYS_EXT)] if core.endswith("_interpretable"): core = core[: -len("_interpretable")] return core or None def _build_ts_index(root: Path) -> dict[str, Path]: """Map <timestamp> -> run_dir by scanning phys JSONs.""" idx: dict[str, tuple[int, float, Path]] = {} for fp in root.rglob(f"{_PHYS_PREFIX}*{_PHYS_EXT}"): ts = _parse_ts_from_phys_json(fp.name) if not ts: continue is_interp = int("interpretable" in fp.name) mtime = fp.stat().st_mtime cur = idx.get(ts) if cur is None: idx[ts] = (is_interp, mtime, fp.parent) continue if (is_interp, mtime) > (cur[0], cur[1]): idx[ts] = (is_interp, mtime, fp.parent) return {k: v[2] for k, v in idx.items()} def _find_phys_json_for_ts( run_dir: Path, ts: str, ) -> Path | None: p1 = run_dir / f"{_PHYS_PREFIX}{ts}_interpretable.json" if p1.exists(): return p1 p2 = run_dir / f"{_PHYS_PREFIX}{ts}.json" if p2.exists(): return p2 # last resort: newest preferred under this run_dir return utils.find_phys_json(run_dir) def _infer_run_dir_from_ablation( ab_path: Path, ts: str, ) -> Path | None: base = ab_path.parent if base.name == "ablation_records": base = base.parent for nm in [ f"{_PHYS_PREFIX}{ts}_interpretable.json", f"{_PHYS_PREFIX}{ts}.json", ]: cand = base / nm if cand.exists(): return cand.parent # fallback: search within base for nm in [ f"{_PHYS_PREFIX}{ts}_interpretable.json", f"{_PHYS_PREFIX}{ts}.json", ]: hits = list(base.rglob(nm)) if hits: hits.sort( key=lambda x: x.stat().st_mtime, reverse=True, ) return hits[0].parent return None def _pick_diag_path(run_dir: Path) -> Path | None: return utils.find_eval_diag_json(run_dir) def _patch_record( rec: dict[str, Any], *, ab_path: Path, ts_index: dict[str, Path], ) -> tuple[dict[str, Any], PatchInfo]: ts = str(rec.get("timestamp") or "").strip() if not ts: return rec, PatchInfo(ok=False, reason="no_timestamp") run_dir = ts_index.get(ts) if run_dir is None: run_dir = _infer_run_dir_from_ablation(ab_path, ts) if run_dir is None: return rec, PatchInfo( ok=False, ts=ts, reason="run_dir_not_found", ) phys_p = _find_phys_json_for_ts(run_dir, ts) if phys_p is None or not phys_p.exists(): return rec, PatchInfo( ok=False, ts=ts, run_dir=str(run_dir), reason="phys_json_missing", ) diag_p = _pick_diag_path(run_dir) phys = utils.safe_load_json(phys_p) phys = utils.phys_json_to_mm(phys) diag = utils.safe_load_json(diag_p) if diag_p else {} r2, mae, mse = utils.pick_point_metrics(phys, diag) cov, shp = utils.pick_interval_metrics(phys, diag) pm = phys.get("point_metrics") or {} rmse = _safe_float(pm.get("rmse")) if rmse is None and _safe_float(mse) is not None: rmse = math.sqrt(float(mse)) # preserve old headline metrics legacy: dict[str, Any] = {} for k in [ "r2", "mse", "mae", "rmse", "coverage80", "sharpness80", ]: if k in rec: legacy[k] = rec.get(k) rec.setdefault("legacy", {}) rec["legacy"]["pre_posthoc_patch"] = legacy units = phys.get("units") or {} rec["units"] = units # headline metrics (paper-consistent) rec["r2"] = _safe_float(r2) rec["mse"] = _safe_float(mse) rec["mae"] = _safe_float(mae) rec["rmse"] = _safe_float(rmse) rec["coverage80"] = _safe_float(cov) rec["sharpness80"] = _safe_float(shp) # physics diagnostics (if present) phys_diag = phys.get("physics_diagnostics") or {} for k in ["epsilon_prior", "epsilon_cons", "epsilon_gw"]: v = phys_diag.get(k) if v is not None: rec[k] = _safe_float(v) # per-horizon (keep whole block for downstream scripts) if "per_horizon" in phys: rec["per_horizon"] = phys.get("per_horizon") # optional: bring PSS from diagnostics when present pss = None blk = ( diag.get("__overall__") if isinstance(diag, dict) else None ) if isinstance(blk, dict): pss = blk.get("pss") if pss is not None: rec["pss"] = _safe_float(pss) # traceability blocks rec["metrics"] = { "posthoc": { "point_metrics": phys.get("point_metrics") or {}, "interval_calibration": ( phys.get("interval_calibration") or {} ), "per_horizon": phys.get("per_horizon") or {}, }, "raw": { "metrics_evaluate": ( phys.get("metrics_evaluate") or {} ), "eval_diag_flat": utils.flatten_eval_diag(diag), }, "sources": { "phys_json": str(phys_p), "eval_diag_json": ( str(diag_p) if diag_p else None ), "run_dir": str(run_dir), }, } return rec, PatchInfo( ok=True, ts=ts, run_dir=str(run_dir), phys_json=str(phys_p), eval_diag=str(diag_p) if diag_p else "", ) def _discover_ablation_files( *, root: Path, ablation: str | None, include_updated: bool, ) -> list[Path]: if ablation: p = Path(ablation).expanduser().resolve() if not p.exists(): raise FileNotFoundError(str(p)) return [p] files = utils.find_all( root, cfg.PATTERNS.get("ablation_record_jsonl", ()), must_exist=False, ) if not include_updated: base = [p for p in files if not _is_updated_name(p)] if base: return base return files def _simple_row(rec: dict[str, Any]) -> dict[str, Any]: row: dict[str, Any] = {} for k, v in rec.items(): if ( isinstance(v, str | int | float | bool) or v is None ): row[k] = v return row
[docs] def update_ablation_records_main( argv: list[str] | None = None, *, prog: str | None = None ) -> None: args = _parse_args(argv, prog=prog) verbose = utils.str_to_bool(args.verbose, default=True) inc_upd = utils.str_to_bool( args.include_updated_inputs, default=False, ) if args.results_root: root = Path(args.results_root).expanduser().resolve() else: root = Path(args.run_dir).expanduser().resolve() # type: ignore[arg-type] if not root.exists(): raise FileNotFoundError(str(root)) utils.ensure_script_dirs() # Build <timestamp> -> run_dir map. ts_index = _build_ts_index(root) ab_files = _discover_ablation_files( root=root, ablation=args.ablation, include_updated=inc_upd, ) if not ab_files: raise FileNotFoundError( f"No ablation_record*.jsonl found under {root}" ) all_rows: list[dict[str, Any]] = [] for ab_path in ab_files: ab_path = ab_path.expanduser().resolve() rows = _read_jsonl(ab_path) if not rows: if verbose: print(f"[Skip] empty: {ab_path}") continue patched: list[dict[str, Any]] = [] infos: list[PatchInfo] = [] for rec in rows: rec2, info = _patch_record( rec, ab_path=ab_path, ts_index=ts_index, ) patched.append(rec2) infos.append(info) ok_n = sum(1 for i in infos if i.ok) bad = [i for i in infos if not i.ok] if verbose: print( f"[File] {ab_path.name}: patched {ok_n}/{len(infos)}" ) for i in bad[:8]: print(f" - ts={i.ts} reason={i.reason}") if len(bad) > 8: print(f" ... +{len(bad) - 8} more") out_p = _out_updated_path(ab_path) if out_p.exists() and not args.overwrite: if verbose: print( f"[Skip] exists (use --overwrite): {out_p}" ) else: if args.backup: bak = ab_path.with_suffix( ab_path.suffix + ".bak" ) bak.write_text( ab_path.read_text(encoding="utf-8"), encoding="utf-8", ) if verbose: print(f"[OK] backup: {bak}") if not args.dry_run: _write_jsonl(out_p, patched) if verbose: print(f"[OK] wrote: {out_p}") for r in patched: row = _simple_row(r) row["_src"] = str(ab_path) srcs = (r.get("metrics") or {}).get( "sources" ) or {} if isinstance(srcs, dict): row["_run_dir"] = srcs.get("run_dir") all_rows.append(row) if args.no_table: return if not all_rows: return df = pd.DataFrame(all_rows) df = df.sort_values( by=["city", "timestamp"], na_position="last", ) stem = str(args.table_stem).strip() or "ablation_table" out_csv = utils.resolve_out_out(stem + ".csv") out_js = utils.resolve_out_out(stem + ".json") if not args.dry_run: df.to_csv(out_csv, index=False) out_js.write_text( json.dumps( all_rows, ensure_ascii=False, indent=2, ), encoding="utf-8", ) if verbose: print(f"[OK] table: {out_csv}") print(f"[OK] table: {out_js}")
[docs] def main( argv: list[str] | None = None, *, prog: str | None = None ) -> None: update_ablation_records_main(argv, prog=prog)
if __name__ == "__main__": main()