Source code for geoprior.scripts.plot_sm3_log_offsets

# SPDX-License-Identifier: Apache-2.0
# GeoPrior-v3 — https://github.com/earthai-tech/geoprior-v3
# Copyright (c) 2026-present
# Author: LKouadio <https://lkouadio.com>
r"""Script helpers for plotting SM3 log-offset diagnostics."""

from __future__ import annotations

import argparse
import json
from pathlib import Path
from typing import Any

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from . import config as cfg
from . import utils

_EPS = 1e-12


def _load_meta(path: Path) -> dict[str, Any]:
    meta_p = Path(str(path) + ".meta.json")
    if not meta_p.exists():
        return {}
    try:
        return json.loads(meta_p.read_text(encoding="utf-8"))
    except Exception:
        return {}


def _load_payload(path: Path) -> dict[str, np.ndarray]:
    suf = path.suffix.lower()

    if suf == ".npz":
        with np.load(str(path), allow_pickle=True) as z:
            return {k: np.asarray(z[k]) for k in z.files}

    if suf == ".csv":
        df = pd.read_csv(str(path))
        return {c: df[c].to_numpy() for c in df.columns}

    if suf == ".parquet":
        df = pd.read_parquet(str(path))
        return {c: df[c].to_numpy() for c in df.columns}

    raise ValueError(f"Unsupported payload: {path}")


def _pick(
    payload: dict[str, np.ndarray],
    *keys: str,
) -> np.ndarray | None:
    for k in keys:
        if k in payload and payload[k] is not None:
            return np.asarray(payload[k])
    return None


def _as_1d(x: np.ndarray) -> np.ndarray:
    return np.asarray(x).reshape(-1)


def _safe_log10(x: np.ndarray, eps: float) -> np.ndarray:
    xx = np.clip(np.asarray(x, float), eps, None)
    return np.log10(xx)


def _ensure_same_n(
    cols: dict[str, np.ndarray],
) -> dict[str, np.ndarray]:
    n0: int | None = None
    out: dict[str, np.ndarray] = {}

    for k, v in cols.items():
        vv = _as_1d(np.asarray(v))
        if n0 is None:
            n0 = int(vv.size)
        elif int(vv.size) != int(n0):
            raise ValueError(
                f"Inconsistent size for {k}: "
                f"{vv.size} vs {n0}"
            )
        out[k] = vv

    return out


def _prior_series(
    *,
    n: int,
    payload: dict[str, np.ndarray],
    payload_keys: tuple[str, ...],
    scalar: float | None,
) -> np.ndarray | None:
    arr = _pick(payload, *payload_keys)
    if arr is not None:
        a1 = _as_1d(arr)
        if a1.size != n:
            raise ValueError(
                f"Prior array size mismatch: {a1.size} vs {n}"
            )
        return a1

    if scalar is None:
        return None

    return np.full((n,), float(scalar), dtype=float)


[docs] def build_offsets_table( payload: dict[str, np.ndarray], *, K_prior: float | None, Ss_prior: float | None, Hd_prior: float | None, eps: float = _EPS, ) -> pd.DataFrame: tau = _pick(payload, "tau", "tau_eff") tp = _pick(payload, "tau_prior", "tau_closure", "tau_cl") K = _pick(payload, "K", "K_eff", "K_field") Ss = _pick(payload, "Ss", "Ss_eff", "Ss_field") Hd = _pick(payload, "Hd", "H_d", "H", "H_field") if tau is None or tp is None: raise KeyError( "payload must contain tau and tau_prior/tau_closure" ) if K is None or Ss is None or Hd is None: raise KeyError( "payload must contain K, Ss, Hd (or H)" ) tau = _as_1d(tau) tp = _as_1d(tp) K = _as_1d(K) Ss = _as_1d(Ss) Hd = _as_1d(Hd) n = int(tau.size) if ( tp.size != n or K.size != n or Ss.size != n or Hd.size != n ): raise ValueError( "payload arrays must have same length" ) log10_tau = _safe_log10(tau, eps) log10_tp = _safe_log10(tp, eps) cols: dict[str, np.ndarray] = { "log10_tau": log10_tau, "log10_tau_prior": log10_tp, "delta_log_tau": log10_tau - log10_tp, "log10_K": _safe_log10(K, eps), "log10_Ss": _safe_log10(Ss, eps), "log10_Hd": _safe_log10(Hd, eps), } Kp = _prior_series( n=n, payload=payload, payload_keys=("K_prior", "K_lith_prior", "k_prior"), scalar=K_prior, ) Ssp = _prior_series( n=n, payload=payload, payload_keys=("Ss_prior", "Ss_lith_prior", "Ss0"), scalar=Ss_prior, ) Hdp = _prior_series( n=n, payload=payload, payload_keys=( "Hd_prior", "H_d_prior", "Hd_lith_prior", ), scalar=Hd_prior, ) if Kp is not None: cols["delta_logK"] = _safe_log10( K, eps ) - _safe_log10(Kp, eps) if Ssp is not None: cols["delta_logSs"] = _safe_log10( Ss, eps ) - _safe_log10(Ssp, eps) if Hdp is not None: cols["delta_logHd"] = _safe_log10( Hd, eps ) - _safe_log10(Hdp, eps) cols = _ensure_same_n(cols) cols["index"] = np.arange(n, dtype=int) return pd.DataFrame(cols)
[docs] def summarise_offsets(df: pd.DataFrame) -> pd.DataFrame: cols = [c for c in df.columns if c.startswith("delta_")] if not cols: raise RuntimeError("No delta_* columns to summarise.") desc = df[cols].describe(percentiles=[0.05, 0.5, 0.95]).T desc = desc.rename( columns={ "mean": "mean", "std": "std", "5%": "p05", "50%": "p50", "95%": "p95", } ) keep = ["mean", "std", "p05", "p50", "p95"] desc = desc[keep] desc.index.name = "metric" return desc
def _beautify(ax: plt.Axes) -> None: ax.spines["top"].set_visible(False) ax.spines["right"].set_visible(False) ax.tick_params(direction="out", length=3, width=0.6)
[docs] def plot_offsets( df: pd.DataFrame, *, out_base: str, dpi: int, show_title: bool, title: str | None, bins: int, ) -> list[str]: utils.ensure_script_dirs() utils.set_paper_style() out_paths: list[str] = [] base = utils.resolve_fig_out(out_base) if base.suffix: base = base.with_suffix("") delta_cols = [ "delta_logK", "delta_logSs", "delta_logHd", "delta_log_tau", ] # -------------------------- # Figure 1: 2x2 hist grid # -------------------------- fig = plt.figure( figsize=(7.2, 4.2), constrained_layout=True ) gs = fig.add_gridspec(2, 2) for i, col in enumerate(delta_cols): ax = fig.add_subplot(gs[i // 2, i % 2]) _beautify(ax) if col in df.columns: x = df[col].to_numpy(float) x = x[np.isfinite(x)] ax.hist(x, bins=int(bins)) ax.set_xlabel(col) ax.set_ylabel("Count") else: ax.text( 0.5, 0.5, f"{col} not available", ha="center", va="center", transform=ax.transAxes, ) ax.set_xticks([]) ax.set_yticks([]) if show_title: ttl = utils.resolve_title( default="SM3 — log-offset diagnostics", title=title, ) fig.suptitle(ttl, x=0.02, ha="left") p1 = str(base) + "-hists" fig.savefig(p1 + ".png", dpi=dpi, bbox_inches="tight") fig.savefig(p1 + ".svg", bbox_inches="tight") plt.close(fig) out_paths += [p1 + ".png", p1 + ".svg"] # -------------------------- # Figure 2: tau scatter # -------------------------- if {"log10_tau_prior", "delta_log_tau"}.issubset( df.columns ): fig2 = plt.figure( figsize=(3.6, 3.0), constrained_layout=True ) ax2 = fig2.add_subplot(1, 1, 1) _beautify(ax2) ax2.scatter( df["log10_tau_prior"], df["delta_log_tau"], s=6, alpha=0.35, rasterized=True, ) ax2.axhline(0.0, linestyle="--", linewidth=0.9) ax2.set_xlabel(r"$\log_{10}\tau_{\mathrm{prior}}$") ax2.set_ylabel(r"$\delta_\tau$") p2 = str(base) + "-tau-scatter" fig2.savefig( p2 + ".png", dpi=dpi, bbox_inches="tight" ) fig2.savefig(p2 + ".svg", bbox_inches="tight") plt.close(fig2) out_paths += [p2 + ".png", p2 + ".svg"] return out_paths
def _resolve_payload_path( *, src: str | None, payload: str | None, ) -> Path: if payload: return Path(payload).expanduser() if not src: raise ValueError("Provide --payload or --src") p = utils.find_latest( src, cfg.PATTERNS["physics_payload"] ) if p is None: raise FileNotFoundError( f"No physics payload under: {src}" ) return p
[docs] def plot_sm3_log_offsets_main( argv: list[str] | None = None, *, prog: str | None = None, ) -> None: ap = argparse.ArgumentParser( prog=prog or "plot-sm3-log-offsets", description="SM3 log-offset diagnostics (v3.2 payload).", ) ap.add_argument("--src", type=str, default=None) ap.add_argument("--payload", type=str, default=None) ap.add_argument("--K-prior", type=float, default=None) ap.add_argument("--Ss-prior", type=float, default=None) ap.add_argument("--Hd-prior", type=float, default=None) ap.add_argument("--bins", type=int, default=50) ap.add_argument("--dpi", type=int, default=cfg.PAPER_DPI) ap.add_argument( "--out-raw-csv", type=str, default="sm3-offsets-raw.csv", ) ap.add_argument( "--out-summary-csv", type=str, default="sm3-offsets-summary.csv", ) ap.add_argument( "--out-json", type=str, default="sm3-offsets.json", ) utils.add_plot_text_args( ap, default_out="sm3-log-offsets", ) args = ap.parse_args(argv) payload_p = _resolve_payload_path( src=args.src, payload=args.payload, ) payload = _load_payload(payload_p) meta = _load_meta(payload_p) df = build_offsets_table( payload, K_prior=args.K_prior, Ss_prior=args.Ss_prior, Hd_prior=args.Hd_prior, ) summ = summarise_offsets(df) out_raw = utils.resolve_out_out(args.out_raw_csv) out_sum = utils.resolve_out_out(args.out_summary_csv) out_js = utils.resolve_out_out(args.out_json) df.to_csv(out_raw, index=False) summ.to_csv(out_sum) show_title = utils.str_to_bool( args.show_title, default=True ) fig_paths = plot_offsets( df, out_base=args.out, dpi=int(args.dpi), show_title=show_title, title=args.title, bins=int(args.bins), ) payload_js = { "payload": str(payload_p.resolve()), "meta": meta, "n": int(len(df)), "columns": list(df.columns), "summary": json.loads( summ.reset_index().to_json(orient="records") ), "figures": fig_paths, "tables": { "raw_csv": str(out_raw), "summary_csv": str(out_sum), }, } out_js.write_text( json.dumps(payload_js, indent=2), encoding="utf-8", ) print(f"[OK] payload: {payload_p}") print(f"[OK] wrote {out_raw}") print(f"[OK] wrote {out_sum}") print(f"[OK] wrote {out_js}") for p in fig_paths: print(f"[OK] wrote {p}")
[docs] def main( argv: list[str] | None = None, *, prog: str | None = None, ) -> None: plot_sm3_log_offsets_main(argv, prog=prog)
if __name__ == "__main__": main() # Auto-discover payload under a run folder # python -m scripts plot-sm3-log-offsets \ # --src results/sm3_synth_1d \ # --out sm3-log-offsets # Explicit payload file # python -m scripts plot-sm3-log-offsets \ # --payload results/.../physics_payload_run_val.npz \ # --out sm3-log-offsets # If your payload does not embed priors, pass scalar priors # python -m scripts plot-sm3-log-offsets \ # --payload results/.../physics_payload_run_val.npz \ # --K-prior 1e-7 --Ss-prior 1e-5 --Hd-prior 40 \ # --out sm3-log-offsets