# SPDX-License-Identifier: Apache-2.0
# GeoPrior-v3 — https://github.com/earthai-tech/geoprior-v3
# Copyright (c) 2026-present
# Author: LKouadio <https://lkouadio.com>
"""
Plot lithology parity across cities.
Left: normalized composition bars (Nansha vs Zhongshan).
Right: difference bars (Zhongshan - Nansha).
Arguments
---------
- ``--src``: dataset directory
- ``--col``: column name, default ``lithology_class``
- ``--year``: ``all`` or a year integer
- ``--out``: output stem/path, saved into ``scripts/figs/``
- ``-ns`` / ``-zh``: city codes; the default uses both
"""
from __future__ import annotations
import argparse
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.gridspec import GridSpec
from . import config as cfg
from . import utils
CITY_COLORS = cfg.CITY_COLORS
def _erf(x: np.ndarray) -> np.ndarray:
s = np.sign(x)
x = np.abs(x)
a1, a2, a3 = 0.254829592, -0.284496736, 1.421413741
a4, a5 = -1.453152027, 1.061405429
t = 1.0 / (1.0 + 0.3275911 * x)
y = 1.0 - (
(((((a5 * t + a4) * t) + a3) * t + a2) * t + a1)
* t
* np.exp(-x * x)
)
return s * y
def _phi(z: float) -> float:
zz = np.array([z], dtype=float)
return float(0.5 * (1.0 + _erf(zz / np.sqrt(2.0)))[0])
[docs]
def chisq_cramers_v(
counts_2xk: np.ndarray,
) -> tuple[float, float, int, float]:
rs = counts_2xk.sum(axis=1, keepdims=True)
cs = counts_2xk.sum(axis=0, keepdims=True)
tot = float(counts_2xk.sum())
if tot <= 0.0:
return np.nan, np.nan, 0, np.nan
exp = (rs @ cs) / tot
with np.errstate(divide="ignore", invalid="ignore"):
chi2 = np.nansum(((counts_2xk - exp) ** 2) / exp)
dof = int(
(counts_2xk.shape[0] - 1) * (counts_2xk.shape[1] - 1)
)
pval = np.nan
if dof > 0 and np.isfinite(chi2):
k = float(dof)
z = (
((chi2 / k) ** (1.0 / 3.0))
- (1.0 - 2.0 / (9.0 * k))
) / np.sqrt(2.0 / (9.0 * k))
pval = 1.0 - _phi(float(z))
pval = float(np.clip(pval, 0.0, 1.0))
r, c = counts_2xk.shape
denom = tot * max(1, min(r - 1, c - 1))
cv = float(np.sqrt(chi2 / denom)) if denom > 0 else np.nan
return float(chi2), float(pval), dof, cv
[docs]
def load_city_df(
src: Path,
filename: str,
*,
year: str = "all",
sample_frac: float | None = None,
sample_n: int | None = None,
seed: int = 42,
) -> pd.DataFrame:
fp = (src / filename).expanduser()
if not fp.exists():
raise FileNotFoundError(str(fp))
df = pd.read_csv(fp)
if year != "all":
df = df.loc[df["year"] == int(year)].copy()
if sample_n is not None:
n = min(int(sample_n), len(df))
df = df.sample(n=n, random_state=seed)
elif sample_frac is not None:
f = float(sample_frac)
f = max(0.0, min(1.0, f))
if f < 1.0:
df = df.sample(frac=f, random_state=seed)
return df
[docs]
def compute_proportions(
ns: pd.DataFrame,
zh: pd.DataFrame,
*,
col: str,
top_n: int,
group_others: bool,
) -> tuple[pd.DataFrame, list[str], np.ndarray]:
c_ns = ns[col].value_counts(dropna=False)
c_zh = zh[col].value_counts(dropna=False)
classes = list(
set(c_ns.index.tolist()) | set(c_zh.index.tolist())
)
tot = {
k: int(c_ns.get(k, 0)) + int(c_zh.get(k, 0))
for k in classes
}
classes = sorted(
classes, key=lambda k: tot[k], reverse=True
)
core = classes[: int(top_n)]
rest = classes[int(top_n) :]
def _props(vc: pd.Series) -> pd.Series:
s = float(vc.sum())
if s <= 0.0:
s = 1.0
return vc / s
p_ns = _props(c_ns)
p_zh = _props(c_zh)
data: list[tuple] = []
for cls in core:
data.append(
(cls, "Nansha", float(p_ns.get(cls, 0.0)))
)
data.append(
(cls, "Zhongshan", float(p_zh.get(cls, 0.0)))
)
if group_others and len(rest) > 0:
ns_o = float(p_ns.loc[rest].sum())
zh_o = float(p_zh.loc[rest].sum())
data.append(("Others", "Nansha", ns_o))
data.append(("Others", "Zhongshan", zh_o))
core = core + ["Others"]
dfp = pd.DataFrame(
data,
columns=["class", "city", "proportion"],
)
order = (
dfp.groupby("class")["proportion"]
.max()
.sort_values(ascending=True)
.index.tolist()
)
dfp["class"] = pd.Categorical(
dfp["class"],
categories=order,
ordered=True,
)
dfp = dfp.sort_values(["class", "city"]).reset_index(
drop=True
)
k = len(core)
mat = np.zeros((2, k), dtype=float)
for j, cls in enumerate(core):
if cls == "Others":
ns_c = (
float(c_ns.loc[rest].sum())
if len(rest)
else 0.0
)
zh_c = (
float(c_zh.loc[rest].sum())
if len(rest)
else 0.0
)
else:
ns_c = float(c_ns.get(cls, 0.0))
zh_c = float(c_zh.get(cls, 0.0))
mat[0, j] = ns_c
mat[1, j] = zh_c
return dfp, core, mat
def _extract_props(
dfp: pd.DataFrame,
classes_order: list[str],
) -> tuple[list[float], list[float]]:
ns_props: list[float] = []
zh_props: list[float] = []
for cls in classes_order:
m_ns = (dfp["class"] == cls) & (
dfp["city"] == "Nansha"
)
m_zh = (dfp["class"] == cls) & (
dfp["city"] == "Zhongshan"
)
ns_props.append(
float(dfp.loc[m_ns, "proportion"].iloc[0])
if m_ns.any()
else 0.0
)
zh_props.append(
float(dfp.loc[m_zh, "proportion"].iloc[0])
if m_zh.any()
else 0.0
)
return ns_props, zh_props
[docs]
def draw_lithology_parity(
dfp: pd.DataFrame,
counts_mat: np.ndarray,
*,
col: str,
outpath: Path,
sharey: bool,
) -> None:
utils.ensure_script_dirs()
utils.set_paper_style()
chi2, pval, dof, cv = chisq_cramers_v(counts_mat)
fig = plt.figure(figsize=(7.0, 4.0))
gs = GridSpec(
1,
2,
figure=fig,
width_ratios=[1.2, 0.8],
wspace=0.35,
)
axA = fig.add_subplot(gs[0, 0])
classes = list(dfp["class"].cat.categories)
y = np.arange(len(classes))
ns_props, zh_props = _extract_props(dfp, classes)
h = 0.35
axA.barh(
y - h / 2.0,
ns_props,
height=h,
color=CITY_COLORS["Nansha"],
label="Nansha",
)
axA.barh(
y + h / 2.0,
zh_props,
height=h,
color=CITY_COLORS["Zhongshan"],
label="Zhongshan",
)
axA.set_yticks(y)
axA.set_yticklabels(classes)
axA.set_xlabel("Proportion (normalized)")
axA.set_xlim(0.0, 1.0)
axA.legend(frameon=False, loc="lower right")
col_lbl = utils.label(col, with_unit=False)
axA.set_title(f"{col_lbl} — composition")
axB = fig.add_subplot(
gs[0, 1], sharey=axA if sharey else None
)
diff = np.asarray(zh_props) - np.asarray(ns_props)
axB.axvline(0.0, color="#444444", lw=0.8)
axB.barh(y, diff, height=0.6, color="#888888")
axB.set_yticks(y)
if sharey:
axB.tick_params(axis="y", left=False, labelleft=False)
else:
axB.set_yticklabels(classes)
axB.set_xlabel("Δ proportion (Zhongshan − Nansha)")
axB.set_title("Parity difference")
stat = (
f"χ²={chi2:.2f}, dof={dof}, p≈{pval:.3f}, V={cv:.3f}"
)
fig.suptitle(
f"Lithology parity across cities • {stat}",
x=0.02,
y=0.99,
ha="left",
)
base = utils.resolve_fig_out(str(outpath))
if base.suffix:
base = base.with_suffix("")
fig.savefig(str(base) + ".png", bbox_inches="tight")
fig.savefig(str(base) + ".svg", bbox_inches="tight")
plt.close(fig)
print(f"[OK] Wrote {base}.png/.svg")
def _build_argparser(
*, prog: str | None = None
) -> argparse.ArgumentParser:
ap = argparse.ArgumentParser(
prog=prog or "plot-litho-parity",
description="Supplementary Fig. S1.",
)
ap.add_argument(
"--src",
type=str,
default=str(
Path(r"F:\repositories\geoprior-learn" r"\data")
),
help="Final dataset directory.",
)
ap.add_argument(
"--ns-file",
type=str,
default="nansha_dataset.final.ready.csv",
)
ap.add_argument(
"--zh-file",
type=str,
default="zhongshan_dataset.final.ready.csv",
)
utils.add_city_flags(ap, default_both=True)
ap.add_argument(
"--col",
"-c",
type=str,
default="lithology_class",
)
ap.add_argument(
"--year",
"-y",
type=str,
default="all",
)
ap.add_argument(
"--sample-frac",
type=float,
default=None,
)
ap.add_argument(
"--sample-n",
type=int,
default=None,
)
ap.add_argument(
"--top-n",
type=int,
default=8,
)
ap.add_argument(
"--group-others",
type=str,
default="true",
)
ap.add_argument(
"--sharey",
type=str,
choices=["true", "false"],
default="true",
)
ap.add_argument(
"--out",
"-o",
type=str,
default="figS1_lithology_parity",
)
return ap
[docs]
def figS1_lithology_parity_main(
argv: list[str] | None = None,
*,
prog: str | None = None,
) -> None:
ap = _build_argparser(prog=prog)
args = ap.parse_args(argv)
cities = utils.resolve_cities(args)
if set(cities) != {"Nansha", "Zhongshan"}:
raise ValueError(
"Fig S1 needs both cities. Use --cities ns,zh "
"or pass -ns -zh."
)
src = Path(args.src).expanduser()
group_others = utils.str_to_bool(args.group_others)
sharey = utils.str_to_bool(args.sharey)
ns = load_city_df(
src,
args.ns_file,
year=args.year,
sample_frac=(
None if args.sample_n else args.sample_frac
),
sample_n=args.sample_n,
)
zh = load_city_df(
src,
args.zh_file,
year=args.year,
sample_frac=(
None if args.sample_n else args.sample_frac
),
sample_n=args.sample_n,
)
if (
args.col not in ns.columns
or args.col not in zh.columns
):
raise KeyError(
f"Column '{args.col}' must exist in both CSVs."
)
dfp, _, mat = compute_proportions(
ns,
zh,
col=args.col,
top_n=args.top_n,
group_others=group_others,
)
draw_lithology_parity(
dfp,
mat,
col=args.col,
outpath=Path(args.out),
sharey=sharey,
)
[docs]
def main(
argv: list[str] | None = None, *, prog: str | None = None
) -> None:
figS1_lithology_parity_main(argv, prog=prog)
if __name__ == "__main__":
main()
# python -m scripts.scripts \
# --src "F:\...\final_dataset" \
# -ns -zh \
# --col lithology_class \
# --year all \
# --out figS1_lithology_parity
# 4) How you run things (from project root)
# With dispatcher
# python -m scripts plot-driver-response --src data
# python -m scripts plot-core-ablation --ns-with ...
# python -m scripts plot-litho-parity --src data -ns -zh
# Or run a single module directly
# python -m scripts.plot_driver_response --src data
# 5) Optional: real shell commands (recommended)
# If you want true commands like plot-driver-response without
# python -m, add to your pyproject.toml:
# [project.scripts]
# plot-driver-response = "scripts.plot_driver_response:main"
# plot-core-ablation = "scripts.plot_core_ablation:main"
# plot-litho-parity = "scripts.plot_litho_parity:main"
# paper-scripts = "scripts.__main__:main"
# Then:
# pip install -e .
# plot-driver-response --src data
# paper-scripts plot-core-ablation --ns-with ...
# Recommendation for planning
# Keep one file = one figure/script (plot_*.py).
# Put all shared CLI flags and style into utils.py
# (add_plot_text_args, add_city_flags, output resolvers).
# Keep config.py as the single truth for units, column aliases,
# and PATTERNS.
# Keep __main__.py as a simple dispatcher only.
# If you want, paste your current scripts/utils.py +
# scripts/config.py (the updated ones) and I’ll make sure the three
# scripts import/use the same helpers in exactly the same way (and keep
# every line ≤62 chars).