Source code for geoprior.models.subsidence.log_offsets_diagnostics

# SPDX-License-Identifier: Apache-2.0
# GeoPrior-v3 — https://github.com/earthai-tech/geoprior-v3
# Copyright (c) 2026-present
# Author: LKouadio <https://lkouadio.com>
r"""Diagnostics for subsidence log-offset policies and payloads."""

from __future__ import annotations

import os
from typing import Any

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

# ---------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------


def _ensure_dir(path: str) -> None:
    """Create directory if it does not exist."""
    if path and not os.path.exists(path):
        os.makedirs(path, exist_ok=True)


def _pick_first(
    d: dict[str, Any], *candidates: str
) -> str | None:
    """Return first key present in dict from a list of candidates."""
    for k in candidates:
        if k in d:
            return k
    return None


def _load_npz(path: str) -> dict[str, np.ndarray]:
    """Load a .npz file into a dict of arrays."""
    with np.load(path, allow_pickle=True) as data:
        return {k: data[k] for k in data.files}


def _compute_offsets_table(
    payload: dict[str, np.ndarray],
) -> pd.DataFrame:
    """
    Build a tidy table with log-offsets from the physics payload.

    This is intentionally defensive: it supports several possible
    naming conventions and only computes offsets for pairs that exist.

    Expected (but not strictly required) keys:
        - K_prior, K_eff
        - Ss_prior, Ss_eff
        - Hd_prior, Hd_eff
        - tau_prior, tau_eff

    Returns
    -------
    df : pandas.DataFrame
        Columns may include:
        ['delta_logK', 'delta_logSs', 'delta_logHd',
         'delta_log_tau', 'log_tau_prior', 'log_tau_eff',
         'depth', 'depth_index'].
    """
    keys = set(payload.keys())

    # Try to find depth / vertical coordinate (optional but nice)
    depth_key = _pick_first(
        payload,
        "depth",
        "z",
        "coord_z",
        "z_mid",
        "cell_center_z",
        "z_index",
    )

    # Utility to grab an array, or None if missing
    def arr(
        key_candidates: list[str],
    ) -> np.ndarray | None:
        k = _pick_first(payload, *key_candidates)
        return None if k is None else np.asarray(payload[k])

    # Candidate names for each physical field
    K_prior = arr(["K_prior", "k_prior", "K_lith_prior"])
    K_eff = arr(
        ["K_eff", "K_effective", "K_geo", "K_posterior"]
    )
    Ss_prior = arr(["Ss_prior", "Ss_lith_prior", "Ss0"])
    Ss_eff = arr(
        ["Ss_eff", "Ss_effective", "Ss_geo", "Ss_posterior"]
    )
    Hd_prior = arr(["Hd_prior", "H_d_prior", "Hd_lith_prior"])
    Hd_eff = arr(
        ["Hd_eff", "H_d_effective", "Hd_geo", "Hd_posterior"]
    )

    tau_prior = arr(["tau_prior", "timescale_prior", "tau0"])
    tau_eff = arr(
        ["tau_eff", "tau_effective", "tau_geoprior"]
    )

    depth = None
    if depth_key is not None:
        depth = np.asarray(payload[depth_key])

    # Determine base size and flatten
    base_size = None

    def _check_and_flatten(
        x: np.ndarray, name: str
    ) -> np.ndarray:
        nonlocal base_size
        x_flat = x.ravel()
        if base_size is None:
            base_size = x_flat.size
        elif x_flat.size != base_size:
            raise ValueError(
                f"Inconsistent size for {name}: {x_flat.size} vs {base_size}"
            )
        return x_flat

    cols: dict[str, np.ndarray] = {}

    if (K_prior is not None) and (K_eff is not None):
        delta_logK = np.log10(K_eff) - np.log10(K_prior)
        cols["delta_logK"] = _check_and_flatten(
            delta_logK, "delta_logK"
        )

    if (Ss_prior is not None) and (Ss_eff is not None):
        delta_logSs = np.log10(Ss_eff) - np.log10(Ss_prior)
        cols["delta_logSs"] = _check_and_flatten(
            delta_logSs, "delta_logSs"
        )

    if (Hd_prior is not None) and (Hd_eff is not None):
        # Hd is already in linear space; we use log10 for consistency
        delta_logHd = np.log10(Hd_eff) - np.log10(Hd_prior)
        cols["delta_logHd"] = _check_and_flatten(
            delta_logHd, "delta_logHd"
        )

    if (tau_prior is not None) and (tau_eff is not None):
        log_tau_prior = np.log10(tau_prior)
        log_tau_eff = np.log10(tau_eff)
        delta_log_tau = log_tau_eff - log_tau_prior

        cols["log_tau_prior"] = _check_and_flatten(
            log_tau_prior, "log_tau_prior"
        )
        cols["log_tau_eff"] = _check_and_flatten(
            log_tau_eff, "log_tau_eff"
        )
        cols["delta_log_tau"] = _check_and_flatten(
            delta_log_tau, "delta_log_tau"
        )

    if not cols:
        raise RuntimeError(
            "No matching K/Ss/Hd/tau pairs found in physics payload; "
            f"available keys: {sorted(keys)}"
        )

    # Optional depth
    if depth is not None:
        depth_flat = depth.ravel()
        if depth_flat.size == base_size:
            cols["depth"] = depth_flat
        else:
            # Fall back to an index if depth has different topology
            cols["depth"] = np.linspace(0.0, 1.0, base_size)

    # Always include an index
    cols["depth_index"] = np.arange(base_size, dtype=int)

    df = pd.DataFrame(cols)
    return df


def _summarize_offsets(df: pd.DataFrame) -> pd.DataFrame:
    """
    Summarise offsets: mean, std, p05, p50, p95 for each δ-column.
    """
    delta_cols = [
        c
        for c in df.columns
        if c.startswith("delta_log") or c == "delta_log_tau"
    ]
    if not delta_cols:
        raise RuntimeError("No δ-columns found in DataFrame.")

    desc = (
        df[delta_cols]
        .describe(percentiles=[0.05, 0.5, 0.95])
        .T
    )
    desc = desc.rename(
        columns={
            "mean": "mean",
            "std": "std",
            "5%": "p05",
            "50%": "p50",
            "95%": "p95",
        }
    )
    # Keep only the relevant columns
    cols = ["mean", "std", "p05", "p50", "p95"]
    desc = desc[cols]
    desc.index.name = "metric"
    return desc


def _make_plots(
    df: pd.DataFrame,
    outdir: str,
    prefix: str = "sm3_offsets",
) -> list[str]:
    """
    Produce simple histograms and τ–δτ scatter plot.
    Returns a list of saved file paths.
    """
    saved: list[str] = []
    _ensure_dir(outdir)

    # Histograms
    for col in [
        "delta_logK",
        "delta_logSs",
        "delta_logHd",
        "delta_log_tau",
    ]:
        if col not in df.columns:
            continue

        plt.figure()
        df[col].dropna().hist(bins=50)
        plt.xlabel(col)
        plt.ylabel("Count")
        plt.title(f"Distribution of {col}")
        fname = os.path.join(
            outdir, f"{prefix}_{col}_hist.png"
        )
        plt.tight_layout()
        plt.savefig(fname, dpi=200)
        plt.close()
        saved.append(fname)

    # τ–δτ scatter (prior vs offset)
    if {"log_tau_prior", "delta_log_tau"}.issubset(
        df.columns
    ):
        plt.figure()
        plt.scatter(
            df["log_tau_prior"],
            df["delta_log_tau"],
            s=4,
            alpha=0.3,
        )
        plt.axhline(0.0, linestyle="--")
        plt.xlabel(r"$\log_{10}\tau_{\mathrm{prior}}$")
        plt.ylabel(
            r"$\delta_{\tau} = \log_{10}\tau_{\mathrm{eff}} - "
            r"\log_{10}\tau_{\mathrm{prior}}$"
        )
        plt.title("Timescale prior vs log-offset (SM3)")
        fname = os.path.join(
            outdir, f"{prefix}_tau_scatter.png"
        )
        plt.tight_layout()
        plt.savefig(fname, dpi=200)
        plt.close()
        saved.append(fname)

    return saved


# ---------------------------------------------------------------------
# Public entry point
# ---------------------------------------------------------------------


[docs] def run_sm3_offsets_from_payload( physics_npz_path: str, outdir: str | None = None, city: str | None = None, model_name: str = "GeoPriorSubsNet", ) -> dict[str, Any]: """ High-level driver: compute SM3 diagnostics from a physics payload. Parameters ---------- physics_npz_path : str Path to ``*_phys_payload_run_val.npz`` as written by :meth:`GeoPriorSubsNet.export_physics_payload`. outdir : str, optional Directory where CSVs and plots are written. If ``None``, defaults to the directory of ``physics_npz_path``. city : str, optional City name for filenames. model_name : str, default "GeoPriorSubsNet" Model name for filenames. Returns ------- result : dict Dictionary with keys: - 'raw_csv' - 'summary_csv' - 'plots' (list of paths) """ physics_npz_path = os.path.abspath(physics_npz_path) if outdir is None: outdir = os.path.dirname(physics_npz_path) _ensure_dir(outdir) tag = [] if city: tag.append(city.lower()) if model_name: tag.append(model_name.lower()) tag_str = "_".join(tag) if tag else "geoprior" print( f"[SM3] Loading physics payload: {physics_npz_path}" ) payload = _load_npz(physics_npz_path) print( f"[SM3] Keys in physics payload: {sorted(payload.keys())}" ) df = _compute_offsets_table(payload) print(f"[SM3] Built offsets table with shape {df.shape}.") raw_csv = os.path.join( outdir, f"{tag_str}_sm3_offsets_raw.csv" ) df.to_csv(raw_csv, index=False) print(f"[SM3] Saved raw offsets CSV -> {raw_csv}") summary_df = _summarize_offsets(df) summary_csv = os.path.join( outdir, f"{tag_str}_sm3_offsets_summary.csv" ) summary_df.to_csv(summary_csv) print(f"[SM3] Saved summary offsets CSV -> {summary_csv}") plots = _make_plots( df, outdir=outdir, prefix=f"{tag_str}_sm3" ) for p in plots: print(f"[SM3] Saved plot -> {p}") return { "raw_csv": raw_csv, "summary_csv": summary_csv, "plots": plots, }