# SPDX-License-Identifier: Apache-2.0
# GeoPrior-v3 — https://github.com/earthai-tech/geoprior-v3
# Copyright (c) 2026-present
# Author: LKouadio <https://lkouadio.com>
"""
Build tidy ablation/sensitivity tables from
ablation_records.
Design goals
------------
- Robust input: user can pass a results root, a city root,
an ablation_records/ folder, or the ablation_record.jsonl
file itself.
- Robust schema: tolerate older/newer column aliases.
- Robust units: normalize subsidence-distance metrics to mm.
- Robust paper export: compact TeX table with arrows (↑/↓)
and a single error metric (RMSE or MSE).
Run
---
python -m scripts build-ablation-table --root results
"""
from __future__ import annotations
import argparse
import json
from pathlib import Path
from typing import Any
import numpy as np
import pandas as pd
from . import config as cfg
from . import utils
_LOWER_IS_BETTER = {
"mae",
"mse",
"rmse",
"sharpness80",
"epsilon_prior",
"epsilon_cons",
"epsilon_gw",
}
# ---------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------
def _parse_args(
argv: list[str] | None,
*,
prog: str | None = None,
) -> argparse.Namespace:
p = argparse.ArgumentParser(
prog=prog or "build-ablation-table",
description=(
"Build ablation tables from ablation_record.jsonl."
),
)
p.add_argument(
"--root",
type=str,
default="results",
help=(
"Root to scan (results/ or a city/job folder). "
"Can also be an ablation_record*.jsonl file."
),
)
p.add_argument(
"--input",
type=str,
default=None,
action="append",
help=(
"Explicit inputs (repeatable). Supports "
".jsonl/.json/.csv. Overrides --root scan."
),
)
p.add_argument(
"--cities",
type=str,
default="",
help="Comma list of cities to keep (default: all).",
)
p.add_argument(
"--models",
type=str,
default="",
help="Comma list of model names to keep.",
)
p.add_argument(
"--sort-by",
type=str,
default="mae",
help=(
"Sort by this metric (default: mae). "
"Use 'none' to keep input order."
),
)
p.add_argument(
"--ascending",
type=str,
default="auto",
help=(
"Sort direction: true/false/auto. "
"auto uses metric convention."
),
)
# Output
p.add_argument(
"--out",
type=str,
default="table_ablations",
help=(
"Output stem/path (scripts/out/ if relative). "
"Suffix is stripped and treated as stem."
),
)
p.add_argument(
"--out-dir",
type=str,
default=None,
help="Optional output dir override.",
)
p.add_argument(
"--formats",
type=str,
default="csv,json",
help=(
"Comma list: csv,json,tex,txt. Default: csv,json"
),
)
# Table content
p.add_argument(
"--metrics",
type=str,
default="",
help=(
"Comma list of metric columns to keep. "
"Default: keep all known metrics."
),
)
p.add_argument(
"--keep-per-horizon",
type=str,
default="true",
help=("Include per-horizon metrics (true/false)."),
)
p.add_argument(
"--max-h",
type=int,
default=3,
help="Max horizon columns to export (default 3).",
)
# Unit handling
p.add_argument(
"--metric-unit",
type=str,
default="mm",
choices=["mm", "m"],
help=(
"Unit for subsidence distance metrics in output. "
"Default: mm."
),
)
# Paper mode
p.add_argument(
"--for-paper",
action="store_true",
help="Produce compact paper-friendly table.",
)
p.add_argument(
"--err-metric",
type=str,
default="rmse",
choices=["rmse", "mse"],
help=(
"Paper table keeps only one error metric (rmse or mse)."
),
)
p.add_argument(
"--keep-r2",
action="store_true",
help="In --for-paper mode, keep R2 column as well.",
)
p.add_argument(
"--sideway",
action="store_true",
help="Write TeX as sidewaystable (landscape).",
)
p.add_argument(
"--caption",
type=str,
default=(
"Extended ablations and sensitivity analysis."
),
help="TeX caption.",
)
p.add_argument(
"--label",
type=str,
default="tab:ablations",
help="TeX label.",
)
p.add_argument(
"--best-per-city",
action="store_true",
help=(
"Also export best row per city for --sort-by metric."
),
)
p.add_argument(
"--stdout",
action="store_true",
help="Print final table to stdout.",
)
# -------------------------------------------------
# Optional grouped outputs (Tables S6 / S7)
# -------------------------------------------------
p.add_argument(
"--group-cols",
type=str,
default="",
help=(
"Extra grouped outputs. Comma list: s6,s7. "
"s6: λ_cons×λ_prior grid. "
"s7: toggle ablations summary."
),
)
p.add_argument(
"--s6-metrics",
type=str,
default="",
help=(
"Comma list of metrics for Table S6 grids. "
"Default: use --sort-by if present else mae."
),
)
p.add_argument(
"--s6-agg",
type=str,
default="mean",
choices=["mean", "median"],
help="Aggregation for duplicate grid points.",
)
p.add_argument(
"--s6-tex-bold-best",
type=str,
default="true",
help=(
"In S6 TeX grids, bold the best cell "
"(true/false)."
),
)
p.add_argument(
"--s6-tex-gray-missing",
type=str,
default="true",
help=(
"In S6 TeX grids, gray missing/invalid "
"cells (true/false)."
),
)
p.add_argument(
"--s6-tex-gray-level",
type=str,
default="gray!15",
help=("TeX color for missing cells (e.g., gray!15)."),
)
p.add_argument(
"--s6-tex-one-file",
type=str,
default="false",
help=(
"Write one TeX file per (city,physics) with "
"all metrics, instead of one per metric."
),
)
p.add_argument(
"--s7-cols",
type=str,
default="",
help=(
"Comma list of grouping columns for Table S7. "
"Default: a robust preset."
),
)
p.add_argument(
"--s7-metrics",
type=str,
default="",
help=(
"Comma list of metrics for Table S7 summary. "
"Default: paper metrics if --for-paper else all."
),
)
p.add_argument(
"--s7-agg",
type=str,
default="mean",
choices=["mean", "median"],
help="Aggregation for Table S7 groups.",
)
return p.parse_args(argv)
# ---------------------------------------------------------------------
# I/O
# ---------------------------------------------------------------------
def _read_jsonl(fp: Path) -> list[dict[str, Any]]:
rows: list[dict[str, Any]] = []
with fp.open("r", encoding="utf-8") as f:
for ln in f:
s = ln.strip()
if not s:
continue
try:
rec = json.loads(s)
except Exception:
continue
if isinstance(rec, dict):
rows.append(rec)
return rows
def _load_one(path: Path) -> pd.DataFrame:
suf = path.suffix.lower()
if suf == ".csv":
df = pd.read_csv(path)
df["_src"] = df.get("_src", str(path))
return df
if suf == ".jsonl":
rows = _read_jsonl(path)
df = pd.DataFrame(rows)
if "_src" not in df.columns:
df["_src"] = str(path)
return df
if suf == ".json":
try:
obj = json.loads(path.read_text(encoding="utf-8"))
except Exception:
obj = None
if isinstance(obj, list):
df = pd.DataFrame(obj)
elif isinstance(obj, dict):
df = pd.DataFrame([obj])
else:
df = pd.DataFrame([])
if not df.empty and "_src" not in df.columns:
df["_src"] = str(path)
return df
return pd.DataFrame([])
def _scan_records(root: Path) -> pd.DataFrame:
files = utils.find_all(
root,
cfg.PATTERNS.get("ablation_record_jsonl", ()),
)
blocks: list[pd.DataFrame] = []
for fp in files:
df = _load_one(fp)
if not df.empty:
blocks.append(df)
if not blocks:
return pd.DataFrame([])
return pd.concat(blocks, ignore_index=True)
def _load_records(args: argparse.Namespace) -> pd.DataFrame:
if args.input:
blocks: list[pd.DataFrame] = []
for s in args.input:
p = utils.as_path(s)
if p.exists():
df = _load_one(p)
if not df.empty:
blocks.append(df)
if not blocks:
return pd.DataFrame([])
return pd.concat(blocks, ignore_index=True)
root = utils.as_path(args.root)
return _scan_records(root)
# ---------------------------------------------------------------------
# Canonicalization / flattening
# ---------------------------------------------------------------------
def _canon_pde_mode(x: Any) -> str:
s = str(x or "").strip().lower()
if s in {"none", "off", "no", "0", "false"}:
return "none"
if s in {"both", "on", "1", "true"}:
return "both"
if s in {"consolidation", "gw_flow"}:
return "both"
return s
def _expand_dict_col(
df: pd.DataFrame,
col: str,
*,
prefix: str,
max_h: int,
only_h: bool,
) -> pd.DataFrame:
if col not in df.columns:
return df
def _keys(d: dict[str, Any]) -> list[str]:
out: list[str] = []
for k in d.keys():
kk = str(k)
if only_h:
if not (
kk.startswith("H") and kk[1:].isdigit()
):
continue
if int(kk[1:]) > int(max_h):
continue
out.append(kk)
if only_h:
out.sort(key=lambda x: int(x[1:]))
else:
out.sort()
return out
keys: list[str] = []
for v in df[col].to_list():
if isinstance(v, dict):
keys = _keys(v)
if keys:
break
if not keys:
return df
for k in keys:
name = f"{prefix}{k}"
df[name] = df[col].apply(
lambda d: (
d.get(k) if isinstance(d, dict) else np.nan
)
)
df[name] = pd.to_numeric(df[name], errors="coerce")
return df
def _canon_cols(
df: pd.DataFrame, *, max_h: int
) -> pd.DataFrame:
if df.empty:
return df
aliases = {
"timestamp": ("timestamp", "ts"),
"city": ("city", "City"),
"model": ("model", "Model"),
"pde_mode": ("pde_mode", "pde"),
"lambda_prior": ("lambda_prior", "lambda_p"),
"lambda_cons": ("lambda_cons", "lambda_c"),
"coverage80": ("coverage80", "coverage8"),
"sharpness80": ("sharpness80", "sharpness"),
}
utils.ensure_columns(df, aliases=aliases)
if "city" in df.columns:
df["city"] = (
df["city"].astype(str).map(utils.canonical_city)
)
if "pde_mode" in df.columns:
df["pde_mode"] = df["pde_mode"].map(_canon_pde_mode)
else:
df["pde_mode"] = "both"
df["pde_bucket"] = np.where(
df["pde_mode"].astype(str).eq("none"),
"none",
"both",
)
# Flatten common nested blocks
df = _expand_dict_col(
df,
"per_horizon_mae",
prefix="per_horizon_mae.",
max_h=max_h,
only_h=True,
)
df = _expand_dict_col(
df,
"per_horizon_r2",
prefix="per_horizon_r2.",
max_h=max_h,
only_h=True,
)
num_cols = [
"lambda_prior",
"lambda_cons",
"lambda_gw",
"lambda_smooth",
"lambda_mv",
"lambda_bounds",
"lambda_q",
"hd_factor",
"r2",
"mae",
"mse",
"rmse",
"coverage80",
"sharpness80",
"epsilon_prior",
"epsilon_cons",
"epsilon_gw",
]
for c in num_cols:
if c in df.columns:
df[c] = pd.to_numeric(df[c], errors="coerce")
for c in list(df.columns):
if c.startswith("per_horizon_"):
df[c] = pd.to_numeric(df[c], errors="coerce")
return df
# ---------------------------------------------------------------------
# Units + derived metrics
# ---------------------------------------------------------------------
def _needs_m_to_mm(row: pd.Series) -> bool:
u = row.get("units", None)
if isinstance(u, dict):
uu = str(u.get("subs_metrics_unit", "")).lower()
if uu == "mm":
return False
if uu in {"m", "meter", "metre"}:
return True
# Heuristic fallback
mae = row.get("mae", np.nan)
mse = row.get("mse", np.nan)
shp = row.get("sharpness80", np.nan)
ok = (
np.isfinite(mae)
and np.isfinite(mse)
and np.isfinite(shp)
and 0.0 < float(mae) < 1.0
and 0.0 < float(mse) < 1.0
and 0.0 < float(shp) < 1.0
)
return bool(ok)
def _scale_rows_to_mm(df: pd.DataFrame) -> pd.DataFrame:
if df.empty:
return df
mask = df.apply(_needs_m_to_mm, axis=1)
if not mask.any():
return df
dist_cols: list[str] = []
mse_cols: list[str] = []
for c in df.columns:
s = str(c).lower()
if s == "mse" or s.endswith(".mse"):
mse_cols.append(c)
if s == "mae" or s.endswith(".mae"):
dist_cols.append(c)
if s == "rmse" or s.endswith(".rmse"):
dist_cols.append(c)
if s == "sharpness80" or s.endswith(".sharpness80"):
dist_cols.append(c)
dist_cols = sorted(set(dist_cols))
mse_cols = sorted(set(mse_cols))
for c in dist_cols:
df.loc[mask, c] = df.loc[mask, c] * 1000.0
for c in mse_cols:
df.loc[mask, c] = df.loc[mask, c] * 1e6
return df
def _to_unit(df: pd.DataFrame, *, unit: str) -> pd.DataFrame:
"""Convert mm->m if requested. Default table unit is mm."""
unit2 = str(unit).strip().lower()
if unit2 not in {"mm", "m"}:
return df
if unit2 == "mm":
return df
out = df.copy()
dist_cols: list[str] = []
mse_cols: list[str] = []
for c in out.columns:
s = str(c).lower()
if s == "mse" or s.endswith(".mse"):
mse_cols.append(c)
if s == "mae" or s.endswith(".mae"):
dist_cols.append(c)
if s == "rmse" or s.endswith(".rmse"):
dist_cols.append(c)
if s == "sharpness80" or s.endswith(".sharpness80"):
dist_cols.append(c)
for c in dist_cols:
out[c] = pd.to_numeric(out[c], errors="coerce") * 1e-3
for c in mse_cols:
out[c] = pd.to_numeric(out[c], errors="coerce") * 1e-6
return out
def _ensure_mse_rmse(df: pd.DataFrame) -> pd.DataFrame:
if df.empty:
return df
out = df.copy()
mse = out["mse"] if "mse" in out.columns else None
rmse = out["rmse"] if "rmse" in out.columns else None
if "rmse" not in out.columns and mse is not None:
out["rmse"] = np.sqrt(
np.maximum(0.0, mse.to_numpy(float))
)
if "mse" not in out.columns and rmse is not None:
out["mse"] = rmse.to_numpy(float) ** 2
return out
# ---------------------------------------------------------------------
# Dedupe / filters
# ---------------------------------------------------------------------
def _record_score(row: pd.Series) -> int:
s = 0
src = str(row.get("_src", "")).lower()
if "updated" in src:
s += 100
u = row.get("units", None)
if isinstance(u, dict):
uu = str(u.get("subs_metrics_unit", "")).lower()
if uu == "mm":
s += 50
if isinstance(row.get("metrics", None), dict):
s += 10
return int(s)
def _dedupe_prefer_best(df: pd.DataFrame) -> pd.DataFrame:
if df.empty:
return df
if "timestamp" not in df.columns:
return df
if "city" not in df.columns:
df = df.copy()
df["city"] = ""
x = df.copy()
x["_score"] = x.apply(_record_score, axis=1)
x = x.sort_values(
by=["timestamp", "city", "_score"],
ascending=[True, True, False],
)
x = x.drop_duplicates(
subset=["timestamp", "city"],
keep="first",
).copy()
return x.drop(columns=["_score"], errors="ignore")
def _filter_models(
df: pd.DataFrame, models: str
) -> pd.DataFrame:
s = str(models or "").strip()
if not s or "model" not in df.columns:
return df
keep = [m.strip() for m in s.split(",") if m.strip()]
if not keep:
return df
return df.loc[df["model"].astype(str).isin(keep)].copy()
def _filter_cities(
df: pd.DataFrame, cities: str
) -> pd.DataFrame:
raw = str(cities or "").strip()
if not raw or "city" not in df.columns:
return df
parts = [p.strip() for p in raw.split(",") if p.strip()]
keep = [utils.canonical_city(p) for p in parts]
if not keep:
return df
return df.loc[df["city"].astype(str).isin(keep)].copy()
def _sort_df(
df: pd.DataFrame,
*,
metric: str,
ascending: str,
) -> pd.DataFrame:
m = str(metric or "").strip().lower()
if not m or m == "none":
return df
if m not in df.columns:
return df
asc_raw = str(ascending or "auto").strip().lower()
if asc_raw in {"true", "1", "yes"}:
asc = True
elif asc_raw in {"false", "0", "no"}:
asc = False
else:
asc = m in _LOWER_IS_BETTER
return df.sort_values(by=[m], ascending=asc)
# ---------------------------------------------------------------------
# Column selection
# ---------------------------------------------------------------------
def _default_metric_cols(df: pd.DataFrame) -> list[str]:
base = [
"mae",
"rmse",
"mse",
"r2",
"coverage80",
"sharpness80",
"epsilon_prior",
"epsilon_cons",
"epsilon_gw",
]
out: list[str] = [c for c in base if c in df.columns]
for c in df.columns:
if str(c).startswith("per_horizon_mae."):
out.append(str(c))
if str(c).startswith("per_horizon_r2."):
out.append(str(c))
return out
def _pick_metrics(df: pd.DataFrame, spec: str) -> list[str]:
raw = str(spec or "").strip()
if not raw:
return _default_metric_cols(df)
parts = [p.strip() for p in raw.split(",") if p.strip()]
return [p for p in parts if p in df.columns]
def _param_cols(df: pd.DataFrame) -> list[str]:
base = [
"timestamp",
"city",
"model",
"pde_mode",
"use_effective_h",
"kappa_mode",
"hd_factor",
"lambda_cons",
"lambda_gw",
"lambda_prior",
"lambda_smooth",
"lambda_mv",
"lambda_bounds",
"lambda_q",
]
return [c for c in base if c in df.columns]
def _paper_cols(
df: pd.DataFrame,
*,
err_metric: str,
keep_r2: bool,
) -> tuple[list[str], list[str]]:
params = [
"city",
"pde_bucket",
"lambda_cons",
"lambda_prior",
"lambda_gw",
"lambda_smooth",
"lambda_bounds",
"lambda_mv",
"lambda_q",
"use_effective_h",
"kappa_mode",
"hd_factor",
]
params = [c for c in params if c in df.columns]
mets = [
"mae",
str(err_metric),
"coverage80",
"sharpness80",
]
if keep_r2:
mets.append("r2")
mets = [c for c in mets if c in df.columns]
return params, mets
# ---------------------------------------------------------------------
# TeX
# ---------------------------------------------------------------------
def _tex_escape(s: Any) -> str:
t = str(s)
t = t.replace("\\", r"\textbackslash{}")
t = t.replace("_", r"\_")
t = t.replace("%", r"\%")
t = t.replace("&", r"\&")
return t
def _fmt_bool(x: Any) -> str:
if isinstance(x, bool):
return "Yes" if x else "No"
s = str(x).strip().lower()
if s in {"1", "true", "yes", "y", "on"}:
return "Yes"
if s in {"0", "false", "no", "n", "off"}:
return "No"
return _tex_escape(x)
def _tex_header(
col: str,
*,
unit: str,
) -> str:
c = str(col)
if c == "city":
return "City"
if c == "model":
return "Model"
if c in {"pde_bucket", "pde_mode"}:
return "Physics"
if c == "use_effective_h":
return r"Use $H_{\mathrm{eff}}$"
if c == "kappa_mode":
return r"$\kappa$ mode"
if c == "hd_factor":
return r"$H_d/H_{\mathrm{eff}}$"
lam = {
"lambda_cons": r"$\lambda_{\mathrm{cons}}$",
"lambda_gw": r"$\lambda_{\mathrm{gw}}$",
"lambda_prior": r"$\lambda_{\mathrm{prior}}$",
"lambda_smooth": r"$\lambda_{\mathrm{smooth}}$",
"lambda_bounds": r"$\lambda_{\mathrm{bounds}}$",
"lambda_mv": r"$\lambda_{\mathrm{mv}}$",
"lambda_q": r"$\lambda_{q}$",
}
if c in lam:
return lam[c]
meta = cfg.PLOT_METRIC_META.get(c, None)
if isinstance(meta, dict):
ttl = str(meta.get("title", c))
u = str(meta.get("unit", "") or "")
if u:
# Override unit if caller requests meters.
if unit == "m" and u in {"mm", "mm²"}:
u = "m" if u == "mm" else "m$^2$"
ttl = ttl.format(unit=u)
return _tex_escape(ttl)
if c in {"epsilon_prior", "epsilon_cons", "epsilon_gw"}:
suf = "(↓)"
if c == "epsilon_prior":
return r"$\epsilon_{\mathrm{prior}}$ " + suf
if c == "epsilon_cons":
return r"$\epsilon_{\mathrm{cons}}$ " + suf
return r"$\epsilon_{\mathrm{gw}}$ " + suf
if c.startswith("per_horizon_mae."):
h = c.split(".", 1)[1]
return rf"MAE$_{{{_tex_escape(h)}}}$"
if c.startswith("per_horizon_r2."):
h = c.split(".", 1)[1]
return rf"$R^2_{{{_tex_escape(h)}}}$"
return _tex_escape(c)
def _tex_fmt_val(
v: Any,
*,
col: str,
) -> str:
if v is None:
return "--"
if isinstance(v, float) and (np.isnan(v) or np.isinf(v)):
return "--"
if col == "use_effective_h":
return _fmt_bool(v)
if col in {"pde_bucket", "pde_mode"}:
s = str(v).strip().lower()
return "On" if s != "none" else "Off"
meta = cfg.PLOT_METRIC_META.get(str(col), None)
if isinstance(meta, dict):
fmt = str(meta.get("fmt", "{:.3g}"))
try:
return fmt.format(float(v))
except Exception:
return "--"
try:
if isinstance(v, int | np.integer):
return str(int(v))
if isinstance(v, float | np.floating):
return f"{float(v):.3g}"
except Exception:
pass
return _tex_escape(v)
def _to_tex_table(
df: pd.DataFrame,
*,
cols: list[str],
unit: str,
sideway: bool,
caption: str,
label: str,
) -> str:
# Alignment: first 2 columns left, rest right.
aligns: list[str] = []
for i, _c in enumerate(cols):
if i < 2:
aligns.append("l")
else:
aligns.append("r")
spec = "".join(aligns)
head = " & ".join(_tex_header(c, unit=unit) for c in cols)
lines: list[str] = []
env = "sidewaystable" if sideway else "table"
lines.append(rf"\\begin{{{env}}}[t]")
lines.append(r"\\centering")
lines.append(r"\\small")
lines.append(r"\\setlength{\\tabcolsep}{4pt}")
lines.append(r"\\renewcommand{\\arraystretch}{1.15}")
lines.append(r"\\begin{tabular}{" + spec + r"}")
lines.append(r"\\toprule")
lines.append(head + r" \\")
lines.append(r"\\midrule")
for _, row in df.iterrows():
vals = [
_tex_fmt_val(row.get(c, None), col=c)
for c in cols
]
lines.append(" & ".join(vals) + r" \\")
lines.append(r"\\bottomrule")
lines.append(r"\\end{tabular}")
lines.append(rf"\\caption{{{_tex_escape(caption)}}}")
lines.append(rf"\\label{{{_tex_escape(label)}}}")
lines.append(rf"\\end{{{env}}}")
lines.append("")
return "\n".join(lines)
def _resolve_out(out: str, out_dir: str | None) -> Path:
p = Path(out).expanduser()
if p.suffix:
p = p.with_suffix("")
if out_dir:
base = Path(out_dir).expanduser()
return (base / p).resolve()
if not p.is_absolute():
p = cfg.OUT_DIR / p
return p
def _best_by_city(
df: pd.DataFrame,
*,
metric: str,
ascending: str,
) -> pd.DataFrame:
if df.empty or "city" not in df.columns:
return df
x = _sort_df(df, metric=metric, ascending=ascending)
return x.groupby("city", sort=False).head(1).copy()
# ---------------------------------------------------------------------
# Grouped outputs (S6 / S7)
# ---------------------------------------------------------------------
def _split_csv(raw: Any) -> list[str]:
s = str(raw or "").strip()
if not s:
return []
return [p.strip() for p in s.split(",") if p.strip()]
def _agg_fn(name: str):
k = str(name or "mean").strip().lower()
if k == "median":
return "median"
return "mean"
def _s6_metrics(args: argparse.Namespace) -> list[str]:
ms = _split_csv(args.s6_metrics)
if ms:
return ms
s = str(args.sort_by or "").strip().lower()
return [s] if s else ["mae"]
def _s7_group_cols(
df: pd.DataFrame, args: argparse.Namespace
) -> list[str]:
cols = _split_csv(args.s7_cols)
if cols:
return [c for c in cols if c in df.columns]
preset = [
"city",
"pde_bucket",
"use_effective_h",
"kappa_mode",
"hd_factor",
"smooth_on",
"bounds_on",
"mv_on",
"q_on",
]
return [c for c in preset if c in df.columns]
def _s7_metrics(
df: pd.DataFrame, args: argparse.Namespace
) -> list[str]:
ms = _split_csv(args.s7_metrics)
if ms:
return [m for m in ms if m in df.columns]
if args.for_paper:
mets = [
"mae",
str(args.err_metric),
"coverage80",
"sharpness80",
]
if args.keep_r2:
mets.append("r2")
return [m for m in mets if m in df.columns]
return _default_metric_cols(df)
def _add_toggle_flags(df: pd.DataFrame) -> pd.DataFrame:
x = df.copy()
def _on(col: str, out: str) -> None:
if col not in x.columns:
return
v = pd.to_numeric(x[col], errors="coerce")
x[out] = np.where(v > 0.0, "on", "off")
_on("lambda_smooth", "smooth_on")
_on("lambda_bounds", "bounds_on")
_on("lambda_mv", "mv_on")
_on("lambda_q", "q_on")
return x
def _metric_lower_is_better(metric: str) -> bool:
m = str(metric or "").strip().lower()
if not m:
return True
if m in _LOWER_IS_BETTER:
return True
meta = cfg.PLOT_METRIC_META.get(m, None)
if isinstance(meta, dict):
ttl = str(meta.get("title", ""))
if "↓" in ttl:
return True
if "↑" in ttl:
return False
if "r2" in m or "coverage" in m:
return False
return True
def _best_mask(
piv: pd.DataFrame, *, metric: str
) -> pd.DataFrame:
if piv.empty:
return pd.DataFrame(
False, index=piv.index, columns=piv.columns
)
arr = piv.to_numpy(dtype=float, copy=True)
ok = np.isfinite(arr)
if not ok.any():
return pd.DataFrame(
False, index=piv.index, columns=piv.columns
)
if _metric_lower_is_better(metric):
best = np.nanmin(arr)
else:
best = np.nanmax(arr)
mask = ok & np.isclose(arr, best, rtol=1e-9, atol=1e-12)
return pd.DataFrame(
mask, index=piv.index, columns=piv.columns
)
def _tex_pivot(
piv: pd.DataFrame,
*,
metric: str,
unit: str,
sideway: bool,
caption: str,
label: str,
bold_best: bool,
gray_missing: bool,
gray_level: str,
) -> str:
cols = ["$\\lambda_{\\mathrm{cons}}$"]
for c in piv.columns:
try:
c2 = float(c)
cols.append(
f"$\\lambda_{{\\mathrm{{prior}}}}={c2:g}$"
)
except Exception:
cols.append(
f"$\\lambda_{{\\mathrm{{prior}}}}={_tex_escape(c)}$"
)
spec = "l" + "r" * int(len(piv.columns))
env = "sidewaystable" if sideway else "table"
meta = cfg.PLOT_METRIC_META.get(metric, {})
fmt = str(meta.get("fmt", "{:.3g}"))
best = (
_best_mask(piv, metric=metric) if bold_best else None
)
lines: list[str] = []
if gray_missing:
lines.append(
"% Requires: \\usepackage[table]{xcolor}"
)
lines.append("% Requires: \\usepackage{booktabs}")
lines.append(rf"\\begin{{{env}}}[t]")
lines.append(r"\\centering")
lines.append(r"\\small")
lines.append(r"\\setlength{\\tabcolsep}{4pt}")
lines.append(r"\\renewcommand{\\arraystretch}{1.15}")
lines.append(r"\\begin{tabular}{" + spec + r"}")
lines.append(r"\\toprule")
lines.append(" & ".join(cols) + r" \\")
lines.append(r"\\midrule")
gl = str(gray_level or "gray!15").strip() or "gray!15"
for lc, row in piv.iterrows():
vals: list[str] = []
try:
vals.append(f"{float(lc):g}")
except Exception:
vals.append("--")
for c in piv.columns:
v = row.get(c, np.nan)
if v is None or (
isinstance(v, float) and not np.isfinite(v)
):
miss = "--"
if gray_missing:
miss = rf"\\cellcolor{{{gl}}}--"
vals.append(miss)
continue
try:
s = fmt.format(float(v))
except Exception:
s = "--"
if best is not None and bool(best.loc[lc, c]):
s = rf"\\textbf{{{s}}}"
vals.append(s)
lines.append(" & ".join(vals) + r" \\")
lines.append(r"\\bottomrule")
lines.append(r"\\end{tabular}")
lines.append(rf"\\caption{{{_tex_escape(caption)}}}")
lines.append(rf"\\label{{{_tex_escape(label)}}}")
lines.append(rf"\\end{{{env}}}")
lines.append("")
return "\n".join(lines)
def _write_table(
df: pd.DataFrame,
*,
out: Path,
fmts: list[str],
) -> None:
if "csv" in fmts:
p = out.with_suffix(".csv")
df.to_csv(p, index=False)
print(f"[OK] csv -> {p}")
if "json" in fmts:
p = out.with_suffix(".json")
p.write_text(
json.dumps(
df.to_dict(orient="records"), indent=2
),
encoding="utf-8",
)
print(f"[OK] json -> {p}")
if "txt" in fmts:
p = out.with_suffix(".txt")
p.write_text(
df.to_string(index=False), encoding="utf-8"
)
print(f"[OK] txt -> {p}")
def _write_tex(
text: str, *, out: Path, fmts: list[str]
) -> None:
if "tex" not in fmts:
return
p = out.with_suffix(".tex")
p.write_text(text, encoding="utf-8")
print(f"[OK] tex -> {p}")
def _build_table_s6(
df: pd.DataFrame,
*,
args: argparse.Namespace,
out: Path,
fmts: list[str],
) -> None:
if "lambda_cons" not in df.columns:
return
if "lambda_prior" not in df.columns:
return
agg = _agg_fn(args.s6_agg)
unit = str(args.metric_unit)
metrics = [
m for m in _s6_metrics(args) if m in df.columns
]
if not metrics:
return
bold_best = utils.str_to_bool(
getattr(args, "s6_tex_bold_best", "true"),
default=True,
)
gray_missing = utils.str_to_bool(
getattr(args, "s6_tex_gray_missing", "true"),
default=True,
)
gray_level = (
str(
getattr(args, "s6_tex_gray_level", "gray!15")
or "gray!15"
).strip()
or "gray!15"
)
one_file = utils.str_to_bool(
getattr(args, "s6_tex_one_file", "false"),
default=False,
)
cities = sorted(set(df.get("city", [])))
buckets = sorted(set(df.get("pde_bucket", [])))
for city in cities:
for pb in buckets:
sub = df.copy()
if "city" in sub.columns:
sub = sub.loc[sub["city"].eq(city)]
if "pde_bucket" in sub.columns:
sub = sub.loc[sub["pde_bucket"].eq(pb)]
if sub.empty:
continue
lc = pd.to_numeric(
sub["lambda_cons"], errors="coerce"
)
lp = pd.to_numeric(
sub["lambda_prior"], errors="coerce"
)
lc_vals = [
float(x) for x in sorted(set(lc.dropna()))
]
lp_vals = [
float(x) for x in sorted(set(lp.dropna()))
]
if not lc_vals or not lp_vals:
continue
tex_blocks: list[str] = []
for m in metrics:
piv = sub.pivot_table(
index="lambda_cons",
columns="lambda_prior",
values=m,
aggfunc=agg,
)
piv = piv.reindex(
index=lc_vals, columns=lp_vals
)
p2 = out.with_name(
out.name + f"__S6__{m}__{city}__{pb}"
)
flat = piv.reset_index()
_write_table(flat, out=p2, fmts=fmts)
if "tex" in fmts:
ttl = _tex_header(m, unit=unit)
cap = (
f"Table S6 ({city}, physics={pb}): "
f"$\\lambda_\\mathrm{{cons}}$"
f"$\\times$"
f"$\\lambda_\\mathrm{{prior}}$ "
f"grid for {ttl}."
)
lab = f"tab:S6_{city}_{pb}_{m}"
tex = _tex_pivot(
piv,
metric=m,
unit=unit,
sideway=bool(args.sideway),
caption=cap,
label=lab,
bold_best=bold_best,
gray_missing=gray_missing,
gray_level=gray_level,
)
if one_file:
tex_blocks.append(tex)
else:
_write_tex(tex, out=p2, fmts=fmts)
if one_file and tex_blocks and "tex" in fmts:
p3 = out.with_name(
out.name + f"__S6__{city}__{pb}"
)
_write_tex(
"\n".join(tex_blocks),
out=p3,
fmts=fmts,
)
def _build_table_s7(
df: pd.DataFrame,
*,
args: argparse.Namespace,
out: Path,
fmts: list[str],
) -> None:
x = _add_toggle_flags(df)
grp_cols = _s7_group_cols(x, args)
mets = _s7_metrics(x, args)
if not grp_cols or not mets:
return
agg = _agg_fn(args.s7_agg)
keep = grp_cols + mets
sub = x.loc[:, [c for c in keep if c in x.columns]].copy()
if sub.empty:
return
gb = sub.groupby(grp_cols, dropna=False)
out_df = gb[mets].agg(agg).reset_index()
out_df["n"] = gb.size().to_numpy(int)
# Paper table: n first, then metrics.
cols = grp_cols + ["n"] + mets
out_df = out_df.loc[:, cols]
p2 = out.with_name(out.name + "__S7")
_write_table(out_df, out=p2, fmts=fmts)
if "tex" in fmts:
cap = "Table S7: ablation toggles summary (group means)."
lab = "tab:S7_ablations"
tex = _to_tex_table(
out_df,
cols=list(out_df.columns),
unit=str(args.metric_unit),
sideway=bool(args.sideway),
caption=cap,
label=lab,
)
_write_tex(tex, out=p2, fmts=fmts)
# ---------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------
[docs]
def build_ablation_table_main(
argv: list[str] | None = None,
*,
prog: str | None = None,
) -> None:
args = _parse_args(argv, prog=prog)
utils.ensure_script_dirs()
df = _load_records(args)
if df.empty:
root = utils.as_path(args.root)
raise SystemExit(
"No ablation_record*.jsonl found under:\n"
f" {root.resolve()}\n"
"Run sensitivity/ablations first."
)
df = _canon_cols(df, max_h=int(args.max_h))
df = _scale_rows_to_mm(df)
df = _ensure_mse_rmse(df)
df = _dedupe_prefer_best(df)
df = _filter_models(df, args.models)
df = _filter_cities(df, args.cities)
keep_ph = utils.str_to_bool(
args.keep_per_horizon, default=True
)
if not keep_ph:
drop = [
c
for c in df.columns
if str(c).startswith("per_horizon_")
]
df = df.drop(columns=drop, errors="ignore")
# Convert to requested unit after derived-metric completion.
df = _to_unit(df, unit=args.metric_unit)
df = _sort_df(
df,
metric=args.sort_by,
ascending=args.ascending,
)
out = _resolve_out(args.out, args.out_dir)
utils.ensure_dir(out.parent)
if args.for_paper:
params, mets = _paper_cols(
df,
err_metric=args.err_metric,
keep_r2=bool(args.keep_r2),
)
cols = params + mets
else:
cols = _param_cols(df) + _pick_metrics(
df, args.metrics
)
cols = [c for c in cols if c in df.columns]
tab = df.loc[:, cols].copy()
fmts = [
s.strip().lower()
for s in str(args.formats).split(",")
]
fmts = [f for f in fmts if f]
if not fmts:
fmts = ["csv", "json"]
if "csv" in fmts:
p = out.with_suffix(".csv")
tab.to_csv(p, index=False)
print(f"[OK] csv -> {p}")
if "json" in fmts:
p = out.with_suffix(".json")
p.write_text(
json.dumps(
tab.to_dict(orient="records"),
indent=2,
),
encoding="utf-8",
)
print(f"[OK] json -> {p}")
if "txt" in fmts:
p = out.with_suffix(".txt")
p.write_text(
tab.to_string(index=False), encoding="utf-8"
)
print(f"[OK] txt -> {p}")
if "tex" in fmts:
if not args.for_paper:
print(
"[Warn] --formats tex is recommended with --for-paper "
"(compact columns)."
)
tex = _to_tex_table(
tab,
cols=list(tab.columns),
unit=str(args.metric_unit),
sideway=bool(args.sideway),
caption=str(args.caption),
label=str(args.label),
)
p = out.with_suffix(".tex")
p.write_text(tex, encoding="utf-8")
print(f"[OK] tex -> {p}")
if args.best_per_city:
best = _best_by_city(
tab,
metric=str(args.sort_by),
ascending=str(args.ascending),
)
p = out.with_name(
out.name + "__best_per_city"
).with_suffix(".csv")
best.to_csv(p, index=False)
print(f"[OK] best/city -> {p}")
# -------------------------------------------------
# Optional grouped tables (S6 / S7)
# -------------------------------------------------
groups = [g.lower() for g in _split_csv(args.group_cols)]
if groups:
if "all" in groups:
groups = ["s6", "s7"]
if "s6" in groups:
_build_table_s6(
df,
args=args,
out=out,
fmts=fmts,
)
if "s7" in groups:
_build_table_s7(
df,
args=args,
out=out,
fmts=fmts,
)
if args.stdout:
print("\n" + tab.to_string(index=False))
[docs]
def main(argv: list[str] | None = None) -> None:
build_ablation_table_main(argv)
if __name__ == "__main__":
main()