# SPDX-License-Identifier: Apache-2.0
# GeoPrior-v3
# Copyright (c) 2026-present
# Author: LKouadio <https://lkouadio.com>
"""
Calibration-statistics generation and inspection helpers.
This module focuses on the compact calibration-stats
artifact produced by the forecast calibration workflow.
It is intentionally narrower than ``eval_physics.py``:
- ``eval_physics.py`` inspects the full evaluation payload,
- ``calibration_stats.py`` inspects the saved calibration
stats object itself.
The module also accepts the richer interpretable eval JSON
because Stage-2 stores the same calibration stats under:
``interval_calibration['factors_per_horizon_from_cal_stats']``.
"""
from __future__ import annotations
from collections.abc import Mapping
from pathlib import Path
from typing import Any
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from .utils import (
ArtifactRecord,
as_path,
clone_artifact,
empty_plot,
filter_plot_kwargs,
finalize_plot,
nested_get,
plot_boolean_checks,
plot_metric_bars,
plot_series_map,
prepare_plot,
read_json,
write_json,
)
PathLike = str | Path
CalibrationStatsLike = (
ArtifactRecord | Mapping[str, Any] | str | Path
)
__all__ = [
"calibration_stats_factors_frame",
"calibration_stats_overall_frame",
"calibration_stats_per_horizon_frame",
"default_calibration_stats_payload",
"generate_calibration_stats",
"inspect_calibration_stats",
"load_calibration_stats",
"plot_calibration_boolean_summary",
"plot_calibration_factors",
"plot_calibration_overall_metrics",
"plot_calibration_per_horizon_coverage",
"plot_calibration_per_horizon_sharpness",
"summarize_calibration_stats",
]
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
def _coerce_horizon_map(
value: Any,
) -> dict[str, float]:
"""Return a stable string-keyed horizon mapping."""
if value is None:
return {}
if isinstance(value, Mapping):
out: dict[str, float] = {}
for key, item in value.items():
try:
out[str(int(float(key)))] = float(item)
except Exception:
continue
return out
if isinstance(value, (list, tuple, np.ndarray)):
out = {}
for idx, item in enumerate(value, start=1):
try:
out[str(idx)] = float(item)
except Exception:
continue
return out
return {}
def _extract_payload(
obj: Mapping[str, Any],
) -> dict[str, Any]:
"""
Extract calibration stats from either:
- a direct calibration-stats payload, or
- an interpretable eval JSON containing the nested
``interval_calibration`` block.
"""
if "interval_calibration" in obj:
interval_block = obj.get("interval_calibration") or {}
if not isinstance(interval_block, Mapping):
return dict(obj)
nested = (
interval_block.get(
"factors_per_horizon_from_cal_stats"
)
or {}
)
if isinstance(nested, Mapping) and nested:
payload = dict(nested)
if (
"factors" not in payload
and "factors_per_horizon" in interval_block
):
payload["factors"] = _coerce_horizon_map(
interval_block.get("factors_per_horizon")
)
return payload
return dict(interval_block)
return dict(obj)
def _as_payload(
stats: CalibrationStatsLike,
) -> dict[str, Any]:
"""Return a plain calibration-stats payload."""
if isinstance(stats, ArtifactRecord):
return _extract_payload(stats.payload)
if isinstance(stats, Mapping):
return _extract_payload(stats)
payload = read_json(stats)
return _extract_payload(payload)
def _per_horizon_rows(
payload: dict[str, Any],
*,
which: str,
) -> list[dict[str, Any]]:
"""Build tidy per-horizon rows."""
block = nested_get(
payload, which, "per_horizon", default={}
)
if not isinstance(block, Mapping):
return []
rows: list[dict[str, Any]] = []
for key, item in block.items():
if not isinstance(item, Mapping):
continue
row = {
"which": str(which),
"horizon": str(key),
"coverage": item.get("coverage"),
"sharpness": item.get("sharpness"),
}
rows.append(row)
try:
rows.sort(key=lambda r: int(float(r["horizon"])))
except Exception:
rows.sort(key=lambda r: r["horizon"])
return rows
# ------------------------------------------------------------------
# Generation
# ------------------------------------------------------------------
[docs]
def default_calibration_stats_payload(
*,
target: float = 0.80,
interval: tuple[float, float] = (0.10, 0.90),
f_max: float = 5.0,
tol: float = 0.02,
factors: dict[str, float] | None = None,
coverage_before: float = 0.865,
coverage_after: float = 0.800,
sharpness_before: float = 33.08,
sharpness_after: float = 33.38,
) -> dict[str, Any]:
"""
Build a realistic default calibration-stats payload.
The structure follows the object saved by the
calibration workflow and later embedded into the
interpretable eval JSON.
"""
factors = factors or {
"1": 1.0,
"2": 1.0,
"3": 1.0183744430541992,
}
payload = {
"target": float(target),
"interval": [float(interval[0]), float(interval[1])],
"f_max": float(f_max),
"tol": float(tol),
"overall_key": "__overall__",
"factors_source": "fit",
"factors": {
str(k): float(v) for k, v in factors.items()
},
"eval_before": {
"coverage": float(coverage_before),
"sharpness": float(sharpness_before),
"per_horizon": {
"1": {
"coverage": 0.9790543662405667,
"sharpness": 23.244874687581664,
},
"2": {
"coverage": 0.8223728117459829,
"sharpness": 27.592224706985775,
},
"3": {
"coverage": 0.7935725653267621,
"sharpness": 48.40746668258111,
},
},
},
"eval_after": {
"coverage": float(coverage_after),
"sharpness": float(sharpness_after),
"per_horizon": {
"1": {
"coverage": 0.9790543662405667,
"sharpness": 23.244874687581664,
},
"2": {
"coverage": 0.8223728117459829,
"sharpness": 27.592224706985775,
},
"3": {
"coverage": 0.8000410698701166,
"sharpness": 49.29692692253824,
},
},
},
}
return payload
[docs]
def generate_calibration_stats(
path: PathLike,
*,
template: CalibrationStatsLike | None = None,
overrides: dict[str, Any] | None = None,
) -> Path:
"""
Generate and save a calibration-stats JSON file.
Parameters
----------
path : str or pathlib.Path
Output JSON path.
template : mapping, path, ArtifactRecord, optional
Optional source payload. If omitted, a realistic
default payload is used.
overrides : dict, optional
Deep overrides applied after template resolution.
"""
base = (
_as_payload(template)
if template is not None
else default_calibration_stats_payload()
)
payload = clone_artifact(base, overrides=overrides)
return write_json(payload, path)
# ------------------------------------------------------------------
# Loading
# ------------------------------------------------------------------
[docs]
def load_calibration_stats(
path: PathLike,
) -> ArtifactRecord:
"""
Load a calibration-stats artifact.
Notes
-----
This loader is list-safe and nested-block aware. It can
read either a direct ``calibration_stats.json`` payload or
an interpretable eval JSON from which the nested block is
extracted.
"""
raw = read_json(path)
payload = _extract_payload(raw)
p = as_path(path)
meta = {
"top_keys": list(payload),
"n_top_keys": len(payload),
"source_name": p.name,
"has_eval_before": "eval_before" in payload,
"has_eval_after": "eval_after" in payload,
"has_factors": "factors" in payload,
}
return ArtifactRecord(
path=p,
kind="calibration_stats",
payload=payload,
stage=None,
city=nested_get(raw, "city"),
model=nested_get(raw, "model"),
meta=meta,
)
# ------------------------------------------------------------------
# Frames
# ------------------------------------------------------------------
[docs]
def calibration_stats_factors_frame(
stats: CalibrationStatsLike,
) -> pd.DataFrame:
"""Return per-horizon calibration factors."""
payload = _as_payload(stats)
factors = _coerce_horizon_map(payload.get("factors"))
rows = [
{
"horizon": key,
"factor": value,
}
for key, value in factors.items()
]
frame = pd.DataFrame(rows)
if frame.empty:
return frame
try:
frame = frame.sort_values(
by="horizon",
key=lambda s: s.astype(float),
)
except Exception:
frame = frame.sort_values("horizon")
return frame.reset_index(drop=True)
[docs]
def calibration_stats_overall_frame(
stats: CalibrationStatsLike,
) -> pd.DataFrame:
"""Return before/after overall calibration metrics."""
payload = _as_payload(stats)
rows = []
target = payload.get("target")
for which in ("eval_before", "eval_after"):
block = payload.get(which) or {}
if not isinstance(block, Mapping):
continue
coverage = block.get("coverage")
sharpness = block.get("sharpness")
rows.append(
{
"which": which,
"coverage": coverage,
"sharpness": sharpness,
"coverage_error": (
abs(float(coverage) - float(target))
if coverage is not None
and target is not None
else None
),
}
)
return pd.DataFrame(rows)
[docs]
def calibration_stats_per_horizon_frame(
stats: CalibrationStatsLike,
*,
which: str = "eval_after",
) -> pd.DataFrame:
"""
Return per-horizon coverage and sharpness.
Parameters
----------
which : {'eval_before', 'eval_after'}
Which calibration stage to extract.
"""
payload = _as_payload(stats)
return pd.DataFrame(
_per_horizon_rows(payload, which=which)
)
# ------------------------------------------------------------------
# Summary
# ------------------------------------------------------------------
[docs]
def summarize_calibration_stats(
stats: CalibrationStatsLike,
) -> dict[str, Any]:
"""Return a compact summary of calibration behavior."""
payload = _as_payload(stats)
target = payload.get("target")
tol = payload.get("tol")
factors = _coerce_horizon_map(payload.get("factors"))
before_cov = nested_get(
payload, "eval_before", "coverage"
)
after_cov = nested_get(payload, "eval_after", "coverage")
before_sharp = nested_get(
payload, "eval_before", "sharpness"
)
after_sharp = nested_get(
payload, "eval_after", "sharpness"
)
before_error = (
abs(float(before_cov) - float(target))
if before_cov is not None and target is not None
else None
)
after_error = (
abs(float(after_cov) - float(target))
if after_cov is not None and target is not None
else None
)
max_factor = max(factors.values()) if factors else None
min_factor = min(factors.values()) if factors else None
summary = {
"target": target,
"interval_low": nested_get(
payload, "interval", default=[None, None]
)[0],
"interval_high": nested_get(
payload, "interval", default=[None, None]
)[1],
"tol": tol,
"n_horizons": len(factors),
"factors_source": payload.get("factors_source"),
"coverage_before": before_cov,
"coverage_after": after_cov,
"sharpness_before": before_sharp,
"sharpness_after": after_sharp,
"coverage_error_before": before_error,
"coverage_error_after": after_error,
"coverage_error_improved": (
(after_error <= before_error)
if after_error is not None
and before_error is not None
else None
),
"target_reached_after": (
(after_error <= float(tol))
if after_error is not None and tol is not None
else None
),
"max_factor": max_factor,
"min_factor": min_factor,
"has_eval_before": isinstance(
payload.get("eval_before"),
Mapping,
),
"has_eval_after": isinstance(
payload.get("eval_after"),
Mapping,
),
"has_factors": bool(factors),
"skipped": bool(payload.get("skipped", False)),
}
return summary
# ------------------------------------------------------------------
# Plotting
# ------------------------------------------------------------------
[docs]
def plot_calibration_factors(
ax: plt.Axes,
stats: CalibrationStatsLike,
*,
title: str = "Calibration factors",
show_grid: bool = True,
grid_kws: dict[str, Any] | None = None,
error: str = "ignore",
**plot_kws: Any,
) -> plt.Axes:
"""Plot per-horizon widening factors."""
fig, ax, _ = prepare_plot(ax=ax)
frame = calibration_stats_factors_frame(stats)
if frame.empty:
_, ax = empty_plot(
fig,
ax,
title=title,
message="No calibration factors",
)
return ax
bar_kws = filter_plot_kwargs(
ax.bar, plot_kws, error=error
)
ax.bar(frame["horizon"], frame["factor"], **bar_kws)
_, ax = finalize_plot(
fig,
ax,
title=title,
xlabel="horizon",
ylabel="factor",
show_grid=show_grid,
grid_kws=grid_kws or {"axis": "y", "alpha": 0.25},
)
return ax
[docs]
def plot_calibration_overall_metrics(
ax: plt.Axes,
stats: CalibrationStatsLike,
*,
title: str = "Calibration summary",
error: str = "ignore",
**plot_kws: Any,
) -> plt.Axes:
"""Plot overall before/after calibration metrics."""
fig, ax, _ = prepare_plot(ax=ax)
frame = calibration_stats_overall_frame(stats)
if frame.empty:
_, ax = empty_plot(
fig,
ax,
title=title,
message="No overall calibration metrics",
)
return ax
plot_frame = pd.DataFrame(
{
"metric": [
"coverage_before",
"coverage_after",
"sharpness_before",
"sharpness_after",
],
"value": [
frame.loc[
frame["which"] == "eval_before",
"coverage",
].iloc[0]
if (frame["which"] == "eval_before").any()
else np.nan,
frame.loc[
frame["which"] == "eval_after",
"coverage",
].iloc[0]
if (frame["which"] == "eval_after").any()
else np.nan,
frame.loc[
frame["which"] == "eval_before",
"sharpness",
].iloc[0]
if (frame["which"] == "eval_before").any()
else np.nan,
frame.loc[
frame["which"] == "eval_after",
"sharpness",
].iloc[0]
if (frame["which"] == "eval_after").any()
else np.nan,
],
}
)
return plot_metric_bars(
ax,
plot_frame,
title=title,
error=error,
**plot_kws,
)
[docs]
def plot_calibration_per_horizon_coverage(
ax: plt.Axes,
stats: CalibrationStatsLike,
*,
which: str = "eval_after",
title: str | None = None,
error: str = "ignore",
**plot_kws: Any,
) -> plt.Axes:
"""Plot per-horizon coverage."""
fig, ax, _ = prepare_plot(ax=ax)
frame = calibration_stats_per_horizon_frame(
stats,
which=which,
)
plot_title = title or f"Coverage by horizon ({which})"
if frame.empty:
_, ax = empty_plot(
fig,
ax,
title=plot_title,
message="No per-horizon coverage",
)
return ax
data = dict(
zip(frame["horizon"], frame["coverage"], strict=False)
)
return plot_series_map(
ax,
data,
title=plot_title,
xlabel="horizon",
ylabel="coverage",
error=error,
**plot_kws,
)
[docs]
def plot_calibration_per_horizon_sharpness(
ax: plt.Axes,
stats: CalibrationStatsLike,
*,
which: str = "eval_after",
title: str | None = None,
error: str = "ignore",
**plot_kws: Any,
) -> plt.Axes:
"""Plot per-horizon sharpness."""
fig, ax, _ = prepare_plot(ax=ax)
frame = calibration_stats_per_horizon_frame(
stats,
which=which,
)
plot_title = title or f"Sharpness by horizon ({which})"
if frame.empty:
_, ax = empty_plot(
fig,
ax,
title=plot_title,
message="No per-horizon sharpness",
)
return ax
data = dict(
zip(
frame["horizon"], frame["sharpness"], strict=False
)
)
return plot_series_map(
ax,
data,
title=plot_title,
xlabel="horizon",
ylabel="sharpness",
error=error,
**plot_kws,
)
[docs]
def plot_calibration_boolean_summary(
ax: plt.Axes,
stats: CalibrationStatsLike,
*,
title: str = "Calibration checks",
error: str = "ignore",
**plot_kws: Any,
) -> plt.Axes:
"""Plot compact boolean checks for calibration status."""
summary = summarize_calibration_stats(stats)
checks = {
"has_eval_before": summary.get("has_eval_before"),
"has_eval_after": summary.get("has_eval_after"),
"has_factors": summary.get("has_factors"),
"coverage_error_improved": summary.get(
"coverage_error_improved"
),
"target_reached_after": summary.get(
"target_reached_after"
),
"not_skipped": not bool(
summary.get("skipped", False)
),
}
return plot_boolean_checks(
ax,
checks,
title=title,
error=error,
**plot_kws,
)
# ------------------------------------------------------------------
# Inspection bundle
# ------------------------------------------------------------------
[docs]
def inspect_calibration_stats(
stats: CalibrationStatsLike,
) -> dict[str, Any]:
"""
Build a compact inspection bundle.
Returns
-------
dict
A dictionary containing the raw payload, a compact
summary, and tidy frames useful for gallery lessons,
notebooks, or debugging.
"""
payload = _as_payload(stats)
summary = summarize_calibration_stats(payload)
return {
"payload": payload,
"summary": summary,
"overall": calibration_stats_overall_frame(payload),
"factors": calibration_stats_factors_frame(payload),
"per_horizon_before": calibration_stats_per_horizon_frame(
payload,
which="eval_before",
),
"per_horizon_after": calibration_stats_per_horizon_frame(
payload,
which="eval_after",
),
}