# SPDX-License-Identifier: Apache-2.0
# GeoPrior-v3 — https://github.com/earthai-tech/geoprior-v3
# Copyright (c) 2026-present
# Author: LKouadio <etanoyau@gmail.com>
# website:https://lkouadio.com
r"""NAT workflow helpers for configs, datasets, and scaling."""
from __future__ import annotations
import datetime as dt
import hashlib
import importlib.util
import json
import os
from collections.abc import Mapping
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
)
import joblib
import numpy as np
# --- Optional TensorFlow import for GeoPrior helpers -----------------------
try: # pragma: no cover - defensive import
import tensorflow as tf # noqa
from tensorflow.keras.optimizers import Adam
TF_AVAILABLE = True
except Exception: # pragma: no cover
TF_AVAILABLE = False
tf = None # type: ignore[assignment]
class _AdamStub:
def __init__(self, *args: Any, **kwargs: Any) -> None:
raise ImportError(
"TensorFlow is required for NATCOM GeoPrior helpers "
"(e.g. compile_geoprior_for_eval). Please install "
"`tensorflow>=2.12`."
)
Adam = _AdamStub # type: ignore[assignment]
# ---------------------------------------------------------------------
# Optional TensorFlow typing support
# ---------------------------------------------------------------------
# We avoid importing TensorFlow at runtime from this module to keep it
# lightweight (useful for tooling / docs environments). For type checkers
# and IDEs, we expose a tf name under TYPE_CHECKING.
#
# Use string annotations like "tf.data.Dataset" and "tf.Tensor" so that
# runtime does not need TensorFlow to be installed.
if TYPE_CHECKING: # pragma: no cover
import tensorflow as tf # noqa: F401
# Shared error message used by helpers that need TensorFlow.
TF_IMPORT_ERROR_MSG = (
"geoprior.utils.nat_utils: TensorFlow is required for this helper "
"but could not be imported. Install `tensorflow` to use functions "
"that construct or consume `tf.data.Dataset` objects."
)
# -------------------------------------------------------------------
# Internal path helpers
# -------------------------------------------------------------------
def _project_root() -> str:
"""
Return the root directory of the `geoprior-learn` repository.
This is computed relative to this file:
geoprior-learn/
geoprior/
utils/
nat_utils/nat_utils.py
nat.com/
config.py
"""
here = os.path.abspath(__file__)
utils_dir = os.path.dirname(os.path.dirname(here))
fusionlab_dir = os.path.dirname(utils_dir)
root = os.path.dirname(fusionlab_dir)
return root
[docs]
def get_natcom_dir(root="nat.com") -> str:
"""
Directory containing NATCOM scripts and configuration,
typically `<repo_root>/nat.com`.
"""
return os.path.join(_project_root(), root)
[docs]
def get_config_paths(root="nat.com") -> tuple[str, str]:
"""
Return `(config_py_path, config_json_path)` for NATCOM.
"""
nat_dir = get_natcom_dir(root=root)
config_py = os.path.join(nat_dir, "config.py")
config_json = os.path.join(nat_dir, "config.json")
return config_py, config_json
def get_default_runs_root(
root: str = "nat.com",
runs_dir_name: str = ".fusionlab_runs",
) -> str:
"""
Return the base directory for GUI run artifacts.
The default is ``<project_root>/.fusionlab_runs`` where
``<project_root>`` is the same root inferred by
:func:`_project_root`.
This is *only* a convenience helper; CLI scripts keep
using their own defaults (usually ``<cwd>/results``).
The GUI overrides ``BASE_OUTPUT_DIR`` with this path so
GUI runs do not mix with CLI results.
"""
proj_root = os.path.dirname(get_natcom_dir(root=root))
runs_root = os.path.join(proj_root, runs_dir_name)
os.makedirs(runs_root, exist_ok=True)
return runs_root
# -------------------------------------------------------------------
# Low-level helpers
# -------------------------------------------------------------------
def _hash_file(path: str) -> str:
"""
Compute a SHA-256 hash of the file at `path`.
"""
h = hashlib.sha256()
with open(path, "rb") as f:
for chunk in iter(lambda: f.read(8192), b""):
if not chunk:
break
h.update(chunk)
return h.hexdigest()
def _import_config_module(config_py: str):
"""
Import `config.py` by absolute path, without assuming it is
on `sys.path`.
"""
if not os.path.exists(config_py):
raise FileNotFoundError(
f"NATCOM config.py not found at: {config_py}"
)
spec = importlib.util.spec_from_file_location(
"nat_config", config_py
)
if spec is None or spec.loader is None:
raise ImportError(
f"Could not load spec for {config_py!r}"
)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module) # type: ignore[attr-defined]
return module
def _is_basic_jsonable(value: Any) -> bool:
"""
Return True if the value is a simple JSON-serialisable type.
"""
return isinstance(
value, int | float | str | bool | list | dict
)
def _extract_config_dict(module) -> dict[str, Any]:
"""
Extract a flat configuration dictionary from the `config`
module by selecting suitable global variables.
- Keys starting with '_' are ignored.
- Functions, classes and modules are ignored.
- Only basic JSON-like values are kept.
Environment variables (CITY, MODEL_NAME_OVERRIDE,
JUPYTER_PROJECT_ROOT) can override some keys.
"""
cfg: dict[str, Any] = {}
for name, value in vars(module).items():
if name.startswith("_"):
continue
if callable(value):
continue
if isinstance(value, type):
continue
if _is_basic_jsonable(value):
cfg[name] = value
# Build a compact "censoring" block for Stage-2 scripts if
# it is not already present.
if "CENSORING_SPECS" in cfg and "censoring" not in cfg:
censor_block = {
"specs": cfg["CENSORING_SPECS"],
"use_effective_h_field": cfg.get(
"USE_EFFECTIVE_H_FIELD", True
),
"include_flags_as_dynamic": cfg.get(
"INCLUDE_CENSOR_FLAGS_AS_DYNAMIC", True
),
}
cfg["censoring"] = censor_block
# Optional environment overrides (advanced use).
city_env = os.getenv("CITY", "").strip()
if city_env:
cfg["CITY_NAME"] = city_env.lower()
model_env = os.getenv("MODEL_NAME_OVERRIDE", "").strip()
if model_env:
cfg["MODEL_NAME"] = model_env
root_env = os.getenv("JUPYTER_PROJECT_ROOT", "").strip()
if root_env:
cfg["DATA_DIR"] = root_env
return cfg
# -------------------------------------------------------------------
# Public API
# -------------------------------------------------------------------
def _refresh_city_files(cfg: dict[str, Any]) -> None:
city = str(cfg.get("CITY_NAME", "")).strip().lower()
var = str(cfg.get("DATASET_VARIANT", "")).strip()
btmp = cfg.get("BIG_FN_TEMPLATE", None)
stmp = cfg.get("SMALL_FN_TEMPLATE", None)
if city and var and isinstance(btmp, str):
cfg["BIG_FN"] = btmp.format(city=city, variant=var)
if city and var and isinstance(stmp, str):
cfg["SMALL_FN"] = stmp.format(city=city, variant=var)
def _apply_env_overrides(
cfg: dict[str, Any],
) -> dict[str, Any]:
changed = False
city_env = os.getenv("CITY", "").strip()
if city_env:
cfg["CITY_NAME"] = city_env.lower()
changed = True
model_env = os.getenv("MODEL_NAME_OVERRIDE", "").strip()
if model_env:
cfg["MODEL_NAME"] = model_env
changed = True
root_env = os.getenv("JUPYTER_PROJECT_ROOT", "").strip()
if root_env:
cfg["DATA_DIR"] = root_env
changed = True
if changed:
_refresh_city_files(cfg)
return cfg
[docs]
def ensure_config_json(
root: str = "nat.com",
) -> tuple[dict[str, Any], str]:
"""
Ensure that `nat.com/config.json` exists and is consistent
with `nat.com/config.py`.
Returns
-------
config : dict
The configuration dictionary (`payload["config"]`).
json_path : str
Absolute path to `config.json`.
Behaviour
---------
- If `config.json` does not exist, it is created from
`config.py`.
- If it exists but the SHA-256 hash of `config.py` has
changed, it is regenerated.
- Otherwise the existing JSON file is reused.
"""
config_py, config_json = get_config_paths(root=root)
py_hash = _hash_file(config_py)
payload: dict[str, Any] | None = None
if os.path.exists(config_json):
try:
with open(config_json, encoding="utf-8") as f:
payload = json.load(f)
except Exception:
payload = None
meta = (
payload.get("__meta__", {})
if isinstance(payload, dict)
else {}
)
if (
isinstance(payload, dict)
and meta.get("config_py_hash") == py_hash
and "config" in payload
):
cfg = dict(payload["config"])
cfg = _apply_env_overrides(cfg)
return cfg, config_json
module = _import_config_module(config_py)
cfg = _extract_config_dict(module)
cfg = _apply_env_overrides(cfg)
payload = {
"city": cfg.get("CITY_NAME"),
"model": cfg.get("MODEL_NAME"),
"config": cfg,
"__meta__": {"config_py_hash": py_hash},
}
os.makedirs(os.path.dirname(config_json), exist_ok=True)
with open(config_json, "w", encoding="utf-8") as f:
json.dump(payload, f, indent=2)
return cfg, config_json
[docs]
def load_nat_config(root="nat.com") -> dict[str, Any]:
"""
High-level helper used by NATCOM scripts.
Example
-------
>>> from geoprior.utils.nat_utils import load_nat_config
>>> cfg = load_nat_config()
>>> CITY_NAME = cfg["CITY_NAME"]
>>> TIME_STEPS = cfg["TIME_STEPS"]
"""
cfg, _ = ensure_config_json(root=root)
return cfg
[docs]
def load_nat_config_payload(root="nat.com") -> dict[str, Any]:
"""
Return the full `config.json` payload, including `city`,
`model` and `__meta__` fields.
This is convenient when you also want to see which hash or
city/model are currently active.
"""
config_py, config_json = get_config_paths(root=root)
if not os.path.exists(config_json):
ensure_config_json(root=root)
with open(config_json, encoding="utf-8") as f:
return json.load(f)
def _as_float1(x):
if x is None:
return None
arr = np.asarray(x).reshape(-1)
return float(arr[0])
def affine_from_scaler(scaler, idx: int = 0):
if hasattr(scaler, "data_min_") and hasattr(
scaler, "data_max_"
):
data_min = np.asarray(scaler.data_min_).reshape(-1)
data_max = np.asarray(scaler.data_max_).reshape(-1)
scale = float((data_max - data_min)[idx])
bias = float(data_min[idx])
return scale, bias
if hasattr(scaler, "scale_") and hasattr(scaler, "mean_"):
sc = np.asarray(scaler.scale_).reshape(-1)
mu = np.asarray(scaler.mean_).reshape(-1)
return float(sc[idx]), float(mu[idx])
if hasattr(scaler, "scale_") and hasattr(
scaler, "center_"
):
sc = np.asarray(scaler.scale_).reshape(-1)
ce = np.asarray(scaler.center_).reshape(-1)
return float(sc[idx]), float(ce[idx])
raise TypeError(
f"Unsupported scaler type for affine inference: {type(scaler)}"
)
[docs]
def resolve_si_affine(
cfg: dict,
scaler_info: dict,
*,
target_name: str,
prefix: str, # "SUBS" or "HEAD"
unit_factor_key: str, # "SUBS_UNIT_TO_SI" or "HEAD_UNIT_TO_SI"
scale_key: str, # "SUBS_SCALE_SI" / "HEAD_SCALE_SI"
bias_key: str, # "SUBS_BIAS_SI" / "HEAD_BIAS_SI"
):
# 1) explicit overrides win
scale = cfg.get(scale_key, None)
bias = cfg.get(bias_key, None)
unit_factor = float(cfg.get(unit_factor_key, 1.0))
auto = bool(cfg.get("AUTO_SI_AFFINE_FROM_STAGE1", True))
if (scale is None or bias is None) and auto:
info = scaler_info.get(target_name) or {}
idx = int(info.get("idx", 0))
scaler = info.get("scaler")
if scaler is None and "scaler_path" in info:
# load happens elsewhere in your code; keep it simple here
scaler = info.get("scaler")
if scaler is None:
raise RuntimeError(
f"[{prefix}] Cannot infer SI affine: scaler for target "
f"{target_name!r} not found in scaler_info."
)
s, b = affine_from_scaler(scaler, idx=idx)
if scale is None:
scale = s
if bias is None:
bias = b
# 2) apply unit conversion into the affine
# y_SI = (y_scaled*scale + bias) * unit_factor
scale_si = float(scale) * unit_factor
bias_si = float(bias) * unit_factor
return scale_si, bias_si
# -------------------------------------------------------------------------
# NATCOM training helpers
# -------------------------------------------------------------------------
[docs]
def map_targets_for_training(
y_dict: dict,
subs_key: str = "subsidence",
gwl_key: str = "gwl",
subs_pred_key: str = "subs_pred",
gwl_pred_key: str = "gwl_pred",
) -> dict:
"""
Standardise target dictionaries to the Keras compile keys.
This helper enforces a small convention used throughout the
NATCOM training scripts:
- Upstream sequence builders typically export raw targets with
keys ``subsidence`` and ``gwl``.
- The GeoPrior model is compiled with targets named
``subs_pred`` and ``gwl_pred``.
This function accepts either style and always returns a dict
keyed by ``subs_pred`` and ``gwl_pred`` for use in Keras.
Parameters
----------
y_dict : dict
Dictionary produced by the Stage-1 sequence exporter or by
a previous training script. Must contain either
(``subsidence``, ``gwl``) or (``subs_pred``, ``gwl_pred``).
subs_key : str, default="subsidence"
Name of the raw subsidence key in ``y_dict``.
gwl_key : str, default="gwl"
Name of the raw groundwater-level key in ``y_dict``.
subs_pred_key : str, default="subs_pred"
Standardised key for the subsidence prediction target.
gwl_pred_key : str, default="gwl_pred"
Standardised key for the GWL prediction target.
Returns
-------
dict
New dictionary with keys ``subs_pred`` and ``gwl_pred``.
Raises
------
KeyError
If the dictionary does not contain either of the expected
key pairs.
"""
# Case 1: raw keys from Stage-1 exporter.
if subs_key in y_dict and gwl_key in y_dict:
return {
subs_pred_key: y_dict[subs_key],
gwl_pred_key: y_dict[gwl_key],
}
# Case 2: already in compiled form.
if subs_pred_key in y_dict and gwl_pred_key in y_dict:
return y_dict
# Anything else is considered an error – we fail loudly so the
# user can fix the pipeline rather than train on the wrong data.
raise KeyError(
f"Targets must contain ({subs_key!r},{gwl_key!r}) or "
f"({subs_pred_key!r},{gwl_pred_key!r})."
)
def _np_nonfinite_report(
arr: np.ndarray,
max_items: int = 5,
) -> dict[str, Any] | None:
"""Return a small report if arr has NaN/Inf."""
bad = ~np.isfinite(arr)
n_bad = int(bad.sum())
if n_bad == 0:
return None
idx = np.argwhere(bad)
head = idx[:max_items]
samples: list[tuple[tuple[int, ...], Any]] = []
for ii in head:
t = tuple(ii.tolist())
samples.append((t, arr[t]))
finite = np.isfinite(arr)
if np.any(finite):
min_f = float(np.nanmin(arr[finite]))
max_f = float(np.nanmax(arr[finite]))
else:
min_f = None
max_f = None
return {
"shape": tuple(arr.shape),
"dtype": str(arr.dtype),
"n_nonfinite": n_bad,
"first_bad": samples,
"min_finite": min_f,
"max_finite": max_f,
}
def check_npz_dict_finite(
d: dict,
name: str,
feature_names_last_dim: list[str] | None = None,
max_bad_channels: int = 30,
max_print: int = 20,
) -> None:
"""
Validate that all numeric arrays in a dict
contain only finite values.
If `feature_names_last_dim` is provided and
matches v.shape[-1], a per-channel report
is added (helpful for dyn/fut features).
"""
problems: list[tuple[str, dict[str, Any]]] = []
for k, v in d.items():
if not isinstance(v, np.ndarray):
continue
if not np.issubdtype(v.dtype, np.number):
continue
rep = _np_nonfinite_report(v)
if rep is None:
continue
if (
feature_names_last_dim
and v.ndim >= 2
and v.shape[-1] == len(feature_names_last_dim)
):
bad = ~np.isfinite(v)
per_ch = bad.reshape(-1, v.shape[-1]).sum(axis=0)
bad_idx = np.where(per_ch > 0)[0].tolist()
rep["bad_channels"] = [
{
"index": int(i),
"name": feature_names_last_dim[i],
"n_nonfinite": int(per_ch[i]),
}
for i in bad_idx[:max_bad_channels]
]
problems.append((str(k), rep))
if not problems:
print(f"[OK] {name}: all numeric arrays finite.")
return
print(f"\n[NaN/Inf] Non-finite values in {name}:")
for k, rep in problems[:max_print]:
sh = rep.get("shape")
# dt = rep.get("dtype")
nb = rep.get("n_nonfinite")
print(f" - key={k!r} shape={sh} dtype={dt}")
print(f" nonfinite={nb}")
if "bad_channels" in rep:
print(" bad channels:")
for ch in rep["bad_channels"]:
i = ch["index"]
nm = ch["name"]
nn = ch["n_nonfinite"]
print(f" * {i:>3} {nm:<32} n={nn}")
fb = rep.get("first_bad", [])
print(f" first bad: {fb[:5]}")
raise RuntimeError(
f"Stopping: {name} contains NaN/Inf. "
"Fix Stage-1 export or cleaning."
)
def scan_tf_dataset_finite(
ds: Any,
name: str,
max_batches: int = 200,
) -> None:
"""
Eager scan of first N batches.
Useful to fail *before* model.fit().
"""
try:
import tensorflow as tf # type: ignore
except Exception as exc: # pragma: no cover
raise ImportError(TF_IMPORT_ERROR_MSG) from exc
for b, (xb, yb) in enumerate(ds):
if max_batches is not None and b >= max_batches:
break
for k, v in xb.items():
if v.dtype.is_floating or v.dtype.is_complex:
tf.debugging.assert_all_finite(
v,
f"{name}: X[{k}] NaN/Inf at batch {b}",
)
for k, v in yb.items():
if v.dtype.is_floating or v.dtype.is_complex:
tf.debugging.assert_all_finite(
v,
f"{name}: y[{k}] NaN/Inf at batch {b}",
)
print(f"[OK] {name}: first {max_batches} batches ok.")
[docs]
def make_tf_dataset(
X_np: dict,
y_np: dict,
batch_size: int,
shuffle: bool,
mode: str,
forecast_horizon: int,
*,
seed: int = 42,
drop_remainder: bool = False,
reshuffle_each_iter: bool = True,
prefetch: bool = True,
check_npz_finite: bool = False,
check_finite: bool = False,
scan_finite_batches: int = 0,
dynamic_feature_names: list[str] | None = None,
future_feature_names: list[str] | None = None,
) -> Any:
"""
Build a `tf.data.Dataset` using NATCOM
conventions.
Steps:
1) ensure_input_shapes(...) for X.
2) map_targets_for_training(...) for y.
3) tf.data pipeline (shuffle/batch/prefetch).
4) optional finite checks (NPZ + tf batches).
Parameters
----------
X_np : dict
Input dictionary, typically obtained from ``np.load`` on
the Stage-1 ``*_inputs_npz`` file.
y_np : dict
Target dictionary, typically obtained from ``np.load`` on
the Stage-1 ``*_targets_npz`` file.
batch_size : int
Number of samples per batch.
shuffle : bool
If ``True``, shuffle the dataset using a fixed seed for
reproducibility.
mode : str
Model mode passed to :func:`ensure_input_shapes`.
forecast_horizon : int
Forecast horizon passed to :func:`ensure_input_shapes`.
check_npz_finite : bool
If True, checks Xin/Yin numpy arrays
for NaN/Inf before building ds.
check_finite : bool
If True, inserts `assert_all_finite`
checks inside the tf.data pipeline.
scan_finite_batches : int
If >0, eagerly scans first N batches
right away (fails early).
dynamic_feature_names, future_feature_names
If provided, used to report bad
channels for feature tensors.
Returns
-------
tf.data.Dataset
Dataset of (X, y) pairs.
Notes
-----
TensorFlow is imported lazily inside the function so that
this module remains importable in environments where TF is
not installed (for example, for tooling or static analysis).
"""
try:
import tensorflow as tf # type: ignore
except Exception as exc: # pragma: no cover
raise ImportError(TF_IMPORT_ERROR_MSG) from exc
# Normalize inputs/targets to canonical keys.
Xin = ensure_input_shapes(
X_np,
mode=mode,
forecast_horizon=forecast_horizon,
)
Yin = map_targets_for_training(y_np)
# Optional: stop early if NPZ content is bad.
if check_npz_finite:
check_npz_dict_finite(Xin, "X_np (Xin)")
check_npz_dict_finite(Yin, "y_np (Yin)")
if (
dynamic_feature_names
and "dynamic_features" in Xin
):
check_npz_dict_finite(
{"dynamic_features": Xin["dynamic_features"]},
"Xin.dynamic_features",
feature_names_last_dim=dynamic_feature_names,
)
if future_feature_names and "future_features" in Xin:
check_npz_dict_finite(
{"future_features": Xin["future_features"]},
"Xin.future_features",
feature_names_last_dim=future_feature_names,
)
# Build dataset.
ds = tf.data.Dataset.from_tensor_slices((Xin, Yin))
# Shuffle with a stable seed.
if shuffle:
# Prefer dynamic_features for size.
if "dynamic_features" in Xin:
n = int(Xin["dynamic_features"].shape[0])
else:
# Fallback: first array in Xin.
first = next(iter(Xin.values()))
n = int(first.shape[0])
ds = ds.shuffle(
buffer_size=max(1, n),
seed=seed,
reshuffle_each_iteration=reshuffle_each_iter,
)
ds = ds.batch(
batch_size,
drop_remainder=drop_remainder,
)
# Optional: add assert_all_finite into pipeline.
if check_finite:
def _assert_batch(xb, yb):
# Assert only float/complex tensors.
for k, v in xb.items():
if v.dtype.is_floating or v.dtype.is_complex:
tf.debugging.assert_all_finite(
v,
f"X[{k}] has NaN/Inf",
)
for k, v in yb.items():
if v.dtype.is_floating or v.dtype.is_complex:
tf.debugging.assert_all_finite(
v,
f"y[{k}] has NaN/Inf",
)
return xb, yb
ds = ds.map(
_assert_batch,
num_parallel_calls=tf.data.AUTOTUNE,
)
if prefetch:
ds = ds.prefetch(tf.data.AUTOTUNE)
# Optional: force an eager scan now.
# This is what stops *before* model.fit().
if scan_finite_batches and scan_finite_batches > 0:
scan_tf_dataset_finite(
ds,
name="make_tf_dataset",
max_batches=int(scan_finite_batches),
)
return ds
[docs]
def load_scaler_info(encoders_block: dict) -> dict | None:
"""
Load the ``scaler_info`` mapping from an encoders block.
Stage-1 exporters typically store a compact description of the
scalers used to normalise the data. In many cases this takes
the form:
.. code-block:: python
encoders = {
"main_scaler": "/path/to/minmax.joblib",
"coord_scaler": "/path/to/coords.joblib",
"scaler_info": "/path/to/scaler_info.joblib",
...
}
where ``scaler_info`` is either a path to a joblib file or an
already-loaded dictionary.
This helper returns a dictionary regardless of how it was
stored, making downstream formatting/evaluation code simpler.
Parameters
----------
encoders_block : dict
The ``encoders`` part of the Stage-1 manifest
(``M["artifacts"]["encoders"]``).
Returns
-------
dict or None
The loaded ``scaler_info`` dictionary, or ``None`` if not
present / not loadable.
"""
si = encoders_block.get("scaler_info")
if isinstance(si, str) and os.path.exists(si):
try:
return joblib.load(si)
except Exception:
# If loading fails we fall back to the raw string; the
# caller can decide how to proceed.
pass
return si
[docs]
def build_censor_mask(
xb: dict,
H,
idx: int | None,
thresh: float = 0.5,
*,
source: str = "dynamic", # {"dynamic", "future"}
reduce_time: str = "any", # {"any", "last", "all"}
align: str = "broadcast", # {"broadcast", "crop", "pad_false", "pad_edge", "error"}
) -> tf.Tensor:
"""
Build a censor mask aligned to the forecast horizon: (B, H, 1).
Parameters
----------
source : {"dynamic", "future"}, default="dynamic"
Selects where the censoring flag is read from. ``"dynamic"``
reads ``xb["dynamic_features"][:, :, idx]`` from the history
window, while ``"future"`` reads
``xb["future_features"][:, :, idx]`` from the forecast window.
reduce_time : {"any", "last", "all"}, default="any"
Reduction applied when ``source="dynamic"`` and the censor flag
behaves like a per-sample label. ``"any"`` marks the sample as
censored if any history step is flagged, ``"last"`` uses only
the last history step, and ``"all"`` requires every history step
to be flagged.
align : {"broadcast", "crop", "pad_false", "pad_edge", "error"}, default="broadcast"
Policy used when the time axis does not already match the
forecast horizon ``H``. ``"broadcast"`` repeats a single-step
label across all horizon steps, ``"crop"`` keeps the last ``H``
steps, ``"pad_false"`` pads missing steps with ``False``,
``"pad_edge"`` repeats the last available step, and ``"error"``
raises on mismatch.
"""
try:
import tensorflow as tf # type: ignore
except Exception as exc: # pragma: no cover
raise ImportError(TF_IMPORT_ERROR_MSG) from exc
# Resolve batch size
if "coords" in xb:
B = tf.shape(xb["coords"])[0]
else:
B = tf.shape(
xb.get(
"dynamic_features", xb.get("future_features")
)
)[0]
H = tf.cast(H, tf.int32)
if idx is None:
return tf.zeros((B, H, 1), dtype=tf.bool)
key = (
"dynamic_features"
if source == "dynamic"
else "future_features"
)
feat = xb.get(key, None)
if feat is None:
return tf.zeros((B, H, 1), dtype=tf.bool)
nfeat = tf.shape(feat)[-1]
def _all_false():
return tf.zeros((B, H, 1), dtype=tf.bool)
def _align_time(m): # m: (B, T, 1)
T = tf.shape(m)[1]
def _broadcast_from_one(step):
return tf.tile(step, [1, H, 1])
if align == "error":
tf.debugging.assert_equal(
T, H, message=f"{key} length != H"
)
return m
if align == "crop":
# if T < H, this cannot increase length -> would still mismatch
return tf.cond(
T >= H, lambda: m[:, -H:, :], lambda: m
)
if align == "pad_false":
# pad at front so last steps line up
pad = tf.maximum(H - T, 0)
m2 = tf.pad(
m,
paddings=[[0, 0], [pad, 0], [0, 0]],
constant_values=False,
)
return m2[:, -H:, :]
if align == "pad_edge":
pad = tf.maximum(H - T, 0)
last = m[:, -1:, :]
m2 = tf.concat(
[tf.tile(last, [1, pad, 1]), m], axis=1
)
return m2[:, -H:, :]
# default: "broadcast"
# If already H, keep it; else broadcast a single-step summary.
return tf.cond(
tf.equal(T, H),
lambda: m,
lambda: _broadcast_from_one(m[:, -1:, :]),
)
def _build():
m = feat[..., idx : idx + 1] > thresh # (B, T, 1)
# If source is dynamic, we usually want a sample-level censor label.
if source == "dynamic":
if reduce_time == "any":
one = tf.reduce_any(
m, axis=1, keepdims=True
) # (B,1,1)
return tf.tile(one, [1, H, 1]) # (B,H,1)
if reduce_time == "all":
one = tf.reduce_all(m, axis=1, keepdims=True)
return tf.tile(one, [1, H, 1])
if reduce_time == "last":
one = m[:, -1:, :]
return tf.tile(one, [1, H, 1])
# Otherwise, align time dimension to horizon (future usually already matches)
return _align_time(m)
return tf.cond(tf.less(idx, nfeat), _build, _all_false)
def build_censor_mask_from_dynamic(
xb: dict,
H: int,
dyn_idx: int | None,
thresh: float = 0.5,
) -> tf.Tensor:
"""
Build a boolean censoring mask from the dynamic features.
This is used to stratify metrics by censored/uncensored cells
based on a flag stored in ``dynamic_features[..., dyn_idx]``.
The function:
- Looks up ``dynamic_features`` from the input batch.
- Applies a threshold on the selected feature column to build
a mask of shape ``(B, T_dyn, 1)``.
- If the dynamic time length differs from ``H``, it takes the
last ``H`` steps (consistent with the forecasting horizon).
- If no dynamic features or index are available, returns an
all-False mask of shape ``(B, H, 1)``.
If the censor flag only exists on the history window (T_dyn=TIME_STEPS),
but the evaluation is done on the forecast horizon (H=FORECAST_HORIZON),
we must broadcast/pad because slicing cannot increase length.
Parameters
----------
xb : dict
Batch input dictionary from a ``tf.data.Dataset`` with
at least ``"dynamic_features"`` and ``"coords"``.
H : int
Horizon length for the evaluation (number of time steps).
dyn_idx : int or None
Index of the censor flag within ``dynamic_features``.
If ``None``, returns an all-False mask.
thresh : float, default=0.5
Threshold above which a value is considered "censored".
Returns
-------
tf.Tensor
Boolean mask of shape ``(B, H, 1)`` where True indicates
censored samples.
"""
try:
import tensorflow as tf # type: ignore
except Exception as exc: # pragma: no cover
raise ImportError(TF_IMPORT_ERROR_MSG) from exc
# Resolve B
if "coords" in xb:
B = tf.shape(xb["coords"])[0]
else:
dyn0 = xb.get("dynamic_features", None)
B = tf.shape(dyn0)[0] if dyn0 is not None else 0
H = tf.cast(H, tf.int32)
# No flag → no censoring
dyn = xb.get("dynamic_features", None)
if dyn is None or dyn_idx is None:
return tf.zeros((B, H, 1), dtype=tf.bool)
# Defensive: dyn_idx range (works even if dyn.shape[-1] is None)
nfeat = tf.shape(dyn)[-1]
def _all_false():
return tf.zeros((B, H, 1), dtype=tf.bool)
def _build():
# (B, T_dyn, 1)
m_dyn = dyn[..., dyn_idx : dyn_idx + 1] > thresh
T_dyn = tf.shape(m_dyn)[1]
# Case 1: exact match
def _same():
return m_dyn
# Case 2: history longer than horizon → take last H
def _crop():
return m_dyn[:, -H:, :]
# Case 3: history shorter than horizon → broadcast last observed flag
def _broadcast_last():
last = m_dyn[:, -1:, :] # (B,1,1)
return tf.tile(last, [1, H, 1]) # (B,H,1)
return tf.case(
[
(tf.equal(T_dyn, H), _same),
(tf.greater(T_dyn, H), _crop),
(tf.less(T_dyn, H), _broadcast_last),
],
default=_broadcast_last,
exclusive=True,
)
return tf.cond(
tf.less(dyn_idx, nfeat), _build, _all_false
)
# -------------------------------------------------------------------------
# Public helpers for Stage-1/Stage-2 NPZ handling and tuned model recovery
# -------------------------------------------------------------------------
[docs]
def pick_npz_for_dataset(
manifest: dict,
split: str,
) -> tuple[dict | None, dict | None]:
"""
Load (inputs, targets) NPZ arrays for a given dataset split.
This is a public, reusable version of the internal helper that
was previously named ``_pick_npz_for_dataset``.
Parameters
----------
manifest : dict
Stage-1 manifest dictionary with the structure::
manifest["artifacts"]["numpy"] = {
"train_inputs_npz": ...,
"train_targets_npz": ...,
"val_inputs_npz": ...,
"val_targets_npz": ...,
"test_inputs_npz": ... (optional),
"test_targets_npz": ... (optional),
}
split : {"train", "val", "test"}
Which dataset to load.
Returns
-------
X : dict or None
Dictionary of input arrays for the requested split, or ``None``
if the split is unavailable (e.g. test NPZ missing).
y : dict or None
Dictionary of target arrays for the requested split, or ``None``
if targets are unavailable.
Raises
------
KeyError
If the manifest does not contain the expected NPZ entries.
ValueError
If ``split`` is not one of ``{"train", "val", "test"}``.
"""
npzs = manifest.get("artifacts", {}).get("numpy", None)
if npzs is None:
raise KeyError(
"Manifest is missing 'artifacts[\"numpy\"]' section with NPZ paths."
)
if split == "train":
x = dict(np.load(npzs["train_inputs_npz"]))
y = dict(np.load(npzs["train_targets_npz"]))
return x, y
if split == "val":
x = dict(np.load(npzs["val_inputs_npz"]))
y = dict(np.load(npzs["val_targets_npz"]))
return x, y
if split == "test":
tin = npzs.get("test_inputs_npz")
tt = npzs.get("test_targets_npz")
if not tin:
# No test split available for this run
return None, None
x = dict(np.load(tin))
y = dict(np.load(tt)) if tt else None
return x, y
raise ValueError(
"split must be one of {'train', 'val', 'test'}."
)
def infer_input_dims_from_X(X: dict) -> tuple[int, int, int]:
"""
Infer (static_input_dim, dynamic_input_dim, future_input_dim) from NPZ inputs.
This is a public, defensive version of the former
``_infer_input_dims_from_X`` helper.
Parameters
----------
X : dict
Dictionary with keys:
- ``'dynamic_features'`` (required, shape (N, T, D_dyn))
- ``'static_features'`` (optional, shape (N, D_static) or None)
- ``'future_features'`` (optional, shape (N, T_future, D_future) or None)
Returns
-------
static_dim : int
Last-dimension size of ``static_features`` (0 if missing or None).
dynamic_dim : int
Last-dimension size of ``dynamic_features``. Raises if missing.
future_dim : int
Last-dimension size of ``future_features`` (0 if missing or None).
Raises
------
KeyError
If ``'dynamic_features'`` is missing in ``X``.
"""
if "dynamic_features" not in X:
raise KeyError(
"X must contain key 'dynamic_features' with shape (N, T, D_dyn)."
)
dyn = np.asarray(X["dynamic_features"])
dynamic_dim = int(dyn.shape[-1])
static = X.get("static_features", None)
static_dim = (
int(np.asarray(static).shape[-1])
if static is not None
else 0
)
fut = X.get("future_features", None)
future_dim = (
int(np.asarray(fut).shape[-1])
if fut is not None
else 0
)
return static_dim, dynamic_dim, future_dim
def _npz_to_dict(path: Path) -> dict[str, np.ndarray]:
path = Path(path)
with np.load(str(path), allow_pickle=False) as z:
return {k: z[k] for k in z.files}
def _strip_prefix(
name: str, prefixes: tuple[str, ...]
) -> str:
low = name.lower()
for p in prefixes:
if low.startswith(p):
return name[len(p) :]
return name
def _split_bundle_npz(
data: Mapping[str, np.ndarray],
) -> tuple[dict[str, np.ndarray], dict[str, np.ndarray]]:
x: dict[str, np.ndarray] = {}
y: dict[str, np.ndarray] = {}
for k, v in data.items():
lk = k.lower()
if lk in ("subs_pred", "gwl_pred"):
y[k] = v
continue
if lk.startswith(("y_", "y.", "target")):
y[k] = v
continue
if lk.startswith(("x_", "x.", "input")):
x[k] = v
continue
x[k] = v
x2: dict[str, np.ndarray] = {}
y2: dict[str, np.ndarray] = {}
for k, v in x.items():
nk = _strip_prefix(
k,
("x_", "x.", "input_", "inputs_"),
)
x2[nk] = v
for k, v in y.items():
nk = _strip_prefix(
k,
("y_", "y.", "target_", "targets_"),
)
y2[nk] = v
return x2, y2
def _infer_targets_path(inputs_path: Path) -> Path:
p = Path(inputs_path)
name = p.name
repls = (
("_inputs", "_targets"),
("inputs", "targets"),
("_input", "_target"),
("input", "target"),
)
for a, b in repls:
if a in name:
cand = p.with_name(name.replace(a, b))
if cand.exists():
return cand
raise FileNotFoundError(
"Could not infer targets NPZ from inputs NPZ:\n"
f" inputs: {str(p)}\n"
"Pass a mapping {'inputs':..., 'targets':...} "
"to load_windows_npz()."
)
[docs]
def load_windows_npz(
path: str | Path | Mapping[str, str],
) -> tuple[dict[str, np.ndarray], dict[str, np.ndarray]]:
"""
Load Stage-1 windows as (x, y).
Supported:
- Bundle NPZ (contains inputs+targets in one file).
- Mapping {'inputs': <npz>, 'targets': <npz>}.
- Inputs NPZ only (targets inferred by filename).
- Directory containing inputs/targets NPZ.
Returns
-------
x : dict[str, np.ndarray]
Inputs (e.g., static_features, dynamic_features, etc.)
y : dict[str, np.ndarray]
Targets (e.g., subs_pred, gwl_pred)
"""
if isinstance(path, Mapping):
ip = Path(path["inputs"])
tp = Path(path["targets"])
return _npz_to_dict(ip), _npz_to_dict(tp)
p = Path(path)
if p.is_dir():
in_cands = (
"inputs.npz",
"train_inputs.npz",
"X_inputs.npz",
)
tg_cands = (
"targets.npz",
"train_targets.npz",
"y_targets.npz",
)
ip = None
tp = None
for n in in_cands:
cand = p / n
if cand.exists():
ip = cand
break
for n in tg_cands:
cand = p / n
if cand.exists():
tp = cand
break
if ip is None or tp is None:
raise FileNotFoundError(
"Directory does not contain recognizable "
"inputs/targets NPZ files:\n"
f" dir: {str(p)}"
)
return _npz_to_dict(ip), _npz_to_dict(tp)
if not p.exists():
raise FileNotFoundError(f"Missing NPZ: {str(p)}")
d = _npz_to_dict(p)
x, y = _split_bundle_npz(d)
if y:
return x, y
tp = _infer_targets_path(p)
return x, _npz_to_dict(tp)
[docs]
def resolve_hybrid_config(
manifest_cfg: dict, live_cfg: dict, verbose: bool = True
) -> dict:
"""
Merge Manifest config (Data Authority) with Live config (Physics Authority).
Strategy
--------
1. Base: Start with Manifest config. This guarantees that data shapes,
time steps, features, and normalization match the artifacts on disk.
2. Override: Apply specific keys from Live config (config.py) that control
architecture, physics equations, loss weights, and training dynamics.
This allows you to tune the model and physics without re-running Stage 1.
Parameters
----------
manifest_cfg : dict
Configuration dictionary loaded from `manifest.json`.
live_cfg : dict
Configuration dictionary loaded from the current `config.py`.
Returns
-------
dict
Merged configuration.
"""
# 1. Start with Manifest (Data Wins)
merged = manifest_cfg.copy()
# 2. Define "Safe" keys that Stage 2 is allowed to override.
# (Everything that does NOT affect input data shapes or target columns)
OVERRIDABLE_KEYS = {
# track City Name , so we can switch to change city as well
"CITY_NAME",
"CITY",
"MODEL_NAME",
"USE_IN_MEMORY_MODEL",
"DEBUG",
"USE_TF_SAVEDMODEL",
"TRACK_AUX_METRICS",
# --- 1. Architecture (Safe to tune if model is rebuilt) ---
"EMBED_DIM",
"HIDDEN_UNITS",
"LSTM_UNITS",
"ATTENTION_UNITS",
"NUMBER_HEADS",
"DROPOUT_RATE",
"MEMORY_SIZE",
"SCALES",
"USE_RESIDUALS",
"USE_BATCH_NORM",
"USE_VSN",
"VSN_UNITS",
"ATTENTION_LEVELS",
# --- 2. Physics Toggles & Math ---
"PDE_MODE_CONFIG",
"SCALE_PDE_RESIDUALS",
"CONSOLIDATION_STEP_RESIDUAL_METHOD",
"ALLOW_SUBS_RESIDUAL",
"OFFSET_MODE",
"PHYSICS_BOUNDS_MODE",
"TIME_UNITS",
# --- 3. Physics Parameters & Initialization ---
"GEOPRIOR_INIT_MV",
"GEOPRIOR_INIT_KAPPA",
"GEOPRIOR_GAMMA_W",
"GEOPRIOR_H_REF",
"GEOPRIOR_KAPPA_MODE",
"GEOPRIOR_USE_EFFECTIVE_H",
"GEOPRIOR_HD_FACTOR",
"PHYSICS_BOUNDS",
# --- 4. Loss Weights (Lambdas) ---
"LAMBDA_CONS",
"LAMBDA_GW",
"LAMBDA_PRIOR",
"LAMBDA_SMOOTH",
"LAMBDA_BOUNDS",
"LAMBDA_MV",
"LAMBDA_Q",
"LAMBDA_OFFSET",
"LOSS_WEIGHT_GWL",
"MV_LR_MULT",
"KAPPA_LR_MULT",
"SUBS_WEIGHTS",
"GWL_WEIGHTS", # Safe: weights don't change shape
# --- 5. Scaling, Stability & Units ---
"CONS_SCALE_FLOOR",
"GW_SCALE_FLOOR",
"GW_RESIDUAL_UNITS",
"CONSOLIDATION_RESIDUAL_UNITS",
"DT_MIN_UNITS",
"Q_WRT_NORMALIZED_TIME",
"Q_IN_SI",
"Q_IN_PER_SECOND",
"Q_KIND",
"Q_LENGTH_IN_SI",
"DRAINAGE_MODE",
"CLIP_GLOBAL_NORM",
"DEBUG_PHYSICS_GRADS",
"SCALING_ERROR_POLICY",
# --- 6. Consolidation Drawdown Gates ---
"CONS_DRAWDOWN_MODE",
"CONS_DRAWDOWN_RULE",
"CONS_STOP_GRAD_REF",
"CONS_DRAWDOWN_ZERO_AT_ORIGIN",
"CONS_DRAWDOWN_CLIP_MAX",
"CONS_RELU_BETA",
# --- 7. MV Prior Strategy ---
"MV_PRIOR_UNITS",
"MV_ALPHA_DISP",
"MV_HUBER_DELTA",
"MV_PRIOR_MODE",
"MV_WEIGHT",
"MV_SCHEDULE_UNIT",
"MV_DELAY_EPOCHS",
"MV_WARMUP_EPOCHS",
"MV_DELAY_STEPS",
"MV_WARMUP_STEPS",
# --- 8. Training Strategy & Gates (Physics-First vs Data-First) ---
"TRAINING_STRATEGY",
# Physics-First specific overrides
"Q_POLICY_PHYSICS_FIRST",
"Q_WARMUP_EPOCHS_PHYSICS_FIRST",
"Q_RAMP_EPOCHS_PHYSICS_FIRST",
"SUBS_RESID_POLICY_PHYSICS_FIRST",
"SUBS_RESID_WARMUP_EPOCHS_PHYSICS_FIRST",
"SUBS_RESID_RAMP_EPOCHS_PHYSICS_FIRST",
"LAMBDA_Q_PHYSICS_FIRST",
"LOSS_WEIGHT_GWL_PHYSICS_FIRST",
# Data-First specific overrides
"LOSS_WEIGHT_GWL_DATA_FIRST",
"LAMBDA_Q_DATA_FIRST",
"Q_POLICY_DATA_FIRST",
"Q_WARMUP_EPOCHS_DATA_FIRST",
"Q_RAMP_EPOCHS_DATA_FIRST",
"SUBS_RESID_POLICY_DATA_FIRST",
"SUBS_RESID_WARMUP_EPOCHS_DATA_FIRST",
"SUBS_RESID_RAMP_EPOCHS_DATA_FIRST",
# --- 9. Lambda Offset Scheduler ---
"USE_LAMBDA_OFFSET_SCHEDULER",
"LAMBDA_OFFSET_UNIT",
"LAMBDA_OFFSET_WHEN",
"LAMBDA_OFFSET_WARMUP",
"LAMBDA_OFFSET_START",
"LAMBDA_OFFSET_END",
"LAMBDA_OFFSET_SCHEDULE",
# --- 10. Training Loop & Logging ---
"EPOCHS",
"BATCH_SIZE",
"LEARNING_RATE",
"PATIENCE",
"LOG_Q_DIAGNOSTICS",
"AUDIT_STAGES",
"EVAL_JSON_UNITS_MODE",
"EVAL_JSON_UNITS_SCOPE",
"VERBOSE",
}
updates = []
for key in OVERRIDABLE_KEYS:
# If the key exists in your live config (config.py), it overrides manifest
if key in live_cfg:
current_val = merged.get(key)
new_val = live_cfg[key]
# Update only if different (or if manifest didn't have it)
if new_val != current_val:
merged[key] = new_val
updates.append(key)
if verbose and updates:
print(
f"[Config] Applied {len(updates)} physics/training overrides from config.py:"
)
# Print a few examples
sample = updates[:4]
print(f" {', '.join(sample)} ...")
return merged
# -------------------------------------------------------------------------
# Backward-compatible aliases for old private helper names
# -------------------------------------------------------------------------
_pick_npz_for_dataset = pick_npz_for_dataset
_infer_input_dims_from_X = infer_input_dims_from_X