# SPDX-License-Identifier: Apache-2.0
# GeoPrior-v3 https://github.com/earthai-tech/geoprior-v3
# Copyright (c) 2026-present
# Author: LKouadio <https://lkouadio.com>
"""
GeoPrior subsidence model utilities.
"""
from __future__ import annotations
import json
from collections.abc import Mapping
from pathlib import Path
from typing import Any
from warnings import warn
import numpy as np
from .. import KERAS_DEPS
Tensor = KERAS_DEPS.Tensor
tf_float32 = KERAS_DEPS.float32
tf_int32 = KERAS_DEPS.int32
tf_cast = KERAS_DEPS.cast
tf_constant = KERAS_DEPS.constant
tf_debugging = KERAS_DEPS.debugging
tf_equal = KERAS_DEPS.equal
tf_maximum = KERAS_DEPS.maximum
tf_minimum = KERAS_DEPS.minimum
tf_greater_equal = KERAS_DEPS.greater_equal
tf_rank = KERAS_DEPS.rank
tf_cond = KERAS_DEPS.cond
tf_shape = KERAS_DEPS.shape
tf_zeros_like = KERAS_DEPS.zeros_like
tf_ones = KERAS_DEPS.ones
tf_greater = KERAS_DEPS.greater
tf_cond = KERAS_DEPS.cond
tf_concat = KERAS_DEPS.concat
tf_convert_to_tensor = KERAS_DEPS.convert_to_tensor
tf_ones_like = KERAS_DEPS.ones_like
tf_less_equal = KERAS_DEPS.less_equal
tf_abs = KERAS_DEPS.abs
tf_print = KERAS_DEPS.print
tf_reduce_mean = KERAS_DEPS.reduce_mean
tf_expand_dims = KERAS_DEPS.expand_dims
tf_tile = KERAS_DEPS.tile
_EPSILON = 1e-12
# ---------------------------------------------------------------------
# Scaling kwargs access helpers (alias-safe)
# ---------------------------------------------------------------------
_SK_ALIASES = {
# common naming drift
"time_units": ("time_unit",),
"cons_residual_units": ("cons_residual_unit",),
# policy drift
"scaling_error_policy": (
"error_policy",
"scaling_policy",
),
# coord drift
"coords_normalized": (
"coord_normalized",
"coords_norm",
),
"coords_in_degrees": (
"coord_in_degrees",
"coords_deg",
),
"coord_order": ("coords_order",),
"coord_ranges": ("coord_range",),
# feature-name list drift
"dynamic_feature_names": (
"dynamic_features_names",
"dyn_feature_names",
),
"future_feature_names": (
"future_features_names",
"fut_feature_names",
),
"static_feature_names": (
"static_features_names",
"stat_feature_names",
),
# feature-channel naming drift
"gwl_col": (
"gwl_dyn_name",
"gwl_dyn_col",
"gwl_name",
),
"subs_dyn_name": (
"subs_col",
"subs_dyn_col",
"subsidence_dyn_name",
),
# feature-channel index drift
"gwl_dyn_index": (
"gwl_index",
"gwl_feature_index",
"gwl_channel_index",
),
"subs_dyn_index": (
"subs_index",
"subs_feature_index",
"subs_channel_index",
),
# z_surf drift
"z_surf_col": (
"z_surf_key",
"z_surf_name",
),
# bounds drift (often nested under scaling_kwargs['bounds'])
"log_tau_min": (
"logTau_min",
"logtau_min",
),
"log_tau_max": (
"logTau_max",
"logtau_max",
),
"tau_min": (
"Tau_min",
"tauMin",
"tau_min_sec",
"tau_min_seconds",
),
"tau_max": (
"Tau_max",
"tauMax",
"tau_max_sec",
"tau_max_seconds",
),
"tau_min_units": (
"tau_min_time_units",
"tau_min_in_time_units",
),
"tau_max_units": (
"tau_max_time_units",
"tau_max_in_time_units",
),
"Q_length_in_si": ("Q_in_m_per_s",),
}
_SK_ALIASES.update(
{
"cons_drawdown_mode": (
"drawdown_mode",
"cons_delta_mode",
),
"cons_drawdown_rule": (
"drawdown_rule",
"cons_delta_rule",
),
"cons_stop_grad_ref": (
"stop_grad_ref",
"cons_stopgrad_ref",
),
"cons_drawdown_zero_at_origin": (
"drawdown_zero_at_origin",
"cons_zero_at_origin",
),
"cons_drawdown_clip_max": (
"drawdown_clip_max",
"cons_clip_max",
),
"cons_relu_beta": (
"relu_beta",
"cons_beta",
),
}
)
# MV prior drift (mode/weight/warmup + loss knobs)
_SK_ALIASES.update(
{
"mv_prior_mode": (
"mv_mode",
"mvprior_mode",
"mv_prior_kind",
),
"mv_weight": (
"mv_prior_weight",
"mvprior_weight",
"mv_w",
),
"mv_warmup_steps": (
"mv_prior_warmup_steps",
"mv_warmup_steps",
"mv_warmup_iters",
"mv_warmup_iterations",
),
"mv_alpha_disp": (
"mv_prior_alpha_disp",
"mv_disp_alpha",
"mv_alpha",
),
"mv_huber_delta": (
"mv_prior_huber_delta",
"mv_delta",
"mv_huber",
),
"mv_prior_units": (
"mv_units",
"mv_gamma_units",
"mv_gw_units",
),
}
)
[docs]
def enforce_scaling_alias_consistency(
scaling_kwargs: dict[str, Any] | None,
*,
where: str = "validate",
) -> None:
"""
Enforce that canonical keys and aliases agree.
If both canonical and an alias exist and their
values differ, apply the scaling error policy.
"""
sk = scaling_kwargs or {}
for key, aliases in _SK_ALIASES.items():
if key not in sk:
continue
v0 = sk.get(key, None)
if v0 is None:
continue
for a in aliases:
if a not in sk:
continue
va = sk.get(a, None)
if va is None:
continue
if va != v0:
msg = (
"Conflicting scaling keys: "
f"{key!r}={v0!r} != {a!r}={va!r}."
)
_handle_scaling_issue(
sk,
msg,
where=where,
)
[docs]
def canonicalize_scaling_kwargs(
scaling_kwargs: dict[str, Any] | None,
*,
copy: bool = True,
) -> dict[str, Any]:
"""
Return a canonicalized scaling dict.
- If a canonical key is missing, but one of its
aliases exists, copy alias -> canonical.
- Keeps existing canonical values unchanged.
"""
sk0 = scaling_kwargs or {}
sk = dict(sk0) if copy else sk0
for key, aliases in _SK_ALIASES.items():
if key in sk and sk.get(key, None) is not None:
continue
for a in aliases:
if a in sk and sk.get(a, None) is not None:
sk[key] = sk[a]
break
return sk
[docs]
def load_scaling_kwargs(
scaling_kwargs: Any | None,
*,
copy: bool = True,
) -> dict[str, Any]:
"""
Load scaling kwargs from a dict-like object or JSON.
Supported inputs
----------------
- dict / Mapping:
Returned (copied by default).
- str:
* If it looks like JSON ("{...}" or "[...]"), parse as JSON.
* Else treat as a filesystem path to a JSON file.
- pathlib.Path:
Treated as a filesystem path to a JSON file.
- None:
Returns {}.
Parameters
----------
scaling_kwargs : Any
Scaling configuration input. Can be a dict, JSON string,
path to JSON file, or None.
copy : bool, default=True
If True, returns a shallow copy of the dict.
Returns
-------
dict
Parsed scaling kwargs as a Python dict.
Raises
------
TypeError
If the input type is unsupported.
ValueError
If JSON parsing fails or JSON does not decode to a dict.
FileNotFoundError
If a JSON path is given but does not exist.
"""
if scaling_kwargs is None:
return {}
if isinstance(scaling_kwargs, Mapping):
return (
dict(scaling_kwargs) if copy else scaling_kwargs
)
if isinstance(scaling_kwargs, Path):
path = scaling_kwargs
text = path.read_text(encoding="utf-8")
obj = json.loads(text)
if not isinstance(obj, dict):
raise ValueError(
"Scaling JSON must decode to an object/dict, "
f"got {type(obj).__name__}."
)
return obj
if isinstance(scaling_kwargs, str):
s = scaling_kwargs.strip()
# 1) Inline JSON object/array.
if (s.startswith("{") and s.endswith("}")) or (
s.startswith("[") and s.endswith("]")
):
try:
obj = json.loads(s)
except json.JSONDecodeError as e:
raise ValueError(
"Invalid scaling_kwargs JSON string."
) from e
if not isinstance(obj, dict):
raise ValueError(
"Scaling JSON must decode to an object/dict, "
f"got {type(obj).__name__}."
)
return obj
# 2) Treat as file path to JSON.
path = Path(s).expanduser()
if not path.exists():
raise FileNotFoundError(
f"Scaling kwargs JSON file not found: {str(path)!r}."
)
text = path.read_text(encoding="utf-8")
try:
obj = json.loads(text)
except json.JSONDecodeError as e:
raise ValueError(
f"Invalid JSON in scaling kwargs file: {str(path)!r}."
) from e
if not isinstance(obj, dict):
raise ValueError(
"Scaling JSON file must decode to an object/dict, "
f"got {type(obj).__name__}."
)
return obj
try:
obj = dict(scaling_kwargs)
except Exception as e:
raise TypeError(
"scaling_kwargs must be a dict/Mapping, JSON string, "
"Path, or a path string to a JSON file."
) from e
return obj
[docs]
def get_sk(
scaling_kwargs,
key: str,
*aliases: str,
default=None,
required: bool = False,
cast=None,
):
"""
Fetch a key from `scaling_kwargs` with aliases + default.
- Tries: key -> built-in aliases -> explicit aliases
- Treats None and blank strings as "missing" and keeps searching.
"""
sk = scaling_kwargs or {}
if not isinstance(sk, Mapping):
try:
sk = dict(sk)
except Exception:
sk = {}
cand = [key]
cand.extend(_SK_ALIASES.get(key, ()))
cand.extend([a for a in aliases if a])
for k in cand:
if k in sk:
v = sk[k]
if v is None:
continue
if isinstance(v, str) and not v.strip():
continue
if cast is not None:
try:
v = cast(v)
except Exception as e:
raise ValueError(
f"Invalid scaling_kwargs[{k!r}]={v!r}."
) from e
return v
if required:
alias_txt = (
", ".join(repr(x) for x in cand[1:]) or "none"
)
raise ValueError(
f"Missing required scaling key {key!r} (aliases: {alias_txt})."
)
if cast is not None and default is not None:
try:
return cast(default)
except Exception:
return default
return default
def _norm_policy(policy: str | None) -> str:
"""
Normalize scaling error policy.
Allowed:
- 'ignore'
- 'warn' (default)
- 'raise'
"""
p = (policy or "warn").strip().lower()
if p not in ("ignore", "warn", "raise"):
p = "warn"
return p
def _handle_scaling_issue(
scaling_kwargs: dict[str, Any] | None,
message: str,
*,
where: str = "validate",
) -> None:
"""
Apply scaling error policy.
Notes
-----
You asked for: even if policy is 'raise', runtime
fallback paths should still fall back to zeros.
So:
- where='validate': obey ignore/warn/raise
- where='runtime' : treat 'raise' as 'warn'
"""
sk = scaling_kwargs or {}
policy = _norm_policy(
get_sk(sk, "scaling_error_policy", default="warn")
)
# Runtime must not crash; still fall back later.
if where != "validate" and policy == "raise":
policy = "warn"
if policy == "ignore":
return
if policy == "warn":
warn(
message,
category=RuntimeWarning,
stacklevel=2,
)
return
# validate + raise
raise ValueError(message)
def _is_deg_mode(mode: str) -> bool:
m = (mode or "").strip().lower()
return m in {
"deg",
"degree",
"degrees",
"lonlat",
"latlon",
}
def _validate_scaling_kwargs(scaling_kwargs):
sk = canonicalize_scaling_kwargs(scaling_kwargs)
enforce_scaling_alias_consistency(sk, where="validate")
mode = str(sk.get("coord_mode", ""))
deg_mode = _is_deg_mode(mode)
deg_flag = bool(sk.get("coords_in_degrees", False))
if deg_mode != deg_flag:
msg = (
"Inconsistent coord flags: "
f"coord_mode={mode!r} but "
f"coords_in_degrees={deg_flag}. "
"Decide: degrees(+deg_to_m_*) or "
"projected meters (coords_in_degrees=False)."
)
_handle_scaling_issue(sk, msg, where="validate")
epsg_used = sk.get("coord_epsg_used", None)
if deg_flag and (epsg_used not in (None, 4326)):
msg = (
"coords_in_degrees=True but "
f"coord_epsg_used={epsg_used!r} "
"looks projected. If you already "
"reprojected, set coords_in_degrees=False."
)
_handle_scaling_issue(sk, msg, where="validate")
[docs]
def validate_scaling_kwargs(
scaling_kwargs: dict[str, Any] | None,
) -> None:
"""
Basic scaling sanity checks.
This includes policy-controlled heuristic checks
for common "silent fallback" cases.
"""
sk = canonicalize_scaling_kwargs(scaling_kwargs)
enforce_scaling_alias_consistency(sk, where="validate")
# --------------------------------------------------
# Degrees mode requires meters-per-degree factors.
# --------------------------------------------------
if bool(sk.get("coords_in_degrees", False)):
for key in ("deg_to_m_lon", "deg_to_m_lat"):
val = sk.get(key, None)
if val is None:
msg = (
"coords_in_degrees=True but missing "
f"scaling_kwargs[{key!r}]."
)
raise ValueError(msg)
try:
v = float(val)
except (TypeError, ValueError) as e:
raise ValueError(
f"Invalid {key!r}={val!r}."
) from e
if not np.isfinite(v) or v <= 0.0:
raise ValueError(f"Invalid {key!r}={v}.")
# --------------------------------------------------
# Normalized coords require coord_ranges.
# --------------------------------------------------
if bool(
sk.get("coords_normalized", False)
) and not sk.get(
"coord_ranges",
None,
):
raise ValueError(
"coords_normalized=True but coord_ranges missing."
)
# --------------------------------------------------
# Require time units (alias-safe).
# --------------------------------------------------
if get_sk(sk, "time_units", default=None) is None:
raise ValueError(
"time_units missing in scaling_kwargs."
)
# --------------------------------------------------
# Heuristic checks (policy-controlled).
# --------------------------------------------------
names = sk.get("dynamic_feature_names", None)
names = list(names) if names is not None else []
# A) Subsidence init: detect cum subs channel.
has_subs_cum = any(
("subs" in str(n).lower() and "cum" in str(n).lower())
for n in names
)
subs_idx = sk.get("subs_dyn_index", None)
subs_name = get_sk(sk, "subs_dyn_name", default=None)
meta = sk.get("gwl_z_meta", {}) or {}
cols = meta.get("cols", {}) or {}
subs_meta = cols.get("subs_model", None)
if (
has_subs_cum
and subs_idx is None
and subs_name is None
):
if subs_meta is None:
msg = (
"dynamic_feature_names contains a cumulative "
"subsidence channel, but no subs_dyn_index/"
"subs_dyn_name and no gwl_z_meta.cols.subs_model. "
"Initial settlement will fall back to zeros."
)
_handle_scaling_issue(
sk,
msg,
where="validate",
)
# B) Depth->head conversion needs z_surf when proxy=False.
kind = str(sk.get("gwl_kind", "")).lower()
proxy = bool(sk.get("use_head_proxy", True))
if (not proxy) and (
kind not in ("head", "waterhead", "hydraulic_head")
):
z_col = sk.get("z_surf_col", None)
z_col = z_col or meta.get("z_surf_col", None)
z_static = cols.get("z_surf_static", None)
z_idx = sk.get("z_surf_static_index", None)
static_names = get_sk(
sk,
"static_feature_names",
default=None,
)
# If you did not provide a way to locate z_surf
# in static features, conversion may fallback.
if (
z_idx is None
and static_names is None
and z_col is not None
and z_static is not None
and z_col != z_static
):
msg = (
"use_head_proxy=False and gwl_kind is depth-like, "
"but z_surf_col differs from gwl_z_meta.cols."
"z_surf_static, and no static_feature_names/"
"z_surf_static_index provided. Depth->head "
"conversion may fall back to depth."
)
_handle_scaling_issue(
sk,
msg,
where="validate",
)
[docs]
def affine_from_cfg(
scaling_kwargs: dict[str, Any] | None,
*,
scale_key: str,
bias_key: str,
meta_keys: tuple[str, ...] = (),
unit_key: str | None = None,
) -> tuple[Tensor, Tensor]:
"""Return (a,b) for y_si = y_model*a + b."""
cfg = scaling_kwargs or {}
a = cfg.get(scale_key, None)
b = cfg.get(bias_key, None)
if a is not None or b is not None:
a = 1.0 if a is None else float(a)
b = 0.0 if b is None else float(b)
return tf_constant(a, tf_float32), tf_constant(
b, tf_float32
)
for mk in meta_keys:
meta = cfg.get(mk, None)
if isinstance(meta, dict):
mu = meta.get("mu", meta.get("mean", None))
sig = meta.get("sigma", meta.get("std", None))
if mu is not None and sig is not None:
return (
tf_constant(float(sig), tf_float32),
tf_constant(float(mu), tf_float32),
)
if unit_key is not None:
u = float(cfg.get(unit_key, 1.0))
return tf_constant(u, tf_float32), tf_constant(
0.0, tf_float32
)
return tf_constant(1.0, tf_float32), tf_constant(
0.0, tf_float32
)
[docs]
def to_si_thickness(
H_model: Tensor,
scaling_kwargs: dict[str, Any] | None,
) -> Tensor:
"""Convert thickness to SI."""
a, b = affine_from_cfg(
scaling_kwargs,
scale_key="H_scale_si",
bias_key="H_bias_si",
meta_keys=("H_z_meta",),
unit_key="thickness_unit_to_si",
)
return tf_cast(H_model, tf_float32) * a + b
[docs]
def to_si_head(
h_model: Tensor,
scaling_kwargs: dict[str, Any] | None,
) -> Tensor:
"""Convert head/depth to SI meters."""
a, b = affine_from_cfg(
scaling_kwargs,
scale_key="head_scale_si",
bias_key="head_bias_si",
meta_keys=("head_z_meta", "gwl_z_meta"),
unit_key="head_unit_to_si",
)
return tf_cast(h_model, tf_float32) * a + b
[docs]
def to_si_subsidence(
s_model: Tensor,
scaling_kwargs: dict[str, Any] | None,
) -> Tensor:
"""Convert subsidence to SI meters."""
a, b = affine_from_cfg(
scaling_kwargs,
scale_key="subs_scale_si",
bias_key="subs_bias_si",
meta_keys=("subs_z_meta",),
unit_key="subs_unit_to_si",
)
return tf_cast(s_model, tf_float32) * a + b
[docs]
def from_si_subsidence(
s_si: Tensor,
scaling_kwargs: dict[str, Any] | None,
) -> Tensor:
"""Inverse of to_si_subsidence: s_model = (s_si - b) / a."""
a, b = affine_from_cfg(
scaling_kwargs,
scale_key="subs_scale_si",
bias_key="subs_bias_si",
meta_keys=("subs_z_meta",),
unit_key="subs_unit_to_si",
)
eps = tf_constant(_EPSILON, tf_float32)
return (tf_cast(s_si, tf_float32) - b) / (a + eps)
[docs]
def deg_to_m(
axis: str,
scaling_kwargs: dict[str, Any] | None,
) -> Tensor:
"""
Meters per degree factor for lon/lat coords.
If coords_in_degrees=True and deg_to_m_lon/lat are missing, we try
to compute them from lat0_deg (recommended).
"""
if axis not in ("x", "y"):
raise ValueError(
f"deg_to_m: axis must be 'x' or 'y', got {axis!r}."
)
cfg = scaling_kwargs or {}
if not bool(cfg.get("coords_in_degrees", False)):
return tf_constant(1.0, tf_float32)
key = "deg_to_m_lon" if axis == "x" else "deg_to_m_lat"
val = cfg.get(key, None)
if val is None:
lat0 = cfg.get("lat0_deg", None)
if lat0 is None:
raise ValueError(
"coords_in_degrees=True but missing deg_to_m_lon/deg_to_m_lat "
"and lat0_deg (needed for lon scaling)."
)
lat0 = float(lat0)
if axis == "x":
v = 111320.0 * float(np.cos(np.deg2rad(lat0)))
else:
v = 110574.0
return tf_constant(v, tf_float32)
try:
v = float(val)
except (TypeError, ValueError) as e:
raise ValueError(f"Invalid {key!r}={val!r}.") from e
if not np.isfinite(v) or v <= 0.0:
raise ValueError(f"Invalid {key!r}={v}.")
return tf_constant(v, tf_float32)
[docs]
def coord_ranges(
scaling_kwargs: dict[str, Any] | None,
) -> tuple[float | None, float | None, float | None]:
"""Return (tR,xR,yR) if coords_normalized."""
cfg = scaling_kwargs or {}
if not bool(cfg.get("coords_normalized", False)):
return None, None, None
r = cfg.get("coord_ranges", {}) or {}
def get(name: str, *alts: str) -> float | None:
v = r.get(name, None)
if v is None:
for a in alts:
v = cfg.get(a, None)
if v is not None:
break
return None if v is None else float(v)
tR = get("t", "t_range", "coord_range_t")
xR = get("x", "x_range", "coord_range_x")
yR = get("y", "y_range", "coord_range_y")
return tR, xR, yR
[docs]
def resolve_gwl_dyn_index(
scaling_kwargs: dict[str, Any] | None,
) -> int:
"""Resolve GWL channel index for dynamic_features."""
sk = scaling_kwargs or {}
idx = sk.get("gwl_dyn_index", None)
if idx is not None:
return int(idx)
names = sk.get("dynamic_feature_names", None)
gwl_col = get_sk(sk, "gwl_col", default=None)
if names is not None and gwl_col is not None:
names = list(names)
if gwl_col in names:
return int(names.index(gwl_col))
raise ValueError(
"Cannot resolve GWL channel. Provide gwl_dyn_index "
"or dynamic_feature_names + gwl_col."
)
[docs]
def get_gwl_dyn_index_cached(model) -> int:
"""Cache gwl_dyn_index on model after first resolve."""
idx = getattr(model, "gwl_dyn_index", None)
if idx is None:
idx = resolve_gwl_dyn_index(
getattr(
model,
"scaling_kwargs",
None,
)
)
model.gwl_dyn_index = int(idx)
return int(idx)
[docs]
def resolve_subs_dyn_index(scaling_kwargs):
"""Resolve subsidence channel index for dynamic_features.
This is optional: v3.2 can use historical subsidence as a dynamic
driver to provide a physics-friendly initial condition for the mean
settlement path.
"""
sk = scaling_kwargs or {}
idx = sk.get("subs_dyn_index", None)
if idx is not None:
return int(idx)
names = sk.get("dynamic_feature_names", None)
subs_col = get_sk(sk, "subs_dyn_name", default=None)
# NEW: fallback to gwl_z_meta.cols.subs_model
if subs_col is None:
meta = sk.get("gwl_z_meta", {}) or {}
cols = meta.get("cols", {}) or {}
subs_col = cols.get("subs_model", None)
if names is not None and subs_col is not None:
names = list(names)
if subs_col in names:
return int(names.index(subs_col))
raise ValueError(
"Cannot resolve subsidence channel. Provide subs_dyn_index "
"or dynamic_feature_names + subs_dyn_name (or gwl_z_meta.cols.subs_model)."
)
[docs]
def get_subs_dyn_index_cached(model) -> int:
"""Cache subs_dyn_index on model after first resolve."""
idx = getattr(model, "subs_dyn_index", None)
if idx is None:
idx = resolve_subs_dyn_index(
getattr(model, "scaling_kwargs", None)
)
model.subs_dyn_index = int(idx)
return int(idx)
[docs]
def slice_dynamic_channel(Xh: Tensor, idx: int) -> Tensor:
"""Slice (B,T,F) -> (B,T,1) at idx."""
idx_t = tf_cast(idx, tf_int32)
F = tf_shape(Xh)[-1]
tf_debugging.assert_less(
idx_t,
F,
message="gwl_dyn_index out of range.",
)
return Xh[:, :, idx_t : idx_t + 1]
[docs]
def assert_dynamic_names_match_tensor(
Xh: Tensor,
scaling_kwargs: dict[str, Any] | None,
) -> None:
"""Check dynamic_feature_names length matches Xh."""
sk = scaling_kwargs or {}
names = sk.get("dynamic_feature_names", None)
if names is None:
return
n = len(list(names))
tf_debugging.assert_equal(
tf_shape(Xh)[-1],
tf_constant(n, tf_int32),
message="dynamic_feature_names != Xh last dim",
)
[docs]
def gwl_to_head_m(
v_m: Tensor,
scaling_kwargs: dict[str, Any] | None,
*,
inputs: dict[str, Tensor] | None = None,
) -> Tensor:
"""
Convert depth-bgs to head if possible.
Behavior
--------
- If gwl_kind is head-like: return v_m.
- Otherwise treat as depth and try:
head = z_surf - depth.
- If z_surf is missing:
* use_head_proxy=True -> return -depth
* use_head_proxy=False -> return depth
"""
sk = scaling_kwargs or {}
# --------------------------------------------------
# 1) Decide whether v_m is head or depth.
# --------------------------------------------------
kind_raw = sk.get("gwl_kind", None)
if kind_raw is None or str(kind_raw).strip() == "":
gwl_col = str(get_sk(sk, "gwl_col", default=""))
gwl_col = gwl_col.lower()
kind = "depth" if ("depth" in gwl_col) else "head"
else:
kind = str(kind_raw).lower()
if kind in ("head", "waterhead", "hydraulic_head"):
return tf_cast(v_m, tf_float32)
# --------------------------------------------------
# 2) Depth convention + proxy behavior.
# --------------------------------------------------
sign = str(sk.get("gwl_sign", "down_positive")).lower()
proxy = bool(sk.get("use_head_proxy", True))
# --------------------------------------------------
# 3) Collect possible z_surf keys.
# Prefer SI/static key first when available.
# --------------------------------------------------
meta = sk.get("gwl_z_meta", {}) or {}
cols = meta.get("cols", {}) or {}
z_surf_col = sk.get("z_surf_col", None)
z_surf_col = z_surf_col or meta.get("z_surf_col", None)
z_surf_static = cols.get("z_surf_static", None)
z_surf_raw = cols.get("z_surf_raw", None)
z_surf_keys = [
k
for k in (z_surf_static, z_surf_col, z_surf_raw)
if k
]
# Dedupe while preserving order.
seen = set()
z_surf_keys = [
k
for k in z_surf_keys
if not (k in seen or seen.add(k))
]
# --------------------------------------------------
# 4) Convert to positive-down depth.
# --------------------------------------------------
v_m = tf_cast(v_m, tf_float32)
depth_m = v_m if sign == "down_positive" else -v_m
# --------------------------------------------------
# 5) Try direct inputs[z_surf_key] first.
# --------------------------------------------------
z_surf = None
if inputs is not None:
for k in z_surf_keys:
z_surf = inputs.get(k, None)
if z_surf is not None:
z_surf = tf_cast(z_surf, tf_float32)
break
# --------------------------------------------------
# 6) If missing, try static_features lookup.
# --------------------------------------------------
if z_surf is None and inputs is not None:
sf = inputs.get("static_features", None)
if sf is not None:
sf = tf_cast(sf, tf_float32)
idx = sk.get("z_surf_static_index", None)
if idx is None:
names = get_sk(
sk,
"static_feature_names",
default=None,
)
if names is not None:
names = list(names)
for k in z_surf_keys:
if k in names:
idx = int(names.index(k))
break
if idx is not None:
idx_i = int(idx)
tf_debugging.assert_less(
tf_cast(idx_i, tf_int32),
tf_shape(sf)[-1],
message="z_surf_static_index out of range.",
)
r = getattr(sf.shape, "rank", None)
if r == 2:
z_surf = sf[:, idx_i : idx_i + 1]
elif r == 3:
z_surf = sf[:, :, idx_i : idx_i + 1]
else:
rr = tf_rank(sf)
z_surf = tf_cond(
tf_equal(rr, 2),
lambda: sf[:, idx_i : idx_i + 1],
lambda: sf[:, :, idx_i : idx_i + 1],
)
if z_surf is None:
# if bool(sk.get("debug_units", False)):
tf_print(
"[gwl_to_head_m] z_surf missing ->",
"use_head_proxy=",
bool(sk.get("use_head_proxy", False)),
"returning depth-like quantity (NOT true head)",
)
# --------------------------------------------------
# 7) If we have z_surf: head = z_surf - depth.
# --------------------------------------------------
if z_surf is not None:
r = tf_rank(z_surf)
z_surf = tf_cond(
tf_equal(r, 1),
lambda: z_surf[:, None, None],
lambda: tf_cond(
tf_equal(r, 2),
lambda: z_surf[:, None, :],
lambda: z_surf,
),
)
# Broadcast z_surf to match depth_m.
z_surf = z_surf + tf_zeros_like(depth_m)
return z_surf - depth_m
# --------------------------------------------------
# 8) Fallback: proxy head or keep depth.
# --------------------------------------------------
return -depth_m if proxy else depth_m
def _reshape_to_b11(v: Tensor) -> Tensor:
"""Coerce a tensor to (B,1,1) if possible."""
v = tf_cast(v, tf_float32)
r = tf_rank(v)
return tf_cond(
tf_equal(r, 1),
lambda: v[:, None, None],
lambda: tf_cond(
tf_equal(r, 2),
lambda: v[:, None, :],
lambda: v,
),
)
[docs]
def get_h_hist_si(
model,
inputs: dict[str, Tensor],
*,
want_head: bool = True,
) -> Tensor:
"""Return head (or depth) history in SI meters.
Parameters
----------
model : object
The model instance (provides ``scaling_kwargs`` and cached indices).
inputs : dict
Batch inputs; expects ``dynamic_features`` unless an explicit
head history key is provided.
want_head : bool, default=True
If True, convert depth-bgs to hydraulic head when possible.
Returns
-------
Tensor
(B,T,1) tensor in SI meters.
"""
sk = getattr(model, "scaling_kwargs", None)
# Explicit override (useful for scenario-driven runs)
for k in ("h_hist_si", "head_hist_si", "gwl_hist_si"):
if k in inputs and inputs[k] is not None:
v = tf_cast(inputs[k], tf_float32)
# (B,T) -> (B,T,1)
if tf_equal(tf_rank(v), 2):
v = v[:, :, None]
if want_head:
v = gwl_to_head_m(v, sk, inputs=inputs)
return v
Xh = inputs.get("dynamic_features", None)
if Xh is None:
raise ValueError(
"Cannot build head history: missing inputs['dynamic_features'] "
"and no explicit head history key (h_hist_si/head_hist_si)."
)
Xh = tf_cast(Xh, tf_float32)
assert_dynamic_names_match_tensor(Xh, sk)
gwl_idx = get_gwl_dyn_index_cached(model)
gwl = slice_dynamic_channel(Xh, gwl_idx)
gwl_si = to_si_head(gwl, sk)
return (
gwl_to_head_m(gwl_si, sk, inputs=inputs)
if want_head
else gwl_si
)
[docs]
def get_s_init_si(
model,
inputs: dict[str, Tensor] | None,
like: Tensor,
) -> Tensor:
"""Return initial settlement (cumulative subsidence) in SI meters.
Priority:
1) explicit keys in inputs (s_init_si/subs_hist_last_si/...)
2) last historical value from dynamic_features if subs_dyn_index exists
3) zeros (broadcast)
"""
sk = getattr(model, "scaling_kwargs", None)
if inputs is not None:
for k in (
"s_init_si",
"subs_init_si",
"subs_hist_last_si",
"s_ref_si",
"subs_ref_si",
"s_init",
"subs_init",
):
if k in inputs and inputs[k] is not None:
return _reshape_to_b11(
inputs[k]
) + tf_zeros_like(like)
Xh = inputs.get("dynamic_features", None)
if Xh is not None:
try:
subs_idx = get_subs_dyn_index_cached(model)
except Exception as e:
_handle_scaling_issue(
getattr(model, "scaling_kwargs", None),
f"Could not resolve subsidence init channel ({e}). "
"Falling back to zeros for s_init_si.",
where="runtime",
)
subs_idx = None
if subs_idx is not None:
Xh = tf_cast(Xh, tf_float32)
assert_dynamic_names_match_tensor(Xh, sk)
s_hist = slice_dynamic_channel(
Xh, int(subs_idx)
)
s_last = s_hist[:, -1:, :]
s_last_si = to_si_subsidence(s_last, sk)
return s_last_si + tf_zeros_like(like)
return tf_zeros_like(like)
[docs]
def get_h_ref_si(
model,
inputs: dict[str, Tensor] | None,
like: Tensor,
) -> Tensor:
"""Return h_ref in SI meters, broadcast to like."""
# sk = getattr(model, "scaling_kwargs", None)
mode = getattr(
getattr(model, "h_ref_config", None), "mode", "auto"
)
mode = (
"fixed"
if str(mode).lower().strip() == "fixed"
else "auto"
)
if inputs is not None:
for k in (
"h_ref_si",
"head_ref_si",
"h_ref",
"head_ref",
):
if (k in inputs) and (inputs[k] is not None):
h_ref = tf_cast(inputs[k], tf_float32)
r = tf_rank(h_ref)
h_ref = tf_cond(
tf_equal(r, 1),
lambda: h_ref[:, None, None],
lambda: tf_cond(
tf_equal(r, 2),
lambda: h_ref[:, None, :],
lambda: h_ref,
),
)
return h_ref + tf_zeros_like(like)
if (
mode != "fixed"
and inputs is not None
and "dynamic_features" in inputs
and inputs["dynamic_features"] is not None
):
h_hist = get_h_hist_si(model, inputs, want_head=True)
return h_hist[:, -1:, :] + tf_zeros_like(like)
h0 = tf_cast(getattr(model, "h_ref", 0.0), tf_float32)
h0 = h0[None, None, None]
return h0 + tf_zeros_like(like)
[docs]
def infer_dt_units_from_t(
t_BH1: Tensor,
scaling_kwargs: dict[str, Any] | None,
*,
eps: float = 1e-12,
) -> Tensor:
"""
Infer per-step dt in *time_units* from time tensor t(B,H,1).
Shapes
------
t_BH1 : (B,H,1)
returns: (B,H,1)
Notes
-----
- dt uses diffs along H; first step uses the first diff.
- If coords are normalized, dt is multiplied by the de-normalization
time range tR (from coord_ranges()).
- Output is clipped to >= eps.
"""
sk = scaling_kwargs or {}
t = tf_convert_to_tensor(t_BH1, dtype=tf_float32)
# t shape: (B,H,1)
H = tf_shape(t)[1]
dt_default = tf_ones_like(t) # (B,H,1), safe in-graph
def _multi_step():
diffs = t[:, 1:, :] - t[:, :-1, :] # (B,H-1,1)
dt_first = diffs[:, :1, :] # (B,1,1)
dt = tf_concat([dt_first, diffs], axis=1) # (B,H,1)
# If coords were normalized, dt is still normalized -> scale back
if bool(sk.get("coords_normalized", False)):
tR, _, _ = coord_ranges(sk)
if tR is None:
raise ValueError(
"coords_normalized=True but coord_ranges missing."
)
dt = dt * tf_constant(float(tR), dtype=tf_float32)
return dt
# if H <= 1: ones; else: diffs
dt = tf_cond(
tf_less_equal(H, 1), lambda: dt_default, _multi_step
)
dt = tf_abs(dt)
dt_pos = tf_greater(dt, tf_constant(0.0, tf_float32))
dt_pos_f = tf_cast(dt_pos, tf_float32)
dt = dt * dt_pos_f + dt_default * (1.0 - dt_pos_f)
dt_eps = float(get_sk(sk, "dt_min_units", default=1e-6))
dt = tf_maximum(dt, tf_constant(dt_eps, tf_float32))
return dt
# -------------------------------------------------
# Training strategy gates (Q and subsidence residual)
# ---------------------------------------------------------------------
[docs]
def policy_gate(
step: Tensor,
policy: str,
*,
warmup_steps: int = 0,
ramp_steps: int = 0,
dtype: Any = tf_float32,
) -> Tensor:
r"""Return a scalar gate in ``[0,1]`` based on a policy + step.
Parameters
----------
step : Tensor
Global step counter (typically ``optimizer.iterations``).
policy : {"always_on","always_off","warmup_off"}
Gating behavior. ``always_on`` returns 1, ``always_off``
returns 0, and ``warmup_off`` returns 0 for
``step < warmup_steps`` before ramping to 1 over
``ramp_steps`` when ``ramp_steps > 0`` or switching
immediately at ``warmup_steps`` otherwise.
warmup_steps : int, default=0
Number of steps to keep the gate at 0 (only for ``warmup_off``).
ramp_steps : int, default=0
Number of steps for a linear ramp from 0->1 after warmup.
If 0, the gate is a hard step.
dtype : dtype, default=tf_float32
Output dtype.
"""
pol = (policy or "always_on").strip().lower()
if pol in ("always_on", "on", "true", "1"):
return tf_constant(1.0, dtype=dtype)
if pol in ("always_off", "off", "false", "0"):
return tf_constant(0.0, dtype=dtype)
w = int(warmup_steps or 0)
r = int(ramp_steps or 0)
if w <= 0 and r <= 0:
return tf_constant(1.0, dtype=dtype)
step_i = tf_cast(step, tf_int32)
if r <= 0:
return tf_cast(
tf_greater_equal(
step_i, tf_constant(w, tf_int32)
),
dtype,
)
step_f = tf_cast(step_i, dtype)
w_f = tf_constant(float(w), dtype)
r_f = tf_constant(float(r), dtype)
frac = (step_f - w_f) / r_f
frac = tf_maximum(tf_constant(0.0, dtype), frac)
frac = tf_minimum(tf_constant(1.0, dtype), frac)
return frac
# ---------------------------------------------------------------------
# Derived SI conversion helpers (optional, but recommended)
# ---------------------------------------------------------------------
[docs]
def finalize_scaling_kwargs(
sk: dict[str, Any],
) -> dict[str, Any]:
"""Add derived SI conversion constants to ``scaling_kwargs``.
Adds (when possible):
- ``seconds_per_time_unit``: float
- ``coord_ranges_si``: dict with keys ``t`` (seconds), ``x``/``y`` (meters)
- ``coord_inv_ranges_si``: inverse of the above (safe floor).
Notes
-----
This helper is designed to be called *once* when assembling
``scaling_kwargs`` (e.g., in your stage2 script) so the model can
reuse those constants without recomputing unit conversions in the
hot training loop.
"""
if sk is None:
return sk
sk = dict(sk)
tu = (
str(get_sk(sk, "time_units", default="second"))
.strip()
.lower()
)
time_unit_to_seconds = {
"second": 1.0,
"sec": 1.0,
"s": 1.0,
"minute": 60.0,
"min": 60.0,
"m": 60.0,
"hour": 3600.0,
"h": 3600.0,
"day": 86400.0,
"d": 86400.0,
# Julian year (365.2425 days) to match prior_maths.py
"year": 31556952.0,
"yr": 31556952.0,
"y": 31556952.0,
}
sec_u = float(time_unit_to_seconds.get(tu, 1.0))
sk.setdefault("seconds_per_time_unit", sec_u)
cr = get_sk(sk, "coord_ranges", default=None)
if isinstance(cr, Mapping) and all(
k in cr for k in ("t", "x", "y")
):
tR = float(cr.get("t", 1.0))
xR = float(cr.get("x", 1.0))
yR = float(cr.get("y", 1.0))
# If coordinates are degrees, convert spans to meters.
if bool(
get_sk(sk, "coords_in_degrees", default=False)
):
deg_to_m_lon = get_sk(
sk, "deg_to_m_lon", default=None
)
deg_to_m_lat = get_sk(
sk, "deg_to_m_lat", default=None
)
if (
deg_to_m_lon is not None
and deg_to_m_lat is not None
):
xR *= float(deg_to_m_lon)
yR *= float(deg_to_m_lat)
# Convert time span to seconds (important if coords_normalized=True).
tR *= sec_u
sk["coord_ranges_si"] = {"t": tR, "x": xR, "y": yR}
eps = 1e-12
sk["coord_inv_ranges_si"] = {
"t": 1.0 / max(tR, eps),
"x": 1.0 / max(xR, eps),
"y": 1.0 / max(yR, eps),
}
return sk
[docs]
def coord_ranges_si(
sk: dict[str, Any],
) -> tuple[float | None, float | None, float | None]:
"""Return coordinate spans in SI (t in seconds; x/y in meters).
If ``coord_ranges_si`` is present in ``sk``, it is used directly.
Otherwise, this is computed from ``coord_ranges`` and ``time_units``
(and degree-to-meter factors when applicable).
"""
cr_si = get_sk(sk, "coord_ranges_si", default=None)
if isinstance(cr_si, Mapping) and all(
k in cr_si for k in ("t", "x", "y")
):
return (
float(cr_si["t"]),
float(cr_si["x"]),
float(cr_si["y"]),
)
sk2 = finalize_scaling_kwargs(sk)
cr_si = get_sk(sk2, "coord_ranges_si", default=None)
if isinstance(cr_si, Mapping) and all(
k in cr_si for k in ("t", "x", "y")
):
return (
float(cr_si["t"]),
float(cr_si["x"]),
float(cr_si["y"]),
)
return None, None, None