Source code for geoprior.utils.inspect.scaling_kwargs

# SPDX-License-Identifier: Apache-2.0
# GeoPrior-v3
# Copyright (c) 2026-present
# Author: LKouadio <https://lkouadio.com>

"""
Scaling-kwargs generation and inspection helpers.

This module focuses on the ``scaling_kwargs.json`` artifact.
It is one of the most important configuration sidecars in the
GeoPrior workflow because it preserves the resolved scaling,
units, coordinate conventions, feature-channel identities,
and physics-specific configuration used by the model.

The functions are designed for two common uses:

1. Sphinx-Gallery examples that need a realistic
   ``scaling_kwargs`` payload without rerunning the full
   Stage-1 / Stage-2 pipeline.
2. Real workflow inspection when a user wants to verify
   coordinate normalization, SI affine maps, groundwater
   conventions, forcing interpretation, MV-prior schedule,
   and bounds before training or diagnosis.
"""

from __future__ import annotations

from collections.abc import Mapping
from pathlib import Path
from typing import Any

import matplotlib.pyplot as plt
import pandas as pd

from .utils import (
    ArtifactRecord,
    clone_artifact,
    deep_update,
    load_artifact,
    plot_boolean_checks,
    plot_metric_bars,
    read_json,
    write_json,
)

PathLike = str | Path
ScalingKwargsLike = (
    ArtifactRecord | Mapping[str, Any] | str | Path
)

__all__ = [
    "default_scaling_kwargs_payload",
    "generate_scaling_kwargs",
    "inspect_scaling_kwargs",
    "load_scaling_kwargs",
    "plot_scaling_kwargs_affine_maps",
    "plot_scaling_kwargs_boolean_summary",
    "plot_scaling_kwargs_bounds",
    "plot_scaling_kwargs_coord_ranges",
    "plot_scaling_kwargs_feature_group_sizes",
    "plot_scaling_kwargs_schedule_scalars",
    "scaling_kwargs_affine_frame",
    "scaling_kwargs_bounds_frame",
    "scaling_kwargs_coord_frame",
    "scaling_kwargs_feature_channels_frame",
    "scaling_kwargs_schedule_frame",
    "summarize_scaling_kwargs",
]

_AFFINE_KEYS = [
    "subs_scale_si",
    "subs_bias_si",
    "head_scale_si",
    "head_bias_si",
    "H_scale_si",
    "H_bias_si",
]

_SCHEDULE_KEYS = [
    "clip_global_norm",
    "cons_scale_floor",
    "gw_scale_floor",
    "dt_min_units",
    "mv_weight",
    "mv_alpha_disp",
    "mv_huber_delta",
    "mv_delay_epochs",
    "mv_warmup_epochs",
    "mv_delay_steps",
    "mv_warmup_steps",
]

_BOUND_ORDER = [
    "H_min",
    "H_max",
    "K_min",
    "K_max",
    "Ss_min",
    "Ss_max",
    "tau_min",
    "tau_max",
    "logK_min",
    "logK_max",
    "logSs_min",
    "logSs_max",
    "logTau_min",
    "logTau_max",
]

_BOOLEAN_KEYS = [
    "allow_subs_residual",
    "coords_normalized",
    "coords_in_degrees",
    "Q_wrt_normalized_time",
    "Q_in_si",
    "Q_in_per_second",
    "Q_length_in_si",
    "debug_physics_grads",
    "cons_stop_grad_ref",
    "use_head_proxy",
    "track_aux_metrics",
]


def _as_payload(
    payload: ScalingKwargsLike,
) -> dict[str, Any]:
    """Return a plain scaling-kwargs mapping."""
    if isinstance(payload, ArtifactRecord):
        return dict(payload.payload)

    if isinstance(payload, Mapping):
        return dict(payload)

    data = read_json(payload)
    return dict(data)


def _try_float(
    value: Any,
) -> float | None:
    """Return ``value`` as float when possible."""
    try:
        if value is None:
            return None
        return float(value)
    except Exception:
        return None


def _string_or_none(
    value: Any,
) -> str | None:
    """Return a compact string or ``None``."""
    if value is None:
        return None
    text = str(value).strip()
    return text or None


def _numeric_subset(
    mapping: dict[str, Any] | None,
    *,
    keys: list[str] | tuple[str, ...] | None = None,
) -> dict[str, float]:
    """Return selected numeric scalar items."""
    src = mapping or {}
    keep = None if keys is None else {str(k) for k in keys}

    out: dict[str, float] = {}
    for key, value in src.items():
        if isinstance(value, bool):
            continue
        num = _try_float(value)
        if num is None:
            continue
        if keep is not None and str(key) not in keep:
            continue
        out[str(key)] = float(num)
    return out


def _feature_count(
    payload: dict[str, Any],
    key: str,
) -> int:
    """Return length of a feature-name list."""
    vals = payload.get(key, None)
    if isinstance(vals, list):
        return int(len(vals))
    return 0


[docs] def default_scaling_kwargs_payload( *, time_units: str = "year", coords_normalized: bool = True, coords_in_degrees: bool = False, coord_mode: str = "degrees", coord_order: list[str] | None = None, coord_ranges: dict[str, float] | None = None, bounds: dict[str, Any] | None = None, dynamic_feature_names: list[str] | None = None, future_feature_names: list[str] | None = None, static_feature_names: list[str] | None = None, gwl_dyn_name: str = "GWL_depth_bgs_m__si", gwl_dyn_index: int = 0, subs_dyn_index: int = 1, z_surf_static_index: int = 11, gwl_kind: str = "depth_bgs", gwl_sign: str = "down_positive", use_head_proxy: bool = True, q_kind: str = "per_volume", mv_prior_mode: str = "calibrate", mv_prior_units: str = "strict", ) -> dict[str, Any]: """ Build a realistic default scaling-kwargs payload. The payload is intentionally template-based. It mirrors the resolved ``scaling_kwargs.json`` structure written by Stage-2 and preserved in manifests, while staying lightweight enough for documentation examples. """ return { "subs_scale_si": 1.0, "subs_bias_si": 0.0, "head_scale_si": 1.0, "head_bias_si": 0.0, "H_scale_si": 1.0, "H_bias_si": 0.0, "subsidence_kind": "cumulative", "allow_subs_residual": True, "coords_normalized": bool(coords_normalized), "coord_order": list(coord_order or ["t", "x", "y"]), "coord_ranges": dict( coord_ranges or { "t": 7.0, "x": 44447.0, "y": 39275.0, } ), "coord_mode": coord_mode, "coord_src_epsg": 4326, "coord_target_epsg": 32649, "coord_epsg_used": 32649, "coords_in_degrees": bool(coords_in_degrees), "cons_residual_units": "second", "cons_scale_floor": 3e-11, "gw_scale_floor": 1e-12, "dt_min_units": 1e-6, "Q_wrt_normalized_time": False, "Q_in_si": False, "Q_in_per_second": False, "Q_kind": q_kind, "Q_length_in_si": False, "drainage_mode": "double", "scaling_error_policy": "raise", "debug_physics_grads": False, "gw_residual_units": "second", "clip_global_norm": 5.0, "cons_drawdown_mode": "softplus", "cons_drawdown_rule": "ref_minus_mean", "cons_stop_grad_ref": True, "cons_drawdown_zero_at_origin": False, "cons_relu_beta": 20.0, "mv_prior_units": mv_prior_units, "mv_alpha_disp": 0.1, "mv_huber_delta": 1.0, "mv_prior_mode": mv_prior_mode, "mv_weight": 0.001, "mv_schedule_unit": "epoch", "mv_delay_epochs": 1, "mv_warmup_epochs": 2, "mv_delay_steps": 4261, "mv_warmup_steps": 8522, "track_aux_metrics": False, "gwl_dyn_index": int(gwl_dyn_index), "subs_dyn_index": int(subs_dyn_index), "gwl_dyn_name": gwl_dyn_name, "z_surf_static_index": int(z_surf_static_index), "gwl_col": gwl_dyn_name, "gwl_dyn_col": gwl_dyn_name, "gwl_target_col": "head_m__si", "subs_model_col": "subsidence_cum__si", "z_surf_col": "z_surf_m__si", "gwl_kind": gwl_kind, "gwl_sign": gwl_sign, "gwl_driver_kind": "depth", "gwl_driver_sign": "down_positive", "gwl_target_kind": "head", "gwl_target_sign": "up_positive", "use_head_proxy": bool(use_head_proxy), "time_units": time_units, "gwl_z_meta": { "raw_kind": "depth_bgs", "raw_sign": "down_positive", "driver_kind": "depth", "driver_sign": "down_positive", "target_kind": "head", "target_sign": "up_positive", "use_head_proxy": bool(use_head_proxy), "z_surf_col": "z_surf_m__si", "head_from_depth_rule": "z_surf - depth", "cols": { "depth_raw": "GWL_depth_bgs_m", "head_raw": "head_m", "z_surf_raw": "z_surf_m", "depth_model": "GWL_depth_bgs_m__si", "head_model": "head_m__si", "z_surf_static": "z_surf_m__si", "subs_model": "subsidence_cum__si", }, }, "bounds": dict( bounds or { "H_min": 0.1, "H_max": 30.0, "K_min": 1e-12, "K_max": 1e-7, "Ss_min": 1e-6, "Ss_max": 1e-3, "tau_min": 604800.0, "tau_max": 9467085600.0, "logK_min": -27.6310211159, "logK_max": -16.1180956510, "logSs_min": -13.8155105580, "logSs_max": -6.9077552790, "logTau_min": 13.3126531038, "logTau_max": 22.9710869460, } ), "dynamic_feature_names": list( dynamic_feature_names or [ "GWL_depth_bgs_m__si", "subsidence_cum__si", "rainfall_mm", "urban_load_global", "soil_thickness_censored", ] ), "future_feature_names": list( future_feature_names or ["rainfall_mm"] ), "static_feature_names": list( static_feature_names or [ "lithology_Conglomerate–Sandstone", "lithology_Limestone–Sandstone", "lithology_Mudstone–Siltstone", "lithology_Sandstone–Siltstone", "lithology_Shale–Limestone", "lithology_Siltstone–Sandstone", "lithology_Siltstone–Shale", "lithology_Tuff–Sandstone", "lithology_class_Coarse-Grained Soil", "lithology_class_Fine-Grained Soil", "lithology_class_Mixed Clastics", "z_surf_m__si", ] ), }
[docs] def generate_scaling_kwargs( path: PathLike, *, template: ScalingKwargsLike | None = None, overrides: dict[str, Any] | None = None, **kwargs: Any, ) -> Path: """ Generate a scaling-kwargs JSON artifact. Parameters ---------- path : path-like Output path for the JSON file. template : mapping or path-like, optional Existing payload used as a template. overrides : dict, optional Nested override values applied after the template. **kwargs : Any Convenience keyword overrides forwarded into the default payload builder when no explicit template is provided. Returns ------- pathlib.Path Written JSON path. """ if template is None: payload = default_scaling_kwargs_payload(**kwargs) else: payload = clone_artifact(_as_payload(template)) if kwargs: payload = deep_update(payload, kwargs) payload = deep_update(payload, overrides) return write_json(payload, path)
[docs] def load_scaling_kwargs( path: PathLike, ) -> ArtifactRecord: """ Load a scaling-kwargs artifact. Returns ------- ArtifactRecord Normalized artifact wrapper. """ return load_artifact(path, kind="scaling_kwargs")
[docs] def scaling_kwargs_affine_frame( payload: ScalingKwargsLike, ) -> pd.DataFrame: """Return affine SI-map rows.""" data = _as_payload(payload) rows: list[dict[str, Any]] = [] unit_map = { "subs_scale_si": "m / model-unit", "subs_bias_si": "m", "head_scale_si": "m / model-unit", "head_bias_si": "m", "H_scale_si": "m / model-unit", "H_bias_si": "m", } for key in _AFFINE_KEYS: rows.append( { "parameter": key, "value": _try_float(data.get(key)), "unit": unit_map.get(key), } ) return pd.DataFrame(rows)
[docs] def scaling_kwargs_coord_frame( payload: ScalingKwargsLike, ) -> pd.DataFrame: """Return coordinate and convention rows.""" data = _as_payload(payload) ranges = data.get("coord_ranges", {}) or {} rows = [ { "section": "coord_meta", "name": "coord_mode", "value": _string_or_none(data.get("coord_mode")), }, { "section": "coord_meta", "name": "coord_order", "value": ", ".join( map(str, data.get("coord_order", []) or []) ), }, { "section": "coord_meta", "name": "coord_src_epsg", "value": data.get("coord_src_epsg"), }, { "section": "coord_meta", "name": "coord_target_epsg", "value": data.get("coord_target_epsg"), }, { "section": "coord_meta", "name": "coord_epsg_used", "value": data.get("coord_epsg_used"), }, { "section": "coord_meta", "name": "time_units", "value": _string_or_none(data.get("time_units")), }, ] for key in ["t", "x", "y"]: rows.append( { "section": "coord_ranges", "name": key, "value": _try_float(ranges.get(key)), } ) return pd.DataFrame(rows)
[docs] def scaling_kwargs_bounds_frame( payload: ScalingKwargsLike, ) -> pd.DataFrame: """Return tidy bounds rows.""" data = _as_payload(payload) bounds = data.get("bounds", {}) or {} rows: list[dict[str, Any]] = [] seen: set[str] = set() for key in _BOUND_ORDER: seen.add(key) rows.append( { "bound": key, "value": _try_float(bounds.get(key)), } ) for key in sorted(bounds): if key in seen: continue rows.append( { "bound": str(key), "value": _try_float(bounds.get(key)), } ) return pd.DataFrame(rows)
[docs] def scaling_kwargs_feature_channels_frame( payload: ScalingKwargsLike, ) -> pd.DataFrame: """Return feature-group and channel rows.""" data = _as_payload(payload) rows: list[dict[str, Any]] = [ { "group": "feature_counts", "name": "dynamic_feature_names", "value": _feature_count( data, "dynamic_feature_names", ), }, { "group": "feature_counts", "name": "future_feature_names", "value": _feature_count( data, "future_feature_names", ), }, { "group": "feature_counts", "name": "static_feature_names", "value": _feature_count( data, "static_feature_names", ), }, { "group": "channel_indices", "name": "gwl_dyn_index", "value": _try_float(data.get("gwl_dyn_index")), }, { "group": "channel_indices", "name": "subs_dyn_index", "value": _try_float(data.get("subs_dyn_index")), }, { "group": "channel_indices", "name": "z_surf_static_index", "value": _try_float( data.get("z_surf_static_index") ), }, { "group": "channel_names", "name": "gwl_dyn_name", "value": _string_or_none( data.get("gwl_dyn_name") ), }, { "group": "channel_names", "name": "gwl_col", "value": _string_or_none(data.get("gwl_col")), }, { "group": "channel_names", "name": "gwl_target_col", "value": _string_or_none( data.get("gwl_target_col") ), }, { "group": "channel_names", "name": "subs_model_col", "value": _string_or_none( data.get("subs_model_col") ), }, { "group": "channel_names", "name": "z_surf_col", "value": _string_or_none(data.get("z_surf_col")), }, ] return pd.DataFrame(rows)
[docs] def scaling_kwargs_schedule_frame( payload: ScalingKwargsLike, ) -> pd.DataFrame: """Return Q/MV schedule and runtime scalar rows.""" data = _as_payload(payload) rows: list[dict[str, Any]] = [] for key in _SCHEDULE_KEYS: rows.append( { "section": "scalar_schedule", "name": key, "value": _try_float(data.get(key)), } ) for key in [ "Q_kind", "mv_prior_mode", "mv_prior_units", "mv_schedule_unit", "cons_residual_units", "gw_residual_units", "drainage_mode", "scaling_error_policy", "gwl_kind", "gwl_sign", ]: rows.append( { "section": "modes", "name": key, "value": _string_or_none(data.get(key)), } ) return pd.DataFrame(rows)
[docs] def summarize_scaling_kwargs( payload: ScalingKwargsLike, ) -> dict[str, Any]: """ Return a compact high-level scaling summary. Returns ------- dict Compact summary intended for logs or gallery prose. """ data = _as_payload(payload) bounds = data.get("bounds", {}) or {} return { "time_units": _string_or_none(data.get("time_units")), "coord_mode": _string_or_none(data.get("coord_mode")), "coord_order": list( data.get("coord_order", []) or [] ), "coords_normalized": bool( data.get("coords_normalized", False) ), "coords_in_degrees": bool( data.get("coords_in_degrees", False) ), "coord_ranges": dict( data.get("coord_ranges", {}) or {} ), "gwl_kind": _string_or_none(data.get("gwl_kind")), "gwl_sign": _string_or_none(data.get("gwl_sign")), "use_head_proxy": bool( data.get("use_head_proxy", False) ), "Q_kind": _string_or_none(data.get("Q_kind")), "mv_prior_mode": _string_or_none( data.get("mv_prior_mode") ), "mv_schedule_unit": _string_or_none( data.get("mv_schedule_unit") ), "dynamic_features": _feature_count( data, "dynamic_feature_names", ), "future_features": _feature_count( data, "future_feature_names", ), "static_features": _feature_count( data, "static_feature_names", ), "n_bounds": int(len(bounds)), "has_bounds": bool(bounds), "has_gwl_z_meta": isinstance( data.get("gwl_z_meta"), dict, ), "affine_maps": _numeric_subset( data, keys=_AFFINE_KEYS, ), }
[docs] def plot_scaling_kwargs_affine_maps( payload: ScalingKwargsLike, *, ax: plt.Axes | None = None, title: str = "Scaling affine maps", error: str = "ignore", **plot_kws: Any, ) -> plt.Axes: """Plot subs/head/H affine map scalars.""" data = _as_payload(payload) metrics = _numeric_subset(data, keys=_AFFINE_KEYS) if ax is None: _, ax = plt.subplots(figsize=(8.0, 4.2)) return plot_metric_bars( ax, metrics, title=title, sort_by_value=False, annotate=True, error=error, **plot_kws, )
[docs] def plot_scaling_kwargs_coord_ranges( payload: ScalingKwargsLike, *, ax: plt.Axes | None = None, title: str = "Coordinate ranges", error: str = "ignore", **plot_kws: Any, ) -> plt.Axes: """Plot ``coord_ranges`` for t/x/y.""" data = _as_payload(payload) metrics = _numeric_subset( data.get("coord_ranges", {}) or {} ) if ax is None: _, ax = plt.subplots(figsize=(7.2, 3.8)) return plot_metric_bars( ax, metrics, title=title, sort_by_value=False, annotate=True, error=error, **plot_kws, )
[docs] def plot_scaling_kwargs_bounds( payload: ScalingKwargsLike, *, ax: plt.Axes | None = None, title: str = "Bounds overview", top_n: int | None = None, error: str = "ignore", **plot_kws: Any, ) -> plt.Axes: """Plot numeric bounds as a compact bar chart.""" frame = scaling_kwargs_bounds_frame(payload) frame = frame.dropna(subset=["value"]).rename( columns={"bound": "metric"} ) if ax is None: _, ax = plt.subplots(figsize=(8.4, 5.0)) return plot_metric_bars( ax, frame, title=title, top_n=top_n, sort_by_value=False, annotate=False, error=error, **plot_kws, )
[docs] def plot_scaling_kwargs_schedule_scalars( payload: ScalingKwargsLike, *, ax: plt.Axes | None = None, title: str = "Schedule and runtime scalars", error: str = "ignore", **plot_kws: Any, ) -> plt.Axes: """Plot selected numeric schedule/runtime scalars.""" data = _as_payload(payload) metrics = _numeric_subset(data, keys=_SCHEDULE_KEYS) if ax is None: _, ax = plt.subplots(figsize=(8.4, 4.8)) return plot_metric_bars( ax, metrics, title=title, sort_by_value=False, annotate=False, error=error, **plot_kws, )
[docs] def plot_scaling_kwargs_feature_group_sizes( payload: ScalingKwargsLike, *, ax: plt.Axes | None = None, title: str = "Feature group sizes", error: str = "ignore", **plot_kws: Any, ) -> plt.Axes: """Plot dynamic/future/static feature-group counts.""" data = _as_payload(payload) metrics = { "dynamic_feature_names": _feature_count( data, "dynamic_feature_names", ), "future_feature_names": _feature_count( data, "future_feature_names", ), "static_feature_names": _feature_count( data, "static_feature_names", ), } if ax is None: _, ax = plt.subplots(figsize=(7.2, 3.6)) return plot_metric_bars( ax, metrics, title=title, sort_by_value=False, annotate=True, error=error, **plot_kws, )
[docs] def plot_scaling_kwargs_boolean_summary( payload: ScalingKwargsLike, *, ax: plt.Axes | None = None, title: str = "Scaling boolean checks", error: str = "ignore", **plot_kws: Any, ) -> plt.Axes: """Plot common boolean config flags.""" data = _as_payload(payload) checks = { key: bool(data.get(key, False)) for key in _BOOLEAN_KEYS if key in data } if ax is None: _, ax = plt.subplots(figsize=(8.2, 4.2)) return plot_boolean_checks( ax, checks, title=title, error=error, **plot_kws, )
[docs] def inspect_scaling_kwargs( payload: ScalingKwargsLike, ) -> dict[str, Any]: """ Return a compact multi-view inspection bundle. Returns ------- dict Dictionary containing summary plus a set of tidy frames. """ data = _as_payload(payload) return { "summary": summarize_scaling_kwargs(data), "affine": scaling_kwargs_affine_frame(data), "coords": scaling_kwargs_coord_frame(data), "bounds": scaling_kwargs_bounds_frame(data), "features": scaling_kwargs_feature_channels_frame( data ), "schedule": scaling_kwargs_schedule_frame(data), }