Source code for geoprior.utils.inspect.manifest

# SPDX-License-Identifier: Apache-2.0
# GeoPrior-v3
# Copyright (c) 2026-present
# Author: LKouadio <https://lkouadio.com>

"""
Stage-1 manifest generation and inspection helpers.

This module focuses on the broader ``manifest.json``
artifact exported by Stage-1 preprocessing.
Compared with the lighter Stage-2 ``run_manifest``
artifact, the Stage-1 manifest is a richer workflow
handshake that preserves:

- dataset and split identity,
- resolved columns and feature groups,
- conventions and scaling payload,
- holdout and sequence counts,
- artifact paths for CSV, encoders, NPZ files,
  and shape summaries,
- runtime/version metadata.

The helpers here support two common uses:

1. Sphinx-Gallery examples that need a realistic
   Stage-1 manifest without rerunning preprocessing.
2. Real workflow inspection when a user wants to
   verify that downstream stages can trust the
   preprocessing handshake.
"""

from __future__ import annotations

from collections.abc import Mapping
from pathlib import Path
from typing import Any

import matplotlib.pyplot as plt
import pandas as pd

from .utils import (
    ArtifactRecord,
    clone_artifact,
    load_artifact,
    nested_get,
    plot_boolean_checks,
    plot_metric_bars,
    read_json,
    write_json,
)

PathLike = str | Path
ManifestLike = ArtifactRecord | Mapping[str, Any] | str | Path

__all__ = [
    "default_manifest_payload",
    "generate_manifest",
    "inspect_manifest",
    "load_manifest",
    "manifest_artifacts_frame",
    "manifest_config_frame",
    "manifest_feature_groups_frame",
    "manifest_holdout_frame",
    "manifest_identity_frame",
    "manifest_paths_frame",
    "manifest_shapes_frame",
    "manifest_versions_frame",
    "plot_manifest_artifact_inventory",
    "plot_manifest_boolean_summary",
    "plot_manifest_coord_ranges",
    "plot_manifest_feature_group_sizes",
    "plot_manifest_holdout_counts",
    "summarize_manifest",
]

_CONFIG_KEYS = [
    "TIME_STEPS",
    "FORECAST_HORIZON_YEARS",
    "MODE",
    "TRAIN_END_YEAR",
    "FORECAST_START_YEAR",
]

_HOLDOUT_KEYS = [
    "valid_for_train",
    "valid_for_forecast",
    "kept_for_processing",
    "train_groups",
    "val_groups",
    "test_groups",
    "train_rows",
    "val_rows",
    "test_rows",
    "train_seq",
    "val_seq",
    "test_seq",
]


def _as_payload(manifest: ManifestLike) -> dict[str, Any]:
    """Return a plain Stage-1 manifest payload."""
    if isinstance(manifest, ArtifactRecord):
        return dict(manifest.payload)

    if isinstance(manifest, Mapping):
        return dict(manifest)

    return dict(read_json(manifest))


def _try_int(value: Any) -> int | None:
    """Return ``value`` as int when possible."""
    try:
        if value is None:
            return None
        return int(value)
    except Exception:
        return None


def _try_float(value: Any) -> float | None:
    """Return ``value`` as float when possible."""
    try:
        if value is None:
            return None
        return float(value)
    except Exception:
        return None


def _feature_count(payload: dict[str, Any], key: str) -> int:
    """Return the size of a Stage-1 feature group."""
    vals = nested_get(
        payload, "config", "features", key, default=[]
    )
    if isinstance(vals, list):
        return len(vals)
    return 0


def _coord_ranges(
    payload: dict[str, Any],
) -> dict[str, float]:
    """Return numeric coord ranges from nested scaling kwargs."""
    src = nested_get(
        payload,
        "config",
        "scaling_kwargs",
        "coord_ranges",
        default={},
    )
    if not isinstance(src, dict):
        return {}

    out: dict[str, float] = {}
    for key, value in src.items():
        num = _try_float(value)
        if num is None:
            continue
        out[str(key)] = num
    return out


def _walk_leaf_paths(
    obj: Any,
    *,
    prefix: str = "",
) -> list[dict[str, Any]]:
    """
    Return leaf rows for nested artifact mappings.

    String leaves are treated as path-like payload entries.
    List leaves are retained as values but not treated as paths.
    """
    rows: list[dict[str, Any]] = []

    if isinstance(obj, Mapping):
        for key, value in obj.items():
            name = f"{prefix}.{key}" if prefix else str(key)
            rows.extend(_walk_leaf_paths(value, prefix=name))
        return rows

    value = obj
    str(value) if value is not None else None
    basename = None
    is_path = False

    if isinstance(value, str) and value:
        basename = Path(value).name
        is_path = any(
            token in value.lower()
            for token in (
                ".csv",
                ".json",
                ".joblib",
                ".npz",
                ".keras",
                ".h5",
                ".pkl",
            )
        )

    rows.append(
        {
            "key": prefix,
            "value": value,
            "basename": basename,
            "is_path": is_path,
        }
    )
    return rows


def _artifact_path_count(payload: dict[str, Any]) -> int:
    """Return the number of leaf artifact path entries."""
    rows = _walk_leaf_paths(payload.get("artifacts", {}))
    return sum(1 for row in rows if row.get("is_path"))


def _shape_rows(
    payload: dict[str, Any],
) -> list[dict[str, Any]]:
    """Extract tidy rows from ``artifacts.shapes``."""
    src = nested_get(
        payload, "artifacts", "shapes", default={}
    )
    rows: list[dict[str, Any]] = []

    if not isinstance(src, dict):
        return rows

    for split_name, block in src.items():
        if not isinstance(block, dict):
            continue
        for tensor_name, shape in block.items():
            if not isinstance(shape, list):
                continue
            rows.append(
                {
                    "split": str(split_name),
                    "tensor": str(tensor_name),
                    "shape": shape,
                    "rank": len(shape),
                    "samples": _try_int(shape[0])
                    if shape
                    else None,
                    "time_dim": _try_int(shape[1])
                    if len(shape) > 1
                    else None,
                    "feature_dim": _try_int(shape[2])
                    if len(shape) > 2
                    else None,
                }
            )
    return rows


[docs] def default_manifest_payload( *, city: str = "nansha", model: str = "GeoPriorSubsNet", time_steps: int = 5, forecast_horizon_years: int = 3, mode: str = "tft_like", ) -> dict[str, Any]: """ Build a realistic default Stage-1 manifest payload. The payload mirrors the stable Stage-1 handshake: ``schema_version`` + identity fields + ``config`` + ``artifacts`` + ``paths`` + ``versions``. """ run_dir = f"results/{city}_{model}_stage1" artifacts_dir = f"{run_dir}/artifacts" payload = { "schema_version": "3.2", "timestamp": "2026-02-10 15:41:18", "city": str(city), "model": str(model), "stage": "stage1", "config": { "TIME_STEPS": int(time_steps), "FORECAST_HORIZON_YEARS": int( forecast_horizon_years ), "MODE": str(mode), "TRAIN_END_YEAR": 2022, "FORECAST_START_YEAR": 2023, "cols": { "time": "year", "lon": "longitude", "lat": "latitude", "time_numeric": "year_numeric_coord", "time_used": "year_numeric_coord", "x_base": "x_m", "y_base": "y_m", "x_used": "x_m", "y_used": "y_m", "subs_raw": "subsidence_cum", "subs_model": "subsidence_cum__si", "depth_raw": "GWL_depth_bgs_m", "head_raw": "head_m", "depth_model": "GWL_depth_bgs_m__si", "head_model": "head_m__si", "h_field_raw": "soil_thickness_eff", "h_field_model": "soil_thickness_eff__si", "z_surf_raw": "z_surf_m", "z_surf_static": "z_surf_m__si", }, "features": { "static": [ "lithology_Conglomerate–Sandstone", "lithology_Limestone–Sandstone", "lithology_Mudstone–Siltstone", "lithology_Sandstone–Siltstone", "lithology_Shale–Limestone", "lithology_Siltstone–Sandstone", "lithology_Siltstone–Shale", "lithology_Tuff–Sandstone", "lithology_class_Coarse-Grained Soil", "lithology_class_Fine-Grained Soil", "lithology_class_Mixed Clastics", "z_surf_m__si", ], "dynamic": [ "GWL_depth_bgs_m__si", "subsidence_cum__si", "rainfall_mm", "urban_load_global", "soil_thickness_censored", ], "future": ["rainfall_mm"], "group_id_cols": ["longitude", "latitude"], }, "indices": { "gwl_dyn_index": 0, "subs_dyn_index": 1, "gwl_dyn_name": "GWL_depth_bgs_m__si", "z_surf_static_index": 11, }, "conventions": { "gwl_kind": "depth_bgs", "gwl_sign": "down_positive", "use_head_proxy": False, "time_units": "year", "gwl_driver_kind": "depth", "gwl_target_kind": "head", "gwl_driver_sign": "down_positive", "gwl_target_sign": "up_positive", }, "scaler_info": { "rainfall_mm": { "scaler_path": f"{artifacts_dir}/{city}_main_scaler.joblib", "all_features": [ "rainfall_mm", "soil_thickness_censored", ], "idx": 0, }, "soil_thickness_censored": { "scaler_path": f"{artifacts_dir}/{city}_main_scaler.joblib", "all_features": [ "rainfall_mm", "soil_thickness_censored", ], "idx": 1, }, }, "scaling_kwargs": { "subs_scale_si": 1.0, "subs_bias_si": 0.0, "head_scale_si": 1.0, "head_bias_si": 0.0, "H_scale_si": 1.0, "H_bias_si": 0.0, "subsidence_kind": "cumulative", "allow_subs_residual": True, "coords_normalized": True, "coord_order": ["t", "x", "y"], "coord_ranges": { "t": 7.0, "x": 44447.0, "y": 39275.0, }, "coord_mode": "degrees", "coord_src_epsg": 4326, "coord_target_epsg": 32649, "coord_epsg_used": 32649, "coords_in_degrees": False, "cons_residual_units": "second", "gw_residual_units": "second", "Q_wrt_normalized_time": False, "Q_in_si": False, "Q_in_per_second": False, "Q_kind": "per_volume", "Q_length_in_si": False, "drainage_mode": "double", "clip_global_norm": 5.0, "gwl_dyn_index": 0, "subs_dyn_index": 1, "gwl_kind": "depth_bgs", "gwl_sign": "down_positive", "use_head_proxy": True, "bounds": { "H_min": 0.1, "H_max": 30.0, "K_min": 1e-12, "K_max": 1e-7, "Ss_min": 1e-6, "Ss_max": 1e-3, "tau_min": 604800.0, "tau_max": 9467085600.0, "logK_min": -27.6310211159, "logK_max": -16.1180956510, "logSs_min": -13.8155105580, "logSs_max": -6.9077552790, "logTau_min": 13.3126531038, "logTau_max": 22.9710869460, }, }, "feature_registry": { "optional_numeric_declared": [ "rainfall_mm", "soil_thickness_censored", ], "optional_categorical_declared": [ "lithology", "lithology_class", ], "already_normalized": [], "future_drivers_declared": ["rainfall_mm"], "resolved_optional_numeric": [ "rainfall_mm", "soil_thickness_censored", ], "resolved_optional_categorical": [ "lithology", "lithology_class", ], }, "censoring": { "specs": [], "report": {}, "use_effective_h_field": True, "flags_as_dynamic": True, "flags_as_future": False, }, "units_provenance": { "subs_unit_to_si_applied_stage1": 0.001, "thickness_unit_to_si_applied_stage1": 1.0, "head_unit_to_si_assumed_stage1": 1.0, "z_surf_unit_to_si_assumed_stage1": 1.0, }, "holdout": { "strategy": "random", "seed": 42, "val_frac": 0.2, "test_frac": 0.1, "group_cols": ["longitude", "latitude"], "coord_cols_for_blocking": { "x_col": "x_m", "y_col": "y_m", "epsg_used": 32649, }, "block_size_m": None, "requirements": { "time_col": "year", "train_end_year": 2022, "forecast_start_year": 2023, "time_steps": int(time_steps), "horizon": int(forecast_horizon_years), "years_required_for_train_windows": 8, "note": ( "Group validity enforced by " "compute_group_masks() before holdout " "splitting." ), }, "group_counts": { "valid_for_train": 194788, "valid_for_forecast": 194788, "kept_for_processing": 194788, "train_groups": 136351, "val_groups": 38958, "test_groups": 19479, }, "row_counts_hist": { "train_rows": 1090808, "val_rows": 311664, "test_rows": 155832, }, "sequence_counts": { "train_seq": 136351, "val_seq": 38958, "test_seq": 19479, }, "artifacts": { "train_groups_csv": f"{artifacts_dir}/holdout/train_groups.csv", "val_groups_csv": f"{artifacts_dir}/holdout/val_groups.csv", "test_groups_csv": f"{artifacts_dir}/holdout/test_groups.csv", }, }, }, "artifacts": { "csv": { "raw": f"{run_dir}/{city}_01_raw.csv", "clean": f"{run_dir}/{city}_02_clean.csv", "scaled": f"{run_dir}/{city}_03_scaled.csv", }, "encoders": { "ohe": { "lithology": f"{artifacts_dir}/{city}_ohe_lithology.joblib", "lithology_class": ( f"{artifacts_dir}/{city}_ohe_lithology_class.joblib" ), }, "coord_scaler": f"{artifacts_dir}/{city}_coord_scaler.joblib", "main_scaler": f"{artifacts_dir}/{city}_main_scaler.joblib", "scaled_ml_numeric_cols": [ "rainfall_mm", "soil_thickness_censored", ], }, "sequences": { "joblib_train_sequences": ( f"{artifacts_dir}/{city}_train_sequences_" f"T{time_steps}_H{forecast_horizon_years}.joblib" ), "dims": { "output_subsidence_dim": 1, "output_gwl_dim": 1, }, }, "numpy": { "train_inputs_npz": f"{artifacts_dir}/train_inputs.npz", "train_targets_npz": f"{artifacts_dir}/train_targets.npz", "val_inputs_npz": f"{artifacts_dir}/val_inputs.npz", "val_targets_npz": f"{artifacts_dir}/val_targets.npz", "test_inputs_npz": f"{artifacts_dir}/test_inputs.npz", "test_targets_npz": f"{artifacts_dir}/test_targets.npz", }, "shapes": { "train_inputs": { "H_field": [136351, 3, 1], "coords": [136351, 3, 3], "dynamic_features": [136351, 5, 5], "future_features": [136351, 8, 1], "static_features": [136351, 12], }, "train_targets": { "gwl_pred": [136351, 3, 1], "subs_pred": [136351, 3, 1], }, "val_inputs": { "H_field": [38958, 3, 1], "coords": [38958, 3, 3], "dynamic_features": [38958, 5, 5], "future_features": [38958, 8, 1], "static_features": [38958, 12], }, "val_targets": { "gwl_pred": [38958, 3, 1], "subs_pred": [38958, 3, 1], }, "test_inputs": { "H_field": [19479, 3, 1], "coords": [19479, 3, 3], "dynamic_features": [19479, 5, 5], "future_features": [19479, 8, 1], "static_features": [19479, 12], }, "test_targets": { "gwl_pred": [19479, 3, 1], "subs_pred": [19479, 3, 1], }, }, }, "paths": { "run_dir": run_dir, "artifacts_dir": artifacts_dir, }, "versions": { "python": "3.10.19", "tensorflow": "2.20.0", "numpy": "2.0.2", "pandas": "2.3.3", "sklearn": "1.7.2", }, } return payload
[docs] def generate_manifest( path: PathLike, *, template: Mapping[str, Any] | None = None, overrides: Mapping[str, Any] | None = None, ) -> Path: """Generate a Stage-1 manifest JSON artifact.""" base = ( dict(template) if template is not None else default_manifest_payload() ) payload = clone_artifact( base, overrides=dict(overrides or {}) ) return write_json(payload, path)
[docs] def load_manifest(path: PathLike) -> ArtifactRecord: """Load a Stage-1 manifest as ``ArtifactRecord``.""" return load_artifact(path, kind="manifest")
[docs] def manifest_identity_frame( manifest: ManifestLike, ) -> pd.DataFrame: """Return a compact Stage-1 identity frame.""" payload = _as_payload(manifest) rows = [ { "section": "identity", "key": "schema_version", "value": payload.get("schema_version"), }, { "section": "identity", "key": "timestamp", "value": payload.get("timestamp"), }, { "section": "identity", "key": "stage", "value": payload.get("stage"), }, { "section": "identity", "key": "city", "value": payload.get("city"), }, { "section": "identity", "key": "model", "value": payload.get("model"), }, ] return pd.DataFrame(rows)
[docs] def manifest_config_frame( manifest: ManifestLike, ) -> pd.DataFrame: """Return a tidy frame for the compact config overview.""" payload = _as_payload(manifest) cfg = nested_get(payload, "config", default={}) conv = nested_get( payload, "config", "conventions", default={} ) rows: list[dict[str, Any]] = [] for key in _CONFIG_KEYS: rows.append( { "section": "config", "key": str(key), "value": (cfg or {}).get(key), } ) for key in ( "gwl_kind", "gwl_sign", "use_head_proxy", "time_units", "gwl_driver_kind", "gwl_target_kind", ): rows.append( { "section": "conventions", "key": str(key), "value": (conv or {}).get(key), } ) rows.append( { "section": "config", "key": "coord_ranges", "value": _coord_ranges(payload), } ) return pd.DataFrame(rows)
[docs] def manifest_feature_groups_frame( manifest: ManifestLike, ) -> pd.DataFrame: """Return feature-group names and counts.""" payload = _as_payload(manifest) feats = nested_get( payload, "config", "features", default={} ) rows: list[dict[str, Any]] = [] for key in ( "static", "dynamic", "future", "group_id_cols", ): values = (feats or {}).get(key, []) if not isinstance(values, list): values = [] rows.append( { "section": "features", "group": str(key), "count": len(values), "values": values, } ) return pd.DataFrame(rows)
[docs] def manifest_holdout_frame( manifest: ManifestLike, ) -> pd.DataFrame: """Return key holdout and split counts as a tidy frame.""" payload = _as_payload(manifest) holdout = nested_get( payload, "config", "holdout", default={} ) group_counts = (holdout or {}).get( "group_counts", {} ) or {} row_counts = (holdout or {}).get( "row_counts_hist", {} ) or {} seq_counts = (holdout or {}).get( "sequence_counts", {} ) or {} rows: list[dict[str, Any]] = [] for key in _HOLDOUT_KEYS: value = None section = "holdout" if key in group_counts: value = group_counts.get(key) section = "group_counts" elif key in row_counts: value = row_counts.get(key) section = "row_counts_hist" elif key in seq_counts: value = seq_counts.get(key) section = "sequence_counts" if value is not None: rows.append( { "section": section, "key": str(key), "value": value, } ) rows.extend( [ { "section": "holdout_meta", "key": "strategy", "value": (holdout or {}).get("strategy"), }, { "section": "holdout_meta", "key": "seed", "value": (holdout or {}).get("seed"), }, { "section": "holdout_meta", "key": "val_frac", "value": (holdout or {}).get("val_frac"), }, { "section": "holdout_meta", "key": "test_frac", "value": (holdout or {}).get("test_frac"), }, ] ) return pd.DataFrame(rows)
[docs] def manifest_artifacts_frame( manifest: ManifestLike, ) -> pd.DataFrame: """Return a leaf-level artifact inventory frame.""" payload = _as_payload(manifest) rows = _walk_leaf_paths(payload.get("artifacts", {})) frame = pd.DataFrame(rows) if not frame.empty: frame.insert(0, "section", "artifacts") return frame
[docs] def manifest_paths_frame( manifest: ManifestLike, ) -> pd.DataFrame: """Return a frame describing top-level manifest paths.""" payload = _as_payload(manifest) rows: list[dict[str, Any]] = [] for key, value in ( payload.get("paths", {}) or {} ).items(): rows.append( { "section": "paths", "key": str(key), "value": value, "basename": Path(str(value)).name if value else None, } ) return pd.DataFrame(rows)
[docs] def manifest_versions_frame( manifest: ManifestLike, ) -> pd.DataFrame: """Return runtime/library versions saved in the manifest.""" payload = _as_payload(manifest) rows: list[dict[str, Any]] = [] for key, value in ( payload.get("versions", {}) or {} ).items(): rows.append( { "section": "versions", "key": str(key), "value": value, } ) return pd.DataFrame(rows)
[docs] def manifest_shapes_frame( manifest: ManifestLike, ) -> pd.DataFrame: """Return a tidy tensor-shape summary frame.""" payload = _as_payload(manifest) return pd.DataFrame(_shape_rows(payload))
[docs] def summarize_manifest( manifest: ManifestLike, ) -> dict[str, Any]: """ Summarize the Stage-1 manifest in a compact dictionary. The summary focuses on what later workflow stages usually need to verify first. """ payload = _as_payload(manifest) cfg = nested_get(payload, "config", default={}) holdout = nested_get( payload, "config", "holdout", default={} ) group_counts = (holdout or {}).get( "group_counts", {} ) or {} row_counts = (holdout or {}).get( "row_counts_hist", {} ) or {} seq_counts = (holdout or {}).get( "sequence_counts", {} ) or {} versions = payload.get("versions", {}) or {} summary = { "schema_version": payload.get("schema_version"), "timestamp": payload.get("timestamp"), "stage": payload.get("stage"), "city": payload.get("city"), "model": payload.get("model"), "time_steps": (cfg or {}).get("TIME_STEPS"), "forecast_horizon_years": (cfg or {}).get( "FORECAST_HORIZON_YEARS" ), "mode": (cfg or {}).get("MODE"), "train_end_year": (cfg or {}).get("TRAIN_END_YEAR"), "forecast_start_year": (cfg or {}).get( "FORECAST_START_YEAR" ), "static_feature_count": _feature_count( payload, "static" ), "dynamic_feature_count": _feature_count( payload, "dynamic" ), "future_feature_count": _feature_count( payload, "future" ), "group_id_cols_count": _feature_count( payload, "group_id_cols", ), "coord_ranges": _coord_ranges(payload), "artifact_path_count": _artifact_path_count(payload), "path_count": len(payload.get("paths", {}) or {}), "version_count": len(versions), "train_groups": group_counts.get("train_groups"), "val_groups": group_counts.get("val_groups"), "test_groups": group_counts.get("test_groups"), "train_rows": row_counts.get("train_rows"), "val_rows": row_counts.get("val_rows"), "test_rows": row_counts.get("test_rows"), "train_seq": seq_counts.get("train_seq"), "val_seq": seq_counts.get("val_seq"), "test_seq": seq_counts.get("test_seq"), "has_config": isinstance(payload.get("config"), dict), "has_artifacts": isinstance( payload.get("artifacts"), dict ), "has_paths": isinstance(payload.get("paths"), dict), "has_versions": isinstance( payload.get("versions"), dict ), "has_holdout": isinstance( (cfg or {}).get("holdout"), dict ), "has_scaling_kwargs": isinstance( (cfg or {}).get("scaling_kwargs"), dict, ), } return summary
[docs] def plot_manifest_artifact_inventory( ax: plt.Axes, manifest: ManifestLike, *, title: str = "Manifest artifact inventory", **plot_kws: Any, ) -> plt.Axes: """Plot artifact and metadata inventory counts.""" payload = _as_payload(manifest) metrics = { "artifact_paths": _artifact_path_count(payload), "paths": len(payload.get("paths", {}) or {}), "versions": len(payload.get("versions", {}) or {}), "shape_blocks": len( nested_get( payload, "artifacts", "shapes", default={} ) or {} ), "numpy_artifacts": len( nested_get( payload, "artifacts", "numpy", default={} ) or {} ), } return plot_metric_bars( ax, metrics, title=title, sort_by_value=False, **plot_kws, )
[docs] def plot_manifest_feature_group_sizes( ax: plt.Axes, manifest: ManifestLike, *, title: str = "Stage-1 feature groups", **plot_kws: Any, ) -> plt.Axes: """Plot Stage-1 feature-group sizes.""" payload = _as_payload(manifest) metrics = { "static": _feature_count(payload, "static"), "dynamic": _feature_count(payload, "dynamic"), "future": _feature_count(payload, "future"), "group_id_cols": _feature_count( payload, "group_id_cols" ), } return plot_metric_bars( ax, metrics, title=title, sort_by_value=False, **plot_kws, )
[docs] def plot_manifest_holdout_counts( ax: plt.Axes, manifest: ManifestLike, *, title: str = "Holdout split counts", **plot_kws: Any, ) -> plt.Axes: """Plot the main group and sequence split counts.""" payload = _as_payload(manifest) holdout = nested_get( payload, "config", "holdout", default={} ) groups = (holdout or {}).get("group_counts", {}) or {} seqs = (holdout or {}).get("sequence_counts", {}) or {} metrics = { "train_groups": groups.get("train_groups", 0), "val_groups": groups.get("val_groups", 0), "test_groups": groups.get("test_groups", 0), "train_seq": seqs.get("train_seq", 0), "val_seq": seqs.get("val_seq", 0), "test_seq": seqs.get("test_seq", 0), } return plot_metric_bars( ax, metrics, title=title, sort_by_value=False, **plot_kws, )
[docs] def plot_manifest_coord_ranges( ax: plt.Axes, manifest: ManifestLike, *, title: str = "Coordinate ranges", **plot_kws: Any, ) -> plt.Axes: """Plot nested coordinate ranges from scaling kwargs.""" return plot_metric_bars( ax, _coord_ranges(_as_payload(manifest)), title=title, sort_by_value=False, **plot_kws, )
[docs] def plot_manifest_boolean_summary( ax: plt.Axes, manifest: ManifestLike, *, title: str = "Stage-1 manifest checks", **plot_kws: Any, ) -> plt.Axes: """Plot simple boolean checks for the handshake.""" payload = _as_payload(manifest) cfg = payload.get("config", {}) or {} checks = { "has_schema_version": bool( payload.get("schema_version") ), "has_stage": bool(payload.get("stage")), "has_city": bool(payload.get("city")), "has_model": bool(payload.get("model")), "has_config": isinstance(payload.get("config"), dict), "has_artifacts": isinstance( payload.get("artifacts"), dict ), "has_paths": isinstance(payload.get("paths"), dict), "has_versions": isinstance( payload.get("versions"), dict ), "has_scaling_kwargs": isinstance( cfg.get("scaling_kwargs"), dict, ), "has_holdout": isinstance(cfg.get("holdout"), dict), "has_numpy_artifacts": bool( nested_get( payload, "artifacts", "numpy", default={} ) ), "has_shape_summary": bool( nested_get( payload, "artifacts", "shapes", default={} ) ), } return plot_boolean_checks( ax, checks, title=title, **plot_kws, )
[docs] def inspect_manifest( manifest: ManifestLike, ) -> dict[str, Any]: """ Return a bundle of useful Stage-1 inspection outputs. Returns ------- dict Normalized payload, compact summary, and the main tidy frames. """ payload = _as_payload(manifest) return { "payload": payload, "summary": summarize_manifest(payload), "identity": manifest_identity_frame(payload), "config": manifest_config_frame(payload), "feature_groups": manifest_feature_groups_frame( payload ), "holdout": manifest_holdout_frame(payload), "artifacts": manifest_artifacts_frame(payload), "paths": manifest_paths_frame(payload), "versions": manifest_versions_frame(payload), "shapes": manifest_shapes_frame(payload), }