# SPDX-License-Identifier: Apache-2.0
# GeoPrior-v3 — https://github.com/earthai-tech/geoprior-v3
# Copyright (c) 2026-present
# Author: LKouadio <https://lkouadio.com>
r"""Script helpers for plotting core ablation results."""
from __future__ import annotations
import argparse
from pathlib import Path
from typing import Any
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.gridspec import GridSpec
from . import config as cfg
from . import utils
def _pick_rmse(
phys: dict[str, Any],
diag: dict[str, Any],
mse: float,
) -> float:
rmse = float("nan")
pm = phys.get("point_metrics") or {}
if "rmse" in pm:
rmse = utils.to_float(pm.get("rmse"))
me = phys.get("metrics_evaluate") or {}
if np.isnan(rmse):
rmse = utils.to_float(me.get("subs_pred_rmse"))
flat = utils.flatten_eval_diag(diag) if diag else {}
if np.isnan(rmse):
rmse = utils.to_float(flat.get("rmse"))
if np.isnan(rmse) and np.isfinite(mse):
rmse = float(np.sqrt(max(0.0, float(mse))))
return float(rmse)
def _collect_one(
*,
city: str,
variant: str,
src: Path,
) -> dict[str, Any]:
phys_p = utils.find_phys_json(src)
diag_p = utils.find_eval_diag_json(src)
phys = utils.safe_load_json(phys_p)
diag = utils.safe_load_json(diag_p)
phys = utils.phys_json_to_mm(phys)
r2, mae, mse = utils.pick_point_metrics(phys, diag)
rmse = _pick_rmse(phys, diag, mse)
cov, shp = utils.pick_interval_metrics(phys, diag)
return {
"city": city,
"variant": variant,
"r2": r2,
"mae": mae,
"mse": mse,
"rmse": rmse,
"coverage80": cov,
"sharpness80": shp,
"phys_json": str(phys_p) if phys_p else "",
"diag_json": str(diag_p) if diag_p else "",
}
[docs]
def collect_fig3_metrics(
*,
cities: list[str],
ns_with: str | None,
ns_no: str | None,
zh_with: str | None,
zh_no: str | None,
) -> pd.DataFrame:
rows: list[dict[str, Any]] = []
for c in cities:
if c == "Nansha":
if ns_with:
rows.append(
_collect_one(
city=c,
variant="with-phys",
src=Path(ns_with).expanduser(),
)
)
if ns_no:
rows.append(
_collect_one(
city=c,
variant="no-phys",
src=Path(ns_no).expanduser(),
)
)
if c == "Zhongshan":
if zh_with:
rows.append(
_collect_one(
city=c,
variant="with-phys",
src=Path(zh_with).expanduser(),
)
)
if zh_no:
rows.append(
_collect_one(
city=c,
variant="no-phys",
src=Path(zh_no).expanduser(),
)
)
return pd.DataFrame(rows)
def _add_panel_label(
ax: plt.Axes,
letter: str,
enabled: bool,
) -> None:
if not enabled:
return
ax.text(
-0.14,
1.06,
letter,
transform=ax.transAxes,
fontsize=12,
fontweight="bold",
va="top",
ha="left",
)
def _bar_values(
ax: plt.Axes,
bars: Any,
fmt: str,
enabled: bool,
) -> None:
if not enabled:
return
for b in bars:
h = b.get_height()
if np.isfinite(h):
ax.text(
b.get_x() + b.get_width() / 2.0,
h,
fmt.format(h),
ha="center",
va="bottom",
fontsize=8,
)
def _metric_meta(
key: str,
) -> tuple[str, str, str]:
md = cfg.PLOT_METRIC_META.get(str(key))
if not md:
raise KeyError(str(key))
unit = str(md.get("unit", "") or "")
ctx = {"unit": unit}
title = str(md["title"]).format(**ctx)
ylabel = str(md["ylabel"]).format(**ctx)
fmt = str(md.get("fmt", "{:.2f}"))
return title, ylabel, fmt
[docs]
def plot_fig3_core_ablation(
df: pd.DataFrame,
*,
cities: list[str],
core_metric: str,
err_metric: str,
out: str,
out_csv: str,
out_tex: str | None,
out_xlsx: str | None,
dpi: int,
show_legend: bool,
show_labels: bool,
show_ticklabels: bool,
show_title: bool,
show_panel_titles: bool,
show_values: bool,
show_panel_labels: bool,
title: str | None,
) -> None:
utils.ensure_script_dirs()
utils.set_paper_style()
df = df.copy()
df["city"] = df["city"].map(utils.canonical_city)
if err_metric not in {"rmse", "mse"}:
raise ValueError("err_metric must be rmse or mse")
err_key = err_metric
cm = str(core_metric or "mae").strip().lower()
if cm not in {"mae", "r2"}:
raise ValueError("core_metric must be mae or r2")
if cm == "mae":
big_key = "mae"
top_key = "r2"
ab_key = "mae"
else:
big_key = "r2"
top_key = "mae"
ab_key = "r2"
def _vals(metric: str, variant: str) -> np.ndarray:
outv: list[float] = []
for c in cities:
sub = df[
(df["city"] == c) & (df["variant"] == variant)
]
if sub.empty:
outv.append(float("nan"))
else:
outv.append(float(sub.iloc[0][metric]))
return np.asarray(outv, dtype=float)
x = np.arange(len(cities))
colors = [
cfg.CITY_COLORS.get(c, "#777777") for c in cities
]
fig = plt.figure(figsize=(10.0, 4.5))
gs = GridSpec(
2,
4,
figure=fig,
wspace=0.9,
hspace=0.6,
)
ax_a = fig.add_subplot(gs[:, 0])
ax_b = fig.add_subplot(gs[0, 1])
ax_c = fig.add_subplot(gs[1, 1])
ax_d = fig.add_subplot(gs[0, 2])
ax_e = fig.add_subplot(gs[1, 2])
ax_f = fig.add_subplot(gs[0, 3])
ax_g = fig.add_subplot(gs[1, 3])
_add_panel_label(ax_a, "a", show_panel_labels)
_add_panel_label(ax_b, "b", show_panel_labels)
_add_panel_label(ax_c, "c", show_panel_labels)
_add_panel_label(ax_d, "d", show_panel_labels)
_add_panel_label(ax_e, "e", show_panel_labels)
_add_panel_label(ax_f, "f", show_panel_labels)
_add_panel_label(ax_g, "g", show_panel_labels)
big_v = _vals(big_key, "with-phys")
top_v = _vals(top_key, "with-phys")
err_v = _vals(err_key, "with-phys")
_, big_y, big_f = _metric_meta(big_key)
top_t, top_y, top_f = _metric_meta(top_key)
err_t, err_y, err_f = _metric_meta(err_key)
bars = ax_a.bar(
x,
big_v,
color=colors,
edgecolor="none",
)
_bar_values(ax_a, bars, big_f, show_values)
if show_panel_titles:
ax_a.set_title("Core (with physics)")
if show_labels:
ax_a.set_ylabel(big_y)
if show_ticklabels:
ax_a.set_xticks(x, cities)
else:
ax_a.set_xticks([])
ax_a.tick_params(labelleft=False)
bars = ax_b.bar(
x,
top_v,
color=colors,
edgecolor="none",
)
_bar_values(ax_b, bars, top_f, show_values)
if show_panel_titles:
ax_b.set_title(top_t)
if show_labels:
ax_b.set_ylabel(top_y)
if show_ticklabels:
ax_b.set_xticks(x, cities)
else:
ax_b.set_xticks([])
ax_b.tick_params(labelleft=False)
bars = ax_c.bar(
x,
err_v,
color=colors,
edgecolor="none",
)
_bar_values(ax_c, bars, err_f, show_values)
if show_panel_titles:
ax_c.set_title(err_t)
if show_labels:
ax_c.set_ylabel(err_y)
if show_ticklabels:
ax_c.set_xticks(x, cities)
else:
ax_c.set_xticks([])
ax_c.tick_params(labelleft=False)
width = 0.36
def _grouped(
ax: plt.Axes,
metric: str,
fmt: str,
title_txt: str,
) -> tuple[Any, Any]:
with_v = _vals(metric, "with-phys")
no_v = _vals(metric, "no-phys")
b1 = ax.bar(
x - width / 2.0,
with_v,
width=width,
color=colors,
edgecolor="none",
label="with physics",
)
b2 = ax.bar(
x + width / 2.0,
no_v,
width=width,
color=colors,
edgecolor="black",
fill=False,
hatch="///",
label="no physics",
)
if metric == "r2":
ax.axhline(
0.0,
linestyle="--",
linewidth=0.6,
alpha=0.6,
)
if show_values:
_bar_values(ax, b1, fmt, True)
_bar_values(ax, b2, fmt, True)
if show_ticklabels:
ax.set_xticks(x, cities)
else:
ax.set_xticks([])
ax.tick_params(labelleft=False)
if show_panel_titles:
ax.set_title(title_txt)
ax.grid(axis="y", alpha=0.2)
return b1, b2
ab_t, ab_y, ab_f = _metric_meta(ab_key)
err_t2, err_y2, err_f2 = _metric_meta(err_key)
cov_t, cov_y, cov_f = _metric_meta("coverage80")
shp_t, shp_y, shp_f = _metric_meta("sharpness80")
_grouped(
ax_d,
ab_key,
ab_f,
f"{ab_t} (with vs no)",
)
_grouped(
ax_e,
err_key,
err_f2,
f"{err_t2} (with vs no)",
)
_grouped(ax_f, "coverage80", cov_f, cov_t)
_grouped(ax_g, "sharpness80", shp_f, shp_t)
if show_labels:
ax_d.set_ylabel(ab_y)
ax_e.set_ylabel(err_y2)
ax_f.set_ylabel(cov_y)
ax_g.set_ylabel(shp_y)
if show_legend:
h, lab = ax_d.get_legend_handles_labels()
fig.legend(
h,
lab,
loc="upper right",
bbox_to_anchor=(0.98, 0.98),
frameon=False,
)
if show_title:
ttl = utils.resolve_title(
default="Fig. 3 — Core & ablation",
title=title,
)
fig.suptitle(ttl, x=0.02, ha="left")
fig_p = utils.resolve_fig_out(out)
if fig_p.suffix:
fig_p = fig_p.with_suffix("")
fig.savefig(
str(fig_p) + ".png", dpi=dpi, bbox_inches="tight"
)
fig.savefig(str(fig_p) + ".svg", bbox_inches="tight")
plt.close(fig)
out_csv_p = utils.resolve_out_out(out_csv)
df.to_csv(out_csv_p, index=False)
if out_xlsx:
out_xlsx_p = utils.resolve_out_out(out_xlsx)
df.to_excel(out_xlsx_p, index=False)
if out_tex:
out_tex_p = utils.resolve_out_out(out_tex)
tex = df.to_latex(index=False)
out_tex_p.write_text(tex, encoding="utf-8")
print(f"[OK] wrote {fig_p}.png/.svg")
print(f"[OK] wrote {out_csv_p}")
def _add_plot_fig3_args(ap: argparse.ArgumentParser) -> None:
utils.add_city_flags(ap, default_both=True)
ap.add_argument("--ns-with", type=str, default=None)
ap.add_argument("--ns-no", type=str, default=None)
ap.add_argument("--zh-with", type=str, default=None)
ap.add_argument("--zh-no", type=str, default=None)
ap.add_argument(
"--core-metric",
type=str,
choices=["mae", "r2"],
default="mae",
help=("Big core metric. mae=new default, r2=legacy."),
)
ap.add_argument(
"--err-metric",
type=str,
choices=["rmse", "mse"],
default="mse",
)
ap.add_argument("--dpi", type=int, default=cfg.PAPER_DPI)
ap.add_argument(
"--out-csv",
type=str,
default="ext-table-fig3-metrics.csv",
)
ap.add_argument(
"--out-tex",
type=str,
default="ext-table-fig3-metrics.tex",
)
ap.add_argument(
"--out-xlsx",
type=str,
default=None,
)
ap.add_argument(
"--write-tex",
type=str,
default="true",
)
ap.add_argument(
"--show-values",
type=str,
default="true",
)
ap.add_argument(
"--show-panel-labels",
type=str,
default="true",
)
utils.add_plot_text_args(
ap,
default_out="fig3-core-ablation",
)
[docs]
def plot_fig3_core_ablation_main(
argv: list[str] | None = None,
*,
prog: str | None = None,
) -> None:
ap = argparse.ArgumentParser(
prog=prog or "plot-core-ablation",
description="Fig. 3 core + ablation.",
)
_add_plot_fig3_args(ap)
args = ap.parse_args(argv)
cities = utils.resolve_cities(args)
if not cities:
cities = ["Nansha", "Zhongshan"]
df = collect_fig3_metrics(
cities=cities,
ns_with=args.ns_with,
ns_no=args.ns_no,
zh_with=args.zh_with,
zh_no=args.zh_no,
)
show_legend = utils.str_to_bool(
args.show_legend, default=True
)
show_labels = utils.str_to_bool(
args.show_labels, default=True
)
show_ticks = utils.str_to_bool(
args.show_ticklabels,
default=True,
)
show_title = utils.str_to_bool(
args.show_title, default=True
)
show_pt = utils.str_to_bool(
args.show_panel_titles,
default=True,
)
show_vals = utils.str_to_bool(
args.show_values, default=True
)
show_pl = utils.str_to_bool(
args.show_panel_labels,
default=True,
)
out_tex = args.out_tex
if not utils.str_to_bool(args.write_tex, default=True):
out_tex = None
plot_fig3_core_ablation(
df,
cities=cities,
core_metric=args.core_metric,
err_metric=args.err_metric,
out=args.out,
out_csv=args.out_csv,
out_tex=out_tex,
out_xlsx=args.out_xlsx,
dpi=int(args.dpi),
show_legend=show_legend,
show_labels=show_labels,
show_ticklabels=show_ticks,
show_title=show_title,
show_panel_titles=show_pt,
show_values=show_vals,
show_panel_labels=show_pl,
title=args.title,
)
[docs]
def main(
argv: list[str] | None = None, *, prog: str | None = None
) -> None:
plot_fig3_core_ablation_main(argv, prog=prog)
if __name__ == "__main__":
main()
# 5) How to call it
# New default (MAE-first, MSE default):
# python -m scripts plot-core-ablation
# New layout + RMSE:
# python -m scripts plot-core-ablation \
# --err-metric rmse
# Legacy layout (R²-first):
# python -m scripts plot-core-ablation \
# --core-metric r2