# License: Apache-2.0
# Copyright (c) 2026-present
# Author: LKouadio <etanoyau@gmail.com>
r"""Target-processing helpers for GeoPrior workflows."""
from __future__ import annotations
from collections.abc import Mapping, Sequence
from .generic_utils import rename_dict_keys
_DEFAULT_EXCLUDE_KEYS = {
"data_final",
"data_mean_raw",
"phys_final",
"phys_mean_raw",
"aux",
"physics",
"maps",
}
def _metric_key_from_name(name: str):
# Keras names: "subs_pred_mae", "subs_pred_coverage80", "gwl_pred_mse", ...
# Skip Keras loss trackers
if name in ("loss",) or name.endswith("_loss"):
return None
if name.startswith("subs_pred_"):
return "subs_pred"
if name.startswith("gwl_pred_"):
return "gwl_pred"
return None
def _safe_to_numpy(x, *, mode="auto"):
"""
Convert to numpy only when it is safe.
mode:
- "never": never convert
- "always": try convert (will fail in graph mode)
- "auto": convert only if .numpy() works (eager)
"""
if mode == "never":
return x
if hasattr(x, "numpy") and callable(x.numpy):
if mode == "always":
return x.numpy()
# auto: try, but don't crash (graph/XLA will throw)
try:
return x.numpy()
except Exception:
return x
return x
[docs]
def get_output_names(
model=None,
y=None,
y_pred=None,
*,
exclude_keys=_DEFAULT_EXCLUDE_KEYS,
):
"""
Try to obtain stable output names (best-effort, Keras-3-safe).
Lookup priority is: ``model._output_keys`` or ``model._output_names``
first, then ``model.output_names``, then keys from ``y_pred``, and
finally keys from ``y``.
"""
if model is not None:
for attr in (
"_output_keys",
"_output_names",
"output_names",
):
names = getattr(model, attr, None)
if names:
return [str(n) for n in list(names)]
for obj in (y_pred, y):
if isinstance(obj, Mapping):
keys = [
k for k in obj.keys() if k not in exclude_keys
]
if keys:
return list(keys)
return None
[docs]
def as_tuple(
obj,
*,
names=None,
model=None,
ctx="value",
strict=True,
exclude_keys=_DEFAULT_EXCLUDE_KEYS,
to_numpy="never", # "never" | "auto" | "always"
):
"""
Convert obj to an ordered tuple of outputs.
Supports:
- dict: ordered by `names` (or inferred)
- list/tuple: returns tuple(obj)
- tensor/ndarray/scalar: returns (obj,)
Parameters
----------
obj : Any
Targets/predictions container.
names : list[str] or None
Desired order of outputs.
model : Any or None
Model used to infer output names (via _output_keys/output_names).
strict : bool
If True, missing keys in dict raises KeyError.
to_numpy : {"never","auto","always"}
Optional conversion of individual leaves to numpy (only safe in eager).
Returns
-------
tuple
Ordered outputs.
"""
if isinstance(obj, Mapping):
if names is None:
names = get_output_names(
model=model,
y=obj,
y_pred=None,
exclude_keys=exclude_keys,
)
if names:
out = []
missing = []
for n in names:
if n in obj:
out.append(
_safe_to_numpy(obj[n], mode=to_numpy)
)
else:
missing.append(n)
if missing and strict:
raise KeyError(
f"{ctx}: missing keys {missing}. Available keys={list(obj.keys())}"
)
return tuple(out)
# No names: keep insertion order excluding known aux keys
return tuple(
_safe_to_numpy(v, mode=to_numpy)
for k, v in obj.items()
if k not in exclude_keys
)
# Avoid treating strings as sequences
if isinstance(obj, Sequence) and not isinstance(
obj, str | bytes
):
return tuple(
_safe_to_numpy(v, mode=to_numpy) for v in obj
)
return (_safe_to_numpy(obj, mode=to_numpy),)
def _prune_dict(d, names):
if not isinstance(d, Mapping) or not names:
return d
return {k: d[k] for k in names if k in d}
[docs]
def update_compiled_metrics(
model,
y_true,
y_pred,
*,
output_names=None,
to_numpy="never", # keep "never" for real training
):
"""
Keras-3-safe compiled metrics updater.
- Prefers dict structure (since you compiled with dict loss/metrics).
- Ensures deterministic output order via output_names/_output_keys.
- Falls back to list/tuple update_state if needed.
- Final fallback: manual per-metric update (won't crash training).
Note: converting to numpy inside train_step is generally NOT safe.
"""
cm = getattr(model, "compiled_metrics", None)
if cm is None:
return
names = output_names or get_output_names(
model=model, y=y_true, y_pred=y_pred
)
# Prefer dicts if available (matches your compile(loss=dict, metrics=dict))
yt = (
_prune_dict(y_true, names)
if isinstance(y_true, Mapping)
else y_true
)
yp = (
_prune_dict(y_pred, names)
if isinstance(y_pred, Mapping)
else y_pred
)
# Optional numpy conversion (debug only; keep default "never")
if to_numpy != "never":
if isinstance(yt, Mapping):
yt = {
k: _safe_to_numpy(v, mode=to_numpy)
for k, v in yt.items()
}
if isinstance(yp, Mapping):
yp = {
k: _safe_to_numpy(v, mode=to_numpy)
for k, v in yp.items()
}
# Attempt 1: dict update (best match for your compile config)
try:
cm.update_state(yt, yp)
return
except Exception:
pass
# Attempt 2: tuple/list update
try:
yt_list = as_tuple(
y_true,
names=names,
model=model,
ctx="y_true",
to_numpy=to_numpy,
)
yp_list = as_tuple(
y_pred,
names=names,
model=model,
ctx="y_pred",
to_numpy=to_numpy,
)
cm.update_state(list(yt_list), list(yp_list))
return
except Exception:
pass
# Attempt 3: manual per-metric update (last resort, never crash)
try:
metrics = getattr(cm, "metrics", []) or []
if isinstance(y_true, Mapping) and isinstance(
y_pred, Mapping
):
for m in metrics:
key = _metric_key_from_name(
getattr(m, "name", "") or ""
)
if not key:
continue
if key in y_true and key in y_pred:
try:
m.update_state(
y_true[key], y_pred[key]
)
except Exception:
continue
except Exception:
return
def _canonicalize_targets(targets):
"""
Return targets as a dict matching compiled output names.
Accepts:
- dict targets: {"subsidence": ..., "gwl": ...} or already
{"subs_pred": ..., "gwl_pred": ...}
- tuple/list targets: (subs, gwl) -> {"subs_pred": subs, "gwl_pred": gwl}
This avoids Keras 3 structure-mismatch errors when y_pred is a dict.
"""
if isinstance(targets, Mapping):
tgt = dict(targets)
# already canonical?
if ("subs_pred" in tgt) and ("gwl_pred" in tgt):
return tgt
return rename_dict_keys(
tgt,
param_to_rename={
"subsidence": "subs_pred",
"gwl": "gwl_pred",
},
)
if isinstance(targets, tuple | list):
if len(targets) == 2:
return {
"subs_pred": targets[0],
"gwl_pred": targets[1],
}
if len(targets) == 1:
return {"subs_pred": targets[0]}
return targets