Source code for geoprior.models.subsidence.plot

# SPDX-License-Identifier: Apache-2.0
# GeoPrior-v3 — https://github.com/earthai-tech/geoprior-v3
# Copyright (c) 2026-present
# Author: LKouadio <https://lkouadio.com>
r"""Plotting helpers for subsidence training and diagnostics."""

from __future__ import annotations

import inspect
import os
import warnings
from collections.abc import Callable
from typing import (
    Any,
)

import matplotlib.pyplot as plt
import numpy as np

from .. import KERAS_DEPS

History = KERAS_DEPS.History


__all__ = [
    "plot_history_in",
    "gather_coords_flat",
    "plot_physics_values_in",
    "plot_epsilons_in",
    "plot_physics_losses_in",
    "autoplot_geoprior_history",
]


def _as_history_dict(history: Any) -> dict[str, list[float]]:
    # Accept History or dict-like.
    if isinstance(history, History):
        return dict(history.history or {})
    if isinstance(history, dict):
        return dict(history)
    h = getattr(history, "history", None)
    if isinstance(h, dict):
        return dict(h)
    raise TypeError(
        "history must be keras History or dict-like."
    )


def _get_valid_kwargs(fn, kwargs: dict) -> dict:
    # Filter kwargs to avoid matplotlib signature errors.
    try:
        sig = inspect.signature(fn)
        valid = set(sig.parameters.keys())
    except Exception:
        return dict(kwargs)
    return {k: v for k, v in kwargs.items() if k in valid}


def _has_pos_only(arrs: list[np.ndarray]) -> bool:
    # True if all finite and strictly positive.
    if not arrs:
        return False
    for a in arrs:
        a = a[np.isfinite(a)]
        if a.size == 0:
            continue
        if np.any(a <= 0.0):
            return False
    return True


def _pick_scale(
    scale: str,
    arrs: list[np.ndarray],
) -> tuple[str, dict]:
    # For requested "log", fall back to symlog if needed.
    if scale != "log":
        return scale, {}
    if _has_pos_only(arrs):
        return "log", {}
    # symlog keeps zeros/negatives visible.
    # linthresh from smallest positive, else 1e-12.
    pos = []
    for a in arrs:
        a = a[np.isfinite(a)]
        pos.append(a[a > 0.0])
    pos = np.concatenate(pos) if pos else np.asarray([])
    lt = float(np.min(pos)) if pos.size else 1e-12
    lt = max(lt, 1e-12)
    return "symlog", {"linthresh": lt}


[docs] def plot_history_in( history: History | dict, metrics: dict[str, list[str]] | None = None, layout: str = "subplots", title: str = "Model Training History", figsize: tuple[float, float] | None = None, style: str = "default", savefig: str | None = None, max_cols: int | str = "auto", show_grid: bool = True, grid_props: dict | None = None, yscale_settings: dict[str, str] | None = None, log_fn: Callable[..., None] | None = None, **plot_kwargs, ) -> None: """ Plot Keras history (train + val) robustly. """ log = log_fn if log_fn is not None else print hist = _as_history_dict(history) if not hist: warnings.warn( "Empty history: nothing to plot.", stacklevel=2 ) return # Style (never crash) try: plt.style.use(style) except Exception: plt.style.use("default") # Auto-group if not provided if metrics is None: metrics = {} for k in hist.keys(): if k.startswith("val_"): continue g = "Losses" if "loss" in k.lower() else k g = g.replace("_", " ").title() metrics.setdefault(g, []).append(k) if not metrics: warnings.warn("No metrics to plot.", stacklevel=2) return if yscale_settings is None: yscale_settings = {} grid_kws = grid_props or {"linestyle": ":", "alpha": 0.7} # Layout n_plots = len(metrics) if layout == "single": n_rows, n_cols = 1, 1 if figsize is None: figsize = (10.0, 6.0) else: cols = 2 if max_cols == "auto" else int(max_cols) n_cols = max(1, min(cols, n_plots)) n_rows = (n_plots + n_cols - 1) // n_cols if figsize is None: figsize = ( float(n_cols) * 6.0, float(n_rows) * 5.0, ) fig, axes = plt.subplots( n_rows, n_cols, figsize=figsize, squeeze=False, ) axflat = axes.flatten() fig.suptitle(title, fontsize=14, weight="bold") def _plot_one_axis( ax, keys: list[str], subttl: str ) -> None: # Collect arrays for scale decision. arrs = [] for k in keys: if k in hist and len(hist[k]): arrs.append(np.asarray(hist[k], dtype=float)) # Choose scale for this subplot. req = yscale_settings.get(subttl, "linear") scale, skws = _pick_scale(req, arrs) try: ax.set_yscale(scale, **skws) except Exception: ax.set_yscale("linear") # Plot requested keys. seen = set() for k in keys: if k not in hist or not len(hist[k]): continue if k in seen: continue seen.add(k) y = np.asarray(hist[k], dtype=float) x = np.arange(1, len(y) + 1) # If user passed val_* explicitly, treat it as val. is_val = k.startswith("val_") base = k[4:] if is_val else k pk = _get_valid_kwargs(ax.plot, plot_kwargs) lab = "Val " if is_val else "Train " lab = lab + base.replace("_", " ").title() ax.plot(x, y, label=lab, **pk) # Auto-add val only for non-val keys. if not is_val: vk = "val_" + k if vk in hist and len(hist[vk]): vy = np.asarray(hist[vk], dtype=float) ax.plot( x, vy, linestyle="--", label=( "Val " + base.replace("_", " ").title() ), **pk, ) ax.set_title(subttl) ax.set_xlabel("Epoch") ax.set_ylabel("Value") ax.legend() if show_grid: ax.grid(**grid_kws) if layout == "single": ax = axflat[0] keys = [] for grp_keys in metrics.values(): keys.extend(grp_keys) _plot_one_axis(ax, keys, "All Metrics") for i in range(1, len(axflat)): axflat[i].set_visible(False) else: i = 0 for subttl, keys in metrics.items(): if i >= len(axflat): break _plot_one_axis(axflat[i], keys, subttl) i += 1 for j in range(i, len(axflat)): axflat[j].set_visible(False) plt.tight_layout(rect=[0.0, 0.03, 1.0, 0.95]) if savefig: root, ext = os.path.splitext(savefig) if not ext: savefig = root + ".png" os.makedirs( os.path.dirname(savefig) or ".", exist_ok=True ) try: plt.savefig(savefig, dpi=300) log(f"[OK] Saved figure -> {savefig}") except Exception as e: warnings.warn(f"Save failed: {e}", stacklevel=2) plt.close(fig) else: plt.show()
[docs] def gather_coords_flat( dataset, *, coord_key="coords", log_fn=None, max_batches=None, ): """ Collect flat (t, x, y) arrays from a tf.data dataset. """ log = log_fn if log_fn is not None else (lambda *_: None) ts = [] xs = [] ys = [] n_seen = 0 for batch in dataset: # dataset can yield inputs or (inputs, targets) inputs = ( batch[0] if isinstance(batch, tuple | list) else batch ) # inputs can be dict, sequence, or coords tensor directly if isinstance(inputs, dict): coords = inputs.get(coord_key, None) if coords is None: raise KeyError( f"Missing '{coord_key}' in inputs dict." ) elif isinstance(inputs, tuple | list): coords = inputs[0] else: coords = inputs # tf.Tensor -> numpy if hasattr(coords, "numpy"): coords = coords.numpy() coords = np.asarray(coords) if coords.shape[-1] != 3: raise ValueError( "coords[..., -1] must be 3 (t, x, y)." ) # coords can be (B, 3) or (B, T, 3) t = np.asarray(coords[..., 0]).ravel() x = np.asarray(coords[..., 1]).ravel() y = np.asarray(coords[..., 2]).ravel() ts.append(t) xs.append(x) ys.append(y) n_seen += 1 if max_batches is not None and n_seen >= max_batches: break if not xs: raise ValueError("Dataset yielded no coords.") out = { "t": np.concatenate(ts, axis=0), "x": np.concatenate(xs, axis=0), "y": np.concatenate(ys, axis=0), } log( "gather_coords_flat:", f"n={out['x'].shape[0]}", ) return out
[docs] def plot_physics_values_in( payload, *, keys=None, dataset=None, coords=None, mode="map", title="Physics diagnostics", n_cols=2, figsize=None, savefig=None, show=True, clip_q=(0.01, 0.99), transform=None, bins=80, s=8, log_fn=None, **scatter_kwargs, ): """ Plot physics arrays (residuals/fields) from a payload dict. """ log = log_fn if log_fn is not None else print def _finite(a): a = np.asarray(a, dtype=float) m = np.isfinite(a) return a[m] def _safe_vlim(v): v = _finite(v) if v.size == 0: return None, None lo, hi = clip_q lo = float(lo) hi = float(hi) if not (0.0 <= lo < hi <= 1.0): return None, None try: vmin = float(np.quantile(v, lo)) vmax = float(np.quantile(v, hi)) except Exception: return None, None if not np.isfinite(vmin) or not np.isfinite(vmax): return None, None if vmin == vmax: return None, None return vmin, vmax def _apply_transform(v): if transform is None: return v if callable(transform): return transform(v) t = str(transform).lower().strip() if t == "abs": return np.abs(v) if t == "log10": v = np.asarray(v, dtype=float) v = np.where(v > 0.0, v, np.nan) return np.log10(v) if t in ("signed_log10", "slog10"): v = np.asarray(v, dtype=float) return np.sign(v) * np.log10(1.0 + np.abs(v)) return v def _pick_keys(d): if keys is not None: return list(keys) # reasonable defaults for GeoPrior payloads pref = [ "cons_res_vals", "R_cons", "epsilon_cons", "epsilon_gw", "epsilon_prior", "log10_tau", "log10_tau_prior", "K", "Ss", "Hd", "H", ] out = [k for k in pref if k in d] if out: return out # fallback: any numeric arrays out = [] for k, v in d.items(): if isinstance(v, np.ndarray): if v.dtype.kind in ("f", "i", "u"): out.append(k) return out def _get_xy_for_values(v): # user provided coords dict if coords is not None: x = np.asarray(coords["x"]).ravel() y = np.asarray(coords["y"]).ravel() return x, y # derive coords from dataset (preferred) if dataset is not None: c = gather_coords_flat( dataset, log_fn=log_fn, ) x = np.asarray(c["x"]).ravel() y = np.asarray(c["y"]).ravel() return x, y return None, None def _align_xy(x, y, v): # try to align lengths without guessing too much x = np.asarray(x).ravel() y = np.asarray(y).ravel() v = np.asarray(v).ravel() nx = x.shape[0] nv = v.shape[0] if nx == nv: return x, y, v # common case: payload is (B) but coords are (B*T) if nx % nv == 0: step = nx // nv x2 = x[::step] y2 = y[::step] if x2.shape[0] == nv: return x2, y2, v # opposite: coords (B) but values (B*T) if nv % nx == 0: step = nv // nx v2 = v[::step] if v2.shape[0] == nx: return x, y, v2 # last resort: truncate to min length n = min(nx, nv) return x[:n], y[:n], v[:n] d = payload if not isinstance(d, dict): raise TypeError("payload must be a dict.") klist = _pick_keys(d) if not klist: warnings.warn( "No plot-able keys found in payload.", stacklevel=2, ) return mode = str(mode).lower().strip() if mode not in ("map", "hist", "both"): mode = "map" n_plots = len(klist) if mode == "both": n_plots = len(klist) * 2 if n_cols < 1: n_cols = 1 n_cols = min(int(n_cols), n_plots) n_rows = (n_plots + n_cols - 1) // n_cols if figsize is None: figsize = (n_cols * 6, n_rows * 5) fig, axes = plt.subplots( n_rows, n_cols, figsize=figsize, squeeze=False, ) axes = axes.ravel() fig.suptitle(title, fontsize=16, weight="bold") plot_i = 0 for k in klist: v = d.get(k, None) if v is None: continue v = np.asarray(v).ravel() v = _apply_transform(v) if mode in ("map", "both"): ax = axes[plot_i] plot_i += 1 x, y = _get_xy_for_values(v) if x is None or y is None: ax.set_visible(False) else: x, y, vv = _align_xy(x, y, v) m = np.isfinite(vv) x = x[m] y = y[m] vv = vv[m] vmin, vmax = _safe_vlim(vv) sc = ax.scatter( x, y, c=vv, s=s, vmin=vmin, vmax=vmax, **scatter_kwargs, ) ax.set_title(k) ax.set_xlabel("x") ax.set_ylabel("y") ax.set_aspect("equal", adjustable="box") try: fig.colorbar(sc, ax=ax, shrink=0.85) except Exception: pass if mode in ("hist", "both"): ax = axes[plot_i] plot_i += 1 vv = _finite(v) if vv.size == 0: ax.set_visible(False) else: ax.hist(vv, bins=int(bins)) ax.set_title(f"{k} (hist)") ax.set_xlabel(k) ax.set_ylabel("count") # hide unused axes for j in range(plot_i, len(axes)): axes[j].set_visible(False) plt.tight_layout(rect=[0, 0.03, 1, 0.96]) if savefig: try: root, ext = os.path.splitext(savefig) if not ext: savefig = root + ".png" out_dir = os.path.dirname(savefig) if out_dir and not os.path.exists(out_dir): os.makedirs(out_dir) plt.savefig( savefig, dpi=300, bbox_inches="tight", ) log(f"Saved: {savefig}") except Exception as e: warnings.warn(f"Save failed: {e}", stacklevel=2) finally: plt.close(fig) else: if show: plt.show() else: plt.close(fig)
def _keys_starting( hist: dict[str, list[float]], p: str ) -> list[str]: ks = [k for k in hist.keys() if not k.startswith("val_")] ks = [k for k in ks if k.startswith(p)] return sorted(ks) def _existing( hist: dict[str, list[float]], keys: list[str] ) -> list[str]: out = [] for k in keys: if k in hist and len(hist[k]): out.append(k) return out
[docs] def plot_epsilons_in( history: History | dict, *, title: str = "Epsilons", savefig: str | None = None, style: str = "default", log_fn: Callable[..., None] | None = None, ) -> None: # Plot only epsilon_* (incl. *_raw) with safe symlog. hist = _as_history_dict(history) eps_keys = _keys_starting(hist, "epsilon_") if not eps_keys: (log_fn or print)("[plot] No epsilon_* keys.") return groups = {"Epsilons": eps_keys} ysc = {"Epsilons": "log"} plot_history_in( hist, metrics=groups, layout="single", title=title, style=style, savefig=savefig, yscale_settings=ysc, log_fn=log_fn, )
[docs] def plot_physics_losses_in( history: History | dict, *, title: str = "Physics Loss Terms", savefig: str | None = None, style: str = "default", log_fn: Callable[..., None] | None = None, ) -> None: # Auto-plot key physics loss terms with log/symlog. hist = _as_history_dict(history) keys = [ "physics_loss", "physics_loss_scaled", "consolidation_loss", "gw_flow_loss", "prior_loss", "smooth_loss", "mv_prior_loss", "bounds_loss", # optional diagnostics if enabled: "q_reg_loss", "q_rms", "q_gate", "subs_resid_gate", ] keys = _existing(hist, keys) if not keys: (log_fn or print)("[plot] No physics loss keys.") return groups = {"Physics": keys} # Request log; plot_history_in will fall back to symlog. ysc = {"Physics": "log"} plot_history_in( hist, metrics=groups, layout="single", title=title, style=style, savefig=savefig, yscale_settings=ysc, log_fn=log_fn, )
[docs] def autoplot_geoprior_history( history: History | dict, *, outdir: str, prefix: str = "geoprior", style: str = "default", log_fn: Callable[..., None] | None = None, ) -> None: # Minimal, robust: epsilons + physics loss terms. os.makedirs(outdir, exist_ok=True) plot_epsilons_in( history, title=f"{prefix} | epsilons", savefig=os.path.join( outdir, f"{prefix}_epsilons.png" ), style=style, log_fn=log_fn, ) plot_physics_losses_in( history, title=f"{prefix} | physics terms", savefig=os.path.join( outdir, f"{prefix}_physics_terms.png" ), style=style, log_fn=log_fn, )