# SPDX-License-Identifier: Apache-2.0
# GeoPrior-v3 — https://github.com/earthai-tech/geoprior-v3
# Copyright (c) 2026-present
# Author: LKouadio <https://lkouadio.com>
"""Plot impact transferability.
This figure extends ``plot_transfer.py`` by adding
decision-maker oriented diagnostics:
Panels
------
(a) Retention vs target baseline (overall):
- R² retention = R² / R²_baseline
- MAE retention = MAE_baseline / MAE
(b) Horizon retention (H1–H3) for R².
(c) Coverage–sharpness tradeoff (@80) per direction.
(d) Threshold risk skill (optional if eval CSVs exist):
- Reliability diagram for exceedance
- Brier score summary
(e) Hotspot stability (optional):
- Jaccard@K overlap of top-K hotspots
- Spearman rank correlation on overlap
- Optional error bars (mean ± std)
- Optional time-series small multiples
Inputs
------
Requires ``xfer_results.csv`` (from nat.com/xfer_matrix.py).
For panel (d), it also uses ``xfer_results.json``
if available, to locate per-job ``csv_eval`` paths.
Run (module form required)
--------------------------
.. code-block:: bash
python -m scripts plot-xfer-impact \
--src results/xfer/nansha__zhongshan \
--split val \
--calib source
"""
from __future__ import annotations
import argparse
import json
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.gridspec import (
GridSpec,
GridSpecFromSubplotSpec,
)
from matplotlib.lines import Line2D
from matplotlib.patches import Patch
from geoprior.utils.transfer import xfer_risk
from . import config as cfg
from . import utils as u
_DIR_CANON = {
"a_to_b": "A_to_B",
"b_to_a": "B_to_A",
"a_to_a": "A_to_A",
"b_to_b": "B_to_B",
}
[docs]
@dataclass
class TextFlags:
show_legend: bool
show_labels: bool
show_ticklabels: bool
show_title: bool
show_panel_titles: bool
title: str | None
def _canon_dir(x: Any) -> str:
s = str(x).strip()
k = s.lower()
return _DIR_CANON.get(k, s)
_BASELINE_INV = {v: k for (k, v) in cfg._BASELINE_MAP.items()}
def _dir_for_panel(*, direc: Any, strat: str) -> str:
d = _canon_dir(direc)
s = str(strat).lower()
if s == "baseline":
return _BASELINE_INV.get(d, d)
return d
def _to_num(df: pd.DataFrame, col: str) -> None:
if col in df.columns:
df[col] = pd.to_numeric(df[col], errors="coerce")
def _metrics_unit(df: pd.DataFrame) -> str:
for c in (
"subsidence_unit",
"metrics_unit",
"metric_unit",
):
if c in df.columns:
s = df[c].dropna()
if not s.empty:
u0 = str(s.iloc[0]).strip().lower()
if u0.startswith("m"):
return "m"
if u0.startswith("mm"):
return "mm"
return "mm"
def _canon_cols(df: pd.DataFrame) -> pd.DataFrame:
aliases = {
"strategy": ("strategy",),
"rescale_mode": ("rescale_mode", "rescale_m"),
"direction": ("direction",),
"source_city": ("source_city",),
"target_city": ("target_city",),
"split": ("split",),
"calibration": ("calibration",),
"overall_mae": ("overall_mae",),
"overall_mse": ("overall_mse",),
"overall_rmse": ("overall_rmse",),
"overall_r2": ("overall_r2",),
"coverage80": ("coverage80",),
"sharpness80": ("sharpness80",),
}
u.ensure_columns(df, aliases=aliases)
for c in (
"strategy",
"rescale_mode",
"split",
"calibration",
"source_city",
"target_city",
):
if c in df.columns:
df[c] = df[c].astype(str).str.strip().str.lower()
if "direction" in df.columns:
df["direction"] = df["direction"].map(_canon_dir)
for c in (
"overall_mae",
"overall_rmse",
"overall_rmse",
"overall_r2",
"coverage80",
"sharpness80",
):
_to_num(df, c)
for c in df.columns:
if str(c).startswith("per_horizon_"):
_to_num(df, c)
return df
def _find_xfer_csv(src: Any) -> Path:
pats = None
if hasattr(cfg, "PATTERNS"):
pats = cfg.PATTERNS.get("xfer_results_csv")
pats = pats or ("xfer_results.csv", "*xfer_results*.csv")
p = u.find_latest(src, pats, must_exist=True)
if p is None:
raise FileNotFoundError(str(src))
return Path(p)
def _find_xfer_json(csv_p: Path) -> Path | None:
root = csv_p.parent
p = u.find_latest(root, ("xfer_results.json",))
return None if p is None else Path(p)
def _is_pair_dir(p: Path) -> bool:
"""
True if p looks like a cityA__cityB folder containing
timestamp subfolders with xfer_results.csv inside.
"""
if not p.exists() or not p.is_dir():
return False
direct = (p / "xfer_results.csv").exists()
if direct:
return False
for d in p.iterdir():
if d.is_dir() and (d / "xfer_results.csv").exists():
return True
return False
def _row_matches(
df: pd.DataFrame,
*,
strategy: str,
split: str,
calib: str,
rescale_mode: str | None,
) -> bool:
if df is None or df.empty:
return False
m = df["strategy"].eq(str(strategy).lower())
m &= df["split"].eq(str(split).lower())
m &= df["calibration"].eq(str(calib).lower())
if (
rescale_mode is not None
and "rescale_mode" in df.columns
):
m &= df["rescale_mode"].eq(str(rescale_mode).lower())
return bool(m.any())
def _pick_latest_csvs_by_strategy(
pair_dir: Path,
*,
strategies: list[str],
split: str,
calib: str,
rescale_mode: str,
baseline_rescale: str,
) -> dict[str, Path]:
"""
From pair_dir (cityA__cityB), pick the latest run folder
per strategy, based on the content of xfer_results.csv.
"""
pats = ("xfer_results.csv", "*xfer_results*.csv")
cands = u.find_all(pair_dir, pats, must_exist=True)
want = [str(s).lower() for s in strategies]
picked: dict[str, Path] = {}
def _rm(s: str) -> str:
return (
baseline_rescale
if s == "baseline"
else rescale_mode
)
# pass 1: require rescale match
for fp in cands:
if len(picked) == len(want):
break
try:
d0 = _canon_cols(pd.read_csv(fp))
except Exception:
continue
for s in want:
if s in picked:
continue
if _row_matches(
d0,
strategy=s,
split=split,
calib=calib,
rescale_mode=_rm(s),
):
picked[s] = Path(fp)
# pass 2: fallback (ignore rescale) if still missing
if len(picked) < len(want):
for fp in cands:
if len(picked) == len(want):
break
try:
d0 = _canon_cols(pd.read_csv(fp))
except Exception:
continue
for s in want:
if s in picked:
continue
if _row_matches(
d0,
strategy=s,
split=split,
calib=calib,
rescale_mode=None,
):
picked[s] = Path(fp)
return picked
def _load_multi_jobs(
pair_dir: Path,
*,
strategies: list[str],
split: str,
calib: str,
rescale_mode: str,
baseline_rescale: str,
) -> tuple[
pd.DataFrame, list[dict[str, Any]], dict[str, Path]
]:
picked = _pick_latest_csvs_by_strategy(
pair_dir,
strategies=strategies,
split=split,
calib=calib,
rescale_mode=rescale_mode,
baseline_rescale=baseline_rescale,
)
dfs: list[pd.DataFrame] = []
rows: list[dict[str, Any]] = []
seen_eval: set = set()
for _s, fp in picked.items():
try:
d0 = pd.read_csv(fp)
dfs.append(d0)
except Exception:
pass
jp = fp.parent / "xfer_results.json"
if not jp.exists():
jp2 = _find_xfer_json(fp)
jp = jp2 if jp2 is not None else jp
if jp.exists():
for r in _load_json_rows(jp):
key = str(r.get("csv_eval", "")) or str(
r.get("csv_future", "")
)
if key and key in seen_eval:
continue
if key:
seen_eval.add(key)
rows.append(r)
if not dfs:
raise SystemExit(
f"No usable xfer_results.csv under {pair_dir}"
)
return (
pd.concat(dfs, axis=0, ignore_index=True),
rows,
picked,
)
def _subset(
df: pd.DataFrame,
*,
direction: str,
strategy: str,
split: str,
calib: str,
rescale_mode: str | None,
baseline_rescale: str,
) -> pd.DataFrame:
d = _canon_dir(direction)
s = str(strategy).lower()
sp = str(split).lower()
cm = str(calib).lower()
use_dir = d
use_rm = rescale_mode
if s == "baseline":
use_dir = cfg._BASELINE_MAP.get(d, d)
use_rm = baseline_rescale
m = df["direction"].eq(use_dir)
m &= df["strategy"].eq(s)
m &= df["split"].eq(sp)
m &= df["calibration"].eq(cm)
if use_rm is not None and "rescale_mode" in df.columns:
m &= df["rescale_mode"].eq(str(use_rm).lower())
out = df.loc[m].copy()
if out.empty and s == "baseline":
m2 = df["direction"].eq(d)
m2 &= df["strategy"].eq(s)
m2 &= df["split"].eq(sp)
m2 &= df["calibration"].eq(cm)
if use_rm is not None and (
"rescale_mode" in df.columns
):
m2 &= df["rescale_mode"].eq(str(use_rm).lower())
out = df.loc[m2].copy()
return out
def _reduce_vals(vals: np.ndarray) -> float:
a = np.asarray(vals, dtype=float)
a = a[np.isfinite(a)]
if a.size == 0:
return float("nan")
return float(np.median(a))
def _pick_metric(
df: pd.DataFrame,
*,
direction: str,
strategy: str,
split: str,
calib: str,
rescale_mode: str | None,
baseline_rescale: str,
col: str,
) -> float:
sub = _subset(
df,
direction=direction,
strategy=strategy,
split=split,
calib=calib,
rescale_mode=rescale_mode,
baseline_rescale=baseline_rescale,
)
if sub.empty or col not in sub.columns:
return float("nan")
return _reduce_vals(sub[col].values)
def _tgt_color(df: pd.DataFrame, direction: str) -> str:
d = _canon_dir(direction)
sub = df[df["direction"].eq(d)]
if sub.empty:
return "#777777"
tc = str(sub.iloc[0].get("target_city", "") or "")
tc = u.canonical_city(tc) # keep canonical casing
return cfg.CITY_COLORS.get(
tc,
cfg.CITY_COLORS.get(
tc.title(),
cfg.CITY_COLORS.get(tc.lower(), "#777777"),
),
)
def _dir_label(df: pd.DataFrame, direction: str) -> str:
d = _canon_dir(direction)
sub = df[df["direction"].eq(d)]
if sub.empty:
return d
r0 = sub.iloc[0]
sc = u.canonical_city(str(r0.get("source_city", "")))
tc = u.canonical_city(str(r0.get("target_city", "")))
return f"{sc} \u2192 {tc}"
def _clean_axes(ax: Any) -> None:
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
def _plot_retention(
ax: Any,
df: pd.DataFrame,
*,
split: str,
calib: str,
directions: list[str],
strategies: list[str],
rescale_mode: str | None,
baseline_rescale: str,
metric: str,
text: TextFlags,
) -> None:
x0 = np.arange(len(strategies), dtype=float)
group_w = 0.78
dir_w = group_w / max(1, len(directions))
bw = dir_w * 0.92
for di, d in enumerate(directions):
face = _tgt_color(df, d)
for si, s in enumerate(strategies):
v = _pick_metric(
df,
direction=d,
strategy=s,
split=split,
calib=calib,
rescale_mode=rescale_mode,
baseline_rescale=baseline_rescale,
col=(
"overall_r2"
if metric == "r2"
else "overall_mae"
),
)
b = _pick_metric(
df,
direction=d,
strategy="baseline",
split=split,
calib=calib,
rescale_mode=rescale_mode,
baseline_rescale=baseline_rescale,
col=(
"overall_r2"
if metric == "r2"
else "overall_mae"
),
)
if not (np.isfinite(v) and np.isfinite(b)):
continue
if metric == "r2":
y = v / b if b != 0 else float("nan")
else:
y = b / v if v != 0 else float("nan")
if not np.isfinite(y):
continue
off = -0.5 * group_w + (di + 0.5) * dir_w
xx = x0[si] + off
ax.bar(
xx,
y,
width=bw,
color=face,
alpha=0.86,
hatch=cfg._STRAT_HATCH.get(s, ""),
edgecolor=cfg._STRAT_EDGE.get(s, "#111111"),
linewidth=0.6,
)
ax.axhline(1.0, linestyle="--", linewidth=0.7)
ax.set_xticks(x0)
ax.set_xticklabels(
[cfg._STRAT_LABEL.get(s, s) for s in strategies],
rotation=12,
ha="right",
)
if text.show_labels:
if metric == "r2":
ax.set_ylabel("R² retention (× baseline)")
else:
ax.set_ylabel("MAE retention (× baseline)")
else:
ax.set_ylabel("")
if not text.show_ticklabels:
ax.set_xticklabels([])
ax.set_yticklabels([])
_clean_axes(ax)
ax.grid(False)
def _pick_horizons(
df: pd.DataFrame,
*,
metric: str,
max_n: int = 3,
) -> list[str]:
pref = f"per_horizon_{metric}."
hs: list[str] = []
for c in df.columns:
c = str(c)
if not c.startswith(pref):
continue
h = c.split(pref, 1)[1]
if re.fullmatch(r"H\d+", h):
hs.append(h)
hs = sorted(set(hs), key=lambda s: int(s[1:]))
return (
hs[:max_n]
if hs
else [f"H{i}" for i in range(1, max_n + 1)]
)
def _plot_horizon_ret(
ax: Any,
df: pd.DataFrame,
*,
direction: str,
metric: str,
split: str,
calib: str,
strategies: list[str],
rescale_mode: str | None,
baseline_rescale: str,
text: TextFlags,
) -> None:
# hs = _pick_horizons(df, metric="r2", max_n=3)
metric = str(metric).lower()
hs = _pick_horizons(df, metric=metric, max_n=3)
opt = "max"
try:
_, _, opt = cfg._METRIC_DEF[metric]
except:
pass
opt = str(opt).lower()
x = np.arange(len(hs), dtype=float)
# baseline per-horizon values
base: list[float] = []
for h in hs:
# col = f"per_horizon_r2.{h}"
col = f"per_horizon_{metric}.{h}"
b = _pick_metric(
df,
direction=direction,
strategy="baseline",
split=split,
calib=calib,
rescale_mode=rescale_mode,
baseline_rescale=baseline_rescale,
col=col,
)
base.append(b)
base_arr = np.asarray(base, dtype=float)
face = _tgt_color(df, direction)
for s in strategies:
if str(s).lower() == "baseline":
continue
ys: list[float] = []
for i, h in enumerate(hs):
# col = f"per_horizon_r2.{h}"
col = f"per_horizon_{metric}.{h}"
v = _pick_metric(
df,
direction=direction,
strategy=s,
split=split,
calib=calib,
rescale_mode=rescale_mode,
baseline_rescale=baseline_rescale,
col=col,
)
b = base_arr[i]
# ok = np.isfinite(v) and np.isfinite(b) and b != 0.0
# ys.append(v / b if ok else float("nan"))
ok = np.isfinite(v) and np.isfinite(b)
if not ok:
ys.append(float("nan"))
elif opt == "min":
ys.append(b / v if v != 0.0 else float("nan"))
else:
ys.append(v / b if b != 0.0 else float("nan"))
yarr = np.asarray(ys, dtype=float)
m = np.isfinite(yarr)
if not m.any():
continue
ax.plot(
x[m],
yarr[m],
color=face,
linestyle=cfg._STRAT_LINESTYLE.get(s, "-"),
marker=cfg._STRAT_MARKER.get(s, "o"),
markersize=3.2,
linewidth=1.0,
)
ax.axhline(1.0, linestyle="--", linewidth=0.7)
ax.set_xticks(x)
ax.set_xticklabels(
hs
) # shows H2/H3/H4 if that’s what exists
# ax.set_ylim(0.0, 1.25)
if opt == "max":
ax.set_ylim(-1.0, 1.25)
else:
ax.set_ylim(0.0, 1.25)
if text.show_labels:
# ax.set_ylabel("R² retention")
try:
_, ylab0, _ = cfg._METRIC_DEF[metric]
except:
ylab0 = metric.upper()
ax.set_ylabel(f"{ylab0} retention (× baseline)")
ax.set_xlabel("Horizon")
else:
ax.set_ylabel("")
ax.set_xlabel("")
if not text.show_ticklabels:
ax.set_xticklabels([])
ax.set_yticklabels([])
_clean_axes(ax)
ax.grid(False)
def _plot_cov_sharp(
ax: Any,
df: pd.DataFrame,
*,
direction: str,
split: str,
calib: str,
strategies: list[str],
rescale_mode: str | None,
baseline_rescale: str,
cov_target: float,
text: TextFlags,
) -> None:
face = _tgt_color(df, direction)
for s in strategies:
sub = _subset(
df,
direction=direction,
strategy=s,
split=split,
calib=calib,
rescale_mode=rescale_mode,
baseline_rescale=baseline_rescale,
)
if sub.empty:
continue
x = _reduce_vals(sub["sharpness80"].values)
y = _reduce_vals(sub["coverage80"].values)
if not (np.isfinite(x) and np.isfinite(y)):
continue
ax.scatter(
x,
y,
s=55.0,
facecolors=face,
edgecolors=cfg._STRAT_EDGE.get(s, "#111111"),
linewidths=0.9,
alpha=0.92,
marker=cfg._STRAT_MARKER.get(s, "o"),
)
ax.axhline(
float(cov_target),
linestyle="--",
linewidth=0.7,
)
ax.set_ylim(0.0, 1.0)
unit = _metrics_unit(df)
if text.show_labels:
ax.set_xlabel(f"Sharpness80 ({unit})")
ax.set_ylabel("Coverage80")
else:
ax.set_xlabel("")
ax.set_ylabel("")
if not text.show_ticklabels:
ax.set_xticklabels([])
ax.set_yticklabels([])
_clean_axes(ax)
ax.grid(False)
def _load_json_rows(p: Path) -> list[dict[str, Any]]:
try:
obj = json.loads(p.read_text(encoding="utf-8"))
except Exception:
return []
if isinstance(obj, list):
return [x for x in obj if isinstance(x, dict)]
return []
def _norm_rel_path(s: str) -> Path:
# ss = str(s).replace("\\\\", "/").strip()
# return Path(ss)
ss = str(s).strip().replace("\\", "/")
p = Path(ss)
if p.is_absolute() or p.exists():
return p
# try repo root (scripts/..)
repo = cfg.SCRIPTS_DIR.parent
p2 = repo / p
if p2.exists():
return p2
return p
def _risk_tables(
rows: list[dict[str, Any]],
*,
split: str,
calib: str,
threshold: float,
) -> tuple[pd.DataFrame, dict[tuple[str, str], pd.DataFrame]]:
if xfer_risk is None:
return pd.DataFrame(), {}
items: list[dict[str, Any]] = []
rel: dict[tuple[str, str], list[pd.DataFrame]] = {}
for r in rows:
if str(r.get("split", "")).lower() != split:
continue
if str(r.get("calibration", "")).lower() != calib:
continue
# strat = str(r.get("strategy", "")).lower()
# direc = _canon_dir(r.get("direction", ""))
strat = str(r.get("strategy", "")).lower()
direc = _dir_for_panel(
direc=r.get("direction", ""), strat=strat
)
pth = _norm_rel_path(str(r.get("csv_eval", "")))
if not pth.exists():
continue
try:
df = pd.read_csv(pth)
except Exception:
continue
need = [
"subsidence_actual",
"subsidence_q10",
"subsidence_q50",
"subsidence_q90",
]
if any(c not in df.columns for c in need):
continue
q10 = df["subsidence_q10"].to_numpy(float)
q50 = df["subsidence_q50"].to_numpy(float)
q90 = df["subsidence_q90"].to_numpy(float)
ya = df["subsidence_actual"].to_numpy(float)
pr = xfer_risk.exceed_prob_from_quantiles(
q10,
q50,
q90,
threshold=float(threshold),
)
yy = (ya >= float(threshold)).astype(float)
bs = xfer_risk.brier_score(pr, yy)
items.append(
{
"strategy": strat,
"direction": direc,
"brier": bs,
"n": int(np.isfinite(pr).sum()),
}
)
key = (direc, strat)
rel_df = xfer_risk.reliability_bins(
pr,
yy,
n_bins=10,
)
rel.setdefault(key, []).append(rel_df)
if not items:
return pd.DataFrame(), {}
tab = pd.DataFrame(items)
rel_out: dict[tuple[str, str], pd.DataFrame] = {}
for k, parts in rel.items():
if not parts:
continue
dd = pd.concat(parts, axis=0, ignore_index=True)
g = dd.groupby(
["bin_lo", "bin_hi"],
as_index=False,
)
rel_out[k] = g.agg(
{"p_mean": "mean", "y_rate": "mean", "n": "sum"}
)
return tab, rel_out
def _plot_risk(
ax_rel: Any,
ax_bs: Any,
*,
tab: pd.DataFrame,
rel: dict[tuple[str, str], pd.DataFrame],
directions: list[str],
strategies: list[str],
dir_colors: dict[str, str],
text: TextFlags,
) -> None:
if tab.empty:
ax_rel.text(
0.5,
0.5,
"No eval CSVs for risk panel",
ha="center",
va="center",
)
ax_rel.set_axis_off()
ax_bs.set_axis_off()
return
ax_rel.plot(
[0.0, 1.0],
[0.0, 1.0],
linestyle="--",
linewidth=0.7,
)
for d in directions:
for s in strategies:
key = (_canon_dir(d), str(s).lower())
rr = rel.get(key)
if rr is None or rr.empty:
continue
color = dir_colors.get(key[0], "#777777")
ax_rel.plot(
rr["p_mean"].to_numpy(float),
rr["y_rate"].to_numpy(float),
color=color,
linestyle=cfg._STRAT_LINESTYLE.get(
key[1], "-"
),
marker=cfg._STRAT_MARKER.get(key[1], "o"),
markersize=2.8,
linewidth=1.0,
)
ax_rel.set_xlim(0.0, 1.0)
ax_rel.set_ylim(0.0, 1.0)
if text.show_labels:
ax_rel.set_xlabel("Predicted prob")
ax_rel.set_ylabel("Observed freq")
else:
ax_rel.set_xlabel("")
ax_rel.set_ylabel("")
_clean_axes(ax_rel)
ax_rel.grid(False)
g = tab.groupby(["direction", "strategy"], as_index=False)
bs = g["brier"].mean()
x0 = np.arange(len(strategies), dtype=float)
group_w = 0.78
dir_w = group_w / max(1, len(directions))
bw = dir_w * 0.92
for di, d in enumerate(directions):
face = dir_colors.get(_canon_dir(d), "#777777")
for si, s in enumerate(strategies):
m = bs["direction"].eq(_canon_dir(d))
m &= bs["strategy"].eq(str(s).lower())
if not m.any():
continue
v = float(bs.loc[m, "brier"].iloc[0])
off = -0.5 * group_w + (di + 0.5) * dir_w
xx = x0[si] + off
ax_bs.bar(
xx,
v,
width=bw,
color=face,
alpha=0.86,
hatch=cfg._STRAT_HATCH.get(s, ""),
edgecolor=cfg._STRAT_EDGE.get(s, "#111111"),
linewidth=0.6,
)
ax_bs.set_xticks(x0)
ax_bs.set_xticklabels(
[cfg._STRAT_LABEL.get(s, s) for s in strategies],
rotation=12,
ha="right",
)
if text.show_labels:
ax_bs.set_ylabel("Brier (↓)")
else:
ax_bs.set_ylabel("")
if not text.show_ticklabels:
ax_bs.set_xticklabels([])
ax_bs.set_yticklabels([])
_clean_axes(ax_bs)
ax_bs.grid(False)
def _hotspot_tables(
rows: list[dict[str, Any]],
*,
split: str,
calib: str,
strategies: list[str],
directions: list[str],
ref_strategy: str,
score: str,
horizon: str,
threshold: float,
k: int,
) -> pd.DataFrame:
if xfer_risk is None:
return pd.DataFrame()
dfs: dict[tuple[str, str], pd.DataFrame] = {}
for r in rows:
if str(r.get("split", "")).lower() != split:
continue
if str(r.get("calibration", "")).lower() != calib:
continue
strat = str(r.get("strategy", "")).lower()
direc = _canon_dir(r.get("direction", ""))
pth = _norm_rel_path(str(r.get("csv_eval", "")))
if not pth.exists():
continue
try:
df = pd.read_csv(pth)
except Exception:
continue
need = ["coord_t", "coord_x", "coord_y"]
if any(c not in df.columns for c in need):
continue
try:
rk = xfer_risk.prepare_hotspot_ranks(
df,
score=str(score).lower(),
threshold=float(threshold),
horizon=horizon,
)
except Exception:
continue
dfs[(direc, strat)] = rk
ref = str(ref_strategy).lower()
out: list[dict[str, Any]] = []
for d in directions:
dd = _canon_dir(d)
ref_rk = dfs.get((dd, ref))
if ref_rk is None or ref_rk.empty:
continue
for s in strategies:
ss = str(s).lower()
if ss == ref:
continue
oth = dfs.get((dd, ss))
if oth is None or oth.empty:
continue
st = xfer_risk.hotspot_stability(
ref_rk,
oth,
k=int(k),
time_col="coord_t",
x_col="coord_x",
y_col="coord_y",
rank_col="rank",
)
if st.empty:
continue
out.append(
{
"direction": dd,
"strategy": ss,
"jaccard": float(st["jaccard"].mean()),
"spearman": float(st["spearman"].mean()),
"n_years": int(len(st)),
"n_common": float(st["n_common"].mean()),
}
)
return pd.DataFrame(out)
def _hotspot_series(
rows: list[dict[str, Any]],
*,
split: str,
calib: str,
strategies: list[str],
directions: list[str],
ref_strategy: str,
score: str,
horizon: str,
threshold: float,
k: int,
) -> pd.DataFrame:
if xfer_risk is None:
return pd.DataFrame()
dfs: dict[tuple[str, str], pd.DataFrame] = {}
for r in rows:
if str(r.get("split", "")).lower() != split:
continue
if str(r.get("calibration", "")).lower() != calib:
continue
# strat = str(r.get("strategy", "")).lower()
# direc = _canon_dir(r.get("direction", ""))
strat = str(r.get("strategy", "")).lower()
direc = _dir_for_panel(
direc=r.get("direction", ""), strat=strat
)
pth = _norm_rel_path(str(r.get("csv_eval", "")))
if not pth.exists():
continue
try:
df = pd.read_csv(pth)
except Exception:
continue
need = ["coord_t", "coord_x", "coord_y"]
if any(c not in df.columns for c in need):
continue
try:
rk = xfer_risk.prepare_hotspot_ranks(
df,
score=str(score).lower(),
threshold=float(threshold),
horizon=horizon,
)
except Exception:
continue
dfs[(direc, strat)] = rk
ref = str(ref_strategy).lower()
out: list[pd.DataFrame] = []
for d in directions:
dd = _canon_dir(d)
ref_rk = dfs.get((dd, ref))
if ref_rk is None or ref_rk.empty:
continue
for s in strategies:
ss = str(s).lower()
if ss == ref:
continue
oth = dfs.get((dd, ss))
if oth is None or oth.empty:
continue
st = xfer_risk.hotspot_stability(
ref_rk,
oth,
k=int(k),
time_col="coord_t",
x_col="coord_x",
y_col="coord_y",
rank_col="rank",
)
if st.empty:
continue
st = st.assign(direction=dd, strategy=ss)
out.append(st)
if not out:
return pd.DataFrame()
return pd.concat(out, axis=0, ignore_index=True)
def _plot_hotspots_bar(
ax_j: Any,
ax_s: Any,
*,
tab: pd.DataFrame,
directions: list[str],
strategies: list[str],
ref_strategy: str,
dir_colors: dict[str, str],
k: int,
errorbars: bool,
text: TextFlags,
) -> None:
if tab.empty:
ax_j.text(
0.5,
0.5,
"No eval CSVs for hotspots",
ha="center",
va="center",
)
ax_j.set_axis_off()
ax_s.set_axis_off()
return
ref = str(ref_strategy).lower()
comps = [
str(s).lower()
for s in strategies
if str(s).lower() != ref
]
x0 = np.arange(len(comps), dtype=float)
group_w = 0.78
dir_w = group_w / max(1, len(directions))
bw = dir_w * 0.92
for di, d in enumerate(directions):
dd = _canon_dir(d)
face = dir_colors.get(dd, "#777777")
for si, s in enumerate(comps):
m = tab["direction"].eq(dd)
m &= tab["strategy"].eq(s)
if not m.any():
continue
jv = float(tab.loc[m, "jaccard_mean"].iloc[0])
sv = float(tab.loc[m, "spearman_mean"].iloc[0])
je = float(tab.loc[m, "jaccard_std"].iloc[0])
se = float(tab.loc[m, "spearman_std"].iloc[0])
off = -0.5 * group_w + (di + 0.5) * dir_w
xx = x0[si] + off
ec = cfg._STRAT_EDGE.get(s, "#111111")
kw = dict(
width=bw,
color=face,
alpha=0.86,
hatch=cfg._STRAT_HATCH.get(s, ""),
edgecolor=ec,
linewidth=0.6,
)
if errorbars:
ax_j.bar(xx, jv, yerr=je, capsize=2, **kw)
ax_s.bar(xx, sv, yerr=se, capsize=2, **kw)
else:
ax_j.bar(xx, jv, **kw)
ax_s.bar(xx, sv, **kw)
ax_j.set_xticks(x0)
ax_s.set_xticks(x0)
labs = [cfg._STRAT_LABEL.get(s, s) for s in comps]
ax_j.set_xticklabels(labs, rotation=12, ha="right")
ax_s.set_xticklabels(labs, rotation=12, ha="right")
ax_j.set_ylim(0.0, 1.0)
ax_s.set_ylim(-1.0, 1.0)
if text.show_labels:
ax_j.set_ylabel(f"Jaccard@{int(k)}")
ax_s.set_ylabel(f"Spearman@{int(k)}")
else:
ax_j.set_ylabel("")
ax_s.set_ylabel("")
if not text.show_ticklabels:
ax_j.set_xticklabels([])
ax_j.set_yticklabels([])
ax_s.set_xticklabels([])
ax_s.set_yticklabels([])
_clean_axes(ax_j)
_clean_axes(ax_s)
ax_j.grid(False)
ax_s.grid(False)
def _plot_hotspots_timeline(
ax_j_ab: Any,
ax_j_ba: Any,
ax_s_ab: Any,
ax_s_ba: Any,
*,
series: pd.DataFrame,
strategies: list[str],
ref_strategy: str,
k: int,
text: TextFlags,
) -> None:
if series is None or series.empty:
for ax in (ax_j_ab, ax_j_ba, ax_s_ab, ax_s_ba):
ax.text(
0.5,
0.5,
"No eval CSVs for hotspots",
ha="center",
va="center",
)
ax.set_axis_off()
return
ref = str(ref_strategy).lower()
comps = [
str(s).lower()
for s in strategies
if str(s).lower() != ref
]
ls = {"xfer": "--", "warm": "-", "baseline": "-"}
mk = {"xfer": "o", "warm": "s", "baseline": "o"}
def _one(
ax: Any,
*,
direction: str,
col: str,
ylim: tuple[float, float],
) -> None:
dd = _canon_dir(direction)
for s in comps:
sub = series[series["direction"].eq(dd)]
sub = sub[sub["strategy"].eq(s)].copy()
if sub.empty:
continue
sub = sub.sort_values("coord_t")
x = sub["coord_t"].to_numpy(float)
y = sub[col].to_numpy(float)
m = np.isfinite(x) & np.isfinite(y)
if not m.any():
continue
ax.plot(
x[m],
y[m],
linestyle=ls.get(s, "-"),
marker=mk.get(s, "o"),
markersize=2.6,
linewidth=1.0,
label=cfg._STRAT_LABEL.get(s, s),
)
ax.set_ylim(*ylim)
_clean_axes(ax)
ax.grid(False)
_one(
ax_j_ab,
direction="A_to_B",
col="jaccard",
ylim=(0.0, 1.0),
)
_one(
ax_j_ba,
direction="B_to_A",
col="jaccard",
ylim=(0.0, 1.0),
)
_one(
ax_s_ab,
direction="A_to_B",
col="spearman",
ylim=(-1.0, 1.0),
)
_one(
ax_s_ba,
direction="B_to_A",
col="spearman",
ylim=(-1.0, 1.0),
)
if text.show_labels:
ax_j_ab.set_ylabel(f"Jaccard@{int(k)}")
ax_s_ab.set_ylabel(f"Spearman@{int(k)}")
ax_s_ab.set_xlabel("Year")
ax_s_ba.set_xlabel("Year")
if not text.show_ticklabels:
for ax in (ax_j_ab, ax_j_ba, ax_s_ab, ax_s_ba):
ax.set_xticklabels([])
ax.set_yticklabels([])
if text.show_legend:
ax_j_ba.legend(frameon=False, loc="best")
def _add_legends_fig(
fig: Any,
df: pd.DataFrame,
*,
directions: list[str],
strategies: list[str],
y: float,
) -> None:
dir_handles: list[Any] = []
for d in directions:
lab = _dir_label(df, d)
face = _tgt_color(df, d)
dir_handles.append(
Patch(
facecolor=face,
edgecolor="#111111",
linewidth=0.6,
label=lab,
)
)
strat_handles = []
for s in strategies:
strat_handles.append(
Line2D(
[0],
[0],
color="#111111",
linestyle=cfg._STRAT_LINESTYLE.get(s, "-"),
marker=cfg._STRAT_MARKER.get(s, "o"),
linewidth=1.0,
markersize=4.0,
label=cfg._STRAT_LABEL.get(s, s),
)
)
fig.legend(
handles=dir_handles,
frameon=False,
loc="upper left",
bbox_to_anchor=(0.02, y),
title="Direction",
borderaxespad=0.0,
handlelength=1.2,
labelspacing=0.25,
)
fig.legend(
handles=strat_handles,
frameon=False,
loc="upper right",
bbox_to_anchor=(0.98, y),
title="Strategy",
borderaxespad=0.0,
handlelength=1.2,
labelspacing=0.25,
)
[docs]
def render(
df: pd.DataFrame,
*,
split: str,
calib: str,
strategies: list[str],
directions: list[str],
rescale_mode: str | None,
baseline_rescale: str,
horizon_metric: str,
cov_target: float,
threshold: float,
xfer_rows: list[dict[str, Any]],
add_hotspots: bool,
hotspot_k: int,
hotspot_score: str,
hotspot_horizon: str,
hotspot_ref: str,
hotspot_style: str,
hotspot_errorbars: bool,
out: Path,
text: TextFlags,
) -> tuple[Path, Path]:
u.set_paper_style()
hs = str(hotspot_style).lower().strip()
if add_hotspots and hs == "timeline":
h = 7.6
elif add_hotspots:
h = 6.2
else:
h = 4.4
r = 3 if add_hotspots else 2
fig = plt.figure(figsize=(7.4, h))
gs = GridSpec(
r,
2,
figure=fig,
wspace=0.35,
hspace=0.55,
)
gs_a = GridSpecFromSubplotSpec(
1, 2, subplot_spec=gs[0, 0], wspace=0.55
)
ax_a1 = fig.add_subplot(gs_a[0, 0])
ax_a2 = fig.add_subplot(gs_a[0, 1])
gs_b = GridSpecFromSubplotSpec(
1, 2, subplot_spec=gs[0, 1], wspace=0.55
)
ax_b1 = fig.add_subplot(gs_b[0, 0])
ax_b2 = fig.add_subplot(gs_b[0, 1])
gs_c = GridSpecFromSubplotSpec(
1, 2, subplot_spec=gs[1, 0], wspace=0.55
)
ax_c1 = fig.add_subplot(gs_c[0, 0])
ax_c2 = fig.add_subplot(gs_c[0, 1])
gs_d = GridSpecFromSubplotSpec(
1, 2, subplot_spec=gs[1, 1], wspace=0.55
)
ax_d1 = fig.add_subplot(gs_d[0, 0])
ax_d2 = fig.add_subplot(gs_d[0, 1])
ax_e1 = None
ax_e2 = None
ax_e3 = None
ax_e4 = None
if add_hotspots:
if hs == "timeline":
gs_e = GridSpecFromSubplotSpec(
2,
2,
subplot_spec=gs[2, :],
wspace=0.35,
hspace=0.45,
)
ax_e1 = fig.add_subplot(gs_e[0, 0])
ax_e2 = fig.add_subplot(gs_e[0, 1])
ax_e3 = fig.add_subplot(gs_e[1, 0])
ax_e4 = fig.add_subplot(gs_e[1, 1])
else:
gs_e = GridSpecFromSubplotSpec(
1, 2, subplot_spec=gs[2, :]
)
ax_e1 = fig.add_subplot(gs_e[0, 0])
ax_e2 = fig.add_subplot(gs_e[0, 1])
_plot_retention(
ax_a1,
df,
split=split,
calib=calib,
directions=directions,
strategies=strategies,
rescale_mode=rescale_mode,
baseline_rescale=baseline_rescale,
metric="r2",
text=text,
)
_plot_retention(
ax_a2,
df,
split=split,
calib=calib,
directions=directions,
strategies=strategies,
rescale_mode=rescale_mode,
baseline_rescale=baseline_rescale,
metric="mae",
text=text,
)
_plot_horizon_ret(
ax_b1,
df,
direction="A_to_B",
metric=horizon_metric,
split=split,
calib=calib,
strategies=strategies,
rescale_mode=rescale_mode,
baseline_rescale=baseline_rescale,
text=text,
)
_plot_horizon_ret(
ax_b2,
df,
direction="B_to_A",
metric=horizon_metric,
split=split,
calib=calib,
strategies=strategies,
rescale_mode=rescale_mode,
baseline_rescale=baseline_rescale,
text=text,
)
_plot_cov_sharp(
ax_c1,
df,
direction="A_to_B",
split=split,
calib=calib,
strategies=strategies,
rescale_mode=rescale_mode,
baseline_rescale=baseline_rescale,
cov_target=cov_target,
text=text,
)
_plot_cov_sharp(
ax_c2,
df,
direction="B_to_A",
split=split,
calib=calib,
strategies=strategies,
rescale_mode=rescale_mode,
baseline_rescale=baseline_rescale,
cov_target=cov_target,
text=text,
)
tab, rel = _risk_tables(
xfer_rows,
split=split,
calib=calib,
threshold=float(threshold),
)
dir_colors = {
_canon_dir(d): _tgt_color(df, d) for d in directions
}
_plot_risk(
ax_d1,
ax_d2,
tab=tab,
rel=rel,
directions=directions,
strategies=strategies,
dir_colors=dir_colors,
text=text,
)
if add_hotspots and ax_e1 is not None:
hs_series = _hotspot_series(
xfer_rows,
split=split,
calib=calib,
strategies=strategies,
directions=directions,
ref_strategy=hotspot_ref,
score=hotspot_score,
horizon=hotspot_horizon,
threshold=float(threshold),
k=int(hotspot_k),
)
if hs == "timeline" and ax_e3 is not None:
_plot_hotspots_timeline(
ax_e1,
ax_e2,
ax_e3,
ax_e4,
series=hs_series,
strategies=strategies,
ref_strategy=hotspot_ref,
k=int(hotspot_k),
text=text,
)
else:
if xfer_risk is None:
hs_tab = pd.DataFrame()
else:
hs_tab = xfer_risk.stability_group_summary(
hs_series,
group_cols=("direction", "strategy"),
time_col="coord_t",
ddof=1,
)
_plot_hotspots_bar(
ax_e1,
ax_e2,
tab=hs_tab,
directions=directions,
strategies=strategies,
ref_strategy=hotspot_ref,
dir_colors=dir_colors,
k=int(hotspot_k),
errorbars=bool(hotspot_errorbars),
text=text,
)
if text.show_panel_titles:
ax_a1.set_title("(a) Retention: R²")
ax_a2.set_title("(a) Retention: MAE")
ax_b1.set_title(
"(b) Horizon:\n" + _dir_label(df, "A_to_B"),
pad=2,
)
ax_b2.set_title(
"(b) Horizon:\n" + _dir_label(df, "B_to_A"),
pad=2,
)
ax_c1.set_title(
"(c) Cov–sharp:\n" + _dir_label(df, "A_to_B"),
pad=2,
)
ax_c2.set_title(
"(c) Cov–sharp:\n" + _dir_label(df, "B_to_A"),
pad=2,
)
unit = _metrics_unit(df)
ax_d1.set_title(
f"(d) Reliability @ {threshold:g} {unit}"
)
ax_d2.set_title("(d) Brier")
if add_hotspots and ax_e1 is not None:
k0 = int(hotspot_k)
if hs == "timeline" and ax_e3 is not None:
ax_e1.set_title(
f"(e) J@{k0}: " + _dir_label(df, "A_to_B")
)
ax_e2.set_title(
f"(e) J@{k0}: " + _dir_label(df, "B_to_A")
)
ax_e3.set_title(
f"(e) ρ@{k0}: " + _dir_label(df, "A_to_B")
)
ax_e4.set_title(
f"(e) ρ@{k0}: " + _dir_label(df, "B_to_A")
)
else:
ax_e1.set_title(f"(e) Hotspots J@{k0}")
ax_e2.set_title(f"(e) Hotspots ρ@{k0}")
has_leg = bool(text.show_legend)
has_ttl = bool(text.show_title)
if has_leg:
y_leg = 0.995 if not has_ttl else 0.965
_add_legends_fig(
fig,
df,
directions=directions,
strategies=strategies,
y=y_leg,
)
if has_ttl:
t0 = "Supplementary Figure Sy — Transfer impact"
t = u.resolve_title(default=t0, title=text.title)
y_ttl = 0.94 if has_leg else 0.99
fig.suptitle(t, x=0.02, y=y_ttl, ha="left")
# reserve headroom for (legend + title)
if has_leg and has_ttl:
fig.subplots_adjust(top=0.86)
elif has_leg:
fig.subplots_adjust(top=0.90)
elif has_ttl:
fig.subplots_adjust(top=0.93)
# top = 0.86 if (has_leg and has_ttl) else (
# 0.90 if has_leg else (0.93 if has_ttl else 0.98))
# fig.tight_layout(rect=[0.0, 0.0, 1.0, top])
stem = out
if stem.suffix:
stem = stem.with_suffix("")
png = stem.with_suffix(".png")
svg = stem.with_suffix(".svg")
eps = stem.with_suffix(".eps")
fig.savefig(png, bbox_inches="tight")
fig.savefig(svg, bbox_inches="tight")
fig.savefig(eps, bbox_inches="tight")
plt.close(fig)
return png, svg
[docs]
def parse_args(
argv: list[str] | None = None, *, prog: str | None = None
) -> Any:
ap = argparse.ArgumentParser(
prog=prog or "plot-xfer-impact",
description=(
"Supplementary Sy — transfer impact plots."
),
)
ap.add_argument(
"--src",
type=str,
default="results/xfer",
help="Folder to search latest xfer_results.csv",
)
ap.add_argument(
"--xfer-csv",
type=str,
default=None,
help="Explicit xfer_results.csv (file or dir)",
)
ap.add_argument(
"--xfer-json",
type=str,
default=None,
help="Explicit xfer_results.json (optional)",
)
ap.add_argument(
"--split",
type=str,
default="val",
choices=("val", "test"),
help="Which split to plot",
)
ap.add_argument(
"--calib",
type=str,
default="source",
choices=("none", "source", "target"),
help="Calibration mode for panels",
)
ap.add_argument(
"--strategies",
nargs="+",
default=["baseline", "xfer", "warm"],
help="Strategies (baseline xfer warm)",
)
ap.add_argument(
"--rescale-mode",
type=str,
default="strict",
choices=("strict", "as_is"),
help="Rescale mode for xfer/warm",
)
ap.add_argument(
"--baseline-rescale",
type=str,
default="as_is",
choices=("strict", "as_is"),
help="Rescale mode for baseline fetch",
)
ap.add_argument(
"--horizon-metric",
type=str,
default="rmse",
choices=("r2", "mae", "mse", "rmse"),
help=(
"Metric for the horizon-retention panels "
"(default: rmse)."
),
)
ap.add_argument(
"--cov-target",
type=float,
default=0.80,
help="Coverage reference line",
)
ap.add_argument(
"--threshold",
type=float,
default=50.0,
help="Exceedance threshold for risk panel",
)
ap.add_argument(
"--add-hotspots",
type=str,
default="false",
help="Add hotspot stability panels (true/false).",
)
ap.add_argument(
"--hotspot-k",
type=int,
default=100,
help="Top-K hotspots for stability.",
)
ap.add_argument(
"--hotspot-score",
type=str,
default="q50",
choices=("q50", "exceed"),
help="Hotspot score: q50 or exceed.",
)
ap.add_argument(
"--hotspot-horizon",
type=str,
default="H3",
choices=("H1", "H2", "H3", "all"),
help="Which horizon to rank (H1/H2/H3/all).",
)
ap.add_argument(
"--hotspot-ref",
type=str,
default="baseline",
help="Reference strategy for stability.",
)
ap.add_argument(
"--hotspot-style",
type=str,
default="bar",
choices=("bar", "timeline"),
help="Hotspot panel: bar or timeline.",
)
ap.add_argument(
"--hotspot-errorbars",
type=str,
default="false",
help="Show mean±std bars (true/false).",
)
u.add_plot_text_args(
ap,
default_out="figureS_xfer_impact",
)
return ap.parse_args(argv)
def _text_flags(args: Any) -> TextFlags:
return TextFlags(
show_legend=u.str_to_bool(
args.show_legend, default=True
),
show_labels=u.str_to_bool(
args.show_labels, default=True
),
show_ticklabels=u.str_to_bool(
args.show_ticklabels,
default=True,
),
show_title=u.str_to_bool(
args.show_title, default=True
),
show_panel_titles=u.str_to_bool(
args.show_panel_titles,
default=True,
),
title=args.title,
)
[docs]
def figSx_xfer_impact_main(
argv: list[str] | None = None, *, prog: str | None = None
) -> None:
u.ensure_script_dirs()
args = parse_args(argv, prog=prog)
text = _text_flags(args)
split = str(args.split).lower()
calib = str(args.calib).lower()
strategies = [str(s).lower() for s in args.strategies]
directions = ["A_to_B", "B_to_A"]
rows: list[dict[str, Any]] = []
rm = str(args.rescale_mode).lower()
brm = str(args.baseline_rescale).lower()
rows: list[dict[str, Any]] = []
picked: dict[str, Path] = {}
# --- Resolve CSV(s) ---
if args.xfer_csv:
p = Path(args.xfer_csv).expanduser()
else:
p = Path(args.src).expanduser()
if p.is_dir() and _is_pair_dir(p):
df_raw, rows_raw, picked = _load_multi_jobs(
p,
strategies=strategies,
split=split,
calib=calib,
rescale_mode=rm,
baseline_rescale=brm,
)
df = _canon_cols(df_raw)
# only use merged JSON rows if user didn't force one
if not args.xfer_json:
rows = rows_raw
if picked:
for s, fp in picked.items():
print(f"[pick] {s}: {fp.parent.name}")
else:
csv_p = _find_xfer_csv(p if p.exists() else args.src)
df = _canon_cols(pd.read_csv(csv_p))
if args.xfer_json:
jp = Path(args.xfer_json).expanduser()
if jp.is_dir():
jp = jp / "xfer_results.json"
if jp.exists():
rows = _load_json_rows(jp)
else:
jp2 = _find_xfer_json(csv_p)
if jp2 is not None and jp2.exists():
rows = _load_json_rows(jp2)
# --- Apply filters after loading ---
df = df[df["split"].eq(split)].copy()
if df.empty:
raise SystemExit("No rows after split filter.")
df = df[df["calibration"].eq(calib)].copy()
if df.empty:
raise SystemExit("No rows after calib filter.")
out = u.resolve_fig_out(args.out)
png, svg = render(
df,
split=split,
calib=calib,
strategies=strategies,
directions=directions,
rescale_mode=rm,
baseline_rescale=brm,
horizon_metric=str(args.horizon_metric).lower(),
cov_target=float(args.cov_target),
threshold=float(args.threshold),
xfer_rows=rows,
add_hotspots=u.str_to_bool(
args.add_hotspots, default=False
),
hotspot_k=int(args.hotspot_k),
hotspot_score=str(args.hotspot_score).lower(),
hotspot_horizon=str(args.hotspot_horizon),
hotspot_ref=str(args.hotspot_ref).lower(),
hotspot_style=str(args.hotspot_style).lower(),
hotspot_errorbars=u.str_to_bool(
args.hotspot_errorbars,
default=False,
),
out=out,
text=text,
)
print(f"[OK] Wrote {png}")
print(f"[OK] Wrote {svg}")
[docs]
def main(
argv: list[str] | None = None, *, prog: str | None = None
) -> None:
figSx_xfer_impact_main(argv, prog=prog)
if __name__ == "__main__":
main()