# License: Apache-2.0
# Copyright (c) 2026-present
# Author: LKouadio <etanoyau@gmail.com>
"""Tuner for GeoPriorSubsNet."""
from __future__ import annotations
import logging
from collections.abc import Callable
from typing import (
Any,
)
import numpy as np
from ...core.handlers import _get_valid_kwargs
from ...logging import get_logger
from ...utils.generic_utils import (
cast_multiple_bool_params,
rename_dict_keys,
vlog,
)
from .. import KERAS_DEPS
from ..subsidence.models import GeoPriorSubsNet
from ..utils.pinn import check_required_input_keys
from . import KT_DEPS
from ._base_tuner import PINNTunerBase
HyperParameters = KT_DEPS.HyperParameters
Objective = KT_DEPS.Objective
Tuner = KT_DEPS.Tuner
AUTOTUNE = KERAS_DEPS.AUTOTUNE
Model = KERAS_DEPS.Model
Adam = KERAS_DEPS.Adam
MeanSquaredError = KERAS_DEPS.MeanSquaredError
MeanAbsoluteError = KERAS_DEPS.MeanAbsoluteError
Callback = KERAS_DEPS.Callback
Dataset = KERAS_DEPS.Dataset
tf_float32 = KERAS_DEPS.float32
tf_const = KERAS_DEPS.constant
tf_cast = KERAS_DEPS.cast
tf_reduce_mean = KERAS_DEPS.reduce_mean
tf_max = KERAS_DEPS.maximum
tf_expand = KERAS_DEPS.expand_dims
logger = get_logger(__name__)
# ------------------------------
# Defaults specialized to model
# ------------------------------
_DEFAULT_COMMON = {
"output_subsidence_dim": 1,
"output_gwl_dim": 1,
"quantiles": None,
"max_window_size": 10,
"memory_size": 100,
"scales": [1],
"multi_scale_agg": "last",
"final_agg": "last",
"use_residuals": True,
"use_batch_norm": False,
"activation": "relu",
"architecture_config": {
"encoder_type": "hybrid",
"decoder_attention_stack": [
"cross",
"hierarchical",
"memory",
],
"feature_processing": "vsn",
},
"loss_weights": {"subs_pred": 1.0, "gwl_pred": 1.0},
}
_DEFAULT_GEOPRIOR = {
**_DEFAULT_COMMON,
"pde_mode": "both",
"mv": 1e-7,
"kappa": 1.0,
"gamma_w": 9810.0,
"h_ref": 0.0,
"use_effective_h": False,
"hd_factor": 1.0,
"kappa_mode": "bar",
"scale_pde_residuals": True,
"scaling_kwargs": {},
}
# Compile-only HPs
_COMPILE_ONLY = {
"learning_rate",
"lambda_gw",
"lambda_cons",
"lambda_prior",
"lambda_smooth",
"lambda_mv",
"lambda_bounds",
"lambda_q",
"lambda_offset",
"scale_mv_with_offset",
"scale_q_with_offset",
"mv_lr_mult",
"kappa_lr_mult",
}
_INT_FIELDS = (
"embed_dim",
"hidden_units",
"lstm_units",
"attention_units",
"vsn_units",
"num_heads",
)
_BOOL_CAST = [
("use_vsn", True),
("use_residuals", True),
("use_batch_norm", False),
("use_effective_h", False),
("scale_pde_residuals", True),
]
[docs]
class SubsNetTuner(PINNTunerBase):
r"""
Specialized tuner for ``GeoPriorSubsNet`` models.
This class provides a flexible hyperparameter tuner for the
physics-informed ``GeoPriorSubsNet``. It builds on
``PINNTunerBase`` and uses keras-tuner backends to search
architectural, physics, and compile-time spaces. Fixed,
data-dependent parameters are separated from the search
space, and GeoPrior input checks are enforced during
``run`` and ``create``.
The recommended entry point is ``SubsNetTuner.create``,
which infers data dimensions from NumPy arrays, merges
them with robust defaults, and applies user overrides.
Parameters
----------
fixed_params : dict
Non-tunable configuration passed to the model ``__init__``.
Typical keys include ``static_input_dim``,
``dynamic_input_dim``, ``future_input_dim``,
``output_subsidence_dim``, ``output_gwl_dim``,
``forecast_horizon``, and stable flags such as
``use_batch_norm`` or ``pde_mode``.
search_space : dict, optional
Hyperparameter definitions. Each entry is either a list of
discrete choices or a typed range dictionary such as
``{"dropout_rate": {"type": "float", "min_value": 0.1,
"max_value": 0.4}}``. Supported types are ``int``,
``float``, ``choice``, and ``bool``. Items are routed to
model ``__init__`` or to ``compile`` as noted below.
objective : str or keras_tuner.Objective, default "val_loss"
Metric to optimize. If the string contains "loss",
direction is inferred as "min".
max_trials : int, default 10
Maximum number of trials evaluated by the tuner.
project_name : str, default "SubsNetrTuner_Project"
Project name used for directory layout.
directory : str, default "subsnet_tuner_results"
Root directory where tuner artifacts are saved.
executions_per_trial : int, default 1
Number of repeated trainings per hyperparameter set.
tuner_type : {'randomsearch', 'bayesianoptimization', 'hyperband'}, default 'randomsearch'
Search algorithm used by keras-tuner.
seed : int, optional
Random seed for reproducibility.
overwrite : bool, default True
If True, existing project results are overwritten.
_logger : logging.Logger or callable, optional
Logger or print-like callable for progress lines.
kwargs : dict
Forwarded to ``PINNTunerBase`` for advanced control.
Attributes
----------
model_class : Type[tf.keras.Model]
Bound to ``GeoPriorSubsNet``.
fixed_params : dict
Finalized, non-tunable model configuration.
search_space : dict
User-provided hyperparameter search definitions.
best_hps_ : keras_tuner.HyperParameters or None
Best hyperparameters found by the search.
best_model_ : tf.keras.Model or None
Model built with the best hyperparameters.
tuner_ : keras_tuner.Tuner or None
Underlying tuner instance after ``run`` or ``search``.
tuning_summary_ : dict
Compact summary saved under the project directory.
Notes
-----
**Required inputs.** ``GeoPriorSubsNet`` expects ``coords`` for
spatiotemporal coordinates, ``dynamic_features`` for time-varying
covariates, and ``H_field`` for the soil-thickness field.
The helper canonicalizes ``H_field`` from common aliases,
e.g., ``soil_thickness``, ``soil thickness``, ``h_field``.
**Target canonicalization.**
Targets are canonicalized to ``subs_pred`` and
``gwl_pred`` from ``subsidence`` and ``gwl``. This is
handled internally before data pipelines are built.
**Compile-only hyperparameters.** These search keys are not
passed to ``__init__`` and are routed to ``compile``:
- ``learning_rate``
- ``lambda_gw``, ``lambda_cons``, ``lambda_prior``,
``lambda_smooth``, ``lambda_mv``
- ``mv_lr_mult``, ``kappa_lr_mult``
**Losses and metrics.** By default the supervised heads use
mean squared error with mean absolute error metrics. If
``fixed_params["quantiles"]`` is set, a pinball loss can
be injected via a user loss factory at compile time.
**Physics objectives.** GeoPrior adds residuals consistent
with consolidation and groundwater flow:
- :math:`R_gw = Ss * d(h)/dt - div(K * grad(h)) - Q`
- :math:`R_cons = d(s)/dt - (s_eq - s) / tau`
with :math:`s_eq = m_v * gamma_w * (h_ref - h) * H`.
Weights are controlled by the compile-time lambdas.
**Typical search groups.**
- Architecture: ``embed_dim``, ``hidden_units``,
``lstm_units``, ``attention_units``, ``num_heads``,
``dropout_rate``, ``vsn_units``, ``use_vsn``,
``use_batch_norm``
- Physics: ``pde_mode``, ``mv``, ``kappa``,
``use_effective_h``, ``hd_factor``, ``kappa_mode``,
``scale_pde_residuals``
- Optimization: ``learning_rate`` and lambda weights
Workflow. ``run`` builds ``tf.data`` pipelines from NumPy inputs,
applies key canonicalization, validates GeoPrior requirements, and
delegates to ``search``. The best HPs and a built model are returned
and stored on the class. The tuning workflow combines Keras Tuner
search patterns with the poromechanics background from
:cite:p:`KerasTunerDocs,Terzaghi1943TheoreticalSoilMechanics,Bear1972DynamicsFluidsPorousMedia`.
Examples
--------
Create from arrays and tune a small space.
>>> from geoprior.models.forecast_tuner import SubsNetTuner
>>> fixed = {"forecast_horizon": 7}
>>> space = {
... "embed_dim": [32, 64],
... "num_heads": [2, 4],
... "dropout_rate": {"type": "float",
... "min_value": 0.1, "max_value": 0.3},
... "learning_rate": [1e-3, 5e-4],
... "lambda_gw": {"type": "float",
... "min_value": 0.5, "max_value": 1.5},
... }
>>> tuner = SubsNetTuner.create(
... inputs_data=inputs_np,
... targets_data=targets_np,
... search_space=space,
... fixed_params=fixed,
... max_trials=20,
... project_name="GeoPrior_HP_Search",
... )
>>> best_model, best_hps, kt = tuner.run(
... inputs=inputs_np,
... y=targets_np,
... validation_data=(val_inputs_np, val_targets_np),
... epochs=30,
... batch_size=32,
... )
See Also
--------
PINNTunerBase
Base hypermodel with the generic ``search`` routine.
GeoPriorSubsNet
Target model with physics residuals and priors.
HydroTuner
Generic PINN tuner for HAL and TransFlow models.
"""
[docs]
def __init__(
self,
fixed_params: dict[str, Any],
search_space: dict[str, Any] | None = None,
objective: str | Objective = "val_loss",
max_trials: int = 10,
project_name: str = "SubsNetrTuner_Project",
directory: str = "subsnet_tuner_results",
executions_per_trial: int = 1,
tuner_type: str = "randomsearch",
seed: int | None = None,
overwrite: bool = True,
_logger: Callable[[str], None]
| logging.Logger
| None = None,
**kwargs,
):
self._logger = _logger or print
super().__init__(
objective=objective,
max_trials=max_trials,
project_name=project_name,
directory=directory,
executions_per_trial=executions_per_trial,
tuner_type=tuner_type,
seed=seed,
overwrite_tuner=overwrite,
_logger=self._logger,
**kwargs,
)
self.fixed_params = dict(fixed_params or {})
self.search_space = dict(search_space or {})
self.model_class: type[Model] = GeoPriorSubsNet
[docs]
@classmethod
def create(
cls,
inputs_data: dict[str, np.ndarray],
targets_data: dict[str, np.ndarray],
search_space: dict[str, Any],
fixed_params: dict[str, Any] | None = None,
**tuner_kwargs,
) -> SubsNetTuner:
t_std = rename_dict_keys(
targets_data.copy(),
param_to_rename={
"subsidence": "subs_pred",
"gwl": "gwl_pred",
},
)
final_fixed = cls._infer_and_merge_params(
inputs_data=inputs_data,
targets_data=t_std,
user_fixed_params=fixed_params,
)
return cls(
fixed_params=final_fixed,
search_space=search_space,
**tuner_kwargs,
)
def _create_hyperparameter(
self, hp: HyperParameters, name: str, definition: Any
) -> int | float | str | bool:
if isinstance(definition, list):
return hp.Choice(name, definition)
if isinstance(definition, dict):
hp_type = definition.get("type", "float")
kw = {
k: v
for k, v in definition.items()
if k != "type"
}
if hp_type == "int":
return hp.Int(name, **kw)
if hp_type == "float":
return hp.Float(name, **kw)
if hp_type == "choice":
return hp.Choice(name, **kw)
if hp_type == "bool":
return hp.Boolean(name, **kw)
raise TypeError(
f"Unsupported HP def for '{name}': {definition}"
)
# ----------------------------
# Build per-trial model
# ----------------------------
[docs]
def build(self, hp: HyperParameters) -> Model:
init_params = dict(self.fixed_params)
compile_hps: dict[str, Any] = {}
for name, spec in self.search_space.items():
val = self._create_hyperparameter(hp, name, spec)
if name in _COMPILE_ONLY:
compile_hps[name] = val
else:
init_params[name] = val
for k in _INT_FIELDS:
if k in init_params:
init_params[k] = int(init_params[k])
valid_init = _get_valid_kwargs(
self.model_class.__init__,
init_params,
error="ignore",
)
cast_multiple_bool_params(
valid_init,
bool_params_to_cast=_BOOL_CAST,
)
model = self.model_class(**valid_init)
# choose losses: pinball if quantiles set
quantiles = valid_init.get("quantiles")
if quantiles:
pinball = _pinball_factory(quantiles)
loss = {
"subs_pred": pinball,
"gwl_pred": pinball,
}
metrics = {} # avoid MAE on Q-dim tensors
else:
loss = {
"subs_pred": MeanSquaredError(
name="subs_data_loss"
),
"gwl_pred": MeanSquaredError(
name="gwl_data_loss"
),
}
metrics = {
"subs_pred": [
MeanAbsoluteError(name="subs_mae")
],
"gwl_pred": [
MeanAbsoluteError(name="gwl_mae")
],
}
lr = compile_hps.pop("learning_rate", 1e-3)
optimizer = Adam(
learning_rate=lr, clipnorm=1.0, clipvalue=0.5
)
valid_compile = _get_valid_kwargs(
model.compile,
compile_hps,
error="ignore",
)
cast_multiple_bool_params(
valid_compile,
bool_params_to_cast=[
("scale_mv_with_offset", False),
("scale_q_with_offset", True),
],
)
model.compile(
optimizer=optimizer,
loss=loss,
metrics=metrics or None,
loss_weights=self.fixed_params.get(
"loss_weights",
{"subs_pred": 1.0, "gwl_pred": 1.0},
),
**valid_compile,
)
return model
# ----------------------------
# Run end-to-end search
# ----------------------------
[docs]
def run(
self,
inputs: dict[str, np.ndarray],
y: dict[str, np.ndarray],
validation_data: tuple[
dict[str, np.ndarray], dict[str, np.ndarray]
]
| None = None,
epochs: int = 10,
batch_size: int = 32,
callbacks: list[Callback] | None = None,
case_info: dict[str, Any] | None = None,
verbose: int = 1,
**search_kwargs,
) -> tuple[
Model | None,
HyperParameters | None,
Tuner | None,
]:
vlog(
"SubsNetTuner: starting run...",
verbose=verbose,
level=1,
logger=self._logger,
)
# fast fail for H_field presence
_require_h_field(inputs, "inputs")
if validation_data:
_require_h_field(validation_data[0], "val_inputs")
req = [
"static_input_dim",
"dynamic_input_dim",
"future_input_dim",
"output_subsidence_dim",
"output_gwl_dim",
"forecast_horizon",
]
if not all(k in self.fixed_params for k in req):
vlog(
"Inferring fixed params from data...",
verbose=verbose,
level=2,
logger=self._logger,
)
y_std = rename_dict_keys(
y.copy(),
param_to_rename={
"subsidence": "subs_pred",
"gwl": "gwl_pred",
},
)
self.fixed_params = self._infer_and_merge_params(
inputs_data=inputs,
targets_data=y_std,
user_fixed_params=self.fixed_params,
)
inputs, y = check_required_input_keys(
inputs, y, model_name="GeoPriorSubsNet"
)
train_ds = (
Dataset.from_tensor_slices((inputs, y))
.batch(batch_size)
.prefetch(AUTOTUNE)
)
val_ds = None
if validation_data:
vx, vy = validation_data
vx, vy = check_required_input_keys(
vx, vy, model_name="GeoPriorSubsNet"
)
val_ds = (
Dataset.from_tensor_slices((vx, vy))
.batch(batch_size)
.prefetch(AUTOTUNE)
)
metric_kind = (
"Quantile"
if self.fixed_params.get("quantiles")
else "Point"
)
self._current_run_case_info = {
"description": (
f"GeoPriorSubsNet {metric_kind} forecast"
),
}
self._current_run_case_info.update(self.fixed_params)
if case_info:
self._current_run_case_info.update(case_info)
vlog(
"Delegating to base.search()...",
verbose=verbose,
level=2,
logger=self._logger,
)
return super().search(
train_data=train_ds,
epochs=epochs,
validation_data=val_ds,
callbacks=callbacks,
verbose=verbose,
**search_kwargs,
)
@staticmethod
def _infer_and_merge_params(
inputs_data: dict[str, np.ndarray],
targets_data: dict[str, np.ndarray],
user_fixed_params: dict[str, Any] | None = None,
) -> dict[str, Any]:
final = dict(_DEFAULT_GEOPRIOR)
inputs_data, targets_data = check_required_input_keys(
inputs_data,
targets_data,
model_name="GeoPriorSubsNet",
)
inferred: dict[str, Any] = {}
if "static_features" in inputs_data:
inferred["static_input_dim"] = inputs_data[
"static_features"
].shape[-1]
else:
inferred["static_input_dim"] = 0
inferred["dynamic_input_dim"] = inputs_data[
"dynamic_features"
].shape[-1]
if "future_features" in inputs_data:
inferred["future_input_dim"] = inputs_data[
"future_features"
].shape[-1]
else:
inferred["future_input_dim"] = 0
inferred["output_subsidence_dim"] = targets_data[
"subs_pred"
].shape[-1]
inferred["output_gwl_dim"] = targets_data[
"gwl_pred"
].shape[-1]
inferred["forecast_horizon"] = targets_data[
"subs_pred"
].shape[1]
final.update(inferred)
if user_fixed_params:
final.update(user_fixed_params)
return final
def _require_h_field(d: dict[str, Any], tag: str) -> None:
if ("H_field" not in d) and ("soil_thickness" not in d):
raise ValueError(
f"{tag} must contain 'H_field' or 'soil_thickness'."
)
def _pinball_factory(qs: list[float]):
q_base = tf_const(qs, dtype=tf_float32) # shape (Q,)
def loss(y_true, y_pred):
yt = tf_cast(y_true, tf_float32)
yp = tf_cast(y_pred, tf_float32)
# --- Normalize shapes ---
# y_pred: either (B,H,Q) or (B,H,Q,1)
if yp.shape.rank == 3:
# want y_true as (B,H,1) so it broadcasts across Q
if yt.shape.rank == 2: # (B,H) -> (B,H,1)
yt = tf_expand(yt, axis=-1)
# if yt is already (B,H,1), leave it
q = q_base[
None, None, :
] # (1,1,Q) -> matches (B,H,Q)
elif yp.shape.rank == 4:
# want y_true as (B,H,1,1) to match (B,H,Q,1)
if yt.shape.rank == 2: # (B,H) -> (B,H,1)
yt = tf_expand(yt, axis=-1)
if yt.shape.rank == 3: # (B,H,1) -> (B,H,1,1)
yt = tf_expand(yt, axis=2)
q = q_base[
None, None, :, None
] # (1,1,Q,1) -> matches (B,H,Q,1)
else:
raise ValueError(
"y_pred must be rank 3 (B,H,Q) or rank 4 (B,H,Q,1)."
)
# --- Pinball loss ---
err = yt - yp
pin = tf_max(q * err, (q - 1.0) * err)
return tf_reduce_mean(pin)
return loss