Source code for geoprior.scripts.plot_sm3_identifiability

# SPDX-License-Identifier: Apache-2.0
# GeoPrior-v3 — https://github.com/earthai-tech/geoprior-v3
# Copyright (c) 2026-present
# Author: LKouadio <https://lkouadio.com>

r"""Plot synthetic identifiability (SM3).

Panels
------
(a) Timescale recovery:
    log10(tau_true) vs log10(tau_est)
    + optional tau_cl overlay (closure-implied)
    + optional tau_prior marker.

(b) Permeability recovery:
    log10(K_true) vs log10(K_est)  OR
    log10(K_true) vs log10(K_cl(tau_est)).

(c) Degeneracy ridge check:
    delta_K vs (delta_Ss + 2 delta_Hd).

(d) Error vs identifiability metric:
    :math:`|\Delta  log10 tau|` vs chosen identifiability metric.

Inputs
------
CSV file produced by SM3 synth runs, typically:
  results/sm3_synth_1d/sm3_synth_runs.csv

Outputs
-------
- Figure: scripts/figs/<out>.png and <out>.svg
- Summary table: scripts/out/<out_csv>
- Summary json: scripts/out/<out_json>

API conventions
---------------
- Style via scripts.utils.set_paper_style()
- Figure output via scripts.utils.resolve_fig_out()
- Tables/JSON via scripts.utils.resolve_out_out()
- Common plot text toggles via scripts.utils.add_plot_text_args()

Notes
-----
- Bootstrap CIs are computed row-wise.
- If --k-from-tau=auto:
    * defaults to false when identify includes "both"
    * else defaults to true.
"""

from __future__ import annotations

import argparse
import json
import warnings
from collections.abc import Callable, Iterable
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.lines import Line2D

from . import config as cfg
from . import utils


# -------------------------------------------------------------------
# Small numeric helpers
# -------------------------------------------------------------------
def _require_cols(
    df: pd.DataFrame,
    cols: Iterable[str],
    *,
    ctx: str = "",
) -> None:
    miss = [c for c in cols if c not in df.columns]
    if miss:
        m = ", ".join(miss)
        raise KeyError(f"Missing cols ({ctx}): {m}")


def _mask_pair(
    x: np.ndarray,
    y: np.ndarray,
) -> tuple[np.ndarray, np.ndarray]:
    x = np.asarray(x, float)
    y = np.asarray(y, float)
    m = np.isfinite(x) & np.isfinite(y)
    return x[m], y[m]


def _spearman(x: np.ndarray, y: np.ndarray) -> float:
    x, y = _mask_pair(x, y)
    if x.size < 3:
        return float("nan")

    rx = pd.Series(x).rank().to_numpy()
    ry = pd.Series(y).rank().to_numpy()

    if float(np.nanstd(rx)) == 0.0:
        return float("nan")
    if float(np.nanstd(ry)) == 0.0:
        return float("nan")

    return float(np.corrcoef(rx, ry)[0, 1])


def _linfit_stats(
    x: np.ndarray,
    y: np.ndarray,
) -> tuple[float, float, float]:
    x = np.asarray(x, float)
    y = np.asarray(y, float)

    m = np.isfinite(x) & np.isfinite(y)
    if int(m.sum()) < 3:
        return (float("nan"), float("nan"), float("nan"))

    xx = x[m]
    yy = y[m]

    if float(np.nanstd(xx)) == 0.0:
        return (float("nan"), float("nan"), float("nan"))

    try:
        rank_warn = np.RankWarning
    except Exception:
        try:
            from numpy.polynomial.polyutils import (
                RankWarning as rank_warn,
            )
        except Exception:
            rank_warn = Warning

    with warnings.catch_warnings():
        warnings.simplefilter("ignore", rank_warn)
        try:
            slope, inter = np.polyfit(xx, yy, 1)
        except Exception:
            return (float("nan"), float("nan"), float("nan"))

    yhat = inter + slope * xx
    ss_res = float(np.nansum((yy - yhat) ** 2))
    ss_tot = float(np.nansum((yy - np.nanmean(yy)) ** 2))
    r2 = (
        float("nan") if ss_tot <= 0 else 1.0 - ss_res / ss_tot
    )

    return (float(slope), float(inter), float(r2))


def _bootstrap_pair(
    x: np.ndarray,
    y: np.ndarray,
    fn: Callable[[np.ndarray, np.ndarray], float],
    *,
    n_boot: int,
    seed: int,
) -> np.ndarray:
    x, y = _mask_pair(x, y)
    n = int(x.size)

    if n < 2:
        return np.full((int(n_boot),), np.nan)

    rng = np.random.default_rng(int(seed))
    out = np.empty((int(n_boot),), float)

    for b in range(int(n_boot)):
        idx = rng.integers(0, n, size=n)
        out[b] = float(fn(x[idx], y[idx]))

    return out


def _boot_ci(
    a: np.ndarray,
    *,
    ci: float,
) -> tuple[float, float]:
    a = np.asarray(a, float)
    a = a[np.isfinite(a)]
    if a.size == 0:
        return (float("nan"), float("nan"))

    alpha = (1.0 - float(ci)) / 2.0
    lo = float(np.quantile(a, alpha))
    hi = float(np.quantile(a, 1.0 - alpha))
    return (lo, hi)


def _fmt_ci(
    est: float,
    lo: float,
    hi: float,
    *,
    nd: int,
) -> str:
    ok = (
        np.isfinite(est)
        and np.isfinite(lo)
        and np.isfinite(hi)
    )
    if not ok:
        return f"{est:.{nd}f}"
    return f"{est:.{nd}f} [{lo:.{nd}f},{hi:.{nd}f}]"


# -------------------------------------------------------------------
# CLI
# -------------------------------------------------------------------
def _parse_args(
    argv: list[str] | None,
    *,
    prog: str | None = None,
) -> argparse.Namespace:
    p = argparse.ArgumentParser(
        prog=prog or "plot-sm3-identifiability",
        description="SM3: Synthetic identifiability figure.",
    )

    p.add_argument(
        "--csv",
        type=str,
        default="results/sm3_synth_1d/sm3_synth_runs.csv",
        help="Input CSV (sm3_synth_runs.csv).",
    )

    p.add_argument(
        "--tau-units",
        type=str,
        choices=["year", "sec"],
        default="year",
        help="Units for tau panels (year/sec).",
    )

    p.add_argument(
        "--metric",
        type=str,
        choices=[
            "ridge_resid",
            "eps_prior_rms",
            "closure_consistency_rms",
        ],
        default="ridge_resid",
        help="Metric used in panel (d).",
    )

    p.add_argument(
        "--k-from-tau",
        type=str,
        default="auto",
        choices=["auto", "true", "false"],
        help="Panel (b): derive K from tau (auto/true/false).",
    )

    p.add_argument(
        "--k-cl-source",
        type=str,
        default="prior",
        choices=["prior", "true", "est"],
        help="If K is derived from tau, which Ss/Hd to use.",
    )

    p.add_argument(
        "--only-identify",
        type=str,
        default=None,
        help="Filter identify column (e.g., tau/both).",
    )

    p.add_argument(
        "--nx-min",
        type=int,
        default=None,
        help="Filter rows with nx >= nx-min.",
    )

    p.add_argument(
        "--n-boot",
        type=int,
        default=2000,
        help="Bootstrap replicates for CI.",
    )
    p.add_argument(
        "--ci",
        type=float,
        default=0.95,
        help="CI level (e.g., 0.95).",
    )
    p.add_argument(
        "--boot-seed",
        type=int,
        default=0,
        help="Bootstrap RNG seed base.",
    )

    p.add_argument(
        "--show-ci",
        type=str,
        default="true",
        help="Show CI in stats annotations (true/false).",
    )
    p.add_argument(
        "--show-stats",
        type=str,
        default="true",
        help="Show stats annotation blocks (true/false).",
    )
    p.add_argument(
        "--show-prior",
        type=str,
        default="true",
        help="Overlay prior points in panel (a).",
    )
    p.add_argument(
        "--show-tau-cl",
        type=str,
        default="true",
        help="Overlay closure tau_cl in panel (a).",
    )

    p.add_argument(
        "--font",
        type=int,
        default=cfg.PAPER_FONT,
    )
    p.add_argument(
        "--dpi",
        type=int,
        default=cfg.PAPER_DPI,
    )

    p.add_argument(
        "--out-csv",
        type=str,
        default="sm3_identifiability_summary.csv",
        help="Summary table path (scripts/out if rel).",
    )
    p.add_argument(
        "--out-json",
        type=str,
        default="sm3_identifiability_summary.json",
        help="Summary json path (scripts/out if rel).",
    )

    utils.add_plot_text_args(
        p,
        default_out="sm3-identifiability",
    )

    return p.parse_args(argv)


# -------------------------------------------------------------------
# Core
# -------------------------------------------------------------------
[docs] def plot_sm3_identifiability( df: pd.DataFrame, *, tau_units: str, metric: str, k_from_tau: bool, k_cl_source: str, show_prior: bool, show_tau_cl: bool, show_legend: bool, show_labels: bool, show_ticklabels: bool, show_panel_titles: bool, show_title: bool, title: str | None, show_stats: bool, show_ci: bool, n_boot: int, ci: float, boot_seed: int, out: str, out_csv: str, out_json: str, dpi: int, ) -> None: utils.ensure_script_dirs() utils.set_paper_style() base_cols = [ "lith_idx", "kappa_b", "tau_true_sec", "tau_prior_sec", "tau_true_year", "tau_prior_year", "tau_est_med_sec", "tau_est_med_year", "K_true_mps", "K_prior_mps", "K_est_med_mps", "K_est_med_m_per_year", "Ss_true", "Ss_prior", "Ss_est_med", "Hd_true", "Hd_prior", "Hd_est_med", "vs_true_delta_K_q50", "vs_true_delta_Ss_q50", "vs_true_delta_Hd_q50", "eps_prior_rms", "closure_consistency_rms", ] _require_cols(df, base_cols, ctx="SM3") eps = 1e-12 sec_per_year = float(cfg.SECONDS_PER_YEAR) ln10 = float(np.log(10.0)) # ------------------------- # helpers for stats + CI # ------------------------- def _mae(a: np.ndarray, b: np.ndarray) -> float: return float(np.nanmean(np.abs(a - b))) def _r2(a: np.ndarray, b: np.ndarray) -> float: a, b = _mask_pair(a, b) if a.size < 2: return float("nan") if float(np.nanstd(a)) == 0.0: return float("nan") if float(np.nanstd(b)) == 0.0: return float("nan") with np.errstate(invalid="ignore", divide="ignore"): r = float(np.corrcoef(a, b)[0, 1]) if not np.isfinite(r): return float("nan") return r * r def _slope(a: np.ndarray, b: np.ndarray) -> float: s, _, _ = _linfit_stats(a, b) return float(s) def _rho(a: np.ndarray, b: np.ndarray) -> float: return float(_spearman(a, b)) def _stat_ci( name: str, x: np.ndarray, y: np.ndarray, fn: Callable[[np.ndarray, np.ndarray], float], *, nd: int, seed_off: int, est: float | None = None, ) -> tuple[dict, str]: est_v = float(fn(x, y)) if est is None else float(est) lo = float("nan") hi = float("nan") if show_ci: bt = _bootstrap_pair( x, y, fn, n_boot=int(n_boot), seed=int(boot_seed + seed_off), ) lo, hi = _boot_ci(bt, ci=float(ci)) rec = { "name": name, "estimate": est_v, "ci_lo": lo, "ci_hi": hi, "nd": int(nd), } txt = _fmt_ci(est_v, lo, hi, nd=int(nd)) return rec, txt # ------------------------- # Panel (a): tau recovery # ------------------------- if tau_units == "year": t_true = np.clip(df["tau_true_year"], eps, None) t_est = np.clip(df["tau_est_med_year"], eps, None) t_pri = np.clip(df["tau_prior_year"], eps, None) xlab_a = r"$\log_{10}\,\tau_{\mathrm{true}}$ (yr)" ylab_a = r"$\log_{10}\,\hat{\tau}$ (yr)" y_pr = np.log10(t_pri.to_numpy(float)) else: t_true = np.clip(df["tau_true_sec"], eps, None) t_est = np.clip(df["tau_est_med_sec"], eps, None) t_pri = np.clip(df["tau_prior_sec"], eps, None) xlab_a = r"$\log_{10}\,\tau_{\mathrm{true}}$ (s)" ylab_a = r"$\log_{10}\,\hat{\tau}$ (s)" y_pr = np.log10(t_pri.to_numpy(float)) x_tau = np.log10(t_true.to_numpy(float)) y_tau = np.log10(t_est.to_numpy(float)) kappa = np.clip(df["kappa_b"].to_numpy(float), eps, None) K_est = np.clip( df["K_est_med_mps"].to_numpy(float), eps, None ) Ss_est = np.clip( df["Ss_est_med"].to_numpy(float), eps, None ) Hd_est = np.clip( df["Hd_est_med"].to_numpy(float), eps, None ) tau_cl = (Hd_est**2) * Ss_est / (np.pi**2 * kappa * K_est) tau_cl = np.clip(tau_cl, eps, None) if tau_units == "year": y_cl = np.log10(tau_cl / sec_per_year) else: y_cl = np.log10(tau_cl) mae_tau = float(np.nanmean(np.abs(y_tau - x_tau))) slope_tau, _, r2_tau = _linfit_stats(x_tau, y_tau) s_mae, mae_txt = _stat_ci( "tau_mae_log10", x_tau, y_tau, _mae, nd=2, seed_off=11, est=mae_tau, ) s_r2, r2_txt = _stat_ci( "tau_r2", x_tau, y_tau, _r2, nd=3, seed_off=12, est=r2_tau, ) s_sl, sl_txt = _stat_ci( "tau_slope", x_tau, y_tau, _slope, nd=2, seed_off=13, est=slope_tau, ) # ------------------------- # Panel (b): K recovery # ------------------------- K_true = np.clip( df["K_true_mps"].to_numpy(float), eps, None ) x_K = np.log10(K_true * sec_per_year) if k_from_tau: tau_est_sec = np.clip( df["tau_est_med_sec"].to_numpy(float), eps, None, ) src = str(k_cl_source).strip().lower() if src == "true": Ss_use = np.clip( df["Ss_true"].to_numpy(float), eps, None ) Hd_use = np.clip( df["Hd_true"].to_numpy(float), eps, None ) elif src == "est": Ss_use = np.clip( df["Ss_est_med"].to_numpy(float), eps, None ) Hd_use = np.clip( df["Hd_est_med"].to_numpy(float), eps, None ) else: Ss_use = np.clip( df["Ss_prior"].to_numpy(float), eps, None ) Hd_use = np.clip( df["Hd_prior"].to_numpy(float), eps, None ) K_cl_mps = ( (Hd_use**2) * Ss_use / (np.pi**2 * kappa * tau_est_sec) ) y_K = np.log10( np.clip(K_cl_mps * sec_per_year, eps, None) ) title_b = "Permeability from closure" ylab_b = r"$\log_{10}\,K_{cl}(\hat{\tau})$ (m/yr)" else: y_K = np.log10( np.clip( df["K_est_med_m_per_year"].to_numpy(float), eps, None, ) ) title_b = "Permeability recovery" ylab_b = r"$\log_{10}\,\hat{K}$ (m/yr)" mae_K = float(np.nanmean(np.abs(y_K - x_K))) slope_K, _, r2_K = _linfit_stats(x_K, y_K) s_maeK, maeK_txt = _stat_ci( "K_mae_log10", x_K, y_K, _mae, nd=2, seed_off=31, est=mae_K, ) s_r2K, r2K_txt = _stat_ci( "K_r2", x_K, y_K, _r2, nd=3, seed_off=32, est=r2_K, ) s_slK, slK_txt = _stat_ci( "K_slope", x_K, y_K, _slope, nd=2, seed_off=33, est=slope_K, ) # ------------------------- # Panel (c): ridge check # ------------------------- dK = df["vs_true_delta_K_q50"].to_numpy(float) / ln10 dSs = df["vs_true_delta_Ss_q50"].to_numpy(float) / ln10 dHd = df["vs_true_delta_Hd_q50"].to_numpy(float) / ln10 ridge_x = dSs + 2.0 * dHd ridge_y = dK ridge_abs = np.abs(ridge_y - ridge_x) rho_ridge = _spearman(ridge_x, ridge_y) ridge_q50 = float(np.nanquantile(ridge_abs, 0.50)) ridge_q95 = float(np.nanquantile(ridge_abs, 0.95)) s_rhoR, rhoR_txt = _stat_ci( "ridge_spearman", ridge_x, ridge_y, _rho, nd=2, seed_off=41, est=rho_ridge, ) # ------------------------- # Panel (d): err vs metric # ------------------------- err_tau = np.abs(y_tau - x_tau) if metric == "ridge_resid": y_m = np.log10(np.clip(ridge_abs, eps, None)) ylab_m = ( r"$\log_{10}\,|\delta_K-" r"(\delta_{S_s}+2\delta_{H_d})|$" ) else: m_raw = np.clip(df[metric].to_numpy(float), eps, None) y_m = np.log10(m_raw) ylab_m = rf"$\log_{{10}}\,{metric}$" rho_id = _spearman(err_tau, y_m) s_rhoI, rhoI_txt = _stat_ci( "err_vs_metric_spearman", err_tau, y_m, _rho, nd=2, seed_off=42, est=rho_id, ) # ------------------------- # Summary outputs (out/) # ------------------------- stats_rows: list[dict] = [ s_mae, s_r2, s_sl, s_maeK, s_r2K, s_slK, s_rhoR, s_rhoI, ] meta = { "n_rows": int(len(df)), "tau_units": str(tau_units), "metric": str(metric), "k_from_tau": bool(k_from_tau), "k_cl_source": str(k_cl_source), "ci": float(ci), "n_boot": int(n_boot), "boot_seed": int(boot_seed), "ridge_q50": float(ridge_q50), "ridge_q95": float(ridge_q95), } out_csv_p = utils.resolve_out_out(out_csv) out_json_p = utils.resolve_out_out(out_json) pd.DataFrame(stats_rows).to_csv(out_csv_p, index=False) out_json_p.write_text( json.dumps( {"meta": meta, "stats": stats_rows}, indent=2, sort_keys=True, ), encoding="utf-8", ) # ------------------------- # Plot styling + plotting # ------------------------- lith = df["lith_idx"].to_numpy(int) lith_names = { 0: "Fine", 1: "Mixed", 2: "Coarse", 3: "Rock", } markers = {0: "o", 1: "s", 2: "^", 3: "D"} def _beautify(ax: plt.Axes) -> None: ax.spines["top"].set_visible(False) ax.spines["right"].set_visible(False) ax.tick_params(direction="out", length=3, width=0.6) def _panel(ax: plt.Axes, lab: str) -> None: ax.text( -0.14, 1.05, lab, transform=ax.transAxes, fontweight="bold", va="bottom", ) nx_note = "" if "nx" in df.columns: nxu = sorted(set(df["nx"].astype(int).tolist())) nx_note = f" (nx={nxu})" fig = plt.figure( figsize=(7.2, 4.2), constrained_layout=True, ) gs = fig.add_gridspec(2, 2) # (a) axA = fig.add_subplot(gs[0, 0]) _beautify(axA) _panel(axA, "a") for i in sorted(np.unique(lith)): m = lith == i axA.scatter( x_tau[m], y_tau[m], s=22, marker=markers.get(i, "o"), alpha=0.9, label=lith_names.get(i, f"L{i}"), rasterized=True, ) if show_tau_cl: axA.scatter( x_tau, y_cl, s=22, facecolors="none", edgecolors="k", linewidths=0.7, label=r"$\tau_{cl}$", rasterized=True, ) if show_prior: axA.scatter( x_tau, y_pr, s=24, marker="^", facecolors="none", edgecolors="0.4", linewidths=0.7, label="prior", rasterized=True, ) loA = float(np.nanmin(np.r_[x_tau, y_tau, y_cl])) hiA = float(np.nanmax(np.r_[x_tau, y_tau, y_cl])) padA = 0.06 * (hiA - loA + 1e-9) loA -= padA hiA += padA axA.plot([loA, hiA], [loA, hiA], "--", linewidth=0.9) axA.set_xlim(loA, hiA) axA.set_ylim(loA, hiA) if show_panel_titles: axA.set_title("Timescale recovery" + nx_note, pad=2) if show_labels: axA.set_xlabel(xlab_a) axA.set_ylabel(ylab_a) else: axA.set_xlabel("") axA.set_ylabel("") if not show_ticklabels: axA.set_xticklabels([]) axA.set_yticklabels([]) if show_stats: axA.text( 0.03, 0.97, ( f"MAE = {mae_txt} (log10)\n" f"$R^2$ = {r2_txt}\n" f"slope = {sl_txt}" ), transform=axA.transAxes, va="top", ha="left", bbox=dict( boxstyle="round,pad=0.2", facecolor="white", alpha=0.85, linewidth=0.0, ), ) # (b) axB = fig.add_subplot(gs[0, 1]) _beautify(axB) _panel(axB, "b") for i in sorted(np.unique(lith)): m = lith == i axB.scatter( x_K[m], y_K[m], s=22, marker=markers.get(i, "o"), alpha=0.9, rasterized=True, ) loB = float(np.nanmin(np.r_[x_K, y_K])) hiB = float(np.nanmax(np.r_[x_K, y_K])) padB = 0.06 * (hiB - loB + 1e-9) loB -= padB hiB += padB axB.plot([loB, hiB], [loB, hiB], "--", linewidth=0.9) axB.set_xlim(loB, hiB) axB.set_ylim(loB, hiB) if show_panel_titles: axB.set_title(title_b + nx_note, pad=2) if show_labels: axB.set_xlabel(r"$\log_{10}\,K_{true}$ (m/yr)") axB.set_ylabel(ylab_b) else: axB.set_xlabel("") axB.set_ylabel("") if not show_ticklabels: axB.set_xticklabels([]) axB.set_yticklabels([]) if show_stats: axB.text( 0.03, 0.97, ( f"MAE = {maeK_txt} (log10)\n" f"$R^2$ = {r2K_txt}\n" f"slope = {slK_txt}" ), transform=axB.transAxes, va="top", ha="left", bbox=dict( boxstyle="round,pad=0.2", facecolor="white", alpha=0.85, linewidth=0.0, ), ) # (c) axC = fig.add_subplot(gs[1, 0]) _beautify(axC) _panel(axC, "c") for i in sorted(np.unique(lith)): m = lith == i axC.scatter( ridge_x[m], ridge_y[m], s=22, marker=markers.get(i, "o"), alpha=0.9, rasterized=True, ) loC = float(np.nanmin(np.r_[ridge_x, ridge_y])) hiC = float(np.nanmax(np.r_[ridge_x, ridge_y])) padC = 0.10 * (hiC - loC + 1e-9) loC -= padC hiC += padC axC.plot([loC, hiC], [loC, hiC], "--", linewidth=0.9) axC.axvline(0.0, linewidth=0.8) axC.axhline(0.0, linewidth=0.8) axC.set_xlim(loC, hiC) axC.set_ylim(loC, hiC) if show_panel_titles: axC.set_title("Degeneracy ridge check", pad=2) if show_labels: axC.set_xlabel( r"$\delta_{S_s}+2\,\delta_{H_d}$ (log$_{10}$)" ) axC.set_ylabel(r"$\delta_K$ (log$_{10}$)") else: axC.set_xlabel("") axC.set_ylabel("") if not show_ticklabels: axC.set_xticklabels([]) axC.set_yticklabels([]) if show_stats: axC.text( 0.03, 0.97, ( f"Spearman r = {rhoR_txt}\n" f"|ridge| q50 = {ridge_q50:.2f}\n" f"|ridge| q95 = {ridge_q95:.2f}" ), transform=axC.transAxes, va="top", ha="left", ) # (d) axD = fig.add_subplot(gs[1, 1]) _beautify(axD) _panel(axD, "d") for i in sorted(np.unique(lith)): m = lith == i axD.scatter( err_tau[m], y_m[m], s=22, marker=markers.get(i, "o"), alpha=0.9, rasterized=True, ) if show_panel_titles: axD.set_title( "Error vs identifiability metric", pad=2 ) if show_labels: axD.set_xlabel(r"$|\Delta\log_{10}\tau|$") axD.set_ylabel(ylab_m) else: axD.set_xlabel("") axD.set_ylabel("") if not show_ticklabels: axD.set_xticklabels([]) axD.set_yticklabels([]) if show_stats: axD.text( 0.03, 0.97, f"Spearman r = {rhoI_txt}", transform=axD.transAxes, va="top", ha="left", ) # shared legend if show_legend: handles: list[Line2D] = [] for i in sorted(np.unique(lith)): handles.append( Line2D( [0], [0], marker=markers.get(i, "o"), linestyle="none", markersize=5, label=lith_names.get(i, f"L{i}"), ) ) if show_tau_cl: handles.append( Line2D( [0], [0], marker="o", linestyle="none", markersize=5, markerfacecolor="none", markeredgecolor="k", label=r"$\tau_{cl}$", ) ) if show_prior: handles.append( Line2D( [0], [0], marker="^", linestyle="none", markersize=5, markerfacecolor="none", markeredgecolor="0.4", label="prior", ) ) fig.legend( handles=handles, loc="upper center", ncol=min(len(handles), 6), frameon=False, ) if show_title: ttl = utils.resolve_title( default="SM3 • Synthetic identifiability", title=title, ) fig.suptitle(ttl, fontsize=11, fontweight="bold") # fig_p = utils.resolve_fig_out(out) # if fig_p.suffix: # fig_p = fig_p.with_suffix("") # fig.savefig(str(fig_p) + ".png", dpi=int(dpi), # bbox_inches="tight") # fig.savefig(str(fig_p) + ".svg", bbox_inches="tight") # plt.close(fig) utils.save_figure(fig, out, dpi=int(dpi))
# print(f"[OK] figs -> {fig_p}.png/.svg") # print(f"[OK] table -> {out_csv_p}") # print(f"[OK] json -> {out_json_p}") # ------------------------------------------------------------------- # Main # -------------------------------------------------------------------
[docs] def plot_sm3_identifiability_main( argv: list[str] | None = None, *, prog: str | None = None, ) -> None: args = _parse_args(argv, prog=prog) utils.set_paper_style( fontsize=int(args.font), dpi=int(args.dpi), ) df = pd.read_csv(Path(args.csv).expanduser()) # strict filtering if args.only_identify is not None: if "identify" not in df.columns: raise KeyError( "--only-identify set but missing identify column." ) want = str(args.only_identify).strip().lower() got = ( df["identify"].astype(str).str.strip().str.lower() ) df = df.loc[got == want].copy() if args.nx_min is not None: if "nx" not in df.columns: raise KeyError("--nx-min set but missing nx.") nxv = df["nx"].astype(int) df = df.loc[nxv >= int(args.nx_min)].copy() if df.empty: raise ValueError("No rows left after filtering.") # decide k_from_tau if args.k_from_tau == "auto": if "identify" in df.columns: modes = ( df["identify"] .astype(str) .str.strip() .str.lower() .unique() .tolist() ) k_from_tau = "both" not in set(modes) else: k_from_tau = True else: k_from_tau = utils.str_to_bool(args.k_from_tau) show_legend = utils.str_to_bool( args.show_legend, default=True ) show_labels = utils.str_to_bool( args.show_labels, default=True ) show_ticks = utils.str_to_bool( args.show_ticklabels, default=True, ) show_title = utils.str_to_bool( args.show_title, default=True ) show_pan_t = utils.str_to_bool( args.show_panel_titles, default=True, ) show_ci = utils.str_to_bool(args.show_ci, default=True) show_stats = utils.str_to_bool( args.show_stats, default=True ) show_prior = utils.str_to_bool( args.show_prior, default=True ) show_tau_cl = utils.str_to_bool( args.show_tau_cl, default=True ) plot_sm3_identifiability( df, tau_units=str(args.tau_units), metric=str(args.metric), k_from_tau=bool(k_from_tau), k_cl_source=str(args.k_cl_source), show_prior=show_prior, show_tau_cl=show_tau_cl, show_legend=show_legend, show_labels=show_labels, show_ticklabels=show_ticks, show_panel_titles=show_pan_t, show_title=show_title, title=args.title, show_stats=show_stats, show_ci=show_ci, n_boot=int(args.n_boot), ci=float(args.ci), boot_seed=int(args.boot_seed), out=str(args.out), out_csv=str(args.out_csv), out_json=str(args.out_json), dpi=int(args.dpi), )
[docs] def main( argv: list[str] | None = None, *, prog: str | None = None, ) -> None: plot_sm3_identifiability_main(argv, prog=prog)
if __name__ == "__main__": main()