Source code for geoprior.utils.inspect.model_init_manifest

# SPDX-License-Identifier: Apache-2.0
# GeoPrior-v3
# Copyright (c) 2026-present
# Author: LKouadio <https://lkouadio.com>

"""
Model-initialization manifest generation and inspection helpers.

This module focuses on the ``model_init_manifest``
artifact saved around model construction time. It
captures the initialized model class, input/output
dimensions, architecture choices, GeoPrior physics
initialization, and the resolved scaling payload used
at initialization.

The functions here support two common needs:

1. Sphinx-Gallery examples that need a realistic
   initialization manifest without rebuilding a model.
2. Real workflow inspection when a user wants to
   verify architecture, physics knobs, feature counts,
   and scaling conventions before training.
"""

from __future__ import annotations

from collections.abc import Mapping
from pathlib import Path
from typing import Any

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from .utils import (
    ArtifactRecord,
    as_path,
    clone_artifact,
    flatten_dict,
    load_artifact,
    nested_get,
    plot_boolean_checks,
    plot_metric_bars,
    read_json,
    write_json,
)

PathLike = str | Path
ModelInitManifestLike = (
    ArtifactRecord | Mapping[str, Any] | str | Path
)

__all__ = [
    "default_model_init_manifest_payload",
    "generate_model_init_manifest",
    "inspect_model_init_manifest",
    "load_model_init_manifest",
    "model_init_architecture_frame",
    "model_init_dims_frame",
    "model_init_feature_groups_frame",
    "model_init_geoprior_frame",
    "model_init_scaling_overview_frame",
    "plot_model_init_architecture",
    "plot_model_init_boolean_summary",
    "plot_model_init_dims",
    "plot_model_init_feature_group_sizes",
    "plot_model_init_geoprior",
    "summarize_model_init_manifest",
]

_ARCH_KEYS = [
    "embed_dim",
    "hidden_units",
    "lstm_units",
    "attention_units",
    "num_heads",
    "dropout_rate",
    "memory_size",
    "vsn_units",
]

_GEO_KEYS = [
    "gamma_w",
    "h_ref_value",
    "hd_factor",
    "init_kappa",
    "init_mv",
]

_SCALING_OVERVIEW_KEYS = [
    "time_units",
    "seconds_per_time_unit",
    "coords_normalized",
    "coords_in_degrees",
    "coord_epsg_used",
    "coord_src_epsg",
    "coord_target_epsg",
    "clip_global_norm",
    "cons_residual_units",
    "gw_residual_units",
    "cons_scale_floor",
    "gw_scale_floor",
    "Q_kind",
    "Q_in_si",
    "Q_in_per_second",
    "Q_length_in_si",
    "Q_wrt_normalized_time",
    "lambda_q",
    "mv_weight",
    "mv_schedule_unit",
    "mv_delay_epochs",
    "mv_warmup_epochs",
    "physics_ramp_steps",
    "physics_warmup_steps",
]


def _as_payload(
    manifest: ModelInitManifestLike,
) -> dict[str, Any]:
    """Return a plain model-init payload."""
    if isinstance(manifest, ArtifactRecord):
        return dict(manifest.payload)

    if isinstance(manifest, Mapping):
        return dict(manifest)

    payload = read_json(manifest)
    return dict(payload)


def _selected_numeric(
    mapping: dict[str, Any] | None,
    keys: list[str] | tuple[str, ...],
) -> dict[str, float]:
    """Extract selected numeric scalar items."""
    out: dict[str, float] = {}
    for key in keys:
        value = (mapping or {}).get(key)
        if isinstance(value, bool):
            continue
        if isinstance(value, (int, float, np.number)):
            out[str(key)] = float(value)
    return out


def _selected_rows(
    mapping: dict[str, Any] | None,
    keys: list[str] | tuple[str, ...],
    *,
    section: str,
) -> list[dict[str, Any]]:
    """Build a row list for selected keys."""
    rows: list[dict[str, Any]] = []
    for key in keys:
        value = (mapping or {}).get(key)
        rows.append(
            {
                "section": section,
                "key": str(key),
                "value": value,
            }
        )
    return rows


def _feature_names(
    payload: dict[str, Any],
    kind: str,
) -> list[str]:
    """Return feature names from nested scaling kwargs."""
    names = nested_get(
        payload,
        "config",
        "scaling_kwargs",
        f"{kind}_feature_names",
        default=[],
    )
    if not isinstance(names, list):
        return []
    return [str(v) for v in names]


def _bounds_count(payload: dict[str, Any]) -> int:
    """Return the number of configured bounds entries."""
    bounds = nested_get(
        payload,
        "config",
        "scaling_kwargs",
        "bounds",
        default={},
    )
    if not isinstance(bounds, dict):
        return 0
    return len(bounds)


[docs] def default_model_init_manifest_payload( *, model_class: str = "GeoPriorSubsNet", forecast_horizon: int = 3, static_input_dim: int = 12, dynamic_input_dim: int = 5, future_input_dim: int = 1, output_subsidence_dim: int = 1, output_gwl_dim: int = 1, quantiles: list[float] | None = None, mode: str = "tft_like", pde_mode: str = "on", identifiability_regime: str | None = None, time_units: str = "year", ) -> dict[str, Any]: """ Build a realistic default model-init manifest. The payload mirrors the saved structure used by the model-init manifest family: ``config`` + ``dims`` + ``model_class``. """ if quantiles is None: quantiles = [0.1, 0.5, 0.9] payload = { "config": { "attention_units": 64, "dropout_rate": 0.1, "embed_dim": 32, "geoprior": { "gamma_w": 9810.0, "h_ref_mode": "auto", "h_ref_value": 0.0, "hd_factor": 0.6, "init_kappa": 1.0, "init_mv": 1e-7, "kappa_mode": "kb", "offset_mode": "mul", "use_effective_h": True, }, "hidden_units": 64, "identifiability_regime": identifiability_regime, "lstm_units": 64, "memory_size": 50, "mode": mode, "num_heads": 2, "pde_mode": pde_mode, "quantiles": list(quantiles), "scales": [1, 2], "scaling_kwargs": { "H_bias_si": 0.0, "H_scale_si": 1.0, "Q_in_per_second": False, "Q_in_si": False, "Q_kind": "per_volume", "Q_length_in_si": False, "Q_wrt_normalized_time": False, "allow_subs_residual": True, "bounds": { "H_min": 0.1, "H_max": 30.0, "K_min": 1e-12, "K_max": 1e-7, "Ss_min": 1e-6, "Ss_max": 1e-3, "tau_min": 604800.0, "tau_max": 9467085600.0, "logK_min": -27.6310211159, "logK_max": -16.1180956510, "logSs_min": -13.8155105580, "logSs_max": -6.9077552790, "logTau_min": 13.3126531038, "logTau_max": 22.9710869460, }, "bounds_config": { "beta": 20.0, "guard": 5.0, "include_tau": True, "kind": "both", "mode": "soft", "tau_w": 1.0, "w": 1.0, }, "clip_global_norm": 5.0, "cons_drawdown_mode": "softplus", "cons_drawdown_rule": "ref_minus_mean", "cons_drawdown_zero_at_origin": False, "cons_relu_beta": 20.0, "cons_residual_units": "second", "cons_scale_floor": 3e-11, "cons_stop_grad_ref": True, "coord_epsg_used": 32649, "coord_inv_ranges_si": { "t": 4.526962643830204e-09, "x": 2.249870632438635e-05, "y": 2.5461489497135583e-05, }, "coord_mode": "degrees", "coord_order": ["t", "x", "y"], "coord_ranges": { "t": 7.0, "x": 44447.0, "y": 39275.0, }, "coord_ranges_si": { "t": 220898664.0, "x": 44447.0, "y": 39275.0, }, "coord_src_epsg": 4326, "coord_target_epsg": 32649, "coords_in_degrees": False, "coords_normalized": True, "debug_physics_grads": False, "drainage_mode": "double", "dt_min_units": 1e-6, "dynamic_feature_names": [ "GWL_depth_bgs_m__si", "subsidence_cum__si", "rainfall_mm", "urban_load_global", "soil_thickness_censored", ], "future_feature_names": ["rainfall_mm"], "gw_residual_units": "second", "gw_scale_floor": 1e-12, "gwl_col": "GWL_depth_bgs_m__si", "gwl_driver_kind": "depth", "gwl_driver_sign": "down_positive", "gwl_dyn_col": "GWL_depth_bgs_m__si", "gwl_dyn_index": 0, "gwl_dyn_name": "GWL_depth_bgs_m__si", "gwl_kind": "depth_bgs", "gwl_sign": "down_positive", "gwl_target_col": "head_m__si", "gwl_target_kind": "head", "gwl_target_sign": "up_positive", "head_bias_si": 0.0, "head_scale_si": 1.0, "lambda_q": 5e-4, "log_q_diagnostics": True, "loss_weight_gwl": 0.8, "mv_alpha_disp": 0.1, "mv_delay_epochs": 1, "mv_delay_steps": 4261, "mv_huber_delta": 1.0, "mv_prior_mode": "calibrate", "mv_prior_units": "strict", "mv_schedule_unit": "epoch", "mv_steps_per_epoch": 4261, "mv_warmup_epochs": 2, "mv_warmup_steps": 8522, "mv_weight": 0.001, "physics_ramp_steps": 500, "physics_warmup_steps": 500, "q_policy": "warmup_off", "q_ramp_epochs": 6, "q_ramp_steps": 25566, "q_warmup_epochs": 2, "q_warmup_steps": 8522, "scaling_error_policy": "raise", "seconds_per_time_unit": 31556952.0, "static_feature_names": [ "lithology_Conglomerate–Sandstone", "lithology_Limestone–Sandstone", "lithology_Mudstone–Siltstone", "lithology_Sandstone–Siltstone", "lithology_Shale–Limestone", "lithology_Siltstone–Sandstone", "lithology_Siltstone–Shale", "lithology_Tuff–Sandstone", "lithology_class_Coarse-Grained Soil", "lithology_class_Fine-Grained Soil", "lithology_class_Mixed Clastics", "z_surf_m__si", ], "subs_bias_si": 0.0, "subs_dyn_index": 1, "subs_model_col": "subsidence_cum__si", "subs_scale_si": 1.0, "subsidence_kind": "cumulative", "time_units": time_units, "track_aux_metrics": False, "use_head_proxy": True, "z_surf_col": "z_surf_m__si", "z_surf_static_index": 11, }, "time_units": time_units, "use_batch_norm": False, "use_residuals": True, "use_vsn": True, "vsn_units": 32, }, "dims": { "dynamic_input_dim": int(dynamic_input_dim), "forecast_horizon": int(forecast_horizon), "future_input_dim": int(future_input_dim), "output_gwl_dim": int(output_gwl_dim), "output_subsidence_dim": int( output_subsidence_dim ), "static_input_dim": int(static_input_dim), }, "model_class": str(model_class), } return payload
[docs] def generate_model_init_manifest( path: PathLike, *, template: Mapping[str, Any] | None = None, overrides: Mapping[str, Any] | None = None, ) -> Path: """ Generate a model-init manifest JSON file. Parameters ---------- path : str or pathlib.Path Output JSON path. template : mapping, optional Base payload. If omitted, uses :func:`default_model_init_manifest_payload`. overrides : mapping, optional Nested overrides applied on top of the template. """ base = ( dict(template) if template is not None else default_model_init_manifest_payload() ) payload = clone_artifact( base, overrides=dict(overrides or {}) ) return write_json(payload, path)
[docs] def load_model_init_manifest( path: PathLike, ) -> ArtifactRecord: """Load a model-init manifest as ``ArtifactRecord``.""" return load_artifact(path, kind="model_init_manifest")
[docs] def model_init_dims_frame( manifest: ModelInitManifestLike, ) -> pd.DataFrame: """Return a tidy frame for input/output dimensions.""" payload = _as_payload(manifest) dims = nested_get(payload, "dims", default={}) rows: list[dict[str, Any]] = [] for key, value in (dims or {}).items(): rows.append( { "section": "dims", "name": str(key), "value": value, } ) return pd.DataFrame(rows)
[docs] def model_init_architecture_frame( manifest: ModelInitManifestLike, ) -> pd.DataFrame: """Return a frame for architecture choices.""" payload = _as_payload(manifest) cfg = nested_get(payload, "config", default={}) rows = _selected_rows( cfg, _ARCH_KEYS, section="architecture" ) rows.extend( [ { "section": "architecture", "key": "mode", "value": cfg.get("mode"), }, { "section": "architecture", "key": "pde_mode", "value": cfg.get("pde_mode"), }, { "section": "architecture", "key": "identifiability_regime", "value": cfg.get("identifiability_regime"), }, { "section": "architecture", "key": "quantiles", "value": cfg.get("quantiles"), }, { "section": "architecture", "key": "scales", "value": cfg.get("scales"), }, ] ) return pd.DataFrame(rows)
[docs] def model_init_geoprior_frame( manifest: ModelInitManifestLike, ) -> pd.DataFrame: """Return a frame for GeoPrior-specific init settings.""" payload = _as_payload(manifest) geoprior = nested_get( payload, "config", "geoprior", default={} ) rows = _selected_rows( geoprior, _GEO_KEYS, section="geoprior" ) rows.extend( [ { "section": "geoprior", "key": "h_ref_mode", "value": geoprior.get("h_ref_mode"), }, { "section": "geoprior", "key": "kappa_mode", "value": geoprior.get("kappa_mode"), }, { "section": "geoprior", "key": "offset_mode", "value": geoprior.get("offset_mode"), }, { "section": "geoprior", "key": "use_effective_h", "value": geoprior.get("use_effective_h"), }, ] ) return pd.DataFrame(rows)
[docs] def model_init_scaling_overview_frame( manifest: ModelInitManifestLike, ) -> pd.DataFrame: """Return a compact overview of resolved scaling kwargs.""" payload = _as_payload(manifest) sk = nested_get( payload, "config", "scaling_kwargs", default={} ) rows = _selected_rows( sk, _SCALING_OVERVIEW_KEYS, section="scaling_overview" ) coord_ranges = ( sk.get("coord_ranges", {}) if isinstance(sk, dict) else {} ) if isinstance(coord_ranges, dict): for key, value in coord_ranges.items(): rows.append( { "section": "coord_ranges", "key": f"coord_range_{key}", "value": value, } ) bounds_cfg = ( sk.get("bounds_config", {}) if isinstance(sk, dict) else {} ) if isinstance(bounds_cfg, dict): for key in [ "mode", "kind", "include_tau", "beta", "guard", "w", "tau_w", ]: rows.append( { "section": "bounds_config", "key": f"bounds_{key}", "value": bounds_cfg.get(key), } ) return pd.DataFrame(rows)
[docs] def model_init_feature_groups_frame( manifest: ModelInitManifestLike, ) -> pd.DataFrame: """Return a tidy frame for nested feature-name groups.""" payload = _as_payload(manifest) rows: list[dict[str, Any]] = [] for group in ["static", "dynamic", "future"]: names = _feature_names(payload, group) for idx, name in enumerate(names): rows.append( { "group": group, "index": int(idx), "feature_name": str(name), } ) return pd.DataFrame(rows)
[docs] def summarize_model_init_manifest( manifest: ModelInitManifestLike, ) -> dict[str, Any]: """ Return a compact summary of a model-init manifest. The goal is not to flatten every nested key, but to expose the most decision-relevant initialization facts. """ payload = _as_payload(manifest) cfg = nested_get(payload, "config", default={}) geoprior = nested_get( payload, "config", "geoprior", default={} ) sk = nested_get( payload, "config", "scaling_kwargs", default={} ) dims = nested_get(payload, "dims", default={}) quantiles = cfg.get("quantiles") quantile_count = ( len(quantiles) if isinstance(quantiles, list) else 0 ) summary = { "model_class": payload.get("model_class"), "forecast_horizon": dims.get("forecast_horizon"), "static_input_dim": dims.get("static_input_dim"), "dynamic_input_dim": dims.get("dynamic_input_dim"), "future_input_dim": dims.get("future_input_dim"), "output_subsidence_dim": dims.get( "output_subsidence_dim" ), "output_gwl_dim": dims.get("output_gwl_dim"), "mode": cfg.get("mode"), "pde_mode": cfg.get("pde_mode"), "identifiability_regime": cfg.get( "identifiability_regime" ), "quantile_count": quantile_count, "quantiles": quantiles, "use_batch_norm": cfg.get("use_batch_norm"), "use_residuals": cfg.get("use_residuals"), "use_vsn": cfg.get("use_vsn"), "time_units": cfg.get("time_units") or sk.get("time_units"), "coords_normalized": sk.get("coords_normalized"), "coords_in_degrees": sk.get("coords_in_degrees"), "coord_order": sk.get("coord_order"), "q_kind": sk.get("Q_kind"), "gwl_kind": sk.get("gwl_kind"), "gwl_sign": sk.get("gwl_sign"), "use_head_proxy": sk.get("use_head_proxy"), "mv_prior_mode": sk.get("mv_prior_mode"), "mv_schedule_unit": sk.get("mv_schedule_unit"), "n_bounds": _bounds_count(payload), "n_static_features": len( _feature_names(payload, "static") ), "n_dynamic_features": len( _feature_names(payload, "dynamic") ), "n_future_features": len( _feature_names(payload, "future") ), "kappa_mode": geoprior.get("kappa_mode"), "offset_mode": geoprior.get("offset_mode"), "use_effective_h": geoprior.get("use_effective_h"), "hd_factor": geoprior.get("hd_factor"), "gamma_w": geoprior.get("gamma_w"), "init_kappa": geoprior.get("init_kappa"), "init_mv": geoprior.get("init_mv"), } return summary
[docs] def plot_model_init_dims( manifest: ModelInitManifestLike, *, ax: plt.Axes | None = None, title: str = "Model-init dimensions", **plot_kws: Any, ) -> plt.Axes: """Plot input/output dimensions.""" payload = _as_payload(manifest) dims = nested_get(payload, "dims", default={}) values = { str(k): float(v) for k, v in (dims or {}).items() if isinstance(v, (int, float, np.number)) } return plot_metric_bars( ax, values, title=title, **plot_kws, )
[docs] def plot_model_init_architecture( manifest: ModelInitManifestLike, *, ax: plt.Axes | None = None, title: str = "Architecture scalars", **plot_kws: Any, ) -> plt.Axes: """Plot key architecture scalars.""" payload = _as_payload(manifest) cfg = nested_get(payload, "config", default={}) values = _selected_numeric(cfg, _ARCH_KEYS) return plot_metric_bars( ax, values, title=title, **plot_kws, )
[docs] def plot_model_init_geoprior( manifest: ModelInitManifestLike, *, ax: plt.Axes | None = None, title: str = "GeoPrior initialization", **plot_kws: Any, ) -> plt.Axes: """Plot key GeoPrior physics-init scalars.""" payload = _as_payload(manifest) geoprior = nested_get( payload, "config", "geoprior", default={} ) values = _selected_numeric(geoprior, _GEO_KEYS) return plot_metric_bars( ax, values, title=title, **plot_kws, )
[docs] def plot_model_init_feature_group_sizes( manifest: ModelInitManifestLike, *, ax: plt.Axes | None = None, title: str = "Feature-group sizes", **plot_kws: Any, ) -> plt.Axes: """Plot the sizes of static/dynamic/future feature groups.""" payload = _as_payload(manifest) values = { "static": float( len(_feature_names(payload, "static")) ), "dynamic": float( len(_feature_names(payload, "dynamic")) ), "future": float( len(_feature_names(payload, "future")) ), } return plot_metric_bars( ax, values, title=title, **plot_kws, )
[docs] def plot_model_init_boolean_summary( manifest: ModelInitManifestLike, *, ax: plt.Axes | None = None, title: str = "Model-init checks", **plot_kws: Any, ) -> plt.Axes: """Plot compact initialization checks as booleans.""" payload = _as_payload(manifest) cfg = nested_get(payload, "config", default={}) geoprior = nested_get( payload, "config", "geoprior", default={} ) sk = nested_get( payload, "config", "scaling_kwargs", default={} ) checks = { "coords_normalized": bool( sk.get("coords_normalized", False) ), "coords_not_degrees": not bool( sk.get("coords_in_degrees", True) ), "use_head_proxy": bool( sk.get("use_head_proxy", False) ), "use_batch_norm": bool( cfg.get("use_batch_norm", False) ), "use_residuals": bool( cfg.get("use_residuals", False) ), "use_vsn": bool(cfg.get("use_vsn", False)), "use_effective_h": bool( geoprior.get("use_effective_h", False) ), "track_aux_metrics": bool( sk.get("track_aux_metrics", False) ), "has_bounds": _bounds_count(payload) > 0, "has_static_features": len( _feature_names(payload, "static") ) > 0, "has_dynamic_features": len( _feature_names(payload, "dynamic") ) > 0, "has_future_features": len( _feature_names(payload, "future") ) > 0, } return plot_boolean_checks( ax, checks, title=title, **plot_kws, )
[docs] def inspect_model_init_manifest( manifest: ModelInitManifestLike, *, output_dir: PathLike | None = None, stem: str = "model_init_manifest", save_figures: bool = True, ) -> dict[str, Any]: """ Inspect a model-init manifest and optionally save figures. Returns ------- dict Bundle containing summary, tabular frames, and optionally written figure paths. """ payload = _as_payload(manifest) summary_map = summarize_model_init_manifest(payload) bundle: dict[str, Any] = { "summary": summary_map, "frames": { "dims": model_init_dims_frame(payload), "architecture": model_init_architecture_frame( payload ), "geoprior": model_init_geoprior_frame(payload), "scaling_overview": model_init_scaling_overview_frame( payload ), "feature_groups": model_init_feature_groups_frame( payload ), "flattened": pd.DataFrame( [ { "key": k, "value": v, } for k, v in flatten_dict(payload).items() ] ), }, "figure_paths": {}, } if not (output_dir and save_figures): return bundle out_dir = as_path(output_dir) out_dir.mkdir(parents=True, exist_ok=True) plots = { f"{stem}_dims.png": (plot_model_init_dims, {}), f"{stem}_architecture.png": ( plot_model_init_architecture, {}, ), f"{stem}_geoprior.png": ( plot_model_init_geoprior, {}, ), f"{stem}_feature_groups.png": ( plot_model_init_feature_group_sizes, {}, ), f"{stem}_checks.png": ( plot_model_init_boolean_summary, {}, ), } for name, (func, kwargs) in plots.items(): fig, ax = plt.subplots(figsize=(8.0, 4.6)) func(payload, ax=ax, **kwargs) fig.tight_layout() path = out_dir / name fig.savefig(path, dpi=160, bbox_inches="tight") plt.close(fig) bundle["figure_paths"][name] = str(path) return bundle