# SPDX-License-Identifier: Apache-2.0
# GeoPrior-v3 — https://github.com/earthai-tech/geoprior-v3
# Copyright (c) 2026-present
# Author: LKouadio <https://lkouadio.com>
"""Plot cross-city transferability of GeoPriorSubsNet.
Uses ``xfer_results.csv`` produced by ``nat.com/xfer_matrix.py`` to show:
- cross-city performance versus calibration mode, including overall MAE (mm)
and overall R²
- the 80% interval coverage-sharpness tradeoff by direction, with
``sharpness80`` (mm) on the x-axis and ``coverage80`` (fraction) on the
y-axis
Panels summarise:
- Nansha → Zhongshan (A_to_B)
- Zhongshan → Nansha (B_to_A)
Example
-------
.. code-block:: bash
python -m scripts.plot_transfer \
--src results/xfer/nansha__zhongshan \
--split val
"""
from __future__ import annotations
import argparse
from collections.abc import Iterable
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.gridspec import GridSpec
from matplotlib.lines import Line2D
from matplotlib.patches import Patch
from . import config as cfg
from . import utils as u
_DIR_CANON = {
"a_to_b": "A_to_B",
"b_to_a": "B_to_A",
"a_to_a": "A_to_A",
"b_to_b": "B_to_B",
}
_STRAT_MS = {
"baseline": 38,
"xfer": 50,
"warm": 62,
}
[docs]
@dataclass
class TextFlags:
show_legend: bool
show_labels: bool
show_ticklabels: bool
show_title: bool
show_panel_titles: bool
title: str | None
def _canon_dir(x: Any) -> str:
s = str(x).strip()
k = s # .lower()
return _DIR_CANON.get(k, s)
def _to_num(df: pd.DataFrame, col: str) -> None:
if col in df.columns:
df[col] = pd.to_numeric(df[col], errors="coerce")
def _canon_cols(df: pd.DataFrame) -> pd.DataFrame:
aliases = {
"strategy": ("strategy",),
"rescale_mode": ("rescale_mode", "rescale_m"),
"direction": ("direction",),
"source_city": ("source_city", "source_cit"),
"target_city": ("target_city", "target_cit"),
"split": ("split",),
"calibration": ("calibration", "calibratio"),
"overall_mae": ("overall_mae",),
"overall_r2": ("overall_r2",),
"coverage80": ("coverage80", "coverage8"),
"sharpness80": ("sharpness80", "sharpness"),
}
u.ensure_columns(df, aliases=aliases)
for c in (
"strategy",
"rescale_mode",
"split",
"calibration",
"source_city",
"target_city",
):
if c in df.columns:
df[c] = df[c].astype(str).str.strip().str.lower()
if "direction" in df.columns:
df["direction"] = df["direction"].map(_canon_dir)
for c in (
"overall_mae",
"overall_r2",
"coverage80",
"sharpness80",
):
_to_num(df, c)
return df
def _find_xfer_csv(src: Any) -> Path:
pats = None
if hasattr(cfg, "PATTERNS"):
pats = cfg.PATTERNS.get("xfer_results_csv") or None
pats = pats or ("xfer_results.csv", "*xfer_results*.csv")
p = u.find_latest(src, pats, must_exist=True)
if p is None:
raise FileNotFoundError(str(src))
return Path(p)
def _subset(
df: pd.DataFrame,
*,
direction: str,
strategy: str,
split: str,
calib: str,
rescale_mode: str | None,
baseline_rescale: str,
) -> pd.DataFrame:
d = _canon_dir(direction)
s = str(strategy).lower()
sp = str(split).lower()
cm = str(calib).lower()
use_dir = d
use_rm = rescale_mode
if s == "baseline":
use_dir = cfg._BASELINE_MAP.get(d, d)
use_rm = baseline_rescale
m = df["direction"].eq(use_dir)
m &= df["strategy"].eq(s)
m &= df["split"].eq(sp)
m &= df["calibration"].eq(cm)
if use_rm is not None:
m &= df["rescale_mode"].eq(str(use_rm).lower())
return df.loc[m].copy()
def _reduce_vals(
vals: Iterable[float],
*,
how: str,
better: str,
cov_target: float,
) -> float:
a = np.asarray(list(vals), dtype=float)
a = a[np.isfinite(a)]
if a.size == 0:
return float("nan")
how = str(how).lower()
if how == "mean":
return float(a.mean())
if how == "median":
return float(np.median(a))
if better == "max":
return float(a.max())
if better == "cov":
j = int(np.argmin(np.abs(a - float(cov_target))))
return float(a[j])
return float(a.min())
def _pick(
df: pd.DataFrame,
*,
direction: str,
strategy: str,
split: str,
calib: str,
rescale_mode: str | None,
baseline_rescale: str,
col: str,
reduce: str,
better: str,
cov_target: float,
) -> float:
sub = _subset(
df,
direction=direction,
strategy=strategy,
split=split,
calib=calib,
rescale_mode=rescale_mode,
baseline_rescale=baseline_rescale,
)
if col not in sub.columns:
return float("nan")
return _reduce_vals(
sub[col].values,
how=reduce,
better=better,
cov_target=cov_target,
)
def _dir_target_city(df: pd.DataFrame, direction: str) -> str:
d = _canon_dir(direction)
sub = df[df["direction"].eq(d)]
if sub.empty:
return ""
return str(sub.iloc[0].get("target_city", "") or "")
def _dir_label(df: pd.DataFrame, direction: str) -> str:
d = _canon_dir(direction)
sub = df[df["direction"].eq(d)]
if sub.empty:
return d
r0 = sub.iloc[0]
sc = u.canonical_city(str(r0.get("source_city", "")))
tc = u.canonical_city(str(r0.get("target_city", "")))
return f"{sc} \u2192 {tc}"
def _plot_bars(
ax: Any,
df: pd.DataFrame,
*,
metric_key: str,
split: str,
directions: list[str],
strategies: list[str],
calib_modes: list[str],
rescale_mode: str | None,
baseline_rescale: str,
reduce: str,
cov_target: float,
text: TextFlags,
) -> None:
col, ylab, better = cfg._METRIC_DEF[str(metric_key)]
n_cal = len(calib_modes)
n_dir = len(directions)
n_str = len(strategies)
x0 = np.arange(n_cal, dtype=float)
group_w = 0.84
dir_w = group_w / max(1, n_dir)
bw = dir_w / max(1, n_str)
for di, d in enumerate(directions):
tgt = _dir_target_city(df, d)
tgt = u.canonical_city(tgt).lower()
face = cfg.CITY_COLORS.get(tgt, "#777777")
for si, s in enumerate(strategies):
for i, cm in enumerate(calib_modes):
v = _pick(
df,
direction=d,
strategy=s,
split=split,
calib=cm,
rescale_mode=rescale_mode,
baseline_rescale=baseline_rescale,
col=col,
reduce=reduce,
better=better,
cov_target=cov_target,
)
if not np.isfinite(v):
continue
off_d = -0.5 * group_w + (di + 0.5) * dir_w
off_s = -0.5 * dir_w + (si + 0.5) * bw
xx = x0[i] + off_d + off_s
ax.bar(
xx,
v,
width=bw,
color=face,
alpha=0.86,
hatch=cfg._STRAT_HATCH.get(s, ""),
edgecolor=cfg._STRAT_EDGE.get(
s, "#111111"
),
linewidth=0.6,
)
ax.set_xticks(x0)
ax.set_xticklabels(
[cfg._CAL_LABEL.get(c, c) for c in calib_modes],
rotation=15,
)
if text.show_labels:
ax.set_ylabel(ylab)
else:
ax.set_ylabel("")
if not text.show_ticklabels:
ax.set_xticklabels([])
ax.set_yticklabels([])
if metric_key == "r2":
ax.axhline(0.0, linestyle="--", linewidth=0.7)
ax.grid(False)
def _plot_cov_sharp(
ax: Any,
df: pd.DataFrame,
*,
direction: str,
split: str,
strategies: list[str],
calib_modes: list[str],
rescale_mode: str | None,
baseline_rescale: str,
reduce: str,
cov_target: float,
text: TextFlags,
) -> tuple[list[float], list[float]]:
tgt = _dir_target_city(df, direction)
tgt = u.canonical_city(tgt).lower()
face = cfg.CITY_COLORS.get(tgt, "#777777")
all_cov: list[float] = []
all_shp: list[float] = []
for s in strategies:
for cm in calib_modes:
sub = _subset(
df,
direction=direction,
strategy=s,
split=split,
calib=cm,
rescale_mode=rescale_mode,
baseline_rescale=baseline_rescale,
)
if sub.empty:
continue
cov = sub["coverage80"].values
shp = sub["sharpness80"].values
x = _reduce_vals(
shp,
how=reduce,
better="min",
cov_target=cov_target,
)
y = _reduce_vals(
cov,
how=reduce,
better="cov",
cov_target=cov_target,
)
if not (np.isfinite(x) and np.isfinite(y)):
continue
ax.scatter(
x,
y,
s=float(_STRAT_MS.get(s, 50)),
marker=cfg._CAL_MARKER.get(cm, "o"),
facecolors=face,
edgecolors=cfg._STRAT_EDGE.get(s, "#111111"),
linewidths=0.8,
alpha=0.92,
)
all_cov.append(float(y))
all_shp.append(float(x))
ax.axhline(
float(cov_target), linestyle="--", linewidth=0.7
)
if text.show_labels:
ax.set_xlabel("Sharpness80 (mm)")
ax.set_ylabel("Coverage80")
else:
ax.set_xlabel("")
ax.set_ylabel("")
if not text.show_ticklabels:
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_ylim(0.0, 1.0)
ax.grid(False)
return all_cov, all_shp
def _add_legends(
ax: Any,
df: pd.DataFrame,
*,
directions: list[str],
strategies: list[str],
calib_modes: list[str],
) -> None:
dir_handles: list[Any] = []
for d in directions:
lab = _dir_label(df, d)
tgt = _dir_target_city(df, d)
tgt = u.canonical_city(tgt).lower()
face = cfg.CITY_COLORS.get(tgt, "#777777")
dir_handles.append(
Patch(
facecolor=face,
edgecolor="#111111",
linewidth=0.6,
label=lab,
)
)
strat_handles: list[Any] = []
for s in strategies:
strat_handles.append(
Patch(
facecolor="white",
edgecolor=cfg._STRAT_EDGE.get(s, "#111111"),
hatch=cfg._STRAT_HATCH.get(s, ""),
linewidth=0.8,
label=cfg._STRAT_LABEL.get(s, s),
)
)
cal_handles: list[Any] = []
for c in calib_modes:
cal_handles.append(
Line2D(
[],
[],
marker=cfg._CAL_MARKER.get(c, "o"),
linestyle="None",
markerfacecolor="white",
markeredgecolor="#111111",
markersize=5,
label=cfg._CAL_LABEL.get(c, c),
)
)
leg1 = ax.legend(
handles=dir_handles,
frameon=False,
loc="upper left",
title="Direction",
)
ax.add_artist(leg1)
leg2 = ax.legend(
handles=strat_handles,
frameon=False,
loc="upper right",
title="Strategy",
)
ax.add_artist(leg2)
ax.legend(
handles=cal_handles,
frameon=False,
loc="lower right",
title="Calibration",
)
[docs]
def render(
df: pd.DataFrame,
*,
split: str,
strategies: list[str],
calib_modes: list[str],
rescale_mode: str | None,
baseline_rescale: str,
reduce: str,
cov_target: float,
out: Path,
text: TextFlags,
metric_top: str = "mae",
metric_bottom: str = "rmse",
) -> tuple[Path, Path]:
u.set_paper_style()
directions = ["A_to_B", "B_to_A"]
fig = plt.figure(figsize=(7.0, 4.0))
gs = GridSpec(
2,
2,
figure=fig,
width_ratios=(1.18, 1.0),
height_ratios=(1.0, 1.0),
wspace=0.40,
hspace=0.55,
)
ax_mae = fig.add_subplot(gs[0, 0])
ax_r2 = fig.add_subplot(gs[1, 0])
ax_ab = fig.add_subplot(gs[0, 1])
ax_ba = fig.add_subplot(gs[1, 1])
_plot_bars(
ax_mae,
df,
metric_key=str(metric_top).lower(),
split=split,
directions=directions,
strategies=strategies,
calib_modes=calib_modes,
rescale_mode=rescale_mode,
baseline_rescale=baseline_rescale,
reduce=reduce,
cov_target=cov_target,
text=text,
)
_plot_bars(
ax_r2,
df,
metric_key=str(metric_bottom).lower(),
split=split,
directions=directions,
strategies=strategies,
calib_modes=calib_modes,
rescale_mode=rescale_mode,
baseline_rescale=baseline_rescale,
reduce=reduce,
cov_target=cov_target,
text=text,
)
cov_ab, shp_ab = _plot_cov_sharp(
ax_ab,
df,
direction="A_to_B",
split=split,
strategies=strategies,
calib_modes=calib_modes,
rescale_mode=rescale_mode,
baseline_rescale=baseline_rescale,
reduce=reduce,
cov_target=cov_target,
text=text,
)
cov_ba, shp_ba = _plot_cov_sharp(
ax_ba,
df,
direction="B_to_A",
split=split,
strategies=strategies,
calib_modes=calib_modes,
rescale_mode=rescale_mode,
baseline_rescale=baseline_rescale,
reduce=reduce,
cov_target=cov_target,
text=text,
)
top_title = str(metric_top).upper()
bottom_title = (
"$R^2$"
if str(metric_bottom).lower() == "r2"
else str(metric_bottom).upper()
)
if text.show_panel_titles:
ax_mae.set_title(f"(a) {top_title} vs calibration")
ax_r2.set_title(f"(a) {bottom_title} vs calibration")
ax_ab.set_title(
"(b) Coverage–sharpness: "
+ _dir_label(df, "A_to_B")
)
ax_ba.set_title(
"(c) Coverage–sharpness: "
+ _dir_label(df, "B_to_A")
)
all_cov = np.asarray(cov_ab + cov_ba, dtype=float)
all_shp = np.asarray(shp_ab + shp_ba, dtype=float)
m = np.isfinite(all_cov) & np.isfinite(all_shp)
if m.any():
ymin = max(0.0, float(all_cov[m].min()) - 0.05)
ymax = min(1.0, float(all_cov[m].max()) + 0.05)
xmin = float(all_shp[m].min())
xmax = float(all_shp[m].max())
pad = 0.05 * (xmax - xmin) if xmax > xmin else 1.0
xmin -= pad
xmax += pad
ax_ab.set_ylim(ymin, ymax)
ax_ba.set_ylim(ymin, ymax)
ax_ab.set_xlim(xmin, xmax)
ax_ba.set_xlim(xmin, xmax)
if text.show_legend:
_add_legends(
ax_ba,
df,
directions=directions,
strategies=strategies,
calib_modes=calib_modes,
)
if text.show_title:
t0 = "Supplementary Figure Sx — Cross-city transferability"
t = u.resolve_title(default=t0, title=text.title)
fig.suptitle(t, x=0.02, y=0.99, ha="left")
stem = out
if stem.suffix:
stem = stem.with_suffix("")
png = stem.with_suffix(".png")
svg = stem.with_suffix(".svg")
fig.savefig(png, bbox_inches="tight")
fig.savefig(svg, bbox_inches="tight")
plt.close(fig)
return png, svg
[docs]
def parse_args(
argv: list[str] | None = None,
*,
prog: str | None = None,
) -> Any:
ap = argparse.ArgumentParser( # type: ignore[attr-defined]
prog=prog or "plot-transfer",
description=(
"Supplementary Figure Sx — transferability from xfer CSV"
),
)
ap.add_argument(
"--src",
type=str,
default="results/xfer",
help="Folder to search latest xfer_results.csv",
)
ap.add_argument(
"--xfer-csv",
type=str,
default=None,
help="Explicit xfer_results.csv path (file or dir)",
)
ap.add_argument(
"--split",
type=str,
default="val",
choices=("val", "test"),
help="Which split to plot",
)
ap.add_argument(
"--strategies",
nargs="+",
default=list(cfg._STRAT_DEFAULT),
help="Strategies (baseline xfer warm)",
)
ap.add_argument(
"--calib-modes",
nargs="+",
default=list(cfg._CAL_ORDER),
help="Calibration modes (none source target)",
)
ap.add_argument(
"--rescale-mode",
type=str,
default="strict",
choices=("strict", "as_is"),
help="Rescale mode for xfer/warm",
)
ap.add_argument(
"--baseline-rescale",
type=str,
default="as_is",
choices=("strict", "as_is"),
help="Rescale mode for baseline fetch",
)
ap.add_argument(
"--reduce",
type=str,
default="best",
choices=("best", "mean", "median"),
help="Reduce duplicates per cell",
)
ap.add_argument(
"--cov-target",
type=float,
default=0.80,
help="Target coverage reference",
)
ap.add_argument(
"--metric-top",
type=str,
default="mae",
choices=tuple(cfg._METRIC_DEF.keys()),
help="Metric for the top-left bar panel.",
)
ap.add_argument(
"--metric-bottom",
type=str,
default="rmse",
choices=tuple(cfg._METRIC_DEF.keys()),
help="Metric for the bottom-left bar panel.",
)
u.add_plot_text_args(
ap,
default_out="figureS_xfer_transferability",
)
return ap.parse_args(argv)
def _text_flags(args: Any) -> TextFlags:
return TextFlags(
show_legend=u.str_to_bool(
args.show_legend, default=True
),
show_labels=u.str_to_bool(
args.show_labels, default=True
),
show_ticklabels=u.str_to_bool(
args.show_ticklabels,
default=True,
),
show_title=u.str_to_bool(
args.show_title, default=True
),
show_panel_titles=u.str_to_bool(
args.show_panel_titles,
default=True,
),
title=args.title,
)
[docs]
def figSx_xfer_transferability_main(
argv: list[str] | None = None,
*,
prog: str | None = None,
) -> None:
u.ensure_script_dirs()
args = parse_args(argv, prog=prog)
text = _text_flags(args)
if args.xfer_csv:
p = Path(args.xfer_csv).expanduser()
if p.is_dir():
csv_p = _find_xfer_csv(p)
else:
csv_p = p
else:
csv_p = _find_xfer_csv(args.src)
df = pd.read_csv(csv_p)
df = _canon_cols(df)
split = str(args.split).lower()
df = df[df["split"].eq(split)].copy()
if df.empty:
raise SystemExit("No rows after split filter.")
strategies = [str(s).lower() for s in args.strategies]
calib = [str(c).lower() for c in args.calib_modes]
out = u.resolve_fig_out(args.out)
png, svg = render(
df,
split=split,
strategies=strategies,
calib_modes=calib,
rescale_mode=str(args.rescale_mode).lower(),
baseline_rescale=str(args.baseline_rescale).lower(),
reduce=str(args.reduce).lower(),
cov_target=float(args.cov_target),
out=out,
text=text,
metric_top=str(args.metric_top).lower(),
metric_bottom=str(args.metric_bottom).lower(),
)
print(f"[OK] Wrote {png}")
print(f"[OK] Wrote {svg}")
[docs]
def main(
argv: list[str] | None = None,
*,
prog: str | None = None,
) -> None:
figSx_xfer_transferability_main(argv, prog=prog)
if __name__ == "__main__":
main()