# SPDX-License-Identifier: Apache-2.0
# GeoPrior-v3 — https://github.com/earthai-tech/geoprior-v3
# Copyright (c) 2026-present
# Author: LKouadio <https://lkouadio.com>
#
r""" Plot 1D physics sensitivity profiles.
Complements:
- Supp. Fig. S6 (ablations: skill + coverage),
- Supp. Fig. S7 (2D residual heatmaps),
by providing 1D "profiles" (physics-on only):
For each city (rows) and each metric (metric_prior, metric_cons),
we draw two line profiles:
- metric vs λ_prior (mean over λ_cons, pde_mode='both')
- metric vs λ_cons (mean over λ_prior, pde_mode='both')
Layout
------
The figure has 2 rows × 4 columns.
- Row 1: City A (default: Nansha), with metric-prior and
metric-cons profiles against ``λ_prior`` and ``λ_cons``.
- Row 2: City B (default: Zhongshan), with the same four panels.
Data source
-----------
We scan for JSONL under
``<root>/**/ablation_records/ablation_record*.jsonl``.
Outputs
-------
- Figure: ``<out>.png`` and ``<out>.pdf``
- Tidy table copy:
``appendix_table_A1_phys_profiles_tidy.csv`` written next to the
figure.
API conventions
---------------
- Style via ``scripts.utils.set_paper_style()``
- Output path via ``scripts.utils.resolve_fig_out()``
- JSONL discovery via ``cfg.PATTERNS["ablation_record_jsonl"]`` plus
``scripts.utils.find_all()``
- A ``main(argv)`` wrapper calls a dedicated ``*_main(argv)``
function.
Linting / format
----------------
- black + ruff, line length <= 62
"""
from __future__ import annotations
import argparse
import json
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from . import config as cfg
from . import utils
# Metrics where "lower is better" (best point highlight).
_LOWER_IS_BETTER = {
"mae",
"mse",
"sharpness80",
"epsilon_prior",
"epsilon_cons",
}
# ---------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------
def _parse_args(
argv: list[str] | None,
*,
prog: str | None = None,
) -> argparse.Namespace:
p = argparse.ArgumentParser(
prog=prog or "plot-physics-profiles",
description="Appendix A1: 1D physics profiles",
)
p.add_argument(
"--root",
type=str,
default="results",
help=(
"Root to scan for "
"**/ablation_records/ablation_record*.jsonl"
),
)
# Backward compat: explicit out dir override.
p.add_argument(
"--out-dir",
type=str,
default=None,
help="Optional output dir override.",
)
p.add_argument(
"--font",
type=int,
default=cfg.PAPER_FONT,
)
p.add_argument(
"--dpi",
type=int,
default=cfg.PAPER_DPI,
)
p.add_argument(
"--city-a",
type=str,
default="Nansha",
help="City A name as recorded in JSONL.",
)
p.add_argument(
"--city-b",
type=str,
default="Zhongshan",
help="City B name as recorded in JSONL.",
)
p.add_argument(
"--metric-prior",
type=str,
default="epsilon_prior",
choices=[
"epsilon_prior",
"coverage80",
"sharpness80",
"r2",
"mae",
"mse",
],
help="Metric used for prior profiles.",
)
p.add_argument(
"--metric-cons",
type=str,
default="epsilon_cons",
choices=[
"epsilon_cons",
"coverage80",
"sharpness80",
"r2",
"mae",
"mse",
],
help="Metric used for cons profiles.",
)
utils.add_plot_text_args(
p,
default_out="appendix_fig_A1_phys_profiles",
)
# Optional highlight of best point per panel.
p.add_argument(
"--show-best",
type=str,
default="true",
help="Mark best point (true/false).",
)
return p.parse_args(argv)
# ---------------------------------------------------------------------
# I/O
# ---------------------------------------------------------------------
def _read_records(root: Path) -> pd.DataFrame:
rows: list[dict] = []
files = utils.find_all(
root,
cfg.PATTERNS.get("ablation_record_jsonl", ()),
)
for fp in files:
try:
with fp.open("r", encoding="utf-8") as f:
for line in f:
s = line.strip()
if not s:
continue
try:
rec = json.loads(s)
except Exception:
continue
if isinstance(rec, dict):
rec["_src"] = str(fp)
rows.append(rec)
except Exception:
continue
df = pd.DataFrame(rows)
if df.empty:
return df
for col in ["lambda_prior", "lambda_cons"]:
if col in df.columns:
df[col] = pd.to_numeric(
df[col],
errors="coerce",
)
return df
# ---------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------
def _city_mask(df: pd.DataFrame, city: str) -> pd.Series:
return (
df["city"]
.astype(str)
.str.strip()
.str.lower()
.eq(str(city).strip().lower())
)
def _metric_label(name: str) -> str:
k = str(name).strip()
# Prefer central physics labels when available.
if k in cfg.PHYS_LABELS:
return cfg.PHYS_LABELS[k]
kl = k.lower()
if kl == "coverage80":
return "Coverage (80%)"
if kl == "sharpness80":
return "Sharpness (80%)"
if kl == "r2":
return "R\u00b2"
if kl == "mae":
return "MAE"
if kl == "mse":
return "MSE"
return k
def _axes_cleanup(ax: plt.Axes) -> None:
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
def _only_physics_on(df: pd.DataFrame) -> pd.DataFrame:
# Some JSONLs may omit pde_mode. In that case we
# keep all rows (treat as "physics on").
if "pde_mode" not in df.columns:
return df
return df.loc[df["pde_mode"].astype(str).eq("both")]
def _profile_over(
df: pd.DataFrame,
*,
city: str,
metric: str,
axis: str,
) -> tuple[np.ndarray, np.ndarray]:
if axis not in ("lambda_prior", "lambda_cons"):
raise ValueError(f"bad axis: {axis!r}")
sub = df.loc[_city_mask(df, city)].copy()
sub = _only_physics_on(sub)
if sub.empty or metric not in sub.columns:
return np.array([]), np.array([])
g = (
sub.groupby(axis, dropna=True)[metric]
.mean()
.reset_index()
.dropna(subset=[axis])
.sort_values(axis)
)
if g.empty:
return np.array([]), np.array([])
x = g[axis].to_numpy(dtype=float)
y = pd.to_numeric(
g[metric],
errors="coerce",
).to_numpy(dtype=float)
ok = np.isfinite(x) & np.isfinite(y)
return x[ok], y[ok]
def _best_idx(y: np.ndarray, *, metric: str) -> int | None:
if y.size == 0:
return None
if not np.isfinite(y).any():
return None
try:
if metric.lower() in _LOWER_IS_BETTER:
return int(np.nanargmin(y))
return int(np.nanargmax(y))
except Exception:
return None
def _plot_profile_panel(
ax: plt.Axes,
df: pd.DataFrame,
*,
city: str,
metric: str,
axis: str,
color: str,
show_best: bool,
show_labels: bool,
show_ticks: bool,
show_title: bool,
) -> None:
x, y = _profile_over(
df,
city=city,
metric=metric,
axis=axis,
)
if x.size == 0:
ax.set_axis_off()
return
ax.plot(
x,
y,
marker="o",
linewidth=1.0,
color=color,
)
if show_best:
idx = _best_idx(y, metric=metric)
if idx is not None:
ax.scatter(
[x[idx]],
[y[idx]],
s=35,
facecolors="none",
edgecolors="black",
linewidths=1.0,
)
if show_ticks:
ax.tick_params(axis="both", which="both")
else:
ax.set_xticklabels([])
ax.set_yticklabels([])
if show_labels:
xl = (
r"$\lambda_{\mathrm{prior}}$"
if axis == "lambda_prior"
else r"$\lambda_{\mathrm{cons}}$"
)
ax.set_xlabel(xl)
ax.set_ylabel(_metric_label(metric))
else:
ax.set_xlabel("")
ax.set_ylabel("")
if show_title:
which = (
r"vs $\lambda_{\mathrm{prior}}$"
if axis == "lambda_prior"
else r"vs $\lambda_{\mathrm{cons}}$"
)
ax.set_title(
f"{city} — {_metric_label(metric)} {which}",
loc="left",
pad=4,
fontweight="bold",
)
_axes_cleanup(ax)
def _resolve_out(
*,
out: str,
out_dir: str | None,
) -> Path:
if out_dir:
base = Path(out_dir).expanduser()
return (base / Path(out).expanduser()).resolve()
return utils.resolve_fig_out(out)
# ---------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------
[docs]
def figA1_phys_profiles_main(
argv: list[str] | None = None,
*,
prog: str | None = None,
) -> None:
args = _parse_args(argv, prog=prog)
utils.set_paper_style(
fontsize=int(args.font),
dpi=int(args.dpi),
)
show_labels = utils.str_to_bool(
args.show_labels,
default=True,
)
show_ticks = utils.str_to_bool(
args.show_ticklabels,
default=True,
)
show_title = utils.str_to_bool(
args.show_title,
default=True,
)
show_pan_t = utils.str_to_bool(
args.show_panel_titles,
default=True,
)
show_best = utils.str_to_bool(
args.show_best,
default=True,
)
root = utils.as_path(args.root)
df = _read_records(root)
if df.empty:
raise SystemExit(
"No ablation_record*.jsonl found under:\n"
f" {root.resolve()}\n"
"Run ablations with the logger enabled first."
)
city_a = utils.canonical_city(args.city_a)
city_b = utils.canonical_city(args.city_b)
out = _resolve_out(out=args.out, out_dir=args.out_dir)
utils.ensure_dir(out.parent)
tidy_csv = (
out.parent
/ "appendix_table_A1_phys_profiles_tidy.csv"
)
df.to_csv(tidy_csv, index=False)
print(f"[OK] table -> {tidy_csv}")
ca = cfg.CITY_COLORS.get(city_a, "#1F78B4")
cb = cfg.CITY_COLORS.get(city_b, "#E31A1C")
fig = plt.figure(figsize=(10.0, 6.5))
gs = fig.add_gridspec(
nrows=2,
ncols=4,
left=0.06,
right=0.98,
top=0.92,
bottom=0.10,
hspace=0.35,
wspace=0.35,
)
# Row 1: city A
_plot_profile_panel(
fig.add_subplot(gs[0, 0]),
df,
city=city_a,
metric=args.metric_prior,
axis="lambda_prior",
color=ca,
show_best=show_best,
show_labels=show_labels,
show_ticks=show_ticks,
show_title=show_pan_t,
)
_plot_profile_panel(
fig.add_subplot(gs[0, 1]),
df,
city=city_a,
metric=args.metric_prior,
axis="lambda_cons",
color=ca,
show_best=show_best,
show_labels=show_labels,
show_ticks=show_ticks,
show_title=show_pan_t,
)
_plot_profile_panel(
fig.add_subplot(gs[0, 2]),
df,
city=city_a,
metric=args.metric_cons,
axis="lambda_prior",
color=ca,
show_best=show_best,
show_labels=show_labels,
show_ticks=show_ticks,
show_title=show_pan_t,
)
_plot_profile_panel(
fig.add_subplot(gs[0, 3]),
df,
city=city_a,
metric=args.metric_cons,
axis="lambda_cons",
color=ca,
show_best=show_best,
show_labels=show_labels,
show_ticks=show_ticks,
show_title=show_pan_t,
)
# Row 2: city B
_plot_profile_panel(
fig.add_subplot(gs[1, 0]),
df,
city=city_b,
metric=args.metric_prior,
axis="lambda_prior",
color=cb,
show_best=show_best,
show_labels=show_labels,
show_ticks=show_ticks,
show_title=show_pan_t,
)
_plot_profile_panel(
fig.add_subplot(gs[1, 1]),
df,
city=city_b,
metric=args.metric_prior,
axis="lambda_cons",
color=cb,
show_best=show_best,
show_labels=show_labels,
show_ticks=show_ticks,
show_title=show_pan_t,
)
_plot_profile_panel(
fig.add_subplot(gs[1, 2]),
df,
city=city_b,
metric=args.metric_cons,
axis="lambda_prior",
color=cb,
show_best=show_best,
show_labels=show_labels,
show_ticks=show_ticks,
show_title=show_pan_t,
)
_plot_profile_panel(
fig.add_subplot(gs[1, 3]),
df,
city=city_b,
metric=args.metric_cons,
axis="lambda_cons",
color=cb,
show_best=show_best,
show_labels=show_labels,
show_ticks=show_ticks,
show_title=show_pan_t,
)
if show_title:
default = "Appendix Fig. A1 • 1D physics sensitivity profiles"
ttl = utils.resolve_title(
default=default,
title=args.title,
)
fig.suptitle(
ttl,
fontsize=11,
fontweight="bold",
)
png = out.with_suffix(".png")
pdf = out.with_suffix(".pdf")
fig.savefig(png, bbox_inches="tight")
fig.savefig(pdf, bbox_inches="tight")
print(f"[OK] figs -> {png} | {pdf}")
[docs]
def main(
argv: list[str] | None = None,
*,
prog: str | None = None,
) -> None:
figA1_phys_profiles_main(argv, prog=prog)
if __name__ == "__main__":
main()