# SPDX-License-Identifier: Apache-2.0
# GeoPrior-v3 — https://github.com/earthai-tech/geoprior-v3
# Copyright (c) 2026-present
# Author: LKouadio <https://lkouadio.com>
r"""Script helpers for plotting driver response analyses."""
from __future__ import annotations
import argparse
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.gridspec import GridSpec
from . import config as cfg
from . import utils
[docs]
def robust_trend(
x: pd.Series,
y: pd.Series,
*,
nbins: int = 30,
min_n: int = 20,
) -> tuple[np.ndarray | None, np.ndarray | None]:
xv = pd.to_numeric(x, errors="coerce").to_numpy()
yv = pd.to_numeric(y, errors="coerce").to_numpy()
m = np.isfinite(xv) & np.isfinite(yv)
xv = xv[m]
yv = yv[m]
if xv.size < max(nbins * 2, 50):
return None, None
q = np.linspace(0.0, 1.0, nbins + 1)
bx = np.quantile(xv, q)
xi: list[float] = []
yi: list[float] = []
for k in range(nbins):
lo = bx[k]
hi = bx[k + 1]
sel = (xv >= lo) & (xv <= hi)
if int(sel.sum()) >= int(min_n):
xi.append(float(np.nanmedian(xv[sel])))
yi.append(float(np.nanmedian(yv[sel])))
if len(xi) < 3:
return None, None
return np.asarray(xi), np.asarray(yi)
def _panel(
ax,
df: pd.DataFrame,
*,
xcol: str,
ycol: str,
show_ylabel: bool,
gridsize: int,
vmin: int,
trend: bool,
trend_bins: int,
trend_min_n: int,
show_labels: bool,
show_ticklabels: bool,
) -> object:
x = pd.to_numeric(df[xcol], errors="coerce")
y = pd.to_numeric(df[ycol], errors="coerce")
m = np.isfinite(x) & np.isfinite(y)
hb = ax.hexbin(
x[m],
y[m],
gridsize=gridsize,
mincnt=vmin,
cmap="viridis",
linewidths=0.0,
bins="log",
)
if trend:
xi, yi = robust_trend(
x,
y,
nbins=trend_bins,
min_n=trend_min_n,
)
if xi is not None:
ax.plot(
xi,
yi,
color="k",
linewidth=1.0,
alpha=0.9,
)
if show_labels:
ax.set_xlabel(utils.label(xcol))
else:
ax.set_xlabel("")
if show_ylabel and show_labels:
ax.set_ylabel(utils.label(ycol))
else:
ax.set_ylabel("")
if not show_ticklabels:
ax.tick_params(
labelbottom=False,
labelleft=False,
)
else:
if not show_ylabel:
ax.tick_params(labelleft=False)
return hb
[docs]
def plot_driver_response(
df: pd.DataFrame,
*,
cities: list[str],
drivers: list[str],
ycol: str,
sharey: str,
out: str,
gridsize: int,
vmin: int,
trend: bool,
trend_bins: int,
trend_min_n: int,
show_legend: bool,
show_labels: bool,
show_ticklabels: bool,
show_title: bool,
show_panel_titles: bool,
title: str | None,
) -> None:
utils.ensure_script_dirs()
utils.set_paper_style()
nrows = len(cities)
ncols = len(drivers)
fig_h = 2.6 * float(nrows)
fig = plt.figure(figsize=(7.0, fig_h))
gs = GridSpec(
nrows,
ncols,
figure=fig,
wspace=0.30,
hspace=0.35,
)
axes: list[list[object]] = [
[None for _ in range(ncols)] for _ in range(nrows)
]
axes[0][0] = fig.add_subplot(gs[0, 0])
for j in range(1, ncols):
ref = axes[0][0] if sharey in ("row", "all") else None
axes[0][j] = fig.add_subplot(gs[0, j], sharey=ref)
for i in range(1, nrows):
ref0 = axes[0][0] if sharey == "all" else None
axes[i][0] = fig.add_subplot(gs[i, 0], sharey=ref0)
for j in range(1, ncols):
if sharey == "all":
ref = axes[0][0]
elif sharey == "row":
ref = axes[i][0]
else:
ref = None
axes[i][j] = fig.add_subplot(gs[i, j], sharey=ref)
last_hb = None
for i, city in enumerate(cities):
dfi = df.loc[df["city"] == city].copy()
for j, xcol in enumerate(drivers):
show_ylabel = (j == 0) and (
sharey != "all" or i == 0
)
hb = _panel(
axes[i][j],
dfi,
xcol=xcol,
ycol=ycol,
show_ylabel=show_ylabel,
gridsize=gridsize,
vmin=vmin,
trend=trend,
trend_bins=trend_bins,
trend_min_n=trend_min_n,
show_labels=show_labels,
show_ticklabels=show_ticklabels,
)
last_hb = hb
if j == 0 and show_panel_titles:
axes[i][j].set_title(
city,
loc="left",
fontweight="bold",
)
elif j == 0 and not show_panel_titles:
axes[i][j].set_title("")
if show_legend and last_hb is not None:
cax = fig.add_axes([0.92, 0.15, 0.015, 0.7])
cb = plt.colorbar(last_hb, cax=cax)
if show_labels:
cb.set_label("log$_{10}$(count)")
else:
cb.set_label("")
cb.ax.set_ylabel("")
if not show_ticklabels:
cb.ax.tick_params(labelleft=False)
default_title = (
"Driver–response sanity: response vs drivers"
)
st = utils.resolve_title(
default=default_title,
title=title,
)
if show_title:
fig.suptitle(
st,
x=0.02,
y=0.99,
ha="left",
)
base = utils.resolve_fig_out(out)
if base.suffix:
base = base.with_suffix("")
fig.savefig(str(base) + ".png", bbox_inches="tight")
fig.savefig(str(base) + ".svg", bbox_inches="tight")
plt.close(fig)
print(f"[OK] wrote {base}.png/.svg")
def _parse_list(s: str) -> list[str]:
parts = [p.strip() for p in str(s).split(",")]
return [p for p in parts if p]
def _add_plot_driver_response_args(ap) -> None:
ap.add_argument(
"--src",
type=str,
required=True,
help="CSV file or directory.",
)
ap.add_argument(
"--file",
type=str,
default=None,
help="Combined CSV name (when --src is dir).",
)
ap.add_argument(
"--ns-file",
type=str,
default="nansha_final_main_std.harmonized.csv",
)
ap.add_argument(
"--zh-file",
type=str,
default="zhongshan_final_main_std.harmonized.csv",
)
utils.add_city_flags(ap, default_both=True)
ap.add_argument(
"--year",
"-y",
type=str,
default="all",
)
ap.add_argument(
"--sample-frac",
type=float,
default=None,
)
ap.add_argument(
"--sample-n",
type=int,
default=None,
)
ap.add_argument(
"--drivers",
type=str,
default=",".join(cfg.DRIVER_RESPONSE_DEFAULT_DRIVERS),
help="Comma list of driver columns.",
)
ap.add_argument(
"--response",
"--col",
dest="response",
type=str,
default=None,
help="Response column (e.g. head_m).",
)
ap.add_argument(
"--subs-kind",
type=str,
choices=["cum", "rate"],
default="cum",
help="If --response is not set, choose "
"subsidence_cum or subsidence.",
)
ap.add_argument(
"--sharey",
type=str,
choices=["none", "row", "all"],
default="row",
)
ap.add_argument(
"--gridsize",
type=int,
default=60,
)
ap.add_argument(
"--vmin",
type=int,
default=1,
)
ap.add_argument(
"--trend",
type=str,
default="true",
)
ap.add_argument(
"--trend-bins",
type=int,
default=30,
)
ap.add_argument(
"--trend-min-n",
type=int,
default=20,
)
utils.add_plot_text_args(
ap,
default_out="figS2_driver_response",
)
[docs]
def plot_driver_response_main(
argv: list[str] | None = None,
*,
prog: str | None = None,
) -> None:
ap = argparse.ArgumentParser(
prog=prog or "plot-driver-response",
description="Driver–response sanity plot.",
)
_add_plot_driver_response_args(ap)
args = ap.parse_args(argv)
cities = utils.resolve_cities(args)
if not cities:
cities = ["Nansha", "Zhongshan"]
src = Path(args.src).expanduser()
df = utils.load_dataset_any(
src,
file=args.file,
ns_file=args.ns_file,
zh_file=args.zh_file,
)
if "city" not in df.columns:
raise KeyError("Dataset must contain 'city' column.")
df = utils.filter_year(df, args.year)
df = utils.sample_df(
df,
sample_frac=args.sample_frac,
sample_n=args.sample_n,
)
drivers = _parse_list(args.drivers)
utils.ensure_columns(
df,
aliases=cfg.COLUMN_ALIASES,
)
ycol = args.response
if not ycol:
if args.subs_kind == "cum":
ycol = "subsidence_cum"
else:
ycol = "subsidence"
if ycol not in df.columns:
alts = cfg.COLUMN_ALIASES.get(ycol, ())
raise KeyError(
f"Missing response '{ycol}'. "
f"Aliases tried: {list(alts)}"
)
for d in drivers:
if d not in df.columns:
alts = cfg.COLUMN_ALIASES.get(d, ())
raise KeyError(
f"Missing driver '{d}'. "
f"Aliases tried: {list(alts)}"
)
trend = utils.str_to_bool(args.trend)
show_legend = utils.str_to_bool(args.show_legend)
show_labels = utils.str_to_bool(args.show_labels)
show_ticks = utils.str_to_bool(args.show_ticklabels)
show_title = utils.str_to_bool(args.show_title)
show_panels = utils.str_to_bool(args.show_panel_titles)
plot_driver_response(
df,
cities=cities,
drivers=drivers,
ycol=ycol,
sharey=args.sharey,
out=args.out,
gridsize=int(args.gridsize),
vmin=int(args.vmin),
trend=trend,
trend_bins=int(args.trend_bins),
trend_min_n=int(args.trend_min_n),
show_legend=show_legend,
show_labels=show_labels,
show_ticklabels=show_ticks,
show_title=show_title,
show_panel_titles=show_panels,
title=args.title,
)
[docs]
def main(
argv: list[str] | None = None,
*,
prog: str | None = None,
) -> None:
plot_driver_response_main(argv, prog=prog)
if __name__ == "__main__":
main()