# License: Apache-2.0
# Copyright (c) 2026-present
# Author: LKouadio <etanoyau@gmail.com>
"""
Base classes and utilities for hyperparameter tuning of PINN models.
"""
import json
import logging
import os
import warnings
from collections.abc import Callable
from typing import (
Any,
)
from ...api.docs import (
DocstringComponents,
_pinn_tuner_common_params,
)
from ...api.property import BaseClass
from ...compat._config import Config
from ...logging import get_logger
from ...utils.deps_utils import ensure_pkg
from ...utils.generic_utils import rename_dict_keys, vlog
from .. import KERAS_DEPS
from . import HAS_KT, KT_DEPS
HyperModel = KT_DEPS.HyperModel
Tuner = KT_DEPS.Tuner
RandomSearch = KT_DEPS.RandomSearch
BayesianOptimization = KT_DEPS.BayesianOptimization
HyperParameters = KT_DEPS.HyperParameters
Objective = KT_DEPS.Objective
Hyperband = KT_DEPS.Hyperband
Model = KERAS_DEPS.Model
Callback = KERAS_DEPS.Callback
Dataset = KERAS_DEPS.Dataset
Adam = KERAS_DEPS.Adam
EarlyStopping = KERAS_DEPS.EarlyStopping
AUTOTUNE = KERAS_DEPS.AUTOTUNE
_pinn_tuner_docs = DocstringComponents.from_nested_components(
base=DocstringComponents(_pinn_tuner_common_params)
)
logger = get_logger(__name__)
[docs]
class PINNTunerBase(HyperModel, BaseClass):
[docs]
@ensure_pkg(
"keras_tuner",
extra="'keras_tuner' is required for model tuning.",
auto_install=Config.INSTALL_DEPS,
use_conda=Config.USE_CONDA,
)
def __init__(
self,
objective: str | Objective = "val_loss",
max_trials: int = 10,
project_name: str = "PINN_Tuning",
directory: str = "pinn_tuner_results",
executions_per_trial: int = 1,
tuner_type: str = "randomsearch",
seed: int | None = None,
overwrite_tuner: bool = True,
_logger: logging.Logger
| Callable[[str], None]
| None = None,
**tuner_kwargs,
):
if not HAS_KT:
raise ImportError(
"keras_tuner is not installed. Please run "
"`pip install keras-tuner` to use this tuning class."
)
super().__init__()
self.objective = objective
self.max_trials = max_trials
self.project_name = project_name
self.directory = directory
self.executions_per_trial = executions_per_trial
self.tuner_type = self._validate_tuner_type(
tuner_type
)
self.seed = seed
self.overwrite_tuner = overwrite_tuner
self.tuner_kwargs = tuner_kwargs
self._logger = _logger or print
self.best_hps_: HyperParameters | None = None
self.best_model_: Model | None = None
self.tuner_: Tuner | None = None
self.tuning_summary_: dict[str, Any] = {}
self.fixed_model_params: dict[str, Any] = {}
self.param_space_config: dict[str, Any] = {}
if isinstance(self.objective, str):
# Default: any metric name containing "loss" is minimized
direction = (
"min" if "loss" in self.objective else "max"
)
self.objective = Objective(
self.objective, direction=direction
)
def _validate_tuner_type(self, tuner_type: str) -> str:
valid_types = {
"randomsearch",
"bayesianoptimization",
"hyperband",
}
tt_lower = tuner_type.lower()
# Allow partial match for "random"
if "random" in tt_lower:
tt_lower = "randomsearch"
if "bayesian" in tt_lower:
tt_lower = "bayesianoptimization"
if tt_lower not in valid_types:
warnings.warn(
f"Unsupported tuner type: '{tuner_type}'. "
f"Supported types: {valid_types}. "
"Defaulting to 'randomsearch'.",
UserWarning,
stacklevel=2,
)
return "randomsearch"
return tt_lower
[docs]
def build(self, hp: HyperParameters) -> Model:
"""
Builds and compiles the Keras model with hyperparameters.
This method **must be overridden** by subclasses (e.g., PIHALTuner)
to define the specific model architecture (like PIHALNet), sample
hyperparameters using the `hp` object based on
`self.param_space`, and compile the model.
Args:
hp (kt.HyperParameters): Keras Tuner HyperParameters object.
Returns:
tf.keras.Model: The compiled Keras model.
"""
raise NotImplementedError(
"Subclasses must implement the `build(hp)` method."
)
[docs]
def search(
self,
train_data: Dataset,
epochs: int,
validation_data: Dataset | None = None,
callbacks: list[Callback] | None = None,
verbose: int = 1,
patience: int = 10,
**additional_search_kwargs,
) -> tuple[
Model | None,
HyperParameters | None,
Tuner | None,
]:
"""
Performs the hyperparameter search using Keras Tuner.
Parameters
----------
train_data : tf.data.Dataset
Training dataset. Must yield tuples of
``(inputs_dict, targets_dict)`` compatible with the model's
``train_step``.
epochs : int
Number of epochs to train each model during a trial.
validation_data : tf.data.Dataset or None, default=None
Validation dataset.
callbacks : list of tf.keras.callbacks.Callback or None, default=None
Keras callbacks for the search phase.
verbose : int, default=1
Verbosity level for Keras Tuner search.
patience : int, default=10
Early-stopping patience.
**additional_search_kwargs
Additional keyword arguments passed to the tuner
``search()`` method.
Returns
-------
best_model : tf.keras.Model or None
Best model instance built with the best hyperparameters.
best_hps : keras_tuner.HyperParameters or None
Best hyperparameters found.
tuner : keras_tuner.Tuner or None
Tuner instance used for the search.
"""
tuner_verbose = additional_search_kwargs.pop(
"tuner_verbose", 1
)
# ------------------------------------------------------------------
# Rename target‑dict keys *only if* each element’s target component
# is a Python dict produced by PIHALNet. For HALNet the target is
# already a Tensor, so we leave it unchanged.
# ------------------------------------------------------------------
def _maybe_rename_targets(tgts):
# tgts is either a dict of tensors or a single/tuple Tensor
return (
rename_dict_keys(
tgts,
param_to_rename={
"subsidence": "subs_pred",
"gwl": "gwl_pred",
},
)
if isinstance(tgts, dict)
else tgts
)
# STEP 1: If train_data is not None, wrap it so that any target dict
# inside gets its keys renamed. We assume each element of
# train_data is (input_dict, target_dict).
if train_data is not None:
train_data = train_data.map(
lambda in_dict, tgts: (
in_dict,
_maybe_rename_targets(tgts),
),
num_parallel_calls=AUTOTUNE,
)
# STEP 2: Do the same for validation_data, if provided.
if validation_data is not None:
validation_data = validation_data.map(
lambda in_dict, tgts: (
in_dict,
_maybe_rename_targets(tgts),
),
num_parallel_calls=AUTOTUNE,
)
tuner_class_map = {
"randomsearch": RandomSearch,
"bayesianoptimization": BayesianOptimization,
"hyperband": Hyperband,
}
TunerClass = tuner_class_map[self.tuner_type]
tuner_params = {
"hypermodel": self,
"objective": self.objective,
"executions_per_trial": self.executions_per_trial,
"directory": self.directory,
"project_name": self.project_name,
"seed": self.seed,
"overwrite": self.overwrite_tuner,
**self.tuner_kwargs,
}
if self.tuner_type == "hyperband":
tuner_params["max_epochs"] = (
self.tuner_kwargs.get("max_epochs", epochs)
)
tuner_params["factor"] = self.tuner_kwargs.get(
"factor", 3
)
if "max_trials" in tuner_params:
del tuner_params["max_trials"]
else:
tuner_params["max_trials"] = self.max_trials
self.tuner_ = TunerClass(**tuner_params)
vlog(
f"Starting hyperparameter search with {self.tuner_type.upper()}...",
verbose=verbose,
level=1,
logger=self._logger,
)
vlog(
f" Project: {self.project_name} (in {self.directory}/)",
verbose=verbose,
level=2,
logger=self._logger,
)
vlog(
f" Objective: {self.objective}",
verbose=verbose,
level=2,
logger=self._logger,
)
vlog(
f" Epochs per trial: {epochs}",
verbose=verbose,
level=2,
logger=self._logger,
)
search_callbacks = callbacks or []
if not any(
isinstance(cb, EarlyStopping)
for cb in search_callbacks
):
# Objective name for monitor:
monitor_objective = self.objective
if not isinstance(
self.objective, str
) and hasattr(self.objective, "name"):
monitor_objective = self.objective.name
early_stopping_search = EarlyStopping(
monitor=str(
monitor_objective
), # Ensure it's a string
patience=patience,
verbose=1
if verbose >= 2
else 0, # Keras verbose mapping
restore_best_weights=True,
)
search_callbacks.append(early_stopping_search)
vlog(
" Added default EarlyStopping callback for search.",
verbose=verbose,
level=2,
logger=self._logger,
)
self.tuner_.search(
train_data,
epochs=epochs,
validation_data=validation_data,
callbacks=search_callbacks,
verbose=tuner_verbose, # 1 if verbose >=1 else 0, # Keras tuner verbose
**additional_search_kwargs,
)
vlog(
"\nHyperparameter search complete.",
verbose=verbose,
level=1,
logger=self._logger,
)
try:
self.tuner_.results_summary(num_trials=10)
except Exception as e:
logger.warning(
f"Could not display Keras Tuner results_summary: {e}"
)
try:
best_hps_list = (
self.tuner_.get_best_hyperparameters(
num_trials=1
)
)
if not best_hps_list:
logger.error(
"Keras Tuner found no best hyperparameters."
)
self.best_hps_ = None
self.best_model_ = None
else:
self.best_hps_ = best_hps_list[0]
vlog(
"\n--- Best Hyperparameters Found ---",
verbose=verbose,
level=1,
logger=self._logger,
)
for (
hp_name,
hp_value,
) in self.best_hps_.values.items():
vlog(
f" {hp_name}: {hp_value}",
verbose=verbose,
level=2,
logger=self._logger,
)
vlog(
"\nBuilding model with best hyperparameters...",
verbose=verbose,
level=1,
logger=self._logger,
)
try:
self.best_model_ = (
self.tuner_.hypermodel.build(
self.best_hps_
)
)
except:
self.best_model_ = (
self.tuner_.get_best_models(
num_models=1
)[0]
) # Alternative
except Exception as e:
logger.error(
f"Error retrieving or building best model: {e}"
)
self.best_hps_ = None
self.best_model_ = None
self._save_tuning_summary(verbose=verbose)
return self.best_model_, self.best_hps_, self.tuner_
def _save_tuning_summary(self, verbose: int = 1):
if self.tuner_ is None or self.best_hps_ is None:
vlog(
"No tuner or best HPs found to save summary.",
verbose=verbose,
level=2,
logger=self._logger,
)
return
summary_data = {
"project_name": self.project_name,
"tuner_type": self.tuner_type,
"objective": self.objective
if isinstance(self.objective, str)
else getattr(
self.objective, "name", str(self.objective)
),
"best_hyperparameters": self.best_hps_.values
if self.best_hps_
else None,
}
try:
best_trial = self.tuner_.oracle.get_best_trials(
1
)[0]
summary_data["best_score"] = best_trial.score
summary_data["best_trial_id"] = (
best_trial.trial_id
)
except:
summary_data["best_score"] = "N/A"
self.tuning_summary_ = summary_data
log_file_path = os.path.join(
self.directory,
self.project_name,
"tuning_summary.json",
)
try:
os.makedirs(
os.path.dirname(log_file_path), exist_ok=True
)
with open(log_file_path, "w") as f:
json.dump(
summary_data, f, indent=4, default=str
)
vlog(
f"Tuning summary saved to {log_file_path}",
verbose=verbose,
level=1,
logger=self._logger,
)
except Exception as e:
logger.warning(
f"Could not save tuning summary log to {log_file_path}: {e}"
)
PINNTunerBase.__doc__ = r"""
Base class for hyperparameter tuning of physics-informed models.
This class wraps keras-tuner orchestration for GeoPrior PINN-style
models. Subclasses are expected to implement ``build(hp)``.
Parameters
----------
objective : str or keras_tuner.Objective, default "val_loss"
Metric to optimize during the search.
max_trials : int, default 10
Maximum number of hyperparameter trials to evaluate.
project_name : str, default "PINN_Tuning"
Project name used for tuner artifacts.
directory : str, default "pinn_tuner_results"
Root directory for tuner outputs.
executions_per_trial : int, default 1
Number of repeated trainings per sampled hyperparameter set.
tuner_type : {"randomsearch", "bayesianoptimization", "hyperband"}, default "randomsearch"
Search backend used by keras-tuner.
seed : int or None, default None
Random seed used for reproducibility.
overwrite_tuner : bool, default True
Whether to overwrite an existing tuner project directory.
tuner_kwargs : dict
Additional keyword arguments forwarded to the underlying tuner
backend.
Attributes
----------
best_hps_ : keras_tuner.HyperParameters or None
Best hyperparameters discovered during tuning.
best_model_ : tf.keras.Model or None
Best compiled model recovered from the tuner.
tuner_ : keras_tuner.Tuner or None
Underlying tuner instance.
tuning_summary_ : dict
Compact summary of the completed tuning run.
"""