Utility API reference#

GeoPrior-v3 exposes utilities across three complementary layers:

  • geoprior.utils for workflow-facing utilities used across staged runs, preprocessing, forecasting, diagnostics, calibration, evaluation, export, and artifact handling;

  • geoprior.models.utils for model-facing utilities used closer to forecasting models, sequence construction, PINN input preparation, tensor formatting, and PDE mode normalization;

  • geoprior.models.subsidence.utils for subsidence-physics utilities used to canonicalize scaling metadata, convert between model and SI units, resolve groundwater and subsidence channels, and support head/depth transformations and physics-aware priors.

This page documents all three layers together because they are not isolated in practice. A typical GeoPrior workflow uses them as a stack:

  • the workflow layer prepares data, resolves config, and drives staged execution;

  • the model layer standardizes sequence handling, model inputs, forecast formatting, and PINN support functions;

  • the subsidence-physics layer reconciles units, aliases, coordinate scaling, groundwater conventions, and model-side physical state extraction.

Overview#

A useful mental map is:

geoprior.utils
   ├── audits / handshakes
   ├── config and NAT helpers
   ├── data / IO / export helpers
   ├── forecast formatting and evaluation
   ├── holdout and split logic
   ├── spatial and geospatial utilities
   ├── sequence and shape helpers
   └── subsidence-oriented workflow conversions

geoprior.models.utils
   ├── general model helpers
   ├── sequence construction
   ├── forecast formatting
   ├── input packing / unpacking
   ├── anomaly and evaluation helpers
   └── PINN-specific helper functions

geoprior.models.subsidence.utils
   ├── scaling alias normalization
   ├── SI conversion helpers
   ├── coordinate scaling helpers
   ├── groundwater / head conversion
   ├── dynamic channel resolution
   └── history / reference-state extraction

Why three utility layers exist#

The separation is architectural rather than cosmetic.

Workflow surface#

The top-level geoprior.utils package is the best entry point when reading Stage-1 through Stage-5 workflows, figure scripts, evaluation paths, and export code. Its public exports already include config loaders, dataset builders, calibration helpers, forecast formatters, holdout logic, geospatial helpers, and subsidence-oriented postprocessing utilities.

Model surface#

The geoprior.models.utils package complements the workflow layer by handling lower-level model concerns such as sequence construction, preparing input tuples, forecast-to-DataFrame formatting, and PINN-oriented helper functions. The associated module docstrings explicitly position this layer in the context of sequence forecasting and Temporal-Fusion-Transformer-style workflows, which is useful for readers trying to connect GeoPrior’s utilities to broader multi-horizon forecasting practice.

Subsidence-physics surface#

The geoprior.models.subsidence.utils module is narrower and more physics-aware. It is where scaling metadata is made canonical, SI conversions are enforced, coordinate policies are checked, groundwater depth can be reconciled with hydraulic head, and dynamic-channel lookup is stabilized. This layer is central when readers need to understand how raw workflow tensors become physically interpretable model quantities.

Reading order#

A practical order for new readers is:

  1. geoprior.utils.nat_utils

  2. geoprior.utils.forecast_utils

  3. geoprior.utils.subsidence_utils

  4. geoprior.models.utils

  5. geoprior.models.utils.pinn

  6. geoprior.models.subsidence.utils

That sequence moves from the workflow surface into the model surface and finally into the more specialized subsidence-physics support layer.

Module indexes#

These indexes use explicit fully qualified module names rather than recursive package discovery. That keeps autosummary predictable and avoids ambiguous package-relative lookup.

Workflow utility modules#

audit_utils

Audit helpers for stage handshakes and scaling artifacts.

base_utils

Essential utilities for data processing and analysis in FusionLab, offering functions for normalization, interpolation, feature selection, outlier removal, and various data manipulation tasks.

calibrate

Forecast utilities.

data_utils

Data utilities.

deps_utils

Dependency utilities providing functions to handle package installation, checking, and ensuring that optional dependencies are available.

forecast_utils

Forecast utilities.

generic_utils

Provides common helper functions and for validation, comparison, and other generic operations

geo_utils

Geospatial utility helpers for GeoPrior workflows.

holdout_utils

Utility helpers for holdout and split workflows.

io_utils

Input/Output utilities for managing file paths, directories, and loading serialized data within FusionLab.

nat_utils

Public exports for NAT workflow utilities.

parallel_utils

Parallel execution helpers for GeoPrior workflows.

scale_metrics

Utilities for computing error metrics in physical units given Stage-1 scaling metadata.

sequence_utils

Sequence-building helpers for temporal model inputs.

shapes

Shape utility helpers for arrays and tensors.

spatial_utils

geospatial_utils - A collection of utilities for geospatial and positional data analysis, filtering, and transformations.

split

geoprior.utils.split

subsidence_utils

Utility helpers for subsidence data, units, and coordinates.

sys_utils

System utilities module for managing system-level operations.

target_utils

Target-processing helpers for GeoPrior workflows.

validator

Provides a comprehensive set of functions and warnings for validating and ensuring the integrity of data.

version

Vendored version parsing utilities.

Model utility modules#

_utils

Utility functions for neural networks models.

pinn

Physics-Informed Neural Network (PINN) Utility functions.

Subsidence-physics utility modules#

The subsidence-physics utility module is indexed from the dedicated subsidence API page to keep one canonical autosummary stub for geoprior.models.subsidence.utils.

Top-level workflow utility package#

The public package surface below documents the aggregated export layer of geoprior.utils.

Public exports for GeoPrior utility helpers.

geoprior.utils.spatial_sampling(data, sample_size=0.01, stratify_by=None, spatial_bins=10, spatial_cols=None, method='abs', min_relative_ratio=0.01, random_state=42, savefile=None, verbose=1)[source]

Sample spatial data intelligently to represent the distribution of the whole area and include different years.

This function performs stratified sampling on spatial data, ensuring that the sample reflects both spatial distribution and temporal aspects of the entire dataset. It combines spatial stratification based on coordinates and additional stratification columns specified by the user.

Parameters:
  • data (pandas.DataFrame) – The input DataFrame to sample from. Must contain spatial coordinate columns (e.g., ‘longitude’, ‘latitude’) and any columns specified in stratify_by.

  • sample_size (float or int, optional) – The proportion or absolute number of samples to select. If float, should be between 0.0 and 1.0 and represents the fraction of the dataset to include in the sample. If int, represents the absolute number of samples to select. Default is 0.01 (1% of the data).

  • stratify_by (list of str, optional) – List of column names to stratify by.

  • spatial_bins (int or tuple/list of int, optional) – Number of bins to divide the spatial coordinates into. If an integer, the same number of bins is used for all spatial dimensions. If a tuple or list, its length must match the number of spatial columns, specifying the number of bins for each spatial dimension. Default is 10.

  • spatial_cols (list or tuple of str, optional) – List of spatial coordinate column names. Can accept one or two columns. If None, the function checks for columns named ‘longitude’ and/or ‘latitude’ in data. If only one spatial column is provided or found, a warning is issued, suggesting that providing both spatial columns is recommended for more accurate sampling. If more than two columns are provided, an error is raised.

  • method (str, {'abs', 'relative'}, default 'abs') – Defines how the sample size is determined. 'abs' or 'absolute' uses a fixed sampling proportion based on sample_size. 'relative' scales sampling by dataset stratification so small groups still receive a proportional sample controlled by min_relative_ratio.

  • min_relative_ratio (float, default 0.01) – Controls the minimum allowable fraction of records that must be sampled when method='relative'. It must be between 0 and 1. For example, min_relative_ratio=0.05 requests at least 5 percent of the total dataset size from each stratification group when possible; if a group is smaller than that minimum, the entire subset is sampled instead.

  • random_state (int, optional) – Random seed for reproducibility. Default is 42.

  • verbose (int, default 1) – Controls progress-bar and status output during execution. Larger values produce more detailed messages.

Returns:

sampled_data – A sampled DataFrame representing the distribution of the whole area and including different years.

Return type:

pandas.DataFrame

Notes

The function performs stratified sampling based on spatial bins and other specified stratification columns. Spatial coordinates are binned using quantile-based discretization (pandas.qcut()), ensuring each bin has approximately the same number of observations.

Let \(N\) be the total number of samples in data, and \(n\) be the desired sample size. The function calculates the number of samples to draw from each stratification group based on the proportion of the group size to the total dataset size:

(1)#\[n_i = \left\lceil \frac{N_i}{N} \times n \right\rceil\]

where \(N_i\) is the size of group \(i\), and \(n_i\) is the number of samples to draw from group \(i\).

The function ensures that all specified spatial and stratification columns exist in data, that the number of spatial bins matches the number of spatial columns, and that the sample size is valid. A warning is issued when only one spatial column is used because two spatial columns usually give more reliable spatial sampling.

Examples

>>> from geoprior.utils.spatial_utils import spatial_sampling
>>> import pandas as pd
>>> # Assume 'df' is a pandas DataFrame with columns
>>> # 'longitude', 'latitude', 'year', and other data.
>>> sampled_df = spatial_sampling(
...     data=df,
...     sample_size=0.05,
...     stratify_by=['year', 'geological_category'],
...     spatial_bins=(10, 15),
...     spatial_cols=['longitude', 'latitude'],
...     random_state=42
... )
>>> print(sampled_df.shape)

See also

pandas.qcut

Quantile-based discretization function used for binning.

sklearn.model_selection.StratifiedShuffleSplit

For stratified sampling.

batch_spatial_sampling

Resample spatial data with batching.

geoprior.utils.create_spatial_clusters(df, spatial_cols=None, cluster_col='region', n_clusters=None, algorithm='kmeans', view=True, figsize=(14, 10), s=60, plot_style='seaborn', cmap='tab20', show_grid=True, grid_props=None, auto_scale=True, savefile=None, verbose=1, **kwargs)[source]

Cluster 2D spatial data in df using <algorithm> and optionally plot the results.

This function, <create_spatial_clusters>, extracts two coordinate columns from <df> to form clusters via methods such as ‘kmeans’, ‘dbscan’, or ‘agglo’ (agglomerative). It uses the function filter_valid_kwargs (when relevant) to strip out invalid parameters for certain estimators, and writes cluster labels into <cluster_col>.

Parameters:
  • df (pandas.DataFrame) – Input DataFrame holding spatial coordinates and optional other fields.

  • spatial_cols (list of str, optional) – Two-column list for x and y coordinates. Defaults to ['longitude','latitude'] if None.

  • cluster_col (str, default 'region') – Name of the column to store the assigned cluster labels.

  • n_clusters (int, optional) – Number of clusters to form. If not provided for KMeans, it is auto-detected. For DBSCAN or Agglomerative, a warning is issued if not set.

  • algorithm (str, default 'kmeans') – Choice of clustering algorithm among ['kmeans','dbscan','agglo'].

  • view (bool, default True) – If True, displays a scatterplot of the final clusters.

  • figsize (tuple, default (14, 10)) – Size of the displayed figure for the cluster plot.

  • s (int, default 60) – Marker size in the scatterplot.

  • plot_style (str, default 'seaborn') – Matplotlib style used for the plot.

  • cmap (str, default 'tab20') – Colormap name used to differentiate clusters.

  • show_grid (bool, default True) – Toggles grid lines on or off.

  • grid_props (dict, optional) – Additional keyword arguments controlling the grid style.

  • auto_scale (bool, default True) – If True, standardize coordinates before clustering.

  • savefile (str, optional) – File path to save the data with an additional <cluster_col> storing the assigned cluster labels if desired.

  • verbose (int, default 1) – Controls console logs. Higher values yield more details about scaling and cluster detection.

  • **kwargs – Additional keyword arguments passed to the chosen algorithm (filtered by filter_valid_kwargs for KMeans, DBSCAN, AgglomerativeClustering ).

Returns:

A copy of <df> with an additional <cluster_col> storing the assigned cluster labels.

Return type:

pandas.DataFrame

Notes

If <auto_scale> is True, it uses a standard scaler to normalize the coordinate columns. The scatterplot is generated using the library seaborn for enhanced styling.

By default, for <algorithm> = “kmeans”, the model attempts to minimize:

(2)#\[J = \sum_{i=1}^{N} \min_{\mu_j} \lVert x_i - \mu_j \rVert^2\]

where \(x_i\) are the scaled or raw 2D coordinates in <df>. The function can optionally auto-detect n_clusters using a silhouette and elbow analysis if not provided.

Examples

>>> from geoprior.utils.spatial_utils import create_spatial_clusters
>>> import pandas as pd
>>> df = pd.DataFrame({
...     "longitude": [0.1, 0.2, 2.2, 2.3],
...     "latitude": [1.0, 1.1, 2.1, 2.2]
... })
>>> # KMeans with auto scale and auto-detect k
>>> result = create_spatial_clusters(
...     df=df,
...     algorithm="kmeans",
...     view=True
... )
>>> # DBSCAN with custom arguments
>>> result_db = create_spatial_clusters(
...     df=df,
...     algorithm="dbscan",
...     eps=0.5,
...     min_samples=2
... )

See also

filter_valid_kwargs

Helps discard unsupported keyword arguments for chosen estimators.

geoprior.utils.augment_city_spatiotemporal_data(df, city, mode='interpolate', group_by_cols=None, time_col=None, value_cols_interpolate=None, feature_cols_augment=None, interpolation_config=None, augmentation_config=None, target_name=None, interpolate_target=False, verbose=True, coordinate_precision=None, savefile=None)[source]

Apply grouped spatiotemporal augmentation with city-aware defaults.

This is a convenience wrapper around augment_spatiotemporal_data. It validates the requested city, optionally rounds coordinates before grouping, and forwards interpolation and augmentation configuration dictionaries.

Parameters:
  • df (pandas.DataFrame) – Input DataFrame containing spatial, temporal, and feature columns.

  • city ({'nansha', 'zhongshan'}) – City identifier used for validation and defaults.

  • mode ({'interpolate', 'augment_features', 'both'}, optional) – Processing mode forwarded to augment_spatiotemporal_data.

  • group_by_cols (list of str or None, optional) – Grouping columns for interpolation.

  • time_col (str or None, optional) – Time column used for interpolation.

  • value_cols_interpolate (list of str or None, optional) – Columns to interpolate.

  • feature_cols_augment (list of str or None, optional) – Columns to augment with noise.

  • interpolation_config (dict or None, optional) – Keyword arguments for interpolate_temporal_gaps. Typical values include {'freq': 'AS', 'method': 'linear'}.

  • augmentation_config (dict or None, optional) – Keyword arguments for augment_series_features. Typical values include {'noise_level': 0.01, 'noise_type': 'gaussian'}.

  • target_name (str or None, optional) – Optional target column name used when inferring default feature sets.

  • interpolate_target (bool, optional) – Whether the target should be included in default interpolation columns.

  • verbose (bool, optional) – Whether to emit progress information.

  • coordinate_precision (int or None, optional) – Decimal precision applied to coordinates before grouping.

  • savefile (str or None, optional) – Optional output CSV path handled by the decorator.

Returns:

Augmented DataFrame.

Return type:

pandas.DataFrame

Raises:
  • ValueError – If city or mode is invalid, or if required arguments are missing for the selected mode.

  • TypeError – If the main inputs are of the wrong type.

geoprior.utils.augment_series_features(series_df, feature_cols, noise_level=0.01, noise_type='gaussian', random_seed=None, savefile=None)[source]

Add random noise to selected numeric feature columns.

Parameters:
  • series_df (pandas.DataFrame) – Input DataFrame representing one or more time series.

  • feature_cols (list of str) – Feature columns to augment.

  • noise_level (float, optional) – Magnitude of the added noise. For Gaussian noise it scales the feature standard deviation, and for uniform noise it scales the feature range.

  • noise_type ({'gaussian', 'uniform'}, optional) – Type of noise distribution to use.

  • random_seed (int or None, optional) – Seed for reproducible noise generation.

  • savefile (str or None, optional) – Optional output path handled by the decorator.

Returns:

DataFrame with noise added to the selected feature columns.

Return type:

pandas.DataFrame

Raises:
  • ValueError – If requested feature columns are missing or noise_type is invalid.

  • TypeError – If the main inputs are of the wrong type.

Notes

Non-numeric columns are skipped, and constant or invalid numeric ranges are left unchanged.

geoprior.utils.generate_dummy_pinn_data(n_samples, *, year_range=None, coords_range=None, subs_range=None, gwl_range=None, rainfall_range=None, vars_range=None)[source]

Generate dummy PINN data dictionary with specified or default ranges.

Parameters:
  • n_samples (int) – Number of samples to generate.

  • year_range (tuple[float, float], optional) – (min_year, max_year) for integer years. Default (2000, 2025).

  • coords_range (tuple[tuple[float, float], tuple[float, float]], optional) – ((lon_min, lon_max), (lat_min, lat_max)). Default ((113.0, 113.8), (22.3, 22.8)).

  • subs_range (tuple[float, float], optional) – (mean_subsidence, std_subsidence) for normal distribution. Default (-20, 15).

  • gwl_range (tuple[float, float], optional) – (mean_gwl, std_gwl) for normal distribution. Default (2.5, 1.0).

  • rainfall_range (tuple[float, float], optional) – (min_rain, max_rain) for uniform distribution. Default (500, 2500).

  • vars_range (dict, optional) – Dictionary that may contain any of the keys: ‘year_range’, ‘coords_range’, ‘subs_range’, ‘gwl_range’, ‘rainfall_range’. Missing keys will fall back to defaults or to explicitly passed arguments.

Returns:

dummy_data_dict

Dictionary with keys:
  • ”year” : integer years array

  • ”longitude” : float longitudes array

  • ”latitude” : float latitudes array

  • ”subsidence” : float subsidence values array

  • ”GWL” : float groundwater level values array

  • ”rainfall_mm” : float rainfall values array

Return type:

dict[str, np.ndarray]

geoprior.utils.augment_spatiotemporal_data(df, mode, group_by_cols=None, time_col=None, value_cols_interpolate=None, feature_cols_augment=None, interpolation_kwargs=None, augmentation_kwargs=None, savefile=None, verbose=False)[source]

Apply interpolation, feature augmentation, or both to grouped data.

Parameters:
  • df (pandas.DataFrame) – Input spatiotemporal DataFrame.

  • mode ({'interpolate', 'augment_features', 'both'}) – Processing mode. Use interpolation only, feature augmentation only, or interpolation followed by augmentation.

  • group_by_cols (list of str or None, optional) – Grouping columns used for per-location processing.

  • time_col (str or None, optional) – Time column required when interpolation is requested.

  • value_cols_interpolate (list of str or None, optional) – Value columns to interpolate when interpolation is enabled.

  • feature_cols_augment (list of str or None, optional) – Feature columns to perturb when augmentation is enabled.

  • interpolation_kwargs (dict or None, optional) – Keyword arguments forwarded to interpolate_temporal_gaps.

  • augmentation_kwargs (dict or None, optional) – Keyword arguments forwarded to augment_series_features.

  • savefile (str or None, optional) – Optional output path handled by the decorator.

  • verbose (bool, optional) – Whether to emit progress information.

Returns:

Processed DataFrame assembled from all groups.

Return type:

pandas.DataFrame

Raises:

ValueError – If mode is invalid or required arguments for the selected mode are missing.

Notes

Groups are processed independently and concatenated afterward.

geoprior.utils.mask_by_reference(data, ref_col, values=None, find_closest=False, fill_value=0, mask_columns=None, error='raise', verbose=0, inplace=False, savefile=None)[source]

Masks (replaces) values in columns other than the reference column for rows in which the reference column matches (or is closest to) the specified value(s).

If a row’s reference-column value is matched, that row’s values in the other columns are overwritten by fill_value. The reference column itself is not modified.

This function supports both exact and approximate matching:
  • Exact matching is used if find_closest=False.

  • Approximate (closest) matching is used if find_closest=True and the reference column is numeric.

By default, if the reference column does not exist or if the given values cannot be found (or approximated) in the reference column, an exception is raised. This behavior can be adjusted with the error parameter.

Parameters:
  • data (pd.DataFrame) – The input DataFrame containing the data to be masked.

  • ref_col (str) – The column in data serving as the reference for matching or finding the closest values.

  • values (Any or sequence of Any, optional) –

    The reference values to look for in ref_col. This can be:
    • A single value (e.g., 0 or "apple").

    • A list/tuple of values (e.g., [0, 10, 25]).

    • If values is None, all rows are masked (i.e. all rows match), effectively overwriting the entire DataFrame (except the reference column) with fill_value.

    Note that if find_closest=False, these values must appear in the reference column; otherwise, an error or warning is triggered (depending on the error setting).

  • find_closest (bool, default False) – If True, performs an approximate match for numeric reference columns. For each entry in values, the function locates the row(s) in ref_col whose value is numerically closest. Non-numeric reference columns will revert to exact matching regardless.

  • fill_value (Any, default 0) –

    The value used to fill/mask the non-reference columns wherever the condition (exact or approximate match) is met. This can be any valid type, e.g., integer, float, string, np.nan, etc. If fill_value='auto' and multiple values are given, each row matched by a particular reference value is filled with that same reference value.

    Examples:
    • If values=9 and fill_value='auto', the fill value is 9 for matched rows.

    • If values=['a', 10] and fill_value='auto', then rows matching ‘a’ are filled with ‘a’, and rows matching 10 are filled with 10.

  • mask_columns (str or list of str, optional) – If specified, only these columns are masked. If None, all columns except ref_col are masked. If any column in mask_columns does not exist in the DataFrame and error='raise', a KeyError is raised; otherwise, a warning may be issued or ignored.

  • error ({'raise', 'warn', 'ignore'}, default 'raise') –

    Controls how to handle errors:
    • ’raise’: raise an error if the reference column does not exist or if any of the given values cannot be matched (or approximated).

    • ’warn’: only issue a warning instead of raising an error.

    • ’ignore’: silently ignore any issues.

  • verbose (int, default 0) –

    Verbosity level:
    • 0: silent (no messages).

    • 1: minimal feedback.

    • 2 or 3: more detailed messages for debugging.

  • inplace (bool, default False) – If True, performs the operation in place and returns the original DataFrame with modifications. If False, returns a modified copy, leaving the original unaltered.

  • savefile (str or None, optional) – File path where the DataFrame is saved if the decorator-based saving is active. If None, no saving occurs.

Returns:

A DataFrame where rows matching the specified condition (exact or approximate) have had their non-reference columns replaced by fill_value.

Return type:

pd.DataFrame

Raises:
  • KeyError – If error='raise' and ref_col is not in data.columns.

  • ValueError – If error='raise' and no exact/approx match can be found for one or more entries in values.

Notes

  • If values is None, all rows are masked in the non-ref columns, effectively overwriting them with fill_value.

  • When find_closest=True, approximate matching is performed only if the reference column is numeric. For non-numeric data, it falls back to exact matching.

  • When multiple reference values are provided, each is processed in turn. If fill_value=’auto’, each matched row is filled with that specific reference value.

Examples

>>> import pandas as pd
>>> from geoprior.utils.data_utils import mask_by_reference
>>>
>>> df = pd.DataFrame({
...     "A": [10, 0, 8, 0],
...     "B": [2, 0.5, 18, 85],
...     "C": [34, 0.8, 12, 4.5],
...     "D": [0, 78, 25, 3.2]
... })
>>>
>>> # Example 1: Exact matching, replace all columns except 'A' with 0
>>> masked_df = mask_by_reference(
...     data=df,
...     ref_col="A",
...     values=0,
...     fill_value=0,
...     find_closest=False,
...     error="raise"
... )
>>> print(masked_df)
>>> # 'B', 'C', 'D' for rows where A=0 are replaced with 0.
>>>
>>> # Example 2: Approximate matching for numeric
>>> # If 'A' has values [0, 10, 8] and we search for 9, then 'A=8' or 'A=10'
>>> # are the closest, so those rows get masked in non-ref columns.
>>> masked_df2 = mask_by_reference(
...     data=df,
...     ref_col="A",
...     values=9,
...     find_closest=True,
...     fill_value=-999
... )
>>> print(masked_df2)
>>>
>>> # Example 2: Approx. match for numeric ref_col
>>> # 9 is between 8 and 10, so rows with A=8 and A=10 are masked
>>> res2 = mask_by_reference(df, "A", 9, find_closest=True, fill_value=-999)
>>> print(res2)
... # Rows 0 (A=10) and 2 (A=8) are replaced with -999 in columns B,C,D
>>>
>>> # Example 3: fill_value='auto' with multiple values
>>> # Rows matching A=0 => fill with 0; rows matching A=8 => fill with 8
>>> res3 = mask_by_reference(df, "A", [0, 8], fill_value='auto')
>>> print(res3)
... # => rows with A=0 => B,C,D replaced by 0
... # => rows with A=8 => B,C,D replaced by 8
>>>
>>> # 2) mask_columns=['C','D'] => only columns C and D are masked
>>> res2 = mask_by_reference(df, "A", values=0, fill_value=999,
...                         mask_columns=["C","D"])
>>> print(res2)
... # Rows where A=0 => columns C,D replaced by 999, while B remains unchanged
>>>
geoprior.utils.nan_ops(data, auxi_data=None, data_kind=None, ops='check_only', action=None, error='raise', process=None, condition=None, savefile=None, verbose=0)[source]

Perform operations on NaN values within data structures, handling both primary data and optional witness data based on specified parameters.

This function provides a comprehensive toolkit for managing missing values (NaN) in various data structures such as NumPy arrays, pandas DataFrames, and pandas Series. Depending on the ops parameter, it can check for the presence of NaN`s, validate data integrity, or sanitize the data by filling or dropping `NaN values. The function also supports handling witness data, which can be crucial in scenarios where the relationship between primary and witness data must be maintained.

(3)#\[\begin{split}\text{Processed\_data} = \begin{cases} \text{filled\_data} & \text{if action is 'fill'} \\ \text{dropped\_data} & \text{if action is 'drop'} \\ \text{original\_data} & \text{otherwise} \end{cases}\end{split}\]
Parameters:
  • data (array-like, pandas.DataFrame, or pandas.Series) – The primary data structure containing NaN values to be processed.

  • auxi_data (array-like, pandas.DataFrame, or pandas.Series, optional) – Auxiliary data that accompanies the primary data. Its role depends on the data_kind parameter. If data_kind is ‘target’, auxi_data is treated as feature data, and vice versa. This is useful for operations that need to maintain the alignment between primary and witness data.

  • data_kind ({'target', 'feature', None}, optional) – Specifies the role of the primary data. If set to ‘target’, data is considered target data, and auxi_data (if provided) is treated as feature data. If set to ‘feature’, data is treated as feature data, and auxi_data is considered target data. If None, no special handling is applied, and witness data is ignored unless explicitly required by other parameters.

  • ops ({'check_only', 'validate', 'sanitize'}, default :py:class:``’check_only’:py:class:``) –

    Defines the operation to perform on the NaN values in the data:

    • 'check_only': Checks whether the data contains any NaN values and returns a boolean indicator.

    • 'validate': Validates that the data does not contain NaN values. If NaN`s are found, it raises an error or warns based on the ``error` parameter.

    • 'sanitize': Cleans the data by either filling or dropping NaN values based on the action, process, and condition parameters.

  • action ({'fill', 'drop'}, optional) –

    Specifies the action to take when ops is set to ‘sanitize’:

    • 'fill': Fills NaN values using the fill_NaN function with the method set to ‘both’.

    • 'drop': Drops NaN values based on the conditions and process specified. If data_kind is ‘target’, it handles `NaN`s in a way that preserves data integrity for machine learning models.

    • If None, defaults to ‘drop’ when sanitizing.

    Note: If ops is not ‘sanitize’ and action is set, an error is raised indicating conflicting parameters.

  • error ({'raise', 'warn', None}, default :py:class:``’raise’:py:class:``) –

    Determines the error handling policy:

    • 'raise': Raises exceptions when encountering issues.

    • 'warn': Emits warnings instead of raising exceptions.

    • None: Defaults to the base policy, which is typically ‘warn’.

    This parameter is utilized by the error_policy function to enforce consistent error handling throughout the operation.

  • process ({'do', 'do_anyway'}, optional) –

    Works in conjunction with the action parameter when action is ‘drop’:

    • 'do': Drops NaN values only if certain conditions are met.

    • 'do_anyway': Forces the dropping of NaN values regardless of conditions.

    This provides flexibility in handling `NaN`s based on the specific requirements of the dataset and the analysis being performed.

  • condition (callable or None, optional) – A callable that defines a condition for dropping NaN values when action is ‘drop’. For example, it can specify that the number of NaN`s should not exceed a certain fraction of the dataset. If the condition is not met, the behavior is controlled by the ``process` parameter.

  • verbose (int, default 0) –

    Controls the verbosity level of the function’s output for debugging purposes:

    • 0: No output.

    • 1: Basic informational messages.

    • 2: Detailed processing messages.

    • 3: Debug-level messages with complete trace of operations.

    Higher verbosity levels provide more insights into the function’s internal operations, aiding in debugging and monitoring.

Returns:

The sanitized data structure with NaN values handled according to the specified parameters. If auxi_data is provided and processed, a tuple containing the sanitized data and auxi_data is returned. Otherwise, only the sanitized data is returned.

Return type:

array-like, pandas.DataFrame, or pandas.Series

Raises:
  • ValueError

    • If an invalid value is provided for ops or data_kind.

    • If auxi_data does not align with data in shape.

    • If sanitization conditions are not met and the error policy is set to ‘raise’.

  • Warning

    • Emits warnings when NaN values are present and the error policy is

    set to ‘warn’.

Examples

>>> from geoprior.utils.data_utils import nan_ops
>>> import pandas as pd
>>> import numpy as np
>>> # Example with target data and witness feature data
>>> target = pd.Series([1, 2, np.nan, 4])
>>> features = pd.DataFrame({
...     'A': [5, np.nan, 7, 8],
...     'B': ['x', 'y', 'z', np.nan]
... })
>>> # Check for NaNs
>>> nan_ops(target, auxi_data=features, data_kind='target', ops='check_only')
(True, True)
>>> # Validate data (will raise ValueError if NaNs are present)
>>> nan_ops(target, auxi_data=features, data_kind='target', ops='validate')
Traceback (most recent call last):
    ...
ValueError: Target contains NaN values.
>>> # Sanitize data by dropping NaNs
>>> cleaned_target, cleaned_features = nan_ops(
...     target,
...     auxi_data=features,
...     data_kind='target',
...     ops='sanitize',
...     action='drop',
...     verbose=2
... )
Dropping NaN values.
Dropped NaNs successfully.
>>> cleaned_target
0    1.0
1    2.0
3    4.0
dtype: float64
>>> cleaned_features
     A    B
0  5.0    x
3  8.0  NaN

Notes

The nan_ops function is designed to provide a robust framework for handling missing values in datasets, especially in machine learning workflows where the integrity of target and feature data is paramount. By allowing conditional operations and providing flexibility in error handling, it ensures that data preprocessing can be tailored to the specific needs of the analysis.

The function leverages helper utilities such as fill_NaN, drop_nan_in, and error_policy to maintain consistency and reliability across different data structures and scenarios. The verbosity levels aid developers in tracing the function’s execution flow, making it easier to debug and verify data transformations.

See also

geoprior.utils.base_utils.fill_NaN

Fills NaN values in numeric data structures using specified methods.

geoprior.core.array_manager.drop_nan_in

Drops NaN values from data structures, optionally alongside witness data.

geoprior.core.utils.error_policy

Determines how errors are handled based on user-specified policies.

geoprior.core.array_manager.array_preserver

Preserves and restores the original structure of array-like data.

geoprior.utils.unpack_frames_from_file(merged, *, group_col='city', output_dir=None, output_format='csv', compression=None, use_source_col=True, source_col='source', filename_pattern='{group_value}_split', drop_columns=None, keep_columns=None, save=True, return_dict=True, save_kwargs=None, verbose=1, logger)[source]

Reverse of merge_city_frames_to_file: split an aggregated NATCOM dataset into per-city frames (and optionally write them to disk).

Parameters:
  • merged (path-like or DataFrame) –

    Aggregated dataset. If path-like, the format is inferred from the file suffix:

    If a DataFrame is passed, it is used directly.

  • group_col (str, optional) – Column used to split the dataset (default: 'city'). Each unique value defines one output chunk.

  • output_dir (path-like, optional) – Directory where per-group files are written. If None and merged is a path, the directory of merged is used. If merged is a DataFrame and output_dir is None, the current working directory is used.

  • output_format ({'csv', 'parquet', 'feather', 'pickle'}, optional) – Output format for per-group files. Default is 'csv'.

  • compression (str or None, optional) –

    Compression to use when writing:

    • For 'csv', forwarded to DataFrame.to_csv() as the compression argument (e.g. 'gzip').

    • For 'parquet', forwarded to DataFrame.to_parquet() (e.g. 'snappy', 'gzip').

    • Ignored for 'feather' and 'pickle' (these use their own defaults).

  • use_source_col (bool, optional) –

    If True (default) and a column named source_col exists, the helper tries to reconstruct the original file name for each group:

    • If a group has a single unique, non-null source value that looks like a filename (e.g. 'nansha_final_main_std.harmonized.csv'), that base name is used for the output (with its suffix adjusted to match output_format if needed).

    • If there are multiple unique source labels within a group, it falls back to filename_pattern.

  • source_col (str, optional) – Name of the column containing the source label (default: 'source'). This should match the column created in merge_frames_to_file() when add_source_label=True.

  • filename_pattern (str, optional) –

    Pattern used when no suitable source label is available. The following placeholders are supported:

    • {group_value} : the group value as a string

    • {group_col} : the name of the grouping column

    Example: filename_pattern="{group_col}_{group_value}_data""city_Nansha_data.csv".

  • drop_columns (iterable of str, optional) – Columns to drop from each group before saving/returning (e.g. ['source'] if you don’t want the bookkeeping column).

  • keep_columns (iterable of str, optional) – If provided, only these columns are kept (all others are dropped after any drop_columns processing is applied).

  • save (bool, optional) – If True (default), write each group to disk as a separate file. If False, no files are written; only the dict of DataFrames is returned (if return_dict=True).

  • return_dict (bool, optional) – If True (default), return a mapping {group_value: group_df}. If False, an empty dict is returned (useful when you only care about side-effect files).

  • save_kwargs (dict, optional) – Extra keyword arguments forwarded to the respective writer: DataFrame.to_csv(), DataFrame.to_parquet(), DataFrame.to_feather(), or DataFrame.to_pickle().

  • verbose (int, optional) – Verbosity level. 0 = silent, >=1 prints progress information.

  • logger (None)

Returns:

out – Dictionary mapping each group value to the corresponding DataFrame. Empty if return_dict=False.

Return type:

dict

Raises:

ValueError – If group_col is not present in the merged dataset.

Examples

>>> from geoprior.utils.geo_utils import unpack_frames_from_file
>>> unpack_frames_from_file(
...     "natcom_all_cities.parquet",
...     group_col="city",
...     output_format="csv",
... )
# -> writes e.g. 'nansha_final_main_std.harmonized.csv',
#    'zhongshan_final_main_std.harmonized.csv' (if `source` labels exist),
#    and returns a dict: {'Nansha': df_nansha, 'Zhongshan': df_zhongshan}
geoprior.utils.widen_temporal_columns(data, dt_col, spatial_cols=None, target_name=None, round_dt=True, ignore_cols=None, nan_op=None, nan_thresh=None, savefile=None, verbose=0)[source]

Convert a long PIHALNet prediction table into a wide format where each temporal slice becomes a dedicated column.

The routine pivots columns whose names follow the pattern

<base>           deterministic forecast
<base>_qXX       quantile forecast (e.g., ``subsidence_q10``)
<base>_actual    ground‑truth column

and produces columns of the form

<base>_<year>            point forecast
<base>_<year>_qXX        quantile forecast
<base>_<year>_actual     ground‑truth value

If duplicate (spatial, year) pairs are found, values are aggregated with :pyfunc:`pandas.Series.groupby(mean) <pandas.core.series.Series.groupby>` prior to pivoting to avoid “Index contains duplicate entries” errors.

Parameters:
  • data (PathLike object or pandas.DataFrame) – Long‑format DataFrame returned by :pyfunc:`geoprior.utils.format_pihalnet_predictions`.

  • dt_col (str) – Column holding the temporal coordinate (e.g., 'coord_t'). Must be numeric or datetime‑coercible. When round_dt is True, values are rounded to integers.

  • spatial_cols ((str, str) or None, default None) – Names of x and y spatial coordinates. These are retained as leading columns in the output. If None, the function falls back to 'sample_idx' or an auto‑generated 'row_id'.

  • target_name (str or None, default None) – Restrict pivoting to a specific base (e.g., 'subsidence'). When None every base present in df is widened.

  • round_dt (bool, default True) – Round dt_col to the nearest integer (helpful for fractional years such as 2020.0001).

  • ignore_cols (list[str] or None, default None) – Additional columns to carry through unchanged. Values are propagated per spatial location using the first non‑null entry.

  • nan_op ({'drop', 'fill', 'both', None}, default None) –

    Strategy for NaN handling after pivot:

    • 'fill' – forward‑fill then back‑fill missing values.

    • 'drop' – drop rows containing NaNs (see nan_thresh).

    • 'both' – fill then drop according to nan_thresh.

    • None – leave NaNs untouched.

  • nan_thresh (float or None, default None) –

    When nan_op contains 'drop', rows are dropped if the proportion of missing values exceeds nan_thresh. Set nan_thresh = 0 to require no NaNs, 0.5 to allow ≤ 50 % missing, etc.

    (4)\[\text{row kept} \;\Longleftrightarrow\; \frac{\text{NaNs in row}}{\text{row width}} \le \text{nan\_thresh}\]

  • savefile (str, optional) – If a file path is provided, the final wide-format DataFrame will be saved as a CSV file.

  • verbose (int, default 0) – Diagnostic verbosity from 0 (silent) to 5 (trace every step).

Returns:

Wide‑format frame with spatial identifiers first, followed by year‑wise forecast, quantile, and actual columns.

Return type:

pandas.DataFrame

Raises:
  • KeyErrordt_col missing from df or spatial_cols absent.

  • ValueError – No columns match target_name or nan_thresh is outside \([0, 1]\).

Notes

  • Duplicate indices are aggregated with the arithmetic mean before pivoting. Modify the aggregation lambda inside the function for alternative choices.

  • If ignore_cols is provided, their first non‑null value per spatial location is appended to the output.

Examples

Minimal usage on a tiny synthetic set

>>> import pandas as pd
>>> from geoprior.utils.data_utils import widen_temporal_columns
>>>
>>> df_long = pd.DataFrame(
...     {
...         "coord_x": [113.15, 113.15, 113.15, 113.15],
...         "coord_y": [22.63, 22.63, 22.63, 22.63],
...         "coord_t": [2019, 2020, 2019, 2020],
...         "subsidence_q50": [0.09, 0.10, 0.12, 0.13],
...         "subsidence_actual": [0.08, 0.11, 0.10, 0.14],
...     }
... )
>>>
>>> wide = widen_temporal_columns(
...     df_long,
...     dt_col="coord_t",
...     spatial_cols=("coord_x", "coord_y"),
...     verbose=2,
... )
[INFO] Initial rows: 4, columns: 2
[INFO] Widening base 'subsidence' (2 columns)
[DONE] Final wide shape: (1, 4)
>>> wide
   coord_x  coord_y  subsidence_2019_actual  subsidence_2020_actual  \
0   113.15    22.63                   0.08                   0.11

subsidence_2019_q50 subsidence_2020_q50

0 0.12 0.13

End‑to‑end example with NaN handling, ignored columns, and two targets

>>> import numpy as np
>>> rng = pd.date_range("2018", periods=3, freq="Y").year
>>> n = 5  # five spatial locations
>>>
>>> # build synthetic long DataFrame
>>> df_long = pd.DataFrame(
...     {
...         "sample_idx": np.repeat(np.arange(n), len(rng)),
...         "coord_x": np.repeat(np.linspace(113.4, 113.5, n), len(rng)),
...         "coord_y": np.repeat(np.linspace(22.1, 22.2, n), len(rng)),
...         "coord_t": np.tile(rng, n),
...         "region": np.repeat(["A", "B", "A", "B", "A"], len(rng)),
...         "subsidence_q10": np.random.rand(n * len(rng)),
...         "subsidence_q50": np.random.rand(n * len(rng)),
...         "subsidence_q90": np.random.rand(n * len(rng)),
...         "subsidence_actual": np.random.rand(n * len(rng)),
...         "GWL_q50": np.random.rand(n * len(rng)),
...     }
... )
>>>
>>> # introduce NaNs for demonstration
>>> df_long.loc[df_long.sample(frac=0.2).index, "subsidence_q50"] = np.nan
>>>
>>> wide = widen_temporal_columns(
...     df_long,
...     dt_col="coord_t",
...     spatial_cols=("coord_x", "coord_y"),
...     ignore_cols=["region"],
...     target_name=None,      # widen both 'subsidence' and 'GWL'
...     nan_op="both",         # fill then drop rows with many NaNs
...     nan_thresh=0.4,        # allow at most 40 % missing
...     verbose=3,
... )
[INFO] Initial rows: 15, columns: 7
[INFO] Widening base 'GWL' (1 columns)
  └─ 0 duplicate rows in 'GWL_q50' → aggregated
[INFO] Widening base 'subsidence' (4 columns)
  └─ 0 duplicate rows in 'subsidence_q10' → aggregated
  └─ 0 duplicate rows in 'subsidence_q50' → aggregated
  └─ 0 duplicate rows in 'subsidence_q90' → aggregated
  └─ 0 duplicate rows in 'subsidence_actual' → aggregated
[INFO] Missing values filled (ffill+bfill).
[INFO] Rows with >40% NaN dropped.
[DONE] Final wide shape: (5, 19)
>>> wide.iloc[:2, :8]  # show first 8 columns
   coord_x  coord_y  GWL_2018_q50  GWL_2019_q50  GWL_2020_q50  \
0  113.400       ...         ...          ...          ...
1  113.425       ...         ...          ...          ...

subsidence_2018_actual subsidence_2019_actual subsidence_2020_actual

0 … … … 1 … … …

See also

pandas.DataFrame.unstack

Core pivoting method used internally.

geoprior.plot.forecast.forecast_view

Visualisation routine that consumes the resulting wide frame.

geoprior.utils.pivot_forecast_dataframe(data, id_vars, time_col, value_prefixes, static_actuals_cols=None, time_col_is_float_year='auto', round_time_col=False, verbose=0, savefile=None, _logger=None, **kws)[source]

Transforms a long-format forecast DataFrame to a wide format.

This utility reshapes time series prediction data from a “long” format, where each row represents a single time step for a given sample, to a “wide” format, where each row represents a single sample and columns correspond to values at different time steps.

Parameters:
  • data (pd.DataFrame) – The input long-format DataFrame. It must contain the columns specified in id_vars and time_col, as well as value columns that start with the strings in value_prefixes.

  • id_vars (list of str) – A list of column names that uniquely identify each sample or group. These columns will be preserved in the wide-format output. For example: ['sample_idx', 'coord_x', 'coord_y'].

  • time_col (str) – The name of the column that represents the time step or year of the forecast (e.g., ‘coord_t’ or ‘forecast_step’). This column’s values will become part of the new column names.

  • value_prefixes (list of str) – A list of prefixes for the value columns that need to be pivoted. The function identifies columns starting with these prefixes. For instance, ['subsidence', 'GWL'] would match ‘subsidence_q10’, ‘GWL_q50’, etc.

  • static_actuals_cols (list of str, optional) – A list of columns containing static “actual” or ground truth values for each sample. These values are assumed to be constant for each unique sample_idx and are merged back into the wide DataFrame after pivoting. Example: ['subsidence_actual'].

  • time_col_is_float_year (bool or 'auto', default 'auto') –

    Controls how the time_col values are formatted into new column names. - If 'auto', automatically detects if time_col has a

    float dtype.

    • If True, treats time_col values (e.g., 2018.0) as years and converts them to integer strings (‘2018’).

    • If False, uses the string representation of the value as is.

  • round_time_col (bool, default False) – If True and time_col is a float type, its values will be rounded to the nearest integer before being used in column names. This is useful for cleaning up float years (e.g., 2018.0001 -> 2018).

  • verbose (int, default 0) – Controls the verbosity of logging messages. 0 is silent. Higher values print more details about the process.

  • savefile (str, optional) – If a file path is provided, the final wide-format DataFrame will be saved as a CSV file to that location.

  • _logger (Logger | Callable[[str], None] | None)

Returns:

A wide-format DataFrame with one row per unique combination of id_vars. New columns are created in the format {prefix}_{time_str}{_suffix} (e.g., ‘subsidence_2018_q10’).

Return type:

pd.DataFrame

See also

pandas.pivot_table

The core function used for reshaping data.

pandas.merge

Used to re-join static columns after pivoting.

Notes

  • The combination of columns in id_vars and time_col must uniquely identify each row in df_long for the pivot to succeed without data loss.

  • If using static_actuals_cols, the id_vars list must contain ‘sample_idx’ to correctly merge the static data back.

Examples

>>> import pandas as pd
>>> from geoprior.utils.data_utils import pivot_forecast_dataframe
>>> data = {
...     'sample_idx':      [0, 0, 1, 1],
...     'coord_t':         [2018.0, 2019.0, 2018.0, 2019.0],
...     'coord_x':         [0.1, 0.1, 0.5, 0.5],
...     'coord_y':         [0.2, 0.2, 0.6, 0.6],
...     'subsidence_q50':  [-8, -9, -13, -14],
...     'subsidence_actual': [-8.5, -8.5, -13.2, -13.2],
...     'GWL_q50':         [1.2, 1.3, 2.2, 2.3],
... }
>>> df_long_example = pd.DataFrame(data)
>>> df_wide = pivot_forecast_dataframe(
...     data=df_long_example,
...     id_vars=['sample_idx', 'coord_x', 'coord_y'],
...     time_col='coord_t',
...     value_prefixes=['subsidence', 'GWL'],
...     static_actuals_cols=['subsidence_actual'],
...     verbose=0
... )
>>> print(df_wide.columns)
Index(['sample_idx', 'coord_x', 'coord_y', 'subsidence_actual',
       'GWL_2018_q50', 'GWL_2019_q50', 'subsidence_2018_q50',
       'subsidence_2019_q50'],
      dtype='object')
geoprior.utils.fetch_joblib_data(job_file, *keys, error_mode='raise', verbose=0)[source]

Dynamically load data from a joblib-saved dictionary with flexible key access.

Parameters:
  • job_file (str) – Path to the joblib file containing a dictionary

  • *keys (str) – Variable-length list of dictionary keys to retrieve

  • error_mode ({'raise', 'warn', 'ignore'}, default 'raise') – Handling of missing keys: - ‘raise’: Immediately raise KeyError - ‘warn’: Issue warning and skip missing keys - ‘ignore’: Silently skip missing keys

  • verbose (int, default 0) – Verbosity level: - 0: No output - 1: Basic loading information - 2: Detailed debugging output

Returns:

  • Full dictionary if no keys specified

  • Tuple of values for requested keys (maintaining order)

Return type:

Union[Dict, Tuple]

Raises:
  • FileNotFoundError – If specified job_file doesn’t exist

  • TypeError – If loaded data isn’t a dictionary

  • KeyError – If requested key not found and error_mode=’raise’

Examples

>>> from geoprior.utils.io_utils import fetch_joblib_data
>>> data = fetch_joblib_data('data.joblib', 'X_train', 'y_train')
>>> X, y = fetch_joblib_data('data.joblib', 'X_val', 'y_val', verbose=1)
>>> full_dict = fetch_joblib_data('data.joblib')

Notes

  • Maintains original insertion order for Python 3.7+ dictionaries

  • Missing keys in ‘warn’/’ignore’ modes result in shorter return tuple

  • Joblib files must contain dictionary objects

geoprior.utils.save_job(job, savefile, *, protocol=None, append_versions=True, append_date=True, fix_imports=True, buffer_callback=None, **job_kws)[source]

Quick save your job using ‘joblib’ or persistent Python pickle module.

Parameters:
  • job (Any) – Anything to save, preferabaly a models in dict

  • savefile (str, or path-like object) – name of file to store the model. The file argument must have a write() method that accepts a single bytes argument. It can thus be a file object opened for binary writing, an io.BytesIO instance, or any other custom object that meets this interface.

  • append_versions (bool, default =True) – Append the version of Joblib module or Python Pickle module following by the scikit-learn, numpy and also pandas versions. This is useful to have idea about previous versions for loading file when system or modules have been upgraded. This could avoid bottleneck when data have been stored for long times and user has forgotten the date and versions at the time the file was saved.

  • append_date (bool, default True,) – Append the date of the day to the filename.

  • protocol (int, optional) –

    The optional protocol argument tells the pickler to use the given protocol; supported protocols are 0, 1, 2, 3, 4 and 5. The default protocol is 4. It was introduced in Python 3.4, and is incompatible with previous versions.

    Specifying a negative protocol version selects the highest protocol version supported. The higher the protocol used, the more recent the version of Python needed to read the pickle produced.

  • fix_imports (bool, default True,) – If fix_imports is True and protocol is less than 3, pickle will try to map the new Python 3 names to the old module names used in Python 2, so that the pickle data stream is readable with Python 2.

  • buffer_call_back (int, optional) –

    If buffer_callback is None (the default), buffer views are serialized into file as part of the pickle stream.

    If buffer_callback is not None, then it can be called any number of times with a buffer view. If the callback returns a false value (such as None), the given buffer is out-of-band; otherwise the buffer is serialized in-band, i.e. inside the pickle stream.

    It is an error if buffer_callback is not None and protocol is None or smaller than 5.

  • job_kws (dict,) – Additional keywords arguments passed to joblib.dump().

Returns:

The final filename where the job was saved.

Return type:

str

Notes

This function appends system-specific metadata like versions and date to the filename, which can aid in tracking compatibility over time.

Examples

>>> from geoprior.utils.io_utils import save_job
>>> model = {"key": "value"}  # Replace with actual model object
>>> savefile = save_job(model, "my_model", append_date=True, append_versions=True)
>>> print(savefile)
'my_model.20240101.sklearn_v1.0.numpy_v1.21.joblib'
geoprior.utils.normalize_time_column(df, time_col, datetime_col='datetime_temp', year_col='year_int', drop_orig=False)[source]

Normalize a time column into a datetime column and an integer year.

The input column may contain integer years, strings, or existing pandas Datetime values. The function creates datetime_col with parsed timestamps and year_col with the extracted integer year. When drop_orig=True, the original time_col is removed and datetime_col is renamed back to time_col.

Parameters:
  • df (pandas.DataFrame) – Input DataFrame containing a time column named time_col.

  • time_col (str) – Name of the column to normalize.

  • datetime_col (str, default 'datetime_temp') – Name of the parsed datetime column.

  • year_col (str, default 'year_int') – Name of the extracted integer year column.

  • drop_orig (bool, default False) – If True, drop the original time_col after parsing and rename datetime_col back to time_col.

Returns:

A copy of df with the parsed datetime column and integer year column.

Return type:

pandas.DataFrame

Raises:
  • ValueError – If time_col is missing or parsing fails for any entry.

  • TypeError – If df is not a pandas DataFrame.

geoprior.utils.convert_eval_payload_units(payload, cfg=None, *, mode='si', scope='all', savefile=None, fmt='json', indent=2, copy_payload=True)[source]

Convert GeoPriorSubsNet evaluation-payload units for reporting.

This is a post-processing helper meant for stage-2 evaluation JSON payloads (e.g. geoprior_eval_phys_<timestamp>.json).

Parameters:
  • payload (Mapping[str, Any]) – The evaluation payload dict. It is expected to contain sections like metrics_evaluate, point_metrics, per_horizon, interval_calibration and censor_stratified.

  • cfg (mapping or module, optional) – The experiment config (e.g. config module or globals()). The helper reads SUBS_UNIT_TO_SI (or stage-1 provenance) and TIME_UNITS from this object when available.

  • mode ({"si", "interpretable"}, default "si") – "si" leaves values untouched. "interpretable" converts selected subsidence and physics metrics from SI into the native units implied by SUBS_UNIT_TO_SI.

  • scope ({"all", "subsidence", "physics"}, default "all") – Which parts to convert when mode="interpretable". "subsidence" converts only subsidence metrics such as MAE, MSE, and sharpness to the native unit. "physics" converts only unambiguous physics residual rates, currently epsilon_cons_raw and epsilon_gw_raw. "all" applies both conversions.

  • savefile (str, optional) – If provided, write the converted payload to this path.

  • fmt ({"json"}, default "json") – Output format when savefile is provided.

  • indent (int, default 2) – JSON indentation.

  • copy_payload (bool, default True) – If True, operate on a deep copy of payload. If False, convert in-place (dangerous).

Returns:

Converted payload as a plain dict.

Return type:

dict

Notes

For subsidence metrics, linear quantities such as MAE and sharpness scale by 1 / SUBS_UNIT_TO_SI, while squared quantities such as MSE scale by (1 / SUBS_UNIT_TO_SI) ** 2.

When physics conversion is enabled, epsilon_cons_raw is treated as a rate in m/s and converted to <subs_native_unit>/<TIME_UNITS> (for example mm/yr), while epsilon_gw_raw is treated as a rate in 1/s and converted to 1/<TIME_UNITS>.

The helper records unit provenance under payload["units"].

geoprior.utils.postprocess_eval_json(eval_json, *, cfg=None, scope='all', out_path=None, overwrite=False, add_rmse=True, force=False, indent=2)[source]

Post-hoc convert a Stage-2 evaluation JSON from SI to interpretable units.

This is a safe wrapper around convert_eval_payload_units(…) that: - loads a JSON from disk (or accepts a payload dict), - infers unit factors from payload[“units”] when cfg is missing, - avoids double conversion unless force=True, - optionally adds RMSE fields (sqrt(MSE)), - writes a converted JSON file if out_path is provided.

Parameters:
  • eval_json (str | Mapping[str, Any]) – Either a file path to the saved JSON, or an in-memory mapping.

  • cfg (Mapping[str, object] | None) – Optional config mapping/module. If missing (or incomplete), this helper will synthesize the minimal keys needed for conversion: - “SUBS_UNIT_TO_SI” - “TIME_UNITS”

  • scope (Literal['all', 'subsidence', 'physics']) – Forwarded to convert_eval_payload_units(…).

  • out_path (str | None) – If given, write the converted payload there. If a directory is given, a filename is generated next to the input name (or “geoprior_eval…”).

  • overwrite (bool) – If False and out_path exists, raise.

  • add_rmse (bool) – If True, add RMSE fields wherever MSE is present (metrics_evaluate, point_metrics, per_horizon).

  • force (bool) – If False, skip conversion when payload already declares an interpretable subsidence unit (e.g., “mm”). If True, always convert.

  • indent (int) – JSON indent used when writing.

Returns:

The converted payload (always returned as a dict).

Return type:

dict

geoprior.utils.evaluate_point_forecast(model, out, y_true_subs, *, y_true_gwl=None, n_q=3, quantiles=None, q=None, use_physical=False, return_physical=None, scaler_info=None, subs_target_name='subsidence', gwl_target_name='gwl', scaler_entry=None, scaler=None, feature_index=None, n_features=None, params=None, strict=True, output_names=None)[source]

End-to-end helper for point-forecast evaluation.

Pipeline: extract predictions from model and out, canonicalize BHQO quantile outputs when needed, pick a point prediction, optionally inverse-scale it, and compute global and per-horizon metrics.

Parameters:
  • model (Any) – Passed to extract_preds(…).

  • out (Any) – Passed to extract_preds(…).

  • y_true_subs (Any) – True subsidence, shape (B, H, 1) or (B, H).

  • y_true_gwl (Any | None) – Optional true gwl/head, same shape conventions.

  • n_q (int) – Quantile-selection controls for (B, H, Q, 1) outputs. If q is None, the median is preferred. If q is an integer, it is treated as a direct quantile index. If q is a float and quantiles is provided, the nearest quantile is selected; otherwise the float is treated as a fraction in [0, 1].

  • quantiles (Sequence[float] | None) – Quantile-selection controls for (B, H, Q, 1) outputs. If q is None, the median is preferred. If q is an integer, it is treated as a direct quantile index. If q is a float and quantiles is provided, the nearest quantile is selected; otherwise the float is treated as a fraction in [0, 1].

  • q (float | int | None) – Quantile-selection controls for (B, H, Q, 1) outputs. If q is None, the median is preferred. If q is an integer, it is treated as a direct quantile index. If q is a float and quantiles is provided, the nearest quantile is selected; otherwise the float is treated as a fraction in [0, 1].

  • use_physical (bool) – If True, compute metrics in physical units.

  • return_physical (bool | None) – If None, defaults to use_physical. If True, return physical-space arrays such as subs_pred_phys and gwl_pred_phys when possible.

  • scaler_info (Mapping[str, Any] | None) – Passed through to inverse_scale_target(...).

  • scaler_entry (Mapping[str, Any] | None) – Passed through to inverse_scale_target(...).

  • scaler (Any | None) – Passed through to inverse_scale_target(...).

  • feature_index (int | None) – Passed through to inverse_scale_target(...).

  • n_features (int | None) – Passed through to inverse_scale_target(...).

  • params (Mapping[str, float] | None) – Passed through to inverse_scale_target(...).

  • subs_target_name (str)

  • gwl_target_name (str)

  • strict (bool)

  • output_names (Sequence[str] | None)

Returns:

Dictionary containing model-space predictions, optional physical-space predictions, and global and per-horizon metrics for subsidence and groundwater outputs.

Return type:

dict

geoprior.utils.inverse_scale_target(y_scaled, *, scaler_info=None, target_name=None, scaler_entry=None, scaler=None, feature_index=None, n_features=None, params=None)[source]

Inverse-transform a scaled target array back to physical units.

Supports three patterns:

  1. Stage-1 scaler_info dict (preferred in NATCOM): pass scaler_info=scaler_info_dict, target_name="subsidence".

  2. A bare scaler instance or a path to a joblib dump via scaler=.... If multi-feature, also pass feature_index and optionally n_features.

  3. Manual scaling parameters via params such as {"min", "max"}, {"mean", "std"}, or {"scale", "shift"}.

Parameters:
  • y_scaled (array-like) – Scaled values, e.g. of shape (N, H, 1) or (N,).

  • scaler_info (mapping, optional)

  • target_name (str, optional)

  • scaler_entry (mapping, optional)

  • scaler (object or str, optional)

  • feature_index (int, optional)

  • n_features (int, optional)

  • params (mapping, optional) – Manual parameters: - {"min", "max"} → MinMax scaling - {"mean", "std"} → standardization - {"scale", "shift"} → x = scale * x_scaled + shift.

Returns:

Array with same shape as y_scaled but in physical units when scaling information is available. If nothing usable is found, returns the input as a NumPy array.

Return type:

np.ndarray

geoprior.utils.deg_to_m_from_lat(lat_deg)[source]

Approx WGS84 meters per degree at reference latitude. Returns (deg_to_m_lon, deg_to_m_lat).

Parameters:

lat_deg (float)

Return type:

tuple[float, float]

geoprior.utils.canonicalize_BHQO(y_pred, *, y_true=None, q_values=(0.1, 0.5, 0.9), n_q=None, layout=None, enforce_monotone=True, return_layout=False, verbose=0, log_fn=<built-in function print>)[source]

Canonicalize quantile outputs to (B, H, Q, O).

Supported layouts (rank-4):
  • BHQO: (B, H, Q, O) -> unchanged

  • BQHO: (B, Q, H, O) -> transpose(0, 2, 1, 3)

  • BHOQ: (B, H, O, Q) -> transpose(0, 1, 3, 2)

If ambiguous (e.g., H == Q), and y_true is given, pick the transform with smallest MAE for q50.

If y_true is not given, fallback is:
  1. use layout if provided

  2. else prefer BHQO if plausible

  3. else pick by min crossing score

Parameters:
  • y_pred (Any) – Quantile tensor, NumPy array or TF tensor.

  • y_true (Any | None) – Target tensor (B, H, O) or (B, H, 1). Used only to resolve ambiguity robustly.

  • q_values (Sequence[float]) – Quantiles in order, e.g. (0.1, 0.5, 0.9).

  • n_q (int | None) – Number of quantiles. Defaults to len(q_values).

  • layout (str | None) – Force interpretation: “BHQO”, “BQHO”, “BHOQ”. Use “auto” (or None) to infer.

  • enforce_monotone (bool) – Sort along Q axis after canonicalization.

  • return_layout (bool) – If True, return (arr, chosen_layout).

  • verbose (int) – Logging controls.

  • log_fn (Callable[[str], None]) – Logging controls.

Returns:

Canonical (B, H, Q, O) and optionally the layout.

Return type:

arr or (arr, layout)

geoprior.utils.calibrate_quantile_forecasts(*, df_eval=None, df_future=None, target_name='subsidence', column_map=None, step_col='forecast_step', interval=(0.1, 0.9), target_coverage=0.8, median_q=0.5, use='auto', tol=0.02, f_max=5.0, max_iter=32, keep_original=False, enforce_monotonic='cummax', overall_key='__overall__', calibrated_col='is_calibrated', factor_col='calibration_factor', factors=None, save_eval=None, save_future=None, save_stats=None, verbose=1, logger=None)[source]

Fit and apply post-hoc interval calibration for quantile forecasts.

This is the high-level DataFrame-oriented entry point for interval recalibration in geoprior.utils.calibrate. It can

  1. detect whether evaluation forecasts already appear calibrated,

  2. fit interval-width correction factors from evaluation data,

  3. apply those factors to evaluation and/or future forecasts,

  4. compute before/after summary diagnostics on the evaluation set,

  5. optionally save the outputs to disk.

The function is designed for workflows where quantile forecasts are already available in tabular form and calibration should be handled without retraining the forecasting model.

Conceptually, the function widens or narrows a predictive interval around a median-like forecast so that the empirical interval coverage better matches the requested target. This is a practical post-hoc strategy for uncertainty refinement in multi-horizon forecasting pipelines [1, 2].

Parameters:
  • df_eval (pandas.DataFrame or None, default None) – Evaluation forecasts used to fit calibration factors and to compute before/after diagnostics. This table should contain the observed target column in addition to the quantile forecasts.

  • df_future (pandas.DataFrame or None, default None) – Future or inference forecasts to which the fitted factors should be applied. This table does not need observed targets.

  • target_name (str, default "subsidence") – Base name used to infer forecast and observation columns when column_map is not explicitly supplied.

  • column_map (mapping or None, default None) – Optional mapping describing the observed column and the quantile columns. This is helpful when the table does not follow the default naming conventions.

  • step_col (str, default "forecast_step") – Column used to fit and apply separate factors per forecast horizon.

  • interval (tuple of float, default (0.1, 0.9)) – Lower and upper quantiles defining the interval to calibrate. The nearest available quantiles are used.

  • target_coverage (float, default 0.8) – Desired empirical coverage after calibration.

  • median_q (float, default 0.5) – Central quantile used as the expansion anchor.

  • use ({"auto", True, False}, default "auto") –

    Control flag for whether calibration is performed.

    • False disables calibration and returns inputs unchanged.

    • "auto" skips calibration when evaluation forecasts already look calibrated.

    • True forces calibration even if the automatic check would skip it.

  • tol (float, default 0.02) – Tolerance used by the automatic already-calibrated check.

  • f_max (float, default 5.0) – Maximum factor allowed during fitting.

  • max_iter (int, default 32) – Maximum number of bisection iterations used when fitting factors.

  • keep_original (bool, default False) – If True, raw quantiles are copied into *_raw columns before calibration is applied.

  • enforce_monotonic ({"cummax", "sort", "none"}, default "cummax") – Strategy used to prevent quantile crossing after recalibration.

  • overall_key (str or None, default "__overall__") – Reserved label stored in the returned statistics dictionary for overall summary reporting.

  • calibrated_col (str, default "is_calibrated") – Column name added to calibrated outputs as a Boolean marker.

  • factor_col (str, default "calibration_factor") – Column name used to store the factor applied to each row.

  • factors (float or mapping or None, default None) – Optional user-specified calibration factors. If provided, these take precedence over factors fitted from df_eval.

  • save_eval (str or path-like or None, default None) – Optional CSV path for saving the calibrated evaluation table.

  • save_future (str or path-like or None, default None) – Optional CSV path for saving the calibrated future table.

  • save_stats (str or path-like or None, default None) – Optional JSON path for saving the calibration summary.

  • verbose (int, default 1) – Verbosity level forwarded to logging helpers.

  • logger (logging.Logger or None, default None) – Optional logger used for progress messages.

Returns:

  • df_eval_cal (pandas.DataFrame or None) – Calibrated evaluation DataFrame, or None when no evaluation table was provided.

  • df_future_cal (pandas.DataFrame or None) – Calibrated future DataFrame, or None when no future table was provided.

  • stats (dict[str, Any]) – Dictionary describing the calibration workflow. Depending on the path taken, it may contain

    • the target interval and target coverage,

    • the fitted or user-specified factors,

    • skip reasons,

    • evaluation summaries before and after calibration.

Return type:

tuple[DataFrame | None, DataFrame | None, dict[str, Any]]

Notes

In use="auto" mode, the function first checks for an explicit calibrated_col and then falls back to a simple empirical coverage-based decision. This makes the wrapper conservative in repeated workflows, where the same tables may pass through the calibration stage more than once.

The returned stats dictionary is designed to be JSON-friendly and therefore suitable for audit trails, experiment manifests, or gallery artifacts.

Examples

>>> import pandas as pd
>>> from geoprior.utils.calibrate import (
...     calibrate_quantile_forecasts,
... )
>>> df_eval = pd.DataFrame(
...     {
...         "forecast_step": [1, 1, 2, 2],
...         "subsidence_actual": [0.4, 0.7, 0.5, 0.9],
...         "subsidence_q10": [0.3, 0.5, 0.4, 0.6],
...         "subsidence_q50": [0.4, 0.7, 0.5, 0.8],
...         "subsidence_q90": [0.5, 0.9, 0.6, 1.0],
...     }
... )
>>> df_eval_cal, df_future_cal, stats = (
...     calibrate_quantile_forecasts(
...         df_eval=df_eval,
...         target_name="subsidence",
...         target_coverage=0.8,
...     )
... )
>>> isinstance(stats, dict)
True

See also

fit_interval_factors_df

Fit per-horizon interval-width correction factors.

apply_interval_factors_df

Apply a scalar or per-horizon factor map to quantile forecasts.

References

For the broader role of calibrated probabilistic multi-horizon forecasting, see Lim et al. [1].

For uncertainty-rich forecasting in the present project ecosystem, see Kouadio et al. [2].

geoprior.utils.audit_stage2_handshake(*, X_train, X_val, y_train, y_val, time_steps, forecast_horizon, mode, dyn_names, fut_names, sta_names, coord_scaler=None, sk_final, save_dir, table_width=100, title_prefix='STAGE-2 HANDSHAKE AUDIT', city='Unkown', model_name='Model', log_fn=None)[source]
Parameters:
geoprior.utils.audit_stage1_scaling(*, df_train, inputs_train, targets_train, coord_scaler=None, coord_ranges=None, coord_mode='auto', coords_in_degrees=False, coord_epsg_used=None, coord_x_col_used='x', coord_y_col_used='y', x_col_used='x', y_col_used='y', time_col_used='t', normalize_coords=True, keep_coords_raw=False, shift_raw_coords=False, subs_model_col=None, gwl_dyn_col=None, gwl_target_col=None, h_field_col=None, dynamic_features=None, static_features=None, future_features=None, scaled_ml_numeric_cols=None, main_scaler_path=None, scaler_info=None, save_dir=None, table_width=110, title_prefix='COORDINATE + FEATURE SCALING AUDIT (Stage-1)', city='Unknown', model_name='Model', sample_rows=5, log_fn=None)[source]

Stage-1 audit: - raw df_train coord stats (t/x/y) + heuristic units - model-fed coords stats from inputs_train[“coords”] (flattened) - coord scaler min/max + coord_ranges - SI channel sanity for physics cols (if present) - target arrays sanity - split of features: scaled ML vs __si vs other Saves a machine-readable JSON if save_dir is provided.

Parameters:
  • inputs_train (dict[str, Any])

  • targets_train (dict[str, Any])

  • coord_scaler (Any)

  • coord_ranges (dict[str, float] | None)

  • coord_mode (str)

  • coords_in_degrees (bool)

  • coord_epsg_used (Any)

  • coord_x_col_used (str)

  • coord_y_col_used (str)

  • x_col_used (str)

  • y_col_used (str)

  • time_col_used (str)

  • normalize_coords (bool)

  • keep_coords_raw (bool)

  • shift_raw_coords (bool)

  • subs_model_col (str | None)

  • gwl_dyn_col (str | None)

  • gwl_target_col (str | None)

  • h_field_col (str | None)

  • dynamic_features (Iterable[str] | None)

  • static_features (Iterable[str] | None)

  • future_features (Iterable[str] | None)

  • scaled_ml_numeric_cols (Iterable[str] | None)

  • main_scaler_path (str | None)

  • scaler_info (dict | None)

  • save_dir (str | None)

  • table_width (int)

  • title_prefix (str)

  • city (str)

  • model_name (str)

  • sample_rows (int)

Return type:

str | None

geoprior.utils.should_audit(audit_stages, *, stage, default=None)[source]

Convenience: should we audit this stage?

Parameters:
  • audit_stages (Any)

  • stage (str)

  • default (Any)

Return type:

bool

geoprior.utils.format_and_forecast(y_pred, y_true, *, coords=None, quantiles=None, target_name='subsidence', output_target_name=None, scaler_target_name=None, target_key_pred='subs_pred', component_index=0, scaler_info=None, coord_scaler=None, coord_columns=('coord_t', 'coord_x', 'coord_y'), train_end_time=None, forecast_start_time=None, forecast_horizon=None, future_time_grid=None, eval_forecast_step=None, eval_export='all', value_mode='rate', input_value_mode='rate', rate_first='cum_over_dtref', absolute_baseline=None, sample_index_offset=0, city_name=None, model_name=None, dataset_name=None, csv_eval_path=None, csv_future_path=None, time_as_datetime=False, time_format=None, calibration=False, calibration_kwargs=None, calibration_save_stats=None, eval_metrics=False, metrics_column_map=None, metrics_quantile_interval=(0.1, 0.9), metrics_per_horizon=False, metrics_extra=None, metrics_extra_kwargs=None, metrics_savefile=None, metrics_save_format='.json', metrics_time_as_str=True, output_unit=None, output_unit_from='m', output_unit_mode='overwrite', output_unit_suffix='_mm', output_unit_col=None, verbose=1, logger=None, **kws)[source]

Format PINN forecasts into evaluation and future DataFrames.

This helper takes the raw model outputs (already split into y_pred['subs_pred'] / y_pred['gwl_pred']), the matching ground-truth dictionary (y_true), and optional coordinate and scaler information, and returns two DataFrames:

  • df_eval: predictions + actuals for an evaluation year (typically the last training year, e.g. 2022).

  • df_future: predictions for the future horizon (e.g. 2023–2025), without actuals.

Parameters:
  • y_pred (dict) –

    Dictionary of model predictions, as returned by GeoPriorSubsNet.predict post-processed into {'subs_pred': ..., 'gwl_pred': ...}.

    For subsidence, the expected shapes are:

    • Quantile mode: (B, H, Q, O) where: B = batch size, H = horizon steps, Q = number of quantiles, O = output dim.

    • Point mode: (B, H, O).

  • y_true (dict or None) –

    Dictionary of true targets, typically

    {'subsidence': ..., 'gwl': ...} or {'subs_pred': ..., 'gwl_pred': ...}.

    If None, evaluation DataFrame is still created but without the actual-value column.

  • coords (ndarray, optional) – Optional coordinates array aligned with predictions. Commonly shaped (B, H, 3) with columns [t_scaled, x_scaled, y_scaled]. Only x and y are used when inverse-transforming spatial coordinates; time is overwritten by the provided temporal config if given.

  • quantiles (list of float or None, optional) – List of quantiles (e.g. [0.1, 0.5, 0.9]) if the model was trained in probabilistic mode. If None, a single prediction column is emitted instead.

  • target_name (str, default 'subsidence') –

    Logical target identifier used as the default key for locating the target scaler in scaler_info and as a fallback for resolving truth arrays in y_true.

    Column naming is controlled by output_target_name (or the auto-derived output prefix when it is None).

  • output_target_name (str or None, optional) –

    Output prefix used when creating DataFrame columns for predictions and actuals.

    This controls the column naming only (e.g. the function will emit f"{output_target_name}_q10", f"{output_target_name}_pred", and f"{output_target_name}_actual").

    If None (default), the function derives the output prefix from target_name and applies a small convenience rule: if target_name ends with "_cum" or "_cumulative", that suffix is stripped for output naming.

    This keeps downstream tooling consistent (many plotting and metrics utilities expect names like subsidence_q10 rather than subsidence_cum_q10), while still allowing the scaler lookup to use the true target key. For example, with target_name="subsidence_cum" and output_target_name=None, output columns become subsidence_q10, subsidence_q50, and subsidence_actual. If output_target_name="subsidence_cum", the output columns keep the suffix such as subsidence_cum_q10.

  • scaler_target_name (str or None, optional) –

    Name used to locate the target scaling block inside scaler_info and to perform inverse-transform for predictions and actuals.

    This controls the scaler key and inverse scaling, not the output column naming.

    If None (default), the scaler key is assumed to be target_name. This is important when you want clean output columns but the scaler was fitted/stored under the original target name.

    A common pattern is to keep target_name="subsidence_cum" so the scaler lookup matches the Stage-1 schema, while letting output_target_name=None produce clean output columns. In that setup, inverse transform still uses the subsidence_cum scaler key, while output columns use the subsidence_ prefix because of the auto-strip rule.

  • target_key_pred (str, default 'subs_pred') – Key inside y_pred that holds the subsidence forecasts.

  • component_index (int, default 0) – Index along the output dimension O to use when output_subsidence_dim > 1. For scalar subsidence this is 0.

  • scaler_info (dict, optional) – Optional Stage-1 scaler_info mapping containing a target scaler under keys such as 'targets' or 'target'. The target block is expected to provide an sklearn-like transformer under 'scaler' together with column names under 'columns' or 'cols'. If present and consistent, subsidence values (predicted and actual) are inverse-transformed for target_name.

  • coord_scaler (object, optional) – Optional scaler used for coordinates. If provided, it is only used to inverse-transform coord_x and coord_y when coords is given and coord_columns can be matched. Time is not taken from the inverse transform; it is controlled by the temporal config.

  • coord_columns (tuple of str, default (``’coord_t’:py:class:`,`’coord_x’:py:class:`,`’coord_y’``)) – Logical names of the time, x, and y coordinate columns. These are used for DataFrame column naming and for mapping into coord_scaler if its block carries column names.

  • train_end_time (scalar or str or datetime, optional) – Physical time associated with the evaluation year (e.g. 2022). If eval_forecast_step is not given, the last horizon step is assumed to correspond to this time.

  • forecast_start_time (scalar or str or datetime, optional) – First time in the future forecast horizon (e.g. 2023).

  • forecast_horizon (int, optional) – Number of forecast steps in the future horizon (e.g. 3). If future_time_grid is not given, this is used together with forecast_start_time to build a regular grid.

  • future_time_grid (array-like, optional) – Explicit physical times for each forecast step, length H. For yearly data this might be [2023, 2024, 2025]. If provided, it overrides any automatic construction from forecast_start_time and forecast_horizon.

  • eval_forecast_step (int or None, optional) – Horizon step index (1-based) to use for evaluation. If None, defaults to the last horizon step H.

  • eval_export ({"all", "last"} or str or int or sequence, optional) –

    Controls which evaluation rows are exported in df_eval and written to csv_eval_path. By default ("all"), the function exports the multi-horizon evaluation DataFrame (df_eval_all), which contains one row per sample and forecast step (e.g. years 2020, 2021, 2022 for H=3).

    Accepted values are:

    • "all" or "full" or "horizons" : export all horizons from df_eval_all.

    • "last" or "single" or "default" : export only the single evaluation step specified by eval_forecast_step (backwards-compatible behaviour).

    • Other str (e.g. "2022") : interpreted as a time value for coord_t; only rows of df_eval_all whose time column matches this value are exported.

    • int or scalar non-string : interpreted as a single time value (e.g. 2022).

    • sequence of values (e.g. [2021, 2022]) : interpreted as a set of time values; only rows whose coord_t belongs to this set are exported.

    If time_as_datetime=True, the selection values are converted with pandas.to_datetime using time_format before filtering. If df_eval_all is not available (e.g. no ground truth was provided), the function falls back to exporting the single-step df_eval regardless of eval_export.

  • value_mode ({"rate", "cumulative", "absolute_cumulative"}, optional) –

    Controls how forecast values are interpreted along the temporal horizon for each sample. The default is "rate", which treats each forecast step as an incremental rate (e.g. annual subsidence rate) and leaves predictions unchanged.

    Supported modes are:

    • "rate" : keep per-step predictions as provided by the model (current behaviour).

    • "cumulative" or "cum" : convert per-step rates into relative cumulative values by applying a cumulative sum over forecast_step for each sample_idx. For example, for years 2023–2025, the value at 2024 is the sum of the 2023 and 2024 rates.

    • "absolute_cumulative" or "abs_cum" or "absolute" : same as "cumulative", then add an absolute baseline provided by absolute_baseline (e.g. cumulative subsidence at the end of the training period), yielding absolute cumulative trajectories.

    Cumulative transforms are applied consistently to:

    • the future forecast DataFrame (df_future),

    • the multi-horizon evaluation DataFrame (df_eval_all),

    • and the single-step evaluation DataFrame (df_eval, which is regenerated from df_eval_all after the transformation).

    When an unsupported string is given, the function logs a warning and falls back to "rate".

  • absolute_baseline (float or Mapping[int, float], optional) –

    Baseline value to use when value_mode requests absolute cumulative outputs ("absolute_cumulative", "abs_cum", "absolute"). This baseline is interpreted as the pre-forecast cumulative level for each sample, for example, cumulative subsidence at train_end_time (e.g. end of 2022), and is added after applying the cumulative sum over the forecast horizon.

    If a scalar float is provided, the same baseline value is added to all samples. If a mapping is provided, it must map sample_idx (integers) to baseline values, allowing per-sample baselines:

    • absolute_baseline = {sample_idx: baseline_value, ...}

    Only prediction columns for target_name are shifted (e.g. "subsidence_q10", "subsidence_q50", "subsidence_q90" or "subsidence_pred"). When df_eval_all is present, the corresponding "<target_name>_actual" column is shifted as well, so evaluation metrics operate on absolute cumulative values.

    If value_mode is an absolute cumulative variant but absolute_baseline is None, the function logs a warning and degrades gracefully to relative cumulative mode (i.e. no baseline shift is applied).

  • sample_index_offset (int, default 0) – Offset added to sample_idx (useful when concatenating multiple tiles).

  • city_name (str, optional) – Optional metadata used only for logging.

  • model_name (str, optional) – Optional metadata used only for logging.

  • dataset_name (str, optional) – Optional metadata used only for logging.

  • csv_eval_path (str, optional) – If provided, df_eval is written to this path (directories are created if needed).

  • csv_future_path (str, optional) – If provided, df_future is written to this path.

  • time_as_datetime (bool, default False) – If True, time values are converted using pandas.to_datetime() with the provided time_format (if any).

  • time_format (str or None, optional) – Optional format string passed to pandas.to_datetime() when time_as_datetime=True.

  • eval_metrics (bool, default False) – If True, automatically call evaluate_forecast() on the resulting df_eval to compute diagnostics. Metrics are not returned by this function; they are either written to disk (if metrics_savefile is provided) or discarded. For programmatic access to the metrics dictionary, call evaluate_forecast() directly.

  • metrics_column_map (mapping, optional) – Optional column mapping forwarded to evaluate_forecast() (see its documentation for details). If None, default column names such as 'coord_t', 'forecast_step', f'{target_name}_q10', and f'{target_name}_actual' are assumed.

  • metrics_quantile_interval (tuple of float, default (0.1, 0.9)) – Interval used for coverage and sharpness diagnostics in quantile mode, forwarded to evaluate_forecast().

  • metrics_per_horizon (bool, default False) – If True, per-horizon MAE/MSE/R² are computed by evaluate_forecast() and included in the diagnostics.

  • metrics_extra (sequence or mapping, optional) –

    Optional additional metrics to compute, forwarded to evaluate_forecast(). Can be:

    • A sequence of metric names (resolved via geoprior.metrics._registry.get_metric).

    • A mapping {name: func} where func is a callable taking (y_true, y_pred, **kwargs).

  • metrics_extra_kwargs (mapping, optional) – Optional per-metric keyword arguments, forwarded to evaluate_forecast(). Keys must match metric names in metrics_extra.

  • metrics_savefile (str, path-like, bool, or None) – If truthy, diagnostics from evaluate_forecast() are written to disk. Behavior matches the savefile argument of evaluate_forecast(). When True, a filename is auto-generated near the evaluation CSV (if any) or in the current working directory.

  • metrics_save_format ({'.json', 'json', '.csv', 'csv'}, default '.json') – Output format for diagnostics written by evaluate_forecast(). JSON preserves the nested metric structure; CSV flattens it into a tall table.

  • metrics_time_as_str (bool, default True) – If True, time keys in the diagnostics written by evaluate_forecast() are converted to strings (useful for JSON serialization).

  • verbose (int, default 1) – Verbosity level passed to vlog().

  • logger (logging.Logger, optional) – Logger instance; if None, a module-level LOG is used.

  • input_value_mode (str)

  • rate_first (str)

  • calibration (str | bool)

  • calibration_kwargs (Mapping[str, Any] | None)

  • calibration_save_stats (str | PathLike | None)

  • output_unit (str | None)

  • output_unit_from (str)

  • output_unit_mode (str)

  • output_unit_suffix (str)

  • output_unit_col (str | None)

Returns:

  • df_eval_to_write (pandas.DataFrame) – DataFrame containing predictions and actuals for the evaluation time. Columns include:

    • 'sample_idx'

    • 'forecast_step'

    • quantile columns (e.g. subsidence_q10) or subsidence_pred

    • 'subsidence_actual' (if y_true given)

    • coord_t, coord_x, coord_y (names from coord_columns).

  • df_future (pandas.DataFrame) – DataFrame containing predictions for the future horizon, without actuals. Same structure as df_eval but without the actual-value column.

Return type:

tuple[DataFrame, DataFrame]

Notes

This function separates scaler lookup (scaler_target_name) from output column naming (output_target_name). This is useful when the stored scaler key contains suffixes like "_cum" but downstream tools expect canonical names such as columns prefixed with subsidence_.

geoprior.utils.evaluate_forecast(eval_data, *, target_name='subsidence', column_map=None, quantile_interval=(0.1, 0.9), per_horizon=False, extra_metrics=None, extra_metric_kwargs=None, overall_key='__overall__', savefile=None, save_format='.json', time_as_str=True, verbose=1, logger=None)[source]

Evaluate forecast diagnostics from an evaluation DataFrame.

This helper consumes the df_eval output from format_and_forecast() (or a compatible DataFrame) and computes aggregate metrics such as MAE, MSE, \(R^2\), coverage, and sharpness. It can also optionally evaluate metrics per forecast horizon and apply additional user-defined metrics.

By default it expects the following columns:

  • 'sample_idx'

  • 'forecast_step'

  • 'coord_t' (time)

  • Quantile or point-prediction columns for the target, e.g.:

    • Quantile mode: f'{target_name}_q10', f'{target_name}_q50', f'{target_name}_q90', …

    • Point mode: f'{target_name}_pred'.

  • Actual column: f'{target_name}_actual'.

A flexible column_map allows remapping these logical roles to arbitrary column names, e.g.:

column_map = {
    'coord_t': 'date',
    'actual': 'true_subs',
    'pred': 'subs_predicted',
}

or, for quantile columns:

column_map = {
    'coord_t': 'date',
    'quantiles': {
        0.1: 'subs_q10',
        0.5: 'subs_q50',
        0.9: 'subs_q90',
    },
}
Parameters:
  • eval_data (str, path-like, or pandas.DataFrame) – Either a path to a CSV file containing the evaluation DataFrame (as saved by format_and_forecast()) or an in-memory DataFrame.

  • target_name (str, default 'subsidence') – Base name for the target columns. Used to infer default column names such as f'{target_name}_q10', f'{target_name}_pred', and f'{target_name}_actual'.

  • column_map (dict, optional) –

    Optional mapping to override default column names. The following keys are recognized:

    • 'sample_idx' : sample index column name (default 'sample_idx').

    • 'forecast_step' : horizon index column name (default 'forecast_step').

    • 'coord_t' : time coordinate column (default 'coord_t').

    • 'actual' : name or list of names for the actual target column(s). Currently a single column is supported; default f'{target_name}_actual'.

    • 'pred' : point prediction column for non-quantile mode, default f'{target_name}_pred'.

    • 'quantiles' :

      • If a mapping: {q: col_name} for quantile levels, where q is a float in (0, 1).

      • If a sequence of column names, the quantile value will be inferred from suffix patterns like f'{target_name}_q{int(q*100):d}'.

  • quantile_interval (tuple of float, default (0.1, 0.9)) – Interval (lower, upper) used for coverage and sharpness metrics, typically corresponding to an 80% interval between Q10 and Q90.

  • per_horizon (bool, default False) – If True, compute per-horizon MAE/MSE/R² grouped by the forecast_step column.

  • extra_metrics (sequence of str or mapping, optional) –

    Optional additional metrics to compute.

    • If a sequence of strings (e.g. ['pss', 'pit']), each name is resolved via geoprior.metrics._registry.get_metric(). If the name is not present in the registry, an error is raised, prompting the user to pass a callable instead.

    • If a mapping {name: func}, each func is called as:

      func(y_true, y_pred, **extra_metric_kwargs.get(name, {}))
      

      where y_pred is the median (Q50) or point forecast.

    For more complex metrics that require full quantile structure or temporal sequences, pass a suitable wrapper function that internally uses the DataFrame as needed.

  • extra_metric_kwargs (mapping, optional) – Optional mapping of per-metric keyword arguments. Keys must match the names in extra_metrics. Each value is a dict of kwargs forwarded to the corresponding metric function.

  • savefile (str, path-like, or bool, optional) –

    If provided, metrics are saved to disk.

    • If True: a filename is auto-generated near eval_data (if it is a path) or in the current working directory.

    • If a string/path without extension: the extension is taken from save_format.

    • If a string/path with extension: that extension takes precedence over save_format.

  • save_format ({'.json', 'json', '.csv', 'csv'}, default '.json') –

    Output format when savefile is truthy. JSON preserves nested structure; CSV is flattened into a tall table.

    • For JSON, the function returns the metrics dictionary.

    • For CSV, the function returns the metrics DataFrame.

  • time_as_str (bool, default True) – If True, time keys in the result dictionary are converted to strings (useful for JSON serialization). If there is only a single time value, the result is flattened and the time key is omitted.

  • verbose (int, default 1) – Verbosity level passed to vlog().

  • logger (logging.Logger, optional) – Optional logger instance used by vlog().

  • overall_key (str | None)

Returns:

results – If save_format is JSON (default), returns a dict:

  • Single time value:

    {
        "overall_mae": ...,
        "overall_mse": ...,
        "overall_r2": ...,
        "coverage80": ...,
        "sharpness80": ...,
        "per_horizon_mae": {1: ..., 2: ..., ...},
        ...
    }
    
  • Multiple time values:

    {
        "2021": { ...metrics... },
        "2022": { ...metrics... },
    }
    

If save_format is CSV, returns a DataFrame with flattened rows:

  • Columns include: coord_t, metric, horizon, and value.

Return type:

dict or pandas.DataFrame

Notes

  • Default metrics in quantile mode:

    • overall_mae, overall_mse, overall_r2

    • coverage80 and sharpness80 (using the requested interval, e.g., Q10–Q90)

    If per_horizon=True, also:

    • per_horizon_mae, per_horizon_mse, per_horizon_r2 (each a mapping from horizon index to score).

  • Default metrics in point mode (no quantiles):

    • mae, mse, r2

    And optionally, if per_horizon=True:

    • per_horizon_mae, per_horizon_mse, per_horizon_r2.

geoprior.utils.default_results_dir(start=None, env_var='RESULTS_DIR', folder_name='results', create=False)[source]

Resolve the canonical ‘results’ directory with robust fallbacks.

Parameters:
Return type:

str

geoprior.utils.ensure_directory_exists(path)[source]

Ensure that a directory exists at the given path, creating it if needed.

This function checks whether the provided path exists and is a directory. If the path does not exist, it attempts to create the directory (including any necessary parent directories). If a file with the same name already exists, or if creation fails, an exception is raised.

Parameters:

path (str or pathlib.Path) – The filesystem path for which to ensure directory existence. Can be either a string or a pathlib.Path object.

Returns:

A Path object pointing to the existing (or newly created) directory.

Return type:

pathlib.Path

Raises:
  • TypeError – If path is not a string or pathlib.Path.

  • FileExistsError – If a file (not a directory) already exists at path.

  • OSError – If the directory cannot be created for any other reason (e.g., insufficient permissions).

Examples

>>> from pathlib import Path
>>> from geoprior.utils.generic_utils import ensure_directory_exists
>>> output_dir = ensure_directory_exists("data/output")
>>> isinstance(output_dir, Path)
True
>>> # The directory "data/output" now exists on disk.

Notes

  • Uses pathlib.Path.mkdir(…, parents=True, exist_ok=True) under the hood for cross-platform compatibility.

  • If path already exists as a directory, this function returns immediately without modifying it.

See also

pathlib.Path.mkdir

Method to create a directory.

os.makedirs

Legacy function for creating directories recursively.

geoprior.utils.getenv_stripped(name, default=None, allow_empty=False)[source]

Read an environment variable and strip whitespace robustly.

Parameters:
  • name (str) – Environment variable name to read.

  • default (str or None, optional) – Value returned when the environment variable is not set.

  • allow_empty (bool, default False) – If False, empty strings are treated as missing and default is returned instead. If True, an empty string is returned unchanged.

Returns:

out – The stripped string value, the empty string (if allowed), or default when unset / empty.

Return type:

str or None

geoprior.utils.print_config_table(sections, title=None, table_width=None, sort_keys=True, key_col_fraction=0.35, max_value_length=200, log_fn=None)[source]

Pretty-print configuration or hyperparameters as a key/value table.

This helper is intended for CLI scripts (Stage-1, training, tuning) so that the user can quickly inspect which parameters are actually in effect.

Parameters:
  • sections (dict or sequence of (str, dict)) –

    If a single dict is passed, all key/value pairs are printed in one block.

    If a sequence is passed, it must contain (name, params) tuples, where name is a section label (e.g. "Physics") and params is a dict mapping parameter names to values.

  • title (str, optional) – Optional title displayed above the table (centered).

  • table_width (int, optional) – Total width of the printed table. If None, the function tries to use geoprior.api.util.get_table_size(). If that fails, it falls back to the terminal width (via shutil.get_terminal_size) or 80 characters.

  • sort_keys (bool, default True) – Whether to sort parameter names alphabetically within each section.

  • key_col_fraction (float, default 0.35) – Fraction of the table width allocated to the parameter-name column. The remainder is used for the value column.

  • max_value_length (int, default 200) – Maximum number of characters kept from the stringified value. Longer values are truncated with an ellipsis ("...") before being wrapped onto multiple lines.

  • log_fn (callable, optional) – Function used to emit lines (defaults to print()). This allows capturing the table in logs if needed.

Returns:

The full rendered table as a single string. It is always printed via print_fn as a side effect.

Return type:

str

Notes

  • Nested containers (lists, tuples, dicts) are rendered in a compact one-line form and then wrapped to fill the value column.

  • This function is intentionally lightweight and does not depend on external tabulation libraries, so it can be safely used in lightweight Stage-1 / Stage-2 scripts.

geoprior.utils.save_all_figures(output_dir='figures', prefix='figure', fmts=('png',), close=True, dpi=150, transparent=False, timestamp=True, verbose=True)[source]

Save all currently open Matplotlib figures to disk in specified formats.

Parameters:
  • output_dir (str) – Directory where figures will be saved. Created if not exists.

  • prefix (str) – Filename prefix for each figure.

  • formats (list or tuple of str) – File formats/extensions to use (e.g., (‘png’,’pdf’)).

  • close (bool) – Whether to close each figure after saving. Default is True.

  • dpi (int or None) – Resolution in dots per inch. None uses Matplotlib default.

  • transparent (bool) – Whether to save figures with transparent background.

  • timestamp (bool) – Append current timestamp (YYYYmmddTHHMMSS) to filenames.

  • verbose (bool) – Print progress messages.

  • fmts (list[str] | tuple)

Returns:

List of saved file paths.

Return type:

List[str]

Examples

>>> import matplotlib.pyplot as plt
>>> plt.figure(); plt.plot([1, 2, 3])
>>> from geoprior.utils.generic_utils import save_all_figures
>>> paths = save_all_figures(output_dir="plots", formats=("png",))
>>> print(paths)
['plots/figure_1_20250521T153045.png']
geoprior.utils.build_censor_mask(xb, H, idx, thresh=0.5, *, source='dynamic', reduce_time='any', align='broadcast')[source]

Build a censor mask aligned to the forecast horizon: (B, H, 1).

Parameters:
  • source ({"dynamic", "future"}, default "dynamic") – Selects where the censoring flag is read from. "dynamic" reads xb["dynamic_features"][:, :, idx] from the history window, while "future" reads xb["future_features"][:, :, idx] from the forecast window.

  • reduce_time ({"any", "last", "all"}, default "any") – Reduction applied when source="dynamic" and the censor flag behaves like a per-sample label. "any" marks the sample as censored if any history step is flagged, "last" uses only the last history step, and "all" requires every history step to be flagged.

  • align ({"broadcast", "crop", "pad_false", "pad_edge", "error"}, default "broadcast") – Policy used when the time axis does not already match the forecast horizon H. "broadcast" repeats a single-step label across all horizon steps, "crop" keeps the last H steps, "pad_false" pads missing steps with False, "pad_edge" repeats the last available step, and "error" raises on mismatch.

  • xb (dict)

  • idx (int | None)

  • thresh (float)

Return type:

Tensor

geoprior.utils.ensure_input_shapes(x, mode, forecast_horizon)[source]

Ensure presence of zero-width static/future placeholders.

Stage-1 exporters sometimes omit static_features or future_features when there are no static/future variables for a particular experiment. Keras, however, expects these inputs to exist so that the input signature remains stable.

This helper:

  • Copies the input dict to avoid in-place modification.

  • Ensures static_features is an array of shape (N, 0) if missing.

  • Ensures future_features is an array of shape (N, T_future, 0) if missing, where:

    • T_future = dynamic_features.shape[1] when mode == "tft_like" (past+future style).

    • Otherwise, T_future = forecast_horizon.

Parameters:
  • x (dict) – Dictionary containing at least dynamic_features with shape (N, T_dyn, D_dyn).

  • mode (str) – Model mode. When "tft_like" the future sequence length is inferred from the dynamic sequence.

  • forecast_horizon (int) – Forecast horizon in time steps/years for non-TFT modes.

Returns:

Shallow copy of x with guaranteed static_features and future_features entries.

Return type:

dict

geoprior.utils.extract_preds(model, out, *, strict=True, output_names=None)[source]

Extract (subs_pred, gwl_pred) from GeoPrior outputs.

Supports:
  1. v3.2+ call(): {“subs_pred”,”gwl_pred”}

  2. forward_with_aux(): (y_pred, aux)

  3. legacy: {“data_final”} + model.split_data_predictions

  4. predict(): list/tuple mapped via output names

If strict=True, list/tuple outputs must be mappable via output names; otherwise we raise to avoid silent swaps.

This helper normalizes the output interface across two GeoPrior generation families:

  1. New interface (preferred) model(inputs) -> {"subs_pred": ..., "gwl_pred": ...}

  2. Legacy interface (backward compatible) model(inputs) -> {"data_final": ...}, where the caller must split the tensor using model.split_data_predictions.

Parameters:
  • model (object) –

    A Keras-like model instance that may expose split_data_predictions(data_final).

    The splitter must return a tuple:

    • subs_pred with shape (B, H, 1) or (B, H, Q, 1)

    • gwl_pred with shape (B, H, 1) or (B, H, Q, 1)

  • out (dict) –

    Output returned by the model call, typically model(inputs, training=False).

    Supported keys are either:

    • {"subs_pred", "gwl_pred"} (new interface), or

    • {"data_final"} (legacy interface).

  • strict (bool)

  • output_names (Sequence[str] | None)

Returns:

  • subs_pred (Tensor) – Predicted subsidence in model space.

    Expected shapes:

    • Point mode: (B, H, 1)

    • Quantile mode: (B, H, Q, 1)

  • gwl_pred (Tensor) – Predicted groundwater/head variable in model space.

    Expected shapes:

    • Point mode: (B, H, 1)

    • Quantile mode: (B, H, Q, 1)

Raises:
  • KeyError – If out does not contain a supported key set.

  • TypeError – If out is not a mapping/dict-like object.

Return type:

tuple[Any, Any]

Notes

This function is intended for Stage-2 and Stage-3 scripts where you may load checkpoints from older experiments. It avoids fragile code that slices data_final manually.

The function does not validate tensor dtypes or numerical finiteness. Upstream code should handle NaN and Inf checks as needed. Output normalization follows the Keras model conventions documented in Keras Team [24].

Examples

New interface:

out = model_inf(xb, training=False)
s_pred, h_pred = extract_stage_outputs(
    model_inf,
    out,
)

Legacy interface:

out = model_inf(xb, training=False)
s_pred, h_pred = extract_stage_outputs(
    model_inf,
    out,
)

See also

subs_point_from_stage_out

Convert subsidence predictions to a point forecast.

geoprior.utils.load_nat_config(root='nat.com')[source]

High-level helper used by NATCOM scripts.

Example

>>> from geoprior.utils.nat_utils import load_nat_config
>>> cfg = load_nat_config()
>>> CITY_NAME = cfg["CITY_NAME"]
>>> TIME_STEPS = cfg["TIME_STEPS"]
Return type:

dict[str, Any]

geoprior.utils.load_nat_config_payload(root='nat.com')[source]

Return the full config.json payload, including city, model and __meta__ fields.

This is convenient when you also want to see which hash or city/model are currently active.

Return type:

dict[str, Any]

geoprior.utils.load_scaler_info(encoders_block)[source]

Load the scaler_info mapping from an encoders block.

Stage-1 exporters typically store a compact description of the scalers used to normalise the data. In many cases this takes the form:

encoders = {
    "main_scaler": "/path/to/minmax.joblib",
    "coord_scaler": "/path/to/coords.joblib",
    "scaler_info": "/path/to/scaler_info.joblib",
    ...
}

where scaler_info is either a path to a joblib file or an already-loaded dictionary.

This helper returns a dictionary regardless of how it was stored, making downstream formatting/evaluation code simpler.

Parameters:

encoders_block (dict) – The encoders part of the Stage-1 manifest (M["artifacts"]["encoders"]).

Returns:

The loaded scaler_info dictionary, or None if not present / not loadable.

Return type:

dict or None

geoprior.utils.make_tf_dataset(X_np, y_np, batch_size, shuffle, mode, forecast_horizon, *, seed=42, drop_remainder=False, reshuffle_each_iter=True, prefetch=True, check_npz_finite=False, check_finite=False, scan_finite_batches=0, dynamic_feature_names=None, future_feature_names=None)[source]

Build a tf.data.Dataset using NATCOM conventions.

Steps: 1) ensure_input_shapes(…) for X. 2) map_targets_for_training(…) for y. 3) tf.data pipeline (shuffle/batch/prefetch). 4) optional finite checks (NPZ + tf batches).

Parameters:
  • X_np (dict) – Input dictionary, typically obtained from np.load on the Stage-1 *_inputs_npz file.

  • y_np (dict) – Target dictionary, typically obtained from np.load on the Stage-1 *_targets_npz file.

  • batch_size (int) – Number of samples per batch.

  • shuffle (bool) – If True, shuffle the dataset using a fixed seed for reproducibility.

  • mode (str) – Model mode passed to ensure_input_shapes().

  • forecast_horizon (int) – Forecast horizon passed to ensure_input_shapes().

  • check_npz_finite (bool) – If True, checks Xin/Yin numpy arrays for NaN/Inf before building ds.

  • check_finite (bool) – If True, inserts assert_all_finite checks inside the tf.data pipeline.

  • scan_finite_batches (int) – If >0, eagerly scans first N batches right away (fails early).

  • dynamic_feature_names (list[str] | None) – If provided, used to report bad channels for feature tensors.

  • future_feature_names (list[str] | None) – If provided, used to report bad channels for feature tensors.

  • seed (int)

  • drop_remainder (bool)

  • reshuffle_each_iter (bool)

  • prefetch (bool)

Returns:

Dataset of (X, y) pairs.

Return type:

tf.data.Dataset

Notes

TensorFlow is imported lazily inside the function so that this module remains importable in environments where TF is not installed (for example, for tooling or static analysis).

geoprior.utils.map_targets_for_training(y_dict, subs_key='subsidence', gwl_key='gwl', subs_pred_key='subs_pred', gwl_pred_key='gwl_pred')[source]

Standardise target dictionaries to the Keras compile keys.

This helper enforces a small convention used throughout the NATCOM training scripts:

  • Upstream sequence builders typically export raw targets with keys subsidence and gwl.

  • The GeoPrior model is compiled with targets named subs_pred and gwl_pred.

This function accepts either style and always returns a dict keyed by subs_pred and gwl_pred for use in Keras.

Parameters:
  • y_dict (dict) – Dictionary produced by the Stage-1 sequence exporter or by a previous training script. Must contain either (subsidence, gwl) or (subs_pred, gwl_pred).

  • subs_key (str, default "subsidence") – Name of the raw subsidence key in y_dict.

  • gwl_key (str, default "gwl") – Name of the raw groundwater-level key in y_dict.

  • subs_pred_key (str, default "subs_pred") – Standardised key for the subsidence prediction target.

  • gwl_pred_key (str, default "gwl_pred") – Standardised key for the GWL prediction target.

Returns:

New dictionary with keys subs_pred and gwl_pred.

Return type:

dict

Raises:

KeyError – If the dictionary does not contain either of the expected key pairs.

geoprior.utils.name_of(obj)[source]

Return a human-readable name for an object.

This utility is handy when serialising compile configurations (e.g., turning metric callables into simple strings for JSON logs).

Parameters:

obj (object) – Any Python object (function, class instance, etc.).

Returns:

obj.__name__ if present, otherwise the class name, and finally str(obj) as a last resort.

Return type:

str

geoprior.utils.resolve_hybrid_config(manifest_cfg, live_cfg, verbose=True)[source]

Merge Manifest config (Data Authority) with Live config (Physics Authority).

Parameters:
Return type:

dict

geoprior.utils.resolve_si_affine(cfg, scaler_info, *, target_name, prefix, unit_factor_key, scale_key, bias_key)[source]
Parameters:
  • cfg (dict)

  • scaler_info (dict)

  • target_name (str)

  • prefix (str)

  • unit_factor_key (str)

  • scale_key (str)

  • bias_key (str)

geoprior.utils.best_epoch_and_metrics(history, monitor='val_loss')[source]

Return the best epoch and metrics at that epoch.

Given a History.history dictionary produced by model.fit(...), this helper identifies the index of the minimum value for the monitored quantity (by default "val_loss") and returns:

  • The epoch index (0-based).

  • A dictionary mapping each metric name to its value at that epoch.

Parameters:
  • history (dict) – The history.history attribute from Keras training.

  • monitor (str, default "val_loss") – Name of the metric to minimise.

Returns:

  • best_epoch (int or None) – Index of the best epoch, or None if monitor is not present.

  • metrics_at_best (dict) – Mapping from metric name to its value at the best epoch. Empty if monitor is not present.

Return type:

tuple[int | None, dict]

geoprior.utils.subs_point_from_out(model, out, quantiles=None, med_idx=None)[source]

Convert model output into a subsidence point forecast.

This helper produces a subsidence tensor shaped (B, H, 1) in model space, regardless of whether the model emits quantiles or a point prediction.

  • If quantiles are present and the subsidence prediction is shaped (B, H, Q, 1), the function selects the median quantile slice.

  • Otherwise, it returns the point prediction directly.

Parameters:
  • model (object) – A Keras-like model instance passed to extract_stage_outputs().

  • out (dict) –

    Output returned by the model call.

    This can be either the new interface with keys "subs_pred" and "gwl_pred", or the legacy interface with key "data_final".

  • quantiles (sequence of float or None, default None) –

    Quantile levels used by the model, such as [0.1, 0.5, 0.9].

    If provided, the function may use it to interpret the rank-4 quantile output and select the median.

    If None, quantile selection is disabled unless med_idx is explicitly provided and the tensor rank indicates quantiles.

  • med_idx (int or None, default None) –

    Index along the quantile axis to use as the “point” forecast when quantiles are available.

    If None and quantiles is provided, the function selects the index closest to 0.5.

Returns:

subs_point – Subsidence point prediction in model space with shape (B, H, 1).

Return type:

Tensor

Raises:
  • ValueError – If subsidence prediction is missing or None.

  • ValueError – If a quantile tensor is detected but a valid median index cannot be resolved.

Notes

Quantile outputs are assumed to be shaped (B, H, Q, 1) where the quantile axis is the third dimension (axis=2).

If the model returns point predictions already, the function is effectively a no-op. The quantile interpretation used here follows Koenker and Bassett [25].

Examples

Quantile model:

out = model_inf(xb, training=False)
s_point = subs_point_from_stage_out(
    model_inf,
    out,
    quantiles=[0.1, 0.5, 0.9],
)

Point model:

out = model_inf(xb, training=False)
s_point = subs_point_from_stage_out(
    model_inf,
    out,
)

See also

extract_stage_outputs

Normalize outputs across new and legacy checkpoints.

geoprior.utils.serialize_subs_params(params, cfg=None)[source]

Make GeoPrior subnet parameters JSON-friendly.

The training scripts typically pass a dictionary of model construction arguments, e.g. subsmodel_params, which contains objects such as LearnableMV or FixedGammaW that are not directly JSON-serialisable.

This helper replaces those objects by small dictionaries describing their type and scalar value, optionally using values from the NATCOM config dictionary.

Parameters:
  • params (dict) – Dictionary of model init parameters (e.g. subsmodel_params in training_NATCOM_GEOPRIOR.py).

  • cfg (dict, optional) –

    NATCOM config dictionary. If provided, scalar values are taken from:

    • GEOPRIOR_INIT_MV

    • GEOPRIOR_INIT_KAPPA

    • GEOPRIOR_GAMMA_W

    • GEOPRIOR_H_REF

    and used as the authoritative numbers.

Returns:

Copy of params where scalar GeoPrior parameters are replaced by JSON-friendly dictionaries.

Return type:

dict

Notes

This function does not import any of the GeoPrior classes. It only introspects attributes like initial_value or value when the corresponding config entry is missing.

geoprior.utils.save_ablation_record(outdir, city, model_name, cfg, eval_dict, phys_diag=None, per_h_mae=None, per_h_r2=None, log_fn=None)[source]

Append a single ablation record to ablation_record.jsonl.

Each training run (e.g., different physics toggles or weights) writes one JSON line containing:

  • Basic run identifiers (city, model, timestamp).

  • Physics configuration (PDE_MODE_CONFIG, lambda weights, effective head flags, etc.).

  • Key performance metrics (R², MSE, MAE, coverage, sharpness).

  • Optional physics diagnostics (epsilon_prior, epsilon_cons).

  • Optional per-horizon MAE/R² for more detailed analysis.

Parameters:
  • outdir (str) – Base output directory for the current run. The ablation file is created under outdir / "ablation_records".

  • city (str) – City name (e.g., "nansha" or "zhongshan").

  • model_name (str) – Model identifier (e.g., "GeoPriorSubsNet").

  • cfg (dict) – Lightweight configuration dictionary containing at least the physics-related keys used below.

  • eval_dict (dict or None) – Dictionary of evaluation metrics (R², MSE, MAE, coverage80, sharpness80). If None, metrics fields default to None.

  • phys_diag (dict or None, optional) – Physics diagnostics (e.g., from evaluate()) with keys such as "epsilon_prior" and "epsilon_cons".

  • per_h_mae (dict or None, optional) – Per-horizon MAE values (e.g., keyed by year/step).

  • per_h_r2 (dict or None, optional) – Per-horizon R² values.

Return type:

None

Notes

The output file is a JSON-Lines file, so it can be loaded with load_ablation_jsonl().

geoprior.utils.cumulative_to_rate(df, *, cum_col='subsidence_cum', rate_col='subsidence', time_col='year', group_cols=('longitude', 'latitude', 'city'), first='cum_over_dtref', inplace=False)[source]

Recover a rate series from cumulative displacement.

rate(t_i) = (cum(t_i) - cum(t_{i-1})) / dt_i for i>=1

first:
  • ‘nan’: first rate is NaN

  • ‘cum_over_dtref’: rate(t0) = cum(t0)/dt_ref (dt_ref median dt)

Return type:

DataFrame with rate_col added/overwritten.

Parameters:
geoprior.utils.normalize_gwl_alias(df, gwl_col_user, *, prefer_depth_bgs=True, verbose=True)[source]

Normalize common GWL naming aliases.

Naming-only: unit conversion happens later.

Parameters:
Return type:

tuple[DataFrame, str | None]

geoprior.utils.rate_to_cumulative(df, *, rate_col='subsidence', cum_col='subsidence_cum', time_col='year', group_cols=('longitude', 'latitude', 'city'), initial='first_equals_rate_dt', inplace=False)[source]

Build cumulative displacement from a rate series.

Parameters:
Return type:

DataFrame

geoprior.utils.resolve_gwl_for_physics(df, gwl_col_user, *, prefer_depth_bgs=True, allow_keep_zscore_as_ml=True, verbose=True)[source]

Pick meters GWL for physics; keep z-score ML-only.

Parameters:
Return type:

tuple[str, str | None]

geoprior.utils.resolve_head_column(df, *, depth_col, head_col='head_m', z_surf_col=None, use_head_proxy=True)[source]

Resolve a head column, creating one if needed.

Parameters:
Return type:

tuple[str, str | None]

geoprior.utils.make_txy_coords(t, x, y, *, time_shift='min', xy_shift='min', time_shift_value=None, x_shift_value=None, y_shift_value=None, dtype='float32')[source]

Build coords tensor (t, x, y) with OPTIONAL shifting (translation only).

This is designed for your “not normalized” workflow:
  • You keep SI units (years and meters),

  • but you avoid feeding huge UTM magnitudes (e.g. 3e5, 2.5e6) into coord MLPs by shifting x,y (and optionally t).

Notes

  • This does NOT min-max scale to [0,1]. It only translates.

  • Returning coord_mins/coord_ranges is still useful for logging/debug.

Parameters:
Return type:

CoordsPack

geoprior.utils.compute_group_masks(df, *, group_cols, time_col, train_end_year, time_steps, horizon)[source]
Build:
  • valid_for_train: groups containing all years for last (T+H)

  • valid_for_forecast: groups containing all years for last T

This assumes annual steps and integer years in time_col.

Parameters:
Return type:

GroupMasks

geoprior.utils.split_groups_holdout(groups, *, seed=42, val_frac=0.2, test_frac=0.1, strategy='random', x_col=None, y_col=None, block_size=None)[source]

Split unique groups into train/val/test (pixel-level holdout).

strategy:
  • “random”: shuffle groups

  • “spatial_block”: shuffle spatial blocks (needs x_col,y_col,block)

Parameters:
Return type:

HoldoutSplit

geoprior.utils.filter_df_by_groups(df, *, group_cols, groups)[source]

Keep only rows in df whose (group_cols) exist in groups DataFrame.

Parameters:
Return type:

DataFrame

geoprior.utils.build_future_sequences_npz(df_scaled, *, time_col, time_col_num, lon_col, lat_col, time_steps, train_end_time=None, forecast_start_time=None, forecast_horizon=None, subs_col=None, gwl_col=None, h_field_col=None, static_features=None, dynamic_features=None, future_features=None, group_id_cols=None, mode=None, model_name=None, artifacts_dir=None, prefix='future', future_mode='auto', normalize_coords=False, coord_scaler=None, verbose=1, logger=None, stop_check=None, progress_hook=None, **kws)[source]

Build history–future sequences and save them as compressed NPZ files.

This helper constructs, for each spatial group, a sliding window of time_steps “history” points followed by a multi–step forecast horizon and exports the resulting NumPy arrays to disk. It is time-agnostic: the time_col can be numeric (e.g. year, index), year-like floats, datetimes, or strings, as long as equality on that column is meaningful.

If train_end_time, forecast_start_time, or forecast_horizon are not provided, they are inferred from the sorted unique values in df_scaled[time_col]:

  • train_end_time: by default the second-to-last unique time, leaving at least one future step.

  • forecast_start_time: by default the first time strictly after train_end_time.

  • forecast_horizon: by default one time step ahead, clipped to the number of available future points.

For each valid group, the function builds history dynamic features of shape (time_steps, n_dynamic), future features of shape (time_steps + H, n_future) when mode starts with "tft" or (H, n_future) otherwise, one static feature vector of shape (n_static,), coordinates over the horizon of shape (H, 3) with columns [time_num, lon, lat], an H_field array of shape (H, 1), and optional subsidence and groundwater targets of shape (H, 1) each.

All per-group arrays are stacked along a new batch dimension and written as two NPZ files:

  • <prefix>_inputs.npz: coordinates, dynamic, static, future features and H field.

  • <prefix>_targets.npz: subsidence and groundwater targets.

Parameters:
  • df_scaled (pandas.DataFrame) – Pre-processed (typically scaled) dataframe containing all required columns: time, spatial coordinates, static/dynamic/ future features and optional targets.

  • time_col (str) – Name of the column encoding the temporal index (e.g. "year", "date", "t_index"). May be numeric, datetime, or string.

  • time_col_num (str or None) – Optional numeric time column used as a tie-breaker when multiple rows share the same time_col value. If provided and present in a group, the last row sorted by this column is selected for that time.

  • lon_col (str) – Name of the longitude (or x-coordinate) column.

  • lat_col (str) – Name of the latitude (or y-coordinate) column.

  • time_steps (int) – Length of the history window (number of time steps in the past). Must be strictly positive.

  • train_end_time (object, optional) – Effective end of the training period. If None, it is inferred as the second-to-last unique value in df_scaled[time_col] (after sorting).

  • forecast_start_time (object, optional) – First time step of the forecast horizon. If None, it is inferred as the first unique time strictly greater than train_end_time.

  • forecast_horizon (int, optional) – Number of future time steps to include. If None, a default horizon of 1 is used and clipped to the maximum number of available future time points.

  • subs_col (str, optional) – Name of the subsidence target column. If None or missing from a group, subsidence targets are filled with NaN.

  • gwl_col (str, optional) – Name of the groundwater-level target column. If None or missing from a group, groundwater targets are filled with NaN.

  • h_field_col (str, optional) – Name of the hydraulic-head field column used as an additional horizon-level input (H_field). If None or missing, a zero field is used.

  • static_features (list of str, optional) – Names of static (time-invariant) feature columns. Any names not present in the dataframe are silently ignored.

  • dynamic_features (list of str, optional) – Names of dynamic (history) feature columns used to build the (time_steps, n_dynamic) sequence. Missing columns are ignored.

  • future_features (list of str, optional) – Names of future covariate columns used to build the history+future or future-only sequence, depending on mode. Missing columns are ignored.

  • group_id_cols (list of str, optional) – Columns used to define spatial (or logical) groups, typically something like ["lon", "lat"] or a station identifier. If None or empty, the entire dataframe is treated as a single global group.

  • mode (str, optional) – Controls how future features are constructed. If the lower-cased value starts with "tft" (e.g. "tft_like"), future features are built on top of both history and future rows. Otherwise, only the forecast horizon rows are used.

  • model_name (str, optional) – Optional model identifier used only in logging messages.

  • artifacts_dir (str, optional) – Directory where NPZ files are written. If None or empty, the current working directory is used.

  • prefix (str, default "future") – Prefix for the output NPZ filenames: "<prefix>_inputs.npz" and "<prefix>_targets.npz".

  • future_mode ({'auto', 'pure-inference', 'pure-data-driven'}, default 'auto') –

    Strategy used to construct the future (forecast) portion of the sequences.

    • 'pure-data-driven': Use only time points that actually exist in df_scaled strictly after the history window. All future time indices must be present in the data; otherwise a ValueError is raised. This corresponds to the original, strictly data-driven behaviour.

    • 'pure-inference': Always synthesize future time points from the last history time, using the median positive time step (or 1.0 as a fallback). Future inputs are built by re-using the last available history row (for future_features, H_field, etc.), and future targets (e.g. subsidence, GWL) are filled with NaN since the true future is unknown. This mode does not require any rows beyond train_end_time.

    • 'auto': Try data-driven mode first. If there are enough actual future time points after train_end_time to cover the requested forecast_horizon, behave like 'pure-data-driven'. If not, automatically fall back to the synthetic 'pure-inference' behaviour described above and emit an informational log message via vlog.

  • verbose (int, default 1) – Verbosity level forwarded to geoprior.utils.vlog(). A value >= 3 provides detailed progress logs (temporal inference, per-group status, dropped groups, etc.).

  • logger (logging.Logger or callable, optional) – Optional logger or logging function used by geoprior.utils.vlog(). If None, messages are printed to standard output.

  • **kws – Reserved for future extensions. Currently ignored.

  • normalize_coords (bool)

  • coord_scaler (Any | None)

  • stop_check (Callable[[], bool])

  • progress_hook (Callable[[float], None] | None)

Returns:

A small dictionary with the absolute paths to the written NPZ files:

{"future_inputs_npz": <path>, "future_targets_npz": <path>}.

Return type:

dict

Raises:

ValueError – If there are not enough history points before train_end_time to satisfy time_steps, if no future points are available after forecast_start_time, or if all groups are dropped due to incomplete history/horizon windows.

Notes

Groups that do not contain all required history and future times are silently dropped, but the number of dropped groups is reported via geoprior.utils.vlog() when verbose > 0.

Examples

>>> from geoprior.nn.pinn.sequences import (
...     build_future_sequences_npz,
... )
>>> result = build_future_sequences_npz(
...     df_scaled=df_scaled,
...     time_col="year",
...     time_col_num="t_index",
...     lon_col="lon",
...     lat_col="lat",
...     time_steps=5,
...     # Let the function infer times/horizon:
...     train_end_time=None,
...     forecast_start_time=None,
...     forecast_horizon=None,
...     subs_col="subsidence",
...     gwl_col="gwl",
...     h_field_col="H_field",
...     static_features=["lithology_class"],
...     dynamic_features=["rainfall_mm", "GWL_depth_bgs_z"],
...     future_features=["normalized_urban_load_proxy"],
...     group_id_cols=["lon", "lat"],
...     mode="tft_like",
...     model_name="GeoPriorSubsNet",
...     artifacts_dir="results/zhongshan/future_npz",
...     prefix="zhongshan_future",
...     verbose=2,
... )
>>> result["future_inputs_npz"]
'results/zhongshan/future_npz/zhongshan_future_inputs.npz'
>>> result["future_targets_npz"]
'results/zhongshan/future_npz/zhongshan_future_targets.npz'
geoprior.utils.resolve_n_jobs(n_jobs)[source]
Parameters:

n_jobs (int)

Return type:

int

geoprior.utils.threads_per_job(*, n_jobs, threads=0, reserve=1)[source]
Parameters:
Return type:

int

geoprior.utils.apply_tf_threading(*, intra, inter)[source]
Parameters:
Return type:

None

geoprior.utils.apply_thread_env(env, *, n_jobs, threads=0, reserve=1)[source]
Parameters:
Return type:

dict[str, str]

geoprior.utils.resolve_device(device, *, env=None)[source]
Parameters:
Return type:

str

geoprior.utils.resolve_gpu_ids(gpu_ids, *, env=None)[source]
Parameters:
Return type:

list[str]

geoprior.utils.pick_gpu_id(idx, gpu_ids)[source]
Parameters:
Return type:

str | None

geoprior.utils.apply_gpu_env(env, *, gpu_id, allow_growth=True)[source]
Parameters:
Return type:

dict[str, str]

class geoprior.utils.ArtifactRecord(path, kind, payload, stage=None, city=None, model=None, meta=<factory>)[source]

Bases: object

Lightweight normalized artifact container.

Parameters:
  • path (pathlib.Path) – Artifact path.

  • kind (str) – Inferred or explicit artifact kind.

  • payload (dict[str, Any]) – Loaded JSON payload.

  • stage (str or None) – Stage if available.

  • city (str or None) – City if available.

  • model (str or None) – Model if available.

  • meta (dict[str, Any]) – Extra extracted metadata.

path: Path
kind: str
payload: dict[str, Any]
stage: str | None
city: str | None
model: str | None
meta: dict[str, Any]
__init__(path, kind, payload, stage=None, city=None, model=None, meta=<factory>)
Parameters:
Return type:

None

geoprior.utils.artifact_brief(record)[source]

Return a compact artifact header summary.

Parameters:

record (ArtifactRecord)

Return type:

dict[str, Any]

geoprior.utils.as_path(path)[source]

Return path as resolved Path.

Parameters:

path (str | Path)

Return type:

Path

geoprior.utils.bool_checks_frame(mapping, *, section=None)[source]

Convert boolean checks into a tidy DataFrame.

Parameters:
Return type:

DataFrame

geoprior.utils.clone_artifact(template, *, overrides=None)[source]

Clone a template payload and apply overrides.

This is useful for Sphinx-Gallery examples where we want a realistic artifact with a few controlled changes.

Parameters:
Return type:

dict[str, Any]

geoprior.utils.deep_update(base, updates)[source]

Recursively update base with updates.

Returns a new dictionary.

Parameters:
Return type:

dict[str, Any]

geoprior.utils.ensure_parent_dir(path)[source]

Create parent directory for path.

Parameters:

path (str | Path)

Return type:

Path

geoprior.utils.flatten_dict(mapping, *, parent_key='', sep='.')[source]

Flatten nested dictionaries.

Non-dict values are kept as they are. Lists and arrays are not expanded.

Parameters:
Return type:

dict[str, Any]

geoprior.utils.infer_artifact_kind(path, payload=None)[source]

Infer artifact kind from file name and keys.

The rules are intentionally simple and stable. Artifact-specific readers can still override the inferred kind if needed.

Parameters:
Return type:

str

geoprior.utils.is_number(value)[source]

Return True for finite or non-finite scalars.

Parameters:

value (Any)

Return type:

bool

geoprior.utils.json_ready(value)[source]

Convert nested values into JSON-safe objects.

Notes

  • NaN and Inf are converted to None.

  • numpy scalars are converted to Python scalars.

  • arrays become lists.

Parameters:

value (Any)

Return type:

Any

geoprior.utils.load_artifact(path, *, kind=None)[source]

Load a JSON artifact into ArtifactRecord.

Parameters:
Return type:

ArtifactRecord

geoprior.utils.metrics_frame(mapping, *, section=None, sort=True)[source]

Convert scalar metrics into a tidy DataFrame.

Parameters:
Return type:

DataFrame

geoprior.utils.nested_get(mapping, *keys, default=None)[source]

Safely traverse nested dictionaries.

Examples

nested_get(d, "config", "scaling_kwargs")

Parameters:
Return type:

Any

geoprior.utils.numeric_items(mapping, *, drop_bools=True)[source]

Extract numeric scalar items from a mapping.

Parameters:
Return type:

dict[str, float]

geoprior.utils.plot_boolean_checks(ax, checks=None, *, title='Checks', show_grid=True, grid_kws=None, ax_obj=None, error='ignore', **plot_kws)[source]

Plot boolean pass/fail checks as a bar view.

Keeps the older (ax, checks) call pattern while allowing (checks, ax_obj=ax) for newer code.

Parameters:
Return type:

Axes

geoprior.utils.plot_metric_bars(ax, metrics=None, *, title='Metrics', top_n=None, sort_by_value=False, absolute=False, annotate=True, xlabel='value', show_grid=True, grid_kws=None, annotate_kws=None, ax_obj=None, error='ignore', **plot_kws)[source]

Plot a compact horizontal metric bar chart.

The legacy calling style plot_metric_bars(ax, metrics, ...) is preserved. A newer style plot_metric_bars(metrics, ax_obj=ax, ...) is also accepted for gradual migration.

Parameters:
Return type:

Axes

geoprior.utils.plot_series_map(ax, series_map=None, *, title='Series', xlabel='key', ylabel='value', marker='o', show_grid=True, grid_kws=None, ax_obj=None, error='ignore', **plot_kws)[source]

Plot a string-keyed numeric mapping as a line.

Keeps the older (ax, series_map) form while also accepting (series_map, ax_obj=ax).

Parameters:
Return type:

Axes

geoprior.utils.read_json(path)[source]

Read a JSON file into a dictionary.

Parameters:

path (str | Path)

Return type:

dict[str, Any]

geoprior.utils.write_json(payload, path, *, indent=2, sort_keys=False)[source]

Write payload as UTF-8 JSON.

Parameters:
Return type:

Path

geoprior.utils.build_stage1_feature_split(*, dynamic_features, static_features, future_features, scaled_ml_numeric_cols)[source]

Build the Stage-1 feature split section.

Parameters:
  • dynamic_features (list of str) – Dynamic input feature names.

  • static_features (list of str) – Static input feature names.

  • future_features (list of str) – Future-known feature names.

  • scaled_ml_numeric_cols (list of str) – Features scaled by the main scaler.

Returns:

Stage-1 style feature split mapping.

Return type:

dict

geoprior.utils.default_stage1_audit_payload(*, city='demo_city', model='GeoPriorSubsNet', normalize_coords=True, keep_coords_raw=None, shift_raw_coords=True, coords_in_degrees=False, coord_mode='degrees', coord_epsg_used=32649, coord_ranges=None, dynamic_features=None, static_features=None, future_features=None, scaled_ml_numeric_cols=None, scaler_path='artifacts/main_scaler.joblib', n_sequences=1200, forecast_horizon=3, time_steps=5)[source]

Build a realistic default Stage-1 audit payload.

The payload is template-based. It is not meant to reproduce the full preprocessing mathematics. Instead, it creates a stable and inspectable audit artifact with the same broad structure as the real Stage-1 audit file.

Parameters:
  • city (str)

  • model (str)

  • normalize_coords (bool)

  • keep_coords_raw (bool | None)

  • shift_raw_coords (bool)

  • coords_in_degrees (bool)

  • coord_mode (str)

  • coord_epsg_used (int | None)

  • coord_ranges (dict[str, float] | None)

  • dynamic_features (list[str] | None)

  • static_features (list[str] | None)

  • future_features (list[str] | None)

  • scaled_ml_numeric_cols (list[str] | None)

  • scaler_path (str)

  • n_sequences (int)

  • forecast_horizon (int)

  • time_steps (int)

Return type:

dict[str, Any]

geoprior.utils.generate_stage1_audit(*, output_path=None, template=None, overrides=None, **kwargs)[source]

Generate a Stage-1 audit payload or file.

Parameters:
  • output_path (path-like, optional) – Destination JSON path. If omitted, the payload is returned instead of written.

  • template (mapping, ArtifactRecord, or path, optional) – Real or synthetic Stage-1 audit template used as the generation base.

  • overrides (dict, optional) – Nested overrides applied after template/default payload creation.

  • **kwargs (dict) – Parameters forwarded to default_stage1_audit_payload when no template is given.

Return type:

dict[str, Any] | Path

geoprior.utils.inspect_stage1_audit(audit, *, output_dir=None, stem='stage1_audit', save_figures=True)[source]

Inspect a Stage-1 audit and optionally save figures.

Returns:

Bundle containing summary, tabular frames, and optionally written figure paths.

Return type:

dict

Parameters:
geoprior.utils.load_stage1_audit(path)[source]

Load a Stage-1 audit artifact.

Raises:

ValueError – If the artifact does not look like a Stage-1 audit payload.

Parameters:

path (str | Path)

Return type:

ArtifactRecord

geoprior.utils.plot_stage1_boolean_summary(audit, *, ax=None, title='Stage-1 audit checks', error='ignore', **plot_kws)[source]

Plot semantic pass/fail checks.

Parameters:
Return type:

Axes

geoprior.utils.plot_stage1_coord_ranges(audit, *, ax=None, title='Stage-1 coord ranges', ylabel='range', show_grid=True, grid_kws=None, annotate=True, annotate_kws=None, error='ignore', **plot_kws)[source]

Plot chain-rule coordinate ranges.

Parameters:
Return type:

Axes

geoprior.utils.plot_stage1_feature_split(audit, *, ax=None, title='Stage-1 feature split', xlabel='feature group', ylabel='features', show_grid=True, grid_kws=None, legend_kws=None, error='ignore', **plot_kws)[source]

Plot feature bucket counts by feature group.

Parameters:
Return type:

Axes

geoprior.utils.plot_stage1_target_stats(audit, *, stat='mean', ax=None, title=None)[source]

Plot target summary statistics.

Parameters:
Return type:

Axes

geoprior.utils.plot_stage1_variable_stats(audit, *, section='physics_df_stats', stat='mean', ax=None, title=None, error='ignore', **plot_kws)[source]

Plot one statistic across variables.

Parameters:
Return type:

Axes

geoprior.utils.stage1_coord_ranges_frame(audit)[source]

Return coord ranges as a tidy frame.

Parameters:

audit (ArtifactRecord | Mapping[str, Any] | str | Path)

Return type:

DataFrame

geoprior.utils.stage1_feature_split_frame(audit)[source]

Explode the feature split into tidy rows.

Parameters:

audit (ArtifactRecord | Mapping[str, Any] | str | Path)

Return type:

DataFrame

geoprior.utils.stage1_stats_frame(audit, *, section='physics_df_stats')[source]

Return a tidy frame for nested variable stats.

Parameters:
Return type:

DataFrame

geoprior.utils.summarize_stage1_audit(audit)[source]

Build a compact semantic summary for inspection.

Parameters:

audit (ArtifactRecord | Mapping[str, Any] | str | Path)

Return type:

dict[str, Any]

geoprior.utils.default_stage2_handshake_payload(*, city='demo_city', model='GeoPriorSubsNet', time_steps=5, forecast_horizon=3, mode='tft_like', n_train=1200, n_val=320, dynamic_dim=5, future_dim=1, static_dim=12, coords_normalized=True, coord_ranges=None, coords_in_degrees=False, coord_epsg_used=32649, time_units='year', gwl_dyn_name='GWL_depth_bgs_m__si', gwl_dyn_index=0, subs_dyn_index=1, z_surf_static_index=11, use_head_proxy=True, q_kind='per_volume', drainage_mode='double')[source]

Build a realistic default Stage-2 handshake payload.

The payload is template-based. It is not meant to rerun Stage-2 logic. Instead, it creates a stable and inspectable audit artifact with the same broad structure as the real Stage-2 handshake file.

Parameters:
  • city (str)

  • model (str)

  • time_steps (int)

  • forecast_horizon (int)

  • mode (str)

  • n_train (int)

  • n_val (int)

  • dynamic_dim (int)

  • future_dim (int)

  • static_dim (int)

  • coords_normalized (bool)

  • coord_ranges (dict[str, float] | None)

  • coords_in_degrees (bool)

  • coord_epsg_used (int | None)

  • time_units (str)

  • gwl_dyn_name (str)

  • gwl_dyn_index (int)

  • subs_dyn_index (int)

  • z_surf_static_index (int)

  • use_head_proxy (bool)

  • q_kind (str)

  • drainage_mode (str)

Return type:

dict[str, Any]

geoprior.utils.generate_stage2_handshake(*, output_path=None, template=None, overrides=None, **kwargs)[source]

Generate a Stage-2 handshake payload or file.

Parameters:
  • output_path (path-like, optional) – Destination JSON path. If omitted, the payload is returned instead of written.

  • template (mapping, ArtifactRecord, or path, optional) – Real or synthetic Stage-2 handshake template used as the generation base.

  • overrides (dict, optional) – Nested overrides applied after template/default payload creation.

  • **kwargs (dict) – Parameters forwarded to default_stage2_handshake_payload when no template is given.

Return type:

dict[str, Any] | Path

geoprior.utils.inspect_stage2_handshake(audit, *, output_dir=None, stem='stage2_handshake', save_figures=True)[source]

Inspect a Stage-2 handshake and optionally save figures.

Returns:

Bundle containing summary, tabular frames, and optionally written figure paths.

Return type:

dict

Parameters:
geoprior.utils.load_stage2_handshake(path)[source]

Load a Stage-2 handshake artifact.

Raises:

ValueError – If the artifact does not look like a Stage-2 handshake payload.

Parameters:

path (str | Path)

Return type:

ArtifactRecord

geoprior.utils.plot_stage2_boolean_summary(audit, *, ax=None, title='Stage-2 handshake checks', error='ignore', **plot_kws)[source]

Plot semantic pass/fail checks.

Parameters:
Return type:

Axes

geoprior.utils.plot_stage2_coord_range_errors(audit, *, ax=None, title='Stage-2 coord range relative errors', ylabel='relative error', show_grid=True, grid_kws=None, annotate=True, annotate_kws=None, error='ignore', **plot_kws)[source]

Plot coord range relative errors.

Parameters:
Return type:

Axes

geoprior.utils.plot_stage2_coord_stats(audit, *, section='coord_stats_norm', stat='mean', ax=None, title=None, error='ignore', **plot_kws)[source]

Plot one coord statistic across axes.

Parameters:
Return type:

Axes

geoprior.utils.plot_stage2_finite_ratios(audit, *, ax=None, title='Stage-2 finite ratios', error='ignore', **plot_kws)[source]

Plot finite-ratio metrics.

Parameters:
Return type:

Axes

geoprior.utils.plot_stage2_sample_sizes(audit, *, ax=None, title='Stage-2 sample sizes', error='ignore', **plot_kws)[source]

Plot training and validation counts.

Parameters:
Return type:

Axes

geoprior.utils.plot_stage2_scaling_summary(audit, *, ax=None, title='Stage-2 scaling summary', top_n=12, error='ignore', **plot_kws)[source]

Plot numeric items from sk_summary.

Parameters:
Return type:

Axes

geoprior.utils.stage2_coord_range_frame(audit)[source]

Return coord range spans and relative errors.

Parameters:

audit (ArtifactRecord | Mapping[str, Any] | str | Path)

Return type:

DataFrame

geoprior.utils.stage2_coord_stats_frame(audit, *, section='coord_stats_norm')[source]

Return a tidy frame for coord stat blocks.

Parameters:
Return type:

DataFrame

geoprior.utils.stage2_finite_frame(audit)[source]

Return finite-ratio checks as a tidy frame.

Parameters:

audit (ArtifactRecord | Mapping[str, Any] | str | Path)

Return type:

DataFrame

geoprior.utils.stage2_layout_frame(audit)[source]

Return expected vs observed layout rows.

Parameters:

audit (ArtifactRecord | Mapping[str, Any] | str | Path)

Return type:

DataFrame

geoprior.utils.stage2_scaling_frame(audit)[source]

Return a tidy frame for the compact scaling summary.

Parameters:

audit (ArtifactRecord | Mapping[str, Any] | str | Path)

Return type:

DataFrame

geoprior.utils.summarize_stage2_handshake(audit)[source]

Build a compact semantic summary for inspection.

Parameters:

audit (ArtifactRecord | Mapping[str, Any] | str | Path)

Return type:

dict[str, Any]

geoprior.utils.default_training_summary_payload(*, city='demo_city', model='GeoPriorSubsNet', horizon=3, best_epoch=17, timestamp='20260222-211635', optimizer='Adam', learning_rate=0.001, time_steps=5, pde_mode='on', offset_mode='mul', quantiles=None, attention_levels=None, coords_normalized=True, coord_ranges=None, run_dir='results/demo_run/train_20260222-211635')[source]

Build a realistic default training-summary payload.

The payload is template-based. It is not meant to reproduce the full training loop. Instead, it creates a stable and inspectable summary artifact with the same broad structure as the real training summary.

Parameters:
Return type:

dict[str, Any]

geoprior.utils.generate_training_summary(*, output_path=None, template=None, overrides=None, **kwargs)[source]

Generate a training-summary payload or file.

Parameters:
  • output_path (path-like, optional) – Destination JSON path. If omitted, the payload is returned instead of written.

  • template (mapping, ArtifactRecord, or path, optional) – Real or synthetic training-summary template used as the generation base.

  • overrides (dict, optional) – Nested overrides applied after template/default payload creation.

  • **kwargs (dict) – Parameters forwarded to default_training_summary_payload when no template is given.

Return type:

dict[str, Any] | Path

geoprior.utils.inspect_training_summary(summary, *, output_dir=None, stem='training_summary', save_figures=True)[source]

Inspect a training summary and optionally save figures.

Returns:

Bundle containing summary, tabular frames, and optionally written figure paths.

Return type:

dict

Parameters:
geoprior.utils.load_training_summary(path)[source]

Load a training-summary artifact.

Raises:

ValueError – If the artifact does not look like a training summary payload.

Parameters:

path (str | Path)

Return type:

ArtifactRecord

geoprior.utils.plot_training_best_metrics(summary, *, split='val', keys=None, ax=None, title=None, error='ignore', **plot_kws)[source]

Plot selected metrics from metrics_at_best.

Parameters:
Return type:

Axes

geoprior.utils.plot_training_boolean_summary(summary, *, ax=None, title='Training summary checks', error='ignore', **plot_kws)[source]

Plot semantic pass/fail checks.

Parameters:
Return type:

Axes

geoprior.utils.plot_training_final_metrics(summary, *, split='val', keys=None, ax=None, title=None, error='ignore', **plot_kws)[source]

Plot selected metrics from final_epoch_metrics.

Parameters:
Return type:

Axes

geoprior.utils.plot_training_loss_family(summary, *, section='metrics_at_best', split='val', ax=None, title=None, error='ignore', **plot_kws)[source]

Plot the loss-family subset for one metric section.

Parameters:
Return type:

Axes

geoprior.utils.plot_training_metric_deltas(summary, *, split='val', keys=None, ax=None, title=None, xlabel='delta', show_grid=True, grid_kws=None, annotate=False, annotate_kws=None, error='ignore', **plot_kws)[source]

Plot final - best deltas for aligned metrics.

Parameters:
Return type:

Axes

geoprior.utils.training_compile_frame(summary)[source]

Return a tidy frame for compile settings.

Parameters:

summary (ArtifactRecord | Mapping[str, Any] | str | Path)

Return type:

DataFrame

geoprior.utils.training_env_frame(summary)[source]

Return a tidy frame for environment info.

Parameters:

summary (ArtifactRecord | Mapping[str, Any] | str | Path)

Return type:

DataFrame

geoprior.utils.training_hp_frame(summary)[source]

Return a tidy frame for hp/init settings.

Parameters:

summary (ArtifactRecord | Mapping[str, Any] | str | Path)

Return type:

DataFrame

geoprior.utils.training_metrics_frame(summary, *, section='metrics_at_best', split='all')[source]

Return a tidy frame for train/validation metrics.

Parameters:
Return type:

DataFrame

geoprior.utils.training_paths_frame(summary)[source]

Return a tidy frame for output paths.

Parameters:

summary (ArtifactRecord | Mapping[str, Any] | str | Path)

Return type:

DataFrame

geoprior.utils.summarize_training_summary(summary)[source]

Build a compact semantic summary for inspection.

Parameters:

summary (ArtifactRecord | Mapping[str, Any] | str | Path)

Return type:

dict[str, Any]

geoprior.utils.default_eval_diagnostics_payload(*, years=None, per_horizon_mae=None, per_horizon_mse=None, per_horizon_rmse=None, per_horizon_r2=None, coverage80=None, sharpness80=None, pss=None)[source]

Build a realistic default eval-diagnostics payload.

The payload is template-based. It is not meant to reproduce the full evaluation pipeline. Instead, it creates a stable and inspectable diagnostics artifact with the same broad structure as the real compact eval_diagnostics JSON.

Parameters:
Return type:

dict[str, Any]

geoprior.utils.eval_overall_frame(diagnostics)[source]

Return a compact frame for the __overall__ block.

Parameters:

diagnostics (ArtifactRecord | Mapping[str, Any] | str | Path)

Return type:

DataFrame

geoprior.utils.eval_per_horizon_frame(diagnostics)[source]

Return a tidy per-horizon metrics frame.

Parameters:

diagnostics (ArtifactRecord | Mapping[str, Any] | str | Path)

Return type:

DataFrame

geoprior.utils.eval_years_frame(diagnostics)[source]

Return one row per year block.

Parameters:

diagnostics (ArtifactRecord | Mapping[str, Any] | str | Path)

Return type:

DataFrame

geoprior.utils.generate_eval_diagnostics(*, output_path=None, template=None, overrides=None, **kwargs)[source]

Generate an eval-diagnostics payload or file.

Parameters:
  • output_path (path-like, optional) – Destination JSON path. If omitted, the payload is returned instead of written.

  • template (mapping, ArtifactRecord, or path, optional) – Real or synthetic diagnostics template used as the generation base.

  • overrides (dict, optional) – Nested overrides applied after template/default payload creation.

  • **kwargs (dict) – Parameters forwarded to default_eval_diagnostics_payload when no template is given.

Return type:

dict[str, Any] | Path

geoprior.utils.inspect_eval_diagnostics(diagnostics, *, output_dir=None, stem='eval_diagnostics', save_figures=True)[source]

Inspect eval diagnostics and optionally save figures.

Returns:

Bundle containing summary, tabular frames, and optionally written figure paths.

Return type:

dict

Parameters:
geoprior.utils.load_eval_diagnostics(path)[source]

Load an eval-diagnostics artifact.

Raises:

ValueError – If the artifact does not look like a compact evaluation-diagnostics payload.

Parameters:

path (str | Path)

Return type:

ArtifactRecord

geoprior.utils.plot_eval_boolean_summary(diagnostics, *, ax=None, title='Evaluation diagnostics checks', error='ignore', **plot_kws)[source]

Plot semantic pass/fail checks.

Parameters:
Return type:

Axes

geoprior.utils.plot_eval_overall_metrics(diagnostics, *, keys=None, ax=None, title=None, error='ignore', **plot_kws)[source]

Plot selected top-level metrics from __overall__.

Parameters:
Return type:

Axes

geoprior.utils.plot_eval_per_horizon_metrics(diagnostics, *, metric='rmse', ax=None, title=None, show_grid=True, grid_kws=None, error='ignore', **plot_kws)[source]

Plot one per-horizon metric from __overall__.

Parameters:
Return type:

Axes

geoprior.utils.plot_eval_year_metric_trend(diagnostics, *, metric='overall_mae', ax=None, title=None, show_grid=True, grid_kws=None, error='ignore', **plot_kws)[source]

Plot one metric across year blocks.

Parameters:
Return type:

Axes

geoprior.utils.summarize_eval_diagnostics(diagnostics)[source]

Build a compact semantic summary for inspection.

Parameters:

diagnostics (ArtifactRecord | Mapping[str, Any] | str | Path)

Return type:

dict[str, Any]

geoprior.utils.default_eval_physics_payload(*, timestamp='20260222-215049', city='demo_city', model='GeoPriorSubsNet', quantiles=None, horizon=3, batch_size=32, subs_unit='mm', time_units='year')[source]

Build a realistic default eval-physics payload.

The payload is template-based. It is not meant to reproduce the Stage-2 evaluation path. Instead, it creates a stable and inspectable artifact with the same broad structure as the interpretable physics evaluation JSON.

Parameters:
Return type:

dict[str, Any]

geoprior.utils.eval_physics_calibration_frame(payload)[source]

Return a tidy frame for top-level calibration scalars.

Parameters:

payload (ArtifactRecord | Mapping[str, Any] | str | Path)

Return type:

DataFrame

geoprior.utils.eval_physics_calibration_per_horizon_frame(payload)[source]

Return a tidy per-horizon calibration frame.

Parameters:

payload (ArtifactRecord | Mapping[str, Any] | str | Path)

Return type:

DataFrame

geoprior.utils.eval_physics_censor_frame(payload)[source]

Return a tidy frame for censor-aware metrics.

Parameters:

payload (ArtifactRecord | Mapping[str, Any] | str | Path)

Return type:

DataFrame

geoprior.utils.eval_physics_metrics_frame(payload)[source]

Return a tidy frame for metrics_evaluate.

Parameters:

payload (ArtifactRecord | Mapping[str, Any] | str | Path)

Return type:

DataFrame

geoprior.utils.eval_physics_per_horizon_frame(payload)[source]

Return a tidy frame for exported per-horizon metrics.

Parameters:

payload (ArtifactRecord | Mapping[str, Any] | str | Path)

Return type:

DataFrame

geoprior.utils.eval_physics_point_metrics_frame(payload)[source]

Return a tidy frame for point metrics.

Parameters:

payload (ArtifactRecord | Mapping[str, Any] | str | Path)

Return type:

DataFrame

geoprior.utils.eval_physics_units_frame(payload)[source]

Return a tidy frame for units metadata.

Parameters:

payload (ArtifactRecord | Mapping[str, Any] | str | Path)

Return type:

DataFrame

geoprior.utils.generate_eval_physics(*, output_path=None, template=None, overrides=None, **kwargs)[source]

Generate an eval-physics payload or file.

Parameters:
  • output_path (path-like, optional) – Destination JSON path. If omitted, the payload is returned instead of written.

  • template (mapping, ArtifactRecord, or path, optional) – Real or synthetic eval-physics template used as the generation base.

  • overrides (dict, optional) – Nested overrides applied after template/default payload creation.

  • **kwargs (dict) – Parameters forwarded to default_eval_physics_payload when no template is given.

Return type:

dict[str, Any] | Path

geoprior.utils.inspect_eval_physics(payload, *, output_dir=None, stem='eval_physics', save_figures=True)[source]

Inspect an eval-physics artifact and optionally save figures.

Returns:

Bundle containing summary, tabular frames, and optionally written figure paths.

Return type:

dict

Parameters:
geoprior.utils.load_eval_physics(path)[source]

Load an eval-physics artifact.

Raises:

ValueError – If the artifact does not look like an eval-physics payload.

Parameters:

path (str | Path)

Return type:

ArtifactRecord

geoprior.utils.plot_eval_physics_boolean_summary(payload, *, ax=None, title='Eval physics checks', error='ignore', **plot_kws)[source]

Plot semantic pass/fail checks.

Parameters:
Return type:

Axes

geoprior.utils.plot_eval_physics_calibration_factors(payload, *, source='top', ax=None, title=None, error='ignore', **plot_kws)[source]

Plot per-horizon calibration factors.

Parameters:
  • source ({'top', 'nested'}, default 'top') – 'top' uses factors_per_horizon. 'nested' uses factors_per_horizon_from_cal_stats.

  • payload (ArtifactRecord | Mapping[str, Any] | str | Path)

  • ax (Axes | None)

  • title (str | None)

  • error (str)

  • plot_kws (Any)

Return type:

Axes

geoprior.utils.plot_eval_physics_epsilons(payload, *, ax=None, title='Eval physics: epsilon diagnostics', error='ignore', **plot_kws)[source]

Plot epsilon-related diagnostics.

Parameters:
Return type:

Axes

geoprior.utils.plot_eval_physics_metrics(payload, *, keys=None, ax=None, title=None, error='ignore', **plot_kws)[source]

Plot selected metrics_evaluate values.

Parameters:
Return type:

Axes

geoprior.utils.plot_eval_physics_per_horizon_metrics(payload, *, metric='mae', ax=None, title=None, error='ignore', **plot_kws)[source]

Plot one exported per-horizon metric map.

Parameters:
Return type:

Axes

geoprior.utils.plot_eval_physics_point_metrics(payload, *, ax=None, title='Eval physics: point metrics', error='ignore', **plot_kws)[source]

Plot point-metric summary.

Parameters:
Return type:

Axes

geoprior.utils.summarize_eval_physics(payload)[source]

Build a compact semantic summary for inspection.

Parameters:

payload (ArtifactRecord | Mapping[str, Any] | str | Path)

Return type:

dict[str, Any]

geoprior.utils.default_physics_payload_meta_payload(*, city='demo_city', model_name='GeoPriorSubsNet', split='TestSet', created_utc='2026-02-22T13:50:48Z', saved_utc=None, pde_modes_active=None, time_units='year', time_coord_units='year', gwl_kind='depth_bgs', gwl_sign='down_positive', use_head_proxy=True, kappa_mode='kb', use_effective_h=True, hd_factor=0.6, lambda_offset=1.0, kappa_value=0.046392478, tau_prior_definition='tau_closure_from_learned_fields', tau_prior_human_name='tau_closure', tau_prior_source='model.evaluate_physics() closure head', tau_closure_formula='Hd^2 * Ss / (pi^2 * kappa_b * K)', units=None, payload_metrics=None)[source]

Build a realistic default physics-payload meta payload.

The payload is intentionally lightweight and template-based. It mirrors the structure of the real *.npz.meta.json sidecar without trying to reconstruct the full physics payload archive.

Parameters:
  • city (str)

  • model_name (str)

  • split (str)

  • created_utc (str)

  • saved_utc (str | None)

  • pde_modes_active (list[str] | None)

  • time_units (str)

  • time_coord_units (str)

  • gwl_kind (str)

  • gwl_sign (str)

  • use_head_proxy (bool)

  • kappa_mode (str)

  • use_effective_h (bool)

  • hd_factor (float)

  • lambda_offset (float)

  • kappa_value (float)

  • tau_prior_definition (str)

  • tau_prior_human_name (str)

  • tau_prior_source (str)

  • tau_closure_formula (str)

  • units (dict[str, str] | None)

  • payload_metrics (dict[str, Any] | None)

Return type:

dict[str, Any]

geoprior.utils.generate_physics_payload_meta(path, *, template=None, overrides=None, **kwargs)[source]

Generate and save a physics-payload meta artifact.

Parameters:
  • path (path-like) – Output JSON path.

  • template (mapping, artifact, or path, optional) – Template payload to clone. When omitted, the function starts from default_physics_payload_meta_payload.

  • overrides (dict, optional) – Deep updates applied after loading the template.

  • **kwargs (Any) – Forwarded to the default payload builder when no template is provided.

Return type:

Path

geoprior.utils.inspect_physics_payload_meta(payload)[source]

Return a structured inspection bundle.

Parameters:

payload (ArtifactRecord | Mapping[str, Any] | str | Path)

Return type:

dict[str, Any]

geoprior.utils.load_physics_payload_meta(path)[source]

Load a physics-payload meta artifact.

Parameters:

path (str | Path)

Return type:

ArtifactRecord

geoprior.utils.physics_payload_meta_closure_frame(payload)[source]

Return a tidy closure / tau-prior frame.

Parameters:

payload (ArtifactRecord | Mapping[str, Any] | str | Path)

Return type:

DataFrame

geoprior.utils.physics_payload_meta_identity_frame(payload)[source]

Return a tidy identity / convention frame.

Parameters:

payload (ArtifactRecord | Mapping[str, Any] | str | Path)

Return type:

DataFrame

geoprior.utils.physics_payload_meta_metrics_frame(payload)[source]

Return compact payload metrics as a tidy frame.

Parameters:

payload (ArtifactRecord | Mapping[str, Any] | str | Path)

Return type:

DataFrame

geoprior.utils.physics_payload_meta_units_frame(payload)[source]

Return the units block as a tidy frame.

Parameters:

payload (ArtifactRecord | Mapping[str, Any] | str | Path)

Return type:

DataFrame

geoprior.utils.plot_physics_payload_meta_boolean_summary(payload, *, ax=None, title='Physics payload meta: checks', **plot_kws)[source]

Plot boolean inspection checks.

Parameters:
Return type:

Axes

geoprior.utils.plot_physics_payload_meta_core_scalars(payload, *, ax=None, title='Physics payload meta: core scalars', **plot_kws)[source]

Plot selected top-level numeric metadata values.

Parameters:
Return type:

Axes

geoprior.utils.plot_physics_payload_meta_payload_metrics(payload, *, ax=None, title='Physics payload meta: payload metrics', **plot_kws)[source]

Plot compact payload-level metrics.

Parameters:
Return type:

Axes

geoprior.utils.summarize_physics_payload_meta(payload)[source]

Build a compact semantic summary for inspection.

Parameters:

payload (ArtifactRecord | Mapping[str, Any] | str | Path)

Return type:

dict[str, Any]

geoprior.utils.default_scaling_kwargs_payload(*, time_units='year', coords_normalized=True, coords_in_degrees=False, coord_mode='degrees', coord_order=None, coord_ranges=None, bounds=None, dynamic_feature_names=None, future_feature_names=None, static_feature_names=None, gwl_dyn_name='GWL_depth_bgs_m__si', gwl_dyn_index=0, subs_dyn_index=1, z_surf_static_index=11, gwl_kind='depth_bgs', gwl_sign='down_positive', use_head_proxy=True, q_kind='per_volume', mv_prior_mode='calibrate', mv_prior_units='strict')[source]

Build a realistic default scaling-kwargs payload.

The payload is intentionally template-based. It mirrors the resolved scaling_kwargs.json structure written by Stage-2 and preserved in manifests, while staying lightweight enough for documentation examples.

Parameters:
  • time_units (str)

  • coords_normalized (bool)

  • coords_in_degrees (bool)

  • coord_mode (str)

  • coord_order (list[str] | None)

  • coord_ranges (dict[str, float] | None)

  • bounds (dict[str, Any] | None)

  • dynamic_feature_names (list[str] | None)

  • future_feature_names (list[str] | None)

  • static_feature_names (list[str] | None)

  • gwl_dyn_name (str)

  • gwl_dyn_index (int)

  • subs_dyn_index (int)

  • z_surf_static_index (int)

  • gwl_kind (str)

  • gwl_sign (str)

  • use_head_proxy (bool)

  • q_kind (str)

  • mv_prior_mode (str)

  • mv_prior_units (str)

Return type:

dict[str, Any]

geoprior.utils.generate_scaling_kwargs(path, *, template=None, overrides=None, **kwargs)[source]

Generate a scaling-kwargs JSON artifact.

Parameters:
  • path (path-like) – Output path for the JSON file.

  • template (mapping or path-like, optional) – Existing payload used as a template.

  • overrides (dict, optional) – Nested override values applied after the template.

  • **kwargs (Any) – Convenience keyword overrides forwarded into the default payload builder when no explicit template is provided.

Returns:

Written JSON path.

Return type:

pathlib.Path

geoprior.utils.inspect_scaling_kwargs(payload)[source]

Return a compact multi-view inspection bundle.

Returns:

Dictionary containing summary plus a set of tidy frames.

Return type:

dict

Parameters:

payload (ArtifactRecord | Mapping[str, Any] | str | Path)

geoprior.utils.load_scaling_kwargs(path)[source]

Load a scaling-kwargs artifact.

Returns:

Normalized artifact wrapper.

Return type:

ArtifactRecord

Parameters:

path (str | Path)

geoprior.utils.plot_scaling_kwargs_affine_maps(payload, *, ax=None, title='Scaling affine maps', error='ignore', **plot_kws)[source]

Plot subs/head/H affine map scalars.

Parameters:
Return type:

Axes

geoprior.utils.plot_scaling_kwargs_boolean_summary(payload, *, ax=None, title='Scaling boolean checks', error='ignore', **plot_kws)[source]

Plot common boolean config flags.

Parameters:
Return type:

Axes

geoprior.utils.plot_scaling_kwargs_bounds(payload, *, ax=None, title='Bounds overview', top_n=None, error='ignore', **plot_kws)[source]

Plot numeric bounds as a compact bar chart.

Parameters:
Return type:

Axes

geoprior.utils.plot_scaling_kwargs_coord_ranges(payload, *, ax=None, title='Coordinate ranges', error='ignore', **plot_kws)[source]

Plot coord_ranges for t/x/y.

Parameters:
Return type:

Axes

geoprior.utils.plot_scaling_kwargs_feature_group_sizes(payload, *, ax=None, title='Feature group sizes', error='ignore', **plot_kws)[source]

Plot dynamic/future/static feature-group counts.

Parameters:
Return type:

Axes

geoprior.utils.plot_scaling_kwargs_schedule_scalars(payload, *, ax=None, title='Schedule and runtime scalars', error='ignore', **plot_kws)[source]

Plot selected numeric schedule/runtime scalars.

Parameters:
Return type:

Axes

geoprior.utils.scaling_kwargs_affine_frame(payload)[source]

Return affine SI-map rows.

Parameters:

payload (ArtifactRecord | Mapping[str, Any] | str | Path)

Return type:

DataFrame

geoprior.utils.scaling_kwargs_bounds_frame(payload)[source]

Return tidy bounds rows.

Parameters:

payload (ArtifactRecord | Mapping[str, Any] | str | Path)

Return type:

DataFrame

geoprior.utils.scaling_kwargs_coord_frame(payload)[source]

Return coordinate and convention rows.

Parameters:

payload (ArtifactRecord | Mapping[str, Any] | str | Path)

Return type:

DataFrame

geoprior.utils.scaling_kwargs_feature_channels_frame(payload)[source]

Return feature-group and channel rows.

Parameters:

payload (ArtifactRecord | Mapping[str, Any] | str | Path)

Return type:

DataFrame

geoprior.utils.scaling_kwargs_schedule_frame(payload)[source]

Return Q/MV schedule and runtime scalar rows.

Parameters:

payload (ArtifactRecord | Mapping[str, Any] | str | Path)

Return type:

DataFrame

geoprior.utils.summarize_scaling_kwargs(payload)[source]

Return a compact high-level scaling summary.

Returns:

Compact summary intended for logs or gallery prose.

Return type:

dict

Parameters:

payload (ArtifactRecord | Mapping[str, Any] | str | Path)

geoprior.utils.default_model_init_manifest_payload(*, model_class='GeoPriorSubsNet', forecast_horizon=3, static_input_dim=12, dynamic_input_dim=5, future_input_dim=1, output_subsidence_dim=1, output_gwl_dim=1, quantiles=None, mode='tft_like', pde_mode='on', identifiability_regime=None, time_units='year')[source]

Build a realistic default model-init manifest.

The payload mirrors the saved structure used by the model-init manifest family: config + dims + model_class.

Parameters:
  • model_class (str)

  • forecast_horizon (int)

  • static_input_dim (int)

  • dynamic_input_dim (int)

  • future_input_dim (int)

  • output_subsidence_dim (int)

  • output_gwl_dim (int)

  • quantiles (list[float] | None)

  • mode (str)

  • pde_mode (str)

  • identifiability_regime (str | None)

  • time_units (str)

Return type:

dict[str, Any]

geoprior.utils.generate_model_init_manifest(path, *, template=None, overrides=None)[source]

Generate a model-init manifest JSON file.

Parameters:
  • path (str or pathlib.Path) – Output JSON path.

  • template (mapping, optional) – Base payload. If omitted, uses default_model_init_manifest_payload().

  • overrides (mapping, optional) – Nested overrides applied on top of the template.

Return type:

Path

geoprior.utils.inspect_model_init_manifest(manifest, *, output_dir=None, stem='model_init_manifest', save_figures=True)[source]

Inspect a model-init manifest and optionally save figures.

Returns:

Bundle containing summary, tabular frames, and optionally written figure paths.

Return type:

dict

Parameters:
geoprior.utils.load_model_init_manifest(path)[source]

Load a model-init manifest as ArtifactRecord.

Parameters:

path (str | Path)

Return type:

ArtifactRecord

geoprior.utils.model_init_architecture_frame(manifest)[source]

Return a frame for architecture choices.

Parameters:

manifest (ArtifactRecord | Mapping[str, Any] | str | Path)

Return type:

DataFrame

geoprior.utils.model_init_dims_frame(manifest)[source]

Return a tidy frame for input/output dimensions.

Parameters:

manifest (ArtifactRecord | Mapping[str, Any] | str | Path)

Return type:

DataFrame

geoprior.utils.model_init_feature_groups_frame(manifest)[source]

Return a tidy frame for nested feature-name groups.

Parameters:

manifest (ArtifactRecord | Mapping[str, Any] | str | Path)

Return type:

DataFrame

geoprior.utils.model_init_geoprior_frame(manifest)[source]

Return a frame for GeoPrior-specific init settings.

Parameters:

manifest (ArtifactRecord | Mapping[str, Any] | str | Path)

Return type:

DataFrame

geoprior.utils.model_init_scaling_overview_frame(manifest)[source]

Return a compact overview of resolved scaling kwargs.

Parameters:

manifest (ArtifactRecord | Mapping[str, Any] | str | Path)

Return type:

DataFrame

geoprior.utils.plot_model_init_architecture(manifest, *, ax=None, title='Architecture scalars', **plot_kws)[source]

Plot key architecture scalars.

Parameters:
Return type:

Axes

geoprior.utils.plot_model_init_boolean_summary(manifest, *, ax=None, title='Model-init checks', **plot_kws)[source]

Plot compact initialization checks as booleans.

Parameters:
Return type:

Axes

geoprior.utils.plot_model_init_dims(manifest, *, ax=None, title='Model-init dimensions', **plot_kws)[source]

Plot input/output dimensions.

Parameters:
Return type:

Axes

geoprior.utils.plot_model_init_feature_group_sizes(manifest, *, ax=None, title='Feature-group sizes', **plot_kws)[source]

Plot the sizes of static/dynamic/future feature groups.

Parameters:
Return type:

Axes

geoprior.utils.plot_model_init_geoprior(manifest, *, ax=None, title='GeoPrior initialization', **plot_kws)[source]

Plot key GeoPrior physics-init scalars.

Parameters:
Return type:

Axes

geoprior.utils.summarize_model_init_manifest(manifest)[source]

Return a compact summary of a model-init manifest.

The goal is not to flatten every nested key, but to expose the most decision-relevant initialization facts.

Parameters:

manifest (ArtifactRecord | Mapping[str, Any] | str | Path)

Return type:

dict[str, Any]

geoprior.utils.default_run_manifest_payload(*, city='nansha', model='GeoPriorSubsNet', stage='stage-2-train', time_steps=5, forecast_horizon_years=3, mode='tft_like', pde_mode_config='on', quantiles=None, identifiability_regime=None)[source]

Build a realistic default run-manifest payload.

The payload mirrors the lightweight Stage-2 run manifest saved after training: stage + city + model + config + paths + artifacts.

Parameters:
  • city (str)

  • model (str)

  • stage (str)

  • time_steps (int)

  • forecast_horizon_years (int)

  • mode (str)

  • pde_mode_config (str)

  • quantiles (list[float] | None)

  • identifiability_regime (str | None)

Return type:

dict[str, Any]

geoprior.utils.generate_run_manifest(path, *, template=None, overrides=None)[source]

Generate a run-manifest JSON artifact.

Parameters:
  • path (str or pathlib.Path) – Output JSON path.

  • template (mapping, optional) – Base payload. If omitted, uses default_run_manifest_payload().

  • overrides (mapping, optional) – Nested overrides applied on top of the template.

Return type:

Path

geoprior.utils.inspect_run_manifest(manifest)[source]

Return a bundle of useful inspection outputs.

Returns:

Dictionary containing the normalized payload, compact summary, and the main tidy frames.

Return type:

dict

Parameters:

manifest (ArtifactRecord | Mapping[str, Any] | str | Path)

geoprior.utils.load_run_manifest(path)[source]

Load a run-manifest as ArtifactRecord.

Parameters:

path (str | Path)

Return type:

ArtifactRecord

geoprior.utils.plot_run_manifest_boolean_summary(ax, manifest, *, title='Run-manifest checks', **plot_kws)[source]

Plot simple boolean checks for expected run outputs.

Parameters:
Return type:

Axes

geoprior.utils.plot_run_manifest_coord_ranges(ax, manifest, *, title='Coordinate ranges', **plot_kws)[source]

Plot coordinate ranges from nested scaling kwargs.

Parameters:
Return type:

Axes

geoprior.utils.plot_run_manifest_feature_group_sizes(ax, manifest, *, title='Feature-group sizes', **plot_kws)[source]

Plot feature-group sizes from nested scaling kwargs.

Parameters:
Return type:

Axes

geoprior.utils.plot_run_manifest_path_inventory(ax, manifest, *, title='Run-manifest inventory', **plot_kws)[source]

Plot path and artifact counts for the run manifest.

Parameters:
Return type:

Axes

geoprior.utils.run_manifest_artifacts_frame(manifest)[source]

Return a frame describing direct artifact pointers.

Parameters:

manifest (ArtifactRecord | Mapping[str, Any] | str | Path)

Return type:

DataFrame

geoprior.utils.run_manifest_config_frame(manifest)[source]

Return a tidy frame for the lightweight config block.

Parameters:

manifest (ArtifactRecord | Mapping[str, Any] | str | Path)

Return type:

DataFrame

geoprior.utils.run_manifest_identity_frame(manifest)[source]

Return a compact run-identity frame.

Parameters:

manifest (ArtifactRecord | Mapping[str, Any] | str | Path)

Return type:

DataFrame

geoprior.utils.run_manifest_paths_frame(manifest)[source]

Return a frame describing exported run paths.

Parameters:

manifest (ArtifactRecord | Mapping[str, Any] | str | Path)

Return type:

DataFrame

geoprior.utils.run_manifest_scaling_overview_frame(manifest)[source]

Return a compact frame for nested scaling overview.

Parameters:

manifest (ArtifactRecord | Mapping[str, Any] | str | Path)

Return type:

DataFrame

geoprior.utils.summarize_run_manifest(manifest)[source]

Summarize a run manifest in a compact dictionary.

The summary focuses on what a user usually wants at a glance: which Stage-2 run this is, how many paths were exported, whether feature groups are present, and the essential scaling/coord context.

Parameters:

manifest (ArtifactRecord | Mapping[str, Any] | str | Path)

Return type:

dict[str, Any]

geoprior.utils.default_manifest_payload(*, city='nansha', model='GeoPriorSubsNet', time_steps=5, forecast_horizon_years=3, mode='tft_like')[source]

Build a realistic default Stage-1 manifest payload.

The payload mirrors the stable Stage-1 handshake: schema_version + identity fields + config + artifacts + paths + versions.

Parameters:
  • city (str)

  • model (str)

  • time_steps (int)

  • forecast_horizon_years (int)

  • mode (str)

Return type:

dict[str, Any]

geoprior.utils.generate_manifest(path, *, template=None, overrides=None)[source]

Generate a Stage-1 manifest JSON artifact.

Parameters:
Return type:

Path

geoprior.utils.inspect_manifest(manifest)[source]

Return a bundle of useful Stage-1 inspection outputs.

Returns:

Normalized payload, compact summary, and the main tidy frames.

Return type:

dict

Parameters:

manifest (ArtifactRecord | Mapping[str, Any] | str | Path)

geoprior.utils.load_manifest(path)[source]

Load a Stage-1 manifest as ArtifactRecord.

Parameters:

path (str | Path)

Return type:

ArtifactRecord

geoprior.utils.manifest_artifacts_frame(manifest)[source]

Return a leaf-level artifact inventory frame.

Parameters:

manifest (ArtifactRecord | Mapping[str, Any] | str | Path)

Return type:

DataFrame

geoprior.utils.manifest_config_frame(manifest)[source]

Return a tidy frame for the compact config overview.

Parameters:

manifest (ArtifactRecord | Mapping[str, Any] | str | Path)

Return type:

DataFrame

geoprior.utils.manifest_feature_groups_frame(manifest)[source]

Return feature-group names and counts.

Parameters:

manifest (ArtifactRecord | Mapping[str, Any] | str | Path)

Return type:

DataFrame

geoprior.utils.manifest_holdout_frame(manifest)[source]

Return key holdout and split counts as a tidy frame.

Parameters:

manifest (ArtifactRecord | Mapping[str, Any] | str | Path)

Return type:

DataFrame

geoprior.utils.manifest_identity_frame(manifest)[source]

Return a compact Stage-1 identity frame.

Parameters:

manifest (ArtifactRecord | Mapping[str, Any] | str | Path)

Return type:

DataFrame

geoprior.utils.manifest_paths_frame(manifest)[source]

Return a frame describing top-level manifest paths.

Parameters:

manifest (ArtifactRecord | Mapping[str, Any] | str | Path)

Return type:

DataFrame

geoprior.utils.manifest_shapes_frame(manifest)[source]

Return a tidy tensor-shape summary frame.

Parameters:

manifest (ArtifactRecord | Mapping[str, Any] | str | Path)

Return type:

DataFrame

geoprior.utils.manifest_versions_frame(manifest)[source]

Return runtime/library versions saved in the manifest.

Parameters:

manifest (ArtifactRecord | Mapping[str, Any] | str | Path)

Return type:

DataFrame

geoprior.utils.plot_manifest_artifact_inventory(ax, manifest, *, title='Manifest artifact inventory', **plot_kws)[source]

Plot artifact and metadata inventory counts.

Parameters:
Return type:

Axes

geoprior.utils.plot_manifest_boolean_summary(ax, manifest, *, title='Stage-1 manifest checks', **plot_kws)[source]

Plot simple boolean checks for the handshake.

Parameters:
Return type:

Axes

geoprior.utils.plot_manifest_coord_ranges(ax, manifest, *, title='Coordinate ranges', **plot_kws)[source]

Plot nested coordinate ranges from scaling kwargs.

Parameters:
Return type:

Axes

geoprior.utils.plot_manifest_feature_group_sizes(ax, manifest, *, title='Stage-1 feature groups', **plot_kws)[source]

Plot Stage-1 feature-group sizes.

Parameters:
Return type:

Axes

geoprior.utils.plot_manifest_holdout_counts(ax, manifest, *, title='Holdout split counts', **plot_kws)[source]

Plot the main group and sequence split counts.

Parameters:
Return type:

Axes

geoprior.utils.summarize_manifest(manifest)[source]

Summarize the Stage-1 manifest in a compact dictionary.

The summary focuses on what later workflow stages usually need to verify first.

Parameters:

manifest (ArtifactRecord | Mapping[str, Any] | str | Path)

Return type:

dict[str, Any]

geoprior.utils.default_xfer_results_payload(*, src_city='nansha', tgt_city='zhongshan', model_name='GeoPriorSubsNet', split='test', calibration='source', strategy='warm', rescale_mode='strict')[source]

Build a realistic default transfer-results payload.

The payload mirrors the common pattern where one xfer_results.json file stores multiple transfer records, often one per direction.

Parameters:
  • src_city (str)

  • tgt_city (str)

  • model_name (str)

  • split (str)

  • calibration (str)

  • strategy (str)

  • rescale_mode (str)

Return type:

list[dict[str, Any]]

geoprior.utils.generate_xfer_results(path, *, template=None, overrides=None)[source]

Generate a reproducible transfer-results artifact.

Parameters:
  • path (path-like) – Destination JSON path.

  • template (list of dict, optional) – Existing transfer records to reuse.

  • overrides (dict or list of dict, optional) – Optional updates applied either to all records or record-wise.

Return type:

Path

geoprior.utils.inspect_xfer_results(xfer)[source]

Build a compact inspection bundle.

Returns:

A dictionary containing: - summary : workflow summary, - overall : overall metrics frame, - per_horizon : per-horizon frame, - schema : schema diagnostics frame, - warm : warm-start frame.

Return type:

dict

Parameters:

xfer (Sequence[Mapping[str, Any]] | str | Path)

geoprior.utils.load_xfer_results(xfer)[source]

Load transfer-results records.

Parameters:

xfer (Sequence[Mapping[str, Any]] | str | Path)

Return type:

list[dict[str, Any]]

geoprior.utils.plot_xfer_boolean_summary(xfer, *, figsize=(8.4, 4.8), ax=None, title='Transfer boolean summary', show_grid=True, grid_kws=None, error='ignore', **plot_kws)[source]

Plot aggregated schema boolean pass rates.

This turns per-record schema booleans into one compact pass-rate view.

Parameters:
Return type:

tuple[Figure, Axes]

geoprior.utils.plot_xfer_direction_metric(xfer, *, metric='overall_rmse', figsize=(7.8, 4.2), ax=None, title=None, ylabel=None, show_grid=True, grid_kws=None, error='ignore', **plot_kws)[source]

Plot one overall metric by transfer direction.

Parameters:
Return type:

tuple[Figure, Axes]

geoprior.utils.plot_xfer_overall_metrics(xfer, *, metrics=None, figsize=(8.4, 4.8), ax=None, title='Transfer overall metrics', ylabel='value', show_grid=True, grid_kws=None, legend=None, legend_kws=None, error='ignore', **plot_kws)[source]

Plot selected overall metrics for each record.

Parameters:
Return type:

tuple[Figure, Axes]

geoprior.utils.plot_xfer_per_horizon_metrics(xfer, *, metric='rmse', figsize=(8.0, 4.8), ax=None, title=None, xlabel='horizon', ylabel=None, marker='o', show_grid=True, grid_kws=None, legend=True, legend_kws=None, error='ignore', **plot_kws)[source]

Plot one per-horizon metric as lines over horizon.

Parameters:
Return type:

tuple[Figure, Axes]

geoprior.utils.plot_xfer_schema_counts(xfer, *, figsize=(8.0, 4.6), ax=None, title='Schema mismatch counts', ylabel='count', show_grid=True, grid_kws=None, legend=True, legend_kws=None, error='ignore', **plot_kws)[source]

Plot schema mismatch counts by record.

Parameters:
Return type:

tuple[Figure, Axes]

geoprior.utils.summarize_xfer_results(xfer)[source]

Build a compact transfer-results summary.

The summary is intentionally workflow-oriented rather than exhaustive.

Parameters:

xfer (Sequence[Mapping[str, Any]] | str | Path)

Return type:

dict[str, Any]

geoprior.utils.xfer_overall_frame(xfer)[source]

Return one tidy row per transfer record.

The frame exposes the most useful comparison columns for quick ranking and filtering.

Parameters:

xfer (Sequence[Mapping[str, Any]] | str | Path)

Return type:

DataFrame

geoprior.utils.xfer_per_horizon_frame(xfer)[source]

Return a tidy per-horizon metric frame.

The output is useful for comparing whether transfer quality degrades differently across directions or strategies as horizon increases.

Parameters:

xfer (Sequence[Mapping[str, Any]] | str | Path)

Return type:

DataFrame

geoprior.utils.xfer_schema_frame(xfer)[source]

Return schema-alignment diagnostics in tidy form.

Parameters:

xfer (Sequence[Mapping[str, Any]] | str | Path)

Return type:

DataFrame

geoprior.utils.xfer_warm_frame(xfer)[source]

Return warm-start settings in tidy form.

Parameters:

xfer (Sequence[Mapping[str, Any]] | str | Path)

Return type:

DataFrame

geoprior.utils.calibration_stats_factors_frame(stats)[source]

Return per-horizon calibration factors.

Parameters:

stats (ArtifactRecord | Mapping[str, Any] | str | Path)

Return type:

DataFrame

geoprior.utils.calibration_stats_overall_frame(stats)[source]

Return before/after overall calibration metrics.

Parameters:

stats (ArtifactRecord | Mapping[str, Any] | str | Path)

Return type:

DataFrame

geoprior.utils.calibration_stats_per_horizon_frame(stats, *, which='eval_after')[source]

Return per-horizon coverage and sharpness.

Parameters:
  • which ({'eval_before', 'eval_after'}) – Which calibration stage to extract.

  • stats (ArtifactRecord | Mapping[str, Any] | str | Path)

Return type:

DataFrame

geoprior.utils.default_calibration_stats_payload(*, target=0.8, interval=(0.1, 0.9), f_max=5.0, tol=0.02, factors=None, coverage_before=0.865, coverage_after=0.8, sharpness_before=33.08, sharpness_after=33.38)[source]

Build a realistic default calibration-stats payload.

The structure follows the object saved by the calibration workflow and later embedded into the interpretable eval JSON.

Parameters:
Return type:

dict[str, Any]

geoprior.utils.generate_calibration_stats(path, *, template=None, overrides=None)[source]

Generate and save a calibration-stats JSON file.

Parameters:
  • path (str or pathlib.Path) – Output JSON path.

  • template (mapping, path, ArtifactRecord, optional) – Optional source payload. If omitted, a realistic default payload is used.

  • overrides (dict, optional) – Deep overrides applied after template resolution.

Return type:

Path

geoprior.utils.inspect_calibration_stats(stats)[source]

Build a compact inspection bundle.

Returns:

A dictionary containing the raw payload, a compact summary, and tidy frames useful for gallery lessons, notebooks, or debugging.

Return type:

dict

Parameters:

stats (ArtifactRecord | Mapping[str, Any] | str | Path)

geoprior.utils.load_calibration_stats(path)[source]

Load a calibration-stats artifact.

Notes

This loader is list-safe and nested-block aware. It can read either a direct calibration_stats.json payload or an interpretable eval JSON from which the nested block is extracted.

Parameters:

path (str | Path)

Return type:

ArtifactRecord

geoprior.utils.plot_calibration_boolean_summary(ax, stats, *, title='Calibration checks', error='ignore', **plot_kws)[source]

Plot compact boolean checks for calibration status.

Parameters:
Return type:

Axes

geoprior.utils.plot_calibration_factors(ax, stats, *, title='Calibration factors', show_grid=True, grid_kws=None, error='ignore', **plot_kws)[source]

Plot per-horizon widening factors.

Parameters:
Return type:

Axes

geoprior.utils.plot_calibration_overall_metrics(ax, stats, *, title='Calibration summary', error='ignore', **plot_kws)[source]

Plot overall before/after calibration metrics.

Parameters:
Return type:

Axes

geoprior.utils.plot_calibration_per_horizon_coverage(ax, stats, *, which='eval_after', title=None, error='ignore', **plot_kws)[source]

Plot per-horizon coverage.

Parameters:
Return type:

Axes

geoprior.utils.plot_calibration_per_horizon_sharpness(ax, stats, *, which='eval_after', title=None, error='ignore', **plot_kws)[source]

Plot per-horizon sharpness.

Parameters:
Return type:

Axes

geoprior.utils.summarize_calibration_stats(stats)[source]

Return a compact summary of calibration behavior.

Parameters:

stats (ArtifactRecord | Mapping[str, Any] | str | Path)

Return type:

dict[str, Any]

geoprior.utils.ablation_config_frame(src)[source]

Return one row per record with config knobs.

Parameters:

src (Sequence[Mapping[str, Any]] | str | Path)

Return type:

DataFrame

geoprior.utils.ablation_metrics_frame(src)[source]

Return long-form scalar metric rows.

Parameters:

src (Sequence[Mapping[str, Any]] | str | Path)

Return type:

DataFrame

geoprior.utils.ablation_per_horizon_frame(src)[source]

Return long-form per-horizon metric rows.

Parameters:

src (Sequence[Mapping[str, Any]] | str | Path)

Return type:

DataFrame

geoprior.utils.ablation_record_flags_frame(src)[source]

Return long-form boolean/config flags.

Parameters:

src (Sequence[Mapping[str, Any]] | str | Path)

Return type:

DataFrame

geoprior.utils.ablation_record_runs_frame(src)[source]

Return one tidy row per ablation record.

Parameters:

src (Sequence[Mapping[str, Any]] | str | Path)

Return type:

DataFrame

geoprior.utils.default_ablation_record_payload(*, city='nansha', model='GeoPriorSubsNet')[source]

Build a realistic default ablation JSONL payload.

The structure mirrors the current real record shape: one JSON object per line, with top-level config knobs, repeated compact metrics, a nested metrics block, units, and per-horizon metric maps.

Parameters:
Return type:

list[dict[str, Any]]

geoprior.utils.generate_ablation_record(output_path, *, overrides=None, city='nansha', model='GeoPriorSubsNet')[source]

Write a realistic demo ablation JSONL file.

Parameters:
Return type:

Path

geoprior.utils.inspect_ablation_record(src, *, output_dir=None, stem='ablation_record', save_figures=True)[source]

Inspect ablation JSONL and optionally save figures.

Parameters:
Return type:

dict[str, Any]

geoprior.utils.load_ablation_record(src)[source]

Load ablation JSONL records into a plain list.

Parameters:

src (Sequence[Mapping[str, Any]] | str | Path)

Return type:

list[dict[str, Any]]

geoprior.utils.plot_ablation_boolean_summary(src, *, ax=None, title='Ablation record checks', **plot_kws)[source]
Parameters:
Return type:

Axes

geoprior.utils.plot_ablation_lambda_weights(src, *, ax=None, title='Lambda weights by variant', xlabel='variant', ylabel='weight', show_grid=True, grid_kws=None, rotate_xticks=25, legend=True, legend_kws=None, error='ignore', **plot_kws)[source]
Parameters:
Return type:

Axes

geoprior.utils.plot_ablation_metric_by_variant(src, *, metric='rmse', ax=None, title=None, xlabel='variant', ylabel=None, show_grid=True, grid_kws=None, rotate_xticks=25, annotate=False, annotate_kws=None, error='ignore', **plot_kws)[source]
Parameters:
Return type:

Axes

geoprior.utils.plot_ablation_per_horizon_metric(src, *, metric='mae', ax=None, title=None, xlabel='horizon', ylabel=None, show_grid=True, grid_kws=None, legend=True, legend_kws=None, error='ignore', **plot_kws)[source]
Parameters:
Return type:

Axes

geoprior.utils.plot_ablation_run_counts(src, *, ax=None, title='Ablation runs per variant', xlabel='variant', ylabel='runs', show_grid=True, grid_kws=None, rotate_xticks=25, annotate=False, annotate_kws=None, error='ignore', **plot_kws)[source]
Parameters:
Return type:

Axes

geoprior.utils.plot_ablation_top_variants(src, *, metric='rmse', top_n=5, ax=None, title=None, xlabel=None, ylabel='variant', show_grid=True, grid_kws=None, annotate=False, annotate_kws=None, error='ignore', **plot_kws)[source]
Parameters:
Return type:

Axes

geoprior.utils.summarize_ablation_record(src)[source]

Return a semantic summary of ablation JSONL.

Parameters:

src (Sequence[Mapping[str, Any]] | str | Path)

Return type:

dict[str, Any]

What this package covers#

The top-level utility package is intentionally broad because it supports the staged application workflow end to end.

Its current public exports include helper groups such as:

  • audits and handshakes including audit_stage1_scaling, audit_stage2_handshake, and should_audit;

  • calibration including calibrate_quantile_forecasts;

  • forecast formatting and evaluation including format_and_forecast, evaluate_forecast, and pivot_forecast_dataframe;

  • configuration and stage helpers including load_nat_config, load_nat_config_payload, make_tf_dataset, map_targets_for_training, resolve_hybrid_config, resolve_si_affine, and serialize_subs_params;

  • geospatial and spatial workflow helpers including spatial_sampling, create_spatial_clusters, augment_city_spatiotemporal_data, and deg_to_m_from_lat;

  • holdout logic including compute_group_masks, split_groups_holdout, and filter_df_by_groups;

  • subsidence-oriented workflow helpers including convert_eval_payload_units, cumulative_to_rate, rate_to_cumulative, resolve_gwl_for_physics, resolve_head_column, and make_txy_coords.

Important workflow-facing modules#

Audit helpers#

Audit helpers for stage handshakes and scaling artifacts.

geoprior.utils.audit_utils.resolve_audit_stages(audit_stages, *, known=('stage1', 'stage2', 'stage3'), default=None)[source]

Resolve cfg[“AUDIT_STAGES”] into a canonical set like {“stage1”,”stage2”}.

Parameters:
Return type:

set[str]

geoprior.utils.audit_utils.should_audit(audit_stages, *, stage, default=None)[source]

Convenience: should we audit this stage?

Parameters:
  • audit_stages (Any)

  • stage (str)

  • default (Any)

Return type:

bool

geoprior.utils.audit_utils.audit_stage1_scaling(*, df_train, inputs_train, targets_train, coord_scaler=None, coord_ranges=None, coord_mode='auto', coords_in_degrees=False, coord_epsg_used=None, coord_x_col_used='x', coord_y_col_used='y', x_col_used='x', y_col_used='y', time_col_used='t', normalize_coords=True, keep_coords_raw=False, shift_raw_coords=False, subs_model_col=None, gwl_dyn_col=None, gwl_target_col=None, h_field_col=None, dynamic_features=None, static_features=None, future_features=None, scaled_ml_numeric_cols=None, main_scaler_path=None, scaler_info=None, save_dir=None, table_width=110, title_prefix='COORDINATE + FEATURE SCALING AUDIT (Stage-1)', city='Unknown', model_name='Model', sample_rows=5, log_fn=None)[source]

Stage-1 audit: - raw df_train coord stats (t/x/y) + heuristic units - model-fed coords stats from inputs_train[“coords”] (flattened) - coord scaler min/max + coord_ranges - SI channel sanity for physics cols (if present) - target arrays sanity - split of features: scaled ML vs __si vs other Saves a machine-readable JSON if save_dir is provided.

Parameters:
  • inputs_train (dict[str, Any])

  • targets_train (dict[str, Any])

  • coord_scaler (Any)

  • coord_ranges (dict[str, float] | None)

  • coord_mode (str)

  • coords_in_degrees (bool)

  • coord_epsg_used (Any)

  • coord_x_col_used (str)

  • coord_y_col_used (str)

  • x_col_used (str)

  • y_col_used (str)

  • time_col_used (str)

  • normalize_coords (bool)

  • keep_coords_raw (bool)

  • shift_raw_coords (bool)

  • subs_model_col (str | None)

  • gwl_dyn_col (str | None)

  • gwl_target_col (str | None)

  • h_field_col (str | None)

  • dynamic_features (Iterable[str] | None)

  • static_features (Iterable[str] | None)

  • future_features (Iterable[str] | None)

  • scaled_ml_numeric_cols (Iterable[str] | None)

  • main_scaler_path (str | None)

  • scaler_info (dict | None)

  • save_dir (str | None)

  • table_width (int)

  • title_prefix (str)

  • city (str)

  • model_name (str)

  • sample_rows (int)

Return type:

str | None

geoprior.utils.audit_utils.audit_stage2_handshake(*, X_train, X_val, y_train, y_val, time_steps, forecast_horizon, mode, dyn_names, fut_names, sta_names, coord_scaler=None, sk_final, save_dir, table_width=100, title_prefix='STAGE-2 HANDSHAKE AUDIT', city='Unkown', model_name='Model', log_fn=None)[source]
Parameters:
geoprior.utils.audit_utils.audit_stage1_stage2_coord_consistency(*, X_train, coord_scaler, sk_final, time_steps, forecast_horizon, time_units='year', save_dir=None, table_width=110, title_prefix='STAGE-1 <-> STAGE-2 COORD CONSISTENCY', city='Unknown', model_name='Model', log_fn=None)[source]

Cross-check coordinate semantics between Stage-1 scaler and Stage-2 NPZ coords.

Key facts for GeoPrior Stage-2:
  • coords are (N, H, 3) and correspond to target horizon times not the full dynamic history. So t has exactly H unique values.

  • x/y typically cover full normalized [0,1] range if you have spatial coverage (often min=0 and max=1).

This audit:
  • computes normalized min/max for t/x/y in X_train[“coords”]

  • derives implied raw min/max using MinMaxScaler data_min_ / data_max_

  • checks raw ranges are within Stage-1 scaler bounds

  • checks t_unique count == H and t_raw_unique spacing (≈1 year)

  • provides UTM plausibility hint if epsg is UTM-like

Parameters:
  • X_train (dict)

  • sk_final (dict)

  • time_steps (int)

  • forecast_horizon (int)

  • time_units (str)

  • save_dir (str | None)

  • table_width (int)

  • title_prefix (str)

  • city (str)

  • model_name (str)

geoprior.utils.audit_utils.audit_stage3_run(*, manifest_path, manifest, cfg, fixed_params, best_hps, run_dir, best_model_path, best_weights_path, use_tf_savedmodel, quantiles, forecast_horizon, mode, pred_shapes=None, eval_results=None, phys_diag=None, calibrator_factors=None, forecast_csv_eval=None, forecast_csv_future=None, metrics_json_path=None, physics_payload_path=None, save_dir=None, table_width=100, title_prefix='STAGE-3 AUDIT', city='Unknown', model_name='Model', log_fn=None)[source]

Stage-3 audit: tuned artifacts + eval sanity.

Parameters:
  • manifest_path (str | None)

  • manifest (dict[str, Any])

  • cfg (dict[str, Any])

  • fixed_params (dict[str, Any])

  • best_hps (dict[str, Any] | None)

  • run_dir (str)

  • best_model_path (str | None)

  • best_weights_path (str | None)

  • use_tf_savedmodel (bool)

  • quantiles (Any)

  • forecast_horizon (int)

  • mode (str)

  • pred_shapes (dict[str, Any] | None)

  • eval_results (dict[str, Any] | None)

  • phys_diag (dict[str, Any] | None)

  • calibrator_factors (Any)

  • forecast_csv_eval (str | None)

  • forecast_csv_future (str | None)

  • metrics_json_path (str | None)

  • physics_payload_path (str | None)

  • save_dir (str | None)

  • table_width (int)

  • title_prefix (str)

  • city (str)

  • model_name (str)

Return type:

str | None

These helpers matter because they validate Stage-1 scaling and Stage-2 handoff assumptions before later model or physics-aware runs rely on them.

Calibration helpers#

Forecast utilities.

geoprior.utils.calibrate.fit_interval_factors_df(df_eval, *, target_name='subsidence', column_map=None, step_col='forecast_step', interval=(0.1, 0.9), target_coverage=0.8, median_q=0.5, tol=0.001, f_max=5.0, max_iter=32, verbose=1, logger=None)[source]

Fit empirical interval-width correction factors from an evaluation DataFrame.

This function learns one multiplicative factor per forecast horizon so that an existing predictive interval reaches a desired empirical coverage on evaluation data. It is designed for the common case where a model already produces quantile forecasts, but the resulting intervals are systematically too narrow or too wide.

For each group of rows, the function identifies

  • one lower quantile,

  • one median-like quantile, and

  • one upper quantile,

then searches for the smallest factor \(f \ge 1\) such that the rescaled interval reaches the requested target coverage.

If \(q_{lo}\), \(q_{md}\), and \(q_{hi}\) denote the selected lower, median, and upper forecasts, the calibrated interval used during fitting is

(5)#\[\tilde q_{lo} = q_{md} - f \, (q_{md} - q_{lo})\]
(6)#\[\tilde q_{hi} = q_{md} + f \, (q_{hi} - q_{md})\]

and the fitted factor is the smallest value whose empirical coverage is at least the target, up to the requested tolerance.

This is a lightweight post-hoc calibration strategy that is especially useful for multi-horizon quantile forecasts, where sharpness and coverage must be balanced carefully in downstream risk analysis [1, 2].

Parameters:
  • df_eval (pandas.DataFrame) – Evaluation table containing the observed target and the forecast quantile columns used for calibration.

  • target_name (str, default "subsidence") – Base name of the forecasted variable. When column_map is not provided, quantile columns are auto-detected using names such as "{target_name}_q10", "{target_name}_q50", and "{target_name}_q90".

  • column_map (mapping or None, default None) –

    Optional explicit column mapping. This may contain at least

    • "quantiles": a mapping from quantile level to column name, or a list of quantile column names,

    • "actual": the observed target column.

    This is useful when the DataFrame does not follow the default naming convention.

  • step_col (str, default "forecast_step") – Column used to fit separate factors per horizon or lead time. If this column is missing, the full DataFrame is treated as a single group and a single factor is returned with key 1.

  • interval (tuple of float, default (0.1, 0.9)) – Nominal lower and upper quantiles defining the interval to calibrate. The function uses the closest available quantiles in the DataFrame.

  • target_coverage (float, default 0.8) – Desired empirical coverage of the calibrated interval.

  • median_q (float, default 0.5) – Target quantile used as the center of the interval expansion. The nearest available quantile in the DataFrame is used.

  • tol (float, default 1e-3) – Numerical tolerance used during the bisection search.

  • f_max (float, default 5.0) – Upper bound for the searched interval factor. If the target coverage cannot be reached before this bound, f_max is returned for that group.

  • max_iter (int, default 32) – Maximum number of bisection iterations used to fit each factor.

  • verbose (int, default 1) – Verbosity level forwarded to the internal logging helper.

  • logger (logging.Logger or None, default None) – Optional logger used for progress messages.

Returns:

Dictionary mapping each forecast horizon to its fitted interval factor.

Return type:

dict[int, float]

Raises:
  • ValueError – If no actual column can be resolved, or if no valid quantile interval can be formed from the available columns.

  • KeyError – If a user-specified column in column_map is missing.

Notes

This function does not change the forecasts directly. It only learns the correction factors. To apply the fitted factors to evaluation or future forecasts, use apply_interval_factors_df().

Because the method calibrates interval width around a median-like center, it is most appropriate when miscalibration is primarily a sharpness problem rather than a severe bias in the central forecast.

Examples

>>> import pandas as pd
>>> from geoprior.utils.calibrate import fit_interval_factors_df
>>> df = pd.DataFrame(
...     {
...         "forecast_step": [1, 1, 1, 2, 2, 2],
...         "subsidence_actual": [0.4, 0.7, 1.0, 0.5, 0.8, 1.2],
...         "subsidence_q10": [0.3, 0.5, 0.8, 0.4, 0.6, 0.9],
...         "subsidence_q50": [0.4, 0.7, 1.0, 0.5, 0.8, 1.1],
...         "subsidence_q90": [0.5, 0.9, 1.1, 0.6, 1.0, 1.3],
...     }
... )
>>> factors = fit_interval_factors_df(
...     df,
...     target_name="subsidence",
...     interval=(0.1, 0.9),
...     target_coverage=0.8,
... )
>>> isinstance(factors, dict)
True

See also

apply_interval_factors_df

Apply learned interval factors to forecast quantiles.

calibrate_quantile_forecasts

High-level wrapper that can fit and apply factors in one call.

References

Lim et al. [1] discuss multi-horizon quantile forecasting in a setting where calibration and interpretability both matter.

For subsidence-oriented uncertainty-aware forecasting and practical post-hoc probabilistic refinement in this project context, see Kouadio et al. [2].

geoprior.utils.calibrate.apply_interval_factors_df(df, factors, *, target_name='subsidence', column_map=None, step_col='forecast_step', median_q=0.5, keep_original=False, factor_col='calibration_factor', calibrated_col='is_calibrated', enforce_monotonic='cummax', verbose=1, logger=None)[source]

Apply interval-width calibration factors to quantile forecasts stored in a DataFrame.

This function rescales forecast quantiles around a median-like central forecast using either

  • a single scalar factor applied to every row, or

  • a mapping of forecast_step -> factor.

The main goal is to widen or shrink predictive intervals after the factors have been estimated on evaluation data with fit_interval_factors_df().

If \(q_{md}\) is the selected median-like forecast and \(f_h\) is the factor for horizon \(h\), then for a lower quantile \(q < q_{md}\) the calibrated value is

(7)#\[\tilde q = q_{md} - f_h \, (q_{md} - q)\]

and for an upper quantile \(q > q_{md}\) it is

(8)#\[\tilde q = q_{md} + f_h \, (q - q_{md})\]

while the median itself is left unchanged.

Parameters:
  • df (pandas.DataFrame) – DataFrame containing forecast quantile columns.

  • factors (mapping or float) –

    Calibration factor specification. This may be

    • a scalar applied to every row, or

    • a mapping from forecast horizon to factor.

  • target_name (str, default "subsidence") – Base variable name used to auto-detect quantile columns when column_map is not supplied.

  • column_map (mapping or None, default None) – Optional explicit mapping describing the quantile columns.

  • step_col (str, default "forecast_step") – Horizon column used when factors is a mapping. If this column is absent, it is created and filled with 1.

  • median_q (float, default 0.5) – Quantile used as the center of the recalibration. The closest available forecast quantile is used.

  • keep_original (bool, default False) – If True, each raw quantile column is preserved in an additional "{col}_raw" column before calibration is applied.

  • factor_col (str, default "calibration_factor") – Name of the column storing the factor applied to each row.

  • calibrated_col (str, default "is_calibrated") – Name of the Boolean marker column added to the returned DataFrame.

  • enforce_monotonic ({"cummax", "sort", "none"}, default "cummax") –

    Strategy used after recalibration to keep quantiles ordered.

    • "cummax" applies a cumulative maximum across quantiles,

    • "sort" sorts the calibrated quantiles row-wise,

    • "none" leaves the result unchanged.

  • verbose (int, default 1) – Verbosity level forwarded to the internal logger helper.

  • logger (logging.Logger or None, default None) – Optional logger used for progress messages.

Returns:

A calibrated copy of df containing updated quantiles and the metadata columns specified by factor_col and calibrated_col.

Return type:

pandas.DataFrame

Raises:

ValueError – If no quantile columns can be resolved from the DataFrame or if an invalid enforce_monotonic mode is requested.

Notes

Monotonicity enforcement is important because quantile-wise post-processing can otherwise produce crossing quantiles. Keeping quantiles ordered is especially important when the calibrated outputs are used to derive intervals, exceedance probabilities, or downstream risk metrics [1, 2].

This function only modifies the quantile forecasts. It does not recompute summary coverage statistics. For fit-and-apply workflows with optional evaluation summaries, use calibrate_quantile_forecasts().

Examples

>>> import pandas as pd
>>> from geoprior.utils.calibrate import apply_interval_factors_df
>>> df = pd.DataFrame(
...     {
...         "forecast_step": [1, 1, 2, 2],
...         "subsidence_q10": [0.3, 0.5, 0.4, 0.6],
...         "subsidence_q50": [0.4, 0.7, 0.5, 0.8],
...         "subsidence_q90": [0.5, 0.9, 0.6, 1.0],
...     }
... )
>>> out = apply_interval_factors_df(
...     df,
...     factors={1: 1.2, 2: 1.1},
...     target_name="subsidence",
... )
>>> "calibration_factor" in out.columns
True

See also

fit_interval_factors_df

Fit per-horizon interval scaling factors.

calibrate_quantile_forecasts

High-level wrapper for fitting, applying, and summarizing interval calibration.

geoprior.utils.calibrate.calibrate_quantile_forecasts(*, df_eval=None, df_future=None, target_name='subsidence', column_map=None, step_col='forecast_step', interval=(0.1, 0.9), target_coverage=0.8, median_q=0.5, use='auto', tol=0.02, f_max=5.0, max_iter=32, keep_original=False, enforce_monotonic='cummax', overall_key='__overall__', calibrated_col='is_calibrated', factor_col='calibration_factor', factors=None, save_eval=None, save_future=None, save_stats=None, verbose=1, logger=None)[source]

Fit and apply post-hoc interval calibration for quantile forecasts.

This is the high-level DataFrame-oriented entry point for interval recalibration in geoprior.utils.calibrate. It can

  1. detect whether evaluation forecasts already appear calibrated,

  2. fit interval-width correction factors from evaluation data,

  3. apply those factors to evaluation and/or future forecasts,

  4. compute before/after summary diagnostics on the evaluation set,

  5. optionally save the outputs to disk.

The function is designed for workflows where quantile forecasts are already available in tabular form and calibration should be handled without retraining the forecasting model.

Conceptually, the function widens or narrows a predictive interval around a median-like forecast so that the empirical interval coverage better matches the requested target. This is a practical post-hoc strategy for uncertainty refinement in multi-horizon forecasting pipelines [1, 2].

Parameters:
  • df_eval (pandas.DataFrame or None, default None) – Evaluation forecasts used to fit calibration factors and to compute before/after diagnostics. This table should contain the observed target column in addition to the quantile forecasts.

  • df_future (pandas.DataFrame or None, default None) – Future or inference forecasts to which the fitted factors should be applied. This table does not need observed targets.

  • target_name (str, default "subsidence") – Base name used to infer forecast and observation columns when column_map is not explicitly supplied.

  • column_map (mapping or None, default None) – Optional mapping describing the observed column and the quantile columns. This is helpful when the table does not follow the default naming conventions.

  • step_col (str, default "forecast_step") – Column used to fit and apply separate factors per forecast horizon.

  • interval (tuple of float, default (0.1, 0.9)) – Lower and upper quantiles defining the interval to calibrate. The nearest available quantiles are used.

  • target_coverage (float, default 0.8) – Desired empirical coverage after calibration.

  • median_q (float, default 0.5) – Central quantile used as the expansion anchor.

  • use ({"auto", True, False}, default "auto") –

    Control flag for whether calibration is performed.

    • False disables calibration and returns inputs unchanged.

    • "auto" skips calibration when evaluation forecasts already look calibrated.

    • True forces calibration even if the automatic check would skip it.

  • tol (float, default 0.02) – Tolerance used by the automatic already-calibrated check.

  • f_max (float, default 5.0) – Maximum factor allowed during fitting.

  • max_iter (int, default 32) – Maximum number of bisection iterations used when fitting factors.

  • keep_original (bool, default False) – If True, raw quantiles are copied into *_raw columns before calibration is applied.

  • enforce_monotonic ({"cummax", "sort", "none"}, default "cummax") – Strategy used to prevent quantile crossing after recalibration.

  • overall_key (str or None, default "__overall__") – Reserved label stored in the returned statistics dictionary for overall summary reporting.

  • calibrated_col (str, default "is_calibrated") – Column name added to calibrated outputs as a Boolean marker.

  • factor_col (str, default "calibration_factor") – Column name used to store the factor applied to each row.

  • factors (float or mapping or None, default None) – Optional user-specified calibration factors. If provided, these take precedence over factors fitted from df_eval.

  • save_eval (str or path-like or None, default None) – Optional CSV path for saving the calibrated evaluation table.

  • save_future (str or path-like or None, default None) – Optional CSV path for saving the calibrated future table.

  • save_stats (str or path-like or None, default None) – Optional JSON path for saving the calibration summary.

  • verbose (int, default 1) – Verbosity level forwarded to logging helpers.

  • logger (logging.Logger or None, default None) – Optional logger used for progress messages.

Returns:

  • df_eval_cal (pandas.DataFrame or None) – Calibrated evaluation DataFrame, or None when no evaluation table was provided.

  • df_future_cal (pandas.DataFrame or None) – Calibrated future DataFrame, or None when no future table was provided.

  • stats (dict[str, Any]) – Dictionary describing the calibration workflow. Depending on the path taken, it may contain

    • the target interval and target coverage,

    • the fitted or user-specified factors,

    • skip reasons,

    • evaluation summaries before and after calibration.

Return type:

tuple[DataFrame | None, DataFrame | None, dict[str, Any]]

Notes

In use="auto" mode, the function first checks for an explicit calibrated_col and then falls back to a simple empirical coverage-based decision. This makes the wrapper conservative in repeated workflows, where the same tables may pass through the calibration stage more than once.

The returned stats dictionary is designed to be JSON-friendly and therefore suitable for audit trails, experiment manifests, or gallery artifacts.

Examples

>>> import pandas as pd
>>> from geoprior.utils.calibrate import (
...     calibrate_quantile_forecasts,
... )
>>> df_eval = pd.DataFrame(
...     {
...         "forecast_step": [1, 1, 2, 2],
...         "subsidence_actual": [0.4, 0.7, 0.5, 0.9],
...         "subsidence_q10": [0.3, 0.5, 0.4, 0.6],
...         "subsidence_q50": [0.4, 0.7, 0.5, 0.8],
...         "subsidence_q90": [0.5, 0.9, 0.6, 1.0],
...     }
... )
>>> df_eval_cal, df_future_cal, stats = (
...     calibrate_quantile_forecasts(
...         df_eval=df_eval,
...         target_name="subsidence",
...         target_coverage=0.8,
...     )
... )
>>> isinstance(stats, dict)
True

See also

fit_interval_factors_df

Fit per-horizon interval-width correction factors.

apply_interval_factors_df

Apply a scalar or per-horizon factor map to quantile forecasts.

References

For the broader role of calibrated probabilistic multi-horizon forecasting, see Lim et al. [1].

For uncertainty-rich forecasting in the present project ecosystem, see Kouadio et al. [2].

geoprior.utils.calibrate.calibrate_forecasts_in(y_val, forecasts_val, *forecasts_to_calibrate, quantiles=(10, 20, 30, 40, 50, 60, 70, 80, 90), prefix='subsidence', verbose=None, _logger=None)[source]

Train and apply isotonic calibration models on forecast data.

Parameters:
  • y_val (pandas.Series) – Observed target values for calibration.

  • forecasts_val (pandas.DataFrame) – Validation forecasts with columns named <prefix>_q{quantile} for each quantile.

  • *forecasts_to_calibrate (DataFrame or dict) – One or more DataFrames (or a mapping) of new forecasts to calibrate using trained models.

  • quantiles (tuple of int, optional) – List of quantile levels (e.g., 10, 50, 90) to calibrate.

  • prefix (str, default 'subsidence') – Column name prefix for quantile forecasts.

  • verbose (int, optional) – Verbosity level passed to vlog for logging.

  • _logger (Logger or callable, optional) – Sink for log messages (print or logging.Logger).

Returns:

  • results (dict) – Mapping of model names to calibrated DataFrames, each containing new <prefix>_q{q}_cal_prob probability columns.

  • calibrators (dict) – Fitted IsotonicRegression models per quantile.

Notes

This function fits an isotonic regression for each quantile on validation data and then predicts calibrated probabilities for new forecasts. Use normalize_model_inputs to handle varied input formats.

Examples

>>> import pandas as pd
>>> import numpy as np
>>> from geoprior.utils.calibrate import calibrate_forecasts
>>> # Create dummy validation series
>>> idx = pd.date_range('2021-01-01', periods=50)
>>> y_val = pd.Series(
...     np.random.randn(50), index=idx
... )
>>> # Simulate validation forecasts
>>> val_df = pd.DataFrame({
...     'subsidence_q10': y_val - 0.2,
...     'subsidence_q90': y_val + 0.2
... }, index=idx)
>>> # Simulate two new forecast sets
>>> newA = val_df + 0.1
>>> newB = val_df - 0.1
>>> # Calibrate forecasts
>>> results, models = calibrate_forecasts(
...     y_val,
...     val_df,
...     newA,
...     newB,
...     quantiles=(10, 90),
...     verbose=2
... )
>>> # Access calibrated probabilities
>>> results['ModelA']['subsidence_q10_cal_prob']
>>> # Inspect fitted calibrator for q90
>>> models['subsidence_q90'].threshold_
geoprior.utils.calibrate.calibrate_probability_forecast(df, prob_col, actual_col, method='isotonic', out_col=None, clip=True, savefile=None)[source]

Calibrate binary probability forecasts with isotonic regression or logistic scaling.

This function post-processes a column of raw forecast probabilities so that they better agree with observed binary outcomes. It returns a copy of the input DataFrame with one additional calibrated probability column.

Two calibration modes are supported:

  • "isotonic"

    Non-parametric monotone calibration using isotonic regression.

  • "logistic"

    Parametric calibration using logistic regression on the raw probabilities.

Parameters:
  • df (pandas.DataFrame) – DataFrame containing a column of raw forecast probabilities and a column of observed binary outcomes.

  • prob_col (str) – Name of the column containing the raw forecast probabilities. Values are usually expected in the interval [0, 1].

  • actual_col (str) – Name of the column containing the realized binary outcomes, encoded as 0 and 1.

  • method ({"isotonic", "logistic"}, default "isotonic") – Calibration method to apply.

  • out_col (str or None, default None) – Name of the calibrated probability column. If None, f"{prob_col}_calib" is used.

  • clip (bool, default True) – If True, calibrated outputs are clipped to [0, 1] before being written to the result.

  • savefile (str or None, default None) – Optional output path handled by the SaveFile() decorator.

Returns:

Copy of df with an additional calibrated probability column.

Return type:

pandas.DataFrame

Raises:

ValueError – If method is not one of "isotonic" or "logistic".

Notes

This function is appropriate for event-probability forecasts, not for continuous quantile forecasts. For quantile-table recalibration, use calibrate_forecasts() or calibrate_quantile_forecasts() depending on the workflow.

The isotonic option is often preferred when a flexible monotone mapping is desired, whereas the logistic option is useful when a simple parametric correction is sufficient.

Examples

>>> import pandas as pd
>>> from geoprior.utils.calibrate import (
...     calibrate_probability_forecast,
... )
>>> df = pd.DataFrame(
...     {
...         "p_raw": [0.10, 0.30, 0.60, 0.85],
...         "y": [0, 0, 1, 1],
...     }
... )
>>> out = calibrate_probability_forecast(
...     df,
...     prob_col="p_raw",
...     actual_col="y",
...     method="isotonic",
... )
>>> "p_raw_calib" in out.columns
True

See also

calibrate_forecasts

Calibrate continuous quantile forecasts by inverting calibrated CDF estimates.

calibrate_quantile_forecasts

Fit-and-apply interval calibration for tabular quantile forecasts.

geoprior.utils.calibrate.calibrate_forecasts(df, quantiles, q_prefix, actual_col, method='isotonic', out_prefix='calib', grid_mode='unit', grid_size=1001, group_by=None, savefile=None)[source]

Calibrate continuous quantile forecasts by fitting a calibrated CDF surrogate and inverting it at the nominal quantile levels.

This function works on a DataFrame containing continuous quantile forecasts and the corresponding observed target. For each requested quantile level \(q\), it builds a binary supervision signal of the form

(9)#\[y_q = \mathbb{1}\{y \le \hat q\}\]

where \(y\) is the observed outcome and \(\hat q\) is the raw forecast quantile at level \(q\).

A monotone classifier is then fitted to approximate the conditional CDF at that threshold. Finally, the calibrated forecast is obtained by inverting the fitted CDF on a grid and taking the first threshold whose calibrated CDF reaches the nominal quantile level.

This provides a practical post-hoc recalibration mechanism for continuous quantile forecasts in multi-horizon settings [1, 2].

Parameters:
  • df (pandas.DataFrame) – DataFrame containing the raw forecast quantile columns and the observed target column.

  • quantiles (sequence of float) – Nominal quantile levels to recalibrate, expressed on the unit interval, for example [0.1, 0.5, 0.9].

  • q_prefix (str) – Base prefix for the raw quantile columns. For example, if q_prefix="subsidence", the function expects columns such as "subsidence_q10", "subsidence_q50", and "subsidence_q90".

  • actual_col (str) – Name of the observed continuous target column.

  • method ({"isotonic", "logistic"}, default "isotonic") – Calibration method used to estimate the monotone CDF surrogate at each quantile level.

  • out_prefix (str, default "calib") – Prefix used to build the new calibrated quantile column names. The output columns follow the pattern "{out_prefix}_{q_prefix}_qXX".

  • grid_mode ({"unit", "range"}, default "unit") –

    Domain used to construct the inversion grid.

    • "unit" builds np.linspace(0, 1, grid_size). This is mainly appropriate when the forecast support is already normalized to the unit interval.

    • "range" builds a grid from the observed range of each raw quantile column and is typically the safer choice for forecasts on the original target scale.

  • grid_size (int, default 1001) – Number of points used in the inversion grid.

  • group_by (str or None, default None) – Optional grouping column. When provided, calibration is performed separately within each group, for example by forecast horizon.

  • savefile (str or None, default None) – Optional output path handled by the SaveFile() decorator.

Returns:

Copy of df with calibrated quantile columns appended.

Return type:

pandas.DataFrame

Notes

This function differs from calibrate_quantile_forecasts(). Here, each quantile is recalibrated through a classifier-plus- inversion approach. By contrast, calibrate_quantile_forecasts() applies multiplicative interval width factors around a median-like forecast.

The two approaches answer different needs:

  • use calibrate_forecasts() when you want to recalibrate the quantile levels themselves through a threshold-CDF view;

  • use calibrate_quantile_forecasts() when your main issue is interval under- or over-dispersion and you want a simpler, auditable width correction.

When group_by is used, the function preserves the original row index order after concatenating the calibrated groups.

Examples

>>> import pandas as pd
>>> from geoprior.utils.calibrate import calibrate_forecasts
>>> df = pd.DataFrame(
...     {
...         "forecast_step": [1, 1, 2, 2],
...         "subsidence_actual": [0.4, 0.8, 0.5, 1.0],
...         "subsidence_q10": [0.3, 0.6, 0.4, 0.8],
...         "subsidence_q50": [0.4, 0.8, 0.5, 1.0],
...         "subsidence_q90": [0.5, 1.0, 0.6, 1.2],
...     }
... )
>>> out = calibrate_forecasts(
...     df,
...     quantiles=[0.1, 0.5, 0.9],
...     q_prefix="subsidence",
...     actual_col="subsidence_actual",
...     method="isotonic",
...     grid_mode="range",
...     group_by="forecast_step",
... )
>>> "calib_subsidence_q50" in out.columns
True

See also

calibrate_probability_forecast

Calibrate event probabilities rather than continuous quantiles.

calibrate_quantile_forecasts

Wrapper for interval-width calibration on tabular forecasts.

fit_interval_factors_df

Learn empirical per-horizon interval scaling factors.

References

For interpretable multi-horizon quantile forecasting, see Lim et al. [1].

For uncertainty-rich forecasting in the GeoPrior ecosystem, see Kouadio et al. [2].

The calibration layer is especially relevant to uncertainty workflows spanning Stage-2 through Stage-5 and any evaluation path that compares interval quality before and after calibration.

Forecast helpers#

Forecast utilities.

geoprior.utils.forecast_utils.detect_forecast_type(df, value_prefixes=None)[source]

Auto-detects whether a DataFrame contains deterministic or quantile forecasts, supporting both long and wide formats.

This utility inspects column names to determine the nature of the predictions.

  • It identifies a ‘quantile’ forecast if it finds columns containing a _qXX pattern (e.g., ‘subsidence_q10’, ‘GWL_2022_q50’).

  • It identifies a ‘deterministic’ forecast if no quantile columns are found, but columns ending in _pred, _actual, or matching a base prefix exist (e.g., ‘subsidence_pred’, ‘subsidence_2022_actual’, ‘GWL’).

Parameters:
  • df (pd.DataFrame) – The DataFrame to inspect.

  • value_prefixes (list of str, optional) – A list of value prefixes (e.g., [‘subsidence’, ‘GWL’]) to focus the search on. If None, prefixes are inferred from column names.

Returns:

One of ‘quantile’, ‘deterministic’, or ‘unknown’.

Return type:

str

Examples

>>> import pandas as pd
>>> from geoprior.utils.forecast_utils import detect_forecast_type
>>> # Long format quantile
>>> df_quant_long = pd.DataFrame(columns=['subsidence_q50', 'GWL_q90'])
>>> detect_forecast_type(df_quant_long)
'quantile'
>>> # Wide format quantile
>>> df_quant_wide = pd.DataFrame(columns=['subsidence_2022_q50'])
>>> detect_forecast_type(df_quant_wide)
'quantile'
>>> # Deterministic forecast
>>> df_determ = pd.DataFrame(columns=['subsidence_pred', 'GWL'])
>>> detect_forecast_type(df_determ)
'deterministic'
geoprior.utils.forecast_utils.format_forecast_dataframe(df, to_wide=True, time_col='coord_t', spatial_cols=('coord_x', 'coord_y'), value_prefixes=None, _logger=None, **pivot_kwargs)[source]

Auto-detects DataFrame format and conditionally pivots to wide format.

This function serves as a smart wrapper. It first determines if the input DataFrame is in a ‘long’ or ‘wide’ forecast format based on its column structure. If to_wide is True and the format is ‘long’, it calls pivot_forecast_dataframe() to perform the transformation.

Parameters:
  • df (pd.DataFrame) – The input DataFrame to check and potentially transform.

  • to_wide (bool, default True) –

    • If True, the function’s goal is to return a wide-format DataFrame. It will pivot a long-format frame or return a wide-format frame as is.

    • If False, the function only performs detection and returns a string (‘wide’, ‘long’, or ‘unknown’).

  • time_col (str, default 'coord_t') – The name of the column that indicates the time step. Its presence is a primary indicator of a long-format DataFrame.

  • value_prefixes (list of str, optional) – A list of prefixes for the value columns (e.g., [‘subsidence’, ‘GWL’]). If None, the function will attempt to infer them from column names that do not match common ID columns.

  • **pivot_kwargs – Additional keyword arguments to pass down to the pivot_forecast_dataframe() function if it is called. Common arguments include id_vars, static_actuals_cols, verbose, etc.

  • spatial_cols (tuple[str])

  • _logger (Logger | Callable[[str], None] | None)

Returns:

  • If to_wide is True, returns the (potentially pivoted) wide-format pd.DataFrame.

  • If to_wide is False, returns a string: ‘wide’, ‘long’, or ‘unknown’.

Return type:

pd.DataFrame or str

See also

pivot_forecast_dataframe

The underlying function that performs the pivot operation.

Examples

>>> # df_long is a typical long-format forecast output
>>> df_long.columns
Index(['sample_idx', 'forecast_step', 'coord_t', 'coord_x', ...])
>>> # Detect format
>>> format_str = format_forecast_dataframe(df_long, to_wide=False)
>>> print(format_str)
'long'
>>>
>>> # Convert to wide format
>>> df_wide = format_forecast_dataframe(
...     df_long,
...     to_wide=True,
...     id_vars=['sample_idx', 'coord_x', 'coord_y'],
...     value_prefixes=['subsidence', 'GWL'],
...     static_actuals_cols=['subsidence_actual']
... )
>>> # print(df_wide.columns)
# Index(['sample_idx', 'coord_x', 'coord_y', 'subsidence_actual',
#        'GWL_2018_q50', ...], dtype='object')
geoprior.utils.forecast_utils.get_value_prefixes(df, exclude_cols=None, spatial_cols=('coord_x', 'coord_y'), time_col='coord_t')[source]

Automatically detects the prefixes of value columns from a DataFrame.

This utility inspects the column names to infer the base names of the metrics being forecasted (e.g., ‘subsidence’, ‘GWL’), excluding common ID and coordinate columns. It works with both long and wide format forecast DataFrames.

Parameters:
  • df (pd.DataFrame) – The DataFrame from which to detect value prefixes.

  • exclude_cols (list of str, optional) – A list of columns to explicitly ignore during detection. If None, a default list of common ID/coordinate columns is used (e.g., ‘sample_idx’, ‘coord_x’, ‘coord_t’, etc.).

  • spatial_cols (tuple[str, str])

  • time_col (str)

Returns:

A sorted list of unique prefixes found in the column names.

Return type:

list of str

Examples

>>> from geoprior.utils.data_utils import get_values_prefixes
>>> # For a long-format DataFrame
>>> long_cols = ['sample_idx', 'coord_t', 'subsidence_q50', 'GWL_q50']
>>> df_long = pd.DataFrame(columns=long_cols)
>>> get_value_prefixes(df_long)
['GWL', 'subsidence']
>>> # For a wide-format DataFrame
>>> wide_cols = ['sample_idx', 'coord_x', 'subsidence_2022_q90', 'GWL_2022_q50']
>>> df_wide = pd.DataFrame(columns=wide_cols)
>>> get_value_prefixes(df_wide)
['GWL', 'subsidence']
geoprior.utils.forecast_utils.get_value_prefixes_in(df, exclude_cols=None)[source]

Automatically detects the prefixes of value columns from a DataFrame. (This is a dependency for the function below)

Parameters:
Return type:

list[str]

geoprior.utils.forecast_utils.pivot_forecast_dataframe(data, id_vars, time_col, value_prefixes, static_actuals_cols=None, time_col_is_float_year='auto', round_time_col=False, verbose=0, savefile=None, _logger=None, **kws)[source]

Transforms a long-format forecast DataFrame to a wide format.

This utility reshapes time series prediction data from a “long” format, where each row represents a single time step for a given sample, to a “wide” format, where each row represents a single sample and columns correspond to values at different time steps.

Parameters:
  • data (pd.DataFrame) – The input long-format DataFrame. It must contain the columns specified in id_vars and time_col, as well as value columns that start with the strings in value_prefixes.

  • id_vars (list of str) – A list of column names that uniquely identify each sample or group. These columns will be preserved in the wide-format output. For example: ['sample_idx', 'coord_x', 'coord_y'].

  • time_col (str) – The name of the column that represents the time step or year of the forecast (e.g., ‘coord_t’ or ‘forecast_step’). This column’s values will become part of the new column names.

  • value_prefixes (list of str) – A list of prefixes for the value columns that need to be pivoted. The function identifies columns starting with these prefixes. For instance, ['subsidence', 'GWL'] would match ‘subsidence_q10’, ‘GWL_q50’, etc.

  • static_actuals_cols (list of str, optional) – A list of columns containing static “actual” or ground truth values for each sample. These values are assumed to be constant for each unique sample_idx and are merged back into the wide DataFrame after pivoting. Example: ['subsidence_actual'].

  • time_col_is_float_year (bool or 'auto', default 'auto') –

    Controls how the time_col values are formatted into new column names. - If 'auto', automatically detects if time_col has a

    float dtype.

    • If True, treats time_col values (e.g., 2018.0) as years and converts them to integer strings (‘2018’).

    • If False, uses the string representation of the value as is.

  • round_time_col (bool, default False) – If True and time_col is a float type, its values will be rounded to the nearest integer before being used in column names. This is useful for cleaning up float years (e.g., 2018.0001 -> 2018).

  • verbose (int, default 0) – Controls the verbosity of logging messages. 0 is silent. Higher values print more details about the process.

  • savefile (str, optional) – If a file path is provided, the final wide-format DataFrame will be saved as a CSV file to that location.

  • _logger (Logger | Callable[[str], None] | None)

Returns:

A wide-format DataFrame with one row per unique combination of id_vars. New columns are created in the format {prefix}_{time_str}{_suffix} (e.g., ‘subsidence_2018_q10’).

Return type:

pd.DataFrame

See also

pandas.pivot_table

The core function used for reshaping data.

pandas.merge

Used to re-join static columns after pivoting.

Notes

  • The combination of columns in id_vars and time_col must uniquely identify each row in df_long for the pivot to succeed without data loss.

  • If using static_actuals_cols, the id_vars list must contain ‘sample_idx’ to correctly merge the static data back.

Examples

>>> import pandas as pd
>>> from geoprior.utils.data_utils import pivot_forecast_dataframe
>>> data = {
...     'sample_idx':      [0, 0, 1, 1],
...     'coord_t':         [2018.0, 2019.0, 2018.0, 2019.0],
...     'coord_x':         [0.1, 0.1, 0.5, 0.5],
...     'coord_y':         [0.2, 0.2, 0.6, 0.6],
...     'subsidence_q50':  [-8, -9, -13, -14],
...     'subsidence_actual': [-8.5, -8.5, -13.2, -13.2],
...     'GWL_q50':         [1.2, 1.3, 2.2, 2.3],
... }
>>> df_long_example = pd.DataFrame(data)
>>> df_wide = pivot_forecast_dataframe(
...     data=df_long_example,
...     id_vars=['sample_idx', 'coord_x', 'coord_y'],
...     time_col='coord_t',
...     value_prefixes=['subsidence', 'GWL'],
...     static_actuals_cols=['subsidence_actual'],
...     verbose=0
... )
>>> print(df_wide.columns)
Index(['sample_idx', 'coord_x', 'coord_y', 'subsidence_actual',
       'GWL_2018_q50', 'GWL_2019_q50', 'subsidence_2018_q50',
       'subsidence_2019_q50'],
      dtype='object')
geoprior.utils.forecast_utils.get_step_names(forecast_steps, step_names=None, default_name='')[source]

Build a step → label mapping for multi‑horizon plots.

The helper reconciles an integer list forecast_steps with an optional alias container (dict or sequence) and returns a dictionary whose keys are the integer steps and whose values are human‑readable labels.

Matching is case‑insensitive and tolerant to common delimiters—e.g. "Step 1", "step‑1", or "forecast step 1" will all map to integer step 1.

Parameters:
  • forecast_steps (Iterable[int]) – Ordered steps, e.g. [1, 2, 3].

  • step_names (dict | list | tuple | None, default None) –

    Custom labels. Accepted forms

    • dict – keys may be int or any string representation of the step.

    • sequence – positional, where the k‑th element labels step k+1.

    • None – no custom mapping.

  • default_name (str, default "") – Fallback label for steps missing from step_names. If empty, the step number itself is used (as a string).

Returns:

Mapping {step : label} for every element of forecast_steps.

Return type:

dict[int, str]

Notes

  • Dictionary keys are normalised with int(re.sub(r"[^0-9]", "", str(key))) before matching.

  • Duplicate keys in step_names are resolved by last‐one wins semantics.

Examples

>>> from geoprior.utils.forecast_utils import get_step_names
>>> get_step_names(
...     forecast_steps=[1, 2, 3],
...     step_names={"1": "Year 2021", 2: "2022", "step 3": "2023"},
... )
{1: 'Year 2021', 2: '2022', 3: '2023'}
>>> get_step_names(
...     forecast_steps=[1, 2, 3, 4],
...     step_names={"1": "2021", "2": "2022"},
... )
{1: '2021', 2: '2022', 3: '3', 4: '4'}
>>> get_step_names(
...     [1, 2, 3, 4],
...     step_names=None,
...     default_name="step with no name",
... )
{1: 'step with no name', 2: 'step with no name',
 3: 'step with no name', 4: 'step with no name'}
geoprior.utils.forecast_utils.stack_quantile_predictions(q_lower, q_median, q_upper)[source]

Stack three quantile trajectories into a single y_pred array of shape (n_samples, 3, n_timesteps), ready for PSS.

Parameters:
  • q_lower (array-like) – Each is either - 1D: (n_timesteps,) → interpreted as a single sample, or - 2D: (n_samples, n_timesteps)

  • q_median (array-like) – Each is either - 1D: (n_timesteps,) → interpreted as a single sample, or - 2D: (n_samples, n_timesteps)

  • q_upper (array-like) – Each is either - 1D: (n_timesteps,) → interpreted as a single sample, or - 2D: (n_samples, n_timesteps)

Returns:

y_pred – Where axis=1 indexes [lower, median, upper].

Return type:

np.ndarray, shape (n_samples, 3, n_timesteps)

Raises:

ValueError – If the three inputs (after promotion) do not share the same shape.

geoprior.utils.forecast_utils.adjust_time_predictions(df, time_col, forecast_horizon, coord_scaler=None, inverse_transformed=False, verbose=1)[source]

Adjusts time predictions by adding the forecast horizon to inverse normalized time. If the time column has already been inverse-transformed, skip the inverse transformation.

Parameters:
  • df (pd.DataFrame) – The DataFrame containing the time predictions (inverse scaled). The time column specified by time_col should contain the time values that need to be adjusted.

  • time_col (str) – The name of the time column in the DataFrame. This column will be adjusted by adding the forecast horizon.

  • forecast_horizon (int) – The forecast horizon (e.g., number of years or time steps) that will be added to the time predictions. This value shifts the time predictions forward.

  • coord_scaler (MinMaxScaler, optional) – The scaler that was used for the coordinates. It is necessary to reverse the scaling for the time column if it was previously normalized. If not provided, the time column should already be inverse-transformed.

  • inverse_transformed (bool, default False) – If True, skips the inverse transformation of the time column and directly adds the forecast horizon. This is useful when the time column has already been inverse-transformed, and you only need to adjust the time by the forecast horizon.

  • verbose (int, default 1) – Verbosity level for logging. Higher values (e.g., verbose=2) provide more detailed information about the operation.

Returns:

The adjusted DataFrame with the time column updated to reflect the forecast horizon. The time predictions are adjusted by adding the forecast_horizon to each entry in the time column.

Return type:

pd.DataFrame

Raises:

ValueError – If the time column is not found in the DataFrame or if the scaler is not available when necessary.

Examples

>>> import pandas as pd
>>> from sklearn.preprocessing import MinMaxScaler
>>> # Sample data for illustration
>>> df = pd.DataFrame({
>>>     'year': [0.0, 0.5, 1.0],
>>>     'subsidence': [0.1, 0.2, 0.3]
>>> })
>>> scaler = MinMaxScaler()
>>> df_scaled = df.copy()
>>> df_scaled['year'] = scaler.fit_transform(df_scaled[['year']])
>>> adjusted_df = adjust_time_predictions(
>>>     df_scaled,
>>>     time_col='year',
>>>     forecast_horizon=4,
>>>     coord_scaler=scaler,
>>>     inverse_transformed=False,
>>>     verbose=2
>>> )
>>> adjusted_df['year']
[0.0, 0.5, 1.0] -> After adjustment, will be shifted to the future.

Notes

  • The time column must be in a normalized scale if not already inverse-transformed.

  • If inverse_transformed=True, the time values will directly be adjusted by the forecast_horizon without applying the inverse transformation.

  • The forecast horizon is added directly to the time values after the necessary inverse transformation (if applicable).

See also

sklearn.preprocessing.MinMaxScaler

Scales features to [0,1].

geoprior.utils.forecast_utils.add_forecast_times(df, *, forecast_times=None, start=None, freq='YS', step_col='forecast_step', time_col='coord_t', error='raise', inplace=False, savefile=None, verbose=0)[source]

Map each 1‑based forecast_step into an explicit calendar time.

You may either:
  1. Pass forecast_times of length H (one per step), or

  2. Pass a single start plus a pandas‐style freq to generate H dates.

If any entry in forecast_times is an integer of exactly 4 digits, it will be interpreted as January 1 of that year.

Parameters:
  • df (pd.DataFrame) – Long‐format forecast table. Must contain an integer column step_col with values 1..H.

  • forecast_times (sequence, optional) – Explicit sequence of length H specifying the target times. Each entry may be: - int (interpreted as January 1 of that year) - str/pd.Timestamp/datetime.date

  • start (int or str or date or Timestamp, optional) – Only used if forecast_times is None. The first time in the sequence; subsequent times will be generated via pd.date_range. If int, treated as a year at Jan 1.

  • freq (str, default "YS") – Pandas offset alias for frequency (e.g. “YS”=year start, “MS”=month start, “D”=day, etc.). Only used when start is set.

  • step_col (str, default "forecast_step") – Name of the 1‑based step index in df.

  • time_col (str, default "coord_t") – Name of the new column to create with mapped times.

  • error ({'raise','warn','ignore'}, default 'raise') – Policy if df[step_col].max() > number of provided times: - ‘raise’: throw ValueError - ‘warn’: issue warning, then still map what you can (truncate) - ‘ignore’: silently truncate to available times

  • inplace (bool, default False) – If True, modify df in place; otherwise return a new DataFrame.

  • savefile (str, optional) – If provided, path to CSV where the resulting DataFrame will be saved.

  • verbose (int, default 0) – Passed to vlog for debug logging.

Returns:

DataFrame with an added column time_col of dtype datetime64.

Return type:

pd.DataFrame

Raises:

ValueError – If neither forecast_times nor start is provided, or if error=’raise’ and there aren’t enough times.

Examples

>>> from geoprior.utils.forecast_utils import add_forecast_times
>>> df = pd.DataFrame({
...     "sample_idx": [0]*3 + [1]*3,
...     "forecast_step": [1,2,3]*2
... })
>>> add_forecast_times(df,
...     forecast_times=[2022,2023,2024])
   sample_idx  forecast_step     coord_t
0           0              1  2022-01-01
1           0              2  2023-01-01
2           0              3  2024-01-01
3           1              1  2022-01-01
4           1              2  2023-01-01
5           1              3  2024-01-01
>>> # Or generate from a start + yearly freq:
>>> add_forecast_times(df, start="2022-06-15", freq="YS")
geoprior.utils.forecast_utils.pivot_forecast(df, *, index_col='sample_idx', pivot_col=None, step_col='forecast_step', time_col='coord_t', value_cols=None, spatial_cols=None, aggfunc='first', fill_value=nan, sep='_', time_formatter=<function <lambda>>, inplace=False, savefile=None, verbose=0)[source]

Pivot a long-format forecast DataFrame into a wide one.

This will take rows identified by index_col + a step_col (or datetime time_col) and spread each forecast step/time into its own set of columns for each value in value_cols, then re-attach the spatial_cols.

Parameters:
  • df (DataFrame) – Long-format forecasts. Must include index_col and at least one of step_col or time_col.

  • index_col (str) – Column that identifies each sample (e.g. “sample_idx”).

  • pivot_col (str | None) – If provided, pivot on this column instead of auto-detecting. Must be either step_col or time_col.

  • step_col (str) – Name of the integer 1‑based forecast step column.

  • time_col (str) – Name of the datetime column (e.g. “coord_t”).

  • value_cols (str | Sequence[str] | None) – Which forecast columns to pivot (e.g. “subsidence_q50” or [“subsidence_q10”,”subsidence_q50”,”subsidence_q90”]). If None, will auto-pick all numeric columns except index/pivot/spatial.

  • spatial_cols (Sequence[str] | None) – List of columns holding static spatial info (e.g. [“longitude”,”latitude”]) to join back once pivoted.

  • aggfunc (str | Callable) – Aggregation function for pivot (default “first”).

  • fill_value (Any) – What to put where a sample/step is missing (default NaN).

  • sep (str) – Separator between value name and step/time in the new column names (default “_”).

  • time_formatter (Callable[[Any], str]) – How to turn a datetime/timestamp into a string for column names (default “%Y-%m-%d”).

  • inplace (bool) – If True, modifies df instead of copying.

  • savefile (str | None) – If given, writes the resulting wide DataFrame to CSV at this path.

  • verbose (int) – Passed to vlog for logging.

Returns:

Wide-format DataFrame with one row per index_col and columns like <value><sep><step> or <value><sep><formatted time>.

Return type:

pd.DataFrame

Example

>>> from geoprior.utils.forecast_utils import pivot_forecast
>>> dff = pivot_forecast(
...    df_,
...    index_col="sample_idx",
...    pivot_col="coord_t",                # ← force pivot on the datetime
...    value_cols=["subsidence_q10","subsidence_q50","subsidence_q90"],
...    spatial_cols=["longitude","latitude"],
...    sep="_",                            # you’ll get subsidence_q50_2022 etc.
...    time_formatter=lambda t: f"{t.year}",
...    verbose=1
... )
[INFO] Pivoting on 'coord_t' for values ['subsidence_q10', 'subsidence_q50', 'subsidence_q90']
[INFO] Joining back spatial cols ['longitude', 'latitude']
>>> dff.columns
Out[37]:
Index(['sample_idx', 'subsidence_q10_2022', 'subsidence_q10_2023',
       'subsidence_q10_2024', 'subsidence_q50_2022', 'subsidence_q50_2023',
       'subsidence_q50_2024', 'subsidence_q90_2022', 'subsidence_q90_2023',
       'subsidence_q90_2024', 'longitude', 'latitude'],
      dtype='object')
geoprior.utils.forecast_utils.plot_reliability_diagram(models_data, y_true=None, prefix='subsidence', figsize=(8, 8), title='Reliability Diagram', plot_style='seaborn-whitegrid', verbose=None, _logger=None)[source]

Plot a reliability diagram for one or multiple models.

Parameters:
  • models_data (dict) – Mapping of model names to forecast data. Each value can be a pandas.DataFrame or a nested dict with keys ‘forecasts’, ‘color’, ‘marker’, and ‘style’.

  • y_true (pandas.Series, optional) – Observed values for empirical coverage calculations. Required when forecasts need processing.

  • prefix (str, default 'subsidence') – Column prefix for quantile forecast fields.

  • figsize (tuple of int, default (8, 8)) – Figure size (width, height) in inches.

  • title (str, default 'Reliability Diagram') – Text title displayed at the top of the plot.

  • plot_style (str, default 'seaborn-whitegrid') – Matplotlib style sheet name to apply.

  • verbose (int, optional) – Verbosity level passed to geoprior.utils.generic_utils.vlog.

  • _logger (Logger or callable, optional) – Function or logger instance for internal messages.

Returns:

Displays the calibration plot and returns nothing.

Return type:

None

Notes

This function draws a diagonal baseline (perfect calibration) and computes empirical coverage for probabilistic intervals using specified quantiles. It wraps simple DataFrame inputs into the required nested format and uses vlog for conditional logging.

Examples

>>> import pandas as pd
>>> import numpy as np
>>> from geoprior.utils.forecast_utils import plot_reliability_diagram
>>> # Create dummy true time series
>>> dates = pd.date_range('2020-01-01', periods=100)
>>> y_true = pd.Series(
...     np.random.randn(100), index=dates
... )
>>> # Create forecasts for ModelA
>>> dfA = pd.DataFrame({
...     'subsidence_q10': y_true - 0.5,
...     'subsidence_q90': y_true + 0.5
... }, index=dates)
>>> # Simple usage with one model
>>> plot_reliability_diagram(
...     models_data={'ModelA': dfA},
...     y_true=y_true,
...     verbose=2
... )
>>> # Create forecasts for ModelB
... # with custom styling
>>> dfB = pd.DataFrame({
...     'subsidence_q10': y_true - 1.0,
...     'subsidence_q90': y_true + 1.0
... }, index=dates)
>>> custom_logger = print
>>> # Custom styling and logger
>>> plot_reliability_diagram(
...     models_data={
...         'ModelA': {
...             'forecasts': dfA,
...             'color': 'C0',
...             'marker': 'x'
...         },
...         'ModelB': {
...             'forecasts': dfB,
...             'color': 'C1',
...             'marker': 'o'
...         }
...     },
...     y_true=y_true,
...     verbose=4,
...     _logger=custom_logger
... )
geoprior.utils.forecast_utils.format_and_forecast(y_pred, y_true, *, coords=None, quantiles=None, target_name='subsidence', output_target_name=None, scaler_target_name=None, target_key_pred='subs_pred', component_index=0, scaler_info=None, coord_scaler=None, coord_columns=('coord_t', 'coord_x', 'coord_y'), train_end_time=None, forecast_start_time=None, forecast_horizon=None, future_time_grid=None, eval_forecast_step=None, eval_export='all', value_mode='rate', input_value_mode='rate', rate_first='cum_over_dtref', absolute_baseline=None, sample_index_offset=0, city_name=None, model_name=None, dataset_name=None, csv_eval_path=None, csv_future_path=None, time_as_datetime=False, time_format=None, calibration=False, calibration_kwargs=None, calibration_save_stats=None, eval_metrics=False, metrics_column_map=None, metrics_quantile_interval=(0.1, 0.9), metrics_per_horizon=False, metrics_extra=None, metrics_extra_kwargs=None, metrics_savefile=None, metrics_save_format='.json', metrics_time_as_str=True, output_unit=None, output_unit_from='m', output_unit_mode='overwrite', output_unit_suffix='_mm', output_unit_col=None, verbose=1, logger=None, **kws)[source]

Format PINN forecasts into evaluation and future DataFrames.

This helper takes the raw model outputs (already split into y_pred['subs_pred'] / y_pred['gwl_pred']), the matching ground-truth dictionary (y_true), and optional coordinate and scaler information, and returns two DataFrames:

  • df_eval: predictions + actuals for an evaluation year (typically the last training year, e.g. 2022).

  • df_future: predictions for the future horizon (e.g. 2023–2025), without actuals.

Parameters:
  • y_pred (dict) –

    Dictionary of model predictions, as returned by GeoPriorSubsNet.predict post-processed into {'subs_pred': ..., 'gwl_pred': ...}.

    For subsidence, the expected shapes are:

    • Quantile mode: (B, H, Q, O) where: B = batch size, H = horizon steps, Q = number of quantiles, O = output dim.

    • Point mode: (B, H, O).

  • y_true (dict or None) –

    Dictionary of true targets, typically

    {'subsidence': ..., 'gwl': ...} or {'subs_pred': ..., 'gwl_pred': ...}.

    If None, evaluation DataFrame is still created but without the actual-value column.

  • coords (ndarray, optional) – Optional coordinates array aligned with predictions. Commonly shaped (B, H, 3) with columns [t_scaled, x_scaled, y_scaled]. Only x and y are used when inverse-transforming spatial coordinates; time is overwritten by the provided temporal config if given.

  • quantiles (list of float or None, optional) – List of quantiles (e.g. [0.1, 0.5, 0.9]) if the model was trained in probabilistic mode. If None, a single prediction column is emitted instead.

  • target_name (str, default 'subsidence') –

    Logical target identifier used as the default key for locating the target scaler in scaler_info and as a fallback for resolving truth arrays in y_true.

    Column naming is controlled by output_target_name (or the auto-derived output prefix when it is None).

  • output_target_name (str or None, optional) –

    Output prefix used when creating DataFrame columns for predictions and actuals.

    This controls the column naming only (e.g. the function will emit f"{output_target_name}_q10", f"{output_target_name}_pred", and f"{output_target_name}_actual").

    If None (default), the function derives the output prefix from target_name and applies a small convenience rule: if target_name ends with "_cum" or "_cumulative", that suffix is stripped for output naming.

    This keeps downstream tooling consistent (many plotting and metrics utilities expect names like subsidence_q10 rather than subsidence_cum_q10), while still allowing the scaler lookup to use the true target key. For example, with target_name="subsidence_cum" and output_target_name=None, output columns become subsidence_q10, subsidence_q50, and subsidence_actual. If output_target_name="subsidence_cum", the output columns keep the suffix such as subsidence_cum_q10.

  • scaler_target_name (str or None, optional) –

    Name used to locate the target scaling block inside scaler_info and to perform inverse-transform for predictions and actuals.

    This controls the scaler key and inverse scaling, not the output column naming.

    If None (default), the scaler key is assumed to be target_name. This is important when you want clean output columns but the scaler was fitted/stored under the original target name.

    A common pattern is to keep target_name="subsidence_cum" so the scaler lookup matches the Stage-1 schema, while letting output_target_name=None produce clean output columns. In that setup, inverse transform still uses the subsidence_cum scaler key, while output columns use the subsidence_ prefix because of the auto-strip rule.

  • target_key_pred (str, default 'subs_pred') – Key inside y_pred that holds the subsidence forecasts.

  • component_index (int, default 0) – Index along the output dimension O to use when output_subsidence_dim > 1. For scalar subsidence this is 0.

  • scaler_info (dict, optional) – Optional Stage-1 scaler_info mapping containing a target scaler under keys such as 'targets' or 'target'. The target block is expected to provide an sklearn-like transformer under 'scaler' together with column names under 'columns' or 'cols'. If present and consistent, subsidence values (predicted and actual) are inverse-transformed for target_name.

  • coord_scaler (object, optional) – Optional scaler used for coordinates. If provided, it is only used to inverse-transform coord_x and coord_y when coords is given and coord_columns can be matched. Time is not taken from the inverse transform; it is controlled by the temporal config.

  • coord_columns (tuple of str, default (``’coord_t’:py:class:`,`’coord_x’:py:class:`,`’coord_y’``)) – Logical names of the time, x, and y coordinate columns. These are used for DataFrame column naming and for mapping into coord_scaler if its block carries column names.

  • train_end_time (scalar or str or datetime, optional) – Physical time associated with the evaluation year (e.g. 2022). If eval_forecast_step is not given, the last horizon step is assumed to correspond to this time.

  • forecast_start_time (scalar or str or datetime, optional) – First time in the future forecast horizon (e.g. 2023).

  • forecast_horizon (int, optional) – Number of forecast steps in the future horizon (e.g. 3). If future_time_grid is not given, this is used together with forecast_start_time to build a regular grid.

  • future_time_grid (array-like, optional) – Explicit physical times for each forecast step, length H. For yearly data this might be [2023, 2024, 2025]. If provided, it overrides any automatic construction from forecast_start_time and forecast_horizon.

  • eval_forecast_step (int or None, optional) – Horizon step index (1-based) to use for evaluation. If None, defaults to the last horizon step H.

  • eval_export ({"all", "last"} or str or int or sequence, optional) –

    Controls which evaluation rows are exported in df_eval and written to csv_eval_path. By default ("all"), the function exports the multi-horizon evaluation DataFrame (df_eval_all), which contains one row per sample and forecast step (e.g. years 2020, 2021, 2022 for H=3).

    Accepted values are:

    • "all" or "full" or "horizons" : export all horizons from df_eval_all.

    • "last" or "single" or "default" : export only the single evaluation step specified by eval_forecast_step (backwards-compatible behaviour).

    • Other str (e.g. "2022") : interpreted as a time value for coord_t; only rows of df_eval_all whose time column matches this value are exported.

    • int or scalar non-string : interpreted as a single time value (e.g. 2022).

    • sequence of values (e.g. [2021, 2022]) : interpreted as a set of time values; only rows whose coord_t belongs to this set are exported.

    If time_as_datetime=True, the selection values are converted with pandas.to_datetime using time_format before filtering. If df_eval_all is not available (e.g. no ground truth was provided), the function falls back to exporting the single-step df_eval regardless of eval_export.

  • value_mode ({"rate", "cumulative", "absolute_cumulative"}, optional) –

    Controls how forecast values are interpreted along the temporal horizon for each sample. The default is "rate", which treats each forecast step as an incremental rate (e.g. annual subsidence rate) and leaves predictions unchanged.

    Supported modes are:

    • "rate" : keep per-step predictions as provided by the model (current behaviour).

    • "cumulative" or "cum" : convert per-step rates into relative cumulative values by applying a cumulative sum over forecast_step for each sample_idx. For example, for years 2023–2025, the value at 2024 is the sum of the 2023 and 2024 rates.

    • "absolute_cumulative" or "abs_cum" or "absolute" : same as "cumulative", then add an absolute baseline provided by absolute_baseline (e.g. cumulative subsidence at the end of the training period), yielding absolute cumulative trajectories.

    Cumulative transforms are applied consistently to:

    • the future forecast DataFrame (df_future),

    • the multi-horizon evaluation DataFrame (df_eval_all),

    • and the single-step evaluation DataFrame (df_eval, which is regenerated from df_eval_all after the transformation).

    When an unsupported string is given, the function logs a warning and falls back to "rate".

  • absolute_baseline (float or Mapping[int, float], optional) –

    Baseline value to use when value_mode requests absolute cumulative outputs ("absolute_cumulative", "abs_cum", "absolute"). This baseline is interpreted as the pre-forecast cumulative level for each sample, for example, cumulative subsidence at train_end_time (e.g. end of 2022), and is added after applying the cumulative sum over the forecast horizon.

    If a scalar float is provided, the same baseline value is added to all samples. If a mapping is provided, it must map sample_idx (integers) to baseline values, allowing per-sample baselines:

    • absolute_baseline = {sample_idx: baseline_value, ...}

    Only prediction columns for target_name are shifted (e.g. "subsidence_q10", "subsidence_q50", "subsidence_q90" or "subsidence_pred"). When df_eval_all is present, the corresponding "<target_name>_actual" column is shifted as well, so evaluation metrics operate on absolute cumulative values.

    If value_mode is an absolute cumulative variant but absolute_baseline is None, the function logs a warning and degrades gracefully to relative cumulative mode (i.e. no baseline shift is applied).

  • sample_index_offset (int, default 0) – Offset added to sample_idx (useful when concatenating multiple tiles).

  • city_name (str, optional) – Optional metadata used only for logging.

  • model_name (str, optional) – Optional metadata used only for logging.

  • dataset_name (str, optional) – Optional metadata used only for logging.

  • csv_eval_path (str, optional) – If provided, df_eval is written to this path (directories are created if needed).

  • csv_future_path (str, optional) – If provided, df_future is written to this path.

  • time_as_datetime (bool, default False) – If True, time values are converted using pandas.to_datetime() with the provided time_format (if any).

  • time_format (str or None, optional) – Optional format string passed to pandas.to_datetime() when time_as_datetime=True.

  • eval_metrics (bool, default False) – If True, automatically call evaluate_forecast() on the resulting df_eval to compute diagnostics. Metrics are not returned by this function; they are either written to disk (if metrics_savefile is provided) or discarded. For programmatic access to the metrics dictionary, call evaluate_forecast() directly.

  • metrics_column_map (mapping, optional) – Optional column mapping forwarded to evaluate_forecast() (see its documentation for details). If None, default column names such as 'coord_t', 'forecast_step', f'{target_name}_q10', and f'{target_name}_actual' are assumed.

  • metrics_quantile_interval (tuple of float, default (0.1, 0.9)) – Interval used for coverage and sharpness diagnostics in quantile mode, forwarded to evaluate_forecast().

  • metrics_per_horizon (bool, default False) – If True, per-horizon MAE/MSE/R² are computed by evaluate_forecast() and included in the diagnostics.

  • metrics_extra (sequence or mapping, optional) –

    Optional additional metrics to compute, forwarded to evaluate_forecast(). Can be:

    • A sequence of metric names (resolved via geoprior.metrics._registry.get_metric).

    • A mapping {name: func} where func is a callable taking (y_true, y_pred, **kwargs).

  • metrics_extra_kwargs (mapping, optional) – Optional per-metric keyword arguments, forwarded to evaluate_forecast(). Keys must match metric names in metrics_extra.

  • metrics_savefile (str, path-like, bool, or None) – If truthy, diagnostics from evaluate_forecast() are written to disk. Behavior matches the savefile argument of evaluate_forecast(). When True, a filename is auto-generated near the evaluation CSV (if any) or in the current working directory.

  • metrics_save_format ({'.json', 'json', '.csv', 'csv'}, default '.json') – Output format for diagnostics written by evaluate_forecast(). JSON preserves the nested metric structure; CSV flattens it into a tall table.

  • metrics_time_as_str (bool, default True) – If True, time keys in the diagnostics written by evaluate_forecast() are converted to strings (useful for JSON serialization).

  • verbose (int, default 1) – Verbosity level passed to vlog().

  • logger (logging.Logger, optional) – Logger instance; if None, a module-level LOG is used.

  • input_value_mode (str)

  • rate_first (str)

  • calibration (str | bool)

  • calibration_kwargs (Mapping[str, Any] | None)

  • calibration_save_stats (str | PathLike | None)

  • output_unit (str | None)

  • output_unit_from (str)

  • output_unit_mode (str)

  • output_unit_suffix (str)

  • output_unit_col (str | None)

Returns:

  • df_eval_to_write (pandas.DataFrame) – DataFrame containing predictions and actuals for the evaluation time. Columns include:

    • 'sample_idx'

    • 'forecast_step'

    • quantile columns (e.g. subsidence_q10) or subsidence_pred

    • 'subsidence_actual' (if y_true given)

    • coord_t, coord_x, coord_y (names from coord_columns).

  • df_future (pandas.DataFrame) – DataFrame containing predictions for the future horizon, without actuals. Same structure as df_eval but without the actual-value column.

Return type:

tuple[DataFrame, DataFrame]

Notes

This function separates scaler lookup (scaler_target_name) from output column naming (output_target_name). This is useful when the stored scaler key contains suffixes like "_cum" but downstream tools expect canonical names such as columns prefixed with subsidence_.

geoprior.utils.forecast_utils.evaluate_forecast(eval_data, *, target_name='subsidence', column_map=None, quantile_interval=(0.1, 0.9), per_horizon=False, extra_metrics=None, extra_metric_kwargs=None, overall_key='__overall__', savefile=None, save_format='.json', time_as_str=True, verbose=1, logger=None)[source]

Evaluate forecast diagnostics from an evaluation DataFrame.

This helper consumes the df_eval output from format_and_forecast() (or a compatible DataFrame) and computes aggregate metrics such as MAE, MSE, \(R^2\), coverage, and sharpness. It can also optionally evaluate metrics per forecast horizon and apply additional user-defined metrics.

By default it expects the following columns:

  • 'sample_idx'

  • 'forecast_step'

  • 'coord_t' (time)

  • Quantile or point-prediction columns for the target, e.g.:

    • Quantile mode: f'{target_name}_q10', f'{target_name}_q50', f'{target_name}_q90', …

    • Point mode: f'{target_name}_pred'.

  • Actual column: f'{target_name}_actual'.

A flexible column_map allows remapping these logical roles to arbitrary column names, e.g.:

column_map = {
    'coord_t': 'date',
    'actual': 'true_subs',
    'pred': 'subs_predicted',
}

or, for quantile columns:

column_map = {
    'coord_t': 'date',
    'quantiles': {
        0.1: 'subs_q10',
        0.5: 'subs_q50',
        0.9: 'subs_q90',
    },
}
Parameters:
  • eval_data (str, path-like, or pandas.DataFrame) – Either a path to a CSV file containing the evaluation DataFrame (as saved by format_and_forecast()) or an in-memory DataFrame.

  • target_name (str, default 'subsidence') – Base name for the target columns. Used to infer default column names such as f'{target_name}_q10', f'{target_name}_pred', and f'{target_name}_actual'.

  • column_map (dict, optional) –

    Optional mapping to override default column names. The following keys are recognized:

    • 'sample_idx' : sample index column name (default 'sample_idx').

    • 'forecast_step' : horizon index column name (default 'forecast_step').

    • 'coord_t' : time coordinate column (default 'coord_t').

    • 'actual' : name or list of names for the actual target column(s). Currently a single column is supported; default f'{target_name}_actual'.

    • 'pred' : point prediction column for non-quantile mode, default f'{target_name}_pred'.

    • 'quantiles' :

      • If a mapping: {q: col_name} for quantile levels, where q is a float in (0, 1).

      • If a sequence of column names, the quantile value will be inferred from suffix patterns like f'{target_name}_q{int(q*100):d}'.

  • quantile_interval (tuple of float, default (0.1, 0.9)) – Interval (lower, upper) used for coverage and sharpness metrics, typically corresponding to an 80% interval between Q10 and Q90.

  • per_horizon (bool, default False) – If True, compute per-horizon MAE/MSE/R² grouped by the forecast_step column.

  • extra_metrics (sequence of str or mapping, optional) –

    Optional additional metrics to compute.

    • If a sequence of strings (e.g. ['pss', 'pit']), each name is resolved via geoprior.metrics._registry.get_metric(). If the name is not present in the registry, an error is raised, prompting the user to pass a callable instead.

    • If a mapping {name: func}, each func is called as:

      func(y_true, y_pred, **extra_metric_kwargs.get(name, {}))
      

      where y_pred is the median (Q50) or point forecast.

    For more complex metrics that require full quantile structure or temporal sequences, pass a suitable wrapper function that internally uses the DataFrame as needed.

  • extra_metric_kwargs (mapping, optional) – Optional mapping of per-metric keyword arguments. Keys must match the names in extra_metrics. Each value is a dict of kwargs forwarded to the corresponding metric function.

  • savefile (str, path-like, or bool, optional) –

    If provided, metrics are saved to disk.

    • If True: a filename is auto-generated near eval_data (if it is a path) or in the current working directory.

    • If a string/path without extension: the extension is taken from save_format.

    • If a string/path with extension: that extension takes precedence over save_format.

  • save_format ({'.json', 'json', '.csv', 'csv'}, default '.json') –

    Output format when savefile is truthy. JSON preserves nested structure; CSV is flattened into a tall table.

    • For JSON, the function returns the metrics dictionary.

    • For CSV, the function returns the metrics DataFrame.

  • time_as_str (bool, default True) – If True, time keys in the result dictionary are converted to strings (useful for JSON serialization). If there is only a single time value, the result is flattened and the time key is omitted.

  • verbose (int, default 1) – Verbosity level passed to vlog().

  • logger (logging.Logger, optional) – Optional logger instance used by vlog().

  • overall_key (str | None)

Returns:

results – If save_format is JSON (default), returns a dict:

  • Single time value:

    {
        "overall_mae": ...,
        "overall_mse": ...,
        "overall_r2": ...,
        "coverage80": ...,
        "sharpness80": ...,
        "per_horizon_mae": {1: ..., 2: ..., ...},
        ...
    }
    
  • Multiple time values:

    {
        "2021": { ...metrics... },
        "2022": { ...metrics... },
    }
    

If save_format is CSV, returns a DataFrame with flattened rows:

  • Columns include: coord_t, metric, horizon, and value.

Return type:

dict or pandas.DataFrame

Notes

  • Default metrics in quantile mode:

    • overall_mae, overall_mse, overall_r2

    • coverage80 and sharpness80 (using the requested interval, e.g., Q10–Q90)

    If per_horizon=True, also:

    • per_horizon_mae, per_horizon_mse, per_horizon_r2 (each a mapping from horizon index to score).

  • Default metrics in point mode (no quantiles):

    • mae, mse, r2

    And optionally, if per_horizon=True:

    • per_horizon_mae, per_horizon_mse, per_horizon_r2.

This is one of the most frequently visited workflow helper modules because it turns saved or live outputs into reusable forecast tables, aligned metrics, and export-ready structures.

Generic helpers#

Provides common helper functions and for validation, comparison, and other generic operations

class geoprior.utils.generic_utils.ExistenceChecker[source]

Bases: object

A utility class for checking and ensuring the existence of files and directories on the filesystem.

This class provides static methods to verify whether a given path exists and to create directories or files if necessary. It raises informative exceptions when paths are invalid or cannot be created.

ensure_directory(path)[source]

Ensure a directory exists at the specified path.

ensure_file(path, create_parent_dirs=False)[source]

Ensure a file exists at the specified path, optionally creating parent directories.

Examples

>>> from geoprior.utils.generic_utils import ExistenceChecker
>>> # Ensure a directory exists
>>> dir_path = ExistenceChecker.ensure_directory("data/output")
>>> isinstance(dir_path, Path)
True
>>> # Ensure a file exists, creating parent directories
>>> file_path = ExistenceChecker.ensure_file(
...     "data/output/results.txt", create_parent_dirs=True
... )
>>> file_path.exists()
True

Notes

  • Uses pathlib.Path.mkdir(…, parents=True, exist_ok=True) under the hood to create directories.

  • Creating a file will produce an empty file if it does not exist.

  • Raises TypeError if the given path is not a str or pathlib.Path, and appropriate OSError/FileExistsError for filesystem errors.

See also

pathlib.Path.mkdir

Method to create directories.

pathlib.Path.touch

Method to create an empty file.

os.makedirs

Legacy function for creating directories recursively.

os.path.exists

Check if a path exists.

static ensure_directory(path)[source]

Ensure that a directory exists at the given path, creating it if needed.

Parameters:

path (str or pathlib.Path) – The filesystem path for which to ensure directory existence. Can be either a string or a pathlib.Path object.

Returns:

A Path object pointing to the existing (or newly created) directory.

Return type:

pathlib.Path

Raises:
  • TypeError – If path is not a string or pathlib.Path.

  • FileExistsError – If a file (not a directory) already exists at path.

  • OSError – If the directory cannot be created for any other reason (e.g., insufficient permissions).

static ensure_file(path, create_parent_dirs=False)[source]

Ensure that a file exists at the given path, creating it if needed.

If create_parent_dirs is True, any missing parent directories will be created automatically.

Parameters:
  • path (str or pathlib.Path) – The filesystem path for the file that must exist.

  • create_parent_dirs (bool, optional) – If True, create any missing parent directories. Default is False.

Returns:

A Path object pointing to the existing (or newly created) file.

Return type:

pathlib.Path

Raises:
  • TypeError – If path is not a string or pathlib.Path.

  • FileExistsError – If a directory (not a file) already exists at path.

  • OSError – If the file or parent directories cannot be created due to filesystem errors.

geoprior.utils.generic_utils.ensure_directory_exists(path)[source]

Ensure that a directory exists at the given path, creating it if needed.

This function checks whether the provided path exists and is a directory. If the path does not exist, it attempts to create the directory (including any necessary parent directories). If a file with the same name already exists, or if creation fails, an exception is raised.

Parameters:

path (str or pathlib.Path) – The filesystem path for which to ensure directory existence. Can be either a string or a pathlib.Path object.

Returns:

A Path object pointing to the existing (or newly created) directory.

Return type:

pathlib.Path

Raises:
  • TypeError – If path is not a string or pathlib.Path.

  • FileExistsError – If a file (not a directory) already exists at path.

  • OSError – If the directory cannot be created for any other reason (e.g., insufficient permissions).

Examples

>>> from pathlib import Path
>>> from geoprior.utils.generic_utils import ensure_directory_exists
>>> output_dir = ensure_directory_exists("data/output")
>>> isinstance(output_dir, Path)
True
>>> # The directory "data/output" now exists on disk.

Notes

  • Uses pathlib.Path.mkdir(…, parents=True, exist_ok=True) under the hood for cross-platform compatibility.

  • If path already exists as a directory, this function returns immediately without modifying it.

See also

pathlib.Path.mkdir

Method to create a directory.

os.makedirs

Legacy function for creating directories recursively.

geoprior.utils.generic_utils.verify_identical_items(list1, list2, mode='unique', ops='check_only', error='raise', objname=None)[source]

Check if two lists contain identical elements according to the specified mode.

In “unique” mode, the function compares the unique elements in each list. In “ascending” mode, it compares elements pairwise in order.

Parameters:
  • list1 (list) – The first list of items.

  • list2` (list) – The second list of items.

  • mode ({'unique', 'ascending'}, default "unique") –

    The mode of comparison:
    • ”unique”: Compare unique elements (order-insensitive).

    • ”ascending”: Compare each element pairwise in order.

  • ops ({'check_only', 'validate'}, default "check_only") – If “check_only”, returns True/False indicating a match. If “validate”, returns the validated list.

  • error ({'raise', 'warn', 'ignore'}, default "raise") – Specifies how to handle mismatches.

  • objname (str, optional) – A name to include in error messages.

Returns:

Depending on ops, returns True/False or the validated list.

Return type:

bool or list

Examples

>>> from geoprior.utils.generic_utils import verify_identical_items
>>> list1 = [0.1, 0.5, 0.9]
>>> list2 = [0.1, 0.5, 0.9]
>>> verify_identical_items(list1, list2, mode="unique", ops="validate")
[0.1, 0.5, 0.9]
>>> verify_identical_items(list1, list2, mode="ascending", ops="check_only")
True

Notes

In “ascending” mode, both lists must have the same length, and the function compares each corresponding pair of elements. In “unique” mode, the function uses the set of unique values for comparison. If the lists contain mixed types, the function attempts to compare their string representations.

geoprior.utils.generic_utils.vlog(message, verbose=None, level=3, depth='auto', mode=None, vp=True, logger=None, **kws)[source]

Log or naive messages with optional indentation and bracketed tags.

This function, vlog, allows conditional logging or printing of messages based on a global or passed in <parameter inline> verbose level. By default, it behaves differently depending on whether mode is 'log' or 'naive'. When \(mode = 'log'\), the message is printed only if \(\text{verbose} \geq \text{level}\). Otherwise, for \(mode\) in [None, 'naive'], the verbosity threshold leads to various bracketed prefixes (e.g. [INFO], [DEBUG], [TRACE]) unless the message already contains such a prefix.

(10)#\[\text{indentation} = 2 \times \text{depth}\]

where \(\text{depth}\) is either manually specified or auto-derived based on <parameter inline> level (1 = ERROR, 2 = WARNING, 3 = INFO, 4/5 = DEBUG, 6/7 = TRACE).

Parameters:
  • message (str) – The text to be printed or logged.

  • verbose (int, optional) – Overall verbosity threshold. If None, it looks for a global variable named verbose. Default is None.

  • level (int, default 3) –

    Severity or importance level of the message. Commonly:

    • 1 = ERROR

    • 2 = WARNING

    • 3 = INFO

    • 4,5 = DEBUG

    • 6,7 = TRACE

  • depth (int or str, default "auto") – Indentation level used for the printed message. If "auto", the depth is computed from <parameter inline> level.

  • mode (str, optional) – Determines logging mode. If set to 'log', prints messages only if \(\text{verbose} \geq \text{level}\). Otherwise (if None or 'naive'), it follows a custom logic driven by <parameter inline> verbose.

  • vp (bool, default True) – If True, the function automatically prepends bracketed tags (e.g. [INFO]) unless the message already contains one of [INFO], [DEBUG], [ERROR], [WARNING], or [TRACE].

  • logger (logging.Logger or Callable[[str], None], optional) –

    Custom sink that receives the already-formatted message string.

    • If you pass a standard :pyclass:`logging.Logger` instance, the message is routed through logger.info.

    • If you supply any callable that accepts a single str (e.g. a GUI text-append function), that callable is invoked directly.

    • Defaults to :pyfunc:`print`, which writes to stdout.

  • kws (Logging instance, optional) – For future extensions.

Returns:

This function does not return anything. It either prints the message to stdout or omits it, depending on <parameter inline> verbose, <parameter inline> level, and mode.

Return type:

None

Notes

This function is helpful for selectively displaying or logging messages in applications that adapt to the user’s required verbosity. By default, each level has a specific bracketed tag and an auto indentation depth.

Examples

>>> from geoprior.utils.generic_utils import vlog
>>> # Example with mode='log'
>>> # This prints only if global or passed-in
>>> # verbose >= 4.
>>> vlog("Check debugging details.", verbose=3,
...      level=4, mode='log')
>>> # Example with mode='naive'
>>> # If verbose=2, it displays as [INFO] prefixed.
>>> vlog("Loading data...", verbose=2, mode='naive')

See also

globals

Used to retrieve the fallback verbose value if not explicitly passed.

geoprior.utils.generic_utils.detect_dt_format(series)[source]

Detect the datetime format of a pandas Series containing datetime values.

This function inspects a non-null sample from the datetime Series and infers the format string based on its components (year, month, day, hour, minute, and second). It returns a format string that can be used with strftime. For example, if the sample indicates only a year is relevant, it returns "%Y"; if full date information is present, it returns "%Y-%m-%d"; and if time details are also present, it extends the format accordingly.

Parameters:

series (pandas.Series) – A Series containing datetime values (dtype datetime64).

Returns:

A datetime format string (e.g., "%Y", "%Y-%m-%d", or "%Y-%m-%d %H:%M:%S") that represents the resolution of the data.

Return type:

str

Examples

>>> from geoprior.utils.generic_utils import detect_dt_format
>>> import pandas as pd
>>> dates = pd.to_datetime(['2023-01-01', '2024-01-01', '2025-01-01'])
>>> fmt = detect_dt_format(pd.Series(dates))
>>> print(fmt)
%Y

Notes

The detection logic checks if month, day, hour, minute, and second are all default values (e.g., month == 1, day == 1, hour == 0, etc.) and infers the most compact format that still represents the data accurately.

geoprior.utils.generic_utils.get_actual_column_name(df, tname=None, actual_name=None, error='raise', default_to=None)[source]

Determines the actual target column name in the given DataFrame.

Parameters:
  • df (pandas.DataFrame) – The DataFrame containing the target column.

  • tname (str, optional) – The base target name (e.g., “subsidence”). If not found in the DataFrame, it will attempt to find a matching column using “<tname>_actual” format.

  • actual_name (str, optional) – If provided, this name will be returned as the actual target column name.

  • error ({'raise', 'warn', 'ignore'}, default 'raise') – Specifies how to handle the case when no valid column is found: - ‘raise’: Raises a ValueError. - ‘warn’: Issues a warning and returns None. - ‘ignore’: Silently returns None.

Returns:

The determined actual column name, or None if no match is found and error=’warn’ or error=’ignore’.

Return type:

str or None

Raises:

ValueError – If no valid target column is found and error=’raise’.

Examples

>>> from geoprior.utils.generic_utils import get_actual_column_name
>>> df = pd.DataFrame({'subsidence_actual': [1, 2, 3]})
>>> get_actual_column_name(df, tname="subsidence")
'subsidence_actual'
>>> df = pd.DataFrame({'subsidence': [1, 2, 3]})
>>> get_actual_column_name(df, tname="subsidence")
'subsidence'
>>> df = pd.DataFrame({'actual': [1, 2, 3]})
>>> get_actual_column_name(df)
'actual'
>>> df = pd.DataFrame({'measurement': [1, 2, 3]})
>>> get_actual_column_name(df, tname="subsidence", error="warn")
Warning: Could not determine the actual target column in the DataFrame.
None
geoprior.utils.generic_utils.transform_contributions(contributions, to_percent=True, normalize=False, norm_range=(0, 1), scale_type=None, zero_division='warn', epsilon=1e-06, log_transform=False)[source]

Converts the feature contributions either to a direct percentage, normalizes them to a custom range, or applies a scaling strategy based on the chosen parameters.

Parameters:
contributionsdict

A dictionary where keys are feature names and values are the feature contributions. Each value is expected to be a numerical value representing the contribution of the respective feature.

to_percentbool, optional, default=True

Whether to convert the contributions to percentages. If True, each value in contributions will be multiplied by 100. This is useful when contributions are given in decimal form but are expected as percentages.

normalizebool, optional, default=False

Whether to normalize the contributions using min-max scaling. If True, the values will be scaled to the range defined in norm_range.

norm_rangetuple, optional, default=(0, 1)

A tuple specifying the range (min, max) for normalization. This range is applied when normalize is set to True. The contributions will be rescaled so that the minimum value maps to norm_range[0] and the maximum value maps to norm_range[1].

scale_typestr, optional, default=None

The scaling strategy. Options include: - 'zscore': Performs Z-score normalization. - 'log': Applies a logarithmic transformation to the data. If None, no scaling is applied.

zero_divisionstr, optional, default=’warn’

Defines how to handle zero or missing values in the contributions. Options include: - 'skip': Skips zero values (no modification). - 'warn': Issues a warning if zero values are found. - 'replace': Replaces zeros with a small value defined by

epsilon to avoid division by zero or undefined results.

epsilonfloat, optional, default=1e-6

A small value used to replace zeros when zero_division is set to 'replace'. This prevents division by zero errors during transformations like Z-score or log transformation.

log_transformbool, optional, default=False

Whether to apply a logarithmic transformation to the contributions. If True, it applies the natural logarithm to each value in the contributions dictionary. Only positive values are valid for log transformation, and zero values are either skipped or replaced based on the zero_division parameter.

Returns:
dict

A dictionary with feature names as keys and the transformed feature contributions as values. The transformation is applied according to the chosen parameters.

See also

numpy.mean

Compute the arithmetic mean of an array.

numpy.std

Compute the standard deviation of an array.

rac{X - mu}{sigma}

where \(X\) is the contribution, \(\mu\) is the mean of the contributions, and \(\sigma\) is the standard deviation of the contributions.

  • If log_transform=True, the function applies the natural logarithm:

    (11)#\[ext{log}(X) ext{ for } X > 0\]
  • The zero_division parameter handles zero values by either skipping, warning, or replacing them with a small value (epsilon).

Examples

>>> from geoprior.utils.generic_utils import transform_contributions
>>> contributions = {
>>>     'GWL': 2.226836617133828,
>>>     'rainfall_mm': 12.398293851061492,
>>>     'normalized_seismic_risk_score': 0.9402759347406523,
>>>     'normalized_density': 4.806074194258057,
>>>     'density_concentration': 5.666943330566496e-06,
>>>     'geology': 1.2798872011280326e-05,
>>>     'density_tier': 1.044039559604414e-05,
>>>     'rainfall_category': 0.0
>>> }
>>> transform_contributions(contributions, to_percent=True, normalize=True)
>>> transform_contributions(contributions, to_percent=False, scale_type='zscore')
geoprior.utils.generic_utils.exclude_duplicate_kwargs(func, existing_kwargs, user_kwargs)[source]

Prevents the user from overriding existing parameters in a target function. The method exclude_duplicate_kwargs checks both developer-specified and function-level parameter names to exclude them from user_kwargs.

(12)#\[ext{final\_kwargs} = \{\,(k, v) \in ext{user\_kwargs} \,\mid\, k\]

otin ext{protected_params},}

Parameters:
funccallable()

The target function whose valid parameters are checked. It uses Python’s introspection to gather the acceptable parameter names.

existing_kwargsdict or list

Developer-defined parameters to protect. Can be: * A dictionary of parameter-value pairs (e.g.,

{'ax': ax_obj, 'data': df}) whose keys are excluded from user overrides.

  • A list of parameter names (e.g., ['ax', 'data']) to protect from user overrides.

user_kwargsdict

The user-supplied keyword arguments that are candidates for merging with existing_kwargs. This dictionary is filtered to remove collisions with protected parameters.

Returns:
dict

A filtered dictionary of user-defined arguments that do not overlap with protected parameters.

Parameters:
Return type:

dict[str, Any]

See also

inspect.signature

Used to introspect function parameters.

filter_valid_kwargs

Another inline function that discards user params not valid for a given function.

Notes

By default, if existing_kwargs is a dictionary, its keys are treated as protected parameter names. If it’s a list, those items are protected. The function signature of func is also used to verify that only recognized parameters are protected. Keyword-filtering patterns like this are covered in Beazley and Jones [26].

Examples

>>> from geoprior.utils.generic_utils import exclude_duplicate_kwargs
>>> import seaborn as sns
>>> # Developer has some base kwargs
... base_kwargs = {
...     'x': 'species',
...     'y': 'sepal_length',
...     'palette': 'viridis'
... }
>>> # User tries to override 'x' with new param
... user_args = {
...     'x': 'petal_width',
...     'color': 'red'
... }
>>> # Filter out duplicates
... safe_args = exclude_duplicate_kwargs(
...     sns.scatterplot,
...     base_kwargs,
...     user_args
... )
>>> safe_args
{'color': 'red'}
geoprior.utils.generic_utils.reorder_columns(df, columns, pos='end')[source]

Reorder columns in a DataFrame by moving specified columns to a chosen position.

This function locates <columns> in the original DataFrame <df> and rearranges them based on the parameter pos. If pos is “end”, columns are appended to the end. If “begin” or “start”, they are placed at the front. If “center”, they are inserted at the midpoint:

Parameters:
  • df (pandas.DataFrame) – The input DataFrame to be modified.

  • columns (str or iterable of str) – A single column name or multiple column names to reposition. If a single string is given, it is converted to a list with one element.

  • pos (str, int, or float, default :py:class:``”end”:py:class:``) –

    Determines the target placement:
    • "end": Append after all other columns.

    • "begin" or "start": Prepend at the start.

    • "center": Insert at the midpoint of remaining columns.

    • integer or float: Insert at zero-based index among the remaining columns. If out of bounds, the original DataFrame is returned unchanged.

Returns:

A new DataFrame with <columns> moved as specified by pos.

Return type:

pandas.DataFrame

`reorder_columns_in`

This method rearranges columns without altering values or data order beyond column placement.

Notes

  • The function checks if <columns> exist in <df>, ignoring columns not present.

  • A warning is issued if the position is beyond the range of valid indices.

  • Negative indices for integer pos are converted to positive by adding the total number of remaining columns.

(13)#\[i_{\text{center}} = \left\lfloor \frac{|R|}{2} \right\rfloor,\]

where \(|R|\) is the number of remaining columns after removing the target columns. For integer or float pos, the target columns are inserted at index \(\lfloor pos \rfloor\) among the remaining columns. Column-order management follows common DataFrame practices discussed in McKinney [27].

Examples

>>> from geoprior.utils.generic_utils import reorder_columns
>>> import pandas as pd
>>> data = pd.DataFrame({
...     'id': [1, 2, 3],
...     'latitude': [10.1, 10.2, 10.3],
...     'landslide': [0, 1, 0],
...     'longitude': [20.1, 20.2, 20.3]
... })
>>> # Move 'landslide' to the end (default)
>>> reorder_columns(data, 'landslide', pos="end")
   id  latitude  longitude  landslide
0   1      10.1       20.1          0
1   2      10.2       20.2          1
2   3      10.3       20.3          0

See also

pandas.DataFrame.reindex

Pandas method for reindexing or reordering columns more generally.

geoprior.utils.generic_utils.find_id_column(df, strategy='naive', regex_pattern=None, uniqueness_threshold=0.95, errors='raise', empty_as_none=True, as_list=False, case_sensitive=False, as_frame=False)[source]

Identify potential ID column(s) in a pandas DataFrame using multiple heuristic strategies.

The function examines column names and/or data properties to detect columns likely to serve as unique identifiers. This is particularly useful for large datasets where the ID field is not explicitly labeled, and for quick scanning of possible key columns.

Parameters:
  • df (pandas.DataFrame) – The input DataFrame in which to search for potential ID columns.

  • strategy ({'naive', 'exact', 'dtype', 'regex','prefix_suffix'}, default 'naive') –

    Defines the logic for detecting ID columns: - exact: Checks for a column name that exactly

    matches id (case sensitivity controlled by case_sensitive).

    • naive: Searches for columns where id is part of the name (e.g., location_id) subject to case sensitivity.

    • prefix_suffix: Considers columns prefixed or suffixed with id or _id.

    • dtype: Examines columns having data types commonly used for IDs (integer, string, or object) and checks if they show high uniqueness via \(\text{uniqueness\_ratio} \geq \text{uniqueness\_threshold}\).

    • <regex>: Uses a custom regular expression <regex_pattern> to find matches in column names.

  • regex_pattern (str, optional) – Required if strategy is ‘regex’. The pattern is compiled via re.compile, with case sensitivity determined by <case_sensitive>.

  • uniqueness_threshold (float, default 0.95) –

    For <dtype> strategy, columns are flagged as ID candidates if the ratio:

    (14)\[r = \frac{ \text{unique\_values} }{ \text{non\_NA\_rows} }\]

    satisfies \(r \geq \text{uniqueness\_threshold}\), or if the number of unique values equals the number of non-null rows.

  • errors ({'raise', 'warn', 'ignore'}, default 'raise') –

    How to handle no-match cases:
    • raise: Raises a ValueError.

    • warn: Issues a UserWarning and returns based on <as_frame> or <empty_as_none>.

    • ignore: Returns an empty result based on the same parameters without warning.

  • empty_as_none (bool, default True) – Applies only if `as_frame` is False. Defines whether to return None (if True) or an empty list (if False) when no ID column is found and <errors> is ‘warn’ or ‘ignore’.

  • as_list (bool, default False) – If True, return all matched columns. If False, return only the first match. Affects both name returns and DataFrame returns.

  • case_sensitive (bool, default False) – If False, comparisons (including regex) are performed in a case-insensitive manner.

  • as_frame (bool, default False) – If True, return the matched columns as a pandas DataFrame. If as_list is True, it may include multiple columns. If no column is found, returns an empty DataFrame (if <errors> is ‘warn’ or ‘ignore’).

Returns:

Depends on as_frame, as_list, and the number of matching columns: - `<as_frame>`=False, `as_list`=False:

returns the first match as a string, or None/[].

  • `as_frame`=False, `as_list`=True: returns all matching column names as a list of strings.

  • `as_frame`=True, `as_list`=False: returns a DataFrame with the first matched column. If no match is found, an empty DataFrame may be returned.

  • `as_frame`=True, `as_list`=True: returns a DataFrame with all matched columns included.

Return type:

str or List[str] or pandas.DataFrame or None

Notes

  • For <dtype> strategy, integer, string, and object columns are inspected. The function calculates a uniqueness ratio and compares it against <uniqueness_threshold>.

  • Negative or zero thresholds are invalid, as are values above 1.

  • If the DataFrame has no columns or is empty, the behavior is determined by <errors>.

  • The relational-model motivation for schema-oriented column handling goes back to Codd [28].

Examples

>>> from geoprior.utils.generic_utils import find_id_column
>>> import pandas as pd
>>> data = pd.DataFrame({
...     'ID_code': [101, 102, 103],
...     'Name': ['Alice', 'Bob', 'Charlie'],
...     'value': [10, 20, 30]
... })
>>> # Example using the 'naive' strategy
>>> col = find_id_column(data, strategy='naive')
>>> print(col)  # Might return 'ID_code'
>>> # Example with as_list=True
>>> cols = find_id_column(data, strategy='naive',
...                       as_list=True)
>>> print(cols)  # ['ID_code']

See also

re.compile

The regex compilation method used when `strategy`=’regex’.

pandas.api.types.is_integer_dtype

Checks integer type.

pandas.api.types.is_string_dtype

Checksstring type.

pandas.api.types.is_object_dtype

Checksobject type.

geoprior.utils.generic_utils.check_group_column_validity(df, group_col, ops='check_only', max_unique=10, auto_bin=False, bins=4, error='warn', bin_labels=None, verbose=True)[source]

Validate a grouping column for categorical-style use and optionally bin it.

Parameters:
  • df (pandas.DataFrame) – Input DataFrame holding the grouping column.

  • group_col (str) – Name of the candidate grouping column in df.

  • ops ({'check_only', 'binning', 'validate'}, optional) – Operation mode. Use "check_only" to return a boolean, "binning" to bin the column when needed and return a modified DataFrame, or "validate" to check validity while honoring error.

  • max_unique (int, optional) – Maximum number of unique numeric values allowed before the column is treated as too continuous for categorical use.

  • auto_bin (bool, optional) – Whether to auto-bin a numeric column when ops='binning'.

  • bins (int, optional) – Number of bins to create when binning is applied.

  • error ({'warn', 'raise', 'ignore'}, optional) – Policy used when validation fails.

  • bin_labels (list of str or None, optional) – Custom labels for generated bins.

  • verbose (bool, optional) – Whether to emit informational messages.

Returns:

Returns a boolean for ops='check_only'. Otherwise returns a DataFrame, possibly with a transformed group_col.

Return type:

bool or pandas.DataFrame

Notes

When quantile binning is used, interval boundaries are derived from the numeric distribution of group_col.

geoprior.utils.generic_utils.save_all_figures(output_dir='figures', prefix='figure', fmts=('png',), close=True, dpi=150, transparent=False, timestamp=True, verbose=True)[source]

Save all currently open Matplotlib figures to disk in specified formats.

Parameters:
  • output_dir (str) – Directory where figures will be saved. Created if not exists.

  • prefix (str) – Filename prefix for each figure.

  • formats (list or tuple of str) – File formats/extensions to use (e.g., (‘png’,’pdf’)).

  • close (bool) – Whether to close each figure after saving. Default is True.

  • dpi (int or None) – Resolution in dots per inch. None uses Matplotlib default.

  • transparent (bool) – Whether to save figures with transparent background.

  • timestamp (bool) – Append current timestamp (YYYYmmddTHHMMSS) to filenames.

  • verbose (bool) – Print progress messages.

  • fmts (list[str] | tuple)

Returns:

List of saved file paths.

Return type:

List[str]

Examples

>>> import matplotlib.pyplot as plt
>>> plt.figure(); plt.plot([1, 2, 3])
>>> from geoprior.utils.generic_utils import save_all_figures
>>> paths = save_all_figures(output_dir="plots", formats=("png",))
>>> print(paths)
['plots/figure_1_20250521T153045.png']
geoprior.utils.generic_utils.rename_dict_keys(data, param_to_rename=None, order='forward')[source]

Renames keys in the data dictionary based on the provided param_to_rename dictionary.

This function will check if the key exists in the data dictionary. If the key is present, it will be renamed according to the mapping provided in the param_to_rename dictionary. If the key is not found in data and a mapping exists in param_to_rename, the function will apply the rename. If no rename is required, the function will return the original dictionary.

Parameters:
  • data (dict) – The dictionary whose keys may be renamed. The function will iterate over the keys of this dictionary and rename them according to the mapping provided in param_to_rename.

  • param_to_rename (dict, optional) – A dictionary mapping old keys to new keys. Each key in this dictionary represents an old key that may be found in data, and the corresponding value is the new key. If None, no renaming is performed. If a key in data matches an old key in param_to_rename, that key will be renamed.

  • order (str, {'forward', 'reverse'}:) –

    Order for renaming keys in a flat dict:

    forward (default):
        param_to_rename = {old_key: new_key}
    
    reverse:
      param_to_rename = {
        canonical_key: alias or (alias1, alias2, ...)
      }
      The first alias found in `data` is moved under the
      canonical key. If the canonical key already exists,
      nothing is changed for that mapping.
    

Returns:

The updated dictionary with keys renamed as per the param_to_rename mapping. If no keys need renaming, the original dictionary is returned.

Return type:

dict

Raises:

ValueError – If param_to_rename is not a dictionary, a ValueError will be raised.

Examples

>>> from geoprior.utils.generic_utils import
Example 1: Renaming a key in the dictionary:
>>> data = {"subsidence": 100}
>>> param_to_rename = {"subsidence": "subs_pred"}
>>> rename_dict_keys(data, param_to_rename)
{'subs_pred': 100}

Example 2: When the key is already valid (no change needed):

>>> data = {"subs_pred": 100}
>>> param_to_rename = {"subsidence": "subs_pred"}
>>> rename_dict_keys(data, param_to_rename)
{'subs_pred': 100}

Example 3: When param_to_rename is None, no renaming is performed:

>>> data = {"subsidence": 100}
>>> rename_dict_keys(data)
{'subsidence': 100}

Notes

  • If param_to_rename is None, no renaming occurs, and the data dictionary is returned as is.

  • This function raises an error if param_to_rename is not a dictionary. Ensure that the parameter is a valid dictionary of old-to-new key mappings.

geoprior.utils.generic_utils.normalize_time_column(df, time_col, datetime_col='datetime_temp', year_col='year_int', drop_orig=False)[source]

Normalize a time column into a datetime column and an integer year.

The input column may contain integer years, strings, or existing pandas Datetime values. The function creates datetime_col with parsed timestamps and year_col with the extracted integer year. When drop_orig=True, the original time_col is removed and datetime_col is renamed back to time_col.

Parameters:
  • df (pandas.DataFrame) – Input DataFrame containing a time column named time_col.

  • time_col (str) – Name of the column to normalize.

  • datetime_col (str, default 'datetime_temp') – Name of the parsed datetime column.

  • year_col (str, default 'year_int') – Name of the extracted integer year column.

  • drop_orig (bool, default False) – If True, drop the original time_col after parsing and rename datetime_col back to time_col.

Returns:

A copy of df with the parsed datetime column and integer year column.

Return type:

pandas.DataFrame

Raises:
  • ValueError – If time_col is missing or parsing fails for any entry.

  • TypeError – If df is not a pandas DataFrame.

geoprior.utils.generic_utils.select_mode(mode=None, default='pihal_like', canonical=None)[source]

Resolve a user-supplied mode alias to a canonical value.

Parameters:
  • mode (str or None, optional) – Case-insensitive mode alias. Accepted values include 'pihal', 'pihal_like', 'tft', 'tft_like', or None to fall back to default.

  • default ({'pihal', 'tft'}, optional) – Canonical value returned when mode is None.

  • canonical (dict or list or None, optional) – Custom alias mapping. A dictionary maps input strings to canonical values. A list is treated as an identity mapping for its items.

Returns:

Canonical string corresponding to the resolved mode.

Return type:

str

Raises:

ValueError – If mode does not match any accepted alias.

geoprior.utils.generic_utils.normalize_model_inputs(*data)[source]
Parameters:

data (DataFrame | Mapping[str, DataFrame] | list | tuple)

Return type:

dict[str, DataFrame]

geoprior.utils.generic_utils.print_config_table(sections, title=None, table_width=None, sort_keys=True, key_col_fraction=0.35, max_value_length=200, log_fn=None)[source]

Pretty-print configuration or hyperparameters as a key/value table.

This helper is intended for CLI scripts (Stage-1, training, tuning) so that the user can quickly inspect which parameters are actually in effect.

Parameters:
  • sections (dict or sequence of (str, dict)) –

    If a single dict is passed, all key/value pairs are printed in one block.

    If a sequence is passed, it must contain (name, params) tuples, where name is a section label (e.g. "Physics") and params is a dict mapping parameter names to values.

  • title (str, optional) – Optional title displayed above the table (centered).

  • table_width (int, optional) – Total width of the printed table. If None, the function tries to use geoprior.api.util.get_table_size(). If that fails, it falls back to the terminal width (via shutil.get_terminal_size) or 80 characters.

  • sort_keys (bool, default True) – Whether to sort parameter names alphabetically within each section.

  • key_col_fraction (float, default 0.35) – Fraction of the table width allocated to the parameter-name column. The remainder is used for the value column.

  • max_value_length (int, default 200) – Maximum number of characters kept from the stringified value. Longer values are truncated with an ellipsis ("...") before being wrapped onto multiple lines.

  • log_fn (callable, optional) – Function used to emit lines (defaults to print()). This allows capturing the table in logs if needed.

Returns:

The full rendered table as a single string. It is always printed via print_fn as a side effect.

Return type:

str

Notes

  • Nested containers (lists, tuples, dicts) are rendered in a compact one-line form and then wrapped to fill the value column.

  • This function is intentionally lightweight and does not depend on external tabulation libraries, so it can be safely used in lightweight Stage-1 / Stage-2 scripts.

This module contains reusable helpers for paths, result folders, configuration display, and figure-saving behavior.

Geo/spatial helpers#

Geospatial utility helpers for GeoPrior workflows.

geoprior.utils.geo_utils.augment_city_spatiotemporal_data(df, city, mode='interpolate', group_by_cols=None, time_col=None, value_cols_interpolate=None, feature_cols_augment=None, interpolation_config=None, augmentation_config=None, target_name=None, interpolate_target=False, verbose=True, coordinate_precision=None, savefile=None)[source]

Apply grouped spatiotemporal augmentation with city-aware defaults.

This is a convenience wrapper around augment_spatiotemporal_data. It validates the requested city, optionally rounds coordinates before grouping, and forwards interpolation and augmentation configuration dictionaries.

Parameters:
  • df (pandas.DataFrame) – Input DataFrame containing spatial, temporal, and feature columns.

  • city ({'nansha', 'zhongshan'}) – City identifier used for validation and defaults.

  • mode ({'interpolate', 'augment_features', 'both'}, optional) – Processing mode forwarded to augment_spatiotemporal_data.

  • group_by_cols (list of str or None, optional) – Grouping columns for interpolation.

  • time_col (str or None, optional) – Time column used for interpolation.

  • value_cols_interpolate (list of str or None, optional) – Columns to interpolate.

  • feature_cols_augment (list of str or None, optional) – Columns to augment with noise.

  • interpolation_config (dict or None, optional) – Keyword arguments for interpolate_temporal_gaps. Typical values include {'freq': 'AS', 'method': 'linear'}.

  • augmentation_config (dict or None, optional) – Keyword arguments for augment_series_features. Typical values include {'noise_level': 0.01, 'noise_type': 'gaussian'}.

  • target_name (str or None, optional) – Optional target column name used when inferring default feature sets.

  • interpolate_target (bool, optional) – Whether the target should be included in default interpolation columns.

  • verbose (bool, optional) – Whether to emit progress information.

  • coordinate_precision (int or None, optional) – Decimal precision applied to coordinates before grouping.

  • savefile (str or None, optional) – Optional output CSV path handled by the decorator.

Returns:

Augmented DataFrame.

Return type:

pandas.DataFrame

Raises:
  • ValueError – If city or mode is invalid, or if required arguments are missing for the selected mode.

  • TypeError – If the main inputs are of the wrong type.

geoprior.utils.geo_utils.augment_series_features(series_df, feature_cols, noise_level=0.01, noise_type='gaussian', random_seed=None, savefile=None)[source]

Add random noise to selected numeric feature columns.

Parameters:
  • series_df (pandas.DataFrame) – Input DataFrame representing one or more time series.

  • feature_cols (list of str) – Feature columns to augment.

  • noise_level (float, optional) – Magnitude of the added noise. For Gaussian noise it scales the feature standard deviation, and for uniform noise it scales the feature range.

  • noise_type ({'gaussian', 'uniform'}, optional) – Type of noise distribution to use.

  • random_seed (int or None, optional) – Seed for reproducible noise generation.

  • savefile (str or None, optional) – Optional output path handled by the decorator.

Returns:

DataFrame with noise added to the selected feature columns.

Return type:

pandas.DataFrame

Raises:
  • ValueError – If requested feature columns are missing or noise_type is invalid.

  • TypeError – If the main inputs are of the wrong type.

Notes

Non-numeric columns are skipped, and constant or invalid numeric ranges are left unchanged.

geoprior.utils.geo_utils.augment_spatiotemporal_data(df, mode, group_by_cols=None, time_col=None, value_cols_interpolate=None, feature_cols_augment=None, interpolation_kwargs=None, augmentation_kwargs=None, savefile=None, verbose=False)[source]

Apply interpolation, feature augmentation, or both to grouped data.

Parameters:
  • df (pandas.DataFrame) – Input spatiotemporal DataFrame.

  • mode ({'interpolate', 'augment_features', 'both'}) – Processing mode. Use interpolation only, feature augmentation only, or interpolation followed by augmentation.

  • group_by_cols (list of str or None, optional) – Grouping columns used for per-location processing.

  • time_col (str or None, optional) – Time column required when interpolation is requested.

  • value_cols_interpolate (list of str or None, optional) – Value columns to interpolate when interpolation is enabled.

  • feature_cols_augment (list of str or None, optional) – Feature columns to perturb when augmentation is enabled.

  • interpolation_kwargs (dict or None, optional) – Keyword arguments forwarded to interpolate_temporal_gaps.

  • augmentation_kwargs (dict or None, optional) – Keyword arguments forwarded to augment_series_features.

  • savefile (str or None, optional) – Optional output path handled by the decorator.

  • verbose (bool, optional) – Whether to emit progress information.

Returns:

Processed DataFrame assembled from all groups.

Return type:

pandas.DataFrame

Raises:

ValueError – If mode is invalid or required arguments for the selected mode are missing.

Notes

Groups are processed independently and concatenated afterward.

geoprior.utils.geo_utils.generate_dummy_pinn_data(n_samples, *, year_range=None, coords_range=None, subs_range=None, gwl_range=None, rainfall_range=None, vars_range=None)[source]

Generate dummy PINN data dictionary with specified or default ranges.

Parameters:
  • n_samples (int) – Number of samples to generate.

  • year_range (tuple[float, float], optional) – (min_year, max_year) for integer years. Default (2000, 2025).

  • coords_range (tuple[tuple[float, float], tuple[float, float]], optional) – ((lon_min, lon_max), (lat_min, lat_max)). Default ((113.0, 113.8), (22.3, 22.8)).

  • subs_range (tuple[float, float], optional) – (mean_subsidence, std_subsidence) for normal distribution. Default (-20, 15).

  • gwl_range (tuple[float, float], optional) – (mean_gwl, std_gwl) for normal distribution. Default (2.5, 1.0).

  • rainfall_range (tuple[float, float], optional) – (min_rain, max_rain) for uniform distribution. Default (500, 2500).

  • vars_range (dict, optional) – Dictionary that may contain any of the keys: ‘year_range’, ‘coords_range’, ‘subs_range’, ‘gwl_range’, ‘rainfall_range’. Missing keys will fall back to defaults or to explicitly passed arguments.

Returns:

dummy_data_dict

Dictionary with keys:
  • ”year” : integer years array

  • ”longitude” : float longitudes array

  • ”latitude” : float latitudes array

  • ”subsidence” : float subsidence values array

  • ”GWL” : float groundwater level values array

  • ”rainfall_mm” : float rainfall values array

Return type:

dict[str, np.ndarray]

geoprior.utils.geo_utils.interpolate_temporal_gaps(series_df, time_col, value_cols, freq=None, method='linear', order=None, fill_limit=None, fill_limit_direction='forward', savefile=None)[source]

Interpolates missing values in specified columns of a time series DataFrame.

This function is designed to work on a DataFrame representing a time series for a single spatial group (e.g., one monitoring location), sorted by time. If freq is provided, the DataFrame’s index is first reindexed to that frequency, which can create NaN values for missing time steps. These NaNs, along with any pre-existing NaNs in value_cols, are then interpolated.

Let \(t_1 < t_2 < \dots < t_n\) be the original timestamps. If freq yields a new index \(\{t_i'\}\) that includes times not in the original, NaNs appear at those \(t_i'\). Then for each column \(v\) in \(\{\text{value\_cols}\}\), we perform:

(15)#\[\begin{split}v(t) \;=\; \begin{cases} \text{interpolate}(v,\;t;\;\text{method},\;\dots) & \text{for } t \in \{t_i'\}\,,\\ v(t) & \text{if } t \in \{t_1,\dots,t_n\}\text{ and not NaN.} \end{cases}\end{split}\]
Parameters:
  • series_df (pd.DataFrame) – Input DataFrame for a single time series, ideally sorted by time_col. The time_col should be convertible to datetime.

  • time_col (str) – Name of the column containing datetime information.

  • value_cols (List[str]) – List of column names whose missing values (NaNs) should be interpolated.

  • freq (str or None, default None) – The desired frequency for the time series (e.g., ‘D’ for daily, ‘MS’ for month start, ‘AS’ for year start). If provided, the DataFrame is reindexed to this frequency before interpolation. This helps identify and fill gaps where entire time steps are missing.

  • method (str, default 'linear') – Interpolation method to use. Passed to pandas.DataFrame.interpolate(). Common methods: ‘linear’, ‘time’, ‘polynomial’, ‘spline’. If ‘polynomial’ or ‘spline’, order must be specified.

  • order (int or None, default None) – Order for polynomial or spline interpolation. Required if method is ‘polynomial’ or ‘spline’.

  • fill_limit (int or None, default None) – Maximum number of consecutive NaNs to fill. Passed to pandas.DataFrame.interpolate().

  • fill_limit_direction (str, default 'forward') – Direction for fill_limit (‘forward’, ‘backward’, ‘both’). Passed to pandas.DataFrame.interpolate().

  • savefile (str | None)

Returns:

DataFrame with specified columns interpolated. If freq was used, the DataFrame will have a DatetimeIndex. Other columns not in value_cols will be forward-filled after reindexing if freq is set, to propagate their last known values into new empty rows.

Return type:

pd.DataFrame

Raises:
  • TypeError – If series_df is not a DataFrame or if value_cols is not a list of strings. Also if time_col is missing from the DataFrame.

  • ValueError – If order is required but not provided for ‘polynomial’ or ‘spline’.

Examples

>>> import pandas as pd
>>> from geoprior.utils.geo_utils import interpolate_temporal_gaps
>>> # Sample time series with missing dates
>>> df = pd.DataFrame({
...     'date': ['2020-01-01', '2020-01-03', '2020-01-06'],
...     'value': [1.0, None, 4.0]
... })
>>> df
         date  value
0 2020-01-01    1.0
1 2020-01-03    NaN
2 2020-01-06    4.0
>>> result = interpolate_temporal_gaps(
...     df, time_col='date', value_cols=['value'], freq='D'
... )
>>> result.head()
         date  value
0 2020-01-01    1.0
1 2020-01-02    2.0
2 2020-01-03    3.0
3 2020-01-04    3.0
4 2020-01-05    3.5

Notes

  • Ensure series_df pertains to a single spatial group and is sorted by time for meaningful interpolation.

  • The ‘time’ method for interpolation requires the index to be a DatetimeIndex.

  • Polynomial or spline methods require order to be specified.

See also

pandas.DataFrame.interpolate

Core interpolation method.

pandas.DataFrame.asfreq

Reindex DataFrame to fixed frequency.

geoprior.utils.geo_utils.resolve_spatial_columns(df, spatial_cols=None, lon_col=None, lat_col=None)[source]

Helper to validate and resolve spatial columns.

Accepts either explicit lon/lat columns or a list of spatial_cols. Returns (lon_col, lat_col).

  • If lon_col and lat_col are both provided, they take precedence (warn if spatial_cols also set).

  • Else if spatial_cols is provided, it must yield exactly two column names.

  • Otherwise, error is raised.

Parameters:
  • df (pd.DataFrame) – Input DataFrame for feature checks.

  • spatial_cols (list[str] or None) – Two-element list of [lon_col, lat_col].

  • lon_col (str or None) – Name of longitude column.

  • lat_col (str or None) – Name of latitude column.

Returns:

(lon_col, lat_col) – Validated column names for longitude and latitude.

Return type:

tuple of str

Raises:

ValueError – If neither lon/lat nor valid spatial_cols is provided, or if spatial_cols len != 2.

geoprior.utils.geo_utils.merge_frames_to_file(sources, output_path, *, output_format='parquet', compression='snappy', check_columns='strict', excel_mode='all_sheets', sheet_names=None, add_source_label=True, source_col='source', sort_by=None, drop_duplicates=False, reset_index=True, save_kwargs=None, verbose=1)[source]

Merge multiple NATCOM city datasets into a single compressed file.

Parameters:
  • sources (iterable of {path-like, DataFrame}) –

    Input sources. Each element can be:

    • A path to a CSV file (e.g. "nansha_final...csv"),

    • A path to an Excel workbook (one or many sheets per city),

    • A pre-loaded DataFrame.

  • output_path (path-like) – Destination file path. If output_format='parquet' and the suffix is missing, '.parquet' is appended.

  • output_format ({'parquet', 'csv', 'feather', 'pickle'}, optional) – Output format. Default is 'parquet' for compact, columnar storage (recommended for Code Ocean).

  • compression (str or None, optional) –

    Compression to use for the chosen format.

    • For 'parquet' this is passed to pandas.DataFrame.to_parquet() (e.g. 'snappy', 'gzip', 'brotli').

    • For 'csv' it is passed to pandas.DataFrame.to_csv() via the compression keyword if non-None.

    • Ignored for 'feather' and 'pickle' (Feather uses its own defaults; pickle rarely benefits from extra compression at this layer).

  • check_columns ({'strict', 'subset', 'union'}, optional) –

    How to handle column consistency across sources:

    • 'strict' (default): all sources must have exactly the same set of columns (order may differ). Columns are then aligned to the order of the first DataFrame. A mismatch raises ValueError.

    • 'subset': all columns in the first DataFrame must exist in each subsequent DataFrame. Extra columns in later sources are dropped. Missing required columns raise ValueError.

    • 'union': columns are unioned across all sources. Any missing column in a particular source is added and filled with NaN before concatenation.

  • excel_mode ({'all_sheets', 'first_sheet'}, optional) –

    Behaviour when a source is an Excel workbook:

    • 'all_sheets' (default): read all sheets and treat each sheet as a separate DataFrame to merge.

    • 'first_sheet': read only the first sheet.

    If sheet_names is provided, it takes precedence.

  • sheet_names (iterable of str, optional) – Explicit sheet names to read from Excel workbooks. If provided, only these sheets are read.

  • add_source_label (bool, optional) – If True (default), add a column named source_col to each chunk before concatenation. For path-like inputs, the label is derived from the file name and, when applicable, sheet name (e.g. "nansha_final_main_std.harmonized.csv" or "zhongshan.xlsx:Sheet1"). For pre-loaded DataFrames, the label is '<in_memory>'.

  • source_col (str, optional) – Name of the column storing the source label when add_source_label=True.

  • sort_by (iterable of str, optional) – Optional column(s) to sort the merged DataFrame by at the end (e.g. ['city', 'year', 'longitude', 'latitude']).

  • drop_duplicates (bool, optional) – If True, drop duplicate rows at the end (after sorting).

  • reset_index (bool, optional) – If True (default), reset index after concatenation.

  • save_kwargs (dict, optional) – Extra keyword arguments forwarded to the corresponding to_* writer (e.g. to_parquet, to_csv, to_feather, to_pickle).

  • verbose (int, optional) – Verbosity level. 0 = silent, >=1 prints basic progress information.

Returns:

merged – The merged DataFrame (also written to disk).

Return type:

pandas.DataFrame

Raises:

ValueError – If check_columns='strict' or 'subset' and a column mismatch is detected.

Examples

>>> from geoprior.utils.geo_utils import merge_frames_to_file
>>> merge_frames_to_file(
...    sources=[
...        "nansha_final_main_std.harmonized.csv",
...        "zhongshan_final_main_std.harmonized.csv",
...    ],
...    output_path="natcom_all_cities",
...    output_format="parquet",
...    compression="snappy",
...    sort_by=["city", "year", "longitude", "latitude"],
... )

Notes

  • All inputs are read fully into memory before concatenation. This is acceptable for the NATCOM subsidence datasets (O(10^6 - 10^7) rows) but can be refactored to a streaming/row-group approach if needed later.

  • Using output_format='parquet' with compression (e.g. 'snappy') is recommended for Code Ocean to minimise disk usage while keeping I/O efficient.

geoprior.utils.geo_utils.unpack_frames_from_file(merged, *, group_col='city', output_dir=None, output_format='csv', compression=None, use_source_col=True, source_col='source', filename_pattern='{group_value}_split', drop_columns=None, keep_columns=None, save=True, return_dict=True, save_kwargs=None, verbose=1, logger)[source]

Reverse of merge_city_frames_to_file: split an aggregated NATCOM dataset into per-city frames (and optionally write them to disk).

Parameters:
  • merged (path-like or DataFrame) –

    Aggregated dataset. If path-like, the format is inferred from the file suffix:

    If a DataFrame is passed, it is used directly.

  • group_col (str, optional) – Column used to split the dataset (default: 'city'). Each unique value defines one output chunk.

  • output_dir (path-like, optional) – Directory where per-group files are written. If None and merged is a path, the directory of merged is used. If merged is a DataFrame and output_dir is None, the current working directory is used.

  • output_format ({'csv', 'parquet', 'feather', 'pickle'}, optional) – Output format for per-group files. Default is 'csv'.

  • compression (str or None, optional) –

    Compression to use when writing:

    • For 'csv', forwarded to DataFrame.to_csv() as the compression argument (e.g. 'gzip').

    • For 'parquet', forwarded to DataFrame.to_parquet() (e.g. 'snappy', 'gzip').

    • Ignored for 'feather' and 'pickle' (these use their own defaults).

  • use_source_col (bool, optional) –

    If True (default) and a column named source_col exists, the helper tries to reconstruct the original file name for each group:

    • If a group has a single unique, non-null source value that looks like a filename (e.g. 'nansha_final_main_std.harmonized.csv'), that base name is used for the output (with its suffix adjusted to match output_format if needed).

    • If there are multiple unique source labels within a group, it falls back to filename_pattern.

  • source_col (str, optional) – Name of the column containing the source label (default: 'source'). This should match the column created in merge_frames_to_file() when add_source_label=True.

  • filename_pattern (str, optional) –

    Pattern used when no suitable source label is available. The following placeholders are supported:

    • {group_value} : the group value as a string

    • {group_col} : the name of the grouping column

    Example: filename_pattern="{group_col}_{group_value}_data""city_Nansha_data.csv".

  • drop_columns (iterable of str, optional) – Columns to drop from each group before saving/returning (e.g. ['source'] if you don’t want the bookkeeping column).

  • keep_columns (iterable of str, optional) – If provided, only these columns are kept (all others are dropped after any drop_columns processing is applied).

  • save (bool, optional) – If True (default), write each group to disk as a separate file. If False, no files are written; only the dict of DataFrames is returned (if return_dict=True).

  • return_dict (bool, optional) – If True (default), return a mapping {group_value: group_df}. If False, an empty dict is returned (useful when you only care about side-effect files).

  • save_kwargs (dict, optional) – Extra keyword arguments forwarded to the respective writer: DataFrame.to_csv(), DataFrame.to_parquet(), DataFrame.to_feather(), or DataFrame.to_pickle().

  • verbose (int, optional) – Verbosity level. 0 = silent, >=1 prints progress information.

  • logger (None)

Returns:

out – Dictionary mapping each group value to the corresponding DataFrame. Empty if return_dict=False.

Return type:

dict

Raises:

ValueError – If group_col is not present in the merged dataset.

Examples

>>> from geoprior.utils.geo_utils import unpack_frames_from_file
>>> unpack_frames_from_file(
...     "natcom_all_cities.parquet",
...     group_col="city",
...     output_format="csv",
... )
# -> writes e.g. 'nansha_final_main_std.harmonized.csv',
#    'zhongshan_final_main_std.harmonized.csv' (if `source` labels exist),
#    and returns a dict: {'Nansha': df_nansha, 'Zhongshan': df_zhongshan}

geospatial_utils - A collection of utilities for geospatial and positional data analysis, filtering, and transformations.

geoprior.utils.spatial_utils.spatial_sampling(data, sample_size=0.01, stratify_by=None, spatial_bins=10, spatial_cols=None, method='abs', min_relative_ratio=0.01, random_state=42, savefile=None, verbose=1)[source]

Sample spatial data intelligently to represent the distribution of the whole area and include different years.

This function performs stratified sampling on spatial data, ensuring that the sample reflects both spatial distribution and temporal aspects of the entire dataset. It combines spatial stratification based on coordinates and additional stratification columns specified by the user.

Parameters:
  • data (pandas.DataFrame) – The input DataFrame to sample from. Must contain spatial coordinate columns (e.g., ‘longitude’, ‘latitude’) and any columns specified in stratify_by.

  • sample_size (float or int, optional) – The proportion or absolute number of samples to select. If float, should be between 0.0 and 1.0 and represents the fraction of the dataset to include in the sample. If int, represents the absolute number of samples to select. Default is 0.01 (1% of the data).

  • stratify_by (list of str, optional) – List of column names to stratify by.

  • spatial_bins (int or tuple/list of int, optional) – Number of bins to divide the spatial coordinates into. If an integer, the same number of bins is used for all spatial dimensions. If a tuple or list, its length must match the number of spatial columns, specifying the number of bins for each spatial dimension. Default is 10.

  • spatial_cols (list or tuple of str, optional) – List of spatial coordinate column names. Can accept one or two columns. If None, the function checks for columns named ‘longitude’ and/or ‘latitude’ in data. If only one spatial column is provided or found, a warning is issued, suggesting that providing both spatial columns is recommended for more accurate sampling. If more than two columns are provided, an error is raised.

  • method (str, {'abs', 'relative'}, default 'abs') – Defines how the sample size is determined. 'abs' or 'absolute' uses a fixed sampling proportion based on sample_size. 'relative' scales sampling by dataset stratification so small groups still receive a proportional sample controlled by min_relative_ratio.

  • min_relative_ratio (float, default 0.01) – Controls the minimum allowable fraction of records that must be sampled when method='relative'. It must be between 0 and 1. For example, min_relative_ratio=0.05 requests at least 5 percent of the total dataset size from each stratification group when possible; if a group is smaller than that minimum, the entire subset is sampled instead.

  • random_state (int, optional) – Random seed for reproducibility. Default is 42.

  • verbose (int, default 1) – Controls progress-bar and status output during execution. Larger values produce more detailed messages.

Returns:

sampled_data – A sampled DataFrame representing the distribution of the whole area and including different years.

Return type:

pandas.DataFrame

Notes

The function performs stratified sampling based on spatial bins and other specified stratification columns. Spatial coordinates are binned using quantile-based discretization (pandas.qcut()), ensuring each bin has approximately the same number of observations.

Let \(N\) be the total number of samples in data, and \(n\) be the desired sample size. The function calculates the number of samples to draw from each stratification group based on the proportion of the group size to the total dataset size:

(16)#\[n_i = \left\lceil \frac{N_i}{N} \times n \right\rceil\]

where \(N_i\) is the size of group \(i\), and \(n_i\) is the number of samples to draw from group \(i\).

The function ensures that all specified spatial and stratification columns exist in data, that the number of spatial bins matches the number of spatial columns, and that the sample size is valid. A warning is issued when only one spatial column is used because two spatial columns usually give more reliable spatial sampling.

Examples

>>> from geoprior.utils.spatial_utils import spatial_sampling
>>> import pandas as pd
>>> # Assume 'df' is a pandas DataFrame with columns
>>> # 'longitude', 'latitude', 'year', and other data.
>>> sampled_df = spatial_sampling(
...     data=df,
...     sample_size=0.05,
...     stratify_by=['year', 'geological_category'],
...     spatial_bins=(10, 15),
...     spatial_cols=['longitude', 'latitude'],
...     random_state=42
... )
>>> print(sampled_df.shape)

See also

pandas.qcut

Quantile-based discretization function used for binning.

sklearn.model_selection.StratifiedShuffleSplit

For stratified sampling.

batch_spatial_sampling

Resample spatial data with batching.

geoprior.utils.spatial_utils.extract_coordinates(df, as_frame=False, drop_xy=False, error='raise', verbose=0)[source]

Extract coordinate columns or their midpoint from a DataFrame.

Parameters:
  • df (pandas.DataFrame) – DataFrame expected to contain longitude/latitude or easting/northing columns.

  • as_frame (bool, optional) – If True, return the coordinate columns as a DataFrame. Otherwise return their midpoint.

  • drop_xy (bool, optional) – If True, remove detected coordinate columns from the returned DataFrame.

  • error (bool or {'raise', 'warn', 'ignore'}, optional) – Error-handling policy for invalid inputs.

  • verbose (int, optional) – Verbosity level for detection messages.

Returns:

Tuple containing the extracted coordinates or midpoint, the DataFrame with optional coordinate removal, and the detected coordinate-column names.

Return type:

tuple

Notes

Longitude/latitude are preferred over easting/northing when both are present.

geoprior.utils.spatial_utils.batch_spatial_sampling(data, sample_size=0.1, n_batches=10, stratify_by=None, spatial_bins=10, spatial_cols=None, method='abs', min_relative_ratio=0.01, random_state=42, verbose=1)[source]

Create stratified spatial sample batches from a DataFrame.

Parameters:
  • data (pandas.DataFrame) – Input DataFrame used for sampling.

  • sample_size (float or int, optional) – Total sample size as a fraction or absolute count.

  • n_batches (int, optional) – Number of batches to generate.

  • stratify_by (str or list of str or None, optional) – Additional columns used for stratification.

  • spatial_bins (int or sequence of int, optional) – Number of spatial bins used when discretizing coordinates.

  • spatial_cols (list or tuple of str or None, optional) – Spatial coordinate columns.

  • method ({'abs', 'absolute', 'relative'}, optional) – Strategy used to translate sample_size into per-batch sample counts.

  • min_relative_ratio (float, optional) – Minimum relative group size used by method='relative'.

  • random_state (int, optional) – Random seed for reproducibility.

  • verbose (int, optional) – Verbosity level.

Returns:

Stratified batches sampled without overlap.

Return type:

list of pandas.DataFrame

Notes

Spatial coordinates are discretized with pandas.qcut and combined with stratify_by columns so batches preserve the overall data distribution as closely as possible.

geoprior.utils.spatial_utils.extract_zones_from(z, threshold='auto', condition='auto', use_negative_criteria=True, percentile=10, x=None, y=None, data=None, view=False, plot_type='scatter', figsize=(8, 6), savefile=None, axis_off=False, show_grid=True, **kwargs)[source]

Extract zones by filtering values against a threshold rule.

Parameters:
  • z (array-like, pandas.Series, or str) – Input data to filter. If z is a string, it is interpreted as a column name in data.

  • threshold ({'auto'} or float or int or tuple, optional) – Filtering criterion. Use 'auto' for percentile-based thresholding, a scalar for a single cutoff, or a length-2 tuple for interval filtering.

  • condition ({'auto', 'above', 'below', 'between'}, optional) – Relation between the data and the threshold.

  • use_negative_criteria (bool, optional) – Controls the automatic condition when condition='auto'.

  • percentile (int or float, optional) – Percentile used when threshold='auto'.

  • x (array-like, pandas.Series, str, or None, optional) – Optional coordinates or column names used for plotting.

  • y (array-like, pandas.Series, str, or None, optional) – Optional coordinates or column names used for plotting.

  • data (pandas.DataFrame or None, optional) – Data source used when x, y, or z are column names.

  • view (bool, optional) – Whether to visualize the filtered result.

  • plot_type (str, optional) – Plot type used when view=True. Common values include 'scatter', 'line', and 'hist'.

  • figsize (tuple, optional) – Figure size for plotting.

  • savefile (str or None, optional) – Optional path used when saving the figure.

  • axis_off (bool, optional) – Whether to hide axes in the plot.

  • show_grid (bool, optional) – Whether to display the plot grid.

  • **kwargs (dict) – Additional plotting keyword arguments.

Returns:

Filtered values and any optional plotting outputs defined by the implementation.

Return type:

object

Notes

When x, y, or z are passed as strings, the function relies on extract_array_from to retrieve the corresponding arrays from data.

geoprior.utils.spatial_utils.filter_position(df, pos, pos_cols=None, find_closest=True, threshold=0.01, error='raise', verbose=0)[source]

filter_position is a utility that filters a pandas.DataFrame based on user-specified spatial positions. It can match positions exactly or compute distances to find the closest points within a threshold.

For a single dimension, the distance is computed as:

(17)#\[d = |x - p|\]

For multi-dimensional data with n coordinates, the Euclidean distance is computed as:

(18)#\[d = \sqrt{\sum_{i=1}^n (x_i - p_i)^2}\]
Parameters:
  • df (pandas.DataFrame) – The DataFrame that will be filtered. This parameter is essential and must contain columns referenced by pos_cols if pos_cols is not None.

  • pos (float or tuple of floats) – The reference position(s) to match or approximate. When pos_cols is None, pos is interpreted as an index value. Otherwise, each value in pos aligns with a specific column in pos_cols.

  • pos_cols (str or tuple of str, optional) – Name(s) of the column(s) in df to match against pos. If pos_cols=None, then pos is treated as an index filter. If multiple columns are given (e.g., latitude and longitude), each coordinate in pos should correspond to one column in pos_cols.

  • find_closest (bool, optional) – If True, nearest-neighbor filtering is performed within the distance threshold. If False, exact matches are used.

  • threshold (float, optional) – The maximum distance within which points are considered a match if find_closest is True. The unit corresponds to the column data (e.g., degrees for geographic lat/lon).

  • error ({'raise', 'warn', 'ignore'}, optional) – Specifies how to handle dimension mismatches or missing values. If 'raise', a ValueError will be raised. If 'warn', a warning is printed and extra values are ignored. If 'ignore', mismatches are silently ignored.

  • verbose (int, optional) – Controls the level of output messages: - 0: No output - 1: Basic info - 2: Additional details - >=3: Comprehensive summary

Returns:

A new DataFrame that contains only rows matching or approximating the specified position(s) within the given threshold if find_closest is True.

Return type:

pandas.DataFrame

Notes

When pos_cols is None, the function attempts to filter by DataFrame index using the first element of pos. This approach may fail for multi-level indexes unless error='warn' or error='ignore' is used to bypass the dimension mismatch.

Examples

>>> from geoprior.utils.spatial_utils import filter_position
>>> import pandas as pd
>>> df = pd.DataFrame({
...     'lat': [113.309998, 113.310001],
...     'lon': [22.831362, 22.831364]
... })
>>> # Exact match
>>> result_exact = filter_position(df, pos=(113.309998,
...                                         22.831362),
...                                pos_cols=('lat', 'lon'),
...                                find_closest=False)
>>> # Nearest match with threshold
>>> result_close = filter_position(df,
...                                pos=(113.31,
...                                     22.83),
...                                pos_cols=('lat',
...                                          'lon'),
...                                find_closest=True,
...                                threshold=0.01)

See also

geoprior.utils.data_utils.truncate_data

Truncate multiple DataFrames based on spatial coordinates or index alignment with a base DataFrame.

geoprior.utils.spatial_utils.create_spatial_clusters(df, spatial_cols=None, cluster_col='region', n_clusters=None, algorithm='kmeans', view=True, figsize=(14, 10), s=60, plot_style='seaborn', cmap='tab20', show_grid=True, grid_props=None, auto_scale=True, savefile=None, verbose=1, **kwargs)[source]

Cluster 2D spatial data in df using <algorithm> and optionally plot the results.

This function, <create_spatial_clusters>, extracts two coordinate columns from <df> to form clusters via methods such as ‘kmeans’, ‘dbscan’, or ‘agglo’ (agglomerative). It uses the function filter_valid_kwargs (when relevant) to strip out invalid parameters for certain estimators, and writes cluster labels into <cluster_col>.

Parameters:
  • df (pandas.DataFrame) – Input DataFrame holding spatial coordinates and optional other fields.

  • spatial_cols (list of str, optional) – Two-column list for x and y coordinates. Defaults to ['longitude','latitude'] if None.

  • cluster_col (str, default 'region') – Name of the column to store the assigned cluster labels.

  • n_clusters (int, optional) – Number of clusters to form. If not provided for KMeans, it is auto-detected. For DBSCAN or Agglomerative, a warning is issued if not set.

  • algorithm (str, default 'kmeans') – Choice of clustering algorithm among ['kmeans','dbscan','agglo'].

  • view (bool, default True) – If True, displays a scatterplot of the final clusters.

  • figsize (tuple, default (14, 10)) – Size of the displayed figure for the cluster plot.

  • s (int, default 60) – Marker size in the scatterplot.

  • plot_style (str, default 'seaborn') – Matplotlib style used for the plot.

  • cmap (str, default 'tab20') – Colormap name used to differentiate clusters.

  • show_grid (bool, default True) – Toggles grid lines on or off.

  • grid_props (dict, optional) – Additional keyword arguments controlling the grid style.

  • auto_scale (bool, default True) – If True, standardize coordinates before clustering.

  • savefile (str, optional) – File path to save the data with an additional <cluster_col> storing the assigned cluster labels if desired.

  • verbose (int, default 1) – Controls console logs. Higher values yield more details about scaling and cluster detection.

  • **kwargs – Additional keyword arguments passed to the chosen algorithm (filtered by filter_valid_kwargs for KMeans, DBSCAN, AgglomerativeClustering ).

Returns:

A copy of <df> with an additional <cluster_col> storing the assigned cluster labels.

Return type:

pandas.DataFrame

Notes

If <auto_scale> is True, it uses a standard scaler to normalize the coordinate columns. The scatterplot is generated using the library seaborn for enhanced styling.

By default, for <algorithm> = “kmeans”, the model attempts to minimize:

(19)#\[J = \sum_{i=1}^{N} \min_{\mu_j} \lVert x_i - \mu_j \rVert^2\]

where \(x_i\) are the scaled or raw 2D coordinates in <df>. The function can optionally auto-detect n_clusters using a silhouette and elbow analysis if not provided.

Examples

>>> from geoprior.utils.spatial_utils import create_spatial_clusters
>>> import pandas as pd
>>> df = pd.DataFrame({
...     "longitude": [0.1, 0.2, 2.2, 2.3],
...     "latitude": [1.0, 1.1, 2.1, 2.2]
... })
>>> # KMeans with auto scale and auto-detect k
>>> result = create_spatial_clusters(
...     df=df,
...     algorithm="kmeans",
...     view=True
... )
>>> # DBSCAN with custom arguments
>>> result_db = create_spatial_clusters(
...     df=df,
...     algorithm="dbscan",
...     eps=0.5,
...     min_samples=2
... )

See also

filter_valid_kwargs

Helps discard unsupported keyword arguments for chosen estimators.

geoprior.utils.spatial_utils.gen_negative_samples(df, target_col, spatial_cols=('longitude', 'latitude'), feature_cols=None, buffer_km=10, neg_feature_range=(0, 5), num_neg_per_pos=1, use_gpd='auto', view=False, savefile=None, verbose=1)[source]

Generate synthetic negative samples for spatial binary classification tasks.

This function creates additional samples labeled as non-events within a specified spatial buffer around the positive (event) observations. The main idea is to generate negative examples that reflect realistic conditions but have not triggered an event, thereby assisting models in distinguishing occurrences from non-occurrences.

Parameters:
  • df (pandas.DataFrame) – Input DataFrame containing the positive samples (events). Must include both the target column and the specified spatial columns.

  • target_col (str) – Column name for the binary target (e.g. landslide). Rows where this column is 1 (or True) are considered positive samples.

  • spatial_cols (tuple of str, default (``’longitude’, ``'latitude')) – Tuple specifying the longitude> and latitude column names in df.

  • feature_cols (list of str or None, default None) – List of feature columns to use or to simulate for generated negatives. If None, the function automatically detects numeric and categorical columns excluding spatial_cols and target_col.

  • buffer_km (float, default 10) – Spatial buffer in kilometers used to define the radius around each positive sample within which negative samples are created.

  • neg_feature_range (tuple of int, default (0, 5)) – Value range (minimum, maximum) used for simulating numeric feature values in negative samples if the corresponding feature column does not exist in <df>.

  • num_neg_per_pos (int, default 1) – Number of negative samples to generate per positive sample. For instance, if num_neg_per_pos=2, each positive sample spawns two negatives.

  • use_gpd (str, default 'auto') – If set to ‘auto’, the function tries to import GeoPandas for visualization. If ‘none’, no GeoPandas usage will occur.

  • view (bool, default False) – Whether to visualize the generated samples on a map. Attempts to use geopandas if installed; falls back to matplotlib if ‘auto’ is chosen and GeoPandas is not available.

  • savefile (str or None, default None) – Path to which the resulting DataFrame is saved if provided. Handled by the decorator that wraps this function.

  • verbose (int, default 1) – Verbosity level. 0 for silent, 1 for progress indication, 2 for more messages, 3 for debugging output.

Returns:

Combined DataFrame with both original positive samples and newly generated negative samples. The <target_col> is 1 for positives and 0 for negatives.

Return type:

pandas.DataFrame

`columns_manager`

This internal function is used to handle the processing of columns for features and spatial parameters.

Notes

  • If a feature column exists in df, the negative samples will copy or randomly select categories for categorical columns, and sample integers within neg_feature_range for numeric columns.

  • If feature_cols is empty or does not exist in df, the function simulates all values for negative samples.

  • When view=True, circles depicting the buffer zone around each positive sample are drawn for visualization.

  • The approximation of 1 degree to roughly 111 km varies slightly depending on latitude.

Mathematically, we define the spatial buffer in degrees as:

(20)#\[\begin{split}\\Delta = \\frac{\\text{buffer_km}}{111.0},\end{split}\]

where \(111.0\) km approximates the distance of one degree of latitude or longitude. For each positive sample at location \((lat, lon)\), we generate \(n\) new points with offsets \(\\delta_{lat}\) and \(\\delta_{lon}\), each drawn from a uniform distribution \(U(-\\Delta, \\Delta)\):

(21)#\[\begin{split}\\begin{aligned} &lat_{new} = lat + \\delta_{lat},\\\\ &lon_{new} = lon + \\delta_{lon}. \\end{aligned}\end{split}\]

Combined with randomly sampled or inferred feature values, these new samples serve as negative examples for modeling tasks such as landslide prediction.

Examples

>>> from geoprior.utils.spatial_utils import gen_negative_samples
>>> import pandas as pd
>>> import numpy as np
>>> df_pos = pd.DataFrame({
...        'latitude': np.random.uniform(24.0, 25.0, 5),
...        'longitude': np.random.uniform(113.0, 114.0, 5),
...        'rainfall_day_1': np.random.randint(10, 30, 5),
...        'rainfall_day_2': np.random.randint(10, 30, 5),
...        'rainfall_day_3': np.random.randint(10, 30, 5),
...        'rainfall_day_4': np.random.randint(10, 30, 5),
...        'rainfall_day_5': np.random.randint(10, 30, 5),
...        'landslide': 1
...    })
>>> combined = gen_negative_samples(
...     df=df_pos,
...     target_col='landslide',
...     buffer_km=10,
...     num_neg_per_pos=2,
...     view=False,
...     verbose=2
... )
>>> print(combined.head())

See also

check_spatial_columns

Ensures the existence of required spatial columns.

exist_features

Verifies the presence of specified features in <df>.

columns_manager

Handles both feature and spatial columns for processing.

geoprior.utils.spatial_utils.gen_buffered_negative_samples(df, target_col, spatial_cols=('longitude', 'latitude'), feature_cols=None, buffer_km=10, neg_feature_range=(0, 5), num_neg_per_pos=1, strategy='landslide', gauge_data=None, use_gpd='auto', id_col='auto', view=False, savefile=None, seed=None, verbose=1)[source]

Generate buffer-based negative samples around existing points or gauge stations.

This function creates additional negative samples for binary spatial events. It either takes an existing landslide dataset (when strategy is ‘landslide’) or a separate gauge dataset (if strategy is ‘gauge’) to serve as the base points for generating negatives within a circular buffer. The function validates input columns and parameters via _validate_negatives_sampling before constructing synthetic samples.

Parameters:
  • df (pandas.DataFrame) – The DataFrame containing positive event samples (e.g., landslides). Must include <target_col> and <spatial_cols>.

  • target_col (str) – Name of the binary target column (1 for event, 0 for no event).

  • spatial_cols (tuple of str, default (``’longitude’, ``'latitude')) – Indicates which columns hold <longitude> and latitude in df.

  • feature_cols (list of str, optional) – Additional feature columns to simulate or copy for negatives. If None, all columns except <spatial_cols> and target_col are used.

  • buffer_km (float, default 10) – The radial distance in kilometers for sampling negative points around each base point.

  • neg_feature_range (tuple of float, default (0, 5)) – A numeric range from which feature values are drawn for negative samples if the column is numeric.

  • num_neg_per_pos (int, default 1) – Number of negatives to generate per positive (landslide) or gauge point.

  • strategy (str, default 'landslide') – Defines the base from which negative samples are generated. Use 'landslide' or 'event' to sample around rows in df. Use 'gauge' to sample around rows in gauge_data.

  • gauge_data (pandas.DataFrame, optional) – Required if strategy is 'gauge'. Must contain spatial_cols.

  • use_gpd (bool or 'auto', default 'auto') – If 'auto', attempts to use GeoPandas for visualization if installed. Otherwise, falls back to Matplotlib. This parameter is forwarded to the underlying visualization function.

  • id_col (str or list of str, default 'auto') – Column(s) representing IDs in df. If 'auto', the function tries to detect possible ID columns. Used by _validate_negatives_sampling.

  • view (bool, default False) – Whether to visualize the sampled negatives around the base points.

  • savefile (str, optional) – If provided, saves the final combined dataset (positives and negatives) to a CSV file at the specified path.

  • seed (int, optional) – Seed for NumPy’s random generator, ensuring reproducible offsets in negative sampling.

  • verbose (int, default 1) – Controls console messages: 1 for minimal, 2 for more detailed logs.

Returns:

The combined dataset containing both the original (positive) rows, labeled with target_col = 1, and the newly generated negative rows, labeled target_col = 0.

Return type:

pandas.DataFrame

``_validate_negatives_sampling``

Validates required columns and parameters, including num_neg_per_pos and neg_feature_range.

``visualize_negative_sampling``

Generates a plot showing the negative samples around the base points if view is True.

Notes

  • If strategy is 'gauge', gauge_data must be provided and contain columns longitude and latitude.

  • When view is True, circles are drawn to illustrate the buffer radius.

  • The ratio of 1 degree to roughly 111 km is an approximation and can vary slightly by latitude.

Formally, a buffer in degrees \(\Delta\) is computed by:

(22)#\[\Delta = \frac{\text{buffer\_km}}{111},\]

where \(111\) is an approximate km-per-degree conversion factor. Each base point \((lat, lon)\) spawns \(n\) negatives, each offset by \(\delta_{lat}\), \(\delta_{lon}\) drawn from \(U(-\Delta, \Delta)\).

Examples

Below is an illustration of how to generate negative samples around both existing event locations (strategy=`landslide`) and separate gauge stations (strategy=`gauge`) using gen_buffered_negative_samples.

First, we simulate a small DataFrame of positive landslide samples with rainfall attributes, as well as a separate DataFrame for gauge stations:

>>> import numpy as np
>>> import pandas as pd
>>> np.random.seed(42)
>>> positive_samples = pd.DataFrame({
...     'id': [1, 2, 3, 4, 5],
...     'latitude': np.random.uniform(24.0, 25.0, 5),
...     'longitude': np.random.uniform(113.0, 114.0, 5),
...     'rainfall_day_1': np.random.randint(10, 30, 5),
...     'rainfall_day_2': np.random.randint(10, 30, 5),
...     'rainfall_day_3': np.random.randint(10, 30, 5),
...     'rainfall_day_4': np.random.randint(10, 30, 5),
...     'rainfall_day_5': np.random.randint(10, 30, 5),
...     'landslide': [1]*5
... })
>>> gauge_data = pd.DataFrame({
...     'gauge_id': ['G1', 'G2', 'G3'],
...     'latitude': np.random.uniform(24.0, 25.0, 3),
...     'longitude': np.random.uniform(113.0, 114.0, 3)
... })

We then call gen_buffered_negative_samples to produce negatives around these data using two different strategies:

>>> from geoprior.utils.spatial_utils import gen_buffered_negative_samples
>>> # Generate negatives around landslide points
>>> results_landslide = generate_negative_samples_with(
...     df=positive_samples,
...     target_col='landslide',
...     spatial_cols=('longitude', 'latitude'),
...     feature_cols=[f'rainfall_day_{i+1}' for i in range(5)],
...     buffer_km=10,
...     num_neg_per_pos=1,
...     strategy='landslide',
...     verbose=1
... )
>>> # Generate negatives around the gauge stations
>>> results_gauge = gen_buffered_negative_samples(
...     df=positive_samples,
...     target_col='landslide',
...     spatial_cols=('longitude', 'latitude'),
...     feature_cols=[f'rainfall_day_{i+1}' for i in range(5)],
...     buffer_km=10,
...     num_neg_per_pos=1,
...     strategy='gauge',
...     gauge_data=gauge_data,
...     verbose=1
... )

See also

generate_negative_samples

Generate synthetic negative samples for spatial binary classification tasks.

_validate_negatives_sampling

Ensures inputs and parameters are correct.

visualize_negative_sampling

Plots the positive and negative points for inspection.

geoprior.utils.spatial_utils.gen_negative_samples_plus(df, target_col, spatial_cols=('longitude', 'latitude'), feature_cols=None, buffer_km=10, neg_feature_range=(0, 5), num_neg_per_pos=1, strategy='landslide', gauge_data=None, elevation_data=None, similarity_features=None, time_col=None, cluster_method='kmeans', use_gpd='auto', id_col='auto', view=False, savefile=None, verbose=1, seed=None)[source]

Generates negative samples for modeling in spatial scenarios, offering multiple strategies. The function calls gen_buffered_negative_samples when the strategy argument is 'landslide', 'event', or 'gauge'. It also calls generate_negative_samples partially in the 'hybrid' strategy. Internally, each sample is augmented to produce negative instances according to a chosen method.

(23)#\[\text{buffer_deg} = \frac{\text{buffer_km}}{111.0}\]

The above formula approximates degrees from kilometers near the equator. The output is a combined dataset containing original positives and generated negatives. The ratio of negatives per positive is controlled by num_neg_per_pos.

Parameters:
  • df (pandas.DataFrame) – Input data containing spatial coordinates and features.

  • target_col (str) – Name of the classification target column. Positive samples in df are labeled, negatives will be generated.

  • spatial_cols (tuple of str, optional) – Columns representing longitude and latitude in df.

  • feature_cols (list of str, optional) – Additional feature columns used to drive or constrain negative sampling processes.

  • buffer_km (float, optional) – Radius in kilometers for local negative sampling. Used to compute buffer degrees.

  • neg_feature_range (tuple of float, optional) – Lower and upper range for continuous features in random negative generation.

  • num_neg_per_pos (int, optional) – Number of negatives to generate per positive instance.

  • strategy (str, optional) –

    Defines the sampling approach. Options include 'landslide', 'gauge', 'random_global', 'temporal_shift', 'clustered_negatives', 'environmental_similarity', 'elevation_based', and 'hybrid'.

    See more in User Guide.

  • gauge_data (pandas.DataFrame, optional) – Reference data for gauge-based or hybrid strategies.

  • elevation_data (pandas.DataFrame, optional) – Elevation records, used if strategy='elevation_based'.

  • similarity_features (list of str, optional) – Columns for nearest neighbor computation in 'environmental_similarity'.

  • time_col (str, optional) – Name of the time column for 'temporal_shift'. Required if using that strategy.

  • cluster_method (str, optional) – Clustering algorithm for 'clustered_negatives'. Default is 'kmeans'.

  • use_gpd (bool or str, optional) – Indicator for whether geopandas is used in buffer-based processes.

  • id_col (str, optional) – Column name used as an identifier. If ‘auto’, a default is used.

  • view (bool, optional) – Flag for visualizing or previewing the results.

  • savefile (str, optional) – Path to save the output dataset. If None, no file is saved.

  • verbose (int, optional) – Verbosity level. Higher values yield more logs.

  • seed (int, optional) – Random seed for reproducibility. If None, randomness is not fixed.

Returns:

A combined DataFrame containing the original positive samples labeled as 1 and newly generated negative samples labeled 0.

Return type:

pandas.DataFrame

Notes

When strategy='hybrid', partial sets of negatives come from two distinct calls to generate_negative_samples for 'landslide' and 'gauge' sub-strategies, then merged.

Examples

>>> from geoprior.utils.spatial_utils import gen_negative_samples_plus
>>> import pandas as pd
>>> df_example = pd.DataFrame({{
...     "longitude": [10.1, 10.2, 10.3],
...     "latitude":  [45.1, 45.2, 45.3],
...     "feature":   [3.4, 2.1, 6.7],
...     "target":    [1, 1, 1]
... }})
>>> gauge_data = pd.DataFrame({{
...     'gauge_id': ['G1', 'G2', 'G3'],
...     'latitude': np.random.uniform(24.0, 25.0, 3),
...     'longitude': np.random.uniform(113.0, 114.0, 3)
... }})
>>> # Generate random global negatives
>>> result = gen_negative_samples_plus(
...     df_example,
...     target_col="target",
...     strategy="random_global"
... )
>>> print(result.head())

See also

gen_buffered_negative_samples

Generates negative samples within a buffer region around reference events or gauges.

generate_negative_samples

A simpler negative sampling utility for certain strategies.

geoprior.utils.spatial_utils.extract_spatial_roi(df, x_range, y_range, x_col='longitude', y_col='latitude', snap_to_closest=True, savefile=None, **kwargs)[source]

Extracts a spatial Region of Interest (ROI) from a DataFrame.

This function filters a DataFrame to include only the data points that fall within a specified rectangular bounding box defined by x and y coordinate ranges.

Parameters:
  • df (pd.DataFrame) – The input DataFrame containing the spatial data.

  • x_range (tuple of (float, float)) – A tuple containing the minimum and maximum desired values for the x-coordinate (e.g., longitude). The order does not matter.

  • y_range (tuple of (float, float)) – A tuple containing the minimum and maximum desired values for the y-coordinate (e.g., latitude). The order does not matter.

  • x_col (str, default 'longitude') – The name of the column in df that contains the x-coordinates.

  • y_col (str, default 'latitude') – The name of the column in df that contains the y-coordinates.

  • snap_to_closest (bool, default True) – If True, and a value in x_range or y_range does not exist in the data, the function will “snap” to the nearest available coordinate in the dataset. If False, it will use the exact boundaries provided.

  • savefile (str, optional) – The path to save the resulting DataFrame as a CSV file.

Returns:

A new DataFrame containing only the rows that fall within the specified spatial bounding box.

Return type:

pd.DataFrame

Raises:

ValueError – If x_col or y_col are not found in the DataFrame, or if the range tuples are not provided correctly.

geoprior.utils.spatial_utils.make_forecast_ready_sample(data, sample_size=0.05, time_col='year', spatial_cols=None, group_cols=None, stratify_by=None, spatial_bins=10, time_steps=3, forecast_horizon=1, require_consecutive=True, keep_years=None, year_mode='latest', min_groups=5, max_groups=None, columns_to_keep=None, method='abs', min_relative_ratio=0.01, random_state=42, export_path=None, export_format=None, savefile=None, sort_output=True, verbose=1)[source]

Build a compact, forecast-ready panel sample.

The function samples spatial groups rather than individual rows, then reconstructs the full panel for the selected groups. This is much safer for demo/testing than row-wise sampling because each sampled location keeps enough temporal history for sequence construction.

Parameters:
  • data (pandas.DataFrame) – Input panel DataFrame.

  • sample_size (float or int, default 0.05) – Group-level sample size passed to the internal spatial sampler. - float: fraction of eligible groups - int: absolute number of eligible groups

  • time_col (str, default 'year') – Time column.

  • spatial_cols (tuple/list of str, optional) – Spatial coordinate columns. If None, the function searches for ‘longitude’ and ‘latitude’.

  • group_cols (tuple/list of str, optional) – Group identifier columns. If None, uses spatial_cols.

  • stratify_by (list/tuple of str, optional) – Extra group-level columns used for stratification. Typical examples: [‘lithology_class’] or [‘city’, ‘lithology_class’].

  • spatial_bins (int or tuple/list, default 10) – Spatial bins passed to spatial_sampling().

  • time_steps (int, default 3) – Lookback window length.

  • forecast_horizon (int, default 1) – Forecast horizon length.

  • require_consecutive (bool, default True) – If True, each kept group must contain at least one consecutive run of length time_steps + forecast_horizon.

  • keep_years (int, optional) – If provided, keep only this many years per group after sampling. Must be >= time_steps + forecast_horizon.

  • year_mode ({'all', 'latest', 'earliest', 'random'}) – How to trim years when keep_years is given.

  • min_groups (int, default 5) – Minimum number of eligible groups required.

  • max_groups (int, optional) – Hard cap on sampled groups after spatial sampling.

  • columns_to_keep (list/tuple of str, optional) – Restrict the returned columns. Group and time columns are always preserved.

  • method ({'abs', 'absolute', 'relative'}, default 'abs') – Same meaning as in spatial_sampling().

  • min_relative_ratio (float, default 0.01) – Minimum relative sampling ratio when method=’relative’.

  • random_state (int, optional) – Random seed.

  • export_path (str, optional) – Explicit path for non-CSV export.

  • export_format (str, optional) – Export format for export_path. If None, inferred from suffix.

  • savefile (str, optional) – CSV save path handled by @SaveFile.

  • sort_output (bool, default True) – Sort final output by group and time.

  • verbose (int, default 1) – Verbosity level.

Returns:

A compact panel sample that preserves group-wise temporal structure for forecast demos/tests.

Return type:

pandas.DataFrame

Examples

>>> from geoprior.utils.spatial_utils import make_forecast_ready_sample
>>> # Small demo sample, latest 4 years per location
>>> demo = make_forecast_ready_sample(
        data=df,
        sample_size=0.05,
        stratify_by=["lithology_class"],
        spatial_cols=("longitude", "latitude"),
        time_col="year",
        time_steps=3,
        forecast_horizon=1,
        keep_years=4,
        year_mode="latest",
        require_consecutive=True,
        savefile="demo_panel.csv",
    )
>>> #Slightly richer test panel, keep all years, export parquet
    demo = make_forecast_ready_sample(
    data=df,
    sample_size=150,
    stratify_by=["city", "lithology_class"],
    spatial_cols=("longitude", "latitude"),
    time_col="year",
    time_steps=3,
    forecast_horizon=1,
    keep_years=None,
    export_path="demo_panel.parquet",
    export_format="parquet",
)

Together, these modules support spatial augmentation, sampling, clustering, dummy data generation, and degree-to-meter conversion, which are all important in Stage-1 construction, spatial forecasting, and map-oriented analysis.

Holdout helpers#

Utility helpers for holdout and split workflows.

class geoprior.utils.holdout_utils.GroupMasks(required_train_years, required_forecast_years, valid_for_train, valid_for_forecast)[source]

Bases: object

Group-level validity masks for early filtering.

Parameters:
required_train_years: list[int]
required_forecast_years: list[int]
valid_for_train: DataFrame
valid_for_forecast: DataFrame
property keep_for_processing: DataFrame

Union(valid_for_train, valid_for_forecast).

__init__(required_train_years, required_forecast_years, valid_for_train, valid_for_forecast)
Parameters:
Return type:

None

geoprior.utils.holdout_utils.compute_group_masks(df, *, group_cols, time_col, train_end_year, time_steps, horizon)[source]
Build:
  • valid_for_train: groups containing all years for last (T+H)

  • valid_for_forecast: groups containing all years for last T

This assumes annual steps and integer years in time_col.

Parameters:
Return type:

GroupMasks

geoprior.utils.holdout_utils.filter_df_by_groups(df, *, group_cols, groups)[source]

Keep only rows in df whose (group_cols) exist in groups DataFrame.

Parameters:
Return type:

DataFrame

class geoprior.utils.holdout_utils.HoldoutSplit(train_groups, val_groups, test_groups)[source]

Bases: object

Pixel holdout split (disjoint groups).

Parameters:
train_groups: DataFrame
val_groups: DataFrame
test_groups: DataFrame
check_disjoint()[source]
Return type:

None

__init__(train_groups, val_groups, test_groups)
Parameters:
Return type:

None

geoprior.utils.holdout_utils.split_groups_holdout(groups, *, seed=42, val_frac=0.2, test_frac=0.1, strategy='random', x_col=None, y_col=None, block_size=None)[source]

Split unique groups into train/val/test (pixel-level holdout).

strategy:
  • “random”: shuffle groups

  • “spatial_block”: shuffle spatial blocks (needs x_col,y_col,block)

Parameters:
Return type:

HoldoutSplit

These utilities are especially relevant when the workflow needs group-aware, spatially constrained, or otherwise structured holdout logic instead of purely random splitting.

IO, system, and runtime helpers#

Input/Output utilities for managing file paths, directories, and loading serialized data within FusionLab. Provides error-checked deserialization, directory management, and archive handling (e.g., .tgz, .zip), streamlining file operations and data recovery.

Adapted for FusionLab from the original geoprior.utils.io_utils.

class geoprior.utils.io_utils.FileManager(root_dir, target_dir, file_types=None, name_patterns=None, move=False, overwrite=False, create_dirs=False)[source]

Bases: BaseClass

A class for managing and organizing files within a directory structure. This class provides methods to filter, organize, and rename files in bulk based on file extensions and name patterns. All operations are executed via the run method to ensure proper initialization and state management.

Mathematically, if \(\mathcal{F}\) represents the set of files in the root directory and \(\phi(f)\) is a filtering function that selects files based on file type and name pattern, then the FileManager produces a subset

(24)#\[\mathcal{F}' = \{ f \in \mathcal{F} \mid \phi(f) \}\]

and performs operations such as moving or copying to reorganize these files into a target directory.

Parameters:
  • root_dir (str) – The root directory containing the files to be managed. This directory must exist and contain the files subject to filtering.

  • target_dir (str) – The directory where the organized files will be placed. If necessary, this directory can be created when create_dirs is True.

  • file_types (list of str, optional) – A list of file extensions (e.g., ['.csv', '.json']) used to filter the files. If None, no file type filtering is applied.

  • name_patterns (list of str, optional) – A list of substrings (e.g., ['2023', 'report']) to filter file names. If None, all file names are included.

  • move (bool, optional) – If True, files are moved from the source to the target directory; otherwise, they are copied. Default is False.

  • overwrite (bool, optional) – If True, existing files in the target directory will be overwritten. If False, existing files are skipped. Default is False.

  • create_dirs (bool, optional) – If True, missing directories in the target path are created. Default is False.

Variables:
  • root_dir (str) – The validated root directory from which files are managed.

  • target_dir (str) – The directory where the processed files are stored.

run(pattern, replacement)[source]

Executes the file organization process. It filters files using the criteria provided at initialization and, if a pattern and corresponding replacement are given, performs bulk renaming.

Parameters:
  • pattern (str | None)

  • replacement (str | None)

get_processed_files()[source]

Returns a list of file paths that have been processed and organized into the target directory.

Return type:

list[str]

Examples

>>> from geoprior.utils.io_utils import FileManager
>>> manager = FileManager(
...     root_dir='data/raw',
...     target_dir='data/processed',
...     file_types=['.csv', '.json'],
...     name_patterns=['2023', 'report'],
...     move=True,
...     overwrite=True,
...     create_dirs=True
... )
>>> manager.run(pattern='old', replacement='new')
>>> processed = manager.get_processed_files()
>>> print(processed)

Notes

The public method run orchestrates the file management operations by first calling the internal method _organize_files() to filter and move or copy files from the source directory to the target directory. If renaming is needed, _rename_files() is invoked with the specified pattern and replacement. The method get_processed_files() compiles a list of all files that have been organized, based on a walk of the target directory. The directory traversal and file-operation APIs are documented in [29, 30].

See also

shutil.move

To move files between directories.

shutil.copy2

To copy files while preserving file metadata.

__init__(root_dir, target_dir, file_types=None, name_patterns=None, move=False, overwrite=False, create_dirs=False)[source]

Initialize the base class.

Parameters:
  • verbose (int, optional) – Verbosity level controlling logging (0 to 3). Defaults to 0.

  • root_dir (str)

  • target_dir (str)

  • file_types (list[str] | None)

  • name_patterns (list[str] | None)

  • move (bool)

  • overwrite (bool)

  • create_dirs (bool)

run(pattern=None, replacement=None)[source]

Executes file organization operations.

This method filters files based on the specified file types and name patterns, then organizes them by moving or copying into the target directory. Additionally, if a pattern is provided, file names containing that pattern are renamed by replacing the pattern with the specified replacement.

Parameters:
  • pattern (str, optional) – The substring to search for in file names. If provided, file names containing this pattern will be renamed.

  • replacement (str, optional) – The string to replace pattern with in file names. Required if pattern is specified.

Returns:

self – The instance itself after executing operations.

Return type:

FileManager

Examples

>>> manager = FileManager(...)
>>> manager.run(pattern='old', replacement='new')
get_processed_files()[source]

Retrieves a list of processed files in the target directory.

Returns:

files – A list containing the full paths of the files that have been organized into the target directory.

Return type:

list of str

Examples

>>> manager = FileManager(...)
>>> manager.run()
>>> files = manager.get_processed_files()
>>> print(files)
fit(pattern=None, replacement=None)

Executes file organization operations.

This method filters files based on the specified file types and name patterns, then organizes them by moving or copying into the target directory. Additionally, if a pattern is provided, file names containing that pattern are renamed by replacing the pattern with the specified replacement.

Parameters:
  • pattern (str, optional) – The substring to search for in file names. If provided, file names containing this pattern will be renamed.

  • replacement (str, optional) – The string to replace pattern with in file names. Required if pattern is specified.

Returns:

self – The instance itself after executing operations.

Return type:

FileManager

Examples

>>> manager = FileManager(...)
>>> manager.run(pattern='old', replacement='new')
help(**kwargs)
my_params = FileManager(     root_dir,     target_dir,     file_types=None,     name_patterns=None,     move=False,     overwrite=False,     create_dirs=False )
geoprior.utils.io_utils.cpath(savepath=None, dpath='_default_path_')[source]

Ensures a directory exists for saving files, creating it if necessary.

Parameters:
  • savepath (str, optional) – The target directory to validate or create. If None, dpath is used as the directory.

  • dpath (str, default '_default_path_') – Default directory created in the current working directory if savepath is None.

Returns:

The absolute path to the validated or created directory.

Return type:

str

Examples

>>> from geoprior.utils.io_utils import cpath
>>> default_path = cpath()
>>> print(f"Files will be saved to: {default_path}")
>>> custom_path = cpath('/path/to/save')
>>> print(f"Files will be saved to: {custom_path}")

Notes

cpath validates the directory path and, if necessary, creates the directory tree. If a problem occurs during creation, an error message is printed.

See also

pathlib.Path.mkdir

Utility for directory creation.

geoprior.utils.io_utils.deserialize_data(filename, verbose=0)[source]

Deserialize and load data from a serialized file using joblib or pickle.

The function attempts to load the serialized data from the provided file filename using joblib first. If joblib fails, it tries to load the data using pickle. An error is raised if both methods fail.

Parameters:
  • filename (str) – The name or path of the file containing the serialized data. This file is expected to be in a compatible format with either joblib or pickle.

  • verbose (int, optional) – Verbosity level. Messages indicating loading progress will be displayed if verbose is greater than 0.

Returns:

The data loaded from the serialized file, or None if loading fails.

Return type:

Any

Raises:
  • TypeError – If filename is not a string, as file paths must be provided as strings.

  • FileNotFoundError – If the specified filename does not exist or cannot be located.

  • IOError – If both joblib and pickle fail to deserialize the data from the file.

  • ValueError – If the file was successfully read but yielded no data (i.e., None).

Examples

>>> from geoprior.utils.io_utils import deserialize_data
>>> data = deserialize_data('path/to/serialized_data.pkl', verbose=1)
Data loaded successfully from 'path/to/serialized_data.pkl' using joblib.

Notes

The function first attempts deserialization with joblib to leverage efficient file handling for large datasets. If joblib encounters an error, it falls back to pickle, which provides broader compatibility with Python objects but may be less optimized for large datasets. Loader semantics for the two backends are documented in [31, 32].

See also

joblib.load

Joblib’s load function for fast I/O operations on large data.

pickle.load

Pickle’s load function for serializing and deserializing Python objects.

geoprior.utils.io_utils.extract_tar_with_progress(tar, member, path)[source]

Extracts a single file from a tar archive with a progress bar.

Parameters:
  • tar (tarfile.TarFile) – Opened tar file object.

  • member (tarfile.TarInfo) – Tar member (file) to be extracted.

  • path (Path) – Directory path where the file will be extracted.

Examples

>>> from geoprior.utils.io_utils import extract_tar_with_progress
>>> with tarfile.open('data.tar.gz', 'r:gz') as tar:
...     member = tar.getmember('file.csv')
...     extract_tar_with_progress(tar, member, Path('output_dir'))

Notes

Uses tqdm for progress tracking of the file extraction process.

geoprior.utils.io_utils.fetch_tgz_from_url(data_url, tgz_filename, data_path=None, file_to_retrieve=None, **kwargs)[source]

Downloads a .tgz file from a specified URL, saves it to a directory, and optionally extracts a specific file from the archive.

This function retrieves a .tgz file from the provided data_url and saves it to the specified data_path directory. If file_to_retrieve is specified, the function will extract only that file from the archive; otherwise, the entire archive will be extracted.

Parameters:
  • data_url (str) – The URL to download the .tgz file from.

  • tgz_filename (str) – The name to assign to the downloaded .tgz file.

  • data_path (Union[str, Path], optional) – Directory where the downloaded file will be saved. Defaults to a ‘tgz_data’ directory in the current working directory if not specified.

  • file_to_retrieve (str, optional) – Specific filename to extract from the .tgz archive. If not provided, the entire archive is extracted.

  • **kwargs (dict) – Additional keyword arguments to pass to the extraction method.

Returns:

Path to the extracted file if a specific file was requested; otherwise, returns None.

Return type:

Optional[Path]

Raises:

FileNotFoundError – If the specified file_to_retrieve is not found in the archive.

Examples

>>> from geoprior.utils.io_utils import fetch_tgz_from_url
>>> data_url = 'https://example.com/data.tar.gz'
>>> extracted_file = fetch_tgz_from_url(
...     data_url, 'data.tar.gz', data_path='data_dir', file_to_retrieve='file.csv')
>>> print(extracted_file)

Notes

Uses the tqdm progress bar for tracking download progress.

geoprior.utils.io_utils.fetch_tgz_locally(tgz_file, filename, savefile='tgz', rename_outfile=None)[source]

Extracts a specific file from a local .tgz archive and optionally renames it.

This function fetches a specific file filename from a local tar archive located at tgz_file, and saves it to savefile. If rename_outfile is specified, the file is renamed after extraction.

Parameters:
  • tgz_file (str) – Full path to the tar file.

  • filename (str) – Name of the target file to extract from the archive.

  • savefile (str, optional) – Destination directory for the extracted file, defaulting to ‘tgz’.

  • rename_outfile (str, optional) – New name for the fetched file. If not provided, retains the original name.

Returns:

Full path to the fetched and possibly renamed file.

Return type:

str

Raises:

FileNotFoundError – If the tgz_file or the specified filename is not found.

Examples

>>> from geoprior.utils.io_utils import fetch_tgz_locally
>>> fetched_file = fetch_tgz_locally(
...     'path/to/archive.tgz', 'file.csv', savefile='extracted', rename_outfile='renamed.csv')
>>> print(fetched_file)
geoprior.utils.io_utils.dummy_csv_translator(csv_fn, pf, delimiter=':', destfile='pme.en.csv')[source]

Translate a CSV file using a dictionary created from a markdown-style parser file.

Parameters:
  • csv_fn (str) – Path to the source CSV file.

  • pf (str) – Path to the markdown-style file used to create the translation dictionary.

  • delimiter (str, default ':') – Delimiter used in the parser file to separate key-value pairs.

  • destfile (str, default 'pme.en.csv') – Name of the destination file for the translated CSV.

Returns:

  • DataFrame – Translated CSV data as a DataFrame.

  • list – List of untranslated terms found in the source CSV.

Notes

  • This function uses parse_md_data to read the parser file and apply translations to the CSV content.

  • Missing translations are collected and returned for review.

Examples

>>> df, missing = dummy_csv_translator(
    "data.csv", "parser_file.md", delimiter=":", destfile="output.csv")
>>> print(df.head())
>>> print(missing)
geoprior.utils.io_utils.fetch_json_data_from_url(url, todo='load')[source]

Retrieve and parse JSON data from a URL.

Parameters:
  • url (str) – Universal Resource Locator (URL) from which JSON data is fetched.

  • todo ({'load', 'dump'}, default 'load') – Action to perform with JSON: - ‘load’: Load JSON data from the URL. - ‘dump’: Parse and prepare data from the URL for saving in a JSON file.

Returns:

A tuple of todo action, filename (or data source), and parsed data.

Return type:

tuple

Raises:

urllib.error.URLError – If there is an issue accessing the URL.

Notes

The function uses json.loads to parse data directly from a URL response, supporting convenient access to web-hosted JSON content.

geoprior.utils.io_utils.get_config_fname_from_varname(data, config_fname=None, config='.yml')[source]

Generate a filename based on a variable name for YAML configuration.

Parameters:
  • data (Any) – The data object from which the variable name will be derived to create a YAML configuration filename.

  • config_fname (str, optional) – Custom configuration filename. If None, the name of data will be used as the filename.

  • config (str, default '.yml') – The file extension/type for the configuration file. Can be ‘.yml’, ‘.json’, or ‘.csv’.

Returns:

A suitable filename for saving the configuration data.

Return type:

str

Raises:

ValueError – If config_fname cannot be derived or an invalid file type is provided.

Notes

This function supports dynamic filename generation based on variable names, which aids in maintaining a clear configuration structure for serialized data. Files are saved with appropriate extensions based on the config type.

geoprior.utils.io_utils.get_valid_key(input_key, default_key, substitute_key_dict=None, regex_pattern='[#&*@!,;\\s]\\s*', deep_search=True)[source]

Validates an input key and substitutes it with a valid key if necessary, based on a mapping of valid keys to their possible substitutes. If the input key is not provided or is invalid, a default key is used.

Parameters:
  • input_key (str) – The key to validate and possibly substitute.

  • default_key (str) – The default key to use if input_key is None, empty, or not found in the substitute mapping.

  • substitute_key_dict (dict, optional) – A mapping of valid keys to lists of their possible substitutes. This allows for flexible key substitution and validation.

  • regex_pattern (str, default = '[#&*@!,;\s-]\s*') – The base pattern to split the text into a columns

  • deep_search (bool, default False) – If deep-search, the key finder is no sensistive to lower/upper case or whether a numeric data is included.

Returns:

A valid key, which is either the original input_key if valid, a substituted key if the original was found in the substitute mappings, or the default_key.

Return type:

str

Notes

This function also leverages an external validation through key_checker for a deep search validation, ensuring the returned key is within the set of valid keys.

Example

>>> from geoprior.utils.io_utils import get_valid_key
>>> substitute_key_dict = {'valid_key1': ['vk1', 'key1'], 'valid_key2': ['vk2', 'key2']}
>>> get_valid_key('vk1', 'default_key', substitute_key_dict)
'valid_key1'
>>> get_valid_key('unknown_key', 'default_key', substitute_key_dict)
'KeyError...'
geoprior.utils.io_utils.key_checker(keys, valid_keys, regex=None, pattern=None, deep_search=False)[source]

check whether a give key exists in valid_keys and return a list if many keys are found.

Parameters:
  • keys (str, list of str) – Key value to find in the valid_keys

  • valid_keys (list) – List of valid keys by default.

  • regex (re object,) –

    Regular expresion object. the default is:

    >>> import re
    >>> re.compile (r'[_#&*@!_,;\s-]\s*', flags=re.IGNORECASE)
    

  • pattern (str, default = '[_#&*@!_,;\s-]\s*') – The base pattern to split the text into a columns

  • deep_search (bool, default False) – If deep-search, the key finder is no sensistive to lower/upper case or whether a numeric data is included.

Returns:

keys – List of keys that exists in the valid_keys.

Return type:

str, list ,

Examples

>>> from geoprior.utils.io_utils import key_checker
>>> key_checker('h502', valid_keys= ['h502', 'h253','h2601'])
Out[68]: 'h502'
>>> key_checker('h502+h2601', valid_keys= ['h502', 'h253','h2601'])
Out[69]: ['h502', 'h2601']
>>> key_checker('h502 h2601', valid_keys= ['h502', 'h253','h2601'])
Out[70]: ['h502', 'h2601']
>>> key_checker(['h502',  'h2601'], valid_keys= ['h502', 'h253','h2601'])
Out[73]: ['h502', 'h2601']
>>> key_checker(['h502',  'h2602'], valid_keys= ['h502', 'h253','h2601'])
UserWarning: key 'h2602' is missing in ['h502', 'h2602']
Out[82]: 'h502'
>>> key_checker(['502',  'H2601'], valid_keys= ['h502', 'h253','h2601'],
                deep_search=True )
Out[57]: ['h502', 'h2601']
geoprior.utils.io_utils.key_search(keys, default_keys, parse_keys=True, regex=None, pattern=None, deep=Ellipsis, raise_exception=Ellipsis)[source]

Find key in a list of default keys and select the best match.

Parameters:
  • keys (str or list) – The string or a list of key. When multiple keys is passed as a string, use the space for key separating.

  • default_keys (str or list) – The likehood key to find. Can be a litteral text. When a litteral text is passed, it is better to provide the regex in order to skip some character to parse the text properly.

  • parse_keys (bool, default True) –

    Parse litteral string using default pattern and regex.

    Added in version 0.2.7.

  • regex (re object,) –

    Regular expresion object. Regex is important to specify the kind of data to parse. the default is:

    >>> import re
    >>> re.compile (r'[_#&*@!_,;\s-]\s*', flags=re.IGNORECASE)
    

  • pattern (str, default = '[_#&*@!_,;\s-]\s*') – The base pattern to split the text into a columns. Pattern is important especially when some character are considers as a part of word but they are not a separator. For example a data columns with a name ‘DH_Azimuth’, if a pattern is not explicitely provided, the default pattern will parse as two separated word which is far from the expected results.

  • deep (bool, default False) – Not sensistive to uppercase.

  • raise_exception (bool, default False) – raise error when key is not find.

Returns:

list

Return type:

list of valid keys or None if not find ( default)

Examples

>>> from geoprior.utils.io_utils import key_search
>>> key_search('h502-hh2601', default_keys= ['h502', 'h253','HH2601'])
Out[44]: ['h502']
>>> key_search('h502-hh2601', default_keys= ['h502', 'h253','HH2601'],
               deep=True)
Out[46]: ['h502', 'HH2601']
>>> key_search('253', default_keys= ("I m here to find key among h502,
                                         h253 and HH2601"))
Out[53]: ['h253']
>>> key_search ('east', default_keys= ['DH_East', 'DH_North']  , deep =True,)
Out[37]: ['East']
key_search ('east', default_keys= ['DH_East', 'DH_North'],
            deep =True,parse_keys= False)
Out[39]: ['DH_East']
geoprior.utils.io_utils.load_serialized_data(filename, verbose=0)[source]

Load data from a serialized file (e.g., pickle or joblib format).

Parameters:
  • filename (str) – Name of the file to load data from.

  • verbose (int, default 0) – Verbosity level. Controls the amount of output information: - 0: No output - >2: Detailed loading process messages.

Returns:

Data loaded from the file, or None if deserialization fails.

Return type:

Any

Raises:

Examples

>>> from geoprior.utils.io_utils import load_serialized_data
>>> data = load_serialized_data('data/my_data.pkl', verbose=3)

Notes

This function attempts to load serialized data using joblib and fallbacks to pickle if needed. Verbose output provides feedback on the loading process and success or failure of each step.

See also

joblib.load

High-performance loading utility.

pickle.load

General-purpose Python serialization library.

geoprior.utils.io_utils.load_csv(data_path, delimiter=',', **kwargs)[source]

Loads a CSV file into a pandas DataFrame.

This function reads a comma-separated values (CSV) file into a pandas DataFrame, with the ability to specify a custom delimiter. It provides support for additional options passed to pandas.read_csv for more granular control over the data loading process.

Parameters:
  • data_path (str) – The file path to the CSV file that is to be loaded. The file path must lead to a .csv file. If the file does not exist at the specified path, a FileNotFoundError is raised.

  • delimiter (str, optional) – The character used to separate values in the CSV file. The default is , for standard CSVs. If a different delimiter is used in the file (e.g., ;), it can be specified here.

  • **kwargs (dict) – Additional keyword arguments that will be passed directly to pandas.read_csv. For instance, users can specify header, index_col, dtype, and other options supported by read_csv for more customized data handling.

Returns:

A pandas DataFrame containing the loaded data, with the specified options applied.

Return type:

DataFrame

Raises:
  • FileNotFoundError – If the specified file does not exist at the provided data_path.

  • ValueError – If the file specified by data_path is not a CSV file (i.e., does not have a .csv extension), a ValueError is raised to ensure correct file type.

Notes

This function simplifies the process of loading CSV data into a DataFrame, with a straightforward parameter for delimiter customization and full access to pandas.read_csv options. It is ideal for basic CSV loading tasks, as well as more complex ones requiring specific column handling, type casting, and missing value handling, which can be passed via **kwargs. CSV-oriented DataFrame loading patterns are discussed in McKinney [27].

Examples

Suppose you have a CSV file example.csv with the following content:

` name,age,city Alice,30,New York Bob,25,Los Angeles `

To load this file into a DataFrame:

>>> from geoprior.utils.io_utils import load_csv
>>> df = load_csv('example.csv')
>>> print(df)
     name  age         city
0   Alice   30     New York
1     Bob   25  Los Angeles

If the file uses a semicolon (;) as the delimiter:

>>> df = load_csv('example.csv', delimiter=';')

Additionally, you can pass custom read_csv parameters through **kwargs, such as specifying a column as the index:

>>> df = load_csv('example.csv', index_col='name')
>>> print(df)
       age         city
name
Alice    30     New York
Bob      25  Los Angeles

See also

pandas.read_csv

Full documentation for loading CSV files into a DataFrame with detailed parameter options.

geoprior.utils.io_utils.move_cfile(cfile, savepath=None, **ckws)[source]

Moves a file to the specified path. If moving fails, copies and deletes the original.

Parameters:
  • cfile (str) – Name of the file to move.

  • savepath (str, optional) – Target directory. If not specified, uses default path via cpath.

Returns:

The new file path and a confirmation message.

Return type:

Tuple[str, str]

Examples

>>> from geoprior.utils.io_utils import move_cfile
>>> new_path, msg = move_cfile('myfile.txt', 'new_directory')
>>> print(new_path, msg)
geoprior.utils.io_utils.parse_csv(csv_fn=None, data=None, todo='reader', fieldnames=None, savepath=None, header=False, verbose=0, **csvkws)[source]

Parses a CSV file or serializes data to a CSV file.

This function allows loading (reading) from or dumping (writing) to a CSV file. It supports standard CSV and dictionary-based CSV formats.

Parameters:
  • csv_fn (str, optional) – The CSV filename for reading or writing. For writing operations, if data is provided and todo is set to ‘write’ or ‘dictwriter’, this specifies the output CSV filename.

  • data (list, optional) – Data to write in the form of a list of lists or dictionaries.

  • todo (str, default 'reader') – Specifies the operation type: - ‘reader’ or ‘dictreader’: Reads data from a CSV file. - ‘writer’ or ‘dictwriter’: Writes data to a CSV file.

  • fieldnames (list of str, optional) – List of keys for dictionary-based writing to specify the field order.

  • savepath (str, optional) – Directory to save the CSV file when writing. Defaults to ‘_savecsv_’ if not provided and the path does not exist.

  • header (bool, default False) – If True, includes headers when writing with DictWriter.

  • verbose (int, default 0) – Controls the verbosity level for output messages.

  • csvkws (dict, optional) – Additional arguments passed to csv.writer or csv.DictWriter.

Returns:

Parsed data from the CSV file, as a list of lists or a list of dictionaries, based on the operation. Returns None when writing.

Return type:

Union[List[Dict], List[List[str]], None]

Notes

For writing data, the method uses either csv.writer for regular CSV or csv.DictWriter for dictionary-based CSV depending on the value of todo.

Examples

>>> from geoprior.utils.io_utils import parse_csv
>>> data = [{"name": "Alice", "age": 30}, {"name": "Bob", "age": 25}]
>>> parse_csv(csv_fn='output.csv', data=data, todo='dictwriter', fieldnames=['name', 'age'])
>>> loaded_data = parse_csv(csv_fn='output.csv', todo='dictreader', fieldnames=['name', 'age'])
>>> print(loaded_data)
[{'name': 'Alice', 'age': 30}, {'name': 'Bob', 'age': 25}]
geoprior.utils.io_utils.parse_json(json_fn=None, data=None, todo='load', savepath=None, verbose=0, **jsonkws)[source]

Parse and manage JSON configuration files, either loading data from or saving data to a JSON file.

Parameters:
  • json_fn (str, optional) – JSON filename or URL. If data is provided and todo is ‘dump’, json_fn will be used as the output filename. If todo is ‘load’, json_fn is the input filename or URL.

  • data (Any, optional) – Data in Python object format to serialize and save if todo is ‘dump’.

  • todo ({'load', 'loads', 'dump', 'dumps'}, default 'load') – Action to perform with JSON: - ‘load’: Load data from a JSON file. - ‘loads’: Parse a JSON string. - ‘dump’: Serialize data to a JSON file. - ‘dumps’: Serialize data to a JSON string.

  • savepath (str, optional) – Path where the JSON file will be saved if todo is ‘dump’. If savepath does not exist, it will save to the default path ‘_savejson_’.

  • verbose (int, default 0) – Controls verbosity of output messages.

  • **jsonkws (dict) – Additional keyword arguments passed to json.dump or json.dumps when saving data.

Returns:

The data loaded from the JSON file or URL if todo is ‘load’, or data after saving if todo is ‘dump’.

Return type:

Any

Raises:
  • json.JSONDecodeError – If there is an issue with reading or writing the JSON file.

  • TypeError – If the JSON file or data cannot be processed.

Notes

This function uses json.load, json.loads, json.dump, and json.dumps for efficient handling of JSON files and strings.

See also

fetch_json_data_from_url

Fetches JSON data from a given URL.

get_config_fname_from_varname

Utility for generating JSON configuration filenames based on variable names.

geoprior.utils.io_utils.parse_md(pf, delimiter=':')[source]

Parse a markdown-style file with key-value pairs separated by a delimiter.

Parameters:
  • pf (str) – Path to the markdown file containing key-value pairs.

  • delimiter (str, default ':') – Delimiter used to separate key-value pairs.

Yields:

Tuple[str, str] – A tuple containing the key and processed value.

Raises:

IOError – If the provided path does not lead to a valid file.

Notes

  • This function yields key-value pairs by reading the file line-by-line.

  • It applies sanitize_unicode_string to keys to ensure data consistency.

Examples

>>> list(parse_md_data('parser_file.md', delimiter=':'))
[('key1', 'Value1'), ('key2', 'Value2')]
geoprior.utils.io_utils.parse_yaml(yml_fn=None, data=None, todo='load', savepath=None, verbose=0, **ymlkws)[source]

Parse and handle YAML configuration files for loading or saving data.

Parameters:
  • yml_fn (str, optional) – The YAML filename. If data is provided and todo is set to ‘dump’, yml_fn will be used as the output filename. If todo is set to ‘load’, yml_fn is the input filename to read from.

  • data (Any, optional) – Data in a Python object format that will be serialized and saved as a YAML file if todo is ‘dump’.

  • todo ({'load', 'dump'}, default 'load') – Action to perform with the YAML file: - ‘load’: Load data from the YAML file specified by yml_fn. - ‘dump’: Serialize data into a YAML format and save to yml_fn.

  • savepath (str, optional) – Path where the YAML file will be saved if todo is ‘dump’. If not provided, a default path will be used. The function will ensure that the path exists.

  • verbose (int, default 0) – Controls verbosity of output messages.

  • **ymlkws (dict) – Additional keyword arguments passed to yaml.dump when saving data.

Returns:

The data loaded from the YAML file if todo is ‘load’, or data after saving if todo is ‘dump’.

Return type:

Any

Raises:

yaml.YAMLError – If there is an issue with reading or writing the YAML file.

Notes

This function uses safe_load and safe_dump methods from PyYAML for secure handling of YAML files.

See also

get_config_fname_from_varname

Utility for generating YAML configuration filenames based on variable names.

geoprior.utils.io_utils.print_cmsg(cfile, todo='load', config='YAML')[source]

Generates output message for configuration file operations.

Parameters:
  • cfile (str) – Name of the configuration file.

  • todo (str, default 'load') – Operation performed (‘load’ or ‘dump’).

  • config (str, default 'YAML') – Type of configuration file (e.g., ‘YAML’, ‘CSV’, ‘JSON’).

Returns:

Confirmation message for the configuration operation.

Return type:

str

Examples

>>> from geoprior.utils.io_utils import print_cmsg
>>> msg = print_cmsg('config.yml', 'dump')
>>> print(msg)
--> YAML 'config.yml' data was successfully saved.
geoprior.utils.io_utils.rename_files(src_files, dst_files, basename=None, extension=None, how='py', prefix=True, keep_copy=True, trailer='_', sortby=None, **kws)[source]

Rename files from one set of names or paths to another.

Parameters:
  • src_files (str or list of str) – Source files or a directory containing files to rename.

  • dst_files (str or list of str) – Destination file names or destination directory.

  • basename (str or None, optional) – Base name used when generating numbered destination files.

  • extension (str or None, optional) – Optional extension filter when src_files is a directory.

  • how (str, optional) – Numbering convention used when destination names are generated.

  • prefix (bool, optional) – Whether generated numbering is appended after the basename.

  • keep_copy (bool, optional) – Whether to keep copies of the original files.

  • trailer (str, optional) – Separator inserted between the basename and the generated counter.

  • sortby (regex, callable, or None, optional) – Optional sort key used when collecting files from a directory.

  • **kws (dict) – Additional keyword arguments forwarded to os.rename.

Return type:

None

geoprior.utils.io_utils.sanitize_unicode_string(str_)[source]

Removes spaces and replaces accented characters in a string.

Parameters:
  • str (str) – The string to sanitize.

  • str_ (str)

Returns:

The sanitized string with removed spaces and replaced accents.

Return type:

str

Examples

>>> from geoprior.utils.io_utils import sanitize_unicode_string
>>> sentence ='Nos clients sont extrêmement satisfaits '
    'de la qualité du service fourni. En outre Nos clients '
        'rachètent frequemment nos "services".'
>>> sanitize_unicode_string  (sentence)
... 'nosclientssontextrmementsatisfaitsdelaqualitduservice'
    'fournienoutrenosclientsrachtentfrequemmentnosservices'
>>> sanitize_unicode_string("Élève à l'école")
'elevealecole'
geoprior.utils.io_utils.save_job(job, savefile, *, protocol=None, append_versions=True, append_date=True, fix_imports=True, buffer_callback=None, **job_kws)[source]

Quick save your job using ‘joblib’ or persistent Python pickle module.

Parameters:
  • job (Any) – Anything to save, preferabaly a models in dict

  • savefile (str, or path-like object) – name of file to store the model. The file argument must have a write() method that accepts a single bytes argument. It can thus be a file object opened for binary writing, an io.BytesIO instance, or any other custom object that meets this interface.

  • append_versions (bool, default =True) – Append the version of Joblib module or Python Pickle module following by the scikit-learn, numpy and also pandas versions. This is useful to have idea about previous versions for loading file when system or modules have been upgraded. This could avoid bottleneck when data have been stored for long times and user has forgotten the date and versions at the time the file was saved.

  • append_date (bool, default True,) – Append the date of the day to the filename.

  • protocol (int, optional) –

    The optional protocol argument tells the pickler to use the given protocol; supported protocols are 0, 1, 2, 3, 4 and 5. The default protocol is 4. It was introduced in Python 3.4, and is incompatible with previous versions.

    Specifying a negative protocol version selects the highest protocol version supported. The higher the protocol used, the more recent the version of Python needed to read the pickle produced.

  • fix_imports (bool, default True,) – If fix_imports is True and protocol is less than 3, pickle will try to map the new Python 3 names to the old module names used in Python 2, so that the pickle data stream is readable with Python 2.

  • buffer_call_back (int, optional) –

    If buffer_callback is None (the default), buffer views are serialized into file as part of the pickle stream.

    If buffer_callback is not None, then it can be called any number of times with a buffer view. If the callback returns a false value (such as None), the given buffer is out-of-band; otherwise the buffer is serialized in-band, i.e. inside the pickle stream.

    It is an error if buffer_callback is not None and protocol is None or smaller than 5.

  • job_kws (dict,) – Additional keywords arguments passed to joblib.dump().

Returns:

The final filename where the job was saved.

Return type:

str

Notes

This function appends system-specific metadata like versions and date to the filename, which can aid in tracking compatibility over time.

Examples

>>> from geoprior.utils.io_utils import save_job
>>> model = {"key": "value"}  # Replace with actual model object
>>> savefile = save_job(model, "my_model", append_date=True, append_versions=True)
>>> print(savefile)
'my_model.20240101.sklearn_v1.0.numpy_v1.21.joblib'
geoprior.utils.io_utils.save_path(nameOfPath)[source]

Creates a directory if it does not exist.

Parameters:

nameOfPath (str) – Name or path of the directory to create.

Returns:

The path of the created directory. If it exists, returns the existing path.

Return type:

str

Examples

>>> save_path("test_directory")
'path/to/test_directory'
geoprior.utils.io_utils.serialize_data(data, filename=None, savepath=None, to=None, force=True, compress=None, pickle_protocol=5, verbose=0)[source]

Serialize and save a Python object to a binary file using either joblib or pickle. This function is designed to be robust and versatile, handling multiple cases including file naming, overwriting behavior, and compression options.

The final file path is computed as:

(25)#\[\text{filepath} = \text{savepath} \oplus \text{filename}\]

where \(\oplus\) denotes string concatenation.

Parameters:
  • data (Any) – The Python object to serialize. The object must be compatible with either joblib.dump or pickle.dump.

  • filename (str, optional) – The target filename for the serialized data. If None, a filename is generated using the current timestamp, e.g., "__mydumpedfile_20230315_123045.pkl".

  • savepath (str, optional) – The directory in which to save the file. If not specified, the current working directory (os.getcwd()) is used. The directory is created if it does not exist.

  • to (str, optional) – The serialization method to use. Acceptable values are 'joblib' and 'pickle'. If None, the default is 'joblib'.

  • force (bool, default True) – If True, any existing file with the same name is overwritten. If False, a timestamp is appended to the filename to ensure uniqueness.

  • compress (int or str, optional) – Compression level or method for joblib.dump. If None, no compression is applied.

  • pickle_protocol (int, default pickle.HIGHEST_PROTOCOL) – The pickle protocol to use when serializing with pickle.dump.

  • verbose (int, default 0) – Controls the verbosity of output messages. Higher values produce more detailed logging during the serialization process.

Returns:

The full path to the saved serialized file.

Return type:

str

Examples

>>> from geoprior.utils.io_utils import serialize_data
>>> import numpy as np
>>> data = {"a": np.arange(10), "b": np.random.rand(10)}
>>> filepath = serialize_data(
...     data, filename="mydata.pkl", savepath="output",
...     to="pickle", force=False, verbose=1
... )
>>> print(filepath)
/current/working/directory/output/mydata_<timestamp>.pkl

Notes

The function first constructs the file path from savepath and filename. If a file already exists and force is False, a timestamp is appended to ensure uniqueness. Then, depending on the value of to, the function attempts to serialize the data using either joblib.dump (with optional compression via the compress parameter) or pickle.dump (using the specified pickle_protocol). If an error occurs during serialization, an IOError is raised.

See also

joblib.dump

Serialize objects to disk using Joblib.

pickle.dump

Serialize objects to disk using Pickle.

os.getcwd

Retrieve the current working directory.

geoprior.utils.io_utils.serialize_data_in(data, filename=None, force=True, savepath=None, verbose=0)[source]

Serializes a Python object to a binary file using either joblib or pickle.

This function attempts to serialize the input data using the joblib.dump method. If this attempt fails, it falls back to using pickle.dump. The final file path is constructed by concatenating the directory specified by savepath (or the current working directory if savepath is None) with the given filename. Mathematically, the file path is given by:

(26)#\[\text{filepath} = \text{savepath} \oplus \text{filename}\]

where \(\oplus\) denotes string concatenation.

Parameters:
  • data (Any) – The Python object to serialize. It must be compatible with either joblib or pickle serialization.

  • filename (str, optional) – The target filename for the serialized data. If None, a filename is generated using the current timestamp formatted as "%Y%m%d%H%M%S" (e.g., "serialized_20230315123045.pkl").

  • force (bool, default True) – Determines whether to overwrite an existing file with the same filename. If False, a timestamp is appended to the filename to ensure uniqueness.

  • savepath (str, optional) – The directory in which to save the serialized file. If not specified, the file is saved to the current working directory (os.getcwd()).

  • verbose (int, default 0) – Controls the verbosity of output messages. Higher values produce more detailed logging during the serialization process.

Returns:

The complete file path to which the data has been serialized.

Return type:

str

Examples

>>> from geoprior.utils.io_utils import serialize_data_in
>>> data = {"a": 1, "b": 2}
>>> filepath = serialize_data_in(data, filename='data.pkl',
...                              force=True, verbose=1)
>>> print(filepath)
/path/to/current/directory/data.pkl

Notes

The function first tries to serialize the input data using joblib.dump. In case of any exception during this attempt, it falls back to using pickle.dump. This dual approach improves robustness in diverse runtime environments where one serialization method might be unsupported or encounter issues with the given data type.

See also

joblib.dump

Serialize objects to disk using Joblib.

pickle.dump

Serialize objects to disk using Pickle.

os.getcwd

Retrieve the current working directory.

geoprior.utils.io_utils.spath(name_of_path)[source]

Create a directory if it does not already exist.

Parameters:

name_of_path (str) – Path-like object to create if it doesn’t exist.

Returns:

The absolute path to the created or existing directory.

Return type:

str

Examples

>>> from geoprior.utils.io_utils import spath
>>> path = spath('data/saved_models')
>>> print(f"Directory available at: {path}")

Notes

spath is useful for quickly ensuring that a specific directory is available for storing files. It provides feedback if the directory already exists.

geoprior.utils.io_utils.store_or_write_hdf5(df, key=None, mode='a', kind=None, path_or_buf=None, encoding='utf8', csv_sep=',', index=Ellipsis, columns=None, sanitize_columns=False, func=None, args=(), applyto=None, **func_kwds)[source]

Store a DataFrame to HDF5, write it to CSV, or sanitize it in memory.

Parameters:
  • df (pandas.DataFrame or array-like) – Input data to store, export, or sanitize.

  • key (str or None, optional) – Group key used when storing to HDF5.

  • mode ({'a', 'w', 'r+'}, optional) – File mode used when opening an HDF5 store.

  • kind ({'store', 'write', None}, optional) – Operation to perform. Use 'store' for HDF5 output, 'write' for CSV export, or None to return a sanitized DataFrame.

  • path_or_buf (str, path-like, pandas.HDFStore, file-like, or None, optional) – Destination path, buffer, or open HDF5 store.

  • encoding (str, optional) – Output encoding used for CSV export.

  • csv_sep (str, optional) – Field separator used for CSV export.

  • index (bool, optional) – Whether to write the index when exporting to CSV.

  • columns (list of str or None, optional) – Column names used when constructing a DataFrame from an array.

  • sanitize_columns (bool, optional) – Whether to sanitize column names with the built-in regex helper.

  • func (callable or None, optional) – Optional custom sanitizing function applied to selected columns.

  • args (tuple, optional) – Positional arguments forwarded to func.

  • applyto (str or list of str or None, optional) – Column or columns to which func should be applied.

  • func_kwds (dict) – Keyword arguments forwarded to func.

Returns:

Returns None when kind is 'store' or 'write'. Otherwise returns the resulting DataFrame.

Return type:

None or pandas.DataFrame

geoprior.utils.io_utils.to_hdf5(data, fn, objname=None, close=True, **hdf5_kws)[source]

Store a data object in Hierarchical Data Format 5 (HDF5).

This function serializes the input data into an HDF5 file. It supports both pandas DataFrames and NumPy arrays. If data is a DataFrame, it uses pd.HDFStore (which requires the pytables package) to store the data. If data is a NumPy array, it uses h5py.File to create a dataset.

The file path is constructed by concatenating the specified savepath (or the current working directory if savepath is not provided) with the provided filename (fn). The function automatically appends the appropriate file extension: .h5 for DataFrames and .hdf5 for arrays.

(27)#\[\text{filepath} = \text{savepath} \oplus \text{filename} \oplus \text{extension}\]

where \(\oplus\) denotes string concatenation.

Parameters:
  • data (Any) – The data object to be stored. Must be either a NumPy array or a pandas DataFrame.

  • fn (str) – The file path (without extension) where the HDF5 file will be saved.

  • objname (str, optional) – The name under which to store the data within the HDF5 file. Defaults to 'data' if not provided.

  • close (bool, default True) – If True, the file is closed after writing. If False, the file remains open for additional modifications.

  • **hdf5_kws (dict, optional) – Additional keyword arguments to pass to the HDFStore constructor (for DataFrames) or to customize dataset creation (for arrays). Common options include mode for the file mode, complevel for compression level, complib for the compression library, and fletcher32 to enable the Fletcher32 checksum. For mode, use 'r' for read-only access, 'w' to create a new file, 'a' to append or create, and 'r+' to open an existing file for reading and writing.

Returns:

store – An IO interface for the stored data. For DataFrames, this is a pd.HDFStore object; for arrays, an h5py.File object.

Return type:

object

Examples

>>> import os
>>> import pandas as pd
>>> from geoprior.utils.io_utils import to_hdf5
>>> data = pd.DataFrame({
...     'a': [1, 2, 3],
...     'b': [4, 5, 6]
... })
>>> save_path = os.path.join('output', 'datafile')
>>> store = to_hdf5(data, fn=save_path, objname='mydata', verbose=1)
>>> # Access stored data:
>>> retrieved = store['mydata']
>>> print(retrieved.head())

Notes

Ensure the dependency pytables is installed when serializing a DataFrame. When serializing NumPy arrays, the dataset is created with the name "dataset_01". If close is set to False, the caller is responsible for closing the store. The pandas and NumPy foundations underlying this serialization path are summarized in [33, 34].

See also

joblib.dump, pickle.dump, h5py.File

geoprior.utils.io_utils.zip_extractor(zip_file, samples='*', ftype=None, savepath=None, pwd=None)[source]

Extracts files from a ZIP archive based on various filtering criteria and saves them to a specified directory.

The extraction process can be controlled by the samples parameter to limit the number of files extracted, or by the ftype parameter to filter by a specific file extension. The resulting file names are returned as a list.

(28)#\[\text{Extracted Files} = \{ f \in \mathcal{A} \mid \phi(f) \}\]

where \(\mathcal{A}\) is the set of all files in the archive, and \(\phi(f)\) is a predicate that checks if a file matches the desired extension and is within the specified sample count.

Parameters:
  • zip_file (str) – Full path to the ZIP archive file.

  • samples (int or str, optional) – Number of files to extract. If set to '*', all files are extracted. Default is '*'.

  • ftype (str, optional) – File extension filter (e.g., '.csv'). Only files with this extension are extracted. If no matching files are found, a ValueError is raised.

  • savepath (str, optional) – Directory where the extracted files will be stored. If not provided, files are extracted to the current working directory.

  • pwd (str or bytes, optional) – Password for encrypted ZIP files. If provided as a string, it will be used as is (or can be encoded to bytes as needed).

Returns:

A list of extracted file names (with paths).

Return type:

list of str

Examples

>>> from geoprior.utils.io_utils import zip_extractor
>>> extracted_files = zip_extractor(
...     'data/archive.zip',
...     samples='*',
...     ftype='.csv',
...     savepath='data/extracted',
...     pwd='secret'
... )
>>> print(extracted_files)
['folder1/file1.csv', 'folder2/file2.csv', ...]

Notes

The function first validates the input ZIP file using check_files (assumed to be defined in the package). It then determines the sample count and filters files by extension if ftype is provided. Extraction is done via the standard ZipFile.extract or ZipFile.extractall methods.

See also

zipfile.ZipFile.extract

Extract a single file from a ZIP archive.

zipfile.ZipFile.extractall

Extract all files from a ZIP archive.

geoprior.utils.io_utils.fetch_joblib_data(job_file, *keys, error_mode='raise', verbose=0)[source]

Dynamically load data from a joblib-saved dictionary with flexible key access.

Parameters:
  • job_file (str) – Path to the joblib file containing a dictionary

  • *keys (str) – Variable-length list of dictionary keys to retrieve

  • error_mode ({'raise', 'warn', 'ignore'}, default 'raise') – Handling of missing keys: - ‘raise’: Immediately raise KeyError - ‘warn’: Issue warning and skip missing keys - ‘ignore’: Silently skip missing keys

  • verbose (int, default 0) – Verbosity level: - 0: No output - 1: Basic loading information - 2: Detailed debugging output

Returns:

  • Full dictionary if no keys specified

  • Tuple of values for requested keys (maintaining order)

Return type:

Union[Dict, Tuple]

Raises:
  • FileNotFoundError – If specified job_file doesn’t exist

  • TypeError – If loaded data isn’t a dictionary

  • KeyError – If requested key not found and error_mode=’raise’

Examples

>>> from geoprior.utils.io_utils import fetch_joblib_data
>>> data = fetch_joblib_data('data.joblib', 'X_train', 'y_train')
>>> X, y = fetch_joblib_data('data.joblib', 'X_val', 'y_val', verbose=1)
>>> full_dict = fetch_joblib_data('data.joblib')

Notes

  • Maintains original insertion order for Python 3.7+ dictionaries

  • Missing keys in ‘warn’/’ignore’ modes result in shorter return tuple

  • Joblib files must contain dictionary objects

geoprior.utils.io_utils.to_txt(d, filename=None, format='txt', indent=2, width=80, depth=None, compat=False, include_header=True, mode='w', encoding='utf-8', overwrite=True, header=None, footer=None, serializer=None, savepath=None, verbose=1, logger=None, **kwargs)[source]

Export data objects to a text or JSON file with optional custom formatting.

The function, <to_txt>, handles writing <d> (a string, dict, list, or general object) to a file named <filename>. When no filename is given, it automatically generates one based on the current date/time. If <format> is “json” and <d> is valid for JSON serialization, it attempts a JSON export. Otherwise, it falls back to text mode, leveraging Python’s built-in pformat and an optional <serializer> for advanced transformations.

(29)#\[\begin{split}\\text{FileName}_{timestamp} \\rightarrow \\text{output}\end{split}\]

where \(\\text{FileName}_{timestamp}\) is an auto-generated name like output_20230101_123456.txt if <filename> is not provided.

Parameters:
  • d (object) – Data to write. Can be any Python object supported by pformat, or a dict if <format> is ‘json’.

  • filename (str, optional) – Full path (or name) of the output file. If None, a time-stamped name is produced, prefixed with ‘output_’.

  • format (str, default 'txt') – File format, either "txt" or "json". If it fails to serialize as JSON, the process reverts to text.

  • indent (int, default 2) – Indentation level for pretty-printing text or JSON.

  • width (int, default 80) – Wrap width for formatted text lines.

  • depth (int, optional) – Maximum depth to which nested structures are expanded. If None, no limit is applied.

  • compat (bool, default False) – If True, instructs pformat to produce more compact text. Not used when exporting JSON.

  • include_header (bool, default True) – Whether to include a decorative header (with timestamp) at the top of the file in text mode.

  • mode (str, default 'w') – File writing mode. Typically ‘w’ for overwrite, ‘a’ for append.

  • encoding (str, default 'utf-8') – Text encoding used when opening the file.

  • overwrite (bool, default True) – If False, raises an error if the file already exists.

  • header (str, optional) – Custom header text (if <include_header> is True). Overwrites the default header if given.

  • footer (str, optional) – Custom footer text appended at the end of the file, if <include_header> is True.

  • serializer (callable, optional) – A function that transforms <d> before printing. If it fails, <d> remains unchanged.

  • verbose (int, default 1) – Verbosity level for logging. Higher values yield more console messages (e.g., file stats at <verbose>>=3).

  • **kwargs – Additional parameters passed to the JSON serializer (json.dump) or pformat.

Returns:

The final filename used to store the output (potentially auto-generated).

Return type:

str

Notes

If <format> is “json”, the function tries json.dump with a few standard parameters. If an exception occurs, it reverts to text export. The <serializer> argument allows custom transformations, such as flattening nested dicts or converting objects to JSON- serializable representations. The standard-library JSON behavior used here is documented in Python Software Foundation [35].

Examples

>>> from geoprior.utils.io_utils import to_txt
>>> my_data = {"name":"Alice","age":30}
>>> # Basic text export
>>> txt_file = to_txt(my_data, verbose=2)
>>> # Enforce JSON format
>>> json_file = to_txt(my_data, format='json', indent=4)

See also

pformat

Pretty-print complex Python data structures.

Parallel execution helpers for GeoPrior workflows.

geoprior.utils.parallel_utils.resolve_n_jobs(n_jobs)[source]
Parameters:

n_jobs (int)

Return type:

int

geoprior.utils.parallel_utils.threads_per_job(*, n_jobs, threads=0, reserve=1)[source]
Parameters:
Return type:

int

geoprior.utils.parallel_utils.apply_thread_env(env, *, n_jobs, threads=0, reserve=1)[source]
Parameters:
Return type:

dict[str, str]

geoprior.utils.parallel_utils.apply_tf_threading(*, intra, inter)[source]
Parameters:
Return type:

None

geoprior.utils.parallel_utils.detect_gpu_ids(*, env=None)[source]
Parameters:

env (dict[str, str] | None)

Return type:

list[str]

geoprior.utils.parallel_utils.resolve_device(device, *, env=None)[source]
Parameters:
Return type:

str

geoprior.utils.parallel_utils.resolve_gpu_ids(gpu_ids, *, env=None)[source]
Parameters:
Return type:

list[str]

geoprior.utils.parallel_utils.pick_gpu_id(idx, gpu_ids)[source]
Parameters:
Return type:

str | None

geoprior.utils.parallel_utils.apply_gpu_env(env, *, gpu_id, allow_growth=True)[source]
Parameters:
Return type:

dict[str, str]

System utilities module for managing system-level operations.

This module provides utilities essential for system-level tasks such as color management, regular expression searching, and projection validation, along with other miscellaneous system operations.

class geoprior.utils.sys_utils.BatchDataFrameBuilder(chunk_size=100000, processor='auto', verbose=1)[source]

Bases: object

Manages incremental construction of a large DataFrame in controlled-size chunks. This can reduce peak memory usage and allow GPU-accelerated libraries (e.g., cudf) if they are available and desired.

The approach can be expressed mathematically as a chunking process that partitions an incoming stream of \(N\) row-dictionaries into \(k\) subsets of size \(m\ (\text{<=}\ \text{chunk\_size})\):

(30)#\[k = \left\lceil \frac{N}{m} \right\rceil\]

Each subset is converted into a DataFrame, stored, and released from memory, and then concatenated at finalization time.

Parameters:
  • chunk_size (int, optional) – The maximum number of rows to hold in the internal buffer before converting them into a DataFrame chunk. Default is 100000.

  • processor ({'auto', 'cpu', 'gpu'}, optional) –

    Controls the engine used to build the DataFrame:

    • 'cpu' : Always use pandas.

    • 'gpu' : Attempt to use cudf (raise an error if not available).

    • 'auto' : Use cudf if a GPU is detected and cudf is installed; otherwise fallback to pandas.

  • verbose (int, optional) –

    Verbosity level. Default is 1:

    • 0 : Silent.

    • 1 : Basic information.

    • 2 : Debug / detailed printing.

Notes

This object is intended for situations where the total row count can be very large, potentially in the millions. By breaking data into chunks, you can avoid excessive memory usage and keep the system more responsive. If processor is 'auto' or 'gpu', the module calls check_processor to verify GPU availability, then uses cudf if appropriate.

Note

If the total data is larger than your available memory (whether RAM or GPU), consider writing out each chunk to disk as a partitioned file (e.g., Parquet or Feather) instead of storing them all in memory.

Examples

>>> from geoprior.utils.sys_utils import BatchDataFrameBuilder
>>> # Suppose we have a large list of dictionaries
>>> data = [
...     {'colA': i, 'colB': i**2} for i in range(10**6)
... ]
>>> with BatchDataFrameBuilder(chunk_size=50000,
...                            processor='auto',
...                            verbose=2) as builder:
...     builder.add_rows(data)
...
>>> # After exiting the context, the final DataFrame is
>>> # automatically built and stored in builder.final_df
>>> final_df = builder.final_df
>>> print(final_df.shape)
(1000000, 2)

See also

pandas.DataFrame

Core pandas DataFrame object.

cudf.DataFrame

GPU DataFrame object from RAPIDS.

check_processor

Utility for detecting GPU availability.

__init__(chunk_size=100000, processor='auto', verbose=1)[source]

Initializes the builder, setting up chunk size, processor preference, and verbosity. Checks for GPU availability if requested.

__enter__()[source]

Enters the context manager. Returns self so we can use it in a with-statement scope.

add_row(row)[source]

Adds a single row to the internal buffer.

This method appends the given dictionary row to the in-memory buffer. If the buffer reaches self.chunk_size, it is automatically flushed.

Parameters:

row (dict) – A row in dictionary form, where keys correspond to column names and values represent the row data.

Notes

Internally calls _flush() once the buffer has reached its maximum size.

add_rows(rows)[source]

Adds multiple rows to the internal buffer.

This method iterates over the list of dictionaries rows. For each element, add_row() is called, which may trigger a flush if the buffer is full.

Parameters:

rows (list of dict) – Each dictionary should have the same structure as a typical row in the final DataFrame.

Notes

This method is merely a convenience layer over add_row().

finalize()[source]

Flushes remaining rows and concatenates all chunks.

Once the remaining rows in _rows are processed into a chunk, this method concatenates all stored chunk DataFrames (either pandas or cudf) into one final DataFrame. The resulting DataFrame is returned.

Returns:

The final DataFrame, which may be a pandas DataFrame or a cudf DataFrame (if processor is set to allow GPU usage and cudf is available).

Return type:

DataFrame

Notes

After concatenation, all chunk DataFrames are cleared from memory. This method is called automatically upon exiting the context (i.e., in __exit__()).

__exit__(exc_type, exc_val, exc_tb)[source]

Exits the context manager. Automatically finalizes the DataFrame by calling finalize(), storing the result in self.final_df.

class geoprior.utils.sys_utils.WorkflowOptimizer(parallelize=True, memory_cleanup=False, log_level=20, optimize_cpu=True, num_processes=None, cpu_cores=None, verbose=True)[source]

Bases: object

WorkflowOptimizer is a decorator class designed to optimize the execution of computationally intensive functions by enabling parallelization, managing CPU and memory resources, and performing cleanup tasks. It provides flexibility through various parameters that allow users to customize optimization strategies according to their workflow requirements.

(31)#\[T_{ ext{total}} = T_{ ext{start}} + T_{ ext{execution}} + T_{ ext{cleanup}}\]

Here, \(T_{ ext{total}}\) is the total workflow time, \(T_{ ext{start}}\) is the initialization time, \(T_{ ext{execution}}\) is the main execution time, and \(T_{ ext{cleanup}}\) is the cleanup time.

Parameters:
  • parallelize (bool, optional) – Flag to enable or disable parallel processing. If set to True, the decorator will attempt to parallelize the execution of the decorated function using multiprocessing. Default is True.

  • memory_cleanup (bool, optional) – Whether to clean up system memory after the execution of the decorated function. This includes triggering garbage collection and clearing GPU caches if applicable. Default is False.

  • log_level (int, optional) – Level of logging verbosity. Accepts standard logging levels such as logging.INFO, logging.DEBUG, etc. Default is logging.INFO.

  • optimize_cpu (bool, optional) – Whether to optimize CPU usage by setting CPU affinity to restrict the process to specific CPU cores. If True, the decorator will bind the process to the cores specified in cpu_cores. Default is True.

  • num_processes (int, optional) – The number of parallel processes to use when parallelize is enabled. If not specified, it defaults to the minimum of the number of available CPU cores and the length of the data iterable passed to the function. Default is None.

  • cpu_cores (list or None, optional) – A list of specific CPU cores to bind the process to for optimized CPU usage. If None, the process is allowed to run on all available CPU cores. Example: [0, 1, 2, 3]. Default is None.

  • verbose (bool, optional) – Whether to print detailed logs during execution. If set to False, only essential information will be logged based on the log_level. Default is True.

Examples

>>> from geoprior.utils.sys_utils import WorkflowOptimizer
>>> import time
>>>
>>> @WorkflowOptimizer(
...     parallelize=True,
...     memory_cleanup=True,
...     log_level=logging.DEBUG,
...     num_processes=4,
...     cpu_cores=[0, 1, 2, 3],
...     verbose=True
... )
>>> def process_data(data_chunk):
...     '''Simulate a time-consuming data processing function.'''
...     time.sleep(1)  # Simulate a time-consuming task
...     return f"Processed {data_chunk}"
...
>>> data_chunks = ['chunk1', 'chunk2', 'chunk3', 'chunk4']
>>> results = process_data(data=data_chunks)
>>> print(results)
['Processed chunk1', 'Processed chunk2', 'Processed chunk3',
'Processed chunk4']

Notes

The decorator checks for the presence of a data keyword argument to decide whether parallelization should be applied. When parallelize=True, the decorated function should be compatible with multiprocessing, meaning it should be picklable. Memory cleanup can be useful in long-running workflows, and CPU affinity may improve performance by reducing context switching and cache misses. Logging behavior follows the standard Python logging model.

See also

multiprocessing.Pool

Provides a pool of worker processes.

psutil.Process

Allows manipulation of system processes.

__init__(parallelize=True, memory_cleanup=False, log_level=20, optimize_cpu=True, num_processes=None, cpu_cores=None, verbose=True)[source]
Parameters:
  • parallelize (bool)

  • memory_cleanup (bool)

  • log_level (int)

  • optimize_cpu (bool)

  • num_processes (int | None)

  • cpu_cores (list[int] | None)

  • verbose (bool)

__call__(func)[source]

Makes the class instance callable so it can be used as a decorator.

Parameters:

func (function) – The function to be decorated and optimized.

Returns:

wrapper – The wrapped function with optimization strategies applied.

Return type:

function

geoprior.utils.sys_utils.check_port_in_use(port)[source]

Checks if a port is currently in use, which is useful for server-based applications.

Parameters:

port (int) – The port number to check.

Returns:

True if the port is in use, otherwise False.

Return type:

bool

Examples

>>> from geoprior.utils.sys_utils import check_port_in_use
>>> check_port_in_use(8080)
False
geoprior.utils.sys_utils.clean_temp_files(directory=None)[source]

Cleans up temporary files in a specified directory.

Parameters:

directory (str, optional) – The directory to clean up. If None, cleans the default temporary directory.

Return type:

None

Notes

This function is particularly useful for freeing up disk space in data-intensive applications.

Examples

>>> from geoprior.utils.sys_utils import clean_temp_files
>>> clean_temp_files("/path/to/temp/dir")
geoprior.utils.sys_utils.create_temp_dir(prefix='tmp')[source]

Creates a temporary directory and returns its path.

Parameters:

prefix (str, optional) – The prefix for the temporary directory name. Default is “tmp”.

Returns:

dir_path – The full path of the created temporary directory.

Return type:

str

Notes

This function is helpful for managing temporary directories in applications where short-term data storage is needed.

Examples

>>> from geoprior.utils.sys_utils import create_temp_dir
>>> temp_dir = create_temp_dir()
>>> print(temp_dir)
'/tmp/tmpabcd1234'

See also

create_temp_file

Creates a temporary file.

geoprior.utils.sys_utils.create_temp_file(suffix='', prefix='tmp')[source]

Creates a temporary file and returns its path.

Parameters:
  • suffix (str, optional) – The suffix for the temporary file. Default is an empty string.

  • prefix (str, optional) – The prefix for the temporary file. Default is “tmp”.

Returns:

file_path – The full path of the created temporary file.

Return type:

str

Notes

This function is useful for handling data temporarily in applications where files need to be stored and accessed for a short time.

Examples

>>> from geoprior.utils.sys_utils import create_temp_file
>>> temp_file = create_temp_file()
>>> print(temp_file)
'/tmp/tmpabcd1234'

See also

create_temp_dir

Creates a temporary directory.

geoprior.utils.sys_utils.environment_summary()[source]

Provides a summary of the current computing environment, including information on Python version, OS, CPU, memory, available GPU(s), and a list of installed Python packages.

Returns:

env_info – Dictionary containing environment details, including:

  • python_version : The version of Python in use.

  • os : Operating system name.

  • os_version : Version of the operating system.

  • cpu_count : Number of logical CPU cores.

  • memory : Total system memory in GB.

  • device_count, device_name, memory_total, cuda_version (if available) : GPU details from detailed_gpu_info.

  • installed_packages : List of installed Python packages (first 10) in name==version format.

Return type:

dict

Notes

The function attempts to load installed packages using pkg_resources. If pkg_resources is not available, it defaults to “N/A” for installed packages.

Raises:
  • ImportError – If pkg_resources is not installed.

  • RuntimeError – If an error occurs while gathering environment information.

Return type:

dict[str, str]

Examples

>>> from geoprior.utils.sys_utils import environment_summary
>>> env_info = environment_summary()
>>> print(env_info)
{'python_version': '3.9.5', 'os': 'Linux', 'os_version': '5.4.0-80-generic',
 'cpu_count': '4', 'memory': '15.5 GB', 'device_count': '1',
 'device_name': 'NVIDIA Tesla T4', 'memory_total': '15.99 GB',
 'cuda_version': '11.1', 'installed_packages': 'numpy==1.21.0, pandas==1.3.0, ...'}
geoprior.utils.sys_utils.find_by_regex(o, pattern, func=<function match>, **kws)[source]

Find pattern in object whatever an “iterable” or not.

when we talk about iterable, a string value is not included.

Parameters:
  • o (str or iterable,) – text litteral or an iterable object containing or not the specific object to match.

  • pattern (str, default = '[_#&*@!_,;\s-]\s*') – The base pattern to split the text into a columns

  • func (re callable , default re.match) –

    regular expression search function. Can be [re.match, re.findall, re.search ],or any other regular expression function.

    • re.match(): function searches the regular expression pattern and

      return the first occurrence. The Python RegEx Match method checks for a match only at the beginning of the string. So, if a match is found in the first line, it returns the match object. But if a match is found in some other line, the Python RegEx Match function returns null.

    • re.search(): function will search the regular expression pattern

      and return the first occurrence. Unlike Python re.match(), it will check all lines of the input string. The Python re.search() function returns a match object when the pattern is found and “null” if the pattern is not found

    • re.findall() module is used to search for ‘all’ occurrences that

      match a given pattern. In contrast, search() module will only return the first occurrence that matches the specified pattern. findall() will iterate over all the lines of the file and will return all non-overlapping matches of pattern in a single step.

  • kws (dict,) – Additional keywords arguments passed to functions re.match() or re.search() or re.findall().

Returns:

om – matched object put is the list

Return type:

list

Example

>>> from geoprior.utils.sys_utils import find_by_regex
>>> from geoprior.datasets import load_hlogs
>>> X0, _= load_hlogs (as_frame =True )
>>> columns = X0.columns
>>> str_columns =','.join (columns)
>>> find_by_regex (str_columns , pattern='depth', func=re.search)
... ['depth']
>>> find_by_regex(columns, pattern ='depth', func=re.search)
... ['depth_top', 'depth_bottom']
geoprior.utils.sys_utils.find_similar_string(name, container, stripitems='_', deep=False)[source]

Find the most similar string in a container to the provided name.

This function searches for the most likely matching string in a container based on the provided name. It sanitizes the name by stripping specified characters and can perform a deep search to find partial matches.

Parameters:
  • name (str) – The string to search for in the container.

  • container (list, tuple, or dict) – The container with strings to search in.

  • stripitems (str or list of str, optional) – Characters or strings to strip from name before searching. If a string, multiple items can be separated by ‘:’, ‘,’, or ‘;’. Default is '_'.

  • deep (bool, optional) – If True, performs a deeper search by checking if name is a substring of any item in the container. Default is False.

Returns:

result – The most similar string from the container, or None if no match is found.

Return type:

str or None

Examples

>>> from geoprior.utils.sys_utils import find_similar_string
>>> container = {'dipole': 1, 'quadrupole': 2}
>>> find_similar_string('dipole_', container)
'dipole'
>>> find_similar_string('dip', container, deep=True)
'dipole'
>>> find_similar_string('+dipole__', container, stripitems='+;__', deep=True)
'dipole'

Notes

This function is useful when trying to find the closest matching string in a container, especially when exact matches are not guaranteed due to formatting inconsistencies or typos.

See also

str.strip

Returns a copy of the string with leading and trailing characters removed.

geoprior.utils.sys_utils.get_cpu_usage(per_cpu=False)[source]

Returns the current CPU usage as a percentage, optionally providing per-core usage for systems with multiple cores.

Parameters:

per_cpu (bool, default False) – If True, returns a list with the CPU usage percentage for each core. If False, returns the overall CPU usage as a single percentage.

Returns:

usage – If per_cpu is False, returns the overall CPU usage as a float percentage. If per_cpu is True, returns a list with each entry corresponding to the usage percentage of an individual core.

Return type:

float or list of float, optional

Notes

This function uses the psutil library to retrieve CPU usage information and requires an interval of 1 second to calculate the usage accurately.

Examples

>>> from geoprior.utils.sys_utils import get_cpu_usage
>>> get_cpu_usage()
1.3
>>> get_cpu_usage(per_cpu=True)
[20.4, 25.1, 21.3, 24.5]
geoprior.utils.sys_utils.get_disk_usage(path='/')[source]

Returns disk usage statistics for a specified filesystem path, including total, used, and free disk space in gigabytes (GB).

Parameters:

path (str, default '/') – The filesystem path for which to check disk usage statistics. By default, it uses the root directory (/).

Returns:

disk_usage – A tuple containing:

  • total_disk : Total disk space in GB.

  • used_disk : Used disk space in GB.

  • free_disk : Free disk space in GB.

Return type:

tuple of float, optional

Notes

Disk usage information is gathered using the psutil library. Disk space is converted to gigabytes (GB) by dividing the values by 1024^3.

Raises:
  • FileNotFoundError – If the specified path does not exist on the filesystem.

  • PermissionError – If the program does not have permission to access the specified path.

Parameters:

path (str)

Return type:

tuple[float, float, float] | None

Examples

>>> from geoprior.utils.sys_utils import get_disk_usage
>>> total, used, free = get_disk_usage(path="/")
>>> print(f"Total: {total} GB, Used: {used} GB, Free: {free} GB")
Total: 256 GB, Used: 128 GB, Free: 128 GB
geoprior.utils.sys_utils.get_gpu_info()[source]

Provides detailed information about available GPUs, including device name, memory capacity, and CUDA version (if PyTorch is installed).

Returns:

gpu_info – Dictionary containing GPU details, including:

  • device_count : Number of available GPU devices.

  • device_name : Name of the first GPU device.

  • memory_total : Total memory of the first GPU device in GB.

  • cuda_version : CUDA version, if available.

If no GPU is available or PyTorch is not installed, returns None.

Return type:

dict or None

Notes

This function requires PyTorch to check for GPU availability. If PyTorch is not installed, it logs a warning and returns None.

Raises:
  • ImportError – If PyTorch is not installed on the system.

  • RuntimeError – If there is an issue retrieving GPU properties.

Return type:

dict[str, str] | None

Examples

>>> from geoprior.utils.sys_utils import get_gpu_info
>>> gpu_info = get_gpu_info()
>>> print(gpu_info)
{'device_count': '1', 'device_name': 'NVIDIA Tesla T4',
 'memory_total': '15.99 GB', 'cuda_version': '11.1'}
geoprior.utils.sys_utils.get_installed_packages()[source]

Lists all installed packages along with their versions in the current Python environment.

Returns:

installed_packages – A list of installed packages and their versions in the format package_name==version.

Return type:

list of str

Notes

This function is useful for dependency management and tracking installed packages, especially in data science and production environments.

Examples

>>> from geoprior.utils.sys_utils import get_installed_packages
>>> get_installed_packages()
['numpy==1.21.0', 'pandas==1.3.0', 'scikit-learn==0.24.2', ...]

See also

environment_summary

Summarizes the environment, including installed packages.

geoprior.utils.sys_utils.get_memory_usage()[source]

Retrieves system memory usage statistics, providing the total, used, and available memory in megabytes (MB).

Returns:

memory – A tuple containing:

  • total_memory : Total memory in MB.

  • used_memory : Used memory in MB.

  • available_memory : Available memory in MB.

Return type:

tuple of float

Notes

This function leverages the psutil library for retrieving memory usage information. The conversion to MB is performed by dividing each value by 1024^2.

Examples

>>> from geoprior.utils.sys_utils import get_memory_usage
>>> total, used, available = get_memory_usage()
>>> print(f"Total: {total} MB, Used: {used} MB, Available: {available} MB")
Total: 8192 MB, Used: 4096 MB, Available: 4096 MB
geoprior.utils.sys_utils.get_python_version()[source]

Returns the version of Python being used in the current environment.

Returns:

python_version – The version of Python currently in use.

Return type:

str

Examples

>>> from geoprior.utils.sys_utils import get_python_version
>>> get_python_version()
'3.8.5'

See also

get_system_info

Provides broader system information, including Python version.

geoprior.utils.sys_utils.get_system_info()[source]

Retrieves basic system information including OS, Python version, CPU details, and GPU availability.

Returns:

system_info – A dictionary containing basic system information:

  • os_name : Name of the operating system.

  • os_version : Version of the operating system.

  • python_version : Python version.

  • cpu_count : Number of logical CPUs.

  • gpu_available : Whether a GPU is available (True or False).

Return type:

dict

Notes

This function checks for GPU availability via PyTorch if installed, otherwise it defaults to False.

Examples

>>> from geoprior.utils.sys_utils import get_system_info
>>> get_system_info()
{'os_name': 'Linux', 'os_version': '5.4.0-81-generic', 'python_version': '3.8.5',
 'cpu_count': '8', 'gpu_available': 'True'}

See also

get_python_version

Retrieves the current Python version.

geoprior.utils.sys_utils.get_uptime()[source]

Returns the system uptime in a human-readable format.

Returns:

uptime – The system uptime formatted as “Xd:Yh:Zm:Ws”, where X, Y, Z, and W are days, hours, minutes, and seconds respectively.

Return type:

str

Notes

This function is useful for monitoring or diagnosing long-running processes on the system.

Examples

>>> from geoprior.utils.sys_utils import get_uptime
>>> get_uptime()
'2d:5h:34m:12s'
geoprior.utils.sys_utils.is_gpu_available()[source]

Checks if a GPU is available for computation on the system, using the PyTorch library if it is installed.

Returns:

available – True if a GPU is available, False otherwise.

Return type:

bool

Notes

This function relies on the torch library (PyTorch) to detect GPU availability. If PyTorch is not installed, it logs a warning and returns False.

Raises:

ImportError – If PyTorch is not installed and thus the GPU availability check cannot be performed.

Return type:

bool

Examples

>>> from geoprior.utils.sys_utils import is_gpu_available
>>> is_gpu_available()
True
geoprior.utils.sys_utils.is_package_installed(package_name)[source]

Checks if a specific package is installed in the current Python environment.

Parameters:

package_name (str) – The name of the package to check.

Returns:

True if the package is installed, otherwise False.

Return type:

bool

Examples

>>> from geoprior.utils.sys_utils import is_package_installed
>>> is_package_installed("numpy")
True
geoprior.utils.sys_utils.is_path_accessible(path, permissions='r')[source]

Checks if a specified path is accessible with the given permissions.

Parameters:
  • path (str) – The path to check for accessibility.

  • permissions (str, optional) – The permission types to check: ‘r’ for read, ‘w’ for write, ‘x’ for execute. Multiple permissions can be specified, e.g., “rw”. Default is “r”.

Returns:

accessible – True if the path is accessible with the specified permissions, otherwise False.

Return type:

bool

Notes

This function verifies file permissions in the current user context, ensuring flexibility for multi-user environments.

Examples

>>> from geoprior.utils.sys_utils import is_path_accessible
>>> is_path_accessible("/path/to/file", permissions="rw")
True
geoprior.utils.sys_utils.is_port_open(port)[source]

Checks if a specified network port is open or occupied on the local machine.

Parameters:

port (int) – The port number to check for availability.

Returns:

Returns True if the port is open (not in use), otherwise False.

Return type:

bool

Notes

This function uses a socket connection to check if the specified port is open. It is helpful in applications where network services or applications need to bind to a specific port.

Raises:

ValueError – If an invalid port number is provided.

Parameters:

port (int)

Return type:

bool

Examples

>>> from geoprior.utils.sys_utils import is_port_open
>>> is_port_open(8080)
False
geoprior.utils.sys_utils.manage_env_variable(var_name, value=None, default=None, action='get', file_path=None, overwrite=False)[source]

Manages environment variables, allowing retrieval, setting, or loading from a .env file.

Parameters:
  • var_name (str) – The name of the environment variable to retrieve, set, or load.

  • value (str, optional) – The value to set for the environment variable. Only used if action is “set”. Default is None.

  • default (str, optional) – The default value to return if the environment variable var_name is not found when action is “get”. If None, returns None when the variable is not found. Default is None.

  • action (str, default "get") –

    The action to perform. Options are:
    • ”get”: Retrieves the environment variable var_name.

    • ”set”: Sets the environment variable var_name to value.

    • ”load”: Loads environment variables from a .env file specified by file_path.

  • file_path (str, optional) – The path to the .env file to load variables from when action is “load”. Required if action is “load”.

  • overwrite (bool, default False) – If True, allows overwriting existing environment variables when action is “load” or “set”. If False, preserves the current value of existing environment variables.

Returns:

result

  • If action is “get”, returns the value of the environment variable var_name or default if the variable is not set.

  • If action is “set” or “load”, returns None.

Return type:

str or None

Notes

  • This function is useful for managing configuration data securely by utilizing environment variables.

  • Loading from a .env file allows you to define multiple variables in a single file, each defined in the KEY=VALUE format.

Raises:
  • ValueError – If action is “set” and value is not provided, or if action is “load” and file_path is not specified.

  • FileNotFoundError – If action is “load” and file_path does not exist.

Parameters:
  • var_name (str)

  • value (str | None)

  • default (str | None)

  • action (str)

  • file_path (str | None)

  • overwrite (bool)

Return type:

str | None

Examples

>>> from geoprior.utils.sys_utils import manage_env_variable
>>> manage_env_variable('HOME', action='get')
'/home/username'
>>> manage_env_variable('NEW_VAR', value='new_value', action='set')
>>> manage_env_variable('NEW_VAR', action='get')
'new_value'
>>> manage_env_variable('NON_EXISTENT_VAR', default='default_value', action='get')
'default_value'
>>> manage_env_variable(
    var_name='', action='load', file_path='/path/to/.env', overwrite=True)

See also

os.getenv

Retrieves environment variables.

os.environ

Provides access to the environment variables.

geoprior.utils.sys_utils.manage_file_lock(file_path, action='lock', blocking=True, exclusive=True)[source]

Manages file locking and unlocking to prevent concurrent access.

This function allows both locking and unlocking actions on a file to prevent or allow concurrent access. It opens the file and applies an exclusive lock or shared lock, depending on the parameters specified.

Parameters:
  • file_path (str) – Path to the file that needs to be locked or unlocked.

  • action (str, default "lock") – Specifies the action to perform: “lock” to acquire a lock, or “unlock” to release a previously acquired lock.

  • blocking (bool, default True) – If True, the lock will block until it can be acquired. If False, the lock will raise an exception if it cannot be acquired immediately.

  • exclusive (bool, default True) – If True, an exclusive lock is applied. If False, a shared lock is applied (other processes can read the file simultaneously).

Returns:

file_descriptor – If action is “lock”, returns the file descriptor on success; otherwise, None if action is “unlock” or if locking fails.

Return type:

int or None

Notes

This function uses the fcntl module for locking, which is only available on Unix-based systems. The lock is maintained as long as the file descriptor remains open.

  • For “lock”, the function opens the file and applies a lock.

  • For “unlock”, it removes the lock and closes the file descriptor.

Raises:
  • ValueError – If the action parameter is not one of “lock” or “unlock”.

  • OSError – If locking or unlocking the file fails.

Parameters:
Return type:

int | None

Examples

>>> from geoprior.utils.sys_utils import manage_file_lock
>>> fd = manage_file_lock("/path/to/file", action="lock", blocking=True)
>>> if fd:
...     print("File is locked.")
...     manage_file_lock(fd, action="unlock")
...     print("File is unlocked.")

See also

os.open

Opens a file descriptor.

fcntl.flock

Applies or removes file locks.

geoprior.utils.sys_utils.manage_temp(suffix='', prefix='tmp', action='create_file', directory=None, clean_all=False)[source]

Manages temporary files and directories by creating, accessing, or cleaning them as needed.

Parameters:
  • suffix (str, optional) – Suffix for the temporary file or directory, used only if action is “create_file” or “create_dir”. Default is an empty string.

  • prefix (str, optional) – Prefix for the temporary file or directory, used only if action is “create_file” or “create_dir”. Default is “tmp”.

  • action (str, default "create_file") –

    Specifies the operation to perform. Options include:
    • ”create_file”: Creates a temporary file and returns its path.

    • ”create_dir”: Creates a temporary directory and returns its path.

    • ”clean”: Cleans temporary files in the specified directory or the system temp directory if none is provided.

  • directory (str, optional) – Directory to clean when action is “clean”. If None, uses the system’s default temporary directory. Ignored for file or directory creation actions.

  • clean_all (bool, default False) – If True, removes all files and directories within the specified directory. If False, only deletes files or directories created by this process. Used only when action is “clean”.

Returns:

temp_path

  • For “create_file” and “create_dir” actions, returns the path of the created file or directory.

  • For “clean”, returns None.

Return type:

str or None

Raises:

ValueError – If an invalid action is specified.

Notes

This function is useful for managing temporary resources in data processing tasks, where files or directories need to be created and cleaned up after use.

Examples

>>> from geoprior.utils.sys_utils import manage_temp
>>> temp_file = manage_temp(action="create_file")
>>> print(temp_file)
'/tmp/tmpabcd1234'
>>> temp_dir = manage_temp(action="create_dir", prefix="data_")
>>> print(temp_dir)
'/tmp/data_abcd1234'
>>> manage_temp(action="clean", directory="/path/to/temp", clean_all=True)

See also

tempfile

Module for creating temporary files and directories.

shutil

High-level file operations.

geoprior.utils.sys_utils.parallelize_jobs(function, tasks=(), n_jobs=None, executor_type='process')[source]

Parallelize the execution of a callable across multiple processors, supporting both positional and keyword arguments.

Parameters:
  • function (Callable[..., Any]) – The function to execute in parallel. This function must be picklable if using executor_type=’process’.

  • tasks (Sequence[Dict[str, Any]], optional) – A sequence of dictionaries, where each dictionary contains two keys: ‘args’ (a tuple) for positional arguments, and ‘kwargs’ (a dict) for keyword arguments, for one execution of function. Defaults to an empty sequence.

  • n_jobs (Optional[int], optional) – The number of jobs to run in parallel. None or 1 uses a single processor, any positive integer specifies the exact number of processors to use, -1 uses all available processors. Default is None (1 processor).

  • executor_type (str, optional) – The type of executor to use. Can be ‘process’ for CPU-bound tasks or ‘thread’ for I/O-bound tasks. Default is ‘process’.

Returns:

A list of results from the function executions.

Return type:

list

Raises:

ValueError – If function is not picklable when using ‘process’ as executor_type.

Examples

>>> from geoprior.utils.sys_utils import parallelize_jobs
>>> def greet(name, greeting='Hello'):
...     return f"{greeting}, {name}!"
>>> tasks = [
...     {'args': ('John',), 'kwargs': {'greeting': 'Hi'}},
...     {'args': ('Jane',), 'kwargs': {}}
... ]
>>> results = parallelize_jobs(greet, tasks, n_jobs=2)
>>> print(results)
['Hi, John!', 'Hello, Jane!']
geoprior.utils.sys_utils.represent_callable(obj, skip=None)[source]

Represent callable objects by formatting their signatures.

This function generates a string representation of a callable object’s signature, including parameters and default values. It supports classes, functions, and instance methods.

Parameters:
  • obj (callable) – The callable object to format.

  • skip (str or list of str, optional) – Parameter names to skip in the representation. Useful for omitting certain attributes.

Returns:

representation – A string representation of the callable object’s signature.

Return type:

str

Raises:

TypeError – If obj is not a callable object.

Examples

>>> from geoprior.utils.sys_utils import represent_callable
>>> def example_function(a, b=2):
...     pass
>>> represent_callable(example_function)
'example_function(a, b=2)'
>>> class ExampleClass:
...     def __init__(self, x, y=10):
...         self.x = x
...         self.y = y
>>> represent_callable(ExampleClass)
'ExampleClass(x, y=10)'
>>> instance = ExampleClass(5)
>>> represent_callable(instance)
'ExampleClass(x=5, y=10)'

Notes

This function is useful for logging or displaying the parameters of callable objects in a readable format.

See also

inspect.signature

Get a signature object for the callable.

geoprior.utils.sys_utils.run_command(command, capture_output=True)[source]

Runs a shell command and optionally captures its output.

Parameters:
  • command (str) – The shell command to execute.

  • capture_output (bool, default True) – If True, captures and returns the command’s output. If False, runs the command without capturing output, which is useful for commands that produce a large output or run interactively.

Returns:

output – Returns the command output as a string if capture_output is True. If capture_output is False, returns None.

Return type:

str or None

Notes

This function uses subprocess.run to execute shell commands, which allows for error handling and logging. For example, run_command("echo Hello World") returns "Hello World\n" when capture_output=True.

Raises:

subprocess.CalledProcessError – If the command exits with a non-zero status and capture_output is True.

Parameters:
  • command (str)

  • capture_output (bool)

Return type:

str | None

geoprior.utils.sys_utils.safe_getattr(obj, name, default_value=None)[source]

Safely get an attribute from an object, with a helpful error message.

This function attempts to retrieve an attribute from the given object. If the attribute is not found, it can return a default value or raise an AttributeError with a suggestion for a similar attribute.

Parameters:
  • obj (object) – The object from which to retrieve the attribute.

  • name (str) – The name of the attribute to retrieve.

  • default_value (any, optional) – A default value to return if the attribute is not found. If None, an AttributeError will be raised.

Returns:

value – The value of the retrieved attribute or the default value.

Return type:

any

Raises:

AttributeError – If the attribute is not found and no default value is provided.

Examples

>>> from geoprior.utils.sys_utils import safe_getattr
>>> class MyClass:
...     def __init__(self, a, b):
...         self.a = a
...         self.b = b
>>> obj = MyClass(1, 2)
>>> safe_getattr(obj, 'a')
1
>>> safe_getattr(obj, 'c', default_value='default')
'default'
>>> safe_getattr(obj, 'c')
Traceback (most recent call last):
    ...
AttributeError: 'MyClass' object has no attribute 'c'. Did you mean 'a'?

Notes

This function enhances the built-in getattr by providing helpful suggestions when an attribute is not found.

See also

getattr

Built-in function to get an attribute from an object.

geoprior.utils.sys_utils.safe_optimize(func=None, *, parallelize=True, memory_cleanup=False, log_level=20, optimize_cpu=True, num_processes=None, cpu_cores=None, verbose=True, mode='strict')[source]

Optimizes the workflow by wrapping a function to measure execution time, enable parallelization, manage resources, and perform memory cleanup and acts similary like class-based decorator WorflowOptimizer.

Class-based decorators can sometimes encounter issues when trying to pickle certain objects, especially in parallel execution contexts. This issue arises because certain objects (such as file handles, open network connections, or non-serializable class instances) cannot be passed between processes in multiprocessing environments. By ensuring compatibility with these contexts, safe_optimize helps mitigate such issues and optimize the execution of computationally intensive workflows.

This decorator is particularly suitable for workflows involving large-scale computations, such as data processing pipelines, machine learning model training, or simulations, where parallel execution and resource optimization are crucial for performance improvement.

Parameters:
  • parallelize (bool, optional) – Flag to enable or disable parallel processing (default is True).

  • memory_cleanup (bool, optional) – Whether to clean up system memory after execution (default is False).

  • log_level (int, optional) – Level of logging (default is logging.INFO). Set to logging.DEBUG for more detailed logs.

  • optimize_cpu (bool, optional) – Whether to optimize CPU core usage (default is True).

  • num_processes (Optional[int], optional) – The number of parallel processes for execution (default is None).

  • cpu_cores (Optional[List[int]], optional) – Specify a list of CPU cores to restrict the process (default is None).

  • verbose (bool, optional) – Whether to print detailed logs during execution (default is True).

  • mode (str, optional) – Mode for handling pickling issues: 'strict' to raise errors, or 'soft' to fallback to sequential execution with warnings (default is 'strict').

  • func (Callable | None)

Returns:

decorator – The wrapped function that includes optimization strategies.

Return type:

Callable

Raises:

ValueError – If an unsupported mode is specified.

Examples

>>> from geoprior.utils.sys_utils import safe_optimize
>>> @safe_optimize(
...     parallelize=True,
...     memory_cleanup=True,
...     log_level=logging.DEBUG,
...     optimize_cpu=True,
...     num_processes=4,
...     cpu_cores=[0, 1, 2, 3],
...     verbose=True,
...     mode='soft'
... )
... def process_data(data):
...     # Your data processing logic here
...     return [d * 2 for d in data]
>>> data = [1, 2, 3, 4, 5]
>>> results = process_data(data)
>>> print(results)
[2, 4, 6, 8, 10]

Notes

  • This decorator uses multiprocessing for parallel execution, which may not be suitable for all environments, especially those that do not support forking (e.g., some Windows configurations).

  • Ensure that the decorated function and its arguments are picklable when using parallelization.

  • The mode parameter allows handling non-picklable objects gracefully.

See also

multiprocessing.Pool

For parallel task execution.

psutil

For system and process utilities.

functools.wraps

For preserving metadata of decorated functions.

geoprior.utils.sys_utils.system_uptime()[source]

Retrieves the system uptime, which is the duration the system has been running since the last boot, in a human-readable format.

Returns:

uptime – System uptime in the format “Xd:Yh:Zm:Ws”, where X, Y, Z, and W represent days, hours, minutes, and seconds, respectively.

Return type:

str

Notes

This function is cross-platform and works on Windows, macOS (Darwin), and Linux. It uses different commands to retrieve uptime based on the operating system.

Raises:
Return type:

str

Examples

>>> from geoprior.utils.sys_utils import system_uptime
>>> system_uptime()
'2d:10h:33m:12s'
geoprior.utils.sys_utils.build_large_df(forecast_results, dt_col, tname, spatial_cols=None, chunk_size=None, verbose=0)[source]

Construct memory-optimized DataFrame from large forecast results using chunked processing.

Implements dynamic chunk sizing and dtype optimization to handle datasets exceeding available memory. Uses temporary storage and parallel processing for efficient resource utilization.

If pyarrow is installed, the function uses parquet I/O; otherwise,

CSV files are used as a fallback.

Parameters:
  • forecast_results (List[Dict]) – Input data as list of dictionary records. Each dictionary represents a row with column-value pairs. Minimum 1000 entries recommended for chunking benefits.

  • dt_col (str) – Name of temporal column. Accepts numeric years (e.g., 2023) or datetime strings. Automatic type detection with fallback to numpy.int32 for years >200000.

  • tname (str) – Target variable prefix for prediction columns. Quantile columns are expected in the form f"{tname}_q{quantile}" such as "subs_q10", while point predictions use f"{tname}_pred".

  • spatial_cols (List[str], optional) – Geographic columns (e.g., ['longitude', 'latitude']). Auto-detects categorical ( <10% unique values) vs continuous spatial data, using pandas.Category or numpy.float32 dtypes respectively.

  • chunk_size (int, optional) –

    Maximum rows per chunk. Auto-calculated using

    (32)\[C_{optimal} = \min\left(10^5, \frac{0.8M_{free}}{S_{row}}\right)\]

    where \(M_{free}\) is available memory in bytes and \(S_{row}\) is the estimated row size, assumed to be about 1 KB by default.

  • verbose (int, default 0) – Logging verbosity. Use 0 for silent mode, 1 for memory reports, 2 for chunk diagnostics, and 3 for per-chunk metrics.

Returns:

Combined DataFrame with optimized dtypes, preserving original column order. Memory footprint reduced by 40-60% compared to naive construction.

Return type:

pd.DataFrame

Examples

>>> from geoprior.utils.sys_utils import build_large_df
>>> import numpy as np

# Basic usage with 1M rows >>> data = [{‘year’: y, ‘value_q50’: np.random.randn()} … for y in range(2010, 2020) for _ in range(100000)] >>> df = build_large_df(data, dt_col=’year’, tname=’value’)

# With spatial columns >>> geo_data = [{‘lat’: np.random.uniform(-90, 90), … ‘lon’: np.random.uniform(-180, 180), … ‘pred’: np.random.randn()} … for _ in range(500000)] >>> df = build_large_df(geo_data, dt_col=’date’, tname=’pred’, … spatial_cols=[‘lat’, ‘lon’], verbose=2)

Notes

Key implementation features include dynamic chunk adjustment using psutil.virtual_memory(), concurrent chunk reading with ThreadPoolExecutor when many chunks are present, dtype inference for temporal and spatial columns, and guaranteed tempfile cleanup via try...finally blocks.

See also

pd.DataFrame

Base DataFrame construction

pd.concat

Chunk aggregation method

geoprior.nn.utils.generate_forecast

Primary data source

geoprior.utils.memory_optimizer.reduce_mem_usage

Detailed dtype optimization

These modules help manage serialization, joblib artifacts, threading, runtime-device behavior, and larger workflow-side assembly patterns.

Scale, shape, and sequence helpers#

Utilities for computing error metrics in physical units given Stage-1 scaling metadata.

Main entry points include inverse_scale_target(...), point_metrics(...), and per_horizon_metrics(...).

These are designed to work with the NATCOM pipeline where Stage-1 stores a scaler_info dict in the manifest:

scaler_info[target_name] = {
    "scaler_path": ".../main_scaler.joblib",
    "all_features": [...],
    "idx": <index in scaler>,
    "scaler": <fitted MinMaxScaler>,  # optional, attached at Stage-2
}

They can also handle a bare scaler object, a scaler path, or manual min/max or mean/std parameters.

geoprior.utils.scale_metrics.evaluate_point_forecast(model, out, y_true_subs, *, y_true_gwl=None, n_q=3, quantiles=None, q=None, use_physical=False, return_physical=None, scaler_info=None, subs_target_name='subsidence', gwl_target_name='gwl', scaler_entry=None, scaler=None, feature_index=None, n_features=None, params=None, strict=True, output_names=None)[source]

End-to-end helper for point-forecast evaluation.

Pipeline: extract predictions from model and out, canonicalize BHQO quantile outputs when needed, pick a point prediction, optionally inverse-scale it, and compute global and per-horizon metrics.

Parameters:
  • model (Any) – Passed to extract_preds(…).

  • out (Any) – Passed to extract_preds(…).

  • y_true_subs (Any) – True subsidence, shape (B, H, 1) or (B, H).

  • y_true_gwl (Any | None) – Optional true gwl/head, same shape conventions.

  • n_q (int) – Quantile-selection controls for (B, H, Q, 1) outputs. If q is None, the median is preferred. If q is an integer, it is treated as a direct quantile index. If q is a float and quantiles is provided, the nearest quantile is selected; otherwise the float is treated as a fraction in [0, 1].

  • quantiles (Sequence[float] | None) – Quantile-selection controls for (B, H, Q, 1) outputs. If q is None, the median is preferred. If q is an integer, it is treated as a direct quantile index. If q is a float and quantiles is provided, the nearest quantile is selected; otherwise the float is treated as a fraction in [0, 1].

  • q (float | int | None) – Quantile-selection controls for (B, H, Q, 1) outputs. If q is None, the median is preferred. If q is an integer, it is treated as a direct quantile index. If q is a float and quantiles is provided, the nearest quantile is selected; otherwise the float is treated as a fraction in [0, 1].

  • use_physical (bool) – If True, compute metrics in physical units.

  • return_physical (bool | None) – If None, defaults to use_physical. If True, return physical-space arrays such as subs_pred_phys and gwl_pred_phys when possible.

  • scaler_info (Mapping[str, Any] | None) – Passed through to inverse_scale_target(...).

  • scaler_entry (Mapping[str, Any] | None) – Passed through to inverse_scale_target(...).

  • scaler (Any | None) – Passed through to inverse_scale_target(...).

  • feature_index (int | None) – Passed through to inverse_scale_target(...).

  • n_features (int | None) – Passed through to inverse_scale_target(...).

  • params (Mapping[str, float] | None) – Passed through to inverse_scale_target(...).

  • subs_target_name (str)

  • gwl_target_name (str)

  • strict (bool)

  • output_names (Sequence[str] | None)

Returns:

Dictionary containing model-space predictions, optional physical-space predictions, and global and per-horizon metrics for subsidence and groundwater outputs.

Return type:

dict

geoprior.utils.scale_metrics.auto_noise_std_from_increments(y_inc, *, noise_frac=0.1, percentile=95.0, min_std=0.0, max_std=None, eps=1e-12)[source]

Compute noise std as a fraction of a robust increment scale.

Parameters:
  • y_inc (np.ndarray) – Deterministic increments (before noise). Any shape; will be flattened.

  • noise_frac (float, default 0.10) – Fraction applied to the percentile scale.

  • percentile (float, default 95.0) – Percentile of the absolute increments used as the scale.

  • min_std (float, default 0.0) – Lower bound for returned std.

  • max_std (float or None, default None) – Optional upper bound for returned std.

  • eps (float, default 1e-12) – Small positive value for safe fallback.

Returns:

A finite, non-negative std value.

Return type:

float

Notes

Non-finite values are filtered first. If the percentile-based scale is approximately zero, the function falls back to the maximum absolute increment, then the mean absolute increment, and finally to eps.

geoprior.utils.scale_metrics.resolve_noise_std(y_inc, *, noise_std=None, noise_frac=0.1, percentile=95.0, min_std=0.0, max_std=None)[source]

Prefer explicit noise_std when provided, otherwise auto-compute from increments.

Parameters:
Return type:

float

geoprior.utils.scale_metrics.scale_target(y_phys, *, scaler_info=None, target_name=None, scaler_entry=None, scaler=None, feature_index=None, n_features=None, params=None)[source]

Forward-transform physical values into scaled space.

Mirrors inverse_scale_target(…) conventions.

Parameters:
Return type:

ndarray

geoprior.utils.scale_metrics.inverse_scale_target(y_scaled, *, scaler_info=None, target_name=None, scaler_entry=None, scaler=None, feature_index=None, n_features=None, params=None)[source]

Inverse-transform a scaled target array back to physical units.

Supports three patterns:

  1. Stage-1 scaler_info dict (preferred in NATCOM): pass scaler_info=scaler_info_dict, target_name="subsidence".

  2. A bare scaler instance or a path to a joblib dump via scaler=.... If multi-feature, also pass feature_index and optionally n_features.

  3. Manual scaling parameters via params such as {"min", "max"}, {"mean", "std"}, or {"scale", "shift"}.

Parameters:
  • y_scaled (array-like) – Scaled values, e.g. of shape (N, H, 1) or (N,).

  • scaler_info (mapping, optional)

  • target_name (str, optional)

  • scaler_entry (mapping, optional)

  • scaler (object or str, optional)

  • feature_index (int, optional)

  • n_features (int, optional)

  • params (mapping, optional) – Manual parameters: - {"min", "max"} → MinMax scaling - {"mean", "std"} → standardization - {"scale", "shift"} → x = scale * x_scaled + shift.

Returns:

Array with same shape as y_scaled but in physical units when scaling information is available. If nothing usable is found, returns the input as a NumPy array.

Return type:

np.ndarray

geoprior.utils.scale_metrics.point_metrics(y_true, y_pred, *, use_physical=False, scaler_info=None, target_name=None, scaler_entry=None, scaler=None, feature_index=None, n_features=None, params=None)[source]

Compute MAE/MSE/R², optionally in physical units.

If use_physical=True, both y_true and y_pred are inverse- transformed via inverse_scale_target before computing metrics.

Parameters:
  • y_true (array-like) – Arbitrary shapes (e.g. (N, H, 1)); flattened internally.

  • y_pred (array-like) – Arbitrary shapes (e.g. (N, H, 1)); flattened internally.

  • use_physical (bool, default False) – If True, inverse-transform y_true and y_pred to physical units before computing metrics.

  • scaler_info (optional) – Passed through to inverse_scale_target.

  • target_name (optional) – Passed through to inverse_scale_target.

  • scaler_entry (optional) – Passed through to inverse_scale_target.

  • scaler (optional) – Passed through to inverse_scale_target.

  • feature_index (optional) – Passed through to inverse_scale_target.

  • n_features (optional) – Passed through to inverse_scale_target.

  • params (optional) – Passed through to inverse_scale_target.

Returns:

{“mae”: …, “mse”: …, “r2”: …}

Return type:

dict

geoprior.utils.scale_metrics.per_horizon_metrics(y_true, y_pred, *, use_physical=False, scaler_info=None, target_name=None, scaler_entry=None, scaler=None, feature_index=None, n_features=None, params=None)[source]

Per-horizon MAE/R².

Parameters:
  • y_true (array-like) – Shape (N, H, 1) or (N, H).

  • y_pred (array-like) – Shape (N, H, 1) or (N, H).

  • use_physical (bool, default False) – If True, inverse-transform to physical units before computing.

  • scaler_info (Mapping[str, Any] | None)

  • target_name (str | None)

  • scaler_entry (Mapping[str, Any] | None)

  • scaler (Any | None)

  • feature_index (int | None)

  • n_features (int | None)

  • params (Mapping[str, float] | None)

Returns:

  • mae_dict (dict) – {‘H1’: mae_h1, …}

  • r2_dict (dict) – {‘H1’: r2_h1, …}

Return type:

tuple[dict, dict]

Sequence-building helpers for temporal model inputs.

geoprior.utils.sequence_utils.check_sequence_feasibility(df, *, time_col, group_id_cols=None, time_steps=12, forecast_horizon=3, engine='vectorized', mode=None, logger=<built-in function print>, verbose=0, error='warn')[source]

Quick pre-flight feasibility check for sliding-window sequence generation

Checks whether the input table is long enough—per group—to yield at least one (look-back + horizon) sliding window, without allocating large NumPy tensors. It is typically called immediately before prepare_pinn_data_sequences() or similar generators to “fail fast’’ on data shortages.

Parameters:
  • df (pandas.DataFrame) – Tidy time-series table in long format. Every row represents one observation timestamp (and optionally one entity when group_id_cols is given). The function never mutates df.

  • time_col (str) – Column that defines temporal order inside each trajectory. Must be sortable; no other assumptions (numeric, datetime, …) are made.

  • group_id_cols (list of str or None, default None) – Column names that jointly identify independent trajectories (e.g. ["well_id"] or ["site", "layer_id"]). When None the whole DataFrame is treated as a single group.

  • time_steps (int, default 12) – Look-back window \(T_ ext{past}\) consumed by the encoder.

  • forecast_horizon (int, default 3) – Prediction horizon \(H\) produced by the decoder.

  • engine ({'vectorized', 'loop', 'pyarrow'}, default 'vectorized') –

    • ‘vectorized’ – fastest; single DataFrame.groupby.size() call (C-level) plus NumPy math.

    • ’native’ – reproduces the original Python loop for debuggability.

    • ’pyarrow’ – forces pandas’ Arrow backend, then runs the same vectorised logic; ~20 % faster on very wide frames when pyarrow ≥ 14 is installed.

  • mode ({'pihal_like', 'tft_like'} or None, optional) – Present only for API symmetry. Ignored – feasibility depends solely on time_steps + forecast_horizon.

  • logger (callable, default print()) – Sink for human-readable log messages. Must accept a single str.

  • verbose (int, default 0) – Verbosity level: 0 → silent, 1 → summary lines, 2 → per-group detail.

  • error ({'raise', 'warn', 'ignore'}, default 'warn') –

    Action when no group is long enough.

    • 'raise' – raise SequenceGeneratorError.

    • 'warn' – emit UserWarning, return False.

    • 'ignore' – stay silent, return False.

Returns:

  • feasible (bool) – True iff at least one sequence can be produced, otherwise False.

  • counts (dict) – Mapping group key → # sequences. The key is a tuple of the group values—or None when group_id_cols is None.

Raises:

SequenceGeneratorError – Raised only when error='raise' and all groups fail the length check.

Return type:

tuple[bool, dict[str | tuple, int]]

Notes

A group passes the check iff

(33)#\[\text{len(group)} \;\ge\; T_\text{past} + H\]

No validation of time-gaps, duplicates, or NaNs is performed; those are deferred to the full preparation routine.

The Arrow backend (engine='pyarrow') can accelerate very wide frames because each column is represented as a contiguous Arrow array with cheap zero-copy slicing.

Examples

  • Minimal usage

>>> from geoprior.utils.sequence_utils import check_sequence_feasibility
>>> ok, counts = check_sequence_feasibility(
...     df,
...     time_col="date",
...     group_id_cols=["site"],
...     time_steps=6,
...     forecast_horizon=3,
... )
>>> ok
True
>>> counts
{'A': 9, 'B': 9}
  • Fail-fast behaviour

>>> check_sequence_feasibility(
...     df_small,
...     time_col="t",
...     time_steps=10,
...     forecast_horizon=5,
...     error="raise",
... )
Traceback (most recent call last):
...
SequenceGeneratorError: No group is long enough ...
  • Switching engines

>>> _ , _ = check_sequence_feasibility(
...     df,
...     time_col="ts",
...     group_id_cols=None,
...     engine="pyarrow",   # requires pandas 2.1+, pyarrow installed
...     verbose=1,
... )
✅ Feasible: 1 234 567 sequences possible.

References

  • McKinney, W. pandas 2.0 User Guide, sec. “GroupBy: split-apply-combine’’.

  • Arrow Project. (2025). Arrow Columnar Memory Format v2.

geoprior.utils.sequence_utils.get_sequence_counts(df, *, group_id_cols, min_len, engine='vectorized', verbose=0, logger=<built-in function print>)[source]

Return the total number of feasible sliding-window sequences and a mapping group → count using the requested execution engine.

Parameters:
  • engine ({'vectorized', 'native', 'pyarrow'}, default 'vectorized') –

    Execution backend.

    • ’vectorized’ – fast C-level DataFrame.groupby.size() (recommended).

    • ’native’ – original Python loop (easier to debug, slower).

    • ’pyarrow’ – forces pandas’ Arrow backend if available, then runs the vectorised path. Falls back silently to 'vectorized' when pyarrow is not installed.

  • df (DataFrame)

  • group_id_cols (list[str] | None)

  • min_len (int)

  • verbose (int)

Return type:

tuple[int, dict[str | tuple, int], Series]

geoprior.utils.sequence_utils.generate_pinn_sequences(df, time_col, subsidence_col, gwl_col, dynamic_cols, static_cols=None, future_cols=None, spatial_cols=None, group_id_cols=None, time_steps=12, forecast_horizon=3, output_subsidence_dim=1, output_gwl_dim=1, mode='pihal_like', normalize_coords=True, cols_to_scale=None, method='rolling', stride=1, random_samples=None, expand_step=1, n_bootstrap=0, progress_hook=None, stop_check=None, verbose=1, _logger=None, **kwargs)[source]

Generate input/target arrays for PINN models using various sampling methods (rolling, strided, random, expanding, bootstrap).

Parameters:
  • df (pd.DataFrame) – Full time-series data.

  • time_col (str) – Name of the time coordinate column.

  • subsidence_col (str) – Name of the subsidence target column.

  • gwl_col (str) – Name of the groundwater level target column.

  • dynamic_cols (list[str]) – Names of past-covariate columns.

  • static_cols (list[str], optional) – Names of static feature columns.

  • future_cols (list[str], optional) – Names of known-future feature columns.

  • spatial_cols ((str, str), optional) – Tuple of (lon_col, lat_col) for spatial coords.

  • group_id_cols (list[str], optional) – Column(s) identifying independent time-series groups.

  • time_steps (int, default 12) – Look-back window length T.

  • forecast_horizon (int, default 3) – Prediction horizon H.

  • output_subsidence_dim (int, default 1) – Last-dim of subsidence target.

  • output_gwl_dim (int, default 1) – Last-dim of GWL target.

  • mode ({'pihal_like','tft_like'}, default 'pihal_like') – Shapes the “future” window length for TFT vs. PIHALNet.

  • normalize_coords (bool, default True) – Apply MinMax scaling to (t,x,y) across all sequences.

  • cols_to_scale (list[str] or 'auto' or None) – Additional columns to scale via MinMax.

  • method ({'rolling','strided','random','expanding','bootstrap'}) – Sequence-generation strategy.

  • stride (int, default 1) – Step size for ‘strided’ sampling.

  • random_samples (int, optional) – Number of random start indices for ‘random’ sampling.

  • expand_step (int, default 1) – Increment size for ‘expanding’ sampling.

  • n_bootstrap (int, default 0) – Number of blocks for ‘bootstrap’ sampling.

  • progress_hook (callable, optional) – Called with float in [0,1] to report overall progress.

  • stop_check (callable, optional) – If returns True, aborts sequence generation early.

  • verbose (int, default 1) – Verbosity level (higher = more logs).

  • _logger (logging.Logger or callable, optional) – Logger or print‐style function for vlog().

  • **kwargs – Passed to helper.

Returns:

  • inputs (dict[str, np.ndarray]) – Contains ‘coords’, ‘dynamic_features’, optionally ‘static_features’ and ‘future_features’.

  • targets (dict[str, np.ndarray]) – Contains ‘subsidence’ and ‘gwl’ arrays.

  • coord_scaler (MinMaxScaler or None) – Fitted scaler for coords, if normalization was applied.

Return type:

tuple[dict[str, ndarray], dict[str, ndarray], MinMaxScaler | None]

geoprior.utils.sequence_utils.generate_ts_sequences(df, time_col, dynamic_cols, static_cols=None, future_cols=None, spatial_cols=None, group_id_cols=None, time_steps=12, forecast_horizon=1, normalize_coords=True, cols_to_scale=None, method='rolling', stride=1, random_samples=None, expand_step=1, n_bootstrap=0, progress_hook=None, stop_check=None, verbose=1, _logger=None, **kwargs)[source]

Generate time-series windows for encoder/decoder and covariates. Supports rolling, strided, random, expanding, and bootstrap.

Parameters:
  • df (pd.DataFrame) – Input frame with time and feature columns.

  • time_col (str) – Name of the time coordinate column.

  • dynamic_cols (list[str]) – Past-covariate columns for encoder inputs.

  • static_cols (list[str] or None) – Static covariate columns, repeated per window.

  • future_cols (list[str] or None) – Known-future covariates for decoder inputs.

  • spatial_cols (tuple(str,str) or None) – (lon, lat) column names for spatial coords.

  • group_id_cols (list[str] or None) – Columns to group by for independent series.

  • time_steps (int) – Number of past steps (T) per window.

  • forecast_horizon (int) – Number of future steps (H) per window.

  • normalize_coords (bool) – If True, MinMax-scale spatial coords.

  • cols_to_scale (list[str] or 'auto' or None) – Other columns to MinMax-scale.

  • method (str) – ‘rolling’,’strided’,’random’,’expanding’,’bootstrap’.

  • stride (int) – Step size for ‘strided’ windows.

  • random_samples (int or None) – Number of random windows if method=’random’.

  • expand_step (int) – Increment for ‘expanding’ windows.

  • n_bootstrap (int) – Number of bootstrap samples if method=’bootstrap’.

  • progress_hook (callable or None) – Receives float [0,1] as work progresses.

  • stop_check (callable or None) – If returns True, aborts generation.

  • verbose (int) – Verbosity level. >0 logs progress.

  • _logger (callable or None) – Logger to use for messages.

Returns:

  • inputs (dict of np.ndarray) – ‘encoder_inputs’,’static’,’future’,’coords’.

  • targets (dict of np.ndarray) – ‘decoder_targets’.

  • coord_scaler (MinMaxScaler or None) – Fitted scaler for coords, if normalized.

Raises:

SequenceGeneratorError – If no valid windows could be generated.

Return type:

tuple[dict[str, ndarray], dict[str, ndarray], MinMaxScaler | None]

geoprior.utils.sequence_utils.build_future_sequences_npz(df_scaled, *, time_col, time_col_num, lon_col, lat_col, time_steps, train_end_time=None, forecast_start_time=None, forecast_horizon=None, subs_col=None, gwl_col=None, h_field_col=None, static_features=None, dynamic_features=None, future_features=None, group_id_cols=None, mode=None, model_name=None, artifacts_dir=None, prefix='future', future_mode='auto', normalize_coords=False, coord_scaler=None, verbose=1, logger=None, stop_check=None, progress_hook=None, **kws)[source]

Build history–future sequences and save them as compressed NPZ files.

This helper constructs, for each spatial group, a sliding window of time_steps “history” points followed by a multi–step forecast horizon and exports the resulting NumPy arrays to disk. It is time-agnostic: the time_col can be numeric (e.g. year, index), year-like floats, datetimes, or strings, as long as equality on that column is meaningful.

If train_end_time, forecast_start_time, or forecast_horizon are not provided, they are inferred from the sorted unique values in df_scaled[time_col]:

  • train_end_time: by default the second-to-last unique time, leaving at least one future step.

  • forecast_start_time: by default the first time strictly after train_end_time.

  • forecast_horizon: by default one time step ahead, clipped to the number of available future points.

For each valid group, the function builds history dynamic features of shape (time_steps, n_dynamic), future features of shape (time_steps + H, n_future) when mode starts with "tft" or (H, n_future) otherwise, one static feature vector of shape (n_static,), coordinates over the horizon of shape (H, 3) with columns [time_num, lon, lat], an H_field array of shape (H, 1), and optional subsidence and groundwater targets of shape (H, 1) each.

All per-group arrays are stacked along a new batch dimension and written as two NPZ files:

  • <prefix>_inputs.npz: coordinates, dynamic, static, future features and H field.

  • <prefix>_targets.npz: subsidence and groundwater targets.

Parameters:
  • df_scaled (pandas.DataFrame) – Pre-processed (typically scaled) dataframe containing all required columns: time, spatial coordinates, static/dynamic/ future features and optional targets.

  • time_col (str) – Name of the column encoding the temporal index (e.g. "year", "date", "t_index"). May be numeric, datetime, or string.

  • time_col_num (str or None) – Optional numeric time column used as a tie-breaker when multiple rows share the same time_col value. If provided and present in a group, the last row sorted by this column is selected for that time.

  • lon_col (str) – Name of the longitude (or x-coordinate) column.

  • lat_col (str) – Name of the latitude (or y-coordinate) column.

  • time_steps (int) – Length of the history window (number of time steps in the past). Must be strictly positive.

  • train_end_time (object, optional) – Effective end of the training period. If None, it is inferred as the second-to-last unique value in df_scaled[time_col] (after sorting).

  • forecast_start_time (object, optional) – First time step of the forecast horizon. If None, it is inferred as the first unique time strictly greater than train_end_time.

  • forecast_horizon (int, optional) – Number of future time steps to include. If None, a default horizon of 1 is used and clipped to the maximum number of available future time points.

  • subs_col (str, optional) – Name of the subsidence target column. If None or missing from a group, subsidence targets are filled with NaN.

  • gwl_col (str, optional) – Name of the groundwater-level target column. If None or missing from a group, groundwater targets are filled with NaN.

  • h_field_col (str, optional) – Name of the hydraulic-head field column used as an additional horizon-level input (H_field). If None or missing, a zero field is used.

  • static_features (list of str, optional) – Names of static (time-invariant) feature columns. Any names not present in the dataframe are silently ignored.

  • dynamic_features (list of str, optional) – Names of dynamic (history) feature columns used to build the (time_steps, n_dynamic) sequence. Missing columns are ignored.

  • future_features (list of str, optional) – Names of future covariate columns used to build the history+future or future-only sequence, depending on mode. Missing columns are ignored.

  • group_id_cols (list of str, optional) – Columns used to define spatial (or logical) groups, typically something like ["lon", "lat"] or a station identifier. If None or empty, the entire dataframe is treated as a single global group.

  • mode (str, optional) – Controls how future features are constructed. If the lower-cased value starts with "tft" (e.g. "tft_like"), future features are built on top of both history and future rows. Otherwise, only the forecast horizon rows are used.

  • model_name (str, optional) – Optional model identifier used only in logging messages.

  • artifacts_dir (str, optional) – Directory where NPZ files are written. If None or empty, the current working directory is used.

  • prefix (str, default "future") – Prefix for the output NPZ filenames: "<prefix>_inputs.npz" and "<prefix>_targets.npz".

  • future_mode ({'auto', 'pure-inference', 'pure-data-driven'}, default 'auto') –

    Strategy used to construct the future (forecast) portion of the sequences.

    • 'pure-data-driven': Use only time points that actually exist in df_scaled strictly after the history window. All future time indices must be present in the data; otherwise a ValueError is raised. This corresponds to the original, strictly data-driven behaviour.

    • 'pure-inference': Always synthesize future time points from the last history time, using the median positive time step (or 1.0 as a fallback). Future inputs are built by re-using the last available history row (for future_features, H_field, etc.), and future targets (e.g. subsidence, GWL) are filled with NaN since the true future is unknown. This mode does not require any rows beyond train_end_time.

    • 'auto': Try data-driven mode first. If there are enough actual future time points after train_end_time to cover the requested forecast_horizon, behave like 'pure-data-driven'. If not, automatically fall back to the synthetic 'pure-inference' behaviour described above and emit an informational log message via vlog.

  • verbose (int, default 1) – Verbosity level forwarded to geoprior.utils.vlog(). A value >= 3 provides detailed progress logs (temporal inference, per-group status, dropped groups, etc.).

  • logger (logging.Logger or callable, optional) – Optional logger or logging function used by geoprior.utils.vlog(). If None, messages are printed to standard output.

  • **kws – Reserved for future extensions. Currently ignored.

  • normalize_coords (bool)

  • coord_scaler (Any | None)

  • stop_check (Callable[[], bool])

  • progress_hook (Callable[[float], None] | None)

Returns:

A small dictionary with the absolute paths to the written NPZ files:

{"future_inputs_npz": <path>, "future_targets_npz": <path>}.

Return type:

dict

Raises:

ValueError – If there are not enough history points before train_end_time to satisfy time_steps, if no future points are available after forecast_start_time, or if all groups are dropped due to incomplete history/horizon windows.

Notes

Groups that do not contain all required history and future times are silently dropped, but the number of dropped groups is reported via geoprior.utils.vlog() when verbose > 0.

Examples

>>> from geoprior.nn.pinn.sequences import (
...     build_future_sequences_npz,
... )
>>> result = build_future_sequences_npz(
...     df_scaled=df_scaled,
...     time_col="year",
...     time_col_num="t_index",
...     lon_col="lon",
...     lat_col="lat",
...     time_steps=5,
...     # Let the function infer times/horizon:
...     train_end_time=None,
...     forecast_start_time=None,
...     forecast_horizon=None,
...     subs_col="subsidence",
...     gwl_col="gwl",
...     h_field_col="H_field",
...     static_features=["lithology_class"],
...     dynamic_features=["rainfall_mm", "GWL_depth_bgs_z"],
...     future_features=["normalized_urban_load_proxy"],
...     group_id_cols=["lon", "lat"],
...     mode="tft_like",
...     model_name="GeoPriorSubsNet",
...     artifacts_dir="results/zhongshan/future_npz",
...     prefix="zhongshan_future",
...     verbose=2,
... )
>>> result["future_inputs_npz"]
'results/zhongshan/future_npz/zhongshan_future_inputs.npz'
>>> result["future_targets_npz"]
'results/zhongshan/future_npz/zhongshan_future_targets.npz'

Shape utility helpers for arrays and tensors.

geoprior.utils.shapes.canonicalize_BHQO(y_pred, *, y_true=None, q_values=(0.1, 0.5, 0.9), n_q=None, layout=None, enforce_monotone=True, return_layout=False, verbose=0, log_fn=<built-in function print>)[source]

Canonicalize quantile outputs to (B, H, Q, O).

Supported layouts (rank-4):
  • BHQO: (B, H, Q, O) -> unchanged

  • BQHO: (B, Q, H, O) -> transpose(0, 2, 1, 3)

  • BHOQ: (B, H, O, Q) -> transpose(0, 1, 3, 2)

If ambiguous (e.g., H == Q), and y_true is given, pick the transform with smallest MAE for q50.

If y_true is not given, fallback is:
  1. use layout if provided

  2. else prefer BHQO if plausible

  3. else pick by min crossing score

Parameters:
  • y_pred (Any) – Quantile tensor, NumPy array or TF tensor.

  • y_true (Any | None) – Target tensor (B, H, O) or (B, H, 1). Used only to resolve ambiguity robustly.

  • q_values (Sequence[float]) – Quantiles in order, e.g. (0.1, 0.5, 0.9).

  • n_q (int | None) – Number of quantiles. Defaults to len(q_values).

  • layout (str | None) – Force interpretation: “BHQO”, “BQHO”, “BHOQ”. Use “auto” (or None) to infer.

  • enforce_monotone (bool) – Sort along Q axis after canonicalization.

  • return_layout (bool) – If True, return (arr, chosen_layout).

  • verbose (int) – Logging controls.

  • log_fn (Callable[[str], None]) – Logging controls.

Returns:

Canonical (B, H, Q, O) and optionally the layout.

Return type:

arr or (arr, layout)

geoprior.utils.shapes.canonicalize_BHQO_quantiles_np(y, n_q=3, *, verbose=0, log_fn=<built-in function print>)[source]

Return y in canonical (B,H,Q,O).

Accepts common layouts:
  • (B,H,Q,O) -> unchanged

  • (B,Q,H,O) -> transpose(0,2,1,3)

  • (B,H,O,Q) -> transpose(0,1,3,2)

If ambiguous (multiple axes match n_q), choose the transform with minimal quantile crossing score.

Parameters:
Return type:

Any

Target-processing helpers for GeoPrior workflows.

geoprior.utils.target_utils.get_output_names(model=None, y=None, y_pred=None, *, exclude_keys={'aux', 'data_final', 'data_mean_raw', 'maps', 'phys_final', 'phys_mean_raw', 'physics'})[source]

Try to obtain stable output names (best-effort, Keras-3-safe).

Lookup priority is: model._output_keys or model._output_names first, then model.output_names, then keys from y_pred, and finally keys from y.

geoprior.utils.target_utils.as_tuple(obj, *, names=None, model=None, ctx='value', strict=True, exclude_keys={'aux', 'data_final', 'data_mean_raw', 'maps', 'phys_final', 'phys_mean_raw', 'physics'}, to_numpy='never')[source]

Convert obj to an ordered tuple of outputs.

Supports:
  • dict: ordered by names (or inferred)

  • list/tuple: returns tuple(obj)

  • tensor/ndarray/scalar: returns (obj,)

Parameters:
  • obj (Any) – Targets/predictions container.

  • names (list[str] or None) – Desired order of outputs.

  • model (Any or None) – Model used to infer output names (via _output_keys/output_names).

  • strict (bool) – If True, missing keys in dict raises KeyError.

  • to_numpy ({"never","auto","always"}) – Optional conversion of individual leaves to numpy (only safe in eager).

Returns:

Ordered outputs.

Return type:

tuple

geoprior.utils.target_utils.update_compiled_metrics(model, y_true, y_pred, *, output_names=None, to_numpy='never')[source]

Keras-3-safe compiled metrics updater.

  • Prefers dict structure (since you compiled with dict loss/metrics).

  • Ensures deterministic output order via output_names/_output_keys.

  • Falls back to list/tuple update_state if needed.

  • Final fallback: manual per-metric update (won’t crash training).

Note: converting to numpy inside train_step is generally NOT safe.

These helpers support inverse target scaling, sequence construction, canonical layout handling, and target-aware postprocessing.

Validation and helper infrastructure#

Essential utilities for data processing and analysis in FusionLab, offering functions for normalization, interpolation, feature selection, outlier removal, and various data manipulation tasks.

Adapted for FusionLab from the original geoprior.utils.base_utils.

geoprior.utils.base_utils.detect_categorical_columns(data, integer_as_cat=True, float0_as_cat=True, min_unique_values=None, max_unique_values=None, handle_nan=None, return_frame=False, consider_dt_as=None, verbose=0)[source]

Detect categorical columns in a dataset by examining column types and user-defined criteria. Columns with integer type or float values ending with .0 can be categorized as categorical, depending on settings. Also handles user-defined thresholds for minimum and maximum unique values.

(34)#\[\forall x \in X,\; x = \lfloor x \rfloor\]

Above equation indicates that for float columns to be treated as categorical, each value \(x\) must be an integer when cast from float. This function leverages the inline methods build_data_if, drop_nan_in, fill_NaN, parameter_validator, and smart_format (excluding those prefixed with _).

Parameters:
  • data (DataFrame or array-like) – The input data to analyze. If not a DataFrame, it will be converted internally.

  • integer_as_cat (bool, optional) – If True, integer-type columns are considered categorical. Default is True.

  • float0_as_cat (bool, optional) – If True, float columns whose values can be cast to integer without remainder are considered categorical. Default is True.

  • min_unique_values (int or None, optional) – Minimum number of unique values in a column to qualify as categorical. If None, no minimum check is applied.

  • max_unique_values (int or :py:class:``’auto’:py:class:`` or None, optional) – Maximum number of unique values allowed for a column to be considered categorical. If 'auto', set the limit to the column’s own unique count. If None, no maximum check is applied.

  • handle_nan (str or None, optional) – Handling method for missing data. Can be 'drop' to remove rows with NaNs, 'fill' to impute them via forward/backward fill, or None for no change.

  • return_frame (bool, optional) – If True, returns a DataFrame of detected categorical columns; otherwise returns a list of column names. Default is False.

  • consider_dt_as (str, optional) – Indicates how to handle or convert datetime columns when ops='validate'. Use None to keep datetime columns as-is. Use 'numeric' for timestamp-style conversion, 'float', 'float32' or 'float64' for float conversion, 'int', 'int32' or 'int64' for integer conversion, and 'object' or 'category' to convert them to Python objects such as strings. If conversion fails, behavior follows the configured error policy.

  • verbose (int, optional) – Verbosity level. If greater than 0, a summary of detected columns is printed.

Returns:

Either a list of column names or a DataFrame containing the categorical columns, depending on the value of return_frame.

Return type:

list or DataFrame

Examples

>>> from geoprior.utils.base_utils import detect_categorical_columns
>>> import pandas as pd
>>> df = pd.DataFrame({
...     'A': [1, 2, 3],
...     'B': [1.0, 2.0, 3.0],
...     'C': ['cat', 'dog', 'mouse']
... })
>>> detect_categorical_columns(df)
['A', 'B', 'C']

Notes

This function focuses on flexible treatment of integer and float columns. Combined with verbose settings, it can provide detailed feedback. Using 'drop' or 'fill' for handle_nan helps reduce disruptions caused by missing data. The array-programming background is discussed in Harris et al. [36].

The function uses flexible criteria for determining whether a column should be treated as categorical, allowing for detection of columns with integer values or float values ending in .0 as categorical columns. The method is useful when preparing data for machine learning algorithms that expect categorical inputs, such as decision trees or classification models.

This method uses the helper function build_data_if from geoprior.utils.validator to ensure that the input data is a DataFrame. If the input is not a DataFrame, it creates one, giving column names that start with input_name.

See also

build_data_if

Validates and converts input into a DataFrame if needed.

drop_nan_in

Drops NaN values from a DataFrame along axis=0.

fill_NaN

Fills missing data in a DataFrame using forward and backward fill.

geoprior.utils.base_utils.extract_target(data, target_names, drop=True, columns=None, return_y_X=False)[source]

Extracts specified target column(s) from a multidimensional numpy array or pandas DataFrame.

with options to rename columns in a DataFrame and control over whether the extracted columns are dropped from the original data.

Parameters:
  • data (Union[np.ndarray, pd.DataFrame]) – The input data from which target columns are to be extracted. Can be a NumPy array or a pandas DataFrame.

  • target_names (Union[str, int, List[Union[str, int]]]) – The name(s) or integer index/indices of the column(s) to extract. If data is a DataFrame, this can be a mix of column names and indices. If data is a NumPy array, only integer indices are allowed.

  • drop (bool, default True) – If True, the extracted columns are removed from the original data. If False, the original data remains unchanged.

  • columns (Optional[List[str]], default None) – If provided and data is a DataFrame, specifies new names for the columns in data. The length of columns must match the number of columns in data. This parameter is ignored if data is a NumPy array.

  • return_y_X (bool, default False) – If True, returns a tuple (y, X) where X is the data with the target columns removed and y is the target columns. If False, returns only y.

Returns:

If return_X_y is True, returns a tuple (X, y) where X is the data with the target columns removed and y is the target columns. If return_X_y is False, returns only y.

Return type:

Union[ArrayLike, pd.Series, pd.DataFrame, Tuple[ pd.DataFrame, ArrayLike]]

Raises:

ValueError – If columns is provided and its length does not match the number of columns in data. If any of the specified target_names do not exist in data. If target_names includes a mix of strings and integers for a NumPy array input.

Examples

>>> import pandas as pd
>>> df = pd.DataFrame({
...     'A': [1, 2, 3],
...     'B': [4, 5, 6],
...     'C': [7, 8, 9]
... })
>>> target = extract_target(df, 'B', drop=True, return_y_X=False)
>>> print(target)
0    4
1    5
2    6
Name: B, dtype: int64
>>> target, remaining = extract_target(df, 'B', drop=True, return_y_X=True)
>>> print(target)
0    4
1    5
2    6
Name: B, dtype: int64
>>> print(remaining)
   A  C
0  1  7
1  2  8
2  3  9
>>> arr = np.random.rand(5, 3)
>>> target, modified_arr = extract_target(arr, 2, return_X_y=True)
>>> print(target)
>>> print(modified_arr)
geoprior.utils.base_utils.fancier_downloader(url, filename, dstpath=None, check_size=False, error='raise', verbose=True)[source]

Download a remote file with a progress bar and optional size verification.

This function downloads a file from the specified url and saves it locally with the given filename. It provides a visual progress bar during the download process and offers an option to verify the downloaded file’s size against the expected size to ensure data integrity. Additionally, the function allows for moving the downloaded file to a specified destination directory.

(35)#\[|S_{downloaded} - S_{expected}| < \epsilon\]

where \(S_{downloaded}\) is the size of the downloaded file, \(S_{expected}\) is the size specified by the server, and \(\epsilon\) is a small tolerance value.

Parameters:
  • url (str) – The URL from which to download the remote file.

  • filename (str) – The desired name for the local file. This is the name under which the file will be saved after downloading.

  • dstpath (Optional[str], default None) – The destination directory path where the downloaded file should be saved. If None, the file is saved in the current working directory.

  • check_size (bool, default False) –

    Whether to verify the size of the downloaded file against the expected size obtained from the server. This is useful for ensuring the integrity of the downloaded file. When True, the function checks:

    (36)\[|S_{downloaded} - S_{expected}| < \epsilon\]

    If the size check fails:

    • If error='raise', an exception is raised.

    • If error='warn', a warning is emitted.

    • If error='ignore', the discrepancy is ignored, and the function continues.

  • error (str, default 'raise') –

    Specifies how to handle errors during the size verification process.

    • 'raise': Raises an exception if the file size does not match.

    • 'warn': Emits a warning and continues execution.

    • 'ignore': Silently ignores the size discrepancy and proceeds.

  • verbose (bool, default True) – Controls the verbosity of the function. If True, the function will print informative messages about the download status, including progress updates and success or failure notifications.

Returns:

Returns None if dstpath is provided and the file is moved to the destination. Otherwise, returns the local filename as a string.

Return type:

Optional[str]

Raises:
  • RuntimeError – If the download fails and error is set to 'raise'.

  • ValueError – If an invalid value is provided for the error parameter.

Examples

>>> from geoprior.utils.base_utils import fancier_downloader
>>> url = 'https://example.com/data/file.h5'
>>> local_filename = 'file.h5'
>>> # Download to current directory without size check
>>> fancier_downloader(url, local_filename)
>>>
>>> # Download to a specific directory with size verification
>>> fancier_downloader(
...     url,
...     local_filename,
...     dstpath='/path/to/save/',
...     check_size=True,
...     error='warn',
...     verbose=True
... )
>>>
>>> # Handle size mismatch by raising an exception
>>> fancier_downloader(
...     url,
...     local_filename,
...     check_size=True,
...     error='raise'
... )

Notes

  • Progress Bar: The function uses the tqdm library to display a progress bar during the download. If tqdm is not installed, it falls back to a basic downloader without a progress bar.

  • Directory Creation: If the specified dstpath does not exist, the function will attempt to create it to ensure the file is saved correctly.

  • File Integrity: Enabling check_size helps in verifying that the downloaded file is complete and uncorrupted. However, it does not perform a checksum verification.

  • Progress-reporting patterns and surrounding tooling are described in [37, 38].

See also

requests.get

Function to perform HTTP GET requests.

tqdm

A library for creating progress bars.

os.makedirs

Function to create directories.

geoprior.utils.base_utils.check_file_exists

Utility to check file existence.

geoprior.utils.base_utils.fillNaN(arr, method='ff')[source]

Fill NaN values in a numpy array, pandas Series, or pandas DataFrame using specified methods for forward filling, backward filling, or both.

Parameters:
  • arr (Union[np.ndarray, pd.Series, pd.DataFrame]) – The input data containing NaN values to be filled. This can be a numpy array, pandas Series, or DataFrame expected to contain numeric data.

  • method (str, optional) – The method used for filling NaN values. Valid options are: - ‘ff’: forward fill (default) - ‘bf’: backward fill - ‘both’: applies both forward and backward fill sequentially

Returns:

The array with NaN values filled according to the specified method. The return type matches the input type (numpy array, Series, or DataFrame).

Return type:

Union[np.ndarray, pd.Series, pd.DataFrame]

geoprior.utils.base_utils.select_features(data, features=None, dtypes_inc=None, dtypes_exc=None, coerce=False, columns=None, verify_integrity=False, parse_features=False, include_missing=None, exclude_missing=None, transform=None, regex=None, callable_selector=None, inplace=False, **astype_kwargs)[source]

Selects features from a dataset based on various criteria and returns a new DataFrame.

Conceptually, the selected columns are the subset of the input column set that satisfies the requested feature names, data-type filters, regex patterns, callable selectors, and missing-data conditions.

Parameters:
  • data (Union[pd.DataFrame, dict, np.ndarray, list]) – The dataset from which to select features. Can be a pandas DataFrame, a dictionary, a NumPy array, or a list of dictionaries/lists.

  • features (Optional[Union[List[str], Pattern, Callable[[str], bool]]], default None) – Specific feature names to select. Can also be a regex pattern or a callable that takes a column name and returns True if the column should be selected.

  • dtypes_inc (Optional[Union[str, List[str]]], default None) – The data type(s) to include in the selection. Possible values are the same as for the pandas include parameter in select_dtypes.

  • dtypes_exc (Optional[Union[str, List[str]]], default None) – The data type(s) to exclude from the selection. Possible values are the same as for the pandas exclude parameter in select_dtypes.

  • coerce (bool, default False) – If True, numeric columns are coerced to the appropriate types without selection, ignoring features, dtypes_inc, and dtypes_exc parameters.

  • columns (Optional[List[str]], default None) – Column names to use if data is a NumPy array or a list without column names.

  • verify_integrity (bool, default False) – Verifies the data type integrity and converts data to the correct types if necessary.

  • parse_features (bool, default False) – Parses string features and converts them to an iterable object (e.g., lists).

  • include_missing (Optional[bool], default None) – If True, includes only columns with missing values. If False, excludes columns with missing values.

  • exclude_missing (Optional[bool], default None) – If True, excludes columns with any missing values.

  • transform (Optional, default None) – Function or dictionary of functions to apply to the selected columns. If a dictionary is provided, keys should correspond to column names.

  • regex (Optional[Union[str, Pattern]], default None) – Regular expression pattern to select columns.

  • callable_selector (Optional[Callable[[str], bool]], default None) – A callable that takes a column name and returns True if the column should be selected.

  • inplace (bool, default False) – If True, modifies the data in place. Otherwise, returns a new DataFrame.

  • **astype_kwargs (Any) – Additional keyword arguments for pandas.DataFrame.astype.

Returns:

A new DataFrame with the selected features.

Return type:

pd.DataFrame

Raises:
  • ValueError – If no columns match the selection criteria and coerce is False.

  • TypeError – If regex is not a string or compiled regex pattern. If callable_selector is not a callable. If transform is not a callable or a dictionary of callables. If provided parameters are of incorrect types.

Examples

>>> from geoprior.utils.base_utils import select_features
>>> import pandas as pd
>>> import re
>>> import numpy as np
>>> data = {
...     "Color": ['Blue', 'Red', 'Green'],
...     "Name": ['Mary', "Daniel", "Augustine"],
...     "Price ($)": ['200', "300", "100"],
...     "Discount": [20, 30, np.nan]
... }
>>> select_features(data, dtypes_inc='number', verify_integrity=True)
   Price ($)  Discount
0      200.0      20.0
1      300.0      30.0
2      100.0       NaN
>>> select_features(data, features=['Color', 'Price ($)'])
   Color Price ($)
0   Blue       200
1    Red       300
2  Green       100
>>> select_features(
...     data,
...     regex='^Price|Discount$',
...     transform={'Price ($)': lambda x: x / 100}
... )
   Price ($)  Discount
0        2.0        20
1        3.0        30
2        1.0         NaN
>>> select_features(
...     data,
...     callable_selector=lambda col: col.startswith('C')
... )
   Color
0   Blue
1    Red
2  Green

Notes

This function is particularly useful in data preprocessing pipelines where the presence of certain features is critical for later analysis or modeling steps. When using regex patterns, ensure that the pattern accurately reflects the intended column names to avoid unintended matches. The callable provided to callable_selector should accept a single column-name string and return a boolean. Transformation functions should be designed to handle the data types of the selected columns to avoid runtime errors. Related selection and coercion behavior is documented in [39, 40, 41, 42].

geoprior.utils.base_utils.fill_NaN(arr, method='ff')[source]

Fill NaN values in an array-like structure using specified methods. Handles numeric and non-numeric data separately to preserve data integrity.

Parameters:
  • arr (array-like, pandas.DataFrame, or pandas.Series) – The input data structure containing NaN values to be filled.

  • method (str, default :py:class:``’ff’:py:class:``) –

    The method to use for filling NaN values. Accepted values:

    • Forward fill: 'forward', 'ff', 'fwd'

    • Backward fill: 'backward', 'bf', 'bwd'

    • Both: 'both', 'ffbf', 'fbwf', 'bff', 'full'

Returns:

The input data structure with NaN values filled according to the specified method.

Return type:

array-like, pandas.DataFrame, or pandas.Series

Raises:

ValueError – If the provided fill method is not recognized.

Notes

Mathematically, the function performs:

(37)#\[ext{Filled\_array} = egin{cases} ext{fillNaN(arr, method)} & ext{if arr is numeric} \ ext{concat(fillNaN(numeric\_parts, method), non\_numeric\_parts)} & ext{otherwise} \end{cases}\]

This ensures that non-numeric data remains unaltered while NaN values in numeric columns are appropriately filled.

The function preserves the original structure of the input array by utilizing array_preserver. Numeric columns are filled using the specified method, while non-numeric columns remain unchanged.

Examples

>>> from geoprior.utils.base_utils import fill_NaN
>>> import pandas as pd
>>> df = pd.DataFrame({
...     'A': [1, 2, np.nan, 4],
...     'B': ['x', np.nan, 'y', 'z']
... })
>>> fill_NaN(df, method='ff')
     A    B
0  1.0    x
1  2.0    x
2  2.0    y
3  4.0    z

See also

geoprior.core.array_manager.array_preserver

Preserves and restores array structures.

geoprior.core.array_manager.to_array

Converts input to a pandas-compatible array-like structure.

geoprior.core.checks.is_numeric_dtype

Checks if the array has numeric data types.

geoprior.utils.base_utils.fillNaN

Core function to fill NaN values in numeric data.

geoprior.utils.base_utils.validate_target_in(df, target, error='raise', verbose=0)[source]

Validate and process the target variable, ensuring it is consistent with the features in the DataFrame.

Parameters:
  • df (pandas.DataFrame) – The DataFrame containing the features and possibly the target column.

  • target (str or pandas.Series or pandas.DataFrame) – The target variable to validate and process.

  • error ({'raise', 'warn', 'ignore'}, optional) – Behavior to use when target validation fails. Use 'raise' to raise an exception, 'warn' to continue with a warning, or 'ignore' to skip reporting.

  • verbose (int, optional) – Verbosity level for logging. Use 0 for no output, 1 for basic information, and 2 for detailed information.

Returns:

Dependency utilities providing functions to handle package installation, checking, and ensuring that optional dependencies are available.

geoprior.utils.deps_utils.ensure_pkg(name, extra='', error='raise', min_version=None, exception=None, dist_name=None, infer_dist_name=False, auto_install=False, use_conda=False, partial_check=False, condition=None, verbose=False)[source]

Decorator to ensure a Python package is installed before function execution.

If the specified package is not installed, or if its installed version does not meet the minimum version requirement, this decorator can optionally install or upgrade the package automatically using either pip or conda.

Parameters:
  • name (str) – The name of the package.

  • extra (str, optional) – Additional specification for the package, such as version or extras.

  • error (str, optional) – Error handling strategy if the package is missing: ‘raise’, ‘ignore’, or ‘warn’.

  • min_version (str or None, optional) – The minimum required version of the package. If not met, triggers installation.

  • exception (Exception, optional) – A custom exception to raise if the package is missing and errors is ‘raise’.

  • dist_name (str, optional) – The distribution name of the package as known by package managers (e.g., pip). If provided and the module import fails, an additional check based on the distribution name is performed. This parameter is useful for packages where the distribution name differs from the importable module name.

  • infer_dist_name (bool, optional) – If True, attempt to infer the distribution name for pip installation, defaults to False.

  • auto_install (bool, optional) – Whether to automatically install the package if missing. Defaults to False.

  • use_conda (bool, optional) – Prefer conda over pip for automatic installation. Defaults to False.

  • partial_check (bool, optional) – If True, checks the existence of the package only if the condition is met. This allows for conditional package checking based on the function’s arguments or other criteria. If False, the check is always performed. Defaults to False.

  • condition (Any, optional) – A condition that determines whether to check for the package’s existence. This can be a callable that takes the same arguments as the decorated function and returns a boolean, a specific argument name to check for truthiness, or any other value that will be evaluated as a boolean. If None, the package check is performed unconditionally unless partial_check is False.

  • verbose (bool, optional) – Enable verbose output during the installation process. Defaults to False.

Returns:

A decorator that wraps functions to ensure the specified package is installed.

Return type:

Callable

Examples

>>> from geoprior.utils.deps_utils import ensure_pkg
>>> @ensure_pkg("numpy", auto_install=True)
... def use_numpy():
...     import numpy as np
...     return np.array([1, 2, 3])
>>> @ensure_pkg("pandas", min_version="1.1.0", errors="warn", use_conda=True)
... def use_pandas():
...     import pandas as pd
...     return pd.DataFrame([[1, 2], [3, 4]])
>>> @ensure_pkg("matplotlib", partial_check=True, condition=lambda x, y: x > 0)
... def plot_data(x, y):
...     import matplotlib.pyplot as plt
...     plt.plot(x, y)
...     plt.show()
>>> @ensure_pkg("skimage", partial_check=True, condition=(
...     lambda *args, **kwargs: 'method' in kwargs and kwargs['method'] == 'hog')
...     )
>>> def check_package_installed(data, method='hog', **kwargs):
...     extractor_function = None
...     if method == 'hog':
...         from skimage.feature import hog
...         extractor_function = lambda image: hog(image, **kwargs)
...     return extractor_function
geoprior.utils.deps_utils.ensure_pkgs(names, extra='', error='raise', min_versions=None, exception=None, dist_names=None, infer_dist_name=False, auto_install=False, use_conda=False, partial_check=False, condition=None, verbose=False)[source]

Decorator to ensure Python packages are installed before function execution.

If the specified packages are not installed, or if their installed versions do not meet the minimum version requirements, this decorator can optionally install or upgrade the packages automatically using either pip or conda.

Parameters:
  • names (str or list of str) – The name(s) of the package(s). Can be a single string with package names separated by commas, or a list of package names.

  • extra (str, optional) – Additional specification for the package(s), such as version or extras.

  • error ({'raise', 'ignore', 'warn'}, optional) – Error handling strategy if a package is missing: ‘raise’, ‘ignore’, or ‘warn’. Defaults to ‘raise’.

  • min_version (str or list of str, optional) – The minimum required version(s) of the package(s). If not met, triggers installation. Can be a single version string applied to all packages or a list matching the names list.

  • exception (Exception, optional) – A custom exception to raise if a package is missing and errors is ‘raise’.

  • dist_name (str or list of str, optional) – The distribution name(s) of the package(s) as known by package managers (e.g., pip). Useful when the distribution name differs from the importable module name. Can be a single string or a list matching the names list.

  • infer_dist_name (bool, optional) – If True, attempt to infer the distribution name for pip installation. Defaults to False.

  • auto_install (bool, optional) – Whether to automatically install missing packages. Defaults to False.

  • use_conda (bool, optional) – Prefer conda over pip for automatic installation. Defaults to False.

  • partial_check (bool, optional) – If True, checks the existence of the packages only if the condition is met. Allows for conditional package checking based on the function’s arguments or other criteria. If False, the check is always performed. Defaults to False.

  • condition (Any, optional) – A condition that determines whether to check for the packages’ existence. Can be a callable that takes the same arguments as the decorated function and returns a boolean, a specific argument name to check for truthiness, or any other value that will be evaluated as a boolean. If None, the package check is performed unconditionally unless partial_check is False.

  • verbose (bool, optional) – Enable verbose output during the installation process. Defaults to False.

  • min_versions (str | list[str | None] | None)

  • dist_names (str | list[str | None] | None)

Returns:

A decorator that wraps functions to ensure the specified packages are installed.

Return type:

Callable

Examples

>>> from geoprior.utils.deps_utils import ensure_pkgs
>>> @ensure_pkgs("numpy, pandas", auto_install=True)
... def use_numpy_pandas():
...     import numpy as np
...     import pandas as pd
...     return np.array([1, 2, 3]), pd.DataFrame([[1, 2], [3, 4]])
>>> @ensure_pkgs(["matplotlib", "seaborn"], min_version=["3.0.0", "0.11.0"])
... def plot_data(x, y):
...     import matplotlib.pyplot as plt
...     import seaborn as sns
...     sns.scatterplot(x=x, y=y)
...     plt.show()
>>> @ensure_pkgs("skimage", partial_check=True, condition=(
...     lambda *args, **kwargs: 'method' in kwargs and kwargs['method'] == 'hog')
... )
... def process_image(data, method='hog', **kwargs):
...     if method == 'hog':
...         from skimage.feature import hog
...         return hog(data, **kwargs)
...     else:
...         # Other processing
...         pass
geoprior.utils.deps_utils.install_package(name, dist_name=None, infer_dist_name=False, version=None, extra='', use_conda=False, verbose=True)[source]

Install a Python package at runtime, optionally specifying a version constraint or other parameters, using either conda or pip. If conda is unavailable or disabled, pip is used by default. The function includes a check for pre-existing installations, allowing users to skip redundant installs.

Parameters:
  • name (str) – Base name of the package to install (e.g., 'requests').

  • dist_name (str, optional) – Distribution name, if different from the import name. For example, scikit-learn’s import name sklearn differs from its distribution name 'scikit-learn'.

  • infer_dist_name (bool, optional) – If True, calls get_installation_name() to infer the distribution name automatically. Defaults to False.

  • version (str, optional) – Version string or comparator. Examples include '1.2.0' interpreted as '>=1.2.0', '==1.2.0', '<2.0', and '>=1.5.3'. If None, no version constraint is applied.

  • extra (str, optional) – Additional install specifiers or command-line flags passed to the installation command. For instance, ' --no-cache-dir' or '[extra]'. Default is ''.

  • use_conda (bool, optional) – If True, attempts installation via conda first. If conda is unavailable or fails, falls back to pip. Defaults to False.

  • verbose (bool, optional) – If True, prints detailed logs throughout the installation. Defaults to True.

Returns:

On success, the specified package is installed (or is already present). If the installer fails, raises a RuntimeError.

Return type:

None

Raises:

RuntimeError – If the installation cannot be completed using either conda or pip, or if conda is requested but unavailable.

Notes

If the package is already installed, as determined by is_module_installed(), no further action is taken. When using pip, a progress bar is displayed if tqdm is installed. For conda, no progress bar is shown because of console I/O capture limitations.

Conceptually, this function assembles an install spec of the form:

(38)#\[\text{install\_str} = \langle \text{name} \rangle + \langle \text{version\_spec} \rangle + \langle \text{extra} \rangle\]

where \(\langle \text{name} \rangle\) is the package name, \(\langle \text{version\_spec} \rangle\) is a version comparator (e.g., >=1.2.0), and \(\langle \text{extra} \rangle\) is any additional flags or arguments.

Examples

>>> from geoprior.utils.deps_utils import install_package
>>> # Install requests with no version constraint, default pip:
>>> install_package('requests', verbose=True)
>>> # Install a specific version via conda (fallback to pip if conda fails):
>>> install_package(
...     'pandas',
...     version='==1.2.0',
...     use_conda=True,
...     verbose=True
... )

See also

is_module_installed

Check whether a Python module or corresponding distribution is already installed.

get_installation_name

Infer a distribution name for the given module name when needed.

geoprior.utils.deps_utils.is_installing(module, upgrade=True, action=True, DEVNULL=False, verbose=0, **subpkws)[source]

Install or uninstall a module/package using the subprocess under the hood.

Parameters:
  • module (str,) – the module or library name to install using Python Index Package PIP

  • upgrade (bool,) – install the lastest version of the package. default is True.

  • DEVNULL (bool,) – decline the stdoutput the message in the console

  • action (str,bool) – Action to perform. ‘install’ or ‘uninstall’ a package. default is True which means ‘intall’.

  • verbose (int, Optional) – Control the verbosity i.e output a message. High level means more messages. default is 0.

  • subpkws (dict,) – additional subprocess keywords arguments

Returns:

success – whether the package is sucessfully installed or not.

Return type:

bool

Example

>>> from gofast import is_installing
>>> is_installing(
    'tqdm', action ='install', DEVNULL=True, verbose =1)
>>> is_installing(
    'tqdm', action ='uninstall', verbose =1)
geoprior.utils.deps_utils.get_installation_name(module_name, distribution_name=None, return_bool=False)[source]

Determines the appropriate name for installing a package, considering potential discrepancies between the distribution name and the module import name. Optionally, returns a boolean indicating if the distribution name matches the import name.

Parameters:
  • module_name (str) – The import name of the module.

  • distribution_name (str, optional) – The distribution name of the package. If None, the function attempts to infer the distribution name from the module name.

  • return_bool (bool, optional) – If True, returns a boolean indicating whether the distribution name matches the module import name. Otherwise, returns the name recommended for installation.

Returns:

Depending on return_bool, returns either a boolean indicating if the distribution name matches the module name, or the name (distribution or module) recommended for installation.

Return type:

Union[str, bool]

geoprior.utils.deps_utils.is_module_installed(module_name, distribution_name=None)[source]

Check if a Python module is installed by attempting to import it. Optionally, a distribution name can be provided if it differs from the module name.

Parameters:
  • module_name (str) – The import name of the module to check.

  • distribution_name (str, optional) – The distribution name of the package as known by package managers (e.g., pip). If provided and the module import fails, an additional check based on the distribution name is performed. This parameter is useful for packages where the distribution name differs from the importable module name.

Returns:

True if the module can be imported or the distribution package is installed, False otherwise.

Return type:

bool

Examples

>>> is_module_installed("sklearn")
True
>>> is_module_installed("scikit-learn", "scikit-learn")
True
>>> is_module_installed("some_nonexistent_module")
False
geoprior.utils.deps_utils.import_optional_dependency(name, extra='', errors='raise', min_version=None, exception=None)[source]

Import an optional dependency.

By default, if a dependency is missing an ImportError with a nice message will be raised. If a dependency is present, but too old, we raise.

Parameters:
  • name (str) – The module name.

  • extra (str) – Additional text to include in the ImportError message.

  • errors (str {'raise', 'warn', 'ignore'}) –

    What to do when a dependency is not found or its version is too old.

    • raise : Raise an ImportError

    • warn : Only applicable when a module’s version is to old. Warns that the version is too old and returns None

    • ignore: If the module is not installed, return None, otherwise, return the module, even if the version is too old. It’s expected that users validate the version locally when using errors="ignore" (see. io/html.py)

  • min_version (str, default None) – Specify a minimum version that is different from the global pandas minimum version required.

  • exception (callable, BaseException) – Can be your own package exception rather than ImportError

Returns:

maybe_module – The imported module, when found and the version is correct. None is returned when the package is not found and errors is False, or when the package’s version is too old and errors is 'warn'.

Return type:

Optional[ModuleType]

geoprior.utils.deps_utils.ensure_module_installed(module_name, auto_install=False, version=None, package_manager='pip', dist_name=None, extra_install_args=None)[source]

Ensure that the required module is installed, optionally installing it if missing.

Parameters:
  • module_name (str) – The name of the module to check and install if necessary.

  • auto_install (bool, optional) – If True, automatically install the module using the specified package manager if it is not already installed (default is False).

  • version (Optional[str], optional) – Specify a version or version range for the module. For example, “>=1.0.0” or “==2.0.1”. If None, no version constraints are applied (default is None).

  • package_manager (str, optional) – The package manager to use for installation. Currently, only "pip" is supported. Future versions may support other package managers like "conda" (default is "pip").

  • dist_name (Optional[str], optional) – Sometimes the module name used for importing is different from the distribution package name. This parameter allows specifying the distribution package name (default is None).

  • extra_install_args (Optional[List[str]], optional) – A list of additional arguments to pass to the package manager during installation. For example, ["--upgrade"] to upgrade the package. If None, no extra arguments are passed (default is None).

Returns:

Returns True if the module is installed or successfully installed, False otherwise.

Return type:

bool

Raises:
  • ImportError – If the module is not installed and auto_install is False, or if the installation fails.

  • ValueError – If an unsupported package manager is specified.

Examples

>>> from geoprior.utils.deps_utils import ensure_module_installed
>>> # Ensure that 'numpy' is installed
>>> ensure_module_installed("numpy")
>>> # Ensure that 'pandas' is installed, automatically installing if missing
>>> ensure_module_installed("pandas", auto_install=True)
>>> # Ensure that 'scipy' version >=1.5.0 is installed
>>> ensure_module_installed("scipy", version=">=1.5.0", auto_install=True)
>>> # Install with additional arguments
>>> ensure_module_installed(
...     "requests",
...     auto_install=True,
...     extra_install_args=["--upgrade"]
... )

Notes

This function currently supports only "pip" as the package manager. When specifying a version, ensure that the version string is compatible with the package manager’s version specification syntax. Packages that require system-level dependencies may still need manual installation steps.

See also

subprocess

For spawning new processes.

sys

System-specific parameters and functions.

geoprior.utils.deps_utils.get_versions(extras=None, distribution_mapping=None)[source]

Retrieve a dictionary containing version information for common libraries, as well as any user-specified packages and distribution name mappings.

Parameters:
  • extras (list of str, optional) – Additional packages for which to attempt version retrieval. By default, None, which means no extra packages beyond the defaults.

  • distribution_mapping (dict, optional) – Mapping from import-like names to actual distribution names. For example, the import name 'sklearn' corresponds to the distribution name 'scikit-learn'. Default is None, which uses a built-in mapping for scikit-learn and any user-provided dictionary overrides or additions.

Returns:

Dictionary of the form:

{
    "__version__": {
        "numpy": "1.24.2",
        "pandas": "1.5.0",
        "sklearn": "1.3.2",
        ...
    }
}

Return type:

dict

Notes

  • By default, this function attempts to retrieve versions for the following packages: ['numpy', 'pandas', 'sklearn', 'joblib', 'tensorflow', 'keras', 'torch'].

  • If a package is not installed, it is skipped (no error is raised).

  • If <distribution_mapping> is provided, it merges with the built-in mapping (for "sklearn""scikit-learn"), allowing users to specify additional name differences.

  • Python 3.8+ is recommended to ensure importlib.metadata is available.

Examples

>>> get_versions()
{
  "__version__": {
    "numpy": "1.24.2",
    "pandas": "1.5.0",
    ...
  }
}
>>> # Add custom package and distribution mapping:
>>> get_versions(
...   extras=["spacy"],
...   distribution_mapping={"spacy": "spacy-legacy"}
... )
{
  "__version__": {
    "numpy": "1.24.2",
    "pandas": "1.5.0",
    "spacy": "3.5.1"
  }
}

geoprior.utils.split

Group-holdout split for sequence data.

Exports:

train_windows_T{T}_H{H}.npz val_windows_T{T}_H{H}.npz test_windows_T{T}_H{H}.npz future_inputs_T{T}_H{H}.npz splits_groups.json

Leakage fix (Zhongshan 2 windows/pixel):

split by group_id first, then window inside split.

class geoprior.utils.split.SplitCfg(seed: 'int' = 42, ratios: 'tuple[float, float, float]' = (0.7, 0.15, 0.15), decimals: 'int' = 8)[source]

Bases: object

Parameters:
seed: int = 42
ratios: tuple[float, float, float] = (0.7, 0.15, 0.15)
decimals: int = 8
__init__(seed=42, ratios=(0.7, 0.15, 0.15), decimals=8)
Parameters:
Return type:

None

geoprior.utils.split.split_group_keys(keys, *, cfg=SplitCfg(seed=42, ratios=(0.7, 0.15, 0.15), decimals=8))[source]
Parameters:
Return type:

dict[str, ndarray]

geoprior.utils.split.subset_by_keys(df, *, group_cols, keys, decimals=8)[source]
Parameters:
Return type:

DataFrame

geoprior.utils.split.write_splits_json(path, *, group_cols, time_steps, horizon, train_end, cfg, splits)[source]
Parameters:
Return type:

str

geoprior.utils.split.pack_xy_npz(x, y)[source]
Parameters:
Return type:

dict[str, ndarray]

geoprior.utils.split.build_group_holdout_npzs(*, df_train, artifacts_dir, group_cols, time_col_used, x_col_used, y_col_used, subs_col, gwl_target_col, gwl_dyn_col, h_field_col, static_cols, dynamic_cols, future_cols, time_steps, horizon, mode, model_name, train_end, keys_ok, cfg=SplitCfg(seed=42, ratios=(0.7, 0.15, 0.15), decimals=8), normalize_coords=True)[source]

Build train/val/test windows using group holdout.

Returns dict containing paths and coord_scaler.

Parameters:
Return type:

dict[str, Any]

geoprior.utils.split.build_future_inputs_npz(*, df_scaled, artifacts_dir, time_col, time_col_num, lon_col, lat_col, subs_col, gwl_col, h_field_col, static_features, dynamic_features, future_features, group_cols, train_end_time, forecast_start_time, horizon, time_steps, mode, model_name, normalize_coords, coord_scaler=None)[source]
Parameters:
Return type:

str

Provides a comprehensive set of functions and warnings for validating and ensuring the integrity of data. This includes utilities for checking data consistency, validating machine learning targets, ensuring proper data types, and handling various validation scenarios.

exception geoprior.utils.validator.DataConversionWarning[source]

Bases: UserWarning

Warning used to notify implicit data conversions happening in the code.

This warning occurs when some input data needs to be converted or interpreted in a way that may not match the user’s expectations. For example, this warning may occur when the user:

  • passes an integer array to a function that expects float input and will convert the input;

  • requests a non-copying operation, but a copy is required to meet the implementation’s data-type expectations;

  • passes an input whose shape can be interpreted ambiguously.

Changed in version 0.18: Moved from sklearn.utils.validation.

exception geoprior.utils.validator.PositiveSpectrumWarning[source]

Bases: UserWarning

Warning raised when the eigenvalues of a PSD matrix have issues

This warning is typically raised by _check_psd_eigenvalues when the eigenvalues of a positive semidefinite (PSD) matrix such as a gram matrix (kernel) present significant negative eigenvalues, or bad conditioning i.e. very small non-zero eigenvalues compared to the largest eigenvalue.

Added in version 0.22.

geoprior.utils.validator.array_to_frame(X, *, to_frame=False, columns=None, raise_exception=False, raise_warning=True, input_name='', force=False)[source]

Validates and optionally converts an array-like object to a pandas DataFrame, applying specified column names if provided or generating them if the force parameter is set.

Parameters:
  • X (array-like) – The array to potentially convert to a DataFrame.

  • columns (str or list of str, optional) – The names for the resulting DataFrame columns or the Series name.

  • to_frame (bool, default False) – If True, converts X to a DataFrame if it isn’t already one.

  • input_name (str, default '') – The name of the input variable, used for error and warning messages.

  • raise_warning (bool, default True) – If True and to_frame is True but columns are not provided, a warning is issued unless force is True.

  • raise_exception (bool, default False) – If True, raises an exception when to_frame is True but columns are not provided and force is False.

  • force (bool, default False) – Forces the conversion of X to a DataFrame by generating column names based on input_name if columns are not provided.

Returns:

The potentially converted DataFrame or Series, or X unchanged.

Return type:

pd.DataFrame or pd.Series

Examples

>>> from geoprior.utils.validator import array_to_frame
>>> from sklearn.datasets import load_iris
>>> data = load_iris()
>>> X = data.data
>>> array_to_frame(X, to_frame=True, columns=['sepal_length', 'sepal_width',
                                              'petal_length', 'petal_width'])
geoprior.utils.validator.array_to_frame2(X, *, to_frame=False, columns=None, raise_exception=False, raise_warning=True, input_name='', force=False)[source]

Added part of is_frame dedicated to X and y frame reconversion validation.

Parameters:
  • X (Array-like) – Array to convert to frame.

  • columns (str or list of str) – Series name or columns names for pandas.Series and DataFrame.

  • to_frame (str, default False) – If True , reconvert the array to frame using the columns orthewise no-action is performed and return the same array.

  • input_name (str, default "") – The data name used to construct the error message.

  • raise_warning (bool, default True) – If True then raise a warning if conversion is required. If ignore, warnings silence mode is triggered.

  • raise_exception (bool, default False) – If True then raise an exception if array is not symmetric.

  • force (bool, default False) – Force conversion array to a frame is columns is not supplied. Use the combinaison, input_name and X.shape[1] range.

Returns:

X

Return type:

converted array

Example

>>> from geoprior.datasets import fetch_data
>>> from geoprior.utils.validator import array_to_frame
>>> data = fetch_data ('hlogs').frame
>>> array_to_frame (data.k.values ,
                    to_frame= True, columns =None, input_name= 'y',
                    raise_warning="silence"
                            )
... array([nan, nan, nan, ..., nan, nan, nan]) # mute
geoprior.utils.validator.assert_all_finite(X, *, allow_nan=False, estimator_name=None, input_name='')[source]

Throw a ValueError if X contains NaN or infinity.

Parameters:
  • X ({ndarray, sparse matrix}) – The input data.

  • allow_nan (bool, default False) – If True, do not throw error when X contains NaN.

  • estimator_name (str, default None) – The estimator name, used to construct the error message.

  • input_name (str, default "") – The data name used to construct the error message. In particular if input_name is “X” and the data has NaN values and allow_nan is False, the error message will link to the imputer documentation.

geoprior.utils.validator.assert_xy_in(x, y, *, data=None, asarray=True, to_frame=False, columns=None, xy_numeric=False, ignore=None, **kws)[source]

Assert the name of x and y in the given data.

Check whether string arguments passed to x and y are valid in the data, then retrieve the x and y array values.

Parameters:
  • x (Arraylike 1d or str, str) – One dimensional arrays. In principle if data is supplied, they must constitute series. If x and y are given as string values, the data must be supplied. x and y names must be included in the dataframe otherwise an error raises.

  • y (Arraylike 1d or str, str) – One dimensional arrays. In principle if data is supplied, they must constitute series. If x and y are given as string values, the data must be supplied. x and y names must be included in the dataframe otherwise an error raises.

  • data (pd.DataFrame,) – Data containing x and y names. Need to be supplied when x and y are given as string names.

  • asarray (bool, default =True) – Returns x and y as array rather than series.

  • to_frame (bool, default False,) – Convert data to a dataframe using either the columns names or the input_names when the keyword parameter force=True.

  • columns (list of str, Optional) – Name of columns to transform the array ( data) to a dataframe.

  • xy_numeric (bool, default False) – Convert x and y to numeric values.

  • ignore (str, optional) – It should be ‘x’ or ‘y’. If set the array is ignored and not asserted.

  • kws (dict,) – Keyword arguments passed to array_to_frame().

Returns:

x, y – One dimensional array or pd.Series

Return type:

Arraylike

Examples

>>> import numpy as np
>>> import pandas as pd
>>> from geoprior.utils.validator import assert_xy_in
>>> x, y = np.random.rand(7 ), np.arange (7 )
>>> data = pd.DataFrame ({'x': x, 'y':y} )
>>> assert_xy_in (x='x', y='y', data = data )
(array([0.37454012, 0.95071431, 0.73199394, 0.59865848, 0.15601864,
        0.15599452, 0.05808361]),
 array([0, 1, 2, 3, 4, 5, 6]))
>>> assert_xy_in (x=x, y=y)
(array([0.37454012, 0.95071431, 0.73199394, 0.59865848, 0.15601864,
        0.15599452, 0.05808361]),
 array([0, 1, 2, 3, 4, 5, 6]))
>>> assert_xy_in (x=x, y=data.y) # y is a series
(array([0.37454012, 0.95071431, 0.73199394, 0.59865848, 0.15601864,
        0.15599452, 0.05808361]),
 array([0, 1, 2, 3, 4, 5, 6]))
>>> assert_xy_in (x=x, y=data.y, asarray =False ) # return y like it was
(array([0.37454012, 0.95071431, 0.73199394, 0.59865848, 0.15601864,
        0.15599452, 0.05808361]),
0    0
1    1
2    2
3    3
4    4
5    5
6    6
Name: y, dtype: int32)
geoprior.utils.validator.build_data_if(data, columns=None, to_frame=True, input_name='data', col_prefix='col_', force=False, error='warn', coerce_datetime=False, coerce_numeric=True, start_incr_at=0, **kw)[source]

Validates and converts data into a pandas DataFrame if requested, optionally enforcing consistent column naming. Intended to standardize data structures for downstream analysis.

See more in geoprior.utils.data_utils.build_df() for documentation details.

geoprior.utils.validator.check_X_y(X, y, accept_sparse=False, *, accept_large_sparse=True, dtype='numeric', order=None, copy=False, force_all_finite=True, ensure_2d=True, allow_nd=False, multi_output=False, ensure_min_samples=1, ensure_min_features=1, y_numeric=False, estimator=None, to_frame=False)[source]

Input validation for standard estimators.

Checks X and y for consistent length, enforces X to be 2D and y 1D. By default, X is checked to be non-empty and containing only finite values. Standard input checks are also applied to y, such as checking that y does not have np.nan or np.inf targets. For multi-label y, set multi_output=True to allow 2D and sparse y. If the dtype of X is object, attempt converting to float, raising on failure.

Parameters:
  • X ({ndarray, list, sparse matrix}) – Input data.

  • y ({ndarray, list, sparse matrix}) – Labels.

  • accept_sparse (str, bool or list of str, default False) – String[s] representing allowed sparse matrix formats, such as ‘csc’, ‘csr’, etc. If the input is sparse but not in the allowed format, it will be converted to the first listed format. True allows the input to be any format. False means that a sparse matrix input will raise an error.

  • accept_large_sparse (bool, default True) – If a CSR, CSC, COO or BSR sparse matrix is supplied and accepted by accept_sparse, accept_large_sparse will cause it to be accepted only if its indices are stored with a 32-bit dtype.

  • dtype ('numeric', type, list of type or None, default 'numeric') – Data type of result. If None, the dtype of the input is preserved. If “numeric”, dtype is preserved unless array.dtype is object. If dtype is a list of types, conversion on the first type is only performed if the dtype of the input is not in the list.

  • order ({'F', 'C'}, default None) – Whether an array will be forced to be fortran or c-style.

  • copy (bool, default False) – Whether a forced copy will be triggered. If copy=False, a copy might be triggered by a conversion.

  • force_all_finite (bool or 'allow-nan', default True) – Whether to raise an error on np.inf, np.nan, pd.NA in X. This parameter does not influence whether y can have np.inf, np.nan, pd.NA values. Use True to require all values of X to be finite, False to allow np.inf, np.nan, and pd.NA, or "allow-nan" to allow only np.nan and pd.NA while still rejecting infinite values. pd.NA is accepted and converted into np.nan.

  • ensure_2d (bool, default True) – Whether to raise a value error if X is not 2D.

  • allow_nd (bool, default False) – Whether to allow X.ndim > 2.

  • multi_output (bool, default False) – Whether to allow 2D y (array or sparse matrix). If false, y will be validated as a vector. y cannot have np.nan or np.inf values if multi_output=True.

  • ensure_min_samples (int, default 1) – Make sure that X has a minimum number of samples in its first axis (rows for a 2D array).

  • ensure_min_features (int, default 1) – Make sure that the 2D array has some minimum number of features (columns). The default value of 1 rejects empty datasets. This check is only enforced when X has effectively 2 dimensions or is originally 1D and ensure_2d is True. Setting to 0 disables this check.

  • y_numeric (bool, default False) – Whether to ensure that y has a numeric type. If dtype of y is object, it is converted to float64. Should only be used for regression algorithms.

  • estimator (str or estimator instance, default None) – If passed, include the name of the estimator in warning messages.

Returns:

  • X_converted (object) – The converted and validated X.

  • y_converted (object) – The converted and validated y.

geoprior.utils.validator.check_array(array, *, accept_large_sparse=True, dtype='numeric', accept_sparse=False, order=None, copy=False, force_all_finite=True, ensure_2d=True, allow_nd=False, ensure_min_samples=1, ensure_min_features=1, estimator=None, input_name='', to_frame=True)[source]

Input validation on an array, list, or similar.

By default, the input is checked to be a non-empty 2D array containing only finite values. If the dtype of the array is object, attempt converting to float, raising on failure.

Parameters:
  • array (object) – Input object to check / convert.

  • accept_sparse (str, bool or list/tuple of str, default False) – String[s] representing allowed sparse matrix formats, such as ‘csc’, ‘csr’, etc. If the input is sparse but not in the allowed format, it will be converted to the first listed format. True allows the input to be any format. False means that a sparse matrix input will raise an error.

  • accept_large_sparse (bool, default True) – If a CSR, CSC, COO or BSR sparse matrix is supplied and accepted by accept_sparse, accept_large_sparse=False will cause it to be accepted only if its indices are stored with a 32-bit dtype.

  • dtype ('numeric', type, list of type or None, default 'numeric') – Data type of result. If None, the dtype of the input is preserved. If “numeric”, dtype is preserved unless array.dtype is object. If dtype is a list of types, conversion on the first type is only performed if the dtype of the input is not in the list.

  • order ({'F', 'C'} or None, default None) – Whether an array will be forced to be fortran or c-style. When order is None (default), then if copy=False, nothing is ensured about the memory layout of the output array; otherwise (copy=True) the memory layout of the returned array is kept as close as possible to the original array.

  • copy (bool, default False) – Whether a forced copy will be triggered. If copy=False, a copy might be triggered by a conversion.

  • force_all_finite (bool or 'allow-nan', default True) – Whether to raise an error on np.inf, np.nan, or pd.NA in array. Use True to require all values to be finite, False to allow np.inf, np.nan, and pd.NA, or "allow-nan" to allow only np.nan and pd.NA while still rejecting infinite values. pd.NA is converted into np.nan.

  • ensure_2d (bool, default True) – Whether to raise a value error if array is not 2D.

  • ensure_min_samples (int, default 1) – Make sure that the array has a minimum number of samples in its first axis (rows for a 2D array). Setting to 0 disables this check.

  • ensure_min_features (int, default 1) – Make sure that the 2D array has some minimum number of features (columns). The default value of 1 rejects empty datasets. This check is only enforced when the input data has effectively 2 dimensions or is originally 1D and ensure_2d is True. Setting to 0 disables this check.

  • estimator (str or estimator instance, default None) – If passed, include the name of the estimator in warning messages.

  • input_name (str, default "") – The data name used to construct the error message. In particular if input_name is “X” and the data has NaN values and allow_nan is False, the error message will link to the imputer documentation.

  • to_frame (bool, default False) – Reconvert array back to pd.Series or pd.DataFrame if the original array is pd.Series or pd.DataFrame.

Returns:

array_converted – The converted and validated array.

Return type:

object

geoprior.utils.validator.check_classification_targets(*y, target_type='numeric', strategy='auto', verbose=False)[source]

Validate that the target arrays are suitable for classification tasks.

This function is designed to ensure that target arrays (y) contain only finite, categorical values, and it raises a ValueError if the targets do not meet the criteria necessary for classification tasks, such as the presence of continuous values, NaNs, or infinite values.

This validation is crucial for preprocessing steps in machine learning pipelines to ensure that the data is appropriate for classification algorithms.

Parameters:
  • *y (array-like) – One or more target arrays to be validated. The input can be in the form of lists, numpy arrays, or pandas series. Each array is checked individually to ensure it meets the criteria for classification targets.

  • target_type (str, optional) – The expected data type of the target arrays. Supported values are ‘numeric’ and ‘object’. If ‘numeric’, the function attempts to convert the target arrays to integers, raising an error if conversion is not possible due to non-numeric values. If ‘object’, the target arrays are left as numpy arrays of dtype object, suitable for categorical classification without conversion. Default is ‘numeric’.

  • strategy (str, optional) –

    Defines the approach for evaluating if the target arrays are suitable for classification based on their unique values and data types. The ‘auto’ strategy uses heuristic or automatic detection to decide whether target data should be treated as categorical, which is useful for most cases. Custom strategies can be defined to enforce specific validation rules or preprocessing steps based on the nature of the target data (e.g., ‘continuous’, ‘multilabel-indicator’, ‘unknown’). These custom strategies should align with the outcomes of a predefined type_of_target function, allowing for nuanced handling of different target data scenarios. The default value is 'auto', which applies general rules for categorization and numeric conversion where applicable.

    If a strategy other than 'auto' is specified, it directly influences how the data is validated and potentially converted, based on the expected or detected type of target data:

    • If ‘continuous’, the function checks if the data can be used for regression tasks and raises an error for classification use without explicit binning.

    • If ‘multilabel-indicator’, it validates the data for multilabel classification tasks and ensures appropriate format.

    • If ‘unknown’, it attempts to validate the data with generic checks, raising errors for any unclear or unsupported data formats.

  • verbose (bool, optional) – If set to True, the function prints a message for each target array checked, confirming that it is suitable for classification. This is helpful for debugging and when validating multiple target arrays simultaneously.

Raises:

ValueError – If any of the target arrays contain values unsuitable for classification. This includes arrays with continuous values, NaNs, infinite values, or arrays that do not represent categorical data properly.

Examples

Using the function with a single array of integer labels:

>>> from geoprior.utils.validator import check_classification_targets
>>> y = [1, 2, 3, 2, 1]
>>> check_classification_targets(y)
[array([1, 2, 3, 2, 1], dtype=object)]

Using the function with multiple arrays, including a mix of integer and string labels:

>>> y1 = [0, 1, 0, 1]
>>> y2 = ["spam", "ham", "spam", "ham"]
>>> check_classification_targets(y1, y2, verbose=True)
Targets are suitable for classification.
Targets are suitable for classification.
[array([0, 1, 0, 1], dtype=object), array(['spam', 'ham', 'spam', 'ham'], dtype=object)]

Attempting to use the function with an array containing NaN values:

>>> y_with_nan = [1, np.nan, 2, 1]
>>> check_classification_targets(y_with_nan)
ValueError: Target values contain NaN or infinite numbers, which are not
suitable for classification.

Attempting to use the function with a continuous target array:

>>> y_continuous = np.linspace(0, 1, 10)
>>> check_classification_targets(y_continuous)
ValueError: The number of unique values is too high for a classification task.
Validating and converting a mixed-type target array to numeric:
>>> y_mixed = [1, '2', 3.0, '4', 5]
>>> check_classification_targets(y_mixed, target_type='numeric')
ValueError: Target array at index 0 contains non-numeric values, which
cannot be converted to integers: ['2', '4']...

Validating object target arrays without attempting conversion:

>>> y_str = ["apple", "banana", "cherry"]
>>> check_classification_targets(y_str, target_type='object')
[array(['apple', 'banana', 'cherry'], dtype=object)]
geoprior.utils.validator.check_consistency_size(*arrays)[source]

Check consistency of array and raises error otherwise.

geoprior.utils.validator.check_consistent_length(*arrays)[source]

Check that all arrays have consistent first dimensions.

Checks whether all objects in arrays have the same shape or length.

Parameters:

*arrays (list or tuple of input objects.) – Objects that will be checked for consistent length.

geoprior.utils.validator.check_donut_inputs(values=None, data=None, labels=None, ops='check', labels_as_index=True, index=None, origin_index='drop', value_name='auto')[source]

Validate and/or build inputs for donut chart plotting.

This function accepts inputs in various forms and returns a pair of numeric values and labels or builds a new \(n \\times 1\) DataFrame from them. The function supports two modes:

  • In ops="check", it returns a tuple (values, labels) after validating that the numeric values are appropriate for plotting.

  • In ops="build", it returns a pandas DataFrame constructed from the inputs. If labels_as_index is True, the labels become the DataFrame index; otherwise, they form a separate column. If an index is provided, it is used to reset the DataFrame index and the original index is either dropped or kept based on origin_index.

The function also accepts inputs through a DataFrame or Series (data). In such cases, if values is a \(\\text{str}\), it is interpreted as a column name of the DataFrame. Similarly, if labels is a \(\\text{str}\), it is used to fetch the label column.

(39)#\[\begin{split}S = \\{ x_i \\}_{i=1}^{n} \\quad \\text{and} \\quad L = \\{ l_i \\}_{i=1}^{n}\end{split}\]

where \(S\) denotes the numeric values and \(L\) denotes the corresponding labels.

Parameters:
  • values (array-like or str, optional) – Numeric values for the donut slices. If data is a DataFrame and values is a double backtick string`` ("colname"), then the column "colname" is used. If data is a Series and values is not provided, the series values are used.

  • data (pandas.Series or pandas.DataFrame, optional) – Data source from which to fetch values and labels. If provided, the function extracts the corresponding numeric data. For a DataFrame, if values (or labels) is a double backtick string`` ("colname"), the function fetches the column named "colname".

  • labels (array-like or str, optional) – Labels for the donut slices. If data is provided and labels is a double backtick string`` ("colname"), then the function uses the specified column as labels. If omitted, the function uses the index of the DataFrame or Series.

  • ops (:py:class:``”check”:py:class:`` or :py:class:``”build”:py:class:``, optional) – Operation mode of the function. In "check" mode, the function returns a tuple (values, labels) after validation. In "build" mode, it returns a new DataFrame built from the inputs. The default is "check".

  • labels_as_index (bool, optional) – If ops="build", this flag determines whether the labels are used as the DataFrame index. If True, the labels become the index; if False, they form a separate column. The default is True.

  • index (array-like or str, optional) – New index to assign in "build" mode. If a double backtick string`` is provided, it must correspond to a column in the DataFrame and that column is used as the new index. If a list is provided, it directly replaces the DataFrame index. In case the original index is to be retained, see origin_index.

  • origin_index (:py:class:``”drop”:py:class:`` or :py:class:``”keep”:py:class:``, optional) – Specifies whether to drop or retain the original index when resetting the DataFrame index. If set to "keep", the original index is saved in a new column named origin_index. The default is "drop".

  • value_name (:py:class:``”auto”:py:class:`` or str, optional) – Name to use for the numeric values in the built DataFrame (when ops="build"). If set to "auto" (or None), the default name "Value" is used unless overridden by the source data. Otherwise, the provided double backtick string`` (e.g., "Total") is used as the column name.

Returns:

  • If ops="check", returns a tuple (values, labels) where values is a NumPy array of numeric values and labels is a list of labels.

  • If ops="build", returns a pandas DataFrame constructed from the inputs. If labels_as_index is True, the DataFrame index is set to the provided labels (or the new index if index is specified). Otherwise, the DataFrame contains separate columns for the labels and numeric values.

Return type:

tuple of (ndarray, list) or pandas.DataFrame

Examples

Build inputs from a DataFrame with explicit column names:

>>> from geoprior.utils.validator import check_donut_inputs
>>> import pandas as pd
>>> df = pd.DataFrame({
...     "Sales": [100, 200, 150],
...     "Country": ["USA", "Canada", "Mexico"]
... })
>>> # Build a DataFrame using "Sales" as values and "Country" as index
>>> new_df = check_donut_inputs(
...     values="Sales",
...     data=df,
...     labels="Country",
...     ops="build",
...     labels_as_index=True,
...     index="Country",
...     origin_index="drop"
... )
>>> new_df
        Sales
USA      100
Canada   200
Mexico   150

Check inputs when only numeric values are provided:

>>> values, labs = check_donut_inputs(
...     values=[10, 20, 30],
...     labels=["A", "B", "C"],
...     ops="check"
... )
>>> values
array([10., 20., 30.])
>>> labs
['A', 'B', 'C']

Notes

The function internally calls the inline helper check_numeric_dtype to ensure that the provided numeric data satisfies the necessary type constraints. The function supports grouping or multiple donut charts by using the input DataFrame directly. See also check_numeric_dtype() for numeric type validation.

geoprior.utils.validator.check_epsilon(eps, y_true=None, y_pred=None, base_epsilon=1e-10, scale_factor=1e-05)[source]

Dynamically determine or validate an epsilon value for numerical computations.

This function either validates a provided epsilon if it is a numeric value, or calculates an appropriate epsilon dynamically based on the input data. The dynamic calculation aims to adjust epsilon based on the scale of the input data, providing flexibility and adaptability in algorithms where numerical stability is critical.

Parameters:
  • eps ({'auto', float}) – The epsilon value to use. If ‘auto’, the function dynamically determines an appropriate epsilon based on y_true and y_pred. If a float, it validates this as the epsilon value.

  • y_true (array-like, optional) – True values array. Used in conjunction with y_pred to dynamically determine epsilon if eps is ‘auto’. If None, this input is ignored.

  • y_pred (array-like, optional) – Predicted values array. Used alongside y_true for epsilon determination. If None, this input is ignored.

  • base_epsilon (float, optional) – Base epsilon value used as a starting point in dynamic determination. This value is adjusted based on the scale_factor and the input data to compute the final epsilon.

  • scale_factor (float, optional) – Scaling factor applied to adjust the base epsilon in relation to the scale of the input data. Helps tailor the epsilon to the problem’s numerical scale.

Returns:

The determined or validated epsilon value. Ensures numerical operations are conducted with an appropriate epsilon to avoid division by zero or other numerical instabilities.

Return type:

float

Examples

>>> y_true = [1, 2, 3]
>>> y_pred = [1.1, 1.9, 3.05]
>>> check_epsilon('auto', y_true, y_pred)
0.00001  # Example output, actual value depends on `determine_epsilon` implementation.
>>> check_epsilon(1e-8)
1e-8

Notes

Using ‘auto’ for eps allows algorithms to adapt to different scales of data, enhancing numerical stability without manually tuning the epsilon value.

geoprior.utils.validator.check_has_run_method(estimator, msg=None, method_name='run')[source]

Check if the given estimator has a callable run method or any other specified method. This utility helps validate that an object can execute the expected method before further actions are taken.

Parameters:
  • estimator (object) – The object (instance or class) to check for the presence of the run method or another specified method.

  • msg (str, optional) – Custom error message to display if the method is missing. If None, a default message is generated based on the method_name.

  • method_name (str, default "run") – The method name to check for. This defaults to run, but you can specify any method name. The method must be callable.

Raises:

AttributeError – Raised if the run method (or any specified method) does not exist on the object or is not callable.

Examples

>>> from geoprior.utils.validator import check_has_run_method
>>> class MyClass:
...     def run(self):
...         pass
>>> check_has_run_method(MyClass())  # No error
>>> class MyClassWithoutRun:
...     pass
>>> check_has_run_method(MyClassWithoutRun())  # Raises AttributeError

Notes

This function performs several checks:

  1. Existence check: It checks whether the run method (or any other specified method) exists in the estimator object.

  2. Callable check: It ensures that the method is callable, which rules out attributes that might exist but aren’t methods.

  3. Static/class method check: The function accepts static or class methods as valid callable methods.

  4. Bound method check: It verifies that instance methods are bound to an object when required, which ensures they can be called properly in the given context.

This function can be expressed as a validation function:

(40)#\[ext{check\_has\_method}(estimator, method\_name) = egin{cases} ext{valid}, & ext{if method exists and callable} \ ext{invalid}, & ext{if method is missing or not callable} \end{cases}\]

It determines whether the method is callable or raises an error otherwise. Callable-method validation here follows the Python documentation and the staticmethod overview in [43, 44].

See also

validate_estimator_methods

A helper function to validate multiple methods on an estimator.

geoprior.utils.validator.check_is_fitted(estimator, attributes=None, *, msg=None, all_or_any=<built-in function all>)[source]

Perform is_fitted validation for estimator.

Checks if the estimator is fitted by verifying the presence of fitted attributes (ending with a trailing underscore) and otherwise raises a NotFittedError with the given message.

If an estimator does not set any attributes with a trailing underscore, it can define a __sklearn_is_fitted__ or __fusionlab_is_fitted__ method returning a boolean to specify if the estimator is fitted or not.

Parameters:
  • estimator (estimator instance) – Estimator instance for which the check is performed.

  • attributes (str, list or tuple of str, default None) –

    Attribute name(s) given as string or a list/tuple of strings Eg.: ["coef_", "estimator_", ...], "coef_"

    If None, estimator is considered fitted if there exist an attribute that ends with a underscore and does not start with double underscore.

  • msg (str, default None) –

    The default error message is, “This %(name)s instance is not fitted yet. Call ‘fit’ with appropriate arguments before using this estimator.”

    For custom messages if “%(name)s” is present in the message string, it is substituted for the estimator name.

    Eg. : “Estimator, %(name)s, must be fitted before sparsifying”.

  • all_or_any (callable, {all, any}, default all) – Specify whether all or any of the given attributes must exist.

Raises:
  • TypeError – If the estimator is a class or not an estimator instance

  • NotFittedError – If the attributes are not found.

geoprior.utils.validator.check_is_fitted2(estimator, attributes, *, msg=None)[source]

Perform is_fitted validation for estimator.

Checks if the estimator is fitted by looking for attributes set during fitting. Typically, these attributes end with an underscore (‘_’).

Parameters:
  • estimator (BaseEstimator) – An instance of a scikit-learn estimator.

  • attributes (str or list of str) – The attributes to check for. These are typically set in the ‘fit’ method.

  • msg (str, optional) – The message to raise in the NotFittedError. If not provided, a default message is used.

Raises:

NotFittedError – If the given attributes are not found in the estimator.

Examples

>>> from sklearn.ensemble import RandomForestClassifier
>>> clf = RandomForestClassifier()
>>> check_is_fitted(clf, ['feature_importances_'])
NotFittedError: This RandomForestClassifier instance is not fitted yet.
geoprior.utils.validator.check_is_runned(estimator, attributes=None, *, msg=None, all_or_any=<built-in function all>)[source]

Validate if an estimator instance has been “runned” (executed) prior to invoking dependent methods. This check ensures that the estimator is in the appropriate operational state, allowing users to identify and address runtime issues effectively.

If an estimator does not set “runned” attributes (such as _is_runned), it may define a __gofast_is_runned__ method. This method should return a boolean indicating whether the estimator is “runned” or not.

Parameters:
  • estimator (object) –

    The instance of the estimator or class being validated. This parameter represents the object in which dependent methods are validated to confirm that the “runned” state has been achieved.

    To determine the “runned” status, the function checks for specific attributes or, if defined, the __gofast_is_runned__ method.

  • attributes (str, list, or tuple of str, optional, default None) –

    Specifies the name(s) of attributes that indicate the “runned” status, such as ['_is_runned'] or ['_is_fitted']. If these attributes are present and set to True, the estimator is considered to have been runned.

    If attributes is set to None, the function will default to checking for _is_runned. This default provides flexibility for estimators that employ standard runned flags.

  • msg (str, optional, default None) –

    Custom error message to be displayed if the validation fails. By default, this error message uses the class name of the estimator in the format:

    ”This %(name)s instance has not been ‘runned’ yet. Call ‘run’ with appropriate arguments before using this method.”

    To customize the message, include %(name)s as a placeholder for the estimator’s class name.

  • all_or_any (callable, {all, any}, optional, default all) – Determines whether all or any of the specified attributes must be present and set to True. By default, the function expects all attributes to be set to True. Set to any for greater flexibility with multiple attributes.

``__gofast_is_runned__`` : optional, callable

If defined within the estimator, this method should return a boolean indicating the “runned” status of the estimator. This provides an alternative to using attributes.

Raises:

RuntimeError – If none of the specified attributes are set to True or if the __gofast_is_runned__ method (if present) returns False.

Notes

The check_is_runned function ensures that methods dependent on the “runned” status are only executed after the estimator has completed all required preliminary processes, like fit or run. This helper mirrors the fitted-state checks described in [45, 46].

Examples

>>> from geoprior.utils.validator import check_is_runned
>>> class ExampleClass:
...     def __init__(self):
...         self._is_runned = False
...
...     def run(self):
...         self._is_runned = True
...         print("Run completed.")
...
...     def process_data(self):
...         check_is_runned(self)
...         print("Processing data...")
>>> model = ExampleClass()
>>> model.process_data()  # Raises RuntimeError
>>> model.run()
>>> model.process_data()  # Now it works

See also

check_is_fitted

Validates that an estimator has been “fitted” before further use.

validate_estimator_methods

Validates essential estimator methods.

geoprior.utils.validator.check_memory(memory)[source]

Check that memory is joblib.Memory-like.

joblib.Memory-like means that memory can be converted into a joblib.Memory instance (typically a str denoting the location) or has the same interface (has a cache method).

Parameters:

memory (None, str or object with the joblib.Memory interface) –

  • If string, the location where to create the joblib.Memory interface.

  • If None, no caching is done and the Memory object is completely transparent.

Returns:

memory – A correct joblib.Memory object.

Return type:

object with the joblib.Memory interface

Raises:

ValueError – If memory is not joblib.Memory-like.

geoprior.utils.validator.check_mixed_data_types(data)[source]

Checks if the given data (DataFrame or numpy array) contains both numerical and categorical columns.

Parameters:

data (pd.DataFrame or np.ndarray) – The data to check. Can be a pandas DataFrame or a numpy array. If data is a numpy array, it is temporarily converted to a DataFrame for type checking.

Returns:

True if the data contains both numerical and categorical columns, False otherwise.

Return type:

bool

Examples

Using with a pandas DataFrame:

>>> import numpy as np
>>> import pandas as pd
>>> from geoprior.utils.validator import check_mixed_data_types
>>> df = pd.DataFrame({'A': [1, 2, 3], 'B': ['a', 'b', 'c']})
>>> print(check_mixed_data_types(df))
True

Using with a numpy array:

>>> array = np.array([[1, 'a'], [2, 'b'], [3, 'c']])
>>> print(check_mixed_data_types(array))
True

With data containing only numerical values:

>>> df_numeric_only = pd.DataFrame({'A': [1, 2, 3], 'B': [4, 5, 6]})
>>> print(check_mixed_data_types(df_numeric_only))
False

With data containing only categorical values:

>>> df_categorical_only = pd.DataFrame({'A': ['a', 'b', 'c'], 'B': ['d', 'e', 'f']})
>>> print(check_mixed_data_types(df_categorical_only))
False
geoprior.utils.validator.check_random_state(seed)[source]

Turn seed into a np.random.RandomState instance.

Parameters:

seed (None, int or instance of RandomState) – If seed is None, return the RandomState singleton used by np.random. If seed is an int, return a new RandomState instance seeded with seed. If seed is already a RandomState instance, return it. Otherwise raise ValueError.

Returns:

The random state object based on seed parameter.

Return type:

numpy.random.RandomState

geoprior.utils.validator.check_scalar(x, name, target_type, *, min_val=None, max_val=None, include_boundaries='both')[source]

Validate scalar parameters type and value.

Parameters:
  • x (object) – The scalar parameter to validate.

  • name (str) – The name of the parameter to be printed in error messages.

  • target_type (type or tuple) – Acceptable data types for the parameter.

  • min_val (float or int, default None) – The minimum valid value the parameter can take. If None (default) it is implied that the parameter does not have a lower bound.

  • max_val (float or int, default None) – The maximum valid value the parameter can take. If None (default) it is implied that the parameter does not have an upper bound.

  • include_boundaries ({"left", "right", "both", "neither"}, default "both") – Whether the interval defined by min_val and max_val should include the boundaries. Use "left" for [min_val, max_val), "right" for (min_val, max_val], "both" for [min_val, max_val], or "neither" for (min_val, max_val).

Returns:

x – The validated number.

Return type:

numbers.Number

Raises:
  • TypeError – If the parameter’s type does not match the desired type.

  • ValueError – If the parameter’s value violates the given bounds. If min_val, max_val and include_boundaries are inconsistent.

geoprior.utils.validator.check_symmetric(array, *, tol=1e-10, raise_warning=True, raise_exception=False)[source]

Make sure that array is 2D, square and symmetric.

If the array is not symmetric, then a symmetrized version is returned. Optionally, a warning or exception is raised if the matrix is not symmetric.

Parameters:
  • array ({ndarray, sparse matrix}) – Input object to check / convert. Must be two-dimensional and square, otherwise a ValueError will be raised.

  • tol (float, default 1e-10) – Absolute tolerance for equivalence of arrays. Default = 1E-10.

  • raise_warning (bool, default True) – If True then raise a warning if conversion is required.

  • raise_exception (bool, default False) – If True then raise an exception if array is not symmetric.

Returns:

array_sym – Symmetrized version of the input array, i.e. the average of array and array.transpose(). If sparse, then duplicate entries are first summed and zeros are eliminated.

Return type:

{ndarray, sparse matrix}

geoprior.utils.validator.check_y(y, multi_output=False, y_numeric=False, input_name='y', estimator=None, to_frame=False, allow_nan=False)[source]

Validates the target array y, ensuring it is suitable for classification or regression tasks based on its content and the specified strategy.

Parameters:
  • y (array-like) – Target values to validate.

  • multi_output (bool, default False) – Whether to allow two-dimensional y values. If False, y is validated as a vector. When multi_output=True, y still cannot contain np.nan or np.inf values unless allow_nan permits NaNs.

  • y_numeric (bool, default False) – Whether to ensure that y has a numeric type. If dtype of y is object, it is converted to float64. Should only be used for regression algorithms.

  • input_name (str, default "y") – Data name used to construct the error message.

  • estimator (str or estimator instance, default None) – If passed, include the name of the estimator in warning messages.

  • allow_nan (bool, default False) – If True, do not raise an error when y contains NaN values.

  • to_frame (bool, default False) – Reconvert the validated array to its initial pandas type when the input was provided as a pandas Series or DataFrame.

Returns:

y_converted – The converted and validated y.

Return type:

object

geoprior.utils.validator.contains_nested_objects(lst, strict=False, allowed_types=None)[source]

Determines whether a list contains nested objects.

Parameters:
  • lst (list) – The list to be checked for nested objects.

  • strict (bool, optional) – If True, all items in the list must be nested objects. If False, the function returns True if any item is a nested object. Default is False.

  • allowed_types (tuple of types, optional) – A tuple of types to consider as nested objects. If None, common nested types like list, set, dict, and tuple are checked. Default is None.

Returns:

True if the list contains nested objects according to the given parameters, otherwise False.

Return type:

bool

Notes

A nested object is defined as any item within the list that is not a primitive data type (e.g., int, float, str) or is a complex structure like lists, sets, dictionaries, etc. The function can be customized to check for specific types using the allowed_types parameter.

Examples

>>> from geoprior.utils.validator import contains_nested_objects
>>> example_list1 = [{1, 2}, [3, 4], {'key': 'value'}]
>>> example_list2 = [1, 2, 3, [4]]
>>> example_list3 = [1, 2, 3, 4]
>>> contains_nested_objects(example_list1)
True  # non-strict, contains nested objects
>>> contains_nested_objects(example_list1, strict=True)
True  # strict, all are nested objects
>>> contains_nested_objects(example_list2)
True  # non-strict, contains at least one nested object
>>> contains_nested_objects(example_list2, strict=True)
False  # strict, not all are nested objects
>>> contains_nested_objects(example_list3)
False  # non-strict, no nested objects
>>> contains_nested_objects(example_list3, strict=True)
False  # strict, no nested objects
geoprior.utils.validator.convert_array_to_pandas(X, *, to_frame=False, columns=None, input_name='X')[source]

Converts an array-like object to a pandas DataFrame or Series, applying provided column names or series name.

Parameters:
  • X (array-like) – The array to convert to a DataFrame or Series.

  • to_frame (bool, default False) – If True, converts the array to a DataFrame. Otherwise, returns the array unchanged.

  • columns (str or list of str, optional) – Name(s) for the columns of the resulting DataFrame or the name of the Series.

  • input_name (str, default 'X') – The name of the input variable; used in constructing error messages.

Returns:

  • pd.DataFrame or pd.Series – The converted DataFrame or Series. If to_frame is False, returns X unchanged.

  • columns (str or list of str) – The column names of the DataFrame or the name of the Series, if applicable.

Raises:
  • TypeError – If X is not array-like or if columns is neither a string nor a list of strings.

  • ValueError – If the conversion to DataFrame is requested but columns is not provided, or if the length of columns does not match the number of columns in X.

geoprior.utils.validator.ensure_2d(X, output_format='auto')[source]

Ensure that the input X is converted to a 2-dimensional structure.

Parameters:
  • X (array-like or pandas.DataFrame) – The input data to convert. Can be a list, numpy array, or DataFrame.

  • output_format (str, optional) – The format of the returned object. Options are “auto”, “array”, or “frame”. “auto” returns a DataFrame if X is a DataFrame, otherwise a numpy array. “array” always returns a numpy array. “frame” always returns a pandas DataFrame.

Returns:

The converted 2-dimensional structure, either as a numpy array or DataFrame.

Return type:

ndarray or DataFrame

Raises:

ValueError – If the output_format is not one of the allowed values.

Examples

>>> import numpy as np
>>> from geoprior.utils.validator import ensure_2d
>>> X = np.array([1, 2, 3])
>>> ensure_2d(X, output_format="array")
array([[1],
       [2],
       [3]])
>>> df = pd.DataFrame([1, 2, 3])
>>> ensure_2d(df, output_format="frame")
   0
0  1
1  2
2  3
geoprior.utils.validator.ensure_non_negative(*arrays, err_msg=None)[source]

Ensure that provided arrays contain only non-negative values.

This function checks each provided array for non-negativity. If any negative values are found in any array, it raises a ValueError. This check is crucial for computations or algorithms where negative values are not permissible, such as logarithmic transformations.

Parameters:
  • *arrays (array-like) – One or more array-like structures (e.g., lists, numpy arrays). Each array is checked for non-negativity.

  • err_msg (str, optional) – Specify a custom error message if negative values are found.

Raises:

ValueError – If any array contains negative values, a ValueError is raised with a message indicating that only non-negative values are expected.

Examples

>>> y_true = [0, 1, 2, 3]
>>> y_pred = [0.5, 2.1, 3.5, -0.1]
>>> ensure_non_negative(y_true, y_pred)
ValueError: Negative value found. Expect only non-negative values.

Note

The function uses a variable number of arguments, allowing flexibility in the number of arrays checked in a single call.

geoprior.utils.validator.filter_valid_kwargs(callable_obj, kwargs)[source]

Filter and return only the valid keyword arguments for a given callable object.

This function checks if the arguments in kwargs are valid for the provided callable object (function, lambda function, method, or class). If any argument is not valid, it is removed from kwargs. The function returns only the valid kwargs.

Parameters:
  • callable_obj (callable) – The callable object (function, lambda function, method, or class) for which the keyword arguments need to be validated.

  • kwargs (dict) – Dictionary of keyword arguments to be validated against the callable object.

Returns:

valid_kwargs – Dictionary containing only the valid keyword arguments for the callable object.

Return type:

dict

Examples

>>> def example_func(a, b, c=3):
...     pass
>>> kwargs = {'a': 1, 'b': 2, 'd': 4}
>>> filter_valid_kwargs(example_func, kwargs)
{'a': 1, 'b': 2}
>>> class ExampleClass:
...     def __init__(self, x, y, z=10):
...         pass
>>> kwargs = {'x': 1, 'y': 2, 'a': 3}
>>> filter_valid_kwargs(ExampleClass, kwargs)
{'x': 1, 'y': 2}
>>> filter_valid_kwargs(ExampleClass(), kwargs)
{'x': 1, 'y': 2}

Notes

This function uses the inspect module to retrieve the signature of the given callable object and validate the keyword arguments.

geoprior.utils.validator.get_estimator_name(estimator)[source]

Get the estimator name whatever it is an instanciated object or not

Parameters:

estimator – callable or instanciated object, callable or instance object that has a fit method.

Returns:

str, name of the estimator.

geoprior.utils.validator.handle_zero_division(y_true, zero_division='warn', metric_name='metric computation', epsilon=1e-15, replace_with=None)[source]

Preprocess input arrays to handle cases where zero could cause division errors in subsequent metric computations.

Parameters:
  • y_true (array-like) – The input data array where zeros might cause division errors.

  • zero_division ({'warn', 'raise', 'ignore'}, default 'warn') – Determines the action to perform when a zero is encountered. Use "warn" to issue a warning and replace zeros with replace_with or epsilon, "raise" to raise an error, or "ignore" to leave zeros unchanged when the metric can handle them natively.

  • metric_name (str, optional) – Name of the metric for which this preprocessing is being done, to be included in warnings or error messages for better context.

  • epsilon (float, optional) – Small value to use as default replacement if replace_with is None, default is 1e-15.

  • replace_with (float or None, optional) – A specific value to replace zeros with, if None, epsilon is used.

Returns:

The processed array with modifications based on the zero_division strategy.

Return type:

numpy.ndarray

Raises:

ValueError – If zero_division is ‘raise’ and zero is found in y_true.

Notes

Using replace_with allows for custom behavior when handling zeros, which can be tailored to the specific requirements of different metric computations.

Examples

>>> from geoprior.utils.validator import handle_zero_division
>>> y_true = [0, 1, 2, 3, 0]
>>> processed_y_true = handle_zero_division(
...     y_true, replace_with=0.001, zero_division='warn'
... )
>>> print(processed_y_true)
[1.e-03 1.e+00 2.e+00 3.e+00 1.e-03]
geoprior.utils.validator.has_methods(models, methods, strict=True, check_status='check_only', msg=None)[source]

Validate that one or more model objects implement required methods.

Parameters:
  • models (object or list of objects) – Model instance or collection of model instances to validate.

  • methods (list of str) – Public method names that each model must implement.

  • strict (bool, optional) – If True, raise an AttributeError when a required method is missing.

  • check_status ({'validate', 'check_only'}, optional) – Return mode. Use 'validate' to return validated models and 'check_only' to return a boolean flag.

  • msg (str or None, optional) – Optional custom error message using {model} and {methods} placeholders.

Returns:

Validated models when check_status='validate' or a boolean flag when check_status='check_only'.

Return type:

list of objects or bool

Raises:
geoprior.utils.validator.has_fit_parameter(estimator, parameter)[source]

Check whether the estimator’s fit method supports the given parameter.

Parameters:
  • estimator (object) – An estimator to inspect.

  • parameter (str) – The searched parameter.

Returns:

is_parameter – Whether the parameter was found to be a named parameter of the estimator’s fit method.

Return type:

bool

Examples

>>> from sklearn.svm import SVC
>>> from sklearn.utils.validation import has_fit_parameter
>>> has_fit_parameter(SVC(), "sample_weight")
True
geoprior.utils.validator.has_required_attributes(model, attributes)[source]

Check if the model has all required Keras-specific attributes.

This function is part of the deep validation process to ensure that the model not only inherits from Keras model classes but also implements essential methods.

Parameters:
  • model (Any) – The model object to inspect.

  • attributes (list of str) – A list of strings representing the names of the attributes to check for in the model.

Returns:

True if the model contains all specified attributes, False otherwise.

Return type:

bool

geoprior.utils.validator.is_binary_class(y, accept_multioutput=False)[source]

Check whether the target array represents binary classification. Optionally, handle multi-output arrays if each output is binary.

Parameters:
  • y (array-like) – The target array to be checked. This can be a 1D array for single output or a 2D array for multiple outputs if accept_multioutput is True.

  • accept_multioutput (bool, default False) – If True, the function checks if each column in a multi-dimensional array is binary. If False, the function checks if the entire array is binary.

Returns:

Returns True if y is binary (or each output is binary if multi-output is accepted), False otherwise.

Return type:

bool

Examples

>>> from geoprior.utils.validator import is_binary_class
>>> is_binary_class([0, 1, 1, 0])
True
>>> is_binary_class([[0, 1], [1, 0], [0, 1], [1, 0]], accept_multioutput=True)
True
>>> is_binary_class([0, 1, 2, 3])
False
geoprior.utils.validator.is_categorical(data, column, strict=False, error='raise')[source]

Checks if a specified column in a DataFrame or Series is of a categorical type.

Parameters:
  • data (DataFrame or Series) – The DataFrame or Series to check.

  • column (str) – The name of the column to check.

  • strict (bool, optional) – If True, only considers pandas CategoricalDtype as categorical. If False, also considers object dtype that often represents categorical data. Default is False.

  • error (str, optional) – Specifies how to handle situations when the column does not exist. Options are ‘raise’, ‘warn’, or ‘ignore’. Default is ‘raise’.

Returns:

True if the column is categorical, otherwise False.

Return type:

bool

Raises:

ValueError – If the column does not exist and error is set to ‘raise’.

Examples

>>> import pandas as pd
>>> from geoprior.utils.validator import is_categorical
>>> df = pd.DataFrame({
...     'fruit': ['Apple', 'Banana', 'Cherry'],
...     'count': [10, 20, 15]
... })
>>> df['fruit'] = df['fruit'].astype('category')
>>> print(is_categorical(df, 'fruit'))
True
>>> print(is_categorical(df, 'count'))
False
>>> print(is_categorical(df, 'non_existent', error='warn'))
Warning: Column 'non_existent' not found in the dataframe.
False
geoprior.utils.validator.is_frame(arr, df_only=False, raise_exception=False, objname=None, error='raise')[source]

Check if arr is a pandas DataFrame or Series.

If df_only=True, the function checks strictly for a pandas DataFrame. Otherwise, it accepts either a pandas DataFrame or Series. This utility is often used to validate input data before processing, ensuring that the input conforms to expected types.

Parameters:
  • arr (object) – The object to examine. Typically a pandas DataFrame or Series, but can be any Python object.

  • df_only (bool, optional) – If True, only verifies that arr is a DataFrame. If False, checks for either a DataFrame or a Series. Default is False.

  • raise_exception (bool, optional) – If True, this will override error=”raise”. This parameter is deprecated and will be removed soon. Default is False.

  • error (str, optional) – Determines the action when arr is not a valid frame. Can be: - "raise": Raises a TypeError. - "warn": Issues a warning. - "ignore": Does nothing. Default is "raise".

  • objname (str or None, optional) – A custom name used in the error message if error is set to "raise". If None, a generic name is used.

Returns:

True if arr is a DataFrame or Series (or strictly a DataFrame if df_only=True), otherwise False.

Return type:

bool

Raises:

TypeError – If error=”raise” and arr is not a valid frame. The error message guides the user to provide the correct type (DataFrame or DataFrame or Series).

Notes

This function does not convert or modify arr. It merely checks its compatibility with common DataFrame/Series interfaces by examining attributes such as ‘columns’ or ‘name’. For a DataFrame, arr.columns should exist, and for a Series, a ‘name’ attribute is often present. Both DataFrame and Series implement __array__, making them NumPy array-like.

Examples

>>> import pandas as pd
>>> from geoprior.utils.validator import is_frame
>>> df = pd.DataFrame({'A': [1,2,3]})
>>> is_frame(df)
True
>>> s = pd.Series([4,5,6], name='S')
>>> is_frame(s)
True
>>> is_frame(s, df_only=True)
False

If error=”raise”:

>>> is_frame(s, df_only=True, error="raise", objname='Input')
Traceback (most recent call last):
    ...
TypeError: 'Input' parameter expects a DataFrame. Got 'Series'
geoprior.utils.validator.is_installed(module)[source]

Checks if TensorFlow is installed.

This function attempts to find the TensorFlow package specification without importing the package. It’s a lightweight method to verify the presence of TensorFlow in the environment.

Returns:

True if TensorFlow is installed, False otherwise.

Return type:

bool

Parameters:

module (str)

Examples

>>> from geoprior.utils.validator import is_installed
>>> print(is_installed("tensorflow"))
True  # Output will be True if TensorFlow is installed, False otherwise.
geoprior.utils.validator.is_normalized(arr, method='sum')[source]

Checks if the provided array is normalized according to the specified method.

Parameters:
  • arr (array-like) – The array to check for normalization.

  • method (str, optional) – The normalization method to check against. Use "01" to confirm values are within [0, 1] with minimum 0 and maximum 1, "zscore" to confirm mean 0 and standard deviation 1, or "sum" to confirm the array sums to 1. Default is "sum".

Returns:

Returns True if the array is normalized according to the specified method, False otherwise.

Return type:

bool

Examples

>>> arr = np.array([0.25, 0.25, 0.25, 0.25])
>>> is_normalized(arr, method='sum')
True
>>> arr = np.array([0, 0.5, 1])
>>> is_normalized(arr, method='01')
True
>>> arr = np.array([1, -1, 1, -1])
>>> is_normalized(arr, method='zscore')
True
geoprior.utils.validator.is_square_matrix(data, data_type=None)[source]

Determine whether the input, either a DataFrame or an array-like structure, forms a square matrix.

Automatically detects the data type unless specified. Supports data inputs that can be converted to a NumPy array.

Parameters:
  • data (DataFrame, array-like, or any object convertible to a numpy array) – The input data to check.

  • data_type (str, optional) – The expected type of the input data. Valid options are ‘array’ or ‘dataframe’. If not specified, the data type is inferred. Default interpretation is as an ‘array’.

Returns:

Returns True if the data is a square matrix, otherwise False.

Return type:

bool

Raises:
  • ValueError – If data_type is neither ‘array’ nor ‘dataframe’.

  • TypeError – If the input data does not match the expected format or cannot be processed.

Examples

>>> is_square_matrix(np.array([[1, 2], [3, 4]]))
True
>>> is_square_matrix(pd.DataFrame([[1, 2, 3], [4, 5, 6]]))
False
>>> is_square_matrix([[1, 2], [3, 4]], data_type='array')
True

Notes

A square matrix has an equal number of rows and columns. This function checks the dimensionality and shape of the data to confirm if it meets this criterion.

geoprior.utils.validator.is_time_series(data, time_col, check_time_interval=False)[source]

Check if the provided DataFrame is time series data.

Parameters:
  • data (pandas.DataFrame) – The DataFrame to be checked.

  • time_col (str) – The name of the column in df expected to represent time.

Returns:

True if df is a time series, False otherwise.

Return type:

bool

Example

>>> import pandas as pd
>>> df = pd.DataFrame({
    'Date': ['2021-01-01', '2021-01-02', '2021-01-03', '2021-01-04', '2021-01-05'],
    'Value': [1, 2, 3, 4, 5]
})
>>> # Should return True if Date column
>>> # can be converted to datetime
>>> print(is_time_series(df, 'Date'))
geoprior.utils.validator.is_valid_policies(nan_policy, allowed_policies=None)[source]

Validates the nan_policy or any policy argument to ensure it is one of the acceptable options (allowed_policies).

Function is used to enforce conformity to predefined NaN handling strategies in data processing tasks.

Parameters:
  • nan_policy (str) – The NaN handling policy to validate. Acceptable values are: ‘propagate’ - NaN values are propagated, i.e., no action is taken. ‘omit’ - NaN values are omitted before proceeding with the operation. ‘raise’ - Raises an error if NaN values are present.

  • allowed_policies (list of str, optional) – A list of allowable policy options. If None, defaults to [‘propagate’, ‘omit’, ‘raise’].

Raises:

ValueError – If nan_policy is not one of the valid options in allowed_policies.

Returns:

The verified nan_policy value, confirming it is within allowed parameters.

Return type:

str

Examples

>>> from geoprior.utils.validator import is_valid_policies
>>> is_valid_policies('omit')  # This should pass without an error.
>>> is_valid_policies('ignore')  # This should raise a ValueError.
geoprior.utils.validator.normalize_array(arr, normalize='auto', method='01')[source]

Checks if an array is normalized according to the specified method and normalizes it if required based on the ‘normalize’ parameter.

Parameters:
  • arr (array-like) – The input array to check and potentially normalize.

  • normalize (str, optional) – Controls whether normalization is applied. Use "auto" to normalize only when the array is not already normalized for the selected method. Use True to always normalize and False to return the array unchanged. Default is "auto".

  • method (str, optional) – Normalization method to apply. Use "01" for min-max scaling, "zscore" for standardization, or "sum" to scale values so they sum to 1. Default is "01".

Returns:

The normalized array, or the original array if no normalization was applied.

Return type:

np.ndarray

Raises:

ValueError – If an unknown normalization method is specified or if normalization cannot be performed due to data characteristics (e.g., zero variance).

Examples

>>> import numpy as np
>>> from geoprior.utils.validator import normalize_array
>>> data = np.array([1, 2, 3, 4, 5])
>>> normalized_data = normalize_array(data, normalize=True, method='01')
>>> print("Normalized between 0 and 1:", normalized_data)
Normalized between 0 and 1: [0.   0.25 0.5  0.75 1.  ]
>>> zscore_data = normalize_array(data, normalize=True, method='zscore')
>>> print("Standardized (Z-score):", zscore_data)
Standardized (Z-score): [-1.41421356 -0.70710678  0.          0.70710678  1.41421356]
>>> sum_data = normalize_array(data, normalize=True, method='sum')
>>> print("Normalized by sum:", sum_data)
Normalized by sum: [0.06666667 0.13333333 0.2        0.26666667 0.33333333]
geoprior.utils.validator.parameter_validator(param_name, target_strs, match_method='contains', raise_exception=True, **kws)[source]

Creates a validator function for ensuring a parameter’s value matches one of the allowed target strings, optionally applying normalization.

This higher-order function returns a validator that can be used to check if a given parameter value matches allowed criteria, optionally raising an exception or normalizing the input.

Parameters:
  • param_name (str) – Name of the parameter to be validated. Used in error messages to indicate which parameter failed validation.

  • target_strs (list of str) – A list of acceptable string values for the parameter.

  • match_method (str, optional) – The method used to match the input string against the target strings. The default method is ‘contains’, which checks if the input string contains any of the target strings.

  • raise_exception (bool, optional) – Specifies whether an exception should be raised if validation fails. Defaults to True, raising an exception on failure.

  • **kws (dict,) – Keyword arguments passed to geoprior.core.utils.normalize_string().

Returns:

A closure that takes a single string argument (the parameter value) and returns a normalized version of it if the parameter matches the target criteria. If the parameter does not match and raise_exception is True, it raises an exception; otherwise, it returns the original value.

Return type:

function

Examples

>>> from geoprior.utils.validator import parameter_validator
>>> validate_outlier_method = parameter_validator(
...  'outlier_method', ['z_score', 'iqr'])
>>> outlier_method = "z_score"
>>> print(validate_outlier_method(outlier_method))
'z_score'
>>> validate_fill_missing = parameter_validator(
...  'fill_missing', ['median', 'mean', 'mode'], raise_exception=False)
>>> fill_missing = "average"  # This does not match but won't raise an exception.
>>> print(validate_fill_missing(fill_missing))
'average'

Notes

  • The function leverages a custom utility function normalize_string from a module named geoprior.core.utils. This utility is assumed to handle string normalization and matching based on the provided match_method.

  • If raise_exception is set to False and the input does not match any target string, the input string is returned unchanged. This behavior allows for optional enforcement of the validation rules.

  • The primary use case for this function is to validate and optionally normalize parameters for configuration settings or function arguments where only specific values are allowed.

geoprior.utils.validator.process_y_pairs(*ys, error='warn', solo_return=False, ops='check_only')[source]

Process and validate paired arrays of ground truth (y_true) and predicted values (y_pred) for machine learning evaluation.

Parameters:
  • *ys (ArrayLike) – Variable-length sequence of array-likes containing alternating (y_true, y_pred) pairs. Must contain even number of inputs.

  • error ({'raise', 'warn', 'ignore'}, default 'warn') – Handling strategy for validation errors: - 'raise': Immediately raise ValueError - 'warn': Issue UserWarning but continue processing - 'ignore': Silently skip invalid pairs

  • solo_return (bool, default False) – When processing single pair, return as individual arrays instead of length-1 lists.

  • ops ({'check_only', 'validate'}, default 'check_only') – Processing mode: - 'check_only': Verify pair lengths without modification - 'validate': Clean data (remove NaNs) and validate dtypes

Returns:

Processed pairs as (y_trues, y_preds) tuple. Return type depends on solo_return and number of valid pairs.

Return type:

Tuple[List[ArrayLike], List[ArrayLike]] or Tuple[ArrayLike, ArrayLike]

Raises:
  • ValueError

    • If input count is odd and error='raise'

    • Length mismatch in pairs when error='raise'

    • Invalid error or ops values

  • UserWarning

    • When odd input count and error='warn'

    • Length mismatches when error='warn'

Examples

Basic usage with valid pairs:

>>> from geoprior.utils.validator  import process_y_pairs
>>> y_true1 = [1.2, 2.3, 3.4]
>>> y_pred1 = [1.1, 2.4, 3.3]
>>> y_true2 = [4.5, 5.6]
>>> y_pred2 = [4.4, 5.7]
>>> process_y_pairs(y_true1, y_pred1, y_true2, y_pred2)
([[1.2, 2.3, 3.4], [4.5, 5.6]], [[1.1, 2.4, 3.3], [4.4, 5.7]])

Handling mismatched pair with warnings:

>>> y_bad = [1, 2, 3]
>>> p_bad = [1, 2]
>>> process_y_pairs(y_bad, p_bad, error='warn')
UserWarning: Length mismatch in pair 0: 3 vs 2
([], [])

Full validation pipeline:

>>> import numpy as np
>>> y_clean, p_clean = process_y_pairs(
...     [1, np.nan, 3], [np.nan, 2.1, 3.2],
...     ops='validate', solo_return=True
... )
>>> y_clean
array([3.])
>>> p_clean
array([3.2])

Notes

Ensures input pairs meet requirements for downstream analysis through:

(41)#\[ \begin{align}\begin{aligned}\forall i \in \{0,2,4,...\},\ (y_{true}^i, y_{pred}^i) \rightarrow (\tilde{y}_{true}^i, \tilde{y}_{true}^i)\ \text{where}\\\text{len}(\tilde{y}_{true}^i) = \text{len}(\tilde{y}_{pred}^i)\\\text{and}\ \tilde{y}_{true}^i \in \mathbb{R}^{n},\ \tilde{y}_{pred}^i \in \mathbb{R}^{n}\end{aligned}\end{align} \]
  1. Uses drop_nan_in for NaN removal and index resetting during validation

  2. Applies validate_yy for dtype consistency checks and array flattening

  3. Forward references for ArrayLike allow flexibility - accepts any array-like structure (list, numpy array, pandas Series, etc.)

  4. The type and array-handling conventions rely on the Python language reference and NumPy’s array-programming model [36, 47].

See also

drop_nan_in

Core NaN removal and index resetting function

validate_yy

Array validation and dtype consistency checker

sklearn.utils.check_consistent_length

Scikit-learn’s length validation

geoprior.utils.validator.to_dtype_str(arr, return_values=False)[source]

Convert numeric or object dtype to string dtype.

This will avoid a particular TypeError when an array is filled by np.nan and at the same time contains string values. Converting the array to dtype str rather than keeping to ‘object’ will pass this error.

Parameters:
  • arr – array-like array with all numpy datatype or pandas dtypes

  • return_values – bool, default=False returns array values in string dtype. This might be usefull when a series with dtype equals to object or numeric is passed.

Returns:

array-like array-like with dtype str Note that if the dataframe or serie is passed, the object datatype will change only if return_values is set to True, otherwise returns the same object.

geoprior.utils.validator.validate_and_adjust_ranges(**kwargs)[source]

Validates and adjusts the provided range tuples to ensure each is composed of two numerical values and is sorted in ascending order.

This function takes multiple range specifications as keyword arguments, each expected to be a tuple of two numerical values (min, max). It validates the format and contents of each range, adjusting them if necessary to ensure that each tuple is ordered as (min, max).

Parameters:

**kwargs (dict) – Keyword arguments where each key is the name of a range (e.g., ‘lat_range’) and its corresponding value is a tuple of two numerical values representing the minimum and maximum of that range.

Returns:

A dictionary with the same keys as the input, but with each tuple value adjusted to ensure it is in the format (min, max).

Return type:

dict

Raises:

ValueError – If any provided range tuple does not contain exactly two values, contains non-numerical values, or if the min value is not less than the max value.

Examples

>>> from geoprior.utils.validator import validate_and_adjust_ranges
>>> validate_and_adjust_ranges(lat_range=(34.00, 36.00), lon_range=(-118.50, -117.00))
{'lat_range': (34.00, 36.00), 'lon_range': (-118.50, -117.00)}
>>> validate_and_adjust_ranges(time_range=(10.0, 0.01))
{'time_range': (0.01, 10.0)}
>>> validate_and_adjust_ranges(invalid_range=(1, 'a'))
ValueError: invalid_range must contain numerical values.

Notes

This function is particularly useful for preprocessing input ranges for various analyses, ensuring consistency and correctness of range specifications. It automates the adjustment of provided ranges, simplifying the setup process for further data processing or modeling tasks.

geoprior.utils.validator.validate_batch_size(batch_size, n_samples, min_batch_size=1, max_batch_size=None)[source]

Validate the batch size against the number of samples.

This function checks whether the provided batch_size is appropriate given the total number of samples n_samples. It ensures that the batch size meets specified minimum and maximum limits, raising appropriate errors if any constraints are violated.

Parameters:
  • batch_size (int) – The size of each batch. This must be a positive integer, as batches must contain at least one sample. A ValueError will be raised if this value is less than the minimum allowed batch size or exceeds the total number of samples.

  • n_samples (int) – The total number of samples in the dataset. This value must be positive and greater than or equal to the batch_size. If batch_size is greater than n_samples, a ValueError is raised.

  • min_batch_size (int, optional) – The minimum allowed batch size (default is 1). This parameter defines the smallest permissible batch size. A ValueError will be raised if the batch_size is less than this value.

  • max_batch_size (int, optional) – The maximum allowed batch size (default is None, meaning no upper limit). This parameter can be used to restrict the size of the batch to a specified maximum value. If max_batch_size is provided, a ValueError will be raised if the batch_size exceeds this limit.

Returns:

batch_size

Return type:

Validated number of batch size

Raises:

ValueError – If the batch_size is less than the min_batch_size, greater than the n_samples, or exceeds the max_batch_size if specified. Additionally, if batch_size is not a positive integer, a ValueError is raised.

Notes

Let B represent the batch_size and N represent the n_samples. The validation can be expressed mathematically as:

(42)#\[ext{If } B < ext{min\_batch\_size} ext{ or } B > N ext{ or } B > ext{max\_batch\_size}: \quad ext{raise ValueError}\]

This function is essential for managing data batching in machine learning workflows, where improper batch sizes can lead to inefficient training or runtime errors. The practical mini-batch constraint follows standard deep-learning training guidance [48].

Examples

>>> from geoprior.utils.validator import validate_batch_size
>>> validate_batch_size(32, 100)  # Valid case
>>> validate_batch_size(0, 100)  # Raises ValueError
>>> validate_batch_size(150, 100)  # Raises ValueError
>>> validate_batch_size(32, 100, max_batch_size=32)  # Valid case
>>> validate_batch_size(40, 100, max_batch_size=32)  # Raises ValueError
geoprior.utils.validator.validate_comparison_data(df, alignment='auto')[source]

Validates a DataFrame to ensure it is a square matrix and that the index and column names match. Optionally aligns the index names to the column names or vice versa based on the alignment parameter.

Parameters:
  • df (pandas.DataFrame) – The DataFrame to validate.

  • alignment (str, default 'auto') – Controls how the DataFrame’s index and columns are aligned if they d o not match. Options are ‘auto’, ‘index_to_columns’, and ‘columns_to_index’.

Returns:

The validated and potentially modified DataFrame.

Return type:

pandas.DataFrame

Raises:

ValueError – If the DataFrame is not square or if index and column names do not match and no suitable alignment option is specified.

Examples

>>> from geoprior.utils.validator import validate_comparison_data
>>> data = pd.DataFrame({
...     'A': [1, 0.9, 0.8],
...     'B': [0.9, 1, 0.85],
...     'C': [0.8, 0.85, 1]
... }, index=['A', 'B', 'X'])
>>> print(validate_comparison_data(data, alignment='index_to_columns'))
>>> data = pd.DataFrame({
...     1: [1, 0.9, 0.8],
...     2: [0.9, 1, 0.85],
...     3: [0.8, 0.85, 1]
... }, index=[1, 2, 'X'])
>>> print(validate_comparison_data(data, alignment='auto'))
geoprior.utils.validator.validate_data_types(data, expected_type='numeric', nan_policy='omit', return_data=False, error='raise')[source]

Checks for mixed data types in a pandas Series or DataFrame and handles according to the specified policies. This function is designed to ensure data consistency by verifying that data matches expected type criteria, offering options to manage and report any discrepancies.

Parameters:
  • data (pd.Series or pd.DataFrame) – The data to be checked. This can be a pandas Series or DataFrame.

  • expected_type ({'numeric', 'categoric', 'both'}, default 'numeric') –

    Specifies the type of data expected:

    • ’numeric’: All data should be of numeric types (int, float).

    • ’categoric’: All data should be categorical, typically strings or pandas Categorical datatype.

    • ’both’: Any mix of numeric and categorical data is considered valid.

  • nan_policy ({'raise', 'omit', 'propagate'}, default 'omit') –

    Determines how NaN values are handled:

    • ’raise’: Raises an error if NaN values are found.

    • ’warn’: Issues a warning if NaN values are found but proceeds.

    • ’propagate’: Continues execution without addressing NaNs.

  • return_data (bool, default False) – If True, returns a DataFrame or Series (depending on the input) that only includes data rows that conform to the expected_type. If False, returns None.

  • error ({'raise', 'warn'}, default 'raise') –

    Configures the error handling behavior when data types do not conform to the expected_type:

    • ’raise’: Raises a TypeError if mixed types are detected.

    • ’warn’: Emits a warning but attempts to continue by filtering non-conforming data if return_data is True.

Returns:

Depending on return_data, this function may return a filtered version of data that conforms to the expected_type or None if return_data is False.

Return type:

pd.Series or pd.DataFrame or None

Raises:
  • ValueError – If NaN values are present and nan_policy is set to ‘error’.

  • TypeError – If data types do not conform to expected_type and error is set to ‘raise’.

Examples

>>> import pandas as pd
>>> from geoprior.utils.validator import validate_data_types
>>> df = pd.DataFrame({'A': [1, 2, 'a', 3.5, np.nan], 'B': ['x', 'y', 'z', None, 't']})
>>> validate_data_types(df, expected_type='numeric', nan_policy='warn',
...                  return_data=True, error='warn')
UserWarning: NaN values found in the data, but processing will continue.
UserWarning: Expected numeric types but found mixed types.
Non-numeric data will be ignored.
   A
0  1.0
1  2.0
3  3.5

Notes

The check_data_types function is useful in data preprocessing steps, particularly when you need to ensure that data fed into a machine learning algorithm meets certain type requirements. Handling mixed data types early on can prevent issues in model training and evaluation.

geoprior.utils.validator.validate_dates(start_date, end_date, return_as_date_str=False, date_format='%Y-%m-%d')[source]

Validates and parses start and end years/dates, with options for output formatting.

This function ensures the validity of provided start and end years or dates, checks if they fall within a reasonable range, and allows the option to return the validated years or dates in a specified string format.

Parameters:
  • start_date (int, float, or str) – The starting year or date. Can be an integer, float (converted to integer), or string in “YYYY” or “YYYY-MM-DD” format.

  • end_date (int, float, or str) – The ending year or date, with the same format options as start_date.

  • return_as_date_str (bool, optional) – If True, returns the start and end dates as strings in the specified format. Default is False, returning years as integers.

  • date_format (str, optional) – The format string for output dates if return_as_date_str is True. Default format is “%Y-%m-%d”.

Returns:

A tuple of two elements, either integers (years) or strings (formatted dates), representing the validated start and end years or dates.

Return type:

tuple

Raises:

ValueError – If the input years or dates are invalid, out of the acceptable range, or if the start year/date does not precede the end year/date.

Examples

>>> from geoprior.utils.validator import validate_dates
>>> validate_dates(1999, 2001)
(1999, 2001)
>>> validate_dates("1999/01/01", "2001/12/31", return_as_date_str=True)
('1999-01-01', '2001-12-31')
>>> validate_dates("1999", "1998")
ValueError: The start date/time must precede the end date/time.
>>> validate_years("1899", "2001")
ValueError: Years must be within the valid range: 1900 to [current year].

Notes

The function supports flexible input formats for years and dates, including handling both slash “/” and dash “-” separators in date strings. It enforces logical and chronological order between start and end inputs and allows customization of the output format for date strings.

geoprior.utils.validator.validate_distribution(distribution, elements=None, kind=None, check_normalization=True)[source]

Validates or generates distributions for given elements, ensuring the sum equals 1 if check_normalization is True.

Parameters:
  • distribution (str, tuple, list) – The distribution to be validated or generated. If ‘auto’, generates a random distribution for the specified number of elements. Can also be a tuple or list representing an explicit distribution.

  • elements (int, list of str, optional) – Defines how many elements the distribution should be generated for when ‘auto’ is used. If a list of strings is provided, its length is used to determine the number of elements.

  • kind (str, optional) – Specifies the kind of distribution. It can be {"probs"} for probability distributions, where the sum should equal 1 and values must be non-negative.

  • check_normalization (bool, optional) – If True, ensures that the sum of the distribution equals 1. Default is True.

Returns:

A tuple representing the validated or generated distribution.

Return type:

tuple

Raises:

ValueError – If the provided distribution does not meet the specified conditions.

Examples

>>> from geoprior.utils.validator import validate_distribution
>>> validate_distribution("auto", elements=['positive', 'neutral', 'negative'])
(0.1450318690603951, 0.5660028611331361, 0.2889652698064687)
geoprior.utils.validator.validate_dtype_selector(dtype_selector)[source]

Validates and categorizes the dtype_selector using regex, including handling cases where ‘only’ is specifically included.

Parameters:

dtype_selector (str) – Input dtype selector string.

Returns:

Categorized dtype_selector based on predefined patterns. If "only" is included, the returned category reflects this so it can drive specific data-type handling.

Return type:

str

Raises:

ValueError – If the input dtype_selector does not match any predefined category.

geoprior.utils.validator.validate_estimator_methods(estimator, methods, msg=None)[source]

Validate that the specified methods exist and are callable on the given estimator.

This utility function is designed to check whether an estimator (or any object) contains the required methods, such as fit or predict, and ensures that those methods are callable. It helps prevent runtime errors by verifying the presence of expected methods.

Parameters:
  • estimator (object) – The object (instance or class) to check for the presence of the specified methods. The estimator can be an instance of a class or the class itself, and it should implement the required methods.

  • methods (list of str) – List of method names (as strings) to validate. Each method name must exist on the estimator and be callable. Examples of methods might include fit, run, or predict.

  • msg (str, optional) – Custom error message to display if any method is missing or not callable. If None, a default message is generated for each missing or invalid method based on the method name.

Raises:

AttributeError – If any method in methods is not present or not callable on the estimator.

Examples

>>> from geoprior.utils.validator import validate_estimator_methods
>>> class MyClass:
...     def fit(self):
...         pass
...     def run(self):
...         pass
>>> validate_estimator_methods(MyClass(), ['fit', 'run'])  # No error
>>> class IncompleteClass:
...     def fit(self):
...         pass
>>> validate_estimator_methods(IncompleteClass(), ['fit', 'run'])
# Raises AttributeError for missing `run` method

Notes

This helper is useful when you want to ensure that an object, such as an estimator or a model, exposes several callable methods before proceeding. If any method is missing or not callable, the function raises an AttributeError. Method-callability checks follow the Python documentation and the callable-object discussion in [43, 49].

See also

check_has_run_method

Validate the presence of a single method, defaulting to run.

geoprior.utils.validator.validate_fit_weights(y, sample_weight=None, weighted_y=False)[source]

Validate and compute sample weights for fitting.

Parameters:
  • y (array-like of shape (n_samples,)) – Target values.

  • sample_weight (array-like of shape (n_samples,), default None) – Sample weights. If None, then samples are equally weighted.

  • weighted_y (bool, default False) – If True, compute the weighted target values.

Returns:

  • sample_weight (array-like of shape (n_samples,)) – Validated sample weights.

  • weighted_y_values (array-like of shape (n_samples,), optional) – Weighted target values if weighted_y is True.

Raises:

ValueError – If sample_weight is not None and its length does not match the length of y. If any value in sample_weight is negative.

Notes

This function checks the input sample weights, ensuring they are consistent with the target values y. If sample_weight is None, it returns an array of ones indicating equal weighting. Otherwise, it validates and returns the given sample weights. If weighted_y is True, it also computes and returns the weighted target values.

Examples

>>> import numpy as np
>>> y = np.array([0, 1, 1, 0, 1])
>>> validate_fit_weights(y)
array([1., 1., 1., 1., 1.])
>>> sample_weight = np.array([1, 0.5, 1, 1.5, 1])
>>> validate_fit_weights(y, sample_weight)
array([1. , 0.5, 1. , 1.5, 1. ])
>>> validate_fit_weights(y, sample_weight, weighted_y=True)
(array([1. , 0.5, 1. , 1.5, 1. ]), array([0. , 0.5, 1. , 0. , 1. ]))
>>> validate_fit_weights(y, weighted_y=True)
(array([1., 1., 1., 1., 1.]), array([0., 1., 1., 0., 1.]))
geoprior.utils.validator.validate_length_range(length_range, sorted_values=True, param_name=None)[source]

Validates the review length range ensuring it’s a tuple with two integers where the first value is less than the second.

Parameters:
  • length_range (tuple) – A tuple containing two values that represent the minimum and maximum lengths of reviews.

  • sorted_values (bool, default True) – If True, the function expects the input length range to be sorted in ascending order and will automatically sort it if not. If False, the input length range is not expected to be sorted, and it will remain as provided.

  • param_name (str, optional) – The name of the parameter being validated. If None, the default name ‘length_range’ will be used in error messages.

Returns:

The validated length range.

Return type:

tuple

Raises:

ValueError – If the length range does not meet the requirements.

Examples

>>> from geoprior.utils.validator import validate_length_range
>>> validate_length_range ( (202, 25) )
(25, 202)
>>> validate_length_range ( (202,) )
ValueError: length_range must be a tuple with two elements.
geoprior.utils.validator.validate_multiclass_target(y, accept_multioutput=False, return_classes=False)[source]

Validates that the target data is suitable for multiclass classification. Optionally accepts multi-output targets and can return the unique classes.

Parameters:
  • y (array-like) – The target data to be validated, expected to contain class labels for multiclass classification. Can be a multi-output array if accept_multioutput is set to True.

  • accept_multioutput (bool, optional) – Allows the target array to be multi-dimensional (default is False).

  • return_classes (bool, optional) – If True, returns the unique classes instead of a validation boolean.

Returns:

If return_classes is False, returns True if the target data is valid for multiclass classification, otherwise raises a ValueError. If return_classes is True, returns the unique classes in the target data.

Return type:

bool or array

Raises:

ValueError – If any of the following conditions are not met: - If accept_multioutput is False, the target data must be one-dimensional. - All elements in the target array must be non-negative integers. - The target array must contain at least two distinct classes.

Examples

>>> from geoprior.utils.validator import validate_multiclass_target
>>> validate_multiclass_target([0, 1, 2, 1, 0])
array([0, 1, 2, 1, 0])
>>> validate_multiclass_target([0, 0, 0])
ValueError: Target array must contain at least two distinct classes.
>>> validate_multiclass_target([0.5, 1.2, 2.3])
ValueError: All elements in the target array must be non-negative integers.
>>> validate_multiclass_target([[1, 2], [2, 3]], accept_multioutput=True,
...                              return_classes=True)
(array([1, 2, 2, 3]), 3)
True
geoprior.utils.validator.validate_multioutput(value, extra='')[source]

Validate the multioutput parameter value and handle special cases.

This function checks if the provided multioutput value is one of the accepted strings (‘raw_values’, ‘uniform_average’, ‘raise’, ‘warn’). It warns or raises an error based on the value if it’s applicable.

Parameters:
  • value (str) – The value of the multioutput parameter to be validated. Accepted values are ‘raw_values’, ‘uniform_average’, ‘raise’, ‘warn’.

  • extra (str, optional) – Additional text to include in the warning or error message if multioutput is not applicable.

Returns:

The validated multioutput value in lowercase if it’s one of the accepted values. If the value is ‘warn’ or ‘raise’, the function handles the case accordingly without returning a value.

Return type:

str

Raises:

ValueError – If value is not one of the accepted strings and is not ‘raise’.

Examples

>>> from geoprior.utils.validator import validate_multioutput
>>> validate_multioutput('raw_values')
'raw_values'
>>> validate_multioutput('warn', extra=' for Dice Similarity Coefficient')
# This will warn that multioutput parameter is not applicable for Dice
# Similarity Coefficient.
>>> validate_multioutput('raise', extra=' for Gini Coefficient')
# This will raise a ValueError indicating that multioutput parameter
# is not applicable for Gini Coefficient.
>>> validate_multioutput('average')
# This will raise a ValueError indicating 'average' is an invalid value
# for multioutput parameter.

Note

The function is designed to ensure API consistency across various metrics functions by providing a standard way to handle multioutput parameter values, especially in contexts where multiple outputs are not applicable.

geoprior.utils.validator.validate_nan_policy(nan_policy, *arrays, sample_weights=None)[source]

Validates and applies a specified nan_policy to input arrays and optionally to sample weights. This utility is essential for pre-processing data prior to statistical analyses or model training, where appropriate handling of NaN values is critical to ensure accurate and reliable outcomes.

Parameters:
  • nan_policy ({'propagate', 'raise', 'omit'}) – Defines how to handle NaNs in the input arrays. ‘propagate’ returns the input data without changes. ‘raise’ throws an error if NaNs are detected. ‘omit’ removes rows with NaNs across all input arrays and sample weights.

  • *arrays (array-like) – Variable number of input arrays to be validated and adjusted based on the specified nan_policy.

  • sample_weights (array-like, optional) – Sample weights array to be validated and adjusted in tandem with the input arrays according to nan_policy. Defaults to None.

Returns:

  • arrays (tuple of np.ndarray) – Adjusted input arrays, with modifications applied based on nan_policy. The order of arrays in the tuple corresponds to the order of input.

  • sample_weights (np.ndarray or None) – Adjusted sample weights, modified according to nan_policy if provided. Returns None if no sample_weights were provided.

Raises:

ValueError – If nan_policy is not among the valid options (‘propagate’, ‘raise’, ‘omit’) or if NaNs are detected when nan_policy is set to ‘raise’.

Notes

Handling NaN values is a critical step in data preprocessing, especially in datasets with missing values. The choice of nan_policy can significantly impact subsequent statistical analysis or predictive modeling by either including, excluding, or signaling errors for observations with missing values. This function ensures consistent application of the chosen policy across multiple datasets, facilitating robust and error-free analyses.

Examples

>>> import numpy as np
>>> from geoprior.utils.validator import validate_nan_policy
>>> y_true = np.array([1, np.nan, 3])
>>> y_pred = np.array([1, 2, 3])
>>> sample_weights = np.array([0.5, 0.5, 1.0])
>>> arrays, sw = validate_nan_policy('omit', y_true, y_pred,
...                                  sample_weights=sample_weights)
>>> arrays
(array([1., 3.]), array([1., 3.]))
>>> sw
array([0.5, 1. ])
geoprior.utils.validator.validate_numeric(value, convert_to='float', allow_negative=True, min_value=None, max_value=None, check_mode='soft')[source]

Validates if a given value is numeric. It can accept numeric strings and numpy arrays of single values. Optionally converts the value to either float or integer.

Parameters:
  • value (Any) – The value to be validated as numeric. This can be of any type but is expected to be convertible to a numeric type. Accepted types include numeric strings (e.g., "42"), single-element numpy arrays (e.g., np.array([3.14])), integers, and floats.

  • convert_to (str, optional) – Type to convert the validated numeric value to. Use "float" for floating-point output or "int" for integer output. Defaults to "float".

  • allow_negative (bool, optional) – Whether to allow negative values. If False, negative values raise a ValueError. Defaults to True.

  • min_value (float or int, optional) – The minimum value allowed. If None, no minimum value check is applied. Defaults to None.

  • max_value (float or int, optional) – The maximum value allowed. If None, no maximum value check is applied. Defaults to None.

  • check_mode (str, optional) – Validation mode. Use "soft" to accept single-element iterables and validate their single value, or "strict" to accept only non-iterable numeric inputs. Defaults to "soft".

Returns:

The validated and optionally converted numeric value. The type of the return value is determined by the convert_to parameter.

Return type:

float or int

Raises:

ValueError – If the value is not numeric or does not meet the specified criteria.

Notes

The function can coerce single-element NumPy arrays, numeric strings, and, in soft mode, single-element iterables before validating the result. The validated value is then converted to float or int and checked against the sign and range constraints. Array coercion details are documented in NumPy developers [50].

Examples

>>> from geoprior.utils.validator import validate_numeric
>>> validate_numeric("42", convert_to='int')
42
>>> validate_numeric(np.array([3.14]), convert_to='float')
3.14
>>> validate_numeric([123], check_mode='soft')
123.0
>>> validate_numeric([123], check_mode='strict')
Traceback (most recent call last):
    ...
ValueError: Value '[123]' is not a numeric type.
>>> validate_numeric("-123.45", allow_negative=False)
Traceback (most recent call last):
    ...
ValueError: Negative values are not allowed: -123.45

See also

numpy.array

Numpy arrays, which can be validated by this function.

geoprior.utils.validator.validate_performance_data(model_performance_data=None, nan_policy='raise', convert_integers=True, check_performance_range=True, verbose=False)[source]

Validates and preprocesses model performance data to ensure it conforms to the necessary structure and constraints for statistical and machine learning analysis. The function accepts either a dictionary or a DataFrame as input and performs the following tasks:

  1. Converts data to a DataFrame if it is provided as a dictionary.

  2. Converts integer values to floats, ensuring compatibility with statistical processing.

  3. Manages NaN values according to the specified nan_policy.

  4. Validates that performance data falls within a valid range, ensuring values lie within [0, 1].

The function is adaptable, capable of being used directly or as a decorator, with or without configuration parameters.

Parameters:
  • model_performance_data (Union[Dict[str, List[float]], pd.DataFrame], optional) – The input model performance data to validate. Can be provided as either a dictionary (with model names as keys and performance metrics as lists) or a DataFrame where each column represents a model.

  • nan_policy (str, default 'raise') – The policy to handle NaN values: * ‘raise’: Raises a ValueError if NaNs are detected. * ‘omit’: Drops rows with NaNs. * ‘propagate’: Ignores NaNs during performance range checks.

  • convert_integers (bool, default True) – Converts integer values within the data to floats if set to True, which is useful for consistency when computing metrics.

  • check_performance_range (bool, default True) – Ensures that performance values lie within the range [0, 1]. If any value falls outside this range, an error is raised unless nan_policy is set to ‘propagate’.

  • verbose (bool, default False) – If True, displays steps of the data validation process for tracking operations and debugging.

geoprior.utils.validator.actual_validate_performance_data(data)

Validates and processes the data according to specified policies and constraints.

geoprior.utils.validator.Usage()
-----
This function can be utilized in three primary ways:
1. **As a function**: Provide data directly to perform validation.
>>> from geoprior.utils.validator import validate_performance_data
>>> data = {'model1': [0.85, 0.90, 0.92], 'model2': [0.80, 0.87, 0.88]}
>>> validate_performance_data(data)
2. **As a decorator**: Use as a decorator to validate the first

argument of a function. If used without parentheses, default values will be applied.

>>> @validate_performance_data
>>> def process_data(validated_data):
>>>     print(validated_data)
3. **As a decorator with parameters**: Customize validation by

specifying parameters.

>>> @validate_performance_data(nan_policy='omit', verbose=True)
>>> def process_data(validated_data):
>>>     print(validated_data)

Notes

The validation process includes statistical pre-checks, using custom modules to convert data and handle NaNs. For integer-to-float conversion, the convert_to_numeric function is utilized, while NaN policies are verified using is_valid_policies. The comparison framing for multiple models follows Demšar [51].

See also

DataFrameFormatter

Formatter for handling DataFrame structures.

MultiFrameFormatter

Formatter for handling multiple DataFrames.

geoprior.utils.validator.validate_positive_integer(value, variable_name, include_zero=False, round_float=None, msg=None)[source]

Validates whether the given value is a positive integer or zero based on the parameter and rounds float values according to the specified method.

Parameters:
  • value (int or float) – The value to validate.

  • variable_name (str) – The name of the variable for error message purposes.

  • include_zero (bool, optional) – If True, zero is considered a valid value. Default is False.

  • round_float (str, optional) – If “ceil”, rounds up float values; if “floor”, rounds down float values; if None, truncates float values to the nearest whole number towards zero.

  • msg (str, optional) – Error message when checking for proper type failed.

Returns:

The validated value converted to an integer.

Return type:

int

Raises:

ValueError – If the value is not a positive integer or zero (based on include_zero), or if the round_float parameter is improperly specified.

geoprior.utils.validator.validate_sample_weights(weights, y, normalize=False)[source]

Validates that the sample weights are suitable for use in calculations.

This function checks that the sample weights are non-negative and match the length of the target array y. It raises an error if any conditions are not met. If a single number is provided as weights, it will be converted into an array with repeated values matching the length of y.

Parameters:
  • weights (array-like or number) – The sample weights to be validated. Each weight must be non-negative. A single number will be converted to an array with repeated values.

  • y (array-like) – The target array that the weights should correspond to. The length of weights must match the length of y.

  • normalize (bool, optional) – If True, weights will be normalized to sum to 1. Default is False.

Returns:

The validated sample weights as a numpy array.

Return type:

numpy.ndarray

Raises:

ValueError – If weights are not one-dimensional, if any weight is negative, or if the length of weights does not match the length of y.

Examples

>>> frpm geoprior.utils.validator import validate_sample_weights
>>> y = [0, 1, 2, 3]
>>> weights = [0.1, 0.2, 0.3, 0.4]
>>> validate_sample_weights(weights, y)
array([0.1, 0.2, 0.3, 0.4])
>>> weights = [-0.1, 0.2, 0.3, 0.4]
>>> validate_sample_weights(weights, y)
ValueError: Sample weights must be non-negative.
>>> weights = [0.1, 0.2, 0.3]
>>> validate_sample_weights(weights, y)
ValueError: Length of sample weights must match length of y.
geoprior.utils.validator.validate_sets(data, mode='base', allow_empty=True, element_type=None, key_type=<class 'str'>)[source]

Validates whether the input data is a set in ‘base’ mode or a dictionary of sets in ‘deep’ mode. Provides additional parameters for flexibility and versatility. Returns the data if it passes validation.

Parameters:
  • data (Union[set, Dict[str, set]]) –

    The input data to validate. It can be either a single set or a dictionary where keys are set names and values are sets.

    • base mode : A single set.

    • deep mode : A dictionary of sets.

  • mode (str, optional) – The mode in which to validate the data. Options are ‘base’ for a single set and ‘deep’ for a dictionary of sets. Default is ‘base’.

  • allow_empty (bool, optional) – Whether to allow empty sets or dictionaries. Default is True.

  • element_type (type, optional) – The expected type of elements in the set(s). If provided, the function checks whether all elements are of this type. Default is None (no type check).

  • key_type (type, optional) – The expected type of keys in the dictionary when in ‘deep’ mode. Default is str.

Returns:

The original data if it matches the specified mode and additional criteria. Raises ValueError if validation fails.

Return type:

Union[set, Dict[str, set]]

Examples

>>> from geoprior.utils.validator import validate_sets
>>> validate_sets({1, 2, 3}, mode='base')
{1, 2, 3}
>>> validate_sets({"Set1": {1, 2, 3}, "Set2": {3, 4, 5}}, mode='deep')
{"Set1": {1, 2, 3}, "Set2": {3, 4, 5}}
>>> validate_sets({"Set1": {1, 2, 3}, "Set2": [3, 4, 5]}, mode='deep')
Traceback (most recent call last):
    ...
ValueError: Data validation failed: expected all values to be sets
>>> validate_sets(set(), mode='base', allow_empty=False)
Traceback (most recent call last):
    ...
ValueError: Data validation failed: empty set is not allowed
>>> validate_sets({"Set1": set()}, mode='deep', allow_empty=False)
Traceback (most recent call last):
    ...
ValueError: Data validation failed: empty dictionary is not allowed
>>> validate_sets({"Set1": {1, 2, 3}}, mode='deep', element_type=int)
{"Set1": {1, 2, 3}}

Notes

This function checks the type of the input data based on the specified mode. In ‘base’ mode, it ensures the data is a set. In ‘deep’ mode, it ensures the data is a dictionary where all values are sets. Additional parameters allow for checking if sets are empty, if elements are of a specific type, and if dictionary keys are of a specific type. The core type test used here is documented in Python Software Foundation [52].

See also

isinstance

Python built-in function to check an object’s type.

geoprior.utils.validator.validate_strategy(strategy=None, error='raise', ops='validate', rename_key=False, **kwargs)[source]

Validate and construct a strategy dictionary for imputing missing data.

This function processes the input strategy to ensure it conforms to the expected format for imputing missing values in numerical and categorical features. It provides flexibility in handling different strategies and error management, making it suitable for integration with scikit-learn’s imputation tools.

Parameters:
  • strategy (Optional[Union[str, Dict[str, str]]], default None) – Defines the imputation strategy for numerical and categorical features. A string is parsed into a dictionary with keys "numeric" and "categorical", a dictionary is used directly, and None selects the default strategy.

  • error (str, default 'raise') – Error handling behavior for invalid strategy tokens. Use "raise" to raise a ValueError, "warn" to emit a warning, or "ignore" to skip invalid tokens silently.

  • ops (str, default 'validate') – Operation mode of the validator. Use "passthrough" to return the input strategy unchanged when it is already a dictionary, "check_only" to validate without modifying it, or "validate" to validate and construct the strategy dictionary from the input.

  • rename_key (bool, default False) – If True, rename aliases such as "num", "numeric", or "numerical" to "numeric", and aliases such as "cat", "categorical", or "categoric" to "categorical". Other keys remain unchanged.

  • **kwargs – Additional keyword arguments for future extensions.

Returns:

Returns the input strategy dictionary for ops='passthrough', True or False for ops='check_only', and the validated or modified strategy dictionary for ops='validate'.

Return type:

Union[Dict[str, str], bool]

Raises:

ValueError – If an invalid error or ops parameter is provided, or if the strategy tokens are invalid and error is set to ‘raise’.

Notes

The function limits numerical strategies to "median" and "mean" while categorical strategies default to "constant". It also handles key aliasing to keep the returned dictionary consistent.

Examples

>>> from geoprior.utils.validator import validate_strategy
>>> validate_strategy('mean constant')
{'numeric': 'mean', 'categorical': 'constant'}
>>> validate_strategy({'num': 'mean', 'cat': 'constant'}, rename_key=True)
{'numeric': 'mean', 'categorical': 'constant'}
>>> validate_strategy('numeric categorical', ops='check_only')
False
>>> validate_strategy('invalid_strategy', error='warn')
{'numeric': 'median', 'categorical': 'constant'}

See also

sklearn.impute.SimpleImputer

Imputation transformer for completing missing values.

geoprior.utils.validator.validate_scores(scores, true_labels=None, mode='strict', accept_multi_output=False)[source]

Validates that the scores represent valid probability distributions and checks consistency between scores and true labels in multi-output scenarios.

Parameters:
  • scores (list or np.ndarray) – A list of np.ndarrays for multi-output probabilities, or a single np.ndarray for single-output probabilities. Each ndarray should contain probability distributions where each row sums to approximately 1 and has non-negative values.

  • true_labels (list or np.ndarray, optional) – The true labels corresponding to the scores. This parameter must be provided in multi-output scenarios to check the alignment of labels and scores. Each element or row in true_labels should correspond to the equivalent in scores.

  • mode (str, optional (default "strict")) – Validation mode for checking probability distributions. Use "strict" to require each row to sum to 1 within numerical tolerance, "soft" to require non-negative scores with totals no greater than 1, or "passthrough" to only check that each score lies in the interval [0, 1].

  • accept_multi_output (bool, default False) – Flag indicating whether scores with multiple outputs are accepted. If False and scores are provided as a list, a ValueError will be raised.

Returns:

The validated scores as a NumPy array.

Return type:

np.ndarray

Raises:

ValueError – If multi-output scores are provided and not accepted. If there is a mismatch in the number of outputs between scores and true_labels. If scores or any subset of scores do not form valid probability distributions. If there is a mismatch in format expectations between scores and true_labels in terms of multi-output handling.

Notes

The function is designed to handle both single and multi-output probability distributions. For multi-output scenarios, both scores and true_labels should be lists of np.ndarrays. This function is particularly useful in scenarios involving machine learning models where output probabilities need to be validated before further processing or metrics calculations.

Examples

>>> import numpy as np
>>> from geoprior.utils.validator import validate_scores
>>> scores_single = np.array([[0.1, 0.9], [0.8, 0.2]])
>>> print(validate_scores(scores_single))
[[0.1, 0.9]
 [0.8, 0.2]]
>>> scores_multi = [np.array([[0.1, 0.9]]), np.array([[0.8, 0.2]])]
>>> true_labels_multi = [np.array([1]), np.array([0])]
>>> print(validate_scores(scores_multi, true_labels_multi, accept_multi_output=True))
[array([[0.1, 0.9]]), array([[0.8, 0.2]])]
geoprior.utils.validator.validate_square_matrix(data, align=False, align_mode='auto', message='')[source]

Validate that the input data forms a square matrix and optionally aligns its indices and columns if specified.

Parameters:
  • data (DataFrame or array-like) – The input data to validate as a square matrix.

  • align (bool, default False) – Whether to align the DataFrame’s index with its columns.

  • align_mode (str, default 'auto') – Alignment mode if indices and columns do not match. Options are ‘auto’, ‘index_to_columns’, and ‘columns_to_index’.

  • message (str, default '') – Additional message to append to the error if validation fails.

Returns:

The validated or aligned square matrix.

Return type:

data

Raises:

ValueError – If the input is not a square matrix.

Examples

>>> from geoprior.utils.validator import validate_square_matrix
>>> validate_square(np.array([[1, 2], [3, 4]]))
array([[1, 2],
       [3, 4]])
>>> validate_square(pd.DataFrame([[1, 2], [3, 4, 5]]))
ValueError: Input must be a square matrix.

Notes

A square matrix is defined as having equal number of rows and columns. This function checks the dimensionality of the data and optionally aligns the index and columns if align is set to True.

geoprior.utils.validator.validate_weights(weights, min_value=None, max_value=None, normalize=False, allowed_dims=1)[source]

Validates and optionally normalizes the given weights array to ensure all elements meet specified criteria and the structure is suitable for computations.

Parameters:
  • weights (array-like) – Weights to be validated. Can be a list, tuple, or numpy array.

  • min_value (float, optional) – Minimum allowable value for weights (inclusive). If None, weights are expected to be non-negative. Explicitly set to a negative value if negative weights are allowed.

  • max_value (float or None, optional) – Maximum allowable value for weights (inclusive). If None, no upper limit is enforced.

  • normalize (bool, optional) – If True, weights will be normalized to sum to 1. Default is False.

  • allowed_dims (int or tuple, optional) – Specifies the allowed dimensions of the weights array. Default is 1 (one-dimensional). If a tuple is provided, weights must match one of the dimensions specified in the tuple.

Returns:

A numpy array of the validated and optionally normalized weights.

Return type:

np.ndarray

Raises:

ValueError – If weights contain values outside the specified range, or if the format or dimensions are not suitable.

Examples

>>> from geoprior.utils.validator import validate_weights
>>> validate_weights([0.25, 0.75, 0.5], normalize=True)
array([0.2, 0.6, 0.4])
>>> validate_weights([-0.1, 0.9], min_value=0)
ValueError: Weights must be non-negative.
>>> validate_weights([0.1, 0.2, 0.7], max_value=0.5)
ValueError: Weights must not exceed 0.5.
>>> validate_weights([1, 2, 3], allowed_dims=(1, 2))
ValueError: Weights dimensions not allowed.
geoprior.utils.validator.validate_yy(y_true, y_pred, expected_type=None, *, validation_mode='strict', flatten=False)[source]

Validates the shapes and types of actual and predicted target arrays, ensuring they are compatible for further analysis or metrics calculation.

Parameters:
  • y_true (array-like) – True target values.

  • y_pred (array-like) – Predicted target values.

  • expected_type (str, optional) – The expected sklearn type of the target (‘binary’, ‘multiclass’, etc.).

  • validation_mode (str, optional) – Validation strictness. Currently, only ‘strict’ is implemented, which requires y_true and y_pred to have the same shape and match the expected_type.

  • flatten (bool, optional) – If True, both y_true and y_pred are flattened to one-dimensional arrays.

Raises:

ValueError – If y_true and y_pred do not meet the validation criteria.

Returns:

The validated y_true and y_pred arrays, potentially flattened.

Return type:

tuple

These modules form the lighter infrastructure layer behind the more visible staged helpers.

NAT/workflow contract helpers#

Public exports for NAT workflow utilities.

geoprior.utils.nat_utils.build_censor_mask(xb, H, idx, thresh=0.5, *, source='dynamic', reduce_time='any', align='broadcast')[source]

Build a censor mask aligned to the forecast horizon: (B, H, 1).

Parameters:
  • source ({"dynamic", "future"}, default "dynamic") – Selects where the censoring flag is read from. "dynamic" reads xb["dynamic_features"][:, :, idx] from the history window, while "future" reads xb["future_features"][:, :, idx] from the forecast window.

  • reduce_time ({"any", "last", "all"}, default "any") – Reduction applied when source="dynamic" and the censor flag behaves like a per-sample label. "any" marks the sample as censored if any history step is flagged, "last" uses only the last history step, and "all" requires every history step to be flagged.

  • align ({"broadcast", "crop", "pad_false", "pad_edge", "error"}, default "broadcast") – Policy used when the time axis does not already match the forecast horizon H. "broadcast" repeats a single-step label across all horizon steps, "crop" keeps the last H steps, "pad_false" pads missing steps with False, "pad_edge" repeats the last available step, and "error" raises on mismatch.

  • xb (dict)

  • idx (int | None)

  • thresh (float)

Return type:

Tensor

geoprior.utils.nat_utils.ensure_input_shapes(x, mode, forecast_horizon)[source]

Ensure presence of zero-width static/future placeholders.

Stage-1 exporters sometimes omit static_features or future_features when there are no static/future variables for a particular experiment. Keras, however, expects these inputs to exist so that the input signature remains stable.

This helper:

  • Copies the input dict to avoid in-place modification.

  • Ensures static_features is an array of shape (N, 0) if missing.

  • Ensures future_features is an array of shape (N, T_future, 0) if missing, where:

    • T_future = dynamic_features.shape[1] when mode == "tft_like" (past+future style).

    • Otherwise, T_future = forecast_horizon.

Parameters:
  • x (dict) – Dictionary containing at least dynamic_features with shape (N, T_dyn, D_dyn).

  • mode (str) – Model mode. When "tft_like" the future sequence length is inferred from the dynamic sequence.

  • forecast_horizon (int) – Forecast horizon in time steps/years for non-TFT modes.

Returns:

Shallow copy of x with guaranteed static_features and future_features entries.

Return type:

dict

geoprior.utils.nat_utils.extract_preds(model, out, *, strict=True, output_names=None)[source]

Extract (subs_pred, gwl_pred) from GeoPrior outputs.

Supports:
  1. v3.2+ call(): {“subs_pred”,”gwl_pred”}

  2. forward_with_aux(): (y_pred, aux)

  3. legacy: {“data_final”} + model.split_data_predictions

  4. predict(): list/tuple mapped via output names

If strict=True, list/tuple outputs must be mappable via output names; otherwise we raise to avoid silent swaps.

This helper normalizes the output interface across two GeoPrior generation families:

  1. New interface (preferred) model(inputs) -> {"subs_pred": ..., "gwl_pred": ...}

  2. Legacy interface (backward compatible) model(inputs) -> {"data_final": ...}, where the caller must split the tensor using model.split_data_predictions.

Parameters:
  • model (object) –

    A Keras-like model instance that may expose split_data_predictions(data_final).

    The splitter must return a tuple:

    • subs_pred with shape (B, H, 1) or (B, H, Q, 1)

    • gwl_pred with shape (B, H, 1) or (B, H, Q, 1)

  • out (dict) –

    Output returned by the model call, typically model(inputs, training=False).

    Supported keys are either:

    • {"subs_pred", "gwl_pred"} (new interface), or

    • {"data_final"} (legacy interface).

  • strict (bool)

  • output_names (Sequence[str] | None)

Returns:

  • subs_pred (Tensor) – Predicted subsidence in model space.

    Expected shapes:

    • Point mode: (B, H, 1)

    • Quantile mode: (B, H, Q, 1)

  • gwl_pred (Tensor) – Predicted groundwater/head variable in model space.

    Expected shapes:

    • Point mode: (B, H, 1)

    • Quantile mode: (B, H, Q, 1)

Raises:
  • KeyError – If out does not contain a supported key set.

  • TypeError – If out is not a mapping/dict-like object.

Return type:

tuple[Any, Any]

Notes

This function is intended for Stage-2 and Stage-3 scripts where you may load checkpoints from older experiments. It avoids fragile code that slices data_final manually.

The function does not validate tensor dtypes or numerical finiteness. Upstream code should handle NaN and Inf checks as needed. Output normalization follows the Keras model conventions documented in Keras Team [24].

Examples

New interface:

out = model_inf(xb, training=False)
s_pred, h_pred = extract_stage_outputs(
    model_inf,
    out,
)

Legacy interface:

out = model_inf(xb, training=False)
s_pred, h_pred = extract_stage_outputs(
    model_inf,
    out,
)

See also

subs_point_from_stage_out

Convert subsidence predictions to a point forecast.

geoprior.utils.nat_utils.load_nat_config(root='nat.com')[source]

High-level helper used by NATCOM scripts.

Example

>>> from geoprior.utils.nat_utils import load_nat_config
>>> cfg = load_nat_config()
>>> CITY_NAME = cfg["CITY_NAME"]
>>> TIME_STEPS = cfg["TIME_STEPS"]
Return type:

dict[str, Any]

geoprior.utils.nat_utils.load_nat_config_payload(root='nat.com')[source]

Return the full config.json payload, including city, model and __meta__ fields.

This is convenient when you also want to see which hash or city/model are currently active.

Return type:

dict[str, Any]

geoprior.utils.nat_utils.load_scaler_info(encoders_block)[source]

Load the scaler_info mapping from an encoders block.

Stage-1 exporters typically store a compact description of the scalers used to normalise the data. In many cases this takes the form:

encoders = {
    "main_scaler": "/path/to/minmax.joblib",
    "coord_scaler": "/path/to/coords.joblib",
    "scaler_info": "/path/to/scaler_info.joblib",
    ...
}

where scaler_info is either a path to a joblib file or an already-loaded dictionary.

This helper returns a dictionary regardless of how it was stored, making downstream formatting/evaluation code simpler.

Parameters:

encoders_block (dict) – The encoders part of the Stage-1 manifest (M["artifacts"]["encoders"]).

Returns:

The loaded scaler_info dictionary, or None if not present / not loadable.

Return type:

dict or None

geoprior.utils.nat_utils.make_tf_dataset(X_np, y_np, batch_size, shuffle, mode, forecast_horizon, *, seed=42, drop_remainder=False, reshuffle_each_iter=True, prefetch=True, check_npz_finite=False, check_finite=False, scan_finite_batches=0, dynamic_feature_names=None, future_feature_names=None)[source]

Build a tf.data.Dataset using NATCOM conventions.

Steps: 1) ensure_input_shapes(…) for X. 2) map_targets_for_training(…) for y. 3) tf.data pipeline (shuffle/batch/prefetch). 4) optional finite checks (NPZ + tf batches).

Parameters:
  • X_np (dict) – Input dictionary, typically obtained from np.load on the Stage-1 *_inputs_npz file.

  • y_np (dict) – Target dictionary, typically obtained from np.load on the Stage-1 *_targets_npz file.

  • batch_size (int) – Number of samples per batch.

  • shuffle (bool) – If True, shuffle the dataset using a fixed seed for reproducibility.

  • mode (str) – Model mode passed to ensure_input_shapes().

  • forecast_horizon (int) – Forecast horizon passed to ensure_input_shapes().

  • check_npz_finite (bool) – If True, checks Xin/Yin numpy arrays for NaN/Inf before building ds.

  • check_finite (bool) – If True, inserts assert_all_finite checks inside the tf.data pipeline.

  • scan_finite_batches (int) – If >0, eagerly scans first N batches right away (fails early).

  • dynamic_feature_names (list[str] | None) – If provided, used to report bad channels for feature tensors.

  • future_feature_names (list[str] | None) – If provided, used to report bad channels for feature tensors.

  • seed (int)

  • drop_remainder (bool)

  • reshuffle_each_iter (bool)

  • prefetch (bool)

Returns:

Dataset of (X, y) pairs.

Return type:

tf.data.Dataset

Notes

TensorFlow is imported lazily inside the function so that this module remains importable in environments where TF is not installed (for example, for tooling or static analysis).

geoprior.utils.nat_utils.map_targets_for_training(y_dict, subs_key='subsidence', gwl_key='gwl', subs_pred_key='subs_pred', gwl_pred_key='gwl_pred')[source]

Standardise target dictionaries to the Keras compile keys.

This helper enforces a small convention used throughout the NATCOM training scripts:

  • Upstream sequence builders typically export raw targets with keys subsidence and gwl.

  • The GeoPrior model is compiled with targets named subs_pred and gwl_pred.

This function accepts either style and always returns a dict keyed by subs_pred and gwl_pred for use in Keras.

Parameters:
  • y_dict (dict) – Dictionary produced by the Stage-1 sequence exporter or by a previous training script. Must contain either (subsidence, gwl) or (subs_pred, gwl_pred).

  • subs_key (str, default "subsidence") – Name of the raw subsidence key in y_dict.

  • gwl_key (str, default "gwl") – Name of the raw groundwater-level key in y_dict.

  • subs_pred_key (str, default "subs_pred") – Standardised key for the subsidence prediction target.

  • gwl_pred_key (str, default "gwl_pred") – Standardised key for the GWL prediction target.

Returns:

New dictionary with keys subs_pred and gwl_pred.

Return type:

dict

Raises:

KeyError – If the dictionary does not contain either of the expected key pairs.

geoprior.utils.nat_utils.name_of(obj)[source]

Return a human-readable name for an object.

This utility is handy when serialising compile configurations (e.g., turning metric callables into simple strings for JSON logs).

Parameters:

obj (object) – Any Python object (function, class instance, etc.).

Returns:

obj.__name__ if present, otherwise the class name, and finally str(obj) as a last resort.

Return type:

str

geoprior.utils.nat_utils.resolve_hybrid_config(manifest_cfg, live_cfg, verbose=True)[source]

Merge Manifest config (Data Authority) with Live config (Physics Authority).

Parameters:
Return type:

dict

geoprior.utils.nat_utils.resolve_si_affine(cfg, scaler_info, *, target_name, prefix, unit_factor_key, scale_key, bias_key)[source]
Parameters:
  • cfg (dict)

  • scaler_info (dict)

  • target_name (str)

  • prefix (str)

  • unit_factor_key (str)

  • scale_key (str)

  • bias_key (str)

geoprior.utils.nat_utils.best_epoch_and_metrics(history, monitor='val_loss')[source]

Return the best epoch and metrics at that epoch.

Given a History.history dictionary produced by model.fit(...), this helper identifies the index of the minimum value for the monitored quantity (by default "val_loss") and returns:

  • The epoch index (0-based).

  • A dictionary mapping each metric name to its value at that epoch.

Parameters:
  • history (dict) – The history.history attribute from Keras training.

  • monitor (str, default "val_loss") – Name of the metric to minimise.

Returns:

  • best_epoch (int or None) – Index of the best epoch, or None if monitor is not present.

  • metrics_at_best (dict) – Mapping from metric name to its value at the best epoch. Empty if monitor is not present.

Return type:

tuple[int | None, dict]

geoprior.utils.nat_utils.subs_point_from_out(model, out, quantiles=None, med_idx=None)[source]

Convert model output into a subsidence point forecast.

This helper produces a subsidence tensor shaped (B, H, 1) in model space, regardless of whether the model emits quantiles or a point prediction.

  • If quantiles are present and the subsidence prediction is shaped (B, H, Q, 1), the function selects the median quantile slice.

  • Otherwise, it returns the point prediction directly.

Parameters:
  • model (object) – A Keras-like model instance passed to extract_stage_outputs().

  • out (dict) –

    Output returned by the model call.

    This can be either the new interface with keys "subs_pred" and "gwl_pred", or the legacy interface with key "data_final".

  • quantiles (sequence of float or None, default None) –

    Quantile levels used by the model, such as [0.1, 0.5, 0.9].

    If provided, the function may use it to interpret the rank-4 quantile output and select the median.

    If None, quantile selection is disabled unless med_idx is explicitly provided and the tensor rank indicates quantiles.

  • med_idx (int or None, default None) –

    Index along the quantile axis to use as the “point” forecast when quantiles are available.

    If None and quantiles is provided, the function selects the index closest to 0.5.

Returns:

subs_point – Subsidence point prediction in model space with shape (B, H, 1).

Return type:

Tensor

Raises:
  • ValueError – If subsidence prediction is missing or None.

  • ValueError – If a quantile tensor is detected but a valid median index cannot be resolved.

Notes

Quantile outputs are assumed to be shaped (B, H, Q, 1) where the quantile axis is the third dimension (axis=2).

If the model returns point predictions already, the function is effectively a no-op. The quantile interpretation used here follows Koenker and Bassett [25].

Examples

Quantile model:

out = model_inf(xb, training=False)
s_point = subs_point_from_stage_out(
    model_inf,
    out,
    quantiles=[0.1, 0.5, 0.9],
)

Point model:

out = model_inf(xb, training=False)
s_point = subs_point_from_stage_out(
    model_inf,
    out,
)

See also

extract_stage_outputs

Normalize outputs across new and legacy checkpoints.

geoprior.utils.nat_utils.serialize_subs_params(params, cfg=None)[source]

Make GeoPrior subnet parameters JSON-friendly.

The training scripts typically pass a dictionary of model construction arguments, e.g. subsmodel_params, which contains objects such as LearnableMV or FixedGammaW that are not directly JSON-serialisable.

This helper replaces those objects by small dictionaries describing their type and scalar value, optionally using values from the NATCOM config dictionary.

Parameters:
  • params (dict) – Dictionary of model init parameters (e.g. subsmodel_params in training_NATCOM_GEOPRIOR.py).

  • cfg (dict, optional) –

    NATCOM config dictionary. If provided, scalar values are taken from:

    • GEOPRIOR_INIT_MV

    • GEOPRIOR_INIT_KAPPA

    • GEOPRIOR_GAMMA_W

    • GEOPRIOR_H_REF

    and used as the authoritative numbers.

Returns:

Copy of params where scalar GeoPrior parameters are replaced by JSON-friendly dictionaries.

Return type:

dict

Notes

This function does not import any of the GeoPrior classes. It only introspects attributes like initial_value or value when the corresponding config entry is missing.

geoprior.utils.nat_utils.save_ablation_record(outdir, city, model_name, cfg, eval_dict, phys_diag=None, per_h_mae=None, per_h_r2=None, log_fn=None)[source]

Append a single ablation record to ablation_record.jsonl.

Each training run (e.g., different physics toggles or weights) writes one JSON line containing:

  • Basic run identifiers (city, model, timestamp).

  • Physics configuration (PDE_MODE_CONFIG, lambda weights, effective head flags, etc.).

  • Key performance metrics (R², MSE, MAE, coverage, sharpness).

  • Optional physics diagnostics (epsilon_prior, epsilon_cons).

  • Optional per-horizon MAE/R² for more detailed analysis.

Parameters:
  • outdir (str) – Base output directory for the current run. The ablation file is created under outdir / "ablation_records".

  • city (str) – City name (e.g., "nansha" or "zhongshan").

  • model_name (str) – Model identifier (e.g., "GeoPriorSubsNet").

  • cfg (dict) – Lightweight configuration dictionary containing at least the physics-related keys used below.

  • eval_dict (dict or None) – Dictionary of evaluation metrics (R², MSE, MAE, coverage80, sharpness80). If None, metrics fields default to None.

  • phys_diag (dict or None, optional) – Physics diagnostics (e.g., from evaluate()) with keys such as "epsilon_prior" and "epsilon_cons".

  • per_h_mae (dict or None, optional) – Per-horizon MAE values (e.g., keyed by year/step).

  • per_h_r2 (dict or None, optional) – Per-horizon R² values.

Return type:

None

Notes

The output file is a JSON-Lines file, so it can be loaded with load_ablation_jsonl().

geoprior.utils.nat_utils.load_windows_npz(path)[source]

Load Stage-1 windows as (x, y).

Supported: - Bundle NPZ (contains inputs+targets in one file). - Mapping {‘inputs’: <npz>, ‘targets’: <npz>}. - Inputs NPZ only (targets inferred by filename). - Directory containing inputs/targets NPZ.

Returns:

  • x (dict[str, np.ndarray]) – Inputs (e.g., static_features, dynamic_features, etc.)

  • y (dict[str, np.ndarray]) – Targets (e.g., subs_pred, gwl_pred)

Parameters:

path (str | Path | Mapping[str, str])

Return type:

tuple[dict[str, ndarray], dict[str, ndarray]]

geoprior.utils.nat_utils.load_tuned_hps_near_model(model_path, *, prefer='keras', required=True, log_fn=None)[source]
Parameters:
Return type:

dict

geoprior.utils.nat_utils.load_trained_hps_near_model(model_path, *, allowed, required=False, log_fn=None)[source]
Parameters:
Return type:

dict

geoprior.utils.nat_utils.sanitize_inputs_np(X)[source]
Parameters:

X (dict)

Return type:

dict

geoprior.utils.nat_utils.load_hps_auto_near_model(model_path, *, allowed, prefer='keras', required=False, log_fn=None)[source]
Parameters:
Return type:

dict

geoprior.utils.nat_utils.load_or_rebuild_geoprior_model(model_path, manifest, X_sample, out_s_dim, out_g_dim, mode, horizon, quantiles, city_name=None, compile_on_load=True, verbose=1)[source]

Load a tuned or trained GeoPriorSubsNet, with robust rebuild fallback.

Parameters:
geoprior.utils.nat_utils.compile_for_eval(model, manifest, best_hps, quantiles, *, include_metrics=True)[source]

Recompile a GeoPriorSubsNet instance for evaluation / diagnostics.

This is intended for: - tuned models loaded from a .keras archive, or - models rebuilt from best_hps.

It does NOT change the architecture or weights, only the compile configuration (optimizer, losses, and physics loss weights).

Parameters:
  • model (GeoPriorSubsNet) – Loaded or freshly-built GeoPriorSubsNet instance.

  • manifest (dict) – Stage-1 manifest; training config is taken from manifest['config'].

  • best_hps (dict or None) – Dictionary of tuned hyperparameters. If empty/None, reasonable defaults are inferred from the manifest.

  • quantiles (list of float or None) – Quantiles used for probabilistic subsidence/GWL outputs.

  • include_metrics (bool, default True) – If True, attach MAE/MSE + coverage/sharpness metrics to match the training script; if False, only losses are configured.

Returns:

The same model instance, compiled in-place.

Return type:

model

geoprior.utils.nat_utils.load_best_hps_near_model(model_path, *, model_name='GeoPriorSubsNet', prefer='keras', log_fn=None)[source]

Load best hyperparameters saved near a model artifact.

Supports model names like: <city>_<model_name>_H{H}_best.keras <city>_<model_name>_H{H}_best.weights.h5

Parameters:
  • model_path (str) – Path to a model file or its run directory.

  • model_name (str or None, default "GeoPriorSubsNet") – Model name token in filenames.

  • prefer ({"keras", "weights"}, default "keras") – Which artifact type to infer the prefix from.

  • log_fn (callable or None, default None) – Logger (e.g. print). None disables logs.

Returns:

best_hps – Non-empty hyperparameter dictionary.

Return type:

dict

Raises:
geoprior.utils.nat_utils.pick_npz_for_dataset(manifest, split)[source]

Load (inputs, targets) NPZ arrays for a given dataset split.

This is a public, reusable version of the internal helper that was previously named _pick_npz_for_dataset.

Parameters:
  • manifest (dict) –

    Stage-1 manifest dictionary with the structure:

    manifest["artifacts"]["numpy"] = {
        "train_inputs_npz": ...,
        "train_targets_npz": ...,
        "val_inputs_npz": ...,
        "val_targets_npz": ...,
        "test_inputs_npz": ... (optional),
        "test_targets_npz": ... (optional),
    }
    

  • split ({"train", "val", "test"}) – Which dataset to load.

Returns:

  • X (dict or None) – Dictionary of input arrays for the requested split, or None if the split is unavailable (e.g. test NPZ missing).

  • y (dict or None) – Dictionary of target arrays for the requested split, or None if targets are unavailable.

Raises:
  • KeyError – If the manifest does not contain the expected NPZ entries.

  • ValueError – If split is not one of {"train", "val", "test"}.

Return type:

tuple[dict | None, dict | None]

geoprior.utils.nat_utils.ensure_config_json(root='nat.com')[source]

Ensure that nat.com/config.json exists and is consistent with nat.com/config.py.

Returns:

  • config (dict) – The configuration dictionary (payload[“config”]).

  • json_path (str) – Absolute path to config.json.

  • Behaviour

  • ---------

  • - If `config.json does not exist`, it is created fromconfig.py.

  • - If it exists but the SHA-256 hash of config.py has – changed, it is regenerated.

  • - Otherwise the existing JSON file is reused.

Parameters:

root (str)

Return type:

tuple[dict[str, Any], str]

geoprior.utils.nat_utils.get_natcom_dir(root='nat.com')[source]

Directory containing NATCOM scripts and configuration, typically <repo_root>/nat.com.

Return type:

str

geoprior.utils.nat_utils.get_config_paths(root='nat.com')[source]

Return (config_py_path, config_json_path) for NATCOM.

Return type:

tuple[str, str]

geoprior.utils.nat_utils.compile_geoprior_for_eval(model, manifest, best_hps, quantiles)[source]

(Re)compile a GeoPriorSubsNet-like model for evaluation.

This helper uses the Stage-1 manifest and tuned hyperparameters to configure:

  • the pinball losses for subsidence and GWL outputs,

  • loss weights for the two heads,

  • physics loss weights (lambda_*),

  • learning rate and LR multipliers.

TensorFlow and geoprior are imported lazily inside this function so that nat_utils can be imported even in non-TF environments.

Parameters:
  • model (GeoPriorSubsNet-like) – An instance of the GeoPriorSubsNet model (or any model exposing the same compile signature).

  • manifest (dict) – Stage-1 manifest dictionary. The config entry is used to retrieve default loss weights and physics settings.

  • best_hps (dict) – Hyperparameters loaded from the tuning run (e.g. via load_best_hps_near_model()).

  • quantiles (list of float or None) – Quantile levels used for probabilistic outputs. If None, mean-squared error is used instead of pinball loss.

Returns:

The same model instance, compiled in-place.

Return type:

model

Raises:

ImportError – If TensorFlow or geoprior’s make_weighted_pinball cannot be imported.

This is one of the most important workflow modules in the GeoPrior stack. It includes config loading, dataset building, target mapping, scaling resolution, prediction extraction, artifact serialization, and model-evaluation support.

Subsidence-oriented workflow helpers#

Utility helpers for subsidence data, units, and coordinates.

geoprior.utils.subsidence_utils.normalize_gwl_alias(df, gwl_col_user, *, prefer_depth_bgs=True, verbose=True)[source]

Normalize common GWL naming aliases.

Naming-only: unit conversion happens later.

Parameters:
Return type:

tuple[DataFrame, str | None]

geoprior.utils.subsidence_utils.resolve_gwl_for_physics(df, gwl_col_user, *, prefer_depth_bgs=True, allow_keep_zscore_as_ml=True, verbose=True)[source]

Pick meters GWL for physics; keep z-score ML-only.

Parameters:
Return type:

tuple[str, str | None]

geoprior.utils.subsidence_utils.resolve_head_column(df, *, depth_col, head_col='head_m', z_surf_col=None, use_head_proxy=True)[source]

Resolve a head column, creating one if needed.

Parameters:
Return type:

tuple[str, str | None]

geoprior.utils.subsidence_utils.convert_target_units_df(df, *, base, from_unit='m', to_unit='mm', mode='overwrite', suffix='_mm', columns=None, unit_col=None, copy_df=True, overwrite_cols=False, strict=False)[source]

Convert target-like columns between “m” and “mm”.

Selected columns: - base - base + “_*” (quantiles, intervals, …)

If mode=”add”, new columns use suffix.

Parameters:
Return type:

DataFrame | None

geoprior.utils.subsidence_utils.subs_unit_to_si(cfg=None, *, default=0.001, units_prov_key='units_provenance', stage1_key='subs_unit_to_si_applied_stage1', cfg_key='SUBS_UNIT_TO_SI')[source]

Subsidence unit->SI factor (to meters).

Priority: 1) cfg[units_prov_key][stage1_key] 2) cfg[cfg_key] 3) default

Parameters:
Return type:

float

geoprior.utils.subsidence_utils.subs_native_unit(cfg=None, *, default='mm')[source]

Infer the “native” subsidence unit from cfg.

  • unit_to_si ~= 1e-3 -> “mm”

  • unit_to_si ~= 1.0 -> “m”

Parameters:
Return type:

str

geoprior.utils.subsidence_utils.add_subsidence_mm_columns(df, cfg=None, *, base='subsidence', columns=None, suffix='_mm', unit_col=None, copy_df=True, overwrite=False)[source]

Add (or overwrite) subsidence columns in millimeters.

Always assumes the current values are in meters.

Parameters:
Return type:

DataFrame | None

geoprior.utils.subsidence_utils.add_subsidence_native_unit_columns(df, cfg=None, *, base='subsidence', columns=None, suffix='_native', unit_col=None, copy_df=True, overwrite=False)[source]

Add columns in the cfg-inferred native unit.

If cfg says the native unit was meters, this becomes a no-op aside from optional unit_col.

Parameters:
Return type:

DataFrame | None

geoprior.utils.subsidence_utils.finalize_si_scaling_kwargs(scaling_kwargs, *, subs_in_si, head_in_si, thickness_in_si, force_identity_affine_if_si=True, warn=True)[source]

Prevent double SI conversion in GeoPrior scaling_kwargs.

Parameters:
Return type:

dict[str, Any]

geoprior.utils.subsidence_utils.finalize_si_affines_and_units(scaling_kwargs, *, subs_in_si, head_in_si, thickness_in_si, force_identity_affine_if_si=True, warn=True)

Prevent double SI conversion in GeoPrior scaling_kwargs.

Parameters:
Return type:

dict[str, Any]

geoprior.utils.subsidence_utils.infer_utm_epsg_from_lonlat(lon_deg, lat_deg)[source]

Infer a UTM EPSG code from lon/lat (EPSG:4326).

Uses standard UTM zoning:

zone = floor((lon + 180)/6) + 1 EPSG = 32600 + zone (north), 32700 + zone (south)

Parameters:
Return type:

int

geoprior.utils.subsidence_utils.lonlat_to_utm_m(lon_deg, lat_deg, *, src_epsg=4326, target_epsg=None)[source]

Convert lon/lat degrees to UTM meters using pyproj.

Return type:

x_m, y_m, target_epsg

Parameters:
class geoprior.utils.subsidence_utils.CoordsPack(coords: 'np.ndarray', coord_mins: 'dict[str, float]', coord_ranges: 'dict[str, float]', meta: 'dict[str, Any]')[source]

Bases: object

Parameters:
coords: ndarray
coord_mins: dict[str, float]
coord_ranges: dict[str, float]
meta: dict[str, Any]
__init__(coords, coord_mins, coord_ranges, meta)
Parameters:
Return type:

None

geoprior.utils.subsidence_utils.make_txy_coords(t, x, y, *, time_shift='min', xy_shift='min', time_shift_value=None, x_shift_value=None, y_shift_value=None, dtype='float32')[source]

Build coords tensor (t, x, y) with OPTIONAL shifting (translation only).

This is designed for your “not normalized” workflow:
  • You keep SI units (years and meters),

  • but you avoid feeding huge UTM magnitudes (e.g. 3e5, 2.5e6) into coord MLPs by shifting x,y (and optionally t).

Notes

  • This does NOT min-max scale to [0,1]. It only translates.

  • Returning coord_mins/coord_ranges is still useful for logging/debug.

Parameters:
Return type:

CoordsPack

geoprior.utils.subsidence_utils.detect_subsidence_mode(df, *, rate_col='subsidence', cum_col='subsidence_cum', time_col='year', group_cols=('longitude', 'latitude', 'city'), tol_rel=0.05, min_points=3, max_groups=200, random_state=42)[source]

Infer whether subsidence columns represent ‘rate’, ‘cumulative’, or an inconsistent pair.

Parameters:
Return type:

dict

geoprior.utils.subsidence_utils.rate_to_cumulative(df, *, rate_col='subsidence', cum_col='subsidence_cum', time_col='year', group_cols=('longitude', 'latitude', 'city'), initial='first_equals_rate_dt', inplace=False)[source]

Build cumulative displacement from a rate series.

Parameters:
Return type:

DataFrame

geoprior.utils.subsidence_utils.cumulative_to_rate(df, *, cum_col='subsidence_cum', rate_col='subsidence', time_col='year', group_cols=('longitude', 'latitude', 'city'), first='cum_over_dtref', inplace=False)[source]

Recover a rate series from cumulative displacement.

rate(t_i) = (cum(t_i) - cum(t_{i-1})) / dt_i for i>=1

first:
  • ‘nan’: first rate is NaN

  • ‘cum_over_dtref’: rate(t0) = cum(t0)/dt_ref (dt_ref median dt)

Return type:

DataFrame with rate_col added/overwritten.

Parameters:
geoprior.utils.subsidence_utils.convert_eval_payload_units(payload, cfg=None, *, mode='si', scope='all', savefile=None, fmt='json', indent=2, copy_payload=True)[source]

Convert GeoPriorSubsNet evaluation-payload units for reporting.

This is a post-processing helper meant for stage-2 evaluation JSON payloads (e.g. geoprior_eval_phys_<timestamp>.json).

Parameters:
  • payload (Mapping[str, Any]) – The evaluation payload dict. It is expected to contain sections like metrics_evaluate, point_metrics, per_horizon, interval_calibration and censor_stratified.

  • cfg (mapping or module, optional) – The experiment config (e.g. config module or globals()). The helper reads SUBS_UNIT_TO_SI (or stage-1 provenance) and TIME_UNITS from this object when available.

  • mode ({"si", "interpretable"}, default "si") – "si" leaves values untouched. "interpretable" converts selected subsidence and physics metrics from SI into the native units implied by SUBS_UNIT_TO_SI.

  • scope ({"all", "subsidence", "physics"}, default "all") – Which parts to convert when mode="interpretable". "subsidence" converts only subsidence metrics such as MAE, MSE, and sharpness to the native unit. "physics" converts only unambiguous physics residual rates, currently epsilon_cons_raw and epsilon_gw_raw. "all" applies both conversions.

  • savefile (str, optional) – If provided, write the converted payload to this path.

  • fmt ({"json"}, default "json") – Output format when savefile is provided.

  • indent (int, default 2) – JSON indentation.

  • copy_payload (bool, default True) – If True, operate on a deep copy of payload. If False, convert in-place (dangerous).

Returns:

Converted payload as a plain dict.

Return type:

dict

Notes

For subsidence metrics, linear quantities such as MAE and sharpness scale by 1 / SUBS_UNIT_TO_SI, while squared quantities such as MSE scale by (1 / SUBS_UNIT_TO_SI) ** 2.

When physics conversion is enabled, epsilon_cons_raw is treated as a rate in m/s and converted to <subs_native_unit>/<TIME_UNITS> (for example mm/yr), while epsilon_gw_raw is treated as a rate in 1/s and converted to 1/<TIME_UNITS>.

The helper records unit provenance under payload["units"].

geoprior.utils.subsidence_utils.clean_gwl_zsurf(df, *, lon_col='longitude', lat_col='latitude', z_col='z_surf', gwl_col='GWL_depth_bgs', gwl_scale='auto', max_depth_m=50.0, z_outlier_iqr_mult=3.0, knn_k=25, return_report=True, savefile=None)[source]

Fix GWL units + robust z_surf, then recompute head_m.

Parameters:
Return type:

tuple[DataFrame, dict[str, Any] | None, str | None]

geoprior.utils.subsidence_utils.postprocess_eval_json(eval_json, *, cfg=None, scope='all', out_path=None, overwrite=False, add_rmse=True, force=False, indent=2)[source]

Post-hoc convert a Stage-2 evaluation JSON from SI to interpretable units.

This is a safe wrapper around convert_eval_payload_units(…) that: - loads a JSON from disk (or accepts a payload dict), - infers unit factors from payload[“units”] when cfg is missing, - avoids double conversion unless force=True, - optionally adds RMSE fields (sqrt(MSE)), - writes a converted JSON file if out_path is provided.

Parameters:
  • eval_json (str | Mapping[str, Any]) – Either a file path to the saved JSON, or an in-memory mapping.

  • cfg (Mapping[str, object] | None) – Optional config mapping/module. If missing (or incomplete), this helper will synthesize the minimal keys needed for conversion: - “SUBS_UNIT_TO_SI” - “TIME_UNITS”

  • scope (Literal['all', 'subsidence', 'physics']) – Forwarded to convert_eval_payload_units(…).

  • out_path (str | None) – If given, write the converted payload there. If a directory is given, a filename is generated next to the input name (or “geoprior_eval…”).

  • overwrite (bool) – If False and out_path exists, raise.

  • add_rmse (bool) – If True, add RMSE fields wherever MSE is present (metrics_evaluate, point_metrics, per_horizon).

  • force (bool) – If False, skip conversion when payload already declares an interpretable subsidence unit (e.g., “mm”). If True, always convert.

  • indent (int) – JSON indent used when writing.

Returns:

The converted payload (always returned as a dict).

Return type:

dict

This module is the workflow-side bridge back to the subsidence application. It exposes unit conversion helpers, cumulative/rate transforms, groundwater-resolution helpers, head-column selection, and coordinate construction.

Selected workflow helpers#

These functions are surfaced explicitly because they appear frequently in staged workflows, export paths, and figure or reproducibility scripts.

geoprior.utils.load_nat_config(root='nat.com')[source]

High-level helper used by NATCOM scripts.

Example

>>> from geoprior.utils.nat_utils import load_nat_config
>>> cfg = load_nat_config()
>>> CITY_NAME = cfg["CITY_NAME"]
>>> TIME_STEPS = cfg["TIME_STEPS"]
Return type:

dict[str, Any]

geoprior.utils.load_nat_config_payload(root='nat.com')[source]

Return the full config.json payload, including city, model and __meta__ fields.

This is convenient when you also want to see which hash or city/model are currently active.

Return type:

dict[str, Any]

geoprior.utils.make_tf_dataset(X_np, y_np, batch_size, shuffle, mode, forecast_horizon, *, seed=42, drop_remainder=False, reshuffle_each_iter=True, prefetch=True, check_npz_finite=False, check_finite=False, scan_finite_batches=0, dynamic_feature_names=None, future_feature_names=None)[source]

Build a tf.data.Dataset using NATCOM conventions.

Steps: 1) ensure_input_shapes(…) for X. 2) map_targets_for_training(…) for y. 3) tf.data pipeline (shuffle/batch/prefetch). 4) optional finite checks (NPZ + tf batches).

Parameters:
  • X_np (dict) – Input dictionary, typically obtained from np.load on the Stage-1 *_inputs_npz file.

  • y_np (dict) – Target dictionary, typically obtained from np.load on the Stage-1 *_targets_npz file.

  • batch_size (int) – Number of samples per batch.

  • shuffle (bool) – If True, shuffle the dataset using a fixed seed for reproducibility.

  • mode (str) – Model mode passed to ensure_input_shapes().

  • forecast_horizon (int) – Forecast horizon passed to ensure_input_shapes().

  • check_npz_finite (bool) – If True, checks Xin/Yin numpy arrays for NaN/Inf before building ds.

  • check_finite (bool) – If True, inserts assert_all_finite checks inside the tf.data pipeline.

  • scan_finite_batches (int) – If >0, eagerly scans first N batches right away (fails early).

  • dynamic_feature_names (list[str] | None) – If provided, used to report bad channels for feature tensors.

  • future_feature_names (list[str] | None) – If provided, used to report bad channels for feature tensors.

  • seed (int)

  • drop_remainder (bool)

  • reshuffle_each_iter (bool)

  • prefetch (bool)

Returns:

Dataset of (X, y) pairs.

Return type:

tf.data.Dataset

Notes

TensorFlow is imported lazily inside the function so that this module remains importable in environments where TF is not installed (for example, for tooling or static analysis).

geoprior.utils.map_targets_for_training(y_dict, subs_key='subsidence', gwl_key='gwl', subs_pred_key='subs_pred', gwl_pred_key='gwl_pred')[source]

Standardise target dictionaries to the Keras compile keys.

This helper enforces a small convention used throughout the NATCOM training scripts:

  • Upstream sequence builders typically export raw targets with keys subsidence and gwl.

  • The GeoPrior model is compiled with targets named subs_pred and gwl_pred.

This function accepts either style and always returns a dict keyed by subs_pred and gwl_pred for use in Keras.

Parameters:
  • y_dict (dict) – Dictionary produced by the Stage-1 sequence exporter or by a previous training script. Must contain either (subsidence, gwl) or (subs_pred, gwl_pred).

  • subs_key (str, default "subsidence") – Name of the raw subsidence key in y_dict.

  • gwl_key (str, default "gwl") – Name of the raw groundwater-level key in y_dict.

  • subs_pred_key (str, default "subs_pred") – Standardised key for the subsidence prediction target.

  • gwl_pred_key (str, default "gwl_pred") – Standardised key for the GWL prediction target.

Returns:

New dictionary with keys subs_pred and gwl_pred.

Return type:

dict

Raises:

KeyError – If the dictionary does not contain either of the expected key pairs.

geoprior.utils.resolve_si_affine(cfg, scaler_info, *, target_name, prefix, unit_factor_key, scale_key, bias_key)[source]
Parameters:
  • cfg (dict)

  • scaler_info (dict)

  • target_name (str)

  • prefix (str)

  • unit_factor_key (str)

  • scale_key (str)

  • bias_key (str)

geoprior.utils.extract_preds(model, out, *, strict=True, output_names=None)[source]

Extract (subs_pred, gwl_pred) from GeoPrior outputs.

Supports:
  1. v3.2+ call(): {“subs_pred”,”gwl_pred”}

  2. forward_with_aux(): (y_pred, aux)

  3. legacy: {“data_final”} + model.split_data_predictions

  4. predict(): list/tuple mapped via output names

If strict=True, list/tuple outputs must be mappable via output names; otherwise we raise to avoid silent swaps.

This helper normalizes the output interface across two GeoPrior generation families:

  1. New interface (preferred) model(inputs) -> {"subs_pred": ..., "gwl_pred": ...}

  2. Legacy interface (backward compatible) model(inputs) -> {"data_final": ...}, where the caller must split the tensor using model.split_data_predictions.

Parameters:
  • model (object) –

    A Keras-like model instance that may expose split_data_predictions(data_final).

    The splitter must return a tuple:

    • subs_pred with shape (B, H, 1) or (B, H, Q, 1)

    • gwl_pred with shape (B, H, 1) or (B, H, Q, 1)

  • out (dict) –

    Output returned by the model call, typically model(inputs, training=False).

    Supported keys are either:

    • {"subs_pred", "gwl_pred"} (new interface), or

    • {"data_final"} (legacy interface).

  • strict (bool)

  • output_names (Sequence[str] | None)

Returns:

  • subs_pred (Tensor) – Predicted subsidence in model space.

    Expected shapes:

    • Point mode: (B, H, 1)

    • Quantile mode: (B, H, Q, 1)

  • gwl_pred (Tensor) – Predicted groundwater/head variable in model space.

    Expected shapes:

    • Point mode: (B, H, 1)

    • Quantile mode: (B, H, Q, 1)

Raises:
  • KeyError – If out does not contain a supported key set.

  • TypeError – If out is not a mapping/dict-like object.

Return type:

tuple[Any, Any]

Notes

This function is intended for Stage-2 and Stage-3 scripts where you may load checkpoints from older experiments. It avoids fragile code that slices data_final manually.

The function does not validate tensor dtypes or numerical finiteness. Upstream code should handle NaN and Inf checks as needed. Output normalization follows the Keras model conventions documented in Keras Team [24].

Examples

New interface:

out = model_inf(xb, training=False)
s_pred, h_pred = extract_stage_outputs(
    model_inf,
    out,
)

Legacy interface:

out = model_inf(xb, training=False)
s_pred, h_pred = extract_stage_outputs(
    model_inf,
    out,
)

See also

subs_point_from_stage_out

Convert subsidence predictions to a point forecast.

geoprior.utils.best_epoch_and_metrics(history, monitor='val_loss')[source]

Return the best epoch and metrics at that epoch.

Given a History.history dictionary produced by model.fit(...), this helper identifies the index of the minimum value for the monitored quantity (by default "val_loss") and returns:

  • The epoch index (0-based).

  • A dictionary mapping each metric name to its value at that epoch.

Parameters:
  • history (dict) – The history.history attribute from Keras training.

  • monitor (str, default "val_loss") – Name of the metric to minimise.

Returns:

  • best_epoch (int or None) – Index of the best epoch, or None if monitor is not present.

  • metrics_at_best (dict) – Mapping from metric name to its value at the best epoch. Empty if monitor is not present.

Return type:

tuple[int | None, dict]

geoprior.utils.audit_stage1_scaling(*, df_train, inputs_train, targets_train, coord_scaler=None, coord_ranges=None, coord_mode='auto', coords_in_degrees=False, coord_epsg_used=None, coord_x_col_used='x', coord_y_col_used='y', x_col_used='x', y_col_used='y', time_col_used='t', normalize_coords=True, keep_coords_raw=False, shift_raw_coords=False, subs_model_col=None, gwl_dyn_col=None, gwl_target_col=None, h_field_col=None, dynamic_features=None, static_features=None, future_features=None, scaled_ml_numeric_cols=None, main_scaler_path=None, scaler_info=None, save_dir=None, table_width=110, title_prefix='COORDINATE + FEATURE SCALING AUDIT (Stage-1)', city='Unknown', model_name='Model', sample_rows=5, log_fn=None)[source]

Stage-1 audit: - raw df_train coord stats (t/x/y) + heuristic units - model-fed coords stats from inputs_train[“coords”] (flattened) - coord scaler min/max + coord_ranges - SI channel sanity for physics cols (if present) - target arrays sanity - split of features: scaled ML vs __si vs other Saves a machine-readable JSON if save_dir is provided.

Parameters:
  • inputs_train (dict[str, Any])

  • targets_train (dict[str, Any])

  • coord_scaler (Any)

  • coord_ranges (dict[str, float] | None)

  • coord_mode (str)

  • coords_in_degrees (bool)

  • coord_epsg_used (Any)

  • coord_x_col_used (str)

  • coord_y_col_used (str)

  • x_col_used (str)

  • y_col_used (str)

  • time_col_used (str)

  • normalize_coords (bool)

  • keep_coords_raw (bool)

  • shift_raw_coords (bool)

  • subs_model_col (str | None)

  • gwl_dyn_col (str | None)

  • gwl_target_col (str | None)

  • h_field_col (str | None)

  • dynamic_features (Iterable[str] | None)

  • static_features (Iterable[str] | None)

  • future_features (Iterable[str] | None)

  • scaled_ml_numeric_cols (Iterable[str] | None)

  • main_scaler_path (str | None)

  • scaler_info (dict | None)

  • save_dir (str | None)

  • table_width (int)

  • title_prefix (str)

  • city (str)

  • model_name (str)

  • sample_rows (int)

Return type:

str | None

geoprior.utils.audit_stage2_handshake(*, X_train, X_val, y_train, y_val, time_steps, forecast_horizon, mode, dyn_names, fut_names, sta_names, coord_scaler=None, sk_final, save_dir, table_width=100, title_prefix='STAGE-2 HANDSHAKE AUDIT', city='Unkown', model_name='Model', log_fn=None)[source]
Parameters:
geoprior.utils.calibrate_quantile_forecasts(*, df_eval=None, df_future=None, target_name='subsidence', column_map=None, step_col='forecast_step', interval=(0.1, 0.9), target_coverage=0.8, median_q=0.5, use='auto', tol=0.02, f_max=5.0, max_iter=32, keep_original=False, enforce_monotonic='cummax', overall_key='__overall__', calibrated_col='is_calibrated', factor_col='calibration_factor', factors=None, save_eval=None, save_future=None, save_stats=None, verbose=1, logger=None)[source]

Fit and apply post-hoc interval calibration for quantile forecasts.

This is the high-level DataFrame-oriented entry point for interval recalibration in geoprior.utils.calibrate. It can

  1. detect whether evaluation forecasts already appear calibrated,

  2. fit interval-width correction factors from evaluation data,

  3. apply those factors to evaluation and/or future forecasts,

  4. compute before/after summary diagnostics on the evaluation set,

  5. optionally save the outputs to disk.

The function is designed for workflows where quantile forecasts are already available in tabular form and calibration should be handled without retraining the forecasting model.

Conceptually, the function widens or narrows a predictive interval around a median-like forecast so that the empirical interval coverage better matches the requested target. This is a practical post-hoc strategy for uncertainty refinement in multi-horizon forecasting pipelines [1, 2].

Parameters:
  • df_eval (pandas.DataFrame or None, default None) – Evaluation forecasts used to fit calibration factors and to compute before/after diagnostics. This table should contain the observed target column in addition to the quantile forecasts.

  • df_future (pandas.DataFrame or None, default None) – Future or inference forecasts to which the fitted factors should be applied. This table does not need observed targets.

  • target_name (str, default "subsidence") – Base name used to infer forecast and observation columns when column_map is not explicitly supplied.

  • column_map (mapping or None, default None) – Optional mapping describing the observed column and the quantile columns. This is helpful when the table does not follow the default naming conventions.

  • step_col (str, default "forecast_step") – Column used to fit and apply separate factors per forecast horizon.

  • interval (tuple of float, default (0.1, 0.9)) – Lower and upper quantiles defining the interval to calibrate. The nearest available quantiles are used.

  • target_coverage (float, default 0.8) – Desired empirical coverage after calibration.

  • median_q (float, default 0.5) – Central quantile used as the expansion anchor.

  • use ({"auto", True, False}, default "auto") –

    Control flag for whether calibration is performed.

    • False disables calibration and returns inputs unchanged.

    • "auto" skips calibration when evaluation forecasts already look calibrated.

    • True forces calibration even if the automatic check would skip it.

  • tol (float, default 0.02) – Tolerance used by the automatic already-calibrated check.

  • f_max (float, default 5.0) – Maximum factor allowed during fitting.

  • max_iter (int, default 32) – Maximum number of bisection iterations used when fitting factors.

  • keep_original (bool, default False) – If True, raw quantiles are copied into *_raw columns before calibration is applied.

  • enforce_monotonic ({"cummax", "sort", "none"}, default "cummax") – Strategy used to prevent quantile crossing after recalibration.

  • overall_key (str or None, default "__overall__") – Reserved label stored in the returned statistics dictionary for overall summary reporting.

  • calibrated_col (str, default "is_calibrated") – Column name added to calibrated outputs as a Boolean marker.

  • factor_col (str, default "calibration_factor") – Column name used to store the factor applied to each row.

  • factors (float or mapping or None, default None) – Optional user-specified calibration factors. If provided, these take precedence over factors fitted from df_eval.

  • save_eval (str or path-like or None, default None) – Optional CSV path for saving the calibrated evaluation table.

  • save_future (str or path-like or None, default None) – Optional CSV path for saving the calibrated future table.

  • save_stats (str or path-like or None, default None) – Optional JSON path for saving the calibration summary.

  • verbose (int, default 1) – Verbosity level forwarded to logging helpers.

  • logger (logging.Logger or None, default None) – Optional logger used for progress messages.

Returns:

  • df_eval_cal (pandas.DataFrame or None) – Calibrated evaluation DataFrame, or None when no evaluation table was provided.

  • df_future_cal (pandas.DataFrame or None) – Calibrated future DataFrame, or None when no future table was provided.

  • stats (dict[str, Any]) – Dictionary describing the calibration workflow. Depending on the path taken, it may contain

    • the target interval and target coverage,

    • the fitted or user-specified factors,

    • skip reasons,

    • evaluation summaries before and after calibration.

Return type:

tuple[DataFrame | None, DataFrame | None, dict[str, Any]]

Notes

In use="auto" mode, the function first checks for an explicit calibrated_col and then falls back to a simple empirical coverage-based decision. This makes the wrapper conservative in repeated workflows, where the same tables may pass through the calibration stage more than once.

The returned stats dictionary is designed to be JSON-friendly and therefore suitable for audit trails, experiment manifests, or gallery artifacts.

Examples

>>> import pandas as pd
>>> from geoprior.utils.calibrate import (
...     calibrate_quantile_forecasts,
... )
>>> df_eval = pd.DataFrame(
...     {
...         "forecast_step": [1, 1, 2, 2],
...         "subsidence_actual": [0.4, 0.7, 0.5, 0.9],
...         "subsidence_q10": [0.3, 0.5, 0.4, 0.6],
...         "subsidence_q50": [0.4, 0.7, 0.5, 0.8],
...         "subsidence_q90": [0.5, 0.9, 0.6, 1.0],
...     }
... )
>>> df_eval_cal, df_future_cal, stats = (
...     calibrate_quantile_forecasts(
...         df_eval=df_eval,
...         target_name="subsidence",
...         target_coverage=0.8,
...     )
... )
>>> isinstance(stats, dict)
True

See also

fit_interval_factors_df

Fit per-horizon interval-width correction factors.

apply_interval_factors_df

Apply a scalar or per-horizon factor map to quantile forecasts.

References

For the broader role of calibrated probabilistic multi-horizon forecasting, see Lim et al. [1].

For uncertainty-rich forecasting in the present project ecosystem, see Kouadio et al. [2].

geoprior.utils.evaluate_forecast(eval_data, *, target_name='subsidence', column_map=None, quantile_interval=(0.1, 0.9), per_horizon=False, extra_metrics=None, extra_metric_kwargs=None, overall_key='__overall__', savefile=None, save_format='.json', time_as_str=True, verbose=1, logger=None)[source]

Evaluate forecast diagnostics from an evaluation DataFrame.

This helper consumes the df_eval output from format_and_forecast() (or a compatible DataFrame) and computes aggregate metrics such as MAE, MSE, \(R^2\), coverage, and sharpness. It can also optionally evaluate metrics per forecast horizon and apply additional user-defined metrics.

By default it expects the following columns:

  • 'sample_idx'

  • 'forecast_step'

  • 'coord_t' (time)

  • Quantile or point-prediction columns for the target, e.g.:

    • Quantile mode: f'{target_name}_q10', f'{target_name}_q50', f'{target_name}_q90', …

    • Point mode: f'{target_name}_pred'.

  • Actual column: f'{target_name}_actual'.

A flexible column_map allows remapping these logical roles to arbitrary column names, e.g.:

column_map = {
    'coord_t': 'date',
    'actual': 'true_subs',
    'pred': 'subs_predicted',
}

or, for quantile columns:

column_map = {
    'coord_t': 'date',
    'quantiles': {
        0.1: 'subs_q10',
        0.5: 'subs_q50',
        0.9: 'subs_q90',
    },
}
Parameters:
  • eval_data (str, path-like, or pandas.DataFrame) – Either a path to a CSV file containing the evaluation DataFrame (as saved by format_and_forecast()) or an in-memory DataFrame.

  • target_name (str, default 'subsidence') – Base name for the target columns. Used to infer default column names such as f'{target_name}_q10', f'{target_name}_pred', and f'{target_name}_actual'.

  • column_map (dict, optional) –

    Optional mapping to override default column names. The following keys are recognized:

    • 'sample_idx' : sample index column name (default 'sample_idx').

    • 'forecast_step' : horizon index column name (default 'forecast_step').

    • 'coord_t' : time coordinate column (default 'coord_t').

    • 'actual' : name or list of names for the actual target column(s). Currently a single column is supported; default f'{target_name}_actual'.

    • 'pred' : point prediction column for non-quantile mode, default f'{target_name}_pred'.

    • 'quantiles' :

      • If a mapping: {q: col_name} for quantile levels, where q is a float in (0, 1).

      • If a sequence of column names, the quantile value will be inferred from suffix patterns like f'{target_name}_q{int(q*100):d}'.

  • quantile_interval (tuple of float, default (0.1, 0.9)) – Interval (lower, upper) used for coverage and sharpness metrics, typically corresponding to an 80% interval between Q10 and Q90.

  • per_horizon (bool, default False) – If True, compute per-horizon MAE/MSE/R² grouped by the forecast_step column.

  • extra_metrics (sequence of str or mapping, optional) –

    Optional additional metrics to compute.

    • If a sequence of strings (e.g. ['pss', 'pit']), each name is resolved via geoprior.metrics._registry.get_metric(). If the name is not present in the registry, an error is raised, prompting the user to pass a callable instead.

    • If a mapping {name: func}, each func is called as:

      func(y_true, y_pred, **extra_metric_kwargs.get(name, {}))
      

      where y_pred is the median (Q50) or point forecast.

    For more complex metrics that require full quantile structure or temporal sequences, pass a suitable wrapper function that internally uses the DataFrame as needed.

  • extra_metric_kwargs (mapping, optional) – Optional mapping of per-metric keyword arguments. Keys must match the names in extra_metrics. Each value is a dict of kwargs forwarded to the corresponding metric function.

  • savefile (str, path-like, or bool, optional) –

    If provided, metrics are saved to disk.

    • If True: a filename is auto-generated near eval_data (if it is a path) or in the current working directory.

    • If a string/path without extension: the extension is taken from save_format.

    • If a string/path with extension: that extension takes precedence over save_format.

  • save_format ({'.json', 'json', '.csv', 'csv'}, default '.json') –

    Output format when savefile is truthy. JSON preserves nested structure; CSV is flattened into a tall table.

    • For JSON, the function returns the metrics dictionary.

    • For CSV, the function returns the metrics DataFrame.

  • time_as_str (bool, default True) – If True, time keys in the result dictionary are converted to strings (useful for JSON serialization). If there is only a single time value, the result is flattened and the time key is omitted.

  • verbose (int, default 1) – Verbosity level passed to vlog().

  • logger (logging.Logger, optional) – Optional logger instance used by vlog().

  • overall_key (str | None)

Returns:

results – If save_format is JSON (default), returns a dict:

  • Single time value:

    {
        "overall_mae": ...,
        "overall_mse": ...,
        "overall_r2": ...,
        "coverage80": ...,
        "sharpness80": ...,
        "per_horizon_mae": {1: ..., 2: ..., ...},
        ...
    }
    
  • Multiple time values:

    {
        "2021": { ...metrics... },
        "2022": { ...metrics... },
    }
    

If save_format is CSV, returns a DataFrame with flattened rows:

  • Columns include: coord_t, metric, horizon, and value.

Return type:

dict or pandas.DataFrame

Notes

  • Default metrics in quantile mode:

    • overall_mae, overall_mse, overall_r2

    • coverage80 and sharpness80 (using the requested interval, e.g., Q10–Q90)

    If per_horizon=True, also:

    • per_horizon_mae, per_horizon_mse, per_horizon_r2 (each a mapping from horizon index to score).

  • Default metrics in point mode (no quantiles):

    • mae, mse, r2

    And optionally, if per_horizon=True:

    • per_horizon_mae, per_horizon_mse, per_horizon_r2.

geoprior.utils.format_and_forecast(y_pred, y_true, *, coords=None, quantiles=None, target_name='subsidence', output_target_name=None, scaler_target_name=None, target_key_pred='subs_pred', component_index=0, scaler_info=None, coord_scaler=None, coord_columns=('coord_t', 'coord_x', 'coord_y'), train_end_time=None, forecast_start_time=None, forecast_horizon=None, future_time_grid=None, eval_forecast_step=None, eval_export='all', value_mode='rate', input_value_mode='rate', rate_first='cum_over_dtref', absolute_baseline=None, sample_index_offset=0, city_name=None, model_name=None, dataset_name=None, csv_eval_path=None, csv_future_path=None, time_as_datetime=False, time_format=None, calibration=False, calibration_kwargs=None, calibration_save_stats=None, eval_metrics=False, metrics_column_map=None, metrics_quantile_interval=(0.1, 0.9), metrics_per_horizon=False, metrics_extra=None, metrics_extra_kwargs=None, metrics_savefile=None, metrics_save_format='.json', metrics_time_as_str=True, output_unit=None, output_unit_from='m', output_unit_mode='overwrite', output_unit_suffix='_mm', output_unit_col=None, verbose=1, logger=None, **kws)[source]

Format PINN forecasts into evaluation and future DataFrames.

This helper takes the raw model outputs (already split into y_pred['subs_pred'] / y_pred['gwl_pred']), the matching ground-truth dictionary (y_true), and optional coordinate and scaler information, and returns two DataFrames:

  • df_eval: predictions + actuals for an evaluation year (typically the last training year, e.g. 2022).

  • df_future: predictions for the future horizon (e.g. 2023–2025), without actuals.

Parameters:
  • y_pred (dict) –

    Dictionary of model predictions, as returned by GeoPriorSubsNet.predict post-processed into {'subs_pred': ..., 'gwl_pred': ...}.

    For subsidence, the expected shapes are:

    • Quantile mode: (B, H, Q, O) where: B = batch size, H = horizon steps, Q = number of quantiles, O = output dim.

    • Point mode: (B, H, O).

  • y_true (dict or None) –

    Dictionary of true targets, typically

    {'subsidence': ..., 'gwl': ...} or {'subs_pred': ..., 'gwl_pred': ...}.

    If None, evaluation DataFrame is still created but without the actual-value column.

  • coords (ndarray, optional) – Optional coordinates array aligned with predictions. Commonly shaped (B, H, 3) with columns [t_scaled, x_scaled, y_scaled]. Only x and y are used when inverse-transforming spatial coordinates; time is overwritten by the provided temporal config if given.

  • quantiles (list of float or None, optional) – List of quantiles (e.g. [0.1, 0.5, 0.9]) if the model was trained in probabilistic mode. If None, a single prediction column is emitted instead.

  • target_name (str, default 'subsidence') –

    Logical target identifier used as the default key for locating the target scaler in scaler_info and as a fallback for resolving truth arrays in y_true.

    Column naming is controlled by output_target_name (or the auto-derived output prefix when it is None).

  • output_target_name (str or None, optional) –

    Output prefix used when creating DataFrame columns for predictions and actuals.

    This controls the column naming only (e.g. the function will emit f"{output_target_name}_q10", f"{output_target_name}_pred", and f"{output_target_name}_actual").

    If None (default), the function derives the output prefix from target_name and applies a small convenience rule: if target_name ends with "_cum" or "_cumulative", that suffix is stripped for output naming.

    This keeps downstream tooling consistent (many plotting and metrics utilities expect names like subsidence_q10 rather than subsidence_cum_q10), while still allowing the scaler lookup to use the true target key. For example, with target_name="subsidence_cum" and output_target_name=None, output columns become subsidence_q10, subsidence_q50, and subsidence_actual. If output_target_name="subsidence_cum", the output columns keep the suffix such as subsidence_cum_q10.

  • scaler_target_name (str or None, optional) –

    Name used to locate the target scaling block inside scaler_info and to perform inverse-transform for predictions and actuals.

    This controls the scaler key and inverse scaling, not the output column naming.

    If None (default), the scaler key is assumed to be target_name. This is important when you want clean output columns but the scaler was fitted/stored under the original target name.

    A common pattern is to keep target_name="subsidence_cum" so the scaler lookup matches the Stage-1 schema, while letting output_target_name=None produce clean output columns. In that setup, inverse transform still uses the subsidence_cum scaler key, while output columns use the subsidence_ prefix because of the auto-strip rule.

  • target_key_pred (str, default 'subs_pred') – Key inside y_pred that holds the subsidence forecasts.

  • component_index (int, default 0) – Index along the output dimension O to use when output_subsidence_dim > 1. For scalar subsidence this is 0.

  • scaler_info (dict, optional) – Optional Stage-1 scaler_info mapping containing a target scaler under keys such as 'targets' or 'target'. The target block is expected to provide an sklearn-like transformer under 'scaler' together with column names under 'columns' or 'cols'. If present and consistent, subsidence values (predicted and actual) are inverse-transformed for target_name.

  • coord_scaler (object, optional) – Optional scaler used for coordinates. If provided, it is only used to inverse-transform coord_x and coord_y when coords is given and coord_columns can be matched. Time is not taken from the inverse transform; it is controlled by the temporal config.

  • coord_columns (tuple of str, default (``’coord_t’:py:class:`,`’coord_x’:py:class:`,`’coord_y’``)) – Logical names of the time, x, and y coordinate columns. These are used for DataFrame column naming and for mapping into coord_scaler if its block carries column names.

  • train_end_time (scalar or str or datetime, optional) – Physical time associated with the evaluation year (e.g. 2022). If eval_forecast_step is not given, the last horizon step is assumed to correspond to this time.

  • forecast_start_time (scalar or str or datetime, optional) – First time in the future forecast horizon (e.g. 2023).

  • forecast_horizon (int, optional) – Number of forecast steps in the future horizon (e.g. 3). If future_time_grid is not given, this is used together with forecast_start_time to build a regular grid.

  • future_time_grid (array-like, optional) – Explicit physical times for each forecast step, length H. For yearly data this might be [2023, 2024, 2025]. If provided, it overrides any automatic construction from forecast_start_time and forecast_horizon.

  • eval_forecast_step (int or None, optional) – Horizon step index (1-based) to use for evaluation. If None, defaults to the last horizon step H.

  • eval_export ({"all", "last"} or str or int or sequence, optional) –

    Controls which evaluation rows are exported in df_eval and written to csv_eval_path. By default ("all"), the function exports the multi-horizon evaluation DataFrame (df_eval_all), which contains one row per sample and forecast step (e.g. years 2020, 2021, 2022 for H=3).

    Accepted values are:

    • "all" or "full" or "horizons" : export all horizons from df_eval_all.

    • "last" or "single" or "default" : export only the single evaluation step specified by eval_forecast_step (backwards-compatible behaviour).

    • Other str (e.g. "2022") : interpreted as a time value for coord_t; only rows of df_eval_all whose time column matches this value are exported.

    • int or scalar non-string : interpreted as a single time value (e.g. 2022).

    • sequence of values (e.g. [2021, 2022]) : interpreted as a set of time values; only rows whose coord_t belongs to this set are exported.

    If time_as_datetime=True, the selection values are converted with pandas.to_datetime using time_format before filtering. If df_eval_all is not available (e.g. no ground truth was provided), the function falls back to exporting the single-step df_eval regardless of eval_export.

  • value_mode ({"rate", "cumulative", "absolute_cumulative"}, optional) –

    Controls how forecast values are interpreted along the temporal horizon for each sample. The default is "rate", which treats each forecast step as an incremental rate (e.g. annual subsidence rate) and leaves predictions unchanged.

    Supported modes are:

    • "rate" : keep per-step predictions as provided by the model (current behaviour).

    • "cumulative" or "cum" : convert per-step rates into relative cumulative values by applying a cumulative sum over forecast_step for each sample_idx. For example, for years 2023–2025, the value at 2024 is the sum of the 2023 and 2024 rates.

    • "absolute_cumulative" or "abs_cum" or "absolute" : same as "cumulative", then add an absolute baseline provided by absolute_baseline (e.g. cumulative subsidence at the end of the training period), yielding absolute cumulative trajectories.

    Cumulative transforms are applied consistently to:

    • the future forecast DataFrame (df_future),

    • the multi-horizon evaluation DataFrame (df_eval_all),

    • and the single-step evaluation DataFrame (df_eval, which is regenerated from df_eval_all after the transformation).

    When an unsupported string is given, the function logs a warning and falls back to "rate".

  • absolute_baseline (float or Mapping[int, float], optional) –

    Baseline value to use when value_mode requests absolute cumulative outputs ("absolute_cumulative", "abs_cum", "absolute"). This baseline is interpreted as the pre-forecast cumulative level for each sample, for example, cumulative subsidence at train_end_time (e.g. end of 2022), and is added after applying the cumulative sum over the forecast horizon.

    If a scalar float is provided, the same baseline value is added to all samples. If a mapping is provided, it must map sample_idx (integers) to baseline values, allowing per-sample baselines:

    • absolute_baseline = {sample_idx: baseline_value, ...}

    Only prediction columns for target_name are shifted (e.g. "subsidence_q10", "subsidence_q50", "subsidence_q90" or "subsidence_pred"). When df_eval_all is present, the corresponding "<target_name>_actual" column is shifted as well, so evaluation metrics operate on absolute cumulative values.

    If value_mode is an absolute cumulative variant but absolute_baseline is None, the function logs a warning and degrades gracefully to relative cumulative mode (i.e. no baseline shift is applied).

  • sample_index_offset (int, default 0) – Offset added to sample_idx (useful when concatenating multiple tiles).

  • city_name (str, optional) – Optional metadata used only for logging.

  • model_name (str, optional) – Optional metadata used only for logging.

  • dataset_name (str, optional) – Optional metadata used only for logging.

  • csv_eval_path (str, optional) – If provided, df_eval is written to this path (directories are created if needed).

  • csv_future_path (str, optional) – If provided, df_future is written to this path.

  • time_as_datetime (bool, default False) – If True, time values are converted using pandas.to_datetime() with the provided time_format (if any).

  • time_format (str or None, optional) – Optional format string passed to pandas.to_datetime() when time_as_datetime=True.

  • eval_metrics (bool, default False) – If True, automatically call evaluate_forecast() on the resulting df_eval to compute diagnostics. Metrics are not returned by this function; they are either written to disk (if metrics_savefile is provided) or discarded. For programmatic access to the metrics dictionary, call evaluate_forecast() directly.

  • metrics_column_map (mapping, optional) – Optional column mapping forwarded to evaluate_forecast() (see its documentation for details). If None, default column names such as 'coord_t', 'forecast_step', f'{target_name}_q10', and f'{target_name}_actual' are assumed.

  • metrics_quantile_interval (tuple of float, default (0.1, 0.9)) – Interval used for coverage and sharpness diagnostics in quantile mode, forwarded to evaluate_forecast().

  • metrics_per_horizon (bool, default False) – If True, per-horizon MAE/MSE/R² are computed by evaluate_forecast() and included in the diagnostics.

  • metrics_extra (sequence or mapping, optional) –

    Optional additional metrics to compute, forwarded to evaluate_forecast(). Can be:

    • A sequence of metric names (resolved via geoprior.metrics._registry.get_metric).

    • A mapping {name: func} where func is a callable taking (y_true, y_pred, **kwargs).

  • metrics_extra_kwargs (mapping, optional) – Optional per-metric keyword arguments, forwarded to evaluate_forecast(). Keys must match metric names in metrics_extra.

  • metrics_savefile (str, path-like, bool, or None) – If truthy, diagnostics from evaluate_forecast() are written to disk. Behavior matches the savefile argument of evaluate_forecast(). When True, a filename is auto-generated near the evaluation CSV (if any) or in the current working directory.

  • metrics_save_format ({'.json', 'json', '.csv', 'csv'}, default '.json') – Output format for diagnostics written by evaluate_forecast(). JSON preserves the nested metric structure; CSV flattens it into a tall table.

  • metrics_time_as_str (bool, default True) – If True, time keys in the diagnostics written by evaluate_forecast() are converted to strings (useful for JSON serialization).

  • verbose (int, default 1) – Verbosity level passed to vlog().

  • logger (logging.Logger, optional) – Logger instance; if None, a module-level LOG is used.

  • input_value_mode (str)

  • rate_first (str)

  • calibration (str | bool)

  • calibration_kwargs (Mapping[str, Any] | None)

  • calibration_save_stats (str | PathLike | None)

  • output_unit (str | None)

  • output_unit_from (str)

  • output_unit_mode (str)

  • output_unit_suffix (str)

  • output_unit_col (str | None)

Returns:

  • df_eval_to_write (pandas.DataFrame) – DataFrame containing predictions and actuals for the evaluation time. Columns include:

    • 'sample_idx'

    • 'forecast_step'

    • quantile columns (e.g. subsidence_q10) or subsidence_pred

    • 'subsidence_actual' (if y_true given)

    • coord_t, coord_x, coord_y (names from coord_columns).

  • df_future (pandas.DataFrame) – DataFrame containing predictions for the future horizon, without actuals. Same structure as df_eval but without the actual-value column.

Return type:

tuple[DataFrame, DataFrame]

Notes

This function separates scaler lookup (scaler_target_name) from output column naming (output_target_name). This is useful when the stored scaler key contains suffixes like "_cum" but downstream tools expect canonical names such as columns prefixed with subsidence_.

geoprior.utils.create_spatial_clusters(df, spatial_cols=None, cluster_col='region', n_clusters=None, algorithm='kmeans', view=True, figsize=(14, 10), s=60, plot_style='seaborn', cmap='tab20', show_grid=True, grid_props=None, auto_scale=True, savefile=None, verbose=1, **kwargs)[source]

Cluster 2D spatial data in df using <algorithm> and optionally plot the results.

This function, <create_spatial_clusters>, extracts two coordinate columns from <df> to form clusters via methods such as ‘kmeans’, ‘dbscan’, or ‘agglo’ (agglomerative). It uses the function filter_valid_kwargs (when relevant) to strip out invalid parameters for certain estimators, and writes cluster labels into <cluster_col>.

Parameters:
  • df (pandas.DataFrame) – Input DataFrame holding spatial coordinates and optional other fields.

  • spatial_cols (list of str, optional) – Two-column list for x and y coordinates. Defaults to ['longitude','latitude'] if None.

  • cluster_col (str, default 'region') – Name of the column to store the assigned cluster labels.

  • n_clusters (int, optional) – Number of clusters to form. If not provided for KMeans, it is auto-detected. For DBSCAN or Agglomerative, a warning is issued if not set.

  • algorithm (str, default 'kmeans') – Choice of clustering algorithm among ['kmeans','dbscan','agglo'].

  • view (bool, default True) – If True, displays a scatterplot of the final clusters.

  • figsize (tuple, default (14, 10)) – Size of the displayed figure for the cluster plot.

  • s (int, default 60) – Marker size in the scatterplot.

  • plot_style (str, default 'seaborn') – Matplotlib style used for the plot.

  • cmap (str, default 'tab20') – Colormap name used to differentiate clusters.

  • show_grid (bool, default True) – Toggles grid lines on or off.

  • grid_props (dict, optional) – Additional keyword arguments controlling the grid style.

  • auto_scale (bool, default True) – If True, standardize coordinates before clustering.

  • savefile (str, optional) – File path to save the data with an additional <cluster_col> storing the assigned cluster labels if desired.

  • verbose (int, default 1) – Controls console logs. Higher values yield more details about scaling and cluster detection.

  • **kwargs – Additional keyword arguments passed to the chosen algorithm (filtered by filter_valid_kwargs for KMeans, DBSCAN, AgglomerativeClustering ).

Returns:

A copy of <df> with an additional <cluster_col> storing the assigned cluster labels.

Return type:

pandas.DataFrame

Notes

If <auto_scale> is True, it uses a standard scaler to normalize the coordinate columns. The scatterplot is generated using the library seaborn for enhanced styling.

By default, for <algorithm> = “kmeans”, the model attempts to minimize:

(43)#\[J = \sum_{i=1}^{N} \min_{\mu_j} \lVert x_i - \mu_j \rVert^2\]

where \(x_i\) are the scaled or raw 2D coordinates in <df>. The function can optionally auto-detect n_clusters using a silhouette and elbow analysis if not provided.

Examples

>>> from geoprior.utils.spatial_utils import create_spatial_clusters
>>> import pandas as pd
>>> df = pd.DataFrame({
...     "longitude": [0.1, 0.2, 2.2, 2.3],
...     "latitude": [1.0, 1.1, 2.1, 2.2]
... })
>>> # KMeans with auto scale and auto-detect k
>>> result = create_spatial_clusters(
...     df=df,
...     algorithm="kmeans",
...     view=True
... )
>>> # DBSCAN with custom arguments
>>> result_db = create_spatial_clusters(
...     df=df,
...     algorithm="dbscan",
...     eps=0.5,
...     min_samples=2
... )

See also

filter_valid_kwargs

Helps discard unsupported keyword arguments for chosen estimators.

geoprior.utils.spatial_sampling(data, sample_size=0.01, stratify_by=None, spatial_bins=10, spatial_cols=None, method='abs', min_relative_ratio=0.01, random_state=42, savefile=None, verbose=1)[source]

Sample spatial data intelligently to represent the distribution of the whole area and include different years.

This function performs stratified sampling on spatial data, ensuring that the sample reflects both spatial distribution and temporal aspects of the entire dataset. It combines spatial stratification based on coordinates and additional stratification columns specified by the user.

Parameters:
  • data (pandas.DataFrame) – The input DataFrame to sample from. Must contain spatial coordinate columns (e.g., ‘longitude’, ‘latitude’) and any columns specified in stratify_by.

  • sample_size (float or int, optional) – The proportion or absolute number of samples to select. If float, should be between 0.0 and 1.0 and represents the fraction of the dataset to include in the sample. If int, represents the absolute number of samples to select. Default is 0.01 (1% of the data).

  • stratify_by (list of str, optional) – List of column names to stratify by.

  • spatial_bins (int or tuple/list of int, optional) – Number of bins to divide the spatial coordinates into. If an integer, the same number of bins is used for all spatial dimensions. If a tuple or list, its length must match the number of spatial columns, specifying the number of bins for each spatial dimension. Default is 10.

  • spatial_cols (list or tuple of str, optional) – List of spatial coordinate column names. Can accept one or two columns. If None, the function checks for columns named ‘longitude’ and/or ‘latitude’ in data. If only one spatial column is provided or found, a warning is issued, suggesting that providing both spatial columns is recommended for more accurate sampling. If more than two columns are provided, an error is raised.

  • method (str, {'abs', 'relative'}, default 'abs') – Defines how the sample size is determined. 'abs' or 'absolute' uses a fixed sampling proportion based on sample_size. 'relative' scales sampling by dataset stratification so small groups still receive a proportional sample controlled by min_relative_ratio.

  • min_relative_ratio (float, default 0.01) – Controls the minimum allowable fraction of records that must be sampled when method='relative'. It must be between 0 and 1. For example, min_relative_ratio=0.05 requests at least 5 percent of the total dataset size from each stratification group when possible; if a group is smaller than that minimum, the entire subset is sampled instead.

  • random_state (int, optional) – Random seed for reproducibility. Default is 42.

  • verbose (int, default 1) – Controls progress-bar and status output during execution. Larger values produce more detailed messages.

Returns:

sampled_data – A sampled DataFrame representing the distribution of the whole area and including different years.

Return type:

pandas.DataFrame

Notes

The function performs stratified sampling based on spatial bins and other specified stratification columns. Spatial coordinates are binned using quantile-based discretization (pandas.qcut()), ensuring each bin has approximately the same number of observations.

Let \(N\) be the total number of samples in data, and \(n\) be the desired sample size. The function calculates the number of samples to draw from each stratification group based on the proportion of the group size to the total dataset size:

(44)#\[n_i = \left\lceil \frac{N_i}{N} \times n \right\rceil\]

where \(N_i\) is the size of group \(i\), and \(n_i\) is the number of samples to draw from group \(i\).

The function ensures that all specified spatial and stratification columns exist in data, that the number of spatial bins matches the number of spatial columns, and that the sample size is valid. A warning is issued when only one spatial column is used because two spatial columns usually give more reliable spatial sampling.

Examples

>>> from geoprior.utils.spatial_utils import spatial_sampling
>>> import pandas as pd
>>> # Assume 'df' is a pandas DataFrame with columns
>>> # 'longitude', 'latitude', 'year', and other data.
>>> sampled_df = spatial_sampling(
...     data=df,
...     sample_size=0.05,
...     stratify_by=['year', 'geological_category'],
...     spatial_bins=(10, 15),
...     spatial_cols=['longitude', 'latitude'],
...     random_state=42
... )
>>> print(sampled_df.shape)

See also

pandas.qcut

Quantile-based discretization function used for binning.

sklearn.model_selection.StratifiedShuffleSplit

For stratified sampling.

batch_spatial_sampling

Resample spatial data with batching.

geoprior.utils.make_txy_coords(t, x, y, *, time_shift='min', xy_shift='min', time_shift_value=None, x_shift_value=None, y_shift_value=None, dtype='float32')[source]

Build coords tensor (t, x, y) with OPTIONAL shifting (translation only).

This is designed for your “not normalized” workflow:
  • You keep SI units (years and meters),

  • but you avoid feeding huge UTM magnitudes (e.g. 3e5, 2.5e6) into coord MLPs by shifting x,y (and optionally t).

Notes

  • This does NOT min-max scale to [0,1]. It only translates.

  • Returning coord_mins/coord_ranges is still useful for logging/debug.

Parameters:
Return type:

CoordsPack

Model-facing utility package#

The model-side utility package complements the workflow surface and is closer to data layout, model I/O contracts, sequence generation, and PINN-oriented helper logic.

Public exports for model utility helpers.

geoprior.models.utils.compute_anomaly_scores(y_true, y_pred=None, method='statistical', threshold=3.0, domain_func=None, contamination=0.05, epsilon=1e-06, estimator=None, random_state=None, residual_metric='mse', objective='ts', verbose=1)[source]#

Compute anomaly scores for given true targets using various methods.

This utility function, anomaly_scores, provides a flexible approach to compute anomaly scores outside the XTFT model itself. Anomaly scores serve as indicators of how unusual certain observations are, guiding the model towards more robust and stable forecasts. By detecting and quantifying anomalies, practitioners can adjust forecasting strategies, improve predictive performance, and handle irregular patterns more effectively.

Parameters:
  • y_true (np.ndarray) –

    The ground truth target values with shape (B, H, O), where: - B: batch size - H: number of forecast horizons (time steps ahead) - O: output dimension (e.g., number of target variables).

    Typically, y_true corresponds to the same array passed as the forecast target to the model. All computations of anomalies are relative to these true values or, if provided, their predicted counterparts y_pred.

  • y_pred (np.ndarray, optional) – The predicted values with shape (B, H, O). If provided and the method is set to ‘residual’, the anomaly scores are derived from the residuals between y_true and y_pred. In this scenario, anomalies reflect discrepancies indicating unusual conditions or model underperformance.

  • method (str, optional) –

    The method used to compute anomaly scores. Supported options:

    • "statistical" or "stats": Uses mean and standard deviation of y_true to measure deviation from the mean. Points far from the mean by a certain factor (controlled by threshold) yield higher anomaly scores.

      Formally, let \(\mu\) be the mean of y_true and \(\sigma\) its standard deviation. The anomaly score for a point \(y\) is:

      (45)#\[(\frac{y - \mu}{\sigma + \varepsilon})^2\]

      where \(\varepsilon\) is a small constant for numerical stability.

    • "domain": Uses a domain-specific heuristic (provided by domain_func) to compute scores. If no domain_func is provided, a default heuristic marks negative values as anomalies.

    • "isolation_forest" or "if": Employs the IsolationForest algorithm to detect outliers. The model learns a structure to isolate anomalies more quickly than normal points. Higher contamination rates allow more points to be considered anomalous.

    • "residual": If y_pred is provided, anomalies are derived from residuals: the difference (y_true - y_pred). By default, mean squared error (mse) is used. Other metrics include mae and rmse, offering flexibility in quantifying deviations:

      (46)#\[\text{MSE: }(y_{true} - y_{pred})^2\]

    Default is "statistical".

  • threshold (float, optional) – Threshold factor for the statistical method. Defines how far beyond mean ± (threshold * std) is considered anomalous. Though not directly applied as a mask here, it can guide interpretation of scores. Default is 3.0.

  • domain_func (callable, optional) –

    A user-defined function for domain method. It takes y_true as input and returns an array of anomaly scores with the same shape. If none is provided, the default heuristic:

    (47)\[\begin{split}\text{anomaly}(y) = \begin{cases} |y| \times 10 & \text{if } y < 0 \\ 0 & \text{otherwise} \end{cases}\end{split}\]

  • contamination (float, optional) – Used in the isolation_forest method. Specifies the proportion of outliers in the dataset. Default is 0.05.

  • epsilon (float, optional) – A small constant \(\varepsilon\) for numerical stability in calculations, especially during statistical normalization. Default is 1e-6.

  • estimator (object, optional) – A pre-fitted IsolationForest estimator for the isolation_forest method. If not provided, a new estimator will be created and fitted to y_true.

  • random_state (int, optional) – Sets a random state for reproducibility in the isolation_forest method.

  • residual_metric (str, optional) –

    The metric used to compute anomalies from residuals if method is set to ‘residual’. Supported metrics:

    • "mse": mean squared error per point (residuals**2)

    • "mae": mean absolute error per point |residuals|

    • "rmse": root mean squared error sqrt((residuals**2))

    Default is "mse".

  • objective (str, optional) – Specifies the type of objective, for future extensibility. Default is "ts" indicating time series. Could be extended for other tasks in the future.

  • verbose (int, optional) – Controls verbosity. If verbose=1, some messages or warnings may be printed. Higher values might produce more detailed logs.

Returns:

anomaly_scores – An array of anomaly scores with the same shape as y_true. Higher values indicate more unusual or anomalous points.

Return type:

np.ndarray

Notes

Choosing an appropriate method depends on the data characteristics, domain requirements, and model complexity. Statistical methods are quick and interpretable but may oversimplify anomalies. Domain heuristics leverage expert knowledge, while isolation forest applies a more robust, data-driven approach. Residual-based anomalies help assess model performance and highlight periods where the model struggles.

Examples

>>> from geoprior.models.losses import compute_anomaly_scores
>>> import numpy as np
>>> # Statistical method example
>>> y_true = np.random.randn(32, 20, 1)  # (B,H,O)
>>> scores = compute_anomaly_scores(y_true, method='statistical', threshold=3)
>>> scores.shape
(32, 20, 1)
>>> # Domain-specific example
>>> def my_heuristic(y):
...     return np.where(y < -1, np.abs(y)*5, 0.0)
>>> scores = compute_anomaly_scores(y_true, method='domain',
                                    domain_func=my_heuristic)
>>> # Isolation Forest example
>>> scores = compute_anomaly_scores(y_true, method='isolation_forest',
                                    contamination=0.1)
>>> # Residual-based example
>>> y_pred = y_true + np.random.normal(0, 1, y_true.shape)  # Introduce noise
>>> scores = compute_anomaly_scores(y_true, y_pred=y_pred, method='residual',
                                    residual_metric='mae')

See also

geoprior.models.losses.objective_loss

For integrating anomaly scores into a multi-objective loss.

geoprior.models.utils.compute_forecast_horizon(data=None, dt_col=None, start_pred=None, end_pred=None, error='raise', verbose=1)[source]#

Compute the forecast horizon for time series forecasting models.

This function calculates the number of future time steps (forecast_horizon) a model should predict based on the provided data or specified prediction dates. It intelligently infers the frequency of the data and computes the horizon accordingly. The function accommodates various datetime formats and handles different input scenarios robustly.

Parameters:
  • data (pandas.DataFrame, pandas.Series, list, or numpy.ndarray, optional) – The dataset containing datetime information. If a pandas.DataFrame is provided, the dt_col parameter must be specified to indicate which column contains the datetime data. For pandas.Series, list, or numpy.ndarray, the function attempts to infer the frequency directly.

  • dt_col (str, optional) – The name of the column in data that contains datetime information. This parameter is required if data is a pandas.DataFrame. Example: dt_col='timestamp'

  • start_pred (str, int, or datetime-like) – The starting point for forecasting. This can be a date string (e.g., ‘2023-04-10’), a datetime object, or an integer representing a year (e.g., 2024). If an integer is provided, it is interpreted as a year, and a warning is issued to inform the user of this interpretation.

  • end_pred (str, int, or datetime-like) – The ending point for forecasting. Similar to start_pred, this can be a date string, a datetime object, or an integer representing a year. The function calculates the forecast horizon based on the difference between start_pred and end_pred.

  • error ({'raise', 'warn', 'ignore'}, default 'raise') –

    Defines the error handling behavior when encountering issues such as invalid input types, missing date columns, or unparseable dates.

    • ’raise’: Raises a ValueError when an error is encountered.

    • ’warn’: Emits a warning and attempts to proceed with default behavior.

  • verbose (int, default 1) –

    Controls the level of verbosity for debug information.

    • 0: No output.

    • 1: Minimal output (e.g., starting message).

    • 2: Intermediate output (e.g., detected dates, computed horizons).

    • 3: Detailed output (e.g., types of predictions, inferred frequencies).

Returns:

The computed forecast_horizon representing the number of steps ahead the model should predict. Returns None if an error occurs and error is set to ‘warn’.

Return type:

int or None

Raises:

ValueError – If invalid parameters are provided and error is set to ‘raise’.

Examples

>>> from geoprior.models.utils import compute_forecast_horizon
>>> import pandas as pd
>>> import numpy as np
>>> from datetime import datetime, timedelta
>>>
>>> # Example 1: Using a DataFrame with a Date Column
>>> df = pd.DataFrame({
...     'date': pd.date_range(start='2023-01-01', periods=100, freq='D'),
...     'value': np.random.randn(100)
... })
>>> horizon = compute_forecast_horizon(
...     data=df,
...     dt_col='date',
...     start_pred='2023-04-10',
...     end_pred='2023-04-20',
...     error='raise',
...     verbose=3
... )
>>> print(f"Forecast Horizon: {horizon}")
Forecast Horizon: 11
>>> # Example 2: Using a List of Datetimes
>>> dates = [datetime(2023, 1, 1) + timedelta(days=i) for i in range(100)]
>>> horizon = compute_forecast_horizon(
...     data=dates,
...     start_pred='2023-04-10',
...     end_pred='2023-04-20',
...     error='warn',
...     verbose=2
... )
>>> print(f"Forecast Horizon: {horizon}")
Forecast Horizon: 11
>>> # Example 3: Handling Integer Years
>>> horizon = compute_forecast_horizon(
...     start_pred=2024,
...     end_pred=2030,
...     error='raise',
...     verbose=1
... )
Forecast Horizon: 7
>>> # Example 4: Without Providing Data (Assuming Frequency Based on Prediction Dates)
>>> horizon = compute_forecast_horizon(
...     start_pred='2023-04-10',
...     end_pred='2023-04-20',
...     error='raise',
...     verbose=1
... )
Forecast Horizon: 11

Notes

  • When data is not provided, the function relies solely on the difference between start_pred and end_pred to compute the forecast horizon. In such cases, if the frequency cannot be inferred, the horizon is calculated based on the largest possible time unit (years, months, weeks, days).

  • If start_pred is after end_pred, the function returns 0 and issues a warning or raises an error based on the error parameter.

  • The function attempts to infer the frequency of the data using pandas utilities. If the frequency cannot be inferred, it defaults to calculating the horizon based on the time difference in the most significant unit.

See also

pandas.date_range

Generates a fixed frequency DatetimeIndex.

pandas.infer_freq

Infers the frequency of a DatetimeIndex.

geoprior.models.utils.create_sequences(df, sequence_length, target_col, step=1, include_overlap=True, drop_last=True, forecast_horizon=None, verbose=3)[source]#

Create input sequences and corresponding targets for time series forecasting.

The create_sequences function generates sequences of features and their corresponding targets from a time series dataset. This is essential for training sequence models like Temporal Fusion Transformers, LSTMs, and others that rely on temporal dependencies.

See more in User Guide.

Parameters:
  • df (pandas.DataFrame) – The processed DataFrame containing features and the target variable.

  • sequence_length (int) – The number of past time steps to include in each input sequence.

  • target_col (str) – The name of the target column.

  • step (int, default 1) – The step size between the starts of consecutive sequences.

  • include_overlap (bool, default True) – Whether to include overlapping sequences based on the step size.

  • drop_last (bool, default True) – Whether to drop the last sequence if it does not have enough data points.

  • forecast_horizon (int, optional, default None) – The number of future time steps to predict. If set to None, the function will create targets for a single future time step. If provided, targets will consist of the next forecast_horizon time steps.

  • verbose (int, default 3) – Controls the verbosity of logging. Ranges from 0 (no logs) to 7 (maximal logs).

Returns:

A tuple containing:
  • sequences: Array of input sequences with shape (num_sequences, sequence_length, num_features).

  • targets:
    • If forecast_horizon is None: Array of target values with shape (num_sequences,).

    • If forecast_horizon is an integer: Array of target sequences with shape (num_sequences, forecast_horizon).

Return type:

Tuple[`numpy.ndarray`, numpy.ndarray]

Raises:

ValueError – If the DataFrame df does not contain the target_col.

Examples

>>> import pandas as pd
>>> import numpy as np
>>> from geoprior.models.utils import create_sequences
>>> # Create a dummy DataFrame
>>> data = pd.DataFrame({
...     'feature1': np.random.rand(100),
...     'feature2': np.random.rand(100),
...     'feature3': np.random.rand(100),
...     'target': np.random.rand(100)
... })
>>> # Create sequences for single-step forecasting
>>> sequence_length = 4
>>> sequences, targets = create_sequences(
...     df=data,
...     sequence_length=sequence_length,
...     target_col='target',
...     step=1,
...     include_overlap=True,
...     drop_last=True,
...     forecast_horizon=None
... )
>>> print(sequences.shape)
(95, 4, 4)
>>> print(targets.shape)
(95,)
>>> # Create sequences for multi-step forecasting (e.g., 3 steps)
>>> forecast_horizon = 3
>>> sequences, targets = create_sequences(
...     df=data,
...     sequence_length=4,
...     target_col='target',
...     step=1,
...     include_overlap=True,
...     drop_last=True,
...     forecast_horizon=3
... )
>>> print(sequences.shape)
(92, 4, 4)
>>> print(targets.shape)
(92, 3)

Notes

  • Sequence Creation: The function slides a window of size sequence_length across the DataFrame to create input sequences. Each sequence is associated with a target value or sequence of values that immediately follow the input sequence.

  • Forecast Horizon:
    • If forecast_horizon is None, the function creates targets for a single future time step.

    • If forecast_horizon is an integer H, the function creates targets consisting of the next H time steps.

  • Step Size: The step parameter controls the stride of the sliding window. A step of 1 results in overlapping sequences, while a larger step reduces overlap.

  • Handling Incomplete Sequences: If drop_last is set to False, the function includes the last sequence even if it doesn’t have enough data points to form a complete sequence or target.

  • Data Validation: The function utilizes are_all_frames_valid from geoprior.core.checks to ensure the integrity of input DataFrame before processing and exist_features to verify the presence of the target column.

The sequences generation can be expressed as:

(48)#\[\begin{split}\\text{For each sequence } i, \\\\ \\mathbf{X}^{(i)} = \\left[ \\mathbf{x}_{i}, \\mathbf{x}_{i+1}, \\\\ \\dots, \\mathbf{x}_{i+T-1} \\right] \\\\ y^{(i)} = \\begin{cases} \\mathbf{x}_{i+T} & \\text{if } \\text{forecast\\_horizon} = \\text{None} \\\\ \\left[ \\mathbf{x}_{i+T}, \\mathbf{x}_{i+T+1}, \\dots, \\\\ \\mathbf{x}_{i+T+H-1} \\right] & \\text{if } \\text{forecast\\_horizon} = H \\end{cases}\end{split}\]
Where:
  • \(\\mathbf{X}^{(i)}\) is the input sequence of length \(T\).

  • \(y^{(i)}\) is the target value(s) following the sequence.

See also

geoprior.models.utils.split_static_dynamic

Function to split sequences into static and dynamic inputs.

geoprior.models.utils.extract_batches_from_dataset(dataset, num_batches_to_extract=1, agg=False, errors='warn')[source]#

Extracts a specified number of batches from a tf.data.Dataset. Optionally aggregates the extracted batches.

Parameters:
  • dataset (tf.data.Dataset) – The TensorFlow dataset to extract batches from.

  • num_batches_to_extract (Union[int, str], default 1) – Number of batches: int, or ‘all’, ‘*’, ‘auto’.

  • agg (bool, default False) – If True, attempts to aggregate the extracted batches into a single tuple structure by concatenating corresponding tensors/arrays or aggregating dictionaries.

  • errors (str, default 'warn') – Error handling: ‘raise’, ‘warn’, ‘ignore’.

Returns:

If agg is False, returns a list of batch tuples. If agg is True, returns one aggregated tuple or None if no batches were extracted. When zero batches are requested or the dataset is empty, the function returns an empty list for agg=False and None for agg=True.

Return type:

Union[List[Tuple[Any, ]], Optional[Tuple[Any, ]]]

Raises:
  • TypeError – If dataset is not a tf.data.Dataset or num_batches_to_extract is invalid (and errors=’raise’).

  • ValueError – If num_batches_to_extract is negative, or fewer batches are available than requested (and errors=’raise’ and not taking all).

  • RuntimeError – For unexpected errors during dataset iteration (and errors=’raise’).

geoprior.models.utils.extract_callbacks_from(fit_params, return_fit_params=False)[source]#

Extract Keras callbacks from a dictionary of fit parameters. The function scans the provided fit_params dictionary, looks for keys associated with callback instances, and removes them from fit_params returning a list of callbacks. Optionally, it can return the updated dictionary without these callbacks.

This function is particularly useful when working with scikit-learn- style estimators that pass parameters through **fit_params. By extracting callbacks directly, the user can seamlessly integrate TensorFlow/Keras callbacks such as EarlyStopping or ModelCheckpoint into their training pipelines.

The function can handle two scenarios: 1. A parameter called 'callbacks' containing a list of callbacks. 2. Individual callback instances passed as keyword arguments.

After extraction, if return_fit_params=True, it returns a tuple (callbacks, fit_params) where callbacks is the extracted list and fit_params is the remaining dictionary. Otherwise, it returns only callbacks.

(49)#\[n_{\mathrm{callbacks}} = n_{\mathrm{list\_callbacks}} \;+\; n_{\mathrm{kwarg\_callbacks}}\]

Here, \(n_{\mathrm{callbacks}}\) represents the total number of extracted callbacks, \(n_{\mathrm{list\_callbacks}}\) is the number of callbacks initially found in the 'callbacks' parameter, and \(n_{\mathrm{kwarg\_callbacks}}\) is the number of callbacks discovered among the other keyword arguments.

Parameters:
  • fit_params (dict) –

    The dictionary of parameters to be passed to a model’s training method. May contain one of the following:

    • 'callbacks': a list of callback instances.

    • Arbitrary keyword arguments that are callback instances.

  • return_fit_params (bool, optional) – If True, returns a tuple of (callbacks, fit_params) where fit_params no longer contains the extracted callbacks. If False, returns only the callbacks list. Default is False.

Returns:

  • callbacks (list of tf.keras.callbacks.Callback) – A list of extracted callback instances.

  • (callbacks, fit_params) (tuple (only if `return_fit_params=True)`) – A tuple where the first element is a list of extracted callbacks and the second is the updated fit_params dictionary after removing the callbacks.

Examples

>>> from geoprior.models.utils import extract_callbacks_from
>>> from tensorflow.keras.callbacks import EarlyStopping
>>> fit_params = {
...     'epochs': 100,
...     'batch_size': 64,
...     'verbose': 1,
...     'early_stopping': EarlyStopping(patience=5)
... }
>>> callbacks, updated_params = extract_callbacks_from(fit_params,
...                                                    return_fit_params=True)
>>> print(callbacks)
[<keras.src.callbacks.EarlyStopping object at 0x...>]
>>> print(updated_params)
{'epochs': 100, 'batch_size': 64, 'verbose': 1}

Notes

Consider using this function when integrating Keras callbacks within a pipeline or estimator that follows scikit-learn conventions, where parameters are passed as fit_params. This approach enables a clean and modular integration of callbacks into your training loops.

See also

tf.keras.callbacks.Callback

Base class used to build new callbacks.

geoprior.models.utils.forecast_multi_step(xtft_model, inputs, forecast_horizon, y=None, dt_col=None, mode='quantile', spatial_cols=None, q=None, tname=None, forecast_dt=None, apply_mask=False, mask_values=None, mask_fill_value=None, savefile=None, verbose=3, **kws)[source]#

Generate a multi-step forecast using the XTFT model.

This function generates forecasts for multiple future time steps using a pre-trained XTFT deep learning model. The model takes three inputs: X_static, X_dynamic, and X_future, and produces predictions according to the formulation:

(50)#\[\hat{y}_{t+i} = f\Bigl( X_{\text{static}},\; X_{\text{dynamic}},\; X_{\text{future}} \Bigr)\]

for \(i = 1, \dots, forecast_horizon\), where \(f\) is the trained XTFT model.

Parameters:
  • xtft_model (object) – A validated Keras model instance. The model is expected to be verified via validate_keras_model.

  • inputs (list or tuple of numpy.ndarray) – A list containing three elements: X_static, X_dynamic, and X_future. If spatial_cols is provided, it is assumed that the first two columns of X_static correspond to the first and second spatial coordinates of the original training data.

  • forecast_horizon (int) – The number of future time steps to forecast. For example, if forecast_horizon is 4, the model will generate predictions for 4 steps ahead.

  • y (numpy.ndarray, optional) – Actual target values. If provided, evaluation metrics such as R² Score and, in quantile mode, the coverage score are computed.

  • dt_col (str, optional) – Name of the time column (e.g. "year"). If provided, a column with this name is added to the output DataFrame. The actual time values must be supplied externally.

  • mode (str, optional) – Forecast mode. Must be either "quantile" or "point". In quantile mode, predictions are generated for multiple quantiles (default: [0.1, 0.5, 0.9]); in point mode, a single prediction is generated.

  • spatial_cols (list of str, optional) – A list of spatial column names. If provided, it must contain at least two elements corresponding to the first and second columns of the original training data’s X_static.

  • time_steps (int, optional) – The number of historical time steps used as input. Default is 3.

  • q (list of float, optional) – List of quantile values for quantile forecasting. The default is [0.1, 0.5, 0.9] when mode is "quantile".

  • tname (str, optional) – Target variable name used to construct output column names. For instance, if tname is "subsidence", then output columns may be named "subsidence_q10_step1", "subsidence_q50_step2", etc. Default is "target".

  • forecast_dt (any, optional) – Forecast datetime information. If provided and its length matches forecast_horizon, its values are added to the output DataFrame.

  • apply_mask (bool, optional) – If True, applies masking via mask_by_reference to replace predictions in non-subsiding areas. Requires that both mask_values and mask_fill_value are provided.

  • mask_values (scalar, optional) – The reference value(s) used for masking. Must be provided if apply_mask is True.

  • mask_fill_value (scalar, optional) – The value used to fill masked predictions. Must be provided if apply_mask is True.

  • savefile (str, optional) – File path to save the forecast results as a CSV file. If not provided, a default filename is generated.

  • verbose (int, optional) – Verbosity level controlling printed output. Higher values produce more detailed messages.

Returns:

A DataFrame containing the multi-step forecast results. In quantile mode, the DataFrame includes columns for each quantile and each forecast step (e.g. <tname>_q10_step1, <tname>_q50_step2, etc.); in point mode, it contains a single prediction column per forecast step (e.g. <tname>_pred_step1). If y is provided, an additional column (<tname>_actual) is included.

Return type:

pandas.DataFrame

Examples

>>> from geoprior.models.transformers import XTFT
>>> from geoprior.models.utils import forecast_multi_step
>>> from geoprior.models.losses import combined_quantile_loss
>>> import pandas as pd
>>> import numpy as np
>>>
>>> # Create a dummy training DataFrame with a date column,
>>> # spatial features ("longitude", "latitude"), two dynamic
>>> # features ("feat1", "feat2"), a static feature ("stat1"), and
>>> # the target variable "subsidence".
>>> date_rng = pd.date_range(start="2020-01-01", periods=60,
...                          freq="D")
>>> train_df = pd.DataFrame({
...     "date": date_rng,
...     "longitude": np.random.uniform(-180, 180, 60),
...     "latitude": np.random.uniform(-90, 90, 60),
...     "feat1": np.random.rand(60),
...     "feat2": np.random.rand(60),
...     "stat1": np.random.rand(60),
...     "subsidence": np.random.rand(60)
... })
>>>
>>> # Prepare dummy input arrays for model training.
>>> # X_static is constructed using "longitude" and "stat1".
>>> X_static = train_df[["longitude", "stat1"]].values
>>> # X_dynamic for "feat1" and "feat2" with time_steps = 3.
>>> X_dynamic = np.random.rand(60, 3, 2)
>>> # X_future is a dummy future feature array with shape (60, 3, 1).
>>> X_future = np.random.rand(60, 3, 1)
>>> # Target output from "subsidence" reshaped to
>>> # (60, 1, 1). For multi-step forecast, forecast_horizon is 4.
>>> forecast_horizon = 4
>>> y_array = train_df["subsidence"].values.reshape(60, 1, 1)
>>>
>>> # Instantiate a dummy XTFT model.
>>> my_model = XTFT(
...     static_input_dim=2,    # "longitude" and "stat1"
...     dynamic_input_dim=2,   # "feat1" and "feat2"
...     future_input_dim=1,    # One future feature
...     forecast_horizon=forecast_horizon,
...     quantiles=[0.1, 0.5, 0.9],
...     embed_dim=16,
...     max_window_size=3,
...     memory_size=50,
...     num_heads=2,
...     dropout_rate=0.1,
...     lstm_units=32,
...     attention_units=32,
...     hidden_units=16
... )
>>> my_model.compile(
...    optimizer="adam",
...    loss=combined_quantile_loss(my_model.quantiles)
...    )
>>>
>>> # Fit the model on the dummy data for demonstration.
>>> my_model.fit(
...     x=[X_static, X_dynamic, X_future],
...     y=y_array,
...     epochs=1,
...     batch_size=8
... )
>>>
>>> # Generate forecast datetime values for the forecast horizon.
>>> forecast_dates = pd.date_range(start="2020-02-01",
...                                periods=forecast_horizon, freq="D")
>>>
>>> # Package inputs as expected by forecast_multi_step.
>>> inputs = [X_static, X_dynamic, X_future]
>>>
>>> # Generate a multi-step forecast in quantile mode.
>>> forecast_df_quantile = forecast_multi_step(
...     xtft_model=my_model,
...     inputs=inputs,
...     forecast_horizon=forecast_horizon,
...     y=y_array,
...     dt_col="date",
...     mode="quantile",
...     spatial_cols=["longitude", "latitude"],
...     q=[0.1, 0.5, 0.9],
...     tname="subsidence",
...     forecast_dt=forecast_dates,
...     apply_mask=False,
...     verbose=3
... )
>>> print("Quantile Forecast:")
>>> print(forecast_df_quantile.head())
>>>
  1. For point forecast

>>> # Instantiate a dummy XTFT model.
>>> my_model = XTFT(
...     static_input_dim=2,    # "longitude" and "stat1"
...     dynamic_input_dim=2,   # "feat1" and "feat2"
...     future_input_dim=1,    # One future feature
...     forecast_horizon=forecast_horizon,
...     quantiles=None, # set quantiles to None
...     embed_dim=16,
...     max_window_size=3,
...     memory_size=50,
...     num_heads=2,
...     dropout_rate=0.1,
...     lstm_units=32,
...     attention_units=32,
...     hidden_units=16
... )
>>> my_model.compile(
...    optimizer="adam", loss="mse",
...    )
>>>
>>> # Fit the model on the dummy data for demonstration.
>>> my_model.fit(
...     x=[X_static, X_dynamic, X_future],
...     y=y_array,
...     epochs=1,
...     batch_size=8
... )
>>> # Generate a multi-step forecast in point mode.
>>> forecast_df_point = forecast_multi_step(
...     xtft_model=my_model,
...     inputs=inputs,
...     forecast_horizon=forecast_horizon,
...     y=y_array,
...     dt_col="date",
...     mode="point",
...     spatial_cols=["longitude", "latitude"],
...     tname="subsidence",
...     forecast_dt=forecast_dates,
...     apply_mask=False,
...     verbose=3
... )
>>> print("Point Forecast:")
>>> print(forecast_df_point.head())

Notes

  • In quantile mode, predictions are generated for each specified quantile for every forecast step, and the median (0.5) is used for evaluation.

  • In point mode, a single prediction is generated per forecast step.

  • The output prediction array is expected to have the shape \((n, forecast\_horizon, m)\), where \(n\) is the number of samples and \(m\) is the number of outputs per step (e.g., number of quantiles in quantile mode or 1 in point mode).

  • The provided spatial_cols must correspond to the first two columns of the original training data’s X_static.

  • Evaluation metrics such as R² Score and Coverage Score (in quantile mode) are computed if actual target values (y) are provided.

  • The DataFrame is constructed by iterating over each sample and each forecast step.

See also

forecast_single_step

Function for single-step forecasts.

coverage_score

Function to compute the coverage score.

validate_keras_model

Function to validate a Keras model.

assert_ratio

Function to verify quantile ratios.

geoprior.models.utils.forecast_single_step(xtft_model, inputs, y=None, dt_col=None, mode='quantile', spatial_cols=None, q=None, tname=None, forecast_dt=None, apply_mask=False, mask_values=None, mask_fill_value=None, savefile=None, verbose=3, **kws)[source]#

Generate a single-step forecast using the XTFT model.

This function generates a forecast for a single future time step using a pre-trained XTFT deep learning model. The model takes three inputs: X_static, X_dynamic, and X_future, and produces a prediction according to the formulation:

(51)#\[\hat{y}_{t+1} = f\Bigl( X_{\text{static}},\; X_{\text{dynamic}},\; X_{\text{future}} \Bigr)\]

where \(f\) is the trained XTFT model. The predictions can be either quantile-based or point-based, as determined by the mode parameter.

Parameters:
  • xtft_model (object) – A validated Keras model instance. The model is expected to be verified via validate_keras_model.

  • inputs (list or tuple of numpy.ndarray) – A list containing three elements: X_static, X_dynamic, and X_future. If spatial_cols is provided, it is assumed that the first column of X_static corresponds to the first spatial coordinate and the second column to the second spatial coordinate of the original training data.

  • y (numpy.ndarray, optional) – Actual target values. If provided, evaluation metrics such as R² Score and (in quantile mode) the coverage score are computed.

  • dt_col (str, optional) – Name of the time column (e.g. "year"). If provided, a column with this name is added to the output DataFrame. The actual time values must be supplied externally.

  • mode (str, optional) – Forecast mode. Must be either "quantile" or "point". In quantile mode, predictions are generated for multiple quantiles (default: 0.1, 0.5, and 0.9).

  • spatial_cols (list of str, optional) – List of spatial column names. If provided, it must contain at least two elements and correspond to the first and second columns of the original training data’s X_static.

  • q (list of float, optional) – List of quantiles for quantile forecasting. Default is [0.1, 0.5, 0.9] when mode is "quantile".

  • tname (str, optional) – Target variable name for predictions. This name is used to construct output column names (e.g. "subsidence"). Default is "target".

  • forecast_dt (any, optional) – Forecast datetime information. Not used in this function but may be provided for compatibility.

  • apply_mask (bool, optional) – If True, applies a masking function (mask_by_reference) to replace predictions in non-subsiding areas. Requires that both mask_values and mask_fill_value are provided.

  • mask_values (scalar, optional) – Reference value(s) used for masking. Must be provided if apply_mask is True.

  • mask_fill_value (scalar, optional) – Value used to fill masked predictions. Must be provided if apply_mask is True.

  • savefile (str, optional) – Path to a CSV file where the forecast results will be saved. If not provided, a default filename is generated.

  • verbose (int, optional) – Verbosity level controlling printed output. Higher values result in more detailed output.

Returns:

A DataFrame containing the forecast results. In quantile mode, the output includes columns for each quantile (e.g. <tname>_q10, <tname>_q50, <tname>_q90). In point mode, a single prediction column (<tname>_pred) is provided. If y is provided, an additional column with the actual target values (<tname>_actual) is included.

Return type:

pandas.DataFrame

Examples

>>> from geoprior.models.transformers import XTFT
>>> from geoprior.models.utils import forecast_single_step
>>> import pandas as pd
>>> import numpy as np
>>>
>>> # Create a dummy training DataFrame with a date column,
>>> # two dynamic features ("feat1", "feat2"), a static feature ("stat1"),
>>> # and dummy spatial features ("longitude", "latitude"), as well as the
>>> # target variable "subsidence".
>>> date_rng = pd.date_range(start="2020-01-01", periods=50, freq="D")
>>> train_df = pd.DataFrame({
...     "date": date_rng,
...     "longitude": np.random.uniform(-180, 180, 50),
...     "latitude": np.random.uniform(-90, 90, 50),
...     "feat1": np.random.rand(50),
...     "feat2": np.random.rand(50),
...     "stat1": np.random.rand(50),
...     "subsidence": np.random.rand(50)
... })
>>>
>>> # Prepare dummy inputs for the model.
>>> # For the static input, combine the spatial feature "longitude" and the
>>> # static feature "stat1". The forecast_single_step function expects that,
>>> # if spatial_cols is provided, the first two columns of X_static correspond
>>> # to the spatial coordinates.
>>> X_static = train_df[["longitude", "stat1"]].values   # shape: (50, 2)
>>>
>>> # Create a dummy dynamic input array for "feat1" and "feat2".
>>> # Assume time_steps = 3, so the shape is (50, 3, 2).
>>> X_dynamic = np.random.rand(50, 3, 2)
>>>
>>> # Create a dummy future input array.
>>> # For this example, assume one future feature with shape (50, 3, 1).
>>> X_future = np.random.rand(50, 3, 1)
>>>
>>> # Create dummy target output from "subsidence", reshaped to (50, 1, 1)
>>> y_array = train_df["subsidence"].values.reshape(50, 1, 1)
>>>
>>> # Instantiate a dummy XTFT model.
>>> # The model expects:
>>> #   - X_static with shape (n_samples, static_input_dim)
>>> #   - X_dynamic with shape (n_samples, time_steps, dynamic_input_dim)
>>> #   - X_future with shape (n_samples, time_steps, future_input_dim)
>>> my_model = XTFT(
...     static_input_dim=2,         # "longitude" and "stat1"
...     dynamic_input_dim=2,        # "feat1" and "feat2"
...     future_input_dim=1,         # One future feature
...     forecast_horizon=1,         # Single-step forecast
...     quantiles=[0.1, 0.5, 0.9],
...     embed_dim=16,
...     max_window_size=3,
...     memory_size=50,
...     num_heads=2,
...     dropout_rate=0.1,
...     lstm_units=32,
...     attention_units=32,
...     hidden_units=16
... )
>>> my_model.compile(optimizer="adam")
>>>
>>> # Fit the model on the dummy data.
>>> my_model.fit(
...     x=[X_static, X_dynamic, X_future],
...     y=y_array,
...     epochs=1,
...     batch_size=8
... )
>>>
>>> # Package the inputs as expected by forecast_single_step.
>>> inputs = [X_static, X_dynamic, X_future]
>>>
>>> # Generate a single-step quantile forecast.
>>> forecast_df = forecast_single_step(
...     xtft_model=my_model,
...     inputs=inputs,
...     y=y_array,
...     dt_col="date",                # The time column name in the output
...     mode="quantile",              # Can be "quantile" or "point"
...     spatial_cols=["longitude", "latitude"],
...     q=[0.1, 0.5, 0.9],
...     tname="subsidence",
...     apply_mask=True,
...     mask_values=0,
...     mask_fill_value=0,
...     verbose=3
... )
>>> print(forecast_df.head())

Notes

  • In quantile mode, the function computes predictions for multiple quantiles and uses the median (0.5) for evaluation.

  • If spatial_cols is provided, it must be the first and second columns of the original training data’s X_static.

  • The function internally utilizes validate_keras_model for model validation, assert_ratio for quantile verification, and mask_by_reference for masking operations.

  • Evaluation metrics such as R² Score and Coverage Score are computed if actual target values (y) are provided.

  • The prediction output is expected to have the shape \((n, 1, m)\), where \(m\) is the number of outputs (e.g., the number of quantiles in quantile mode, or 1 in point mode).

See also

generate_forecast_multi_step

Function for multi-step forecasts.

coverage_score

Function to compute the coverage score.

validate_keras_model

Function to validate a Keras model.

assert_ratio

Function to validate quantile ratios.

geoprior.models.utils.format_predictions(predictions=None, model=None, inputs=None, y_true_sequences=None, target_name='target', quantiles=None, forecast_horizon=None, output_dim=None, spatial_data_array=None, spatial_cols=None, spatial_cols_indices=None, evaluate_coverage=False, scaler=None, scaler_feature_names=None, target_idx_in_scaler=None, verbose=0, _logger=None, **kwargs)[source]#

Formats model predictions into a structured pandas DataFrame.

This utility function takes raw model predictions (either directly as an array/tensor or generated by a provided model and its inputs) and transforms them into a long-format pandas DataFrame. It can handle point forecasts, quantile forecasts, single or multi-output predictions, and optionally include actual target values, spatial identifiers, and perform coverage score evaluation for quantile forecasts.

The output DataFrame is structured with ‘sample_idx’ (identifying the original input sequence) and ‘forecast_step’ (from 1 to H, where H is the forecast horizon).

Parameters:
  • predictions (np.ndarray or tf.Tensor, optional) –

    The raw prediction tensor or array.

    • For point forecasts, expected shapes:
      • (num_samples, forecast_horizon, output_dim)

      • (num_samples, forecast_horizon) if output_dim=1 (will be reshaped)

      • (num_samples, output_dim) if forecast_horizon=1 (will be reshaped)

    • For quantile forecasts, expected shapes:
      • (num_samples, forecast_horizon, num_quantiles * output_dim)

      • (num_samples, forecast_horizon, num_quantiles, output_dim)

      • (num_samples, num_quantiles * output_dim) if forecast_horizon=1

    If None, model and inputs must be provided. Default is None.

  • model (tf.keras.Model, optional) – A trained Keras model to generate predictions if predictions is not provided. Used in conjunction with inputs. Default is None.

  • inputs (List[Optional[Union[np.ndarray, tf.Tensor]]], optional) – A list of input tensors (e.g., [static, dynamic, future]) required by the model to generate predictions. Required if predictions is None and model is provided. Default is None.

  • y_true_sequences (np.ndarray or tf.Tensor, optional) – The true target values corresponding to the predictions, used for including actuals in the output DataFrame and for evaluation. Expected shape: (num_samples, forecast_horizon, output_dim). Default is None.

  • target_name (str, optional) – Base name for the target variable. Used to prefix prediction and actual column names (e.g., “sales_pred”, “sales_q50”, “sales_actual”). Default is “target”.

  • quantiles (List[float], optional) – A list of quantiles that were predicted by the model (e.g., [0.1, 0.5, 0.9]). Required if the predictions are quantile forecasts. If provided, prediction columns will be named like {target_name}_q10, {target_name}_q50, etc. Default is None (for point forecasts).

  • forecast_horizon (int, optional) – The number of time steps into the future that the model predicts. If not provided, it’s inferred from predictions.shape[1] or y_true_sequences.shape[1]. Default is None.

  • output_dim (int, optional) – The number of target variables predicted at each time step (e.g., 1 for univariate, >1 for multivariate target). If not provided, it’s inferred from the shape of predictions or y_true_sequences. Default is None.

  • spatial_data_array (np.ndarray or tf.Tensor or pd.DataFrame or pd.Series, optional) –

    An array or DataFrame containing static spatial/identifier features for each of the num_samples sequences.

    • If NumPy/Tensor: Expected shape (num_samples, num_spatial_features). spatial_cols_indices must be provided.

    • If DataFrame/Series: Must have num_samples rows. spatial_cols must be provided.

    These features will be repeated for each forecast step in the output DataFrame. Default is None.

  • spatial_cols (List[str], optional) – List of column names to select from spatial_data_array if it’s a DataFrame/Series, or names to assign to columns if spatial_data_array is NumPy/Tensor and spatial_cols_indices are provided. Default is None.

  • spatial_cols_indices (List[int], optional) – List of column indices to select from spatial_data_array if it’s a NumPy/Tensor. Length must match spatial_cols if provided. Default is None.

  • evaluate_coverage (bool, default False) – If True, quantiles are provided (at least two), and y_true_sequences is available, calculates the coverage score using the first and last quantiles as interval bounds. Requires geoprior.metrics.coverage_score.

  • scaler (Any, optional) – A fitted scikit-learn-like scaler object (must have an inverse_transform method) used to scale the target variable and potentially other features. If provided along with scaler_feature_names and target_idx_in_scaler, predictions and actuals for the target will be inverse-transformed. Default is None.

  • scaler_feature_names (List[str], optional) – A list of all feature names (in order) that the scaler was originally fit on. Required if scaler is provided and targeted inverse transformation is needed. Default is None.

  • target_idx_in_scaler (int, optional) – The index of the target_name within the scaler_feature_names list. Required if scaler is provided and targeted inverse transformation is needed. Default is None.

  • verbose (int, default 0) –

    Verbosity level for logging during processing.

    • 0: Silent.

    • 1: Basic info.

    • 3: More detailed steps.

    • 5: Very detailed shape information.

  • **kwargs (Any) – Additional keyword arguments (currently not used but included for future extensibility).

  • _logger (Logger | Callable[[str], None] | None)

  • **kwargs

Returns:

A long-format DataFrame containing sample_idx and forecast_step, optional spatial columns, prediction columns, and actual-value columns when y_true_sequences is provided. Point forecasts use names like {target_name}_pred or {target_name}_{output_idx}_pred. Quantile forecasts use names like {target_name}_qXX or {target_name}_{output_idx}_qXX. Actual values use {target_name}_actual or {target_name}_{output_idx}_actual. Prediction and actual values are inverse-transformed when valid scaler information is provided.

Return type:

pandas.DataFrame

Raises:
  • ValueError – If predictions is None and model or inputs is also None. If predictions shape is invalid (not 2D, 3D, or 4D). If quantiles are provided but prediction shape is incompatible for inferring output_dim. If spatial_data_array is provided without necessary name/index parameters.

  • TypeError – If predictions or other inputs cannot be converted to the expected tensor/array types.

See also

geoprior.models.utils.forecast_multi_step

Higher-level forecasting utility.

geoprior.metrics.coverage_score

For evaluating quantile forecast intervals.

Examples

>>> import tensorflow as tf
>>> import numpy as np
>>> from geoprior.models.utils import format_predictions_to_dataframe
>>> B, H, O = 4, 3, 1 # Batch, Horizon, OutputDim
>>> Q = [0.1, 0.5, 0.9]
>>> preds_point = tf.random.normal((B, H, O))
>>> preds_quant = tf.random.normal((B, H, len(Q))) # For O=1
>>> y_true = tf.random.normal((B, H, O))
>>> # Point forecast
>>> df_point = format_predictions_to_dataframe(
...     predictions=preds_point, y_true_sequences=y_true,
...     target_name="value", forecast_horizon=H, output_dim=O
... )
>>> print(df_point.head(H)) # Show first sample's horizon
   sample_idx  forecast_step  value_pred  value_actual
0           0              1   -0.576731     -0.647362
1           0              2    0.183931      1.198977
2           0              3   -0.766871      0.534040
>>> # Quantile forecast
>>> df_quant = format_predictions_to_dataframe(
...     predictions=preds_quant, y_true_sequences=y_true,
...     target_name="value", quantiles=Q,
...     forecast_horizon=H, output_dim=O
... )
>>> print(df_quant.head(H))
   sample_idx  forecast_step  value_q10  value_q50  value_q90  value_actual
0           0              1  -0.209947   0.263107  -0.308929     -0.647362
1           0              2   0.303091   0.594701  -0.225007      1.198977
2           0              3   0.136699  -1.237739   0.002834      0.534040
>>> # With spatial data (NumPy array)
>>> spatial_np = np.array([[101, 201], [102, 202], [103, 203], [104, 204]])
>>> df_spatial = format_predictions_to_dataframe(
...     predictions=preds_point,
...     spatial_data_array=spatial_np,
...     spatial_cols=['store_id', 'region_id'],
...     spatial_cols_indices=[0, 1]
... )
>>> print(df_spatial[['sample_idx', 'forecast_step', 'store_id']].head(H))
   sample_idx  forecast_step  store_id
0           0              1     101.0
1           0              2     101.0
2           0              3     101.0
geoprior.models.utils.format_predictions_to_dataframe(predictions=None, model=None, inputs=None, y_true_sequences=None, target_name='target', quantiles=None, forecast_horizon=None, output_dim=None, spatial_data_array=None, spatial_cols=None, spatial_cols_indices=None, evaluate_coverage=False, scaler=None, scaler_feature_names=None, target_idx_in_scaler=None, verbose=0, **kwargs)[source]#

Deprecated alias for format_predictions. See format_predictions for the updated, recommended API. All original parameters are forwarded to format_predictions.

Returns:

The formatted prediction DataFrame from format_predictions.

Return type:

pd.DataFrame

Parameters:
  • predictions (ndarray | Any | None)

  • model (Model | None)

  • inputs (list[ndarray | Any | None] | None)

  • y_true_sequences (ndarray | Any | None)

  • target_name (str | None)

  • quantiles (list[float] | None)

  • forecast_horizon (int | None)

  • output_dim (int | None)

  • spatial_data_array (ndarray | Any | None)

  • spatial_cols (list[str] | None)

  • spatial_cols_indices (list[int] | None)

  • evaluate_coverage (bool)

  • scaler (Any | None)

  • scaler_feature_names (list[str] | None)

  • target_idx_in_scaler (int | None)

  • verbose (int)

  • kwargs (Any)

geoprior.models.utils.generate_forecast(xtft_model, train_data, dt_col, dynamic_features, future_features=None, static_features=None, test_data=None, mode='quantile', spatial_cols=None, forecast_horizon=4, time_steps=3, q=None, tname=None, forecast_dt=None, savefile=None, verbose=3, **kw)[source]#

Generate forecast using the XTFT model.

This function uses a pre-trained Keras model to forecast future values based on provided historical data. The model receives three inputs: X_static, X_dynamic, and X_future re-built from train_data, and outputs predictions over a specified forecast horizon.

See more in User Guide.

Parameters:
  • xtft_model (object) – A validated Keras model instance. It is processed by the validate_keras_model method.

  • train_data (pandas.DataFrame) – The training data containing historical records. Must include the dt_col and all required feature columns.

  • dt_col (str) – Name of the column representing time. It may be a datetime or numeric column (e.g. "year").

  • dynamic_features (list of str) – List of dynamic feature column names. They are formatted via columns_manager.

  • future_features (list of str, optional) – List of future feature names. These columns are tiled over the forecast horizon.

  • static_features (list of str, optional) – List of static feature names. If not provided, a dummy input is used.

  • test_data (pandas.DataFrame, optional) – DataFrame containing actual values used for evaluation. If provided, it is used to compute the R² and coverage score for mode='quantile'.

  • mode (str, optional) – Forecast mode. Must be either "quantile" or "point". In quantile mode, predictions for multiple quantiles (default: [0.1, 0.5, 0.9]) are computed.

  • spatial_cols (list of str, optional) – List of spatial column names for grouping the data. When provided, forecasts are computed per location; otherwise, a global forecast is performed.

  • forecast_horizon (int, optional) – Number of future periods to forecast. Default is 4.

  • time_steps (int, optional) – Number of past time steps to use as input for the model. Default is 3.

  • q (list of float, optional) – List of quantiles for use in quantile mode. Default is [0.1, 0.5, 0.9]. Each quantile is validated by the assert_ratio function.

  • tname (str, optional) – Target variable name used for constructing forecast result columns. Defaults to "target".

  • forecast_dt (list or str, optional) – List of forecast dates or "auto" to derive dates from dt_col. In auto mode, if dt_col is datetime, frequency is inferred using pd.infer_freq.

  • savefile (str, optional) – Path to the CSV file where forecast results will be saved. If not provided, a default filename is generated.

  • verbose (int, optional) – Verbosity level (0-7). Controls the amount of execution output.

Returns:

A DataFrame containing the forecast results. In quantile mode, each forecast period includes columns for each quantile; in point mode, a single prediction column is provided.

Return type:

pandas.DataFrame

Examples

  1. Example refering to Train data only

>>> import os
>>> import pandas as pd
>>> import numpy as np
>>> from geoprior.models.transformers import XTFT
>>> from geoprior.models.losses import combined_quantile_loss
>>> from geoprior.models.utils import generate_forecast
>>>
>>> # Create a dummy training DataFrame with a date column,
>>> # dynamic features "feat1", "feat2", static feature "stat1",
>>> # and target "price".
>>> date_rng = pd.date_range(start="2020-01-01", periods=50, freq="D")
>>> train_df = pd.DataFrame({
...     "date": date_rng,
...     "feat1": np.random.rand(50),
...     "feat2": np.random.rand(50),
...     "stat1": np.random.rand(50),
...     "price": np.random.rand(50)
... })
>>>
>>> # Prepare a dummy XTFT model with example parameters.
>>> # Note: The model expects the following input shapes:
>>> # - X_static: (n_samples, static_input_dim)
>>> # - X_dynamic: (n_samples, time_steps, dynamic_input_dim)
>>> # - X_future:  (n_samples, time_steps, future_input_dim)
>>> my_model = XTFT(
...     static_input_dim=1,           # "stat1"
...     dynamic_input_dim=2,          # "feat1" and "feat2"
...     future_input_dim=1,           # features provided for dim1
...     forecast_horizon=5,           # Forecasting 5 periods ahead
...     quantiles=[0.1, 0.5, 0.9],
...     embed_dim=16,
...     max_window_size=3,
...     memory_size=50,
...     num_heads=2,
...     dropout_rate=0.1,
...     lstm_units=32,
...     attention_units=32,
...     hidden_units=16
... )
>>> my_model.compile(optimizer="adam")
>>>
>>> # Create dummy input arrays for model fitting.
>>> # For simplicity, assume time_steps = 3 and use random data.
>>> X_static = train_df[["stat1"]].values      # shape: (50, 1)
>>> # Create a dummy dynamic input array of shape (50, 3, 2)
>>> X_dynamic = np.random.rand(50, 3, 2)
>>> # Create a dummy features
>>> X_future = np.random.rand(50, 3, 1)
>>> # Create dummy target output from "price"
>>> y_array = train_df["price"].values.reshape(50, 1, 1)
>>>
>>> # Fit the model on the dummy data.
>>> my_model.fit(
...     x=[X_static, X_dynamic, X_future],
...     y=y_array,
...     epochs=1,
...     batch_size=8
... )
>>>
>>> # Generate forecast using the generate_forecast function.
>>> forecast = generate_forecast(
...     xtft_model=my_model,
...     train_data=train_df,
...     dt_col="date",
...     dynamic_features=["feat1", "feat2"],
...     static_features=["stat1"],
...     forecast_horizon=5,
...     time_steps=3,
...     tname="price",
...     mode="quantile",
...     verbose=3
... )
>>> print(forecast.head())
  1. Example refering to Test data included.

>>> # Create a dummy DataFrame with a date column,
>>> # two dynamic features ("feat1", "feat2"), one static feature ("stat1"),
>>> # and target "price".
>>> date_rng = pd.date_range(start="2020-01-01", periods=60, freq="D")
>>> data = {
...     "date": date_rng,
...     "feat1": np.random.rand(60),
...     "feat2": np.random.rand(60),
...     "stat1": np.random.rand(60),
...     "price": np.random.rand(60)
... }
>>> df = pd.DataFrame(data)
>>>
>>> # Split the DataFrame into training and test sets.
>>> # Training data: dates before 2020-02-01
>>> # Test data: dates from 2020-02-01 onward.
>>> train_df = df[df["date"] < "2020-02-01"].copy()
>>> test_df  = df[df["date"] >= "2020-02-01"].copy()
>>>
>>> # Create dummy input arrays for model fitting.
>>> # Assume time_steps = 3.
>>> X_static = train_df[["stat1"]].values      # Shape: (n_train, 1)
>>> X_dynamic = np.random.rand(len(train_df), 3, 2)
>>> X_future  = np.random.rand(len(train_df), 3, 1)
>>> # Create dummy target output from "price".
>>> y_array   = train_df["price"].values.reshape(len(train_df), 1, 1)
>>>
>>> # Instantiate a dummy XTFT model.
>>> my_model = XTFT(
...     static_input_dim=1,           # "stat1"
...     dynamic_input_dim=2,          # "feat1" and "feat2"
...     future_input_dim=1,           # For the provided future feature
...     forecast_horizon=5,           # Forecasting 5 periods ahead
...     quantiles=[0.1, 0.5, 0.9],
...     embed_dim=16,
...     max_window_size=3,
...     memory_size=50,
...     num_heads=2,
...     dropout_rate=0.1,
...     lstm_units=32,
...     attention_units=32,
...     hidden_units=16
... )
>>> loss_fn = combined_quantile_loss(my_model.quantiles)
>>> my_model.compile(optimizer="adam", loss=loss_fn)
>>>
>>> # Fit the model on the training data.
>>> my_model.fit(
...     x=[X_static, X_dynamic, X_future],
...     y=y_array,
...     epochs=1,
...     batch_size=8,
...     callbacks = [early_stopping, model_checkpoint]
... )
>>>
>>> # Generate forecast using the generate_forecast function.
>>> # This example uses test_df for evaluation, which will compute
>>> # metrics like R² Score and Coverage Score.
>>> forecast = generate_forecast(
...     xtft_model=my_model,
...     train_data=train_df,
...     dt_col="date",
...     dynamic_features=["feat1", "feat2"],
...     static_features=["stat1"],
...     test_data=test_df.iloc[:5, :], # to fit the first horizon forecasting.
...     forecast_horizon=5,
...     time_steps=3,
...     tname="price",
...     mode="quantile",
...     verbose=3
... )
>>> print(forecast.head())
  1. Example of Point forecasting

>>> # Create a dummy training DataFrame with a date column,
>>> # two dynamic features ("feat1", "feat2"), one static feature ("stat1"),
>>> # and target "price".
>>> date_rng = pd.date_range(start="2020-01-01", periods=50, freq="D")
>>> train_df = pd.DataFrame({
...     "date": date_rng,
...     "feat1": np.random.rand(50),
...     "feat2": np.random.rand(50),
...     "stat1": np.random.rand(50),
...     "price": np.random.rand(50)
... })
>>>
>>> # Create dummy input arrays for model fitting.
>>> # X_static is derived from the static feature "stat1".
>>> X_static = train_df[["stat1"]].values      # shape: (50, 1)
>>>
>>> # X_dynamic is a dummy dynamic array for "feat1" and "feat2".
>>> # For time_steps = 3, its shape is (50, 3, 2).
>>> X_dynamic = np.random.rand(50, 3, 2)
>>>
>>> # X_future is a dummy array for future features.
>>> # Here, we assume a single future feature with shape (50, 3, 1).
>>> X_future = np.random.rand(50, 3, 1)
>>>
>>> # Create dummy target output from "price".
>>> y_array = train_df["price"].values.reshape(50, 1, 1)
>>>
>>> # Instantiate a dummy XTFT model.
>>> my_model = XTFT(
...     static_input_dim=1,           # "stat1"
...     dynamic_input_dim=2,          # "feat1" and "feat2"
...     future_input_dim=1,           # Provided future feature
...     forecast_horizon=5,           # Forecast 5 periods ahead
...     quantiles=None,    # [0.1, 0.5, 0.9] Not used in point mode
...     embed_dim=16,
...     max_window_size=3,
...     memory_size=50,
...     num_heads=2,
...     dropout_rate=0.1,
...     lstm_units=32,
...     attention_units=32,
...     hidden_units=16
... )
>>> my_model.compile(optimizer="adam")
>>>
>>> # Fit the model on the dummy data.
>>> my_model.fit(
...     x=[X_static, X_dynamic, X_future],
...     y=y_array,
...     epochs=1,
...     batch_size=8
... )
>>>
>>> # Generate forecast using the generate_forecast function in point mode.
>>> forecast = generate_forecast(
...     xtft_model=my_model,
...     train_data=train_df,
...     dt_col="date",
...     dynamic_features=["feat1", "feat2"],
...     static_features=["stat1"],
...     forecast_horizon=5,
...     time_steps=3,
...     tname="price",
...     mode="point",
...     verbose=3
... )
>>> print(forecast.head())

Notes

The function groups data by spatial_cols if provided, and formats features via columns_manager. It validates the time column using check_datetime and uses dummy inputs for missing static or future features. The forecast is produced by invoking xtft_model.predict on a list containing static, dynamic, and future inputs. The predictions are generated as follows:

(52)#\[\hat{y}_{t+i} = f\Bigl(X_{\text{static}},\; X_{\text{dynamic}},\; X_{\text{future}}\Bigr)\]

where \(i\) denotes the forecast period.

See also

geoprior.models.utils.reshape_xtft_data

Function to reshape data for XTFT models.

geoprior.utils.validator.validate_keras_model

Function to validate Keras model compatibility.

geoprior.core.handlers.columns_manager

Utility to manage and format column names.

geoprior.core.checks.check_datetime

Function to check and validate datetime columns.

geoprior.core.checks.check_spatial_columns

Function to validate spatial columns in data.

geoprior.core.checks.assert_ratio

Function to validate and assert ratio values.

geoprior.metrics_special.coverage_score

Function to compute coverage score for quantile predictions.

geoprior.models.utils.generate_forecast_with(xtft_model, inputs, forecast_horizon, y=None, dt_col=None, mode='quantile', spatial_cols=None, q=None, tname=None, forecast_dt=None, apply_mask=False, mask_values=None, mask_fill_value=None, savefile=None, verbose=3, **kw)[source]#

Generate forecasts using a pre-trained XTFT model based on the forecast horizon.

There are two approaches to generating forecasts with an XTFT model:

  1. A monolithic function (e.g. generate_forecast) that handles both single-step and multi-step forecasts within a single implementation. This approach results in a single, large function that internally branches its logic based on the value of the forecast horizon.

  2. A modular design where the single-step and multi-step forecasting functionalities are separated into two distinct functions (e.g. forecast_single_step and forecast_multi_step), with a thin wrapper (e.g. generate_xtft_forecast) that dispatches to the appropriate function based on the forecast horizon.

The modular approach (2) is generally preferred because it separates concerns and improves code readability, maintainability, and unit testing. Each function becomes responsible for a well-defined task: one for single-step forecasts and one for multi-step forecasts. The wrapper function, which we propose to name generate_xtft_forecast, simply selects the correct method based on the forecast horizon. Use this approach when your application may need to handle both short- and long- term forecasts, as it keeps the codebase modular and easier to debug.

Below is an implementation of the wrapper function generate_xtft_forecast that calls forecast_single_step when forecast_horizon equals 1 and forecast_multi_step when forecast_horizon is greater than 1.

Parameters:
  • xtft_model (object) – A validated Keras model instance. The model is expected to be verified via validate_keras_model.

  • inputs (list or tuple of numpy.ndarray) – A list containing three elements: X_static, X_dynamic, and X_future. If spatial_cols is provided, it is assumed that the first two columns of X_static correspond to the first and second spatial coordinates of the original training data.

  • forecast_horizon (int) – The number of future time steps to forecast. A value of 1 triggers a single-step forecast; values greater than 1 trigger a multi-step forecast.

  • y (numpy.ndarray, optional) – Actual target values for evaluation. If provided, evaluation metrics (e.g., R² Score, and in quantile mode, the coverage score) are computed.

  • dt_col (str, optional) – Name of the time column (e.g. "year"). If provided, the output DataFrame includes a column with these values.

  • mode (str, optional) – Forecast mode, either "quantile" or "point". In quantile mode, predictions are generated for multiple quantiles (default: [0.1, 0.5, 0.9]); in point mode, a single prediction is generated.

  • spatial_cols (list of str, optional) – List of spatial column names. If provided, it must contain at least two elements corresponding to the first and second columns of the original training data’s X_static.

  • time_steps (int, optional) – The number of historical time steps used as input.

  • q (list of float, optional) – List of quantile values for quantile forecasting. Default is [0.1, 0.5, 0.9] when mode is "quantile".

  • tname (str, optional) – Target variable name used to construct output column names (e.g., "subsidence"). Default is "target".

  • forecast_dt (any, optional) – Forecast datetime information. If provided and its length matches forecast_horizon, its values are added to the output DataFrame.

  • apply_mask (bool, optional) – If True, applies masking (via mask_by_reference) to adjust predictions in non-subsiding areas. Requires that both mask_values and mask_fill_value are provided.

  • mask_values (scalar, optional) – The reference value(s) used for masking. Must be provided if apply_mask is True.

  • mask_fill_value (scalar, optional) – The value used to fill masked predictions. Must be provided if apply_mask is True.

  • savefile (str, optional) – File path to save the forecast results as a CSV file. If not provided, a default filename is generated.

  • verbose (int, optional) – Verbosity level controlling printed output.

  • **kw (dict, optional) – Does nothing; here for future extension.

Returns:

A DataFrame containing the forecast results. In quantile mode, the output includes columns for each quantile and forecast step (e.g. <tname>_q10_step1, <tname>_q50_step2, etc.); in point mode, it contains a single prediction column per forecast step (e.g. <tname>_pred_step1). If y is provided, an additional column (<tname>_actual) is included.

Return type:

pandas.DataFrame

Examples

>>> from geoprior.models.transformers import XTFT
>>> from geoprior.models.utils import generate_forecast_with
>>> import numpy as np
>>>
>>> # Prepare a dummy XTFT model with example parameters.
>>> my_model = XTFT(
...     static_input_dim=10,
...     dynamic_input_dim=5,
...     future_input_dim=3,
...     forecast_horizon=1,          # This parameter will be updated in the
...                                  # wrapper function based on forecast_horizon.
...     quantiles=[0.1, 0.5, 0.9],
...     embed_dim=32,
...     max_window_size=3,
...     memory_size=100,
...     num_heads=4,
...     dropout_rate=0.1,
...     lstm_units=64,
...     attention_units=64,
...     hidden_units=32
... )
>>> my_model.compile(optimizer='adam')
>>>
>>> # Create dummy input data.
>>> X_static = np.random.rand(100, 10)
>>> X_dynamic = np.random.rand(100, 3, 5)
>>> X_future  = np.random.rand(100, 3, 3)
>>> y_array   = np.random.rand(100, 1, 1)  # For single-step target output.
>>> inputs    = [X_static, X_dynamic, X_future]
>>>
>>> # Fit the model with dummy data.
>>> my_model.fit(
...     x=[X_static, X_dynamic, X_future],
...     y=y_array,
...     epochs=1,
...     batch_size=32
... )
>>>
>>> # Example for a single-step forecast:
>>> forecast_df = generate_forecast_with(
...     xtft_model=my_model,
...     inputs=inputs,
...     forecast_horizon=1,
...     y=y_array,
...     dt_col="year",
...     mode="quantile",
...     spatial_cols=["longitude", "latitude"],
...     tname="subsidence",
...     verbose=3
... )
>>> print(forecast_df.head())
>>>
>>> # Example for a multi-step forecast:
>>> forecast_dates = ["2023", "2024", "2025", "2026"]
>>> forecast_df = generate_forecast_with(
...     xtft_model=my_model,
...     inputs=inputs,
...     forecast_horizon=4,
...     y=y_array,
...     dt_col="year",
...     mode="point",
...     spatial_cols=["longitude", "latitude"],
...     tname="subsidence",
...     forecast_dt=forecast_dates,
...     verbose=3
... )
>>> print(forecast_df.head())

See also

forecast_single_step

Generates a single-step forecast.

forecast_multi_step

Generates a multi-step forecast.

validate_keras_model

Validates a Keras model.

coverage_score

Computes the coverage score.

geoprior.models.utils.prepare_model_inputs(dynamic_input, static_input=None, future_input=None, model_type='strict', forecast_horizon=None, verbose=0, **kwargs)[source]#

Prepares a list of input tensors for a model’s call method.

This function standardizes the creation of the input list [static, dynamic, future] expected by many models in geoprior. It handles cases where static or future inputs might be None, creating appropriate dummy tensors with zero features if the model_type is ‘strict’.

Parameters:
  • dynamic_input (np.ndarray or tf.Tensor) – The dynamic (past observed) features. This input is always required and must be a valid tensor or array. Expected shape: (batch_size, past_time_steps, num_dynamic_features).

  • static_input (np.ndarray or tf.Tensor, optional) – The static (time-invariant) features. Expected shape: (batch_size, num_static_features). If None and model_type is ‘strict’, a dummy tensor with 0 static features will be created. Default is None.

  • future_input (np.ndarray or tf.Tensor, optional) – The known future features. Expected shape: (batch_size, future_time_span, num_future_features). If None and model_type is ‘strict’, a dummy tensor with 0 future features will be created. The time span for this dummy future tensor will be past_time_steps (from dynamic_input) plus forecast_horizon if provided, otherwise just past_time_steps. Default is None.

  • model_type ({'strict', 'flexible'}, default 'strict') –

    Determines how None inputs for static and future features are handled:

    • 'strict': If static_input or future_input is None, a dummy tensor with a feature dimension of 0 will be created and included in the output list. This is for models that expect a 3-element list of tensors, even if some paths are unused.

    • 'flexible': If static_input or future_input is None, None itself will be placed in the corresponding position in the output list. This is for models that can internally handle None inputs for optional feature types.

  • forecast_horizon (int, optional) – The forecast horizon. Used only if model_type=’strict’ and future_input is None, to determine the time dimension of the dummy future tensor (as past_time_steps + forecast_horizon). If not provided in this scenario, the dummy future tensor’s time dimension will match dynamic_input’s past_time_steps. Default is None.

  • verbose (int, default 0) –

    Verbosity level. If > 0, prints information about dummy tensor creation.

    • 0: Silent.

    • 1: Basic info on dummy creation.

    • 2: More details on shapes.

Returns:

A list containing three elements in the order: [processed_static_input, processed_dynamic_input, processed_future_input]. Elements can be TensorFlow tensors or None (if model_type=’flexible’ and original input was None). All returned tensors are cast to tf.float32.

Return type:

List[Optional[tf.Tensor]]

Raises:
  • ValueError – If dynamic_input is None. If dynamic_input is not at least 2D (needs batch dimension). If static_input (when provided) is not 2D. If future_input (when provided) is not 3D.

  • TypeError – If inputs cannot be converted to TensorFlow tensors.

Examples

>>> import tensorflow as tf
>>> import numpy as np
>>> from geoprior.models.utils import prepare_model_inputs
>>> B, T, H = 2, 10, 3
>>> D_s, D_d, D_f = 2, 4, 1
>>> dyn_in = tf.random.normal((B, T, D_d))
>>> stat_in = tf.random.normal((B, D_s))
>>> fut_in = tf.random.normal((B, T + H, D_f))
>>> # Strict mode, all inputs provided
>>> s, d, f = prepare_model_inputs(dyn_in, stat_in, fut_in, model_type='strict')
>>> print(f"S: {s.shape}, D: {d.shape}, F: {f.shape}")
S: (2, 2), D: (2, 10, 4), F: (2, 13, 1)
>>> # Strict mode, static is None
>>> s, d, f = prepare_model_inputs(dyn_in, static_input=None, future_input=fut_in,
...                                model_type='strict', forecast_horizon=H)
>>> print(f"S: {s.shape}, D: {d.shape}, F: {f.shape}")
S: (2, 0), D: (2, 10, 4), F: (2, 13, 1)
>>> # Flexible mode, static and future are None
>>> s, d, f = prepare_model_inputs(dyn_in, static_input=None, future_input=None,
...                                model_type='flexible')
>>> print(f"S: {s is None}, D: {d.shape}, F: {f is None}")
S: True, D: (2, 10, 4), F: True
geoprior.models.utils.prepare_model_inputs_in(dynamic_input, static_input=None, future_input=None, model_type='strict', forecast_horizon=None, verbose=0)[source]#
Prepares a list of input tensors for a model’s call method in graph

compatible mode.

Prepares a list of input tensors for a model’s call method.

This function standardizes the creation of the input list [static, dynamic, future] expected by many models in geoprior. It handles cases where static or future inputs might be None, creating appropriate dummy tensors with zero features if the model_type is ‘strict’.

Parameters:
dynamic_inputnp.ndarray or tf.Tensor

The dynamic (past observed) features. This input is always required and must be a valid tensor or array. Expected shape: (batch_size, past_time_steps, num_dynamic_features).

static_inputnp.ndarray or tf.Tensor, optional

The static (time-invariant) features. Expected shape: (batch_size, num_static_features). If None and model_type is ‘strict’, a dummy tensor with 0 static features will be created. Default is None.

future_inputnp.ndarray or tf.Tensor, optional

The known future features. Expected shape: (batch_size, future_time_span, num_future_features). If None and model_type is ‘strict’, a dummy tensor with 0 future features will be created. The time span for this dummy future tensor will be past_time_steps (from dynamic_input) plus forecast_horizon if provided, otherwise just past_time_steps. Default is None.

model_type{‘strict’, ‘flexible’}, default ‘strict’

Determines how None inputs for static and future features are handled:

  • 'strict': If static_input or future_input is None, a dummy tensor with a feature dimension of 0 will be created and included in the output list. This is for models that expect a 3-element list of tensors, even if some paths are unused.

  • 'flexible': If static_input or future_input is None, None itself will be placed in the corresponding position in the output list. This is for models that can internally handle None inputs for optional feature types.

forecast_horizonint, optional

The forecast horizon. Used only if model_type=’strict’ and future_input is None, to determine the time dimension of the dummy future tensor (as past_time_steps + forecast_horizon). If not provided in this scenario, the dummy future tensor’s time dimension will match dynamic_input’s past_time_steps. Default is None.

verboseint, default 0

Verbosity level. If > 0, prints information about dummy tensor creation.

  • 0: Silent.

  • 1: Basic info on dummy creation.

  • 2: More details on shapes.

Returns:
List[Optional[tf.Tensor]]

A list containing three elements in the order: [processed_static_input, processed_dynamic_input, processed_future_input]. Elements can be TensorFlow tensors or None (if model_type=’flexible’ and original input was None). All returned tensors are cast to tf.float32.

Raises:
ValueError

If dynamic_input is None. If dynamic_input is not at least 2D (needs batch dimension). If static_input (when provided) is not 2D. If future_input (when provided) is not 3D.

TypeError

If inputs cannot be converted to TensorFlow tensors.

Parameters:
Return type:

list[Any | None]

Examples

>>> import tensorflow as tf
>>> import numpy as np
>>> from geoprior.models.utils import prepare_model_inputs_in
>>> B, T, H = 2, 10, 3
>>> D_s, D_d, D_f = 2, 4, 1
>>> dyn_in = tf.random.normal((B, T, D_d))
>>> stat_in = tf.random.normal((B, D_s))
>>> fut_in = tf.random.normal((B, T + H, D_f))
>>> # Strict mode, all inputs provided
>>> s, d, f = prepare_model_inputs_in(dyn_in, stat_in, fut_in, model_type='strict')
>>> print(f"S: {s.shape}, D: {d.shape}, F: {f.shape}")
S: (2, 2), D: (2, 10, 4), F: (2, 13, 1)
>>> # Strict mode, static is None
>>> s, d, f = prepare_model_inputs_in(dyn_in, static_input=None, future_input=fut_in,
...                                model_type='strict', forecast_horizon=H)
>>> print(f"S: {s.shape}, D: {d.shape}, F: {f.shape}")
S: (2, 0), D: (2, 10, 4), F: (2, 13, 1)
>>> # Flexible mode, static and future are None
>>> s, d, f = prepare_model_inputs_in(dyn_in, static_input=None, future_input=None,
...                                model_type='flexible')
>>> print(f"S: {s is None}, D: {d.shape}, F: {f is None}")
S: True, D: (2, 10, 4), F: True
geoprior.models.utils.prepare_spatial_future_data(final_processed_data, feature_columns, dynamic_feature_indices, sequence_length=1, dt_col='date', static_feature_names=None, forecast_horizon=None, future_years=None, encoded_cat_columns=None, scaling_params=None, spatial_cols=None, squeeze_last=False, verbosity=0)[source]#

Prepare future static and dynamic inputs for making predictions.

This function prepares the necessary static and dynamic inputs required for forecasting future values in time series data. It processes the provided dataset by grouping it by location_id, extracting the last sequence of data points based on the specified sequence_length, and generating future inputs for prediction over the defined forecast_horizon.

The function handles both integer and datetime representations of the dt_col, extracting the year from datetime columns when necessary. It also allows for flexibility in specifying static features and encoded categorical variables.

(53)#\[\text{scaled\_time} = \frac{\text{future\_time} - \mu}{\sigma}\]
Parameters:
  • final_processed_data (pandas.DataFrame) – The processed DataFrame containing all features and targets. Must include the location_id column and the specified dt_col.

  • feature_columns (List[str]) – List of feature column names to be used for dynamic input preparation.

  • dynamic_feature_indices (List[int]) – Indices of dynamic features in feature_columns. These features are considered time-dependent and are used to prepare dynamic inputs.

  • sequence_length (int, optional) – The number of past time steps to include in each input sequence. Default is 1.

  • dt_col (str, optional) – The name of the time-related column in final_processed_data. Defaults to 'date'.

  • static_feature_names (List[str], optional) – List of static feature column names. If not provided, defaults to ['longitude', 'latitude'] plus any encoded_cat_columns.

  • forecast_horizon (int, optional) – The number of future time steps to predict. If set to None, the function defaults to predicting the next immediate time step.

  • future_years (List[int], optional) – List of future years to predict. Must match the length of forecast_horizon if forecast_horizon is provided.

  • encoded_cat_columns (List[str], optional) – List of encoded categorical column names to be treated as static features.

  • scaling_params (Dict[str, Dict[str, float]], optional) – Dictionary containing scaling parameters (mean and standard deviation) for features. Example: {'year': {'mean': 2000, 'std': 10}}. If not provided, the function computes the mean and std for the dt_col.

  • squeeze_last (bool, default True,) – Squeeze the last axis which correspond to the output dimension y if equal to 1.

  • verbosity (int, optional) – Verbosity level from 0 to 7 for debugging and understanding the process. Higher values produce more detailed logs.

  • spatial_cols (tuple[str, str])

Returns:

A tuple containing:

  • future_static_inputsnumpy.ndarray

    Array of future static inputs with shape (num_samples, num_static_vars, 1).

  • future_dynamic_inputsnumpy.ndarray

    Array of future dynamic inputs with shape (num_samples, sequence_length, num_dynamic_vars, 1).

  • future_years_listList[int]

    List of future time values corresponding to each sample.

  • location_ids_listList[int]

    List of location IDs corresponding to each sample.

  • longitudesList[float]

    List of longitude values corresponding to each sample.

  • latitudesList[float]

    List of latitude values corresponding to each sample.

Return type:

Tuple[np.ndarray, np.ndarray, List[int], List[int], List[float], List[float]]

Examples

>>> from geoprior.models.utils import prepare_spatial_future_data
>>> import pandas as pd
>>> data = pd.DataFrame({
...     'location_id': [1, 1, 1, 2, 2, 2],
...     'year': [2018, 2019, 2020, 2018, 2019, 2020],
...     'longitude': [10.0, 10.0, 10.0, 20.0, 20.0, 20.0],
...     'latitude': [50.0, 50.0, 50.0, 60.0, 60.0, 60.0],
...     'temperature': [15, 16, 15.5, 20, 21, 20.5],
...     'rainfall': [100, 110, 105, 200, 210, 205],
...     'encoded_cat': [1, 1, 1, 2, 2, 2]
... })
>>> feature_cols = ['year', 'temperature', 'rainfall', 'encoded_cat']
>>> dynamic_indices = [0, 1, 2]
>>> future_static, future_dynamic, future_years, loc_ids, longs,\
    lats = prepare_spatial_future_data(
...     final_processed_data=data,
...     feature_columns=feature_cols,
...     dynamic_feature_indices=dynamic_indices,
...     sequence_length=2,
...     forecast_horizon=1,
...     future_years=[2021],
...     encoded_cat_columns=['encoded_cat'],
...     verbosity=5,
...     dt_col='year'
... )
>>> print(future_static.shape)
(2, 3, 1)
>>> print(future_dynamic.shape)
(2, 2, 3, 1)

Notes

  • The function handles both integer and datetime representations of the dt_col. If dt_col is a datetime type, the year is extracted for scaling purposes.

  • If forecast_horizon is set to None, the function defaults to generating data for the next immediate time step based on the last entry in the time column.

  • Ensure that the length of future_years matches forecast_horizon if forecast_horizon is provided.

  • The static_feature_names parameter allows for flexibility in specifying which static features to include. If not provided, it defaults to ['longitude', 'latitude'] plus any encoded_cat_columns.

See also

prepare_future_data

Main function for preparing future data inputs.

geoprior.models.utils.set_default_params(quantiles=None, scales=None, multi_scale_agg=None)[source]#

Sets and validates default values for quantiles, scales, and return_sequences parameters.

Parameters:
  • quantiles (str, list of float, or None, optional) –

    Specifies the quantiles to be used for probabilistic forecasting.

    If set to 'auto', it defaults to [0.1, 0.5, 0.9]. If a list is provided, each element must be a float between 0 and 1 exclusive. If None, it remains None and can be used for deterministic forecasting.

  • scales (str, list of int, or None, optional) –

    Specifies the scaling factors to be used in multi-scale processing.

    If set to 'auto' or None, it defaults to [1]. If a list is provided, each element must be a positive integer.

  • multi_scale_agg (str or None, optional) –

    Determines the aggregation method for multi-scale features.

    If set to None, return_sequences is False. Otherwise, return_sequences is True. Expected aggregation methods include 'average', 'concat', 'sum', 'last', and 'auto' (which falls back to 'last'), depending on model requirements.

Returns:

Tuple containing validated quantiles, validated scales, and the return_sequences flag derived from multi_scale_agg.

Return type:

Tuple[List[float], List[int], bool]

Raises:

ValueError – If quantiles is neither ‘auto’ nor a list of valid floats. If scales is neither ‘auto’ nor a list of valid positive integers. If multi_scale_agg is provided but not a recognized aggregation method.

Examples

>>> # Example 1: Using default 'auto' settings
>>> quantiles, scales, return_sequences = set_default_parameters(
    quantiles='auto', scales='auto', multi_scale_agg='auto')
>>> print(quantiles)
[0.1, 0.5, 0.9]
>>> print(scales)
[1]
>>> print(return_sequences)
True
>>> # Example 2: Providing custom quantiles and scales
>>> quantiles, scales, return_sequences = set_default_parameters(
...     quantiles=[0.05, 0.5, 0.95],
...     scales=[1, 2, 4],
...     multi_scale_agg='concat'
... )
>>> print(quantiles)
[0.05, 0.5, 0.95]
>>> print(scales)
[1, 2, 4]
>>> print(return_sequences)
True
>>> # Example 3: Invalid quantiles input
>>> set_default_parameters(quantiles=[-0.1, 1.2])
Traceback (most recent call last):
...
ValueError: Each quantile must be a float between 0 and 1 (exclusive).
Invalid quantiles: [-0.1, 1.2]
>>> # Example 4: Invalid scales input
>>> set_default_parameters(scales=[0, -2])
Traceback (most recent call last):
...
ValueError: Each scale must be a positive integer. Invalid scales: [0, -2]
geoprior.models.utils.split_static_dynamic(sequences, static_indices, dynamic_indices, static_time_step=0, reshape_static=True, reshape_dynamic=True, static_reshape_shape=None, dynamic_reshape_shape=None)[source]#

Split sequences into static and dynamic inputs for the model.

The split_static_dynamic function divides input sequences into static and dynamic components based on specified feature indices. Static features are typically location-specific and do not change over time, while dynamic features vary across different time steps.

(54)#\[\begin{split}\text{Static Inputs} = \mathbf{S} = \mathbf{X}_{t, static\_indices} \\ \text{Dynamic Inputs} = \mathbf{D} = \mathbf{X}_{:, dynamic\_indices}\end{split}\]
Parameters:
  • sequences (numpy.ndarray) – Array of input sequences with shape (batch_size, sequence_length, num_features).

  • static_indices (List[int]) – Indices of static features within the feature dimension.

  • dynamic_indices (List[int]) – Indices of dynamic features within the feature dimension.

  • static_time_step (int, default 0) – Time step from which to extract static features (default is the first time step).

  • reshape_static (bool, default True) – Whether to reshape static inputs. If False, returns without reshaping.

  • reshape_dynamic (bool, default True) – Whether to reshape dynamic inputs. If False, returns without reshaping.

  • static_reshape_shape (:py:class:Optional[Tuple[int`, :py:class:`...]], default None) – Desired shape for static inputs after reshaping. If None, defaults to (batch_size, num_static_vars, 1).

  • dynamic_reshape_shape (:py:class:Optional[Tuple[int`, :py:class:`...]], default None) – Desired shape for dynamic inputs after reshaping. If None, defaults to (batch_size, sequence_length, num_dynamic_vars, 1).

Returns:

A tuple containing: - Static inputs with shape as specified. - Dynamic inputs with shape as specified.

Return type:

Tuple[`numpy.ndarray`, numpy.ndarray]

Raises:

ValueError – If static_time_step is out of range for the given sequence length.

Examples

>>> import numpy as np
>>> from geoprior.models.utils import split_static_dynamic
>>>
>>> # Create a dummy sequence array
>>> sequences = np.random.rand(100, 10, 5)  # (
...   batch_size=100, sequence_length=10, num_features=5)
>>>
>>> # Define static and dynamic feature indices
>>> static_indices = [0, 1]  # e.g., longitude and latitude
>>> dynamic_indices = [2, 3, 4]  # e.g., year, GWL, density
>>>
>>> # Split the sequences
>>> static_inputs, dynamic_inputs = split_static_dynamic(
...     sequences,
...     static_indices=static_indices,
...     dynamic_indices=dynamic_indices,
...     static_time_step=0
... )
>>>
>>> print(static_inputs.shape)
(100, 2, 1)
>>> print(dynamic_inputs.shape)
(100, 10, 3, 1)

Notes

  • Static Features: These are typically location-specific features such as geographical coordinates or categorical attributes that remain constant over time.

  • Dynamic Features: These features vary over different time steps and are essential for capturing temporal dependencies in the data.

  • Reshaping: The function provides flexibility in reshaping the static and dynamic inputs to match the input requirements of various models, including Temporal Fusion Transformers.

See also

geoprior.models.utils.create_sequences

Function to create input sequences and targets for time series forecasting.

geoprior.models.utils.squeeze_last_dim_if(tensors, output_dims)[source]#

Squeeze the last dimension of tensor(s) if it equals 1 based on output_dims.

output_dims can be:
  • a single int: apply that rule to every tensor

  • a list of ints: must match length of tensors list, each applied to corresponding tensor

  • a dict mapping keys to ints: each key must appear in tensors dict; unmatched keys in output_dims will warn but be ignored

Parameters:
  • tensors (tf.Tensor or list of tf.Tensor or dict of tf.Tensor) –

    • If a single tf.Tensor, processed with output_dims if int.

    • If a list, each element is processed with its corresponding entry in output_dims if list, or with the single int if int.

    • If a dict, each value is processed with the matching int in output_dims dict, or with the single int if output_dims is int.

  • output_dims (int or list of int or dict of int) –

    • If int: all tensors use that output dimension rule.

    • If list: length must equal len(tensors) when tensors is list.

    • If dict: keys map to ints; for any key in tensors missing from output_dims, no squeeze is applied; keys in output_dims absent in tensors produce a warning.

Returns:

Same structure as tensors, but with trailing size-1 axes removed where output_dim == 1. If output_dim != 1, that tensor is unchanged.

Return type:

tf.Tensor or list or dict

Raises:
  • TypeError – If tensors is not a tf.Tensor, list, or dict; or if output_dims type is invalid; or if list lengths do not match.

  • ValueError – If list lengths mismatch.

Examples

>>> import tensorflow as tf
>>> from geoprior.utils.generic_utils import squeeze_last_dim_if
>>>
>>> # Single tensor, single int → squeeze if last dim=1
>>> t = tf.zeros((8, 5, 1))
>>> out = squeeze_last_dim_if(t, output_dims=1)
>>> out.shape
TensorShape([8, 5])
>>>
>>> # List of tensors, list of ints
>>> t_list = [tf.zeros((4, 1)), tf.zeros((4, 2, 1))]
>>> out_list = squeeze_last_dim_if(t_list, output_dims=[1, 2])
>>> [x.shape for x in out_list]
[TensorShape([4]), TensorShape([4, 2, 1])]
>>>
>>> # Dict of tensors, dict of ints
>>> t_dict = {
...     "a": tf.zeros((3, 1, 1)),
...     "b": tf.zeros((3, 1, 2))
... }
>>> out_dict = squeeze_last_dim_if(t_dict, output_dims={"a": 1, "b": 2})
>>> {k: v.shape for k, v in out_dict.items()}
{'a': TensorShape([3, 1]), 'b': TensorShape([3, 1, 2])}

Notes

  • This function does a shallow traversal: * If tensors is list, each element must be tf.Tensor. * If tensors is dict, each value must be tf.Tensor.

  • For dict mode: missing keys in output_dims → no change; extra keys in output_dims → warning.

See also

tf.squeeze

Remove dimensions of size 1 from the shape of a tensor.

tf.reshape

Manually reshape a tensor if more complex edits are needed.

geoprior.models.utils.step_to_long(df, tname=None, dt_col=None, spatial_cols=None, mode='quantile', quantiles=None, verbose=3, sort=True)[source]#

Convert a multi-step forecast DataFrame from wide to long format.

This function transforms a DataFrame containing multi-step forecast predictions into a long-format DataFrame. In quantile mode, forecast columns such as subsidence_q10_step1, subsidence_q50_step1, etc. are consolidated into unified columns (e.g. subsidence_q10, subsidence_q50, etc.), while in point mode, a single prediction column (subsidence_pred) is generated. The transformation also carries over additional columns (e.g. spatial coordinates and time) from the original DataFrame.

Parameters:
  • df (pandas.DataFrame) – The multi-step forecast DataFrame. Expected to contain forecast prediction columns (e.g. columns with _q or _pred_step in their names) along with other identifiers.

  • tname (str, optional) – The base name of the target variable (e.g. "subsidence"). If None, the function attempts to auto-detect the target name from the column names.

  • dt_col (str, optional) – The name of the time column to include in the final DataFrame. If not provided, time sorting is not performed.

  • spatial_cols (list of str, optional) – A list of spatial coordinate columns (e.g. ["longitude", "latitude"]) to be retained in the final output.

  • mode ({"quantile", "point"}, default "quantile") – The forecast mode. In "quantile" mode, multiple quantile forecast columns are merged into unified columns. In "point" mode, a single prediction column is produced.

  • quantiles (list of float, optional) – The quantile values for quantile mode (e.g. [0.1, 0.5, 0.9]). If not provided, defaults are used.

  • sort (bool, optional) – If True, sorts the final DataFrame by the column specified in dt_col (if present). Default is True.

  • verbose (int, optional) – Verbosity level for logging output. Higher values (e.g. 5 to 7) provide more detailed debug information.

Returns:

A long-format DataFrame containing the retained spatial columns, the time column when dt_col is provided, and the merged forecast prediction columns. In quantile mode, the output contains unified columns such as subsidence_q10 and subsidence_q50. In point mode, it contains a single subsidence_pred column.

Return type:

pandas.DataFrame

Examples

>>> from geoprior.models.utils import step_to_long
>>> # Given a DataFrame `forecast_df` with columns like:
>>> # ['longitude', 'latitude', 'year', 'subsidence_actual',
>>> #  'subsidence_q10_step1', 'subsidence_q50_step1', 'subsidence_q89_step1',
>>> #  'subsidence_q10_step2', ...]
>>> long_df = step_to_long(
...     df=forecast_df,
...     tname="subsidence",
...     dt_col="year",
...     spatial_cols=["longitude", "latitude"],
...     mode="quantile",
...     quantiles=[0.1, 0.5, 0.9],
...     verbose=3,
...     sort=True
... )
>>> print(long_df.head())

Notes

Internally, this function calls:

  • check_forecast_mode() to validate the user-specified quantiles.

  • validate_consistency_q() and validate_quantiles() to ensure the supplied quantiles match those auto-detected from the DataFrame.

  • Depending on mode, either _step_to_long_q() for quantile mode or _step_to_long_pred() for point mode performs the conversion.

Mathematically, let \(X \in \mathbb{R}^{n \times m}\) represent the wide-format DataFrame, where each row corresponds to one sample and each forecast step is stored in separate columns. The function reshapes \(X\) into a long-format DataFrame \(Y \in \mathbb{R}^{(n \cdot s) \times p}\), where \(s\) is the forecast horizon and \(p\) is the number of output columns after merging forecast step values.

See also

_step_to_long_q

Converts multi-step quantile forecasts to long format.

_step_to_long_pred

Converts multi-step point forecasts to long format.

detect_digits

Extracts numeric values from strings for quantile detection.

geoprior.models.utils.export_keras_losses(history, keys=None, savefile=None, verbose=0, formats=('json', 'csv'))[source]#

Export loss(es) (and any other metric) from a Keras History object.

Parameters:
  • history (History) – The History object returned by model.fit().

  • keys (list[str] | None) – List of history.history keys to export, e.g. [“loss”, “val_loss”, “physics_loss”]. If None, defaults to all keys ending with “loss”.

  • savefile (str | None) –

    Path (optionally with extension) where to write the output.

    If the extension is .json, only JSON is written. If the extension is .csv, only CSV is written. If no extension is given, all formats in formats are written using savefile as the base name.

  • verbose (int) – If >0, prints status messages.

  • formats (tuple[str, ...]) – File formats to write when savefile has no extension. Valid entries are “json” and “csv”.

Returns:

result – Dictionary containing "epochs_run" and one list entry per exported history key.

Return type:

dict

geoprior.models.utils.get_tensor_from(inputs, *tensor_names, default=None, check_type=True, auto_convert=True)[source]#

Safely retrieves the first available tensor from a dictionary using a list of possible keys.

This utility is crucial for handling model inputs within a TensorFlow graph (e.g., in train_step). It avoids the ambiguous boolean evaluation of Tensors (e.g., tensor_a or tensor_b), which causes runtime errors, by explicitly checking for is not None.

Parameters:
  • inputs (dict) – The dictionary to search, typically the model’s input dictionary (e.g., the inputs provided to call or train_step).

  • *tensor_names (str) – One or more string keys to check for in the inputs dictionary, in order of priority.

  • default (Any, optional) – A default value to return if no keys are found or if no found value is a valid tensor. Defaults to None.

  • check_type (bool, default True) – If True, only returns a value if it is (or can be converted to) a Tensor or Variable. If False, returns the first non-None value regardless of its type.

  • auto_convert (bool, default True) – If True and check_type is True, this function will attempt to convert a found non-Tensor value (like a NumPy array or a list) into a TensorFlow tensor using tf.convert_to_tensor.

Returns:

The first found tf.Tensor or tf.Variable associated with one of the tensor_names. If auto_convert is True, this can also be a newly converted tensor. Returns default (typically None) if no valid tensor is found.

Return type:

Optional[tf.Tensor]

Raises:

TypeError – If inputs is not a dictionary.

Examples

>>> import tensorflow as tf
>>> inputs_dict = {
...     'some_other_key': [1, 2, 3],
...     'soil_thickness': tf.constant([20., 21.], dtype=tf.float32)
... }
>>>
>>> # Correctly finds 'soil_thickness'
>>> get_tensor_from(inputs_dict, 'H_field', 'soil_thickness')
<tf.Tensor: shape=(2,), dtype=float32, numpy=array([20., 21.], ...)>
>>>
>>> # Returns None safely if nothing is found
>>> get_tensor_from(inputs_dict, 'missing_key', 'another_key')
None
>>>
>>> # Demonstrating auto_convert
>>> inputs_dict_np = {'H_field': np.array([10., 11.])}
>>> get_tensor_from(inputs_dict_np, 'H_field', auto_convert=True)
<tf.Tensor: shape=(2,), dtype=float32, numpy=array([10., 11.], ...)>
geoprior.models.utils.format_pihalnet_predictions(pihalnet_outputs=None, model=None, model_inputs=None, y_true_dict=None, target_mapping=None, include_gwl=True, include_coords=True, quantiles=None, forecast_horizon=None, output_dims=None, ids_data_array=None, ids_cols=None, ids_cols_indices=None, scaler_info=None, coord_scaler=None, evaluate_coverage=False, coverage_quantile_indices=(0, -1), savefile=None, name=None, model_name=None, apply_mask=False, mask_values=None, mask_fill_value=None, verbose=0, _logger=None, stop_check=None, **kwargs)[source]#

Formats PIHALNet/GeoPriorSubsNet predictions into a structured pandas DataFrame, handling inversion, quantiles, and coordinates.

This function is the core formatter. It: 1. Gets model outputs (or uses provided ones). 2. Unpacks ‘data_final’ if model_name is ‘geoprior’. 3. Inverse-transforms all prediction and actual arrays using scaler_info. 4. Builds a long-format DataFrame with sample_idx and forecast_step. 5. Appends inverted quantile/point predictions. 6. Appends inverted actual values. 7. Appends inverted coordinates. 8. Appends static/ID columns. 9. Evaluates coverage on the inverted data.

Parameters:
  • pihalnet_outputs (dict, optional) – Raw output from model.predict(). If None, model and model_inputs must be provided.

  • model (tf.keras.Model, optional) – Trained model instance (if pihalnet_outputs is None).

  • model_inputs (dict, optional) – Inputs for the model to generate predictions (if pihalnet_outputs is None).

  • y_true_dict (dict, optional) – Dictionary of true target arrays (e.g., {‘subs_pred’: y_true_s}). Required for including actuals and evaluating coverage.

  • target_mapping (dict, optional) – Maps prediction keys to base names for DataFrame columns. Default: {‘subs_pred’: ‘subsidence’, ‘gwl_pred’: ‘gwl’}.

  • include_gwl (bool, default True) – Whether to include ‘gwl_pred’ in the final DataFrame.

  • include_coords (bool, default True) – Whether to include ‘coord_t’, ‘coord_x’, ‘coord_y’ columns.

  • quantiles (list[float], optional) – List of quantiles (e.g., [0.1, 0.5, 0.9]). If provided, quantile columns (e.g., ‘subsidence_q10’) are created.

  • forecast_horizon (int, optional) – The forecast horizon length (H). If not provided, it’s inferred from the prediction array’s shape.

  • output_dims (dict, optional) – Maps prediction keys to their output dimension (O). E.g., {‘subs_pred’: 1, ‘gwl_pred’: 1}. Crucial for correctly splitting GeoPrior outputs and reshaping.

  • ids_data_array (np.ndarray or pd.DataFrame, optional) – Static/ID data (e.g., original coordinates) to merge. Must have the same number of samples (B) as predictions.

  • ids_cols (list[str], optional) – Column names if ids_data_array is a DataFrame.

  • ids_cols_indices (list[int], optional) – Column indices if ids_data_array is a NumPy array.

  • scaler_info (dict, optional) – Dictionary for inverse scaling. Each target entry should provide a fitted scaler, the target index inside that scaler, and the feature-name ordering used when the scaler was fit.

  • coord_scaler (sklearn.preprocessing.Scaler, optional) – A fitted scaler object for inverse transforming the ‘coords’ tensor.

  • evaluate_coverage (bool, default False) – If True, calculates coverage percentage for quantiles.

  • coverage_quantile_indices (tuple[int, int], default (0, -1)) – Indices of the low and high quantiles in the quantiles list to use for coverage (e.g., 0 and -1 for 10th and 90th).

  • savefile (str, optional) – If provided, saves the final DataFrame to this path.

  • model_name (str, optional) – Specifies the model type. If ‘geoprior’ or ‘geopriorsubsnet’, triggers unpacking of the ‘data_final’ output.

  • apply_mask (bool, default False) – If True, masks predictions based on mask_values in the first target’s _actual column.

  • mask_values (float or int, optional) – The value in the _actual column to trigger masking.

  • mask_fill_value (float, optional) – The value to replace masked predictions with (e.g., np.nan).

  • verbose (int, default 0) – Logging verbosity.

  • _logger (logging.Logger or callable, optional) – Logger object.

  • stop_check (callable, optional) – Function to check for early stopping.

  • name (str | None)

Returns:

A long-format DataFrame with predictions, actuals, and coordinates.

Return type:

pd.DataFrame

geoprior.models.utils.normalize_for_pinn(df, time_col, coord_x, coord_y, cols_to_scale='auto', scale_coords=True, exclude_cols=None, protect_si_suffix='__si', shift_time_by_horizon=False, verbose=1, forecast_horizon=None, _logger=None, coord_scaler=None, fit_coord_scaler=True, other_scaler=None, fit_other_scaler=True, **kws)[source]#

Apply Min-Max normalization to spatial-temporal coordinates and optionally to other numeric columns. If cols_to_scale == "auto", automatically select numeric columns excluding categorical and one-hot features.

By default, this function scales the time, longitude, and latitude columns (if scale_coords=True). Then, it either scales explicitly provided columns in cols_to_scale or automatically infers numeric columns (excluding coordinates if scale_coords is False, and excluding one-hot/boolean columns).

The Min-Max scaling for a feature \(x\) is:

(55)#\[x' = \frac{x - \min(x)}{\max(x) - \min(x)}\]
Parameters:
  • df (pd.DataFrame) – The input DataFrame containing at least time_col, coord_x, and coord_y columns. The DataFrame should contain temporal and spatial information to be scaled.

  • time_col (str) – The name of the numeric time column (e.g., year as numeric or datetime). This column will be used to adjust and scale the temporal data.

  • coord_x (str) – The name of the longitude column in the DataFrame. This column will be scaled along with the latitude and time columns.

  • coord_y (str) – The name of the latitude column in the DataFrame. This column will be scaled along with the longitude and time columns.

  • cols_to_scale (list of str or "auto" or None, default "auto") – If a list of column names, scales exactly those columns. If "auto", selects all numeric columns, excluding time_col, coord_x, and coord_y if scale_coords=False, and excluding one-hot encoded columns whose values are only {0, 1}. If None, no extra columns are scaled.

  • scale_coords (bool, default True) – If True, scales the [time_col, coord_x, coord_y] columns. If False, these columns remain unchanged.

  • verbose (int, default 1) – Verbosity level for logging. Values higher than 1 provide more detailed logging information.

  • forecast_horizon (Optional[int], default None) – The number of time steps to shift the time column by. This is added to the time values before scaling if provided.

  • _logger (Optional[Union[logging.Logger, Callable[[str], None]]], default None) – Logger or function to handle logging messages. If None, the default logging mechanism is used.

  • kws (dict, optional) – These will be passed on to any other internal function used in the data processing or scaling steps.

  • exclude_cols (list[str] | None)

  • protect_si_suffix (str)

  • shift_time_by_horizon (bool)

  • coord_scaler (MinMaxScaler | None)

  • fit_coord_scaler (bool)

  • other_scaler (MinMaxScaler | None)

  • fit_other_scaler (bool)

Returns:

  • df_scaled (pd.DataFrame) – A new DataFrame with the specified columns normalized.

  • coord_scaler (MinMaxScaler or None) – The fitted scaler for the [time_col, coord_x, coord_y] columns if scale_coords=True, else None.

  • other_scaler (MinMaxScaler or None) – The fitted scaler for any additional columns that were scaled (either explicitly provided or auto-selected). None if no columns were scaled beyond the coordinates.

Raises:
  • TypeError – If df is not a DataFrame, or cols_to_scale is neither a list nor “auto” nor None, or if any explicitly provided column is not a string.

  • ValueError – If required columns (time_col, coord_x, coord_y) or any of cols_to_scale do not exist in df, or cannot be converted to numeric.

Return type:

tuple[DataFrame, MinMaxScaler | None, MinMaxScaler | None]

Examples

>>> import pandas as pd
>>> from geoprior.nn.pinn.utils import normalize_for_pinn
>>> data = {
...     "year_num": [0.0, 1.0, 2.0],
...     "lon": [100.0, 101.0, 102.0],
...     "lat": [30.0, 31.0, 32.0],
...     "feat1": [10.0, 20.0, 30.0],
...     "one_hot_A": [0, 1, 0]
... }
>>> df = pd.DataFrame(data)
>>> df_scaled, coord_scl, feat_scl = normalize_for_pinn(
...     df,
...     time_col="year_num",
...     coord_x="lon",
...     coord_y="lat",
...     cols_to_scale="auto",
...     scale_coords=True,
...     verbose=2
... )
>>> # 'year_num','lon','lat','feat1' get scaled; 'one_hot_A' excluded
>>> df_scaled["year_num"].tolist()
[0.0, 0.5, 1.0]
>>> df_scaled["feat1"].tolist()
[0.0, 0.5, 1.0]

Notes

  • When cols_to_scale="auto", numeric columns with only {0, 1} values are assumed to be one-hot and excluded from scaling.

  • If scale_coords=False, coordinate columns remain unchanged, and auto-selection (if used) will exclude them.

  • Returned coord_scaler is None if scale_coords=False. Returned other_scaler is None if cols_to_scale is None or results in an empty set after filtering.

See also

sklearn.preprocessing.MinMaxScaler

Scales features to [0,1].

geoprior.models.utils.prepare_pinn_data_sequences(df, time_col, subsidence_col, gwl_col, dynamic_cols, static_cols=None, future_cols=None, spatial_cols=None, h_field_col=None, lon_col=None, lat_col=None, group_id_cols=None, time_steps=12, forecast_horizon=3, output_subsidence_dim=1, output_gwl_dim=1, datetime_format=None, normalize_coords=True, cols_to_scale=None, lock_physics_cols=True, protect_si_suffix='__si', return_coord_scaler=False, coord_scaler=None, fit_coord_scaler=True, mode=None, model=None, savefile=None, progress_hook=None, stop_check=None, verbose=0, _logger=None, **kws)[source]#
Parameters:
Return type:

tuple[dict[str, ndarray], dict[str, ndarray]] | tuple[dict[str, ndarray], dict[str, ndarray], MinMaxScaler | None]

geoprior.models.utils.format_pinn_predictions(predictions=None, model=None, model_inputs=None, y_true_dict=None, target_mapping=None, include_gwl=True, include_coords=True, quantiles=None, forecast_horizon=None, output_dims=None, ids_data_array=None, ids_cols=None, ids_cols_indices=None, scaler_info=None, coord_scaler=None, evaluate_coverage=False, coverage_quantile_indices=(0, -1), savefile=None, _logger=None, name=None, model_name=None, stop_check=None, verbose=0, **kwargs)[source]#

Formats PINN model predictions into a structured pandas DataFrame.

This is a general-purpose utility for transforming raw model outputs (from models like PIHALNet or TransFlowSubsNet) into a long-format DataFrame suitable for analysis, visualization, or export.

This is a powerful, general-purpose utility for transforming raw model outputs into a long-format DataFrame suitable for analysis, visualization, or export. It handles multi-target outputs (e.g., subsidence and GWL), point or quantile forecasts, and can optionally include true values, coordinate information, and other metadata. It also supports inverse-scaling of predictions and evaluation of quantile coverage.

Parameters:
  • predictions (dict of Tensors, optional) – The dictionary of prediction tensors, typically returned by a model’s .predict() method. Keys should match the model’s output layer names (e.g., 'subs_pred', 'gwl_pred'). If None, predictions are generated internally using the model and model_inputs arguments. Default is None.

  • model (keras.Model, optional) – A compiled Keras model instance used to generate predictions if the predictions dictionary is not provided. Default is None.

  • model_inputs (dict of Tensors, optional) – A dictionary of input tensors matching the model’s signature, required only if predictions is None. Default is None.

  • y_true_dict (dict, optional) – A dictionary containing the ground-truth target arrays, keyed by their base names (e.g., 'subsidence', 'gwl'). If provided, an <target>_actual column will be added to the output DataFrame for comparison. Default is None.

  • target_mapping (dict, optional) – A custom mapping from model output keys to desired base names in the DataFrame columns. For example: {'subs_pred': 'subsidence_mm', 'gwl_pred': 'head_m'}. Default is None.

  • include_gwl (bool, default True) – Toggles the inclusion of groundwater level (GWL) predictions in the final DataFrame.

  • include_coords (bool, default True) – Toggles the inclusion of the spatio-temporal coordinate columns (coord_t, coord_x, coord_y) in the final DataFrame.

  • quantiles (list of float, optional) – The list of quantile levels (e.g., [0.1, 0.5, 0.9]) that the model predicted. This is crucial for correctly parsing probabilistic forecasts. Default is None.

  • forecast_horizon (int, optional) – The length of the forecast horizon. If None, it is inferred from the shape of the prediction tensors. Default is None.

  • output_dims (dict of str, optional) – A dictionary specifying the feature dimension of each target, e.g., {'subs_pred': 1, 'gwl_pred': 1}. If None, it’s inferred from the tensor shapes. Default is None.

  • ids_data_array (np.ndarray or pd.DataFrame, optional) – An array or DataFrame containing static identifiers (e.g., well IDs, site categories) for each sample. Its length must match the number of samples in the prediction. Default is None.

  • ids_cols (list of str, optional) – A list of column names for the ids_data_array. Required if ids_data_array is a NumPy array. Default is None.

  • ids_cols_indices (list of int, optional) – A list of column indices to select from ids_data_array if it is a NumPy array. Default is None.

  • scaler_info (dict, optional) – A dictionary providing the necessary information to perform inverse scaling on a per-target basis. Each key should be a target name (e.g., ‘subsidence’) and its value a dictionary containing {'scaler': obj, 'all_features': list, 'idx': int}. Default is None.

  • coord_scaler (object, optional) – A fitted scikit-learn-like scaler object used to perform an inverse transform on the coordinate columns. Default is None.

  • evaluate_coverage (bool, default False) – If True and quantile predictions are present, calculates the unconditional coverage of the prediction interval.

  • coverage_quantile_indices (tuple of (int, int), default (0, -1)) – The indices of the lower and upper quantiles in the sorted quantiles list to use for the coverage calculation. Default is (0, -1), which corresponds to the full range.

  • savefile (str, optional) – If a file path is provided, the final DataFrame is saved to a CSV file at this location. Default is None.

  • name (str or None) – Name of the prediction. Name is used to format the output of the data and coverage result if applicable.

  • model_name (str, None,) – Name of the model.

  • verbose (int, default 0) – The verbosity level, from 0 (silent) to 5 (trace every step).

  • **kwargs (dict,) – Additional keyword arguments for future extensions.

  • _logger (Logger | Callable[[str], None] | None)

  • stop_check (Callable[[], bool])

Returns:

A long-format DataFrame where each row represents a single forecast step for a single sample. Columns include sample and step identifiers, coordinates, predictions, and optionally actuals and metadata.

Return type:

pd.DataFrame

Notes

  • The function returns a column-aligned DataFrame, which simplifies subsequent analysis and plotting.

  • For quantile forecasts, prediction columns are named using the pattern <target_name>_q<quantile*100>, e.g., subsidence_q5, subsidence_q50, subsidence_q95.

  • For point forecasts, the column is named <target_name>_pred.

See also

geoprior.plot.forecast.plot_forecasts

A powerful utility for visualizing the DataFrame produced by this function.

geoprior.models.utils.extract_txy(inputs, coord_slice_map=None, expect_dim=None, verbose=0, _logger=None, **kws)[source]#

Extracts t, x, y tensors from various input formats.

This utility standardizes coordinate inputs, accepting a single tensor or a dictionary, and handling both 2D (spatial/static) and 3D (spatio-temporal) data with flexible dimension validation.

Parameters:
  • inputs (tf.Tensor, np.ndarray, or dict) – The input data containing coordinates. Can be a single tensor or a dictionary with ‘coords’ or ‘t’, ‘x’, ‘y’ keys.

  • coord_slice_map (dict, optional) – Mapping for ‘t’, ‘x’, ‘y’ to their index in the last dimension of a coordinate tensor. Defaults to {‘t’: 0, ‘x’: 1, ‘y’: 2}.

  • expect_dim ({'2d', '3d', '3d_only'}, optional) – Enforces a constraint on the input’s dimension. '2d' requires input shaped like (batch, 3). '3d' accepts 3D input and expands 2D input to 3D with a time dimension of 1. '3d_only' requires 3D input and raises an error for 2D input. None accepts both 2D and 3D inputs without changing their rank.

  • verbose (int, default 0) – Controls logging verbosity.

  • _logger (Logger | Callable[[str], None] | None)

Returns:

t, x, y – The extracted t, x, and y coordinate tensors. Their rank (2D or 3D) depends on the input and the expect_dim mode.

Return type:

Tuple[tf.Tensor, tf.Tensor, tf.Tensor]

Raises:

ValueError – If input format is unsupported, dimensions are inconsistent, or expect_dim constraint is violated.

geoprior.models.utils.plot_hydraulic_head(model, t_slice, x_bounds, y_bounds, resolution=100, ax=None, title=None, cmap='viridis', colorbar_label='Hydraulic Head (h)', save_path=None, show_plot=True, **contourf_kwargs)[source]#

Generate and plot a 2D contour map of a hydraulic head solution.

This utility visualizes the output of a Physics-Informed Neural Network (PINN) that solves for the hydraulic head \(h(t, x, y)\). It automates the process of creating a spatial grid, running model predictions, and generating a publication-quality contour plot for a specific slice in time.

Parameters:
  • model (tf.keras.Model) – The trained PINN model. It is expected to have a .predict() method that accepts a dictionary of tensors with keys {'t', 'x', 'y'}.

  • t_slice (float) – The specific point in time \(t\) for which to plot the 2D spatial solution.

  • x_bounds (tuple of float) – A tuple (x_min, x_max) defining the spatial domain for the x-axis.

  • y_bounds (tuple of float) – A tuple (y_min, y_max) defining the spatial domain for the y-axis.

  • resolution (int, optional) – The number of points to sample along each spatial axis, creating a grid of resolution x resolution points for prediction. Higher values result in a smoother plot. Default is 100.

  • ax (matplotlib.axes.Axes, optional) – A pre-existing Matplotlib Axes object to plot on. If None, a new figure and axes are created internally. This is useful for embedding this plot within a larger figure arrangement. Default is None.

  • title (str, optional) – A custom title for the plot. If None, a default title is generated using the value of t_slice. Default is None.

  • cmap (str, optional) – The name of the Matplotlib colormap to use for the contour plot. Default is 'viridis'.

  • colorbar_label (str, optional) – The text label for the color bar. Default is 'Hydraulic Head (h)'.

  • save_path (str, optional) – If provided, the path (including filename and extension) where the generated plot will be saved. This is only active when the function creates its own figure (i.e., when ax is None). Default is None.

  • show_plot (bool, optional) – If True, calls plt.show() to display the plot. This is only active when the function creates its own figure. Default is True.

  • **contourf_kwargs (any) – Additional keyword arguments that are passed directly to the matplotlib.pyplot.contourf function. This allows for advanced customization (e.g., levels=20, extend='both').

Returns:

  • ax (matplotlib.axes.Axes) – The Matplotlib Axes object on which the contour plot was drawn.

  • contour (matplotlib.cm.ScalarMappable) – The contour plot object, which can be used for further customizations, such as modifying the color bar.

Return type:

tuple[Axes, _ScalarMappable]

See also

geoprior.models.pinn.PiTGWFlow

The PINN model this function is designed to visualize.

Notes

The core mechanism of this function involves creating a 2D meshgrid of \((x, y)\) coordinates. These grid points are then “flattened” into a long list of points, as the PINN model expects a batch of individual coordinates for prediction, not a grid.

The prediction process is as follows:

  1. A grid of shape (resolution, resolution) is created for \(x\) and \(y\).

  2. These grids are reshaped into column vectors of shape (resolution*resolution, 1).

  3. A time vector of the same shape, filled with t_slice, is created.

  4. The model’s .predict() method is called on these flat tensors.

  5. The resulting flat prediction vector is reshaped back to the original (resolution, resolution) grid shape for plotting.

If a custom ax is provided, the user is responsible for calling plt.show() or saving the parent figure.

Examples

>>> import numpy as np
>>> import tensorflow as tf
>>> import matplotlib.pyplot as plt
>>> # This is a mock model for demonstration purposes.
>>> # In practice, you would use a trained PiTGWFlow model.
>>> class MockPINN(tf.keras.Model):
...     def call(self, inputs):
...         # A simple analytical function for demonstration
...         t, x, y = inputs['t'], inputs['x'], inputs['y']
...         return tf.sin(np.pi * x) * tf.cos(np.pi * y) * tf.exp(-t)
...
>>> mock_model = MockPINN()

1. Simple Plotting Example

This example creates a single plot and saves it to a file.

>>> ax, contour = plot_hydraulic_head(
...     model=mock_model,
...     t_slice=0.5,
...     x_bounds=(-1, 1),
...     y_bounds=(-1, 1),
...     resolution=50,
...     save_path="hydraulic_head_t0.5.png",
...     show_plot=False  # Do not display interactively
... )
Plot saved to hydraulic_head_t0.5.png

2. Advanced Example with Subplots

This example shows how to use the ax parameter to draw the solution at two different times side-by-side in one figure.

>>> fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))
>>> fig.suptitle('Hydraulic Head at Different Times', fontsize=16)
...
>>> # Plot solution at t = 0.1
>>> plot_hydraulic_head(
...     model=mock_model, t_slice=0.1, x_bounds=(-1, 1),
...     y_bounds=(-1, 1), ax=ax1, show_plot=False
... )
...
>>> # Plot solution at t = 1.0
>>> plot_hydraulic_head(
...     model=mock_model, t_slice=1.0, x_bounds=(-1, 1),
...     y_bounds=(-1, 1), ax=ax2, show_plot=False
... )
...
>>> plt.tight_layout(rect=[0, 0, 1, 0.96])
>>> plt.show()
geoprior.models.utils.make_dict_to_tuple_fn(feature_keys, target_keys=None, *, allow_missing_optional=True)[source]#

Create a tf.data.Dataset.map function that converts a feature dictionary into a positional tuple expected by a sub-classed Keras model.

Parameters:
  • feature_keys (Sequence[str]) –

    Keys in the exact positional order required by the model that will populate the tuple. For your PINN:

    >>> feature_keys = [
    ...     "coords",           # (B, T, 3)
    ...     "static_features",  # (B, S)
    ...     "dynamic_features", # (B, T, D)
    ...     "future_features",  # (B, H, F)
    ... ]
    

  • target_keys (Sequence[str] or None, default None) – If None, pass through the dataset’s existing target element. If a sequence is provided, extract those keys from the feature dictionary or, when available, from the target dictionary. The function returns a single tensor when len(target_keys) == 1 and a {key: tensor} mapping otherwise.

  • allow_missing_optional (bool, default True) – Whether to substitute None for missing optional feature or target keys. If False, a missing key raises KeyError.

Returns:

A map-function you can plug into a tf.data.Dataset:

>>> mapper = make_dict_to_tuple_fn(feature_keys, ["subsidence", "gwl"])
>>> train_ds = train_ds.map(mapper, num_parallel_calls=tf.data.AUTOTUNE)

Return type:

Callable

Notes

  • Duplicate names in feature_keys are rejected so the positional tuple cannot be ambiguous.

  • The function is fully Autograph-compatible; TensorFlow will stage it as part of the dataset pipeline.

geoprior.models.utils.extract_txy_in(inputs, coord_slice_map=None, expect_dim=None, verbose=0, _logger=None, **kws)[source]#

Extracts t, x, y tensors from various input formats.

This utility standardizes coordinate inputs, accepting a single tensor or a dictionary, and handling both 2D (spatial/static) and 3D (spatio-temporal) data. It ensures a consistent 3D output format for robust downstream processing.

Parameters:
  • inputs (tf.Tensor, np.ndarray, or dict) – The input data containing coordinates. A single tensor or array may be 2D with shape (batch, 3) or 3D with shape (batch, time_steps, 3). A dictionary may contain a 'coords' key with the coordinate tensor, or separate 't', 'x', and 'y' keys.

  • coord_slice_map (dict, optional) – Mapping for ‘t’, ‘x’, ‘y’ to their index in the last dimension of a coordinate tensor. Defaults to {‘t’: 0, ‘x’: 1, ‘y’: 2}.

  • expect_dim ({'2d', '3d'}, optional) – If provided, enforces that the input resolves to the specified dimension. '2d' requires input shaped like (batch, 3) or a dictionary of (batch, 1) tensors. '3d' requires input shaped like (batch, time, 3) or a dictionary of (batch, time, 1) tensors. If None, both are accepted and 2D inputs are expanded to 3D.

  • verbose (int, default 0) – Controls the verbosity of logging messages. 0 is silent, 1 provides basic info, and higher values provide more detail.

  • _logger (Logger | Callable[[str], None] | None)

Returns:

t, x, y – The extracted t, x, and y coordinate tensors, each reshaped to be 3D with a singleton last dimension, e.g., (batch, time_steps, 1).

Return type:

Tuple[tf.Tensor, tf.Tensor, tf.Tensor]

Raises:

ValueError – If input format is unsupported, dimensions are inconsistent, or expect_dim constraint is violated.

geoprior.models.utils.process_pde_modes(pde_mode, enforce_consolidation=False, pde_mode_config=None, solo_return=False)[source]#

Normalize and validate PDE mode selection.

Parameters:
  • pde_mode (str, sequence of str, or None) –

    Requested PDE mode(s).

    Accepted canonical values are: - "none" - "consolidation" - "gw_flow" - "both"

    Accepted aliases: - None, "off" -> "none" - "on" -> "both"

  • enforce_consolidation (bool, default False) –

    If True, any resolved mode other than exact ["consolidation"] is coerced to ["consolidation"] and a warning is emitted.

    This includes: - ["none"] - ["gw_flow"] - ["consolidation", "gw_flow"]

  • pde_mode_config (str, sequence of str, or None, optional) – Optional override. If provided, this value takes precedence over pde_mode.

  • solo_return (bool, default False) –

    If False, return a canonical list of active modes.

    If True, return a single canonical label: - "none" - "consolidation" - "gw_flow" - "both"

Returns:

Canonical PDE mode(s), either as a list or a single label.

Return type:

list of str or str

Raises:
  • TypeError – If the input type is invalid.

  • ValueError – If a token is unsupported or the mode selection is ambiguous.

Examples

>>> process_pde_modes(None)
['none']
>>> process_pde_modes("off")
['none']
>>> process_pde_modes("on")
['consolidation', 'gw_flow']
>>> process_pde_modes("both", solo_return=True)
'both'
>>> process_pde_modes("gw_flow", enforce_consolidation=True)
['consolidation']

What this package covers#

The public exports of geoprior.models.utils combine general model helpers from geoprior.models.utils._utils with PINN-specific helpers from geoprior.models.utils.pinn.

The current export surface includes functions such as:

  • prepare_model_inputs and prepare_model_inputs_in for standardizing model input tuples;

  • create_sequences and split_static_dynamic for sequence preparation;

  • forecast_single_step and forecast_multi_step for rollout helpers;

  • format_predictions_to_dataframe and format_pinn_predictions for output formatting;

  • prepare_pinn_data_sequences and normalize_for_pinn for PINN preparation;

  • extract_txy and extract_txy_in for coordinate extraction;

  • process_pde_modes and PDE_MODE_ALIASES for PDE-mode normalization.

General model helper module#

Utility functions for neural networks models.

This module provides utility functions to preprocess data for Temporal Fusion Transformer (TFT) models, including splitting sequences into static and dynamic inputs and creating input sequences with corresponding targets for time series forecasting.

geoprior.models.utils._utils.split_static_dynamic(sequences, static_indices, dynamic_indices, static_time_step=0, reshape_static=True, reshape_dynamic=True, static_reshape_shape=None, dynamic_reshape_shape=None)[source]

Split sequences into static and dynamic inputs for the model.

The split_static_dynamic function divides input sequences into static and dynamic components based on specified feature indices. Static features are typically location-specific and do not change over time, while dynamic features vary across different time steps.

(56)#\[\begin{split}\text{Static Inputs} = \mathbf{S} = \mathbf{X}_{t, static\_indices} \\ \text{Dynamic Inputs} = \mathbf{D} = \mathbf{X}_{:, dynamic\_indices}\end{split}\]
Parameters:
  • sequences (numpy.ndarray) – Array of input sequences with shape (batch_size, sequence_length, num_features).

  • static_indices (List[int]) – Indices of static features within the feature dimension.

  • dynamic_indices (List[int]) – Indices of dynamic features within the feature dimension.

  • static_time_step (int, default 0) – Time step from which to extract static features (default is the first time step).

  • reshape_static (bool, default True) – Whether to reshape static inputs. If False, returns without reshaping.

  • reshape_dynamic (bool, default True) – Whether to reshape dynamic inputs. If False, returns without reshaping.

  • static_reshape_shape (:py:class:Optional[Tuple[int`, :py:class:`...]], default None) – Desired shape for static inputs after reshaping. If None, defaults to (batch_size, num_static_vars, 1).

  • dynamic_reshape_shape (:py:class:Optional[Tuple[int`, :py:class:`...]], default None) – Desired shape for dynamic inputs after reshaping. If None, defaults to (batch_size, sequence_length, num_dynamic_vars, 1).

Returns:

A tuple containing: - Static inputs with shape as specified. - Dynamic inputs with shape as specified.

Return type:

Tuple[`numpy.ndarray`, numpy.ndarray]

Raises:

ValueError – If static_time_step is out of range for the given sequence length.

Examples

>>> import numpy as np
>>> from geoprior.models.utils import split_static_dynamic
>>>
>>> # Create a dummy sequence array
>>> sequences = np.random.rand(100, 10, 5)  # (
...   batch_size=100, sequence_length=10, num_features=5)
>>>
>>> # Define static and dynamic feature indices
>>> static_indices = [0, 1]  # e.g., longitude and latitude
>>> dynamic_indices = [2, 3, 4]  # e.g., year, GWL, density
>>>
>>> # Split the sequences
>>> static_inputs, dynamic_inputs = split_static_dynamic(
...     sequences,
...     static_indices=static_indices,
...     dynamic_indices=dynamic_indices,
...     static_time_step=0
... )
>>>
>>> print(static_inputs.shape)
(100, 2, 1)
>>> print(dynamic_inputs.shape)
(100, 10, 3, 1)

Notes

  • Static Features: These are typically location-specific features such as geographical coordinates or categorical attributes that remain constant over time.

  • Dynamic Features: These features vary over different time steps and are essential for capturing temporal dependencies in the data.

  • Reshaping: The function provides flexibility in reshaping the static and dynamic inputs to match the input requirements of various models, including Temporal Fusion Transformers.

See also

geoprior.models.utils.create_sequences

Function to create input sequences and targets for time series forecasting.

geoprior.models.utils._utils.create_sequences(df, sequence_length, target_col, step=1, include_overlap=True, drop_last=True, forecast_horizon=None, verbose=3)[source]

Create input sequences and corresponding targets for time series forecasting.

The create_sequences function generates sequences of features and their corresponding targets from a time series dataset. This is essential for training sequence models like Temporal Fusion Transformers, LSTMs, and others that rely on temporal dependencies.

See more in User Guide.

Parameters:
  • df (pandas.DataFrame) – The processed DataFrame containing features and the target variable.

  • sequence_length (int) – The number of past time steps to include in each input sequence.

  • target_col (str) – The name of the target column.

  • step (int, default 1) – The step size between the starts of consecutive sequences.

  • include_overlap (bool, default True) – Whether to include overlapping sequences based on the step size.

  • drop_last (bool, default True) – Whether to drop the last sequence if it does not have enough data points.

  • forecast_horizon (int, optional, default None) – The number of future time steps to predict. If set to None, the function will create targets for a single future time step. If provided, targets will consist of the next forecast_horizon time steps.

  • verbose (int, default 3) – Controls the verbosity of logging. Ranges from 0 (no logs) to 7 (maximal logs).

Returns:

A tuple containing:
  • sequences: Array of input sequences with shape (num_sequences, sequence_length, num_features).

  • targets:
    • If forecast_horizon is None: Array of target values with shape (num_sequences,).

    • If forecast_horizon is an integer: Array of target sequences with shape (num_sequences, forecast_horizon).

Return type:

Tuple[`numpy.ndarray`, numpy.ndarray]

Raises:

ValueError – If the DataFrame df does not contain the target_col.

Examples

>>> import pandas as pd
>>> import numpy as np
>>> from geoprior.models.utils import create_sequences
>>> # Create a dummy DataFrame
>>> data = pd.DataFrame({
...     'feature1': np.random.rand(100),
...     'feature2': np.random.rand(100),
...     'feature3': np.random.rand(100),
...     'target': np.random.rand(100)
... })
>>> # Create sequences for single-step forecasting
>>> sequence_length = 4
>>> sequences, targets = create_sequences(
...     df=data,
...     sequence_length=sequence_length,
...     target_col='target',
...     step=1,
...     include_overlap=True,
...     drop_last=True,
...     forecast_horizon=None
... )
>>> print(sequences.shape)
(95, 4, 4)
>>> print(targets.shape)
(95,)
>>> # Create sequences for multi-step forecasting (e.g., 3 steps)
>>> forecast_horizon = 3
>>> sequences, targets = create_sequences(
...     df=data,
...     sequence_length=4,
...     target_col='target',
...     step=1,
...     include_overlap=True,
...     drop_last=True,
...     forecast_horizon=3
... )
>>> print(sequences.shape)
(92, 4, 4)
>>> print(targets.shape)
(92, 3)

Notes

  • Sequence Creation: The function slides a window of size sequence_length across the DataFrame to create input sequences. Each sequence is associated with a target value or sequence of values that immediately follow the input sequence.

  • Forecast Horizon:
    • If forecast_horizon is None, the function creates targets for a single future time step.

    • If forecast_horizon is an integer H, the function creates targets consisting of the next H time steps.

  • Step Size: The step parameter controls the stride of the sliding window. A step of 1 results in overlapping sequences, while a larger step reduces overlap.

  • Handling Incomplete Sequences: If drop_last is set to False, the function includes the last sequence even if it doesn’t have enough data points to form a complete sequence or target.

  • Data Validation: The function utilizes are_all_frames_valid from geoprior.core.checks to ensure the integrity of input DataFrame before processing and exist_features to verify the presence of the target column.

The sequences generation can be expressed as:

(57)#\[\begin{split}\\text{For each sequence } i, \\\\ \\mathbf{X}^{(i)} = \\left[ \\mathbf{x}_{i}, \\mathbf{x}_{i+1}, \\\\ \\dots, \\mathbf{x}_{i+T-1} \\right] \\\\ y^{(i)} = \\begin{cases} \\mathbf{x}_{i+T} & \\text{if } \\text{forecast\\_horizon} = \\text{None} \\\\ \\left[ \\mathbf{x}_{i+T}, \\mathbf{x}_{i+T+1}, \\dots, \\\\ \\mathbf{x}_{i+T+H-1} \\right] & \\text{if } \\text{forecast\\_horizon} = H \\end{cases}\end{split}\]
Where:
  • \(\\mathbf{X}^{(i)}\) is the input sequence of length \(T\).

  • \(y^{(i)}\) is the target value(s) following the sequence.

See also

geoprior.models.utils.split_static_dynamic

Function to split sequences into static and dynamic inputs.

geoprior.models.utils._utils.compute_forecast_horizon(data=None, dt_col=None, start_pred=None, end_pred=None, error='raise', verbose=1)[source]

Compute the forecast horizon for time series forecasting models.

This function calculates the number of future time steps (forecast_horizon) a model should predict based on the provided data or specified prediction dates. It intelligently infers the frequency of the data and computes the horizon accordingly. The function accommodates various datetime formats and handles different input scenarios robustly.

Parameters:
  • data (pandas.DataFrame, pandas.Series, list, or numpy.ndarray, optional) – The dataset containing datetime information. If a pandas.DataFrame is provided, the dt_col parameter must be specified to indicate which column contains the datetime data. For pandas.Series, list, or numpy.ndarray, the function attempts to infer the frequency directly.

  • dt_col (str, optional) – The name of the column in data that contains datetime information. This parameter is required if data is a pandas.DataFrame. Example: dt_col='timestamp'

  • start_pred (str, int, or datetime-like) – The starting point for forecasting. This can be a date string (e.g., ‘2023-04-10’), a datetime object, or an integer representing a year (e.g., 2024). If an integer is provided, it is interpreted as a year, and a warning is issued to inform the user of this interpretation.

  • end_pred (str, int, or datetime-like) – The ending point for forecasting. Similar to start_pred, this can be a date string, a datetime object, or an integer representing a year. The function calculates the forecast horizon based on the difference between start_pred and end_pred.

  • error ({'raise', 'warn', 'ignore'}, default 'raise') –

    Defines the error handling behavior when encountering issues such as invalid input types, missing date columns, or unparseable dates.

    • ’raise’: Raises a ValueError when an error is encountered.

    • ’warn’: Emits a warning and attempts to proceed with default behavior.

  • verbose (int, default 1) –

    Controls the level of verbosity for debug information.

    • 0: No output.

    • 1: Minimal output (e.g., starting message).

    • 2: Intermediate output (e.g., detected dates, computed horizons).

    • 3: Detailed output (e.g., types of predictions, inferred frequencies).

Returns:

The computed forecast_horizon representing the number of steps ahead the model should predict. Returns None if an error occurs and error is set to ‘warn’.

Return type:

int or None

Raises:

ValueError – If invalid parameters are provided and error is set to ‘raise’.

Examples

>>> from geoprior.models.utils import compute_forecast_horizon
>>> import pandas as pd
>>> import numpy as np
>>> from datetime import datetime, timedelta
>>>
>>> # Example 1: Using a DataFrame with a Date Column
>>> df = pd.DataFrame({
...     'date': pd.date_range(start='2023-01-01', periods=100, freq='D'),
...     'value': np.random.randn(100)
... })
>>> horizon = compute_forecast_horizon(
...     data=df,
...     dt_col='date',
...     start_pred='2023-04-10',
...     end_pred='2023-04-20',
...     error='raise',
...     verbose=3
... )
>>> print(f"Forecast Horizon: {horizon}")
Forecast Horizon: 11
>>> # Example 2: Using a List of Datetimes
>>> dates = [datetime(2023, 1, 1) + timedelta(days=i) for i in range(100)]
>>> horizon = compute_forecast_horizon(
...     data=dates,
...     start_pred='2023-04-10',
...     end_pred='2023-04-20',
...     error='warn',
...     verbose=2
... )
>>> print(f"Forecast Horizon: {horizon}")
Forecast Horizon: 11
>>> # Example 3: Handling Integer Years
>>> horizon = compute_forecast_horizon(
...     start_pred=2024,
...     end_pred=2030,
...     error='raise',
...     verbose=1
... )
Forecast Horizon: 7
>>> # Example 4: Without Providing Data (Assuming Frequency Based on Prediction Dates)
>>> horizon = compute_forecast_horizon(
...     start_pred='2023-04-10',
...     end_pred='2023-04-20',
...     error='raise',
...     verbose=1
... )
Forecast Horizon: 11

Notes

  • When data is not provided, the function relies solely on the difference between start_pred and end_pred to compute the forecast horizon. In such cases, if the frequency cannot be inferred, the horizon is calculated based on the largest possible time unit (years, months, weeks, days).

  • If start_pred is after end_pred, the function returns 0 and issues a warning or raises an error based on the error parameter.

  • The function attempts to infer the frequency of the data using pandas utilities. If the frequency cannot be inferred, it defaults to calculating the horizon based on the time difference in the most significant unit.

See also

pandas.date_range

Generates a fixed frequency DatetimeIndex.

pandas.infer_freq

Infers the frequency of a DatetimeIndex.

geoprior.models.utils._utils.prepare_spatial_future_data(final_processed_data, feature_columns, dynamic_feature_indices, sequence_length=1, dt_col='date', static_feature_names=None, forecast_horizon=None, future_years=None, encoded_cat_columns=None, scaling_params=None, spatial_cols=None, squeeze_last=False, verbosity=0)[source]

Prepare future static and dynamic inputs for making predictions.

This function prepares the necessary static and dynamic inputs required for forecasting future values in time series data. It processes the provided dataset by grouping it by location_id, extracting the last sequence of data points based on the specified sequence_length, and generating future inputs for prediction over the defined forecast_horizon.

The function handles both integer and datetime representations of the dt_col, extracting the year from datetime columns when necessary. It also allows for flexibility in specifying static features and encoded categorical variables.

(58)#\[\text{scaled\_time} = \frac{\text{future\_time} - \mu}{\sigma}\]
Parameters:
  • final_processed_data (pandas.DataFrame) – The processed DataFrame containing all features and targets. Must include the location_id column and the specified dt_col.

  • feature_columns (List[str]) – List of feature column names to be used for dynamic input preparation.

  • dynamic_feature_indices (List[int]) – Indices of dynamic features in feature_columns. These features are considered time-dependent and are used to prepare dynamic inputs.

  • sequence_length (int, optional) – The number of past time steps to include in each input sequence. Default is 1.

  • dt_col (str, optional) – The name of the time-related column in final_processed_data. Defaults to 'date'.

  • static_feature_names (List[str], optional) – List of static feature column names. If not provided, defaults to ['longitude', 'latitude'] plus any encoded_cat_columns.

  • forecast_horizon (int, optional) – The number of future time steps to predict. If set to None, the function defaults to predicting the next immediate time step.

  • future_years (List[int], optional) – List of future years to predict. Must match the length of forecast_horizon if forecast_horizon is provided.

  • encoded_cat_columns (List[str], optional) – List of encoded categorical column names to be treated as static features.

  • scaling_params (Dict[str, Dict[str, float]], optional) – Dictionary containing scaling parameters (mean and standard deviation) for features. Example: {'year': {'mean': 2000, 'std': 10}}. If not provided, the function computes the mean and std for the dt_col.

  • squeeze_last (bool, default True,) – Squeeze the last axis which correspond to the output dimension y if equal to 1.

  • verbosity (int, optional) – Verbosity level from 0 to 7 for debugging and understanding the process. Higher values produce more detailed logs.

  • spatial_cols (tuple[str, str])

Returns:

A tuple containing:

  • future_static_inputsnumpy.ndarray

    Array of future static inputs with shape (num_samples, num_static_vars, 1).

  • future_dynamic_inputsnumpy.ndarray

    Array of future dynamic inputs with shape (num_samples, sequence_length, num_dynamic_vars, 1).

  • future_years_listList[int]

    List of future time values corresponding to each sample.

  • location_ids_listList[int]

    List of location IDs corresponding to each sample.

  • longitudesList[float]

    List of longitude values corresponding to each sample.

  • latitudesList[float]

    List of latitude values corresponding to each sample.

Return type:

Tuple[np.ndarray, np.ndarray, List[int], List[int], List[float], List[float]]

Examples

>>> from geoprior.models.utils import prepare_spatial_future_data
>>> import pandas as pd
>>> data = pd.DataFrame({
...     'location_id': [1, 1, 1, 2, 2, 2],
...     'year': [2018, 2019, 2020, 2018, 2019, 2020],
...     'longitude': [10.0, 10.0, 10.0, 20.0, 20.0, 20.0],
...     'latitude': [50.0, 50.0, 50.0, 60.0, 60.0, 60.0],
...     'temperature': [15, 16, 15.5, 20, 21, 20.5],
...     'rainfall': [100, 110, 105, 200, 210, 205],
...     'encoded_cat': [1, 1, 1, 2, 2, 2]
... })
>>> feature_cols = ['year', 'temperature', 'rainfall', 'encoded_cat']
>>> dynamic_indices = [0, 1, 2]
>>> future_static, future_dynamic, future_years, loc_ids, longs,\
    lats = prepare_spatial_future_data(
...     final_processed_data=data,
...     feature_columns=feature_cols,
...     dynamic_feature_indices=dynamic_indices,
...     sequence_length=2,
...     forecast_horizon=1,
...     future_years=[2021],
...     encoded_cat_columns=['encoded_cat'],
...     verbosity=5,
...     dt_col='year'
... )
>>> print(future_static.shape)
(2, 3, 1)
>>> print(future_dynamic.shape)
(2, 2, 3, 1)

Notes

  • The function handles both integer and datetime representations of the dt_col. If dt_col is a datetime type, the year is extracted for scaling purposes.

  • If forecast_horizon is set to None, the function defaults to generating data for the next immediate time step based on the last entry in the time column.

  • Ensure that the length of future_years matches forecast_horizon if forecast_horizon is provided.

  • The static_feature_names parameter allows for flexibility in specifying which static features to include. If not provided, it defaults to ['longitude', 'latitude'] plus any encoded_cat_columns.

See also

prepare_future_data

Main function for preparing future data inputs.

geoprior.models.utils._utils.compute_anomaly_scores(y_true, y_pred=None, method='statistical', threshold=3.0, domain_func=None, contamination=0.05, epsilon=1e-06, estimator=None, random_state=None, residual_metric='mse', objective='ts', verbose=1)[source]

Compute anomaly scores for given true targets using various methods.

This utility function, anomaly_scores, provides a flexible approach to compute anomaly scores outside the XTFT model itself. Anomaly scores serve as indicators of how unusual certain observations are, guiding the model towards more robust and stable forecasts. By detecting and quantifying anomalies, practitioners can adjust forecasting strategies, improve predictive performance, and handle irregular patterns more effectively.

Parameters:
  • y_true (np.ndarray) –

    The ground truth target values with shape (B, H, O), where: - B: batch size - H: number of forecast horizons (time steps ahead) - O: output dimension (e.g., number of target variables).

    Typically, y_true corresponds to the same array passed as the forecast target to the model. All computations of anomalies are relative to these true values or, if provided, their predicted counterparts y_pred.

  • y_pred (np.ndarray, optional) – The predicted values with shape (B, H, O). If provided and the method is set to ‘residual’, the anomaly scores are derived from the residuals between y_true and y_pred. In this scenario, anomalies reflect discrepancies indicating unusual conditions or model underperformance.

  • method (str, optional) –

    The method used to compute anomaly scores. Supported options:

    • "statistical" or "stats": Uses mean and standard deviation of y_true to measure deviation from the mean. Points far from the mean by a certain factor (controlled by threshold) yield higher anomaly scores.

      Formally, let \(\mu\) be the mean of y_true and \(\sigma\) its standard deviation. The anomaly score for a point \(y\) is:

      (59)#\[(\frac{y - \mu}{\sigma + \varepsilon})^2\]

      where \(\varepsilon\) is a small constant for numerical stability.

    • "domain": Uses a domain-specific heuristic (provided by domain_func) to compute scores. If no domain_func is provided, a default heuristic marks negative values as anomalies.

    • "isolation_forest" or "if": Employs the IsolationForest algorithm to detect outliers. The model learns a structure to isolate anomalies more quickly than normal points. Higher contamination rates allow more points to be considered anomalous.

    • "residual": If y_pred is provided, anomalies are derived from residuals: the difference (y_true - y_pred). By default, mean squared error (mse) is used. Other metrics include mae and rmse, offering flexibility in quantifying deviations:

      (60)#\[\text{MSE: }(y_{true} - y_{pred})^2\]

    Default is "statistical".

  • threshold (float, optional) – Threshold factor for the statistical method. Defines how far beyond mean ± (threshold * std) is considered anomalous. Though not directly applied as a mask here, it can guide interpretation of scores. Default is 3.0.

  • domain_func (callable, optional) –

    A user-defined function for domain method. It takes y_true as input and returns an array of anomaly scores with the same shape. If none is provided, the default heuristic:

    (61)\[\begin{split}\text{anomaly}(y) = \begin{cases} |y| \times 10 & \text{if } y < 0 \\ 0 & \text{otherwise} \end{cases}\end{split}\]

  • contamination (float, optional) – Used in the isolation_forest method. Specifies the proportion of outliers in the dataset. Default is 0.05.

  • epsilon (float, optional) – A small constant \(\varepsilon\) for numerical stability in calculations, especially during statistical normalization. Default is 1e-6.

  • estimator (object, optional) – A pre-fitted IsolationForest estimator for the isolation_forest method. If not provided, a new estimator will be created and fitted to y_true.

  • random_state (int, optional) – Sets a random state for reproducibility in the isolation_forest method.

  • residual_metric (str, optional) –

    The metric used to compute anomalies from residuals if method is set to ‘residual’. Supported metrics:

    • "mse": mean squared error per point (residuals**2)

    • "mae": mean absolute error per point |residuals|

    • "rmse": root mean squared error sqrt((residuals**2))

    Default is "mse".

  • objective (str, optional) – Specifies the type of objective, for future extensibility. Default is "ts" indicating time series. Could be extended for other tasks in the future.

  • verbose (int, optional) – Controls verbosity. If verbose=1, some messages or warnings may be printed. Higher values might produce more detailed logs.

Returns:

anomaly_scores – An array of anomaly scores with the same shape as y_true. Higher values indicate more unusual or anomalous points.

Return type:

np.ndarray

Notes

Choosing an appropriate method depends on the data characteristics, domain requirements, and model complexity. Statistical methods are quick and interpretable but may oversimplify anomalies. Domain heuristics leverage expert knowledge, while isolation forest applies a more robust, data-driven approach. Residual-based anomalies help assess model performance and highlight periods where the model struggles.

Examples

>>> from geoprior.models.losses import compute_anomaly_scores
>>> import numpy as np
>>> # Statistical method example
>>> y_true = np.random.randn(32, 20, 1)  # (B,H,O)
>>> scores = compute_anomaly_scores(y_true, method='statistical', threshold=3)
>>> scores.shape
(32, 20, 1)
>>> # Domain-specific example
>>> def my_heuristic(y):
...     return np.where(y < -1, np.abs(y)*5, 0.0)
>>> scores = compute_anomaly_scores(y_true, method='domain',
                                    domain_func=my_heuristic)
>>> # Isolation Forest example
>>> scores = compute_anomaly_scores(y_true, method='isolation_forest',
                                    contamination=0.1)
>>> # Residual-based example
>>> y_pred = y_true + np.random.normal(0, 1, y_true.shape)  # Introduce noise
>>> scores = compute_anomaly_scores(y_true, y_pred=y_pred, method='residual',
                                    residual_metric='mae')

See also

geoprior.models.losses.objective_loss

For integrating anomaly scores into a multi-objective loss.

geoprior.models.utils._utils.generate_forecast(xtft_model, train_data, dt_col, dynamic_features, future_features=None, static_features=None, test_data=None, mode='quantile', spatial_cols=None, forecast_horizon=4, time_steps=3, q=None, tname=None, forecast_dt=None, savefile=None, verbose=3, **kw)[source]

Generate forecast using the XTFT model.

This function uses a pre-trained Keras model to forecast future values based on provided historical data. The model receives three inputs: X_static, X_dynamic, and X_future re-built from train_data, and outputs predictions over a specified forecast horizon.

See more in User Guide.

Parameters:
  • xtft_model (object) – A validated Keras model instance. It is processed by the validate_keras_model method.

  • train_data (pandas.DataFrame) – The training data containing historical records. Must include the dt_col and all required feature columns.

  • dt_col (str) – Name of the column representing time. It may be a datetime or numeric column (e.g. "year").

  • dynamic_features (list of str) – List of dynamic feature column names. They are formatted via columns_manager.

  • future_features (list of str, optional) – List of future feature names. These columns are tiled over the forecast horizon.

  • static_features (list of str, optional) – List of static feature names. If not provided, a dummy input is used.

  • test_data (pandas.DataFrame, optional) – DataFrame containing actual values used for evaluation. If provided, it is used to compute the R² and coverage score for mode='quantile'.

  • mode (str, optional) – Forecast mode. Must be either "quantile" or "point". In quantile mode, predictions for multiple quantiles (default: [0.1, 0.5, 0.9]) are computed.

  • spatial_cols (list of str, optional) – List of spatial column names for grouping the data. When provided, forecasts are computed per location; otherwise, a global forecast is performed.

  • forecast_horizon (int, optional) – Number of future periods to forecast. Default is 4.

  • time_steps (int, optional) – Number of past time steps to use as input for the model. Default is 3.

  • q (list of float, optional) – List of quantiles for use in quantile mode. Default is [0.1, 0.5, 0.9]. Each quantile is validated by the assert_ratio function.

  • tname (str, optional) – Target variable name used for constructing forecast result columns. Defaults to "target".

  • forecast_dt (list or str, optional) – List of forecast dates or "auto" to derive dates from dt_col. In auto mode, if dt_col is datetime, frequency is inferred using pd.infer_freq.

  • savefile (str, optional) – Path to the CSV file where forecast results will be saved. If not provided, a default filename is generated.

  • verbose (int, optional) – Verbosity level (0-7). Controls the amount of execution output.

Returns:

A DataFrame containing the forecast results. In quantile mode, each forecast period includes columns for each quantile; in point mode, a single prediction column is provided.

Return type:

pandas.DataFrame

Examples

  1. Example refering to Train data only

>>> import os
>>> import pandas as pd
>>> import numpy as np
>>> from geoprior.models.transformers import XTFT
>>> from geoprior.models.losses import combined_quantile_loss
>>> from geoprior.models.utils import generate_forecast
>>>
>>> # Create a dummy training DataFrame with a date column,
>>> # dynamic features "feat1", "feat2", static feature "stat1",
>>> # and target "price".
>>> date_rng = pd.date_range(start="2020-01-01", periods=50, freq="D")
>>> train_df = pd.DataFrame({
...     "date": date_rng,
...     "feat1": np.random.rand(50),
...     "feat2": np.random.rand(50),
...     "stat1": np.random.rand(50),
...     "price": np.random.rand(50)
... })
>>>
>>> # Prepare a dummy XTFT model with example parameters.
>>> # Note: The model expects the following input shapes:
>>> # - X_static: (n_samples, static_input_dim)
>>> # - X_dynamic: (n_samples, time_steps, dynamic_input_dim)
>>> # - X_future:  (n_samples, time_steps, future_input_dim)
>>> my_model = XTFT(
...     static_input_dim=1,           # "stat1"
...     dynamic_input_dim=2,          # "feat1" and "feat2"
...     future_input_dim=1,           # features provided for dim1
...     forecast_horizon=5,           # Forecasting 5 periods ahead
...     quantiles=[0.1, 0.5, 0.9],
...     embed_dim=16,
...     max_window_size=3,
...     memory_size=50,
...     num_heads=2,
...     dropout_rate=0.1,
...     lstm_units=32,
...     attention_units=32,
...     hidden_units=16
... )
>>> my_model.compile(optimizer="adam")
>>>
>>> # Create dummy input arrays for model fitting.
>>> # For simplicity, assume time_steps = 3 and use random data.
>>> X_static = train_df[["stat1"]].values      # shape: (50, 1)
>>> # Create a dummy dynamic input array of shape (50, 3, 2)
>>> X_dynamic = np.random.rand(50, 3, 2)
>>> # Create a dummy features
>>> X_future = np.random.rand(50, 3, 1)
>>> # Create dummy target output from "price"
>>> y_array = train_df["price"].values.reshape(50, 1, 1)
>>>
>>> # Fit the model on the dummy data.
>>> my_model.fit(
...     x=[X_static, X_dynamic, X_future],
...     y=y_array,
...     epochs=1,
...     batch_size=8
... )
>>>
>>> # Generate forecast using the generate_forecast function.
>>> forecast = generate_forecast(
...     xtft_model=my_model,
...     train_data=train_df,
...     dt_col="date",
...     dynamic_features=["feat1", "feat2"],
...     static_features=["stat1"],
...     forecast_horizon=5,
...     time_steps=3,
...     tname="price",
...     mode="quantile",
...     verbose=3
... )
>>> print(forecast.head())
  1. Example refering to Test data included.

>>> # Create a dummy DataFrame with a date column,
>>> # two dynamic features ("feat1", "feat2"), one static feature ("stat1"),
>>> # and target "price".
>>> date_rng = pd.date_range(start="2020-01-01", periods=60, freq="D")
>>> data = {
...     "date": date_rng,
...     "feat1": np.random.rand(60),
...     "feat2": np.random.rand(60),
...     "stat1": np.random.rand(60),
...     "price": np.random.rand(60)
... }
>>> df = pd.DataFrame(data)
>>>
>>> # Split the DataFrame into training and test sets.
>>> # Training data: dates before 2020-02-01
>>> # Test data: dates from 2020-02-01 onward.
>>> train_df = df[df["date"] < "2020-02-01"].copy()
>>> test_df  = df[df["date"] >= "2020-02-01"].copy()
>>>
>>> # Create dummy input arrays for model fitting.
>>> # Assume time_steps = 3.
>>> X_static = train_df[["stat1"]].values      # Shape: (n_train, 1)
>>> X_dynamic = np.random.rand(len(train_df), 3, 2)
>>> X_future  = np.random.rand(len(train_df), 3, 1)
>>> # Create dummy target output from "price".
>>> y_array   = train_df["price"].values.reshape(len(train_df), 1, 1)
>>>
>>> # Instantiate a dummy XTFT model.
>>> my_model = XTFT(
...     static_input_dim=1,           # "stat1"
...     dynamic_input_dim=2,          # "feat1" and "feat2"
...     future_input_dim=1,           # For the provided future feature
...     forecast_horizon=5,           # Forecasting 5 periods ahead
...     quantiles=[0.1, 0.5, 0.9],
...     embed_dim=16,
...     max_window_size=3,
...     memory_size=50,
...     num_heads=2,
...     dropout_rate=0.1,
...     lstm_units=32,
...     attention_units=32,
...     hidden_units=16
... )
>>> loss_fn = combined_quantile_loss(my_model.quantiles)
>>> my_model.compile(optimizer="adam", loss=loss_fn)
>>>
>>> # Fit the model on the training data.
>>> my_model.fit(
...     x=[X_static, X_dynamic, X_future],
...     y=y_array,
...     epochs=1,
...     batch_size=8,
...     callbacks = [early_stopping, model_checkpoint]
... )
>>>
>>> # Generate forecast using the generate_forecast function.
>>> # This example uses test_df for evaluation, which will compute
>>> # metrics like R² Score and Coverage Score.
>>> forecast = generate_forecast(
...     xtft_model=my_model,
...     train_data=train_df,
...     dt_col="date",
...     dynamic_features=["feat1", "feat2"],
...     static_features=["stat1"],
...     test_data=test_df.iloc[:5, :], # to fit the first horizon forecasting.
...     forecast_horizon=5,
...     time_steps=3,
...     tname="price",
...     mode="quantile",
...     verbose=3
... )
>>> print(forecast.head())
  1. Example of Point forecasting

>>> # Create a dummy training DataFrame with a date column,
>>> # two dynamic features ("feat1", "feat2"), one static feature ("stat1"),
>>> # and target "price".
>>> date_rng = pd.date_range(start="2020-01-01", periods=50, freq="D")
>>> train_df = pd.DataFrame({
...     "date": date_rng,
...     "feat1": np.random.rand(50),
...     "feat2": np.random.rand(50),
...     "stat1": np.random.rand(50),
...     "price": np.random.rand(50)
... })
>>>
>>> # Create dummy input arrays for model fitting.
>>> # X_static is derived from the static feature "stat1".
>>> X_static = train_df[["stat1"]].values      # shape: (50, 1)
>>>
>>> # X_dynamic is a dummy dynamic array for "feat1" and "feat2".
>>> # For time_steps = 3, its shape is (50, 3, 2).
>>> X_dynamic = np.random.rand(50, 3, 2)
>>>
>>> # X_future is a dummy array for future features.
>>> # Here, we assume a single future feature with shape (50, 3, 1).
>>> X_future = np.random.rand(50, 3, 1)
>>>
>>> # Create dummy target output from "price".
>>> y_array = train_df["price"].values.reshape(50, 1, 1)
>>>
>>> # Instantiate a dummy XTFT model.
>>> my_model = XTFT(
...     static_input_dim=1,           # "stat1"
...     dynamic_input_dim=2,          # "feat1" and "feat2"
...     future_input_dim=1,           # Provided future feature
...     forecast_horizon=5,           # Forecast 5 periods ahead
...     quantiles=None,    # [0.1, 0.5, 0.9] Not used in point mode
...     embed_dim=16,
...     max_window_size=3,
...     memory_size=50,
...     num_heads=2,
...     dropout_rate=0.1,
...     lstm_units=32,
...     attention_units=32,
...     hidden_units=16
... )
>>> my_model.compile(optimizer="adam")
>>>
>>> # Fit the model on the dummy data.
>>> my_model.fit(
...     x=[X_static, X_dynamic, X_future],
...     y=y_array,
...     epochs=1,
...     batch_size=8
... )
>>>
>>> # Generate forecast using the generate_forecast function in point mode.
>>> forecast = generate_forecast(
...     xtft_model=my_model,
...     train_data=train_df,
...     dt_col="date",
...     dynamic_features=["feat1", "feat2"],
...     static_features=["stat1"],
...     forecast_horizon=5,
...     time_steps=3,
...     tname="price",
...     mode="point",
...     verbose=3
... )
>>> print(forecast.head())

Notes

The function groups data by spatial_cols if provided, and formats features via columns_manager. It validates the time column using check_datetime and uses dummy inputs for missing static or future features. The forecast is produced by invoking xtft_model.predict on a list containing static, dynamic, and future inputs. The predictions are generated as follows:

(62)#\[\hat{y}_{t+i} = f\Bigl(X_{\text{static}},\; X_{\text{dynamic}},\; X_{\text{future}}\Bigr)\]

where \(i\) denotes the forecast period.

See also

geoprior.models.utils.reshape_xtft_data

Function to reshape data for XTFT models.

geoprior.utils.validator.validate_keras_model

Function to validate Keras model compatibility.

geoprior.core.handlers.columns_manager

Utility to manage and format column names.

geoprior.core.checks.check_datetime

Function to check and validate datetime columns.

geoprior.core.checks.check_spatial_columns

Function to validate spatial columns in data.

geoprior.core.checks.assert_ratio

Function to validate and assert ratio values.

geoprior.metrics_special.coverage_score

Function to compute coverage score for quantile predictions.

geoprior.models.utils._utils.generate_forecast_with(xtft_model, inputs, forecast_horizon, y=None, dt_col=None, mode='quantile', spatial_cols=None, q=None, tname=None, forecast_dt=None, apply_mask=False, mask_values=None, mask_fill_value=None, savefile=None, verbose=3, **kw)[source]

Generate forecasts using a pre-trained XTFT model based on the forecast horizon.

There are two approaches to generating forecasts with an XTFT model:

  1. A monolithic function (e.g. generate_forecast) that handles both single-step and multi-step forecasts within a single implementation. This approach results in a single, large function that internally branches its logic based on the value of the forecast horizon.

  2. A modular design where the single-step and multi-step forecasting functionalities are separated into two distinct functions (e.g. forecast_single_step and forecast_multi_step), with a thin wrapper (e.g. generate_xtft_forecast) that dispatches to the appropriate function based on the forecast horizon.

The modular approach (2) is generally preferred because it separates concerns and improves code readability, maintainability, and unit testing. Each function becomes responsible for a well-defined task: one for single-step forecasts and one for multi-step forecasts. The wrapper function, which we propose to name generate_xtft_forecast, simply selects the correct method based on the forecast horizon. Use this approach when your application may need to handle both short- and long- term forecasts, as it keeps the codebase modular and easier to debug.

Below is an implementation of the wrapper function generate_xtft_forecast that calls forecast_single_step when forecast_horizon equals 1 and forecast_multi_step when forecast_horizon is greater than 1.

Parameters:
  • xtft_model (object) – A validated Keras model instance. The model is expected to be verified via validate_keras_model.

  • inputs (list or tuple of numpy.ndarray) – A list containing three elements: X_static, X_dynamic, and X_future. If spatial_cols is provided, it is assumed that the first two columns of X_static correspond to the first and second spatial coordinates of the original training data.

  • forecast_horizon (int) – The number of future time steps to forecast. A value of 1 triggers a single-step forecast; values greater than 1 trigger a multi-step forecast.

  • y (numpy.ndarray, optional) – Actual target values for evaluation. If provided, evaluation metrics (e.g., R² Score, and in quantile mode, the coverage score) are computed.

  • dt_col (str, optional) – Name of the time column (e.g. "year"). If provided, the output DataFrame includes a column with these values.

  • mode (str, optional) – Forecast mode, either "quantile" or "point". In quantile mode, predictions are generated for multiple quantiles (default: [0.1, 0.5, 0.9]); in point mode, a single prediction is generated.

  • spatial_cols (list of str, optional) – List of spatial column names. If provided, it must contain at least two elements corresponding to the first and second columns of the original training data’s X_static.

  • time_steps (int, optional) – The number of historical time steps used as input.

  • q (list of float, optional) – List of quantile values for quantile forecasting. Default is [0.1, 0.5, 0.9] when mode is "quantile".

  • tname (str, optional) – Target variable name used to construct output column names (e.g., "subsidence"). Default is "target".

  • forecast_dt (any, optional) – Forecast datetime information. If provided and its length matches forecast_horizon, its values are added to the output DataFrame.

  • apply_mask (bool, optional) – If True, applies masking (via mask_by_reference) to adjust predictions in non-subsiding areas. Requires that both mask_values and mask_fill_value are provided.

  • mask_values (scalar, optional) – The reference value(s) used for masking. Must be provided if apply_mask is True.

  • mask_fill_value (scalar, optional) – The value used to fill masked predictions. Must be provided if apply_mask is True.

  • savefile (str, optional) – File path to save the forecast results as a CSV file. If not provided, a default filename is generated.

  • verbose (int, optional) – Verbosity level controlling printed output.

  • **kw (dict, optional) – Does nothing; here for future extension.

Returns:

A DataFrame containing the forecast results. In quantile mode, the output includes columns for each quantile and forecast step (e.g. <tname>_q10_step1, <tname>_q50_step2, etc.); in point mode, it contains a single prediction column per forecast step (e.g. <tname>_pred_step1). If y is provided, an additional column (<tname>_actual) is included.

Return type:

pandas.DataFrame

Examples

>>> from geoprior.models.transformers import XTFT
>>> from geoprior.models.utils import generate_forecast_with
>>> import numpy as np
>>>
>>> # Prepare a dummy XTFT model with example parameters.
>>> my_model = XTFT(
...     static_input_dim=10,
...     dynamic_input_dim=5,
...     future_input_dim=3,
...     forecast_horizon=1,          # This parameter will be updated in the
...                                  # wrapper function based on forecast_horizon.
...     quantiles=[0.1, 0.5, 0.9],
...     embed_dim=32,
...     max_window_size=3,
...     memory_size=100,
...     num_heads=4,
...     dropout_rate=0.1,
...     lstm_units=64,
...     attention_units=64,
...     hidden_units=32
... )
>>> my_model.compile(optimizer='adam')
>>>
>>> # Create dummy input data.
>>> X_static = np.random.rand(100, 10)
>>> X_dynamic = np.random.rand(100, 3, 5)
>>> X_future  = np.random.rand(100, 3, 3)
>>> y_array   = np.random.rand(100, 1, 1)  # For single-step target output.
>>> inputs    = [X_static, X_dynamic, X_future]
>>>
>>> # Fit the model with dummy data.
>>> my_model.fit(
...     x=[X_static, X_dynamic, X_future],
...     y=y_array,
...     epochs=1,
...     batch_size=32
... )
>>>
>>> # Example for a single-step forecast:
>>> forecast_df = generate_forecast_with(
...     xtft_model=my_model,
...     inputs=inputs,
...     forecast_horizon=1,
...     y=y_array,
...     dt_col="year",
...     mode="quantile",
...     spatial_cols=["longitude", "latitude"],
...     tname="subsidence",
...     verbose=3
... )
>>> print(forecast_df.head())
>>>
>>> # Example for a multi-step forecast:
>>> forecast_dates = ["2023", "2024", "2025", "2026"]
>>> forecast_df = generate_forecast_with(
...     xtft_model=my_model,
...     inputs=inputs,
...     forecast_horizon=4,
...     y=y_array,
...     dt_col="year",
...     mode="point",
...     spatial_cols=["longitude", "latitude"],
...     tname="subsidence",
...     forecast_dt=forecast_dates,
...     verbose=3
... )
>>> print(forecast_df.head())

See also

forecast_single_step

Generates a single-step forecast.

forecast_multi_step

Generates a multi-step forecast.

validate_keras_model

Validates a Keras model.

coverage_score

Computes the coverage score.

geoprior.models.utils._utils.forecast_multi_step(xtft_model, inputs, forecast_horizon, y=None, dt_col=None, mode='quantile', spatial_cols=None, q=None, tname=None, forecast_dt=None, apply_mask=False, mask_values=None, mask_fill_value=None, savefile=None, verbose=3, **kws)[source]

Generate a multi-step forecast using the XTFT model.

This function generates forecasts for multiple future time steps using a pre-trained XTFT deep learning model. The model takes three inputs: X_static, X_dynamic, and X_future, and produces predictions according to the formulation:

(63)#\[\hat{y}_{t+i} = f\Bigl( X_{\text{static}},\; X_{\text{dynamic}},\; X_{\text{future}} \Bigr)\]

for \(i = 1, \dots, forecast_horizon\), where \(f\) is the trained XTFT model.

Parameters:
  • xtft_model (object) – A validated Keras model instance. The model is expected to be verified via validate_keras_model.

  • inputs (list or tuple of numpy.ndarray) – A list containing three elements: X_static, X_dynamic, and X_future. If spatial_cols is provided, it is assumed that the first two columns of X_static correspond to the first and second spatial coordinates of the original training data.

  • forecast_horizon (int) – The number of future time steps to forecast. For example, if forecast_horizon is 4, the model will generate predictions for 4 steps ahead.

  • y (numpy.ndarray, optional) – Actual target values. If provided, evaluation metrics such as R² Score and, in quantile mode, the coverage score are computed.

  • dt_col (str, optional) – Name of the time column (e.g. "year"). If provided, a column with this name is added to the output DataFrame. The actual time values must be supplied externally.

  • mode (str, optional) – Forecast mode. Must be either "quantile" or "point". In quantile mode, predictions are generated for multiple quantiles (default: [0.1, 0.5, 0.9]); in point mode, a single prediction is generated.

  • spatial_cols (list of str, optional) – A list of spatial column names. If provided, it must contain at least two elements corresponding to the first and second columns of the original training data’s X_static.

  • time_steps (int, optional) – The number of historical time steps used as input. Default is 3.

  • q (list of float, optional) – List of quantile values for quantile forecasting. The default is [0.1, 0.5, 0.9] when mode is "quantile".

  • tname (str, optional) – Target variable name used to construct output column names. For instance, if tname is "subsidence", then output columns may be named "subsidence_q10_step1", "subsidence_q50_step2", etc. Default is "target".

  • forecast_dt (any, optional) – Forecast datetime information. If provided and its length matches forecast_horizon, its values are added to the output DataFrame.

  • apply_mask (bool, optional) – If True, applies masking via mask_by_reference to replace predictions in non-subsiding areas. Requires that both mask_values and mask_fill_value are provided.

  • mask_values (scalar, optional) – The reference value(s) used for masking. Must be provided if apply_mask is True.

  • mask_fill_value (scalar, optional) – The value used to fill masked predictions. Must be provided if apply_mask is True.

  • savefile (str, optional) – File path to save the forecast results as a CSV file. If not provided, a default filename is generated.

  • verbose (int, optional) – Verbosity level controlling printed output. Higher values produce more detailed messages.

Returns:

A DataFrame containing the multi-step forecast results. In quantile mode, the DataFrame includes columns for each quantile and each forecast step (e.g. <tname>_q10_step1, <tname>_q50_step2, etc.); in point mode, it contains a single prediction column per forecast step (e.g. <tname>_pred_step1). If y is provided, an additional column (<tname>_actual) is included.

Return type:

pandas.DataFrame

Examples

>>> from geoprior.models.transformers import XTFT
>>> from geoprior.models.utils import forecast_multi_step
>>> from geoprior.models.losses import combined_quantile_loss
>>> import pandas as pd
>>> import numpy as np
>>>
>>> # Create a dummy training DataFrame with a date column,
>>> # spatial features ("longitude", "latitude"), two dynamic
>>> # features ("feat1", "feat2"), a static feature ("stat1"), and
>>> # the target variable "subsidence".
>>> date_rng = pd.date_range(start="2020-01-01", periods=60,
...                          freq="D")
>>> train_df = pd.DataFrame({
...     "date": date_rng,
...     "longitude": np.random.uniform(-180, 180, 60),
...     "latitude": np.random.uniform(-90, 90, 60),
...     "feat1": np.random.rand(60),
...     "feat2": np.random.rand(60),
...     "stat1": np.random.rand(60),
...     "subsidence": np.random.rand(60)
... })
>>>
>>> # Prepare dummy input arrays for model training.
>>> # X_static is constructed using "longitude" and "stat1".
>>> X_static = train_df[["longitude", "stat1"]].values
>>> # X_dynamic for "feat1" and "feat2" with time_steps = 3.
>>> X_dynamic = np.random.rand(60, 3, 2)
>>> # X_future is a dummy future feature array with shape (60, 3, 1).
>>> X_future = np.random.rand(60, 3, 1)
>>> # Target output from "subsidence" reshaped to
>>> # (60, 1, 1). For multi-step forecast, forecast_horizon is 4.
>>> forecast_horizon = 4
>>> y_array = train_df["subsidence"].values.reshape(60, 1, 1)
>>>
>>> # Instantiate a dummy XTFT model.
>>> my_model = XTFT(
...     static_input_dim=2,    # "longitude" and "stat1"
...     dynamic_input_dim=2,   # "feat1" and "feat2"
...     future_input_dim=1,    # One future feature
...     forecast_horizon=forecast_horizon,
...     quantiles=[0.1, 0.5, 0.9],
...     embed_dim=16,
...     max_window_size=3,
...     memory_size=50,
...     num_heads=2,
...     dropout_rate=0.1,
...     lstm_units=32,
...     attention_units=32,
...     hidden_units=16
... )
>>> my_model.compile(
...    optimizer="adam",
...    loss=combined_quantile_loss(my_model.quantiles)
...    )
>>>
>>> # Fit the model on the dummy data for demonstration.
>>> my_model.fit(
...     x=[X_static, X_dynamic, X_future],
...     y=y_array,
...     epochs=1,
...     batch_size=8
... )
>>>
>>> # Generate forecast datetime values for the forecast horizon.
>>> forecast_dates = pd.date_range(start="2020-02-01",
...                                periods=forecast_horizon, freq="D")
>>>
>>> # Package inputs as expected by forecast_multi_step.
>>> inputs = [X_static, X_dynamic, X_future]
>>>
>>> # Generate a multi-step forecast in quantile mode.
>>> forecast_df_quantile = forecast_multi_step(
...     xtft_model=my_model,
...     inputs=inputs,
...     forecast_horizon=forecast_horizon,
...     y=y_array,
...     dt_col="date",
...     mode="quantile",
...     spatial_cols=["longitude", "latitude"],
...     q=[0.1, 0.5, 0.9],
...     tname="subsidence",
...     forecast_dt=forecast_dates,
...     apply_mask=False,
...     verbose=3
... )
>>> print("Quantile Forecast:")
>>> print(forecast_df_quantile.head())
>>>
  1. For point forecast

>>> # Instantiate a dummy XTFT model.
>>> my_model = XTFT(
...     static_input_dim=2,    # "longitude" and "stat1"
...     dynamic_input_dim=2,   # "feat1" and "feat2"
...     future_input_dim=1,    # One future feature
...     forecast_horizon=forecast_horizon,
...     quantiles=None, # set quantiles to None
...     embed_dim=16,
...     max_window_size=3,
...     memory_size=50,
...     num_heads=2,
...     dropout_rate=0.1,
...     lstm_units=32,
...     attention_units=32,
...     hidden_units=16
... )
>>> my_model.compile(
...    optimizer="adam", loss="mse",
...    )
>>>
>>> # Fit the model on the dummy data for demonstration.
>>> my_model.fit(
...     x=[X_static, X_dynamic, X_future],
...     y=y_array,
...     epochs=1,
...     batch_size=8
... )
>>> # Generate a multi-step forecast in point mode.
>>> forecast_df_point = forecast_multi_step(
...     xtft_model=my_model,
...     inputs=inputs,
...     forecast_horizon=forecast_horizon,
...     y=y_array,
...     dt_col="date",
...     mode="point",
...     spatial_cols=["longitude", "latitude"],
...     tname="subsidence",
...     forecast_dt=forecast_dates,
...     apply_mask=False,
...     verbose=3
... )
>>> print("Point Forecast:")
>>> print(forecast_df_point.head())

Notes

  • In quantile mode, predictions are generated for each specified quantile for every forecast step, and the median (0.5) is used for evaluation.

  • In point mode, a single prediction is generated per forecast step.

  • The output prediction array is expected to have the shape \((n, forecast\_horizon, m)\), where \(n\) is the number of samples and \(m\) is the number of outputs per step (e.g., number of quantiles in quantile mode or 1 in point mode).

  • The provided spatial_cols must correspond to the first two columns of the original training data’s X_static.

  • Evaluation metrics such as R² Score and Coverage Score (in quantile mode) are computed if actual target values (y) are provided.

  • The DataFrame is constructed by iterating over each sample and each forecast step.

See also

forecast_single_step

Function for single-step forecasts.

coverage_score

Function to compute the coverage score.

validate_keras_model

Function to validate a Keras model.

assert_ratio

Function to verify quantile ratios.

geoprior.models.utils._utils.forecast_single_step(xtft_model, inputs, y=None, dt_col=None, mode='quantile', spatial_cols=None, q=None, tname=None, forecast_dt=None, apply_mask=False, mask_values=None, mask_fill_value=None, savefile=None, verbose=3, **kws)[source]

Generate a single-step forecast using the XTFT model.

This function generates a forecast for a single future time step using a pre-trained XTFT deep learning model. The model takes three inputs: X_static, X_dynamic, and X_future, and produces a prediction according to the formulation:

(64)#\[\hat{y}_{t+1} = f\Bigl( X_{\text{static}},\; X_{\text{dynamic}},\; X_{\text{future}} \Bigr)\]

where \(f\) is the trained XTFT model. The predictions can be either quantile-based or point-based, as determined by the mode parameter.

Parameters:
  • xtft_model (object) – A validated Keras model instance. The model is expected to be verified via validate_keras_model.

  • inputs (list or tuple of numpy.ndarray) – A list containing three elements: X_static, X_dynamic, and X_future. If spatial_cols is provided, it is assumed that the first column of X_static corresponds to the first spatial coordinate and the second column to the second spatial coordinate of the original training data.

  • y (numpy.ndarray, optional) – Actual target values. If provided, evaluation metrics such as R² Score and (in quantile mode) the coverage score are computed.

  • dt_col (str, optional) – Name of the time column (e.g. "year"). If provided, a column with this name is added to the output DataFrame. The actual time values must be supplied externally.

  • mode (str, optional) – Forecast mode. Must be either "quantile" or "point". In quantile mode, predictions are generated for multiple quantiles (default: 0.1, 0.5, and 0.9).

  • spatial_cols (list of str, optional) – List of spatial column names. If provided, it must contain at least two elements and correspond to the first and second columns of the original training data’s X_static.

  • q (list of float, optional) – List of quantiles for quantile forecasting. Default is [0.1, 0.5, 0.9] when mode is "quantile".

  • tname (str, optional) – Target variable name for predictions. This name is used to construct output column names (e.g. "subsidence"). Default is "target".

  • forecast_dt (any, optional) – Forecast datetime information. Not used in this function but may be provided for compatibility.

  • apply_mask (bool, optional) – If True, applies a masking function (mask_by_reference) to replace predictions in non-subsiding areas. Requires that both mask_values and mask_fill_value are provided.

  • mask_values (scalar, optional) – Reference value(s) used for masking. Must be provided if apply_mask is True.

  • mask_fill_value (scalar, optional) – Value used to fill masked predictions. Must be provided if apply_mask is True.

  • savefile (str, optional) – Path to a CSV file where the forecast results will be saved. If not provided, a default filename is generated.

  • verbose (int, optional) – Verbosity level controlling printed output. Higher values result in more detailed output.

Returns:

A DataFrame containing the forecast results. In quantile mode, the output includes columns for each quantile (e.g. <tname>_q10, <tname>_q50, <tname>_q90). In point mode, a single prediction column (<tname>_pred) is provided. If y is provided, an additional column with the actual target values (<tname>_actual) is included.

Return type:

pandas.DataFrame

Examples

>>> from geoprior.models.transformers import XTFT
>>> from geoprior.models.utils import forecast_single_step
>>> import pandas as pd
>>> import numpy as np
>>>
>>> # Create a dummy training DataFrame with a date column,
>>> # two dynamic features ("feat1", "feat2"), a static feature ("stat1"),
>>> # and dummy spatial features ("longitude", "latitude"), as well as the
>>> # target variable "subsidence".
>>> date_rng = pd.date_range(start="2020-01-01", periods=50, freq="D")
>>> train_df = pd.DataFrame({
...     "date": date_rng,
...     "longitude": np.random.uniform(-180, 180, 50),
...     "latitude": np.random.uniform(-90, 90, 50),
...     "feat1": np.random.rand(50),
...     "feat2": np.random.rand(50),
...     "stat1": np.random.rand(50),
...     "subsidence": np.random.rand(50)
... })
>>>
>>> # Prepare dummy inputs for the model.
>>> # For the static input, combine the spatial feature "longitude" and the
>>> # static feature "stat1". The forecast_single_step function expects that,
>>> # if spatial_cols is provided, the first two columns of X_static correspond
>>> # to the spatial coordinates.
>>> X_static = train_df[["longitude", "stat1"]].values   # shape: (50, 2)
>>>
>>> # Create a dummy dynamic input array for "feat1" and "feat2".
>>> # Assume time_steps = 3, so the shape is (50, 3, 2).
>>> X_dynamic = np.random.rand(50, 3, 2)
>>>
>>> # Create a dummy future input array.
>>> # For this example, assume one future feature with shape (50, 3, 1).
>>> X_future = np.random.rand(50, 3, 1)
>>>
>>> # Create dummy target output from "subsidence", reshaped to (50, 1, 1)
>>> y_array = train_df["subsidence"].values.reshape(50, 1, 1)
>>>
>>> # Instantiate a dummy XTFT model.
>>> # The model expects:
>>> #   - X_static with shape (n_samples, static_input_dim)
>>> #   - X_dynamic with shape (n_samples, time_steps, dynamic_input_dim)
>>> #   - X_future with shape (n_samples, time_steps, future_input_dim)
>>> my_model = XTFT(
...     static_input_dim=2,         # "longitude" and "stat1"
...     dynamic_input_dim=2,        # "feat1" and "feat2"
...     future_input_dim=1,         # One future feature
...     forecast_horizon=1,         # Single-step forecast
...     quantiles=[0.1, 0.5, 0.9],
...     embed_dim=16,
...     max_window_size=3,
...     memory_size=50,
...     num_heads=2,
...     dropout_rate=0.1,
...     lstm_units=32,
...     attention_units=32,
...     hidden_units=16
... )
>>> my_model.compile(optimizer="adam")
>>>
>>> # Fit the model on the dummy data.
>>> my_model.fit(
...     x=[X_static, X_dynamic, X_future],
...     y=y_array,
...     epochs=1,
...     batch_size=8
... )
>>>
>>> # Package the inputs as expected by forecast_single_step.
>>> inputs = [X_static, X_dynamic, X_future]
>>>
>>> # Generate a single-step quantile forecast.
>>> forecast_df = forecast_single_step(
...     xtft_model=my_model,
...     inputs=inputs,
...     y=y_array,
...     dt_col="date",                # The time column name in the output
...     mode="quantile",              # Can be "quantile" or "point"
...     spatial_cols=["longitude", "latitude"],
...     q=[0.1, 0.5, 0.9],
...     tname="subsidence",
...     apply_mask=True,
...     mask_values=0,
...     mask_fill_value=0,
...     verbose=3
... )
>>> print(forecast_df.head())

Notes

  • In quantile mode, the function computes predictions for multiple quantiles and uses the median (0.5) for evaluation.

  • If spatial_cols is provided, it must be the first and second columns of the original training data’s X_static.

  • The function internally utilizes validate_keras_model for model validation, assert_ratio for quantile verification, and mask_by_reference for masking operations.

  • Evaluation metrics such as R² Score and Coverage Score are computed if actual target values (y) are provided.

  • The prediction output is expected to have the shape \((n, 1, m)\), where \(m\) is the number of outputs (e.g., the number of quantiles in quantile mode, or 1 in point mode).

See also

generate_forecast_multi_step

Function for multi-step forecasts.

coverage_score

Function to compute the coverage score.

validate_keras_model

Function to validate a Keras model.

assert_ratio

Function to validate quantile ratios.

geoprior.models.utils._utils.step_to_long(df, tname=None, dt_col=None, spatial_cols=None, mode='quantile', quantiles=None, verbose=3, sort=True)[source]

Convert a multi-step forecast DataFrame from wide to long format.

This function transforms a DataFrame containing multi-step forecast predictions into a long-format DataFrame. In quantile mode, forecast columns such as subsidence_q10_step1, subsidence_q50_step1, etc. are consolidated into unified columns (e.g. subsidence_q10, subsidence_q50, etc.), while in point mode, a single prediction column (subsidence_pred) is generated. The transformation also carries over additional columns (e.g. spatial coordinates and time) from the original DataFrame.

Parameters:
  • df (pandas.DataFrame) – The multi-step forecast DataFrame. Expected to contain forecast prediction columns (e.g. columns with _q or _pred_step in their names) along with other identifiers.

  • tname (str, optional) – The base name of the target variable (e.g. "subsidence"). If None, the function attempts to auto-detect the target name from the column names.

  • dt_col (str, optional) – The name of the time column to include in the final DataFrame. If not provided, time sorting is not performed.

  • spatial_cols (list of str, optional) – A list of spatial coordinate columns (e.g. ["longitude", "latitude"]) to be retained in the final output.

  • mode ({"quantile", "point"}, default "quantile") – The forecast mode. In "quantile" mode, multiple quantile forecast columns are merged into unified columns. In "point" mode, a single prediction column is produced.

  • quantiles (list of float, optional) – The quantile values for quantile mode (e.g. [0.1, 0.5, 0.9]). If not provided, defaults are used.

  • sort (bool, optional) – If True, sorts the final DataFrame by the column specified in dt_col (if present). Default is True.

  • verbose (int, optional) – Verbosity level for logging output. Higher values (e.g. 5 to 7) provide more detailed debug information.

Returns:

A long-format DataFrame containing the retained spatial columns, the time column when dt_col is provided, and the merged forecast prediction columns. In quantile mode, the output contains unified columns such as subsidence_q10 and subsidence_q50. In point mode, it contains a single subsidence_pred column.

Return type:

pandas.DataFrame

Examples

>>> from geoprior.models.utils import step_to_long
>>> # Given a DataFrame `forecast_df` with columns like:
>>> # ['longitude', 'latitude', 'year', 'subsidence_actual',
>>> #  'subsidence_q10_step1', 'subsidence_q50_step1', 'subsidence_q89_step1',
>>> #  'subsidence_q10_step2', ...]
>>> long_df = step_to_long(
...     df=forecast_df,
...     tname="subsidence",
...     dt_col="year",
...     spatial_cols=["longitude", "latitude"],
...     mode="quantile",
...     quantiles=[0.1, 0.5, 0.9],
...     verbose=3,
...     sort=True
... )
>>> print(long_df.head())

Notes

Internally, this function calls:

  • check_forecast_mode() to validate the user-specified quantiles.

  • validate_consistency_q() and validate_quantiles() to ensure the supplied quantiles match those auto-detected from the DataFrame.

  • Depending on mode, either _step_to_long_q() for quantile mode or _step_to_long_pred() for point mode performs the conversion.

Mathematically, let \(X \in \mathbb{R}^{n \times m}\) represent the wide-format DataFrame, where each row corresponds to one sample and each forecast step is stored in separate columns. The function reshapes \(X\) into a long-format DataFrame \(Y \in \mathbb{R}^{(n \cdot s) \times p}\), where \(s\) is the forecast horizon and \(p\) is the number of output columns after merging forecast step values.

See also

_step_to_long_q

Converts multi-step quantile forecasts to long format.

_step_to_long_pred

Converts multi-step point forecasts to long format.

detect_digits

Extracts numeric values from strings for quantile detection.

geoprior.models.utils._utils.prepare_model_inputs(dynamic_input, static_input=None, future_input=None, model_type='strict', forecast_horizon=None, verbose=0, **kwargs)[source]

Prepares a list of input tensors for a model’s call method.

This function standardizes the creation of the input list [static, dynamic, future] expected by many models in geoprior. It handles cases where static or future inputs might be None, creating appropriate dummy tensors with zero features if the model_type is ‘strict’.

Parameters:
  • dynamic_input (np.ndarray or tf.Tensor) – The dynamic (past observed) features. This input is always required and must be a valid tensor or array. Expected shape: (batch_size, past_time_steps, num_dynamic_features).

  • static_input (np.ndarray or tf.Tensor, optional) – The static (time-invariant) features. Expected shape: (batch_size, num_static_features). If None and model_type is ‘strict’, a dummy tensor with 0 static features will be created. Default is None.

  • future_input (np.ndarray or tf.Tensor, optional) – The known future features. Expected shape: (batch_size, future_time_span, num_future_features). If None and model_type is ‘strict’, a dummy tensor with 0 future features will be created. The time span for this dummy future tensor will be past_time_steps (from dynamic_input) plus forecast_horizon if provided, otherwise just past_time_steps. Default is None.

  • model_type ({'strict', 'flexible'}, default 'strict') –

    Determines how None inputs for static and future features are handled:

    • 'strict': If static_input or future_input is None, a dummy tensor with a feature dimension of 0 will be created and included in the output list. This is for models that expect a 3-element list of tensors, even if some paths are unused.

    • 'flexible': If static_input or future_input is None, None itself will be placed in the corresponding position in the output list. This is for models that can internally handle None inputs for optional feature types.

  • forecast_horizon (int, optional) – The forecast horizon. Used only if model_type=’strict’ and future_input is None, to determine the time dimension of the dummy future tensor (as past_time_steps + forecast_horizon). If not provided in this scenario, the dummy future tensor’s time dimension will match dynamic_input’s past_time_steps. Default is None.

  • verbose (int, default 0) –

    Verbosity level. If > 0, prints information about dummy tensor creation.

    • 0: Silent.

    • 1: Basic info on dummy creation.

    • 2: More details on shapes.

Returns:

A list containing three elements in the order: [processed_static_input, processed_dynamic_input, processed_future_input]. Elements can be TensorFlow tensors or None (if model_type=’flexible’ and original input was None). All returned tensors are cast to tf.float32.

Return type:

List[Optional[tf.Tensor]]

Raises:
  • ValueError – If dynamic_input is None. If dynamic_input is not at least 2D (needs batch dimension). If static_input (when provided) is not 2D. If future_input (when provided) is not 3D.

  • TypeError – If inputs cannot be converted to TensorFlow tensors.

Examples

>>> import tensorflow as tf
>>> import numpy as np
>>> from geoprior.models.utils import prepare_model_inputs
>>> B, T, H = 2, 10, 3
>>> D_s, D_d, D_f = 2, 4, 1
>>> dyn_in = tf.random.normal((B, T, D_d))
>>> stat_in = tf.random.normal((B, D_s))
>>> fut_in = tf.random.normal((B, T + H, D_f))
>>> # Strict mode, all inputs provided
>>> s, d, f = prepare_model_inputs(dyn_in, stat_in, fut_in, model_type='strict')
>>> print(f"S: {s.shape}, D: {d.shape}, F: {f.shape}")
S: (2, 2), D: (2, 10, 4), F: (2, 13, 1)
>>> # Strict mode, static is None
>>> s, d, f = prepare_model_inputs(dyn_in, static_input=None, future_input=fut_in,
...                                model_type='strict', forecast_horizon=H)
>>> print(f"S: {s.shape}, D: {d.shape}, F: {f.shape}")
S: (2, 0), D: (2, 10, 4), F: (2, 13, 1)
>>> # Flexible mode, static and future are None
>>> s, d, f = prepare_model_inputs(dyn_in, static_input=None, future_input=None,
...                                model_type='flexible')
>>> print(f"S: {s is None}, D: {d.shape}, F: {f is None}")
S: True, D: (2, 10, 4), F: True
geoprior.models.utils._utils.extract_batches_from_dataset(dataset, num_batches_to_extract=1, agg=False, errors='warn')[source]

Extracts a specified number of batches from a tf.data.Dataset. Optionally aggregates the extracted batches.

Parameters:
  • dataset (tf.data.Dataset) – The TensorFlow dataset to extract batches from.

  • num_batches_to_extract (Union[int, str], default 1) – Number of batches: int, or ‘all’, ‘*’, ‘auto’.

  • agg (bool, default False) – If True, attempts to aggregate the extracted batches into a single tuple structure by concatenating corresponding tensors/arrays or aggregating dictionaries.

  • errors (str, default 'warn') – Error handling: ‘raise’, ‘warn’, ‘ignore’.

Returns:

If agg is False, returns a list of batch tuples. If agg is True, returns one aggregated tuple or None if no batches were extracted. When zero batches are requested or the dataset is empty, the function returns an empty list for agg=False and None for agg=True.

Return type:

Union[List[Tuple[Any, ]], Optional[Tuple[Any, ]]]

Raises:
  • TypeError – If dataset is not a tf.data.Dataset or num_batches_to_extract is invalid (and errors=’raise’).

  • ValueError – If num_batches_to_extract is negative, or fewer batches are available than requested (and errors=’raise’ and not taking all).

  • RuntimeError – For unexpected errors during dataset iteration (and errors=’raise’).

geoprior.models.utils._utils.format_predictions(predictions=None, model=None, inputs=None, y_true_sequences=None, target_name='target', quantiles=None, forecast_horizon=None, output_dim=None, spatial_data_array=None, spatial_cols=None, spatial_cols_indices=None, evaluate_coverage=False, scaler=None, scaler_feature_names=None, target_idx_in_scaler=None, verbose=0, _logger=None, **kwargs)[source]

Formats model predictions into a structured pandas DataFrame.

This utility function takes raw model predictions (either directly as an array/tensor or generated by a provided model and its inputs) and transforms them into a long-format pandas DataFrame. It can handle point forecasts, quantile forecasts, single or multi-output predictions, and optionally include actual target values, spatial identifiers, and perform coverage score evaluation for quantile forecasts.

The output DataFrame is structured with ‘sample_idx’ (identifying the original input sequence) and ‘forecast_step’ (from 1 to H, where H is the forecast horizon).

Parameters:
  • predictions (np.ndarray or tf.Tensor, optional) –

    The raw prediction tensor or array.

    • For point forecasts, expected shapes:
      • (num_samples, forecast_horizon, output_dim)

      • (num_samples, forecast_horizon) if output_dim=1 (will be reshaped)

      • (num_samples, output_dim) if forecast_horizon=1 (will be reshaped)

    • For quantile forecasts, expected shapes:
      • (num_samples, forecast_horizon, num_quantiles * output_dim)

      • (num_samples, forecast_horizon, num_quantiles, output_dim)

      • (num_samples, num_quantiles * output_dim) if forecast_horizon=1

    If None, model and inputs must be provided. Default is None.

  • model (tf.keras.Model, optional) – A trained Keras model to generate predictions if predictions is not provided. Used in conjunction with inputs. Default is None.

  • inputs (List[Optional[Union[np.ndarray, tf.Tensor]]], optional) – A list of input tensors (e.g., [static, dynamic, future]) required by the model to generate predictions. Required if predictions is None and model is provided. Default is None.

  • y_true_sequences (np.ndarray or tf.Tensor, optional) – The true target values corresponding to the predictions, used for including actuals in the output DataFrame and for evaluation. Expected shape: (num_samples, forecast_horizon, output_dim). Default is None.

  • target_name (str, optional) – Base name for the target variable. Used to prefix prediction and actual column names (e.g., “sales_pred”, “sales_q50”, “sales_actual”). Default is “target”.

  • quantiles (List[float], optional) – A list of quantiles that were predicted by the model (e.g., [0.1, 0.5, 0.9]). Required if the predictions are quantile forecasts. If provided, prediction columns will be named like {target_name}_q10, {target_name}_q50, etc. Default is None (for point forecasts).

  • forecast_horizon (int, optional) – The number of time steps into the future that the model predicts. If not provided, it’s inferred from predictions.shape[1] or y_true_sequences.shape[1]. Default is None.

  • output_dim (int, optional) – The number of target variables predicted at each time step (e.g., 1 for univariate, >1 for multivariate target). If not provided, it’s inferred from the shape of predictions or y_true_sequences. Default is None.

  • spatial_data_array (np.ndarray or tf.Tensor or pd.DataFrame or pd.Series, optional) –

    An array or DataFrame containing static spatial/identifier features for each of the num_samples sequences.

    • If NumPy/Tensor: Expected shape (num_samples, num_spatial_features). spatial_cols_indices must be provided.

    • If DataFrame/Series: Must have num_samples rows. spatial_cols must be provided.

    These features will be repeated for each forecast step in the output DataFrame. Default is None.

  • spatial_cols (List[str], optional) – List of column names to select from spatial_data_array if it’s a DataFrame/Series, or names to assign to columns if spatial_data_array is NumPy/Tensor and spatial_cols_indices are provided. Default is None.

  • spatial_cols_indices (List[int], optional) – List of column indices to select from spatial_data_array if it’s a NumPy/Tensor. Length must match spatial_cols if provided. Default is None.

  • evaluate_coverage (bool, default False) – If True, quantiles are provided (at least two), and y_true_sequences is available, calculates the coverage score using the first and last quantiles as interval bounds. Requires geoprior.metrics.coverage_score.

  • scaler (Any, optional) – A fitted scikit-learn-like scaler object (must have an inverse_transform method) used to scale the target variable and potentially other features. If provided along with scaler_feature_names and target_idx_in_scaler, predictions and actuals for the target will be inverse-transformed. Default is None.

  • scaler_feature_names (List[str], optional) – A list of all feature names (in order) that the scaler was originally fit on. Required if scaler is provided and targeted inverse transformation is needed. Default is None.

  • target_idx_in_scaler (int, optional) – The index of the target_name within the scaler_feature_names list. Required if scaler is provided and targeted inverse transformation is needed. Default is None.

  • verbose (int, default 0) –

    Verbosity level for logging during processing.

    • 0: Silent.

    • 1: Basic info.

    • 3: More detailed steps.

    • 5: Very detailed shape information.

  • **kwargs (Any) – Additional keyword arguments (currently not used but included for future extensibility).

  • _logger (Logger | Callable[[str], None] | None)

  • **kwargs

Returns:

A long-format DataFrame containing sample_idx and forecast_step, optional spatial columns, prediction columns, and actual-value columns when y_true_sequences is provided. Point forecasts use names like {target_name}_pred or {target_name}_{output_idx}_pred. Quantile forecasts use names like {target_name}_qXX or {target_name}_{output_idx}_qXX. Actual values use {target_name}_actual or {target_name}_{output_idx}_actual. Prediction and actual values are inverse-transformed when valid scaler information is provided.

Return type:

pandas.DataFrame

Raises:
  • ValueError – If predictions is None and model or inputs is also None. If predictions shape is invalid (not 2D, 3D, or 4D). If quantiles are provided but prediction shape is incompatible for inferring output_dim. If spatial_data_array is provided without necessary name/index parameters.

  • TypeError – If predictions or other inputs cannot be converted to the expected tensor/array types.

See also

geoprior.models.utils.forecast_multi_step

Higher-level forecasting utility.

geoprior.metrics.coverage_score

For evaluating quantile forecast intervals.

Examples

>>> import tensorflow as tf
>>> import numpy as np
>>> from geoprior.models.utils import format_predictions_to_dataframe
>>> B, H, O = 4, 3, 1 # Batch, Horizon, OutputDim
>>> Q = [0.1, 0.5, 0.9]
>>> preds_point = tf.random.normal((B, H, O))
>>> preds_quant = tf.random.normal((B, H, len(Q))) # For O=1
>>> y_true = tf.random.normal((B, H, O))
>>> # Point forecast
>>> df_point = format_predictions_to_dataframe(
...     predictions=preds_point, y_true_sequences=y_true,
...     target_name="value", forecast_horizon=H, output_dim=O
... )
>>> print(df_point.head(H)) # Show first sample's horizon
   sample_idx  forecast_step  value_pred  value_actual
0           0              1   -0.576731     -0.647362
1           0              2    0.183931      1.198977
2           0              3   -0.766871      0.534040
>>> # Quantile forecast
>>> df_quant = format_predictions_to_dataframe(
...     predictions=preds_quant, y_true_sequences=y_true,
...     target_name="value", quantiles=Q,
...     forecast_horizon=H, output_dim=O
... )
>>> print(df_quant.head(H))
   sample_idx  forecast_step  value_q10  value_q50  value_q90  value_actual
0           0              1  -0.209947   0.263107  -0.308929     -0.647362
1           0              2   0.303091   0.594701  -0.225007      1.198977
2           0              3   0.136699  -1.237739   0.002834      0.534040
>>> # With spatial data (NumPy array)
>>> spatial_np = np.array([[101, 201], [102, 202], [103, 203], [104, 204]])
>>> df_spatial = format_predictions_to_dataframe(
...     predictions=preds_point,
...     spatial_data_array=spatial_np,
...     spatial_cols=['store_id', 'region_id'],
...     spatial_cols_indices=[0, 1]
... )
>>> print(df_spatial[['sample_idx', 'forecast_step', 'store_id']].head(H))
   sample_idx  forecast_step  store_id
0           0              1     101.0
1           0              2     101.0
2           0              3     101.0
geoprior.models.utils._utils.export_keras_losses(history, keys=None, savefile=None, verbose=0, formats=('json', 'csv'))[source]

Export loss(es) (and any other metric) from a Keras History object.

Parameters:
  • history (History) – The History object returned by model.fit().

  • keys (list[str] | None) – List of history.history keys to export, e.g. [“loss”, “val_loss”, “physics_loss”]. If None, defaults to all keys ending with “loss”.

  • savefile (str | None) –

    Path (optionally with extension) where to write the output.

    If the extension is .json, only JSON is written. If the extension is .csv, only CSV is written. If no extension is given, all formats in formats are written using savefile as the base name.

  • verbose (int) – If >0, prints status messages.

  • formats (tuple[str, ...]) – File formats to write when savefile has no extension. Valid entries are “json” and “csv”.

Returns:

result – Dictionary containing "epochs_run" and one list entry per exported history key.

Return type:

dict

geoprior.models.utils._utils.get_tensor_from(inputs, *tensor_names, default=None, check_type=True, auto_convert=True)[source]

Safely retrieves the first available tensor from a dictionary using a list of possible keys.

This utility is crucial for handling model inputs within a TensorFlow graph (e.g., in train_step). It avoids the ambiguous boolean evaluation of Tensors (e.g., tensor_a or tensor_b), which causes runtime errors, by explicitly checking for is not None.

Parameters:
  • inputs (dict) – The dictionary to search, typically the model’s input dictionary (e.g., the inputs provided to call or train_step).

  • *tensor_names (str) – One or more string keys to check for in the inputs dictionary, in order of priority.

  • default (Any, optional) – A default value to return if no keys are found or if no found value is a valid tensor. Defaults to None.

  • check_type (bool, default True) – If True, only returns a value if it is (or can be converted to) a Tensor or Variable. If False, returns the first non-None value regardless of its type.

  • auto_convert (bool, default True) – If True and check_type is True, this function will attempt to convert a found non-Tensor value (like a NumPy array or a list) into a TensorFlow tensor using tf.convert_to_tensor.

Returns:

The first found tf.Tensor or tf.Variable associated with one of the tensor_names. If auto_convert is True, this can also be a newly converted tensor. Returns default (typically None) if no valid tensor is found.

Return type:

Optional[tf.Tensor]

Raises:

TypeError – If inputs is not a dictionary.

Examples

>>> import tensorflow as tf
>>> inputs_dict = {
...     'some_other_key': [1, 2, 3],
...     'soil_thickness': tf.constant([20., 21.], dtype=tf.float32)
... }
>>>
>>> # Correctly finds 'soil_thickness'
>>> get_tensor_from(inputs_dict, 'H_field', 'soil_thickness')
<tf.Tensor: shape=(2,), dtype=float32, numpy=array([20., 21.], ...)>
>>>
>>> # Returns None safely if nothing is found
>>> get_tensor_from(inputs_dict, 'missing_key', 'another_key')
None
>>>
>>> # Demonstrating auto_convert
>>> inputs_dict_np = {'H_field': np.array([10., 11.])}
>>> get_tensor_from(inputs_dict_np, 'H_field', auto_convert=True)
<tf.Tensor: shape=(2,), dtype=float32, numpy=array([10., 11.], ...)>

This module contains many of the reusable helpers for sequence forecasting and model-side formatting. Its own module docstring frames the utilities in the context of Temporal Fusion Transformer-style preprocessing and multi-horizon forecasting, which gives the page valuable methodological context.

PINN/model physics helper module#

Physics-Informed Neural Network (PINN) Utility functions.

geoprior.models.utils.pinn.process_pde_modes(pde_mode, enforce_consolidation=False, pde_mode_config=None, solo_return=False)[source]

Normalize and validate PDE mode selection.

Parameters:
  • pde_mode (str, sequence of str, or None) –

    Requested PDE mode(s).

    Accepted canonical values are: - "none" - "consolidation" - "gw_flow" - "both"

    Accepted aliases: - None, "off" -> "none" - "on" -> "both"

  • enforce_consolidation (bool, default False) –

    If True, any resolved mode other than exact ["consolidation"] is coerced to ["consolidation"] and a warning is emitted.

    This includes: - ["none"] - ["gw_flow"] - ["consolidation", "gw_flow"]

  • pde_mode_config (str, sequence of str, or None, optional) – Optional override. If provided, this value takes precedence over pde_mode.

  • solo_return (bool, default False) –

    If False, return a canonical list of active modes.

    If True, return a single canonical label: - "none" - "consolidation" - "gw_flow" - "both"

Returns:

Canonical PDE mode(s), either as a list or a single label.

Return type:

list of str or str

Raises:
  • TypeError – If the input type is invalid.

  • ValueError – If a token is unsupported or the mode selection is ambiguous.

Examples

>>> process_pde_modes(None)
['none']
>>> process_pde_modes("off")
['none']
>>> process_pde_modes("on")
['consolidation', 'gw_flow']
>>> process_pde_modes("both", solo_return=True)
'both'
>>> process_pde_modes("gw_flow", enforce_consolidation=True)
['consolidation']
geoprior.models.utils.pinn.format_pinn_predictions(predictions=None, model=None, model_inputs=None, y_true_dict=None, target_mapping=None, include_gwl=True, include_coords=True, quantiles=None, forecast_horizon=None, output_dims=None, ids_data_array=None, ids_cols=None, ids_cols_indices=None, scaler_info=None, coord_scaler=None, evaluate_coverage=False, coverage_quantile_indices=(0, -1), savefile=None, _logger=None, name=None, model_name=None, stop_check=None, verbose=0, **kwargs)[source]

Formats PINN model predictions into a structured pandas DataFrame.

This is a general-purpose utility for transforming raw model outputs (from models like PIHALNet or TransFlowSubsNet) into a long-format DataFrame suitable for analysis, visualization, or export.

This is a powerful, general-purpose utility for transforming raw model outputs into a long-format DataFrame suitable for analysis, visualization, or export. It handles multi-target outputs (e.g., subsidence and GWL), point or quantile forecasts, and can optionally include true values, coordinate information, and other metadata. It also supports inverse-scaling of predictions and evaluation of quantile coverage.

Parameters:
  • predictions (dict of Tensors, optional) – The dictionary of prediction tensors, typically returned by a model’s .predict() method. Keys should match the model’s output layer names (e.g., 'subs_pred', 'gwl_pred'). If None, predictions are generated internally using the model and model_inputs arguments. Default is None.

  • model (keras.Model, optional) – A compiled Keras model instance used to generate predictions if the predictions dictionary is not provided. Default is None.

  • model_inputs (dict of Tensors, optional) – A dictionary of input tensors matching the model’s signature, required only if predictions is None. Default is None.

  • y_true_dict (dict, optional) – A dictionary containing the ground-truth target arrays, keyed by their base names (e.g., 'subsidence', 'gwl'). If provided, an <target>_actual column will be added to the output DataFrame for comparison. Default is None.

  • target_mapping (dict, optional) – A custom mapping from model output keys to desired base names in the DataFrame columns. For example: {'subs_pred': 'subsidence_mm', 'gwl_pred': 'head_m'}. Default is None.

  • include_gwl (bool, default True) – Toggles the inclusion of groundwater level (GWL) predictions in the final DataFrame.

  • include_coords (bool, default True) – Toggles the inclusion of the spatio-temporal coordinate columns (coord_t, coord_x, coord_y) in the final DataFrame.

  • quantiles (list of float, optional) – The list of quantile levels (e.g., [0.1, 0.5, 0.9]) that the model predicted. This is crucial for correctly parsing probabilistic forecasts. Default is None.

  • forecast_horizon (int, optional) – The length of the forecast horizon. If None, it is inferred from the shape of the prediction tensors. Default is None.

  • output_dims (dict of str, optional) – A dictionary specifying the feature dimension of each target, e.g., {'subs_pred': 1, 'gwl_pred': 1}. If None, it’s inferred from the tensor shapes. Default is None.

  • ids_data_array (np.ndarray or pd.DataFrame, optional) – An array or DataFrame containing static identifiers (e.g., well IDs, site categories) for each sample. Its length must match the number of samples in the prediction. Default is None.

  • ids_cols (list of str, optional) – A list of column names for the ids_data_array. Required if ids_data_array is a NumPy array. Default is None.

  • ids_cols_indices (list of int, optional) – A list of column indices to select from ids_data_array if it is a NumPy array. Default is None.

  • scaler_info (dict, optional) – A dictionary providing the necessary information to perform inverse scaling on a per-target basis. Each key should be a target name (e.g., ‘subsidence’) and its value a dictionary containing {'scaler': obj, 'all_features': list, 'idx': int}. Default is None.

  • coord_scaler (object, optional) – A fitted scikit-learn-like scaler object used to perform an inverse transform on the coordinate columns. Default is None.

  • evaluate_coverage (bool, default False) – If True and quantile predictions are present, calculates the unconditional coverage of the prediction interval.

  • coverage_quantile_indices (tuple of (int, int), default (0, -1)) – The indices of the lower and upper quantiles in the sorted quantiles list to use for the coverage calculation. Default is (0, -1), which corresponds to the full range.

  • savefile (str, optional) – If a file path is provided, the final DataFrame is saved to a CSV file at this location. Default is None.

  • name (str or None) – Name of the prediction. Name is used to format the output of the data and coverage result if applicable.

  • model_name (str, None,) – Name of the model.

  • verbose (int, default 0) – The verbosity level, from 0 (silent) to 5 (trace every step).

  • **kwargs (dict,) – Additional keyword arguments for future extensions.

  • _logger (Logger | Callable[[str], None] | None)

  • stop_check (Callable[[], bool])

Returns:

A long-format DataFrame where each row represents a single forecast step for a single sample. Columns include sample and step identifiers, coordinates, predictions, and optionally actuals and metadata.

Return type:

pd.DataFrame

Notes

  • The function returns a column-aligned DataFrame, which simplifies subsequent analysis and plotting.

  • For quantile forecasts, prediction columns are named using the pattern <target_name>_q<quantile*100>, e.g., subsidence_q5, subsidence_q50, subsidence_q95.

  • For point forecasts, the column is named <target_name>_pred.

See also

geoprior.plot.forecast.plot_forecasts

A powerful utility for visualizing the DataFrame produced by this function.

geoprior.models.utils.pinn.format_pihalnet_predictions(pihalnet_outputs=None, model=None, model_inputs=None, y_true_dict=None, target_mapping=None, include_gwl=True, include_coords=True, quantiles=None, forecast_horizon=None, output_dims=None, ids_data_array=None, ids_cols=None, ids_cols_indices=None, scaler_info=None, coord_scaler=None, evaluate_coverage=False, coverage_quantile_indices=(0, -1), savefile=None, name=None, model_name=None, apply_mask=False, mask_values=None, mask_fill_value=None, verbose=0, _logger=None, stop_check=None, **kwargs)[source]

Formats PIHALNet/GeoPriorSubsNet predictions into a structured pandas DataFrame, handling inversion, quantiles, and coordinates.

This function is the core formatter. It: 1. Gets model outputs (or uses provided ones). 2. Unpacks ‘data_final’ if model_name is ‘geoprior’. 3. Inverse-transforms all prediction and actual arrays using scaler_info. 4. Builds a long-format DataFrame with sample_idx and forecast_step. 5. Appends inverted quantile/point predictions. 6. Appends inverted actual values. 7. Appends inverted coordinates. 8. Appends static/ID columns. 9. Evaluates coverage on the inverted data.

Parameters:
  • pihalnet_outputs (dict, optional) – Raw output from model.predict(). If None, model and model_inputs must be provided.

  • model (tf.keras.Model, optional) – Trained model instance (if pihalnet_outputs is None).

  • model_inputs (dict, optional) – Inputs for the model to generate predictions (if pihalnet_outputs is None).

  • y_true_dict (dict, optional) – Dictionary of true target arrays (e.g., {‘subs_pred’: y_true_s}). Required for including actuals and evaluating coverage.

  • target_mapping (dict, optional) – Maps prediction keys to base names for DataFrame columns. Default: {‘subs_pred’: ‘subsidence’, ‘gwl_pred’: ‘gwl’}.

  • include_gwl (bool, default True) – Whether to include ‘gwl_pred’ in the final DataFrame.

  • include_coords (bool, default True) – Whether to include ‘coord_t’, ‘coord_x’, ‘coord_y’ columns.

  • quantiles (list[float], optional) – List of quantiles (e.g., [0.1, 0.5, 0.9]). If provided, quantile columns (e.g., ‘subsidence_q10’) are created.

  • forecast_horizon (int, optional) – The forecast horizon length (H). If not provided, it’s inferred from the prediction array’s shape.

  • output_dims (dict, optional) – Maps prediction keys to their output dimension (O). E.g., {‘subs_pred’: 1, ‘gwl_pred’: 1}. Crucial for correctly splitting GeoPrior outputs and reshaping.

  • ids_data_array (np.ndarray or pd.DataFrame, optional) – Static/ID data (e.g., original coordinates) to merge. Must have the same number of samples (B) as predictions.

  • ids_cols (list[str], optional) – Column names if ids_data_array is a DataFrame.

  • ids_cols_indices (list[int], optional) – Column indices if ids_data_array is a NumPy array.

  • scaler_info (dict, optional) – Dictionary for inverse scaling. Each target entry should provide a fitted scaler, the target index inside that scaler, and the feature-name ordering used when the scaler was fit.

  • coord_scaler (sklearn.preprocessing.Scaler, optional) – A fitted scaler object for inverse transforming the ‘coords’ tensor.

  • evaluate_coverage (bool, default False) – If True, calculates coverage percentage for quantiles.

  • coverage_quantile_indices (tuple[int, int], default (0, -1)) – Indices of the low and high quantiles in the quantiles list to use for coverage (e.g., 0 and -1 for 10th and 90th).

  • savefile (str, optional) – If provided, saves the final DataFrame to this path.

  • model_name (str, optional) – Specifies the model type. If ‘geoprior’ or ‘geopriorsubsnet’, triggers unpacking of the ‘data_final’ output.

  • apply_mask (bool, default False) – If True, masks predictions based on mask_values in the first target’s _actual column.

  • mask_values (float or int, optional) – The value in the _actual column to trigger masking.

  • mask_fill_value (float, optional) – The value to replace masked predictions with (e.g., np.nan).

  • verbose (int, default 0) – Logging verbosity.

  • _logger (logging.Logger or callable, optional) – Logger object.

  • stop_check (callable, optional) – Function to check for early stopping.

  • name (str | None)

Returns:

A long-format DataFrame with predictions, actuals, and coordinates.

Return type:

pd.DataFrame

geoprior.models.utils.pinn.format_preds(pihalnet_outputs=None, model=None, model_inputs=None, y_true_dict=None, target_mapping=None, include_gwl=True, include_coords=True, quantiles=None, forecast_horizon=None, output_dims=None, ids_data_array=None, ids_cols=None, ids_cols_indices=None, scaler_info=None, coord_scaler=None, evaluate_coverage=False, coverage_quantile_indices=(0, -1), savefile=None, name=None, apply_mask=False, mask_values=None, mask_fill_value=None, verbose=0, _logger=None, stop_check=None, **kwargs)[source]

Main function orchestrating all helper steps.

Parameters:
Return type:

DataFrame

geoprior.models.utils.pinn.prepare_pinn_data_sequences(df, time_col, subsidence_col, gwl_col, dynamic_cols, static_cols=None, future_cols=None, spatial_cols=None, h_field_col=None, lon_col=None, lat_col=None, group_id_cols=None, time_steps=12, forecast_horizon=3, output_subsidence_dim=1, output_gwl_dim=1, datetime_format=None, normalize_coords=True, cols_to_scale=None, lock_physics_cols=True, protect_si_suffix='__si', return_coord_scaler=False, coord_scaler=None, fit_coord_scaler=True, mode=None, model=None, savefile=None, progress_hook=None, stop_check=None, verbose=0, _logger=None, **kws)[source]
Parameters:
Return type:

tuple[dict[str, ndarray], dict[str, ndarray]] | tuple[dict[str, ndarray], dict[str, ndarray], MinMaxScaler | None]

geoprior.models.utils.pinn.check_and_rename_keys(inputs, y)[source]

Helper function to check and rename keys in the inputs and target dictionaries.

This function ensures that the necessary keys are present in both the inputs and y dictionaries. If the keys for ‘subsidence’ or ‘gwl’ are not found, it attempts to rename them from possible alternatives like ‘subs_pred’ or ‘gwl_pred’.

Parameters:
  • inputs (dict) – A dictionary containing the input data. The keys ‘coords’ and ‘dynamic_features’ are expected.

  • y (dict) – A dictionary containing the target values. The keys ‘subsidence’ and ‘gwl’ are expected, but they could also appear as ‘subs_pred’ or ‘gwl_pred’.

Raises:

ValueError – If required keys are missing in inputs or y, or if renaming does not result in valid keys for ‘subsidence’ and ‘gwl’.

geoprior.models.utils.pinn.check_required_input_keys(inputs, y=None, message=None, model_name=None, do_rename=True)[source]

Validate presence of required keys in inputs and y. Optionally canonicalize keys via reverse alias mapping.

This function ensures that the necessary keys are present in both the inputs and y dictionaries. If the keys for ‘subsidence’ or ‘gwl’ are not found, it attempts to rename them from possible alternatives like ‘subs_pred’ or ‘gwl_pred’.

Parameters:
  • inputs (dict) – A dictionary containing the input data. The keys ‘coords’ and ‘dynamic_features’ are expected.

  • y (dict) – A dictionary containing the target values. The keys ‘subsidence’ and ‘gwl’ are expected, but they could also appear as ‘subs_pred’ or ‘gwl_pred’.

  • message (str, optional) – Message to raise error when inputs/y are not dictionnary.

  • model_name (str | None)

  • do_rename (bool)

Raises:

ValueError – If required keys are missing in inputs or y, or if renaming does not result in valid keys for ‘subsidence’ and ‘gwl’.

Return type:

tuple[dict[str, Any] | None, dict[str, Any] | None]

geoprior.models.utils.pinn.extract_txy_in(inputs, coord_slice_map=None, expect_dim=None, verbose=0, _logger=None, **kws)[source]

Extracts t, x, y tensors from various input formats.

This utility standardizes coordinate inputs, accepting a single tensor or a dictionary, and handling both 2D (spatial/static) and 3D (spatio-temporal) data. It ensures a consistent 3D output format for robust downstream processing.

Parameters:
  • inputs (tf.Tensor, np.ndarray, or dict) – The input data containing coordinates. A single tensor or array may be 2D with shape (batch, 3) or 3D with shape (batch, time_steps, 3). A dictionary may contain a 'coords' key with the coordinate tensor, or separate 't', 'x', and 'y' keys.

  • coord_slice_map (dict, optional) – Mapping for ‘t’, ‘x’, ‘y’ to their index in the last dimension of a coordinate tensor. Defaults to {‘t’: 0, ‘x’: 1, ‘y’: 2}.

  • expect_dim ({'2d', '3d'}, optional) – If provided, enforces that the input resolves to the specified dimension. '2d' requires input shaped like (batch, 3) or a dictionary of (batch, 1) tensors. '3d' requires input shaped like (batch, time, 3) or a dictionary of (batch, time, 1) tensors. If None, both are accepted and 2D inputs are expanded to 3D.

  • verbose (int, default 0) – Controls the verbosity of logging messages. 0 is silent, 1 provides basic info, and higher values provide more detail.

  • _logger (Logger | Callable[[str], None] | None)

Returns:

t, x, y – The extracted t, x, and y coordinate tensors, each reshaped to be 3D with a singleton last dimension, e.g., (batch, time_steps, 1).

Return type:

Tuple[tf.Tensor, tf.Tensor, tf.Tensor]

Raises:

ValueError – If input format is unsupported, dimensions are inconsistent, or expect_dim constraint is violated.

geoprior.models.utils.pinn.extract_txy(inputs, coord_slice_map=None, expect_dim=None, verbose=0, _logger=None, **kws)[source]

Extracts t, x, y tensors from various input formats.

This utility standardizes coordinate inputs, accepting a single tensor or a dictionary, and handling both 2D (spatial/static) and 3D (spatio-temporal) data with flexible dimension validation.

Parameters:
  • inputs (tf.Tensor, np.ndarray, or dict) – The input data containing coordinates. Can be a single tensor or a dictionary with ‘coords’ or ‘t’, ‘x’, ‘y’ keys.

  • coord_slice_map (dict, optional) – Mapping for ‘t’, ‘x’, ‘y’ to their index in the last dimension of a coordinate tensor. Defaults to {‘t’: 0, ‘x’: 1, ‘y’: 2}.

  • expect_dim ({'2d', '3d', '3d_only'}, optional) – Enforces a constraint on the input’s dimension. '2d' requires input shaped like (batch, 3). '3d' accepts 3D input and expands 2D input to 3D with a time dimension of 1. '3d_only' requires 3D input and raises an error for 2D input. None accepts both 2D and 3D inputs without changing their rank.

  • verbose (int, default 0) – Controls logging verbosity.

  • _logger (Logger | Callable[[str], None] | None)

Returns:

t, x, y – The extracted t, x, and y coordinate tensors. Their rank (2D or 3D) depends on the input and the expect_dim mode.

Return type:

Tuple[tf.Tensor, tf.Tensor, tf.Tensor]

Raises:

ValueError – If input format is unsupported, dimensions are inconsistent, or expect_dim constraint is violated.

geoprior.models.utils.pinn.plot_hydraulic_head(model, t_slice, x_bounds, y_bounds, resolution=100, ax=None, title=None, cmap='viridis', colorbar_label='Hydraulic Head (h)', save_path=None, show_plot=True, **contourf_kwargs)[source]

Generate and plot a 2D contour map of a hydraulic head solution.

This utility visualizes the output of a Physics-Informed Neural Network (PINN) that solves for the hydraulic head \(h(t, x, y)\). It automates the process of creating a spatial grid, running model predictions, and generating a publication-quality contour plot for a specific slice in time.

Parameters:
  • model (tf.keras.Model) – The trained PINN model. It is expected to have a .predict() method that accepts a dictionary of tensors with keys {'t', 'x', 'y'}.

  • t_slice (float) – The specific point in time \(t\) for which to plot the 2D spatial solution.

  • x_bounds (tuple of float) – A tuple (x_min, x_max) defining the spatial domain for the x-axis.

  • y_bounds (tuple of float) – A tuple (y_min, y_max) defining the spatial domain for the y-axis.

  • resolution (int, optional) – The number of points to sample along each spatial axis, creating a grid of resolution x resolution points for prediction. Higher values result in a smoother plot. Default is 100.

  • ax (matplotlib.axes.Axes, optional) – A pre-existing Matplotlib Axes object to plot on. If None, a new figure and axes are created internally. This is useful for embedding this plot within a larger figure arrangement. Default is None.

  • title (str, optional) – A custom title for the plot. If None, a default title is generated using the value of t_slice. Default is None.

  • cmap (str, optional) – The name of the Matplotlib colormap to use for the contour plot. Default is 'viridis'.

  • colorbar_label (str, optional) – The text label for the color bar. Default is 'Hydraulic Head (h)'.

  • save_path (str, optional) – If provided, the path (including filename and extension) where the generated plot will be saved. This is only active when the function creates its own figure (i.e., when ax is None). Default is None.

  • show_plot (bool, optional) – If True, calls plt.show() to display the plot. This is only active when the function creates its own figure. Default is True.

  • **contourf_kwargs (any) – Additional keyword arguments that are passed directly to the matplotlib.pyplot.contourf function. This allows for advanced customization (e.g., levels=20, extend='both').

Returns:

  • ax (matplotlib.axes.Axes) – The Matplotlib Axes object on which the contour plot was drawn.

  • contour (matplotlib.cm.ScalarMappable) – The contour plot object, which can be used for further customizations, such as modifying the color bar.

Return type:

tuple[Axes, _ScalarMappable]

See also

geoprior.models.pinn.PiTGWFlow

The PINN model this function is designed to visualize.

Notes

The core mechanism of this function involves creating a 2D meshgrid of \((x, y)\) coordinates. These grid points are then “flattened” into a long list of points, as the PINN model expects a batch of individual coordinates for prediction, not a grid.

The prediction process is as follows:

  1. A grid of shape (resolution, resolution) is created for \(x\) and \(y\).

  2. These grids are reshaped into column vectors of shape (resolution*resolution, 1).

  3. A time vector of the same shape, filled with t_slice, is created.

  4. The model’s .predict() method is called on these flat tensors.

  5. The resulting flat prediction vector is reshaped back to the original (resolution, resolution) grid shape for plotting.

If a custom ax is provided, the user is responsible for calling plt.show() or saving the parent figure.

Examples

>>> import numpy as np
>>> import tensorflow as tf
>>> import matplotlib.pyplot as plt
>>> # This is a mock model for demonstration purposes.
>>> # In practice, you would use a trained PiTGWFlow model.
>>> class MockPINN(tf.keras.Model):
...     def call(self, inputs):
...         # A simple analytical function for demonstration
...         t, x, y = inputs['t'], inputs['x'], inputs['y']
...         return tf.sin(np.pi * x) * tf.cos(np.pi * y) * tf.exp(-t)
...
>>> mock_model = MockPINN()

1. Simple Plotting Example

This example creates a single plot and saves it to a file.

>>> ax, contour = plot_hydraulic_head(
...     model=mock_model,
...     t_slice=0.5,
...     x_bounds=(-1, 1),
...     y_bounds=(-1, 1),
...     resolution=50,
...     save_path="hydraulic_head_t0.5.png",
...     show_plot=False  # Do not display interactively
... )
Plot saved to hydraulic_head_t0.5.png

2. Advanced Example with Subplots

This example shows how to use the ax parameter to draw the solution at two different times side-by-side in one figure.

>>> fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))
>>> fig.suptitle('Hydraulic Head at Different Times', fontsize=16)
...
>>> # Plot solution at t = 0.1
>>> plot_hydraulic_head(
...     model=mock_model, t_slice=0.1, x_bounds=(-1, 1),
...     y_bounds=(-1, 1), ax=ax1, show_plot=False
... )
...
>>> # Plot solution at t = 1.0
>>> plot_hydraulic_head(
...     model=mock_model, t_slice=1.0, x_bounds=(-1, 1),
...     y_bounds=(-1, 1), ax=ax2, show_plot=False
... )
...
>>> plt.tight_layout(rect=[0, 0, 1, 0.96])
>>> plt.show()

This module is the entry point for PINN-specific utility logic, including PDE mode normalization, coordinate extraction, PINN prediction formatting, and related helpers used near the physics-informed model layer.

Selected model-facing helpers#

geoprior.models.utils.prepare_model_inputs(dynamic_input, static_input=None, future_input=None, model_type='strict', forecast_horizon=None, verbose=0, **kwargs)[source]

Prepares a list of input tensors for a model’s call method.

This function standardizes the creation of the input list [static, dynamic, future] expected by many models in geoprior. It handles cases where static or future inputs might be None, creating appropriate dummy tensors with zero features if the model_type is ‘strict’.

Parameters:
  • dynamic_input (np.ndarray or tf.Tensor) – The dynamic (past observed) features. This input is always required and must be a valid tensor or array. Expected shape: (batch_size, past_time_steps, num_dynamic_features).

  • static_input (np.ndarray or tf.Tensor, optional) – The static (time-invariant) features. Expected shape: (batch_size, num_static_features). If None and model_type is ‘strict’, a dummy tensor with 0 static features will be created. Default is None.

  • future_input (np.ndarray or tf.Tensor, optional) – The known future features. Expected shape: (batch_size, future_time_span, num_future_features). If None and model_type is ‘strict’, a dummy tensor with 0 future features will be created. The time span for this dummy future tensor will be past_time_steps (from dynamic_input) plus forecast_horizon if provided, otherwise just past_time_steps. Default is None.

  • model_type ({'strict', 'flexible'}, default 'strict') –

    Determines how None inputs for static and future features are handled:

    • 'strict': If static_input or future_input is None, a dummy tensor with a feature dimension of 0 will be created and included in the output list. This is for models that expect a 3-element list of tensors, even if some paths are unused.

    • 'flexible': If static_input or future_input is None, None itself will be placed in the corresponding position in the output list. This is for models that can internally handle None inputs for optional feature types.

  • forecast_horizon (int, optional) – The forecast horizon. Used only if model_type=’strict’ and future_input is None, to determine the time dimension of the dummy future tensor (as past_time_steps + forecast_horizon). If not provided in this scenario, the dummy future tensor’s time dimension will match dynamic_input’s past_time_steps. Default is None.

  • verbose (int, default 0) –

    Verbosity level. If > 0, prints information about dummy tensor creation.

    • 0: Silent.

    • 1: Basic info on dummy creation.

    • 2: More details on shapes.

Returns:

A list containing three elements in the order: [processed_static_input, processed_dynamic_input, processed_future_input]. Elements can be TensorFlow tensors or None (if model_type=’flexible’ and original input was None). All returned tensors are cast to tf.float32.

Return type:

List[Optional[tf.Tensor]]

Raises:
  • ValueError – If dynamic_input is None. If dynamic_input is not at least 2D (needs batch dimension). If static_input (when provided) is not 2D. If future_input (when provided) is not 3D.

  • TypeError – If inputs cannot be converted to TensorFlow tensors.

Examples

>>> import tensorflow as tf
>>> import numpy as np
>>> from geoprior.models.utils import prepare_model_inputs
>>> B, T, H = 2, 10, 3
>>> D_s, D_d, D_f = 2, 4, 1
>>> dyn_in = tf.random.normal((B, T, D_d))
>>> stat_in = tf.random.normal((B, D_s))
>>> fut_in = tf.random.normal((B, T + H, D_f))
>>> # Strict mode, all inputs provided
>>> s, d, f = prepare_model_inputs(dyn_in, stat_in, fut_in, model_type='strict')
>>> print(f"S: {s.shape}, D: {d.shape}, F: {f.shape}")
S: (2, 2), D: (2, 10, 4), F: (2, 13, 1)
>>> # Strict mode, static is None
>>> s, d, f = prepare_model_inputs(dyn_in, static_input=None, future_input=fut_in,
...                                model_type='strict', forecast_horizon=H)
>>> print(f"S: {s.shape}, D: {d.shape}, F: {f.shape}")
S: (2, 0), D: (2, 10, 4), F: (2, 13, 1)
>>> # Flexible mode, static and future are None
>>> s, d, f = prepare_model_inputs(dyn_in, static_input=None, future_input=None,
...                                model_type='flexible')
>>> print(f"S: {s is None}, D: {d.shape}, F: {f is None}")
S: True, D: (2, 10, 4), F: True
geoprior.models.utils.prepare_model_inputs_in(dynamic_input, static_input=None, future_input=None, model_type='strict', forecast_horizon=None, verbose=0)[source]
Prepares a list of input tensors for a model’s call method in graph

compatible mode.

Prepares a list of input tensors for a model’s call method.

This function standardizes the creation of the input list [static, dynamic, future] expected by many models in geoprior. It handles cases where static or future inputs might be None, creating appropriate dummy tensors with zero features if the model_type is ‘strict’.

Parameters:
dynamic_inputnp.ndarray or tf.Tensor

The dynamic (past observed) features. This input is always required and must be a valid tensor or array. Expected shape: (batch_size, past_time_steps, num_dynamic_features).

static_inputnp.ndarray or tf.Tensor, optional

The static (time-invariant) features. Expected shape: (batch_size, num_static_features). If None and model_type is ‘strict’, a dummy tensor with 0 static features will be created. Default is None.

future_inputnp.ndarray or tf.Tensor, optional

The known future features. Expected shape: (batch_size, future_time_span, num_future_features). If None and model_type is ‘strict’, a dummy tensor with 0 future features will be created. The time span for this dummy future tensor will be past_time_steps (from dynamic_input) plus forecast_horizon if provided, otherwise just past_time_steps. Default is None.

model_type{‘strict’, ‘flexible’}, default ‘strict’

Determines how None inputs for static and future features are handled:

  • 'strict': If static_input or future_input is None, a dummy tensor with a feature dimension of 0 will be created and included in the output list. This is for models that expect a 3-element list of tensors, even if some paths are unused.

  • 'flexible': If static_input or future_input is None, None itself will be placed in the corresponding position in the output list. This is for models that can internally handle None inputs for optional feature types.

forecast_horizonint, optional

The forecast horizon. Used only if model_type=’strict’ and future_input is None, to determine the time dimension of the dummy future tensor (as past_time_steps + forecast_horizon). If not provided in this scenario, the dummy future tensor’s time dimension will match dynamic_input’s past_time_steps. Default is None.

verboseint, default 0

Verbosity level. If > 0, prints information about dummy tensor creation.

  • 0: Silent.

  • 1: Basic info on dummy creation.

  • 2: More details on shapes.

Returns:
List[Optional[tf.Tensor]]

A list containing three elements in the order: [processed_static_input, processed_dynamic_input, processed_future_input]. Elements can be TensorFlow tensors or None (if model_type=’flexible’ and original input was None). All returned tensors are cast to tf.float32.

Raises:
ValueError

If dynamic_input is None. If dynamic_input is not at least 2D (needs batch dimension). If static_input (when provided) is not 2D. If future_input (when provided) is not 3D.

TypeError

If inputs cannot be converted to TensorFlow tensors.

Parameters:
Return type:

list[Any | None]

Examples

>>> import tensorflow as tf
>>> import numpy as np
>>> from geoprior.models.utils import prepare_model_inputs_in
>>> B, T, H = 2, 10, 3
>>> D_s, D_d, D_f = 2, 4, 1
>>> dyn_in = tf.random.normal((B, T, D_d))
>>> stat_in = tf.random.normal((B, D_s))
>>> fut_in = tf.random.normal((B, T + H, D_f))
>>> # Strict mode, all inputs provided
>>> s, d, f = prepare_model_inputs_in(dyn_in, stat_in, fut_in, model_type='strict')
>>> print(f"S: {s.shape}, D: {d.shape}, F: {f.shape}")
S: (2, 2), D: (2, 10, 4), F: (2, 13, 1)
>>> # Strict mode, static is None
>>> s, d, f = prepare_model_inputs_in(dyn_in, static_input=None, future_input=fut_in,
...                                model_type='strict', forecast_horizon=H)
>>> print(f"S: {s.shape}, D: {d.shape}, F: {f.shape}")
S: (2, 0), D: (2, 10, 4), F: (2, 13, 1)
>>> # Flexible mode, static and future are None
>>> s, d, f = prepare_model_inputs_in(dyn_in, static_input=None, future_input=None,
...                                model_type='flexible')
>>> print(f"S: {s is None}, D: {d.shape}, F: {f is None}")
S: True, D: (2, 10, 4), F: True
geoprior.models.utils.create_sequences(df, sequence_length, target_col, step=1, include_overlap=True, drop_last=True, forecast_horizon=None, verbose=3)[source]

Create input sequences and corresponding targets for time series forecasting.

The create_sequences function generates sequences of features and their corresponding targets from a time series dataset. This is essential for training sequence models like Temporal Fusion Transformers, LSTMs, and others that rely on temporal dependencies.

See more in User Guide.

Parameters:
  • df (pandas.DataFrame) – The processed DataFrame containing features and the target variable.

  • sequence_length (int) – The number of past time steps to include in each input sequence.

  • target_col (str) – The name of the target column.

  • step (int, default 1) – The step size between the starts of consecutive sequences.

  • include_overlap (bool, default True) – Whether to include overlapping sequences based on the step size.

  • drop_last (bool, default True) – Whether to drop the last sequence if it does not have enough data points.

  • forecast_horizon (int, optional, default None) – The number of future time steps to predict. If set to None, the function will create targets for a single future time step. If provided, targets will consist of the next forecast_horizon time steps.

  • verbose (int, default 3) – Controls the verbosity of logging. Ranges from 0 (no logs) to 7 (maximal logs).

Returns:

A tuple containing:
  • sequences: Array of input sequences with shape (num_sequences, sequence_length, num_features).

  • targets:
    • If forecast_horizon is None: Array of target values with shape (num_sequences,).

    • If forecast_horizon is an integer: Array of target sequences with shape (num_sequences, forecast_horizon).

Return type:

Tuple[`numpy.ndarray`, numpy.ndarray]

Raises:

ValueError – If the DataFrame df does not contain the target_col.

Examples

>>> import pandas as pd
>>> import numpy as np
>>> from geoprior.models.utils import create_sequences
>>> # Create a dummy DataFrame
>>> data = pd.DataFrame({
...     'feature1': np.random.rand(100),
...     'feature2': np.random.rand(100),
...     'feature3': np.random.rand(100),
...     'target': np.random.rand(100)
... })
>>> # Create sequences for single-step forecasting
>>> sequence_length = 4
>>> sequences, targets = create_sequences(
...     df=data,
...     sequence_length=sequence_length,
...     target_col='target',
...     step=1,
...     include_overlap=True,
...     drop_last=True,
...     forecast_horizon=None
... )
>>> print(sequences.shape)
(95, 4, 4)
>>> print(targets.shape)
(95,)
>>> # Create sequences for multi-step forecasting (e.g., 3 steps)
>>> forecast_horizon = 3
>>> sequences, targets = create_sequences(
...     df=data,
...     sequence_length=4,
...     target_col='target',
...     step=1,
...     include_overlap=True,
...     drop_last=True,
...     forecast_horizon=3
... )
>>> print(sequences.shape)
(92, 4, 4)
>>> print(targets.shape)
(92, 3)

Notes

  • Sequence Creation: The function slides a window of size sequence_length across the DataFrame to create input sequences. Each sequence is associated with a target value or sequence of values that immediately follow the input sequence.

  • Forecast Horizon:
    • If forecast_horizon is None, the function creates targets for a single future time step.

    • If forecast_horizon is an integer H, the function creates targets consisting of the next H time steps.

  • Step Size: The step parameter controls the stride of the sliding window. A step of 1 results in overlapping sequences, while a larger step reduces overlap.

  • Handling Incomplete Sequences: If drop_last is set to False, the function includes the last sequence even if it doesn’t have enough data points to form a complete sequence or target.

  • Data Validation: The function utilizes are_all_frames_valid from geoprior.core.checks to ensure the integrity of input DataFrame before processing and exist_features to verify the presence of the target column.

The sequences generation can be expressed as:

(65)#\[\begin{split}\\text{For each sequence } i, \\\\ \\mathbf{X}^{(i)} = \\left[ \\mathbf{x}_{i}, \\mathbf{x}_{i+1}, \\\\ \\dots, \\mathbf{x}_{i+T-1} \\right] \\\\ y^{(i)} = \\begin{cases} \\mathbf{x}_{i+T} & \\text{if } \\text{forecast\\_horizon} = \\text{None} \\\\ \\left[ \\mathbf{x}_{i+T}, \\mathbf{x}_{i+T+1}, \\dots, \\\\ \\mathbf{x}_{i+T+H-1} \\right] & \\text{if } \\text{forecast\\_horizon} = H \\end{cases}\end{split}\]
Where:
  • \(\\mathbf{X}^{(i)}\) is the input sequence of length \(T\).

  • \(y^{(i)}\) is the target value(s) following the sequence.

See also

geoprior.models.utils.split_static_dynamic

Function to split sequences into static and dynamic inputs.

geoprior.models.utils.split_static_dynamic(sequences, static_indices, dynamic_indices, static_time_step=0, reshape_static=True, reshape_dynamic=True, static_reshape_shape=None, dynamic_reshape_shape=None)[source]

Split sequences into static and dynamic inputs for the model.

The split_static_dynamic function divides input sequences into static and dynamic components based on specified feature indices. Static features are typically location-specific and do not change over time, while dynamic features vary across different time steps.

(66)#\[\begin{split}\text{Static Inputs} = \mathbf{S} = \mathbf{X}_{t, static\_indices} \\ \text{Dynamic Inputs} = \mathbf{D} = \mathbf{X}_{:, dynamic\_indices}\end{split}\]
Parameters:
  • sequences (numpy.ndarray) – Array of input sequences with shape (batch_size, sequence_length, num_features).

  • static_indices (List[int]) – Indices of static features within the feature dimension.

  • dynamic_indices (List[int]) – Indices of dynamic features within the feature dimension.

  • static_time_step (int, default 0) – Time step from which to extract static features (default is the first time step).

  • reshape_static (bool, default True) – Whether to reshape static inputs. If False, returns without reshaping.

  • reshape_dynamic (bool, default True) – Whether to reshape dynamic inputs. If False, returns without reshaping.

  • static_reshape_shape (:py:class:Optional[Tuple[int`, :py:class:`...]], default None) – Desired shape for static inputs after reshaping. If None, defaults to (batch_size, num_static_vars, 1).

  • dynamic_reshape_shape (:py:class:Optional[Tuple[int`, :py:class:`...]], default None) – Desired shape for dynamic inputs after reshaping. If None, defaults to (batch_size, sequence_length, num_dynamic_vars, 1).

Returns:

A tuple containing: - Static inputs with shape as specified. - Dynamic inputs with shape as specified.

Return type:

Tuple[`numpy.ndarray`, numpy.ndarray]

Raises:

ValueError – If static_time_step is out of range for the given sequence length.

Examples

>>> import numpy as np
>>> from geoprior.models.utils import split_static_dynamic
>>>
>>> # Create a dummy sequence array
>>> sequences = np.random.rand(100, 10, 5)  # (
...   batch_size=100, sequence_length=10, num_features=5)
>>>
>>> # Define static and dynamic feature indices
>>> static_indices = [0, 1]  # e.g., longitude and latitude
>>> dynamic_indices = [2, 3, 4]  # e.g., year, GWL, density
>>>
>>> # Split the sequences
>>> static_inputs, dynamic_inputs = split_static_dynamic(
...     sequences,
...     static_indices=static_indices,
...     dynamic_indices=dynamic_indices,
...     static_time_step=0
... )
>>>
>>> print(static_inputs.shape)
(100, 2, 1)
>>> print(dynamic_inputs.shape)
(100, 10, 3, 1)

Notes

  • Static Features: These are typically location-specific features such as geographical coordinates or categorical attributes that remain constant over time.

  • Dynamic Features: These features vary over different time steps and are essential for capturing temporal dependencies in the data.

  • Reshaping: The function provides flexibility in reshaping the static and dynamic inputs to match the input requirements of various models, including Temporal Fusion Transformers.

See also

geoprior.models.utils.create_sequences

Function to create input sequences and targets for time series forecasting.

geoprior.models.utils.forecast_single_step(xtft_model, inputs, y=None, dt_col=None, mode='quantile', spatial_cols=None, q=None, tname=None, forecast_dt=None, apply_mask=False, mask_values=None, mask_fill_value=None, savefile=None, verbose=3, **kws)[source]

Generate a single-step forecast using the XTFT model.

This function generates a forecast for a single future time step using a pre-trained XTFT deep learning model. The model takes three inputs: X_static, X_dynamic, and X_future, and produces a prediction according to the formulation:

(67)#\[\hat{y}_{t+1} = f\Bigl( X_{\text{static}},\; X_{\text{dynamic}},\; X_{\text{future}} \Bigr)\]

where \(f\) is the trained XTFT model. The predictions can be either quantile-based or point-based, as determined by the mode parameter.

Parameters:
  • xtft_model (object) – A validated Keras model instance. The model is expected to be verified via validate_keras_model.

  • inputs (list or tuple of numpy.ndarray) – A list containing three elements: X_static, X_dynamic, and X_future. If spatial_cols is provided, it is assumed that the first column of X_static corresponds to the first spatial coordinate and the second column to the second spatial coordinate of the original training data.

  • y (numpy.ndarray, optional) – Actual target values. If provided, evaluation metrics such as R² Score and (in quantile mode) the coverage score are computed.

  • dt_col (str, optional) – Name of the time column (e.g. "year"). If provided, a column with this name is added to the output DataFrame. The actual time values must be supplied externally.

  • mode (str, optional) – Forecast mode. Must be either "quantile" or "point". In quantile mode, predictions are generated for multiple quantiles (default: 0.1, 0.5, and 0.9).

  • spatial_cols (list of str, optional) – List of spatial column names. If provided, it must contain at least two elements and correspond to the first and second columns of the original training data’s X_static.

  • q (list of float, optional) – List of quantiles for quantile forecasting. Default is [0.1, 0.5, 0.9] when mode is "quantile".

  • tname (str, optional) – Target variable name for predictions. This name is used to construct output column names (e.g. "subsidence"). Default is "target".

  • forecast_dt (any, optional) – Forecast datetime information. Not used in this function but may be provided for compatibility.

  • apply_mask (bool, optional) – If True, applies a masking function (mask_by_reference) to replace predictions in non-subsiding areas. Requires that both mask_values and mask_fill_value are provided.

  • mask_values (scalar, optional) – Reference value(s) used for masking. Must be provided if apply_mask is True.

  • mask_fill_value (scalar, optional) – Value used to fill masked predictions. Must be provided if apply_mask is True.

  • savefile (str, optional) – Path to a CSV file where the forecast results will be saved. If not provided, a default filename is generated.

  • verbose (int, optional) – Verbosity level controlling printed output. Higher values result in more detailed output.

Returns:

A DataFrame containing the forecast results. In quantile mode, the output includes columns for each quantile (e.g. <tname>_q10, <tname>_q50, <tname>_q90). In point mode, a single prediction column (<tname>_pred) is provided. If y is provided, an additional column with the actual target values (<tname>_actual) is included.

Return type:

pandas.DataFrame

Examples

>>> from geoprior.models.transformers import XTFT
>>> from geoprior.models.utils import forecast_single_step
>>> import pandas as pd
>>> import numpy as np
>>>
>>> # Create a dummy training DataFrame with a date column,
>>> # two dynamic features ("feat1", "feat2"), a static feature ("stat1"),
>>> # and dummy spatial features ("longitude", "latitude"), as well as the
>>> # target variable "subsidence".
>>> date_rng = pd.date_range(start="2020-01-01", periods=50, freq="D")
>>> train_df = pd.DataFrame({
...     "date": date_rng,
...     "longitude": np.random.uniform(-180, 180, 50),
...     "latitude": np.random.uniform(-90, 90, 50),
...     "feat1": np.random.rand(50),
...     "feat2": np.random.rand(50),
...     "stat1": np.random.rand(50),
...     "subsidence": np.random.rand(50)
... })
>>>
>>> # Prepare dummy inputs for the model.
>>> # For the static input, combine the spatial feature "longitude" and the
>>> # static feature "stat1". The forecast_single_step function expects that,
>>> # if spatial_cols is provided, the first two columns of X_static correspond
>>> # to the spatial coordinates.
>>> X_static = train_df[["longitude", "stat1"]].values   # shape: (50, 2)
>>>
>>> # Create a dummy dynamic input array for "feat1" and "feat2".
>>> # Assume time_steps = 3, so the shape is (50, 3, 2).
>>> X_dynamic = np.random.rand(50, 3, 2)
>>>
>>> # Create a dummy future input array.
>>> # For this example, assume one future feature with shape (50, 3, 1).
>>> X_future = np.random.rand(50, 3, 1)
>>>
>>> # Create dummy target output from "subsidence", reshaped to (50, 1, 1)
>>> y_array = train_df["subsidence"].values.reshape(50, 1, 1)
>>>
>>> # Instantiate a dummy XTFT model.
>>> # The model expects:
>>> #   - X_static with shape (n_samples, static_input_dim)
>>> #   - X_dynamic with shape (n_samples, time_steps, dynamic_input_dim)
>>> #   - X_future with shape (n_samples, time_steps, future_input_dim)
>>> my_model = XTFT(
...     static_input_dim=2,         # "longitude" and "stat1"
...     dynamic_input_dim=2,        # "feat1" and "feat2"
...     future_input_dim=1,         # One future feature
...     forecast_horizon=1,         # Single-step forecast
...     quantiles=[0.1, 0.5, 0.9],
...     embed_dim=16,
...     max_window_size=3,
...     memory_size=50,
...     num_heads=2,
...     dropout_rate=0.1,
...     lstm_units=32,
...     attention_units=32,
...     hidden_units=16
... )
>>> my_model.compile(optimizer="adam")
>>>
>>> # Fit the model on the dummy data.
>>> my_model.fit(
...     x=[X_static, X_dynamic, X_future],
...     y=y_array,
...     epochs=1,
...     batch_size=8
... )
>>>
>>> # Package the inputs as expected by forecast_single_step.
>>> inputs = [X_static, X_dynamic, X_future]
>>>
>>> # Generate a single-step quantile forecast.
>>> forecast_df = forecast_single_step(
...     xtft_model=my_model,
...     inputs=inputs,
...     y=y_array,
...     dt_col="date",                # The time column name in the output
...     mode="quantile",              # Can be "quantile" or "point"
...     spatial_cols=["longitude", "latitude"],
...     q=[0.1, 0.5, 0.9],
...     tname="subsidence",
...     apply_mask=True,
...     mask_values=0,
...     mask_fill_value=0,
...     verbose=3
... )
>>> print(forecast_df.head())

Notes

  • In quantile mode, the function computes predictions for multiple quantiles and uses the median (0.5) for evaluation.

  • If spatial_cols is provided, it must be the first and second columns of the original training data’s X_static.

  • The function internally utilizes validate_keras_model for model validation, assert_ratio for quantile verification, and mask_by_reference for masking operations.

  • Evaluation metrics such as R² Score and Coverage Score are computed if actual target values (y) are provided.

  • The prediction output is expected to have the shape \((n, 1, m)\), where \(m\) is the number of outputs (e.g., the number of quantiles in quantile mode, or 1 in point mode).

See also

generate_forecast_multi_step

Function for multi-step forecasts.

coverage_score

Function to compute the coverage score.

validate_keras_model

Function to validate a Keras model.

assert_ratio

Function to validate quantile ratios.

geoprior.models.utils.forecast_multi_step(xtft_model, inputs, forecast_horizon, y=None, dt_col=None, mode='quantile', spatial_cols=None, q=None, tname=None, forecast_dt=None, apply_mask=False, mask_values=None, mask_fill_value=None, savefile=None, verbose=3, **kws)[source]

Generate a multi-step forecast using the XTFT model.

This function generates forecasts for multiple future time steps using a pre-trained XTFT deep learning model. The model takes three inputs: X_static, X_dynamic, and X_future, and produces predictions according to the formulation:

(68)#\[\hat{y}_{t+i} = f\Bigl( X_{\text{static}},\; X_{\text{dynamic}},\; X_{\text{future}} \Bigr)\]

for \(i = 1, \dots, forecast_horizon\), where \(f\) is the trained XTFT model.

Parameters:
  • xtft_model (object) – A validated Keras model instance. The model is expected to be verified via validate_keras_model.

  • inputs (list or tuple of numpy.ndarray) – A list containing three elements: X_static, X_dynamic, and X_future. If spatial_cols is provided, it is assumed that the first two columns of X_static correspond to the first and second spatial coordinates of the original training data.

  • forecast_horizon (int) – The number of future time steps to forecast. For example, if forecast_horizon is 4, the model will generate predictions for 4 steps ahead.

  • y (numpy.ndarray, optional) – Actual target values. If provided, evaluation metrics such as R² Score and, in quantile mode, the coverage score are computed.

  • dt_col (str, optional) – Name of the time column (e.g. "year"). If provided, a column with this name is added to the output DataFrame. The actual time values must be supplied externally.

  • mode (str, optional) – Forecast mode. Must be either "quantile" or "point". In quantile mode, predictions are generated for multiple quantiles (default: [0.1, 0.5, 0.9]); in point mode, a single prediction is generated.

  • spatial_cols (list of str, optional) – A list of spatial column names. If provided, it must contain at least two elements corresponding to the first and second columns of the original training data’s X_static.

  • time_steps (int, optional) – The number of historical time steps used as input. Default is 3.

  • q (list of float, optional) – List of quantile values for quantile forecasting. The default is [0.1, 0.5, 0.9] when mode is "quantile".

  • tname (str, optional) – Target variable name used to construct output column names. For instance, if tname is "subsidence", then output columns may be named "subsidence_q10_step1", "subsidence_q50_step2", etc. Default is "target".

  • forecast_dt (any, optional) – Forecast datetime information. If provided and its length matches forecast_horizon, its values are added to the output DataFrame.

  • apply_mask (bool, optional) – If True, applies masking via mask_by_reference to replace predictions in non-subsiding areas. Requires that both mask_values and mask_fill_value are provided.

  • mask_values (scalar, optional) – The reference value(s) used for masking. Must be provided if apply_mask is True.

  • mask_fill_value (scalar, optional) – The value used to fill masked predictions. Must be provided if apply_mask is True.

  • savefile (str, optional) – File path to save the forecast results as a CSV file. If not provided, a default filename is generated.

  • verbose (int, optional) – Verbosity level controlling printed output. Higher values produce more detailed messages.

Returns:

A DataFrame containing the multi-step forecast results. In quantile mode, the DataFrame includes columns for each quantile and each forecast step (e.g. <tname>_q10_step1, <tname>_q50_step2, etc.); in point mode, it contains a single prediction column per forecast step (e.g. <tname>_pred_step1). If y is provided, an additional column (<tname>_actual) is included.

Return type:

pandas.DataFrame

Examples

>>> from geoprior.models.transformers import XTFT
>>> from geoprior.models.utils import forecast_multi_step
>>> from geoprior.models.losses import combined_quantile_loss
>>> import pandas as pd
>>> import numpy as np
>>>
>>> # Create a dummy training DataFrame with a date column,
>>> # spatial features ("longitude", "latitude"), two dynamic
>>> # features ("feat1", "feat2"), a static feature ("stat1"), and
>>> # the target variable "subsidence".
>>> date_rng = pd.date_range(start="2020-01-01", periods=60,
...                          freq="D")
>>> train_df = pd.DataFrame({
...     "date": date_rng,
...     "longitude": np.random.uniform(-180, 180, 60),
...     "latitude": np.random.uniform(-90, 90, 60),
...     "feat1": np.random.rand(60),
...     "feat2": np.random.rand(60),
...     "stat1": np.random.rand(60),
...     "subsidence": np.random.rand(60)
... })
>>>
>>> # Prepare dummy input arrays for model training.
>>> # X_static is constructed using "longitude" and "stat1".
>>> X_static = train_df[["longitude", "stat1"]].values
>>> # X_dynamic for "feat1" and "feat2" with time_steps = 3.
>>> X_dynamic = np.random.rand(60, 3, 2)
>>> # X_future is a dummy future feature array with shape (60, 3, 1).
>>> X_future = np.random.rand(60, 3, 1)
>>> # Target output from "subsidence" reshaped to
>>> # (60, 1, 1). For multi-step forecast, forecast_horizon is 4.
>>> forecast_horizon = 4
>>> y_array = train_df["subsidence"].values.reshape(60, 1, 1)
>>>
>>> # Instantiate a dummy XTFT model.
>>> my_model = XTFT(
...     static_input_dim=2,    # "longitude" and "stat1"
...     dynamic_input_dim=2,   # "feat1" and "feat2"
...     future_input_dim=1,    # One future feature
...     forecast_horizon=forecast_horizon,
...     quantiles=[0.1, 0.5, 0.9],
...     embed_dim=16,
...     max_window_size=3,
...     memory_size=50,
...     num_heads=2,
...     dropout_rate=0.1,
...     lstm_units=32,
...     attention_units=32,
...     hidden_units=16
... )
>>> my_model.compile(
...    optimizer="adam",
...    loss=combined_quantile_loss(my_model.quantiles)
...    )
>>>
>>> # Fit the model on the dummy data for demonstration.
>>> my_model.fit(
...     x=[X_static, X_dynamic, X_future],
...     y=y_array,
...     epochs=1,
...     batch_size=8
... )
>>>
>>> # Generate forecast datetime values for the forecast horizon.
>>> forecast_dates = pd.date_range(start="2020-02-01",
...                                periods=forecast_horizon, freq="D")
>>>
>>> # Package inputs as expected by forecast_multi_step.
>>> inputs = [X_static, X_dynamic, X_future]
>>>
>>> # Generate a multi-step forecast in quantile mode.
>>> forecast_df_quantile = forecast_multi_step(
...     xtft_model=my_model,
...     inputs=inputs,
...     forecast_horizon=forecast_horizon,
...     y=y_array,
...     dt_col="date",
...     mode="quantile",
...     spatial_cols=["longitude", "latitude"],
...     q=[0.1, 0.5, 0.9],
...     tname="subsidence",
...     forecast_dt=forecast_dates,
...     apply_mask=False,
...     verbose=3
... )
>>> print("Quantile Forecast:")
>>> print(forecast_df_quantile.head())
>>>
  1. For point forecast

>>> # Instantiate a dummy XTFT model.
>>> my_model = XTFT(
...     static_input_dim=2,    # "longitude" and "stat1"
...     dynamic_input_dim=2,   # "feat1" and "feat2"
...     future_input_dim=1,    # One future feature
...     forecast_horizon=forecast_horizon,
...     quantiles=None, # set quantiles to None
...     embed_dim=16,
...     max_window_size=3,
...     memory_size=50,
...     num_heads=2,
...     dropout_rate=0.1,
...     lstm_units=32,
...     attention_units=32,
...     hidden_units=16
... )
>>> my_model.compile(
...    optimizer="adam", loss="mse",
...    )
>>>
>>> # Fit the model on the dummy data for demonstration.
>>> my_model.fit(
...     x=[X_static, X_dynamic, X_future],
...     y=y_array,
...     epochs=1,
...     batch_size=8
... )
>>> # Generate a multi-step forecast in point mode.
>>> forecast_df_point = forecast_multi_step(
...     xtft_model=my_model,
...     inputs=inputs,
...     forecast_horizon=forecast_horizon,
...     y=y_array,
...     dt_col="date",
...     mode="point",
...     spatial_cols=["longitude", "latitude"],
...     tname="subsidence",
...     forecast_dt=forecast_dates,
...     apply_mask=False,
...     verbose=3
... )
>>> print("Point Forecast:")
>>> print(forecast_df_point.head())

Notes

  • In quantile mode, predictions are generated for each specified quantile for every forecast step, and the median (0.5) is used for evaluation.

  • In point mode, a single prediction is generated per forecast step.

  • The output prediction array is expected to have the shape \((n, forecast\_horizon, m)\), where \(n\) is the number of samples and \(m\) is the number of outputs per step (e.g., number of quantiles in quantile mode or 1 in point mode).

  • The provided spatial_cols must correspond to the first two columns of the original training data’s X_static.

  • Evaluation metrics such as R² Score and Coverage Score (in quantile mode) are computed if actual target values (y) are provided.

  • The DataFrame is constructed by iterating over each sample and each forecast step.

See also

forecast_single_step

Function for single-step forecasts.

coverage_score

Function to compute the coverage score.

validate_keras_model

Function to validate a Keras model.

assert_ratio

Function to verify quantile ratios.

geoprior.models.utils.format_predictions_to_dataframe(predictions=None, model=None, inputs=None, y_true_sequences=None, target_name='target', quantiles=None, forecast_horizon=None, output_dim=None, spatial_data_array=None, spatial_cols=None, spatial_cols_indices=None, evaluate_coverage=False, scaler=None, scaler_feature_names=None, target_idx_in_scaler=None, verbose=0, **kwargs)[source]

Deprecated alias for format_predictions. See format_predictions for the updated, recommended API. All original parameters are forwarded to format_predictions.

Returns:

The formatted prediction DataFrame from format_predictions.

Return type:

pd.DataFrame

Parameters:
  • predictions (ndarray | Any | None)

  • model (Model | None)

  • inputs (list[ndarray | Any | None] | None)

  • y_true_sequences (ndarray | Any | None)

  • target_name (str | None)

  • quantiles (list[float] | None)

  • forecast_horizon (int | None)

  • output_dim (int | None)

  • spatial_data_array (ndarray | Any | None)

  • spatial_cols (list[str] | None)

  • spatial_cols_indices (list[int] | None)

  • evaluate_coverage (bool)

  • scaler (Any | None)

  • scaler_feature_names (list[str] | None)

  • target_idx_in_scaler (int | None)

  • verbose (int)

  • kwargs (Any)

geoprior.models.utils.extract_batches_from_dataset(dataset, num_batches_to_extract=1, agg=False, errors='warn')[source]

Extracts a specified number of batches from a tf.data.Dataset. Optionally aggregates the extracted batches.

Parameters:
  • dataset (tf.data.Dataset) – The TensorFlow dataset to extract batches from.

  • num_batches_to_extract (Union[int, str], default 1) – Number of batches: int, or ‘all’, ‘*’, ‘auto’.

  • agg (bool, default False) – If True, attempts to aggregate the extracted batches into a single tuple structure by concatenating corresponding tensors/arrays or aggregating dictionaries.

  • errors (str, default 'warn') – Error handling: ‘raise’, ‘warn’, ‘ignore’.

Returns:

If agg is False, returns a list of batch tuples. If agg is True, returns one aggregated tuple or None if no batches were extracted. When zero batches are requested or the dataset is empty, the function returns an empty list for agg=False and None for agg=True.

Return type:

Union[List[Tuple[Any, ]], Optional[Tuple[Any, ]]]

Raises:
  • TypeError – If dataset is not a tf.data.Dataset or num_batches_to_extract is invalid (and errors=’raise’).

  • ValueError – If num_batches_to_extract is negative, or fewer batches are available than requested (and errors=’raise’ and not taking all).

  • RuntimeError – For unexpected errors during dataset iteration (and errors=’raise’).

geoprior.models.utils.get_tensor_from(inputs, *tensor_names, default=None, check_type=True, auto_convert=True)[source]

Safely retrieves the first available tensor from a dictionary using a list of possible keys.

This utility is crucial for handling model inputs within a TensorFlow graph (e.g., in train_step). It avoids the ambiguous boolean evaluation of Tensors (e.g., tensor_a or tensor_b), which causes runtime errors, by explicitly checking for is not None.

Parameters:
  • inputs (dict) – The dictionary to search, typically the model’s input dictionary (e.g., the inputs provided to call or train_step).

  • *tensor_names (str) – One or more string keys to check for in the inputs dictionary, in order of priority.

  • default (Any, optional) – A default value to return if no keys are found or if no found value is a valid tensor. Defaults to None.

  • check_type (bool, default True) – If True, only returns a value if it is (or can be converted to) a Tensor or Variable. If False, returns the first non-None value regardless of its type.

  • auto_convert (bool, default True) – If True and check_type is True, this function will attempt to convert a found non-Tensor value (like a NumPy array or a list) into a TensorFlow tensor using tf.convert_to_tensor.

Returns:

The first found tf.Tensor or tf.Variable associated with one of the tensor_names. If auto_convert is True, this can also be a newly converted tensor. Returns default (typically None) if no valid tensor is found.

Return type:

Optional[tf.Tensor]

Raises:

TypeError – If inputs is not a dictionary.

Examples

>>> import tensorflow as tf
>>> inputs_dict = {
...     'some_other_key': [1, 2, 3],
...     'soil_thickness': tf.constant([20., 21.], dtype=tf.float32)
... }
>>>
>>> # Correctly finds 'soil_thickness'
>>> get_tensor_from(inputs_dict, 'H_field', 'soil_thickness')
<tf.Tensor: shape=(2,), dtype=float32, numpy=array([20., 21.], ...)>
>>>
>>> # Returns None safely if nothing is found
>>> get_tensor_from(inputs_dict, 'missing_key', 'another_key')
None
>>>
>>> # Demonstrating auto_convert
>>> inputs_dict_np = {'H_field': np.array([10., 11.])}
>>> get_tensor_from(inputs_dict_np, 'H_field', auto_convert=True)
<tf.Tensor: shape=(2,), dtype=float32, numpy=array([10., 11.], ...)>
geoprior.models.utils.compute_anomaly_scores(y_true, y_pred=None, method='statistical', threshold=3.0, domain_func=None, contamination=0.05, epsilon=1e-06, estimator=None, random_state=None, residual_metric='mse', objective='ts', verbose=1)[source]

Compute anomaly scores for given true targets using various methods.

This utility function, anomaly_scores, provides a flexible approach to compute anomaly scores outside the XTFT model itself. Anomaly scores serve as indicators of how unusual certain observations are, guiding the model towards more robust and stable forecasts. By detecting and quantifying anomalies, practitioners can adjust forecasting strategies, improve predictive performance, and handle irregular patterns more effectively.

Parameters:
  • y_true (np.ndarray) –

    The ground truth target values with shape (B, H, O), where: - B: batch size - H: number of forecast horizons (time steps ahead) - O: output dimension (e.g., number of target variables).

    Typically, y_true corresponds to the same array passed as the forecast target to the model. All computations of anomalies are relative to these true values or, if provided, their predicted counterparts y_pred.

  • y_pred (np.ndarray, optional) – The predicted values with shape (B, H, O). If provided and the method is set to ‘residual’, the anomaly scores are derived from the residuals between y_true and y_pred. In this scenario, anomalies reflect discrepancies indicating unusual conditions or model underperformance.

  • method (str, optional) –

    The method used to compute anomaly scores. Supported options:

    • "statistical" or "stats": Uses mean and standard deviation of y_true to measure deviation from the mean. Points far from the mean by a certain factor (controlled by threshold) yield higher anomaly scores.

      Formally, let \(\mu\) be the mean of y_true and \(\sigma\) its standard deviation. The anomaly score for a point \(y\) is:

      (69)#\[(\frac{y - \mu}{\sigma + \varepsilon})^2\]

      where \(\varepsilon\) is a small constant for numerical stability.

    • "domain": Uses a domain-specific heuristic (provided by domain_func) to compute scores. If no domain_func is provided, a default heuristic marks negative values as anomalies.

    • "isolation_forest" or "if": Employs the IsolationForest algorithm to detect outliers. The model learns a structure to isolate anomalies more quickly than normal points. Higher contamination rates allow more points to be considered anomalous.

    • "residual": If y_pred is provided, anomalies are derived from residuals: the difference (y_true - y_pred). By default, mean squared error (mse) is used. Other metrics include mae and rmse, offering flexibility in quantifying deviations:

      (70)#\[\text{MSE: }(y_{true} - y_{pred})^2\]

    Default is "statistical".

  • threshold (float, optional) – Threshold factor for the statistical method. Defines how far beyond mean ± (threshold * std) is considered anomalous. Though not directly applied as a mask here, it can guide interpretation of scores. Default is 3.0.

  • domain_func (callable, optional) –

    A user-defined function for domain method. It takes y_true as input and returns an array of anomaly scores with the same shape. If none is provided, the default heuristic:

    (71)\[\begin{split}\text{anomaly}(y) = \begin{cases} |y| \times 10 & \text{if } y < 0 \\ 0 & \text{otherwise} \end{cases}\end{split}\]

  • contamination (float, optional) – Used in the isolation_forest method. Specifies the proportion of outliers in the dataset. Default is 0.05.

  • epsilon (float, optional) – A small constant \(\varepsilon\) for numerical stability in calculations, especially during statistical normalization. Default is 1e-6.

  • estimator (object, optional) – A pre-fitted IsolationForest estimator for the isolation_forest method. If not provided, a new estimator will be created and fitted to y_true.

  • random_state (int, optional) – Sets a random state for reproducibility in the isolation_forest method.

  • residual_metric (str, optional) –

    The metric used to compute anomalies from residuals if method is set to ‘residual’. Supported metrics:

    • "mse": mean squared error per point (residuals**2)

    • "mae": mean absolute error per point |residuals|

    • "rmse": root mean squared error sqrt((residuals**2))

    Default is "mse".

  • objective (str, optional) – Specifies the type of objective, for future extensibility. Default is "ts" indicating time series. Could be extended for other tasks in the future.

  • verbose (int, optional) – Controls verbosity. If verbose=1, some messages or warnings may be printed. Higher values might produce more detailed logs.

Returns:

anomaly_scores – An array of anomaly scores with the same shape as y_true. Higher values indicate more unusual or anomalous points.

Return type:

np.ndarray

Notes

Choosing an appropriate method depends on the data characteristics, domain requirements, and model complexity. Statistical methods are quick and interpretable but may oversimplify anomalies. Domain heuristics leverage expert knowledge, while isolation forest applies a more robust, data-driven approach. Residual-based anomalies help assess model performance and highlight periods where the model struggles.

Examples

>>> from geoprior.models.losses import compute_anomaly_scores
>>> import numpy as np
>>> # Statistical method example
>>> y_true = np.random.randn(32, 20, 1)  # (B,H,O)
>>> scores = compute_anomaly_scores(y_true, method='statistical', threshold=3)
>>> scores.shape
(32, 20, 1)
>>> # Domain-specific example
>>> def my_heuristic(y):
...     return np.where(y < -1, np.abs(y)*5, 0.0)
>>> scores = compute_anomaly_scores(y_true, method='domain',
                                    domain_func=my_heuristic)
>>> # Isolation Forest example
>>> scores = compute_anomaly_scores(y_true, method='isolation_forest',
                                    contamination=0.1)
>>> # Residual-based example
>>> y_pred = y_true + np.random.normal(0, 1, y_true.shape)  # Introduce noise
>>> scores = compute_anomaly_scores(y_true, y_pred=y_pred, method='residual',
                                    residual_metric='mae')

See also

geoprior.models.losses.objective_loss

For integrating anomaly scores into a multi-objective loss.

geoprior.models.utils.PDE_MODE_ALIASES = frozenset({'', '0', '1', 'all', 'both', 'consolidation', 'disable', 'disabled', 'false', 'gw_flow', 'none', 'off', 'on', 'true'})

frozenset() -> empty frozenset object frozenset(iterable) -> frozenset object

Build an immutable unordered collection of unique elements.

geoprior.models.utils.prepare_pinn_data_sequences(df, time_col, subsidence_col, gwl_col, dynamic_cols, static_cols=None, future_cols=None, spatial_cols=None, h_field_col=None, lon_col=None, lat_col=None, group_id_cols=None, time_steps=12, forecast_horizon=3, output_subsidence_dim=1, output_gwl_dim=1, datetime_format=None, normalize_coords=True, cols_to_scale=None, lock_physics_cols=True, protect_si_suffix='__si', return_coord_scaler=False, coord_scaler=None, fit_coord_scaler=True, mode=None, model=None, savefile=None, progress_hook=None, stop_check=None, verbose=0, _logger=None, **kws)[source]
Parameters:
Return type:

tuple[dict[str, ndarray], dict[str, ndarray]] | tuple[dict[str, ndarray], dict[str, ndarray], MinMaxScaler | None]

geoprior.models.utils.normalize_for_pinn(df, time_col, coord_x, coord_y, cols_to_scale='auto', scale_coords=True, exclude_cols=None, protect_si_suffix='__si', shift_time_by_horizon=False, verbose=1, forecast_horizon=None, _logger=None, coord_scaler=None, fit_coord_scaler=True, other_scaler=None, fit_other_scaler=True, **kws)[source]

Apply Min-Max normalization to spatial-temporal coordinates and optionally to other numeric columns. If cols_to_scale == "auto", automatically select numeric columns excluding categorical and one-hot features.

By default, this function scales the time, longitude, and latitude columns (if scale_coords=True). Then, it either scales explicitly provided columns in cols_to_scale or automatically infers numeric columns (excluding coordinates if scale_coords is False, and excluding one-hot/boolean columns).

The Min-Max scaling for a feature \(x\) is:

(72)#\[x' = \frac{x - \min(x)}{\max(x) - \min(x)}\]
Parameters:
  • df (pd.DataFrame) – The input DataFrame containing at least time_col, coord_x, and coord_y columns. The DataFrame should contain temporal and spatial information to be scaled.

  • time_col (str) – The name of the numeric time column (e.g., year as numeric or datetime). This column will be used to adjust and scale the temporal data.

  • coord_x (str) – The name of the longitude column in the DataFrame. This column will be scaled along with the latitude and time columns.

  • coord_y (str) – The name of the latitude column in the DataFrame. This column will be scaled along with the longitude and time columns.

  • cols_to_scale (list of str or "auto" or None, default "auto") – If a list of column names, scales exactly those columns. If "auto", selects all numeric columns, excluding time_col, coord_x, and coord_y if scale_coords=False, and excluding one-hot encoded columns whose values are only {0, 1}. If None, no extra columns are scaled.

  • scale_coords (bool, default True) – If True, scales the [time_col, coord_x, coord_y] columns. If False, these columns remain unchanged.

  • verbose (int, default 1) – Verbosity level for logging. Values higher than 1 provide more detailed logging information.

  • forecast_horizon (Optional[int], default None) – The number of time steps to shift the time column by. This is added to the time values before scaling if provided.

  • _logger (Optional[Union[logging.Logger, Callable[[str], None]]], default None) – Logger or function to handle logging messages. If None, the default logging mechanism is used.

  • kws (dict, optional) – These will be passed on to any other internal function used in the data processing or scaling steps.

  • exclude_cols (list[str] | None)

  • protect_si_suffix (str)

  • shift_time_by_horizon (bool)

  • coord_scaler (MinMaxScaler | None)

  • fit_coord_scaler (bool)

  • other_scaler (MinMaxScaler | None)

  • fit_other_scaler (bool)

Returns:

  • df_scaled (pd.DataFrame) – A new DataFrame with the specified columns normalized.

  • coord_scaler (MinMaxScaler or None) – The fitted scaler for the [time_col, coord_x, coord_y] columns if scale_coords=True, else None.

  • other_scaler (MinMaxScaler or None) – The fitted scaler for any additional columns that were scaled (either explicitly provided or auto-selected). None if no columns were scaled beyond the coordinates.

Raises:
  • TypeError – If df is not a DataFrame, or cols_to_scale is neither a list nor “auto” nor None, or if any explicitly provided column is not a string.

  • ValueError – If required columns (time_col, coord_x, coord_y) or any of cols_to_scale do not exist in df, or cannot be converted to numeric.

Return type:

tuple[DataFrame, MinMaxScaler | None, MinMaxScaler | None]

Examples

>>> import pandas as pd
>>> from geoprior.nn.pinn.utils import normalize_for_pinn
>>> data = {
...     "year_num": [0.0, 1.0, 2.0],
...     "lon": [100.0, 101.0, 102.0],
...     "lat": [30.0, 31.0, 32.0],
...     "feat1": [10.0, 20.0, 30.0],
...     "one_hot_A": [0, 1, 0]
... }
>>> df = pd.DataFrame(data)
>>> df_scaled, coord_scl, feat_scl = normalize_for_pinn(
...     df,
...     time_col="year_num",
...     coord_x="lon",
...     coord_y="lat",
...     cols_to_scale="auto",
...     scale_coords=True,
...     verbose=2
... )
>>> # 'year_num','lon','lat','feat1' get scaled; 'one_hot_A' excluded
>>> df_scaled["year_num"].tolist()
[0.0, 0.5, 1.0]
>>> df_scaled["feat1"].tolist()
[0.0, 0.5, 1.0]

Notes

  • When cols_to_scale="auto", numeric columns with only {0, 1} values are assumed to be one-hot and excluded from scaling.

  • If scale_coords=False, coordinate columns remain unchanged, and auto-selection (if used) will exclude them.

  • Returned coord_scaler is None if scale_coords=False. Returned other_scaler is None if cols_to_scale is None or results in an empty set after filtering.

See also

sklearn.preprocessing.MinMaxScaler

Scales features to [0,1].

geoprior.models.utils.format_pinn_predictions(predictions=None, model=None, model_inputs=None, y_true_dict=None, target_mapping=None, include_gwl=True, include_coords=True, quantiles=None, forecast_horizon=None, output_dims=None, ids_data_array=None, ids_cols=None, ids_cols_indices=None, scaler_info=None, coord_scaler=None, evaluate_coverage=False, coverage_quantile_indices=(0, -1), savefile=None, _logger=None, name=None, model_name=None, stop_check=None, verbose=0, **kwargs)[source]

Formats PINN model predictions into a structured pandas DataFrame.

This is a general-purpose utility for transforming raw model outputs (from models like PIHALNet or TransFlowSubsNet) into a long-format DataFrame suitable for analysis, visualization, or export.

This is a powerful, general-purpose utility for transforming raw model outputs into a long-format DataFrame suitable for analysis, visualization, or export. It handles multi-target outputs (e.g., subsidence and GWL), point or quantile forecasts, and can optionally include true values, coordinate information, and other metadata. It also supports inverse-scaling of predictions and evaluation of quantile coverage.

Parameters:
  • predictions (dict of Tensors, optional) – The dictionary of prediction tensors, typically returned by a model’s .predict() method. Keys should match the model’s output layer names (e.g., 'subs_pred', 'gwl_pred'). If None, predictions are generated internally using the model and model_inputs arguments. Default is None.

  • model (keras.Model, optional) – A compiled Keras model instance used to generate predictions if the predictions dictionary is not provided. Default is None.

  • model_inputs (dict of Tensors, optional) – A dictionary of input tensors matching the model’s signature, required only if predictions is None. Default is None.

  • y_true_dict (dict, optional) – A dictionary containing the ground-truth target arrays, keyed by their base names (e.g., 'subsidence', 'gwl'). If provided, an <target>_actual column will be added to the output DataFrame for comparison. Default is None.

  • target_mapping (dict, optional) – A custom mapping from model output keys to desired base names in the DataFrame columns. For example: {'subs_pred': 'subsidence_mm', 'gwl_pred': 'head_m'}. Default is None.

  • include_gwl (bool, default True) – Toggles the inclusion of groundwater level (GWL) predictions in the final DataFrame.

  • include_coords (bool, default True) – Toggles the inclusion of the spatio-temporal coordinate columns (coord_t, coord_x, coord_y) in the final DataFrame.

  • quantiles (list of float, optional) – The list of quantile levels (e.g., [0.1, 0.5, 0.9]) that the model predicted. This is crucial for correctly parsing probabilistic forecasts. Default is None.

  • forecast_horizon (int, optional) – The length of the forecast horizon. If None, it is inferred from the shape of the prediction tensors. Default is None.

  • output_dims (dict of str, optional) – A dictionary specifying the feature dimension of each target, e.g., {'subs_pred': 1, 'gwl_pred': 1}. If None, it’s inferred from the tensor shapes. Default is None.

  • ids_data_array (np.ndarray or pd.DataFrame, optional) – An array or DataFrame containing static identifiers (e.g., well IDs, site categories) for each sample. Its length must match the number of samples in the prediction. Default is None.

  • ids_cols (list of str, optional) – A list of column names for the ids_data_array. Required if ids_data_array is a NumPy array. Default is None.

  • ids_cols_indices (list of int, optional) – A list of column indices to select from ids_data_array if it is a NumPy array. Default is None.

  • scaler_info (dict, optional) – A dictionary providing the necessary information to perform inverse scaling on a per-target basis. Each key should be a target name (e.g., ‘subsidence’) and its value a dictionary containing {'scaler': obj, 'all_features': list, 'idx': int}. Default is None.

  • coord_scaler (object, optional) – A fitted scikit-learn-like scaler object used to perform an inverse transform on the coordinate columns. Default is None.

  • evaluate_coverage (bool, default False) – If True and quantile predictions are present, calculates the unconditional coverage of the prediction interval.

  • coverage_quantile_indices (tuple of (int, int), default (0, -1)) – The indices of the lower and upper quantiles in the sorted quantiles list to use for the coverage calculation. Default is (0, -1), which corresponds to the full range.

  • savefile (str, optional) – If a file path is provided, the final DataFrame is saved to a CSV file at this location. Default is None.

  • name (str or None) – Name of the prediction. Name is used to format the output of the data and coverage result if applicable.

  • model_name (str, None,) – Name of the model.

  • verbose (int, default 0) – The verbosity level, from 0 (silent) to 5 (trace every step).

  • **kwargs (dict,) – Additional keyword arguments for future extensions.

  • _logger (Logger | Callable[[str], None] | None)

  • stop_check (Callable[[], bool])

Returns:

A long-format DataFrame where each row represents a single forecast step for a single sample. Columns include sample and step identifiers, coordinates, predictions, and optionally actuals and metadata.

Return type:

pd.DataFrame

Notes

  • The function returns a column-aligned DataFrame, which simplifies subsequent analysis and plotting.

  • For quantile forecasts, prediction columns are named using the pattern <target_name>_q<quantile*100>, e.g., subsidence_q5, subsidence_q50, subsidence_q95.

  • For point forecasts, the column is named <target_name>_pred.

See also

geoprior.plot.forecast.plot_forecasts

A powerful utility for visualizing the DataFrame produced by this function.

geoprior.models.utils.extract_txy(inputs, coord_slice_map=None, expect_dim=None, verbose=0, _logger=None, **kws)[source]

Extracts t, x, y tensors from various input formats.

This utility standardizes coordinate inputs, accepting a single tensor or a dictionary, and handling both 2D (spatial/static) and 3D (spatio-temporal) data with flexible dimension validation.

Parameters:
  • inputs (tf.Tensor, np.ndarray, or dict) – The input data containing coordinates. Can be a single tensor or a dictionary with ‘coords’ or ‘t’, ‘x’, ‘y’ keys.

  • coord_slice_map (dict, optional) – Mapping for ‘t’, ‘x’, ‘y’ to their index in the last dimension of a coordinate tensor. Defaults to {‘t’: 0, ‘x’: 1, ‘y’: 2}.

  • expect_dim ({'2d', '3d', '3d_only'}, optional) – Enforces a constraint on the input’s dimension. '2d' requires input shaped like (batch, 3). '3d' accepts 3D input and expands 2D input to 3D with a time dimension of 1. '3d_only' requires 3D input and raises an error for 2D input. None accepts both 2D and 3D inputs without changing their rank.

  • verbose (int, default 0) – Controls logging verbosity.

  • _logger (Logger | Callable[[str], None] | None)

Returns:

t, x, y – The extracted t, x, and y coordinate tensors. Their rank (2D or 3D) depends on the input and the expect_dim mode.

Return type:

Tuple[tf.Tensor, tf.Tensor, tf.Tensor]

Raises:

ValueError – If input format is unsupported, dimensions are inconsistent, or expect_dim constraint is violated.

geoprior.models.utils.extract_txy_in(inputs, coord_slice_map=None, expect_dim=None, verbose=0, _logger=None, **kws)[source]

Extracts t, x, y tensors from various input formats.

This utility standardizes coordinate inputs, accepting a single tensor or a dictionary, and handling both 2D (spatial/static) and 3D (spatio-temporal) data. It ensures a consistent 3D output format for robust downstream processing.

Parameters:
  • inputs (tf.Tensor, np.ndarray, or dict) – The input data containing coordinates. A single tensor or array may be 2D with shape (batch, 3) or 3D with shape (batch, time_steps, 3). A dictionary may contain a 'coords' key with the coordinate tensor, or separate 't', 'x', and 'y' keys.

  • coord_slice_map (dict, optional) – Mapping for ‘t’, ‘x’, ‘y’ to their index in the last dimension of a coordinate tensor. Defaults to {‘t’: 0, ‘x’: 1, ‘y’: 2}.

  • expect_dim ({'2d', '3d'}, optional) – If provided, enforces that the input resolves to the specified dimension. '2d' requires input shaped like (batch, 3) or a dictionary of (batch, 1) tensors. '3d' requires input shaped like (batch, time, 3) or a dictionary of (batch, time, 1) tensors. If None, both are accepted and 2D inputs are expanded to 3D.

  • verbose (int, default 0) – Controls the verbosity of logging messages. 0 is silent, 1 provides basic info, and higher values provide more detail.

  • _logger (Logger | Callable[[str], None] | None)

Returns:

t, x, y – The extracted t, x, and y coordinate tensors, each reshaped to be 3D with a singleton last dimension, e.g., (batch, time_steps, 1).

Return type:

Tuple[tf.Tensor, tf.Tensor, tf.Tensor]

Raises:

ValueError – If input format is unsupported, dimensions are inconsistent, or expect_dim constraint is violated.

geoprior.models.utils.process_pde_modes(pde_mode, enforce_consolidation=False, pde_mode_config=None, solo_return=False)[source]

Normalize and validate PDE mode selection.

Parameters:
  • pde_mode (str, sequence of str, or None) –

    Requested PDE mode(s).

    Accepted canonical values are: - "none" - "consolidation" - "gw_flow" - "both"

    Accepted aliases: - None, "off" -> "none" - "on" -> "both"

  • enforce_consolidation (bool, default False) –

    If True, any resolved mode other than exact ["consolidation"] is coerced to ["consolidation"] and a warning is emitted.

    This includes: - ["none"] - ["gw_flow"] - ["consolidation", "gw_flow"]

  • pde_mode_config (str, sequence of str, or None, optional) – Optional override. If provided, this value takes precedence over pde_mode.

  • solo_return (bool, default False) –

    If False, return a canonical list of active modes.

    If True, return a single canonical label: - "none" - "consolidation" - "gw_flow" - "both"

Returns:

Canonical PDE mode(s), either as a list or a single label.

Return type:

list of str or str

Raises:
  • TypeError – If the input type is invalid.

  • ValueError – If a token is unsupported or the mode selection is ambiguous.

Examples

>>> process_pde_modes(None)
['none']
>>> process_pde_modes("off")
['none']
>>> process_pde_modes("on")
['consolidation', 'gw_flow']
>>> process_pde_modes("both", solo_return=True)
'both'
>>> process_pde_modes("gw_flow", enforce_consolidation=True)
['consolidation']

Subsidence-physics utility layer#

The subsidence-physics utility layer is more specialized than the other two layers. It is not simply another helper bucket; it is a coherence layer for unit systems, scaling metadata, coordinate policies, historical states, and hydrogeological interpretation.

GeoPrior subsidence model utilities.

geoprior.models.subsidence.utils.enforce_scaling_alias_consistency(scaling_kwargs, *, where='validate')[source]

Enforce that canonical keys and aliases agree.

If both canonical and an alias exist and their values differ, apply the scaling error policy.

Parameters:
Return type:

None

geoprior.models.subsidence.utils.canonicalize_scaling_kwargs(scaling_kwargs, *, copy=True)[source]

Return a canonicalized scaling dict.

  • If a canonical key is missing, but one of its aliases exists, copy alias -> canonical.

  • Keeps existing canonical values unchanged.

Parameters:
Return type:

dict[str, Any]

geoprior.models.subsidence.utils.load_scaling_kwargs(scaling_kwargs, *, copy=True)[source]

Load scaling kwargs from a dict-like object or JSON.

Parameters:
  • scaling_kwargs (Any | None)

  • copy (bool)

Return type:

dict[str, Any]

geoprior.models.subsidence.utils.get_sk(scaling_kwargs, key, *aliases, default=None, required=False, cast=None)[source]

Fetch a key from scaling_kwargs with aliases + default.

  • Tries: key -> built-in aliases -> explicit aliases

  • Treats None and blank strings as “missing” and keeps searching.

Parameters:
geoprior.models.subsidence.utils.validate_scaling_kwargs(scaling_kwargs)[source]

Basic scaling sanity checks.

This includes policy-controlled heuristic checks for common “silent fallback” cases.

Parameters:

scaling_kwargs (dict[str, Any] | None)

Return type:

None

geoprior.models.subsidence.utils.affine_from_cfg(scaling_kwargs, *, scale_key, bias_key, meta_keys=(), unit_key=None)[source]

Return (a,b) for y_si = y_model*a + b.

Parameters:
Return type:

tuple[Tensor, Tensor]

geoprior.models.subsidence.utils.to_si_thickness(H_model, scaling_kwargs)[source]

Convert thickness to SI.

Parameters:
  • H_model (Tensor)

  • scaling_kwargs (dict[str, Any] | None)

Return type:

Tensor

geoprior.models.subsidence.utils.to_si_head(h_model, scaling_kwargs)[source]

Convert head/depth to SI meters.

Parameters:
  • h_model (Tensor)

  • scaling_kwargs (dict[str, Any] | None)

Return type:

Tensor

geoprior.models.subsidence.utils.to_si_subsidence(s_model, scaling_kwargs)[source]

Convert subsidence to SI meters.

Parameters:
  • s_model (Tensor)

  • scaling_kwargs (dict[str, Any] | None)

Return type:

Tensor

geoprior.models.subsidence.utils.from_si_subsidence(s_si, scaling_kwargs)[source]

Inverse of to_si_subsidence: s_model = (s_si - b) / a.

Parameters:
  • s_si (Tensor)

  • scaling_kwargs (dict[str, Any] | None)

Return type:

Tensor

geoprior.models.subsidence.utils.deg_to_m(axis, scaling_kwargs)[source]

Meters per degree factor for lon/lat coords.

If coords_in_degrees=True and deg_to_m_lon/lat are missing, we try to compute them from lat0_deg (recommended).

Parameters:
Return type:

Tensor

geoprior.models.subsidence.utils.coord_ranges(scaling_kwargs)[source]

Return (tR,xR,yR) if coords_normalized.

Parameters:

scaling_kwargs (dict[str, Any] | None)

Return type:

tuple[float | None, float | None, float | None]

geoprior.models.subsidence.utils.resolve_gwl_dyn_index(scaling_kwargs)[source]

Resolve GWL channel index for dynamic_features.

Parameters:

scaling_kwargs (dict[str, Any] | None)

Return type:

int

geoprior.models.subsidence.utils.get_gwl_dyn_index_cached(model)[source]

Cache gwl_dyn_index on model after first resolve.

Return type:

int

geoprior.models.subsidence.utils.resolve_subs_dyn_index(scaling_kwargs)[source]

Resolve subsidence channel index for dynamic_features.

This is optional: v3.2 can use historical subsidence as a dynamic driver to provide a physics-friendly initial condition for the mean settlement path.

geoprior.models.subsidence.utils.get_subs_dyn_index_cached(model)[source]

Cache subs_dyn_index on model after first resolve.

Return type:

int

geoprior.models.subsidence.utils.slice_dynamic_channel(Xh, idx)[source]

Slice (B,T,F) -> (B,T,1) at idx.

Parameters:
  • Xh (Tensor)

  • idx (int)

Return type:

Tensor

geoprior.models.subsidence.utils.assert_dynamic_names_match_tensor(Xh, scaling_kwargs)[source]

Check dynamic_feature_names length matches Xh.

Parameters:
  • Xh (Tensor)

  • scaling_kwargs (dict[str, Any] | None)

Return type:

None

geoprior.models.subsidence.utils.gwl_to_head_m(v_m, scaling_kwargs, *, inputs=None)[source]

Convert depth-bgs to head if possible.

Parameters:
  • v_m (Tensor)

  • scaling_kwargs (dict[str, Any] | None)

  • inputs (dict[str, Tensor] | None)

Return type:

Tensor

geoprior.models.subsidence.utils.get_h_hist_si(model, inputs, *, want_head=True)[source]

Return head (or depth) history in SI meters.

Parameters:
  • model (object) – The model instance (provides scaling_kwargs and cached indices).

  • inputs (dict) – Batch inputs; expects dynamic_features unless an explicit head history key is provided.

  • want_head (bool, default True) – If True, convert depth-bgs to hydraulic head when possible.

Returns:

(B,T,1) tensor in SI meters.

Return type:

Tensor

geoprior.models.subsidence.utils.get_s_init_si(model, inputs, like)[source]

Return initial settlement (cumulative subsidence) in SI meters.

Priority: 1) explicit keys in inputs (s_init_si/subs_hist_last_si/…) 2) last historical value from dynamic_features if subs_dyn_index exists 3) zeros (broadcast)

Parameters:
  • inputs (dict[str, Tensor] | None)

  • like (Tensor)

Return type:

Tensor

geoprior.models.subsidence.utils.get_h_ref_si(model, inputs, like)[source]

Return h_ref in SI meters, broadcast to like.

Parameters:
  • inputs (dict[str, Tensor] | None)

  • like (Tensor)

Return type:

Tensor

geoprior.models.subsidence.utils.infer_dt_units_from_t(t_BH1, scaling_kwargs, *, eps=1e-12)[source]

Infer per-step dt in time_units from time tensor t(B,H,1).

Parameters:
Return type:

Tensor

geoprior.models.subsidence.utils.policy_gate(step, policy, *, warmup_steps=0, ramp_steps=0, dtype=tf.float32)[source]

Return a scalar gate in [0,1] based on a policy + step.

Parameters:
  • step (Tensor) – Global step counter (typically optimizer.iterations).

  • policy ({"always_on","always_off","warmup_off"}) – Gating behavior. always_on returns 1, always_off returns 0, and warmup_off returns 0 for step < warmup_steps before ramping to 1 over ramp_steps when ramp_steps > 0 or switching immediately at warmup_steps otherwise.

  • warmup_steps (int, default 0) – Number of steps to keep the gate at 0 (only for warmup_off).

  • ramp_steps (int, default 0) – Number of steps for a linear ramp from 0->1 after warmup. If 0, the gate is a hard step.

  • dtype (dtype, default tf_float32) – Output dtype.

Return type:

Tensor

geoprior.models.subsidence.utils.finalize_scaling_kwargs(sk)[source]

Add derived SI conversion constants to scaling_kwargs.

Adds (when possible): - seconds_per_time_unit: float - coord_ranges_si: dict with keys t (seconds), x/y (meters) - coord_inv_ranges_si: inverse of the above (safe floor).

Notes

This helper is designed to be called once when assembling scaling_kwargs (e.g., in your stage2 script) so the model can reuse those constants without recomputing unit conversions in the hot training loop.

Parameters:

sk (dict[str, Any])

Return type:

dict[str, Any]

geoprior.models.subsidence.utils.coord_ranges_si(sk)[source]

Return coordinate spans in SI (t in seconds; x/y in meters).

If coord_ranges_si is present in sk, it is used directly. Otherwise, this is computed from coord_ranges and time_units (and degree-to-meter factors when applicable).

Parameters:

sk (dict[str, Any])

Return type:

tuple[float | None, float | None, float | None]

What this module covers#

The current implementation includes several strongly related subfamilies of helpers:

  • scaling canonicalization and policy enforcement such as canonicalize_scaling_kwargs, validate_scaling_kwargs, enforce_scaling_alias_consistency, and finalize_scaling_kwargs;

  • SI conversion helpers such as to_si_thickness, to_si_head, to_si_subsidence, and from_si_subsidence;

  • coordinate helpers such as deg_to_m, coord_ranges, and coord_ranges_si;

  • dynamic-channel resolution such as resolve_gwl_dyn_index, resolve_subs_dyn_index, get_gwl_dyn_index_cached, and get_subs_dyn_index_cached;

  • groundwater/head reconciliation such as gwl_to_head_m and get_h_ref_si;

  • history and reference-state extraction such as get_h_hist_si and get_s_init_si.

Why this module matters#

A large part of the difficulty in physics-informed forecasting is not the PDE term alone, but the consistency of units, metadata, and conventions. This module exists to make those conventions explicit. It handles alias drift in scaling keys, coordinates recorded in degrees versus projected meters, depth-versus-head ambiguity for groundwater, and the lookup of historical channels used to initialize or interpret model states.

Selected subsidence-physics helpers#

geoprior.models.subsidence.utils.canonicalize_scaling_kwargs(scaling_kwargs, *, copy=True)[source]

Return a canonicalized scaling dict.

  • If a canonical key is missing, but one of its aliases exists, copy alias -> canonical.

  • Keeps existing canonical values unchanged.

Parameters:
Return type:

dict[str, Any]

geoprior.models.subsidence.utils.validate_scaling_kwargs(scaling_kwargs)[source]

Basic scaling sanity checks.

This includes policy-controlled heuristic checks for common “silent fallback” cases.

Parameters:

scaling_kwargs (dict[str, Any] | None)

Return type:

None

geoprior.models.subsidence.utils.enforce_scaling_alias_consistency(scaling_kwargs, *, where='validate')[source]

Enforce that canonical keys and aliases agree.

If both canonical and an alias exist and their values differ, apply the scaling error policy.

Parameters:
Return type:

None

geoprior.models.subsidence.utils.to_si_thickness(H_model, scaling_kwargs)[source]

Convert thickness to SI.

Parameters:
  • H_model (Tensor)

  • scaling_kwargs (dict[str, Any] | None)

Return type:

Tensor

geoprior.models.subsidence.utils.to_si_head(h_model, scaling_kwargs)[source]

Convert head/depth to SI meters.

Parameters:
  • h_model (Tensor)

  • scaling_kwargs (dict[str, Any] | None)

Return type:

Tensor

geoprior.models.subsidence.utils.to_si_subsidence(s_model, scaling_kwargs)[source]

Convert subsidence to SI meters.

Parameters:
  • s_model (Tensor)

  • scaling_kwargs (dict[str, Any] | None)

Return type:

Tensor

geoprior.models.subsidence.utils.from_si_subsidence(s_si, scaling_kwargs)[source]

Inverse of to_si_subsidence: s_model = (s_si - b) / a.

Parameters:
  • s_si (Tensor)

  • scaling_kwargs (dict[str, Any] | None)

Return type:

Tensor

geoprior.models.subsidence.utils.deg_to_m(axis, scaling_kwargs)[source]

Meters per degree factor for lon/lat coords.

If coords_in_degrees=True and deg_to_m_lon/lat are missing, we try to compute them from lat0_deg (recommended).

Parameters:
Return type:

Tensor

geoprior.models.subsidence.utils.resolve_gwl_dyn_index(scaling_kwargs)[source]

Resolve GWL channel index for dynamic_features.

Parameters:

scaling_kwargs (dict[str, Any] | None)

Return type:

int

geoprior.models.subsidence.utils.resolve_subs_dyn_index(scaling_kwargs)[source]

Resolve subsidence channel index for dynamic_features.

This is optional: v3.2 can use historical subsidence as a dynamic driver to provide a physics-friendly initial condition for the mean settlement path.

geoprior.models.subsidence.utils.gwl_to_head_m(v_m, scaling_kwargs, *, inputs=None)[source]

Convert depth-bgs to head if possible.

Parameters:
  • v_m (Tensor)

  • scaling_kwargs (dict[str, Any] | None)

  • inputs (dict[str, Tensor] | None)

Return type:

Tensor

geoprior.models.subsidence.utils.get_h_hist_si(model, inputs, *, want_head=True)[source]

Return head (or depth) history in SI meters.

Parameters:
  • model (object) – The model instance (provides scaling_kwargs and cached indices).

  • inputs (dict) – Batch inputs; expects dynamic_features unless an explicit head history key is provided.

  • want_head (bool, default True) – If True, convert depth-bgs to hydraulic head when possible.

Returns:

(B,T,1) tensor in SI meters.

Return type:

Tensor

geoprior.models.subsidence.utils.get_s_init_si(model, inputs, like)[source]

Return initial settlement (cumulative subsidence) in SI meters.

Priority: 1) explicit keys in inputs (s_init_si/subs_hist_last_si/…) 2) last historical value from dynamic_features if subs_dyn_index exists 3) zeros (broadcast)

Parameters:
  • inputs (dict[str, Tensor] | None)

  • like (Tensor)

Return type:

Tensor

geoprior.models.subsidence.utils.get_h_ref_si(model, inputs, like)[source]

Return h_ref in SI meters, broadcast to like.

Parameters:
  • inputs (dict[str, Tensor] | None)

  • like (Tensor)

Return type:

Tensor

geoprior.models.subsidence.utils.policy_gate(step, policy, *, warmup_steps=0, ramp_steps=0, dtype=tf.float32)[source]

Return a scalar gate in [0,1] based on a policy + step.

Parameters:
  • step (Tensor) – Global step counter (typically optimizer.iterations).

  • policy ({"always_on","always_off","warmup_off"}) – Gating behavior. always_on returns 1, always_off returns 0, and warmup_off returns 0 for step < warmup_steps before ramping to 1 over ramp_steps when ramp_steps > 0 or switching immediately at warmup_steps otherwise.

  • warmup_steps (int, default 0) – Number of steps to keep the gate at 0 (only for warmup_off).

  • ramp_steps (int, default 0) – Number of steps for a linear ramp from 0->1 after warmup. If 0, the gate is a hard step.

  • dtype (dtype, default tf_float32) – Output dtype.

Return type:

Tensor

geoprior.models.subsidence.utils.finalize_scaling_kwargs(sk)[source]

Add derived SI conversion constants to scaling_kwargs.

Adds (when possible): - seconds_per_time_unit: float - coord_ranges_si: dict with keys t (seconds), x/y (meters) - coord_inv_ranges_si: inverse of the above (safe floor).

Notes

This helper is designed to be called once when assembling scaling_kwargs (e.g., in your stage2 script) so the model can reuse those constants without recomputing unit conversions in the hot training loop.

Parameters:

sk (dict[str, Any])

Return type:

dict[str, Any]

geoprior.models.subsidence.utils.coord_ranges_si(sk)[source]

Return coordinate spans in SI (t in seconds; x/y in meters).

If coord_ranges_si is present in sk, it is used directly. Otherwise, this is computed from coord_ranges and time_units (and degree-to-meter factors when applicable).

Parameters:

sk (dict[str, Any])

Return type:

tuple[float | None, float | None, float | None]

Connections to the scientific stack#

These utility layers help explain how GeoPrior-v3 connects workflow orchestration, forecasting abstractions, and physics-aware subsidence modeling.

In particular:

  • the workflow layer explains how artifacts, staged runs, evaluation tables, and calibrated outputs are prepared;

  • the model layer explains how generic forecasting utilities interact with sequence models and PINN-style helper logic;

  • the subsidence-physics layer explains how scaling metadata, hydraulic-head interpretation, and state initialization stay coherent when models are applied to land-subsidence problems.

For readers coming from a forecasting background, the geoprior.models.utils._utils docstrings are useful because they frame several helpers explicitly in the language of sequence forecasting and Temporal Fusion Transformer workflows. For readers coming from hydrogeology or subsidence physics, the geoprior.models.subsidence.utils layer is the most direct explanation of how physically meaningful quantities are recovered from model-side tensors and metadata.

Source listings with comments#

These source listings are included intentionally so the inline implementation comments remain visible in the documentation.

Top-level utility exports#

geoprior/utils/__init__.py#
r"""Public exports for GeoPrior utility helpers."""

from .audit_utils import (
    audit_stage1_scaling,
    audit_stage2_handshake,
    should_audit,
)
from .calibrate import calibrate_quantile_forecasts
from .data_utils import (
    mask_by_reference,
    nan_ops,
    widen_temporal_columns,
)
from .forecast_utils import (
    evaluate_forecast,
    format_and_forecast,
    pivot_forecast_dataframe,
)
from .generic_utils import (
    default_results_dir,
    ensure_directory_exists,
    getenv_stripped,
    normalize_time_column,
    print_config_table,
    save_all_figures,
)
from .geo_utils import (
    augment_city_spatiotemporal_data,
    augment_series_features,
    augment_spatiotemporal_data,
    generate_dummy_pinn_data,
    unpack_frames_from_file,
)
from .holdout_utils import (
    compute_group_masks,
    filter_df_by_groups,
    split_groups_holdout,
)
from .inspect import (
    ArtifactRecord,
    ablation_config_frame,
    ablation_metrics_frame,
    ablation_per_horizon_frame,
    ablation_record_flags_frame,
    ablation_record_runs_frame,
    artifact_brief,
    as_path,
    bool_checks_frame,
    build_stage1_feature_split,
    calibration_stats_factors_frame,
    calibration_stats_overall_frame,
    calibration_stats_per_horizon_frame,
    clone_artifact,
    deep_update,
    default_ablation_record_payload,
    default_calibration_stats_payload,
    default_eval_diagnostics_payload,
    default_eval_physics_payload,
    default_manifest_payload,
    default_model_init_manifest_payload,
    default_physics_payload_meta_payload,
    default_run_manifest_payload,
    default_scaling_kwargs_payload,
    default_stage1_audit_payload,
    default_stage2_handshake_payload,
    default_training_summary_payload,
    default_xfer_results_payload,
    ensure_parent_dir,
    eval_overall_frame,
    eval_per_horizon_frame,
    eval_physics_calibration_frame,
    eval_physics_calibration_per_horizon_frame,
    eval_physics_censor_frame,
    eval_physics_metrics_frame,
    eval_physics_per_horizon_frame,
    eval_physics_point_metrics_frame,
    eval_physics_units_frame,
    eval_years_frame,
    flatten_dict,
    generate_ablation_record,
    generate_calibration_stats,
    generate_eval_diagnostics,
    generate_eval_physics,
    generate_manifest,
    generate_model_init_manifest,
    generate_physics_payload_meta,
    generate_run_manifest,
    generate_scaling_kwargs,
    generate_stage1_audit,
    generate_stage2_handshake,
    generate_training_summary,
    generate_xfer_results,
    infer_artifact_kind,
    inspect_ablation_record,
    inspect_calibration_stats,
    inspect_eval_diagnostics,
    inspect_eval_physics,
    inspect_manifest,
    inspect_model_init_manifest,
    inspect_physics_payload_meta,
    inspect_run_manifest,
    inspect_scaling_kwargs,
    inspect_stage1_audit,
    inspect_stage2_handshake,
    inspect_training_summary,
    inspect_xfer_results,
    is_number,
    json_ready,
    load_ablation_record,
    load_artifact,
    load_calibration_stats,
    load_eval_diagnostics,
    load_eval_physics,
    load_manifest,
    load_model_init_manifest,
    load_physics_payload_meta,
    load_run_manifest,
    load_scaling_kwargs,
    load_stage1_audit,
    load_stage2_handshake,
    load_training_summary,
    load_xfer_results,
    manifest_artifacts_frame,
    manifest_config_frame,
    manifest_feature_groups_frame,
    manifest_holdout_frame,
    manifest_identity_frame,
    manifest_paths_frame,
    manifest_shapes_frame,
    manifest_versions_frame,
    metrics_frame,
    model_init_architecture_frame,
    model_init_dims_frame,
    model_init_feature_groups_frame,
    model_init_geoprior_frame,
    model_init_scaling_overview_frame,
    nested_get,
    numeric_items,
    physics_payload_meta_closure_frame,
    physics_payload_meta_identity_frame,
    physics_payload_meta_metrics_frame,
    physics_payload_meta_units_frame,
    plot_ablation_boolean_summary,
    plot_ablation_lambda_weights,
    plot_ablation_metric_by_variant,
    plot_ablation_per_horizon_metric,
    plot_ablation_run_counts,
    plot_ablation_top_variants,
    plot_boolean_checks,
    plot_calibration_boolean_summary,
    plot_calibration_factors,
    plot_calibration_overall_metrics,
    plot_calibration_per_horizon_coverage,
    plot_calibration_per_horizon_sharpness,
    plot_eval_boolean_summary,
    plot_eval_overall_metrics,
    plot_eval_per_horizon_metrics,
    plot_eval_physics_boolean_summary,
    plot_eval_physics_calibration_factors,
    plot_eval_physics_epsilons,
    plot_eval_physics_metrics,
    plot_eval_physics_per_horizon_metrics,
    plot_eval_physics_point_metrics,
    plot_eval_year_metric_trend,
    plot_manifest_artifact_inventory,
    plot_manifest_boolean_summary,
    plot_manifest_coord_ranges,
    plot_manifest_feature_group_sizes,
    plot_manifest_holdout_counts,
    plot_metric_bars,
    plot_model_init_architecture,
    plot_model_init_boolean_summary,
    plot_model_init_dims,
    plot_model_init_feature_group_sizes,
    plot_model_init_geoprior,
    plot_physics_payload_meta_boolean_summary,
    plot_physics_payload_meta_core_scalars,
    plot_physics_payload_meta_payload_metrics,
    plot_run_manifest_boolean_summary,
    plot_run_manifest_coord_ranges,
    plot_run_manifest_feature_group_sizes,
    plot_run_manifest_path_inventory,
    plot_scaling_kwargs_affine_maps,
    plot_scaling_kwargs_boolean_summary,
    plot_scaling_kwargs_bounds,
    plot_scaling_kwargs_coord_ranges,
    plot_scaling_kwargs_feature_group_sizes,
    plot_scaling_kwargs_schedule_scalars,
    plot_series_map,
    plot_stage1_boolean_summary,
    plot_stage1_coord_ranges,
    plot_stage1_feature_split,
    plot_stage1_target_stats,
    plot_stage1_variable_stats,
    plot_stage2_boolean_summary,
    plot_stage2_coord_range_errors,
    plot_stage2_coord_stats,
    plot_stage2_finite_ratios,
    plot_stage2_sample_sizes,
    plot_stage2_scaling_summary,
    plot_training_best_metrics,
    plot_training_boolean_summary,
    plot_training_final_metrics,
    plot_training_loss_family,
    plot_training_metric_deltas,
    plot_xfer_boolean_summary,
    plot_xfer_direction_metric,
    plot_xfer_overall_metrics,
    plot_xfer_per_horizon_metrics,
    plot_xfer_schema_counts,
    read_json,
    run_manifest_artifacts_frame,
    run_manifest_config_frame,
    run_manifest_identity_frame,
    run_manifest_paths_frame,
    run_manifest_scaling_overview_frame,
    scaling_kwargs_affine_frame,
    scaling_kwargs_bounds_frame,
    scaling_kwargs_coord_frame,
    scaling_kwargs_feature_channels_frame,
    scaling_kwargs_schedule_frame,
    stage1_coord_ranges_frame,
    stage1_feature_split_frame,
    stage1_stats_frame,
    stage2_coord_range_frame,
    stage2_coord_stats_frame,
    stage2_finite_frame,
    stage2_layout_frame,
    stage2_scaling_frame,
    summarize_ablation_record,
    summarize_calibration_stats,
    summarize_eval_diagnostics,
    summarize_eval_physics,
    summarize_manifest,
    summarize_model_init_manifest,
    summarize_physics_payload_meta,
    summarize_run_manifest,
    summarize_scaling_kwargs,
    summarize_stage1_audit,
    summarize_stage2_handshake,
    summarize_training_summary,
    summarize_xfer_results,
    training_compile_frame,
    training_env_frame,
    training_hp_frame,
    training_metrics_frame,
    training_paths_frame,
    write_json,
    xfer_overall_frame,
    xfer_per_horizon_frame,
    xfer_schema_frame,
    xfer_warm_frame,
)
from .io_utils import (
    fetch_joblib_data,
    save_job,
)
from .nat_utils import (
    best_epoch_and_metrics,
    build_censor_mask,
    ensure_input_shapes,
    extract_preds,
    load_nat_config,
    load_nat_config_payload,
    load_scaler_info,
    make_tf_dataset,
    map_targets_for_training,
    name_of,
    resolve_hybrid_config,
    resolve_si_affine,
    save_ablation_record,
    serialize_subs_params,
    subs_point_from_out,
)
from .parallel_utils import (
    apply_gpu_env,
    apply_tf_threading,
    apply_thread_env,
    pick_gpu_id,
    resolve_device,
    resolve_gpu_ids,
    resolve_n_jobs,
    threads_per_job,
)
from .scale_metrics import (
    evaluate_point_forecast,
    inverse_scale_target,
)
from .sequence_utils import build_future_sequences_npz
from .shapes import canonicalize_BHQO
from .spatial_utils import (
    create_spatial_clusters,
    deg_to_m_from_lat,
    spatial_sampling,
)
from .subsidence_utils import (
    convert_eval_payload_units,
    cumulative_to_rate,
    make_txy_coords,
    normalize_gwl_alias,
    postprocess_eval_json,
    rate_to_cumulative,
    resolve_gwl_for_physics,
    resolve_head_column,
)

__all__ = [
    "spatial_sampling",
    "create_spatial_clusters",
    "augment_city_spatiotemporal_data",
    "augment_series_features",
    "generate_dummy_pinn_data",
    "augment_spatiotemporal_data",
    "mask_by_reference",
    "nan_ops",
    "unpack_frames_from_file",
    "widen_temporal_columns",
    "pivot_forecast_dataframe",
    "fetch_joblib_data",
    "save_job",
    "normalize_time_column",
    "convert_eval_payload_units",
    "postprocess_eval_json",
    "evaluate_point_forecast",
    "inverse_scale_target",
    "deg_to_m_from_lat",
    "canonicalize_BHQO",
    "calibrate_quantile_forecasts",
    "audit_stage2_handshake",
    "audit_stage1_scaling",
    "should_audit",
    "format_and_forecast",
    "evaluate_forecast",
    "default_results_dir",
    "ensure_directory_exists",
    "getenv_stripped",
    "print_config_table",
    "save_all_figures",
    "build_censor_mask",
    "ensure_input_shapes",
    "extract_preds",
    "load_nat_config",
    "load_nat_config_payload",
    "load_scaler_info",
    "make_tf_dataset",
    "map_targets_for_training",
    "name_of",
    "resolve_hybrid_config",
    "resolve_si_affine",
    "best_epoch_and_metrics",
    "subs_point_from_out",
    "serialize_subs_params",
    "save_ablation_record",
    "cumulative_to_rate",
    "normalize_gwl_alias",
    "rate_to_cumulative",
    "resolve_gwl_for_physics",
    "resolve_head_column",
    "make_txy_coords",
    "build_future_sequences_npz",
    "compute_group_masks",
    "split_groups_holdout",
    "filter_df_by_groups",
    "build_future_sequences_npz",
    "resolve_n_jobs",
    "threads_per_job",
    "apply_tf_threading",
    "apply_thread_env",
    "resolve_device",
    "resolve_gpu_ids",
    "pick_gpu_id",
    "apply_gpu_env",
    "ArtifactRecord",
    "artifact_brief",
    "as_path",
    "bool_checks_frame",
    "clone_artifact",
    "deep_update",
    "ensure_parent_dir",
    "flatten_dict",
    "infer_artifact_kind",
    "is_number",
    "json_ready",
    "load_artifact",
    "metrics_frame",
    "nested_get",
    "numeric_items",
    "plot_boolean_checks",
    "plot_metric_bars",
    "plot_series_map",
    "read_json",
    "write_json",
    "build_stage1_feature_split",
    "default_stage1_audit_payload",
    "generate_stage1_audit",
    "inspect_stage1_audit",
    "load_stage1_audit",
    "plot_stage1_boolean_summary",
    "plot_stage1_coord_ranges",
    "plot_stage1_feature_split",
    "plot_stage1_target_stats",
    "plot_stage1_variable_stats",
    "stage1_coord_ranges_frame",
    "stage1_feature_split_frame",
    "stage1_stats_frame",
    "summarize_stage1_audit",
    "default_stage2_handshake_payload",
    "generate_stage2_handshake",
    "inspect_stage2_handshake",
    "load_stage2_handshake",
    "plot_stage2_boolean_summary",
    "plot_stage2_coord_range_errors",
    "plot_stage2_coord_stats",
    "plot_stage2_finite_ratios",
    "plot_stage2_sample_sizes",
    "plot_stage2_scaling_summary",
    "stage2_coord_range_frame",
    "stage2_coord_stats_frame",
    "stage2_finite_frame",
    "stage2_layout_frame",
    "stage2_scaling_frame",
    "summarize_stage2_handshake",
    "default_training_summary_payload",
    "generate_training_summary",
    "inspect_training_summary",
    "load_training_summary",
    "plot_training_best_metrics",
    "plot_training_boolean_summary",
    "plot_training_final_metrics",
    "plot_training_loss_family",
    "plot_training_metric_deltas",
    "training_compile_frame",
    "training_env_frame",
    "training_hp_frame",
    "training_metrics_frame",
    "training_paths_frame",
    "summarize_training_summary",
    "default_eval_diagnostics_payload",
    "eval_overall_frame",
    "eval_per_horizon_frame",
    "eval_years_frame",
    "generate_eval_diagnostics",
    "inspect_eval_diagnostics",
    "load_eval_diagnostics",
    "plot_eval_boolean_summary",
    "plot_eval_overall_metrics",
    "plot_eval_per_horizon_metrics",
    "plot_eval_year_metric_trend",
    "summarize_eval_diagnostics",
    "default_eval_physics_payload",
    "eval_physics_calibration_frame",
    "eval_physics_calibration_per_horizon_frame",
    "eval_physics_censor_frame",
    "eval_physics_metrics_frame",
    "eval_physics_per_horizon_frame",
    "eval_physics_point_metrics_frame",
    "eval_physics_units_frame",
    "generate_eval_physics",
    "inspect_eval_physics",
    "load_eval_physics",
    "plot_eval_physics_boolean_summary",
    "plot_eval_physics_calibration_factors",
    "plot_eval_physics_epsilons",
    "plot_eval_physics_metrics",
    "plot_eval_physics_per_horizon_metrics",
    "plot_eval_physics_point_metrics",
    "summarize_eval_physics",
    "default_physics_payload_meta_payload",
    "generate_physics_payload_meta",
    "inspect_physics_payload_meta",
    "load_physics_payload_meta",
    "physics_payload_meta_closure_frame",
    "physics_payload_meta_identity_frame",
    "physics_payload_meta_metrics_frame",
    "physics_payload_meta_units_frame",
    "plot_physics_payload_meta_boolean_summary",
    "plot_physics_payload_meta_core_scalars",
    "plot_physics_payload_meta_payload_metrics",
    "summarize_physics_payload_meta",
    "default_scaling_kwargs_payload",
    "generate_scaling_kwargs",
    "inspect_scaling_kwargs",
    "load_scaling_kwargs",
    "plot_scaling_kwargs_affine_maps",
    "plot_scaling_kwargs_boolean_summary",
    "plot_scaling_kwargs_bounds",
    "plot_scaling_kwargs_coord_ranges",
    "plot_scaling_kwargs_feature_group_sizes",
    "plot_scaling_kwargs_schedule_scalars",
    "scaling_kwargs_affine_frame",
    "scaling_kwargs_bounds_frame",
    "scaling_kwargs_coord_frame",
    "scaling_kwargs_feature_channels_frame",
    "scaling_kwargs_schedule_frame",
    "summarize_scaling_kwargs",
    "default_model_init_manifest_payload",
    "generate_model_init_manifest",
    "inspect_model_init_manifest",
    "load_model_init_manifest",
    "model_init_architecture_frame",
    "model_init_dims_frame",
    "model_init_feature_groups_frame",
    "model_init_geoprior_frame",
    "model_init_scaling_overview_frame",
    "plot_model_init_architecture",
    "plot_model_init_boolean_summary",
    "plot_model_init_dims",
    "plot_model_init_feature_group_sizes",
    "plot_model_init_geoprior",
    "summarize_model_init_manifest",
    "default_run_manifest_payload",
    "generate_run_manifest",
    "inspect_run_manifest",
    "load_run_manifest",
    "plot_run_manifest_boolean_summary",
    "plot_run_manifest_coord_ranges",
    "plot_run_manifest_feature_group_sizes",
    "plot_run_manifest_path_inventory",
    "run_manifest_artifacts_frame",
    "run_manifest_config_frame",
    "run_manifest_identity_frame",
    "run_manifest_paths_frame",
    "run_manifest_scaling_overview_frame",
    "summarize_run_manifest",
    "default_manifest_payload",
    "generate_manifest",
    "inspect_manifest",
    "load_manifest",
    "manifest_artifacts_frame",
    "manifest_config_frame",
    "manifest_feature_groups_frame",
    "manifest_holdout_frame",
    "manifest_identity_frame",
    "manifest_paths_frame",
    "manifest_shapes_frame",
    "manifest_versions_frame",
    "plot_manifest_artifact_inventory",
    "plot_manifest_boolean_summary",
    "plot_manifest_coord_ranges",
    "plot_manifest_feature_group_sizes",
    "plot_manifest_holdout_counts",
    "summarize_manifest",
    "default_xfer_results_payload",
    "generate_xfer_results",
    "inspect_xfer_results",
    "load_xfer_results",
    "plot_xfer_boolean_summary",
    "plot_xfer_direction_metric",
    "plot_xfer_overall_metrics",
    "plot_xfer_per_horizon_metrics",
    "plot_xfer_schema_counts",
    "summarize_xfer_results",
    "xfer_overall_frame",
    "xfer_per_horizon_frame",
    "xfer_schema_frame",
    "xfer_warm_frame",
    "calibration_stats_factors_frame",
    "calibration_stats_overall_frame",
    "calibration_stats_per_horizon_frame",
    "default_calibration_stats_payload",
    "generate_calibration_stats",
    "inspect_calibration_stats",
    "load_calibration_stats",
    "plot_calibration_boolean_summary",
    "plot_calibration_factors",
    "plot_calibration_overall_metrics",
    "plot_calibration_per_horizon_coverage",
    "plot_calibration_per_horizon_sharpness",
    "summarize_calibration_stats",
    "ablation_config_frame",
    "ablation_metrics_frame",
    "ablation_per_horizon_frame",
    "ablation_record_flags_frame",
    "ablation_record_runs_frame",
    "default_ablation_record_payload",
    "generate_ablation_record",
    "inspect_ablation_record",
    "load_ablation_record",
    "plot_ablation_boolean_summary",
    "plot_ablation_lambda_weights",
    "plot_ablation_metric_by_variant",
    "plot_ablation_per_horizon_metric",
    "plot_ablation_run_counts",
    "plot_ablation_top_variants",
    "summarize_ablation_record",
]

Model utility exports#

geoprior/models/utils/__init__.py#
# SPDX-License-Identifier: Apache-2.0
# GeoPrior-v3 — https://github.com/earthai-tech/geoprior-v3
# Copyright (c) 2026-present
# Author: LKouadio <etanoyau@gmail.com>
# website:https://lkouadio.com
r"""Public exports for model utility helpers."""

from ._utils import (
    compute_anomaly_scores,
    compute_forecast_horizon,
    create_sequences,
    export_keras_losses,
    extract_batches_from_dataset,
    extract_callbacks_from,
    forecast_multi_step,
    forecast_single_step,
    format_predictions,
    format_predictions_to_dataframe,
    generate_forecast,
    generate_forecast_with,
    get_tensor_from,
    make_dict_to_tuple_fn,
    prepare_model_inputs,
    prepare_model_inputs_in,
    prepare_spatial_future_data,
    set_default_params,
    split_static_dynamic,
    squeeze_last_dim_if,
    step_to_long,
)
from .pinn import (
    PDE_MODE_ALIASES,
    extract_txy,
    extract_txy_in,
    format_pihalnet_predictions,
    format_pinn_predictions,
    normalize_for_pinn,
    plot_hydraulic_head,
    prepare_pinn_data_sequences,
    process_pde_modes,
)

__all__ = [
    "PDE_MODE_ALIASES",
    "compute_anomaly_scores",
    "compute_forecast_horizon",
    "create_sequences",
    "extract_batches_from_dataset",
    "extract_callbacks_from",
    "forecast_multi_step",
    "forecast_single_step",
    "format_predictions",
    "format_predictions_to_dataframe",
    "generate_forecast",
    "generate_forecast_with",
    "prepare_model_inputs",
    "prepare_model_inputs_in",
    "prepare_spatial_future_data",
    "set_default_params",
    "split_static_dynamic",
    "squeeze_last_dim_if",
    "step_to_long",
    "export_keras_losses",
    "get_tensor_from",
    "format_pihalnet_predictions",
    "normalize_for_pinn",
    "prepare_pinn_data_sequences",
    "format_pinn_predictions",
    "extract_txy",
    "plot_hydraulic_head",
    "make_dict_to_tuple_fn",
    "extract_physical_parameters",
    "extract_txy_in",
    "process_pde_modes",
]

Subsidence utility implementation#

geoprior/models/subsidence/utils.py#
# SPDX-License-Identifier: Apache-2.0
# GeoPrior-v3 — https://github.com/earthai-tech/geoprior-v3
# Copyright (c) 2026-present
# Author: LKouadio <https://lkouadio.com>

"""
GeoPrior subsidence model utilities.
"""

from __future__ import annotations

import json
from collections.abc import Mapping
from pathlib import Path
from typing import Any
from warnings import warn

import numpy as np

from .. import KERAS_DEPS

Tensor = KERAS_DEPS.Tensor

tf_float32 = KERAS_DEPS.float32
tf_int32 = KERAS_DEPS.int32

tf_cast = KERAS_DEPS.cast
tf_constant = KERAS_DEPS.constant
tf_debugging = KERAS_DEPS.debugging
tf_equal = KERAS_DEPS.equal
tf_maximum = KERAS_DEPS.maximum
tf_minimum = KERAS_DEPS.minimum
tf_greater_equal = KERAS_DEPS.greater_equal
tf_rank = KERAS_DEPS.rank
tf_cond = KERAS_DEPS.cond
tf_shape = KERAS_DEPS.shape
tf_zeros_like = KERAS_DEPS.zeros_like
tf_ones = KERAS_DEPS.ones
tf_greater = KERAS_DEPS.greater
tf_cond = KERAS_DEPS.cond
tf_concat = KERAS_DEPS.concat
tf_convert_to_tensor = KERAS_DEPS.convert_to_tensor
tf_ones_like = KERAS_DEPS.ones_like
tf_less_equal = KERAS_DEPS.less_equal
tf_abs = KERAS_DEPS.abs
tf_print = KERAS_DEPS.print
tf_reduce_mean = KERAS_DEPS.reduce_mean
tf_expand_dims = KERAS_DEPS.expand_dims
tf_tile = KERAS_DEPS.tile


_EPSILON = 1e-12
# ---------------------------------------------------------------------
# Scaling kwargs access helpers (alias-safe)
# ---------------------------------------------------------------------
_SK_ALIASES = {
    # common naming drift
    "time_units": ("time_unit",),
    "cons_residual_units": ("cons_residual_unit",),
    # policy drift
    "scaling_error_policy": (
        "error_policy",
        "scaling_policy",
    ),
    # coord drift
    "coords_normalized": (
        "coord_normalized",
        "coords_norm",
    ),
    "coords_in_degrees": (
        "coord_in_degrees",
        "coords_deg",
    ),
    "coord_order": ("coords_order",),
    "coord_ranges": ("coord_range",),
    # feature-name list drift
    "dynamic_feature_names": (
        "dynamic_features_names",
        "dyn_feature_names",
    ),
    "future_feature_names": (
        "future_features_names",
        "fut_feature_names",
    ),
    "static_feature_names": (
        "static_features_names",
        "stat_feature_names",
    ),
    # feature-channel naming drift
    "gwl_col": (
        "gwl_dyn_name",
        "gwl_dyn_col",
        "gwl_name",
    ),
    "subs_dyn_name": (
        "subs_col",
        "subs_dyn_col",
        "subsidence_dyn_name",
    ),
    # feature-channel index drift
    "gwl_dyn_index": (
        "gwl_index",
        "gwl_feature_index",
        "gwl_channel_index",
    ),
    "subs_dyn_index": (
        "subs_index",
        "subs_feature_index",
        "subs_channel_index",
    ),
    # z_surf drift
    "z_surf_col": (
        "z_surf_key",
        "z_surf_name",
    ),
    # bounds drift (often nested under scaling_kwargs['bounds'])
    "log_tau_min": (
        "logTau_min",
        "logtau_min",
    ),
    "log_tau_max": (
        "logTau_max",
        "logtau_max",
    ),
    "tau_min": (
        "Tau_min",
        "tauMin",
        "tau_min_sec",
        "tau_min_seconds",
    ),
    "tau_max": (
        "Tau_max",
        "tauMax",
        "tau_max_sec",
        "tau_max_seconds",
    ),
    "tau_min_units": (
        "tau_min_time_units",
        "tau_min_in_time_units",
    ),
    "tau_max_units": (
        "tau_max_time_units",
        "tau_max_in_time_units",
    ),
    "Q_length_in_si": ("Q_in_m_per_s",),
}

_SK_ALIASES.update(
    {
        "cons_drawdown_mode": (
            "drawdown_mode",
            "cons_delta_mode",
        ),
        "cons_drawdown_rule": (
            "drawdown_rule",
            "cons_delta_rule",
        ),
        "cons_stop_grad_ref": (
            "stop_grad_ref",
            "cons_stopgrad_ref",
        ),
        "cons_drawdown_zero_at_origin": (
            "drawdown_zero_at_origin",
            "cons_zero_at_origin",
        ),
        "cons_drawdown_clip_max": (
            "drawdown_clip_max",
            "cons_clip_max",
        ),
        "cons_relu_beta": (
            "relu_beta",
            "cons_beta",
        ),
    }
)


# MV prior drift (mode/weight/warmup + loss knobs)
_SK_ALIASES.update(
    {
        "mv_prior_mode": (
            "mv_mode",
            "mvprior_mode",
            "mv_prior_kind",
        ),
        "mv_weight": (
            "mv_prior_weight",
            "mvprior_weight",
            "mv_w",
        ),
        "mv_warmup_steps": (
            "mv_prior_warmup_steps",
            "mv_warmup_steps",
            "mv_warmup_iters",
            "mv_warmup_iterations",
        ),
        "mv_alpha_disp": (
            "mv_prior_alpha_disp",
            "mv_disp_alpha",
            "mv_alpha",
        ),
        "mv_huber_delta": (
            "mv_prior_huber_delta",
            "mv_delta",
            "mv_huber",
        ),
        "mv_prior_units": (
            "mv_units",
            "mv_gamma_units",
            "mv_gw_units",
        ),
    }
)


def enforce_scaling_alias_consistency(
    scaling_kwargs: dict[str, Any] | None,
    *,
    where: str = "validate",
) -> None:
    """
    Enforce that canonical keys and aliases agree.

    If both canonical and an alias exist and their
    values differ, apply the scaling error policy.
    """
    sk = scaling_kwargs or {}

    for key, aliases in _SK_ALIASES.items():
        if key not in sk:
            continue

        v0 = sk.get(key, None)
        if v0 is None:
            continue

        for a in aliases:
            if a not in sk:
                continue

            va = sk.get(a, None)
            if va is None:
                continue

            if va != v0:
                msg = (
                    "Conflicting scaling keys: "
                    f"{key!r}={v0!r} != {a!r}={va!r}."
                )
                _handle_scaling_issue(
                    sk,
                    msg,
                    where=where,
                )


def canonicalize_scaling_kwargs(
    scaling_kwargs: dict[str, Any] | None,
    *,
    copy: bool = True,
) -> dict[str, Any]:
    """
    Return a canonicalized scaling dict.

    - If a canonical key is missing, but one of its
      aliases exists, copy alias -> canonical.
    - Keeps existing canonical values unchanged.
    """
    sk0 = scaling_kwargs or {}
    sk = dict(sk0) if copy else sk0

    for key, aliases in _SK_ALIASES.items():
        if key in sk and sk.get(key, None) is not None:
            continue

        for a in aliases:
            if a in sk and sk.get(a, None) is not None:
                sk[key] = sk[a]
                break

    return sk


def load_scaling_kwargs(
    scaling_kwargs: Any | None,
    *,
    copy: bool = True,
) -> dict[str, Any]:
    """
    Load scaling kwargs from a dict-like object or JSON.

    Supported inputs
    ----------------
    - dict / Mapping:
        Returned (copied by default).
    - str:
        * If it looks like JSON ("{...}" or "[...]"), parse as JSON.
        * Else treat as a filesystem path to a JSON file.
    - pathlib.Path:
        Treated as a filesystem path to a JSON file.
    - None:
        Returns {}.

    Parameters
    ----------
    scaling_kwargs : Any
        Scaling configuration input. Can be a dict, JSON string,
        path to JSON file, or None.
    copy : bool, default=True
        If True, returns a shallow copy of the dict.

    Returns
    -------
    dict
        Parsed scaling kwargs as a Python dict.

    Raises
    ------
    TypeError
        If the input type is unsupported.
    ValueError
        If JSON parsing fails or JSON does not decode to a dict.
    FileNotFoundError
        If a JSON path is given but does not exist.
    """
    if scaling_kwargs is None:
        return {}

    if isinstance(scaling_kwargs, Mapping):
        return (
            dict(scaling_kwargs) if copy else scaling_kwargs
        )

    if isinstance(scaling_kwargs, Path):
        path = scaling_kwargs
        text = path.read_text(encoding="utf-8")
        obj = json.loads(text)
        if not isinstance(obj, dict):
            raise ValueError(
                "Scaling JSON must decode to an object/dict, "
                f"got {type(obj).__name__}."
            )
        return obj

    if isinstance(scaling_kwargs, str):
        s = scaling_kwargs.strip()

        # 1) Inline JSON object/array.
        if (s.startswith("{") and s.endswith("}")) or (
            s.startswith("[") and s.endswith("]")
        ):
            try:
                obj = json.loads(s)
            except json.JSONDecodeError as e:
                raise ValueError(
                    "Invalid scaling_kwargs JSON string."
                ) from e
            if not isinstance(obj, dict):
                raise ValueError(
                    "Scaling JSON must decode to an object/dict, "
                    f"got {type(obj).__name__}."
                )

            return obj

        # 2) Treat as file path to JSON.
        path = Path(s).expanduser()
        if not path.exists():
            raise FileNotFoundError(
                f"Scaling kwargs JSON file not found: {str(path)!r}."
            )
        text = path.read_text(encoding="utf-8")
        try:
            obj = json.loads(text)
        except json.JSONDecodeError as e:
            raise ValueError(
                f"Invalid JSON in scaling kwargs file: {str(path)!r}."
            ) from e
        if not isinstance(obj, dict):
            raise ValueError(
                "Scaling JSON file must decode to an object/dict, "
                f"got {type(obj).__name__}."
            )
        return obj

    try:
        obj = dict(scaling_kwargs)
    except Exception as e:
        raise TypeError(
            "scaling_kwargs must be a dict/Mapping, JSON string, "
            "Path, or a path string to a JSON file."
        ) from e

    return obj


def get_sk(
    scaling_kwargs,
    key: str,
    *aliases: str,
    default=None,
    required: bool = False,
    cast=None,
):
    """
    Fetch a key from `scaling_kwargs` with aliases + default.

    - Tries: key -> built-in aliases -> explicit aliases
    - Treats None and blank strings as "missing" and keeps searching.
    """
    sk = scaling_kwargs or {}
    if not isinstance(sk, Mapping):
        try:
            sk = dict(sk)
        except Exception:
            sk = {}

    cand = [key]
    cand.extend(_SK_ALIASES.get(key, ()))
    cand.extend([a for a in aliases if a])

    for k in cand:
        if k in sk:
            v = sk[k]
            if v is None:
                continue
            if isinstance(v, str) and not v.strip():
                continue
            if cast is not None:
                try:
                    v = cast(v)
                except Exception as e:
                    raise ValueError(
                        f"Invalid scaling_kwargs[{k!r}]={v!r}."
                    ) from e
            return v

    if required:
        alias_txt = (
            ", ".join(repr(x) for x in cand[1:]) or "none"
        )
        raise ValueError(
            f"Missing required scaling key {key!r} (aliases: {alias_txt})."
        )
    if cast is not None and default is not None:
        try:
            return cast(default)
        except Exception:
            return default
    return default


def _norm_policy(policy: str | None) -> str:
    """
    Normalize scaling error policy.

    Allowed:
    - 'ignore'
    - 'warn'   (default)
    - 'raise'
    """
    p = (policy or "warn").strip().lower()
    if p not in ("ignore", "warn", "raise"):
        p = "warn"
    return p


def _handle_scaling_issue(
    scaling_kwargs: dict[str, Any] | None,
    message: str,
    *,
    where: str = "validate",
) -> None:
    """
    Apply scaling error policy.

    Notes
    -----
    You asked for: even if policy is 'raise', runtime
    fallback paths should still fall back to zeros.
    So:
    - where='validate': obey ignore/warn/raise
    - where='runtime' : treat 'raise' as 'warn'
    """
    sk = scaling_kwargs or {}
    policy = _norm_policy(
        get_sk(sk, "scaling_error_policy", default="warn")
    )

    # Runtime must not crash; still fall back later.
    if where != "validate" and policy == "raise":
        policy = "warn"

    if policy == "ignore":
        return

    if policy == "warn":
        warn(
            message,
            category=RuntimeWarning,
            stacklevel=2,
        )
        return
    # validate + raise
    raise ValueError(message)


def _is_deg_mode(mode: str) -> bool:
    m = (mode or "").strip().lower()
    return m in {
        "deg",
        "degree",
        "degrees",
        "lonlat",
        "latlon",
    }


def _validate_scaling_kwargs(scaling_kwargs):
    sk = canonicalize_scaling_kwargs(scaling_kwargs)
    enforce_scaling_alias_consistency(sk, where="validate")

    mode = str(sk.get("coord_mode", ""))
    deg_mode = _is_deg_mode(mode)
    deg_flag = bool(sk.get("coords_in_degrees", False))

    if deg_mode != deg_flag:
        msg = (
            "Inconsistent coord flags: "
            f"coord_mode={mode!r} but "
            f"coords_in_degrees={deg_flag}. "
            "Decide: degrees(+deg_to_m_*) or "
            "projected meters (coords_in_degrees=False)."
        )
        _handle_scaling_issue(sk, msg, where="validate")

    epsg_used = sk.get("coord_epsg_used", None)
    if deg_flag and (epsg_used not in (None, 4326)):
        msg = (
            "coords_in_degrees=True but "
            f"coord_epsg_used={epsg_used!r} "
            "looks projected. If you already "
            "reprojected, set coords_in_degrees=False."
        )
        _handle_scaling_issue(sk, msg, where="validate")


def validate_scaling_kwargs(
    scaling_kwargs: dict[str, Any] | None,
) -> None:
    """
    Basic scaling sanity checks.

    This includes policy-controlled heuristic checks
    for common "silent fallback" cases.
    """
    sk = canonicalize_scaling_kwargs(scaling_kwargs)
    enforce_scaling_alias_consistency(sk, where="validate")

    # --------------------------------------------------
    # Degrees mode requires meters-per-degree factors.
    # --------------------------------------------------
    if bool(sk.get("coords_in_degrees", False)):
        for key in ("deg_to_m_lon", "deg_to_m_lat"):
            val = sk.get(key, None)
            if val is None:
                msg = (
                    "coords_in_degrees=True but missing "
                    f"scaling_kwargs[{key!r}]."
                )
                raise ValueError(msg)
            try:
                v = float(val)
            except (TypeError, ValueError) as e:
                raise ValueError(
                    f"Invalid {key!r}={val!r}."
                ) from e
            if not np.isfinite(v) or v <= 0.0:
                raise ValueError(f"Invalid {key!r}={v}.")

    # --------------------------------------------------
    # Normalized coords require coord_ranges.
    # --------------------------------------------------
    if bool(
        sk.get("coords_normalized", False)
    ) and not sk.get(
        "coord_ranges",
        None,
    ):
        raise ValueError(
            "coords_normalized=True but coord_ranges missing."
        )

    # --------------------------------------------------
    # Require time units (alias-safe).
    # --------------------------------------------------
    if get_sk(sk, "time_units", default=None) is None:
        raise ValueError(
            "time_units missing in scaling_kwargs."
        )

    # --------------------------------------------------
    # Heuristic checks (policy-controlled).
    # --------------------------------------------------
    names = sk.get("dynamic_feature_names", None)
    names = list(names) if names is not None else []

    # A) Subsidence init: detect cum subs channel.
    has_subs_cum = any(
        ("subs" in str(n).lower() and "cum" in str(n).lower())
        for n in names
    )

    subs_idx = sk.get("subs_dyn_index", None)
    subs_name = get_sk(sk, "subs_dyn_name", default=None)

    meta = sk.get("gwl_z_meta", {}) or {}
    cols = meta.get("cols", {}) or {}
    subs_meta = cols.get("subs_model", None)

    if (
        has_subs_cum
        and subs_idx is None
        and subs_name is None
    ):
        if subs_meta is None:
            msg = (
                "dynamic_feature_names contains a cumulative "
                "subsidence channel, but no subs_dyn_index/"
                "subs_dyn_name and no gwl_z_meta.cols.subs_model. "
                "Initial settlement will fall back to zeros."
            )
            _handle_scaling_issue(
                sk,
                msg,
                where="validate",
            )

    # B) Depth->head conversion needs z_surf when proxy=False.
    kind = str(sk.get("gwl_kind", "")).lower()
    proxy = bool(sk.get("use_head_proxy", True))

    if (not proxy) and (
        kind not in ("head", "waterhead", "hydraulic_head")
    ):
        z_col = sk.get("z_surf_col", None)
        z_col = z_col or meta.get("z_surf_col", None)

        z_static = cols.get("z_surf_static", None)
        z_idx = sk.get("z_surf_static_index", None)

        static_names = get_sk(
            sk,
            "static_feature_names",
            default=None,
        )

        # If you did not provide a way to locate z_surf
        # in static features, conversion may fallback.
        if (
            z_idx is None
            and static_names is None
            and z_col is not None
            and z_static is not None
            and z_col != z_static
        ):
            msg = (
                "use_head_proxy=False and gwl_kind is depth-like, "
                "but z_surf_col differs from gwl_z_meta.cols."
                "z_surf_static, and no static_feature_names/"
                "z_surf_static_index provided. Depth->head "
                "conversion may fall back to depth."
            )
            _handle_scaling_issue(
                sk,
                msg,
                where="validate",
            )


def affine_from_cfg(
    scaling_kwargs: dict[str, Any] | None,
    *,
    scale_key: str,
    bias_key: str,
    meta_keys: tuple[str, ...] = (),
    unit_key: str | None = None,
) -> tuple[Tensor, Tensor]:
    """Return (a,b) for y_si = y_model*a + b."""
    cfg = scaling_kwargs or {}

    a = cfg.get(scale_key, None)
    b = cfg.get(bias_key, None)

    if a is not None or b is not None:
        a = 1.0 if a is None else float(a)
        b = 0.0 if b is None else float(b)
        return tf_constant(a, tf_float32), tf_constant(
            b, tf_float32
        )

    for mk in meta_keys:
        meta = cfg.get(mk, None)
        if isinstance(meta, dict):
            mu = meta.get("mu", meta.get("mean", None))
            sig = meta.get("sigma", meta.get("std", None))
            if mu is not None and sig is not None:
                return (
                    tf_constant(float(sig), tf_float32),
                    tf_constant(float(mu), tf_float32),
                )

    if unit_key is not None:
        u = float(cfg.get(unit_key, 1.0))
        return tf_constant(u, tf_float32), tf_constant(
            0.0, tf_float32
        )

    return tf_constant(1.0, tf_float32), tf_constant(
        0.0, tf_float32
    )


def to_si_thickness(
    H_model: Tensor,
    scaling_kwargs: dict[str, Any] | None,
) -> Tensor:
    """Convert thickness to SI."""
    a, b = affine_from_cfg(
        scaling_kwargs,
        scale_key="H_scale_si",
        bias_key="H_bias_si",
        meta_keys=("H_z_meta",),
        unit_key="thickness_unit_to_si",
    )
    return tf_cast(H_model, tf_float32) * a + b


def to_si_head(
    h_model: Tensor,
    scaling_kwargs: dict[str, Any] | None,
) -> Tensor:
    """Convert head/depth to SI meters."""
    a, b = affine_from_cfg(
        scaling_kwargs,
        scale_key="head_scale_si",
        bias_key="head_bias_si",
        meta_keys=("head_z_meta", "gwl_z_meta"),
        unit_key="head_unit_to_si",
    )
    return tf_cast(h_model, tf_float32) * a + b


def to_si_subsidence(
    s_model: Tensor,
    scaling_kwargs: dict[str, Any] | None,
) -> Tensor:
    """Convert subsidence to SI meters."""
    a, b = affine_from_cfg(
        scaling_kwargs,
        scale_key="subs_scale_si",
        bias_key="subs_bias_si",
        meta_keys=("subs_z_meta",),
        unit_key="subs_unit_to_si",
    )
    return tf_cast(s_model, tf_float32) * a + b


def from_si_subsidence(
    s_si: Tensor,
    scaling_kwargs: dict[str, Any] | None,
) -> Tensor:
    """Inverse of to_si_subsidence: s_model = (s_si - b) / a."""
    a, b = affine_from_cfg(
        scaling_kwargs,
        scale_key="subs_scale_si",
        bias_key="subs_bias_si",
        meta_keys=("subs_z_meta",),
        unit_key="subs_unit_to_si",
    )
    eps = tf_constant(_EPSILON, tf_float32)
    return (tf_cast(s_si, tf_float32) - b) / (a + eps)


def deg_to_m(
    axis: str,
    scaling_kwargs: dict[str, Any] | None,
) -> Tensor:
    """
    Meters per degree factor for lon/lat coords.

    If coords_in_degrees=True and deg_to_m_lon/lat are missing, we try
    to compute them from lat0_deg (recommended).
    """
    if axis not in ("x", "y"):
        raise ValueError(
            f"deg_to_m: axis must be 'x' or 'y', got {axis!r}."
        )

    cfg = scaling_kwargs or {}
    if not bool(cfg.get("coords_in_degrees", False)):
        return tf_constant(1.0, tf_float32)

    key = "deg_to_m_lon" if axis == "x" else "deg_to_m_lat"
    val = cfg.get(key, None)

    if val is None:
        lat0 = cfg.get("lat0_deg", None)
        if lat0 is None:
            raise ValueError(
                "coords_in_degrees=True but missing deg_to_m_lon/deg_to_m_lat "
                "and lat0_deg (needed for lon scaling)."
            )
        lat0 = float(lat0)
        if axis == "x":
            v = 111320.0 * float(np.cos(np.deg2rad(lat0)))
        else:
            v = 110574.0
        return tf_constant(v, tf_float32)

    try:
        v = float(val)
    except (TypeError, ValueError) as e:
        raise ValueError(f"Invalid {key!r}={val!r}.") from e

    if not np.isfinite(v) or v <= 0.0:
        raise ValueError(f"Invalid {key!r}={v}.")

    return tf_constant(v, tf_float32)


def coord_ranges(
    scaling_kwargs: dict[str, Any] | None,
) -> tuple[float | None, float | None, float | None]:
    """Return (tR,xR,yR) if coords_normalized."""
    cfg = scaling_kwargs or {}
    if not bool(cfg.get("coords_normalized", False)):
        return None, None, None

    r = cfg.get("coord_ranges", {}) or {}

    def get(name: str, *alts: str) -> float | None:
        v = r.get(name, None)
        if v is None:
            for a in alts:
                v = cfg.get(a, None)
                if v is not None:
                    break
        return None if v is None else float(v)

    tR = get("t", "t_range", "coord_range_t")
    xR = get("x", "x_range", "coord_range_x")
    yR = get("y", "y_range", "coord_range_y")
    return tR, xR, yR


def resolve_gwl_dyn_index(
    scaling_kwargs: dict[str, Any] | None,
) -> int:
    """Resolve GWL channel index for dynamic_features."""
    sk = scaling_kwargs or {}

    idx = sk.get("gwl_dyn_index", None)
    if idx is not None:
        return int(idx)

    names = sk.get("dynamic_feature_names", None)
    gwl_col = get_sk(sk, "gwl_col", default=None)

    if names is not None and gwl_col is not None:
        names = list(names)
        if gwl_col in names:
            return int(names.index(gwl_col))

    raise ValueError(
        "Cannot resolve GWL channel. Provide gwl_dyn_index "
        "or dynamic_feature_names + gwl_col."
    )


def get_gwl_dyn_index_cached(model) -> int:
    """Cache gwl_dyn_index on model after first resolve."""
    idx = getattr(model, "gwl_dyn_index", None)
    if idx is None:
        idx = resolve_gwl_dyn_index(
            getattr(
                model,
                "scaling_kwargs",
                None,
            )
        )
        model.gwl_dyn_index = int(idx)
    return int(idx)


def resolve_subs_dyn_index(scaling_kwargs):
    """Resolve subsidence channel index for dynamic_features.

    This is optional: v3.2 can use historical subsidence as a dynamic
    driver to provide a physics-friendly initial condition for the mean
    settlement path.
    """
    sk = scaling_kwargs or {}

    idx = sk.get("subs_dyn_index", None)
    if idx is not None:
        return int(idx)

    names = sk.get("dynamic_feature_names", None)

    subs_col = get_sk(sk, "subs_dyn_name", default=None)

    # NEW: fallback to gwl_z_meta.cols.subs_model
    if subs_col is None:
        meta = sk.get("gwl_z_meta", {}) or {}
        cols = meta.get("cols", {}) or {}
        subs_col = cols.get("subs_model", None)

    if names is not None and subs_col is not None:
        names = list(names)
        if subs_col in names:
            return int(names.index(subs_col))

    raise ValueError(
        "Cannot resolve subsidence channel. Provide subs_dyn_index "
        "or dynamic_feature_names + subs_dyn_name (or gwl_z_meta.cols.subs_model)."
    )


def get_subs_dyn_index_cached(model) -> int:
    """Cache subs_dyn_index on model after first resolve."""
    idx = getattr(model, "subs_dyn_index", None)
    if idx is None:
        idx = resolve_subs_dyn_index(
            getattr(model, "scaling_kwargs", None)
        )
        model.subs_dyn_index = int(idx)
    return int(idx)


def slice_dynamic_channel(Xh: Tensor, idx: int) -> Tensor:
    """Slice (B,T,F) -> (B,T,1) at idx."""
    idx_t = tf_cast(idx, tf_int32)
    F = tf_shape(Xh)[-1]
    tf_debugging.assert_less(
        idx_t,
        F,
        message="gwl_dyn_index out of range.",
    )
    return Xh[:, :, idx_t : idx_t + 1]


def assert_dynamic_names_match_tensor(
    Xh: Tensor,
    scaling_kwargs: dict[str, Any] | None,
) -> None:
    """Check dynamic_feature_names length matches Xh."""
    sk = scaling_kwargs or {}
    names = sk.get("dynamic_feature_names", None)
    if names is None:
        return
    n = len(list(names))
    tf_debugging.assert_equal(
        tf_shape(Xh)[-1],
        tf_constant(n, tf_int32),
        message="dynamic_feature_names != Xh last dim",
    )


def gwl_to_head_m(
    v_m: Tensor,
    scaling_kwargs: dict[str, Any] | None,
    *,
    inputs: dict[str, Tensor] | None = None,
) -> Tensor:
    """
    Convert depth-bgs to head if possible.

    Behavior
    --------
    - If gwl_kind is head-like: return v_m.
    - Otherwise treat as depth and try:
      head = z_surf - depth.
    - If z_surf is missing:
      * use_head_proxy=True  -> return -depth
      * use_head_proxy=False -> return depth
    """
    sk = scaling_kwargs or {}

    # --------------------------------------------------
    # 1) Decide whether v_m is head or depth.
    # --------------------------------------------------
    kind_raw = sk.get("gwl_kind", None)
    if kind_raw is None or str(kind_raw).strip() == "":
        gwl_col = str(get_sk(sk, "gwl_col", default=""))
        gwl_col = gwl_col.lower()
        kind = "depth" if ("depth" in gwl_col) else "head"
    else:
        kind = str(kind_raw).lower()

    if kind in ("head", "waterhead", "hydraulic_head"):
        return tf_cast(v_m, tf_float32)

    # --------------------------------------------------
    # 2) Depth convention + proxy behavior.
    # --------------------------------------------------
    sign = str(sk.get("gwl_sign", "down_positive")).lower()
    proxy = bool(sk.get("use_head_proxy", True))

    # --------------------------------------------------
    # 3) Collect possible z_surf keys.
    # Prefer SI/static key first when available.
    # --------------------------------------------------
    meta = sk.get("gwl_z_meta", {}) or {}
    cols = meta.get("cols", {}) or {}

    z_surf_col = sk.get("z_surf_col", None)
    z_surf_col = z_surf_col or meta.get("z_surf_col", None)

    z_surf_static = cols.get("z_surf_static", None)
    z_surf_raw = cols.get("z_surf_raw", None)

    z_surf_keys = [
        k
        for k in (z_surf_static, z_surf_col, z_surf_raw)
        if k
    ]

    # Dedupe while preserving order.
    seen = set()
    z_surf_keys = [
        k
        for k in z_surf_keys
        if not (k in seen or seen.add(k))
    ]

    # --------------------------------------------------
    # 4) Convert to positive-down depth.
    # --------------------------------------------------
    v_m = tf_cast(v_m, tf_float32)
    depth_m = v_m if sign == "down_positive" else -v_m

    # --------------------------------------------------
    # 5) Try direct inputs[z_surf_key] first.
    # --------------------------------------------------
    z_surf = None
    if inputs is not None:
        for k in z_surf_keys:
            z_surf = inputs.get(k, None)
            if z_surf is not None:
                z_surf = tf_cast(z_surf, tf_float32)
                break

    # --------------------------------------------------
    # 6) If missing, try static_features lookup.
    # --------------------------------------------------
    if z_surf is None and inputs is not None:
        sf = inputs.get("static_features", None)
        if sf is not None:
            sf = tf_cast(sf, tf_float32)

            idx = sk.get("z_surf_static_index", None)
            if idx is None:
                names = get_sk(
                    sk,
                    "static_feature_names",
                    default=None,
                )
                if names is not None:
                    names = list(names)
                    for k in z_surf_keys:
                        if k in names:
                            idx = int(names.index(k))
                            break

            if idx is not None:
                idx_i = int(idx)

                tf_debugging.assert_less(
                    tf_cast(idx_i, tf_int32),
                    tf_shape(sf)[-1],
                    message="z_surf_static_index out of range.",
                )

                r = getattr(sf.shape, "rank", None)
                if r == 2:
                    z_surf = sf[:, idx_i : idx_i + 1]
                elif r == 3:
                    z_surf = sf[:, :, idx_i : idx_i + 1]
                else:
                    rr = tf_rank(sf)
                    z_surf = tf_cond(
                        tf_equal(rr, 2),
                        lambda: sf[:, idx_i : idx_i + 1],
                        lambda: sf[:, :, idx_i : idx_i + 1],
                    )

    if z_surf is None:
        # if bool(sk.get("debug_units", False)):
        tf_print(
            "[gwl_to_head_m] z_surf missing ->",
            "use_head_proxy=",
            bool(sk.get("use_head_proxy", False)),
            "returning depth-like quantity (NOT true head)",
        )

    # --------------------------------------------------
    # 7) If we have z_surf: head = z_surf - depth.
    # --------------------------------------------------
    if z_surf is not None:
        r = tf_rank(z_surf)
        z_surf = tf_cond(
            tf_equal(r, 1),
            lambda: z_surf[:, None, None],
            lambda: tf_cond(
                tf_equal(r, 2),
                lambda: z_surf[:, None, :],
                lambda: z_surf,
            ),
        )

        # Broadcast z_surf to match depth_m.
        z_surf = z_surf + tf_zeros_like(depth_m)
        return z_surf - depth_m

    # --------------------------------------------------
    # 8) Fallback: proxy head or keep depth.
    # --------------------------------------------------
    return -depth_m if proxy else depth_m


def _reshape_to_b11(v: Tensor) -> Tensor:
    """Coerce a tensor to (B,1,1) if possible."""
    v = tf_cast(v, tf_float32)
    r = tf_rank(v)
    return tf_cond(
        tf_equal(r, 1),
        lambda: v[:, None, None],
        lambda: tf_cond(
            tf_equal(r, 2),
            lambda: v[:, None, :],
            lambda: v,
        ),
    )


def get_h_hist_si(
    model,
    inputs: dict[str, Tensor],
    *,
    want_head: bool = True,
) -> Tensor:
    """Return head (or depth) history in SI meters.

    Parameters
    ----------
    model : object
        The model instance (provides ``scaling_kwargs`` and cached indices).
    inputs : dict
        Batch inputs; expects ``dynamic_features`` unless an explicit
        head history key is provided.
    want_head : bool, default=True
        If True, convert depth-bgs to hydraulic head when possible.

    Returns
    -------
    Tensor
        (B,T,1) tensor in SI meters.
    """
    sk = getattr(model, "scaling_kwargs", None)

    # Explicit override (useful for scenario-driven runs)
    for k in ("h_hist_si", "head_hist_si", "gwl_hist_si"):
        if k in inputs and inputs[k] is not None:
            v = tf_cast(inputs[k], tf_float32)
            # (B,T) -> (B,T,1)
            if tf_equal(tf_rank(v), 2):
                v = v[:, :, None]
            if want_head:
                v = gwl_to_head_m(v, sk, inputs=inputs)
            return v

    Xh = inputs.get("dynamic_features", None)
    if Xh is None:
        raise ValueError(
            "Cannot build head history: missing inputs['dynamic_features'] "
            "and no explicit head history key (h_hist_si/head_hist_si)."
        )

    Xh = tf_cast(Xh, tf_float32)
    assert_dynamic_names_match_tensor(Xh, sk)

    gwl_idx = get_gwl_dyn_index_cached(model)
    gwl = slice_dynamic_channel(Xh, gwl_idx)
    gwl_si = to_si_head(gwl, sk)

    return (
        gwl_to_head_m(gwl_si, sk, inputs=inputs)
        if want_head
        else gwl_si
    )


def get_s_init_si(
    model,
    inputs: dict[str, Tensor] | None,
    like: Tensor,
) -> Tensor:
    """Return initial settlement (cumulative subsidence) in SI meters.

    Priority:
    1) explicit keys in inputs (s_init_si/subs_hist_last_si/...)
    2) last historical value from dynamic_features if subs_dyn_index exists
    3) zeros (broadcast)
    """
    sk = getattr(model, "scaling_kwargs", None)

    if inputs is not None:
        for k in (
            "s_init_si",
            "subs_init_si",
            "subs_hist_last_si",
            "s_ref_si",
            "subs_ref_si",
            "s_init",
            "subs_init",
        ):
            if k in inputs and inputs[k] is not None:
                return _reshape_to_b11(
                    inputs[k]
                ) + tf_zeros_like(like)

        Xh = inputs.get("dynamic_features", None)
        if Xh is not None:
            try:
                subs_idx = get_subs_dyn_index_cached(model)
            except Exception as e:
                _handle_scaling_issue(
                    getattr(model, "scaling_kwargs", None),
                    f"Could not resolve subsidence init channel ({e}). "
                    "Falling back to zeros for s_init_si.",
                    where="runtime",
                )
                subs_idx = None

            if subs_idx is not None:
                Xh = tf_cast(Xh, tf_float32)
                assert_dynamic_names_match_tensor(Xh, sk)
                s_hist = slice_dynamic_channel(
                    Xh, int(subs_idx)
                )
                s_last = s_hist[:, -1:, :]
                s_last_si = to_si_subsidence(s_last, sk)
                return s_last_si + tf_zeros_like(like)

    return tf_zeros_like(like)


def get_h_ref_si(
    model,
    inputs: dict[str, Tensor] | None,
    like: Tensor,
) -> Tensor:
    """Return h_ref in SI meters, broadcast to like."""
    # sk = getattr(model, "scaling_kwargs", None)

    mode = getattr(
        getattr(model, "h_ref_config", None), "mode", "auto"
    )
    mode = (
        "fixed"
        if str(mode).lower().strip() == "fixed"
        else "auto"
    )

    if inputs is not None:
        for k in (
            "h_ref_si",
            "head_ref_si",
            "h_ref",
            "head_ref",
        ):
            if (k in inputs) and (inputs[k] is not None):
                h_ref = tf_cast(inputs[k], tf_float32)
                r = tf_rank(h_ref)
                h_ref = tf_cond(
                    tf_equal(r, 1),
                    lambda: h_ref[:, None, None],
                    lambda: tf_cond(
                        tf_equal(r, 2),
                        lambda: h_ref[:, None, :],
                        lambda: h_ref,
                    ),
                )
                return h_ref + tf_zeros_like(like)

    if (
        mode != "fixed"
        and inputs is not None
        and "dynamic_features" in inputs
        and inputs["dynamic_features"] is not None
    ):
        h_hist = get_h_hist_si(model, inputs, want_head=True)
        return h_hist[:, -1:, :] + tf_zeros_like(like)

    h0 = tf_cast(getattr(model, "h_ref", 0.0), tf_float32)
    h0 = h0[None, None, None]
    return h0 + tf_zeros_like(like)


def infer_dt_units_from_t(
    t_BH1: Tensor,
    scaling_kwargs: dict[str, Any] | None,
    *,
    eps: float = 1e-12,
) -> Tensor:
    """
    Infer per-step dt in *time_units* from time tensor t(B,H,1).

    Shapes
    ------
    t_BH1 : (B,H,1)
    returns: (B,H,1)

    Notes
    -----
    - dt uses diffs along H; first step uses the first diff.
    - If coords are normalized, dt is multiplied by the de-normalization
      time range tR (from coord_ranges()).
    - Output is clipped to >= eps.
    """

    sk = scaling_kwargs or {}
    t = tf_convert_to_tensor(t_BH1, dtype=tf_float32)

    # t shape: (B,H,1)
    H = tf_shape(t)[1]
    dt_default = tf_ones_like(t)  # (B,H,1), safe in-graph

    def _multi_step():
        diffs = t[:, 1:, :] - t[:, :-1, :]  # (B,H-1,1)
        dt_first = diffs[:, :1, :]  # (B,1,1)
        dt = tf_concat([dt_first, diffs], axis=1)  # (B,H,1)

        # If coords were normalized, dt is still normalized -> scale back
        if bool(sk.get("coords_normalized", False)):
            tR, _, _ = coord_ranges(sk)
            if tR is None:
                raise ValueError(
                    "coords_normalized=True but coord_ranges missing."
                )
            dt = dt * tf_constant(float(tR), dtype=tf_float32)
        return dt

    # if H <= 1: ones; else: diffs
    dt = tf_cond(
        tf_less_equal(H, 1), lambda: dt_default, _multi_step
    )
    dt = tf_abs(dt)
    dt_pos = tf_greater(dt, tf_constant(0.0, tf_float32))
    dt_pos_f = tf_cast(dt_pos, tf_float32)
    dt = dt * dt_pos_f + dt_default * (1.0 - dt_pos_f)

    dt_eps = float(get_sk(sk, "dt_min_units", default=1e-6))
    dt = tf_maximum(dt, tf_constant(dt_eps, tf_float32))

    return dt


# -------------------------------------------------
# Training strategy gates (Q and subsidence residual)
# ---------------------------------------------------------------------
def policy_gate(
    step: Tensor,
    policy: str,
    *,
    warmup_steps: int = 0,
    ramp_steps: int = 0,
    dtype: Any = tf_float32,
) -> Tensor:
    r"""Return a scalar gate in ``[0,1]`` based on a policy + step.

    Parameters
    ----------
    step : Tensor
        Global step counter (typically ``optimizer.iterations``).
    policy : {"always_on","always_off","warmup_off"}
        Gating behavior. ``always_on`` returns 1, ``always_off``
        returns 0, and ``warmup_off`` returns 0 for
        ``step < warmup_steps`` before ramping to 1 over
        ``ramp_steps`` when ``ramp_steps > 0`` or switching
        immediately at ``warmup_steps`` otherwise.
    warmup_steps : int, default=0
        Number of steps to keep the gate at 0 (only for ``warmup_off``).
    ramp_steps : int, default=0
        Number of steps for a linear ramp from 0->1 after warmup.
        If 0, the gate is a hard step.
    dtype : dtype, default=tf_float32
        Output dtype.
    """
    pol = (policy or "always_on").strip().lower()
    if pol in ("always_on", "on", "true", "1"):
        return tf_constant(1.0, dtype=dtype)
    if pol in ("always_off", "off", "false", "0"):
        return tf_constant(0.0, dtype=dtype)

    w = int(warmup_steps or 0)
    r = int(ramp_steps or 0)

    if w <= 0 and r <= 0:
        return tf_constant(1.0, dtype=dtype)

    step_i = tf_cast(step, tf_int32)

    if r <= 0:
        return tf_cast(
            tf_greater_equal(
                step_i, tf_constant(w, tf_int32)
            ),
            dtype,
        )

    step_f = tf_cast(step_i, dtype)
    w_f = tf_constant(float(w), dtype)
    r_f = tf_constant(float(r), dtype)
    frac = (step_f - w_f) / r_f
    frac = tf_maximum(tf_constant(0.0, dtype), frac)
    frac = tf_minimum(tf_constant(1.0, dtype), frac)
    return frac


# ---------------------------------------------------------------------
# Derived SI conversion helpers (optional, but recommended)
# ---------------------------------------------------------------------
def finalize_scaling_kwargs(
    sk: dict[str, Any],
) -> dict[str, Any]:
    """Add derived SI conversion constants to ``scaling_kwargs``.

    Adds (when possible):
    - ``seconds_per_time_unit``: float
    - ``coord_ranges_si``: dict with keys ``t`` (seconds), ``x``/``y`` (meters)
    - ``coord_inv_ranges_si``: inverse of the above (safe floor).

    Notes
    -----
    This helper is designed to be called *once* when assembling
    ``scaling_kwargs`` (e.g., in your stage2 script) so the model can
    reuse those constants without recomputing unit conversions in the
    hot training loop.
    """
    if sk is None:
        return sk

    sk = dict(sk)

    tu = (
        str(get_sk(sk, "time_units", default="second"))
        .strip()
        .lower()
    )
    time_unit_to_seconds = {
        "second": 1.0,
        "sec": 1.0,
        "s": 1.0,
        "minute": 60.0,
        "min": 60.0,
        "m": 60.0,
        "hour": 3600.0,
        "h": 3600.0,
        "day": 86400.0,
        "d": 86400.0,
        # Julian year (365.2425 days) to match prior_maths.py
        "year": 31556952.0,
        "yr": 31556952.0,
        "y": 31556952.0,
    }
    sec_u = float(time_unit_to_seconds.get(tu, 1.0))
    sk.setdefault("seconds_per_time_unit", sec_u)

    cr = get_sk(sk, "coord_ranges", default=None)
    if isinstance(cr, Mapping) and all(
        k in cr for k in ("t", "x", "y")
    ):
        tR = float(cr.get("t", 1.0))
        xR = float(cr.get("x", 1.0))
        yR = float(cr.get("y", 1.0))

        # If coordinates are degrees, convert spans to meters.
        if bool(
            get_sk(sk, "coords_in_degrees", default=False)
        ):
            deg_to_m_lon = get_sk(
                sk, "deg_to_m_lon", default=None
            )
            deg_to_m_lat = get_sk(
                sk, "deg_to_m_lat", default=None
            )
            if (
                deg_to_m_lon is not None
                and deg_to_m_lat is not None
            ):
                xR *= float(deg_to_m_lon)
                yR *= float(deg_to_m_lat)

        # Convert time span to seconds (important if coords_normalized=True).
        tR *= sec_u

        sk["coord_ranges_si"] = {"t": tR, "x": xR, "y": yR}
        eps = 1e-12
        sk["coord_inv_ranges_si"] = {
            "t": 1.0 / max(tR, eps),
            "x": 1.0 / max(xR, eps),
            "y": 1.0 / max(yR, eps),
        }

    return sk


def coord_ranges_si(
    sk: dict[str, Any],
) -> tuple[float | None, float | None, float | None]:
    """Return coordinate spans in SI (t in seconds; x/y in meters).

    If ``coord_ranges_si`` is present in ``sk``, it is used directly.
    Otherwise, this is computed from ``coord_ranges`` and ``time_units``
    (and degree-to-meter factors when applicable).
    """
    cr_si = get_sk(sk, "coord_ranges_si", default=None)
    if isinstance(cr_si, Mapping) and all(
        k in cr_si for k in ("t", "x", "y")
    ):
        return (
            float(cr_si["t"]),
            float(cr_si["x"]),
            float(cr_si["y"]),
        )

    sk2 = finalize_scaling_kwargs(sk)
    cr_si = get_sk(sk2, "coord_ranges_si", default=None)
    if isinstance(cr_si, Mapping) and all(
        k in cr_si for k in ("t", "x", "y")
    ):
        return (
            float(cr_si["t"]),
            float(cr_si["x"]),
            float(cr_si["y"]),
        )

    return None, None, None

NAT package exports#

geoprior/utils/nat_utils/__init__.py#
r"""Public exports for NAT workflow utilities."""

from .nat_utils import (
    build_censor_mask,
    ensure_config_json,
    ensure_input_shapes,
    get_config_paths,
    get_natcom_dir,
    load_nat_config,
    load_nat_config_payload,
    load_scaler_info,
    load_windows_npz,
    make_tf_dataset,
    map_targets_for_training,
    pick_npz_for_dataset,
    resolve_hybrid_config,
    resolve_si_affine,
    sanitize_inputs_np,
)
from .natutils import (
    best_epoch_and_metrics,
    compile_for_eval,
    compile_geoprior_for_eval,
    extract_preds,
    load_best_hps_near_model,
    load_hps_auto_near_model,
    load_or_rebuild_geoprior_model,
    load_trained_hps_near_model,
    load_tuned_hps_near_model,
    name_of,
    save_ablation_record,
    serialize_subs_params,
    subs_point_from_out,
)

__all__ = [
    "build_censor_mask",
    "ensure_input_shapes",
    "extract_preds",
    "load_nat_config",
    "load_nat_config_payload",
    "load_scaler_info",
    "make_tf_dataset",
    "map_targets_for_training",
    "name_of",
    "resolve_hybrid_config",
    "resolve_si_affine",
    "best_epoch_and_metrics",
    "subs_point_from_out",
    "serialize_subs_params",
    "save_ablation_record",
    "load_windows_npz",
    "load_tuned_hps_near_model",
    "load_trained_hps_near_model",
    "sanitize_inputs_np",
    "load_hps_auto_near_model",
    "load_or_rebuild_geoprior_model",
    "compile_for_eval",
    "load_best_hps_near_model",
    "pick_npz_for_dataset",
    "ensure_config_json",
    "get_natcom_dir",
    "get_config_paths",
    "compile_geoprior_for_eval",
]

Optional: workflow/NAT helper module#

geoprior/utils/nat_utils/natutils.py#
# SPDX-License-Identifier: Apache-2.0
# GeoPrior-v3 — https://github.com/earthai-tech/geoprior-v3
# Copyright (c) 2026-present
# Author: LKouadio <etanoyau@gmail.com>
# website:https://lkouadio.com
r"""NAT evaluation and artifact helpers for GeoPrior."""

from __future__ import annotations

import datetime as dt
import glob
import json
import os
from collections.abc import Mapping, Sequence
from typing import Any

import numpy as np
import pandas as pd

# --- Optional TensorFlow import for GeoPrior helpers -----------------------
try:  # pragma: no cover - defensive import
    import tensorflow as tf  # noqa
    from tensorflow.keras.optimizers import Adam

    TF_AVAILABLE = True
except Exception:  # pragma: no cover
    TF_AVAILABLE = False
    tf = None  # type: ignore[assignment]

    class _AdamStub:
        def __init__(self, *args: Any, **kwargs: Any) -> None:
            raise ImportError(
                "TensorFlow is required for NATCOM GeoPrior helpers "
                "(e.g. compile_geoprior_for_eval). Please install "
                "`tensorflow>=2.12`."
            )

    Adam = _AdamStub  # type: ignore[assignment]


def save_ablation_record(
    outdir: str,
    city: str,
    model_name: str,
    cfg: dict,
    eval_dict: dict | None,
    phys_diag: dict | None = None,
    per_h_mae: dict | None = None,
    per_h_r2: dict | None = None,
    log_fn=None,
) -> None:
    """
    Append a single ablation record to ``ablation_record.jsonl``.

    Each training run (e.g., different physics toggles or weights)
    writes one JSON line containing:

    - Basic run identifiers (city, model, timestamp).
    - Physics configuration (``PDE_MODE_CONFIG``, lambda weights,
      effective head flags, etc.).
    - Key performance metrics (R², MSE, MAE, coverage, sharpness).
    - Optional physics diagnostics (``epsilon_prior``,
      ``epsilon_cons``).
    - Optional per-horizon MAE/R² for more detailed analysis.

    Parameters
    ----------
    outdir : str
        Base output directory for the current run. The ablation
        file is created under ``outdir / "ablation_records"``.
    city : str
        City name (e.g., ``"nansha"`` or ``"zhongshan"``).
    model_name : str
        Model identifier (e.g., ``"GeoPriorSubsNet"``).
    cfg : dict
        Lightweight configuration dictionary containing at least
        the physics-related keys used below.
    eval_dict : dict or None
        Dictionary of evaluation metrics (R², MSE, MAE,
        coverage80, sharpness80). If ``None``, metrics fields
        default to ``None``.
    phys_diag : dict or None, optional
        Physics diagnostics (e.g., from ``evaluate()``) with keys
        such as ``"epsilon_prior"`` and ``"epsilon_cons"``.
    per_h_mae : dict or None, optional
        Per-horizon MAE values (e.g., keyed by year/step).
    per_h_r2 : dict or None, optional
        Per-horizon R² values.

    Notes
    -----
    The output file is a JSON-Lines file, so it can be loaded
    with :func:`load_ablation_jsonl`.
    """
    if log_fn is None:
        log_fn = print

    # eval_dict = eval_dict or {}
    metrics = dict(eval_dict or {})

    rec = {
        "timestamp": dt.datetime.now().strftime(
            "%Y%m%d-%H%M%S"
        ),
        "city": city,
        "model": model_name,
        # Physics toggles / weights
        "pde_mode": cfg.get("PDE_MODE_CONFIG"),
        "use_effective_h": bool(
            cfg.get("GEOPRIOR_USE_EFFECTIVE_H", True)
        ),
        "kappa_mode": cfg.get("GEOPRIOR_KAPPA_MODE", "bar"),
        "hd_factor": cfg.get("GEOPRIOR_HD_FACTOR", 0.6),
        "lambda_cons": cfg.get("LAMBDA_CONS"),
        "lambda_gw": cfg.get("LAMBDA_GW"),
        "lambda_prior": cfg.get("LAMBDA_PRIOR"),
        "lambda_smooth": cfg.get("LAMBDA_SMOOTH"),
        "lambda_mv": cfg.get("LAMBDA_MV"),
        "lambda_bounds": cfg.get("LAMBDA_BOUNDS"),
        "lambda_q": cfg.get("LAMBDA_Q"),
        # Key metrics
        # "r2": eval_dict.get("r2"),
        # "mse": eval_dict.get("mse"),
        # "mae": eval_dict.get("mae"),
        # "coverage80": eval_dict.get("coverage80"),
        # "sharpness80": eval_dict.get("sharpness80"),
        "r2": metrics.get("r2"),
        "mse": metrics.get("mse"),
        "mae": metrics.get("mae"),
        "rmse": metrics.get("rmse"),
        "coverage80": metrics.get("coverage80"),
        "sharpness80": metrics.get("sharpness80"),
    }
    # Keep the full metrics payload (post-hoc vs evaluate(), units, etc.)
    rec["metrics"] = metrics

    # Convenience: surface units at top-level if provided.
    if isinstance(metrics.get("units"), dict):
        rec["units"] = metrics.get("units")

    if phys_diag:
        rec["epsilon_prior"] = phys_diag.get("epsilon_prior")
        rec["epsilon_cons"] = phys_diag.get("epsilon_cons")
        if "epsilon_gw" in phys_diag:
            rec["epsilon_gw"] = phys_diag.get("epsilon_gw")

    if per_h_mae is not None:
        rec["per_horizon_mae"] = per_h_mae
    if per_h_r2 is not None:
        rec["per_horizon_r2"] = per_h_r2

    abl_dir = os.path.join(outdir, "ablation_records")
    os.makedirs(abl_dir, exist_ok=True)

    jpath = os.path.join(abl_dir, "ablation_record.jsonl")
    with open(jpath, "a", encoding="utf-8") as f:
        f.write(json.dumps(rec) + "\n")

    log_fn(f"[Ablation] appended -> {jpath}")


def load_ablation_jsonl(path: str) -> pd.DataFrame:
    """
    Load an ablation JSON-Lines file into a :class:`pandas.DataFrame`.

    This is the companion to :func:`save_ablation_record`. Each
    line is parsed as JSON and turned into one row.

    Parameters
    ----------
    path : str
        Path to ``ablation_record.jsonl``.

    Returns
    -------
    pandas.DataFrame
        DataFrame where each row corresponds to one ablation
        record.

    Examples
    --------
    >>> df_abl = load_ablation_jsonl(
    ...     "ablation_records/ablation_record.jsonl"
    ... )
    >>> df_abl.head()
    """
    rows = []
    with open(path, encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            rows.append(json.loads(line))
    return pd.DataFrame(rows)


def name_of(obj: object) -> str:
    """
    Return a human-readable name for an object.

    This utility is handy when serialising compile configurations
    (e.g., turning metric callables into simple strings for JSON
    logs).

    Parameters
    ----------
    obj : object
        Any Python object (function, class instance, etc.).

    Returns
    -------
    str
        ``obj.__name__`` if present, otherwise the class name, and
        finally ``str(obj)`` as a last resort.
    """
    if hasattr(obj, "__name__"):
        return obj.__name__
    if hasattr(obj, "__class__"):
        return obj.__class__.__name__
    return str(obj)


def serialize_subs_params(
    params: dict,
    cfg: dict | None = None,
) -> dict:
    """
    Make GeoPrior subnet parameters JSON-friendly.

    The training scripts typically pass a dictionary of model
    construction arguments, e.g. ``subsmodel_params``, which
    contains objects such as ``LearnableMV`` or ``FixedGammaW``
    that are not directly JSON-serialisable.

    This helper replaces those objects by small dictionaries
    describing their type and scalar value, optionally using
    values from the NATCOM config dictionary.

    Parameters
    ----------
    params : dict
        Dictionary of model init parameters (e.g.
        ``subsmodel_params`` in ``training_NATCOM_GEOPRIOR.py``).
    cfg : dict, optional
        NATCOM config dictionary. If provided, scalar values are
        taken from:

        - ``GEOPRIOR_INIT_MV``
        - ``GEOPRIOR_INIT_KAPPA``
        - ``GEOPRIOR_GAMMA_W``
        - ``GEOPRIOR_H_REF``

        and used as the authoritative numbers.

    Returns
    -------
    dict
        Copy of ``params`` where scalar GeoPrior parameters are
        replaced by JSON-friendly dictionaries.

    Notes
    -----
    This function does **not** import any of the GeoPrior classes.
    It only introspects attributes like ``initial_value`` or
    ``value`` when the corresponding config entry is missing.
    """
    out = dict(params)
    cfg = cfg or {}

    # Helper to extract a scalar from either the config or the
    # original object (Learnable*/Fixed*).
    def _extract_scalar(obj, cfg_key: str) -> float | None:
        if cfg_key in cfg and cfg[cfg_key] is not None:
            try:
                return float(cfg[cfg_key])
            except Exception:
                pass
        # Fallback: try to read a typical attribute name.
        for attr in ("initial_value", "value"):
            if hasattr(obj, attr):
                try:
                    return float(getattr(obj, attr))
                except Exception:
                    continue
        return None

    if "mv" in out:
        mv_val = _extract_scalar(
            out["mv"], "GEOPRIOR_INIT_MV"
        )
        out["mv"] = {
            "type": "LearnableMV",
            "initial_value": mv_val,
        }

    if "kappa" in out:
        kap_val = _extract_scalar(
            out["kappa"], "GEOPRIOR_INIT_KAPPA"
        )
        out["kappa"] = {
            "type": "LearnableKappa",
            "initial_value": kap_val,
        }

    if "gamma_w" in out:
        gw_val = _extract_scalar(
            out["gamma_w"], "GEOPRIOR_GAMMA_W"
        )
        out["gamma_w"] = {
            "type": "FixedGammaW",
            "value": gw_val,
        }

    if "h_ref" in out:
        href_val = _extract_scalar(
            out["h_ref"], "GEOPRIOR_H_REF"
        )
        out["h_ref"] = {
            "type": "FixedHRef",
            "value": href_val,
        }

    return out


def best_epoch_and_metrics(
    history: dict,
    monitor: str = "val_loss",
) -> tuple[int | None, dict]:
    """
    Return the best epoch and metrics at that epoch.

    Given a ``History.history`` dictionary produced by
    ``model.fit(...)``, this helper identifies the index of the
    minimum value for the monitored quantity (by default
    ``"val_loss"``) and returns:

    - The epoch index (0-based).
    - A dictionary mapping each metric name to its value at that
      epoch.

    Parameters
    ----------
    history : dict
        The ``history.history`` attribute from Keras training.
    monitor : str, default="val_loss"
        Name of the metric to minimise.

    Returns
    -------
    best_epoch : int or None
        Index of the best epoch, or ``None`` if ``monitor`` is
        not present.
    metrics_at_best : dict
        Mapping from metric name to its value at the best epoch.
        Empty if ``monitor`` is not present.
    """
    if not history or monitor not in history:
        return None, {}

    # nanargmin makes sure NaNs are ignored when searching for the
    # best epoch.
    be = int(np.nanargmin(history[monitor]))
    metrics_at_best = {
        k: float(v[be])
        for k, v in history.items()
        if len(v) > be
    }
    return be, metrics_at_best


def load_or_rebuild_geoprior_model(
    model_path: str,
    manifest: dict,
    X_sample: dict,
    out_s_dim: int,
    out_g_dim: int,
    mode: str,
    horizon: int,
    quantiles: list[float] | None,
    city_name: str | None = None,
    compile_on_load: bool = True,
    verbose: int = 1,
):
    """
    Load a tuned *or trained* GeoPriorSubsNet, with robust rebuild fallback.

    Strategy
    --------
    1. Try ``tf.keras.models.load_model(model_path)`` with all required
       custom objects registered.

    2. If that fails:

       2a) Try tuned-model reconstruction::

              best_hps = load_best_hps_near_model(...)
              model = build_geoprior_from_hps(...)

       2b) If no best_hps JSON is found, assume a plain *trained* model
           and fall back to the ``*_training_summary.json`` recorded next
           to the checkpoint::

              training_summary = load_training_summary_near_model(...)
              model = build_geoprior_from_training_summary(...)

       In both cases, a minimal ``best_hps``-like dict is returned so
       :func:`compile_for_eval` can recreate the physics weights and
       learning rate.

    3. Try to load the best weights checkpoint via
       :func:`infer_best_weights_path(model_path)`.

    Returns
    -------
    model :
        A GeoPriorSubsNet instance ready to be recompiled for evaluation.

    best_hps : dict or None
        Tuned hyperparameters if present, otherwise a small dict
        containing at least ``learning_rate`` and the lambda weights
        from the training summary, or ``None``.
    """
    label_city = city_name or "GeoPrior"

    # --- Lazy imports so nat_utils can be imported without TF/geoprior ---
    try:
        import tensorflow as tf  # type: ignore # noqa
        from tensorflow.keras.models import load_model  # type: ignore
        from tensorflow.keras.utils import custom_object_scope  # type: ignore
    except Exception as e:  # pragma: no cover - env dependent
        raise ImportError(
            "load_or_rebuild_geoprior_model requires TensorFlow. "
            "Please install 'tensorflow>=2.12' to use this helper."
        ) from e

    try:
        from geoprior.nn.keras_metrics import (  # type: ignore
            coverage80_fn,
            sharpness80_fn,
        )
        from geoprior.nn.losses import (
            make_weighted_pinball,  # type: ignore
        )
        from geoprior.nn.pinn.models import (
            GeoPriorSubsNet,  # type: ignore
        )
        from geoprior.params import (  # type: ignore
            FixedGammaW,
            FixedHRef,
            LearnableKappa,
            LearnableMV,
        )
    except Exception as e:  # pragma: no cover - env dependent
        raise ImportError(
            "load_or_rebuild_geoprior_model requires geoprior components "
            "(GeoPriorSubsNet, LearnableMV, etc.). Ensure geoprior is "
            "installed and importable."
        ) from e

    custom_objects = {
        "GeoPriorSubsNet": GeoPriorSubsNet,
        "LearnableMV": LearnableMV,
        "LearnableKappa": LearnableKappa,
        "FixedGammaW": FixedGammaW,
        "FixedHRef": FixedHRef,
        "make_weighted_pinball": make_weighted_pinball,
        "coverage80_fn": coverage80_fn,
        "sharpness80_fn": sharpness80_fn,
    }

    best_hps: dict | None = None

    # ------------------- 1) Try direct load_model -------------------------
    with custom_object_scope(custom_objects):
        if verbose:
            print(
                f"[Model] Attempting to load model from: {model_path}"
            )

        try:
            model = load_model(
                model_path, compile=compile_on_load
            )
            if verbose:
                print(
                    f"[Model] Successfully loaded model for {label_city} "
                    f"from: {model_path}"
                )
            return model, best_hps
        except Exception as e_load:
            if verbose:
                print(
                    f"[Warn] load_model('{model_path}') failed: {e_load}\n"
                    "[Warn] Falling back to config-based reconstruction."
                )

    # ------------------- 2) Fallback: tuned HPs OR training summary -------
    try:
        # 2a) Tuned model path
        best_hps = load_best_hps_near_model(model_path)
        model = build_geoprior_from_hps(
            manifest=manifest,
            X_sample=X_sample,
            best_hps=best_hps,
            out_s_dim=out_s_dim,
            out_g_dim=out_g_dim,
            mode=mode,
            horizon=horizon,
            quantiles=quantiles,
        )
    except Exception as e_hps:
        # 2b) No best_hps JSON -> treat this as a plain trained model.
        if verbose:
            print(
                "[Fallback] No best_hps JSON found next to model_path; "
                "assuming a plain trained model.\n"
                f"          Reason: {e_hps}"
            )

        training_summary = load_training_summary_near_model(
            model_path, city_name=city_name
        )
        if training_summary is None:
            raise RuntimeError(
                "Failed to reconstruct GeoPriorSubsNet: neither tuned "
                "hyperparameters nor *_training_summary.json were found "
                f"near model_path={model_path!r}."
            ) from e_hps

        model = build_geoprior_from_training_summary(
            manifest=manifest,
            X_sample=X_sample,
            training_summary=training_summary,
            out_s_dim=out_s_dim,
            out_g_dim=out_g_dim,
            mode=mode,
            horizon=horizon,
            quantiles=quantiles,
        )

        # Build a minimal best_hps dict so compile_for_eval can recover the
        # training-time physics weights and learning rate.
        compile_block = (
            training_summary.get("compile", {}) or {}
        )
        phys = (
            compile_block.get("physics_loss_weights", {})
            or {}
        )
        lr = compile_block.get("learning_rate", None)

        hps_from_train: dict[str, float] = {}
        if lr is not None:
            try:
                hps_from_train["learning_rate"] = float(lr)
            except Exception:
                pass
        for k, v in phys.items():
            try:
                hps_from_train[k] = float(v)
            except Exception:
                continue

        best_hps = hps_from_train or None

    # ------------------- 3) Load weights if checkpoint exists -------------
    weights_path = infer_best_weights_path(model_path)
    if weights_path is not None:
        try:
            model.load_weights(weights_path)
            if verbose:
                print(
                    "[Fallback] Loaded weights into reconstructed "
                    f"GeoPriorSubsNet from: {weights_path}"
                )
        except Exception as e_w:
            if verbose:
                print(
                    "[Warn] Could not load weights from checkpoint:\n"
                    f"       {weights_path}\n"
                    f"       Error: {e_w}\n"
                    "       The rebuilt model is using freshly-initialised "
                    "weights. Predictions will NOT match the original run."
                )
    else:
        if verbose:
            print(
                "[Warn] No weights checkpoint found near model.\n"
                "       Using rebuilt model with freshly-initialised "
                "weights. Predictions will NOT match the original run."
            )

    return model, best_hps


def build_geoprior_from_training_summary(
    manifest: dict,
    X_sample: dict,
    training_summary: dict,
    out_s_dim: int,
    out_g_dim: int,
    mode: str,
    horizon: int,
    quantiles: list[float] | None,
) -> Any:
    """
    Reconstruct a GeoPriorSubsNet from a training_summary JSON.

    This is the fallback path for plain *trained* models (no tuning),
    using the architecture recorded under ``hp_init['model_init_params']``.

    Parameters
    ----------
    manifest : dict
        Stage-1 manifest dictionary (for some defaults).

    X_sample : dict
        One NPZ inputs dictionary already passed through
        :func:`ensure_input_shapes`. Only shapes are used.

    training_summary : dict
        Parsed ``*_training_summary.json`` for this run.

    out_s_dim, out_g_dim, mode, horizon, quantiles :
        Same semantics as in :func:`build_geoprior_from_hps`.

    Returns
    -------
    model : GeoPriorSubsNet
        Reconstructed model (uncompiled).
    """
    try:
        from geoprior.nn.pinn.models import (
            GeoPriorSubsNet,  # type: ignore
        )
    except Exception as e:  # pragma: no cover - env dependent
        raise ImportError(
            "build_geoprior_from_training_summary requires "
            "'geoprior.nn.pinn.models.GeoPriorSubsNet'. "
            "Ensure geoprior is installed and importable."
        ) from e

    cfg = manifest.get("config", {}) or {}

    # Infer input dims from the NPZ sample
    static_dim, dynamic_dim, future_dim = (
        infer_input_dims_from_X(X_sample)
    )

    hp_init = training_summary.get("hp_init", {}) or {}
    model_init = hp_init.get("model_init_params", {}) or {}

    # Quantiles: prefer explicit argument, then the training summary
    q = quantiles or hp_init.get("quantiles")

    # Attention stack
    attention_levels = model_init.get(
        "attention_levels",
        hp_init.get(
            "attention_levels",
            cfg.get(
                "ATTENTION_LEVELS",
                ["cross", "hierarchical", "memory"],
            ),
        ),
    )

    # Physics toggles
    censor_cfg = cfg.get("censoring", {}) or {}
    use_effective_h = bool(
        model_init.get(
            "use_effective_h",
            hp_init.get(
                "use_effective_h",
                censor_cfg.get("use_effective_h_field", True),
            ),
        )
    )
    pde_mode = hp_init.get(
        "pde_mode", cfg.get("PDE_MODE_CONFIG", "both")
    )
    kappa_mode = model_init.get(
        "kappa_mode",
        cfg.get("GEOPRIOR_KAPPA_MODE", "bar"),
    )
    scale_pde_residuals = bool(
        model_init.get("scale_pde_residuals", True)
    )

    # Architecture hyperparameters (as used at training time)
    embed_dim = int(model_init.get("embed_dim", 32))
    hidden_units = int(model_init.get("hidden_units", 96))
    lstm_units = int(model_init.get("lstm_units", 96))
    attention_units = int(
        model_init.get("attention_units", 32)
    )
    num_heads = int(model_init.get("num_heads", 4))
    dropout_rate = float(model_init.get("dropout_rate", 0.1))

    use_vsn = bool(
        model_init.get(
            "use_vsn", hp_init.get("use_vsn", True)
        )
    )
    vsn_units = int(
        model_init.get(
            "vsn_units", hp_init.get("vsn_units", 32)
        )
    )
    use_batch_norm = bool(
        model_init.get(
            "use_batch_norm",
            hp_init.get(
                "use_batch_norm",
                cfg.get("USE_BATCH_NORM", True),
            ),
        )
    )

    # Geomechanical parameters were serialised via `serialize_subs_params`
    # so we need to extract scalar initial values again.
    def _extract_initial(
        spec: Any, cfg_key: str, cfg_default: float
    ) -> float:
        if isinstance(spec, dict):
            if "initial_value" in spec:
                try:
                    return float(spec["initial_value"])
                except Exception:
                    pass
            if "value" in spec:
                try:
                    return float(spec["value"])
                except Exception:
                    pass
        if cfg_key in cfg and cfg[cfg_key] is not None:
            try:
                return float(cfg[cfg_key])
            except Exception:
                pass
        return float(cfg_default)

    mv_spec = model_init.get("mv", {})
    kappa_spec = model_init.get("kappa", {})

    mv_init = _extract_initial(
        mv_spec, "GEOPRIOR_INIT_MV", 5e-7
    )
    kappa_init = _extract_initial(
        kappa_spec, "GEOPRIOR_INIT_KAPPA", 1.0
    )

    # Pack the remaining architectural knobs into `architecture_config`
    known_keys = {
        "embed_dim",
        "hidden_units",
        "lstm_units",
        "attention_units",
        "num_heads",
        "dropout_rate",
        "use_vsn",
        "vsn_units",
        "use_batch_norm",
        "mv",
        "kappa",
        "gamma_w",
        "h_ref",
        "kappa_mode",
        "use_effective_h",
        "scale_pde_residuals",
        "attention_levels",
        "mode",
        "time_steps",
    }
    architecture_config = {
        k: v
        for k, v in model_init.items()
        if k not in known_keys
    }

    model = GeoPriorSubsNet(
        static_input_dim=static_dim,
        dynamic_input_dim=dynamic_dim,
        future_input_dim=future_dim,
        output_subsidence_dim=out_s_dim,
        output_gwl_dim=out_g_dim,
        forecast_horizon=horizon,
        mode=mode,
        attention_levels=attention_levels,
        quantiles=q,
        # physics switches
        pde_mode=pde_mode,
        scale_pde_residuals=scale_pde_residuals,
        kappa_mode=kappa_mode,
        use_effective_h=use_effective_h,
        # architecture hyperparameters
        embed_dim=embed_dim,
        hidden_units=hidden_units,
        lstm_units=lstm_units,
        attention_units=attention_units,
        num_heads=num_heads,
        dropout_rate=dropout_rate,
        use_vsn=use_vsn,
        vsn_units=vsn_units,
        use_batch_norm=use_batch_norm,
        # geomechanical priors
        mv=float(mv_init),
        kappa=float(kappa_init),
        architecture_config=architecture_config,
    )

    print(
        "[Fallback] Reconstructed GeoPriorSubsNet from training_summary with "
        f"static_dim={static_dim}, dynamic_dim={dynamic_dim}, "
        f"future_dim={future_dim}, horizon={horizon}, mode={mode}"
    )
    return model


def load_geoprior_for_inference(
    model_path: str,
    manifest: dict,
    X_sample: dict,
    out_s_dim: int,
    out_g_dim: int,
    mode: str,
    horizon: int,
    quantiles: list[float] | None,
    city_name: str | None = None,
    include_metrics: bool = True,
    verbose: int = 1,
):
    """
    Convenience wrapper: load (tuned or trained) GeoPriorSubsNet and
    compile it for evaluation/inference.

    Returns
    -------
    model :
        Compiled model ready for ``predict`` / diagnostics.

    info : dict
        Small dict with the ``best_hps`` (if any) and the resolved
        quantiles, useful for logging.
    """
    model, best_hps = load_or_rebuild_geoprior_model(
        model_path=model_path,
        manifest=manifest,
        X_sample=X_sample,
        out_s_dim=out_s_dim,
        out_g_dim=out_g_dim,
        mode=mode,
        horizon=horizon,
        quantiles=quantiles,
        city_name=city_name,
        compile_on_load=False,
        verbose=verbose,
    )

    model = compile_for_eval(
        model=model,
        manifest=manifest,
        best_hps=best_hps,
        quantiles=quantiles,
        include_metrics=include_metrics,
    )

    info = {
        "best_hps": best_hps,
        "quantiles": quantiles,
    }
    return model, info


def load_training_summary_near_model(
    model_path: str,
    city_name: str | None = None,
) -> dict | None:
    """
    Locate and load a ``*_training_summary.json`` next to a trained model.

    Strategy
    --------
    1. Prefer ``<city>_GeoPriorSubsNet_training_summary.json`` if
       ``city_name`` is given.
    2. Fallback: first file in the run directory that ends with
       ``'_training_summary.json'``.

    Parameters
    ----------
    model_path : str
        Path to the `.keras` archive or any checkpoint inside a
        ``train_YYYYMMDD-HHMMSS`` directory.

    city_name : str or None
        Optional city name to build a more specific candidate filename.

    Returns
    -------
    dict or None
        Parsed JSON dict if found and loadable, otherwise ``None``.
    """
    run_dir = os.path.dirname(os.path.abspath(model_path))
    candidates: list[str] = []

    if city_name:
        candidates.append(
            os.path.join(
                run_dir,
                f"{city_name}_GeoPriorSubsNet_training_summary.json",
            )
        )

    # Generic fallback: any *_training_summary.json in the run dir
    try:
        for fname in os.listdir(run_dir):
            if fname.endswith("_training_summary.json"):
                candidates.append(
                    os.path.join(run_dir, fname)
                )
    except FileNotFoundError:
        return None

    seen: set[str] = set()
    for path in candidates:
        if path in seen:
            continue
        seen.add(path)
        if os.path.exists(path):
            try:
                with open(path, encoding="utf-8") as f:
                    ts = json.load(f)
                print(
                    f"[TrainSummary] Loaded training_summary from: {path}"
                )
                return ts
            except (
                Exception
            ) as e:  # pragma: no cover - defensive
                print(
                    f"[Warn] Could not read training_summary JSON at {path!r}: {e}"
                )
                # Try the next candidate
                continue

    return None


def extract_preds(
    model: Any,
    out: Any,
    *,
    strict: bool = True,
    output_names: Sequence[str] | None = None,
) -> tuple[Any, Any]:
    r"""
    Extract (subs_pred, gwl_pred) from GeoPrior outputs.

    Supports:
      1) v3.2+ call(): {"subs_pred","gwl_pred"}
      2) forward_with_aux(): (y_pred, aux)
      3) legacy: {"data_final"} + model.split_data_predictions
      4) predict(): list/tuple mapped via output names

    If `strict=True`, list/tuple outputs *must* be mappable via
    output names; otherwise we raise to avoid silent swaps.


    This helper normalizes the output interface across two
    GeoPrior generation families:

    1. New interface (preferred)
       ``model(inputs) -> {"subs_pred": ..., "gwl_pred": ...}``

    2. Legacy interface (backward compatible)
       ``model(inputs) -> {"data_final": ...}``, where the caller
       must split the tensor using ``model.split_data_predictions``.

    Parameters
    ----------
    model : object
        A Keras-like model instance that may expose
        ``split_data_predictions(data_final)``.

        The splitter must return a tuple:

        - ``subs_pred`` with shape ``(B, H, 1)`` or ``(B, H, Q, 1)``
        - ``gwl_pred``  with shape ``(B, H, 1)`` or ``(B, H, Q, 1)``

    out : dict
        Output returned by the model call, typically
        ``model(inputs, training=False)``.

        Supported keys are either:

        - ``{"subs_pred", "gwl_pred"}`` (new interface), or
        - ``{"data_final"}`` (legacy interface).

    Returns
    -------
    subs_pred : Tensor
        Predicted subsidence in model space.

        Expected shapes:

        - Point mode: ``(B, H, 1)``
        - Quantile mode: ``(B, H, Q, 1)``

    gwl_pred : Tensor
        Predicted groundwater/head variable in model space.

        Expected shapes:

        - Point mode: ``(B, H, 1)``
        - Quantile mode: ``(B, H, Q, 1)``

    Raises
    ------
    KeyError
        If ``out`` does not contain a supported key set.
    TypeError
        If ``out`` is not a mapping/dict-like object.

    Notes
    -----
    This function is intended for Stage-2 and Stage-3
    scripts where you may load checkpoints from older
    experiments. It avoids fragile code that slices
    ``data_final`` manually.

    The function does not validate tensor dtypes or
    numerical finiteness. Upstream code should handle
    ``NaN`` and ``Inf`` checks as needed. Output normalization
    follows the Keras model conventions documented in
    :cite:t:`KerasDocs`.

    Examples
    --------
    New interface::

        out = model_inf(xb, training=False)
        s_pred, h_pred = extract_stage_outputs(
            model_inf,
            out,
        )

    Legacy interface::

        out = model_inf(xb, training=False)
        s_pred, h_pred = extract_stage_outputs(
            model_inf,
            out,
        )

    See Also
    --------
    subs_point_from_stage_out :
        Convert subsidence predictions to a point forecast.
    """
    # ---------------------------------------------------------
    # 0) forward_with_aux() style: (y_pred, aux)
    # ---------------------------------------------------------
    if isinstance(out, tuple) and len(out) == 2:
        y_pred, aux = out
        if isinstance(y_pred, Mapping):
            out = y_pred
        elif isinstance(aux, Mapping):
            # fallback: sometimes callers pass aux by mistake
            out = aux

    # ---------------------------------------------------------
    # 1) Mapping outputs
    # ---------------------------------------------------------
    if isinstance(out, Mapping):
        has_new = ("subs_pred" in out) and ("gwl_pred" in out)
        if has_new:
            return out["subs_pred"], out["gwl_pred"]

        # Legacy: data_final -> split
        if "data_final" in out and hasattr(
            model, "split_data_predictions"
        ):
            return model.split_data_predictions(
                out["data_final"]
            )

        # Single-key wrapper: unwrap one level and retry
        if len(out) == 1:
            only_val = next(iter(out.values()))
            return extract_preds(
                model,
                only_val,
                strict=strict,
                output_names=output_names,
            )

        raise KeyError(
            "Unsupported model output keys. Expected "
            "{'subs_pred','gwl_pred'} or {'data_final'} "
            "or a single-key wrapper. "
            f"Got keys={list(out.keys())!r}."
        )

    # ---------------------------------------------------------
    # 2) predict() outputs as list/tuple
    # ---------------------------------------------------------
    if isinstance(out, list | tuple):
        names = None
        if output_names is not None:
            names = list(output_names)
        else:
            names = getattr(model, "output_names", None)

        if names and len(names) == len(out):
            mapped = dict(zip(names, out, strict=False))
            return extract_preds(
                model,
                mapped,
                strict=strict,
                output_names=names,
            )

        if not strict and len(out) >= 2:
            # last-resort, opt-in only
            return out[0], out[1]

        raise TypeError(
            "Model output is a list/tuple but cannot be mapped "
            "to names. Provide `output_names=...` or set "
            "`strict=False` to assume order."
        )

    raise TypeError(
        "Expected `out` as Mapping, (y_pred, aux), "
        "or list/tuple. "
        f"Got type={type(out)!r}."
    )


def subs_point_from_out(
    model, out, quantiles=None, med_idx=None
):
    r"""
    Convert model output into a subsidence point forecast.

    This helper produces a subsidence tensor shaped ``(B, H, 1)``
    in model space, regardless of whether the model emits
    quantiles or a point prediction.

    - If quantiles are present and the subsidence prediction
      is shaped ``(B, H, Q, 1)``, the function selects the
      median quantile slice.
    - Otherwise, it returns the point prediction directly.

    Parameters
    ----------
    model : object
        A Keras-like model instance passed to
        :func:`extract_stage_outputs`.

    out : dict
        Output returned by the model call.

        This can be either the new interface with keys
        ``"subs_pred"`` and ``"gwl_pred"``, or the legacy
        interface with key ``"data_final"``.

    quantiles : sequence of float or None, default=None
        Quantile levels used by the model, such as
        ``[0.1, 0.5, 0.9]``.

        If provided, the function may use it to interpret
        the rank-4 quantile output and select the median.

        If ``None``, quantile selection is disabled unless
        ``med_idx`` is explicitly provided and the tensor
        rank indicates quantiles.

    med_idx : int or None, default=None
        Index along the quantile axis to use as the
        "point" forecast when quantiles are available.

        If ``None`` and ``quantiles`` is provided, the
        function selects the index closest to ``0.5``.

    Returns
    -------
    subs_point : Tensor
        Subsidence point prediction in model space with
        shape ``(B, H, 1)``.

    Raises
    ------
    ValueError
        If subsidence prediction is missing or ``None``.
    ValueError
        If a quantile tensor is detected but a valid
        median index cannot be resolved.

    Notes
    -----
    Quantile outputs are assumed to be shaped
    ``(B, H, Q, 1)`` where the quantile axis is the
    third dimension (axis=2).

    If the model returns point predictions already,
    the function is effectively a no-op. The quantile
    interpretation used here follows
    :cite:t:`KoenkerBassett1978`.

    Examples
    --------
    Quantile model::

        out = model_inf(xb, training=False)
        s_point = subs_point_from_stage_out(
            model_inf,
            out,
            quantiles=[0.1, 0.5, 0.9],
        )

    Point model::

        out = model_inf(xb, training=False)
        s_point = subs_point_from_stage_out(
            model_inf,
            out,
        )

    See Also
    --------
    extract_stage_outputs :
        Normalize outputs across new and legacy checkpoints.
    """
    subs_pred, _ = extract_preds(model, out)

    if subs_pred is None:
        raise ValueError("Model output 'subs_pred' is None.")

    # has_rank = hasattr(subs_pred, "shape") and (
    #     getattr(subs_pred.shape, "rank", None) is not None
    # )
    # is_quantile_tensor = has_rank and (subs_pred.shape.rank == 4)
    rank = None
    if hasattr(subs_pred, "shape"):
        rank = getattr(
            subs_pred.shape, "rank", None
        )  # TF tensors
        if rank is None:
            try:
                rank = len(
                    subs_pred.shape
                )  # NumPy arrays / tuples
            except Exception:
                rank = None

    is_quantile_tensor = rank == 4

    if not is_quantile_tensor:
        return subs_pred

    if med_idx is None:
        if not quantiles:
            raise ValueError(
                "Quantile tensor detected but `med_idx` "
                "is None and `quantiles` is not provided."
            )

        q = np.asarray(quantiles, dtype=float)
        med_idx = int(np.argmin(np.abs(q - 0.5)))

    if med_idx is None or int(med_idx) < 0:
        raise ValueError(
            "Invalid `med_idx` resolved for quantiles."
        )

    # return subs_pred[..., int(med_idx), :]
    # Quantile outputs assumed (B, H, Q, 1)
    return subs_pred[:, :, int(med_idx), :]


def _extract_allowed_hps(
    obj: object,
    *,
    allowed: set[str],
) -> dict:
    out: dict = {}

    def _walk(x: object) -> None:
        if isinstance(x, dict):
            for k, v in x.items():
                if isinstance(k, str) and k in allowed:
                    out[k] = v
                _walk(v)
        elif isinstance(x, list | tuple):
            for it in x:
                _walk(it)

    _walk(obj)
    return {k: out[k] for k in allowed if k in out}


def load_tuned_hps_near_model(
    model_path: str,
    *,
    prefer: str = "keras",
    required: bool = True,
    log_fn=None,
) -> dict:
    log = log_fn if callable(log_fn) else None

    def _msg(s: str) -> None:
        if log is not None:
            log(s)

    def _load_json(p: str) -> dict:
        with open(p, encoding="utf-8") as f:
            return json.load(f)

    mp = os.path.abspath(str(model_path))
    run_dir = mp if os.path.isdir(mp) else os.path.dirname(mp)
    base = "" if os.path.isdir(mp) else os.path.basename(mp)

    stem = None
    if prefer == "keras":
        if base.endswith("_best.keras"):
            stem = base[: -len("_best.keras")]
        elif base.endswith(".keras"):
            stem = base[: -len(".keras")]
    else:
        if base.endswith("_best.weights.h5"):
            stem = base[: -len("_best.weights.h5")]
        elif base.endswith(".weights.h5"):
            stem = base[: -len(".weights.h5")]

    cands: list[str] = []
    if stem:
        cands.append(
            os.path.join(run_dir, stem + "_best_hps.json")
        )
    cands.append(os.path.join(run_dir, "tuning_summary.json"))
    cands.extend(
        glob.glob(
            os.path.join(run_dir, "*tuning_summary*.json")
        )
    )

    for p in cands:
        if not os.path.exists(p):
            continue
        data = _load_json(p)
        hps = data.get("best_hps") or data.get("hps") or {}
        if isinstance(hps, dict) and hps:
            _msg(f"[HP] tuned: {p}")
            return hps

    if required:
        raise FileNotFoundError(
            "No tuned hyperparameters found near:\n"
            f"  model_path={model_path!r}\n"
            f"  run_dir={run_dir!r}\n"
        )
    return {}


def load_trained_hps_near_model(
    model_path: str,
    *,
    allowed: set[str],
    required: bool = False,
    log_fn=None,
) -> dict:
    log = log_fn if callable(log_fn) else None

    def _msg(s: str) -> None:
        if log is not None:
            log(s)

    def _load_json(p: str) -> object:
        with open(p, encoding="utf-8") as f:
            return json.load(f)

    mp = os.path.abspath(str(model_path))
    run_dir = mp if os.path.isdir(mp) else os.path.dirname(mp)

    pats = [
        os.path.join(run_dir, "model_init_manifest.json"),
        os.path.join(run_dir, "*training_summary*.json"),
        os.path.join(run_dir, "*architecture*.json"),
    ]

    files: list[str] = []
    for pat in pats:
        files.extend(glob.glob(pat))

    for p in files:
        if not os.path.exists(p):
            continue
        data = _load_json(p)
        hps = _extract_allowed_hps(data, allowed=allowed)
        if hps:
            _msg(f"[HP] trained: {p}")
            return hps

    if required:
        raise FileNotFoundError(
            "No trained hyperparameters found near:\n"
            f"  model_path={model_path!r}\n"
            f"  run_dir={run_dir!r}\n"
        )
    return {}


def load_hps_auto_near_model(
    model_path: str,
    *,
    allowed: set[str],
    prefer: str = "keras",
    required: bool = False,
    log_fn=None,
) -> dict:
    mp = os.path.abspath(str(model_path))
    run_dir = mp if os.path.isdir(mp) else os.path.dirname(mp)

    tuned_hits = []
    tuned_hits.extend(
        glob.glob(os.path.join(run_dir, "*_best_hps.json"))
    )
    tuned_hits.extend(
        glob.glob(
            os.path.join(run_dir, "*tuning_summary*.json")
        )
    )
    is_tuned = bool(tuned_hits) or ("tuning" in run_dir)

    if is_tuned:
        return load_tuned_hps_near_model(
            model_path,
            prefer=prefer,
            required=required,
            log_fn=log_fn,
        )

    return load_trained_hps_near_model(
        model_path,
        allowed=allowed,
        required=required,
        log_fn=log_fn,
    )


def load_best_hps_near_model(
    model_path: str,
    *,
    model_name: str | None = "GeoPriorSubsNet",
    prefer: str = "keras",
    log_fn=None,
) -> dict:
    """
    Load best hyperparameters saved near a model artifact.

    Supports model names like:
    <city>_<model_name>_H{H}_best.keras
    <city>_<model_name>_H{H}_best.weights.h5

    Parameters
    ----------
    model_path : str
        Path to a model file or its run directory.
    model_name : str or None, default="GeoPriorSubsNet"
        Model name token in filenames.
    prefer : {"keras", "weights"}, default="keras"
        Which artifact type to infer the prefix from.
    log_fn : callable or None, default=None
        Logger (e.g. print). None disables logs.

    Returns
    -------
    best_hps : dict
        Non-empty hyperparameter dictionary.

    Raises
    ------
    FileNotFoundError
        If no hyperparameter JSON is found.
    ValueError
        If a candidate JSON exists but is empty/invalid.
    """

    log = log_fn if callable(log_fn) else None

    if prefer not in ("keras", "weights"):
        raise ValueError(
            "prefer must be 'keras' or 'weights'."
        )

    mp = os.path.abspath(str(model_path))
    run_dir = mp if os.path.isdir(mp) else os.path.dirname(mp)
    base = "" if os.path.isdir(mp) else os.path.basename(mp)

    def _msg(s: str) -> None:
        if log is not None:
            log(s)

    def _load_json(p: str) -> dict:
        with open(p, encoding="utf-8") as f:
            return json.load(f)

    def _newest(paths: list[str]) -> str | None:
        c = []
        for p in paths:
            try:
                c.append((os.path.getmtime(p), p))
            except Exception:
                pass
        if not c:
            return None
        c.sort(reverse=True)
        return c[0][1]

    # -------------------------------------------------
    # If a directory is provided, infer "base" by scan.
    # -------------------------------------------------
    if not base:
        pats = []
        if prefer == "keras":
            if model_name:
                pats.append(
                    os.path.join(
                        run_dir,
                        f"*_{model_name}_H*_best.keras",
                    )
                )
            pats.append(
                os.path.join(run_dir, "*_H*_best.keras")
            )
            pats.append(os.path.join(run_dir, "*_best.keras"))
        else:
            if model_name:
                pats.append(
                    os.path.join(
                        run_dir,
                        f"*_{model_name}_H*_best.weights.h5",
                    )
                )
            pats.append(
                os.path.join(
                    run_dir,
                    "*_H*_best.weights.h5",
                )
            )
            pats.append(
                os.path.join(run_dir, "*_best.weights.h5")
            )

        hits = []
        for pat in pats:
            hits.extend(glob.glob(pat))
        best = _newest(hits)
        if best:
            base = os.path.basename(best)

    # -------------------------------------------------
    # Infer stem/prefix from the artifact filename.
    # -------------------------------------------------
    stem = None
    if prefer == "keras":
        if base.endswith("_best.keras"):
            stem = base[: -len("_best.keras")]
        elif base.endswith(".keras"):
            stem = base[: -len(".keras")]
    else:
        if base.endswith("_best.weights.h5"):
            stem = base[: -len("_best.weights.h5")]
        elif base.endswith(".weights.h5"):
            stem = base[: -len(".weights.h5")]

    # Try to parse city / model / horizon from stem.
    city = None
    mname = None
    horizon = None
    city_model = None

    if stem:
        left = stem
        if "_H" in left:
            a, b = left.rsplit("_H", 1)
            digs = []
            for ch in b:
                if ch.isdigit():
                    digs.append(ch)
                else:
                    break
            if digs:
                horizon = int("".join(digs))
            left = a

        city_model = left

        if model_name:
            tok = "_" + str(model_name)
            if city_model.endswith(tok):
                city = city_model[: -len(tok)]
                mname = str(model_name)

        if city is None:
            parts = city_model.split("_")
            if len(parts) >= 2:
                city = parts[0]
                mname = "_".join(parts[1:])

    # -------------------------------------------------
    # 1) Near-model explicit JSONs.
    # -------------------------------------------------
    cands = []
    if stem:
        cands.append(
            os.path.join(run_dir, stem + "_best_hps.json")
        )

    if city and mname:
        cands.append(
            os.path.join(
                run_dir,
                f"{city}_{mname}_best_hps.json",
            )
        )
        if horizon is not None:
            cands.append(
                os.path.join(
                    run_dir,
                    f"{city}_{mname}_H{horizon}_best_hps.json",
                )
            )

    if city:
        cands.append(
            os.path.join(
                run_dir,
                f"{city}_{mname}_best_hps.json",
            )
        )

    seen = set()
    for p in cands:
        if p in seen:
            continue
        seen.add(p)
        if os.path.exists(p):
            best_hps = _load_json(p)
            if isinstance(best_hps, dict) and best_hps:
                _msg(f"[HP] Loaded best_hps: {p}")
                return best_hps
            raise ValueError(
                f"{p!r} exists but is empty/invalid."
            )

    # -------------------------------------------------
    # 2) tuning summaries.
    # -------------------------------------------------
    sum_pats = [
        os.path.join(run_dir, "tuning_summary.json"),
        os.path.join(run_dir, "*tuning_summary*.json"),
    ]
    sums = []
    for pat in sum_pats:
        sums.extend(glob.glob(pat))
    for p in sums:
        if not os.path.exists(p):
            continue
        s = _load_json(p)
        best_hps = s.get("best_hps") or s.get("hps") or {}
        if isinstance(best_hps, dict) and best_hps:
            _msg(f"[HP] Loaded best_hps: {p}")
            return best_hps

    # -------------------------------------------------
    # 3) training summaries.
    # -------------------------------------------------
    for p in glob.glob(
        os.path.join(run_dir, "*training_summary*.json")
    ):
        s = _load_json(p)
        best_hps = (
            s.get("best_hps")
            or s.get("hps")
            or s.get("params")
            or {}
        )
        if isinstance(best_hps, dict) and best_hps:
            _msg(f"[HP] Loaded best_hps: {p}")
            return best_hps

    # -------------------------------------------------
    # 4) architecture dumps.
    # -------------------------------------------------
    for p in glob.glob(
        os.path.join(run_dir, "*architecture*.json")
    ):
        a = _load_json(p)
        best_hps = (
            a.get("best_hps")
            or a.get("hps")
            or a.get("params")
            or {}
        )
        if isinstance(best_hps, dict) and best_hps:
            _msg(f"[HP] Loaded best_hps: {p}")
            return best_hps

    raise FileNotFoundError(
        "Could not find best hyperparameters near:\n"
        f"  model_path={model_path!r}\n"
        f"  run_dir={run_dir!r}\n"
        f"  prefer={prefer!r}\n"
        "Looked for *_best_hps.json + summaries."
    )


def coerce_quantile_weights(
    d: dict | None,
    default: dict,
) -> dict:
    """
    Normalize a quantile-weight mapping to have float keys and float values.

    This helper is useful when reading JSON configs where the quantile
    keys are stored as strings (e.g. ``{'0.1': 3.0, '0.5': 1.0}``).

    Parameters
    ----------
    d : dict or None
        Original dictionary mapping quantile-like keys (str or float) to
        numeric weights. If ``None`` or empty, ``default`` is returned.

    default : dict
        Fallback dictionary to use when ``d`` is ``None`` or empty.

    Returns
    -------
    out : dict
        Dictionary with the same keys and values, but with:

        - keys coerced to float when possible (otherwise left as-is),
        - values coerced to ``float``.
    """
    if not d:
        return default

    out: dict[Any, float] = {}
    for k, v in d.items():
        try:
            q = float(k)
        except (TypeError, ValueError):
            # Non-numeric key (rare): keep as-is
            q = k
        out[q] = float(v)
    return out


def compile_for_eval(
    model: Any,
    manifest: dict,
    best_hps: dict | None,
    quantiles: list[float] | None,
    *,
    include_metrics: bool = True,
) -> Any:
    """
    Recompile a GeoPriorSubsNet instance for evaluation / diagnostics.

    This is intended for:
    - tuned models loaded from a `.keras` archive, or
    - models rebuilt from best_hps.

    It does NOT change the architecture or weights, only the compile
    configuration (optimizer, losses, and physics loss weights).

    Parameters
    ----------
    model : GeoPriorSubsNet
        Loaded or freshly-built GeoPriorSubsNet instance.
    manifest : dict
        Stage-1 manifest; training config is taken from
        ``manifest['config']``.
    best_hps : dict or None
        Dictionary of tuned hyperparameters. If empty/None, reasonable
        defaults are inferred from the manifest.
    quantiles : list of float or None
        Quantiles used for probabilistic subsidence/GWL outputs.
    include_metrics : bool, default=True
        If True, attach MAE/MSE + coverage/sharpness metrics to match
        the training script; if False, only losses are configured.

    Returns
    -------
    model :
        The same model instance, compiled in-place.
    """
    if not TF_AVAILABLE:
        raise ImportError(
            "TensorFlow is required to compile GeoPriorSubsNet. "
            "Install `tensorflow>=2.12` to use "
            "`compile_geoprior_for_eval`."
        )

    # Local imports so nat_utils.py itself stays lightweight
    from geoprior.nn.losses import make_weighted_pinball

    if include_metrics:
        from geoprior.nn.keras_metrics import (
            coverage80_fn,
            sharpness80_fn,
        )

    cfg = manifest.get("config", {}) or {}
    best_hps = best_hps or {}

    # ---- 1. Data loss weights / quantile weights -------------------------
    subs_raw = cfg.get(
        "SUBS_WEIGHTS",
        {0.1: 3.0, 0.5: 1.0, 0.9: 3.0},
    )
    gwl_raw = cfg.get(
        "GWL_WEIGHTS",
        {0.1: 1.5, 0.5: 1.0, 0.9: 1.5},
    )

    subs_w = _coerce_quantile_weights(
        subs_raw, {0.1: 3.0, 0.5: 1.0, 0.9: 3.0}
    )
    gwl_w = _coerce_quantile_weights(
        gwl_raw, {0.1: 1.5, 0.5: 1.0, 0.9: 1.5}
    )

    if quantiles:
        loss_dict = {
            "subs_pred": make_weighted_pinball(
                quantiles, subs_w
            ),
            "gwl_pred": make_weighted_pinball(
                quantiles, gwl_w
            ),
        }
    else:
        mse = tf.keras.losses.MeanSquaredError()
        loss_dict = {"subs_pred": mse, "gwl_pred": mse}

    loss_weights = {"subs_pred": 1.0, "gwl_pred": 0.5}

    # ---- 2. Physics weights: prefer best_hps, fall back to config --------
    def _hp_or_cfg(
        hp_key: str, cfg_key: str, default: float
    ) -> float:
        if (
            hp_key in best_hps
            and best_hps[hp_key] is not None
        ):
            return float(best_hps[hp_key])
        if cfg_key in cfg and cfg[cfg_key] is not None:
            return float(cfg[cfg_key])
        return float(default)

    lr = _hp_or_cfg("learning_rate", "LEARNING_RATE", 1e-4)

    physics_kwargs = {
        "lambda_gw": _hp_or_cfg(
            "lambda_gw", "LAMBDA_GW", 1.0
        ),
        "lambda_cons": _hp_or_cfg(
            "lambda_cons", "LAMBDA_CONS", 1.0
        ),
        "lambda_prior": _hp_or_cfg(
            "lambda_prior", "LAMBDA_PRIOR", 0.1
        ),
        "lambda_smooth": _hp_or_cfg(
            "lambda_smooth", "LAMBDA_SMOOTH", 0.01
        ),
        "lambda_mv": _hp_or_cfg(
            "lambda_mv", "LAMBDA_MV", 0.0
        ),
        "mv_lr_mult": _hp_or_cfg(
            "mv_lr_mult", "MV_LR_MULT", 1.0
        ),
        "kappa_lr_mult": _hp_or_cfg(
            "kappa_lr_mult", "KAPPA_LR_MULT", 1.0
        ),
    }

    compile_kwargs: dict[str, Any] = {
        "optimizer": Adam(learning_rate=lr),
        "loss": loss_dict,
        "loss_weights": loss_weights,
        **physics_kwargs,
    }

    if include_metrics:
        metrics_dict = {
            "subs_pred": ["mae", "mse"]
            + (
                [coverage80_fn, sharpness80_fn]
                if quantiles
                else []
            ),
            "gwl_pred": ["mae", "mse"],
        }
        compile_kwargs["metrics"] = metrics_dict

    model.compile(**compile_kwargs)
    return model


def compile_geoprior_for_eval(
    model: Any,  # type: ignore[override]
    manifest: dict,
    best_hps: dict,
    quantiles: list[float] | None,
) -> Any:
    """
    (Re)compile a GeoPriorSubsNet-like model for evaluation.

    This helper uses the Stage-1 manifest and tuned hyperparameters to
    configure:

    - the pinball losses for subsidence and GWL outputs,
    - loss weights for the two heads,
    - physics loss weights (lambda_*),
    - learning rate and LR multipliers.

    TensorFlow and geoprior are imported lazily inside this function so
    that ``nat_utils`` can be imported even in non-TF environments.

    Parameters
    ----------
    model : GeoPriorSubsNet-like
        An instance of the GeoPriorSubsNet model (or any model exposing
        the same compile signature).

    manifest : dict
        Stage-1 manifest dictionary. The ``config`` entry is used to
        retrieve default loss weights and physics settings.

    best_hps : dict
        Hyperparameters loaded from the tuning run
        (e.g. via :func:`load_best_hps_near_model`).

    quantiles : list of float or None
        Quantile levels used for probabilistic outputs. If ``None``,
        mean-squared error is used instead of pinball loss.

    Returns
    -------
    model
        The same model instance, compiled in-place.

    Raises
    ------
    ImportError
        If TensorFlow or geoprior's ``make_weighted_pinball`` cannot be
        imported.
    """
    cfg = manifest.get("config", {}) or {}

    # Lazy imports so nat_utils.py is importable without TensorFlow
    try:
        import tensorflow as tf  # type: ignore
        from tensorflow.keras.optimizers import (
            Adam,  # type: ignore
        )
    except Exception as e:  # pragma: no cover - env dependent
        raise ImportError(
            "compile_geoprior_for_eval requires TensorFlow. "
            "Please install 'tensorflow>=2.12' to use this helper."
        ) from e

    try:
        from geoprior.nn.losses import (
            make_weighted_pinball,  # type: ignore
        )
    except Exception as e:  # pragma: no cover - env dependent
        raise ImportError(
            "compile_geoprior_for_eval requires "
            "'geoprior.nn.losses.make_weighted_pinball'. "
            "Ensure geoprior is installed and importable."
        ) from e

    # Base loss weights between subsidence and GWL heads
    loss_weights = {"subs_pred": 1.0, "gwl_pred": 0.5}

    # Quantile-specific weights from config (with robust defaults)
    subs_raw = cfg.get(
        "SUBS_WEIGHTS", {0.1: 3.0, 0.5: 1.0, 0.9: 3.0}
    )
    gwl_raw = cfg.get(
        "GWL_WEIGHTS", {0.1: 1.5, 0.5: 1.0, 0.9: 1.5}
    )

    subs_weights = coerce_quantile_weights(
        subs_raw, {0.1: 3.0, 0.5: 1.0, 0.9: 3.0}
    )
    gwl_weights = coerce_quantile_weights(
        gwl_raw, {0.1: 1.5, 0.5: 1.0, 0.9: 1.5}
    )

    if quantiles:
        loss_subs = make_weighted_pinball(
            quantiles, subs_weights
        )
        loss_gwl = make_weighted_pinball(
            quantiles, gwl_weights
        )
    else:
        loss_subs = tf.keras.losses.MSE
        loss_gwl = tf.keras.losses.MSE

    loss_dict = {"subs_pred": loss_subs, "gwl_pred": loss_gwl}

    # Learning rate: tuned value or config fallback
    lr_default = cfg.get("LEARNING_RATE", 5e-5)
    lr = float(best_hps.get("learning_rate", lr_default))
    optimizer = Adam(learning_rate=lr)

    # Physics loss weights and LR multipliers
    lambda_gw = float(
        best_hps.get("lambda_gw", cfg.get("LAMBDA_GW", 1.0))
    )
    lambda_cons = float(
        best_hps.get(
            "lambda_cons", cfg.get("LAMBDA_CONS", 1.0)
        )
    )
    lambda_prior = float(
        best_hps.get(
            "lambda_prior", cfg.get("LAMBDA_PRIOR", 1.0)
        )
    )
    lambda_smooth = float(
        best_hps.get(
            "lambda_smooth", cfg.get("LAMBDA_SMOOTH", 1.0)
        )
    )
    lambda_mv = float(
        best_hps.get("lambda_mv", cfg.get("LAMBDA_MV", 0.0))
    )
    mv_lr_mult = float(
        best_hps.get("mv_lr_mult", cfg.get("MV_LR_MULT", 1.0))
    )
    kappa_lr_mult = float(
        best_hps.get(
            "kappa_lr_mult", cfg.get("KAPPA_LR_MULT", 1.0)
        )
    )

    model.compile(
        optimizer=optimizer,
        loss=loss_dict,
        loss_weights=loss_weights,
        # physics loss weights + LR multipliers
        lambda_gw=lambda_gw,
        lambda_cons=lambda_cons,
        lambda_prior=lambda_prior,
        lambda_smooth=lambda_smooth,
        lambda_mv=lambda_mv,
        mv_lr_mult=mv_lr_mult,
        kappa_lr_mult=kappa_lr_mult,
    )
    return model


def build_geoprior_from_hps(
    manifest: dict,
    X_sample: dict,
    best_hps: dict,
    out_s_dim: int,
    out_g_dim: int,
    mode: str,
    horizon: int,
    quantiles: list[float] | None,
) -> Any:
    """
    Reconstruct a GeoPriorSubsNet from Stage-1 metadata + tuned HPs.

    This function is primarily intended as a **robust fallback** when
    ``tf.keras.models.load_model`` cannot deserialize a tuned model.
    It reconstructs the network geometry and physics settings from:

    - Stage-1 ``manifest['config']`` (for fixed architecture / physics),
    - tuned hyperparameters (for variable architecture / physics),
    - the Stage-1 input NPZ (for input dimensions).

    Parameters
    ----------
    manifest : dict
        Stage-1 manifest dictionary.

    X_sample : dict
        Inputs NPZ dictionary (already sanitized and passed through
        :func:`ensure_input_shapes`). Only shapes are used.

    best_hps : dict
        Hyperparameters loaded via :func:`load_best_hps_near_model`.

    out_s_dim : int
        Output dimension for the subsidence head.

    out_g_dim : int
        Output dimension for the GWL head.

    mode : str
        Sequence mode, e.g. ``"tft_like"`` or ``"pihal_like"``.

    horizon : int
        Forecast horizon (number of time steps).

    quantiles : list of float or None
        Quantile levels for probabilistic outputs.

    Returns
    -------
    model : GeoPriorSubsNet
        A freshly instantiated and compiled GeoPriorSubsNet instance.

    Raises
    ------
    ImportError
        If GeoPriorSubsNet cannot be imported from geoprior.
    """
    try:
        from geoprior.nn.pinn.models import (
            GeoPriorSubsNet,  # type: ignore
        )
    except Exception as e:  # pragma: no cover - env dependent
        raise ImportError(
            "build_geoprior_from_hps requires "
            "'geoprior.nn.pinn.models.GeoPriorSubsNet'. "
            "Ensure geoprior is installed and importable."
        ) from e

    cfg = manifest.get("config", {}) or {}

    # Infer input dims directly from NPZ
    static_dim, dynamic_dim, future_dim = (
        infer_input_dims_from_X(X_sample)
    )

    # Attention stack: fall back to a sensible default if not present
    attention_levels = cfg.get(
        "ATTENTION_LEVELS",
        ["cross", "hierarchical", "memory"],
    )

    # Whether we used effective H during Stage-2
    censor_cfg = cfg.get("censoring", {}) or {}
    use_effective_h = censor_cfg.get(
        "use_effective_h_field", True
    )

    # Feature-processing mode controlled by tuned HPs
    use_vsn = bool(best_hps.get("use_vsn", True))
    feature_processing = "vsn" if use_vsn else "dense"

    architecture_config = {
        "encoder_type": "hybrid",
        "decoder_attention_stack": attention_levels,
        "feature_processing": feature_processing,
    }

    # Instantiate the model core with tuned settings
    model = GeoPriorSubsNet(
        static_input_dim=static_dim,
        dynamic_input_dim=dynamic_dim,
        future_input_dim=future_dim,
        output_subsidence_dim=out_s_dim,
        output_gwl_dim=out_g_dim,
        forecast_horizon=horizon,
        mode=mode,
        attention_levels=attention_levels,
        quantiles=quantiles,
        # physics switches from best_hps
        pde_mode=best_hps.get("pde_mode", "both"),
        scale_pde_residuals=bool(
            best_hps.get("scale_pde_residuals", True)
        ),
        kappa_mode=best_hps.get("kappa_mode", "bar"),
        use_effective_h=use_effective_h,
        # architecture hyperparameters
        embed_dim=int(best_hps.get("embed_dim", 32)),
        hidden_units=int(best_hps.get("hidden_units", 96)),
        lstm_units=int(best_hps.get("lstm_units", 96)),
        attention_units=int(
            best_hps.get("attention_units", 32)
        ),
        num_heads=int(best_hps.get("num_heads", 4)),
        dropout_rate=float(best_hps.get("dropout_rate", 0.1)),
        use_vsn=use_vsn,
        vsn_units=int(best_hps.get("vsn_units", 32)),
        use_batch_norm=bool(
            best_hps.get("use_batch_norm", True)
        ),
        # geomechanical priors (floats interpreted internally by the model)
        mv=float(best_hps.get("mv", 5e-7)),
        kappa=float(best_hps.get("kappa", 1.0)),
        architecture_config=architecture_config,
    )

    # Compile using the shared helper (losses + physics weights)
    compile_geoprior_for_eval(
        model=model,
        manifest=manifest,
        best_hps=best_hps,
        quantiles=quantiles,
    )

    print(
        "[Fallback] Reconstructed GeoPriorSubsNet from best_hps with "
        f"static_dim={static_dim}, dynamic_dim={dynamic_dim}, "
        f"future_dim={future_dim}, horizon={horizon}, mode={mode}"
    )
    return model


def build_geoprior_from_cfg(
    manifest: dict,
    X_sample: dict,
    out_s_dim: int,
    out_g_dim: int,
    mode: str,
    horizon: int,
    quantiles: list[float] | None,
) -> Any:
    """
    Reconstruct a GeoPriorSubsNet from the NATCOM config only.

    This is intended as a fallback for *trained* models (no tuning JSON)
    or when no best_hps can be found next to `model_path`.

    Parameters
    ----------
    manifest : dict
        Stage-1 manifest dictionary. The ``"config"`` entry is used as
        the single source of truth for architecture + physics settings.
    X_sample : dict
        NPZ inputs dict (already sanitised and passed through
        :func:`ensure_input_shapes` or equivalent). Only shapes are used.
    out_s_dim, out_g_dim : int
        Output dims for subsidence and GWL heads.
    mode : str
        Sequence mode, e.g. ``"tft_like"`` or ``"pihal_like"``.
    horizon : int
        Forecast horizon (number of time steps).
    quantiles : list of float or None
        Quantiles for probabilistic outputs.

    Returns
    -------
    model : GeoPriorSubsNet
        Compiled model instance ready for prediction/eval.
    """
    try:
        from geoprior.nn.pinn.models import (
            GeoPriorSubsNet,  # type: ignore
        )
    except Exception as e:  # pragma: no cover - env dependent
        raise ImportError(
            "build_geoprior_from_cfg requires "
            "'geoprior.nn.pinn.models.GeoPriorSubsNet'. "
            "Ensure geoprior is installed and importable."
        ) from e

    cfg = manifest.get("config", {}) or {}

    # --- Input dims inferred from X_sample ----------------------------
    static_dim, dynamic_dim, future_dim = (
        infer_input_dims_from_X(X_sample)
    )

    # --- Attention stack / effective-H flag ---------------------------
    attention_levels = cfg.get(
        "ATTENTION_LEVELS",
        ["cross", "hierarchical", "memory"],
    )

    censor_cfg = cfg.get("censoring", {}) or {}
    use_effective_h = censor_cfg.get(
        "use_effective_h_field",
        bool(cfg.get("GEOPRIOR_USE_EFFECTIVE_H", True)),
    )

    # --- Physics switches ---------------------------------------------
    pde_mode = cfg.get(
        "PDE_MODE_CONFIG", cfg.get("PDE_MODE", "both")
    )
    scale_pde_residuals = bool(
        cfg.get(
            "SCALE_PDE_RESIDUALS",
            cfg.get("SCALE_PDE_RES", True),
        )
    )
    kappa_mode = cfg.get(
        "GEOPRIOR_KAPPA_MODE",
        cfg.get("KAPPA_MODE", "bar"),
    )

    # --- Small helpers to read ints/floats/bools from cfg ------------
    def _cfg_int(default: int, *keys: str) -> int:
        for k in keys:
            if k in cfg and cfg[k] is not None:
                try:
                    return int(cfg[k])
                except Exception:
                    pass
        return int(default)

    def _cfg_float(default: float, *keys: str) -> float:
        for k in keys:
            if k in cfg and cfg[k] is not None:
                try:
                    return float(cfg[k])
                except Exception:
                    pass
        return float(default)

    def _cfg_bool(default: bool, *keys: str) -> bool:
        for k in keys:
            if k in cfg and cfg[k] is not None:
                return bool(cfg[k])
        return bool(default)

    # --- Architecture hyperparams (config-side defaults) --------------
    embed_dim = _cfg_int(
        32, "EMBED_DIM", "GEOPRIOR_EMBED_DIM"
    )
    hidden_units = _cfg_int(
        96, "HIDDEN_UNITS", "GEOPRIOR_HIDDEN_UNITS"
    )
    lstm_units = _cfg_int(
        96, "LSTM_UNITS", "GEOPRIOR_LSTM_UNITS"
    )
    attention_units = _cfg_int(
        32, "ATTENTION_UNITS", "GEOPRIOR_ATTENTION_UNITS"
    )
    num_heads = _cfg_int(
        4, "NUM_HEADS", "NUMBER_HEADS", "GEOPRIOR_NUM_HEADS"
    )
    dropout_rate = _cfg_float(
        0.1, "DROPOUT_RATE", "GEOPRIOR_DROPOUT_RATE"
    )
    use_vsn = _cfg_bool(True, "USE_VSN", "GEOPRIOR_USE_VSN")
    vsn_units = _cfg_int(
        32, "VSN_UNITS", "GEOPRIOR_VSN_UNITS"
    )
    use_batch_norm = _cfg_bool(
        True, "USE_BATCH_NORM", "GEOPRIOR_USE_BATCH_NORM"
    )

    # --- Geomechanical priors (Terzaghi-ish) --------------------------
    mv = _cfg_float(5e-7, "GEOPRIOR_INIT_MV")
    kappa = _cfg_float(1.0, "GEOPRIOR_INIT_KAPPA")

    architecture_config = {
        "encoder_type": "hybrid",
        "decoder_attention_stack": attention_levels,
        "feature_processing": "vsn" if use_vsn else "dense",
    }

    model = GeoPriorSubsNet(
        static_input_dim=static_dim,
        dynamic_input_dim=dynamic_dim,
        future_input_dim=future_dim,
        output_subsidence_dim=out_s_dim,
        output_gwl_dim=out_g_dim,
        forecast_horizon=horizon,
        mode=mode,
        attention_levels=attention_levels,
        quantiles=quantiles,
        # physics switches
        pde_mode=pde_mode,
        scale_pde_residuals=scale_pde_residuals,
        kappa_mode=kappa_mode,
        use_effective_h=use_effective_h,
        # architecture hyperparameters
        embed_dim=embed_dim,
        hidden_units=hidden_units,
        lstm_units=lstm_units,
        attention_units=attention_units,
        num_heads=num_heads,
        dropout_rate=dropout_rate,
        use_vsn=use_vsn,
        vsn_units=vsn_units,
        use_batch_norm=use_batch_norm,
        # priors
        mv=mv,
        kappa=kappa,
        architecture_config=architecture_config,
    )

    # Compile using config-only settings
    compile_for_eval(
        model=model,
        manifest=manifest,
        best_hps=None,
        quantiles=quantiles,
        include_metrics=True,
    )

    print(
        "[Fallback] Reconstructed GeoPriorSubsNet from manifest config with "
        f"static_dim={static_dim}, dynamic_dim={dynamic_dim}, "
        f"future_dim={future_dim}, horizon={horizon}, mode={mode}"
    )
    return model


def infer_best_weights_path(model_path: str) -> str | None:
    """
    Infer the best-weights checkpoint path for a tuned GeoPrior model.

    Strategy
    --------
    1. Look for ``tuning_summary.json`` in the same folder as
       ``model_path`` and return the stored ``\"best_weights_path\"``
       if it exists and the file is present on disk.
    2. Fallback: replace the ``.keras`` suffix of ``model_path`` by
       ``.weights.h5``, assuming the convention::

           <CITY>_GeoPrior_best.keras
           -> <CITY>_GeoPrior_best.weights.h5

    Parameters
    ----------
    model_path : str
        Path to the tuned model archive (usually ``.keras``).

    Returns
    -------
    weights_path : str or None
        Absolute path to the weights file if found, otherwise ``None``.
    """
    run_dir = os.path.dirname(os.path.abspath(model_path))

    # 1) Preferred: tuning_summary.json
    summary_path = os.path.join(
        run_dir, "tuning_summary.json"
    )
    if os.path.exists(summary_path):
        try:
            with open(summary_path, encoding="utf-8") as f:
                summary = json.load(f)
            w = summary.get("best_weights_path")
            if w and os.path.exists(w):
                return w
        except Exception as e:  # pragma: no cover - defensive
            print(
                f"[Warn] Could not read tuning_summary.json for weights: {e}"
            )

    # 2) Name-based guess from the .keras path
    root, ext = os.path.splitext(model_path)
    guess = root + ".weights.h5"
    if os.path.exists(guess):
        return guess

    return None


def _load_or_rebuild_geoprior_model(
    model_path: str,
    manifest: dict,
    X_sample: dict,
    out_s_dim: int,
    out_g_dim: int,
    mode: str,
    horizon: int,
    quantiles: list[float] | None,
    city_name: str | None = None,
    compile_on_load: bool = True,
    verbose: int = 1,
):
    """
    Load a tuned GeoPriorSubsNet from disk, with robust rebuild fallback.

    This helper centralizes the logic:

    1. Try to load the model from ``model_path`` via
       :func:`tf.keras.models.load_model`, with all required custom
       objects registered (GeoPriorSubsNet, LearnableMV, etc.).

    2. If loading fails (e.g. due to environment or serialization
       changes), attempt a robust fallback:
       - Load the tuned hyperparameters via
         :func:`load_best_hps_near_model`.
       - Rebuild a compatible GeoPriorSubsNet instance using
         :func:`build_geoprior_from_hps`, based on Stage-1
         ``manifest['config']`` and an input sample ``X_sample``.
       - Find the best weights checkpoint via
         :func:`infer_best_weights_path` and load them into the
         rebuilt model, if available.

    Parameters
    ----------
    model_path : str
        Path to the tuned model archive (usually ``.keras``) produced
        by the tuner, e.g.::

            .../tuning/run_YYYYMMDD-HHMMSS/nansha_GeoPrior_best.keras

    manifest : dict
        Stage-1 manifest dictionary; its ``"config"`` entry is used to
        reconstruct the compile/physics configuration when rebuilding.

    X_sample : dict
        One NPZ inputs dictionary (e.g. validation or train NPZ) that
        has already been sanitized and passed through
        :func:`ensure_input_shapes`. Only its shapes are used to infer
        input dimensions.

    out_s_dim : int
        Output dimension for the subsidence head
        (usually from ``M['artifacts']['sequences']['dims']``).

    out_g_dim : int
        Output dimension for the GWL head.

    mode : str
        Sequence mode, e.g. ``"tft_like"`` or ``"pihal_like"``.

    horizon : int
        Forecast horizon (number of time steps).

    quantiles : list of float or None
        Quantile levels used for probabilistic outputs. If ``None``,
        the model is treated as a point-forecast model.

    city_name : str or None, optional
        City name for log messages. If ``None``, a neutral label is
        used in logs.

    compile_on_load : bool, default=True
        Whether to pass ``compile=True`` to :func:`load_model`. If
        ``False``, the model is loaded uncompiled, and only the
        rebuilt branch is compiled via
        :func:`build_geoprior_from_hps`.

    verbose : int, default=1
        Verbosity level for log messages (0 = silent, 1 = info).

    Returns
    -------
    model :
        A GeoPriorSubsNet instance (or compatible model) ready for
        evaluation/prediction.

    best_hps : dict or None
        Dictionary of tuned hyperparameters if they were loaded during
        the fallback path; otherwise ``None``.

    Raises
    ------
    ImportError
        If TensorFlow or required geoprior components cannot be
        imported.

    RuntimeError
        If both direct loading and fallback reconstruction fail.
    """
    label_city = city_name or "GeoPrior"

    # --- Lazy imports so nat_utils can be imported without TF/geoprior ---
    try:
        import tensorflow as tf  # type: ignore # noqa
        from tensorflow.keras.models import load_model  # type: ignore
        from tensorflow.keras.utils import custom_object_scope  # type: ignore
    except Exception as e:  # pragma: no cover - env dependent
        raise ImportError(
            "load_or_rebuild_geoprior_model requires TensorFlow. "
            "Please install 'tensorflow>=2.12' to use this helper."
        ) from e

    try:
        from geoprior.nn.keras_metrics import (  # type: ignore
            coverage80_fn,
            sharpness80_fn,
        )
        from geoprior.nn.losses import (
            make_weighted_pinball,  # type: ignore
        )
        from geoprior.nn.pinn.models import (
            GeoPriorSubsNet,  # type: ignore
        )
        from geoprior.params import (  # type: ignore
            FixedGammaW,
            FixedHRef,
            LearnableKappa,
            LearnableMV,
        )
    except Exception as e:  # pragma: no cover - env dependent
        raise ImportError(
            "load_or_rebuild_geoprior_model requires geoprior components "
            "(GeoPriorSubsNet, LearnableMV, etc.). Ensure geoprior is "
            "installed and importable."
        ) from e

    # ------------------- 1) Try direct load_model -------------------------
    custom_objects = {
        "GeoPriorSubsNet": GeoPriorSubsNet,
        "LearnableMV": LearnableMV,
        "LearnableKappa": LearnableKappa,
        "FixedGammaW": FixedGammaW,
        "FixedHRef": FixedHRef,
        # custom loss factory / class
        "make_weighted_pinball": make_weighted_pinball,
        # custom metrics used in compile
        "coverage80_fn": coverage80_fn,
        "sharpness80_fn": sharpness80_fn,
    }

    best_hps: dict | None = None

    with custom_object_scope(custom_objects):
        if verbose:
            print(
                f"[Model] Attempting to load tuned model from: {model_path}"
            )

        # try:
        model = load_model(
            model_path, compile=compile_on_load
        )
        if verbose:
            print(
                f"[Model] Successfully loaded tuned model for {label_city} "
                f"from: {model_path}"
            )
        return model, best_hps
        # except Exception as e_load:
        #     if verbose:
        #         print(
        #             f"[Warn] load_model('{model_path}') failed: {e_load}\n"
        #             "[Warn] Attempting robust fallback: rebuild GeoPriorSubsNet "
        #             "from tuned hyperparameters."
        #         )

    # ------------------- 2) Fallback: rebuild + load weights --------------
    # 2.1 Hyperparameters near the tuned model
    try:
        best_hps = load_best_hps_near_model(model_path)
    except Exception as e_hps:
        raise RuntimeError(
            "Failed to load tuned hyperparameters for fallback model "
            f"reconstruction near model_path={model_path!r}: {e_hps}"
        ) from e_hps

    # 2.2 Rebuild architecture + compile using Stage-1 manifest + best_hps
    try:
        model = build_geoprior_from_hps(
            manifest=manifest,
            X_sample=X_sample,
            best_hps=best_hps,
            out_s_dim=out_s_dim,
            out_g_dim=out_g_dim,
            mode=mode,
            horizon=horizon,
            quantiles=quantiles,
        )
    except Exception as e_build:
        raise RuntimeError(
            "Failed to reconstruct GeoPriorSubsNet from best_hps. "
            f"Error: {e_build}"
        ) from e_build

    # 2.3 Load weights into the rebuilt model, if a checkpoint is found
    weights_path = infer_best_weights_path(model_path)
    if weights_path is not None:
        try:
            model.load_weights(weights_path)
            if verbose:
                print(
                    "[Fallback] Loaded weights into rebuilt GeoPriorSubsNet "
                    f"from: {weights_path}"
                )
        except Exception as e_w:
            # We still return the rebuilt model, but warn that it is not
            # weight-identical to the tuned run.
            if verbose:
                print(
                    "[Warn] Could not load weights from checkpoint:\n"
                    f"       {weights_path}\n"
                    f"       Error: {e_w}\n"
                    "       The rebuilt model is using freshly-initialized "
                    "weights. Predictions will NOT match the tuned model."
                )
    else:
        if verbose:
            print(
                "[Warn] No weights checkpoint found near tuned model.\n"
                "       Using rebuilt model with freshly-initialized "
                "weights. Predictions will NOT match the tuned model."
            )

    return model, best_hps


def infer_input_dims_from_X(X: dict) -> tuple[int, int, int]:
    """
    Infer (static_input_dim, dynamic_input_dim, future_input_dim)
    from NPZ inputs.

    This is a public, defensive version of the former
    ``_infer_input_dims_from_X`` helper.

    Parameters
    ----------
    X : dict
        Dictionary with keys:

        - ``'dynamic_features'`` (required, shape (N, T, D_dyn))
        - ``'static_features'`` (optional, shape (N, D_static) or None)
        - ``'future_features'`` (optional, shape (N, T_future, D_future) or None)

    Returns
    -------
    static_dim : int
        Last-dimension size of ``static_features`` (0 if missing or None).

    dynamic_dim : int
        Last-dimension size of ``dynamic_features``. Raises if missing.

    future_dim : int
        Last-dimension size of ``future_features`` (0 if missing or None).

    Raises
    ------
    KeyError
        If ``'dynamic_features'`` is missing in ``X``.
    """
    if "dynamic_features" not in X:
        raise KeyError(
            "X must contain key 'dynamic_features' with shape (N, T, D_dyn)."
        )

    dyn = np.asarray(X["dynamic_features"])
    dynamic_dim = int(dyn.shape[-1])

    static = X.get("static_features", None)
    static_dim = (
        int(np.asarray(static).shape[-1])
        if static is not None
        else 0
    )

    fut = X.get("future_features", None)
    future_dim = (
        int(np.asarray(fut).shape[-1])
        if fut is not None
        else 0
    )

    return static_dim, dynamic_dim, future_dim


# -------------------------------------------------------------------------
# Backward-compatible aliases for old private helper names
# -------------------------------------------------------------------------
safe_compile = compile_for_eval

_infer_input_dims_from_X = infer_input_dims_from_X
_load_best_hps_near_model = load_best_hps_near_model
_coerce_quantile_weights = coerce_quantile_weights
_compile_geoprior_for_eval = compile_geoprior_for_eval
_build_geoprior_from_hps = build_geoprior_from_hps
_infer_best_weights_path = infer_best_weights_path

_build_geoprior_from_cfg = build_geoprior_from_cfg

# Back-compat alias (docstrings still mention it)
extract_stage_outputs = extract_preds

Optional: model/PINN helper module#

geoprior/models/utils/pinn.py#
# SPDX-License-Identifier: Apache-2.0
# GeoPrior-v3 — https://github.com/earthai-tech/geoprior-v3
# Copyright (c) 2026-present
# Author: LKouadio <https://lkouadio.com>


"""
Physics-Informed Neural Network (PINN) Utility functions.
"""

from __future__ import annotations

import logging
import os
import warnings  # noqa
from collections.abc import Callable, Sequence
from typing import (
    Any,
    Final,
)

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.cm import ScalarMappable
from sklearn.preprocessing import MinMaxScaler

from ...api.util import get_table_size
from ...core.checks import (
    check_datetime,
    exist_features,
)
from ...core.handlers import columns_manager
from ...core.io import SaveFile
from ...decorators import isdf
from ...logging import get_logger
from ...metrics.utils import compute_quantile_diagnostics
from ...utils.data_utils import mask_by_reference
from ...utils.deps_utils import ensure_pkg
from ...utils.forecast_utils import normalize_for_pinn
from ...utils.generic_utils import (
    print_box,
    rename_dict_keys,
    select_mode,
    vlog,
)
from ...utils.geo_utils import resolve_spatial_columns
from ...utils.io_utils import save_job
from ...utils.sequence_utils import check_sequence_feasibility
from ...utils.validator import validate_positive_integer
from .. import KERAS_BACKEND, KERAS_DEPS

Model = KERAS_DEPS.Model
Tensor = KERAS_DEPS.Tensor

tf_shape = KERAS_DEPS.shape
tf_convert_to_tensor = KERAS_DEPS.convert_to_tensor
tf_float32 = KERAS_DEPS.float32
tf_cast = KERAS_DEPS.cast
tf_concat = KERAS_DEPS.concat
tf_expand_dims = KERAS_DEPS.expand_dims
tf_debugging = KERAS_DEPS.debugging
tf_fill = KERAS_DEPS.fill
tf_reshape = KERAS_DEPS.reshape

logger = get_logger(__name__)
_TW = get_table_size()

all__ = [
    "format_pihalnet_predictions",
    "prepare_pinn_data_sequences",
    "format_pinn_predictions",
    "extract_txy_in",
    "extract_txy",
    "plot_hydraulic_head",
    "PDE_MODE_ALIASES",
]


_PDE_NONE_ALIASES = {
    "none",
    "off",
    "false",
    "0",
    "disable",
    "disabled",
    "",
}

_PDE_BOTH_ALIASES = {
    "both",
    "on",
    "all",
    "true",
    "1",
}

_PDE_TOKEN_ALIASES = {
    "consolidation": "consolidation",
    "cons": "consolidation",
    "gw_flow": "gw_flow",
    "gw": "gw_flow",
    "groundwater": "gw_flow",
    "groundwater_flow": "gw_flow",
}

_PDE_ORDER = ("consolidation", "gw_flow")


PDE_CANONICAL_MODES: Final[frozenset[str]] = frozenset(
    {
        "consolidation",
        "gw_flow",
        "both",
        "none",
    }
)

PDE_MODE_ALIASES: Final[frozenset[str]] = frozenset(
    PDE_CANONICAL_MODES
    | _PDE_NONE_ALIASES
    | _PDE_BOTH_ALIASES
)

# Optional: strict public set for places where only canonical
# labels should be accepted after normalization.
PDE_ACTIVE_MODES: Final[frozenset[str]] = frozenset(
    {
        "consolidation",
        "gw_flow",
        "none",
    }
)


def _flatten_mode_input(
    value: str | Sequence[str] | None,
) -> list[str]:
    """
    Convert the raw pde_mode input into a flat list of lowercase tokens.
    """
    if value is None:
        return ["none"]

    if isinstance(value, str):
        # support comma-separated strings like "consolidation,gw_flow"
        parts = [p.strip().lower() for p in value.split(",")]
        parts = [p for p in parts if p != ""]
        return parts or ["none"]

    if isinstance(value, list | tuple | set):
        out: list[str] = []
        for item in value:
            if item is None:
                out.append("none")
            else:
                s = str(item).strip().lower()
                if s:
                    out.append(s)
                else:
                    out.append("none")
        return out or ["none"]

    raise TypeError(
        "`pde_mode` must be a string, a sequence of strings, or None."
    )


def _normalize_pde_tokens(tokens: Sequence[str]) -> list[str]:
    """
    Normalize aliases into canonical tokens.

    Canonical tokens are:
    - 'none'
    - 'consolidation'
    - 'gw_flow'
    - 'both' (intermediate token, later expanded)
    """
    normalized: list[str] = []

    for tok in tokens:
        if tok in _PDE_NONE_ALIASES:
            normalized.append("none")
        elif tok in _PDE_BOTH_ALIASES:
            normalized.append("both")
        elif tok in _PDE_TOKEN_ALIASES:
            normalized.append(_PDE_TOKEN_ALIASES[tok])
        else:
            raise ValueError(
                f"Unsupported pde_mode token: {tok!r}. "
                "Allowed values are: "
                "'none'/'off', 'consolidation', 'gw_flow', 'both'/'on'."
            )

    return normalized


def _canonicalize_pde_modes(
    tokens: Sequence[str],
) -> list[str]:
    """
    Convert normalized tokens into the final active-mode list.

    Rules
    -----
    - 'none' cannot be combined with any active PDE mode.
    - 'both' expands to ['consolidation', 'gw_flow'].
    - duplicates are removed while preserving canonical order.
    """
    toks = list(tokens)

    has_none = "none" in toks
    has_both = "both" in toks  # noqa

    # Expand "both" first conceptually, but reject ambiguous combinations.
    if has_none:
        non_none = [t for t in toks if t != "none"]
        if non_none:
            raise ValueError(
                "Ambiguous pde_mode: 'none/off' cannot be combined with "
                f"other modes: {non_none!r}."
            )
        return ["none"]

    active = set()

    for tok in toks:
        if tok == "both":
            active.update(("consolidation", "gw_flow"))
        elif tok in ("consolidation", "gw_flow"):
            active.add(tok)
        else:
            raise ValueError(
                f"Unexpected normalized token: {tok!r}."
            )

    if not active:
        return ["none"]

    return [mode for mode in _PDE_ORDER if mode in active]


def _collapse_pde_modes_to_label(
    active_modes: Sequence[str],
) -> str:
    """
    Convert a canonical active-mode list back to a single label.
    """
    modes = list(active_modes)

    if modes == ["none"]:
        return "none"
    if modes == ["consolidation"]:
        return "consolidation"
    if modes == ["gw_flow"]:
        return "gw_flow"
    if modes == ["consolidation", "gw_flow"]:
        return "both"

    raise ValueError(
        f"Cannot collapse unexpected active mode list: {modes!r}"
    )


def process_pde_modes(
    pde_mode: str | Sequence[str] | None,
    enforce_consolidation: bool = False,
    pde_mode_config: str | Sequence[str] | None = None,
    solo_return: bool = False,
) -> list[str] | str:
    r"""
    Normalize and validate PDE mode selection.

    Parameters
    ----------
    pde_mode : str, sequence of str, or None
        Requested PDE mode(s).

        Accepted canonical values are:
        - ``"none"``
        - ``"consolidation"``
        - ``"gw_flow"``
        - ``"both"``

        Accepted aliases:
        - ``None``, ``"off"`` -> ``"none"``
        - ``"on"`` -> ``"both"``
    enforce_consolidation : bool, default=False
        If True, any resolved mode other than exact ``["consolidation"]``
        is coerced to ``["consolidation"]`` and a warning is emitted.

        This includes:
        - ``["none"]``
        - ``["gw_flow"]``
        - ``["consolidation", "gw_flow"]``
    pde_mode_config : str, sequence of str, or None, optional
        Optional override. If provided, this value takes precedence over
        ``pde_mode``.
    solo_return : bool, default=False
        If False, return a canonical list of active modes.

        If True, return a single canonical label:
        - ``"none"``
        - ``"consolidation"``
        - ``"gw_flow"``
        - ``"both"``

    Returns
    -------
    list of str or str
        Canonical PDE mode(s), either as a list or a single label.

    Raises
    ------
    TypeError
        If the input type is invalid.
    ValueError
        If a token is unsupported or the mode selection is ambiguous.

    Examples
    --------
    >>> process_pde_modes(None)
    ['none']

    >>> process_pde_modes("off")
    ['none']

    >>> process_pde_modes("on")
    ['consolidation', 'gw_flow']

    >>> process_pde_modes("both", solo_return=True)
    'both'

    >>> process_pde_modes("gw_flow", enforce_consolidation=True)
    ['consolidation']
    """
    source = (
        pde_mode_config
        if pde_mode_config is not None
        else pde_mode
    )

    raw_tokens = _flatten_mode_input(source)
    normalized_tokens = _normalize_pde_tokens(raw_tokens)
    active_modes = _canonicalize_pde_modes(normalized_tokens)

    if enforce_consolidation and active_modes != [
        "consolidation"
    ]:
        warnings.warn(
            "This model requires PDE mode 'consolidation'. "
            f"Received {active_modes!r}; coercing to ['consolidation'].",
            UserWarning,
            stacklevel=2,
        )
        active_modes = ["consolidation"]

    if solo_return:
        return _collapse_pde_modes_to_label(active_modes)

    return active_modes


def _process_pde_modes(
    pde_mode: str | list | None,
    enforce_consolidation: bool = False,
    pde_mode_config: str | list | None = None,
    solo_return: bool = False,
) -> list:
    r"""
    Process and validate the `pde_mode` argument to determine the active PDE modes.

    This function handles `pde_mode` inputs and processes them according to
    the following rules:
    - If the input is 'none', only the mode 'none' will be active.
    - If the input is 'both', it will set both 'consolidation' and 'gw_flow'.
    - If the input is not 'consolidation' and `enforce_consolidation` is True,
      it will issue a warning and fallback to using only 'consolidation'.
    - If `pde_mode_config` is provided, it overrides any other mode setting.

    Parameters
    ----------
    pde_mode : str, list of str, or None
        The desired PDE modes. Can be:
        - A string (e.g., 'consolidation', 'gw_flow', etc.)
        - A list of strings (e.g., ['consolidation', 'gw_flow'])
        - None (to set no active modes)
    enforce_consolidation : bool, default=False
        If True, the function ensures that 'consolidation' is the only
        mode active and issues a warning if another mode is passed.
    pde_mode_config : str, list of str, or None, optional
        If provided, overrides the `pde_mode` argument.

    Returns
    -------
    list of str
        A list of active PDE modes. The list will contain the modes in lowercase.
        If 'none' or 'both' is specified, these will be processed according to the logic.

    Raises
    ------
    TypeError
        If `pde_mode` is neither a string, a list of strings, nor None.

    Warnings
    --------
    If `enforce_consolidation` is True and a mode other than 'consolidation'
    is passed, a warning will be issued, and 'consolidation' will be used
    as the active mode instead.
    """
    # If pde_mode_config is provided, use it, otherwise fall back to pde_mode
    if pde_mode_config:
        pde_mode = pde_mode_config

    if isinstance(pde_mode, str):
        if pde_mode.lower() == "on":
            pde_mode = "both"
        pde_modes_active = [pde_mode.lower()]

    elif isinstance(pde_mode, list):
        pde_modes_active = [
            str(p_type).lower() for p_type in pde_mode
        ]
    elif pde_mode is None:
        pde_modes_active = [
            "none"
        ]  # Explicitly 'none' if None is provided
    else:
        raise TypeError(
            "`pde_mode` must be a string, list of strings, or None."
        )

    # Handle special cases for "none" and "both"
    if (
        "none" in pde_modes_active
        or "off" in pde_modes_active
    ):  # If 'none' is present, override others
        pde_modes_active = ["none"]
    if (
        "both" in pde_modes_active or "on" in pde_modes_active
    ):  # If 'both' is present, use both modes
        pde_modes_active = ["consolidation", "gw_flow"]

    # Enforce consolidation mode if specified
    if (
        enforce_consolidation
        and "consolidation" not in pde_modes_active
    ):
        warnings.warn(
            "You have passed a mode other than 'consolidation'. "
            "Falling back to 'consolidation' as the active mode.",
            UserWarning,
            stacklevel=2,
        )
        pde_modes_active = ["consolidation"]

    # Ensure 'consolidation' is the only active mode if it was defaulted,
    # or if user explicitly selected only unsupported modes.

    if any(
        unsupported in pde_modes_active
        for unsupported in {
            "consolidation",
            "gw_flow",
            "both",
            "none",
            "on",
            "off",
        }
    ):
        # This case means decorator didn't force 'consolidation', so we do it here
        logger.info(
            f"Unsupported pde_mode '{pde_mode}' "
            "selected without 'consolidation'. "
            "Model will use 'consolidation' mode."
        )
        # pde_modes_active = ['consolidation']

    if solo_return:
        pde_modes_active = pde_modes_active[0]

    # Return the processed pde_modes_active
    return pde_modes_active


@SaveFile
def format_pinn_predictions(
    predictions: dict[str, Tensor] | None = None,
    model: Model | None = None,
    model_inputs: dict[str, Tensor] | None = None,
    y_true_dict: dict[str, np.ndarray | Tensor] | None = None,
    target_mapping: dict[str, str] | None = None,
    include_gwl: bool = True,
    include_coords: bool = True,
    quantiles: list[float] | None = None,
    forecast_horizon: int | None = None,
    output_dims: dict[str, int] | None = None,
    ids_data_array: np.ndarray | pd.DataFrame | None = None,
    ids_cols: list[str] | None = None,
    ids_cols_indices: list[int] | None = None,
    scaler_info: dict[str, dict[str, Any]] | None = None,
    coord_scaler: Any | None = None,
    evaluate_coverage: bool = False,
    coverage_quantile_indices: tuple[int, int] = (0, -1),
    savefile: str | None = None,
    _logger: logging.Logger
    | Callable[[str], None]
    | None = None,
    name: str | None = None,
    model_name: str | None = None,
    stop_check: Callable[[], bool] = None,
    verbose: int = 0,
    **kwargs,
) -> pd.DataFrame:
    r"""Formats PINN model predictions into a structured pandas DataFrame.

    This is a general-purpose utility for transforming raw model outputs
    (from models like PIHALNet or TransFlowSubsNet) into a long-format
    DataFrame suitable for analysis, visualization, or export.

    This is a powerful, general-purpose utility for transforming raw
    model outputs into a long-format DataFrame suitable for analysis,
    visualization, or export. It handles multi-target outputs (e.g.,
    subsidence and GWL), point or quantile forecasts, and can
    optionally include true values, coordinate information, and other
    metadata. It also supports inverse-scaling of predictions and
    evaluation of quantile coverage.

    Parameters
    ----------
    predictions : dict of Tensors, optional
        The dictionary of prediction tensors, typically returned by a
        model's ``.predict()`` method. Keys should match the model's
        output layer names (e.g., ``'subs_pred'``, ``'gwl_pred'``).
        If ``None``, predictions are generated internally using the
        `model` and `model_inputs` arguments. Default is ``None``.
    model : keras.Model, optional
        A compiled Keras model instance used to generate predictions
        if the `predictions` dictionary is not provided.
        Default is ``None``.
    model_inputs : dict of Tensors, optional
        A dictionary of input tensors matching the model's signature,
        required only if `predictions` is ``None``. Default is ``None``.
    y_true_dict : dict, optional
        A dictionary containing the ground-truth target arrays, keyed
        by their base names (e.g., ``'subsidence'``, ``'gwl'``).
        If provided, an ``<target>_actual`` column will be added to
        the output DataFrame for comparison. Default is ``None``.
    target_mapping : dict, optional
        A custom mapping from model output keys to desired base names in
        the DataFrame columns. For example:
        ``{'subs_pred': 'subsidence_mm', 'gwl_pred': 'head_m'}``.
        Default is ``None``.
    include_gwl : bool, default=True
        Toggles the inclusion of groundwater level (GWL) predictions
        in the final DataFrame.
    include_coords : bool, default=True
        Toggles the inclusion of the spatio-temporal coordinate columns
        (``coord_t``, ``coord_x``, ``coord_y``) in the final DataFrame.
    quantiles : list of float, optional
        The list of quantile levels (e.g., ``[0.1, 0.5, 0.9]``) that
        the model predicted. This is crucial for correctly parsing
        probabilistic forecasts. Default is ``None``.
    forecast_horizon : int, optional
        The length of the forecast horizon. If ``None``, it is inferred
        from the shape of the prediction tensors. Default is ``None``.
    output_dims : dict of str, optional
        A dictionary specifying the feature dimension of each target,
        e.g., ``{'subs_pred': 1, 'gwl_pred': 1}``. If ``None``, it's
        inferred from the tensor shapes. Default is ``None``.
    ids_data_array : np.ndarray or pd.DataFrame, optional
        An array or DataFrame containing static identifiers (e.g., well
        IDs, site categories) for each sample. Its length must match
        the number of samples in the prediction. Default is ``None``.
    ids_cols : list of str, optional
        A list of column names for the `ids_data_array`. Required if
        `ids_data_array` is a NumPy array. Default is ``None``.
    ids_cols_indices : list of int, optional
        A list of column indices to select from `ids_data_array` if it
        is a NumPy array. Default is ``None``.
    scaler_info : dict, optional
        A dictionary providing the necessary information to perform
        inverse scaling on a per-target basis. Each key should be a
        target name (e.g., 'subsidence') and its value a dictionary
        containing ``{'scaler': obj, 'all_features': list, 'idx': int}``.
        Default is ``None``.
    coord_scaler : object, optional
        A fitted scikit-learn-like scaler object used to perform an
        inverse transform on the coordinate columns. Default is ``None``.
    evaluate_coverage : bool, default=False
        If ``True`` and quantile predictions are present, calculates the
        unconditional coverage of the prediction interval.
    coverage_quantile_indices : tuple of (int, int), default=(0, -1)
        The indices of the lower and upper quantiles in the sorted
        `quantiles` list to use for the coverage calculation. Default
        is ``(0, -1)``, which corresponds to the full range.
    savefile : str, optional
        If a file path is provided, the final DataFrame is saved to a
        CSV file at this location. Default is ``None``.
    name: str or None
        Name of the prediction. Name is used to format the output of
        the data and coverage result if applicable.
    model_name: str, None,
        Name of the model.

    verbose : int, default=0
        The verbosity level, from 0 (silent) to 5 (trace every step).
    **kwargs: dict,
        Additional keyword arguments for future extensions.

    Returns
    -------
    pd.DataFrame
        A long-format DataFrame where each row represents a single
        forecast step for a single sample. Columns include sample and
        step identifiers, coordinates, predictions, and optionally
        actuals and metadata.

    Notes
    -----
    - The function returns a column-aligned DataFrame, which simplifies
      subsequent analysis and plotting.
    - For quantile forecasts, prediction columns are named using the
      pattern ``<target_name>_q<quantile*100>``, e.g., ``subsidence_q5``,
      ``subsidence_q50``, ``subsidence_q95``.
    - For point forecasts, the column is named ``<target_name>_pred``.

    See Also
    --------
    geoprior.plot.forecast.plot_forecasts : A powerful utility for
        visualizing the DataFrame produced by this function.
    """
    #
    # This acts as a backward-compatible wrapper.
    return format_pihalnet_predictions(
        pihalnet_outputs=predictions,
        model=model,
        model_inputs=model_inputs,
        y_true_dict=y_true_dict,
        target_mapping=target_mapping,
        include_gwl=include_gwl,
        include_coords=include_coords,
        quantiles=quantiles,
        forecast_horizon=forecast_horizon,
        output_dims=output_dims,
        ids_data_array=ids_data_array,
        ids_cols=ids_cols,
        ids_cols_indices=ids_cols_indices,
        scaler_info=scaler_info,
        coord_scaler=coord_scaler,
        evaluate_coverage=evaluate_coverage,
        coverage_quantile_indices=coverage_quantile_indices,
        savefile=savefile,
        name=name,
        model_name=model_name,
        verbose=verbose,
        _logger=_logger,
        stop_check=stop_check,
        **kwargs,
    )


@SaveFile
def format_pihalnet_predictions(
    pihalnet_outputs: dict[str, Tensor] | None = None,
    model: Model | None = None,
    model_inputs: dict[str, Tensor] | None = None,
    y_true_dict: dict[str, np.ndarray | Tensor] | None = None,
    target_mapping: dict[str, str] | None = None,
    include_gwl: bool = True,
    include_coords: bool = True,
    quantiles: list[float] | None = None,
    forecast_horizon: int | None = None,
    output_dims: dict[str, int] | None = None,
    ids_data_array: np.ndarray | pd.DataFrame | None = None,
    ids_cols: list[str] | None = None,
    ids_cols_indices: list[int] | None = None,
    scaler_info: dict[str, dict[str, Any]] | None = None,
    coord_scaler: Any | None = None,
    evaluate_coverage: bool = False,
    coverage_quantile_indices: tuple[int, int] = (0, -1),
    savefile: str | None = None,
    name: str | None = None,
    model_name: str | None = None,
    apply_mask: bool = False,
    mask_values: float | int | None = None,
    mask_fill_value: float | None = None,
    verbose: int = 0,
    _logger: logging.Logger
    | Callable[[str], None]
    | None = None,
    stop_check: Callable[[], bool] = None,
    **kwargs,
) -> pd.DataFrame:
    r"""
    Formats PIHALNet/GeoPriorSubsNet predictions into a structured
    pandas DataFrame, handling inversion, quantiles, and coordinates.

    This function is the core formatter. It:
    1. Gets model outputs (or uses provided ones).
    2. Unpacks 'data_final' if `model_name` is 'geoprior'.
    3. Inverse-transforms all prediction and actual arrays using `scaler_info`.
    4. Builds a long-format DataFrame with sample_idx and forecast_step.
    5. Appends inverted quantile/point predictions.
    6. Appends inverted actual values.
    7. Appends inverted coordinates.
    8. Appends static/ID columns.
    9. Evaluates coverage on the inverted data.

    Parameters
    ----------
    pihalnet_outputs : dict, optional
        Raw output from `model.predict()`. If None, `model` and
        `model_inputs` must be provided.
    model : tf.keras.Model, optional
        Trained model instance (if `pihalnet_outputs` is None).
    model_inputs : dict, optional
        Inputs for the model to generate predictions
        (if `pihalnet_outputs` is None).
    y_true_dict : dict, optional
        Dictionary of true target arrays (e.g., {'subs_pred': y_true_s}).
        Required for including actuals and evaluating coverage.
    target_mapping : dict, optional
        Maps prediction keys to base names for DataFrame columns.
        Default: {'subs_pred': 'subsidence', 'gwl_pred': 'gwl'}.
    include_gwl : bool, default True
        Whether to include 'gwl_pred' in the final DataFrame.
    include_coords : bool, default True
        Whether to include 'coord_t', 'coord_x', 'coord_y' columns.
    quantiles : list[float], optional
        List of quantiles (e.g., [0.1, 0.5, 0.9]). If provided,
        quantile columns (e.g., 'subsidence_q10') are created.
    forecast_horizon : int, optional
        The forecast horizon length (H). If not provided, it's
        inferred from the prediction array's shape.
    output_dims : dict, optional
        Maps prediction keys to their output dimension (O).
        E.g., {'subs_pred': 1, 'gwl_pred': 1}. Crucial for
        correctly splitting GeoPrior outputs and reshaping.
    ids_data_array : np.ndarray or pd.DataFrame, optional
        Static/ID data (e.g., original coordinates) to merge.
        Must have the same number of samples (B) as predictions.
    ids_cols : list[str], optional
        Column names if `ids_data_array` is a DataFrame.
    ids_cols_indices : list[int], optional
        Column indices if `ids_data_array` is a NumPy array.
    scaler_info : dict, optional
        Dictionary for inverse scaling. Each target entry should provide
        a fitted scaler, the target index inside that scaler, and the
        feature-name ordering used when the scaler was fit.
    coord_scaler : sklearn.preprocessing.Scaler, optional
        A *fitted* scaler object for inverse transforming the 'coords' tensor.
    evaluate_coverage : bool, default False
        If True, calculates coverage percentage for quantiles.
    coverage_quantile_indices : tuple[int, int], default (0, -1)
        Indices of the low and high quantiles in the `quantiles` list
        to use for coverage (e.g., 0 and -1 for 10th and 90th).
    savefile : str, optional
        If provided, saves the final DataFrame to this path.
    model_name : str, optional
        Specifies the model type. If 'geoprior' or 'geopriorsubsnet',
        triggers unpacking of the 'data_final' output.
    apply_mask : bool, default False
        If True, masks predictions based on `mask_values` in the
        first target's `_actual` column.
    mask_values : float or int, optional
        The value in the `_actual` column to trigger masking.
    mask_fill_value : float, optional
        The value to replace masked predictions with (e.g., np.nan).
    verbose : int, default 0
        Logging verbosity.
    _logger : logging.Logger or callable, optional
        Logger object.
    stop_check : callable, optional
        Function to check for early stopping.

    Returns
    -------
    pd.DataFrame
        A long-format DataFrame with predictions, actuals, and coordinates.
    """
    vlog(
        f"Starting model prediction formatting (verbose={verbose}).",
        level=3,
        verbose=verbose,
        logger=_logger,
    )

    # --- 1. Obtain Model Predictions (if not provided) ---
    if pihalnet_outputs is None:
        if model is None or model_inputs is None:
            raise ValueError(
                "If 'pihalnet_outputs' is None, both 'model' and "
                "'model_inputs' must be provided."
            )
        vlog(
            "  Predictions not provided, generating from model...",
            level=4,
            verbose=verbose,
            logger=_logger,
        )
        try:
            pihalnet_outputs = model.predict(
                model_inputs, verbose=0
            )
            if not isinstance(pihalnet_outputs, dict):
                raise ValueError(
                    "Model output is not a dictionary as expected."
                )
            vlog(
                "  Model predictions generated.",
                level=5,
                verbose=verbose,
                logger=_logger,
            )
        except Exception as e:
            raise RuntimeError(
                f"Failed to generate predictions from model: {e}"
            ) from e

        if stop_check and stop_check():
            raise InterruptedError(
                "Model prediction aborted."
            )

    # --- 2. Unpack GeoPrior Model Output (if needed) ---
    is_geoprior = str(model_name).lower().strip() in (
        "geoprior",
        "geopriorsubsnet",
    )
    if is_geoprior:
        vlog(
            f"GeoPrior model detected (model_name='{model_name}'). "
            "Unpacking 'data_final' tensor.",
            level=3,
            verbose=verbose,
            logger=_logger,
        )

        if output_dims is None:
            output_dims = {"subs_pred": 1, "gwl_pred": 1}
            vlog(
                "  `output_dims` not provided, defaulting to "
                "{'subs_pred': 1, 'gwl_pred': 1}.",
                level=4,
                verbose=verbose,
                logger=_logger,
            )

        pihalnet_outputs = _unpack_geoprior_outputs(
            geoprior_outputs=pihalnet_outputs,
            output_dims=output_dims,
            quantiles=quantiles,
            verbose=verbose,
            _logger=_logger,
        )

    # --- 3. Ensure all data is NumPy & INVERSE TRANSFORM ---
    if target_mapping is None:
        target_mapping = {
            "subs_pred": "subsidence",
            "gwl_pred": "gwl",
        }
        vlog(
            "  `target_mapping` not provided, using defaults.",
            level=4,
            verbose=verbose,
            logger=_logger,
        )

    if not include_gwl:
        target_mapping.pop("gwl_pred", None)

    # Process predictions
    processed_preds = {}
    if pihalnet_outputs:
        vlog(
            "  Converting and inverse-transforming predictions...",
            level=4,
            verbose=verbose,
            logger=_logger,
        )
        for pred_key, base_name in target_mapping.items():
            if pred_key in pihalnet_outputs:
                val_tensor = pihalnet_outputs[pred_key]
                pred_array = (
                    val_tensor.numpy()
                    if hasattr(val_tensor, "numpy")
                    else np.asarray(val_tensor)
                )

                inverted_array = _inverse_transform_array(
                    pred_array,
                    base_name,
                    scaler_info,
                    verbose,
                    _logger,
                )
                processed_preds[pred_key] = inverted_array
            else:
                vlog(
                    f"  [WARN] Prediction key '{pred_key}' not found "
                    "in model outputs.",
                    level=2,
                    verbose=verbose,
                    logger=_logger,
                )

    if not processed_preds:
        vlog(
            "  No valid prediction keys found in model output. "
            "Returning empty DF.",
            level=1,
            verbose=verbose,
            logger=_logger,
        )
        return pd.DataFrame()

    # Process actuals (y_true)
    processed_actuals = {}
    if y_true_dict:
        vlog(
            "  Converting and inverse-transforming actuals (y_true)...",
            level=4,
            verbose=verbose,
            logger=_logger,
        )
        for true_key, base_name in target_mapping.items():
            # --- START FIX ---
            # Try to get the tensor using the pred_key ('subs_pred') first
            val_tensor = y_true_dict.get(true_key)
            if val_tensor is None:
                # If not found, try getting it with the base_name ('subsidence')
                val_tensor = y_true_dict.get(base_name)

            if val_tensor is not None:
                # --- END FIX ---
                true_array = (
                    val_tensor.numpy()
                    if hasattr(val_tensor, "numpy")
                    else np.asarray(val_tensor)
                )

                inverted_array = _inverse_transform_array(
                    true_array,
                    base_name,
                    scaler_info,
                    verbose,
                    _logger,
                )
                # Use true_key ('subs_pred') to align with processed_preds
                processed_actuals[true_key] = inverted_array
            else:
                vlog(
                    f"  [WARN] True data key '{true_key}' (or '{base_name}') "
                    "not found in y_true_dict.",
                    level=2,
                    verbose=verbose,
                    logger=_logger,
                )

    # --- 4. Initialize DataFrame (Base) ---
    first_pred_key = list(processed_preds.keys())[0]
    first_pred_array = processed_preds[first_pred_key]

    num_samples = first_pred_array.shape[0]
    H_inferred = forecast_horizon or first_pred_array.shape[1]

    if num_samples == 0 or H_inferred == 0:
        vlog(
            "  No samples or zero forecast horizon. Returning empty DF.",
            level=1,
            verbose=verbose,
            logger=_logger,
        )
        return pd.DataFrame()

    vlog(
        f"  Formatting for {num_samples} samples, Horizon={H_inferred}.",
        level=4,
        verbose=verbose,
        logger=_logger,
    )

    sample_indices = np.repeat(
        np.arange(num_samples), H_inferred
    )
    forecast_steps = np.tile(
        np.arange(1, H_inferred + 1), num_samples
    )

    all_data_dfs = [
        pd.DataFrame(
            {
                "sample_idx": sample_indices,
                "forecast_step": forecast_steps,
            }
        )
    ]

    # --- 5. Populate DataFrame with INVERTED Data ---
    if output_dims is None:
        vlog(
            "  [WARN] `output_dims` is None. Inferring from arrays. "
            "This may be unreliable for multi-output models.",
            level=2,
            verbose=verbose,
            logger=_logger,
        )
        output_dims = {}
        for k, v in processed_preds.items():
            output_dims[k] = (
                v.shape[-1]
                if quantiles is None
                else v.shape[-1]
            )

    if quantiles is not None:
        df_part = _add_quantiles_to_df(
            processed_preds,
            target_mapping,
            quantiles,
            output_dims,
            num_samples,
            H_inferred,
            verbose,
            _logger,
        )
        all_data_dfs.append(df_part)
    else:
        df_part = _add_point_preds_to_df(
            processed_preds,
            target_mapping,
            output_dims,
            num_samples,
            H_inferred,
            verbose,
            _logger,
        )
        all_data_dfs.append(df_part)

    if processed_actuals:
        df_part = _add_actuals_to_df(
            processed_actuals,
            target_mapping,
            output_dims,
            num_samples,
            H_inferred,
            verbose,
            _logger,
        )
        all_data_dfs.append(df_part)

    # # --- 6. Add Coordinates (inverted inside helper) ---
    df_coord_cols = []
    if include_coords and model_inputs:
        df_part, df_coord_cols = _add_coords_to_df(
            model_inputs=model_inputs,
            coord_scaler=coord_scaler,
            num_samples=num_samples,
            H_inferred=H_inferred,
            verbose=verbose,
            _logger=_logger,
            force_future_from=kwargs.get(
                "forecast_start_year"
            ),
        )
        all_data_dfs.append(df_part)

    # --- 7. Add IDs (no inversion needed) ---
    df_part = _add_ids_to_df(
        ids_data_array,
        ids_cols,
        ids_cols_indices,
        num_samples,
        H_inferred,
        verbose,
        _logger,
    )
    all_data_dfs.append(df_part)

    if stop_check and stop_check():
        raise InterruptedError(
            "DataFrame population aborted."
        )

    # --- 8. Concatenate all DataFrames ---
    final_df = pd.concat(all_data_dfs, axis=1)

    # --- 10. Optional Masking ---
    if apply_mask:
        vlog(
            "  Applying mask to prediction columns...",
            level=4,
            verbose=verbose,
            logger=_logger,
        )
        if mask_values is None or mask_fill_value is None:
            raise ValueError(
                "When apply_mask=True, both mask_values and "
                "mask_fill_value must be provided."
            )

        # we assume you only ever want to mask against your first target:
        first_base = list(target_mapping.values())[0]
        ref_col = f"{first_base}_actual"
        if ref_col not in final_df.columns:
            vlog(
                f"  [WARN] Masking reference column '{ref_col}' not in "
                "DataFrame. Skipping masking.",
                level=2,
                verbose=verbose,
                logger=_logger,
            )
        else:
            # collect all the forecast columns you produced
            mask_cols = [
                c
                for c in final_df.columns
                if any(
                    c.startswith(f"{base}_q")
                    for base in target_mapping.values()
                )
                or any(
                    c.startswith(f"{base}_pred")
                    for base in target_mapping.values()
                )
            ]

            if not mask_cols:
                vlog(
                    "  [WARN] No maskable forecast columns found in final_df. "
                    "Skipping masking.",
                    level=2,
                    verbose=verbose,
                    logger=_logger,
                )
            else:
                try:
                    # Use the imported mask_by_reference function
                    final_df = mask_by_reference(
                        data=final_df,
                        ref_col=ref_col,
                        values=mask_values,
                        find_closest=False,
                        fill_value=mask_fill_value,
                        mask_columns=mask_cols,
                        error="ignore",
                        inplace=False,
                    )
                    vlog(
                        f"  Successfully applied mask to {len(mask_cols)} "
                        f"columns based on '{ref_col}'.",
                        level=4,
                        verbose=verbose,
                        logger=_logger,
                    )
                except Exception as e:
                    vlog(
                        f"  [WARN] Failed to apply mask: {e}",
                        level=2,
                        verbose=verbose,
                        logger=_logger,
                    )

    # --- 9. Evaluate Coverage (now compares inverted vs. inverted) ---
    if evaluate_coverage and quantiles and processed_actuals:
        _evaluate_coverage(
            df=final_df,
            target_mapping=target_mapping,
            q_indices=coverage_quantile_indices,
            savefile=savefile,
            name=name,
            verbose=verbose,
            _logger=_logger,
        )

    if stop_check and stop_check():
        raise InterruptedError("Metric/Masking aborted.")

    vlog(
        "Model prediction formatting to DataFrame complete.",
        level=3,
        verbose=verbose,
        logger=_logger,
    )

    return final_df


def _format_target_predictions(
    predictions_np: np.ndarray,
    num_samples: int,
    H: int,  # Horizon
    O: int,  # Output dim for this specific target
    base_target_name: str,
    quantiles: list[float] | None,
    verbose: int = 0,
) -> tuple[list[str], pd.DataFrame]:
    """Helper to format predictions for a single target variable."""
    pred_cols_names = []

    # Expected input shapes to this helper:
    # Point: (N, H, O)
    # Quantile: (N, H, Q, O) OR (N, H, Q) if O=1 was pre-squeezed

    if quantiles:
        num_q = len(quantiles)
        # Ensure predictions_np is (N, H, Q, O)
        if (
            predictions_np.ndim == 3
            and predictions_np.shape[-1] == num_q
            and O == 1
        ):
            # Case: (N, H, Q), implies O=1
            preds_to_process = np.expand_dims(
                predictions_np, axis=-1
            )  # (N,H,Q,1)
        elif (
            predictions_np.ndim == 3
            and predictions_np.shape[-1] == num_q * O
        ):
            # Case: (N, H, Q*O)
            preds_to_process = predictions_np.reshape(
                (num_samples, H, num_q, O)
            )
        elif (
            predictions_np.ndim == 4
            and predictions_np.shape[2] == num_q
            and predictions_np.shape[3] == O
        ):
            # Case: (N, H, Q, O) - already correct
            preds_to_process = predictions_np
        else:
            raise ValueError(
                f"Unexpected quantile prediction shape for {base_target_name}: "
                f"{predictions_np.shape}. Expected compatible with N={num_samples}, "
                f"H={H}, Q={num_q}, O={O}"
            )

        # Now preds_to_process is (N, H, Q, O)
        df_data_for_concat = []
        for o_idx in range(O):
            for q_idx, q_val in enumerate(quantiles):
                col_name = f"{base_target_name}"
                if O > 1:
                    col_name += f"_{o_idx}"
                col_name += f"_q{int(q_val * 100)}"
                pred_cols_names.append(col_name)
                df_data_for_concat.append(
                    preds_to_process[
                        :, :, q_idx, o_idx
                    ].reshape(-1)
                )
        pred_df_part = pd.DataFrame(
            dict(
                zip(
                    pred_cols_names,
                    df_data_for_concat,
                    strict=False,
                )
            )
        )

    else:  # Point forecast
        # predictions_np should be (N, H, O)
        if (
            predictions_np.ndim != 3
            or predictions_np.shape[-1] != O
        ):
            raise ValueError(
                f"Unexpected point prediction shape for {base_target_name}: "
                f"{predictions_np.shape}. Expected (N,H,O) with O={O}"
            )

        df_data_for_concat = []
        for o_idx in range(O):
            col_name = f"{base_target_name}"
            if O > 1:
                col_name += f"_{o_idx}"
            col_name += "_pred"
            pred_cols_names.append(col_name)
            df_data_for_concat.append(
                predictions_np[:, :, o_idx].reshape(-1)
            )
        pred_df_part = pd.DataFrame(
            dict(
                zip(
                    pred_cols_names,
                    df_data_for_concat,
                    strict=False,
                )
            )
        )

    return pred_cols_names, pred_df_part


def _unpack_geoprior_outputs(
    geoprior_outputs: dict[str, Tensor],
    output_dims: dict[str, int],
    quantiles: list[float] | None,
    verbose: int = 0,
    _logger: Any = None,
) -> dict[str, Tensor]:
    """
    Unpacks the nested 'data_final' tensor from GeoPriorSubsNet
    into a flat dictionary expected by the formatter.
    """
    if "data_final" not in geoprior_outputs:
        vlog(
            "  [WARN] 'data_final' key not found in GeoPrior model output. "
            "Assuming outputs are already flat.",
            level=2,
            verbose=verbose,
            logger=_logger,
        )
        return geoprior_outputs  # Return as-is

    data_final_tensor = geoprior_outputs["data_final"]

    # Get the split index for subsidence
    s_dim = output_dims.get("subs_pred")
    if s_dim is None:
        vlog(
            "  [WARN] 'output_dims' did not contain 'subs_pred'. "
            "Assuming subsidence output_dim = 1.",
            level=2,
            verbose=verbose,
            logger=_logger,
        )
        s_dim = 1  # Default to 1 as a fallback

    # Split the tensor based on whether quantiles are present
    if quantiles is not None:
        # Shape is (B, H, Q, O_total)
        vlog(
            f"  Splitting 4D quantile tensor at final axis index {s_dim}.",
            level=4,
            verbose=verbose,
            logger=_logger,
        )
        s_pred_tensor = data_final_tensor[..., :s_dim]
        h_pred_tensor = data_final_tensor[..., s_dim:]
    else:
        # Shape is (B, H, O_total)
        vlog(
            f"  Splitting 3D point-forecast tensor at final axis index {s_dim}.",
            level=4,
            verbose=verbose,
            logger=_logger,
        )
        s_pred_tensor = data_final_tensor[..., :s_dim]
        h_pred_tensor = data_final_tensor[..., s_dim:]

    # Create the new flat dictionary
    unpacked_outputs = {
        "subs_pred": s_pred_tensor,
        "gwl_pred": h_pred_tensor,
    }

    # Pass through any other keys that aren't the standard nested ones
    for k, v in geoprior_outputs.items():
        if k not in [
            "data_final",
            "data_mean",
            "phys_mean_raw",
        ]:
            unpacked_outputs[k] = v

    return unpacked_outputs


def _inverse_transform_array(
    array: np.ndarray,
    base_name: str,
    scaler_info: dict[str, dict[str, Any]] | None,
    verbose: int = 0,
    _logger: Any = None,
) -> np.ndarray:
    """
    Helper to inverse-transform a single numpy array using scaler_info.

    This function finds the correct scaler and column index from
    scaler_info and applies the inverse min-max scaling.
    """
    if scaler_info is None or base_name not in scaler_info:
        # No scaler info for this target, return original array
        vlog(
            f"    - No scaler_info found for '{base_name}'. "
            "Returning original (scaled) array.",
            level=4,
            verbose=verbose,
            logger=_logger,
        )
        return array

    try:
        info = scaler_info[base_name]
        scaler = info["scaler"]
        col_index = info["idx"]

        # Get the specific min and scale for this column from the scaler
        # scaler.min_ is the 'intercept'
        # scaler.scale_ is the 'coefficient'
        min_val = scaler.min_[col_index]
        scale_val = scaler.scale_[col_index]

        # --- THIS IS THE FIX ---
        # The inverse of (X * scale_) + min_ is (X - min_) / scale_
        if scale_val == 0:
            # Handle case where scale is zero (all original values were the same)
            # The inverse is the original constant value, which we get from data_min_
            try:
                original_min = scaler.data_min_[col_index]
            except AttributeError:
                # Fallback if data_min_ is not available (should be on a fitted scaler)
                original_min = 0  # This is a guess, but division by zero is worse
                vlog(
                    f"  [WARN] Scaler for '{base_name}' has scale_=0 but "
                    "no 'data_min_' attribute. Inverse may be incorrect.",
                    level=2,
                    verbose=verbose,
                    logger=_logger,
                )
            inverted_array = np.full_like(array, original_min)
        else:
            inverted_array = (array - min_val) / scale_val

        vlog(
            f"    - Applied inverse transform to '{base_name}' array "
            f"(min_attr: {min_val:.4f}, scale_attr: {scale_val:.4f})",
            level=5,
            verbose=verbose,
            logger=_logger,
        )

        return inverted_array

    except AttributeError:
        vlog(
            f"  [WARN] Scaler for '{base_name}' is missing 'min_' or 'scale_'"
            " attributes. Returning original array.",
            level=2,
            verbose=verbose,
            logger=_logger,
        )
        return array
    except IndexError:
        vlog(
            f"  [WARN] Column index '{col_index}' out of bounds for "
            f"scaler of '{base_name}'. Returning original array.",
            level=2,
            verbose=verbose,
            logger=_logger,
        )
        return array
    except Exception as e:
        vlog(
            f"  [WARN] Failed to inverse transform '{base_name}': {e}"
            ". Returning original array.",
            level=2,
            verbose=verbose,
            logger=_logger,
        )
        return array


def _add_quantiles_to_df(
    predictions: dict[str, np.ndarray],
    target_mapping: dict[str, str],
    quantiles: list[float],
    output_dims: dict[str, int],
    num_samples: int,
    H_inferred: int,
    verbose: int = 0,
    _logger: Any = None,
) -> pd.DataFrame:
    """Formats INVERTED quantile predictions into a DataFrame."""
    df_parts = []
    for pred_key, base_name in target_mapping.items():
        if pred_key not in predictions:
            continue

        pred_array = predictions[
            pred_key
        ]  # Should be (B, H, Q, O)

        O_target = output_dims.get(pred_key)
        if O_target is None:
            O_target = pred_array.shape[-1]
            vlog(
                f"  [WARN] output_dim for '{pred_key}' not specified. "
                f"Inferred {O_target} from array shape.",
                level=2,
                verbose=verbose,
                logger=_logger,
            )

        if pred_array.shape != (
            num_samples,
            H_inferred,
            len(quantiles),
            O_target,
        ):
            vlog(
                f"  [WARN] Quantile array for '{pred_key}' has unexpected shape "
                f"{pred_array.shape}. Expected {(num_samples, H_inferred, len(quantiles), O_target)}. "
                "Skipping.",
                level=2,
                verbose=verbose,
                logger=_logger,
            )
            continue

        # Reshape: (B, H, Q, O) -> (B*H, Q*O)
        reshaped_preds = pred_array.reshape(
            num_samples * H_inferred, -1
        )

        # Create column names
        col_names = []
        for o_idx in range(O_target):
            for q in quantiles:
                # Format quantile string (e.g., q05, q50, q95)
                q_str = f"q{int(q * 100):02d}"
                if q_str == "q50":
                    q_str = "q50"  # Standardize
                elif q_str == "q10":
                    q_str = "q10"
                elif q_str == "q90":
                    q_str = "q90"
                elif q_str == "q05":
                    q_str = "q05"
                elif q_str == "q95":
                    q_str = "q95"

                col_name = f"{base_name}_{q_str}"
                if O_target > 1:
                    col_name += f"_{o_idx}"
                col_names.append(col_name)

        if reshaped_preds.shape[1] != len(col_names):
            vlog(
                f"  [WARN] Mismatch in quantile columns for '{base_name}'. "
                f"Expected {len(col_names)} columns, found {reshaped_preds.shape[1]}. "
                "Skipping.",
                level=2,
                verbose=verbose,
                logger=_logger,
            )
            continue

        df_parts.append(
            pd.DataFrame(reshaped_preds, columns=col_names)
        )
        vlog(
            f"    - Added quantile columns for '{base_name}': {col_names}",
            level=5,
            verbose=verbose,
            logger=_logger,
        )

    return (
        pd.concat(df_parts, axis=1)
        if df_parts
        else pd.DataFrame()
    )


def _add_point_preds_to_df(
    predictions: dict[str, np.ndarray],
    target_mapping: dict[str, str],
    output_dims: dict[str, int],
    num_samples: int,
    H_inferred: int,
    verbose: int = 0,
    _logger: Any = None,
) -> pd.DataFrame:
    """Formats INVERTED point predictions into a DataFrame."""
    df_parts = []
    for pred_key, base_name in target_mapping.items():
        if pred_key not in predictions:
            continue

        pred_array = predictions[pred_key]  # (B, H, O)
        O_target = output_dims.get(pred_key)
        if O_target is None:
            O_target = pred_array.shape[-1]
            vlog(
                f"  [WARN] output_dim for '{pred_key}' not specified. "
                f"Inferred {O_target} from array shape.",
                level=2,
                verbose=verbose,
                logger=_logger,
            )

        if pred_array.shape != (
            num_samples,
            H_inferred,
            O_target,
        ):
            vlog(
                f"  [WARN] Point pred array for '{pred_key}' has unexpected shape "
                f"{pred_array.shape}. Expected {(num_samples, H_inferred, O_target)}. "
                "Skipping.",
                level=2,
                verbose=verbose,
                logger=_logger,
            )
            continue

        reshaped_preds = pred_array.reshape(
            num_samples * H_inferred, O_target
        )

        col_names = []
        for o_idx in range(O_target):
            col_name = f"{base_name}_pred"
            if O_target > 1:
                col_name += f"_{o_idx}"
            col_names.append(col_name)

        df_parts.append(
            pd.DataFrame(reshaped_preds, columns=col_names)
        )
        vlog(
            f"    - Added point pred columns for '{base_name}': {col_names}",
            level=5,
            verbose=verbose,
            logger=_logger,
        )

    return (
        pd.concat(df_parts, axis=1)
        if df_parts
        else pd.DataFrame()
    )


def _add_actuals_to_df(
    y_true_dict: dict[str, np.ndarray],
    target_mapping: dict[str, str],
    output_dims: dict[str, int],
    num_samples: int,
    H_inferred: int,
    verbose: int = 0,
    _logger: Any = None,
) -> pd.DataFrame:
    """Formats INVERTED actuals into a DataFrame."""
    df_parts = []
    for pred_key, base_name in target_mapping.items():
        if pred_key not in y_true_dict:
            continue

        true_array = y_true_dict[pred_key]  # (B, H, O)
        O_target = output_dims.get(pred_key)
        if O_target is None:
            O_target = true_array.shape[-1]
            vlog(
                f"  [WARN] output_dim for '{pred_key}' not specified. "
                f"Inferred {O_target} from array shape.",
                level=2,
                verbose=verbose,
                logger=_logger,
            )

        if true_array.shape != (
            num_samples,
            H_inferred,
            O_target,
        ):
            vlog(
                f"  [WARN] Actuals array for '{pred_key}' has unexpected shape "
                f"{true_array.shape}. Expected {(num_samples, H_inferred, O_target)}. "
                "Skipping.",
                level=2,
                verbose=verbose,
                logger=_logger,
            )
            continue

        reshaped_true = true_array.reshape(
            num_samples * H_inferred, O_target
        )

        col_names = []
        for o_idx in range(O_target):
            col_name = f"{base_name}_actual"
            if O_target > 1:
                col_name += f"_{o_idx}"
            col_names.append(col_name)

        df_parts.append(
            pd.DataFrame(reshaped_true, columns=col_names)
        )
        vlog(
            f"    - Added actuals columns for '{base_name}': {col_names}",
            level=5,
            verbose=verbose,
            logger=_logger,
        )

    return (
        pd.concat(df_parts, axis=1)
        if df_parts
        else pd.DataFrame()
    )


def _add_coords_to_df(
    model_inputs: dict[str, Tensor],
    coord_scaler: Any,
    num_samples: int,
    H_inferred: int,
    verbose: int = 0,
    _logger: Any = None,
    force_future_from: float | None = None,
) -> tuple[pd.DataFrame, list[str]]:
    """Formats and inverse-transforms coordinates."""

    coord_names = ["coord_t", "coord_x", "coord_y"]
    if (
        "coords" not in model_inputs
        or model_inputs["coords"] is None
    ):
        vlog(
            "  'coords' not in model_inputs. Skipping coordinate columns.",
            level=2,
            verbose=verbose,
            logger=_logger,
        )
        return pd.DataFrame(), []

    coords_arr = model_inputs["coords"]
    if hasattr(coords_arr, "numpy"):
        coords_arr = coords_arr.numpy()

    if coords_arr.shape != (num_samples, H_inferred, 3):
        vlog(
            f"  'coords' shape mismatch ({coords_arr.shape}). Expected "
            f"{(num_samples, H_inferred, 3)}. Skipping coordinate columns.",
            level=2,
            verbose=verbose,
            logger=_logger,
        )
        return pd.DataFrame(), []

    coords_reshaped = coords_arr.reshape(
        num_samples * H_inferred, 3
    )

    if coord_scaler is not None:
        vlog(
            "  Applying inverse transform to t,x,y coordinates...",
            level=4,
            verbose=verbose,
            logger=_logger,
        )
        try:
            coords_reshaped = coord_scaler.inverse_transform(
                coords_reshaped
            )
        except Exception as e:
            vlog(
                f"  [WARN] Could not inverse transform coordinates: {e}. "
                "Using normalized coordinates.",
                level=2,
                verbose=verbose,
                logger=_logger,
            )

    # ---  override time with forecast years, if requested ---
    if force_future_from is not None:
        # years_for_one_sample: e.g. [2023, 2024, 2025] for H_inferred=3
        years = np.arange(
            float(force_future_from),
            float(force_future_from) + float(H_inferred),
            dtype=coords_reshaped.dtype,
        )
        # Repeat for all samples and flatten to match (num_samples*H, )
        years_tiled = np.tile(years, num_samples)
        coords_reshaped[:, 0] = years_tiled
        vlog(
            f"  Overriding coord_t with forecast years starting at "
            f"{force_future_from} for horizon={H_inferred}.",
            level=4,
            verbose=verbose,
            logger=_logger,
        )

    vlog(
        f"  Added coordinate columns: {coord_names}",
        level=4,
        verbose=verbose,
        logger=_logger,
    )
    return pd.DataFrame(
        coords_reshaped, columns=coord_names
    ), coord_names


def _add_ids_to_df(
    ids_data_array: Any,
    ids_cols: list[str] | None,
    ids_cols_indices: list[int] | None,
    num_samples: int,
    H_inferred: int,
    verbose: int = 0,
    _logger: Any = None,
) -> pd.DataFrame:
    """Extracts and repeats static/ID columns."""
    if ids_data_array is None:
        vlog(
            "  No `ids_data_array` provided, skipping static/ID columns.",
            level=4,
            verbose=verbose,
            logger=_logger,
        )
        return pd.DataFrame()

    vlog(
        "  Processing additional static/ID columns...",
        level=4,
        verbose=verbose,
        logger=_logger,
    )

    ids_np_array = None
    ids_cols_to_use = []

    if isinstance(ids_data_array, pd.DataFrame):
        ids_cols_to_use = (
            list(ids_cols)
            if ids_cols
            else list(ids_data_array.columns)
        )
        try:
            ids_np_array = ids_data_array[
                ids_cols_to_use
            ].values
        except KeyError as e:
            vlog(
                f"  [WARN] Columns not in ids_data_array: {e}. Skipping IDs.",
                level=2,
                verbose=verbose,
                logger=_logger,
            )
            return pd.DataFrame()

    elif isinstance(ids_data_array, np.ndarray):
        ids_np_array = ids_data_array
        if ids_cols_indices:
            ids_np_array = ids_np_array[:, ids_cols_indices]

        if (
            ids_cols
            and len(ids_cols) == ids_np_array.shape[1]
        ):
            ids_cols_to_use = list(ids_cols)
        else:
            ids_cols_to_use = [
                f"id_{k}"
                for k in range(ids_np_array.shape[1])
            ]

    elif isinstance(ids_data_array, Tensor):  # Is a Tensor
        ids_np_array = ids_data_array.numpy()
        if ids_cols_indices:
            ids_np_array = ids_np_array[:, ids_cols_indices]
        if (
            ids_cols
            and len(ids_cols) == ids_np_array.shape[1]
        ):
            ids_cols_to_use = list(ids_cols)
        else:
            ids_cols_to_use = [
                f"id_{k}"
                for k in range(ids_np_array.shape[1])
            ]
    else:
        vlog(
            f"  [WARN] Unsupported type for `ids_data_array`: "
            f"{type(ids_data_array)}. Skipping IDs.",
            level=2,
            verbose=verbose,
            logger=_logger,
        )
        return pd.DataFrame()

    # Now, repeat the data
    if (
        ids_np_array is not None
        and ids_np_array.shape[0] == num_samples
    ):
        # Use np.repeat to tile each sample H_inferred times
        expanded_ids_data = np.repeat(
            ids_np_array, H_inferred, axis=0
        )
        vlog(
            f"    - Added static/ID columns: {ids_cols_to_use}",
            level=5,
            verbose=verbose,
            logger=_logger,
        )
        return pd.DataFrame(
            expanded_ids_data, columns=ids_cols_to_use
        )
    else:
        vlog(
            f"  [WARN] `ids_data_array` sample size"
            f" ({ids_np_array.shape[0] if ids_np_array is not None else 'None'}) "
            f"does not match predictions ({num_samples}). Skipping IDs.",
            level=2,
            verbose=verbose,
            logger=_logger,
        )
        return pd.DataFrame()


def _evaluate_coverage(
    df: pd.DataFrame,
    target_mapping: dict[str, str],
    q_indices: tuple[int, int],
    savefile: str | None = None,
    name: str | None = None,
    verbose: int = 0,
    _logger: Any = None,
):
    """Calculates and logs quantile coverage from an INVERTED DataFrame."""
    vlog(
        "  Evaluating forecast coverage...",
        level=4,
        verbose=verbose,
        logger=_logger,
    )

    for base_name in target_mapping.values():
        actual_col = f"{base_name}_actual"

        # Find quantile columns
        q_cols = sorted(
            [
                c
                for c in df.columns
                if c.startswith(base_name) and "_q" in c
            ]
        )
        if not q_cols or actual_col not in df.columns:
            vlog(
                f"  Skipping coverage for '{base_name}': "
                "Missing actuals or quantile columns.",
                level=3,
                verbose=verbose,
                logger=_logger,
            )
            continue

        try:
            q_low_col = q_cols[q_indices[0]]
            q_high_col = q_cols[q_indices[1]]
        except IndexError:
            vlog(
                f"  [WARN] Coverage q_indices {q_indices} are out of bounds "
                f"for detected quantile columns: {q_cols}. Skipping coverage.",
                level=2,
                verbose=verbose,
                logger=_logger,
            )
            continue

        q_low = df[q_low_col]
        q_high = df[q_high_col]
        actual = df[actual_col]

        is_covered = (actual >= q_low) & (actual <= q_high)
        coverage_percent = is_covered.mean() * 100

        vlog(
            f"  Forecast Coverage ({base_name}, {q_low_col} to {q_high_col}): "
            f"{coverage_percent:.2f}%",
            level=1,
            verbose=verbose,
            logger=_logger,
        )

        # Store in df attributes
        df.attrs[f"{base_name}_coverage"] = coverage_percent

        # Use compute_quantile_diagnostics
        try:
            cov_savefile = None
            if savefile:
                # If savefile has an extension, treat it as a file path.
                # Example: ".../subs_eval.csv"
                #   -> ".../subs_eval_diagnostics_results.json"
                root, ext = os.path.splitext(savefile)
                if ext:
                    cov_savefile = (
                        f"{root}_diagnostics_results.json"
                    )
                else:
                    # No extension: treat `savefile` as a directory-like path
                    # Example: "results/nansha_eval"
                    #   -> "results/nansha_eval/diagnostics_results.json"
                    cov_savefile = os.path.join(
                        savefile, "diagnostics_results.json"
                    )

            compute_quantile_diagnostics(
                df,
                target_name=base_name,
                quantiles=[
                    float(c.split("_q")[-1]) / 100.0
                    for c in q_cols
                ],
                coverage_quantile_indices=q_indices,
                savefile=cov_savefile,
                name=name,
                verbose=verbose,
                logger=_logger,
            )

        except Exception as e:
            vlog(
                f"  [WARN] Failed to compute quantile diagnostics: {e}",
                level=2,
                verbose=verbose,
                logger=_logger,
            )


def _get_model_predictions(
    pihalnet_outputs: dict[str, Tensor] | None,
    model: Model | None,
    model_inputs: dict[str, Tensor] | None,
    verbose: int,
    stop_check: Callable[[], bool],
    _logger: Any,
) -> dict[str, Tensor]:
    """
    Obtain model outputs dict, running predict if needed.
    """
    if pihalnet_outputs is not None:
        return pihalnet_outputs
    if model is None or model_inputs is None:
        raise ValueError(
            "If 'pihalnet_outputs' is None, both 'model' and "
            "'model_inputs' must be provided."
        )
    vlog(
        "  Predictions not provided, generating from model...",
        level=4,
        verbose=verbose,
        logger=_logger,
    )
    try:
        outputs = model.predict(model_inputs, verbose=0)
        if not isinstance(outputs, dict):
            raise ValueError(
                "Model output is not a dict as expected from PINN models"
            )
    except Exception as e:
        raise RuntimeError(
            f"Failed to generate predictions: {e}"
        ) from e
    vlog(
        "  Model predictions generated.",
        level=5,
        verbose=verbose,
        logger=_logger,
    )

    if stop_check and stop_check():
        raise InterruptedError("Model prediction aborted.")
    return outputs


def _convert_outputs_to_numpy(
    pihalnet_outputs: dict[str, Any],
    verbose: int,
    _logger,
) -> dict[str, np.ndarray]:
    """
    Ensure prediction tensors are NumPy arrays.
    """
    proc: dict[str, np.ndarray] = {}
    for key, val in pihalnet_outputs.items():
        if key in ("subs_pred", "gwl_pred"):
            if hasattr(val, "numpy"):
                proc[key] = val.numpy()
            elif isinstance(val, np.ndarray):
                proc[key] = val
            else:
                try:
                    proc[key] = np.array(val)
                except Exception as e:
                    raise TypeError(
                        f"Could not convert output '{key}' to NumPy. "
                        f"Type: {type(val)}. Error: {e}"
                    ) from e

        elif key == "pde_residual":
            # PDE residual can be handled differently if needed
            pass  # Not typically added to this output DataFrame

    if not proc:
        vlog(
            "  No 'subs_pred' or 'gwl_pred'"
            " found in outputs. Returning empty DF.",
            level=1,
            verbose=verbose,
            logger=_logger,
        )
        return pd.DataFrame()

    return proc


def _define_targets(
    processed: dict[str, np.ndarray],
    include_gwl: bool,
    target_mapping: dict[str, str] | None = None,
    stop_check: Callable[[], bool] = None,
) -> dict[str, str]:
    """
    Map processed outputs to base target names.
    """
    if target_mapping is None:
        target_mapping = {
            "subs_pred": "subsidence",
            "gwl_pred": "gwl",
        }
    targets: dict[str, str] = {}
    if "subs_pred" in processed:
        targets["subs_pred"] = target_mapping["subs_pred"]
    if include_gwl and "gwl_pred" in processed:
        targets["gwl_pred"] = target_mapping["gwl_pred"]

    if stop_check and stop_check():
        raise InterruptedError(
            "Target confifuration aborted."
        )

    return targets


def _infer_dimensions(
    first_pred: np.ndarray,
    forecast_horizon: int | None,
    verbose: int = 0,
    _logger: Any = print,
) -> tuple[int, int]:
    """
    Infer num_samples and horizon (H) from the array.
    """
    N = first_pred.shape[0]
    H = forecast_horizon or first_pred.shape[1]
    if N == 0 or H == 0:
        raise ValueError(
            "No samples or zero forecast horizon."
        )

    vlog(
        f"  Formatting for {N} samples, Horizon={H}.",
        level=4,
        verbose=verbose,
        logger=_logger,
    )

    return N, H


def _build_sample_df(
    num_samples: int, H: int
) -> pd.DataFrame:
    """
    Build DataFrame of sample_idx and forecast_step.
    """
    idx = np.repeat(np.arange(num_samples), H)
    steps = np.tile(np.arange(1, H + 1), num_samples)
    return pd.DataFrame(
        {"sample_idx": idx, "forecast_step": steps}
    )


def _add_coordinate_columns(
    base_df: pd.DataFrame,
    model_inputs: dict[str, Any] | None,
    coord_scaler: Any,
    include_coords: bool,
    stop_check: Callable[[], bool],
    _logger: Any,
    verbose: int,
) -> pd.DataFrame:
    """
    Extract and inverse-transform coords if requested.
    """
    if not include_coords:
        return pd.DataFrame()
    coords_arr = None
    if model_inputs and "coords" in model_inputs:
        coords_arr = model_inputs["coords"]
        if hasattr(coords_arr, "numpy"):
            coords_arr = coords_arr.numpy()
    if coords_arr is None:
        return pd.DataFrame()
    N = base_df["sample_idx"].nunique()
    H = base_df["forecast_step"].max()

    if coords_arr.shape != (N, H, 3):
        vlog(
            "  'coords' shape mismatch or not found."
            " Skipping coordinate columns.",
            level=2,
            verbose=verbose,
            logger=_logger,
        )
        return pd.DataFrame()

    resh = coords_arr.reshape(N * H, 3)
    if coord_scaler:
        vlog(
            "  Applying inverse transform to t,x,y coordinates...",
            level=4,
            verbose=verbose,
            logger=_logger,
        )
        try:
            resh = coord_scaler.inverse_transform(resh)
        except Exception:
            pass

    if stop_check and stop_check():
        raise InterruptedError(
            "Coordinates transformation aborted."
        )

    return pd.DataFrame(
        resh, columns=["coord_t", "coord_x", "coord_y"]
    )


def _add_id_columns(
    base_df: pd.DataFrame,
    ids_data_array: Any,
    ids_cols: list[str] | None,
    ids_cols_indices: list[int] | None,
    stop_check: Callable[[], bool],
    _logger: Any,
    verbose: int,
) -> pd.DataFrame:
    """
    Process and repeat static ID columns per forecast step.

    Converts DataFrame, Series, NumPy array, or Tensor to
    NumPy, selects requested columns or indices, then uses
    `sample_idx` from base_df to expand each ID row
    across all forecast steps.
    """
    # Skip if no IDs provided
    if ids_data_array is None:
        vlog(
            "  No `ids_data_array` provided, skipping ID columns.",
            level=4,
            verbose=verbose,
            logger=_logger,
        )
        return pd.DataFrame()

    vlog(
        "  Processing additional static/ID columns...",
        level=4,
        verbose=verbose,
        logger=_logger,
    )

    # Convert to NumPy array
    try:
        if isinstance(ids_data_array, pd.DataFrame):
            arr = ids_data_array.values
            col_names = ids_cols or list(
                ids_data_array.columns
            )
        elif isinstance(ids_data_array, pd.Series):
            arr = ids_data_array.to_frame().values
            col_names = ids_cols or [
                ids_data_array.name or "id_0"
            ]
            if (
                arr.ndim == 2
                and len(col_names) != arr.shape[1]
            ):
                vlog(
                    "    Series IDs length mismatch, using first col.",
                    level=5,
                    verbose=verbose,
                    logger=_logger,
                )
                col_names = col_names[:1]
        elif hasattr(ids_data_array, "numpy"):
            arr = ids_data_array.numpy()
            col_names = ids_cols or [
                f"id_{i}" for i in range(arr.shape[1])
            ]
        elif isinstance(ids_data_array, np.ndarray):
            arr = ids_data_array
            if ids_cols_indices is not None:
                arr = arr[:, ids_cols_indices]
            col_names = ids_cols or [
                f"id_{i}" for i in range(arr.shape[1])
            ]
        else:
            _logger.warning(
                f"Unsupported IDs type {type(ids_data_array)}."
                " Skipping ID columns."
            )
            return pd.DataFrame()
    except Exception as e:
        _logger.warning(
            f"Error converting IDs to array: {e}."
            " Skipping ID columns."
        )
        return pd.DataFrame()

    vlog(
        f"    Converted IDs to NumPy with shape {arr.shape}",
        level=5,
        verbose=verbose,
        logger=_logger,
    )
    vlog(
        f"    Selected ID column names: {col_names}",
        level=5,
        verbose=verbose,
        logger=_logger,
    )

    # Expand rows according to sample_idx
    sample_indices = base_df["sample_idx"].values
    if arr.shape[0] != sample_indices.max() + 1:
        _logger.warning(
            f"IDs rows ({arr.shape[0]}) != samples "
            f"({sample_indices.max() + 1}). Skipping."
        )
        return pd.DataFrame()

    expanded = arr[sample_indices]
    id_df = pd.DataFrame(expanded, columns=col_names)
    vlog(
        f"    Expanded ID DataFrame shape: {id_df.shape}",
        level=5,
        verbose=verbose,
        logger=_logger,
    )

    # User cancellation check
    if stop_check and stop_check():
        raise InterruptedError(
            "ID columns processing aborted."
        )

    return id_df

    # # Logic to convert ids_data_array to 2D NumPy
    # # and repeat according to base_df indices
    # # ... (Detailed implementation)
    # return pd.DataFrame()  # stub for brevity


def _process_target_variable(
    preds: np.ndarray,
    pred_key: str,
    base_name: str,
    num_samples: int,
    H: int,
    quantiles: list[float] | None,
    y_true_dict: dict[str, Any] | None,
    scaler_info: dict[str, Any] | None,
    verbose: int,
    stop_check: Callable[[], bool],
    _logger: Any,
) -> pd.DataFrame:
    """
    Format predictions, attach actuals, and prepare scaling.

    - Calls _format_target_predictions for quantiles
      or point forecasts.
    - Reshapes and adds true values if provided.
    - Records inverse-scaling actions to apply later.
    """
    # 1. Format predictions
    cols_pred, df_pred = _format_target_predictions(
        preds,
        num_samples,
        H,
        O=(preds.shape[-1] if preds.ndim == 3 else 1),
        base_target_name=base_name,
        quantiles=quantiles,
        verbose=verbose,
    )
    dfs: list[pd.DataFrame] = [df_pred]

    # 2. Add actual values
    if y_true_dict:
        y_true = y_true_dict.get(pred_key)
        if y_true is None:
            y_true = y_true_dict.get(base_name)
        if hasattr(y_true, "numpy"):
            y_true = y_true.numpy()
        if y_true is not None:
            arr = y_true.reshape(num_samples * H, -1)
            cols_act = []
            for i in range(arr.shape[1]):
                col = base_name
                if arr.shape[1] > 1:
                    col += f"_{i}"
                cols_act.append(f"{col}_actual")
            dfs.append(pd.DataFrame(arr, columns=cols_act))

    # 3. Prepare inverse-scaling info
    if scaler_info and base_name in scaler_info:
        info = scaler_info[base_name]
        to_transform = cols_pred + cols_act
        info["_cols_to_inv"] = to_transform

    if stop_check and stop_check():
        raise InterruptedError(
            "Target variable processing aborted."
        )

    return pd.concat(dfs, axis=1)


def _apply_masking(
    df: pd.DataFrame,
    targets: dict[str, str],
    apply_mask: bool,
    mask_values: Any,
    mask_fill_value: float,
    quantiles: list[float] | None,
) -> pd.DataFrame:
    """
    Mask forecasts based on reference column values.
    """
    if not apply_mask:
        return df
    first_base = list(targets.values())[0]
    ref = f"{first_base}_actual"
    if ref not in df:
        raise KeyError(f"Reference col '{ref}' missing")
    cols = []
    if quantiles:
        for base in targets.values():
            for q in quantiles:
                cols.append(f"{base}_q{int(q * 100)}")
    else:
        cols = [c for c in df if c.endswith("_pred")]
    return mask_by_reference(
        data=df,
        ref_col=ref,
        values=mask_values,
        find_closest=False,
        fill_value=mask_fill_value,
        mask_columns=cols,
        error="ignore",
        inplace=False,
    )


def _compute_coverage(
    df: pd.DataFrame,
    targets: dict[str, str],
    quantiles: list[float],
    y_true_dict: dict[str, Any],
    coverage_quantile_indices: tuple[int, int],
    savefile: str | None,
    name: str | None,
    verbose: int,
    _logger: Any,
) -> None:
    """
    Compute quantile diagnostics per target.
    """
    for base in targets.values():
        try:
            compute_quantile_diagnostics(
                df,
                target_name=base,
                quantiles=quantiles,
                coverage_quantile_indices=coverage_quantile_indices,
                savefile=(
                    os.path.join(
                        os.path.dirname(savefile),
                        "diagnostics_results.json",
                    )
                    if savefile
                    else None
                ),
                name=name,
                verbose=verbose,
                logger=_logger,
            )
        except Exception as e:
            vlog(
                f"Skipping coverage due to {e}",
                level=2,
                verbose=verbose,
                logger=_logger,
            )


# XXX : TODO
# revise format outputs and apply
#  mask by reference
def format_preds(
    pihalnet_outputs: dict[str, Tensor] | None = None,
    model: Model | None = None,
    model_inputs: dict[str, Tensor] | None = None,
    y_true_dict: dict[str, Any] | None = None,
    target_mapping: dict[str, str] | None = None,
    include_gwl: bool = True,
    include_coords: bool = True,
    quantiles: list[float] | None = None,
    forecast_horizon: int | None = None,
    output_dims: dict[str, int] | None = None,
    ids_data_array: Any | None = None,
    ids_cols: list[str] | None = None,
    ids_cols_indices: list[int] | None = None,
    scaler_info: dict[str, Any] | None = None,
    coord_scaler: Any | None = None,
    evaluate_coverage: bool = False,
    coverage_quantile_indices: tuple[int, int] = (0, -1),
    savefile: str | None = None,
    name: str | None = None,
    apply_mask: bool = False,
    mask_values: Any | None = None,
    mask_fill_value: float | None = None,
    verbose: int = 0,
    _logger: Any = None,
    stop_check: Callable[[], bool] = None,
    **kwargs,
) -> pd.DataFrame:
    """
    Main function orchestrating all helper steps.
    """
    vlog(
        "Starting formatting",
        level=3,
        verbose=verbose,
        logger=_logger,
    )
    # Step 1: obtain raw outputs
    raw_out = _get_model_predictions(
        pihalnet_outputs,
        model,
        model_inputs,
        verbose,
        stop_check,
        _logger,
    )
    # Step 2: convert to NumPy
    proc = _convert_outputs_to_numpy(
        raw_out, verbose, _logger
    )

    # Step 3: define targets
    targets = _define_targets(
        proc,
        include_gwl,
        target_mapping,
        stop_check,
    )
    if not targets:
        vlog(
            "  No valid targets to process after"
            " filtering. Returning empty DF.",
            level=1,
            verbose=verbose,
            logger=_logger,
        )
        return pd.DataFrame()

    # Step 4: infer dims
    num_samp, H = _infer_dimensions(
        proc[next(iter(targets))],
        forecast_horizon,
        verbose,
        _logger,
    )
    # Step 5: build base df
    base_df = _build_sample_df(num_samp, H)
    # Step 6: add coords
    coords_df = _add_coordinate_columns(
        base_df,
        model_inputs,
        coord_scaler,
        include_coords,
        stop_check,
        _logger,
        verbose,
    )
    # Step 7: add IDs
    ids_df = _add_id_columns(
        base_df,
        ids_data_array,
        ids_cols,
        ids_cols_indices,
        stop_check,
        _logger,
        verbose,
    )
    # Step 8: process targets
    dfs = []
    for key, base in targets.items():
        dfs.append(
            _process_target_variable(
                proc[key],
                key,
                base,
                num_samp,
                H,
                quantiles,
                y_true_dict,
                scaler_info,
                verbose,
                stop_check,
                _logger,
            )
        )
    # Step 9: concat all
    final_df = pd.concat(
        [base_df, coords_df, ids_df] + dfs, axis=1
    )
    # Step 10: masking
    final_df = _apply_masking(
        final_df,
        targets,
        apply_mask,
        mask_values,
        mask_fill_value,
        quantiles,
    )
    # Step 11: coverage
    if evaluate_coverage and quantiles:
        _compute_coverage(
            final_df,
            targets,
            quantiles,
            y_true_dict,
            coverage_quantile_indices,
            savefile,
            name,
            verbose,
            _logger,
        )
    vlog(
        "Formatting complete",
        level=3,
        verbose=verbose,
        logger=_logger,
    )
    return final_df


@isdf
def prepare_pinn_data_sequences(
    df: pd.DataFrame,
    time_col: str,
    subsidence_col: str,
    gwl_col: str,
    dynamic_cols: list[str],
    static_cols: list[str] | None = None,
    future_cols: list[str] | None = None,
    spatial_cols: tuple[str, str] | None = None,
    h_field_col: str | None = None,
    lon_col: str | None = None,
    lat_col: str | None = None,
    group_id_cols: list[str] | None = None,
    time_steps: int = 12,
    forecast_horizon: int = 3,
    output_subsidence_dim: int = 1,
    output_gwl_dim: int = 1,
    datetime_format: str | None = None,
    normalize_coords: bool = True,
    cols_to_scale: list[str] | str | None = None,
    lock_physics_cols: bool = True,
    protect_si_suffix: str = "__si",
    return_coord_scaler: bool = False,
    coord_scaler: MinMaxScaler | None = None,
    fit_coord_scaler: bool = True,
    mode: str | None = None,
    model: str | None = None,
    savefile: str | None = None,
    progress_hook: Callable[[float], None] | None = None,
    stop_check: Callable[[], bool] = None,
    verbose: int = 0,
    _logger: logging.Logger
    | Callable[[str], None]
    | None = None,
    **kws,
) -> (
    tuple[dict[str, np.ndarray], dict[str, np.ndarray]]
    | tuple[
        dict[str, np.ndarray],
        dict[str, np.ndarray],
        MinMaxScaler | None,
    ]
):
    # -------------------------------------------------------------------------
    _to_range = lambda f, lo, hi: lo + (hi - lo) * f  # noqa: E731

    vlog(
        "Starting PINN data sequence preparation...",
        verbose=verbose,
        level=1,
        logger=_logger,
    )

    df_proc = df.copy()

    # --- 1. Validate Input Parameters and Columns ---
    if verbose >= 2:
        logger.debug(
            "Validating input parameters and columns."
        )

    vlog(
        "Validating essential columns...",
        verbose=verbose,
        level=2,
        logger=_logger,
    )

    lon_col, lat_col = resolve_spatial_columns(
        df_proc,
        spatial_cols=spatial_cols,
        lon_col=lon_col,
        lat_col=lat_col,
    )
    essential_cols = [
        time_col,
        lon_col,
        lat_col,
        subsidence_col,
        gwl_col,
    ]

    # --- MODIFICATION 1: Conditional Validation ---
    is_geoprior = str(model).lower().strip() in (
        "geoprior",
        "geopriorsubsnet",
    )

    if is_geoprior:
        if h_field_col is None:
            # Check for default names
            if "H_field" in df_proc.columns:
                h_field_col = "H_field"
            elif "soil_thickness" in df_proc.columns:
                h_field_col = "soil_thickness"
            else:
                raise ValueError(
                    "`model` is 'geoprior' but `h_field_col` was not "
                    "provided and default names 'H_field' or 'soil_thickness' "
                    "were not found in the DataFrame."
                )

        vlog(
            f"GeoPrior model detected. Using '{h_field_col}' as H_field.",
            verbose=verbose,
            level=3,
            logger=_logger,
        )

        essential_cols.append(h_field_col)
        vlog(
            f"Validated '{h_field_col}' for GeoPriorSubsNet.",
            verbose=verbose,
            level=3,
            logger=_logger,
        )

    # --- End Modification 1 ---

    exist_features(
        df_proc,
        features=essential_cols,
        message="Essential column(s) missing.",
    )

    vlog(
        "Validating time-series dataset...",
        verbose=verbose,
        level=2,
        logger=_logger,
    )

    check_datetime(
        df_proc,
        dt_cols=time_col,
        ops="check_only",
        consider_dt_as="numeric",
        accept_dt=True,
        allow_int=True,
    )

    vlog(
        "Managing feature column lists...",
        verbose=verbose,
        level=3,
        logger=_logger,
    )

    dynamic_cols = columns_manager(
        dynamic_cols, empty_as_none=False
    )
    exist_features(
        df_proc,
        features=dynamic_cols,
        name="Dynamic feature column(s)",
    )

    static_cols = columns_manager(static_cols)
    if static_cols:
        exist_features(
            df_proc,
            features=static_cols,
            name="Static feature column(s)",
        )

    future_cols = columns_manager(future_cols)
    if future_cols:
        exist_features(
            df_proc,
            features=future_cols,
            name="Future feature column(s)",
        )

    group_id_cols = columns_manager(group_id_cols)
    if group_id_cols:
        exist_features(
            df_proc,
            features=group_id_cols,
            name="Group ID column(s)",
        )

    vlog(
        "Validating time_steps and forecast_horizon...",
        verbose=verbose,
        level=3,
        logger=_logger,
    )

    time_steps = validate_positive_integer(
        time_steps, "time_steps"
    )
    forecast_horizon = validate_positive_integer(
        forecast_horizon,
        "forecast_horizon",
    )

    vlog(
        "Converting time column to numeric values...",
        verbose=verbose,
        level=4,
        logger=_logger,
    )
    if pd.api.types.is_numeric_dtype(df_proc[time_col]):
        numerical_time_col = time_col
        if verbose >= 2:
            logger.debug(
                f"Time column '{time_col}' is already numeric. Using it directly."
            )
    else:
        try:
            df_proc[time_col] = pd.to_datetime(
                df_proc[time_col], format=datetime_format
            )
            df_proc[f"{time_col}_numeric"] = df_proc[
                time_col
            ].dt.year + (
                df_proc[time_col].dt.dayofyear - 1
            ) / (
                365
                + df_proc[time_col].dt.is_leap_year.astype(
                    int
                )
            )
            numerical_time_col = f"{time_col}_numeric"
            if verbose >= 2:
                logger.debug(
                    f"Converted datetime column '{time_col}'"
                    f" to numerical '{numerical_time_col}'."
                )
            vlog(
                f"Time column converted to '{numerical_time_col}'",
                verbose=verbose,
                level=5,
                logger=_logger,
            )
        except Exception as e:
            raise ValueError(
                f"Failed to convert or process time column '{time_col}'. "
                f"Ensure it's datetime-like or specify `datetime_format`."
                f" Error: {e}"
            )
    vlog(
        "Pre-flight: assessing sliding-window feasibility ...",
        verbose=verbose,
        level=1,
        logger=_logger,
    )

    ok, _ = check_sequence_feasibility(
        df_proc.copy(),
        time_col=time_col,
        group_id_cols=group_id_cols,
        time_steps=time_steps,
        forecast_horizon=forecast_horizon,
        verbose=verbose,
        error="raise",  # fail fast on impossibility,
        logger=_logger,
    )

    vlog(
        "Feasibility check passed — generating sequences...",
        verbose=verbose,
        level=1,
        logger=_logger,
    )

    vlog(
        "Starting PINN data sequence generation...",
        verbose=verbose,
        level=1,
        logger=_logger,
    )

    mode = select_mode(mode, default="pihal_like")
    vlog(
        f"Operating in '{mode}' data preparation mode.",
        level=1,
        verbose=verbose,
        logger=_logger,
    )

    exclude_cols = []
    if lock_physics_cols:
        exclude_cols += [subsidence_col, gwl_col]
        if h_field_col is not None:
            exclude_cols.append(h_field_col)

        # Also exclude any "__si" columns automatically (safe for Stage-1 outputs)
        if protect_si_suffix:
            exclude_cols += [
                c
                for c in df_proc.columns
                if str(c).endswith(protect_si_suffix)
            ]

    # IMPORTANT: do NOT shift time in normalization for PINN
    df_proc, coord_scaler, cols_scaler = normalize_for_pinn(
        df=df_proc,
        time_col=numerical_time_col,
        coord_x=lon_col,
        coord_y=lat_col,
        scale_coords=normalize_coords,
        cols_to_scale=cols_to_scale,
        forecast_horizon=forecast_horizon,
        shift_time_by_horizon=False,
        exclude_cols=exclude_cols,
        protect_si_suffix=protect_si_suffix,
        verbose=verbose,
        _logger=_logger,
        coord_scaler=coord_scaler,
        fit_coord_scaler=fit_coord_scaler,
    )

    # --- 2. Group and Sort Data ---
    vlog(
        "Grouping and sorting data...",
        verbose=verbose,
        level=2,
        logger=_logger,
    )

    if verbose >= 2:
        logger.debug(
            f"Grouping and sorting data. Group IDs:"
            f" {group_id_cols or 'None (single group)'}."
        )

    sort_by_cols = [numerical_time_col]
    if group_id_cols:
        sort_by_cols = group_id_cols + sort_by_cols
    df_proc = df_proc.sort_values(
        by=sort_by_cols
    ).reset_index(drop=True)

    if group_id_cols:
        grouped_data = df_proc.groupby(group_id_cols)
        group_keys = list(grouped_data.groups.keys())
        if verbose >= 1:
            logger.info(
                f"Data grouped by {group_id_cols} into {len(group_keys)} groups."
            )
    else:
        grouped_data = [(None, df_proc)]
        group_keys = [None]
        if verbose >= 1:
            logger.info(
                "Processing entire DataFrame as a single group."
            )

    # --- 3. First Pass: Calculate Total Number of Sequences ---
    total_sequences = 0
    min_len_per_group = time_steps + forecast_horizon
    valid_group_dfs = []

    vlog(
        "Counting valid sequences in groups...",
        verbose=verbose,
        level=4,
        logger=_logger,
    )

    if verbose >= 2:
        logger.debug(
            "First pass: Calculating total sequences. Min"
            f" length per group: {min_len_per_group}."
        )

    n_groups = len(group_keys) or 1
    for g_idx, group_key in enumerate(group_keys):
        group_df = (
            grouped_data.get_group(group_key)
            if group_id_cols
            else df_proc
        )

        if stop_check and stop_check():
            raise InterruptedError(
                "Sequence generation aborted."
            )

        key_str = (
            group_key
            if group_key is not None
            else "<Full Dataset>"
        )
        if progress_hook is not None:
            progress_hook(
                _to_range((g_idx + 1) / n_groups, 0.0, 0.50)
            )

        if len(group_df) < min_len_per_group:
            if verbose >= 5:
                logger.info(
                    f"Group '{key_str}' has {len(group_df)} points, less than "
                    f"min required ({min_len_per_group}). Skipping."
                )
            vlog(
                f"Group {key_str} too small:"
                f" {len(group_df)} < {min_len_per_group}",
                verbose=verbose,
                level=6,
                logger=_logger,
            )

            continue

        num_seq_in_group = (
            len(group_df) - min_len_per_group + 1
        )
        total_sequences += num_seq_in_group
        valid_group_dfs.append(group_df)
        if verbose >= 5:
            logger.debug(
                f"Group '{group_key if group_key else '<Full Dataset>'}' "
                f"will yield {num_seq_in_group} sequences."
            )
        vlog(
            f"Group {key_str} yields {total_sequences} seqs.",
            verbose=verbose,
            level=6,
            logger=_logger,
        )

    if total_sequences == 0:
        raise ValueError(
            "No group has enough data points to create sequences with "
            f"time_steps={time_steps} and forecast_horizon={forecast_horizon}."
        )
    if verbose >= 1:
        logger.info(
            "Total valid sequences to be"
            f" generated: {total_sequences}."
        )

    vlog(
        f"Total sequences: {total_sequences}",
        verbose=verbose,
        level=1,
        logger=_logger,
    )

    # --- 4. Pre-allocate NumPy Arrays ---
    vlog(
        "Pre-allocating arrays...",
        verbose=verbose,
        level=2,
        logger=_logger,
    )

    num_dynamic_feats = len(dynamic_cols)
    num_static_feats = len(static_cols) if static_cols else 0
    num_future_feats = len(future_cols) if future_cols else 0

    coords_horizon_arr = np.zeros(
        (total_sequences, forecast_horizon, 3),
        dtype=np.float32,
    )
    static_features_arr = np.zeros(
        (total_sequences, num_static_feats), dtype=np.float32
    )
    dynamic_features_arr = np.zeros(
        (total_sequences, time_steps, num_dynamic_feats),
        dtype=np.float32,
    )

    if mode == "tft_like":
        future_window_len = time_steps + forecast_horizon
        vlog(
            f"Allocating future features array for 'tft_like' mode "
            f"with time dimension: {future_window_len}",
            level=2,
            verbose=verbose,
            logger=_logger,
        )
    else:
        future_window_len = forecast_horizon

    future_features_arr = np.zeros(
        (
            total_sequences,
            future_window_len,
            num_future_feats,
        ),
        dtype=np.float32,
    )
    target_subsidence_arr = np.zeros(
        (
            total_sequences,
            forecast_horizon,
            output_subsidence_dim,
        ),
        dtype=np.float32,
    )
    target_gwl_arr = np.zeros(
        (total_sequences, forecast_horizon, output_gwl_dim),
        dtype=np.float32,
    )

    # --- MODIFICATION 2: Conditional Allocation ---
    H_field_arr = None  # Initialize as None
    if is_geoprior:
        H_field_arr = np.zeros(
            (total_sequences, forecast_horizon, 1),
            dtype=np.float32,
        )
        vlog(
            f"Allocated 'H_field' array with shape {H_field_arr.shape}",
            verbose=verbose,
            level=4,
            logger=_logger,
        )
    # --- End Modification 2 ---

    vlog(
        "Arrays shapes set.",
        verbose=verbose,
        level=5,
        logger=_logger,
    )

    if verbose >= 2:
        logger.debug(
            "Pre-allocated NumPy arrays for sequence data:"
        )
        logger.debug(
            f"  Coords Horizon: {coords_horizon_arr.shape}"
        )
        logger.debug(
            f"  Static Features: {static_features_arr.shape}"
        )
        logger.debug(
            f"  Dynamic Features: {dynamic_features_arr.shape}"
        )
        logger.debug(
            f"  Future Features: {future_features_arr.shape}"
        )
        logger.debug(
            f"  Target Subsidence: {target_subsidence_arr.shape}"
        )
        logger.debug(f"  Target GWL: {target_gwl_arr.shape}")
        # --- MODIFICATION 2b: Conditional Logging ---
        if H_field_arr is not None:
            logger.debug(
                f"  H_field ({h_field_col}): {H_field_arr.shape}"
            )

        # --- End Modification 2b ---

    # --- 5. Second Pass: Populate Arrays with Rolling Windows ---
    current_seq_idx = 0
    if verbose >= 2:
        logger.debug(
            "Second pass: Populating sequence arrays..."
        )

    vlog(
        "Populating arrays with data...",
        verbose=verbose,
        level=2,
        logger=_logger,
    )

    total_seq = (
        sum(
            len(gdf) - min_len_per_group + 1
            for gdf in valid_group_dfs
        )
        or 1
    )
    done_seq = 0

    for group_df in valid_group_dfs:
        group_t_coords = group_df[numerical_time_col].values
        group_x_coords = group_df[lon_col].values
        group_y_coords = group_df[lat_col].values

        t_min_group, t_max_group = (
            group_t_coords.min(),
            group_t_coords.max(),
        )

        if verbose >= 3:
            print()
            time_scale_info = (
                f"{t_min_group:.4f}-{t_max_group:.4f}"
            )
            if normalize_coords and coord_scaler:
                time_scale_info += " (normalized)"
            else:
                time_scale_info += " (original scale)"

            print_box(
                f"Group window t: {time_scale_info}",
                width=_TW,
                align="center",
                border_char="+",
                horizontal_char="-",
                vertical_char="|",
                padding=1,
            )

        group_static_vals = None
        if static_cols and num_static_feats > 0:
            group_static_vals = group_df.iloc[0][
                static_cols
            ].values.astype(np.float32)

        # --- MODIFICATION 3: Get static H_val for the group ---
        group_H_val = None
        if is_geoprior:
            # Get the single static soil thickness value for this group
            group_H_val = group_df.iloc[0][h_field_col]
        # --- End Modification 3 ---

        num_seq_in_this_group = (
            len(group_df) - min_len_per_group + 1
        )
        for i in range(num_seq_in_this_group):
            done_seq += 1
            if progress_hook is not None:
                progress_hook(
                    _to_range(done_seq / total_seq, 0.5, 1.00)
                )

            if stop_check and stop_check():
                raise InterruptedError(
                    "Sequence generation aborted by user"
                )

            if static_cols and num_static_feats > 0:
                static_features_arr[current_seq_idx] = (
                    group_static_vals
                )

            dynamic_start_idx = i
            dynamic_end_idx = i + time_steps
            dynamic_features_arr[current_seq_idx] = (
                group_df.iloc[
                    dynamic_start_idx:dynamic_end_idx
                ][dynamic_cols].values.astype(np.float32)
            )

            horizon_start_idx = i + time_steps
            horizon_end_idx = (
                i + time_steps + forecast_horizon
            )

            if future_cols and num_future_feats > 0:
                if mode == "tft_like":
                    future_start_idx = i
                    future_end_idx = (
                        i + time_steps + forecast_horizon
                    )
                else:  # 'pihal_like' mode
                    future_start_idx = horizon_start_idx
                    future_end_idx = horizon_end_idx

                future_features_arr[current_seq_idx] = (
                    group_df.iloc[
                        future_start_idx:future_end_idx
                    ][future_cols].values.astype(np.float32)
                )

            t_horizon = group_t_coords[
                horizon_start_idx:horizon_end_idx
            ]
            x_horizon = group_x_coords[
                horizon_start_idx:horizon_end_idx
            ]
            y_horizon = group_y_coords[
                horizon_start_idx:horizon_end_idx
            ]

            coords_horizon_arr[current_seq_idx, :, 0] = (
                t_horizon
            )
            coords_horizon_arr[current_seq_idx, :, 1] = (
                x_horizon
            )
            coords_horizon_arr[current_seq_idx, :, 2] = (
                y_horizon
            )

            target_subsidence_arr[current_seq_idx] = (
                group_df.iloc[
                    horizon_start_idx:horizon_end_idx
                ][subsidence_col]
                .values.reshape(
                    forecast_horizon, output_subsidence_dim
                )
                .astype(np.float32)
            )

            target_gwl_arr[current_seq_idx] = (
                group_df.iloc[
                    horizon_start_idx:horizon_end_idx
                ][gwl_col]
                .values.reshape(
                    forecast_horizon, output_gwl_dim
                )
                .astype(np.float32)
            )

            # --- MODIFICATION 4: Populate H_field array ---
            if H_field_arr is not None:
                # Tile the static H_val across the forecast horizon
                H_field_arr[current_seq_idx, :, 0] = (
                    group_H_val
                )
            # --- End Modification 4 ---

            if verbose >= 7:
                logger.debug(f"  Sequence {current_seq_idx}:")
                logger.debug(
                    "    Dynamic window:"
                    f" {dynamic_start_idx}-{dynamic_end_idx - 1}"
                )
                logger.debug(
                    "    Horizon window:"
                    f" {horizon_start_idx}-{horizon_end_idx - 1}"
                )
                logger.debug(
                    f"    Coords (first step):"
                    f" {coords_horizon_arr[current_seq_idx, 0, :]}"
                )

            vlog(
                f"Seq {current_seq_idx}:"
                f" dyn {dynamic_start_idx}-{dynamic_end_idx - 1},"
                f"hzn {horizon_start_idx}-{horizon_end_idx - 1}",
                verbose=verbose,
                level=7,
                logger=_logger,
            )

            current_seq_idx += 1

    if verbose >= 1:
        logger.info("Successfully populated sequence arrays.")

    vlog(
        "Data population complete.",
        verbose=verbose,
        level=1,
        logger=_logger,
    )

    inputs_dict = {
        "coords": coords_horizon_arr,
        "static_features": static_features_arr
        if num_static_feats > 0
        else None,
        "dynamic_features": dynamic_features_arr,
        "future_features": future_features_arr
        if num_future_feats > 0
        else None,
    }

    # --- MODIFICATION 5: Conditionally add H_field to inputs_dict ---
    if H_field_arr is not None:
        inputs_dict["H_field"] = H_field_arr
    # --- End Modification 5 ---

    inputs_dict = {
        k: v for k, v in inputs_dict.items() if v is not None
    }

    targets_dict = {
        "subsidence": target_subsidence_arr,
        "gwl": target_gwl_arr,
    }

    if savefile:
        vlog(
            f"\nPreparing to save sequence data to '{savefile}'...",
            verbose=verbose,
            level=3,
            logger=_logger,
        )
        # --- v3.2: allow split between GWL dynamic driver and GWL prediction target ---
        gwl_dyn_col = kws.get("gwl_dyn_col", None) or gwl_col

        if gwl_dyn_col not in dynamic_cols:
            raise ValueError(
                f"gwl_dyn_col={gwl_dyn_col!r} must be present in dynamic_cols.\n"
                f"Got gwl_col(target)={gwl_col!r} and dynamic_cols={dynamic_cols}.\n"
                "GeoPrior v3.2 expects: gwl_dyn_col=<depth__si>, gwl_col=<head__si>."
            )

        gwl_dyn_index = int(dynamic_cols.index(gwl_dyn_col))

        job_dict = {
            "static_data": static_features_arr,
            "dynamic_data": dynamic_features_arr,
            "future_data": future_features_arr,
            "subsidence": target_subsidence_arr,
            "gwl": target_gwl_arr,
            "static_features": static_cols,
            "dynamic_features": dynamic_cols,
            "future_features": future_cols,
            "inputs_dict": inputs_dict,
            "targets_dict": targets_dict,
            "subsidence_col": subsidence_col,
            "spatial_features": spatial_cols,
            "lon_col": lon_col,
            "lat_col": lat_col,
            "time_col": time_col,
            "time_steps": time_steps,
            "forecast_horizon": forecast_horizon,
            "cols_scaler": cols_scaler,
            "coord_scaler": coord_scaler,
            "normalize_coords_flag": normalize_coords,
            "saved_coord_scaler_flag": coord_scaler
            is not None,
            # ---  Conditionally add to savefile ---
            "model_type": model,
            "h_field_col": h_field_col,
            "H_field": H_field_arr
            if H_field_arr is not None
            else None,
            "gwl_col": gwl_col,  # target column (head)
            "gwl_dyn_col": gwl_dyn_col,  # dynamic driver column (depth)
            "gwl_dyn_index": gwl_dyn_index,  # index of the driver in dynamic_cols
        }
        try:
            job_dict.update(get_versions())
        except NameError:
            vlog(
                "\n  `get_versions` not found, version info not saved.",
                verbose=verbose,
                level=1,
                logger=_logger,
            )

        try:
            save_job(
                job_dict, savefile, append_versions=False
            )

            if verbose >= 1:
                vlog(
                    f"Sequence data dictionary successfully "
                    f"saved to '{savefile}'.",
                    verbose=verbose,
                    level=1,
                    logger=_logger,
                )
        except Exception as e:
            vlog(
                f"Failed to save job dictionary to "
                f"'{savefile}': {e}",
                verbose=verbose,
                level=1,
                logger=_logger,
            )

    if verbose >= 1:
        logger.info(
            "PINN data sequence preparation completed."
        )
        if verbose >= 3:
            for key, arr in inputs_dict.items():
                logger.debug(
                    "  Final input '{key}' shape:"
                    f" {arr.shape if arr is not None else 'None'}"
                )
            for key, arr in targets_dict.items():
                logger.debug(
                    f"  Final target '{key}' shape: {arr.shape}"
                )

    vlog(
        "PINN data sequence preparation successfully completed.",
        verbose=verbose,
        level=3,
        logger=_logger,
    )

    if return_coord_scaler:
        return inputs_dict, targets_dict, coord_scaler

    return inputs_dict, targets_dict


def check_and_rename_keys(
    inputs, y
):  # ranem to check_input_keys
    r"""
    Helper function to check and rename keys in the inputs
    and target dictionaries.

    This function ensures that the necessary keys are present in both the
    `inputs` and `y` dictionaries. If the keys for 'subsidence' or 'gwl'
    are not found, it attempts to rename them from possible alternatives
    like 'subs_pred' or 'gwl_pred'.

    Parameters
    ----------
    inputs : dict
        A dictionary containing the input data. The keys 'coords' and
        'dynamic_features' are expected.

    y : dict
        A dictionary containing the target values. The keys 'subsidence'
        and 'gwl' are expected, but they could also appear as 'subs_pred'
        or 'gwl_pred'.

    Raises
    ------
    ValueError
        If required keys are missing in `inputs` or `y`, or if renaming
        does not result in valid keys for 'subsidence' and 'gwl'.
    """

    # Check if 'coords' and 'dynamic_features' are in inputs
    if "coords" not in inputs or inputs["coords"] is None:
        raise ValueError("Input 'coords' is missing or None.")
    if (
        "dynamic_features" not in inputs
        or inputs["dynamic_features"] is None
    ):
        raise ValueError(
            "Input 'dynamic_features' is missing or None."
        )

    # Check for 'subsidence' in y, allow renaming from 'subs_pred'
    # just check whether subsidence or subs_pred in y

    if "subsidence" not in y and "subs_pred" in y:
        y["subsidence"] = y.pop("subs_pred")
    if "subsidence" not in y:
        # here by explicit to hel the
        raise ValueError(
            "Target 'subsidence' is missing or None."
        )
        # use that user should provide subsidene or subs_pred

    # Check for 'gwl' in y, allow renaming from 'gwl_pred'
    # just check whether gwl or gwl_pred one,its in the y
    #
    if "gwl_pred" not in y and "gwl" in y:
        # no need to rename yet, later this will handle si just check only
        y["gwl_pred"] = y.pop("gwl")
    if "gwl" not in y:
        raise ValueError("Target 'gwl' is missing or None.")

    return inputs, y


def _check_required_input_keys(
    inputs,
    y=None,
    message=None,
):
    r"""
    Helper function to check and rename keys in the inputs
    and target dictionaries.

    This function ensures that the necessary keys are present in both the
    `inputs` and `y` dictionaries. If the keys for 'subsidence' or 'gwl'
    are not found, it attempts to rename them from possible alternatives
    like 'subs_pred' or 'gwl_pred'.

    Parameters
    ----------
    inputs : dict
        A dictionary containing the input data. The keys 'coords' and
        'dynamic_features' are expected.

    y : dict
        A dictionary containing the target values. The keys 'subsidence'
        and 'gwl' are expected, but they could also appear as 'subs_pred'
        or 'gwl_pred'.

    message : str, optional
       Message to raise error when inputs/y are not dictionnary.

    Raises
    ------
    ValueError
        If required keys are missing in `inputs` or `y`, or if renaming
        does not result in valid keys for 'subsidence' and 'gwl'.
    """
    if inputs is not None:
        if not isinstance(inputs, dict):
            message = message or (
                "Inputs must be a dictionnary containing"
                " 'coords' and 'dynamic_features'."
                f" Got {type(inputs).__name__!r}"
            )
            raise TypeError(message)

        # Check if 'coords' and 'dynamic_features' are in inputs
        if "coords" not in inputs or inputs["coords"] is None:
            raise ValueError(
                "Input 'coords' is missing or None."
            )
        if (
            "dynamic_features" not in inputs
            or inputs["dynamic_features"] is None
        ):
            raise ValueError(
                "Input 'dynamic_features' is missing or None."
            )

    if y is not None:
        if not isinstance(y, dict):
            message = message or (
                "Target `y` must be a dictionnary containing"
                " 'subs_pred/subsidence' and 'gwl/gwl_red'."
                f" Got {type(y).__name__!r}"
            )
            raise TypeError(message)

        # Check for 'subsidence' in y, allow renaming from 'subs_pred'
        if "subsidence" not in y and "subs_pred" not in y:
            raise ValueError(
                "Target 'subsidence' is missing or None."
                " Please provide 'subsidence' or 'subs_pred'."
            )

        # Check for 'gwl' in y, allow renaming from 'gwl_pred'
        if "gwl" not in y and "gwl_pred" not in y:
            raise ValueError(
                "Target 'gwl' is missing or None."
                " Please provide 'gwl' or 'gwl_pred'."
            )

    return inputs, y


def check_required_input_keys(
    inputs: dict[str, Any] | None,
    y: dict[str, Any] | None = None,
    message: str | None = None,
    model_name: str | None = None,
    do_rename: bool = True,
) -> tuple[dict[str, Any] | None, dict[str, Any] | None]:
    """
    Validate presence of required keys in `inputs` and `y`.
    Optionally canonicalize keys via reverse alias mapping.

    This function ensures that the necessary keys are present in both the
    `inputs` and `y` dictionaries. If the keys for 'subsidence' or 'gwl'
    are not found, it attempts to rename them from possible alternatives
    like 'subs_pred' or 'gwl_pred'.

    Parameters
    ----------
    inputs : dict
        A dictionary containing the input data. The keys 'coords' and
        'dynamic_features' are expected.

    y : dict
        A dictionary containing the target values. The keys 'subsidence'
        and 'gwl' are expected, but they could also appear as 'subs_pred'
        or 'gwl_pred'.

    message : str, optional
       Message to raise error when inputs/y are not dictionnary.

    Raises
    ------
    ValueError
        If required keys are missing in `inputs` or `y`, or if renaming
        does not result in valid keys for 'subsidence' and 'gwl'.


    """

    # ---- inputs checks -------------------------------------------------
    if inputs is not None:
        if not isinstance(inputs, dict):
            msg = message or (
                "Inputs must be a dict with 'coords' and "
                "'dynamic_features'. Got "
                f"{type(inputs).__name__!r}."
            )
            raise TypeError(msg)

        # Canonicalize GeoPrior-specific inputs (H field aliases)
        if do_rename and model_name:
            name = str(model_name).lower()
            if name in (
                "geoprior",
                "geopriorsubsnet",
                "geopriorsubsnet",
            ):
                inputs = rename_dict_keys(
                    inputs,
                    {
                        "H_field": (
                            "H_field",
                            "soil_thickness",
                            "soil thickness",
                            "h_field",
                        )
                    },
                    order="reverse",
                )

        # Mandatory base inputs
        if (
            "coords" not in inputs
            or inputs.get("coords") is None
        ):
            raise ValueError(
                "Input 'coords' is missing or None."
            )
        if (
            "dynamic_features" not in inputs
            or inputs.get("dynamic_features") is None
        ):
            raise ValueError(
                "Input 'dynamic_features' is missing or None."
            )

        # GeoPrior requires H_field (after alias unification)
        if model_name:
            name = str(model_name).lower()
            if name in (
                "geoprior",
                "geopriorsubsnet",
                "geopriorsubsnet",
            ):
                if (
                    "H_field" not in inputs
                    or inputs.get("H_field") is None
                ):
                    raise ValueError(
                        "GeoPrior requires 'H_field' in inputs "
                        "(aliases accepted: 'soil_thickness', "
                        "'soil thickness', 'h_field')."
                    )

    # ---- target checks -------------------------------------------------
    if y is not None:
        if not isinstance(y, dict):
            msg = message or (
                "Target `y` must be a dict containing "
                "'subs_pred/subsidence' and 'gwl_pred/gwl'. Got "
                f"{type(y).__name__!r}."
            )
            raise TypeError(msg)

        # Canonicalize targets to subs_pred / gwl_pred
        if do_rename:
            y = rename_dict_keys(
                y,
                {
                    "subs_pred": ("subs_pred", "subsidence"),
                    "gwl_pred": ("gwl_pred", "gwl"),
                },
                order="reverse",
            )

        # Accept either canonical or legacy if not renaming
        has_subs = ("subs_pred" in y) or ("subsidence" in y)
        has_gwl = ("gwl_pred" in y) or ("gwl" in y)

        if not has_subs:
            raise ValueError(
                "Target missing subsidence. Provide 'subs_pred' or "
                "'subsidence'."
            )
        if not has_gwl:
            raise ValueError(
                "Target missing gwl. Provide 'gwl_pred' or 'gwl'."
            )

    return inputs, y


def _extract_txy_in(
    inputs: Tensor
    | np.ndarray
    | dict[str, Tensor | np.ndarray],
    coord_slice_map: dict[str, int] | None = None,
) -> tuple[Tensor, Tensor, Tensor]:
    r"""
    Extracts t, x, y tensors from `inputs`, which may be:
      - A single 3D tensor of shape (batch, time_steps, 3)
      - A dict containing a key 'coords' with such a tensor
      - A dict containing separate keys 't', 'x', and 'y'

    Parameters
    ----------
    inputs : tf.Tensor or np.ndarray or dict
        If tensor/array: expected shape is (batch, time_steps, 3).
        If dict:
          - If 'coords' in dict: dict['coords'] must be (batch, time_steps, 3).
          - Otherwise, dict must have keys 't', 'x', 'y' each of shape
            (batch, time_steps, 1) or (batch, time_steps).
    coord_slice_map : dict, optional
        Mapping from 't', 'x', 'y' to their index in the last dimension of
        the coords tensor. Defaults to {'t': 0, 'x': 1, 'y': 2}.

    Returns
    -------
    t : tf.Tensor
        Tensor of shape (batch, time_steps, 1) corresponding to time coordinate.
    x : tf.Tensor
        Tensor of shape (batch, time_steps, 1) corresponding to x coordinate.
    y : tf.Tensor
        Tensor of shape (batch, time_steps, 1) corresponding to y coordinate.

    Raises
    ------
    ValueError
        If `inputs` is not in one of the supported formats, or dimensions
        are inconsistent.
    """

    # Default slice map
    if coord_slice_map is None:
        coord_slice_map = {"t": 0, "x": 1, "y": 2}

    # Helper to ensure output is a tf.Tensor with a final singleton dim
    def _ensure_tensor_with_last_dim(
        inp: Tensor | np.ndarray,
    ) -> Tensor:
        if isinstance(inp, np.ndarray):
            inp = tf_convert_to_tensor(inp)
        if not isinstance(inp, Tensor):
            raise ValueError(
                f"Expected tf.Tensor or np.ndarray, got {type(inp)}"
            )
        # If shape is (batch, time_steps), add last dimension
        if inp.ndim == 2:
            inp = tf_expand_dims(inp, axis=-1)
        # Now expect (batch, time_steps, 1)
        if inp.ndim != 3 or inp.shape[-1] != 1:
            raise ValueError(
                "Coordinate array must have shape "
                f"(batch, time_steps, 1); got {inp.shape}"
            )
        return inp

    # Case 1: inputs is a dict
    if isinstance(inputs, dict):
        # If 'coords' key is present
        if "coords" in inputs:
            coords_tensor = inputs["coords"]
            if isinstance(coords_tensor, np.ndarray):
                coords_tensor = tf_convert_to_tensor(
                    coords_tensor
                )
            if not isinstance(coords_tensor, Tensor):
                raise ValueError(
                    f"Expected tensor/array for 'coords';"
                    f" got {type(coords_tensor)}"
                )
            # Expect shape (batch, time_steps, 3)
            if (
                coords_tensor.ndim != 3
                or coords_tensor.shape[-1] < 3
            ):
                raise ValueError(
                    f"'coords' must have shape (batch, time_steps, ≥3);"
                    f" got {coords_tensor.shape}"
                )
            # Slice out t, x, y
            t = coords_tensor[
                ...,
                coord_slice_map["t"] : coord_slice_map["t"]
                + 1,
            ]
            x = coords_tensor[
                ...,
                coord_slice_map["x"] : coord_slice_map["x"]
                + 1,
            ]
            y = coords_tensor[
                ...,
                coord_slice_map["y"] : coord_slice_map["y"]
                + 1,
            ]
            return (
                tf_cast(t, tf_float32),
                tf_cast(x, tf_float32),
                tf_cast(y, tf_float32),
            )

        # If keys 't','x','y' exist separately
        if all(k in inputs for k in ("t", "x", "y")):
            t = _ensure_tensor_with_last_dim(inputs["t"])
            x = _ensure_tensor_with_last_dim(inputs["x"])
            y = _ensure_tensor_with_last_dim(inputs["y"])
            return (
                tf_cast(t, tf_float32),
                tf_cast(x, tf_float32),
                tf_cast(y, tf_float32),
            )

        raise ValueError(
            "Dict `inputs` must contain either key 'coords' or keys 't', 'x', 'y'."
        )

    # Case 2: inputs is a single tensor/array
    if isinstance(inputs, Tensor | np.ndarray):
        coords_tensor = inputs
        if isinstance(coords_tensor, np.ndarray):
            coords_tensor = tf_convert_to_tensor(
                coords_tensor
            )
        # Expect shape (batch, time_steps, 3)
        if (
            coords_tensor.ndim != 3
            or coords_tensor.shape[-1] < 3
        ):
            raise ValueError(
                f"Tensor `inputs` must have shape (batch, time_steps, 3);"
                f" got {coords_tensor.shape}"
            )
        t = coords_tensor[
            ...,
            coord_slice_map["t"] : coord_slice_map["t"] + 1,
        ]
        x = coords_tensor[
            ...,
            coord_slice_map["x"] : coord_slice_map["x"] + 1,
        ]
        y = coords_tensor[
            ...,
            coord_slice_map["y"] : coord_slice_map["y"] + 1,
        ]
        return (
            tf_cast(t, tf_float32),
            tf_cast(x, tf_float32),
            tf_cast(y, tf_float32),
        )

    raise ValueError(
        f"`inputs` must be a tensor/array or dict; got {type(inputs)}"
    )


def _get_coords(
    inputs: dict[str, Any] | Sequence[Any] | Tensor,
    *,
    check_shape: bool = False,
    is_time_dependent: bool = True,
) -> Tensor:
    r"""
    Extract the **coords** tensor from any input layout.

    Parameters
    ----------
    inputs : dict, sequence or tensor
        * **Dict** – must contain a ``"coords"`` key.
        * **Sequence** – coords are expected at index 0.
        * **Tensor** – interpreted as coords directly.
    check_shape : bool, default ``False``
        If ``True``, validate that the last dimension equals 3 and that
        the rank matches *is_time_dependent*.
    is_time_dependent : bool, default ``True``
        * ``True``  → expect a 3-D shape ``(B, T, 3)``
        * ``False`` → allow a 2-D shape ``(B, 3)``; a 3-D shape is also
          accepted (useful when the model ignores time).

    Returns
    -------
    tf.Tensor
        The extracted coords tensor.

    Raises
    ------
    KeyError
        If a dict is missing the ``"coords"`` key.
    TypeError
        If *inputs* is not dict, sequence or tensor.
    ValueError
        If ``check_shape`` is ``True`` and the tensor shape is invalid.

    Notes
    -----
    The function is lightweight and executes outside any
    ``tf.function`` context; use it freely inside the model’s
    ``train_step`` or data-pipe helpers.
    """
    # ── 1. locate the tensor
    if isinstance(inputs, Tensor):
        coords = inputs
    elif isinstance(inputs, tuple | list):
        coords = inputs[0]
    elif isinstance(inputs, dict):
        try:
            coords = inputs["coords"]
        except KeyError as err:
            raise KeyError(
                "Input dict lacks a 'coords' entry."
            ) from err
    else:
        raise TypeError(
            "inputs must be a dict, sequence, or Tensor; "
            f"got {type(inputs).__name__}"
        )

    # ── 2. validate shape if requested
    if check_shape:
        shape = coords.shape  # static if known, else dynamic
        # last dim must be 3
        if shape.rank is not None:
            if shape[-1] != 3:
                raise ValueError(
                    "coords[..., -1] must equal 3 (t, x, y); "
                    f"found {shape[-1]}"
                )
            # rank check when static
            if is_time_dependent and shape.rank != 3:
                raise ValueError(
                    "Expected coords rank 3 (B, T, 3) for time-dependent "
                    "data; received rank "
                    f"{shape.rank}"
                )
            if not is_time_dependent and shape.rank not in (
                2,
                3,
            ):
                raise ValueError(
                    "Expected coords rank 2 or 3 for static data; "
                    f"received rank {shape.rank}"
                )
        else:  # dynamic shapes (inside tf.function)
            tf_debugging.assert_equal(
                tf_shape(coords)[-1],
                3,
                message="coords[..., -1] must equal 3 (t, x, y)",
            )
            if is_time_dependent:
                tf_debugging.assert_rank(coords, 3)
            else:
                tf_debugging.assert_rank_in(coords, (2, 3))

    return coords


@ensure_pkg(
    KERAS_BACKEND or "tensorflow",
    extra="TensorFlow is required for this function.",
)
def extract_txy_in(
    inputs: Tensor
    | np.ndarray
    | dict[str, Tensor | np.ndarray],
    coord_slice_map: dict[str, int] | None = None,
    expect_dim: str | None = None,
    verbose: int = 0,
    _logger: logging.Logger
    | Callable[[str], None]
    | None = None,
    **kws,
) -> tuple[Tensor, Tensor, Tensor]:
    r"""
    Extracts t, x, y tensors from various input formats.

    This utility standardizes coordinate inputs, accepting a single
    tensor or a dictionary, and handling both 2D (spatial/static)
    and 3D (spatio-temporal) data. It ensures a consistent 3D
    output format for robust downstream processing.

    Parameters
    ----------
    inputs : tf.Tensor, np.ndarray, or dict
        The input data containing coordinates. A single tensor or array
        may be 2D with shape ``(batch, 3)`` or 3D with shape
        ``(batch, time_steps, 3)``. A dictionary may contain a
        ``'coords'`` key with the coordinate tensor, or separate
        ``'t'``, ``'x'``, and ``'y'`` keys.

    coord_slice_map : dict, optional
        Mapping for 't', 'x', 'y' to their index in the last
        dimension of a coordinate tensor.
        Defaults to `{'t': 0, 'x': 1, 'y': 2}`.

    expect_dim : {'2d', '3d'}, optional
        If provided, enforces that the input resolves to the
        specified dimension. ``'2d'`` requires input shaped like
        ``(batch, 3)`` or a dictionary of ``(batch, 1)`` tensors.
        ``'3d'`` requires input shaped like ``(batch, time, 3)`` or a
        dictionary of ``(batch, time, 1)`` tensors. If ``None``,
        both are accepted and 2D inputs are expanded to 3D.

    verbose : int, default 0
        Controls the verbosity of logging messages. `0` is silent,
        `1` provides basic info, and higher values provide more detail.

    Returns
    -------
    t, x, y : Tuple[tf.Tensor, tf.Tensor, tf.Tensor]
        The extracted t, x, and y coordinate tensors, each reshaped
        to be 3D with a singleton last dimension, e.g.,
        `(batch, time_steps, 1)`.

    Raises
    ------
    ValueError
        If input format is unsupported, dimensions are inconsistent,
        or `expect_dim` constraint is violated.
    """
    vlog(
        "Extracting (t, x, y) coordinates from inputs...",
        level=2,
        verbose=verbose,
        logger=_logger,
    )

    if expect_dim and expect_dim not in ["2d", "3d"]:
        raise ValueError(
            "`expect_dim` must be None, '2d', or '3d'."
        )

    if coord_slice_map is None:
        coord_slice_map = {"t": 0, "x": 1, "y": 2}

    def _ensure_3d_and_validate(tensor, name):
        """Helper to convert to tensor, ensure 3D, and validate."""
        if not isinstance(tensor, Tensor):
            tensor = tf_convert_to_tensor(
                tensor, dtype=tf_float32
            )

        if expect_dim == "2d" and tensor.ndim != 2:
            raise ValueError(
                f"Input '{name}' must be 2D (expect_dim='2d'), "
                f"but got rank {tensor.ndim}."
            )
        elif expect_dim == "3d" and tensor.ndim != 3:
            raise ValueError(
                f"Input '{name}' must be 3D (expect_dim='3d'), "
                f"but got rank {tensor.ndim}."
            )

        # For consistency, always expand 2D tensors to 3D.
        if tensor.ndim == 2:
            vlog(
                f"Expanding 2D input '{name}' to 3D for consistency.",
                level=3,
                verbose=verbose,
                logger=_logger,
            )
            return tf_expand_dims(tensor, axis=1)
        elif tensor.ndim == 3:
            return tensor
        else:
            raise ValueError(
                f"Input '{name}' must be a 2D or 3D tensor, but got "
                f"rank {tensor.ndim} with shape {tensor.shape}."
            )

    # --- Main Logic ---
    if isinstance(inputs, dict):
        if "coords" in inputs:
            coords_tensor = _ensure_3d_and_validate(
                inputs["coords"], "coords"
            )
        elif all(k in inputs for k in ("t", "x", "y")):
            t = _ensure_3d_and_validate(inputs["t"], "t")
            x = _ensure_3d_and_validate(inputs["x"], "x")
            y = _ensure_3d_and_validate(inputs["y"], "y")
            # The shapes are now guaranteed to be 3D.
            return (
                tf_cast(t, tf_float32),
                tf_cast(x, tf_float32),
                tf_cast(y, tf_float32),
            )
        else:
            raise ValueError(
                "Dict `inputs` must contain either 'coords' key "
                "or all of 't', 'x', 'y' keys."
            )
    elif isinstance(inputs, Tensor | np.ndarray):
        coords_tensor = _ensure_3d_and_validate(
            inputs, "inputs"
        )
    else:
        raise TypeError(
            f"`inputs` must be a tensor/array or dict; got {type(inputs)}"
        )

    # Slice the now-guaranteed 3D coords_tensor
    tf_debugging.assert_greater_equal(
        tf_shape(coords_tensor)[-1],
        3,
        message=(
            "Coordinate tensor must carry at least three features "
            "(t, x, y) in the last dimension,"
            f" but got shape {coords_tensor.shape}"
        ),
    )

    # if tf_shape(coords_tensor)[-1] < 3:
    #     raise ValueError(
    #         "Coordinate tensor must have at least 3 features (t,x,y) "
    #         f"in the last dimension, but got shape {coords_tensor.shape}"
    #     )

    t = coords_tensor[
        ..., coord_slice_map["t"] : coord_slice_map["t"] + 1
    ]
    x = coords_tensor[
        ..., coord_slice_map["x"] : coord_slice_map["x"] + 1
    ]
    y = coords_tensor[
        ..., coord_slice_map["y"] : coord_slice_map["y"] + 1
    ]

    vlog(
        f"Successfully extracted t:{t.shape}, x:{x.shape}, "
        f"y:{y.shape}",
        level=2,
        verbose=verbose,
        logger=_logger,
    )

    return (
        tf_cast(t, tf_float32),
        tf_cast(x, tf_float32),
        tf_cast(y, tf_float32),
    )


@ensure_pkg(
    KERAS_BACKEND or "tensorflow",
    extra="TensorFlow is required for this function.",
)
def extract_txy(
    inputs: Tensor
    | np.ndarray
    | dict[str, Tensor | np.ndarray],
    coord_slice_map: dict[str, int] | None = None,
    expect_dim: str | None = None,
    verbose: int = 0,
    _logger: logging.Logger
    | Callable[[str], None]
    | None = None,
    **kws,
) -> tuple[Tensor, Tensor, Tensor]:
    r"""
    Extracts t, x, y tensors from various input formats.

    This utility standardizes coordinate inputs, accepting a single
    tensor or a dictionary, and handling both 2D (spatial/static)
    and 3D (spatio-temporal) data with flexible dimension validation.

    Parameters
    ----------
    inputs : tf.Tensor, np.ndarray, or dict
        The input data containing coordinates. Can be a single tensor
        or a dictionary with 'coords' or 't', 'x', 'y' keys.

    coord_slice_map : dict, optional
        Mapping for 't', 'x', 'y' to their index in the last
        dimension of a coordinate tensor.
        Defaults to `{'t': 0, 'x': 1, 'y': 2}`.

    expect_dim : {'2d', '3d', '3d_only'}, optional
        Enforces a constraint on the input's dimension. ``'2d'``
        requires input shaped like ``(batch, 3)``. ``'3d'`` accepts
        3D input and expands 2D input to 3D with a time dimension of 1.
        ``'3d_only'`` requires 3D input and raises an error for 2D
        input. ``None`` accepts both 2D and 3D inputs without changing
        their rank.

    verbose : int, default 0
        Controls logging verbosity.

    Returns
    -------
    t, x, y : Tuple[tf.Tensor, tf.Tensor, tf.Tensor]
        The extracted t, x, and y coordinate tensors. Their rank (2D
        or 3D) depends on the input and the `expect_dim` mode.

    Raises
    ------
    ValueError
        If input format is unsupported, dimensions are inconsistent,
        or `expect_dim` constraint is violated.
    """
    vlog(
        "Extracting (t, x, y) coordinates from inputs...",
        level=2,
        verbose=verbose,
        logger=_logger,
    )

    if expect_dim and expect_dim not in [
        "2d",
        "3d",
        "3d_only",
    ]:
        raise ValueError(
            "`expect_dim` must be None, '2d', '3d', or '3d_only'."
        )

    if coord_slice_map is None:
        coord_slice_map = {"t": 0, "x": 1, "y": 2}

    def _process_tensor(tensor, name):
        """Helper to convert to tensor and validate dimension."""
        if not isinstance(tensor, Tensor):
            tensor = tf_convert_to_tensor(
                tensor, dtype=tf_float32
            )

        input_ndim = tensor.ndim

        # Validate against expect_dim
        if expect_dim == "2d" and input_ndim != 2:
            raise ValueError(
                f"Input '{name}' must be 2D for expect_dim='2d', "
                f"but got rank {input_ndim}."
            )
        elif expect_dim == "3d_only" and input_ndim != 3:
            raise ValueError(
                f"Input '{name}' must be 3D for expect_dim='3d_only', "
                f"but got rank {input_ndim}."
            )

        # Handle expansion for '3d' mode
        if expect_dim == "3d" and input_ndim == 2:
            vlog(
                f"Expanding 2D input '{name}' to 3D for expect_dim='3d'.",
                level=3,
                verbose=verbose,
                logger=_logger,
            )
            return tf_expand_dims(tensor, axis=1)

        # For all other cases, including `expect_dim=None`, return as is
        # after basic rank validation.
        if input_ndim not in [2, 3]:
            raise ValueError(
                f"Input '{name}' must be a 2D or 3D tensor, but got "
                f"rank {input_ndim} with shape {tensor.shape}."
            )
        return tensor

    # --- Main Logic ---
    if isinstance(inputs, dict):
        if "coords" in inputs:
            coords_tensor = _process_tensor(
                inputs["coords"], "coords"
            )
        elif all(k in inputs for k in ("t", "x", "y")):
            # When t,x,y are separate, they are typically 1D or 2D.
            # We process them and then concatenate.
            t_p = _process_tensor(inputs["t"], "t")
            x_p = _process_tensor(inputs["x"], "x")
            y_p = _process_tensor(inputs["y"], "y")

            return (
                tf_cast(t_p, tf_float32),
                tf_cast(x_p, tf_float32),
                tf_cast(y_p, tf_float32),
            )

            # The individual tensors might be 2D or 3D. Concat will work if
            # their ranks match, which is handled by _process_tensor.
            # coords_tensor = tf_concat([t_p, x_p, y_p], axis=-1)
        else:
            raise ValueError(
                "Dict `inputs` must contain either 'coords' key "
                "or all of 't', 'x', 'y' keys."
            )
    elif isinstance(inputs, Tensor | np.ndarray):
        coords_tensor = _process_tensor(inputs, "inputs")
    else:
        raise TypeError(
            f"`inputs` must be a tensor/array or dict; got {type(inputs)}"
        )

    # Slice the processed coords_tensor
    if tf_shape(coords_tensor)[-1] < 3:
        raise ValueError(
            "Coordinate tensor must have at least 3 features (t,x,y) "
            f"in the last dimension, but got shape {coords_tensor.shape}"
        )

    # Slicing keeps the original number of dimensions
    t = coords_tensor[
        ..., coord_slice_map["t"] : coord_slice_map["t"] + 1
    ]
    x = coords_tensor[
        ..., coord_slice_map["x"] : coord_slice_map["x"] + 1
    ]
    y = coords_tensor[
        ..., coord_slice_map["y"] : coord_slice_map["y"] + 1
    ]

    vlog(
        f"Successfully extracted t:{t.shape}, x:{x.shape}, "
        f"y:{y.shape}",
        level=2,
        verbose=verbose,
        logger=_logger,
    )

    return (
        tf_cast(t, tf_float32),
        tf_cast(x, tf_float32),
        tf_cast(y, tf_float32),
    )


@ensure_pkg(
    KERAS_BACKEND or "tensorflow",
    extra="TensorFlow is required for this function.",
)
def plot_hydraulic_head(
    model: Model,
    t_slice: float,
    x_bounds: tuple[float, float],
    y_bounds: tuple[float, float],
    resolution: int = 100,
    ax: plt.Axes | None = None,
    title: str | None = None,
    cmap: str = "viridis",
    colorbar_label: str = "Hydraulic Head (h)",
    save_path: str | None = None,
    show_plot: bool = True,
    **contourf_kwargs: Any,
) -> tuple[plt.Axes, ScalarMappable]:
    r"""Generate and plot a 2D contour map of a hydraulic head solution.

    This utility visualizes the output of a Physics-Informed Neural
    Network (PINN) that solves for the hydraulic head
    :math:`h(t, x, y)`. It automates the process of creating a
    spatial grid, running model predictions, and generating a
    publication-quality contour plot for a specific slice in time.

    Parameters
    ----------
    model : tf.keras.Model
        The trained PINN model. It is expected to have a ``.predict()``
        method that accepts a dictionary of tensors with keys
        ``{'t', 'x', 'y'}``.
    t_slice : float
        The specific point in time :math:`t` for which to plot the
        2D spatial solution.
    x_bounds : tuple of float
        A tuple ``(x_min, x_max)`` defining the spatial domain for
        the x-axis.
    y_bounds : tuple of float
        A tuple ``(y_min, y_max)`` defining the spatial domain for
        the y-axis.
    resolution : int, optional
        The number of points to sample along each spatial axis,
        creating a grid of ``resolution x resolution`` points for
        prediction. Higher values result in a smoother plot.
        Default is 100.
    ax : matplotlib.axes.Axes, optional
        A pre-existing Matplotlib Axes object to plot on. If ``None``,
        a new figure and axes are created internally. This is useful
        for embedding this plot within a larger figure arrangement.
        Default is ``None``.
    title : str, optional
        A custom title for the plot. If ``None``, a default title
        is generated using the value of `t_slice`. Default is ``None``.
    cmap : str, optional
        The name of the Matplotlib colormap to use for the contour
        plot. Default is ``'viridis'``.
    colorbar_label : str, optional
        The text label for the color bar. Default is
        ``'Hydraulic Head (h)'``.
    save_path : str, optional
        If provided, the path (including filename and extension)
        where the generated plot will be saved. This is only active
        when the function creates its own figure (i.e., when `ax`
        is ``None``). Default is ``None``.
    show_plot : bool, optional
        If ``True``, calls ``plt.show()`` to display the plot. This
        is only active when the function creates its own figure.
        Default is ``True``.
    **contourf_kwargs : any
        Additional keyword arguments that are passed directly to the
        ``matplotlib.pyplot.contourf`` function. This allows for
        advanced customization (e.g., ``levels=20``, ``extend='both'``).

    Returns
    -------
    ax : matplotlib.axes.Axes
        The Matplotlib Axes object on which the contour plot was drawn.
    contour : matplotlib.cm.ScalarMappable
        The contour plot object, which can be used for further
        customizations, such as modifying the color bar.

    See Also
    --------
    geoprior.models.pinn.PiTGWFlow : The PINN model this function is
                                 designed to visualize.

    Notes
    -----
    The core mechanism of this function involves creating a 2D
    meshgrid of :math:`(x, y)` coordinates. These grid points are then
    "flattened" into a long list of points, as the PINN model expects
    a batch of individual coordinates for prediction, not a grid.

    The prediction process is as follows:

    1.  A grid of shape ``(resolution, resolution)`` is created for
        :math:`x` and :math:`y`.
    2.  These grids are reshaped into column vectors of shape
        ``(resolution*resolution, 1)``.
    3.  A time vector of the same shape, filled with `t_slice`, is
        created.
    4.  The model's ``.predict()`` method is called on these flat
        tensors.
    5.  The resulting flat prediction vector is reshaped back to the
        original ``(resolution, resolution)`` grid shape for plotting.

    If a custom `ax` is provided, the user is responsible for calling
    ``plt.show()`` or saving the parent figure.

    Examples
    --------
    >>> import numpy as np
    >>> import tensorflow as tf
    >>> import matplotlib.pyplot as plt
    >>> # This is a mock model for demonstration purposes.
    >>> # In practice, you would use a trained PiTGWFlow model.
    >>> class MockPINN(tf.keras.Model):
    ...     def call(self, inputs):
    ...         # A simple analytical function for demonstration
    ...         t, x, y = inputs['t'], inputs['x'], inputs['y']
    ...         return tf.sin(np.pi * x) * tf.cos(np.pi * y) * tf.exp(-t)
    ...
    >>> mock_model = MockPINN()

    **1. Simple Plotting Example**

    This example creates a single plot and saves it to a file.

    >>> ax, contour = plot_hydraulic_head(
    ...     model=mock_model,
    ...     t_slice=0.5,
    ...     x_bounds=(-1, 1),
    ...     y_bounds=(-1, 1),
    ...     resolution=50,
    ...     save_path="hydraulic_head_t0.5.png",
    ...     show_plot=False  # Do not display interactively
    ... )
    Plot saved to hydraulic_head_t0.5.png

    **2. Advanced Example with Subplots**

    This example shows how to use the `ax` parameter to draw the
    solution at two different times side-by-side in one figure.

    >>> fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))
    >>> fig.suptitle('Hydraulic Head at Different Times', fontsize=16)
    ...
    >>> # Plot solution at t = 0.1
    >>> plot_hydraulic_head(
    ...     model=mock_model, t_slice=0.1, x_bounds=(-1, 1),
    ...     y_bounds=(-1, 1), ax=ax1, show_plot=False
    ... )
    ...
    >>> # Plot solution at t = 1.0
    >>> plot_hydraulic_head(
    ...     model=mock_model, t_slice=1.0, x_bounds=(-1, 1),
    ...     y_bounds=(-1, 1), ax=ax2, show_plot=False
    ... )
    ...
    >>> plt.tight_layout(rect=[0, 0, 1, 0.96])
    >>> plt.show()
    """
    # ... function implementation follows ...
    # --- 1. Handle Matplotlib Figure and Axes ---
    if ax is None:
        # Create a new figure only if no axes are provided
        fig, current_ax = plt.subplots(figsize=(9, 7))
        # Flag to indicate we have control over the figure object
        fig_created = True
    else:
        current_ax = ax
        fig = current_ax.figure
        fig_created = False

    # --- 2. Create Prediction Grid ---
    x_range = np.linspace(
        x_bounds[0], x_bounds[1], resolution
    )
    y_range = np.linspace(
        y_bounds[0], y_bounds[1], resolution
    )
    X, Y = np.meshgrid(x_range, y_range)

    # Flatten the grid to a list of points for model prediction
    x_flat = tf_convert_to_tensor(X.ravel(), dtype=tf_float32)
    y_flat = tf_convert_to_tensor(Y.ravel(), dtype=tf_float32)
    t_flat = tf_fill(x_flat.shape, t_slice)

    # Reshape to column vectors (N, 1) as expected by the model
    grid_coords = {
        "t": tf_reshape(t_flat, (-1, 1)),
        "x": tf_reshape(x_flat, (-1, 1)),
        "y": tf_reshape(y_flat, (-1, 1)),
    }

    # --- 3. Run Model Prediction ---
    h_pred_flat = model.predict(grid_coords)
    # Reshape the flat predictions back to the grid shape for plotting
    h_pred_grid = tf_reshape(h_pred_flat, X.shape)

    # --- 4. Plotting ---
    # Set default contour levels if not provided
    if "levels" not in contourf_kwargs:
        contourf_kwargs["levels"] = 100

    contour = current_ax.contourf(
        X, Y, h_pred_grid, cmap=cmap, **contourf_kwargs
    )

    # Add color bar
    fig.colorbar(contour, ax=current_ax, label=colorbar_label)

    # Set plot labels and title
    if title is None:
        title = f"Learned Hydraulic Head Solution at t = {t_slice}"
    current_ax.set_title(title, fontsize=14)
    current_ax.set_xlabel("x-coordinate")
    current_ax.set_ylabel("y-coordinate")
    current_ax.set_aspect("equal")

    # --- 5. Save and Show ---
    if fig_created:
        if save_path:
            fig.savefig(
                save_path, dpi=300, bbox_inches="tight"
            )
            print(f"Plot saved to {save_path}")

        if show_plot:
            plt.show()
        else:
            # If not showing, close the figure to free up memory
            plt.close(fig)

    return current_ax, contour

API notes#

A few practical notes help orient readers:

  • geoprior.utils is the best starting point for staged workflow code, diagnostics, export, and reproducibility scripts.

  • geoprior.models.utils is the better starting point for model-input preparation, sequence construction, and PINN-side formatting helpers.

  • geoprior.models.subsidence.utils is the best place to inspect scaling, SI conversion, depth/head conventions, and reference-state extraction logic.

  • The three layers are meant to complement one another rather than compete for the same role.

See also#