# SPDX-License-Identifier: Apache-2.0
# GeoPrior-v3
# Copyright (c) 2026-present
# Author: LKouadio <https://lkouadio.com>
"""
Shared helpers for ``geoprior.utils.inspect``.
This module is intentionally generic.
It does not know the full semantics of each
artifact family. Instead, it provides the
small reusable pieces that inspection
submodules can share:
- robust JSON loading and writing,
- nested mapping access,
- safe flattening for summaries,
- artifact kind inference,
- demo payload cloning for gallery examples,
- compact tabular conversion for metrics,
- generic plotting helpers.
Submodules such as ``stage1_audit.py`` or
``training_summary.py`` should build the
artifact-specific logic on top of these
helpers rather than reimplementing them.
"""
from __future__ import annotations
import copy
import json
import math
from collections.abc import Callable
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from ...core.handlers import get_valid_kwargs
PathLike = str | Path
__all__ = [
"ArtifactRecord",
"artifact_brief",
"as_path",
"bool_checks_frame",
"clone_artifact",
"deep_update",
"ensure_parent_dir",
"flatten_dict",
"infer_artifact_kind",
"is_number",
"json_ready",
"load_artifact",
"metrics_frame",
"nested_get",
"numeric_items",
"prepare_plot",
"empty_plot",
"finalize_plot",
"filter_plot_kwargs",
"plot_boolean_checks",
"plot_metric_bars",
"plot_series_map",
"read_json",
"write_json",
]
[docs]
@dataclass(slots=True)
class ArtifactRecord:
"""
Lightweight normalized artifact container.
Parameters
----------
path : pathlib.Path
Artifact path.
kind : str
Inferred or explicit artifact kind.
payload : dict[str, Any]
Loaded JSON payload.
stage : str or None
Stage if available.
city : str or None
City if available.
model : str or None
Model if available.
meta : dict[str, Any]
Extra extracted metadata.
"""
path: Path
kind: str
payload: dict[str, Any]
stage: str | None = None
city: str | None = None
model: str | None = None
meta: dict[str, Any] = field(default_factory=dict)
[docs]
def as_path(path: PathLike) -> Path:
"""Return ``path`` as resolved ``Path``."""
return Path(path).expanduser().resolve()
[docs]
def ensure_parent_dir(path: PathLike) -> Path:
"""Create parent directory for ``path``."""
p = as_path(path)
p.parent.mkdir(parents=True, exist_ok=True)
return p
[docs]
def read_json(path: PathLike) -> dict[str, Any]:
"""Read a JSON file into a dictionary."""
p = as_path(path)
with p.open("r", encoding="utf-8") as stream:
data = json.load(stream)
if not isinstance(data, dict):
raise ValueError(
"Expected a JSON object at "
f"{p!s}, got {type(data).__name__}."
)
return data
[docs]
def write_json(
payload: dict[str, Any],
path: PathLike,
*,
indent: int = 2,
sort_keys: bool = False,
) -> Path:
"""Write ``payload`` as UTF-8 JSON."""
p = ensure_parent_dir(path)
safe = json_ready(payload)
with p.open("w", encoding="utf-8") as stream:
json.dump(
safe,
stream,
indent=indent,
sort_keys=sort_keys,
ensure_ascii=False,
)
stream.write("\n")
return p
[docs]
def is_number(value: Any) -> bool:
"""Return True for finite or non-finite scalars."""
return isinstance(value, (int, float, np.number))
[docs]
def json_ready(value: Any) -> Any:
"""
Convert nested values into JSON-safe objects.
Notes
-----
- ``NaN`` and ``Inf`` are converted to ``None``.
- numpy scalars are converted to Python scalars.
- arrays become lists.
"""
if isinstance(value, dict):
return {
str(k): json_ready(v) for k, v in value.items()
}
if isinstance(value, (list, tuple)):
return [json_ready(v) for v in value]
if isinstance(value, np.ndarray):
return [json_ready(v) for v in value.tolist()]
if isinstance(value, np.generic):
return json_ready(value.item())
if isinstance(value, float):
if math.isnan(value) or math.isinf(value):
return None
return value
return value
[docs]
def deep_update(
base: dict[str, Any],
updates: dict[str, Any] | None,
) -> dict[str, Any]:
"""
Recursively update ``base`` with ``updates``.
Returns a new dictionary.
"""
out = copy.deepcopy(base)
if not updates:
return out
for key, value in updates.items():
if isinstance(value, dict) and isinstance(
out.get(key), dict
):
out[key] = deep_update(out[key], value)
else:
out[key] = copy.deepcopy(value)
return out
[docs]
def clone_artifact(
template: dict[str, Any],
*,
overrides: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""
Clone a template payload and apply overrides.
This is useful for Sphinx-Gallery examples
where we want a realistic artifact with a few
controlled changes.
"""
return deep_update(template, overrides)
[docs]
def nested_get(
mapping: dict[str, Any] | None,
*keys: str,
default: Any = None,
) -> Any:
"""
Safely traverse nested dictionaries.
Examples
--------
``nested_get(d, "config", "scaling_kwargs")``
"""
cur: Any = mapping
for key in keys:
if not isinstance(cur, dict) or key not in cur:
return default
cur = cur[key]
return cur
[docs]
def flatten_dict(
mapping: dict[str, Any],
*,
parent_key: str = "",
sep: str = ".",
) -> dict[str, Any]:
"""
Flatten nested dictionaries.
Non-dict values are kept as they are.
Lists and arrays are not expanded.
"""
items: dict[str, Any] = {}
for key, value in mapping.items():
new_key = (
f"{parent_key}{sep}{key}"
if parent_key
else str(key)
)
if isinstance(value, dict):
items.update(
flatten_dict(
value,
parent_key=new_key,
sep=sep,
)
)
else:
items[new_key] = value
return items
[docs]
def numeric_items(
mapping: dict[str, Any] | None,
*,
drop_bools: bool = True,
) -> dict[str, float]:
"""Extract numeric scalar items from a mapping."""
out: dict[str, float] = {}
if not mapping:
return out
for key, value in mapping.items():
if drop_bools and isinstance(value, bool):
continue
if is_number(value):
out[str(key)] = float(value)
return out
[docs]
def infer_artifact_kind(
path: PathLike,
payload: dict[str, Any] | None = None,
) -> str:
"""
Infer artifact kind from file name and keys.
The rules are intentionally simple and stable.
Artifact-specific readers can still override
the inferred kind if needed.
"""
p = as_path(path)
name = p.name.lower()
data = payload or {}
keys = set(data)
if "scaling_audit" in name:
return "stage1_audit"
if "handshake_audit" in name:
return "stage2_handshake"
if name.endswith(".npz.meta.json"):
return "physics_payload_meta"
if "training_summary" in name:
return "training_summary"
if "eval_diagnostics" in name:
return "eval_diagnostics"
if "calibration_stats" in name:
return "calibration_stats"
if "scaling_kwargs" in name:
return "scaling_kwargs"
if "model_init_manifest" in name:
return "model_init_manifest"
if "run_manifest" in name:
return "run_manifest"
if name == "manifest.json":
return "manifest"
if {
"metrics_at_best",
"final_epoch_metrics",
}.issubset(keys):
return "training_summary"
if {
"expected",
"got",
"finite",
}.issubset(keys):
return "stage2_handshake"
if {
"provenance",
"coord_scaler",
"feature_split",
}.issubset(keys):
return "stage1_audit"
if {
"schema_version",
"stage",
"artifacts",
"paths",
}.issubset(keys):
return "manifest"
if {
"config",
"dims",
"model_class",
}.issubset(keys):
return "model_init_manifest"
if {
"metrics_evaluate",
"physics_diagnostics",
}.issubset(keys):
return "eval_physics"
if {
"target",
"interval",
"factors",
"eval_before",
"eval_after",
}.issubset(keys):
return "calibration_stats"
if {
"created_utc",
"units",
"payload_metrics",
}.issubset(keys):
return "physics_payload_meta"
if {
"stage",
"city",
"model",
"config",
"paths",
"artifacts",
}.issubset(keys):
return "run_manifest"
return "json_artifact"
[docs]
def load_artifact(
path: PathLike,
*,
kind: str | None = None,
) -> ArtifactRecord:
"""Load a JSON artifact into ``ArtifactRecord``."""
p = as_path(path)
payload = read_json(p)
artifact_kind = kind or infer_artifact_kind(p, payload)
stage = nested_get(payload, "stage")
city = nested_get(payload, "city")
model = (
nested_get(payload, "model")
or nested_get(payload, "model_name")
or nested_get(payload, "model_class")
)
meta = {
"top_keys": list(payload),
"n_top_keys": len(payload),
"has_config": "config" in payload,
"has_paths": "paths" in payload,
"has_artifacts": "artifacts" in payload,
}
return ArtifactRecord(
path=p,
kind=artifact_kind,
payload=payload,
stage=stage,
city=city,
model=model,
meta=meta,
)
[docs]
def artifact_brief(record: ArtifactRecord) -> dict[str, Any]:
"""
Return a compact artifact header summary.
"""
return {
"path": str(record.path),
"kind": record.kind,
"stage": record.stage,
"city": record.city,
"model": record.model,
"n_top_keys": record.meta.get("n_top_keys"),
"top_keys": record.meta.get("top_keys"),
}
[docs]
def metrics_frame(
mapping: dict[str, Any] | None,
*,
section: str | None = None,
sort: bool = True,
) -> pd.DataFrame:
"""
Convert scalar metrics into a tidy DataFrame.
"""
pairs = numeric_items(mapping)
frame = pd.DataFrame(
{
"metric": list(pairs),
"value": list(pairs.values()),
}
)
if section is not None and not frame.empty:
frame.insert(0, "section", section)
if sort and not frame.empty:
frame = frame.sort_values("metric")
return frame.reset_index(drop=True)
[docs]
def bool_checks_frame(
mapping: dict[str, Any] | None,
*,
section: str | None = None,
) -> pd.DataFrame:
"""
Convert boolean checks into a tidy DataFrame.
"""
rows: list[dict[str, Any]] = []
for key, value in (mapping or {}).items():
if isinstance(value, bool):
rows.append(
{
"check": str(key),
"ok": value,
}
)
frame = pd.DataFrame(rows)
if section is not None and not frame.empty:
frame.insert(0, "section", section)
return frame.reset_index(drop=True)
def prepare_plot(
ax: plt.Axes | None = None,
*,
figsize: tuple[float, float] | None = None,
constrained_layout: bool = True,
) -> tuple[plt.Figure, plt.Axes, bool]:
"""
Return a plotting context as ``(fig, ax, created)``.
This helper is intentionally conservative so older
plot helpers can adopt it without changing their
public behavior.
"""
if ax is None:
fig, ax = plt.subplots(
figsize=figsize,
constrained_layout=constrained_layout,
)
return fig, ax, True
return ax.figure, ax, False
def empty_plot(
fig: plt.Figure,
ax: plt.Axes,
*,
title: str | None = None,
message: str = "No data",
axis_off: bool = True,
) -> tuple[plt.Figure, plt.Axes]:
"""
Render a consistent empty-state plot.
"""
if title:
ax.set_title(title)
ax.text(
0.5,
0.5,
message,
ha="center",
va="center",
transform=ax.transAxes,
)
if axis_off:
ax.set_axis_off()
return fig, ax
def finalize_plot(
fig: plt.Figure,
ax: plt.Axes,
*,
title: str | None = None,
xlabel: str | None = None,
ylabel: str | None = None,
show_grid: bool = True,
grid_kws: dict[str, Any] | None = None,
legend: bool | None = None,
legend_kws: dict[str, Any] | None = None,
rotate_xticks: float | int | None = None,
rotate_yticks: float | int | None = None,
) -> tuple[plt.Figure, plt.Axes]:
"""
Apply common cosmetic settings and return
``(fig, ax)``.
"""
if title:
ax.set_title(title)
if xlabel:
ax.set_xlabel(xlabel)
if ylabel:
ax.set_ylabel(ylabel)
if rotate_xticks is not None:
ax.tick_params(axis="x", rotation=rotate_xticks)
if rotate_yticks is not None:
ax.tick_params(axis="y", rotation=rotate_yticks)
if show_grid:
ax.grid(**(grid_kws or {"alpha": 0.25}))
if legend is True:
ax.legend(**(legend_kws or {}))
return fig, ax
def filter_plot_kwargs(
callable_obj: Callable[..., Any] | Any,
kwargs: dict[str, Any] | None = None,
*,
error: str = "ignore",
) -> dict[str, Any]:
"""
Keep only kwargs accepted by ``callable_obj``.
Notes
-----
``geoprior.core.handlers.get_valid_kwargs`` accepts keyword
arguments through ``**kwargs`` rather than via a named ``kwargs``
dictionary. Passing ``kwargs=...`` therefore duplicates the
internal ``kwargs`` argument and raises a ``TypeError`` on the
current handler implementation. This wrapper normalizes the call
and falls back to local signature filtering if needed.
"""
if not kwargs:
return {}
kws = dict(kwargs)
try:
return get_valid_kwargs(callable_obj, **kws)
except TypeError:
pass
try:
import inspect
sig = inspect.signature(callable_obj)
except (TypeError, ValueError):
return {}
if any(
p.kind == inspect.Parameter.VAR_KEYWORD
for p in sig.parameters.values()
):
return kws
valid = set(sig.parameters)
return {k: v for k, v in kws.items() if k in valid}
def _coerce_plot_inputs(
first: plt.Axes | dict[str, Any] | pd.DataFrame | None,
second: dict[str, Any] | pd.DataFrame | None = None,
*,
ax: plt.Axes | None = None,
) -> tuple[
plt.Figure, plt.Axes, dict[str, Any] | pd.DataFrame | None
]:
"""
Support both legacy ``(ax, data)`` and newer
``(data, ax=...)`` calling styles.
"""
if isinstance(first, plt.Axes):
plot_ax = first
plot_data = second
fig = plot_ax.figure
return fig, plot_ax, plot_data
fig, plot_ax, _ = prepare_plot(ax=ax)
return fig, plot_ax, first
[docs]
def plot_metric_bars(
ax: plt.Axes | dict[str, Any] | pd.DataFrame,
metrics: dict[str, Any] | pd.DataFrame | None = None,
*,
title: str = "Metrics",
top_n: int | None = None,
sort_by_value: bool = False,
absolute: bool = False,
annotate: bool = True,
xlabel: str = "value",
show_grid: bool = True,
grid_kws: dict[str, Any] | None = None,
annotate_kws: dict[str, Any] | None = None,
ax_obj: plt.Axes | None = None,
error: str = "ignore",
**plot_kws: Any,
) -> plt.Axes:
"""
Plot a compact horizontal metric bar chart.
The legacy calling style ``plot_metric_bars(ax,
metrics, ...)`` is preserved. A newer style
``plot_metric_bars(metrics, ax_obj=ax, ...)`` is
also accepted for gradual migration.
"""
fig, plot_ax, plot_data = _coerce_plot_inputs(
ax,
metrics,
ax=ax_obj,
)
if isinstance(plot_data, pd.DataFrame):
frame = plot_data.copy()
else:
frame = metrics_frame(plot_data)
if frame.empty:
_, plot_ax = empty_plot(
fig,
plot_ax,
title=title,
message="No numeric metrics",
)
return plot_ax
if absolute:
frame["plot_value"] = frame["value"].abs()
else:
frame["plot_value"] = frame["value"]
if sort_by_value:
frame = frame.sort_values("plot_value")
else:
frame = frame.sort_values("metric")
if top_n is not None:
frame = frame.tail(int(top_n))
bar_kws = filter_plot_kwargs(
plot_ax.barh,
plot_kws,
error=error,
)
plot_ax.barh(
frame["metric"],
frame["plot_value"],
**bar_kws,
)
_, plot_ax = finalize_plot(
fig,
plot_ax,
title=title,
xlabel=xlabel,
show_grid=show_grid,
grid_kws=grid_kws or {"axis": "x", "alpha": 0.25},
)
if annotate:
text_kws = filter_plot_kwargs(
plot_ax.text,
annotate_kws,
error=error,
)
for idx, row in frame.reset_index(
drop=True
).iterrows():
plot_ax.text(
row["plot_value"],
idx,
f" {row['value']:.4g}",
va="center",
**text_kws,
)
return plot_ax
[docs]
def plot_boolean_checks(
ax: plt.Axes | dict[str, Any] | pd.DataFrame,
checks: dict[str, Any] | pd.DataFrame | None = None,
*,
title: str = "Checks",
show_grid: bool = True,
grid_kws: dict[str, Any] | None = None,
ax_obj: plt.Axes | None = None,
error: str = "ignore",
**plot_kws: Any,
) -> plt.Axes:
"""
Plot boolean pass/fail checks as a bar view.
Keeps the older ``(ax, checks)`` call pattern while
allowing ``(checks, ax_obj=ax)`` for newer code.
"""
fig, plot_ax, plot_data = _coerce_plot_inputs(
ax,
checks,
ax=ax_obj,
)
if isinstance(plot_data, pd.DataFrame):
frame = plot_data.copy()
else:
frame = bool_checks_frame(plot_data)
if frame.empty:
_, plot_ax = empty_plot(
fig,
plot_ax,
title=title,
message="No boolean checks",
)
return plot_ax
frame["score"] = frame["ok"].astype(int)
labels = frame["check"]
vals = frame["score"]
bar_kws = filter_plot_kwargs(
plot_ax.barh,
plot_kws,
error=error,
)
plot_ax.barh(labels, vals, **bar_kws)
plot_ax.set_xlim(0, 1)
plot_ax.set_xticks([0, 1])
plot_ax.set_xticklabels(["fail", "pass"])
_, plot_ax = finalize_plot(
fig,
plot_ax,
title=title,
show_grid=show_grid,
grid_kws=grid_kws or {"axis": "x", "alpha": 0.25},
)
return plot_ax
[docs]
def plot_series_map(
ax: plt.Axes | dict[str, Any],
series_map: dict[str, Any] | None = None,
*,
title: str = "Series",
xlabel: str = "key",
ylabel: str = "value",
marker: str = "o",
show_grid: bool = True,
grid_kws: dict[str, Any] | None = None,
ax_obj: plt.Axes | None = None,
error: str = "ignore",
**plot_kws: Any,
) -> plt.Axes:
"""
Plot a string-keyed numeric mapping as a line.
Keeps the older ``(ax, series_map)`` form while
also accepting ``(series_map, ax_obj=ax)``.
"""
fig, plot_ax, plot_data = _coerce_plot_inputs(
ax,
series_map,
ax=ax_obj,
)
pairs = numeric_items(plot_data)
if not pairs:
_, plot_ax = empty_plot(
fig,
plot_ax,
title=title,
message="No numeric series",
)
return plot_ax
def _key(v: str) -> tuple[int, Any]:
try:
return (0, float(v))
except Exception:
return (1, v)
keys = sorted(pairs, key=_key)
x = np.arange(len(keys))
y = np.array([pairs[k] for k in keys], dtype=float)
line_kws = filter_plot_kwargs(
plot_ax.plot,
plot_kws,
error=error,
)
if "marker" not in line_kws:
line_kws["marker"] = marker
plot_ax.plot(x, y, **line_kws)
plot_ax.set_xticks(x)
plot_ax.set_xticklabels(keys)
_, plot_ax = finalize_plot(
fig,
plot_ax,
title=title,
xlabel=xlabel,
ylabel=ylabel,
show_grid=show_grid,
grid_kws=grid_kws or {"alpha": 0.25},
)
return plot_ax