Source code for geoprior.utils.inspect.training_summary

# SPDX-License-Identifier: Apache-2.0
# GeoPrior-v3
# Copyright (c) 2026-present
# Author: LKouadio <https://lkouadio.com>

"""
Training-summary generation and inspection helpers.

This module focuses on the Stage-2 training summary
artifact. It provides:

- robust loading,
- reproducible demo-summary generation,
- compact tabular summaries,
- quick visual inspection helpers.

The functions are designed for two common uses:

1. Sphinx-Gallery examples that need a realistic
   training summary without rerunning training.
2. Real workflow inspection when a user wants to
   review best/final metrics, compile settings,
   and initialization choices at a glance.
"""

from __future__ import annotations

from collections.abc import Mapping
from pathlib import Path
from typing import Any

import matplotlib.pyplot as plt
import pandas as pd

from .utils import (
    ArtifactRecord,
    as_path,
    clone_artifact,
    deep_update,
    empty_plot,
    filter_plot_kwargs,
    finalize_plot,
    flatten_dict,
    load_artifact,
    nested_get,
    plot_boolean_checks,
    plot_metric_bars,
    prepare_plot,
    read_json,
    write_json,
)

PathLike = str | Path
TrainingSummaryLike = (
    ArtifactRecord | Mapping[str, Any] | str | Path
)

__all__ = [
    "default_training_summary_payload",
    "generate_training_summary",
    "inspect_training_summary",
    "load_training_summary",
    "plot_training_best_metrics",
    "plot_training_boolean_summary",
    "plot_training_final_metrics",
    "plot_training_loss_family",
    "plot_training_metric_deltas",
    "training_compile_frame",
    "training_env_frame",
    "training_hp_frame",
    "training_metrics_frame",
    "training_paths_frame",
    "summarize_training_summary",
]

_CORE_METRICS = [
    "loss",
    "data_loss",
    "physics_loss",
    "physics_loss_scaled",
    "subs_pred_mae_q50",
    "gwl_pred_mae_q50",
    "subs_pred_coverage80",
    "subs_pred_sharpness80",
    "epsilon_prior",
    "epsilon_cons",
    "epsilon_gw",
    "lambda_offset",
    "physics_mult",
]

_LOSS_FAMILY = [
    "loss",
    "data_loss",
    "physics_loss",
    "physics_loss_scaled",
    "consolidation_loss",
    "gw_flow_loss",
    "prior_loss",
    "smooth_loss",
    "mv_prior_loss",
    "bounds_loss",
    "q_reg_loss",
]


def _as_payload(
    summary: TrainingSummaryLike,
) -> dict[str, Any]:
    """Return a plain training-summary payload."""
    if isinstance(summary, ArtifactRecord):
        return dict(summary.payload)

    if isinstance(summary, Mapping):
        return dict(summary)

    payload = read_json(summary)
    return dict(payload)


def _default_metric_block(
    *,
    val_scale: float = 0.95,
) -> dict[str, float]:
    """Build a realistic core metric block."""
    base = {
        "bounds_loss": 1.0e-12,
        "consolidation_loss": 1.7e-5,
        "data_loss": 0.0568,
        "epsilon_cons": 5.8e-3,
        "epsilon_cons_raw": 4.1e-9,
        "epsilon_gw": 5.5e-7,
        "epsilon_gw_raw": 4.4e-13,
        "epsilon_prior": 7.7e-4,
        "gw_flow_loss": 1.9e-13,
        "gwl_pred_mae_q50": 0.240,
        "gwl_pred_mse_q50": 0.0777,
        "lambda_offset": 1.0,
        "loss": 0.0568,
        "mv_prior_loss": 9.0e-7,
        "physics_loss": 1.71e-5,
        "physics_loss_scaled": 1.71e-5,
        "physics_mult": 1.0,
        "prior_loss": 1.44e-7,
        "q_gate": 1.0,
        "q_reg_loss": 0.0,
        "q_rms": 4.6e-13,
        "smooth_loss": 0.0,
        "subs_pred_coverage80": 0.8000,
        "subs_pred_mae_q50": 0.0101,
        "subs_pred_mse_q50": 2.72e-4,
        "subs_pred_sharpness80": 0.0307,
        "subs_resid_gate": 1.0,
        "total_loss": 0.0568,
    }

    val = {}
    for key, value in base.items():
        if key in {
            "lambda_offset",
            "physics_mult",
            "q_gate",
            "subs_resid_gate",
        }:
            val[f"val_{key}"] = float(value)
        elif key.endswith("coverage80"):
            val[f"val_{key}"] = float(
                min(1.0, max(0.0, value * 1.02))
            )
        else:
            val[f"val_{key}"] = float(value) * float(
                val_scale
            )

    out = dict(base)
    out.update(val)
    return out


def _split_metric_name(
    name: str,
) -> tuple[str, str]:
    """Return ``(split, metric_name)`` from a metric key."""
    text = str(name)
    if text.startswith("val_"):
        return "val", text[4:]
    return "train", text


def _filter_metrics(
    mapping: dict[str, Any] | None,
    *,
    split: str = "all",
    keys: list[str] | tuple[str, ...] | None = None,
) -> dict[str, float]:
    """Filter scalar metrics by split and optional key set."""
    split_mode = str(split).strip().lower()
    keep = None if keys is None else {str(k) for k in keys}

    out: dict[str, float] = {}
    for name, value in (mapping or {}).items():
        if isinstance(value, bool) or not isinstance(
            value,
            (int, float),
        ):
            continue

        metric_split, metric_name = _split_metric_name(name)
        if split_mode != "all" and metric_split != split_mode:
            continue
        if keep is not None and metric_name not in keep:
            continue
        out[str(metric_name)] = float(value)
    return out


def _delta_metrics(
    best: dict[str, Any] | None,
    final: dict[str, Any] | None,
    *,
    split: str = "train",
    keys: list[str] | tuple[str, ...] | None = None,
) -> dict[str, float]:
    """Return ``final - best`` deltas for aligned metrics."""
    best_map = _filter_metrics(
        best,
        split=split,
        keys=keys,
    )
    final_map = _filter_metrics(
        final,
        split=split,
        keys=keys,
    )

    common = sorted(set(best_map) & set(final_map))
    return {
        name: float(final_map[name] - best_map[name])
        for name in common
    }


[docs] def default_training_summary_payload( *, city: str = "demo_city", model: str = "GeoPriorSubsNet", horizon: int = 3, best_epoch: int = 17, timestamp: str = "20260222-211635", optimizer: str = "Adam", learning_rate: float = 1e-3, time_steps: int = 5, pde_mode: str = "on", offset_mode: str = "mul", quantiles: list[float] | None = None, attention_levels: list[str] | None = None, coords_normalized: bool = True, coord_ranges: dict[str, float] | None = None, run_dir: str = "results/demo_run/train_20260222-211635", ) -> dict[str, Any]: """ Build a realistic default training-summary payload. The payload is template-based. It is not meant to reproduce the full training loop. Instead, it creates a stable and inspectable summary artifact with the same broad structure as the real training summary. """ q = list(quantiles or [0.1, 0.5, 0.9]) attn = list( attention_levels or ["cross", "hierarchical", "memory"] ) coord_ranges = dict( coord_ranges or {"t": 7.0, "x": 44447.0, "y": 39275.0} ) metrics_best = _default_metric_block(val_scale=0.87) metrics_final = _default_metric_block(val_scale=0.90) metrics_final["loss"] = 0.05445 metrics_final["data_loss"] = 0.05441 metrics_final["physics_loss"] = 3.51e-5 metrics_final["physics_loss_scaled"] = 3.51e-5 metrics_final["epsilon_prior"] = 5.89e-4 metrics_final["val_loss"] = 0.04951 metrics_final["val_data_loss"] = 0.04948 metrics_final["val_physics_loss"] = 6.9e-5 metrics_final["val_physics_loss_scaled"] = 6.9e-5 payload = { "timestamp": str(timestamp), "city": city, "model": model, "horizon": int(horizon), "best_epoch": int(best_epoch), "metrics_at_best": metrics_best, "final_epoch_metrics": metrics_final, "env": { "python": "3.10.19", "tensorflow": "2.20.0", "numpy": "2.0.2", "platform": "Windows-10-demo", "device": { "has_tf": True, "device_mode_requested": "auto", "device_mode_effective": "cpu", "num_cpus": 1, "num_gpus": 0, "visible_gpus": [], "intra_threads": None, "inter_threads": None, "gpu_memory_growth": None, "gpu_memory_limit_mb": None, }, }, "compile": { "optimizer": str(optimizer), "learning_rate": float(learning_rate), "loss_weights": { "subs_pred": 1.0, "gwl_pred": 0.8, }, "metrics": { "subs_pred": [ "MAEQ50", "MSEQ50", "Coverage80", "Sharpness80", ], "gwl_pred": ["MAEQ50", "MSEQ50"], }, "physics_loss_weights": { "lambda_cons": 1.0, "lambda_gw": 0.1, "lambda_prior": 0.2, "lambda_smooth": 0.01, "lambda_bounds": 0.05, "lambda_mv": 0.01, "mv_lr_mult": 1.0, "lambda_offset": 1.0, "kappa_lr_mult": 5.0, "lambda_q": 5.0e-4, }, "lambda_offset": 1.0, }, "hp_init": { "quantiles": q, "subs_weights": { "0.1": 3.0, "0.5": 1.0, "0.9": 3.0, }, "gwl_weights": { "0.1": 1.5, "0.5": 1.0, "0.9": 1.5, }, "attention_levels": attn, "pde_mode": str(pde_mode), "time_steps": int(time_steps), "use_batch_norm": False, "use_vsn": True, "vsn_units": 32, "mode": "tft_like", "model_init_params": { "embed_dim": 32, "hidden_units": 64, "lstm_units": 64, "attention_units": 64, "num_heads": 2, "dropout_rate": 0.1, "memory_size": 50, "scales": [1, 2], "use_residuals": True, "use_batch_norm": False, "use_vsn": True, "vsn_units": 32, "mode": "tft_like", "attention_levels": attn, "scale_pde_residuals": True, "scaling_kwargs": { "time_units": "year", "coords_normalized": bool( coords_normalized ), "coord_ranges": coord_ranges, "coord_order": ["t", "x", "y"], "gwl_kind": "depth_bgs", "gwl_sign": "down_positive", "use_head_proxy": True, "Q_kind": "per_volume", }, "bounds_mode": "soft", "mv": { "type": "LearnableMV", "initial_value": 1e-7, }, "kappa": { "type": "LearnableKappa", "initial_value": 1.0, }, "gamma_w": { "type": "FixedGammaW", "value": 9810.0, }, "h_ref": { "type": "FixedHRef", "value": 0.0, }, "kappa_mode": "kb", "use_effective_h": True, "hd_factor": 0.6, "offset_mode": str(offset_mode), "residual_method": "exact", "time_units": "year", }, "offset_mode": str(offset_mode), "scaling_kwargs": { "bounds": { "H_min": 0.1, "H_max": 30.0, "K_min": 1e-12, "K_max": 1e-7, }, "time_units": "year", "coords_normalized": bool(coords_normalized), "coord_ranges": coord_ranges, }, "identifiability_regime": None, }, "paths": { "run_dir": str(run_dir), "weights_h5": ( f"{run_dir}/{city}_{model}_H{horizon}.weights.h5" ), "arch_json": ( f"{run_dir}/{city}_{model}_architecture.json" ), "csv_log": ( f"{run_dir}/{city}_{model}_train_log.csv" ), "best_keras": ( f"{run_dir}/{city}_{model}_H{horizon}_best.keras" ), "best_weights": ( f"{run_dir}/{city}_{model}_H{horizon}_best.weights.h5" ), "model_init_manifest": ( f"{run_dir}/model_init_manifest.json" ), "final_keras": ( f"{run_dir}/{city}_{model}_H{horizon}_final.keras" ), }, } return payload
[docs] def generate_training_summary( *, output_path: PathLike | None = None, template: TrainingSummaryLike | None = None, overrides: dict[str, Any] | None = None, **kwargs, ) -> dict[str, Any] | Path: """ Generate a training-summary payload or file. Parameters ---------- output_path : path-like, optional Destination JSON path. If omitted, the payload is returned instead of written. template : mapping, ArtifactRecord, or path, optional Real or synthetic training-summary template used as the generation base. overrides : dict, optional Nested overrides applied after template/default payload creation. **kwargs : dict Parameters forwarded to ``default_training_summary_payload`` when no template is given. """ if template is None: payload = default_training_summary_payload(**kwargs) else: payload = clone_artifact( _as_payload(template), overrides=None, ) if overrides: payload = deep_update(payload, overrides) if output_path is None: return payload return write_json(payload, output_path)
[docs] def load_training_summary( path: PathLike, ) -> ArtifactRecord: """ Load a training-summary artifact. Raises ------ ValueError If the artifact does not look like a training summary payload. """ record = load_artifact(path, kind="training_summary") needed = { "metrics_at_best", "final_epoch_metrics", "compile", "hp_init", } if not needed.issubset(record.payload): raise ValueError( "The file does not contain the expected " "training-summary sections." ) return record
[docs] def training_metrics_frame( summary: TrainingSummaryLike, *, section: str = "metrics_at_best", split: str = "all", ) -> pd.DataFrame: """ Return a tidy frame for train/validation metrics. """ payload = _as_payload(summary) src = payload.get(section, {}) or {} rows: list[dict[str, Any]] = [] for name, value in src.items(): if isinstance(value, bool) or not isinstance( value, (int, float), ): continue metric_split, metric_name = _split_metric_name(name) if str(split).strip().lower() != "all": if metric_split != str(split).strip().lower(): continue rows.append( { "section": section, "split": metric_split, "metric": metric_name, "value": float(value), } ) frame = pd.DataFrame(rows) if not frame.empty: frame = frame.sort_values(["split", "metric"]) return frame.reset_index(drop=True)
[docs] def training_env_frame( summary: TrainingSummaryLike, ) -> pd.DataFrame: """Return a tidy frame for environment info.""" payload = _as_payload(summary) env = payload.get("env", {}) or {} flat = flatten_dict(env) rows = [] for key, value in flat.items(): rows.append( { "key": str(key), "value": value, "is_numeric": isinstance( value, (int, float), ) and not isinstance(value, bool), } ) return pd.DataFrame(rows)
[docs] def training_compile_frame( summary: TrainingSummaryLike, ) -> pd.DataFrame: """Return a tidy frame for compile settings.""" payload = _as_payload(summary) compile_cfg = payload.get("compile", {}) or {} flat = flatten_dict(compile_cfg) rows = [] for key, value in flat.items(): rows.append( { "key": str(key), "value": value, "is_numeric": isinstance( value, (int, float), ) and not isinstance(value, bool), } ) return pd.DataFrame(rows)
[docs] def training_hp_frame( summary: TrainingSummaryLike, ) -> pd.DataFrame: """Return a tidy frame for hp/init settings.""" payload = _as_payload(summary) hp = payload.get("hp_init", {}) or {} flat = flatten_dict(hp) rows = [] for key, value in flat.items(): rows.append( { "key": str(key), "value": value, "is_numeric": isinstance( value, (int, float), ) and not isinstance(value, bool), } ) return pd.DataFrame(rows)
[docs] def training_paths_frame( summary: TrainingSummaryLike, ) -> pd.DataFrame: """Return a tidy frame for output paths.""" payload = _as_payload(summary) paths = payload.get("paths", {}) or {} return pd.DataFrame( { "key": list(paths.keys()), "value": list(paths.values()), } )
[docs] def summarize_training_summary( summary: TrainingSummaryLike, ) -> dict[str, Any]: """ Build a compact semantic summary for inspection. """ payload = _as_payload(summary) best = payload.get("metrics_at_best", {}) or {} final = payload.get("final_epoch_metrics", {}) or {} compile_cfg = payload.get("compile", {}) or {} hp = payload.get("hp_init", {}) or {} paths = payload.get("paths", {}) or {} best_train = _filter_metrics(best, split="train") best_val = _filter_metrics(best, split="val") final_train = _filter_metrics(final, split="train") final_val = _filter_metrics(final, split="val") saved_model_keys = { "best_keras", "best_weights", "final_keras", } summary_map = { "brief": { "kind": "training_summary", "city": payload.get("city"), "model": payload.get("model"), "timestamp": payload.get("timestamp"), "horizon": payload.get("horizon"), "best_epoch": payload.get("best_epoch"), }, "core_metrics": { "best_train_loss": best_train.get("loss"), "best_val_loss": best_val.get("loss"), "final_train_loss": final_train.get("loss"), "final_val_loss": final_val.get("loss"), "best_train_subs_mae_q50": best_train.get( "subs_pred_mae_q50" ), "best_val_subs_mae_q50": best_val.get( "subs_pred_mae_q50" ), "best_train_gwl_mae_q50": best_train.get( "gwl_pred_mae_q50" ), "best_val_gwl_mae_q50": best_val.get( "gwl_pred_mae_q50" ), "best_val_coverage80": best_val.get( "subs_pred_coverage80" ), "best_val_sharpness80": best_val.get( "subs_pred_sharpness80" ), "delta_final_minus_best_val_loss": ( None if best_val.get("loss") is None or final_val.get("loss") is None else float( final_val["loss"] - best_val["loss"] ) ), }, "compile": { "optimizer": compile_cfg.get("optimizer"), "learning_rate": compile_cfg.get("learning_rate"), "lambda_offset": compile_cfg.get("lambda_offset"), "loss_weight_keys": list( (compile_cfg.get("loss_weights") or {}).keys() ), }, "checks": { "has_best_metrics": bool(best), "has_final_metrics": bool(final), "has_validation_metrics": bool(best_val), "has_physics_metrics": all( key in best_train for key in [ "physics_loss", "epsilon_prior", "epsilon_cons", "epsilon_gw", ] ), "best_epoch_is_positive": int( payload.get("best_epoch", 0) or 0 ) >= 0, "lambda_offset_stable": ( best_train.get("lambda_offset") == final_train.get("lambda_offset") ), "quantiles_defined": bool( hp.get("quantiles", []) ), "has_scaling_kwargs": bool( nested_get(hp, "scaling_kwargs", default={}) or nested_get( hp, "model_init_params", "scaling_kwargs", default={}, ) ), "has_saved_model_paths": saved_model_keys.issubset( paths.keys() ), "has_optimizer": bool( compile_cfg.get("optimizer") ), }, } return summary_map
[docs] def plot_training_best_metrics( summary: TrainingSummaryLike, *, split: str = "val", keys: list[str] | tuple[str, ...] | None = None, ax: plt.Axes | None = None, title: str | None = None, error: str = "ignore", **plot_kws: Any, ) -> plt.Axes: """ Plot selected metrics from ``metrics_at_best``. """ fig, plot_ax, _ = prepare_plot(ax=ax, figsize=(8.2, 4.8)) payload = _as_payload(summary) metrics = _filter_metrics( payload.get("metrics_at_best", {}), split=split, keys=keys or _CORE_METRICS, ) plot_metric_bars( plot_ax, metrics, title=title or f"Best metrics ({split})", sort_by_value=True, top_n=12, absolute=True, error=error, **plot_kws, ) return plot_ax
[docs] def plot_training_final_metrics( summary: TrainingSummaryLike, *, split: str = "val", keys: list[str] | tuple[str, ...] | None = None, ax: plt.Axes | None = None, title: str | None = None, error: str = "ignore", **plot_kws: Any, ) -> plt.Axes: """ Plot selected metrics from ``final_epoch_metrics``. """ fig, plot_ax, _ = prepare_plot(ax=ax, figsize=(8.2, 4.8)) payload = _as_payload(summary) metrics = _filter_metrics( payload.get("final_epoch_metrics", {}), split=split, keys=keys or _CORE_METRICS, ) plot_metric_bars( plot_ax, metrics, title=title or f"Final metrics ({split})", sort_by_value=True, top_n=12, absolute=True, error=error, **plot_kws, ) return plot_ax
[docs] def plot_training_metric_deltas( summary: TrainingSummaryLike, *, split: str = "val", keys: list[str] | tuple[str, ...] | None = None, ax: plt.Axes | None = None, title: str | None = None, xlabel: str = "delta", show_grid: bool = True, grid_kws: dict[str, Any] | None = None, annotate: bool = False, annotate_kws: dict[str, Any] | None = None, error: str = "ignore", **plot_kws: Any, ) -> plt.Axes: """ Plot ``final - best`` deltas for aligned metrics. """ fig, plot_ax, _ = prepare_plot(ax=ax, figsize=(8.2, 4.8)) payload = _as_payload(summary) deltas = _delta_metrics( payload.get("metrics_at_best", {}), payload.get("final_epoch_metrics", {}), split=split, keys=keys or _CORE_METRICS, ) if not deltas: _, plot_ax = empty_plot( fig, plot_ax, title=title or f"Metric deltas ({split})", message="No aligned delta metrics", ) return plot_ax names = list(deltas.keys()) vals = list(deltas.values()) bar_kws = filter_plot_kwargs( plot_ax.barh, plot_kws, error=error, ) bars = plot_ax.barh(names, vals, **bar_kws) _, plot_ax = finalize_plot( fig, plot_ax, title=title or f"Final - best ({split})", xlabel=xlabel, show_grid=show_grid, grid_kws=grid_kws or {"axis": "x", "alpha": 0.25}, ) if annotate: text_kws = filter_plot_kwargs( plot_ax.text, annotate_kws, error=error, ) for bar, value in zip(bars, vals, strict=False): plot_ax.text( bar.get_width(), bar.get_y() + bar.get_height() / 2.0, f" {float(value):.4g}", va="center", **text_kws, ) return plot_ax
[docs] def plot_training_loss_family( summary: TrainingSummaryLike, *, section: str = "metrics_at_best", split: str = "val", ax: plt.Axes | None = None, title: str | None = None, error: str = "ignore", **plot_kws: Any, ) -> plt.Axes: """ Plot the loss-family subset for one metric section. """ fig, plot_ax, _ = prepare_plot(ax=ax, figsize=(8.2, 4.8)) payload = _as_payload(summary) metrics = _filter_metrics( payload.get(section, {}), split=split, keys=_LOSS_FAMILY, ) plot_metric_bars( plot_ax, metrics, title=title or f"Loss family: {section} ({split})", sort_by_value=True, top_n=None, absolute=True, error=error, **plot_kws, ) return plot_ax
[docs] def plot_training_boolean_summary( summary: TrainingSummaryLike, *, ax: plt.Axes | None = None, title: str = "Training summary checks", error: str = "ignore", **plot_kws: Any, ) -> plt.Axes: """Plot semantic pass/fail checks.""" fig, plot_ax, _ = prepare_plot(ax=ax, figsize=(8.0, 4.6)) checks = summarize_training_summary(summary)["checks"] plot_boolean_checks( plot_ax, checks, title=title, error=error, **plot_kws, ) return plot_ax
[docs] def inspect_training_summary( summary: TrainingSummaryLike, *, output_dir: PathLike | None = None, stem: str = "training_summary", save_figures: bool = True, ) -> dict[str, Any]: """ Inspect a training summary and optionally save figures. Returns ------- dict Bundle containing summary, tabular frames, and optionally written figure paths. """ payload = _as_payload(summary) summary_map = summarize_training_summary(payload) bundle: dict[str, Any] = { "summary": summary_map, "frames": { "metrics_at_best": training_metrics_frame( payload, section="metrics_at_best", split="all", ), "final_epoch_metrics": training_metrics_frame( payload, section="final_epoch_metrics", split="all", ), "env": training_env_frame(payload), "compile": training_compile_frame(payload), "hp_init": training_hp_frame(payload), "paths": training_paths_frame(payload), }, "figure_paths": {}, } if not (output_dir and save_figures): return bundle out_dir = as_path(output_dir) out_dir.mkdir(parents=True, exist_ok=True) plots = { f"{stem}_best_val_metrics.png": ( plot_training_best_metrics, {"split": "val"}, ), f"{stem}_final_val_metrics.png": ( plot_training_final_metrics, {"split": "val"}, ), f"{stem}_delta_val_metrics.png": ( plot_training_metric_deltas, {"split": "val"}, ), f"{stem}_best_val_losses.png": ( plot_training_loss_family, { "section": "metrics_at_best", "split": "val", }, ), f"{stem}_checks.png": ( plot_training_boolean_summary, {}, ), } for name, (func, kwargs) in plots.items(): fig, ax = plt.subplots(figsize=(8.0, 4.6)) func(payload, ax=ax, **kwargs) fig.tight_layout() path = out_dir / name fig.savefig(path, dpi=160, bbox_inches="tight") plt.close(fig) bundle["figure_paths"][name] = str(path) return bundle