# SPDX-License-Identifier: Apache-2.0
# GeoPrior-v3 — https://github.com/earthai-tech/geoprior-v3
# Copyright (c) 2026-present
# Author: LKouadio <https://lkouadio.com>
r"""Script helpers for plotting spatial forecast outputs."""
from __future__ import annotations
import argparse
from pathlib import Path
from typing import Any
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from . import config as cfg
from . import utils
_CITY_A = cfg.CITY_CANON.get("ns", "Nansha")
_CITY_B = cfg.CITY_CANON.get("zh", "Zhongshan")
# ---------------------------------------------------------------------
# IO helpers
# ---------------------------------------------------------------------
def _require_cols(
df: pd.DataFrame,
cols: list[str],
*,
where: str,
) -> None:
miss = [c for c in cols if c not in df.columns]
if miss:
raise KeyError(f"{where}: missing {miss}")
def _canonize(
df: pd.DataFrame,
*,
where: str,
required: list[str],
) -> pd.DataFrame:
# NOTE: ensure_columns returns a mapping;
# it mutates df in-place.
utils.ensure_columns(df, aliases=cfg._BASE_ALIASES)
_require_cols(df, required, where=where)
return df
def _load_calib_df(path: str) -> pd.DataFrame:
p = utils.as_path(path)
df = pd.read_csv(p)
_canonize(
df,
where="calibrated-forecast",
required=list(cfg._CALIB_REQUIRED),
)
for c in [
"sample_idx",
"forecast_step",
"coord_t",
"coord_x",
"coord_y",
"subsidence_actual",
"subsidence_q50",
]:
df[c] = pd.to_numeric(df[c], errors="coerce")
df = df.dropna(
subset=[
"sample_idx",
"coord_t",
"coord_x",
"coord_y",
"subsidence_actual",
"subsidence_q50",
]
).copy()
return df
def _load_future_df(path: str) -> pd.DataFrame:
p = utils.as_path(path)
df = pd.read_csv(p)
_canonize(
df,
where="future-forecast",
required=list(cfg._FUT_REQUIRED),
)
for c in [
"sample_idx",
"forecast_step",
"coord_t",
"coord_x",
"coord_y",
"subsidence_q50",
]:
df[c] = pd.to_numeric(df[c], errors="coerce")
df = df.dropna(
subset=[
"sample_idx",
"coord_t",
"coord_x",
"coord_y",
"subsidence_q50",
]
).copy()
return df
def _pick_calib_path(
art: utils.Artifacts,
split: str,
) -> tuple[Path | None, str]:
if split == "val":
return (art.forecast_val_csv, "val")
if split == "test":
return (art.forecast_test_csv, "test")
if art.forecast_test_csv is not None:
return (art.forecast_test_csv, "test")
return (art.forecast_val_csv, "val")
def _pick_future_path(
art: utils.Artifacts,
split: str,
) -> tuple[Path | None, str]:
if split == "val":
return (art.forecast_future_csv, "val")
if split == "test":
return (art.forecast_test_future_csv, "test")
if art.forecast_test_future_csv is not None:
return (art.forecast_test_future_csv, "test")
return (art.forecast_future_csv, "val")
def _resolve_city(
*,
city: str,
src: str | None,
calib: str | None,
future: str | None,
split: str,
) -> dict[str, Any]:
"""
Resolve (calib_df, future_df) for one city.
Priority:
1) manual overrides: --*-calib and --*-future
2) auto-detect from --*-src with --split selection
"""
out: dict[str, Any] = {"name": city}
out["color"] = cfg.CITY_COLORS.get(city, "#333333")
if calib and future:
out["calib_df"] = _load_calib_df(calib)
out["future_df"] = _load_future_df(future)
out["split"] = split
out["src_note"] = "manual"
return out
if not src:
raise ValueError(
f"{city}: provide --{city[:2].lower()}-src "
f"or both --{city[:2].lower()}-calib and "
f"--{city[:2].lower()}-future"
)
art = utils.detect_artifacts(src)
c_p, c_lab = _pick_calib_path(art, split)
f_p, f_lab = _pick_future_path(art, split)
if c_p is None:
raise FileNotFoundError(
f"{city}: no calibrated CSV found under {src}"
)
if f_p is None:
raise FileNotFoundError(
f"{city}: no future CSV found under {src}"
)
out["calib_df"] = _load_calib_df(str(c_p))
out["future_df"] = _load_future_df(str(f_p))
out["split"] = c_lab
out["src_note"] = (
f"{c_p.name} + {f_p.name} ({c_lab}/{f_lab})"
)
return out
# ---------------------------------------------------------------------
# Spatial helpers
# ---------------------------------------------------------------------
def _coord_kind(df: pd.DataFrame) -> str:
"""
Heuristic:
- lon/lat if x in [-180,180] and y in [-90,90]
- otherwise treat as projected (UTM, etc.)
"""
x = df["coord_x"].to_numpy()
y = df["coord_y"].to_numpy()
if x.size == 0 or y.size == 0:
return "unknown"
xm = float(np.nanmax(np.abs(x)))
ym = float(np.nanmax(np.abs(y)))
if xm <= 180.0 and ym <= 90.0:
return "lonlat"
return "projected"
def _subset_year(df: pd.DataFrame, year: int) -> pd.DataFrame:
return df.loc[
df["coord_t"].astype(int).eq(int(year))
].copy()
def _compute_cum_calib(
df: pd.DataFrame,
*,
kind: str,
) -> pd.DataFrame:
df2 = df.sort_values(["sample_idx", "coord_t"]).copy()
k = str(kind).strip().lower()
if k in ("rate", "increment"):
df2["subs_cum_actual"] = df2.groupby("sample_idx")[
"subsidence_actual"
].cumsum()
df2["subs_cum_q50"] = df2.groupby("sample_idx")[
"subsidence_q50"
].cumsum()
else:
# already cumulative
df2["subs_cum_actual"] = df2["subsidence_actual"]
df2["subs_cum_q50"] = df2["subsidence_q50"]
return df2
def _compute_cum_q50(
df: pd.DataFrame,
*,
kind: str,
) -> pd.DataFrame:
df2 = df.sort_values(["sample_idx", "coord_t"]).copy()
k = str(kind).strip().lower()
if k in ("rate", "increment"):
df2["subs_cum_q50"] = df2.groupby("sample_idx")[
"subsidence_q50"
].cumsum()
else:
df2["subs_cum_q50"] = df2["subsidence_q50"]
return df2
def _extent_from_panels(
panels: dict[
str, tuple[np.ndarray, np.ndarray, np.ndarray]
],
) -> tuple[float, float, float, float] | None:
xs: list[np.ndarray] = []
ys: list[np.ndarray] = []
for _k, (x, y, _v) in panels.items():
if x.size and y.size:
xs.append(x)
ys.append(y)
if not xs:
return None
x0 = float(np.nanmin(np.concatenate(xs)))
x1 = float(np.nanmax(np.concatenate(xs)))
y0 = float(np.nanmin(np.concatenate(ys)))
y1 = float(np.nanmax(np.concatenate(ys)))
return (x0, x1, y0, y1)
def _grid_mean(
x: np.ndarray,
y: np.ndarray,
v: np.ndarray,
*,
res: int,
extent: tuple[float, float, float, float],
) -> tuple[np.ndarray, tuple[float, float, float, float]]:
"""
Rasterize (mean in bin) to res x res grid.
Returns:
Z (shape [res,res], origin lower)
extent (xmin, xmax, ymin, ymax)
"""
xmin, xmax, ymin, ymax = extent
xb = np.linspace(xmin, xmax, res + 1)
yb = np.linspace(ymin, ymax, res + 1)
sumv, _, _ = np.histogram2d(
x, y, bins=[xb, yb], weights=v
)
cnts, _, _ = np.histogram2d(x, y, bins=[xb, yb])
with np.errstate(invalid="ignore", divide="ignore"):
z = sumv / cnts
z[cnts == 0] = np.nan
return (z.T, extent)
def _clip_bounds(
vals: list[np.ndarray],
clip: float,
) -> tuple[float, float]:
good = []
for v in vals:
if v is None:
continue
vv = v[np.isfinite(v)]
if vv.size:
good.append(vv)
if not good:
return (0.0, 1.0)
g = np.concatenate(good)
hi = float(np.nanpercentile(g, clip))
lo = float(np.nanpercentile(g, 100.0 - clip))
if not np.isfinite(lo) or not np.isfinite(hi) or lo >= hi:
lo = float(np.nanmin(g))
hi = float(np.nanmax(g))
return (lo, hi)
def _axes_cleanup(ax: plt.Axes) -> None:
ax.set_xticks([])
ax.set_yticks([])
for s in ("top", "right", "left", "bottom"):
ax.spines[s].set_visible(False)
def _infer_unit(cities: list[dict[str, Any]]) -> str:
for c in cities:
df = c.get("calib_df")
if isinstance(df, pd.DataFrame):
if "subsidence_unit" in df.columns:
u = str(df["subsidence_unit"].iloc[0])
if u and u.lower() != "nan":
return u
return "mm"
# ---------------------------------------------------------------------
# Plot (Figure 6)
# ---------------------------------------------------------------------
[docs]
def plot_fig6_spatial_forecasts(
*,
cities: list[dict[str, Any]],
year_val: int,
years_fore: list[int],
cumulative: bool,
subsidence_kind: str,
grid_res: int,
clip: float,
cmap_name: str,
hotspot_mode: str,
hotspot_q: float,
out: str,
out_hotspots: str | None,
dpi: int,
font: int,
show_legend: bool,
show_title: bool,
show_panel_titles: bool,
title: str | None,
) -> None:
"""
Two columns fixed:
- observed year_val (actual)
- predicted year_val (q50)
Remaining columns:
- years_fore (q50 or cum q50)
Rows:
- one per city (1-city and 2-city layouts supported)
"""
if not cities:
raise ValueError("No city data provided.")
utils.ensure_script_dirs()
utils.set_paper_style(fontsize=int(font), dpi=int(dpi))
_infer_unit(cities)
subs_kind = str(subsidence_kind).strip().lower()
plot_cum = bool(cumulative) or (subs_kind == "cumulative")
cmap = plt.get_cmap(cmap_name)
# Build per-city panel dictionaries and extents
for c in cities:
calib_df = c["calib_df"]
fut_df = c["future_df"]
calib_use = calib_df
fut_use = fut_df
if plot_cum:
calib_use = _compute_cum_calib(
calib_df, kind=subs_kind
)
fut_use = _compute_cum_q50(fut_df, kind=subs_kind)
obs = _subset_year(calib_use, year_val)
pred = obs.copy()
panels: dict[
str, tuple[np.ndarray, np.ndarray, np.ndarray]
] = {}
x = obs["coord_x"].to_numpy()
y = obs["coord_y"].to_numpy()
if plot_cum:
v0 = obs["subs_cum_actual"].to_numpy()
v1 = pred["subs_cum_q50"].to_numpy()
else:
v0 = obs["subsidence_actual"].to_numpy()
v1 = pred["subsidence_q50"].to_numpy()
panels["obs"] = (x, y, v0)
panels["pred"] = (x, y, v1)
for yy in years_fore:
dyy = _subset_year(fut_use, int(yy))
if dyy.empty:
continue
xf = dyy["coord_x"].to_numpy()
yf = dyy["coord_y"].to_numpy()
if plot_cum:
vf = dyy["subs_cum_q50"].to_numpy()
else:
vf = dyy["subsidence_q50"].to_numpy()
panels[f"Y{int(yy)}"] = (xf, yf, vf)
c["panels"] = panels
ext = _extent_from_panels(panels)
if ext is None:
raise ValueError(
f"{c['name']}: no points to plot."
)
c["extent"] = ext
c["coord_kind"] = _coord_kind(calib_df)
# Global vmin/vmax across all cities/panels
all_vals: list[np.ndarray] = []
for c in cities:
for _k, (_x, _y, v) in c["panels"].items():
all_vals.append(v)
vmin, vmax = _clip_bounds(all_vals, float(clip))
# Layout sizes
n_rows = len(cities)
n_cols = 2 + len(years_fore)
fig_w = 2.2 * n_cols + 1.6
fig_h = 2.6 * n_rows + 1.2
fig, axes = plt.subplots(
nrows=n_rows,
ncols=n_cols,
figsize=(fig_w, fig_h),
constrained_layout=False,
)
ax_arr = np.array(axes, ndmin=2)
plt.subplots_adjust(
left=0.06,
right=0.90 if show_legend else 0.98,
top=0.92,
bottom=0.08,
wspace=0.06,
hspace=0.18,
)
# Hotspot export accumulator
hs_rows: list[dict[str, Any]] = []
def _add_hotspots(
*,
city: str,
panel: str,
kind: str,
year: int,
extent: tuple[float, float, float, float],
z: np.ndarray,
metric: np.ndarray,
thr: float,
base: np.ndarray | None,
) -> None:
xmin, xmax, ymin, ymax = extent
ny, nx = z.shape
dx = (xmax - xmin) / float(nx)
dy = (ymax - ymin) / float(ny)
m = metric >= thr
idxs = np.argwhere(m)
for iy, ix in idxs:
vv = float(z[iy, ix])
if not np.isfinite(vv):
continue
rec: dict[str, Any] = {
"city": city,
"panel": panel,
"kind": kind,
"year": int(year),
"coord_x": float(xmin + (ix + 0.5) * dx),
"coord_y": float(ymin + (iy + 0.5) * dy),
"value": vv,
"metric_value": float(metric[iy, ix]),
"threshold": float(thr),
"hotspot_mode": str(hotspot_mode),
"hotspot_quantile": float(hotspot_q),
}
if base is not None:
rec["baseline_value"] = float(base[iy, ix])
hs_rows.append(rec)
def _draw_panel(
ax: plt.Axes,
*,
city: str,
panel_key: str,
x: np.ndarray,
y: np.ndarray,
v: np.ndarray,
extent: tuple[float, float, float, float],
coord_kind: str,
delta_base: (
tuple[np.ndarray, np.ndarray, np.ndarray] | None
),
title_txt: str,
) -> mpl.image.AxesImage | None:
if x.size == 0:
ax.set_axis_off()
return None
z, ext = _grid_mean(
x, y, v, res=grid_res, extent=extent
)
aspect = "equal"
if coord_kind == "lonlat":
aspect = "auto"
im = ax.imshow(
z,
extent=ext,
origin="lower",
cmap=cmap,
vmin=vmin,
vmax=vmax,
interpolation="nearest",
aspect=aspect,
)
# Hotspots: contours + optional CSV point export
if hotspot_mode != "none":
metric = None
base_z = None
if hotspot_mode == "absolute":
metric = np.abs(z)
else:
if delta_base is not None:
bx, by, bv = delta_base
base_z, _ = _grid_mean(
bx,
by,
bv,
res=grid_res,
extent=extent,
)
metric = np.abs(z - base_z)
if (
metric is not None
and np.isfinite(metric).any()
):
thr = float(np.nanquantile(metric, hotspot_q))
try:
ax.contour(
(metric >= thr).astype(float),
levels=[0.5],
colors="k",
linewidths=0.7,
origin="lower",
extent=ext,
)
except Exception:
pass
# Export points
if hs_rows is not None:
if panel_key.startswith("Y"):
yy = int(panel_key[1:])
kind = "forecast"
elif panel_key == "obs":
yy = int(year_val)
kind = "observed"
else:
yy = int(year_val)
kind = "predicted"
_add_hotspots(
city=city,
panel=panel_key,
kind=kind,
year=yy,
extent=ext,
z=z,
metric=metric,
thr=thr,
base=base_z,
)
if show_panel_titles:
ax.set_title(
title_txt,
pad=4,
fontweight="bold",
)
_axes_cleanup(ax)
return im
# Delta baselines (obs year) per city
baselines: dict[
str, tuple[np.ndarray, np.ndarray, np.ndarray]
] = {}
for c in cities:
x0, y0, v0 = c["panels"]["obs"]
baselines[c["name"]] = (x0, y0, v0)
# Draw all panels
for r, c in enumerate(cities):
name = str(c["name"])
panels = c["panels"]
extent = c["extent"]
ck = str(c["coord_kind"])
# fixed first two columns
keys = ["obs", "pred"]
for yy in years_fore:
keys.append(f"Y{int(yy)}")
for j, k in enumerate(keys):
ax = ax_arr[r, j]
x, y, v = panels.get(k, (np.array([]),) * 3)
if k == "obs":
t = f"{name} • {year_val} observed"
elif k == "pred":
t = f"{name} • {year_val} predicted"
else:
yy = int(k[1:])
tag = "cumulative" if plot_cum else "q50"
t = f"{name} • {yy} {tag}"
db = baselines.get(name)
if hotspot_mode == "delta" and k == "obs":
db = None
_draw_panel(
ax,
city=name,
panel_key=k,
x=x,
y=y,
v=v,
extent=extent,
coord_kind=ck,
delta_base=db,
title_txt=t,
)
# Shared colorbar
if show_legend:
sm = plt.cm.ScalarMappable(
norm=mpl.colors.Normalize(vmin=vmin, vmax=vmax),
cmap=cmap,
)
cax = fig.add_axes([0.92, 0.20, 0.02, 0.60])
cb = fig.colorbar(sm, cax=cax)
key = "subsidence_cum" if plot_cum else "subsidence"
cb.set_label(utils.label(key))
# Suptitle
if show_title:
ttl = utils.resolve_title(
default="Fig. 6 — Spatial validation + forecasts",
title=title,
)
fig.suptitle(ttl, x=0.02, ha="left")
# Save figure
fig_p = utils.resolve_fig_out(out)
if fig_p.suffix:
fig_p = fig_p.with_suffix("")
fig.savefig(
str(fig_p) + ".png",
dpi=dpi,
bbox_inches="tight",
)
fig.savefig(str(fig_p) + ".svg", bbox_inches="tight")
plt.close(fig)
print(f"[OK] wrote {fig_p}.png/.svg")
# Save hotspot points (to scripts/out/)
if hotspot_mode != "none":
if not hs_rows:
print("[WARN] hotspot enabled but none found.")
return
if out_hotspots:
out_hs = utils.as_path(out_hotspots)
else:
out_hs = utils.resolve_out_out(
"fig6-hotspot-points.csv"
)
pd.DataFrame(hs_rows).to_csv(out_hs, index=False)
print(f"[OK] wrote {out_hs}")
# ---------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------
def _add_fig6_args(ap: argparse.ArgumentParser) -> None:
utils.add_city_flags(ap, default_both=True)
ap.add_argument(
"--ns-src",
type=str,
default=None,
help="Nansha results dir (auto-detect).",
)
ap.add_argument(
"--zh-src",
type=str,
default=None,
help="Zhongshan results dir (auto-detect).",
)
ap.add_argument(
"--ns-calib",
type=str,
default=None,
help="Override Nansha calibrated CSV.",
)
ap.add_argument(
"--zh-calib",
type=str,
default=None,
help="Override Zhongshan calibrated CSV.",
)
ap.add_argument(
"--ns-future",
type=str,
default=None,
help="Override Nansha future CSV.",
)
ap.add_argument(
"--zh-future",
type=str,
default=None,
help="Override Zhongshan future CSV.",
)
ap.add_argument(
"--split",
type=str,
choices=["auto", "val", "test"],
default="auto",
help="Pick val/test artifacts when using --*-src.",
)
ap.add_argument(
"--year-val",
type=int,
default=2022,
help="Validation year for observed/predicted maps.",
)
ap.add_argument(
"--years-forecast",
nargs="+",
type=int,
default=[2025, 2026],
help="Forecast years to plot as additional columns.",
)
ap.add_argument(
"--cumulative",
action="store_true",
help="Use cumulative q50 for forecast years.",
)
ap.add_argument(
"--subsidence-kind",
type=str,
default="cumulative",
choices=["cumulative", "rate", "increment"],
help=(
"Meaning of subsidence columns in CSVs. "
"Default: cumulative."
),
)
ap.add_argument(
"--grid-res",
type=int,
default=300,
help="Raster resolution (NxN bins).",
)
ap.add_argument(
"--clip",
type=float,
default=98.0,
help="Percentile clip for shared color scale.",
)
ap.add_argument(
"--cmap",
type=str,
default="viridis",
help="Matplotlib colormap name.",
)
ap.add_argument(
"--hotspot",
type=str,
choices=["none", "absolute", "delta"],
default="delta",
help="Hotspot contour mode.",
)
ap.add_argument(
"--hotspot-quantile",
type=float,
default=0.90,
help="Quantile threshold for hotspots.",
)
ap.add_argument(
"--hotspot-out",
type=str,
default=None,
help="CSV output for hotspot points.",
)
ap.add_argument("--dpi", type=int, default=cfg.PAPER_DPI)
ap.add_argument(
"--font",
type=int,
default=cfg.PAPER_FONT,
)
utils.add_plot_text_args(
ap,
default_out="fig6-spatial-forecasts",
)
[docs]
def plot_fig6_spatial_forecasts_main(
argv: list[str] | None = None,
*,
prog: str | None = None,
) -> None:
ap = argparse.ArgumentParser(
prog=prog or "plot-spatial-forecasts",
description="Fig. 6 spatial maps + forecasts.",
)
_add_fig6_args(ap)
args = ap.parse_args(argv)
show_legend = utils.str_to_bool(
args.show_legend, default=True
)
show_title = utils.str_to_bool(
args.show_title, default=True
)
show_pt = utils.str_to_bool(
args.show_panel_titles, default=True
)
cities0 = utils.resolve_cities(args)
if not cities0:
cities0 = [_CITY_A, _CITY_B]
want_ns = _CITY_A in cities0
want_zh = _CITY_B in cities0
if not want_ns and not want_zh:
want_ns = True
want_zh = True
cities: list[dict[str, Any]] = []
if want_ns:
cities.append(
_resolve_city(
city=_CITY_A,
src=args.ns_src,
calib=args.ns_calib,
future=args.ns_future,
split=args.split,
)
)
if want_zh:
cities.append(
_resolve_city(
city=_CITY_B,
src=args.zh_src,
calib=args.zh_calib,
future=args.zh_future,
split=args.split,
)
)
plot_fig6_spatial_forecasts(
cities=cities,
year_val=int(args.year_val),
years_fore=list(args.years_forecast),
cumulative=bool(args.cumulative),
subsidence_kind=str(args.subsidence_kind),
grid_res=int(args.grid_res),
clip=float(args.clip),
cmap_name=str(args.cmap),
hotspot_mode=str(args.hotspot),
hotspot_q=float(args.hotspot_quantile),
out=str(args.out),
out_hotspots=args.hotspot_out,
dpi=int(args.dpi),
font=int(args.font),
show_legend=bool(show_legend),
show_title=bool(show_title),
show_panel_titles=bool(show_pt),
title=args.title,
)
[docs]
def main(
argv: list[str] | None = None,
*,
prog: str | None = None,
) -> None:
plot_fig6_spatial_forecasts_main(argv, prog=prog)
if __name__ == "__main__":
main()