Source code for geoprior.models.forecast_tuner._base_tuner

# License: Apache-2.0
# Copyright (c) 2026-present
# Author: LKouadio <etanoyau@gmail.com>

"""
Base classes and utilities for hyperparameter tuning of PINN models.
"""

import json
import logging
import os
import warnings
from collections.abc import Callable
from typing import (
    Any,
)

from ...api.docs import (
    DocstringComponents,
    _pinn_tuner_common_params,
)
from ...api.property import BaseClass
from ...compat._config import Config
from ...logging import get_logger
from ...utils.deps_utils import ensure_pkg
from ...utils.generic_utils import rename_dict_keys, vlog
from .. import KERAS_DEPS
from . import HAS_KT, KT_DEPS

HyperModel = KT_DEPS.HyperModel
Tuner = KT_DEPS.Tuner
RandomSearch = KT_DEPS.RandomSearch
BayesianOptimization = KT_DEPS.BayesianOptimization
HyperParameters = KT_DEPS.HyperParameters
Objective = KT_DEPS.Objective
Hyperband = KT_DEPS.Hyperband

Model = KERAS_DEPS.Model
Callback = KERAS_DEPS.Callback
Dataset = KERAS_DEPS.Dataset
Adam = KERAS_DEPS.Adam
EarlyStopping = KERAS_DEPS.EarlyStopping
AUTOTUNE = KERAS_DEPS.AUTOTUNE

_pinn_tuner_docs = DocstringComponents.from_nested_components(
    base=DocstringComponents(_pinn_tuner_common_params)
)

logger = get_logger(__name__)


[docs] class PINNTunerBase(HyperModel, BaseClass):
[docs] @ensure_pkg( "keras_tuner", extra="'keras_tuner' is required for model tuning.", auto_install=Config.INSTALL_DEPS, use_conda=Config.USE_CONDA, ) def __init__( self, objective: str | Objective = "val_loss", max_trials: int = 10, project_name: str = "PINN_Tuning", directory: str = "pinn_tuner_results", executions_per_trial: int = 1, tuner_type: str = "randomsearch", seed: int | None = None, overwrite_tuner: bool = True, _logger: logging.Logger | Callable[[str], None] | None = None, **tuner_kwargs, ): if not HAS_KT: raise ImportError( "keras_tuner is not installed. Please run " "`pip install keras-tuner` to use this tuning class." ) super().__init__() self.objective = objective self.max_trials = max_trials self.project_name = project_name self.directory = directory self.executions_per_trial = executions_per_trial self.tuner_type = self._validate_tuner_type( tuner_type ) self.seed = seed self.overwrite_tuner = overwrite_tuner self.tuner_kwargs = tuner_kwargs self._logger = _logger or print self.best_hps_: HyperParameters | None = None self.best_model_: Model | None = None self.tuner_: Tuner | None = None self.tuning_summary_: dict[str, Any] = {} self.fixed_model_params: dict[str, Any] = {} self.param_space_config: dict[str, Any] = {} if isinstance(self.objective, str): # Default: any metric name containing "loss" is minimized direction = ( "min" if "loss" in self.objective else "max" ) self.objective = Objective( self.objective, direction=direction )
def _validate_tuner_type(self, tuner_type: str) -> str: valid_types = { "randomsearch", "bayesianoptimization", "hyperband", } tt_lower = tuner_type.lower() # Allow partial match for "random" if "random" in tt_lower: tt_lower = "randomsearch" if "bayesian" in tt_lower: tt_lower = "bayesianoptimization" if tt_lower not in valid_types: warnings.warn( f"Unsupported tuner type: '{tuner_type}'. " f"Supported types: {valid_types}. " "Defaulting to 'randomsearch'.", UserWarning, stacklevel=2, ) return "randomsearch" return tt_lower
[docs] def build(self, hp: HyperParameters) -> Model: """ Builds and compiles the Keras model with hyperparameters. This method **must be overridden** by subclasses (e.g., PIHALTuner) to define the specific model architecture (like PIHALNet), sample hyperparameters using the `hp` object based on `self.param_space`, and compile the model. Args: hp (kt.HyperParameters): Keras Tuner HyperParameters object. Returns: tf.keras.Model: The compiled Keras model. """ raise NotImplementedError( "Subclasses must implement the `build(hp)` method." )
[docs] def search( self, train_data: Dataset, epochs: int, validation_data: Dataset | None = None, callbacks: list[Callback] | None = None, verbose: int = 1, patience: int = 10, **additional_search_kwargs, ) -> tuple[ Model | None, HyperParameters | None, Tuner | None, ]: """ Performs the hyperparameter search using Keras Tuner. Parameters ---------- train_data : tf.data.Dataset Training dataset. Must yield tuples of ``(inputs_dict, targets_dict)`` compatible with the model's ``train_step``. epochs : int Number of epochs to train each model during a trial. validation_data : tf.data.Dataset or None, default=None Validation dataset. callbacks : list of tf.keras.callbacks.Callback or None, default=None Keras callbacks for the search phase. verbose : int, default=1 Verbosity level for Keras Tuner search. patience : int, default=10 Early-stopping patience. **additional_search_kwargs Additional keyword arguments passed to the tuner ``search()`` method. Returns ------- best_model : tf.keras.Model or None Best model instance built with the best hyperparameters. best_hps : keras_tuner.HyperParameters or None Best hyperparameters found. tuner : keras_tuner.Tuner or None Tuner instance used for the search. """ tuner_verbose = additional_search_kwargs.pop( "tuner_verbose", 1 ) # ------------------------------------------------------------------ # Rename target‑dict keys *only if* each element’s target component # is a Python dict produced by PIHALNet. For HALNet the target is # already a Tensor, so we leave it unchanged. # ------------------------------------------------------------------ def _maybe_rename_targets(tgts): # tgts is either a dict of tensors or a single/tuple Tensor return ( rename_dict_keys( tgts, param_to_rename={ "subsidence": "subs_pred", "gwl": "gwl_pred", }, ) if isinstance(tgts, dict) else tgts ) # STEP 1: If train_data is not None, wrap it so that any target dict # inside gets its keys renamed. We assume each element of # train_data is (input_dict, target_dict). if train_data is not None: train_data = train_data.map( lambda in_dict, tgts: ( in_dict, _maybe_rename_targets(tgts), ), num_parallel_calls=AUTOTUNE, ) # STEP 2: Do the same for validation_data, if provided. if validation_data is not None: validation_data = validation_data.map( lambda in_dict, tgts: ( in_dict, _maybe_rename_targets(tgts), ), num_parallel_calls=AUTOTUNE, ) tuner_class_map = { "randomsearch": RandomSearch, "bayesianoptimization": BayesianOptimization, "hyperband": Hyperband, } TunerClass = tuner_class_map[self.tuner_type] tuner_params = { "hypermodel": self, "objective": self.objective, "executions_per_trial": self.executions_per_trial, "directory": self.directory, "project_name": self.project_name, "seed": self.seed, "overwrite": self.overwrite_tuner, **self.tuner_kwargs, } if self.tuner_type == "hyperband": tuner_params["max_epochs"] = ( self.tuner_kwargs.get("max_epochs", epochs) ) tuner_params["factor"] = self.tuner_kwargs.get( "factor", 3 ) if "max_trials" in tuner_params: del tuner_params["max_trials"] else: tuner_params["max_trials"] = self.max_trials self.tuner_ = TunerClass(**tuner_params) vlog( f"Starting hyperparameter search with {self.tuner_type.upper()}...", verbose=verbose, level=1, logger=self._logger, ) vlog( f" Project: {self.project_name} (in {self.directory}/)", verbose=verbose, level=2, logger=self._logger, ) vlog( f" Objective: {self.objective}", verbose=verbose, level=2, logger=self._logger, ) vlog( f" Epochs per trial: {epochs}", verbose=verbose, level=2, logger=self._logger, ) search_callbacks = callbacks or [] if not any( isinstance(cb, EarlyStopping) for cb in search_callbacks ): # Objective name for monitor: monitor_objective = self.objective if not isinstance( self.objective, str ) and hasattr(self.objective, "name"): monitor_objective = self.objective.name early_stopping_search = EarlyStopping( monitor=str( monitor_objective ), # Ensure it's a string patience=patience, verbose=1 if verbose >= 2 else 0, # Keras verbose mapping restore_best_weights=True, ) search_callbacks.append(early_stopping_search) vlog( " Added default EarlyStopping callback for search.", verbose=verbose, level=2, logger=self._logger, ) self.tuner_.search( train_data, epochs=epochs, validation_data=validation_data, callbacks=search_callbacks, verbose=tuner_verbose, # 1 if verbose >=1 else 0, # Keras tuner verbose **additional_search_kwargs, ) vlog( "\nHyperparameter search complete.", verbose=verbose, level=1, logger=self._logger, ) try: self.tuner_.results_summary(num_trials=10) except Exception as e: logger.warning( f"Could not display Keras Tuner results_summary: {e}" ) try: best_hps_list = ( self.tuner_.get_best_hyperparameters( num_trials=1 ) ) if not best_hps_list: logger.error( "Keras Tuner found no best hyperparameters." ) self.best_hps_ = None self.best_model_ = None else: self.best_hps_ = best_hps_list[0] vlog( "\n--- Best Hyperparameters Found ---", verbose=verbose, level=1, logger=self._logger, ) for ( hp_name, hp_value, ) in self.best_hps_.values.items(): vlog( f" {hp_name}: {hp_value}", verbose=verbose, level=2, logger=self._logger, ) vlog( "\nBuilding model with best hyperparameters...", verbose=verbose, level=1, logger=self._logger, ) try: self.best_model_ = ( self.tuner_.hypermodel.build( self.best_hps_ ) ) except: self.best_model_ = ( self.tuner_.get_best_models( num_models=1 )[0] ) # Alternative except Exception as e: logger.error( f"Error retrieving or building best model: {e}" ) self.best_hps_ = None self.best_model_ = None self._save_tuning_summary(verbose=verbose) return self.best_model_, self.best_hps_, self.tuner_
def _save_tuning_summary(self, verbose: int = 1): if self.tuner_ is None or self.best_hps_ is None: vlog( "No tuner or best HPs found to save summary.", verbose=verbose, level=2, logger=self._logger, ) return summary_data = { "project_name": self.project_name, "tuner_type": self.tuner_type, "objective": self.objective if isinstance(self.objective, str) else getattr( self.objective, "name", str(self.objective) ), "best_hyperparameters": self.best_hps_.values if self.best_hps_ else None, } try: best_trial = self.tuner_.oracle.get_best_trials( 1 )[0] summary_data["best_score"] = best_trial.score summary_data["best_trial_id"] = ( best_trial.trial_id ) except: summary_data["best_score"] = "N/A" self.tuning_summary_ = summary_data log_file_path = os.path.join( self.directory, self.project_name, "tuning_summary.json", ) try: os.makedirs( os.path.dirname(log_file_path), exist_ok=True ) with open(log_file_path, "w") as f: json.dump( summary_data, f, indent=4, default=str ) vlog( f"Tuning summary saved to {log_file_path}", verbose=verbose, level=1, logger=self._logger, ) except Exception as e: logger.warning( f"Could not save tuning summary log to {log_file_path}: {e}" )
PINNTunerBase.__doc__ = r""" Base class for hyperparameter tuning of physics-informed models. This class wraps keras-tuner orchestration for GeoPrior PINN-style models. Subclasses are expected to implement ``build(hp)``. Parameters ---------- objective : str or keras_tuner.Objective, default "val_loss" Metric to optimize during the search. max_trials : int, default 10 Maximum number of hyperparameter trials to evaluate. project_name : str, default "PINN_Tuning" Project name used for tuner artifacts. directory : str, default "pinn_tuner_results" Root directory for tuner outputs. executions_per_trial : int, default 1 Number of repeated trainings per sampled hyperparameter set. tuner_type : {"randomsearch", "bayesianoptimization", "hyperband"}, default "randomsearch" Search backend used by keras-tuner. seed : int or None, default None Random seed used for reproducibility. overwrite_tuner : bool, default True Whether to overwrite an existing tuner project directory. tuner_kwargs : dict Additional keyword arguments forwarded to the underlying tuner backend. Attributes ---------- best_hps_ : keras_tuner.HyperParameters or None Best hyperparameters discovered during tuning. best_model_ : tf.keras.Model or None Best compiled model recovered from the tuner. tuner_ : keras_tuner.Tuner or None Underlying tuner instance. tuning_summary_ : dict Compact summary of the completed tuning run. """