# SPDX-License-Identifier: Apache-2.0
# GeoPrior-v3 — https://github.com/earthai-tech/geoprior-v3
# Copyright (c) 2026-present
# Author: LKouadio <https://lkouadio.com>
"""
Physics-Informed Neural Network (PINN) Utility functions.
"""
from __future__ import annotations
import logging
import os
import warnings # noqa
from collections.abc import Callable, Sequence
from typing import (
Any,
Final,
)
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.cm import ScalarMappable
from sklearn.preprocessing import MinMaxScaler
from ...api.util import get_table_size
from ...core.checks import (
check_datetime,
exist_features,
)
from ...core.handlers import columns_manager
from ...core.io import SaveFile
from ...decorators import isdf
from ...logging import get_logger
from ...metrics.utils import compute_quantile_diagnostics
from ...utils.data_utils import mask_by_reference
from ...utils.deps_utils import ensure_pkg
from ...utils.forecast_utils import normalize_for_pinn
from ...utils.generic_utils import (
print_box,
rename_dict_keys,
select_mode,
vlog,
)
from ...utils.geo_utils import resolve_spatial_columns
from ...utils.io_utils import save_job
from ...utils.sequence_utils import check_sequence_feasibility
from ...utils.validator import validate_positive_integer
from .. import KERAS_BACKEND, KERAS_DEPS
Model = KERAS_DEPS.Model
Tensor = KERAS_DEPS.Tensor
tf_shape = KERAS_DEPS.shape
tf_convert_to_tensor = KERAS_DEPS.convert_to_tensor
tf_float32 = KERAS_DEPS.float32
tf_cast = KERAS_DEPS.cast
tf_concat = KERAS_DEPS.concat
tf_expand_dims = KERAS_DEPS.expand_dims
tf_debugging = KERAS_DEPS.debugging
tf_fill = KERAS_DEPS.fill
tf_reshape = KERAS_DEPS.reshape
logger = get_logger(__name__)
_TW = get_table_size()
all__ = [
"format_pihalnet_predictions",
"prepare_pinn_data_sequences",
"format_pinn_predictions",
"extract_txy_in",
"extract_txy",
"plot_hydraulic_head",
"PDE_MODE_ALIASES",
]
_PDE_NONE_ALIASES = {
"none",
"off",
"false",
"0",
"disable",
"disabled",
"",
}
_PDE_BOTH_ALIASES = {
"both",
"on",
"all",
"true",
"1",
}
_PDE_TOKEN_ALIASES = {
"consolidation": "consolidation",
"cons": "consolidation",
"gw_flow": "gw_flow",
"gw": "gw_flow",
"groundwater": "gw_flow",
"groundwater_flow": "gw_flow",
}
_PDE_ORDER = ("consolidation", "gw_flow")
PDE_CANONICAL_MODES: Final[frozenset[str]] = frozenset(
{
"consolidation",
"gw_flow",
"both",
"none",
}
)
PDE_MODE_ALIASES: Final[frozenset[str]] = frozenset(
PDE_CANONICAL_MODES
| _PDE_NONE_ALIASES
| _PDE_BOTH_ALIASES
)
# Optional: strict public set for places where only canonical
# labels should be accepted after normalization.
PDE_ACTIVE_MODES: Final[frozenset[str]] = frozenset(
{
"consolidation",
"gw_flow",
"none",
}
)
def _flatten_mode_input(
value: str | Sequence[str] | None,
) -> list[str]:
"""
Convert the raw pde_mode input into a flat list of lowercase tokens.
"""
if value is None:
return ["none"]
if isinstance(value, str):
# support comma-separated strings like "consolidation,gw_flow"
parts = [p.strip().lower() for p in value.split(",")]
parts = [p for p in parts if p != ""]
return parts or ["none"]
if isinstance(value, list | tuple | set):
out: list[str] = []
for item in value:
if item is None:
out.append("none")
else:
s = str(item).strip().lower()
if s:
out.append(s)
else:
out.append("none")
return out or ["none"]
raise TypeError(
"`pde_mode` must be a string, a sequence of strings, or None."
)
def _normalize_pde_tokens(tokens: Sequence[str]) -> list[str]:
"""
Normalize aliases into canonical tokens.
Canonical tokens are:
- 'none'
- 'consolidation'
- 'gw_flow'
- 'both' (intermediate token, later expanded)
"""
normalized: list[str] = []
for tok in tokens:
if tok in _PDE_NONE_ALIASES:
normalized.append("none")
elif tok in _PDE_BOTH_ALIASES:
normalized.append("both")
elif tok in _PDE_TOKEN_ALIASES:
normalized.append(_PDE_TOKEN_ALIASES[tok])
else:
raise ValueError(
f"Unsupported pde_mode token: {tok!r}. "
"Allowed values are: "
"'none'/'off', 'consolidation', 'gw_flow', 'both'/'on'."
)
return normalized
def _canonicalize_pde_modes(
tokens: Sequence[str],
) -> list[str]:
"""
Convert normalized tokens into the final active-mode list.
Rules
-----
- 'none' cannot be combined with any active PDE mode.
- 'both' expands to ['consolidation', 'gw_flow'].
- duplicates are removed while preserving canonical order.
"""
toks = list(tokens)
has_none = "none" in toks
has_both = "both" in toks # noqa
# Expand "both" first conceptually, but reject ambiguous combinations.
if has_none:
non_none = [t for t in toks if t != "none"]
if non_none:
raise ValueError(
"Ambiguous pde_mode: 'none/off' cannot be combined with "
f"other modes: {non_none!r}."
)
return ["none"]
active = set()
for tok in toks:
if tok == "both":
active.update(("consolidation", "gw_flow"))
elif tok in ("consolidation", "gw_flow"):
active.add(tok)
else:
raise ValueError(
f"Unexpected normalized token: {tok!r}."
)
if not active:
return ["none"]
return [mode for mode in _PDE_ORDER if mode in active]
def _collapse_pde_modes_to_label(
active_modes: Sequence[str],
) -> str:
"""
Convert a canonical active-mode list back to a single label.
"""
modes = list(active_modes)
if modes == ["none"]:
return "none"
if modes == ["consolidation"]:
return "consolidation"
if modes == ["gw_flow"]:
return "gw_flow"
if modes == ["consolidation", "gw_flow"]:
return "both"
raise ValueError(
f"Cannot collapse unexpected active mode list: {modes!r}"
)
[docs]
def process_pde_modes(
pde_mode: str | Sequence[str] | None,
enforce_consolidation: bool = False,
pde_mode_config: str | Sequence[str] | None = None,
solo_return: bool = False,
) -> list[str] | str:
r"""
Normalize and validate PDE mode selection.
Parameters
----------
pde_mode : str, sequence of str, or None
Requested PDE mode(s).
Accepted canonical values are:
- ``"none"``
- ``"consolidation"``
- ``"gw_flow"``
- ``"both"``
Accepted aliases:
- ``None``, ``"off"`` -> ``"none"``
- ``"on"`` -> ``"both"``
enforce_consolidation : bool, default=False
If True, any resolved mode other than exact ``["consolidation"]``
is coerced to ``["consolidation"]`` and a warning is emitted.
This includes:
- ``["none"]``
- ``["gw_flow"]``
- ``["consolidation", "gw_flow"]``
pde_mode_config : str, sequence of str, or None, optional
Optional override. If provided, this value takes precedence over
``pde_mode``.
solo_return : bool, default=False
If False, return a canonical list of active modes.
If True, return a single canonical label:
- ``"none"``
- ``"consolidation"``
- ``"gw_flow"``
- ``"both"``
Returns
-------
list of str or str
Canonical PDE mode(s), either as a list or a single label.
Raises
------
TypeError
If the input type is invalid.
ValueError
If a token is unsupported or the mode selection is ambiguous.
Examples
--------
>>> process_pde_modes(None)
['none']
>>> process_pde_modes("off")
['none']
>>> process_pde_modes("on")
['consolidation', 'gw_flow']
>>> process_pde_modes("both", solo_return=True)
'both'
>>> process_pde_modes("gw_flow", enforce_consolidation=True)
['consolidation']
"""
source = (
pde_mode_config
if pde_mode_config is not None
else pde_mode
)
raw_tokens = _flatten_mode_input(source)
normalized_tokens = _normalize_pde_tokens(raw_tokens)
active_modes = _canonicalize_pde_modes(normalized_tokens)
if enforce_consolidation and active_modes != [
"consolidation"
]:
warnings.warn(
"This model requires PDE mode 'consolidation'. "
f"Received {active_modes!r}; coercing to ['consolidation'].",
UserWarning,
stacklevel=2,
)
active_modes = ["consolidation"]
if solo_return:
return _collapse_pde_modes_to_label(active_modes)
return active_modes
def _process_pde_modes(
pde_mode: str | list | None,
enforce_consolidation: bool = False,
pde_mode_config: str | list | None = None,
solo_return: bool = False,
) -> list:
r"""
Process and validate the `pde_mode` argument to determine the active PDE modes.
This function handles `pde_mode` inputs and processes them according to
the following rules:
- If the input is 'none', only the mode 'none' will be active.
- If the input is 'both', it will set both 'consolidation' and 'gw_flow'.
- If the input is not 'consolidation' and `enforce_consolidation` is True,
it will issue a warning and fallback to using only 'consolidation'.
- If `pde_mode_config` is provided, it overrides any other mode setting.
Parameters
----------
pde_mode : str, list of str, or None
The desired PDE modes. Can be:
- A string (e.g., 'consolidation', 'gw_flow', etc.)
- A list of strings (e.g., ['consolidation', 'gw_flow'])
- None (to set no active modes)
enforce_consolidation : bool, default=False
If True, the function ensures that 'consolidation' is the only
mode active and issues a warning if another mode is passed.
pde_mode_config : str, list of str, or None, optional
If provided, overrides the `pde_mode` argument.
Returns
-------
list of str
A list of active PDE modes. The list will contain the modes in lowercase.
If 'none' or 'both' is specified, these will be processed according to the logic.
Raises
------
TypeError
If `pde_mode` is neither a string, a list of strings, nor None.
Warnings
--------
If `enforce_consolidation` is True and a mode other than 'consolidation'
is passed, a warning will be issued, and 'consolidation' will be used
as the active mode instead.
"""
# If pde_mode_config is provided, use it, otherwise fall back to pde_mode
if pde_mode_config:
pde_mode = pde_mode_config
if isinstance(pde_mode, str):
if pde_mode.lower() == "on":
pde_mode = "both"
pde_modes_active = [pde_mode.lower()]
elif isinstance(pde_mode, list):
pde_modes_active = [
str(p_type).lower() for p_type in pde_mode
]
elif pde_mode is None:
pde_modes_active = [
"none"
] # Explicitly 'none' if None is provided
else:
raise TypeError(
"`pde_mode` must be a string, list of strings, or None."
)
# Handle special cases for "none" and "both"
if (
"none" in pde_modes_active
or "off" in pde_modes_active
): # If 'none' is present, override others
pde_modes_active = ["none"]
if (
"both" in pde_modes_active or "on" in pde_modes_active
): # If 'both' is present, use both modes
pde_modes_active = ["consolidation", "gw_flow"]
# Enforce consolidation mode if specified
if (
enforce_consolidation
and "consolidation" not in pde_modes_active
):
warnings.warn(
"You have passed a mode other than 'consolidation'. "
"Falling back to 'consolidation' as the active mode.",
UserWarning,
stacklevel=2,
)
pde_modes_active = ["consolidation"]
# Ensure 'consolidation' is the only active mode if it was defaulted,
# or if user explicitly selected only unsupported modes.
if any(
unsupported in pde_modes_active
for unsupported in {
"consolidation",
"gw_flow",
"both",
"none",
"on",
"off",
}
):
# This case means decorator didn't force 'consolidation', so we do it here
logger.info(
f"Unsupported pde_mode '{pde_mode}' "
"selected without 'consolidation'. "
"Model will use 'consolidation' mode."
)
# pde_modes_active = ['consolidation']
if solo_return:
pde_modes_active = pde_modes_active[0]
# Return the processed pde_modes_active
return pde_modes_active
def _format_target_predictions(
predictions_np: np.ndarray,
num_samples: int,
H: int, # Horizon
O: int, # Output dim for this specific target
base_target_name: str,
quantiles: list[float] | None,
verbose: int = 0,
) -> tuple[list[str], pd.DataFrame]:
"""Helper to format predictions for a single target variable."""
pred_cols_names = []
# Expected input shapes to this helper:
# Point: (N, H, O)
# Quantile: (N, H, Q, O) OR (N, H, Q) if O=1 was pre-squeezed
if quantiles:
num_q = len(quantiles)
# Ensure predictions_np is (N, H, Q, O)
if (
predictions_np.ndim == 3
and predictions_np.shape[-1] == num_q
and O == 1
):
# Case: (N, H, Q), implies O=1
preds_to_process = np.expand_dims(
predictions_np, axis=-1
) # (N,H,Q,1)
elif (
predictions_np.ndim == 3
and predictions_np.shape[-1] == num_q * O
):
# Case: (N, H, Q*O)
preds_to_process = predictions_np.reshape(
(num_samples, H, num_q, O)
)
elif (
predictions_np.ndim == 4
and predictions_np.shape[2] == num_q
and predictions_np.shape[3] == O
):
# Case: (N, H, Q, O) - already correct
preds_to_process = predictions_np
else:
raise ValueError(
f"Unexpected quantile prediction shape for {base_target_name}: "
f"{predictions_np.shape}. Expected compatible with N={num_samples}, "
f"H={H}, Q={num_q}, O={O}"
)
# Now preds_to_process is (N, H, Q, O)
df_data_for_concat = []
for o_idx in range(O):
for q_idx, q_val in enumerate(quantiles):
col_name = f"{base_target_name}"
if O > 1:
col_name += f"_{o_idx}"
col_name += f"_q{int(q_val * 100)}"
pred_cols_names.append(col_name)
df_data_for_concat.append(
preds_to_process[
:, :, q_idx, o_idx
].reshape(-1)
)
pred_df_part = pd.DataFrame(
dict(
zip(
pred_cols_names,
df_data_for_concat,
strict=False,
)
)
)
else: # Point forecast
# predictions_np should be (N, H, O)
if (
predictions_np.ndim != 3
or predictions_np.shape[-1] != O
):
raise ValueError(
f"Unexpected point prediction shape for {base_target_name}: "
f"{predictions_np.shape}. Expected (N,H,O) with O={O}"
)
df_data_for_concat = []
for o_idx in range(O):
col_name = f"{base_target_name}"
if O > 1:
col_name += f"_{o_idx}"
col_name += "_pred"
pred_cols_names.append(col_name)
df_data_for_concat.append(
predictions_np[:, :, o_idx].reshape(-1)
)
pred_df_part = pd.DataFrame(
dict(
zip(
pred_cols_names,
df_data_for_concat,
strict=False,
)
)
)
return pred_cols_names, pred_df_part
def _unpack_geoprior_outputs(
geoprior_outputs: dict[str, Tensor],
output_dims: dict[str, int],
quantiles: list[float] | None,
verbose: int = 0,
_logger: Any = None,
) -> dict[str, Tensor]:
"""
Unpacks the nested 'data_final' tensor from GeoPriorSubsNet
into a flat dictionary expected by the formatter.
"""
if "data_final" not in geoprior_outputs:
vlog(
" [WARN] 'data_final' key not found in GeoPrior model output. "
"Assuming outputs are already flat.",
level=2,
verbose=verbose,
logger=_logger,
)
return geoprior_outputs # Return as-is
data_final_tensor = geoprior_outputs["data_final"]
# Get the split index for subsidence
s_dim = output_dims.get("subs_pred")
if s_dim is None:
vlog(
" [WARN] 'output_dims' did not contain 'subs_pred'. "
"Assuming subsidence output_dim = 1.",
level=2,
verbose=verbose,
logger=_logger,
)
s_dim = 1 # Default to 1 as a fallback
# Split the tensor based on whether quantiles are present
if quantiles is not None:
# Shape is (B, H, Q, O_total)
vlog(
f" Splitting 4D quantile tensor at final axis index {s_dim}.",
level=4,
verbose=verbose,
logger=_logger,
)
s_pred_tensor = data_final_tensor[..., :s_dim]
h_pred_tensor = data_final_tensor[..., s_dim:]
else:
# Shape is (B, H, O_total)
vlog(
f" Splitting 3D point-forecast tensor at final axis index {s_dim}.",
level=4,
verbose=verbose,
logger=_logger,
)
s_pred_tensor = data_final_tensor[..., :s_dim]
h_pred_tensor = data_final_tensor[..., s_dim:]
# Create the new flat dictionary
unpacked_outputs = {
"subs_pred": s_pred_tensor,
"gwl_pred": h_pred_tensor,
}
# Pass through any other keys that aren't the standard nested ones
for k, v in geoprior_outputs.items():
if k not in [
"data_final",
"data_mean",
"phys_mean_raw",
]:
unpacked_outputs[k] = v
return unpacked_outputs
def _inverse_transform_array(
array: np.ndarray,
base_name: str,
scaler_info: dict[str, dict[str, Any]] | None,
verbose: int = 0,
_logger: Any = None,
) -> np.ndarray:
"""
Helper to inverse-transform a single numpy array using scaler_info.
This function finds the correct scaler and column index from
scaler_info and applies the inverse min-max scaling.
"""
if scaler_info is None or base_name not in scaler_info:
# No scaler info for this target, return original array
vlog(
f" - No scaler_info found for '{base_name}'. "
"Returning original (scaled) array.",
level=4,
verbose=verbose,
logger=_logger,
)
return array
try:
info = scaler_info[base_name]
scaler = info["scaler"]
col_index = info["idx"]
# Get the specific min and scale for this column from the scaler
# scaler.min_ is the 'intercept'
# scaler.scale_ is the 'coefficient'
min_val = scaler.min_[col_index]
scale_val = scaler.scale_[col_index]
# --- THIS IS THE FIX ---
# The inverse of (X * scale_) + min_ is (X - min_) / scale_
if scale_val == 0:
# Handle case where scale is zero (all original values were the same)
# The inverse is the original constant value, which we get from data_min_
try:
original_min = scaler.data_min_[col_index]
except AttributeError:
# Fallback if data_min_ is not available (should be on a fitted scaler)
original_min = 0 # This is a guess, but division by zero is worse
vlog(
f" [WARN] Scaler for '{base_name}' has scale_=0 but "
"no 'data_min_' attribute. Inverse may be incorrect.",
level=2,
verbose=verbose,
logger=_logger,
)
inverted_array = np.full_like(array, original_min)
else:
inverted_array = (array - min_val) / scale_val
vlog(
f" - Applied inverse transform to '{base_name}' array "
f"(min_attr: {min_val:.4f}, scale_attr: {scale_val:.4f})",
level=5,
verbose=verbose,
logger=_logger,
)
return inverted_array
except AttributeError:
vlog(
f" [WARN] Scaler for '{base_name}' is missing 'min_' or 'scale_'"
" attributes. Returning original array.",
level=2,
verbose=verbose,
logger=_logger,
)
return array
except IndexError:
vlog(
f" [WARN] Column index '{col_index}' out of bounds for "
f"scaler of '{base_name}'. Returning original array.",
level=2,
verbose=verbose,
logger=_logger,
)
return array
except Exception as e:
vlog(
f" [WARN] Failed to inverse transform '{base_name}': {e}"
". Returning original array.",
level=2,
verbose=verbose,
logger=_logger,
)
return array
def _add_quantiles_to_df(
predictions: dict[str, np.ndarray],
target_mapping: dict[str, str],
quantiles: list[float],
output_dims: dict[str, int],
num_samples: int,
H_inferred: int,
verbose: int = 0,
_logger: Any = None,
) -> pd.DataFrame:
"""Formats INVERTED quantile predictions into a DataFrame."""
df_parts = []
for pred_key, base_name in target_mapping.items():
if pred_key not in predictions:
continue
pred_array = predictions[
pred_key
] # Should be (B, H, Q, O)
O_target = output_dims.get(pred_key)
if O_target is None:
O_target = pred_array.shape[-1]
vlog(
f" [WARN] output_dim for '{pred_key}' not specified. "
f"Inferred {O_target} from array shape.",
level=2,
verbose=verbose,
logger=_logger,
)
if pred_array.shape != (
num_samples,
H_inferred,
len(quantiles),
O_target,
):
vlog(
f" [WARN] Quantile array for '{pred_key}' has unexpected shape "
f"{pred_array.shape}. Expected {(num_samples, H_inferred, len(quantiles), O_target)}. "
"Skipping.",
level=2,
verbose=verbose,
logger=_logger,
)
continue
# Reshape: (B, H, Q, O) -> (B*H, Q*O)
reshaped_preds = pred_array.reshape(
num_samples * H_inferred, -1
)
# Create column names
col_names = []
for o_idx in range(O_target):
for q in quantiles:
# Format quantile string (e.g., q05, q50, q95)
q_str = f"q{int(q * 100):02d}"
if q_str == "q50":
q_str = "q50" # Standardize
elif q_str == "q10":
q_str = "q10"
elif q_str == "q90":
q_str = "q90"
elif q_str == "q05":
q_str = "q05"
elif q_str == "q95":
q_str = "q95"
col_name = f"{base_name}_{q_str}"
if O_target > 1:
col_name += f"_{o_idx}"
col_names.append(col_name)
if reshaped_preds.shape[1] != len(col_names):
vlog(
f" [WARN] Mismatch in quantile columns for '{base_name}'. "
f"Expected {len(col_names)} columns, found {reshaped_preds.shape[1]}. "
"Skipping.",
level=2,
verbose=verbose,
logger=_logger,
)
continue
df_parts.append(
pd.DataFrame(reshaped_preds, columns=col_names)
)
vlog(
f" - Added quantile columns for '{base_name}': {col_names}",
level=5,
verbose=verbose,
logger=_logger,
)
return (
pd.concat(df_parts, axis=1)
if df_parts
else pd.DataFrame()
)
def _add_point_preds_to_df(
predictions: dict[str, np.ndarray],
target_mapping: dict[str, str],
output_dims: dict[str, int],
num_samples: int,
H_inferred: int,
verbose: int = 0,
_logger: Any = None,
) -> pd.DataFrame:
"""Formats INVERTED point predictions into a DataFrame."""
df_parts = []
for pred_key, base_name in target_mapping.items():
if pred_key not in predictions:
continue
pred_array = predictions[pred_key] # (B, H, O)
O_target = output_dims.get(pred_key)
if O_target is None:
O_target = pred_array.shape[-1]
vlog(
f" [WARN] output_dim for '{pred_key}' not specified. "
f"Inferred {O_target} from array shape.",
level=2,
verbose=verbose,
logger=_logger,
)
if pred_array.shape != (
num_samples,
H_inferred,
O_target,
):
vlog(
f" [WARN] Point pred array for '{pred_key}' has unexpected shape "
f"{pred_array.shape}. Expected {(num_samples, H_inferred, O_target)}. "
"Skipping.",
level=2,
verbose=verbose,
logger=_logger,
)
continue
reshaped_preds = pred_array.reshape(
num_samples * H_inferred, O_target
)
col_names = []
for o_idx in range(O_target):
col_name = f"{base_name}_pred"
if O_target > 1:
col_name += f"_{o_idx}"
col_names.append(col_name)
df_parts.append(
pd.DataFrame(reshaped_preds, columns=col_names)
)
vlog(
f" - Added point pred columns for '{base_name}': {col_names}",
level=5,
verbose=verbose,
logger=_logger,
)
return (
pd.concat(df_parts, axis=1)
if df_parts
else pd.DataFrame()
)
def _add_actuals_to_df(
y_true_dict: dict[str, np.ndarray],
target_mapping: dict[str, str],
output_dims: dict[str, int],
num_samples: int,
H_inferred: int,
verbose: int = 0,
_logger: Any = None,
) -> pd.DataFrame:
"""Formats INVERTED actuals into a DataFrame."""
df_parts = []
for pred_key, base_name in target_mapping.items():
if pred_key not in y_true_dict:
continue
true_array = y_true_dict[pred_key] # (B, H, O)
O_target = output_dims.get(pred_key)
if O_target is None:
O_target = true_array.shape[-1]
vlog(
f" [WARN] output_dim for '{pred_key}' not specified. "
f"Inferred {O_target} from array shape.",
level=2,
verbose=verbose,
logger=_logger,
)
if true_array.shape != (
num_samples,
H_inferred,
O_target,
):
vlog(
f" [WARN] Actuals array for '{pred_key}' has unexpected shape "
f"{true_array.shape}. Expected {(num_samples, H_inferred, O_target)}. "
"Skipping.",
level=2,
verbose=verbose,
logger=_logger,
)
continue
reshaped_true = true_array.reshape(
num_samples * H_inferred, O_target
)
col_names = []
for o_idx in range(O_target):
col_name = f"{base_name}_actual"
if O_target > 1:
col_name += f"_{o_idx}"
col_names.append(col_name)
df_parts.append(
pd.DataFrame(reshaped_true, columns=col_names)
)
vlog(
f" - Added actuals columns for '{base_name}': {col_names}",
level=5,
verbose=verbose,
logger=_logger,
)
return (
pd.concat(df_parts, axis=1)
if df_parts
else pd.DataFrame()
)
def _add_coords_to_df(
model_inputs: dict[str, Tensor],
coord_scaler: Any,
num_samples: int,
H_inferred: int,
verbose: int = 0,
_logger: Any = None,
force_future_from: float | None = None,
) -> tuple[pd.DataFrame, list[str]]:
"""Formats and inverse-transforms coordinates."""
coord_names = ["coord_t", "coord_x", "coord_y"]
if (
"coords" not in model_inputs
or model_inputs["coords"] is None
):
vlog(
" 'coords' not in model_inputs. Skipping coordinate columns.",
level=2,
verbose=verbose,
logger=_logger,
)
return pd.DataFrame(), []
coords_arr = model_inputs["coords"]
if hasattr(coords_arr, "numpy"):
coords_arr = coords_arr.numpy()
if coords_arr.shape != (num_samples, H_inferred, 3):
vlog(
f" 'coords' shape mismatch ({coords_arr.shape}). Expected "
f"{(num_samples, H_inferred, 3)}. Skipping coordinate columns.",
level=2,
verbose=verbose,
logger=_logger,
)
return pd.DataFrame(), []
coords_reshaped = coords_arr.reshape(
num_samples * H_inferred, 3
)
if coord_scaler is not None:
vlog(
" Applying inverse transform to t,x,y coordinates...",
level=4,
verbose=verbose,
logger=_logger,
)
try:
coords_reshaped = coord_scaler.inverse_transform(
coords_reshaped
)
except Exception as e:
vlog(
f" [WARN] Could not inverse transform coordinates: {e}. "
"Using normalized coordinates.",
level=2,
verbose=verbose,
logger=_logger,
)
# --- override time with forecast years, if requested ---
if force_future_from is not None:
# years_for_one_sample: e.g. [2023, 2024, 2025] for H_inferred=3
years = np.arange(
float(force_future_from),
float(force_future_from) + float(H_inferred),
dtype=coords_reshaped.dtype,
)
# Repeat for all samples and flatten to match (num_samples*H, )
years_tiled = np.tile(years, num_samples)
coords_reshaped[:, 0] = years_tiled
vlog(
f" Overriding coord_t with forecast years starting at "
f"{force_future_from} for horizon={H_inferred}.",
level=4,
verbose=verbose,
logger=_logger,
)
vlog(
f" Added coordinate columns: {coord_names}",
level=4,
verbose=verbose,
logger=_logger,
)
return pd.DataFrame(
coords_reshaped, columns=coord_names
), coord_names
def _add_ids_to_df(
ids_data_array: Any,
ids_cols: list[str] | None,
ids_cols_indices: list[int] | None,
num_samples: int,
H_inferred: int,
verbose: int = 0,
_logger: Any = None,
) -> pd.DataFrame:
"""Extracts and repeats static/ID columns."""
if ids_data_array is None:
vlog(
" No `ids_data_array` provided, skipping static/ID columns.",
level=4,
verbose=verbose,
logger=_logger,
)
return pd.DataFrame()
vlog(
" Processing additional static/ID columns...",
level=4,
verbose=verbose,
logger=_logger,
)
ids_np_array = None
ids_cols_to_use = []
if isinstance(ids_data_array, pd.DataFrame):
ids_cols_to_use = (
list(ids_cols)
if ids_cols
else list(ids_data_array.columns)
)
try:
ids_np_array = ids_data_array[
ids_cols_to_use
].values
except KeyError as e:
vlog(
f" [WARN] Columns not in ids_data_array: {e}. Skipping IDs.",
level=2,
verbose=verbose,
logger=_logger,
)
return pd.DataFrame()
elif isinstance(ids_data_array, np.ndarray):
ids_np_array = ids_data_array
if ids_cols_indices:
ids_np_array = ids_np_array[:, ids_cols_indices]
if (
ids_cols
and len(ids_cols) == ids_np_array.shape[1]
):
ids_cols_to_use = list(ids_cols)
else:
ids_cols_to_use = [
f"id_{k}"
for k in range(ids_np_array.shape[1])
]
elif isinstance(ids_data_array, Tensor): # Is a Tensor
ids_np_array = ids_data_array.numpy()
if ids_cols_indices:
ids_np_array = ids_np_array[:, ids_cols_indices]
if (
ids_cols
and len(ids_cols) == ids_np_array.shape[1]
):
ids_cols_to_use = list(ids_cols)
else:
ids_cols_to_use = [
f"id_{k}"
for k in range(ids_np_array.shape[1])
]
else:
vlog(
f" [WARN] Unsupported type for `ids_data_array`: "
f"{type(ids_data_array)}. Skipping IDs.",
level=2,
verbose=verbose,
logger=_logger,
)
return pd.DataFrame()
# Now, repeat the data
if (
ids_np_array is not None
and ids_np_array.shape[0] == num_samples
):
# Use np.repeat to tile each sample H_inferred times
expanded_ids_data = np.repeat(
ids_np_array, H_inferred, axis=0
)
vlog(
f" - Added static/ID columns: {ids_cols_to_use}",
level=5,
verbose=verbose,
logger=_logger,
)
return pd.DataFrame(
expanded_ids_data, columns=ids_cols_to_use
)
else:
vlog(
f" [WARN] `ids_data_array` sample size"
f" ({ids_np_array.shape[0] if ids_np_array is not None else 'None'}) "
f"does not match predictions ({num_samples}). Skipping IDs.",
level=2,
verbose=verbose,
logger=_logger,
)
return pd.DataFrame()
def _evaluate_coverage(
df: pd.DataFrame,
target_mapping: dict[str, str],
q_indices: tuple[int, int],
savefile: str | None = None,
name: str | None = None,
verbose: int = 0,
_logger: Any = None,
):
"""Calculates and logs quantile coverage from an INVERTED DataFrame."""
vlog(
" Evaluating forecast coverage...",
level=4,
verbose=verbose,
logger=_logger,
)
for base_name in target_mapping.values():
actual_col = f"{base_name}_actual"
# Find quantile columns
q_cols = sorted(
[
c
for c in df.columns
if c.startswith(base_name) and "_q" in c
]
)
if not q_cols or actual_col not in df.columns:
vlog(
f" Skipping coverage for '{base_name}': "
"Missing actuals or quantile columns.",
level=3,
verbose=verbose,
logger=_logger,
)
continue
try:
q_low_col = q_cols[q_indices[0]]
q_high_col = q_cols[q_indices[1]]
except IndexError:
vlog(
f" [WARN] Coverage q_indices {q_indices} are out of bounds "
f"for detected quantile columns: {q_cols}. Skipping coverage.",
level=2,
verbose=verbose,
logger=_logger,
)
continue
q_low = df[q_low_col]
q_high = df[q_high_col]
actual = df[actual_col]
is_covered = (actual >= q_low) & (actual <= q_high)
coverage_percent = is_covered.mean() * 100
vlog(
f" Forecast Coverage ({base_name}, {q_low_col} to {q_high_col}): "
f"{coverage_percent:.2f}%",
level=1,
verbose=verbose,
logger=_logger,
)
# Store in df attributes
df.attrs[f"{base_name}_coverage"] = coverage_percent
# Use compute_quantile_diagnostics
try:
cov_savefile = None
if savefile:
# If savefile has an extension, treat it as a file path.
# Example: ".../subs_eval.csv"
# -> ".../subs_eval_diagnostics_results.json"
root, ext = os.path.splitext(savefile)
if ext:
cov_savefile = (
f"{root}_diagnostics_results.json"
)
else:
# No extension: treat `savefile` as a directory-like path
# Example: "results/nansha_eval"
# -> "results/nansha_eval/diagnostics_results.json"
cov_savefile = os.path.join(
savefile, "diagnostics_results.json"
)
compute_quantile_diagnostics(
df,
target_name=base_name,
quantiles=[
float(c.split("_q")[-1]) / 100.0
for c in q_cols
],
coverage_quantile_indices=q_indices,
savefile=cov_savefile,
name=name,
verbose=verbose,
logger=_logger,
)
except Exception as e:
vlog(
f" [WARN] Failed to compute quantile diagnostics: {e}",
level=2,
verbose=verbose,
logger=_logger,
)
def _get_model_predictions(
pihalnet_outputs: dict[str, Tensor] | None,
model: Model | None,
model_inputs: dict[str, Tensor] | None,
verbose: int,
stop_check: Callable[[], bool],
_logger: Any,
) -> dict[str, Tensor]:
"""
Obtain model outputs dict, running predict if needed.
"""
if pihalnet_outputs is not None:
return pihalnet_outputs
if model is None or model_inputs is None:
raise ValueError(
"If 'pihalnet_outputs' is None, both 'model' and "
"'model_inputs' must be provided."
)
vlog(
" Predictions not provided, generating from model...",
level=4,
verbose=verbose,
logger=_logger,
)
try:
outputs = model.predict(model_inputs, verbose=0)
if not isinstance(outputs, dict):
raise ValueError(
"Model output is not a dict as expected from PINN models"
)
except Exception as e:
raise RuntimeError(
f"Failed to generate predictions: {e}"
) from e
vlog(
" Model predictions generated.",
level=5,
verbose=verbose,
logger=_logger,
)
if stop_check and stop_check():
raise InterruptedError("Model prediction aborted.")
return outputs
def _convert_outputs_to_numpy(
pihalnet_outputs: dict[str, Any],
verbose: int,
_logger,
) -> dict[str, np.ndarray]:
"""
Ensure prediction tensors are NumPy arrays.
"""
proc: dict[str, np.ndarray] = {}
for key, val in pihalnet_outputs.items():
if key in ("subs_pred", "gwl_pred"):
if hasattr(val, "numpy"):
proc[key] = val.numpy()
elif isinstance(val, np.ndarray):
proc[key] = val
else:
try:
proc[key] = np.array(val)
except Exception as e:
raise TypeError(
f"Could not convert output '{key}' to NumPy. "
f"Type: {type(val)}. Error: {e}"
) from e
elif key == "pde_residual":
# PDE residual can be handled differently if needed
pass # Not typically added to this output DataFrame
if not proc:
vlog(
" No 'subs_pred' or 'gwl_pred'"
" found in outputs. Returning empty DF.",
level=1,
verbose=verbose,
logger=_logger,
)
return pd.DataFrame()
return proc
def _define_targets(
processed: dict[str, np.ndarray],
include_gwl: bool,
target_mapping: dict[str, str] | None = None,
stop_check: Callable[[], bool] = None,
) -> dict[str, str]:
"""
Map processed outputs to base target names.
"""
if target_mapping is None:
target_mapping = {
"subs_pred": "subsidence",
"gwl_pred": "gwl",
}
targets: dict[str, str] = {}
if "subs_pred" in processed:
targets["subs_pred"] = target_mapping["subs_pred"]
if include_gwl and "gwl_pred" in processed:
targets["gwl_pred"] = target_mapping["gwl_pred"]
if stop_check and stop_check():
raise InterruptedError(
"Target confifuration aborted."
)
return targets
def _infer_dimensions(
first_pred: np.ndarray,
forecast_horizon: int | None,
verbose: int = 0,
_logger: Any = print,
) -> tuple[int, int]:
"""
Infer num_samples and horizon (H) from the array.
"""
N = first_pred.shape[0]
H = forecast_horizon or first_pred.shape[1]
if N == 0 or H == 0:
raise ValueError(
"No samples or zero forecast horizon."
)
vlog(
f" Formatting for {N} samples, Horizon={H}.",
level=4,
verbose=verbose,
logger=_logger,
)
return N, H
def _build_sample_df(
num_samples: int, H: int
) -> pd.DataFrame:
"""
Build DataFrame of sample_idx and forecast_step.
"""
idx = np.repeat(np.arange(num_samples), H)
steps = np.tile(np.arange(1, H + 1), num_samples)
return pd.DataFrame(
{"sample_idx": idx, "forecast_step": steps}
)
def _add_coordinate_columns(
base_df: pd.DataFrame,
model_inputs: dict[str, Any] | None,
coord_scaler: Any,
include_coords: bool,
stop_check: Callable[[], bool],
_logger: Any,
verbose: int,
) -> pd.DataFrame:
"""
Extract and inverse-transform coords if requested.
"""
if not include_coords:
return pd.DataFrame()
coords_arr = None
if model_inputs and "coords" in model_inputs:
coords_arr = model_inputs["coords"]
if hasattr(coords_arr, "numpy"):
coords_arr = coords_arr.numpy()
if coords_arr is None:
return pd.DataFrame()
N = base_df["sample_idx"].nunique()
H = base_df["forecast_step"].max()
if coords_arr.shape != (N, H, 3):
vlog(
" 'coords' shape mismatch or not found."
" Skipping coordinate columns.",
level=2,
verbose=verbose,
logger=_logger,
)
return pd.DataFrame()
resh = coords_arr.reshape(N * H, 3)
if coord_scaler:
vlog(
" Applying inverse transform to t,x,y coordinates...",
level=4,
verbose=verbose,
logger=_logger,
)
try:
resh = coord_scaler.inverse_transform(resh)
except Exception:
pass
if stop_check and stop_check():
raise InterruptedError(
"Coordinates transformation aborted."
)
return pd.DataFrame(
resh, columns=["coord_t", "coord_x", "coord_y"]
)
def _add_id_columns(
base_df: pd.DataFrame,
ids_data_array: Any,
ids_cols: list[str] | None,
ids_cols_indices: list[int] | None,
stop_check: Callable[[], bool],
_logger: Any,
verbose: int,
) -> pd.DataFrame:
"""
Process and repeat static ID columns per forecast step.
Converts DataFrame, Series, NumPy array, or Tensor to
NumPy, selects requested columns or indices, then uses
`sample_idx` from base_df to expand each ID row
across all forecast steps.
"""
# Skip if no IDs provided
if ids_data_array is None:
vlog(
" No `ids_data_array` provided, skipping ID columns.",
level=4,
verbose=verbose,
logger=_logger,
)
return pd.DataFrame()
vlog(
" Processing additional static/ID columns...",
level=4,
verbose=verbose,
logger=_logger,
)
# Convert to NumPy array
try:
if isinstance(ids_data_array, pd.DataFrame):
arr = ids_data_array.values
col_names = ids_cols or list(
ids_data_array.columns
)
elif isinstance(ids_data_array, pd.Series):
arr = ids_data_array.to_frame().values
col_names = ids_cols or [
ids_data_array.name or "id_0"
]
if (
arr.ndim == 2
and len(col_names) != arr.shape[1]
):
vlog(
" Series IDs length mismatch, using first col.",
level=5,
verbose=verbose,
logger=_logger,
)
col_names = col_names[:1]
elif hasattr(ids_data_array, "numpy"):
arr = ids_data_array.numpy()
col_names = ids_cols or [
f"id_{i}" for i in range(arr.shape[1])
]
elif isinstance(ids_data_array, np.ndarray):
arr = ids_data_array
if ids_cols_indices is not None:
arr = arr[:, ids_cols_indices]
col_names = ids_cols or [
f"id_{i}" for i in range(arr.shape[1])
]
else:
_logger.warning(
f"Unsupported IDs type {type(ids_data_array)}."
" Skipping ID columns."
)
return pd.DataFrame()
except Exception as e:
_logger.warning(
f"Error converting IDs to array: {e}."
" Skipping ID columns."
)
return pd.DataFrame()
vlog(
f" Converted IDs to NumPy with shape {arr.shape}",
level=5,
verbose=verbose,
logger=_logger,
)
vlog(
f" Selected ID column names: {col_names}",
level=5,
verbose=verbose,
logger=_logger,
)
# Expand rows according to sample_idx
sample_indices = base_df["sample_idx"].values
if arr.shape[0] != sample_indices.max() + 1:
_logger.warning(
f"IDs rows ({arr.shape[0]}) != samples "
f"({sample_indices.max() + 1}). Skipping."
)
return pd.DataFrame()
expanded = arr[sample_indices]
id_df = pd.DataFrame(expanded, columns=col_names)
vlog(
f" Expanded ID DataFrame shape: {id_df.shape}",
level=5,
verbose=verbose,
logger=_logger,
)
# User cancellation check
if stop_check and stop_check():
raise InterruptedError(
"ID columns processing aborted."
)
return id_df
# # Logic to convert ids_data_array to 2D NumPy
# # and repeat according to base_df indices
# # ... (Detailed implementation)
# return pd.DataFrame() # stub for brevity
def _process_target_variable(
preds: np.ndarray,
pred_key: str,
base_name: str,
num_samples: int,
H: int,
quantiles: list[float] | None,
y_true_dict: dict[str, Any] | None,
scaler_info: dict[str, Any] | None,
verbose: int,
stop_check: Callable[[], bool],
_logger: Any,
) -> pd.DataFrame:
"""
Format predictions, attach actuals, and prepare scaling.
- Calls _format_target_predictions for quantiles
or point forecasts.
- Reshapes and adds true values if provided.
- Records inverse-scaling actions to apply later.
"""
# 1. Format predictions
cols_pred, df_pred = _format_target_predictions(
preds,
num_samples,
H,
O=(preds.shape[-1] if preds.ndim == 3 else 1),
base_target_name=base_name,
quantiles=quantiles,
verbose=verbose,
)
dfs: list[pd.DataFrame] = [df_pred]
# 2. Add actual values
if y_true_dict:
y_true = y_true_dict.get(pred_key)
if y_true is None:
y_true = y_true_dict.get(base_name)
if hasattr(y_true, "numpy"):
y_true = y_true.numpy()
if y_true is not None:
arr = y_true.reshape(num_samples * H, -1)
cols_act = []
for i in range(arr.shape[1]):
col = base_name
if arr.shape[1] > 1:
col += f"_{i}"
cols_act.append(f"{col}_actual")
dfs.append(pd.DataFrame(arr, columns=cols_act))
# 3. Prepare inverse-scaling info
if scaler_info and base_name in scaler_info:
info = scaler_info[base_name]
to_transform = cols_pred + cols_act
info["_cols_to_inv"] = to_transform
if stop_check and stop_check():
raise InterruptedError(
"Target variable processing aborted."
)
return pd.concat(dfs, axis=1)
def _apply_masking(
df: pd.DataFrame,
targets: dict[str, str],
apply_mask: bool,
mask_values: Any,
mask_fill_value: float,
quantiles: list[float] | None,
) -> pd.DataFrame:
"""
Mask forecasts based on reference column values.
"""
if not apply_mask:
return df
first_base = list(targets.values())[0]
ref = f"{first_base}_actual"
if ref not in df:
raise KeyError(f"Reference col '{ref}' missing")
cols = []
if quantiles:
for base in targets.values():
for q in quantiles:
cols.append(f"{base}_q{int(q * 100)}")
else:
cols = [c for c in df if c.endswith("_pred")]
return mask_by_reference(
data=df,
ref_col=ref,
values=mask_values,
find_closest=False,
fill_value=mask_fill_value,
mask_columns=cols,
error="ignore",
inplace=False,
)
def _compute_coverage(
df: pd.DataFrame,
targets: dict[str, str],
quantiles: list[float],
y_true_dict: dict[str, Any],
coverage_quantile_indices: tuple[int, int],
savefile: str | None,
name: str | None,
verbose: int,
_logger: Any,
) -> None:
"""
Compute quantile diagnostics per target.
"""
for base in targets.values():
try:
compute_quantile_diagnostics(
df,
target_name=base,
quantiles=quantiles,
coverage_quantile_indices=coverage_quantile_indices,
savefile=(
os.path.join(
os.path.dirname(savefile),
"diagnostics_results.json",
)
if savefile
else None
),
name=name,
verbose=verbose,
logger=_logger,
)
except Exception as e:
vlog(
f"Skipping coverage due to {e}",
level=2,
verbose=verbose,
logger=_logger,
)
# XXX : TODO
# revise format outputs and apply
# mask by reference
[docs]
@isdf
def prepare_pinn_data_sequences(
df: pd.DataFrame,
time_col: str,
subsidence_col: str,
gwl_col: str,
dynamic_cols: list[str],
static_cols: list[str] | None = None,
future_cols: list[str] | None = None,
spatial_cols: tuple[str, str] | None = None,
h_field_col: str | None = None,
lon_col: str | None = None,
lat_col: str | None = None,
group_id_cols: list[str] | None = None,
time_steps: int = 12,
forecast_horizon: int = 3,
output_subsidence_dim: int = 1,
output_gwl_dim: int = 1,
datetime_format: str | None = None,
normalize_coords: bool = True,
cols_to_scale: list[str] | str | None = None,
lock_physics_cols: bool = True,
protect_si_suffix: str = "__si",
return_coord_scaler: bool = False,
coord_scaler: MinMaxScaler | None = None,
fit_coord_scaler: bool = True,
mode: str | None = None,
model: str | None = None,
savefile: str | None = None,
progress_hook: Callable[[float], None] | None = None,
stop_check: Callable[[], bool] = None,
verbose: int = 0,
_logger: logging.Logger
| Callable[[str], None]
| None = None,
**kws,
) -> (
tuple[dict[str, np.ndarray], dict[str, np.ndarray]]
| tuple[
dict[str, np.ndarray],
dict[str, np.ndarray],
MinMaxScaler | None,
]
):
# -------------------------------------------------------------------------
_to_range = lambda f, lo, hi: lo + (hi - lo) * f # noqa: E731
vlog(
"Starting PINN data sequence preparation...",
verbose=verbose,
level=1,
logger=_logger,
)
df_proc = df.copy()
# --- 1. Validate Input Parameters and Columns ---
if verbose >= 2:
logger.debug(
"Validating input parameters and columns."
)
vlog(
"Validating essential columns...",
verbose=verbose,
level=2,
logger=_logger,
)
lon_col, lat_col = resolve_spatial_columns(
df_proc,
spatial_cols=spatial_cols,
lon_col=lon_col,
lat_col=lat_col,
)
essential_cols = [
time_col,
lon_col,
lat_col,
subsidence_col,
gwl_col,
]
# --- MODIFICATION 1: Conditional Validation ---
is_geoprior = str(model).lower().strip() in (
"geoprior",
"geopriorsubsnet",
)
if is_geoprior:
if h_field_col is None:
# Check for default names
if "H_field" in df_proc.columns:
h_field_col = "H_field"
elif "soil_thickness" in df_proc.columns:
h_field_col = "soil_thickness"
else:
raise ValueError(
"`model` is 'geoprior' but `h_field_col` was not "
"provided and default names 'H_field' or 'soil_thickness' "
"were not found in the DataFrame."
)
vlog(
f"GeoPrior model detected. Using '{h_field_col}' as H_field.",
verbose=verbose,
level=3,
logger=_logger,
)
essential_cols.append(h_field_col)
vlog(
f"Validated '{h_field_col}' for GeoPriorSubsNet.",
verbose=verbose,
level=3,
logger=_logger,
)
# --- End Modification 1 ---
exist_features(
df_proc,
features=essential_cols,
message="Essential column(s) missing.",
)
vlog(
"Validating time-series dataset...",
verbose=verbose,
level=2,
logger=_logger,
)
check_datetime(
df_proc,
dt_cols=time_col,
ops="check_only",
consider_dt_as="numeric",
accept_dt=True,
allow_int=True,
)
vlog(
"Managing feature column lists...",
verbose=verbose,
level=3,
logger=_logger,
)
dynamic_cols = columns_manager(
dynamic_cols, empty_as_none=False
)
exist_features(
df_proc,
features=dynamic_cols,
name="Dynamic feature column(s)",
)
static_cols = columns_manager(static_cols)
if static_cols:
exist_features(
df_proc,
features=static_cols,
name="Static feature column(s)",
)
future_cols = columns_manager(future_cols)
if future_cols:
exist_features(
df_proc,
features=future_cols,
name="Future feature column(s)",
)
group_id_cols = columns_manager(group_id_cols)
if group_id_cols:
exist_features(
df_proc,
features=group_id_cols,
name="Group ID column(s)",
)
vlog(
"Validating time_steps and forecast_horizon...",
verbose=verbose,
level=3,
logger=_logger,
)
time_steps = validate_positive_integer(
time_steps, "time_steps"
)
forecast_horizon = validate_positive_integer(
forecast_horizon,
"forecast_horizon",
)
vlog(
"Converting time column to numeric values...",
verbose=verbose,
level=4,
logger=_logger,
)
if pd.api.types.is_numeric_dtype(df_proc[time_col]):
numerical_time_col = time_col
if verbose >= 2:
logger.debug(
f"Time column '{time_col}' is already numeric. Using it directly."
)
else:
try:
df_proc[time_col] = pd.to_datetime(
df_proc[time_col], format=datetime_format
)
df_proc[f"{time_col}_numeric"] = df_proc[
time_col
].dt.year + (
df_proc[time_col].dt.dayofyear - 1
) / (
365
+ df_proc[time_col].dt.is_leap_year.astype(
int
)
)
numerical_time_col = f"{time_col}_numeric"
if verbose >= 2:
logger.debug(
f"Converted datetime column '{time_col}'"
f" to numerical '{numerical_time_col}'."
)
vlog(
f"Time column converted to '{numerical_time_col}'",
verbose=verbose,
level=5,
logger=_logger,
)
except Exception as e:
raise ValueError(
f"Failed to convert or process time column '{time_col}'. "
f"Ensure it's datetime-like or specify `datetime_format`."
f" Error: {e}"
)
vlog(
"Pre-flight: assessing sliding-window feasibility ...",
verbose=verbose,
level=1,
logger=_logger,
)
ok, _ = check_sequence_feasibility(
df_proc.copy(),
time_col=time_col,
group_id_cols=group_id_cols,
time_steps=time_steps,
forecast_horizon=forecast_horizon,
verbose=verbose,
error="raise", # fail fast on impossibility,
logger=_logger,
)
vlog(
"Feasibility check passed — generating sequences...",
verbose=verbose,
level=1,
logger=_logger,
)
vlog(
"Starting PINN data sequence generation...",
verbose=verbose,
level=1,
logger=_logger,
)
mode = select_mode(mode, default="pihal_like")
vlog(
f"Operating in '{mode}' data preparation mode.",
level=1,
verbose=verbose,
logger=_logger,
)
exclude_cols = []
if lock_physics_cols:
exclude_cols += [subsidence_col, gwl_col]
if h_field_col is not None:
exclude_cols.append(h_field_col)
# Also exclude any "__si" columns automatically (safe for Stage-1 outputs)
if protect_si_suffix:
exclude_cols += [
c
for c in df_proc.columns
if str(c).endswith(protect_si_suffix)
]
# IMPORTANT: do NOT shift time in normalization for PINN
df_proc, coord_scaler, cols_scaler = normalize_for_pinn(
df=df_proc,
time_col=numerical_time_col,
coord_x=lon_col,
coord_y=lat_col,
scale_coords=normalize_coords,
cols_to_scale=cols_to_scale,
forecast_horizon=forecast_horizon,
shift_time_by_horizon=False,
exclude_cols=exclude_cols,
protect_si_suffix=protect_si_suffix,
verbose=verbose,
_logger=_logger,
coord_scaler=coord_scaler,
fit_coord_scaler=fit_coord_scaler,
)
# --- 2. Group and Sort Data ---
vlog(
"Grouping and sorting data...",
verbose=verbose,
level=2,
logger=_logger,
)
if verbose >= 2:
logger.debug(
f"Grouping and sorting data. Group IDs:"
f" {group_id_cols or 'None (single group)'}."
)
sort_by_cols = [numerical_time_col]
if group_id_cols:
sort_by_cols = group_id_cols + sort_by_cols
df_proc = df_proc.sort_values(
by=sort_by_cols
).reset_index(drop=True)
if group_id_cols:
grouped_data = df_proc.groupby(group_id_cols)
group_keys = list(grouped_data.groups.keys())
if verbose >= 1:
logger.info(
f"Data grouped by {group_id_cols} into {len(group_keys)} groups."
)
else:
grouped_data = [(None, df_proc)]
group_keys = [None]
if verbose >= 1:
logger.info(
"Processing entire DataFrame as a single group."
)
# --- 3. First Pass: Calculate Total Number of Sequences ---
total_sequences = 0
min_len_per_group = time_steps + forecast_horizon
valid_group_dfs = []
vlog(
"Counting valid sequences in groups...",
verbose=verbose,
level=4,
logger=_logger,
)
if verbose >= 2:
logger.debug(
"First pass: Calculating total sequences. Min"
f" length per group: {min_len_per_group}."
)
n_groups = len(group_keys) or 1
for g_idx, group_key in enumerate(group_keys):
group_df = (
grouped_data.get_group(group_key)
if group_id_cols
else df_proc
)
if stop_check and stop_check():
raise InterruptedError(
"Sequence generation aborted."
)
key_str = (
group_key
if group_key is not None
else "<Full Dataset>"
)
if progress_hook is not None:
progress_hook(
_to_range((g_idx + 1) / n_groups, 0.0, 0.50)
)
if len(group_df) < min_len_per_group:
if verbose >= 5:
logger.info(
f"Group '{key_str}' has {len(group_df)} points, less than "
f"min required ({min_len_per_group}). Skipping."
)
vlog(
f"Group {key_str} too small:"
f" {len(group_df)} < {min_len_per_group}",
verbose=verbose,
level=6,
logger=_logger,
)
continue
num_seq_in_group = (
len(group_df) - min_len_per_group + 1
)
total_sequences += num_seq_in_group
valid_group_dfs.append(group_df)
if verbose >= 5:
logger.debug(
f"Group '{group_key if group_key else '<Full Dataset>'}' "
f"will yield {num_seq_in_group} sequences."
)
vlog(
f"Group {key_str} yields {total_sequences} seqs.",
verbose=verbose,
level=6,
logger=_logger,
)
if total_sequences == 0:
raise ValueError(
"No group has enough data points to create sequences with "
f"time_steps={time_steps} and forecast_horizon={forecast_horizon}."
)
if verbose >= 1:
logger.info(
"Total valid sequences to be"
f" generated: {total_sequences}."
)
vlog(
f"Total sequences: {total_sequences}",
verbose=verbose,
level=1,
logger=_logger,
)
# --- 4. Pre-allocate NumPy Arrays ---
vlog(
"Pre-allocating arrays...",
verbose=verbose,
level=2,
logger=_logger,
)
num_dynamic_feats = len(dynamic_cols)
num_static_feats = len(static_cols) if static_cols else 0
num_future_feats = len(future_cols) if future_cols else 0
coords_horizon_arr = np.zeros(
(total_sequences, forecast_horizon, 3),
dtype=np.float32,
)
static_features_arr = np.zeros(
(total_sequences, num_static_feats), dtype=np.float32
)
dynamic_features_arr = np.zeros(
(total_sequences, time_steps, num_dynamic_feats),
dtype=np.float32,
)
if mode == "tft_like":
future_window_len = time_steps + forecast_horizon
vlog(
f"Allocating future features array for 'tft_like' mode "
f"with time dimension: {future_window_len}",
level=2,
verbose=verbose,
logger=_logger,
)
else:
future_window_len = forecast_horizon
future_features_arr = np.zeros(
(
total_sequences,
future_window_len,
num_future_feats,
),
dtype=np.float32,
)
target_subsidence_arr = np.zeros(
(
total_sequences,
forecast_horizon,
output_subsidence_dim,
),
dtype=np.float32,
)
target_gwl_arr = np.zeros(
(total_sequences, forecast_horizon, output_gwl_dim),
dtype=np.float32,
)
# --- MODIFICATION 2: Conditional Allocation ---
H_field_arr = None # Initialize as None
if is_geoprior:
H_field_arr = np.zeros(
(total_sequences, forecast_horizon, 1),
dtype=np.float32,
)
vlog(
f"Allocated 'H_field' array with shape {H_field_arr.shape}",
verbose=verbose,
level=4,
logger=_logger,
)
# --- End Modification 2 ---
vlog(
"Arrays shapes set.",
verbose=verbose,
level=5,
logger=_logger,
)
if verbose >= 2:
logger.debug(
"Pre-allocated NumPy arrays for sequence data:"
)
logger.debug(
f" Coords Horizon: {coords_horizon_arr.shape}"
)
logger.debug(
f" Static Features: {static_features_arr.shape}"
)
logger.debug(
f" Dynamic Features: {dynamic_features_arr.shape}"
)
logger.debug(
f" Future Features: {future_features_arr.shape}"
)
logger.debug(
f" Target Subsidence: {target_subsidence_arr.shape}"
)
logger.debug(f" Target GWL: {target_gwl_arr.shape}")
# --- MODIFICATION 2b: Conditional Logging ---
if H_field_arr is not None:
logger.debug(
f" H_field ({h_field_col}): {H_field_arr.shape}"
)
# --- End Modification 2b ---
# --- 5. Second Pass: Populate Arrays with Rolling Windows ---
current_seq_idx = 0
if verbose >= 2:
logger.debug(
"Second pass: Populating sequence arrays..."
)
vlog(
"Populating arrays with data...",
verbose=verbose,
level=2,
logger=_logger,
)
total_seq = (
sum(
len(gdf) - min_len_per_group + 1
for gdf in valid_group_dfs
)
or 1
)
done_seq = 0
for group_df in valid_group_dfs:
group_t_coords = group_df[numerical_time_col].values
group_x_coords = group_df[lon_col].values
group_y_coords = group_df[lat_col].values
t_min_group, t_max_group = (
group_t_coords.min(),
group_t_coords.max(),
)
if verbose >= 3:
print()
time_scale_info = (
f"{t_min_group:.4f}-{t_max_group:.4f}"
)
if normalize_coords and coord_scaler:
time_scale_info += " (normalized)"
else:
time_scale_info += " (original scale)"
print_box(
f"Group window t: {time_scale_info}",
width=_TW,
align="center",
border_char="+",
horizontal_char="-",
vertical_char="|",
padding=1,
)
group_static_vals = None
if static_cols and num_static_feats > 0:
group_static_vals = group_df.iloc[0][
static_cols
].values.astype(np.float32)
# --- MODIFICATION 3: Get static H_val for the group ---
group_H_val = None
if is_geoprior:
# Get the single static soil thickness value for this group
group_H_val = group_df.iloc[0][h_field_col]
# --- End Modification 3 ---
num_seq_in_this_group = (
len(group_df) - min_len_per_group + 1
)
for i in range(num_seq_in_this_group):
done_seq += 1
if progress_hook is not None:
progress_hook(
_to_range(done_seq / total_seq, 0.5, 1.00)
)
if stop_check and stop_check():
raise InterruptedError(
"Sequence generation aborted by user"
)
if static_cols and num_static_feats > 0:
static_features_arr[current_seq_idx] = (
group_static_vals
)
dynamic_start_idx = i
dynamic_end_idx = i + time_steps
dynamic_features_arr[current_seq_idx] = (
group_df.iloc[
dynamic_start_idx:dynamic_end_idx
][dynamic_cols].values.astype(np.float32)
)
horizon_start_idx = i + time_steps
horizon_end_idx = (
i + time_steps + forecast_horizon
)
if future_cols and num_future_feats > 0:
if mode == "tft_like":
future_start_idx = i
future_end_idx = (
i + time_steps + forecast_horizon
)
else: # 'pihal_like' mode
future_start_idx = horizon_start_idx
future_end_idx = horizon_end_idx
future_features_arr[current_seq_idx] = (
group_df.iloc[
future_start_idx:future_end_idx
][future_cols].values.astype(np.float32)
)
t_horizon = group_t_coords[
horizon_start_idx:horizon_end_idx
]
x_horizon = group_x_coords[
horizon_start_idx:horizon_end_idx
]
y_horizon = group_y_coords[
horizon_start_idx:horizon_end_idx
]
coords_horizon_arr[current_seq_idx, :, 0] = (
t_horizon
)
coords_horizon_arr[current_seq_idx, :, 1] = (
x_horizon
)
coords_horizon_arr[current_seq_idx, :, 2] = (
y_horizon
)
target_subsidence_arr[current_seq_idx] = (
group_df.iloc[
horizon_start_idx:horizon_end_idx
][subsidence_col]
.values.reshape(
forecast_horizon, output_subsidence_dim
)
.astype(np.float32)
)
target_gwl_arr[current_seq_idx] = (
group_df.iloc[
horizon_start_idx:horizon_end_idx
][gwl_col]
.values.reshape(
forecast_horizon, output_gwl_dim
)
.astype(np.float32)
)
# --- MODIFICATION 4: Populate H_field array ---
if H_field_arr is not None:
# Tile the static H_val across the forecast horizon
H_field_arr[current_seq_idx, :, 0] = (
group_H_val
)
# --- End Modification 4 ---
if verbose >= 7:
logger.debug(f" Sequence {current_seq_idx}:")
logger.debug(
" Dynamic window:"
f" {dynamic_start_idx}-{dynamic_end_idx - 1}"
)
logger.debug(
" Horizon window:"
f" {horizon_start_idx}-{horizon_end_idx - 1}"
)
logger.debug(
f" Coords (first step):"
f" {coords_horizon_arr[current_seq_idx, 0, :]}"
)
vlog(
f"Seq {current_seq_idx}:"
f" dyn {dynamic_start_idx}-{dynamic_end_idx - 1},"
f"hzn {horizon_start_idx}-{horizon_end_idx - 1}",
verbose=verbose,
level=7,
logger=_logger,
)
current_seq_idx += 1
if verbose >= 1:
logger.info("Successfully populated sequence arrays.")
vlog(
"Data population complete.",
verbose=verbose,
level=1,
logger=_logger,
)
inputs_dict = {
"coords": coords_horizon_arr,
"static_features": static_features_arr
if num_static_feats > 0
else None,
"dynamic_features": dynamic_features_arr,
"future_features": future_features_arr
if num_future_feats > 0
else None,
}
# --- MODIFICATION 5: Conditionally add H_field to inputs_dict ---
if H_field_arr is not None:
inputs_dict["H_field"] = H_field_arr
# --- End Modification 5 ---
inputs_dict = {
k: v for k, v in inputs_dict.items() if v is not None
}
targets_dict = {
"subsidence": target_subsidence_arr,
"gwl": target_gwl_arr,
}
if savefile:
vlog(
f"\nPreparing to save sequence data to '{savefile}'...",
verbose=verbose,
level=3,
logger=_logger,
)
# --- v3.2: allow split between GWL dynamic driver and GWL prediction target ---
gwl_dyn_col = kws.get("gwl_dyn_col", None) or gwl_col
if gwl_dyn_col not in dynamic_cols:
raise ValueError(
f"gwl_dyn_col={gwl_dyn_col!r} must be present in dynamic_cols.\n"
f"Got gwl_col(target)={gwl_col!r} and dynamic_cols={dynamic_cols}.\n"
"GeoPrior v3.2 expects: gwl_dyn_col=<depth__si>, gwl_col=<head__si>."
)
gwl_dyn_index = int(dynamic_cols.index(gwl_dyn_col))
job_dict = {
"static_data": static_features_arr,
"dynamic_data": dynamic_features_arr,
"future_data": future_features_arr,
"subsidence": target_subsidence_arr,
"gwl": target_gwl_arr,
"static_features": static_cols,
"dynamic_features": dynamic_cols,
"future_features": future_cols,
"inputs_dict": inputs_dict,
"targets_dict": targets_dict,
"subsidence_col": subsidence_col,
"spatial_features": spatial_cols,
"lon_col": lon_col,
"lat_col": lat_col,
"time_col": time_col,
"time_steps": time_steps,
"forecast_horizon": forecast_horizon,
"cols_scaler": cols_scaler,
"coord_scaler": coord_scaler,
"normalize_coords_flag": normalize_coords,
"saved_coord_scaler_flag": coord_scaler
is not None,
# --- Conditionally add to savefile ---
"model_type": model,
"h_field_col": h_field_col,
"H_field": H_field_arr
if H_field_arr is not None
else None,
"gwl_col": gwl_col, # target column (head)
"gwl_dyn_col": gwl_dyn_col, # dynamic driver column (depth)
"gwl_dyn_index": gwl_dyn_index, # index of the driver in dynamic_cols
}
try:
job_dict.update(get_versions())
except NameError:
vlog(
"\n `get_versions` not found, version info not saved.",
verbose=verbose,
level=1,
logger=_logger,
)
try:
save_job(
job_dict, savefile, append_versions=False
)
if verbose >= 1:
vlog(
f"Sequence data dictionary successfully "
f"saved to '{savefile}'.",
verbose=verbose,
level=1,
logger=_logger,
)
except Exception as e:
vlog(
f"Failed to save job dictionary to "
f"'{savefile}': {e}",
verbose=verbose,
level=1,
logger=_logger,
)
if verbose >= 1:
logger.info(
"PINN data sequence preparation completed."
)
if verbose >= 3:
for key, arr in inputs_dict.items():
logger.debug(
" Final input '{key}' shape:"
f" {arr.shape if arr is not None else 'None'}"
)
for key, arr in targets_dict.items():
logger.debug(
f" Final target '{key}' shape: {arr.shape}"
)
vlog(
"PINN data sequence preparation successfully completed.",
verbose=verbose,
level=3,
logger=_logger,
)
if return_coord_scaler:
return inputs_dict, targets_dict, coord_scaler
return inputs_dict, targets_dict
[docs]
def check_and_rename_keys(
inputs, y
): # ranem to check_input_keys
r"""
Helper function to check and rename keys in the inputs
and target dictionaries.
This function ensures that the necessary keys are present in both the
`inputs` and `y` dictionaries. If the keys for 'subsidence' or 'gwl'
are not found, it attempts to rename them from possible alternatives
like 'subs_pred' or 'gwl_pred'.
Parameters
----------
inputs : dict
A dictionary containing the input data. The keys 'coords' and
'dynamic_features' are expected.
y : dict
A dictionary containing the target values. The keys 'subsidence'
and 'gwl' are expected, but they could also appear as 'subs_pred'
or 'gwl_pred'.
Raises
------
ValueError
If required keys are missing in `inputs` or `y`, or if renaming
does not result in valid keys for 'subsidence' and 'gwl'.
"""
# Check if 'coords' and 'dynamic_features' are in inputs
if "coords" not in inputs or inputs["coords"] is None:
raise ValueError("Input 'coords' is missing or None.")
if (
"dynamic_features" not in inputs
or inputs["dynamic_features"] is None
):
raise ValueError(
"Input 'dynamic_features' is missing or None."
)
# Check for 'subsidence' in y, allow renaming from 'subs_pred'
# just check whether subsidence or subs_pred in y
if "subsidence" not in y and "subs_pred" in y:
y["subsidence"] = y.pop("subs_pred")
if "subsidence" not in y:
# here by explicit to hel the
raise ValueError(
"Target 'subsidence' is missing or None."
)
# use that user should provide subsidene or subs_pred
# Check for 'gwl' in y, allow renaming from 'gwl_pred'
# just check whether gwl or gwl_pred one,its in the y
#
if "gwl_pred" not in y and "gwl" in y:
# no need to rename yet, later this will handle si just check only
y["gwl_pred"] = y.pop("gwl")
if "gwl" not in y:
raise ValueError("Target 'gwl' is missing or None.")
return inputs, y
def _check_required_input_keys(
inputs,
y=None,
message=None,
):
r"""
Helper function to check and rename keys in the inputs
and target dictionaries.
This function ensures that the necessary keys are present in both the
`inputs` and `y` dictionaries. If the keys for 'subsidence' or 'gwl'
are not found, it attempts to rename them from possible alternatives
like 'subs_pred' or 'gwl_pred'.
Parameters
----------
inputs : dict
A dictionary containing the input data. The keys 'coords' and
'dynamic_features' are expected.
y : dict
A dictionary containing the target values. The keys 'subsidence'
and 'gwl' are expected, but they could also appear as 'subs_pred'
or 'gwl_pred'.
message : str, optional
Message to raise error when inputs/y are not dictionnary.
Raises
------
ValueError
If required keys are missing in `inputs` or `y`, or if renaming
does not result in valid keys for 'subsidence' and 'gwl'.
"""
if inputs is not None:
if not isinstance(inputs, dict):
message = message or (
"Inputs must be a dictionnary containing"
" 'coords' and 'dynamic_features'."
f" Got {type(inputs).__name__!r}"
)
raise TypeError(message)
# Check if 'coords' and 'dynamic_features' are in inputs
if "coords" not in inputs or inputs["coords"] is None:
raise ValueError(
"Input 'coords' is missing or None."
)
if (
"dynamic_features" not in inputs
or inputs["dynamic_features"] is None
):
raise ValueError(
"Input 'dynamic_features' is missing or None."
)
if y is not None:
if not isinstance(y, dict):
message = message or (
"Target `y` must be a dictionnary containing"
" 'subs_pred/subsidence' and 'gwl/gwl_red'."
f" Got {type(y).__name__!r}"
)
raise TypeError(message)
# Check for 'subsidence' in y, allow renaming from 'subs_pred'
if "subsidence" not in y and "subs_pred" not in y:
raise ValueError(
"Target 'subsidence' is missing or None."
" Please provide 'subsidence' or 'subs_pred'."
)
# Check for 'gwl' in y, allow renaming from 'gwl_pred'
if "gwl" not in y and "gwl_pred" not in y:
raise ValueError(
"Target 'gwl' is missing or None."
" Please provide 'gwl' or 'gwl_pred'."
)
return inputs, y
def _extract_txy_in(
inputs: Tensor
| np.ndarray
| dict[str, Tensor | np.ndarray],
coord_slice_map: dict[str, int] | None = None,
) -> tuple[Tensor, Tensor, Tensor]:
r"""
Extracts t, x, y tensors from `inputs`, which may be:
- A single 3D tensor of shape (batch, time_steps, 3)
- A dict containing a key 'coords' with such a tensor
- A dict containing separate keys 't', 'x', and 'y'
Parameters
----------
inputs : tf.Tensor or np.ndarray or dict
If tensor/array: expected shape is (batch, time_steps, 3).
If dict:
- If 'coords' in dict: dict['coords'] must be (batch, time_steps, 3).
- Otherwise, dict must have keys 't', 'x', 'y' each of shape
(batch, time_steps, 1) or (batch, time_steps).
coord_slice_map : dict, optional
Mapping from 't', 'x', 'y' to their index in the last dimension of
the coords tensor. Defaults to {'t': 0, 'x': 1, 'y': 2}.
Returns
-------
t : tf.Tensor
Tensor of shape (batch, time_steps, 1) corresponding to time coordinate.
x : tf.Tensor
Tensor of shape (batch, time_steps, 1) corresponding to x coordinate.
y : tf.Tensor
Tensor of shape (batch, time_steps, 1) corresponding to y coordinate.
Raises
------
ValueError
If `inputs` is not in one of the supported formats, or dimensions
are inconsistent.
"""
# Default slice map
if coord_slice_map is None:
coord_slice_map = {"t": 0, "x": 1, "y": 2}
# Helper to ensure output is a tf.Tensor with a final singleton dim
def _ensure_tensor_with_last_dim(
inp: Tensor | np.ndarray,
) -> Tensor:
if isinstance(inp, np.ndarray):
inp = tf_convert_to_tensor(inp)
if not isinstance(inp, Tensor):
raise ValueError(
f"Expected tf.Tensor or np.ndarray, got {type(inp)}"
)
# If shape is (batch, time_steps), add last dimension
if inp.ndim == 2:
inp = tf_expand_dims(inp, axis=-1)
# Now expect (batch, time_steps, 1)
if inp.ndim != 3 or inp.shape[-1] != 1:
raise ValueError(
"Coordinate array must have shape "
f"(batch, time_steps, 1); got {inp.shape}"
)
return inp
# Case 1: inputs is a dict
if isinstance(inputs, dict):
# If 'coords' key is present
if "coords" in inputs:
coords_tensor = inputs["coords"]
if isinstance(coords_tensor, np.ndarray):
coords_tensor = tf_convert_to_tensor(
coords_tensor
)
if not isinstance(coords_tensor, Tensor):
raise ValueError(
f"Expected tensor/array for 'coords';"
f" got {type(coords_tensor)}"
)
# Expect shape (batch, time_steps, 3)
if (
coords_tensor.ndim != 3
or coords_tensor.shape[-1] < 3
):
raise ValueError(
f"'coords' must have shape (batch, time_steps, ≥3);"
f" got {coords_tensor.shape}"
)
# Slice out t, x, y
t = coords_tensor[
...,
coord_slice_map["t"] : coord_slice_map["t"]
+ 1,
]
x = coords_tensor[
...,
coord_slice_map["x"] : coord_slice_map["x"]
+ 1,
]
y = coords_tensor[
...,
coord_slice_map["y"] : coord_slice_map["y"]
+ 1,
]
return (
tf_cast(t, tf_float32),
tf_cast(x, tf_float32),
tf_cast(y, tf_float32),
)
# If keys 't','x','y' exist separately
if all(k in inputs for k in ("t", "x", "y")):
t = _ensure_tensor_with_last_dim(inputs["t"])
x = _ensure_tensor_with_last_dim(inputs["x"])
y = _ensure_tensor_with_last_dim(inputs["y"])
return (
tf_cast(t, tf_float32),
tf_cast(x, tf_float32),
tf_cast(y, tf_float32),
)
raise ValueError(
"Dict `inputs` must contain either key 'coords' or keys 't', 'x', 'y'."
)
# Case 2: inputs is a single tensor/array
if isinstance(inputs, Tensor | np.ndarray):
coords_tensor = inputs
if isinstance(coords_tensor, np.ndarray):
coords_tensor = tf_convert_to_tensor(
coords_tensor
)
# Expect shape (batch, time_steps, 3)
if (
coords_tensor.ndim != 3
or coords_tensor.shape[-1] < 3
):
raise ValueError(
f"Tensor `inputs` must have shape (batch, time_steps, 3);"
f" got {coords_tensor.shape}"
)
t = coords_tensor[
...,
coord_slice_map["t"] : coord_slice_map["t"] + 1,
]
x = coords_tensor[
...,
coord_slice_map["x"] : coord_slice_map["x"] + 1,
]
y = coords_tensor[
...,
coord_slice_map["y"] : coord_slice_map["y"] + 1,
]
return (
tf_cast(t, tf_float32),
tf_cast(x, tf_float32),
tf_cast(y, tf_float32),
)
raise ValueError(
f"`inputs` must be a tensor/array or dict; got {type(inputs)}"
)
def _get_coords(
inputs: dict[str, Any] | Sequence[Any] | Tensor,
*,
check_shape: bool = False,
is_time_dependent: bool = True,
) -> Tensor:
r"""
Extract the **coords** tensor from any input layout.
Parameters
----------
inputs : dict, sequence or tensor
* **Dict** – must contain a ``"coords"`` key.
* **Sequence** – coords are expected at index 0.
* **Tensor** – interpreted as coords directly.
check_shape : bool, default ``False``
If ``True``, validate that the last dimension equals 3 and that
the rank matches *is_time_dependent*.
is_time_dependent : bool, default ``True``
* ``True`` → expect a 3-D shape ``(B, T, 3)``
* ``False`` → allow a 2-D shape ``(B, 3)``; a 3-D shape is also
accepted (useful when the model ignores time).
Returns
-------
tf.Tensor
The extracted coords tensor.
Raises
------
KeyError
If a dict is missing the ``"coords"`` key.
TypeError
If *inputs* is not dict, sequence or tensor.
ValueError
If ``check_shape`` is ``True`` and the tensor shape is invalid.
Notes
-----
The function is lightweight and executes outside any
``tf.function`` context; use it freely inside the model’s
``train_step`` or data-pipe helpers.
"""
# ── 1. locate the tensor
if isinstance(inputs, Tensor):
coords = inputs
elif isinstance(inputs, tuple | list):
coords = inputs[0]
elif isinstance(inputs, dict):
try:
coords = inputs["coords"]
except KeyError as err:
raise KeyError(
"Input dict lacks a 'coords' entry."
) from err
else:
raise TypeError(
"inputs must be a dict, sequence, or Tensor; "
f"got {type(inputs).__name__}"
)
# ── 2. validate shape if requested
if check_shape:
shape = coords.shape # static if known, else dynamic
# last dim must be 3
if shape.rank is not None:
if shape[-1] != 3:
raise ValueError(
"coords[..., -1] must equal 3 (t, x, y); "
f"found {shape[-1]}"
)
# rank check when static
if is_time_dependent and shape.rank != 3:
raise ValueError(
"Expected coords rank 3 (B, T, 3) for time-dependent "
"data; received rank "
f"{shape.rank}"
)
if not is_time_dependent and shape.rank not in (
2,
3,
):
raise ValueError(
"Expected coords rank 2 or 3 for static data; "
f"received rank {shape.rank}"
)
else: # dynamic shapes (inside tf.function)
tf_debugging.assert_equal(
tf_shape(coords)[-1],
3,
message="coords[..., -1] must equal 3 (t, x, y)",
)
if is_time_dependent:
tf_debugging.assert_rank(coords, 3)
else:
tf_debugging.assert_rank_in(coords, (2, 3))
return coords
[docs]
@ensure_pkg(
KERAS_BACKEND or "tensorflow",
extra="TensorFlow is required for this function.",
)
def plot_hydraulic_head(
model: Model,
t_slice: float,
x_bounds: tuple[float, float],
y_bounds: tuple[float, float],
resolution: int = 100,
ax: plt.Axes | None = None,
title: str | None = None,
cmap: str = "viridis",
colorbar_label: str = "Hydraulic Head (h)",
save_path: str | None = None,
show_plot: bool = True,
**contourf_kwargs: Any,
) -> tuple[plt.Axes, ScalarMappable]:
r"""Generate and plot a 2D contour map of a hydraulic head solution.
This utility visualizes the output of a Physics-Informed Neural
Network (PINN) that solves for the hydraulic head
:math:`h(t, x, y)`. It automates the process of creating a
spatial grid, running model predictions, and generating a
publication-quality contour plot for a specific slice in time.
Parameters
----------
model : tf.keras.Model
The trained PINN model. It is expected to have a ``.predict()``
method that accepts a dictionary of tensors with keys
``{'t', 'x', 'y'}``.
t_slice : float
The specific point in time :math:`t` for which to plot the
2D spatial solution.
x_bounds : tuple of float
A tuple ``(x_min, x_max)`` defining the spatial domain for
the x-axis.
y_bounds : tuple of float
A tuple ``(y_min, y_max)`` defining the spatial domain for
the y-axis.
resolution : int, optional
The number of points to sample along each spatial axis,
creating a grid of ``resolution x resolution`` points for
prediction. Higher values result in a smoother plot.
Default is 100.
ax : matplotlib.axes.Axes, optional
A pre-existing Matplotlib Axes object to plot on. If ``None``,
a new figure and axes are created internally. This is useful
for embedding this plot within a larger figure arrangement.
Default is ``None``.
title : str, optional
A custom title for the plot. If ``None``, a default title
is generated using the value of `t_slice`. Default is ``None``.
cmap : str, optional
The name of the Matplotlib colormap to use for the contour
plot. Default is ``'viridis'``.
colorbar_label : str, optional
The text label for the color bar. Default is
``'Hydraulic Head (h)'``.
save_path : str, optional
If provided, the path (including filename and extension)
where the generated plot will be saved. This is only active
when the function creates its own figure (i.e., when `ax`
is ``None``). Default is ``None``.
show_plot : bool, optional
If ``True``, calls ``plt.show()`` to display the plot. This
is only active when the function creates its own figure.
Default is ``True``.
**contourf_kwargs : any
Additional keyword arguments that are passed directly to the
``matplotlib.pyplot.contourf`` function. This allows for
advanced customization (e.g., ``levels=20``, ``extend='both'``).
Returns
-------
ax : matplotlib.axes.Axes
The Matplotlib Axes object on which the contour plot was drawn.
contour : matplotlib.cm.ScalarMappable
The contour plot object, which can be used for further
customizations, such as modifying the color bar.
See Also
--------
geoprior.models.pinn.PiTGWFlow : The PINN model this function is
designed to visualize.
Notes
-----
The core mechanism of this function involves creating a 2D
meshgrid of :math:`(x, y)` coordinates. These grid points are then
"flattened" into a long list of points, as the PINN model expects
a batch of individual coordinates for prediction, not a grid.
The prediction process is as follows:
1. A grid of shape ``(resolution, resolution)`` is created for
:math:`x` and :math:`y`.
2. These grids are reshaped into column vectors of shape
``(resolution*resolution, 1)``.
3. A time vector of the same shape, filled with `t_slice`, is
created.
4. The model's ``.predict()`` method is called on these flat
tensors.
5. The resulting flat prediction vector is reshaped back to the
original ``(resolution, resolution)`` grid shape for plotting.
If a custom `ax` is provided, the user is responsible for calling
``plt.show()`` or saving the parent figure.
Examples
--------
>>> import numpy as np
>>> import tensorflow as tf
>>> import matplotlib.pyplot as plt
>>> # This is a mock model for demonstration purposes.
>>> # In practice, you would use a trained PiTGWFlow model.
>>> class MockPINN(tf.keras.Model):
... def call(self, inputs):
... # A simple analytical function for demonstration
... t, x, y = inputs['t'], inputs['x'], inputs['y']
... return tf.sin(np.pi * x) * tf.cos(np.pi * y) * tf.exp(-t)
...
>>> mock_model = MockPINN()
**1. Simple Plotting Example**
This example creates a single plot and saves it to a file.
>>> ax, contour = plot_hydraulic_head(
... model=mock_model,
... t_slice=0.5,
... x_bounds=(-1, 1),
... y_bounds=(-1, 1),
... resolution=50,
... save_path="hydraulic_head_t0.5.png",
... show_plot=False # Do not display interactively
... )
Plot saved to hydraulic_head_t0.5.png
**2. Advanced Example with Subplots**
This example shows how to use the `ax` parameter to draw the
solution at two different times side-by-side in one figure.
>>> fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))
>>> fig.suptitle('Hydraulic Head at Different Times', fontsize=16)
...
>>> # Plot solution at t = 0.1
>>> plot_hydraulic_head(
... model=mock_model, t_slice=0.1, x_bounds=(-1, 1),
... y_bounds=(-1, 1), ax=ax1, show_plot=False
... )
...
>>> # Plot solution at t = 1.0
>>> plot_hydraulic_head(
... model=mock_model, t_slice=1.0, x_bounds=(-1, 1),
... y_bounds=(-1, 1), ax=ax2, show_plot=False
... )
...
>>> plt.tight_layout(rect=[0, 0, 1, 0.96])
>>> plt.show()
"""
# ... function implementation follows ...
# --- 1. Handle Matplotlib Figure and Axes ---
if ax is None:
# Create a new figure only if no axes are provided
fig, current_ax = plt.subplots(figsize=(9, 7))
# Flag to indicate we have control over the figure object
fig_created = True
else:
current_ax = ax
fig = current_ax.figure
fig_created = False
# --- 2. Create Prediction Grid ---
x_range = np.linspace(
x_bounds[0], x_bounds[1], resolution
)
y_range = np.linspace(
y_bounds[0], y_bounds[1], resolution
)
X, Y = np.meshgrid(x_range, y_range)
# Flatten the grid to a list of points for model prediction
x_flat = tf_convert_to_tensor(X.ravel(), dtype=tf_float32)
y_flat = tf_convert_to_tensor(Y.ravel(), dtype=tf_float32)
t_flat = tf_fill(x_flat.shape, t_slice)
# Reshape to column vectors (N, 1) as expected by the model
grid_coords = {
"t": tf_reshape(t_flat, (-1, 1)),
"x": tf_reshape(x_flat, (-1, 1)),
"y": tf_reshape(y_flat, (-1, 1)),
}
# --- 3. Run Model Prediction ---
h_pred_flat = model.predict(grid_coords)
# Reshape the flat predictions back to the grid shape for plotting
h_pred_grid = tf_reshape(h_pred_flat, X.shape)
# --- 4. Plotting ---
# Set default contour levels if not provided
if "levels" not in contourf_kwargs:
contourf_kwargs["levels"] = 100
contour = current_ax.contourf(
X, Y, h_pred_grid, cmap=cmap, **contourf_kwargs
)
# Add color bar
fig.colorbar(contour, ax=current_ax, label=colorbar_label)
# Set plot labels and title
if title is None:
title = f"Learned Hydraulic Head Solution at t = {t_slice}"
current_ax.set_title(title, fontsize=14)
current_ax.set_xlabel("x-coordinate")
current_ax.set_ylabel("y-coordinate")
current_ax.set_aspect("equal")
# --- 5. Save and Show ---
if fig_created:
if save_path:
fig.savefig(
save_path, dpi=300, bbox_inches="tight"
)
print(f"Plot saved to {save_path}")
if show_plot:
plt.show()
else:
# If not showing, close the figure to free up memory
plt.close(fig)
return current_ax, contour