# SPDX-License-Identifier: Apache-2.0
# GeoPrior-v3 https://github.com/earthai-tech/geoprior-v3
# Copyright (c) 2026-present
# Author: LKouadio <https://lkouadio.com>
r"""Core step computations for subsidence physics evaluation."""
from __future__ import annotations
from typing import Any
from ...compat.types import TensorLike
from .. import KERAS_DEPS
from ..utils import get_tensor_from
from .batch_io import _get_coords
from .debugs import (
dbg_step2_coords_checks,
dbg_step9_losses,
dbg_step33_physics_fields,
dbg_step33_physics_logits,
)
from .derivatives import (
compute_head_pde_derivatives_raw,
ensure_si_derivative_frame,
)
from .losses import (
assemble_physics_loss,
build_physics_bundle,
pack_eval_physics,
)
from .maths import (
_get_bounds_loss_cfg,
compose_physics_fields,
compute_bounds_residual,
compute_consolidation_step_residual,
compute_gw_flow_residual,
compute_mv_prior,
compute_scales,
compute_smoothness_prior,
cons_step_to_cons_residual,
guard_scale_with_residual,
q_to_gw_source_term_si,
resolve_auto_scale_floor,
resolve_cons_drawdown_options,
resolve_gw_units,
scale_residual,
seconds_per_time_unit,
settlement_state_for_pde,
to_rms,
)
from .stability import (
clamp_physics_logits,
compute_physics_warmup_gate,
sanitize_scales,
)
from .utils import (
get_h_ref_si,
get_s_init_si,
get_sk,
gwl_to_head_m,
infer_dt_units_from_t,
to_si_head,
to_si_subsidence,
to_si_thickness,
validate_scaling_kwargs,
)
K = KERAS_DEPS
Tensor = K.Tensor
GradientTape = K.GradientTape
tf_broadcast_to = K.broadcast_to
tf_cast = K.cast
tf_concat = K.concat
tf_cond = K.cond
tf_constant = K.constant
tf_convert_to_tensor = K.convert_to_tensor
tf_equal = K.equal
tf_expand_dims = K.expand_dims
tf_float32 = K.float32
tf_float64 = K.float64
tf_greater_equal = K.greater_equal
tf_int32 = K.int32
tf_maximum = K.maximum
tf_rank = K.rank
tf_reduce_mean = K.reduce_mean
tf_reshape = K.reshape
tf_shape = K.shape
tf_square = K.square
tf_stop_gradient = K.stop_gradient
tf_tile = K.tile
tf_zeros_like = K.zeros_like
def _mean_if_quantiles(x: Tensor) -> Tensor:
"""Mean over Q axis if present; ensure (B,H,1)."""
r = tf_rank(x)
x = tf_cond(
tf_greater_equal(r, 3),
lambda: tf_reduce_mean(x, axis=2),
lambda: x,
)
r2 = tf_rank(x)
x = tf_cond(
tf_equal(r2, 2),
lambda: tf_expand_dims(x, axis=-1),
lambda: x,
)
return x
def _ensure_bh1(x: Tensor, like: Tensor) -> Tensor:
"""Force (B,H,1) and broadcast to `like`."""
r = tf_rank(x)
x = tf_cond(
tf_equal(r, 2),
lambda: tf_reshape(
x,
[tf_shape(x)[0], tf_shape(x)[1], 1],
),
lambda: x,
)
return x + tf_zeros_like(like)
def _coords_to_bh3(model: Any, coords: Tensor) -> Tensor:
"""Ensure coords is (B,H,3)."""
if coords.shape.rank == 2:
coords = tf_expand_dims(coords, axis=1)
H = int(getattr(model, "forecast_horizon", 1))
coords = tf_tile(coords, [1, H, 1])
return coords
def _physics_is_on(model: Any) -> bool:
"""True if physics terms are enabled."""
if hasattr(model, "_physics_off"):
return not bool(model._physics_off())
return True
[docs]
def physics_core(
model: Any,
inputs: dict[str, TensorLike | None],
training: bool,
return_maps: bool = False,
*,
for_train: bool = False,
) -> dict[str, Any]:
r"""
Compute GeoPrior physics residuals and losses for a batch.
This function implements the shared physics pathway used by both
training and evaluation for GeoPrior-style PINN models. It is
designed to keep the physics logic consistent across:
* ``train_step()`` (when physics losses are added to the total loss)
* evaluation routines (when physics diagnostics are reported)
At a high level, the function performs:
1. Input preparation and SI conversions (thickness, head, coords).
2. Forward pass through the model to obtain data predictions and
physics logits.
3. Mapping of physics logits to bounded physical fields
(:math:`K`, :math:`S_s`, :math:`tau`) and the closure prior
:math:`tau_{phys}`.
4. Automatic differentiation to obtain PDE derivatives with respect
to the model coords.
5. Chain-rule scaling to SI-consistent derivatives.
6. Construction of residual maps for:
* consolidation relaxation residual,
* groundwater flow residual,
* time-scale prior residual,
* smoothness prior residual,
* bounds residual.
7. Optional nondimensionalization / residual scaling.
8. Assembly of physics losses, gating schedules, and diagnostic
epsilon metrics.
The returned dictionary contains predictions, auxiliary forward
outputs, packed physics values (for logging), and optionally the
full residual maps and fields.
Parameters
----------
model : object
Model instance providing GeoPrior-style methods and attributes.
The function expects the model to expose (at minimum):
* ``scaling_kwargs`` : dict
Resolved scaling and convention payload.
* ``time_units`` : str or None
Dataset time unit (for per-second conversions).
* ``forecast_horizon`` : int
Horizon length used to tile coords when needed.
* ``_forward_all(inputs, training=...)`` : callable
Forward pass returning ``(y_pred, aux)``.
* ``split_data_predictions(x)`` : callable
Split concatenated data head into subsidence and GWL.
* ``split_physics_predictions(x)`` : callable
Split concatenated physics head into
``(K_logits, Ss_logits, dlogtau_logits, Q_logits)``.
* ``pde_modes_active`` : iterable of str
Active PDE modes (e.g., {'consolidation', 'gw_flow'}).
* Optional gates: ``_q_gate()``, ``_subs_resid_gate()``.
* Optional physics switch: ``_physics_off()``.
The function is tolerant to partial capabilities and will
short-circuit when physics is disabled, but missing mandatory
signals (e.g., thickness) raise errors.
inputs : dict
Dict input batch following the GeoPrior batch API.
Required entries
----------------
* ``coords`` : Tensor
Coordinate tensor. Expected shape ``(B, H, 3)`` with order
(t, x, y). If shape is ``(B, 3)``, it is tiled across
horizon.
* ``H_field`` or ``soil_thickness`` : Tensor
Thickness field used by consolidation closure and priors.
Common optional entries
-----------------------
* ``static_features`` : Tensor
* ``dynamic_features`` : Tensor
* ``future_features`` : Tensor
* ``s0_si`` : Tensor (optional state injection)
Used by settlement-state formatting utilities.
The exact batch layout depends on your Stage-1 export. This
function relies on ``_get_coords(inputs)`` and ``get_tensor_from``
to locate inputs robustly.
training : bool
Forward-pass training flag passed to ``model._forward_all`` and
downstream field composition. Use True during training and
False during evaluation.
return_maps : bool, default False
If True, return additional intermediate tensors and residual
maps, including (K, Ss, tau, tau_prior, Q), SI thickness, SI head
and reference head, and both raw and scaled residual fields.
Enabling ``return_maps`` increases memory usage and is intended
for debugging, diagnostics, and research analysis.
for_train : bool, default False
If True, apply training-time gating schedules for physics loss
activation (warmup and ramp) based on optimizer step.
This flag is separate from ``training`` to allow evaluation-style
forward passes with training-time schedules when needed.
Returns
-------
out : dict
Output dictionary with the following common keys:
``'y_pred'`` : dict
Model predictions (at least ``'subs_pred'`` and ``'gwl_pred'``).
``'aux'`` : dict
Auxiliary forward outputs produced by the model forward path.
Commonly includes:
* ``data_mean_raw`` (optional),
* ``phys_mean_raw`` (required by this function).
``'physics'`` : dict or None
Physics bundle returned by :func:`build_physics_bundle`.
Contains loss scalars, epsilons, and diagnostics. If physics
is disabled, this is None.
``'physics_packed'`` : dict
Packed physics values suitable for logging in evaluation
mode. This is always returned (may be empty when physics off).
``'terms_scaled'`` : dict
Dictionary of physics loss terms after scheduling gates and
multipliers have been applied. Keys are stable across train
and eval for consistent logging.
``'dt_units'`` : Tensor
Inferred dataset time step size in dataset time units
(not seconds). This value is used in settlement-state and
certain Q conversions.
``'scales'`` : dict or None
Optional residual scaling dictionary when physics residual
scaling is enabled. May include per-term scale factors used
for nondimensionalization.
If ``return_maps=True``, additional keys include (non-exhaustive):
``'K_field'``, ``'Ss_field'``, ``'tau_field'``, ``'tau_phys'``,
``'Hd_eff'``, ``'H_si'``, ``'Q_si'``, ``'h_si'``,
``'h_ref_si_11'``, ``'R_cons'``, ``'R_gw'``, ``'R_prior'``,
``'R_smooth'``, ``'R_bounds'``, and scaled counterparts.
Raises
------
ValueError
If required inputs are missing (e.g., no thickness field).
ValueError
If coords do not have shape ``(B, H, 3)`` after coercion.
ValueError
If expected forward outputs are missing (e.g., missing
``'phys_mean_raw'``).
Notes
-----
Physics switch behavior
~~~~~~~~~~~~~~~~~~~~~~~
If the model indicates physics is disabled (via ``_physics_off``),
the function performs only the forward pass and returns:
* predictions and aux outputs,
* ``physics=None``,
* packed physics with ``physics=None``,
* empty scaled term dict.
This allows unified train/eval code paths without special casing.
Derivative handling and SI conversion
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Derivatives are computed via autodiff with respect to the coords
tensor fed to ``call()``. These raw derivatives are then converted
to SI-consistent derivatives using coordinate ranges and
conversions:
* normalized coords are rescaled by spans (and spans squared for
second derivatives),
* degree-based spatial coords are converted to meters when needed,
* time derivatives are converted to per-second using ``time_units``
unless SI time spans are already supplied.
Residual families
~~~~~~~~~~~~~~~~~
The core residual maps assembled by this function correspond to:
Groundwater flow
.. math::
R_{gw} = S_s \\, \partial_t h
- \nabla \cdot (K \\, \nabla h) - Q
Consolidation relaxation
.. math::
R_{cons} = \partial_t s - \frac{s_{eq}(h) - s}{tau}
Time-scale prior
A residual tying learned :math:`tau` to a closure prior
:math:`tau_{phys}` in log space (implementation-dependent).
Smoothness prior
A spatial smoothness regularizer on :math:`K` and :math:`S_s`
implemented via gradients of fields w.r.t. spatial coords.
Bounds residual
Residual measuring violation of declared bounds for
(H, K, S_s, tau) or their log transforms.
Scaling and nondimensionalization
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
When enabled (``model.scale_pde_residuals=True``), residuals may be
scaled by data-driven or physics-driven magnitudes to produce
dimensionless residuals with more comparable scales across sites.
Floors are applied to prevent division by near-zero scales.
Training gates
~~~~~~~~~~~~~~
When ``for_train=True``, the physics loss is gated by a warmup/ramp
schedule based on optimizer step:
* warmup: physics contribution is suppressed,
* ramp: physics contribution increases to full strength.
This improves stability in early training by letting the data head
learn a reasonable representation before enforcing physics strongly.
Examples
--------
Compute physics losses during training:
>>> out = physics_core(
... model=model,
... inputs=batch,
... training=True,
... for_train=True,
... )
>>> float(out["physics"]["physics_loss_scaled"])
0.0 # may be gated early in training
Evaluate and return residual maps for debugging:
>>> out = physics_core(
... model=model,
... inputs=batch,
... training=False,
... return_maps=True,
... )
>>> sorted([k for k in out if k.startswith("R_")])[:4]
['R_bounds', 'R_cons', 'R_gw', 'R_prior']
See Also
--------
geoprior.models.subsidence.derivatives.compute_head_pde_derivatives_raw
Compute raw autodiff PDE derivatives w.r.t. coords.
geoprior.models.subsidence.derivatives.ensure_si_derivative_frame
Convert raw derivatives to SI-consistent derivatives.
geoprior.models.subsidence.losses.assemble_physics_loss
Assemble physics loss scalars and term dictionaries.
geoprior.models.subsidence.losses.build_physics_bundle
Build a packed physics bundle used for logging and metrics.
geoprior.models.subsidence.maths.compose_physics_fields
Map logits to bounded physical fields and tau prior.
"""
sk = getattr(model, "scaling_kwargs", None) or {}
validate_scaling_kwargs(sk)
verbose = getattr(model, "verbose", 0) if for_train else 0
time_units = getattr(model, "time_units", None)
# ----------------------------------------------------------
# 1) Prepare H_si + coords + dt
# ----------------------------------------------------------
H_in = get_tensor_from(
inputs,
"H_field",
"soil_thickness",
auto_convert=True,
)
if H_in is None:
raise ValueError(
"physics_core() requires 'H_field' "
"(or 'soil_thickness') in inputs."
)
H_field = tf_convert_to_tensor(H_in, dtype=tf_float32)
H_si = to_si_thickness(H_field, sk)
H_floor = float(get_sk(sk, "H_floor_si", default=1e-3))
H_si = tf_maximum(H_si, tf_constant(H_floor, tf_float32))
coords = tf_convert_to_tensor(
_get_coords(inputs), tf_float32
)
coords = _coords_to_bh3(model, coords)
if coords.shape.rank != 3 or coords.shape[-1] != 3:
raise ValueError(
"coords must be (B,H,3) with (t,x,y)."
)
inputs_fwd = dict(inputs)
inputs_fwd["coords"] = coords
t = coords[..., 0:1]
dt_units = infer_dt_units_from_t(t, sk)
coords_norm = bool(
get_sk(sk, "coords_normalized", default=False)
)
coords_deg = bool(
get_sk(sk, "coords_in_degrees", default=False)
)
dbg_step2_coords_checks(
verbose=verbose,
coords=coords,
inputs=inputs,
)
# ----------------------------------------------------------
# 2) Physics OFF shortcut
# ----------------------------------------------------------
if not _physics_is_on(model):
y_pred, aux = model._forward_all(
inputs_fwd,
training=training,
)
return {
"y_pred": y_pred,
"aux": aux,
"physics": None,
"physics_packed": pack_eval_physics(
model,
physics=None,
),
"terms_scaled": {},
"dt_units": dt_units,
}
# ----------------------------------------------------------
# 3) Forward + AD derivatives (raw coord units)
# ----------------------------------------------------------
with GradientTape(persistent=True) as tape:
tape.watch(coords)
# dbg_step3_forward(...)
y_pred, aux = model._forward_all(
inputs_fwd,
training=training,
)
data_mean_raw = aux.get("data_mean_raw", None)
if data_mean_raw is not None:
subs_m, gwl_m = model.split_data_predictions(
data_mean_raw,
)
else:
subs_m = _mean_if_quantiles(y_pred["subs_pred"])
gwl_m = _mean_if_quantiles(y_pred["gwl_pred"])
subs_mean_raw = _mean_if_quantiles(subs_m)
gwl_mean_raw = _mean_if_quantiles(gwl_m)
gwl_si = to_si_head(
tf_cast(gwl_mean_raw, tf_float32), sk
)
h_si = gwl_to_head_m(gwl_si, sk, inputs=inputs_fwd)
phys_mean_raw = aux.get("phys_mean_raw", None)
if phys_mean_raw is None:
raise ValueError("Missing 'phys_mean_raw'.")
parts = model.split_physics_predictions(phys_mean_raw)
K_l, Ss_l, dlt_l, Q_l = parts
K_l, Ss_l, dlt_l, Q_l = clamp_physics_logits(
K_l,
Ss_l,
dlt_l,
Q_l,
)
freeze = bool(
get_sk(
sk,
"freeze_physics_fields_over_time",
default=True,
)
)
if freeze:
K_b = tf_broadcast_to(
tf_reduce_mean(K_l, axis=1, keepdims=True),
tf_shape(K_l),
)
Ss_b = tf_broadcast_to(
tf_reduce_mean(Ss_l, axis=1, keepdims=True),
tf_shape(Ss_l),
)
tau_b = tf_broadcast_to(
tf_reduce_mean(dlt_l, axis=1, keepdims=True),
tf_shape(dlt_l),
)
else:
K_b, Ss_b, tau_b = K_l, Ss_l, dlt_l
(
K_field,
Ss_field,
tau_field,
tau_phys,
Hd_eff,
dlogtau,
logK,
logSs,
log_tau,
log_tau_phys,
loss_bounds_barrier,
) = compose_physics_fields(
model,
coords_flat=coords,
H_si=H_si,
K_base=K_b,
Ss_base=Ss_b,
tau_base=tau_b,
training=training,
verbose=verbose,
)
dbg_step33_physics_logits(
verbose=verbose,
K_logits=K_l,
Ss_logits=Ss_l,
dlogtau_logits=dlt_l,
Q_logits=Q_l,
K_base=K_b,
Ss_base=Ss_b,
dlogtau_base=tau_b,
)
dbg_step33_physics_fields(
verbose=verbose,
K_field=K_field,
Ss_field=Ss_field,
tau_field=tau_field,
tau_phys=tau_phys,
Hd_eff=Hd_eff,
delta_log_tau=dlogtau,
logK=logK,
logSs=logSs,
log_tau=log_tau,
log_tau_phys=log_tau_phys,
)
deriv_raw = compute_head_pde_derivatives_raw(
tape,
coords,
h_si,
K_field,
Ss_field,
)
del tape
# ----------------------------------------------------------
# 4) Chain-rule conversion to SI
# ----------------------------------------------------------
# dbg_step5_chain_rule(...)
deriv_si, dmeta = ensure_si_derivative_frame(
dh_dt_raw=deriv_raw["dh_dt_raw"],
d_K_dh_dx_dx_raw=deriv_raw["d_K_dh_dx_dx_raw"],
d_K_dh_dy_dy_raw=deriv_raw["d_K_dh_dy_dy_raw"],
dK_dx_raw=deriv_raw["dK_dx_raw"],
dK_dy_raw=deriv_raw["dK_dy_raw"],
dSs_dx_raw=deriv_raw["dSs_dx_raw"],
dSs_dy_raw=deriv_raw["dSs_dy_raw"],
scaling_kwargs=sk,
time_units=time_units,
coords_normalized=coords_norm,
coords_in_degrees=coords_deg,
)
dh_dt = deriv_si["dh_dt"]
dKdhx = deriv_si["d_K_dh_dx_dx"]
dKdhy = deriv_si["d_K_dh_dy_dy"]
dK_dx = deriv_si["dK_dx"]
dK_dy = deriv_si["dK_dy"]
dSs_dx = deriv_si["dSs_dx"]
dSs_dy = deriv_si["dSs_dy"]
tR_tf = dmeta.get("t_range_units_tf", None)
# ----------------------------------------------------------
# 5) Q in SI + gate
# ----------------------------------------------------------
# dbg_step6_q(...)
if Q_l is None:
Q_si = tf_zeros_like(dh_dt)
else:
Q_si = q_to_gw_source_term_si(
model,
Q_l,
Ss_field=Ss_field,
H_field=H_si,
coords_normalized=coords_norm,
t_range_units=tR_tf,
time_units=time_units,
scaling_kwargs=sk,
verbose=verbose,
)
Q_si = _ensure_bh1(Q_si, like=dh_dt)
if hasattr(model, "_q_gate"):
q_gate = model._q_gate()
else:
q_gate = tf_constant(1.0, tf_float32)
Q_si = Q_si * q_gate
loss_q_reg = tf_reduce_mean(tf_square(Q_si))
q_rms = to_rms(Q_si)
if hasattr(model, "_subs_resid_gate"):
subs_gate = model._subs_resid_gate()
else:
subs_gate = tf_constant(0.0, tf_float32)
# ----------------------------------------------------------
# 6) Consolidation residual
# ----------------------------------------------------------
# dbg_step7_consolidation(...)
allow_resid = bool(
get_sk(sk, "allow_subs_residual", default=False)
)
cons_active = hasattr(
model, "pde_modes_active"
) and "consolidation" in getattr(
model,
"pde_modes_active",
(),
)
like_11 = h_si[:, :1, :1]
h_ref_11 = get_h_ref_si(model, inputs_fwd, like=like_11)
h_ref = h_ref_11 + tf_zeros_like(h_si)
s_inc_pred = tf_zeros_like(h_si)
if (not cons_active) or (cons_active and not allow_resid):
cons_res = tf_zeros_like(h_si)
else:
s_pred_si = to_si_subsidence(
tf_cast(subs_mean_raw, tf_float32),
sk,
)
s0_cum_11 = get_s_init_si(
model, inputs_fwd, like=like_11
)
pde_inputs = dict(inputs_fwd)
pde_inputs["s0_si"] = s0_cum_11
s_inc_pred = settlement_state_for_pde(
s_pred_si,
t,
scaling_kwargs=sk,
inputs=pde_inputs,
time_units=time_units,
dt=dt_units,
return_incremental=True,
)
s0_inc_11 = tf_zeros_like(s0_cum_11)
s_state = tf_concat([s0_inc_11, s_inc_pred], axis=1)
h_state = tf_concat([h_ref_11, h_si], axis=1)
dd = resolve_cons_drawdown_options(sk)
cons_step = compute_consolidation_step_residual(
s_state_si=s_state,
h_mean_si=h_state,
Ss_field=Ss_field,
H_field_si=H_si,
tau_field=tau_field,
h_ref_si=h_ref,
dt=dt_units,
time_units=time_units,
method="exact",
relu_beta=dd["relu_beta"],
drawdown_mode=dd["drawdown_mode"],
drawdown_rule=dd["drawdown_rule"],
stop_grad_ref=dd["stop_grad_ref"],
drawdown_zero_at_origin=dd[
"drawdown_zero_at_origin"
],
drawdown_clip_max=dd["drawdown_clip_max"],
verbose=verbose,
)
cons_res = cons_step_to_cons_residual(
cons_step,
dt_units=dt_units,
scaling_kwargs=sk,
time_units=time_units,
)
# ----------------------------------------------------------
# 7) GW residual + priors
# ----------------------------------------------------------
# dbg_step8_residuals(...)
gw_res = compute_gw_flow_residual(
model,
dh_dt=dh_dt,
d_K_dh_dx_dx=dKdhx,
d_K_dh_dy_dy=dKdhy,
Ss_field=Ss_field,
Q=Q_si,
verbose=verbose,
)
prior_res = dlogtau
smooth_res = compute_smoothness_prior(
dK_dx,
dK_dy,
dSs_dx,
dSs_dy,
K_field=K_field,
Ss_field=Ss_field,
)
step = getattr(
getattr(model, "optimizer", None), "iterations", None
)
if step is None:
step = tf_constant(0, tf_int32)
loss_mv = compute_mv_prior(
model,
Ss_field,
logSs=logSs,
as_loss=True,
step=step,
alpha_disp=float(
get_sk(sk, "mv_alpha_disp", default=0.1)
),
delta=float(
get_sk(sk, "mv_huber_delta", default=1.0)
),
verbose=verbose,
)
# R_H, R_K, R_Ss, R_tau = compute_bounds_residual(
# model,
# H_field=H_si,
# logK=logK,
# logSs=logSs,
# log_tau=log_tau,
# verbose=verbose,
# )
# bounds_res = tf_concat([R_H, R_K, R_Ss, R_tau], axis=-1)
# loss_bounds = tf_reduce_mean(tf_square(bounds_res))
R_H, R_K, R_Ss, R_tau = compute_bounds_residual(
model,
H_field=H_si,
logK=logK,
logSs=logSs,
log_tau=log_tau,
verbose=verbose,
)
bounds_res = tf_concat(
[R_H, R_K, R_Ss, R_tau],
axis=-1,
)
loss_bounds_resid = tf_reduce_mean(
tf_square(bounds_res),
)
# If we want H enforced even in "barrier" mode:
loss_bounds_H = tf_reduce_mean(tf_square(R_H))
kind = (
str(_get_bounds_loss_cfg(sk).get("kind", "barrier"))
.strip()
.lower()
)
if kind == "residual":
loss_bounds = loss_bounds_resid
elif kind == "barrier":
# barrier is for K/Ss/tau; keep H from residual
loss_bounds = loss_bounds_barrier + loss_bounds_H
else: # "both"
# WARNING: double-penalizes K/Ss/tau if barrier+residual
loss_bounds = loss_bounds_resid + loss_bounds_barrier
# ----------------------------------------------------------
# 8) GW display units (raw diagnostics only)
# ----------------------------------------------------------
gw_units = resolve_gw_units(sk)
gw_res_si = gw_res
gw_res_disp = gw_res_si
if gw_units == "time_unit":
sec_u = seconds_per_time_unit(
time_units,
dtype=tf_float32,
)
gw_res_disp = gw_res_si * sec_u
cons_res_raw = cons_res
gw_res_raw = gw_res_disp
# ----------------------------------------------------------
# 9) Optional nondimensionalization
# ----------------------------------------------------------
cons_scaled = cons_res
gw_scaled = gw_res_si
scales: dict[str, Tensor] | None = None
if bool(getattr(model, "scale_pde_residuals", False)):
cons_floor = resolve_auto_scale_floor("cons", sk)
gw_floor = resolve_auto_scale_floor("gw", sk)
div_term = dKdhx + dKdhy
s_for_scales = (
tf_stop_gradient(s_inc_pred)
if (cons_active and allow_resid)
else tf_zeros_like(h_si)
)
scales = compute_scales(
model,
t=t,
dt=dt_units,
time_units=time_units,
s_mean=s_for_scales,
h_mean=h_si,
K_field=K_field,
Ss_field=Ss_field,
tau_field=tau_field,
H_field=H_si,
h_ref_si=h_ref_11,
Q=Q_si,
dh_dt=dh_dt,
div_K_grad_h=div_term,
verbose=verbose,
)
scales = sanitize_scales(scales)
scales = {
k: tf_stop_gradient(v) for k, v in scales.items()
}
cons_s = guard_scale_with_residual(
residual=cons_res,
scale=scales["cons_scale"],
floor=cons_floor,
)
gw_s = guard_scale_with_residual(
residual=gw_res_si,
scale=scales["gw_scale"],
floor=gw_floor,
)
cons_scaled = scale_residual(
cons_res,
cons_s,
floor=cons_floor,
)
gw_scaled = scale_residual(
gw_res_si,
gw_s,
floor=gw_floor,
)
# ----------------------------------------------------------
# 10) Losses + epsilons
# ----------------------------------------------------------
loss_cons = tf_reduce_mean(tf_square(cons_scaled))
loss_gw = tf_reduce_mean(tf_square(gw_scaled))
loss_prior = tf_reduce_mean(tf_square(prior_res))
loss_smooth = tf_reduce_mean(tf_square(smooth_res))
eps_prior = to_rms(prior_res)
eps_cons_raw = to_rms(cons_res_raw)
eps_gw_raw = to_rms(gw_res_raw, dtype=tf_float64)
eps_cons = to_rms(cons_scaled)
eps_gw = to_rms(gw_scaled, dtype=tf_float64)
(
phys_raw,
phys_scaled,
phys_mult,
terms_scaled,
) = assemble_physics_loss(
model,
loss_cons=loss_cons,
loss_gw=loss_gw,
loss_prior=loss_prior,
loss_smooth=loss_smooth,
loss_mv=loss_mv,
loss_q_reg=loss_q_reg,
loss_bounds=loss_bounds,
)
if for_train:
w = int(
get_sk(sk, "physics_warmup_steps", default=500)
)
r = int(get_sk(sk, "physics_ramp_steps", default=500))
gate = compute_physics_warmup_gate(
step,
warmup_steps=w,
ramp_steps=r,
)
phys_scaled = phys_scaled * gate
terms_scaled = {
k: v * gate for k, v in terms_scaled.items()
}
physics = build_physics_bundle(
model,
physics_loss_raw=phys_raw,
physics_loss_scaled=phys_scaled,
phys_mult=phys_mult,
loss_cons=loss_cons,
loss_gw=loss_gw,
loss_prior=loss_prior,
loss_smooth=loss_smooth,
loss_mv=loss_mv,
loss_q_reg=loss_q_reg,
q_rms=q_rms,
q_gate=q_gate,
subs_resid_gate=subs_gate,
loss_bounds=loss_bounds,
eps_prior=eps_prior,
eps_cons=eps_cons,
eps_gw=eps_gw,
eps_cons_raw=eps_cons_raw,
eps_gw_raw=eps_gw_raw,
)
out: dict[str, Any] = {
"y_pred": y_pred,
"aux": aux,
"physics": physics,
"physics_packed": pack_eval_physics(
model,
physics=physics,
),
"terms_scaled": terms_scaled,
"dt_units": dt_units,
"scales": scales,
}
if return_maps:
out.update(
{
"Q_si": Q_si,
"K_field": K_field,
"Ss_field": Ss_field,
"tau_field": tau_field,
"tau_phys": tau_phys,
"Hd_eff": Hd_eff,
"H_si": H_si,
"R_cons": cons_res,
"R_gw": gw_res_si,
"R_prior": prior_res,
"R_smooth": smooth_res,
"R_bounds": bounds_res,
"R_cons_scaled": cons_scaled,
"R_gw_scaled": gw_scaled,
"gw_res_display": gw_res_disp,
"h_si": h_si,
"h_ref_si_11": h_ref_11,
}
)
# add legacies names.
out.update(
{
"K": K_field, # effective K (m/s)
"Ss": Ss_field, # effective Ss (1/m)
"tau": tau_field, # learned tau (s)
"tau_prior": tau_phys, # closure tau (s)
"tau_closure": tau_phys, # alias (clearer naming)
"Hd": Hd_eff, # effective drainage thickness (m)
"H": H_si, # base thickness (m)
"H_field": H_si, # legacy name used elsewhere
"cons_res_vals": cons_res, # alias
}
)
dbg_step9_losses(
verbose=verbose,
loss_cons=loss_cons,
loss_gw=loss_gw,
loss_prior=loss_prior,
loss_smooth=loss_smooth,
)
return out