Utility API reference#
GeoPrior-v3 exposes utilities across three complementary layers:
geoprior.utilsfor workflow-facing utilities used across staged runs, preprocessing, forecasting, diagnostics, calibration, evaluation, export, and artifact handling;geoprior.models.utilsfor model-facing utilities used closer to forecasting models, sequence construction, PINN input preparation, tensor formatting, and PDE mode normalization;geoprior.models.subsidence.utilsfor subsidence-physics utilities used to canonicalize scaling metadata, convert between model and SI units, resolve groundwater and subsidence channels, and support head/depth transformations and physics-aware priors.
This page documents all three layers together because they are not isolated in practice. A typical GeoPrior workflow uses them as a stack:
the workflow layer prepares data, resolves config, and drives staged execution;
the model layer standardizes sequence handling, model inputs, forecast formatting, and PINN support functions;
the subsidence-physics layer reconciles units, aliases, coordinate scaling, groundwater conventions, and model-side physical state extraction.
Overview#
A useful mental map is:
geoprior.utils
├── audits / handshakes
├── config and NAT helpers
├── data / IO / export helpers
├── forecast formatting and evaluation
├── holdout and split logic
├── spatial and geospatial utilities
├── sequence and shape helpers
└── subsidence-oriented workflow conversions
geoprior.models.utils
├── general model helpers
├── sequence construction
├── forecast formatting
├── input packing / unpacking
├── anomaly and evaluation helpers
└── PINN-specific helper functions
geoprior.models.subsidence.utils
├── scaling alias normalization
├── SI conversion helpers
├── coordinate scaling helpers
├── groundwater / head conversion
├── dynamic channel resolution
└── history / reference-state extraction
Why three utility layers exist#
The separation is architectural rather than cosmetic.
Workflow surface#
The top-level geoprior.utils package is the best entry
point when reading Stage-1 through Stage-5 workflows, figure
scripts, evaluation paths, and export code. Its public exports
already include config loaders, dataset builders, calibration
helpers, forecast formatters, holdout logic, geospatial helpers,
and subsidence-oriented postprocessing utilities.
Model surface#
The geoprior.models.utils package complements the
workflow layer by handling lower-level model concerns such as
sequence construction, preparing input tuples,
forecast-to-DataFrame formatting, and PINN-oriented helper
functions. The associated module docstrings explicitly position
this layer in the context of sequence forecasting and
Temporal-Fusion-Transformer-style workflows, which is useful for
readers trying to connect GeoPrior’s utilities to broader
multi-horizon forecasting practice.
Subsidence-physics surface#
The geoprior.models.subsidence.utils module is narrower
and more physics-aware. It is where scaling metadata is made
canonical, SI conversions are enforced, coordinate policies are
checked, groundwater depth can be reconciled with hydraulic head,
and dynamic-channel lookup is stabilized. This layer is central
when readers need to understand how raw workflow tensors become
physically interpretable model quantities.
Reading order#
A practical order for new readers is:
That sequence moves from the workflow surface into the model surface and finally into the more specialized subsidence-physics support layer.
Module indexes#
These indexes use explicit fully qualified module names rather than recursive package discovery. That keeps autosummary predictable and avoids ambiguous package-relative lookup.
Workflow utility modules#
Audit helpers for stage handshakes and scaling artifacts. |
|
Essential utilities for data processing and analysis in FusionLab, offering functions for normalization, interpolation, feature selection, outlier removal, and various data manipulation tasks. |
|
Forecast utilities. |
|
Data utilities. |
|
Dependency utilities providing functions to handle package installation, checking, and ensuring that optional dependencies are available. |
|
Forecast utilities. |
|
Provides common helper functions and for validation, comparison, and other generic operations |
|
Geospatial utility helpers for GeoPrior workflows. |
|
Utility helpers for holdout and split workflows. |
|
Input/Output utilities for managing file paths, directories, and loading serialized data within FusionLab. |
|
Public exports for NAT workflow utilities. |
|
Parallel execution helpers for GeoPrior workflows. |
|
Utilities for computing error metrics in physical units given Stage-1 scaling metadata. |
|
Sequence-building helpers for temporal model inputs. |
|
Shape utility helpers for arrays and tensors. |
|
geospatial_utils - A collection of utilities for geospatial and positional data analysis, filtering, and transformations. |
|
geoprior.utils.split |
|
Utility helpers for subsidence data, units, and coordinates. |
|
System utilities module for managing system-level operations. |
|
Target-processing helpers for GeoPrior workflows. |
|
Provides a comprehensive set of functions and warnings for validating and ensuring the integrity of data. |
|
Vendored version parsing utilities. |
Model utility modules#
Subsidence-physics utility modules#
The subsidence-physics utility module is indexed from the
dedicated subsidence API page to keep one canonical autosummary
stub for geoprior.models.subsidence.utils.
Top-level workflow utility package#
The public package surface below documents the aggregated export
layer of geoprior.utils.
Public exports for GeoPrior utility helpers.
- geoprior.utils.spatial_sampling(data, sample_size=0.01, stratify_by=None, spatial_bins=10, spatial_cols=None, method='abs', min_relative_ratio=0.01, random_state=42, savefile=None, verbose=1)[source]
Sample spatial data intelligently to represent the distribution of the whole area and include different years.
This function performs stratified sampling on spatial data, ensuring that the sample reflects both spatial distribution and temporal aspects of the entire dataset. It combines spatial stratification based on coordinates and additional stratification columns specified by the user.
- Parameters:
data (
pandas.DataFrame) – The input DataFrame to sample from. Must contain spatial coordinate columns (e.g., ‘longitude’, ‘latitude’) and any columns specified instratify_by.sample_size (
floatorint, optional) – The proportion or absolute number of samples to select. If float, should be between 0.0 and 1.0 and represents the fraction of the dataset to include in the sample. If int, represents the absolute number of samples to select. Default is0.01(1% of the data).stratify_by (
listofstr, optional) – List of column names to stratify by.spatial_bins (
intortuple/listofint, optional) – Number of bins to divide the spatial coordinates into. If an integer, the same number of bins is used for all spatial dimensions. If a tuple or list, its length must match the number of spatial columns, specifying the number of bins for each spatial dimension. Default is10.spatial_cols (
listortupleofstr, optional) – List of spatial coordinate column names. Can accept one or two columns. IfNone, the function checks for columns named ‘longitude’ and/or ‘latitude’ indata. If only one spatial column is provided or found, a warning is issued, suggesting that providing both spatial columns is recommended for more accurate sampling. If more than two columns are provided, an error is raised.method (
str,{'abs', 'relative'}, default'abs') – Defines how the sample size is determined.'abs'or'absolute'uses a fixed sampling proportion based onsample_size.'relative'scales sampling by dataset stratification so small groups still receive a proportional sample controlled bymin_relative_ratio.min_relative_ratio (
float, default0.01) – Controls the minimum allowable fraction of records that must be sampled whenmethod='relative'. It must be between0and1. For example,min_relative_ratio=0.05requests at least 5 percent of the total dataset size from each stratification group when possible; if a group is smaller than that minimum, the entire subset is sampled instead.random_state (
int, optional) – Random seed for reproducibility. Default is42.verbose (
int, default1) – Controls progress-bar and status output during execution. Larger values produce more detailed messages.
- Returns:
sampled_data – A sampled DataFrame representing the distribution of the whole area and including different years.
- Return type:
Notes
The function performs stratified sampling based on spatial bins and other specified stratification columns. Spatial coordinates are binned using quantile-based discretization (
pandas.qcut()), ensuring each bin has approximately the same number of observations.Let \(N\) be the total number of samples in
data, and \(n\) be the desired sample size. The function calculates the number of samples to draw from each stratification group based on the proportion of the group size to the total dataset size:(1)#\[n_i = \left\lceil \frac{N_i}{N} \times n \right\rceil\]where \(N_i\) is the size of group \(i\), and \(n_i\) is the number of samples to draw from group \(i\).
The function ensures that all specified spatial and stratification columns exist in
data, that the number of spatial bins matches the number of spatial columns, and that the sample size is valid. A warning is issued when only one spatial column is used because two spatial columns usually give more reliable spatial sampling.Examples
>>> from geoprior.utils.spatial_utils import spatial_sampling >>> import pandas as pd >>> # Assume 'df' is a pandas DataFrame with columns >>> # 'longitude', 'latitude', 'year', and other data. >>> sampled_df = spatial_sampling( ... data=df, ... sample_size=0.05, ... stratify_by=['year', 'geological_category'], ... spatial_bins=(10, 15), ... spatial_cols=['longitude', 'latitude'], ... random_state=42 ... ) >>> print(sampled_df.shape)
See also
pandas.qcutQuantile-based discretization function used for binning.
sklearn.model_selection.StratifiedShuffleSplitFor stratified sampling.
batch_spatial_samplingResample spatial data with batching.
- geoprior.utils.create_spatial_clusters(df, spatial_cols=None, cluster_col='region', n_clusters=None, algorithm='kmeans', view=True, figsize=(14, 10), s=60, plot_style='seaborn', cmap='tab20', show_grid=True, grid_props=None, auto_scale=True, savefile=None, verbose=1, **kwargs)[source]
Cluster 2D spatial data in
dfusing <algorithm> and optionally plot the results.This function, <create_spatial_clusters>, extracts two coordinate columns from <df> to form clusters via methods such as ‘kmeans’, ‘dbscan’, or ‘agglo’ (agglomerative). It uses the function filter_valid_kwargs (when relevant) to strip out invalid parameters for certain estimators, and writes cluster labels into <cluster_col>.
- Parameters:
df (
pandas.DataFrame) – Input DataFrame holding spatial coordinates and optional other fields.spatial_cols (
listofstr, optional) – Two-column list for x and y coordinates. Defaults to['longitude','latitude']if None.cluster_col (
str, default'region') – Name of the column to store the assigned cluster labels.n_clusters (
int, optional) – Number of clusters to form. If not provided for KMeans, it is auto-detected. For DBSCAN or Agglomerative, a warning is issued if not set.algorithm (
str, default'kmeans') – Choice of clustering algorithm among['kmeans','dbscan','agglo'].view (
bool, defaultTrue) – If True, displays a scatterplot of the final clusters.figsize (
tuple, default(14,10)) – Size of the displayed figure for the cluster plot.s (
int, default60) – Marker size in the scatterplot.plot_style (
str, default'seaborn') – Matplotlib style used for the plot.cmap (
str, default'tab20') – Colormap name used to differentiate clusters.show_grid (
bool, defaultTrue) – Toggles grid lines on or off.grid_props (
dict, optional) – Additional keyword arguments controlling the grid style.auto_scale (
bool, defaultTrue) – If True, standardize coordinates before clustering.savefile (
str, optional) – File path to save the data with an additional <cluster_col> storing the assigned cluster labels if desired.verbose (
int, default1) – Controls console logs. Higher values yield more details about scaling and cluster detection.**kwargs – Additional keyword arguments passed to the chosen algorithm (filtered by filter_valid_kwargs for KMeans, DBSCAN, AgglomerativeClustering ).
- Returns:
A copy of <df> with an additional <cluster_col> storing the assigned cluster labels.
- Return type:
Notes
If <auto_scale> is True, it uses a standard scaler to normalize the coordinate columns. The scatterplot is generated using the library seaborn for enhanced styling.
By default, for <algorithm> = “kmeans”, the model attempts to minimize:
(2)#\[J = \sum_{i=1}^{N} \min_{\mu_j} \lVert x_i - \mu_j \rVert^2\]where \(x_i\) are the scaled or raw 2D coordinates in <df>. The function can optionally auto-detect
n_clustersusing a silhouette and elbow analysis if not provided.Examples
>>> from geoprior.utils.spatial_utils import create_spatial_clusters >>> import pandas as pd >>> df = pd.DataFrame({ ... "longitude": [0.1, 0.2, 2.2, 2.3], ... "latitude": [1.0, 1.1, 2.1, 2.2] ... }) >>> # KMeans with auto scale and auto-detect k >>> result = create_spatial_clusters( ... df=df, ... algorithm="kmeans", ... view=True ... ) >>> # DBSCAN with custom arguments >>> result_db = create_spatial_clusters( ... df=df, ... algorithm="dbscan", ... eps=0.5, ... min_samples=2 ... )
See also
filter_valid_kwargsHelps discard unsupported keyword arguments for chosen estimators.
- geoprior.utils.augment_city_spatiotemporal_data(df, city, mode='interpolate', group_by_cols=None, time_col=None, value_cols_interpolate=None, feature_cols_augment=None, interpolation_config=None, augmentation_config=None, target_name=None, interpolate_target=False, verbose=True, coordinate_precision=None, savefile=None)[source]
Apply grouped spatiotemporal augmentation with city-aware defaults.
This is a convenience wrapper around
augment_spatiotemporal_data. It validates the requested city, optionally rounds coordinates before grouping, and forwards interpolation and augmentation configuration dictionaries.- Parameters:
df (
pandas.DataFrame) – Input DataFrame containing spatial, temporal, and feature columns.city (
{'nansha', 'zhongshan'}) – City identifier used for validation and defaults.mode (
{'interpolate', 'augment_features', 'both'}, optional) – Processing mode forwarded toaugment_spatiotemporal_data.group_by_cols (
listofstrorNone, optional) – Grouping columns for interpolation.time_col (
strorNone, optional) – Time column used for interpolation.value_cols_interpolate (
listofstrorNone, optional) – Columns to interpolate.feature_cols_augment (
listofstrorNone, optional) – Columns to augment with noise.interpolation_config (
dictorNone, optional) – Keyword arguments forinterpolate_temporal_gaps. Typical values include{'freq': 'AS', 'method': 'linear'}.augmentation_config (
dictorNone, optional) – Keyword arguments foraugment_series_features. Typical values include{'noise_level': 0.01, 'noise_type': 'gaussian'}.target_name (
strorNone, optional) – Optional target column name used when inferring default feature sets.interpolate_target (
bool, optional) – Whether the target should be included in default interpolation columns.verbose (
bool, optional) – Whether to emit progress information.coordinate_precision (
intorNone, optional) – Decimal precision applied to coordinates before grouping.savefile (
strorNone, optional) – Optional output CSV path handled by the decorator.
- Returns:
Augmented DataFrame.
- Return type:
- Raises:
ValueError – If
cityormodeis invalid, or if required arguments are missing for the selected mode.TypeError – If the main inputs are of the wrong type.
- geoprior.utils.augment_series_features(series_df, feature_cols, noise_level=0.01, noise_type='gaussian', random_seed=None, savefile=None)[source]
Add random noise to selected numeric feature columns.
- Parameters:
series_df (
pandas.DataFrame) – Input DataFrame representing one or more time series.noise_level (
float, optional) – Magnitude of the added noise. For Gaussian noise it scales the feature standard deviation, and for uniform noise it scales the feature range.noise_type (
{'gaussian', 'uniform'}, optional) – Type of noise distribution to use.random_seed (
intorNone, optional) – Seed for reproducible noise generation.savefile (
strorNone, optional) – Optional output path handled by the decorator.
- Returns:
DataFrame with noise added to the selected feature columns.
- Return type:
- Raises:
ValueError – If requested feature columns are missing or
noise_typeis invalid.TypeError – If the main inputs are of the wrong type.
Notes
Non-numeric columns are skipped, and constant or invalid numeric ranges are left unchanged.
- geoprior.utils.generate_dummy_pinn_data(n_samples, *, year_range=None, coords_range=None, subs_range=None, gwl_range=None, rainfall_range=None, vars_range=None)[source]
Generate dummy PINN data dictionary with specified or default ranges.
- Parameters:
n_samples (
int) – Number of samples to generate.year_range (
tuple[float,float], optional) – (min_year, max_year) for integer years. Default (2000, 2025).coords_range (
tuple[tuple[float,float],tuple[float,float]], optional) – ((lon_min, lon_max), (lat_min, lat_max)). Default ((113.0, 113.8), (22.3, 22.8)).subs_range (
tuple[float,float], optional) – (mean_subsidence, std_subsidence) for normal distribution. Default (-20, 15).gwl_range (
tuple[float,float], optional) – (mean_gwl, std_gwl) for normal distribution. Default (2.5, 1.0).rainfall_range (
tuple[float,float], optional) – (min_rain, max_rain) for uniform distribution. Default (500, 2500).vars_range (
dict, optional) – Dictionary that may contain any of the keys: ‘year_range’, ‘coords_range’, ‘subs_range’, ‘gwl_range’, ‘rainfall_range’. Missing keys will fall back to defaults or to explicitly passed arguments.
- Returns:
dummy_data_dict –
- Dictionary with keys:
”year” : integer years array
”longitude” : float longitudes array
”latitude” : float latitudes array
”subsidence” : float subsidence values array
”GWL” : float groundwater level values array
”rainfall_mm” : float rainfall values array
- Return type:
dict[str,np.ndarray]
- geoprior.utils.augment_spatiotemporal_data(df, mode, group_by_cols=None, time_col=None, value_cols_interpolate=None, feature_cols_augment=None, interpolation_kwargs=None, augmentation_kwargs=None, savefile=None, verbose=False)[source]
Apply interpolation, feature augmentation, or both to grouped data.
- Parameters:
df (
pandas.DataFrame) – Input spatiotemporal DataFrame.mode (
{'interpolate', 'augment_features', 'both'}) – Processing mode. Use interpolation only, feature augmentation only, or interpolation followed by augmentation.group_by_cols (
listofstrorNone, optional) – Grouping columns used for per-location processing.time_col (
strorNone, optional) – Time column required when interpolation is requested.value_cols_interpolate (
listofstrorNone, optional) – Value columns to interpolate when interpolation is enabled.feature_cols_augment (
listofstrorNone, optional) – Feature columns to perturb when augmentation is enabled.interpolation_kwargs (
dictorNone, optional) – Keyword arguments forwarded tointerpolate_temporal_gaps.augmentation_kwargs (
dictorNone, optional) – Keyword arguments forwarded toaugment_series_features.savefile (
strorNone, optional) – Optional output path handled by the decorator.verbose (
bool, optional) – Whether to emit progress information.
- Returns:
Processed DataFrame assembled from all groups.
- Return type:
- Raises:
ValueError – If
modeis invalid or required arguments for the selected mode are missing.
Notes
Groups are processed independently and concatenated afterward.
- geoprior.utils.mask_by_reference(data, ref_col, values=None, find_closest=False, fill_value=0, mask_columns=None, error='raise', verbose=0, inplace=False, savefile=None)[source]
Masks (replaces) values in columns other than the reference column for rows in which the reference column matches (or is closest to) the specified value(s).
If a row’s reference-column value is matched, that row’s values in the other columns are overwritten by
fill_value. The reference column itself is not modified.- This function supports both exact and approximate matching:
Exact matching is used if
find_closest=False.Approximate (closest) matching is used if
find_closest=Trueand the reference column is numeric.
By default, if the reference column does not exist or if the given
valuescannot be found (or approximated) in the reference column, an exception is raised. This behavior can be adjusted with theerrorparameter.- Parameters:
data (
pd.DataFrame) – The input DataFrame containing the data to be masked.ref_col (
str) – The column indataserving as the reference for matching or finding the closest values.values (
AnyorsequenceofAny, optional) –- The reference values to look for in
ref_col. This can be: A single value (e.g.,
0or"apple").A list/tuple of values (e.g.,
[0, 10, 25]).If
valuesis None, all rows are masked (i.e. all rows match), effectively overwriting the entire DataFrame (except the reference column) withfill_value.
Note that if
find_closest=False, these values must appear in the reference column; otherwise, an error or warning is triggered (depending on theerrorsetting).- The reference values to look for in
find_closest (
bool, defaultFalse) – If True, performs an approximate match for numeric reference columns. For each entry invalues, the function locates the row(s) inref_colwhose value is numerically closest. Non-numeric reference columns will revert to exact matching regardless.fill_value (
Any, default0) –The value used to fill/mask the non-reference columns wherever the condition (exact or approximate match) is met. This can be any valid type, e.g., integer, float, string, np.nan, etc. If
fill_value='auto'and multiple values are given, each row matched by a particular reference value is filled with that same reference value.- Examples:
If
values=9andfill_value='auto', the fill value is 9 for matched rows.If
values=['a', 10]andfill_value='auto', then rows matching ‘a’ are filled with ‘a’, and rows matching 10 are filled with 10.
mask_columns (
strorlistofstr, optional) – If specified, only these columns are masked. If None, all columns exceptref_colare masked. If any column inmask_columnsdoes not exist in the DataFrame anderror='raise', a KeyError is raised; otherwise, a warning may be issued or ignored.error (
{'raise', 'warn', 'ignore'}, default'raise') –- Controls how to handle errors:
’raise’: raise an error if the reference column does not exist or if any of the given values cannot be matched (or approximated).
’warn’: only issue a warning instead of raising an error.
’ignore’: silently ignore any issues.
verbose (
int, default0) –- Verbosity level:
0: silent (no messages).
1: minimal feedback.
2 or 3: more detailed messages for debugging.
inplace (
bool, defaultFalse) – If True, performs the operation in place and returns the original DataFrame with modifications. If False, returns a modified copy, leaving the original unaltered.savefile (
strorNone, optional) – File path where the DataFrame is saved if the decorator-based saving is active. If None, no saving occurs.
- Returns:
A DataFrame where rows matching the specified condition (exact or approximate) have had their non-reference columns replaced by
fill_value.- Return type:
pd.DataFrame- Raises:
KeyError – If
error='raise'andref_colis not indata.columns.ValueError – If
error='raise'and no exact/approx match can be found for one or more entries invalues.
Notes
If
valuesis None, all rows are masked in the non-ref columns, effectively overwriting them withfill_value.When
find_closest=True, approximate matching is performed only if the reference column is numeric. For non-numeric data, it falls back to exact matching.When multiple reference values are provided, each is processed in turn. If fill_value=’auto’, each matched row is filled with that specific reference value.
Examples
>>> import pandas as pd >>> from geoprior.utils.data_utils import mask_by_reference >>> >>> df = pd.DataFrame({ ... "A": [10, 0, 8, 0], ... "B": [2, 0.5, 18, 85], ... "C": [34, 0.8, 12, 4.5], ... "D": [0, 78, 25, 3.2] ... }) >>> >>> # Example 1: Exact matching, replace all columns except 'A' with 0 >>> masked_df = mask_by_reference( ... data=df, ... ref_col="A", ... values=0, ... fill_value=0, ... find_closest=False, ... error="raise" ... ) >>> print(masked_df) >>> # 'B', 'C', 'D' for rows where A=0 are replaced with 0. >>> >>> # Example 2: Approximate matching for numeric >>> # If 'A' has values [0, 10, 8] and we search for 9, then 'A=8' or 'A=10' >>> # are the closest, so those rows get masked in non-ref columns. >>> masked_df2 = mask_by_reference( ... data=df, ... ref_col="A", ... values=9, ... find_closest=True, ... fill_value=-999 ... ) >>> print(masked_df2)
>>> >>> # Example 2: Approx. match for numeric ref_col >>> # 9 is between 8 and 10, so rows with A=8 and A=10 are masked >>> res2 = mask_by_reference(df, "A", 9, find_closest=True, fill_value=-999) >>> print(res2) ... # Rows 0 (A=10) and 2 (A=8) are replaced with -999 in columns B,C,D >>> >>> # Example 3: fill_value='auto' with multiple values >>> # Rows matching A=0 => fill with 0; rows matching A=8 => fill with 8 >>> res3 = mask_by_reference(df, "A", [0, 8], fill_value='auto') >>> print(res3) ... # => rows with A=0 => B,C,D replaced by 0 ... # => rows with A=8 => B,C,D replaced by 8 >>> >>> # 2) mask_columns=['C','D'] => only columns C and D are masked >>> res2 = mask_by_reference(df, "A", values=0, fill_value=999, ... mask_columns=["C","D"]) >>> print(res2) ... # Rows where A=0 => columns C,D replaced by 999, while B remains unchanged >>>
- geoprior.utils.nan_ops(data, auxi_data=None, data_kind=None, ops='check_only', action=None, error='raise', process=None, condition=None, savefile=None, verbose=0)[source]
Perform operations on NaN values within data structures, handling both primary data and optional witness data based on specified parameters.
This function provides a comprehensive toolkit for managing missing values (NaN) in various data structures such as NumPy arrays, pandas DataFrames, and pandas Series. Depending on the ops parameter, it can check for the presence of NaN`s, validate data integrity, or sanitize the data by filling or dropping `NaN values. The function also supports handling witness data, which can be crucial in scenarios where the relationship between primary and witness data must be maintained.
(3)#\[\begin{split}\text{Processed\_data} = \begin{cases} \text{filled\_data} & \text{if action is 'fill'} \\ \text{dropped\_data} & \text{if action is 'drop'} \\ \text{original\_data} & \text{otherwise} \end{cases}\end{split}\]- Parameters:
data (
array-like,pandas.DataFrame, orpandas.Series) – The primary data structure containing NaN values to be processed.auxi_data (
array-like,pandas.DataFrame, orpandas.Series, optional) – Auxiliary data that accompanies the primary data. Its role depends on thedata_kindparameter. Ifdata_kindis ‘target’,auxi_datais treated as feature data, and vice versa. This is useful for operations that need to maintain the alignment between primary and witness data.data_kind (
{'target', 'feature', None}, optional) – Specifies the role of the primary data. If set to ‘target’, data is considered target data, andauxi_data(if provided) is treated as feature data. If set to ‘feature’, data is treated as feature data, andauxi_datais considered target data. If None, no special handling is applied, and witness data is ignored unless explicitly required by other parameters.ops (
{'check_only', 'validate', 'sanitize'}, default :py:class:``’check_only’:py:class:``) –Defines the operation to perform on the NaN values in the data:
'check_only': Checks whether the data contains any NaN values and returns a boolean indicator.'validate': Validates that the data does not contain NaN values. If NaN`s are found, it raises an error or warns based on the ``error` parameter.'sanitize': Cleans the data by either filling or dropping NaN values based on theaction,process, andconditionparameters.
action (
{'fill', 'drop'}, optional) –Specifies the action to take when
opsis set to ‘sanitize’:'fill': Fills NaN values using the fill_NaN function with the method set to ‘both’.'drop': Drops NaN values based on the conditions and process specified. If data_kind is ‘target’, it handles `NaN`s in a way that preserves data integrity for machine learning models.If None, defaults to ‘drop’ when sanitizing.
Note: If
opsis not ‘sanitize’ andactionis set, an error is raised indicating conflicting parameters.error (
{'raise', 'warn', None}, default :py:class:``’raise’:py:class:``) –Determines the error handling policy:
'raise': Raises exceptions when encountering issues.'warn': Emits warnings instead of raising exceptions.None: Defaults to the base policy, which is typically ‘warn’.
This parameter is utilized by the error_policy function to enforce consistent error handling throughout the operation.
process (
{'do', 'do_anyway'}, optional) –Works in conjunction with the
actionparameter whenactionis ‘drop’:'do': Drops NaN values only if certain conditions are met.'do_anyway': Forces the dropping of NaN values regardless of conditions.
This provides flexibility in handling `NaN`s based on the specific requirements of the dataset and the analysis being performed.
condition (
callableorNone, optional) – A callable that defines a condition for dropping NaN values whenactionis ‘drop’. For example, it can specify that the number of NaN`s should not exceed a certain fraction of the dataset. If the condition is not met, the behavior is controlled by the ``process` parameter.verbose (
int, default0) –Controls the verbosity level of the function’s output for debugging purposes:
0: No output.1: Basic informational messages.2: Detailed processing messages.3: Debug-level messages with complete trace of operations.
Higher verbosity levels provide more insights into the function’s internal operations, aiding in debugging and monitoring.
- Returns:
The sanitized data structure with NaN values handled according to the specified parameters. If
auxi_datais provided and processed, a tuple containing the sanitized data and auxi_data is returned. Otherwise, only the sanitized data is returned.- Return type:
array-like,pandas.DataFrame, orpandas.Series- Raises:
If an invalid value is provided for
opsordata_kind.If
auxi_datadoes not align withdatain shape.If sanitization conditions are not met and the error policy is set to ‘raise’.
Warning –
Emits warnings when NaN values are present and the error policy is
set to ‘warn’.
Examples
>>> from geoprior.utils.data_utils import nan_ops >>> import pandas as pd >>> import numpy as np >>> # Example with target data and witness feature data >>> target = pd.Series([1, 2, np.nan, 4]) >>> features = pd.DataFrame({ ... 'A': [5, np.nan, 7, 8], ... 'B': ['x', 'y', 'z', np.nan] ... }) >>> # Check for NaNs >>> nan_ops(target, auxi_data=features, data_kind='target', ops='check_only') (True, True) >>> # Validate data (will raise ValueError if NaNs are present) >>> nan_ops(target, auxi_data=features, data_kind='target', ops='validate') Traceback (most recent call last): ... ValueError: Target contains NaN values. >>> # Sanitize data by dropping NaNs >>> cleaned_target, cleaned_features = nan_ops( ... target, ... auxi_data=features, ... data_kind='target', ... ops='sanitize', ... action='drop', ... verbose=2 ... ) Dropping NaN values. Dropped NaNs successfully. >>> cleaned_target 0 1.0 1 2.0 3 4.0 dtype: float64 >>> cleaned_features A B 0 5.0 x 3 8.0 NaN
Notes
The nan_ops function is designed to provide a robust framework for handling missing values in datasets, especially in machine learning workflows where the integrity of target and feature data is paramount. By allowing conditional operations and providing flexibility in error handling, it ensures that data preprocessing can be tailored to the specific needs of the analysis.
The function leverages helper utilities such as fill_NaN, drop_nan_in, and error_policy to maintain consistency and reliability across different data structures and scenarios. The verbosity levels aid developers in tracing the function’s execution flow, making it easier to debug and verify data transformations.
See also
geoprior.utils.base_utils.fill_NaNFills NaN values in numeric data structures using specified methods.
geoprior.core.array_manager.drop_nan_inDrops NaN values from data structures, optionally alongside witness data.
geoprior.core.utils.error_policyDetermines how errors are handled based on user-specified policies.
geoprior.core.array_manager.array_preserverPreserves and restores the original structure of array-like data.
- geoprior.utils.unpack_frames_from_file(merged, *, group_col='city', output_dir=None, output_format='csv', compression=None, use_source_col=True, source_col='source', filename_pattern='{group_value}_split', drop_columns=None, keep_columns=None, save=True, return_dict=True, save_kwargs=None, verbose=1, logger)[source]
Reverse of merge_city_frames_to_file: split an aggregated NATCOM dataset into per-city frames (and optionally write them to disk).
- Parameters:
merged (
path-likeorDataFrame) –Aggregated dataset. If path-like, the format is inferred from the file suffix:
.parquet→pandas.read_parquet().csv→pandas.read_csv().feather→pandas.read_feather().pkl/.pickle→pandas.read_pickle()
If a
DataFrameis passed, it is used directly.group_col (
str, optional) – Column used to split the dataset (default:'city'). Each unique value defines one output chunk.output_dir (
path-like, optional) – Directory where per-group files are written. IfNoneandmergedis a path, the directory ofmergedis used. Ifmergedis a DataFrame andoutput_dirisNone, the current working directory is used.output_format (
{'csv', 'parquet', 'feather', 'pickle'}, optional) – Output format for per-group files. Default is'csv'.compression (
strorNone, optional) –Compression to use when writing:
For
'csv', forwarded toDataFrame.to_csv()as thecompressionargument (e.g.'gzip').For
'parquet', forwarded toDataFrame.to_parquet()(e.g.'snappy','gzip').Ignored for
'feather'and'pickle'(these use their own defaults).
use_source_col (
bool, optional) –If
True(default) and a column namedsource_colexists, the helper tries to reconstruct the original file name for each group:If a group has a single unique, non-null source value that looks like a filename (e.g.
'nansha_final_main_std.harmonized.csv'), that base name is used for the output (with its suffix adjusted to matchoutput_formatif needed).If there are multiple unique source labels within a group, it falls back to
filename_pattern.
source_col (
str, optional) – Name of the column containing the source label (default:'source'). This should match the column created inmerge_frames_to_file()whenadd_source_label=True.filename_pattern (
str, optional) –Pattern used when no suitable source label is available. The following placeholders are supported:
{group_value}: the group value as a string{group_col}: the name of the grouping column
Example:
filename_pattern="{group_col}_{group_value}_data"→"city_Nansha_data.csv".drop_columns (
iterableofstr, optional) – Columns to drop from each group before saving/returning (e.g.['source']if you don’t want the bookkeeping column).keep_columns (
iterableofstr, optional) – If provided, only these columns are kept (all others are dropped after anydrop_columnsprocessing is applied).save (
bool, optional) – IfTrue(default), write each group to disk as a separate file. IfFalse, no files are written; only the dict of DataFrames is returned (ifreturn_dict=True).return_dict (
bool, optional) – IfTrue(default), return a mapping{group_value: group_df}. IfFalse, an empty dict is returned (useful when you only care about side-effect files).save_kwargs (
dict, optional) – Extra keyword arguments forwarded to the respective writer:DataFrame.to_csv(),DataFrame.to_parquet(),DataFrame.to_feather(), orDataFrame.to_pickle().verbose (
int, optional) – Verbosity level.0= silent,>=1prints progress information.logger (None)
- Returns:
out – Dictionary mapping each group value to the corresponding
DataFrame. Empty ifreturn_dict=False.- Return type:
- Raises:
ValueError – If
group_colis not present in the merged dataset.
Examples
>>> from geoprior.utils.geo_utils import unpack_frames_from_file >>> unpack_frames_from_file( ... "natcom_all_cities.parquet", ... group_col="city", ... output_format="csv", ... ) # -> writes e.g. 'nansha_final_main_std.harmonized.csv', # 'zhongshan_final_main_std.harmonized.csv' (if `source` labels exist), # and returns a dict: {'Nansha': df_nansha, 'Zhongshan': df_zhongshan}
- geoprior.utils.widen_temporal_columns(data, dt_col, spatial_cols=None, target_name=None, round_dt=True, ignore_cols=None, nan_op=None, nan_thresh=None, savefile=None, verbose=0)[source]
Convert a long PIHALNet prediction table into a wide format where each temporal slice becomes a dedicated column.
The routine pivots columns whose names follow the pattern
<base> deterministic forecast <base>_qXX quantile forecast (e.g., ``subsidence_q10``) <base>_actual ground‑truth column
and produces columns of the form
<base>_<year> point forecast <base>_<year>_qXX quantile forecast <base>_<year>_actual ground‑truth value
If duplicate
(spatial, year)pairs are found, values are aggregated with :pyfunc:`pandas.Series.groupby(mean) <pandas.core.series.Series.groupby>` prior to pivoting to avoid “Index contains duplicate entries” errors.- Parameters:
data (
PathLike objectorpandas.DataFrame) – Long‑format DataFrame returned by :pyfunc:`geoprior.utils.format_pihalnet_predictions`.dt_col (
str) – Column holding the temporal coordinate (e.g.,'coord_t'). Must be numeric or datetime‑coercible. When round_dt is True, values are rounded to integers.spatial_cols (
(str,str)orNone, defaultNone) – Names of x and y spatial coordinates. These are retained as leading columns in the output. If None, the function falls back to'sample_idx'or an auto‑generated'row_id'.target_name (
strorNone, defaultNone) – Restrict pivoting to a specific base (e.g.,'subsidence'). When None every base present in df is widened.round_dt (
bool, defaultTrue) – Round dt_col to the nearest integer (helpful for fractional years such as 2020.0001).ignore_cols (
list[str]orNone, defaultNone) – Additional columns to carry through unchanged. Values are propagated per spatial location using the first non‑null entry.nan_op (
{'drop', 'fill', 'both', None}, defaultNone) –Strategy for NaN handling after pivot:
'fill'– forward‑fill then back‑fill missing values.'drop'– drop rows containing NaNs (see nan_thresh).'both'– fill then drop according to nan_thresh.None– leave NaNs untouched.
nan_thresh (
floatorNone, defaultNone) –When nan_op contains
'drop', rows are dropped if the proportion of missing values exceeds nan_thresh. Set nan_thresh = 0 to require no NaNs, 0.5 to allow ≤ 50 % missing, etc.(4)\[\text{row kept} \;\Longleftrightarrow\; \frac{\text{NaNs in row}}{\text{row width}} \le \text{nan\_thresh}\]savefile (
str, optional) – If a file path is provided, the final wide-format DataFrame will be saved as a CSV file.verbose (
int, default0) – Diagnostic verbosity from 0 (silent) to 5 (trace every step).
- Returns:
Wide‑format frame with spatial identifiers first, followed by year‑wise forecast, quantile, and actual columns.
- Return type:
- Raises:
KeyError – dt_col missing from df or spatial_cols absent.
ValueError – No columns match target_name or nan_thresh is outside \([0, 1]\).
Notes
Duplicate indices are aggregated with the arithmetic mean before pivoting. Modify the aggregation lambda inside the function for alternative choices.
If ignore_cols is provided, their first non‑null value per spatial location is appended to the output.
Examples
Minimal usage on a tiny synthetic set
>>> import pandas as pd >>> from geoprior.utils.data_utils import widen_temporal_columns >>> >>> df_long = pd.DataFrame( ... { ... "coord_x": [113.15, 113.15, 113.15, 113.15], ... "coord_y": [22.63, 22.63, 22.63, 22.63], ... "coord_t": [2019, 2020, 2019, 2020], ... "subsidence_q50": [0.09, 0.10, 0.12, 0.13], ... "subsidence_actual": [0.08, 0.11, 0.10, 0.14], ... } ... ) >>> >>> wide = widen_temporal_columns( ... df_long, ... dt_col="coord_t", ... spatial_cols=("coord_x", "coord_y"), ... verbose=2, ... ) [INFO] Initial rows: 4, columns: 2 [INFO] Widening base 'subsidence' (2 columns) [DONE] Final wide shape: (1, 4) >>> wide coord_x coord_y subsidence_2019_actual subsidence_2020_actual \ 0 113.15 22.63 0.08 0.11
subsidence_2019_q50 subsidence_2020_q50
0 0.12 0.13
End‑to‑end example with NaN handling, ignored columns, and two targets
>>> import numpy as np >>> rng = pd.date_range("2018", periods=3, freq="Y").year >>> n = 5 # five spatial locations >>> >>> # build synthetic long DataFrame >>> df_long = pd.DataFrame( ... { ... "sample_idx": np.repeat(np.arange(n), len(rng)), ... "coord_x": np.repeat(np.linspace(113.4, 113.5, n), len(rng)), ... "coord_y": np.repeat(np.linspace(22.1, 22.2, n), len(rng)), ... "coord_t": np.tile(rng, n), ... "region": np.repeat(["A", "B", "A", "B", "A"], len(rng)), ... "subsidence_q10": np.random.rand(n * len(rng)), ... "subsidence_q50": np.random.rand(n * len(rng)), ... "subsidence_q90": np.random.rand(n * len(rng)), ... "subsidence_actual": np.random.rand(n * len(rng)), ... "GWL_q50": np.random.rand(n * len(rng)), ... } ... ) >>> >>> # introduce NaNs for demonstration >>> df_long.loc[df_long.sample(frac=0.2).index, "subsidence_q50"] = np.nan >>> >>> wide = widen_temporal_columns( ... df_long, ... dt_col="coord_t", ... spatial_cols=("coord_x", "coord_y"), ... ignore_cols=["region"], ... target_name=None, # widen both 'subsidence' and 'GWL' ... nan_op="both", # fill then drop rows with many NaNs ... nan_thresh=0.4, # allow at most 40 % missing ... verbose=3, ... ) [INFO] Initial rows: 15, columns: 7 [INFO] Widening base 'GWL' (1 columns) └─ 0 duplicate rows in 'GWL_q50' → aggregated [INFO] Widening base 'subsidence' (4 columns) └─ 0 duplicate rows in 'subsidence_q10' → aggregated └─ 0 duplicate rows in 'subsidence_q50' → aggregated └─ 0 duplicate rows in 'subsidence_q90' → aggregated └─ 0 duplicate rows in 'subsidence_actual' → aggregated [INFO] Missing values filled (ffill+bfill). [INFO] Rows with >40% NaN dropped. [DONE] Final wide shape: (5, 19) >>> wide.iloc[:2, :8] # show first 8 columns coord_x coord_y GWL_2018_q50 GWL_2019_q50 GWL_2020_q50 \ 0 113.400 ... ... ... ... 1 113.425 ... ... ... ...
subsidence_2018_actual subsidence_2019_actual subsidence_2020_actual
0 … … … 1 … … …
See also
pandas.DataFrame.unstackCore pivoting method used internally.
geoprior.plot.forecast.forecast_viewVisualisation routine that consumes the resulting wide frame.
- geoprior.utils.pivot_forecast_dataframe(data, id_vars, time_col, value_prefixes, static_actuals_cols=None, time_col_is_float_year='auto', round_time_col=False, verbose=0, savefile=None, _logger=None, **kws)[source]
Transforms a long-format forecast DataFrame to a wide format.
This utility reshapes time series prediction data from a “long” format, where each row represents a single time step for a given sample, to a “wide” format, where each row represents a single sample and columns correspond to values at different time steps.
- Parameters:
data (
pd.DataFrame) – The input long-format DataFrame. It must contain the columns specified in id_vars and time_col, as well as value columns that start with the strings in value_prefixes.id_vars (
listofstr) – A list of column names that uniquely identify each sample or group. These columns will be preserved in the wide-format output. For example:['sample_idx', 'coord_x', 'coord_y'].time_col (
str) – The name of the column that represents the time step or year of the forecast (e.g., ‘coord_t’ or ‘forecast_step’). This column’s values will become part of the new column names.value_prefixes (
listofstr) – A list of prefixes for the value columns that need to be pivoted. The function identifies columns starting with these prefixes. For instance,['subsidence', 'GWL']would match ‘subsidence_q10’, ‘GWL_q50’, etc.static_actuals_cols (
listofstr, optional) – A list of columns containing static “actual” or ground truth values for each sample. These values are assumed to be constant for each unique sample_idx and are merged back into the wide DataFrame after pivoting. Example:['subsidence_actual'].time_col_is_float_year (
boolor'auto', default'auto') –Controls how the time_col values are formatted into new column names. - If
'auto', automatically detects if time_col has afloat dtype.
If
True, treats time_col values (e.g., 2018.0) as years and converts them to integer strings (‘2018’).If
False, uses the string representation of the value as is.
round_time_col (
bool, defaultFalse) – IfTrueand time_col is a float type, its values will be rounded to the nearest integer before being used in column names. This is useful for cleaning up float years (e.g., 2018.0001 -> 2018).verbose (
int, default0) – Controls the verbosity of logging messages. 0 is silent. Higher values print more details about the process.savefile (
str, optional) – If a file path is provided, the final wide-format DataFrame will be saved as a CSV file to that location.
- Returns:
A wide-format DataFrame with one row per unique combination of id_vars. New columns are created in the format {prefix}_{time_str}{_suffix} (e.g., ‘subsidence_2018_q10’).
- Return type:
pd.DataFrame
See also
pandas.pivot_tableThe core function used for reshaping data.
pandas.mergeUsed to re-join static columns after pivoting.
Notes
The combination of columns in id_vars and time_col must uniquely identify each row in df_long for the pivot to succeed without data loss.
If using static_actuals_cols, the id_vars list must contain ‘sample_idx’ to correctly merge the static data back.
Examples
>>> import pandas as pd >>> from geoprior.utils.data_utils import pivot_forecast_dataframe >>> data = { ... 'sample_idx': [0, 0, 1, 1], ... 'coord_t': [2018.0, 2019.0, 2018.0, 2019.0], ... 'coord_x': [0.1, 0.1, 0.5, 0.5], ... 'coord_y': [0.2, 0.2, 0.6, 0.6], ... 'subsidence_q50': [-8, -9, -13, -14], ... 'subsidence_actual': [-8.5, -8.5, -13.2, -13.2], ... 'GWL_q50': [1.2, 1.3, 2.2, 2.3], ... } >>> df_long_example = pd.DataFrame(data) >>> df_wide = pivot_forecast_dataframe( ... data=df_long_example, ... id_vars=['sample_idx', 'coord_x', 'coord_y'], ... time_col='coord_t', ... value_prefixes=['subsidence', 'GWL'], ... static_actuals_cols=['subsidence_actual'], ... verbose=0 ... ) >>> print(df_wide.columns) Index(['sample_idx', 'coord_x', 'coord_y', 'subsidence_actual', 'GWL_2018_q50', 'GWL_2019_q50', 'subsidence_2018_q50', 'subsidence_2019_q50'], dtype='object')
- geoprior.utils.fetch_joblib_data(job_file, *keys, error_mode='raise', verbose=0)[source]
Dynamically load data from a joblib-saved dictionary with flexible key access.
- Parameters:
job_file (
str) – Path to the joblib file containing a dictionary*keys (
str) – Variable-length list of dictionary keys to retrieveerror_mode (
{'raise', 'warn', 'ignore'}, default'raise') – Handling of missing keys: - ‘raise’: Immediately raise KeyError - ‘warn’: Issue warning and skip missing keys - ‘ignore’: Silently skip missing keysverbose (
int, default0) – Verbosity level: - 0: No output - 1: Basic loading information - 2: Detailed debugging output
- Returns:
Full dictionary if no keys specified
Tuple of values for requested keys (maintaining order)
- Return type:
Union[Dict,Tuple]- Raises:
FileNotFoundError – If specified job_file doesn’t exist
TypeError – If loaded data isn’t a dictionary
KeyError – If requested key not found and error_mode=’raise’
Examples
>>> from geoprior.utils.io_utils import fetch_joblib_data >>> data = fetch_joblib_data('data.joblib', 'X_train', 'y_train') >>> X, y = fetch_joblib_data('data.joblib', 'X_val', 'y_val', verbose=1) >>> full_dict = fetch_joblib_data('data.joblib')
Notes
Maintains original insertion order for Python 3.7+ dictionaries
Missing keys in ‘warn’/’ignore’ modes result in shorter return tuple
Joblib files must contain dictionary objects
- geoprior.utils.save_job(job, savefile, *, protocol=None, append_versions=True, append_date=True, fix_imports=True, buffer_callback=None, **job_kws)[source]
Quick save your job using ‘joblib’ or persistent Python pickle module.
- Parameters:
job (
Any) – Anything to save, preferabaly a models in dictsavefile (
str, orpath-like object) – name of file to store the model. The file argument must have a write() method that accepts a single bytes argument. It can thus be a file object opened for binary writing, an io.BytesIO instance, or any other custom object that meets this interface.append_versions (
bool, default=True) – Append the version of Joblib module or Python Pickle module following by the scikit-learn, numpy and also pandas versions. This is useful to have idea about previous versions for loading file when system or modules have been upgraded. This could avoid bottleneck when data have been stored for long times and user has forgotten the date and versions at the time the file was saved.append_date (
bool, defaultTrue,) – Append the date of the day to the filename.protocol (
int, optional) –The optional protocol argument tells the pickler to use the given protocol; supported protocols are 0, 1, 2, 3, 4 and 5. The default protocol is 4. It was introduced in Python 3.4, and is incompatible with previous versions.
Specifying a negative protocol version selects the highest protocol version supported. The higher the protocol used, the more recent the version of Python needed to read the pickle produced.
fix_imports (
bool, defaultTrue,) – If fix_imports is True and protocol is less than 3, pickle will try to map the new Python 3 names to the old module names used in Python 2, so that the pickle data stream is readable with Python 2.buffer_call_back (
int, optional) –If buffer_callback is None (the default), buffer views are serialized into file as part of the pickle stream.
If buffer_callback is not None, then it can be called any number of times with a buffer view. If the callback returns a false value (such as None), the given buffer is out-of-band; otherwise the buffer is serialized in-band, i.e. inside the pickle stream.
It is an error if buffer_callback is not None and protocol is None or smaller than 5.
job_kws (
dict,) – Additional keywords arguments passed tojoblib.dump().
- Returns:
The final filename where the job was saved.
- Return type:
Notes
This function appends system-specific metadata like versions and date to the filename, which can aid in tracking compatibility over time.
Examples
>>> from geoprior.utils.io_utils import save_job >>> model = {"key": "value"} # Replace with actual model object >>> savefile = save_job(model, "my_model", append_date=True, append_versions=True) >>> print(savefile) 'my_model.20240101.sklearn_v1.0.numpy_v1.21.joblib'
- geoprior.utils.normalize_time_column(df, time_col, datetime_col='datetime_temp', year_col='year_int', drop_orig=False)[source]
Normalize a time column into a datetime column and an integer year.
The input column may contain integer years, strings, or existing pandas Datetime values. The function creates
datetime_colwith parsed timestamps andyear_colwith the extracted integer year. Whendrop_orig=True, the originaltime_colis removed anddatetime_colis renamed back totime_col.- Parameters:
df (
pandas.DataFrame) – Input DataFrame containing a time column namedtime_col.time_col (
str) – Name of the column to normalize.datetime_col (
str, default'datetime_temp') – Name of the parsed datetime column.year_col (
str, default'year_int') – Name of the extracted integer year column.drop_orig (
bool, defaultFalse) – IfTrue, drop the originaltime_colafter parsing and renamedatetime_colback totime_col.
- Returns:
A copy of
dfwith the parsed datetime column and integer year column.- Return type:
- Raises:
ValueError – If
time_colis missing or parsing fails for any entry.TypeError – If
dfis not a pandas DataFrame.
- geoprior.utils.convert_eval_payload_units(payload, cfg=None, *, mode='si', scope='all', savefile=None, fmt='json', indent=2, copy_payload=True)[source]
Convert GeoPriorSubsNet evaluation-payload units for reporting.
This is a post-processing helper meant for stage-2 evaluation JSON payloads (e.g.
geoprior_eval_phys_<timestamp>.json).- Parameters:
payload (
Mapping[str,Any]) – The evaluation payload dict. It is expected to contain sections likemetrics_evaluate,point_metrics,per_horizon,interval_calibrationandcensor_stratified.cfg (
mappingormodule, optional) – The experiment config (e.g.configmodule orglobals()). The helper readsSUBS_UNIT_TO_SI(or stage-1 provenance) andTIME_UNITSfrom this object when available.mode (
{"si", "interpretable"}, default"si") –"si"leaves values untouched."interpretable"converts selected subsidence and physics metrics from SI into the native units implied bySUBS_UNIT_TO_SI.scope (
{"all", "subsidence", "physics"}, default"all") – Which parts to convert whenmode="interpretable"."subsidence"converts only subsidence metrics such as MAE, MSE, and sharpness to the native unit."physics"converts only unambiguous physics residual rates, currentlyepsilon_cons_rawandepsilon_gw_raw."all"applies both conversions.savefile (
str, optional) – If provided, write the converted payload to this path.fmt (
{"json"}, default"json") – Output format whensavefileis provided.indent (
int, default2) – JSON indentation.copy_payload (
bool, defaultTrue) – If True, operate on a deep copy ofpayload. If False, convert in-place (dangerous).
- Returns:
Converted payload as a plain
dict.- Return type:
Notes
For subsidence metrics, linear quantities such as MAE and sharpness scale by
1 / SUBS_UNIT_TO_SI, while squared quantities such as MSE scale by(1 / SUBS_UNIT_TO_SI) ** 2.When physics conversion is enabled,
epsilon_cons_rawis treated as a rate inm/sand converted to<subs_native_unit>/<TIME_UNITS>(for examplemm/yr), whileepsilon_gw_rawis treated as a rate in1/sand converted to1/<TIME_UNITS>.The helper records unit provenance under
payload["units"].
- geoprior.utils.postprocess_eval_json(eval_json, *, cfg=None, scope='all', out_path=None, overwrite=False, add_rmse=True, force=False, indent=2)[source]
Post-hoc convert a Stage-2 evaluation JSON from SI to interpretable units.
This is a safe wrapper around convert_eval_payload_units(…) that: - loads a JSON from disk (or accepts a payload dict), - infers unit factors from payload[“units”] when cfg is missing, - avoids double conversion unless force=True, - optionally adds RMSE fields (sqrt(MSE)), - writes a converted JSON file if out_path is provided.
- Parameters:
eval_json (str | Mapping[str, Any]) – Either a file path to the saved JSON, or an in-memory mapping.
cfg (Mapping[str, object] | None) – Optional config mapping/module. If missing (or incomplete), this helper will synthesize the minimal keys needed for conversion: - “SUBS_UNIT_TO_SI” - “TIME_UNITS”
scope (Literal['all', 'subsidence', 'physics']) – Forwarded to convert_eval_payload_units(…).
out_path (str | None) – If given, write the converted payload there. If a directory is given, a filename is generated next to the input name (or “geoprior_eval…”).
overwrite (bool) – If False and out_path exists, raise.
add_rmse (bool) – If True, add RMSE fields wherever MSE is present (metrics_evaluate, point_metrics, per_horizon).
force (bool) – If False, skip conversion when payload already declares an interpretable subsidence unit (e.g., “mm”). If True, always convert.
indent (int) – JSON indent used when writing.
- Returns:
The converted payload (always returned as a dict).
- Return type:
- geoprior.utils.evaluate_point_forecast(model, out, y_true_subs, *, y_true_gwl=None, n_q=3, quantiles=None, q=None, use_physical=False, return_physical=None, scaler_info=None, subs_target_name='subsidence', gwl_target_name='gwl', scaler_entry=None, scaler=None, feature_index=None, n_features=None, params=None, strict=True, output_names=None)[source]
End-to-end helper for point-forecast evaluation.
Pipeline: extract predictions from
modelandout, canonicalize BHQO quantile outputs when needed, pick a point prediction, optionally inverse-scale it, and compute global and per-horizon metrics.- Parameters:
model (Any) – Passed to extract_preds(…).
out (Any) – Passed to extract_preds(…).
y_true_subs (Any) – True subsidence, shape (B, H, 1) or (B, H).
y_true_gwl (Any | None) – Optional true gwl/head, same shape conventions.
n_q (int) – Quantile-selection controls for
(B, H, Q, 1)outputs. IfqisNone, the median is preferred. Ifqis an integer, it is treated as a direct quantile index. Ifqis a float andquantilesis provided, the nearest quantile is selected; otherwise the float is treated as a fraction in[0, 1].quantiles (Sequence[float] | None) – Quantile-selection controls for
(B, H, Q, 1)outputs. IfqisNone, the median is preferred. Ifqis an integer, it is treated as a direct quantile index. Ifqis a float andquantilesis provided, the nearest quantile is selected; otherwise the float is treated as a fraction in[0, 1].q (float | int | None) – Quantile-selection controls for
(B, H, Q, 1)outputs. IfqisNone, the median is preferred. Ifqis an integer, it is treated as a direct quantile index. Ifqis a float andquantilesis provided, the nearest quantile is selected; otherwise the float is treated as a fraction in[0, 1].use_physical (bool) – If True, compute metrics in physical units.
return_physical (bool | None) – If None, defaults to use_physical. If
True, return physical-space arrays such assubs_pred_physandgwl_pred_physwhen possible.scaler_info (Mapping[str, Any] | None) – Passed through to
inverse_scale_target(...).scaler_entry (Mapping[str, Any] | None) – Passed through to
inverse_scale_target(...).scaler (Any | None) – Passed through to
inverse_scale_target(...).feature_index (int | None) – Passed through to
inverse_scale_target(...).n_features (int | None) – Passed through to
inverse_scale_target(...).params (Mapping[str, float] | None) – Passed through to
inverse_scale_target(...).subs_target_name (str)
gwl_target_name (str)
strict (bool)
- Returns:
Dictionary containing model-space predictions, optional physical-space predictions, and global and per-horizon metrics for subsidence and groundwater outputs.
- Return type:
- geoprior.utils.inverse_scale_target(y_scaled, *, scaler_info=None, target_name=None, scaler_entry=None, scaler=None, feature_index=None, n_features=None, params=None)[source]
Inverse-transform a scaled target array back to physical units.
Supports three patterns:
Stage-1
scaler_infodict (preferred in NATCOM): passscaler_info=scaler_info_dict, target_name="subsidence".A bare scaler instance or a path to a joblib dump via
scaler=.... If multi-feature, also passfeature_indexand optionallyn_features.Manual scaling parameters via
paramssuch as {"min", "max"}, {"mean", "std"}, or {"scale", "shift"}.
- Parameters:
y_scaled (
array-like) – Scaled values, e.g. of shape (N, H, 1) or (N,).scaler_info (
mapping, optional)target_name (
str, optional)scaler_entry (
mapping, optional)feature_index (
int, optional)n_features (
int, optional)params (
mapping, optional) – Manual parameters: - {"min", "max"} → MinMax scaling - {"mean", "std"} → standardization - {"scale", "shift"} → x = scale * x_scaled + shift.
- Returns:
Array with same shape as
y_scaledbut in physical units when scaling information is available. If nothing usable is found, returns the input as a NumPy array.- Return type:
np.ndarray
- geoprior.utils.deg_to_m_from_lat(lat_deg)[source]
Approx WGS84 meters per degree at reference latitude. Returns (deg_to_m_lon, deg_to_m_lat).
- geoprior.utils.canonicalize_BHQO(y_pred, *, y_true=None, q_values=(0.1, 0.5, 0.9), n_q=None, layout=None, enforce_monotone=True, return_layout=False, verbose=0, log_fn=<built-in function print>)[source]
Canonicalize quantile outputs to (B, H, Q, O).
- Supported layouts (rank-4):
BHQO: (B, H, Q, O) -> unchanged
BQHO: (B, Q, H, O) -> transpose(0, 2, 1, 3)
BHOQ: (B, H, O, Q) -> transpose(0, 1, 3, 2)
If ambiguous (e.g., H == Q), and y_true is given, pick the transform with smallest MAE for q50.
- If y_true is not given, fallback is:
use layout if provided
else prefer BHQO if plausible
else pick by min crossing score
- Parameters:
y_pred (Any) – Quantile tensor, NumPy array or TF tensor.
y_true (Any | None) – Target tensor (B, H, O) or (B, H, 1). Used only to resolve ambiguity robustly.
q_values (Sequence[float]) – Quantiles in order, e.g. (0.1, 0.5, 0.9).
n_q (int | None) – Number of quantiles. Defaults to len(q_values).
layout (str | None) – Force interpretation: “BHQO”, “BQHO”, “BHOQ”. Use “auto” (or None) to infer.
enforce_monotone (bool) – Sort along Q axis after canonicalization.
return_layout (bool) – If True, return (arr, chosen_layout).
verbose (int) – Logging controls.
- Returns:
Canonical (B, H, Q, O) and optionally the layout.
- Return type:
arror(arr,layout)
- geoprior.utils.calibrate_quantile_forecasts(*, df_eval=None, df_future=None, target_name='subsidence', column_map=None, step_col='forecast_step', interval=(0.1, 0.9), target_coverage=0.8, median_q=0.5, use='auto', tol=0.02, f_max=5.0, max_iter=32, keep_original=False, enforce_monotonic='cummax', overall_key='__overall__', calibrated_col='is_calibrated', factor_col='calibration_factor', factors=None, save_eval=None, save_future=None, save_stats=None, verbose=1, logger=None)[source]
Fit and apply post-hoc interval calibration for quantile forecasts.
This is the high-level DataFrame-oriented entry point for interval recalibration in
geoprior.utils.calibrate. It candetect whether evaluation forecasts already appear calibrated,
fit interval-width correction factors from evaluation data,
apply those factors to evaluation and/or future forecasts,
compute before/after summary diagnostics on the evaluation set,
optionally save the outputs to disk.
The function is designed for workflows where quantile forecasts are already available in tabular form and calibration should be handled without retraining the forecasting model.
Conceptually, the function widens or narrows a predictive interval around a median-like forecast so that the empirical interval coverage better matches the requested target. This is a practical post-hoc strategy for uncertainty refinement in multi-horizon forecasting pipelines [1, 2].
- Parameters:
df_eval (
pandas.DataFrameorNone, defaultNone) – Evaluation forecasts used to fit calibration factors and to compute before/after diagnostics. This table should contain the observed target column in addition to the quantile forecasts.df_future (
pandas.DataFrameorNone, defaultNone) – Future or inference forecasts to which the fitted factors should be applied. This table does not need observed targets.target_name (
str, default"subsidence") – Base name used to infer forecast and observation columns whencolumn_mapis not explicitly supplied.column_map (
mappingorNone, defaultNone) – Optional mapping describing the observed column and the quantile columns. This is helpful when the table does not follow the default naming conventions.step_col (
str, default"forecast_step") – Column used to fit and apply separate factors per forecast horizon.interval (
tupleoffloat, default(0.1,0.9)) – Lower and upper quantiles defining the interval to calibrate. The nearest available quantiles are used.target_coverage (
float, default0.8) – Desired empirical coverage after calibration.median_q (
float, default0.5) – Central quantile used as the expansion anchor.use (
{"auto", True, False}, default"auto") –Control flag for whether calibration is performed.
Falsedisables calibration and returns inputs unchanged."auto"skips calibration when evaluation forecasts already look calibrated.Trueforces calibration even if the automatic check would skip it.
tol (
float, default0.02) – Tolerance used by the automatic already-calibrated check.f_max (
float, default5.0) – Maximum factor allowed during fitting.max_iter (
int, default32) – Maximum number of bisection iterations used when fitting factors.keep_original (
bool, defaultFalse) – If True, raw quantiles are copied into*_rawcolumns before calibration is applied.enforce_monotonic (
{"cummax", "sort", "none"}, default"cummax") – Strategy used to prevent quantile crossing after recalibration.overall_key (
strorNone, default"__overall__") – Reserved label stored in the returned statistics dictionary for overall summary reporting.calibrated_col (
str, default"is_calibrated") – Column name added to calibrated outputs as a Boolean marker.factor_col (
str, default"calibration_factor") – Column name used to store the factor applied to each row.factors (
floatormappingorNone, defaultNone) – Optional user-specified calibration factors. If provided, these take precedence over factors fitted fromdf_eval.save_eval (
strorpath-likeorNone, defaultNone) – Optional CSV path for saving the calibrated evaluation table.save_future (
strorpath-likeorNone, defaultNone) – Optional CSV path for saving the calibrated future table.save_stats (
strorpath-likeorNone, defaultNone) – Optional JSON path for saving the calibration summary.verbose (
int, default1) – Verbosity level forwarded to logging helpers.logger (
logging.LoggerorNone, defaultNone) – Optional logger used for progress messages.
- Returns:
df_eval_cal (
pandas.DataFrameorNone) – Calibrated evaluation DataFrame, or None when no evaluation table was provided.df_future_cal (
pandas.DataFrameorNone) – Calibrated future DataFrame, or None when no future table was provided.stats (
dict[str,Any]) – Dictionary describing the calibration workflow. Depending on the path taken, it may containthe target interval and target coverage,
the fitted or user-specified factors,
skip reasons,
evaluation summaries before and after calibration.
- Return type:
Notes
In
use="auto"mode, the function first checks for an explicitcalibrated_coland then falls back to a simple empirical coverage-based decision. This makes the wrapper conservative in repeated workflows, where the same tables may pass through the calibration stage more than once.The returned
statsdictionary is designed to be JSON-friendly and therefore suitable for audit trails, experiment manifests, or gallery artifacts.Examples
>>> import pandas as pd >>> from geoprior.utils.calibrate import ( ... calibrate_quantile_forecasts, ... ) >>> df_eval = pd.DataFrame( ... { ... "forecast_step": [1, 1, 2, 2], ... "subsidence_actual": [0.4, 0.7, 0.5, 0.9], ... "subsidence_q10": [0.3, 0.5, 0.4, 0.6], ... "subsidence_q50": [0.4, 0.7, 0.5, 0.8], ... "subsidence_q90": [0.5, 0.9, 0.6, 1.0], ... } ... ) >>> df_eval_cal, df_future_cal, stats = ( ... calibrate_quantile_forecasts( ... df_eval=df_eval, ... target_name="subsidence", ... target_coverage=0.8, ... ) ... ) >>> isinstance(stats, dict) True
See also
fit_interval_factors_dfFit per-horizon interval-width correction factors.
apply_interval_factors_dfApply a scalar or per-horizon factor map to quantile forecasts.
- geoprior.utils.audit_stage2_handshake(*, X_train, X_val, y_train, y_val, time_steps, forecast_horizon, mode, dyn_names, fut_names, sta_names, coord_scaler=None, sk_final, save_dir, table_width=100, title_prefix='STAGE-2 HANDSHAKE AUDIT', city='Unkown', model_name='Model', log_fn=None)[source]
- geoprior.utils.audit_stage1_scaling(*, df_train, inputs_train, targets_train, coord_scaler=None, coord_ranges=None, coord_mode='auto', coords_in_degrees=False, coord_epsg_used=None, coord_x_col_used='x', coord_y_col_used='y', x_col_used='x', y_col_used='y', time_col_used='t', normalize_coords=True, keep_coords_raw=False, shift_raw_coords=False, subs_model_col=None, gwl_dyn_col=None, gwl_target_col=None, h_field_col=None, dynamic_features=None, static_features=None, future_features=None, scaled_ml_numeric_cols=None, main_scaler_path=None, scaler_info=None, save_dir=None, table_width=110, title_prefix='COORDINATE + FEATURE SCALING AUDIT (Stage-1)', city='Unknown', model_name='Model', sample_rows=5, log_fn=None)[source]
Stage-1 audit: - raw df_train coord stats (t/x/y) + heuristic units - model-fed coords stats from inputs_train[“coords”] (flattened) - coord scaler min/max + coord_ranges - SI channel sanity for physics cols (if present) - target arrays sanity - split of features: scaled ML vs __si vs other Saves a machine-readable JSON if save_dir is provided.
- Parameters:
coord_scaler (Any)
coord_mode (str)
coords_in_degrees (bool)
coord_epsg_used (Any)
coord_x_col_used (str)
coord_y_col_used (str)
x_col_used (str)
y_col_used (str)
time_col_used (str)
normalize_coords (bool)
keep_coords_raw (bool)
shift_raw_coords (bool)
subs_model_col (str | None)
gwl_dyn_col (str | None)
gwl_target_col (str | None)
h_field_col (str | None)
main_scaler_path (str | None)
scaler_info (dict | None)
save_dir (str | None)
table_width (int)
title_prefix (str)
city (str)
model_name (str)
sample_rows (int)
- Return type:
str | None
- geoprior.utils.should_audit(audit_stages, *, stage, default=None)[source]
Convenience: should we audit this stage?
- geoprior.utils.format_and_forecast(y_pred, y_true, *, coords=None, quantiles=None, target_name='subsidence', output_target_name=None, scaler_target_name=None, target_key_pred='subs_pred', component_index=0, scaler_info=None, coord_scaler=None, coord_columns=('coord_t', 'coord_x', 'coord_y'), train_end_time=None, forecast_start_time=None, forecast_horizon=None, future_time_grid=None, eval_forecast_step=None, eval_export='all', value_mode='rate', input_value_mode='rate', rate_first='cum_over_dtref', absolute_baseline=None, sample_index_offset=0, city_name=None, model_name=None, dataset_name=None, csv_eval_path=None, csv_future_path=None, time_as_datetime=False, time_format=None, calibration=False, calibration_kwargs=None, calibration_save_stats=None, eval_metrics=False, metrics_column_map=None, metrics_quantile_interval=(0.1, 0.9), metrics_per_horizon=False, metrics_extra=None, metrics_extra_kwargs=None, metrics_savefile=None, metrics_save_format='.json', metrics_time_as_str=True, output_unit=None, output_unit_from='m', output_unit_mode='overwrite', output_unit_suffix='_mm', output_unit_col=None, verbose=1, logger=None, **kws)[source]
Format PINN forecasts into evaluation and future DataFrames.
This helper takes the raw model outputs (already split into
y_pred['subs_pred']/y_pred['gwl_pred']), the matching ground-truth dictionary (y_true), and optional coordinate and scaler information, and returns two DataFrames:df_eval: predictions + actuals for an evaluation year (typically the last training year, e.g. 2022).df_future: predictions for the future horizon (e.g. 2023–2025), without actuals.
- Parameters:
y_pred (
dict) –Dictionary of model predictions, as returned by
GeoPriorSubsNet.predictpost-processed into{'subs_pred': ..., 'gwl_pred': ...}.For subsidence, the expected shapes are:
Quantile mode:
(B, H, Q, O)where:B= batch size,H= horizon steps,Q= number of quantiles,O= output dim.Point mode:
(B, H, O).
Dictionary of true targets, typically
{'subsidence': ..., 'gwl': ...}or{'subs_pred': ..., 'gwl_pred': ...}.If
None, evaluation DataFrame is still created but without the actual-value column.coords (
ndarray, optional) – Optional coordinates array aligned with predictions. Commonly shaped(B, H, 3)with columns[t_scaled, x_scaled, y_scaled]. Onlyxandyare used when inverse-transforming spatial coordinates; time is overwritten by the provided temporal config if given.quantiles (
listoffloatorNone, optional) – List of quantiles (e.g.[0.1, 0.5, 0.9]) if the model was trained in probabilistic mode. IfNone, a single prediction column is emitted instead.target_name (
str, default'subsidence') –Logical target identifier used as the default key for locating the target scaler in
scaler_infoand as a fallback for resolving truth arrays iny_true.Column naming is controlled by
output_target_name(or the auto-derived output prefix when it isNone).output_target_name (
strorNone, optional) –Output prefix used when creating DataFrame columns for predictions and actuals.
This controls the column naming only (e.g. the function will emit
f"{output_target_name}_q10",f"{output_target_name}_pred", andf"{output_target_name}_actual").If
None(default), the function derives the output prefix fromtarget_nameand applies a small convenience rule: iftarget_nameends with"_cum"or"_cumulative", that suffix is stripped for output naming.This keeps downstream tooling consistent (many plotting and metrics utilities expect names like
subsidence_q10rather thansubsidence_cum_q10), while still allowing the scaler lookup to use the true target key. For example, withtarget_name="subsidence_cum"andoutput_target_name=None, output columns becomesubsidence_q10,subsidence_q50, andsubsidence_actual. Ifoutput_target_name="subsidence_cum", the output columns keep the suffix such assubsidence_cum_q10.scaler_target_name (
strorNone, optional) –Name used to locate the target scaling block inside
scaler_infoand to perform inverse-transform for predictions and actuals.This controls the scaler key and inverse scaling, not the output column naming.
If
None(default), the scaler key is assumed to betarget_name. This is important when you want clean output columns but the scaler was fitted/stored under the original target name.A common pattern is to keep
target_name="subsidence_cum"so the scaler lookup matches the Stage-1 schema, while lettingoutput_target_name=Noneproduce clean output columns. In that setup, inverse transform still uses thesubsidence_cumscaler key, while output columns use thesubsidence_prefix because of the auto-strip rule.target_key_pred (
str, default'subs_pred') – Key insidey_predthat holds the subsidence forecasts.component_index (
int, default0) – Index along the output dimension O to use whenoutput_subsidence_dim > 1. For scalar subsidence this is 0.scaler_info (
dict, optional) – Optional Stage-1scaler_infomapping containing a target scaler under keys such as'targets'or'target'. The target block is expected to provide an sklearn-like transformer under'scaler'together with column names under'columns'or'cols'. If present and consistent, subsidence values (predicted and actual) are inverse-transformed fortarget_name.coord_scaler (
object, optional) – Optional scaler used for coordinates. If provided, it is only used to inverse-transformcoord_xandcoord_ywhencoordsis given andcoord_columnscan be matched. Time is not taken from the inverse transform; it is controlled by the temporal config.coord_columns (
tupleofstr, default(``’coord_t’:py:class:`,`’coord_x’:py:class:`,`’coord_y’``)) – Logical names of the time, x, and y coordinate columns. These are used for DataFrame column naming and for mapping intocoord_scalerif its block carries column names.train_end_time (
scalarorstrordatetime, optional) – Physical time associated with the evaluation year (e.g. 2022). Ifeval_forecast_stepis not given, the last horizon step is assumed to correspond to this time.forecast_start_time (
scalarorstrordatetime, optional) – First time in the future forecast horizon (e.g. 2023).forecast_horizon (
int, optional) – Number of forecast steps in the future horizon (e.g. 3). Iffuture_time_gridis not given, this is used together withforecast_start_timeto build a regular grid.future_time_grid (
array-like, optional) – Explicit physical times for each forecast step, length H. For yearly data this might be[2023, 2024, 2025]. If provided, it overrides any automatic construction fromforecast_start_timeandforecast_horizon.eval_forecast_step (
intorNone, optional) – Horizon step index (1-based) to use for evaluation. IfNone, defaults to the last horizon step H.eval_export (
{"all", "last"}orstrorintorsequence, optional) –Controls which evaluation rows are exported in
df_evaland written tocsv_eval_path. By default ("all"), the function exports the multi-horizon evaluation DataFrame (df_eval_all), which contains one row per sample and forecast step (e.g. years 2020, 2021, 2022 forH=3).Accepted values are:
"all"or"full"or"horizons": export all horizons fromdf_eval_all."last"or"single"or"default": export only the single evaluation step specified byeval_forecast_step(backwards-compatible behaviour).Other
str(e.g."2022") : interpreted as a time value forcoord_t; only rows ofdf_eval_allwhose time column matches this value are exported.intor scalar non-string : interpreted as a single time value (e.g.2022).sequence of values (e.g.
[2021, 2022]) : interpreted as a set of time values; only rows whosecoord_tbelongs to this set are exported.
If
time_as_datetime=True, the selection values are converted withpandas.to_datetimeusingtime_formatbefore filtering. Ifdf_eval_allis not available (e.g. no ground truth was provided), the function falls back to exporting the single-stepdf_evalregardless ofeval_export.value_mode (
{"rate", "cumulative", "absolute_cumulative"}, optional) –Controls how forecast values are interpreted along the temporal horizon for each sample. The default is
"rate", which treats each forecast step as an incremental rate (e.g. annual subsidence rate) and leaves predictions unchanged.Supported modes are:
"rate": keep per-step predictions as provided by the model (current behaviour)."cumulative"or"cum": convert per-step rates into relative cumulative values by applying a cumulative sum overforecast_stepfor eachsample_idx. For example, for years 2023–2025, the value at 2024 is the sum of the 2023 and 2024 rates."absolute_cumulative"or"abs_cum"or"absolute": same as"cumulative", then add an absolute baseline provided byabsolute_baseline(e.g. cumulative subsidence at the end of the training period), yielding absolute cumulative trajectories.
Cumulative transforms are applied consistently to:
the future forecast DataFrame (
df_future),the multi-horizon evaluation DataFrame (
df_eval_all),and the single-step evaluation DataFrame (
df_eval, which is regenerated fromdf_eval_allafter the transformation).
When an unsupported string is given, the function logs a warning and falls back to
"rate".absolute_baseline (
floatorMapping[int,float], optional) –Baseline value to use when
value_moderequests absolute cumulative outputs ("absolute_cumulative","abs_cum","absolute"). This baseline is interpreted as the pre-forecast cumulative level for each sample, for example, cumulative subsidence attrain_end_time(e.g. end of 2022), and is added after applying the cumulative sum over the forecast horizon.If a scalar
floatis provided, the same baseline value is added to all samples. If a mapping is provided, it must mapsample_idx(integers) to baseline values, allowing per-sample baselines:absolute_baseline = {sample_idx: baseline_value, ...}
Only prediction columns for
target_nameare shifted (e.g."subsidence_q10","subsidence_q50","subsidence_q90"or"subsidence_pred"). Whendf_eval_allis present, the corresponding"<target_name>_actual"column is shifted as well, so evaluation metrics operate on absolute cumulative values.If
value_modeis an absolute cumulative variant butabsolute_baselineisNone, the function logs a warning and degrades gracefully to relative cumulative mode (i.e. no baseline shift is applied).sample_index_offset (
int, default0) – Offset added tosample_idx(useful when concatenating multiple tiles).city_name (
str, optional) – Optional metadata used only for logging.model_name (
str, optional) – Optional metadata used only for logging.dataset_name (
str, optional) – Optional metadata used only for logging.csv_eval_path (
str, optional) – If provided,df_evalis written to this path (directories are created if needed).csv_future_path (
str, optional) – If provided,df_futureis written to this path.time_as_datetime (
bool, defaultFalse) – IfTrue, time values are converted usingpandas.to_datetime()with the providedtime_format(if any).time_format (
strorNone, optional) – Optional format string passed topandas.to_datetime()whentime_as_datetime=True.eval_metrics (
bool, defaultFalse) – IfTrue, automatically callevaluate_forecast()on the resultingdf_evalto compute diagnostics. Metrics are not returned by this function; they are either written to disk (ifmetrics_savefileis provided) or discarded. For programmatic access to the metrics dictionary, callevaluate_forecast()directly.metrics_column_map (
mapping, optional) – Optional column mapping forwarded toevaluate_forecast()(see its documentation for details). IfNone, default column names such as'coord_t','forecast_step',f'{target_name}_q10', andf'{target_name}_actual'are assumed.metrics_quantile_interval (
tupleoffloat, default(0.1,0.9)) – Interval used for coverage and sharpness diagnostics in quantile mode, forwarded toevaluate_forecast().metrics_per_horizon (
bool, defaultFalse) – IfTrue, per-horizon MAE/MSE/R² are computed byevaluate_forecast()and included in the diagnostics.metrics_extra (
sequenceormapping, optional) –Optional additional metrics to compute, forwarded to
evaluate_forecast(). Can be:A sequence of metric names (resolved via
geoprior.metrics._registry.get_metric).A mapping
{name: func}wherefuncis a callable taking(y_true, y_pred, **kwargs).
metrics_extra_kwargs (
mapping, optional) – Optional per-metric keyword arguments, forwarded toevaluate_forecast(). Keys must match metric names inmetrics_extra.metrics_savefile (
str,path-like,bool, orNone) – If truthy, diagnostics fromevaluate_forecast()are written to disk. Behavior matches thesavefileargument ofevaluate_forecast(). WhenTrue, a filename is auto-generated near the evaluation CSV (if any) or in the current working directory.metrics_save_format (
{'.json', 'json', '.csv', 'csv'}, default'.json') – Output format for diagnostics written byevaluate_forecast(). JSON preserves the nested metric structure; CSV flattens it into a tall table.metrics_time_as_str (
bool, defaultTrue) – IfTrue, time keys in the diagnostics written byevaluate_forecast()are converted to strings (useful for JSON serialization).verbose (
int, default1) – Verbosity level passed tovlog().logger (
logging.Logger, optional) – Logger instance; ifNone, a module-levelLOGis used.input_value_mode (str)
rate_first (str)
output_unit (str | None)
output_unit_from (str)
output_unit_mode (str)
output_unit_suffix (str)
output_unit_col (str | None)
- Returns:
df_eval_to_write (
pandas.DataFrame) – DataFrame containing predictions and actuals for the evaluation time. Columns include:'sample_idx''forecast_step'quantile columns (e.g.
subsidence_q10) orsubsidence_pred'subsidence_actual'(if y_true given)coord_t,coord_x,coord_y(names fromcoord_columns).
df_future (
pandas.DataFrame) – DataFrame containing predictions for the future horizon, without actuals. Same structure asdf_evalbut without the actual-value column.
- Return type:
Notes
This function separates scaler lookup (
scaler_target_name) from output column naming (output_target_name). This is useful when the stored scaler key contains suffixes like"_cum"but downstream tools expect canonical names such as columns prefixed withsubsidence_.
- geoprior.utils.evaluate_forecast(eval_data, *, target_name='subsidence', column_map=None, quantile_interval=(0.1, 0.9), per_horizon=False, extra_metrics=None, extra_metric_kwargs=None, overall_key='__overall__', savefile=None, save_format='.json', time_as_str=True, verbose=1, logger=None)[source]
Evaluate forecast diagnostics from an evaluation DataFrame.
This helper consumes the
df_evaloutput fromformat_and_forecast()(or a compatible DataFrame) and computes aggregate metrics such as MAE, MSE, \(R^2\), coverage, and sharpness. It can also optionally evaluate metrics per forecast horizon and apply additional user-defined metrics.By default it expects the following columns:
'sample_idx''forecast_step''coord_t'(time)Quantile or point-prediction columns for the target, e.g.:
Quantile mode:
f'{target_name}_q10',f'{target_name}_q50',f'{target_name}_q90', …Point mode:
f'{target_name}_pred'.
Actual column:
f'{target_name}_actual'.
A flexible
column_mapallows remapping these logical roles to arbitrary column names, e.g.:column_map = { 'coord_t': 'date', 'actual': 'true_subs', 'pred': 'subs_predicted', }
or, for quantile columns:
column_map = { 'coord_t': 'date', 'quantiles': { 0.1: 'subs_q10', 0.5: 'subs_q50', 0.9: 'subs_q90', }, }
- Parameters:
eval_data (
str,path-like, orpandas.DataFrame) – Either a path to a CSV file containing the evaluation DataFrame (as saved byformat_and_forecast()) or an in-memory DataFrame.target_name (
str, default'subsidence') – Base name for the target columns. Used to infer default column names such asf'{target_name}_q10',f'{target_name}_pred', andf'{target_name}_actual'.column_map (
dict, optional) –Optional mapping to override default column names. The following keys are recognized:
'sample_idx': sample index column name (default'sample_idx').'forecast_step': horizon index column name (default'forecast_step').'coord_t': time coordinate column (default'coord_t').'actual': name or list of names for the actual target column(s). Currently a single column is supported; defaultf'{target_name}_actual'.'pred': point prediction column for non-quantile mode, defaultf'{target_name}_pred'.'quantiles':If a mapping:
{q: col_name}for quantile levels, whereqis a float in (0, 1).If a sequence of column names, the quantile value will be inferred from suffix patterns like
f'{target_name}_q{int(q*100):d}'.
quantile_interval (
tupleoffloat, default(0.1,0.9)) – Interval (lower, upper) used for coverage and sharpness metrics, typically corresponding to an 80% interval between Q10 and Q90.per_horizon (
bool, defaultFalse) – IfTrue, compute per-horizon MAE/MSE/R² grouped by theforecast_stepcolumn.extra_metrics (
sequenceofstrormapping, optional) –Optional additional metrics to compute.
If a sequence of strings (e.g.
['pss', 'pit']), each name is resolved viageoprior.metrics._registry.get_metric(). If the name is not present in the registry, an error is raised, prompting the user to pass a callable instead.If a mapping
{name: func}, eachfuncis called as:func(y_true, y_pred, **extra_metric_kwargs.get(name, {}))
where
y_predis the median (Q50) or point forecast.
For more complex metrics that require full quantile structure or temporal sequences, pass a suitable wrapper function that internally uses the DataFrame as needed.
extra_metric_kwargs (
mapping, optional) – Optional mapping of per-metric keyword arguments. Keys must match the names inextra_metrics. Each value is a dict of kwargs forwarded to the corresponding metric function.savefile (
str,path-like, orbool, optional) –If provided, metrics are saved to disk.
If
True: a filename is auto-generated neareval_data(if it is a path) or in the current working directory.If a string/path without extension: the extension is taken from
save_format.If a string/path with extension: that extension takes precedence over
save_format.
save_format (
{'.json', 'json', '.csv', 'csv'}, default'.json') –Output format when
savefileis truthy. JSON preserves nested structure; CSV is flattened into a tall table.For JSON, the function returns the metrics dictionary.
For CSV, the function returns the metrics DataFrame.
time_as_str (
bool, defaultTrue) – IfTrue, time keys in the result dictionary are converted to strings (useful for JSON serialization). If there is only a single time value, the result is flattened and the time key is omitted.verbose (
int, default1) – Verbosity level passed tovlog().logger (
logging.Logger, optional) – Optional logger instance used byvlog().overall_key (str | None)
- Returns:
results – If
save_formatis JSON (default), returns a dict:Single time value:
{ "overall_mae": ..., "overall_mse": ..., "overall_r2": ..., "coverage80": ..., "sharpness80": ..., "per_horizon_mae": {1: ..., 2: ..., ...}, ... }
Multiple time values:
{ "2021": { ...metrics... }, "2022": { ...metrics... }, }
If
save_formatis CSV, returns a DataFrame with flattened rows:Columns include:
coord_t,metric,horizon, andvalue.
- Return type:
Notes
Default metrics in quantile mode:
overall_mae,overall_mse,overall_r2coverage80andsharpness80(using the requested interval, e.g., Q10–Q90)
If
per_horizon=True, also:per_horizon_mae,per_horizon_mse,per_horizon_r2(each a mapping from horizon index to score).
Default metrics in point mode (no quantiles):
mae,mse,r2
And optionally, if
per_horizon=True:per_horizon_mae,per_horizon_mse,per_horizon_r2.
- geoprior.utils.default_results_dir(start=None, env_var='RESULTS_DIR', folder_name='results', create=False)[source]
Resolve the canonical ‘results’ directory with robust fallbacks.
- geoprior.utils.ensure_directory_exists(path)[source]
Ensure that a directory exists at the given path, creating it if needed.
This function checks whether the provided path exists and is a directory. If the path does not exist, it attempts to create the directory (including any necessary parent directories). If a file with the same name already exists, or if creation fails, an exception is raised.
- Parameters:
path (
strorpathlib.Path) – The filesystem path for which to ensure directory existence. Can be either a string or a pathlib.Path object.- Returns:
A Path object pointing to the existing (or newly created) directory.
- Return type:
- Raises:
TypeError – If path is not a string or pathlib.Path.
FileExistsError – If a file (not a directory) already exists at path.
OSError – If the directory cannot be created for any other reason (e.g., insufficient permissions).
Examples
>>> from pathlib import Path >>> from geoprior.utils.generic_utils import ensure_directory_exists >>> output_dir = ensure_directory_exists("data/output") >>> isinstance(output_dir, Path) True >>> # The directory "data/output" now exists on disk.
Notes
Uses pathlib.Path.mkdir(…, parents=True, exist_ok=True) under the hood for cross-platform compatibility.
If path already exists as a directory, this function returns immediately without modifying it.
See also
pathlib.Path.mkdirMethod to create a directory.
os.makedirsLegacy function for creating directories recursively.
- geoprior.utils.getenv_stripped(name, default=None, allow_empty=False)[source]
Read an environment variable and strip whitespace robustly.
- Parameters:
name (
str) – Environment variable name to read.default (
strorNone, optional) – Value returned when the environment variable is not set.allow_empty (
bool, defaultFalse) – If False, empty strings are treated as missing and default is returned instead. If True, an empty string is returned unchanged.
- Returns:
out – The stripped string value, the empty string (if allowed), or default when unset / empty.
- Return type:
- geoprior.utils.print_config_table(sections, title=None, table_width=None, sort_keys=True, key_col_fraction=0.35, max_value_length=200, log_fn=None)[source]
Pretty-print configuration or hyperparameters as a key/value table.
This helper is intended for CLI scripts (Stage-1, training, tuning) so that the user can quickly inspect which parameters are actually in effect.
- Parameters:
sections (
dictorsequenceof(str,dict)) –If a single dict is passed, all key/value pairs are printed in one block.
If a sequence is passed, it must contain
(name, params)tuples, wherenameis a section label (e.g."Physics") andparamsis a dict mapping parameter names to values.title (
str, optional) – Optional title displayed above the table (centered).table_width (
int, optional) – Total width of the printed table. IfNone, the function tries to usegeoprior.api.util.get_table_size(). If that fails, it falls back to the terminal width (viashutil.get_terminal_size) or 80 characters.sort_keys (
bool, defaultTrue) – Whether to sort parameter names alphabetically within each section.key_col_fraction (
float, default0.35) – Fraction of the table width allocated to the parameter-name column. The remainder is used for the value column.max_value_length (
int, default200) – Maximum number of characters kept from the stringified value. Longer values are truncated with an ellipsis ("...") before being wrapped onto multiple lines.log_fn (
callable, optional) – Function used to emit lines (defaults toprint()). This allows capturing the table in logs if needed.
- Returns:
The full rendered table as a single string. It is always printed via
print_fnas a side effect.- Return type:
Notes
Nested containers (lists, tuples, dicts) are rendered in a compact one-line form and then wrapped to fill the value column.
This function is intentionally lightweight and does not depend on external tabulation libraries, so it can be safely used in lightweight Stage-1 / Stage-2 scripts.
- geoprior.utils.save_all_figures(output_dir='figures', prefix='figure', fmts=('png',), close=True, dpi=150, transparent=False, timestamp=True, verbose=True)[source]
Save all currently open Matplotlib figures to disk in specified formats.
- Parameters:
output_dir (
str) – Directory where figures will be saved. Created if not exists.prefix (
str) – Filename prefix for each figure.formats (
listortupleofstr) – File formats/extensions to use (e.g., (‘png’,’pdf’)).close (
bool) – Whether to close each figure after saving. Default is True.dpi (
intorNone) – Resolution in dots per inch. None uses Matplotlib default.transparent (
bool) – Whether to save figures with transparent background.timestamp (
bool) – Append current timestamp (YYYYmmddTHHMMSS) to filenames.verbose (
bool) – Print progress messages.
- Returns:
List of saved file paths.
- Return type:
List[str]
Examples
>>> import matplotlib.pyplot as plt >>> plt.figure(); plt.plot([1, 2, 3]) >>> from geoprior.utils.generic_utils import save_all_figures >>> paths = save_all_figures(output_dir="plots", formats=("png",)) >>> print(paths) ['plots/figure_1_20250521T153045.png']
- geoprior.utils.build_censor_mask(xb, H, idx, thresh=0.5, *, source='dynamic', reduce_time='any', align='broadcast')[source]
Build a censor mask aligned to the forecast horizon: (B, H, 1).
- Parameters:
source (
{"dynamic", "future"}, default"dynamic") – Selects where the censoring flag is read from."dynamic"readsxb["dynamic_features"][:, :, idx]from the history window, while"future"readsxb["future_features"][:, :, idx]from the forecast window.reduce_time (
{"any", "last", "all"}, default"any") – Reduction applied whensource="dynamic"and the censor flag behaves like a per-sample label."any"marks the sample as censored if any history step is flagged,"last"uses only the last history step, and"all"requires every history step to be flagged.align (
{"broadcast", "crop", "pad_false", "pad_edge", "error"}, default"broadcast") – Policy used when the time axis does not already match the forecast horizonH."broadcast"repeats a single-step label across all horizon steps,"crop"keeps the lastHsteps,"pad_false"pads missing steps withFalse,"pad_edge"repeats the last available step, and"error"raises on mismatch.xb (dict)
idx (int | None)
thresh (float)
- Return type:
Tensor
- geoprior.utils.ensure_input_shapes(x, mode, forecast_horizon)[source]
Ensure presence of zero-width static/future placeholders.
Stage-1 exporters sometimes omit
static_featuresorfuture_featureswhen there are no static/future variables for a particular experiment. Keras, however, expects these inputs to exist so that the input signature remains stable.This helper:
Copies the input dict to avoid in-place modification.
Ensures
static_featuresis an array of shape(N, 0)if missing.Ensures
future_featuresis an array of shape(N, T_future, 0)if missing, where:T_future = dynamic_features.shape[1]whenmode == "tft_like"(past+future style).Otherwise,
T_future = forecast_horizon.
- Parameters:
- Returns:
Shallow copy of
xwith guaranteedstatic_featuresandfuture_featuresentries.- Return type:
- geoprior.utils.extract_preds(model, out, *, strict=True, output_names=None)[source]
Extract (subs_pred, gwl_pred) from GeoPrior outputs.
- Supports:
v3.2+ call(): {“subs_pred”,”gwl_pred”}
forward_with_aux(): (y_pred, aux)
legacy: {“data_final”} + model.split_data_predictions
predict(): list/tuple mapped via output names
If strict=True, list/tuple outputs must be mappable via output names; otherwise we raise to avoid silent swaps.
This helper normalizes the output interface across two GeoPrior generation families:
New interface (preferred)
model(inputs) -> {"subs_pred": ..., "gwl_pred": ...}Legacy interface (backward compatible)
model(inputs) -> {"data_final": ...}, where the caller must split the tensor usingmodel.split_data_predictions.
- Parameters:
model (
object) –A Keras-like model instance that may expose
split_data_predictions(data_final).The splitter must return a tuple:
subs_predwith shape(B, H, 1)or(B, H, Q, 1)gwl_predwith shape(B, H, 1)or(B, H, Q, 1)
out (
dict) –Output returned by the model call, typically
model(inputs, training=False).Supported keys are either:
{"subs_pred", "gwl_pred"}(new interface), or{"data_final"}(legacy interface).
strict (bool)
- Returns:
subs_pred (
Tensor) – Predicted subsidence in model space.Expected shapes:
Point mode:
(B, H, 1)Quantile mode:
(B, H, Q, 1)
gwl_pred (
Tensor) – Predicted groundwater/head variable in model space.Expected shapes:
Point mode:
(B, H, 1)Quantile mode:
(B, H, Q, 1)
- Raises:
- Return type:
Notes
This function is intended for Stage-2 and Stage-3 scripts where you may load checkpoints from older experiments. It avoids fragile code that slices
data_finalmanually.The function does not validate tensor dtypes or numerical finiteness. Upstream code should handle
NaNandInfchecks as needed. Output normalization follows the Keras model conventions documented in Keras Team [24].Examples
New interface:
out = model_inf(xb, training=False) s_pred, h_pred = extract_stage_outputs( model_inf, out, )
Legacy interface:
out = model_inf(xb, training=False) s_pred, h_pred = extract_stage_outputs( model_inf, out, )
See also
subs_point_from_stage_outConvert subsidence predictions to a point forecast.
- geoprior.utils.load_nat_config(root='nat.com')[source]
High-level helper used by NATCOM scripts.
Example
>>> from geoprior.utils.nat_utils import load_nat_config >>> cfg = load_nat_config() >>> CITY_NAME = cfg["CITY_NAME"] >>> TIME_STEPS = cfg["TIME_STEPS"]
- geoprior.utils.load_nat_config_payload(root='nat.com')[source]
Return the full config.json payload, including city, model and __meta__ fields.
This is convenient when you also want to see which hash or city/model are currently active.
- geoprior.utils.load_scaler_info(encoders_block)[source]
Load the
scaler_infomapping from an encoders block.Stage-1 exporters typically store a compact description of the scalers used to normalise the data. In many cases this takes the form:
encoders = { "main_scaler": "/path/to/minmax.joblib", "coord_scaler": "/path/to/coords.joblib", "scaler_info": "/path/to/scaler_info.joblib", ... }
where
scaler_infois either a path to a joblib file or an already-loaded dictionary.This helper returns a dictionary regardless of how it was stored, making downstream formatting/evaluation code simpler.
- geoprior.utils.make_tf_dataset(X_np, y_np, batch_size, shuffle, mode, forecast_horizon, *, seed=42, drop_remainder=False, reshuffle_each_iter=True, prefetch=True, check_npz_finite=False, check_finite=False, scan_finite_batches=0, dynamic_feature_names=None, future_feature_names=None)[source]
Build a tf.data.Dataset using NATCOM conventions.
Steps: 1) ensure_input_shapes(…) for X. 2) map_targets_for_training(…) for y. 3) tf.data pipeline (shuffle/batch/prefetch). 4) optional finite checks (NPZ + tf batches).
- Parameters:
X_np (
dict) – Input dictionary, typically obtained fromnp.loadon the Stage-1*_inputs_npzfile.y_np (
dict) – Target dictionary, typically obtained fromnp.loadon the Stage-1*_targets_npzfile.batch_size (
int) – Number of samples per batch.shuffle (
bool) – IfTrue, shuffle the dataset using a fixed seed for reproducibility.mode (
str) – Model mode passed toensure_input_shapes().forecast_horizon (
int) – Forecast horizon passed toensure_input_shapes().check_npz_finite (
bool) – If True, checks Xin/Yin numpy arrays for NaN/Inf before building ds.check_finite (
bool) – If True, inserts assert_all_finite checks inside the tf.data pipeline.scan_finite_batches (
int) – If >0, eagerly scans first N batches right away (fails early).dynamic_feature_names (list[str] | None) – If provided, used to report bad channels for feature tensors.
future_feature_names (list[str] | None) – If provided, used to report bad channels for feature tensors.
seed (int)
drop_remainder (bool)
reshuffle_each_iter (bool)
prefetch (bool)
- Returns:
Dataset of (X, y) pairs.
- Return type:
tf.data.Dataset
Notes
TensorFlow is imported lazily inside the function so that this module remains importable in environments where TF is not installed (for example, for tooling or static analysis).
- geoprior.utils.map_targets_for_training(y_dict, subs_key='subsidence', gwl_key='gwl', subs_pred_key='subs_pred', gwl_pred_key='gwl_pred')[source]
Standardise target dictionaries to the Keras compile keys.
This helper enforces a small convention used throughout the NATCOM training scripts:
Upstream sequence builders typically export raw targets with keys
subsidenceandgwl.The GeoPrior model is compiled with targets named
subs_predandgwl_pred.
This function accepts either style and always returns a dict keyed by
subs_predandgwl_predfor use in Keras.- Parameters:
y_dict (
dict) – Dictionary produced by the Stage-1 sequence exporter or by a previous training script. Must contain either (subsidence,gwl) or (subs_pred,gwl_pred).subs_key (
str, default"subsidence") – Name of the raw subsidence key iny_dict.gwl_key (
str, default"gwl") – Name of the raw groundwater-level key iny_dict.subs_pred_key (
str, default"subs_pred") – Standardised key for the subsidence prediction target.gwl_pred_key (
str, default"gwl_pred") – Standardised key for the GWL prediction target.
- Returns:
New dictionary with keys
subs_predandgwl_pred.- Return type:
- Raises:
KeyError – If the dictionary does not contain either of the expected key pairs.
- geoprior.utils.name_of(obj)[source]
Return a human-readable name for an object.
This utility is handy when serialising compile configurations (e.g., turning metric callables into simple strings for JSON logs).
- geoprior.utils.resolve_hybrid_config(manifest_cfg, live_cfg, verbose=True)[source]
Merge Manifest config (Data Authority) with Live config (Physics Authority).
- geoprior.utils.resolve_si_affine(cfg, scaler_info, *, target_name, prefix, unit_factor_key, scale_key, bias_key)[source]
- geoprior.utils.best_epoch_and_metrics(history, monitor='val_loss')[source]
Return the best epoch and metrics at that epoch.
Given a
History.historydictionary produced bymodel.fit(...), this helper identifies the index of the minimum value for the monitored quantity (by default"val_loss") and returns:The epoch index (0-based).
A dictionary mapping each metric name to its value at that epoch.
- Parameters:
- Returns:
- Return type:
- geoprior.utils.subs_point_from_out(model, out, quantiles=None, med_idx=None)[source]
Convert model output into a subsidence point forecast.
This helper produces a subsidence tensor shaped
(B, H, 1)in model space, regardless of whether the model emits quantiles or a point prediction.If quantiles are present and the subsidence prediction is shaped
(B, H, Q, 1), the function selects the median quantile slice.Otherwise, it returns the point prediction directly.
- Parameters:
model (
object) – A Keras-like model instance passed toextract_stage_outputs().out (
dict) –Output returned by the model call.
This can be either the new interface with keys
"subs_pred"and"gwl_pred", or the legacy interface with key"data_final".quantiles (
sequenceoffloatorNone, defaultNone) –Quantile levels used by the model, such as
[0.1, 0.5, 0.9].If provided, the function may use it to interpret the rank-4 quantile output and select the median.
If
None, quantile selection is disabled unlessmed_idxis explicitly provided and the tensor rank indicates quantiles.med_idx (
intorNone, defaultNone) –Index along the quantile axis to use as the “point” forecast when quantiles are available.
If
Noneandquantilesis provided, the function selects the index closest to0.5.
- Returns:
subs_point – Subsidence point prediction in model space with shape
(B, H, 1).- Return type:
Tensor- Raises:
ValueError – If subsidence prediction is missing or
None.ValueError – If a quantile tensor is detected but a valid median index cannot be resolved.
Notes
Quantile outputs are assumed to be shaped
(B, H, Q, 1)where the quantile axis is the third dimension (axis=2).If the model returns point predictions already, the function is effectively a no-op. The quantile interpretation used here follows Koenker and Bassett [25].
Examples
Quantile model:
out = model_inf(xb, training=False) s_point = subs_point_from_stage_out( model_inf, out, quantiles=[0.1, 0.5, 0.9], )
Point model:
out = model_inf(xb, training=False) s_point = subs_point_from_stage_out( model_inf, out, )
See also
extract_stage_outputsNormalize outputs across new and legacy checkpoints.
- geoprior.utils.serialize_subs_params(params, cfg=None)[source]
Make GeoPrior subnet parameters JSON-friendly.
The training scripts typically pass a dictionary of model construction arguments, e.g.
subsmodel_params, which contains objects such asLearnableMVorFixedGammaWthat are not directly JSON-serialisable.This helper replaces those objects by small dictionaries describing their type and scalar value, optionally using values from the NATCOM config dictionary.
- Parameters:
params (
dict) – Dictionary of model init parameters (e.g.subsmodel_paramsintraining_NATCOM_GEOPRIOR.py).cfg (
dict, optional) –NATCOM config dictionary. If provided, scalar values are taken from:
GEOPRIOR_INIT_MVGEOPRIOR_INIT_KAPPAGEOPRIOR_GAMMA_WGEOPRIOR_H_REF
and used as the authoritative numbers.
- Returns:
Copy of
paramswhere scalar GeoPrior parameters are replaced by JSON-friendly dictionaries.- Return type:
Notes
This function does not import any of the GeoPrior classes. It only introspects attributes like
initial_valueorvaluewhen the corresponding config entry is missing.
- geoprior.utils.save_ablation_record(outdir, city, model_name, cfg, eval_dict, phys_diag=None, per_h_mae=None, per_h_r2=None, log_fn=None)[source]
Append a single ablation record to
ablation_record.jsonl.Each training run (e.g., different physics toggles or weights) writes one JSON line containing:
Basic run identifiers (city, model, timestamp).
Physics configuration (
PDE_MODE_CONFIG, lambda weights, effective head flags, etc.).Key performance metrics (R², MSE, MAE, coverage, sharpness).
Optional physics diagnostics (
epsilon_prior,epsilon_cons).Optional per-horizon MAE/R² for more detailed analysis.
- Parameters:
outdir (
str) – Base output directory for the current run. The ablation file is created underoutdir / "ablation_records".city (
str) – City name (e.g.,"nansha"or"zhongshan").model_name (
str) – Model identifier (e.g.,"GeoPriorSubsNet").cfg (
dict) – Lightweight configuration dictionary containing at least the physics-related keys used below.eval_dict (
dictorNone) – Dictionary of evaluation metrics (R², MSE, MAE, coverage80, sharpness80). IfNone, metrics fields default toNone.phys_diag (
dictorNone, optional) – Physics diagnostics (e.g., fromevaluate()) with keys such as"epsilon_prior"and"epsilon_cons".per_h_mae (
dictorNone, optional) – Per-horizon MAE values (e.g., keyed by year/step).
- Return type:
None
Notes
The output file is a JSON-Lines file, so it can be loaded with
load_ablation_jsonl().
- geoprior.utils.cumulative_to_rate(df, *, cum_col='subsidence_cum', rate_col='subsidence', time_col='year', group_cols=('longitude', 'latitude', 'city'), first='cum_over_dtref', inplace=False)[source]
Recover a rate series from cumulative displacement.
rate(t_i) = (cum(t_i) - cum(t_{i-1})) / dt_i for i>=1
- first:
‘nan’: first rate is NaN
‘cum_over_dtref’: rate(t0) = cum(t0)/dt_ref (dt_ref median dt)
- geoprior.utils.normalize_gwl_alias(df, gwl_col_user, *, prefer_depth_bgs=True, verbose=True)[source]
Normalize common GWL naming aliases.
Naming-only: unit conversion happens later.
- geoprior.utils.rate_to_cumulative(df, *, rate_col='subsidence', cum_col='subsidence_cum', time_col='year', group_cols=('longitude', 'latitude', 'city'), initial='first_equals_rate_dt', inplace=False)[source]
Build cumulative displacement from a rate series.
- geoprior.utils.resolve_gwl_for_physics(df, gwl_col_user, *, prefer_depth_bgs=True, allow_keep_zscore_as_ml=True, verbose=True)[source]
Pick meters GWL for physics; keep z-score ML-only.
- geoprior.utils.resolve_head_column(df, *, depth_col, head_col='head_m', z_surf_col=None, use_head_proxy=True)[source]
Resolve a head column, creating one if needed.
- geoprior.utils.make_txy_coords(t, x, y, *, time_shift='min', xy_shift='min', time_shift_value=None, x_shift_value=None, y_shift_value=None, dtype='float32')[source]
Build coords tensor (t, x, y) with OPTIONAL shifting (translation only).
- This is designed for your “not normalized” workflow:
You keep SI units (years and meters),
but you avoid feeding huge UTM magnitudes (e.g. 3e5, 2.5e6) into coord MLPs by shifting x,y (and optionally t).
Notes
This does NOT min-max scale to [0,1]. It only translates.
Returning coord_mins/coord_ranges is still useful for logging/debug.
- geoprior.utils.compute_group_masks(df, *, group_cols, time_col, train_end_year, time_steps, horizon)[source]
- Build:
valid_for_train: groups containing all years for last (T+H)
valid_for_forecast: groups containing all years for last T
This assumes annual steps and integer years in time_col.
- geoprior.utils.split_groups_holdout(groups, *, seed=42, val_frac=0.2, test_frac=0.1, strategy='random', x_col=None, y_col=None, block_size=None)[source]
Split unique groups into train/val/test (pixel-level holdout).
- strategy:
“random”: shuffle groups
“spatial_block”: shuffle spatial blocks (needs x_col,y_col,block)
- geoprior.utils.filter_df_by_groups(df, *, group_cols, groups)[source]
Keep only rows in df whose (group_cols) exist in groups DataFrame.
- geoprior.utils.build_future_sequences_npz(df_scaled, *, time_col, time_col_num, lon_col, lat_col, time_steps, train_end_time=None, forecast_start_time=None, forecast_horizon=None, subs_col=None, gwl_col=None, h_field_col=None, static_features=None, dynamic_features=None, future_features=None, group_id_cols=None, mode=None, model_name=None, artifacts_dir=None, prefix='future', future_mode='auto', normalize_coords=False, coord_scaler=None, verbose=1, logger=None, stop_check=None, progress_hook=None, **kws)[source]
Build history–future sequences and save them as compressed NPZ files.
This helper constructs, for each spatial group, a sliding window of time_steps “history” points followed by a multi–step forecast horizon and exports the resulting NumPy arrays to disk. It is time-agnostic: the time_col can be numeric (e.g. year, index), year-like floats, datetimes, or strings, as long as equality on that column is meaningful.
If train_end_time, forecast_start_time, or forecast_horizon are not provided, they are inferred from the sorted unique values in
df_scaled[time_col]:train_end_time: by default the second-to-last unique time, leaving at least one future step.forecast_start_time: by default the first time strictly aftertrain_end_time.forecast_horizon: by default one time step ahead, clipped to the number of available future points.
For each valid group, the function builds history dynamic features of shape
(time_steps, n_dynamic), future features of shape(time_steps + H, n_future)whenmodestarts with"tft"or(H, n_future)otherwise, one static feature vector of shape(n_static,), coordinates over the horizon of shape(H, 3)with columns[time_num, lon, lat], anH_fieldarray of shape(H, 1), and optional subsidence and groundwater targets of shape(H, 1)each.All per-group arrays are stacked along a new batch dimension and written as two NPZ files:
<prefix>_inputs.npz: coordinates, dynamic, static, future features and H field.<prefix>_targets.npz: subsidence and groundwater targets.
- Parameters:
df_scaled (
pandas.DataFrame) – Pre-processed (typically scaled) dataframe containing all required columns: time, spatial coordinates, static/dynamic/ future features and optional targets.time_col (
str) – Name of the column encoding the temporal index (e.g."year","date","t_index"). May be numeric, datetime, or string.time_col_num (
strorNone) – Optional numeric time column used as a tie-breaker when multiple rows share the sametime_colvalue. If provided and present in a group, the last row sorted by this column is selected for that time.lon_col (
str) – Name of the longitude (or x-coordinate) column.lat_col (
str) – Name of the latitude (or y-coordinate) column.time_steps (
int) – Length of the history window (number of time steps in the past). Must be strictly positive.train_end_time (
object, optional) – Effective end of the training period. IfNone, it is inferred as the second-to-last unique value indf_scaled[time_col](after sorting).forecast_start_time (
object, optional) – First time step of the forecast horizon. IfNone, it is inferred as the first unique time strictly greater thantrain_end_time.forecast_horizon (
int, optional) – Number of future time steps to include. IfNone, a default horizon of1is used and clipped to the maximum number of available future time points.subs_col (
str, optional) – Name of the subsidence target column. IfNoneor missing from a group, subsidence targets are filled withNaN.gwl_col (
str, optional) – Name of the groundwater-level target column. IfNoneor missing from a group, groundwater targets are filled withNaN.h_field_col (
str, optional) – Name of the hydraulic-head field column used as an additional horizon-level input (H_field). IfNoneor missing, a zero field is used.static_features (
listofstr, optional) – Names of static (time-invariant) feature columns. Any names not present in the dataframe are silently ignored.dynamic_features (
listofstr, optional) – Names of dynamic (history) feature columns used to build the(time_steps, n_dynamic)sequence. Missing columns are ignored.future_features (
listofstr, optional) – Names of future covariate columns used to build the history+future or future-only sequence, depending onmode. Missing columns are ignored.group_id_cols (
listofstr, optional) – Columns used to define spatial (or logical) groups, typically something like["lon", "lat"]or a station identifier. IfNoneor empty, the entire dataframe is treated as a single global group.mode (
str, optional) – Controls how future features are constructed. If the lower-cased value starts with"tft"(e.g."tft_like"), future features are built on top of both history and future rows. Otherwise, only the forecast horizon rows are used.model_name (
str, optional) – Optional model identifier used only in logging messages.artifacts_dir (
str, optional) – Directory where NPZ files are written. IfNoneor empty, the current working directory is used.prefix (
str, default"future") – Prefix for the output NPZ filenames:"<prefix>_inputs.npz"and"<prefix>_targets.npz".future_mode (
{'auto', 'pure-inference', 'pure-data-driven'}, default'auto') –Strategy used to construct the future (forecast) portion of the sequences.
'pure-data-driven': Use only time points that actually exist indf_scaledstrictly after the history window. All future time indices must be present in the data; otherwise aValueErroris raised. This corresponds to the original, strictly data-driven behaviour.'pure-inference': Always synthesize future time points from the last history time, using the median positive time step (or1.0as a fallback). Future inputs are built by re-using the last available history row (forfuture_features,H_field, etc.), and future targets (e.g. subsidence, GWL) are filled withNaNsince the true future is unknown. This mode does not require any rows beyondtrain_end_time.'auto': Try data-driven mode first. If there are enough actual future time points aftertrain_end_timeto cover the requestedforecast_horizon, behave like'pure-data-driven'. If not, automatically fall back to the synthetic'pure-inference'behaviour described above and emit an informational log message viavlog.
verbose (
int, default1) – Verbosity level forwarded togeoprior.utils.vlog(). A value>= 3provides detailed progress logs (temporal inference, per-group status, dropped groups, etc.).logger (
logging.Loggerorcallable, optional) – Optional logger or logging function used bygeoprior.utils.vlog(). IfNone, messages are printed to standard output.**kws – Reserved for future extensions. Currently ignored.
normalize_coords (bool)
coord_scaler (Any | None)
- Returns:
A small dictionary with the absolute paths to the written NPZ files:
{"future_inputs_npz": <path>, "future_targets_npz": <path>}.- Return type:
- Raises:
ValueError – If there are not enough history points before
train_end_timeto satisfytime_steps, if no future points are available afterforecast_start_time, or if all groups are dropped due to incomplete history/horizon windows.
Notes
Groups that do not contain all required history and future times are silently dropped, but the number of dropped groups is reported via
geoprior.utils.vlog()whenverbose > 0.Examples
>>> from geoprior.nn.pinn.sequences import ( ... build_future_sequences_npz, ... ) >>> result = build_future_sequences_npz( ... df_scaled=df_scaled, ... time_col="year", ... time_col_num="t_index", ... lon_col="lon", ... lat_col="lat", ... time_steps=5, ... # Let the function infer times/horizon: ... train_end_time=None, ... forecast_start_time=None, ... forecast_horizon=None, ... subs_col="subsidence", ... gwl_col="gwl", ... h_field_col="H_field", ... static_features=["lithology_class"], ... dynamic_features=["rainfall_mm", "GWL_depth_bgs_z"], ... future_features=["normalized_urban_load_proxy"], ... group_id_cols=["lon", "lat"], ... mode="tft_like", ... model_name="GeoPriorSubsNet", ... artifacts_dir="results/zhongshan/future_npz", ... prefix="zhongshan_future", ... verbose=2, ... ) >>> result["future_inputs_npz"] 'results/zhongshan/future_npz/zhongshan_future_inputs.npz' >>> result["future_targets_npz"] 'results/zhongshan/future_npz/zhongshan_future_targets.npz'
- geoprior.utils.threads_per_job(*, n_jobs, threads=0, reserve=1)[source]
- geoprior.utils.apply_tf_threading(*, intra, inter)[source]
- geoprior.utils.apply_thread_env(env, *, n_jobs, threads=0, reserve=1)[source]
- geoprior.utils.resolve_device(device, *, env=None)[source]
- geoprior.utils.resolve_gpu_ids(gpu_ids, *, env=None)[source]
- geoprior.utils.pick_gpu_id(idx, gpu_ids)[source]
- geoprior.utils.apply_gpu_env(env, *, gpu_id, allow_growth=True)[source]
- class geoprior.utils.ArtifactRecord(path, kind, payload, stage=None, city=None, model=None, meta=<factory>)[source]
Bases:
objectLightweight normalized artifact container.
- Parameters:
path (
pathlib.Path) – Artifact path.kind (
str) – Inferred or explicit artifact kind.payload (
dict[str,Any]) – Loaded JSON payload.meta (
dict[str,Any]) – Extra extracted metadata.
- path: Path
- kind: str
- geoprior.utils.artifact_brief(record)[source]
Return a compact artifact header summary.
- geoprior.utils.as_path(path)[source]
Return
pathas resolvedPath.
- geoprior.utils.bool_checks_frame(mapping, *, section=None)[source]
Convert boolean checks into a tidy DataFrame.
- geoprior.utils.clone_artifact(template, *, overrides=None)[source]
Clone a template payload and apply overrides.
This is useful for Sphinx-Gallery examples where we want a realistic artifact with a few controlled changes.
- geoprior.utils.deep_update(base, updates)[source]
Recursively update
basewithupdates.Returns a new dictionary.
- geoprior.utils.ensure_parent_dir(path)[source]
Create parent directory for
path.
- geoprior.utils.flatten_dict(mapping, *, parent_key='', sep='.')[source]
Flatten nested dictionaries.
Non-dict values are kept as they are. Lists and arrays are not expanded.
- geoprior.utils.infer_artifact_kind(path, payload=None)[source]
Infer artifact kind from file name and keys.
The rules are intentionally simple and stable. Artifact-specific readers can still override the inferred kind if needed.
- geoprior.utils.is_number(value)[source]
Return True for finite or non-finite scalars.
- geoprior.utils.json_ready(value)[source]
Convert nested values into JSON-safe objects.
Notes
NaNandInfare converted toNone.numpy scalars are converted to Python scalars.
arrays become lists.
- geoprior.utils.load_artifact(path, *, kind=None)[source]
Load a JSON artifact into
ArtifactRecord.
- geoprior.utils.metrics_frame(mapping, *, section=None, sort=True)[source]
Convert scalar metrics into a tidy DataFrame.
- geoprior.utils.nested_get(mapping, *keys, default=None)[source]
Safely traverse nested dictionaries.
Examples
nested_get(d, "config", "scaling_kwargs")
- geoprior.utils.numeric_items(mapping, *, drop_bools=True)[source]
Extract numeric scalar items from a mapping.
- geoprior.utils.plot_boolean_checks(ax, checks=None, *, title='Checks', show_grid=True, grid_kws=None, ax_obj=None, error='ignore', **plot_kws)[source]
Plot boolean pass/fail checks as a bar view.
Keeps the older
(ax, checks)call pattern while allowing(checks, ax_obj=ax)for newer code.
- geoprior.utils.plot_metric_bars(ax, metrics=None, *, title='Metrics', top_n=None, sort_by_value=False, absolute=False, annotate=True, xlabel='value', show_grid=True, grid_kws=None, annotate_kws=None, ax_obj=None, error='ignore', **plot_kws)[source]
Plot a compact horizontal metric bar chart.
The legacy calling style
plot_metric_bars(ax, metrics, ...)is preserved. A newer styleplot_metric_bars(metrics, ax_obj=ax, ...)is also accepted for gradual migration.
- geoprior.utils.plot_series_map(ax, series_map=None, *, title='Series', xlabel='key', ylabel='value', marker='o', show_grid=True, grid_kws=None, ax_obj=None, error='ignore', **plot_kws)[source]
Plot a string-keyed numeric mapping as a line.
Keeps the older
(ax, series_map)form while also accepting(series_map, ax_obj=ax).
- geoprior.utils.read_json(path)[source]
Read a JSON file into a dictionary.
- geoprior.utils.write_json(payload, path, *, indent=2, sort_keys=False)[source]
Write
payloadas UTF-8 JSON.
- geoprior.utils.build_stage1_feature_split(*, dynamic_features, static_features, future_features, scaled_ml_numeric_cols)[source]
Build the Stage-1 feature split section.
- Parameters:
- Returns:
Stage-1 style feature split mapping.
- Return type:
- geoprior.utils.default_stage1_audit_payload(*, city='demo_city', model='GeoPriorSubsNet', normalize_coords=True, keep_coords_raw=None, shift_raw_coords=True, coords_in_degrees=False, coord_mode='degrees', coord_epsg_used=32649, coord_ranges=None, dynamic_features=None, static_features=None, future_features=None, scaled_ml_numeric_cols=None, scaler_path='artifacts/main_scaler.joblib', n_sequences=1200, forecast_horizon=3, time_steps=5)[source]
Build a realistic default Stage-1 audit payload.
The payload is template-based. It is not meant to reproduce the full preprocessing mathematics. Instead, it creates a stable and inspectable audit artifact with the same broad structure as the real Stage-1 audit file.
- geoprior.utils.generate_stage1_audit(*, output_path=None, template=None, overrides=None, **kwargs)[source]
Generate a Stage-1 audit payload or file.
- Parameters:
output_path (
path-like, optional) – Destination JSON path. If omitted, the payload is returned instead of written.template (
mapping,ArtifactRecord, orpath, optional) – Real or synthetic Stage-1 audit template used as the generation base.overrides (
dict, optional) – Nested overrides applied after template/default payload creation.**kwargs (
dict) – Parameters forwarded todefault_stage1_audit_payloadwhen no template is given.
- Return type:
- geoprior.utils.inspect_stage1_audit(audit, *, output_dir=None, stem='stage1_audit', save_figures=True)[source]
Inspect a Stage-1 audit and optionally save figures.
- geoprior.utils.load_stage1_audit(path)[source]
Load a Stage-1 audit artifact.
- Raises:
ValueError – If the artifact does not look like a Stage-1 audit payload.
- Parameters:
- Return type:
ArtifactRecord
- geoprior.utils.plot_stage1_boolean_summary(audit, *, ax=None, title='Stage-1 audit checks', error='ignore', **plot_kws)[source]
Plot semantic pass/fail checks.
- geoprior.utils.plot_stage1_coord_ranges(audit, *, ax=None, title='Stage-1 coord ranges', ylabel='range', show_grid=True, grid_kws=None, annotate=True, annotate_kws=None, error='ignore', **plot_kws)[source]
Plot chain-rule coordinate ranges.
- geoprior.utils.plot_stage1_feature_split(audit, *, ax=None, title='Stage-1 feature split', xlabel='feature group', ylabel='features', show_grid=True, grid_kws=None, legend_kws=None, error='ignore', **plot_kws)[source]
Plot feature bucket counts by feature group.
- geoprior.utils.plot_stage1_target_stats(audit, *, stat='mean', ax=None, title=None)[source]
Plot target summary statistics.
- geoprior.utils.plot_stage1_variable_stats(audit, *, section='physics_df_stats', stat='mean', ax=None, title=None, error='ignore', **plot_kws)[source]
Plot one statistic across variables.
- geoprior.utils.stage1_coord_ranges_frame(audit)[source]
Return coord ranges as a tidy frame.
- geoprior.utils.stage1_feature_split_frame(audit)[source]
Explode the feature split into tidy rows.
- geoprior.utils.stage1_stats_frame(audit, *, section='physics_df_stats')[source]
Return a tidy frame for nested variable stats.
- geoprior.utils.summarize_stage1_audit(audit)[source]
Build a compact semantic summary for inspection.
- geoprior.utils.default_stage2_handshake_payload(*, city='demo_city', model='GeoPriorSubsNet', time_steps=5, forecast_horizon=3, mode='tft_like', n_train=1200, n_val=320, dynamic_dim=5, future_dim=1, static_dim=12, coords_normalized=True, coord_ranges=None, coords_in_degrees=False, coord_epsg_used=32649, time_units='year', gwl_dyn_name='GWL_depth_bgs_m__si', gwl_dyn_index=0, subs_dyn_index=1, z_surf_static_index=11, use_head_proxy=True, q_kind='per_volume', drainage_mode='double')[source]
Build a realistic default Stage-2 handshake payload.
The payload is template-based. It is not meant to rerun Stage-2 logic. Instead, it creates a stable and inspectable audit artifact with the same broad structure as the real Stage-2 handshake file.
- Parameters:
city (str)
model (str)
time_steps (int)
forecast_horizon (int)
mode (str)
n_train (int)
n_val (int)
dynamic_dim (int)
future_dim (int)
static_dim (int)
coords_normalized (bool)
coords_in_degrees (bool)
coord_epsg_used (int | None)
time_units (str)
gwl_dyn_name (str)
gwl_dyn_index (int)
subs_dyn_index (int)
z_surf_static_index (int)
use_head_proxy (bool)
q_kind (str)
drainage_mode (str)
- Return type:
- geoprior.utils.generate_stage2_handshake(*, output_path=None, template=None, overrides=None, **kwargs)[source]
Generate a Stage-2 handshake payload or file.
- Parameters:
output_path (
path-like, optional) – Destination JSON path. If omitted, the payload is returned instead of written.template (
mapping,ArtifactRecord, orpath, optional) – Real or synthetic Stage-2 handshake template used as the generation base.overrides (
dict, optional) – Nested overrides applied after template/default payload creation.**kwargs (
dict) – Parameters forwarded todefault_stage2_handshake_payloadwhen no template is given.
- Return type:
- geoprior.utils.inspect_stage2_handshake(audit, *, output_dir=None, stem='stage2_handshake', save_figures=True)[source]
Inspect a Stage-2 handshake and optionally save figures.
- geoprior.utils.load_stage2_handshake(path)[source]
Load a Stage-2 handshake artifact.
- Raises:
ValueError – If the artifact does not look like a Stage-2 handshake payload.
- Parameters:
- Return type:
ArtifactRecord
- geoprior.utils.plot_stage2_boolean_summary(audit, *, ax=None, title='Stage-2 handshake checks', error='ignore', **plot_kws)[source]
Plot semantic pass/fail checks.
- geoprior.utils.plot_stage2_coord_range_errors(audit, *, ax=None, title='Stage-2 coord range relative errors', ylabel='relative error', show_grid=True, grid_kws=None, annotate=True, annotate_kws=None, error='ignore', **plot_kws)[source]
Plot coord range relative errors.
- geoprior.utils.plot_stage2_coord_stats(audit, *, section='coord_stats_norm', stat='mean', ax=None, title=None, error='ignore', **plot_kws)[source]
Plot one coord statistic across axes.
- geoprior.utils.plot_stage2_finite_ratios(audit, *, ax=None, title='Stage-2 finite ratios', error='ignore', **plot_kws)[source]
Plot finite-ratio metrics.
- geoprior.utils.plot_stage2_sample_sizes(audit, *, ax=None, title='Stage-2 sample sizes', error='ignore', **plot_kws)[source]
Plot training and validation counts.
- geoprior.utils.plot_stage2_scaling_summary(audit, *, ax=None, title='Stage-2 scaling summary', top_n=12, error='ignore', **plot_kws)[source]
Plot numeric items from
sk_summary.
- geoprior.utils.stage2_coord_range_frame(audit)[source]
Return coord range spans and relative errors.
- geoprior.utils.stage2_coord_stats_frame(audit, *, section='coord_stats_norm')[source]
Return a tidy frame for coord stat blocks.
- geoprior.utils.stage2_finite_frame(audit)[source]
Return finite-ratio checks as a tidy frame.
- geoprior.utils.stage2_layout_frame(audit)[source]
Return expected vs observed layout rows.
- geoprior.utils.stage2_scaling_frame(audit)[source]
Return a tidy frame for the compact scaling summary.
- geoprior.utils.summarize_stage2_handshake(audit)[source]
Build a compact semantic summary for inspection.
- geoprior.utils.default_training_summary_payload(*, city='demo_city', model='GeoPriorSubsNet', horizon=3, best_epoch=17, timestamp='20260222-211635', optimizer='Adam', learning_rate=0.001, time_steps=5, pde_mode='on', offset_mode='mul', quantiles=None, attention_levels=None, coords_normalized=True, coord_ranges=None, run_dir='results/demo_run/train_20260222-211635')[source]
Build a realistic default training-summary payload.
The payload is template-based. It is not meant to reproduce the full training loop. Instead, it creates a stable and inspectable summary artifact with the same broad structure as the real training summary.
- geoprior.utils.generate_training_summary(*, output_path=None, template=None, overrides=None, **kwargs)[source]
Generate a training-summary payload or file.
- Parameters:
output_path (
path-like, optional) – Destination JSON path. If omitted, the payload is returned instead of written.template (
mapping,ArtifactRecord, orpath, optional) – Real or synthetic training-summary template used as the generation base.overrides (
dict, optional) – Nested overrides applied after template/default payload creation.**kwargs (
dict) – Parameters forwarded todefault_training_summary_payloadwhen no template is given.
- Return type:
- geoprior.utils.inspect_training_summary(summary, *, output_dir=None, stem='training_summary', save_figures=True)[source]
Inspect a training summary and optionally save figures.
- geoprior.utils.load_training_summary(path)[source]
Load a training-summary artifact.
- Raises:
ValueError – If the artifact does not look like a training summary payload.
- Parameters:
- Return type:
ArtifactRecord
- geoprior.utils.plot_training_best_metrics(summary, *, split='val', keys=None, ax=None, title=None, error='ignore', **plot_kws)[source]
Plot selected metrics from
metrics_at_best.
- geoprior.utils.plot_training_boolean_summary(summary, *, ax=None, title='Training summary checks', error='ignore', **plot_kws)[source]
Plot semantic pass/fail checks.
- geoprior.utils.plot_training_final_metrics(summary, *, split='val', keys=None, ax=None, title=None, error='ignore', **plot_kws)[source]
Plot selected metrics from
final_epoch_metrics.
- geoprior.utils.plot_training_loss_family(summary, *, section='metrics_at_best', split='val', ax=None, title=None, error='ignore', **plot_kws)[source]
Plot the loss-family subset for one metric section.
- geoprior.utils.plot_training_metric_deltas(summary, *, split='val', keys=None, ax=None, title=None, xlabel='delta', show_grid=True, grid_kws=None, annotate=False, annotate_kws=None, error='ignore', **plot_kws)[source]
Plot
final - bestdeltas for aligned metrics.- Parameters:
- Return type:
- geoprior.utils.training_compile_frame(summary)[source]
Return a tidy frame for compile settings.
- geoprior.utils.training_env_frame(summary)[source]
Return a tidy frame for environment info.
- geoprior.utils.training_hp_frame(summary)[source]
Return a tidy frame for hp/init settings.
- geoprior.utils.training_metrics_frame(summary, *, section='metrics_at_best', split='all')[source]
Return a tidy frame for train/validation metrics.
- geoprior.utils.training_paths_frame(summary)[source]
Return a tidy frame for output paths.
- geoprior.utils.summarize_training_summary(summary)[source]
Build a compact semantic summary for inspection.
- geoprior.utils.default_eval_diagnostics_payload(*, years=None, per_horizon_mae=None, per_horizon_mse=None, per_horizon_rmse=None, per_horizon_r2=None, coverage80=None, sharpness80=None, pss=None)[source]
Build a realistic default eval-diagnostics payload.
The payload is template-based. It is not meant to reproduce the full evaluation pipeline. Instead, it creates a stable and inspectable diagnostics artifact with the same broad structure as the real compact
eval_diagnosticsJSON.- Parameters:
- Return type:
- geoprior.utils.eval_overall_frame(diagnostics)[source]
Return a compact frame for the
__overall__block.
- geoprior.utils.eval_per_horizon_frame(diagnostics)[source]
Return a tidy per-horizon metrics frame.
- geoprior.utils.eval_years_frame(diagnostics)[source]
Return one row per year block.
- geoprior.utils.generate_eval_diagnostics(*, output_path=None, template=None, overrides=None, **kwargs)[source]
Generate an eval-diagnostics payload or file.
- Parameters:
output_path (
path-like, optional) – Destination JSON path. If omitted, the payload is returned instead of written.template (
mapping,ArtifactRecord, orpath, optional) – Real or synthetic diagnostics template used as the generation base.overrides (
dict, optional) – Nested overrides applied after template/default payload creation.**kwargs (
dict) – Parameters forwarded todefault_eval_diagnostics_payloadwhen no template is given.
- Return type:
- geoprior.utils.inspect_eval_diagnostics(diagnostics, *, output_dir=None, stem='eval_diagnostics', save_figures=True)[source]
Inspect eval diagnostics and optionally save figures.
- geoprior.utils.load_eval_diagnostics(path)[source]
Load an eval-diagnostics artifact.
- Raises:
ValueError – If the artifact does not look like a compact evaluation-diagnostics payload.
- Parameters:
- Return type:
ArtifactRecord
- geoprior.utils.plot_eval_boolean_summary(diagnostics, *, ax=None, title='Evaluation diagnostics checks', error='ignore', **plot_kws)[source]
Plot semantic pass/fail checks.
- geoprior.utils.plot_eval_overall_metrics(diagnostics, *, keys=None, ax=None, title=None, error='ignore', **plot_kws)[source]
Plot selected top-level metrics from
__overall__.
- geoprior.utils.plot_eval_per_horizon_metrics(diagnostics, *, metric='rmse', ax=None, title=None, show_grid=True, grid_kws=None, error='ignore', **plot_kws)[source]
Plot one per-horizon metric from
__overall__.
- geoprior.utils.plot_eval_year_metric_trend(diagnostics, *, metric='overall_mae', ax=None, title=None, show_grid=True, grid_kws=None, error='ignore', **plot_kws)[source]
Plot one metric across year blocks.
- geoprior.utils.summarize_eval_diagnostics(diagnostics)[source]
Build a compact semantic summary for inspection.
- geoprior.utils.default_eval_physics_payload(*, timestamp='20260222-215049', city='demo_city', model='GeoPriorSubsNet', quantiles=None, horizon=3, batch_size=32, subs_unit='mm', time_units='year')[source]
Build a realistic default eval-physics payload.
The payload is template-based. It is not meant to reproduce the Stage-2 evaluation path. Instead, it creates a stable and inspectable artifact with the same broad structure as the interpretable physics evaluation JSON.
- geoprior.utils.eval_physics_calibration_frame(payload)[source]
Return a tidy frame for top-level calibration scalars.
- geoprior.utils.eval_physics_calibration_per_horizon_frame(payload)[source]
Return a tidy per-horizon calibration frame.
- geoprior.utils.eval_physics_censor_frame(payload)[source]
Return a tidy frame for censor-aware metrics.
- geoprior.utils.eval_physics_metrics_frame(payload)[source]
Return a tidy frame for
metrics_evaluate.
- geoprior.utils.eval_physics_per_horizon_frame(payload)[source]
Return a tidy frame for exported per-horizon metrics.
- geoprior.utils.eval_physics_point_metrics_frame(payload)[source]
Return a tidy frame for point metrics.
- geoprior.utils.eval_physics_units_frame(payload)[source]
Return a tidy frame for units metadata.
- geoprior.utils.generate_eval_physics(*, output_path=None, template=None, overrides=None, **kwargs)[source]
Generate an eval-physics payload or file.
- Parameters:
output_path (
path-like, optional) – Destination JSON path. If omitted, the payload is returned instead of written.template (
mapping,ArtifactRecord, orpath, optional) – Real or synthetic eval-physics template used as the generation base.overrides (
dict, optional) – Nested overrides applied after template/default payload creation.**kwargs (
dict) – Parameters forwarded todefault_eval_physics_payloadwhen no template is given.
- Return type:
- geoprior.utils.inspect_eval_physics(payload, *, output_dir=None, stem='eval_physics', save_figures=True)[source]
Inspect an eval-physics artifact and optionally save figures.
- geoprior.utils.load_eval_physics(path)[source]
Load an eval-physics artifact.
- Raises:
ValueError – If the artifact does not look like an eval-physics payload.
- Parameters:
- Return type:
ArtifactRecord
- geoprior.utils.plot_eval_physics_boolean_summary(payload, *, ax=None, title='Eval physics checks', error='ignore', **plot_kws)[source]
Plot semantic pass/fail checks.
- geoprior.utils.plot_eval_physics_calibration_factors(payload, *, source='top', ax=None, title=None, error='ignore', **plot_kws)[source]
Plot per-horizon calibration factors.
- geoprior.utils.plot_eval_physics_epsilons(payload, *, ax=None, title='Eval physics: epsilon diagnostics', error='ignore', **plot_kws)[source]
Plot epsilon-related diagnostics.
- geoprior.utils.plot_eval_physics_metrics(payload, *, keys=None, ax=None, title=None, error='ignore', **plot_kws)[source]
Plot selected
metrics_evaluatevalues.
- geoprior.utils.plot_eval_physics_per_horizon_metrics(payload, *, metric='mae', ax=None, title=None, error='ignore', **plot_kws)[source]
Plot one exported per-horizon metric map.
- geoprior.utils.plot_eval_physics_point_metrics(payload, *, ax=None, title='Eval physics: point metrics', error='ignore', **plot_kws)[source]
Plot point-metric summary.
- geoprior.utils.summarize_eval_physics(payload)[source]
Build a compact semantic summary for inspection.
- geoprior.utils.default_physics_payload_meta_payload(*, city='demo_city', model_name='GeoPriorSubsNet', split='TestSet', created_utc='2026-02-22T13:50:48Z', saved_utc=None, pde_modes_active=None, time_units='year', time_coord_units='year', gwl_kind='depth_bgs', gwl_sign='down_positive', use_head_proxy=True, kappa_mode='kb', use_effective_h=True, hd_factor=0.6, lambda_offset=1.0, kappa_value=0.046392478, tau_prior_definition='tau_closure_from_learned_fields', tau_prior_human_name='tau_closure', tau_prior_source='model.evaluate_physics() closure head', tau_closure_formula='Hd^2 * Ss / (pi^2 * kappa_b * K)', units=None, payload_metrics=None)[source]
Build a realistic default physics-payload meta payload.
The payload is intentionally lightweight and template-based. It mirrors the structure of the real
*.npz.meta.jsonsidecar without trying to reconstruct the full physics payload archive.- Parameters:
city (str)
model_name (str)
split (str)
created_utc (str)
saved_utc (str | None)
time_units (str)
time_coord_units (str)
gwl_kind (str)
gwl_sign (str)
use_head_proxy (bool)
kappa_mode (str)
use_effective_h (bool)
hd_factor (float)
lambda_offset (float)
kappa_value (float)
tau_prior_definition (str)
tau_prior_human_name (str)
tau_prior_source (str)
tau_closure_formula (str)
- Return type:
- geoprior.utils.generate_physics_payload_meta(path, *, template=None, overrides=None, **kwargs)[source]
Generate and save a physics-payload meta artifact.
- Parameters:
path (
path-like) – Output JSON path.template (
mapping,artifact, orpath, optional) – Template payload to clone. When omitted, the function starts fromdefault_physics_payload_meta_payload.overrides (
dict, optional) – Deep updates applied after loading the template.**kwargs (
Any) – Forwarded to the default payload builder when no template is provided.
- Return type:
- geoprior.utils.inspect_physics_payload_meta(payload)[source]
Return a structured inspection bundle.
- geoprior.utils.load_physics_payload_meta(path)[source]
Load a physics-payload meta artifact.
- geoprior.utils.physics_payload_meta_closure_frame(payload)[source]
Return a tidy closure / tau-prior frame.
- geoprior.utils.physics_payload_meta_identity_frame(payload)[source]
Return a tidy identity / convention frame.
- geoprior.utils.physics_payload_meta_metrics_frame(payload)[source]
Return compact payload metrics as a tidy frame.
- geoprior.utils.physics_payload_meta_units_frame(payload)[source]
Return the units block as a tidy frame.
- geoprior.utils.plot_physics_payload_meta_boolean_summary(payload, *, ax=None, title='Physics payload meta: checks', **plot_kws)[source]
Plot boolean inspection checks.
- geoprior.utils.plot_physics_payload_meta_core_scalars(payload, *, ax=None, title='Physics payload meta: core scalars', **plot_kws)[source]
Plot selected top-level numeric metadata values.
- geoprior.utils.plot_physics_payload_meta_payload_metrics(payload, *, ax=None, title='Physics payload meta: payload metrics', **plot_kws)[source]
Plot compact payload-level metrics.
- geoprior.utils.summarize_physics_payload_meta(payload)[source]
Build a compact semantic summary for inspection.
- geoprior.utils.default_scaling_kwargs_payload(*, time_units='year', coords_normalized=True, coords_in_degrees=False, coord_mode='degrees', coord_order=None, coord_ranges=None, bounds=None, dynamic_feature_names=None, future_feature_names=None, static_feature_names=None, gwl_dyn_name='GWL_depth_bgs_m__si', gwl_dyn_index=0, subs_dyn_index=1, z_surf_static_index=11, gwl_kind='depth_bgs', gwl_sign='down_positive', use_head_proxy=True, q_kind='per_volume', mv_prior_mode='calibrate', mv_prior_units='strict')[source]
Build a realistic default scaling-kwargs payload.
The payload is intentionally template-based. It mirrors the resolved
scaling_kwargs.jsonstructure written by Stage-2 and preserved in manifests, while staying lightweight enough for documentation examples.- Parameters:
- Return type:
- geoprior.utils.generate_scaling_kwargs(path, *, template=None, overrides=None, **kwargs)[source]
Generate a scaling-kwargs JSON artifact.
- Parameters:
path (
path-like) – Output path for the JSON file.template (
mappingorpath-like, optional) – Existing payload used as a template.overrides (
dict, optional) – Nested override values applied after the template.**kwargs (
Any) – Convenience keyword overrides forwarded into the default payload builder when no explicit template is provided.
- Returns:
Written JSON path.
- Return type:
- geoprior.utils.inspect_scaling_kwargs(payload)[source]
Return a compact multi-view inspection bundle.
- geoprior.utils.load_scaling_kwargs(path)[source]
Load a scaling-kwargs artifact.
- geoprior.utils.plot_scaling_kwargs_affine_maps(payload, *, ax=None, title='Scaling affine maps', error='ignore', **plot_kws)[source]
Plot subs/head/H affine map scalars.
- geoprior.utils.plot_scaling_kwargs_boolean_summary(payload, *, ax=None, title='Scaling boolean checks', error='ignore', **plot_kws)[source]
Plot common boolean config flags.
- geoprior.utils.plot_scaling_kwargs_bounds(payload, *, ax=None, title='Bounds overview', top_n=None, error='ignore', **plot_kws)[source]
Plot numeric bounds as a compact bar chart.
- geoprior.utils.plot_scaling_kwargs_coord_ranges(payload, *, ax=None, title='Coordinate ranges', error='ignore', **plot_kws)[source]
Plot
coord_rangesfor t/x/y.
- geoprior.utils.plot_scaling_kwargs_feature_group_sizes(payload, *, ax=None, title='Feature group sizes', error='ignore', **plot_kws)[source]
Plot dynamic/future/static feature-group counts.
- geoprior.utils.plot_scaling_kwargs_schedule_scalars(payload, *, ax=None, title='Schedule and runtime scalars', error='ignore', **plot_kws)[source]
Plot selected numeric schedule/runtime scalars.
- geoprior.utils.scaling_kwargs_affine_frame(payload)[source]
Return affine SI-map rows.
- geoprior.utils.scaling_kwargs_bounds_frame(payload)[source]
Return tidy bounds rows.
- geoprior.utils.scaling_kwargs_coord_frame(payload)[source]
Return coordinate and convention rows.
- geoprior.utils.scaling_kwargs_feature_channels_frame(payload)[source]
Return feature-group and channel rows.
- geoprior.utils.scaling_kwargs_schedule_frame(payload)[source]
Return Q/MV schedule and runtime scalar rows.
- geoprior.utils.summarize_scaling_kwargs(payload)[source]
Return a compact high-level scaling summary.
- geoprior.utils.default_model_init_manifest_payload(*, model_class='GeoPriorSubsNet', forecast_horizon=3, static_input_dim=12, dynamic_input_dim=5, future_input_dim=1, output_subsidence_dim=1, output_gwl_dim=1, quantiles=None, mode='tft_like', pde_mode='on', identifiability_regime=None, time_units='year')[source]
Build a realistic default model-init manifest.
The payload mirrors the saved structure used by the model-init manifest family:
config+dims+model_class.- Parameters:
- Return type:
- geoprior.utils.generate_model_init_manifest(path, *, template=None, overrides=None)[source]
Generate a model-init manifest JSON file.
- Parameters:
path (
strorpathlib.Path) – Output JSON path.template (
mapping, optional) – Base payload. If omitted, usesdefault_model_init_manifest_payload().overrides (
mapping, optional) – Nested overrides applied on top of the template.
- Return type:
- geoprior.utils.inspect_model_init_manifest(manifest, *, output_dir=None, stem='model_init_manifest', save_figures=True)[source]
Inspect a model-init manifest and optionally save figures.
- geoprior.utils.load_model_init_manifest(path)[source]
Load a model-init manifest as
ArtifactRecord.
- geoprior.utils.model_init_architecture_frame(manifest)[source]
Return a frame for architecture choices.
- geoprior.utils.model_init_dims_frame(manifest)[source]
Return a tidy frame for input/output dimensions.
- geoprior.utils.model_init_feature_groups_frame(manifest)[source]
Return a tidy frame for nested feature-name groups.
- geoprior.utils.model_init_geoprior_frame(manifest)[source]
Return a frame for GeoPrior-specific init settings.
- geoprior.utils.model_init_scaling_overview_frame(manifest)[source]
Return a compact overview of resolved scaling kwargs.
- geoprior.utils.plot_model_init_architecture(manifest, *, ax=None, title='Architecture scalars', **plot_kws)[source]
Plot key architecture scalars.
- geoprior.utils.plot_model_init_boolean_summary(manifest, *, ax=None, title='Model-init checks', **plot_kws)[source]
Plot compact initialization checks as booleans.
- geoprior.utils.plot_model_init_dims(manifest, *, ax=None, title='Model-init dimensions', **plot_kws)[source]
Plot input/output dimensions.
- geoprior.utils.plot_model_init_feature_group_sizes(manifest, *, ax=None, title='Feature-group sizes', **plot_kws)[source]
Plot the sizes of static/dynamic/future feature groups.
- geoprior.utils.plot_model_init_geoprior(manifest, *, ax=None, title='GeoPrior initialization', **plot_kws)[source]
Plot key GeoPrior physics-init scalars.
- geoprior.utils.summarize_model_init_manifest(manifest)[source]
Return a compact summary of a model-init manifest.
The goal is not to flatten every nested key, but to expose the most decision-relevant initialization facts.
- geoprior.utils.default_run_manifest_payload(*, city='nansha', model='GeoPriorSubsNet', stage='stage-2-train', time_steps=5, forecast_horizon_years=3, mode='tft_like', pde_mode_config='on', quantiles=None, identifiability_regime=None)[source]
Build a realistic default run-manifest payload.
The payload mirrors the lightweight Stage-2 run manifest saved after training:
stage+city+model+config+paths+artifacts.
- geoprior.utils.generate_run_manifest(path, *, template=None, overrides=None)[source]
Generate a run-manifest JSON artifact.
- Parameters:
path (
strorpathlib.Path) – Output JSON path.template (
mapping, optional) – Base payload. If omitted, usesdefault_run_manifest_payload().overrides (
mapping, optional) – Nested overrides applied on top of the template.
- Return type:
- geoprior.utils.inspect_run_manifest(manifest)[source]
Return a bundle of useful inspection outputs.
- geoprior.utils.load_run_manifest(path)[source]
Load a run-manifest as
ArtifactRecord.
- geoprior.utils.plot_run_manifest_boolean_summary(ax, manifest, *, title='Run-manifest checks', **plot_kws)[source]
Plot simple boolean checks for expected run outputs.
- geoprior.utils.plot_run_manifest_coord_ranges(ax, manifest, *, title='Coordinate ranges', **plot_kws)[source]
Plot coordinate ranges from nested scaling kwargs.
- geoprior.utils.plot_run_manifest_feature_group_sizes(ax, manifest, *, title='Feature-group sizes', **plot_kws)[source]
Plot feature-group sizes from nested scaling kwargs.
- geoprior.utils.plot_run_manifest_path_inventory(ax, manifest, *, title='Run-manifest inventory', **plot_kws)[source]
Plot path and artifact counts for the run manifest.
- geoprior.utils.run_manifest_artifacts_frame(manifest)[source]
Return a frame describing direct artifact pointers.
- geoprior.utils.run_manifest_config_frame(manifest)[source]
Return a tidy frame for the lightweight config block.
- geoprior.utils.run_manifest_identity_frame(manifest)[source]
Return a compact run-identity frame.
- geoprior.utils.run_manifest_paths_frame(manifest)[source]
Return a frame describing exported run paths.
- geoprior.utils.run_manifest_scaling_overview_frame(manifest)[source]
Return a compact frame for nested scaling overview.
- geoprior.utils.summarize_run_manifest(manifest)[source]
Summarize a run manifest in a compact dictionary.
The summary focuses on what a user usually wants at a glance: which Stage-2 run this is, how many paths were exported, whether feature groups are present, and the essential scaling/coord context.
- geoprior.utils.default_manifest_payload(*, city='nansha', model='GeoPriorSubsNet', time_steps=5, forecast_horizon_years=3, mode='tft_like')[source]
Build a realistic default Stage-1 manifest payload.
The payload mirrors the stable Stage-1 handshake:
schema_version+ identity fields +config+artifacts+paths+versions.
- geoprior.utils.generate_manifest(path, *, template=None, overrides=None)[source]
Generate a Stage-1 manifest JSON artifact.
- geoprior.utils.inspect_manifest(manifest)[source]
Return a bundle of useful Stage-1 inspection outputs.
- geoprior.utils.load_manifest(path)[source]
Load a Stage-1 manifest as
ArtifactRecord.
- geoprior.utils.manifest_artifacts_frame(manifest)[source]
Return a leaf-level artifact inventory frame.
- geoprior.utils.manifest_config_frame(manifest)[source]
Return a tidy frame for the compact config overview.
- geoprior.utils.manifest_feature_groups_frame(manifest)[source]
Return feature-group names and counts.
- geoprior.utils.manifest_holdout_frame(manifest)[source]
Return key holdout and split counts as a tidy frame.
- geoprior.utils.manifest_identity_frame(manifest)[source]
Return a compact Stage-1 identity frame.
- geoprior.utils.manifest_paths_frame(manifest)[source]
Return a frame describing top-level manifest paths.
- geoprior.utils.manifest_shapes_frame(manifest)[source]
Return a tidy tensor-shape summary frame.
- geoprior.utils.manifest_versions_frame(manifest)[source]
Return runtime/library versions saved in the manifest.
- geoprior.utils.plot_manifest_artifact_inventory(ax, manifest, *, title='Manifest artifact inventory', **plot_kws)[source]
Plot artifact and metadata inventory counts.
- geoprior.utils.plot_manifest_boolean_summary(ax, manifest, *, title='Stage-1 manifest checks', **plot_kws)[source]
Plot simple boolean checks for the handshake.
- geoprior.utils.plot_manifest_coord_ranges(ax, manifest, *, title='Coordinate ranges', **plot_kws)[source]
Plot nested coordinate ranges from scaling kwargs.
- geoprior.utils.plot_manifest_feature_group_sizes(ax, manifest, *, title='Stage-1 feature groups', **plot_kws)[source]
Plot Stage-1 feature-group sizes.
- geoprior.utils.plot_manifest_holdout_counts(ax, manifest, *, title='Holdout split counts', **plot_kws)[source]
Plot the main group and sequence split counts.
- geoprior.utils.summarize_manifest(manifest)[source]
Summarize the Stage-1 manifest in a compact dictionary.
The summary focuses on what later workflow stages usually need to verify first.
- geoprior.utils.default_xfer_results_payload(*, src_city='nansha', tgt_city='zhongshan', model_name='GeoPriorSubsNet', split='test', calibration='source', strategy='warm', rescale_mode='strict')[source]
Build a realistic default transfer-results payload.
The payload mirrors the common pattern where one
xfer_results.jsonfile stores multiple transfer records, often one per direction.
- geoprior.utils.generate_xfer_results(path, *, template=None, overrides=None)[source]
Generate a reproducible transfer-results artifact.
- geoprior.utils.inspect_xfer_results(xfer)[source]
Build a compact inspection bundle.
- geoprior.utils.load_xfer_results(xfer)[source]
Load transfer-results records.
- geoprior.utils.plot_xfer_boolean_summary(xfer, *, figsize=(8.4, 4.8), ax=None, title='Transfer boolean summary', show_grid=True, grid_kws=None, error='ignore', **plot_kws)[source]
Plot aggregated schema boolean pass rates.
This turns per-record schema booleans into one compact pass-rate view.
- geoprior.utils.plot_xfer_direction_metric(xfer, *, metric='overall_rmse', figsize=(7.8, 4.2), ax=None, title=None, ylabel=None, show_grid=True, grid_kws=None, error='ignore', **plot_kws)[source]
Plot one overall metric by transfer direction.
- geoprior.utils.plot_xfer_overall_metrics(xfer, *, metrics=None, figsize=(8.4, 4.8), ax=None, title='Transfer overall metrics', ylabel='value', show_grid=True, grid_kws=None, legend=None, legend_kws=None, error='ignore', **plot_kws)[source]
Plot selected overall metrics for each record.
- Parameters:
- Return type:
- geoprior.utils.plot_xfer_per_horizon_metrics(xfer, *, metric='rmse', figsize=(8.0, 4.8), ax=None, title=None, xlabel='horizon', ylabel=None, marker='o', show_grid=True, grid_kws=None, legend=True, legend_kws=None, error='ignore', **plot_kws)[source]
Plot one per-horizon metric as lines over horizon.
- Parameters:
- Return type:
- geoprior.utils.plot_xfer_schema_counts(xfer, *, figsize=(8.0, 4.6), ax=None, title='Schema mismatch counts', ylabel='count', show_grid=True, grid_kws=None, legend=True, legend_kws=None, error='ignore', **plot_kws)[source]
Plot schema mismatch counts by record.
- geoprior.utils.summarize_xfer_results(xfer)[source]
Build a compact transfer-results summary.
The summary is intentionally workflow-oriented rather than exhaustive.
- geoprior.utils.xfer_overall_frame(xfer)[source]
Return one tidy row per transfer record.
The frame exposes the most useful comparison columns for quick ranking and filtering.
- geoprior.utils.xfer_per_horizon_frame(xfer)[source]
Return a tidy per-horizon metric frame.
The output is useful for comparing whether transfer quality degrades differently across directions or strategies as horizon increases.
- geoprior.utils.xfer_schema_frame(xfer)[source]
Return schema-alignment diagnostics in tidy form.
- geoprior.utils.xfer_warm_frame(xfer)[source]
Return warm-start settings in tidy form.
- geoprior.utils.calibration_stats_factors_frame(stats)[source]
Return per-horizon calibration factors.
- geoprior.utils.calibration_stats_overall_frame(stats)[source]
Return before/after overall calibration metrics.
- geoprior.utils.calibration_stats_per_horizon_frame(stats, *, which='eval_after')[source]
Return per-horizon coverage and sharpness.
- geoprior.utils.default_calibration_stats_payload(*, target=0.8, interval=(0.1, 0.9), f_max=5.0, tol=0.02, factors=None, coverage_before=0.865, coverage_after=0.8, sharpness_before=33.08, sharpness_after=33.38)[source]
Build a realistic default calibration-stats payload.
The structure follows the object saved by the calibration workflow and later embedded into the interpretable eval JSON.
- geoprior.utils.generate_calibration_stats(path, *, template=None, overrides=None)[source]
Generate and save a calibration-stats JSON file.
- Parameters:
path (
strorpathlib.Path) – Output JSON path.template (
mapping,path,ArtifactRecord, optional) – Optional source payload. If omitted, a realistic default payload is used.overrides (
dict, optional) – Deep overrides applied after template resolution.
- Return type:
- geoprior.utils.inspect_calibration_stats(stats)[source]
Build a compact inspection bundle.
- geoprior.utils.load_calibration_stats(path)[source]
Load a calibration-stats artifact.
Notes
This loader is list-safe and nested-block aware. It can read either a direct
calibration_stats.jsonpayload or an interpretable eval JSON from which the nested block is extracted.
- geoprior.utils.plot_calibration_boolean_summary(ax, stats, *, title='Calibration checks', error='ignore', **plot_kws)[source]
Plot compact boolean checks for calibration status.
- geoprior.utils.plot_calibration_factors(ax, stats, *, title='Calibration factors', show_grid=True, grid_kws=None, error='ignore', **plot_kws)[source]
Plot per-horizon widening factors.
- geoprior.utils.plot_calibration_overall_metrics(ax, stats, *, title='Calibration summary', error='ignore', **plot_kws)[source]
Plot overall before/after calibration metrics.
- geoprior.utils.plot_calibration_per_horizon_coverage(ax, stats, *, which='eval_after', title=None, error='ignore', **plot_kws)[source]
Plot per-horizon coverage.
- geoprior.utils.plot_calibration_per_horizon_sharpness(ax, stats, *, which='eval_after', title=None, error='ignore', **plot_kws)[source]
Plot per-horizon sharpness.
- geoprior.utils.summarize_calibration_stats(stats)[source]
Return a compact summary of calibration behavior.
- geoprior.utils.ablation_config_frame(src)[source]
Return one row per record with config knobs.
- geoprior.utils.ablation_metrics_frame(src)[source]
Return long-form scalar metric rows.
- geoprior.utils.ablation_per_horizon_frame(src)[source]
Return long-form per-horizon metric rows.
- geoprior.utils.ablation_record_flags_frame(src)[source]
Return long-form boolean/config flags.
- geoprior.utils.ablation_record_runs_frame(src)[source]
Return one tidy row per ablation record.
- geoprior.utils.default_ablation_record_payload(*, city='nansha', model='GeoPriorSubsNet')[source]
Build a realistic default ablation JSONL payload.
The structure mirrors the current real record shape: one JSON object per line, with top-level config knobs, repeated compact metrics, a nested
metricsblock, units, and per-horizon metric maps.
- geoprior.utils.generate_ablation_record(output_path, *, overrides=None, city='nansha', model='GeoPriorSubsNet')[source]
Write a realistic demo ablation JSONL file.
- geoprior.utils.inspect_ablation_record(src, *, output_dir=None, stem='ablation_record', save_figures=True)[source]
Inspect ablation JSONL and optionally save figures.
- geoprior.utils.load_ablation_record(src)[source]
Load ablation JSONL records into a plain list.
- geoprior.utils.plot_ablation_boolean_summary(src, *, ax=None, title='Ablation record checks', **plot_kws)[source]
- geoprior.utils.plot_ablation_lambda_weights(src, *, ax=None, title='Lambda weights by variant', xlabel='variant', ylabel='weight', show_grid=True, grid_kws=None, rotate_xticks=25, legend=True, legend_kws=None, error='ignore', **plot_kws)[source]
- geoprior.utils.plot_ablation_metric_by_variant(src, *, metric='rmse', ax=None, title=None, xlabel='variant', ylabel=None, show_grid=True, grid_kws=None, rotate_xticks=25, annotate=False, annotate_kws=None, error='ignore', **plot_kws)[source]
- Parameters:
- Return type:
- geoprior.utils.plot_ablation_per_horizon_metric(src, *, metric='mae', ax=None, title=None, xlabel='horizon', ylabel=None, show_grid=True, grid_kws=None, legend=True, legend_kws=None, error='ignore', **plot_kws)[source]
- geoprior.utils.plot_ablation_run_counts(src, *, ax=None, title='Ablation runs per variant', xlabel='variant', ylabel='runs', show_grid=True, grid_kws=None, rotate_xticks=25, annotate=False, annotate_kws=None, error='ignore', **plot_kws)[source]
- geoprior.utils.plot_ablation_top_variants(src, *, metric='rmse', top_n=5, ax=None, title=None, xlabel=None, ylabel='variant', show_grid=True, grid_kws=None, annotate=False, annotate_kws=None, error='ignore', **plot_kws)[source]
- Parameters:
- Return type:
- geoprior.utils.summarize_ablation_record(src)[source]
Return a semantic summary of ablation JSONL.
What this package covers#
The top-level utility package is intentionally broad because it supports the staged application workflow end to end.
Its current public exports include helper groups such as:
audits and handshakes including
audit_stage1_scaling,audit_stage2_handshake, andshould_audit;calibration including
calibrate_quantile_forecasts;forecast formatting and evaluation including
format_and_forecast,evaluate_forecast, andpivot_forecast_dataframe;configuration and stage helpers including
load_nat_config,load_nat_config_payload,make_tf_dataset,map_targets_for_training,resolve_hybrid_config,resolve_si_affine, andserialize_subs_params;geospatial and spatial workflow helpers including
spatial_sampling,create_spatial_clusters,augment_city_spatiotemporal_data, anddeg_to_m_from_lat;holdout logic including
compute_group_masks,split_groups_holdout, andfilter_df_by_groups;subsidence-oriented workflow helpers including
convert_eval_payload_units,cumulative_to_rate,rate_to_cumulative,resolve_gwl_for_physics,resolve_head_column, andmake_txy_coords.
Important workflow-facing modules#
Audit helpers#
Audit helpers for stage handshakes and scaling artifacts.
- geoprior.utils.audit_utils.resolve_audit_stages(audit_stages, *, known=('stage1', 'stage2', 'stage3'), default=None)[source]
Resolve cfg[“AUDIT_STAGES”] into a canonical set like {“stage1”,”stage2”}.
- geoprior.utils.audit_utils.should_audit(audit_stages, *, stage, default=None)[source]
Convenience: should we audit this stage?
- geoprior.utils.audit_utils.audit_stage1_scaling(*, df_train, inputs_train, targets_train, coord_scaler=None, coord_ranges=None, coord_mode='auto', coords_in_degrees=False, coord_epsg_used=None, coord_x_col_used='x', coord_y_col_used='y', x_col_used='x', y_col_used='y', time_col_used='t', normalize_coords=True, keep_coords_raw=False, shift_raw_coords=False, subs_model_col=None, gwl_dyn_col=None, gwl_target_col=None, h_field_col=None, dynamic_features=None, static_features=None, future_features=None, scaled_ml_numeric_cols=None, main_scaler_path=None, scaler_info=None, save_dir=None, table_width=110, title_prefix='COORDINATE + FEATURE SCALING AUDIT (Stage-1)', city='Unknown', model_name='Model', sample_rows=5, log_fn=None)[source]
Stage-1 audit: - raw df_train coord stats (t/x/y) + heuristic units - model-fed coords stats from inputs_train[“coords”] (flattened) - coord scaler min/max + coord_ranges - SI channel sanity for physics cols (if present) - target arrays sanity - split of features: scaled ML vs __si vs other Saves a machine-readable JSON if save_dir is provided.
- Parameters:
coord_scaler (Any)
coord_mode (str)
coords_in_degrees (bool)
coord_epsg_used (Any)
coord_x_col_used (str)
coord_y_col_used (str)
x_col_used (str)
y_col_used (str)
time_col_used (str)
normalize_coords (bool)
keep_coords_raw (bool)
shift_raw_coords (bool)
subs_model_col (str | None)
gwl_dyn_col (str | None)
gwl_target_col (str | None)
h_field_col (str | None)
main_scaler_path (str | None)
scaler_info (dict | None)
save_dir (str | None)
table_width (int)
title_prefix (str)
city (str)
model_name (str)
sample_rows (int)
- Return type:
str | None
- geoprior.utils.audit_utils.audit_stage2_handshake(*, X_train, X_val, y_train, y_val, time_steps, forecast_horizon, mode, dyn_names, fut_names, sta_names, coord_scaler=None, sk_final, save_dir, table_width=100, title_prefix='STAGE-2 HANDSHAKE AUDIT', city='Unkown', model_name='Model', log_fn=None)[source]
- geoprior.utils.audit_utils.audit_stage1_stage2_coord_consistency(*, X_train, coord_scaler, sk_final, time_steps, forecast_horizon, time_units='year', save_dir=None, table_width=110, title_prefix='STAGE-1 <-> STAGE-2 COORD CONSISTENCY', city='Unknown', model_name='Model', log_fn=None)[source]
Cross-check coordinate semantics between Stage-1 scaler and Stage-2 NPZ coords.
- Key facts for GeoPrior Stage-2:
coords are (N, H, 3) and correspond to target horizon times not the full dynamic history. So t has exactly H unique values.
x/y typically cover full normalized [0,1] range if you have spatial coverage (often min=0 and max=1).
- This audit:
computes normalized min/max for t/x/y in X_train[“coords”]
derives implied raw min/max using MinMaxScaler
data_min_/data_max_checks raw ranges are within Stage-1 scaler bounds
checks t_unique count == H and t_raw_unique spacing (≈1 year)
provides UTM plausibility hint if epsg is UTM-like
- geoprior.utils.audit_utils.audit_stage3_run(*, manifest_path, manifest, cfg, fixed_params, best_hps, run_dir, best_model_path, best_weights_path, use_tf_savedmodel, quantiles, forecast_horizon, mode, pred_shapes=None, eval_results=None, phys_diag=None, calibrator_factors=None, forecast_csv_eval=None, forecast_csv_future=None, metrics_json_path=None, physics_payload_path=None, save_dir=None, table_width=100, title_prefix='STAGE-3 AUDIT', city='Unknown', model_name='Model', log_fn=None)[source]
Stage-3 audit: tuned artifacts + eval sanity.
- Parameters:
manifest_path (str | None)
run_dir (str)
best_model_path (str | None)
best_weights_path (str | None)
use_tf_savedmodel (bool)
quantiles (Any)
forecast_horizon (int)
mode (str)
calibrator_factors (Any)
forecast_csv_eval (str | None)
forecast_csv_future (str | None)
metrics_json_path (str | None)
physics_payload_path (str | None)
save_dir (str | None)
table_width (int)
title_prefix (str)
city (str)
model_name (str)
- Return type:
str | None
These helpers matter because they validate Stage-1 scaling and Stage-2 handoff assumptions before later model or physics-aware runs rely on them.
Calibration helpers#
Forecast utilities.
- geoprior.utils.calibrate.fit_interval_factors_df(df_eval, *, target_name='subsidence', column_map=None, step_col='forecast_step', interval=(0.1, 0.9), target_coverage=0.8, median_q=0.5, tol=0.001, f_max=5.0, max_iter=32, verbose=1, logger=None)[source]
Fit empirical interval-width correction factors from an evaluation DataFrame.
This function learns one multiplicative factor per forecast horizon so that an existing predictive interval reaches a desired empirical coverage on evaluation data. It is designed for the common case where a model already produces quantile forecasts, but the resulting intervals are systematically too narrow or too wide.
For each group of rows, the function identifies
one lower quantile,
one median-like quantile, and
one upper quantile,
then searches for the smallest factor \(f \ge 1\) such that the rescaled interval reaches the requested target coverage.
If \(q_{lo}\), \(q_{md}\), and \(q_{hi}\) denote the selected lower, median, and upper forecasts, the calibrated interval used during fitting is
(5)#\[\tilde q_{lo} = q_{md} - f \, (q_{md} - q_{lo})\](6)#\[\tilde q_{hi} = q_{md} + f \, (q_{hi} - q_{md})\]and the fitted factor is the smallest value whose empirical coverage is at least the target, up to the requested tolerance.
This is a lightweight post-hoc calibration strategy that is especially useful for multi-horizon quantile forecasts, where sharpness and coverage must be balanced carefully in downstream risk analysis [1, 2].
- Parameters:
df_eval (
pandas.DataFrame) – Evaluation table containing the observed target and the forecast quantile columns used for calibration.target_name (
str, default"subsidence") – Base name of the forecasted variable. Whencolumn_mapis not provided, quantile columns are auto-detected using names such as"{target_name}_q10","{target_name}_q50", and"{target_name}_q90".column_map (
mappingorNone, defaultNone) –Optional explicit column mapping. This may contain at least
"quantiles": a mapping from quantile level to column name, or a list of quantile column names,"actual": the observed target column.
This is useful when the DataFrame does not follow the default naming convention.
step_col (
str, default"forecast_step") – Column used to fit separate factors per horizon or lead time. If this column is missing, the full DataFrame is treated as a single group and a single factor is returned with key1.interval (
tupleoffloat, default(0.1,0.9)) – Nominal lower and upper quantiles defining the interval to calibrate. The function uses the closest available quantiles in the DataFrame.target_coverage (
float, default0.8) – Desired empirical coverage of the calibrated interval.median_q (
float, default0.5) – Target quantile used as the center of the interval expansion. The nearest available quantile in the DataFrame is used.tol (
float, default1e-3) – Numerical tolerance used during the bisection search.f_max (
float, default5.0) – Upper bound for the searched interval factor. If the target coverage cannot be reached before this bound,f_maxis returned for that group.max_iter (
int, default32) – Maximum number of bisection iterations used to fit each factor.verbose (
int, default1) – Verbosity level forwarded to the internal logging helper.logger (
logging.LoggerorNone, defaultNone) – Optional logger used for progress messages.
- Returns:
Dictionary mapping each forecast horizon to its fitted interval factor.
- Return type:
dict[int,float]- Raises:
ValueError – If no actual column can be resolved, or if no valid quantile interval can be formed from the available columns.
KeyError – If a user-specified column in
column_mapis missing.
Notes
This function does not change the forecasts directly. It only learns the correction factors. To apply the fitted factors to evaluation or future forecasts, use
apply_interval_factors_df().Because the method calibrates interval width around a median-like center, it is most appropriate when miscalibration is primarily a sharpness problem rather than a severe bias in the central forecast.
Examples
>>> import pandas as pd >>> from geoprior.utils.calibrate import fit_interval_factors_df >>> df = pd.DataFrame( ... { ... "forecast_step": [1, 1, 1, 2, 2, 2], ... "subsidence_actual": [0.4, 0.7, 1.0, 0.5, 0.8, 1.2], ... "subsidence_q10": [0.3, 0.5, 0.8, 0.4, 0.6, 0.9], ... "subsidence_q50": [0.4, 0.7, 1.0, 0.5, 0.8, 1.1], ... "subsidence_q90": [0.5, 0.9, 1.1, 0.6, 1.0, 1.3], ... } ... ) >>> factors = fit_interval_factors_df( ... df, ... target_name="subsidence", ... interval=(0.1, 0.9), ... target_coverage=0.8, ... ) >>> isinstance(factors, dict) True
See also
apply_interval_factors_dfApply learned interval factors to forecast quantiles.
calibrate_quantile_forecastsHigh-level wrapper that can fit and apply factors in one call.
- geoprior.utils.calibrate.apply_interval_factors_df(df, factors, *, target_name='subsidence', column_map=None, step_col='forecast_step', median_q=0.5, keep_original=False, factor_col='calibration_factor', calibrated_col='is_calibrated', enforce_monotonic='cummax', verbose=1, logger=None)[source]
Apply interval-width calibration factors to quantile forecasts stored in a DataFrame.
This function rescales forecast quantiles around a median-like central forecast using either
a single scalar factor applied to every row, or
a mapping of
forecast_step -> factor.
The main goal is to widen or shrink predictive intervals after the factors have been estimated on evaluation data with
fit_interval_factors_df().If \(q_{md}\) is the selected median-like forecast and \(f_h\) is the factor for horizon \(h\), then for a lower quantile \(q < q_{md}\) the calibrated value is
(7)#\[\tilde q = q_{md} - f_h \, (q_{md} - q)\]and for an upper quantile \(q > q_{md}\) it is
(8)#\[\tilde q = q_{md} + f_h \, (q - q_{md})\]while the median itself is left unchanged.
- Parameters:
df (
pandas.DataFrame) – DataFrame containing forecast quantile columns.factors (
mappingorfloat) –Calibration factor specification. This may be
a scalar applied to every row, or
a mapping from forecast horizon to factor.
target_name (
str, default"subsidence") – Base variable name used to auto-detect quantile columns whencolumn_mapis not supplied.column_map (
mappingorNone, defaultNone) – Optional explicit mapping describing the quantile columns.step_col (
str, default"forecast_step") – Horizon column used whenfactorsis a mapping. If this column is absent, it is created and filled with1.median_q (
float, default0.5) – Quantile used as the center of the recalibration. The closest available forecast quantile is used.keep_original (
bool, defaultFalse) – If True, each raw quantile column is preserved in an additional"{col}_raw"column before calibration is applied.factor_col (
str, default"calibration_factor") – Name of the column storing the factor applied to each row.calibrated_col (
str, default"is_calibrated") – Name of the Boolean marker column added to the returned DataFrame.enforce_monotonic (
{"cummax", "sort", "none"}, default"cummax") –Strategy used after recalibration to keep quantiles ordered.
"cummax"applies a cumulative maximum across quantiles,"sort"sorts the calibrated quantiles row-wise,"none"leaves the result unchanged.
verbose (
int, default1) – Verbosity level forwarded to the internal logger helper.logger (
logging.LoggerorNone, defaultNone) – Optional logger used for progress messages.
- Returns:
A calibrated copy of
dfcontaining updated quantiles and the metadata columns specified byfactor_colandcalibrated_col.- Return type:
- Raises:
ValueError – If no quantile columns can be resolved from the DataFrame or if an invalid
enforce_monotonicmode is requested.
Notes
Monotonicity enforcement is important because quantile-wise post-processing can otherwise produce crossing quantiles. Keeping quantiles ordered is especially important when the calibrated outputs are used to derive intervals, exceedance probabilities, or downstream risk metrics [1, 2].
This function only modifies the quantile forecasts. It does not recompute summary coverage statistics. For fit-and-apply workflows with optional evaluation summaries, use
calibrate_quantile_forecasts().Examples
>>> import pandas as pd >>> from geoprior.utils.calibrate import apply_interval_factors_df >>> df = pd.DataFrame( ... { ... "forecast_step": [1, 1, 2, 2], ... "subsidence_q10": [0.3, 0.5, 0.4, 0.6], ... "subsidence_q50": [0.4, 0.7, 0.5, 0.8], ... "subsidence_q90": [0.5, 0.9, 0.6, 1.0], ... } ... ) >>> out = apply_interval_factors_df( ... df, ... factors={1: 1.2, 2: 1.1}, ... target_name="subsidence", ... ) >>> "calibration_factor" in out.columns True
See also
fit_interval_factors_dfFit per-horizon interval scaling factors.
calibrate_quantile_forecastsHigh-level wrapper for fitting, applying, and summarizing interval calibration.
- geoprior.utils.calibrate.calibrate_quantile_forecasts(*, df_eval=None, df_future=None, target_name='subsidence', column_map=None, step_col='forecast_step', interval=(0.1, 0.9), target_coverage=0.8, median_q=0.5, use='auto', tol=0.02, f_max=5.0, max_iter=32, keep_original=False, enforce_monotonic='cummax', overall_key='__overall__', calibrated_col='is_calibrated', factor_col='calibration_factor', factors=None, save_eval=None, save_future=None, save_stats=None, verbose=1, logger=None)[source]
Fit and apply post-hoc interval calibration for quantile forecasts.
This is the high-level DataFrame-oriented entry point for interval recalibration in
geoprior.utils.calibrate. It candetect whether evaluation forecasts already appear calibrated,
fit interval-width correction factors from evaluation data,
apply those factors to evaluation and/or future forecasts,
compute before/after summary diagnostics on the evaluation set,
optionally save the outputs to disk.
The function is designed for workflows where quantile forecasts are already available in tabular form and calibration should be handled without retraining the forecasting model.
Conceptually, the function widens or narrows a predictive interval around a median-like forecast so that the empirical interval coverage better matches the requested target. This is a practical post-hoc strategy for uncertainty refinement in multi-horizon forecasting pipelines [1, 2].
- Parameters:
df_eval (
pandas.DataFrameorNone, defaultNone) – Evaluation forecasts used to fit calibration factors and to compute before/after diagnostics. This table should contain the observed target column in addition to the quantile forecasts.df_future (
pandas.DataFrameorNone, defaultNone) – Future or inference forecasts to which the fitted factors should be applied. This table does not need observed targets.target_name (
str, default"subsidence") – Base name used to infer forecast and observation columns whencolumn_mapis not explicitly supplied.column_map (
mappingorNone, defaultNone) – Optional mapping describing the observed column and the quantile columns. This is helpful when the table does not follow the default naming conventions.step_col (
str, default"forecast_step") – Column used to fit and apply separate factors per forecast horizon.interval (
tupleoffloat, default(0.1,0.9)) – Lower and upper quantiles defining the interval to calibrate. The nearest available quantiles are used.target_coverage (
float, default0.8) – Desired empirical coverage after calibration.median_q (
float, default0.5) – Central quantile used as the expansion anchor.use (
{"auto", True, False}, default"auto") –Control flag for whether calibration is performed.
Falsedisables calibration and returns inputs unchanged."auto"skips calibration when evaluation forecasts already look calibrated.Trueforces calibration even if the automatic check would skip it.
tol (
float, default0.02) – Tolerance used by the automatic already-calibrated check.f_max (
float, default5.0) – Maximum factor allowed during fitting.max_iter (
int, default32) – Maximum number of bisection iterations used when fitting factors.keep_original (
bool, defaultFalse) – If True, raw quantiles are copied into*_rawcolumns before calibration is applied.enforce_monotonic (
{"cummax", "sort", "none"}, default"cummax") – Strategy used to prevent quantile crossing after recalibration.overall_key (
strorNone, default"__overall__") – Reserved label stored in the returned statistics dictionary for overall summary reporting.calibrated_col (
str, default"is_calibrated") – Column name added to calibrated outputs as a Boolean marker.factor_col (
str, default"calibration_factor") – Column name used to store the factor applied to each row.factors (
floatormappingorNone, defaultNone) – Optional user-specified calibration factors. If provided, these take precedence over factors fitted fromdf_eval.save_eval (
strorpath-likeorNone, defaultNone) – Optional CSV path for saving the calibrated evaluation table.save_future (
strorpath-likeorNone, defaultNone) – Optional CSV path for saving the calibrated future table.save_stats (
strorpath-likeorNone, defaultNone) – Optional JSON path for saving the calibration summary.verbose (
int, default1) – Verbosity level forwarded to logging helpers.logger (
logging.LoggerorNone, defaultNone) – Optional logger used for progress messages.
- Returns:
df_eval_cal (
pandas.DataFrameorNone) – Calibrated evaluation DataFrame, or None when no evaluation table was provided.df_future_cal (
pandas.DataFrameorNone) – Calibrated future DataFrame, or None when no future table was provided.stats (
dict[str,Any]) – Dictionary describing the calibration workflow. Depending on the path taken, it may containthe target interval and target coverage,
the fitted or user-specified factors,
skip reasons,
evaluation summaries before and after calibration.
- Return type:
Notes
In
use="auto"mode, the function first checks for an explicitcalibrated_coland then falls back to a simple empirical coverage-based decision. This makes the wrapper conservative in repeated workflows, where the same tables may pass through the calibration stage more than once.The returned
statsdictionary is designed to be JSON-friendly and therefore suitable for audit trails, experiment manifests, or gallery artifacts.Examples
>>> import pandas as pd >>> from geoprior.utils.calibrate import ( ... calibrate_quantile_forecasts, ... ) >>> df_eval = pd.DataFrame( ... { ... "forecast_step": [1, 1, 2, 2], ... "subsidence_actual": [0.4, 0.7, 0.5, 0.9], ... "subsidence_q10": [0.3, 0.5, 0.4, 0.6], ... "subsidence_q50": [0.4, 0.7, 0.5, 0.8], ... "subsidence_q90": [0.5, 0.9, 0.6, 1.0], ... } ... ) >>> df_eval_cal, df_future_cal, stats = ( ... calibrate_quantile_forecasts( ... df_eval=df_eval, ... target_name="subsidence", ... target_coverage=0.8, ... ) ... ) >>> isinstance(stats, dict) True
See also
fit_interval_factors_dfFit per-horizon interval-width correction factors.
apply_interval_factors_dfApply a scalar or per-horizon factor map to quantile forecasts.
- geoprior.utils.calibrate.calibrate_forecasts_in(y_val, forecasts_val, *forecasts_to_calibrate, quantiles=(10, 20, 30, 40, 50, 60, 70, 80, 90), prefix='subsidence', verbose=None, _logger=None)[source]
Train and apply isotonic calibration models on forecast data.
- Parameters:
y_val (
pandas.Series) – Observed target values for calibration.forecasts_val (
pandas.DataFrame) – Validation forecasts with columns named <prefix>_q{quantile} for each quantile.*forecasts_to_calibrate (
DataFrameordict) – One or more DataFrames (or a mapping) of new forecasts to calibrate using trained models.quantiles (
tupleofint, optional) – List of quantile levels (e.g., 10, 50, 90) to calibrate.prefix (
str, default'subsidence') – Column name prefix for quantile forecasts.verbose (
int, optional) – Verbosity level passed to vlog for logging._logger (
Loggerorcallable, optional) – Sink for log messages (print or logging.Logger).
- Returns:
Notes
This function fits an isotonic regression for each quantile on validation data and then predicts calibrated probabilities for new forecasts. Use normalize_model_inputs to handle varied input formats.
Examples
>>> import pandas as pd >>> import numpy as np >>> from geoprior.utils.calibrate import calibrate_forecasts >>> # Create dummy validation series >>> idx = pd.date_range('2021-01-01', periods=50) >>> y_val = pd.Series( ... np.random.randn(50), index=idx ... ) >>> # Simulate validation forecasts >>> val_df = pd.DataFrame({ ... 'subsidence_q10': y_val - 0.2, ... 'subsidence_q90': y_val + 0.2 ... }, index=idx) >>> # Simulate two new forecast sets >>> newA = val_df + 0.1 >>> newB = val_df - 0.1 >>> # Calibrate forecasts >>> results, models = calibrate_forecasts( ... y_val, ... val_df, ... newA, ... newB, ... quantiles=(10, 90), ... verbose=2 ... ) >>> # Access calibrated probabilities >>> results['ModelA']['subsidence_q10_cal_prob'] >>> # Inspect fitted calibrator for q90 >>> models['subsidence_q90'].threshold_
- geoprior.utils.calibrate.calibrate_probability_forecast(df, prob_col, actual_col, method='isotonic', out_col=None, clip=True, savefile=None)[source]
Calibrate binary probability forecasts with isotonic regression or logistic scaling.
This function post-processes a column of raw forecast probabilities so that they better agree with observed binary outcomes. It returns a copy of the input DataFrame with one additional calibrated probability column.
Two calibration modes are supported:
"isotonic"Non-parametric monotone calibration using isotonic regression.
"logistic"Parametric calibration using logistic regression on the raw probabilities.
- Parameters:
df (
pandas.DataFrame) – DataFrame containing a column of raw forecast probabilities and a column of observed binary outcomes.prob_col (
str) – Name of the column containing the raw forecast probabilities. Values are usually expected in the interval[0, 1].actual_col (
str) – Name of the column containing the realized binary outcomes, encoded as0and1.method (
{"isotonic", "logistic"}, default"isotonic") – Calibration method to apply.out_col (
strorNone, defaultNone) – Name of the calibrated probability column. If None,f"{prob_col}_calib"is used.clip (
bool, defaultTrue) – If True, calibrated outputs are clipped to[0, 1]before being written to the result.savefile (
strorNone, defaultNone) – Optional output path handled by theSaveFile()decorator.
- Returns:
Copy of
dfwith an additional calibrated probability column.- Return type:
- Raises:
ValueError – If
methodis not one of"isotonic"or"logistic".
Notes
This function is appropriate for event-probability forecasts, not for continuous quantile forecasts. For quantile-table recalibration, use
calibrate_forecasts()orcalibrate_quantile_forecasts()depending on the workflow.The isotonic option is often preferred when a flexible monotone mapping is desired, whereas the logistic option is useful when a simple parametric correction is sufficient.
Examples
>>> import pandas as pd >>> from geoprior.utils.calibrate import ( ... calibrate_probability_forecast, ... ) >>> df = pd.DataFrame( ... { ... "p_raw": [0.10, 0.30, 0.60, 0.85], ... "y": [0, 0, 1, 1], ... } ... ) >>> out = calibrate_probability_forecast( ... df, ... prob_col="p_raw", ... actual_col="y", ... method="isotonic", ... ) >>> "p_raw_calib" in out.columns True
See also
calibrate_forecastsCalibrate continuous quantile forecasts by inverting calibrated CDF estimates.
calibrate_quantile_forecastsFit-and-apply interval calibration for tabular quantile forecasts.
- geoprior.utils.calibrate.calibrate_forecasts(df, quantiles, q_prefix, actual_col, method='isotonic', out_prefix='calib', grid_mode='unit', grid_size=1001, group_by=None, savefile=None)[source]
Calibrate continuous quantile forecasts by fitting a calibrated CDF surrogate and inverting it at the nominal quantile levels.
This function works on a DataFrame containing continuous quantile forecasts and the corresponding observed target. For each requested quantile level \(q\), it builds a binary supervision signal of the form
(9)#\[y_q = \mathbb{1}\{y \le \hat q\}\]where \(y\) is the observed outcome and \(\hat q\) is the raw forecast quantile at level \(q\).
A monotone classifier is then fitted to approximate the conditional CDF at that threshold. Finally, the calibrated forecast is obtained by inverting the fitted CDF on a grid and taking the first threshold whose calibrated CDF reaches the nominal quantile level.
This provides a practical post-hoc recalibration mechanism for continuous quantile forecasts in multi-horizon settings [1, 2].
- Parameters:
df (
pandas.DataFrame) – DataFrame containing the raw forecast quantile columns and the observed target column.quantiles (
sequenceoffloat) – Nominal quantile levels to recalibrate, expressed on the unit interval, for example[0.1, 0.5, 0.9].q_prefix (
str) – Base prefix for the raw quantile columns. For example, ifq_prefix="subsidence", the function expects columns such as"subsidence_q10","subsidence_q50", and"subsidence_q90".actual_col (
str) – Name of the observed continuous target column.method (
{"isotonic", "logistic"}, default"isotonic") – Calibration method used to estimate the monotone CDF surrogate at each quantile level.out_prefix (
str, default"calib") – Prefix used to build the new calibrated quantile column names. The output columns follow the pattern"{out_prefix}_{q_prefix}_qXX".grid_mode (
{"unit", "range"}, default"unit") –Domain used to construct the inversion grid.
"unit"buildsnp.linspace(0, 1, grid_size). This is mainly appropriate when the forecast support is already normalized to the unit interval."range"builds a grid from the observed range of each raw quantile column and is typically the safer choice for forecasts on the original target scale.
grid_size (
int, default1001) – Number of points used in the inversion grid.group_by (
strorNone, defaultNone) – Optional grouping column. When provided, calibration is performed separately within each group, for example by forecast horizon.savefile (
strorNone, defaultNone) – Optional output path handled by theSaveFile()decorator.
- Returns:
Copy of
dfwith calibrated quantile columns appended.- Return type:
Notes
This function differs from
calibrate_quantile_forecasts(). Here, each quantile is recalibrated through a classifier-plus- inversion approach. By contrast,calibrate_quantile_forecasts()applies multiplicative interval width factors around a median-like forecast.The two approaches answer different needs:
use
calibrate_forecasts()when you want to recalibrate the quantile levels themselves through a threshold-CDF view;use
calibrate_quantile_forecasts()when your main issue is interval under- or over-dispersion and you want a simpler, auditable width correction.
When
group_byis used, the function preserves the original row index order after concatenating the calibrated groups.Examples
>>> import pandas as pd >>> from geoprior.utils.calibrate import calibrate_forecasts >>> df = pd.DataFrame( ... { ... "forecast_step": [1, 1, 2, 2], ... "subsidence_actual": [0.4, 0.8, 0.5, 1.0], ... "subsidence_q10": [0.3, 0.6, 0.4, 0.8], ... "subsidence_q50": [0.4, 0.8, 0.5, 1.0], ... "subsidence_q90": [0.5, 1.0, 0.6, 1.2], ... } ... ) >>> out = calibrate_forecasts( ... df, ... quantiles=[0.1, 0.5, 0.9], ... q_prefix="subsidence", ... actual_col="subsidence_actual", ... method="isotonic", ... grid_mode="range", ... group_by="forecast_step", ... ) >>> "calib_subsidence_q50" in out.columns True
See also
calibrate_probability_forecastCalibrate event probabilities rather than continuous quantiles.
calibrate_quantile_forecastsWrapper for interval-width calibration on tabular forecasts.
fit_interval_factors_dfLearn empirical per-horizon interval scaling factors.
The calibration layer is especially relevant to uncertainty workflows spanning Stage-2 through Stage-5 and any evaluation path that compares interval quality before and after calibration.
Forecast helpers#
Forecast utilities.
- geoprior.utils.forecast_utils.detect_forecast_type(df, value_prefixes=None)[source]
Auto-detects whether a DataFrame contains deterministic or quantile forecasts, supporting both long and wide formats.
This utility inspects column names to determine the nature of the predictions.
It identifies a ‘quantile’ forecast if it finds columns containing a
_qXXpattern (e.g., ‘subsidence_q10’, ‘GWL_2022_q50’).It identifies a ‘deterministic’ forecast if no quantile columns are found, but columns ending in
_pred, _actual, or matching a base prefix exist (e.g., ‘subsidence_pred’, ‘subsidence_2022_actual’, ‘GWL’).
- Parameters:
- Returns:
One of ‘quantile’, ‘deterministic’, or ‘unknown’.
- Return type:
Examples
>>> import pandas as pd >>> from geoprior.utils.forecast_utils import detect_forecast_type >>> # Long format quantile >>> df_quant_long = pd.DataFrame(columns=['subsidence_q50', 'GWL_q90']) >>> detect_forecast_type(df_quant_long) 'quantile'
>>> # Wide format quantile >>> df_quant_wide = pd.DataFrame(columns=['subsidence_2022_q50']) >>> detect_forecast_type(df_quant_wide) 'quantile'
>>> # Deterministic forecast >>> df_determ = pd.DataFrame(columns=['subsidence_pred', 'GWL']) >>> detect_forecast_type(df_determ) 'deterministic'
- geoprior.utils.forecast_utils.format_forecast_dataframe(df, to_wide=True, time_col='coord_t', spatial_cols=('coord_x', 'coord_y'), value_prefixes=None, _logger=None, **pivot_kwargs)[source]
Auto-detects DataFrame format and conditionally pivots to wide format.
This function serves as a smart wrapper. It first determines if the input DataFrame is in a ‘long’ or ‘wide’ forecast format based on its column structure. If to_wide is True and the format is ‘long’, it calls
pivot_forecast_dataframe()to perform the transformation.- Parameters:
df (
pd.DataFrame) – The input DataFrame to check and potentially transform.to_wide (
bool, defaultTrue) –If
True, the function’s goal is to return a wide-format DataFrame. It will pivot a long-format frame or return a wide-format frame as is.If
False, the function only performs detection and returns a string (‘wide’, ‘long’, or ‘unknown’).
time_col (
str, default'coord_t') – The name of the column that indicates the time step. Its presence is a primary indicator of a long-format DataFrame.value_prefixes (
listofstr, optional) – A list of prefixes for the value columns (e.g., [‘subsidence’, ‘GWL’]). IfNone, the function will attempt to infer them from column names that do not match common ID columns.**pivot_kwargs – Additional keyword arguments to pass down to the
pivot_forecast_dataframe()function if it is called. Common arguments include id_vars, static_actuals_cols, verbose, etc.
- Returns:
If to_wide is
True, returns the (potentially pivoted) wide-formatpd.DataFrame.If to_wide is
False, returns a string: ‘wide’, ‘long’, or ‘unknown’.
- Return type:
pd.DataFrameorstr
See also
pivot_forecast_dataframeThe underlying function that performs the pivot operation.
Examples
>>> # df_long is a typical long-format forecast output >>> df_long.columns Index(['sample_idx', 'forecast_step', 'coord_t', 'coord_x', ...]) >>> # Detect format >>> format_str = format_forecast_dataframe(df_long, to_wide=False) >>> print(format_str) 'long' >>> >>> # Convert to wide format >>> df_wide = format_forecast_dataframe( ... df_long, ... to_wide=True, ... id_vars=['sample_idx', 'coord_x', 'coord_y'], ... value_prefixes=['subsidence', 'GWL'], ... static_actuals_cols=['subsidence_actual'] ... ) >>> # print(df_wide.columns) # Index(['sample_idx', 'coord_x', 'coord_y', 'subsidence_actual', # 'GWL_2018_q50', ...], dtype='object')
- geoprior.utils.forecast_utils.get_value_prefixes(df, exclude_cols=None, spatial_cols=('coord_x', 'coord_y'), time_col='coord_t')[source]
Automatically detects the prefixes of value columns from a DataFrame.
This utility inspects the column names to infer the base names of the metrics being forecasted (e.g., ‘subsidence’, ‘GWL’), excluding common ID and coordinate columns. It works with both long and wide format forecast DataFrames.
- Parameters:
- Returns:
A sorted list of unique prefixes found in the column names.
- Return type:
Examples
>>> from geoprior.utils.data_utils import get_values_prefixes >>> # For a long-format DataFrame >>> long_cols = ['sample_idx', 'coord_t', 'subsidence_q50', 'GWL_q50'] >>> df_long = pd.DataFrame(columns=long_cols) >>> get_value_prefixes(df_long) ['GWL', 'subsidence']
>>> # For a wide-format DataFrame >>> wide_cols = ['sample_idx', 'coord_x', 'subsidence_2022_q90', 'GWL_2022_q50'] >>> df_wide = pd.DataFrame(columns=wide_cols) >>> get_value_prefixes(df_wide) ['GWL', 'subsidence']
- geoprior.utils.forecast_utils.get_value_prefixes_in(df, exclude_cols=None)[source]
Automatically detects the prefixes of value columns from a DataFrame. (This is a dependency for the function below)
- geoprior.utils.forecast_utils.pivot_forecast_dataframe(data, id_vars, time_col, value_prefixes, static_actuals_cols=None, time_col_is_float_year='auto', round_time_col=False, verbose=0, savefile=None, _logger=None, **kws)[source]
Transforms a long-format forecast DataFrame to a wide format.
This utility reshapes time series prediction data from a “long” format, where each row represents a single time step for a given sample, to a “wide” format, where each row represents a single sample and columns correspond to values at different time steps.
- Parameters:
data (
pd.DataFrame) – The input long-format DataFrame. It must contain the columns specified in id_vars and time_col, as well as value columns that start with the strings in value_prefixes.id_vars (
listofstr) – A list of column names that uniquely identify each sample or group. These columns will be preserved in the wide-format output. For example:['sample_idx', 'coord_x', 'coord_y'].time_col (
str) – The name of the column that represents the time step or year of the forecast (e.g., ‘coord_t’ or ‘forecast_step’). This column’s values will become part of the new column names.value_prefixes (
listofstr) – A list of prefixes for the value columns that need to be pivoted. The function identifies columns starting with these prefixes. For instance,['subsidence', 'GWL']would match ‘subsidence_q10’, ‘GWL_q50’, etc.static_actuals_cols (
listofstr, optional) – A list of columns containing static “actual” or ground truth values for each sample. These values are assumed to be constant for each unique sample_idx and are merged back into the wide DataFrame after pivoting. Example:['subsidence_actual'].time_col_is_float_year (
boolor'auto', default'auto') –Controls how the time_col values are formatted into new column names. - If
'auto', automatically detects if time_col has afloat dtype.
If
True, treats time_col values (e.g., 2018.0) as years and converts them to integer strings (‘2018’).If
False, uses the string representation of the value as is.
round_time_col (
bool, defaultFalse) – IfTrueand time_col is a float type, its values will be rounded to the nearest integer before being used in column names. This is useful for cleaning up float years (e.g., 2018.0001 -> 2018).verbose (
int, default0) – Controls the verbosity of logging messages. 0 is silent. Higher values print more details about the process.savefile (
str, optional) – If a file path is provided, the final wide-format DataFrame will be saved as a CSV file to that location.
- Returns:
A wide-format DataFrame with one row per unique combination of id_vars. New columns are created in the format {prefix}_{time_str}{_suffix} (e.g., ‘subsidence_2018_q10’).
- Return type:
pd.DataFrame
See also
pandas.pivot_tableThe core function used for reshaping data.
pandas.mergeUsed to re-join static columns after pivoting.
Notes
The combination of columns in id_vars and time_col must uniquely identify each row in df_long for the pivot to succeed without data loss.
If using static_actuals_cols, the id_vars list must contain ‘sample_idx’ to correctly merge the static data back.
Examples
>>> import pandas as pd >>> from geoprior.utils.data_utils import pivot_forecast_dataframe >>> data = { ... 'sample_idx': [0, 0, 1, 1], ... 'coord_t': [2018.0, 2019.0, 2018.0, 2019.0], ... 'coord_x': [0.1, 0.1, 0.5, 0.5], ... 'coord_y': [0.2, 0.2, 0.6, 0.6], ... 'subsidence_q50': [-8, -9, -13, -14], ... 'subsidence_actual': [-8.5, -8.5, -13.2, -13.2], ... 'GWL_q50': [1.2, 1.3, 2.2, 2.3], ... } >>> df_long_example = pd.DataFrame(data) >>> df_wide = pivot_forecast_dataframe( ... data=df_long_example, ... id_vars=['sample_idx', 'coord_x', 'coord_y'], ... time_col='coord_t', ... value_prefixes=['subsidence', 'GWL'], ... static_actuals_cols=['subsidence_actual'], ... verbose=0 ... ) >>> print(df_wide.columns) Index(['sample_idx', 'coord_x', 'coord_y', 'subsidence_actual', 'GWL_2018_q50', 'GWL_2019_q50', 'subsidence_2018_q50', 'subsidence_2019_q50'], dtype='object')
- geoprior.utils.forecast_utils.get_step_names(forecast_steps, step_names=None, default_name='')[source]
Build a step → label mapping for multi‑horizon plots.
The helper reconciles an integer list
forecast_stepswith an optional alias container (dict or sequence) and returns a dictionary whose keys are the integer steps and whose values are human‑readable labels.Matching is case‑insensitive and tolerant to common delimiters—e.g.
"Step 1","step‑1", or"forecast step 1"will all map to integer step1.- Parameters:
forecast_steps (
Iterable[int]) – Ordered steps, e.g.[1, 2, 3].step_names (
dict | list | tuple | None, defaultNone) –Custom labels. Accepted forms
dict – keys may be
intor any string representation of the step.sequence – positional, where the k‑th element labels step
k+1.None – no custom mapping.
default_name (
str, default"") – Fallback label for steps missing from step_names. If empty, the step number itself is used (as a string).
- Returns:
Mapping
{step : label}for every element of forecast_steps.- Return type:
dict[int,str]
Notes
Dictionary keys are normalised with
int(re.sub(r"[^0-9]", "", str(key)))before matching.Duplicate keys in step_names are resolved by last‐one wins semantics.
Examples
>>> from geoprior.utils.forecast_utils import get_step_names >>> get_step_names( ... forecast_steps=[1, 2, 3], ... step_names={"1": "Year 2021", 2: "2022", "step 3": "2023"}, ... ) {1: 'Year 2021', 2: '2022', 3: '2023'}
>>> get_step_names( ... forecast_steps=[1, 2, 3, 4], ... step_names={"1": "2021", "2": "2022"}, ... ) {1: '2021', 2: '2022', 3: '3', 4: '4'}
>>> get_step_names( ... [1, 2, 3, 4], ... step_names=None, ... default_name="step with no name", ... ) {1: 'step with no name', 2: 'step with no name', 3: 'step with no name', 4: 'step with no name'}
- geoprior.utils.forecast_utils.stack_quantile_predictions(q_lower, q_median, q_upper)[source]
Stack three quantile trajectories into a single y_pred array of shape (n_samples, 3, n_timesteps), ready for PSS.
- Parameters:
q_lower (
array-like) – Each is either - 1D: (n_timesteps,) → interpreted as a single sample, or - 2D: (n_samples, n_timesteps)q_median (
array-like) – Each is either - 1D: (n_timesteps,) → interpreted as a single sample, or - 2D: (n_samples, n_timesteps)q_upper (
array-like) – Each is either - 1D: (n_timesteps,) → interpreted as a single sample, or - 2D: (n_samples, n_timesteps)
- Returns:
y_pred – Where axis=1 indexes [lower, median, upper].
- Return type:
np.ndarray,shape (n_samples,3,n_timesteps)- Raises:
ValueError – If the three inputs (after promotion) do not share the same shape.
- geoprior.utils.forecast_utils.adjust_time_predictions(df, time_col, forecast_horizon, coord_scaler=None, inverse_transformed=False, verbose=1)[source]
Adjusts time predictions by adding the forecast horizon to inverse normalized time. If the time column has already been inverse-transformed, skip the inverse transformation.
- Parameters:
df (
pd.DataFrame) – The DataFrame containing the time predictions (inverse scaled). The time column specified by time_col should contain the time values that need to be adjusted.time_col (
str) – The name of the time column in the DataFrame. This column will be adjusted by adding the forecast horizon.forecast_horizon (
int) – The forecast horizon (e.g., number of years or time steps) that will be added to the time predictions. This value shifts the time predictions forward.coord_scaler (
MinMaxScaler, optional) – The scaler that was used for the coordinates. It is necessary to reverse the scaling for the time column if it was previously normalized. If not provided, the time column should already be inverse-transformed.inverse_transformed (
bool, defaultFalse) – If True, skips the inverse transformation of the time column and directly adds the forecast horizon. This is useful when the time column has already been inverse-transformed, and you only need to adjust the time by the forecast horizon.verbose (
int, default1) – Verbosity level for logging. Higher values (e.g., verbose=2) provide more detailed information about the operation.
- Returns:
The adjusted DataFrame with the time column updated to reflect the forecast horizon. The time predictions are adjusted by adding the forecast_horizon to each entry in the time column.
- Return type:
pd.DataFrame- Raises:
ValueError – If the time column is not found in the DataFrame or if the scaler is not available when necessary.
Examples
>>> import pandas as pd >>> from sklearn.preprocessing import MinMaxScaler >>> # Sample data for illustration >>> df = pd.DataFrame({ >>> 'year': [0.0, 0.5, 1.0], >>> 'subsidence': [0.1, 0.2, 0.3] >>> }) >>> scaler = MinMaxScaler() >>> df_scaled = df.copy() >>> df_scaled['year'] = scaler.fit_transform(df_scaled[['year']]) >>> adjusted_df = adjust_time_predictions( >>> df_scaled, >>> time_col='year', >>> forecast_horizon=4, >>> coord_scaler=scaler, >>> inverse_transformed=False, >>> verbose=2 >>> ) >>> adjusted_df['year'] [0.0, 0.5, 1.0] -> After adjustment, will be shifted to the future.
Notes
The time column must be in a normalized scale if not already inverse-transformed.
If inverse_transformed=True, the time values will directly be adjusted by the forecast_horizon without applying the inverse transformation.
The forecast horizon is added directly to the time values after the necessary inverse transformation (if applicable).
See also
sklearn.preprocessing.MinMaxScalerScales features to [0,1].
- geoprior.utils.forecast_utils.add_forecast_times(df, *, forecast_times=None, start=None, freq='YS', step_col='forecast_step', time_col='coord_t', error='raise', inplace=False, savefile=None, verbose=0)[source]
Map each 1‑based forecast_step into an explicit calendar time.
- You may either:
Pass forecast_times of length H (one per step), or
Pass a single start plus a pandas‐style freq to generate H dates.
If any entry in forecast_times is an integer of exactly 4 digits, it will be interpreted as January 1 of that year.
- Parameters:
df (
pd.DataFrame) – Long‐format forecast table. Must contain an integer column step_col with values 1..H.forecast_times (
sequence, optional) – Explicit sequence of length H specifying the target times. Each entry may be: - int (interpreted as January 1 of that year) - str/pd.Timestamp/datetime.datestart (
intorstrordateorTimestamp, optional) – Only used if forecast_times is None. The first time in the sequence; subsequent times will be generated via pd.date_range. If int, treated as a year at Jan 1.freq (
str, default"YS") – Pandas offset alias for frequency (e.g. “YS”=year start, “MS”=month start, “D”=day, etc.). Only used when start is set.step_col (
str, default"forecast_step") – Name of the 1‑based step index in df.time_col (
str, default"coord_t") – Name of the new column to create with mapped times.error (
{'raise','warn','ignore'}, default'raise') – Policy if df[step_col].max() > number of provided times: - ‘raise’: throw ValueError - ‘warn’: issue warning, then still map what you can (truncate) - ‘ignore’: silently truncate to available timesinplace (
bool, defaultFalse) – If True, modify df in place; otherwise return a new DataFrame.savefile (
str, optional) – If provided, path to CSV where the resulting DataFrame will be saved.verbose (
int, default0) – Passed to vlog for debug logging.
- Returns:
DataFrame with an added column time_col of dtype datetime64.
- Return type:
pd.DataFrame- Raises:
ValueError – If neither forecast_times nor start is provided, or if error=’raise’ and there aren’t enough times.
Examples
>>> from geoprior.utils.forecast_utils import add_forecast_times >>> df = pd.DataFrame({ ... "sample_idx": [0]*3 + [1]*3, ... "forecast_step": [1,2,3]*2 ... }) >>> add_forecast_times(df, ... forecast_times=[2022,2023,2024]) sample_idx forecast_step coord_t 0 0 1 2022-01-01 1 0 2 2023-01-01 2 0 3 2024-01-01 3 1 1 2022-01-01 4 1 2 2023-01-01 5 1 3 2024-01-01
>>> # Or generate from a start + yearly freq: >>> add_forecast_times(df, start="2022-06-15", freq="YS")
- geoprior.utils.forecast_utils.pivot_forecast(df, *, index_col='sample_idx', pivot_col=None, step_col='forecast_step', time_col='coord_t', value_cols=None, spatial_cols=None, aggfunc='first', fill_value=nan, sep='_', time_formatter=<function <lambda>>, inplace=False, savefile=None, verbose=0)[source]
Pivot a long-format forecast DataFrame into a wide one.
This will take rows identified by index_col + a step_col (or datetime time_col) and spread each forecast step/time into its own set of columns for each value in value_cols, then re-attach the spatial_cols.
- Parameters:
df (DataFrame) – Long-format forecasts. Must include index_col and at least one of step_col or time_col.
index_col (str) – Column that identifies each sample (e.g. “sample_idx”).
pivot_col (str | None) – If provided, pivot on this column instead of auto-detecting. Must be either step_col or time_col.
step_col (str) – Name of the integer 1‑based forecast step column.
time_col (str) – Name of the datetime column (e.g. “coord_t”).
value_cols (str | Sequence[str] | None) – Which forecast columns to pivot (e.g. “subsidence_q50” or [“subsidence_q10”,”subsidence_q50”,”subsidence_q90”]). If None, will auto-pick all numeric columns except index/pivot/spatial.
spatial_cols (Sequence[str] | None) – List of columns holding static spatial info (e.g. [“longitude”,”latitude”]) to join back once pivoted.
aggfunc (str | Callable) – Aggregation function for pivot (default “first”).
fill_value (Any) – What to put where a sample/step is missing (default NaN).
sep (str) – Separator between value name and step/time in the new column names (default “_”).
time_formatter (Callable[[Any], str]) – How to turn a datetime/timestamp into a string for column names (default “%Y-%m-%d”).
inplace (bool) – If True, modifies df instead of copying.
savefile (str | None) – If given, writes the resulting wide DataFrame to CSV at this path.
verbose (int) – Passed to vlog for logging.
- Returns:
Wide-format DataFrame with one row per index_col and columns like <value><sep><step> or <value><sep><formatted time>.
- Return type:
pd.DataFrame
Example
>>> from geoprior.utils.forecast_utils import pivot_forecast
>>> dff = pivot_forecast( ... df_, ... index_col="sample_idx", ... pivot_col="coord_t", # ← force pivot on the datetime ... value_cols=["subsidence_q10","subsidence_q50","subsidence_q90"], ... spatial_cols=["longitude","latitude"], ... sep="_", # you’ll get subsidence_q50_2022 etc. ... time_formatter=lambda t: f"{t.year}", ... verbose=1 ... ) [INFO] Pivoting on 'coord_t' for values ['subsidence_q10', 'subsidence_q50', 'subsidence_q90'] [INFO] Joining back spatial cols ['longitude', 'latitude']
>>> dff.columns Out[37]: Index(['sample_idx', 'subsidence_q10_2022', 'subsidence_q10_2023', 'subsidence_q10_2024', 'subsidence_q50_2022', 'subsidence_q50_2023', 'subsidence_q50_2024', 'subsidence_q90_2022', 'subsidence_q90_2023', 'subsidence_q90_2024', 'longitude', 'latitude'], dtype='object')
- geoprior.utils.forecast_utils.plot_reliability_diagram(models_data, y_true=None, prefix='subsidence', figsize=(8, 8), title='Reliability Diagram', plot_style='seaborn-whitegrid', verbose=None, _logger=None)[source]
Plot a reliability diagram for one or multiple models.
- Parameters:
models_data (
dict) – Mapping of model names to forecast data. Each value can be a pandas.DataFrame or a nested dict with keys ‘forecasts’, ‘color’, ‘marker’, and ‘style’.y_true (
pandas.Series, optional) – Observed values for empirical coverage calculations. Required when forecasts need processing.prefix (
str, default'subsidence') – Column prefix for quantile forecast fields.figsize (
tupleofint, default(8,8)) – Figure size (width, height) in inches.title (
str, default'Reliability Diagram') – Text title displayed at the top of the plot.plot_style (
str, default'seaborn-whitegrid') – Matplotlib style sheet name to apply.verbose (
int, optional) – Verbosity level passed to geoprior.utils.generic_utils.vlog._logger (
Loggerorcallable, optional) – Function or logger instance for internal messages.
- Returns:
Displays the calibration plot and returns nothing.
- Return type:
Notes
This function draws a diagonal baseline (perfect calibration) and computes empirical coverage for probabilistic intervals using specified quantiles. It wraps simple DataFrame inputs into the required nested format and uses vlog for conditional logging.
Examples
>>> import pandas as pd >>> import numpy as np >>> from geoprior.utils.forecast_utils import plot_reliability_diagram >>> # Create dummy true time series >>> dates = pd.date_range('2020-01-01', periods=100) >>> y_true = pd.Series( ... np.random.randn(100), index=dates ... ) >>> # Create forecasts for ModelA >>> dfA = pd.DataFrame({ ... 'subsidence_q10': y_true - 0.5, ... 'subsidence_q90': y_true + 0.5 ... }, index=dates) >>> # Simple usage with one model >>> plot_reliability_diagram( ... models_data={'ModelA': dfA}, ... y_true=y_true, ... verbose=2 ... ) >>> # Create forecasts for ModelB ... # with custom styling >>> dfB = pd.DataFrame({ ... 'subsidence_q10': y_true - 1.0, ... 'subsidence_q90': y_true + 1.0 ... }, index=dates) >>> custom_logger = print >>> # Custom styling and logger >>> plot_reliability_diagram( ... models_data={ ... 'ModelA': { ... 'forecasts': dfA, ... 'color': 'C0', ... 'marker': 'x' ... }, ... 'ModelB': { ... 'forecasts': dfB, ... 'color': 'C1', ... 'marker': 'o' ... } ... }, ... y_true=y_true, ... verbose=4, ... _logger=custom_logger ... )
- geoprior.utils.forecast_utils.format_and_forecast(y_pred, y_true, *, coords=None, quantiles=None, target_name='subsidence', output_target_name=None, scaler_target_name=None, target_key_pred='subs_pred', component_index=0, scaler_info=None, coord_scaler=None, coord_columns=('coord_t', 'coord_x', 'coord_y'), train_end_time=None, forecast_start_time=None, forecast_horizon=None, future_time_grid=None, eval_forecast_step=None, eval_export='all', value_mode='rate', input_value_mode='rate', rate_first='cum_over_dtref', absolute_baseline=None, sample_index_offset=0, city_name=None, model_name=None, dataset_name=None, csv_eval_path=None, csv_future_path=None, time_as_datetime=False, time_format=None, calibration=False, calibration_kwargs=None, calibration_save_stats=None, eval_metrics=False, metrics_column_map=None, metrics_quantile_interval=(0.1, 0.9), metrics_per_horizon=False, metrics_extra=None, metrics_extra_kwargs=None, metrics_savefile=None, metrics_save_format='.json', metrics_time_as_str=True, output_unit=None, output_unit_from='m', output_unit_mode='overwrite', output_unit_suffix='_mm', output_unit_col=None, verbose=1, logger=None, **kws)[source]
Format PINN forecasts into evaluation and future DataFrames.
This helper takes the raw model outputs (already split into
y_pred['subs_pred']/y_pred['gwl_pred']), the matching ground-truth dictionary (y_true), and optional coordinate and scaler information, and returns two DataFrames:df_eval: predictions + actuals for an evaluation year (typically the last training year, e.g. 2022).df_future: predictions for the future horizon (e.g. 2023–2025), without actuals.
- Parameters:
y_pred (
dict) –Dictionary of model predictions, as returned by
GeoPriorSubsNet.predictpost-processed into{'subs_pred': ..., 'gwl_pred': ...}.For subsidence, the expected shapes are:
Quantile mode:
(B, H, Q, O)where:B= batch size,H= horizon steps,Q= number of quantiles,O= output dim.Point mode:
(B, H, O).
Dictionary of true targets, typically
{'subsidence': ..., 'gwl': ...}or{'subs_pred': ..., 'gwl_pred': ...}.If
None, evaluation DataFrame is still created but without the actual-value column.coords (
ndarray, optional) – Optional coordinates array aligned with predictions. Commonly shaped(B, H, 3)with columns[t_scaled, x_scaled, y_scaled]. Onlyxandyare used when inverse-transforming spatial coordinates; time is overwritten by the provided temporal config if given.quantiles (
listoffloatorNone, optional) – List of quantiles (e.g.[0.1, 0.5, 0.9]) if the model was trained in probabilistic mode. IfNone, a single prediction column is emitted instead.target_name (
str, default'subsidence') –Logical target identifier used as the default key for locating the target scaler in
scaler_infoand as a fallback for resolving truth arrays iny_true.Column naming is controlled by
output_target_name(or the auto-derived output prefix when it isNone).output_target_name (
strorNone, optional) –Output prefix used when creating DataFrame columns for predictions and actuals.
This controls the column naming only (e.g. the function will emit
f"{output_target_name}_q10",f"{output_target_name}_pred", andf"{output_target_name}_actual").If
None(default), the function derives the output prefix fromtarget_nameand applies a small convenience rule: iftarget_nameends with"_cum"or"_cumulative", that suffix is stripped for output naming.This keeps downstream tooling consistent (many plotting and metrics utilities expect names like
subsidence_q10rather thansubsidence_cum_q10), while still allowing the scaler lookup to use the true target key. For example, withtarget_name="subsidence_cum"andoutput_target_name=None, output columns becomesubsidence_q10,subsidence_q50, andsubsidence_actual. Ifoutput_target_name="subsidence_cum", the output columns keep the suffix such assubsidence_cum_q10.scaler_target_name (
strorNone, optional) –Name used to locate the target scaling block inside
scaler_infoand to perform inverse-transform for predictions and actuals.This controls the scaler key and inverse scaling, not the output column naming.
If
None(default), the scaler key is assumed to betarget_name. This is important when you want clean output columns but the scaler was fitted/stored under the original target name.A common pattern is to keep
target_name="subsidence_cum"so the scaler lookup matches the Stage-1 schema, while lettingoutput_target_name=Noneproduce clean output columns. In that setup, inverse transform still uses thesubsidence_cumscaler key, while output columns use thesubsidence_prefix because of the auto-strip rule.target_key_pred (
str, default'subs_pred') – Key insidey_predthat holds the subsidence forecasts.component_index (
int, default0) – Index along the output dimension O to use whenoutput_subsidence_dim > 1. For scalar subsidence this is 0.scaler_info (
dict, optional) – Optional Stage-1scaler_infomapping containing a target scaler under keys such as'targets'or'target'. The target block is expected to provide an sklearn-like transformer under'scaler'together with column names under'columns'or'cols'. If present and consistent, subsidence values (predicted and actual) are inverse-transformed fortarget_name.coord_scaler (
object, optional) – Optional scaler used for coordinates. If provided, it is only used to inverse-transformcoord_xandcoord_ywhencoordsis given andcoord_columnscan be matched. Time is not taken from the inverse transform; it is controlled by the temporal config.coord_columns (
tupleofstr, default(``’coord_t’:py:class:`,`’coord_x’:py:class:`,`’coord_y’``)) – Logical names of the time, x, and y coordinate columns. These are used for DataFrame column naming and for mapping intocoord_scalerif its block carries column names.train_end_time (
scalarorstrordatetime, optional) – Physical time associated with the evaluation year (e.g. 2022). Ifeval_forecast_stepis not given, the last horizon step is assumed to correspond to this time.forecast_start_time (
scalarorstrordatetime, optional) – First time in the future forecast horizon (e.g. 2023).forecast_horizon (
int, optional) – Number of forecast steps in the future horizon (e.g. 3). Iffuture_time_gridis not given, this is used together withforecast_start_timeto build a regular grid.future_time_grid (
array-like, optional) – Explicit physical times for each forecast step, length H. For yearly data this might be[2023, 2024, 2025]. If provided, it overrides any automatic construction fromforecast_start_timeandforecast_horizon.eval_forecast_step (
intorNone, optional) – Horizon step index (1-based) to use for evaluation. IfNone, defaults to the last horizon step H.eval_export (
{"all", "last"}orstrorintorsequence, optional) –Controls which evaluation rows are exported in
df_evaland written tocsv_eval_path. By default ("all"), the function exports the multi-horizon evaluation DataFrame (df_eval_all), which contains one row per sample and forecast step (e.g. years 2020, 2021, 2022 forH=3).Accepted values are:
"all"or"full"or"horizons": export all horizons fromdf_eval_all."last"or"single"or"default": export only the single evaluation step specified byeval_forecast_step(backwards-compatible behaviour).Other
str(e.g."2022") : interpreted as a time value forcoord_t; only rows ofdf_eval_allwhose time column matches this value are exported.intor scalar non-string : interpreted as a single time value (e.g.2022).sequence of values (e.g.
[2021, 2022]) : interpreted as a set of time values; only rows whosecoord_tbelongs to this set are exported.
If
time_as_datetime=True, the selection values are converted withpandas.to_datetimeusingtime_formatbefore filtering. Ifdf_eval_allis not available (e.g. no ground truth was provided), the function falls back to exporting the single-stepdf_evalregardless ofeval_export.value_mode (
{"rate", "cumulative", "absolute_cumulative"}, optional) –Controls how forecast values are interpreted along the temporal horizon for each sample. The default is
"rate", which treats each forecast step as an incremental rate (e.g. annual subsidence rate) and leaves predictions unchanged.Supported modes are:
"rate": keep per-step predictions as provided by the model (current behaviour)."cumulative"or"cum": convert per-step rates into relative cumulative values by applying a cumulative sum overforecast_stepfor eachsample_idx. For example, for years 2023–2025, the value at 2024 is the sum of the 2023 and 2024 rates."absolute_cumulative"or"abs_cum"or"absolute": same as"cumulative", then add an absolute baseline provided byabsolute_baseline(e.g. cumulative subsidence at the end of the training period), yielding absolute cumulative trajectories.
Cumulative transforms are applied consistently to:
the future forecast DataFrame (
df_future),the multi-horizon evaluation DataFrame (
df_eval_all),and the single-step evaluation DataFrame (
df_eval, which is regenerated fromdf_eval_allafter the transformation).
When an unsupported string is given, the function logs a warning and falls back to
"rate".absolute_baseline (
floatorMapping[int,float], optional) –Baseline value to use when
value_moderequests absolute cumulative outputs ("absolute_cumulative","abs_cum","absolute"). This baseline is interpreted as the pre-forecast cumulative level for each sample, for example, cumulative subsidence attrain_end_time(e.g. end of 2022), and is added after applying the cumulative sum over the forecast horizon.If a scalar
floatis provided, the same baseline value is added to all samples. If a mapping is provided, it must mapsample_idx(integers) to baseline values, allowing per-sample baselines:absolute_baseline = {sample_idx: baseline_value, ...}
Only prediction columns for
target_nameare shifted (e.g."subsidence_q10","subsidence_q50","subsidence_q90"or"subsidence_pred"). Whendf_eval_allis present, the corresponding"<target_name>_actual"column is shifted as well, so evaluation metrics operate on absolute cumulative values.If
value_modeis an absolute cumulative variant butabsolute_baselineisNone, the function logs a warning and degrades gracefully to relative cumulative mode (i.e. no baseline shift is applied).sample_index_offset (
int, default0) – Offset added tosample_idx(useful when concatenating multiple tiles).city_name (
str, optional) – Optional metadata used only for logging.model_name (
str, optional) – Optional metadata used only for logging.dataset_name (
str, optional) – Optional metadata used only for logging.csv_eval_path (
str, optional) – If provided,df_evalis written to this path (directories are created if needed).csv_future_path (
str, optional) – If provided,df_futureis written to this path.time_as_datetime (
bool, defaultFalse) – IfTrue, time values are converted usingpandas.to_datetime()with the providedtime_format(if any).time_format (
strorNone, optional) – Optional format string passed topandas.to_datetime()whentime_as_datetime=True.eval_metrics (
bool, defaultFalse) – IfTrue, automatically callevaluate_forecast()on the resultingdf_evalto compute diagnostics. Metrics are not returned by this function; they are either written to disk (ifmetrics_savefileis provided) or discarded. For programmatic access to the metrics dictionary, callevaluate_forecast()directly.metrics_column_map (
mapping, optional) – Optional column mapping forwarded toevaluate_forecast()(see its documentation for details). IfNone, default column names such as'coord_t','forecast_step',f'{target_name}_q10', andf'{target_name}_actual'are assumed.metrics_quantile_interval (
tupleoffloat, default(0.1,0.9)) – Interval used for coverage and sharpness diagnostics in quantile mode, forwarded toevaluate_forecast().metrics_per_horizon (
bool, defaultFalse) – IfTrue, per-horizon MAE/MSE/R² are computed byevaluate_forecast()and included in the diagnostics.metrics_extra (
sequenceormapping, optional) –Optional additional metrics to compute, forwarded to
evaluate_forecast(). Can be:A sequence of metric names (resolved via
geoprior.metrics._registry.get_metric).A mapping
{name: func}wherefuncis a callable taking(y_true, y_pred, **kwargs).
metrics_extra_kwargs (
mapping, optional) – Optional per-metric keyword arguments, forwarded toevaluate_forecast(). Keys must match metric names inmetrics_extra.metrics_savefile (
str,path-like,bool, orNone) – If truthy, diagnostics fromevaluate_forecast()are written to disk. Behavior matches thesavefileargument ofevaluate_forecast(). WhenTrue, a filename is auto-generated near the evaluation CSV (if any) or in the current working directory.metrics_save_format (
{'.json', 'json', '.csv', 'csv'}, default'.json') – Output format for diagnostics written byevaluate_forecast(). JSON preserves the nested metric structure; CSV flattens it into a tall table.metrics_time_as_str (
bool, defaultTrue) – IfTrue, time keys in the diagnostics written byevaluate_forecast()are converted to strings (useful for JSON serialization).verbose (
int, default1) – Verbosity level passed tovlog().logger (
logging.Logger, optional) – Logger instance; ifNone, a module-levelLOGis used.input_value_mode (str)
rate_first (str)
output_unit (str | None)
output_unit_from (str)
output_unit_mode (str)
output_unit_suffix (str)
output_unit_col (str | None)
- Returns:
df_eval_to_write (
pandas.DataFrame) – DataFrame containing predictions and actuals for the evaluation time. Columns include:'sample_idx''forecast_step'quantile columns (e.g.
subsidence_q10) orsubsidence_pred'subsidence_actual'(if y_true given)coord_t,coord_x,coord_y(names fromcoord_columns).
df_future (
pandas.DataFrame) – DataFrame containing predictions for the future horizon, without actuals. Same structure asdf_evalbut without the actual-value column.
- Return type:
Notes
This function separates scaler lookup (
scaler_target_name) from output column naming (output_target_name). This is useful when the stored scaler key contains suffixes like"_cum"but downstream tools expect canonical names such as columns prefixed withsubsidence_.
- geoprior.utils.forecast_utils.evaluate_forecast(eval_data, *, target_name='subsidence', column_map=None, quantile_interval=(0.1, 0.9), per_horizon=False, extra_metrics=None, extra_metric_kwargs=None, overall_key='__overall__', savefile=None, save_format='.json', time_as_str=True, verbose=1, logger=None)[source]
Evaluate forecast diagnostics from an evaluation DataFrame.
This helper consumes the
df_evaloutput fromformat_and_forecast()(or a compatible DataFrame) and computes aggregate metrics such as MAE, MSE, \(R^2\), coverage, and sharpness. It can also optionally evaluate metrics per forecast horizon and apply additional user-defined metrics.By default it expects the following columns:
'sample_idx''forecast_step''coord_t'(time)Quantile or point-prediction columns for the target, e.g.:
Quantile mode:
f'{target_name}_q10',f'{target_name}_q50',f'{target_name}_q90', …Point mode:
f'{target_name}_pred'.
Actual column:
f'{target_name}_actual'.
A flexible
column_mapallows remapping these logical roles to arbitrary column names, e.g.:column_map = { 'coord_t': 'date', 'actual': 'true_subs', 'pred': 'subs_predicted', }
or, for quantile columns:
column_map = { 'coord_t': 'date', 'quantiles': { 0.1: 'subs_q10', 0.5: 'subs_q50', 0.9: 'subs_q90', }, }
- Parameters:
eval_data (
str,path-like, orpandas.DataFrame) – Either a path to a CSV file containing the evaluation DataFrame (as saved byformat_and_forecast()) or an in-memory DataFrame.target_name (
str, default'subsidence') – Base name for the target columns. Used to infer default column names such asf'{target_name}_q10',f'{target_name}_pred', andf'{target_name}_actual'.column_map (
dict, optional) –Optional mapping to override default column names. The following keys are recognized:
'sample_idx': sample index column name (default'sample_idx').'forecast_step': horizon index column name (default'forecast_step').'coord_t': time coordinate column (default'coord_t').'actual': name or list of names for the actual target column(s). Currently a single column is supported; defaultf'{target_name}_actual'.'pred': point prediction column for non-quantile mode, defaultf'{target_name}_pred'.'quantiles':If a mapping:
{q: col_name}for quantile levels, whereqis a float in (0, 1).If a sequence of column names, the quantile value will be inferred from suffix patterns like
f'{target_name}_q{int(q*100):d}'.
quantile_interval (
tupleoffloat, default(0.1,0.9)) – Interval (lower, upper) used for coverage and sharpness metrics, typically corresponding to an 80% interval between Q10 and Q90.per_horizon (
bool, defaultFalse) – IfTrue, compute per-horizon MAE/MSE/R² grouped by theforecast_stepcolumn.extra_metrics (
sequenceofstrormapping, optional) –Optional additional metrics to compute.
If a sequence of strings (e.g.
['pss', 'pit']), each name is resolved viageoprior.metrics._registry.get_metric(). If the name is not present in the registry, an error is raised, prompting the user to pass a callable instead.If a mapping
{name: func}, eachfuncis called as:func(y_true, y_pred, **extra_metric_kwargs.get(name, {}))
where
y_predis the median (Q50) or point forecast.
For more complex metrics that require full quantile structure or temporal sequences, pass a suitable wrapper function that internally uses the DataFrame as needed.
extra_metric_kwargs (
mapping, optional) – Optional mapping of per-metric keyword arguments. Keys must match the names inextra_metrics. Each value is a dict of kwargs forwarded to the corresponding metric function.savefile (
str,path-like, orbool, optional) –If provided, metrics are saved to disk.
If
True: a filename is auto-generated neareval_data(if it is a path) or in the current working directory.If a string/path without extension: the extension is taken from
save_format.If a string/path with extension: that extension takes precedence over
save_format.
save_format (
{'.json', 'json', '.csv', 'csv'}, default'.json') –Output format when
savefileis truthy. JSON preserves nested structure; CSV is flattened into a tall table.For JSON, the function returns the metrics dictionary.
For CSV, the function returns the metrics DataFrame.
time_as_str (
bool, defaultTrue) – IfTrue, time keys in the result dictionary are converted to strings (useful for JSON serialization). If there is only a single time value, the result is flattened and the time key is omitted.verbose (
int, default1) – Verbosity level passed tovlog().logger (
logging.Logger, optional) – Optional logger instance used byvlog().overall_key (str | None)
- Returns:
results – If
save_formatis JSON (default), returns a dict:Single time value:
{ "overall_mae": ..., "overall_mse": ..., "overall_r2": ..., "coverage80": ..., "sharpness80": ..., "per_horizon_mae": {1: ..., 2: ..., ...}, ... }
Multiple time values:
{ "2021": { ...metrics... }, "2022": { ...metrics... }, }
If
save_formatis CSV, returns a DataFrame with flattened rows:Columns include:
coord_t,metric,horizon, andvalue.
- Return type:
Notes
Default metrics in quantile mode:
overall_mae,overall_mse,overall_r2coverage80andsharpness80(using the requested interval, e.g., Q10–Q90)
If
per_horizon=True, also:per_horizon_mae,per_horizon_mse,per_horizon_r2(each a mapping from horizon index to score).
Default metrics in point mode (no quantiles):
mae,mse,r2
And optionally, if
per_horizon=True:per_horizon_mae,per_horizon_mse,per_horizon_r2.
This is one of the most frequently visited workflow helper modules because it turns saved or live outputs into reusable forecast tables, aligned metrics, and export-ready structures.
Generic helpers#
Provides common helper functions and for validation, comparison, and other generic operations
- class geoprior.utils.generic_utils.ExistenceChecker[source]
Bases:
objectA utility class for checking and ensuring the existence of files and directories on the filesystem.
This class provides static methods to verify whether a given path exists and to create directories or files if necessary. It raises informative exceptions when paths are invalid or cannot be created.
- ensure_directory(path)[source]
Ensure a directory exists at the specified path.
- ensure_file(path, create_parent_dirs=False)[source]
Ensure a file exists at the specified path, optionally creating parent directories.
Examples
>>> from geoprior.utils.generic_utils import ExistenceChecker >>> # Ensure a directory exists >>> dir_path = ExistenceChecker.ensure_directory("data/output") >>> isinstance(dir_path, Path) True >>> # Ensure a file exists, creating parent directories >>> file_path = ExistenceChecker.ensure_file( ... "data/output/results.txt", create_parent_dirs=True ... ) >>> file_path.exists() True
Notes
Uses pathlib.Path.mkdir(…, parents=True, exist_ok=True) under the hood to create directories.
Creating a file will produce an empty file if it does not exist.
Raises TypeError if the given path is not a str or pathlib.Path, and appropriate OSError/FileExistsError for filesystem errors.
See also
pathlib.Path.mkdirMethod to create directories.
pathlib.Path.touchMethod to create an empty file.
os.makedirsLegacy function for creating directories recursively.
os.path.existsCheck if a path exists.
- static ensure_directory(path)[source]
Ensure that a directory exists at the given path, creating it if needed.
- Parameters:
path (
strorpathlib.Path) – The filesystem path for which to ensure directory existence. Can be either a string or a pathlib.Path object.- Returns:
A Path object pointing to the existing (or newly created) directory.
- Return type:
- Raises:
TypeError – If path is not a string or pathlib.Path.
FileExistsError – If a file (not a directory) already exists at path.
OSError – If the directory cannot be created for any other reason (e.g., insufficient permissions).
- static ensure_file(path, create_parent_dirs=False)[source]
Ensure that a file exists at the given path, creating it if needed.
If create_parent_dirs is True, any missing parent directories will be created automatically.
- Parameters:
path (
strorpathlib.Path) – The filesystem path for the file that must exist.create_parent_dirs (
bool, optional) – If True, create any missing parent directories. Default is False.
- Returns:
A Path object pointing to the existing (or newly created) file.
- Return type:
- Raises:
TypeError – If path is not a string or pathlib.Path.
FileExistsError – If a directory (not a file) already exists at path.
OSError – If the file or parent directories cannot be created due to filesystem errors.
- geoprior.utils.generic_utils.ensure_directory_exists(path)[source]
Ensure that a directory exists at the given path, creating it if needed.
This function checks whether the provided path exists and is a directory. If the path does not exist, it attempts to create the directory (including any necessary parent directories). If a file with the same name already exists, or if creation fails, an exception is raised.
- Parameters:
path (
strorpathlib.Path) – The filesystem path for which to ensure directory existence. Can be either a string or a pathlib.Path object.- Returns:
A Path object pointing to the existing (or newly created) directory.
- Return type:
- Raises:
TypeError – If path is not a string or pathlib.Path.
FileExistsError – If a file (not a directory) already exists at path.
OSError – If the directory cannot be created for any other reason (e.g., insufficient permissions).
Examples
>>> from pathlib import Path >>> from geoprior.utils.generic_utils import ensure_directory_exists >>> output_dir = ensure_directory_exists("data/output") >>> isinstance(output_dir, Path) True >>> # The directory "data/output" now exists on disk.
Notes
Uses pathlib.Path.mkdir(…, parents=True, exist_ok=True) under the hood for cross-platform compatibility.
If path already exists as a directory, this function returns immediately without modifying it.
See also
pathlib.Path.mkdirMethod to create a directory.
os.makedirsLegacy function for creating directories recursively.
- geoprior.utils.generic_utils.verify_identical_items(list1, list2, mode='unique', ops='check_only', error='raise', objname=None)[source]
Check if two lists contain identical elements according to the specified mode.
In “unique” mode, the function compares the unique elements in each list. In “ascending” mode, it compares elements pairwise in order.
- Parameters:
list1 (
list) – The first list of items.list2` (
list) – The second list of items.mode (
{'unique', 'ascending'}, default"unique") –- The mode of comparison:
”unique”: Compare unique elements (order-insensitive).
”ascending”: Compare each element pairwise in order.
ops (
{'check_only', 'validate'}, default"check_only") – If “check_only”, returns True/False indicating a match. If “validate”, returns the validated list.error (
{'raise', 'warn', 'ignore'}, default"raise") – Specifies how to handle mismatches.objname (
str, optional) – A name to include in error messages.
- Returns:
Depending on ops, returns True/False or the validated list.
- Return type:
Examples
>>> from geoprior.utils.generic_utils import verify_identical_items >>> list1 = [0.1, 0.5, 0.9] >>> list2 = [0.1, 0.5, 0.9] >>> verify_identical_items(list1, list2, mode="unique", ops="validate") [0.1, 0.5, 0.9] >>> verify_identical_items(list1, list2, mode="ascending", ops="check_only") True
Notes
In “ascending” mode, both lists must have the same length, and the function compares each corresponding pair of elements. In “unique” mode, the function uses the set of unique values for comparison. If the lists contain mixed types, the function attempts to compare their string representations.
- geoprior.utils.generic_utils.vlog(message, verbose=None, level=3, depth='auto', mode=None, vp=True, logger=None, **kws)[source]
Log or naive messages with optional indentation and bracketed tags.
This function, vlog, allows conditional logging or printing of messages based on a global or passed in <parameter inline> verbose level. By default, it behaves differently depending on whether
modeis'log'or'naive'. When \(mode = 'log'\), the message is printed only if \(\text{verbose} \geq \text{level}\). Otherwise, for \(mode\) in [None,'naive'], the verbosity threshold leads to various bracketed prefixes (e.g. [INFO], [DEBUG], [TRACE]) unless the message already contains such a prefix.(10)#\[\text{indentation} = 2 \times \text{depth}\]where \(\text{depth}\) is either manually specified or auto-derived based on <parameter inline> level (1 = ERROR, 2 = WARNING, 3 = INFO, 4/5 = DEBUG, 6/7 = TRACE).
- Parameters:
message (
str) – The text to be printed or logged.verbose (
int, optional) – Overall verbosity threshold. IfNone, it looks for a global variable namedverbose. Default isNone.level (
int, default3) –Severity or importance level of the message. Commonly:
1 = ERROR
2 = WARNING
3 = INFO
4,5 = DEBUG
6,7 = TRACE
depth (
intorstr, default"auto") – Indentation level used for the printed message. If"auto", the depth is computed from <parameter inline> level.mode (
str, optional) – Determines logging mode. If set to'log', prints messages only if \(\text{verbose} \geq \text{level}\). Otherwise (ifNoneor'naive'), it follows a custom logic driven by <parameter inline> verbose.vp (
bool, defaultTrue) – IfTrue, the function automatically prepends bracketed tags (e.g. [INFO]) unless the message already contains one of [INFO], [DEBUG], [ERROR], [WARNING], or [TRACE].logger (
logging.LoggerorCallable[[str],None], optional) –Custom sink that receives the already-formatted message string.
If you pass a standard :pyclass:`logging.Logger` instance, the message is routed through
logger.info.If you supply any
callablethat accepts a singlestr(e.g. a GUI text-append function), that callable is invoked directly.Defaults to :pyfunc:`print`, which writes to stdout.
kws (
Logging instance, optional) – For future extensions.
- Returns:
This function does not return anything. It either prints the message to stdout or omits it, depending on <parameter inline> verbose, <parameter inline> level, and
mode.- Return type:
Notes
This function is helpful for selectively displaying or logging messages in applications that adapt to the user’s required verbosity. By default, each level has a specific bracketed tag and an auto indentation depth.
Examples
>>> from geoprior.utils.generic_utils import vlog >>> # Example with mode='log' >>> # This prints only if global or passed-in >>> # verbose >= 4. >>> vlog("Check debugging details.", verbose=3, ... level=4, mode='log') >>> # Example with mode='naive' >>> # If verbose=2, it displays as [INFO] prefixed. >>> vlog("Loading data...", verbose=2, mode='naive')
See also
globalsUsed to retrieve the fallback verbose value if not explicitly passed.
- geoprior.utils.generic_utils.detect_dt_format(series)[source]
Detect the datetime format of a pandas Series containing datetime values.
This function inspects a non-null sample from the datetime Series and infers the format string based on its components (year, month, day, hour, minute, and second). It returns a format string that can be used with
strftime. For example, if the sample indicates only a year is relevant, it returns"%Y"; if full date information is present, it returns"%Y-%m-%d"; and if time details are also present, it extends the format accordingly.- Parameters:
series (
pandas.Series) – A Series containing datetime values (dtype datetime64).- Returns:
A datetime format string (e.g.,
"%Y","%Y-%m-%d", or"%Y-%m-%d %H:%M:%S") that represents the resolution of the data.- Return type:
Examples
>>> from geoprior.utils.generic_utils import detect_dt_format >>> import pandas as pd >>> dates = pd.to_datetime(['2023-01-01', '2024-01-01', '2025-01-01']) >>> fmt = detect_dt_format(pd.Series(dates)) >>> print(fmt) %Y
Notes
The detection logic checks if month, day, hour, minute, and second are all default values (e.g., month == 1, day == 1, hour == 0, etc.) and infers the most compact format that still represents the data accurately.
- geoprior.utils.generic_utils.get_actual_column_name(df, tname=None, actual_name=None, error='raise', default_to=None)[source]
Determines the actual target column name in the given DataFrame.
- Parameters:
df (
pandas.DataFrame) – The DataFrame containing the target column.tname (
str, optional) – The base target name (e.g., “subsidence”). If not found in the DataFrame, it will attempt to find a matching column using “<tname>_actual” format.actual_name (
str, optional) – If provided, this name will be returned as the actual target column name.error (
{'raise', 'warn', 'ignore'}, default'raise') – Specifies how to handle the case when no valid column is found: - ‘raise’: Raises a ValueError. - ‘warn’: Issues a warning and returns None. - ‘ignore’: Silently returns None.
- Returns:
The determined actual column name, or None if no match is found and error=’warn’ or error=’ignore’.
- Return type:
- Raises:
ValueError – If no valid target column is found and error=’raise’.
Examples
>>> from geoprior.utils.generic_utils import get_actual_column_name >>> df = pd.DataFrame({'subsidence_actual': [1, 2, 3]}) >>> get_actual_column_name(df, tname="subsidence") 'subsidence_actual'
>>> df = pd.DataFrame({'subsidence': [1, 2, 3]}) >>> get_actual_column_name(df, tname="subsidence") 'subsidence'
>>> df = pd.DataFrame({'actual': [1, 2, 3]}) >>> get_actual_column_name(df) 'actual'
>>> df = pd.DataFrame({'measurement': [1, 2, 3]}) >>> get_actual_column_name(df, tname="subsidence", error="warn") Warning: Could not determine the actual target column in the DataFrame. None
- geoprior.utils.generic_utils.transform_contributions(contributions, to_percent=True, normalize=False, norm_range=(0, 1), scale_type=None, zero_division='warn', epsilon=1e-06, log_transform=False)[source]
Converts the feature contributions either to a direct percentage, normalizes them to a custom range, or applies a scaling strategy based on the chosen parameters.
- Parameters:
- contributions
dict A dictionary where keys are feature names and values are the feature contributions. Each value is expected to be a numerical value representing the contribution of the respective feature.
- to_percentbool,
optional, default=True Whether to convert the contributions to percentages. If True, each value in contributions will be multiplied by 100. This is useful when contributions are given in decimal form but are expected as percentages.
- normalizebool,
optional, default=False Whether to normalize the contributions using min-max scaling. If True, the values will be scaled to the range defined in
norm_range.- norm_range
tuple,optional, default=(0, 1) A tuple specifying the range (min, max) for normalization. This range is applied when normalize is set to True. The contributions will be rescaled so that the minimum value maps to norm_range[0] and the maximum value maps to norm_range[1].
- scale_type
str,optional, default=None The scaling strategy. Options include: -
'zscore': Performs Z-score normalization. -'log': Applies a logarithmic transformation to the data. If None, no scaling is applied.- zero_division
str,optional, default=’warn’ Defines how to handle zero or missing values in the contributions. Options include: -
'skip': Skips zero values (no modification). -'warn': Issues a warning if zero values are found. -'replace': Replaces zeros with a small value defined byepsilonto avoid division by zero or undefined results.- epsilon
float,optional, default=1e-6 A small value used to replace zeros when zero_division is set to
'replace'. This prevents division by zero errors during transformations like Z-score or log transformation.- log_transformbool,
optional, default=False Whether to apply a logarithmic transformation to the contributions. If True, it applies the natural logarithm to each value in the contributions dictionary. Only positive values are valid for log transformation, and zero values are either skipped or replaced based on the
zero_divisionparameter.
- contributions
- Returns:
dictA dictionary with feature names as keys and the transformed feature contributions as values. The transformation is applied according to the chosen parameters.
See also
numpy.meanCompute the arithmetic mean of an array.
numpy.stdCompute the standard deviation of an array.
rac{X - mu}{sigma}
where \(X\) is the contribution, \(\mu\) is the mean of the contributions, and \(\sigma\) is the standard deviation of the contributions.
If
log_transform=True, the function applies the natural logarithm:(11)#\[ext{log}(X) ext{ for } X > 0\]The
zero_divisionparameter handles zero values by either skipping, warning, or replacing them with a small value (epsilon).
Examples
>>> from geoprior.utils.generic_utils import transform_contributions >>> contributions = { >>> 'GWL': 2.226836617133828, >>> 'rainfall_mm': 12.398293851061492, >>> 'normalized_seismic_risk_score': 0.9402759347406523, >>> 'normalized_density': 4.806074194258057, >>> 'density_concentration': 5.666943330566496e-06, >>> 'geology': 1.2798872011280326e-05, >>> 'density_tier': 1.044039559604414e-05, >>> 'rainfall_category': 0.0 >>> } >>> transform_contributions(contributions, to_percent=True, normalize=True) >>> transform_contributions(contributions, to_percent=False, scale_type='zscore')
- geoprior.utils.generic_utils.exclude_duplicate_kwargs(func, existing_kwargs, user_kwargs)[source]
Prevents the user from overriding existing parameters in a target function. The method exclude_duplicate_kwargs checks both developer-specified and function-level parameter names to exclude them from user_kwargs.
(12)#\[ext{final\_kwargs} = \{\,(k, v) \in ext{user\_kwargs} \,\mid\, k\]otin ext{protected_params},}
- Parameters:
- func
callable() The target function whose valid parameters are checked. It uses Python’s introspection to gather the acceptable parameter names.
- existing_kwargs
dictorlist Developer-defined parameters to protect. Can be: * A dictionary of parameter-value pairs (e.g.,
{'ax': ax_obj, 'data': df}) whose keys are excluded from user overrides.A list of parameter names (e.g.,
['ax', 'data']) to protect from user overrides.
- user_kwargs
dict The user-supplied keyword arguments that are candidates for merging with existing_kwargs. This dictionary is filtered to remove collisions with protected parameters.
- func
- Returns:
dictA filtered dictionary of user-defined arguments that do not overlap with protected parameters.
- Parameters:
- Return type:
See also
inspect.signatureUsed to introspect function parameters.
filter_valid_kwargsAnother inline function that discards user params not valid for a given function.
Notes
By default, if existing_kwargs is a dictionary, its keys are treated as protected parameter names. If it’s a list, those items are protected. The function signature of func is also used to verify that only recognized parameters are protected. Keyword-filtering patterns like this are covered in Beazley and Jones [26].
Examples
>>> from geoprior.utils.generic_utils import exclude_duplicate_kwargs >>> import seaborn as sns >>> # Developer has some base kwargs ... base_kwargs = { ... 'x': 'species', ... 'y': 'sepal_length', ... 'palette': 'viridis' ... } >>> # User tries to override 'x' with new param ... user_args = { ... 'x': 'petal_width', ... 'color': 'red' ... } >>> # Filter out duplicates ... safe_args = exclude_duplicate_kwargs( ... sns.scatterplot, ... base_kwargs, ... user_args ... ) >>> safe_args {'color': 'red'}
- geoprior.utils.generic_utils.reorder_columns(df, columns, pos='end')[source]
Reorder columns in a DataFrame by moving specified columns to a chosen position.
This function locates <columns> in the original DataFrame <df> and rearranges them based on the parameter
pos. Ifposis “end”, columns are appended to the end. If “begin” or “start”, they are placed at the front. If “center”, they are inserted at the midpoint:- Parameters:
df (
pandas.DataFrame) – The input DataFrame to be modified.columns (
stroriterableofstr) – A single column name or multiple column names to reposition. If a single string is given, it is converted to a list with one element.pos (
str,int, orfloat, default :py:class:``”end”:py:class:``) –- Determines the target placement:
"end": Append after all other columns."begin"or"start": Prepend at the start."center": Insert at the midpoint of remaining columns.integer or float: Insert at zero-based index among the remaining columns. If out of bounds, the original DataFrame is returned unchanged.
- Returns:
A new DataFrame with <columns> moved as specified by
pos.- Return type:
- `reorder_columns_in`
This method rearranges columns without altering values or data order beyond column placement.
Notes
The function checks if <columns> exist in <df>, ignoring columns not present.
A warning is issued if the position is beyond the range of valid indices.
Negative indices for integer
posare converted to positive by adding the total number of remaining columns.
(13)#\[i_{\text{center}} = \left\lfloor \frac{|R|}{2} \right\rfloor,\]where \(|R|\) is the number of remaining columns after removing the target columns. For integer or float
pos, the target columns are inserted at index \(\lfloor pos \rfloor\) among the remaining columns. Column-order management follows common DataFrame practices discussed in McKinney [27].Examples
>>> from geoprior.utils.generic_utils import reorder_columns >>> import pandas as pd >>> data = pd.DataFrame({ ... 'id': [1, 2, 3], ... 'latitude': [10.1, 10.2, 10.3], ... 'landslide': [0, 1, 0], ... 'longitude': [20.1, 20.2, 20.3] ... }) >>> # Move 'landslide' to the end (default) >>> reorder_columns(data, 'landslide', pos="end") id latitude longitude landslide 0 1 10.1 20.1 0 1 2 10.2 20.2 1 2 3 10.3 20.3 0
See also
pandas.DataFrame.reindexPandas method for reindexing or reordering columns more generally.
- geoprior.utils.generic_utils.find_id_column(df, strategy='naive', regex_pattern=None, uniqueness_threshold=0.95, errors='raise', empty_as_none=True, as_list=False, case_sensitive=False, as_frame=False)[source]
Identify potential ID column(s) in a pandas DataFrame using multiple heuristic strategies.
The function examines column names and/or data properties to detect columns likely to serve as unique identifiers. This is particularly useful for large datasets where the ID field is not explicitly labeled, and for quick scanning of possible key columns.
- Parameters:
df (
pandas.DataFrame) – The input DataFrame in which to search for potential ID columns.strategy (
{'naive', 'exact', 'dtype', 'regex','prefix_suffix'}, default'naive') –Defines the logic for detecting ID columns: - exact: Checks for a column name that exactly
matches id (case sensitivity controlled by
case_sensitive).naive: Searches for columns where id is part of the name (e.g., location_id) subject to case sensitivity.
prefix_suffix: Considers columns prefixed or suffixed with id or _id.
dtype: Examines columns having data types commonly used for IDs (integer, string, or object) and checks if they show high uniqueness via \(\text{uniqueness\_ratio} \geq \text{uniqueness\_threshold}\).
<regex>: Uses a custom regular expression
<regex_pattern>to find matches in column names.
regex_pattern (
str, optional) – Required if strategy is ‘regex’. The pattern is compiled via re.compile, with case sensitivity determined by <case_sensitive>.uniqueness_threshold (
float, default0.95) –For <dtype> strategy, columns are flagged as ID candidates if the ratio:
(14)\[r = \frac{ \text{unique\_values} }{ \text{non\_NA\_rows} }\]satisfies \(r \geq \text{uniqueness\_threshold}\), or if the number of unique values equals the number of non-null rows.
errors (
{'raise', 'warn', 'ignore'}, default'raise') –- How to handle no-match cases:
raise: Raises a ValueError.
warn: Issues a UserWarning and returns based on <as_frame> or <empty_as_none>.
ignore: Returns an empty result based on the same parameters without warning.
empty_as_none (
bool, defaultTrue) – Applies only if `as_frame` is False. Defines whether to return None (if True) or an empty list (if False) when no ID column is found and <errors> is ‘warn’ or ‘ignore’.as_list (
bool, defaultFalse) – If True, return all matched columns. If False, return only the first match. Affects both name returns and DataFrame returns.case_sensitive (
bool, defaultFalse) – If False, comparisons (including regex) are performed in a case-insensitive manner.as_frame (
bool, defaultFalse) – If True, return the matched columns as a pandas DataFrame. If as_list is True, it may include multiple columns. If no column is found, returns an empty DataFrame (if <errors> is ‘warn’ or ‘ignore’).
- Returns:
Depends on as_frame, as_list, and the number of matching columns: - `<as_frame>`=False, `as_list`=False:
returns the first match as a string, or None/[].
`as_frame`=False, `as_list`=True: returns all matching column names as a list of strings.
`as_frame`=True, `as_list`=False: returns a DataFrame with the first matched column. If no match is found, an empty DataFrame may be returned.
`as_frame`=True, `as_list`=True: returns a DataFrame with all matched columns included.
- Return type:
strorList[str]orpandas.DataFrameorNone
Notes
For <dtype> strategy, integer, string, and object columns are inspected. The function calculates a uniqueness ratio and compares it against <uniqueness_threshold>.
Negative or zero thresholds are invalid, as are values above 1.
If the DataFrame has no columns or is empty, the behavior is determined by <errors>.
The relational-model motivation for schema-oriented column handling goes back to Codd [28].
Examples
>>> from geoprior.utils.generic_utils import find_id_column >>> import pandas as pd >>> data = pd.DataFrame({ ... 'ID_code': [101, 102, 103], ... 'Name': ['Alice', 'Bob', 'Charlie'], ... 'value': [10, 20, 30] ... }) >>> # Example using the 'naive' strategy >>> col = find_id_column(data, strategy='naive') >>> print(col) # Might return 'ID_code' >>> # Example with as_list=True >>> cols = find_id_column(data, strategy='naive', ... as_list=True) >>> print(cols) # ['ID_code']
See also
re.compileThe regex compilation method used when `strategy`=’regex’.
pandas.api.types.is_integer_dtypeChecks integer type.
pandas.api.types.is_string_dtypeChecksstring type.
pandas.api.types.is_object_dtypeChecksobject type.
- geoprior.utils.generic_utils.check_group_column_validity(df, group_col, ops='check_only', max_unique=10, auto_bin=False, bins=4, error='warn', bin_labels=None, verbose=True)[source]
Validate a grouping column for categorical-style use and optionally bin it.
- Parameters:
df (
pandas.DataFrame) – Input DataFrame holding the grouping column.group_col (
str) – Name of the candidate grouping column indf.ops (
{'check_only', 'binning', 'validate'}, optional) – Operation mode. Use"check_only"to return a boolean,"binning"to bin the column when needed and return a modified DataFrame, or"validate"to check validity while honoringerror.max_unique (
int, optional) – Maximum number of unique numeric values allowed before the column is treated as too continuous for categorical use.auto_bin (
bool, optional) – Whether to auto-bin a numeric column whenops='binning'.bins (
int, optional) – Number of bins to create when binning is applied.error (
{'warn', 'raise', 'ignore'}, optional) – Policy used when validation fails.bin_labels (
listofstrorNone, optional) – Custom labels for generated bins.verbose (
bool, optional) – Whether to emit informational messages.
- Returns:
Returns a boolean for
ops='check_only'. Otherwise returns a DataFrame, possibly with a transformedgroup_col.- Return type:
Notes
When quantile binning is used, interval boundaries are derived from the numeric distribution of
group_col.
- geoprior.utils.generic_utils.save_all_figures(output_dir='figures', prefix='figure', fmts=('png',), close=True, dpi=150, transparent=False, timestamp=True, verbose=True)[source]
Save all currently open Matplotlib figures to disk in specified formats.
- Parameters:
output_dir (
str) – Directory where figures will be saved. Created if not exists.prefix (
str) – Filename prefix for each figure.formats (
listortupleofstr) – File formats/extensions to use (e.g., (‘png’,’pdf’)).close (
bool) – Whether to close each figure after saving. Default is True.dpi (
intorNone) – Resolution in dots per inch. None uses Matplotlib default.transparent (
bool) – Whether to save figures with transparent background.timestamp (
bool) – Append current timestamp (YYYYmmddTHHMMSS) to filenames.verbose (
bool) – Print progress messages.
- Returns:
List of saved file paths.
- Return type:
List[str]
Examples
>>> import matplotlib.pyplot as plt >>> plt.figure(); plt.plot([1, 2, 3]) >>> from geoprior.utils.generic_utils import save_all_figures >>> paths = save_all_figures(output_dir="plots", formats=("png",)) >>> print(paths) ['plots/figure_1_20250521T153045.png']
- geoprior.utils.generic_utils.rename_dict_keys(data, param_to_rename=None, order='forward')[source]
Renames keys in the data dictionary based on the provided param_to_rename dictionary.
This function will check if the key exists in the data dictionary. If the key is present, it will be renamed according to the mapping provided in the param_to_rename dictionary. If the key is not found in data and a mapping exists in param_to_rename, the function will apply the rename. If no rename is required, the function will return the original dictionary.
- Parameters:
data (
dict) – The dictionary whose keys may be renamed. The function will iterate over the keys of this dictionary and rename them according to the mapping provided in param_to_rename.param_to_rename (
dict, optional) – A dictionary mapping old keys to new keys. Each key in this dictionary represents an old key that may be found in data, and the corresponding value is the new key. If None, no renaming is performed. If a key in data matches an old key in param_to_rename, that key will be renamed.order (
str,{'forward', 'reverse'}:) –Order for renaming keys in a flat dict:
forward (default): param_to_rename = {old_key: new_key} reverse: param_to_rename = { canonical_key: alias or (alias1, alias2, ...) } The first alias found in `data` is moved under the canonical key. If the canonical key already exists, nothing is changed for that mapping.
- Returns:
The updated dictionary with keys renamed as per the param_to_rename mapping. If no keys need renaming, the original dictionary is returned.
- Return type:
- Raises:
ValueError – If param_to_rename is not a dictionary, a ValueError will be raised.
Examples
>>> from geoprior.utils.generic_utils import Example 1: Renaming a key in the dictionary:
>>> data = {"subsidence": 100} >>> param_to_rename = {"subsidence": "subs_pred"} >>> rename_dict_keys(data, param_to_rename) {'subs_pred': 100}
Example 2: When the key is already valid (no change needed):
>>> data = {"subs_pred": 100} >>> param_to_rename = {"subsidence": "subs_pred"} >>> rename_dict_keys(data, param_to_rename) {'subs_pred': 100}
Example 3: When param_to_rename is None, no renaming is performed:
>>> data = {"subsidence": 100} >>> rename_dict_keys(data) {'subsidence': 100}
Notes
If param_to_rename is None, no renaming occurs, and the data dictionary is returned as is.
This function raises an error if param_to_rename is not a dictionary. Ensure that the parameter is a valid dictionary of old-to-new key mappings.
- geoprior.utils.generic_utils.normalize_time_column(df, time_col, datetime_col='datetime_temp', year_col='year_int', drop_orig=False)[source]
Normalize a time column into a datetime column and an integer year.
The input column may contain integer years, strings, or existing pandas Datetime values. The function creates
datetime_colwith parsed timestamps andyear_colwith the extracted integer year. Whendrop_orig=True, the originaltime_colis removed anddatetime_colis renamed back totime_col.- Parameters:
df (
pandas.DataFrame) – Input DataFrame containing a time column namedtime_col.time_col (
str) – Name of the column to normalize.datetime_col (
str, default'datetime_temp') – Name of the parsed datetime column.year_col (
str, default'year_int') – Name of the extracted integer year column.drop_orig (
bool, defaultFalse) – IfTrue, drop the originaltime_colafter parsing and renamedatetime_colback totime_col.
- Returns:
A copy of
dfwith the parsed datetime column and integer year column.- Return type:
- Raises:
ValueError – If
time_colis missing or parsing fails for any entry.TypeError – If
dfis not a pandas DataFrame.
- geoprior.utils.generic_utils.select_mode(mode=None, default='pihal_like', canonical=None)[source]
Resolve a user-supplied mode alias to a canonical value.
- Parameters:
mode (
strorNone, optional) – Case-insensitive mode alias. Accepted values include'pihal','pihal_like','tft','tft_like', orNoneto fall back todefault.default (
{'pihal', 'tft'}, optional) – Canonical value returned whenmodeisNone.canonical (
dictorlistorNone, optional) – Custom alias mapping. A dictionary maps input strings to canonical values. A list is treated as an identity mapping for its items.
- Returns:
Canonical string corresponding to the resolved mode.
- Return type:
- Raises:
ValueError – If
modedoes not match any accepted alias.
- geoprior.utils.generic_utils.normalize_model_inputs(*data)[source]
- geoprior.utils.generic_utils.print_config_table(sections, title=None, table_width=None, sort_keys=True, key_col_fraction=0.35, max_value_length=200, log_fn=None)[source]
Pretty-print configuration or hyperparameters as a key/value table.
This helper is intended for CLI scripts (Stage-1, training, tuning) so that the user can quickly inspect which parameters are actually in effect.
- Parameters:
sections (
dictorsequenceof(str,dict)) –If a single dict is passed, all key/value pairs are printed in one block.
If a sequence is passed, it must contain
(name, params)tuples, wherenameis a section label (e.g."Physics") andparamsis a dict mapping parameter names to values.title (
str, optional) – Optional title displayed above the table (centered).table_width (
int, optional) – Total width of the printed table. IfNone, the function tries to usegeoprior.api.util.get_table_size(). If that fails, it falls back to the terminal width (viashutil.get_terminal_size) or 80 characters.sort_keys (
bool, defaultTrue) – Whether to sort parameter names alphabetically within each section.key_col_fraction (
float, default0.35) – Fraction of the table width allocated to the parameter-name column. The remainder is used for the value column.max_value_length (
int, default200) – Maximum number of characters kept from the stringified value. Longer values are truncated with an ellipsis ("...") before being wrapped onto multiple lines.log_fn (
callable, optional) – Function used to emit lines (defaults toprint()). This allows capturing the table in logs if needed.
- Returns:
The full rendered table as a single string. It is always printed via
print_fnas a side effect.- Return type:
Notes
Nested containers (lists, tuples, dicts) are rendered in a compact one-line form and then wrapped to fill the value column.
This function is intentionally lightweight and does not depend on external tabulation libraries, so it can be safely used in lightweight Stage-1 / Stage-2 scripts.
This module contains reusable helpers for paths, result folders, configuration display, and figure-saving behavior.
Geo/spatial helpers#
Geospatial utility helpers for GeoPrior workflows.
- geoprior.utils.geo_utils.augment_city_spatiotemporal_data(df, city, mode='interpolate', group_by_cols=None, time_col=None, value_cols_interpolate=None, feature_cols_augment=None, interpolation_config=None, augmentation_config=None, target_name=None, interpolate_target=False, verbose=True, coordinate_precision=None, savefile=None)[source]
Apply grouped spatiotemporal augmentation with city-aware defaults.
This is a convenience wrapper around
augment_spatiotemporal_data. It validates the requested city, optionally rounds coordinates before grouping, and forwards interpolation and augmentation configuration dictionaries.- Parameters:
df (
pandas.DataFrame) – Input DataFrame containing spatial, temporal, and feature columns.city (
{'nansha', 'zhongshan'}) – City identifier used for validation and defaults.mode (
{'interpolate', 'augment_features', 'both'}, optional) – Processing mode forwarded toaugment_spatiotemporal_data.group_by_cols (
listofstrorNone, optional) – Grouping columns for interpolation.time_col (
strorNone, optional) – Time column used for interpolation.value_cols_interpolate (
listofstrorNone, optional) – Columns to interpolate.feature_cols_augment (
listofstrorNone, optional) – Columns to augment with noise.interpolation_config (
dictorNone, optional) – Keyword arguments forinterpolate_temporal_gaps. Typical values include{'freq': 'AS', 'method': 'linear'}.augmentation_config (
dictorNone, optional) – Keyword arguments foraugment_series_features. Typical values include{'noise_level': 0.01, 'noise_type': 'gaussian'}.target_name (
strorNone, optional) – Optional target column name used when inferring default feature sets.interpolate_target (
bool, optional) – Whether the target should be included in default interpolation columns.verbose (
bool, optional) – Whether to emit progress information.coordinate_precision (
intorNone, optional) – Decimal precision applied to coordinates before grouping.savefile (
strorNone, optional) – Optional output CSV path handled by the decorator.
- Returns:
Augmented DataFrame.
- Return type:
- Raises:
ValueError – If
cityormodeis invalid, or if required arguments are missing for the selected mode.TypeError – If the main inputs are of the wrong type.
- geoprior.utils.geo_utils.augment_series_features(series_df, feature_cols, noise_level=0.01, noise_type='gaussian', random_seed=None, savefile=None)[source]
Add random noise to selected numeric feature columns.
- Parameters:
series_df (
pandas.DataFrame) – Input DataFrame representing one or more time series.noise_level (
float, optional) – Magnitude of the added noise. For Gaussian noise it scales the feature standard deviation, and for uniform noise it scales the feature range.noise_type (
{'gaussian', 'uniform'}, optional) – Type of noise distribution to use.random_seed (
intorNone, optional) – Seed for reproducible noise generation.savefile (
strorNone, optional) – Optional output path handled by the decorator.
- Returns:
DataFrame with noise added to the selected feature columns.
- Return type:
- Raises:
ValueError – If requested feature columns are missing or
noise_typeis invalid.TypeError – If the main inputs are of the wrong type.
Notes
Non-numeric columns are skipped, and constant or invalid numeric ranges are left unchanged.
- geoprior.utils.geo_utils.augment_spatiotemporal_data(df, mode, group_by_cols=None, time_col=None, value_cols_interpolate=None, feature_cols_augment=None, interpolation_kwargs=None, augmentation_kwargs=None, savefile=None, verbose=False)[source]
Apply interpolation, feature augmentation, or both to grouped data.
- Parameters:
df (
pandas.DataFrame) – Input spatiotemporal DataFrame.mode (
{'interpolate', 'augment_features', 'both'}) – Processing mode. Use interpolation only, feature augmentation only, or interpolation followed by augmentation.group_by_cols (
listofstrorNone, optional) – Grouping columns used for per-location processing.time_col (
strorNone, optional) – Time column required when interpolation is requested.value_cols_interpolate (
listofstrorNone, optional) – Value columns to interpolate when interpolation is enabled.feature_cols_augment (
listofstrorNone, optional) – Feature columns to perturb when augmentation is enabled.interpolation_kwargs (
dictorNone, optional) – Keyword arguments forwarded tointerpolate_temporal_gaps.augmentation_kwargs (
dictorNone, optional) – Keyword arguments forwarded toaugment_series_features.savefile (
strorNone, optional) – Optional output path handled by the decorator.verbose (
bool, optional) – Whether to emit progress information.
- Returns:
Processed DataFrame assembled from all groups.
- Return type:
- Raises:
ValueError – If
modeis invalid or required arguments for the selected mode are missing.
Notes
Groups are processed independently and concatenated afterward.
- geoprior.utils.geo_utils.generate_dummy_pinn_data(n_samples, *, year_range=None, coords_range=None, subs_range=None, gwl_range=None, rainfall_range=None, vars_range=None)[source]
Generate dummy PINN data dictionary with specified or default ranges.
- Parameters:
n_samples (
int) – Number of samples to generate.year_range (
tuple[float,float], optional) – (min_year, max_year) for integer years. Default (2000, 2025).coords_range (
tuple[tuple[float,float],tuple[float,float]], optional) – ((lon_min, lon_max), (lat_min, lat_max)). Default ((113.0, 113.8), (22.3, 22.8)).subs_range (
tuple[float,float], optional) – (mean_subsidence, std_subsidence) for normal distribution. Default (-20, 15).gwl_range (
tuple[float,float], optional) – (mean_gwl, std_gwl) for normal distribution. Default (2.5, 1.0).rainfall_range (
tuple[float,float], optional) – (min_rain, max_rain) for uniform distribution. Default (500, 2500).vars_range (
dict, optional) – Dictionary that may contain any of the keys: ‘year_range’, ‘coords_range’, ‘subs_range’, ‘gwl_range’, ‘rainfall_range’. Missing keys will fall back to defaults or to explicitly passed arguments.
- Returns:
dummy_data_dict –
- Dictionary with keys:
”year” : integer years array
”longitude” : float longitudes array
”latitude” : float latitudes array
”subsidence” : float subsidence values array
”GWL” : float groundwater level values array
”rainfall_mm” : float rainfall values array
- Return type:
dict[str,np.ndarray]
- geoprior.utils.geo_utils.interpolate_temporal_gaps(series_df, time_col, value_cols, freq=None, method='linear', order=None, fill_limit=None, fill_limit_direction='forward', savefile=None)[source]
Interpolates missing values in specified columns of a time series DataFrame.
This function is designed to work on a DataFrame representing a time series for a single spatial group (e.g., one monitoring location), sorted by time. If
freqis provided, the DataFrame’s index is first reindexed to that frequency, which can create NaN values for missing time steps. These NaNs, along with any pre-existing NaNs invalue_cols, are then interpolated.Let \(t_1 < t_2 < \dots < t_n\) be the original timestamps. If
freqyields a new index \(\{t_i'\}\) that includes times not in the original, NaNs appear at those \(t_i'\). Then for each column \(v\) in \(\{\text{value\_cols}\}\), we perform:(15)#\[\begin{split}v(t) \;=\; \begin{cases} \text{interpolate}(v,\;t;\;\text{method},\;\dots) & \text{for } t \in \{t_i'\}\,,\\ v(t) & \text{if } t \in \{t_1,\dots,t_n\}\text{ and not NaN.} \end{cases}\end{split}\]- Parameters:
series_df (
pd.DataFrame) – Input DataFrame for a single time series, ideally sorted bytime_col. Thetime_colshould be convertible to datetime.time_col (
str) – Name of the column containing datetime information.value_cols (
List[str]) – List of column names whose missing values (NaNs) should be interpolated.freq (
strorNone, defaultNone) – The desired frequency for the time series (e.g., ‘D’ for daily, ‘MS’ for month start, ‘AS’ for year start). If provided, the DataFrame is reindexed to this frequency before interpolation. This helps identify and fill gaps where entire time steps are missing.method (
str, default'linear') – Interpolation method to use. Passed to pandas.DataFrame.interpolate(). Common methods: ‘linear’, ‘time’, ‘polynomial’, ‘spline’. If ‘polynomial’ or ‘spline’,ordermust be specified.order (
intorNone, defaultNone) – Order for polynomial or spline interpolation. Required ifmethodis ‘polynomial’ or ‘spline’.fill_limit (
intorNone, defaultNone) – Maximum number of consecutive NaNs to fill. Passed to pandas.DataFrame.interpolate().fill_limit_direction (
str, default'forward') – Direction forfill_limit(‘forward’, ‘backward’, ‘both’). Passed to pandas.DataFrame.interpolate().savefile (str | None)
- Returns:
DataFrame with specified columns interpolated. If
freqwas used, the DataFrame will have a DatetimeIndex. Other columns not invalue_colswill be forward-filled after reindexing iffreqis set, to propagate their last known values into new empty rows.- Return type:
pd.DataFrame- Raises:
TypeError – If
series_dfis not a DataFrame or ifvalue_colsis not a list of strings. Also iftime_colis missing from the DataFrame.ValueError – If
orderis required but not provided for ‘polynomial’ or ‘spline’.
Examples
>>> import pandas as pd >>> from geoprior.utils.geo_utils import interpolate_temporal_gaps >>> # Sample time series with missing dates >>> df = pd.DataFrame({ ... 'date': ['2020-01-01', '2020-01-03', '2020-01-06'], ... 'value': [1.0, None, 4.0] ... }) >>> df date value 0 2020-01-01 1.0 1 2020-01-03 NaN 2 2020-01-06 4.0 >>> result = interpolate_temporal_gaps( ... df, time_col='date', value_cols=['value'], freq='D' ... ) >>> result.head() date value 0 2020-01-01 1.0 1 2020-01-02 2.0 2 2020-01-03 3.0 3 2020-01-04 3.0 4 2020-01-05 3.5
Notes
Ensure
series_dfpertains to a single spatial group and is sorted by time for meaningful interpolation.The ‘time’ method for interpolation requires the index to be a DatetimeIndex.
Polynomial or spline methods require
orderto be specified.
See also
pandas.DataFrame.interpolateCore interpolation method.
pandas.DataFrame.asfreqReindex DataFrame to fixed frequency.
- geoprior.utils.geo_utils.resolve_spatial_columns(df, spatial_cols=None, lon_col=None, lat_col=None)[source]
Helper to validate and resolve spatial columns.
Accepts either explicit lon/lat columns or a list of spatial_cols. Returns (lon_col, lat_col).
If lon_col and lat_col are both provided, they take precedence (warn if spatial_cols also set).
Else if spatial_cols is provided, it must yield exactly two column names.
Otherwise, error is raised.
- Parameters:
- Returns:
(lon_col, lat_col) – Validated column names for longitude and latitude.
- Return type:
- Raises:
ValueError – If neither lon/lat nor valid spatial_cols is provided, or if spatial_cols len != 2.
- geoprior.utils.geo_utils.merge_frames_to_file(sources, output_path, *, output_format='parquet', compression='snappy', check_columns='strict', excel_mode='all_sheets', sheet_names=None, add_source_label=True, source_col='source', sort_by=None, drop_duplicates=False, reset_index=True, save_kwargs=None, verbose=1)[source]
Merge multiple NATCOM city datasets into a single compressed file.
- Parameters:
sources (
iterableof{path-like, DataFrame}) –Input sources. Each element can be:
A path to a CSV file (e.g.
"nansha_final...csv"),A path to an Excel workbook (one or many sheets per city),
A pre-loaded
DataFrame.
output_path (
path-like) – Destination file path. Ifoutput_format='parquet'and the suffix is missing,'.parquet'is appended.output_format (
{'parquet', 'csv', 'feather', 'pickle'}, optional) – Output format. Default is'parquet'for compact, columnar storage (recommended for Code Ocean).compression (
strorNone, optional) –Compression to use for the chosen format.
For
'parquet'this is passed topandas.DataFrame.to_parquet()(e.g.'snappy','gzip','brotli').For
'csv'it is passed topandas.DataFrame.to_csv()via thecompressionkeyword if non-None.Ignored for
'feather'and'pickle'(Feather uses its own defaults; pickle rarely benefits from extra compression at this layer).
check_columns (
{'strict', 'subset', 'union'}, optional) –How to handle column consistency across sources:
'strict'(default): all sources must have exactly the same set of columns (order may differ). Columns are then aligned to the order of the first DataFrame. A mismatch raisesValueError.'subset': all columns in the first DataFrame must exist in each subsequent DataFrame. Extra columns in later sources are dropped. Missing required columns raiseValueError.'union': columns are unioned across all sources. Any missing column in a particular source is added and filled withNaNbefore concatenation.
excel_mode (
{'all_sheets', 'first_sheet'}, optional) –Behaviour when a source is an Excel workbook:
'all_sheets'(default): read all sheets and treat each sheet as a separate DataFrame to merge.'first_sheet': read only the first sheet.
If
sheet_namesis provided, it takes precedence.sheet_names (
iterableofstr, optional) – Explicit sheet names to read from Excel workbooks. If provided, only these sheets are read.add_source_label (
bool, optional) – IfTrue(default), add a column namedsource_colto each chunk before concatenation. For path-like inputs, the label is derived from the file name and, when applicable, sheet name (e.g."nansha_final_main_std.harmonized.csv"or"zhongshan.xlsx:Sheet1"). For pre-loaded DataFrames, the label is'<in_memory>'.source_col (
str, optional) – Name of the column storing the source label whenadd_source_label=True.sort_by (
iterableofstr, optional) – Optional column(s) to sort the merged DataFrame by at the end (e.g.['city', 'year', 'longitude', 'latitude']).drop_duplicates (
bool, optional) – IfTrue, drop duplicate rows at the end (after sorting).reset_index (
bool, optional) – IfTrue(default), reset index after concatenation.save_kwargs (
dict, optional) – Extra keyword arguments forwarded to the correspondingto_*writer (e.g.to_parquet,to_csv,to_feather,to_pickle).verbose (
int, optional) – Verbosity level.0= silent,>=1prints basic progress information.
- Returns:
merged – The merged DataFrame (also written to disk).
- Return type:
- Raises:
ValueError – If
check_columns='strict'or'subset'and a column mismatch is detected.
Examples
>>> from geoprior.utils.geo_utils import merge_frames_to_file >>> merge_frames_to_file( ... sources=[ ... "nansha_final_main_std.harmonized.csv", ... "zhongshan_final_main_std.harmonized.csv", ... ], ... output_path="natcom_all_cities", ... output_format="parquet", ... compression="snappy", ... sort_by=["city", "year", "longitude", "latitude"], ... )
Notes
All inputs are read fully into memory before concatenation. This is acceptable for the NATCOM subsidence datasets (
O(10^6 - 10^7)rows) but can be refactored to a streaming/row-group approach if needed later.Using
output_format='parquet'with compression (e.g.'snappy') is recommended for Code Ocean to minimise disk usage while keeping I/O efficient.
- geoprior.utils.geo_utils.unpack_frames_from_file(merged, *, group_col='city', output_dir=None, output_format='csv', compression=None, use_source_col=True, source_col='source', filename_pattern='{group_value}_split', drop_columns=None, keep_columns=None, save=True, return_dict=True, save_kwargs=None, verbose=1, logger)[source]
Reverse of merge_city_frames_to_file: split an aggregated NATCOM dataset into per-city frames (and optionally write them to disk).
- Parameters:
merged (
path-likeorDataFrame) –Aggregated dataset. If path-like, the format is inferred from the file suffix:
.parquet→pandas.read_parquet().csv→pandas.read_csv().feather→pandas.read_feather().pkl/.pickle→pandas.read_pickle()
If a
DataFrameis passed, it is used directly.group_col (
str, optional) – Column used to split the dataset (default:'city'). Each unique value defines one output chunk.output_dir (
path-like, optional) – Directory where per-group files are written. IfNoneandmergedis a path, the directory ofmergedis used. Ifmergedis a DataFrame andoutput_dirisNone, the current working directory is used.output_format (
{'csv', 'parquet', 'feather', 'pickle'}, optional) – Output format for per-group files. Default is'csv'.compression (
strorNone, optional) –Compression to use when writing:
For
'csv', forwarded toDataFrame.to_csv()as thecompressionargument (e.g.'gzip').For
'parquet', forwarded toDataFrame.to_parquet()(e.g.'snappy','gzip').Ignored for
'feather'and'pickle'(these use their own defaults).
use_source_col (
bool, optional) –If
True(default) and a column namedsource_colexists, the helper tries to reconstruct the original file name for each group:If a group has a single unique, non-null source value that looks like a filename (e.g.
'nansha_final_main_std.harmonized.csv'), that base name is used for the output (with its suffix adjusted to matchoutput_formatif needed).If there are multiple unique source labels within a group, it falls back to
filename_pattern.
source_col (
str, optional) – Name of the column containing the source label (default:'source'). This should match the column created inmerge_frames_to_file()whenadd_source_label=True.filename_pattern (
str, optional) –Pattern used when no suitable source label is available. The following placeholders are supported:
{group_value}: the group value as a string{group_col}: the name of the grouping column
Example:
filename_pattern="{group_col}_{group_value}_data"→"city_Nansha_data.csv".drop_columns (
iterableofstr, optional) – Columns to drop from each group before saving/returning (e.g.['source']if you don’t want the bookkeeping column).keep_columns (
iterableofstr, optional) – If provided, only these columns are kept (all others are dropped after anydrop_columnsprocessing is applied).save (
bool, optional) – IfTrue(default), write each group to disk as a separate file. IfFalse, no files are written; only the dict of DataFrames is returned (ifreturn_dict=True).return_dict (
bool, optional) – IfTrue(default), return a mapping{group_value: group_df}. IfFalse, an empty dict is returned (useful when you only care about side-effect files).save_kwargs (
dict, optional) – Extra keyword arguments forwarded to the respective writer:DataFrame.to_csv(),DataFrame.to_parquet(),DataFrame.to_feather(), orDataFrame.to_pickle().verbose (
int, optional) – Verbosity level.0= silent,>=1prints progress information.logger (None)
- Returns:
out – Dictionary mapping each group value to the corresponding
DataFrame. Empty ifreturn_dict=False.- Return type:
- Raises:
ValueError – If
group_colis not present in the merged dataset.
Examples
>>> from geoprior.utils.geo_utils import unpack_frames_from_file >>> unpack_frames_from_file( ... "natcom_all_cities.parquet", ... group_col="city", ... output_format="csv", ... ) # -> writes e.g. 'nansha_final_main_std.harmonized.csv', # 'zhongshan_final_main_std.harmonized.csv' (if `source` labels exist), # and returns a dict: {'Nansha': df_nansha, 'Zhongshan': df_zhongshan}
geospatial_utils - A collection of utilities for geospatial and positional data analysis, filtering, and transformations.
- geoprior.utils.spatial_utils.spatial_sampling(data, sample_size=0.01, stratify_by=None, spatial_bins=10, spatial_cols=None, method='abs', min_relative_ratio=0.01, random_state=42, savefile=None, verbose=1)[source]
Sample spatial data intelligently to represent the distribution of the whole area and include different years.
This function performs stratified sampling on spatial data, ensuring that the sample reflects both spatial distribution and temporal aspects of the entire dataset. It combines spatial stratification based on coordinates and additional stratification columns specified by the user.
- Parameters:
data (
pandas.DataFrame) – The input DataFrame to sample from. Must contain spatial coordinate columns (e.g., ‘longitude’, ‘latitude’) and any columns specified instratify_by.sample_size (
floatorint, optional) – The proportion or absolute number of samples to select. If float, should be between 0.0 and 1.0 and represents the fraction of the dataset to include in the sample. If int, represents the absolute number of samples to select. Default is0.01(1% of the data).stratify_by (
listofstr, optional) – List of column names to stratify by.spatial_bins (
intortuple/listofint, optional) – Number of bins to divide the spatial coordinates into. If an integer, the same number of bins is used for all spatial dimensions. If a tuple or list, its length must match the number of spatial columns, specifying the number of bins for each spatial dimension. Default is10.spatial_cols (
listortupleofstr, optional) – List of spatial coordinate column names. Can accept one or two columns. IfNone, the function checks for columns named ‘longitude’ and/or ‘latitude’ indata. If only one spatial column is provided or found, a warning is issued, suggesting that providing both spatial columns is recommended for more accurate sampling. If more than two columns are provided, an error is raised.method (
str,{'abs', 'relative'}, default'abs') – Defines how the sample size is determined.'abs'or'absolute'uses a fixed sampling proportion based onsample_size.'relative'scales sampling by dataset stratification so small groups still receive a proportional sample controlled bymin_relative_ratio.min_relative_ratio (
float, default0.01) – Controls the minimum allowable fraction of records that must be sampled whenmethod='relative'. It must be between0and1. For example,min_relative_ratio=0.05requests at least 5 percent of the total dataset size from each stratification group when possible; if a group is smaller than that minimum, the entire subset is sampled instead.random_state (
int, optional) – Random seed for reproducibility. Default is42.verbose (
int, default1) – Controls progress-bar and status output during execution. Larger values produce more detailed messages.
- Returns:
sampled_data – A sampled DataFrame representing the distribution of the whole area and including different years.
- Return type:
Notes
The function performs stratified sampling based on spatial bins and other specified stratification columns. Spatial coordinates are binned using quantile-based discretization (
pandas.qcut()), ensuring each bin has approximately the same number of observations.Let \(N\) be the total number of samples in
data, and \(n\) be the desired sample size. The function calculates the number of samples to draw from each stratification group based on the proportion of the group size to the total dataset size:(16)#\[n_i = \left\lceil \frac{N_i}{N} \times n \right\rceil\]where \(N_i\) is the size of group \(i\), and \(n_i\) is the number of samples to draw from group \(i\).
The function ensures that all specified spatial and stratification columns exist in
data, that the number of spatial bins matches the number of spatial columns, and that the sample size is valid. A warning is issued when only one spatial column is used because two spatial columns usually give more reliable spatial sampling.Examples
>>> from geoprior.utils.spatial_utils import spatial_sampling >>> import pandas as pd >>> # Assume 'df' is a pandas DataFrame with columns >>> # 'longitude', 'latitude', 'year', and other data. >>> sampled_df = spatial_sampling( ... data=df, ... sample_size=0.05, ... stratify_by=['year', 'geological_category'], ... spatial_bins=(10, 15), ... spatial_cols=['longitude', 'latitude'], ... random_state=42 ... ) >>> print(sampled_df.shape)
See also
pandas.qcutQuantile-based discretization function used for binning.
sklearn.model_selection.StratifiedShuffleSplitFor stratified sampling.
batch_spatial_samplingResample spatial data with batching.
- geoprior.utils.spatial_utils.extract_coordinates(df, as_frame=False, drop_xy=False, error='raise', verbose=0)[source]
Extract coordinate columns or their midpoint from a DataFrame.
- Parameters:
df (
pandas.DataFrame) – DataFrame expected to contain longitude/latitude or easting/northing columns.as_frame (
bool, optional) – IfTrue, return the coordinate columns as a DataFrame. Otherwise return their midpoint.drop_xy (
bool, optional) – IfTrue, remove detected coordinate columns from the returned DataFrame.error (
boolor{'raise', 'warn', 'ignore'}, optional) – Error-handling policy for invalid inputs.verbose (
int, optional) – Verbosity level for detection messages.
- Returns:
Tuple containing the extracted coordinates or midpoint, the DataFrame with optional coordinate removal, and the detected coordinate-column names.
- Return type:
Notes
Longitude/latitude are preferred over easting/northing when both are present.
- geoprior.utils.spatial_utils.batch_spatial_sampling(data, sample_size=0.1, n_batches=10, stratify_by=None, spatial_bins=10, spatial_cols=None, method='abs', min_relative_ratio=0.01, random_state=42, verbose=1)[source]
Create stratified spatial sample batches from a DataFrame.
- Parameters:
data (
pandas.DataFrame) – Input DataFrame used for sampling.sample_size (
floatorint, optional) – Total sample size as a fraction or absolute count.n_batches (
int, optional) – Number of batches to generate.stratify_by (
strorlistofstrorNone, optional) – Additional columns used for stratification.spatial_bins (
intorsequenceofint, optional) – Number of spatial bins used when discretizing coordinates.spatial_cols (
listortupleofstrorNone, optional) – Spatial coordinate columns.method (
{'abs', 'absolute', 'relative'}, optional) – Strategy used to translatesample_sizeinto per-batch sample counts.min_relative_ratio (
float, optional) – Minimum relative group size used bymethod='relative'.random_state (
int, optional) – Random seed for reproducibility.verbose (
int, optional) – Verbosity level.
- Returns:
Stratified batches sampled without overlap.
- Return type:
Notes
Spatial coordinates are discretized with
pandas.qcutand combined withstratify_bycolumns so batches preserve the overall data distribution as closely as possible.
- geoprior.utils.spatial_utils.extract_zones_from(z, threshold='auto', condition='auto', use_negative_criteria=True, percentile=10, x=None, y=None, data=None, view=False, plot_type='scatter', figsize=(8, 6), savefile=None, axis_off=False, show_grid=True, **kwargs)[source]
Extract zones by filtering values against a threshold rule.
- Parameters:
z (
array-like,pandas.Series, orstr) – Input data to filter. Ifzis a string, it is interpreted as a column name indata.threshold (
{'auto'}orfloatorintortuple, optional) – Filtering criterion. Use'auto'for percentile-based thresholding, a scalar for a single cutoff, or a length-2 tuple for interval filtering.condition (
{'auto', 'above', 'below', 'between'}, optional) – Relation between the data and the threshold.use_negative_criteria (
bool, optional) – Controls the automatic condition whencondition='auto'.percentile (
intorfloat, optional) – Percentile used whenthreshold='auto'.x (
array-like,pandas.Series,str, orNone, optional) – Optional coordinates or column names used for plotting.y (
array-like,pandas.Series,str, orNone, optional) – Optional coordinates or column names used for plotting.data (
pandas.DataFrameorNone, optional) – Data source used whenx,y, orzare column names.view (
bool, optional) – Whether to visualize the filtered result.plot_type (
str, optional) – Plot type used whenview=True. Common values include'scatter','line', and'hist'.figsize (
tuple, optional) – Figure size for plotting.savefile (
strorNone, optional) – Optional path used when saving the figure.axis_off (
bool, optional) – Whether to hide axes in the plot.show_grid (
bool, optional) – Whether to display the plot grid.**kwargs (
dict) – Additional plotting keyword arguments.
- Returns:
Filtered values and any optional plotting outputs defined by the implementation.
- Return type:
Notes
When
x,y, orzare passed as strings, the function relies onextract_array_fromto retrieve the corresponding arrays fromdata.
- geoprior.utils.spatial_utils.filter_position(df, pos, pos_cols=None, find_closest=True, threshold=0.01, error='raise', verbose=0)[source]
filter_position is a utility that filters a pandas.DataFrame based on user-specified spatial positions. It can match positions exactly or compute distances to find the closest points within a threshold.
For a single dimension, the distance is computed as:
(17)#\[d = |x - p|\]For multi-dimensional data with n coordinates, the Euclidean distance is computed as:
(18)#\[d = \sqrt{\sum_{i=1}^n (x_i - p_i)^2}\]- Parameters:
df (
pandas.DataFrame) – The DataFrame that will be filtered. This parameter is essential and must contain columns referenced by pos_cols ifpos_colsis not None.pos (
floatortupleoffloats) – The reference position(s) to match or approximate. When pos_cols is None, pos is interpreted as an index value. Otherwise, each value in pos aligns with a specific column in pos_cols.pos_cols (
strortupleofstr, optional) – Name(s) of the column(s) in df to match against pos. Ifpos_cols=None, then pos is treated as an index filter. If multiple columns are given (e.g., latitude and longitude), each coordinate in pos should correspond to one column in pos_cols.find_closest (
bool, optional) – If True, nearest-neighbor filtering is performed within the distance threshold. If False, exact matches are used.threshold (
float, optional) – The maximum distance within which points are considered a match if find_closest is True. The unit corresponds to the column data (e.g., degrees for geographic lat/lon).error (
{'raise', 'warn', 'ignore'}, optional) – Specifies how to handle dimension mismatches or missing values. If'raise', a ValueError will be raised. If'warn', a warning is printed and extra values are ignored. If'ignore', mismatches are silently ignored.verbose (
int, optional) – Controls the level of output messages: - 0: No output - 1: Basic info - 2: Additional details - >=3: Comprehensive summary
- Returns:
A new DataFrame that contains only rows matching or approximating the specified position(s) within the given threshold if find_closest is True.
- Return type:
Notes
When pos_cols is None, the function attempts to filter by DataFrame index using the first element of pos. This approach may fail for multi-level indexes unless
error='warn'orerror='ignore'is used to bypass the dimension mismatch.Examples
>>> from geoprior.utils.spatial_utils import filter_position >>> import pandas as pd >>> df = pd.DataFrame({ ... 'lat': [113.309998, 113.310001], ... 'lon': [22.831362, 22.831364] ... }) >>> # Exact match >>> result_exact = filter_position(df, pos=(113.309998, ... 22.831362), ... pos_cols=('lat', 'lon'), ... find_closest=False) >>> # Nearest match with threshold >>> result_close = filter_position(df, ... pos=(113.31, ... 22.83), ... pos_cols=('lat', ... 'lon'), ... find_closest=True, ... threshold=0.01)
See also
geoprior.utils.data_utils.truncate_dataTruncate multiple DataFrames based on spatial coordinates or index alignment with a base DataFrame.
- geoprior.utils.spatial_utils.create_spatial_clusters(df, spatial_cols=None, cluster_col='region', n_clusters=None, algorithm='kmeans', view=True, figsize=(14, 10), s=60, plot_style='seaborn', cmap='tab20', show_grid=True, grid_props=None, auto_scale=True, savefile=None, verbose=1, **kwargs)[source]
Cluster 2D spatial data in
dfusing <algorithm> and optionally plot the results.This function, <create_spatial_clusters>, extracts two coordinate columns from <df> to form clusters via methods such as ‘kmeans’, ‘dbscan’, or ‘agglo’ (agglomerative). It uses the function filter_valid_kwargs (when relevant) to strip out invalid parameters for certain estimators, and writes cluster labels into <cluster_col>.
- Parameters:
df (
pandas.DataFrame) – Input DataFrame holding spatial coordinates and optional other fields.spatial_cols (
listofstr, optional) – Two-column list for x and y coordinates. Defaults to['longitude','latitude']if None.cluster_col (
str, default'region') – Name of the column to store the assigned cluster labels.n_clusters (
int, optional) – Number of clusters to form. If not provided for KMeans, it is auto-detected. For DBSCAN or Agglomerative, a warning is issued if not set.algorithm (
str, default'kmeans') – Choice of clustering algorithm among['kmeans','dbscan','agglo'].view (
bool, defaultTrue) – If True, displays a scatterplot of the final clusters.figsize (
tuple, default(14,10)) – Size of the displayed figure for the cluster plot.s (
int, default60) – Marker size in the scatterplot.plot_style (
str, default'seaborn') – Matplotlib style used for the plot.cmap (
str, default'tab20') – Colormap name used to differentiate clusters.show_grid (
bool, defaultTrue) – Toggles grid lines on or off.grid_props (
dict, optional) – Additional keyword arguments controlling the grid style.auto_scale (
bool, defaultTrue) – If True, standardize coordinates before clustering.savefile (
str, optional) – File path to save the data with an additional <cluster_col> storing the assigned cluster labels if desired.verbose (
int, default1) – Controls console logs. Higher values yield more details about scaling and cluster detection.**kwargs – Additional keyword arguments passed to the chosen algorithm (filtered by filter_valid_kwargs for KMeans, DBSCAN, AgglomerativeClustering ).
- Returns:
A copy of <df> with an additional <cluster_col> storing the assigned cluster labels.
- Return type:
Notes
If <auto_scale> is True, it uses a standard scaler to normalize the coordinate columns. The scatterplot is generated using the library seaborn for enhanced styling.
By default, for <algorithm> = “kmeans”, the model attempts to minimize:
(19)#\[J = \sum_{i=1}^{N} \min_{\mu_j} \lVert x_i - \mu_j \rVert^2\]where \(x_i\) are the scaled or raw 2D coordinates in <df>. The function can optionally auto-detect
n_clustersusing a silhouette and elbow analysis if not provided.Examples
>>> from geoprior.utils.spatial_utils import create_spatial_clusters >>> import pandas as pd >>> df = pd.DataFrame({ ... "longitude": [0.1, 0.2, 2.2, 2.3], ... "latitude": [1.0, 1.1, 2.1, 2.2] ... }) >>> # KMeans with auto scale and auto-detect k >>> result = create_spatial_clusters( ... df=df, ... algorithm="kmeans", ... view=True ... ) >>> # DBSCAN with custom arguments >>> result_db = create_spatial_clusters( ... df=df, ... algorithm="dbscan", ... eps=0.5, ... min_samples=2 ... )
See also
filter_valid_kwargsHelps discard unsupported keyword arguments for chosen estimators.
- geoprior.utils.spatial_utils.gen_negative_samples(df, target_col, spatial_cols=('longitude', 'latitude'), feature_cols=None, buffer_km=10, neg_feature_range=(0, 5), num_neg_per_pos=1, use_gpd='auto', view=False, savefile=None, verbose=1)[source]
Generate synthetic negative samples for spatial binary classification tasks.
This function creates additional samples labeled as non-events within a specified spatial buffer around the positive (event) observations. The main idea is to generate negative examples that reflect realistic conditions but have not triggered an event, thereby assisting models in distinguishing occurrences from non-occurrences.
- Parameters:
df (
pandas.DataFrame) – Input DataFrame containing the positive samples (events). Must include both the target column and the specified spatial columns.target_col (
str) – Column name for the binary target (e.g. landslide). Rows where this column is 1 (or True) are considered positive samples.spatial_cols (
tupleofstr, default(``’longitude’, ``'latitude')) – Tuple specifying the longitude> and latitude column names in df.feature_cols (
listofstrorNone, defaultNone) – List of feature columns to use or to simulate for generated negatives. IfNone, the function automatically detects numeric and categorical columns excluding spatial_cols and target_col.buffer_km (
float, default10) – Spatial buffer in kilometers used to define the radius around each positive sample within which negative samples are created.neg_feature_range (
tupleofint, default(0,5)) – Value range (minimum, maximum) used for simulating numeric feature values in negative samples if the corresponding feature column does not exist in <df>.num_neg_per_pos (
int, default1) – Number of negative samples to generate per positive sample. For instance, ifnum_neg_per_pos=2, each positive sample spawns two negatives.use_gpd (
str, default'auto') – If set to ‘auto’, the function tries to import GeoPandas for visualization. If ‘none’, no GeoPandas usage will occur.view (
bool, defaultFalse) – Whether to visualize the generated samples on a map. Attempts to use geopandas if installed; falls back to matplotlib if ‘auto’ is chosen and GeoPandas is not available.savefile (
strorNone, defaultNone) – Path to which the resulting DataFrame is saved if provided. Handled by the decorator that wraps this function.verbose (
int, default1) – Verbosity level. 0 for silent, 1 for progress indication, 2 for more messages, 3 for debugging output.
- Returns:
Combined DataFrame with both original positive samples and newly generated negative samples. The <target_col> is 1 for positives and 0 for negatives.
- Return type:
- `columns_manager`
This internal function is used to handle the processing of columns for features and spatial parameters.
Notes
If a feature column exists in df, the negative samples will copy or randomly select categories for categorical columns, and sample integers within
neg_feature_rangefor numeric columns.If feature_cols is empty or does not exist in df, the function simulates all values for negative samples.
When view=True, circles depicting the buffer zone around each positive sample are drawn for visualization.
The approximation of 1 degree to roughly 111 km varies slightly depending on latitude.
Mathematically, we define the spatial buffer in degrees as:
(20)#\[\begin{split}\\Delta = \\frac{\\text{buffer_km}}{111.0},\end{split}\]where \(111.0\) km approximates the distance of one degree of latitude or longitude. For each positive sample at location \((lat, lon)\), we generate \(n\) new points with offsets \(\\delta_{lat}\) and \(\\delta_{lon}\), each drawn from a uniform distribution \(U(-\\Delta, \\Delta)\):
(21)#\[\begin{split}\\begin{aligned} &lat_{new} = lat + \\delta_{lat},\\\\ &lon_{new} = lon + \\delta_{lon}. \\end{aligned}\end{split}\]Combined with randomly sampled or inferred feature values, these new samples serve as negative examples for modeling tasks such as landslide prediction.
Examples
>>> from geoprior.utils.spatial_utils import gen_negative_samples >>> import pandas as pd >>> import numpy as np >>> df_pos = pd.DataFrame({ ... 'latitude': np.random.uniform(24.0, 25.0, 5), ... 'longitude': np.random.uniform(113.0, 114.0, 5), ... 'rainfall_day_1': np.random.randint(10, 30, 5), ... 'rainfall_day_2': np.random.randint(10, 30, 5), ... 'rainfall_day_3': np.random.randint(10, 30, 5), ... 'rainfall_day_4': np.random.randint(10, 30, 5), ... 'rainfall_day_5': np.random.randint(10, 30, 5), ... 'landslide': 1 ... }) >>> combined = gen_negative_samples( ... df=df_pos, ... target_col='landslide', ... buffer_km=10, ... num_neg_per_pos=2, ... view=False, ... verbose=2 ... ) >>> print(combined.head())
See also
check_spatial_columnsEnsures the existence of required spatial columns.
exist_featuresVerifies the presence of specified features in <df>.
columns_managerHandles both feature and spatial columns for processing.
- geoprior.utils.spatial_utils.gen_buffered_negative_samples(df, target_col, spatial_cols=('longitude', 'latitude'), feature_cols=None, buffer_km=10, neg_feature_range=(0, 5), num_neg_per_pos=1, strategy='landslide', gauge_data=None, use_gpd='auto', id_col='auto', view=False, savefile=None, seed=None, verbose=1)[source]
Generate buffer-based negative samples around existing points or gauge stations.
This function creates additional negative samples for binary spatial events. It either takes an existing landslide dataset (when strategy is ‘landslide’) or a separate gauge dataset (if strategy is ‘gauge’) to serve as the base points for generating negatives within a circular buffer. The function validates input columns and parameters via _validate_negatives_sampling before constructing synthetic samples.
- Parameters:
df (
pandas.DataFrame) – The DataFrame containing positive event samples (e.g., landslides). Must include <target_col> and <spatial_cols>.target_col (
str) – Name of the binary target column (1 for event, 0 for no event).spatial_cols (
tupleofstr, default(``’longitude’, ``'latitude')) – Indicates which columns hold <longitude> and latitude in df.feature_cols (
listofstr, optional) – Additional feature columns to simulate or copy for negatives. IfNone, all columns except <spatial_cols> and target_col are used.buffer_km (
float, default10) – The radial distance in kilometers for sampling negative points around each base point.neg_feature_range (
tupleoffloat, default(0,5)) – A numeric range from which feature values are drawn for negative samples if the column is numeric.num_neg_per_pos (
int, default1) – Number of negatives to generate per positive (landslide) or gauge point.strategy (
str, default'landslide') – Defines the base from which negative samples are generated. Use'landslide'or'event'to sample around rows indf. Use'gauge'to sample around rows ingauge_data.gauge_data (
pandas.DataFrame, optional) – Required ifstrategyis'gauge'. Must containspatial_cols.use_gpd (
boolor'auto', default'auto') – If'auto', attempts to use GeoPandas for visualization if installed. Otherwise, falls back to Matplotlib. This parameter is forwarded to the underlying visualization function.id_col (
strorlistofstr, default'auto') – Column(s) representing IDs indf. If'auto', the function tries to detect possible ID columns. Used by_validate_negatives_sampling.view (
bool, defaultFalse) – Whether to visualize the sampled negatives around the base points.savefile (
str, optional) – If provided, saves the final combined dataset (positives and negatives) to a CSV file at the specified path.seed (
int, optional) – Seed for NumPy’s random generator, ensuring reproducible offsets in negative sampling.verbose (
int, default1) – Controls console messages:1for minimal,2for more detailed logs.
- Returns:
The combined dataset containing both the original (positive) rows, labeled with
target_col= 1, and the newly generated negative rows, labeledtarget_col= 0.- Return type:
- ``_validate_negatives_sampling``
Validates required columns and parameters, including
num_neg_per_posandneg_feature_range.
- ``visualize_negative_sampling``
Generates a plot showing the negative samples around the base points if
viewis True.
Notes
If
strategyis'gauge',gauge_datamust be provided and contain columnslongitudeandlatitude.When
viewis True, circles are drawn to illustrate the buffer radius.The ratio of 1 degree to roughly 111 km is an approximation and can vary slightly by latitude.
Formally, a buffer in degrees \(\Delta\) is computed by:
(22)#\[\Delta = \frac{\text{buffer\_km}}{111},\]where \(111\) is an approximate km-per-degree conversion factor. Each base point \((lat, lon)\) spawns \(n\) negatives, each offset by \(\delta_{lat}\), \(\delta_{lon}\) drawn from \(U(-\Delta, \Delta)\).
Examples
Below is an illustration of how to generate negative samples around both existing event locations (strategy=`landslide`) and separate gauge stations (strategy=`gauge`) using
gen_buffered_negative_samples.First, we simulate a small DataFrame of positive landslide samples with rainfall attributes, as well as a separate DataFrame for gauge stations:
>>> import numpy as np >>> import pandas as pd >>> np.random.seed(42)
>>> positive_samples = pd.DataFrame({ ... 'id': [1, 2, 3, 4, 5], ... 'latitude': np.random.uniform(24.0, 25.0, 5), ... 'longitude': np.random.uniform(113.0, 114.0, 5), ... 'rainfall_day_1': np.random.randint(10, 30, 5), ... 'rainfall_day_2': np.random.randint(10, 30, 5), ... 'rainfall_day_3': np.random.randint(10, 30, 5), ... 'rainfall_day_4': np.random.randint(10, 30, 5), ... 'rainfall_day_5': np.random.randint(10, 30, 5), ... 'landslide': [1]*5 ... })
>>> gauge_data = pd.DataFrame({ ... 'gauge_id': ['G1', 'G2', 'G3'], ... 'latitude': np.random.uniform(24.0, 25.0, 3), ... 'longitude': np.random.uniform(113.0, 114.0, 3) ... })
We then call
gen_buffered_negative_samplesto produce negatives around these data using two different strategies:>>> from geoprior.utils.spatial_utils import gen_buffered_negative_samples
>>> # Generate negatives around landslide points >>> results_landslide = generate_negative_samples_with( ... df=positive_samples, ... target_col='landslide', ... spatial_cols=('longitude', 'latitude'), ... feature_cols=[f'rainfall_day_{i+1}' for i in range(5)], ... buffer_km=10, ... num_neg_per_pos=1, ... strategy='landslide', ... verbose=1 ... )
>>> # Generate negatives around the gauge stations >>> results_gauge = gen_buffered_negative_samples( ... df=positive_samples, ... target_col='landslide', ... spatial_cols=('longitude', 'latitude'), ... feature_cols=[f'rainfall_day_{i+1}' for i in range(5)], ... buffer_km=10, ... num_neg_per_pos=1, ... strategy='gauge', ... gauge_data=gauge_data, ... verbose=1 ... )
See also
generate_negative_samplesGenerate synthetic negative samples for spatial binary classification tasks.
_validate_negatives_samplingEnsures inputs and parameters are correct.
visualize_negative_samplingPlots the positive and negative points for inspection.
- geoprior.utils.spatial_utils.gen_negative_samples_plus(df, target_col, spatial_cols=('longitude', 'latitude'), feature_cols=None, buffer_km=10, neg_feature_range=(0, 5), num_neg_per_pos=1, strategy='landslide', gauge_data=None, elevation_data=None, similarity_features=None, time_col=None, cluster_method='kmeans', use_gpd='auto', id_col='auto', view=False, savefile=None, verbose=1, seed=None)[source]
Generates negative samples for modeling in spatial scenarios, offering multiple strategies. The function calls gen_buffered_negative_samples when the
strategyargument is'landslide','event', or'gauge'. It also calls generate_negative_samples partially in the'hybrid'strategy. Internally, each sample is augmented to produce negative instances according to a chosen method.(23)#\[\text{buffer_deg} = \frac{\text{buffer_km}}{111.0}\]The above formula approximates degrees from kilometers near the equator. The output is a combined dataset containing original positives and generated negatives. The ratio of negatives per positive is controlled by
num_neg_per_pos.- Parameters:
df (
pandas.DataFrame) – Input data containing spatial coordinates and features.target_col (
str) – Name of the classification target column. Positive samples in df are labeled, negatives will be generated.spatial_cols (
tupleofstr, optional) – Columns representing longitude and latitude in df.feature_cols (
listofstr, optional) – Additional feature columns used to drive or constrain negative sampling processes.buffer_km (
float, optional) – Radius in kilometers for local negative sampling. Used to compute buffer degrees.neg_feature_range (
tupleoffloat, optional) – Lower and upper range for continuous features in random negative generation.num_neg_per_pos (
int, optional) – Number of negatives to generate per positive instance.strategy (
str, optional) –Defines the sampling approach. Options include
'landslide','gauge','random_global','temporal_shift','clustered_negatives','environmental_similarity','elevation_based', and'hybrid'.See more in User Guide.
gauge_data (
pandas.DataFrame, optional) – Reference data for gauge-based or hybrid strategies.elevation_data (
pandas.DataFrame, optional) – Elevation records, used ifstrategy='elevation_based'.similarity_features (
listofstr, optional) – Columns for nearest neighbor computation in'environmental_similarity'.time_col (
str, optional) – Name of the time column for'temporal_shift'. Required if using that strategy.cluster_method (
str, optional) – Clustering algorithm for'clustered_negatives'. Default is'kmeans'.use_gpd (
boolorstr, optional) – Indicator for whether geopandas is used in buffer-based processes.id_col (
str, optional) – Column name used as an identifier. If ‘auto’, a default is used.view (
bool, optional) – Flag for visualizing or previewing the results.savefile (
str, optional) – Path to save the output dataset. If None, no file is saved.verbose (
int, optional) – Verbosity level. Higher values yield more logs.seed (
int, optional) – Random seed for reproducibility. If None, randomness is not fixed.
- Returns:
A combined DataFrame containing the original positive samples labeled as 1 and newly generated negative samples labeled 0.
- Return type:
Notes
When
strategy='hybrid', partial sets of negatives come from two distinct calls to generate_negative_samples for'landslide'and'gauge'sub-strategies, then merged.Examples
>>> from geoprior.utils.spatial_utils import gen_negative_samples_plus >>> import pandas as pd >>> df_example = pd.DataFrame({{ ... "longitude": [10.1, 10.2, 10.3], ... "latitude": [45.1, 45.2, 45.3], ... "feature": [3.4, 2.1, 6.7], ... "target": [1, 1, 1] ... }}) >>> gauge_data = pd.DataFrame({{ ... 'gauge_id': ['G1', 'G2', 'G3'], ... 'latitude': np.random.uniform(24.0, 25.0, 3), ... 'longitude': np.random.uniform(113.0, 114.0, 3) ... }}) >>> # Generate random global negatives >>> result = gen_negative_samples_plus( ... df_example, ... target_col="target", ... strategy="random_global" ... ) >>> print(result.head())
See also
gen_buffered_negative_samplesGenerates negative samples within a buffer region around reference events or gauges.
generate_negative_samplesA simpler negative sampling utility for certain strategies.
- geoprior.utils.spatial_utils.extract_spatial_roi(df, x_range, y_range, x_col='longitude', y_col='latitude', snap_to_closest=True, savefile=None, **kwargs)[source]
Extracts a spatial Region of Interest (ROI) from a DataFrame.
This function filters a DataFrame to include only the data points that fall within a specified rectangular bounding box defined by x and y coordinate ranges.
- Parameters:
df (
pd.DataFrame) – The input DataFrame containing the spatial data.x_range (
tupleof(float,float)) – A tuple containing the minimum and maximum desired values for the x-coordinate (e.g., longitude). The order does not matter.y_range (
tupleof(float,float)) – A tuple containing the minimum and maximum desired values for the y-coordinate (e.g., latitude). The order does not matter.x_col (
str, default'longitude') – The name of the column in df that contains the x-coordinates.y_col (
str, default'latitude') – The name of the column in df that contains the y-coordinates.snap_to_closest (
bool, defaultTrue) – If True, and a value in x_range or y_range does not exist in the data, the function will “snap” to the nearest available coordinate in the dataset. If False, it will use the exact boundaries provided.savefile (
str, optional) – The path to save the resulting DataFrame as a CSV file.
- Returns:
A new DataFrame containing only the rows that fall within the specified spatial bounding box.
- Return type:
pd.DataFrame- Raises:
ValueError – If x_col or y_col are not found in the DataFrame, or if the range tuples are not provided correctly.
- geoprior.utils.spatial_utils.make_forecast_ready_sample(data, sample_size=0.05, time_col='year', spatial_cols=None, group_cols=None, stratify_by=None, spatial_bins=10, time_steps=3, forecast_horizon=1, require_consecutive=True, keep_years=None, year_mode='latest', min_groups=5, max_groups=None, columns_to_keep=None, method='abs', min_relative_ratio=0.01, random_state=42, export_path=None, export_format=None, savefile=None, sort_output=True, verbose=1)[source]
Build a compact, forecast-ready panel sample.
The function samples spatial groups rather than individual rows, then reconstructs the full panel for the selected groups. This is much safer for demo/testing than row-wise sampling because each sampled location keeps enough temporal history for sequence construction.
- Parameters:
data (
pandas.DataFrame) – Input panel DataFrame.sample_size (
floatorint, default0.05) – Group-level sample size passed to the internal spatial sampler. - float: fraction of eligible groups - int: absolute number of eligible groupstime_col (
str, default'year') – Time column.spatial_cols (
tuple/listofstr, optional) – Spatial coordinate columns. If None, the function searches for ‘longitude’ and ‘latitude’.group_cols (
tuple/listofstr, optional) – Group identifier columns. If None, uses spatial_cols.stratify_by (
list/tupleofstr, optional) – Extra group-level columns used for stratification. Typical examples: [‘lithology_class’] or [‘city’, ‘lithology_class’].spatial_bins (
intortuple/list, default10) – Spatial bins passed to spatial_sampling().time_steps (
int, default3) – Lookback window length.forecast_horizon (
int, default1) – Forecast horizon length.require_consecutive (
bool, defaultTrue) – If True, each kept group must contain at least one consecutive run of length time_steps + forecast_horizon.keep_years (
int, optional) – If provided, keep only this many years per group after sampling. Must be >= time_steps + forecast_horizon.year_mode (
{'all', 'latest', 'earliest', 'random'}) – How to trim years when keep_years is given.min_groups (
int, default5) – Minimum number of eligible groups required.max_groups (
int, optional) – Hard cap on sampled groups after spatial sampling.columns_to_keep (
list/tupleofstr, optional) – Restrict the returned columns. Group and time columns are always preserved.method (
{'abs', 'absolute', 'relative'}, default'abs') – Same meaning as in spatial_sampling().min_relative_ratio (
float, default0.01) – Minimum relative sampling ratio when method=’relative’.random_state (
int, optional) – Random seed.export_path (
str, optional) – Explicit path for non-CSV export.export_format (
str, optional) – Export format for export_path. If None, inferred from suffix.savefile (
str, optional) – CSV save path handled by @SaveFile.sort_output (
bool, defaultTrue) – Sort final output by group and time.verbose (
int, default1) – Verbosity level.
- Returns:
A compact panel sample that preserves group-wise temporal structure for forecast demos/tests.
- Return type:
Examples
>>> from geoprior.utils.spatial_utils import make_forecast_ready_sample >>> # Small demo sample, latest 4 years per location >>> demo = make_forecast_ready_sample( data=df, sample_size=0.05, stratify_by=["lithology_class"], spatial_cols=("longitude", "latitude"), time_col="year", time_steps=3, forecast_horizon=1, keep_years=4, year_mode="latest", require_consecutive=True, savefile="demo_panel.csv", ) >>> #Slightly richer test panel, keep all years, export parquet demo = make_forecast_ready_sample( data=df, sample_size=150, stratify_by=["city", "lithology_class"], spatial_cols=("longitude", "latitude"), time_col="year", time_steps=3, forecast_horizon=1, keep_years=None, export_path="demo_panel.parquet", export_format="parquet", )
Together, these modules support spatial augmentation, sampling, clustering, dummy data generation, and degree-to-meter conversion, which are all important in Stage-1 construction, spatial forecasting, and map-oriented analysis.
Holdout helpers#
Utility helpers for holdout and split workflows.
- class geoprior.utils.holdout_utils.GroupMasks(required_train_years, required_forecast_years, valid_for_train, valid_for_forecast)[source]
Bases:
objectGroup-level validity masks for early filtering.
- Parameters:
- valid_for_train: DataFrame
- valid_for_forecast: DataFrame
- property keep_for_processing: DataFrame
Union(valid_for_train, valid_for_forecast).
- geoprior.utils.holdout_utils.compute_group_masks(df, *, group_cols, time_col, train_end_year, time_steps, horizon)[source]
- Build:
valid_for_train: groups containing all years for last (T+H)
valid_for_forecast: groups containing all years for last T
This assumes annual steps and integer years in time_col.
- geoprior.utils.holdout_utils.filter_df_by_groups(df, *, group_cols, groups)[source]
Keep only rows in df whose (group_cols) exist in groups DataFrame.
- class geoprior.utils.holdout_utils.HoldoutSplit(train_groups, val_groups, test_groups)[source]
Bases:
objectPixel holdout split (disjoint groups).
- train_groups: DataFrame
- val_groups: DataFrame
- test_groups: DataFrame
- check_disjoint()[source]
- Return type:
None
- geoprior.utils.holdout_utils.split_groups_holdout(groups, *, seed=42, val_frac=0.2, test_frac=0.1, strategy='random', x_col=None, y_col=None, block_size=None)[source]
Split unique groups into train/val/test (pixel-level holdout).
- strategy:
“random”: shuffle groups
“spatial_block”: shuffle spatial blocks (needs x_col,y_col,block)
These utilities are especially relevant when the workflow needs group-aware, spatially constrained, or otherwise structured holdout logic instead of purely random splitting.
IO, system, and runtime helpers#
Input/Output utilities for managing file paths, directories, and loading serialized data within FusionLab. Provides error-checked deserialization, directory management, and archive handling (e.g., .tgz, .zip), streamlining file operations and data recovery.
Adapted for FusionLab from the original geoprior.utils.io_utils.
- class geoprior.utils.io_utils.FileManager(root_dir, target_dir, file_types=None, name_patterns=None, move=False, overwrite=False, create_dirs=False)[source]
Bases:
BaseClassA class for managing and organizing files within a directory structure. This class provides methods to filter, organize, and rename files in bulk based on file extensions and name patterns. All operations are executed via the
runmethod to ensure proper initialization and state management.Mathematically, if \(\mathcal{F}\) represents the set of files in the root directory and \(\phi(f)\) is a filtering function that selects files based on file type and name pattern, then the FileManager produces a subset
(24)#\[\mathcal{F}' = \{ f \in \mathcal{F} \mid \phi(f) \}\]and performs operations such as moving or copying to reorganize these files into a target directory.
- Parameters:
root_dir (
str) – The root directory containing the files to be managed. This directory must exist and contain the files subject to filtering.target_dir (
str) – The directory where the organized files will be placed. If necessary, this directory can be created whencreate_dirsis True.file_types (
listofstr, optional) – A list of file extensions (e.g.,['.csv', '.json']) used to filter the files. IfNone, no file type filtering is applied.name_patterns (
listofstr, optional) – A list of substrings (e.g.,['2023', 'report']) to filter file names. IfNone, all file names are included.move (
bool, optional) – If True, files are moved from the source to the target directory; otherwise, they are copied. Default is False.overwrite (
bool, optional) – If True, existing files in the target directory will be overwritten. If False, existing files are skipped. Default is False.create_dirs (
bool, optional) – If True, missing directories in the target path are created. Default is False.
- Variables:
- run(pattern, replacement)[source]
Executes the file organization process. It filters files using the criteria provided at initialization and, if a pattern and corresponding replacement are given, performs bulk renaming.
- get_processed_files()[source]
Returns a list of file paths that have been processed and organized into the target directory.
Examples
>>> from geoprior.utils.io_utils import FileManager >>> manager = FileManager( ... root_dir='data/raw', ... target_dir='data/processed', ... file_types=['.csv', '.json'], ... name_patterns=['2023', 'report'], ... move=True, ... overwrite=True, ... create_dirs=True ... ) >>> manager.run(pattern='old', replacement='new') >>> processed = manager.get_processed_files() >>> print(processed)
Notes
The public method
runorchestrates the file management operations by first calling the internal method_organize_files()to filter and move or copy files from the source directory to the target directory. If renaming is needed,_rename_files()is invoked with the specified pattern and replacement. The methodget_processed_files()compiles a list of all files that have been organized, based on a walk of the target directory. The directory traversal and file-operation APIs are documented in [29, 30].See also
shutil.moveTo move files between directories.
shutil.copy2To copy files while preserving file metadata.
- __init__(root_dir, target_dir, file_types=None, name_patterns=None, move=False, overwrite=False, create_dirs=False)[source]
Initialize the base class.
- run(pattern=None, replacement=None)[source]
Executes file organization operations.
This method filters files based on the specified file types and name patterns, then organizes them by moving or copying into the target directory. Additionally, if a pattern is provided, file names containing that pattern are renamed by replacing the pattern with the specified replacement.
- Parameters:
- Returns:
self – The instance itself after executing operations.
- Return type:
Examples
>>> manager = FileManager(...) >>> manager.run(pattern='old', replacement='new')
- get_processed_files()[source]
Retrieves a list of processed files in the target directory.
- Returns:
files – A list containing the full paths of the files that have been organized into the target directory.
- Return type:
Examples
>>> manager = FileManager(...) >>> manager.run() >>> files = manager.get_processed_files() >>> print(files)
- fit(pattern=None, replacement=None)
Executes file organization operations.
This method filters files based on the specified file types and name patterns, then organizes them by moving or copying into the target directory. Additionally, if a pattern is provided, file names containing that pattern are renamed by replacing the pattern with the specified replacement.
- Parameters:
- Returns:
self – The instance itself after executing operations.
- Return type:
Examples
>>> manager = FileManager(...) >>> manager.run(pattern='old', replacement='new')
- help(**kwargs)
- my_params = FileManager( root_dir, target_dir, file_types=None, name_patterns=None, move=False, overwrite=False, create_dirs=False )
- geoprior.utils.io_utils.cpath(savepath=None, dpath='_default_path_')[source]
Ensures a directory exists for saving files, creating it if necessary.
- Parameters:
- Returns:
The absolute path to the validated or created directory.
- Return type:
Examples
>>> from geoprior.utils.io_utils import cpath >>> default_path = cpath() >>> print(f"Files will be saved to: {default_path}")
>>> custom_path = cpath('/path/to/save') >>> print(f"Files will be saved to: {custom_path}")
Notes
cpath validates the directory path and, if necessary, creates the directory tree. If a problem occurs during creation, an error message is printed.
See also
pathlib.Path.mkdirUtility for directory creation.
- geoprior.utils.io_utils.deserialize_data(filename, verbose=0)[source]
Deserialize and load data from a serialized file using joblib or pickle.
The function attempts to load the serialized data from the provided file filename using joblib first. If joblib fails, it tries to load the data using pickle. An error is raised if both methods fail.
- Parameters:
- Returns:
The data loaded from the serialized file, or None if loading fails.
- Return type:
Any- Raises:
TypeError – If filename is not a string, as file paths must be provided as strings.
FileNotFoundError – If the specified filename does not exist or cannot be located.
IOError – If both joblib and pickle fail to deserialize the data from the file.
ValueError – If the file was successfully read but yielded no data (i.e., None).
Examples
>>> from geoprior.utils.io_utils import deserialize_data >>> data = deserialize_data('path/to/serialized_data.pkl', verbose=1) Data loaded successfully from 'path/to/serialized_data.pkl' using joblib.
Notes
The function first attempts deserialization with joblib to leverage efficient file handling for large datasets. If joblib encounters an error, it falls back to pickle, which provides broader compatibility with Python objects but may be less optimized for large datasets. Loader semantics for the two backends are documented in [31, 32].
See also
joblib.loadJoblib’s load function for fast I/O operations on large data.
pickle.loadPickle’s load function for serializing and deserializing Python objects.
- geoprior.utils.io_utils.extract_tar_with_progress(tar, member, path)[source]
Extracts a single file from a tar archive with a progress bar.
- Parameters:
tar (
tarfile.TarFile) – Opened tar file object.member (
tarfile.TarInfo) – Tar member (file) to be extracted.path (
Path) – Directory path where the file will be extracted.
Examples
>>> from geoprior.utils.io_utils import extract_tar_with_progress >>> with tarfile.open('data.tar.gz', 'r:gz') as tar: ... member = tar.getmember('file.csv') ... extract_tar_with_progress(tar, member, Path('output_dir'))
Notes
Uses tqdm for progress tracking of the file extraction process.
- geoprior.utils.io_utils.fetch_tgz_from_url(data_url, tgz_filename, data_path=None, file_to_retrieve=None, **kwargs)[source]
Downloads a .tgz file from a specified URL, saves it to a directory, and optionally extracts a specific file from the archive.
This function retrieves a .tgz file from the provided data_url and saves it to the specified data_path directory. If file_to_retrieve is specified, the function will extract only that file from the archive; otherwise, the entire archive will be extracted.
- Parameters:
data_url (
str) – The URL to download the .tgz file from.tgz_filename (
str) – The name to assign to the downloaded .tgz file.data_path (
Union[str,Path], optional) – Directory where the downloaded file will be saved. Defaults to a ‘tgz_data’ directory in the current working directory if not specified.file_to_retrieve (
str, optional) – Specific filename to extract from the .tgz archive. If not provided, the entire archive is extracted.**kwargs (
dict) – Additional keyword arguments to pass to the extraction method.
- Returns:
Path to the extracted file if a specific file was requested; otherwise, returns None.
- Return type:
Optional[Path]- Raises:
FileNotFoundError – If the specified file_to_retrieve is not found in the archive.
Examples
>>> from geoprior.utils.io_utils import fetch_tgz_from_url >>> data_url = 'https://example.com/data.tar.gz' >>> extracted_file = fetch_tgz_from_url( ... data_url, 'data.tar.gz', data_path='data_dir', file_to_retrieve='file.csv') >>> print(extracted_file)
Notes
Uses the tqdm progress bar for tracking download progress.
- geoprior.utils.io_utils.fetch_tgz_locally(tgz_file, filename, savefile='tgz', rename_outfile=None)[source]
Extracts a specific file from a local .tgz archive and optionally renames it.
This function fetches a specific file filename from a local tar archive located at tgz_file, and saves it to savefile. If rename_outfile is specified, the file is renamed after extraction.
- Parameters:
tgz_file (
str) – Full path to the tar file.filename (
str) – Name of the target file to extract from the archive.savefile (
str, optional) – Destination directory for the extracted file, defaulting to ‘tgz’.rename_outfile (
str, optional) – New name for the fetched file. If not provided, retains the original name.
- Returns:
Full path to the fetched and possibly renamed file.
- Return type:
- Raises:
FileNotFoundError – If the tgz_file or the specified filename is not found.
Examples
>>> from geoprior.utils.io_utils import fetch_tgz_locally >>> fetched_file = fetch_tgz_locally( ... 'path/to/archive.tgz', 'file.csv', savefile='extracted', rename_outfile='renamed.csv') >>> print(fetched_file)
- geoprior.utils.io_utils.dummy_csv_translator(csv_fn, pf, delimiter=':', destfile='pme.en.csv')[source]
Translate a CSV file using a dictionary created from a markdown-style parser file.
- Parameters:
csv_fn (
str) – Path to the source CSV file.pf (
str) – Path to the markdown-style file used to create the translation dictionary.delimiter (
str, default':') – Delimiter used in the parser file to separate key-value pairs.destfile (
str, default'pme.en.csv') – Name of the destination file for the translated CSV.
- Returns:
DataFrame– Translated CSV data as a DataFrame.list– List of untranslated terms found in the source CSV.
Notes
This function uses parse_md_data to read the parser file and apply translations to the CSV content.
Missing translations are collected and returned for review.
Examples
>>> df, missing = dummy_csv_translator( "data.csv", "parser_file.md", delimiter=":", destfile="output.csv") >>> print(df.head()) >>> print(missing)
- geoprior.utils.io_utils.fetch_json_data_from_url(url, todo='load')[source]
Retrieve and parse JSON data from a URL.
- Parameters:
url (
str) – Universal Resource Locator (URL) from which JSON data is fetched.todo (
{'load', 'dump'}, default'load') – Action to perform with JSON: - ‘load’: Load JSON data from the URL. - ‘dump’: Parse and prepare data from the URL for saving in a JSON file.
- Returns:
A tuple of todo action, filename (or data source), and parsed data.
- Return type:
- Raises:
urllib.error.URLError – If there is an issue accessing the URL.
Notes
The function uses json.loads to parse data directly from a URL response, supporting convenient access to web-hosted JSON content.
- geoprior.utils.io_utils.get_config_fname_from_varname(data, config_fname=None, config='.yml')[source]
Generate a filename based on a variable name for YAML configuration.
- Parameters:
data (
Any) – The data object from which the variable name will be derived to create a YAML configuration filename.config_fname (
str, optional) – Custom configuration filename. If None, the name of data will be used as the filename.config (
str, default'.yml') – The file extension/type for the configuration file. Can be ‘.yml’, ‘.json’, or ‘.csv’.
- Returns:
A suitable filename for saving the configuration data.
- Return type:
- Raises:
ValueError – If config_fname cannot be derived or an invalid file type is provided.
Notes
This function supports dynamic filename generation based on variable names, which aids in maintaining a clear configuration structure for serialized data. Files are saved with appropriate extensions based on the config type.
- geoprior.utils.io_utils.get_valid_key(input_key, default_key, substitute_key_dict=None, regex_pattern='[#&*@!,;\\s]\\s*', deep_search=True)[source]
Validates an input key and substitutes it with a valid key if necessary, based on a mapping of valid keys to their possible substitutes. If the input key is not provided or is invalid, a default key is used.
- Parameters:
input_key (
str) – The key to validate and possibly substitute.default_key (
str) – The default key to use if input_key is None, empty, or not found in the substitute mapping.substitute_key_dict (
dict, optional) – A mapping of valid keys to lists of their possible substitutes. This allows for flexible key substitution and validation.regex_pattern (
str, default ='[#&*@!,;\s-]\s*') – The base pattern to split the text into a columnsdeep_search (
bool, defaultFalse) – If deep-search, the key finder is no sensistive to lower/upper case or whether a numeric data is included.
- Returns:
A valid key, which is either the original input_key if valid, a substituted key if the original was found in the substitute mappings, or the default_key.
- Return type:
Notes
This function also leverages an external validation through key_checker for a deep search validation, ensuring the returned key is within the set of valid keys.
Example
>>> from geoprior.utils.io_utils import get_valid_key >>> substitute_key_dict = {'valid_key1': ['vk1', 'key1'], 'valid_key2': ['vk2', 'key2']} >>> get_valid_key('vk1', 'default_key', substitute_key_dict) 'valid_key1' >>> get_valid_key('unknown_key', 'default_key', substitute_key_dict) 'KeyError...'
- geoprior.utils.io_utils.key_checker(keys, valid_keys, regex=None, pattern=None, deep_search=False)[source]
check whether a give key exists in valid_keys and return a list if many keys are found.
- Parameters:
keys (
str,listofstr) – Key value to find in the valid_keysvalid_keys (
list) – List of valid keys by default.regex (re object,) –
Regular expresion object. the default is:
>>> import re >>> re.compile (r'[_#&*@!_,;\s-]\s*', flags=re.IGNORECASE)
pattern (
str, default ='[_#&*@!_,;\s-]\s*') – The base pattern to split the text into a columnsdeep_search (
bool, defaultFalse) – If deep-search, the key finder is no sensistive to lower/upper case or whether a numeric data is included.
- Returns:
keys – List of keys that exists in the valid_keys.
- Return type:
str,list ,
Examples
>>> from geoprior.utils.io_utils import key_checker >>> key_checker('h502', valid_keys= ['h502', 'h253','h2601']) Out[68]: 'h502' >>> key_checker('h502+h2601', valid_keys= ['h502', 'h253','h2601']) Out[69]: ['h502', 'h2601'] >>> key_checker('h502 h2601', valid_keys= ['h502', 'h253','h2601']) Out[70]: ['h502', 'h2601'] >>> key_checker(['h502', 'h2601'], valid_keys= ['h502', 'h253','h2601']) Out[73]: ['h502', 'h2601'] >>> key_checker(['h502', 'h2602'], valid_keys= ['h502', 'h253','h2601']) UserWarning: key 'h2602' is missing in ['h502', 'h2602'] Out[82]: 'h502' >>> key_checker(['502', 'H2601'], valid_keys= ['h502', 'h253','h2601'], deep_search=True ) Out[57]: ['h502', 'h2601']
- geoprior.utils.io_utils.key_search(keys, default_keys, parse_keys=True, regex=None, pattern=None, deep=Ellipsis, raise_exception=Ellipsis)[source]
Find key in a list of default keys and select the best match.
- Parameters:
keys (
strorlist) – The string or a list of key. When multiple keys is passed as a string, use the space for key separating.default_keys (
strorlist) – The likehood key to find. Can be a litteral text. When a litteral text is passed, it is better to provide the regex in order to skip some character to parse the text properly.parse_keys (
bool, defaultTrue) –Parse litteral string using default pattern and regex.
Added in version 0.2.7.
regex (re object,) –
Regular expresion object. Regex is important to specify the kind of data to parse. the default is:
>>> import re >>> re.compile (r'[_#&*@!_,;\s-]\s*', flags=re.IGNORECASE)
pattern (
str, default ='[_#&*@!_,;\s-]\s*') – The base pattern to split the text into a columns. Pattern is important especially when some character are considers as a part of word but they are not a separator. For example a data columns with a name ‘DH_Azimuth’, if a pattern is not explicitely provided, the default pattern will parse as two separated word which is far from the expected results.raise_exception (
bool, defaultFalse) – raise error when key is not find.
- Returns:
list
- Return type:
listofvalid keysorNone if not find ( default)
Examples
>>> from geoprior.utils.io_utils import key_search >>> key_search('h502-hh2601', default_keys= ['h502', 'h253','HH2601']) Out[44]: ['h502'] >>> key_search('h502-hh2601', default_keys= ['h502', 'h253','HH2601'], deep=True) Out[46]: ['h502', 'HH2601'] >>> key_search('253', default_keys= ("I m here to find key among h502, h253 and HH2601")) Out[53]: ['h253'] >>> key_search ('east', default_keys= ['DH_East', 'DH_North'] , deep =True,) Out[37]: ['East'] key_search ('east', default_keys= ['DH_East', 'DH_North'], deep =True,parse_keys= False) Out[39]: ['DH_East']
- geoprior.utils.io_utils.load_serialized_data(filename, verbose=0)[source]
Load data from a serialized file (e.g., pickle or joblib format).
- Parameters:
- Returns:
Data loaded from the file, or None if deserialization fails.
- Return type:
Any- Raises:
TypeError – If filename is not a string.
FileExistsError – If the specified file does not exist.
Examples
>>> from geoprior.utils.io_utils import load_serialized_data >>> data = load_serialized_data('data/my_data.pkl', verbose=3)
Notes
This function attempts to load serialized data using joblib and fallbacks to pickle if needed. Verbose output provides feedback on the loading process and success or failure of each step.
See also
joblib.loadHigh-performance loading utility.
pickle.loadGeneral-purpose Python serialization library.
- geoprior.utils.io_utils.load_csv(data_path, delimiter=',', **kwargs)[source]
Loads a CSV file into a pandas DataFrame.
This function reads a comma-separated values (CSV) file into a pandas DataFrame, with the ability to specify a custom delimiter. It provides support for additional options passed to pandas.read_csv for more granular control over the data loading process.
- Parameters:
data_path (
str) – The file path to the CSV file that is to be loaded. The file path must lead to a .csv file. If the file does not exist at the specified path, a FileNotFoundError is raised.delimiter (
str, optional) – The character used to separate values in the CSV file. The default is , for standard CSVs. If a different delimiter is used in the file (e.g., ;), it can be specified here.**kwargs (
dict) – Additional keyword arguments that will be passed directly to pandas.read_csv. For instance, users can specify header, index_col, dtype, and other options supported by read_csv for more customized data handling.
- Returns:
A pandas DataFrame containing the loaded data, with the specified options applied.
- Return type:
DataFrame- Raises:
FileNotFoundError – If the specified file does not exist at the provided data_path.
ValueError – If the file specified by data_path is not a CSV file (i.e., does not have a .csv extension), a ValueError is raised to ensure correct file type.
Notes
This function simplifies the process of loading CSV data into a DataFrame, with a straightforward parameter for delimiter customization and full access to pandas.read_csv options. It is ideal for basic CSV loading tasks, as well as more complex ones requiring specific column handling, type casting, and missing value handling, which can be passed via **kwargs. CSV-oriented DataFrame loading patterns are discussed in McKinney [27].
Examples
Suppose you have a CSV file example.csv with the following content:
` name,age,city Alice,30,New York Bob,25,Los Angeles `To load this file into a DataFrame:
>>> from geoprior.utils.io_utils import load_csv >>> df = load_csv('example.csv') >>> print(df) name age city 0 Alice 30 New York 1 Bob 25 Los Angeles
If the file uses a semicolon (;) as the delimiter:
>>> df = load_csv('example.csv', delimiter=';')
Additionally, you can pass custom read_csv parameters through **kwargs, such as specifying a column as the index:
>>> df = load_csv('example.csv', index_col='name') >>> print(df) age city name Alice 30 New York Bob 25 Los Angeles
See also
pandas.read_csvFull documentation for loading CSV files into a DataFrame with detailed parameter options.
- geoprior.utils.io_utils.move_cfile(cfile, savepath=None, **ckws)[source]
Moves a file to the specified path. If moving fails, copies and deletes the original.
- Parameters:
- Returns:
The new file path and a confirmation message.
- Return type:
Tuple[str,str]
Examples
>>> from geoprior.utils.io_utils import move_cfile >>> new_path, msg = move_cfile('myfile.txt', 'new_directory') >>> print(new_path, msg)
- geoprior.utils.io_utils.parse_csv(csv_fn=None, data=None, todo='reader', fieldnames=None, savepath=None, header=False, verbose=0, **csvkws)[source]
Parses a CSV file or serializes data to a CSV file.
This function allows loading (reading) from or dumping (writing) to a CSV file. It supports standard CSV and dictionary-based CSV formats.
- Parameters:
csv_fn (
str, optional) – The CSV filename for reading or writing. For writing operations, if data is provided and todo is set to ‘write’ or ‘dictwriter’, this specifies the output CSV filename.data (
list, optional) – Data to write in the form of a list of lists or dictionaries.todo (
str, default'reader') – Specifies the operation type: - ‘reader’ or ‘dictreader’: Reads data from a CSV file. - ‘writer’ or ‘dictwriter’: Writes data to a CSV file.fieldnames (
listofstr, optional) – List of keys for dictionary-based writing to specify the field order.savepath (
str, optional) – Directory to save the CSV file when writing. Defaults to ‘_savecsv_’ if not provided and the path does not exist.header (
bool, defaultFalse) – If True, includes headers when writing with DictWriter.verbose (
int, default0) – Controls the verbosity level for output messages.csvkws (
dict, optional) – Additional arguments passed to csv.writer or csv.DictWriter.
- Returns:
Parsed data from the CSV file, as a list of lists or a list of dictionaries, based on the operation. Returns None when writing.
- Return type:
Union[List[Dict],List[List[str]],None]
Notes
For writing data, the method uses either csv.writer for regular CSV or csv.DictWriter for dictionary-based CSV depending on the value of todo.
Examples
>>> from geoprior.utils.io_utils import parse_csv >>> data = [{"name": "Alice", "age": 30}, {"name": "Bob", "age": 25}] >>> parse_csv(csv_fn='output.csv', data=data, todo='dictwriter', fieldnames=['name', 'age']) >>> loaded_data = parse_csv(csv_fn='output.csv', todo='dictreader', fieldnames=['name', 'age']) >>> print(loaded_data) [{'name': 'Alice', 'age': 30}, {'name': 'Bob', 'age': 25}]
- geoprior.utils.io_utils.parse_json(json_fn=None, data=None, todo='load', savepath=None, verbose=0, **jsonkws)[source]
Parse and manage JSON configuration files, either loading data from or saving data to a JSON file.
- Parameters:
json_fn (
str, optional) – JSON filename or URL. If data is provided and todo is ‘dump’, json_fn will be used as the output filename. If todo is ‘load’, json_fn is the input filename or URL.data (
Any, optional) – Data in Python object format to serialize and save if todo is ‘dump’.todo (
{'load', 'loads', 'dump', 'dumps'}, default'load') – Action to perform with JSON: - ‘load’: Load data from a JSON file. - ‘loads’: Parse a JSON string. - ‘dump’: Serialize data to a JSON file. - ‘dumps’: Serialize data to a JSON string.savepath (
str, optional) – Path where the JSON file will be saved if todo is ‘dump’. If savepath does not exist, it will save to the default path ‘_savejson_’.verbose (
int, default0) – Controls verbosity of output messages.**jsonkws (
dict) – Additional keyword arguments passed to json.dump or json.dumps when saving data.
- Returns:
The data loaded from the JSON file or URL if todo is ‘load’, or data after saving if todo is ‘dump’.
- Return type:
Any- Raises:
json.JSONDecodeError – If there is an issue with reading or writing the JSON file.
TypeError – If the JSON file or data cannot be processed.
Notes
This function uses json.load, json.loads, json.dump, and json.dumps for efficient handling of JSON files and strings.
See also
fetch_json_data_from_urlFetches JSON data from a given URL.
get_config_fname_from_varnameUtility for generating JSON configuration filenames based on variable names.
- geoprior.utils.io_utils.parse_md(pf, delimiter=':')[source]
Parse a markdown-style file with key-value pairs separated by a delimiter.
- Parameters:
- Yields:
Tuple[str,str]– A tuple containing the key and processed value.- Raises:
IOError – If the provided path does not lead to a valid file.
Notes
This function yields key-value pairs by reading the file line-by-line.
It applies sanitize_unicode_string to keys to ensure data consistency.
Examples
>>> list(parse_md_data('parser_file.md', delimiter=':')) [('key1', 'Value1'), ('key2', 'Value2')]
- geoprior.utils.io_utils.parse_yaml(yml_fn=None, data=None, todo='load', savepath=None, verbose=0, **ymlkws)[source]
Parse and handle YAML configuration files for loading or saving data.
- Parameters:
yml_fn (
str, optional) – The YAML filename. If data is provided and todo is set to ‘dump’, yml_fn will be used as the output filename. If todo is set to ‘load’, yml_fn is the input filename to read from.data (
Any, optional) – Data in a Python object format that will be serialized and saved as a YAML file if todo is ‘dump’.todo (
{'load', 'dump'}, default'load') – Action to perform with the YAML file: - ‘load’: Load data from the YAML file specified by yml_fn. - ‘dump’: Serialize data into a YAML format and save to yml_fn.savepath (
str, optional) – Path where the YAML file will be saved if todo is ‘dump’. If not provided, a default path will be used. The function will ensure that the path exists.verbose (
int, default0) – Controls verbosity of output messages.**ymlkws (
dict) – Additional keyword arguments passed to yaml.dump when saving data.
- Returns:
The data loaded from the YAML file if todo is ‘load’, or data after saving if todo is ‘dump’.
- Return type:
Any- Raises:
yaml.YAMLError – If there is an issue with reading or writing the YAML file.
Notes
This function uses safe_load and safe_dump methods from PyYAML for secure handling of YAML files.
See also
get_config_fname_from_varnameUtility for generating YAML configuration filenames based on variable names.
- geoprior.utils.io_utils.print_cmsg(cfile, todo='load', config='YAML')[source]
Generates output message for configuration file operations.
- Parameters:
- Returns:
Confirmation message for the configuration operation.
- Return type:
Examples
>>> from geoprior.utils.io_utils import print_cmsg >>> msg = print_cmsg('config.yml', 'dump') >>> print(msg) --> YAML 'config.yml' data was successfully saved.
- geoprior.utils.io_utils.rename_files(src_files, dst_files, basename=None, extension=None, how='py', prefix=True, keep_copy=True, trailer='_', sortby=None, **kws)[source]
Rename files from one set of names or paths to another.
- Parameters:
src_files (
strorlistofstr) – Source files or a directory containing files to rename.dst_files (
strorlistofstr) – Destination file names or destination directory.basename (
strorNone, optional) – Base name used when generating numbered destination files.extension (
strorNone, optional) – Optional extension filter whensrc_filesis a directory.how (
str, optional) – Numbering convention used when destination names are generated.prefix (
bool, optional) – Whether generated numbering is appended after the basename.keep_copy (
bool, optional) – Whether to keep copies of the original files.trailer (
str, optional) – Separator inserted between the basename and the generated counter.sortby (
regex,callable, orNone, optional) – Optional sort key used when collecting files from a directory.**kws (
dict) – Additional keyword arguments forwarded toos.rename.
- Return type:
None
- geoprior.utils.io_utils.sanitize_unicode_string(str_)[source]
Removes spaces and replaces accented characters in a string.
- Parameters:
- Returns:
The sanitized string with removed spaces and replaced accents.
- Return type:
Examples
>>> from geoprior.utils.io_utils import sanitize_unicode_string >>> sentence ='Nos clients sont extrêmement satisfaits ' 'de la qualité du service fourni. En outre Nos clients ' 'rachètent frequemment nos "services".' >>> sanitize_unicode_string (sentence) ... 'nosclientssontextrmementsatisfaitsdelaqualitduservice' 'fournienoutrenosclientsrachtentfrequemmentnosservices' >>> sanitize_unicode_string("Élève à l'école") 'elevealecole'
- geoprior.utils.io_utils.save_job(job, savefile, *, protocol=None, append_versions=True, append_date=True, fix_imports=True, buffer_callback=None, **job_kws)[source]
Quick save your job using ‘joblib’ or persistent Python pickle module.
- Parameters:
job (
Any) – Anything to save, preferabaly a models in dictsavefile (
str, orpath-like object) – name of file to store the model. The file argument must have a write() method that accepts a single bytes argument. It can thus be a file object opened for binary writing, an io.BytesIO instance, or any other custom object that meets this interface.append_versions (
bool, default=True) – Append the version of Joblib module or Python Pickle module following by the scikit-learn, numpy and also pandas versions. This is useful to have idea about previous versions for loading file when system or modules have been upgraded. This could avoid bottleneck when data have been stored for long times and user has forgotten the date and versions at the time the file was saved.append_date (
bool, defaultTrue,) – Append the date of the day to the filename.protocol (
int, optional) –The optional protocol argument tells the pickler to use the given protocol; supported protocols are 0, 1, 2, 3, 4 and 5. The default protocol is 4. It was introduced in Python 3.4, and is incompatible with previous versions.
Specifying a negative protocol version selects the highest protocol version supported. The higher the protocol used, the more recent the version of Python needed to read the pickle produced.
fix_imports (
bool, defaultTrue,) – If fix_imports is True and protocol is less than 3, pickle will try to map the new Python 3 names to the old module names used in Python 2, so that the pickle data stream is readable with Python 2.buffer_call_back (
int, optional) –If buffer_callback is None (the default), buffer views are serialized into file as part of the pickle stream.
If buffer_callback is not None, then it can be called any number of times with a buffer view. If the callback returns a false value (such as None), the given buffer is out-of-band; otherwise the buffer is serialized in-band, i.e. inside the pickle stream.
It is an error if buffer_callback is not None and protocol is None or smaller than 5.
job_kws (
dict,) – Additional keywords arguments passed tojoblib.dump().
- Returns:
The final filename where the job was saved.
- Return type:
Notes
This function appends system-specific metadata like versions and date to the filename, which can aid in tracking compatibility over time.
Examples
>>> from geoprior.utils.io_utils import save_job >>> model = {"key": "value"} # Replace with actual model object >>> savefile = save_job(model, "my_model", append_date=True, append_versions=True) >>> print(savefile) 'my_model.20240101.sklearn_v1.0.numpy_v1.21.joblib'
- geoprior.utils.io_utils.save_path(nameOfPath)[source]
Creates a directory if it does not exist.
- Parameters:
nameOfPath (
str) – Name or path of the directory to create.- Returns:
The path of the created directory. If it exists, returns the existing path.
- Return type:
Examples
>>> save_path("test_directory") 'path/to/test_directory'
- geoprior.utils.io_utils.serialize_data(data, filename=None, savepath=None, to=None, force=True, compress=None, pickle_protocol=5, verbose=0)[source]
Serialize and save a Python object to a binary file using either
jobliborpickle. This function is designed to be robust and versatile, handling multiple cases including file naming, overwriting behavior, and compression options.The final file path is computed as:
(25)#\[\text{filepath} = \text{savepath} \oplus \text{filename}\]where \(\oplus\) denotes string concatenation.
- Parameters:
data (
Any) – The Python object to serialize. The object must be compatible with eitherjoblib.dumporpickle.dump.filename (
str, optional) – The target filename for the serialized data. IfNone, a filename is generated using the current timestamp, e.g.,"__mydumpedfile_20230315_123045.pkl".savepath (
str, optional) – The directory in which to save the file. If not specified, the current working directory (os.getcwd()) is used. The directory is created if it does not exist.to (
str, optional) – The serialization method to use. Acceptable values are'joblib'and'pickle'. IfNone, the default is'joblib'.force (
bool, defaultTrue) – IfTrue, any existing file with the same name is overwritten. IfFalse, a timestamp is appended to the filename to ensure uniqueness.compress (
intorstr, optional) – Compression level or method forjoblib.dump. IfNone, no compression is applied.pickle_protocol (
int, defaultpickle.HIGHEST_PROTOCOL) – The pickle protocol to use when serializing withpickle.dump.verbose (
int, default0) – Controls the verbosity of output messages. Higher values produce more detailed logging during the serialization process.
- Returns:
The full path to the saved serialized file.
- Return type:
Examples
>>> from geoprior.utils.io_utils import serialize_data >>> import numpy as np >>> data = {"a": np.arange(10), "b": np.random.rand(10)} >>> filepath = serialize_data( ... data, filename="mydata.pkl", savepath="output", ... to="pickle", force=False, verbose=1 ... ) >>> print(filepath) /current/working/directory/output/mydata_<timestamp>.pkl
Notes
The function first constructs the file path from
savepathandfilename. If a file already exists andforceis False, a timestamp is appended to ensure uniqueness. Then, depending on the value ofto, the function attempts to serialize the data using eitherjoblib.dump(with optional compression via thecompressparameter) orpickle.dump(using the specifiedpickle_protocol). If an error occurs during serialization, anIOErroris raised.See also
joblib.dumpSerialize objects to disk using Joblib.
pickle.dumpSerialize objects to disk using Pickle.
os.getcwdRetrieve the current working directory.
- geoprior.utils.io_utils.serialize_data_in(data, filename=None, force=True, savepath=None, verbose=0)[source]
Serializes a Python object to a binary file using either joblib or pickle.
This function attempts to serialize the input data using the
joblib.dumpmethod. If this attempt fails, it falls back to usingpickle.dump. The final file path is constructed by concatenating the directory specified bysavepath(or the current working directory ifsavepathis None) with the givenfilename. Mathematically, the file path is given by:(26)#\[\text{filepath} = \text{savepath} \oplus \text{filename}\]where \(\oplus\) denotes string concatenation.
- Parameters:
data (
Any) – The Python object to serialize. It must be compatible with eitherjobliborpickleserialization.filename (
str, optional) – The target filename for the serialized data. IfNone, a filename is generated using the current timestamp formatted as"%Y%m%d%H%M%S"(e.g.,"serialized_20230315123045.pkl").force (
bool, defaultTrue) – Determines whether to overwrite an existing file with the same filename. IfFalse, a timestamp is appended to the filename to ensure uniqueness.savepath (
str, optional) – The directory in which to save the serialized file. If not specified, the file is saved to the current working directory (os.getcwd()).verbose (
int, default0) – Controls the verbosity of output messages. Higher values produce more detailed logging during the serialization process.
- Returns:
The complete file path to which the data has been serialized.
- Return type:
Examples
>>> from geoprior.utils.io_utils import serialize_data_in >>> data = {"a": 1, "b": 2} >>> filepath = serialize_data_in(data, filename='data.pkl', ... force=True, verbose=1) >>> print(filepath) /path/to/current/directory/data.pkl
Notes
The function first tries to serialize the input data using
joblib.dump. In case of any exception during this attempt, it falls back to usingpickle.dump. This dual approach improves robustness in diverse runtime environments where one serialization method might be unsupported or encounter issues with the given data type.See also
joblib.dumpSerialize objects to disk using Joblib.
pickle.dumpSerialize objects to disk using Pickle.
os.getcwdRetrieve the current working directory.
- geoprior.utils.io_utils.spath(name_of_path)[source]
Create a directory if it does not already exist.
- Parameters:
name_of_path (
str) – Path-like object to create if it doesn’t exist.- Returns:
The absolute path to the created or existing directory.
- Return type:
Examples
>>> from geoprior.utils.io_utils import spath >>> path = spath('data/saved_models') >>> print(f"Directory available at: {path}")
Notes
spath is useful for quickly ensuring that a specific directory is available for storing files. It provides feedback if the directory already exists.
- geoprior.utils.io_utils.store_or_write_hdf5(df, key=None, mode='a', kind=None, path_or_buf=None, encoding='utf8', csv_sep=',', index=Ellipsis, columns=None, sanitize_columns=False, func=None, args=(), applyto=None, **func_kwds)[source]
Store a DataFrame to HDF5, write it to CSV, or sanitize it in memory.
- Parameters:
df (
pandas.DataFrameorarray-like) – Input data to store, export, or sanitize.key (
strorNone, optional) – Group key used when storing to HDF5.mode (
{'a', 'w', 'r+'}, optional) – File mode used when opening an HDF5 store.kind (
{'store', 'write', None}, optional) – Operation to perform. Use'store'for HDF5 output,'write'for CSV export, orNoneto return a sanitized DataFrame.path_or_buf (
str,path-like,pandas.HDFStore,file-like, orNone, optional) – Destination path, buffer, or open HDF5 store.encoding (
str, optional) – Output encoding used for CSV export.csv_sep (
str, optional) – Field separator used for CSV export.index (
bool, optional) – Whether to write the index when exporting to CSV.columns (
listofstrorNone, optional) – Column names used when constructing a DataFrame from an array.sanitize_columns (
bool, optional) – Whether to sanitize column names with the built-in regex helper.func (
callableorNone, optional) – Optional custom sanitizing function applied to selected columns.args (
tuple, optional) – Positional arguments forwarded tofunc.applyto (
strorlistofstrorNone, optional) – Column or columns to whichfuncshould be applied.func_kwds (
dict) – Keyword arguments forwarded tofunc.
- Returns:
Returns
Nonewhenkindis'store'or'write'. Otherwise returns the resulting DataFrame.- Return type:
- geoprior.utils.io_utils.to_hdf5(data, fn, objname=None, close=True, **hdf5_kws)[source]
Store a data object in Hierarchical Data Format 5 (HDF5).
This function serializes the input
datainto an HDF5 file. It supports both pandas DataFrames and NumPy arrays. Ifdatais a DataFrame, it usespd.HDFStore(which requires thepytablespackage) to store the data. Ifdatais a NumPy array, it usesh5py.Fileto create a dataset.The file path is constructed by concatenating the specified
savepath(or the current working directory ifsavepathis not provided) with the provided filename (fn). The function automatically appends the appropriate file extension:.h5for DataFrames and.hdf5for arrays.(27)#\[\text{filepath} = \text{savepath} \oplus \text{filename} \oplus \text{extension}\]where \(\oplus\) denotes string concatenation.
- Parameters:
data (
Any) – The data object to be stored. Must be either a NumPy array or a pandas DataFrame.fn (
str) – The file path (without extension) where the HDF5 file will be saved.objname (
str, optional) – The name under which to store the data within the HDF5 file. Defaults to'data'if not provided.close (
bool, defaultTrue) – IfTrue, the file is closed after writing. IfFalse, the file remains open for additional modifications.**hdf5_kws (
dict, optional) – Additional keyword arguments to pass to the HDFStore constructor (for DataFrames) or to customize dataset creation (for arrays). Common options includemodefor the file mode,complevelfor compression level,complibfor the compression library, andfletcher32to enable the Fletcher32 checksum. Formode, use'r'for read-only access,'w'to create a new file,'a'to append or create, and'r+'to open an existing file for reading and writing.
- Returns:
store – An IO interface for the stored data. For DataFrames, this is a
pd.HDFStoreobject; for arrays, anh5py.Fileobject.- Return type:
Examples
>>> import os >>> import pandas as pd >>> from geoprior.utils.io_utils import to_hdf5 >>> data = pd.DataFrame({ ... 'a': [1, 2, 3], ... 'b': [4, 5, 6] ... }) >>> save_path = os.path.join('output', 'datafile') >>> store = to_hdf5(data, fn=save_path, objname='mydata', verbose=1) >>> # Access stored data: >>> retrieved = store['mydata'] >>> print(retrieved.head())
Notes
Ensure the dependency
pytablesis installed when serializing a DataFrame. When serializing NumPy arrays, the dataset is created with the name"dataset_01". Ifcloseis set toFalse, the caller is responsible for closing the store. The pandas and NumPy foundations underlying this serialization path are summarized in [33, 34].See also
joblib.dump,pickle.dump,h5py.File
- geoprior.utils.io_utils.zip_extractor(zip_file, samples='*', ftype=None, savepath=None, pwd=None)[source]
Extracts files from a ZIP archive based on various filtering criteria and saves them to a specified directory.
The extraction process can be controlled by the
samplesparameter to limit the number of files extracted, or by theftypeparameter to filter by a specific file extension. The resulting file names are returned as a list.(28)#\[\text{Extracted Files} = \{ f \in \mathcal{A} \mid \phi(f) \}\]where \(\mathcal{A}\) is the set of all files in the archive, and \(\phi(f)\) is a predicate that checks if a file matches the desired extension and is within the specified sample count.
- Parameters:
zip_file (
str) – Full path to the ZIP archive file.samples (
intorstr, optional) – Number of files to extract. If set to'*', all files are extracted. Default is'*'.ftype (
str, optional) – File extension filter (e.g.,'.csv'). Only files with this extension are extracted. If no matching files are found, a ValueError is raised.savepath (
str, optional) – Directory where the extracted files will be stored. If not provided, files are extracted to the current working directory.pwd (
strorbytes, optional) – Password for encrypted ZIP files. If provided as a string, it will be used as is (or can be encoded to bytes as needed).
- Returns:
A list of extracted file names (with paths).
- Return type:
Examples
>>> from geoprior.utils.io_utils import zip_extractor >>> extracted_files = zip_extractor( ... 'data/archive.zip', ... samples='*', ... ftype='.csv', ... savepath='data/extracted', ... pwd='secret' ... ) >>> print(extracted_files) ['folder1/file1.csv', 'folder2/file2.csv', ...]
Notes
The function first validates the input ZIP file using
check_files(assumed to be defined in the package). It then determines the sample count and filters files by extension ifftypeis provided. Extraction is done via the standardZipFile.extractorZipFile.extractallmethods.See also
zipfile.ZipFile.extractExtract a single file from a ZIP archive.
zipfile.ZipFile.extractallExtract all files from a ZIP archive.
- geoprior.utils.io_utils.fetch_joblib_data(job_file, *keys, error_mode='raise', verbose=0)[source]
Dynamically load data from a joblib-saved dictionary with flexible key access.
- Parameters:
job_file (
str) – Path to the joblib file containing a dictionary*keys (
str) – Variable-length list of dictionary keys to retrieveerror_mode (
{'raise', 'warn', 'ignore'}, default'raise') – Handling of missing keys: - ‘raise’: Immediately raise KeyError - ‘warn’: Issue warning and skip missing keys - ‘ignore’: Silently skip missing keysverbose (
int, default0) – Verbosity level: - 0: No output - 1: Basic loading information - 2: Detailed debugging output
- Returns:
Full dictionary if no keys specified
Tuple of values for requested keys (maintaining order)
- Return type:
Union[Dict,Tuple]- Raises:
FileNotFoundError – If specified job_file doesn’t exist
TypeError – If loaded data isn’t a dictionary
KeyError – If requested key not found and error_mode=’raise’
Examples
>>> from geoprior.utils.io_utils import fetch_joblib_data >>> data = fetch_joblib_data('data.joblib', 'X_train', 'y_train') >>> X, y = fetch_joblib_data('data.joblib', 'X_val', 'y_val', verbose=1) >>> full_dict = fetch_joblib_data('data.joblib')
Notes
Maintains original insertion order for Python 3.7+ dictionaries
Missing keys in ‘warn’/’ignore’ modes result in shorter return tuple
Joblib files must contain dictionary objects
- geoprior.utils.io_utils.to_txt(d, filename=None, format='txt', indent=2, width=80, depth=None, compat=False, include_header=True, mode='w', encoding='utf-8', overwrite=True, header=None, footer=None, serializer=None, savepath=None, verbose=1, logger=None, **kwargs)[source]
Export data objects to a text or JSON file with optional custom formatting.
The function, <to_txt>, handles writing <d> (a string, dict, list, or general object) to a file named <filename>. When no filename is given, it automatically generates one based on the current date/time. If <format> is “json” and <d> is valid for JSON serialization, it attempts a JSON export. Otherwise, it falls back to text mode, leveraging Python’s built-in pformat and an optional <serializer> for advanced transformations.
(29)#\[\begin{split}\\text{FileName}_{timestamp} \\rightarrow \\text{output}\end{split}\]where \(\\text{FileName}_{timestamp}\) is an auto-generated name like output_20230101_123456.txt if <filename> is not provided.
- Parameters:
d (
object) – Data to write. Can be any Python object supported by pformat, or a dict if <format> is ‘json’.filename (
str, optional) – Full path (or name) of the output file. If None, a time-stamped name is produced, prefixed with ‘output_’.format (
str, default'txt') – File format, either"txt"or"json". If it fails to serialize as JSON, the process reverts to text.indent (
int, default2) – Indentation level for pretty-printing text or JSON.width (
int, default80) – Wrap width for formatted text lines.depth (
int, optional) – Maximum depth to which nested structures are expanded. If None, no limit is applied.compat (
bool, defaultFalse) – If True, instructs pformat to produce more compact text. Not used when exporting JSON.include_header (
bool, defaultTrue) – Whether to include a decorative header (with timestamp) at the top of the file in text mode.mode (
str, default'w') – File writing mode. Typically ‘w’ for overwrite, ‘a’ for append.encoding (
str, default'utf-8') – Text encoding used when opening the file.overwrite (
bool, defaultTrue) – If False, raises an error if the file already exists.header (
str, optional) – Custom header text (if <include_header> is True). Overwrites the default header if given.footer (
str, optional) – Custom footer text appended at the end of the file, if <include_header> is True.serializer (
callable, optional) – A function that transforms <d> before printing. If it fails, <d> remains unchanged.verbose (
int, default1) – Verbosity level for logging. Higher values yield more console messages (e.g., file stats at <verbose>>=3).**kwargs – Additional parameters passed to the JSON serializer (json.dump) or pformat.
- Returns:
The final filename used to store the output (potentially auto-generated).
- Return type:
Notes
If <format> is “json”, the function tries json.dump with a few standard parameters. If an exception occurs, it reverts to text export. The <serializer> argument allows custom transformations, such as flattening nested dicts or converting objects to JSON- serializable representations. The standard-library JSON behavior used here is documented in Python Software Foundation [35].
Examples
>>> from geoprior.utils.io_utils import to_txt >>> my_data = {"name":"Alice","age":30} >>> # Basic text export >>> txt_file = to_txt(my_data, verbose=2) >>> # Enforce JSON format >>> json_file = to_txt(my_data, format='json', indent=4)
See also
pformatPretty-print complex Python data structures.
Parallel execution helpers for GeoPrior workflows.
- geoprior.utils.parallel_utils.resolve_n_jobs(n_jobs)[source]
- geoprior.utils.parallel_utils.threads_per_job(*, n_jobs, threads=0, reserve=1)[source]
- geoprior.utils.parallel_utils.apply_thread_env(env, *, n_jobs, threads=0, reserve=1)[source]
- geoprior.utils.parallel_utils.apply_tf_threading(*, intra, inter)[source]
- geoprior.utils.parallel_utils.detect_gpu_ids(*, env=None)[source]
- geoprior.utils.parallel_utils.resolve_device(device, *, env=None)[source]
- geoprior.utils.parallel_utils.resolve_gpu_ids(gpu_ids, *, env=None)[source]
- geoprior.utils.parallel_utils.pick_gpu_id(idx, gpu_ids)[source]
- geoprior.utils.parallel_utils.apply_gpu_env(env, *, gpu_id, allow_growth=True)[source]
System utilities module for managing system-level operations.
This module provides utilities essential for system-level tasks such as color management, regular expression searching, and projection validation, along with other miscellaneous system operations.
- class geoprior.utils.sys_utils.BatchDataFrameBuilder(chunk_size=100000, processor='auto', verbose=1)[source]
Bases:
objectManages incremental construction of a large DataFrame in controlled-size chunks. This can reduce peak memory usage and allow GPU-accelerated libraries (e.g.,
cudf) if they are available and desired.The approach can be expressed mathematically as a chunking process that partitions an incoming stream of \(N\) row-dictionaries into \(k\) subsets of size \(m\ (\text{<=}\ \text{chunk\_size})\):
(30)#\[k = \left\lceil \frac{N}{m} \right\rceil\]Each subset is converted into a DataFrame, stored, and released from memory, and then concatenated at finalization time.
- Parameters:
chunk_size (
int, optional) – The maximum number of rows to hold in the internal buffer before converting them into a DataFrame chunk. Default is 100000.processor (
{'auto', 'cpu', 'gpu'}, optional) –Controls the engine used to build the DataFrame:
'cpu': Always use pandas.'gpu': Attempt to use cudf (raise an error if not available).'auto': Use cudf if a GPU is detected and cudf is installed; otherwise fallback to pandas.
verbose (
int, optional) –Verbosity level. Default is 1:
0 : Silent.
1 : Basic information.
2 : Debug / detailed printing.
Notes
This object is intended for situations where the total row count can be very large, potentially in the millions. By breaking data into chunks, you can avoid excessive memory usage and keep the system more responsive. If processor is
'auto'or'gpu', the module callscheck_processorto verify GPU availability, then uses cudf if appropriate.Note
If the total data is larger than your available memory (whether RAM or GPU), consider writing out each chunk to disk as a partitioned file (e.g., Parquet or Feather) instead of storing them all in memory.
Examples
>>> from geoprior.utils.sys_utils import BatchDataFrameBuilder >>> # Suppose we have a large list of dictionaries >>> data = [ ... {'colA': i, 'colB': i**2} for i in range(10**6) ... ] >>> with BatchDataFrameBuilder(chunk_size=50000, ... processor='auto', ... verbose=2) as builder: ... builder.add_rows(data) ... >>> # After exiting the context, the final DataFrame is >>> # automatically built and stored in builder.final_df >>> final_df = builder.final_df >>> print(final_df.shape) (1000000, 2)
See also
pandas.DataFrameCore pandas DataFrame object.
cudf.DataFrameGPU DataFrame object from RAPIDS.
check_processorUtility for detecting GPU availability.
- __init__(chunk_size=100000, processor='auto', verbose=1)[source]
Initializes the builder, setting up chunk size, processor preference, and verbosity. Checks for GPU availability if requested.
- __enter__()[source]
Enters the context manager. Returns self so we can use it in a with-statement scope.
- add_row(row)[source]
Adds a single row to the internal buffer.
This method appends the given dictionary row to the in-memory buffer. If the buffer reaches self.chunk_size, it is automatically flushed.
- Parameters:
row (
dict) – A row in dictionary form, where keys correspond to column names and values represent the row data.
Notes
Internally calls
_flush()once the buffer has reached its maximum size.
- add_rows(rows)[source]
Adds multiple rows to the internal buffer.
This method iterates over the list of dictionaries rows. For each element,
add_row()is called, which may trigger a flush if the buffer is full.- Parameters:
rows (
listofdict) – Each dictionary should have the same structure as a typical row in the final DataFrame.
Notes
This method is merely a convenience layer over
add_row().
- finalize()[source]
Flushes remaining rows and concatenates all chunks.
Once the remaining rows in _rows are processed into a chunk, this method concatenates all stored chunk DataFrames (either pandas or cudf) into one final DataFrame. The resulting DataFrame is returned.
- Returns:
The final DataFrame, which may be a pandas DataFrame or a cudf DataFrame (if processor is set to allow GPU usage and cudf is available).
- Return type:
DataFrame
Notes
After concatenation, all chunk DataFrames are cleared from memory. This method is called automatically upon exiting the context (i.e., in
__exit__()).
- __exit__(exc_type, exc_val, exc_tb)[source]
Exits the context manager. Automatically finalizes the DataFrame by calling
finalize(), storing the result in self.final_df.
- class geoprior.utils.sys_utils.WorkflowOptimizer(parallelize=True, memory_cleanup=False, log_level=20, optimize_cpu=True, num_processes=None, cpu_cores=None, verbose=True)[source]
Bases:
objectWorkflowOptimizer is a decorator class designed to optimize the execution of computationally intensive functions by enabling parallelization, managing CPU and memory resources, and performing cleanup tasks. It provides flexibility through various parameters that allow users to customize optimization strategies according to their workflow requirements.
(31)#\[T_{ ext{total}} = T_{ ext{start}} + T_{ ext{execution}} + T_{ ext{cleanup}}\]Here, \(T_{ ext{total}}\) is the total workflow time, \(T_{ ext{start}}\) is the initialization time, \(T_{ ext{execution}}\) is the main execution time, and \(T_{ ext{cleanup}}\) is the cleanup time.
- Parameters:
parallelize (
bool, optional) – Flag to enable or disable parallel processing. If set toTrue, the decorator will attempt to parallelize the execution of the decorated function using multiprocessing. Default isTrue.memory_cleanup (
bool, optional) – Whether to clean up system memory after the execution of the decorated function. This includes triggering garbage collection and clearing GPU caches if applicable. Default isFalse.log_level (
int, optional) – Level of logging verbosity. Accepts standard logging levels such aslogging.INFO,logging.DEBUG, etc. Default islogging.INFO.optimize_cpu (
bool, optional) – Whether to optimize CPU usage by setting CPU affinity to restrict the process to specific CPU cores. IfTrue, the decorator will bind the process to the cores specified incpu_cores. Default isTrue.num_processes (
int, optional) – The number of parallel processes to use whenparallelizeis enabled. If not specified, it defaults to the minimum of the number of available CPU cores and the length of thedataiterable passed to the function. Default isNone.cpu_cores (
listorNone, optional) – A list of specific CPU cores to bind the process to for optimized CPU usage. IfNone, the process is allowed to run on all available CPU cores. Example:[0, 1, 2, 3]. Default isNone.verbose (
bool, optional) – Whether to print detailed logs during execution. If set toFalse, only essential information will be logged based on thelog_level. Default isTrue.
Examples
>>> from geoprior.utils.sys_utils import WorkflowOptimizer >>> import time >>> >>> @WorkflowOptimizer( ... parallelize=True, ... memory_cleanup=True, ... log_level=logging.DEBUG, ... num_processes=4, ... cpu_cores=[0, 1, 2, 3], ... verbose=True ... ) >>> def process_data(data_chunk): ... '''Simulate a time-consuming data processing function.''' ... time.sleep(1) # Simulate a time-consuming task ... return f"Processed {data_chunk}" ... >>> data_chunks = ['chunk1', 'chunk2', 'chunk3', 'chunk4'] >>> results = process_data(data=data_chunks) >>> print(results) ['Processed chunk1', 'Processed chunk2', 'Processed chunk3', 'Processed chunk4']
Notes
The decorator checks for the presence of a
datakeyword argument to decide whether parallelization should be applied. Whenparallelize=True, the decorated function should be compatible with multiprocessing, meaning it should be picklable. Memory cleanup can be useful in long-running workflows, and CPU affinity may improve performance by reducing context switching and cache misses. Logging behavior follows the standard Python logging model.See also
multiprocessing.PoolProvides a pool of worker processes.
psutil.ProcessAllows manipulation of system processes.
- __init__(parallelize=True, memory_cleanup=False, log_level=20, optimize_cpu=True, num_processes=None, cpu_cores=None, verbose=True)[source]
- __call__(func)[source]
Makes the class instance callable so it can be used as a decorator.
- Parameters:
func (
function) – The function to be decorated and optimized.- Returns:
wrapper – The wrapped function with optimization strategies applied.
- Return type:
function
- geoprior.utils.sys_utils.check_port_in_use(port)[source]
Checks if a port is currently in use, which is useful for server-based applications.
- Parameters:
port (
int) – The port number to check.- Returns:
True if the port is in use, otherwise False.
- Return type:
Examples
>>> from geoprior.utils.sys_utils import check_port_in_use >>> check_port_in_use(8080) False
- geoprior.utils.sys_utils.clean_temp_files(directory=None)[source]
Cleans up temporary files in a specified directory.
- Parameters:
directory (
str, optional) – The directory to clean up. If None, cleans the default temporary directory.- Return type:
Notes
This function is particularly useful for freeing up disk space in data-intensive applications.
Examples
>>> from geoprior.utils.sys_utils import clean_temp_files >>> clean_temp_files("/path/to/temp/dir")
- geoprior.utils.sys_utils.create_temp_dir(prefix='tmp')[source]
Creates a temporary directory and returns its path.
- Parameters:
prefix (
str, optional) – The prefix for the temporary directory name. Default is “tmp”.- Returns:
dir_path – The full path of the created temporary directory.
- Return type:
Notes
This function is helpful for managing temporary directories in applications where short-term data storage is needed.
Examples
>>> from geoprior.utils.sys_utils import create_temp_dir >>> temp_dir = create_temp_dir() >>> print(temp_dir) '/tmp/tmpabcd1234'
See also
create_temp_fileCreates a temporary file.
- geoprior.utils.sys_utils.create_temp_file(suffix='', prefix='tmp')[source]
Creates a temporary file and returns its path.
- Parameters:
- Returns:
file_path – The full path of the created temporary file.
- Return type:
Notes
This function is useful for handling data temporarily in applications where files need to be stored and accessed for a short time.
Examples
>>> from geoprior.utils.sys_utils import create_temp_file >>> temp_file = create_temp_file() >>> print(temp_file) '/tmp/tmpabcd1234'
See also
create_temp_dirCreates a temporary directory.
- geoprior.utils.sys_utils.environment_summary()[source]
Provides a summary of the current computing environment, including information on Python version, OS, CPU, memory, available GPU(s), and a list of installed Python packages.
- Returns:
env_info – Dictionary containing environment details, including:
python_version : The version of Python in use.
os : Operating system name.
os_version : Version of the operating system.
cpu_count : Number of logical CPU cores.
memory : Total system memory in GB.
device_count, device_name, memory_total, cuda_version (if available) : GPU details from detailed_gpu_info.
installed_packages : List of installed Python packages (first 10) in name==version format.
- Return type:
Notes
The function attempts to load installed packages using pkg_resources. If pkg_resources is not available, it defaults to “N/A” for installed packages.
- Raises:
ImportError – If pkg_resources is not installed.
RuntimeError – If an error occurs while gathering environment information.
- Return type:
Examples
>>> from geoprior.utils.sys_utils import environment_summary >>> env_info = environment_summary() >>> print(env_info) {'python_version': '3.9.5', 'os': 'Linux', 'os_version': '5.4.0-80-generic', 'cpu_count': '4', 'memory': '15.5 GB', 'device_count': '1', 'device_name': 'NVIDIA Tesla T4', 'memory_total': '15.99 GB', 'cuda_version': '11.1', 'installed_packages': 'numpy==1.21.0, pandas==1.3.0, ...'}
- geoprior.utils.sys_utils.find_by_regex(o, pattern, func=<function match>, **kws)[source]
Find pattern in object whatever an “iterable” or not.
when we talk about iterable, a string value is not included.
- Parameters:
o (
stroriterable,) – text litteral or an iterable object containing or not the specific object to match.pattern (
str, default ='[_#&*@!_,;\s-]\s*') – The base pattern to split the text into a columnsfunc (re callable , default
re.match) –regular expression search function. Can be [re.match, re.findall, re.search ],or any other regular expression function.
re.match(): function searches the regular expression pattern andreturn the first occurrence. The Python RegEx Match method checks for a match only at the beginning of the string. So, if a match is found in the first line, it returns the match object. But if a match is found in some other line, the Python RegEx Match function returns null.
re.search(): function will search the regular expression patternand return the first occurrence. Unlike Python re.match(), it will check all lines of the input string. The Python re.search() function returns a match object when the pattern is found and “null” if the pattern is not found
re.findall()module is used to search for ‘all’ occurrences thatmatch a given pattern. In contrast, search() module will only return the first occurrence that matches the specified pattern. findall() will iterate over all the lines of the file and will return all non-overlapping matches of pattern in a single step.
kws (
dict,) – Additional keywords arguments passed to functionsre.match()orre.search()orre.findall().
- Returns:
om – matched object put is the list
- Return type:
Example
>>> from geoprior.utils.sys_utils import find_by_regex >>> from geoprior.datasets import load_hlogs >>> X0, _= load_hlogs (as_frame =True ) >>> columns = X0.columns >>> str_columns =','.join (columns) >>> find_by_regex (str_columns , pattern='depth', func=re.search) ... ['depth'] >>> find_by_regex(columns, pattern ='depth', func=re.search) ... ['depth_top', 'depth_bottom']
- geoprior.utils.sys_utils.find_similar_string(name, container, stripitems='_', deep=False)[source]
Find the most similar string in a container to the provided name.
This function searches for the most likely matching string in a container based on the provided name. It sanitizes the name by stripping specified characters and can perform a deep search to find partial matches.
- Parameters:
name (
str) – The string to search for in the container.container (
list,tuple, ordict) – The container with strings to search in.stripitems (
strorlistofstr, optional) – Characters or strings to strip from name before searching. If a string, multiple items can be separated by ‘:’, ‘,’, or ‘;’. Default is'_'.deep (
bool, optional) – IfTrue, performs a deeper search by checking if name is a substring of any item in the container. Default isFalse.
- Returns:
result – The most similar string from the container, or
Noneif no match is found.- Return type:
Examples
>>> from geoprior.utils.sys_utils import find_similar_string >>> container = {'dipole': 1, 'quadrupole': 2} >>> find_similar_string('dipole_', container) 'dipole' >>> find_similar_string('dip', container, deep=True) 'dipole' >>> find_similar_string('+dipole__', container, stripitems='+;__', deep=True) 'dipole'
Notes
This function is useful when trying to find the closest matching string in a container, especially when exact matches are not guaranteed due to formatting inconsistencies or typos.
See also
str.stripReturns a copy of the string with leading and trailing characters removed.
- geoprior.utils.sys_utils.get_cpu_usage(per_cpu=False)[source]
Returns the current CPU usage as a percentage, optionally providing per-core usage for systems with multiple cores.
- Parameters:
per_cpu (
bool, defaultFalse) – If True, returns a list with the CPU usage percentage for each core. If False, returns the overall CPU usage as a single percentage.- Returns:
usage – If per_cpu is False, returns the overall CPU usage as a float percentage. If per_cpu is True, returns a list with each entry corresponding to the usage percentage of an individual core.
- Return type:
Notes
This function uses the psutil library to retrieve CPU usage information and requires an interval of 1 second to calculate the usage accurately.
Examples
>>> from geoprior.utils.sys_utils import get_cpu_usage >>> get_cpu_usage() 1.3 >>> get_cpu_usage(per_cpu=True) [20.4, 25.1, 21.3, 24.5]
- geoprior.utils.sys_utils.get_disk_usage(path='/')[source]
Returns disk usage statistics for a specified filesystem path, including total, used, and free disk space in gigabytes (GB).
- Parameters:
path (
str, default'/') – The filesystem path for which to check disk usage statistics. By default, it uses the root directory (/).- Returns:
disk_usage – A tuple containing:
total_disk : Total disk space in GB.
used_disk : Used disk space in GB.
free_disk : Free disk space in GB.
- Return type:
Notes
Disk usage information is gathered using the psutil library. Disk space is converted to gigabytes (GB) by dividing the values by 1024^3.
- Raises:
FileNotFoundError – If the specified path does not exist on the filesystem.
PermissionError – If the program does not have permission to access the specified path.
- Parameters:
path (str)
- Return type:
Examples
>>> from geoprior.utils.sys_utils import get_disk_usage >>> total, used, free = get_disk_usage(path="/") >>> print(f"Total: {total} GB, Used: {used} GB, Free: {free} GB") Total: 256 GB, Used: 128 GB, Free: 128 GB
- geoprior.utils.sys_utils.get_gpu_info()[source]
Provides detailed information about available GPUs, including device name, memory capacity, and CUDA version (if PyTorch is installed).
- Returns:
gpu_info – Dictionary containing GPU details, including:
device_count : Number of available GPU devices.
device_name : Name of the first GPU device.
memory_total : Total memory of the first GPU device in GB.
cuda_version : CUDA version, if available.
If no GPU is available or PyTorch is not installed, returns None.
- Return type:
Notes
This function requires PyTorch to check for GPU availability. If PyTorch is not installed, it logs a warning and returns None.
- Raises:
ImportError – If PyTorch is not installed on the system.
RuntimeError – If there is an issue retrieving GPU properties.
- Return type:
Examples
>>> from geoprior.utils.sys_utils import get_gpu_info >>> gpu_info = get_gpu_info() >>> print(gpu_info) {'device_count': '1', 'device_name': 'NVIDIA Tesla T4', 'memory_total': '15.99 GB', 'cuda_version': '11.1'}
- geoprior.utils.sys_utils.get_installed_packages()[source]
Lists all installed packages along with their versions in the current Python environment.
- Returns:
installed_packages – A list of installed packages and their versions in the format package_name==version.
- Return type:
Notes
This function is useful for dependency management and tracking installed packages, especially in data science and production environments.
Examples
>>> from geoprior.utils.sys_utils import get_installed_packages >>> get_installed_packages() ['numpy==1.21.0', 'pandas==1.3.0', 'scikit-learn==0.24.2', ...]
See also
environment_summarySummarizes the environment, including installed packages.
- geoprior.utils.sys_utils.get_memory_usage()[source]
Retrieves system memory usage statistics, providing the total, used, and available memory in megabytes (MB).
- Returns:
memory – A tuple containing:
total_memory : Total memory in MB.
used_memory : Used memory in MB.
available_memory : Available memory in MB.
- Return type:
Notes
This function leverages the psutil library for retrieving memory usage information. The conversion to MB is performed by dividing each value by 1024^2.
Examples
>>> from geoprior.utils.sys_utils import get_memory_usage >>> total, used, available = get_memory_usage() >>> print(f"Total: {total} MB, Used: {used} MB, Available: {available} MB") Total: 8192 MB, Used: 4096 MB, Available: 4096 MB
- geoprior.utils.sys_utils.get_python_version()[source]
Returns the version of Python being used in the current environment.
- Returns:
python_version – The version of Python currently in use.
- Return type:
Examples
>>> from geoprior.utils.sys_utils import get_python_version >>> get_python_version() '3.8.5'
See also
get_system_infoProvides broader system information, including Python version.
- geoprior.utils.sys_utils.get_system_info()[source]
Retrieves basic system information including OS, Python version, CPU details, and GPU availability.
- Returns:
system_info – A dictionary containing basic system information:
os_name : Name of the operating system.
os_version : Version of the operating system.
python_version : Python version.
cpu_count : Number of logical CPUs.
gpu_available : Whether a GPU is available (True or False).
- Return type:
Notes
This function checks for GPU availability via PyTorch if installed, otherwise it defaults to False.
Examples
>>> from geoprior.utils.sys_utils import get_system_info >>> get_system_info() {'os_name': 'Linux', 'os_version': '5.4.0-81-generic', 'python_version': '3.8.5', 'cpu_count': '8', 'gpu_available': 'True'}
See also
get_python_versionRetrieves the current Python version.
- geoprior.utils.sys_utils.get_uptime()[source]
Returns the system uptime in a human-readable format.
- Returns:
uptime – The system uptime formatted as “Xd:Yh:Zm:Ws”, where X, Y, Z, and W are days, hours, minutes, and seconds respectively.
- Return type:
Notes
This function is useful for monitoring or diagnosing long-running processes on the system.
Examples
>>> from geoprior.utils.sys_utils import get_uptime >>> get_uptime() '2d:5h:34m:12s'
- geoprior.utils.sys_utils.is_gpu_available()[source]
Checks if a GPU is available for computation on the system, using the PyTorch library if it is installed.
- Returns:
available – True if a GPU is available, False otherwise.
- Return type:
Notes
This function relies on the torch library (PyTorch) to detect GPU availability. If PyTorch is not installed, it logs a warning and returns False.
- Raises:
ImportError – If PyTorch is not installed and thus the GPU availability check cannot be performed.
- Return type:
Examples
>>> from geoprior.utils.sys_utils import is_gpu_available >>> is_gpu_available() True
- geoprior.utils.sys_utils.is_package_installed(package_name)[source]
Checks if a specific package is installed in the current Python environment.
- Parameters:
package_name (
str) – The name of the package to check.- Returns:
True if the package is installed, otherwise False.
- Return type:
Examples
>>> from geoprior.utils.sys_utils import is_package_installed >>> is_package_installed("numpy") True
- geoprior.utils.sys_utils.is_path_accessible(path, permissions='r')[source]
Checks if a specified path is accessible with the given permissions.
- Parameters:
- Returns:
accessible – True if the path is accessible with the specified permissions, otherwise False.
- Return type:
Notes
This function verifies file permissions in the current user context, ensuring flexibility for multi-user environments.
Examples
>>> from geoprior.utils.sys_utils import is_path_accessible >>> is_path_accessible("/path/to/file", permissions="rw") True
- geoprior.utils.sys_utils.is_port_open(port)[source]
Checks if a specified network port is open or occupied on the local machine.
- Parameters:
port (
int) – The port number to check for availability.- Returns:
Returns True if the port is open (not in use), otherwise False.
- Return type:
Notes
This function uses a socket connection to check if the specified port is open. It is helpful in applications where network services or applications need to bind to a specific port.
- Raises:
ValueError – If an invalid port number is provided.
- Parameters:
port (int)
- Return type:
Examples
>>> from geoprior.utils.sys_utils import is_port_open >>> is_port_open(8080) False
- geoprior.utils.sys_utils.manage_env_variable(var_name, value=None, default=None, action='get', file_path=None, overwrite=False)[source]
Manages environment variables, allowing retrieval, setting, or loading from a .env file.
- Parameters:
var_name (
str) – The name of the environment variable to retrieve, set, or load.value (
str, optional) – The value to set for the environment variable. Only used if action is “set”. Default is None.default (
str, optional) – The default value to return if the environment variable var_name is not found when action is “get”. If None, returns None when the variable is not found. Default is None.action (
str, default"get") –- The action to perform. Options are:
”get”: Retrieves the environment variable var_name.
”set”: Sets the environment variable var_name to value.
”load”: Loads environment variables from a .env file specified by file_path.
file_path (
str, optional) – The path to the .env file to load variables from when action is “load”. Required if action is “load”.overwrite (
bool, defaultFalse) – If True, allows overwriting existing environment variables when action is “load” or “set”. If False, preserves the current value of existing environment variables.
- Returns:
result –
If action is “get”, returns the value of the environment variable var_name or default if the variable is not set.
If action is “set” or “load”, returns None.
- Return type:
Notes
This function is useful for managing configuration data securely by utilizing environment variables.
Loading from a .env file allows you to define multiple variables in a single file, each defined in the KEY=VALUE format.
- Raises:
ValueError – If action is “set” and value is not provided, or if action is “load” and file_path is not specified.
FileNotFoundError – If action is “load” and file_path does not exist.
- Parameters:
- Return type:
str | None
Examples
>>> from geoprior.utils.sys_utils import manage_env_variable >>> manage_env_variable('HOME', action='get') '/home/username' >>> manage_env_variable('NEW_VAR', value='new_value', action='set') >>> manage_env_variable('NEW_VAR', action='get') 'new_value' >>> manage_env_variable('NON_EXISTENT_VAR', default='default_value', action='get') 'default_value' >>> manage_env_variable( var_name='', action='load', file_path='/path/to/.env', overwrite=True)
See also
os.getenvRetrieves environment variables.
os.environProvides access to the environment variables.
- geoprior.utils.sys_utils.manage_file_lock(file_path, action='lock', blocking=True, exclusive=True)[source]
Manages file locking and unlocking to prevent concurrent access.
This function allows both locking and unlocking actions on a file to prevent or allow concurrent access. It opens the file and applies an exclusive lock or shared lock, depending on the parameters specified.
- Parameters:
file_path (
str) – Path to the file that needs to be locked or unlocked.action (
str, default"lock") – Specifies the action to perform: “lock” to acquire a lock, or “unlock” to release a previously acquired lock.blocking (
bool, defaultTrue) – If True, the lock will block until it can be acquired. If False, the lock will raise an exception if it cannot be acquired immediately.exclusive (
bool, defaultTrue) – If True, an exclusive lock is applied. If False, a shared lock is applied (other processes can read the file simultaneously).
- Returns:
file_descriptor – If action is “lock”, returns the file descriptor on success; otherwise, None if action is “unlock” or if locking fails.
- Return type:
Notes
This function uses the fcntl module for locking, which is only available on Unix-based systems. The lock is maintained as long as the file descriptor remains open.
For “lock”, the function opens the file and applies a lock.
For “unlock”, it removes the lock and closes the file descriptor.
- Raises:
ValueError – If the action parameter is not one of “lock” or “unlock”.
OSError – If locking or unlocking the file fails.
- Parameters:
- Return type:
int | None
Examples
>>> from geoprior.utils.sys_utils import manage_file_lock >>> fd = manage_file_lock("/path/to/file", action="lock", blocking=True) >>> if fd: ... print("File is locked.") ... manage_file_lock(fd, action="unlock") ... print("File is unlocked.")
See also
os.openOpens a file descriptor.
fcntl.flockApplies or removes file locks.
- geoprior.utils.sys_utils.manage_temp(suffix='', prefix='tmp', action='create_file', directory=None, clean_all=False)[source]
Manages temporary files and directories by creating, accessing, or cleaning them as needed.
- Parameters:
suffix (
str, optional) – Suffix for the temporary file or directory, used only if action is “create_file” or “create_dir”. Default is an empty string.prefix (
str, optional) – Prefix for the temporary file or directory, used only if action is “create_file” or “create_dir”. Default is “tmp”.action (
str, default"create_file") –- Specifies the operation to perform. Options include:
”create_file”: Creates a temporary file and returns its path.
”create_dir”: Creates a temporary directory and returns its path.
”clean”: Cleans temporary files in the specified directory or the system temp directory if none is provided.
directory (
str, optional) – Directory to clean when action is “clean”. If None, uses the system’s default temporary directory. Ignored for file or directory creation actions.clean_all (
bool, defaultFalse) – If True, removes all files and directories within the specified directory. If False, only deletes files or directories created by this process. Used only when action is “clean”.
- Returns:
temp_path –
For “create_file” and “create_dir” actions, returns the path of the created file or directory.
For “clean”, returns None.
- Return type:
- Raises:
ValueError – If an invalid action is specified.
Notes
This function is useful for managing temporary resources in data processing tasks, where files or directories need to be created and cleaned up after use.
Examples
>>> from geoprior.utils.sys_utils import manage_temp >>> temp_file = manage_temp(action="create_file") >>> print(temp_file) '/tmp/tmpabcd1234'
>>> temp_dir = manage_temp(action="create_dir", prefix="data_") >>> print(temp_dir) '/tmp/data_abcd1234'
>>> manage_temp(action="clean", directory="/path/to/temp", clean_all=True)
- geoprior.utils.sys_utils.parallelize_jobs(function, tasks=(), n_jobs=None, executor_type='process')[source]
Parallelize the execution of a callable across multiple processors, supporting both positional and keyword arguments.
- Parameters:
function (
Callable[...,Any]) – The function to execute in parallel. This function must be picklable if using executor_type=’process’.tasks (
Sequence[Dict[str,Any]], optional) – A sequence of dictionaries, where each dictionary contains two keys: ‘args’ (a tuple) for positional arguments, and ‘kwargs’ (a dict) for keyword arguments, for one execution of function. Defaults to an empty sequence.n_jobs (
Optional[int], optional) – The number of jobs to run in parallel. None or 1 uses a single processor, any positive integer specifies the exact number of processors to use, -1 uses all available processors. Default is None (1 processor).executor_type (
str, optional) – The type of executor to use. Can be ‘process’ for CPU-bound tasks or ‘thread’ for I/O-bound tasks. Default is ‘process’.
- Returns:
A list of results from the function executions.
- Return type:
- Raises:
ValueError – If function is not picklable when using ‘process’ as executor_type.
Examples
>>> from geoprior.utils.sys_utils import parallelize_jobs >>> def greet(name, greeting='Hello'): ... return f"{greeting}, {name}!" >>> tasks = [ ... {'args': ('John',), 'kwargs': {'greeting': 'Hi'}}, ... {'args': ('Jane',), 'kwargs': {}} ... ] >>> results = parallelize_jobs(greet, tasks, n_jobs=2) >>> print(results) ['Hi, John!', 'Hello, Jane!']
- geoprior.utils.sys_utils.represent_callable(obj, skip=None)[source]
Represent callable objects by formatting their signatures.
This function generates a string representation of a callable object’s signature, including parameters and default values. It supports classes, functions, and instance methods.
- Parameters:
- Returns:
representation – A string representation of the callable object’s signature.
- Return type:
- Raises:
TypeError – If obj is not a callable object.
Examples
>>> from geoprior.utils.sys_utils import represent_callable >>> def example_function(a, b=2): ... pass >>> represent_callable(example_function) 'example_function(a, b=2)' >>> class ExampleClass: ... def __init__(self, x, y=10): ... self.x = x ... self.y = y >>> represent_callable(ExampleClass) 'ExampleClass(x, y=10)' >>> instance = ExampleClass(5) >>> represent_callable(instance) 'ExampleClass(x=5, y=10)'
Notes
This function is useful for logging or displaying the parameters of callable objects in a readable format.
See also
inspect.signatureGet a signature object for the callable.
- geoprior.utils.sys_utils.run_command(command, capture_output=True)[source]
Runs a shell command and optionally captures its output.
- Parameters:
- Returns:
output – Returns the command output as a string if capture_output is True. If capture_output is False, returns None.
- Return type:
Notes
This function uses subprocess.run to execute shell commands, which allows for error handling and logging. For example,
run_command("echo Hello World")returns"Hello World\n"whencapture_output=True.- Raises:
subprocess.CalledProcessError – If the command exits with a non-zero status and capture_output is True.
- Parameters:
- Return type:
str | None
- geoprior.utils.sys_utils.safe_getattr(obj, name, default_value=None)[source]
Safely get an attribute from an object, with a helpful error message.
This function attempts to retrieve an attribute from the given object. If the attribute is not found, it can return a default value or raise an AttributeError with a suggestion for a similar attribute.
- Parameters:
- Returns:
value – The value of the retrieved attribute or the default value.
- Return type:
any- Raises:
AttributeError – If the attribute is not found and no default value is provided.
Examples
>>> from geoprior.utils.sys_utils import safe_getattr >>> class MyClass: ... def __init__(self, a, b): ... self.a = a ... self.b = b >>> obj = MyClass(1, 2) >>> safe_getattr(obj, 'a') 1 >>> safe_getattr(obj, 'c', default_value='default') 'default' >>> safe_getattr(obj, 'c') Traceback (most recent call last): ... AttributeError: 'MyClass' object has no attribute 'c'. Did you mean 'a'?
Notes
This function enhances the built-in getattr by providing helpful suggestions when an attribute is not found.
See also
getattrBuilt-in function to get an attribute from an object.
- geoprior.utils.sys_utils.safe_optimize(func=None, *, parallelize=True, memory_cleanup=False, log_level=20, optimize_cpu=True, num_processes=None, cpu_cores=None, verbose=True, mode='strict')[source]
Optimizes the workflow by wrapping a function to measure execution time, enable parallelization, manage resources, and perform memory cleanup and acts similary like class-based decorator WorflowOptimizer.
Class-based decorators can sometimes encounter issues when trying to pickle certain objects, especially in parallel execution contexts. This issue arises because certain objects (such as file handles, open network connections, or non-serializable class instances) cannot be passed between processes in multiprocessing environments. By ensuring compatibility with these contexts, safe_optimize helps mitigate such issues and optimize the execution of computationally intensive workflows.
This decorator is particularly suitable for workflows involving large-scale computations, such as data processing pipelines, machine learning model training, or simulations, where parallel execution and resource optimization are crucial for performance improvement.
- Parameters:
parallelize (
bool, optional) – Flag to enable or disable parallel processing (default isTrue).memory_cleanup (
bool, optional) – Whether to clean up system memory after execution (default isFalse).log_level (
int, optional) – Level of logging (default islogging.INFO). Set tologging.DEBUGfor more detailed logs.optimize_cpu (
bool, optional) – Whether to optimize CPU core usage (default isTrue).num_processes (
Optional[int], optional) – The number of parallel processes for execution (default isNone).cpu_cores (
Optional[List[int]], optional) – Specify a list of CPU cores to restrict the process (default isNone).verbose (
bool, optional) – Whether to print detailed logs during execution (default isTrue).mode (
str, optional) – Mode for handling pickling issues:'strict'to raise errors, or'soft'to fallback to sequential execution with warnings (default is'strict').func (Callable | None)
- Returns:
decorator – The wrapped function that includes optimization strategies.
- Return type:
Callable- Raises:
ValueError – If an unsupported mode is specified.
Examples
>>> from geoprior.utils.sys_utils import safe_optimize
>>> @safe_optimize( ... parallelize=True, ... memory_cleanup=True, ... log_level=logging.DEBUG, ... optimize_cpu=True, ... num_processes=4, ... cpu_cores=[0, 1, 2, 3], ... verbose=True, ... mode='soft' ... ) ... def process_data(data): ... # Your data processing logic here ... return [d * 2 for d in data]
>>> data = [1, 2, 3, 4, 5] >>> results = process_data(data) >>> print(results) [2, 4, 6, 8, 10]
Notes
This decorator uses multiprocessing for parallel execution, which may not be suitable for all environments, especially those that do not support forking (e.g., some Windows configurations).
Ensure that the decorated function and its arguments are picklable when using parallelization.
The mode parameter allows handling non-picklable objects gracefully.
See also
multiprocessing.PoolFor parallel task execution.
psutilFor system and process utilities.
functools.wrapsFor preserving metadata of decorated functions.
- geoprior.utils.sys_utils.system_uptime()[source]
Retrieves the system uptime, which is the duration the system has been running since the last boot, in a human-readable format.
- Returns:
uptime – System uptime in the format “Xd:Yh:Zm:Ws”, where X, Y, Z, and W represent days, hours, minutes, and seconds, respectively.
- Return type:
Notes
This function is cross-platform and works on Windows, macOS (Darwin), and Linux. It uses different commands to retrieve uptime based on the operating system.
- Raises:
NotImplementedError – If the function is called on an unsupported operating system.
RuntimeError – If an error occurs while retrieving uptime.
- Return type:
Examples
>>> from geoprior.utils.sys_utils import system_uptime >>> system_uptime() '2d:10h:33m:12s'
- geoprior.utils.sys_utils.build_large_df(forecast_results, dt_col, tname, spatial_cols=None, chunk_size=None, verbose=0)[source]
Construct memory-optimized DataFrame from large forecast results using chunked processing.
Implements dynamic chunk sizing and dtype optimization to handle datasets exceeding available memory. Uses temporary storage and parallel processing for efficient resource utilization.
- If pyarrow is installed, the function uses parquet I/O; otherwise,
CSV files are used as a fallback.
- Parameters:
forecast_results (
List[Dict]) – Input data as list of dictionary records. Each dictionary represents a row with column-value pairs. Minimum 1000 entries recommended for chunking benefits.dt_col (
str) – Name of temporal column. Accepts numeric years (e.g.,2023) or datetime strings. Automatic type detection with fallback tonumpy.int32for years >200000.tname (
str) – Target variable prefix for prediction columns. Quantile columns are expected in the formf"{tname}_q{quantile}"such as"subs_q10", while point predictions usef"{tname}_pred".spatial_cols (
List[str], optional) – Geographic columns (e.g.,['longitude', 'latitude']). Auto-detects categorical ( <10% unique values) vs continuous spatial data, usingpandas.Categoryornumpy.float32dtypes respectively.chunk_size (
int, optional) –Maximum rows per chunk. Auto-calculated using
(32)\[C_{optimal} = \min\left(10^5, \frac{0.8M_{free}}{S_{row}}\right)\]where \(M_{free}\) is available memory in bytes and \(S_{row}\) is the estimated row size, assumed to be about 1 KB by default.
verbose (
int, default0) – Logging verbosity. Use0for silent mode,1for memory reports,2for chunk diagnostics, and3for per-chunk metrics.
- Returns:
Combined DataFrame with optimized dtypes, preserving original column order. Memory footprint reduced by 40-60% compared to naive construction.
- Return type:
pd.DataFrame
Examples
>>> from geoprior.utils.sys_utils import build_large_df >>> import numpy as np
# Basic usage with 1M rows >>> data = [{‘year’: y, ‘value_q50’: np.random.randn()} … for y in range(2010, 2020) for _ in range(100000)] >>> df = build_large_df(data, dt_col=’year’, tname=’value’)
# With spatial columns >>> geo_data = [{‘lat’: np.random.uniform(-90, 90), … ‘lon’: np.random.uniform(-180, 180), … ‘pred’: np.random.randn()} … for _ in range(500000)] >>> df = build_large_df(geo_data, dt_col=’date’, tname=’pred’, … spatial_cols=[‘lat’, ‘lon’], verbose=2)
Notes
Key implementation features include dynamic chunk adjustment using
psutil.virtual_memory(), concurrent chunk reading withThreadPoolExecutorwhen many chunks are present, dtype inference for temporal and spatial columns, and guaranteed tempfile cleanup viatry...finallyblocks.See also
pd.DataFrameBase DataFrame construction
pd.concatChunk aggregation method
geoprior.nn.utils.generate_forecastPrimary data source
geoprior.utils.memory_optimizer.reduce_mem_usageDetailed dtype optimization
These modules help manage serialization, joblib artifacts, threading, runtime-device behavior, and larger workflow-side assembly patterns.
Scale, shape, and sequence helpers#
Utilities for computing error metrics in physical units given Stage-1 scaling metadata.
Main entry points include inverse_scale_target(...),
point_metrics(...), and per_horizon_metrics(...).
These are designed to work with the NATCOM pipeline where Stage-1
stores a scaler_info dict in the manifest:
scaler_info[target_name] = {
"scaler_path": ".../main_scaler.joblib",
"all_features": [...],
"idx": <index in scaler>,
"scaler": <fitted MinMaxScaler>, # optional, attached at Stage-2
}
They can also handle a bare scaler object, a scaler path, or manual min/max or mean/std parameters.
- geoprior.utils.scale_metrics.evaluate_point_forecast(model, out, y_true_subs, *, y_true_gwl=None, n_q=3, quantiles=None, q=None, use_physical=False, return_physical=None, scaler_info=None, subs_target_name='subsidence', gwl_target_name='gwl', scaler_entry=None, scaler=None, feature_index=None, n_features=None, params=None, strict=True, output_names=None)[source]
End-to-end helper for point-forecast evaluation.
Pipeline: extract predictions from
modelandout, canonicalize BHQO quantile outputs when needed, pick a point prediction, optionally inverse-scale it, and compute global and per-horizon metrics.- Parameters:
model (Any) – Passed to extract_preds(…).
out (Any) – Passed to extract_preds(…).
y_true_subs (Any) – True subsidence, shape (B, H, 1) or (B, H).
y_true_gwl (Any | None) – Optional true gwl/head, same shape conventions.
n_q (int) – Quantile-selection controls for
(B, H, Q, 1)outputs. IfqisNone, the median is preferred. Ifqis an integer, it is treated as a direct quantile index. Ifqis a float andquantilesis provided, the nearest quantile is selected; otherwise the float is treated as a fraction in[0, 1].quantiles (Sequence[float] | None) – Quantile-selection controls for
(B, H, Q, 1)outputs. IfqisNone, the median is preferred. Ifqis an integer, it is treated as a direct quantile index. Ifqis a float andquantilesis provided, the nearest quantile is selected; otherwise the float is treated as a fraction in[0, 1].q (float | int | None) – Quantile-selection controls for
(B, H, Q, 1)outputs. IfqisNone, the median is preferred. Ifqis an integer, it is treated as a direct quantile index. Ifqis a float andquantilesis provided, the nearest quantile is selected; otherwise the float is treated as a fraction in[0, 1].use_physical (bool) – If True, compute metrics in physical units.
return_physical (bool | None) – If None, defaults to use_physical. If
True, return physical-space arrays such assubs_pred_physandgwl_pred_physwhen possible.scaler_info (Mapping[str, Any] | None) – Passed through to
inverse_scale_target(...).scaler_entry (Mapping[str, Any] | None) – Passed through to
inverse_scale_target(...).scaler (Any | None) – Passed through to
inverse_scale_target(...).feature_index (int | None) – Passed through to
inverse_scale_target(...).n_features (int | None) – Passed through to
inverse_scale_target(...).params (Mapping[str, float] | None) – Passed through to
inverse_scale_target(...).subs_target_name (str)
gwl_target_name (str)
strict (bool)
- Returns:
Dictionary containing model-space predictions, optional physical-space predictions, and global and per-horizon metrics for subsidence and groundwater outputs.
- Return type:
- geoprior.utils.scale_metrics.auto_noise_std_from_increments(y_inc, *, noise_frac=0.1, percentile=95.0, min_std=0.0, max_std=None, eps=1e-12)[source]
Compute noise std as a fraction of a robust increment scale.
- Parameters:
y_inc (
np.ndarray) – Deterministic increments (before noise). Any shape; will be flattened.noise_frac (
float, default0.10) – Fraction applied to the percentile scale.percentile (
float, default95.0) – Percentile of the absolute increments used as the scale.min_std (
float, default0.0) – Lower bound for returned std.max_std (
floatorNone, defaultNone) – Optional upper bound for returned std.eps (
float, default1e-12) – Small positive value for safe fallback.
- Returns:
A finite, non-negative std value.
- Return type:
Notes
Non-finite values are filtered first. If the percentile-based scale is approximately zero, the function falls back to the maximum absolute increment, then the mean absolute increment, and finally to
eps.
- geoprior.utils.scale_metrics.resolve_noise_std(y_inc, *, noise_std=None, noise_frac=0.1, percentile=95.0, min_std=0.0, max_std=None)[source]
Prefer explicit noise_std when provided, otherwise auto-compute from increments.
- geoprior.utils.scale_metrics.scale_target(y_phys, *, scaler_info=None, target_name=None, scaler_entry=None, scaler=None, feature_index=None, n_features=None, params=None)[source]
Forward-transform physical values into scaled space.
Mirrors inverse_scale_target(…) conventions.
- geoprior.utils.scale_metrics.inverse_scale_target(y_scaled, *, scaler_info=None, target_name=None, scaler_entry=None, scaler=None, feature_index=None, n_features=None, params=None)[source]
Inverse-transform a scaled target array back to physical units.
Supports three patterns:
Stage-1
scaler_infodict (preferred in NATCOM): passscaler_info=scaler_info_dict, target_name="subsidence".A bare scaler instance or a path to a joblib dump via
scaler=.... If multi-feature, also passfeature_indexand optionallyn_features.Manual scaling parameters via
paramssuch as {"min", "max"}, {"mean", "std"}, or {"scale", "shift"}.
- Parameters:
y_scaled (
array-like) – Scaled values, e.g. of shape (N, H, 1) or (N,).scaler_info (
mapping, optional)target_name (
str, optional)scaler_entry (
mapping, optional)feature_index (
int, optional)n_features (
int, optional)params (
mapping, optional) – Manual parameters: - {"min", "max"} → MinMax scaling - {"mean", "std"} → standardization - {"scale", "shift"} → x = scale * x_scaled + shift.
- Returns:
Array with same shape as
y_scaledbut in physical units when scaling information is available. If nothing usable is found, returns the input as a NumPy array.- Return type:
np.ndarray
- geoprior.utils.scale_metrics.point_metrics(y_true, y_pred, *, use_physical=False, scaler_info=None, target_name=None, scaler_entry=None, scaler=None, feature_index=None, n_features=None, params=None)[source]
Compute MAE/MSE/R², optionally in physical units.
If
use_physical=True, both y_true and y_pred are inverse- transformed via inverse_scale_target before computing metrics.- Parameters:
y_true (
array-like) – Arbitrary shapes (e.g. (N, H, 1)); flattened internally.y_pred (
array-like) – Arbitrary shapes (e.g. (N, H, 1)); flattened internally.use_physical (
bool, defaultFalse) – IfTrue, inverse-transformy_trueandy_predto physical units before computing metrics.scaler_info (optional) – Passed through to
inverse_scale_target.target_name (optional) – Passed through to
inverse_scale_target.scaler_entry (optional) – Passed through to
inverse_scale_target.scaler (optional) – Passed through to
inverse_scale_target.feature_index (optional) – Passed through to
inverse_scale_target.n_features (optional) – Passed through to
inverse_scale_target.params (optional) – Passed through to
inverse_scale_target.
- Returns:
{“mae”: …, “mse”: …, “r2”: …}
- Return type:
- geoprior.utils.scale_metrics.per_horizon_metrics(y_true, y_pred, *, use_physical=False, scaler_info=None, target_name=None, scaler_entry=None, scaler=None, feature_index=None, n_features=None, params=None)[source]
Per-horizon MAE/R².
- Parameters:
- Returns:
- Return type:
Sequence-building helpers for temporal model inputs.
- geoprior.utils.sequence_utils.check_sequence_feasibility(df, *, time_col, group_id_cols=None, time_steps=12, forecast_horizon=3, engine='vectorized', mode=None, logger=<built-in function print>, verbose=0, error='warn')[source]
Quick pre-flight feasibility check for sliding-window sequence generation
Checks whether the input table is long enough—per group—to yield at least one (look-back + horizon) sliding window, without allocating large NumPy tensors. It is typically called immediately before
prepare_pinn_data_sequences()or similar generators to “fail fast’’ on data shortages.- Parameters:
df (
pandas.DataFrame) – Tidy time-series table in long format. Every row represents one observation timestamp (and optionally one entity when group_id_cols is given). The function never mutates df.time_col (
str) – Column that defines temporal order inside each trajectory. Must be sortable; no other assumptions (numeric, datetime, …) are made.group_id_cols (
listofstrorNone, defaultNone) – Column names that jointly identify independent trajectories (e.g.["well_id"]or["site", "layer_id"]). When None the whole DataFrame is treated as a single group.time_steps (
int, default12) – Look-back window \(T_ ext{past}\) consumed by the encoder.forecast_horizon (
int, default3) – Prediction horizon \(H\) produced by the decoder.engine (
{'vectorized', 'loop', 'pyarrow'}, default'vectorized') –‘vectorized’ – fastest; single
DataFrame.groupby.size()call (C-level) plus NumPy math.’native’ – reproduces the original Python loop for debuggability.
’pyarrow’ – forces pandas’ Arrow backend, then runs the same vectorised logic; ~20 % faster on very wide frames when pyarrow ≥ 14 is installed.
mode (
{'pihal_like', 'tft_like'}orNone, optional) – Present only for API symmetry. Ignored – feasibility depends solely ontime_steps + forecast_horizon.logger (
callable, defaultprint()) – Sink for human-readable log messages. Must accept a single str.verbose (
int, default0) – Verbosity level: 0 → silent, 1 → summary lines, 2 → per-group detail.error (
{'raise', 'warn', 'ignore'}, default'warn') –Action when no group is long enough.
'raise'– raiseSequenceGeneratorError.'warn'– emitUserWarning, returnFalse.'ignore'– stay silent, returnFalse.
- Returns:
- Raises:
SequenceGeneratorError – Raised only when
error='raise'and all groups fail the length check.- Return type:
Notes
A group passes the check iff
(33)#\[\text{len(group)} \;\ge\; T_\text{past} + H\]No validation of time-gaps, duplicates, or NaNs is performed; those are deferred to the full preparation routine.
The Arrow backend (
engine='pyarrow') can accelerate very wide frames because each column is represented as a contiguous Arrow array with cheap zero-copy slicing.Examples
Minimal usage
>>> from geoprior.utils.sequence_utils import check_sequence_feasibility >>> ok, counts = check_sequence_feasibility( ... df, ... time_col="date", ... group_id_cols=["site"], ... time_steps=6, ... forecast_horizon=3, ... ) >>> ok True >>> counts {'A': 9, 'B': 9}
Fail-fast behaviour
>>> check_sequence_feasibility( ... df_small, ... time_col="t", ... time_steps=10, ... forecast_horizon=5, ... error="raise", ... ) Traceback (most recent call last): ... SequenceGeneratorError: No group is long enough ...
Switching engines
>>> _ , _ = check_sequence_feasibility( ... df, ... time_col="ts", ... group_id_cols=None, ... engine="pyarrow", # requires pandas 2.1+, pyarrow installed ... verbose=1, ... ) ✅ Feasible: 1 234 567 sequences possible.
References
McKinney, W. pandas 2.0 User Guide, sec. “GroupBy: split-apply-combine’’.
Arrow Project. (2025). Arrow Columnar Memory Format v2.
- geoprior.utils.sequence_utils.get_sequence_counts(df, *, group_id_cols, min_len, engine='vectorized', verbose=0, logger=<built-in function print>)[source]
Return the total number of feasible sliding-window sequences and a mapping group → count using the requested execution engine.
- Parameters:
engine (
{'vectorized', 'native', 'pyarrow'}, default'vectorized') –Execution backend.
’vectorized’ – fast C-level
DataFrame.groupby.size()(recommended).’native’ – original Python loop (easier to debug, slower).
’pyarrow’ – forces pandas’ Arrow backend if available, then runs the vectorised path. Falls back silently to
'vectorized'when pyarrow is not installed.
df (DataFrame)
min_len (int)
verbose (int)
- Return type:
- geoprior.utils.sequence_utils.generate_pinn_sequences(df, time_col, subsidence_col, gwl_col, dynamic_cols, static_cols=None, future_cols=None, spatial_cols=None, group_id_cols=None, time_steps=12, forecast_horizon=3, output_subsidence_dim=1, output_gwl_dim=1, mode='pihal_like', normalize_coords=True, cols_to_scale=None, method='rolling', stride=1, random_samples=None, expand_step=1, n_bootstrap=0, progress_hook=None, stop_check=None, verbose=1, _logger=None, **kwargs)[source]
Generate input/target arrays for PINN models using various sampling methods (rolling, strided, random, expanding, bootstrap).
- Parameters:
df (
pd.DataFrame) – Full time-series data.time_col (
str) – Name of the time coordinate column.subsidence_col (
str) – Name of the subsidence target column.gwl_col (
str) – Name of the groundwater level target column.dynamic_cols (
list[str]) – Names of past-covariate columns.static_cols (
list[str], optional) – Names of static feature columns.future_cols (
list[str], optional) – Names of known-future feature columns.spatial_cols (
(str,str), optional) – Tuple of (lon_col, lat_col) for spatial coords.group_id_cols (
list[str], optional) – Column(s) identifying independent time-series groups.time_steps (
int, default12) – Look-back window length T.forecast_horizon (
int, default3) – Prediction horizon H.output_subsidence_dim (
int, default1) – Last-dim of subsidence target.output_gwl_dim (
int, default1) – Last-dim of GWL target.mode (
{'pihal_like','tft_like'}, default'pihal_like') – Shapes the “future” window length for TFT vs. PIHALNet.normalize_coords (
bool, defaultTrue) – Apply MinMax scaling to (t,x,y) across all sequences.cols_to_scale (
list[str]or'auto'orNone) – Additional columns to scale via MinMax.method (
{'rolling','strided','random','expanding','bootstrap'}) – Sequence-generation strategy.stride (
int, default1) – Step size for ‘strided’ sampling.random_samples (
int, optional) – Number of random start indices for ‘random’ sampling.expand_step (
int, default1) – Increment size for ‘expanding’ sampling.n_bootstrap (
int, default0) – Number of blocks for ‘bootstrap’ sampling.progress_hook (
callable, optional) – Called with float in [0,1] to report overall progress.stop_check (
callable, optional) – If returns True, aborts sequence generation early.verbose (
int, default1) – Verbosity level (higher = more logs)._logger (
logging.Loggerorcallable, optional) – Logger or print‐style function for vlog().**kwargs – Passed to helper.
- Returns:
inputs (
dict[str,np.ndarray]) – Contains ‘coords’, ‘dynamic_features’, optionally ‘static_features’ and ‘future_features’.targets (
dict[str,np.ndarray]) – Contains ‘subsidence’ and ‘gwl’ arrays.coord_scaler (
MinMaxScalerorNone) – Fitted scaler for coords, if normalization was applied.
- Return type:
tuple[dict[str, ndarray], dict[str, ndarray], MinMaxScaler | None]
- geoprior.utils.sequence_utils.generate_ts_sequences(df, time_col, dynamic_cols, static_cols=None, future_cols=None, spatial_cols=None, group_id_cols=None, time_steps=12, forecast_horizon=1, normalize_coords=True, cols_to_scale=None, method='rolling', stride=1, random_samples=None, expand_step=1, n_bootstrap=0, progress_hook=None, stop_check=None, verbose=1, _logger=None, **kwargs)[source]
Generate time-series windows for encoder/decoder and covariates. Supports rolling, strided, random, expanding, and bootstrap.
- Parameters:
df (
pd.DataFrame) – Input frame with time and feature columns.time_col (
str) – Name of the time coordinate column.dynamic_cols (
list[str]) – Past-covariate columns for encoder inputs.static_cols (
list[str]orNone) – Static covariate columns, repeated per window.future_cols (
list[str]orNone) – Known-future covariates for decoder inputs.spatial_cols (
tuple(str,str)orNone) – (lon, lat) column names for spatial coords.group_id_cols (
list[str]orNone) – Columns to group by for independent series.time_steps (
int) – Number of past steps (T) per window.forecast_horizon (
int) – Number of future steps (H) per window.normalize_coords (
bool) – If True, MinMax-scale spatial coords.cols_to_scale (
list[str]or'auto'orNone) – Other columns to MinMax-scale.method (
str) – ‘rolling’,’strided’,’random’,’expanding’,’bootstrap’.stride (
int) – Step size for ‘strided’ windows.random_samples (
intorNone) – Number of random windows if method=’random’.expand_step (
int) – Increment for ‘expanding’ windows.n_bootstrap (
int) – Number of bootstrap samples if method=’bootstrap’.progress_hook (
callableorNone) – Receives float [0,1] as work progresses.stop_check (
callableorNone) – If returns True, aborts generation.verbose (
int) – Verbosity level. >0 logs progress._logger (
callableorNone) – Logger to use for messages.
- Returns:
- Raises:
SequenceGeneratorError – If no valid windows could be generated.
- Return type:
tuple[dict[str, ndarray], dict[str, ndarray], MinMaxScaler | None]
- geoprior.utils.sequence_utils.build_future_sequences_npz(df_scaled, *, time_col, time_col_num, lon_col, lat_col, time_steps, train_end_time=None, forecast_start_time=None, forecast_horizon=None, subs_col=None, gwl_col=None, h_field_col=None, static_features=None, dynamic_features=None, future_features=None, group_id_cols=None, mode=None, model_name=None, artifacts_dir=None, prefix='future', future_mode='auto', normalize_coords=False, coord_scaler=None, verbose=1, logger=None, stop_check=None, progress_hook=None, **kws)[source]
Build history–future sequences and save them as compressed NPZ files.
This helper constructs, for each spatial group, a sliding window of time_steps “history” points followed by a multi–step forecast horizon and exports the resulting NumPy arrays to disk. It is time-agnostic: the time_col can be numeric (e.g. year, index), year-like floats, datetimes, or strings, as long as equality on that column is meaningful.
If train_end_time, forecast_start_time, or forecast_horizon are not provided, they are inferred from the sorted unique values in
df_scaled[time_col]:train_end_time: by default the second-to-last unique time, leaving at least one future step.forecast_start_time: by default the first time strictly aftertrain_end_time.forecast_horizon: by default one time step ahead, clipped to the number of available future points.
For each valid group, the function builds history dynamic features of shape
(time_steps, n_dynamic), future features of shape(time_steps + H, n_future)whenmodestarts with"tft"or(H, n_future)otherwise, one static feature vector of shape(n_static,), coordinates over the horizon of shape(H, 3)with columns[time_num, lon, lat], anH_fieldarray of shape(H, 1), and optional subsidence and groundwater targets of shape(H, 1)each.All per-group arrays are stacked along a new batch dimension and written as two NPZ files:
<prefix>_inputs.npz: coordinates, dynamic, static, future features and H field.<prefix>_targets.npz: subsidence and groundwater targets.
- Parameters:
df_scaled (
pandas.DataFrame) – Pre-processed (typically scaled) dataframe containing all required columns: time, spatial coordinates, static/dynamic/ future features and optional targets.time_col (
str) – Name of the column encoding the temporal index (e.g."year","date","t_index"). May be numeric, datetime, or string.time_col_num (
strorNone) – Optional numeric time column used as a tie-breaker when multiple rows share the sametime_colvalue. If provided and present in a group, the last row sorted by this column is selected for that time.lon_col (
str) – Name of the longitude (or x-coordinate) column.lat_col (
str) – Name of the latitude (or y-coordinate) column.time_steps (
int) – Length of the history window (number of time steps in the past). Must be strictly positive.train_end_time (
object, optional) – Effective end of the training period. IfNone, it is inferred as the second-to-last unique value indf_scaled[time_col](after sorting).forecast_start_time (
object, optional) – First time step of the forecast horizon. IfNone, it is inferred as the first unique time strictly greater thantrain_end_time.forecast_horizon (
int, optional) – Number of future time steps to include. IfNone, a default horizon of1is used and clipped to the maximum number of available future time points.subs_col (
str, optional) – Name of the subsidence target column. IfNoneor missing from a group, subsidence targets are filled withNaN.gwl_col (
str, optional) – Name of the groundwater-level target column. IfNoneor missing from a group, groundwater targets are filled withNaN.h_field_col (
str, optional) – Name of the hydraulic-head field column used as an additional horizon-level input (H_field). IfNoneor missing, a zero field is used.static_features (
listofstr, optional) – Names of static (time-invariant) feature columns. Any names not present in the dataframe are silently ignored.dynamic_features (
listofstr, optional) – Names of dynamic (history) feature columns used to build the(time_steps, n_dynamic)sequence. Missing columns are ignored.future_features (
listofstr, optional) – Names of future covariate columns used to build the history+future or future-only sequence, depending onmode. Missing columns are ignored.group_id_cols (
listofstr, optional) – Columns used to define spatial (or logical) groups, typically something like["lon", "lat"]or a station identifier. IfNoneor empty, the entire dataframe is treated as a single global group.mode (
str, optional) – Controls how future features are constructed. If the lower-cased value starts with"tft"(e.g."tft_like"), future features are built on top of both history and future rows. Otherwise, only the forecast horizon rows are used.model_name (
str, optional) – Optional model identifier used only in logging messages.artifacts_dir (
str, optional) – Directory where NPZ files are written. IfNoneor empty, the current working directory is used.prefix (
str, default"future") – Prefix for the output NPZ filenames:"<prefix>_inputs.npz"and"<prefix>_targets.npz".future_mode (
{'auto', 'pure-inference', 'pure-data-driven'}, default'auto') –Strategy used to construct the future (forecast) portion of the sequences.
'pure-data-driven': Use only time points that actually exist indf_scaledstrictly after the history window. All future time indices must be present in the data; otherwise aValueErroris raised. This corresponds to the original, strictly data-driven behaviour.'pure-inference': Always synthesize future time points from the last history time, using the median positive time step (or1.0as a fallback). Future inputs are built by re-using the last available history row (forfuture_features,H_field, etc.), and future targets (e.g. subsidence, GWL) are filled withNaNsince the true future is unknown. This mode does not require any rows beyondtrain_end_time.'auto': Try data-driven mode first. If there are enough actual future time points aftertrain_end_timeto cover the requestedforecast_horizon, behave like'pure-data-driven'. If not, automatically fall back to the synthetic'pure-inference'behaviour described above and emit an informational log message viavlog.
verbose (
int, default1) – Verbosity level forwarded togeoprior.utils.vlog(). A value>= 3provides detailed progress logs (temporal inference, per-group status, dropped groups, etc.).logger (
logging.Loggerorcallable, optional) – Optional logger or logging function used bygeoprior.utils.vlog(). IfNone, messages are printed to standard output.**kws – Reserved for future extensions. Currently ignored.
normalize_coords (bool)
coord_scaler (Any | None)
- Returns:
A small dictionary with the absolute paths to the written NPZ files:
{"future_inputs_npz": <path>, "future_targets_npz": <path>}.- Return type:
- Raises:
ValueError – If there are not enough history points before
train_end_timeto satisfytime_steps, if no future points are available afterforecast_start_time, or if all groups are dropped due to incomplete history/horizon windows.
Notes
Groups that do not contain all required history and future times are silently dropped, but the number of dropped groups is reported via
geoprior.utils.vlog()whenverbose > 0.Examples
>>> from geoprior.nn.pinn.sequences import ( ... build_future_sequences_npz, ... ) >>> result = build_future_sequences_npz( ... df_scaled=df_scaled, ... time_col="year", ... time_col_num="t_index", ... lon_col="lon", ... lat_col="lat", ... time_steps=5, ... # Let the function infer times/horizon: ... train_end_time=None, ... forecast_start_time=None, ... forecast_horizon=None, ... subs_col="subsidence", ... gwl_col="gwl", ... h_field_col="H_field", ... static_features=["lithology_class"], ... dynamic_features=["rainfall_mm", "GWL_depth_bgs_z"], ... future_features=["normalized_urban_load_proxy"], ... group_id_cols=["lon", "lat"], ... mode="tft_like", ... model_name="GeoPriorSubsNet", ... artifacts_dir="results/zhongshan/future_npz", ... prefix="zhongshan_future", ... verbose=2, ... ) >>> result["future_inputs_npz"] 'results/zhongshan/future_npz/zhongshan_future_inputs.npz' >>> result["future_targets_npz"] 'results/zhongshan/future_npz/zhongshan_future_targets.npz'
Shape utility helpers for arrays and tensors.
- geoprior.utils.shapes.canonicalize_BHQO(y_pred, *, y_true=None, q_values=(0.1, 0.5, 0.9), n_q=None, layout=None, enforce_monotone=True, return_layout=False, verbose=0, log_fn=<built-in function print>)[source]
Canonicalize quantile outputs to (B, H, Q, O).
- Supported layouts (rank-4):
BHQO: (B, H, Q, O) -> unchanged
BQHO: (B, Q, H, O) -> transpose(0, 2, 1, 3)
BHOQ: (B, H, O, Q) -> transpose(0, 1, 3, 2)
If ambiguous (e.g., H == Q), and y_true is given, pick the transform with smallest MAE for q50.
- If y_true is not given, fallback is:
use layout if provided
else prefer BHQO if plausible
else pick by min crossing score
- Parameters:
y_pred (Any) – Quantile tensor, NumPy array or TF tensor.
y_true (Any | None) – Target tensor (B, H, O) or (B, H, 1). Used only to resolve ambiguity robustly.
q_values (Sequence[float]) – Quantiles in order, e.g. (0.1, 0.5, 0.9).
n_q (int | None) – Number of quantiles. Defaults to len(q_values).
layout (str | None) – Force interpretation: “BHQO”, “BQHO”, “BHOQ”. Use “auto” (or None) to infer.
enforce_monotone (bool) – Sort along Q axis after canonicalization.
return_layout (bool) – If True, return (arr, chosen_layout).
verbose (int) – Logging controls.
- Returns:
Canonical (B, H, Q, O) and optionally the layout.
- Return type:
arror(arr,layout)
- geoprior.utils.shapes.canonicalize_BHQO_quantiles_np(y, n_q=3, *, verbose=0, log_fn=<built-in function print>)[source]
Return y in canonical (B,H,Q,O).
- Accepts common layouts:
(B,H,Q,O) -> unchanged
(B,Q,H,O) -> transpose(0,2,1,3)
(B,H,O,Q) -> transpose(0,1,3,2)
If ambiguous (multiple axes match n_q), choose the transform with minimal quantile crossing score.
Target-processing helpers for GeoPrior workflows.
- geoprior.utils.target_utils.get_output_names(model=None, y=None, y_pred=None, *, exclude_keys={'aux', 'data_final', 'data_mean_raw', 'maps', 'phys_final', 'phys_mean_raw', 'physics'})[source]
Try to obtain stable output names (best-effort, Keras-3-safe).
Lookup priority is:
model._output_keysormodel._output_namesfirst, thenmodel.output_names, then keys fromy_pred, and finally keys fromy.
- geoprior.utils.target_utils.as_tuple(obj, *, names=None, model=None, ctx='value', strict=True, exclude_keys={'aux', 'data_final', 'data_mean_raw', 'maps', 'phys_final', 'phys_mean_raw', 'physics'}, to_numpy='never')[source]
Convert obj to an ordered tuple of outputs.
- Supports:
dict: ordered by names (or inferred)
list/tuple: returns tuple(obj)
tensor/ndarray/scalar: returns (obj,)
- Parameters:
obj (
Any) – Targets/predictions container.names (
list[str]orNone) – Desired order of outputs.model (
AnyorNone) – Model used to infer output names (via _output_keys/output_names).strict (
bool) – If True, missing keys in dict raises KeyError.to_numpy (
{"never","auto","always"}) – Optional conversion of individual leaves to numpy (only safe in eager).
- Returns:
Ordered outputs.
- Return type:
- geoprior.utils.target_utils.update_compiled_metrics(model, y_true, y_pred, *, output_names=None, to_numpy='never')[source]
Keras-3-safe compiled metrics updater.
Prefers dict structure (since you compiled with dict loss/metrics).
Ensures deterministic output order via output_names/_output_keys.
Falls back to list/tuple update_state if needed.
Final fallback: manual per-metric update (won’t crash training).
Note: converting to numpy inside train_step is generally NOT safe.
These helpers support inverse target scaling, sequence construction, canonical layout handling, and target-aware postprocessing.
Validation and helper infrastructure#
Essential utilities for data processing and analysis in FusionLab, offering functions for normalization, interpolation, feature selection, outlier removal, and various data manipulation tasks.
Adapted for FusionLab from the original geoprior.utils.base_utils.
- geoprior.utils.base_utils.detect_categorical_columns(data, integer_as_cat=True, float0_as_cat=True, min_unique_values=None, max_unique_values=None, handle_nan=None, return_frame=False, consider_dt_as=None, verbose=0)[source]
Detect categorical columns in a dataset by examining column types and user-defined criteria. Columns with integer type or float values ending with .0 can be categorized as categorical, depending on settings. Also handles user-defined thresholds for minimum and maximum unique values.
(34)#\[\forall x \in X,\; x = \lfloor x \rfloor\]Above equation indicates that for float columns to be treated as categorical, each value \(x\) must be an integer when cast from float. This function leverages the inline methods
build_data_if,drop_nan_in,fill_NaN,parameter_validator, andsmart_format(excluding those prefixed with_).- Parameters:
data (
DataFrameorarray-like) – The input data to analyze. If not a DataFrame, it will be converted internally.integer_as_cat (
bool, optional) – IfTrue, integer-type columns are considered categorical. Default isTrue.float0_as_cat (
bool, optional) – IfTrue, float columns whose values can be cast to integer without remainder are considered categorical. Default isTrue.min_unique_values (
intorNone, optional) – Minimum number of unique values in a column to qualify as categorical. IfNone, no minimum check is applied.max_unique_values (
intor :py:class:``’auto’:py:class:``orNone, optional) – Maximum number of unique values allowed for a column to be considered categorical. If'auto', set the limit to the column’s own unique count. IfNone, no maximum check is applied.handle_nan (
strorNone, optional) – Handling method for missing data. Can be'drop'to remove rows with NaNs,'fill'to impute them via forward/backward fill, orNonefor no change.return_frame (
bool, optional) – IfTrue, returns a DataFrame of detected categorical columns; otherwise returns a list of column names. Default isFalse.consider_dt_as (
str, optional) – Indicates how to handle or convert datetime columns whenops='validate'. UseNoneto keep datetime columns as-is. Use'numeric'for timestamp-style conversion,'float','float32'or'float64'for float conversion,'int','int32'or'int64'for integer conversion, and'object'or'category'to convert them to Python objects such as strings. If conversion fails, behavior follows the configured error policy.verbose (
int, optional) – Verbosity level. If greater than 0, a summary of detected columns is printed.
- Returns:
Either a list of column names or a DataFrame containing the categorical columns, depending on the value of
return_frame.- Return type:
listorDataFrame
Examples
>>> from geoprior.utils.base_utils import detect_categorical_columns >>> import pandas as pd >>> df = pd.DataFrame({ ... 'A': [1, 2, 3], ... 'B': [1.0, 2.0, 3.0], ... 'C': ['cat', 'dog', 'mouse'] ... }) >>> detect_categorical_columns(df) ['A', 'B', 'C']
Notes
This function focuses on flexible treatment of integer and float columns. Combined with
verbosesettings, it can provide detailed feedback. Using'drop'or'fill'forhandle_nanhelps reduce disruptions caused by missing data. The array-programming background is discussed in Harris et al. [36].The function uses flexible criteria for determining whether a column should be treated as categorical, allowing for detection of columns with integer values or float values ending in .0 as categorical columns. The method is useful when preparing data for machine learning algorithms that expect categorical inputs, such as decision trees or classification models.
This method uses the helper function build_data_if from geoprior.utils.validator to ensure that the input data is a DataFrame. If the input is not a DataFrame, it creates one, giving column names that start with input_name.
See also
build_data_ifValidates and converts input into a DataFrame if needed.
drop_nan_inDrops NaN values from a DataFrame along axis=0.
fill_NaNFills missing data in a DataFrame using forward and backward fill.
- geoprior.utils.base_utils.extract_target(data, target_names, drop=True, columns=None, return_y_X=False)[source]
Extracts specified target column(s) from a multidimensional numpy array or pandas DataFrame.
with options to rename columns in a DataFrame and control over whether the extracted columns are dropped from the original data.
- Parameters:
data (
Union[np.ndarray,pd.DataFrame]) – The input data from which target columns are to be extracted. Can be a NumPy array or a pandas DataFrame.target_names (
Union[str,int,List[Union[str,int]]]) – The name(s) or integer index/indices of the column(s) to extract. If data is a DataFrame, this can be a mix of column names and indices. If data is a NumPy array, only integer indices are allowed.drop (
bool, defaultTrue) – If True, the extracted columns are removed from the original data. If False, the original data remains unchanged.columns (
Optional[List[str]], defaultNone) – If provided and data is a DataFrame, specifies new names for the columns in data. The length of columns must match the number of columns in data. This parameter is ignored if data is a NumPy array.return_y_X (
bool, defaultFalse) – If True, returns a tuple (y, X) where X is the data with the target columns removed and y is the target columns. If False, returns only y.
- Returns:
If return_X_y is True, returns a tuple (X, y) where X is the data with the target columns removed and y is the target columns. If return_X_y is False, returns only y.
- Return type:
Union[ArrayLike,pd.Series,pd.DataFrame,Tuple[ pd.DataFrame,ArrayLike]]- Raises:
ValueError – If columns is provided and its length does not match the number of columns in data. If any of the specified target_names do not exist in data. If target_names includes a mix of strings and integers for a NumPy array input.
Examples
>>> import pandas as pd >>> df = pd.DataFrame({ ... 'A': [1, 2, 3], ... 'B': [4, 5, 6], ... 'C': [7, 8, 9] ... }) >>> target = extract_target(df, 'B', drop=True, return_y_X=False) >>> print(target) 0 4 1 5 2 6 Name: B, dtype: int64 >>> target, remaining = extract_target(df, 'B', drop=True, return_y_X=True) >>> print(target) 0 4 1 5 2 6 Name: B, dtype: int64 >>> print(remaining) A C 0 1 7 1 2 8 2 3 9 >>> arr = np.random.rand(5, 3) >>> target, modified_arr = extract_target(arr, 2, return_X_y=True) >>> print(target) >>> print(modified_arr)
- geoprior.utils.base_utils.fancier_downloader(url, filename, dstpath=None, check_size=False, error='raise', verbose=True)[source]
Download a remote file with a progress bar and optional size verification.
This function downloads a file from the specified
urland saves it locally with the givenfilename. It provides a visual progress bar during the download process and offers an option to verify the downloaded file’s size against the expected size to ensure data integrity. Additionally, the function allows for moving the downloaded file to a specified destination directory.(35)#\[|S_{downloaded} - S_{expected}| < \epsilon\]where \(S_{downloaded}\) is the size of the downloaded file, \(S_{expected}\) is the size specified by the server, and \(\epsilon\) is a small tolerance value.
- Parameters:
url (
str) – The URL from which to download the remote file.filename (
str) – The desired name for the local file. This is the name under which the file will be saved after downloading.dstpath (
Optional[str], defaultNone) – The destination directory path where the downloaded file should be saved. IfNone, the file is saved in the current working directory.check_size (
bool, defaultFalse) –Whether to verify the size of the downloaded file against the expected size obtained from the server. This is useful for ensuring the integrity of the downloaded file. When
True, the function checks:(36)\[|S_{downloaded} - S_{expected}| < \epsilon\]If the size check fails:
If
error='raise', an exception is raised.If
error='warn', a warning is emitted.If
error='ignore', the discrepancy is ignored, and the function continues.
error (
str, default'raise') –Specifies how to handle errors during the size verification process.
'raise': Raises an exception if the file size does not match.'warn': Emits a warning and continues execution.'ignore': Silently ignores the size discrepancy and proceeds.
verbose (
bool, defaultTrue) – Controls the verbosity of the function. IfTrue, the function will print informative messages about the download status, including progress updates and success or failure notifications.
- Returns:
Returns
Noneifdstpathis provided and the file is moved to the destination. Otherwise, returns the local filename as a string.- Return type:
Optional[str]- Raises:
RuntimeError – If the download fails and
erroris set to'raise'.ValueError – If an invalid value is provided for the
errorparameter.
Examples
>>> from geoprior.utils.base_utils import fancier_downloader >>> url = 'https://example.com/data/file.h5' >>> local_filename = 'file.h5' >>> # Download to current directory without size check >>> fancier_downloader(url, local_filename) >>> >>> # Download to a specific directory with size verification >>> fancier_downloader( ... url, ... local_filename, ... dstpath='/path/to/save/', ... check_size=True, ... error='warn', ... verbose=True ... ) >>> >>> # Handle size mismatch by raising an exception >>> fancier_downloader( ... url, ... local_filename, ... check_size=True, ... error='raise' ... )
Notes
Progress Bar: The function uses the tqdm library to display a progress bar during the download. If tqdm is not installed, it falls back to a basic downloader without a progress bar.
Directory Creation: If the specified
dstpathdoes not exist, the function will attempt to create it to ensure the file is saved correctly.File Integrity: Enabling
check_sizehelps in verifying that the downloaded file is complete and uncorrupted. However, it does not perform a checksum verification.Progress-reporting patterns and surrounding tooling are described in [37, 38].
See also
requests.getFunction to perform HTTP GET requests.
tqdmA library for creating progress bars.
os.makedirsFunction to create directories.
geoprior.utils.base_utils.check_file_existsUtility to check file existence.
- geoprior.utils.base_utils.fillNaN(arr, method='ff')[source]
Fill NaN values in a numpy array, pandas Series, or pandas DataFrame using specified methods for forward filling, backward filling, or both.
- Parameters:
arr (
Union[np.ndarray,pd.Series,pd.DataFrame]) – The input data containing NaN values to be filled. This can be a numpy array, pandas Series, or DataFrame expected to contain numeric data.method (
str, optional) – The method used for filling NaN values. Valid options are: - ‘ff’: forward fill (default) - ‘bf’: backward fill - ‘both’: applies both forward and backward fill sequentially
- Returns:
The array with NaN values filled according to the specified method. The return type matches the input type (numpy array, Series, or DataFrame).
- Return type:
Union[np.ndarray,pd.Series,pd.DataFrame]
- geoprior.utils.base_utils.select_features(data, features=None, dtypes_inc=None, dtypes_exc=None, coerce=False, columns=None, verify_integrity=False, parse_features=False, include_missing=None, exclude_missing=None, transform=None, regex=None, callable_selector=None, inplace=False, **astype_kwargs)[source]
Selects features from a dataset based on various criteria and returns a new DataFrame.
Conceptually, the selected columns are the subset of the input column set that satisfies the requested feature names, data-type filters, regex patterns, callable selectors, and missing-data conditions.
- Parameters:
data (
Union[pd.DataFrame,dict,np.ndarray,list]) – The dataset from which to select features. Can be a pandas DataFrame, a dictionary, a NumPy array, or a list of dictionaries/lists.features (
Optional[Union[List[str],Pattern,Callable[[str],bool]]], defaultNone) – Specific feature names to select. Can also be a regex pattern or a callable that takes a column name and returnsTrueif the column should be selected.dtypes_inc (
Optional[Union[str,List[str]]], defaultNone) – The data type(s) to include in the selection. Possible values are the same as for the pandasincludeparameter inselect_dtypes.dtypes_exc (
Optional[Union[str,List[str]]], defaultNone) – The data type(s) to exclude from the selection. Possible values are the same as for the pandasexcludeparameter inselect_dtypes.coerce (
bool, defaultFalse) – IfTrue, numeric columns are coerced to the appropriate types without selection, ignoringfeatures,dtypes_inc, anddtypes_excparameters.columns (
Optional[List[str]], defaultNone) – Column names to use ifdatais a NumPy array or a list without column names.verify_integrity (
bool, defaultFalse) – Verifies the data type integrity and converts data to the correct types if necessary.parse_features (
bool, defaultFalse) – Parses string features and converts them to an iterable object (e.g., lists).include_missing (
Optional[bool], defaultNone) – IfTrue, includes only columns with missing values. IfFalse, excludes columns with missing values.exclude_missing (
Optional[bool], defaultNone) – IfTrue, excludes columns with any missing values.transform (
Optional, defaultNone) – Function or dictionary of functions to apply to the selected columns. If a dictionary is provided, keys should correspond to column names.regex (
Optional[Union[str,Pattern]], defaultNone) – Regular expression pattern to select columns.callable_selector (
Optional[Callable[[str],bool]], defaultNone) – A callable that takes a column name and returnsTrueif the column should be selected.inplace (
bool, defaultFalse) – IfTrue, modifies the data in place. Otherwise, returns a new DataFrame.**astype_kwargs (
Any) – Additional keyword arguments forpandas.DataFrame.astype.
- Returns:
A new DataFrame with the selected features.
- Return type:
pd.DataFrame- Raises:
ValueError – If no columns match the selection criteria and
coerceisFalse.TypeError – If
regexis not a string or compiled regex pattern. Ifcallable_selectoris not a callable. Iftransformis not a callable or a dictionary of callables. If provided parameters are of incorrect types.
Examples
>>> from geoprior.utils.base_utils import select_features >>> import pandas as pd >>> import re >>> import numpy as np >>> data = { ... "Color": ['Blue', 'Red', 'Green'], ... "Name": ['Mary', "Daniel", "Augustine"], ... "Price ($)": ['200', "300", "100"], ... "Discount": [20, 30, np.nan] ... } >>> select_features(data, dtypes_inc='number', verify_integrity=True) Price ($) Discount 0 200.0 20.0 1 300.0 30.0 2 100.0 NaN
>>> select_features(data, features=['Color', 'Price ($)']) Color Price ($) 0 Blue 200 1 Red 300 2 Green 100
>>> select_features( ... data, ... regex='^Price|Discount$', ... transform={'Price ($)': lambda x: x / 100} ... ) Price ($) Discount 0 2.0 20 1 3.0 30 2 1.0 NaN
>>> select_features( ... data, ... callable_selector=lambda col: col.startswith('C') ... ) Color 0 Blue 1 Red 2 Green
Notes
This function is particularly useful in data preprocessing pipelines where the presence of certain features is critical for later analysis or modeling steps. When using regex patterns, ensure that the pattern accurately reflects the intended column names to avoid unintended matches. The callable provided to
callable_selectorshould accept a single column-name string and return a boolean. Transformation functions should be designed to handle the data types of the selected columns to avoid runtime errors. Related selection and coercion behavior is documented in [39, 40, 41, 42].See also
validate_feature,pandas.DataFrame.select_dtypes,pandas.DataFrame.astype
- geoprior.utils.base_utils.fill_NaN(arr, method='ff')[source]
Fill NaN values in an array-like structure using specified methods. Handles numeric and non-numeric data separately to preserve data integrity.
- Parameters:
arr (
array-like,pandas.DataFrame, orpandas.Series) – The input data structure containing NaN values to be filled.method (
str, default :py:class:``’ff’:py:class:``) –The method to use for filling NaN values. Accepted values:
Forward fill:
'forward','ff','fwd'Backward fill:
'backward','bf','bwd'Both:
'both','ffbf','fbwf','bff','full'
- Returns:
The input data structure with NaN values filled according to the specified method.
- Return type:
array-like,pandas.DataFrame, orpandas.Series- Raises:
ValueError – If the provided fill method is not recognized.
Notes
Mathematically, the function performs:
(37)#\[ext{Filled\_array} = egin{cases} ext{fillNaN(arr, method)} & ext{if arr is numeric} \ ext{concat(fillNaN(numeric\_parts, method), non\_numeric\_parts)} & ext{otherwise} \end{cases}\]This ensures that non-numeric data remains unaltered while NaN values in numeric columns are appropriately filled.
The function preserves the original structure of the input array by utilizing
array_preserver. Numeric columns are filled using the specified method, while non-numeric columns remain unchanged.Examples
>>> from geoprior.utils.base_utils import fill_NaN >>> import pandas as pd >>> df = pd.DataFrame({ ... 'A': [1, 2, np.nan, 4], ... 'B': ['x', np.nan, 'y', 'z'] ... }) >>> fill_NaN(df, method='ff') A B 0 1.0 x 1 2.0 x 2 2.0 y 3 4.0 z
See also
geoprior.core.array_manager.array_preserverPreserves and restores array structures.
geoprior.core.array_manager.to_arrayConverts input to a pandas-compatible array-like structure.
geoprior.core.checks.is_numeric_dtypeChecks if the array has numeric data types.
geoprior.utils.base_utils.fillNaNCore function to fill NaN values in numeric data.
- geoprior.utils.base_utils.validate_target_in(df, target, error='raise', verbose=0)[source]
Validate and process the target variable, ensuring it is consistent with the features in the DataFrame.
- Parameters:
df (
pandas.DataFrame) – The DataFrame containing the features and possibly the target column.target (
strorpandas.Seriesorpandas.DataFrame) – The target variable to validate and process.error (
{'raise', 'warn', 'ignore'}, optional) – Behavior to use when target validation fails. Use'raise'to raise an exception,'warn'to continue with a warning, or'ignore'to skip reporting.verbose (
int, optional) – Verbosity level for logging. Use0for no output,1for basic information, and2for detailed information.
- Returns:
target (
pandas.Series) – The processed target variable.df (
pandas.DataFrame) – The DataFrame containing the features and target.
Dependency utilities providing functions to handle package installation, checking, and ensuring that optional dependencies are available.
- geoprior.utils.deps_utils.ensure_pkg(name, extra='', error='raise', min_version=None, exception=None, dist_name=None, infer_dist_name=False, auto_install=False, use_conda=False, partial_check=False, condition=None, verbose=False)[source]
Decorator to ensure a Python package is installed before function execution.
If the specified package is not installed, or if its installed version does not meet the minimum version requirement, this decorator can optionally install or upgrade the package automatically using either pip or conda.
- Parameters:
name (
str) – The name of the package.extra (
str, optional) – Additional specification for the package, such as version or extras.error (
str, optional) – Error handling strategy if the package is missing: ‘raise’, ‘ignore’, or ‘warn’.min_version (
strorNone, optional) – The minimum required version of the package. If not met, triggers installation.exception (
Exception, optional) – A custom exception to raise if the package is missing and errors is ‘raise’.dist_name (
str, optional) – The distribution name of the package as known by package managers (e.g., pip). If provided and the module import fails, an additional check based on the distribution name is performed. This parameter is useful for packages where the distribution name differs from the importable module name.infer_dist_name (
bool, optional) – If True, attempt to infer the distribution name for pip installation, defaults to False.auto_install (
bool, optional) – Whether to automatically install the package if missing. Defaults to False.use_conda (
bool, optional) – Prefer conda over pip for automatic installation. Defaults to False.partial_check (
bool, optional) – If True, checks the existence of the package only if the condition is met. This allows for conditional package checking based on the function’s arguments or other criteria. If False, the check is always performed. Defaults to False.condition (
Any, optional) – A condition that determines whether to check for the package’s existence. This can be a callable that takes the same arguments as the decorated function and returns a boolean, a specific argument name to check for truthiness, or any other value that will be evaluated as a boolean. If None, the package check is performed unconditionally unless partial_check is False.verbose (
bool, optional) – Enable verbose output during the installation process. Defaults to False.
- Returns:
A decorator that wraps functions to ensure the specified package is installed.
- Return type:
Callable
Examples
>>> from geoprior.utils.deps_utils import ensure_pkg >>> @ensure_pkg("numpy", auto_install=True) ... def use_numpy(): ... import numpy as np ... return np.array([1, 2, 3])
>>> @ensure_pkg("pandas", min_version="1.1.0", errors="warn", use_conda=True) ... def use_pandas(): ... import pandas as pd ... return pd.DataFrame([[1, 2], [3, 4]])
>>> @ensure_pkg("matplotlib", partial_check=True, condition=lambda x, y: x > 0) ... def plot_data(x, y): ... import matplotlib.pyplot as plt ... plt.plot(x, y) ... plt.show()
>>> @ensure_pkg("skimage", partial_check=True, condition=( ... lambda *args, **kwargs: 'method' in kwargs and kwargs['method'] == 'hog') ... ) >>> def check_package_installed(data, method='hog', **kwargs): ... extractor_function = None ... if method == 'hog': ... from skimage.feature import hog ... extractor_function = lambda image: hog(image, **kwargs) ... return extractor_function
- geoprior.utils.deps_utils.ensure_pkgs(names, extra='', error='raise', min_versions=None, exception=None, dist_names=None, infer_dist_name=False, auto_install=False, use_conda=False, partial_check=False, condition=None, verbose=False)[source]
Decorator to ensure Python packages are installed before function execution.
If the specified packages are not installed, or if their installed versions do not meet the minimum version requirements, this decorator can optionally install or upgrade the packages automatically using either pip or conda.
- Parameters:
names (
strorlistofstr) – The name(s) of the package(s). Can be a single string with package names separated by commas, or a list of package names.extra (
str, optional) – Additional specification for the package(s), such as version or extras.error (
{'raise', 'ignore', 'warn'}, optional) – Error handling strategy if a package is missing: ‘raise’, ‘ignore’, or ‘warn’. Defaults to ‘raise’.min_version (
strorlistofstr, optional) – The minimum required version(s) of the package(s). If not met, triggers installation. Can be a single version string applied to all packages or a list matching the names list.exception (
Exception, optional) – A custom exception to raise if a package is missing and errors is ‘raise’.dist_name (
strorlistofstr, optional) – The distribution name(s) of the package(s) as known by package managers (e.g., pip). Useful when the distribution name differs from the importable module name. Can be a single string or a list matching the names list.infer_dist_name (
bool, optional) – If True, attempt to infer the distribution name for pip installation. Defaults to False.auto_install (
bool, optional) – Whether to automatically install missing packages. Defaults to False.use_conda (
bool, optional) – Prefer conda over pip for automatic installation. Defaults to False.partial_check (
bool, optional) – If True, checks the existence of the packages only if the condition is met. Allows for conditional package checking based on the function’s arguments or other criteria. If False, the check is always performed. Defaults to False.condition (
Any, optional) – A condition that determines whether to check for the packages’ existence. Can be a callable that takes the same arguments as the decorated function and returns a boolean, a specific argument name to check for truthiness, or any other value that will be evaluated as a boolean. If None, the package check is performed unconditionally unless partial_check is False.verbose (
bool, optional) – Enable verbose output during the installation process. Defaults to False.
- Returns:
A decorator that wraps functions to ensure the specified packages are installed.
- Return type:
Callable
Examples
>>> from geoprior.utils.deps_utils import ensure_pkgs >>> @ensure_pkgs("numpy, pandas", auto_install=True) ... def use_numpy_pandas(): ... import numpy as np ... import pandas as pd ... return np.array([1, 2, 3]), pd.DataFrame([[1, 2], [3, 4]])
>>> @ensure_pkgs(["matplotlib", "seaborn"], min_version=["3.0.0", "0.11.0"]) ... def plot_data(x, y): ... import matplotlib.pyplot as plt ... import seaborn as sns ... sns.scatterplot(x=x, y=y) ... plt.show()
>>> @ensure_pkgs("skimage", partial_check=True, condition=( ... lambda *args, **kwargs: 'method' in kwargs and kwargs['method'] == 'hog') ... ) ... def process_image(data, method='hog', **kwargs): ... if method == 'hog': ... from skimage.feature import hog ... return hog(data, **kwargs) ... else: ... # Other processing ... pass
- geoprior.utils.deps_utils.install_package(name, dist_name=None, infer_dist_name=False, version=None, extra='', use_conda=False, verbose=True)[source]
Install a Python package at runtime, optionally specifying a version constraint or other parameters, using either conda or pip. If conda is unavailable or disabled, pip is used by default. The function includes a check for pre-existing installations, allowing users to skip redundant installs.
- Parameters:
name (
str) – Base name of the package to install (e.g.,'requests').dist_name (
str, optional) – Distribution name, if different from the import name. For example, scikit-learn’s import namesklearndiffers from its distribution name'scikit-learn'.infer_dist_name (
bool, optional) – If True, callsget_installation_name()to infer the distribution name automatically. Defaults to False.version (
str, optional) – Version string or comparator. Examples include'1.2.0'interpreted as'>=1.2.0','==1.2.0','<2.0', and'>=1.5.3'. IfNone, no version constraint is applied.extra (
str, optional) – Additional install specifiers or command-line flags passed to the installation command. For instance,' --no-cache-dir'or'[extra]'. Default is''.use_conda (
bool, optional) – If True, attempts installation via conda first. If conda is unavailable or fails, falls back to pip. Defaults to False.verbose (
bool, optional) – If True, prints detailed logs throughout the installation. Defaults to True.
- Returns:
On success, the specified package is installed (or is already present). If the installer fails, raises a RuntimeError.
- Return type:
- Raises:
RuntimeError – If the installation cannot be completed using either conda or pip, or if conda is requested but unavailable.
Notes
If the package is already installed, as determined by
is_module_installed(), no further action is taken. When using pip, a progress bar is displayed iftqdmis installed. For conda, no progress bar is shown because of console I/O capture limitations.Conceptually, this function assembles an install spec of the form:
(38)#\[\text{install\_str} = \langle \text{name} \rangle + \langle \text{version\_spec} \rangle + \langle \text{extra} \rangle\]where \(\langle \text{name} \rangle\) is the package name, \(\langle \text{version\_spec} \rangle\) is a version comparator (e.g.,
>=1.2.0), and \(\langle \text{extra} \rangle\) is any additional flags or arguments.Examples
>>> from geoprior.utils.deps_utils import install_package >>> # Install requests with no version constraint, default pip: >>> install_package('requests', verbose=True)
>>> # Install a specific version via conda (fallback to pip if conda fails): >>> install_package( ... 'pandas', ... version='==1.2.0', ... use_conda=True, ... verbose=True ... )
See also
is_module_installedCheck whether a Python module or corresponding distribution is already installed.
get_installation_nameInfer a distribution name for the given module name when needed.
- geoprior.utils.deps_utils.is_installing(module, upgrade=True, action=True, DEVNULL=False, verbose=0, **subpkws)[source]
Install or uninstall a module/package using the subprocess under the hood.
- Parameters:
module (
str,) – the module or library name to install using Python Index Package PIPupgrade (
bool,) – install the lastest version of the package. default isTrue.DEVNULL (
bool,) – decline the stdoutput the message in the consoleaction (
str,bool) – Action to perform. ‘install’ or ‘uninstall’ a package. default isTruewhich means ‘intall’.verbose (
int,Optional) – Control the verbosity i.e output a message. High level means more messages. default is0.subpkws (
dict,) – additional subprocess keywords arguments
- Returns:
success – whether the package is sucessfully installed or not.
- Return type:
Example
>>> from gofast import is_installing >>> is_installing( 'tqdm', action ='install', DEVNULL=True, verbose =1) >>> is_installing( 'tqdm', action ='uninstall', verbose =1)
- geoprior.utils.deps_utils.get_installation_name(module_name, distribution_name=None, return_bool=False)[source]
Determines the appropriate name for installing a package, considering potential discrepancies between the distribution name and the module import name. Optionally, returns a boolean indicating if the distribution name matches the import name.
- Parameters:
module_name (
str) – The import name of the module.distribution_name (
str, optional) – The distribution name of the package. If None, the function attempts to infer the distribution name from the module name.return_bool (
bool, optional) – If True, returns a boolean indicating whether the distribution name matches the module import name. Otherwise, returns the name recommended for installation.
- Returns:
Depending on return_bool, returns either a boolean indicating if the distribution name matches the module name, or the name (distribution or module) recommended for installation.
- Return type:
Union[str,bool]
- geoprior.utils.deps_utils.is_module_installed(module_name, distribution_name=None)[source]
Check if a Python module is installed by attempting to import it. Optionally, a distribution name can be provided if it differs from the module name.
- Parameters:
module_name (
str) – The import name of the module to check.distribution_name (
str, optional) – The distribution name of the package as known by package managers (e.g., pip). If provided and the module import fails, an additional check based on the distribution name is performed. This parameter is useful for packages where the distribution name differs from the importable module name.
- Returns:
True if the module can be imported or the distribution package is installed, False otherwise.
- Return type:
Examples
>>> is_module_installed("sklearn") True >>> is_module_installed("scikit-learn", "scikit-learn") True >>> is_module_installed("some_nonexistent_module") False
- geoprior.utils.deps_utils.import_optional_dependency(name, extra='', errors='raise', min_version=None, exception=None)[source]
Import an optional dependency.
By default, if a dependency is missing an ImportError with a nice message will be raised. If a dependency is present, but too old, we raise.
- Parameters:
name (
str) – The module name.extra (
str) – Additional text to include in the ImportError message.errors (str
{'raise', 'warn', 'ignore'}) –What to do when a dependency is not found or its version is too old.
raise : Raise an ImportError
warn : Only applicable when a module’s version is to old. Warns that the version is too old and returns None
ignore: If the module is not installed, return None, otherwise, return the module, even if the version is too old. It’s expected that users validate the version locally when using
errors="ignore"(see.io/html.py)
min_version (
str, defaultNone) – Specify a minimum version that is different from the global pandas minimum version required.exception (
callable,BaseException) – Can be your own package exception rather than ImportError
- Returns:
maybe_module – The imported module, when found and the version is correct. None is returned when the package is not found and errors is False, or when the package’s version is too old and errors is
'warn'.- Return type:
Optional[ModuleType]
- geoprior.utils.deps_utils.ensure_module_installed(module_name, auto_install=False, version=None, package_manager='pip', dist_name=None, extra_install_args=None)[source]
Ensure that the required module is installed, optionally installing it if missing.
- Parameters:
module_name (
str) – The name of the module to check and install if necessary.auto_install (
bool, optional) – IfTrue, automatically install the module using the specified package manager if it is not already installed (default isFalse).version (
Optional[str], optional) – Specify a version or version range for the module. For example, “>=1.0.0” or “==2.0.1”. IfNone, no version constraints are applied (default isNone).package_manager (
str, optional) – The package manager to use for installation. Currently, only"pip"is supported. Future versions may support other package managers like"conda"(default is"pip").dist_name (
Optional[str], optional) – Sometimes the module name used for importing is different from the distribution package name. This parameter allows specifying the distribution package name (default isNone).extra_install_args (
Optional[List[str]], optional) – A list of additional arguments to pass to the package manager during installation. For example,["--upgrade"]to upgrade the package. IfNone, no extra arguments are passed (default isNone).
- Returns:
Returns
Trueif the module is installed or successfully installed,Falseotherwise.- Return type:
- Raises:
ImportError – If the module is not installed and
auto_installisFalse, or if the installation fails.ValueError – If an unsupported package manager is specified.
Examples
>>> from geoprior.utils.deps_utils import ensure_module_installed
>>> # Ensure that 'numpy' is installed >>> ensure_module_installed("numpy")
>>> # Ensure that 'pandas' is installed, automatically installing if missing >>> ensure_module_installed("pandas", auto_install=True)
>>> # Ensure that 'scipy' version >=1.5.0 is installed >>> ensure_module_installed("scipy", version=">=1.5.0", auto_install=True)
>>> # Install with additional arguments >>> ensure_module_installed( ... "requests", ... auto_install=True, ... extra_install_args=["--upgrade"] ... )
Notes
This function currently supports only
"pip"as the package manager. When specifying a version, ensure that the version string is compatible with the package manager’s version specification syntax. Packages that require system-level dependencies may still need manual installation steps.See also
subprocessFor spawning new processes.
sysSystem-specific parameters and functions.
- geoprior.utils.deps_utils.get_versions(extras=None, distribution_mapping=None)[source]
Retrieve a dictionary containing version information for common libraries, as well as any user-specified packages and distribution name mappings.
- Parameters:
extras (
listofstr, optional) – Additional packages for which to attempt version retrieval. By default, None, which means no extra packages beyond the defaults.distribution_mapping (
dict, optional) – Mapping from import-like names to actual distribution names. For example, the import name'sklearn'corresponds to the distribution name'scikit-learn'. Default is None, which uses a built-in mapping for scikit-learn and any user-provided dictionary overrides or additions.
- Returns:
Dictionary of the form:
{ "__version__": { "numpy": "1.24.2", "pandas": "1.5.0", "sklearn": "1.3.2", ... } }
- Return type:
Notes
By default, this function attempts to retrieve versions for the following packages:
['numpy', 'pandas', 'sklearn', 'joblib', 'tensorflow', 'keras', 'torch'].If a package is not installed, it is skipped (no error is raised).
If <distribution_mapping> is provided, it merges with the built-in mapping (for
"sklearn"→"scikit-learn"), allowing users to specify additional name differences.Python 3.8+ is recommended to ensure
importlib.metadatais available.
Examples
>>> get_versions() { "__version__": { "numpy": "1.24.2", "pandas": "1.5.0", ... } }
>>> # Add custom package and distribution mapping: >>> get_versions( ... extras=["spacy"], ... distribution_mapping={"spacy": "spacy-legacy"} ... ) { "__version__": { "numpy": "1.24.2", "pandas": "1.5.0", "spacy": "3.5.1" } }
geoprior.utils.split
Group-holdout split for sequence data.
- Exports:
train_windows_T{T}_H{H}.npz val_windows_T{T}_H{H}.npz test_windows_T{T}_H{H}.npz future_inputs_T{T}_H{H}.npz splits_groups.json
- Leakage fix (Zhongshan 2 windows/pixel):
split by group_id first, then window inside split.
- class geoprior.utils.split.SplitCfg(seed: 'int' = 42, ratios: 'tuple[float, float, float]' = (0.7, 0.15, 0.15), decimals: 'int' = 8)[source]
Bases:
object- seed: int = 42
- decimals: int = 8
- geoprior.utils.split.split_group_keys(keys, *, cfg=SplitCfg(seed=42, ratios=(0.7, 0.15, 0.15), decimals=8))[source]
- geoprior.utils.split.subset_by_keys(df, *, group_cols, keys, decimals=8)[source]
- geoprior.utils.split.write_splits_json(path, *, group_cols, time_steps, horizon, train_end, cfg, splits)[source]
- geoprior.utils.split.pack_xy_npz(x, y)[source]
- geoprior.utils.split.build_group_holdout_npzs(*, df_train, artifacts_dir, group_cols, time_col_used, x_col_used, y_col_used, subs_col, gwl_target_col, gwl_dyn_col, h_field_col, static_cols, dynamic_cols, future_cols, time_steps, horizon, mode, model_name, train_end, keys_ok, cfg=SplitCfg(seed=42, ratios=(0.7, 0.15, 0.15), decimals=8), normalize_coords=True)[source]
Build train/val/test windows using group holdout.
Returns dict containing paths and coord_scaler.
- Parameters:
df_train (DataFrame)
artifacts_dir (str)
time_col_used (str)
x_col_used (str)
y_col_used (str)
subs_col (str)
gwl_target_col (str)
gwl_dyn_col (str)
h_field_col (str)
time_steps (int)
horizon (int)
mode (str)
model_name (str)
train_end (float | None)
keys_ok (ndarray)
cfg (SplitCfg)
normalize_coords (bool)
- Return type:
- geoprior.utils.split.build_future_inputs_npz(*, df_scaled, artifacts_dir, time_col, time_col_num, lon_col, lat_col, subs_col, gwl_col, h_field_col, static_features, dynamic_features, future_features, group_cols, train_end_time, forecast_start_time, horizon, time_steps, mode, model_name, normalize_coords, coord_scaler=None)[source]
- Parameters:
df_scaled (DataFrame)
artifacts_dir (str)
time_col (str)
time_col_num (str | None)
lon_col (str)
lat_col (str)
subs_col (str)
gwl_col (str)
h_field_col (str)
train_end_time (Any)
forecast_start_time (Any)
horizon (int)
time_steps (int)
mode (str)
model_name (str)
normalize_coords (bool)
coord_scaler (Any)
- Return type:
Provides a comprehensive set of functions and warnings for validating and ensuring the integrity of data. This includes utilities for checking data consistency, validating machine learning targets, ensuring proper data types, and handling various validation scenarios.
- exception geoprior.utils.validator.DataConversionWarning[source]
Bases:
UserWarningWarning used to notify implicit data conversions happening in the code.
This warning occurs when some input data needs to be converted or interpreted in a way that may not match the user’s expectations. For example, this warning may occur when the user:
passes an integer array to a function that expects float input and will convert the input;
requests a non-copying operation, but a copy is required to meet the implementation’s data-type expectations;
passes an input whose shape can be interpreted ambiguously.
Changed in version 0.18: Moved from
sklearn.utils.validation.
- exception geoprior.utils.validator.PositiveSpectrumWarning[source]
Bases:
UserWarningWarning raised when the eigenvalues of a PSD matrix have issues
This warning is typically raised by
_check_psd_eigenvalueswhen the eigenvalues of a positive semidefinite (PSD) matrix such as a gram matrix (kernel) present significant negative eigenvalues, or bad conditioning i.e. very small non-zero eigenvalues compared to the largest eigenvalue.Added in version 0.22.
- geoprior.utils.validator.array_to_frame(X, *, to_frame=False, columns=None, raise_exception=False, raise_warning=True, input_name='', force=False)[source]
Validates and optionally converts an array-like object to a pandas DataFrame, applying specified column names if provided or generating them if the force parameter is set.
- Parameters:
X (
array-like) – The array to potentially convert to a DataFrame.columns (
strorlistofstr, optional) – The names for the resulting DataFrame columns or the Series name.to_frame (
bool, defaultFalse) – If True, converts X to a DataFrame if it isn’t already one.input_name (
str, default'') – The name of the input variable, used for error and warning messages.raise_warning (
bool, defaultTrue) – If True and to_frame is True but columns are not provided, a warning is issued unless force is True.raise_exception (
bool, defaultFalse) – If True, raises an exception when to_frame is True but columns are not provided and force is False.force (
bool, defaultFalse) – Forces the conversion of X to a DataFrame by generating column names based on input_name if columns are not provided.
- Returns:
The potentially converted DataFrame or Series, or X unchanged.
- Return type:
pd.DataFrameorpd.Series
Examples
>>> from geoprior.utils.validator import array_to_frame >>> from sklearn.datasets import load_iris >>> data = load_iris() >>> X = data.data >>> array_to_frame(X, to_frame=True, columns=['sepal_length', 'sepal_width', 'petal_length', 'petal_width'])
- geoprior.utils.validator.array_to_frame2(X, *, to_frame=False, columns=None, raise_exception=False, raise_warning=True, input_name='', force=False)[source]
Added part of is_frame dedicated to X and y frame reconversion validation.
- Parameters:
X (
Array-like) – Array to convert to frame.columns (
strorlistofstr) – Series name or columns names for pandas.Series and DataFrame.to_frame (
str, defaultFalse) – IfTrue, reconvert the array to frame using the columns orthewise no-action is performed and return the same array.input_name (
str, default"") – The data name used to construct the error message.raise_warning (
bool, defaultTrue) – If True then raise a warning if conversion is required. Ifignore, warnings silence mode is triggered.raise_exception (
bool, defaultFalse) – If True then raise an exception if array is not symmetric.force (
bool, defaultFalse) – Force conversion array to a frame is columns is not supplied. Use the combinaison, input_name and X.shape[1] range.
- Returns:
X
- Return type:
converted array
Example
>>> from geoprior.datasets import fetch_data >>> from geoprior.utils.validator import array_to_frame >>> data = fetch_data ('hlogs').frame >>> array_to_frame (data.k.values , to_frame= True, columns =None, input_name= 'y', raise_warning="silence" ) ... array([nan, nan, nan, ..., nan, nan, nan]) # mute
- geoprior.utils.validator.assert_all_finite(X, *, allow_nan=False, estimator_name=None, input_name='')[source]
Throw a ValueError if X contains NaN or infinity.
- Parameters:
X (
{ndarray, sparse matrix}) – The input data.allow_nan (
bool, defaultFalse) – If True, do not throw error when X contains NaN.estimator_name (
str, defaultNone) – The estimator name, used to construct the error message.input_name (
str, default"") – The data name used to construct the error message. In particular if input_name is “X” and the data has NaN values and allow_nan is False, the error message will link to the imputer documentation.
- geoprior.utils.validator.assert_xy_in(x, y, *, data=None, asarray=True, to_frame=False, columns=None, xy_numeric=False, ignore=None, **kws)[source]
Assert the name of x and y in the given data.
Check whether string arguments passed to x and y are valid in the data, then retrieve the x and y array values.
- Parameters:
x (
Arraylike 1dorstr,str) – One dimensional arrays. In principle if data is supplied, they must constitute series. If x and y are given as string values, the data must be supplied. x and y names must be included in the dataframe otherwise an error raises.y (
Arraylike 1dorstr,str) – One dimensional arrays. In principle if data is supplied, they must constitute series. If x and y are given as string values, the data must be supplied. x and y names must be included in the dataframe otherwise an error raises.data (
pd.DataFrame,) – Data containing x and y names. Need to be supplied when x and y are given as string names.asarray (
bool, default=True) – Returns x and y as array rather than series.to_frame (
bool, defaultFalse,) – Convert data to a dataframe using either the columns names or the input_names when the keyword parameterforce=True.columns (
listofstr,Optional) – Name of columns to transform the array (data) to a dataframe.xy_numeric (
bool, defaultFalse) – Convert x and y to numeric values.ignore (
str, optional) – It should be ‘x’ or ‘y’. If set the array is ignored and not asserted.kws (
dict,) – Keyword arguments passed toarray_to_frame().
- Returns:
x, y – One dimensional array or pd.Series
- Return type:
Arraylike
Examples
>>> import numpy as np >>> import pandas as pd >>> from geoprior.utils.validator import assert_xy_in >>> x, y = np.random.rand(7 ), np.arange (7 ) >>> data = pd.DataFrame ({'x': x, 'y':y} ) >>> assert_xy_in (x='x', y='y', data = data ) (array([0.37454012, 0.95071431, 0.73199394, 0.59865848, 0.15601864, 0.15599452, 0.05808361]), array([0, 1, 2, 3, 4, 5, 6])) >>> assert_xy_in (x=x, y=y) (array([0.37454012, 0.95071431, 0.73199394, 0.59865848, 0.15601864, 0.15599452, 0.05808361]), array([0, 1, 2, 3, 4, 5, 6])) >>> assert_xy_in (x=x, y=data.y) # y is a series (array([0.37454012, 0.95071431, 0.73199394, 0.59865848, 0.15601864, 0.15599452, 0.05808361]), array([0, 1, 2, 3, 4, 5, 6])) >>> assert_xy_in (x=x, y=data.y, asarray =False ) # return y like it was (array([0.37454012, 0.95071431, 0.73199394, 0.59865848, 0.15601864, 0.15599452, 0.05808361]), 0 0 1 1 2 2 3 3 4 4 5 5 6 6 Name: y, dtype: int32)
- geoprior.utils.validator.build_data_if(data, columns=None, to_frame=True, input_name='data', col_prefix='col_', force=False, error='warn', coerce_datetime=False, coerce_numeric=True, start_incr_at=0, **kw)[source]
Validates and converts
datainto a pandas DataFrame if requested, optionally enforcing consistent column naming. Intended to standardize data structures for downstream analysis.See more in
geoprior.utils.data_utils.build_df()for documentation details.
- geoprior.utils.validator.check_X_y(X, y, accept_sparse=False, *, accept_large_sparse=True, dtype='numeric', order=None, copy=False, force_all_finite=True, ensure_2d=True, allow_nd=False, multi_output=False, ensure_min_samples=1, ensure_min_features=1, y_numeric=False, estimator=None, to_frame=False)[source]
Input validation for standard estimators.
Checks X and y for consistent length, enforces X to be 2D and y 1D. By default, X is checked to be non-empty and containing only finite values. Standard input checks are also applied to y, such as checking that y does not have np.nan or np.inf targets. For multi-label y, set multi_output=True to allow 2D and sparse y. If the dtype of X is object, attempt converting to float, raising on failure.
- Parameters:
X (
{ndarray, list, sparse matrix}) – Input data.y (
{ndarray, list, sparse matrix}) – Labels.accept_sparse (
str,boolorlistofstr, defaultFalse) – String[s] representing allowed sparse matrix formats, such as ‘csc’, ‘csr’, etc. If the input is sparse but not in the allowed format, it will be converted to the first listed format. True allows the input to be any format. False means that a sparse matrix input will raise an error.accept_large_sparse (
bool, defaultTrue) – If a CSR, CSC, COO or BSR sparse matrix is supplied and accepted by accept_sparse, accept_large_sparse will cause it to be accepted only if its indices are stored with a 32-bit dtype.dtype (
'numeric',type,listoftypeorNone, default'numeric') – Data type of result. If None, the dtype of the input is preserved. If “numeric”, dtype is preserved unless array.dtype is object. If dtype is a list of types, conversion on the first type is only performed if the dtype of the input is not in the list.order (
{'F', 'C'}, defaultNone) – Whether an array will be forced to be fortran or c-style.copy (
bool, defaultFalse) – Whether a forced copy will be triggered. If copy=False, a copy might be triggered by a conversion.force_all_finite (
boolor'allow-nan', defaultTrue) – Whether to raise an error on np.inf, np.nan, pd.NA in X. This parameter does not influence whether y can have np.inf, np.nan, pd.NA values. UseTrueto require all values ofXto be finite,Falseto allownp.inf,np.nan, andpd.NA, or"allow-nan"to allow onlynp.nanandpd.NAwhile still rejecting infinite values.pd.NAis accepted and converted intonp.nan.ensure_2d (
bool, defaultTrue) – Whether to raise a value error if X is not 2D.allow_nd (
bool, defaultFalse) – Whether to allow X.ndim > 2.multi_output (
bool, defaultFalse) – Whether to allow 2D y (array or sparse matrix). If false, y will be validated as a vector. y cannot have np.nan or np.inf values if multi_output=True.ensure_min_samples (
int, default1) – Make sure that X has a minimum number of samples in its first axis (rows for a 2D array).ensure_min_features (
int, default1) – Make sure that the 2D array has some minimum number of features (columns). The default value of 1 rejects empty datasets. This check is only enforced when X has effectively 2 dimensions or is originally 1D andensure_2dis True. Setting to 0 disables this check.y_numeric (
bool, defaultFalse) – Whether to ensure that y has a numeric type. If dtype of y is object, it is converted to float64. Should only be used for regression algorithms.estimator (
strorestimator instance, defaultNone) – If passed, include the name of the estimator in warning messages.
- Returns:
- geoprior.utils.validator.check_array(array, *, accept_large_sparse=True, dtype='numeric', accept_sparse=False, order=None, copy=False, force_all_finite=True, ensure_2d=True, allow_nd=False, ensure_min_samples=1, ensure_min_features=1, estimator=None, input_name='', to_frame=True)[source]
Input validation on an array, list, or similar.
By default, the input is checked to be a non-empty 2D array containing only finite values. If the dtype of the array is object, attempt converting to float, raising on failure.
- Parameters:
array (
object) – Input object to check / convert.accept_sparse (
str,boolorlist/tupleofstr, defaultFalse) – String[s] representing allowed sparse matrix formats, such as ‘csc’, ‘csr’, etc. If the input is sparse but not in the allowed format, it will be converted to the first listed format. True allows the input to be any format. False means that a sparse matrix input will raise an error.accept_large_sparse (
bool, defaultTrue) – If a CSR, CSC, COO or BSR sparse matrix is supplied and accepted by accept_sparse, accept_large_sparse=False will cause it to be accepted only if its indices are stored with a 32-bit dtype.dtype (
'numeric',type,listoftypeorNone, default'numeric') – Data type of result. If None, the dtype of the input is preserved. If “numeric”, dtype is preserved unless array.dtype is object. If dtype is a list of types, conversion on the first type is only performed if the dtype of the input is not in the list.order (
{'F', 'C'}orNone, defaultNone) – Whether an array will be forced to be fortran or c-style. When order is None (default), then if copy=False, nothing is ensured about the memory layout of the output array; otherwise (copy=True) the memory layout of the returned array is kept as close as possible to the original array.copy (
bool, defaultFalse) – Whether a forced copy will be triggered. If copy=False, a copy might be triggered by a conversion.force_all_finite (
boolor'allow-nan', defaultTrue) – Whether to raise an error onnp.inf,np.nan, orpd.NAinarray. UseTrueto require all values to be finite,Falseto allownp.inf,np.nan, andpd.NA, or"allow-nan"to allow onlynp.nanandpd.NAwhile still rejecting infinite values.pd.NAis converted intonp.nan.ensure_2d (
bool, defaultTrue) – Whether to raise a value error if array is not 2D.ensure_min_samples (
int, default1) – Make sure that the array has a minimum number of samples in its first axis (rows for a 2D array). Setting to 0 disables this check.ensure_min_features (
int, default1) – Make sure that the 2D array has some minimum number of features (columns). The default value of 1 rejects empty datasets. This check is only enforced when the input data has effectively 2 dimensions or is originally 1D andensure_2dis True. Setting to 0 disables this check.estimator (
strorestimator instance, defaultNone) – If passed, include the name of the estimator in warning messages.input_name (
str, default"") – The data name used to construct the error message. In particular if input_name is “X” and the data has NaN values and allow_nan is False, the error message will link to the imputer documentation.to_frame (
bool, defaultFalse) – Reconvert array back to pd.Series or pd.DataFrame if the original array is pd.Series or pd.DataFrame.
- Returns:
array_converted – The converted and validated array.
- Return type:
- geoprior.utils.validator.check_classification_targets(*y, target_type='numeric', strategy='auto', verbose=False)[source]
Validate that the target arrays are suitable for classification tasks.
This function is designed to ensure that target arrays (y) contain only finite, categorical values, and it raises a ValueError if the targets do not meet the criteria necessary for classification tasks, such as the presence of continuous values, NaNs, or infinite values.
This validation is crucial for preprocessing steps in machine learning pipelines to ensure that the data is appropriate for classification algorithms.
- Parameters:
*y (
array-like) – One or more target arrays to be validated. The input can be in the form of lists, numpy arrays, or pandas series. Each array is checked individually to ensure it meets the criteria for classification targets.target_type (
str, optional) – The expected data type of the target arrays. Supported values are ‘numeric’ and ‘object’. If ‘numeric’, the function attempts to convert the target arrays to integers, raising an error if conversion is not possible due to non-numeric values. If ‘object’, the target arrays are left as numpy arrays of dtype object, suitable for categorical classification without conversion. Default is ‘numeric’.strategy (
str, optional) –Defines the approach for evaluating if the target arrays are suitable for classification based on their unique values and data types. The ‘auto’ strategy uses heuristic or automatic detection to decide whether target data should be treated as categorical, which is useful for most cases. Custom strategies can be defined to enforce specific validation rules or preprocessing steps based on the nature of the target data (e.g., ‘continuous’, ‘multilabel-indicator’, ‘unknown’). These custom strategies should align with the outcomes of a predefined type_of_target function, allowing for nuanced handling of different target data scenarios. The default value is
'auto', which applies general rules for categorization and numeric conversion where applicable.If a strategy other than
'auto'is specified, it directly influences how the data is validated and potentially converted, based on the expected or detected type of target data:If ‘continuous’, the function checks if the data can be used for regression tasks and raises an error for classification use without explicit binning.
If ‘multilabel-indicator’, it validates the data for multilabel classification tasks and ensures appropriate format.
If ‘unknown’, it attempts to validate the data with generic checks, raising errors for any unclear or unsupported data formats.
verbose (
bool, optional) – If set to True, the function prints a message for each target array checked, confirming that it is suitable for classification. This is helpful for debugging and when validating multiple target arrays simultaneously.
- Raises:
ValueError – If any of the target arrays contain values unsuitable for classification. This includes arrays with continuous values, NaNs, infinite values, or arrays that do not represent categorical data properly.
Examples
Using the function with a single array of integer labels:
>>> from geoprior.utils.validator import check_classification_targets >>> y = [1, 2, 3, 2, 1] >>> check_classification_targets(y) [array([1, 2, 3, 2, 1], dtype=object)]
Using the function with multiple arrays, including a mix of integer and string labels:
>>> y1 = [0, 1, 0, 1] >>> y2 = ["spam", "ham", "spam", "ham"] >>> check_classification_targets(y1, y2, verbose=True) Targets are suitable for classification. Targets are suitable for classification. [array([0, 1, 0, 1], dtype=object), array(['spam', 'ham', 'spam', 'ham'], dtype=object)]
Attempting to use the function with an array containing NaN values:
>>> y_with_nan = [1, np.nan, 2, 1] >>> check_classification_targets(y_with_nan) ValueError: Target values contain NaN or infinite numbers, which are not suitable for classification.
Attempting to use the function with a continuous target array:
>>> y_continuous = np.linspace(0, 1, 10) >>> check_classification_targets(y_continuous) ValueError: The number of unique values is too high for a classification task. Validating and converting a mixed-type target array to numeric:
>>> y_mixed = [1, '2', 3.0, '4', 5] >>> check_classification_targets(y_mixed, target_type='numeric') ValueError: Target array at index 0 contains non-numeric values, which cannot be converted to integers: ['2', '4']...
Validating object target arrays without attempting conversion:
>>> y_str = ["apple", "banana", "cherry"] >>> check_classification_targets(y_str, target_type='object') [array(['apple', 'banana', 'cherry'], dtype=object)]
- geoprior.utils.validator.check_consistency_size(*arrays)[source]
Check consistency of array and raises error otherwise.
- geoprior.utils.validator.check_consistent_length(*arrays)[source]
Check that all arrays have consistent first dimensions.
Checks whether all objects in arrays have the same shape or length.
- geoprior.utils.validator.check_donut_inputs(values=None, data=None, labels=None, ops='check', labels_as_index=True, index=None, origin_index='drop', value_name='auto')[source]
Validate and/or build inputs for donut chart plotting.
This function accepts inputs in various forms and returns a pair of numeric values and labels or builds a new \(n \\times 1\) DataFrame from them. The function supports two modes:
In
ops="check", it returns a tuple(values, labels)after validating that the numeric values are appropriate for plotting.In
ops="build", it returns a pandas DataFrame constructed from the inputs. Iflabels_as_indexisTrue, the labels become the DataFrame index; otherwise, they form a separate column. If anindexis provided, it is used to reset the DataFrame index and the original index is either dropped or kept based onorigin_index.
The function also accepts inputs through a DataFrame or Series (
data). In such cases, ifvaluesis a \(\\text{str}\), it is interpreted as a column name of the DataFrame. Similarly, iflabelsis a \(\\text{str}\), it is used to fetch the label column.(39)#\[\begin{split}S = \\{ x_i \\}_{i=1}^{n} \\quad \\text{and} \\quad L = \\{ l_i \\}_{i=1}^{n}\end{split}\]where \(S\) denotes the numeric values and \(L\) denotes the corresponding labels.
- Parameters:
values (
array-likeorstr, optional) – Numeric values for the donut slices. Ifdatais a DataFrame andvaluesis a double backtick string`` ("colname"), then the column"colname"is used. Ifdatais a Series andvaluesis not provided, the series values are used.data (
pandas.Seriesorpandas.DataFrame, optional) – Data source from which to fetchvaluesandlabels. If provided, the function extracts the corresponding numeric data. For a DataFrame, ifvalues(orlabels) is a double backtick string`` ("colname"), the function fetches the column named"colname".labels (
array-likeorstr, optional) – Labels for the donut slices. Ifdatais provided andlabelsis a double backtick string`` ("colname"), then the function uses the specified column as labels. If omitted, the function uses the index of the DataFrame or Series.ops (:py:class:
``”check”:py:class:``or :py:class:``”build”:py:class:``, optional) – Operation mode of the function. In"check"mode, the function returns a tuple(values, labels)after validation. In"build"mode, it returns a new DataFrame built from the inputs. The default is"check".labels_as_index (
bool, optional) – Ifops="build", this flag determines whether the labels are used as the DataFrame index. IfTrue, the labels become the index; ifFalse, they form a separate column. The default isTrue.index (
array-likeorstr, optional) – New index to assign in"build"mode. If a double backtick string`` is provided, it must correspond to a column in the DataFrame and that column is used as the new index. If a list is provided, it directly replaces the DataFrame index. In case the original index is to be retained, seeorigin_index.origin_index (:py:class:
``”drop”:py:class:``or :py:class:``”keep”:py:class:``, optional) – Specifies whether to drop or retain the original index when resetting the DataFrame index. If set to"keep", the original index is saved in a new column namedorigin_index. The default is"drop".value_name (:py:class:
``”auto”:py:class:``orstr, optional) – Name to use for the numeric values in the built DataFrame (whenops="build"). If set to"auto"(orNone), the default name"Value"is used unless overridden by the source data. Otherwise, the provided double backtick string`` (e.g.,"Total") is used as the column name.
- Returns:
If
ops="check", returns a tuple(values, labels)wherevaluesis a NumPy array of numeric values andlabelsis a list of labels.If
ops="build", returns a pandas DataFrame constructed from the inputs. Iflabels_as_indexisTrue, the DataFrame index is set to the provided labels (or the new index ifindexis specified). Otherwise, the DataFrame contains separate columns for the labels and numeric values.
- Return type:
tupleof(ndarray,list)orpandas.DataFrame
Examples
Build inputs from a DataFrame with explicit column names:
>>> from geoprior.utils.validator import check_donut_inputs >>> import pandas as pd >>> df = pd.DataFrame({ ... "Sales": [100, 200, 150], ... "Country": ["USA", "Canada", "Mexico"] ... }) >>> # Build a DataFrame using "Sales" as values and "Country" as index >>> new_df = check_donut_inputs( ... values="Sales", ... data=df, ... labels="Country", ... ops="build", ... labels_as_index=True, ... index="Country", ... origin_index="drop" ... ) >>> new_df Sales USA 100 Canada 200 Mexico 150
Check inputs when only numeric values are provided:
>>> values, labs = check_donut_inputs( ... values=[10, 20, 30], ... labels=["A", "B", "C"], ... ops="check" ... ) >>> values array([10., 20., 30.]) >>> labs ['A', 'B', 'C']
Notes
The function internally calls the inline helper
check_numeric_dtypeto ensure that the provided numeric data satisfies the necessary type constraints. The function supports grouping or multiple donut charts by using the input DataFrame directly. See alsocheck_numeric_dtype()for numeric type validation.
- geoprior.utils.validator.check_epsilon(eps, y_true=None, y_pred=None, base_epsilon=1e-10, scale_factor=1e-05)[source]
Dynamically determine or validate an epsilon value for numerical computations.
This function either validates a provided epsilon if it is a numeric value, or calculates an appropriate epsilon dynamically based on the input data. The dynamic calculation aims to adjust epsilon based on the scale of the input data, providing flexibility and adaptability in algorithms where numerical stability is critical.
- Parameters:
eps (
{'auto', float}) – The epsilon value to use. If ‘auto’, the function dynamically determines an appropriate epsilon based on y_true and y_pred. If a float, it validates this as the epsilon value.y_true (
array-like, optional) – True values array. Used in conjunction with y_pred to dynamically determine epsilon if eps is ‘auto’. If None, this input is ignored.y_pred (
array-like, optional) – Predicted values array. Used alongside y_true for epsilon determination. If None, this input is ignored.base_epsilon (
float, optional) – Base epsilon value used as a starting point in dynamic determination. This value is adjusted based on the scale_factor and the input data to compute the final epsilon.scale_factor (
float, optional) – Scaling factor applied to adjust the base epsilon in relation to the scale of the input data. Helps tailor the epsilon to the problem’s numerical scale.
- Returns:
The determined or validated epsilon value. Ensures numerical operations are conducted with an appropriate epsilon to avoid division by zero or other numerical instabilities.
- Return type:
Examples
>>> y_true = [1, 2, 3] >>> y_pred = [1.1, 1.9, 3.05] >>> check_epsilon('auto', y_true, y_pred) 0.00001 # Example output, actual value depends on `determine_epsilon` implementation.
>>> check_epsilon(1e-8) 1e-8
Notes
Using ‘auto’ for eps allows algorithms to adapt to different scales of data, enhancing numerical stability without manually tuning the epsilon value.
- geoprior.utils.validator.check_has_run_method(estimator, msg=None, method_name='run')[source]
Check if the given estimator has a callable run method or any other specified method. This utility helps validate that an object can execute the expected method before further actions are taken.
- Parameters:
estimator (
object) – The object (instance or class) to check for the presence of the run method or another specified method.msg (
str, optional) – Custom error message to display if the method is missing. If None, a default message is generated based on the method_name.method_name (
str, default"run") – The method name to check for. This defaults to run, but you can specify any method name. The method must be callable.
- Raises:
AttributeError – Raised if the run method (or any specified method) does not exist on the object or is not callable.
Examples
>>> from geoprior.utils.validator import check_has_run_method >>> class MyClass: ... def run(self): ... pass >>> check_has_run_method(MyClass()) # No error
>>> class MyClassWithoutRun: ... pass >>> check_has_run_method(MyClassWithoutRun()) # Raises AttributeError
Notes
This function performs several checks:
Existence check: It checks whether the run method (or any other specified method) exists in the estimator object.
Callable check: It ensures that the method is callable, which rules out attributes that might exist but aren’t methods.
Static/class method check: The function accepts static or class methods as valid callable methods.
Bound method check: It verifies that instance methods are bound to an object when required, which ensures they can be called properly in the given context.
This function can be expressed as a validation function:
(40)#\[ext{check\_has\_method}(estimator, method\_name) = egin{cases} ext{valid}, & ext{if method exists and callable} \ ext{invalid}, & ext{if method is missing or not callable} \end{cases}\]It determines whether the method is callable or raises an error otherwise. Callable-method validation here follows the Python documentation and the staticmethod overview in [43, 44].
See also
validate_estimator_methodsA helper function to validate multiple methods on an estimator.
- geoprior.utils.validator.check_is_fitted(estimator, attributes=None, *, msg=None, all_or_any=<built-in function all>)[source]
Perform is_fitted validation for estimator.
Checks if the estimator is fitted by verifying the presence of fitted attributes (ending with a trailing underscore) and otherwise raises a NotFittedError with the given message.
If an estimator does not set any attributes with a trailing underscore, it can define a
__sklearn_is_fitted__or__fusionlab_is_fitted__method returning a boolean to specify if the estimator is fitted or not.- Parameters:
estimator (
estimator instance) – Estimator instance for which the check is performed.attributes (
str,listortupleofstr, defaultNone) –Attribute name(s) given as string or a list/tuple of strings Eg.:
["coef_", "estimator_", ...], "coef_"If None, estimator is considered fitted if there exist an attribute that ends with a underscore and does not start with double underscore.
The default error message is, “This %(name)s instance is not fitted yet. Call ‘fit’ with appropriate arguments before using this estimator.”
For custom messages if “%(name)s” is present in the message string, it is substituted for the estimator name.
Eg. : “Estimator, %(name)s, must be fitted before sparsifying”.
all_or_any (
callable,{all, any}, defaultall) – Specify whether all or any of the given attributes must exist.
- Raises:
TypeError – If the estimator is a class or not an estimator instance
NotFittedError – If the attributes are not found.
- geoprior.utils.validator.check_is_fitted2(estimator, attributes, *, msg=None)[source]
Perform is_fitted validation for estimator.
Checks if the estimator is fitted by looking for attributes set during fitting. Typically, these attributes end with an underscore (‘_’).
- Parameters:
- Raises:
NotFittedError – If the given attributes are not found in the estimator.
Examples
>>> from sklearn.ensemble import RandomForestClassifier >>> clf = RandomForestClassifier() >>> check_is_fitted(clf, ['feature_importances_']) NotFittedError: This RandomForestClassifier instance is not fitted yet.
- geoprior.utils.validator.check_is_runned(estimator, attributes=None, *, msg=None, all_or_any=<built-in function all>)[source]
Validate if an estimator instance has been “runned” (executed) prior to invoking dependent methods. This check ensures that the estimator is in the appropriate operational state, allowing users to identify and address runtime issues effectively.
If an estimator does not set “runned” attributes (such as
_is_runned), it may define a__gofast_is_runned__method. This method should return a boolean indicating whether the estimator is “runned” or not.- Parameters:
estimator (
object) –The instance of the estimator or class being validated. This parameter represents the object in which dependent methods are validated to confirm that the “runned” state has been achieved.
To determine the “runned” status, the function checks for specific attributes or, if defined, the
__gofast_is_runned__method.attributes (
str,list, ortupleofstr, optional, defaultNone) –Specifies the name(s) of attributes that indicate the “runned” status, such as
['_is_runned']or['_is_fitted']. If these attributes are present and set to True, the estimator is considered to have been runned.If
attributesis set to None, the function will default to checking for_is_runned. This default provides flexibility for estimators that employ standard runned flags.msg (
str, optional, defaultNone) –Custom error message to be displayed if the validation fails. By default, this error message uses the class name of the estimator in the format:
”This %(name)s instance has not been ‘runned’ yet. Call ‘run’ with appropriate arguments before using this method.”
To customize the message, include %(name)s as a placeholder for the estimator’s class name.
all_or_any (
callable,{all, any}, optional, defaultall) – Determines whether all or any of the specified attributes must be present and set to True. By default, the function expects all attributes to be set to True. Set to any for greater flexibility with multiple attributes.
- ``__gofast_is_runned__`` : optional, callable
If defined within the estimator, this method should return a boolean indicating the “runned” status of the estimator. This provides an alternative to using attributes.
- Raises:
RuntimeError – If none of the specified attributes are set to True or if the __gofast_is_runned__ method (if present) returns False.
Notes
The check_is_runned function ensures that methods dependent on the “runned” status are only executed after the estimator has completed all required preliminary processes, like fit or run. This helper mirrors the fitted-state checks described in [45, 46].
Examples
>>> from geoprior.utils.validator import check_is_runned >>> class ExampleClass: ... def __init__(self): ... self._is_runned = False ... ... def run(self): ... self._is_runned = True ... print("Run completed.") ... ... def process_data(self): ... check_is_runned(self) ... print("Processing data...") >>> model = ExampleClass() >>> model.process_data() # Raises RuntimeError >>> model.run() >>> model.process_data() # Now it works
See also
check_is_fittedValidates that an estimator has been “fitted” before further use.
validate_estimator_methodsValidates essential estimator methods.
- geoprior.utils.validator.check_memory(memory)[source]
Check that
memoryis joblib.Memory-like.joblib.Memory-like means that
memorycan be converted into a joblib.Memory instance (typically a str denoting thelocation) or has the same interface (has acachemethod).- Parameters:
memory (
None,strorobject with the joblib.Memory interface) –If string, the location where to create the joblib.Memory interface.
If None, no caching is done and the Memory object is completely transparent.
- Returns:
memory – A correct joblib.Memory object.
- Return type:
object with the joblib.Memory interface- Raises:
ValueError – If
memoryis not joblib.Memory-like.
- geoprior.utils.validator.check_mixed_data_types(data)[source]
Checks if the given data (DataFrame or numpy array) contains both numerical and categorical columns.
- Parameters:
data (
pd.DataFrameornp.ndarray) – The data to check. Can be a pandas DataFrame or a numpy array. If data is a numpy array, it is temporarily converted to a DataFrame for type checking.- Returns:
True if the data contains both numerical and categorical columns, False otherwise.
- Return type:
Examples
Using with a pandas DataFrame:
>>> import numpy as np >>> import pandas as pd >>> from geoprior.utils.validator import check_mixed_data_types >>> df = pd.DataFrame({'A': [1, 2, 3], 'B': ['a', 'b', 'c']}) >>> print(check_mixed_data_types(df)) True
Using with a numpy array:
>>> array = np.array([[1, 'a'], [2, 'b'], [3, 'c']]) >>> print(check_mixed_data_types(array)) True
With data containing only numerical values:
>>> df_numeric_only = pd.DataFrame({'A': [1, 2, 3], 'B': [4, 5, 6]}) >>> print(check_mixed_data_types(df_numeric_only)) False
With data containing only categorical values:
>>> df_categorical_only = pd.DataFrame({'A': ['a', 'b', 'c'], 'B': ['d', 'e', 'f']}) >>> print(check_mixed_data_types(df_categorical_only)) False
- geoprior.utils.validator.check_random_state(seed)[source]
Turn seed into a np.random.RandomState instance.
- Parameters:
seed (
None,intorinstanceofRandomState) – If seed is None, return the RandomState singleton used by np.random. If seed is an int, return a new RandomState instance seeded with seed. If seed is already a RandomState instance, return it. Otherwise raise ValueError.- Returns:
The random state object based on seed parameter.
- Return type:
- geoprior.utils.validator.check_scalar(x, name, target_type, *, min_val=None, max_val=None, include_boundaries='both')[source]
Validate scalar parameters type and value.
- Parameters:
x (
object) – The scalar parameter to validate.name (
str) – The name of the parameter to be printed in error messages.target_type (
typeortuple) – Acceptable data types for the parameter.min_val (
floatorint, defaultNone) – The minimum valid value the parameter can take. If None (default) it is implied that the parameter does not have a lower bound.max_val (
floatorint, defaultNone) – The maximum valid value the parameter can take. If None (default) it is implied that the parameter does not have an upper bound.include_boundaries (
{"left", "right", "both", "neither"}, default"both") – Whether the interval defined by min_val and max_val should include the boundaries. Use"left"for[min_val, max_val),"right"for(min_val, max_val],"both"for[min_val, max_val], or"neither"for(min_val, max_val).
- Returns:
x – The validated number.
- Return type:
- Raises:
TypeError – If the parameter’s type does not match the desired type.
ValueError – If the parameter’s value violates the given bounds. If min_val, max_val and include_boundaries are inconsistent.
- geoprior.utils.validator.check_symmetric(array, *, tol=1e-10, raise_warning=True, raise_exception=False)[source]
Make sure that array is 2D, square and symmetric.
If the array is not symmetric, then a symmetrized version is returned. Optionally, a warning or exception is raised if the matrix is not symmetric.
- Parameters:
array (
{ndarray, sparse matrix}) – Input object to check / convert. Must be two-dimensional and square, otherwise a ValueError will be raised.tol (
float, default1e-10) – Absolute tolerance for equivalence of arrays. Default = 1E-10.raise_warning (
bool, defaultTrue) – If True then raise a warning if conversion is required.raise_exception (
bool, defaultFalse) – If True then raise an exception if array is not symmetric.
- Returns:
array_sym – Symmetrized version of the input array, i.e. the average of array and array.transpose(). If sparse, then duplicate entries are first summed and zeros are eliminated.
- Return type:
{ndarray, sparse matrix}
- geoprior.utils.validator.check_y(y, multi_output=False, y_numeric=False, input_name='y', estimator=None, to_frame=False, allow_nan=False)[source]
Validates the target array y, ensuring it is suitable for classification or regression tasks based on its content and the specified strategy.
- Parameters:
y (
array-like) – Target values to validate.multi_output (
bool, defaultFalse) – Whether to allow two-dimensionalyvalues. IfFalse,yis validated as a vector. Whenmulti_output=True,ystill cannot containnp.nanornp.infvalues unlessallow_nanpermits NaNs.y_numeric (
bool, defaultFalse) – Whether to ensure that y has a numeric type. If dtype of y is object, it is converted to float64. Should only be used for regression algorithms.input_name (
str, default"y") – Data name used to construct the error message.estimator (
strorestimator instance, defaultNone) – If passed, include the name of the estimator in warning messages.allow_nan (
bool, defaultFalse) – IfTrue, do not raise an error whenycontains NaN values.to_frame (
bool, defaultFalse) – Reconvert the validated array to its initial pandas type when the input was provided as a pandas Series or DataFrame.
- Returns:
y_converted – The converted and validated y.
- Return type:
- geoprior.utils.validator.contains_nested_objects(lst, strict=False, allowed_types=None)[source]
Determines whether a list contains nested objects.
- Parameters:
lst (
list) – The list to be checked for nested objects.strict (
bool, optional) – If True, all items in the list must be nested objects. If False, the function returns True if any item is a nested object. Default is False.allowed_types (
tupleoftypes, optional) – A tuple of types to consider as nested objects. If None, common nested types like list, set, dict, and tuple are checked. Default is None.
- Returns:
True if the list contains nested objects according to the given parameters, otherwise False.
- Return type:
Notes
A nested object is defined as any item within the list that is not a primitive data type (e.g., int, float, str) or is a complex structure like lists, sets, dictionaries, etc. The function can be customized to check for specific types using the allowed_types parameter.
Examples
>>> from geoprior.utils.validator import contains_nested_objects >>> example_list1 = [{1, 2}, [3, 4], {'key': 'value'}] >>> example_list2 = [1, 2, 3, [4]] >>> example_list3 = [1, 2, 3, 4] >>> contains_nested_objects(example_list1) True # non-strict, contains nested objects >>> contains_nested_objects(example_list1, strict=True) True # strict, all are nested objects >>> contains_nested_objects(example_list2) True # non-strict, contains at least one nested object >>> contains_nested_objects(example_list2, strict=True) False # strict, not all are nested objects >>> contains_nested_objects(example_list3) False # non-strict, no nested objects >>> contains_nested_objects(example_list3, strict=True) False # strict, no nested objects
- geoprior.utils.validator.convert_array_to_pandas(X, *, to_frame=False, columns=None, input_name='X')[source]
Converts an array-like object to a pandas DataFrame or Series, applying provided column names or series name.
- Parameters:
X (
array-like) – The array to convert to a DataFrame or Series.to_frame (
bool, defaultFalse) – If True, converts the array to a DataFrame. Otherwise, returns the array unchanged.columns (
strorlistofstr, optional) – Name(s) for the columns of the resulting DataFrame or the name of the Series.input_name (
str, default'X') – The name of the input variable; used in constructing error messages.
- Returns:
- Raises:
TypeError – If X is not array-like or if columns is neither a string nor a list of strings.
ValueError – If the conversion to DataFrame is requested but columns is not provided, or if the length of columns does not match the number of columns in X.
- geoprior.utils.validator.ensure_2d(X, output_format='auto')[source]
Ensure that the input X is converted to a 2-dimensional structure.
- Parameters:
X (
array-likeorpandas.DataFrame) – The input data to convert. Can be a list, numpy array, or DataFrame.output_format (
str, optional) – The format of the returned object. Options are “auto”, “array”, or “frame”. “auto” returns a DataFrame if X is a DataFrame, otherwise a numpy array. “array” always returns a numpy array. “frame” always returns a pandas DataFrame.
- Returns:
The converted 2-dimensional structure, either as a numpy array or DataFrame.
- Return type:
ndarrayorDataFrame- Raises:
ValueError – If the output_format is not one of the allowed values.
Examples
>>> import numpy as np >>> from geoprior.utils.validator import ensure_2d >>> X = np.array([1, 2, 3]) >>> ensure_2d(X, output_format="array") array([[1], [2], [3]]) >>> df = pd.DataFrame([1, 2, 3]) >>> ensure_2d(df, output_format="frame") 0 0 1 1 2 2 3
- geoprior.utils.validator.ensure_non_negative(*arrays, err_msg=None)[source]
Ensure that provided arrays contain only non-negative values.
This function checks each provided array for non-negativity. If any negative values are found in any array, it raises a ValueError. This check is crucial for computations or algorithms where negative values are not permissible, such as logarithmic transformations.
- Parameters:
*arrays (
array-like) – One or more array-like structures (e.g., lists, numpy arrays). Each array is checked for non-negativity.err_msg (
str, optional) – Specify a custom error message if negative values are found.
- Raises:
ValueError – If any array contains negative values, a ValueError is raised with a message indicating that only non-negative values are expected.
Examples
>>> y_true = [0, 1, 2, 3] >>> y_pred = [0.5, 2.1, 3.5, -0.1] >>> ensure_non_negative(y_true, y_pred) ValueError: Negative value found. Expect only non-negative values.
Note
The function uses a variable number of arguments, allowing flexibility in the number of arrays checked in a single call.
- geoprior.utils.validator.filter_valid_kwargs(callable_obj, kwargs)[source]
Filter and return only the valid keyword arguments for a given callable object.
This function checks if the arguments in kwargs are valid for the provided callable object (function, lambda function, method, or class). If any argument is not valid, it is removed from kwargs. The function returns only the valid kwargs.
- Parameters:
callable_obj (
callable) – The callable object (function, lambda function, method, or class) for which the keyword arguments need to be validated.kwargs (
dict) – Dictionary of keyword arguments to be validated against the callable object.
- Returns:
valid_kwargs – Dictionary containing only the valid keyword arguments for the callable object.
- Return type:
Examples
>>> def example_func(a, b, c=3): ... pass >>> kwargs = {'a': 1, 'b': 2, 'd': 4} >>> filter_valid_kwargs(example_func, kwargs) {'a': 1, 'b': 2}
>>> class ExampleClass: ... def __init__(self, x, y, z=10): ... pass >>> kwargs = {'x': 1, 'y': 2, 'a': 3} >>> filter_valid_kwargs(ExampleClass, kwargs) {'x': 1, 'y': 2} >>> filter_valid_kwargs(ExampleClass(), kwargs) {'x': 1, 'y': 2}
Notes
This function uses the inspect module to retrieve the signature of the given callable object and validate the keyword arguments.
- geoprior.utils.validator.get_estimator_name(estimator)[source]
Get the estimator name whatever it is an instanciated object or not
- Parameters:
estimator – callable or instanciated object, callable or instance object that has a fit method.
- Returns:
str, name of the estimator.
- geoprior.utils.validator.handle_zero_division(y_true, zero_division='warn', metric_name='metric computation', epsilon=1e-15, replace_with=None)[source]
Preprocess input arrays to handle cases where zero could cause division errors in subsequent metric computations.
- Parameters:
y_true (
array-like) – The input data array where zeros might cause division errors.zero_division (
{'warn', 'raise', 'ignore'}, default'warn') – Determines the action to perform when a zero is encountered. Use"warn"to issue a warning and replace zeros withreplace_withorepsilon,"raise"to raise an error, or"ignore"to leave zeros unchanged when the metric can handle them natively.metric_name (
str, optional) – Name of the metric for which this preprocessing is being done, to be included in warnings or error messages for better context.epsilon (
float, optional) – Small value to use as default replacement if replace_with is None, default is 1e-15.replace_with (
floatorNone, optional) – A specific value to replace zeros with, if None, epsilon is used.
- Returns:
The processed array with modifications based on the zero_division strategy.
- Return type:
- Raises:
ValueError – If zero_division is ‘raise’ and zero is found in y_true.
Notes
Using replace_with allows for custom behavior when handling zeros, which can be tailored to the specific requirements of different metric computations.
Examples
>>> from geoprior.utils.validator import handle_zero_division >>> y_true = [0, 1, 2, 3, 0] >>> processed_y_true = handle_zero_division( ... y_true, replace_with=0.001, zero_division='warn' ... ) >>> print(processed_y_true) [1.e-03 1.e+00 2.e+00 3.e+00 1.e-03]
- geoprior.utils.validator.has_methods(models, methods, strict=True, check_status='check_only', msg=None)[source]
Validate that one or more model objects implement required methods.
- Parameters:
models (
objectorlistofobjects) – Model instance or collection of model instances to validate.methods (
listofstr) – Public method names that each model must implement.strict (
bool, optional) – IfTrue, raise anAttributeErrorwhen a required method is missing.check_status (
{'validate', 'check_only'}, optional) – Return mode. Use'validate'to return validated models and'check_only'to return a boolean flag.msg (
strorNone, optional) – Optional custom error message using{model}and{methods}placeholders.
- Returns:
Validated models when
check_status='validate'or a boolean flag whencheck_status='check_only'.- Return type:
- Raises:
AttributeError – If a required method is missing and
strict=True.TypeError – If
methodsis not a list of strings.ValueError – If
check_statusis invalid.
- geoprior.utils.validator.has_fit_parameter(estimator, parameter)[source]
Check whether the estimator’s fit method supports the given parameter.
- Parameters:
- Returns:
is_parameter – Whether the parameter was found to be a named parameter of the estimator’s fit method.
- Return type:
Examples
>>> from sklearn.svm import SVC >>> from sklearn.utils.validation import has_fit_parameter >>> has_fit_parameter(SVC(), "sample_weight") True
- geoprior.utils.validator.has_required_attributes(model, attributes)[source]
Check if the model has all required Keras-specific attributes.
This function is part of the deep validation process to ensure that the model not only inherits from Keras model classes but also implements essential methods.
- geoprior.utils.validator.is_binary_class(y, accept_multioutput=False)[source]
Check whether the target array represents binary classification. Optionally, handle multi-output arrays if each output is binary.
- Parameters:
y (
array-like) – The target array to be checked. This can be a 1D array for single output or a 2D array for multiple outputs if accept_multioutput is True.accept_multioutput (
bool, defaultFalse) – If True, the function checks if each column in a multi-dimensional array is binary. If False, the function checks if the entire array is binary.
- Returns:
Returns True if y is binary (or each output is binary if multi-output is accepted), False otherwise.
- Return type:
Examples
>>> from geoprior.utils.validator import is_binary_class >>> is_binary_class([0, 1, 1, 0]) True >>> is_binary_class([[0, 1], [1, 0], [0, 1], [1, 0]], accept_multioutput=True) True >>> is_binary_class([0, 1, 2, 3]) False
- geoprior.utils.validator.is_categorical(data, column, strict=False, error='raise')[source]
Checks if a specified column in a DataFrame or Series is of a categorical type.
- Parameters:
data (
DataFrameorSeries) – The DataFrame or Series to check.column (
str) – The name of the column to check.strict (
bool, optional) – If True, only considers pandas CategoricalDtype as categorical. If False, also considers object dtype that often represents categorical data. Default is False.error (
str, optional) – Specifies how to handle situations when the column does not exist. Options are ‘raise’, ‘warn’, or ‘ignore’. Default is ‘raise’.
- Returns:
True if the column is categorical, otherwise False.
- Return type:
- Raises:
ValueError – If the column does not exist and error is set to ‘raise’.
Examples
>>> import pandas as pd >>> from geoprior.utils.validator import is_categorical >>> df = pd.DataFrame({ ... 'fruit': ['Apple', 'Banana', 'Cherry'], ... 'count': [10, 20, 15] ... }) >>> df['fruit'] = df['fruit'].astype('category') >>> print(is_categorical(df, 'fruit')) True >>> print(is_categorical(df, 'count')) False >>> print(is_categorical(df, 'non_existent', error='warn')) Warning: Column 'non_existent' not found in the dataframe. False
- geoprior.utils.validator.is_frame(arr, df_only=False, raise_exception=False, objname=None, error='raise')[source]
Check if arr is a pandas DataFrame or Series.
If
df_only=True, the function checks strictly for a pandas DataFrame. Otherwise, it accepts either a pandas DataFrame or Series. This utility is often used to validate input data before processing, ensuring that the input conforms to expected types.- Parameters:
arr (
object) – The object to examine. Typically a pandas DataFrame or Series, but can be any Python object.df_only (
bool, optional) – If True, only verifies that arr is a DataFrame. If False, checks for either a DataFrame or a Series. Default is False.raise_exception (
bool, optional) – If True, this will override error=”raise”. This parameter is deprecated and will be removed soon. Default is False.error (
str, optional) – Determines the action when arr is not a valid frame. Can be: -"raise": Raises a TypeError. -"warn": Issues a warning. -"ignore": Does nothing. Default is"raise".objname (
strorNone, optional) – A custom name used in the error message if error is set to"raise". If None, a generic name is used.
- Returns:
True if arr is a DataFrame or Series (or strictly a DataFrame if df_only=True), otherwise False.
- Return type:
- Raises:
TypeError – If error=”raise” and arr is not a valid frame. The error message guides the user to provide the correct type (DataFrame or DataFrame or Series).
Notes
This function does not convert or modify arr. It merely checks its compatibility with common DataFrame/Series interfaces by examining attributes such as ‘columns’ or ‘name’. For a DataFrame, arr.columns should exist, and for a Series, a ‘name’ attribute is often present. Both DataFrame and Series implement __array__, making them NumPy array-like.
Examples
>>> import pandas as pd >>> from geoprior.utils.validator import is_frame
>>> df = pd.DataFrame({'A': [1,2,3]}) >>> is_frame(df) True
>>> s = pd.Series([4,5,6], name='S') >>> is_frame(s) True
>>> is_frame(s, df_only=True) False
If error=”raise”:
>>> is_frame(s, df_only=True, error="raise", objname='Input') Traceback (most recent call last): ... TypeError: 'Input' parameter expects a DataFrame. Got 'Series'
- geoprior.utils.validator.is_installed(module)[source]
Checks if TensorFlow is installed.
This function attempts to find the TensorFlow package specification without importing the package. It’s a lightweight method to verify the presence of TensorFlow in the environment.
- Returns:
True if TensorFlow is installed, False otherwise.
- Return type:
- Parameters:
module (str)
Examples
>>> from geoprior.utils.validator import is_installed >>> print(is_installed("tensorflow")) True # Output will be True if TensorFlow is installed, False otherwise.
- geoprior.utils.validator.is_normalized(arr, method='sum')[source]
Checks if the provided array is normalized according to the specified method.
- Parameters:
arr (
array-like) – The array to check for normalization.method (
str, optional) – The normalization method to check against. Use"01"to confirm values are within[0, 1]with minimum0and maximum1,"zscore"to confirm mean0and standard deviation1, or"sum"to confirm the array sums to1. Default is"sum".
- Returns:
Returns True if the array is normalized according to the specified method, False otherwise.
- Return type:
Examples
>>> arr = np.array([0.25, 0.25, 0.25, 0.25]) >>> is_normalized(arr, method='sum') True
>>> arr = np.array([0, 0.5, 1]) >>> is_normalized(arr, method='01') True
>>> arr = np.array([1, -1, 1, -1]) >>> is_normalized(arr, method='zscore') True
- geoprior.utils.validator.is_square_matrix(data, data_type=None)[source]
Determine whether the input, either a DataFrame or an array-like structure, forms a square matrix.
Automatically detects the data type unless specified. Supports data inputs that can be converted to a NumPy array.
- Parameters:
data (
DataFrame,array-like, orany object convertibletoa numpy array) – The input data to check.data_type (
str, optional) – The expected type of the input data. Valid options are ‘array’ or ‘dataframe’. If not specified, the data type is inferred. Default interpretation is as an ‘array’.
- Returns:
Returns True if the data is a square matrix, otherwise False.
- Return type:
- Raises:
ValueError – If data_type is neither ‘array’ nor ‘dataframe’.
TypeError – If the input data does not match the expected format or cannot be processed.
Examples
>>> is_square_matrix(np.array([[1, 2], [3, 4]])) True
>>> is_square_matrix(pd.DataFrame([[1, 2, 3], [4, 5, 6]])) False
>>> is_square_matrix([[1, 2], [3, 4]], data_type='array') True
Notes
A square matrix has an equal number of rows and columns. This function checks the dimensionality and shape of the data to confirm if it meets this criterion.
- geoprior.utils.validator.is_time_series(data, time_col, check_time_interval=False)[source]
Check if the provided DataFrame is time series data.
- Parameters:
data (
pandas.DataFrame) – The DataFrame to be checked.time_col (
str) – The name of the column in df expected to represent time.
- Returns:
True if df is a time series, False otherwise.
- Return type:
Example
>>> import pandas as pd >>> df = pd.DataFrame({ 'Date': ['2021-01-01', '2021-01-02', '2021-01-03', '2021-01-04', '2021-01-05'], 'Value': [1, 2, 3, 4, 5] }) >>> # Should return True if Date column >>> # can be converted to datetime >>> print(is_time_series(df, 'Date'))
- geoprior.utils.validator.is_valid_policies(nan_policy, allowed_policies=None)[source]
Validates the nan_policy or any policy argument to ensure it is one of the acceptable options (allowed_policies).
Function is used to enforce conformity to predefined NaN handling strategies in data processing tasks.
- Parameters:
nan_policy (
str) – The NaN handling policy to validate. Acceptable values are: ‘propagate’ - NaN values are propagated, i.e., no action is taken. ‘omit’ - NaN values are omitted before proceeding with the operation. ‘raise’ - Raises an error if NaN values are present.allowed_policies (
listofstr, optional) – A list of allowable policy options. If None, defaults to [‘propagate’, ‘omit’, ‘raise’].
- Raises:
ValueError – If nan_policy is not one of the valid options in allowed_policies.
- Returns:
The verified nan_policy value, confirming it is within allowed parameters.
- Return type:
Examples
>>> from geoprior.utils.validator import is_valid_policies >>> is_valid_policies('omit') # This should pass without an error. >>> is_valid_policies('ignore') # This should raise a ValueError.
- geoprior.utils.validator.normalize_array(arr, normalize='auto', method='01')[source]
Checks if an array is normalized according to the specified method and normalizes it if required based on the ‘normalize’ parameter.
- Parameters:
arr (
array-like) – The input array to check and potentially normalize.normalize (
str, optional) – Controls whether normalization is applied. Use"auto"to normalize only when the array is not already normalized for the selectedmethod. UseTrueto always normalize andFalseto return the array unchanged. Default is"auto".method (
str, optional) – Normalization method to apply. Use"01"for min-max scaling,"zscore"for standardization, or"sum"to scale values so they sum to1. Default is"01".
- Returns:
The normalized array, or the original array if no normalization was applied.
- Return type:
np.ndarray- Raises:
ValueError – If an unknown normalization method is specified or if normalization cannot be performed due to data characteristics (e.g., zero variance).
Examples
>>> import numpy as np >>> from geoprior.utils.validator import normalize_array
>>> data = np.array([1, 2, 3, 4, 5]) >>> normalized_data = normalize_array(data, normalize=True, method='01') >>> print("Normalized between 0 and 1:", normalized_data) Normalized between 0 and 1: [0. 0.25 0.5 0.75 1. ]
>>> zscore_data = normalize_array(data, normalize=True, method='zscore') >>> print("Standardized (Z-score):", zscore_data) Standardized (Z-score): [-1.41421356 -0.70710678 0. 0.70710678 1.41421356]
>>> sum_data = normalize_array(data, normalize=True, method='sum') >>> print("Normalized by sum:", sum_data) Normalized by sum: [0.06666667 0.13333333 0.2 0.26666667 0.33333333]
- geoprior.utils.validator.parameter_validator(param_name, target_strs, match_method='contains', raise_exception=True, **kws)[source]
Creates a validator function for ensuring a parameter’s value matches one of the allowed target strings, optionally applying normalization.
This higher-order function returns a validator that can be used to check if a given parameter value matches allowed criteria, optionally raising an exception or normalizing the input.
- Parameters:
param_name (
str) – Name of the parameter to be validated. Used in error messages to indicate which parameter failed validation.target_strs (
listofstr) – A list of acceptable string values for the parameter.match_method (
str, optional) – The method used to match the input string against the target strings. The default method is ‘contains’, which checks if the input string contains any of the target strings.raise_exception (
bool, optional) – Specifies whether an exception should be raised if validation fails. Defaults to True, raising an exception on failure.**kws (
dict,) – Keyword arguments passed togeoprior.core.utils.normalize_string().
- Returns:
A closure that takes a single string argument (the parameter value) and returns a normalized version of it if the parameter matches the target criteria. If the parameter does not match and raise_exception is True, it raises an exception; otherwise, it returns the original value.
- Return type:
function
Examples
>>> from geoprior.utils.validator import parameter_validator >>> validate_outlier_method = parameter_validator( ... 'outlier_method', ['z_score', 'iqr']) >>> outlier_method = "z_score" >>> print(validate_outlier_method(outlier_method)) 'z_score'
>>> validate_fill_missing = parameter_validator( ... 'fill_missing', ['median', 'mean', 'mode'], raise_exception=False) >>> fill_missing = "average" # This does not match but won't raise an exception. >>> print(validate_fill_missing(fill_missing)) 'average'
Notes
The function leverages a custom utility function normalize_string from a module named geoprior.core.utils. This utility is assumed to handle string normalization and matching based on the provided match_method.
If raise_exception is set to False and the input does not match any target string, the input string is returned unchanged. This behavior allows for optional enforcement of the validation rules.
The primary use case for this function is to validate and optionally normalize parameters for configuration settings or function arguments where only specific values are allowed.
- geoprior.utils.validator.process_y_pairs(*ys, error='warn', solo_return=False, ops='check_only')[source]
Process and validate paired arrays of ground truth (
y_true) and predicted values (y_pred) for machine learning evaluation.- Parameters:
*ys (
ArrayLike) – Variable-length sequence of array-likes containing alternating (y_true,y_pred) pairs. Must contain even number of inputs.error (
{'raise', 'warn', 'ignore'}, default'warn') – Handling strategy for validation errors: -'raise': Immediately raise ValueError -'warn': Issue UserWarning but continue processing -'ignore': Silently skip invalid pairssolo_return (
bool, defaultFalse) – When processing single pair, return as individual arrays instead of length-1 lists.ops (
{'check_only', 'validate'}, default'check_only') – Processing mode: -'check_only': Verify pair lengths without modification -'validate': Clean data (remove NaNs) and validate dtypes
- Returns:
Processed pairs as (
y_trues,y_preds) tuple. Return type depends onsolo_returnand number of valid pairs.- Return type:
Tuple[List[ArrayLike],List[ArrayLike]]orTuple[ArrayLike,ArrayLike]- Raises:
If input count is odd and
error='raise'Length mismatch in pairs when
error='raise'Invalid
errororopsvalues
When odd input count and
error='warn'Length mismatches when
error='warn'
Examples
Basic usage with valid pairs:
>>> from geoprior.utils.validator import process_y_pairs >>> y_true1 = [1.2, 2.3, 3.4] >>> y_pred1 = [1.1, 2.4, 3.3] >>> y_true2 = [4.5, 5.6] >>> y_pred2 = [4.4, 5.7] >>> process_y_pairs(y_true1, y_pred1, y_true2, y_pred2) ([[1.2, 2.3, 3.4], [4.5, 5.6]], [[1.1, 2.4, 3.3], [4.4, 5.7]])
Handling mismatched pair with warnings:
>>> y_bad = [1, 2, 3] >>> p_bad = [1, 2] >>> process_y_pairs(y_bad, p_bad, error='warn') UserWarning: Length mismatch in pair 0: 3 vs 2 ([], [])
Full validation pipeline:
>>> import numpy as np >>> y_clean, p_clean = process_y_pairs( ... [1, np.nan, 3], [np.nan, 2.1, 3.2], ... ops='validate', solo_return=True ... ) >>> y_clean array([3.]) >>> p_clean array([3.2])
Notes
Ensures input pairs meet requirements for downstream analysis through:
(41)#\[ \begin{align}\begin{aligned}\forall i \in \{0,2,4,...\},\ (y_{true}^i, y_{pred}^i) \rightarrow (\tilde{y}_{true}^i, \tilde{y}_{true}^i)\ \text{where}\\\text{len}(\tilde{y}_{true}^i) = \text{len}(\tilde{y}_{pred}^i)\\\text{and}\ \tilde{y}_{true}^i \in \mathbb{R}^{n},\ \tilde{y}_{pred}^i \in \mathbb{R}^{n}\end{aligned}\end{align} \]Uses
drop_nan_infor NaN removal and index resetting during validationApplies
validate_yyfor dtype consistency checks and array flatteningForward references for
ArrayLikeallow flexibility - accepts any array-like structure (list, numpy array, pandas Series, etc.)The type and array-handling conventions rely on the Python language reference and NumPy’s array-programming model [36, 47].
See also
drop_nan_inCore NaN removal and index resetting function
validate_yyArray validation and dtype consistency checker
sklearn.utils.check_consistent_lengthScikit-learn’s length validation
- geoprior.utils.validator.to_dtype_str(arr, return_values=False)[source]
Convert numeric or object dtype to string dtype.
This will avoid a particular TypeError when an array is filled by np.nan and at the same time contains string values. Converting the array to dtype str rather than keeping to ‘object’ will pass this error.
- Parameters:
arr – array-like array with all numpy datatype or pandas dtypes
return_values – bool, default=False returns array values in string dtype. This might be usefull when a series with dtype equals to object or numeric is passed.
- Returns:
array-like array-like with dtype str Note that if the dataframe or serie is passed, the object datatype will change only if return_values is set to
True, otherwise returns the same object.
- geoprior.utils.validator.validate_and_adjust_ranges(**kwargs)[source]
Validates and adjusts the provided range tuples to ensure each is composed of two numerical values and is sorted in ascending order.
This function takes multiple range specifications as keyword arguments, each expected to be a tuple of two numerical values (min, max). It validates the format and contents of each range, adjusting them if necessary to ensure that each tuple is ordered as (min, max).
- Parameters:
**kwargs (
dict) – Keyword arguments where each key is the name of a range (e.g., ‘lat_range’) and its corresponding value is a tuple of two numerical values representing the minimum and maximum of that range.- Returns:
A dictionary with the same keys as the input, but with each tuple value adjusted to ensure it is in the format (min, max).
- Return type:
- Raises:
ValueError – If any provided range tuple does not contain exactly two values, contains non-numerical values, or if the min value is not less than the max value.
Examples
>>> from geoprior.utils.validator import validate_and_adjust_ranges >>> validate_and_adjust_ranges(lat_range=(34.00, 36.00), lon_range=(-118.50, -117.00)) {'lat_range': (34.00, 36.00), 'lon_range': (-118.50, -117.00)}
>>> validate_and_adjust_ranges(time_range=(10.0, 0.01)) {'time_range': (0.01, 10.0)}
>>> validate_and_adjust_ranges(invalid_range=(1, 'a')) ValueError: invalid_range must contain numerical values.
Notes
This function is particularly useful for preprocessing input ranges for various analyses, ensuring consistency and correctness of range specifications. It automates the adjustment of provided ranges, simplifying the setup process for further data processing or modeling tasks.
- geoprior.utils.validator.validate_batch_size(batch_size, n_samples, min_batch_size=1, max_batch_size=None)[source]
Validate the batch size against the number of samples.
This function checks whether the provided batch_size is appropriate given the total number of samples n_samples. It ensures that the batch size meets specified minimum and maximum limits, raising appropriate errors if any constraints are violated.
- Parameters:
batch_size (
int) – The size of each batch. This must be a positive integer, as batches must contain at least one sample. A ValueError will be raised if this value is less than the minimum allowed batch size or exceeds the total number of samples.n_samples (
int) – The total number of samples in the dataset. This value must be positive and greater than or equal to the batch_size. If batch_size is greater than n_samples, a ValueError is raised.min_batch_size (
int, optional) – The minimum allowed batch size (default is 1). This parameter defines the smallest permissible batch size. A ValueError will be raised if the batch_size is less than this value.max_batch_size (
int, optional) – The maximum allowed batch size (default is None, meaning no upper limit). This parameter can be used to restrict the size of the batch to a specified maximum value. If max_batch_size is provided, a ValueError will be raised if the batch_size exceeds this limit.
- Returns:
batch_size
- Return type:
Validated numberofbatch size- Raises:
ValueError – If the batch_size is less than the min_batch_size, greater than the n_samples, or exceeds the max_batch_size if specified. Additionally, if batch_size is not a positive integer, a ValueError is raised.
Notes
Let B represent the batch_size and N represent the n_samples. The validation can be expressed mathematically as:
(42)#\[ext{If } B < ext{min\_batch\_size} ext{ or } B > N ext{ or } B > ext{max\_batch\_size}: \quad ext{raise ValueError}\]This function is essential for managing data batching in machine learning workflows, where improper batch sizes can lead to inefficient training or runtime errors. The practical mini-batch constraint follows standard deep-learning training guidance [48].
Examples
>>> from geoprior.utils.validator import validate_batch_size >>> validate_batch_size(32, 100) # Valid case >>> validate_batch_size(0, 100) # Raises ValueError >>> validate_batch_size(150, 100) # Raises ValueError >>> validate_batch_size(32, 100, max_batch_size=32) # Valid case >>> validate_batch_size(40, 100, max_batch_size=32) # Raises ValueError
- geoprior.utils.validator.validate_comparison_data(df, alignment='auto')[source]
Validates a DataFrame to ensure it is a square matrix and that the index and column names match. Optionally aligns the index names to the column names or vice versa based on the alignment parameter.
- Parameters:
df (
pandas.DataFrame) – The DataFrame to validate.alignment (
str, default'auto') – Controls how the DataFrame’s index and columns are aligned if they d o not match. Options are ‘auto’, ‘index_to_columns’, and ‘columns_to_index’.
- Returns:
The validated and potentially modified DataFrame.
- Return type:
- Raises:
ValueError – If the DataFrame is not square or if index and column names do not match and no suitable alignment option is specified.
Examples
>>> from geoprior.utils.validator import validate_comparison_data >>> data = pd.DataFrame({ ... 'A': [1, 0.9, 0.8], ... 'B': [0.9, 1, 0.85], ... 'C': [0.8, 0.85, 1] ... }, index=['A', 'B', 'X']) >>> print(validate_comparison_data(data, alignment='index_to_columns'))
>>> data = pd.DataFrame({ ... 1: [1, 0.9, 0.8], ... 2: [0.9, 1, 0.85], ... 3: [0.8, 0.85, 1] ... }, index=[1, 2, 'X']) >>> print(validate_comparison_data(data, alignment='auto'))
- geoprior.utils.validator.validate_data_types(data, expected_type='numeric', nan_policy='omit', return_data=False, error='raise')[source]
Checks for mixed data types in a pandas Series or DataFrame and handles according to the specified policies. This function is designed to ensure data consistency by verifying that data matches expected type criteria, offering options to manage and report any discrepancies.
- Parameters:
data (
pd.Seriesorpd.DataFrame) – The data to be checked. This can be a pandas Series or DataFrame.expected_type (
{'numeric', 'categoric', 'both'}, default'numeric') –Specifies the type of data expected:
’numeric’: All data should be of numeric types (int, float).
’categoric’: All data should be categorical, typically strings or pandas Categorical datatype.
’both’: Any mix of numeric and categorical data is considered valid.
nan_policy (
{'raise', 'omit', 'propagate'}, default'omit') –Determines how NaN values are handled:
’raise’: Raises an error if NaN values are found.
’warn’: Issues a warning if NaN values are found but proceeds.
’propagate’: Continues execution without addressing NaNs.
return_data (
bool, defaultFalse) – If True, returns a DataFrame or Series (depending on the input) that only includes data rows that conform to the expected_type. If False, returns None.error (
{'raise', 'warn'}, default'raise') –Configures the error handling behavior when data types do not conform to the expected_type:
’raise’: Raises a TypeError if mixed types are detected.
’warn’: Emits a warning but attempts to continue by filtering non-conforming data if return_data is True.
- Returns:
Depending on return_data, this function may return a filtered version of data that conforms to the expected_type or None if return_data is False.
- Return type:
pd.Seriesorpd.DataFrameorNone- Raises:
ValueError – If NaN values are present and nan_policy is set to ‘error’.
TypeError – If data types do not conform to expected_type and error is set to ‘raise’.
Examples
>>> import pandas as pd >>> from geoprior.utils.validator import validate_data_types >>> df = pd.DataFrame({'A': [1, 2, 'a', 3.5, np.nan], 'B': ['x', 'y', 'z', None, 't']}) >>> validate_data_types(df, expected_type='numeric', nan_policy='warn', ... return_data=True, error='warn') UserWarning: NaN values found in the data, but processing will continue. UserWarning: Expected numeric types but found mixed types. Non-numeric data will be ignored. A 0 1.0 1 2.0 3 3.5
Notes
The check_data_types function is useful in data preprocessing steps, particularly when you need to ensure that data fed into a machine learning algorithm meets certain type requirements. Handling mixed data types early on can prevent issues in model training and evaluation.
- geoprior.utils.validator.validate_dates(start_date, end_date, return_as_date_str=False, date_format='%Y-%m-%d')[source]
Validates and parses start and end years/dates, with options for output formatting.
This function ensures the validity of provided start and end years or dates, checks if they fall within a reasonable range, and allows the option to return the validated years or dates in a specified string format.
- Parameters:
start_date (
int,float, orstr) – The starting year or date. Can be an integer, float (converted to integer), or string in “YYYY” or “YYYY-MM-DD” format.end_date (
int,float, orstr) – The ending year or date, with the same format options as start_date.return_as_date_str (
bool, optional) – If True, returns the start and end dates as strings in the specified format. Default is False, returning years as integers.date_format (
str, optional) – The format string for output dates if return_as_date_str is True. Default format is “%Y-%m-%d”.
- Returns:
A tuple of two elements, either integers (years) or strings (formatted dates), representing the validated start and end years or dates.
- Return type:
- Raises:
ValueError – If the input years or dates are invalid, out of the acceptable range, or if the start year/date does not precede the end year/date.
Examples
>>> from geoprior.utils.validator import validate_dates >>> validate_dates(1999, 2001) (1999, 2001)
>>> validate_dates("1999/01/01", "2001/12/31", return_as_date_str=True) ('1999-01-01', '2001-12-31')
>>> validate_dates("1999", "1998") ValueError: The start date/time must precede the end date/time.
>>> validate_years("1899", "2001") ValueError: Years must be within the valid range: 1900 to [current year].
Notes
The function supports flexible input formats for years and dates, including handling both slash “/” and dash “-” separators in date strings. It enforces logical and chronological order between start and end inputs and allows customization of the output format for date strings.
- geoprior.utils.validator.validate_distribution(distribution, elements=None, kind=None, check_normalization=True)[source]
Validates or generates distributions for given elements, ensuring the sum equals 1 if check_normalization is True.
- Parameters:
distribution (
str,tuple,list) – The distribution to be validated or generated. If ‘auto’, generates a random distribution for the specified number of elements. Can also be a tuple or list representing an explicit distribution.elements (
int,listofstr, optional) – Defines how many elements the distribution should be generated for when ‘auto’ is used. If a list of strings is provided, its length is used to determine the number of elements.kind (
str, optional) – Specifies the kind of distribution. It can be{"probs"}for probability distributions, where the sum should equal 1 and values must be non-negative.check_normalization (
bool, optional) – If True, ensures that the sum of the distribution equals 1. Default is True.
- Returns:
A tuple representing the validated or generated distribution.
- Return type:
- Raises:
ValueError – If the provided distribution does not meet the specified conditions.
Examples
>>> from geoprior.utils.validator import validate_distribution >>> validate_distribution("auto", elements=['positive', 'neutral', 'negative']) (0.1450318690603951, 0.5660028611331361, 0.2889652698064687)
- geoprior.utils.validator.validate_dtype_selector(dtype_selector)[source]
Validates and categorizes the dtype_selector using regex, including handling cases where ‘only’ is specifically included.
- Parameters:
dtype_selector (
str) – Input dtype selector string.- Returns:
Categorized
dtype_selectorbased on predefined patterns. If"only"is included, the returned category reflects this so it can drive specific data-type handling.- Return type:
- Raises:
ValueError – If the input
dtype_selectordoes not match any predefined category.
- geoprior.utils.validator.validate_estimator_methods(estimator, methods, msg=None)[source]
Validate that the specified methods exist and are callable on the given estimator.
This utility function is designed to check whether an estimator (or any object) contains the required methods, such as
fitorpredict, and ensures that those methods are callable. It helps prevent runtime errors by verifying the presence of expected methods.- Parameters:
estimator (
object) – The object (instance or class) to check for the presence of the specified methods. The estimator can be an instance of a class or the class itself, and it should implement the required methods.methods (
listofstr) – List of method names (as strings) to validate. Each method name must exist on the estimator and be callable. Examples of methods might includefit,run, orpredict.msg (
str, optional) – Custom error message to display if any method is missing or not callable. If None, a default message is generated for each missing or invalid method based on the method name.
- Raises:
AttributeError – If any method in
methodsis not present or not callable on the estimator.
Examples
>>> from geoprior.utils.validator import validate_estimator_methods >>> class MyClass: ... def fit(self): ... pass ... def run(self): ... pass >>> validate_estimator_methods(MyClass(), ['fit', 'run']) # No error
>>> class IncompleteClass: ... def fit(self): ... pass >>> validate_estimator_methods(IncompleteClass(), ['fit', 'run']) # Raises AttributeError for missing `run` method
Notes
This helper is useful when you want to ensure that an object, such as an estimator or a model, exposes several callable methods before proceeding. If any method is missing or not callable, the function raises an
AttributeError. Method-callability checks follow the Python documentation and the callable-object discussion in [43, 49].See also
check_has_run_methodValidate the presence of a single method, defaulting to
run.
- geoprior.utils.validator.validate_fit_weights(y, sample_weight=None, weighted_y=False)[source]
Validate and compute sample weights for fitting.
- Parameters:
- Returns:
sample_weight (
array-likeofshape (n_samples,)) – Validated sample weights.weighted_y_values (
array-likeofshape (n_samples,), optional) – Weighted target values if weighted_y is True.
- Raises:
ValueError – If sample_weight is not None and its length does not match the length of y. If any value in sample_weight is negative.
Notes
This function checks the input sample weights, ensuring they are consistent with the target values y. If sample_weight is None, it returns an array of ones indicating equal weighting. Otherwise, it validates and returns the given sample weights. If weighted_y is True, it also computes and returns the weighted target values.
Examples
>>> import numpy as np >>> y = np.array([0, 1, 1, 0, 1]) >>> validate_fit_weights(y) array([1., 1., 1., 1., 1.])
>>> sample_weight = np.array([1, 0.5, 1, 1.5, 1]) >>> validate_fit_weights(y, sample_weight) array([1. , 0.5, 1. , 1.5, 1. ])
>>> validate_fit_weights(y, sample_weight, weighted_y=True) (array([1. , 0.5, 1. , 1.5, 1. ]), array([0. , 0.5, 1. , 0. , 1. ]))
>>> validate_fit_weights(y, weighted_y=True) (array([1., 1., 1., 1., 1.]), array([0., 1., 1., 0., 1.]))
- geoprior.utils.validator.validate_length_range(length_range, sorted_values=True, param_name=None)[source]
Validates the review length range ensuring it’s a tuple with two integers where the first value is less than the second.
- Parameters:
length_range (
tuple) – A tuple containing two values that represent the minimum and maximum lengths of reviews.sorted_values (
bool, defaultTrue) – If True, the function expects the input length range to be sorted in ascending order and will automatically sort it if not. If False, the input length range is not expected to be sorted, and it will remain as provided.param_name (
str, optional) – The name of the parameter being validated. If None, the default name ‘length_range’ will be used in error messages.
- Returns:
The validated length range.
- Return type:
- Raises:
ValueError – If the length range does not meet the requirements.
Examples
>>> from geoprior.utils.validator import validate_length_range >>> validate_length_range ( (202, 25) ) (25, 202) >>> validate_length_range ( (202,) ) ValueError: length_range must be a tuple with two elements.
- geoprior.utils.validator.validate_multiclass_target(y, accept_multioutput=False, return_classes=False)[source]
Validates that the target data is suitable for multiclass classification. Optionally accepts multi-output targets and can return the unique classes.
- Parameters:
y (
array-like) – The target data to be validated, expected to contain class labels for multiclass classification. Can be a multi-output array if accept_multioutput is set to True.accept_multioutput (
bool, optional) – Allows the target array to be multi-dimensional (default is False).return_classes (
bool, optional) – If True, returns the unique classes instead of a validation boolean.
- Returns:
If return_classes is False, returns True if the target data is valid for multiclass classification, otherwise raises a ValueError. If return_classes is True, returns the unique classes in the target data.
- Return type:
boolorarray- Raises:
ValueError – If any of the following conditions are not met: - If accept_multioutput is False, the target data must be one-dimensional. - All elements in the target array must be non-negative integers. - The target array must contain at least two distinct classes.
Examples
>>> from geoprior.utils.validator import validate_multiclass_target >>> validate_multiclass_target([0, 1, 2, 1, 0]) array([0, 1, 2, 1, 0]) >>> validate_multiclass_target([0, 0, 0]) ValueError: Target array must contain at least two distinct classes. >>> validate_multiclass_target([0.5, 1.2, 2.3]) ValueError: All elements in the target array must be non-negative integers. >>> validate_multiclass_target([[1, 2], [2, 3]], accept_multioutput=True, ... return_classes=True) (array([1, 2, 2, 3]), 3) True
- geoprior.utils.validator.validate_multioutput(value, extra='')[source]
Validate the multioutput parameter value and handle special cases.
This function checks if the provided multioutput value is one of the accepted strings (‘raw_values’, ‘uniform_average’, ‘raise’, ‘warn’). It warns or raises an error based on the value if it’s applicable.
- Parameters:
- Returns:
The validated multioutput value in lowercase if it’s one of the accepted values. If the value is ‘warn’ or ‘raise’, the function handles the case accordingly without returning a value.
- Return type:
- Raises:
ValueError – If value is not one of the accepted strings and is not ‘raise’.
Examples
>>> from geoprior.utils.validator import validate_multioutput >>> validate_multioutput('raw_values') 'raw_values'
>>> validate_multioutput('warn', extra=' for Dice Similarity Coefficient') # This will warn that multioutput parameter is not applicable for Dice # Similarity Coefficient.
>>> validate_multioutput('raise', extra=' for Gini Coefficient') # This will raise a ValueError indicating that multioutput parameter # is not applicable for Gini Coefficient.
>>> validate_multioutput('average') # This will raise a ValueError indicating 'average' is an invalid value # for multioutput parameter.
Note
The function is designed to ensure API consistency across various metrics functions by providing a standard way to handle multioutput parameter values, especially in contexts where multiple outputs are not applicable.
- geoprior.utils.validator.validate_nan_policy(nan_policy, *arrays, sample_weights=None)[source]
Validates and applies a specified nan_policy to input arrays and optionally to sample weights. This utility is essential for pre-processing data prior to statistical analyses or model training, where appropriate handling of NaN values is critical to ensure accurate and reliable outcomes.
- Parameters:
nan_policy (
{'propagate', 'raise', 'omit'}) – Defines how to handle NaNs in the input arrays. ‘propagate’ returns the input data without changes. ‘raise’ throws an error if NaNs are detected. ‘omit’ removes rows with NaNs across all input arrays and sample weights.*arrays (
array-like) – Variable number of input arrays to be validated and adjusted based on the specified nan_policy.sample_weights (
array-like, optional) – Sample weights array to be validated and adjusted in tandem with the input arrays according to nan_policy. Defaults to None.
- Returns:
arrays (
tupleofnp.ndarray) – Adjusted input arrays, with modifications applied based on nan_policy. The order of arrays in the tuple corresponds to the order of input.sample_weights (
np.ndarrayorNone) – Adjusted sample weights, modified according to nan_policy if provided. Returns None if no sample_weights were provided.
- Raises:
ValueError – If nan_policy is not among the valid options (‘propagate’, ‘raise’, ‘omit’) or if NaNs are detected when nan_policy is set to ‘raise’.
Notes
Handling NaN values is a critical step in data preprocessing, especially in datasets with missing values. The choice of nan_policy can significantly impact subsequent statistical analysis or predictive modeling by either including, excluding, or signaling errors for observations with missing values. This function ensures consistent application of the chosen policy across multiple datasets, facilitating robust and error-free analyses.
Examples
>>> import numpy as np >>> from geoprior.utils.validator import validate_nan_policy >>> y_true = np.array([1, np.nan, 3]) >>> y_pred = np.array([1, 2, 3]) >>> sample_weights = np.array([0.5, 0.5, 1.0]) >>> arrays, sw = validate_nan_policy('omit', y_true, y_pred, ... sample_weights=sample_weights) >>> arrays (array([1., 3.]), array([1., 3.])) >>> sw array([0.5, 1. ])
- geoprior.utils.validator.validate_numeric(value, convert_to='float', allow_negative=True, min_value=None, max_value=None, check_mode='soft')[source]
Validates if a given value is numeric. It can accept numeric strings and numpy arrays of single values. Optionally converts the value to either float or integer.
- Parameters:
value (
Any) – The value to be validated as numeric. This can be of any type but is expected to be convertible to a numeric type. Accepted types include numeric strings (e.g.,"42"), single-element numpy arrays (e.g., np.array([3.14])), integers, and floats.convert_to (
str, optional) – Type to convert the validated numeric value to. Use"float"for floating-point output or"int"for integer output. Defaults to"float".allow_negative (
bool, optional) – Whether to allow negative values. IfFalse, negative values raise aValueError. Defaults toTrue.min_value (
floatorint, optional) – The minimum value allowed. IfNone, no minimum value check is applied. Defaults toNone.max_value (
floatorint, optional) – The maximum value allowed. IfNone, no maximum value check is applied. Defaults toNone.check_mode (
str, optional) – Validation mode. Use"soft"to accept single-element iterables and validate their single value, or"strict"to accept only non-iterable numeric inputs. Defaults to"soft".
- Returns:
The validated and optionally converted numeric value. The type of the return value is determined by the convert_to parameter.
- Return type:
- Raises:
ValueError – If the value is not numeric or does not meet the specified criteria.
Notes
The function can coerce single-element NumPy arrays, numeric strings, and, in
softmode, single-element iterables before validating the result. The validated value is then converted tofloatorintand checked against the sign and range constraints. Array coercion details are documented in NumPy developers [50].Examples
>>> from geoprior.utils.validator import validate_numeric >>> validate_numeric("42", convert_to='int') 42 >>> validate_numeric(np.array([3.14]), convert_to='float') 3.14 >>> validate_numeric([123], check_mode='soft') 123.0 >>> validate_numeric([123], check_mode='strict') Traceback (most recent call last): ... ValueError: Value '[123]' is not a numeric type. >>> validate_numeric("-123.45", allow_negative=False) Traceback (most recent call last): ... ValueError: Negative values are not allowed: -123.45
See also
numpy.arrayNumpy arrays, which can be validated by this function.
- geoprior.utils.validator.validate_performance_data(model_performance_data=None, nan_policy='raise', convert_integers=True, check_performance_range=True, verbose=False)[source]
Validates and preprocesses model performance data to ensure it conforms to the necessary structure and constraints for statistical and machine learning analysis. The function accepts either a dictionary or a DataFrame as input and performs the following tasks:
Converts data to a DataFrame if it is provided as a dictionary.
Converts integer values to floats, ensuring compatibility with statistical processing.
Manages NaN values according to the specified nan_policy.
Validates that performance data falls within a valid range, ensuring values lie within [0, 1].
The function is adaptable, capable of being used directly or as a decorator, with or without configuration parameters.
- Parameters:
model_performance_data (
Union[Dict[str,List[float]],pd.DataFrame], optional) – The input model performance data to validate. Can be provided as either a dictionary (with model names as keys and performance metrics as lists) or a DataFrame where each column represents a model.nan_policy (
str, default'raise') – The policy to handle NaN values: * ‘raise’: Raises a ValueError if NaNs are detected. * ‘omit’: Drops rows with NaNs. * ‘propagate’: Ignores NaNs during performance range checks.convert_integers (
bool, defaultTrue) – Converts integer values within the data to floats if set to True, which is useful for consistency when computing metrics.check_performance_range (
bool, defaultTrue) – Ensures that performance values lie within the range [0, 1]. If any value falls outside this range, an error is raised unless nan_policy is set to ‘propagate’.verbose (
bool, defaultFalse) – If True, displays steps of the data validation process for tracking operations and debugging.
- geoprior.utils.validator.actual_validate_performance_data(data)
Validates and processes the data according to specified policies and constraints.
- geoprior.utils.validator.Usage()
- -----
- This function can be utilized in three primary ways:
- 1. **As a function**: Provide data directly to perform validation.
- >>> from geoprior.utils.validator import validate_performance_data
- >>> data = {'model1': [0.85, 0.90, 0.92], 'model2': [0.80, 0.87, 0.88]}
- >>> validate_performance_data(data)
- 2. **As a decorator**: Use as a decorator to validate the first
argument of a function. If used without parentheses, default values will be applied.
- >>> @validate_performance_data
- >>> def process_data(validated_data):
- >>> print(validated_data)
- 3. **As a decorator with parameters**: Customize validation by
specifying parameters.
- >>> @validate_performance_data(nan_policy='omit', verbose=True)
- >>> def process_data(validated_data):
- >>> print(validated_data)
Notes
The validation process includes statistical pre-checks, using custom modules to convert data and handle NaNs. For integer-to-float conversion, the convert_to_numeric function is utilized, while NaN policies are verified using is_valid_policies. The comparison framing for multiple models follows Demšar [51].
See also
DataFrameFormatterFormatter for handling DataFrame structures.
MultiFrameFormatterFormatter for handling multiple DataFrames.
- geoprior.utils.validator.validate_positive_integer(value, variable_name, include_zero=False, round_float=None, msg=None)[source]
Validates whether the given value is a positive integer or zero based on the parameter and rounds float values according to the specified method.
- Parameters:
variable_name (
str) – The name of the variable for error message purposes.include_zero (
bool, optional) – If True, zero is considered a valid value. Default is False.round_float (
str, optional) – If “ceil”, rounds up float values; if “floor”, rounds down float values; if None, truncates float values to the nearest whole number towards zero.msg (
str, optional) – Error message when checking for proper type failed.
- Returns:
The validated value converted to an integer.
- Return type:
- Raises:
ValueError – If the value is not a positive integer or zero (based on include_zero), or if the round_float parameter is improperly specified.
- geoprior.utils.validator.validate_sample_weights(weights, y, normalize=False)[source]
Validates that the sample weights are suitable for use in calculations.
This function checks that the sample weights are non-negative and match the length of the target array y. It raises an error if any conditions are not met. If a single number is provided as weights, it will be converted into an array with repeated values matching the length of y.
- Parameters:
weights (
array-likeornumber) – The sample weights to be validated. Each weight must be non-negative. A single number will be converted to an array with repeated values.y (
array-like) – The target array that the weights should correspond to. The length of weights must match the length of y.normalize (
bool, optional) – If True, weights will be normalized to sum to 1. Default is False.
- Returns:
The validated sample weights as a numpy array.
- Return type:
- Raises:
ValueError – If weights are not one-dimensional, if any weight is negative, or if the length of weights does not match the length of y.
Examples
>>> frpm geoprior.utils.validator import validate_sample_weights >>> y = [0, 1, 2, 3] >>> weights = [0.1, 0.2, 0.3, 0.4] >>> validate_sample_weights(weights, y) array([0.1, 0.2, 0.3, 0.4])
>>> weights = [-0.1, 0.2, 0.3, 0.4] >>> validate_sample_weights(weights, y) ValueError: Sample weights must be non-negative.
>>> weights = [0.1, 0.2, 0.3] >>> validate_sample_weights(weights, y) ValueError: Length of sample weights must match length of y.
- geoprior.utils.validator.validate_sets(data, mode='base', allow_empty=True, element_type=None, key_type=<class 'str'>)[source]
Validates whether the input data is a set in ‘base’ mode or a dictionary of sets in ‘deep’ mode. Provides additional parameters for flexibility and versatility. Returns the data if it passes validation.
- Parameters:
data (
Union[set,Dict[str,set]]) –The input data to validate. It can be either a single set or a dictionary where keys are set names and values are sets.
base mode : A single set.
deep mode : A dictionary of sets.
mode (
str, optional) – The mode in which to validate the data. Options are ‘base’ for a single set and ‘deep’ for a dictionary of sets. Default is ‘base’.allow_empty (
bool, optional) – Whether to allow empty sets or dictionaries. Default is True.element_type (
type, optional) – The expected type of elements in the set(s). If provided, the function checks whether all elements are of this type. Default is None (no type check).key_type (
type, optional) – The expected type of keys in the dictionary when in ‘deep’ mode. Default is str.
- Returns:
The original data if it matches the specified mode and additional criteria. Raises ValueError if validation fails.
- Return type:
Union[set,Dict[str,set]]
Examples
>>> from geoprior.utils.validator import validate_sets >>> validate_sets({1, 2, 3}, mode='base') {1, 2, 3}
>>> validate_sets({"Set1": {1, 2, 3}, "Set2": {3, 4, 5}}, mode='deep') {"Set1": {1, 2, 3}, "Set2": {3, 4, 5}}
>>> validate_sets({"Set1": {1, 2, 3}, "Set2": [3, 4, 5]}, mode='deep') Traceback (most recent call last): ... ValueError: Data validation failed: expected all values to be sets
>>> validate_sets(set(), mode='base', allow_empty=False) Traceback (most recent call last): ... ValueError: Data validation failed: empty set is not allowed
>>> validate_sets({"Set1": set()}, mode='deep', allow_empty=False) Traceback (most recent call last): ... ValueError: Data validation failed: empty dictionary is not allowed
>>> validate_sets({"Set1": {1, 2, 3}}, mode='deep', element_type=int) {"Set1": {1, 2, 3}}
Notes
This function checks the type of the input data based on the specified mode. In ‘base’ mode, it ensures the data is a set. In ‘deep’ mode, it ensures the data is a dictionary where all values are sets. Additional parameters allow for checking if sets are empty, if elements are of a specific type, and if dictionary keys are of a specific type. The core type test used here is documented in Python Software Foundation [52].
See also
isinstancePython built-in function to check an object’s type.
- geoprior.utils.validator.validate_strategy(strategy=None, error='raise', ops='validate', rename_key=False, **kwargs)[source]
Validate and construct a strategy dictionary for imputing missing data.
This function processes the input
strategyto ensure it conforms to the expected format for imputing missing values in numerical and categorical features. It provides flexibility in handling different strategies and error management, making it suitable for integration with scikit-learn’s imputation tools.- Parameters:
strategy (
Optional[Union[str,Dict[str,str]]], defaultNone) – Defines the imputation strategy for numerical and categorical features. A string is parsed into a dictionary with keys"numeric"and"categorical", a dictionary is used directly, andNoneselects the default strategy.error (
str, default'raise') – Error handling behavior for invalid strategy tokens. Use"raise"to raise aValueError,"warn"to emit a warning, or"ignore"to skip invalid tokens silently.ops (
str, default'validate') – Operation mode of the validator. Use"passthrough"to return the input strategy unchanged when it is already a dictionary,"check_only"to validate without modifying it, or"validate"to validate and construct the strategy dictionary from the input.rename_key (
bool, defaultFalse) – IfTrue, rename aliases such as"num","numeric", or"numerical"to"numeric", and aliases such as"cat","categorical", or"categoric"to"categorical". Other keys remain unchanged.**kwargs – Additional keyword arguments for future extensions.
- Returns:
Returns the input strategy dictionary for
ops='passthrough',TrueorFalseforops='check_only', and the validated or modified strategy dictionary forops='validate'.- Return type:
Union[Dict[str,str],bool]- Raises:
ValueError – If an invalid error or ops parameter is provided, or if the strategy tokens are invalid and error is set to ‘raise’.
Notes
The function limits numerical strategies to
"median"and"mean"while categorical strategies default to"constant". It also handles key aliasing to keep the returned dictionary consistent.Examples
>>> from geoprior.utils.validator import validate_strategy >>> validate_strategy('mean constant') {'numeric': 'mean', 'categorical': 'constant'}
>>> validate_strategy({'num': 'mean', 'cat': 'constant'}, rename_key=True) {'numeric': 'mean', 'categorical': 'constant'}
>>> validate_strategy('numeric categorical', ops='check_only') False
>>> validate_strategy('invalid_strategy', error='warn') {'numeric': 'median', 'categorical': 'constant'}
See also
sklearn.impute.SimpleImputerImputation transformer for completing missing values.
- geoprior.utils.validator.validate_scores(scores, true_labels=None, mode='strict', accept_multi_output=False)[source]
Validates that the scores represent valid probability distributions and checks consistency between scores and true labels in multi-output scenarios.
- Parameters:
scores (
listornp.ndarray) – A list of np.ndarrays for multi-output probabilities, or a single np.ndarray for single-output probabilities. Each ndarray should contain probability distributions where each row sums to approximately 1 and has non-negative values.true_labels (
listornp.ndarray, optional) – The true labels corresponding to the scores. This parameter must be provided in multi-output scenarios to check the alignment of labels and scores. Each element or row in true_labels should correspond to the equivalent in scores.mode (
str, optional (default"strict")) – Validation mode for checking probability distributions. Use"strict"to require each row to sum to1within numerical tolerance,"soft"to require non-negative scores with totals no greater than1, or"passthrough"to only check that each score lies in the interval[0, 1].accept_multi_output (
bool, defaultFalse) – Flag indicating whether scores with multiple outputs are accepted. If False and scores are provided as a list, a ValueError will be raised.
- Returns:
The validated scores as a NumPy array.
- Return type:
np.ndarray- Raises:
ValueError – If multi-output scores are provided and not accepted. If there is a mismatch in the number of outputs between scores and true_labels. If scores or any subset of scores do not form valid probability distributions. If there is a mismatch in format expectations between scores and true_labels in terms of multi-output handling.
Notes
The function is designed to handle both single and multi-output probability distributions. For multi-output scenarios, both scores and true_labels should be lists of np.ndarrays. This function is particularly useful in scenarios involving machine learning models where output probabilities need to be validated before further processing or metrics calculations.
Examples
>>> import numpy as np >>> from geoprior.utils.validator import validate_scores >>> scores_single = np.array([[0.1, 0.9], [0.8, 0.2]]) >>> print(validate_scores(scores_single)) [[0.1, 0.9] [0.8, 0.2]]
>>> scores_multi = [np.array([[0.1, 0.9]]), np.array([[0.8, 0.2]])] >>> true_labels_multi = [np.array([1]), np.array([0])] >>> print(validate_scores(scores_multi, true_labels_multi, accept_multi_output=True)) [array([[0.1, 0.9]]), array([[0.8, 0.2]])]
- geoprior.utils.validator.validate_square_matrix(data, align=False, align_mode='auto', message='')[source]
Validate that the input data forms a square matrix and optionally aligns its indices and columns if specified.
- Parameters:
data (
DataFrameorarray-like) – The input data to validate as a square matrix.align (
bool, defaultFalse) – Whether to align the DataFrame’s index with its columns.align_mode (
str, default'auto') – Alignment mode if indices and columns do not match. Options are ‘auto’, ‘index_to_columns’, and ‘columns_to_index’.message (
str, default'') – Additional message to append to the error if validation fails.
- Returns:
The validated or aligned square matrix.
- Return type:
data- Raises:
ValueError – If the input is not a square matrix.
Examples
>>> from geoprior.utils.validator import validate_square_matrix >>> validate_square(np.array([[1, 2], [3, 4]])) array([[1, 2], [3, 4]])
>>> validate_square(pd.DataFrame([[1, 2], [3, 4, 5]])) ValueError: Input must be a square matrix.
Notes
A square matrix is defined as having equal number of rows and columns. This function checks the dimensionality of the data and optionally aligns the index and columns if
alignis set toTrue.
- geoprior.utils.validator.validate_weights(weights, min_value=None, max_value=None, normalize=False, allowed_dims=1)[source]
Validates and optionally normalizes the given weights array to ensure all elements meet specified criteria and the structure is suitable for computations.
- Parameters:
weights (
array-like) – Weights to be validated. Can be a list, tuple, or numpy array.min_value (
float, optional) – Minimum allowable value for weights (inclusive). If None, weights are expected to be non-negative. Explicitly set to a negative value if negative weights are allowed.max_value (
floatorNone, optional) – Maximum allowable value for weights (inclusive). If None, no upper limit is enforced.normalize (
bool, optional) – If True, weights will be normalized to sum to 1. Default is False.allowed_dims (
intortuple, optional) – Specifies the allowed dimensions of the weights array. Default is 1 (one-dimensional). If a tuple is provided, weights must match one of the dimensions specified in the tuple.
- Returns:
A numpy array of the validated and optionally normalized weights.
- Return type:
np.ndarray- Raises:
ValueError – If weights contain values outside the specified range, or if the format or dimensions are not suitable.
Examples
>>> from geoprior.utils.validator import validate_weights
>>> validate_weights([0.25, 0.75, 0.5], normalize=True) array([0.2, 0.6, 0.4])
>>> validate_weights([-0.1, 0.9], min_value=0) ValueError: Weights must be non-negative.
>>> validate_weights([0.1, 0.2, 0.7], max_value=0.5) ValueError: Weights must not exceed 0.5.
>>> validate_weights([1, 2, 3], allowed_dims=(1, 2)) ValueError: Weights dimensions not allowed.
- geoprior.utils.validator.validate_yy(y_true, y_pred, expected_type=None, *, validation_mode='strict', flatten=False)[source]
Validates the shapes and types of actual and predicted target arrays, ensuring they are compatible for further analysis or metrics calculation.
- Parameters:
y_true (
array-like) – True target values.y_pred (
array-like) – Predicted target values.expected_type (
str, optional) – The expected sklearn type of the target (‘binary’, ‘multiclass’, etc.).validation_mode (
str, optional) – Validation strictness. Currently, only ‘strict’ is implemented, which requires y_true and y_pred to have the same shape and match the expected_type.flatten (
bool, optional) – If True, both y_true and y_pred are flattened to one-dimensional arrays.
- Raises:
ValueError – If y_true and y_pred do not meet the validation criteria.
- Returns:
The validated y_true and y_pred arrays, potentially flattened.
- Return type:
These modules form the lighter infrastructure layer behind the more visible staged helpers.
NAT/workflow contract helpers#
Public exports for NAT workflow utilities.
- geoprior.utils.nat_utils.build_censor_mask(xb, H, idx, thresh=0.5, *, source='dynamic', reduce_time='any', align='broadcast')[source]
Build a censor mask aligned to the forecast horizon: (B, H, 1).
- Parameters:
source (
{"dynamic", "future"}, default"dynamic") – Selects where the censoring flag is read from."dynamic"readsxb["dynamic_features"][:, :, idx]from the history window, while"future"readsxb["future_features"][:, :, idx]from the forecast window.reduce_time (
{"any", "last", "all"}, default"any") – Reduction applied whensource="dynamic"and the censor flag behaves like a per-sample label."any"marks the sample as censored if any history step is flagged,"last"uses only the last history step, and"all"requires every history step to be flagged.align (
{"broadcast", "crop", "pad_false", "pad_edge", "error"}, default"broadcast") – Policy used when the time axis does not already match the forecast horizonH."broadcast"repeats a single-step label across all horizon steps,"crop"keeps the lastHsteps,"pad_false"pads missing steps withFalse,"pad_edge"repeats the last available step, and"error"raises on mismatch.xb (dict)
idx (int | None)
thresh (float)
- Return type:
Tensor
- geoprior.utils.nat_utils.ensure_input_shapes(x, mode, forecast_horizon)[source]
Ensure presence of zero-width static/future placeholders.
Stage-1 exporters sometimes omit
static_featuresorfuture_featureswhen there are no static/future variables for a particular experiment. Keras, however, expects these inputs to exist so that the input signature remains stable.This helper:
Copies the input dict to avoid in-place modification.
Ensures
static_featuresis an array of shape(N, 0)if missing.Ensures
future_featuresis an array of shape(N, T_future, 0)if missing, where:T_future = dynamic_features.shape[1]whenmode == "tft_like"(past+future style).Otherwise,
T_future = forecast_horizon.
- Parameters:
- Returns:
Shallow copy of
xwith guaranteedstatic_featuresandfuture_featuresentries.- Return type:
- geoprior.utils.nat_utils.extract_preds(model, out, *, strict=True, output_names=None)[source]
Extract (subs_pred, gwl_pred) from GeoPrior outputs.
- Supports:
v3.2+ call(): {“subs_pred”,”gwl_pred”}
forward_with_aux(): (y_pred, aux)
legacy: {“data_final”} + model.split_data_predictions
predict(): list/tuple mapped via output names
If strict=True, list/tuple outputs must be mappable via output names; otherwise we raise to avoid silent swaps.
This helper normalizes the output interface across two GeoPrior generation families:
New interface (preferred)
model(inputs) -> {"subs_pred": ..., "gwl_pred": ...}Legacy interface (backward compatible)
model(inputs) -> {"data_final": ...}, where the caller must split the tensor usingmodel.split_data_predictions.
- Parameters:
model (
object) –A Keras-like model instance that may expose
split_data_predictions(data_final).The splitter must return a tuple:
subs_predwith shape(B, H, 1)or(B, H, Q, 1)gwl_predwith shape(B, H, 1)or(B, H, Q, 1)
out (
dict) –Output returned by the model call, typically
model(inputs, training=False).Supported keys are either:
{"subs_pred", "gwl_pred"}(new interface), or{"data_final"}(legacy interface).
strict (bool)
- Returns:
subs_pred (
Tensor) – Predicted subsidence in model space.Expected shapes:
Point mode:
(B, H, 1)Quantile mode:
(B, H, Q, 1)
gwl_pred (
Tensor) – Predicted groundwater/head variable in model space.Expected shapes:
Point mode:
(B, H, 1)Quantile mode:
(B, H, Q, 1)
- Raises:
- Return type:
Notes
This function is intended for Stage-2 and Stage-3 scripts where you may load checkpoints from older experiments. It avoids fragile code that slices
data_finalmanually.The function does not validate tensor dtypes or numerical finiteness. Upstream code should handle
NaNandInfchecks as needed. Output normalization follows the Keras model conventions documented in Keras Team [24].Examples
New interface:
out = model_inf(xb, training=False) s_pred, h_pred = extract_stage_outputs( model_inf, out, )
Legacy interface:
out = model_inf(xb, training=False) s_pred, h_pred = extract_stage_outputs( model_inf, out, )
See also
subs_point_from_stage_outConvert subsidence predictions to a point forecast.
- geoprior.utils.nat_utils.load_nat_config(root='nat.com')[source]
High-level helper used by NATCOM scripts.
Example
>>> from geoprior.utils.nat_utils import load_nat_config >>> cfg = load_nat_config() >>> CITY_NAME = cfg["CITY_NAME"] >>> TIME_STEPS = cfg["TIME_STEPS"]
- geoprior.utils.nat_utils.load_nat_config_payload(root='nat.com')[source]
Return the full config.json payload, including city, model and __meta__ fields.
This is convenient when you also want to see which hash or city/model are currently active.
- geoprior.utils.nat_utils.load_scaler_info(encoders_block)[source]
Load the
scaler_infomapping from an encoders block.Stage-1 exporters typically store a compact description of the scalers used to normalise the data. In many cases this takes the form:
encoders = { "main_scaler": "/path/to/minmax.joblib", "coord_scaler": "/path/to/coords.joblib", "scaler_info": "/path/to/scaler_info.joblib", ... }
where
scaler_infois either a path to a joblib file or an already-loaded dictionary.This helper returns a dictionary regardless of how it was stored, making downstream formatting/evaluation code simpler.
- geoprior.utils.nat_utils.make_tf_dataset(X_np, y_np, batch_size, shuffle, mode, forecast_horizon, *, seed=42, drop_remainder=False, reshuffle_each_iter=True, prefetch=True, check_npz_finite=False, check_finite=False, scan_finite_batches=0, dynamic_feature_names=None, future_feature_names=None)[source]
Build a tf.data.Dataset using NATCOM conventions.
Steps: 1) ensure_input_shapes(…) for X. 2) map_targets_for_training(…) for y. 3) tf.data pipeline (shuffle/batch/prefetch). 4) optional finite checks (NPZ + tf batches).
- Parameters:
X_np (
dict) – Input dictionary, typically obtained fromnp.loadon the Stage-1*_inputs_npzfile.y_np (
dict) – Target dictionary, typically obtained fromnp.loadon the Stage-1*_targets_npzfile.batch_size (
int) – Number of samples per batch.shuffle (
bool) – IfTrue, shuffle the dataset using a fixed seed for reproducibility.mode (
str) – Model mode passed toensure_input_shapes().forecast_horizon (
int) – Forecast horizon passed toensure_input_shapes().check_npz_finite (
bool) – If True, checks Xin/Yin numpy arrays for NaN/Inf before building ds.check_finite (
bool) – If True, inserts assert_all_finite checks inside the tf.data pipeline.scan_finite_batches (
int) – If >0, eagerly scans first N batches right away (fails early).dynamic_feature_names (list[str] | None) – If provided, used to report bad channels for feature tensors.
future_feature_names (list[str] | None) – If provided, used to report bad channels for feature tensors.
seed (int)
drop_remainder (bool)
reshuffle_each_iter (bool)
prefetch (bool)
- Returns:
Dataset of (X, y) pairs.
- Return type:
tf.data.Dataset
Notes
TensorFlow is imported lazily inside the function so that this module remains importable in environments where TF is not installed (for example, for tooling or static analysis).
- geoprior.utils.nat_utils.map_targets_for_training(y_dict, subs_key='subsidence', gwl_key='gwl', subs_pred_key='subs_pred', gwl_pred_key='gwl_pred')[source]
Standardise target dictionaries to the Keras compile keys.
This helper enforces a small convention used throughout the NATCOM training scripts:
Upstream sequence builders typically export raw targets with keys
subsidenceandgwl.The GeoPrior model is compiled with targets named
subs_predandgwl_pred.
This function accepts either style and always returns a dict keyed by
subs_predandgwl_predfor use in Keras.- Parameters:
y_dict (
dict) – Dictionary produced by the Stage-1 sequence exporter or by a previous training script. Must contain either (subsidence,gwl) or (subs_pred,gwl_pred).subs_key (
str, default"subsidence") – Name of the raw subsidence key iny_dict.gwl_key (
str, default"gwl") – Name of the raw groundwater-level key iny_dict.subs_pred_key (
str, default"subs_pred") – Standardised key for the subsidence prediction target.gwl_pred_key (
str, default"gwl_pred") – Standardised key for the GWL prediction target.
- Returns:
New dictionary with keys
subs_predandgwl_pred.- Return type:
- Raises:
KeyError – If the dictionary does not contain either of the expected key pairs.
- geoprior.utils.nat_utils.name_of(obj)[source]
Return a human-readable name for an object.
This utility is handy when serialising compile configurations (e.g., turning metric callables into simple strings for JSON logs).
- geoprior.utils.nat_utils.resolve_hybrid_config(manifest_cfg, live_cfg, verbose=True)[source]
Merge Manifest config (Data Authority) with Live config (Physics Authority).
- geoprior.utils.nat_utils.resolve_si_affine(cfg, scaler_info, *, target_name, prefix, unit_factor_key, scale_key, bias_key)[source]
- geoprior.utils.nat_utils.best_epoch_and_metrics(history, monitor='val_loss')[source]
Return the best epoch and metrics at that epoch.
Given a
History.historydictionary produced bymodel.fit(...), this helper identifies the index of the minimum value for the monitored quantity (by default"val_loss") and returns:The epoch index (0-based).
A dictionary mapping each metric name to its value at that epoch.
- Parameters:
- Returns:
- Return type:
- geoprior.utils.nat_utils.subs_point_from_out(model, out, quantiles=None, med_idx=None)[source]
Convert model output into a subsidence point forecast.
This helper produces a subsidence tensor shaped
(B, H, 1)in model space, regardless of whether the model emits quantiles or a point prediction.If quantiles are present and the subsidence prediction is shaped
(B, H, Q, 1), the function selects the median quantile slice.Otherwise, it returns the point prediction directly.
- Parameters:
model (
object) – A Keras-like model instance passed toextract_stage_outputs().out (
dict) –Output returned by the model call.
This can be either the new interface with keys
"subs_pred"and"gwl_pred", or the legacy interface with key"data_final".quantiles (
sequenceoffloatorNone, defaultNone) –Quantile levels used by the model, such as
[0.1, 0.5, 0.9].If provided, the function may use it to interpret the rank-4 quantile output and select the median.
If
None, quantile selection is disabled unlessmed_idxis explicitly provided and the tensor rank indicates quantiles.med_idx (
intorNone, defaultNone) –Index along the quantile axis to use as the “point” forecast when quantiles are available.
If
Noneandquantilesis provided, the function selects the index closest to0.5.
- Returns:
subs_point – Subsidence point prediction in model space with shape
(B, H, 1).- Return type:
Tensor- Raises:
ValueError – If subsidence prediction is missing or
None.ValueError – If a quantile tensor is detected but a valid median index cannot be resolved.
Notes
Quantile outputs are assumed to be shaped
(B, H, Q, 1)where the quantile axis is the third dimension (axis=2).If the model returns point predictions already, the function is effectively a no-op. The quantile interpretation used here follows Koenker and Bassett [25].
Examples
Quantile model:
out = model_inf(xb, training=False) s_point = subs_point_from_stage_out( model_inf, out, quantiles=[0.1, 0.5, 0.9], )
Point model:
out = model_inf(xb, training=False) s_point = subs_point_from_stage_out( model_inf, out, )
See also
extract_stage_outputsNormalize outputs across new and legacy checkpoints.
- geoprior.utils.nat_utils.serialize_subs_params(params, cfg=None)[source]
Make GeoPrior subnet parameters JSON-friendly.
The training scripts typically pass a dictionary of model construction arguments, e.g.
subsmodel_params, which contains objects such asLearnableMVorFixedGammaWthat are not directly JSON-serialisable.This helper replaces those objects by small dictionaries describing their type and scalar value, optionally using values from the NATCOM config dictionary.
- Parameters:
params (
dict) – Dictionary of model init parameters (e.g.subsmodel_paramsintraining_NATCOM_GEOPRIOR.py).cfg (
dict, optional) –NATCOM config dictionary. If provided, scalar values are taken from:
GEOPRIOR_INIT_MVGEOPRIOR_INIT_KAPPAGEOPRIOR_GAMMA_WGEOPRIOR_H_REF
and used as the authoritative numbers.
- Returns:
Copy of
paramswhere scalar GeoPrior parameters are replaced by JSON-friendly dictionaries.- Return type:
Notes
This function does not import any of the GeoPrior classes. It only introspects attributes like
initial_valueorvaluewhen the corresponding config entry is missing.
- geoprior.utils.nat_utils.save_ablation_record(outdir, city, model_name, cfg, eval_dict, phys_diag=None, per_h_mae=None, per_h_r2=None, log_fn=None)[source]
Append a single ablation record to
ablation_record.jsonl.Each training run (e.g., different physics toggles or weights) writes one JSON line containing:
Basic run identifiers (city, model, timestamp).
Physics configuration (
PDE_MODE_CONFIG, lambda weights, effective head flags, etc.).Key performance metrics (R², MSE, MAE, coverage, sharpness).
Optional physics diagnostics (
epsilon_prior,epsilon_cons).Optional per-horizon MAE/R² for more detailed analysis.
- Parameters:
outdir (
str) – Base output directory for the current run. The ablation file is created underoutdir / "ablation_records".city (
str) – City name (e.g.,"nansha"or"zhongshan").model_name (
str) – Model identifier (e.g.,"GeoPriorSubsNet").cfg (
dict) – Lightweight configuration dictionary containing at least the physics-related keys used below.eval_dict (
dictorNone) – Dictionary of evaluation metrics (R², MSE, MAE, coverage80, sharpness80). IfNone, metrics fields default toNone.phys_diag (
dictorNone, optional) – Physics diagnostics (e.g., fromevaluate()) with keys such as"epsilon_prior"and"epsilon_cons".per_h_mae (
dictorNone, optional) – Per-horizon MAE values (e.g., keyed by year/step).
- Return type:
None
Notes
The output file is a JSON-Lines file, so it can be loaded with
load_ablation_jsonl().
- geoprior.utils.nat_utils.load_windows_npz(path)[source]
Load Stage-1 windows as (x, y).
Supported: - Bundle NPZ (contains inputs+targets in one file). - Mapping {‘inputs’: <npz>, ‘targets’: <npz>}. - Inputs NPZ only (targets inferred by filename). - Directory containing inputs/targets NPZ.
- geoprior.utils.nat_utils.load_tuned_hps_near_model(model_path, *, prefer='keras', required=True, log_fn=None)[source]
- geoprior.utils.nat_utils.load_trained_hps_near_model(model_path, *, allowed, required=False, log_fn=None)[source]
- geoprior.utils.nat_utils.load_hps_auto_near_model(model_path, *, allowed, prefer='keras', required=False, log_fn=None)[source]
- geoprior.utils.nat_utils.load_or_rebuild_geoprior_model(model_path, manifest, X_sample, out_s_dim, out_g_dim, mode, horizon, quantiles, city_name=None, compile_on_load=True, verbose=1)[source]
Load a tuned or trained GeoPriorSubsNet, with robust rebuild fallback.
- geoprior.utils.nat_utils.compile_for_eval(model, manifest, best_hps, quantiles, *, include_metrics=True)[source]
Recompile a GeoPriorSubsNet instance for evaluation / diagnostics.
This is intended for: - tuned models loaded from a .keras archive, or - models rebuilt from best_hps.
It does NOT change the architecture or weights, only the compile configuration (optimizer, losses, and physics loss weights).
- Parameters:
model (
GeoPriorSubsNet) – Loaded or freshly-built GeoPriorSubsNet instance.manifest (
dict) – Stage-1 manifest; training config is taken frommanifest['config'].best_hps (
dictorNone) – Dictionary of tuned hyperparameters. If empty/None, reasonable defaults are inferred from the manifest.quantiles (
listoffloatorNone) – Quantiles used for probabilistic subsidence/GWL outputs.include_metrics (
bool, defaultTrue) – If True, attach MAE/MSE + coverage/sharpness metrics to match the training script; if False, only losses are configured.
- Returns:
The same model instance, compiled in-place.
- Return type:
model
- geoprior.utils.nat_utils.load_best_hps_near_model(model_path, *, model_name='GeoPriorSubsNet', prefer='keras', log_fn=None)[source]
Load best hyperparameters saved near a model artifact.
Supports model names like: <city>_<model_name>_H{H}_best.keras <city>_<model_name>_H{H}_best.weights.h5
- Parameters:
model_path (
str) – Path to a model file or its run directory.model_name (
strorNone, default"GeoPriorSubsNet") – Model name token in filenames.prefer (
{"keras", "weights"}, default"keras") – Which artifact type to infer the prefix from.log_fn (
callableorNone, defaultNone) – Logger (e.g. print). None disables logs.
- Returns:
best_hps – Non-empty hyperparameter dictionary.
- Return type:
- Raises:
FileNotFoundError – If no hyperparameter JSON is found.
ValueError – If a candidate JSON exists but is empty/invalid.
- geoprior.utils.nat_utils.pick_npz_for_dataset(manifest, split)[source]
Load (inputs, targets) NPZ arrays for a given dataset split.
This is a public, reusable version of the internal helper that was previously named
_pick_npz_for_dataset.- Parameters:
manifest (
dict) –Stage-1 manifest dictionary with the structure:
manifest["artifacts"]["numpy"] = { "train_inputs_npz": ..., "train_targets_npz": ..., "val_inputs_npz": ..., "val_targets_npz": ..., "test_inputs_npz": ... (optional), "test_targets_npz": ... (optional), }
split (
{"train", "val", "test"}) – Which dataset to load.
- Returns:
- Raises:
KeyError – If the manifest does not contain the expected NPZ entries.
ValueError – If
splitis not one of{"train", "val", "test"}.
- Return type:
- geoprior.utils.nat_utils.ensure_config_json(root='nat.com')[source]
Ensure that nat.com/config.json exists and is consistent with nat.com/config.py.
- Returns:
config (
dict) – The configuration dictionary (payload[“config”]).json_path (
str) – Absolute path to config.json.Behaviour---------- If `config.jsondoes not exist`,it is created from– config.py.- If it exists but the SHA-256 hashof config.py has – changed, it is regenerated.- Otherwise the existing JSON file is reused.
- Parameters:
root (str)
- Return type:
- geoprior.utils.nat_utils.get_natcom_dir(root='nat.com')[source]
Directory containing NATCOM scripts and configuration, typically <repo_root>/nat.com.
- Return type:
- geoprior.utils.nat_utils.get_config_paths(root='nat.com')[source]
Return (config_py_path, config_json_path) for NATCOM.
- geoprior.utils.nat_utils.compile_geoprior_for_eval(model, manifest, best_hps, quantiles)[source]
(Re)compile a GeoPriorSubsNet-like model for evaluation.
This helper uses the Stage-1 manifest and tuned hyperparameters to configure:
the pinball losses for subsidence and GWL outputs,
loss weights for the two heads,
physics loss weights (lambda_*),
learning rate and LR multipliers.
TensorFlow and geoprior are imported lazily inside this function so that
nat_utilscan be imported even in non-TF environments.- Parameters:
model (
GeoPriorSubsNet-like) – An instance of the GeoPriorSubsNet model (or any model exposing the same compile signature).manifest (
dict) – Stage-1 manifest dictionary. Theconfigentry is used to retrieve default loss weights and physics settings.best_hps (
dict) – Hyperparameters loaded from the tuning run (e.g. viaload_best_hps_near_model()).quantiles (
listoffloatorNone) – Quantile levels used for probabilistic outputs. IfNone, mean-squared error is used instead of pinball loss.
- Returns:
The same model instance, compiled in-place.
- Return type:
model- Raises:
ImportError – If TensorFlow or geoprior’s
make_weighted_pinballcannot be imported.
This is one of the most important workflow modules in the GeoPrior stack. It includes config loading, dataset building, target mapping, scaling resolution, prediction extraction, artifact serialization, and model-evaluation support.
Subsidence-oriented workflow helpers#
Utility helpers for subsidence data, units, and coordinates.
- geoprior.utils.subsidence_utils.normalize_gwl_alias(df, gwl_col_user, *, prefer_depth_bgs=True, verbose=True)[source]
Normalize common GWL naming aliases.
Naming-only: unit conversion happens later.
- geoprior.utils.subsidence_utils.resolve_gwl_for_physics(df, gwl_col_user, *, prefer_depth_bgs=True, allow_keep_zscore_as_ml=True, verbose=True)[source]
Pick meters GWL for physics; keep z-score ML-only.
- geoprior.utils.subsidence_utils.resolve_head_column(df, *, depth_col, head_col='head_m', z_surf_col=None, use_head_proxy=True)[source]
Resolve a head column, creating one if needed.
- geoprior.utils.subsidence_utils.convert_target_units_df(df, *, base, from_unit='m', to_unit='mm', mode='overwrite', suffix='_mm', columns=None, unit_col=None, copy_df=True, overwrite_cols=False, strict=False)[source]
Convert target-like columns between “m” and “mm”.
Selected columns: - base - base + “_*” (quantiles, intervals, …)
If mode=”add”, new columns use suffix.
- geoprior.utils.subsidence_utils.subs_unit_to_si(cfg=None, *, default=0.001, units_prov_key='units_provenance', stage1_key='subs_unit_to_si_applied_stage1', cfg_key='SUBS_UNIT_TO_SI')[source]
Subsidence unit->SI factor (to meters).
Priority: 1) cfg[units_prov_key][stage1_key] 2) cfg[cfg_key] 3) default
- geoprior.utils.subsidence_utils.subs_native_unit(cfg=None, *, default='mm')[source]
Infer the “native” subsidence unit from cfg.
unit_to_si ~= 1e-3 -> “mm”
unit_to_si ~= 1.0 -> “m”
- geoprior.utils.subsidence_utils.add_subsidence_mm_columns(df, cfg=None, *, base='subsidence', columns=None, suffix='_mm', unit_col=None, copy_df=True, overwrite=False)[source]
Add (or overwrite) subsidence columns in millimeters.
Always assumes the current values are in meters.
- geoprior.utils.subsidence_utils.add_subsidence_native_unit_columns(df, cfg=None, *, base='subsidence', columns=None, suffix='_native', unit_col=None, copy_df=True, overwrite=False)[source]
Add columns in the cfg-inferred native unit.
If cfg says the native unit was meters, this becomes a no-op aside from optional unit_col.
- geoprior.utils.subsidence_utils.finalize_si_scaling_kwargs(scaling_kwargs, *, subs_in_si, head_in_si, thickness_in_si, force_identity_affine_if_si=True, warn=True)[source]
Prevent double SI conversion in GeoPrior scaling_kwargs.
- geoprior.utils.subsidence_utils.finalize_si_affines_and_units(scaling_kwargs, *, subs_in_si, head_in_si, thickness_in_si, force_identity_affine_if_si=True, warn=True)
Prevent double SI conversion in GeoPrior scaling_kwargs.
- geoprior.utils.subsidence_utils.infer_utm_epsg_from_lonlat(lon_deg, lat_deg)[source]
Infer a UTM EPSG code from lon/lat (EPSG:4326).
- Uses standard UTM zoning:
zone = floor((lon + 180)/6) + 1 EPSG = 32600 + zone (north), 32700 + zone (south)
- geoprior.utils.subsidence_utils.lonlat_to_utm_m(lon_deg, lat_deg, *, src_epsg=4326, target_epsg=None)[source]
Convert lon/lat degrees to UTM meters using pyproj.
- class geoprior.utils.subsidence_utils.CoordsPack(coords: 'np.ndarray', coord_mins: 'dict[str, float]', coord_ranges: 'dict[str, float]', meta: 'dict[str, Any]')[source]
Bases:
object- Parameters:
- coords: ndarray
- geoprior.utils.subsidence_utils.make_txy_coords(t, x, y, *, time_shift='min', xy_shift='min', time_shift_value=None, x_shift_value=None, y_shift_value=None, dtype='float32')[source]
Build coords tensor (t, x, y) with OPTIONAL shifting (translation only).
- This is designed for your “not normalized” workflow:
You keep SI units (years and meters),
but you avoid feeding huge UTM magnitudes (e.g. 3e5, 2.5e6) into coord MLPs by shifting x,y (and optionally t).
Notes
This does NOT min-max scale to [0,1]. It only translates.
Returning coord_mins/coord_ranges is still useful for logging/debug.
- geoprior.utils.subsidence_utils.detect_subsidence_mode(df, *, rate_col='subsidence', cum_col='subsidence_cum', time_col='year', group_cols=('longitude', 'latitude', 'city'), tol_rel=0.05, min_points=3, max_groups=200, random_state=42)[source]
Infer whether subsidence columns represent ‘rate’, ‘cumulative’, or an inconsistent pair.
- geoprior.utils.subsidence_utils.rate_to_cumulative(df, *, rate_col='subsidence', cum_col='subsidence_cum', time_col='year', group_cols=('longitude', 'latitude', 'city'), initial='first_equals_rate_dt', inplace=False)[source]
Build cumulative displacement from a rate series.
- geoprior.utils.subsidence_utils.cumulative_to_rate(df, *, cum_col='subsidence_cum', rate_col='subsidence', time_col='year', group_cols=('longitude', 'latitude', 'city'), first='cum_over_dtref', inplace=False)[source]
Recover a rate series from cumulative displacement.
rate(t_i) = (cum(t_i) - cum(t_{i-1})) / dt_i for i>=1
- first:
‘nan’: first rate is NaN
‘cum_over_dtref’: rate(t0) = cum(t0)/dt_ref (dt_ref median dt)
- geoprior.utils.subsidence_utils.convert_eval_payload_units(payload, cfg=None, *, mode='si', scope='all', savefile=None, fmt='json', indent=2, copy_payload=True)[source]
Convert GeoPriorSubsNet evaluation-payload units for reporting.
This is a post-processing helper meant for stage-2 evaluation JSON payloads (e.g.
geoprior_eval_phys_<timestamp>.json).- Parameters:
payload (
Mapping[str,Any]) – The evaluation payload dict. It is expected to contain sections likemetrics_evaluate,point_metrics,per_horizon,interval_calibrationandcensor_stratified.cfg (
mappingormodule, optional) – The experiment config (e.g.configmodule orglobals()). The helper readsSUBS_UNIT_TO_SI(or stage-1 provenance) andTIME_UNITSfrom this object when available.mode (
{"si", "interpretable"}, default"si") –"si"leaves values untouched."interpretable"converts selected subsidence and physics metrics from SI into the native units implied bySUBS_UNIT_TO_SI.scope (
{"all", "subsidence", "physics"}, default"all") – Which parts to convert whenmode="interpretable"."subsidence"converts only subsidence metrics such as MAE, MSE, and sharpness to the native unit."physics"converts only unambiguous physics residual rates, currentlyepsilon_cons_rawandepsilon_gw_raw."all"applies both conversions.savefile (
str, optional) – If provided, write the converted payload to this path.fmt (
{"json"}, default"json") – Output format whensavefileis provided.indent (
int, default2) – JSON indentation.copy_payload (
bool, defaultTrue) – If True, operate on a deep copy ofpayload. If False, convert in-place (dangerous).
- Returns:
Converted payload as a plain
dict.- Return type:
Notes
For subsidence metrics, linear quantities such as MAE and sharpness scale by
1 / SUBS_UNIT_TO_SI, while squared quantities such as MSE scale by(1 / SUBS_UNIT_TO_SI) ** 2.When physics conversion is enabled,
epsilon_cons_rawis treated as a rate inm/sand converted to<subs_native_unit>/<TIME_UNITS>(for examplemm/yr), whileepsilon_gw_rawis treated as a rate in1/sand converted to1/<TIME_UNITS>.The helper records unit provenance under
payload["units"].
- geoprior.utils.subsidence_utils.clean_gwl_zsurf(df, *, lon_col='longitude', lat_col='latitude', z_col='z_surf', gwl_col='GWL_depth_bgs', gwl_scale='auto', max_depth_m=50.0, z_outlier_iqr_mult=3.0, knn_k=25, return_report=True, savefile=None)[source]
Fix GWL units + robust z_surf, then recompute head_m.
- geoprior.utils.subsidence_utils.postprocess_eval_json(eval_json, *, cfg=None, scope='all', out_path=None, overwrite=False, add_rmse=True, force=False, indent=2)[source]
Post-hoc convert a Stage-2 evaluation JSON from SI to interpretable units.
This is a safe wrapper around convert_eval_payload_units(…) that: - loads a JSON from disk (or accepts a payload dict), - infers unit factors from payload[“units”] when cfg is missing, - avoids double conversion unless force=True, - optionally adds RMSE fields (sqrt(MSE)), - writes a converted JSON file if out_path is provided.
- Parameters:
eval_json (str | Mapping[str, Any]) – Either a file path to the saved JSON, or an in-memory mapping.
cfg (Mapping[str, object] | None) – Optional config mapping/module. If missing (or incomplete), this helper will synthesize the minimal keys needed for conversion: - “SUBS_UNIT_TO_SI” - “TIME_UNITS”
scope (Literal['all', 'subsidence', 'physics']) – Forwarded to convert_eval_payload_units(…).
out_path (str | None) – If given, write the converted payload there. If a directory is given, a filename is generated next to the input name (or “geoprior_eval…”).
overwrite (bool) – If False and out_path exists, raise.
add_rmse (bool) – If True, add RMSE fields wherever MSE is present (metrics_evaluate, point_metrics, per_horizon).
force (bool) – If False, skip conversion when payload already declares an interpretable subsidence unit (e.g., “mm”). If True, always convert.
indent (int) – JSON indent used when writing.
- Returns:
The converted payload (always returned as a dict).
- Return type:
This module is the workflow-side bridge back to the subsidence application. It exposes unit conversion helpers, cumulative/rate transforms, groundwater-resolution helpers, head-column selection, and coordinate construction.
Selected workflow helpers#
These functions are surfaced explicitly because they appear frequently in staged workflows, export paths, and figure or reproducibility scripts.
- geoprior.utils.load_nat_config(root='nat.com')[source]
High-level helper used by NATCOM scripts.
Example
>>> from geoprior.utils.nat_utils import load_nat_config >>> cfg = load_nat_config() >>> CITY_NAME = cfg["CITY_NAME"] >>> TIME_STEPS = cfg["TIME_STEPS"]
- geoprior.utils.load_nat_config_payload(root='nat.com')[source]
Return the full config.json payload, including city, model and __meta__ fields.
This is convenient when you also want to see which hash or city/model are currently active.
- geoprior.utils.make_tf_dataset(X_np, y_np, batch_size, shuffle, mode, forecast_horizon, *, seed=42, drop_remainder=False, reshuffle_each_iter=True, prefetch=True, check_npz_finite=False, check_finite=False, scan_finite_batches=0, dynamic_feature_names=None, future_feature_names=None)[source]
Build a tf.data.Dataset using NATCOM conventions.
Steps: 1) ensure_input_shapes(…) for X. 2) map_targets_for_training(…) for y. 3) tf.data pipeline (shuffle/batch/prefetch). 4) optional finite checks (NPZ + tf batches).
- Parameters:
X_np (
dict) – Input dictionary, typically obtained fromnp.loadon the Stage-1*_inputs_npzfile.y_np (
dict) – Target dictionary, typically obtained fromnp.loadon the Stage-1*_targets_npzfile.batch_size (
int) – Number of samples per batch.shuffle (
bool) – IfTrue, shuffle the dataset using a fixed seed for reproducibility.mode (
str) – Model mode passed toensure_input_shapes().forecast_horizon (
int) – Forecast horizon passed toensure_input_shapes().check_npz_finite (
bool) – If True, checks Xin/Yin numpy arrays for NaN/Inf before building ds.check_finite (
bool) – If True, inserts assert_all_finite checks inside the tf.data pipeline.scan_finite_batches (
int) – If >0, eagerly scans first N batches right away (fails early).dynamic_feature_names (list[str] | None) – If provided, used to report bad channels for feature tensors.
future_feature_names (list[str] | None) – If provided, used to report bad channels for feature tensors.
seed (int)
drop_remainder (bool)
reshuffle_each_iter (bool)
prefetch (bool)
- Returns:
Dataset of (X, y) pairs.
- Return type:
tf.data.Dataset
Notes
TensorFlow is imported lazily inside the function so that this module remains importable in environments where TF is not installed (for example, for tooling or static analysis).
- geoprior.utils.map_targets_for_training(y_dict, subs_key='subsidence', gwl_key='gwl', subs_pred_key='subs_pred', gwl_pred_key='gwl_pred')[source]
Standardise target dictionaries to the Keras compile keys.
This helper enforces a small convention used throughout the NATCOM training scripts:
Upstream sequence builders typically export raw targets with keys
subsidenceandgwl.The GeoPrior model is compiled with targets named
subs_predandgwl_pred.
This function accepts either style and always returns a dict keyed by
subs_predandgwl_predfor use in Keras.- Parameters:
y_dict (
dict) – Dictionary produced by the Stage-1 sequence exporter or by a previous training script. Must contain either (subsidence,gwl) or (subs_pred,gwl_pred).subs_key (
str, default"subsidence") – Name of the raw subsidence key iny_dict.gwl_key (
str, default"gwl") – Name of the raw groundwater-level key iny_dict.subs_pred_key (
str, default"subs_pred") – Standardised key for the subsidence prediction target.gwl_pred_key (
str, default"gwl_pred") – Standardised key for the GWL prediction target.
- Returns:
New dictionary with keys
subs_predandgwl_pred.- Return type:
- Raises:
KeyError – If the dictionary does not contain either of the expected key pairs.
- geoprior.utils.resolve_si_affine(cfg, scaler_info, *, target_name, prefix, unit_factor_key, scale_key, bias_key)[source]
- geoprior.utils.extract_preds(model, out, *, strict=True, output_names=None)[source]
Extract (subs_pred, gwl_pred) from GeoPrior outputs.
- Supports:
v3.2+ call(): {“subs_pred”,”gwl_pred”}
forward_with_aux(): (y_pred, aux)
legacy: {“data_final”} + model.split_data_predictions
predict(): list/tuple mapped via output names
If strict=True, list/tuple outputs must be mappable via output names; otherwise we raise to avoid silent swaps.
This helper normalizes the output interface across two GeoPrior generation families:
New interface (preferred)
model(inputs) -> {"subs_pred": ..., "gwl_pred": ...}Legacy interface (backward compatible)
model(inputs) -> {"data_final": ...}, where the caller must split the tensor usingmodel.split_data_predictions.
- Parameters:
model (
object) –A Keras-like model instance that may expose
split_data_predictions(data_final).The splitter must return a tuple:
subs_predwith shape(B, H, 1)or(B, H, Q, 1)gwl_predwith shape(B, H, 1)or(B, H, Q, 1)
out (
dict) –Output returned by the model call, typically
model(inputs, training=False).Supported keys are either:
{"subs_pred", "gwl_pred"}(new interface), or{"data_final"}(legacy interface).
strict (bool)
- Returns:
subs_pred (
Tensor) – Predicted subsidence in model space.Expected shapes:
Point mode:
(B, H, 1)Quantile mode:
(B, H, Q, 1)
gwl_pred (
Tensor) – Predicted groundwater/head variable in model space.Expected shapes:
Point mode:
(B, H, 1)Quantile mode:
(B, H, Q, 1)
- Raises:
- Return type:
Notes
This function is intended for Stage-2 and Stage-3 scripts where you may load checkpoints from older experiments. It avoids fragile code that slices
data_finalmanually.The function does not validate tensor dtypes or numerical finiteness. Upstream code should handle
NaNandInfchecks as needed. Output normalization follows the Keras model conventions documented in Keras Team [24].Examples
New interface:
out = model_inf(xb, training=False) s_pred, h_pred = extract_stage_outputs( model_inf, out, )
Legacy interface:
out = model_inf(xb, training=False) s_pred, h_pred = extract_stage_outputs( model_inf, out, )
See also
subs_point_from_stage_outConvert subsidence predictions to a point forecast.
- geoprior.utils.best_epoch_and_metrics(history, monitor='val_loss')[source]
Return the best epoch and metrics at that epoch.
Given a
History.historydictionary produced bymodel.fit(...), this helper identifies the index of the minimum value for the monitored quantity (by default"val_loss") and returns:The epoch index (0-based).
A dictionary mapping each metric name to its value at that epoch.
- Parameters:
- Returns:
- Return type:
- geoprior.utils.audit_stage1_scaling(*, df_train, inputs_train, targets_train, coord_scaler=None, coord_ranges=None, coord_mode='auto', coords_in_degrees=False, coord_epsg_used=None, coord_x_col_used='x', coord_y_col_used='y', x_col_used='x', y_col_used='y', time_col_used='t', normalize_coords=True, keep_coords_raw=False, shift_raw_coords=False, subs_model_col=None, gwl_dyn_col=None, gwl_target_col=None, h_field_col=None, dynamic_features=None, static_features=None, future_features=None, scaled_ml_numeric_cols=None, main_scaler_path=None, scaler_info=None, save_dir=None, table_width=110, title_prefix='COORDINATE + FEATURE SCALING AUDIT (Stage-1)', city='Unknown', model_name='Model', sample_rows=5, log_fn=None)[source]
Stage-1 audit: - raw df_train coord stats (t/x/y) + heuristic units - model-fed coords stats from inputs_train[“coords”] (flattened) - coord scaler min/max + coord_ranges - SI channel sanity for physics cols (if present) - target arrays sanity - split of features: scaled ML vs __si vs other Saves a machine-readable JSON if save_dir is provided.
- Parameters:
coord_scaler (Any)
coord_mode (str)
coords_in_degrees (bool)
coord_epsg_used (Any)
coord_x_col_used (str)
coord_y_col_used (str)
x_col_used (str)
y_col_used (str)
time_col_used (str)
normalize_coords (bool)
keep_coords_raw (bool)
shift_raw_coords (bool)
subs_model_col (str | None)
gwl_dyn_col (str | None)
gwl_target_col (str | None)
h_field_col (str | None)
main_scaler_path (str | None)
scaler_info (dict | None)
save_dir (str | None)
table_width (int)
title_prefix (str)
city (str)
model_name (str)
sample_rows (int)
- Return type:
str | None
- geoprior.utils.audit_stage2_handshake(*, X_train, X_val, y_train, y_val, time_steps, forecast_horizon, mode, dyn_names, fut_names, sta_names, coord_scaler=None, sk_final, save_dir, table_width=100, title_prefix='STAGE-2 HANDSHAKE AUDIT', city='Unkown', model_name='Model', log_fn=None)[source]
- geoprior.utils.calibrate_quantile_forecasts(*, df_eval=None, df_future=None, target_name='subsidence', column_map=None, step_col='forecast_step', interval=(0.1, 0.9), target_coverage=0.8, median_q=0.5, use='auto', tol=0.02, f_max=5.0, max_iter=32, keep_original=False, enforce_monotonic='cummax', overall_key='__overall__', calibrated_col='is_calibrated', factor_col='calibration_factor', factors=None, save_eval=None, save_future=None, save_stats=None, verbose=1, logger=None)[source]
Fit and apply post-hoc interval calibration for quantile forecasts.
This is the high-level DataFrame-oriented entry point for interval recalibration in
geoprior.utils.calibrate. It candetect whether evaluation forecasts already appear calibrated,
fit interval-width correction factors from evaluation data,
apply those factors to evaluation and/or future forecasts,
compute before/after summary diagnostics on the evaluation set,
optionally save the outputs to disk.
The function is designed for workflows where quantile forecasts are already available in tabular form and calibration should be handled without retraining the forecasting model.
Conceptually, the function widens or narrows a predictive interval around a median-like forecast so that the empirical interval coverage better matches the requested target. This is a practical post-hoc strategy for uncertainty refinement in multi-horizon forecasting pipelines [1, 2].
- Parameters:
df_eval (
pandas.DataFrameorNone, defaultNone) – Evaluation forecasts used to fit calibration factors and to compute before/after diagnostics. This table should contain the observed target column in addition to the quantile forecasts.df_future (
pandas.DataFrameorNone, defaultNone) – Future or inference forecasts to which the fitted factors should be applied. This table does not need observed targets.target_name (
str, default"subsidence") – Base name used to infer forecast and observation columns whencolumn_mapis not explicitly supplied.column_map (
mappingorNone, defaultNone) – Optional mapping describing the observed column and the quantile columns. This is helpful when the table does not follow the default naming conventions.step_col (
str, default"forecast_step") – Column used to fit and apply separate factors per forecast horizon.interval (
tupleoffloat, default(0.1,0.9)) – Lower and upper quantiles defining the interval to calibrate. The nearest available quantiles are used.target_coverage (
float, default0.8) – Desired empirical coverage after calibration.median_q (
float, default0.5) – Central quantile used as the expansion anchor.use (
{"auto", True, False}, default"auto") –Control flag for whether calibration is performed.
Falsedisables calibration and returns inputs unchanged."auto"skips calibration when evaluation forecasts already look calibrated.Trueforces calibration even if the automatic check would skip it.
tol (
float, default0.02) – Tolerance used by the automatic already-calibrated check.f_max (
float, default5.0) – Maximum factor allowed during fitting.max_iter (
int, default32) – Maximum number of bisection iterations used when fitting factors.keep_original (
bool, defaultFalse) – If True, raw quantiles are copied into*_rawcolumns before calibration is applied.enforce_monotonic (
{"cummax", "sort", "none"}, default"cummax") – Strategy used to prevent quantile crossing after recalibration.overall_key (
strorNone, default"__overall__") – Reserved label stored in the returned statistics dictionary for overall summary reporting.calibrated_col (
str, default"is_calibrated") – Column name added to calibrated outputs as a Boolean marker.factor_col (
str, default"calibration_factor") – Column name used to store the factor applied to each row.factors (
floatormappingorNone, defaultNone) – Optional user-specified calibration factors. If provided, these take precedence over factors fitted fromdf_eval.save_eval (
strorpath-likeorNone, defaultNone) – Optional CSV path for saving the calibrated evaluation table.save_future (
strorpath-likeorNone, defaultNone) – Optional CSV path for saving the calibrated future table.save_stats (
strorpath-likeorNone, defaultNone) – Optional JSON path for saving the calibration summary.verbose (
int, default1) – Verbosity level forwarded to logging helpers.logger (
logging.LoggerorNone, defaultNone) – Optional logger used for progress messages.
- Returns:
df_eval_cal (
pandas.DataFrameorNone) – Calibrated evaluation DataFrame, or None when no evaluation table was provided.df_future_cal (
pandas.DataFrameorNone) – Calibrated future DataFrame, or None when no future table was provided.stats (
dict[str,Any]) – Dictionary describing the calibration workflow. Depending on the path taken, it may containthe target interval and target coverage,
the fitted or user-specified factors,
skip reasons,
evaluation summaries before and after calibration.
- Return type:
Notes
In
use="auto"mode, the function first checks for an explicitcalibrated_coland then falls back to a simple empirical coverage-based decision. This makes the wrapper conservative in repeated workflows, where the same tables may pass through the calibration stage more than once.The returned
statsdictionary is designed to be JSON-friendly and therefore suitable for audit trails, experiment manifests, or gallery artifacts.Examples
>>> import pandas as pd >>> from geoprior.utils.calibrate import ( ... calibrate_quantile_forecasts, ... ) >>> df_eval = pd.DataFrame( ... { ... "forecast_step": [1, 1, 2, 2], ... "subsidence_actual": [0.4, 0.7, 0.5, 0.9], ... "subsidence_q10": [0.3, 0.5, 0.4, 0.6], ... "subsidence_q50": [0.4, 0.7, 0.5, 0.8], ... "subsidence_q90": [0.5, 0.9, 0.6, 1.0], ... } ... ) >>> df_eval_cal, df_future_cal, stats = ( ... calibrate_quantile_forecasts( ... df_eval=df_eval, ... target_name="subsidence", ... target_coverage=0.8, ... ) ... ) >>> isinstance(stats, dict) True
See also
fit_interval_factors_dfFit per-horizon interval-width correction factors.
apply_interval_factors_dfApply a scalar or per-horizon factor map to quantile forecasts.
- geoprior.utils.evaluate_forecast(eval_data, *, target_name='subsidence', column_map=None, quantile_interval=(0.1, 0.9), per_horizon=False, extra_metrics=None, extra_metric_kwargs=None, overall_key='__overall__', savefile=None, save_format='.json', time_as_str=True, verbose=1, logger=None)[source]
Evaluate forecast diagnostics from an evaluation DataFrame.
This helper consumes the
df_evaloutput fromformat_and_forecast()(or a compatible DataFrame) and computes aggregate metrics such as MAE, MSE, \(R^2\), coverage, and sharpness. It can also optionally evaluate metrics per forecast horizon and apply additional user-defined metrics.By default it expects the following columns:
'sample_idx''forecast_step''coord_t'(time)Quantile or point-prediction columns for the target, e.g.:
Quantile mode:
f'{target_name}_q10',f'{target_name}_q50',f'{target_name}_q90', …Point mode:
f'{target_name}_pred'.
Actual column:
f'{target_name}_actual'.
A flexible
column_mapallows remapping these logical roles to arbitrary column names, e.g.:column_map = { 'coord_t': 'date', 'actual': 'true_subs', 'pred': 'subs_predicted', }
or, for quantile columns:
column_map = { 'coord_t': 'date', 'quantiles': { 0.1: 'subs_q10', 0.5: 'subs_q50', 0.9: 'subs_q90', }, }
- Parameters:
eval_data (
str,path-like, orpandas.DataFrame) – Either a path to a CSV file containing the evaluation DataFrame (as saved byformat_and_forecast()) or an in-memory DataFrame.target_name (
str, default'subsidence') – Base name for the target columns. Used to infer default column names such asf'{target_name}_q10',f'{target_name}_pred', andf'{target_name}_actual'.column_map (
dict, optional) –Optional mapping to override default column names. The following keys are recognized:
'sample_idx': sample index column name (default'sample_idx').'forecast_step': horizon index column name (default'forecast_step').'coord_t': time coordinate column (default'coord_t').'actual': name or list of names for the actual target column(s). Currently a single column is supported; defaultf'{target_name}_actual'.'pred': point prediction column for non-quantile mode, defaultf'{target_name}_pred'.'quantiles':If a mapping:
{q: col_name}for quantile levels, whereqis a float in (0, 1).If a sequence of column names, the quantile value will be inferred from suffix patterns like
f'{target_name}_q{int(q*100):d}'.
quantile_interval (
tupleoffloat, default(0.1,0.9)) – Interval (lower, upper) used for coverage and sharpness metrics, typically corresponding to an 80% interval between Q10 and Q90.per_horizon (
bool, defaultFalse) – IfTrue, compute per-horizon MAE/MSE/R² grouped by theforecast_stepcolumn.extra_metrics (
sequenceofstrormapping, optional) –Optional additional metrics to compute.
If a sequence of strings (e.g.
['pss', 'pit']), each name is resolved viageoprior.metrics._registry.get_metric(). If the name is not present in the registry, an error is raised, prompting the user to pass a callable instead.If a mapping
{name: func}, eachfuncis called as:func(y_true, y_pred, **extra_metric_kwargs.get(name, {}))
where
y_predis the median (Q50) or point forecast.
For more complex metrics that require full quantile structure or temporal sequences, pass a suitable wrapper function that internally uses the DataFrame as needed.
extra_metric_kwargs (
mapping, optional) – Optional mapping of per-metric keyword arguments. Keys must match the names inextra_metrics. Each value is a dict of kwargs forwarded to the corresponding metric function.savefile (
str,path-like, orbool, optional) –If provided, metrics are saved to disk.
If
True: a filename is auto-generated neareval_data(if it is a path) or in the current working directory.If a string/path without extension: the extension is taken from
save_format.If a string/path with extension: that extension takes precedence over
save_format.
save_format (
{'.json', 'json', '.csv', 'csv'}, default'.json') –Output format when
savefileis truthy. JSON preserves nested structure; CSV is flattened into a tall table.For JSON, the function returns the metrics dictionary.
For CSV, the function returns the metrics DataFrame.
time_as_str (
bool, defaultTrue) – IfTrue, time keys in the result dictionary are converted to strings (useful for JSON serialization). If there is only a single time value, the result is flattened and the time key is omitted.verbose (
int, default1) – Verbosity level passed tovlog().logger (
logging.Logger, optional) – Optional logger instance used byvlog().overall_key (str | None)
- Returns:
results – If
save_formatis JSON (default), returns a dict:Single time value:
{ "overall_mae": ..., "overall_mse": ..., "overall_r2": ..., "coverage80": ..., "sharpness80": ..., "per_horizon_mae": {1: ..., 2: ..., ...}, ... }
Multiple time values:
{ "2021": { ...metrics... }, "2022": { ...metrics... }, }
If
save_formatis CSV, returns a DataFrame with flattened rows:Columns include:
coord_t,metric,horizon, andvalue.
- Return type:
Notes
Default metrics in quantile mode:
overall_mae,overall_mse,overall_r2coverage80andsharpness80(using the requested interval, e.g., Q10–Q90)
If
per_horizon=True, also:per_horizon_mae,per_horizon_mse,per_horizon_r2(each a mapping from horizon index to score).
Default metrics in point mode (no quantiles):
mae,mse,r2
And optionally, if
per_horizon=True:per_horizon_mae,per_horizon_mse,per_horizon_r2.
- geoprior.utils.format_and_forecast(y_pred, y_true, *, coords=None, quantiles=None, target_name='subsidence', output_target_name=None, scaler_target_name=None, target_key_pred='subs_pred', component_index=0, scaler_info=None, coord_scaler=None, coord_columns=('coord_t', 'coord_x', 'coord_y'), train_end_time=None, forecast_start_time=None, forecast_horizon=None, future_time_grid=None, eval_forecast_step=None, eval_export='all', value_mode='rate', input_value_mode='rate', rate_first='cum_over_dtref', absolute_baseline=None, sample_index_offset=0, city_name=None, model_name=None, dataset_name=None, csv_eval_path=None, csv_future_path=None, time_as_datetime=False, time_format=None, calibration=False, calibration_kwargs=None, calibration_save_stats=None, eval_metrics=False, metrics_column_map=None, metrics_quantile_interval=(0.1, 0.9), metrics_per_horizon=False, metrics_extra=None, metrics_extra_kwargs=None, metrics_savefile=None, metrics_save_format='.json', metrics_time_as_str=True, output_unit=None, output_unit_from='m', output_unit_mode='overwrite', output_unit_suffix='_mm', output_unit_col=None, verbose=1, logger=None, **kws)[source]
Format PINN forecasts into evaluation and future DataFrames.
This helper takes the raw model outputs (already split into
y_pred['subs_pred']/y_pred['gwl_pred']), the matching ground-truth dictionary (y_true), and optional coordinate and scaler information, and returns two DataFrames:df_eval: predictions + actuals for an evaluation year (typically the last training year, e.g. 2022).df_future: predictions for the future horizon (e.g. 2023–2025), without actuals.
- Parameters:
y_pred (
dict) –Dictionary of model predictions, as returned by
GeoPriorSubsNet.predictpost-processed into{'subs_pred': ..., 'gwl_pred': ...}.For subsidence, the expected shapes are:
Quantile mode:
(B, H, Q, O)where:B= batch size,H= horizon steps,Q= number of quantiles,O= output dim.Point mode:
(B, H, O).
Dictionary of true targets, typically
{'subsidence': ..., 'gwl': ...}or{'subs_pred': ..., 'gwl_pred': ...}.If
None, evaluation DataFrame is still created but without the actual-value column.coords (
ndarray, optional) – Optional coordinates array aligned with predictions. Commonly shaped(B, H, 3)with columns[t_scaled, x_scaled, y_scaled]. Onlyxandyare used when inverse-transforming spatial coordinates; time is overwritten by the provided temporal config if given.quantiles (
listoffloatorNone, optional) – List of quantiles (e.g.[0.1, 0.5, 0.9]) if the model was trained in probabilistic mode. IfNone, a single prediction column is emitted instead.target_name (
str, default'subsidence') –Logical target identifier used as the default key for locating the target scaler in
scaler_infoand as a fallback for resolving truth arrays iny_true.Column naming is controlled by
output_target_name(or the auto-derived output prefix when it isNone).output_target_name (
strorNone, optional) –Output prefix used when creating DataFrame columns for predictions and actuals.
This controls the column naming only (e.g. the function will emit
f"{output_target_name}_q10",f"{output_target_name}_pred", andf"{output_target_name}_actual").If
None(default), the function derives the output prefix fromtarget_nameand applies a small convenience rule: iftarget_nameends with"_cum"or"_cumulative", that suffix is stripped for output naming.This keeps downstream tooling consistent (many plotting and metrics utilities expect names like
subsidence_q10rather thansubsidence_cum_q10), while still allowing the scaler lookup to use the true target key. For example, withtarget_name="subsidence_cum"andoutput_target_name=None, output columns becomesubsidence_q10,subsidence_q50, andsubsidence_actual. Ifoutput_target_name="subsidence_cum", the output columns keep the suffix such assubsidence_cum_q10.scaler_target_name (
strorNone, optional) –Name used to locate the target scaling block inside
scaler_infoand to perform inverse-transform for predictions and actuals.This controls the scaler key and inverse scaling, not the output column naming.
If
None(default), the scaler key is assumed to betarget_name. This is important when you want clean output columns but the scaler was fitted/stored under the original target name.A common pattern is to keep
target_name="subsidence_cum"so the scaler lookup matches the Stage-1 schema, while lettingoutput_target_name=Noneproduce clean output columns. In that setup, inverse transform still uses thesubsidence_cumscaler key, while output columns use thesubsidence_prefix because of the auto-strip rule.target_key_pred (
str, default'subs_pred') – Key insidey_predthat holds the subsidence forecasts.component_index (
int, default0) – Index along the output dimension O to use whenoutput_subsidence_dim > 1. For scalar subsidence this is 0.scaler_info (
dict, optional) – Optional Stage-1scaler_infomapping containing a target scaler under keys such as'targets'or'target'. The target block is expected to provide an sklearn-like transformer under'scaler'together with column names under'columns'or'cols'. If present and consistent, subsidence values (predicted and actual) are inverse-transformed fortarget_name.coord_scaler (
object, optional) – Optional scaler used for coordinates. If provided, it is only used to inverse-transformcoord_xandcoord_ywhencoordsis given andcoord_columnscan be matched. Time is not taken from the inverse transform; it is controlled by the temporal config.coord_columns (
tupleofstr, default(``’coord_t’:py:class:`,`’coord_x’:py:class:`,`’coord_y’``)) – Logical names of the time, x, and y coordinate columns. These are used for DataFrame column naming and for mapping intocoord_scalerif its block carries column names.train_end_time (
scalarorstrordatetime, optional) – Physical time associated with the evaluation year (e.g. 2022). Ifeval_forecast_stepis not given, the last horizon step is assumed to correspond to this time.forecast_start_time (
scalarorstrordatetime, optional) – First time in the future forecast horizon (e.g. 2023).forecast_horizon (
int, optional) – Number of forecast steps in the future horizon (e.g. 3). Iffuture_time_gridis not given, this is used together withforecast_start_timeto build a regular grid.future_time_grid (
array-like, optional) – Explicit physical times for each forecast step, length H. For yearly data this might be[2023, 2024, 2025]. If provided, it overrides any automatic construction fromforecast_start_timeandforecast_horizon.eval_forecast_step (
intorNone, optional) – Horizon step index (1-based) to use for evaluation. IfNone, defaults to the last horizon step H.eval_export (
{"all", "last"}orstrorintorsequence, optional) –Controls which evaluation rows are exported in
df_evaland written tocsv_eval_path. By default ("all"), the function exports the multi-horizon evaluation DataFrame (df_eval_all), which contains one row per sample and forecast step (e.g. years 2020, 2021, 2022 forH=3).Accepted values are:
"all"or"full"or"horizons": export all horizons fromdf_eval_all."last"or"single"or"default": export only the single evaluation step specified byeval_forecast_step(backwards-compatible behaviour).Other
str(e.g."2022") : interpreted as a time value forcoord_t; only rows ofdf_eval_allwhose time column matches this value are exported.intor scalar non-string : interpreted as a single time value (e.g.2022).sequence of values (e.g.
[2021, 2022]) : interpreted as a set of time values; only rows whosecoord_tbelongs to this set are exported.
If
time_as_datetime=True, the selection values are converted withpandas.to_datetimeusingtime_formatbefore filtering. Ifdf_eval_allis not available (e.g. no ground truth was provided), the function falls back to exporting the single-stepdf_evalregardless ofeval_export.value_mode (
{"rate", "cumulative", "absolute_cumulative"}, optional) –Controls how forecast values are interpreted along the temporal horizon for each sample. The default is
"rate", which treats each forecast step as an incremental rate (e.g. annual subsidence rate) and leaves predictions unchanged.Supported modes are:
"rate": keep per-step predictions as provided by the model (current behaviour)."cumulative"or"cum": convert per-step rates into relative cumulative values by applying a cumulative sum overforecast_stepfor eachsample_idx. For example, for years 2023–2025, the value at 2024 is the sum of the 2023 and 2024 rates."absolute_cumulative"or"abs_cum"or"absolute": same as"cumulative", then add an absolute baseline provided byabsolute_baseline(e.g. cumulative subsidence at the end of the training period), yielding absolute cumulative trajectories.
Cumulative transforms are applied consistently to:
the future forecast DataFrame (
df_future),the multi-horizon evaluation DataFrame (
df_eval_all),and the single-step evaluation DataFrame (
df_eval, which is regenerated fromdf_eval_allafter the transformation).
When an unsupported string is given, the function logs a warning and falls back to
"rate".absolute_baseline (
floatorMapping[int,float], optional) –Baseline value to use when
value_moderequests absolute cumulative outputs ("absolute_cumulative","abs_cum","absolute"). This baseline is interpreted as the pre-forecast cumulative level for each sample, for example, cumulative subsidence attrain_end_time(e.g. end of 2022), and is added after applying the cumulative sum over the forecast horizon.If a scalar
floatis provided, the same baseline value is added to all samples. If a mapping is provided, it must mapsample_idx(integers) to baseline values, allowing per-sample baselines:absolute_baseline = {sample_idx: baseline_value, ...}
Only prediction columns for
target_nameare shifted (e.g."subsidence_q10","subsidence_q50","subsidence_q90"or"subsidence_pred"). Whendf_eval_allis present, the corresponding"<target_name>_actual"column is shifted as well, so evaluation metrics operate on absolute cumulative values.If
value_modeis an absolute cumulative variant butabsolute_baselineisNone, the function logs a warning and degrades gracefully to relative cumulative mode (i.e. no baseline shift is applied).sample_index_offset (
int, default0) – Offset added tosample_idx(useful when concatenating multiple tiles).city_name (
str, optional) – Optional metadata used only for logging.model_name (
str, optional) – Optional metadata used only for logging.dataset_name (
str, optional) – Optional metadata used only for logging.csv_eval_path (
str, optional) – If provided,df_evalis written to this path (directories are created if needed).csv_future_path (
str, optional) – If provided,df_futureis written to this path.time_as_datetime (
bool, defaultFalse) – IfTrue, time values are converted usingpandas.to_datetime()with the providedtime_format(if any).time_format (
strorNone, optional) – Optional format string passed topandas.to_datetime()whentime_as_datetime=True.eval_metrics (
bool, defaultFalse) – IfTrue, automatically callevaluate_forecast()on the resultingdf_evalto compute diagnostics. Metrics are not returned by this function; they are either written to disk (ifmetrics_savefileis provided) or discarded. For programmatic access to the metrics dictionary, callevaluate_forecast()directly.metrics_column_map (
mapping, optional) – Optional column mapping forwarded toevaluate_forecast()(see its documentation for details). IfNone, default column names such as'coord_t','forecast_step',f'{target_name}_q10', andf'{target_name}_actual'are assumed.metrics_quantile_interval (
tupleoffloat, default(0.1,0.9)) – Interval used for coverage and sharpness diagnostics in quantile mode, forwarded toevaluate_forecast().metrics_per_horizon (
bool, defaultFalse) – IfTrue, per-horizon MAE/MSE/R² are computed byevaluate_forecast()and included in the diagnostics.metrics_extra (
sequenceormapping, optional) –Optional additional metrics to compute, forwarded to
evaluate_forecast(). Can be:A sequence of metric names (resolved via
geoprior.metrics._registry.get_metric).A mapping
{name: func}wherefuncis a callable taking(y_true, y_pred, **kwargs).
metrics_extra_kwargs (
mapping, optional) – Optional per-metric keyword arguments, forwarded toevaluate_forecast(). Keys must match metric names inmetrics_extra.metrics_savefile (
str,path-like,bool, orNone) – If truthy, diagnostics fromevaluate_forecast()are written to disk. Behavior matches thesavefileargument ofevaluate_forecast(). WhenTrue, a filename is auto-generated near the evaluation CSV (if any) or in the current working directory.metrics_save_format (
{'.json', 'json', '.csv', 'csv'}, default'.json') – Output format for diagnostics written byevaluate_forecast(). JSON preserves the nested metric structure; CSV flattens it into a tall table.metrics_time_as_str (
bool, defaultTrue) – IfTrue, time keys in the diagnostics written byevaluate_forecast()are converted to strings (useful for JSON serialization).verbose (
int, default1) – Verbosity level passed tovlog().logger (
logging.Logger, optional) – Logger instance; ifNone, a module-levelLOGis used.input_value_mode (str)
rate_first (str)
output_unit (str | None)
output_unit_from (str)
output_unit_mode (str)
output_unit_suffix (str)
output_unit_col (str | None)
- Returns:
df_eval_to_write (
pandas.DataFrame) – DataFrame containing predictions and actuals for the evaluation time. Columns include:'sample_idx''forecast_step'quantile columns (e.g.
subsidence_q10) orsubsidence_pred'subsidence_actual'(if y_true given)coord_t,coord_x,coord_y(names fromcoord_columns).
df_future (
pandas.DataFrame) – DataFrame containing predictions for the future horizon, without actuals. Same structure asdf_evalbut without the actual-value column.
- Return type:
Notes
This function separates scaler lookup (
scaler_target_name) from output column naming (output_target_name). This is useful when the stored scaler key contains suffixes like"_cum"but downstream tools expect canonical names such as columns prefixed withsubsidence_.
- geoprior.utils.create_spatial_clusters(df, spatial_cols=None, cluster_col='region', n_clusters=None, algorithm='kmeans', view=True, figsize=(14, 10), s=60, plot_style='seaborn', cmap='tab20', show_grid=True, grid_props=None, auto_scale=True, savefile=None, verbose=1, **kwargs)[source]
Cluster 2D spatial data in
dfusing <algorithm> and optionally plot the results.This function, <create_spatial_clusters>, extracts two coordinate columns from <df> to form clusters via methods such as ‘kmeans’, ‘dbscan’, or ‘agglo’ (agglomerative). It uses the function filter_valid_kwargs (when relevant) to strip out invalid parameters for certain estimators, and writes cluster labels into <cluster_col>.
- Parameters:
df (
pandas.DataFrame) – Input DataFrame holding spatial coordinates and optional other fields.spatial_cols (
listofstr, optional) – Two-column list for x and y coordinates. Defaults to['longitude','latitude']if None.cluster_col (
str, default'region') – Name of the column to store the assigned cluster labels.n_clusters (
int, optional) – Number of clusters to form. If not provided for KMeans, it is auto-detected. For DBSCAN or Agglomerative, a warning is issued if not set.algorithm (
str, default'kmeans') – Choice of clustering algorithm among['kmeans','dbscan','agglo'].view (
bool, defaultTrue) – If True, displays a scatterplot of the final clusters.figsize (
tuple, default(14,10)) – Size of the displayed figure for the cluster plot.s (
int, default60) – Marker size in the scatterplot.plot_style (
str, default'seaborn') – Matplotlib style used for the plot.cmap (
str, default'tab20') – Colormap name used to differentiate clusters.show_grid (
bool, defaultTrue) – Toggles grid lines on or off.grid_props (
dict, optional) – Additional keyword arguments controlling the grid style.auto_scale (
bool, defaultTrue) – If True, standardize coordinates before clustering.savefile (
str, optional) – File path to save the data with an additional <cluster_col> storing the assigned cluster labels if desired.verbose (
int, default1) – Controls console logs. Higher values yield more details about scaling and cluster detection.**kwargs – Additional keyword arguments passed to the chosen algorithm (filtered by filter_valid_kwargs for KMeans, DBSCAN, AgglomerativeClustering ).
- Returns:
A copy of <df> with an additional <cluster_col> storing the assigned cluster labels.
- Return type:
Notes
If <auto_scale> is True, it uses a standard scaler to normalize the coordinate columns. The scatterplot is generated using the library seaborn for enhanced styling.
By default, for <algorithm> = “kmeans”, the model attempts to minimize:
(43)#\[J = \sum_{i=1}^{N} \min_{\mu_j} \lVert x_i - \mu_j \rVert^2\]where \(x_i\) are the scaled or raw 2D coordinates in <df>. The function can optionally auto-detect
n_clustersusing a silhouette and elbow analysis if not provided.Examples
>>> from geoprior.utils.spatial_utils import create_spatial_clusters >>> import pandas as pd >>> df = pd.DataFrame({ ... "longitude": [0.1, 0.2, 2.2, 2.3], ... "latitude": [1.0, 1.1, 2.1, 2.2] ... }) >>> # KMeans with auto scale and auto-detect k >>> result = create_spatial_clusters( ... df=df, ... algorithm="kmeans", ... view=True ... ) >>> # DBSCAN with custom arguments >>> result_db = create_spatial_clusters( ... df=df, ... algorithm="dbscan", ... eps=0.5, ... min_samples=2 ... )
See also
filter_valid_kwargsHelps discard unsupported keyword arguments for chosen estimators.
- geoprior.utils.spatial_sampling(data, sample_size=0.01, stratify_by=None, spatial_bins=10, spatial_cols=None, method='abs', min_relative_ratio=0.01, random_state=42, savefile=None, verbose=1)[source]
Sample spatial data intelligently to represent the distribution of the whole area and include different years.
This function performs stratified sampling on spatial data, ensuring that the sample reflects both spatial distribution and temporal aspects of the entire dataset. It combines spatial stratification based on coordinates and additional stratification columns specified by the user.
- Parameters:
data (
pandas.DataFrame) – The input DataFrame to sample from. Must contain spatial coordinate columns (e.g., ‘longitude’, ‘latitude’) and any columns specified instratify_by.sample_size (
floatorint, optional) – The proportion or absolute number of samples to select. If float, should be between 0.0 and 1.0 and represents the fraction of the dataset to include in the sample. If int, represents the absolute number of samples to select. Default is0.01(1% of the data).stratify_by (
listofstr, optional) – List of column names to stratify by.spatial_bins (
intortuple/listofint, optional) – Number of bins to divide the spatial coordinates into. If an integer, the same number of bins is used for all spatial dimensions. If a tuple or list, its length must match the number of spatial columns, specifying the number of bins for each spatial dimension. Default is10.spatial_cols (
listortupleofstr, optional) – List of spatial coordinate column names. Can accept one or two columns. IfNone, the function checks for columns named ‘longitude’ and/or ‘latitude’ indata. If only one spatial column is provided or found, a warning is issued, suggesting that providing both spatial columns is recommended for more accurate sampling. If more than two columns are provided, an error is raised.method (
str,{'abs', 'relative'}, default'abs') – Defines how the sample size is determined.'abs'or'absolute'uses a fixed sampling proportion based onsample_size.'relative'scales sampling by dataset stratification so small groups still receive a proportional sample controlled bymin_relative_ratio.min_relative_ratio (
float, default0.01) – Controls the minimum allowable fraction of records that must be sampled whenmethod='relative'. It must be between0and1. For example,min_relative_ratio=0.05requests at least 5 percent of the total dataset size from each stratification group when possible; if a group is smaller than that minimum, the entire subset is sampled instead.random_state (
int, optional) – Random seed for reproducibility. Default is42.verbose (
int, default1) – Controls progress-bar and status output during execution. Larger values produce more detailed messages.
- Returns:
sampled_data – A sampled DataFrame representing the distribution of the whole area and including different years.
- Return type:
Notes
The function performs stratified sampling based on spatial bins and other specified stratification columns. Spatial coordinates are binned using quantile-based discretization (
pandas.qcut()), ensuring each bin has approximately the same number of observations.Let \(N\) be the total number of samples in
data, and \(n\) be the desired sample size. The function calculates the number of samples to draw from each stratification group based on the proportion of the group size to the total dataset size:(44)#\[n_i = \left\lceil \frac{N_i}{N} \times n \right\rceil\]where \(N_i\) is the size of group \(i\), and \(n_i\) is the number of samples to draw from group \(i\).
The function ensures that all specified spatial and stratification columns exist in
data, that the number of spatial bins matches the number of spatial columns, and that the sample size is valid. A warning is issued when only one spatial column is used because two spatial columns usually give more reliable spatial sampling.Examples
>>> from geoprior.utils.spatial_utils import spatial_sampling >>> import pandas as pd >>> # Assume 'df' is a pandas DataFrame with columns >>> # 'longitude', 'latitude', 'year', and other data. >>> sampled_df = spatial_sampling( ... data=df, ... sample_size=0.05, ... stratify_by=['year', 'geological_category'], ... spatial_bins=(10, 15), ... spatial_cols=['longitude', 'latitude'], ... random_state=42 ... ) >>> print(sampled_df.shape)
See also
pandas.qcutQuantile-based discretization function used for binning.
sklearn.model_selection.StratifiedShuffleSplitFor stratified sampling.
batch_spatial_samplingResample spatial data with batching.
- geoprior.utils.make_txy_coords(t, x, y, *, time_shift='min', xy_shift='min', time_shift_value=None, x_shift_value=None, y_shift_value=None, dtype='float32')[source]
Build coords tensor (t, x, y) with OPTIONAL shifting (translation only).
- This is designed for your “not normalized” workflow:
You keep SI units (years and meters),
but you avoid feeding huge UTM magnitudes (e.g. 3e5, 2.5e6) into coord MLPs by shifting x,y (and optionally t).
Notes
This does NOT min-max scale to [0,1]. It only translates.
Returning coord_mins/coord_ranges is still useful for logging/debug.
Model-facing utility package#
The model-side utility package complements the workflow surface and is closer to data layout, model I/O contracts, sequence generation, and PINN-oriented helper logic.
Public exports for model utility helpers.
- geoprior.models.utils.compute_anomaly_scores(y_true, y_pred=None, method='statistical', threshold=3.0, domain_func=None, contamination=0.05, epsilon=1e-06, estimator=None, random_state=None, residual_metric='mse', objective='ts', verbose=1)[source]#
Compute anomaly scores for given true targets using various methods.
This utility function,
anomaly_scores, provides a flexible approach to compute anomaly scores outside the XTFT model itself. Anomaly scores serve as indicators of how unusual certain observations are, guiding the model towards more robust and stable forecasts. By detecting and quantifying anomalies, practitioners can adjust forecasting strategies, improve predictive performance, and handle irregular patterns more effectively.- Parameters:
y_true (
np.ndarray) –The ground truth target values with shape
(B, H, O), where: - B: batch size - H: number of forecast horizons (time steps ahead) - O: output dimension (e.g., number of target variables).Typically, y_true corresponds to the same array passed as the forecast target to the model. All computations of anomalies are relative to these true values or, if provided, their predicted counterparts y_pred.
y_pred (
np.ndarray, optional) – The predicted values with shape(B, H, O). If provided and the method is set to ‘residual’, the anomaly scores are derived from the residuals between y_true and y_pred. In this scenario, anomalies reflect discrepancies indicating unusual conditions or model underperformance.method (
str, optional) –The method used to compute anomaly scores. Supported options:
"statistical"or"stats": Uses mean and standard deviation of y_true to measure deviation from the mean. Points far from the mean by a certain factor (controlled by threshold) yield higher anomaly scores.Formally, let \(\mu\) be the mean of y_true and \(\sigma\) its standard deviation. The anomaly score for a point \(y\) is:
(45)#\[(\frac{y - \mu}{\sigma + \varepsilon})^2\]where \(\varepsilon\) is a small constant for numerical stability.
"domain": Uses a domain-specific heuristic (provided by domain_func) to compute scores. If no domain_func is provided, a default heuristic marks negative values as anomalies."isolation_forest"or"if": Employs the IsolationForest algorithm to detect outliers. The model learns a structure to isolate anomalies more quickly than normal points. Higher contamination rates allow more points to be considered anomalous."residual": If y_pred is provided, anomalies are derived from residuals: the difference (y_true - y_pred). By default, mean squared error (mse) is used. Other metrics include mae and rmse, offering flexibility in quantifying deviations:(46)#\[\text{MSE: }(y_{true} - y_{pred})^2\]
Default is
"statistical".threshold (
float, optional) – Threshold factor for the statistical method. Defines how far beyond mean ± (threshold * std) is considered anomalous. Though not directly applied as a mask here, it can guide interpretation of scores. Default is3.0.domain_func (
callable, optional) –A user-defined function for domain method. It takes y_true as input and returns an array of anomaly scores with the same shape. If none is provided, the default heuristic:
(47)\[\begin{split}\text{anomaly}(y) = \begin{cases} |y| \times 10 & \text{if } y < 0 \\ 0 & \text{otherwise} \end{cases}\end{split}\]contamination (
float, optional) – Used in the isolation_forest method. Specifies the proportion of outliers in the dataset. Default is0.05.epsilon (
float, optional) – A small constant \(\varepsilon\) for numerical stability in calculations, especially during statistical normalization. Default is1e-6.estimator (
object, optional) – A pre-fitted IsolationForest estimator for the isolation_forest method. If not provided, a new estimator will be created and fitted to y_true.random_state (
int, optional) – Sets a random state for reproducibility in the isolation_forest method.residual_metric (
str, optional) –The metric used to compute anomalies from residuals if method is set to ‘residual’. Supported metrics:
"mse": mean squared error per point (residuals**2)"mae": mean absolute error per point |residuals|"rmse": root mean squared error sqrt((residuals**2))
Default is
"mse".objective (
str, optional) – Specifies the type of objective, for future extensibility. Default is"ts"indicating time series. Could be extended for other tasks in the future.verbose (
int, optional) – Controls verbosity. If verbose=1, some messages or warnings may be printed. Higher values might produce more detailed logs.
- Returns:
anomaly_scores – An array of anomaly scores with the same shape as y_true. Higher values indicate more unusual or anomalous points.
- Return type:
np.ndarray
Notes
Choosing an appropriate method depends on the data characteristics, domain requirements, and model complexity. Statistical methods are quick and interpretable but may oversimplify anomalies. Domain heuristics leverage expert knowledge, while isolation forest applies a more robust, data-driven approach. Residual-based anomalies help assess model performance and highlight periods where the model struggles.
Examples
>>> from geoprior.models.losses import compute_anomaly_scores >>> import numpy as np
>>> # Statistical method example >>> y_true = np.random.randn(32, 20, 1) # (B,H,O) >>> scores = compute_anomaly_scores(y_true, method='statistical', threshold=3) >>> scores.shape (32, 20, 1)
>>> # Domain-specific example >>> def my_heuristic(y): ... return np.where(y < -1, np.abs(y)*5, 0.0) >>> scores = compute_anomaly_scores(y_true, method='domain', domain_func=my_heuristic)
>>> # Isolation Forest example >>> scores = compute_anomaly_scores(y_true, method='isolation_forest', contamination=0.1)
>>> # Residual-based example >>> y_pred = y_true + np.random.normal(0, 1, y_true.shape) # Introduce noise >>> scores = compute_anomaly_scores(y_true, y_pred=y_pred, method='residual', residual_metric='mae')
See also
geoprior.models.losses.objective_lossFor integrating anomaly scores into a multi-objective loss.
- geoprior.models.utils.compute_forecast_horizon(data=None, dt_col=None, start_pred=None, end_pred=None, error='raise', verbose=1)[source]#
Compute the forecast horizon for time series forecasting models.
This function calculates the number of future time steps (forecast_horizon) a model should predict based on the provided data or specified prediction dates. It intelligently infers the frequency of the data and computes the horizon accordingly. The function accommodates various datetime formats and handles different input scenarios robustly.
- Parameters:
data (
pandas.DataFrame,pandas.Series,list, ornumpy.ndarray, optional) – The dataset containing datetime information. If a pandas.DataFrame is provided, the dt_col parameter must be specified to indicate which column contains the datetime data. For pandas.Series, list, or numpy.ndarray, the function attempts to infer the frequency directly.dt_col (
str, optional) – The name of the column in data that contains datetime information. This parameter is required if data is a pandas.DataFrame. Example:dt_col='timestamp'start_pred (
str,int, ordatetime-like) – The starting point for forecasting. This can be a date string (e.g., ‘2023-04-10’), a datetime object, or an integer representing a year (e.g., 2024). If an integer is provided, it is interpreted as a year, and a warning is issued to inform the user of this interpretation.end_pred (
str,int, ordatetime-like) – The ending point for forecasting. Similar to start_pred, this can be a date string, a datetime object, or an integer representing a year. The function calculates the forecast horizon based on the difference between start_pred and end_pred.error (
{'raise', 'warn', 'ignore'}, default'raise') –Defines the error handling behavior when encountering issues such as invalid input types, missing date columns, or unparseable dates.
’raise’: Raises a ValueError when an error is encountered.
’warn’: Emits a warning and attempts to proceed with default behavior.
verbose (
int, default1) –Controls the level of verbosity for debug information.
0: No output.
1: Minimal output (e.g., starting message).
2: Intermediate output (e.g., detected dates, computed horizons).
3: Detailed output (e.g., types of predictions, inferred frequencies).
- Returns:
The computed forecast_horizon representing the number of steps ahead the model should predict. Returns None if an error occurs and error is set to ‘warn’.
- Return type:
- Raises:
ValueError – If invalid parameters are provided and error is set to ‘raise’.
Examples
>>> from geoprior.models.utils import compute_forecast_horizon >>> import pandas as pd >>> import numpy as np >>> from datetime import datetime, timedelta >>> >>> # Example 1: Using a DataFrame with a Date Column >>> df = pd.DataFrame({ ... 'date': pd.date_range(start='2023-01-01', periods=100, freq='D'), ... 'value': np.random.randn(100) ... }) >>> horizon = compute_forecast_horizon( ... data=df, ... dt_col='date', ... start_pred='2023-04-10', ... end_pred='2023-04-20', ... error='raise', ... verbose=3 ... ) >>> print(f"Forecast Horizon: {horizon}") Forecast Horizon: 11
>>> # Example 2: Using a List of Datetimes >>> dates = [datetime(2023, 1, 1) + timedelta(days=i) for i in range(100)] >>> horizon = compute_forecast_horizon( ... data=dates, ... start_pred='2023-04-10', ... end_pred='2023-04-20', ... error='warn', ... verbose=2 ... ) >>> print(f"Forecast Horizon: {horizon}") Forecast Horizon: 11
>>> # Example 3: Handling Integer Years >>> horizon = compute_forecast_horizon( ... start_pred=2024, ... end_pred=2030, ... error='raise', ... verbose=1 ... ) Forecast Horizon: 7
>>> # Example 4: Without Providing Data (Assuming Frequency Based on Prediction Dates) >>> horizon = compute_forecast_horizon( ... start_pred='2023-04-10', ... end_pred='2023-04-20', ... error='raise', ... verbose=1 ... ) Forecast Horizon: 11
Notes
When data is not provided, the function relies solely on the difference between start_pred and end_pred to compute the forecast horizon. In such cases, if the frequency cannot be inferred, the horizon is calculated based on the largest possible time unit (years, months, weeks, days).
If start_pred is after end_pred, the function returns 0 and issues a warning or raises an error based on the error parameter.
The function attempts to infer the frequency of the data using pandas utilities. If the frequency cannot be inferred, it defaults to calculating the horizon based on the time difference in the most significant unit.
See also
pandas.date_rangeGenerates a fixed frequency DatetimeIndex.
pandas.infer_freqInfers the frequency of a DatetimeIndex.
- geoprior.models.utils.create_sequences(df, sequence_length, target_col, step=1, include_overlap=True, drop_last=True, forecast_horizon=None, verbose=3)[source]#
Create input sequences and corresponding targets for time series forecasting.
The create_sequences function generates sequences of features and their corresponding targets from a time series dataset. This is essential for training sequence models like Temporal Fusion Transformers, LSTMs, and others that rely on temporal dependencies.
See more in User Guide.
- Parameters:
df (pandas.DataFrame) – The processed DataFrame containing features and the target variable.
sequence_length (int) – The number of past time steps to include in each input sequence.
target_col (str) – The name of the target column.
step (int, default 1) – The step size between the starts of consecutive sequences.
include_overlap (bool, default True) – Whether to include overlapping sequences based on the step size.
drop_last (bool, default True) – Whether to drop the last sequence if it does not have enough data points.
forecast_horizon (int, optional, default None) – The number of future time steps to predict. If set to None, the function will create targets for a single future time step. If provided, targets will consist of the next forecast_horizon time steps.
verbose (int, default 3) – Controls the verbosity of logging. Ranges from 0 (no logs) to 7 (maximal logs).
- Returns:
- A tuple containing:
sequences: Array of input sequences with shape (num_sequences, sequence_length, num_features).
- targets:
If forecast_horizon is None: Array of target values with shape (num_sequences,).
If forecast_horizon is an integer: Array of target sequences with shape (num_sequences, forecast_horizon).
- Return type:
Tuple[`numpy.ndarray`, numpy.ndarray]- Raises:
ValueError – If the DataFrame df does not contain the target_col.
Examples
>>> import pandas as pd >>> import numpy as np >>> from geoprior.models.utils import create_sequences
>>> # Create a dummy DataFrame >>> data = pd.DataFrame({ ... 'feature1': np.random.rand(100), ... 'feature2': np.random.rand(100), ... 'feature3': np.random.rand(100), ... 'target': np.random.rand(100) ... })
>>> # Create sequences for single-step forecasting >>> sequence_length = 4 >>> sequences, targets = create_sequences( ... df=data, ... sequence_length=sequence_length, ... target_col='target', ... step=1, ... include_overlap=True, ... drop_last=True, ... forecast_horizon=None ... ) >>> print(sequences.shape) (95, 4, 4) >>> print(targets.shape) (95,)
>>> # Create sequences for multi-step forecasting (e.g., 3 steps) >>> forecast_horizon = 3 >>> sequences, targets = create_sequences( ... df=data, ... sequence_length=4, ... target_col='target', ... step=1, ... include_overlap=True, ... drop_last=True, ... forecast_horizon=3 ... ) >>> print(sequences.shape) (92, 4, 4) >>> print(targets.shape) (92, 3)
Notes
Sequence Creation: The function slides a window of size sequence_length across the DataFrame to create input sequences. Each sequence is associated with a target value or sequence of values that immediately follow the input sequence.
- Forecast Horizon:
If forecast_horizon is None, the function creates targets for a single future time step.
If forecast_horizon is an integer H, the function creates targets consisting of the next H time steps.
Step Size: The step parameter controls the stride of the sliding window. A step of 1 results in overlapping sequences, while a larger step reduces overlap.
Handling Incomplete Sequences: If drop_last is set to False, the function includes the last sequence even if it doesn’t have enough data points to form a complete sequence or target.
Data Validation: The function utilizes are_all_frames_valid from geoprior.core.checks to ensure the integrity of input DataFrame before processing and exist_features to verify the presence of the target column.
The sequences generation can be expressed as:
(48)#\[\begin{split}\\text{For each sequence } i, \\\\ \\mathbf{X}^{(i)} = \\left[ \\mathbf{x}_{i}, \\mathbf{x}_{i+1}, \\\\ \\dots, \\mathbf{x}_{i+T-1} \\right] \\\\ y^{(i)} = \\begin{cases} \\mathbf{x}_{i+T} & \\text{if } \\text{forecast\\_horizon} = \\text{None} \\\\ \\left[ \\mathbf{x}_{i+T}, \\mathbf{x}_{i+T+1}, \\dots, \\\\ \\mathbf{x}_{i+T+H-1} \\right] & \\text{if } \\text{forecast\\_horizon} = H \\end{cases}\end{split}\]- Where:
\(\\mathbf{X}^{(i)}\) is the input sequence of length \(T\).
\(y^{(i)}\) is the target value(s) following the sequence.
See also
geoprior.models.utils.split_static_dynamicFunction to split sequences into static and dynamic inputs.
- geoprior.models.utils.extract_batches_from_dataset(dataset, num_batches_to_extract=1, agg=False, errors='warn')[source]#
Extracts a specified number of batches from a tf.data.Dataset. Optionally aggregates the extracted batches.
- Parameters:
dataset (
tf.data.Dataset) – The TensorFlow dataset to extract batches from.num_batches_to_extract (
Union[int,str], default1) – Number of batches: int, or ‘all’, ‘*’, ‘auto’.agg (
bool, defaultFalse) – If True, attempts to aggregate the extracted batches into a single tuple structure by concatenating corresponding tensors/arrays or aggregating dictionaries.errors (
str, default'warn') – Error handling: ‘raise’, ‘warn’, ‘ignore’.
- Returns:
If
aggisFalse, returns a list of batch tuples. IfaggisTrue, returns one aggregated tuple orNoneif no batches were extracted. When zero batches are requested or the dataset is empty, the function returns an empty list foragg=FalseandNoneforagg=True.- Return type:
Union[List[Tuple[Any,]],Optional[Tuple[Any,]]]- Raises:
TypeError – If dataset is not a tf.data.Dataset or num_batches_to_extract is invalid (and errors=’raise’).
ValueError – If num_batches_to_extract is negative, or fewer batches are available than requested (and errors=’raise’ and not taking all).
RuntimeError – For unexpected errors during dataset iteration (and errors=’raise’).
- geoprior.models.utils.extract_callbacks_from(fit_params, return_fit_params=False)[source]#
Extract Keras callbacks from a dictionary of fit parameters. The function scans the provided
fit_paramsdictionary, looks for keys associated with callback instances, and removes them fromfit_paramsreturning a list of callbacks. Optionally, it can return the updated dictionary without these callbacks.This function is particularly useful when working with scikit-learn- style estimators that pass parameters through
**fit_params. By extracting callbacks directly, the user can seamlessly integrate TensorFlow/Keras callbacks such as EarlyStopping or ModelCheckpoint into their training pipelines.The function can handle two scenarios: 1. A parameter called
'callbacks'containing a list of callbacks. 2. Individual callback instances passed as keyword arguments.After extraction, if
return_fit_params=True, it returns a tuple (callbacks, fit_params) where callbacks is the extracted list and fit_params is the remaining dictionary. Otherwise, it returns only callbacks.(49)#\[n_{\mathrm{callbacks}} = n_{\mathrm{list\_callbacks}} \;+\; n_{\mathrm{kwarg\_callbacks}}\]Here, \(n_{\mathrm{callbacks}}\) represents the total number of extracted callbacks, \(n_{\mathrm{list\_callbacks}}\) is the number of callbacks initially found in the
'callbacks'parameter, and \(n_{\mathrm{kwarg\_callbacks}}\) is the number of callbacks discovered among the other keyword arguments.- Parameters:
fit_params (
dict) –The dictionary of parameters to be passed to a model’s training method. May contain one of the following:
'callbacks': a list of callback instances.Arbitrary keyword arguments that are callback instances.
return_fit_params (
bool, optional) – IfTrue, returns a tuple of (callbacks, fit_params) where fit_params no longer contains the extracted callbacks. IfFalse, returns only the callbacks list. Default isFalse.
- Returns:
callbacks (
listoftf.keras.callbacks.Callback) – A list of extracted callback instances.(callbacks, fit_params) (
tuple (only if `return_fit_params=True)`) – A tuple where the first element is a list of extracted callbacks and the second is the updated fit_params dictionary after removing the callbacks.
Examples
>>> from geoprior.models.utils import extract_callbacks_from >>> from tensorflow.keras.callbacks import EarlyStopping >>> fit_params = { ... 'epochs': 100, ... 'batch_size': 64, ... 'verbose': 1, ... 'early_stopping': EarlyStopping(patience=5) ... } >>> callbacks, updated_params = extract_callbacks_from(fit_params, ... return_fit_params=True) >>> print(callbacks) [<keras.src.callbacks.EarlyStopping object at 0x...>] >>> print(updated_params) {'epochs': 100, 'batch_size': 64, 'verbose': 1}
Notes
Consider using this function when integrating Keras callbacks within a pipeline or estimator that follows scikit-learn conventions, where parameters are passed as fit_params. This approach enables a clean and modular integration of callbacks into your training loops.
See also
tf.keras.callbacks.CallbackBase class used to build new callbacks.
- geoprior.models.utils.forecast_multi_step(xtft_model, inputs, forecast_horizon, y=None, dt_col=None, mode='quantile', spatial_cols=None, q=None, tname=None, forecast_dt=None, apply_mask=False, mask_values=None, mask_fill_value=None, savefile=None, verbose=3, **kws)[source]#
Generate a multi-step forecast using the XTFT model.
This function generates forecasts for multiple future time steps using a pre-trained XTFT deep learning model. The model takes three inputs: X_static, X_dynamic, and X_future, and produces predictions according to the formulation:
(50)#\[\hat{y}_{t+i} = f\Bigl( X_{\text{static}},\; X_{\text{dynamic}},\; X_{\text{future}} \Bigr)\]for \(i = 1, \dots, forecast_horizon\), where \(f\) is the trained XTFT model.
- Parameters:
xtft_model (
object) – A validated Keras model instance. The model is expected to be verified viavalidate_keras_model.inputs (
listortupleofnumpy.ndarray) – A list containing three elements:X_static,X_dynamic, andX_future. Ifspatial_colsis provided, it is assumed that the first two columns ofX_staticcorrespond to the first and second spatial coordinates of the original training data.forecast_horizon (
int) – The number of future time steps to forecast. For example, ifforecast_horizonis 4, the model will generate predictions for 4 steps ahead.y (
numpy.ndarray, optional) – Actual target values. If provided, evaluation metrics such as R² Score and, in quantile mode, the coverage score are computed.dt_col (
str, optional) – Name of the time column (e.g."year"). If provided, a column with this name is added to the output DataFrame. The actual time values must be supplied externally.mode (
str, optional) – Forecast mode. Must be either"quantile"or"point". In quantile mode, predictions are generated for multiple quantiles (default:[0.1, 0.5, 0.9]); in point mode, a single prediction is generated.spatial_cols (
listofstr, optional) – A list of spatial column names. If provided, it must contain at least two elements corresponding to the first and second columns of the original training data’sX_static.time_steps (
int, optional) – The number of historical time steps used as input. Default is3.q (
listoffloat, optional) – List of quantile values for quantile forecasting. The default is[0.1, 0.5, 0.9]whenmodeis"quantile".tname (
str, optional) – Target variable name used to construct output column names. For instance, iftnameis"subsidence", then output columns may be named"subsidence_q10_step1","subsidence_q50_step2", etc. Default is"target".forecast_dt (
any, optional) – Forecast datetime information. If provided and its length matchesforecast_horizon, its values are added to the output DataFrame.apply_mask (
bool, optional) – If True, applies masking viamask_by_referenceto replace predictions in non-subsiding areas. Requires that bothmask_valuesandmask_fill_valueare provided.mask_values (
scalar, optional) – The reference value(s) used for masking. Must be provided ifapply_maskis True.mask_fill_value (
scalar, optional) – The value used to fill masked predictions. Must be provided ifapply_maskis True.savefile (
str, optional) – File path to save the forecast results as a CSV file. If not provided, a default filename is generated.verbose (
int, optional) – Verbosity level controlling printed output. Higher values produce more detailed messages.
- Returns:
A DataFrame containing the multi-step forecast results. In quantile mode, the DataFrame includes columns for each quantile and each forecast step (e.g.
<tname>_q10_step1,<tname>_q50_step2, etc.); in point mode, it contains a single prediction column per forecast step (e.g.<tname>_pred_step1). Ifyis provided, an additional column (<tname>_actual) is included.- Return type:
Examples
>>> from geoprior.models.transformers import XTFT >>> from geoprior.models.utils import forecast_multi_step >>> from geoprior.models.losses import combined_quantile_loss >>> import pandas as pd >>> import numpy as np >>> >>> # Create a dummy training DataFrame with a date column, >>> # spatial features ("longitude", "latitude"), two dynamic >>> # features ("feat1", "feat2"), a static feature ("stat1"), and >>> # the target variable "subsidence". >>> date_rng = pd.date_range(start="2020-01-01", periods=60, ... freq="D") >>> train_df = pd.DataFrame({ ... "date": date_rng, ... "longitude": np.random.uniform(-180, 180, 60), ... "latitude": np.random.uniform(-90, 90, 60), ... "feat1": np.random.rand(60), ... "feat2": np.random.rand(60), ... "stat1": np.random.rand(60), ... "subsidence": np.random.rand(60) ... }) >>> >>> # Prepare dummy input arrays for model training. >>> # X_static is constructed using "longitude" and "stat1". >>> X_static = train_df[["longitude", "stat1"]].values >>> # X_dynamic for "feat1" and "feat2" with time_steps = 3. >>> X_dynamic = np.random.rand(60, 3, 2) >>> # X_future is a dummy future feature array with shape (60, 3, 1). >>> X_future = np.random.rand(60, 3, 1) >>> # Target output from "subsidence" reshaped to >>> # (60, 1, 1). For multi-step forecast, forecast_horizon is 4. >>> forecast_horizon = 4 >>> y_array = train_df["subsidence"].values.reshape(60, 1, 1) >>> >>> # Instantiate a dummy XTFT model. >>> my_model = XTFT( ... static_input_dim=2, # "longitude" and "stat1" ... dynamic_input_dim=2, # "feat1" and "feat2" ... future_input_dim=1, # One future feature ... forecast_horizon=forecast_horizon, ... quantiles=[0.1, 0.5, 0.9], ... embed_dim=16, ... max_window_size=3, ... memory_size=50, ... num_heads=2, ... dropout_rate=0.1, ... lstm_units=32, ... attention_units=32, ... hidden_units=16 ... ) >>> my_model.compile( ... optimizer="adam", ... loss=combined_quantile_loss(my_model.quantiles) ... ) >>> >>> # Fit the model on the dummy data for demonstration. >>> my_model.fit( ... x=[X_static, X_dynamic, X_future], ... y=y_array, ... epochs=1, ... batch_size=8 ... ) >>> >>> # Generate forecast datetime values for the forecast horizon. >>> forecast_dates = pd.date_range(start="2020-02-01", ... periods=forecast_horizon, freq="D") >>> >>> # Package inputs as expected by forecast_multi_step. >>> inputs = [X_static, X_dynamic, X_future] >>> >>> # Generate a multi-step forecast in quantile mode. >>> forecast_df_quantile = forecast_multi_step( ... xtft_model=my_model, ... inputs=inputs, ... forecast_horizon=forecast_horizon, ... y=y_array, ... dt_col="date", ... mode="quantile", ... spatial_cols=["longitude", "latitude"], ... q=[0.1, 0.5, 0.9], ... tname="subsidence", ... forecast_dt=forecast_dates, ... apply_mask=False, ... verbose=3 ... ) >>> print("Quantile Forecast:") >>> print(forecast_df_quantile.head()) >>>
For point forecast
>>> # Instantiate a dummy XTFT model. >>> my_model = XTFT( ... static_input_dim=2, # "longitude" and "stat1" ... dynamic_input_dim=2, # "feat1" and "feat2" ... future_input_dim=1, # One future feature ... forecast_horizon=forecast_horizon, ... quantiles=None, # set quantiles to None ... embed_dim=16, ... max_window_size=3, ... memory_size=50, ... num_heads=2, ... dropout_rate=0.1, ... lstm_units=32, ... attention_units=32, ... hidden_units=16 ... ) >>> my_model.compile( ... optimizer="adam", loss="mse", ... ) >>> >>> # Fit the model on the dummy data for demonstration. >>> my_model.fit( ... x=[X_static, X_dynamic, X_future], ... y=y_array, ... epochs=1, ... batch_size=8 ... )
>>> # Generate a multi-step forecast in point mode. >>> forecast_df_point = forecast_multi_step( ... xtft_model=my_model, ... inputs=inputs, ... forecast_horizon=forecast_horizon, ... y=y_array, ... dt_col="date", ... mode="point", ... spatial_cols=["longitude", "latitude"], ... tname="subsidence", ... forecast_dt=forecast_dates, ... apply_mask=False, ... verbose=3 ... ) >>> print("Point Forecast:") >>> print(forecast_df_point.head())
Notes
In quantile mode, predictions are generated for each specified quantile for every forecast step, and the median (
0.5) is used for evaluation.In point mode, a single prediction is generated per forecast step.
The output prediction array is expected to have the shape \((n, forecast\_horizon, m)\), where \(n\) is the number of samples and \(m\) is the number of outputs per step (e.g., number of quantiles in quantile mode or 1 in point mode).
The provided
spatial_colsmust correspond to the first two columns of the original training data’sX_static.Evaluation metrics such as R² Score and Coverage Score (in quantile mode) are computed if actual target values (
y) are provided.The DataFrame is constructed by iterating over each sample and each forecast step.
See also
forecast_single_stepFunction for single-step forecasts.
coverage_scoreFunction to compute the coverage score.
validate_keras_modelFunction to validate a Keras model.
assert_ratioFunction to verify quantile ratios.
- geoprior.models.utils.forecast_single_step(xtft_model, inputs, y=None, dt_col=None, mode='quantile', spatial_cols=None, q=None, tname=None, forecast_dt=None, apply_mask=False, mask_values=None, mask_fill_value=None, savefile=None, verbose=3, **kws)[source]#
Generate a single-step forecast using the XTFT model.
This function generates a forecast for a single future time step using a pre-trained XTFT deep learning model. The model takes three inputs: X_static, X_dynamic, and X_future, and produces a prediction according to the formulation:
(51)#\[\hat{y}_{t+1} = f\Bigl( X_{\text{static}},\; X_{\text{dynamic}},\; X_{\text{future}} \Bigr)\]where \(f\) is the trained XTFT model. The predictions can be either quantile-based or point-based, as determined by the mode parameter.
- Parameters:
xtft_model (
object) – A validated Keras model instance. The model is expected to be verified viavalidate_keras_model.inputs (
listortupleofnumpy.ndarray) – A list containing three elements:X_static,X_dynamic, andX_future. Ifspatial_colsis provided, it is assumed that the first column ofX_staticcorresponds to the first spatial coordinate and the second column to the second spatial coordinate of the original training data.y (
numpy.ndarray, optional) – Actual target values. If provided, evaluation metrics such as R² Score and (in quantile mode) the coverage score are computed.dt_col (
str, optional) – Name of the time column (e.g."year"). If provided, a column with this name is added to the output DataFrame. The actual time values must be supplied externally.mode (
str, optional) – Forecast mode. Must be either"quantile"or"point". In quantile mode, predictions are generated for multiple quantiles (default:0.1,0.5, and0.9).spatial_cols (
listofstr, optional) – List of spatial column names. If provided, it must contain at least two elements and correspond to the first and second columns of the original training data’sX_static.q (
listoffloat, optional) – List of quantiles for quantile forecasting. Default is[0.1, 0.5, 0.9]when mode is"quantile".tname (
str, optional) – Target variable name for predictions. This name is used to construct output column names (e.g."subsidence"). Default is"target".forecast_dt (
any, optional) – Forecast datetime information. Not used in this function but may be provided for compatibility.apply_mask (
bool, optional) – If True, applies a masking function (mask_by_reference) to replace predictions in non-subsiding areas. Requires that bothmask_valuesandmask_fill_valueare provided.mask_values (
scalar, optional) – Reference value(s) used for masking. Must be provided ifapply_maskis True.mask_fill_value (
scalar, optional) – Value used to fill masked predictions. Must be provided ifapply_maskis True.savefile (
str, optional) – Path to a CSV file where the forecast results will be saved. If not provided, a default filename is generated.verbose (
int, optional) – Verbosity level controlling printed output. Higher values result in more detailed output.
- Returns:
A DataFrame containing the forecast results. In quantile mode, the output includes columns for each quantile (e.g.
<tname>_q10,<tname>_q50,<tname>_q90). In point mode, a single prediction column (<tname>_pred) is provided. If y is provided, an additional column with the actual target values (<tname>_actual) is included.- Return type:
Examples
>>> from geoprior.models.transformers import XTFT >>> from geoprior.models.utils import forecast_single_step >>> import pandas as pd >>> import numpy as np >>> >>> # Create a dummy training DataFrame with a date column, >>> # two dynamic features ("feat1", "feat2"), a static feature ("stat1"), >>> # and dummy spatial features ("longitude", "latitude"), as well as the >>> # target variable "subsidence". >>> date_rng = pd.date_range(start="2020-01-01", periods=50, freq="D") >>> train_df = pd.DataFrame({ ... "date": date_rng, ... "longitude": np.random.uniform(-180, 180, 50), ... "latitude": np.random.uniform(-90, 90, 50), ... "feat1": np.random.rand(50), ... "feat2": np.random.rand(50), ... "stat1": np.random.rand(50), ... "subsidence": np.random.rand(50) ... }) >>> >>> # Prepare dummy inputs for the model. >>> # For the static input, combine the spatial feature "longitude" and the >>> # static feature "stat1". The forecast_single_step function expects that, >>> # if spatial_cols is provided, the first two columns of X_static correspond >>> # to the spatial coordinates. >>> X_static = train_df[["longitude", "stat1"]].values # shape: (50, 2) >>> >>> # Create a dummy dynamic input array for "feat1" and "feat2". >>> # Assume time_steps = 3, so the shape is (50, 3, 2). >>> X_dynamic = np.random.rand(50, 3, 2) >>> >>> # Create a dummy future input array. >>> # For this example, assume one future feature with shape (50, 3, 1). >>> X_future = np.random.rand(50, 3, 1) >>> >>> # Create dummy target output from "subsidence", reshaped to (50, 1, 1) >>> y_array = train_df["subsidence"].values.reshape(50, 1, 1) >>> >>> # Instantiate a dummy XTFT model. >>> # The model expects: >>> # - X_static with shape (n_samples, static_input_dim) >>> # - X_dynamic with shape (n_samples, time_steps, dynamic_input_dim) >>> # - X_future with shape (n_samples, time_steps, future_input_dim) >>> my_model = XTFT( ... static_input_dim=2, # "longitude" and "stat1" ... dynamic_input_dim=2, # "feat1" and "feat2" ... future_input_dim=1, # One future feature ... forecast_horizon=1, # Single-step forecast ... quantiles=[0.1, 0.5, 0.9], ... embed_dim=16, ... max_window_size=3, ... memory_size=50, ... num_heads=2, ... dropout_rate=0.1, ... lstm_units=32, ... attention_units=32, ... hidden_units=16 ... ) >>> my_model.compile(optimizer="adam") >>> >>> # Fit the model on the dummy data. >>> my_model.fit( ... x=[X_static, X_dynamic, X_future], ... y=y_array, ... epochs=1, ... batch_size=8 ... ) >>> >>> # Package the inputs as expected by forecast_single_step. >>> inputs = [X_static, X_dynamic, X_future] >>> >>> # Generate a single-step quantile forecast. >>> forecast_df = forecast_single_step( ... xtft_model=my_model, ... inputs=inputs, ... y=y_array, ... dt_col="date", # The time column name in the output ... mode="quantile", # Can be "quantile" or "point" ... spatial_cols=["longitude", "latitude"], ... q=[0.1, 0.5, 0.9], ... tname="subsidence", ... apply_mask=True, ... mask_values=0, ... mask_fill_value=0, ... verbose=3 ... ) >>> print(forecast_df.head())
Notes
In quantile mode, the function computes predictions for multiple quantiles and uses the median (
0.5) for evaluation.If
spatial_colsis provided, it must be the first and second columns of the original training data’sX_static.The function internally utilizes
validate_keras_modelfor model validation,assert_ratiofor quantile verification, andmask_by_referencefor masking operations.Evaluation metrics such as R² Score and Coverage Score are computed if actual target values (y) are provided.
The prediction output is expected to have the shape \((n, 1, m)\), where \(m\) is the number of outputs (e.g., the number of quantiles in quantile mode, or 1 in point mode).
See also
generate_forecast_multi_stepFunction for multi-step forecasts.
coverage_scoreFunction to compute the coverage score.
validate_keras_modelFunction to validate a Keras model.
assert_ratioFunction to validate quantile ratios.
- geoprior.models.utils.format_predictions(predictions=None, model=None, inputs=None, y_true_sequences=None, target_name='target', quantiles=None, forecast_horizon=None, output_dim=None, spatial_data_array=None, spatial_cols=None, spatial_cols_indices=None, evaluate_coverage=False, scaler=None, scaler_feature_names=None, target_idx_in_scaler=None, verbose=0, _logger=None, **kwargs)[source]#
Formats model predictions into a structured pandas DataFrame.
This utility function takes raw model predictions (either directly as an array/tensor or generated by a provided model and its inputs) and transforms them into a long-format pandas DataFrame. It can handle point forecasts, quantile forecasts, single or multi-output predictions, and optionally include actual target values, spatial identifiers, and perform coverage score evaluation for quantile forecasts.
The output DataFrame is structured with ‘sample_idx’ (identifying the original input sequence) and ‘forecast_step’ (from 1 to H, where H is the forecast horizon).
- Parameters:
predictions (
np.ndarrayortf.Tensor, optional) –The raw prediction tensor or array.
- For point forecasts, expected shapes:
(num_samples, forecast_horizon, output_dim)
(num_samples, forecast_horizon) if output_dim=1 (will be reshaped)
(num_samples, output_dim) if forecast_horizon=1 (will be reshaped)
- For quantile forecasts, expected shapes:
(num_samples, forecast_horizon, num_quantiles * output_dim)
(num_samples, forecast_horizon, num_quantiles, output_dim)
(num_samples, num_quantiles * output_dim) if forecast_horizon=1
If
None, model and inputs must be provided. Default isNone.model (
tf.keras.Model, optional) – A trained Keras model to generate predictions if predictions is not provided. Used in conjunction with inputs. Default isNone.inputs (
List[Optional[Union[np.ndarray,tf.Tensor]]], optional) – A list of input tensors (e.g., [static, dynamic, future]) required by the model to generate predictions. Required if predictions isNoneand model is provided. Default isNone.y_true_sequences (
np.ndarrayortf.Tensor, optional) – The true target values corresponding to the predictions, used for including actuals in the output DataFrame and for evaluation. Expected shape: (num_samples, forecast_horizon, output_dim). Default isNone.target_name (
str, optional) – Base name for the target variable. Used to prefix prediction and actual column names (e.g., “sales_pred”, “sales_q50”, “sales_actual”). Default is “target”.quantiles (
List[float], optional) – A list of quantiles that were predicted by the model (e.g., [0.1, 0.5, 0.9]). Required if the predictions are quantile forecasts. If provided, prediction columns will be named like {target_name}_q10, {target_name}_q50, etc. Default isNone(for point forecasts).forecast_horizon (
int, optional) – The number of time steps into the future that the model predicts. If not provided, it’s inferred from predictions.shape[1] or y_true_sequences.shape[1]. Default isNone.output_dim (
int, optional) – The number of target variables predicted at each time step (e.g., 1 for univariate, >1 for multivariate target). If not provided, it’s inferred from the shape of predictions or y_true_sequences. Default isNone.spatial_data_array (
np.ndarrayortf.Tensororpd.DataFrameorpd.Series, optional) –An array or DataFrame containing static spatial/identifier features for each of the num_samples sequences.
If NumPy/Tensor: Expected shape (num_samples, num_spatial_features). spatial_cols_indices must be provided.
If DataFrame/Series: Must have num_samples rows. spatial_cols must be provided.
These features will be repeated for each forecast step in the output DataFrame. Default is
None.spatial_cols (
List[str], optional) – List of column names to select from spatial_data_array if it’s a DataFrame/Series, or names to assign to columns if spatial_data_array is NumPy/Tensor and spatial_cols_indices are provided. Default isNone.spatial_cols_indices (
List[int], optional) – List of column indices to select from spatial_data_array if it’s a NumPy/Tensor. Length must match spatial_cols if provided. Default isNone.evaluate_coverage (
bool, defaultFalse) – IfTrue, quantiles are provided (at least two), and y_true_sequences is available, calculates the coverage score using the first and last quantiles as interval bounds. Requires geoprior.metrics.coverage_score.scaler (
Any, optional) – A fitted scikit-learn-like scaler object (must have an inverse_transform method) used to scale the target variable and potentially other features. If provided along with scaler_feature_names and target_idx_in_scaler, predictions and actuals for the target will be inverse-transformed. Default isNone.scaler_feature_names (
List[str], optional) – A list of all feature names (in order) that the scaler was originally fit on. Required if scaler is provided and targeted inverse transformation is needed. Default isNone.target_idx_in_scaler (
int, optional) – The index of the target_name within the scaler_feature_names list. Required if scaler is provided and targeted inverse transformation is needed. Default isNone.verbose (
int, default0) –Verbosity level for logging during processing.
0: Silent.1: Basic info.3: More detailed steps.5: Very detailed shape information.
**kwargs (Any) – Additional keyword arguments (currently not used but included for future extensibility).
**kwargs
- Returns:
A long-format DataFrame containing
sample_idxandforecast_step, optional spatial columns, prediction columns, and actual-value columns wheny_true_sequencesis provided. Point forecasts use names like{target_name}_predor{target_name}_{output_idx}_pred. Quantile forecasts use names like{target_name}_qXXor{target_name}_{output_idx}_qXX. Actual values use{target_name}_actualor{target_name}_{output_idx}_actual. Prediction and actual values are inverse-transformed when valid scaler information is provided.- Return type:
- Raises:
ValueError – If predictions is None and model or inputs is also None. If predictions shape is invalid (not 2D, 3D, or 4D). If quantiles are provided but prediction shape is incompatible for inferring output_dim. If spatial_data_array is provided without necessary name/index parameters.
TypeError – If predictions or other inputs cannot be converted to the expected tensor/array types.
See also
geoprior.models.utils.forecast_multi_stepHigher-level forecasting utility.
geoprior.metrics.coverage_scoreFor evaluating quantile forecast intervals.
Examples
>>> import tensorflow as tf >>> import numpy as np >>> from geoprior.models.utils import format_predictions_to_dataframe
>>> B, H, O = 4, 3, 1 # Batch, Horizon, OutputDim >>> Q = [0.1, 0.5, 0.9] >>> preds_point = tf.random.normal((B, H, O)) >>> preds_quant = tf.random.normal((B, H, len(Q))) # For O=1 >>> y_true = tf.random.normal((B, H, O))
>>> # Point forecast >>> df_point = format_predictions_to_dataframe( ... predictions=preds_point, y_true_sequences=y_true, ... target_name="value", forecast_horizon=H, output_dim=O ... ) >>> print(df_point.head(H)) # Show first sample's horizon sample_idx forecast_step value_pred value_actual 0 0 1 -0.576731 -0.647362 1 0 2 0.183931 1.198977 2 0 3 -0.766871 0.534040
>>> # Quantile forecast >>> df_quant = format_predictions_to_dataframe( ... predictions=preds_quant, y_true_sequences=y_true, ... target_name="value", quantiles=Q, ... forecast_horizon=H, output_dim=O ... ) >>> print(df_quant.head(H)) sample_idx forecast_step value_q10 value_q50 value_q90 value_actual 0 0 1 -0.209947 0.263107 -0.308929 -0.647362 1 0 2 0.303091 0.594701 -0.225007 1.198977 2 0 3 0.136699 -1.237739 0.002834 0.534040
>>> # With spatial data (NumPy array) >>> spatial_np = np.array([[101, 201], [102, 202], [103, 203], [104, 204]]) >>> df_spatial = format_predictions_to_dataframe( ... predictions=preds_point, ... spatial_data_array=spatial_np, ... spatial_cols=['store_id', 'region_id'], ... spatial_cols_indices=[0, 1] ... ) >>> print(df_spatial[['sample_idx', 'forecast_step', 'store_id']].head(H)) sample_idx forecast_step store_id 0 0 1 101.0 1 0 2 101.0 2 0 3 101.0
- geoprior.models.utils.format_predictions_to_dataframe(predictions=None, model=None, inputs=None, y_true_sequences=None, target_name='target', quantiles=None, forecast_horizon=None, output_dim=None, spatial_data_array=None, spatial_cols=None, spatial_cols_indices=None, evaluate_coverage=False, scaler=None, scaler_feature_names=None, target_idx_in_scaler=None, verbose=0, **kwargs)[source]#
Deprecated alias for format_predictions. See format_predictions for the updated, recommended API. All original parameters are forwarded to format_predictions.
- Returns:
The formatted prediction DataFrame from format_predictions.
- Return type:
pd.DataFrame- Parameters:
- geoprior.models.utils.generate_forecast(xtft_model, train_data, dt_col, dynamic_features, future_features=None, static_features=None, test_data=None, mode='quantile', spatial_cols=None, forecast_horizon=4, time_steps=3, q=None, tname=None, forecast_dt=None, savefile=None, verbose=3, **kw)[source]#
Generate forecast using the XTFT model.
This function uses a pre-trained Keras model to forecast future values based on provided historical data. The model receives three inputs: X_static, X_dynamic, and X_future re-built from train_data, and outputs predictions over a specified forecast horizon.
See more in User Guide.
- Parameters:
xtft_model (
object) – A validated Keras model instance. It is processed by thevalidate_keras_modelmethod.train_data (
pandas.DataFrame) – The training data containing historical records. Must include the dt_col and all required feature columns.dt_col (
str) – Name of the column representing time. It may be a datetime or numeric column (e.g."year").dynamic_features (
listofstr) – List of dynamic feature column names. They are formatted viacolumns_manager.future_features (
listofstr, optional) – List of future feature names. These columns are tiled over the forecast horizon.static_features (
listofstr, optional) – List of static feature names. If not provided, a dummy input is used.test_data (
pandas.DataFrame, optional) – DataFrame containing actual values used for evaluation. If provided, it is used to compute the R² and coverage score formode='quantile'.mode (
str, optional) – Forecast mode. Must be either"quantile"or"point". Inquantilemode, predictions for multiple quantiles (default: [0.1, 0.5, 0.9]) are computed.spatial_cols (
listofstr, optional) – List of spatial column names for grouping the data. When provided, forecasts are computed per location; otherwise, a global forecast is performed.forecast_horizon (
int, optional) – Number of future periods to forecast. Default is 4.time_steps (
int, optional) – Number of past time steps to use as input for the model. Default is 3.q (
listoffloat, optional) – List of quantiles for use inquantilemode. Default is [0.1, 0.5, 0.9]. Each quantile is validated by theassert_ratiofunction.tname (
str, optional) – Target variable name used for constructing forecast result columns. Defaults to"target".forecast_dt (
listorstr, optional) – List of forecast dates or"auto"to derive dates from dt_col. In auto mode, if dt_col is datetime, frequency is inferred usingpd.infer_freq.savefile (
str, optional) – Path to the CSV file where forecast results will be saved. If not provided, a default filename is generated.verbose (
int, optional) – Verbosity level (0-7). Controls the amount of execution output.
- Returns:
A DataFrame containing the forecast results. In
quantilemode, each forecast period includes columns for each quantile; inpointmode, a single prediction column is provided.- Return type:
Examples
Example refering to Train data only
>>> import os >>> import pandas as pd >>> import numpy as np >>> from geoprior.models.transformers import XTFT >>> from geoprior.models.losses import combined_quantile_loss >>> from geoprior.models.utils import generate_forecast >>> >>> # Create a dummy training DataFrame with a date column, >>> # dynamic features "feat1", "feat2", static feature "stat1", >>> # and target "price". >>> date_rng = pd.date_range(start="2020-01-01", periods=50, freq="D") >>> train_df = pd.DataFrame({ ... "date": date_rng, ... "feat1": np.random.rand(50), ... "feat2": np.random.rand(50), ... "stat1": np.random.rand(50), ... "price": np.random.rand(50) ... }) >>> >>> # Prepare a dummy XTFT model with example parameters. >>> # Note: The model expects the following input shapes: >>> # - X_static: (n_samples, static_input_dim) >>> # - X_dynamic: (n_samples, time_steps, dynamic_input_dim) >>> # - X_future: (n_samples, time_steps, future_input_dim) >>> my_model = XTFT( ... static_input_dim=1, # "stat1" ... dynamic_input_dim=2, # "feat1" and "feat2" ... future_input_dim=1, # features provided for dim1 ... forecast_horizon=5, # Forecasting 5 periods ahead ... quantiles=[0.1, 0.5, 0.9], ... embed_dim=16, ... max_window_size=3, ... memory_size=50, ... num_heads=2, ... dropout_rate=0.1, ... lstm_units=32, ... attention_units=32, ... hidden_units=16 ... ) >>> my_model.compile(optimizer="adam") >>> >>> # Create dummy input arrays for model fitting. >>> # For simplicity, assume time_steps = 3 and use random data. >>> X_static = train_df[["stat1"]].values # shape: (50, 1) >>> # Create a dummy dynamic input array of shape (50, 3, 2) >>> X_dynamic = np.random.rand(50, 3, 2) >>> # Create a dummy features >>> X_future = np.random.rand(50, 3, 1) >>> # Create dummy target output from "price" >>> y_array = train_df["price"].values.reshape(50, 1, 1) >>> >>> # Fit the model on the dummy data. >>> my_model.fit( ... x=[X_static, X_dynamic, X_future], ... y=y_array, ... epochs=1, ... batch_size=8 ... ) >>> >>> # Generate forecast using the generate_forecast function. >>> forecast = generate_forecast( ... xtft_model=my_model, ... train_data=train_df, ... dt_col="date", ... dynamic_features=["feat1", "feat2"], ... static_features=["stat1"], ... forecast_horizon=5, ... time_steps=3, ... tname="price", ... mode="quantile", ... verbose=3 ... ) >>> print(forecast.head())
Example refering to Test data included.
>>> # Create a dummy DataFrame with a date column, >>> # two dynamic features ("feat1", "feat2"), one static feature ("stat1"), >>> # and target "price". >>> date_rng = pd.date_range(start="2020-01-01", periods=60, freq="D") >>> data = { ... "date": date_rng, ... "feat1": np.random.rand(60), ... "feat2": np.random.rand(60), ... "stat1": np.random.rand(60), ... "price": np.random.rand(60) ... } >>> df = pd.DataFrame(data) >>> >>> # Split the DataFrame into training and test sets. >>> # Training data: dates before 2020-02-01 >>> # Test data: dates from 2020-02-01 onward. >>> train_df = df[df["date"] < "2020-02-01"].copy() >>> test_df = df[df["date"] >= "2020-02-01"].copy() >>> >>> # Create dummy input arrays for model fitting. >>> # Assume time_steps = 3. >>> X_static = train_df[["stat1"]].values # Shape: (n_train, 1) >>> X_dynamic = np.random.rand(len(train_df), 3, 2) >>> X_future = np.random.rand(len(train_df), 3, 1) >>> # Create dummy target output from "price". >>> y_array = train_df["price"].values.reshape(len(train_df), 1, 1) >>> >>> # Instantiate a dummy XTFT model. >>> my_model = XTFT( ... static_input_dim=1, # "stat1" ... dynamic_input_dim=2, # "feat1" and "feat2" ... future_input_dim=1, # For the provided future feature ... forecast_horizon=5, # Forecasting 5 periods ahead ... quantiles=[0.1, 0.5, 0.9], ... embed_dim=16, ... max_window_size=3, ... memory_size=50, ... num_heads=2, ... dropout_rate=0.1, ... lstm_units=32, ... attention_units=32, ... hidden_units=16 ... ) >>> loss_fn = combined_quantile_loss(my_model.quantiles) >>> my_model.compile(optimizer="adam", loss=loss_fn) >>> >>> # Fit the model on the training data. >>> my_model.fit( ... x=[X_static, X_dynamic, X_future], ... y=y_array, ... epochs=1, ... batch_size=8, ... callbacks = [early_stopping, model_checkpoint] ... ) >>> >>> # Generate forecast using the generate_forecast function. >>> # This example uses test_df for evaluation, which will compute >>> # metrics like R² Score and Coverage Score. >>> forecast = generate_forecast( ... xtft_model=my_model, ... train_data=train_df, ... dt_col="date", ... dynamic_features=["feat1", "feat2"], ... static_features=["stat1"], ... test_data=test_df.iloc[:5, :], # to fit the first horizon forecasting. ... forecast_horizon=5, ... time_steps=3, ... tname="price", ... mode="quantile", ... verbose=3 ... ) >>> print(forecast.head())
Example of Point forecasting
>>> # Create a dummy training DataFrame with a date column, >>> # two dynamic features ("feat1", "feat2"), one static feature ("stat1"), >>> # and target "price". >>> date_rng = pd.date_range(start="2020-01-01", periods=50, freq="D") >>> train_df = pd.DataFrame({ ... "date": date_rng, ... "feat1": np.random.rand(50), ... "feat2": np.random.rand(50), ... "stat1": np.random.rand(50), ... "price": np.random.rand(50) ... }) >>> >>> # Create dummy input arrays for model fitting. >>> # X_static is derived from the static feature "stat1". >>> X_static = train_df[["stat1"]].values # shape: (50, 1) >>> >>> # X_dynamic is a dummy dynamic array for "feat1" and "feat2". >>> # For time_steps = 3, its shape is (50, 3, 2). >>> X_dynamic = np.random.rand(50, 3, 2) >>> >>> # X_future is a dummy array for future features. >>> # Here, we assume a single future feature with shape (50, 3, 1). >>> X_future = np.random.rand(50, 3, 1) >>> >>> # Create dummy target output from "price". >>> y_array = train_df["price"].values.reshape(50, 1, 1) >>> >>> # Instantiate a dummy XTFT model. >>> my_model = XTFT( ... static_input_dim=1, # "stat1" ... dynamic_input_dim=2, # "feat1" and "feat2" ... future_input_dim=1, # Provided future feature ... forecast_horizon=5, # Forecast 5 periods ahead ... quantiles=None, # [0.1, 0.5, 0.9] Not used in point mode ... embed_dim=16, ... max_window_size=3, ... memory_size=50, ... num_heads=2, ... dropout_rate=0.1, ... lstm_units=32, ... attention_units=32, ... hidden_units=16 ... ) >>> my_model.compile(optimizer="adam") >>> >>> # Fit the model on the dummy data. >>> my_model.fit( ... x=[X_static, X_dynamic, X_future], ... y=y_array, ... epochs=1, ... batch_size=8 ... ) >>> >>> # Generate forecast using the generate_forecast function in point mode. >>> forecast = generate_forecast( ... xtft_model=my_model, ... train_data=train_df, ... dt_col="date", ... dynamic_features=["feat1", "feat2"], ... static_features=["stat1"], ... forecast_horizon=5, ... time_steps=3, ... tname="price", ... mode="point", ... verbose=3 ... ) >>> print(forecast.head())
Notes
The function groups data by spatial_cols if provided, and formats features via
columns_manager. It validates the time column usingcheck_datetimeand uses dummy inputs for missing static or future features. The forecast is produced by invokingxtft_model.predicton a list containing static, dynamic, and future inputs. The predictions are generated as follows:(52)#\[\hat{y}_{t+i} = f\Bigl(X_{\text{static}},\; X_{\text{dynamic}},\; X_{\text{future}}\Bigr)\]where \(i\) denotes the forecast period.
See also
geoprior.models.utils.reshape_xtft_dataFunction to reshape data for XTFT models.
geoprior.utils.validator.validate_keras_modelFunction to validate Keras model compatibility.
geoprior.core.handlers.columns_managerUtility to manage and format column names.
geoprior.core.checks.check_datetimeFunction to check and validate datetime columns.
geoprior.core.checks.check_spatial_columnsFunction to validate spatial columns in data.
geoprior.core.checks.assert_ratioFunction to validate and assert ratio values.
geoprior.metrics_special.coverage_scoreFunction to compute coverage score for quantile predictions.
- geoprior.models.utils.generate_forecast_with(xtft_model, inputs, forecast_horizon, y=None, dt_col=None, mode='quantile', spatial_cols=None, q=None, tname=None, forecast_dt=None, apply_mask=False, mask_values=None, mask_fill_value=None, savefile=None, verbose=3, **kw)[source]#
Generate forecasts using a pre-trained XTFT model based on the forecast horizon.
There are two approaches to generating forecasts with an XTFT model:
A monolithic function (e.g.
generate_forecast) that handles both single-step and multi-step forecasts within a single implementation. This approach results in a single, large function that internally branches its logic based on the value of the forecast horizon.A modular design where the single-step and multi-step forecasting functionalities are separated into two distinct functions (e.g.
forecast_single_stepandforecast_multi_step), with a thin wrapper (e.g.generate_xtft_forecast) that dispatches to the appropriate function based on the forecast horizon.
The modular approach (2) is generally preferred because it separates concerns and improves code readability, maintainability, and unit testing. Each function becomes responsible for a well-defined task: one for single-step forecasts and one for multi-step forecasts. The wrapper function, which we propose to name
generate_xtft_forecast, simply selects the correct method based on the forecast horizon. Use this approach when your application may need to handle both short- and long- term forecasts, as it keeps the codebase modular and easier to debug.Below is an implementation of the wrapper function
generate_xtft_forecastthat callsforecast_single_stepwhenforecast_horizonequals 1 andforecast_multi_stepwhenforecast_horizonis greater than 1.- Parameters:
xtft_model (
object) – A validated Keras model instance. The model is expected to be verified viavalidate_keras_model.inputs (
listortupleofnumpy.ndarray) – A list containing three elements:X_static,X_dynamic, andX_future. Ifspatial_colsis provided, it is assumed that the first two columns ofX_staticcorrespond to the first and second spatial coordinates of the original training data.forecast_horizon (
int) – The number of future time steps to forecast. A value of 1 triggers a single-step forecast; values greater than 1 trigger a multi-step forecast.y (
numpy.ndarray, optional) – Actual target values for evaluation. If provided, evaluation metrics (e.g., R² Score, and in quantile mode, the coverage score) are computed.dt_col (
str, optional) – Name of the time column (e.g."year"). If provided, the output DataFrame includes a column with these values.mode (
str, optional) – Forecast mode, either"quantile"or"point". In quantile mode, predictions are generated for multiple quantiles (default:[0.1, 0.5, 0.9]); in point mode, a single prediction is generated.spatial_cols (
listofstr, optional) – List of spatial column names. If provided, it must contain at least two elements corresponding to the first and second columns of the original training data’sX_static.time_steps (
int, optional) – The number of historical time steps used as input.q (
listoffloat, optional) – List of quantile values for quantile forecasting. Default is[0.1, 0.5, 0.9]whenmodeis"quantile".tname (
str, optional) – Target variable name used to construct output column names (e.g.,"subsidence"). Default is"target".forecast_dt (
any, optional) – Forecast datetime information. If provided and its length matchesforecast_horizon, its values are added to the output DataFrame.apply_mask (
bool, optional) – If True, applies masking (viamask_by_reference) to adjust predictions in non-subsiding areas. Requires that bothmask_valuesandmask_fill_valueare provided.mask_values (
scalar, optional) – The reference value(s) used for masking. Must be provided ifapply_maskis True.mask_fill_value (
scalar, optional) – The value used to fill masked predictions. Must be provided ifapply_maskis True.savefile (
str, optional) – File path to save the forecast results as a CSV file. If not provided, a default filename is generated.verbose (
int, optional) – Verbosity level controlling printed output.**kw (
dict, optional) – Does nothing; here for future extension.
- Returns:
A DataFrame containing the forecast results. In quantile mode, the output includes columns for each quantile and forecast step (e.g.
<tname>_q10_step1,<tname>_q50_step2, etc.); in point mode, it contains a single prediction column per forecast step (e.g.<tname>_pred_step1). Ifyis provided, an additional column (<tname>_actual) is included.- Return type:
Examples
>>> from geoprior.models.transformers import XTFT >>> from geoprior.models.utils import generate_forecast_with >>> import numpy as np >>> >>> # Prepare a dummy XTFT model with example parameters. >>> my_model = XTFT( ... static_input_dim=10, ... dynamic_input_dim=5, ... future_input_dim=3, ... forecast_horizon=1, # This parameter will be updated in the ... # wrapper function based on forecast_horizon. ... quantiles=[0.1, 0.5, 0.9], ... embed_dim=32, ... max_window_size=3, ... memory_size=100, ... num_heads=4, ... dropout_rate=0.1, ... lstm_units=64, ... attention_units=64, ... hidden_units=32 ... ) >>> my_model.compile(optimizer='adam') >>> >>> # Create dummy input data. >>> X_static = np.random.rand(100, 10) >>> X_dynamic = np.random.rand(100, 3, 5) >>> X_future = np.random.rand(100, 3, 3) >>> y_array = np.random.rand(100, 1, 1) # For single-step target output. >>> inputs = [X_static, X_dynamic, X_future] >>> >>> # Fit the model with dummy data. >>> my_model.fit( ... x=[X_static, X_dynamic, X_future], ... y=y_array, ... epochs=1, ... batch_size=32 ... ) >>> >>> # Example for a single-step forecast: >>> forecast_df = generate_forecast_with( ... xtft_model=my_model, ... inputs=inputs, ... forecast_horizon=1, ... y=y_array, ... dt_col="year", ... mode="quantile", ... spatial_cols=["longitude", "latitude"], ... tname="subsidence", ... verbose=3 ... ) >>> print(forecast_df.head()) >>> >>> # Example for a multi-step forecast: >>> forecast_dates = ["2023", "2024", "2025", "2026"] >>> forecast_df = generate_forecast_with( ... xtft_model=my_model, ... inputs=inputs, ... forecast_horizon=4, ... y=y_array, ... dt_col="year", ... mode="point", ... spatial_cols=["longitude", "latitude"], ... tname="subsidence", ... forecast_dt=forecast_dates, ... verbose=3 ... ) >>> print(forecast_df.head())
See also
forecast_single_stepGenerates a single-step forecast.
forecast_multi_stepGenerates a multi-step forecast.
validate_keras_modelValidates a Keras model.
coverage_scoreComputes the coverage score.
- geoprior.models.utils.prepare_model_inputs(dynamic_input, static_input=None, future_input=None, model_type='strict', forecast_horizon=None, verbose=0, **kwargs)[source]#
Prepares a list of input tensors for a model’s call method.
This function standardizes the creation of the input list [static, dynamic, future] expected by many models in
geoprior. It handles cases where static or future inputs might beNone, creating appropriate dummy tensors with zero features if the model_type is ‘strict’.- Parameters:
dynamic_input (
np.ndarrayortf.Tensor) – The dynamic (past observed) features. This input is always required and must be a valid tensor or array. Expected shape: (batch_size, past_time_steps, num_dynamic_features).static_input (
np.ndarrayortf.Tensor, optional) – The static (time-invariant) features. Expected shape: (batch_size, num_static_features). IfNoneand model_type is ‘strict’, a dummy tensor with 0 static features will be created. Default isNone.future_input (
np.ndarrayortf.Tensor, optional) – The known future features. Expected shape: (batch_size, future_time_span, num_future_features). IfNoneand model_type is ‘strict’, a dummy tensor with 0 future features will be created. The time span for this dummy future tensor will be past_time_steps (from dynamic_input) plus forecast_horizon if provided, otherwise just past_time_steps. Default isNone.model_type (
{'strict', 'flexible'}, default'strict') –Determines how
Noneinputs for static and future features are handled:'strict': If static_input or future_input isNone, a dummy tensor with a feature dimension of 0 will be created and included in the output list. This is for models that expect a 3-element list of tensors, even if some paths are unused.'flexible': If static_input or future_input isNone,Noneitself will be placed in the corresponding position in the output list. This is for models that can internally handleNoneinputs for optional feature types.
forecast_horizon (
int, optional) – The forecast horizon. Used only if model_type=’strict’ and future_input isNone, to determine the time dimension of the dummy future tensor (as past_time_steps + forecast_horizon). If not provided in this scenario, the dummy future tensor’s time dimension will match dynamic_input’s past_time_steps. Default isNone.verbose (
int, default0) –Verbosity level. If > 0, prints information about dummy tensor creation.
0: Silent.1: Basic info on dummy creation.2: More details on shapes.
- Returns:
A list containing three elements in the order: [processed_static_input, processed_dynamic_input, processed_future_input]. Elements can be TensorFlow tensors or
None(if model_type=’flexible’ and original input wasNone). All returned tensors are cast to tf.float32.- Return type:
List[Optional[tf.Tensor]]- Raises:
ValueError – If dynamic_input is
None. If dynamic_input is not at least 2D (needs batch dimension). If static_input (when provided) is not 2D. If future_input (when provided) is not 3D.TypeError – If inputs cannot be converted to TensorFlow tensors.
Examples
>>> import tensorflow as tf >>> import numpy as np >>> from geoprior.models.utils import prepare_model_inputs >>> B, T, H = 2, 10, 3 >>> D_s, D_d, D_f = 2, 4, 1 >>> dyn_in = tf.random.normal((B, T, D_d)) >>> stat_in = tf.random.normal((B, D_s)) >>> fut_in = tf.random.normal((B, T + H, D_f))
>>> # Strict mode, all inputs provided >>> s, d, f = prepare_model_inputs(dyn_in, stat_in, fut_in, model_type='strict') >>> print(f"S: {s.shape}, D: {d.shape}, F: {f.shape}") S: (2, 2), D: (2, 10, 4), F: (2, 13, 1)
>>> # Strict mode, static is None >>> s, d, f = prepare_model_inputs(dyn_in, static_input=None, future_input=fut_in, ... model_type='strict', forecast_horizon=H) >>> print(f"S: {s.shape}, D: {d.shape}, F: {f.shape}") S: (2, 0), D: (2, 10, 4), F: (2, 13, 1)
>>> # Flexible mode, static and future are None >>> s, d, f = prepare_model_inputs(dyn_in, static_input=None, future_input=None, ... model_type='flexible') >>> print(f"S: {s is None}, D: {d.shape}, F: {f is None}") S: True, D: (2, 10, 4), F: True
- geoprior.models.utils.prepare_model_inputs_in(dynamic_input, static_input=None, future_input=None, model_type='strict', forecast_horizon=None, verbose=0)[source]#
- Prepares a list of input tensors for a model’s call method in graph
compatible mode.
Prepares a list of input tensors for a model’s call method.
This function standardizes the creation of the input list [static, dynamic, future] expected by many models in
geoprior. It handles cases where static or future inputs might beNone, creating appropriate dummy tensors with zero features if the model_type is ‘strict’.- Parameters:
- dynamic_input
np.ndarrayortf.Tensor The dynamic (past observed) features. This input is always required and must be a valid tensor or array. Expected shape: (batch_size, past_time_steps, num_dynamic_features).
- static_input
np.ndarrayortf.Tensor,optional The static (time-invariant) features. Expected shape: (batch_size, num_static_features). If
Noneand model_type is ‘strict’, a dummy tensor with 0 static features will be created. Default isNone.- future_input
np.ndarrayortf.Tensor,optional The known future features. Expected shape: (batch_size, future_time_span, num_future_features). If
Noneand model_type is ‘strict’, a dummy tensor with 0 future features will be created. The time span for this dummy future tensor will be past_time_steps (from dynamic_input) plus forecast_horizon if provided, otherwise just past_time_steps. Default isNone.- model_type{‘strict’, ‘flexible’},
default‘strict’ Determines how
Noneinputs for static and future features are handled:'strict': If static_input or future_input isNone, a dummy tensor with a feature dimension of 0 will be created and included in the output list. This is for models that expect a 3-element list of tensors, even if some paths are unused.'flexible': If static_input or future_input isNone,Noneitself will be placed in the corresponding position in the output list. This is for models that can internally handleNoneinputs for optional feature types.
- forecast_horizon
int,optional The forecast horizon. Used only if model_type=’strict’ and future_input is
None, to determine the time dimension of the dummy future tensor (as past_time_steps + forecast_horizon). If not provided in this scenario, the dummy future tensor’s time dimension will match dynamic_input’s past_time_steps. Default isNone.- verbose
int,default0 Verbosity level. If > 0, prints information about dummy tensor creation.
0: Silent.1: Basic info on dummy creation.2: More details on shapes.
- dynamic_input
- Returns:
List[Optional[tf.Tensor]]A list containing three elements in the order: [processed_static_input, processed_dynamic_input, processed_future_input]. Elements can be TensorFlow tensors or
None(if model_type=’flexible’ and original input wasNone). All returned tensors are cast to tf.float32.
- Raises:
ValueErrorIf dynamic_input is
None. If dynamic_input is not at least 2D (needs batch dimension). If static_input (when provided) is not 2D. If future_input (when provided) is not 3D.TypeErrorIf inputs cannot be converted to TensorFlow tensors.
- Parameters:
- Return type:
Examples
>>> import tensorflow as tf >>> import numpy as np >>> from geoprior.models.utils import prepare_model_inputs_in >>> B, T, H = 2, 10, 3 >>> D_s, D_d, D_f = 2, 4, 1 >>> dyn_in = tf.random.normal((B, T, D_d)) >>> stat_in = tf.random.normal((B, D_s)) >>> fut_in = tf.random.normal((B, T + H, D_f))
>>> # Strict mode, all inputs provided >>> s, d, f = prepare_model_inputs_in(dyn_in, stat_in, fut_in, model_type='strict') >>> print(f"S: {s.shape}, D: {d.shape}, F: {f.shape}") S: (2, 2), D: (2, 10, 4), F: (2, 13, 1)
>>> # Strict mode, static is None >>> s, d, f = prepare_model_inputs_in(dyn_in, static_input=None, future_input=fut_in, ... model_type='strict', forecast_horizon=H) >>> print(f"S: {s.shape}, D: {d.shape}, F: {f.shape}") S: (2, 0), D: (2, 10, 4), F: (2, 13, 1)
>>> # Flexible mode, static and future are None >>> s, d, f = prepare_model_inputs_in(dyn_in, static_input=None, future_input=None, ... model_type='flexible') >>> print(f"S: {s is None}, D: {d.shape}, F: {f is None}") S: True, D: (2, 10, 4), F: True
- geoprior.models.utils.prepare_spatial_future_data(final_processed_data, feature_columns, dynamic_feature_indices, sequence_length=1, dt_col='date', static_feature_names=None, forecast_horizon=None, future_years=None, encoded_cat_columns=None, scaling_params=None, spatial_cols=None, squeeze_last=False, verbosity=0)[source]#
Prepare future static and dynamic inputs for making predictions.
This function prepares the necessary static and dynamic inputs required for forecasting future values in time series data. It processes the provided dataset by grouping it by
location_id, extracting the last sequence of data points based on the specifiedsequence_length, and generating future inputs for prediction over the definedforecast_horizon.The function handles both integer and datetime representations of the
dt_col, extracting the year from datetime columns when necessary. It also allows for flexibility in specifying static features and encoded categorical variables.(53)#\[\text{scaled\_time} = \frac{\text{future\_time} - \mu}{\sigma}\]- Parameters:
final_processed_data (
pandas.DataFrame) – The processed DataFrame containing all features and targets. Must include thelocation_idcolumn and the specifieddt_col.feature_columns (
List[str]) – List of feature column names to be used for dynamic input preparation.dynamic_feature_indices (
List[int]) – Indices of dynamic features infeature_columns. These features are considered time-dependent and are used to prepare dynamic inputs.sequence_length (
int, optional) – The number of past time steps to include in each input sequence. Default is1.dt_col (
str, optional) – The name of the time-related column infinal_processed_data. Defaults to'date'.static_feature_names (
List[str], optional) – List of static feature column names. If not provided, defaults to['longitude', 'latitude']plus anyencoded_cat_columns.forecast_horizon (
int, optional) – The number of future time steps to predict. If set toNone, the function defaults to predicting the next immediate time step.future_years (
List[int], optional) – List of future years to predict. Must match the length offorecast_horizonifforecast_horizonis provided.encoded_cat_columns (
List[str], optional) – List of encoded categorical column names to be treated as static features.scaling_params (
Dict[str,Dict[str,float]], optional) – Dictionary containing scaling parameters (mean and standard deviation) for features. Example:{'year': {'mean': 2000, 'std': 10}}. If not provided, the function computes the mean and std for thedt_col.squeeze_last (
bool, defaultTrue,) – Squeeze the last axis which correspond to the output dimensionyif equal to1.verbosity (
int, optional) – Verbosity level from0to7for debugging and understanding the process. Higher values produce more detailed logs.
- Returns:
A tuple containing:
future_static_inputsnumpy.ndarrayArray of future static inputs with shape
(num_samples, num_static_vars, 1).
future_dynamic_inputsnumpy.ndarrayArray of future dynamic inputs with shape
(num_samples, sequence_length, num_dynamic_vars, 1).
future_years_listList[int]List of future time values corresponding to each sample.
location_ids_listList[int]List of location IDs corresponding to each sample.
longitudesList[float]List of longitude values corresponding to each sample.
latitudesList[float]List of latitude values corresponding to each sample.
- Return type:
Tuple[np.ndarray,np.ndarray,List[int],List[int],List[float],List[float]]
Examples
>>> from geoprior.models.utils import prepare_spatial_future_data >>> import pandas as pd >>> data = pd.DataFrame({ ... 'location_id': [1, 1, 1, 2, 2, 2], ... 'year': [2018, 2019, 2020, 2018, 2019, 2020], ... 'longitude': [10.0, 10.0, 10.0, 20.0, 20.0, 20.0], ... 'latitude': [50.0, 50.0, 50.0, 60.0, 60.0, 60.0], ... 'temperature': [15, 16, 15.5, 20, 21, 20.5], ... 'rainfall': [100, 110, 105, 200, 210, 205], ... 'encoded_cat': [1, 1, 1, 2, 2, 2] ... }) >>> feature_cols = ['year', 'temperature', 'rainfall', 'encoded_cat'] >>> dynamic_indices = [0, 1, 2] >>> future_static, future_dynamic, future_years, loc_ids, longs,\ lats = prepare_spatial_future_data( ... final_processed_data=data, ... feature_columns=feature_cols, ... dynamic_feature_indices=dynamic_indices, ... sequence_length=2, ... forecast_horizon=1, ... future_years=[2021], ... encoded_cat_columns=['encoded_cat'], ... verbosity=5, ... dt_col='year' ... ) >>> print(future_static.shape) (2, 3, 1) >>> print(future_dynamic.shape) (2, 2, 3, 1)
Notes
The function handles both integer and datetime representations of the
dt_col. Ifdt_colis a datetime type, the year is extracted for scaling purposes.If
forecast_horizonis set toNone, the function defaults to generating data for the next immediate time step based on the last entry in the time column.Ensure that the length of
future_yearsmatchesforecast_horizonifforecast_horizonis provided.The
static_feature_namesparameter allows for flexibility in specifying which static features to include. If not provided, it defaults to['longitude', 'latitude']plus anyencoded_cat_columns.
See also
prepare_future_dataMain function for preparing future data inputs.
- geoprior.models.utils.set_default_params(quantiles=None, scales=None, multi_scale_agg=None)[source]#
Sets and validates default values for quantiles, scales, and return_sequences parameters.
- Parameters:
quantiles (
str,listoffloat, orNone, optional) –Specifies the quantiles to be used for probabilistic forecasting.
If set to
'auto', it defaults to[0.1, 0.5, 0.9]. If a list is provided, each element must be a float between 0 and 1 exclusive. IfNone, it remainsNoneand can be used for deterministic forecasting.scales (
str,listofint, orNone, optional) –Specifies the scaling factors to be used in multi-scale processing.
If set to
'auto'orNone, it defaults to[1]. If a list is provided, each element must be a positive integer.multi_scale_agg (
strorNone, optional) –Determines the aggregation method for multi-scale features.
If set to
None,return_sequencesisFalse. Otherwise,return_sequencesisTrue. Expected aggregation methods include'average','concat','sum','last', and'auto'(which falls back to'last'), depending on model requirements.
- Returns:
Tuple containing validated
quantiles, validatedscales, and thereturn_sequencesflag derived frommulti_scale_agg.- Return type:
Tuple[List[float],List[int],bool]- Raises:
ValueError – If quantiles is neither ‘auto’ nor a list of valid floats. If scales is neither ‘auto’ nor a list of valid positive integers. If multi_scale_agg is provided but not a recognized aggregation method.
Examples
>>> # Example 1: Using default 'auto' settings >>> quantiles, scales, return_sequences = set_default_parameters( quantiles='auto', scales='auto', multi_scale_agg='auto') >>> print(quantiles) [0.1, 0.5, 0.9] >>> print(scales) [1] >>> print(return_sequences) True
>>> # Example 2: Providing custom quantiles and scales >>> quantiles, scales, return_sequences = set_default_parameters( ... quantiles=[0.05, 0.5, 0.95], ... scales=[1, 2, 4], ... multi_scale_agg='concat' ... ) >>> print(quantiles) [0.05, 0.5, 0.95] >>> print(scales) [1, 2, 4] >>> print(return_sequences) True
>>> # Example 3: Invalid quantiles input >>> set_default_parameters(quantiles=[-0.1, 1.2]) Traceback (most recent call last): ... ValueError: Each quantile must be a float between 0 and 1 (exclusive). Invalid quantiles: [-0.1, 1.2]
>>> # Example 4: Invalid scales input >>> set_default_parameters(scales=[0, -2]) Traceback (most recent call last): ... ValueError: Each scale must be a positive integer. Invalid scales: [0, -2]
- geoprior.models.utils.split_static_dynamic(sequences, static_indices, dynamic_indices, static_time_step=0, reshape_static=True, reshape_dynamic=True, static_reshape_shape=None, dynamic_reshape_shape=None)[source]#
Split sequences into static and dynamic inputs for the model.
The split_static_dynamic function divides input sequences into static and dynamic components based on specified feature indices. Static features are typically location-specific and do not change over time, while dynamic features vary across different time steps.
(54)#\[\begin{split}\text{Static Inputs} = \mathbf{S} = \mathbf{X}_{t, static\_indices} \\ \text{Dynamic Inputs} = \mathbf{D} = \mathbf{X}_{:, dynamic\_indices}\end{split}\]- Parameters:
sequences (numpy.ndarray) – Array of input sequences with shape (batch_size, sequence_length, num_features).
static_indices (List[int]) – Indices of static features within the feature dimension.
dynamic_indices (List[int]) – Indices of dynamic features within the feature dimension.
static_time_step (int, default 0) – Time step from which to extract static features (default is the first time step).
reshape_static (bool, default True) – Whether to reshape static inputs. If False, returns without reshaping.
reshape_dynamic (bool, default True) – Whether to reshape dynamic inputs. If False, returns without reshaping.
static_reshape_shape (:py:class:
Optional[Tuple[int`, :py:class:`...]], default None) – Desired shape for static inputs after reshaping. If None, defaults to (batch_size, num_static_vars, 1).dynamic_reshape_shape (:py:class:
Optional[Tuple[int`, :py:class:`...]], default None) – Desired shape for dynamic inputs after reshaping. If None, defaults to (batch_size, sequence_length, num_dynamic_vars, 1).
- Returns:
A tuple containing: - Static inputs with shape as specified. - Dynamic inputs with shape as specified.
- Return type:
Tuple[`numpy.ndarray`, numpy.ndarray]- Raises:
ValueError – If static_time_step is out of range for the given sequence length.
Examples
>>> import numpy as np >>> from geoprior.models.utils import split_static_dynamic >>> >>> # Create a dummy sequence array >>> sequences = np.random.rand(100, 10, 5) # ( ... batch_size=100, sequence_length=10, num_features=5) >>> >>> # Define static and dynamic feature indices >>> static_indices = [0, 1] # e.g., longitude and latitude >>> dynamic_indices = [2, 3, 4] # e.g., year, GWL, density >>> >>> # Split the sequences >>> static_inputs, dynamic_inputs = split_static_dynamic( ... sequences, ... static_indices=static_indices, ... dynamic_indices=dynamic_indices, ... static_time_step=0 ... ) >>> >>> print(static_inputs.shape) (100, 2, 1) >>> print(dynamic_inputs.shape) (100, 10, 3, 1)
Notes
Static Features: These are typically location-specific features such as geographical coordinates or categorical attributes that remain constant over time.
Dynamic Features: These features vary over different time steps and are essential for capturing temporal dependencies in the data.
Reshaping: The function provides flexibility in reshaping the static and dynamic inputs to match the input requirements of various models, including Temporal Fusion Transformers.
See also
geoprior.models.utils.create_sequencesFunction to create input sequences and targets for time series forecasting.
- geoprior.models.utils.squeeze_last_dim_if(tensors, output_dims)[source]#
Squeeze the last dimension of tensor(s) if it equals 1 based on output_dims.
- output_dims can be:
a single int: apply that rule to every tensor
a list of ints: must match length of tensors list, each applied to corresponding tensor
a dict mapping keys to ints: each key must appear in tensors dict; unmatched keys in output_dims will warn but be ignored
- Parameters:
tensors (
tf.Tensororlistoftf.Tensorordictoftf.Tensor) –If a single tf.Tensor, processed with output_dims if int.
If a list, each element is processed with its corresponding entry in output_dims if list, or with the single int if int.
If a dict, each value is processed with the matching int in output_dims dict, or with the single int if output_dims is int.
output_dims (
intorlistofintordictofint) –If int: all tensors use that output dimension rule.
If list: length must equal len(tensors) when tensors is list.
If dict: keys map to ints; for any key in tensors missing from output_dims, no squeeze is applied; keys in output_dims absent in tensors produce a warning.
- Returns:
Same structure as tensors, but with trailing size-1 axes removed where output_dim == 1. If output_dim != 1, that tensor is unchanged.
- Return type:
- Raises:
TypeError – If tensors is not a tf.Tensor, list, or dict; or if output_dims type is invalid; or if list lengths do not match.
ValueError – If list lengths mismatch.
Examples
>>> import tensorflow as tf >>> from geoprior.utils.generic_utils import squeeze_last_dim_if >>> >>> # Single tensor, single int → squeeze if last dim=1 >>> t = tf.zeros((8, 5, 1)) >>> out = squeeze_last_dim_if(t, output_dims=1) >>> out.shape TensorShape([8, 5]) >>> >>> # List of tensors, list of ints >>> t_list = [tf.zeros((4, 1)), tf.zeros((4, 2, 1))] >>> out_list = squeeze_last_dim_if(t_list, output_dims=[1, 2]) >>> [x.shape for x in out_list] [TensorShape([4]), TensorShape([4, 2, 1])] >>> >>> # Dict of tensors, dict of ints >>> t_dict = { ... "a": tf.zeros((3, 1, 1)), ... "b": tf.zeros((3, 1, 2)) ... } >>> out_dict = squeeze_last_dim_if(t_dict, output_dims={"a": 1, "b": 2}) >>> {k: v.shape for k, v in out_dict.items()} {'a': TensorShape([3, 1]), 'b': TensorShape([3, 1, 2])}
Notes
This function does a shallow traversal: * If tensors is list, each element must be tf.Tensor. * If tensors is dict, each value must be tf.Tensor.
For dict mode: missing keys in output_dims → no change; extra keys in output_dims → warning.
See also
tf.squeezeRemove dimensions of size 1 from the shape of a tensor.
tf.reshapeManually reshape a tensor if more complex edits are needed.
- geoprior.models.utils.step_to_long(df, tname=None, dt_col=None, spatial_cols=None, mode='quantile', quantiles=None, verbose=3, sort=True)[source]#
Convert a multi-step forecast DataFrame from wide to long format.
This function transforms a DataFrame containing multi-step forecast predictions into a long-format DataFrame. In quantile mode, forecast columns such as
subsidence_q10_step1,subsidence_q50_step1, etc. are consolidated into unified columns (e.g.subsidence_q10,subsidence_q50, etc.), while in point mode, a single prediction column (subsidence_pred) is generated. The transformation also carries over additional columns (e.g. spatial coordinates and time) from the original DataFrame.- Parameters:
df (
pandas.DataFrame) – The multi-step forecast DataFrame. Expected to contain forecast prediction columns (e.g. columns with_qor_pred_stepin their names) along with other identifiers.tname (
str, optional) – The base name of the target variable (e.g."subsidence"). IfNone, the function attempts to auto-detect the target name from the column names.dt_col (
str, optional) – The name of the time column to include in the final DataFrame. If not provided, time sorting is not performed.spatial_cols (
listofstr, optional) – A list of spatial coordinate columns (e.g.["longitude", "latitude"]) to be retained in the final output.mode (
{"quantile", "point"}, default"quantile") – The forecast mode. In"quantile"mode, multiple quantile forecast columns are merged into unified columns. In"point"mode, a single prediction column is produced.quantiles (
listoffloat, optional) – The quantile values for quantile mode (e.g.[0.1, 0.5, 0.9]). If not provided, defaults are used.sort (
bool, optional) – If True, sorts the final DataFrame by the column specified indt_col(if present). Default is True.verbose (
int, optional) – Verbosity level for logging output. Higher values (e.g. 5 to 7) provide more detailed debug information.
- Returns:
A long-format DataFrame containing the retained spatial columns, the time column when
dt_colis provided, and the merged forecast prediction columns. In quantile mode, the output contains unified columns such assubsidence_q10andsubsidence_q50. In point mode, it contains a singlesubsidence_predcolumn.- Return type:
Examples
>>> from geoprior.models.utils import step_to_long >>> # Given a DataFrame `forecast_df` with columns like: >>> # ['longitude', 'latitude', 'year', 'subsidence_actual', >>> # 'subsidence_q10_step1', 'subsidence_q50_step1', 'subsidence_q89_step1', >>> # 'subsidence_q10_step2', ...] >>> long_df = step_to_long( ... df=forecast_df, ... tname="subsidence", ... dt_col="year", ... spatial_cols=["longitude", "latitude"], ... mode="quantile", ... quantiles=[0.1, 0.5, 0.9], ... verbose=3, ... sort=True ... ) >>> print(long_df.head())
Notes
Internally, this function calls:
check_forecast_mode()to validate the user-specified quantiles.validate_consistency_q()andvalidate_quantiles()to ensure the supplied quantiles match those auto-detected from the DataFrame.Depending on
mode, either_step_to_long_q()for quantile mode or_step_to_long_pred()for point mode performs the conversion.
Mathematically, let \(X \in \mathbb{R}^{n \times m}\) represent the wide-format DataFrame, where each row corresponds to one sample and each forecast step is stored in separate columns. The function reshapes \(X\) into a long-format DataFrame \(Y \in \mathbb{R}^{(n \cdot s) \times p}\), where \(s\) is the forecast horizon and \(p\) is the number of output columns after merging forecast step values.
See also
_step_to_long_qConverts multi-step quantile forecasts to long format.
_step_to_long_predConverts multi-step point forecasts to long format.
detect_digitsExtracts numeric values from strings for quantile detection.
- geoprior.models.utils.export_keras_losses(history, keys=None, savefile=None, verbose=0, formats=('json', 'csv'))[source]#
Export loss(es) (and any other metric) from a Keras History object.
- Parameters:
history (History) – The History object returned by model.fit().
keys (list[str] | None) – List of history.history keys to export, e.g. [“loss”, “val_loss”, “physics_loss”]. If None, defaults to all keys ending with “loss”.
savefile (str | None) –
Path (optionally with extension) where to write the output.
If the extension is
.json, only JSON is written. If the extension is.csv, only CSV is written. If no extension is given, all formats informatsare written usingsavefileas the base name.verbose (int) – If >0, prints status messages.
formats (tuple[str, ...]) – File formats to write when savefile has no extension. Valid entries are “json” and “csv”.
- Returns:
result – Dictionary containing
"epochs_run"and one list entry per exported history key.- Return type:
- geoprior.models.utils.get_tensor_from(inputs, *tensor_names, default=None, check_type=True, auto_convert=True)[source]#
Safely retrieves the first available tensor from a dictionary using a list of possible keys.
This utility is crucial for handling model inputs within a TensorFlow graph (e.g., in train_step). It avoids the ambiguous boolean evaluation of Tensors (e.g., tensor_a or tensor_b), which causes runtime errors, by explicitly checking for is not None.
- Parameters:
inputs (
dict) – The dictionary to search, typically the model’s input dictionary (e.g., the inputs provided to call or train_step).*tensor_names (
str) – One or more string keys to check for in the inputs dictionary, in order of priority.default (
Any, optional) – A default value to return if no keys are found or if no found value is a valid tensor. Defaults to None.check_type (
bool, defaultTrue) – If True, only returns a value if it is (or can be converted to) a Tensor or Variable. If False, returns the first non-None value regardless of its type.auto_convert (
bool, defaultTrue) – If True and check_type is True, this function will attempt to convert a found non-Tensor value (like a NumPy array or a list) into a TensorFlow tensor using tf.convert_to_tensor.
- Returns:
The first found tf.Tensor or tf.Variable associated with one of the tensor_names. If auto_convert is True, this can also be a newly converted tensor. Returns default (typically None) if no valid tensor is found.
- Return type:
Optional[tf.Tensor]- Raises:
TypeError – If inputs is not a dictionary.
Examples
>>> import tensorflow as tf >>> inputs_dict = { ... 'some_other_key': [1, 2, 3], ... 'soil_thickness': tf.constant([20., 21.], dtype=tf.float32) ... } >>> >>> # Correctly finds 'soil_thickness' >>> get_tensor_from(inputs_dict, 'H_field', 'soil_thickness') <tf.Tensor: shape=(2,), dtype=float32, numpy=array([20., 21.], ...)> >>> >>> # Returns None safely if nothing is found >>> get_tensor_from(inputs_dict, 'missing_key', 'another_key') None >>> >>> # Demonstrating auto_convert >>> inputs_dict_np = {'H_field': np.array([10., 11.])} >>> get_tensor_from(inputs_dict_np, 'H_field', auto_convert=True) <tf.Tensor: shape=(2,), dtype=float32, numpy=array([10., 11.], ...)>
- geoprior.models.utils.format_pihalnet_predictions(pihalnet_outputs=None, model=None, model_inputs=None, y_true_dict=None, target_mapping=None, include_gwl=True, include_coords=True, quantiles=None, forecast_horizon=None, output_dims=None, ids_data_array=None, ids_cols=None, ids_cols_indices=None, scaler_info=None, coord_scaler=None, evaluate_coverage=False, coverage_quantile_indices=(0, -1), savefile=None, name=None, model_name=None, apply_mask=False, mask_values=None, mask_fill_value=None, verbose=0, _logger=None, stop_check=None, **kwargs)[source]#
Formats PIHALNet/GeoPriorSubsNet predictions into a structured pandas DataFrame, handling inversion, quantiles, and coordinates.
This function is the core formatter. It: 1. Gets model outputs (or uses provided ones). 2. Unpacks ‘data_final’ if model_name is ‘geoprior’. 3. Inverse-transforms all prediction and actual arrays using scaler_info. 4. Builds a long-format DataFrame with sample_idx and forecast_step. 5. Appends inverted quantile/point predictions. 6. Appends inverted actual values. 7. Appends inverted coordinates. 8. Appends static/ID columns. 9. Evaluates coverage on the inverted data.
- Parameters:
pihalnet_outputs (
dict, optional) – Raw output from model.predict(). If None, model and model_inputs must be provided.model (
tf.keras.Model, optional) – Trained model instance (if pihalnet_outputs is None).model_inputs (
dict, optional) – Inputs for the model to generate predictions (if pihalnet_outputs is None).y_true_dict (
dict, optional) – Dictionary of true target arrays (e.g., {‘subs_pred’: y_true_s}). Required for including actuals and evaluating coverage.target_mapping (
dict, optional) – Maps prediction keys to base names for DataFrame columns. Default: {‘subs_pred’: ‘subsidence’, ‘gwl_pred’: ‘gwl’}.include_gwl (
bool, defaultTrue) – Whether to include ‘gwl_pred’ in the final DataFrame.include_coords (
bool, defaultTrue) – Whether to include ‘coord_t’, ‘coord_x’, ‘coord_y’ columns.quantiles (
list[float], optional) – List of quantiles (e.g., [0.1, 0.5, 0.9]). If provided, quantile columns (e.g., ‘subsidence_q10’) are created.forecast_horizon (
int, optional) – The forecast horizon length (H). If not provided, it’s inferred from the prediction array’s shape.output_dims (
dict, optional) – Maps prediction keys to their output dimension (O). E.g., {‘subs_pred’: 1, ‘gwl_pred’: 1}. Crucial for correctly splitting GeoPrior outputs and reshaping.ids_data_array (
np.ndarrayorpd.DataFrame, optional) – Static/ID data (e.g., original coordinates) to merge. Must have the same number of samples (B) as predictions.ids_cols (
list[str], optional) – Column names if ids_data_array is a DataFrame.ids_cols_indices (
list[int], optional) – Column indices if ids_data_array is a NumPy array.scaler_info (
dict, optional) – Dictionary for inverse scaling. Each target entry should provide a fitted scaler, the target index inside that scaler, and the feature-name ordering used when the scaler was fit.coord_scaler (
sklearn.preprocessing.Scaler, optional) – A fitted scaler object for inverse transforming the ‘coords’ tensor.evaluate_coverage (
bool, defaultFalse) – If True, calculates coverage percentage for quantiles.coverage_quantile_indices (
tuple[int,int], default(0,-1)) – Indices of the low and high quantiles in the quantiles list to use for coverage (e.g., 0 and -1 for 10th and 90th).savefile (
str, optional) – If provided, saves the final DataFrame to this path.model_name (
str, optional) – Specifies the model type. If ‘geoprior’ or ‘geopriorsubsnet’, triggers unpacking of the ‘data_final’ output.apply_mask (
bool, defaultFalse) – If True, masks predictions based on mask_values in the first target’s _actual column.mask_values (
floatorint, optional) – The value in the _actual column to trigger masking.mask_fill_value (
float, optional) – The value to replace masked predictions with (e.g., np.nan).verbose (
int, default0) – Logging verbosity._logger (
logging.Loggerorcallable, optional) – Logger object.stop_check (
callable, optional) – Function to check for early stopping.name (str | None)
- Returns:
A long-format DataFrame with predictions, actuals, and coordinates.
- Return type:
pd.DataFrame
- geoprior.models.utils.normalize_for_pinn(df, time_col, coord_x, coord_y, cols_to_scale='auto', scale_coords=True, exclude_cols=None, protect_si_suffix='__si', shift_time_by_horizon=False, verbose=1, forecast_horizon=None, _logger=None, coord_scaler=None, fit_coord_scaler=True, other_scaler=None, fit_other_scaler=True, **kws)[source]#
Apply Min-Max normalization to spatial-temporal coordinates and optionally to other numeric columns. If
cols_to_scale == "auto", automatically select numeric columns excluding categorical and one-hot features.By default, this function scales the time, longitude, and latitude columns (if
scale_coords=True). Then, it either scales explicitly provided columns incols_to_scaleor automatically infers numeric columns (excluding coordinates ifscale_coordsis False, and excluding one-hot/boolean columns).The Min-Max scaling for a feature \(x\) is:
(55)#\[x' = \frac{x - \min(x)}{\max(x) - \min(x)}\]- Parameters:
df (
pd.DataFrame) – The input DataFrame containing at leasttime_col,coord_x, andcoord_ycolumns. The DataFrame should contain temporal and spatial information to be scaled.time_col (
str) – The name of the numeric time column (e.g., year as numeric or datetime). This column will be used to adjust and scale the temporal data.coord_x (
str) – The name of the longitude column in the DataFrame. This column will be scaled along with the latitude and time columns.coord_y (
str) – The name of the latitude column in the DataFrame. This column will be scaled along with the longitude and time columns.cols_to_scale (
listofstror"auto"orNone, default"auto") – If a list of column names, scales exactly those columns. If"auto", selects all numeric columns, excludingtime_col,coord_x, andcoord_yifscale_coords=False, and excluding one-hot encoded columns whose values are only{0, 1}. If None, no extra columns are scaled.scale_coords (
bool, defaultTrue) – If True, scales the[time_col, coord_x, coord_y]columns. If False, these columns remain unchanged.verbose (
int, default1) – Verbosity level for logging. Values higher than 1 provide more detailed logging information.forecast_horizon (
Optional[int], defaultNone) – The number of time steps to shift the time column by. This is added to the time values before scaling if provided._logger (
Optional[Union[logging.Logger,Callable[[str],None]]], defaultNone) – Logger or function to handle logging messages. If None, the default logging mechanism is used.kws (
dict, optional) – These will be passed on to any other internal function used in the data processing or scaling steps.protect_si_suffix (str)
shift_time_by_horizon (bool)
coord_scaler (MinMaxScaler | None)
fit_coord_scaler (bool)
other_scaler (MinMaxScaler | None)
fit_other_scaler (bool)
- Returns:
df_scaled (
pd.DataFrame) – A new DataFrame with the specified columns normalized.coord_scaler (
MinMaxScalerorNone) – The fitted scaler for the[time_col, coord_x, coord_y]columns ifscale_coords=True, elseNone.other_scaler (
MinMaxScalerorNone) – The fitted scaler for any additional columns that were scaled (either explicitly provided or auto-selected). None if no columns were scaled beyond the coordinates.
- Raises:
TypeError – If df is not a DataFrame, or cols_to_scale is neither a list nor “auto” nor None, or if any explicitly provided column is not a string.
ValueError – If required columns (
time_col,coord_x,coord_y) or any ofcols_to_scaledo not exist indf, or cannot be converted to numeric.
- Return type:
tuple[DataFrame, MinMaxScaler | None, MinMaxScaler | None]
Examples
>>> import pandas as pd >>> from geoprior.nn.pinn.utils import normalize_for_pinn >>> data = { ... "year_num": [0.0, 1.0, 2.0], ... "lon": [100.0, 101.0, 102.0], ... "lat": [30.0, 31.0, 32.0], ... "feat1": [10.0, 20.0, 30.0], ... "one_hot_A": [0, 1, 0] ... } >>> df = pd.DataFrame(data) >>> df_scaled, coord_scl, feat_scl = normalize_for_pinn( ... df, ... time_col="year_num", ... coord_x="lon", ... coord_y="lat", ... cols_to_scale="auto", ... scale_coords=True, ... verbose=2 ... ) >>> # 'year_num','lon','lat','feat1' get scaled; 'one_hot_A' excluded >>> df_scaled["year_num"].tolist() [0.0, 0.5, 1.0] >>> df_scaled["feat1"].tolist() [0.0, 0.5, 1.0]
Notes
When
cols_to_scale="auto", numeric columns with only{0, 1}values are assumed to be one-hot and excluded from scaling.If
scale_coords=False, coordinate columns remain unchanged, and auto-selection (if used) will exclude them.Returned
coord_scalerisNoneifscale_coords=False. Returnedother_scalerisNoneifcols_to_scaleisNoneor results in an empty set after filtering.
See also
sklearn.preprocessing.MinMaxScalerScales features to [0,1].
- geoprior.models.utils.prepare_pinn_data_sequences(df, time_col, subsidence_col, gwl_col, dynamic_cols, static_cols=None, future_cols=None, spatial_cols=None, h_field_col=None, lon_col=None, lat_col=None, group_id_cols=None, time_steps=12, forecast_horizon=3, output_subsidence_dim=1, output_gwl_dim=1, datetime_format=None, normalize_coords=True, cols_to_scale=None, lock_physics_cols=True, protect_si_suffix='__si', return_coord_scaler=False, coord_scaler=None, fit_coord_scaler=True, mode=None, model=None, savefile=None, progress_hook=None, stop_check=None, verbose=0, _logger=None, **kws)[source]#
- Parameters:
df (DataFrame)
time_col (str)
subsidence_col (str)
gwl_col (str)
h_field_col (str | None)
lon_col (str | None)
lat_col (str | None)
time_steps (int)
forecast_horizon (int)
output_subsidence_dim (int)
output_gwl_dim (int)
datetime_format (str | None)
normalize_coords (bool)
lock_physics_cols (bool)
protect_si_suffix (str)
return_coord_scaler (bool)
coord_scaler (MinMaxScaler | None)
fit_coord_scaler (bool)
mode (str | None)
model (str | None)
savefile (str | None)
verbose (int)
- Return type:
tuple[dict[str, ndarray], dict[str, ndarray]] | tuple[dict[str, ndarray], dict[str, ndarray], MinMaxScaler | None]
- geoprior.models.utils.format_pinn_predictions(predictions=None, model=None, model_inputs=None, y_true_dict=None, target_mapping=None, include_gwl=True, include_coords=True, quantiles=None, forecast_horizon=None, output_dims=None, ids_data_array=None, ids_cols=None, ids_cols_indices=None, scaler_info=None, coord_scaler=None, evaluate_coverage=False, coverage_quantile_indices=(0, -1), savefile=None, _logger=None, name=None, model_name=None, stop_check=None, verbose=0, **kwargs)[source]#
Formats PINN model predictions into a structured pandas DataFrame.
This is a general-purpose utility for transforming raw model outputs (from models like PIHALNet or TransFlowSubsNet) into a long-format DataFrame suitable for analysis, visualization, or export.
This is a powerful, general-purpose utility for transforming raw model outputs into a long-format DataFrame suitable for analysis, visualization, or export. It handles multi-target outputs (e.g., subsidence and GWL), point or quantile forecasts, and can optionally include true values, coordinate information, and other metadata. It also supports inverse-scaling of predictions and evaluation of quantile coverage.
- Parameters:
predictions (
dictofTensors, optional) – The dictionary of prediction tensors, typically returned by a model’s.predict()method. Keys should match the model’s output layer names (e.g.,'subs_pred','gwl_pred'). IfNone, predictions are generated internally using the model and model_inputs arguments. Default isNone.model (
keras.Model, optional) – A compiled Keras model instance used to generate predictions if the predictions dictionary is not provided. Default isNone.model_inputs (
dictofTensors, optional) – A dictionary of input tensors matching the model’s signature, required only if predictions isNone. Default isNone.y_true_dict (
dict, optional) – A dictionary containing the ground-truth target arrays, keyed by their base names (e.g.,'subsidence','gwl'). If provided, an<target>_actualcolumn will be added to the output DataFrame for comparison. Default isNone.target_mapping (
dict, optional) – A custom mapping from model output keys to desired base names in the DataFrame columns. For example:{'subs_pred': 'subsidence_mm', 'gwl_pred': 'head_m'}. Default isNone.include_gwl (
bool, defaultTrue) – Toggles the inclusion of groundwater level (GWL) predictions in the final DataFrame.include_coords (
bool, defaultTrue) – Toggles the inclusion of the spatio-temporal coordinate columns (coord_t,coord_x,coord_y) in the final DataFrame.quantiles (
listoffloat, optional) – The list of quantile levels (e.g.,[0.1, 0.5, 0.9]) that the model predicted. This is crucial for correctly parsing probabilistic forecasts. Default isNone.forecast_horizon (
int, optional) – The length of the forecast horizon. IfNone, it is inferred from the shape of the prediction tensors. Default isNone.output_dims (
dictofstr, optional) – A dictionary specifying the feature dimension of each target, e.g.,{'subs_pred': 1, 'gwl_pred': 1}. IfNone, it’s inferred from the tensor shapes. Default isNone.ids_data_array (
np.ndarrayorpd.DataFrame, optional) – An array or DataFrame containing static identifiers (e.g., well IDs, site categories) for each sample. Its length must match the number of samples in the prediction. Default isNone.ids_cols (
listofstr, optional) – A list of column names for the ids_data_array. Required if ids_data_array is a NumPy array. Default isNone.ids_cols_indices (
listofint, optional) – A list of column indices to select from ids_data_array if it is a NumPy array. Default isNone.scaler_info (
dict, optional) – A dictionary providing the necessary information to perform inverse scaling on a per-target basis. Each key should be a target name (e.g., ‘subsidence’) and its value a dictionary containing{'scaler': obj, 'all_features': list, 'idx': int}. Default isNone.coord_scaler (
object, optional) – A fitted scikit-learn-like scaler object used to perform an inverse transform on the coordinate columns. Default isNone.evaluate_coverage (
bool, defaultFalse) – IfTrueand quantile predictions are present, calculates the unconditional coverage of the prediction interval.coverage_quantile_indices (
tupleof(int,int), default(0,-1)) – The indices of the lower and upper quantiles in the sorted quantiles list to use for the coverage calculation. Default is(0, -1), which corresponds to the full range.savefile (
str, optional) – If a file path is provided, the final DataFrame is saved to a CSV file at this location. Default isNone.name (
strorNone) – Name of the prediction. Name is used to format the output of the data and coverage result if applicable.model_name (
str,None,) – Name of the model.verbose (
int, default0) – The verbosity level, from 0 (silent) to 5 (trace every step).**kwargs (
dict,) – Additional keyword arguments for future extensions.
- Returns:
A long-format DataFrame where each row represents a single forecast step for a single sample. Columns include sample and step identifiers, coordinates, predictions, and optionally actuals and metadata.
- Return type:
pd.DataFrame
Notes
The function returns a column-aligned DataFrame, which simplifies subsequent analysis and plotting.
For quantile forecasts, prediction columns are named using the pattern
<target_name>_q<quantile*100>, e.g.,subsidence_q5,subsidence_q50,subsidence_q95.For point forecasts, the column is named
<target_name>_pred.
See also
geoprior.plot.forecast.plot_forecastsA powerful utility for visualizing the DataFrame produced by this function.
- geoprior.models.utils.extract_txy(inputs, coord_slice_map=None, expect_dim=None, verbose=0, _logger=None, **kws)[source]#
Extracts t, x, y tensors from various input formats.
This utility standardizes coordinate inputs, accepting a single tensor or a dictionary, and handling both 2D (spatial/static) and 3D (spatio-temporal) data with flexible dimension validation.
- Parameters:
inputs (
tf.Tensor,np.ndarray, ordict) – The input data containing coordinates. Can be a single tensor or a dictionary with ‘coords’ or ‘t’, ‘x’, ‘y’ keys.coord_slice_map (
dict, optional) – Mapping for ‘t’, ‘x’, ‘y’ to their index in the last dimension of a coordinate tensor. Defaults to {‘t’: 0, ‘x’: 1, ‘y’: 2}.expect_dim (
{'2d', '3d', '3d_only'}, optional) – Enforces a constraint on the input’s dimension.'2d'requires input shaped like(batch, 3).'3d'accepts 3D input and expands 2D input to 3D with a time dimension of 1.'3d_only'requires 3D input and raises an error for 2D input.Noneaccepts both 2D and 3D inputs without changing their rank.verbose (
int, default0) – Controls logging verbosity.
- Returns:
t, x, y – The extracted t, x, and y coordinate tensors. Their rank (2D or 3D) depends on the input and the expect_dim mode.
- Return type:
Tuple[tf.Tensor,tf.Tensor,tf.Tensor]- Raises:
ValueError – If input format is unsupported, dimensions are inconsistent, or expect_dim constraint is violated.
- geoprior.models.utils.plot_hydraulic_head(model, t_slice, x_bounds, y_bounds, resolution=100, ax=None, title=None, cmap='viridis', colorbar_label='Hydraulic Head (h)', save_path=None, show_plot=True, **contourf_kwargs)[source]#
Generate and plot a 2D contour map of a hydraulic head solution.
This utility visualizes the output of a Physics-Informed Neural Network (PINN) that solves for the hydraulic head \(h(t, x, y)\). It automates the process of creating a spatial grid, running model predictions, and generating a publication-quality contour plot for a specific slice in time.
- Parameters:
model (
tf.keras.Model) – The trained PINN model. It is expected to have a.predict()method that accepts a dictionary of tensors with keys{'t', 'x', 'y'}.t_slice (
float) – The specific point in time \(t\) for which to plot the 2D spatial solution.x_bounds (
tupleoffloat) – A tuple(x_min, x_max)defining the spatial domain for the x-axis.y_bounds (
tupleoffloat) – A tuple(y_min, y_max)defining the spatial domain for the y-axis.resolution (
int, optional) – The number of points to sample along each spatial axis, creating a grid ofresolution x resolutionpoints for prediction. Higher values result in a smoother plot. Default is 100.ax (
matplotlib.axes.Axes, optional) – A pre-existing Matplotlib Axes object to plot on. IfNone, a new figure and axes are created internally. This is useful for embedding this plot within a larger figure arrangement. Default isNone.title (
str, optional) – A custom title for the plot. IfNone, a default title is generated using the value of t_slice. Default isNone.cmap (
str, optional) – The name of the Matplotlib colormap to use for the contour plot. Default is'viridis'.colorbar_label (
str, optional) – The text label for the color bar. Default is'Hydraulic Head (h)'.save_path (
str, optional) – If provided, the path (including filename and extension) where the generated plot will be saved. This is only active when the function creates its own figure (i.e., when ax isNone). Default isNone.show_plot (
bool, optional) – IfTrue, callsplt.show()to display the plot. This is only active when the function creates its own figure. Default isTrue.**contourf_kwargs (
any) – Additional keyword arguments that are passed directly to thematplotlib.pyplot.contourffunction. This allows for advanced customization (e.g.,levels=20,extend='both').
- Returns:
ax (
matplotlib.axes.Axes) – The Matplotlib Axes object on which the contour plot was drawn.contour (
matplotlib.cm.ScalarMappable) – The contour plot object, which can be used for further customizations, such as modifying the color bar.
- Return type:
See also
geoprior.models.pinn.PiTGWFlowThe PINN model this function is designed to visualize.
Notes
The core mechanism of this function involves creating a 2D meshgrid of \((x, y)\) coordinates. These grid points are then “flattened” into a long list of points, as the PINN model expects a batch of individual coordinates for prediction, not a grid.
The prediction process is as follows:
A grid of shape
(resolution, resolution)is created for \(x\) and \(y\).These grids are reshaped into column vectors of shape
(resolution*resolution, 1).A time vector of the same shape, filled with t_slice, is created.
The model’s
.predict()method is called on these flat tensors.The resulting flat prediction vector is reshaped back to the original
(resolution, resolution)grid shape for plotting.
If a custom ax is provided, the user is responsible for calling
plt.show()or saving the parent figure.Examples
>>> import numpy as np >>> import tensorflow as tf >>> import matplotlib.pyplot as plt >>> # This is a mock model for demonstration purposes. >>> # In practice, you would use a trained PiTGWFlow model. >>> class MockPINN(tf.keras.Model): ... def call(self, inputs): ... # A simple analytical function for demonstration ... t, x, y = inputs['t'], inputs['x'], inputs['y'] ... return tf.sin(np.pi * x) * tf.cos(np.pi * y) * tf.exp(-t) ... >>> mock_model = MockPINN()
1. Simple Plotting Example
This example creates a single plot and saves it to a file.
>>> ax, contour = plot_hydraulic_head( ... model=mock_model, ... t_slice=0.5, ... x_bounds=(-1, 1), ... y_bounds=(-1, 1), ... resolution=50, ... save_path="hydraulic_head_t0.5.png", ... show_plot=False # Do not display interactively ... ) Plot saved to hydraulic_head_t0.5.png
2. Advanced Example with Subplots
This example shows how to use the ax parameter to draw the solution at two different times side-by-side in one figure.
>>> fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6)) >>> fig.suptitle('Hydraulic Head at Different Times', fontsize=16) ... >>> # Plot solution at t = 0.1 >>> plot_hydraulic_head( ... model=mock_model, t_slice=0.1, x_bounds=(-1, 1), ... y_bounds=(-1, 1), ax=ax1, show_plot=False ... ) ... >>> # Plot solution at t = 1.0 >>> plot_hydraulic_head( ... model=mock_model, t_slice=1.0, x_bounds=(-1, 1), ... y_bounds=(-1, 1), ax=ax2, show_plot=False ... ) ... >>> plt.tight_layout(rect=[0, 0, 1, 0.96]) >>> plt.show()
- geoprior.models.utils.make_dict_to_tuple_fn(feature_keys, target_keys=None, *, allow_missing_optional=True)[source]#
Create a tf.data.Dataset.map function that converts a feature dictionary into a positional tuple expected by a sub-classed Keras model.
- Parameters:
feature_keys (
Sequence[str]) –Keys in the exact positional order required by the model that will populate the tuple. For your PINN:
>>> feature_keys = [ ... "coords", # (B, T, 3) ... "static_features", # (B, S) ... "dynamic_features", # (B, T, D) ... "future_features", # (B, H, F) ... ]
target_keys (
Sequence[str]orNone, defaultNone) – IfNone, pass through the dataset’s existing target element. If a sequence is provided, extract those keys from the feature dictionary or, when available, from the target dictionary. The function returns a single tensor whenlen(target_keys) == 1and a{key: tensor}mapping otherwise.allow_missing_optional (
bool, defaultTrue) – Whether to substituteNonefor missing optional feature or target keys. IfFalse, a missing key raisesKeyError.
- Returns:
A map-function you can plug into a
tf.data.Dataset:>>> mapper = make_dict_to_tuple_fn(feature_keys, ["subsidence", "gwl"]) >>> train_ds = train_ds.map(mapper, num_parallel_calls=tf.data.AUTOTUNE)
- Return type:
Callable
Notes
Duplicate names in
feature_keysare rejected so the positional tuple cannot be ambiguous.The function is fully Autograph-compatible; TensorFlow will stage it as part of the dataset pipeline.
- geoprior.models.utils.extract_txy_in(inputs, coord_slice_map=None, expect_dim=None, verbose=0, _logger=None, **kws)[source]#
Extracts t, x, y tensors from various input formats.
This utility standardizes coordinate inputs, accepting a single tensor or a dictionary, and handling both 2D (spatial/static) and 3D (spatio-temporal) data. It ensures a consistent 3D output format for robust downstream processing.
- Parameters:
inputs (
tf.Tensor,np.ndarray, ordict) – The input data containing coordinates. A single tensor or array may be 2D with shape(batch, 3)or 3D with shape(batch, time_steps, 3). A dictionary may contain a'coords'key with the coordinate tensor, or separate't','x', and'y'keys.coord_slice_map (
dict, optional) – Mapping for ‘t’, ‘x’, ‘y’ to their index in the last dimension of a coordinate tensor. Defaults to {‘t’: 0, ‘x’: 1, ‘y’: 2}.expect_dim (
{'2d', '3d'}, optional) – If provided, enforces that the input resolves to the specified dimension.'2d'requires input shaped like(batch, 3)or a dictionary of(batch, 1)tensors.'3d'requires input shaped like(batch, time, 3)or a dictionary of(batch, time, 1)tensors. IfNone, both are accepted and 2D inputs are expanded to 3D.verbose (
int, default0) – Controls the verbosity of logging messages. 0 is silent, 1 provides basic info, and higher values provide more detail.
- Returns:
t, x, y – The extracted t, x, and y coordinate tensors, each reshaped to be 3D with a singleton last dimension, e.g., (batch, time_steps, 1).
- Return type:
Tuple[tf.Tensor,tf.Tensor,tf.Tensor]- Raises:
ValueError – If input format is unsupported, dimensions are inconsistent, or expect_dim constraint is violated.
- geoprior.models.utils.process_pde_modes(pde_mode, enforce_consolidation=False, pde_mode_config=None, solo_return=False)[source]#
Normalize and validate PDE mode selection.
- Parameters:
pde_mode (
str,sequenceofstr, orNone) –Requested PDE mode(s).
Accepted canonical values are: -
"none"-"consolidation"-"gw_flow"-"both"Accepted aliases: -
None,"off"->"none"-"on"->"both"enforce_consolidation (
bool, defaultFalse) –If True, any resolved mode other than exact
["consolidation"]is coerced to["consolidation"]and a warning is emitted.This includes: -
["none"]-["gw_flow"]-["consolidation", "gw_flow"]pde_mode_config (
str,sequenceofstr, orNone, optional) – Optional override. If provided, this value takes precedence overpde_mode.solo_return (
bool, defaultFalse) –If False, return a canonical list of active modes.
If True, return a single canonical label: -
"none"-"consolidation"-"gw_flow"-"both"
- Returns:
Canonical PDE mode(s), either as a list or a single label.
- Return type:
- Raises:
TypeError – If the input type is invalid.
ValueError – If a token is unsupported or the mode selection is ambiguous.
Examples
>>> process_pde_modes(None) ['none']
>>> process_pde_modes("off") ['none']
>>> process_pde_modes("on") ['consolidation', 'gw_flow']
>>> process_pde_modes("both", solo_return=True) 'both'
>>> process_pde_modes("gw_flow", enforce_consolidation=True) ['consolidation']
What this package covers#
The public exports of geoprior.models.utils combine
general model helpers from geoprior.models.utils._utils
with PINN-specific helpers from
geoprior.models.utils.pinn.
The current export surface includes functions such as:
prepare_model_inputsandprepare_model_inputs_infor standardizing model input tuples;create_sequencesandsplit_static_dynamicfor sequence preparation;forecast_single_stepandforecast_multi_stepfor rollout helpers;format_predictions_to_dataframeandformat_pinn_predictionsfor output formatting;prepare_pinn_data_sequencesandnormalize_for_pinnfor PINN preparation;extract_txyandextract_txy_infor coordinate extraction;process_pde_modesandPDE_MODE_ALIASESfor PDE-mode normalization.
General model helper module#
Utility functions for neural networks models.
This module provides utility functions to preprocess data for Temporal Fusion Transformer (TFT) models, including splitting sequences into static and dynamic inputs and creating input sequences with corresponding targets for time series forecasting.
- geoprior.models.utils._utils.split_static_dynamic(sequences, static_indices, dynamic_indices, static_time_step=0, reshape_static=True, reshape_dynamic=True, static_reshape_shape=None, dynamic_reshape_shape=None)[source]
Split sequences into static and dynamic inputs for the model.
The split_static_dynamic function divides input sequences into static and dynamic components based on specified feature indices. Static features are typically location-specific and do not change over time, while dynamic features vary across different time steps.
(56)#\[\begin{split}\text{Static Inputs} = \mathbf{S} = \mathbf{X}_{t, static\_indices} \\ \text{Dynamic Inputs} = \mathbf{D} = \mathbf{X}_{:, dynamic\_indices}\end{split}\]- Parameters:
sequences (numpy.ndarray) – Array of input sequences with shape (batch_size, sequence_length, num_features).
static_indices (List[int]) – Indices of static features within the feature dimension.
dynamic_indices (List[int]) – Indices of dynamic features within the feature dimension.
static_time_step (int, default 0) – Time step from which to extract static features (default is the first time step).
reshape_static (bool, default True) – Whether to reshape static inputs. If False, returns without reshaping.
reshape_dynamic (bool, default True) – Whether to reshape dynamic inputs. If False, returns without reshaping.
static_reshape_shape (:py:class:
Optional[Tuple[int`, :py:class:`...]], default None) – Desired shape for static inputs after reshaping. If None, defaults to (batch_size, num_static_vars, 1).dynamic_reshape_shape (:py:class:
Optional[Tuple[int`, :py:class:`...]], default None) – Desired shape for dynamic inputs after reshaping. If None, defaults to (batch_size, sequence_length, num_dynamic_vars, 1).
- Returns:
A tuple containing: - Static inputs with shape as specified. - Dynamic inputs with shape as specified.
- Return type:
Tuple[`numpy.ndarray`, numpy.ndarray]- Raises:
ValueError – If static_time_step is out of range for the given sequence length.
Examples
>>> import numpy as np >>> from geoprior.models.utils import split_static_dynamic >>> >>> # Create a dummy sequence array >>> sequences = np.random.rand(100, 10, 5) # ( ... batch_size=100, sequence_length=10, num_features=5) >>> >>> # Define static and dynamic feature indices >>> static_indices = [0, 1] # e.g., longitude and latitude >>> dynamic_indices = [2, 3, 4] # e.g., year, GWL, density >>> >>> # Split the sequences >>> static_inputs, dynamic_inputs = split_static_dynamic( ... sequences, ... static_indices=static_indices, ... dynamic_indices=dynamic_indices, ... static_time_step=0 ... ) >>> >>> print(static_inputs.shape) (100, 2, 1) >>> print(dynamic_inputs.shape) (100, 10, 3, 1)
Notes
Static Features: These are typically location-specific features such as geographical coordinates or categorical attributes that remain constant over time.
Dynamic Features: These features vary over different time steps and are essential for capturing temporal dependencies in the data.
Reshaping: The function provides flexibility in reshaping the static and dynamic inputs to match the input requirements of various models, including Temporal Fusion Transformers.
See also
geoprior.models.utils.create_sequencesFunction to create input sequences and targets for time series forecasting.
- geoprior.models.utils._utils.create_sequences(df, sequence_length, target_col, step=1, include_overlap=True, drop_last=True, forecast_horizon=None, verbose=3)[source]
Create input sequences and corresponding targets for time series forecasting.
The create_sequences function generates sequences of features and their corresponding targets from a time series dataset. This is essential for training sequence models like Temporal Fusion Transformers, LSTMs, and others that rely on temporal dependencies.
See more in User Guide.
- Parameters:
df (pandas.DataFrame) – The processed DataFrame containing features and the target variable.
sequence_length (int) – The number of past time steps to include in each input sequence.
target_col (str) – The name of the target column.
step (int, default 1) – The step size between the starts of consecutive sequences.
include_overlap (bool, default True) – Whether to include overlapping sequences based on the step size.
drop_last (bool, default True) – Whether to drop the last sequence if it does not have enough data points.
forecast_horizon (int, optional, default None) – The number of future time steps to predict. If set to None, the function will create targets for a single future time step. If provided, targets will consist of the next forecast_horizon time steps.
verbose (int, default 3) – Controls the verbosity of logging. Ranges from 0 (no logs) to 7 (maximal logs).
- Returns:
- A tuple containing:
sequences: Array of input sequences with shape (num_sequences, sequence_length, num_features).
- targets:
If forecast_horizon is None: Array of target values with shape (num_sequences,).
If forecast_horizon is an integer: Array of target sequences with shape (num_sequences, forecast_horizon).
- Return type:
Tuple[`numpy.ndarray`, numpy.ndarray]- Raises:
ValueError – If the DataFrame df does not contain the target_col.
Examples
>>> import pandas as pd >>> import numpy as np >>> from geoprior.models.utils import create_sequences
>>> # Create a dummy DataFrame >>> data = pd.DataFrame({ ... 'feature1': np.random.rand(100), ... 'feature2': np.random.rand(100), ... 'feature3': np.random.rand(100), ... 'target': np.random.rand(100) ... })
>>> # Create sequences for single-step forecasting >>> sequence_length = 4 >>> sequences, targets = create_sequences( ... df=data, ... sequence_length=sequence_length, ... target_col='target', ... step=1, ... include_overlap=True, ... drop_last=True, ... forecast_horizon=None ... ) >>> print(sequences.shape) (95, 4, 4) >>> print(targets.shape) (95,)
>>> # Create sequences for multi-step forecasting (e.g., 3 steps) >>> forecast_horizon = 3 >>> sequences, targets = create_sequences( ... df=data, ... sequence_length=4, ... target_col='target', ... step=1, ... include_overlap=True, ... drop_last=True, ... forecast_horizon=3 ... ) >>> print(sequences.shape) (92, 4, 4) >>> print(targets.shape) (92, 3)
Notes
Sequence Creation: The function slides a window of size sequence_length across the DataFrame to create input sequences. Each sequence is associated with a target value or sequence of values that immediately follow the input sequence.
- Forecast Horizon:
If forecast_horizon is None, the function creates targets for a single future time step.
If forecast_horizon is an integer H, the function creates targets consisting of the next H time steps.
Step Size: The step parameter controls the stride of the sliding window. A step of 1 results in overlapping sequences, while a larger step reduces overlap.
Handling Incomplete Sequences: If drop_last is set to False, the function includes the last sequence even if it doesn’t have enough data points to form a complete sequence or target.
Data Validation: The function utilizes are_all_frames_valid from geoprior.core.checks to ensure the integrity of input DataFrame before processing and exist_features to verify the presence of the target column.
The sequences generation can be expressed as:
(57)#\[\begin{split}\\text{For each sequence } i, \\\\ \\mathbf{X}^{(i)} = \\left[ \\mathbf{x}_{i}, \\mathbf{x}_{i+1}, \\\\ \\dots, \\mathbf{x}_{i+T-1} \\right] \\\\ y^{(i)} = \\begin{cases} \\mathbf{x}_{i+T} & \\text{if } \\text{forecast\\_horizon} = \\text{None} \\\\ \\left[ \\mathbf{x}_{i+T}, \\mathbf{x}_{i+T+1}, \\dots, \\\\ \\mathbf{x}_{i+T+H-1} \\right] & \\text{if } \\text{forecast\\_horizon} = H \\end{cases}\end{split}\]- Where:
\(\\mathbf{X}^{(i)}\) is the input sequence of length \(T\).
\(y^{(i)}\) is the target value(s) following the sequence.
See also
geoprior.models.utils.split_static_dynamicFunction to split sequences into static and dynamic inputs.
- geoprior.models.utils._utils.compute_forecast_horizon(data=None, dt_col=None, start_pred=None, end_pred=None, error='raise', verbose=1)[source]
Compute the forecast horizon for time series forecasting models.
This function calculates the number of future time steps (forecast_horizon) a model should predict based on the provided data or specified prediction dates. It intelligently infers the frequency of the data and computes the horizon accordingly. The function accommodates various datetime formats and handles different input scenarios robustly.
- Parameters:
data (
pandas.DataFrame,pandas.Series,list, ornumpy.ndarray, optional) – The dataset containing datetime information. If a pandas.DataFrame is provided, the dt_col parameter must be specified to indicate which column contains the datetime data. For pandas.Series, list, or numpy.ndarray, the function attempts to infer the frequency directly.dt_col (
str, optional) – The name of the column in data that contains datetime information. This parameter is required if data is a pandas.DataFrame. Example:dt_col='timestamp'start_pred (
str,int, ordatetime-like) – The starting point for forecasting. This can be a date string (e.g., ‘2023-04-10’), a datetime object, or an integer representing a year (e.g., 2024). If an integer is provided, it is interpreted as a year, and a warning is issued to inform the user of this interpretation.end_pred (
str,int, ordatetime-like) – The ending point for forecasting. Similar to start_pred, this can be a date string, a datetime object, or an integer representing a year. The function calculates the forecast horizon based on the difference between start_pred and end_pred.error (
{'raise', 'warn', 'ignore'}, default'raise') –Defines the error handling behavior when encountering issues such as invalid input types, missing date columns, or unparseable dates.
’raise’: Raises a ValueError when an error is encountered.
’warn’: Emits a warning and attempts to proceed with default behavior.
verbose (
int, default1) –Controls the level of verbosity for debug information.
0: No output.
1: Minimal output (e.g., starting message).
2: Intermediate output (e.g., detected dates, computed horizons).
3: Detailed output (e.g., types of predictions, inferred frequencies).
- Returns:
The computed forecast_horizon representing the number of steps ahead the model should predict. Returns None if an error occurs and error is set to ‘warn’.
- Return type:
- Raises:
ValueError – If invalid parameters are provided and error is set to ‘raise’.
Examples
>>> from geoprior.models.utils import compute_forecast_horizon >>> import pandas as pd >>> import numpy as np >>> from datetime import datetime, timedelta >>> >>> # Example 1: Using a DataFrame with a Date Column >>> df = pd.DataFrame({ ... 'date': pd.date_range(start='2023-01-01', periods=100, freq='D'), ... 'value': np.random.randn(100) ... }) >>> horizon = compute_forecast_horizon( ... data=df, ... dt_col='date', ... start_pred='2023-04-10', ... end_pred='2023-04-20', ... error='raise', ... verbose=3 ... ) >>> print(f"Forecast Horizon: {horizon}") Forecast Horizon: 11
>>> # Example 2: Using a List of Datetimes >>> dates = [datetime(2023, 1, 1) + timedelta(days=i) for i in range(100)] >>> horizon = compute_forecast_horizon( ... data=dates, ... start_pred='2023-04-10', ... end_pred='2023-04-20', ... error='warn', ... verbose=2 ... ) >>> print(f"Forecast Horizon: {horizon}") Forecast Horizon: 11
>>> # Example 3: Handling Integer Years >>> horizon = compute_forecast_horizon( ... start_pred=2024, ... end_pred=2030, ... error='raise', ... verbose=1 ... ) Forecast Horizon: 7
>>> # Example 4: Without Providing Data (Assuming Frequency Based on Prediction Dates) >>> horizon = compute_forecast_horizon( ... start_pred='2023-04-10', ... end_pred='2023-04-20', ... error='raise', ... verbose=1 ... ) Forecast Horizon: 11
Notes
When data is not provided, the function relies solely on the difference between start_pred and end_pred to compute the forecast horizon. In such cases, if the frequency cannot be inferred, the horizon is calculated based on the largest possible time unit (years, months, weeks, days).
If start_pred is after end_pred, the function returns 0 and issues a warning or raises an error based on the error parameter.
The function attempts to infer the frequency of the data using pandas utilities. If the frequency cannot be inferred, it defaults to calculating the horizon based on the time difference in the most significant unit.
See also
pandas.date_rangeGenerates a fixed frequency DatetimeIndex.
pandas.infer_freqInfers the frequency of a DatetimeIndex.
- geoprior.models.utils._utils.prepare_spatial_future_data(final_processed_data, feature_columns, dynamic_feature_indices, sequence_length=1, dt_col='date', static_feature_names=None, forecast_horizon=None, future_years=None, encoded_cat_columns=None, scaling_params=None, spatial_cols=None, squeeze_last=False, verbosity=0)[source]
Prepare future static and dynamic inputs for making predictions.
This function prepares the necessary static and dynamic inputs required for forecasting future values in time series data. It processes the provided dataset by grouping it by
location_id, extracting the last sequence of data points based on the specifiedsequence_length, and generating future inputs for prediction over the definedforecast_horizon.The function handles both integer and datetime representations of the
dt_col, extracting the year from datetime columns when necessary. It also allows for flexibility in specifying static features and encoded categorical variables.(58)#\[\text{scaled\_time} = \frac{\text{future\_time} - \mu}{\sigma}\]- Parameters:
final_processed_data (
pandas.DataFrame) – The processed DataFrame containing all features and targets. Must include thelocation_idcolumn and the specifieddt_col.feature_columns (
List[str]) – List of feature column names to be used for dynamic input preparation.dynamic_feature_indices (
List[int]) – Indices of dynamic features infeature_columns. These features are considered time-dependent and are used to prepare dynamic inputs.sequence_length (
int, optional) – The number of past time steps to include in each input sequence. Default is1.dt_col (
str, optional) – The name of the time-related column infinal_processed_data. Defaults to'date'.static_feature_names (
List[str], optional) – List of static feature column names. If not provided, defaults to['longitude', 'latitude']plus anyencoded_cat_columns.forecast_horizon (
int, optional) – The number of future time steps to predict. If set toNone, the function defaults to predicting the next immediate time step.future_years (
List[int], optional) – List of future years to predict. Must match the length offorecast_horizonifforecast_horizonis provided.encoded_cat_columns (
List[str], optional) – List of encoded categorical column names to be treated as static features.scaling_params (
Dict[str,Dict[str,float]], optional) – Dictionary containing scaling parameters (mean and standard deviation) for features. Example:{'year': {'mean': 2000, 'std': 10}}. If not provided, the function computes the mean and std for thedt_col.squeeze_last (
bool, defaultTrue,) – Squeeze the last axis which correspond to the output dimensionyif equal to1.verbosity (
int, optional) – Verbosity level from0to7for debugging and understanding the process. Higher values produce more detailed logs.
- Returns:
A tuple containing:
future_static_inputsnumpy.ndarrayArray of future static inputs with shape
(num_samples, num_static_vars, 1).
future_dynamic_inputsnumpy.ndarrayArray of future dynamic inputs with shape
(num_samples, sequence_length, num_dynamic_vars, 1).
future_years_listList[int]List of future time values corresponding to each sample.
location_ids_listList[int]List of location IDs corresponding to each sample.
longitudesList[float]List of longitude values corresponding to each sample.
latitudesList[float]List of latitude values corresponding to each sample.
- Return type:
Tuple[np.ndarray,np.ndarray,List[int],List[int],List[float],List[float]]
Examples
>>> from geoprior.models.utils import prepare_spatial_future_data >>> import pandas as pd >>> data = pd.DataFrame({ ... 'location_id': [1, 1, 1, 2, 2, 2], ... 'year': [2018, 2019, 2020, 2018, 2019, 2020], ... 'longitude': [10.0, 10.0, 10.0, 20.0, 20.0, 20.0], ... 'latitude': [50.0, 50.0, 50.0, 60.0, 60.0, 60.0], ... 'temperature': [15, 16, 15.5, 20, 21, 20.5], ... 'rainfall': [100, 110, 105, 200, 210, 205], ... 'encoded_cat': [1, 1, 1, 2, 2, 2] ... }) >>> feature_cols = ['year', 'temperature', 'rainfall', 'encoded_cat'] >>> dynamic_indices = [0, 1, 2] >>> future_static, future_dynamic, future_years, loc_ids, longs,\ lats = prepare_spatial_future_data( ... final_processed_data=data, ... feature_columns=feature_cols, ... dynamic_feature_indices=dynamic_indices, ... sequence_length=2, ... forecast_horizon=1, ... future_years=[2021], ... encoded_cat_columns=['encoded_cat'], ... verbosity=5, ... dt_col='year' ... ) >>> print(future_static.shape) (2, 3, 1) >>> print(future_dynamic.shape) (2, 2, 3, 1)
Notes
The function handles both integer and datetime representations of the
dt_col. Ifdt_colis a datetime type, the year is extracted for scaling purposes.If
forecast_horizonis set toNone, the function defaults to generating data for the next immediate time step based on the last entry in the time column.Ensure that the length of
future_yearsmatchesforecast_horizonifforecast_horizonis provided.The
static_feature_namesparameter allows for flexibility in specifying which static features to include. If not provided, it defaults to['longitude', 'latitude']plus anyencoded_cat_columns.
See also
prepare_future_dataMain function for preparing future data inputs.
- geoprior.models.utils._utils.compute_anomaly_scores(y_true, y_pred=None, method='statistical', threshold=3.0, domain_func=None, contamination=0.05, epsilon=1e-06, estimator=None, random_state=None, residual_metric='mse', objective='ts', verbose=1)[source]
Compute anomaly scores for given true targets using various methods.
This utility function,
anomaly_scores, provides a flexible approach to compute anomaly scores outside the XTFT model itself. Anomaly scores serve as indicators of how unusual certain observations are, guiding the model towards more robust and stable forecasts. By detecting and quantifying anomalies, practitioners can adjust forecasting strategies, improve predictive performance, and handle irregular patterns more effectively.- Parameters:
y_true (
np.ndarray) –The ground truth target values with shape
(B, H, O), where: - B: batch size - H: number of forecast horizons (time steps ahead) - O: output dimension (e.g., number of target variables).Typically, y_true corresponds to the same array passed as the forecast target to the model. All computations of anomalies are relative to these true values or, if provided, their predicted counterparts y_pred.
y_pred (
np.ndarray, optional) – The predicted values with shape(B, H, O). If provided and the method is set to ‘residual’, the anomaly scores are derived from the residuals between y_true and y_pred. In this scenario, anomalies reflect discrepancies indicating unusual conditions or model underperformance.method (
str, optional) –The method used to compute anomaly scores. Supported options:
"statistical"or"stats": Uses mean and standard deviation of y_true to measure deviation from the mean. Points far from the mean by a certain factor (controlled by threshold) yield higher anomaly scores.Formally, let \(\mu\) be the mean of y_true and \(\sigma\) its standard deviation. The anomaly score for a point \(y\) is:
(59)#\[(\frac{y - \mu}{\sigma + \varepsilon})^2\]where \(\varepsilon\) is a small constant for numerical stability.
"domain": Uses a domain-specific heuristic (provided by domain_func) to compute scores. If no domain_func is provided, a default heuristic marks negative values as anomalies."isolation_forest"or"if": Employs the IsolationForest algorithm to detect outliers. The model learns a structure to isolate anomalies more quickly than normal points. Higher contamination rates allow more points to be considered anomalous."residual": If y_pred is provided, anomalies are derived from residuals: the difference (y_true - y_pred). By default, mean squared error (mse) is used. Other metrics include mae and rmse, offering flexibility in quantifying deviations:(60)#\[\text{MSE: }(y_{true} - y_{pred})^2\]
Default is
"statistical".threshold (
float, optional) – Threshold factor for the statistical method. Defines how far beyond mean ± (threshold * std) is considered anomalous. Though not directly applied as a mask here, it can guide interpretation of scores. Default is3.0.domain_func (
callable, optional) –A user-defined function for domain method. It takes y_true as input and returns an array of anomaly scores with the same shape. If none is provided, the default heuristic:
(61)\[\begin{split}\text{anomaly}(y) = \begin{cases} |y| \times 10 & \text{if } y < 0 \\ 0 & \text{otherwise} \end{cases}\end{split}\]contamination (
float, optional) – Used in the isolation_forest method. Specifies the proportion of outliers in the dataset. Default is0.05.epsilon (
float, optional) – A small constant \(\varepsilon\) for numerical stability in calculations, especially during statistical normalization. Default is1e-6.estimator (
object, optional) – A pre-fitted IsolationForest estimator for the isolation_forest method. If not provided, a new estimator will be created and fitted to y_true.random_state (
int, optional) – Sets a random state for reproducibility in the isolation_forest method.residual_metric (
str, optional) –The metric used to compute anomalies from residuals if method is set to ‘residual’. Supported metrics:
"mse": mean squared error per point (residuals**2)"mae": mean absolute error per point |residuals|"rmse": root mean squared error sqrt((residuals**2))
Default is
"mse".objective (
str, optional) – Specifies the type of objective, for future extensibility. Default is"ts"indicating time series. Could be extended for other tasks in the future.verbose (
int, optional) – Controls verbosity. If verbose=1, some messages or warnings may be printed. Higher values might produce more detailed logs.
- Returns:
anomaly_scores – An array of anomaly scores with the same shape as y_true. Higher values indicate more unusual or anomalous points.
- Return type:
np.ndarray
Notes
Choosing an appropriate method depends on the data characteristics, domain requirements, and model complexity. Statistical methods are quick and interpretable but may oversimplify anomalies. Domain heuristics leverage expert knowledge, while isolation forest applies a more robust, data-driven approach. Residual-based anomalies help assess model performance and highlight periods where the model struggles.
Examples
>>> from geoprior.models.losses import compute_anomaly_scores >>> import numpy as np
>>> # Statistical method example >>> y_true = np.random.randn(32, 20, 1) # (B,H,O) >>> scores = compute_anomaly_scores(y_true, method='statistical', threshold=3) >>> scores.shape (32, 20, 1)
>>> # Domain-specific example >>> def my_heuristic(y): ... return np.where(y < -1, np.abs(y)*5, 0.0) >>> scores = compute_anomaly_scores(y_true, method='domain', domain_func=my_heuristic)
>>> # Isolation Forest example >>> scores = compute_anomaly_scores(y_true, method='isolation_forest', contamination=0.1)
>>> # Residual-based example >>> y_pred = y_true + np.random.normal(0, 1, y_true.shape) # Introduce noise >>> scores = compute_anomaly_scores(y_true, y_pred=y_pred, method='residual', residual_metric='mae')
See also
geoprior.models.losses.objective_lossFor integrating anomaly scores into a multi-objective loss.
- geoprior.models.utils._utils.generate_forecast(xtft_model, train_data, dt_col, dynamic_features, future_features=None, static_features=None, test_data=None, mode='quantile', spatial_cols=None, forecast_horizon=4, time_steps=3, q=None, tname=None, forecast_dt=None, savefile=None, verbose=3, **kw)[source]
Generate forecast using the XTFT model.
This function uses a pre-trained Keras model to forecast future values based on provided historical data. The model receives three inputs: X_static, X_dynamic, and X_future re-built from train_data, and outputs predictions over a specified forecast horizon.
See more in User Guide.
- Parameters:
xtft_model (
object) – A validated Keras model instance. It is processed by thevalidate_keras_modelmethod.train_data (
pandas.DataFrame) – The training data containing historical records. Must include the dt_col and all required feature columns.dt_col (
str) – Name of the column representing time. It may be a datetime or numeric column (e.g."year").dynamic_features (
listofstr) – List of dynamic feature column names. They are formatted viacolumns_manager.future_features (
listofstr, optional) – List of future feature names. These columns are tiled over the forecast horizon.static_features (
listofstr, optional) – List of static feature names. If not provided, a dummy input is used.test_data (
pandas.DataFrame, optional) – DataFrame containing actual values used for evaluation. If provided, it is used to compute the R² and coverage score formode='quantile'.mode (
str, optional) – Forecast mode. Must be either"quantile"or"point". Inquantilemode, predictions for multiple quantiles (default: [0.1, 0.5, 0.9]) are computed.spatial_cols (
listofstr, optional) – List of spatial column names for grouping the data. When provided, forecasts are computed per location; otherwise, a global forecast is performed.forecast_horizon (
int, optional) – Number of future periods to forecast. Default is 4.time_steps (
int, optional) – Number of past time steps to use as input for the model. Default is 3.q (
listoffloat, optional) – List of quantiles for use inquantilemode. Default is [0.1, 0.5, 0.9]. Each quantile is validated by theassert_ratiofunction.tname (
str, optional) – Target variable name used for constructing forecast result columns. Defaults to"target".forecast_dt (
listorstr, optional) – List of forecast dates or"auto"to derive dates from dt_col. In auto mode, if dt_col is datetime, frequency is inferred usingpd.infer_freq.savefile (
str, optional) – Path to the CSV file where forecast results will be saved. If not provided, a default filename is generated.verbose (
int, optional) – Verbosity level (0-7). Controls the amount of execution output.
- Returns:
A DataFrame containing the forecast results. In
quantilemode, each forecast period includes columns for each quantile; inpointmode, a single prediction column is provided.- Return type:
Examples
Example refering to Train data only
>>> import os >>> import pandas as pd >>> import numpy as np >>> from geoprior.models.transformers import XTFT >>> from geoprior.models.losses import combined_quantile_loss >>> from geoprior.models.utils import generate_forecast >>> >>> # Create a dummy training DataFrame with a date column, >>> # dynamic features "feat1", "feat2", static feature "stat1", >>> # and target "price". >>> date_rng = pd.date_range(start="2020-01-01", periods=50, freq="D") >>> train_df = pd.DataFrame({ ... "date": date_rng, ... "feat1": np.random.rand(50), ... "feat2": np.random.rand(50), ... "stat1": np.random.rand(50), ... "price": np.random.rand(50) ... }) >>> >>> # Prepare a dummy XTFT model with example parameters. >>> # Note: The model expects the following input shapes: >>> # - X_static: (n_samples, static_input_dim) >>> # - X_dynamic: (n_samples, time_steps, dynamic_input_dim) >>> # - X_future: (n_samples, time_steps, future_input_dim) >>> my_model = XTFT( ... static_input_dim=1, # "stat1" ... dynamic_input_dim=2, # "feat1" and "feat2" ... future_input_dim=1, # features provided for dim1 ... forecast_horizon=5, # Forecasting 5 periods ahead ... quantiles=[0.1, 0.5, 0.9], ... embed_dim=16, ... max_window_size=3, ... memory_size=50, ... num_heads=2, ... dropout_rate=0.1, ... lstm_units=32, ... attention_units=32, ... hidden_units=16 ... ) >>> my_model.compile(optimizer="adam") >>> >>> # Create dummy input arrays for model fitting. >>> # For simplicity, assume time_steps = 3 and use random data. >>> X_static = train_df[["stat1"]].values # shape: (50, 1) >>> # Create a dummy dynamic input array of shape (50, 3, 2) >>> X_dynamic = np.random.rand(50, 3, 2) >>> # Create a dummy features >>> X_future = np.random.rand(50, 3, 1) >>> # Create dummy target output from "price" >>> y_array = train_df["price"].values.reshape(50, 1, 1) >>> >>> # Fit the model on the dummy data. >>> my_model.fit( ... x=[X_static, X_dynamic, X_future], ... y=y_array, ... epochs=1, ... batch_size=8 ... ) >>> >>> # Generate forecast using the generate_forecast function. >>> forecast = generate_forecast( ... xtft_model=my_model, ... train_data=train_df, ... dt_col="date", ... dynamic_features=["feat1", "feat2"], ... static_features=["stat1"], ... forecast_horizon=5, ... time_steps=3, ... tname="price", ... mode="quantile", ... verbose=3 ... ) >>> print(forecast.head())
Example refering to Test data included.
>>> # Create a dummy DataFrame with a date column, >>> # two dynamic features ("feat1", "feat2"), one static feature ("stat1"), >>> # and target "price". >>> date_rng = pd.date_range(start="2020-01-01", periods=60, freq="D") >>> data = { ... "date": date_rng, ... "feat1": np.random.rand(60), ... "feat2": np.random.rand(60), ... "stat1": np.random.rand(60), ... "price": np.random.rand(60) ... } >>> df = pd.DataFrame(data) >>> >>> # Split the DataFrame into training and test sets. >>> # Training data: dates before 2020-02-01 >>> # Test data: dates from 2020-02-01 onward. >>> train_df = df[df["date"] < "2020-02-01"].copy() >>> test_df = df[df["date"] >= "2020-02-01"].copy() >>> >>> # Create dummy input arrays for model fitting. >>> # Assume time_steps = 3. >>> X_static = train_df[["stat1"]].values # Shape: (n_train, 1) >>> X_dynamic = np.random.rand(len(train_df), 3, 2) >>> X_future = np.random.rand(len(train_df), 3, 1) >>> # Create dummy target output from "price". >>> y_array = train_df["price"].values.reshape(len(train_df), 1, 1) >>> >>> # Instantiate a dummy XTFT model. >>> my_model = XTFT( ... static_input_dim=1, # "stat1" ... dynamic_input_dim=2, # "feat1" and "feat2" ... future_input_dim=1, # For the provided future feature ... forecast_horizon=5, # Forecasting 5 periods ahead ... quantiles=[0.1, 0.5, 0.9], ... embed_dim=16, ... max_window_size=3, ... memory_size=50, ... num_heads=2, ... dropout_rate=0.1, ... lstm_units=32, ... attention_units=32, ... hidden_units=16 ... ) >>> loss_fn = combined_quantile_loss(my_model.quantiles) >>> my_model.compile(optimizer="adam", loss=loss_fn) >>> >>> # Fit the model on the training data. >>> my_model.fit( ... x=[X_static, X_dynamic, X_future], ... y=y_array, ... epochs=1, ... batch_size=8, ... callbacks = [early_stopping, model_checkpoint] ... ) >>> >>> # Generate forecast using the generate_forecast function. >>> # This example uses test_df for evaluation, which will compute >>> # metrics like R² Score and Coverage Score. >>> forecast = generate_forecast( ... xtft_model=my_model, ... train_data=train_df, ... dt_col="date", ... dynamic_features=["feat1", "feat2"], ... static_features=["stat1"], ... test_data=test_df.iloc[:5, :], # to fit the first horizon forecasting. ... forecast_horizon=5, ... time_steps=3, ... tname="price", ... mode="quantile", ... verbose=3 ... ) >>> print(forecast.head())
Example of Point forecasting
>>> # Create a dummy training DataFrame with a date column, >>> # two dynamic features ("feat1", "feat2"), one static feature ("stat1"), >>> # and target "price". >>> date_rng = pd.date_range(start="2020-01-01", periods=50, freq="D") >>> train_df = pd.DataFrame({ ... "date": date_rng, ... "feat1": np.random.rand(50), ... "feat2": np.random.rand(50), ... "stat1": np.random.rand(50), ... "price": np.random.rand(50) ... }) >>> >>> # Create dummy input arrays for model fitting. >>> # X_static is derived from the static feature "stat1". >>> X_static = train_df[["stat1"]].values # shape: (50, 1) >>> >>> # X_dynamic is a dummy dynamic array for "feat1" and "feat2". >>> # For time_steps = 3, its shape is (50, 3, 2). >>> X_dynamic = np.random.rand(50, 3, 2) >>> >>> # X_future is a dummy array for future features. >>> # Here, we assume a single future feature with shape (50, 3, 1). >>> X_future = np.random.rand(50, 3, 1) >>> >>> # Create dummy target output from "price". >>> y_array = train_df["price"].values.reshape(50, 1, 1) >>> >>> # Instantiate a dummy XTFT model. >>> my_model = XTFT( ... static_input_dim=1, # "stat1" ... dynamic_input_dim=2, # "feat1" and "feat2" ... future_input_dim=1, # Provided future feature ... forecast_horizon=5, # Forecast 5 periods ahead ... quantiles=None, # [0.1, 0.5, 0.9] Not used in point mode ... embed_dim=16, ... max_window_size=3, ... memory_size=50, ... num_heads=2, ... dropout_rate=0.1, ... lstm_units=32, ... attention_units=32, ... hidden_units=16 ... ) >>> my_model.compile(optimizer="adam") >>> >>> # Fit the model on the dummy data. >>> my_model.fit( ... x=[X_static, X_dynamic, X_future], ... y=y_array, ... epochs=1, ... batch_size=8 ... ) >>> >>> # Generate forecast using the generate_forecast function in point mode. >>> forecast = generate_forecast( ... xtft_model=my_model, ... train_data=train_df, ... dt_col="date", ... dynamic_features=["feat1", "feat2"], ... static_features=["stat1"], ... forecast_horizon=5, ... time_steps=3, ... tname="price", ... mode="point", ... verbose=3 ... ) >>> print(forecast.head())
Notes
The function groups data by spatial_cols if provided, and formats features via
columns_manager. It validates the time column usingcheck_datetimeand uses dummy inputs for missing static or future features. The forecast is produced by invokingxtft_model.predicton a list containing static, dynamic, and future inputs. The predictions are generated as follows:(62)#\[\hat{y}_{t+i} = f\Bigl(X_{\text{static}},\; X_{\text{dynamic}},\; X_{\text{future}}\Bigr)\]where \(i\) denotes the forecast period.
See also
geoprior.models.utils.reshape_xtft_dataFunction to reshape data for XTFT models.
geoprior.utils.validator.validate_keras_modelFunction to validate Keras model compatibility.
geoprior.core.handlers.columns_managerUtility to manage and format column names.
geoprior.core.checks.check_datetimeFunction to check and validate datetime columns.
geoprior.core.checks.check_spatial_columnsFunction to validate spatial columns in data.
geoprior.core.checks.assert_ratioFunction to validate and assert ratio values.
geoprior.metrics_special.coverage_scoreFunction to compute coverage score for quantile predictions.
- geoprior.models.utils._utils.generate_forecast_with(xtft_model, inputs, forecast_horizon, y=None, dt_col=None, mode='quantile', spatial_cols=None, q=None, tname=None, forecast_dt=None, apply_mask=False, mask_values=None, mask_fill_value=None, savefile=None, verbose=3, **kw)[source]
Generate forecasts using a pre-trained XTFT model based on the forecast horizon.
There are two approaches to generating forecasts with an XTFT model:
A monolithic function (e.g.
generate_forecast) that handles both single-step and multi-step forecasts within a single implementation. This approach results in a single, large function that internally branches its logic based on the value of the forecast horizon.A modular design where the single-step and multi-step forecasting functionalities are separated into two distinct functions (e.g.
forecast_single_stepandforecast_multi_step), with a thin wrapper (e.g.generate_xtft_forecast) that dispatches to the appropriate function based on the forecast horizon.
The modular approach (2) is generally preferred because it separates concerns and improves code readability, maintainability, and unit testing. Each function becomes responsible for a well-defined task: one for single-step forecasts and one for multi-step forecasts. The wrapper function, which we propose to name
generate_xtft_forecast, simply selects the correct method based on the forecast horizon. Use this approach when your application may need to handle both short- and long- term forecasts, as it keeps the codebase modular and easier to debug.Below is an implementation of the wrapper function
generate_xtft_forecastthat callsforecast_single_stepwhenforecast_horizonequals 1 andforecast_multi_stepwhenforecast_horizonis greater than 1.- Parameters:
xtft_model (
object) – A validated Keras model instance. The model is expected to be verified viavalidate_keras_model.inputs (
listortupleofnumpy.ndarray) – A list containing three elements:X_static,X_dynamic, andX_future. Ifspatial_colsis provided, it is assumed that the first two columns ofX_staticcorrespond to the first and second spatial coordinates of the original training data.forecast_horizon (
int) – The number of future time steps to forecast. A value of 1 triggers a single-step forecast; values greater than 1 trigger a multi-step forecast.y (
numpy.ndarray, optional) – Actual target values for evaluation. If provided, evaluation metrics (e.g., R² Score, and in quantile mode, the coverage score) are computed.dt_col (
str, optional) – Name of the time column (e.g."year"). If provided, the output DataFrame includes a column with these values.mode (
str, optional) – Forecast mode, either"quantile"or"point". In quantile mode, predictions are generated for multiple quantiles (default:[0.1, 0.5, 0.9]); in point mode, a single prediction is generated.spatial_cols (
listofstr, optional) – List of spatial column names. If provided, it must contain at least two elements corresponding to the first and second columns of the original training data’sX_static.time_steps (
int, optional) – The number of historical time steps used as input.q (
listoffloat, optional) – List of quantile values for quantile forecasting. Default is[0.1, 0.5, 0.9]whenmodeis"quantile".tname (
str, optional) – Target variable name used to construct output column names (e.g.,"subsidence"). Default is"target".forecast_dt (
any, optional) – Forecast datetime information. If provided and its length matchesforecast_horizon, its values are added to the output DataFrame.apply_mask (
bool, optional) – If True, applies masking (viamask_by_reference) to adjust predictions in non-subsiding areas. Requires that bothmask_valuesandmask_fill_valueare provided.mask_values (
scalar, optional) – The reference value(s) used for masking. Must be provided ifapply_maskis True.mask_fill_value (
scalar, optional) – The value used to fill masked predictions. Must be provided ifapply_maskis True.savefile (
str, optional) – File path to save the forecast results as a CSV file. If not provided, a default filename is generated.verbose (
int, optional) – Verbosity level controlling printed output.**kw (
dict, optional) – Does nothing; here for future extension.
- Returns:
A DataFrame containing the forecast results. In quantile mode, the output includes columns for each quantile and forecast step (e.g.
<tname>_q10_step1,<tname>_q50_step2, etc.); in point mode, it contains a single prediction column per forecast step (e.g.<tname>_pred_step1). Ifyis provided, an additional column (<tname>_actual) is included.- Return type:
Examples
>>> from geoprior.models.transformers import XTFT >>> from geoprior.models.utils import generate_forecast_with >>> import numpy as np >>> >>> # Prepare a dummy XTFT model with example parameters. >>> my_model = XTFT( ... static_input_dim=10, ... dynamic_input_dim=5, ... future_input_dim=3, ... forecast_horizon=1, # This parameter will be updated in the ... # wrapper function based on forecast_horizon. ... quantiles=[0.1, 0.5, 0.9], ... embed_dim=32, ... max_window_size=3, ... memory_size=100, ... num_heads=4, ... dropout_rate=0.1, ... lstm_units=64, ... attention_units=64, ... hidden_units=32 ... ) >>> my_model.compile(optimizer='adam') >>> >>> # Create dummy input data. >>> X_static = np.random.rand(100, 10) >>> X_dynamic = np.random.rand(100, 3, 5) >>> X_future = np.random.rand(100, 3, 3) >>> y_array = np.random.rand(100, 1, 1) # For single-step target output. >>> inputs = [X_static, X_dynamic, X_future] >>> >>> # Fit the model with dummy data. >>> my_model.fit( ... x=[X_static, X_dynamic, X_future], ... y=y_array, ... epochs=1, ... batch_size=32 ... ) >>> >>> # Example for a single-step forecast: >>> forecast_df = generate_forecast_with( ... xtft_model=my_model, ... inputs=inputs, ... forecast_horizon=1, ... y=y_array, ... dt_col="year", ... mode="quantile", ... spatial_cols=["longitude", "latitude"], ... tname="subsidence", ... verbose=3 ... ) >>> print(forecast_df.head()) >>> >>> # Example for a multi-step forecast: >>> forecast_dates = ["2023", "2024", "2025", "2026"] >>> forecast_df = generate_forecast_with( ... xtft_model=my_model, ... inputs=inputs, ... forecast_horizon=4, ... y=y_array, ... dt_col="year", ... mode="point", ... spatial_cols=["longitude", "latitude"], ... tname="subsidence", ... forecast_dt=forecast_dates, ... verbose=3 ... ) >>> print(forecast_df.head())
See also
forecast_single_stepGenerates a single-step forecast.
forecast_multi_stepGenerates a multi-step forecast.
validate_keras_modelValidates a Keras model.
coverage_scoreComputes the coverage score.
- geoprior.models.utils._utils.forecast_multi_step(xtft_model, inputs, forecast_horizon, y=None, dt_col=None, mode='quantile', spatial_cols=None, q=None, tname=None, forecast_dt=None, apply_mask=False, mask_values=None, mask_fill_value=None, savefile=None, verbose=3, **kws)[source]
Generate a multi-step forecast using the XTFT model.
This function generates forecasts for multiple future time steps using a pre-trained XTFT deep learning model. The model takes three inputs: X_static, X_dynamic, and X_future, and produces predictions according to the formulation:
(63)#\[\hat{y}_{t+i} = f\Bigl( X_{\text{static}},\; X_{\text{dynamic}},\; X_{\text{future}} \Bigr)\]for \(i = 1, \dots, forecast_horizon\), where \(f\) is the trained XTFT model.
- Parameters:
xtft_model (
object) – A validated Keras model instance. The model is expected to be verified viavalidate_keras_model.inputs (
listortupleofnumpy.ndarray) – A list containing three elements:X_static,X_dynamic, andX_future. Ifspatial_colsis provided, it is assumed that the first two columns ofX_staticcorrespond to the first and second spatial coordinates of the original training data.forecast_horizon (
int) – The number of future time steps to forecast. For example, ifforecast_horizonis 4, the model will generate predictions for 4 steps ahead.y (
numpy.ndarray, optional) – Actual target values. If provided, evaluation metrics such as R² Score and, in quantile mode, the coverage score are computed.dt_col (
str, optional) – Name of the time column (e.g."year"). If provided, a column with this name is added to the output DataFrame. The actual time values must be supplied externally.mode (
str, optional) – Forecast mode. Must be either"quantile"or"point". In quantile mode, predictions are generated for multiple quantiles (default:[0.1, 0.5, 0.9]); in point mode, a single prediction is generated.spatial_cols (
listofstr, optional) – A list of spatial column names. If provided, it must contain at least two elements corresponding to the first and second columns of the original training data’sX_static.time_steps (
int, optional) – The number of historical time steps used as input. Default is3.q (
listoffloat, optional) – List of quantile values for quantile forecasting. The default is[0.1, 0.5, 0.9]whenmodeis"quantile".tname (
str, optional) – Target variable name used to construct output column names. For instance, iftnameis"subsidence", then output columns may be named"subsidence_q10_step1","subsidence_q50_step2", etc. Default is"target".forecast_dt (
any, optional) – Forecast datetime information. If provided and its length matchesforecast_horizon, its values are added to the output DataFrame.apply_mask (
bool, optional) – If True, applies masking viamask_by_referenceto replace predictions in non-subsiding areas. Requires that bothmask_valuesandmask_fill_valueare provided.mask_values (
scalar, optional) – The reference value(s) used for masking. Must be provided ifapply_maskis True.mask_fill_value (
scalar, optional) – The value used to fill masked predictions. Must be provided ifapply_maskis True.savefile (
str, optional) – File path to save the forecast results as a CSV file. If not provided, a default filename is generated.verbose (
int, optional) – Verbosity level controlling printed output. Higher values produce more detailed messages.
- Returns:
A DataFrame containing the multi-step forecast results. In quantile mode, the DataFrame includes columns for each quantile and each forecast step (e.g.
<tname>_q10_step1,<tname>_q50_step2, etc.); in point mode, it contains a single prediction column per forecast step (e.g.<tname>_pred_step1). Ifyis provided, an additional column (<tname>_actual) is included.- Return type:
Examples
>>> from geoprior.models.transformers import XTFT >>> from geoprior.models.utils import forecast_multi_step >>> from geoprior.models.losses import combined_quantile_loss >>> import pandas as pd >>> import numpy as np >>> >>> # Create a dummy training DataFrame with a date column, >>> # spatial features ("longitude", "latitude"), two dynamic >>> # features ("feat1", "feat2"), a static feature ("stat1"), and >>> # the target variable "subsidence". >>> date_rng = pd.date_range(start="2020-01-01", periods=60, ... freq="D") >>> train_df = pd.DataFrame({ ... "date": date_rng, ... "longitude": np.random.uniform(-180, 180, 60), ... "latitude": np.random.uniform(-90, 90, 60), ... "feat1": np.random.rand(60), ... "feat2": np.random.rand(60), ... "stat1": np.random.rand(60), ... "subsidence": np.random.rand(60) ... }) >>> >>> # Prepare dummy input arrays for model training. >>> # X_static is constructed using "longitude" and "stat1". >>> X_static = train_df[["longitude", "stat1"]].values >>> # X_dynamic for "feat1" and "feat2" with time_steps = 3. >>> X_dynamic = np.random.rand(60, 3, 2) >>> # X_future is a dummy future feature array with shape (60, 3, 1). >>> X_future = np.random.rand(60, 3, 1) >>> # Target output from "subsidence" reshaped to >>> # (60, 1, 1). For multi-step forecast, forecast_horizon is 4. >>> forecast_horizon = 4 >>> y_array = train_df["subsidence"].values.reshape(60, 1, 1) >>> >>> # Instantiate a dummy XTFT model. >>> my_model = XTFT( ... static_input_dim=2, # "longitude" and "stat1" ... dynamic_input_dim=2, # "feat1" and "feat2" ... future_input_dim=1, # One future feature ... forecast_horizon=forecast_horizon, ... quantiles=[0.1, 0.5, 0.9], ... embed_dim=16, ... max_window_size=3, ... memory_size=50, ... num_heads=2, ... dropout_rate=0.1, ... lstm_units=32, ... attention_units=32, ... hidden_units=16 ... ) >>> my_model.compile( ... optimizer="adam", ... loss=combined_quantile_loss(my_model.quantiles) ... ) >>> >>> # Fit the model on the dummy data for demonstration. >>> my_model.fit( ... x=[X_static, X_dynamic, X_future], ... y=y_array, ... epochs=1, ... batch_size=8 ... ) >>> >>> # Generate forecast datetime values for the forecast horizon. >>> forecast_dates = pd.date_range(start="2020-02-01", ... periods=forecast_horizon, freq="D") >>> >>> # Package inputs as expected by forecast_multi_step. >>> inputs = [X_static, X_dynamic, X_future] >>> >>> # Generate a multi-step forecast in quantile mode. >>> forecast_df_quantile = forecast_multi_step( ... xtft_model=my_model, ... inputs=inputs, ... forecast_horizon=forecast_horizon, ... y=y_array, ... dt_col="date", ... mode="quantile", ... spatial_cols=["longitude", "latitude"], ... q=[0.1, 0.5, 0.9], ... tname="subsidence", ... forecast_dt=forecast_dates, ... apply_mask=False, ... verbose=3 ... ) >>> print("Quantile Forecast:") >>> print(forecast_df_quantile.head()) >>>
For point forecast
>>> # Instantiate a dummy XTFT model. >>> my_model = XTFT( ... static_input_dim=2, # "longitude" and "stat1" ... dynamic_input_dim=2, # "feat1" and "feat2" ... future_input_dim=1, # One future feature ... forecast_horizon=forecast_horizon, ... quantiles=None, # set quantiles to None ... embed_dim=16, ... max_window_size=3, ... memory_size=50, ... num_heads=2, ... dropout_rate=0.1, ... lstm_units=32, ... attention_units=32, ... hidden_units=16 ... ) >>> my_model.compile( ... optimizer="adam", loss="mse", ... ) >>> >>> # Fit the model on the dummy data for demonstration. >>> my_model.fit( ... x=[X_static, X_dynamic, X_future], ... y=y_array, ... epochs=1, ... batch_size=8 ... )
>>> # Generate a multi-step forecast in point mode. >>> forecast_df_point = forecast_multi_step( ... xtft_model=my_model, ... inputs=inputs, ... forecast_horizon=forecast_horizon, ... y=y_array, ... dt_col="date", ... mode="point", ... spatial_cols=["longitude", "latitude"], ... tname="subsidence", ... forecast_dt=forecast_dates, ... apply_mask=False, ... verbose=3 ... ) >>> print("Point Forecast:") >>> print(forecast_df_point.head())
Notes
In quantile mode, predictions are generated for each specified quantile for every forecast step, and the median (
0.5) is used for evaluation.In point mode, a single prediction is generated per forecast step.
The output prediction array is expected to have the shape \((n, forecast\_horizon, m)\), where \(n\) is the number of samples and \(m\) is the number of outputs per step (e.g., number of quantiles in quantile mode or 1 in point mode).
The provided
spatial_colsmust correspond to the first two columns of the original training data’sX_static.Evaluation metrics such as R² Score and Coverage Score (in quantile mode) are computed if actual target values (
y) are provided.The DataFrame is constructed by iterating over each sample and each forecast step.
See also
forecast_single_stepFunction for single-step forecasts.
coverage_scoreFunction to compute the coverage score.
validate_keras_modelFunction to validate a Keras model.
assert_ratioFunction to verify quantile ratios.
- geoprior.models.utils._utils.forecast_single_step(xtft_model, inputs, y=None, dt_col=None, mode='quantile', spatial_cols=None, q=None, tname=None, forecast_dt=None, apply_mask=False, mask_values=None, mask_fill_value=None, savefile=None, verbose=3, **kws)[source]
Generate a single-step forecast using the XTFT model.
This function generates a forecast for a single future time step using a pre-trained XTFT deep learning model. The model takes three inputs: X_static, X_dynamic, and X_future, and produces a prediction according to the formulation:
(64)#\[\hat{y}_{t+1} = f\Bigl( X_{\text{static}},\; X_{\text{dynamic}},\; X_{\text{future}} \Bigr)\]where \(f\) is the trained XTFT model. The predictions can be either quantile-based or point-based, as determined by the mode parameter.
- Parameters:
xtft_model (
object) – A validated Keras model instance. The model is expected to be verified viavalidate_keras_model.inputs (
listortupleofnumpy.ndarray) – A list containing three elements:X_static,X_dynamic, andX_future. Ifspatial_colsis provided, it is assumed that the first column ofX_staticcorresponds to the first spatial coordinate and the second column to the second spatial coordinate of the original training data.y (
numpy.ndarray, optional) – Actual target values. If provided, evaluation metrics such as R² Score and (in quantile mode) the coverage score are computed.dt_col (
str, optional) – Name of the time column (e.g."year"). If provided, a column with this name is added to the output DataFrame. The actual time values must be supplied externally.mode (
str, optional) – Forecast mode. Must be either"quantile"or"point". In quantile mode, predictions are generated for multiple quantiles (default:0.1,0.5, and0.9).spatial_cols (
listofstr, optional) – List of spatial column names. If provided, it must contain at least two elements and correspond to the first and second columns of the original training data’sX_static.q (
listoffloat, optional) – List of quantiles for quantile forecasting. Default is[0.1, 0.5, 0.9]when mode is"quantile".tname (
str, optional) – Target variable name for predictions. This name is used to construct output column names (e.g."subsidence"). Default is"target".forecast_dt (
any, optional) – Forecast datetime information. Not used in this function but may be provided for compatibility.apply_mask (
bool, optional) – If True, applies a masking function (mask_by_reference) to replace predictions in non-subsiding areas. Requires that bothmask_valuesandmask_fill_valueare provided.mask_values (
scalar, optional) – Reference value(s) used for masking. Must be provided ifapply_maskis True.mask_fill_value (
scalar, optional) – Value used to fill masked predictions. Must be provided ifapply_maskis True.savefile (
str, optional) – Path to a CSV file where the forecast results will be saved. If not provided, a default filename is generated.verbose (
int, optional) – Verbosity level controlling printed output. Higher values result in more detailed output.
- Returns:
A DataFrame containing the forecast results. In quantile mode, the output includes columns for each quantile (e.g.
<tname>_q10,<tname>_q50,<tname>_q90). In point mode, a single prediction column (<tname>_pred) is provided. If y is provided, an additional column with the actual target values (<tname>_actual) is included.- Return type:
Examples
>>> from geoprior.models.transformers import XTFT >>> from geoprior.models.utils import forecast_single_step >>> import pandas as pd >>> import numpy as np >>> >>> # Create a dummy training DataFrame with a date column, >>> # two dynamic features ("feat1", "feat2"), a static feature ("stat1"), >>> # and dummy spatial features ("longitude", "latitude"), as well as the >>> # target variable "subsidence". >>> date_rng = pd.date_range(start="2020-01-01", periods=50, freq="D") >>> train_df = pd.DataFrame({ ... "date": date_rng, ... "longitude": np.random.uniform(-180, 180, 50), ... "latitude": np.random.uniform(-90, 90, 50), ... "feat1": np.random.rand(50), ... "feat2": np.random.rand(50), ... "stat1": np.random.rand(50), ... "subsidence": np.random.rand(50) ... }) >>> >>> # Prepare dummy inputs for the model. >>> # For the static input, combine the spatial feature "longitude" and the >>> # static feature "stat1". The forecast_single_step function expects that, >>> # if spatial_cols is provided, the first two columns of X_static correspond >>> # to the spatial coordinates. >>> X_static = train_df[["longitude", "stat1"]].values # shape: (50, 2) >>> >>> # Create a dummy dynamic input array for "feat1" and "feat2". >>> # Assume time_steps = 3, so the shape is (50, 3, 2). >>> X_dynamic = np.random.rand(50, 3, 2) >>> >>> # Create a dummy future input array. >>> # For this example, assume one future feature with shape (50, 3, 1). >>> X_future = np.random.rand(50, 3, 1) >>> >>> # Create dummy target output from "subsidence", reshaped to (50, 1, 1) >>> y_array = train_df["subsidence"].values.reshape(50, 1, 1) >>> >>> # Instantiate a dummy XTFT model. >>> # The model expects: >>> # - X_static with shape (n_samples, static_input_dim) >>> # - X_dynamic with shape (n_samples, time_steps, dynamic_input_dim) >>> # - X_future with shape (n_samples, time_steps, future_input_dim) >>> my_model = XTFT( ... static_input_dim=2, # "longitude" and "stat1" ... dynamic_input_dim=2, # "feat1" and "feat2" ... future_input_dim=1, # One future feature ... forecast_horizon=1, # Single-step forecast ... quantiles=[0.1, 0.5, 0.9], ... embed_dim=16, ... max_window_size=3, ... memory_size=50, ... num_heads=2, ... dropout_rate=0.1, ... lstm_units=32, ... attention_units=32, ... hidden_units=16 ... ) >>> my_model.compile(optimizer="adam") >>> >>> # Fit the model on the dummy data. >>> my_model.fit( ... x=[X_static, X_dynamic, X_future], ... y=y_array, ... epochs=1, ... batch_size=8 ... ) >>> >>> # Package the inputs as expected by forecast_single_step. >>> inputs = [X_static, X_dynamic, X_future] >>> >>> # Generate a single-step quantile forecast. >>> forecast_df = forecast_single_step( ... xtft_model=my_model, ... inputs=inputs, ... y=y_array, ... dt_col="date", # The time column name in the output ... mode="quantile", # Can be "quantile" or "point" ... spatial_cols=["longitude", "latitude"], ... q=[0.1, 0.5, 0.9], ... tname="subsidence", ... apply_mask=True, ... mask_values=0, ... mask_fill_value=0, ... verbose=3 ... ) >>> print(forecast_df.head())
Notes
In quantile mode, the function computes predictions for multiple quantiles and uses the median (
0.5) for evaluation.If
spatial_colsis provided, it must be the first and second columns of the original training data’sX_static.The function internally utilizes
validate_keras_modelfor model validation,assert_ratiofor quantile verification, andmask_by_referencefor masking operations.Evaluation metrics such as R² Score and Coverage Score are computed if actual target values (y) are provided.
The prediction output is expected to have the shape \((n, 1, m)\), where \(m\) is the number of outputs (e.g., the number of quantiles in quantile mode, or 1 in point mode).
See also
generate_forecast_multi_stepFunction for multi-step forecasts.
coverage_scoreFunction to compute the coverage score.
validate_keras_modelFunction to validate a Keras model.
assert_ratioFunction to validate quantile ratios.
- geoprior.models.utils._utils.step_to_long(df, tname=None, dt_col=None, spatial_cols=None, mode='quantile', quantiles=None, verbose=3, sort=True)[source]
Convert a multi-step forecast DataFrame from wide to long format.
This function transforms a DataFrame containing multi-step forecast predictions into a long-format DataFrame. In quantile mode, forecast columns such as
subsidence_q10_step1,subsidence_q50_step1, etc. are consolidated into unified columns (e.g.subsidence_q10,subsidence_q50, etc.), while in point mode, a single prediction column (subsidence_pred) is generated. The transformation also carries over additional columns (e.g. spatial coordinates and time) from the original DataFrame.- Parameters:
df (
pandas.DataFrame) – The multi-step forecast DataFrame. Expected to contain forecast prediction columns (e.g. columns with_qor_pred_stepin their names) along with other identifiers.tname (
str, optional) – The base name of the target variable (e.g."subsidence"). IfNone, the function attempts to auto-detect the target name from the column names.dt_col (
str, optional) – The name of the time column to include in the final DataFrame. If not provided, time sorting is not performed.spatial_cols (
listofstr, optional) – A list of spatial coordinate columns (e.g.["longitude", "latitude"]) to be retained in the final output.mode (
{"quantile", "point"}, default"quantile") – The forecast mode. In"quantile"mode, multiple quantile forecast columns are merged into unified columns. In"point"mode, a single prediction column is produced.quantiles (
listoffloat, optional) – The quantile values for quantile mode (e.g.[0.1, 0.5, 0.9]). If not provided, defaults are used.sort (
bool, optional) – If True, sorts the final DataFrame by the column specified indt_col(if present). Default is True.verbose (
int, optional) – Verbosity level for logging output. Higher values (e.g. 5 to 7) provide more detailed debug information.
- Returns:
A long-format DataFrame containing the retained spatial columns, the time column when
dt_colis provided, and the merged forecast prediction columns. In quantile mode, the output contains unified columns such assubsidence_q10andsubsidence_q50. In point mode, it contains a singlesubsidence_predcolumn.- Return type:
Examples
>>> from geoprior.models.utils import step_to_long >>> # Given a DataFrame `forecast_df` with columns like: >>> # ['longitude', 'latitude', 'year', 'subsidence_actual', >>> # 'subsidence_q10_step1', 'subsidence_q50_step1', 'subsidence_q89_step1', >>> # 'subsidence_q10_step2', ...] >>> long_df = step_to_long( ... df=forecast_df, ... tname="subsidence", ... dt_col="year", ... spatial_cols=["longitude", "latitude"], ... mode="quantile", ... quantiles=[0.1, 0.5, 0.9], ... verbose=3, ... sort=True ... ) >>> print(long_df.head())
Notes
Internally, this function calls:
check_forecast_mode()to validate the user-specified quantiles.validate_consistency_q()andvalidate_quantiles()to ensure the supplied quantiles match those auto-detected from the DataFrame.Depending on
mode, either_step_to_long_q()for quantile mode or_step_to_long_pred()for point mode performs the conversion.
Mathematically, let \(X \in \mathbb{R}^{n \times m}\) represent the wide-format DataFrame, where each row corresponds to one sample and each forecast step is stored in separate columns. The function reshapes \(X\) into a long-format DataFrame \(Y \in \mathbb{R}^{(n \cdot s) \times p}\), where \(s\) is the forecast horizon and \(p\) is the number of output columns after merging forecast step values.
See also
_step_to_long_qConverts multi-step quantile forecasts to long format.
_step_to_long_predConverts multi-step point forecasts to long format.
detect_digitsExtracts numeric values from strings for quantile detection.
- geoprior.models.utils._utils.prepare_model_inputs(dynamic_input, static_input=None, future_input=None, model_type='strict', forecast_horizon=None, verbose=0, **kwargs)[source]
Prepares a list of input tensors for a model’s call method.
This function standardizes the creation of the input list [static, dynamic, future] expected by many models in
geoprior. It handles cases where static or future inputs might beNone, creating appropriate dummy tensors with zero features if the model_type is ‘strict’.- Parameters:
dynamic_input (
np.ndarrayortf.Tensor) – The dynamic (past observed) features. This input is always required and must be a valid tensor or array. Expected shape: (batch_size, past_time_steps, num_dynamic_features).static_input (
np.ndarrayortf.Tensor, optional) – The static (time-invariant) features. Expected shape: (batch_size, num_static_features). IfNoneand model_type is ‘strict’, a dummy tensor with 0 static features will be created. Default isNone.future_input (
np.ndarrayortf.Tensor, optional) – The known future features. Expected shape: (batch_size, future_time_span, num_future_features). IfNoneand model_type is ‘strict’, a dummy tensor with 0 future features will be created. The time span for this dummy future tensor will be past_time_steps (from dynamic_input) plus forecast_horizon if provided, otherwise just past_time_steps. Default isNone.model_type (
{'strict', 'flexible'}, default'strict') –Determines how
Noneinputs for static and future features are handled:'strict': If static_input or future_input isNone, a dummy tensor with a feature dimension of 0 will be created and included in the output list. This is for models that expect a 3-element list of tensors, even if some paths are unused.'flexible': If static_input or future_input isNone,Noneitself will be placed in the corresponding position in the output list. This is for models that can internally handleNoneinputs for optional feature types.
forecast_horizon (
int, optional) – The forecast horizon. Used only if model_type=’strict’ and future_input isNone, to determine the time dimension of the dummy future tensor (as past_time_steps + forecast_horizon). If not provided in this scenario, the dummy future tensor’s time dimension will match dynamic_input’s past_time_steps. Default isNone.verbose (
int, default0) –Verbosity level. If > 0, prints information about dummy tensor creation.
0: Silent.1: Basic info on dummy creation.2: More details on shapes.
- Returns:
A list containing three elements in the order: [processed_static_input, processed_dynamic_input, processed_future_input]. Elements can be TensorFlow tensors or
None(if model_type=’flexible’ and original input wasNone). All returned tensors are cast to tf.float32.- Return type:
List[Optional[tf.Tensor]]- Raises:
ValueError – If dynamic_input is
None. If dynamic_input is not at least 2D (needs batch dimension). If static_input (when provided) is not 2D. If future_input (when provided) is not 3D.TypeError – If inputs cannot be converted to TensorFlow tensors.
Examples
>>> import tensorflow as tf >>> import numpy as np >>> from geoprior.models.utils import prepare_model_inputs >>> B, T, H = 2, 10, 3 >>> D_s, D_d, D_f = 2, 4, 1 >>> dyn_in = tf.random.normal((B, T, D_d)) >>> stat_in = tf.random.normal((B, D_s)) >>> fut_in = tf.random.normal((B, T + H, D_f))
>>> # Strict mode, all inputs provided >>> s, d, f = prepare_model_inputs(dyn_in, stat_in, fut_in, model_type='strict') >>> print(f"S: {s.shape}, D: {d.shape}, F: {f.shape}") S: (2, 2), D: (2, 10, 4), F: (2, 13, 1)
>>> # Strict mode, static is None >>> s, d, f = prepare_model_inputs(dyn_in, static_input=None, future_input=fut_in, ... model_type='strict', forecast_horizon=H) >>> print(f"S: {s.shape}, D: {d.shape}, F: {f.shape}") S: (2, 0), D: (2, 10, 4), F: (2, 13, 1)
>>> # Flexible mode, static and future are None >>> s, d, f = prepare_model_inputs(dyn_in, static_input=None, future_input=None, ... model_type='flexible') >>> print(f"S: {s is None}, D: {d.shape}, F: {f is None}") S: True, D: (2, 10, 4), F: True
- geoprior.models.utils._utils.extract_batches_from_dataset(dataset, num_batches_to_extract=1, agg=False, errors='warn')[source]
Extracts a specified number of batches from a tf.data.Dataset. Optionally aggregates the extracted batches.
- Parameters:
dataset (
tf.data.Dataset) – The TensorFlow dataset to extract batches from.num_batches_to_extract (
Union[int,str], default1) – Number of batches: int, or ‘all’, ‘*’, ‘auto’.agg (
bool, defaultFalse) – If True, attempts to aggregate the extracted batches into a single tuple structure by concatenating corresponding tensors/arrays or aggregating dictionaries.errors (
str, default'warn') – Error handling: ‘raise’, ‘warn’, ‘ignore’.
- Returns:
If
aggisFalse, returns a list of batch tuples. IfaggisTrue, returns one aggregated tuple orNoneif no batches were extracted. When zero batches are requested or the dataset is empty, the function returns an empty list foragg=FalseandNoneforagg=True.- Return type:
Union[List[Tuple[Any,]],Optional[Tuple[Any,]]]- Raises:
TypeError – If dataset is not a tf.data.Dataset or num_batches_to_extract is invalid (and errors=’raise’).
ValueError – If num_batches_to_extract is negative, or fewer batches are available than requested (and errors=’raise’ and not taking all).
RuntimeError – For unexpected errors during dataset iteration (and errors=’raise’).
- geoprior.models.utils._utils.format_predictions(predictions=None, model=None, inputs=None, y_true_sequences=None, target_name='target', quantiles=None, forecast_horizon=None, output_dim=None, spatial_data_array=None, spatial_cols=None, spatial_cols_indices=None, evaluate_coverage=False, scaler=None, scaler_feature_names=None, target_idx_in_scaler=None, verbose=0, _logger=None, **kwargs)[source]
Formats model predictions into a structured pandas DataFrame.
This utility function takes raw model predictions (either directly as an array/tensor or generated by a provided model and its inputs) and transforms them into a long-format pandas DataFrame. It can handle point forecasts, quantile forecasts, single or multi-output predictions, and optionally include actual target values, spatial identifiers, and perform coverage score evaluation for quantile forecasts.
The output DataFrame is structured with ‘sample_idx’ (identifying the original input sequence) and ‘forecast_step’ (from 1 to H, where H is the forecast horizon).
- Parameters:
predictions (
np.ndarrayortf.Tensor, optional) –The raw prediction tensor or array.
- For point forecasts, expected shapes:
(num_samples, forecast_horizon, output_dim)
(num_samples, forecast_horizon) if output_dim=1 (will be reshaped)
(num_samples, output_dim) if forecast_horizon=1 (will be reshaped)
- For quantile forecasts, expected shapes:
(num_samples, forecast_horizon, num_quantiles * output_dim)
(num_samples, forecast_horizon, num_quantiles, output_dim)
(num_samples, num_quantiles * output_dim) if forecast_horizon=1
If
None, model and inputs must be provided. Default isNone.model (
tf.keras.Model, optional) – A trained Keras model to generate predictions if predictions is not provided. Used in conjunction with inputs. Default isNone.inputs (
List[Optional[Union[np.ndarray,tf.Tensor]]], optional) – A list of input tensors (e.g., [static, dynamic, future]) required by the model to generate predictions. Required if predictions isNoneand model is provided. Default isNone.y_true_sequences (
np.ndarrayortf.Tensor, optional) – The true target values corresponding to the predictions, used for including actuals in the output DataFrame and for evaluation. Expected shape: (num_samples, forecast_horizon, output_dim). Default isNone.target_name (
str, optional) – Base name for the target variable. Used to prefix prediction and actual column names (e.g., “sales_pred”, “sales_q50”, “sales_actual”). Default is “target”.quantiles (
List[float], optional) – A list of quantiles that were predicted by the model (e.g., [0.1, 0.5, 0.9]). Required if the predictions are quantile forecasts. If provided, prediction columns will be named like {target_name}_q10, {target_name}_q50, etc. Default isNone(for point forecasts).forecast_horizon (
int, optional) – The number of time steps into the future that the model predicts. If not provided, it’s inferred from predictions.shape[1] or y_true_sequences.shape[1]. Default isNone.output_dim (
int, optional) – The number of target variables predicted at each time step (e.g., 1 for univariate, >1 for multivariate target). If not provided, it’s inferred from the shape of predictions or y_true_sequences. Default isNone.spatial_data_array (
np.ndarrayortf.Tensororpd.DataFrameorpd.Series, optional) –An array or DataFrame containing static spatial/identifier features for each of the num_samples sequences.
If NumPy/Tensor: Expected shape (num_samples, num_spatial_features). spatial_cols_indices must be provided.
If DataFrame/Series: Must have num_samples rows. spatial_cols must be provided.
These features will be repeated for each forecast step in the output DataFrame. Default is
None.spatial_cols (
List[str], optional) – List of column names to select from spatial_data_array if it’s a DataFrame/Series, or names to assign to columns if spatial_data_array is NumPy/Tensor and spatial_cols_indices are provided. Default isNone.spatial_cols_indices (
List[int], optional) – List of column indices to select from spatial_data_array if it’s a NumPy/Tensor. Length must match spatial_cols if provided. Default isNone.evaluate_coverage (
bool, defaultFalse) – IfTrue, quantiles are provided (at least two), and y_true_sequences is available, calculates the coverage score using the first and last quantiles as interval bounds. Requires geoprior.metrics.coverage_score.scaler (
Any, optional) – A fitted scikit-learn-like scaler object (must have an inverse_transform method) used to scale the target variable and potentially other features. If provided along with scaler_feature_names and target_idx_in_scaler, predictions and actuals for the target will be inverse-transformed. Default isNone.scaler_feature_names (
List[str], optional) – A list of all feature names (in order) that the scaler was originally fit on. Required if scaler is provided and targeted inverse transformation is needed. Default isNone.target_idx_in_scaler (
int, optional) – The index of the target_name within the scaler_feature_names list. Required if scaler is provided and targeted inverse transformation is needed. Default isNone.verbose (
int, default0) –Verbosity level for logging during processing.
0: Silent.1: Basic info.3: More detailed steps.5: Very detailed shape information.
**kwargs (Any) – Additional keyword arguments (currently not used but included for future extensibility).
**kwargs
- Returns:
A long-format DataFrame containing
sample_idxandforecast_step, optional spatial columns, prediction columns, and actual-value columns wheny_true_sequencesis provided. Point forecasts use names like{target_name}_predor{target_name}_{output_idx}_pred. Quantile forecasts use names like{target_name}_qXXor{target_name}_{output_idx}_qXX. Actual values use{target_name}_actualor{target_name}_{output_idx}_actual. Prediction and actual values are inverse-transformed when valid scaler information is provided.- Return type:
- Raises:
ValueError – If predictions is None and model or inputs is also None. If predictions shape is invalid (not 2D, 3D, or 4D). If quantiles are provided but prediction shape is incompatible for inferring output_dim. If spatial_data_array is provided without necessary name/index parameters.
TypeError – If predictions or other inputs cannot be converted to the expected tensor/array types.
See also
geoprior.models.utils.forecast_multi_stepHigher-level forecasting utility.
geoprior.metrics.coverage_scoreFor evaluating quantile forecast intervals.
Examples
>>> import tensorflow as tf >>> import numpy as np >>> from geoprior.models.utils import format_predictions_to_dataframe
>>> B, H, O = 4, 3, 1 # Batch, Horizon, OutputDim >>> Q = [0.1, 0.5, 0.9] >>> preds_point = tf.random.normal((B, H, O)) >>> preds_quant = tf.random.normal((B, H, len(Q))) # For O=1 >>> y_true = tf.random.normal((B, H, O))
>>> # Point forecast >>> df_point = format_predictions_to_dataframe( ... predictions=preds_point, y_true_sequences=y_true, ... target_name="value", forecast_horizon=H, output_dim=O ... ) >>> print(df_point.head(H)) # Show first sample's horizon sample_idx forecast_step value_pred value_actual 0 0 1 -0.576731 -0.647362 1 0 2 0.183931 1.198977 2 0 3 -0.766871 0.534040
>>> # Quantile forecast >>> df_quant = format_predictions_to_dataframe( ... predictions=preds_quant, y_true_sequences=y_true, ... target_name="value", quantiles=Q, ... forecast_horizon=H, output_dim=O ... ) >>> print(df_quant.head(H)) sample_idx forecast_step value_q10 value_q50 value_q90 value_actual 0 0 1 -0.209947 0.263107 -0.308929 -0.647362 1 0 2 0.303091 0.594701 -0.225007 1.198977 2 0 3 0.136699 -1.237739 0.002834 0.534040
>>> # With spatial data (NumPy array) >>> spatial_np = np.array([[101, 201], [102, 202], [103, 203], [104, 204]]) >>> df_spatial = format_predictions_to_dataframe( ... predictions=preds_point, ... spatial_data_array=spatial_np, ... spatial_cols=['store_id', 'region_id'], ... spatial_cols_indices=[0, 1] ... ) >>> print(df_spatial[['sample_idx', 'forecast_step', 'store_id']].head(H)) sample_idx forecast_step store_id 0 0 1 101.0 1 0 2 101.0 2 0 3 101.0
- geoprior.models.utils._utils.export_keras_losses(history, keys=None, savefile=None, verbose=0, formats=('json', 'csv'))[source]
Export loss(es) (and any other metric) from a Keras History object.
- Parameters:
history (History) – The History object returned by model.fit().
keys (list[str] | None) – List of history.history keys to export, e.g. [“loss”, “val_loss”, “physics_loss”]. If None, defaults to all keys ending with “loss”.
savefile (str | None) –
Path (optionally with extension) where to write the output.
If the extension is
.json, only JSON is written. If the extension is.csv, only CSV is written. If no extension is given, all formats informatsare written usingsavefileas the base name.verbose (int) – If >0, prints status messages.
formats (tuple[str, ...]) – File formats to write when savefile has no extension. Valid entries are “json” and “csv”.
- Returns:
result – Dictionary containing
"epochs_run"and one list entry per exported history key.- Return type:
- geoprior.models.utils._utils.get_tensor_from(inputs, *tensor_names, default=None, check_type=True, auto_convert=True)[source]
Safely retrieves the first available tensor from a dictionary using a list of possible keys.
This utility is crucial for handling model inputs within a TensorFlow graph (e.g., in train_step). It avoids the ambiguous boolean evaluation of Tensors (e.g., tensor_a or tensor_b), which causes runtime errors, by explicitly checking for is not None.
- Parameters:
inputs (
dict) – The dictionary to search, typically the model’s input dictionary (e.g., the inputs provided to call or train_step).*tensor_names (
str) – One or more string keys to check for in the inputs dictionary, in order of priority.default (
Any, optional) – A default value to return if no keys are found or if no found value is a valid tensor. Defaults to None.check_type (
bool, defaultTrue) – If True, only returns a value if it is (or can be converted to) a Tensor or Variable. If False, returns the first non-None value regardless of its type.auto_convert (
bool, defaultTrue) – If True and check_type is True, this function will attempt to convert a found non-Tensor value (like a NumPy array or a list) into a TensorFlow tensor using tf.convert_to_tensor.
- Returns:
The first found tf.Tensor or tf.Variable associated with one of the tensor_names. If auto_convert is True, this can also be a newly converted tensor. Returns default (typically None) if no valid tensor is found.
- Return type:
Optional[tf.Tensor]- Raises:
TypeError – If inputs is not a dictionary.
Examples
>>> import tensorflow as tf >>> inputs_dict = { ... 'some_other_key': [1, 2, 3], ... 'soil_thickness': tf.constant([20., 21.], dtype=tf.float32) ... } >>> >>> # Correctly finds 'soil_thickness' >>> get_tensor_from(inputs_dict, 'H_field', 'soil_thickness') <tf.Tensor: shape=(2,), dtype=float32, numpy=array([20., 21.], ...)> >>> >>> # Returns None safely if nothing is found >>> get_tensor_from(inputs_dict, 'missing_key', 'another_key') None >>> >>> # Demonstrating auto_convert >>> inputs_dict_np = {'H_field': np.array([10., 11.])} >>> get_tensor_from(inputs_dict_np, 'H_field', auto_convert=True) <tf.Tensor: shape=(2,), dtype=float32, numpy=array([10., 11.], ...)>
This module contains many of the reusable helpers for sequence forecasting and model-side formatting. Its own module docstring frames the utilities in the context of Temporal Fusion Transformer-style preprocessing and multi-horizon forecasting, which gives the page valuable methodological context.
PINN/model physics helper module#
Physics-Informed Neural Network (PINN) Utility functions.
- geoprior.models.utils.pinn.process_pde_modes(pde_mode, enforce_consolidation=False, pde_mode_config=None, solo_return=False)[source]
Normalize and validate PDE mode selection.
- Parameters:
pde_mode (
str,sequenceofstr, orNone) –Requested PDE mode(s).
Accepted canonical values are: -
"none"-"consolidation"-"gw_flow"-"both"Accepted aliases: -
None,"off"->"none"-"on"->"both"enforce_consolidation (
bool, defaultFalse) –If True, any resolved mode other than exact
["consolidation"]is coerced to["consolidation"]and a warning is emitted.This includes: -
["none"]-["gw_flow"]-["consolidation", "gw_flow"]pde_mode_config (
str,sequenceofstr, orNone, optional) – Optional override. If provided, this value takes precedence overpde_mode.solo_return (
bool, defaultFalse) –If False, return a canonical list of active modes.
If True, return a single canonical label: -
"none"-"consolidation"-"gw_flow"-"both"
- Returns:
Canonical PDE mode(s), either as a list or a single label.
- Return type:
- Raises:
TypeError – If the input type is invalid.
ValueError – If a token is unsupported or the mode selection is ambiguous.
Examples
>>> process_pde_modes(None) ['none']
>>> process_pde_modes("off") ['none']
>>> process_pde_modes("on") ['consolidation', 'gw_flow']
>>> process_pde_modes("both", solo_return=True) 'both'
>>> process_pde_modes("gw_flow", enforce_consolidation=True) ['consolidation']
- geoprior.models.utils.pinn.format_pinn_predictions(predictions=None, model=None, model_inputs=None, y_true_dict=None, target_mapping=None, include_gwl=True, include_coords=True, quantiles=None, forecast_horizon=None, output_dims=None, ids_data_array=None, ids_cols=None, ids_cols_indices=None, scaler_info=None, coord_scaler=None, evaluate_coverage=False, coverage_quantile_indices=(0, -1), savefile=None, _logger=None, name=None, model_name=None, stop_check=None, verbose=0, **kwargs)[source]
Formats PINN model predictions into a structured pandas DataFrame.
This is a general-purpose utility for transforming raw model outputs (from models like PIHALNet or TransFlowSubsNet) into a long-format DataFrame suitable for analysis, visualization, or export.
This is a powerful, general-purpose utility for transforming raw model outputs into a long-format DataFrame suitable for analysis, visualization, or export. It handles multi-target outputs (e.g., subsidence and GWL), point or quantile forecasts, and can optionally include true values, coordinate information, and other metadata. It also supports inverse-scaling of predictions and evaluation of quantile coverage.
- Parameters:
predictions (
dictofTensors, optional) – The dictionary of prediction tensors, typically returned by a model’s.predict()method. Keys should match the model’s output layer names (e.g.,'subs_pred','gwl_pred'). IfNone, predictions are generated internally using the model and model_inputs arguments. Default isNone.model (
keras.Model, optional) – A compiled Keras model instance used to generate predictions if the predictions dictionary is not provided. Default isNone.model_inputs (
dictofTensors, optional) – A dictionary of input tensors matching the model’s signature, required only if predictions isNone. Default isNone.y_true_dict (
dict, optional) – A dictionary containing the ground-truth target arrays, keyed by their base names (e.g.,'subsidence','gwl'). If provided, an<target>_actualcolumn will be added to the output DataFrame for comparison. Default isNone.target_mapping (
dict, optional) – A custom mapping from model output keys to desired base names in the DataFrame columns. For example:{'subs_pred': 'subsidence_mm', 'gwl_pred': 'head_m'}. Default isNone.include_gwl (
bool, defaultTrue) – Toggles the inclusion of groundwater level (GWL) predictions in the final DataFrame.include_coords (
bool, defaultTrue) – Toggles the inclusion of the spatio-temporal coordinate columns (coord_t,coord_x,coord_y) in the final DataFrame.quantiles (
listoffloat, optional) – The list of quantile levels (e.g.,[0.1, 0.5, 0.9]) that the model predicted. This is crucial for correctly parsing probabilistic forecasts. Default isNone.forecast_horizon (
int, optional) – The length of the forecast horizon. IfNone, it is inferred from the shape of the prediction tensors. Default isNone.output_dims (
dictofstr, optional) – A dictionary specifying the feature dimension of each target, e.g.,{'subs_pred': 1, 'gwl_pred': 1}. IfNone, it’s inferred from the tensor shapes. Default isNone.ids_data_array (
np.ndarrayorpd.DataFrame, optional) – An array or DataFrame containing static identifiers (e.g., well IDs, site categories) for each sample. Its length must match the number of samples in the prediction. Default isNone.ids_cols (
listofstr, optional) – A list of column names for the ids_data_array. Required if ids_data_array is a NumPy array. Default isNone.ids_cols_indices (
listofint, optional) – A list of column indices to select from ids_data_array if it is a NumPy array. Default isNone.scaler_info (
dict, optional) – A dictionary providing the necessary information to perform inverse scaling on a per-target basis. Each key should be a target name (e.g., ‘subsidence’) and its value a dictionary containing{'scaler': obj, 'all_features': list, 'idx': int}. Default isNone.coord_scaler (
object, optional) – A fitted scikit-learn-like scaler object used to perform an inverse transform on the coordinate columns. Default isNone.evaluate_coverage (
bool, defaultFalse) – IfTrueand quantile predictions are present, calculates the unconditional coverage of the prediction interval.coverage_quantile_indices (
tupleof(int,int), default(0,-1)) – The indices of the lower and upper quantiles in the sorted quantiles list to use for the coverage calculation. Default is(0, -1), which corresponds to the full range.savefile (
str, optional) – If a file path is provided, the final DataFrame is saved to a CSV file at this location. Default isNone.name (
strorNone) – Name of the prediction. Name is used to format the output of the data and coverage result if applicable.model_name (
str,None,) – Name of the model.verbose (
int, default0) – The verbosity level, from 0 (silent) to 5 (trace every step).**kwargs (
dict,) – Additional keyword arguments for future extensions.
- Returns:
A long-format DataFrame where each row represents a single forecast step for a single sample. Columns include sample and step identifiers, coordinates, predictions, and optionally actuals and metadata.
- Return type:
pd.DataFrame
Notes
The function returns a column-aligned DataFrame, which simplifies subsequent analysis and plotting.
For quantile forecasts, prediction columns are named using the pattern
<target_name>_q<quantile*100>, e.g.,subsidence_q5,subsidence_q50,subsidence_q95.For point forecasts, the column is named
<target_name>_pred.
See also
geoprior.plot.forecast.plot_forecastsA powerful utility for visualizing the DataFrame produced by this function.
- geoprior.models.utils.pinn.format_pihalnet_predictions(pihalnet_outputs=None, model=None, model_inputs=None, y_true_dict=None, target_mapping=None, include_gwl=True, include_coords=True, quantiles=None, forecast_horizon=None, output_dims=None, ids_data_array=None, ids_cols=None, ids_cols_indices=None, scaler_info=None, coord_scaler=None, evaluate_coverage=False, coverage_quantile_indices=(0, -1), savefile=None, name=None, model_name=None, apply_mask=False, mask_values=None, mask_fill_value=None, verbose=0, _logger=None, stop_check=None, **kwargs)[source]
Formats PIHALNet/GeoPriorSubsNet predictions into a structured pandas DataFrame, handling inversion, quantiles, and coordinates.
This function is the core formatter. It: 1. Gets model outputs (or uses provided ones). 2. Unpacks ‘data_final’ if model_name is ‘geoprior’. 3. Inverse-transforms all prediction and actual arrays using scaler_info. 4. Builds a long-format DataFrame with sample_idx and forecast_step. 5. Appends inverted quantile/point predictions. 6. Appends inverted actual values. 7. Appends inverted coordinates. 8. Appends static/ID columns. 9. Evaluates coverage on the inverted data.
- Parameters:
pihalnet_outputs (
dict, optional) – Raw output from model.predict(). If None, model and model_inputs must be provided.model (
tf.keras.Model, optional) – Trained model instance (if pihalnet_outputs is None).model_inputs (
dict, optional) – Inputs for the model to generate predictions (if pihalnet_outputs is None).y_true_dict (
dict, optional) – Dictionary of true target arrays (e.g., {‘subs_pred’: y_true_s}). Required for including actuals and evaluating coverage.target_mapping (
dict, optional) – Maps prediction keys to base names for DataFrame columns. Default: {‘subs_pred’: ‘subsidence’, ‘gwl_pred’: ‘gwl’}.include_gwl (
bool, defaultTrue) – Whether to include ‘gwl_pred’ in the final DataFrame.include_coords (
bool, defaultTrue) – Whether to include ‘coord_t’, ‘coord_x’, ‘coord_y’ columns.quantiles (
list[float], optional) – List of quantiles (e.g., [0.1, 0.5, 0.9]). If provided, quantile columns (e.g., ‘subsidence_q10’) are created.forecast_horizon (
int, optional) – The forecast horizon length (H). If not provided, it’s inferred from the prediction array’s shape.output_dims (
dict, optional) – Maps prediction keys to their output dimension (O). E.g., {‘subs_pred’: 1, ‘gwl_pred’: 1}. Crucial for correctly splitting GeoPrior outputs and reshaping.ids_data_array (
np.ndarrayorpd.DataFrame, optional) – Static/ID data (e.g., original coordinates) to merge. Must have the same number of samples (B) as predictions.ids_cols (
list[str], optional) – Column names if ids_data_array is a DataFrame.ids_cols_indices (
list[int], optional) – Column indices if ids_data_array is a NumPy array.scaler_info (
dict, optional) – Dictionary for inverse scaling. Each target entry should provide a fitted scaler, the target index inside that scaler, and the feature-name ordering used when the scaler was fit.coord_scaler (
sklearn.preprocessing.Scaler, optional) – A fitted scaler object for inverse transforming the ‘coords’ tensor.evaluate_coverage (
bool, defaultFalse) – If True, calculates coverage percentage for quantiles.coverage_quantile_indices (
tuple[int,int], default(0,-1)) – Indices of the low and high quantiles in the quantiles list to use for coverage (e.g., 0 and -1 for 10th and 90th).savefile (
str, optional) – If provided, saves the final DataFrame to this path.model_name (
str, optional) – Specifies the model type. If ‘geoprior’ or ‘geopriorsubsnet’, triggers unpacking of the ‘data_final’ output.apply_mask (
bool, defaultFalse) – If True, masks predictions based on mask_values in the first target’s _actual column.mask_values (
floatorint, optional) – The value in the _actual column to trigger masking.mask_fill_value (
float, optional) – The value to replace masked predictions with (e.g., np.nan).verbose (
int, default0) – Logging verbosity._logger (
logging.Loggerorcallable, optional) – Logger object.stop_check (
callable, optional) – Function to check for early stopping.name (str | None)
- Returns:
A long-format DataFrame with predictions, actuals, and coordinates.
- Return type:
pd.DataFrame
- geoprior.models.utils.pinn.format_preds(pihalnet_outputs=None, model=None, model_inputs=None, y_true_dict=None, target_mapping=None, include_gwl=True, include_coords=True, quantiles=None, forecast_horizon=None, output_dims=None, ids_data_array=None, ids_cols=None, ids_cols_indices=None, scaler_info=None, coord_scaler=None, evaluate_coverage=False, coverage_quantile_indices=(0, -1), savefile=None, name=None, apply_mask=False, mask_values=None, mask_fill_value=None, verbose=0, _logger=None, stop_check=None, **kwargs)[source]
Main function orchestrating all helper steps.
- Parameters:
model (Model | None)
include_gwl (bool)
include_coords (bool)
forecast_horizon (int | None)
ids_data_array (Any | None)
coord_scaler (Any | None)
evaluate_coverage (bool)
savefile (str | None)
name (str | None)
apply_mask (bool)
mask_values (Any | None)
mask_fill_value (float | None)
verbose (int)
_logger (Any)
- Return type:
- geoprior.models.utils.pinn.prepare_pinn_data_sequences(df, time_col, subsidence_col, gwl_col, dynamic_cols, static_cols=None, future_cols=None, spatial_cols=None, h_field_col=None, lon_col=None, lat_col=None, group_id_cols=None, time_steps=12, forecast_horizon=3, output_subsidence_dim=1, output_gwl_dim=1, datetime_format=None, normalize_coords=True, cols_to_scale=None, lock_physics_cols=True, protect_si_suffix='__si', return_coord_scaler=False, coord_scaler=None, fit_coord_scaler=True, mode=None, model=None, savefile=None, progress_hook=None, stop_check=None, verbose=0, _logger=None, **kws)[source]
- Parameters:
df (DataFrame)
time_col (str)
subsidence_col (str)
gwl_col (str)
h_field_col (str | None)
lon_col (str | None)
lat_col (str | None)
time_steps (int)
forecast_horizon (int)
output_subsidence_dim (int)
output_gwl_dim (int)
datetime_format (str | None)
normalize_coords (bool)
lock_physics_cols (bool)
protect_si_suffix (str)
return_coord_scaler (bool)
coord_scaler (MinMaxScaler | None)
fit_coord_scaler (bool)
mode (str | None)
model (str | None)
savefile (str | None)
verbose (int)
- Return type:
tuple[dict[str, ndarray], dict[str, ndarray]] | tuple[dict[str, ndarray], dict[str, ndarray], MinMaxScaler | None]
- geoprior.models.utils.pinn.check_and_rename_keys(inputs, y)[source]
Helper function to check and rename keys in the inputs and target dictionaries.
This function ensures that the necessary keys are present in both the inputs and y dictionaries. If the keys for ‘subsidence’ or ‘gwl’ are not found, it attempts to rename them from possible alternatives like ‘subs_pred’ or ‘gwl_pred’.
- Parameters:
- Raises:
ValueError – If required keys are missing in inputs or y, or if renaming does not result in valid keys for ‘subsidence’ and ‘gwl’.
- geoprior.models.utils.pinn.check_required_input_keys(inputs, y=None, message=None, model_name=None, do_rename=True)[source]
Validate presence of required keys in inputs and y. Optionally canonicalize keys via reverse alias mapping.
This function ensures that the necessary keys are present in both the inputs and y dictionaries. If the keys for ‘subsidence’ or ‘gwl’ are not found, it attempts to rename them from possible alternatives like ‘subs_pred’ or ‘gwl_pred’.
- Parameters:
inputs (
dict) – A dictionary containing the input data. The keys ‘coords’ and ‘dynamic_features’ are expected.y (
dict) – A dictionary containing the target values. The keys ‘subsidence’ and ‘gwl’ are expected, but they could also appear as ‘subs_pred’ or ‘gwl_pred’.message (
str, optional) – Message to raise error when inputs/y are not dictionnary.model_name (str | None)
do_rename (bool)
- Raises:
ValueError – If required keys are missing in inputs or y, or if renaming does not result in valid keys for ‘subsidence’ and ‘gwl’.
- Return type:
- geoprior.models.utils.pinn.extract_txy_in(inputs, coord_slice_map=None, expect_dim=None, verbose=0, _logger=None, **kws)[source]
Extracts t, x, y tensors from various input formats.
This utility standardizes coordinate inputs, accepting a single tensor or a dictionary, and handling both 2D (spatial/static) and 3D (spatio-temporal) data. It ensures a consistent 3D output format for robust downstream processing.
- Parameters:
inputs (
tf.Tensor,np.ndarray, ordict) – The input data containing coordinates. A single tensor or array may be 2D with shape(batch, 3)or 3D with shape(batch, time_steps, 3). A dictionary may contain a'coords'key with the coordinate tensor, or separate't','x', and'y'keys.coord_slice_map (
dict, optional) – Mapping for ‘t’, ‘x’, ‘y’ to their index in the last dimension of a coordinate tensor. Defaults to {‘t’: 0, ‘x’: 1, ‘y’: 2}.expect_dim (
{'2d', '3d'}, optional) – If provided, enforces that the input resolves to the specified dimension.'2d'requires input shaped like(batch, 3)or a dictionary of(batch, 1)tensors.'3d'requires input shaped like(batch, time, 3)or a dictionary of(batch, time, 1)tensors. IfNone, both are accepted and 2D inputs are expanded to 3D.verbose (
int, default0) – Controls the verbosity of logging messages. 0 is silent, 1 provides basic info, and higher values provide more detail.
- Returns:
t, x, y – The extracted t, x, and y coordinate tensors, each reshaped to be 3D with a singleton last dimension, e.g., (batch, time_steps, 1).
- Return type:
Tuple[tf.Tensor,tf.Tensor,tf.Tensor]- Raises:
ValueError – If input format is unsupported, dimensions are inconsistent, or expect_dim constraint is violated.
- geoprior.models.utils.pinn.extract_txy(inputs, coord_slice_map=None, expect_dim=None, verbose=0, _logger=None, **kws)[source]
Extracts t, x, y tensors from various input formats.
This utility standardizes coordinate inputs, accepting a single tensor or a dictionary, and handling both 2D (spatial/static) and 3D (spatio-temporal) data with flexible dimension validation.
- Parameters:
inputs (
tf.Tensor,np.ndarray, ordict) – The input data containing coordinates. Can be a single tensor or a dictionary with ‘coords’ or ‘t’, ‘x’, ‘y’ keys.coord_slice_map (
dict, optional) – Mapping for ‘t’, ‘x’, ‘y’ to their index in the last dimension of a coordinate tensor. Defaults to {‘t’: 0, ‘x’: 1, ‘y’: 2}.expect_dim (
{'2d', '3d', '3d_only'}, optional) – Enforces a constraint on the input’s dimension.'2d'requires input shaped like(batch, 3).'3d'accepts 3D input and expands 2D input to 3D with a time dimension of 1.'3d_only'requires 3D input and raises an error for 2D input.Noneaccepts both 2D and 3D inputs without changing their rank.verbose (
int, default0) – Controls logging verbosity.
- Returns:
t, x, y – The extracted t, x, and y coordinate tensors. Their rank (2D or 3D) depends on the input and the expect_dim mode.
- Return type:
Tuple[tf.Tensor,tf.Tensor,tf.Tensor]- Raises:
ValueError – If input format is unsupported, dimensions are inconsistent, or expect_dim constraint is violated.
- geoprior.models.utils.pinn.plot_hydraulic_head(model, t_slice, x_bounds, y_bounds, resolution=100, ax=None, title=None, cmap='viridis', colorbar_label='Hydraulic Head (h)', save_path=None, show_plot=True, **contourf_kwargs)[source]
Generate and plot a 2D contour map of a hydraulic head solution.
This utility visualizes the output of a Physics-Informed Neural Network (PINN) that solves for the hydraulic head \(h(t, x, y)\). It automates the process of creating a spatial grid, running model predictions, and generating a publication-quality contour plot for a specific slice in time.
- Parameters:
model (
tf.keras.Model) – The trained PINN model. It is expected to have a.predict()method that accepts a dictionary of tensors with keys{'t', 'x', 'y'}.t_slice (
float) – The specific point in time \(t\) for which to plot the 2D spatial solution.x_bounds (
tupleoffloat) – A tuple(x_min, x_max)defining the spatial domain for the x-axis.y_bounds (
tupleoffloat) – A tuple(y_min, y_max)defining the spatial domain for the y-axis.resolution (
int, optional) – The number of points to sample along each spatial axis, creating a grid ofresolution x resolutionpoints for prediction. Higher values result in a smoother plot. Default is 100.ax (
matplotlib.axes.Axes, optional) – A pre-existing Matplotlib Axes object to plot on. IfNone, a new figure and axes are created internally. This is useful for embedding this plot within a larger figure arrangement. Default isNone.title (
str, optional) – A custom title for the plot. IfNone, a default title is generated using the value of t_slice. Default isNone.cmap (
str, optional) – The name of the Matplotlib colormap to use for the contour plot. Default is'viridis'.colorbar_label (
str, optional) – The text label for the color bar. Default is'Hydraulic Head (h)'.save_path (
str, optional) – If provided, the path (including filename and extension) where the generated plot will be saved. This is only active when the function creates its own figure (i.e., when ax isNone). Default isNone.show_plot (
bool, optional) – IfTrue, callsplt.show()to display the plot. This is only active when the function creates its own figure. Default isTrue.**contourf_kwargs (
any) – Additional keyword arguments that are passed directly to thematplotlib.pyplot.contourffunction. This allows for advanced customization (e.g.,levels=20,extend='both').
- Returns:
ax (
matplotlib.axes.Axes) – The Matplotlib Axes object on which the contour plot was drawn.contour (
matplotlib.cm.ScalarMappable) – The contour plot object, which can be used for further customizations, such as modifying the color bar.
- Return type:
See also
geoprior.models.pinn.PiTGWFlowThe PINN model this function is designed to visualize.
Notes
The core mechanism of this function involves creating a 2D meshgrid of \((x, y)\) coordinates. These grid points are then “flattened” into a long list of points, as the PINN model expects a batch of individual coordinates for prediction, not a grid.
The prediction process is as follows:
A grid of shape
(resolution, resolution)is created for \(x\) and \(y\).These grids are reshaped into column vectors of shape
(resolution*resolution, 1).A time vector of the same shape, filled with t_slice, is created.
The model’s
.predict()method is called on these flat tensors.The resulting flat prediction vector is reshaped back to the original
(resolution, resolution)grid shape for plotting.
If a custom ax is provided, the user is responsible for calling
plt.show()or saving the parent figure.Examples
>>> import numpy as np >>> import tensorflow as tf >>> import matplotlib.pyplot as plt >>> # This is a mock model for demonstration purposes. >>> # In practice, you would use a trained PiTGWFlow model. >>> class MockPINN(tf.keras.Model): ... def call(self, inputs): ... # A simple analytical function for demonstration ... t, x, y = inputs['t'], inputs['x'], inputs['y'] ... return tf.sin(np.pi * x) * tf.cos(np.pi * y) * tf.exp(-t) ... >>> mock_model = MockPINN()
1. Simple Plotting Example
This example creates a single plot and saves it to a file.
>>> ax, contour = plot_hydraulic_head( ... model=mock_model, ... t_slice=0.5, ... x_bounds=(-1, 1), ... y_bounds=(-1, 1), ... resolution=50, ... save_path="hydraulic_head_t0.5.png", ... show_plot=False # Do not display interactively ... ) Plot saved to hydraulic_head_t0.5.png
2. Advanced Example with Subplots
This example shows how to use the ax parameter to draw the solution at two different times side-by-side in one figure.
>>> fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6)) >>> fig.suptitle('Hydraulic Head at Different Times', fontsize=16) ... >>> # Plot solution at t = 0.1 >>> plot_hydraulic_head( ... model=mock_model, t_slice=0.1, x_bounds=(-1, 1), ... y_bounds=(-1, 1), ax=ax1, show_plot=False ... ) ... >>> # Plot solution at t = 1.0 >>> plot_hydraulic_head( ... model=mock_model, t_slice=1.0, x_bounds=(-1, 1), ... y_bounds=(-1, 1), ax=ax2, show_plot=False ... ) ... >>> plt.tight_layout(rect=[0, 0, 1, 0.96]) >>> plt.show()
This module is the entry point for PINN-specific utility logic, including PDE mode normalization, coordinate extraction, PINN prediction formatting, and related helpers used near the physics-informed model layer.
Selected model-facing helpers#
- geoprior.models.utils.prepare_model_inputs(dynamic_input, static_input=None, future_input=None, model_type='strict', forecast_horizon=None, verbose=0, **kwargs)[source]
Prepares a list of input tensors for a model’s call method.
This function standardizes the creation of the input list [static, dynamic, future] expected by many models in
geoprior. It handles cases where static or future inputs might beNone, creating appropriate dummy tensors with zero features if the model_type is ‘strict’.- Parameters:
dynamic_input (
np.ndarrayortf.Tensor) – The dynamic (past observed) features. This input is always required and must be a valid tensor or array. Expected shape: (batch_size, past_time_steps, num_dynamic_features).static_input (
np.ndarrayortf.Tensor, optional) – The static (time-invariant) features. Expected shape: (batch_size, num_static_features). IfNoneand model_type is ‘strict’, a dummy tensor with 0 static features will be created. Default isNone.future_input (
np.ndarrayortf.Tensor, optional) – The known future features. Expected shape: (batch_size, future_time_span, num_future_features). IfNoneand model_type is ‘strict’, a dummy tensor with 0 future features will be created. The time span for this dummy future tensor will be past_time_steps (from dynamic_input) plus forecast_horizon if provided, otherwise just past_time_steps. Default isNone.model_type (
{'strict', 'flexible'}, default'strict') –Determines how
Noneinputs for static and future features are handled:'strict': If static_input or future_input isNone, a dummy tensor with a feature dimension of 0 will be created and included in the output list. This is for models that expect a 3-element list of tensors, even if some paths are unused.'flexible': If static_input or future_input isNone,Noneitself will be placed in the corresponding position in the output list. This is for models that can internally handleNoneinputs for optional feature types.
forecast_horizon (
int, optional) – The forecast horizon. Used only if model_type=’strict’ and future_input isNone, to determine the time dimension of the dummy future tensor (as past_time_steps + forecast_horizon). If not provided in this scenario, the dummy future tensor’s time dimension will match dynamic_input’s past_time_steps. Default isNone.verbose (
int, default0) –Verbosity level. If > 0, prints information about dummy tensor creation.
0: Silent.1: Basic info on dummy creation.2: More details on shapes.
- Returns:
A list containing three elements in the order: [processed_static_input, processed_dynamic_input, processed_future_input]. Elements can be TensorFlow tensors or
None(if model_type=’flexible’ and original input wasNone). All returned tensors are cast to tf.float32.- Return type:
List[Optional[tf.Tensor]]- Raises:
ValueError – If dynamic_input is
None. If dynamic_input is not at least 2D (needs batch dimension). If static_input (when provided) is not 2D. If future_input (when provided) is not 3D.TypeError – If inputs cannot be converted to TensorFlow tensors.
Examples
>>> import tensorflow as tf >>> import numpy as np >>> from geoprior.models.utils import prepare_model_inputs >>> B, T, H = 2, 10, 3 >>> D_s, D_d, D_f = 2, 4, 1 >>> dyn_in = tf.random.normal((B, T, D_d)) >>> stat_in = tf.random.normal((B, D_s)) >>> fut_in = tf.random.normal((B, T + H, D_f))
>>> # Strict mode, all inputs provided >>> s, d, f = prepare_model_inputs(dyn_in, stat_in, fut_in, model_type='strict') >>> print(f"S: {s.shape}, D: {d.shape}, F: {f.shape}") S: (2, 2), D: (2, 10, 4), F: (2, 13, 1)
>>> # Strict mode, static is None >>> s, d, f = prepare_model_inputs(dyn_in, static_input=None, future_input=fut_in, ... model_type='strict', forecast_horizon=H) >>> print(f"S: {s.shape}, D: {d.shape}, F: {f.shape}") S: (2, 0), D: (2, 10, 4), F: (2, 13, 1)
>>> # Flexible mode, static and future are None >>> s, d, f = prepare_model_inputs(dyn_in, static_input=None, future_input=None, ... model_type='flexible') >>> print(f"S: {s is None}, D: {d.shape}, F: {f is None}") S: True, D: (2, 10, 4), F: True
- geoprior.models.utils.prepare_model_inputs_in(dynamic_input, static_input=None, future_input=None, model_type='strict', forecast_horizon=None, verbose=0)[source]
- Prepares a list of input tensors for a model’s call method in graph
compatible mode.
Prepares a list of input tensors for a model’s call method.
This function standardizes the creation of the input list [static, dynamic, future] expected by many models in
geoprior. It handles cases where static or future inputs might beNone, creating appropriate dummy tensors with zero features if the model_type is ‘strict’.- Parameters:
- dynamic_input
np.ndarrayortf.Tensor The dynamic (past observed) features. This input is always required and must be a valid tensor or array. Expected shape: (batch_size, past_time_steps, num_dynamic_features).
- static_input
np.ndarrayortf.Tensor,optional The static (time-invariant) features. Expected shape: (batch_size, num_static_features). If
Noneand model_type is ‘strict’, a dummy tensor with 0 static features will be created. Default isNone.- future_input
np.ndarrayortf.Tensor,optional The known future features. Expected shape: (batch_size, future_time_span, num_future_features). If
Noneand model_type is ‘strict’, a dummy tensor with 0 future features will be created. The time span for this dummy future tensor will be past_time_steps (from dynamic_input) plus forecast_horizon if provided, otherwise just past_time_steps. Default isNone.- model_type{‘strict’, ‘flexible’},
default‘strict’ Determines how
Noneinputs for static and future features are handled:'strict': If static_input or future_input isNone, a dummy tensor with a feature dimension of 0 will be created and included in the output list. This is for models that expect a 3-element list of tensors, even if some paths are unused.'flexible': If static_input or future_input isNone,Noneitself will be placed in the corresponding position in the output list. This is for models that can internally handleNoneinputs for optional feature types.
- forecast_horizon
int,optional The forecast horizon. Used only if model_type=’strict’ and future_input is
None, to determine the time dimension of the dummy future tensor (as past_time_steps + forecast_horizon). If not provided in this scenario, the dummy future tensor’s time dimension will match dynamic_input’s past_time_steps. Default isNone.- verbose
int,default0 Verbosity level. If > 0, prints information about dummy tensor creation.
0: Silent.1: Basic info on dummy creation.2: More details on shapes.
- dynamic_input
- Returns:
List[Optional[tf.Tensor]]A list containing three elements in the order: [processed_static_input, processed_dynamic_input, processed_future_input]. Elements can be TensorFlow tensors or
None(if model_type=’flexible’ and original input wasNone). All returned tensors are cast to tf.float32.
- Raises:
ValueErrorIf dynamic_input is
None. If dynamic_input is not at least 2D (needs batch dimension). If static_input (when provided) is not 2D. If future_input (when provided) is not 3D.TypeErrorIf inputs cannot be converted to TensorFlow tensors.
- Parameters:
- Return type:
Examples
>>> import tensorflow as tf >>> import numpy as np >>> from geoprior.models.utils import prepare_model_inputs_in >>> B, T, H = 2, 10, 3 >>> D_s, D_d, D_f = 2, 4, 1 >>> dyn_in = tf.random.normal((B, T, D_d)) >>> stat_in = tf.random.normal((B, D_s)) >>> fut_in = tf.random.normal((B, T + H, D_f))
>>> # Strict mode, all inputs provided >>> s, d, f = prepare_model_inputs_in(dyn_in, stat_in, fut_in, model_type='strict') >>> print(f"S: {s.shape}, D: {d.shape}, F: {f.shape}") S: (2, 2), D: (2, 10, 4), F: (2, 13, 1)
>>> # Strict mode, static is None >>> s, d, f = prepare_model_inputs_in(dyn_in, static_input=None, future_input=fut_in, ... model_type='strict', forecast_horizon=H) >>> print(f"S: {s.shape}, D: {d.shape}, F: {f.shape}") S: (2, 0), D: (2, 10, 4), F: (2, 13, 1)
>>> # Flexible mode, static and future are None >>> s, d, f = prepare_model_inputs_in(dyn_in, static_input=None, future_input=None, ... model_type='flexible') >>> print(f"S: {s is None}, D: {d.shape}, F: {f is None}") S: True, D: (2, 10, 4), F: True
- geoprior.models.utils.create_sequences(df, sequence_length, target_col, step=1, include_overlap=True, drop_last=True, forecast_horizon=None, verbose=3)[source]
Create input sequences and corresponding targets for time series forecasting.
The create_sequences function generates sequences of features and their corresponding targets from a time series dataset. This is essential for training sequence models like Temporal Fusion Transformers, LSTMs, and others that rely on temporal dependencies.
See more in User Guide.
- Parameters:
df (pandas.DataFrame) – The processed DataFrame containing features and the target variable.
sequence_length (int) – The number of past time steps to include in each input sequence.
target_col (str) – The name of the target column.
step (int, default 1) – The step size between the starts of consecutive sequences.
include_overlap (bool, default True) – Whether to include overlapping sequences based on the step size.
drop_last (bool, default True) – Whether to drop the last sequence if it does not have enough data points.
forecast_horizon (int, optional, default None) – The number of future time steps to predict. If set to None, the function will create targets for a single future time step. If provided, targets will consist of the next forecast_horizon time steps.
verbose (int, default 3) – Controls the verbosity of logging. Ranges from 0 (no logs) to 7 (maximal logs).
- Returns:
- A tuple containing:
sequences: Array of input sequences with shape (num_sequences, sequence_length, num_features).
- targets:
If forecast_horizon is None: Array of target values with shape (num_sequences,).
If forecast_horizon is an integer: Array of target sequences with shape (num_sequences, forecast_horizon).
- Return type:
Tuple[`numpy.ndarray`, numpy.ndarray]- Raises:
ValueError – If the DataFrame df does not contain the target_col.
Examples
>>> import pandas as pd >>> import numpy as np >>> from geoprior.models.utils import create_sequences
>>> # Create a dummy DataFrame >>> data = pd.DataFrame({ ... 'feature1': np.random.rand(100), ... 'feature2': np.random.rand(100), ... 'feature3': np.random.rand(100), ... 'target': np.random.rand(100) ... })
>>> # Create sequences for single-step forecasting >>> sequence_length = 4 >>> sequences, targets = create_sequences( ... df=data, ... sequence_length=sequence_length, ... target_col='target', ... step=1, ... include_overlap=True, ... drop_last=True, ... forecast_horizon=None ... ) >>> print(sequences.shape) (95, 4, 4) >>> print(targets.shape) (95,)
>>> # Create sequences for multi-step forecasting (e.g., 3 steps) >>> forecast_horizon = 3 >>> sequences, targets = create_sequences( ... df=data, ... sequence_length=4, ... target_col='target', ... step=1, ... include_overlap=True, ... drop_last=True, ... forecast_horizon=3 ... ) >>> print(sequences.shape) (92, 4, 4) >>> print(targets.shape) (92, 3)
Notes
Sequence Creation: The function slides a window of size sequence_length across the DataFrame to create input sequences. Each sequence is associated with a target value or sequence of values that immediately follow the input sequence.
- Forecast Horizon:
If forecast_horizon is None, the function creates targets for a single future time step.
If forecast_horizon is an integer H, the function creates targets consisting of the next H time steps.
Step Size: The step parameter controls the stride of the sliding window. A step of 1 results in overlapping sequences, while a larger step reduces overlap.
Handling Incomplete Sequences: If drop_last is set to False, the function includes the last sequence even if it doesn’t have enough data points to form a complete sequence or target.
Data Validation: The function utilizes are_all_frames_valid from geoprior.core.checks to ensure the integrity of input DataFrame before processing and exist_features to verify the presence of the target column.
The sequences generation can be expressed as:
(65)#\[\begin{split}\\text{For each sequence } i, \\\\ \\mathbf{X}^{(i)} = \\left[ \\mathbf{x}_{i}, \\mathbf{x}_{i+1}, \\\\ \\dots, \\mathbf{x}_{i+T-1} \\right] \\\\ y^{(i)} = \\begin{cases} \\mathbf{x}_{i+T} & \\text{if } \\text{forecast\\_horizon} = \\text{None} \\\\ \\left[ \\mathbf{x}_{i+T}, \\mathbf{x}_{i+T+1}, \\dots, \\\\ \\mathbf{x}_{i+T+H-1} \\right] & \\text{if } \\text{forecast\\_horizon} = H \\end{cases}\end{split}\]- Where:
\(\\mathbf{X}^{(i)}\) is the input sequence of length \(T\).
\(y^{(i)}\) is the target value(s) following the sequence.
See also
geoprior.models.utils.split_static_dynamicFunction to split sequences into static and dynamic inputs.
- geoprior.models.utils.split_static_dynamic(sequences, static_indices, dynamic_indices, static_time_step=0, reshape_static=True, reshape_dynamic=True, static_reshape_shape=None, dynamic_reshape_shape=None)[source]
Split sequences into static and dynamic inputs for the model.
The split_static_dynamic function divides input sequences into static and dynamic components based on specified feature indices. Static features are typically location-specific and do not change over time, while dynamic features vary across different time steps.
(66)#\[\begin{split}\text{Static Inputs} = \mathbf{S} = \mathbf{X}_{t, static\_indices} \\ \text{Dynamic Inputs} = \mathbf{D} = \mathbf{X}_{:, dynamic\_indices}\end{split}\]- Parameters:
sequences (numpy.ndarray) – Array of input sequences with shape (batch_size, sequence_length, num_features).
static_indices (List[int]) – Indices of static features within the feature dimension.
dynamic_indices (List[int]) – Indices of dynamic features within the feature dimension.
static_time_step (int, default 0) – Time step from which to extract static features (default is the first time step).
reshape_static (bool, default True) – Whether to reshape static inputs. If False, returns without reshaping.
reshape_dynamic (bool, default True) – Whether to reshape dynamic inputs. If False, returns without reshaping.
static_reshape_shape (:py:class:
Optional[Tuple[int`, :py:class:`...]], default None) – Desired shape for static inputs after reshaping. If None, defaults to (batch_size, num_static_vars, 1).dynamic_reshape_shape (:py:class:
Optional[Tuple[int`, :py:class:`...]], default None) – Desired shape for dynamic inputs after reshaping. If None, defaults to (batch_size, sequence_length, num_dynamic_vars, 1).
- Returns:
A tuple containing: - Static inputs with shape as specified. - Dynamic inputs with shape as specified.
- Return type:
Tuple[`numpy.ndarray`, numpy.ndarray]- Raises:
ValueError – If static_time_step is out of range for the given sequence length.
Examples
>>> import numpy as np >>> from geoprior.models.utils import split_static_dynamic >>> >>> # Create a dummy sequence array >>> sequences = np.random.rand(100, 10, 5) # ( ... batch_size=100, sequence_length=10, num_features=5) >>> >>> # Define static and dynamic feature indices >>> static_indices = [0, 1] # e.g., longitude and latitude >>> dynamic_indices = [2, 3, 4] # e.g., year, GWL, density >>> >>> # Split the sequences >>> static_inputs, dynamic_inputs = split_static_dynamic( ... sequences, ... static_indices=static_indices, ... dynamic_indices=dynamic_indices, ... static_time_step=0 ... ) >>> >>> print(static_inputs.shape) (100, 2, 1) >>> print(dynamic_inputs.shape) (100, 10, 3, 1)
Notes
Static Features: These are typically location-specific features such as geographical coordinates or categorical attributes that remain constant over time.
Dynamic Features: These features vary over different time steps and are essential for capturing temporal dependencies in the data.
Reshaping: The function provides flexibility in reshaping the static and dynamic inputs to match the input requirements of various models, including Temporal Fusion Transformers.
See also
geoprior.models.utils.create_sequencesFunction to create input sequences and targets for time series forecasting.
- geoprior.models.utils.forecast_single_step(xtft_model, inputs, y=None, dt_col=None, mode='quantile', spatial_cols=None, q=None, tname=None, forecast_dt=None, apply_mask=False, mask_values=None, mask_fill_value=None, savefile=None, verbose=3, **kws)[source]
Generate a single-step forecast using the XTFT model.
This function generates a forecast for a single future time step using a pre-trained XTFT deep learning model. The model takes three inputs: X_static, X_dynamic, and X_future, and produces a prediction according to the formulation:
(67)#\[\hat{y}_{t+1} = f\Bigl( X_{\text{static}},\; X_{\text{dynamic}},\; X_{\text{future}} \Bigr)\]where \(f\) is the trained XTFT model. The predictions can be either quantile-based or point-based, as determined by the mode parameter.
- Parameters:
xtft_model (
object) – A validated Keras model instance. The model is expected to be verified viavalidate_keras_model.inputs (
listortupleofnumpy.ndarray) – A list containing three elements:X_static,X_dynamic, andX_future. Ifspatial_colsis provided, it is assumed that the first column ofX_staticcorresponds to the first spatial coordinate and the second column to the second spatial coordinate of the original training data.y (
numpy.ndarray, optional) – Actual target values. If provided, evaluation metrics such as R² Score and (in quantile mode) the coverage score are computed.dt_col (
str, optional) – Name of the time column (e.g."year"). If provided, a column with this name is added to the output DataFrame. The actual time values must be supplied externally.mode (
str, optional) – Forecast mode. Must be either"quantile"or"point". In quantile mode, predictions are generated for multiple quantiles (default:0.1,0.5, and0.9).spatial_cols (
listofstr, optional) – List of spatial column names. If provided, it must contain at least two elements and correspond to the first and second columns of the original training data’sX_static.q (
listoffloat, optional) – List of quantiles for quantile forecasting. Default is[0.1, 0.5, 0.9]when mode is"quantile".tname (
str, optional) – Target variable name for predictions. This name is used to construct output column names (e.g."subsidence"). Default is"target".forecast_dt (
any, optional) – Forecast datetime information. Not used in this function but may be provided for compatibility.apply_mask (
bool, optional) – If True, applies a masking function (mask_by_reference) to replace predictions in non-subsiding areas. Requires that bothmask_valuesandmask_fill_valueare provided.mask_values (
scalar, optional) – Reference value(s) used for masking. Must be provided ifapply_maskis True.mask_fill_value (
scalar, optional) – Value used to fill masked predictions. Must be provided ifapply_maskis True.savefile (
str, optional) – Path to a CSV file where the forecast results will be saved. If not provided, a default filename is generated.verbose (
int, optional) – Verbosity level controlling printed output. Higher values result in more detailed output.
- Returns:
A DataFrame containing the forecast results. In quantile mode, the output includes columns for each quantile (e.g.
<tname>_q10,<tname>_q50,<tname>_q90). In point mode, a single prediction column (<tname>_pred) is provided. If y is provided, an additional column with the actual target values (<tname>_actual) is included.- Return type:
Examples
>>> from geoprior.models.transformers import XTFT >>> from geoprior.models.utils import forecast_single_step >>> import pandas as pd >>> import numpy as np >>> >>> # Create a dummy training DataFrame with a date column, >>> # two dynamic features ("feat1", "feat2"), a static feature ("stat1"), >>> # and dummy spatial features ("longitude", "latitude"), as well as the >>> # target variable "subsidence". >>> date_rng = pd.date_range(start="2020-01-01", periods=50, freq="D") >>> train_df = pd.DataFrame({ ... "date": date_rng, ... "longitude": np.random.uniform(-180, 180, 50), ... "latitude": np.random.uniform(-90, 90, 50), ... "feat1": np.random.rand(50), ... "feat2": np.random.rand(50), ... "stat1": np.random.rand(50), ... "subsidence": np.random.rand(50) ... }) >>> >>> # Prepare dummy inputs for the model. >>> # For the static input, combine the spatial feature "longitude" and the >>> # static feature "stat1". The forecast_single_step function expects that, >>> # if spatial_cols is provided, the first two columns of X_static correspond >>> # to the spatial coordinates. >>> X_static = train_df[["longitude", "stat1"]].values # shape: (50, 2) >>> >>> # Create a dummy dynamic input array for "feat1" and "feat2". >>> # Assume time_steps = 3, so the shape is (50, 3, 2). >>> X_dynamic = np.random.rand(50, 3, 2) >>> >>> # Create a dummy future input array. >>> # For this example, assume one future feature with shape (50, 3, 1). >>> X_future = np.random.rand(50, 3, 1) >>> >>> # Create dummy target output from "subsidence", reshaped to (50, 1, 1) >>> y_array = train_df["subsidence"].values.reshape(50, 1, 1) >>> >>> # Instantiate a dummy XTFT model. >>> # The model expects: >>> # - X_static with shape (n_samples, static_input_dim) >>> # - X_dynamic with shape (n_samples, time_steps, dynamic_input_dim) >>> # - X_future with shape (n_samples, time_steps, future_input_dim) >>> my_model = XTFT( ... static_input_dim=2, # "longitude" and "stat1" ... dynamic_input_dim=2, # "feat1" and "feat2" ... future_input_dim=1, # One future feature ... forecast_horizon=1, # Single-step forecast ... quantiles=[0.1, 0.5, 0.9], ... embed_dim=16, ... max_window_size=3, ... memory_size=50, ... num_heads=2, ... dropout_rate=0.1, ... lstm_units=32, ... attention_units=32, ... hidden_units=16 ... ) >>> my_model.compile(optimizer="adam") >>> >>> # Fit the model on the dummy data. >>> my_model.fit( ... x=[X_static, X_dynamic, X_future], ... y=y_array, ... epochs=1, ... batch_size=8 ... ) >>> >>> # Package the inputs as expected by forecast_single_step. >>> inputs = [X_static, X_dynamic, X_future] >>> >>> # Generate a single-step quantile forecast. >>> forecast_df = forecast_single_step( ... xtft_model=my_model, ... inputs=inputs, ... y=y_array, ... dt_col="date", # The time column name in the output ... mode="quantile", # Can be "quantile" or "point" ... spatial_cols=["longitude", "latitude"], ... q=[0.1, 0.5, 0.9], ... tname="subsidence", ... apply_mask=True, ... mask_values=0, ... mask_fill_value=0, ... verbose=3 ... ) >>> print(forecast_df.head())
Notes
In quantile mode, the function computes predictions for multiple quantiles and uses the median (
0.5) for evaluation.If
spatial_colsis provided, it must be the first and second columns of the original training data’sX_static.The function internally utilizes
validate_keras_modelfor model validation,assert_ratiofor quantile verification, andmask_by_referencefor masking operations.Evaluation metrics such as R² Score and Coverage Score are computed if actual target values (y) are provided.
The prediction output is expected to have the shape \((n, 1, m)\), where \(m\) is the number of outputs (e.g., the number of quantiles in quantile mode, or 1 in point mode).
See also
generate_forecast_multi_stepFunction for multi-step forecasts.
coverage_scoreFunction to compute the coverage score.
validate_keras_modelFunction to validate a Keras model.
assert_ratioFunction to validate quantile ratios.
- geoprior.models.utils.forecast_multi_step(xtft_model, inputs, forecast_horizon, y=None, dt_col=None, mode='quantile', spatial_cols=None, q=None, tname=None, forecast_dt=None, apply_mask=False, mask_values=None, mask_fill_value=None, savefile=None, verbose=3, **kws)[source]
Generate a multi-step forecast using the XTFT model.
This function generates forecasts for multiple future time steps using a pre-trained XTFT deep learning model. The model takes three inputs: X_static, X_dynamic, and X_future, and produces predictions according to the formulation:
(68)#\[\hat{y}_{t+i} = f\Bigl( X_{\text{static}},\; X_{\text{dynamic}},\; X_{\text{future}} \Bigr)\]for \(i = 1, \dots, forecast_horizon\), where \(f\) is the trained XTFT model.
- Parameters:
xtft_model (
object) – A validated Keras model instance. The model is expected to be verified viavalidate_keras_model.inputs (
listortupleofnumpy.ndarray) – A list containing three elements:X_static,X_dynamic, andX_future. Ifspatial_colsis provided, it is assumed that the first two columns ofX_staticcorrespond to the first and second spatial coordinates of the original training data.forecast_horizon (
int) – The number of future time steps to forecast. For example, ifforecast_horizonis 4, the model will generate predictions for 4 steps ahead.y (
numpy.ndarray, optional) – Actual target values. If provided, evaluation metrics such as R² Score and, in quantile mode, the coverage score are computed.dt_col (
str, optional) – Name of the time column (e.g."year"). If provided, a column with this name is added to the output DataFrame. The actual time values must be supplied externally.mode (
str, optional) – Forecast mode. Must be either"quantile"or"point". In quantile mode, predictions are generated for multiple quantiles (default:[0.1, 0.5, 0.9]); in point mode, a single prediction is generated.spatial_cols (
listofstr, optional) – A list of spatial column names. If provided, it must contain at least two elements corresponding to the first and second columns of the original training data’sX_static.time_steps (
int, optional) – The number of historical time steps used as input. Default is3.q (
listoffloat, optional) – List of quantile values for quantile forecasting. The default is[0.1, 0.5, 0.9]whenmodeis"quantile".tname (
str, optional) – Target variable name used to construct output column names. For instance, iftnameis"subsidence", then output columns may be named"subsidence_q10_step1","subsidence_q50_step2", etc. Default is"target".forecast_dt (
any, optional) – Forecast datetime information. If provided and its length matchesforecast_horizon, its values are added to the output DataFrame.apply_mask (
bool, optional) – If True, applies masking viamask_by_referenceto replace predictions in non-subsiding areas. Requires that bothmask_valuesandmask_fill_valueare provided.mask_values (
scalar, optional) – The reference value(s) used for masking. Must be provided ifapply_maskis True.mask_fill_value (
scalar, optional) – The value used to fill masked predictions. Must be provided ifapply_maskis True.savefile (
str, optional) – File path to save the forecast results as a CSV file. If not provided, a default filename is generated.verbose (
int, optional) – Verbosity level controlling printed output. Higher values produce more detailed messages.
- Returns:
A DataFrame containing the multi-step forecast results. In quantile mode, the DataFrame includes columns for each quantile and each forecast step (e.g.
<tname>_q10_step1,<tname>_q50_step2, etc.); in point mode, it contains a single prediction column per forecast step (e.g.<tname>_pred_step1). Ifyis provided, an additional column (<tname>_actual) is included.- Return type:
Examples
>>> from geoprior.models.transformers import XTFT >>> from geoprior.models.utils import forecast_multi_step >>> from geoprior.models.losses import combined_quantile_loss >>> import pandas as pd >>> import numpy as np >>> >>> # Create a dummy training DataFrame with a date column, >>> # spatial features ("longitude", "latitude"), two dynamic >>> # features ("feat1", "feat2"), a static feature ("stat1"), and >>> # the target variable "subsidence". >>> date_rng = pd.date_range(start="2020-01-01", periods=60, ... freq="D") >>> train_df = pd.DataFrame({ ... "date": date_rng, ... "longitude": np.random.uniform(-180, 180, 60), ... "latitude": np.random.uniform(-90, 90, 60), ... "feat1": np.random.rand(60), ... "feat2": np.random.rand(60), ... "stat1": np.random.rand(60), ... "subsidence": np.random.rand(60) ... }) >>> >>> # Prepare dummy input arrays for model training. >>> # X_static is constructed using "longitude" and "stat1". >>> X_static = train_df[["longitude", "stat1"]].values >>> # X_dynamic for "feat1" and "feat2" with time_steps = 3. >>> X_dynamic = np.random.rand(60, 3, 2) >>> # X_future is a dummy future feature array with shape (60, 3, 1). >>> X_future = np.random.rand(60, 3, 1) >>> # Target output from "subsidence" reshaped to >>> # (60, 1, 1). For multi-step forecast, forecast_horizon is 4. >>> forecast_horizon = 4 >>> y_array = train_df["subsidence"].values.reshape(60, 1, 1) >>> >>> # Instantiate a dummy XTFT model. >>> my_model = XTFT( ... static_input_dim=2, # "longitude" and "stat1" ... dynamic_input_dim=2, # "feat1" and "feat2" ... future_input_dim=1, # One future feature ... forecast_horizon=forecast_horizon, ... quantiles=[0.1, 0.5, 0.9], ... embed_dim=16, ... max_window_size=3, ... memory_size=50, ... num_heads=2, ... dropout_rate=0.1, ... lstm_units=32, ... attention_units=32, ... hidden_units=16 ... ) >>> my_model.compile( ... optimizer="adam", ... loss=combined_quantile_loss(my_model.quantiles) ... ) >>> >>> # Fit the model on the dummy data for demonstration. >>> my_model.fit( ... x=[X_static, X_dynamic, X_future], ... y=y_array, ... epochs=1, ... batch_size=8 ... ) >>> >>> # Generate forecast datetime values for the forecast horizon. >>> forecast_dates = pd.date_range(start="2020-02-01", ... periods=forecast_horizon, freq="D") >>> >>> # Package inputs as expected by forecast_multi_step. >>> inputs = [X_static, X_dynamic, X_future] >>> >>> # Generate a multi-step forecast in quantile mode. >>> forecast_df_quantile = forecast_multi_step( ... xtft_model=my_model, ... inputs=inputs, ... forecast_horizon=forecast_horizon, ... y=y_array, ... dt_col="date", ... mode="quantile", ... spatial_cols=["longitude", "latitude"], ... q=[0.1, 0.5, 0.9], ... tname="subsidence", ... forecast_dt=forecast_dates, ... apply_mask=False, ... verbose=3 ... ) >>> print("Quantile Forecast:") >>> print(forecast_df_quantile.head()) >>>
For point forecast
>>> # Instantiate a dummy XTFT model. >>> my_model = XTFT( ... static_input_dim=2, # "longitude" and "stat1" ... dynamic_input_dim=2, # "feat1" and "feat2" ... future_input_dim=1, # One future feature ... forecast_horizon=forecast_horizon, ... quantiles=None, # set quantiles to None ... embed_dim=16, ... max_window_size=3, ... memory_size=50, ... num_heads=2, ... dropout_rate=0.1, ... lstm_units=32, ... attention_units=32, ... hidden_units=16 ... ) >>> my_model.compile( ... optimizer="adam", loss="mse", ... ) >>> >>> # Fit the model on the dummy data for demonstration. >>> my_model.fit( ... x=[X_static, X_dynamic, X_future], ... y=y_array, ... epochs=1, ... batch_size=8 ... )
>>> # Generate a multi-step forecast in point mode. >>> forecast_df_point = forecast_multi_step( ... xtft_model=my_model, ... inputs=inputs, ... forecast_horizon=forecast_horizon, ... y=y_array, ... dt_col="date", ... mode="point", ... spatial_cols=["longitude", "latitude"], ... tname="subsidence", ... forecast_dt=forecast_dates, ... apply_mask=False, ... verbose=3 ... ) >>> print("Point Forecast:") >>> print(forecast_df_point.head())
Notes
In quantile mode, predictions are generated for each specified quantile for every forecast step, and the median (
0.5) is used for evaluation.In point mode, a single prediction is generated per forecast step.
The output prediction array is expected to have the shape \((n, forecast\_horizon, m)\), where \(n\) is the number of samples and \(m\) is the number of outputs per step (e.g., number of quantiles in quantile mode or 1 in point mode).
The provided
spatial_colsmust correspond to the first two columns of the original training data’sX_static.Evaluation metrics such as R² Score and Coverage Score (in quantile mode) are computed if actual target values (
y) are provided.The DataFrame is constructed by iterating over each sample and each forecast step.
See also
forecast_single_stepFunction for single-step forecasts.
coverage_scoreFunction to compute the coverage score.
validate_keras_modelFunction to validate a Keras model.
assert_ratioFunction to verify quantile ratios.
- geoprior.models.utils.format_predictions_to_dataframe(predictions=None, model=None, inputs=None, y_true_sequences=None, target_name='target', quantiles=None, forecast_horizon=None, output_dim=None, spatial_data_array=None, spatial_cols=None, spatial_cols_indices=None, evaluate_coverage=False, scaler=None, scaler_feature_names=None, target_idx_in_scaler=None, verbose=0, **kwargs)[source]
Deprecated alias for format_predictions. See format_predictions for the updated, recommended API. All original parameters are forwarded to format_predictions.
- Returns:
The formatted prediction DataFrame from format_predictions.
- Return type:
pd.DataFrame- Parameters:
- geoprior.models.utils.extract_batches_from_dataset(dataset, num_batches_to_extract=1, agg=False, errors='warn')[source]
Extracts a specified number of batches from a tf.data.Dataset. Optionally aggregates the extracted batches.
- Parameters:
dataset (
tf.data.Dataset) – The TensorFlow dataset to extract batches from.num_batches_to_extract (
Union[int,str], default1) – Number of batches: int, or ‘all’, ‘*’, ‘auto’.agg (
bool, defaultFalse) – If True, attempts to aggregate the extracted batches into a single tuple structure by concatenating corresponding tensors/arrays or aggregating dictionaries.errors (
str, default'warn') – Error handling: ‘raise’, ‘warn’, ‘ignore’.
- Returns:
If
aggisFalse, returns a list of batch tuples. IfaggisTrue, returns one aggregated tuple orNoneif no batches were extracted. When zero batches are requested or the dataset is empty, the function returns an empty list foragg=FalseandNoneforagg=True.- Return type:
Union[List[Tuple[Any,]],Optional[Tuple[Any,]]]- Raises:
TypeError – If dataset is not a tf.data.Dataset or num_batches_to_extract is invalid (and errors=’raise’).
ValueError – If num_batches_to_extract is negative, or fewer batches are available than requested (and errors=’raise’ and not taking all).
RuntimeError – For unexpected errors during dataset iteration (and errors=’raise’).
- geoprior.models.utils.get_tensor_from(inputs, *tensor_names, default=None, check_type=True, auto_convert=True)[source]
Safely retrieves the first available tensor from a dictionary using a list of possible keys.
This utility is crucial for handling model inputs within a TensorFlow graph (e.g., in train_step). It avoids the ambiguous boolean evaluation of Tensors (e.g., tensor_a or tensor_b), which causes runtime errors, by explicitly checking for is not None.
- Parameters:
inputs (
dict) – The dictionary to search, typically the model’s input dictionary (e.g., the inputs provided to call or train_step).*tensor_names (
str) – One or more string keys to check for in the inputs dictionary, in order of priority.default (
Any, optional) – A default value to return if no keys are found or if no found value is a valid tensor. Defaults to None.check_type (
bool, defaultTrue) – If True, only returns a value if it is (or can be converted to) a Tensor or Variable. If False, returns the first non-None value regardless of its type.auto_convert (
bool, defaultTrue) – If True and check_type is True, this function will attempt to convert a found non-Tensor value (like a NumPy array or a list) into a TensorFlow tensor using tf.convert_to_tensor.
- Returns:
The first found tf.Tensor or tf.Variable associated with one of the tensor_names. If auto_convert is True, this can also be a newly converted tensor. Returns default (typically None) if no valid tensor is found.
- Return type:
Optional[tf.Tensor]- Raises:
TypeError – If inputs is not a dictionary.
Examples
>>> import tensorflow as tf >>> inputs_dict = { ... 'some_other_key': [1, 2, 3], ... 'soil_thickness': tf.constant([20., 21.], dtype=tf.float32) ... } >>> >>> # Correctly finds 'soil_thickness' >>> get_tensor_from(inputs_dict, 'H_field', 'soil_thickness') <tf.Tensor: shape=(2,), dtype=float32, numpy=array([20., 21.], ...)> >>> >>> # Returns None safely if nothing is found >>> get_tensor_from(inputs_dict, 'missing_key', 'another_key') None >>> >>> # Demonstrating auto_convert >>> inputs_dict_np = {'H_field': np.array([10., 11.])} >>> get_tensor_from(inputs_dict_np, 'H_field', auto_convert=True) <tf.Tensor: shape=(2,), dtype=float32, numpy=array([10., 11.], ...)>
- geoprior.models.utils.compute_anomaly_scores(y_true, y_pred=None, method='statistical', threshold=3.0, domain_func=None, contamination=0.05, epsilon=1e-06, estimator=None, random_state=None, residual_metric='mse', objective='ts', verbose=1)[source]
Compute anomaly scores for given true targets using various methods.
This utility function,
anomaly_scores, provides a flexible approach to compute anomaly scores outside the XTFT model itself. Anomaly scores serve as indicators of how unusual certain observations are, guiding the model towards more robust and stable forecasts. By detecting and quantifying anomalies, practitioners can adjust forecasting strategies, improve predictive performance, and handle irregular patterns more effectively.- Parameters:
y_true (
np.ndarray) –The ground truth target values with shape
(B, H, O), where: - B: batch size - H: number of forecast horizons (time steps ahead) - O: output dimension (e.g., number of target variables).Typically, y_true corresponds to the same array passed as the forecast target to the model. All computations of anomalies are relative to these true values or, if provided, their predicted counterparts y_pred.
y_pred (
np.ndarray, optional) – The predicted values with shape(B, H, O). If provided and the method is set to ‘residual’, the anomaly scores are derived from the residuals between y_true and y_pred. In this scenario, anomalies reflect discrepancies indicating unusual conditions or model underperformance.method (
str, optional) –The method used to compute anomaly scores. Supported options:
"statistical"or"stats": Uses mean and standard deviation of y_true to measure deviation from the mean. Points far from the mean by a certain factor (controlled by threshold) yield higher anomaly scores.Formally, let \(\mu\) be the mean of y_true and \(\sigma\) its standard deviation. The anomaly score for a point \(y\) is:
(69)#\[(\frac{y - \mu}{\sigma + \varepsilon})^2\]where \(\varepsilon\) is a small constant for numerical stability.
"domain": Uses a domain-specific heuristic (provided by domain_func) to compute scores. If no domain_func is provided, a default heuristic marks negative values as anomalies."isolation_forest"or"if": Employs the IsolationForest algorithm to detect outliers. The model learns a structure to isolate anomalies more quickly than normal points. Higher contamination rates allow more points to be considered anomalous."residual": If y_pred is provided, anomalies are derived from residuals: the difference (y_true - y_pred). By default, mean squared error (mse) is used. Other metrics include mae and rmse, offering flexibility in quantifying deviations:(70)#\[\text{MSE: }(y_{true} - y_{pred})^2\]
Default is
"statistical".threshold (
float, optional) – Threshold factor for the statistical method. Defines how far beyond mean ± (threshold * std) is considered anomalous. Though not directly applied as a mask here, it can guide interpretation of scores. Default is3.0.domain_func (
callable, optional) –A user-defined function for domain method. It takes y_true as input and returns an array of anomaly scores with the same shape. If none is provided, the default heuristic:
(71)\[\begin{split}\text{anomaly}(y) = \begin{cases} |y| \times 10 & \text{if } y < 0 \\ 0 & \text{otherwise} \end{cases}\end{split}\]contamination (
float, optional) – Used in the isolation_forest method. Specifies the proportion of outliers in the dataset. Default is0.05.epsilon (
float, optional) – A small constant \(\varepsilon\) for numerical stability in calculations, especially during statistical normalization. Default is1e-6.estimator (
object, optional) – A pre-fitted IsolationForest estimator for the isolation_forest method. If not provided, a new estimator will be created and fitted to y_true.random_state (
int, optional) – Sets a random state for reproducibility in the isolation_forest method.residual_metric (
str, optional) –The metric used to compute anomalies from residuals if method is set to ‘residual’. Supported metrics:
"mse": mean squared error per point (residuals**2)"mae": mean absolute error per point |residuals|"rmse": root mean squared error sqrt((residuals**2))
Default is
"mse".objective (
str, optional) – Specifies the type of objective, for future extensibility. Default is"ts"indicating time series. Could be extended for other tasks in the future.verbose (
int, optional) – Controls verbosity. If verbose=1, some messages or warnings may be printed. Higher values might produce more detailed logs.
- Returns:
anomaly_scores – An array of anomaly scores with the same shape as y_true. Higher values indicate more unusual or anomalous points.
- Return type:
np.ndarray
Notes
Choosing an appropriate method depends on the data characteristics, domain requirements, and model complexity. Statistical methods are quick and interpretable but may oversimplify anomalies. Domain heuristics leverage expert knowledge, while isolation forest applies a more robust, data-driven approach. Residual-based anomalies help assess model performance and highlight periods where the model struggles.
Examples
>>> from geoprior.models.losses import compute_anomaly_scores >>> import numpy as np
>>> # Statistical method example >>> y_true = np.random.randn(32, 20, 1) # (B,H,O) >>> scores = compute_anomaly_scores(y_true, method='statistical', threshold=3) >>> scores.shape (32, 20, 1)
>>> # Domain-specific example >>> def my_heuristic(y): ... return np.where(y < -1, np.abs(y)*5, 0.0) >>> scores = compute_anomaly_scores(y_true, method='domain', domain_func=my_heuristic)
>>> # Isolation Forest example >>> scores = compute_anomaly_scores(y_true, method='isolation_forest', contamination=0.1)
>>> # Residual-based example >>> y_pred = y_true + np.random.normal(0, 1, y_true.shape) # Introduce noise >>> scores = compute_anomaly_scores(y_true, y_pred=y_pred, method='residual', residual_metric='mae')
See also
geoprior.models.losses.objective_lossFor integrating anomaly scores into a multi-objective loss.
- geoprior.models.utils.PDE_MODE_ALIASES = frozenset({'', '0', '1', 'all', 'both', 'consolidation', 'disable', 'disabled', 'false', 'gw_flow', 'none', 'off', 'on', 'true'})
frozenset() -> empty frozenset object frozenset(iterable) -> frozenset object
Build an immutable unordered collection of unique elements.
- geoprior.models.utils.prepare_pinn_data_sequences(df, time_col, subsidence_col, gwl_col, dynamic_cols, static_cols=None, future_cols=None, spatial_cols=None, h_field_col=None, lon_col=None, lat_col=None, group_id_cols=None, time_steps=12, forecast_horizon=3, output_subsidence_dim=1, output_gwl_dim=1, datetime_format=None, normalize_coords=True, cols_to_scale=None, lock_physics_cols=True, protect_si_suffix='__si', return_coord_scaler=False, coord_scaler=None, fit_coord_scaler=True, mode=None, model=None, savefile=None, progress_hook=None, stop_check=None, verbose=0, _logger=None, **kws)[source]
- Parameters:
df (DataFrame)
time_col (str)
subsidence_col (str)
gwl_col (str)
h_field_col (str | None)
lon_col (str | None)
lat_col (str | None)
time_steps (int)
forecast_horizon (int)
output_subsidence_dim (int)
output_gwl_dim (int)
datetime_format (str | None)
normalize_coords (bool)
lock_physics_cols (bool)
protect_si_suffix (str)
return_coord_scaler (bool)
coord_scaler (MinMaxScaler | None)
fit_coord_scaler (bool)
mode (str | None)
model (str | None)
savefile (str | None)
verbose (int)
- Return type:
tuple[dict[str, ndarray], dict[str, ndarray]] | tuple[dict[str, ndarray], dict[str, ndarray], MinMaxScaler | None]
- geoprior.models.utils.normalize_for_pinn(df, time_col, coord_x, coord_y, cols_to_scale='auto', scale_coords=True, exclude_cols=None, protect_si_suffix='__si', shift_time_by_horizon=False, verbose=1, forecast_horizon=None, _logger=None, coord_scaler=None, fit_coord_scaler=True, other_scaler=None, fit_other_scaler=True, **kws)[source]
Apply Min-Max normalization to spatial-temporal coordinates and optionally to other numeric columns. If
cols_to_scale == "auto", automatically select numeric columns excluding categorical and one-hot features.By default, this function scales the time, longitude, and latitude columns (if
scale_coords=True). Then, it either scales explicitly provided columns incols_to_scaleor automatically infers numeric columns (excluding coordinates ifscale_coordsis False, and excluding one-hot/boolean columns).The Min-Max scaling for a feature \(x\) is:
(72)#\[x' = \frac{x - \min(x)}{\max(x) - \min(x)}\]- Parameters:
df (
pd.DataFrame) – The input DataFrame containing at leasttime_col,coord_x, andcoord_ycolumns. The DataFrame should contain temporal and spatial information to be scaled.time_col (
str) – The name of the numeric time column (e.g., year as numeric or datetime). This column will be used to adjust and scale the temporal data.coord_x (
str) – The name of the longitude column in the DataFrame. This column will be scaled along with the latitude and time columns.coord_y (
str) – The name of the latitude column in the DataFrame. This column will be scaled along with the longitude and time columns.cols_to_scale (
listofstror"auto"orNone, default"auto") – If a list of column names, scales exactly those columns. If"auto", selects all numeric columns, excludingtime_col,coord_x, andcoord_yifscale_coords=False, and excluding one-hot encoded columns whose values are only{0, 1}. If None, no extra columns are scaled.scale_coords (
bool, defaultTrue) – If True, scales the[time_col, coord_x, coord_y]columns. If False, these columns remain unchanged.verbose (
int, default1) – Verbosity level for logging. Values higher than 1 provide more detailed logging information.forecast_horizon (
Optional[int], defaultNone) – The number of time steps to shift the time column by. This is added to the time values before scaling if provided._logger (
Optional[Union[logging.Logger,Callable[[str],None]]], defaultNone) – Logger or function to handle logging messages. If None, the default logging mechanism is used.kws (
dict, optional) – These will be passed on to any other internal function used in the data processing or scaling steps.protect_si_suffix (str)
shift_time_by_horizon (bool)
coord_scaler (MinMaxScaler | None)
fit_coord_scaler (bool)
other_scaler (MinMaxScaler | None)
fit_other_scaler (bool)
- Returns:
df_scaled (
pd.DataFrame) – A new DataFrame with the specified columns normalized.coord_scaler (
MinMaxScalerorNone) – The fitted scaler for the[time_col, coord_x, coord_y]columns ifscale_coords=True, elseNone.other_scaler (
MinMaxScalerorNone) – The fitted scaler for any additional columns that were scaled (either explicitly provided or auto-selected). None if no columns were scaled beyond the coordinates.
- Raises:
TypeError – If df is not a DataFrame, or cols_to_scale is neither a list nor “auto” nor None, or if any explicitly provided column is not a string.
ValueError – If required columns (
time_col,coord_x,coord_y) or any ofcols_to_scaledo not exist indf, or cannot be converted to numeric.
- Return type:
tuple[DataFrame, MinMaxScaler | None, MinMaxScaler | None]
Examples
>>> import pandas as pd >>> from geoprior.nn.pinn.utils import normalize_for_pinn >>> data = { ... "year_num": [0.0, 1.0, 2.0], ... "lon": [100.0, 101.0, 102.0], ... "lat": [30.0, 31.0, 32.0], ... "feat1": [10.0, 20.0, 30.0], ... "one_hot_A": [0, 1, 0] ... } >>> df = pd.DataFrame(data) >>> df_scaled, coord_scl, feat_scl = normalize_for_pinn( ... df, ... time_col="year_num", ... coord_x="lon", ... coord_y="lat", ... cols_to_scale="auto", ... scale_coords=True, ... verbose=2 ... ) >>> # 'year_num','lon','lat','feat1' get scaled; 'one_hot_A' excluded >>> df_scaled["year_num"].tolist() [0.0, 0.5, 1.0] >>> df_scaled["feat1"].tolist() [0.0, 0.5, 1.0]
Notes
When
cols_to_scale="auto", numeric columns with only{0, 1}values are assumed to be one-hot and excluded from scaling.If
scale_coords=False, coordinate columns remain unchanged, and auto-selection (if used) will exclude them.Returned
coord_scalerisNoneifscale_coords=False. Returnedother_scalerisNoneifcols_to_scaleisNoneor results in an empty set after filtering.
See also
sklearn.preprocessing.MinMaxScalerScales features to [0,1].
- geoprior.models.utils.format_pinn_predictions(predictions=None, model=None, model_inputs=None, y_true_dict=None, target_mapping=None, include_gwl=True, include_coords=True, quantiles=None, forecast_horizon=None, output_dims=None, ids_data_array=None, ids_cols=None, ids_cols_indices=None, scaler_info=None, coord_scaler=None, evaluate_coverage=False, coverage_quantile_indices=(0, -1), savefile=None, _logger=None, name=None, model_name=None, stop_check=None, verbose=0, **kwargs)[source]
Formats PINN model predictions into a structured pandas DataFrame.
This is a general-purpose utility for transforming raw model outputs (from models like PIHALNet or TransFlowSubsNet) into a long-format DataFrame suitable for analysis, visualization, or export.
This is a powerful, general-purpose utility for transforming raw model outputs into a long-format DataFrame suitable for analysis, visualization, or export. It handles multi-target outputs (e.g., subsidence and GWL), point or quantile forecasts, and can optionally include true values, coordinate information, and other metadata. It also supports inverse-scaling of predictions and evaluation of quantile coverage.
- Parameters:
predictions (
dictofTensors, optional) – The dictionary of prediction tensors, typically returned by a model’s.predict()method. Keys should match the model’s output layer names (e.g.,'subs_pred','gwl_pred'). IfNone, predictions are generated internally using the model and model_inputs arguments. Default isNone.model (
keras.Model, optional) – A compiled Keras model instance used to generate predictions if the predictions dictionary is not provided. Default isNone.model_inputs (
dictofTensors, optional) – A dictionary of input tensors matching the model’s signature, required only if predictions isNone. Default isNone.y_true_dict (
dict, optional) – A dictionary containing the ground-truth target arrays, keyed by their base names (e.g.,'subsidence','gwl'). If provided, an<target>_actualcolumn will be added to the output DataFrame for comparison. Default isNone.target_mapping (
dict, optional) – A custom mapping from model output keys to desired base names in the DataFrame columns. For example:{'subs_pred': 'subsidence_mm', 'gwl_pred': 'head_m'}. Default isNone.include_gwl (
bool, defaultTrue) – Toggles the inclusion of groundwater level (GWL) predictions in the final DataFrame.include_coords (
bool, defaultTrue) – Toggles the inclusion of the spatio-temporal coordinate columns (coord_t,coord_x,coord_y) in the final DataFrame.quantiles (
listoffloat, optional) – The list of quantile levels (e.g.,[0.1, 0.5, 0.9]) that the model predicted. This is crucial for correctly parsing probabilistic forecasts. Default isNone.forecast_horizon (
int, optional) – The length of the forecast horizon. IfNone, it is inferred from the shape of the prediction tensors. Default isNone.output_dims (
dictofstr, optional) – A dictionary specifying the feature dimension of each target, e.g.,{'subs_pred': 1, 'gwl_pred': 1}. IfNone, it’s inferred from the tensor shapes. Default isNone.ids_data_array (
np.ndarrayorpd.DataFrame, optional) – An array or DataFrame containing static identifiers (e.g., well IDs, site categories) for each sample. Its length must match the number of samples in the prediction. Default isNone.ids_cols (
listofstr, optional) – A list of column names for the ids_data_array. Required if ids_data_array is a NumPy array. Default isNone.ids_cols_indices (
listofint, optional) – A list of column indices to select from ids_data_array if it is a NumPy array. Default isNone.scaler_info (
dict, optional) – A dictionary providing the necessary information to perform inverse scaling on a per-target basis. Each key should be a target name (e.g., ‘subsidence’) and its value a dictionary containing{'scaler': obj, 'all_features': list, 'idx': int}. Default isNone.coord_scaler (
object, optional) – A fitted scikit-learn-like scaler object used to perform an inverse transform on the coordinate columns. Default isNone.evaluate_coverage (
bool, defaultFalse) – IfTrueand quantile predictions are present, calculates the unconditional coverage of the prediction interval.coverage_quantile_indices (
tupleof(int,int), default(0,-1)) – The indices of the lower and upper quantiles in the sorted quantiles list to use for the coverage calculation. Default is(0, -1), which corresponds to the full range.savefile (
str, optional) – If a file path is provided, the final DataFrame is saved to a CSV file at this location. Default isNone.name (
strorNone) – Name of the prediction. Name is used to format the output of the data and coverage result if applicable.model_name (
str,None,) – Name of the model.verbose (
int, default0) – The verbosity level, from 0 (silent) to 5 (trace every step).**kwargs (
dict,) – Additional keyword arguments for future extensions.
- Returns:
A long-format DataFrame where each row represents a single forecast step for a single sample. Columns include sample and step identifiers, coordinates, predictions, and optionally actuals and metadata.
- Return type:
pd.DataFrame
Notes
The function returns a column-aligned DataFrame, which simplifies subsequent analysis and plotting.
For quantile forecasts, prediction columns are named using the pattern
<target_name>_q<quantile*100>, e.g.,subsidence_q5,subsidence_q50,subsidence_q95.For point forecasts, the column is named
<target_name>_pred.
See also
geoprior.plot.forecast.plot_forecastsA powerful utility for visualizing the DataFrame produced by this function.
- geoprior.models.utils.extract_txy(inputs, coord_slice_map=None, expect_dim=None, verbose=0, _logger=None, **kws)[source]
Extracts t, x, y tensors from various input formats.
This utility standardizes coordinate inputs, accepting a single tensor or a dictionary, and handling both 2D (spatial/static) and 3D (spatio-temporal) data with flexible dimension validation.
- Parameters:
inputs (
tf.Tensor,np.ndarray, ordict) – The input data containing coordinates. Can be a single tensor or a dictionary with ‘coords’ or ‘t’, ‘x’, ‘y’ keys.coord_slice_map (
dict, optional) – Mapping for ‘t’, ‘x’, ‘y’ to their index in the last dimension of a coordinate tensor. Defaults to {‘t’: 0, ‘x’: 1, ‘y’: 2}.expect_dim (
{'2d', '3d', '3d_only'}, optional) – Enforces a constraint on the input’s dimension.'2d'requires input shaped like(batch, 3).'3d'accepts 3D input and expands 2D input to 3D with a time dimension of 1.'3d_only'requires 3D input and raises an error for 2D input.Noneaccepts both 2D and 3D inputs without changing their rank.verbose (
int, default0) – Controls logging verbosity.
- Returns:
t, x, y – The extracted t, x, and y coordinate tensors. Their rank (2D or 3D) depends on the input and the expect_dim mode.
- Return type:
Tuple[tf.Tensor,tf.Tensor,tf.Tensor]- Raises:
ValueError – If input format is unsupported, dimensions are inconsistent, or expect_dim constraint is violated.
- geoprior.models.utils.extract_txy_in(inputs, coord_slice_map=None, expect_dim=None, verbose=0, _logger=None, **kws)[source]
Extracts t, x, y tensors from various input formats.
This utility standardizes coordinate inputs, accepting a single tensor or a dictionary, and handling both 2D (spatial/static) and 3D (spatio-temporal) data. It ensures a consistent 3D output format for robust downstream processing.
- Parameters:
inputs (
tf.Tensor,np.ndarray, ordict) – The input data containing coordinates. A single tensor or array may be 2D with shape(batch, 3)or 3D with shape(batch, time_steps, 3). A dictionary may contain a'coords'key with the coordinate tensor, or separate't','x', and'y'keys.coord_slice_map (
dict, optional) – Mapping for ‘t’, ‘x’, ‘y’ to their index in the last dimension of a coordinate tensor. Defaults to {‘t’: 0, ‘x’: 1, ‘y’: 2}.expect_dim (
{'2d', '3d'}, optional) – If provided, enforces that the input resolves to the specified dimension.'2d'requires input shaped like(batch, 3)or a dictionary of(batch, 1)tensors.'3d'requires input shaped like(batch, time, 3)or a dictionary of(batch, time, 1)tensors. IfNone, both are accepted and 2D inputs are expanded to 3D.verbose (
int, default0) – Controls the verbosity of logging messages. 0 is silent, 1 provides basic info, and higher values provide more detail.
- Returns:
t, x, y – The extracted t, x, and y coordinate tensors, each reshaped to be 3D with a singleton last dimension, e.g., (batch, time_steps, 1).
- Return type:
Tuple[tf.Tensor,tf.Tensor,tf.Tensor]- Raises:
ValueError – If input format is unsupported, dimensions are inconsistent, or expect_dim constraint is violated.
- geoprior.models.utils.process_pde_modes(pde_mode, enforce_consolidation=False, pde_mode_config=None, solo_return=False)[source]
Normalize and validate PDE mode selection.
- Parameters:
pde_mode (
str,sequenceofstr, orNone) –Requested PDE mode(s).
Accepted canonical values are: -
"none"-"consolidation"-"gw_flow"-"both"Accepted aliases: -
None,"off"->"none"-"on"->"both"enforce_consolidation (
bool, defaultFalse) –If True, any resolved mode other than exact
["consolidation"]is coerced to["consolidation"]and a warning is emitted.This includes: -
["none"]-["gw_flow"]-["consolidation", "gw_flow"]pde_mode_config (
str,sequenceofstr, orNone, optional) – Optional override. If provided, this value takes precedence overpde_mode.solo_return (
bool, defaultFalse) –If False, return a canonical list of active modes.
If True, return a single canonical label: -
"none"-"consolidation"-"gw_flow"-"both"
- Returns:
Canonical PDE mode(s), either as a list or a single label.
- Return type:
- Raises:
TypeError – If the input type is invalid.
ValueError – If a token is unsupported or the mode selection is ambiguous.
Examples
>>> process_pde_modes(None) ['none']
>>> process_pde_modes("off") ['none']
>>> process_pde_modes("on") ['consolidation', 'gw_flow']
>>> process_pde_modes("both", solo_return=True) 'both'
>>> process_pde_modes("gw_flow", enforce_consolidation=True) ['consolidation']
Subsidence-physics utility layer#
The subsidence-physics utility layer is more specialized than the other two layers. It is not simply another helper bucket; it is a coherence layer for unit systems, scaling metadata, coordinate policies, historical states, and hydrogeological interpretation.
GeoPrior subsidence model utilities.
- geoprior.models.subsidence.utils.enforce_scaling_alias_consistency(scaling_kwargs, *, where='validate')[source]
Enforce that canonical keys and aliases agree.
If both canonical and an alias exist and their values differ, apply the scaling error policy.
- geoprior.models.subsidence.utils.canonicalize_scaling_kwargs(scaling_kwargs, *, copy=True)[source]
Return a canonicalized scaling dict.
If a canonical key is missing, but one of its aliases exists, copy alias -> canonical.
Keeps existing canonical values unchanged.
- geoprior.models.subsidence.utils.load_scaling_kwargs(scaling_kwargs, *, copy=True)[source]
Load scaling kwargs from a dict-like object or JSON.
- geoprior.models.subsidence.utils.get_sk(scaling_kwargs, key, *aliases, default=None, required=False, cast=None)[source]
Fetch a key from scaling_kwargs with aliases + default.
Tries: key -> built-in aliases -> explicit aliases
Treats None and blank strings as “missing” and keeps searching.
- geoprior.models.subsidence.utils.validate_scaling_kwargs(scaling_kwargs)[source]
Basic scaling sanity checks.
This includes policy-controlled heuristic checks for common “silent fallback” cases.
- geoprior.models.subsidence.utils.affine_from_cfg(scaling_kwargs, *, scale_key, bias_key, meta_keys=(), unit_key=None)[source]
Return (a,b) for y_si = y_model*a + b.
- geoprior.models.subsidence.utils.to_si_thickness(H_model, scaling_kwargs)[source]
Convert thickness to SI.
- geoprior.models.subsidence.utils.to_si_head(h_model, scaling_kwargs)[source]
Convert head/depth to SI meters.
- geoprior.models.subsidence.utils.to_si_subsidence(s_model, scaling_kwargs)[source]
Convert subsidence to SI meters.
- geoprior.models.subsidence.utils.from_si_subsidence(s_si, scaling_kwargs)[source]
Inverse of to_si_subsidence: s_model = (s_si - b) / a.
- geoprior.models.subsidence.utils.deg_to_m(axis, scaling_kwargs)[source]
Meters per degree factor for lon/lat coords.
If coords_in_degrees=True and deg_to_m_lon/lat are missing, we try to compute them from lat0_deg (recommended).
- geoprior.models.subsidence.utils.coord_ranges(scaling_kwargs)[source]
Return (tR,xR,yR) if coords_normalized.
- geoprior.models.subsidence.utils.resolve_gwl_dyn_index(scaling_kwargs)[source]
Resolve GWL channel index for dynamic_features.
- geoprior.models.subsidence.utils.get_gwl_dyn_index_cached(model)[source]
Cache gwl_dyn_index on model after first resolve.
- Return type:
- geoprior.models.subsidence.utils.resolve_subs_dyn_index(scaling_kwargs)[source]
Resolve subsidence channel index for dynamic_features.
This is optional: v3.2 can use historical subsidence as a dynamic driver to provide a physics-friendly initial condition for the mean settlement path.
- geoprior.models.subsidence.utils.get_subs_dyn_index_cached(model)[source]
Cache subs_dyn_index on model after first resolve.
- Return type:
- geoprior.models.subsidence.utils.slice_dynamic_channel(Xh, idx)[source]
Slice (B,T,F) -> (B,T,1) at idx.
- Parameters:
Xh (Tensor)
idx (int)
- Return type:
Tensor
- geoprior.models.subsidence.utils.assert_dynamic_names_match_tensor(Xh, scaling_kwargs)[source]
Check dynamic_feature_names length matches Xh.
- geoprior.models.subsidence.utils.gwl_to_head_m(v_m, scaling_kwargs, *, inputs=None)[source]
Convert depth-bgs to head if possible.
- geoprior.models.subsidence.utils.get_h_hist_si(model, inputs, *, want_head=True)[source]
Return head (or depth) history in SI meters.
- Parameters:
- Returns:
(B,T,1) tensor in SI meters.
- Return type:
Tensor
- geoprior.models.subsidence.utils.get_s_init_si(model, inputs, like)[source]
Return initial settlement (cumulative subsidence) in SI meters.
Priority: 1) explicit keys in inputs (s_init_si/subs_hist_last_si/…) 2) last historical value from dynamic_features if subs_dyn_index exists 3) zeros (broadcast)
- geoprior.models.subsidence.utils.get_h_ref_si(model, inputs, like)[source]
Return h_ref in SI meters, broadcast to like.
- geoprior.models.subsidence.utils.infer_dt_units_from_t(t_BH1, scaling_kwargs, *, eps=1e-12)[source]
Infer per-step dt in time_units from time tensor t(B,H,1).
- geoprior.models.subsidence.utils.policy_gate(step, policy, *, warmup_steps=0, ramp_steps=0, dtype=tf.float32)[source]
Return a scalar gate in
[0,1]based on a policy + step.- Parameters:
step (
Tensor) – Global step counter (typicallyoptimizer.iterations).policy (
{"always_on","always_off","warmup_off"}) – Gating behavior.always_onreturns 1,always_offreturns 0, andwarmup_offreturns 0 forstep < warmup_stepsbefore ramping to 1 overramp_stepswhenramp_steps > 0or switching immediately atwarmup_stepsotherwise.warmup_steps (
int, default0) – Number of steps to keep the gate at 0 (only forwarmup_off).ramp_steps (
int, default0) – Number of steps for a linear ramp from 0->1 after warmup. If 0, the gate is a hard step.dtype (
dtype, defaulttf_float32) – Output dtype.
- Return type:
Tensor
- geoprior.models.subsidence.utils.finalize_scaling_kwargs(sk)[source]
Add derived SI conversion constants to
scaling_kwargs.Adds (when possible): -
seconds_per_time_unit: float -coord_ranges_si: dict with keyst(seconds),x/y(meters) -coord_inv_ranges_si: inverse of the above (safe floor).Notes
This helper is designed to be called once when assembling
scaling_kwargs(e.g., in your stage2 script) so the model can reuse those constants without recomputing unit conversions in the hot training loop.
- geoprior.models.subsidence.utils.coord_ranges_si(sk)[source]
Return coordinate spans in SI (t in seconds; x/y in meters).
If
coord_ranges_siis present insk, it is used directly. Otherwise, this is computed fromcoord_rangesandtime_units(and degree-to-meter factors when applicable).
What this module covers#
The current implementation includes several strongly related subfamilies of helpers:
scaling canonicalization and policy enforcement such as
canonicalize_scaling_kwargs,validate_scaling_kwargs,enforce_scaling_alias_consistency, andfinalize_scaling_kwargs;SI conversion helpers such as
to_si_thickness,to_si_head,to_si_subsidence, andfrom_si_subsidence;coordinate helpers such as
deg_to_m,coord_ranges, andcoord_ranges_si;dynamic-channel resolution such as
resolve_gwl_dyn_index,resolve_subs_dyn_index,get_gwl_dyn_index_cached, andget_subs_dyn_index_cached;groundwater/head reconciliation such as
gwl_to_head_mandget_h_ref_si;history and reference-state extraction such as
get_h_hist_siandget_s_init_si.
Why this module matters#
A large part of the difficulty in physics-informed forecasting is not the PDE term alone, but the consistency of units, metadata, and conventions. This module exists to make those conventions explicit. It handles alias drift in scaling keys, coordinates recorded in degrees versus projected meters, depth-versus-head ambiguity for groundwater, and the lookup of historical channels used to initialize or interpret model states.
Selected subsidence-physics helpers#
- geoprior.models.subsidence.utils.canonicalize_scaling_kwargs(scaling_kwargs, *, copy=True)[source]
Return a canonicalized scaling dict.
If a canonical key is missing, but one of its aliases exists, copy alias -> canonical.
Keeps existing canonical values unchanged.
- geoprior.models.subsidence.utils.validate_scaling_kwargs(scaling_kwargs)[source]
Basic scaling sanity checks.
This includes policy-controlled heuristic checks for common “silent fallback” cases.
- geoprior.models.subsidence.utils.enforce_scaling_alias_consistency(scaling_kwargs, *, where='validate')[source]
Enforce that canonical keys and aliases agree.
If both canonical and an alias exist and their values differ, apply the scaling error policy.
- geoprior.models.subsidence.utils.to_si_thickness(H_model, scaling_kwargs)[source]
Convert thickness to SI.
- geoprior.models.subsidence.utils.to_si_head(h_model, scaling_kwargs)[source]
Convert head/depth to SI meters.
- geoprior.models.subsidence.utils.to_si_subsidence(s_model, scaling_kwargs)[source]
Convert subsidence to SI meters.
- geoprior.models.subsidence.utils.from_si_subsidence(s_si, scaling_kwargs)[source]
Inverse of to_si_subsidence: s_model = (s_si - b) / a.
- geoprior.models.subsidence.utils.deg_to_m(axis, scaling_kwargs)[source]
Meters per degree factor for lon/lat coords.
If coords_in_degrees=True and deg_to_m_lon/lat are missing, we try to compute them from lat0_deg (recommended).
- geoprior.models.subsidence.utils.resolve_gwl_dyn_index(scaling_kwargs)[source]
Resolve GWL channel index for dynamic_features.
- geoprior.models.subsidence.utils.resolve_subs_dyn_index(scaling_kwargs)[source]
Resolve subsidence channel index for dynamic_features.
This is optional: v3.2 can use historical subsidence as a dynamic driver to provide a physics-friendly initial condition for the mean settlement path.
- geoprior.models.subsidence.utils.gwl_to_head_m(v_m, scaling_kwargs, *, inputs=None)[source]
Convert depth-bgs to head if possible.
- geoprior.models.subsidence.utils.get_h_hist_si(model, inputs, *, want_head=True)[source]
Return head (or depth) history in SI meters.
- Parameters:
- Returns:
(B,T,1) tensor in SI meters.
- Return type:
Tensor
- geoprior.models.subsidence.utils.get_s_init_si(model, inputs, like)[source]
Return initial settlement (cumulative subsidence) in SI meters.
Priority: 1) explicit keys in inputs (s_init_si/subs_hist_last_si/…) 2) last historical value from dynamic_features if subs_dyn_index exists 3) zeros (broadcast)
- geoprior.models.subsidence.utils.get_h_ref_si(model, inputs, like)[source]
Return h_ref in SI meters, broadcast to like.
- geoprior.models.subsidence.utils.policy_gate(step, policy, *, warmup_steps=0, ramp_steps=0, dtype=tf.float32)[source]
Return a scalar gate in
[0,1]based on a policy + step.- Parameters:
step (
Tensor) – Global step counter (typicallyoptimizer.iterations).policy (
{"always_on","always_off","warmup_off"}) – Gating behavior.always_onreturns 1,always_offreturns 0, andwarmup_offreturns 0 forstep < warmup_stepsbefore ramping to 1 overramp_stepswhenramp_steps > 0or switching immediately atwarmup_stepsotherwise.warmup_steps (
int, default0) – Number of steps to keep the gate at 0 (only forwarmup_off).ramp_steps (
int, default0) – Number of steps for a linear ramp from 0->1 after warmup. If 0, the gate is a hard step.dtype (
dtype, defaulttf_float32) – Output dtype.
- Return type:
Tensor
- geoprior.models.subsidence.utils.finalize_scaling_kwargs(sk)[source]
Add derived SI conversion constants to
scaling_kwargs.Adds (when possible): -
seconds_per_time_unit: float -coord_ranges_si: dict with keyst(seconds),x/y(meters) -coord_inv_ranges_si: inverse of the above (safe floor).Notes
This helper is designed to be called once when assembling
scaling_kwargs(e.g., in your stage2 script) so the model can reuse those constants without recomputing unit conversions in the hot training loop.
- geoprior.models.subsidence.utils.coord_ranges_si(sk)[source]
Return coordinate spans in SI (t in seconds; x/y in meters).
If
coord_ranges_siis present insk, it is used directly. Otherwise, this is computed fromcoord_rangesandtime_units(and degree-to-meter factors when applicable).
Connections to the scientific stack#
These utility layers help explain how GeoPrior-v3 connects workflow orchestration, forecasting abstractions, and physics-aware subsidence modeling.
In particular:
the workflow layer explains how artifacts, staged runs, evaluation tables, and calibrated outputs are prepared;
the model layer explains how generic forecasting utilities interact with sequence models and PINN-style helper logic;
the subsidence-physics layer explains how scaling metadata, hydraulic-head interpretation, and state initialization stay coherent when models are applied to land-subsidence problems.
For readers coming from a forecasting background, the
geoprior.models.utils._utils docstrings are useful because
they frame several helpers explicitly in the language of sequence
forecasting and Temporal Fusion Transformer workflows. For readers
coming from hydrogeology or subsidence physics, the
geoprior.models.subsidence.utils layer is the most direct
explanation of how physically meaningful quantities are recovered
from model-side tensors and metadata.
Source listings with comments#
These source listings are included intentionally so the inline implementation comments remain visible in the documentation.
Top-level utility exports#
r"""Public exports for GeoPrior utility helpers."""
from .audit_utils import (
audit_stage1_scaling,
audit_stage2_handshake,
should_audit,
)
from .calibrate import calibrate_quantile_forecasts
from .data_utils import (
mask_by_reference,
nan_ops,
widen_temporal_columns,
)
from .forecast_utils import (
evaluate_forecast,
format_and_forecast,
pivot_forecast_dataframe,
)
from .generic_utils import (
default_results_dir,
ensure_directory_exists,
getenv_stripped,
normalize_time_column,
print_config_table,
save_all_figures,
)
from .geo_utils import (
augment_city_spatiotemporal_data,
augment_series_features,
augment_spatiotemporal_data,
generate_dummy_pinn_data,
unpack_frames_from_file,
)
from .holdout_utils import (
compute_group_masks,
filter_df_by_groups,
split_groups_holdout,
)
from .inspect import (
ArtifactRecord,
ablation_config_frame,
ablation_metrics_frame,
ablation_per_horizon_frame,
ablation_record_flags_frame,
ablation_record_runs_frame,
artifact_brief,
as_path,
bool_checks_frame,
build_stage1_feature_split,
calibration_stats_factors_frame,
calibration_stats_overall_frame,
calibration_stats_per_horizon_frame,
clone_artifact,
deep_update,
default_ablation_record_payload,
default_calibration_stats_payload,
default_eval_diagnostics_payload,
default_eval_physics_payload,
default_manifest_payload,
default_model_init_manifest_payload,
default_physics_payload_meta_payload,
default_run_manifest_payload,
default_scaling_kwargs_payload,
default_stage1_audit_payload,
default_stage2_handshake_payload,
default_training_summary_payload,
default_xfer_results_payload,
ensure_parent_dir,
eval_overall_frame,
eval_per_horizon_frame,
eval_physics_calibration_frame,
eval_physics_calibration_per_horizon_frame,
eval_physics_censor_frame,
eval_physics_metrics_frame,
eval_physics_per_horizon_frame,
eval_physics_point_metrics_frame,
eval_physics_units_frame,
eval_years_frame,
flatten_dict,
generate_ablation_record,
generate_calibration_stats,
generate_eval_diagnostics,
generate_eval_physics,
generate_manifest,
generate_model_init_manifest,
generate_physics_payload_meta,
generate_run_manifest,
generate_scaling_kwargs,
generate_stage1_audit,
generate_stage2_handshake,
generate_training_summary,
generate_xfer_results,
infer_artifact_kind,
inspect_ablation_record,
inspect_calibration_stats,
inspect_eval_diagnostics,
inspect_eval_physics,
inspect_manifest,
inspect_model_init_manifest,
inspect_physics_payload_meta,
inspect_run_manifest,
inspect_scaling_kwargs,
inspect_stage1_audit,
inspect_stage2_handshake,
inspect_training_summary,
inspect_xfer_results,
is_number,
json_ready,
load_ablation_record,
load_artifact,
load_calibration_stats,
load_eval_diagnostics,
load_eval_physics,
load_manifest,
load_model_init_manifest,
load_physics_payload_meta,
load_run_manifest,
load_scaling_kwargs,
load_stage1_audit,
load_stage2_handshake,
load_training_summary,
load_xfer_results,
manifest_artifacts_frame,
manifest_config_frame,
manifest_feature_groups_frame,
manifest_holdout_frame,
manifest_identity_frame,
manifest_paths_frame,
manifest_shapes_frame,
manifest_versions_frame,
metrics_frame,
model_init_architecture_frame,
model_init_dims_frame,
model_init_feature_groups_frame,
model_init_geoprior_frame,
model_init_scaling_overview_frame,
nested_get,
numeric_items,
physics_payload_meta_closure_frame,
physics_payload_meta_identity_frame,
physics_payload_meta_metrics_frame,
physics_payload_meta_units_frame,
plot_ablation_boolean_summary,
plot_ablation_lambda_weights,
plot_ablation_metric_by_variant,
plot_ablation_per_horizon_metric,
plot_ablation_run_counts,
plot_ablation_top_variants,
plot_boolean_checks,
plot_calibration_boolean_summary,
plot_calibration_factors,
plot_calibration_overall_metrics,
plot_calibration_per_horizon_coverage,
plot_calibration_per_horizon_sharpness,
plot_eval_boolean_summary,
plot_eval_overall_metrics,
plot_eval_per_horizon_metrics,
plot_eval_physics_boolean_summary,
plot_eval_physics_calibration_factors,
plot_eval_physics_epsilons,
plot_eval_physics_metrics,
plot_eval_physics_per_horizon_metrics,
plot_eval_physics_point_metrics,
plot_eval_year_metric_trend,
plot_manifest_artifact_inventory,
plot_manifest_boolean_summary,
plot_manifest_coord_ranges,
plot_manifest_feature_group_sizes,
plot_manifest_holdout_counts,
plot_metric_bars,
plot_model_init_architecture,
plot_model_init_boolean_summary,
plot_model_init_dims,
plot_model_init_feature_group_sizes,
plot_model_init_geoprior,
plot_physics_payload_meta_boolean_summary,
plot_physics_payload_meta_core_scalars,
plot_physics_payload_meta_payload_metrics,
plot_run_manifest_boolean_summary,
plot_run_manifest_coord_ranges,
plot_run_manifest_feature_group_sizes,
plot_run_manifest_path_inventory,
plot_scaling_kwargs_affine_maps,
plot_scaling_kwargs_boolean_summary,
plot_scaling_kwargs_bounds,
plot_scaling_kwargs_coord_ranges,
plot_scaling_kwargs_feature_group_sizes,
plot_scaling_kwargs_schedule_scalars,
plot_series_map,
plot_stage1_boolean_summary,
plot_stage1_coord_ranges,
plot_stage1_feature_split,
plot_stage1_target_stats,
plot_stage1_variable_stats,
plot_stage2_boolean_summary,
plot_stage2_coord_range_errors,
plot_stage2_coord_stats,
plot_stage2_finite_ratios,
plot_stage2_sample_sizes,
plot_stage2_scaling_summary,
plot_training_best_metrics,
plot_training_boolean_summary,
plot_training_final_metrics,
plot_training_loss_family,
plot_training_metric_deltas,
plot_xfer_boolean_summary,
plot_xfer_direction_metric,
plot_xfer_overall_metrics,
plot_xfer_per_horizon_metrics,
plot_xfer_schema_counts,
read_json,
run_manifest_artifacts_frame,
run_manifest_config_frame,
run_manifest_identity_frame,
run_manifest_paths_frame,
run_manifest_scaling_overview_frame,
scaling_kwargs_affine_frame,
scaling_kwargs_bounds_frame,
scaling_kwargs_coord_frame,
scaling_kwargs_feature_channels_frame,
scaling_kwargs_schedule_frame,
stage1_coord_ranges_frame,
stage1_feature_split_frame,
stage1_stats_frame,
stage2_coord_range_frame,
stage2_coord_stats_frame,
stage2_finite_frame,
stage2_layout_frame,
stage2_scaling_frame,
summarize_ablation_record,
summarize_calibration_stats,
summarize_eval_diagnostics,
summarize_eval_physics,
summarize_manifest,
summarize_model_init_manifest,
summarize_physics_payload_meta,
summarize_run_manifest,
summarize_scaling_kwargs,
summarize_stage1_audit,
summarize_stage2_handshake,
summarize_training_summary,
summarize_xfer_results,
training_compile_frame,
training_env_frame,
training_hp_frame,
training_metrics_frame,
training_paths_frame,
write_json,
xfer_overall_frame,
xfer_per_horizon_frame,
xfer_schema_frame,
xfer_warm_frame,
)
from .io_utils import (
fetch_joblib_data,
save_job,
)
from .nat_utils import (
best_epoch_and_metrics,
build_censor_mask,
ensure_input_shapes,
extract_preds,
load_nat_config,
load_nat_config_payload,
load_scaler_info,
make_tf_dataset,
map_targets_for_training,
name_of,
resolve_hybrid_config,
resolve_si_affine,
save_ablation_record,
serialize_subs_params,
subs_point_from_out,
)
from .parallel_utils import (
apply_gpu_env,
apply_tf_threading,
apply_thread_env,
pick_gpu_id,
resolve_device,
resolve_gpu_ids,
resolve_n_jobs,
threads_per_job,
)
from .scale_metrics import (
evaluate_point_forecast,
inverse_scale_target,
)
from .sequence_utils import build_future_sequences_npz
from .shapes import canonicalize_BHQO
from .spatial_utils import (
create_spatial_clusters,
deg_to_m_from_lat,
spatial_sampling,
)
from .subsidence_utils import (
convert_eval_payload_units,
cumulative_to_rate,
make_txy_coords,
normalize_gwl_alias,
postprocess_eval_json,
rate_to_cumulative,
resolve_gwl_for_physics,
resolve_head_column,
)
__all__ = [
"spatial_sampling",
"create_spatial_clusters",
"augment_city_spatiotemporal_data",
"augment_series_features",
"generate_dummy_pinn_data",
"augment_spatiotemporal_data",
"mask_by_reference",
"nan_ops",
"unpack_frames_from_file",
"widen_temporal_columns",
"pivot_forecast_dataframe",
"fetch_joblib_data",
"save_job",
"normalize_time_column",
"convert_eval_payload_units",
"postprocess_eval_json",
"evaluate_point_forecast",
"inverse_scale_target",
"deg_to_m_from_lat",
"canonicalize_BHQO",
"calibrate_quantile_forecasts",
"audit_stage2_handshake",
"audit_stage1_scaling",
"should_audit",
"format_and_forecast",
"evaluate_forecast",
"default_results_dir",
"ensure_directory_exists",
"getenv_stripped",
"print_config_table",
"save_all_figures",
"build_censor_mask",
"ensure_input_shapes",
"extract_preds",
"load_nat_config",
"load_nat_config_payload",
"load_scaler_info",
"make_tf_dataset",
"map_targets_for_training",
"name_of",
"resolve_hybrid_config",
"resolve_si_affine",
"best_epoch_and_metrics",
"subs_point_from_out",
"serialize_subs_params",
"save_ablation_record",
"cumulative_to_rate",
"normalize_gwl_alias",
"rate_to_cumulative",
"resolve_gwl_for_physics",
"resolve_head_column",
"make_txy_coords",
"build_future_sequences_npz",
"compute_group_masks",
"split_groups_holdout",
"filter_df_by_groups",
"build_future_sequences_npz",
"resolve_n_jobs",
"threads_per_job",
"apply_tf_threading",
"apply_thread_env",
"resolve_device",
"resolve_gpu_ids",
"pick_gpu_id",
"apply_gpu_env",
"ArtifactRecord",
"artifact_brief",
"as_path",
"bool_checks_frame",
"clone_artifact",
"deep_update",
"ensure_parent_dir",
"flatten_dict",
"infer_artifact_kind",
"is_number",
"json_ready",
"load_artifact",
"metrics_frame",
"nested_get",
"numeric_items",
"plot_boolean_checks",
"plot_metric_bars",
"plot_series_map",
"read_json",
"write_json",
"build_stage1_feature_split",
"default_stage1_audit_payload",
"generate_stage1_audit",
"inspect_stage1_audit",
"load_stage1_audit",
"plot_stage1_boolean_summary",
"plot_stage1_coord_ranges",
"plot_stage1_feature_split",
"plot_stage1_target_stats",
"plot_stage1_variable_stats",
"stage1_coord_ranges_frame",
"stage1_feature_split_frame",
"stage1_stats_frame",
"summarize_stage1_audit",
"default_stage2_handshake_payload",
"generate_stage2_handshake",
"inspect_stage2_handshake",
"load_stage2_handshake",
"plot_stage2_boolean_summary",
"plot_stage2_coord_range_errors",
"plot_stage2_coord_stats",
"plot_stage2_finite_ratios",
"plot_stage2_sample_sizes",
"plot_stage2_scaling_summary",
"stage2_coord_range_frame",
"stage2_coord_stats_frame",
"stage2_finite_frame",
"stage2_layout_frame",
"stage2_scaling_frame",
"summarize_stage2_handshake",
"default_training_summary_payload",
"generate_training_summary",
"inspect_training_summary",
"load_training_summary",
"plot_training_best_metrics",
"plot_training_boolean_summary",
"plot_training_final_metrics",
"plot_training_loss_family",
"plot_training_metric_deltas",
"training_compile_frame",
"training_env_frame",
"training_hp_frame",
"training_metrics_frame",
"training_paths_frame",
"summarize_training_summary",
"default_eval_diagnostics_payload",
"eval_overall_frame",
"eval_per_horizon_frame",
"eval_years_frame",
"generate_eval_diagnostics",
"inspect_eval_diagnostics",
"load_eval_diagnostics",
"plot_eval_boolean_summary",
"plot_eval_overall_metrics",
"plot_eval_per_horizon_metrics",
"plot_eval_year_metric_trend",
"summarize_eval_diagnostics",
"default_eval_physics_payload",
"eval_physics_calibration_frame",
"eval_physics_calibration_per_horizon_frame",
"eval_physics_censor_frame",
"eval_physics_metrics_frame",
"eval_physics_per_horizon_frame",
"eval_physics_point_metrics_frame",
"eval_physics_units_frame",
"generate_eval_physics",
"inspect_eval_physics",
"load_eval_physics",
"plot_eval_physics_boolean_summary",
"plot_eval_physics_calibration_factors",
"plot_eval_physics_epsilons",
"plot_eval_physics_metrics",
"plot_eval_physics_per_horizon_metrics",
"plot_eval_physics_point_metrics",
"summarize_eval_physics",
"default_physics_payload_meta_payload",
"generate_physics_payload_meta",
"inspect_physics_payload_meta",
"load_physics_payload_meta",
"physics_payload_meta_closure_frame",
"physics_payload_meta_identity_frame",
"physics_payload_meta_metrics_frame",
"physics_payload_meta_units_frame",
"plot_physics_payload_meta_boolean_summary",
"plot_physics_payload_meta_core_scalars",
"plot_physics_payload_meta_payload_metrics",
"summarize_physics_payload_meta",
"default_scaling_kwargs_payload",
"generate_scaling_kwargs",
"inspect_scaling_kwargs",
"load_scaling_kwargs",
"plot_scaling_kwargs_affine_maps",
"plot_scaling_kwargs_boolean_summary",
"plot_scaling_kwargs_bounds",
"plot_scaling_kwargs_coord_ranges",
"plot_scaling_kwargs_feature_group_sizes",
"plot_scaling_kwargs_schedule_scalars",
"scaling_kwargs_affine_frame",
"scaling_kwargs_bounds_frame",
"scaling_kwargs_coord_frame",
"scaling_kwargs_feature_channels_frame",
"scaling_kwargs_schedule_frame",
"summarize_scaling_kwargs",
"default_model_init_manifest_payload",
"generate_model_init_manifest",
"inspect_model_init_manifest",
"load_model_init_manifest",
"model_init_architecture_frame",
"model_init_dims_frame",
"model_init_feature_groups_frame",
"model_init_geoprior_frame",
"model_init_scaling_overview_frame",
"plot_model_init_architecture",
"plot_model_init_boolean_summary",
"plot_model_init_dims",
"plot_model_init_feature_group_sizes",
"plot_model_init_geoprior",
"summarize_model_init_manifest",
"default_run_manifest_payload",
"generate_run_manifest",
"inspect_run_manifest",
"load_run_manifest",
"plot_run_manifest_boolean_summary",
"plot_run_manifest_coord_ranges",
"plot_run_manifest_feature_group_sizes",
"plot_run_manifest_path_inventory",
"run_manifest_artifacts_frame",
"run_manifest_config_frame",
"run_manifest_identity_frame",
"run_manifest_paths_frame",
"run_manifest_scaling_overview_frame",
"summarize_run_manifest",
"default_manifest_payload",
"generate_manifest",
"inspect_manifest",
"load_manifest",
"manifest_artifacts_frame",
"manifest_config_frame",
"manifest_feature_groups_frame",
"manifest_holdout_frame",
"manifest_identity_frame",
"manifest_paths_frame",
"manifest_shapes_frame",
"manifest_versions_frame",
"plot_manifest_artifact_inventory",
"plot_manifest_boolean_summary",
"plot_manifest_coord_ranges",
"plot_manifest_feature_group_sizes",
"plot_manifest_holdout_counts",
"summarize_manifest",
"default_xfer_results_payload",
"generate_xfer_results",
"inspect_xfer_results",
"load_xfer_results",
"plot_xfer_boolean_summary",
"plot_xfer_direction_metric",
"plot_xfer_overall_metrics",
"plot_xfer_per_horizon_metrics",
"plot_xfer_schema_counts",
"summarize_xfer_results",
"xfer_overall_frame",
"xfer_per_horizon_frame",
"xfer_schema_frame",
"xfer_warm_frame",
"calibration_stats_factors_frame",
"calibration_stats_overall_frame",
"calibration_stats_per_horizon_frame",
"default_calibration_stats_payload",
"generate_calibration_stats",
"inspect_calibration_stats",
"load_calibration_stats",
"plot_calibration_boolean_summary",
"plot_calibration_factors",
"plot_calibration_overall_metrics",
"plot_calibration_per_horizon_coverage",
"plot_calibration_per_horizon_sharpness",
"summarize_calibration_stats",
"ablation_config_frame",
"ablation_metrics_frame",
"ablation_per_horizon_frame",
"ablation_record_flags_frame",
"ablation_record_runs_frame",
"default_ablation_record_payload",
"generate_ablation_record",
"inspect_ablation_record",
"load_ablation_record",
"plot_ablation_boolean_summary",
"plot_ablation_lambda_weights",
"plot_ablation_metric_by_variant",
"plot_ablation_per_horizon_metric",
"plot_ablation_run_counts",
"plot_ablation_top_variants",
"summarize_ablation_record",
]
Model utility exports#
# SPDX-License-Identifier: Apache-2.0
# GeoPrior-v3 https://github.com/earthai-tech/geoprior-v3
# Copyright (c) 2026-present
# Author: LKouadio <etanoyau@gmail.com>
# website:https://lkouadio.com
r"""Public exports for model utility helpers."""
from ._utils import (
compute_anomaly_scores,
compute_forecast_horizon,
create_sequences,
export_keras_losses,
extract_batches_from_dataset,
extract_callbacks_from,
forecast_multi_step,
forecast_single_step,
format_predictions,
format_predictions_to_dataframe,
generate_forecast,
generate_forecast_with,
get_tensor_from,
make_dict_to_tuple_fn,
prepare_model_inputs,
prepare_model_inputs_in,
prepare_spatial_future_data,
set_default_params,
split_static_dynamic,
squeeze_last_dim_if,
step_to_long,
)
from .pinn import (
PDE_MODE_ALIASES,
extract_txy,
extract_txy_in,
format_pihalnet_predictions,
format_pinn_predictions,
normalize_for_pinn,
plot_hydraulic_head,
prepare_pinn_data_sequences,
process_pde_modes,
)
__all__ = [
"PDE_MODE_ALIASES",
"compute_anomaly_scores",
"compute_forecast_horizon",
"create_sequences",
"extract_batches_from_dataset",
"extract_callbacks_from",
"forecast_multi_step",
"forecast_single_step",
"format_predictions",
"format_predictions_to_dataframe",
"generate_forecast",
"generate_forecast_with",
"prepare_model_inputs",
"prepare_model_inputs_in",
"prepare_spatial_future_data",
"set_default_params",
"split_static_dynamic",
"squeeze_last_dim_if",
"step_to_long",
"export_keras_losses",
"get_tensor_from",
"format_pihalnet_predictions",
"normalize_for_pinn",
"prepare_pinn_data_sequences",
"format_pinn_predictions",
"extract_txy",
"plot_hydraulic_head",
"make_dict_to_tuple_fn",
"extract_physical_parameters",
"extract_txy_in",
"process_pde_modes",
]
Subsidence utility implementation#
# SPDX-License-Identifier: Apache-2.0
# GeoPrior-v3 https://github.com/earthai-tech/geoprior-v3
# Copyright (c) 2026-present
# Author: LKouadio <https://lkouadio.com>
"""
GeoPrior subsidence model utilities.
"""
from __future__ import annotations
import json
from collections.abc import Mapping
from pathlib import Path
from typing import Any
from warnings import warn
import numpy as np
from .. import KERAS_DEPS
Tensor = KERAS_DEPS.Tensor
tf_float32 = KERAS_DEPS.float32
tf_int32 = KERAS_DEPS.int32
tf_cast = KERAS_DEPS.cast
tf_constant = KERAS_DEPS.constant
tf_debugging = KERAS_DEPS.debugging
tf_equal = KERAS_DEPS.equal
tf_maximum = KERAS_DEPS.maximum
tf_minimum = KERAS_DEPS.minimum
tf_greater_equal = KERAS_DEPS.greater_equal
tf_rank = KERAS_DEPS.rank
tf_cond = KERAS_DEPS.cond
tf_shape = KERAS_DEPS.shape
tf_zeros_like = KERAS_DEPS.zeros_like
tf_ones = KERAS_DEPS.ones
tf_greater = KERAS_DEPS.greater
tf_cond = KERAS_DEPS.cond
tf_concat = KERAS_DEPS.concat
tf_convert_to_tensor = KERAS_DEPS.convert_to_tensor
tf_ones_like = KERAS_DEPS.ones_like
tf_less_equal = KERAS_DEPS.less_equal
tf_abs = KERAS_DEPS.abs
tf_print = KERAS_DEPS.print
tf_reduce_mean = KERAS_DEPS.reduce_mean
tf_expand_dims = KERAS_DEPS.expand_dims
tf_tile = KERAS_DEPS.tile
_EPSILON = 1e-12
# ---------------------------------------------------------------------
# Scaling kwargs access helpers (alias-safe)
# ---------------------------------------------------------------------
_SK_ALIASES = {
# common naming drift
"time_units": ("time_unit",),
"cons_residual_units": ("cons_residual_unit",),
# policy drift
"scaling_error_policy": (
"error_policy",
"scaling_policy",
),
# coord drift
"coords_normalized": (
"coord_normalized",
"coords_norm",
),
"coords_in_degrees": (
"coord_in_degrees",
"coords_deg",
),
"coord_order": ("coords_order",),
"coord_ranges": ("coord_range",),
# feature-name list drift
"dynamic_feature_names": (
"dynamic_features_names",
"dyn_feature_names",
),
"future_feature_names": (
"future_features_names",
"fut_feature_names",
),
"static_feature_names": (
"static_features_names",
"stat_feature_names",
),
# feature-channel naming drift
"gwl_col": (
"gwl_dyn_name",
"gwl_dyn_col",
"gwl_name",
),
"subs_dyn_name": (
"subs_col",
"subs_dyn_col",
"subsidence_dyn_name",
),
# feature-channel index drift
"gwl_dyn_index": (
"gwl_index",
"gwl_feature_index",
"gwl_channel_index",
),
"subs_dyn_index": (
"subs_index",
"subs_feature_index",
"subs_channel_index",
),
# z_surf drift
"z_surf_col": (
"z_surf_key",
"z_surf_name",
),
# bounds drift (often nested under scaling_kwargs['bounds'])
"log_tau_min": (
"logTau_min",
"logtau_min",
),
"log_tau_max": (
"logTau_max",
"logtau_max",
),
"tau_min": (
"Tau_min",
"tauMin",
"tau_min_sec",
"tau_min_seconds",
),
"tau_max": (
"Tau_max",
"tauMax",
"tau_max_sec",
"tau_max_seconds",
),
"tau_min_units": (
"tau_min_time_units",
"tau_min_in_time_units",
),
"tau_max_units": (
"tau_max_time_units",
"tau_max_in_time_units",
),
"Q_length_in_si": ("Q_in_m_per_s",),
}
_SK_ALIASES.update(
{
"cons_drawdown_mode": (
"drawdown_mode",
"cons_delta_mode",
),
"cons_drawdown_rule": (
"drawdown_rule",
"cons_delta_rule",
),
"cons_stop_grad_ref": (
"stop_grad_ref",
"cons_stopgrad_ref",
),
"cons_drawdown_zero_at_origin": (
"drawdown_zero_at_origin",
"cons_zero_at_origin",
),
"cons_drawdown_clip_max": (
"drawdown_clip_max",
"cons_clip_max",
),
"cons_relu_beta": (
"relu_beta",
"cons_beta",
),
}
)
# MV prior drift (mode/weight/warmup + loss knobs)
_SK_ALIASES.update(
{
"mv_prior_mode": (
"mv_mode",
"mvprior_mode",
"mv_prior_kind",
),
"mv_weight": (
"mv_prior_weight",
"mvprior_weight",
"mv_w",
),
"mv_warmup_steps": (
"mv_prior_warmup_steps",
"mv_warmup_steps",
"mv_warmup_iters",
"mv_warmup_iterations",
),
"mv_alpha_disp": (
"mv_prior_alpha_disp",
"mv_disp_alpha",
"mv_alpha",
),
"mv_huber_delta": (
"mv_prior_huber_delta",
"mv_delta",
"mv_huber",
),
"mv_prior_units": (
"mv_units",
"mv_gamma_units",
"mv_gw_units",
),
}
)
def enforce_scaling_alias_consistency(
scaling_kwargs: dict[str, Any] | None,
*,
where: str = "validate",
) -> None:
"""
Enforce that canonical keys and aliases agree.
If both canonical and an alias exist and their
values differ, apply the scaling error policy.
"""
sk = scaling_kwargs or {}
for key, aliases in _SK_ALIASES.items():
if key not in sk:
continue
v0 = sk.get(key, None)
if v0 is None:
continue
for a in aliases:
if a not in sk:
continue
va = sk.get(a, None)
if va is None:
continue
if va != v0:
msg = (
"Conflicting scaling keys: "
f"{key!r}={v0!r} != {a!r}={va!r}."
)
_handle_scaling_issue(
sk,
msg,
where=where,
)
def canonicalize_scaling_kwargs(
scaling_kwargs: dict[str, Any] | None,
*,
copy: bool = True,
) -> dict[str, Any]:
"""
Return a canonicalized scaling dict.
- If a canonical key is missing, but one of its
aliases exists, copy alias -> canonical.
- Keeps existing canonical values unchanged.
"""
sk0 = scaling_kwargs or {}
sk = dict(sk0) if copy else sk0
for key, aliases in _SK_ALIASES.items():
if key in sk and sk.get(key, None) is not None:
continue
for a in aliases:
if a in sk and sk.get(a, None) is not None:
sk[key] = sk[a]
break
return sk
def load_scaling_kwargs(
scaling_kwargs: Any | None,
*,
copy: bool = True,
) -> dict[str, Any]:
"""
Load scaling kwargs from a dict-like object or JSON.
Supported inputs
----------------
- dict / Mapping:
Returned (copied by default).
- str:
* If it looks like JSON ("{...}" or "[...]"), parse as JSON.
* Else treat as a filesystem path to a JSON file.
- pathlib.Path:
Treated as a filesystem path to a JSON file.
- None:
Returns {}.
Parameters
----------
scaling_kwargs : Any
Scaling configuration input. Can be a dict, JSON string,
path to JSON file, or None.
copy : bool, default=True
If True, returns a shallow copy of the dict.
Returns
-------
dict
Parsed scaling kwargs as a Python dict.
Raises
------
TypeError
If the input type is unsupported.
ValueError
If JSON parsing fails or JSON does not decode to a dict.
FileNotFoundError
If a JSON path is given but does not exist.
"""
if scaling_kwargs is None:
return {}
if isinstance(scaling_kwargs, Mapping):
return (
dict(scaling_kwargs) if copy else scaling_kwargs
)
if isinstance(scaling_kwargs, Path):
path = scaling_kwargs
text = path.read_text(encoding="utf-8")
obj = json.loads(text)
if not isinstance(obj, dict):
raise ValueError(
"Scaling JSON must decode to an object/dict, "
f"got {type(obj).__name__}."
)
return obj
if isinstance(scaling_kwargs, str):
s = scaling_kwargs.strip()
# 1) Inline JSON object/array.
if (s.startswith("{") and s.endswith("}")) or (
s.startswith("[") and s.endswith("]")
):
try:
obj = json.loads(s)
except json.JSONDecodeError as e:
raise ValueError(
"Invalid scaling_kwargs JSON string."
) from e
if not isinstance(obj, dict):
raise ValueError(
"Scaling JSON must decode to an object/dict, "
f"got {type(obj).__name__}."
)
return obj
# 2) Treat as file path to JSON.
path = Path(s).expanduser()
if not path.exists():
raise FileNotFoundError(
f"Scaling kwargs JSON file not found: {str(path)!r}."
)
text = path.read_text(encoding="utf-8")
try:
obj = json.loads(text)
except json.JSONDecodeError as e:
raise ValueError(
f"Invalid JSON in scaling kwargs file: {str(path)!r}."
) from e
if not isinstance(obj, dict):
raise ValueError(
"Scaling JSON file must decode to an object/dict, "
f"got {type(obj).__name__}."
)
return obj
try:
obj = dict(scaling_kwargs)
except Exception as e:
raise TypeError(
"scaling_kwargs must be a dict/Mapping, JSON string, "
"Path, or a path string to a JSON file."
) from e
return obj
def get_sk(
scaling_kwargs,
key: str,
*aliases: str,
default=None,
required: bool = False,
cast=None,
):
"""
Fetch a key from `scaling_kwargs` with aliases + default.
- Tries: key -> built-in aliases -> explicit aliases
- Treats None and blank strings as "missing" and keeps searching.
"""
sk = scaling_kwargs or {}
if not isinstance(sk, Mapping):
try:
sk = dict(sk)
except Exception:
sk = {}
cand = [key]
cand.extend(_SK_ALIASES.get(key, ()))
cand.extend([a for a in aliases if a])
for k in cand:
if k in sk:
v = sk[k]
if v is None:
continue
if isinstance(v, str) and not v.strip():
continue
if cast is not None:
try:
v = cast(v)
except Exception as e:
raise ValueError(
f"Invalid scaling_kwargs[{k!r}]={v!r}."
) from e
return v
if required:
alias_txt = (
", ".join(repr(x) for x in cand[1:]) or "none"
)
raise ValueError(
f"Missing required scaling key {key!r} (aliases: {alias_txt})."
)
if cast is not None and default is not None:
try:
return cast(default)
except Exception:
return default
return default
def _norm_policy(policy: str | None) -> str:
"""
Normalize scaling error policy.
Allowed:
- 'ignore'
- 'warn' (default)
- 'raise'
"""
p = (policy or "warn").strip().lower()
if p not in ("ignore", "warn", "raise"):
p = "warn"
return p
def _handle_scaling_issue(
scaling_kwargs: dict[str, Any] | None,
message: str,
*,
where: str = "validate",
) -> None:
"""
Apply scaling error policy.
Notes
-----
You asked for: even if policy is 'raise', runtime
fallback paths should still fall back to zeros.
So:
- where='validate': obey ignore/warn/raise
- where='runtime' : treat 'raise' as 'warn'
"""
sk = scaling_kwargs or {}
policy = _norm_policy(
get_sk(sk, "scaling_error_policy", default="warn")
)
# Runtime must not crash; still fall back later.
if where != "validate" and policy == "raise":
policy = "warn"
if policy == "ignore":
return
if policy == "warn":
warn(
message,
category=RuntimeWarning,
stacklevel=2,
)
return
# validate + raise
raise ValueError(message)
def _is_deg_mode(mode: str) -> bool:
m = (mode or "").strip().lower()
return m in {
"deg",
"degree",
"degrees",
"lonlat",
"latlon",
}
def _validate_scaling_kwargs(scaling_kwargs):
sk = canonicalize_scaling_kwargs(scaling_kwargs)
enforce_scaling_alias_consistency(sk, where="validate")
mode = str(sk.get("coord_mode", ""))
deg_mode = _is_deg_mode(mode)
deg_flag = bool(sk.get("coords_in_degrees", False))
if deg_mode != deg_flag:
msg = (
"Inconsistent coord flags: "
f"coord_mode={mode!r} but "
f"coords_in_degrees={deg_flag}. "
"Decide: degrees(+deg_to_m_*) or "
"projected meters (coords_in_degrees=False)."
)
_handle_scaling_issue(sk, msg, where="validate")
epsg_used = sk.get("coord_epsg_used", None)
if deg_flag and (epsg_used not in (None, 4326)):
msg = (
"coords_in_degrees=True but "
f"coord_epsg_used={epsg_used!r} "
"looks projected. If you already "
"reprojected, set coords_in_degrees=False."
)
_handle_scaling_issue(sk, msg, where="validate")
def validate_scaling_kwargs(
scaling_kwargs: dict[str, Any] | None,
) -> None:
"""
Basic scaling sanity checks.
This includes policy-controlled heuristic checks
for common "silent fallback" cases.
"""
sk = canonicalize_scaling_kwargs(scaling_kwargs)
enforce_scaling_alias_consistency(sk, where="validate")
# --------------------------------------------------
# Degrees mode requires meters-per-degree factors.
# --------------------------------------------------
if bool(sk.get("coords_in_degrees", False)):
for key in ("deg_to_m_lon", "deg_to_m_lat"):
val = sk.get(key, None)
if val is None:
msg = (
"coords_in_degrees=True but missing "
f"scaling_kwargs[{key!r}]."
)
raise ValueError(msg)
try:
v = float(val)
except (TypeError, ValueError) as e:
raise ValueError(
f"Invalid {key!r}={val!r}."
) from e
if not np.isfinite(v) or v <= 0.0:
raise ValueError(f"Invalid {key!r}={v}.")
# --------------------------------------------------
# Normalized coords require coord_ranges.
# --------------------------------------------------
if bool(
sk.get("coords_normalized", False)
) and not sk.get(
"coord_ranges",
None,
):
raise ValueError(
"coords_normalized=True but coord_ranges missing."
)
# --------------------------------------------------
# Require time units (alias-safe).
# --------------------------------------------------
if get_sk(sk, "time_units", default=None) is None:
raise ValueError(
"time_units missing in scaling_kwargs."
)
# --------------------------------------------------
# Heuristic checks (policy-controlled).
# --------------------------------------------------
names = sk.get("dynamic_feature_names", None)
names = list(names) if names is not None else []
# A) Subsidence init: detect cum subs channel.
has_subs_cum = any(
("subs" in str(n).lower() and "cum" in str(n).lower())
for n in names
)
subs_idx = sk.get("subs_dyn_index", None)
subs_name = get_sk(sk, "subs_dyn_name", default=None)
meta = sk.get("gwl_z_meta", {}) or {}
cols = meta.get("cols", {}) or {}
subs_meta = cols.get("subs_model", None)
if (
has_subs_cum
and subs_idx is None
and subs_name is None
):
if subs_meta is None:
msg = (
"dynamic_feature_names contains a cumulative "
"subsidence channel, but no subs_dyn_index/"
"subs_dyn_name and no gwl_z_meta.cols.subs_model. "
"Initial settlement will fall back to zeros."
)
_handle_scaling_issue(
sk,
msg,
where="validate",
)
# B) Depth->head conversion needs z_surf when proxy=False.
kind = str(sk.get("gwl_kind", "")).lower()
proxy = bool(sk.get("use_head_proxy", True))
if (not proxy) and (
kind not in ("head", "waterhead", "hydraulic_head")
):
z_col = sk.get("z_surf_col", None)
z_col = z_col or meta.get("z_surf_col", None)
z_static = cols.get("z_surf_static", None)
z_idx = sk.get("z_surf_static_index", None)
static_names = get_sk(
sk,
"static_feature_names",
default=None,
)
# If you did not provide a way to locate z_surf
# in static features, conversion may fallback.
if (
z_idx is None
and static_names is None
and z_col is not None
and z_static is not None
and z_col != z_static
):
msg = (
"use_head_proxy=False and gwl_kind is depth-like, "
"but z_surf_col differs from gwl_z_meta.cols."
"z_surf_static, and no static_feature_names/"
"z_surf_static_index provided. Depth->head "
"conversion may fall back to depth."
)
_handle_scaling_issue(
sk,
msg,
where="validate",
)
def affine_from_cfg(
scaling_kwargs: dict[str, Any] | None,
*,
scale_key: str,
bias_key: str,
meta_keys: tuple[str, ...] = (),
unit_key: str | None = None,
) -> tuple[Tensor, Tensor]:
"""Return (a,b) for y_si = y_model*a + b."""
cfg = scaling_kwargs or {}
a = cfg.get(scale_key, None)
b = cfg.get(bias_key, None)
if a is not None or b is not None:
a = 1.0 if a is None else float(a)
b = 0.0 if b is None else float(b)
return tf_constant(a, tf_float32), tf_constant(
b, tf_float32
)
for mk in meta_keys:
meta = cfg.get(mk, None)
if isinstance(meta, dict):
mu = meta.get("mu", meta.get("mean", None))
sig = meta.get("sigma", meta.get("std", None))
if mu is not None and sig is not None:
return (
tf_constant(float(sig), tf_float32),
tf_constant(float(mu), tf_float32),
)
if unit_key is not None:
u = float(cfg.get(unit_key, 1.0))
return tf_constant(u, tf_float32), tf_constant(
0.0, tf_float32
)
return tf_constant(1.0, tf_float32), tf_constant(
0.0, tf_float32
)
def to_si_thickness(
H_model: Tensor,
scaling_kwargs: dict[str, Any] | None,
) -> Tensor:
"""Convert thickness to SI."""
a, b = affine_from_cfg(
scaling_kwargs,
scale_key="H_scale_si",
bias_key="H_bias_si",
meta_keys=("H_z_meta",),
unit_key="thickness_unit_to_si",
)
return tf_cast(H_model, tf_float32) * a + b
def to_si_head(
h_model: Tensor,
scaling_kwargs: dict[str, Any] | None,
) -> Tensor:
"""Convert head/depth to SI meters."""
a, b = affine_from_cfg(
scaling_kwargs,
scale_key="head_scale_si",
bias_key="head_bias_si",
meta_keys=("head_z_meta", "gwl_z_meta"),
unit_key="head_unit_to_si",
)
return tf_cast(h_model, tf_float32) * a + b
def to_si_subsidence(
s_model: Tensor,
scaling_kwargs: dict[str, Any] | None,
) -> Tensor:
"""Convert subsidence to SI meters."""
a, b = affine_from_cfg(
scaling_kwargs,
scale_key="subs_scale_si",
bias_key="subs_bias_si",
meta_keys=("subs_z_meta",),
unit_key="subs_unit_to_si",
)
return tf_cast(s_model, tf_float32) * a + b
def from_si_subsidence(
s_si: Tensor,
scaling_kwargs: dict[str, Any] | None,
) -> Tensor:
"""Inverse of to_si_subsidence: s_model = (s_si - b) / a."""
a, b = affine_from_cfg(
scaling_kwargs,
scale_key="subs_scale_si",
bias_key="subs_bias_si",
meta_keys=("subs_z_meta",),
unit_key="subs_unit_to_si",
)
eps = tf_constant(_EPSILON, tf_float32)
return (tf_cast(s_si, tf_float32) - b) / (a + eps)
def deg_to_m(
axis: str,
scaling_kwargs: dict[str, Any] | None,
) -> Tensor:
"""
Meters per degree factor for lon/lat coords.
If coords_in_degrees=True and deg_to_m_lon/lat are missing, we try
to compute them from lat0_deg (recommended).
"""
if axis not in ("x", "y"):
raise ValueError(
f"deg_to_m: axis must be 'x' or 'y', got {axis!r}."
)
cfg = scaling_kwargs or {}
if not bool(cfg.get("coords_in_degrees", False)):
return tf_constant(1.0, tf_float32)
key = "deg_to_m_lon" if axis == "x" else "deg_to_m_lat"
val = cfg.get(key, None)
if val is None:
lat0 = cfg.get("lat0_deg", None)
if lat0 is None:
raise ValueError(
"coords_in_degrees=True but missing deg_to_m_lon/deg_to_m_lat "
"and lat0_deg (needed for lon scaling)."
)
lat0 = float(lat0)
if axis == "x":
v = 111320.0 * float(np.cos(np.deg2rad(lat0)))
else:
v = 110574.0
return tf_constant(v, tf_float32)
try:
v = float(val)
except (TypeError, ValueError) as e:
raise ValueError(f"Invalid {key!r}={val!r}.") from e
if not np.isfinite(v) or v <= 0.0:
raise ValueError(f"Invalid {key!r}={v}.")
return tf_constant(v, tf_float32)
def coord_ranges(
scaling_kwargs: dict[str, Any] | None,
) -> tuple[float | None, float | None, float | None]:
"""Return (tR,xR,yR) if coords_normalized."""
cfg = scaling_kwargs or {}
if not bool(cfg.get("coords_normalized", False)):
return None, None, None
r = cfg.get("coord_ranges", {}) or {}
def get(name: str, *alts: str) -> float | None:
v = r.get(name, None)
if v is None:
for a in alts:
v = cfg.get(a, None)
if v is not None:
break
return None if v is None else float(v)
tR = get("t", "t_range", "coord_range_t")
xR = get("x", "x_range", "coord_range_x")
yR = get("y", "y_range", "coord_range_y")
return tR, xR, yR
def resolve_gwl_dyn_index(
scaling_kwargs: dict[str, Any] | None,
) -> int:
"""Resolve GWL channel index for dynamic_features."""
sk = scaling_kwargs or {}
idx = sk.get("gwl_dyn_index", None)
if idx is not None:
return int(idx)
names = sk.get("dynamic_feature_names", None)
gwl_col = get_sk(sk, "gwl_col", default=None)
if names is not None and gwl_col is not None:
names = list(names)
if gwl_col in names:
return int(names.index(gwl_col))
raise ValueError(
"Cannot resolve GWL channel. Provide gwl_dyn_index "
"or dynamic_feature_names + gwl_col."
)
def get_gwl_dyn_index_cached(model) -> int:
"""Cache gwl_dyn_index on model after first resolve."""
idx = getattr(model, "gwl_dyn_index", None)
if idx is None:
idx = resolve_gwl_dyn_index(
getattr(
model,
"scaling_kwargs",
None,
)
)
model.gwl_dyn_index = int(idx)
return int(idx)
def resolve_subs_dyn_index(scaling_kwargs):
"""Resolve subsidence channel index for dynamic_features.
This is optional: v3.2 can use historical subsidence as a dynamic
driver to provide a physics-friendly initial condition for the mean
settlement path.
"""
sk = scaling_kwargs or {}
idx = sk.get("subs_dyn_index", None)
if idx is not None:
return int(idx)
names = sk.get("dynamic_feature_names", None)
subs_col = get_sk(sk, "subs_dyn_name", default=None)
# NEW: fallback to gwl_z_meta.cols.subs_model
if subs_col is None:
meta = sk.get("gwl_z_meta", {}) or {}
cols = meta.get("cols", {}) or {}
subs_col = cols.get("subs_model", None)
if names is not None and subs_col is not None:
names = list(names)
if subs_col in names:
return int(names.index(subs_col))
raise ValueError(
"Cannot resolve subsidence channel. Provide subs_dyn_index "
"or dynamic_feature_names + subs_dyn_name (or gwl_z_meta.cols.subs_model)."
)
def get_subs_dyn_index_cached(model) -> int:
"""Cache subs_dyn_index on model after first resolve."""
idx = getattr(model, "subs_dyn_index", None)
if idx is None:
idx = resolve_subs_dyn_index(
getattr(model, "scaling_kwargs", None)
)
model.subs_dyn_index = int(idx)
return int(idx)
def slice_dynamic_channel(Xh: Tensor, idx: int) -> Tensor:
"""Slice (B,T,F) -> (B,T,1) at idx."""
idx_t = tf_cast(idx, tf_int32)
F = tf_shape(Xh)[-1]
tf_debugging.assert_less(
idx_t,
F,
message="gwl_dyn_index out of range.",
)
return Xh[:, :, idx_t : idx_t + 1]
def assert_dynamic_names_match_tensor(
Xh: Tensor,
scaling_kwargs: dict[str, Any] | None,
) -> None:
"""Check dynamic_feature_names length matches Xh."""
sk = scaling_kwargs or {}
names = sk.get("dynamic_feature_names", None)
if names is None:
return
n = len(list(names))
tf_debugging.assert_equal(
tf_shape(Xh)[-1],
tf_constant(n, tf_int32),
message="dynamic_feature_names != Xh last dim",
)
def gwl_to_head_m(
v_m: Tensor,
scaling_kwargs: dict[str, Any] | None,
*,
inputs: dict[str, Tensor] | None = None,
) -> Tensor:
"""
Convert depth-bgs to head if possible.
Behavior
--------
- If gwl_kind is head-like: return v_m.
- Otherwise treat as depth and try:
head = z_surf - depth.
- If z_surf is missing:
* use_head_proxy=True -> return -depth
* use_head_proxy=False -> return depth
"""
sk = scaling_kwargs or {}
# --------------------------------------------------
# 1) Decide whether v_m is head or depth.
# --------------------------------------------------
kind_raw = sk.get("gwl_kind", None)
if kind_raw is None or str(kind_raw).strip() == "":
gwl_col = str(get_sk(sk, "gwl_col", default=""))
gwl_col = gwl_col.lower()
kind = "depth" if ("depth" in gwl_col) else "head"
else:
kind = str(kind_raw).lower()
if kind in ("head", "waterhead", "hydraulic_head"):
return tf_cast(v_m, tf_float32)
# --------------------------------------------------
# 2) Depth convention + proxy behavior.
# --------------------------------------------------
sign = str(sk.get("gwl_sign", "down_positive")).lower()
proxy = bool(sk.get("use_head_proxy", True))
# --------------------------------------------------
# 3) Collect possible z_surf keys.
# Prefer SI/static key first when available.
# --------------------------------------------------
meta = sk.get("gwl_z_meta", {}) or {}
cols = meta.get("cols", {}) or {}
z_surf_col = sk.get("z_surf_col", None)
z_surf_col = z_surf_col or meta.get("z_surf_col", None)
z_surf_static = cols.get("z_surf_static", None)
z_surf_raw = cols.get("z_surf_raw", None)
z_surf_keys = [
k
for k in (z_surf_static, z_surf_col, z_surf_raw)
if k
]
# Dedupe while preserving order.
seen = set()
z_surf_keys = [
k
for k in z_surf_keys
if not (k in seen or seen.add(k))
]
# --------------------------------------------------
# 4) Convert to positive-down depth.
# --------------------------------------------------
v_m = tf_cast(v_m, tf_float32)
depth_m = v_m if sign == "down_positive" else -v_m
# --------------------------------------------------
# 5) Try direct inputs[z_surf_key] first.
# --------------------------------------------------
z_surf = None
if inputs is not None:
for k in z_surf_keys:
z_surf = inputs.get(k, None)
if z_surf is not None:
z_surf = tf_cast(z_surf, tf_float32)
break
# --------------------------------------------------
# 6) If missing, try static_features lookup.
# --------------------------------------------------
if z_surf is None and inputs is not None:
sf = inputs.get("static_features", None)
if sf is not None:
sf = tf_cast(sf, tf_float32)
idx = sk.get("z_surf_static_index", None)
if idx is None:
names = get_sk(
sk,
"static_feature_names",
default=None,
)
if names is not None:
names = list(names)
for k in z_surf_keys:
if k in names:
idx = int(names.index(k))
break
if idx is not None:
idx_i = int(idx)
tf_debugging.assert_less(
tf_cast(idx_i, tf_int32),
tf_shape(sf)[-1],
message="z_surf_static_index out of range.",
)
r = getattr(sf.shape, "rank", None)
if r == 2:
z_surf = sf[:, idx_i : idx_i + 1]
elif r == 3:
z_surf = sf[:, :, idx_i : idx_i + 1]
else:
rr = tf_rank(sf)
z_surf = tf_cond(
tf_equal(rr, 2),
lambda: sf[:, idx_i : idx_i + 1],
lambda: sf[:, :, idx_i : idx_i + 1],
)
if z_surf is None:
# if bool(sk.get("debug_units", False)):
tf_print(
"[gwl_to_head_m] z_surf missing ->",
"use_head_proxy=",
bool(sk.get("use_head_proxy", False)),
"returning depth-like quantity (NOT true head)",
)
# --------------------------------------------------
# 7) If we have z_surf: head = z_surf - depth.
# --------------------------------------------------
if z_surf is not None:
r = tf_rank(z_surf)
z_surf = tf_cond(
tf_equal(r, 1),
lambda: z_surf[:, None, None],
lambda: tf_cond(
tf_equal(r, 2),
lambda: z_surf[:, None, :],
lambda: z_surf,
),
)
# Broadcast z_surf to match depth_m.
z_surf = z_surf + tf_zeros_like(depth_m)
return z_surf - depth_m
# --------------------------------------------------
# 8) Fallback: proxy head or keep depth.
# --------------------------------------------------
return -depth_m if proxy else depth_m
def _reshape_to_b11(v: Tensor) -> Tensor:
"""Coerce a tensor to (B,1,1) if possible."""
v = tf_cast(v, tf_float32)
r = tf_rank(v)
return tf_cond(
tf_equal(r, 1),
lambda: v[:, None, None],
lambda: tf_cond(
tf_equal(r, 2),
lambda: v[:, None, :],
lambda: v,
),
)
def get_h_hist_si(
model,
inputs: dict[str, Tensor],
*,
want_head: bool = True,
) -> Tensor:
"""Return head (or depth) history in SI meters.
Parameters
----------
model : object
The model instance (provides ``scaling_kwargs`` and cached indices).
inputs : dict
Batch inputs; expects ``dynamic_features`` unless an explicit
head history key is provided.
want_head : bool, default=True
If True, convert depth-bgs to hydraulic head when possible.
Returns
-------
Tensor
(B,T,1) tensor in SI meters.
"""
sk = getattr(model, "scaling_kwargs", None)
# Explicit override (useful for scenario-driven runs)
for k in ("h_hist_si", "head_hist_si", "gwl_hist_si"):
if k in inputs and inputs[k] is not None:
v = tf_cast(inputs[k], tf_float32)
# (B,T) -> (B,T,1)
if tf_equal(tf_rank(v), 2):
v = v[:, :, None]
if want_head:
v = gwl_to_head_m(v, sk, inputs=inputs)
return v
Xh = inputs.get("dynamic_features", None)
if Xh is None:
raise ValueError(
"Cannot build head history: missing inputs['dynamic_features'] "
"and no explicit head history key (h_hist_si/head_hist_si)."
)
Xh = tf_cast(Xh, tf_float32)
assert_dynamic_names_match_tensor(Xh, sk)
gwl_idx = get_gwl_dyn_index_cached(model)
gwl = slice_dynamic_channel(Xh, gwl_idx)
gwl_si = to_si_head(gwl, sk)
return (
gwl_to_head_m(gwl_si, sk, inputs=inputs)
if want_head
else gwl_si
)
def get_s_init_si(
model,
inputs: dict[str, Tensor] | None,
like: Tensor,
) -> Tensor:
"""Return initial settlement (cumulative subsidence) in SI meters.
Priority:
1) explicit keys in inputs (s_init_si/subs_hist_last_si/...)
2) last historical value from dynamic_features if subs_dyn_index exists
3) zeros (broadcast)
"""
sk = getattr(model, "scaling_kwargs", None)
if inputs is not None:
for k in (
"s_init_si",
"subs_init_si",
"subs_hist_last_si",
"s_ref_si",
"subs_ref_si",
"s_init",
"subs_init",
):
if k in inputs and inputs[k] is not None:
return _reshape_to_b11(
inputs[k]
) + tf_zeros_like(like)
Xh = inputs.get("dynamic_features", None)
if Xh is not None:
try:
subs_idx = get_subs_dyn_index_cached(model)
except Exception as e:
_handle_scaling_issue(
getattr(model, "scaling_kwargs", None),
f"Could not resolve subsidence init channel ({e}). "
"Falling back to zeros for s_init_si.",
where="runtime",
)
subs_idx = None
if subs_idx is not None:
Xh = tf_cast(Xh, tf_float32)
assert_dynamic_names_match_tensor(Xh, sk)
s_hist = slice_dynamic_channel(
Xh, int(subs_idx)
)
s_last = s_hist[:, -1:, :]
s_last_si = to_si_subsidence(s_last, sk)
return s_last_si + tf_zeros_like(like)
return tf_zeros_like(like)
def get_h_ref_si(
model,
inputs: dict[str, Tensor] | None,
like: Tensor,
) -> Tensor:
"""Return h_ref in SI meters, broadcast to like."""
# sk = getattr(model, "scaling_kwargs", None)
mode = getattr(
getattr(model, "h_ref_config", None), "mode", "auto"
)
mode = (
"fixed"
if str(mode).lower().strip() == "fixed"
else "auto"
)
if inputs is not None:
for k in (
"h_ref_si",
"head_ref_si",
"h_ref",
"head_ref",
):
if (k in inputs) and (inputs[k] is not None):
h_ref = tf_cast(inputs[k], tf_float32)
r = tf_rank(h_ref)
h_ref = tf_cond(
tf_equal(r, 1),
lambda: h_ref[:, None, None],
lambda: tf_cond(
tf_equal(r, 2),
lambda: h_ref[:, None, :],
lambda: h_ref,
),
)
return h_ref + tf_zeros_like(like)
if (
mode != "fixed"
and inputs is not None
and "dynamic_features" in inputs
and inputs["dynamic_features"] is not None
):
h_hist = get_h_hist_si(model, inputs, want_head=True)
return h_hist[:, -1:, :] + tf_zeros_like(like)
h0 = tf_cast(getattr(model, "h_ref", 0.0), tf_float32)
h0 = h0[None, None, None]
return h0 + tf_zeros_like(like)
def infer_dt_units_from_t(
t_BH1: Tensor,
scaling_kwargs: dict[str, Any] | None,
*,
eps: float = 1e-12,
) -> Tensor:
"""
Infer per-step dt in *time_units* from time tensor t(B,H,1).
Shapes
------
t_BH1 : (B,H,1)
returns: (B,H,1)
Notes
-----
- dt uses diffs along H; first step uses the first diff.
- If coords are normalized, dt is multiplied by the de-normalization
time range tR (from coord_ranges()).
- Output is clipped to >= eps.
"""
sk = scaling_kwargs or {}
t = tf_convert_to_tensor(t_BH1, dtype=tf_float32)
# t shape: (B,H,1)
H = tf_shape(t)[1]
dt_default = tf_ones_like(t) # (B,H,1), safe in-graph
def _multi_step():
diffs = t[:, 1:, :] - t[:, :-1, :] # (B,H-1,1)
dt_first = diffs[:, :1, :] # (B,1,1)
dt = tf_concat([dt_first, diffs], axis=1) # (B,H,1)
# If coords were normalized, dt is still normalized -> scale back
if bool(sk.get("coords_normalized", False)):
tR, _, _ = coord_ranges(sk)
if tR is None:
raise ValueError(
"coords_normalized=True but coord_ranges missing."
)
dt = dt * tf_constant(float(tR), dtype=tf_float32)
return dt
# if H <= 1: ones; else: diffs
dt = tf_cond(
tf_less_equal(H, 1), lambda: dt_default, _multi_step
)
dt = tf_abs(dt)
dt_pos = tf_greater(dt, tf_constant(0.0, tf_float32))
dt_pos_f = tf_cast(dt_pos, tf_float32)
dt = dt * dt_pos_f + dt_default * (1.0 - dt_pos_f)
dt_eps = float(get_sk(sk, "dt_min_units", default=1e-6))
dt = tf_maximum(dt, tf_constant(dt_eps, tf_float32))
return dt
# -------------------------------------------------
# Training strategy gates (Q and subsidence residual)
# ---------------------------------------------------------------------
def policy_gate(
step: Tensor,
policy: str,
*,
warmup_steps: int = 0,
ramp_steps: int = 0,
dtype: Any = tf_float32,
) -> Tensor:
r"""Return a scalar gate in ``[0,1]`` based on a policy + step.
Parameters
----------
step : Tensor
Global step counter (typically ``optimizer.iterations``).
policy : {"always_on","always_off","warmup_off"}
Gating behavior. ``always_on`` returns 1, ``always_off``
returns 0, and ``warmup_off`` returns 0 for
``step < warmup_steps`` before ramping to 1 over
``ramp_steps`` when ``ramp_steps > 0`` or switching
immediately at ``warmup_steps`` otherwise.
warmup_steps : int, default=0
Number of steps to keep the gate at 0 (only for ``warmup_off``).
ramp_steps : int, default=0
Number of steps for a linear ramp from 0->1 after warmup.
If 0, the gate is a hard step.
dtype : dtype, default=tf_float32
Output dtype.
"""
pol = (policy or "always_on").strip().lower()
if pol in ("always_on", "on", "true", "1"):
return tf_constant(1.0, dtype=dtype)
if pol in ("always_off", "off", "false", "0"):
return tf_constant(0.0, dtype=dtype)
w = int(warmup_steps or 0)
r = int(ramp_steps or 0)
if w <= 0 and r <= 0:
return tf_constant(1.0, dtype=dtype)
step_i = tf_cast(step, tf_int32)
if r <= 0:
return tf_cast(
tf_greater_equal(
step_i, tf_constant(w, tf_int32)
),
dtype,
)
step_f = tf_cast(step_i, dtype)
w_f = tf_constant(float(w), dtype)
r_f = tf_constant(float(r), dtype)
frac = (step_f - w_f) / r_f
frac = tf_maximum(tf_constant(0.0, dtype), frac)
frac = tf_minimum(tf_constant(1.0, dtype), frac)
return frac
# ---------------------------------------------------------------------
# Derived SI conversion helpers (optional, but recommended)
# ---------------------------------------------------------------------
def finalize_scaling_kwargs(
sk: dict[str, Any],
) -> dict[str, Any]:
"""Add derived SI conversion constants to ``scaling_kwargs``.
Adds (when possible):
- ``seconds_per_time_unit``: float
- ``coord_ranges_si``: dict with keys ``t`` (seconds), ``x``/``y`` (meters)
- ``coord_inv_ranges_si``: inverse of the above (safe floor).
Notes
-----
This helper is designed to be called *once* when assembling
``scaling_kwargs`` (e.g., in your stage2 script) so the model can
reuse those constants without recomputing unit conversions in the
hot training loop.
"""
if sk is None:
return sk
sk = dict(sk)
tu = (
str(get_sk(sk, "time_units", default="second"))
.strip()
.lower()
)
time_unit_to_seconds = {
"second": 1.0,
"sec": 1.0,
"s": 1.0,
"minute": 60.0,
"min": 60.0,
"m": 60.0,
"hour": 3600.0,
"h": 3600.0,
"day": 86400.0,
"d": 86400.0,
# Julian year (365.2425 days) to match prior_maths.py
"year": 31556952.0,
"yr": 31556952.0,
"y": 31556952.0,
}
sec_u = float(time_unit_to_seconds.get(tu, 1.0))
sk.setdefault("seconds_per_time_unit", sec_u)
cr = get_sk(sk, "coord_ranges", default=None)
if isinstance(cr, Mapping) and all(
k in cr for k in ("t", "x", "y")
):
tR = float(cr.get("t", 1.0))
xR = float(cr.get("x", 1.0))
yR = float(cr.get("y", 1.0))
# If coordinates are degrees, convert spans to meters.
if bool(
get_sk(sk, "coords_in_degrees", default=False)
):
deg_to_m_lon = get_sk(
sk, "deg_to_m_lon", default=None
)
deg_to_m_lat = get_sk(
sk, "deg_to_m_lat", default=None
)
if (
deg_to_m_lon is not None
and deg_to_m_lat is not None
):
xR *= float(deg_to_m_lon)
yR *= float(deg_to_m_lat)
# Convert time span to seconds (important if coords_normalized=True).
tR *= sec_u
sk["coord_ranges_si"] = {"t": tR, "x": xR, "y": yR}
eps = 1e-12
sk["coord_inv_ranges_si"] = {
"t": 1.0 / max(tR, eps),
"x": 1.0 / max(xR, eps),
"y": 1.0 / max(yR, eps),
}
return sk
def coord_ranges_si(
sk: dict[str, Any],
) -> tuple[float | None, float | None, float | None]:
"""Return coordinate spans in SI (t in seconds; x/y in meters).
If ``coord_ranges_si`` is present in ``sk``, it is used directly.
Otherwise, this is computed from ``coord_ranges`` and ``time_units``
(and degree-to-meter factors when applicable).
"""
cr_si = get_sk(sk, "coord_ranges_si", default=None)
if isinstance(cr_si, Mapping) and all(
k in cr_si for k in ("t", "x", "y")
):
return (
float(cr_si["t"]),
float(cr_si["x"]),
float(cr_si["y"]),
)
sk2 = finalize_scaling_kwargs(sk)
cr_si = get_sk(sk2, "coord_ranges_si", default=None)
if isinstance(cr_si, Mapping) and all(
k in cr_si for k in ("t", "x", "y")
):
return (
float(cr_si["t"]),
float(cr_si["x"]),
float(cr_si["y"]),
)
return None, None, None
NAT package exports#
r"""Public exports for NAT workflow utilities."""
from .nat_utils import (
build_censor_mask,
ensure_config_json,
ensure_input_shapes,
get_config_paths,
get_natcom_dir,
load_nat_config,
load_nat_config_payload,
load_scaler_info,
load_windows_npz,
make_tf_dataset,
map_targets_for_training,
pick_npz_for_dataset,
resolve_hybrid_config,
resolve_si_affine,
sanitize_inputs_np,
)
from .natutils import (
best_epoch_and_metrics,
compile_for_eval,
compile_geoprior_for_eval,
extract_preds,
load_best_hps_near_model,
load_hps_auto_near_model,
load_or_rebuild_geoprior_model,
load_trained_hps_near_model,
load_tuned_hps_near_model,
name_of,
save_ablation_record,
serialize_subs_params,
subs_point_from_out,
)
__all__ = [
"build_censor_mask",
"ensure_input_shapes",
"extract_preds",
"load_nat_config",
"load_nat_config_payload",
"load_scaler_info",
"make_tf_dataset",
"map_targets_for_training",
"name_of",
"resolve_hybrid_config",
"resolve_si_affine",
"best_epoch_and_metrics",
"subs_point_from_out",
"serialize_subs_params",
"save_ablation_record",
"load_windows_npz",
"load_tuned_hps_near_model",
"load_trained_hps_near_model",
"sanitize_inputs_np",
"load_hps_auto_near_model",
"load_or_rebuild_geoprior_model",
"compile_for_eval",
"load_best_hps_near_model",
"pick_npz_for_dataset",
"ensure_config_json",
"get_natcom_dir",
"get_config_paths",
"compile_geoprior_for_eval",
]
Optional: workflow/NAT helper module#
# SPDX-License-Identifier: Apache-2.0
# GeoPrior-v3 — https://github.com/earthai-tech/geoprior-v3
# Copyright (c) 2026-present
# Author: LKouadio <etanoyau@gmail.com>
# website:https://lkouadio.com
r"""NAT evaluation and artifact helpers for GeoPrior."""
from __future__ import annotations
import datetime as dt
import glob
import json
import os
from collections.abc import Mapping, Sequence
from typing import Any
import numpy as np
import pandas as pd
# --- Optional TensorFlow import for GeoPrior helpers -----------------------
try: # pragma: no cover - defensive import
import tensorflow as tf # noqa
from tensorflow.keras.optimizers import Adam
TF_AVAILABLE = True
except Exception: # pragma: no cover
TF_AVAILABLE = False
tf = None # type: ignore[assignment]
class _AdamStub:
def __init__(self, *args: Any, **kwargs: Any) -> None:
raise ImportError(
"TensorFlow is required for NATCOM GeoPrior helpers "
"(e.g. compile_geoprior_for_eval). Please install "
"`tensorflow>=2.12`."
)
Adam = _AdamStub # type: ignore[assignment]
def save_ablation_record(
outdir: str,
city: str,
model_name: str,
cfg: dict,
eval_dict: dict | None,
phys_diag: dict | None = None,
per_h_mae: dict | None = None,
per_h_r2: dict | None = None,
log_fn=None,
) -> None:
"""
Append a single ablation record to ``ablation_record.jsonl``.
Each training run (e.g., different physics toggles or weights)
writes one JSON line containing:
- Basic run identifiers (city, model, timestamp).
- Physics configuration (``PDE_MODE_CONFIG``, lambda weights,
effective head flags, etc.).
- Key performance metrics (R², MSE, MAE, coverage, sharpness).
- Optional physics diagnostics (``epsilon_prior``,
``epsilon_cons``).
- Optional per-horizon MAE/R² for more detailed analysis.
Parameters
----------
outdir : str
Base output directory for the current run. The ablation
file is created under ``outdir / "ablation_records"``.
city : str
City name (e.g., ``"nansha"`` or ``"zhongshan"``).
model_name : str
Model identifier (e.g., ``"GeoPriorSubsNet"``).
cfg : dict
Lightweight configuration dictionary containing at least
the physics-related keys used below.
eval_dict : dict or None
Dictionary of evaluation metrics (R², MSE, MAE,
coverage80, sharpness80). If ``None``, metrics fields
default to ``None``.
phys_diag : dict or None, optional
Physics diagnostics (e.g., from ``evaluate()``) with keys
such as ``"epsilon_prior"`` and ``"epsilon_cons"``.
per_h_mae : dict or None, optional
Per-horizon MAE values (e.g., keyed by year/step).
per_h_r2 : dict or None, optional
Per-horizon R² values.
Notes
-----
The output file is a JSON-Lines file, so it can be loaded
with :func:`load_ablation_jsonl`.
"""
if log_fn is None:
log_fn = print
# eval_dict = eval_dict or {}
metrics = dict(eval_dict or {})
rec = {
"timestamp": dt.datetime.now().strftime(
"%Y%m%d-%H%M%S"
),
"city": city,
"model": model_name,
# Physics toggles / weights
"pde_mode": cfg.get("PDE_MODE_CONFIG"),
"use_effective_h": bool(
cfg.get("GEOPRIOR_USE_EFFECTIVE_H", True)
),
"kappa_mode": cfg.get("GEOPRIOR_KAPPA_MODE", "bar"),
"hd_factor": cfg.get("GEOPRIOR_HD_FACTOR", 0.6),
"lambda_cons": cfg.get("LAMBDA_CONS"),
"lambda_gw": cfg.get("LAMBDA_GW"),
"lambda_prior": cfg.get("LAMBDA_PRIOR"),
"lambda_smooth": cfg.get("LAMBDA_SMOOTH"),
"lambda_mv": cfg.get("LAMBDA_MV"),
"lambda_bounds": cfg.get("LAMBDA_BOUNDS"),
"lambda_q": cfg.get("LAMBDA_Q"),
# Key metrics
# "r2": eval_dict.get("r2"),
# "mse": eval_dict.get("mse"),
# "mae": eval_dict.get("mae"),
# "coverage80": eval_dict.get("coverage80"),
# "sharpness80": eval_dict.get("sharpness80"),
"r2": metrics.get("r2"),
"mse": metrics.get("mse"),
"mae": metrics.get("mae"),
"rmse": metrics.get("rmse"),
"coverage80": metrics.get("coverage80"),
"sharpness80": metrics.get("sharpness80"),
}
# Keep the full metrics payload (post-hoc vs evaluate(), units, etc.)
rec["metrics"] = metrics
# Convenience: surface units at top-level if provided.
if isinstance(metrics.get("units"), dict):
rec["units"] = metrics.get("units")
if phys_diag:
rec["epsilon_prior"] = phys_diag.get("epsilon_prior")
rec["epsilon_cons"] = phys_diag.get("epsilon_cons")
if "epsilon_gw" in phys_diag:
rec["epsilon_gw"] = phys_diag.get("epsilon_gw")
if per_h_mae is not None:
rec["per_horizon_mae"] = per_h_mae
if per_h_r2 is not None:
rec["per_horizon_r2"] = per_h_r2
abl_dir = os.path.join(outdir, "ablation_records")
os.makedirs(abl_dir, exist_ok=True)
jpath = os.path.join(abl_dir, "ablation_record.jsonl")
with open(jpath, "a", encoding="utf-8") as f:
f.write(json.dumps(rec) + "\n")
log_fn(f"[Ablation] appended -> {jpath}")
def load_ablation_jsonl(path: str) -> pd.DataFrame:
"""
Load an ablation JSON-Lines file into a :class:`pandas.DataFrame`.
This is the companion to :func:`save_ablation_record`. Each
line is parsed as JSON and turned into one row.
Parameters
----------
path : str
Path to ``ablation_record.jsonl``.
Returns
-------
pandas.DataFrame
DataFrame where each row corresponds to one ablation
record.
Examples
--------
>>> df_abl = load_ablation_jsonl(
... "ablation_records/ablation_record.jsonl"
... )
>>> df_abl.head()
"""
rows = []
with open(path, encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
rows.append(json.loads(line))
return pd.DataFrame(rows)
def name_of(obj: object) -> str:
"""
Return a human-readable name for an object.
This utility is handy when serialising compile configurations
(e.g., turning metric callables into simple strings for JSON
logs).
Parameters
----------
obj : object
Any Python object (function, class instance, etc.).
Returns
-------
str
``obj.__name__`` if present, otherwise the class name, and
finally ``str(obj)`` as a last resort.
"""
if hasattr(obj, "__name__"):
return obj.__name__
if hasattr(obj, "__class__"):
return obj.__class__.__name__
return str(obj)
def serialize_subs_params(
params: dict,
cfg: dict | None = None,
) -> dict:
"""
Make GeoPrior subnet parameters JSON-friendly.
The training scripts typically pass a dictionary of model
construction arguments, e.g. ``subsmodel_params``, which
contains objects such as ``LearnableMV`` or ``FixedGammaW``
that are not directly JSON-serialisable.
This helper replaces those objects by small dictionaries
describing their type and scalar value, optionally using
values from the NATCOM config dictionary.
Parameters
----------
params : dict
Dictionary of model init parameters (e.g.
``subsmodel_params`` in ``training_NATCOM_GEOPRIOR.py``).
cfg : dict, optional
NATCOM config dictionary. If provided, scalar values are
taken from:
- ``GEOPRIOR_INIT_MV``
- ``GEOPRIOR_INIT_KAPPA``
- ``GEOPRIOR_GAMMA_W``
- ``GEOPRIOR_H_REF``
and used as the authoritative numbers.
Returns
-------
dict
Copy of ``params`` where scalar GeoPrior parameters are
replaced by JSON-friendly dictionaries.
Notes
-----
This function does **not** import any of the GeoPrior classes.
It only introspects attributes like ``initial_value`` or
``value`` when the corresponding config entry is missing.
"""
out = dict(params)
cfg = cfg or {}
# Helper to extract a scalar from either the config or the
# original object (Learnable*/Fixed*).
def _extract_scalar(obj, cfg_key: str) -> float | None:
if cfg_key in cfg and cfg[cfg_key] is not None:
try:
return float(cfg[cfg_key])
except Exception:
pass
# Fallback: try to read a typical attribute name.
for attr in ("initial_value", "value"):
if hasattr(obj, attr):
try:
return float(getattr(obj, attr))
except Exception:
continue
return None
if "mv" in out:
mv_val = _extract_scalar(
out["mv"], "GEOPRIOR_INIT_MV"
)
out["mv"] = {
"type": "LearnableMV",
"initial_value": mv_val,
}
if "kappa" in out:
kap_val = _extract_scalar(
out["kappa"], "GEOPRIOR_INIT_KAPPA"
)
out["kappa"] = {
"type": "LearnableKappa",
"initial_value": kap_val,
}
if "gamma_w" in out:
gw_val = _extract_scalar(
out["gamma_w"], "GEOPRIOR_GAMMA_W"
)
out["gamma_w"] = {
"type": "FixedGammaW",
"value": gw_val,
}
if "h_ref" in out:
href_val = _extract_scalar(
out["h_ref"], "GEOPRIOR_H_REF"
)
out["h_ref"] = {
"type": "FixedHRef",
"value": href_val,
}
return out
def best_epoch_and_metrics(
history: dict,
monitor: str = "val_loss",
) -> tuple[int | None, dict]:
"""
Return the best epoch and metrics at that epoch.
Given a ``History.history`` dictionary produced by
``model.fit(...)``, this helper identifies the index of the
minimum value for the monitored quantity (by default
``"val_loss"``) and returns:
- The epoch index (0-based).
- A dictionary mapping each metric name to its value at that
epoch.
Parameters
----------
history : dict
The ``history.history`` attribute from Keras training.
monitor : str, default="val_loss"
Name of the metric to minimise.
Returns
-------
best_epoch : int or None
Index of the best epoch, or ``None`` if ``monitor`` is
not present.
metrics_at_best : dict
Mapping from metric name to its value at the best epoch.
Empty if ``monitor`` is not present.
"""
if not history or monitor not in history:
return None, {}
# nanargmin makes sure NaNs are ignored when searching for the
# best epoch.
be = int(np.nanargmin(history[monitor]))
metrics_at_best = {
k: float(v[be])
for k, v in history.items()
if len(v) > be
}
return be, metrics_at_best
def load_or_rebuild_geoprior_model(
model_path: str,
manifest: dict,
X_sample: dict,
out_s_dim: int,
out_g_dim: int,
mode: str,
horizon: int,
quantiles: list[float] | None,
city_name: str | None = None,
compile_on_load: bool = True,
verbose: int = 1,
):
"""
Load a tuned *or trained* GeoPriorSubsNet, with robust rebuild fallback.
Strategy
--------
1. Try ``tf.keras.models.load_model(model_path)`` with all required
custom objects registered.
2. If that fails:
2a) Try tuned-model reconstruction::
best_hps = load_best_hps_near_model(...)
model = build_geoprior_from_hps(...)
2b) If no best_hps JSON is found, assume a plain *trained* model
and fall back to the ``*_training_summary.json`` recorded next
to the checkpoint::
training_summary = load_training_summary_near_model(...)
model = build_geoprior_from_training_summary(...)
In both cases, a minimal ``best_hps``-like dict is returned so
:func:`compile_for_eval` can recreate the physics weights and
learning rate.
3. Try to load the best weights checkpoint via
:func:`infer_best_weights_path(model_path)`.
Returns
-------
model :
A GeoPriorSubsNet instance ready to be recompiled for evaluation.
best_hps : dict or None
Tuned hyperparameters if present, otherwise a small dict
containing at least ``learning_rate`` and the lambda weights
from the training summary, or ``None``.
"""
label_city = city_name or "GeoPrior"
# --- Lazy imports so nat_utils can be imported without TF/geoprior ---
try:
import tensorflow as tf # type: ignore # noqa
from tensorflow.keras.models import load_model # type: ignore
from tensorflow.keras.utils import custom_object_scope # type: ignore
except Exception as e: # pragma: no cover - env dependent
raise ImportError(
"load_or_rebuild_geoprior_model requires TensorFlow. "
"Please install 'tensorflow>=2.12' to use this helper."
) from e
try:
from geoprior.nn.keras_metrics import ( # type: ignore
coverage80_fn,
sharpness80_fn,
)
from geoprior.nn.losses import (
make_weighted_pinball, # type: ignore
)
from geoprior.nn.pinn.models import (
GeoPriorSubsNet, # type: ignore
)
from geoprior.params import ( # type: ignore
FixedGammaW,
FixedHRef,
LearnableKappa,
LearnableMV,
)
except Exception as e: # pragma: no cover - env dependent
raise ImportError(
"load_or_rebuild_geoprior_model requires geoprior components "
"(GeoPriorSubsNet, LearnableMV, etc.). Ensure geoprior is "
"installed and importable."
) from e
custom_objects = {
"GeoPriorSubsNet": GeoPriorSubsNet,
"LearnableMV": LearnableMV,
"LearnableKappa": LearnableKappa,
"FixedGammaW": FixedGammaW,
"FixedHRef": FixedHRef,
"make_weighted_pinball": make_weighted_pinball,
"coverage80_fn": coverage80_fn,
"sharpness80_fn": sharpness80_fn,
}
best_hps: dict | None = None
# ------------------- 1) Try direct load_model -------------------------
with custom_object_scope(custom_objects):
if verbose:
print(
f"[Model] Attempting to load model from: {model_path}"
)
try:
model = load_model(
model_path, compile=compile_on_load
)
if verbose:
print(
f"[Model] Successfully loaded model for {label_city} "
f"from: {model_path}"
)
return model, best_hps
except Exception as e_load:
if verbose:
print(
f"[Warn] load_model('{model_path}') failed: {e_load}\n"
"[Warn] Falling back to config-based reconstruction."
)
# ------------------- 2) Fallback: tuned HPs OR training summary -------
try:
# 2a) Tuned model path
best_hps = load_best_hps_near_model(model_path)
model = build_geoprior_from_hps(
manifest=manifest,
X_sample=X_sample,
best_hps=best_hps,
out_s_dim=out_s_dim,
out_g_dim=out_g_dim,
mode=mode,
horizon=horizon,
quantiles=quantiles,
)
except Exception as e_hps:
# 2b) No best_hps JSON -> treat this as a plain trained model.
if verbose:
print(
"[Fallback] No best_hps JSON found next to model_path; "
"assuming a plain trained model.\n"
f" Reason: {e_hps}"
)
training_summary = load_training_summary_near_model(
model_path, city_name=city_name
)
if training_summary is None:
raise RuntimeError(
"Failed to reconstruct GeoPriorSubsNet: neither tuned "
"hyperparameters nor *_training_summary.json were found "
f"near model_path={model_path!r}."
) from e_hps
model = build_geoprior_from_training_summary(
manifest=manifest,
X_sample=X_sample,
training_summary=training_summary,
out_s_dim=out_s_dim,
out_g_dim=out_g_dim,
mode=mode,
horizon=horizon,
quantiles=quantiles,
)
# Build a minimal best_hps dict so compile_for_eval can recover the
# training-time physics weights and learning rate.
compile_block = (
training_summary.get("compile", {}) or {}
)
phys = (
compile_block.get("physics_loss_weights", {})
or {}
)
lr = compile_block.get("learning_rate", None)
hps_from_train: dict[str, float] = {}
if lr is not None:
try:
hps_from_train["learning_rate"] = float(lr)
except Exception:
pass
for k, v in phys.items():
try:
hps_from_train[k] = float(v)
except Exception:
continue
best_hps = hps_from_train or None
# ------------------- 3) Load weights if checkpoint exists -------------
weights_path = infer_best_weights_path(model_path)
if weights_path is not None:
try:
model.load_weights(weights_path)
if verbose:
print(
"[Fallback] Loaded weights into reconstructed "
f"GeoPriorSubsNet from: {weights_path}"
)
except Exception as e_w:
if verbose:
print(
"[Warn] Could not load weights from checkpoint:\n"
f" {weights_path}\n"
f" Error: {e_w}\n"
" The rebuilt model is using freshly-initialised "
"weights. Predictions will NOT match the original run."
)
else:
if verbose:
print(
"[Warn] No weights checkpoint found near model.\n"
" Using rebuilt model with freshly-initialised "
"weights. Predictions will NOT match the original run."
)
return model, best_hps
def build_geoprior_from_training_summary(
manifest: dict,
X_sample: dict,
training_summary: dict,
out_s_dim: int,
out_g_dim: int,
mode: str,
horizon: int,
quantiles: list[float] | None,
) -> Any:
"""
Reconstruct a GeoPriorSubsNet from a training_summary JSON.
This is the fallback path for plain *trained* models (no tuning),
using the architecture recorded under ``hp_init['model_init_params']``.
Parameters
----------
manifest : dict
Stage-1 manifest dictionary (for some defaults).
X_sample : dict
One NPZ inputs dictionary already passed through
:func:`ensure_input_shapes`. Only shapes are used.
training_summary : dict
Parsed ``*_training_summary.json`` for this run.
out_s_dim, out_g_dim, mode, horizon, quantiles :
Same semantics as in :func:`build_geoprior_from_hps`.
Returns
-------
model : GeoPriorSubsNet
Reconstructed model (uncompiled).
"""
try:
from geoprior.nn.pinn.models import (
GeoPriorSubsNet, # type: ignore
)
except Exception as e: # pragma: no cover - env dependent
raise ImportError(
"build_geoprior_from_training_summary requires "
"'geoprior.nn.pinn.models.GeoPriorSubsNet'. "
"Ensure geoprior is installed and importable."
) from e
cfg = manifest.get("config", {}) or {}
# Infer input dims from the NPZ sample
static_dim, dynamic_dim, future_dim = (
infer_input_dims_from_X(X_sample)
)
hp_init = training_summary.get("hp_init", {}) or {}
model_init = hp_init.get("model_init_params", {}) or {}
# Quantiles: prefer explicit argument, then the training summary
q = quantiles or hp_init.get("quantiles")
# Attention stack
attention_levels = model_init.get(
"attention_levels",
hp_init.get(
"attention_levels",
cfg.get(
"ATTENTION_LEVELS",
["cross", "hierarchical", "memory"],
),
),
)
# Physics toggles
censor_cfg = cfg.get("censoring", {}) or {}
use_effective_h = bool(
model_init.get(
"use_effective_h",
hp_init.get(
"use_effective_h",
censor_cfg.get("use_effective_h_field", True),
),
)
)
pde_mode = hp_init.get(
"pde_mode", cfg.get("PDE_MODE_CONFIG", "both")
)
kappa_mode = model_init.get(
"kappa_mode",
cfg.get("GEOPRIOR_KAPPA_MODE", "bar"),
)
scale_pde_residuals = bool(
model_init.get("scale_pde_residuals", True)
)
# Architecture hyperparameters (as used at training time)
embed_dim = int(model_init.get("embed_dim", 32))
hidden_units = int(model_init.get("hidden_units", 96))
lstm_units = int(model_init.get("lstm_units", 96))
attention_units = int(
model_init.get("attention_units", 32)
)
num_heads = int(model_init.get("num_heads", 4))
dropout_rate = float(model_init.get("dropout_rate", 0.1))
use_vsn = bool(
model_init.get(
"use_vsn", hp_init.get("use_vsn", True)
)
)
vsn_units = int(
model_init.get(
"vsn_units", hp_init.get("vsn_units", 32)
)
)
use_batch_norm = bool(
model_init.get(
"use_batch_norm",
hp_init.get(
"use_batch_norm",
cfg.get("USE_BATCH_NORM", True),
),
)
)
# Geomechanical parameters were serialised via `serialize_subs_params`
# so we need to extract scalar initial values again.
def _extract_initial(
spec: Any, cfg_key: str, cfg_default: float
) -> float:
if isinstance(spec, dict):
if "initial_value" in spec:
try:
return float(spec["initial_value"])
except Exception:
pass
if "value" in spec:
try:
return float(spec["value"])
except Exception:
pass
if cfg_key in cfg and cfg[cfg_key] is not None:
try:
return float(cfg[cfg_key])
except Exception:
pass
return float(cfg_default)
mv_spec = model_init.get("mv", {})
kappa_spec = model_init.get("kappa", {})
mv_init = _extract_initial(
mv_spec, "GEOPRIOR_INIT_MV", 5e-7
)
kappa_init = _extract_initial(
kappa_spec, "GEOPRIOR_INIT_KAPPA", 1.0
)
# Pack the remaining architectural knobs into `architecture_config`
known_keys = {
"embed_dim",
"hidden_units",
"lstm_units",
"attention_units",
"num_heads",
"dropout_rate",
"use_vsn",
"vsn_units",
"use_batch_norm",
"mv",
"kappa",
"gamma_w",
"h_ref",
"kappa_mode",
"use_effective_h",
"scale_pde_residuals",
"attention_levels",
"mode",
"time_steps",
}
architecture_config = {
k: v
for k, v in model_init.items()
if k not in known_keys
}
model = GeoPriorSubsNet(
static_input_dim=static_dim,
dynamic_input_dim=dynamic_dim,
future_input_dim=future_dim,
output_subsidence_dim=out_s_dim,
output_gwl_dim=out_g_dim,
forecast_horizon=horizon,
mode=mode,
attention_levels=attention_levels,
quantiles=q,
# physics switches
pde_mode=pde_mode,
scale_pde_residuals=scale_pde_residuals,
kappa_mode=kappa_mode,
use_effective_h=use_effective_h,
# architecture hyperparameters
embed_dim=embed_dim,
hidden_units=hidden_units,
lstm_units=lstm_units,
attention_units=attention_units,
num_heads=num_heads,
dropout_rate=dropout_rate,
use_vsn=use_vsn,
vsn_units=vsn_units,
use_batch_norm=use_batch_norm,
# geomechanical priors
mv=float(mv_init),
kappa=float(kappa_init),
architecture_config=architecture_config,
)
print(
"[Fallback] Reconstructed GeoPriorSubsNet from training_summary with "
f"static_dim={static_dim}, dynamic_dim={dynamic_dim}, "
f"future_dim={future_dim}, horizon={horizon}, mode={mode}"
)
return model
def load_geoprior_for_inference(
model_path: str,
manifest: dict,
X_sample: dict,
out_s_dim: int,
out_g_dim: int,
mode: str,
horizon: int,
quantiles: list[float] | None,
city_name: str | None = None,
include_metrics: bool = True,
verbose: int = 1,
):
"""
Convenience wrapper: load (tuned or trained) GeoPriorSubsNet and
compile it for evaluation/inference.
Returns
-------
model :
Compiled model ready for ``predict`` / diagnostics.
info : dict
Small dict with the ``best_hps`` (if any) and the resolved
quantiles, useful for logging.
"""
model, best_hps = load_or_rebuild_geoprior_model(
model_path=model_path,
manifest=manifest,
X_sample=X_sample,
out_s_dim=out_s_dim,
out_g_dim=out_g_dim,
mode=mode,
horizon=horizon,
quantiles=quantiles,
city_name=city_name,
compile_on_load=False,
verbose=verbose,
)
model = compile_for_eval(
model=model,
manifest=manifest,
best_hps=best_hps,
quantiles=quantiles,
include_metrics=include_metrics,
)
info = {
"best_hps": best_hps,
"quantiles": quantiles,
}
return model, info
def load_training_summary_near_model(
model_path: str,
city_name: str | None = None,
) -> dict | None:
"""
Locate and load a ``*_training_summary.json`` next to a trained model.
Strategy
--------
1. Prefer ``<city>_GeoPriorSubsNet_training_summary.json`` if
``city_name`` is given.
2. Fallback: first file in the run directory that ends with
``'_training_summary.json'``.
Parameters
----------
model_path : str
Path to the `.keras` archive or any checkpoint inside a
``train_YYYYMMDD-HHMMSS`` directory.
city_name : str or None
Optional city name to build a more specific candidate filename.
Returns
-------
dict or None
Parsed JSON dict if found and loadable, otherwise ``None``.
"""
run_dir = os.path.dirname(os.path.abspath(model_path))
candidates: list[str] = []
if city_name:
candidates.append(
os.path.join(
run_dir,
f"{city_name}_GeoPriorSubsNet_training_summary.json",
)
)
# Generic fallback: any *_training_summary.json in the run dir
try:
for fname in os.listdir(run_dir):
if fname.endswith("_training_summary.json"):
candidates.append(
os.path.join(run_dir, fname)
)
except FileNotFoundError:
return None
seen: set[str] = set()
for path in candidates:
if path in seen:
continue
seen.add(path)
if os.path.exists(path):
try:
with open(path, encoding="utf-8") as f:
ts = json.load(f)
print(
f"[TrainSummary] Loaded training_summary from: {path}"
)
return ts
except (
Exception
) as e: # pragma: no cover - defensive
print(
f"[Warn] Could not read training_summary JSON at {path!r}: {e}"
)
# Try the next candidate
continue
return None
def extract_preds(
model: Any,
out: Any,
*,
strict: bool = True,
output_names: Sequence[str] | None = None,
) -> tuple[Any, Any]:
r"""
Extract (subs_pred, gwl_pred) from GeoPrior outputs.
Supports:
1) v3.2+ call(): {"subs_pred","gwl_pred"}
2) forward_with_aux(): (y_pred, aux)
3) legacy: {"data_final"} + model.split_data_predictions
4) predict(): list/tuple mapped via output names
If `strict=True`, list/tuple outputs *must* be mappable via
output names; otherwise we raise to avoid silent swaps.
This helper normalizes the output interface across two
GeoPrior generation families:
1. New interface (preferred)
``model(inputs) -> {"subs_pred": ..., "gwl_pred": ...}``
2. Legacy interface (backward compatible)
``model(inputs) -> {"data_final": ...}``, where the caller
must split the tensor using ``model.split_data_predictions``.
Parameters
----------
model : object
A Keras-like model instance that may expose
``split_data_predictions(data_final)``.
The splitter must return a tuple:
- ``subs_pred`` with shape ``(B, H, 1)`` or ``(B, H, Q, 1)``
- ``gwl_pred`` with shape ``(B, H, 1)`` or ``(B, H, Q, 1)``
out : dict
Output returned by the model call, typically
``model(inputs, training=False)``.
Supported keys are either:
- ``{"subs_pred", "gwl_pred"}`` (new interface), or
- ``{"data_final"}`` (legacy interface).
Returns
-------
subs_pred : Tensor
Predicted subsidence in model space.
Expected shapes:
- Point mode: ``(B, H, 1)``
- Quantile mode: ``(B, H, Q, 1)``
gwl_pred : Tensor
Predicted groundwater/head variable in model space.
Expected shapes:
- Point mode: ``(B, H, 1)``
- Quantile mode: ``(B, H, Q, 1)``
Raises
------
KeyError
If ``out`` does not contain a supported key set.
TypeError
If ``out`` is not a mapping/dict-like object.
Notes
-----
This function is intended for Stage-2 and Stage-3
scripts where you may load checkpoints from older
experiments. It avoids fragile code that slices
``data_final`` manually.
The function does not validate tensor dtypes or
numerical finiteness. Upstream code should handle
``NaN`` and ``Inf`` checks as needed. Output normalization
follows the Keras model conventions documented in
:cite:t:`KerasDocs`.
Examples
--------
New interface::
out = model_inf(xb, training=False)
s_pred, h_pred = extract_stage_outputs(
model_inf,
out,
)
Legacy interface::
out = model_inf(xb, training=False)
s_pred, h_pred = extract_stage_outputs(
model_inf,
out,
)
See Also
--------
subs_point_from_stage_out :
Convert subsidence predictions to a point forecast.
"""
# ---------------------------------------------------------
# 0) forward_with_aux() style: (y_pred, aux)
# ---------------------------------------------------------
if isinstance(out, tuple) and len(out) == 2:
y_pred, aux = out
if isinstance(y_pred, Mapping):
out = y_pred
elif isinstance(aux, Mapping):
# fallback: sometimes callers pass aux by mistake
out = aux
# ---------------------------------------------------------
# 1) Mapping outputs
# ---------------------------------------------------------
if isinstance(out, Mapping):
has_new = ("subs_pred" in out) and ("gwl_pred" in out)
if has_new:
return out["subs_pred"], out["gwl_pred"]
# Legacy: data_final -> split
if "data_final" in out and hasattr(
model, "split_data_predictions"
):
return model.split_data_predictions(
out["data_final"]
)
# Single-key wrapper: unwrap one level and retry
if len(out) == 1:
only_val = next(iter(out.values()))
return extract_preds(
model,
only_val,
strict=strict,
output_names=output_names,
)
raise KeyError(
"Unsupported model output keys. Expected "
"{'subs_pred','gwl_pred'} or {'data_final'} "
"or a single-key wrapper. "
f"Got keys={list(out.keys())!r}."
)
# ---------------------------------------------------------
# 2) predict() outputs as list/tuple
# ---------------------------------------------------------
if isinstance(out, list | tuple):
names = None
if output_names is not None:
names = list(output_names)
else:
names = getattr(model, "output_names", None)
if names and len(names) == len(out):
mapped = dict(zip(names, out, strict=False))
return extract_preds(
model,
mapped,
strict=strict,
output_names=names,
)
if not strict and len(out) >= 2:
# last-resort, opt-in only
return out[0], out[1]
raise TypeError(
"Model output is a list/tuple but cannot be mapped "
"to names. Provide `output_names=...` or set "
"`strict=False` to assume order."
)
raise TypeError(
"Expected `out` as Mapping, (y_pred, aux), "
"or list/tuple. "
f"Got type={type(out)!r}."
)
def subs_point_from_out(
model, out, quantiles=None, med_idx=None
):
r"""
Convert model output into a subsidence point forecast.
This helper produces a subsidence tensor shaped ``(B, H, 1)``
in model space, regardless of whether the model emits
quantiles or a point prediction.
- If quantiles are present and the subsidence prediction
is shaped ``(B, H, Q, 1)``, the function selects the
median quantile slice.
- Otherwise, it returns the point prediction directly.
Parameters
----------
model : object
A Keras-like model instance passed to
:func:`extract_stage_outputs`.
out : dict
Output returned by the model call.
This can be either the new interface with keys
``"subs_pred"`` and ``"gwl_pred"``, or the legacy
interface with key ``"data_final"``.
quantiles : sequence of float or None, default=None
Quantile levels used by the model, such as
``[0.1, 0.5, 0.9]``.
If provided, the function may use it to interpret
the rank-4 quantile output and select the median.
If ``None``, quantile selection is disabled unless
``med_idx`` is explicitly provided and the tensor
rank indicates quantiles.
med_idx : int or None, default=None
Index along the quantile axis to use as the
"point" forecast when quantiles are available.
If ``None`` and ``quantiles`` is provided, the
function selects the index closest to ``0.5``.
Returns
-------
subs_point : Tensor
Subsidence point prediction in model space with
shape ``(B, H, 1)``.
Raises
------
ValueError
If subsidence prediction is missing or ``None``.
ValueError
If a quantile tensor is detected but a valid
median index cannot be resolved.
Notes
-----
Quantile outputs are assumed to be shaped
``(B, H, Q, 1)`` where the quantile axis is the
third dimension (axis=2).
If the model returns point predictions already,
the function is effectively a no-op. The quantile
interpretation used here follows
:cite:t:`KoenkerBassett1978`.
Examples
--------
Quantile model::
out = model_inf(xb, training=False)
s_point = subs_point_from_stage_out(
model_inf,
out,
quantiles=[0.1, 0.5, 0.9],
)
Point model::
out = model_inf(xb, training=False)
s_point = subs_point_from_stage_out(
model_inf,
out,
)
See Also
--------
extract_stage_outputs :
Normalize outputs across new and legacy checkpoints.
"""
subs_pred, _ = extract_preds(model, out)
if subs_pred is None:
raise ValueError("Model output 'subs_pred' is None.")
# has_rank = hasattr(subs_pred, "shape") and (
# getattr(subs_pred.shape, "rank", None) is not None
# )
# is_quantile_tensor = has_rank and (subs_pred.shape.rank == 4)
rank = None
if hasattr(subs_pred, "shape"):
rank = getattr(
subs_pred.shape, "rank", None
) # TF tensors
if rank is None:
try:
rank = len(
subs_pred.shape
) # NumPy arrays / tuples
except Exception:
rank = None
is_quantile_tensor = rank == 4
if not is_quantile_tensor:
return subs_pred
if med_idx is None:
if not quantiles:
raise ValueError(
"Quantile tensor detected but `med_idx` "
"is None and `quantiles` is not provided."
)
q = np.asarray(quantiles, dtype=float)
med_idx = int(np.argmin(np.abs(q - 0.5)))
if med_idx is None or int(med_idx) < 0:
raise ValueError(
"Invalid `med_idx` resolved for quantiles."
)
# return subs_pred[..., int(med_idx), :]
# Quantile outputs assumed (B, H, Q, 1)
return subs_pred[:, :, int(med_idx), :]
def _extract_allowed_hps(
obj: object,
*,
allowed: set[str],
) -> dict:
out: dict = {}
def _walk(x: object) -> None:
if isinstance(x, dict):
for k, v in x.items():
if isinstance(k, str) and k in allowed:
out[k] = v
_walk(v)
elif isinstance(x, list | tuple):
for it in x:
_walk(it)
_walk(obj)
return {k: out[k] for k in allowed if k in out}
def load_tuned_hps_near_model(
model_path: str,
*,
prefer: str = "keras",
required: bool = True,
log_fn=None,
) -> dict:
log = log_fn if callable(log_fn) else None
def _msg(s: str) -> None:
if log is not None:
log(s)
def _load_json(p: str) -> dict:
with open(p, encoding="utf-8") as f:
return json.load(f)
mp = os.path.abspath(str(model_path))
run_dir = mp if os.path.isdir(mp) else os.path.dirname(mp)
base = "" if os.path.isdir(mp) else os.path.basename(mp)
stem = None
if prefer == "keras":
if base.endswith("_best.keras"):
stem = base[: -len("_best.keras")]
elif base.endswith(".keras"):
stem = base[: -len(".keras")]
else:
if base.endswith("_best.weights.h5"):
stem = base[: -len("_best.weights.h5")]
elif base.endswith(".weights.h5"):
stem = base[: -len(".weights.h5")]
cands: list[str] = []
if stem:
cands.append(
os.path.join(run_dir, stem + "_best_hps.json")
)
cands.append(os.path.join(run_dir, "tuning_summary.json"))
cands.extend(
glob.glob(
os.path.join(run_dir, "*tuning_summary*.json")
)
)
for p in cands:
if not os.path.exists(p):
continue
data = _load_json(p)
hps = data.get("best_hps") or data.get("hps") or {}
if isinstance(hps, dict) and hps:
_msg(f"[HP] tuned: {p}")
return hps
if required:
raise FileNotFoundError(
"No tuned hyperparameters found near:\n"
f" model_path={model_path!r}\n"
f" run_dir={run_dir!r}\n"
)
return {}
def load_trained_hps_near_model(
model_path: str,
*,
allowed: set[str],
required: bool = False,
log_fn=None,
) -> dict:
log = log_fn if callable(log_fn) else None
def _msg(s: str) -> None:
if log is not None:
log(s)
def _load_json(p: str) -> object:
with open(p, encoding="utf-8") as f:
return json.load(f)
mp = os.path.abspath(str(model_path))
run_dir = mp if os.path.isdir(mp) else os.path.dirname(mp)
pats = [
os.path.join(run_dir, "model_init_manifest.json"),
os.path.join(run_dir, "*training_summary*.json"),
os.path.join(run_dir, "*architecture*.json"),
]
files: list[str] = []
for pat in pats:
files.extend(glob.glob(pat))
for p in files:
if not os.path.exists(p):
continue
data = _load_json(p)
hps = _extract_allowed_hps(data, allowed=allowed)
if hps:
_msg(f"[HP] trained: {p}")
return hps
if required:
raise FileNotFoundError(
"No trained hyperparameters found near:\n"
f" model_path={model_path!r}\n"
f" run_dir={run_dir!r}\n"
)
return {}
def load_hps_auto_near_model(
model_path: str,
*,
allowed: set[str],
prefer: str = "keras",
required: bool = False,
log_fn=None,
) -> dict:
mp = os.path.abspath(str(model_path))
run_dir = mp if os.path.isdir(mp) else os.path.dirname(mp)
tuned_hits = []
tuned_hits.extend(
glob.glob(os.path.join(run_dir, "*_best_hps.json"))
)
tuned_hits.extend(
glob.glob(
os.path.join(run_dir, "*tuning_summary*.json")
)
)
is_tuned = bool(tuned_hits) or ("tuning" in run_dir)
if is_tuned:
return load_tuned_hps_near_model(
model_path,
prefer=prefer,
required=required,
log_fn=log_fn,
)
return load_trained_hps_near_model(
model_path,
allowed=allowed,
required=required,
log_fn=log_fn,
)
def load_best_hps_near_model(
model_path: str,
*,
model_name: str | None = "GeoPriorSubsNet",
prefer: str = "keras",
log_fn=None,
) -> dict:
"""
Load best hyperparameters saved near a model artifact.
Supports model names like:
<city>_<model_name>_H{H}_best.keras
<city>_<model_name>_H{H}_best.weights.h5
Parameters
----------
model_path : str
Path to a model file or its run directory.
model_name : str or None, default="GeoPriorSubsNet"
Model name token in filenames.
prefer : {"keras", "weights"}, default="keras"
Which artifact type to infer the prefix from.
log_fn : callable or None, default=None
Logger (e.g. print). None disables logs.
Returns
-------
best_hps : dict
Non-empty hyperparameter dictionary.
Raises
------
FileNotFoundError
If no hyperparameter JSON is found.
ValueError
If a candidate JSON exists but is empty/invalid.
"""
log = log_fn if callable(log_fn) else None
if prefer not in ("keras", "weights"):
raise ValueError(
"prefer must be 'keras' or 'weights'."
)
mp = os.path.abspath(str(model_path))
run_dir = mp if os.path.isdir(mp) else os.path.dirname(mp)
base = "" if os.path.isdir(mp) else os.path.basename(mp)
def _msg(s: str) -> None:
if log is not None:
log(s)
def _load_json(p: str) -> dict:
with open(p, encoding="utf-8") as f:
return json.load(f)
def _newest(paths: list[str]) -> str | None:
c = []
for p in paths:
try:
c.append((os.path.getmtime(p), p))
except Exception:
pass
if not c:
return None
c.sort(reverse=True)
return c[0][1]
# -------------------------------------------------
# If a directory is provided, infer "base" by scan.
# -------------------------------------------------
if not base:
pats = []
if prefer == "keras":
if model_name:
pats.append(
os.path.join(
run_dir,
f"*_{model_name}_H*_best.keras",
)
)
pats.append(
os.path.join(run_dir, "*_H*_best.keras")
)
pats.append(os.path.join(run_dir, "*_best.keras"))
else:
if model_name:
pats.append(
os.path.join(
run_dir,
f"*_{model_name}_H*_best.weights.h5",
)
)
pats.append(
os.path.join(
run_dir,
"*_H*_best.weights.h5",
)
)
pats.append(
os.path.join(run_dir, "*_best.weights.h5")
)
hits = []
for pat in pats:
hits.extend(glob.glob(pat))
best = _newest(hits)
if best:
base = os.path.basename(best)
# -------------------------------------------------
# Infer stem/prefix from the artifact filename.
# -------------------------------------------------
stem = None
if prefer == "keras":
if base.endswith("_best.keras"):
stem = base[: -len("_best.keras")]
elif base.endswith(".keras"):
stem = base[: -len(".keras")]
else:
if base.endswith("_best.weights.h5"):
stem = base[: -len("_best.weights.h5")]
elif base.endswith(".weights.h5"):
stem = base[: -len(".weights.h5")]
# Try to parse city / model / horizon from stem.
city = None
mname = None
horizon = None
city_model = None
if stem:
left = stem
if "_H" in left:
a, b = left.rsplit("_H", 1)
digs = []
for ch in b:
if ch.isdigit():
digs.append(ch)
else:
break
if digs:
horizon = int("".join(digs))
left = a
city_model = left
if model_name:
tok = "_" + str(model_name)
if city_model.endswith(tok):
city = city_model[: -len(tok)]
mname = str(model_name)
if city is None:
parts = city_model.split("_")
if len(parts) >= 2:
city = parts[0]
mname = "_".join(parts[1:])
# -------------------------------------------------
# 1) Near-model explicit JSONs.
# -------------------------------------------------
cands = []
if stem:
cands.append(
os.path.join(run_dir, stem + "_best_hps.json")
)
if city and mname:
cands.append(
os.path.join(
run_dir,
f"{city}_{mname}_best_hps.json",
)
)
if horizon is not None:
cands.append(
os.path.join(
run_dir,
f"{city}_{mname}_H{horizon}_best_hps.json",
)
)
if city:
cands.append(
os.path.join(
run_dir,
f"{city}_{mname}_best_hps.json",
)
)
seen = set()
for p in cands:
if p in seen:
continue
seen.add(p)
if os.path.exists(p):
best_hps = _load_json(p)
if isinstance(best_hps, dict) and best_hps:
_msg(f"[HP] Loaded best_hps: {p}")
return best_hps
raise ValueError(
f"{p!r} exists but is empty/invalid."
)
# -------------------------------------------------
# 2) tuning summaries.
# -------------------------------------------------
sum_pats = [
os.path.join(run_dir, "tuning_summary.json"),
os.path.join(run_dir, "*tuning_summary*.json"),
]
sums = []
for pat in sum_pats:
sums.extend(glob.glob(pat))
for p in sums:
if not os.path.exists(p):
continue
s = _load_json(p)
best_hps = s.get("best_hps") or s.get("hps") or {}
if isinstance(best_hps, dict) and best_hps:
_msg(f"[HP] Loaded best_hps: {p}")
return best_hps
# -------------------------------------------------
# 3) training summaries.
# -------------------------------------------------
for p in glob.glob(
os.path.join(run_dir, "*training_summary*.json")
):
s = _load_json(p)
best_hps = (
s.get("best_hps")
or s.get("hps")
or s.get("params")
or {}
)
if isinstance(best_hps, dict) and best_hps:
_msg(f"[HP] Loaded best_hps: {p}")
return best_hps
# -------------------------------------------------
# 4) architecture dumps.
# -------------------------------------------------
for p in glob.glob(
os.path.join(run_dir, "*architecture*.json")
):
a = _load_json(p)
best_hps = (
a.get("best_hps")
or a.get("hps")
or a.get("params")
or {}
)
if isinstance(best_hps, dict) and best_hps:
_msg(f"[HP] Loaded best_hps: {p}")
return best_hps
raise FileNotFoundError(
"Could not find best hyperparameters near:\n"
f" model_path={model_path!r}\n"
f" run_dir={run_dir!r}\n"
f" prefer={prefer!r}\n"
"Looked for *_best_hps.json + summaries."
)
def coerce_quantile_weights(
d: dict | None,
default: dict,
) -> dict:
"""
Normalize a quantile-weight mapping to have float keys and float values.
This helper is useful when reading JSON configs where the quantile
keys are stored as strings (e.g. ``{'0.1': 3.0, '0.5': 1.0}``).
Parameters
----------
d : dict or None
Original dictionary mapping quantile-like keys (str or float) to
numeric weights. If ``None`` or empty, ``default`` is returned.
default : dict
Fallback dictionary to use when ``d`` is ``None`` or empty.
Returns
-------
out : dict
Dictionary with the same keys and values, but with:
- keys coerced to float when possible (otherwise left as-is),
- values coerced to ``float``.
"""
if not d:
return default
out: dict[Any, float] = {}
for k, v in d.items():
try:
q = float(k)
except (TypeError, ValueError):
# Non-numeric key (rare): keep as-is
q = k
out[q] = float(v)
return out
def compile_for_eval(
model: Any,
manifest: dict,
best_hps: dict | None,
quantiles: list[float] | None,
*,
include_metrics: bool = True,
) -> Any:
"""
Recompile a GeoPriorSubsNet instance for evaluation / diagnostics.
This is intended for:
- tuned models loaded from a `.keras` archive, or
- models rebuilt from best_hps.
It does NOT change the architecture or weights, only the compile
configuration (optimizer, losses, and physics loss weights).
Parameters
----------
model : GeoPriorSubsNet
Loaded or freshly-built GeoPriorSubsNet instance.
manifest : dict
Stage-1 manifest; training config is taken from
``manifest['config']``.
best_hps : dict or None
Dictionary of tuned hyperparameters. If empty/None, reasonable
defaults are inferred from the manifest.
quantiles : list of float or None
Quantiles used for probabilistic subsidence/GWL outputs.
include_metrics : bool, default=True
If True, attach MAE/MSE + coverage/sharpness metrics to match
the training script; if False, only losses are configured.
Returns
-------
model :
The same model instance, compiled in-place.
"""
if not TF_AVAILABLE:
raise ImportError(
"TensorFlow is required to compile GeoPriorSubsNet. "
"Install `tensorflow>=2.12` to use "
"`compile_geoprior_for_eval`."
)
# Local imports so nat_utils.py itself stays lightweight
from geoprior.nn.losses import make_weighted_pinball
if include_metrics:
from geoprior.nn.keras_metrics import (
coverage80_fn,
sharpness80_fn,
)
cfg = manifest.get("config", {}) or {}
best_hps = best_hps or {}
# ---- 1. Data loss weights / quantile weights -------------------------
subs_raw = cfg.get(
"SUBS_WEIGHTS",
{0.1: 3.0, 0.5: 1.0, 0.9: 3.0},
)
gwl_raw = cfg.get(
"GWL_WEIGHTS",
{0.1: 1.5, 0.5: 1.0, 0.9: 1.5},
)
subs_w = _coerce_quantile_weights(
subs_raw, {0.1: 3.0, 0.5: 1.0, 0.9: 3.0}
)
gwl_w = _coerce_quantile_weights(
gwl_raw, {0.1: 1.5, 0.5: 1.0, 0.9: 1.5}
)
if quantiles:
loss_dict = {
"subs_pred": make_weighted_pinball(
quantiles, subs_w
),
"gwl_pred": make_weighted_pinball(
quantiles, gwl_w
),
}
else:
mse = tf.keras.losses.MeanSquaredError()
loss_dict = {"subs_pred": mse, "gwl_pred": mse}
loss_weights = {"subs_pred": 1.0, "gwl_pred": 0.5}
# ---- 2. Physics weights: prefer best_hps, fall back to config --------
def _hp_or_cfg(
hp_key: str, cfg_key: str, default: float
) -> float:
if (
hp_key in best_hps
and best_hps[hp_key] is not None
):
return float(best_hps[hp_key])
if cfg_key in cfg and cfg[cfg_key] is not None:
return float(cfg[cfg_key])
return float(default)
lr = _hp_or_cfg("learning_rate", "LEARNING_RATE", 1e-4)
physics_kwargs = {
"lambda_gw": _hp_or_cfg(
"lambda_gw", "LAMBDA_GW", 1.0
),
"lambda_cons": _hp_or_cfg(
"lambda_cons", "LAMBDA_CONS", 1.0
),
"lambda_prior": _hp_or_cfg(
"lambda_prior", "LAMBDA_PRIOR", 0.1
),
"lambda_smooth": _hp_or_cfg(
"lambda_smooth", "LAMBDA_SMOOTH", 0.01
),
"lambda_mv": _hp_or_cfg(
"lambda_mv", "LAMBDA_MV", 0.0
),
"mv_lr_mult": _hp_or_cfg(
"mv_lr_mult", "MV_LR_MULT", 1.0
),
"kappa_lr_mult": _hp_or_cfg(
"kappa_lr_mult", "KAPPA_LR_MULT", 1.0
),
}
compile_kwargs: dict[str, Any] = {
"optimizer": Adam(learning_rate=lr),
"loss": loss_dict,
"loss_weights": loss_weights,
**physics_kwargs,
}
if include_metrics:
metrics_dict = {
"subs_pred": ["mae", "mse"]
+ (
[coverage80_fn, sharpness80_fn]
if quantiles
else []
),
"gwl_pred": ["mae", "mse"],
}
compile_kwargs["metrics"] = metrics_dict
model.compile(**compile_kwargs)
return model
def compile_geoprior_for_eval(
model: Any, # type: ignore[override]
manifest: dict,
best_hps: dict,
quantiles: list[float] | None,
) -> Any:
"""
(Re)compile a GeoPriorSubsNet-like model for evaluation.
This helper uses the Stage-1 manifest and tuned hyperparameters to
configure:
- the pinball losses for subsidence and GWL outputs,
- loss weights for the two heads,
- physics loss weights (lambda_*),
- learning rate and LR multipliers.
TensorFlow and geoprior are imported lazily inside this function so
that ``nat_utils`` can be imported even in non-TF environments.
Parameters
----------
model : GeoPriorSubsNet-like
An instance of the GeoPriorSubsNet model (or any model exposing
the same compile signature).
manifest : dict
Stage-1 manifest dictionary. The ``config`` entry is used to
retrieve default loss weights and physics settings.
best_hps : dict
Hyperparameters loaded from the tuning run
(e.g. via :func:`load_best_hps_near_model`).
quantiles : list of float or None
Quantile levels used for probabilistic outputs. If ``None``,
mean-squared error is used instead of pinball loss.
Returns
-------
model
The same model instance, compiled in-place.
Raises
------
ImportError
If TensorFlow or geoprior's ``make_weighted_pinball`` cannot be
imported.
"""
cfg = manifest.get("config", {}) or {}
# Lazy imports so nat_utils.py is importable without TensorFlow
try:
import tensorflow as tf # type: ignore
from tensorflow.keras.optimizers import (
Adam, # type: ignore
)
except Exception as e: # pragma: no cover - env dependent
raise ImportError(
"compile_geoprior_for_eval requires TensorFlow. "
"Please install 'tensorflow>=2.12' to use this helper."
) from e
try:
from geoprior.nn.losses import (
make_weighted_pinball, # type: ignore
)
except Exception as e: # pragma: no cover - env dependent
raise ImportError(
"compile_geoprior_for_eval requires "
"'geoprior.nn.losses.make_weighted_pinball'. "
"Ensure geoprior is installed and importable."
) from e
# Base loss weights between subsidence and GWL heads
loss_weights = {"subs_pred": 1.0, "gwl_pred": 0.5}
# Quantile-specific weights from config (with robust defaults)
subs_raw = cfg.get(
"SUBS_WEIGHTS", {0.1: 3.0, 0.5: 1.0, 0.9: 3.0}
)
gwl_raw = cfg.get(
"GWL_WEIGHTS", {0.1: 1.5, 0.5: 1.0, 0.9: 1.5}
)
subs_weights = coerce_quantile_weights(
subs_raw, {0.1: 3.0, 0.5: 1.0, 0.9: 3.0}
)
gwl_weights = coerce_quantile_weights(
gwl_raw, {0.1: 1.5, 0.5: 1.0, 0.9: 1.5}
)
if quantiles:
loss_subs = make_weighted_pinball(
quantiles, subs_weights
)
loss_gwl = make_weighted_pinball(
quantiles, gwl_weights
)
else:
loss_subs = tf.keras.losses.MSE
loss_gwl = tf.keras.losses.MSE
loss_dict = {"subs_pred": loss_subs, "gwl_pred": loss_gwl}
# Learning rate: tuned value or config fallback
lr_default = cfg.get("LEARNING_RATE", 5e-5)
lr = float(best_hps.get("learning_rate", lr_default))
optimizer = Adam(learning_rate=lr)
# Physics loss weights and LR multipliers
lambda_gw = float(
best_hps.get("lambda_gw", cfg.get("LAMBDA_GW", 1.0))
)
lambda_cons = float(
best_hps.get(
"lambda_cons", cfg.get("LAMBDA_CONS", 1.0)
)
)
lambda_prior = float(
best_hps.get(
"lambda_prior", cfg.get("LAMBDA_PRIOR", 1.0)
)
)
lambda_smooth = float(
best_hps.get(
"lambda_smooth", cfg.get("LAMBDA_SMOOTH", 1.0)
)
)
lambda_mv = float(
best_hps.get("lambda_mv", cfg.get("LAMBDA_MV", 0.0))
)
mv_lr_mult = float(
best_hps.get("mv_lr_mult", cfg.get("MV_LR_MULT", 1.0))
)
kappa_lr_mult = float(
best_hps.get(
"kappa_lr_mult", cfg.get("KAPPA_LR_MULT", 1.0)
)
)
model.compile(
optimizer=optimizer,
loss=loss_dict,
loss_weights=loss_weights,
# physics loss weights + LR multipliers
lambda_gw=lambda_gw,
lambda_cons=lambda_cons,
lambda_prior=lambda_prior,
lambda_smooth=lambda_smooth,
lambda_mv=lambda_mv,
mv_lr_mult=mv_lr_mult,
kappa_lr_mult=kappa_lr_mult,
)
return model
def build_geoprior_from_hps(
manifest: dict,
X_sample: dict,
best_hps: dict,
out_s_dim: int,
out_g_dim: int,
mode: str,
horizon: int,
quantiles: list[float] | None,
) -> Any:
"""
Reconstruct a GeoPriorSubsNet from Stage-1 metadata + tuned HPs.
This function is primarily intended as a **robust fallback** when
``tf.keras.models.load_model`` cannot deserialize a tuned model.
It reconstructs the network geometry and physics settings from:
- Stage-1 ``manifest['config']`` (for fixed architecture / physics),
- tuned hyperparameters (for variable architecture / physics),
- the Stage-1 input NPZ (for input dimensions).
Parameters
----------
manifest : dict
Stage-1 manifest dictionary.
X_sample : dict
Inputs NPZ dictionary (already sanitized and passed through
:func:`ensure_input_shapes`). Only shapes are used.
best_hps : dict
Hyperparameters loaded via :func:`load_best_hps_near_model`.
out_s_dim : int
Output dimension for the subsidence head.
out_g_dim : int
Output dimension for the GWL head.
mode : str
Sequence mode, e.g. ``"tft_like"`` or ``"pihal_like"``.
horizon : int
Forecast horizon (number of time steps).
quantiles : list of float or None
Quantile levels for probabilistic outputs.
Returns
-------
model : GeoPriorSubsNet
A freshly instantiated and compiled GeoPriorSubsNet instance.
Raises
------
ImportError
If GeoPriorSubsNet cannot be imported from geoprior.
"""
try:
from geoprior.nn.pinn.models import (
GeoPriorSubsNet, # type: ignore
)
except Exception as e: # pragma: no cover - env dependent
raise ImportError(
"build_geoprior_from_hps requires "
"'geoprior.nn.pinn.models.GeoPriorSubsNet'. "
"Ensure geoprior is installed and importable."
) from e
cfg = manifest.get("config", {}) or {}
# Infer input dims directly from NPZ
static_dim, dynamic_dim, future_dim = (
infer_input_dims_from_X(X_sample)
)
# Attention stack: fall back to a sensible default if not present
attention_levels = cfg.get(
"ATTENTION_LEVELS",
["cross", "hierarchical", "memory"],
)
# Whether we used effective H during Stage-2
censor_cfg = cfg.get("censoring", {}) or {}
use_effective_h = censor_cfg.get(
"use_effective_h_field", True
)
# Feature-processing mode controlled by tuned HPs
use_vsn = bool(best_hps.get("use_vsn", True))
feature_processing = "vsn" if use_vsn else "dense"
architecture_config = {
"encoder_type": "hybrid",
"decoder_attention_stack": attention_levels,
"feature_processing": feature_processing,
}
# Instantiate the model core with tuned settings
model = GeoPriorSubsNet(
static_input_dim=static_dim,
dynamic_input_dim=dynamic_dim,
future_input_dim=future_dim,
output_subsidence_dim=out_s_dim,
output_gwl_dim=out_g_dim,
forecast_horizon=horizon,
mode=mode,
attention_levels=attention_levels,
quantiles=quantiles,
# physics switches from best_hps
pde_mode=best_hps.get("pde_mode", "both"),
scale_pde_residuals=bool(
best_hps.get("scale_pde_residuals", True)
),
kappa_mode=best_hps.get("kappa_mode", "bar"),
use_effective_h=use_effective_h,
# architecture hyperparameters
embed_dim=int(best_hps.get("embed_dim", 32)),
hidden_units=int(best_hps.get("hidden_units", 96)),
lstm_units=int(best_hps.get("lstm_units", 96)),
attention_units=int(
best_hps.get("attention_units", 32)
),
num_heads=int(best_hps.get("num_heads", 4)),
dropout_rate=float(best_hps.get("dropout_rate", 0.1)),
use_vsn=use_vsn,
vsn_units=int(best_hps.get("vsn_units", 32)),
use_batch_norm=bool(
best_hps.get("use_batch_norm", True)
),
# geomechanical priors (floats interpreted internally by the model)
mv=float(best_hps.get("mv", 5e-7)),
kappa=float(best_hps.get("kappa", 1.0)),
architecture_config=architecture_config,
)
# Compile using the shared helper (losses + physics weights)
compile_geoprior_for_eval(
model=model,
manifest=manifest,
best_hps=best_hps,
quantiles=quantiles,
)
print(
"[Fallback] Reconstructed GeoPriorSubsNet from best_hps with "
f"static_dim={static_dim}, dynamic_dim={dynamic_dim}, "
f"future_dim={future_dim}, horizon={horizon}, mode={mode}"
)
return model
def build_geoprior_from_cfg(
manifest: dict,
X_sample: dict,
out_s_dim: int,
out_g_dim: int,
mode: str,
horizon: int,
quantiles: list[float] | None,
) -> Any:
"""
Reconstruct a GeoPriorSubsNet from the NATCOM config only.
This is intended as a fallback for *trained* models (no tuning JSON)
or when no best_hps can be found next to `model_path`.
Parameters
----------
manifest : dict
Stage-1 manifest dictionary. The ``"config"`` entry is used as
the single source of truth for architecture + physics settings.
X_sample : dict
NPZ inputs dict (already sanitised and passed through
:func:`ensure_input_shapes` or equivalent). Only shapes are used.
out_s_dim, out_g_dim : int
Output dims for subsidence and GWL heads.
mode : str
Sequence mode, e.g. ``"tft_like"`` or ``"pihal_like"``.
horizon : int
Forecast horizon (number of time steps).
quantiles : list of float or None
Quantiles for probabilistic outputs.
Returns
-------
model : GeoPriorSubsNet
Compiled model instance ready for prediction/eval.
"""
try:
from geoprior.nn.pinn.models import (
GeoPriorSubsNet, # type: ignore
)
except Exception as e: # pragma: no cover - env dependent
raise ImportError(
"build_geoprior_from_cfg requires "
"'geoprior.nn.pinn.models.GeoPriorSubsNet'. "
"Ensure geoprior is installed and importable."
) from e
cfg = manifest.get("config", {}) or {}
# --- Input dims inferred from X_sample ----------------------------
static_dim, dynamic_dim, future_dim = (
infer_input_dims_from_X(X_sample)
)
# --- Attention stack / effective-H flag ---------------------------
attention_levels = cfg.get(
"ATTENTION_LEVELS",
["cross", "hierarchical", "memory"],
)
censor_cfg = cfg.get("censoring", {}) or {}
use_effective_h = censor_cfg.get(
"use_effective_h_field",
bool(cfg.get("GEOPRIOR_USE_EFFECTIVE_H", True)),
)
# --- Physics switches ---------------------------------------------
pde_mode = cfg.get(
"PDE_MODE_CONFIG", cfg.get("PDE_MODE", "both")
)
scale_pde_residuals = bool(
cfg.get(
"SCALE_PDE_RESIDUALS",
cfg.get("SCALE_PDE_RES", True),
)
)
kappa_mode = cfg.get(
"GEOPRIOR_KAPPA_MODE",
cfg.get("KAPPA_MODE", "bar"),
)
# --- Small helpers to read ints/floats/bools from cfg ------------
def _cfg_int(default: int, *keys: str) -> int:
for k in keys:
if k in cfg and cfg[k] is not None:
try:
return int(cfg[k])
except Exception:
pass
return int(default)
def _cfg_float(default: float, *keys: str) -> float:
for k in keys:
if k in cfg and cfg[k] is not None:
try:
return float(cfg[k])
except Exception:
pass
return float(default)
def _cfg_bool(default: bool, *keys: str) -> bool:
for k in keys:
if k in cfg and cfg[k] is not None:
return bool(cfg[k])
return bool(default)
# --- Architecture hyperparams (config-side defaults) --------------
embed_dim = _cfg_int(
32, "EMBED_DIM", "GEOPRIOR_EMBED_DIM"
)
hidden_units = _cfg_int(
96, "HIDDEN_UNITS", "GEOPRIOR_HIDDEN_UNITS"
)
lstm_units = _cfg_int(
96, "LSTM_UNITS", "GEOPRIOR_LSTM_UNITS"
)
attention_units = _cfg_int(
32, "ATTENTION_UNITS", "GEOPRIOR_ATTENTION_UNITS"
)
num_heads = _cfg_int(
4, "NUM_HEADS", "NUMBER_HEADS", "GEOPRIOR_NUM_HEADS"
)
dropout_rate = _cfg_float(
0.1, "DROPOUT_RATE", "GEOPRIOR_DROPOUT_RATE"
)
use_vsn = _cfg_bool(True, "USE_VSN", "GEOPRIOR_USE_VSN")
vsn_units = _cfg_int(
32, "VSN_UNITS", "GEOPRIOR_VSN_UNITS"
)
use_batch_norm = _cfg_bool(
True, "USE_BATCH_NORM", "GEOPRIOR_USE_BATCH_NORM"
)
# --- Geomechanical priors (Terzaghi-ish) --------------------------
mv = _cfg_float(5e-7, "GEOPRIOR_INIT_MV")
kappa = _cfg_float(1.0, "GEOPRIOR_INIT_KAPPA")
architecture_config = {
"encoder_type": "hybrid",
"decoder_attention_stack": attention_levels,
"feature_processing": "vsn" if use_vsn else "dense",
}
model = GeoPriorSubsNet(
static_input_dim=static_dim,
dynamic_input_dim=dynamic_dim,
future_input_dim=future_dim,
output_subsidence_dim=out_s_dim,
output_gwl_dim=out_g_dim,
forecast_horizon=horizon,
mode=mode,
attention_levels=attention_levels,
quantiles=quantiles,
# physics switches
pde_mode=pde_mode,
scale_pde_residuals=scale_pde_residuals,
kappa_mode=kappa_mode,
use_effective_h=use_effective_h,
# architecture hyperparameters
embed_dim=embed_dim,
hidden_units=hidden_units,
lstm_units=lstm_units,
attention_units=attention_units,
num_heads=num_heads,
dropout_rate=dropout_rate,
use_vsn=use_vsn,
vsn_units=vsn_units,
use_batch_norm=use_batch_norm,
# priors
mv=mv,
kappa=kappa,
architecture_config=architecture_config,
)
# Compile using config-only settings
compile_for_eval(
model=model,
manifest=manifest,
best_hps=None,
quantiles=quantiles,
include_metrics=True,
)
print(
"[Fallback] Reconstructed GeoPriorSubsNet from manifest config with "
f"static_dim={static_dim}, dynamic_dim={dynamic_dim}, "
f"future_dim={future_dim}, horizon={horizon}, mode={mode}"
)
return model
def infer_best_weights_path(model_path: str) -> str | None:
"""
Infer the best-weights checkpoint path for a tuned GeoPrior model.
Strategy
--------
1. Look for ``tuning_summary.json`` in the same folder as
``model_path`` and return the stored ``\"best_weights_path\"``
if it exists and the file is present on disk.
2. Fallback: replace the ``.keras`` suffix of ``model_path`` by
``.weights.h5``, assuming the convention::
<CITY>_GeoPrior_best.keras
-> <CITY>_GeoPrior_best.weights.h5
Parameters
----------
model_path : str
Path to the tuned model archive (usually ``.keras``).
Returns
-------
weights_path : str or None
Absolute path to the weights file if found, otherwise ``None``.
"""
run_dir = os.path.dirname(os.path.abspath(model_path))
# 1) Preferred: tuning_summary.json
summary_path = os.path.join(
run_dir, "tuning_summary.json"
)
if os.path.exists(summary_path):
try:
with open(summary_path, encoding="utf-8") as f:
summary = json.load(f)
w = summary.get("best_weights_path")
if w and os.path.exists(w):
return w
except Exception as e: # pragma: no cover - defensive
print(
f"[Warn] Could not read tuning_summary.json for weights: {e}"
)
# 2) Name-based guess from the .keras path
root, ext = os.path.splitext(model_path)
guess = root + ".weights.h5"
if os.path.exists(guess):
return guess
return None
def _load_or_rebuild_geoprior_model(
model_path: str,
manifest: dict,
X_sample: dict,
out_s_dim: int,
out_g_dim: int,
mode: str,
horizon: int,
quantiles: list[float] | None,
city_name: str | None = None,
compile_on_load: bool = True,
verbose: int = 1,
):
"""
Load a tuned GeoPriorSubsNet from disk, with robust rebuild fallback.
This helper centralizes the logic:
1. Try to load the model from ``model_path`` via
:func:`tf.keras.models.load_model`, with all required custom
objects registered (GeoPriorSubsNet, LearnableMV, etc.).
2. If loading fails (e.g. due to environment or serialization
changes), attempt a robust fallback:
- Load the tuned hyperparameters via
:func:`load_best_hps_near_model`.
- Rebuild a compatible GeoPriorSubsNet instance using
:func:`build_geoprior_from_hps`, based on Stage-1
``manifest['config']`` and an input sample ``X_sample``.
- Find the best weights checkpoint via
:func:`infer_best_weights_path` and load them into the
rebuilt model, if available.
Parameters
----------
model_path : str
Path to the tuned model archive (usually ``.keras``) produced
by the tuner, e.g.::
.../tuning/run_YYYYMMDD-HHMMSS/nansha_GeoPrior_best.keras
manifest : dict
Stage-1 manifest dictionary; its ``"config"`` entry is used to
reconstruct the compile/physics configuration when rebuilding.
X_sample : dict
One NPZ inputs dictionary (e.g. validation or train NPZ) that
has already been sanitized and passed through
:func:`ensure_input_shapes`. Only its shapes are used to infer
input dimensions.
out_s_dim : int
Output dimension for the subsidence head
(usually from ``M['artifacts']['sequences']['dims']``).
out_g_dim : int
Output dimension for the GWL head.
mode : str
Sequence mode, e.g. ``"tft_like"`` or ``"pihal_like"``.
horizon : int
Forecast horizon (number of time steps).
quantiles : list of float or None
Quantile levels used for probabilistic outputs. If ``None``,
the model is treated as a point-forecast model.
city_name : str or None, optional
City name for log messages. If ``None``, a neutral label is
used in logs.
compile_on_load : bool, default=True
Whether to pass ``compile=True`` to :func:`load_model`. If
``False``, the model is loaded uncompiled, and only the
rebuilt branch is compiled via
:func:`build_geoprior_from_hps`.
verbose : int, default=1
Verbosity level for log messages (0 = silent, 1 = info).
Returns
-------
model :
A GeoPriorSubsNet instance (or compatible model) ready for
evaluation/prediction.
best_hps : dict or None
Dictionary of tuned hyperparameters if they were loaded during
the fallback path; otherwise ``None``.
Raises
------
ImportError
If TensorFlow or required geoprior components cannot be
imported.
RuntimeError
If both direct loading and fallback reconstruction fail.
"""
label_city = city_name or "GeoPrior"
# --- Lazy imports so nat_utils can be imported without TF/geoprior ---
try:
import tensorflow as tf # type: ignore # noqa
from tensorflow.keras.models import load_model # type: ignore
from tensorflow.keras.utils import custom_object_scope # type: ignore
except Exception as e: # pragma: no cover - env dependent
raise ImportError(
"load_or_rebuild_geoprior_model requires TensorFlow. "
"Please install 'tensorflow>=2.12' to use this helper."
) from e
try:
from geoprior.nn.keras_metrics import ( # type: ignore
coverage80_fn,
sharpness80_fn,
)
from geoprior.nn.losses import (
make_weighted_pinball, # type: ignore
)
from geoprior.nn.pinn.models import (
GeoPriorSubsNet, # type: ignore
)
from geoprior.params import ( # type: ignore
FixedGammaW,
FixedHRef,
LearnableKappa,
LearnableMV,
)
except Exception as e: # pragma: no cover - env dependent
raise ImportError(
"load_or_rebuild_geoprior_model requires geoprior components "
"(GeoPriorSubsNet, LearnableMV, etc.). Ensure geoprior is "
"installed and importable."
) from e
# ------------------- 1) Try direct load_model -------------------------
custom_objects = {
"GeoPriorSubsNet": GeoPriorSubsNet,
"LearnableMV": LearnableMV,
"LearnableKappa": LearnableKappa,
"FixedGammaW": FixedGammaW,
"FixedHRef": FixedHRef,
# custom loss factory / class
"make_weighted_pinball": make_weighted_pinball,
# custom metrics used in compile
"coverage80_fn": coverage80_fn,
"sharpness80_fn": sharpness80_fn,
}
best_hps: dict | None = None
with custom_object_scope(custom_objects):
if verbose:
print(
f"[Model] Attempting to load tuned model from: {model_path}"
)
# try:
model = load_model(
model_path, compile=compile_on_load
)
if verbose:
print(
f"[Model] Successfully loaded tuned model for {label_city} "
f"from: {model_path}"
)
return model, best_hps
# except Exception as e_load:
# if verbose:
# print(
# f"[Warn] load_model('{model_path}') failed: {e_load}\n"
# "[Warn] Attempting robust fallback: rebuild GeoPriorSubsNet "
# "from tuned hyperparameters."
# )
# ------------------- 2) Fallback: rebuild + load weights --------------
# 2.1 Hyperparameters near the tuned model
try:
best_hps = load_best_hps_near_model(model_path)
except Exception as e_hps:
raise RuntimeError(
"Failed to load tuned hyperparameters for fallback model "
f"reconstruction near model_path={model_path!r}: {e_hps}"
) from e_hps
# 2.2 Rebuild architecture + compile using Stage-1 manifest + best_hps
try:
model = build_geoprior_from_hps(
manifest=manifest,
X_sample=X_sample,
best_hps=best_hps,
out_s_dim=out_s_dim,
out_g_dim=out_g_dim,
mode=mode,
horizon=horizon,
quantiles=quantiles,
)
except Exception as e_build:
raise RuntimeError(
"Failed to reconstruct GeoPriorSubsNet from best_hps. "
f"Error: {e_build}"
) from e_build
# 2.3 Load weights into the rebuilt model, if a checkpoint is found
weights_path = infer_best_weights_path(model_path)
if weights_path is not None:
try:
model.load_weights(weights_path)
if verbose:
print(
"[Fallback] Loaded weights into rebuilt GeoPriorSubsNet "
f"from: {weights_path}"
)
except Exception as e_w:
# We still return the rebuilt model, but warn that it is not
# weight-identical to the tuned run.
if verbose:
print(
"[Warn] Could not load weights from checkpoint:\n"
f" {weights_path}\n"
f" Error: {e_w}\n"
" The rebuilt model is using freshly-initialized "
"weights. Predictions will NOT match the tuned model."
)
else:
if verbose:
print(
"[Warn] No weights checkpoint found near tuned model.\n"
" Using rebuilt model with freshly-initialized "
"weights. Predictions will NOT match the tuned model."
)
return model, best_hps
def infer_input_dims_from_X(X: dict) -> tuple[int, int, int]:
"""
Infer (static_input_dim, dynamic_input_dim, future_input_dim)
from NPZ inputs.
This is a public, defensive version of the former
``_infer_input_dims_from_X`` helper.
Parameters
----------
X : dict
Dictionary with keys:
- ``'dynamic_features'`` (required, shape (N, T, D_dyn))
- ``'static_features'`` (optional, shape (N, D_static) or None)
- ``'future_features'`` (optional, shape (N, T_future, D_future) or None)
Returns
-------
static_dim : int
Last-dimension size of ``static_features`` (0 if missing or None).
dynamic_dim : int
Last-dimension size of ``dynamic_features``. Raises if missing.
future_dim : int
Last-dimension size of ``future_features`` (0 if missing or None).
Raises
------
KeyError
If ``'dynamic_features'`` is missing in ``X``.
"""
if "dynamic_features" not in X:
raise KeyError(
"X must contain key 'dynamic_features' with shape (N, T, D_dyn)."
)
dyn = np.asarray(X["dynamic_features"])
dynamic_dim = int(dyn.shape[-1])
static = X.get("static_features", None)
static_dim = (
int(np.asarray(static).shape[-1])
if static is not None
else 0
)
fut = X.get("future_features", None)
future_dim = (
int(np.asarray(fut).shape[-1])
if fut is not None
else 0
)
return static_dim, dynamic_dim, future_dim
# -------------------------------------------------------------------------
# Backward-compatible aliases for old private helper names
# -------------------------------------------------------------------------
safe_compile = compile_for_eval
_infer_input_dims_from_X = infer_input_dims_from_X
_load_best_hps_near_model = load_best_hps_near_model
_coerce_quantile_weights = coerce_quantile_weights
_compile_geoprior_for_eval = compile_geoprior_for_eval
_build_geoprior_from_hps = build_geoprior_from_hps
_infer_best_weights_path = infer_best_weights_path
_build_geoprior_from_cfg = build_geoprior_from_cfg
# Back-compat alias (docstrings still mention it)
extract_stage_outputs = extract_preds
Optional: model/PINN helper module#
# SPDX-License-Identifier: Apache-2.0
# GeoPrior-v3 — https://github.com/earthai-tech/geoprior-v3
# Copyright (c) 2026-present
# Author: LKouadio <https://lkouadio.com>
"""
Physics-Informed Neural Network (PINN) Utility functions.
"""
from __future__ import annotations
import logging
import os
import warnings # noqa
from collections.abc import Callable, Sequence
from typing import (
Any,
Final,
)
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.cm import ScalarMappable
from sklearn.preprocessing import MinMaxScaler
from ...api.util import get_table_size
from ...core.checks import (
check_datetime,
exist_features,
)
from ...core.handlers import columns_manager
from ...core.io import SaveFile
from ...decorators import isdf
from ...logging import get_logger
from ...metrics.utils import compute_quantile_diagnostics
from ...utils.data_utils import mask_by_reference
from ...utils.deps_utils import ensure_pkg
from ...utils.forecast_utils import normalize_for_pinn
from ...utils.generic_utils import (
print_box,
rename_dict_keys,
select_mode,
vlog,
)
from ...utils.geo_utils import resolve_spatial_columns
from ...utils.io_utils import save_job
from ...utils.sequence_utils import check_sequence_feasibility
from ...utils.validator import validate_positive_integer
from .. import KERAS_BACKEND, KERAS_DEPS
Model = KERAS_DEPS.Model
Tensor = KERAS_DEPS.Tensor
tf_shape = KERAS_DEPS.shape
tf_convert_to_tensor = KERAS_DEPS.convert_to_tensor
tf_float32 = KERAS_DEPS.float32
tf_cast = KERAS_DEPS.cast
tf_concat = KERAS_DEPS.concat
tf_expand_dims = KERAS_DEPS.expand_dims
tf_debugging = KERAS_DEPS.debugging
tf_fill = KERAS_DEPS.fill
tf_reshape = KERAS_DEPS.reshape
logger = get_logger(__name__)
_TW = get_table_size()
all__ = [
"format_pihalnet_predictions",
"prepare_pinn_data_sequences",
"format_pinn_predictions",
"extract_txy_in",
"extract_txy",
"plot_hydraulic_head",
"PDE_MODE_ALIASES",
]
_PDE_NONE_ALIASES = {
"none",
"off",
"false",
"0",
"disable",
"disabled",
"",
}
_PDE_BOTH_ALIASES = {
"both",
"on",
"all",
"true",
"1",
}
_PDE_TOKEN_ALIASES = {
"consolidation": "consolidation",
"cons": "consolidation",
"gw_flow": "gw_flow",
"gw": "gw_flow",
"groundwater": "gw_flow",
"groundwater_flow": "gw_flow",
}
_PDE_ORDER = ("consolidation", "gw_flow")
PDE_CANONICAL_MODES: Final[frozenset[str]] = frozenset(
{
"consolidation",
"gw_flow",
"both",
"none",
}
)
PDE_MODE_ALIASES: Final[frozenset[str]] = frozenset(
PDE_CANONICAL_MODES
| _PDE_NONE_ALIASES
| _PDE_BOTH_ALIASES
)
# Optional: strict public set for places where only canonical
# labels should be accepted after normalization.
PDE_ACTIVE_MODES: Final[frozenset[str]] = frozenset(
{
"consolidation",
"gw_flow",
"none",
}
)
def _flatten_mode_input(
value: str | Sequence[str] | None,
) -> list[str]:
"""
Convert the raw pde_mode input into a flat list of lowercase tokens.
"""
if value is None:
return ["none"]
if isinstance(value, str):
# support comma-separated strings like "consolidation,gw_flow"
parts = [p.strip().lower() for p in value.split(",")]
parts = [p for p in parts if p != ""]
return parts or ["none"]
if isinstance(value, list | tuple | set):
out: list[str] = []
for item in value:
if item is None:
out.append("none")
else:
s = str(item).strip().lower()
if s:
out.append(s)
else:
out.append("none")
return out or ["none"]
raise TypeError(
"`pde_mode` must be a string, a sequence of strings, or None."
)
def _normalize_pde_tokens(tokens: Sequence[str]) -> list[str]:
"""
Normalize aliases into canonical tokens.
Canonical tokens are:
- 'none'
- 'consolidation'
- 'gw_flow'
- 'both' (intermediate token, later expanded)
"""
normalized: list[str] = []
for tok in tokens:
if tok in _PDE_NONE_ALIASES:
normalized.append("none")
elif tok in _PDE_BOTH_ALIASES:
normalized.append("both")
elif tok in _PDE_TOKEN_ALIASES:
normalized.append(_PDE_TOKEN_ALIASES[tok])
else:
raise ValueError(
f"Unsupported pde_mode token: {tok!r}. "
"Allowed values are: "
"'none'/'off', 'consolidation', 'gw_flow', 'both'/'on'."
)
return normalized
def _canonicalize_pde_modes(
tokens: Sequence[str],
) -> list[str]:
"""
Convert normalized tokens into the final active-mode list.
Rules
-----
- 'none' cannot be combined with any active PDE mode.
- 'both' expands to ['consolidation', 'gw_flow'].
- duplicates are removed while preserving canonical order.
"""
toks = list(tokens)
has_none = "none" in toks
has_both = "both" in toks # noqa
# Expand "both" first conceptually, but reject ambiguous combinations.
if has_none:
non_none = [t for t in toks if t != "none"]
if non_none:
raise ValueError(
"Ambiguous pde_mode: 'none/off' cannot be combined with "
f"other modes: {non_none!r}."
)
return ["none"]
active = set()
for tok in toks:
if tok == "both":
active.update(("consolidation", "gw_flow"))
elif tok in ("consolidation", "gw_flow"):
active.add(tok)
else:
raise ValueError(
f"Unexpected normalized token: {tok!r}."
)
if not active:
return ["none"]
return [mode for mode in _PDE_ORDER if mode in active]
def _collapse_pde_modes_to_label(
active_modes: Sequence[str],
) -> str:
"""
Convert a canonical active-mode list back to a single label.
"""
modes = list(active_modes)
if modes == ["none"]:
return "none"
if modes == ["consolidation"]:
return "consolidation"
if modes == ["gw_flow"]:
return "gw_flow"
if modes == ["consolidation", "gw_flow"]:
return "both"
raise ValueError(
f"Cannot collapse unexpected active mode list: {modes!r}"
)
def process_pde_modes(
pde_mode: str | Sequence[str] | None,
enforce_consolidation: bool = False,
pde_mode_config: str | Sequence[str] | None = None,
solo_return: bool = False,
) -> list[str] | str:
r"""
Normalize and validate PDE mode selection.
Parameters
----------
pde_mode : str, sequence of str, or None
Requested PDE mode(s).
Accepted canonical values are:
- ``"none"``
- ``"consolidation"``
- ``"gw_flow"``
- ``"both"``
Accepted aliases:
- ``None``, ``"off"`` -> ``"none"``
- ``"on"`` -> ``"both"``
enforce_consolidation : bool, default=False
If True, any resolved mode other than exact ``["consolidation"]``
is coerced to ``["consolidation"]`` and a warning is emitted.
This includes:
- ``["none"]``
- ``["gw_flow"]``
- ``["consolidation", "gw_flow"]``
pde_mode_config : str, sequence of str, or None, optional
Optional override. If provided, this value takes precedence over
``pde_mode``.
solo_return : bool, default=False
If False, return a canonical list of active modes.
If True, return a single canonical label:
- ``"none"``
- ``"consolidation"``
- ``"gw_flow"``
- ``"both"``
Returns
-------
list of str or str
Canonical PDE mode(s), either as a list or a single label.
Raises
------
TypeError
If the input type is invalid.
ValueError
If a token is unsupported or the mode selection is ambiguous.
Examples
--------
>>> process_pde_modes(None)
['none']
>>> process_pde_modes("off")
['none']
>>> process_pde_modes("on")
['consolidation', 'gw_flow']
>>> process_pde_modes("both", solo_return=True)
'both'
>>> process_pde_modes("gw_flow", enforce_consolidation=True)
['consolidation']
"""
source = (
pde_mode_config
if pde_mode_config is not None
else pde_mode
)
raw_tokens = _flatten_mode_input(source)
normalized_tokens = _normalize_pde_tokens(raw_tokens)
active_modes = _canonicalize_pde_modes(normalized_tokens)
if enforce_consolidation and active_modes != [
"consolidation"
]:
warnings.warn(
"This model requires PDE mode 'consolidation'. "
f"Received {active_modes!r}; coercing to ['consolidation'].",
UserWarning,
stacklevel=2,
)
active_modes = ["consolidation"]
if solo_return:
return _collapse_pde_modes_to_label(active_modes)
return active_modes
def _process_pde_modes(
pde_mode: str | list | None,
enforce_consolidation: bool = False,
pde_mode_config: str | list | None = None,
solo_return: bool = False,
) -> list:
r"""
Process and validate the `pde_mode` argument to determine the active PDE modes.
This function handles `pde_mode` inputs and processes them according to
the following rules:
- If the input is 'none', only the mode 'none' will be active.
- If the input is 'both', it will set both 'consolidation' and 'gw_flow'.
- If the input is not 'consolidation' and `enforce_consolidation` is True,
it will issue a warning and fallback to using only 'consolidation'.
- If `pde_mode_config` is provided, it overrides any other mode setting.
Parameters
----------
pde_mode : str, list of str, or None
The desired PDE modes. Can be:
- A string (e.g., 'consolidation', 'gw_flow', etc.)
- A list of strings (e.g., ['consolidation', 'gw_flow'])
- None (to set no active modes)
enforce_consolidation : bool, default=False
If True, the function ensures that 'consolidation' is the only
mode active and issues a warning if another mode is passed.
pde_mode_config : str, list of str, or None, optional
If provided, overrides the `pde_mode` argument.
Returns
-------
list of str
A list of active PDE modes. The list will contain the modes in lowercase.
If 'none' or 'both' is specified, these will be processed according to the logic.
Raises
------
TypeError
If `pde_mode` is neither a string, a list of strings, nor None.
Warnings
--------
If `enforce_consolidation` is True and a mode other than 'consolidation'
is passed, a warning will be issued, and 'consolidation' will be used
as the active mode instead.
"""
# If pde_mode_config is provided, use it, otherwise fall back to pde_mode
if pde_mode_config:
pde_mode = pde_mode_config
if isinstance(pde_mode, str):
if pde_mode.lower() == "on":
pde_mode = "both"
pde_modes_active = [pde_mode.lower()]
elif isinstance(pde_mode, list):
pde_modes_active = [
str(p_type).lower() for p_type in pde_mode
]
elif pde_mode is None:
pde_modes_active = [
"none"
] # Explicitly 'none' if None is provided
else:
raise TypeError(
"`pde_mode` must be a string, list of strings, or None."
)
# Handle special cases for "none" and "both"
if (
"none" in pde_modes_active
or "off" in pde_modes_active
): # If 'none' is present, override others
pde_modes_active = ["none"]
if (
"both" in pde_modes_active or "on" in pde_modes_active
): # If 'both' is present, use both modes
pde_modes_active = ["consolidation", "gw_flow"]
# Enforce consolidation mode if specified
if (
enforce_consolidation
and "consolidation" not in pde_modes_active
):
warnings.warn(
"You have passed a mode other than 'consolidation'. "
"Falling back to 'consolidation' as the active mode.",
UserWarning,
stacklevel=2,
)
pde_modes_active = ["consolidation"]
# Ensure 'consolidation' is the only active mode if it was defaulted,
# or if user explicitly selected only unsupported modes.
if any(
unsupported in pde_modes_active
for unsupported in {
"consolidation",
"gw_flow",
"both",
"none",
"on",
"off",
}
):
# This case means decorator didn't force 'consolidation', so we do it here
logger.info(
f"Unsupported pde_mode '{pde_mode}' "
"selected without 'consolidation'. "
"Model will use 'consolidation' mode."
)
# pde_modes_active = ['consolidation']
if solo_return:
pde_modes_active = pde_modes_active[0]
# Return the processed pde_modes_active
return pde_modes_active
@SaveFile
def format_pinn_predictions(
predictions: dict[str, Tensor] | None = None,
model: Model | None = None,
model_inputs: dict[str, Tensor] | None = None,
y_true_dict: dict[str, np.ndarray | Tensor] | None = None,
target_mapping: dict[str, str] | None = None,
include_gwl: bool = True,
include_coords: bool = True,
quantiles: list[float] | None = None,
forecast_horizon: int | None = None,
output_dims: dict[str, int] | None = None,
ids_data_array: np.ndarray | pd.DataFrame | None = None,
ids_cols: list[str] | None = None,
ids_cols_indices: list[int] | None = None,
scaler_info: dict[str, dict[str, Any]] | None = None,
coord_scaler: Any | None = None,
evaluate_coverage: bool = False,
coverage_quantile_indices: tuple[int, int] = (0, -1),
savefile: str | None = None,
_logger: logging.Logger
| Callable[[str], None]
| None = None,
name: str | None = None,
model_name: str | None = None,
stop_check: Callable[[], bool] = None,
verbose: int = 0,
**kwargs,
) -> pd.DataFrame:
r"""Formats PINN model predictions into a structured pandas DataFrame.
This is a general-purpose utility for transforming raw model outputs
(from models like PIHALNet or TransFlowSubsNet) into a long-format
DataFrame suitable for analysis, visualization, or export.
This is a powerful, general-purpose utility for transforming raw
model outputs into a long-format DataFrame suitable for analysis,
visualization, or export. It handles multi-target outputs (e.g.,
subsidence and GWL), point or quantile forecasts, and can
optionally include true values, coordinate information, and other
metadata. It also supports inverse-scaling of predictions and
evaluation of quantile coverage.
Parameters
----------
predictions : dict of Tensors, optional
The dictionary of prediction tensors, typically returned by a
model's ``.predict()`` method. Keys should match the model's
output layer names (e.g., ``'subs_pred'``, ``'gwl_pred'``).
If ``None``, predictions are generated internally using the
`model` and `model_inputs` arguments. Default is ``None``.
model : keras.Model, optional
A compiled Keras model instance used to generate predictions
if the `predictions` dictionary is not provided.
Default is ``None``.
model_inputs : dict of Tensors, optional
A dictionary of input tensors matching the model's signature,
required only if `predictions` is ``None``. Default is ``None``.
y_true_dict : dict, optional
A dictionary containing the ground-truth target arrays, keyed
by their base names (e.g., ``'subsidence'``, ``'gwl'``).
If provided, an ``<target>_actual`` column will be added to
the output DataFrame for comparison. Default is ``None``.
target_mapping : dict, optional
A custom mapping from model output keys to desired base names in
the DataFrame columns. For example:
``{'subs_pred': 'subsidence_mm', 'gwl_pred': 'head_m'}``.
Default is ``None``.
include_gwl : bool, default=True
Toggles the inclusion of groundwater level (GWL) predictions
in the final DataFrame.
include_coords : bool, default=True
Toggles the inclusion of the spatio-temporal coordinate columns
(``coord_t``, ``coord_x``, ``coord_y``) in the final DataFrame.
quantiles : list of float, optional
The list of quantile levels (e.g., ``[0.1, 0.5, 0.9]``) that
the model predicted. This is crucial for correctly parsing
probabilistic forecasts. Default is ``None``.
forecast_horizon : int, optional
The length of the forecast horizon. If ``None``, it is inferred
from the shape of the prediction tensors. Default is ``None``.
output_dims : dict of str, optional
A dictionary specifying the feature dimension of each target,
e.g., ``{'subs_pred': 1, 'gwl_pred': 1}``. If ``None``, it's
inferred from the tensor shapes. Default is ``None``.
ids_data_array : np.ndarray or pd.DataFrame, optional
An array or DataFrame containing static identifiers (e.g., well
IDs, site categories) for each sample. Its length must match
the number of samples in the prediction. Default is ``None``.
ids_cols : list of str, optional
A list of column names for the `ids_data_array`. Required if
`ids_data_array` is a NumPy array. Default is ``None``.
ids_cols_indices : list of int, optional
A list of column indices to select from `ids_data_array` if it
is a NumPy array. Default is ``None``.
scaler_info : dict, optional
A dictionary providing the necessary information to perform
inverse scaling on a per-target basis. Each key should be a
target name (e.g., 'subsidence') and its value a dictionary
containing ``{'scaler': obj, 'all_features': list, 'idx': int}``.
Default is ``None``.
coord_scaler : object, optional
A fitted scikit-learn-like scaler object used to perform an
inverse transform on the coordinate columns. Default is ``None``.
evaluate_coverage : bool, default=False
If ``True`` and quantile predictions are present, calculates the
unconditional coverage of the prediction interval.
coverage_quantile_indices : tuple of (int, int), default=(0, -1)
The indices of the lower and upper quantiles in the sorted
`quantiles` list to use for the coverage calculation. Default
is ``(0, -1)``, which corresponds to the full range.
savefile : str, optional
If a file path is provided, the final DataFrame is saved to a
CSV file at this location. Default is ``None``.
name: str or None
Name of the prediction. Name is used to format the output of
the data and coverage result if applicable.
model_name: str, None,
Name of the model.
verbose : int, default=0
The verbosity level, from 0 (silent) to 5 (trace every step).
**kwargs: dict,
Additional keyword arguments for future extensions.
Returns
-------
pd.DataFrame
A long-format DataFrame where each row represents a single
forecast step for a single sample. Columns include sample and
step identifiers, coordinates, predictions, and optionally
actuals and metadata.
Notes
-----
- The function returns a column-aligned DataFrame, which simplifies
subsequent analysis and plotting.
- For quantile forecasts, prediction columns are named using the
pattern ``<target_name>_q<quantile*100>``, e.g., ``subsidence_q5``,
``subsidence_q50``, ``subsidence_q95``.
- For point forecasts, the column is named ``<target_name>_pred``.
See Also
--------
geoprior.plot.forecast.plot_forecasts : A powerful utility for
visualizing the DataFrame produced by this function.
"""
#
# This acts as a backward-compatible wrapper.
return format_pihalnet_predictions(
pihalnet_outputs=predictions,
model=model,
model_inputs=model_inputs,
y_true_dict=y_true_dict,
target_mapping=target_mapping,
include_gwl=include_gwl,
include_coords=include_coords,
quantiles=quantiles,
forecast_horizon=forecast_horizon,
output_dims=output_dims,
ids_data_array=ids_data_array,
ids_cols=ids_cols,
ids_cols_indices=ids_cols_indices,
scaler_info=scaler_info,
coord_scaler=coord_scaler,
evaluate_coverage=evaluate_coverage,
coverage_quantile_indices=coverage_quantile_indices,
savefile=savefile,
name=name,
model_name=model_name,
verbose=verbose,
_logger=_logger,
stop_check=stop_check,
**kwargs,
)
@SaveFile
def format_pihalnet_predictions(
pihalnet_outputs: dict[str, Tensor] | None = None,
model: Model | None = None,
model_inputs: dict[str, Tensor] | None = None,
y_true_dict: dict[str, np.ndarray | Tensor] | None = None,
target_mapping: dict[str, str] | None = None,
include_gwl: bool = True,
include_coords: bool = True,
quantiles: list[float] | None = None,
forecast_horizon: int | None = None,
output_dims: dict[str, int] | None = None,
ids_data_array: np.ndarray | pd.DataFrame | None = None,
ids_cols: list[str] | None = None,
ids_cols_indices: list[int] | None = None,
scaler_info: dict[str, dict[str, Any]] | None = None,
coord_scaler: Any | None = None,
evaluate_coverage: bool = False,
coverage_quantile_indices: tuple[int, int] = (0, -1),
savefile: str | None = None,
name: str | None = None,
model_name: str | None = None,
apply_mask: bool = False,
mask_values: float | int | None = None,
mask_fill_value: float | None = None,
verbose: int = 0,
_logger: logging.Logger
| Callable[[str], None]
| None = None,
stop_check: Callable[[], bool] = None,
**kwargs,
) -> pd.DataFrame:
r"""
Formats PIHALNet/GeoPriorSubsNet predictions into a structured
pandas DataFrame, handling inversion, quantiles, and coordinates.
This function is the core formatter. It:
1. Gets model outputs (or uses provided ones).
2. Unpacks 'data_final' if `model_name` is 'geoprior'.
3. Inverse-transforms all prediction and actual arrays using `scaler_info`.
4. Builds a long-format DataFrame with sample_idx and forecast_step.
5. Appends inverted quantile/point predictions.
6. Appends inverted actual values.
7. Appends inverted coordinates.
8. Appends static/ID columns.
9. Evaluates coverage on the inverted data.
Parameters
----------
pihalnet_outputs : dict, optional
Raw output from `model.predict()`. If None, `model` and
`model_inputs` must be provided.
model : tf.keras.Model, optional
Trained model instance (if `pihalnet_outputs` is None).
model_inputs : dict, optional
Inputs for the model to generate predictions
(if `pihalnet_outputs` is None).
y_true_dict : dict, optional
Dictionary of true target arrays (e.g., {'subs_pred': y_true_s}).
Required for including actuals and evaluating coverage.
target_mapping : dict, optional
Maps prediction keys to base names for DataFrame columns.
Default: {'subs_pred': 'subsidence', 'gwl_pred': 'gwl'}.
include_gwl : bool, default True
Whether to include 'gwl_pred' in the final DataFrame.
include_coords : bool, default True
Whether to include 'coord_t', 'coord_x', 'coord_y' columns.
quantiles : list[float], optional
List of quantiles (e.g., [0.1, 0.5, 0.9]). If provided,
quantile columns (e.g., 'subsidence_q10') are created.
forecast_horizon : int, optional
The forecast horizon length (H). If not provided, it's
inferred from the prediction array's shape.
output_dims : dict, optional
Maps prediction keys to their output dimension (O).
E.g., {'subs_pred': 1, 'gwl_pred': 1}. Crucial for
correctly splitting GeoPrior outputs and reshaping.
ids_data_array : np.ndarray or pd.DataFrame, optional
Static/ID data (e.g., original coordinates) to merge.
Must have the same number of samples (B) as predictions.
ids_cols : list[str], optional
Column names if `ids_data_array` is a DataFrame.
ids_cols_indices : list[int], optional
Column indices if `ids_data_array` is a NumPy array.
scaler_info : dict, optional
Dictionary for inverse scaling. Each target entry should provide
a fitted scaler, the target index inside that scaler, and the
feature-name ordering used when the scaler was fit.
coord_scaler : sklearn.preprocessing.Scaler, optional
A *fitted* scaler object for inverse transforming the 'coords' tensor.
evaluate_coverage : bool, default False
If True, calculates coverage percentage for quantiles.
coverage_quantile_indices : tuple[int, int], default (0, -1)
Indices of the low and high quantiles in the `quantiles` list
to use for coverage (e.g., 0 and -1 for 10th and 90th).
savefile : str, optional
If provided, saves the final DataFrame to this path.
model_name : str, optional
Specifies the model type. If 'geoprior' or 'geopriorsubsnet',
triggers unpacking of the 'data_final' output.
apply_mask : bool, default False
If True, masks predictions based on `mask_values` in the
first target's `_actual` column.
mask_values : float or int, optional
The value in the `_actual` column to trigger masking.
mask_fill_value : float, optional
The value to replace masked predictions with (e.g., np.nan).
verbose : int, default 0
Logging verbosity.
_logger : logging.Logger or callable, optional
Logger object.
stop_check : callable, optional
Function to check for early stopping.
Returns
-------
pd.DataFrame
A long-format DataFrame with predictions, actuals, and coordinates.
"""
vlog(
f"Starting model prediction formatting (verbose={verbose}).",
level=3,
verbose=verbose,
logger=_logger,
)
# --- 1. Obtain Model Predictions (if not provided) ---
if pihalnet_outputs is None:
if model is None or model_inputs is None:
raise ValueError(
"If 'pihalnet_outputs' is None, both 'model' and "
"'model_inputs' must be provided."
)
vlog(
" Predictions not provided, generating from model...",
level=4,
verbose=verbose,
logger=_logger,
)
try:
pihalnet_outputs = model.predict(
model_inputs, verbose=0
)
if not isinstance(pihalnet_outputs, dict):
raise ValueError(
"Model output is not a dictionary as expected."
)
vlog(
" Model predictions generated.",
level=5,
verbose=verbose,
logger=_logger,
)
except Exception as e:
raise RuntimeError(
f"Failed to generate predictions from model: {e}"
) from e
if stop_check and stop_check():
raise InterruptedError(
"Model prediction aborted."
)
# --- 2. Unpack GeoPrior Model Output (if needed) ---
is_geoprior = str(model_name).lower().strip() in (
"geoprior",
"geopriorsubsnet",
)
if is_geoprior:
vlog(
f"GeoPrior model detected (model_name='{model_name}'). "
"Unpacking 'data_final' tensor.",
level=3,
verbose=verbose,
logger=_logger,
)
if output_dims is None:
output_dims = {"subs_pred": 1, "gwl_pred": 1}
vlog(
" `output_dims` not provided, defaulting to "
"{'subs_pred': 1, 'gwl_pred': 1}.",
level=4,
verbose=verbose,
logger=_logger,
)
pihalnet_outputs = _unpack_geoprior_outputs(
geoprior_outputs=pihalnet_outputs,
output_dims=output_dims,
quantiles=quantiles,
verbose=verbose,
_logger=_logger,
)
# --- 3. Ensure all data is NumPy & INVERSE TRANSFORM ---
if target_mapping is None:
target_mapping = {
"subs_pred": "subsidence",
"gwl_pred": "gwl",
}
vlog(
" `target_mapping` not provided, using defaults.",
level=4,
verbose=verbose,
logger=_logger,
)
if not include_gwl:
target_mapping.pop("gwl_pred", None)
# Process predictions
processed_preds = {}
if pihalnet_outputs:
vlog(
" Converting and inverse-transforming predictions...",
level=4,
verbose=verbose,
logger=_logger,
)
for pred_key, base_name in target_mapping.items():
if pred_key in pihalnet_outputs:
val_tensor = pihalnet_outputs[pred_key]
pred_array = (
val_tensor.numpy()
if hasattr(val_tensor, "numpy")
else np.asarray(val_tensor)
)
inverted_array = _inverse_transform_array(
pred_array,
base_name,
scaler_info,
verbose,
_logger,
)
processed_preds[pred_key] = inverted_array
else:
vlog(
f" [WARN] Prediction key '{pred_key}' not found "
"in model outputs.",
level=2,
verbose=verbose,
logger=_logger,
)
if not processed_preds:
vlog(
" No valid prediction keys found in model output. "
"Returning empty DF.",
level=1,
verbose=verbose,
logger=_logger,
)
return pd.DataFrame()
# Process actuals (y_true)
processed_actuals = {}
if y_true_dict:
vlog(
" Converting and inverse-transforming actuals (y_true)...",
level=4,
verbose=verbose,
logger=_logger,
)
for true_key, base_name in target_mapping.items():
# --- START FIX ---
# Try to get the tensor using the pred_key ('subs_pred') first
val_tensor = y_true_dict.get(true_key)
if val_tensor is None:
# If not found, try getting it with the base_name ('subsidence')
val_tensor = y_true_dict.get(base_name)
if val_tensor is not None:
# --- END FIX ---
true_array = (
val_tensor.numpy()
if hasattr(val_tensor, "numpy")
else np.asarray(val_tensor)
)
inverted_array = _inverse_transform_array(
true_array,
base_name,
scaler_info,
verbose,
_logger,
)
# Use true_key ('subs_pred') to align with processed_preds
processed_actuals[true_key] = inverted_array
else:
vlog(
f" [WARN] True data key '{true_key}' (or '{base_name}') "
"not found in y_true_dict.",
level=2,
verbose=verbose,
logger=_logger,
)
# --- 4. Initialize DataFrame (Base) ---
first_pred_key = list(processed_preds.keys())[0]
first_pred_array = processed_preds[first_pred_key]
num_samples = first_pred_array.shape[0]
H_inferred = forecast_horizon or first_pred_array.shape[1]
if num_samples == 0 or H_inferred == 0:
vlog(
" No samples or zero forecast horizon. Returning empty DF.",
level=1,
verbose=verbose,
logger=_logger,
)
return pd.DataFrame()
vlog(
f" Formatting for {num_samples} samples, Horizon={H_inferred}.",
level=4,
verbose=verbose,
logger=_logger,
)
sample_indices = np.repeat(
np.arange(num_samples), H_inferred
)
forecast_steps = np.tile(
np.arange(1, H_inferred + 1), num_samples
)
all_data_dfs = [
pd.DataFrame(
{
"sample_idx": sample_indices,
"forecast_step": forecast_steps,
}
)
]
# --- 5. Populate DataFrame with INVERTED Data ---
if output_dims is None:
vlog(
" [WARN] `output_dims` is None. Inferring from arrays. "
"This may be unreliable for multi-output models.",
level=2,
verbose=verbose,
logger=_logger,
)
output_dims = {}
for k, v in processed_preds.items():
output_dims[k] = (
v.shape[-1]
if quantiles is None
else v.shape[-1]
)
if quantiles is not None:
df_part = _add_quantiles_to_df(
processed_preds,
target_mapping,
quantiles,
output_dims,
num_samples,
H_inferred,
verbose,
_logger,
)
all_data_dfs.append(df_part)
else:
df_part = _add_point_preds_to_df(
processed_preds,
target_mapping,
output_dims,
num_samples,
H_inferred,
verbose,
_logger,
)
all_data_dfs.append(df_part)
if processed_actuals:
df_part = _add_actuals_to_df(
processed_actuals,
target_mapping,
output_dims,
num_samples,
H_inferred,
verbose,
_logger,
)
all_data_dfs.append(df_part)
# # --- 6. Add Coordinates (inverted inside helper) ---
df_coord_cols = []
if include_coords and model_inputs:
df_part, df_coord_cols = _add_coords_to_df(
model_inputs=model_inputs,
coord_scaler=coord_scaler,
num_samples=num_samples,
H_inferred=H_inferred,
verbose=verbose,
_logger=_logger,
force_future_from=kwargs.get(
"forecast_start_year"
),
)
all_data_dfs.append(df_part)
# --- 7. Add IDs (no inversion needed) ---
df_part = _add_ids_to_df(
ids_data_array,
ids_cols,
ids_cols_indices,
num_samples,
H_inferred,
verbose,
_logger,
)
all_data_dfs.append(df_part)
if stop_check and stop_check():
raise InterruptedError(
"DataFrame population aborted."
)
# --- 8. Concatenate all DataFrames ---
final_df = pd.concat(all_data_dfs, axis=1)
# --- 10. Optional Masking ---
if apply_mask:
vlog(
" Applying mask to prediction columns...",
level=4,
verbose=verbose,
logger=_logger,
)
if mask_values is None or mask_fill_value is None:
raise ValueError(
"When apply_mask=True, both mask_values and "
"mask_fill_value must be provided."
)
# we assume you only ever want to mask against your first target:
first_base = list(target_mapping.values())[0]
ref_col = f"{first_base}_actual"
if ref_col not in final_df.columns:
vlog(
f" [WARN] Masking reference column '{ref_col}' not in "
"DataFrame. Skipping masking.",
level=2,
verbose=verbose,
logger=_logger,
)
else:
# collect all the forecast columns you produced
mask_cols = [
c
for c in final_df.columns
if any(
c.startswith(f"{base}_q")
for base in target_mapping.values()
)
or any(
c.startswith(f"{base}_pred")
for base in target_mapping.values()
)
]
if not mask_cols:
vlog(
" [WARN] No maskable forecast columns found in final_df. "
"Skipping masking.",
level=2,
verbose=verbose,
logger=_logger,
)
else:
try:
# Use the imported mask_by_reference function
final_df = mask_by_reference(
data=final_df,
ref_col=ref_col,
values=mask_values,
find_closest=False,
fill_value=mask_fill_value,
mask_columns=mask_cols,
error="ignore",
inplace=False,
)
vlog(
f" Successfully applied mask to {len(mask_cols)} "
f"columns based on '{ref_col}'.",
level=4,
verbose=verbose,
logger=_logger,
)
except Exception as e:
vlog(
f" [WARN] Failed to apply mask: {e}",
level=2,
verbose=verbose,
logger=_logger,
)
# --- 9. Evaluate Coverage (now compares inverted vs. inverted) ---
if evaluate_coverage and quantiles and processed_actuals:
_evaluate_coverage(
df=final_df,
target_mapping=target_mapping,
q_indices=coverage_quantile_indices,
savefile=savefile,
name=name,
verbose=verbose,
_logger=_logger,
)
if stop_check and stop_check():
raise InterruptedError("Metric/Masking aborted.")
vlog(
"Model prediction formatting to DataFrame complete.",
level=3,
verbose=verbose,
logger=_logger,
)
return final_df
def _format_target_predictions(
predictions_np: np.ndarray,
num_samples: int,
H: int, # Horizon
O: int, # Output dim for this specific target
base_target_name: str,
quantiles: list[float] | None,
verbose: int = 0,
) -> tuple[list[str], pd.DataFrame]:
"""Helper to format predictions for a single target variable."""
pred_cols_names = []
# Expected input shapes to this helper:
# Point: (N, H, O)
# Quantile: (N, H, Q, O) OR (N, H, Q) if O=1 was pre-squeezed
if quantiles:
num_q = len(quantiles)
# Ensure predictions_np is (N, H, Q, O)
if (
predictions_np.ndim == 3
and predictions_np.shape[-1] == num_q
and O == 1
):
# Case: (N, H, Q), implies O=1
preds_to_process = np.expand_dims(
predictions_np, axis=-1
) # (N,H,Q,1)
elif (
predictions_np.ndim == 3
and predictions_np.shape[-1] == num_q * O
):
# Case: (N, H, Q*O)
preds_to_process = predictions_np.reshape(
(num_samples, H, num_q, O)
)
elif (
predictions_np.ndim == 4
and predictions_np.shape[2] == num_q
and predictions_np.shape[3] == O
):
# Case: (N, H, Q, O) - already correct
preds_to_process = predictions_np
else:
raise ValueError(
f"Unexpected quantile prediction shape for {base_target_name}: "
f"{predictions_np.shape}. Expected compatible with N={num_samples}, "
f"H={H}, Q={num_q}, O={O}"
)
# Now preds_to_process is (N, H, Q, O)
df_data_for_concat = []
for o_idx in range(O):
for q_idx, q_val in enumerate(quantiles):
col_name = f"{base_target_name}"
if O > 1:
col_name += f"_{o_idx}"
col_name += f"_q{int(q_val * 100)}"
pred_cols_names.append(col_name)
df_data_for_concat.append(
preds_to_process[
:, :, q_idx, o_idx
].reshape(-1)
)
pred_df_part = pd.DataFrame(
dict(
zip(
pred_cols_names,
df_data_for_concat,
strict=False,
)
)
)
else: # Point forecast
# predictions_np should be (N, H, O)
if (
predictions_np.ndim != 3
or predictions_np.shape[-1] != O
):
raise ValueError(
f"Unexpected point prediction shape for {base_target_name}: "
f"{predictions_np.shape}. Expected (N,H,O) with O={O}"
)
df_data_for_concat = []
for o_idx in range(O):
col_name = f"{base_target_name}"
if O > 1:
col_name += f"_{o_idx}"
col_name += "_pred"
pred_cols_names.append(col_name)
df_data_for_concat.append(
predictions_np[:, :, o_idx].reshape(-1)
)
pred_df_part = pd.DataFrame(
dict(
zip(
pred_cols_names,
df_data_for_concat,
strict=False,
)
)
)
return pred_cols_names, pred_df_part
def _unpack_geoprior_outputs(
geoprior_outputs: dict[str, Tensor],
output_dims: dict[str, int],
quantiles: list[float] | None,
verbose: int = 0,
_logger: Any = None,
) -> dict[str, Tensor]:
"""
Unpacks the nested 'data_final' tensor from GeoPriorSubsNet
into a flat dictionary expected by the formatter.
"""
if "data_final" not in geoprior_outputs:
vlog(
" [WARN] 'data_final' key not found in GeoPrior model output. "
"Assuming outputs are already flat.",
level=2,
verbose=verbose,
logger=_logger,
)
return geoprior_outputs # Return as-is
data_final_tensor = geoprior_outputs["data_final"]
# Get the split index for subsidence
s_dim = output_dims.get("subs_pred")
if s_dim is None:
vlog(
" [WARN] 'output_dims' did not contain 'subs_pred'. "
"Assuming subsidence output_dim = 1.",
level=2,
verbose=verbose,
logger=_logger,
)
s_dim = 1 # Default to 1 as a fallback
# Split the tensor based on whether quantiles are present
if quantiles is not None:
# Shape is (B, H, Q, O_total)
vlog(
f" Splitting 4D quantile tensor at final axis index {s_dim}.",
level=4,
verbose=verbose,
logger=_logger,
)
s_pred_tensor = data_final_tensor[..., :s_dim]
h_pred_tensor = data_final_tensor[..., s_dim:]
else:
# Shape is (B, H, O_total)
vlog(
f" Splitting 3D point-forecast tensor at final axis index {s_dim}.",
level=4,
verbose=verbose,
logger=_logger,
)
s_pred_tensor = data_final_tensor[..., :s_dim]
h_pred_tensor = data_final_tensor[..., s_dim:]
# Create the new flat dictionary
unpacked_outputs = {
"subs_pred": s_pred_tensor,
"gwl_pred": h_pred_tensor,
}
# Pass through any other keys that aren't the standard nested ones
for k, v in geoprior_outputs.items():
if k not in [
"data_final",
"data_mean",
"phys_mean_raw",
]:
unpacked_outputs[k] = v
return unpacked_outputs
def _inverse_transform_array(
array: np.ndarray,
base_name: str,
scaler_info: dict[str, dict[str, Any]] | None,
verbose: int = 0,
_logger: Any = None,
) -> np.ndarray:
"""
Helper to inverse-transform a single numpy array using scaler_info.
This function finds the correct scaler and column index from
scaler_info and applies the inverse min-max scaling.
"""
if scaler_info is None or base_name not in scaler_info:
# No scaler info for this target, return original array
vlog(
f" - No scaler_info found for '{base_name}'. "
"Returning original (scaled) array.",
level=4,
verbose=verbose,
logger=_logger,
)
return array
try:
info = scaler_info[base_name]
scaler = info["scaler"]
col_index = info["idx"]
# Get the specific min and scale for this column from the scaler
# scaler.min_ is the 'intercept'
# scaler.scale_ is the 'coefficient'
min_val = scaler.min_[col_index]
scale_val = scaler.scale_[col_index]
# --- THIS IS THE FIX ---
# The inverse of (X * scale_) + min_ is (X - min_) / scale_
if scale_val == 0:
# Handle case where scale is zero (all original values were the same)
# The inverse is the original constant value, which we get from data_min_
try:
original_min = scaler.data_min_[col_index]
except AttributeError:
# Fallback if data_min_ is not available (should be on a fitted scaler)
original_min = 0 # This is a guess, but division by zero is worse
vlog(
f" [WARN] Scaler for '{base_name}' has scale_=0 but "
"no 'data_min_' attribute. Inverse may be incorrect.",
level=2,
verbose=verbose,
logger=_logger,
)
inverted_array = np.full_like(array, original_min)
else:
inverted_array = (array - min_val) / scale_val
vlog(
f" - Applied inverse transform to '{base_name}' array "
f"(min_attr: {min_val:.4f}, scale_attr: {scale_val:.4f})",
level=5,
verbose=verbose,
logger=_logger,
)
return inverted_array
except AttributeError:
vlog(
f" [WARN] Scaler for '{base_name}' is missing 'min_' or 'scale_'"
" attributes. Returning original array.",
level=2,
verbose=verbose,
logger=_logger,
)
return array
except IndexError:
vlog(
f" [WARN] Column index '{col_index}' out of bounds for "
f"scaler of '{base_name}'. Returning original array.",
level=2,
verbose=verbose,
logger=_logger,
)
return array
except Exception as e:
vlog(
f" [WARN] Failed to inverse transform '{base_name}': {e}"
". Returning original array.",
level=2,
verbose=verbose,
logger=_logger,
)
return array
def _add_quantiles_to_df(
predictions: dict[str, np.ndarray],
target_mapping: dict[str, str],
quantiles: list[float],
output_dims: dict[str, int],
num_samples: int,
H_inferred: int,
verbose: int = 0,
_logger: Any = None,
) -> pd.DataFrame:
"""Formats INVERTED quantile predictions into a DataFrame."""
df_parts = []
for pred_key, base_name in target_mapping.items():
if pred_key not in predictions:
continue
pred_array = predictions[
pred_key
] # Should be (B, H, Q, O)
O_target = output_dims.get(pred_key)
if O_target is None:
O_target = pred_array.shape[-1]
vlog(
f" [WARN] output_dim for '{pred_key}' not specified. "
f"Inferred {O_target} from array shape.",
level=2,
verbose=verbose,
logger=_logger,
)
if pred_array.shape != (
num_samples,
H_inferred,
len(quantiles),
O_target,
):
vlog(
f" [WARN] Quantile array for '{pred_key}' has unexpected shape "
f"{pred_array.shape}. Expected {(num_samples, H_inferred, len(quantiles), O_target)}. "
"Skipping.",
level=2,
verbose=verbose,
logger=_logger,
)
continue
# Reshape: (B, H, Q, O) -> (B*H, Q*O)
reshaped_preds = pred_array.reshape(
num_samples * H_inferred, -1
)
# Create column names
col_names = []
for o_idx in range(O_target):
for q in quantiles:
# Format quantile string (e.g., q05, q50, q95)
q_str = f"q{int(q * 100):02d}"
if q_str == "q50":
q_str = "q50" # Standardize
elif q_str == "q10":
q_str = "q10"
elif q_str == "q90":
q_str = "q90"
elif q_str == "q05":
q_str = "q05"
elif q_str == "q95":
q_str = "q95"
col_name = f"{base_name}_{q_str}"
if O_target > 1:
col_name += f"_{o_idx}"
col_names.append(col_name)
if reshaped_preds.shape[1] != len(col_names):
vlog(
f" [WARN] Mismatch in quantile columns for '{base_name}'. "
f"Expected {len(col_names)} columns, found {reshaped_preds.shape[1]}. "
"Skipping.",
level=2,
verbose=verbose,
logger=_logger,
)
continue
df_parts.append(
pd.DataFrame(reshaped_preds, columns=col_names)
)
vlog(
f" - Added quantile columns for '{base_name}': {col_names}",
level=5,
verbose=verbose,
logger=_logger,
)
return (
pd.concat(df_parts, axis=1)
if df_parts
else pd.DataFrame()
)
def _add_point_preds_to_df(
predictions: dict[str, np.ndarray],
target_mapping: dict[str, str],
output_dims: dict[str, int],
num_samples: int,
H_inferred: int,
verbose: int = 0,
_logger: Any = None,
) -> pd.DataFrame:
"""Formats INVERTED point predictions into a DataFrame."""
df_parts = []
for pred_key, base_name in target_mapping.items():
if pred_key not in predictions:
continue
pred_array = predictions[pred_key] # (B, H, O)
O_target = output_dims.get(pred_key)
if O_target is None:
O_target = pred_array.shape[-1]
vlog(
f" [WARN] output_dim for '{pred_key}' not specified. "
f"Inferred {O_target} from array shape.",
level=2,
verbose=verbose,
logger=_logger,
)
if pred_array.shape != (
num_samples,
H_inferred,
O_target,
):
vlog(
f" [WARN] Point pred array for '{pred_key}' has unexpected shape "
f"{pred_array.shape}. Expected {(num_samples, H_inferred, O_target)}. "
"Skipping.",
level=2,
verbose=verbose,
logger=_logger,
)
continue
reshaped_preds = pred_array.reshape(
num_samples * H_inferred, O_target
)
col_names = []
for o_idx in range(O_target):
col_name = f"{base_name}_pred"
if O_target > 1:
col_name += f"_{o_idx}"
col_names.append(col_name)
df_parts.append(
pd.DataFrame(reshaped_preds, columns=col_names)
)
vlog(
f" - Added point pred columns for '{base_name}': {col_names}",
level=5,
verbose=verbose,
logger=_logger,
)
return (
pd.concat(df_parts, axis=1)
if df_parts
else pd.DataFrame()
)
def _add_actuals_to_df(
y_true_dict: dict[str, np.ndarray],
target_mapping: dict[str, str],
output_dims: dict[str, int],
num_samples: int,
H_inferred: int,
verbose: int = 0,
_logger: Any = None,
) -> pd.DataFrame:
"""Formats INVERTED actuals into a DataFrame."""
df_parts = []
for pred_key, base_name in target_mapping.items():
if pred_key not in y_true_dict:
continue
true_array = y_true_dict[pred_key] # (B, H, O)
O_target = output_dims.get(pred_key)
if O_target is None:
O_target = true_array.shape[-1]
vlog(
f" [WARN] output_dim for '{pred_key}' not specified. "
f"Inferred {O_target} from array shape.",
level=2,
verbose=verbose,
logger=_logger,
)
if true_array.shape != (
num_samples,
H_inferred,
O_target,
):
vlog(
f" [WARN] Actuals array for '{pred_key}' has unexpected shape "
f"{true_array.shape}. Expected {(num_samples, H_inferred, O_target)}. "
"Skipping.",
level=2,
verbose=verbose,
logger=_logger,
)
continue
reshaped_true = true_array.reshape(
num_samples * H_inferred, O_target
)
col_names = []
for o_idx in range(O_target):
col_name = f"{base_name}_actual"
if O_target > 1:
col_name += f"_{o_idx}"
col_names.append(col_name)
df_parts.append(
pd.DataFrame(reshaped_true, columns=col_names)
)
vlog(
f" - Added actuals columns for '{base_name}': {col_names}",
level=5,
verbose=verbose,
logger=_logger,
)
return (
pd.concat(df_parts, axis=1)
if df_parts
else pd.DataFrame()
)
def _add_coords_to_df(
model_inputs: dict[str, Tensor],
coord_scaler: Any,
num_samples: int,
H_inferred: int,
verbose: int = 0,
_logger: Any = None,
force_future_from: float | None = None,
) -> tuple[pd.DataFrame, list[str]]:
"""Formats and inverse-transforms coordinates."""
coord_names = ["coord_t", "coord_x", "coord_y"]
if (
"coords" not in model_inputs
or model_inputs["coords"] is None
):
vlog(
" 'coords' not in model_inputs. Skipping coordinate columns.",
level=2,
verbose=verbose,
logger=_logger,
)
return pd.DataFrame(), []
coords_arr = model_inputs["coords"]
if hasattr(coords_arr, "numpy"):
coords_arr = coords_arr.numpy()
if coords_arr.shape != (num_samples, H_inferred, 3):
vlog(
f" 'coords' shape mismatch ({coords_arr.shape}). Expected "
f"{(num_samples, H_inferred, 3)}. Skipping coordinate columns.",
level=2,
verbose=verbose,
logger=_logger,
)
return pd.DataFrame(), []
coords_reshaped = coords_arr.reshape(
num_samples * H_inferred, 3
)
if coord_scaler is not None:
vlog(
" Applying inverse transform to t,x,y coordinates...",
level=4,
verbose=verbose,
logger=_logger,
)
try:
coords_reshaped = coord_scaler.inverse_transform(
coords_reshaped
)
except Exception as e:
vlog(
f" [WARN] Could not inverse transform coordinates: {e}. "
"Using normalized coordinates.",
level=2,
verbose=verbose,
logger=_logger,
)
# --- override time with forecast years, if requested ---
if force_future_from is not None:
# years_for_one_sample: e.g. [2023, 2024, 2025] for H_inferred=3
years = np.arange(
float(force_future_from),
float(force_future_from) + float(H_inferred),
dtype=coords_reshaped.dtype,
)
# Repeat for all samples and flatten to match (num_samples*H, )
years_tiled = np.tile(years, num_samples)
coords_reshaped[:, 0] = years_tiled
vlog(
f" Overriding coord_t with forecast years starting at "
f"{force_future_from} for horizon={H_inferred}.",
level=4,
verbose=verbose,
logger=_logger,
)
vlog(
f" Added coordinate columns: {coord_names}",
level=4,
verbose=verbose,
logger=_logger,
)
return pd.DataFrame(
coords_reshaped, columns=coord_names
), coord_names
def _add_ids_to_df(
ids_data_array: Any,
ids_cols: list[str] | None,
ids_cols_indices: list[int] | None,
num_samples: int,
H_inferred: int,
verbose: int = 0,
_logger: Any = None,
) -> pd.DataFrame:
"""Extracts and repeats static/ID columns."""
if ids_data_array is None:
vlog(
" No `ids_data_array` provided, skipping static/ID columns.",
level=4,
verbose=verbose,
logger=_logger,
)
return pd.DataFrame()
vlog(
" Processing additional static/ID columns...",
level=4,
verbose=verbose,
logger=_logger,
)
ids_np_array = None
ids_cols_to_use = []
if isinstance(ids_data_array, pd.DataFrame):
ids_cols_to_use = (
list(ids_cols)
if ids_cols
else list(ids_data_array.columns)
)
try:
ids_np_array = ids_data_array[
ids_cols_to_use
].values
except KeyError as e:
vlog(
f" [WARN] Columns not in ids_data_array: {e}. Skipping IDs.",
level=2,
verbose=verbose,
logger=_logger,
)
return pd.DataFrame()
elif isinstance(ids_data_array, np.ndarray):
ids_np_array = ids_data_array
if ids_cols_indices:
ids_np_array = ids_np_array[:, ids_cols_indices]
if (
ids_cols
and len(ids_cols) == ids_np_array.shape[1]
):
ids_cols_to_use = list(ids_cols)
else:
ids_cols_to_use = [
f"id_{k}"
for k in range(ids_np_array.shape[1])
]
elif isinstance(ids_data_array, Tensor): # Is a Tensor
ids_np_array = ids_data_array.numpy()
if ids_cols_indices:
ids_np_array = ids_np_array[:, ids_cols_indices]
if (
ids_cols
and len(ids_cols) == ids_np_array.shape[1]
):
ids_cols_to_use = list(ids_cols)
else:
ids_cols_to_use = [
f"id_{k}"
for k in range(ids_np_array.shape[1])
]
else:
vlog(
f" [WARN] Unsupported type for `ids_data_array`: "
f"{type(ids_data_array)}. Skipping IDs.",
level=2,
verbose=verbose,
logger=_logger,
)
return pd.DataFrame()
# Now, repeat the data
if (
ids_np_array is not None
and ids_np_array.shape[0] == num_samples
):
# Use np.repeat to tile each sample H_inferred times
expanded_ids_data = np.repeat(
ids_np_array, H_inferred, axis=0
)
vlog(
f" - Added static/ID columns: {ids_cols_to_use}",
level=5,
verbose=verbose,
logger=_logger,
)
return pd.DataFrame(
expanded_ids_data, columns=ids_cols_to_use
)
else:
vlog(
f" [WARN] `ids_data_array` sample size"
f" ({ids_np_array.shape[0] if ids_np_array is not None else 'None'}) "
f"does not match predictions ({num_samples}). Skipping IDs.",
level=2,
verbose=verbose,
logger=_logger,
)
return pd.DataFrame()
def _evaluate_coverage(
df: pd.DataFrame,
target_mapping: dict[str, str],
q_indices: tuple[int, int],
savefile: str | None = None,
name: str | None = None,
verbose: int = 0,
_logger: Any = None,
):
"""Calculates and logs quantile coverage from an INVERTED DataFrame."""
vlog(
" Evaluating forecast coverage...",
level=4,
verbose=verbose,
logger=_logger,
)
for base_name in target_mapping.values():
actual_col = f"{base_name}_actual"
# Find quantile columns
q_cols = sorted(
[
c
for c in df.columns
if c.startswith(base_name) and "_q" in c
]
)
if not q_cols or actual_col not in df.columns:
vlog(
f" Skipping coverage for '{base_name}': "
"Missing actuals or quantile columns.",
level=3,
verbose=verbose,
logger=_logger,
)
continue
try:
q_low_col = q_cols[q_indices[0]]
q_high_col = q_cols[q_indices[1]]
except IndexError:
vlog(
f" [WARN] Coverage q_indices {q_indices} are out of bounds "
f"for detected quantile columns: {q_cols}. Skipping coverage.",
level=2,
verbose=verbose,
logger=_logger,
)
continue
q_low = df[q_low_col]
q_high = df[q_high_col]
actual = df[actual_col]
is_covered = (actual >= q_low) & (actual <= q_high)
coverage_percent = is_covered.mean() * 100
vlog(
f" Forecast Coverage ({base_name}, {q_low_col} to {q_high_col}): "
f"{coverage_percent:.2f}%",
level=1,
verbose=verbose,
logger=_logger,
)
# Store in df attributes
df.attrs[f"{base_name}_coverage"] = coverage_percent
# Use compute_quantile_diagnostics
try:
cov_savefile = None
if savefile:
# If savefile has an extension, treat it as a file path.
# Example: ".../subs_eval.csv"
# -> ".../subs_eval_diagnostics_results.json"
root, ext = os.path.splitext(savefile)
if ext:
cov_savefile = (
f"{root}_diagnostics_results.json"
)
else:
# No extension: treat `savefile` as a directory-like path
# Example: "results/nansha_eval"
# -> "results/nansha_eval/diagnostics_results.json"
cov_savefile = os.path.join(
savefile, "diagnostics_results.json"
)
compute_quantile_diagnostics(
df,
target_name=base_name,
quantiles=[
float(c.split("_q")[-1]) / 100.0
for c in q_cols
],
coverage_quantile_indices=q_indices,
savefile=cov_savefile,
name=name,
verbose=verbose,
logger=_logger,
)
except Exception as e:
vlog(
f" [WARN] Failed to compute quantile diagnostics: {e}",
level=2,
verbose=verbose,
logger=_logger,
)
def _get_model_predictions(
pihalnet_outputs: dict[str, Tensor] | None,
model: Model | None,
model_inputs: dict[str, Tensor] | None,
verbose: int,
stop_check: Callable[[], bool],
_logger: Any,
) -> dict[str, Tensor]:
"""
Obtain model outputs dict, running predict if needed.
"""
if pihalnet_outputs is not None:
return pihalnet_outputs
if model is None or model_inputs is None:
raise ValueError(
"If 'pihalnet_outputs' is None, both 'model' and "
"'model_inputs' must be provided."
)
vlog(
" Predictions not provided, generating from model...",
level=4,
verbose=verbose,
logger=_logger,
)
try:
outputs = model.predict(model_inputs, verbose=0)
if not isinstance(outputs, dict):
raise ValueError(
"Model output is not a dict as expected from PINN models"
)
except Exception as e:
raise RuntimeError(
f"Failed to generate predictions: {e}"
) from e
vlog(
" Model predictions generated.",
level=5,
verbose=verbose,
logger=_logger,
)
if stop_check and stop_check():
raise InterruptedError("Model prediction aborted.")
return outputs
def _convert_outputs_to_numpy(
pihalnet_outputs: dict[str, Any],
verbose: int,
_logger,
) -> dict[str, np.ndarray]:
"""
Ensure prediction tensors are NumPy arrays.
"""
proc: dict[str, np.ndarray] = {}
for key, val in pihalnet_outputs.items():
if key in ("subs_pred", "gwl_pred"):
if hasattr(val, "numpy"):
proc[key] = val.numpy()
elif isinstance(val, np.ndarray):
proc[key] = val
else:
try:
proc[key] = np.array(val)
except Exception as e:
raise TypeError(
f"Could not convert output '{key}' to NumPy. "
f"Type: {type(val)}. Error: {e}"
) from e
elif key == "pde_residual":
# PDE residual can be handled differently if needed
pass # Not typically added to this output DataFrame
if not proc:
vlog(
" No 'subs_pred' or 'gwl_pred'"
" found in outputs. Returning empty DF.",
level=1,
verbose=verbose,
logger=_logger,
)
return pd.DataFrame()
return proc
def _define_targets(
processed: dict[str, np.ndarray],
include_gwl: bool,
target_mapping: dict[str, str] | None = None,
stop_check: Callable[[], bool] = None,
) -> dict[str, str]:
"""
Map processed outputs to base target names.
"""
if target_mapping is None:
target_mapping = {
"subs_pred": "subsidence",
"gwl_pred": "gwl",
}
targets: dict[str, str] = {}
if "subs_pred" in processed:
targets["subs_pred"] = target_mapping["subs_pred"]
if include_gwl and "gwl_pred" in processed:
targets["gwl_pred"] = target_mapping["gwl_pred"]
if stop_check and stop_check():
raise InterruptedError(
"Target confifuration aborted."
)
return targets
def _infer_dimensions(
first_pred: np.ndarray,
forecast_horizon: int | None,
verbose: int = 0,
_logger: Any = print,
) -> tuple[int, int]:
"""
Infer num_samples and horizon (H) from the array.
"""
N = first_pred.shape[0]
H = forecast_horizon or first_pred.shape[1]
if N == 0 or H == 0:
raise ValueError(
"No samples or zero forecast horizon."
)
vlog(
f" Formatting for {N} samples, Horizon={H}.",
level=4,
verbose=verbose,
logger=_logger,
)
return N, H
def _build_sample_df(
num_samples: int, H: int
) -> pd.DataFrame:
"""
Build DataFrame of sample_idx and forecast_step.
"""
idx = np.repeat(np.arange(num_samples), H)
steps = np.tile(np.arange(1, H + 1), num_samples)
return pd.DataFrame(
{"sample_idx": idx, "forecast_step": steps}
)
def _add_coordinate_columns(
base_df: pd.DataFrame,
model_inputs: dict[str, Any] | None,
coord_scaler: Any,
include_coords: bool,
stop_check: Callable[[], bool],
_logger: Any,
verbose: int,
) -> pd.DataFrame:
"""
Extract and inverse-transform coords if requested.
"""
if not include_coords:
return pd.DataFrame()
coords_arr = None
if model_inputs and "coords" in model_inputs:
coords_arr = model_inputs["coords"]
if hasattr(coords_arr, "numpy"):
coords_arr = coords_arr.numpy()
if coords_arr is None:
return pd.DataFrame()
N = base_df["sample_idx"].nunique()
H = base_df["forecast_step"].max()
if coords_arr.shape != (N, H, 3):
vlog(
" 'coords' shape mismatch or not found."
" Skipping coordinate columns.",
level=2,
verbose=verbose,
logger=_logger,
)
return pd.DataFrame()
resh = coords_arr.reshape(N * H, 3)
if coord_scaler:
vlog(
" Applying inverse transform to t,x,y coordinates...",
level=4,
verbose=verbose,
logger=_logger,
)
try:
resh = coord_scaler.inverse_transform(resh)
except Exception:
pass
if stop_check and stop_check():
raise InterruptedError(
"Coordinates transformation aborted."
)
return pd.DataFrame(
resh, columns=["coord_t", "coord_x", "coord_y"]
)
def _add_id_columns(
base_df: pd.DataFrame,
ids_data_array: Any,
ids_cols: list[str] | None,
ids_cols_indices: list[int] | None,
stop_check: Callable[[], bool],
_logger: Any,
verbose: int,
) -> pd.DataFrame:
"""
Process and repeat static ID columns per forecast step.
Converts DataFrame, Series, NumPy array, or Tensor to
NumPy, selects requested columns or indices, then uses
`sample_idx` from base_df to expand each ID row
across all forecast steps.
"""
# Skip if no IDs provided
if ids_data_array is None:
vlog(
" No `ids_data_array` provided, skipping ID columns.",
level=4,
verbose=verbose,
logger=_logger,
)
return pd.DataFrame()
vlog(
" Processing additional static/ID columns...",
level=4,
verbose=verbose,
logger=_logger,
)
# Convert to NumPy array
try:
if isinstance(ids_data_array, pd.DataFrame):
arr = ids_data_array.values
col_names = ids_cols or list(
ids_data_array.columns
)
elif isinstance(ids_data_array, pd.Series):
arr = ids_data_array.to_frame().values
col_names = ids_cols or [
ids_data_array.name or "id_0"
]
if (
arr.ndim == 2
and len(col_names) != arr.shape[1]
):
vlog(
" Series IDs length mismatch, using first col.",
level=5,
verbose=verbose,
logger=_logger,
)
col_names = col_names[:1]
elif hasattr(ids_data_array, "numpy"):
arr = ids_data_array.numpy()
col_names = ids_cols or [
f"id_{i}" for i in range(arr.shape[1])
]
elif isinstance(ids_data_array, np.ndarray):
arr = ids_data_array
if ids_cols_indices is not None:
arr = arr[:, ids_cols_indices]
col_names = ids_cols or [
f"id_{i}" for i in range(arr.shape[1])
]
else:
_logger.warning(
f"Unsupported IDs type {type(ids_data_array)}."
" Skipping ID columns."
)
return pd.DataFrame()
except Exception as e:
_logger.warning(
f"Error converting IDs to array: {e}."
" Skipping ID columns."
)
return pd.DataFrame()
vlog(
f" Converted IDs to NumPy with shape {arr.shape}",
level=5,
verbose=verbose,
logger=_logger,
)
vlog(
f" Selected ID column names: {col_names}",
level=5,
verbose=verbose,
logger=_logger,
)
# Expand rows according to sample_idx
sample_indices = base_df["sample_idx"].values
if arr.shape[0] != sample_indices.max() + 1:
_logger.warning(
f"IDs rows ({arr.shape[0]}) != samples "
f"({sample_indices.max() + 1}). Skipping."
)
return pd.DataFrame()
expanded = arr[sample_indices]
id_df = pd.DataFrame(expanded, columns=col_names)
vlog(
f" Expanded ID DataFrame shape: {id_df.shape}",
level=5,
verbose=verbose,
logger=_logger,
)
# User cancellation check
if stop_check and stop_check():
raise InterruptedError(
"ID columns processing aborted."
)
return id_df
# # Logic to convert ids_data_array to 2D NumPy
# # and repeat according to base_df indices
# # ... (Detailed implementation)
# return pd.DataFrame() # stub for brevity
def _process_target_variable(
preds: np.ndarray,
pred_key: str,
base_name: str,
num_samples: int,
H: int,
quantiles: list[float] | None,
y_true_dict: dict[str, Any] | None,
scaler_info: dict[str, Any] | None,
verbose: int,
stop_check: Callable[[], bool],
_logger: Any,
) -> pd.DataFrame:
"""
Format predictions, attach actuals, and prepare scaling.
- Calls _format_target_predictions for quantiles
or point forecasts.
- Reshapes and adds true values if provided.
- Records inverse-scaling actions to apply later.
"""
# 1. Format predictions
cols_pred, df_pred = _format_target_predictions(
preds,
num_samples,
H,
O=(preds.shape[-1] if preds.ndim == 3 else 1),
base_target_name=base_name,
quantiles=quantiles,
verbose=verbose,
)
dfs: list[pd.DataFrame] = [df_pred]
# 2. Add actual values
if y_true_dict:
y_true = y_true_dict.get(pred_key)
if y_true is None:
y_true = y_true_dict.get(base_name)
if hasattr(y_true, "numpy"):
y_true = y_true.numpy()
if y_true is not None:
arr = y_true.reshape(num_samples * H, -1)
cols_act = []
for i in range(arr.shape[1]):
col = base_name
if arr.shape[1] > 1:
col += f"_{i}"
cols_act.append(f"{col}_actual")
dfs.append(pd.DataFrame(arr, columns=cols_act))
# 3. Prepare inverse-scaling info
if scaler_info and base_name in scaler_info:
info = scaler_info[base_name]
to_transform = cols_pred + cols_act
info["_cols_to_inv"] = to_transform
if stop_check and stop_check():
raise InterruptedError(
"Target variable processing aborted."
)
return pd.concat(dfs, axis=1)
def _apply_masking(
df: pd.DataFrame,
targets: dict[str, str],
apply_mask: bool,
mask_values: Any,
mask_fill_value: float,
quantiles: list[float] | None,
) -> pd.DataFrame:
"""
Mask forecasts based on reference column values.
"""
if not apply_mask:
return df
first_base = list(targets.values())[0]
ref = f"{first_base}_actual"
if ref not in df:
raise KeyError(f"Reference col '{ref}' missing")
cols = []
if quantiles:
for base in targets.values():
for q in quantiles:
cols.append(f"{base}_q{int(q * 100)}")
else:
cols = [c for c in df if c.endswith("_pred")]
return mask_by_reference(
data=df,
ref_col=ref,
values=mask_values,
find_closest=False,
fill_value=mask_fill_value,
mask_columns=cols,
error="ignore",
inplace=False,
)
def _compute_coverage(
df: pd.DataFrame,
targets: dict[str, str],
quantiles: list[float],
y_true_dict: dict[str, Any],
coverage_quantile_indices: tuple[int, int],
savefile: str | None,
name: str | None,
verbose: int,
_logger: Any,
) -> None:
"""
Compute quantile diagnostics per target.
"""
for base in targets.values():
try:
compute_quantile_diagnostics(
df,
target_name=base,
quantiles=quantiles,
coverage_quantile_indices=coverage_quantile_indices,
savefile=(
os.path.join(
os.path.dirname(savefile),
"diagnostics_results.json",
)
if savefile
else None
),
name=name,
verbose=verbose,
logger=_logger,
)
except Exception as e:
vlog(
f"Skipping coverage due to {e}",
level=2,
verbose=verbose,
logger=_logger,
)
# XXX : TODO
# revise format outputs and apply
# mask by reference
def format_preds(
pihalnet_outputs: dict[str, Tensor] | None = None,
model: Model | None = None,
model_inputs: dict[str, Tensor] | None = None,
y_true_dict: dict[str, Any] | None = None,
target_mapping: dict[str, str] | None = None,
include_gwl: bool = True,
include_coords: bool = True,
quantiles: list[float] | None = None,
forecast_horizon: int | None = None,
output_dims: dict[str, int] | None = None,
ids_data_array: Any | None = None,
ids_cols: list[str] | None = None,
ids_cols_indices: list[int] | None = None,
scaler_info: dict[str, Any] | None = None,
coord_scaler: Any | None = None,
evaluate_coverage: bool = False,
coverage_quantile_indices: tuple[int, int] = (0, -1),
savefile: str | None = None,
name: str | None = None,
apply_mask: bool = False,
mask_values: Any | None = None,
mask_fill_value: float | None = None,
verbose: int = 0,
_logger: Any = None,
stop_check: Callable[[], bool] = None,
**kwargs,
) -> pd.DataFrame:
"""
Main function orchestrating all helper steps.
"""
vlog(
"Starting formatting",
level=3,
verbose=verbose,
logger=_logger,
)
# Step 1: obtain raw outputs
raw_out = _get_model_predictions(
pihalnet_outputs,
model,
model_inputs,
verbose,
stop_check,
_logger,
)
# Step 2: convert to NumPy
proc = _convert_outputs_to_numpy(
raw_out, verbose, _logger
)
# Step 3: define targets
targets = _define_targets(
proc,
include_gwl,
target_mapping,
stop_check,
)
if not targets:
vlog(
" No valid targets to process after"
" filtering. Returning empty DF.",
level=1,
verbose=verbose,
logger=_logger,
)
return pd.DataFrame()
# Step 4: infer dims
num_samp, H = _infer_dimensions(
proc[next(iter(targets))],
forecast_horizon,
verbose,
_logger,
)
# Step 5: build base df
base_df = _build_sample_df(num_samp, H)
# Step 6: add coords
coords_df = _add_coordinate_columns(
base_df,
model_inputs,
coord_scaler,
include_coords,
stop_check,
_logger,
verbose,
)
# Step 7: add IDs
ids_df = _add_id_columns(
base_df,
ids_data_array,
ids_cols,
ids_cols_indices,
stop_check,
_logger,
verbose,
)
# Step 8: process targets
dfs = []
for key, base in targets.items():
dfs.append(
_process_target_variable(
proc[key],
key,
base,
num_samp,
H,
quantiles,
y_true_dict,
scaler_info,
verbose,
stop_check,
_logger,
)
)
# Step 9: concat all
final_df = pd.concat(
[base_df, coords_df, ids_df] + dfs, axis=1
)
# Step 10: masking
final_df = _apply_masking(
final_df,
targets,
apply_mask,
mask_values,
mask_fill_value,
quantiles,
)
# Step 11: coverage
if evaluate_coverage and quantiles:
_compute_coverage(
final_df,
targets,
quantiles,
y_true_dict,
coverage_quantile_indices,
savefile,
name,
verbose,
_logger,
)
vlog(
"Formatting complete",
level=3,
verbose=verbose,
logger=_logger,
)
return final_df
@isdf
def prepare_pinn_data_sequences(
df: pd.DataFrame,
time_col: str,
subsidence_col: str,
gwl_col: str,
dynamic_cols: list[str],
static_cols: list[str] | None = None,
future_cols: list[str] | None = None,
spatial_cols: tuple[str, str] | None = None,
h_field_col: str | None = None,
lon_col: str | None = None,
lat_col: str | None = None,
group_id_cols: list[str] | None = None,
time_steps: int = 12,
forecast_horizon: int = 3,
output_subsidence_dim: int = 1,
output_gwl_dim: int = 1,
datetime_format: str | None = None,
normalize_coords: bool = True,
cols_to_scale: list[str] | str | None = None,
lock_physics_cols: bool = True,
protect_si_suffix: str = "__si",
return_coord_scaler: bool = False,
coord_scaler: MinMaxScaler | None = None,
fit_coord_scaler: bool = True,
mode: str | None = None,
model: str | None = None,
savefile: str | None = None,
progress_hook: Callable[[float], None] | None = None,
stop_check: Callable[[], bool] = None,
verbose: int = 0,
_logger: logging.Logger
| Callable[[str], None]
| None = None,
**kws,
) -> (
tuple[dict[str, np.ndarray], dict[str, np.ndarray]]
| tuple[
dict[str, np.ndarray],
dict[str, np.ndarray],
MinMaxScaler | None,
]
):
# -------------------------------------------------------------------------
_to_range = lambda f, lo, hi: lo + (hi - lo) * f # noqa: E731
vlog(
"Starting PINN data sequence preparation...",
verbose=verbose,
level=1,
logger=_logger,
)
df_proc = df.copy()
# --- 1. Validate Input Parameters and Columns ---
if verbose >= 2:
logger.debug(
"Validating input parameters and columns."
)
vlog(
"Validating essential columns...",
verbose=verbose,
level=2,
logger=_logger,
)
lon_col, lat_col = resolve_spatial_columns(
df_proc,
spatial_cols=spatial_cols,
lon_col=lon_col,
lat_col=lat_col,
)
essential_cols = [
time_col,
lon_col,
lat_col,
subsidence_col,
gwl_col,
]
# --- MODIFICATION 1: Conditional Validation ---
is_geoprior = str(model).lower().strip() in (
"geoprior",
"geopriorsubsnet",
)
if is_geoprior:
if h_field_col is None:
# Check for default names
if "H_field" in df_proc.columns:
h_field_col = "H_field"
elif "soil_thickness" in df_proc.columns:
h_field_col = "soil_thickness"
else:
raise ValueError(
"`model` is 'geoprior' but `h_field_col` was not "
"provided and default names 'H_field' or 'soil_thickness' "
"were not found in the DataFrame."
)
vlog(
f"GeoPrior model detected. Using '{h_field_col}' as H_field.",
verbose=verbose,
level=3,
logger=_logger,
)
essential_cols.append(h_field_col)
vlog(
f"Validated '{h_field_col}' for GeoPriorSubsNet.",
verbose=verbose,
level=3,
logger=_logger,
)
# --- End Modification 1 ---
exist_features(
df_proc,
features=essential_cols,
message="Essential column(s) missing.",
)
vlog(
"Validating time-series dataset...",
verbose=verbose,
level=2,
logger=_logger,
)
check_datetime(
df_proc,
dt_cols=time_col,
ops="check_only",
consider_dt_as="numeric",
accept_dt=True,
allow_int=True,
)
vlog(
"Managing feature column lists...",
verbose=verbose,
level=3,
logger=_logger,
)
dynamic_cols = columns_manager(
dynamic_cols, empty_as_none=False
)
exist_features(
df_proc,
features=dynamic_cols,
name="Dynamic feature column(s)",
)
static_cols = columns_manager(static_cols)
if static_cols:
exist_features(
df_proc,
features=static_cols,
name="Static feature column(s)",
)
future_cols = columns_manager(future_cols)
if future_cols:
exist_features(
df_proc,
features=future_cols,
name="Future feature column(s)",
)
group_id_cols = columns_manager(group_id_cols)
if group_id_cols:
exist_features(
df_proc,
features=group_id_cols,
name="Group ID column(s)",
)
vlog(
"Validating time_steps and forecast_horizon...",
verbose=verbose,
level=3,
logger=_logger,
)
time_steps = validate_positive_integer(
time_steps, "time_steps"
)
forecast_horizon = validate_positive_integer(
forecast_horizon,
"forecast_horizon",
)
vlog(
"Converting time column to numeric values...",
verbose=verbose,
level=4,
logger=_logger,
)
if pd.api.types.is_numeric_dtype(df_proc[time_col]):
numerical_time_col = time_col
if verbose >= 2:
logger.debug(
f"Time column '{time_col}' is already numeric. Using it directly."
)
else:
try:
df_proc[time_col] = pd.to_datetime(
df_proc[time_col], format=datetime_format
)
df_proc[f"{time_col}_numeric"] = df_proc[
time_col
].dt.year + (
df_proc[time_col].dt.dayofyear - 1
) / (
365
+ df_proc[time_col].dt.is_leap_year.astype(
int
)
)
numerical_time_col = f"{time_col}_numeric"
if verbose >= 2:
logger.debug(
f"Converted datetime column '{time_col}'"
f" to numerical '{numerical_time_col}'."
)
vlog(
f"Time column converted to '{numerical_time_col}'",
verbose=verbose,
level=5,
logger=_logger,
)
except Exception as e:
raise ValueError(
f"Failed to convert or process time column '{time_col}'. "
f"Ensure it's datetime-like or specify `datetime_format`."
f" Error: {e}"
)
vlog(
"Pre-flight: assessing sliding-window feasibility ...",
verbose=verbose,
level=1,
logger=_logger,
)
ok, _ = check_sequence_feasibility(
df_proc.copy(),
time_col=time_col,
group_id_cols=group_id_cols,
time_steps=time_steps,
forecast_horizon=forecast_horizon,
verbose=verbose,
error="raise", # fail fast on impossibility,
logger=_logger,
)
vlog(
"Feasibility check passed — generating sequences...",
verbose=verbose,
level=1,
logger=_logger,
)
vlog(
"Starting PINN data sequence generation...",
verbose=verbose,
level=1,
logger=_logger,
)
mode = select_mode(mode, default="pihal_like")
vlog(
f"Operating in '{mode}' data preparation mode.",
level=1,
verbose=verbose,
logger=_logger,
)
exclude_cols = []
if lock_physics_cols:
exclude_cols += [subsidence_col, gwl_col]
if h_field_col is not None:
exclude_cols.append(h_field_col)
# Also exclude any "__si" columns automatically (safe for Stage-1 outputs)
if protect_si_suffix:
exclude_cols += [
c
for c in df_proc.columns
if str(c).endswith(protect_si_suffix)
]
# IMPORTANT: do NOT shift time in normalization for PINN
df_proc, coord_scaler, cols_scaler = normalize_for_pinn(
df=df_proc,
time_col=numerical_time_col,
coord_x=lon_col,
coord_y=lat_col,
scale_coords=normalize_coords,
cols_to_scale=cols_to_scale,
forecast_horizon=forecast_horizon,
shift_time_by_horizon=False,
exclude_cols=exclude_cols,
protect_si_suffix=protect_si_suffix,
verbose=verbose,
_logger=_logger,
coord_scaler=coord_scaler,
fit_coord_scaler=fit_coord_scaler,
)
# --- 2. Group and Sort Data ---
vlog(
"Grouping and sorting data...",
verbose=verbose,
level=2,
logger=_logger,
)
if verbose >= 2:
logger.debug(
f"Grouping and sorting data. Group IDs:"
f" {group_id_cols or 'None (single group)'}."
)
sort_by_cols = [numerical_time_col]
if group_id_cols:
sort_by_cols = group_id_cols + sort_by_cols
df_proc = df_proc.sort_values(
by=sort_by_cols
).reset_index(drop=True)
if group_id_cols:
grouped_data = df_proc.groupby(group_id_cols)
group_keys = list(grouped_data.groups.keys())
if verbose >= 1:
logger.info(
f"Data grouped by {group_id_cols} into {len(group_keys)} groups."
)
else:
grouped_data = [(None, df_proc)]
group_keys = [None]
if verbose >= 1:
logger.info(
"Processing entire DataFrame as a single group."
)
# --- 3. First Pass: Calculate Total Number of Sequences ---
total_sequences = 0
min_len_per_group = time_steps + forecast_horizon
valid_group_dfs = []
vlog(
"Counting valid sequences in groups...",
verbose=verbose,
level=4,
logger=_logger,
)
if verbose >= 2:
logger.debug(
"First pass: Calculating total sequences. Min"
f" length per group: {min_len_per_group}."
)
n_groups = len(group_keys) or 1
for g_idx, group_key in enumerate(group_keys):
group_df = (
grouped_data.get_group(group_key)
if group_id_cols
else df_proc
)
if stop_check and stop_check():
raise InterruptedError(
"Sequence generation aborted."
)
key_str = (
group_key
if group_key is not None
else "<Full Dataset>"
)
if progress_hook is not None:
progress_hook(
_to_range((g_idx + 1) / n_groups, 0.0, 0.50)
)
if len(group_df) < min_len_per_group:
if verbose >= 5:
logger.info(
f"Group '{key_str}' has {len(group_df)} points, less than "
f"min required ({min_len_per_group}). Skipping."
)
vlog(
f"Group {key_str} too small:"
f" {len(group_df)} < {min_len_per_group}",
verbose=verbose,
level=6,
logger=_logger,
)
continue
num_seq_in_group = (
len(group_df) - min_len_per_group + 1
)
total_sequences += num_seq_in_group
valid_group_dfs.append(group_df)
if verbose >= 5:
logger.debug(
f"Group '{group_key if group_key else '<Full Dataset>'}' "
f"will yield {num_seq_in_group} sequences."
)
vlog(
f"Group {key_str} yields {total_sequences} seqs.",
verbose=verbose,
level=6,
logger=_logger,
)
if total_sequences == 0:
raise ValueError(
"No group has enough data points to create sequences with "
f"time_steps={time_steps} and forecast_horizon={forecast_horizon}."
)
if verbose >= 1:
logger.info(
"Total valid sequences to be"
f" generated: {total_sequences}."
)
vlog(
f"Total sequences: {total_sequences}",
verbose=verbose,
level=1,
logger=_logger,
)
# --- 4. Pre-allocate NumPy Arrays ---
vlog(
"Pre-allocating arrays...",
verbose=verbose,
level=2,
logger=_logger,
)
num_dynamic_feats = len(dynamic_cols)
num_static_feats = len(static_cols) if static_cols else 0
num_future_feats = len(future_cols) if future_cols else 0
coords_horizon_arr = np.zeros(
(total_sequences, forecast_horizon, 3),
dtype=np.float32,
)
static_features_arr = np.zeros(
(total_sequences, num_static_feats), dtype=np.float32
)
dynamic_features_arr = np.zeros(
(total_sequences, time_steps, num_dynamic_feats),
dtype=np.float32,
)
if mode == "tft_like":
future_window_len = time_steps + forecast_horizon
vlog(
f"Allocating future features array for 'tft_like' mode "
f"with time dimension: {future_window_len}",
level=2,
verbose=verbose,
logger=_logger,
)
else:
future_window_len = forecast_horizon
future_features_arr = np.zeros(
(
total_sequences,
future_window_len,
num_future_feats,
),
dtype=np.float32,
)
target_subsidence_arr = np.zeros(
(
total_sequences,
forecast_horizon,
output_subsidence_dim,
),
dtype=np.float32,
)
target_gwl_arr = np.zeros(
(total_sequences, forecast_horizon, output_gwl_dim),
dtype=np.float32,
)
# --- MODIFICATION 2: Conditional Allocation ---
H_field_arr = None # Initialize as None
if is_geoprior:
H_field_arr = np.zeros(
(total_sequences, forecast_horizon, 1),
dtype=np.float32,
)
vlog(
f"Allocated 'H_field' array with shape {H_field_arr.shape}",
verbose=verbose,
level=4,
logger=_logger,
)
# --- End Modification 2 ---
vlog(
"Arrays shapes set.",
verbose=verbose,
level=5,
logger=_logger,
)
if verbose >= 2:
logger.debug(
"Pre-allocated NumPy arrays for sequence data:"
)
logger.debug(
f" Coords Horizon: {coords_horizon_arr.shape}"
)
logger.debug(
f" Static Features: {static_features_arr.shape}"
)
logger.debug(
f" Dynamic Features: {dynamic_features_arr.shape}"
)
logger.debug(
f" Future Features: {future_features_arr.shape}"
)
logger.debug(
f" Target Subsidence: {target_subsidence_arr.shape}"
)
logger.debug(f" Target GWL: {target_gwl_arr.shape}")
# --- MODIFICATION 2b: Conditional Logging ---
if H_field_arr is not None:
logger.debug(
f" H_field ({h_field_col}): {H_field_arr.shape}"
)
# --- End Modification 2b ---
# --- 5. Second Pass: Populate Arrays with Rolling Windows ---
current_seq_idx = 0
if verbose >= 2:
logger.debug(
"Second pass: Populating sequence arrays..."
)
vlog(
"Populating arrays with data...",
verbose=verbose,
level=2,
logger=_logger,
)
total_seq = (
sum(
len(gdf) - min_len_per_group + 1
for gdf in valid_group_dfs
)
or 1
)
done_seq = 0
for group_df in valid_group_dfs:
group_t_coords = group_df[numerical_time_col].values
group_x_coords = group_df[lon_col].values
group_y_coords = group_df[lat_col].values
t_min_group, t_max_group = (
group_t_coords.min(),
group_t_coords.max(),
)
if verbose >= 3:
print()
time_scale_info = (
f"{t_min_group:.4f}-{t_max_group:.4f}"
)
if normalize_coords and coord_scaler:
time_scale_info += " (normalized)"
else:
time_scale_info += " (original scale)"
print_box(
f"Group window t: {time_scale_info}",
width=_TW,
align="center",
border_char="+",
horizontal_char="-",
vertical_char="|",
padding=1,
)
group_static_vals = None
if static_cols and num_static_feats > 0:
group_static_vals = group_df.iloc[0][
static_cols
].values.astype(np.float32)
# --- MODIFICATION 3: Get static H_val for the group ---
group_H_val = None
if is_geoprior:
# Get the single static soil thickness value for this group
group_H_val = group_df.iloc[0][h_field_col]
# --- End Modification 3 ---
num_seq_in_this_group = (
len(group_df) - min_len_per_group + 1
)
for i in range(num_seq_in_this_group):
done_seq += 1
if progress_hook is not None:
progress_hook(
_to_range(done_seq / total_seq, 0.5, 1.00)
)
if stop_check and stop_check():
raise InterruptedError(
"Sequence generation aborted by user"
)
if static_cols and num_static_feats > 0:
static_features_arr[current_seq_idx] = (
group_static_vals
)
dynamic_start_idx = i
dynamic_end_idx = i + time_steps
dynamic_features_arr[current_seq_idx] = (
group_df.iloc[
dynamic_start_idx:dynamic_end_idx
][dynamic_cols].values.astype(np.float32)
)
horizon_start_idx = i + time_steps
horizon_end_idx = (
i + time_steps + forecast_horizon
)
if future_cols and num_future_feats > 0:
if mode == "tft_like":
future_start_idx = i
future_end_idx = (
i + time_steps + forecast_horizon
)
else: # 'pihal_like' mode
future_start_idx = horizon_start_idx
future_end_idx = horizon_end_idx
future_features_arr[current_seq_idx] = (
group_df.iloc[
future_start_idx:future_end_idx
][future_cols].values.astype(np.float32)
)
t_horizon = group_t_coords[
horizon_start_idx:horizon_end_idx
]
x_horizon = group_x_coords[
horizon_start_idx:horizon_end_idx
]
y_horizon = group_y_coords[
horizon_start_idx:horizon_end_idx
]
coords_horizon_arr[current_seq_idx, :, 0] = (
t_horizon
)
coords_horizon_arr[current_seq_idx, :, 1] = (
x_horizon
)
coords_horizon_arr[current_seq_idx, :, 2] = (
y_horizon
)
target_subsidence_arr[current_seq_idx] = (
group_df.iloc[
horizon_start_idx:horizon_end_idx
][subsidence_col]
.values.reshape(
forecast_horizon, output_subsidence_dim
)
.astype(np.float32)
)
target_gwl_arr[current_seq_idx] = (
group_df.iloc[
horizon_start_idx:horizon_end_idx
][gwl_col]
.values.reshape(
forecast_horizon, output_gwl_dim
)
.astype(np.float32)
)
# --- MODIFICATION 4: Populate H_field array ---
if H_field_arr is not None:
# Tile the static H_val across the forecast horizon
H_field_arr[current_seq_idx, :, 0] = (
group_H_val
)
# --- End Modification 4 ---
if verbose >= 7:
logger.debug(f" Sequence {current_seq_idx}:")
logger.debug(
" Dynamic window:"
f" {dynamic_start_idx}-{dynamic_end_idx - 1}"
)
logger.debug(
" Horizon window:"
f" {horizon_start_idx}-{horizon_end_idx - 1}"
)
logger.debug(
f" Coords (first step):"
f" {coords_horizon_arr[current_seq_idx, 0, :]}"
)
vlog(
f"Seq {current_seq_idx}:"
f" dyn {dynamic_start_idx}-{dynamic_end_idx - 1},"
f"hzn {horizon_start_idx}-{horizon_end_idx - 1}",
verbose=verbose,
level=7,
logger=_logger,
)
current_seq_idx += 1
if verbose >= 1:
logger.info("Successfully populated sequence arrays.")
vlog(
"Data population complete.",
verbose=verbose,
level=1,
logger=_logger,
)
inputs_dict = {
"coords": coords_horizon_arr,
"static_features": static_features_arr
if num_static_feats > 0
else None,
"dynamic_features": dynamic_features_arr,
"future_features": future_features_arr
if num_future_feats > 0
else None,
}
# --- MODIFICATION 5: Conditionally add H_field to inputs_dict ---
if H_field_arr is not None:
inputs_dict["H_field"] = H_field_arr
# --- End Modification 5 ---
inputs_dict = {
k: v for k, v in inputs_dict.items() if v is not None
}
targets_dict = {
"subsidence": target_subsidence_arr,
"gwl": target_gwl_arr,
}
if savefile:
vlog(
f"\nPreparing to save sequence data to '{savefile}'...",
verbose=verbose,
level=3,
logger=_logger,
)
# --- v3.2: allow split between GWL dynamic driver and GWL prediction target ---
gwl_dyn_col = kws.get("gwl_dyn_col", None) or gwl_col
if gwl_dyn_col not in dynamic_cols:
raise ValueError(
f"gwl_dyn_col={gwl_dyn_col!r} must be present in dynamic_cols.\n"
f"Got gwl_col(target)={gwl_col!r} and dynamic_cols={dynamic_cols}.\n"
"GeoPrior v3.2 expects: gwl_dyn_col=<depth__si>, gwl_col=<head__si>."
)
gwl_dyn_index = int(dynamic_cols.index(gwl_dyn_col))
job_dict = {
"static_data": static_features_arr,
"dynamic_data": dynamic_features_arr,
"future_data": future_features_arr,
"subsidence": target_subsidence_arr,
"gwl": target_gwl_arr,
"static_features": static_cols,
"dynamic_features": dynamic_cols,
"future_features": future_cols,
"inputs_dict": inputs_dict,
"targets_dict": targets_dict,
"subsidence_col": subsidence_col,
"spatial_features": spatial_cols,
"lon_col": lon_col,
"lat_col": lat_col,
"time_col": time_col,
"time_steps": time_steps,
"forecast_horizon": forecast_horizon,
"cols_scaler": cols_scaler,
"coord_scaler": coord_scaler,
"normalize_coords_flag": normalize_coords,
"saved_coord_scaler_flag": coord_scaler
is not None,
# --- Conditionally add to savefile ---
"model_type": model,
"h_field_col": h_field_col,
"H_field": H_field_arr
if H_field_arr is not None
else None,
"gwl_col": gwl_col, # target column (head)
"gwl_dyn_col": gwl_dyn_col, # dynamic driver column (depth)
"gwl_dyn_index": gwl_dyn_index, # index of the driver in dynamic_cols
}
try:
job_dict.update(get_versions())
except NameError:
vlog(
"\n `get_versions` not found, version info not saved.",
verbose=verbose,
level=1,
logger=_logger,
)
try:
save_job(
job_dict, savefile, append_versions=False
)
if verbose >= 1:
vlog(
f"Sequence data dictionary successfully "
f"saved to '{savefile}'.",
verbose=verbose,
level=1,
logger=_logger,
)
except Exception as e:
vlog(
f"Failed to save job dictionary to "
f"'{savefile}': {e}",
verbose=verbose,
level=1,
logger=_logger,
)
if verbose >= 1:
logger.info(
"PINN data sequence preparation completed."
)
if verbose >= 3:
for key, arr in inputs_dict.items():
logger.debug(
" Final input '{key}' shape:"
f" {arr.shape if arr is not None else 'None'}"
)
for key, arr in targets_dict.items():
logger.debug(
f" Final target '{key}' shape: {arr.shape}"
)
vlog(
"PINN data sequence preparation successfully completed.",
verbose=verbose,
level=3,
logger=_logger,
)
if return_coord_scaler:
return inputs_dict, targets_dict, coord_scaler
return inputs_dict, targets_dict
def check_and_rename_keys(
inputs, y
): # ranem to check_input_keys
r"""
Helper function to check and rename keys in the inputs
and target dictionaries.
This function ensures that the necessary keys are present in both the
`inputs` and `y` dictionaries. If the keys for 'subsidence' or 'gwl'
are not found, it attempts to rename them from possible alternatives
like 'subs_pred' or 'gwl_pred'.
Parameters
----------
inputs : dict
A dictionary containing the input data. The keys 'coords' and
'dynamic_features' are expected.
y : dict
A dictionary containing the target values. The keys 'subsidence'
and 'gwl' are expected, but they could also appear as 'subs_pred'
or 'gwl_pred'.
Raises
------
ValueError
If required keys are missing in `inputs` or `y`, or if renaming
does not result in valid keys for 'subsidence' and 'gwl'.
"""
# Check if 'coords' and 'dynamic_features' are in inputs
if "coords" not in inputs or inputs["coords"] is None:
raise ValueError("Input 'coords' is missing or None.")
if (
"dynamic_features" not in inputs
or inputs["dynamic_features"] is None
):
raise ValueError(
"Input 'dynamic_features' is missing or None."
)
# Check for 'subsidence' in y, allow renaming from 'subs_pred'
# just check whether subsidence or subs_pred in y
if "subsidence" not in y and "subs_pred" in y:
y["subsidence"] = y.pop("subs_pred")
if "subsidence" not in y:
# here by explicit to hel the
raise ValueError(
"Target 'subsidence' is missing or None."
)
# use that user should provide subsidene or subs_pred
# Check for 'gwl' in y, allow renaming from 'gwl_pred'
# just check whether gwl or gwl_pred one,its in the y
#
if "gwl_pred" not in y and "gwl" in y:
# no need to rename yet, later this will handle si just check only
y["gwl_pred"] = y.pop("gwl")
if "gwl" not in y:
raise ValueError("Target 'gwl' is missing or None.")
return inputs, y
def _check_required_input_keys(
inputs,
y=None,
message=None,
):
r"""
Helper function to check and rename keys in the inputs
and target dictionaries.
This function ensures that the necessary keys are present in both the
`inputs` and `y` dictionaries. If the keys for 'subsidence' or 'gwl'
are not found, it attempts to rename them from possible alternatives
like 'subs_pred' or 'gwl_pred'.
Parameters
----------
inputs : dict
A dictionary containing the input data. The keys 'coords' and
'dynamic_features' are expected.
y : dict
A dictionary containing the target values. The keys 'subsidence'
and 'gwl' are expected, but they could also appear as 'subs_pred'
or 'gwl_pred'.
message : str, optional
Message to raise error when inputs/y are not dictionnary.
Raises
------
ValueError
If required keys are missing in `inputs` or `y`, or if renaming
does not result in valid keys for 'subsidence' and 'gwl'.
"""
if inputs is not None:
if not isinstance(inputs, dict):
message = message or (
"Inputs must be a dictionnary containing"
" 'coords' and 'dynamic_features'."
f" Got {type(inputs).__name__!r}"
)
raise TypeError(message)
# Check if 'coords' and 'dynamic_features' are in inputs
if "coords" not in inputs or inputs["coords"] is None:
raise ValueError(
"Input 'coords' is missing or None."
)
if (
"dynamic_features" not in inputs
or inputs["dynamic_features"] is None
):
raise ValueError(
"Input 'dynamic_features' is missing or None."
)
if y is not None:
if not isinstance(y, dict):
message = message or (
"Target `y` must be a dictionnary containing"
" 'subs_pred/subsidence' and 'gwl/gwl_red'."
f" Got {type(y).__name__!r}"
)
raise TypeError(message)
# Check for 'subsidence' in y, allow renaming from 'subs_pred'
if "subsidence" not in y and "subs_pred" not in y:
raise ValueError(
"Target 'subsidence' is missing or None."
" Please provide 'subsidence' or 'subs_pred'."
)
# Check for 'gwl' in y, allow renaming from 'gwl_pred'
if "gwl" not in y and "gwl_pred" not in y:
raise ValueError(
"Target 'gwl' is missing or None."
" Please provide 'gwl' or 'gwl_pred'."
)
return inputs, y
def check_required_input_keys(
inputs: dict[str, Any] | None,
y: dict[str, Any] | None = None,
message: str | None = None,
model_name: str | None = None,
do_rename: bool = True,
) -> tuple[dict[str, Any] | None, dict[str, Any] | None]:
"""
Validate presence of required keys in `inputs` and `y`.
Optionally canonicalize keys via reverse alias mapping.
This function ensures that the necessary keys are present in both the
`inputs` and `y` dictionaries. If the keys for 'subsidence' or 'gwl'
are not found, it attempts to rename them from possible alternatives
like 'subs_pred' or 'gwl_pred'.
Parameters
----------
inputs : dict
A dictionary containing the input data. The keys 'coords' and
'dynamic_features' are expected.
y : dict
A dictionary containing the target values. The keys 'subsidence'
and 'gwl' are expected, but they could also appear as 'subs_pred'
or 'gwl_pred'.
message : str, optional
Message to raise error when inputs/y are not dictionnary.
Raises
------
ValueError
If required keys are missing in `inputs` or `y`, or if renaming
does not result in valid keys for 'subsidence' and 'gwl'.
"""
# ---- inputs checks -------------------------------------------------
if inputs is not None:
if not isinstance(inputs, dict):
msg = message or (
"Inputs must be a dict with 'coords' and "
"'dynamic_features'. Got "
f"{type(inputs).__name__!r}."
)
raise TypeError(msg)
# Canonicalize GeoPrior-specific inputs (H field aliases)
if do_rename and model_name:
name = str(model_name).lower()
if name in (
"geoprior",
"geopriorsubsnet",
"geopriorsubsnet",
):
inputs = rename_dict_keys(
inputs,
{
"H_field": (
"H_field",
"soil_thickness",
"soil thickness",
"h_field",
)
},
order="reverse",
)
# Mandatory base inputs
if (
"coords" not in inputs
or inputs.get("coords") is None
):
raise ValueError(
"Input 'coords' is missing or None."
)
if (
"dynamic_features" not in inputs
or inputs.get("dynamic_features") is None
):
raise ValueError(
"Input 'dynamic_features' is missing or None."
)
# GeoPrior requires H_field (after alias unification)
if model_name:
name = str(model_name).lower()
if name in (
"geoprior",
"geopriorsubsnet",
"geopriorsubsnet",
):
if (
"H_field" not in inputs
or inputs.get("H_field") is None
):
raise ValueError(
"GeoPrior requires 'H_field' in inputs "
"(aliases accepted: 'soil_thickness', "
"'soil thickness', 'h_field')."
)
# ---- target checks -------------------------------------------------
if y is not None:
if not isinstance(y, dict):
msg = message or (
"Target `y` must be a dict containing "
"'subs_pred/subsidence' and 'gwl_pred/gwl'. Got "
f"{type(y).__name__!r}."
)
raise TypeError(msg)
# Canonicalize targets to subs_pred / gwl_pred
if do_rename:
y = rename_dict_keys(
y,
{
"subs_pred": ("subs_pred", "subsidence"),
"gwl_pred": ("gwl_pred", "gwl"),
},
order="reverse",
)
# Accept either canonical or legacy if not renaming
has_subs = ("subs_pred" in y) or ("subsidence" in y)
has_gwl = ("gwl_pred" in y) or ("gwl" in y)
if not has_subs:
raise ValueError(
"Target missing subsidence. Provide 'subs_pred' or "
"'subsidence'."
)
if not has_gwl:
raise ValueError(
"Target missing gwl. Provide 'gwl_pred' or 'gwl'."
)
return inputs, y
def _extract_txy_in(
inputs: Tensor
| np.ndarray
| dict[str, Tensor | np.ndarray],
coord_slice_map: dict[str, int] | None = None,
) -> tuple[Tensor, Tensor, Tensor]:
r"""
Extracts t, x, y tensors from `inputs`, which may be:
- A single 3D tensor of shape (batch, time_steps, 3)
- A dict containing a key 'coords' with such a tensor
- A dict containing separate keys 't', 'x', and 'y'
Parameters
----------
inputs : tf.Tensor or np.ndarray or dict
If tensor/array: expected shape is (batch, time_steps, 3).
If dict:
- If 'coords' in dict: dict['coords'] must be (batch, time_steps, 3).
- Otherwise, dict must have keys 't', 'x', 'y' each of shape
(batch, time_steps, 1) or (batch, time_steps).
coord_slice_map : dict, optional
Mapping from 't', 'x', 'y' to their index in the last dimension of
the coords tensor. Defaults to {'t': 0, 'x': 1, 'y': 2}.
Returns
-------
t : tf.Tensor
Tensor of shape (batch, time_steps, 1) corresponding to time coordinate.
x : tf.Tensor
Tensor of shape (batch, time_steps, 1) corresponding to x coordinate.
y : tf.Tensor
Tensor of shape (batch, time_steps, 1) corresponding to y coordinate.
Raises
------
ValueError
If `inputs` is not in one of the supported formats, or dimensions
are inconsistent.
"""
# Default slice map
if coord_slice_map is None:
coord_slice_map = {"t": 0, "x": 1, "y": 2}
# Helper to ensure output is a tf.Tensor with a final singleton dim
def _ensure_tensor_with_last_dim(
inp: Tensor | np.ndarray,
) -> Tensor:
if isinstance(inp, np.ndarray):
inp = tf_convert_to_tensor(inp)
if not isinstance(inp, Tensor):
raise ValueError(
f"Expected tf.Tensor or np.ndarray, got {type(inp)}"
)
# If shape is (batch, time_steps), add last dimension
if inp.ndim == 2:
inp = tf_expand_dims(inp, axis=-1)
# Now expect (batch, time_steps, 1)
if inp.ndim != 3 or inp.shape[-1] != 1:
raise ValueError(
"Coordinate array must have shape "
f"(batch, time_steps, 1); got {inp.shape}"
)
return inp
# Case 1: inputs is a dict
if isinstance(inputs, dict):
# If 'coords' key is present
if "coords" in inputs:
coords_tensor = inputs["coords"]
if isinstance(coords_tensor, np.ndarray):
coords_tensor = tf_convert_to_tensor(
coords_tensor
)
if not isinstance(coords_tensor, Tensor):
raise ValueError(
f"Expected tensor/array for 'coords';"
f" got {type(coords_tensor)}"
)
# Expect shape (batch, time_steps, 3)
if (
coords_tensor.ndim != 3
or coords_tensor.shape[-1] < 3
):
raise ValueError(
f"'coords' must have shape (batch, time_steps, ≥3);"
f" got {coords_tensor.shape}"
)
# Slice out t, x, y
t = coords_tensor[
...,
coord_slice_map["t"] : coord_slice_map["t"]
+ 1,
]
x = coords_tensor[
...,
coord_slice_map["x"] : coord_slice_map["x"]
+ 1,
]
y = coords_tensor[
...,
coord_slice_map["y"] : coord_slice_map["y"]
+ 1,
]
return (
tf_cast(t, tf_float32),
tf_cast(x, tf_float32),
tf_cast(y, tf_float32),
)
# If keys 't','x','y' exist separately
if all(k in inputs for k in ("t", "x", "y")):
t = _ensure_tensor_with_last_dim(inputs["t"])
x = _ensure_tensor_with_last_dim(inputs["x"])
y = _ensure_tensor_with_last_dim(inputs["y"])
return (
tf_cast(t, tf_float32),
tf_cast(x, tf_float32),
tf_cast(y, tf_float32),
)
raise ValueError(
"Dict `inputs` must contain either key 'coords' or keys 't', 'x', 'y'."
)
# Case 2: inputs is a single tensor/array
if isinstance(inputs, Tensor | np.ndarray):
coords_tensor = inputs
if isinstance(coords_tensor, np.ndarray):
coords_tensor = tf_convert_to_tensor(
coords_tensor
)
# Expect shape (batch, time_steps, 3)
if (
coords_tensor.ndim != 3
or coords_tensor.shape[-1] < 3
):
raise ValueError(
f"Tensor `inputs` must have shape (batch, time_steps, 3);"
f" got {coords_tensor.shape}"
)
t = coords_tensor[
...,
coord_slice_map["t"] : coord_slice_map["t"] + 1,
]
x = coords_tensor[
...,
coord_slice_map["x"] : coord_slice_map["x"] + 1,
]
y = coords_tensor[
...,
coord_slice_map["y"] : coord_slice_map["y"] + 1,
]
return (
tf_cast(t, tf_float32),
tf_cast(x, tf_float32),
tf_cast(y, tf_float32),
)
raise ValueError(
f"`inputs` must be a tensor/array or dict; got {type(inputs)}"
)
def _get_coords(
inputs: dict[str, Any] | Sequence[Any] | Tensor,
*,
check_shape: bool = False,
is_time_dependent: bool = True,
) -> Tensor:
r"""
Extract the **coords** tensor from any input layout.
Parameters
----------
inputs : dict, sequence or tensor
* **Dict** – must contain a ``"coords"`` key.
* **Sequence** – coords are expected at index 0.
* **Tensor** – interpreted as coords directly.
check_shape : bool, default ``False``
If ``True``, validate that the last dimension equals 3 and that
the rank matches *is_time_dependent*.
is_time_dependent : bool, default ``True``
* ``True`` → expect a 3-D shape ``(B, T, 3)``
* ``False`` → allow a 2-D shape ``(B, 3)``; a 3-D shape is also
accepted (useful when the model ignores time).
Returns
-------
tf.Tensor
The extracted coords tensor.
Raises
------
KeyError
If a dict is missing the ``"coords"`` key.
TypeError
If *inputs* is not dict, sequence or tensor.
ValueError
If ``check_shape`` is ``True`` and the tensor shape is invalid.
Notes
-----
The function is lightweight and executes outside any
``tf.function`` context; use it freely inside the model’s
``train_step`` or data-pipe helpers.
"""
# ── 1. locate the tensor
if isinstance(inputs, Tensor):
coords = inputs
elif isinstance(inputs, tuple | list):
coords = inputs[0]
elif isinstance(inputs, dict):
try:
coords = inputs["coords"]
except KeyError as err:
raise KeyError(
"Input dict lacks a 'coords' entry."
) from err
else:
raise TypeError(
"inputs must be a dict, sequence, or Tensor; "
f"got {type(inputs).__name__}"
)
# ── 2. validate shape if requested
if check_shape:
shape = coords.shape # static if known, else dynamic
# last dim must be 3
if shape.rank is not None:
if shape[-1] != 3:
raise ValueError(
"coords[..., -1] must equal 3 (t, x, y); "
f"found {shape[-1]}"
)
# rank check when static
if is_time_dependent and shape.rank != 3:
raise ValueError(
"Expected coords rank 3 (B, T, 3) for time-dependent "
"data; received rank "
f"{shape.rank}"
)
if not is_time_dependent and shape.rank not in (
2,
3,
):
raise ValueError(
"Expected coords rank 2 or 3 for static data; "
f"received rank {shape.rank}"
)
else: # dynamic shapes (inside tf.function)
tf_debugging.assert_equal(
tf_shape(coords)[-1],
3,
message="coords[..., -1] must equal 3 (t, x, y)",
)
if is_time_dependent:
tf_debugging.assert_rank(coords, 3)
else:
tf_debugging.assert_rank_in(coords, (2, 3))
return coords
@ensure_pkg(
KERAS_BACKEND or "tensorflow",
extra="TensorFlow is required for this function.",
)
def extract_txy_in(
inputs: Tensor
| np.ndarray
| dict[str, Tensor | np.ndarray],
coord_slice_map: dict[str, int] | None = None,
expect_dim: str | None = None,
verbose: int = 0,
_logger: logging.Logger
| Callable[[str], None]
| None = None,
**kws,
) -> tuple[Tensor, Tensor, Tensor]:
r"""
Extracts t, x, y tensors from various input formats.
This utility standardizes coordinate inputs, accepting a single
tensor or a dictionary, and handling both 2D (spatial/static)
and 3D (spatio-temporal) data. It ensures a consistent 3D
output format for robust downstream processing.
Parameters
----------
inputs : tf.Tensor, np.ndarray, or dict
The input data containing coordinates. A single tensor or array
may be 2D with shape ``(batch, 3)`` or 3D with shape
``(batch, time_steps, 3)``. A dictionary may contain a
``'coords'`` key with the coordinate tensor, or separate
``'t'``, ``'x'``, and ``'y'`` keys.
coord_slice_map : dict, optional
Mapping for 't', 'x', 'y' to their index in the last
dimension of a coordinate tensor.
Defaults to `{'t': 0, 'x': 1, 'y': 2}`.
expect_dim : {'2d', '3d'}, optional
If provided, enforces that the input resolves to the
specified dimension. ``'2d'`` requires input shaped like
``(batch, 3)`` or a dictionary of ``(batch, 1)`` tensors.
``'3d'`` requires input shaped like ``(batch, time, 3)`` or a
dictionary of ``(batch, time, 1)`` tensors. If ``None``,
both are accepted and 2D inputs are expanded to 3D.
verbose : int, default 0
Controls the verbosity of logging messages. `0` is silent,
`1` provides basic info, and higher values provide more detail.
Returns
-------
t, x, y : Tuple[tf.Tensor, tf.Tensor, tf.Tensor]
The extracted t, x, and y coordinate tensors, each reshaped
to be 3D with a singleton last dimension, e.g.,
`(batch, time_steps, 1)`.
Raises
------
ValueError
If input format is unsupported, dimensions are inconsistent,
or `expect_dim` constraint is violated.
"""
vlog(
"Extracting (t, x, y) coordinates from inputs...",
level=2,
verbose=verbose,
logger=_logger,
)
if expect_dim and expect_dim not in ["2d", "3d"]:
raise ValueError(
"`expect_dim` must be None, '2d', or '3d'."
)
if coord_slice_map is None:
coord_slice_map = {"t": 0, "x": 1, "y": 2}
def _ensure_3d_and_validate(tensor, name):
"""Helper to convert to tensor, ensure 3D, and validate."""
if not isinstance(tensor, Tensor):
tensor = tf_convert_to_tensor(
tensor, dtype=tf_float32
)
if expect_dim == "2d" and tensor.ndim != 2:
raise ValueError(
f"Input '{name}' must be 2D (expect_dim='2d'), "
f"but got rank {tensor.ndim}."
)
elif expect_dim == "3d" and tensor.ndim != 3:
raise ValueError(
f"Input '{name}' must be 3D (expect_dim='3d'), "
f"but got rank {tensor.ndim}."
)
# For consistency, always expand 2D tensors to 3D.
if tensor.ndim == 2:
vlog(
f"Expanding 2D input '{name}' to 3D for consistency.",
level=3,
verbose=verbose,
logger=_logger,
)
return tf_expand_dims(tensor, axis=1)
elif tensor.ndim == 3:
return tensor
else:
raise ValueError(
f"Input '{name}' must be a 2D or 3D tensor, but got "
f"rank {tensor.ndim} with shape {tensor.shape}."
)
# --- Main Logic ---
if isinstance(inputs, dict):
if "coords" in inputs:
coords_tensor = _ensure_3d_and_validate(
inputs["coords"], "coords"
)
elif all(k in inputs for k in ("t", "x", "y")):
t = _ensure_3d_and_validate(inputs["t"], "t")
x = _ensure_3d_and_validate(inputs["x"], "x")
y = _ensure_3d_and_validate(inputs["y"], "y")
# The shapes are now guaranteed to be 3D.
return (
tf_cast(t, tf_float32),
tf_cast(x, tf_float32),
tf_cast(y, tf_float32),
)
else:
raise ValueError(
"Dict `inputs` must contain either 'coords' key "
"or all of 't', 'x', 'y' keys."
)
elif isinstance(inputs, Tensor | np.ndarray):
coords_tensor = _ensure_3d_and_validate(
inputs, "inputs"
)
else:
raise TypeError(
f"`inputs` must be a tensor/array or dict; got {type(inputs)}"
)
# Slice the now-guaranteed 3D coords_tensor
tf_debugging.assert_greater_equal(
tf_shape(coords_tensor)[-1],
3,
message=(
"Coordinate tensor must carry at least three features "
"(t, x, y) in the last dimension,"
f" but got shape {coords_tensor.shape}"
),
)
# if tf_shape(coords_tensor)[-1] < 3:
# raise ValueError(
# "Coordinate tensor must have at least 3 features (t,x,y) "
# f"in the last dimension, but got shape {coords_tensor.shape}"
# )
t = coords_tensor[
..., coord_slice_map["t"] : coord_slice_map["t"] + 1
]
x = coords_tensor[
..., coord_slice_map["x"] : coord_slice_map["x"] + 1
]
y = coords_tensor[
..., coord_slice_map["y"] : coord_slice_map["y"] + 1
]
vlog(
f"Successfully extracted t:{t.shape}, x:{x.shape}, "
f"y:{y.shape}",
level=2,
verbose=verbose,
logger=_logger,
)
return (
tf_cast(t, tf_float32),
tf_cast(x, tf_float32),
tf_cast(y, tf_float32),
)
@ensure_pkg(
KERAS_BACKEND or "tensorflow",
extra="TensorFlow is required for this function.",
)
def extract_txy(
inputs: Tensor
| np.ndarray
| dict[str, Tensor | np.ndarray],
coord_slice_map: dict[str, int] | None = None,
expect_dim: str | None = None,
verbose: int = 0,
_logger: logging.Logger
| Callable[[str], None]
| None = None,
**kws,
) -> tuple[Tensor, Tensor, Tensor]:
r"""
Extracts t, x, y tensors from various input formats.
This utility standardizes coordinate inputs, accepting a single
tensor or a dictionary, and handling both 2D (spatial/static)
and 3D (spatio-temporal) data with flexible dimension validation.
Parameters
----------
inputs : tf.Tensor, np.ndarray, or dict
The input data containing coordinates. Can be a single tensor
or a dictionary with 'coords' or 't', 'x', 'y' keys.
coord_slice_map : dict, optional
Mapping for 't', 'x', 'y' to their index in the last
dimension of a coordinate tensor.
Defaults to `{'t': 0, 'x': 1, 'y': 2}`.
expect_dim : {'2d', '3d', '3d_only'}, optional
Enforces a constraint on the input's dimension. ``'2d'``
requires input shaped like ``(batch, 3)``. ``'3d'`` accepts
3D input and expands 2D input to 3D with a time dimension of 1.
``'3d_only'`` requires 3D input and raises an error for 2D
input. ``None`` accepts both 2D and 3D inputs without changing
their rank.
verbose : int, default 0
Controls logging verbosity.
Returns
-------
t, x, y : Tuple[tf.Tensor, tf.Tensor, tf.Tensor]
The extracted t, x, and y coordinate tensors. Their rank (2D
or 3D) depends on the input and the `expect_dim` mode.
Raises
------
ValueError
If input format is unsupported, dimensions are inconsistent,
or `expect_dim` constraint is violated.
"""
vlog(
"Extracting (t, x, y) coordinates from inputs...",
level=2,
verbose=verbose,
logger=_logger,
)
if expect_dim and expect_dim not in [
"2d",
"3d",
"3d_only",
]:
raise ValueError(
"`expect_dim` must be None, '2d', '3d', or '3d_only'."
)
if coord_slice_map is None:
coord_slice_map = {"t": 0, "x": 1, "y": 2}
def _process_tensor(tensor, name):
"""Helper to convert to tensor and validate dimension."""
if not isinstance(tensor, Tensor):
tensor = tf_convert_to_tensor(
tensor, dtype=tf_float32
)
input_ndim = tensor.ndim
# Validate against expect_dim
if expect_dim == "2d" and input_ndim != 2:
raise ValueError(
f"Input '{name}' must be 2D for expect_dim='2d', "
f"but got rank {input_ndim}."
)
elif expect_dim == "3d_only" and input_ndim != 3:
raise ValueError(
f"Input '{name}' must be 3D for expect_dim='3d_only', "
f"but got rank {input_ndim}."
)
# Handle expansion for '3d' mode
if expect_dim == "3d" and input_ndim == 2:
vlog(
f"Expanding 2D input '{name}' to 3D for expect_dim='3d'.",
level=3,
verbose=verbose,
logger=_logger,
)
return tf_expand_dims(tensor, axis=1)
# For all other cases, including `expect_dim=None`, return as is
# after basic rank validation.
if input_ndim not in [2, 3]:
raise ValueError(
f"Input '{name}' must be a 2D or 3D tensor, but got "
f"rank {input_ndim} with shape {tensor.shape}."
)
return tensor
# --- Main Logic ---
if isinstance(inputs, dict):
if "coords" in inputs:
coords_tensor = _process_tensor(
inputs["coords"], "coords"
)
elif all(k in inputs for k in ("t", "x", "y")):
# When t,x,y are separate, they are typically 1D or 2D.
# We process them and then concatenate.
t_p = _process_tensor(inputs["t"], "t")
x_p = _process_tensor(inputs["x"], "x")
y_p = _process_tensor(inputs["y"], "y")
return (
tf_cast(t_p, tf_float32),
tf_cast(x_p, tf_float32),
tf_cast(y_p, tf_float32),
)
# The individual tensors might be 2D or 3D. Concat will work if
# their ranks match, which is handled by _process_tensor.
# coords_tensor = tf_concat([t_p, x_p, y_p], axis=-1)
else:
raise ValueError(
"Dict `inputs` must contain either 'coords' key "
"or all of 't', 'x', 'y' keys."
)
elif isinstance(inputs, Tensor | np.ndarray):
coords_tensor = _process_tensor(inputs, "inputs")
else:
raise TypeError(
f"`inputs` must be a tensor/array or dict; got {type(inputs)}"
)
# Slice the processed coords_tensor
if tf_shape(coords_tensor)[-1] < 3:
raise ValueError(
"Coordinate tensor must have at least 3 features (t,x,y) "
f"in the last dimension, but got shape {coords_tensor.shape}"
)
# Slicing keeps the original number of dimensions
t = coords_tensor[
..., coord_slice_map["t"] : coord_slice_map["t"] + 1
]
x = coords_tensor[
..., coord_slice_map["x"] : coord_slice_map["x"] + 1
]
y = coords_tensor[
..., coord_slice_map["y"] : coord_slice_map["y"] + 1
]
vlog(
f"Successfully extracted t:{t.shape}, x:{x.shape}, "
f"y:{y.shape}",
level=2,
verbose=verbose,
logger=_logger,
)
return (
tf_cast(t, tf_float32),
tf_cast(x, tf_float32),
tf_cast(y, tf_float32),
)
@ensure_pkg(
KERAS_BACKEND or "tensorflow",
extra="TensorFlow is required for this function.",
)
def plot_hydraulic_head(
model: Model,
t_slice: float,
x_bounds: tuple[float, float],
y_bounds: tuple[float, float],
resolution: int = 100,
ax: plt.Axes | None = None,
title: str | None = None,
cmap: str = "viridis",
colorbar_label: str = "Hydraulic Head (h)",
save_path: str | None = None,
show_plot: bool = True,
**contourf_kwargs: Any,
) -> tuple[plt.Axes, ScalarMappable]:
r"""Generate and plot a 2D contour map of a hydraulic head solution.
This utility visualizes the output of a Physics-Informed Neural
Network (PINN) that solves for the hydraulic head
:math:`h(t, x, y)`. It automates the process of creating a
spatial grid, running model predictions, and generating a
publication-quality contour plot for a specific slice in time.
Parameters
----------
model : tf.keras.Model
The trained PINN model. It is expected to have a ``.predict()``
method that accepts a dictionary of tensors with keys
``{'t', 'x', 'y'}``.
t_slice : float
The specific point in time :math:`t` for which to plot the
2D spatial solution.
x_bounds : tuple of float
A tuple ``(x_min, x_max)`` defining the spatial domain for
the x-axis.
y_bounds : tuple of float
A tuple ``(y_min, y_max)`` defining the spatial domain for
the y-axis.
resolution : int, optional
The number of points to sample along each spatial axis,
creating a grid of ``resolution x resolution`` points for
prediction. Higher values result in a smoother plot.
Default is 100.
ax : matplotlib.axes.Axes, optional
A pre-existing Matplotlib Axes object to plot on. If ``None``,
a new figure and axes are created internally. This is useful
for embedding this plot within a larger figure arrangement.
Default is ``None``.
title : str, optional
A custom title for the plot. If ``None``, a default title
is generated using the value of `t_slice`. Default is ``None``.
cmap : str, optional
The name of the Matplotlib colormap to use for the contour
plot. Default is ``'viridis'``.
colorbar_label : str, optional
The text label for the color bar. Default is
``'Hydraulic Head (h)'``.
save_path : str, optional
If provided, the path (including filename and extension)
where the generated plot will be saved. This is only active
when the function creates its own figure (i.e., when `ax`
is ``None``). Default is ``None``.
show_plot : bool, optional
If ``True``, calls ``plt.show()`` to display the plot. This
is only active when the function creates its own figure.
Default is ``True``.
**contourf_kwargs : any
Additional keyword arguments that are passed directly to the
``matplotlib.pyplot.contourf`` function. This allows for
advanced customization (e.g., ``levels=20``, ``extend='both'``).
Returns
-------
ax : matplotlib.axes.Axes
The Matplotlib Axes object on which the contour plot was drawn.
contour : matplotlib.cm.ScalarMappable
The contour plot object, which can be used for further
customizations, such as modifying the color bar.
See Also
--------
geoprior.models.pinn.PiTGWFlow : The PINN model this function is
designed to visualize.
Notes
-----
The core mechanism of this function involves creating a 2D
meshgrid of :math:`(x, y)` coordinates. These grid points are then
"flattened" into a long list of points, as the PINN model expects
a batch of individual coordinates for prediction, not a grid.
The prediction process is as follows:
1. A grid of shape ``(resolution, resolution)`` is created for
:math:`x` and :math:`y`.
2. These grids are reshaped into column vectors of shape
``(resolution*resolution, 1)``.
3. A time vector of the same shape, filled with `t_slice`, is
created.
4. The model's ``.predict()`` method is called on these flat
tensors.
5. The resulting flat prediction vector is reshaped back to the
original ``(resolution, resolution)`` grid shape for plotting.
If a custom `ax` is provided, the user is responsible for calling
``plt.show()`` or saving the parent figure.
Examples
--------
>>> import numpy as np
>>> import tensorflow as tf
>>> import matplotlib.pyplot as plt
>>> # This is a mock model for demonstration purposes.
>>> # In practice, you would use a trained PiTGWFlow model.
>>> class MockPINN(tf.keras.Model):
... def call(self, inputs):
... # A simple analytical function for demonstration
... t, x, y = inputs['t'], inputs['x'], inputs['y']
... return tf.sin(np.pi * x) * tf.cos(np.pi * y) * tf.exp(-t)
...
>>> mock_model = MockPINN()
**1. Simple Plotting Example**
This example creates a single plot and saves it to a file.
>>> ax, contour = plot_hydraulic_head(
... model=mock_model,
... t_slice=0.5,
... x_bounds=(-1, 1),
... y_bounds=(-1, 1),
... resolution=50,
... save_path="hydraulic_head_t0.5.png",
... show_plot=False # Do not display interactively
... )
Plot saved to hydraulic_head_t0.5.png
**2. Advanced Example with Subplots**
This example shows how to use the `ax` parameter to draw the
solution at two different times side-by-side in one figure.
>>> fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))
>>> fig.suptitle('Hydraulic Head at Different Times', fontsize=16)
...
>>> # Plot solution at t = 0.1
>>> plot_hydraulic_head(
... model=mock_model, t_slice=0.1, x_bounds=(-1, 1),
... y_bounds=(-1, 1), ax=ax1, show_plot=False
... )
...
>>> # Plot solution at t = 1.0
>>> plot_hydraulic_head(
... model=mock_model, t_slice=1.0, x_bounds=(-1, 1),
... y_bounds=(-1, 1), ax=ax2, show_plot=False
... )
...
>>> plt.tight_layout(rect=[0, 0, 1, 0.96])
>>> plt.show()
"""
# ... function implementation follows ...
# --- 1. Handle Matplotlib Figure and Axes ---
if ax is None:
# Create a new figure only if no axes are provided
fig, current_ax = plt.subplots(figsize=(9, 7))
# Flag to indicate we have control over the figure object
fig_created = True
else:
current_ax = ax
fig = current_ax.figure
fig_created = False
# --- 2. Create Prediction Grid ---
x_range = np.linspace(
x_bounds[0], x_bounds[1], resolution
)
y_range = np.linspace(
y_bounds[0], y_bounds[1], resolution
)
X, Y = np.meshgrid(x_range, y_range)
# Flatten the grid to a list of points for model prediction
x_flat = tf_convert_to_tensor(X.ravel(), dtype=tf_float32)
y_flat = tf_convert_to_tensor(Y.ravel(), dtype=tf_float32)
t_flat = tf_fill(x_flat.shape, t_slice)
# Reshape to column vectors (N, 1) as expected by the model
grid_coords = {
"t": tf_reshape(t_flat, (-1, 1)),
"x": tf_reshape(x_flat, (-1, 1)),
"y": tf_reshape(y_flat, (-1, 1)),
}
# --- 3. Run Model Prediction ---
h_pred_flat = model.predict(grid_coords)
# Reshape the flat predictions back to the grid shape for plotting
h_pred_grid = tf_reshape(h_pred_flat, X.shape)
# --- 4. Plotting ---
# Set default contour levels if not provided
if "levels" not in contourf_kwargs:
contourf_kwargs["levels"] = 100
contour = current_ax.contourf(
X, Y, h_pred_grid, cmap=cmap, **contourf_kwargs
)
# Add color bar
fig.colorbar(contour, ax=current_ax, label=colorbar_label)
# Set plot labels and title
if title is None:
title = f"Learned Hydraulic Head Solution at t = {t_slice}"
current_ax.set_title(title, fontsize=14)
current_ax.set_xlabel("x-coordinate")
current_ax.set_ylabel("y-coordinate")
current_ax.set_aspect("equal")
# --- 5. Save and Show ---
if fig_created:
if save_path:
fig.savefig(
save_path, dpi=300, bbox_inches="tight"
)
print(f"Plot saved to {save_path}")
if show_plot:
plt.show()
else:
# If not showing, close the figure to free up memory
plt.close(fig)
return current_ax, contour
API notes#
A few practical notes help orient readers:
geoprior.utilsis the best starting point for staged workflow code, diagnostics, export, and reproducibility scripts.geoprior.models.utilsis the better starting point for model-input preparation, sequence construction, and PINN-side formatting helpers.geoprior.models.subsidence.utilsis the best place to inspect scaling, SI conversion, depth/head conventions, and reference-state extraction logic.The three layers are meant to complement one another rather than compete for the same role.