# SPDX-License-Identifier: Apache-2.0
# GeoPrior-v3 https://github.com/earthai-tech/geoprior-v3
# Copyright (c) 2026-present
# Author: LKouadio <https://lkouadio.com>
"""
GeoPrior maths helpers (physics terms + scaling).
"""
from __future__ import annotations
from collections.abc import Mapping, Sequence
from typing import Any
import numpy as np
from ...api.docs import (
DocstringComponents,
_halnet_core_params,
)
from ...compat.types import TensorLike
from ...logging import OncePerMessageFilter, get_logger
from .. import KERAS_DEPS, dependency_message
from .utils import coord_ranges, get_h_ref_si, get_sk
K = KERAS_DEPS
Tensor = K.Tensor
Dataset = K.Dataset
GradientTape = K.GradientTape
Constraint = K.Constraint
tf_abs = K.abs
tf_argmin = K.argmin
tf_broadcast_to = K.broadcast_to
tf_cast = K.cast
tf_clip_by_value = K.clip_by_value
tf_concat = K.concat
tf_cond = K.cond
tf_constant = K.constant
tf_convert_to_tensor = K.convert_to_tensor
tf_cumsum = K.cumsum
tf_debugging = K.debugging
tf_equal = K.equal
tf_exp = K.exp
tf_expand_dims = K.expand_dims
tf_float32 = K.float32
tf_gather = K.gather
tf_greater = K.greater
tf_identity = K.identity
tf_int32 = K.int32
tf_is_inf = K.is_inf
tf_is_nan = K.is_nan
tf_log = K.log
tf_logical_and = K.logical_and
tf_logical_or = K.logical_or
tf_math = K.math
tf_maximum = K.maximum
tf_minimum = K.minimum
tf_ones_like = K.ones_like
tf_pow = K.pow
tf_print = K.print
tf_rank = K.rank
tf_reduce_any = K.reduce_any
tf_reduce_max = K.reduce_max
tf_reduce_mean = K.reduce_mean
tf_reduce_min = K.reduce_min
tf_reduce_sum = K.reduce_sum
tf_reshape = K.reshape
tf_scan = K.scan
tf_shape = K.shape
tf_sigmoid = K.sigmoid
tf_softplus = K.softplus
tf_sqrt = K.sqrt
tf_square = K.square
tf_stack = K.stack
tf_stop_gradient = K.stop_gradient
tf_switch_case = K.switch_case
tf_tile = K.tile
tf_transpose = K.transpose
tf_where = K.where
tf_zeros = K.zeros
tf_zeros_like = K.zeros_like
register_keras_serializable = K.register_keras_serializable
deserialize_keras_object = K.deserialize_keras_object
# Optional: silence autograph verbosity in TF-backed runtimes.
tf_autograph = getattr(K, "autograph", None)
if tf_autograph is not None:
tf_autograph.set_verbosity(0)
# Module logger + shared docs
DEP_MSG = dependency_message("subsidence.maths")
logger = get_logger(__name__)
logger.addFilter(OncePerMessageFilter())
_param_docs = DocstringComponents.from_nested_components(
base=DocstringComponents(_halnet_core_params),
)
# Constants + types
_EPSILON = 1e-15
AxisLike = int | Sequence[int] | None
# Time units + scaling
TIME_UNIT_TO_SECONDS = {
"unitless": 1.0,
"step": 1.0,
"index": 1.0,
"s": 1.0,
"sec": 1.0,
"second": 1.0,
"seconds": 1.0,
"min": 60.0,
"minute": 60.0,
"minutes": 60.0,
"h": 3600.0,
"hr": 3600.0,
"hour": 3600.0,
"hours": 3600.0,
"day": 86400.0,
"days": 86400.0,
"week": 7.0 * 86400.0,
"weeks": 7.0 * 86400.0,
"year": 31556952.0,
"years": 31556952.0,
"yr": 31556952.0,
"month": 31556952.0 / 12.0,
"months": 31556952.0 / 12.0,
}
[docs]
class LogClipConstraint(Constraint):
r"""
NaN-safe clip constraint for log-parameters.
This constraint is intended for parameters stored in log-space,
such as ``logK``, ``logSs``, or ``log_tau``, where the model must
enforce hard bounds:
.. math::
w \in [w_{min}, w_{max}]
Why this exists
---------------
In TensorFlow, ``clip_by_value`` does not repair invalid values:
.. math::
clip(NaN, a, b) = NaN
Therefore, if a parameter ever becomes non-finite (NaN or Inf),
a plain clipping constraint will silently keep it invalid and
training can destabilize. This class explicitly sanitizes
non-finite entries before applying the clip.
Mapping
-------
Given an input weight tensor ``w`` and bounds
``min_value`` and ``max_value``:
1) Sanitize non-finite entries:
.. math::
w_{safe}[i]
=
\begin{cases}
w[i], & \text{if } w[i] \text{ is finite} \\
w_{min}, & \text{otherwise}
\end{cases}
2) Apply hard clipping:
.. math::
w_{out}
=
\min(\max(w_{safe}, w_{min}), w_{max})
The output is guaranteed to be finite as long as
``min_value`` and ``max_value`` are finite.
Parameters
----------
min_value : float or Tensor
Lower bound for the constrained tensor in log-space. This is
cast to ``tf_float32`` and stored.
max_value : float or Tensor
Upper bound for the constrained tensor in log-space. This is
cast to ``tf_float32`` and stored.
Returns
-------
Constraint
A callable constraint object compatible with Keras variables.
When applied, it returns a clipped tensor in float32.
Notes
-----
* This constraint is most appropriate for parameters represented
in log-space because hard bounds in log-space correspond to
multiplicative bounds in linear space.
* Sanitizing to ``min_value`` is a conservative choice:
it prevents NaN propagation while keeping the parameter within
the feasible region. If you prefer a different fallback (e.g.
0 or the midpoint), change the replacement value accordingly.
* The constraint operates in ``tf_float32`` for speed and
compatibility with typical training graphs.
Examples
--------
Constrain a learnable log-parameter:
.. code-block:: python
logK = tf.Variable(
initial_value=0.0,
constraint=LogClipConstraint(-20.0, 5.0),
trainable=True,
dtype=tf.float32,
)
In a Keras layer weight:
.. code-block:: python
self.log_tau = self.add_weight(
name="log_tau",
shape=(1,),
initializer="zeros",
trainable=True,
constraint=LogClipConstraint(log_tau_min, log_tau_max),
)
See Also
--------
keras.constraints.Constraint
Base class for Keras constraints.
tf.clip_by_value
Elementwise clipping. Note that it does not repair NaNs.
tf.where
Used here to sanitize non-finite entries before clipping.
"""
[docs]
def __init__(self, min_value, max_value):
self.min_value = tf_cast(min_value, tf_float32)
self.max_value = tf_cast(max_value, tf_float32)
def __call__(self, w):
w = tf_cast(w, tf_float32)
w = tf_where(
tf_math.is_finite(w),
w,
self.min_value,
)
return tf_clip_by_value(
w,
self.min_value,
self.max_value,
)
[docs]
def vprint(verbose: int, *args) -> None:
"""Verbose print (eager-friendly)."""
if int(verbose) > 0:
tf_print(*args)
[docs]
def tf_print_nonfinite(
tag: str, x: Tensor, summarize: int = 6
) -> Tensor:
"""Print a compact report ONLY if x contains NaN/Inf (graph-safe)."""
x = tf_convert_to_tensor(x, dtype=tf_float32)
is_nan = tf_is_nan(x)
is_inf = tf_is_inf(x)
is_bad = tf_logical_or(is_nan, is_inf)
n_nan = tf_reduce_sum(tf_cast(is_nan, tf_int32))
n_inf = tf_reduce_sum(tf_cast(is_inf, tf_int32))
n_bad = tf_reduce_sum(tf_cast(is_bad, tf_int32))
def _do_print():
# safe stats: replace bad values with 0 for min/max/mean
x_safe = tf_where(is_bad, tf_zeros_like(x), x)
tf_print(
"[NONFINITE]",
tag,
"| shape=",
tf_shape(x),
"| n_bad=",
n_bad,
"n_nan=",
n_nan,
"n_inf=",
n_inf,
"| min=",
tf_reduce_min(x_safe),
"| max=",
tf_reduce_max(x_safe),
"| mean=",
tf_reduce_mean(x_safe),
summarize=summarize,
)
return tf_constant(0, tf_int32)
return tf_cond(
n_bad > 0, _do_print, lambda: tf_constant(0, tf_int32)
)
# ---------------------------------------------------------------------
# Q-kind support (gw forcing)
# ---------------------------------------------------------------------
[docs]
def resolve_q_kind(sk: dict[str, Any] | None) -> str:
"""Normalize Q meaning for gw forcing."""
if not sk:
return "per_volume"
v = get_sk(
sk,
"Q_kind",
"q_kind",
"gw_q_kind",
default="per_volume",
)
mode = str(v).strip().lower()
if mode in (
"pervol",
"per_volume",
"volumetric",
"per_volume_rate",
):
return "per_volume"
if mode in (
"recharge",
"recharge_rate",
"infiltration",
"r",
):
return "recharge_rate"
if mode in ("head_rate", "dhdt", "head_forcing", "qh"):
return "head_rate"
return "per_volume"
[docs]
def q_to_gw_source_term_si(
model,
Q_logits: Tensor,
*,
Ss_field: TensorLike | None,
H_field: TensorLike | None,
coords_normalized: bool,
t_range_units: TensorLike | None,
time_units: str | None,
scaling_kwargs: dict[str, Any] | None,
H_floor: float = 1e0, # 1e-6,
verbose: int = 0,
) -> Tensor:
r"""
Convert ``Q_logits`` into a GW source term in SI units.
This helper maps the network output ``Q_logits`` into a source
term :math:`Q_{term}` that is compatible with the groundwater
PDE residual used by the model:
.. math::
R_{gw}
=
S_s \, \frac{\partial h}{\partial t}
-
\nabla \cdot (K \nabla h)
-
Q_{term}
The returned tensor always has units of 1/s so it can be
subtracted directly in :math:`R_{gw}`.
Overview
--------
The model can interpret the raw output ``Q_logits`` in multiple
ways depending on ``Q_kind`` resolved from ``scaling_kwargs``.
All modes must end with :math:`Q_{term}` in 1/s, but the meaning
of ``Q_logits`` differs:
``per_volume``
``Q_logits`` represents a volumetric forcing rate already in
inverse time units (either 1/time_unit or 1/s depending on
flags). This is the simplest and most backward-compatible
interpretation.
``recharge_rate``
``Q_logits`` represents a recharge flux expressed as a length
rate (m/time_unit or m/s). It is converted into a volumetric
rate by dividing by an effective thickness :math:`H`:
.. math::
Q_{term} = \frac{R}{H}
where :math:`R` is in m/s and :math:`H` is in m, giving 1/s.
``head_rate``
``Q_logits`` represents a head-rate forcing (m/time_unit or
m/s) that enters the storage term. Since the storage term is
:math:`S_s \, dh/dt`, the equivalent forcing in the residual
is:
.. math::
Q_{term} = S_s \, q_h
where :math:`q_h` is in m/s and :math:`S_s` is in 1/m,
yielding 1/s.
Time normalization handling
---------------------------
When coordinates are normalized (typical in this project), the
time coordinate is scaled by a range factor :math:`t_R`. If
``Q_logits`` was produced in that normalized coordinate system,
it must be converted back to the intended time units before any
SI conversion. This function delegates that correction to the
internal helper ``_apply_q_normalized_time_rule`` using:
* ``coords_normalized``
* ``t_range_units`` (the time range in model time units)
* ``scaling_kwargs`` flags
After this correction, unit conversion is applied based on the
selected ``Q_kind`` and SI flags.
Mode details
------------
per_volume
~~~~~~~~~~
In this mode the model output is treated as already being an
inverse-time quantity.
If either of the following flags is True:
* ``Q_in_per_second`` or
* ``Q_in_si``
then ``Q_logits`` is assumed to already be in 1/s and returned
directly.
Otherwise, it is assumed to be in 1/time_unit and converted to
1/s via:
.. math::
Q_{term} = Q \cdot \frac{1}{sec\_per\_time\_unit}
where ``sec_per_time_unit`` depends on ``time_units``.
recharge_rate
~~~~~~~~~~~~~
Here, ``Q_logits`` is treated as a length rate :math:`R`:
.. math::
R \in \mathrm{m}/\mathrm{s}
It is converted from m/time_unit to m/s unless
``Q_length_in_si`` is True.
The volumetric rate is then computed as:
.. math::
Q_{term} = \frac{R}{H}
To prevent division instability, the thickness is floored:
.. math::
H_{safe} = \max(H, H_{floor})
and the returned value is:
.. math::
Q_{term} = \frac{R}{H_{safe}}
head_rate
~~~~~~~~~
Here, ``Q_logits`` is treated as a head-rate :math:`q_h`:
.. math::
q_h \in \mathrm{m}/\mathrm{s}
It is converted from m/time_unit to m/s unless
``Q_length_in_si`` is True.
The residual forcing is:
.. math::
Q_{term} = S_s \, q_h
If ``Ss_field`` is missing, a robust fallback consistent with
the rest of the model is used:
.. math::
S_s \approx m_v \, gamma_w
where :math:`m_v` is taken from the model and :math:`gamma_w`
is the configured unit weight of water.
Parameters
----------
This function is typically called from the physics core where
all arguments are already defined. For meanings and expected
shapes, refer to the caller that constructs the GW residual
and its inputs.
In brief:
* ``Q_logits`` is the network output for the forcing channel.
* ``Ss_field`` and ``H_field`` are the effective fields used
by the PDE, broadcastable to the batch-horizon layout.
* ``coords_normalized`` and ``t_range_units`` describe how time
normalization was applied.
* ``time_units`` specifies the units used for conversion when
interpreting rates.
* ``scaling_kwargs`` provides configuration including ``Q_kind``
and unit flags.
Returns
-------
Q_term : Tensor
Source term :math:`Q_{term}` in 1/s, broadcastable to the
GW residual layout (typically (B,H,1)).
Raises
------
ValueError
If ``Q_kind='recharge_rate'`` is selected but ``H_field`` is
not provided.
Notes
-----
* Choose ``per_volume`` when you want a direct 1/s forcing that
can be interpreted as a volumetric sink/source per unit volume.
* Choose ``recharge_rate`` when you want the network to predict
a flux-like term (m/s) that becomes volumetric forcing by
dividing by thickness.
* Choose ``head_rate`` when you want forcing to act like an
additive term to :math:`dh/dt` inside the storage term.
Examples
--------
Assuming the physics core has already produced ``Q_logits`` and
the effective fields:
.. code-block:: python
Q_term = q_to_gw_source_term_si(
model,
Q_logits,
Ss_field=Ss_field,
H_field=H_field,
coords_normalized=coords_normalized,
t_range_units=t_range_units,
time_units=time_units,
scaling_kwargs=scaling_kwargs,
)
See Also
--------
rate_to_per_second
Converts values in 1/time_unit or m/time_unit to SI per
second rates.
_apply_q_normalized_time_rule
Corrects rates if the time coordinate was normalized.
resolve_q_kind
Resolves the configured Q interpretation mode from the
scaling configuration.
"""
sk = scaling_kwargs or {}
kind = resolve_q_kind(sk)
Q_base = tf_cast(Q_logits, tf_float32)
Q_base = _apply_q_normalized_time_rule(
Q_base,
sk=sk,
coords_normalized=coords_normalized,
t_range_units=t_range_units,
)
if kind == "per_volume":
# Backward-compatible flags for volumetric Q:
Q_in_per_second = bool(
get_sk(sk, "Q_in_per_second", default=False)
)
Q_in_si = bool(get_sk(sk, "Q_in_si", default=False))
if Q_in_per_second or Q_in_si:
Q_per_s = Q_base
else:
Q_per_s = rate_to_per_second(
Q_base, time_units=time_units
)
vprint(
verbose,
"Q_kind=per_volume, Q_term(1/s)=",
Q_per_s,
)
return Q_per_s
# For the other kinds, interpret Q as a LENGTH RATE (m/time)
# Use a *separate* flag so we don't conflict with Q_in_si default=True.
Q_len_in_si = bool(
get_sk(sk, "Q_length_in_si", default=False)
)
if Q_len_in_si:
Q_m_per_s = Q_base
else:
Q_m_per_s = rate_to_per_second(
Q_base, time_units=time_units
)
if kind == "recharge_rate":
if H_field is None:
raise ValueError(
"Q_kind='recharge_rate' requires H_field."
)
H_safe = tf_maximum(
tf_cast(H_field, tf_float32),
tf_constant(H_floor, tf_float32),
)
Q_term = Q_m_per_s / H_safe
vprint(
verbose,
"Q_kind=recharge_rate, Q_term(1/s)=",
Q_term,
)
return Q_term
# kind == "head_rate"
if Ss_field is None:
# robust fallback consistent with your consolidation logic:
Ss_eff = model._mv_value() * model.gamma_w
else:
Ss_eff = Ss_field
Q_term = tf_cast(Ss_eff, tf_float32) * Q_m_per_s
vprint(verbose, "Q_kind=head_rate, Q_term(1/s)=", Q_term)
return Q_term
def _apply_q_normalized_time_rule(
Q_base: Tensor,
*,
sk: dict[str, Any] | None,
coords_normalized: bool,
t_range_units: TensorLike | None,
) -> Tensor:
"""
If Q was produced w.r.t normalized time, convert it back to per-time_unit
by dividing by t_range_units.
"""
if not sk:
return Q_base
Q_wrt_norm_t = bool(
get_sk(sk, "Q_wrt_normalized_time", default=False)
)
if coords_normalized and Q_wrt_norm_t:
if t_range_units is None:
tR, _, _ = coord_ranges(sk)
if tR is None:
raise ValueError(
"Q_wrt_normalized_time=True but coord_ranges['t'] missing."
)
t_range_units = tf_constant(float(tR), tf_float32)
Q_base = Q_base / (
t_range_units + tf_constant(_EPSILON, tf_float32)
)
return Q_base
[docs]
def q_to_per_second(
Q_base: Tensor,
*,
scaling_kwargs: dict[str, Any] | None,
time_units: str | None,
coords_normalized: bool,
t_range_units: TensorLike | None = None,
eps: float = 1e-12,
) -> Tensor:
"""
Normalize Q into 1/s.
Assumed meaning (recommended default):
Q_kind = "per_volume" -> Q is already 1/time_unit or 1/s, representing
volumetric source/sink per unit volume.
If coords_normalized and Q_wrt_normalized_time=True, we de-normalize
by the time range first (same chain rule as dh/dt).
"""
sk = scaling_kwargs or {}
Q = tf_cast(Q_base, tf_float32)
# If produced w.r.t normalized time, de-normalize by t_range (in time_units)
if coords_normalized and bool(
get_sk(sk, "Q_wrt_normalized_time", default=False)
):
if t_range_units is None:
tR, _, _ = coord_ranges(sk)
if tR is None:
raise ValueError(
"Q_wrt_normalized_time=True but coord_ranges['t'] missing."
)
t_range_units = tf_constant(float(tR), tf_float32)
Q = Q / (t_range_units + tf_constant(eps, tf_float32))
# Interpretation:
# - If Q_in_per_second=True: Q already 1/s
# - Else: treat Q as 1/time_units and convert to 1/s
if bool(get_sk(sk, "Q_in_per_second", default=False)):
return Q
# IMPORTANT: I recommend default=False here (safer).
# Keep your current behavior if you must, but "Q_in_si" is ambiguous.
if bool(get_sk(sk, "Q_in_si", default=False)):
return Q
return rate_to_per_second(Q, time_units=time_units)
[docs]
def cons_step_to_cons_residual(
cons_step_m: Tensor,
*,
dt_units: Tensor,
scaling_kwargs: dict[str, Any] | None,
time_units: str | None,
eps: float = 1e-12,
) -> Tensor:
"""
Convert consolidation step residual (meters per step) into the chosen
residual units. Supported outputs are ``"step"`` for meters,
``"time_unit"`` for meters per time unit, and ``"second"`` for
meters per second (SI rate).
"""
sk = scaling_kwargs or {}
mode = resolve_cons_units(sk)
# dt safety (in time_units, e.g. years)
dt_min = float(get_sk(sk, "dt_min_units", default=1e-6))
dt_u = tf_maximum(
tf_abs(tf_cast(dt_units, tf_float32)),
tf_constant(dt_min, tf_float32),
)
if mode == "step":
return cons_step_m
if mode == "time_unit":
return cons_step_m / dt_u
# default: seconds
dt_sec = dt_to_seconds(dt_u, time_units=time_units)
dt_sec = tf_maximum(dt_sec, tf_constant(eps, tf_float32))
return cons_step_m / dt_sec
# ---------------------------------------------------------------------
# Physics residuals / priors
# ---------------------------------------------------------------------
def _canon_mv_prior_mode(v) -> str:
"""
Normalize mv-prior mode string to canonical labels.
"""
if v is None:
return "calibrate"
s = str(v).strip().lower()
s = s.replace("-", "_")
# ---- explicit off/disable ----
if s in (
"off",
"none",
"disabled",
"disable",
"false",
"0",
):
return "off"
# Default / detach-style synonyms.
if s in (
"default",
"detach",
"stopgrad",
"stop_grad",
"stop_gradient",
"calibrate",
"calibrate_mv",
):
return "calibrate"
# Fully coupled (can be unstable).
if s in (
"field",
"ss_field",
"backprop",
"coupled",
):
return "field"
# Prefer log-parameterization (safer anchoring).
if s in (
"logss",
"log_ss",
"logs",
):
return "logss"
# Unknown: keep user value (but non-empty).
return s or "calibrate"
def _get_mv_prior_mode(model) -> str:
"""
Resolve mv-prior mode from scaling kwargs (alias-safe).
Notes
-----
We try top-level keys first, then `bounds` fallback.
"""
sk = getattr(model, "scaling_kwargs", None) or {}
# 1) Top-level scaling kwargs (alias-safe).
v = get_sk(sk, "mv_prior_mode", default=None)
# 2) Nested bounds fallback (common pattern in this codebase).
if v is None:
b = sk.get("bounds", None) or {}
v = get_sk(b, "mv_prior_mode", default=None)
return _canon_mv_prior_mode(v)
def _resolve_mv_prior_weight(
model,
*,
weight=None,
warmup_steps=None,
step=None,
dtype=tf_float32,
) -> TensorLike | None:
"""
Resolve mv-prior weight with delay + warmup.
Keys
----
mv_schedule_unit: "epoch" or "step"
mv_delay_epochs, mv_warmup_epochs
mv_delay_steps, mv_warmup_steps
mv_steps_per_epoch (epoch->step)
"""
sk = getattr(model, "scaling_kwargs", None) or {}
b = sk.get("bounds", None) or {}
# ----------------------------
# Base weight.
# ----------------------------
if weight is None:
weight = get_sk(sk, "mv_weight", default=None)
if weight is None:
weight = get_sk(b, "mv_weight", default=None)
if weight is None:
return None
w = tf_constant(float(weight), dtype)
w = _finite_or_zero(w)
# No step => constant weight.
if step is None:
return w
# ----------------------------
# Schedule unit.
# ----------------------------
unit = get_sk(sk, "mv_schedule_unit", default=None)
if unit is None:
unit = get_sk(b, "mv_schedule_unit", default=None)
if unit is None:
unit = "step" if warmup_steps is not None else "epoch"
unit = str(unit).strip().lower()
if unit not in ("epoch", "step"):
unit = "step"
# ----------------------------
# Epoch params.
# ----------------------------
de = get_sk(sk, "mv_delay_epochs", default=None)
if de is None:
de = get_sk(b, "mv_delay_epochs", default=None)
we = get_sk(sk, "mv_warmup_epochs", default=None)
if we is None:
we = get_sk(b, "mv_warmup_epochs", default=None)
# ----------------------------
# Step params.
# ----------------------------
ds = get_sk(sk, "mv_delay_steps", default=None)
if ds is None:
ds = get_sk(b, "mv_delay_steps", default=None)
if warmup_steps is None:
ws = get_sk(sk, "mv_warmup_steps", default=None)
if ws is None:
ws = get_sk(b, "mv_warmup_steps", default=None)
else:
ws = warmup_steps
spe = get_sk(sk, "mv_steps_per_epoch", default=None)
if spe is None:
spe = get_sk(b, "mv_steps_per_epoch", default=None)
# ----------------------------
# Convert to ints.
# ----------------------------
def _to_int(v):
if v is None:
return None
try:
return int(v)
except Exception:
return None
delay_s = _to_int(ds)
warm_s = _to_int(ws)
# Epoch -> step conversion.
if unit == "epoch":
spe_i = _to_int(spe)
if spe_i is not None and spe_i > 0:
if delay_s is None:
de_i = _to_int(de) or 0
delay_s = max(0, de_i) * spe_i
if warm_s is None:
we_i = _to_int(we) or 0
warm_s = max(0, we_i) * spe_i
if delay_s is None:
delay_s = 0
# ----------------------------
# Ramp with delay + warmup.
# ----------------------------
s = tf_cast(step, dtype)
s = _finite_or_zero(s)
d = tf_constant(float(delay_s), dtype)
d = _finite_or_zero(d)
# Hard gate if warmup missing/0.
if (warm_s is None) or (warm_s <= 0):
one = tf_constant(1.0, dtype)
zero = tf_constant(0.0, dtype)
ramp = tf_where(s >= d, one, zero)
return w * ramp
wu = tf_constant(float(max(1, warm_s)), dtype)
wu = _finite_or_zero(wu)
ramp = tf_clip_by_value((s - d) / wu, 0.0, 1.0)
ramp = _finite_or_zero(ramp)
return w * ramp
[docs]
def resolve_mv_gamma_log_target_from_logSs(
model,
logSs,
*,
eps=_EPSILON,
verbose=0,
) -> Tensor:
"""
Like resolve_mv_gamma_log_target(), but uses logSs.
This is the preferred path for mode='logss' because it
avoids the 1/Ss gradient amplification from log(Ss_field).
"""
mv_units = _get_mv_prior_units(model)
log_mv = _safe_log_mv(model, eps=eps)
log_gw = _safe_log_gw(model, eps=eps)
# Strict path: smooth and stable.
if mv_units != "auto":
log_target = log_mv + log_gw
vprint(verbose, "mv_prior_units:", mv_units)
vprint(verbose, "log_target(strict):", log_target)
return log_target
# Auto path: choose 1e3 convention by matching mean(logSs).
logSs = tf_cast(logSs, tf_float32)
eps_t = tf_constant(float(eps), tf_float32)
log_eps = tf_log(eps_t)
logSs = tf_where(tf_math.is_finite(logSs), logSs, log_eps)
logSs_mean = tf_reduce_mean(logSs)
log1000 = tf_log(tf_constant(1000.0, tf_float32))
log_mv_c = tf_stack(
[log_mv, log_mv - log1000, log_mv, log_mv - log1000],
)
log_gw_c = tf_stack(
[log_gw, log_gw, log_gw + log1000, log_gw + log1000],
)
log_targets = log_mv_c + log_gw_c
errs = tf_abs(logSs_mean - log_targets)
idx = tf_cast(
tf_argmin(tf_stop_gradient(errs), axis=0),
tf_int32,
)
log_target = tf_gather(log_targets, idx)
vprint(verbose, "mv_prior_units:", mv_units)
vprint(verbose, "mv/gw idx:", idx)
vprint(verbose, "log_target(auto):", log_target)
return log_target
def _mv_prior_disabled_return(
*,
as_loss: bool,
Ss_field: TensorLike | None,
logSs: TensorLike | None,
dtype=tf_float32,
) -> Tensor:
"""
Return zeros for disabled mv-prior.
- as_loss=True -> scalar 0
- as_loss=False -> zeros_like(logSs or Ss_field)
"""
if bool(as_loss):
return tf_constant(0.0, dtype)
ref = logSs if (logSs is not None) else Ss_field
if ref is None:
return tf_constant(0.0, dtype)
ref = tf_cast(ref, dtype)
return tf_zeros_like(ref)
def _mv_prior_is_disabled(model, *, mode: str) -> bool:
"""
True if mv-prior should be skipped.
"""
if mode == "off":
return True
lam = float(getattr(model, "lambda_mv", 0.0))
return lam <= 0.0
[docs]
def compute_mv_prior(
model,
Ss_field: TensorLike | None = None,
*,
logSs: TensorLike | None = None,
mode: str | None = None,
as_loss: bool = True,
weight=None,
warmup_steps=None,
step=None,
alpha_disp=0.1,
delta=1.0,
eps=_EPSILON,
verbose=0,
):
r"""
Compute an m_v - gamma_w prior from predicted S_s.
This routine builds a log-space residual that ties the model's
specific storage :math:`S_s` to the consolidation coefficient
:math:`m_v` and the unit weight of water :math:`gamma_w` via:
.. math::
S_s \approx m_v \, \gamma_w
The constraint is applied in log space for numerical stability:
.. math::
r = \log(S_s) - \log(m_v \, \gamma_w)
Depending on ``mode``, gradients may be blocked or allowed to
flow through :math:`S_s` (or its log) to control stability.
Mathematical objective
----------------------
If ``as_loss=True``, this function returns a scalar loss built
from two components:
1) A global mismatch term based on the mean residual:
.. math::
\bar{r} = \mathrm{mean}(r)
.. math::
L_g = \mathrm{Huber}(\bar{r}; delta)
2) A dispersion term that discourages spatial or batch-wide
scatter around the mean residual:
.. math::
L_d = \mathrm{mean}(
\mathrm{Huber}(r - \bar{r}; delta)
)
The total loss is:
.. math::
L = L_g + alpha\_disp \, L_d
Optionally, an additional weight and warmup ramp may be applied:
.. math::
L \leftarrow w(step) \, L
Mode semantics and gradient flow
--------------------------------
The choice of mode controls where gradients are allowed.
``calibrate`` (default)
Uses :func:`tf.stop_gradient` on ``Ss_field`` before taking
:math:`\log(S_s)`. This calibrates :math:`m_v` without
reshaping the :math:`S_s` field produced by the trunk.
This is typically the safest choice when the mean settlement
is physics-driven and the network already has strong physics
constraints elsewhere.
``field``
Backpropagates through ``Ss_field``. This can be unstable
when :math:`S_s` becomes small because:
.. math::
\frac{\partial \log(S_s)}{\partial S_s} = \frac{1}{S_s}
so gradients can be amplified.
``logss``
Backpropagates through ``logSs`` directly. This is often
more stable than ``field`` because the log transform is
already computed upstream in a controlled manner (for
example, using guarded exponentiation in the field
composer). Use this when you want a stronger anchoring of
the log-storage field without the 1/S_s amplification.
Inputs used
-----------
This function requires exactly one of:
* ``logSs`` when ``mode='logss'``, or
* ``Ss_field`` when ``mode!='logss'``.
The log target term :math:`\log(m_v \, \gamma_w)` is obtained
from the model configuration through helper resolvers. These
helpers may also apply internal conventions such as:
* whether :math:`m_v` is learnable or fixed,
* the unit system for :math:`gamma_w`,
* safe floors ``eps`` to avoid :math:`\log(0)`.
Parameters
----------
model : Any
Model instance providing :math:`m_v`, :math:`gamma_w`,
scaling configuration, and optional scheduling state used
by helper resolvers.
Ss_field : Tensor, optional
Specific storage field :math:`S_s` in 1/m. Required unless
``mode='logss'``. The expected shape is broadcastable to
the physics batch layout (typically (B,H,1) or (B,1,1)).
logSs : Tensor, optional
Log-specific storage :math:`\log(S_s)`. Required when
``mode='logss'``. Prefer passing the raw log output from
the field composer to preserve true bound violations.
mode : str, optional
Prior mode controlling gradient flow. If None, the mode is
resolved from model configuration. Supported modes are
``'calibrate'``, ``'field'``, and ``'logss'`` (aliases may
be accepted by the internal canonicalizer).
as_loss : bool, default=True
If True, return a scalar loss :math:`L`. If False, return
the residual field :math:`r`.
weight : float or Tensor, optional
Optional multiplicative factor applied to the returned loss.
If None, the function may still derive a weight from model
configuration.
warmup_steps : int, optional
If provided, enables an internal warmup schedule for the
prior weight via helper logic (for example a linear ramp
from 0 to 1 over ``warmup_steps``).
step : int or Tensor, optional
Training step index passed to the warmup logic. If None, the
warmup logic may use model state or disable warmup.
alpha_disp : float, default=0.1
Weight for the dispersion penalty :math:`L_d`.
delta : float, default=1.0
Huber threshold parameter used by the robust penalty.
eps : float, default=_EPSILON
Positive floor used when computing :math:`\log(S_s)` to
avoid :math:`\log(0)` and reduce numerical issues.
verbose : int, default=0
Verbosity flag forwarded to internal helpers for debugging.
Returns
-------
loss_or_residual : Tensor
If ``as_loss=True``, a scalar tensor representing the MV
prior loss. If ``as_loss=False``, the residual field
:math:`r` with the same shape as ``logSs`` (or derived
logSs from ``Ss_field``).
Raises
------
ValueError
If required inputs are missing for the selected mode
(for example, ``mode='logss'`` without ``logSs``).
Notes
-----
Why log space
~~~~~~~~~~~~~
Using :math:`\log(S_s)` makes the prior scale-invariant and
reduces sensitivity to absolute magnitudes. It also aligns with
how bounds on :math:`K` and :math:`S_s` are often expressed in
log space.
Why the mean + dispersion split
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Calibrating :math:`m_v` is primarily a global adjustment. The
mean term penalizes systematic mismatch, while the dispersion
term discourages pathological spatial variability in the prior
residual without forcing every location to match exactly.
Scheduling
~~~~~~~~~~
In many training regimes, it is beneficial to ramp this prior
after the data loss stabilizes. The optional warmup hook allows
a gradual introduction to avoid early domination.
Examples
--------
Compute a scalar MV prior loss in calibrate mode:
.. code-block:: python
loss_mv = compute_mv_prior(
model,
Ss_field=Ss_field,
mode="calibrate",
as_loss=True,
alpha_disp=0.1,
delta=1.0,
)
Use the residual field for diagnostics (no reduction):
.. code-block:: python
r_mv = compute_mv_prior(
model,
Ss_field=Ss_field,
mode="field",
as_loss=False,
)
Use logSs from the field composer (preferred for strong
anchoring without 1/S_s amplification):
.. code-block:: python
K, Ss, tau, tau_phys, Hd, dlogtau, logK, logSs, log_tau, log_tau_phys = (
compose_physics_fields(...)
)
loss_mv = compute_mv_prior(
model,
logSs=logSs,
mode="logss",
as_loss=True,
)
See Also
--------
compose_physics_fields
Produces ``logSs`` consistent with guarded exponentiation.
compute_consistency_prior
Prior linking learned tau to physically implied tau.
assemble_physics_loss
Combines MV prior with other physics terms and offsets.
"""
# ----------------------------------------------------------
# 1) Resolve mode (alias-safe via scaling_kwargs).
# ----------------------------------------------------------
if mode is None:
mode = _get_mv_prior_mode(model)
mode = _canon_mv_prior_mode(mode)
# ----------------------------
# 1b) Off / disabled gate.
# ----------------------------
if _mv_prior_is_disabled(model, mode=mode):
return _mv_prior_disabled_return(
as_loss=as_loss,
Ss_field=Ss_field,
logSs=logSs,
)
# ----------------------------------------------------------
# 2) Build log-space residual r.
# ----------------------------------------------------------
if mode == "logss":
if logSs is None:
raise ValueError(
"mode='logss' requires `logSs` from "
"compose_physics_fields().",
)
logSs_ = tf_cast(logSs, tf_float32)
log_target = resolve_mv_gamma_log_target_from_logSs(
model,
logSs_,
eps=eps,
verbose=verbose,
)
r = logSs_ - log_target
else:
if Ss_field is None:
raise ValueError(
"compute_mv_prior requires Ss_field "
"for mode != 'logss'.",
)
Ss_in = Ss_field
# Default: detach Ss to avoid trunk destabilization.
if mode == "calibrate":
Ss_in = tf_stop_gradient(Ss_field)
logSs_ = safe_log_pos(Ss_in, eps=eps)
log_target = resolve_mv_gamma_log_target(
model,
Ss_in,
eps=eps,
verbose=verbose,
)
r = logSs_ - log_target
# Return residual if requested (diagnostics use-case).
if not bool(as_loss):
return r
# ----------------------------------------------------------
# 3) Scalar loss: global mismatch + dispersion penalty.
# ----------------------------------------------------------
r_bar = tf_reduce_mean(r)
loss_g = huber(r_bar, delta=delta)
loss_d = tf_reduce_mean(huber(r - r_bar, delta=delta))
a = tf_constant(float(alpha_disp), tf_float32)
loss = loss_g + a * loss_d
# ----------------------------------------------------------
# 4) Optional independent weight + warmup ramp.
# ----------------------------------------------------------
w = _resolve_mv_prior_weight(
model,
weight=weight,
warmup_steps=warmup_steps,
step=step,
)
if w is not None:
loss = loss * w
return loss
def _get_mv_prior_units(model) -> str:
"""
Get mv prior units mode from scaling kwargs.
Expected values:
- "auto" : choose best 1e3 convention
- "strict" : use log(mv) + log(gamma_w)
"""
sk = getattr(model, "scaling_kwargs", None) or {}
# Allow either top-level or nested placement.
v = sk.get("mv_prior_units", None)
if v is None:
b = sk.get("bounds", None) or {}
v = b.get("mv_prior_units", None)
if v is None:
return "strict"
return str(v).strip().lower()
def _safe_log_mv(model, *, eps=_EPSILON) -> Tensor:
"""
Return log(mv) safely.
- If mv is learnable: use model.log_mv (log-space).
- If mv is fixed: log(mv_fixed) in a safe way.
- If missing/None: return log(eps).
"""
eps_t = tf_constant(float(eps), tf_float32)
log_eps = tf_log(eps_t)
log_mv_raw = getattr(model, "log_mv", None)
if log_mv_raw is not None:
log_mv = tf_cast(log_mv_raw, tf_float32)
return tf_where(
tf_math.is_finite(log_mv),
log_mv,
log_eps,
)
mv = getattr(model, "_mv_fixed", None)
if mv is None:
return log_eps
return safe_log_pos(mv, eps=eps)
def _safe_log_gw(model, *, eps=_EPSILON) -> Tensor:
"""
Return log(gamma_w) safely.
Uses a constant fallback if gamma_w is missing/None.
"""
gw = getattr(model, "gamma_w", None)
if gw is None:
gw = tf_constant(9810.0, tf_float32)
return safe_log_pos(gw, eps=eps)
[docs]
def resolve_mv_gamma_log_target(
model,
Ss_field,
*,
eps=_EPSILON,
verbose=0,
) -> Tensor:
"""
Return log(mv * gamma_w) with configurable units.
If mv_prior_units == "strict":
log_target = log(mv) + log(gamma_w)
If mv_prior_units == "auto":
pick among 4 candidates that best matches
mean(log(Ss_field)) in magnitude:
- mv vs mv/1000
- gamma_w vs gamma_w*1000
"""
mv_units = _get_mv_prior_units(model)
log_mv = _safe_log_mv(model, eps=eps)
log_gw = _safe_log_gw(model, eps=eps)
# Strict path: no argmin, no discrete switches.
if mv_units != "auto":
log_target = log_mv + log_gw
vprint(verbose, "mv_prior_units:", mv_units)
vprint(verbose, "log_target(strict):", log_target)
return log_target
# Auto path: use Ss only for scale matching.
logSs_mean = tf_reduce_mean(
safe_log_pos(Ss_field, eps=eps),
)
log1000 = tf_log(tf_constant(1000.0, tf_float32))
# 4 candidates (log-space).
log_mv_c = tf_stack(
[log_mv, log_mv - log1000, log_mv, log_mv - log1000],
)
log_gw_c = tf_stack(
[log_gw, log_gw, log_gw + log1000, log_gw + log1000],
)
log_targets = log_mv_c + log_gw_c
errs = tf_abs(logSs_mean - log_targets)
# Discrete choice only; do not backprop it.
idx = tf_cast(
tf_argmin(tf_stop_gradient(errs), axis=0),
tf_int32,
)
log_target = tf_gather(log_targets, idx)
vprint(verbose, "mv_prior_units:", mv_units)
vprint(verbose, "mv/gw idx:", idx)
vprint(verbose, "log_target(auto):", log_target)
return log_target
# -----------------------------
# Reusable numeric helpers
# -----------------------------
[docs]
def safe_pos(x, *, eps=_EPSILON, dtype=tf_float32):
"""
Force x to be finite and >= eps.
Replaces NaN/Inf by eps, then floors.
"""
eps_t = tf_constant(float(eps), dtype)
x = tf_cast(x, dtype)
x = tf_where(tf_math.is_finite(x), x, eps_t)
x = tf_clip_by_value(x, eps_t, tf_constant(1e30, dtype))
return tf_maximum(x, eps_t)
[docs]
def safe_log_pos(x, *, eps=_EPSILON, dtype=tf_float32):
"""log(safe_pos(x))."""
return tf_log(safe_pos(x, eps=eps, dtype=dtype))
[docs]
def huber(x, *, delta=1.0):
"""
Huber loss (elementwise).
delta is treated as a scalar constant.
"""
d = tf_constant(float(delta), x.dtype)
ax = tf_abs(x)
quad = tf_minimum(ax, d)
lin = ax - quad
return 0.5 * tf_square(quad) + d * lin
[docs]
def compute_gw_flow_residual(
model,
dh_dt: Tensor,
d_K_dh_dx_dx: Tensor,
d_K_dh_dy_dy: Tensor,
Ss_field: Tensor,
*,
Q: TensorLike | None = None,
verbose: int = 0,
) -> Tensor:
"""Groundwater flow PDE residual (NaN/Inf-safe, broadcast-safe)."""
if "gw_flow" not in model.pde_modes_active:
return tf_zeros_like(dh_dt)
# --- convert + sanitize core terms ---
dh_dt = _finite_or_zero(
tf_convert_to_tensor(dh_dt, dtype=tf_float32)
)
d_K_dh_dx_dx = _finite_or_zero(
tf_convert_to_tensor(d_K_dh_dx_dx, dtype=dh_dt.dtype)
)
d_K_dh_dy_dy = _finite_or_zero(
tf_convert_to_tensor(d_K_dh_dy_dy, dtype=dh_dt.dtype)
)
Ss_field = _finite_or_zero(
tf_convert_to_tensor(Ss_field, dtype=dh_dt.dtype)
)
# --- Q: scalar / (H,) / (B,H) / (B,H,1) -> (B,H,1) ---
if Q is None:
Qv = tf_zeros_like(dh_dt)
else:
Qv = tf_convert_to_tensor(Q, dtype=dh_dt.dtype)
Qv = ensure_3d(
Qv
) # scalar->(1,1,1), (H,)->(1,H,1), (B,H)->(B,H,1)
Qv = tf_broadcast_to(
Qv, tf_shape(dh_dt)
) # now broadcast is valid
Qv = _finite_or_zero(Qv)
div_K_grad_h = d_K_dh_dx_dx + d_K_dh_dy_dy
storage_term = Ss_field * dh_dt
out = storage_term - div_K_grad_h - Qv
out = _finite_or_zero(
out
) # optional, but makes the "output finite" contract explicit
if verbose > 6:
vprint(verbose, "gw: dh_dt=", dh_dt)
vprint(verbose, "gw: div=", div_K_grad_h)
vprint(verbose, "gw: Q=", Qv)
vprint(verbose, "gw: out=", out)
tf_print(
"to_rms(Ss_field * dh_dt)=",
to_rms(Ss_field * dh_dt),
"to_rms(div_K_grad_h)=",
to_rms(div_K_grad_h),
"to_rms(Qv)=",
to_rms(Qv),
"to_rms(out)=",
to_rms(out),
)
return out
[docs]
def compute_consolidation_residual(
model,
ds_dt: Tensor,
s_state: Tensor,
h_mean: Tensor,
H_field: Tensor,
tau_field: Tensor,
*,
Ss_field: TensorLike | None = None,
inputs: dict[str, Tensor] | None = None,
verbose: int = 0,
) -> Tensor:
"""Consolidation PDE residual (Voigt)."""
if "consolidation" not in model.pde_modes_active:
return tf_zeros_like(ds_dt)
eps = tf_constant(_EPSILON, dtype=tf_float32)
tau_safe = tf_maximum(tau_field, eps)
h_ref_si = get_h_ref_si(model, inputs, like=h_mean)
delta_h = tf_maximum(h_ref_si - h_mean, 0.0)
if Ss_field is None:
Ss_eff = model._mv_value() * model.gamma_w
src = "mv*gw"
else:
Ss_eff = Ss_field
src = "Ss_field"
s_eq = Ss_eff * delta_h * H_field
relaxation = (s_eq - s_state) / tau_safe
out = ds_dt - relaxation
vprint(verbose, "cons: h_ref=", h_ref_si)
vprint(verbose, "cons: delta_h=", delta_h)
vprint(verbose, "cons: Ss_eff(", src, ")=", Ss_eff)
vprint(verbose, "cons: s_eq=", s_eq)
vprint(verbose, "cons: s_state=", s_state)
vprint(verbose, "cons: relax=", relaxation)
vprint(verbose, "cons: out=", out)
return out
def _positive_part(
x: Tensor,
*,
mode: str = "smooth_relu",
beta: float = 20.0,
eps: float = _EPSILON,
zero_at_origin: bool = False,
) -> Tensor:
"""Return the non-negative part of x, with selectable smoothness.
Parameters
----------
x : Tensor
Input tensor.
mode : {'smooth_relu', 'relu', 'softplus', 'none'}
- 'smooth_relu': softplus(beta*x)/beta (smooth ReLU approx)
- 'relu' : max(x, 0)
- 'softplus' : softplus(x) (always > 0)
- 'none' : x (no clamping)
beta : float
Curvature control for 'smooth_relu'. Larger -> closer to ReLU.
eps : float
Small additive floor after gating (usually 0).
zero_at_origin : bool
If True and mode == 'smooth_relu', shift so that output is
(approximately) 0 at x=0:
softplus(beta*x)/beta - log(2)/beta
Note: this shifted version can become slightly negative for x<0.
If you need strict non-negativity, keep this False.
Returns
-------
Tensor
Gated tensor.
"""
mode = str(mode).strip().lower()
x = tf_cast(x, tf_float32)
x = _finite_or_zero(x)
if mode == "none":
y = x
elif mode == "relu":
y = tf_maximum(x, tf_constant(eps, dtype=x.dtype))
elif mode == "softplus":
y = positive(x, eps=eps)
elif mode == "smooth_relu":
b = tf_constant(float(beta), dtype=x.dtype)
y = tf_softplus(b * x) / b
if bool(zero_at_origin):
log2 = tf_constant(
float(np.log(2.0)), dtype=x.dtype
)
y = y - (log2 / b)
else:
raise ValueError(
"_positive_part: mode must be one of "
"{'smooth_relu','relu','softplus','none'}."
)
if eps and float(eps) > 0.0:
y = y + tf_constant(float(eps), dtype=y.dtype)
return y
[docs]
def equilibrium_compaction_si(
*,
h_mean_si: Tensor,
h_ref_si: Tensor,
Ss_field: Tensor,
H_field_si: Tensor,
drawdown_mode: str = "smooth_relu",
drawdown_rule: str = "ref_minus_mean",
relu_beta: float = 20.0,
stop_grad_ref: bool = True,
drawdown_zero_at_origin: bool = False,
drawdown_clip_max: float | None = None,
eps: float = _EPSILON,
verbose: int = 0,
) -> Tensor:
r"""
Compute equilibrium compaction ``s_eq`` in SI meters.
This function computes the equilibrium (instantaneous) settlement
that would be reached under a sustained head change, given a
specific storage field and a compressible thickness. The output
``s_eq`` is used by the consolidation residual to compare the
current settlement state against its equilibrium target.
Mathematical definition
-----------------------
Let:
* :math:`h_{mean}` be the mean hydraulic head (m),
* :math:`h_{ref}` be a reference head (m),
* :math:`S_s` be specific storage (1/m),
* :math:`H` be compressible thickness (m).
A drawdown (head loss) scalar :math:`Delta h` is formed from a
rule:
.. math::
Delta h_{raw} =
\begin{cases}
h_{ref} - h_{mean}, & \text{rule = ref\_minus\_mean} \\
h_{mean} - h_{ref}, & \text{rule = mean\_minus\_ref}
\end{cases}
A non-negative drawdown is enforced using a gating operator
:math:`[x]_+` controlled by ``drawdown_mode``:
.. math::
Delta h = [Delta h_{raw}]_+
Finally, equilibrium compaction is:
.. math::
s_{eq} = S_s \, Delta h \, H
Units are consistent: :math:`(1/m) * (m) * (m) = m`.
Drawdown gating
---------------
The gating operator is chosen by ``drawdown_mode``:
* ``"none"``:
:math:`[x]_+ = x` (no positivity enforcement).
* ``"relu"``:
:math:`[x]_+ = max(0, x)`.
* ``"softplus"``:
Smooth positive part via softplus (implementation dependent).
* ``"smooth_relu"``:
A smooth approximation to ReLU controlled by ``relu_beta``.
If ``drawdown_zero_at_origin=True`` in ``"smooth_relu"`` mode,
the smooth positive part is shifted so its value at zero is
approximately zero, improving interpretability of the residual
near :math:`Delta h_{raw}=0`.
Reference-gradient handling
---------------------------
If ``stop_grad_ref=True`` (recommended), gradients are stopped
through the reference:
.. math::
h_{ref} := stop\_gradient(h_{ref})
This prevents the model from trivially reducing drawdown by
moving the reference rather than adjusting the predicted head.
Clipping
--------
If ``drawdown_clip_max`` is provided, the gated drawdown is
clipped after gating:
.. math::
Delta h := clip(Delta h, eps, Delta h_{max})
This can prevent extremely large drawdowns from dominating the
physics loss in early training.
Parameters
----------
h_mean_si : Tensor
Mean head (or depth, depending on your pipeline) in meters.
Must be broadcastable to shape ``(B, H, 1)``. If provided as
``(B, H)``, it is expanded to ``(B, H, 1)``.
h_ref_si : Tensor
Reference head (or depth) in meters, broadcastable to
``h_mean_si``. If ``stop_grad_ref=True``, gradients are
stopped through this tensor.
Ss_field : Tensor
Specific storage field :math:`S_s` in 1/m. Must be
broadcastable to ``h_mean_si``. Non-finite values are
sanitized to zero. A non-negativity clamp is applied.
H_field_si : Tensor
Compressible thickness :math:`H` in meters. Must be
broadcastable to ``h_mean_si``. Non-finite values are
sanitized to zero. A non-negativity clamp is applied.
drawdown_mode : str, default="smooth_relu"
Positivity enforcement for drawdown. Supported values:
``"smooth_relu"``, ``"relu"``, ``"softplus"``, ``"none"``.
drawdown_rule : str, default="smooth_relu"
Rule used to form raw drawdown:
``"ref_minus_mean"`` or ``"mean_minus_ref"``.
relu_beta : float, default=20.0
Smoothness/steepness for ``"smooth_relu"`` gating. Larger
values approach hard ReLU.
stop_grad_ref : bool, default=True
If True, apply ``stop_gradient`` to ``h_ref_si`` to prevent
reference drift.
drawdown_zero_at_origin : bool, default=False
Only used by ``"smooth_relu"``. If True, shift the smooth
positive-part so it is approximately zero at input zero.
drawdown_clip_max : float or None, default=None
If provided, clip gated drawdown to
``[eps, drawdown_clip_max]``.
eps : float, default=_EPSILON
Small positive constant used by gating/clipping utilities.
verbose : int, default=0
Verbosity level for debug printing and basic stats.
Returns
-------
s_eq : Tensor
Equilibrium compaction in meters, shape ``(B, H, 1)``. Any
non-finite values are sanitized to zero as a final safeguard.
Notes
-----
Head vs depth convention
~~~~~~~~~~~~~~~~~~~~~~~~
The default drawdown convention assumes head loss:
.. math::
Delta h = h_{ref} - h_{mean}
If upstream code uses a depth-like quantity that increases
downward, the physically meaningful drawdown may require the
opposite sign:
.. math::
Delta h = h_{mean} - h_{ref}
Use ``drawdown_rule="mean_minus_ref"`` in that case.
Sanitization behavior
~~~~~~~~~~~~~~~~~~~~~
Inputs are sanitized with a "finite-or-zero" rule before use.
This favors training stability over strict error signaling. If
you want fail-fast behavior, validate upstream.
Examples
--------
Compute equilibrium compaction with default settings:
>>> s_eq = equilibrium_compaction_si(
... h_mean_si=h_mean,
... h_ref_si=h_ref,
... Ss_field=Ss,
... H_field_si=H,
... )
Flip the drawdown rule for depth-like signals:
>>> s_eq = equilibrium_compaction_si(
... h_mean_si=depth_mean,
... h_ref_si=depth_ref,
... Ss_field=Ss,
... H_field_si=H,
... drawdown_rule="mean_minus_ref",
... )
Clip extreme drawdowns during early training:
>>> s_eq = equilibrium_compaction_si(
... h_mean_si=h_mean,
... h_ref_si=h_ref,
... Ss_field=Ss,
... H_field_si=H,
... drawdown_clip_max=50.0,
... )
See Also
--------
compute_consolidation_step_residual
Uses ``s_eq`` as the equilibrium target in the ODE residual.
settlement_state_for_pde
Maps model settlement outputs to the ODE state convention.
"""
h_mean_si = _ensure_3d(tf_cast(h_mean_si, tf_float32))
h_ref_si = _broadcast_like(
_ensure_3d(tf_cast(h_ref_si, tf_float32)), h_mean_si
)
Ss_field = _broadcast_like(
_ensure_3d(Ss_field), h_mean_si
)
H_field_si = _broadcast_like(
_ensure_3d(H_field_si), h_mean_si
)
def _n_bad(x: Tensor) -> Tensor:
return tf_reduce_sum(
tf_cast(~tf_math.is_finite(x), tf_int32)
)
# --- debug counts BEFORE sanitization
vprint(
verbose,
"[equilibrium_compaction_si] nonfinite counts (pre):",
"h_mean",
_n_bad(h_mean_si),
"h_ref",
_n_bad(h_ref_si),
"Ss",
_n_bad(Ss_field),
"H",
_n_bad(H_field_si),
)
# --- sanitize ALL inputs (this is what makes your tests pass)
h_mean_si = _finite_or_zero(h_mean_si)
h_ref_si = _finite_or_zero(h_ref_si)
Ss_field = _finite_or_zero(Ss_field)
H_field_si = _finite_or_zero(H_field_si)
# Optional: enforce non-negativity for physical fields
# (keeps math stable if something goes negative)
zero = tf_constant(0.0, dtype=tf_float32)
Ss_field = tf_maximum(Ss_field, zero)
H_field_si = tf_maximum(H_field_si, zero)
# --- debug counts AFTER sanitization
vprint(
verbose,
"[equilibrium_compaction_si] nonfinite counts (post):",
"h_mean",
_n_bad(h_mean_si),
"h_ref",
_n_bad(h_ref_si),
"Ss",
_n_bad(Ss_field),
"H",
_n_bad(H_field_si),
)
if bool(stop_grad_ref):
h_ref_si = tf_stop_gradient(h_ref_si)
vprint(
verbose,
"[equilibrium_compaction_si] shapes:",
"h_mean",
h_mean_si.shape,
"h_ref",
h_ref_si.shape,
"Ss",
Ss_field.shape,
"H",
H_field_si.shape,
"| mode=",
drawdown_mode,
"| rule=",
drawdown_rule,
"| stop_grad_ref=",
stop_grad_ref,
)
rule = str(drawdown_rule).strip().lower()
if rule in {"ref_minus_mean", "ref-mean", "ref_mean"}:
delta_raw = h_ref_si - h_mean_si
elif rule in {"mean_minus_ref", "mean-ref", "mean_ref"}:
delta_raw = h_mean_si - h_ref_si
else:
raise ValueError(
"equilibrium_compaction_si: drawdown_rule must be "
"'ref_minus_mean' or 'mean_minus_ref'."
)
delta_h = _positive_part(
delta_raw,
mode=drawdown_mode,
beta=relu_beta,
eps=eps,
zero_at_origin=bool(drawdown_zero_at_origin),
)
if drawdown_clip_max is not None:
mx = tf_constant(
float(drawdown_clip_max), dtype=delta_h.dtype
)
delta_h = tf_clip_by_value(
delta_h, tf_constant(eps, dtype=delta_h.dtype), mx
)
vprint(
verbose,
"[equilibrium_compaction_si] delta_h stats:",
"min=",
tf_reduce_min(delta_h),
"max=",
tf_reduce_max(delta_h),
"mean=",
tf_reduce_mean(delta_h),
)
s_eq = Ss_field * delta_h * H_field_si
s_eq = _finite_or_zero(s_eq) # extra safety net
vprint(
verbose,
"[equilibrium_compaction_si] s_eq stats:",
"min=",
tf_reduce_min(s_eq),
"max=",
tf_reduce_max(s_eq),
"mean=",
tf_reduce_mean(s_eq),
"| nonfinite=",
_n_bad(s_eq),
)
return s_eq
[docs]
def integrate_consolidation_mean(
*,
h_mean_si: Tensor,
Ss_field: Tensor,
H_field_si: Tensor,
tau_field: Tensor,
h_ref_si: Tensor,
s_init_si: Tensor,
dt: TensorLike | None = None,
time_units: str | None = "yr",
method: str = "exact",
eps_tau: float = 1e-12,
relu_beta: float = 20.0,
drawdown_mode: str = "smooth_relu",
drawdown_rule: str = "ref_minus_mean",
stop_grad_ref: bool = True,
drawdown_zero_at_origin: bool = False,
drawdown_clip_max: float | None = None,
verbose: int = 0,
) -> Tensor:
r"""
Integrate mean consolidation settlement over a forecast horizon.
This routine evolves the mean settlement state
:math:`\bar{s}(t)` using a stable, shape-safe time stepper that is
compatible with TensorFlow graph execution. It is designed for the
GeoPriorSubsNet "Option-1" mean path, where the mean subsidence is
physics-driven from the predicted head.
The integrator advances a first-order relaxation model:
.. math::
\frac{d\bar{s}}{dt} =
\frac{s_{eq}(t) - \bar{s}(t)}{\tau(t)}
where:
* :math:`\bar{s}(t)` is the mean settlement state (m),
* :math:`s_{eq}(t)` is the equilibrium compaction (m),
* :math:`\tau(t)` is a consolidation time scale (s).
The equilibrium compaction is computed by
:func:`equilibrium_compaction_si`:
.. math::
s_{eq}(t) = S_s(t)\, \Delta h(t)\, H(t)
with :math:`S_s` (1/m), :math:`H` (m), and drawdown
:math:`\Delta h` (m) formed from ``h_mean_si`` and ``h_ref_si``
using ``drawdown_rule`` and gated by ``drawdown_mode``.
Discrete-time update
--------------------
Let :math:`t_0, ..., t_{H-1}` be the horizon times and let
:math:`\Delta t_i` be the step duration in seconds. The state
update can be done in two ways:
Exact step (stable for large :math:`\Delta t/\tau`)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
For each step :math:`i`:
.. math::
a_i = \exp\left(-\frac{\Delta t_i}{\tau_i}\right)
.. math::
\bar{s}_{i} =
a_i\, \bar{s}_{i-1} + (1-a_i)\, s_{eq,i}
This is the closed-form solution of the linear ODE assuming
:math:`s_{eq}` and :math:`\tau` are constant on the step.
Euler step (simple, less stable)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
For each step :math:`i`:
.. math::
\bar{s}_{i} =
\bar{s}_{i-1} +
\Delta t_i\, \frac{s_{eq,i} - \bar{s}_{i-1}}{\tau_i}
Use ``method="exact"`` unless you have a strong reason to match
a legacy discretization.
Time and units
--------------
The integrator expects:
* ``h_mean_si`` and ``h_ref_si`` in meters,
* ``Ss_field`` in 1/m,
* ``H_field_si`` in meters,
* ``tau_field`` in seconds,
* ``dt`` expressed in ``time_units`` and converted to seconds
internally via :func:`dt_to_seconds`.
If ``dt`` is None, a unit step of ``1`` is used per horizon index,
interpreted in ``time_units``.
Shape contract and horizon alignment
------------------------------------
Internally, this function forces a strict ``(B, H, 1)`` layout for
the evolving state and for all step inputs, because TensorFlow
scan operations can widen shapes when rank is ambiguous.
Inputs that vary with time may be provided as:
* length ``H`` : already aligned,
* length ``H+1``: treated as state-length, sliced to ``[:-1]``,
* length ``H-1``: treated as step-length, padded by prepending the
first entry, producing length ``H``,
* length ``1`` : broadcast across horizon.
This alignment is applied to ``dt`` and ``tau_field``. The
equilibrium sequence ``s_eq`` is computed at length ``H``.
Stability and sanitization
--------------------------
This integrator aggressively sanitizes non-finite values:
* ``dt`` and ``dt_sec`` are mapped through a finite-or-zero rule,
then clamped to non-negative.
* ``tau`` is sanitized and lower-bounded by ``eps_tau``.
* The final output is passed through a finite-or-zero rule.
These guards are intended to prevent training from crashing when
upstream predictions temporarily produce NaN/Inf.
Parameters
----------
h_mean_si : Tensor
Mean head in meters. Shape ``(B, H, 1)`` or ``(B, H)``.
The last dim is forced to 1 for scan stability.
Ss_field : Tensor
Specific storage :math:`S_s` in 1/m. Broadcastable to
``(B, H, 1)``.
H_field_si : Tensor
Compressible thickness :math:`H` in meters. Broadcastable to
``(B, H, 1)``.
tau_field : Tensor
Consolidation time scale :math:`\tau` in seconds.
Broadcastable to ``(B, H, 1)`` or horizon-aligned by the
alignment rules described above.
h_ref_si : Tensor
Reference head (meters). Broadcastable to ``h_mean_si``.
If ``stop_grad_ref=True``, gradients are stopped through this
reference inside :func:`equilibrium_compaction_si`.
s_init_si : Tensor
Initial settlement state at the horizon origin. This is the
initial value used by the scan initializer. It is expected to
represent the settlement at the first horizon time.
Typical shape is ``(B, 1, 1)`` or ``(B, 1)``.
dt : Tensor, optional
Step duration in ``time_units``. If provided, it must be
broadcastable and is horizon-aligned. If None, a unit step of
ones is used.
time_units : str, default="yr"
Units for ``dt``. Converted to seconds via :func:`dt_to_seconds`.
Examples include "yr", "day", "hour", or "unitless", depending
on your pipeline.
method : {"exact", "euler"}, default="exact"
Integration scheme. "exact" uses the closed-form step for a
first-order linear relaxation. "euler" uses forward Euler.
eps_tau : float, default=1e-12
Lower bound for ``tau`` to prevent division by zero and
undefined exponentials.
relu_beta : float, default=20.0
Smoothness parameter forwarded to
:func:`equilibrium_compaction_si` for ``drawdown_mode="smooth_relu"``.
drawdown_mode : str, default="smooth_relu"
Drawdown gating forwarded to :func:`equilibrium_compaction_si`.
Common values: "smooth_relu", "relu", "softplus", "none".
drawdown_rule : str, default="ref_minus_mean"
Drawdown rule forwarded to :func:`equilibrium_compaction_si`.
Use "ref_minus_mean" for head-loss convention. Use
"mean_minus_ref" for depth-like (down-positive) signals.
stop_grad_ref : bool, default=True
Forwarded to :func:`equilibrium_compaction_si`. If True, stops
gradient through ``h_ref_si`` to prevent reference drift.
drawdown_zero_at_origin : bool, default=False
Forwarded to :func:`equilibrium_compaction_si`. If True, shifts
smooth drawdown gating so the value at zero is near zero.
drawdown_clip_max : float or None, default=None
Forwarded to :func:`equilibrium_compaction_si`. If set, clips
drawdown after gating to avoid extreme values dominating loss.
verbose : int, default=0
Verbosity for printing basic statistics and shape info.
Returns
-------
s_bar_si : Tensor
Mean cumulative settlement over the horizon in meters, shape
``(B, H, 1)``. The sequence is the scan output of the chosen
stepper initialized at ``s_init_si``.
Notes
-----
Relationship to model outputs
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
In Option-1 mean modeling, the network predicts mean head, and
settlement mean is computed by integrating the relaxation ODE.
The model may optionally add a learned residual around this mean,
but the returned value here is the physics mean only.
Interpreting the horizon index
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
This function produces a length-H sequence. If your horizon times
represent the future steps 1..H, ensure that ``dt`` and ``h_mean_si``
are consistent with that convention. If you supply sequences of
length H+1 (state nodes), the last node is dropped by alignment.
Numerical behavior
~~~~~~~~~~~~~~~~~~
The exact update has the desirable limit:
* If :math:`\Delta t_i \ll \tau_i`, then
:math:`\bar{s}_i \approx \bar{s}_{i-1}` (slow relaxation).
* If :math:`\Delta t_i \gg \tau_i`, then
:math:`\bar{s}_i \approx s_{eq,i}` (fast equilibration).
Examples
--------
Integrate with unit yearly steps and exact update:
>>> s_bar = integrate_consolidation_mean(
... h_mean_si=h_mean,
... Ss_field=Ss,
... H_field_si=H,
... tau_field=tau,
... h_ref_si=h_ref,
... s_init_si=s0,
... time_units="yr",
... method="exact",
... )
Use explicit dt (months) and Euler update:
>>> s_bar = integrate_consolidation_mean(
... h_mean_si=h_mean,
... Ss_field=Ss,
... H_field_si=H,
... tau_field=tau,
... h_ref_si=h_ref,
... s_init_si=s0,
... dt=dt_months,
... time_units="month",
... method="euler",
... )
Flip drawdown rule for depth-like signals:
>>> s_bar = integrate_consolidation_mean(
... h_mean_si=depth_mean,
... Ss_field=Ss,
... H_field_si=H,
... tau_field=tau,
... h_ref_si=depth_ref,
... s_init_si=s0,
... drawdown_rule="mean_minus_ref",
... )
See Also
--------
equilibrium_compaction_si
Computes :math:`s_{eq}(t)` from head/drawdown and fields.
tau_phys_from_fields
Computes a physically motivated baseline time scale.
settlement_state_for_pde
Converts predicted settlement representations to an ODE state.
"""
def _align_to_horizon(x: Tensor, *, name: str) -> Tensor:
"""Align x time-length to horizon H (or keep length 1)."""
xt = _ensure_3d(tf_cast(x, tf_float32))
tx = tf_shape(xt)[1]
# If provided as state-length (H+1), slice to horizon H.
xt = tf_cond(
tf_equal(tx, H + 1),
lambda: xt[:, :-1, :],
lambda: xt,
)
# If provided as step-length (H-1), pad to horizon H by
# repeating the first step (consistent with dt inference).
tx2 = tf_shape(xt)[1]
def _pad_prepend() -> Tensor:
first = xt[:, :1, :]
return tf_concat([first, xt], axis=1)
xt = tf_cond(
tf_logical_and(
tf_greater(H, 1),
tf_equal(tx2, H - 1),
),
_pad_prepend,
lambda: xt,
)
# Now must be length H or 1.
tx3 = tf_shape(xt)[1]
ok = tf_logical_or(tf_equal(tx3, H), tf_equal(tx3, 1))
tf_debugging.assert_equal(
ok,
True,
message=(
f"{name} has incompatible time length; "
"expected H, H-1, H+1, or 1."
),
)
return xt
h_mean_si = _ensure_3d(tf_cast(h_mean_si, tf_float32))
# ----------------------------------------------------------
# Force a strict (B,H,1) shape (static last dim = 1).
# This prevents tf.scan from widening shapes to (None,None).
# ----------------------------------------------------------
shp = tf_shape(h_mean_si)
B = shp[0]
H = shp[1]
h_mean_si = tf_reshape(h_mean_si, [B, H, 1])
vprint(
verbose,
"[integrate_consolidation_mean] B,H =",
B,
H,
"| time_units=",
time_units,
"| method=",
method,
)
# --- dt in seconds (BH1) -----------------------------------
if dt is None:
dt = tf_ones_like(h_mean_si)
vprint(
verbose,
"[integrate_consolidation_mean] dt=None -> 1",
)
else:
dt_in = _align_to_horizon(dt, name="dt")
dt = _broadcast_like(dt_in, h_mean_si)
dt = tf_reshape(dt, [B, H, 1])
# sanitize dt before converting
dt = _finite_or_zero(dt)
# Optional: disallow negative time steps
dt = tf_maximum(dt, tf_constant(0.0, dtype=dt.dtype))
dt_sec = dt_to_seconds(dt, time_units=time_units)
dt_sec = tf_reshape(dt_sec, [B, H, 1])
# sanitize dt_sec too (unit conversion could create non-finite)
dt_sec = _finite_or_zero(dt_sec)
dt_sec = tf_maximum(
dt_sec, tf_constant(0.0, dtype=dt_sec.dtype)
)
vprint(
verbose,
"[integrate_consolidation_mean] dt_sec stats:",
"min=",
tf_reduce_min(dt_sec),
"max=",
tf_reduce_max(dt_sec),
"mean=",
tf_reduce_mean(dt_sec),
)
# --- tau (BH1) ---------------------------------------------
tau_in = _align_to_horizon(tau_field, name="tau_field")
tau = _broadcast_like(tau_in, h_mean_si)
tau = tf_reshape(tau, [B, H, 1])
tf_debugging.assert_equal(
tf_shape(tau)[1],
H,
message=(
"integrate_consolidation_mean:"
" tau horizon must match h_mean_si horizon"
),
)
# sanitize tau BEFORE clamping
tau = _finite_or_zero(tau)
tau = tf_maximum(
tau,
tf_constant(eps_tau, dtype=tf_float32),
)
vprint(
verbose,
"[integrate_consolidation_mean] tau stats:",
"min=",
tf_reduce_min(tau),
"max=",
tf_reduce_max(tau),
"mean=",
tf_reduce_mean(tau),
)
# --- equilibrium compaction (BH1) --------------------------
s_eq = equilibrium_compaction_si(
h_mean_si=h_mean_si,
h_ref_si=h_ref_si,
Ss_field=Ss_field,
H_field_si=H_field_si,
# NEW forwarding:
drawdown_mode=drawdown_mode,
drawdown_rule=drawdown_rule,
stop_grad_ref=stop_grad_ref,
drawdown_zero_at_origin=drawdown_zero_at_origin,
drawdown_clip_max=drawdown_clip_max,
relu_beta=relu_beta,
verbose=verbose,
)
s_eq = tf_reshape(s_eq, [B, H, 1])
method = str(method).strip().lower()
if method not in {"exact", "euler"}:
raise ValueError(
"integrate_consolidation_mean: "
"method must be 'exact' or 'euler'."
)
# --- initializer (B,1) -------------------------------------
# s0 = _ensure_3d(tf_cast(s_init_si, tf_float32))
# s0 = s0[:, :1, :1]
# s0 = tf_reshape(s0, [B, 1, 1])
# s0 = _finite_or_zero(s0)
# s0_2d = tf_reshape(s0[:, 0, :], [B, 1])
s0 = _ensure_3d(tf_cast(s_init_si, tf_float32))
s0 = s0[:, :1, :1]
# broadcast to (B,1,1) using the same mechanism as dt/tau
s0 = _broadcast_like(s0, h_mean_si[:, :1, :1])
s0 = tf_reshape(s0, [B, 1, 1])
s0 = _finite_or_zero(s0)
s0_2d = tf_reshape(s0[:, 0, :], [B, 1])
vprint(
verbose,
"[integrate_consolidation_mean] s_init stats:",
"min=",
tf_reduce_min(s0_2d),
"max=",
tf_reduce_max(s0_2d),
"mean=",
tf_reduce_mean(s0_2d),
)
if tf_transpose is None or tf_scan is None:
raise RuntimeError(
"TensorFlow ops 'transpose'/'scan' missing "
"from KERAS_DEPS."
)
# time-major: (H,B,1)
dt_tm = tf_transpose(dt_sec, [1, 0, 2])
tau_tm = tf_transpose(tau, [1, 0, 2])
seq_tm = tf_transpose(s_eq, [1, 0, 2])
def step(
prev: Tensor,
elems: tuple[Tensor, Tensor, Tensor],
) -> Tensor:
dt_i, tau_i, seq_i = elems
# Force (B,1) each iteration (prevents widening).
shp_prev = tf_shape(prev)
dt_i = tf_reshape(dt_i, shp_prev)
tau_i = tf_reshape(tau_i, shp_prev)
seq_i = tf_reshape(seq_i, shp_prev)
if method == "exact":
a = tf_exp(
-dt_i
/ (tau_i + tf_constant(_EPSILON, tau_i.dtype))
)
nxt = prev * a + seq_i * (1.0 - a)
else:
nxt = prev + dt_i * (seq_i - prev) / (
tau_i + tf_constant(_EPSILON, tau_i.dtype)
)
return tf_reshape(nxt, shp_prev)
s_tm = tf_scan(
fn=step,
elems=(dt_tm, tau_tm, seq_tm),
initializer=s0_2d,
)
s_bar = tf_transpose(s_tm, [1, 0, 2])
s_bar = _finite_or_zero(s_bar)
vprint(
verbose,
"[integrate_consolidation_mean] s_bar stats:",
"min=",
tf_reduce_min(s_bar),
"max=",
tf_reduce_max(s_bar),
"mean=",
tf_reduce_mean(s_bar),
)
return s_bar
[docs]
def compute_consolidation_step_residual(
*,
s_state_si: Tensor,
h_mean_si: Tensor,
Ss_field: Tensor,
H_field_si: Tensor,
tau_field: Tensor,
h_ref_si: Tensor,
dt: TensorLike | None = None,
time_units: str | None = "yr",
method: str = "exact",
eps_tau: float = 1e-12,
relu_beta: float = 20.0,
drawdown_mode: str = "smooth_relu",
drawdown_rule: str = "ref_minus_mean",
stop_grad_ref: bool = True,
drawdown_zero_at_origin: bool = False,
drawdown_clip_max: float | None = None,
verbose: int = 0,
) -> Tensor:
r"""
Compute a one-step consolidation residual in SI space.
This function forms a per-step residual that penalizes violations of
a first-order consolidation relaxation model over a sequence of
states. It is intended for physics diagnostics and for PDE-style
training objectives where the settlement state is predicted (or
derived) and should satisfy a stable time-stepping rule.
Model and notation
------------------
Let the settlement state be :math:`s(t)` (m). The governing ODE is a
Voigt-style relaxation toward an equilibrium settlement
:math:`s_{eq}(t)`:
.. math::
\frac{ds}{dt} = \frac{s_{eq}(t) - s(t)}{\tau(t)}
where :math:`\tau(t)` is a (possibly space/time varying) time scale
in seconds. The equilibrium settlement is computed from head
drawdown:
.. math::
s_{eq}(t) = S_s(t)\, \Delta h(t)\, H(t)
with :math:`S_s` in 1/m and :math:`H` in m. The drawdown
:math:`\Delta h(t)` is constructed from ``h_mean_si`` and ``h_ref_si``
using ``drawdown_rule`` and is gated by ``drawdown_mode`` via
:func:`equilibrium_compaction_si`.
Discrete residual definition
----------------------------
Given state samples :math:`s_n = s(t_n)` and a step duration
:math:`\Delta t_n` (seconds), this routine computes a one-step
prediction :math:`\hat{s}_{n+1}` from :math:`s_n` and :math:`s_{eq,n}`
and returns the residual:
.. math::
r_n = s_{n+1} - \hat{s}_{n+1}
The residual has units of meters and is produced for
:math:`n = 0, ..., T-2`, hence the output time length is ``T-1``.
Two steppers are supported:
Exact step (closed-form, stable)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Assuming :math:`s_{eq}` and :math:`\tau` are constant over the step:
.. math::
a_n = \exp\left(-\frac{\Delta t_n}{\tau_n}\right)
.. math::
\hat{s}_{n+1} =
a_n s_n + (1-a_n) s_{eq,n}
Euler step (forward Euler)
~~~~~~~~~~~~~~~~~~~~~~~~~~
.. math::
\hat{s}_{n+1} =
s_n + \Delta t_n \frac{s_{eq,n} - s_n}{\tau_n}
The "exact" update is unconditionally stable for stiff regimes
(:math:`\Delta t / \tau` large). Euler may be unstable unless
:math:`\Delta t \ll \tau`.
Time and units
--------------
* Inputs are SI: meters, seconds, and 1/m as documented below.
* ``dt`` is interpreted in ``time_units`` and converted to seconds
using :func:`dt_to_seconds`.
Shape contract and alignment
----------------------------
Let ``s_state_si`` and ``h_mean_si`` have shape ``(B, T, 1)`` (or
``(B, T)`` which is promoted to ``(B, T, 1)``). This function forms
step-aligned sequences of length ``H = T-1``:
* :math:`s_n = s[:, :-1]` and :math:`s_{n+1} = s[:, 1:]`
* :math:`h_n = h[:, :-1]`
Time-varying fields may be provided with time length:
* ``T`` : state-length, sliced to ``T-1`` steps,
* ``T-1`` : already step-length,
* ``1`` : broadcast across steps.
After alignment, fields are broadcast to ``(B, T-1, 1)``.
Numerical safety
----------------
This routine sanitizes key quantities:
* Non-finite values are mapped through a finite-or-zero rule.
* ``tau`` is clamped below by ``eps_tau``.
* ``dt`` converted to seconds is clamped to non-negative.
These guards prevent crashes during training when upstream model
outputs temporarily produce NaN/Inf.
Parameters
----------
s_state_si : Tensor
Settlement state in meters. Shape ``(B, T, 1)`` or ``(B, T)``.
This is the state used in the stepper (often incremental).
h_mean_si : Tensor
Mean head (or depth-like signal) in meters. Shape must match
``s_state_si`` in the time dimension ``T``. The stepper uses
``h_mean_si[:, :-1]`` to compute :math:`s_{eq,n}`.
Ss_field : Tensor
Specific storage :math:`S_s` in 1/m. Time length may be ``T``,
``T-1``, or ``1``; it is aligned to steps and broadcast.
H_field_si : Tensor
Compressible thickness :math:`H` in meters. Time length may be
``T``, ``T-1``, or ``1``; it is aligned to steps and broadcast.
tau_field : Tensor
Consolidation time scale :math:`\tau` in seconds. Time length may
be ``T``, ``T-1``, or ``1``; it is aligned to steps and clamped
below by ``eps_tau``.
h_ref_si : Tensor
Reference head (or reference depth) in meters. Time length may be
``T``, ``T-1``, or ``1``; it is aligned to steps. If
``stop_grad_ref=True``, gradients are stopped through this
reference inside :func:`equilibrium_compaction_si`.
dt : Tensor, optional
Step duration in ``time_units``. Time length may be ``T``,
``T-1``, or ``1``. If None, a unit step is used for every step.
time_units : str, default="yr"
Units for ``dt`` prior to conversion to seconds via
:func:`dt_to_seconds`.
method : {"exact", "euler"}, default="exact"
Stepping scheme used to build :math:`\hat{s}_{n+1}`.
eps_tau : float, default=1e-12
Lower bound for ``tau`` to prevent division by zero and overflow
in stiff regimes.
relu_beta : float, default=20.0
Smoothness parameter forwarded to :func:`equilibrium_compaction_si`
when ``drawdown_mode="smooth_relu"``.
drawdown_mode : str, default="smooth_relu"
Forwarded to :func:`equilibrium_compaction_si`. Controls the
positive-part gating applied to drawdown.
drawdown_rule : str, default="ref_minus_mean"
Forwarded to :func:`equilibrium_compaction_si`. Controls the sign
convention for drawdown. Use "ref_minus_mean" for head loss and
"mean_minus_ref" for depth-like (down-positive) signals.
stop_grad_ref : bool, default=True
Forwarded to :func:`equilibrium_compaction_si`. If True, prevents
gradients through the reference signal ``h_ref_si``.
drawdown_zero_at_origin : bool, default=False
Forwarded to :func:`equilibrium_compaction_si`. If True, shifts
the smooth drawdown gate so the value at zero is near zero.
drawdown_clip_max : float or None, default=None
Forwarded to :func:`equilibrium_compaction_si`. Clips drawdown
after gating to reduce extreme values.
verbose : int, default=0
Verbosity for basic shape and residual statistics.
Returns
-------
res : Tensor
One-step residual sequence in meters, shape ``(B, T-1, 1)``:
.. math::
r_n = s_{n+1} - \hat{s}_{n+1}
Notes
-----
Incremental vs cumulative state
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
This residual is defined on the settlement state passed as
``s_state_si``. If your training uses an incremental ODE state
:math:`s_{inc}(t) = s_{cum}(t) - s_0`, ensure that both
``s_state_si`` and ``s_eq`` are expressed in the same state space.
A mismatch (e.g., residuals on incremental state but equilibrium in
cumulative units) will produce biased residuals.
When to prefer the exact step
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
If the ratio :math:`\Delta t / \tau` is frequently larger than 1,
Euler updates can become numerically unstable. The exact step is
stable in both stiff and non-stiff regimes and is typically the best
default for physics losses.
Examples
--------
Compute residuals for a horizon sequence:
>>> r = compute_consolidation_step_residual(
... s_state_si=s_state,
... h_mean_si=h_mean,
... Ss_field=Ss,
... H_field_si=H,
... tau_field=tau,
... h_ref_si=h_ref,
... dt=dt,
... time_units="yr",
... method="exact",
... )
Flip drawdown rule for depth-like inputs:
>>> r = compute_consolidation_step_residual(
... s_state_si=s_state,
... h_mean_si=depth,
... Ss_field=Ss,
... H_field_si=H,
... tau_field=tau,
... h_ref_si=depth_ref,
... drawdown_rule="mean_minus_ref",
... )
See Also
--------
integrate_consolidation_mean
Integrates the same relaxation model forward in time.
equilibrium_compaction_si
Computes equilibrium settlement from drawdown and fields.
dt_to_seconds
Converts ``dt`` in ``time_units`` to SI seconds.
"""
# ---------------------------------------------------------
# 1) Normalize core tensors to (B,T,1) float32.
# ---------------------------------------------------------
s_state = _ensure_3d(tf_cast(s_state_si, tf_float32))
h_state = _ensure_3d(tf_cast(h_mean_si, tf_float32))
T_s = tf_shape(s_state)[1]
T_h = tf_shape(h_state)[1]
tf_debugging.assert_equal(
T_s,
T_h,
message="s_state_si and h_mean_si must share T.",
)
vprint(
verbose,
"[compute_cons_step_res] T=",
T_s,
"| method=",
method,
)
# ---------------------------------------------------------
# 2) Build step-aligned sequences (length H = T-1).
# ---------------------------------------------------------
s_n = s_state[:, :-1, :] # (B,H,1)
s_np1 = s_state[:, 1:, :] # (B,H,1)
h_n = h_state[:, :-1, :] # (B,H,1)
H = tf_shape(s_n)[1] # H = T-1
# ---------------------------------------------------------
# 3) Helper: align a time series to step length H.
# Accepts:
# - (B,T,1) -> slice to (B,H,1)
# - (B,H,1) -> keep
# - (B,1,1) -> broadcast later
# ---------------------------------------------------------
def _align_to_steps(
x: TensorLike | None, name: str
) -> TensorLike | None:
if x is None:
return None
xt = _ensure_3d(tf_cast(x, tf_float32))
tx = tf_shape(xt)[1]
# If provided at state length T, slice to steps.
xt = tf_cond(
tf_equal(tx, H + 1),
lambda: xt[:, :-1, :],
lambda: xt,
)
# After slicing, require time dim == H or 1.
tx2 = tf_shape(xt)[1]
ok = tf_logical_or(tf_equal(tx2, H), tf_equal(tx2, 1))
tf_debugging.assert_equal(
ok,
True,
message=(
f"{name} time length must be H or 1 "
"or T (then sliced)."
),
)
return xt
# ---------------------------------------------------------
# 4) dt handling: align then broadcast to (B,H,1).
# ---------------------------------------------------------
if dt is None:
dt_steps = tf_ones_like(s_n)
vprint(
verbose,
"[compute_cons_step_res] dt=None -> 1 per step",
)
else:
dt_in = _align_to_steps(dt, "dt")
dt_steps = _broadcast_like(dt_in, s_n)
# Convert dt to seconds for the stepper.
dt_sec = dt_to_seconds(dt_steps, time_units=time_units)
# Optional safety: keep finite, non-negative dt.
dt_sec = _finite_or_zero(dt_sec)
dt_sec = tf_maximum(dt_sec, tf_constant(0.0, tf_float32))
# ---------------------------------------------------------
# 5) Align other time-series fields to step length.
# ---------------------------------------------------------
h_ref_n = _align_to_steps(h_ref_si, "h_ref_si")
Ss_n = _align_to_steps(Ss_field, "Ss_field")
Hf_n = _align_to_steps(H_field_si, "H_field_si")
tau_n = _align_to_steps(tau_field, "tau_field")
# Broadcast each aligned series to (B,H,1).
h_ref_n = _broadcast_like(h_ref_n, s_n)
Ss_n = _broadcast_like(Ss_n, s_n)
Hf_n = _broadcast_like(Hf_n, s_n)
tau = _broadcast_like(tau_n, s_n)
# Clamp tau for numerical stability.
tau = _finite_or_zero(tau)
tau = tf_maximum(tau, tf_constant(eps_tau, tf_float32))
# ---------------------------------------------------------
# 6) Compute equilibrium settlement at step times.
# ---------------------------------------------------------
s_eq_n = equilibrium_compaction_si(
h_mean_si=h_n,
h_ref_si=h_ref_n,
Ss_field=Ss_n,
H_field_si=Hf_n,
drawdown_mode=drawdown_mode,
drawdown_rule=drawdown_rule,
stop_grad_ref=stop_grad_ref,
drawdown_zero_at_origin=drawdown_zero_at_origin,
drawdown_clip_max=drawdown_clip_max,
relu_beta=relu_beta,
verbose=verbose,
)
# ---------------------------------------------------------
# 7) Stable one-step prediction and residual.
# ---------------------------------------------------------
m = str(method).strip().lower()
# Calculate ratio for stability check
dt_tau_ratio = dt_sec / (
tau + tf_constant(_EPSILON, tau.dtype)
)
use_exact = tf_logical_or( # noqa
tf_equal(m, "exact"),
tf_reduce_any(dt_tau_ratio > 1.0), # Safety switch
)
def _step_exact():
a = tf_exp(
-dt_sec / (tau + tf_constant(_EPSILON, tau.dtype))
)
return s_n * a + s_eq_n * (1.0 - a)
def _step_euler():
return s_n + dt_sec * (s_eq_n - s_n) / (
tau + tf_constant(_EPSILON, tau.dtype)
)
# Use exact if requested OR if stability is at risk
if m == "exact":
pred = _step_exact()
else:
# Hybrid safety: use exact where stiff, euler where safe?
# Easier to just force exact if user didn't strictly demand pure euler behavior
# But for now, let's just use the user choice but warn/clamp.
# Better: just use exact. It's unconditionally stable.
pred = _step_euler()
# NOTE: I highly recommend changing the default in __init__ to 'exact'
# if it isn't already.
res = s_np1 - pred
res = _finite_or_zero(res)
vprint(
verbose,
"[compute_cons_step_res] res stats:",
"min=",
tf_reduce_min(res),
"max=",
tf_reduce_max(res),
"mean=",
tf_reduce_mean(res),
)
return res
[docs]
def tau_phys_from_fields(
model,
K_field: Tensor,
Ss_field: Tensor,
H_field: Tensor,
*,
eps: float = _EPSILON,
verbose: int = 0,
return_log: bool = False,
) -> tuple[Tensor, Tensor]:
r"""
Compute the physics closure consolidation timescale ``tau_phys``
and the effective drainage thickness ``Hd``.
This function implements the model's consolidation timescale
closure :math:`tau_{phys}` in a numerically stable way. The core
design is to compute :math:`log(tau_{phys})` first, and only
apply ``exp`` at the end (unless ``return_log=True``). This
prevents unstable gradients that can arise from naive algebraic
forms that contain high powers of :math:`1/K`.
Mathematical definition
-----------------------
Let the model provide effective fields in SI units:
* :math:`K` in m/s (hydraulic conductivity),
* :math:`S_s` in 1/m (specific storage),
* :math:`H` in m (thickness),
* :math:`kappa` is a positive scalar multiplier (dimensionless
in the code path; its physical meaning depends on the chosen
mode).
An effective drainage thickness :math:`H_d` is defined as:
.. math::
H_d =
\begin{cases}
H \cdot f_{Hd}, & \text{if use\_effective\_thickness is True} \\
H, & \text{otherwise}
\end{cases}
where :math:`f_{Hd}` is ``model.Hd_factor``. The function returns
this :math:`H_d` as ``Hd``.
Two closure modes are supported via ``model.kappa_mode``:
1) ``kappa_mode="bar"``
The closure is:
.. math::
\tau_{phys}
= \frac{\kappa \, H^2 \, S_s}{\pi^2 \, K}
The log form is:
.. math::
\log(\tau_{phys})
= \log(\kappa)
+ 2\log(H)
+ \log(S_s)
- \log(\pi^2)
- \log(K)
2) Any other value (the "non-bar" branch)
The closure is:
.. math::
\tau_{phys}
= \frac{H_d^2 \, S_s}{\pi^2 \, \kappa \, K}
The log form is:
.. math::
\log(\tau_{phys})
= 2\log(H_d)
+ \log(S_s)
- \log(\pi^2)
- \log(\kappa)
- \log(K)
The implementation uses the log forms above. If ``return_log`` is
False, it returns:
.. math::
\tau_{phys} = \exp(\log(\tau_{phys}))
Numerical stability
-------------------
All inputs are sanitized and floored to be strictly positive:
.. math::
K_{safe} = \max(K, eps) \\
S_{s,safe}= \max(S_s,eps) \\
H_{safe} = \max(H, eps) \\
H_{d,safe}= \max(H_d,eps) \\
\kappa_{safe} = \max(\kappa, eps)
This ensures that ``log`` is never applied to non-positive or
non-finite values. Computing in log-space also avoids gradient
blow-ups associated with expressions that behave like
:math:`1/K^2` when written in certain equivalent algebraic forms.
Parameters
----------
model : Any
Model instance providing:
``use_effective_thickness``, ``Hd_factor``, ``kappa_mode``,
and a callable ``_kappa_value()`` returning a positive scalar.
K_field : Tensor
Effective conductivity :math:`K` in SI (m/s). Shape is
broadcastable to ``(B, H, 1)``.
Ss_field : Tensor
Effective specific storage :math:`S_s` in SI (1/m). Shape is
broadcastable to ``(B, H, 1)``.
H_field : Tensor
Thickness :math:`H` in SI (m). Shape is broadcastable to
``(B, H, 1)``.
eps : float, default=_EPSILON
Positive floor used by ``finite_floor`` to prevent invalid
logs and divisions.
verbose : int, default=0
Verbosity level for debug printing.
return_log : bool, default=False
If True, return ``(log_tau_phys, Hd)`` where ``log_tau_phys``
is :math:`log(tau_{phys})`. If False, return
``(tau_phys, Hd)`` where ``tau_phys`` is in seconds.
Returns
-------
tau_or_log_tau : Tensor
If ``return_log=False``:
Physics closure timescale :math:`tau_{phys}` in seconds.
If ``return_log=True``:
:math:`log(tau_{phys})` in log-seconds.
Hd : Tensor
Effective drainage thickness :math:`H_d` in meters. This is
either ``H_field`` or ``H_field * Hd_factor`` depending on
``model.use_effective_thickness``. Shape is broadcastable to
``(B, H, 1)``.
Notes
-----
Choice of ``H`` vs ``Hd`` in "bar" mode
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
In the ``kappa_mode="bar"`` branch, the code uses ``H_safe`` in
the :math:`H^2` term. The function still computes and returns
``Hd`` for downstream diagnostics and for non-"bar" mode. If your
physical interpretation requires :math:`H_d` as the diffusion path
length in "bar" mode, you may adapt the closure accordingly.
Gradient behavior
~~~~~~~~~~~~~~~~~
Computing :math:`log(tau_{phys})` first improves stability because
the derivative of ``log`` scales like :math:`1/x`, while direct
algebraic expansions of :math:`tau_{phys}` can introduce stronger
inverse powers in intermediate steps. The only exponential is
applied at the end (if requested), keeping the computational graph
well behaved.
Examples
--------
Compute ``tau_phys`` and use it in a prior:
>>> tau_phys, Hd = tau_phys_from_fields(model, K, Ss, H)
>>> R_prior = tf.math.log(tf.maximum(tau_learned, 1e-12)) \
... - tf.math.log(tf.maximum(tau_phys, 1e-12))
>>> loss_prior = tf.reduce_mean(tf.square(R_prior))
Get log-space output for direct log-priors:
>>> log_tau_phys, Hd = tau_phys_from_fields(
... model, K, Ss, H, return_log=True
... )
See Also
--------
compute_consistency_prior
Builds the residual ``log(tau_learned) - log(tau_phys)``.
compose_physics_fields
Composes bounded SI fields and returns ``log_tau_phys`` for
diagnostics and bounds penalties.
get_log_tau_bounds
Provides configured bounds for ``log_tau`` (log-seconds).
"""
eps = float(eps)
pi_sq = tf_constant(np.pi**2, dtype=tf_float32)
# Sanitize inputs
K_safe = finite_floor(K_field, eps=eps)
Ss_safe = finite_floor(Ss_field, eps=eps)
H_safe = finite_floor(H_field, eps=eps)
# --- Effective Thickness Logic ---
use_hd = bool(
getattr(model, "use_effective_thickness", False)
)
if use_hd:
f = getattr(model, "Hd_factor", 1.0)
f = tf_cast(f, H_safe.dtype)
# finite check for factor
f = tf_where(
tf_math.is_finite(f),
f,
tf_constant(1.0, H_safe.dtype),
)
Hd = H_safe * f
else:
Hd = H_safe
Hd = finite_floor(Hd, eps=eps)
# --- Kappa Logic ---
kappa = model._kappa_value()
kappa = tf_cast(kappa, H_safe.dtype)
kappa = tf_where(
tf_math.is_finite(kappa),
kappa,
tf_constant(1.0, H_safe.dtype),
)
kappa = finite_floor(kappa, eps=eps)
# --- Log-Space Computation (Stable) ---
# log(tau) = log(C) + log(Ss) + 2*log(Hd) - log(K)
log_Ss = tf_math.log(Ss_safe)
log_K = tf_math.log(K_safe)
log_Hd = tf_math.log(Hd)
log_pi = tf_math.log(pi_sq)
log_kap = tf_math.log(kappa)
# Formula depends on kappa_mode
mode = str(getattr(model, "kappa_mode", "bar"))
if mode == "bar":
# tau = kappa * H^2 * Ss / (pi^2 * K)
# log_tau = log(k) + 2log(H) + log(Ss) - log(pi^2) - log(K)
# Note: using H_safe here typically, or Hd?
# Code usually assumes Hd for diffusion path length.
# Using H_safe to match original code structure:
log_H = tf_math.log(H_safe)
log_tau = (
log_kap + 2.0 * log_H + log_Ss - log_pi - log_K
)
else:
# tau = (Hd/H)^2 * H^2 * Ss / (pi^2 * kappa * K)
# = Hd^2 * Ss / (pi^2 * kappa * K)
log_tau = (
2.0 * log_Hd + log_Ss - log_pi - log_kap - log_K
)
if return_log:
return log_tau, Hd
# Only exp() at the very end.
# If log_tau is huge (e.g. 50), exp() overflows, but gradient of exp is manageable.
tau_phys = tf_exp(log_tau)
vprint(verbose, "tau_phys: log_tau=", log_tau)
vprint(verbose, "tau_phys: out=", tau_phys)
return tau_phys, Hd
[docs]
def compute_consistency_prior(
model,
K_field: Tensor,
Ss_field: Tensor,
tau_field: Tensor,
H_field: Tensor,
*,
verbose: int = 0,
) -> Tensor:
r"""
Compute the consolidation timescale consistency prior.
This prior constrains the learned consolidation timescale ``tau``
to remain physically consistent with the permeability-storage-
thickness closure implied by the poroelastic consolidation model.
It returns the *log-space mismatch*:
.. math::
R_{\mathrm{prior}}
= \log(\tau_{\mathrm{learned}})
- \log(\tau_{\mathrm{phys}})
where :math:`\tau_{\mathrm{phys}}` is computed from the predicted
fields :math:`K`, :math:`S_s`, and :math:`H` through
:func:`tau_phys_from_fields`.
Log-space is used for two reasons:
1. Positivity: :math:`\tau > 0` is enforced implicitly.
2. Scale: timescales may span orders of magnitude; comparing
logs yields a relative-type error signal.
Mathematical formulation
------------------------
Let the model predict effective fields (all SI):
* :math:`K(x,y)` in m/s (hydraulic conductivity),
* :math:`S_s(x,y)` in 1/m (specific storage),
* :math:`H(x,y)` in m (compressible thickness),
* :math:`\tau_{\mathrm{learned}}(x,y)` in s (learned timescale).
A common 1D Terzaghi-style drainage closure gives a characteristic
timescale:
.. math::
\tau_{\mathrm{phys}}
= \frac{H_d^2 \, S_s}{\pi^2 \, c_v}
with consolidation coefficient:
.. math::
c_v = \frac{K}{S_s}
or, equivalently, for a diffusion-like closure:
.. math::
\tau_{\mathrm{phys}}
= \frac{H_d^2 \, S_s}{\pi^2 \, K}
where :math:`H_d` is the effective drainage thickness (often a
fraction of :math:`H`, e.g. :math:`H_d = \mathrm{hd\_factor}\,H`).
The exact form used is delegated to :func:`tau_phys_from_fields`,
which may incorporate additional model parameters such as
:math:`\kappa` (compressibility/bulk coupling) or boundary
conditions.
The prior residual returned by this function is:
.. math::
R_{\mathrm{prior}}(x,y)
= \log(\max(\tau_{\mathrm{learned}}(x,y), \varepsilon))
- \log(\tau_{\mathrm{phys}}(x,y))
where :math:`\varepsilon` is a small constant used to avoid
``log(0)`` when the learned timescale becomes numerically small.
This residual is typically used inside an L2 penalty:
.. math::
L_{\mathrm{prior}}
= \mathbb{E}\left[R_{\mathrm{prior}}^2\right]
and contributes to the physics loss with a user-controlled weight
(e.g., ``lambda_prior``).
Parameters
----------
model : Any
Model instance providing configuration used by
:func:`tau_phys_from_fields` (for example, effective thickness
settings, kappa mode, bounds, and scaling config).
K_field : Tensor
Effective hydraulic conductivity :math:`K` in SI units (m/s).
Expected shape is broadcastable to ``(B, H, 1)``.
Ss_field : Tensor
Effective specific storage :math:`S_s` in SI units (1/m).
Expected shape is broadcastable to ``(B, H, 1)``.
tau_field : Tensor
Learned timescale :math:`\tau_{\mathrm{learned}}` in SI seconds.
Expected shape is broadcastable to ``(B, H, 1)``.
H_field : Tensor
Thickness :math:`H` in SI meters. Expected shape is
broadcastable to ``(B, H, 1)``.
verbose : int, default=0
Verbosity level for debug printing.
Returns
-------
residual : Tensor
Log-space prior residual:
.. math::
R_{\mathrm{prior}}
= \log(\tau_{\mathrm{learned}})
- \log(\tau_{\mathrm{phys}})
Shape follows the broadcasted shape of the inputs, typically
``(B, H, 1)``.
Notes
-----
Numerical stability
~~~~~~~~~~~~~~~~~~~
* ``tau_field`` is floored by a small :math:`\varepsilon` before
taking the logarithm.
* :func:`tau_phys_from_fields` is called with ``return_log=True``
to compute :math:`\log(\tau_{\mathrm{phys}})` directly, avoiding
the unstable pattern ``log(exp(log_tau))``.
Interpretation
~~~~~~~~~~~~~~
* ``residual = 0`` means the learned timescale matches the closure.
* Positive values indicate :math:`\tau_{\mathrm{learned}}` is larger
(slower consolidation) than predicted by the closure.
* Negative values indicate a smaller (faster) learned timescale.
Examples
--------
Compute and reduce to an L2 prior loss:
>>> R_prior = compute_consistency_prior(
... model, K_field=K, Ss_field=Ss,
... tau_field=tau, H_field=H
... )
>>> loss_prior = tf.reduce_mean(tf.square(R_prior))
See Also
--------
tau_phys_from_fields
Computes :math:`\tau_{\mathrm{phys}}` from :math:`K`, :math:`S_s`,
and :math:`H` (and model configuration).
compose_physics_fields
Builds bounded/guarded SI fields and returns both learned
and closure timescales in log-space.
"""
eps = tf_constant(_EPSILON, dtype=tf_float32)
# 1. Get learned tau in log space
tau_safe = tf_maximum(tau_field, eps)
log_tau_learned = tf_math.log(tau_safe)
# 2. Get physical tau in log space directly (Stable!)
log_tau_phys, _ = tau_phys_from_fields(
model,
K_field,
Ss_field,
H_field,
verbose=0,
return_log=True, # Use the new flag
)
out = log_tau_learned - log_tau_phys
vprint(
verbose,
"cons_prior: log_tau_learned=",
log_tau_learned,
)
vprint(verbose, "cons_prior: log_tau_phys=", log_tau_phys)
vprint(verbose, "cons_prior: out=", out)
return out
[docs]
def compute_smoothness_prior(
dK_dx: Tensor,
dK_dy: Tensor,
dSs_dx: Tensor,
dSs_dy: Tensor,
*,
K_field: TensorLike | None = None,
Ss_field: TensorLike | None = None,
already_log: bool = False,
verbose: int = 0,
) -> Tensor:
r"""
Compute a smoothness prior on spatial gradients of physics fields.
This function builds a spatial regularizer that penalizes rapid
variation of the permeability-like field ``K`` and the storage
field ``Ss`` over the spatial coordinates. In the GeoPrior PINN,
this prior stabilizes the inverse problem by discouraging
unrealistic high-frequency spatial structure in learned fields.
The preferred penalty is applied in *log-space*:
.. math::
R_{\mathrm{smooth}}
= \left\|\nabla \log K\right\|^2
+ \left\|\nabla \log S_s\right\|^2
where, in 2D:
.. math::
\left\|\nabla \log K\right\|^2
= \left(\frac{\partial \log K}{\partial x}\right)^2
+ \left(\frac{\partial \log K}{\partial y}\right)^2
and similarly for :math:`S_s`. Penalizing gradients of logs is
often preferable to raw gradients because it regularizes *relative*
changes (order-of-magnitude variations) rather than absolute
changes.
Implementation modes
--------------------
The function supports three modes, chosen by inputs:
1. ``already_log=True``:
Inputs ``dK_dx`` etc. are interpreted as
:math:`\partial_x \log K` and so on, and the penalty is:
.. math::
R_{\mathrm{smooth}}
= (d\log K/dx)^2 + (d\log K/dy)^2
+ (d\log S_s/dx)^2 + (d\log S_s/dy)^2
2. ``already_log=False`` with ``K_field`` and ``Ss_field``:
The function converts raw gradients to log-gradients using:
.. math::
\frac{\partial \log K}{\partial x}
= \frac{1}{K}\frac{\partial K}{\partial x}
and similarly for the other terms. For numerical stability,
denominators are floored by a small constant ``eps_div``:
.. math::
K_{\mathrm{denom}} = \max(K, \varepsilon_{\mathrm{div}})
This avoids exploding ratios when ``K`` or ``Ss`` are very small.
3. Fallback (rare):
If ``K_field`` and ``Ss_field`` are not provided, it penalizes
the raw gradients directly:
.. math::
R_{\mathrm{smooth}}
= (dK/dx)^2 + (dK/dy)^2
+ (dS_s/dx)^2 + (dS_s/dy)^2
Parameters
----------
dK_dx, dK_dy : Tensor
Spatial gradients of ``K`` with respect to x and y. If
``already_log=True``, these are gradients of ``logK`` instead.
dSs_dx, dSs_dy : Tensor
Spatial gradients of ``Ss`` with respect to x and y. If
``already_log=True``, these are gradients of ``logSs`` instead.
K_field : Tensor or None, default=None
Field ``K`` in SI units (m/s). Used only when converting
raw gradients to log-gradients.
Ss_field : Tensor or None, default=None
Field ``Ss`` in SI units (1/m). Used only when converting
raw gradients to log-gradients.
already_log : bool, default=False
If True, treat the input gradients as log-gradients.
verbose : int, default=0
Verbosity level for debug printing.
Returns
-------
residual : Tensor
Smoothness residual map. Typical shape is ``(B, H, 1)`` or a
broadcast-compatible shape. This quantity is usually squared
and reduced to form ``loss_smooth``.
Notes
-----
Why log-space regularization?
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
In hydrogeologic inverse problems, fields like ``K`` can vary by
orders of magnitude across space. Penalizing gradients of ``logK``
naturally encourages spatial smoothness in a multiplicative sense,
which aligns with common geostatistical assumptions.
Stability of division
~~~~~~~~~~~~~~~~~~~~~
When converting ``dK`` to ``dlogK = dK / K``, small ``K`` can
create extremely large ratios. The denominator floor
``eps_div`` is applied only for the conversion step, so that
regions with effectively zero permeability do not dominate the
regularizer.
Examples
--------
Penalty in log-space using provided fields:
>>> R_smooth = compute_smoothness_prior(
... dK_dx, dK_dy, dSs_dx, dSs_dy,
... K_field=K, Ss_field=Ss, already_log=False
... )
>>> loss_smooth = tf.reduce_mean(tf.square(R_smooth))
Direct log-gradient mode (inputs already log-gradients):
>>> R_smooth = compute_smoothness_prior(
... dlogK_dx, dlogK_dy, dlogSs_dx, dlogSs_dy,
... already_log=True
... )
See Also
--------
compose_physics_fields
Produces bounded SI fields ``K_field`` and ``Ss_field`` and
associated log values for diagnostics and priors.
ensure_si_derivative_frame
Converts autodiff derivatives to SI-consistent spatial
derivatives suitable for smoothness penalties.
"""
# Safe epsilon for division
eps_div = tf_constant(1e-6, dtype=tf_float32)
if already_log:
# Inputs are already d(logK)/dx, etc.
out = (
tf_square(dK_dx)
+ tf_square(dK_dy)
+ tf_square(dSs_dx)
+ tf_square(dSs_dy)
)
vprint(verbose, "smooth(log-direct): out=", out)
return out
if (K_field is not None) and (Ss_field is not None):
# We want d(logK) = dK / K.
# STABILITY FIX: Use a larger epsilon or clip K for the denominator only.
# A tiny K implies K is essentially zero/impermeable.
# We don't want to penalize dK variations when K is 1e-15 vs 1e-16.
K_denom = tf_maximum(K_field, eps_div)
Ss_denom = tf_maximum(Ss_field, eps_div)
dlogK_dx = dK_dx / K_denom
dlogK_dy = dK_dy / K_denom
dlogSs_dx = dSs_dx / Ss_denom
dlogSs_dy = dSs_dy / Ss_denom
out = (
tf_square(dlogK_dx)
+ tf_square(dlogK_dy)
+ tf_square(dlogSs_dx)
+ tf_square(dlogSs_dy)
)
vprint(verbose, "smooth(log-div): out=", out)
return out
# Fallback to raw gradients (rarely used but safe)
out = (
tf_square(dK_dx)
+ tf_square(dK_dy)
+ tf_square(dSs_dx)
+ tf_square(dSs_dy)
)
vprint(verbose, "smooth(raw): out=", out)
return out
# ---------------------------------------------------------------------
# Bounds + field composition
# ---------------------------------------------------------------------
def _soft_barrier_l2(
x: Tensor,
lo: Tensor,
hi: Tensor,
*,
beta: float = 20.0,
) -> Tensor:
x = tf_cast(x, tf_float32)
lo = tf_cast(lo, tf_float32)
hi = tf_cast(hi, tf_float32)
b = tf_constant(float(beta), tf_float32)
v_lo = tf_softplus(b * (lo - x)) / b
v_hi = tf_softplus(b * (x - hi)) / b
return tf_square(v_lo) + tf_square(v_hi)
def _reduce_barrier_mean(v: Tensor) -> Tensor:
v = tf_where(tf_math.is_finite(v), v, tf_zeros_like(v))
return tf_reduce_mean(v)
def _soft_clip(
x: Tensor,
lo: Tensor,
hi: Tensor,
*,
beta: float = 20.0,
) -> Tensor:
x = tf_cast(x, tf_float32)
lo = tf_cast(lo, tf_float32)
hi = tf_cast(hi, tf_float32)
b = tf_constant(float(beta), tf_float32)
y1 = tf_softplus(b * (x - lo)) / b
y2 = tf_softplus(b * (x - hi)) / b
return lo + y1 - y2
# def guarded_exp_from_bounds(
# raw_log,
# log_min,
# log_max,
# *,
# eps=0.0,
# guard=5.0,
# dtype=None,
# name="",
# ):
# """
# Safe exp() with a wide log-space guard-band around [log_min, log_max].
# - raw_log: unconstrained log-parameter (may drift during training)
# - log_min/log_max: physical bounds in log-space
# - guard: extra margin to avoid overflow; values outside are clipped only
# for *numerical safety*, not as a hard physical constraint.
# """
# if dtype is None:
# dtype = raw_log.dtype
# raw_log = tf_cast(raw_log, dtype)
# log_min = tf_cast(log_min, dtype)
# log_max = tf_cast(log_max, dtype)
# guard = tf_cast(tf_constant(guard), dtype)
# eps = tf_cast(tf_constant(eps), dtype)
# # replace NaN/Inf with 0 to avoid propagating non-finites
# raw_log = tf_where(
# tf_math.is_finite(raw_log), raw_log,
# tf_zeros_like(raw_log)
# )
# # guard-band clip (prevents exp overflow)
# log_safe = tf_clip_by_value(
# raw_log, log_min - guard, log_max + guard
# )
# field = tf_exp(log_safe) + eps
# if name:
# tf_debugging.assert_all_finite(raw_log, f"{name} raw_log non-finite")
# tf_debugging.assert_all_finite(field, f"{name} field non-finite")
# return field, raw_log, log_safe
[docs]
def exp_from_bounds(
raw_log,
log_min,
log_max,
*,
mode="soft",
beta=20.0,
guard=5.0,
eps=0.0,
dtype=None,
name="",
):
if dtype is None:
dtype = raw_log.dtype
mode = str(mode).strip().lower()
beta_f = float(beta)
guard_f = float(guard)
eps_f = float(eps)
raw_log = tf_cast(raw_log, dtype)
# log_min = tf_cast(log_min, dtype)
# log_max = tf_cast(log_max, dtype)
if (log_min is None) or (log_max is None):
if mode == "hard":
raise ValueError(
"bounds_mode='hard' requires finite log bounds."
)
# Soft / none mode without configured bounds:
# keep raw log values for diagnostics, but guard the
# exponentiation so float32 never overflows.
safe_log_abs_max = tf_constant(80.0, dtype)
log_safe = tf_clip_by_value(
raw_log,
-safe_log_abs_max,
safe_log_abs_max,
)
pen = tf_zeros_like(raw_log)
field = tf_exp(log_safe) + tf_constant(eps_f, dtype)
if name:
tf_debugging.assert_all_finite(
raw_log, f"{name} raw_log non-finite"
)
tf_debugging.assert_all_finite(
field, f"{name} field non-finite"
)
return field, raw_log, log_safe, pen
log_min = tf_cast(log_min, dtype)
log_max = tf_cast(log_max, dtype)
raw_log = tf_where(
tf_math.is_finite(raw_log),
raw_log,
tf_zeros_like(raw_log),
)
if mode == "hard":
log_safe = tf_clip_by_value(raw_log, log_min, log_max)
pen = tf_zeros_like(raw_log)
elif mode == "sigmoid":
t = tf_sigmoid(raw_log)
log_safe = log_min + (log_max - log_min) * t
pen = tf_zeros_like(raw_log)
elif mode == "soft":
# numeric safety guard-band, but smooth
lo_g = log_min - tf_constant(guard_f, dtype)
hi_g = log_max + tf_constant(guard_f, dtype)
log_safe = _soft_clip(
raw_log, lo_g, hi_g, beta=beta_f
)
# physical bounds penalty (differentiable)
pen = _soft_barrier_l2(
raw_log, log_min, log_max, beta=beta_f
)
else: # "none"
lo_g = log_min - tf_constant(guard_f, dtype)
hi_g = log_max + tf_constant(guard_f, dtype)
log_safe = tf_clip_by_value(raw_log, lo_g, hi_g)
pen = tf_zeros_like(raw_log)
field = tf_exp(log_safe) + tf_constant(eps_f, dtype)
if name:
tf_debugging.assert_all_finite(
raw_log, f"{name} raw_log non-finite"
)
tf_debugging.assert_all_finite(
field, f"{name} field non-finite"
)
return field, raw_log, log_safe, pen
[docs]
def get_log_bounds(
model,
*,
as_tensor: bool = True,
dtype=tf_float32,
verbose: int = 0,
) -> tuple[Any, Any, Any, Any]:
r"""
Get validated log-space bounds for K and Ss.
This helper reads bounds from ``model.scaling_kwargs['bounds']``
and returns a 4-tuple:
``(logK_min, logK_max, logSs_min, logSs_max)``.
It supports two equivalent representations:
* Direct log-bounds:
``logK_min/logK_max`` and ``logSs_min/logSs_max``.
* Linear bounds converted to logs:
``K_min/K_max`` and ``Ss_min/Ss_max``.
If bounds are missing, the function returns
``(None, None, None, None)``.
Parameters
----------
model : Any
Model-like object with an optional ``scaling_kwargs`` dict.
Bounds are read from ``model.scaling_kwargs['bounds']``.
as_tensor : bool, default=True
If True, return Tensor scalars created with ``tf_constant``.
If False, return Python floats.
dtype : tf.DType, default=tf_float32
Tensor dtype used when ``as_tensor=True``.
verbose : int, default=0
Verbosity level for optional debug printing.
Returns
-------
logK_min, logK_max, logSs_min, logSs_max : tuple
Log-space bounds as Tensor scalars (if ``as_tensor=True``),
otherwise Python floats.
If bounds are not configured, returns:
``(None, None, None, None)``.
Raises
------
ValueError
If bounds exist but are invalid, including:
* non-finite values (NaN or inf)
* non-positive linear bounds (<= 0)
* unordered bounds (max <= min)
Notes
-----
Validation contract
~~~~~~~~~~~~~~~~~~~
This function never emits NaN log bounds. If the configuration
contains invalid entries, it fails fast with ``ValueError``.
Conversion precedence
~~~~~~~~~~~~~~~~~~~~~
If log-bounds are present, they are used directly. Otherwise,
the function looks for linear bounds and converts them via:
.. math::
\log K_{\min} = \log(K_{\min}), \quad
\log K_{\max} = \log(K_{\max}),
and similarly for :math:`S_s`.
Optional Ss heuristic
~~~~~~~~~~~~~~~~~~~~
If ``Ss_min/Ss_max`` appear to be compressibility-like values
(e.g., :math:`m_v`), the function may optionally convert them
to :math:`S_s` using :math:`S_s = m_v \gamma_w` when a finite
``model.gamma_w`` is available. This heuristic is best-effort
and never raises by itself.
Examples
--------
Use Tensor bounds for downstream math:
>>> logK_min, logK_max, logSs_min, logSs_max = get_log_bounds(
... model, as_tensor=True
... )
Return Python floats for inspection:
>>> bounds = get_log_bounds(model, as_tensor=False)
>>> print(bounds)
See Also
--------
get_log_tau_bounds
Companion helper for tau bounds in log space.
compute_bounds_residual
Uses these bounds to compute normalized violations.
"""
sk = getattr(model, "scaling_kwargs", None) or {}
b = (sk.get("bounds", None) or {}) or {}
def _as_float(v: Any) -> float:
"""Best-effort cast to float for config values."""
if hasattr(v, "numpy"):
v = v.numpy()
return float(v)
def _is_finite(v: float) -> bool:
return bool(np.isfinite(v))
def _validate_lin_pair(
vmin: float,
vmax: float,
*,
name_min: str,
name_max: str,
) -> tuple[float, float]:
"""Validate linear bounds are finite and positive."""
if (not _is_finite(vmin)) or (not _is_finite(vmax)):
raise ValueError(
f"{name_min}/{name_max} must be finite. "
f"Got vmin={vmin}, vmax={vmax}."
)
if (vmin <= 0.0) or (vmax <= 0.0):
raise ValueError(
f"{name_min}/{name_max} must be > 0. "
f"Got vmin={vmin}, vmax={vmax}."
)
if vmax <= vmin:
raise ValueError(
f"{name_max} must be > {name_min}. "
f"Got vmin={vmin}, vmax={vmax}."
)
return vmin, vmax
def _validate_log_pair(
lmin: float,
lmax: float,
*,
name_min: str,
name_max: str,
) -> tuple[float, float]:
"""Validate log-bounds are finite and ordered."""
if (not _is_finite(lmin)) or (not _is_finite(lmax)):
raise ValueError(
f"{name_min}/{name_max} must be finite. "
f"Got lmin={lmin}, lmax={lmax}."
)
if lmax <= lmin:
raise ValueError(
f"{name_max} must be > {name_min}. "
f"Got lmin={lmin}, lmax={lmax}."
)
return lmin, lmax
def _maybe_convert_ss_from_mv(
vmin: float,
vmax: float,
) -> tuple[float, float]:
"""
Heuristic: if Ss bounds look like m_v, convert using gamma_w.
This only runs for (Ss_min, Ss_max). If gamma_w is missing
or non-finite, we skip conversion (and still validate).
"""
try:
gw = getattr(model, "gamma_w", None)
if gw is None:
return vmin, vmax
gw_f = _as_float(gw)
if (not _is_finite(gw_f)) or (gw_f <= 0.0):
return vmin, vmax
# mv_config.initial_value is only used as a sanity hint.
mv0 = getattr(
getattr(model, "mv_config", None),
"initial_value",
None,
)
mv0 = float(mv0) if mv0 is not None else None
ss_exp = (mv0 * gw_f) if mv0 else None
# "looks like mv" = very small upper bound,
# and gamma_w looks like N/m^3.
looks_mv = (vmax <= 1e-5) and (gw_f > 1e3)
if looks_mv and (ss_exp is None or ss_exp > 1e-5):
logger.warning(
"Ss_min/max look like m_v; convert via "
"Ss = m_v * gamma_w."
)
return vmin * gw_f, vmax * gw_f
except:
# Never crash: conversion is optional.
return vmin, vmax
return vmin, vmax
def _get_pair(
log_min_key: str,
log_max_key: str,
lin_min_key: str,
lin_max_key: str,
) -> tuple[float | None, float | None]:
"""
Read either log-bounds or linear bounds and return log-bounds.
Returns (None, None) if neither form exists.
Raises ValueError on invalid values.
"""
# 1) Prefer explicit log-bounds if provided.
log_min = b.get(log_min_key, None)
log_max = b.get(log_max_key, None)
if (log_min is not None) and (log_max is not None):
lmin = _as_float(log_min)
lmax = _as_float(log_max)
lmin, lmax = _validate_log_pair(
lmin,
lmax,
name_min=log_min_key,
name_max=log_max_key,
)
return lmin, lmax
# 2) Otherwise, build from linear bounds if provided.
if (lin_min_key not in b) or (lin_max_key not in b):
return None, None
vmin = _as_float(b[lin_min_key])
vmax = _as_float(b[lin_max_key])
# Optional: detect Ss_min/max passed as m_v.
if (lin_min_key == "Ss_min") and (
lin_max_key == "Ss_max"
):
vmin, vmax = _maybe_convert_ss_from_mv(vmin, vmax)
vmin, vmax = _validate_lin_pair(
vmin,
vmax,
name_min=lin_min_key,
name_max=lin_max_key,
)
return float(np.log(vmin)), float(np.log(vmax))
logK_min, logK_max = _get_pair(
"logK_min",
"logK_max",
"K_min",
"K_max",
)
logSs_min, logSs_max = _get_pair(
"logSs_min",
"logSs_max",
"Ss_min",
"Ss_max",
)
# If either set is missing, treat bounds as not configured.
if (logK_min is None) or (logSs_min is None):
return (None, None, None, None)
if not as_tensor:
return logK_min, logK_max, logSs_min, logSs_max
out = (
tf_constant(float(logK_min), dtype),
tf_constant(float(logK_max), dtype),
tf_constant(float(logSs_min), dtype),
tf_constant(float(logSs_max), dtype),
)
vprint(verbose, "bounds: out=", out)
return out
[docs]
def get_log_tau_bounds(
model,
*,
as_tensor: bool = True,
dtype=tf_float32,
verbose: int = 0,
) -> tuple[Any, Any]:
r"""
Get validated log-space bounds for the consolidation timescale.
This helper returns a 2-tuple:
``(log_tau_min, log_tau_max)``,
where :math:`\tau` is the consolidation timescale expressed in
SI seconds, and the returned bounds are in log-seconds.
The function reads bounds from ``model.scaling_kwargs['bounds']``
with the following precedence:
1. Explicit log bounds:
``log_tau_min`` and ``log_tau_max`` (already log-seconds).
2. Linear bounds in seconds:
``tau_min`` and ``tau_max``.
3. Linear bounds in dataset time units:
``tau_min_units`` and ``tau_max_units`` multiplied by the
seconds-per-time-unit factor inferred from ``time_units``.
4. Robust defaults if nothing is provided.
Parameters
----------
model : Any
Model-like object with an optional ``scaling_kwargs`` dict.
Tau bounds are read from ``model.scaling_kwargs['bounds']``.
as_tensor : bool, default=True
If True, return Tensor scalars created with ``tf_constant``.
If False, return Python floats.
dtype : tf.DType, default=tf_float32
Tensor dtype used when ``as_tensor=True``.
verbose : int, default=0
Verbosity level for optional debug printing.
Returns
-------
log_tau_min, log_tau_max : tuple
Log-space bounds (log-seconds). Returned as Tensor scalars
when ``as_tensor=True``, otherwise as Python floats.
Raises
------
ValueError
If user-provided bounds exist but are invalid, including:
* non-finite values (NaN or inf)
* non-positive linear bounds (<= 0)
* unordered bounds (max <= min) for explicit log bounds
Notes
-----
Units and meaning
~~~~~~~~~~~~~~~~~
The consolidation timescale :math:`\tau` controls the relaxation
rate in a first-order consolidation closure, e.g.:
.. math::
\partial_t s = \frac{s_{eq}(h) - s}{\tau},
where :math:`s` is settlement and :math:`s_{eq}` is the
equilibrium settlement implied by head (or drawdown).
Default behavior
~~~~~~~~~~~~~~~~
If no tau bounds are provided, robust defaults are used:
* ``tau_min = 7 days``
* ``tau_max = 300 years``
Both are converted to seconds and then logged. A warning may be
emitted to make the defaulting explicit.
Swapped linear bounds
~~~~~~~~~~~~~~~~~~~~~
If linear bounds are provided with ``tau_max < tau_min``, the
function may swap them to maintain a valid interval.
Examples
--------
Use Tensor bounds for log-space clipping:
>>> log_tau_min, log_tau_max = get_log_tau_bounds(model)
Return floats for reporting:
>>> log_tau_min, log_tau_max = get_log_tau_bounds(
... model, as_tensor=False
... )
See Also
--------
get_log_bounds
Bounds helper for log(K) and log(Ss).
compute_bounds_residual
Computes normalized bound violations using these limits.
"""
sk = getattr(model, "scaling_kwargs", None) or {}
bounds = (sk.get("bounds", None) or {}) or {}
def _is_finite(v: float) -> bool:
return bool(np.isfinite(v))
def _need_raise(v: float | None) -> bool:
return (v is not None) and (not _is_finite(float(v)))
# 1) Explicit log-bounds (already in log-seconds).
log_min = get_sk(
bounds, "log_tau_min", default=None, cast=float
)
log_max = get_sk(
bounds, "log_tau_max", default=None, cast=float
)
if _need_raise(log_min) or _need_raise(log_max):
raise ValueError(
"log_tau_min/log_tau_max must be finite."
)
if (log_min is not None) and (log_max is not None):
if float(log_max) <= float(log_min):
raise ValueError(
"log_tau_max must be > log_tau_min. "
f"Got {log_min}, {log_max}."
)
if not as_tensor:
return float(log_min), float(log_max)
out = (
tf_constant(float(log_min), dtype=dtype),
tf_constant(float(log_max), dtype=dtype),
)
vprint(verbose, "tau_bounds(log-sec):", out)
return out
# 2) Linear tau bounds (seconds).
tau_min = get_sk(
bounds, "tau_min", default=None, cast=float
)
tau_max = get_sk(
bounds, "tau_max", default=None, cast=float
)
if _need_raise(tau_min) or _need_raise(tau_max):
raise ValueError("tau_min/tau_max must be finite.")
# 2b) Linear tau bounds in "time_units".
if (tau_min is None) or (tau_max is None):
tau_min_u = get_sk(
bounds, "tau_min_units", default=None, cast=float
)
tau_max_u = get_sk(
bounds, "tau_max_units", default=None, cast=float
)
if _need_raise(tau_min_u) or _need_raise(tau_max_u):
raise ValueError(
"tau_min_units/tau_max_units must be finite."
)
if (tau_min_u is not None) and (
tau_max_u is not None
):
tu = (
get_sk(sk, "time_units", default=None)
or getattr(model, "time_units", None)
or "yr"
)
key = normalize_time_units(tu)
sec_per = float(
TIME_UNIT_TO_SECONDS.get(key, 1.0)
)
tau_min = float(tau_min_u) * sec_per
tau_max = float(tau_max_u) * sec_per
# 2c) Defaults if still missing.
if (tau_min is None) or (tau_max is None):
sec_day = 86400.0
sec_year = float(
TIME_UNIT_TO_SECONDS.get("yr", 31556952.0),
)
tau_min = 7.0 * sec_day
tau_max = 300.0 * sec_year
logger.warning(
"Tau bounds not found in scaling_kwargs['bounds']; "
"using defaults: tau_min=7 days, "
"tau_max=300 years (SI seconds)."
)
tau_min = float(tau_min)
tau_max = float(tau_max)
if (not _is_finite(tau_min)) or (not _is_finite(tau_max)):
raise ValueError(
f"tau_min/tau_max must be finite. "
f"Got {tau_min}, {tau_max}."
)
if (tau_min <= 0.0) or (tau_max <= 0.0):
raise ValueError(
f"tau_min/tau_max must be > 0. "
f"Got {tau_min}, {tau_max}."
)
if tau_max < tau_min:
logger.warning(
"tau_max < tau_min; swapping tau bounds."
)
tau_min, tau_max = tau_max, tau_min
log_min = float(np.log(tau_min))
log_max = float(np.log(tau_max))
if not as_tensor:
return log_min, log_max
out = (
tf_constant(float(log_min), dtype=dtype),
tf_constant(float(log_max), dtype=dtype),
)
vprint(verbose, "tau_bounds(log-sec):", out)
return out
[docs]
def bounded_exp(
raw: Tensor,
log_min: Tensor,
log_max: Tensor,
*,
eps: float = _EPSILON,
return_log: bool = False,
verbose: int = 0,
):
r"""
Exponentiate a raw parameter inside hard log-bounds.
This helper maps an unconstrained tensor ``raw`` to a positive
field by interpolating in log space between ``log_min`` and
``log_max``. The mapping is smooth and bounded:
.. math::
z = \sigma(\mathrm{raw}), \quad
\log v = \log v_{min} + z(\log v_{max} - \log v_{min}), \quad
v = \exp(\log v) + \varepsilon,
where :math:`\sigma` is the logistic sigmoid and
:math:`\varepsilon` is a small positive floor.
This is used when ``bounds_mode="hard"`` to ensure learned
fields such as :math:`K`, :math:`S_s`, or :math:`\tau` never
leave their configured ranges.
Parameters
----------
raw : Tensor
Unconstrained logit-like tensor (any shape). Non-finite
entries are sanitized to zeros to avoid NaN propagation.
log_min : Tensor
Lower bound in log space. Must be finite for strict
correctness, but non-finite values are sanitized to a safe
constant to prevent NaNs.
log_max : Tensor
Upper bound in log space. Must be finite for strict
correctness, but non-finite values are sanitized to a safe
constant to prevent NaNs.
eps : float, default=_EPSILON
Positive floor added after exponentiation to guarantee
strictly positive output.
return_log : bool, default=False
If True, return ``(out, logv)`` where ``logv`` is the bounded
log value actually exponentiated. If False, return ``out``
only.
verbose : int, default=0
Verbosity level for optional debug printing.
Returns
-------
out : Tensor
Positive bounded field tensor with the same shape as ``raw``.
logv : Tensor, optional
Bounded log value used to compute ``out``. Returned only
when ``return_log=True``.
Notes
-----
Hard bounds via sigmoid
~~~~~~~~~~~~~~~~~~~~~~~
The sigmoid interpolation produces values strictly inside the
interval (up to numerical precision). This avoids the gradient
discontinuity of direct clipping while still enforcing bounds.
Sanitization policy
~~~~~~~~~~~~~~~~~~~
To prevent NaNs and Infs from contaminating training, the
function sanitizes:
* non-finite values in ``raw`` to zeros,
* non-finite values in bounds to safe constants,
* swapped bounds by repairing the interval ordering.
This behavior is defensive and prioritizes numerical stability.
Examples
--------
Bound a raw logit field to the K interval:
>>> K, logK = bounded_exp(
... rawK, logK_min, logK_max, return_log=True
... )
Bound a tau field (already in log seconds bounds):
>>> tau = bounded_exp(raw_tau, log_tau_min, log_tau_max)
See Also
--------
guarded_exp_from_bounds
Soft-bounds path that keeps raw logs for penalties while
guarding exponentiation overflow.
compose_physics_fields
Uses bounded_exp to build K, Ss, and tau fields.
"""
eps_t = tf_constant(float(eps), tf_float32)
log_eps = tf_log(eps_t)
# Sanitize inputs to avoid NaN propagation.
raw = tf_cast(raw, tf_float32)
raw = tf_where(
tf_math.is_finite(raw), raw, tf_zeros_like(raw)
)
log_min = tf_cast(log_min, tf_float32)
log_max = tf_cast(log_max, tf_float32)
log_min = tf_where(
tf_math.is_finite(log_min),
log_min,
log_eps,
)
log_max = tf_where(
tf_math.is_finite(log_max),
log_max,
log_min + tf_constant(1.0, tf_float32),
)
# If user swapped bounds, repair silently (safe + monotone).
log_lo = tf_minimum(log_min, log_max)
log_hi = tf_maximum(log_min, log_max)
# Map raw -> (0,1) then interpolate inside [log_lo, log_hi].
z = tf_sigmoid(raw)
logv = log_lo + z * (log_hi - log_lo)
# Output is positive, with epsilon floor.
out = tf_exp(logv) + eps_t
vprint(verbose, "bounded_exp: logv=", logv)
vprint(verbose, "bounded_exp: out=", out)
if return_log:
return out, logv
return out
[docs]
def finite_floor(x: Tensor, eps: float) -> Tensor:
"""
Replace NaN/Inf by eps and floor to eps.
Useful when you want "never NaN" behaviour, not strict errors.
"""
x = tf_cast(x, tf_float32)
eps_t = tf_constant(float(eps), tf_float32)
x = tf_where(tf_math.is_finite(x), x, eps_t)
return tf_maximum(x, eps_t)
def _finite_or_zero(x: Tensor) -> Tensor:
x = tf_cast(x, tf_float32)
return tf_where(tf_math.is_finite(x), x, tf_zeros_like(x))
def _get_bounds_loss_cfg(
model: Any = None,
scaling_kwargs: dict | None = None,
) -> dict[str, Any]:
def _as_map(x: Any) -> Mapping[str, Any]:
return x if isinstance(x, Mapping) else {}
def _attr(name: str, default: Any) -> Any:
if model is None:
return default
return getattr(model, name, default)
def _take(
sk: Mapping[str, Any], key: str, cur: Any
) -> Any:
# If user/profile put None by mistake, ignore it
if key in sk:
v = sk.get(key)
if v is not None:
return v
return cur
sk_model = _as_map(getattr(model, "scaling_kwargs", None))
sk_arg = _as_map(scaling_kwargs)
# precedence (low -> high):
# model attrs < model.scaling_kwargs < scaling_kwargs arg
mode = _attr("bounds_mode", "soft")
kind = _attr("bounds_loss_kind", "both")
beta = _attr("bounds_beta", 20.0)
guard = _attr("bounds_guard", 5.0)
w_b = _attr("bounds_w", 1.0)
inc_tau = _attr("bounds_include_tau", True)
w_tau = _attr("bounds_tau_w", 1.0)
sources = [sk_model, sk_arg]
# 1) flat keys
for sk in sources:
mode = _take(sk, "bounds_mode", mode)
kind = _take(sk, "bounds_loss_kind", kind)
beta = _take(sk, "bounds_beta", beta)
guard = _take(sk, "bounds_guard", guard)
w_b = _take(sk, "bounds_w", w_b)
inc_tau = _take(sk, "bounds_include_tau", inc_tau)
w_tau = _take(sk, "bounds_tau_w", w_tau)
# 2) nested dict
nested_keys = (
"bounds_loss_settings",
"bounds_loss_setting",
"bound_cfg",
"bound_loss_cfg",
"bounds_settings",
"bounds_config",
"bounds_loss_config",
)
for sk in sources:
nested: Any = {}
for k in nested_keys:
if k in sk:
nested = sk.get(k)
break
if not isinstance(nested, Mapping):
continue
mode = _take(nested, "mode", mode)
kind = _take(nested, "kind", kind)
beta = _take(nested, "beta", beta)
guard = _take(nested, "guard", guard)
w_b = _take(nested, "w", w_b)
inc_tau = _take(nested, "include_tau", inc_tau)
w_tau = _take(nested, "tau_w", w_tau)
mode = str(mode).strip().lower()
kind = str(kind).strip().lower()
return dict(
mode=mode,
kind=kind,
beta=float(beta),
guard=float(guard),
w=float(w_b),
include_tau=bool(inc_tau),
tau_w=float(w_tau),
)
[docs]
def compose_physics_fields(
model,
*,
coords_flat: Tensor,
H_si: Tensor,
K_base: Tensor,
Ss_base: Tensor,
tau_base: Tensor,
training: bool = False,
eps_KSs: float = _EPSILON,
eps_tau: float = 1e-6,
verbose: int = 0,
):
r"""
Compose physically meaningful fields :math:`K`, :math:`S_s`, and
:math:`\tau` from network "base" logits and coordinate corrections.
This routine is the central *field mapping* step for GeoPrior-style
PINN models. The model predicts coarse (time-dependent) latent logits
``K_base``, ``Ss_base``, and ``tau_base`` from the physics head, then
adds smooth spatial corrections from coordinate MLPs:
* ``model.K_coord_mlp`` for :math:`\log K`
* ``model.Ss_coord_mlp`` for :math:`\log S_s`
* ``model.tau_coord_mlp`` for :math:`\Delta \log \tau`
The corrected parameters are then mapped to SI-consistent, positive
fields (in float32-safe ways) and combined with a physics closure
timescale :math:`\tau_\mathrm{phys}` computed from the fields.
Let :math:`(t, x, y)` denote the coordinate tensor passed to the
decoder. Spatial corrections are evaluated on coordinates with time
zeroed:
.. math::
\tilde{\mathbf{c}} = (0, x, y).
Define the raw log-parameters (logits) as:
.. math::
\ell_K &= \ell_K^\mathrm{base}(t,x,y)
+ \Delta \ell_K(\tilde{\mathbf{c}}), \\
\ell_{S_s} &= \ell_{S_s}^\mathrm{base}(t,x,y)
+ \Delta \ell_{S_s}(\tilde{\mathbf{c}}).
The resulting fields are positive exponentials:
.. math::
K = \exp(\ell_K), \qquad
S_s = \exp(\ell_{S_s}),
subject to (log-)bounds. In ``bounds_mode="hard"`` the values are
projected into the valid interval by clipping in log space, while in
``bounds_mode="soft"`` the function returns the unbounded logs for
penalties but uses a *guarded exponential* to avoid float32 overflow.
For the consolidation timescale, we first compute a closure (prior)
timescale from the fields:
.. math::
\log \tau_\mathrm{phys} =
f_\tau(K, S_s, H; \text{model options}),
where :math:`H` is the drained thickness in meters (``H_si``) and
``tau_phys_from_fields`` implements the chosen closure and drainage
convention. The network adds a learnable residual in log space:
.. math::
\Delta \log \tau =
\ell_\tau^\mathrm{base}(t,x,y) + \Delta \ell_\tau(\tilde{\mathbf{c}}),
and the total learned timescale is:
.. math::
\log \tau = \log \tau_\mathrm{phys} + \Delta \log \tau, \qquad
\tau = \exp(\log \tau) + \varepsilon_\tau.
The term :math:`\varepsilon_\tau` (``eps_tau``) is a small positive
floor to avoid exact zeros and improve numerical stability.
Parameters
----------
model : Any
Model-like object providing:
* coordinate MLPs: ``K_coord_mlp``, ``Ss_coord_mlp``,
``tau_coord_mlp``
* bounds configuration: ``bounds_mode`` and bounds accessors
used by ``get_log_bounds`` and ``get_log_tau_bounds``
* closure configuration used by ``tau_phys_from_fields``
coords_flat : Tensor
Coordinate tensor used by the decoder. Expected shape is
``(B, H, 3)`` with last dimension ordered as ``(t, x, y)``.
The function constructs ``(0, x, y)`` for the coordinate MLPs to
keep corrections time-invariant by default.
H_si : Tensor
Drained thickness :math:`H` in SI units (meters). Shape must be
broadcastable to ``(B, H, 1)``.
K_base : Tensor
Base logits for :math:`\log K`. Shape is typically ``(B, H, 1)``.
Ss_base : Tensor
Base logits for :math:`\log S_s`. Shape is typically
``(B, H, 1)``.
tau_base : Tensor
Base logits for :math:`\Delta \log \tau`. Shape is typically
``(B, H, 1)``.
training : bool, default=False
Forward mode for coordinate MLPs.
eps_KSs : float, default=_EPSILON
Small positive constant used when mapping log-parameters to
positive values (e.g., inside bounded / guarded exponentials).
eps_tau : float, default=1e-6
Additive floor for :math:`\tau` in seconds to avoid exact zeros.
verbose : int, default=0
Verbosity level used by internal debug printing utilities.
Returns
-------
K_field : Tensor
Effective hydraulic conductivity field :math:`K` in SI units.
Shape ``(B, H, 1)``. Units are typically meters per second.
Ss_field : Tensor
Effective specific storage field :math:`S_s` in SI units.
Shape ``(B, H, 1)``. Units are typically inverse meters.
tau_field : Tensor
Learned consolidation timescale :math:`\tau` in seconds.
Shape ``(B, H, 1)``.
tau_phys : Tensor
Closure-based timescale :math:`\tau_\mathrm{phys}` in seconds.
Shape ``(B, H, 1)`` (broadcasted as needed).
Hd_eff : Tensor
Effective drainage thickness :math:`H_d` in meters used by the
closure, accounting for drainage mode and ``hd_factor`` style
options. Shape broadcastable to ``(B, H, 1)``.
delta_log_tau : Tensor
The learnable log-residual :math:`\Delta \log \tau` added to
:math:`\log \tau_\mathrm{phys}`. Shape ``(B, H, 1)``.
logK : Tensor
Log-parameter :math:`\log K` used for priors, bounds penalties,
and diagnostics. Shape ``(B, H, 1)``.
logSs : Tensor
Log-parameter :math:`\log S_s` used for priors, bounds penalties,
and diagnostics. Shape ``(B, H, 1)``.
log_tau : Tensor
Log of total timescale :math:`\log \tau` (pre-guard in soft mode).
Returned for bounds penalties and diagnostics. Shape ``(B, H, 1)``.
log_tau_phys : Tensor
Log of closure timescale :math:`\log \tau_\mathrm{phys}` returned
for priors and diagnostics. Shape ``(B, H, 1)``.
Notes
-----
**Why coordinate corrections use ``(0, x, y)``.**
The coordinate MLPs are intended to represent slowly varying spatial
heterogeneity (e.g., lithology-driven variability). Zeroing time
reduces the risk that the model encodes time-varying physics fields
that can destabilize PDE derivatives across horizons.
**Hard vs soft bounds.**
When ``bounds_mode="hard"``, log-parameters are projected into the
valid interval, yielding fields that always satisfy bounds.
When ``bounds_mode="soft"``, log-parameters are returned unmodified
for differentiable penalties, but exponentiation is guarded to prevent
float32 overflow. This preserves gradients for penalties without
risking NaN / Inf in the forward pass.
**Numerical stability.**
The function deliberately avoids reapplying ``log(exp(.))`` patterns.
In particular, it composes :math:`\log \tau` additively:
.. math::
\log \tau = \log \tau_\mathrm{phys} + \Delta \log \tau,
which is both exact and numerically stable.
Examples
--------
Compute fields inside a physics forward pass:
>>> K_field, Ss_field, tau_field, tau_phys, Hd_eff, dlogtau, logK, \
... logSs, log_tau, log_tau_phys = compose_physics_fields(
... model,
... coords_flat=coords,
... H_si=H_si,
... K_base=K_logits,
... Ss_base=Ss_logits,
... tau_base=dlogtau_logits,
... training=True,
... )
Use returned logs for priors and bounds penalties:
>>> prior_res = dlogtau
>>> bounds_penalty_inputs = (logK, logSs, log_tau)
See Also
--------
tau_phys_from_fields
Computes the closure timescale :math:`\tau_\mathrm{phys}`.
get_log_bounds, get_log_tau_bounds
Provide log-space bounds used for field mapping.
bounded_exp, guarded_exp_from_bounds
Safe mappings from log-parameters to positive fields.
compute_bounds_residual
Uses the returned logs and thickness for bounds penalties.
"""
bc = _get_bounds_loss_cfg(model)
mode = bc["mode"]
beta = bc["beta"]
guard = bc["guard"]
w_b = bc["w"]
include_tau = bc["include_tau"]
w_tau = bc["tau_w"]
if verbose > 6:
tf_print_nonfinite("compose/coords_flat", coords_flat)
tf_print_nonfinite("compose/K_base", K_base)
tf_print_nonfinite("compose/Ss_base", Ss_base)
tf_print_nonfinite("compose/tau_base", tau_base)
coords_xy0 = tf_concat(
[
tf_zeros_like(coords_flat[..., :1]),
coords_flat[..., 1:],
],
axis=-1,
)
coords_xy0 = _finite_or_zero(coords_xy0)
K_corr = _finite_or_zero(
model.K_coord_mlp(coords_xy0, training=training)
)
Ss_corr = _finite_or_zero(
model.Ss_coord_mlp(coords_xy0, training=training)
)
tau_corr = _finite_or_zero(
model.tau_coord_mlp(coords_xy0, training=training)
)
if verbose > 6:
tf_print_nonfinite("compose/K_corr", K_corr)
tf_print_nonfinite("compose/Ss_corr", Ss_corr)
tf_print_nonfinite("compose/tau_corr", tau_corr)
rawK = K_base + K_corr
rawSs = Ss_base + Ss_corr
# bounds_mode = str(getattr(model, "bounds_mode", "soft")).strip().lower()
logK_min, logK_max, logSs_min, logSs_max = get_log_bounds(
model,
as_tensor=True,
dtype=rawK.dtype,
verbose=verbose,
)
# # ---- K, Ss ----
# if bounds_mode == "hard":
# K_field, logK = bounded_exp(
# rawK, logK_min, logK_max, eps=eps_KSs,
# return_log=True, verbose=verbose,
# )
# Ss_field, logSs = bounded_exp(
# rawSs, logSs_min, logSs_max, eps=eps_KSs,
# return_log=True, verbose=verbose,
# )
# else:
# # Keep raw log-params (useful for priors/diagnostics),
# # but NEVER feed an unbounded log into exp() in float32.
# K_field, logK, _ = guarded_exp_from_bounds(
# rawK, logK_min, logK_max, eps=eps_KSs,
# guard=5.0, name="K"
# )
# Ss_field, logSs, _ = guarded_exp_from_bounds(
# rawSs, logSs_min, logSs_max,eps=eps_KSs,
# guard=5.0, name="Ss"
# )
# ---- K, Ss (policy-driven) ----
K_field, logK_raw, logK_safe, pK = exp_from_bounds(
rawK,
logK_min,
logK_max,
mode=mode,
beta=beta,
guard=guard,
eps=eps_KSs,
dtype=rawK.dtype,
name="K",
)
Ss_field, logSs_raw, logSs_safe, pS = exp_from_bounds(
rawSs,
logSs_min,
logSs_max,
mode=mode,
beta=beta,
guard=guard,
eps=eps_KSs,
dtype=rawSs.dtype,
name="Ss",
)
# What to return as "logK/logSs" depends on policy:
# - soft/none: return raw (useful for penalties/diagnostics)
# - hard/sigmoid: return safe (already within bounds)
if mode in ("soft", "none"):
logK = logK_raw
logSs = logSs_raw
else:
logK = logK_safe
logSs = logSs_safe
loss_bounds_KSs = tf_constant(float(w_b), rawK.dtype) * (
_reduce_barrier_mean(pK) + _reduce_barrier_mean(pS)
)
# Optional: keep the asserts, but now they won't trip from exp overflow.
tf_debugging.assert_all_finite(
logK, "rawK/logK non-finite"
)
tf_debugging.assert_all_finite(
logSs, "rawSs/logSs non-finite"
)
tf_debugging.assert_all_finite(
K_field, "K_field non-finite"
)
tf_debugging.assert_all_finite(
Ss_field, "Ss_field non-finite"
)
# ---- tau ( log-space composition + bounds) ----
delta_log_tau = _finite_or_zero(tau_base + tau_corr)
if verbose > 6:
tf_print_nonfinite("compose/rawK", rawK)
tf_print_nonfinite("compose/rawSs", rawSs)
tf_print_nonfinite(
"compose/delta_log_tau", delta_log_tau
)
# 1. Capture output as LOG value (because return_log=True)
log_tau_phys, Hd_eff = tau_phys_from_fields(
model,
K_field,
Ss_field,
H_si,
return_log=True,
verbose=0,
)
# 2. Calculate linear tau_phys safely from log (for logging/debugging)
tau_phys = tf_exp(log_tau_phys)
# ---- tau (policy-driven) ----
# 3. Calculate total log directly (avoiding re-logging the exp)
# Previous bad logic: log(max(exp(log_x), eps)) -> redundant and lossy
# New logic: just add the logs directly.
log_tau_total = log_tau_phys + delta_log_tau
log_tau_min, log_tau_max = get_log_tau_bounds(
model,
as_tensor=True,
dtype=log_tau_total.dtype,
verbose=0,
)
# if bounds_mode == "hard":
# # true hard bounds: clip in log-space (keeps tau_phys anchoring)
# log_tau = tf_clip_by_value(log_tau_total, log_tau_min, log_tau_max)
# tau_field = tf_exp(log_tau) + tf_constant(eps_tau, log_tau.dtype)
# else:
# # soft mode: keep log_tau for bounds penalty, but guard exp overflow
# log_tau = log_tau_total
# guard_lo = log_tau_min - tf_constant(10.0, log_tau.dtype)
# guard_hi = log_tau_max + tf_constant(10.0, log_tau.dtype)
# log_tau_safe = tf_clip_by_value(log_tau, guard_lo, guard_hi)
# tau_field = tf_exp(log_tau_safe) + tf_constant(eps_tau, log_tau.dtype)
tau_field, logTau_raw, logTau_safe, pT = exp_from_bounds(
log_tau_total,
log_tau_min,
log_tau_max,
mode=mode,
beta=beta,
guard=guard,
eps=eps_tau,
dtype=log_tau_total.dtype,
name="tau",
)
if mode in ("soft", "none"):
log_tau = logTau_raw
else:
log_tau = logTau_safe
loss_bounds_tau = tf_zeros_like(loss_bounds_KSs)
if include_tau:
loss_bounds_tau = tf_constant(
float(w_tau), log_tau.dtype
) * (_reduce_barrier_mean(pT))
loss_bounds = loss_bounds_KSs + loss_bounds_tau
if verbose > 6:
tf_print_nonfinite("compose/K_field", K_field)
tf_print_nonfinite("compose/Ss_field", Ss_field)
tf_print_nonfinite("compose/tau_phys", tau_phys)
tf_print_nonfinite(
"compose/log_tau_phys", log_tau_phys
)
tf_print_nonfinite(
"compose/log_tau_total", log_tau_total
)
tf_print_nonfinite("compose/tau_field", tau_field)
vprint(verbose, "fields: K=", K_field)
vprint(verbose, "fields: Ss=", Ss_field)
vprint(verbose, "fields: tau=", tau_field)
vprint(verbose, "fields: tau_phys=", tau_phys)
return (
K_field,
Ss_field,
tau_field,
tau_phys,
Hd_eff,
delta_log_tau,
logK,
logSs,
log_tau, # return log_tau for bounds penalty + diagnostics
log_tau_phys, # optional but very useful for priors/diagnostics
loss_bounds,
)
def _log_bounds_residual(
logv: Tensor,
lo: Tensor,
hi: Tensor,
*,
eps: float = 1e-12,
name: str = "",
) -> Tensor:
"""
Normalized bound violation in log-space.
We compute a symmetric distance outside [lo, hi], then
normalize by the range (hi - lo). This returns 0 inside
bounds and >0 outside bounds.
Notes
-----
- We sanitize non-finite logv to avoid NaN explosions.
- lo/hi are assumed finite tensors (from helpers).
"""
dtype = logv.dtype
zero = tf_constant(0.0, dtype=dtype)
eps_t = tf_constant(float(eps), dtype=dtype)
# Sanitize inputs: never propagate NaN/Inf into loss.
is_ok = tf_math.is_finite(logv)
logv = tf_where(is_ok, logv, tf_zeros_like(logv))
lo = tf_cast(lo, dtype)
hi = tf_cast(hi, dtype)
lower = tf_maximum(lo - logv, zero)
upper = tf_maximum(logv - hi, zero)
rng = tf_maximum(hi - lo, eps_t)
res = (lower + upper) / rng
# Optional debug checks (keep off by default).
if name:
msg = name + " bounds residual non-finite"
tf_debugging.assert_all_finite(res, msg)
return res
[docs]
def compute_bounds_residual(
model: Any,
*,
H_field: Tensor,
logK: TensorLike | None = None,
logSs: TensorLike | None = None,
log_tau: TensorLike | None = None,
K_field: TensorLike | None = None,
Ss_field: TensorLike | None = None,
tau_field: TensorLike | None = None,
eps: float = _EPSILON,
verbose: int = 0,
) -> tuple[Tensor, Tensor, Tensor, Tensor]:
r"""
Compute differentiable bound-violation residuals for the learned
physics fields.
This function converts configured parameter bounds into *residual
maps* that can be squared and averaged to form a soft penalty term
(e.g., :math:`L_\mathrm{bounds} = \mathrm{mean}(R^2)`).
The bounds policy is driven by ``model.scaling_kwargs['bounds']`` and
supports:
* Linear-space bounds for drained thickness :math:`H` (meters).
* Log-space bounds for :math:`K`, :math:`S_s`, and :math:`\tau`.
The returned residuals are normalized by the corresponding bound
ranges, so they are roughly comparable across parameters.
Mathematical formulation
------------------------
Let :math:`z` be a scalar parameter with bounds
:math:`z_{\min} \le z \le z_{\max}`. A standard non-negative violation
is:
.. math::
v(z) = \max(z_{\min} - z, 0) + \max(z - z_{\max}, 0).
This function returns a *range-normalized* residual:
.. math::
R(z) = \frac{v(z)}{\max(z_{\max} - z_{\min}, \varepsilon)}.
For log-bounded parameters (conductivity, storage, and timescale), the
same definition is applied in log space. For example, if
:math:`\ell_K = \log K`, then:
.. math::
R_K = R(\ell_K; \ell_{K,\min}, \ell_{K,\max}),
where :math:`\ell_{K,\min} = \log K_{\min}` and
:math:`\ell_{K,\max} = \log K_{\max}`.
Preferred usage in soft bounds mode
-----------------------------------
When ``bounds_mode="soft"``, it is preferable to pass the raw log
parameters ``logK``, ``logSs``, and ``log_tau`` produced before any
"guarded exponential" is applied. This ensures that the penalty
reflects the true magnitude of out-of-range logits, even if the
corresponding fields are exponentiated using an overflow guard.
If raw logs are not provided, the function falls back to inferring
logs from the fields via ``log(max(field, eps))``. This is safe, but
may under-estimate violations if the field values were produced by a
guarded mapping.
Parameters
----------
model : Any
Model-like object providing:
* ``scaling_kwargs`` with optional bounds + bounds policy keys.
* Accessors used by ``get_log_bounds`` and ``get_log_tau_bounds``.
* (Optional) pre-resolved bound tensors cached by the model.
Bounds configuration is read from ``model.scaling_kwargs``:
* ``bounds`` (dict): numeric ranges such as ``H_min/H_max``,
``K_min/K_max`` or ``logK_min/logK_max``, similarly for
``Ss`` and optionally ``tau``.
* ``bounds_mode``: one of ``{'soft','hard','sigmoid','none'}``.
* ``bounds_beta``: barrier sharpness for ``bounds_mode='soft'``.
* ``bounds_guard``: numeric guard band used by guarded mappings.
* ``bounds_w``: barrier weight for K + Ss.
* ``bounds_include_tau``: whether to include tau bounds.
* ``bounds_tau_w``: barrier weight for tau if included.
H_field : Tensor
Drained thickness field :math:`H` in SI meters. Shape must be
broadcastable to ``(B, H, 1)``. Bounds are applied in linear space
using keys ``H_min`` and ``H_max`` when present.
logK : Tensor, optional
Log-conductivity :math:`\log K` (preferred). If provided, bounds
are applied directly in log space and ``K_field`` is not needed.
logSs : Tensor, optional
Log-specific storage :math:`\log S_s` (preferred). If provided,
bounds are applied directly in log space and ``Ss_field`` is not
needed.
log_tau : Tensor, optional
Log-timescale :math:`\log \tau` in seconds (preferred). If
provided, bounds are applied directly in log space and
``tau_field`` is not needed.
K_field : Tensor, optional
Conductivity field :math:`K` (meters per second). Used only if
``logK`` is not provided.
Ss_field : Tensor, optional
Specific storage field :math:`S_s` (inverse meters). Used only if
``logSs`` is not provided.
tau_field : Tensor, optional
Timescale field :math:`\tau` (seconds). Used only if ``log_tau`` is
not provided.
eps : float, default=_EPSILON
Small positive constant used to avoid division by zero and
undefined logs, via :math:`\max(\cdot, \varepsilon)`.
verbose : int, default=0
Verbosity level for optional debug printing.
Returns
-------
R_H : Tensor
Residual map for thickness bounds violations. Same shape as
``H_field`` (broadcasted as needed). All values are non-negative.
R_K : Tensor
Residual map for conductivity log-bounds violations. Same shape as
``H_field`` (broadcasted as needed). All values are non-negative.
R_Ss : Tensor
Residual map for specific storage log-bounds violations. Same
shape as ``H_field`` (broadcasted as needed). All values are
non-negative.
R_tau : Tensor
Residual map for timescale log-bounds violations. Same shape as
``H_field`` (broadcasted as needed). All values are non-negative.
loss_bounds_barrier : Tensor
Non-negative scalar barrier penalty induced by the configured
bounds policy (typically K/Ss and optionally tau). This is the
*barrier component only*.
Any *residual-style* bound penalty (range-normalized violation
residuals for H/K/Ss/tau) is computed separately (e.g. via
``compute_bounds_residual``) and can be combined downstream
according to ``bounds_loss_kind``.
Notes
-----
Bounds configuration
~~~~~~~~~~~~~~~~~~~~
The function reads bounds from:
* ``model.scaling_kwargs.get('bounds', {})`` for ``H_min`` and ``H_max``.
* ``get_log_bounds(model, ...)`` for log-space bounds of :math:`K` and
:math:`S_s` (typically ``logK_min``, ``logK_max``, ``logSs_min``,
``logSs_max``).
* ``get_log_tau_bounds(model, ...)`` for log-space bounds of
:math:`\tau` (typically ``logTau_min`` and ``logTau_max``).
1) Linear-space bounds (thickness)
* ``scaling_kwargs['bounds']['H_min']`` and
``scaling_kwargs['bounds']['H_max']`` (meters)
2) Log-space bounds (K and Ss)
Bounds are resolved by ``get_log_bounds(model, ...)``, which
may be backed by any of the following configuration entries:
* Direct log bounds:
``bounds['logK_min']`` / ``bounds['logK_max']``,
``bounds['logSs_min']`` / ``bounds['logSs_max']``
* Or linear bounds that can be converted to logs:
``bounds['K_min']`` / ``bounds['K_max']``,
``bounds['Ss_min']`` / ``bounds['Ss_max']``
3) Log-space bounds (tau)
Bounds are resolved by ``get_log_tau_bounds(model, ...)``,
typically from:
* Direct log bounds:
``bounds['logTau_min']`` / ``bounds['logTau_max']``
* Or linear bounds convertible to logs:
``bounds['tau_min']`` / ``bounds['tau_max']``
If a bound is missing (or the accessor returns ``None``),
the corresponding residual is returned as zeros.
Interaction with bounds policy
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Keys such as ``bounds_mode``, ``bounds_beta``, ``bounds_guard``,
``bounds_w``, ``bounds_include_tau``, and ``bounds_tau_w`` define
how the *barrier* term is produced (usually in
``compose_physics_fields``). They do not change the definition
of the residual maps returned here; instead, they control how
the physics core combines residual penalty and barrier penalty
via ``bounds_loss_kind``.
Normalization
~~~~~~~~~~~~~
Residuals are normalized by the bound range to reduce sensitivity to
the absolute scale of each parameter. This makes it easier to set a
single loss weight such as ``lambda_bounds`` without one parameter
dominating purely due to units.
Debugging tip
~~~~~~~~~~~~~
To debug bounds behavior, log the two components separately:
* ``loss_bounds_resid`` from these residual maps:
``mean(square(concat([R_H, R_K, R_Ss, R_tau])))``
* ``loss_bounds_barrier`` returned by ``compose_physics_fields``
Then feed only their configured combination into the final
physics loss according to ``bounds_loss_kind``.
Examples
--------
Compute residuals from raw logs (recommended in soft mode):
>>> R_H, R_K, R_Ss, R_tau = compute_bounds_residual(
... model,
... H_field=H_si,
... logK=logK,
... logSs=logSs,
... log_tau=log_tau,
... )
Fallback when only fields are available:
>>> R_H, R_K, R_Ss, R_tau = compute_bounds_residual(
... model,
... H_field=H_si,
... K_field=K_field,
... Ss_field=Ss_field,
... tau_field=tau_field,
... )
Create a scalar penalty:
>>> bounds_res = tf_concat([R_H, R_K, R_Ss, R_tau], axis=-1)
>>> loss_bounds = tf_reduce_mean(tf_square(bounds_res))
See Also
--------
compose_physics_fields
Produces both raw logs and exponentiated fields.
get_log_bounds, get_log_tau_bounds
Retrieve log-space bounds for :math:`K`, :math:`S_s`, and
:math:`\tau`.
_log_bounds_residual
Internal helper that converts log-values to normalized residuals.
"""
dtype = H_field.dtype
eps_t = tf_constant(eps, dtype=dtype)
zero = tf_constant(0.0, dtype=dtype)
# ------------------------------------------------------
# H bounds (linear space, SI meters).
# ------------------------------------------------------
H_safe = tf_maximum(tf_cast(H_field, dtype), eps_t)
sk = getattr(model, "scaling_kwargs", None) or {}
b = (sk.get("bounds", None) or {}) or {}
H_min = b.get("H_min", None)
H_max = b.get("H_max", None)
if (H_min is None) or (H_max is None):
R_H = tf_zeros_like(H_safe)
else:
H_min_t = tf_constant(float(H_min), dtype=dtype)
H_max_t = tf_constant(float(H_max), dtype=dtype)
lo = tf_maximum(H_min_t - H_safe, zero)
hi = tf_maximum(H_safe - H_max_t, zero)
rng = tf_maximum(H_max_t - H_min_t, eps_t)
R_H = (lo + hi) / rng
# ------------------------------------------------------
# K, Ss bounds (log-space).
# Prefer raw logs if provided.
# ------------------------------------------------------
out = get_log_bounds(
model,
as_tensor=True,
dtype=dtype,
verbose=0,
)
logK_min, logK_max, logSs_min, logSs_max = out
if logK_min is None:
# Bounds not configured -> no penalty.
R_K = tf_zeros_like(H_safe)
R_Ss = tf_zeros_like(H_safe)
else:
# ---- K residual ----
if logK is None:
if K_field is None:
R_K = tf_zeros_like(H_safe)
else:
K_safe = tf_maximum(
tf_cast(K_field, dtype), eps_t
)
logK_hat = tf_math.log(K_safe)
R_K = _log_bounds_residual(
logK_hat,
logK_min,
logK_max,
name="K",
)
else:
R_K = _log_bounds_residual(
tf_cast(logK, dtype),
logK_min,
logK_max,
name="K",
)
# ---- Ss residual ----
if logSs is None:
if Ss_field is None:
R_Ss = tf_zeros_like(H_safe)
else:
Ss_safe = tf_maximum(
tf_cast(Ss_field, dtype),
eps_t,
)
logSs_hat = tf_math.log(Ss_safe)
R_Ss = _log_bounds_residual(
logSs_hat,
logSs_min,
logSs_max,
name="Ss",
)
else:
R_Ss = _log_bounds_residual(
tf_cast(logSs, dtype),
logSs_min,
logSs_max,
name="Ss",
)
# ------------------------------------------------------
# tau bounds (log-space, seconds).
# Prefer raw log_tau if provided.
# ------------------------------------------------------
log_tau_min, log_tau_max = get_log_tau_bounds(
model,
as_tensor=True,
dtype=dtype,
verbose=0,
)
if log_tau is not None:
R_tau = _log_bounds_residual(
tf_cast(log_tau, dtype),
log_tau_min,
log_tau_max,
name="tau",
)
elif tau_field is not None:
tau_safe = tf_maximum(
tf_cast(tau_field, dtype), eps_t
)
log_tau_hat = tf_math.log(tau_safe)
R_tau = _log_bounds_residual(
log_tau_hat,
log_tau_min,
log_tau_max,
name="tau",
)
else:
R_tau = tf_zeros_like(H_safe)
if verbose > 6:
vprint(verbose, "bounds: R_H=", R_H)
vprint(verbose, "bounds: R_K=", R_K)
vprint(verbose, "bounds: R_Ss=", R_Ss)
vprint(verbose, "bounds: R_tau=", R_tau)
return R_H, R_K, R_Ss, R_tau
def _compute_bounds_residual(
model,
K_field: Tensor,
Ss_field: Tensor,
H_field: Tensor,
*,
eps: float = 1e-12,
verbose: int = 0,
) -> tuple[Tensor, Tensor, Tensor]:
"""Bounds residuals for H,K,Ss.
Preferred usage in soft bounds mode
-----------------------------------
When ``model.scaling_kwargs['bounds_mode'] == 'soft'``, the
model typically enforces bounds via a *barrier* penalty (often
computed inside ``compose_physics_fields``) while the forward
mapping from logits to fields may also use a numeric guard
(e.g. ``bounds_guard``) to prevent overflow.
For the most informative residual-style diagnostics, pass the
*raw* log-parameters (``logK``, ``logSs``, ``log_tau``) produced
*before* any guarded exponential / squashing is applied. This
ensures the residual penalty reflects the true distance of the
raw parameters from the configured log-bounds, even if the
corresponding physical fields were produced by a guarded
mapping.
If raw logs are not provided, the function falls back to
inferring logs from the fields via ``log(max(field, eps))``.
This is numerically safe, but it can *under-estimate*
violations when the forward mapping clips or guards extreme
values (e.g. due to ``bounds_guard`` or hard/sigmoid modes).
Important
~~~~~~~~~
``compute_bounds_residual`` always returns *violation residuals*
(regardless of ``bounds_mode``). The training objective chooses
how to combine:
(1) the residual penalty derived from these residual maps and
(2) the barrier penalty returned by ``compose_physics_fields``,
according to ``scaling_kwargs['bounds_loss_kind']``.
"""
dtype = K_field.dtype
eps = tf_constant(eps, dtype=dtype)
zero = tf_constant(0.0, dtype=dtype)
K_safe = tf_maximum(K_field, eps)
Ss_safe = tf_maximum(Ss_field, eps)
H_safe = tf_maximum(H_field, eps)
bounds_cfg = (model.scaling_kwargs or {}).get(
"bounds",
{},
) or {}
H_min = bounds_cfg.get("H_min", None)
H_max = bounds_cfg.get("H_max", None)
if (H_min is None) or (H_max is None):
R_H = tf_zeros_like(H_safe)
else:
H_min_t = tf_constant(float(H_min), dtype=dtype)
H_max_t = tf_constant(float(H_max), dtype=dtype)
lower = tf_maximum(H_min_t - H_safe, zero)
upper = tf_maximum(H_safe - H_max_t, zero)
H_rng = tf_maximum(H_max_t - H_min_t, eps)
R_H = (lower + upper) / H_rng
def log_bound(val_safe, log_min, log_max):
logv = tf_log(val_safe)
lo = tf_constant(float(log_min), dtype=dtype)
hi = tf_constant(float(log_max), dtype=dtype)
lower = tf_maximum(lo - logv, zero)
upper = tf_maximum(logv - hi, zero)
rng = tf_maximum(hi - lo, eps)
return (lower + upper) / rng
logK_min = bounds_cfg.get("logK_min", None)
logK_max = bounds_cfg.get("logK_max", None)
if (logK_min is None or logK_max is None) and (
bounds_cfg.get("K_min") is not None
and bounds_cfg.get("K_max") is not None
):
logK_min = float(np.log(float(bounds_cfg["K_min"])))
logK_max = float(np.log(float(bounds_cfg["K_max"])))
if (logK_min is None) or (logK_max is None):
R_K = tf_zeros_like(K_safe)
else:
R_K = log_bound(K_safe, logK_min, logK_max)
logSs_min = bounds_cfg.get("logSs_min", None)
logSs_max = bounds_cfg.get("logSs_max", None)
if (logSs_min is None or logSs_max is None) and (
bounds_cfg.get("Ss_min") is not None
and bounds_cfg.get("Ss_max") is not None
):
logSs_min = float(np.log(float(bounds_cfg["Ss_min"])))
logSs_max = float(np.log(float(bounds_cfg["Ss_max"])))
if (logSs_min is None) or (logSs_max is None):
R_Ss = tf_zeros_like(Ss_safe)
else:
R_Ss = log_bound(Ss_safe, logSs_min, logSs_max)
if verbose > 6:
vprint(verbose, "bounds: R_H=", R_H)
vprint(verbose, "bounds: R_K=", R_K)
vprint(verbose, "bounds: R_Ss=", R_Ss)
return R_H, R_K, R_Ss
[docs]
def guard_scale_with_residual(
residual: Tensor,
scale: Tensor,
*,
floor: float,
eps: float = _EPSILON,
) -> Tensor:
r"""
Guard a residual scale using the observed residual magnitude.
This helper prevents residual normalization from exploding
when a nominal scale is too small compared with the actual
residual values observed on the current batch.
Mathematical intent
-------------------
Given a residual tensor :math:`r` and a proposed scale
:math:`s`, this function produces a guarded scale
:math:`\hat{s}` such that:
.. math::
\hat{s} \ge s_{floor}
and also:
.. math::
\hat{s} \ge r_{ref}
.. math::
\hat{s} \ge 0.1\, r_{max}
where:
.. math::
r_{ref} = \mathrm{mean}(|r|) + \varepsilon
.. math::
r_{max} = \mathrm{max}(|r|) + \varepsilon
This ensures the normalized residual:
.. math::
\tilde{r} = \frac{r}{\hat{s}}
is not arbitrarily large solely because the scale estimate
collapsed (for example, due to degenerate dt, missing
features, or transient NaN handling upstream).
Key properties
--------------
* Robust to NaN/Inf in ``residual`` and ``scale``:
non-finite entries are treated as zeros for the residual
magnitude and replaced by ``floor`` for the scale.
* Uses stop-gradient on the guarded scale, so the model
cannot reduce loss by manipulating the scale path.
* Uses both a typical magnitude (mean) and a tail proxy
(max) to avoid under-scaling when outliers occur.
Parameters
----------
residual : Tensor
Residual tensor :math:`r`. Any shape is accepted.
Values are flattened internally to compute statistics.
scale : Tensor
Proposed residual scale :math:`s`. May be scalar-like
(recommended). If non-finite, it is replaced by
``floor``.
floor : float
Minimal allowed scale :math:`s_{floor}`. This protects
against division by tiny values.
eps : float, default=_EPSILON
Small positive constant :math:`\varepsilon` added to the
residual statistics to prevent zero magnitudes.
Returns
-------
scale_guarded : Tensor
Guarded scale :math:`\hat{s}` with stop-gradient applied.
Notes
-----
Why guard against residual magnitude
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
If a scale is estimated from weak signals (for example,
nearly constant series) it may become very small. Dividing
a non-trivial residual by that small scale can dominate the
training objective and destabilize optimization.
Why the 0.1 * max term
~~~~~~~~~~~~~~~~~~~~~~
The mean alone can underestimate the needed scale when
residuals have heavy tails or occasional spikes. Including a
fraction of the max provides a simple tail-sensitive guard.
Examples
--------
This function is typically used before scaling residuals:
.. code-block:: python
scale_gw = _gw_scale_core(...)
scale_gw = guard_scale_with_residual(
residual=R_gw,
scale=scale_gw,
floor=1e-8,
)
R_gw_scaled = scale_residual(R_gw, scale_gw)
See Also
--------
scale_residual
Divide a residual by a stop-gradient scale safely.
to_rms
Compute RMS magnitudes when you want RMS-based scaling.
"""
dtype = residual.dtype
eps_t = tf_constant(float(eps), dtype=dtype)
floor_t = tf_constant(float(floor), dtype=dtype)
r = tf_abs(_finite_or_zero(residual))
r = tf_reshape(r, [-1])
r_ref = tf_stop_gradient(tf_reduce_mean(r) + eps_t)
r_max = tf_stop_gradient(tf_reduce_max(r) + eps_t)
s = tf_cast(scale, dtype)
s = tf_where(tf_math.is_finite(s), s, floor_t)
# Guard: scale >= typical residual magnitude
s = tf_maximum(s, r_ref)
s = tf_maximum(s, tf_constant(0.1, dtype) * r_max)
return tf_stop_gradient(tf_maximum(s, floor_t))
[docs]
def scale_residual(
residual: Tensor,
scale: Tensor,
*,
floor: float = _EPSILON,
) -> Tensor:
r"""
Scale a residual by a (guarded) normalization factor.
This helper divides a residual tensor by a positive scale,
with strict safeguards against non-finite or tiny scales.
The scale is treated as a constant with respect to
backpropagation (stop-gradient).
Mathematical definition
-----------------------
Given a residual :math:`r` and a scale :math:`s`, the scaled
residual is:
.. math::
\tilde{r} = \frac{r}{\hat{s} + \varepsilon}
where:
.. math::
\hat{s} = \mathrm{stop\_grad}(\max(s, s_{floor}))
and any non-finite :math:`s` is replaced by :math:`s_{floor}`
before flooring.
This ensures that :math:`\tilde{r}` remains finite and that
gradients flow only through :math:`r`, not through the scale.
Parameters
----------
residual : Tensor
Residual tensor :math:`r` to normalize.
scale : Tensor
Scale tensor :math:`s` (typically scalar-like). If it is
NaN/Inf, it is replaced with ``floor``.
floor : float, default=_EPSILON
Minimal allowed scale :math:`s_{floor}`.
Returns
-------
residual_scaled : Tensor
Scaled residual :math:`\tilde{r}` with the same shape as
``residual``.
Notes
-----
Use with guard_scale_with_residual
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
If ``scale`` is derived from heuristics or batch statistics,
it can occasionally be too small. In that case, call
:func:`guard_scale_with_residual` first to ensure the scale
is consistent with observed residual magnitudes.
Why stop-gradient
~~~~~~~~~~~~~~~~~
Residual scaling is a conditioning tool. Allowing gradients
to adjust the scale can create degenerate solutions where
the model inflates the scale instead of reducing the
residual.
Examples
--------
.. code-block:: python
s = _finite_or_zero(scale_est)
s = guard_scale_with_residual(R, s, floor=1e-8)
R_scaled = scale_residual(R, s, floor=1e-8)
See Also
--------
guard_scale_with_residual
Strengthen a scale using residual statistics.
compute_scales
Compute robust scales for consolidation and groundwater
residuals.
"""
s = tf_cast(scale, residual.dtype)
f = tf_constant(float(floor), residual.dtype)
# If scale is NaN/Inf -> replace with floor BEFORE max()
s = tf_where(tf_math.is_finite(s), s, f)
s = tf_maximum(s, f)
s = tf_stop_gradient(s)
return residual / (
s + tf_constant(_EPSILON, residual.dtype)
)
def _cons_scale_core(
*,
s: Tensor,
h: Tensor,
Ss: Tensor,
dt_ref_u: Tensor,
dt_ref_s: Tensor,
mode: str,
time_units: str,
tau: Tensor,
Hf: Tensor,
href: Tensor,
use_relax: bool,
floor: float,
) -> Tensor:
r"""
Compute the consolidation residual scale.
This helper builds a robust, positive scale used to
non-dimensionalize (or weight) the consolidation residual.
It is designed to be stable under noisy early training,
variable horizons, and occasional non-finite values.
The scale is a stop-gradient quantity. It is computed from
simple batch statistics so that it adapts to the magnitude
of the current batch while not injecting gradients back into
the model through the normalization path.
Mathematical intent
-------------------
Let :math:`s_{b,t}` be the settlement state (m) for batch
index :math:`b` and time index :math:`t` over a horizon
length :math:`H`.
A consolidation residual is typically expressed as either:
1) step form (meters per step)
.. math::
R_{cons}^{step}(t)
= s_{t+1} - s_t
- \hat{\Delta s}_t
2) rate form (meters per time)
.. math::
R_{cons}^{rate}(t)
= \frac{ds}{dt}
- \frac{s_{eq}(t) - s(t)}{\tau(t)}
To make such residuals comparable across batches and to
prevent any single term from dominating due to scale alone,
we normalize by a characteristic magnitude:
.. math::
\tilde{R}_{cons} = \frac{R_{cons}}{c_*}
where :math:`c_*` is the "consolidation scale" returned by
this function.
This helper computes :math:`c_*` from two sources:
A) empirical change statistics from the settlement series
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Define per-step increments:
.. math::
\Delta s_{b,t} = s_{b,t+1} - s_{b,t}
Let :math:`|\Delta s|` denote absolute increments flattened
over batch and time. Two robust summary statistics are used:
.. math::
d_{ref} = \mathrm{mean}(|\Delta s|)
.. math::
d_{max} = \mathrm{max}(|\Delta s|)
Both are treated as constants w.r.t. gradients
(:func:`tf.stop_gradient`).
Depending on ``mode``, the base scale is:
* ``mode="step"`` (meters per step)
.. math::
c_{base} =
\max(d_{ref}, 0.1\, d_{max})
* ``mode="time_unit"`` (meters per time_unit)
Let :math:`\Delta t_{ref,u}` be a representative step size
in "time_units" (the caller provides ``dt_ref_u``):
.. math::
c_{base} =
\max\left(
\frac{d_{ref}}{\Delta t_{ref,u}},
0.1\,\frac{d_{max}}{\Delta t_{ref,u}}
\right)
* otherwise (SI rate, meters per second)
Let :math:`\Delta t_{ref,s}` be a representative step size
in seconds (the caller provides ``dt_ref_s``):
.. math::
c_{base} =
\max\left(
\frac{d_{ref}}{\Delta t_{ref,s}},
0.1\,\frac{d_{max}}{\Delta t_{ref,s}}
\right)
B) optional relaxation and equilibrium misfit statistics
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
When ``use_relax=True``, an additional scale component is
computed from the equilibrium settlement implied by the
fields and drawdown proxy.
A non-negative drawdown proxy is formed as:
.. math::
\Delta h = \max(h_{ref} - h, 0)
An equilibrium settlement is then:
.. math::
s_{eq} = S_s \, \Delta h \, H_f
where :math:`S_s` is specific storage (1/m) and :math:`H_f`
is a thickness (m). The equilibrium misfit magnitude is:
.. math::
e = |s_{eq} - s|
Flatten :math:`e` over batch and time and compute:
.. math::
e_{ref} = \mathrm{mean}(e) + \varepsilon
.. math::
e_{max} = \mathrm{max}(e) + \varepsilon
How this affects the scale depends on ``mode``:
* ``mode="step"``
In step units, equilibrium misfit is already in meters,
so it can directly act as a characteristic magnitude:
.. math::
c_* \leftarrow
\max(c_{base}, e_{ref}, 0.1\,e_{max})
* rate modes (meters per time)
Convert the misfit to a characteristic relaxation rate:
.. math::
r = \left|\frac{s_{eq} - s}{\tau}\right|
For SI rate mode, :math:`r` is in m/s. For ``mode="time_unit"``,
convert m/s to m/time_unit using the number of seconds per
unit :math:`\mathrm{sec}_{u}`:
.. math::
r_{u} = r\, \mathrm{sec}_{u}
Then summarize:
.. math::
r_{ref} = \mathrm{mean}(r) + \varepsilon
.. math::
r_{max} = \mathrm{max}(r) + \varepsilon
And update:
.. math::
c_* \leftarrow
\max(c_{base}, r_{ref}, 0.1\,r_{max})
Final flooring and gradient behavior
------------------------------------
A positive floor is enforced:
.. math::
c_* \leftarrow \max(c_*, c_{floor})
and the result is stop-gradient:
.. math::
c_* \leftarrow \mathrm{stop\_grad}(c_*)
This ensures the scale cannot collapse to zero and cannot
backpropagate into model parameters.
Parameters
----------
Refer to :func:`compute_scales` for the definition and
meaning of all inputs. This helper assumes all tensors are
already broadcastable to ``(B, H, 1)`` and represent SI
quantities consistent with the consolidation objective.
Returns
-------
cons_scale : Tensor
A scalar Tensor (or scalar-like Tensor) representing the
consolidation scale :math:`c_*`.
Units depend on ``mode``:
* ``"step"`` : meters per step
* ``"time_unit"``: meters per time_unit
* otherwise : meters per second
Notes
-----
Why both mean and max
~~~~~~~~~~~~~~~~~~~~~
Using :math:`\max(d_{ref}, 0.1 d_{max})` reduces sensitivity
to outliers while still reacting when the batch contains
rare but very large changes (which can otherwise produce
under-scaling and exploding normalized residuals).
Why stop-gradient
~~~~~~~~~~~~~~~~~
The scale is a diagnostic normalization factor. Letting
gradients flow through statistics such as mean/max can
create undesirable feedback loops where the model learns to
change the scale instead of reducing the residual.
Drawdown proxy
~~~~~~~~~~~~~~
The drawdown proxy here uses a hard positive-part gate
:math:`\max(h_{ref} - h, 0)`. If your pipeline uses a smooth
gate or a different sign convention, that logic should be
handled by the caller before reaching this helper.
Examples
--------
This function is not intended to be called directly.
Use :func:`compute_scales`, which computes both
consolidation and groundwater residual scales and handles
time-unit conversions and input sanitization.
See Also
--------
compute_scales
Public interface that computes residual scales.
equilibrium_compaction_si
Computes :math:`s_{eq}` given fields and drawdown logic.
dt_to_seconds
Conversion of time step sizes to SI seconds.
"""
eps = tf_constant(_EPSILON, tf_float32)
floor_t = tf_constant(float(floor), tf_float32)
# ------------------------------------------------------
# Sanitize inputs (avoid NaN/Inf in reductions).
# ------------------------------------------------------
s = _finite_or_zero(s)
h = _finite_or_zero(h)
Ss = _finite_or_zero(Ss)
dt_ref_u = finite_floor(dt_ref_u, _EPSILON)
dt_ref_s = finite_floor(dt_ref_s, _EPSILON)
# ------------------------------------------------------
# ds statistics (meters). Must be graph-safe:
# use tf_cond, not a Python `if` on tf.shape().
# ------------------------------------------------------
def _ds_stats() -> tuple[Tensor, Tensor]:
ds = s[:, 1:, :] - s[:, :-1, :]
ds = _finite_or_zero(ds)
ds_abs = tf_abs(tf_reshape(ds, [-1]))
ds_ref = tf_stop_gradient(tf_reduce_mean(ds_abs))
ds_max = tf_stop_gradient(tf_reduce_max(ds_abs))
return ds_ref, ds_max
def _ds_stats_zero() -> tuple[Tensor, Tensor]:
z = tf_constant(0.0, tf_float32)
return z, z
# Horizon length H = shape(s)[1]
H_len = tf_shape(s)[1]
has_ds = tf_greater(H_len, tf_constant(1, tf_int32))
ds_ref, ds_max = tf_cond(
has_ds, _ds_stats, _ds_stats_zero
)
# ------------------------------------------------------
# Base scale from ds (step / rate).
# ------------------------------------------------------
if mode == "step":
cons = tf_maximum(ds_ref, 0.1 * ds_max)
elif mode == "time_unit":
cons = tf_maximum(
ds_ref / dt_ref_u,
0.1 * (ds_max / dt_ref_u),
)
else:
cons = tf_maximum(
ds_ref / dt_ref_s,
0.1 * (ds_max / dt_ref_s),
)
# ------------------------------------------------------
# Optional equilibrium / relaxation term.
# ------------------------------------------------------
if use_relax:
tau = finite_floor(tau, _EPSILON)
Hf = tf_maximum(_finite_or_zero(Hf), 0.0)
href = _finite_or_zero(href)
# dh >= 0 (drawdown / head loss proxy)
dh = tf_maximum(href - h, 0.0)
# 1D equilibrium settlement (meters)
s_eq = Ss * dh * Hf
# Misfit to equilibrium (meters)
eq_mis = tf_abs(_finite_or_zero(s_eq - s))
eq_vec = tf_reshape(eq_mis, [-1])
eq_ref = tf_stop_gradient(
tf_reduce_mean(eq_vec) + eps
)
eq_max = tf_stop_gradient(tf_reduce_max(eq_vec) + eps)
if mode == "step":
# In step mode, keep as meters/step.
cons = tf_maximum(cons, eq_ref)
cons = tf_maximum(cons, 0.1 * eq_max)
else:
# Relaxation rate: meters/second.
relax = tf_abs(eq_mis / (tau + eps))
if mode == "time_unit":
# Convert to meters/time_unit.
sec_u = seconds_per_time_unit(
time_units,
dtype=tf_float32,
)
relax = relax * sec_u
relax = _finite_or_zero(relax)
r_vec = tf_reshape(relax, [-1])
r_ref = tf_stop_gradient(
tf_reduce_mean(r_vec) + eps
)
r_max = tf_stop_gradient(
tf_reduce_max(r_vec) + eps
)
cons = tf_maximum(cons, r_ref)
cons = tf_maximum(cons, 0.1 * r_max)
# ------------------------------------------------------
# Final floor and stop-gradient.
# ------------------------------------------------------
cons = tf_maximum(cons, floor_t)
return tf_stop_gradient(cons)
def _gw_scale_core(
*,
h: Tensor,
Ss: Tensor,
dt_ref_s: Tensor,
time_units: str,
gw_units: str,
dh_dt: TensorLike | None,
div_K_grad_h: TensorLike | None,
Q: TensorLike | None,
floor: float,
) -> Tensor:
r"""
Compute the groundwater-flow residual scale.
This helper builds a robust, positive scale used to
non-dimensionalize (or weight) the groundwater PDE residual.
It is intended to stabilize training by ensuring the PDE
term has a comparable magnitude across batches, horizons,
and unit conventions.
The scale is computed from batch statistics and returned as
a stop-gradient value, so it does not backpropagate through
the normalization path.
Mathematical intent
-------------------
A common 2D groundwater-flow residual in specific-storage
form is:
.. math::
R_{gw}
= S_s \,\frac{\partial h}{\partial t}
- \nabla \cdot (K \nabla h)
- Q
where:
* :math:`h` is hydraulic head (m),
* :math:`S_s` is specific storage (1/m),
* :math:`K` is hydraulic conductivity (m/s),
* :math:`Q` is a volumetric forcing per storage thickness,
expressed here in compatible residual units.
In SI, each term has units of 1/s:
* storage term:
:math:`S_s \, \partial_t h` has
:math:`(1/m) (m/s) = 1/s`
* divergence term:
:math:`\nabla \cdot (K \nabla h)` has 1/s
under the standard Darcy form and consistent spatial
scaling handled upstream
* forcing term:
:math:`Q` is assumed already mapped to 1/s
The normalized residual is:
.. math::
\tilde{R}_{gw} = \frac{R_{gw}}{g_*}
where :math:`g_*` is the groundwater scale returned by this
function.
Scale construction
------------------
This helper uses three optional contributors:
A) storage-term magnitude
~~~~~~~~~~~~~~~~~~~~~~~~
A representative head-rate magnitude is estimated from either
a provided :math:`\partial_t h` (``dh_dt``) or from finite
differences across the horizon.
If ``dh_dt`` is not provided, define step head differences:
.. math::
\Delta h_{b,t} = h_{b,t+1} - h_{b,t}
Let :math:`\Delta t_{ref,s}` be a representative step size
in seconds (caller-provided ``dt_ref_s``). Define:
.. math::
\dot{h}_{ref} =
\frac{\mathrm{mean}(|\Delta h|)}{\Delta t_{ref,s}}
.. math::
\dot{h}_{max} =
\frac{\mathrm{max}(|\Delta h|)}{\Delta t_{ref,s}}
If ``dh_dt`` is provided, its mean and max absolute values
are used directly as :math:`\dot{h}_{ref}` and
:math:`\dot{h}_{max}`.
A representative storage coefficient magnitude is computed as:
.. math::
S_{s,ref} = \mathrm{mean}(|S_s|)
The storage-term reference scales are:
.. math::
s_{ref} = S_{s,ref}\,\dot{h}_{ref}
.. math::
s_{max} = S_{s,ref}\,\dot{h}_{max}
and the initial scale is:
.. math::
g_* \leftarrow \max(s_{ref}, 0.1\,s_{max})
B) divergence-term magnitude (optional)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
If the divergence contribution
:math:`d = \nabla \cdot (K \nabla h)` is provided upstream as
``div_K_grad_h``, its batch statistics contribute:
.. math::
d_{ref} = \mathrm{mean}(|d|) + \varepsilon
.. math::
d_{max} = \mathrm{max}(|d|) + \varepsilon
and the scale is updated:
.. math::
g_* \leftarrow \max(g_*, d_{ref}, 0.1\,d_{max})
C) forcing-term magnitude (optional)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
If forcing ``Q`` is provided (assumed in compatible units),
its batch statistics contribute:
.. math::
q_{ref} = \mathrm{mean}(|Q|) + \varepsilon
.. math::
q_{max} = \mathrm{max}(|Q|) + \varepsilon
and the scale is updated:
.. math::
g_* \leftarrow \max(g_*, q_{ref}, 0.1\,q_{max})
Flooring, units, and gradients
------------------------------
A positive floor is enforced:
.. math::
g_* \leftarrow \max(g_*, g_{floor})
The result is returned with stop-gradient:
.. math::
g_* \leftarrow \mathrm{stop\_grad}(g_*)
By default, :math:`g_*` is in SI (1/s). If ``gw_units`` is
``"time_unit"``, the scale is converted to 1/time_unit using
the seconds-per-unit constant :math:`\mathrm{sec}_u`:
.. math::
g_*^{(u)} = g_* \,\mathrm{sec}_u
so that dividing a residual expressed in 1/time_unit by
:math:`g_*^{(u)}` remains consistent.
Parameters
----------
Refer to :func:`compute_scales` for the meaning and expected
units of all inputs. This helper assumes inputs are already
broadcastable to ``(B, H, 1)`` and consistent with the PDE
assembly used upstream.
Returns
-------
gw_scale : Tensor
A scalar Tensor (or scalar-like Tensor) representing the
groundwater scale :math:`g_*`.
Units:
* default : 1/s
* gw_units="time_unit" : 1/time_unit
Notes
-----
Why mean and max
~~~~~~~~~~~~~~~~
Using :math:`\max(x_{ref}, 0.1 x_{max})` provides a robust
scale that tracks typical magnitudes while remaining
sensitive to rare but large values that can otherwise cause
under-scaling and unstable normalized residuals.
Why stop-gradient
~~~~~~~~~~~~~~~~~
The scale is a normalization constant, not a learnable
quantity. Allowing gradients through batch statistics can
create feedback loops where the model changes the scale
instead of reducing the residual.
What "compatible units" means for div and Q
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
This helper assumes ``div_K_grad_h`` and ``Q`` are already
expressed in the same residual units as the storage term,
i.e. compatible with 1/s in SI. Any spatial-coordinate
normalization and chain-rule rescaling should be handled
upstream.
Examples
--------
This function is not intended to be called directly.
Use :func:`compute_scales`, which computes both residual
scales and manages time-unit conversions and sanitization.
See Also
--------
compute_scales
Public interface that computes both cons and gw scales.
seconds_per_time_unit
Converts a time-unit string into seconds per unit.
dt_to_seconds
Converts per-step dt into SI seconds.
"""
eps = tf_constant(_EPSILON, tf_float32)
floor_t = tf_constant(float(floor), tf_float32)
# ------------------------------------------------------
# Sanitize inputs (avoid NaN/Inf in reductions).
# ------------------------------------------------------
h = _finite_or_zero(h)
Ss = _finite_or_zero(Ss)
dt_ref_s = finite_floor(dt_ref_s, _EPSILON)
# ------------------------------------------------------
# dh/dt reference (SI: m/s).
# If dh_dt is provided, use it directly.
# Otherwise estimate from consecutive steps in h.
# Must be graph-safe: use tf_cond for shape checks.
# ------------------------------------------------------
def _dh_dt_from_h() -> tuple[Tensor, Tensor]:
dh = h[:, 1:, :] - h[:, :-1, :]
dh = _finite_or_zero(dh)
dh_abs = tf_abs(tf_reshape(dh, [-1]))
dh_ref = tf_stop_gradient(tf_reduce_mean(dh_abs))
dh_max = tf_stop_gradient(tf_reduce_max(dh_abs))
dh_dt_ref = dh_ref / dt_ref_s
dh_dt_max = dh_max / dt_ref_s
return dh_dt_ref, dh_dt_max
def _dh_dt_zero() -> tuple[Tensor, Tensor]:
z = tf_constant(0.0, tf_float32)
return z, z
if dh_dt is None:
H_len = tf_shape(h)[1]
has_dh = tf_greater(H_len, tf_constant(1, tf_int32))
dh_dt_ref, dh_dt_max = tf_cond(
has_dh,
_dh_dt_from_h,
_dh_dt_zero,
)
else:
d = _finite_or_zero(dh_dt)
d = tf_abs(tf_reshape(d, [-1]))
dh_dt_ref = tf_stop_gradient(tf_reduce_mean(d))
dh_dt_max = tf_stop_gradient(tf_reduce_max(d))
# ------------------------------------------------------
# Ss reference (1/m).
# ------------------------------------------------------
Ss_abs = tf_abs(tf_reshape(_finite_or_zero(Ss), [-1]))
Ss_ref = tf_stop_gradient(tf_reduce_mean(Ss_abs))
# Storage term scale (1/s).
storage_ref = Ss_ref * dh_dt_ref
storage_max = Ss_ref * dh_dt_max
gw = tf_maximum(storage_ref, 0.1 * storage_max)
# ------------------------------------------------------
# Optional div term (already in compatible units).
# ------------------------------------------------------gim
if div_K_grad_h is not None:
divv = _finite_or_zero(div_K_grad_h)
divv = tf_abs(tf_reshape(divv, [-1]))
div_ref = tf_stop_gradient(tf_reduce_mean(divv) + eps)
div_max = tf_stop_gradient(tf_reduce_max(divv) + eps)
gw = tf_maximum(gw, div_ref)
gw = tf_maximum(gw, 0.1 * div_max)
# ------------------------------------------------------
# Optional forcing term Q (already in compatible units).
# ------------------------------------------------------
if Q is not None:
QQ = _finite_or_zero(Q)
QQ = tf_abs(tf_reshape(QQ, [-1]))
Q_ref = tf_stop_gradient(tf_reduce_mean(QQ) + eps)
Q_max = tf_stop_gradient(tf_reduce_max(QQ) + eps)
gw = tf_maximum(gw, Q_ref)
gw = tf_maximum(gw, 0.1 * Q_max)
# Floor for numerical stability.
gw = tf_maximum(gw, floor_t)
# ------------------------------------------------------
# Optional "per time_unit" conversion (non-SI).
# ------------------------------------------------------
if gw_units == "time_unit":
sec_u = seconds_per_time_unit(
time_units,
dtype=tf_float32,
)
gw = gw * sec_u
return tf_stop_gradient(gw)
[docs]
def compute_scales(
model,
*,
t: Tensor,
s_mean: Tensor,
h_mean: Tensor,
K_field: Tensor,
Ss_field: Tensor,
tau_field: TensorLike | None = None,
H_field: TensorLike | None = None,
h_ref_si: TensorLike | None = None,
Q: TensorLike | None = None,
dt: TensorLike | None = None,
time_units: str | None = None,
dh_dt: TensorLike | None = None,
div_K_grad_h: TensorLike | None = None,
verbose: int = 0,
) -> dict[str, Tensor]:
r"""
Compute robust normalization scales for physics residuals.
This function estimates per-batch (or per-sample) scale factors
used to non-dimensionalize physics residuals before squaring and
averaging. The goal is to make losses comparable across sites,
time spans, and coordinate encodings, and to prevent a single
residual from dominating due to unit magnitude alone.
The returned scales are typically used as:
.. math::
R_{cons}^{*} = \frac{R_{cons}}{s_{cons}}, \qquad
R_{gw}^{*} = \frac{R_{gw}}{s_{gw}},
where :math:`s_{cons}` and :math:`s_{gw}` are produced by this
function (with floors applied for numerical safety).
The routine is intentionally defensive. It sanitizes shapes to
``(B, H, 1)``, guards non-finite values, enforces positive dt,
and applies safe floors before any division or reduction.
Parameters
----------
model : Any
Model-like object holding configuration in
``model.scaling_kwargs`` and optionally ``model.time_units``
and ``model.h_ref``. This function reads:
* consolidation display mode from ``resolve_cons_units(sk)``
* groundwater display mode from ``resolve_gw_units(sk)``
* coordinate normalization flags via ``sk['coords_normalized']``
* coordinate ranges via ``coord_ranges(sk)``
* auto floors via ``resolve_auto_scale_floor(kind, sk)``
t : Tensor
Time coordinate tensor. Expected shape is ``(B, H, 1)`` or
``(B, H)``. Units follow the dataset time encoding. If
``coords_normalized=True``, ``t`` is assumed normalized and
is de-normalized using ``coord_ranges(sk)['t']`` when dt is
inferred internally.
s_mean : Tensor
Mean settlement state used for consolidation scaling.
Expected shape is ``(B, H, 1)`` or ``(B, H)``.
h_mean : Tensor
Mean head state used for scaling. Expected shape is
``(B, H, 1)`` or ``(B, H)``. Units should match the model
internal convention (typically SI meters).
K_field : Tensor
Effective conductivity field. Present for signature
compatibility and potential future scale heuristics. Current
logic does not require this argument directly.
Ss_field : Tensor
Effective specific storage field :math:`S_s`. Used by both
consolidation and groundwater scale heuristics. Expected
shape is broadcastable to ``(B, H, 1)``.
tau_field : Tensor, optional
Consolidation timescale :math:`tau` in seconds. Provide this
together with ``H_field`` to enable relaxation-aware
consolidation scaling.
H_field : Tensor, optional
Drained thickness :math:`H` in meters. Used with
``tau_field`` for relaxation-aware consolidation scaling.
h_ref_si : Tensor, optional
Reference head :math:`h_{ref}` in meters. If not provided,
the function falls back to ``model.h_ref`` (or 0.0). The
value is broadcast to ``(B, H, 1)`` and sanitized.
Q : Tensor, optional
Source term used in the groundwater residual scaling.
Expected shape is broadcastable to ``(B, H, 1)``.
dt : Tensor, optional
Time step tensor in the dataset time units. If provided,
it is used directly (after shape normalization). If None,
dt is inferred from ``t``. The inferred dt is de-normalized
when ``coords_normalized=True``.
time_units : str, optional
Name of the dataset time unit (e.g., "year", "day",
"second"). If None, the function resolves it from
``sk['time_units']`` or ``model.time_units``. It is used to
convert dt to seconds.
dh_dt : Tensor, optional
Precomputed :math:`dh/dt` in SI units (m/s). If provided,
groundwater scaling can use it directly rather than
reconstructing a representative magnitude.
div_K_grad_h : Tensor, optional
Precomputed divergence term for groundwater flow,
:math:`\nabla \cdot (K \nabla h)`, in SI units. If provided,
it is used as a representative magnitude for groundwater
scaling.
verbose : int, default=0
Verbosity level. If > 0, basic statistics of computed scales
may be printed.
Returns
-------
scales : dict[str, Tensor]
Dictionary with keys:
* ``'cons_scale'`` : Tensor
Scale for consolidation residuals.
* ``'gw_scale'`` : Tensor
Scale for groundwater-flow residuals.
Each value is shaped as ``(B, 1, 1)`` or broadcastable to
``(B, H, 1)``, depending on internal heuristics.
Notes
-----
**Why scaling is needed.**
Consolidation and groundwater residuals can differ by many
orders of magnitude depending on:
* the dataset time unit (years vs seconds),
* coordinate normalization spans (t, x, y),
* site geometry and hydro-mechanical priors,
* whether residuals are reported in SI or display units.
A stable scaling strategy prevents trivial unit choices from
changing optimization dynamics.
**dt construction and safety.**
If ``dt`` is not provided, dt is inferred as consecutive
differences along horizon:
* if :math:`H > 1`, :math:`dt_h = t_{h} - t_{h-1}`
* else, dt defaults to 1.0 (in dataset time units)
When ``coords_normalized=True``, dt is multiplied by the raw
time span ``t_range`` from ``coord_ranges(sk)`` to recover dt in
dataset time units. dt is then converted to seconds via
``dt_to_seconds(dt, time_units=...)``.
All dt paths apply:
* absolute value
* finite sanitization
* a positive floor
* a final lower bound using ``seconds_per_time_unit(time_units)``
This guards against degenerate dt values that would explode
scales.
**Relaxation-aware consolidation scaling.**
If both ``tau_field`` and ``H_field`` are provided, consolidation
scales may incorporate a relaxation time scale to better match
the form of the consolidation closure used by the model. If they
are not provided, a simpler heuristic is used.
**Groundwater scaling inputs.**
Groundwater scales are computed from representative magnitudes
of the groundwater PDE components, optionally using
``dh_dt`` and ``div_K_grad_h`` when provided. The scaling also
accounts for display unit policies returned by
``resolve_gw_units(sk)``.
**This function is not traced.**
This wrapper is not decorated with ``tf.function`` because it
accepts a Python ``model`` object. Callers may wrap the function
at a higher level if a stable tracing boundary is desired.
Examples
--------
Compute scales inside the physics path:
>>> scales = compute_scales(
... model,
... t=t,
... s_mean=s_inc_pred,
... h_mean=h_si,
... K_field=K_field,
... Ss_field=Ss_field,
... tau_field=tau_field,
... H_field=H_si,
... h_ref_si=h_ref_11,
... Q=Q_si,
... dt=dt_units,
... time_units=model.time_units,
... dh_dt=dh_dt,
... div_K_grad_h=dKdhx + dKdhy,
... )
Use the returned scales to normalize residuals:
>>> cons_scaled = R_cons / scales["cons_scale"]
>>> gw_scaled = R_gw / scales["gw_scale"]
See Also
--------
scale_residual
Applies a scale and floor to a residual tensor.
resolve_auto_scale_floor
Resolves "auto" floors for scale denominators.
ensure_si_derivative_frame
Converts raw autodiff derivatives to SI-consistent forms.
"""
sk = getattr(model, "scaling_kwargs", None) or {}
mode = resolve_cons_units(sk)
gw_units = resolve_gw_units(sk)
# --- Normalize ranks to (B,H,1).
s = tf_cast(s_mean, tf_float32)
h = tf_cast(h_mean, tf_float32)
if s.shape.rank == 2:
s = s[:, :, None]
if h.shape.rank == 2:
h = h[:, :, None]
# --- Time units (consistent source of truth).
if time_units is None:
time_units = (
get_sk(sk, "time_units", default=None)
or getattr(model, "time_units", None)
or "unitless"
)
def _diffs():
return tt[:, 1:, :] - tt[:, :-1, :]
def _ones():
return tf_zeros_like(s[:, :1, :]) + 1.0
# --- Build dt in *time_units*.
if dt is None:
tt = tf_cast(t, tf_float32)
if tt.shape.rank == 2:
tt = tt[:, :, None]
H = tf_shape(tt)[1]
# if (tt.shape.rank >= 2) and (tt.shape[1] > 1):
# dt_step = tt[:, 1:, :] - tt[:, :-1, :]
# else:
# dt_step = tf_zeros_like(s[:, :1, :]) + 1.0
dt_step = tf_cond(tf_greater(H, 1), _diffs, _ones)
# De-normalize if coords were normalized.
coords_norm = bool(sk.get("coords_normalized", False))
tR, _, _ = coord_ranges(sk)
if coords_norm and tR:
dt_step = dt_step * tf_cast(float(tR), tf_float32)
else:
dt_step = tf_cast(dt, tf_float32)
if dt_step.shape.rank == 2:
dt_step = dt_step[:, :, None]
# Sanitize dt before any conversion/reduction.
dt_step = tf_abs(_finite_or_zero(dt_step))
dt_step = finite_floor(dt_step, _EPSILON)
dt_sec = dt_to_seconds(dt_step, time_units=time_units)
dt_sec = tf_abs(_finite_or_zero(dt_sec))
dt_sec = finite_floor(dt_sec, _EPSILON)
# Scalar dt refs.
dt_ref_u = tf_reduce_mean(tf_reshape(dt_step, [-1]))
dt_ref_u = finite_floor(dt_ref_u, _EPSILON)
dt_ref_s = tf_reduce_mean(tf_reshape(dt_sec, [-1]))
dt_ref_s = finite_floor(dt_ref_s, _EPSILON)
# Prefer a sane SI lower bound when dt is broken.
sec_u = seconds_per_time_unit(
time_units,
dtype=tf_float32,
)
dt_ref_s = tf_maximum(dt_ref_s, sec_u)
# --- h_ref broadcast (finite).
if h_ref_si is None:
h_ref_si = tf_cast(
getattr(model, "h_ref", 0.0), tf_float32
)
href = tf_convert_to_tensor(h_ref_si, tf_float32)
href = tf_broadcast_to(href, tf_shape(h))
href = _finite_or_zero(href)
# --- Floors.
# cons_floor_def = _EPSILON
# if mode in ("step", "time_unit"):
# cons_floor_def = 1e-6
cons_floor = resolve_auto_scale_floor("cons", sk)
gw_floor = resolve_auto_scale_floor("gw", sk)
# cons_floor = float(
# get_sk(sk, "cons_scale_floor", default=cons_floor_def)
# )
# gw_floor = float(
# get_sk(sk, "gw_scale_floor", default=_EPSILON)
# )
# --- Optional tau/H (shape-safe).
use_relax = (tau_field is not None) and (
H_field is not None
)
if use_relax:
tau = tf_cast(tau_field, tf_float32)
Hf = tf_cast(H_field, tf_float32)
if tau.shape.rank == 2:
tau = tau[:, :, None]
if Hf.shape.rank == 2:
Hf = Hf[:, :, None]
tau = tf_broadcast_to(tau, tf_shape(h))
Hf = tf_broadcast_to(Hf, tf_shape(h))
else:
tau = tf_ones_like(h)
Hf = tf_zeros_like(h)
# --- Sanitize Ss once, then reuse.
Ss = tf_cast(Ss_field, tf_float32)
if Ss.shape.rank == 2:
Ss = Ss[:, :, None]
Ss = tf_broadcast_to(Ss, tf_shape(h))
Ss = _finite_or_zero(Ss)
cons_scale = _cons_scale_core(
s=s,
h=h,
Ss=Ss,
dt_ref_u=dt_ref_u,
dt_ref_s=dt_ref_s,
mode=mode,
time_units=time_units,
tau=tau,
Hf=Hf,
href=href,
use_relax=use_relax,
floor=cons_floor,
)
gw_scale = _gw_scale_core(
h=h,
Ss=Ss,
dt_ref_s=dt_ref_s,
time_units=time_units,
gw_units=gw_units,
dh_dt=dh_dt,
div_K_grad_h=div_K_grad_h,
Q=Q,
floor=gw_floor,
)
if verbose > 0:
_stats("cons_scale", cons_scale)
_stats("gw_scale", gw_scale)
return {"cons_scale": cons_scale, "gw_scale": gw_scale}
[docs]
def resolve_auto_scale_floor(
key: str,
scaling_kwargs: dict[str, Any] | None,
default_val: float | str = "auto",
) -> float:
"""
Robustly determine a numerical stability floor for physics scales.
If the user provides a float in scaling_kwargs, it is respected.
If 'auto', we derive a safe floor based on float32 stability limits
converted to the active unit system (SI, time_units, or steps).
Baselines (SI):
- cons (velocity): 1e-7 m/s (~3 m/yr)
High floor because velocity residuals are often noise-dominated.
- gw (rate): 1e-9 1/s (~0.03 /yr)
Lower floor to capture subtler groundwater dynamics.
"""
sk = scaling_kwargs or {}
# 1. Check user override in config (e.g., "cons_scale_floor": 1e-12)
# We strip "auto" if it appears as a string literal.
val = get_sk(
sk, f"{key}_scale_floor", default=default_val
)
if isinstance(val, float | int) and not isinstance(
val, bool
):
return float(val)
if str(val).lower() != "auto":
try:
return float(val)
except (ValueError, TypeError):
pass # Fallthrough to auto logic
# 2. "Auto" Logic: Derive based on Units
time_units = get_sk(sk, "time_units", default="year")
# Calculate conversion factor: 1 "time_unit" = X seconds
try:
sec_per_unit = float(
seconds_per_time_unit(time_units)
)
except Exception:
sec_per_unit = (
31556952.0 # Default to year if unknown
)
# Define Safe SI Baselines (float32 stability thresholds)
# m/s for cons, 1/s for gw
SI_BASE_CONS = 1e-7
SI_BASE_GW = 1e-9
if key == "cons":
# Target units: "second", "time_unit", or "step"
resid_units = str(
get_sk(
sk, "cons_residual_units", default="second"
)
).lower()
if "second" in resid_units:
return SI_BASE_CONS
elif "time" in resid_units:
# Convert m/s -> m/year (or m/month, etc)
# floor = (m/s) * (s/unit) = m/unit
return SI_BASE_CONS * sec_per_unit
else:
# "step": treat roughly like SI (conservative)
return SI_BASE_CONS
elif key == "gw":
# Target units: "second" or "time_unit"
resid_units = str(
get_sk(
sk, "gw_residual_units", default="time_unit"
)
).lower()
if "second" in resid_units:
return SI_BASE_GW
elif "time" in resid_units:
# Convert 1/s -> 1/year
# floor = (1/s) * (s/unit) = 1/unit
return SI_BASE_GW * sec_per_unit
# Fallback safe default
return 1e-7
[docs]
def resolve_gw_units(sk):
v = get_sk(sk, "gw_residual_units", default="time_unit")
v = str(v).strip().lower()
if v in ("sec", "second", "seconds", "s"):
return "second"
return "time_unit"
[docs]
def resolve_cons_units(
sk: dict[str, Any] | None,
) -> str:
"""Normalize consolidation residual units."""
if not sk:
return "second"
v = get_sk(sk, "cons_residual_units", default="second")
mode = str(v).strip().lower()
if mode in ("s", "sec", "secs", "seconds"):
mode = "second"
elif mode in ("tu", "time", "timeunit", "time_units"):
mode = "time_unit"
elif mode in ("step", "index", "unitless"):
mode = "step"
if mode not in ("step", "time_unit", "second"):
mode = "second"
return mode
# ---------------------------------------------------------------------
# Settlement-kind adaptation
# ---------------------------------------------------------------------
[docs]
def settlement_state_for_pde(
s_pred_si: Tensor,
t: Tensor,
*,
scaling_kwargs: dict[str, Any] | None = None,
inputs: dict[str, Tensor] | None = None,
time_units: str | None = None,
baseline_keys: Sequence[str] = (
"s0_si",
"subs0_si",
"s_ref_si",
"subs_ref_si",
),
dt: TensorLike | None = None,
return_incremental: bool = True,
verbose: int = 0,
) -> Tensor:
r"""
Map predicted settlement output into a PDE-ready settlement state.
This helper converts a model settlement output ``s_pred_si`` into
a consistent settlement time series in SI meters that can be used
as the state variable in the consolidation residual and related
physics terms.
The model can represent settlement in different output modes
controlled by ``scaling_kwargs['subsidence_kind']``:
* ``"cumulative"`` : ``s_pred_si`` already represents cumulative
settlement :math:`s(t)` (meters).
* ``"increment"`` : ``s_pred_si`` represents per-step increments
:math:`\Delta s_h` (meters per step).
* ``"rate"`` : ``s_pred_si`` represents a settlement rate
:math:`ds/dt` (meters per time unit).
The function first constructs a cumulative series :math:`s(t)`
and then optionally returns the incremental state
:math:`s_{inc}(t)` used by the ODE/PDE residuals.
Parameters
----------
s_pred_si : Tensor
Predicted settlement output in SI meters (or SI meters per
time unit when ``subsidence_kind="rate"``). Expected shapes:
* ``(B, H, 1)`` (preferred)
* ``(B, H)`` will be promoted to ``(B, H, 1)``
t : Tensor
Time coordinate used to infer :math:`\Delta t` when
``subsidence_kind="rate"`` and ``dt`` is not provided.
Expected shape is ``(B, H, 1)`` or ``(B, H)``.
scaling_kwargs : dict, optional
Scaling and configuration dictionary. This function reads
``subsidence_kind`` via:
``get_sk(sk, 'subsidence_kind', default='cumulative')``
If not provided, defaults to ``{}``.
inputs : dict[str, Tensor], optional
Optional batch inputs that may contain a baseline settlement
value :math:`s_0` (SI meters). If provided, the function
searches for the first available tensor among ``baseline_keys``
and uses it as :math:`s_0`.
time_units : str, optional
Name of the dataset time unit (e.g., "year", "day"). This
argument is informational here and is logged for diagnostics.
When ``subsidence_kind="rate"``, the interpretation of
``s_pred_si`` is "meters per time unit".
baseline_keys : Sequence[str], default=("s0_si", "subs0_si",
"s_ref_si", "subs_ref_si")
Candidate keys to locate a baseline settlement tensor
:math:`s_0` in ``inputs``. The first match found is used.
dt : Tensor, optional
Time step per horizon in dataset time units. Used only when
``subsidence_kind="rate"``. Expected shape is ``(B, H, 1)``
or ``(B, H)``. If None, dt is inferred from ``t`` by finite
differences, with a fallback of 1.0 for the first step.
return_incremental : bool, default=True
If True, return the incremental settlement state:
.. math::
s_{inc}(t_h) = s(t_h) - s_0,
shaped like ``(B, H, 1)``. If False, return the cumulative
settlement series :math:`s(t_h)`.
verbose : int, default=0
Verbosity level. When > 0, prints basic diagnostics of the
selected mode and intermediate tensors.
Returns
-------
s_state : Tensor
Settlement state in SI meters with shape ``(B, H, 1)``.
If ``return_incremental=True`` the output is
:math:`s_{inc}(t)` (incremental since :math:`s_0`). Otherwise
the output is the cumulative series :math:`s(t)`.
Notes
-----
**Baseline handling.**
The baseline :math:`s_0` is interpreted as the settlement value
at the reference time :math:`t_0` used by the physics residuals.
If no baseline is provided, :math:`s_0` defaults to zero with
shape ``(B, 1, 1)`` and is broadcast over the horizon.
**Cumulative construction.**
The function builds a cumulative settlement series :math:`s(t)`
according to ``subsidence_kind``:
1) ``subsidence_kind="cumulative"``
``s_pred_si`` is assumed to already represent :math:`s(t)`:
.. math::
s(t_h) = s_{pred}(t_h).
This includes cases where the caller already added a baseline,
e.g., :math:`s(t) = s_0 + s_{inc}(t)`.
2) ``subsidence_kind="increment"``
``s_pred_si`` is interpreted as per-step increments:
.. math::
s(t_h) = s_0 + \sum_{j=0}^{h} \Delta s_j.
3) ``subsidence_kind="rate"``
``s_pred_si`` is interpreted as a rate in meters per time unit:
.. math::
\Delta s_h = \left(\frac{ds}{dt}\right)_h \Delta t_h,
\qquad
s(t_h) = s_0 + \sum_{j=0}^{h} \Delta s_j.
If ``dt`` is not provided, :math:`\Delta t_h` is inferred from
the time coordinate tensor ``t`` using finite differences. The
first step uses a fallback of 1.0 (for backward compatibility).
**Incremental state for PDE/ODE residuals.**
Many physics residuals are written for an incremental settlement
state :math:`s_{inc}(t)` that starts at zero at :math:`t_0`. When
``return_incremental=True`` the function returns:
.. math::
s_{inc}(t_h) = s(t_h) - s_0.
This makes it safe to concatenate an explicit initial state
(e.g., ``s0_inc=0``) when constructing a state sequence for an
exact-step consolidation integrator.
Examples
--------
Convert per-step increments to an incremental PDE state:
>>> sk = {"subsidence_kind": "increment"}
>>> s_inc = settlement_state_for_pde(
... s_pred_si=ds_pred_m,
... t=coords_t,
... scaling_kwargs=sk,
... inputs={"s0_si": s0_m},
... return_incremental=True,
... )
Convert a rate output using provided dt:
>>> sk = {"subsidence_kind": "rate"}
>>> s_inc = settlement_state_for_pde(
... s_pred_si=dsdt_pred_m_per_u,
... t=coords_t,
... dt=dt_units,
... scaling_kwargs=sk,
... inputs={"s0_si": s0_m},
... return_incremental=True,
... )
Return the cumulative series instead:
>>> s_cum = settlement_state_for_pde(
... s_pred_si=s_pred_m,
... t=coords_t,
... scaling_kwargs={"subsidence_kind": "cumulative"},
... return_incremental=False,
... )
See Also
--------
compute_consolidation_step_residual
Builds the consolidation residual from settlement and head
states.
cons_step_to_cons_residual
Converts a step residual into a residual consistent with the
PDE time convention.
integrate_consolidation_mean
Integrates a consolidation mean settlement trajectory used as
a physics-driven prediction path.
"""
sk = scaling_kwargs or {}
kind = (
str(
get_sk(
sk, "subsidence_kind", default="cumulative"
)
)
.strip()
.lower()
)
s = tf_cast(s_pred_si, tf_float32)
if (
getattr(s, "shape", None) is not None
and s.shape.rank == 2
):
s = s[:, :, None]
# --- baseline s0 (SI meters) ---
s0 = None
if inputs is not None:
for k in baseline_keys:
if (k in inputs) and (inputs[k] is not None):
s0 = tf_cast(inputs[k], tf_float32)
r = tf_rank(s0)
s0 = tf_cond(
tf_equal(r, 1),
lambda: s0[:, None, None],
lambda: tf_cond(
tf_equal(r, 2),
lambda: s0[:, :, None],
lambda: s0,
),
)
break
if s0 is None:
s0 = tf_zeros_like(s[:, :1, :])
vprint(verbose, "settlement_kind=", kind)
vprint(verbose, "s_pred_si=", s)
vprint(verbose, "s0=", s0)
vprint(verbose, "time_units=", time_units)
# -------------------------------------------------------------
# Build cumulative series s_cum(t) first (same shape as s)
# -------------------------------------------------------------
if kind == "cumulative":
s_cum = s # may include baseline (as in call(): s0_cum + s_inc)
elif kind == "increment":
# s is Îs per step
s_cum = s0 + tf_cumsum(s, axis=1)
elif kind == "rate":
# s is ds/dt (meters / time_unit)
if dt is not None:
dtt = tf_cast(dt, tf_float32)
if (
getattr(dtt, "shape", None) is not None
and dtt.shape.rank == 2
):
dtt = dtt[:, :, None]
ds = s * dtt
else:
tt = tf_cast(t, tf_float32)
if (
getattr(tt, "shape", None) is not None
and tt.shape.rank == 2
):
tt = tt[:, :, None]
dtn = tt[:, 1:, :] - tt[:, :-1, :]
# fallback default for first step (kept for backward compat)
dt0 = tf_zeros_like(tt[:, :1, :]) + 1.0
ds = s * tf_concat([dt0, dtn], axis=1)
s_cum = s0 + tf_cumsum(ds, axis=1)
vprint(verbose, "t=", tt)
vprint(verbose, "ds=", ds)
else:
raise ValueError(
f"Unsupported subsidence_kind={kind!r}. "
"Use one of {'cumulative','increment','rate'}."
)
# -------------------------------------------------------------
# Return incremental ODE state if requested: s_inc(t)=s_cum(t)-s0
# -------------------------------------------------------------
if return_incremental:
s0H = s0 + tf_zeros_like(
s_cum
) # broadcast to (B,H,1)
return s_cum - s0H
return s_cum
[docs]
def to_rms(
x: Tensor,
*,
axis: AxisLike = None,
keepdims: bool = False,
eps: float | None = None,
ms_floor: float | None = None,
rms_floor: float | None = None,
nan_policy: str = "propagate",
dtype: Any = None,
) -> Tensor:
r"""
Compute the root-mean-square (RMS) of a tensor.
This utility computes:
.. math::
\mathrm{RMS}(x)
= \sqrt{\mathbb{E}[x^2]}
over the requested reduction axes. It is designed for robust
diagnostics in physics-informed training loops, where tensors may
contain extremely small values (needing ``float64``) or occasional
non-finite entries (handled via ``nan_policy``).
Parameters
----------
x : Tensor
Input tensor. Any shape is accepted. The computation is
performed in ``dtype`` (default float32).
axis : int or Sequence[int] or None, default=None
Axis or axes to reduce.
* If None, reduce over all dimensions and return a scalar.
* If an int or sequence, reduce only those axes.
keepdims : bool, default=False
If True, keep reduced dimensions with length 1.
eps : float or None, default=None
Optional lower bound applied to the mean-square value before
the square root is taken. If provided and > 0, the mean-square
is floored as:
.. math::
\mathrm{MS} = \max(\mathrm{MS}, \mathrm{eps})
where :math:`\mathrm{MS} = \mathbb{E}[x^2]`.
ms_floor : float or None, default=None
Alias for an additional mean-square floor applied after
``eps``. If provided and > 0, it is applied as:
.. math::
\mathrm{MS} = \max(\mathrm{MS}, \mathrm{ms\_floor})
This can be useful when you want a hard numerical floor but
prefer to keep ``eps`` reserved for "epsilon-like" smoothing.
rms_floor : float or None, default=None
Optional lower bound applied after taking the square root.
If provided and > 0:
.. math::
\mathrm{RMS} = \max(\mathrm{RMS}, \mathrm{rms\_floor})
nan_policy : {"propagate", "raise", "omit"}, default="propagate"
Policy for handling non-finite values (NaN/Inf):
* ``"propagate"``:
Use the standard reduction. Non-finite values propagate
through ``mean`` and the RMS becomes non-finite.
* ``"raise"``:
Assert that ``x`` is all finite before reducing, raising an
error if NaN/Inf is present.
* ``"omit"``:
Ignore non-finite entries by treating them as missing. The
RMS is computed from finite entries only:
.. math::
\mathrm{MS}
= \frac{\sum x_i^2}{N_f}
where :math:`N_f` is the count of finite entries along the
reduced axes (clipped to at least 1).
dtype : Any, default=None
Compute dtype. If None, uses ``tf_float32`` for speed.
Pass ``dtype=tf_float64`` when diagnosing very small residuals
or when accumulated rounding error matters.
Returns
-------
rms : Tensor
RMS value reduced along ``axis``. If ``axis=None`` the result
is a scalar tensor; otherwise it has the reduced shape.
Notes
-----
**Flooring behavior.**
Floors are opt-in. If ``eps is None`` and ``ms_floor is None``,
no flooring is applied to the mean-square. If ``rms_floor is
None``, no flooring is applied to the RMS.
A common pattern for stable logging of near-zero residuals is to
use a small mean-square floor with float64 diagnostics:
* ``dtype=tf_float64`` to reduce rounding error.
* ``ms_floor`` to avoid taking ``sqrt(0)`` when a later operation
applies ``log`` or divides by RMS.
**Non-finite handling.**
``nan_policy="omit"`` is intended for diagnostics and logging.
For training-time physics losses, prefer cleaning tensors before
the loss is computed, so gradients are well-defined.
Examples
--------
Compute RMS over all entries:
>>> r = to_rms(x)
Compute per-batch RMS (reduce over horizon and channel axes):
>>> r_b = to_rms(x, axis=(1, 2))
Omit non-finite values when logging a residual map:
>>> eps_gw = to_rms(R_gw, nan_policy="omit", dtype=tf_float64)
Apply a small mean-square floor for stable downstream log:
>>> eps = to_rms(R, ms_floor=1e-30, dtype=tf_float64)
See Also
--------
scale_residual
Scales residuals by computed characteristic scales.
guard_scale_with_residual
Ensures a scale is safe when residuals are near zero.
"""
pol = str(nan_policy or "propagate").strip().lower()
if pol not in ("propagate", "raise", "omit"):
pol = "propagate"
dt = tf_float32 if dtype is None else dtype
x = tf_cast(x, dt)
if pol == "raise":
tf_debugging.assert_all_finite(
x,
"to_rms(): x has NaN/Inf",
)
ms = tf_reduce_mean(
tf_square(x),
axis=axis,
keepdims=keepdims,
)
elif pol == "omit":
finite = tf_math.is_finite(x)
x0 = tf_where(finite, x, tf_zeros_like(x))
num = tf_reduce_sum(
tf_square(x0),
axis=axis,
keepdims=keepdims,
)
den = tf_reduce_sum(
tf_cast(finite, dt),
axis=axis,
keepdims=keepdims,
)
den = tf_maximum(
den,
tf_constant(1.0, dt),
)
ms = num / den
else: # propagate
ms = tf_reduce_mean(
tf_square(x),
axis=axis,
keepdims=keepdims,
)
# mean-square floors (opt-in)
if eps is not None and float(eps) > 0.0:
ms = tf_maximum(
ms,
tf_constant(float(eps), dt),
)
if ms_floor is not None and float(ms_floor) > 0.0:
ms = tf_maximum(
ms,
tf_constant(float(ms_floor), dt),
)
rms = tf_sqrt(ms)
if rms_floor is not None and float(rms_floor) > 0.0:
rms = tf_maximum(
rms,
tf_constant(float(rms_floor), dt),
)
return rms
def _as_bool(x: Any, default: bool = False) -> bool:
"""Parse bool-like values robustly (bool/int/str)."""
if x is None:
return bool(default)
if isinstance(x, bool):
return x
if isinstance(x, int | float):
return bool(int(x))
if isinstance(x, str):
s = x.strip().lower()
if s in {"1", "true", "t", "yes", "y", "on"}:
return True
if s in {"0", "false", "f", "no", "n", "off"}:
return False
return bool(default)
def _cast_lower_str(v):
return str(v).strip().lower()
def _cast_optional_float(v):
if v is None:
return None
if isinstance(v, str):
s = v.strip().lower()
if s in {"", "none", "null"}:
return None
return float(v)
[docs]
def resolve_cons_drawdown_options(
scaling_kwargs,
*,
default_mode: str = "smooth_relu",
default_rule: str = "ref_minus_mean",
default_stop_grad_ref: bool = True,
default_zero_at_origin: bool = False,
default_clip_max: float | None = None,
default_relu_beta: float = 20.0,
) -> dict[str, Any]:
"""Resolve consolidation drawdown options from scaling_kwargs.
Supported keys (prefer the 'cons_*' names):
- cons_drawdown_mode / drawdown_mode
- cons_drawdown_rule / drawdown_rule
- cons_stop_grad_ref / stop_grad_ref
- cons_drawdown_zero_at_origin / drawdown_zero_at_origin
- cons_drawdown_clip_max / drawdown_clip_max
- cons_relu_beta / relu_beta
Returns
-------
dict with keys:
drawdown_mode, drawdown_rule, stop_grad_ref,
drawdown_zero_at_origin, drawdown_clip_max, relu_beta
"""
sk = scaling_kwargs or {}
mode = get_sk(
sk,
"cons_drawdown_mode",
default=default_mode,
cast=_cast_lower_str,
)
rule = get_sk(
sk,
"cons_drawdown_rule",
default=default_rule,
cast=_cast_lower_str,
)
stopg = get_sk(
sk,
"cons_stop_grad_ref",
default=default_stop_grad_ref,
cast=lambda x: _as_bool(x, default_stop_grad_ref),
)
zero0 = get_sk(
sk,
"cons_drawdown_zero_at_origin",
default=default_zero_at_origin,
cast=lambda x: _as_bool(x, default_zero_at_origin),
)
clipm = get_sk(
sk,
"cons_drawdown_clip_max",
default=default_clip_max,
cast=_cast_optional_float,
)
beta = get_sk(
sk,
"cons_relu_beta",
default=default_relu_beta,
cast=float,
)
allowed_modes = {
"smooth_relu",
"relu",
"softplus",
"none",
}
allowed_rules = {"ref_minus_mean", "mean_minus_ref"}
if mode not in allowed_modes:
pol = _cast_lower_str(
get_sk(
sk, "scaling_error_policy", default="raise"
)
)
if pol == "raise":
raise ValueError(
f"drawdown_mode must be {sorted(allowed_modes)}; got {mode!r}"
)
mode = default_mode
if rule not in allowed_rules:
pol = _cast_lower_str(
get_sk(
sk, "scaling_error_policy", default="raise"
)
)
if pol == "raise":
raise ValueError(
f"drawdown_rule must be {sorted(allowed_rules)}; got {rule!r}"
)
rule = default_rule
return {
"drawdown_mode": mode,
"drawdown_rule": rule,
"stop_grad_ref": stopg,
"drawdown_zero_at_origin": zero0,
"drawdown_clip_max": clipm,
"relu_beta": beta,
}
# ---------------------------------------
# Helpers
# ---------------------------------------
[docs]
def normalize_time_units(u: str | None) -> str:
"""Normalize time unit strings."""
if u is None:
return "unitless"
s = str(u).strip().lower().replace(" ", "")
if "/" in s:
s = s.split("/", 1)[1]
if s.startswith("1/"):
s = s[2:]
if s == "secs":
s = "sec"
elif s == "yrs":
s = "yr"
elif s == "mins":
s = "min"
elif s == "hrs":
s = "hr"
return s
[docs]
def seconds_per_time_unit(
time_units: str | None,
*,
dtype=tf_float32,
) -> Tensor:
"""Seconds-per-unit."""
key = normalize_time_units(time_units)
if key not in TIME_UNIT_TO_SECONDS:
keys = sorted(TIME_UNIT_TO_SECONDS.keys())
raise ValueError(
f"Unsupported time_units={time_units!r}. "
f"Supported: {keys}"
)
return tf_constant(
float(TIME_UNIT_TO_SECONDS[key]), dtype=dtype
)
# ---------------------------------------------------------------------
# v3.2 helpers: physics-driven mean settlement via stable stepping
# ---------------------------------------------------------------------
[docs]
def ensure_3d(x: Tensor) -> Tensor:
"""
Return a rank-3 tensor, preferring static rank when available.
Rules
-----
r=0 -> (1,1,1)
r=1 -> (1,N,1)
r=2 -> (B,H,1)
r=3 -> unchanged
"""
x = tf_convert_to_tensor(x)
r_static = x.shape.rank
# --- Fast path: static rank known (works great with KerasTensors) ---
if r_static is not None:
if r_static == 0:
# scalar -> (1,1,1)
return tf_reshape(x, [1, 1, 1])
if r_static == 1:
# (B,) -> (B,1,1)
n = tf_shape(x)[0]
return tf_reshape(x, [1, n, 1])
if r_static == 2:
return tf_expand_dims(x, axis=-1)
if r_static == 3:
return x
raise ValueError(
f"_ensure_3d: rank {r_static} not supported"
)
# --- Fallback: dynamic rank (only if static is unknown) ---
r = tf_rank(x)
def r0():
return tf_reshape(x, [1, 1, 1])
def r1():
n = tf_shape(x)[0]
return tf_reshape(x, [1, n, 1])
def r2():
# (B,H) -> (B,H,1)
return tf_expand_dims(x, axis=-1)
def r3():
# already (B,H,1)
return x
x = tf_cond(tf_equal(r, 0), r0, lambda: x)
x = tf_cond(tf_equal(tf_rank(x), 1), r1, lambda: x)
x = tf_cond(tf_equal(tf_rank(x), 2), r2, lambda: x)
tf_debugging.assert_equal(
tf_rank(x), 3, message="_ensure_3d must return rank-3"
)
return x
def _ensure_3d(x: Tensor) -> Tensor:
"""Ensure (B,T,1) shape."""
if (
getattr(x, "shape", None) is not None
and x.shape.rank == 2
):
return x[:, :, None]
return x
def _broadcast_like(
x: TensorLike | None, like: Tensor
) -> Tensor:
"""Convert and broadcast x to the shape of `like` (dtype preserved)."""
if x is None:
return tf_zeros_like(like)
xt = tf_convert_to_tensor(x, dtype=like.dtype)
return tf_broadcast_to(xt, tf_shape(like))
[docs]
def dt_to_seconds(
dt: Tensor, *, time_units: str | None
) -> Tensor:
"""dt(time_units) -> seconds."""
dt = tf_convert_to_tensor(dt)
dt = tf_cast(dt, tf_float32)
dt = _finite_or_zero(dt) # NaN/Inf -> 0
dt = tf_maximum(
dt, tf_constant(0.0, dt.dtype)
) # no negative dt
sec = seconds_per_time_unit(time_units, dtype=dt.dtype)
return dt * sec
[docs]
def rate_to_per_second(
dz_dt: Tensor,
*,
time_units: str | None,
) -> Tensor:
"""d/d(time_units) -> d/ds."""
sec = seconds_per_time_unit(time_units, dtype=dz_dt.dtype)
return dz_dt / (sec + tf_constant(_EPSILON, dz_dt.dtype))
[docs]
def smooth_relu(x: Tensor, *, beta: float = 20.0) -> Tensor:
"""Smooth approximation to relu(x) with controlled curvature."""
b = tf_constant(float(beta), dtype=x.dtype)
return tf_softplus(b * x) / b
[docs]
def positive(x: Tensor, *, eps: float = _EPSILON) -> Tensor:
"""Softplus positivity."""
return tf_softplus(x) + tf_constant(eps, x.dtype)
def _stats(name: str, x: Tensor) -> None:
x = tf_cast(x, tf_float32)
tf_print(
name,
"shape=",
tf_shape(x),
"min=",
tf_reduce_min(x),
"mean=",
tf_reduce_mean(x),
"max=",
tf_reduce_max(x),
summarize=8,
)
def _frac_leq_zero(x: Tensor) -> Tensor:
x = tf_cast(x, tf_float32)
return tf_reduce_mean(tf_cast(x <= 0.0, tf_float32))
def _assert_grads_finite(
grads: list[TensorLike | None],
vars_: list[Tensor],
) -> None:
for g, v in zip(grads, vars_, strict=False):
if g is None:
continue
tf_debugging.assert_all_finite(
g,
f"NaN/Inf grad for {v.name}",
)