geoprior.models.utils._utils#
Utility functions for neural networks models.
This module provides utility functions to preprocess data for Temporal Fusion Transformer (TFT) models, including splitting sequences into static and dynamic inputs and creating input sequences with corresponding targets for time series forecasting.
Functions
|
Compute anomaly scores for given true targets using various methods. |
|
Compute the forecast horizon for time series forecasting models. |
|
Create input sequences and corresponding targets for time series forecasting. |
|
Export loss(es) (and any other metric) from a Keras History object. |
|
Extracts a specified number of batches from a tf.data.Dataset. |
|
Extract Keras callbacks from a dictionary of fit parameters. |
|
Generate a multi-step forecast using the XTFT model. |
|
Generate a single-step forecast using the XTFT model. |
|
Formats model predictions into a structured pandas DataFrame. |
|
Deprecated alias for format_predictions. |
|
Generate forecast using the XTFT model. |
|
|
|
Safely retrieves the first available tensor from a dictionary using a list of possible keys. |
|
Create a tf.data.Dataset.map function that converts a feature dictionary into a positional tuple expected by a sub-classed Keras model. |
|
Prepares a list of input tensors for a model's call method. |
|
Prepares a list of input tensors for a model's call method in graph |
|
Prepare future static and dynamic inputs for making predictions. |
|
Sets and validates default values for quantiles, scales, and return_sequences parameters. |
|
Split sequences into static and dynamic inputs for the model. |
|
Squeeze the last dimension of tensor(s) if it equals 1 based on output_dims. |
|
Convert a multi-step forecast DataFrame from wide to long format. |
- geoprior.models.utils._utils.split_static_dynamic(sequences, static_indices, dynamic_indices, static_time_step=0, reshape_static=True, reshape_dynamic=True, static_reshape_shape=None, dynamic_reshape_shape=None)[source]#
Split sequences into static and dynamic inputs for the model.
The split_static_dynamic function divides input sequences into static and dynamic components based on specified feature indices. Static features are typically location-specific and do not change over time, while dynamic features vary across different time steps.
(1)#\[\begin{split}\text{Static Inputs} = \mathbf{S} = \mathbf{X}_{t, static\_indices} \\ \text{Dynamic Inputs} = \mathbf{D} = \mathbf{X}_{:, dynamic\_indices}\end{split}\]- Parameters:
sequences (numpy.ndarray) – Array of input sequences with shape (batch_size, sequence_length, num_features).
static_indices (List[int]) – Indices of static features within the feature dimension.
dynamic_indices (List[int]) – Indices of dynamic features within the feature dimension.
static_time_step (int, default 0) – Time step from which to extract static features (default is the first time step).
reshape_static (bool, default True) – Whether to reshape static inputs. If False, returns without reshaping.
reshape_dynamic (bool, default True) – Whether to reshape dynamic inputs. If False, returns without reshaping.
static_reshape_shape (:py:class:
Optional[Tuple[int`, :py:class:`...]], default None) – Desired shape for static inputs after reshaping. If None, defaults to (batch_size, num_static_vars, 1).dynamic_reshape_shape (:py:class:
Optional[Tuple[int`, :py:class:`...]], default None) – Desired shape for dynamic inputs after reshaping. If None, defaults to (batch_size, sequence_length, num_dynamic_vars, 1).
- Returns:
A tuple containing: - Static inputs with shape as specified. - Dynamic inputs with shape as specified.
- Return type:
Tuple[`numpy.ndarray`, numpy.ndarray]- Raises:
ValueError – If static_time_step is out of range for the given sequence length.
Examples
>>> import numpy as np >>> from geoprior.models.utils import split_static_dynamic >>> >>> # Create a dummy sequence array >>> sequences = np.random.rand(100, 10, 5) # ( ... batch_size=100, sequence_length=10, num_features=5) >>> >>> # Define static and dynamic feature indices >>> static_indices = [0, 1] # e.g., longitude and latitude >>> dynamic_indices = [2, 3, 4] # e.g., year, GWL, density >>> >>> # Split the sequences >>> static_inputs, dynamic_inputs = split_static_dynamic( ... sequences, ... static_indices=static_indices, ... dynamic_indices=dynamic_indices, ... static_time_step=0 ... ) >>> >>> print(static_inputs.shape) (100, 2, 1) >>> print(dynamic_inputs.shape) (100, 10, 3, 1)
Notes
Static Features: These are typically location-specific features such as geographical coordinates or categorical attributes that remain constant over time.
Dynamic Features: These features vary over different time steps and are essential for capturing temporal dependencies in the data.
Reshaping: The function provides flexibility in reshaping the static and dynamic inputs to match the input requirements of various models, including Temporal Fusion Transformers.
See also
geoprior.models.utils.create_sequencesFunction to create input sequences and targets for time series forecasting.
- geoprior.models.utils._utils.create_sequences(df, sequence_length, target_col, step=1, include_overlap=True, drop_last=True, forecast_horizon=None, verbose=3)[source]#
Create input sequences and corresponding targets for time series forecasting.
The create_sequences function generates sequences of features and their corresponding targets from a time series dataset. This is essential for training sequence models like Temporal Fusion Transformers, LSTMs, and others that rely on temporal dependencies.
See more in User Guide.
- Parameters:
df (pandas.DataFrame) – The processed DataFrame containing features and the target variable.
sequence_length (int) – The number of past time steps to include in each input sequence.
target_col (str) – The name of the target column.
step (int, default 1) – The step size between the starts of consecutive sequences.
include_overlap (bool, default True) – Whether to include overlapping sequences based on the step size.
drop_last (bool, default True) – Whether to drop the last sequence if it does not have enough data points.
forecast_horizon (int, optional, default None) – The number of future time steps to predict. If set to None, the function will create targets for a single future time step. If provided, targets will consist of the next forecast_horizon time steps.
verbose (int, default 3) – Controls the verbosity of logging. Ranges from 0 (no logs) to 7 (maximal logs).
- Returns:
- A tuple containing:
sequences: Array of input sequences with shape (num_sequences, sequence_length, num_features).
- targets:
If forecast_horizon is None: Array of target values with shape (num_sequences,).
If forecast_horizon is an integer: Array of target sequences with shape (num_sequences, forecast_horizon).
- Return type:
Tuple[`numpy.ndarray`, numpy.ndarray]- Raises:
ValueError – If the DataFrame df does not contain the target_col.
Examples
>>> import pandas as pd >>> import numpy as np >>> from geoprior.models.utils import create_sequences
>>> # Create a dummy DataFrame >>> data = pd.DataFrame({ ... 'feature1': np.random.rand(100), ... 'feature2': np.random.rand(100), ... 'feature3': np.random.rand(100), ... 'target': np.random.rand(100) ... })
>>> # Create sequences for single-step forecasting >>> sequence_length = 4 >>> sequences, targets = create_sequences( ... df=data, ... sequence_length=sequence_length, ... target_col='target', ... step=1, ... include_overlap=True, ... drop_last=True, ... forecast_horizon=None ... ) >>> print(sequences.shape) (95, 4, 4) >>> print(targets.shape) (95,)
>>> # Create sequences for multi-step forecasting (e.g., 3 steps) >>> forecast_horizon = 3 >>> sequences, targets = create_sequences( ... df=data, ... sequence_length=4, ... target_col='target', ... step=1, ... include_overlap=True, ... drop_last=True, ... forecast_horizon=3 ... ) >>> print(sequences.shape) (92, 4, 4) >>> print(targets.shape) (92, 3)
Notes
Sequence Creation: The function slides a window of size sequence_length across the DataFrame to create input sequences. Each sequence is associated with a target value or sequence of values that immediately follow the input sequence.
- Forecast Horizon:
If forecast_horizon is None, the function creates targets for a single future time step.
If forecast_horizon is an integer H, the function creates targets consisting of the next H time steps.
Step Size: The step parameter controls the stride of the sliding window. A step of 1 results in overlapping sequences, while a larger step reduces overlap.
Handling Incomplete Sequences: If drop_last is set to False, the function includes the last sequence even if it doesn’t have enough data points to form a complete sequence or target.
Data Validation: The function utilizes are_all_frames_valid from geoprior.core.checks to ensure the integrity of input DataFrame before processing and exist_features to verify the presence of the target column.
The sequences generation can be expressed as:
(2)#\[\begin{split}\\text{For each sequence } i, \\\\ \\mathbf{X}^{(i)} = \\left[ \\mathbf{x}_{i}, \\mathbf{x}_{i+1}, \\\\ \\dots, \\mathbf{x}_{i+T-1} \\right] \\\\ y^{(i)} = \\begin{cases} \\mathbf{x}_{i+T} & \\text{if } \\text{forecast\\_horizon} = \\text{None} \\\\ \\left[ \\mathbf{x}_{i+T}, \\mathbf{x}_{i+T+1}, \\dots, \\\\ \\mathbf{x}_{i+T+H-1} \\right] & \\text{if } \\text{forecast\\_horizon} = H \\end{cases}\end{split}\]- Where:
\(\\mathbf{X}^{(i)}\) is the input sequence of length \(T\).
\(y^{(i)}\) is the target value(s) following the sequence.
See also
geoprior.models.utils.split_static_dynamicFunction to split sequences into static and dynamic inputs.
- geoprior.models.utils._utils.compute_forecast_horizon(data=None, dt_col=None, start_pred=None, end_pred=None, error='raise', verbose=1)[source]#
Compute the forecast horizon for time series forecasting models.
This function calculates the number of future time steps (forecast_horizon) a model should predict based on the provided data or specified prediction dates. It intelligently infers the frequency of the data and computes the horizon accordingly. The function accommodates various datetime formats and handles different input scenarios robustly.
- Parameters:
data (
pandas.DataFrame,pandas.Series,list, ornumpy.ndarray, optional) – The dataset containing datetime information. If a pandas.DataFrame is provided, the dt_col parameter must be specified to indicate which column contains the datetime data. For pandas.Series, list, or numpy.ndarray, the function attempts to infer the frequency directly.dt_col (
str, optional) – The name of the column in data that contains datetime information. This parameter is required if data is a pandas.DataFrame. Example:dt_col='timestamp'start_pred (
str,int, ordatetime-like) – The starting point for forecasting. This can be a date string (e.g., ‘2023-04-10’), a datetime object, or an integer representing a year (e.g., 2024). If an integer is provided, it is interpreted as a year, and a warning is issued to inform the user of this interpretation.end_pred (
str,int, ordatetime-like) – The ending point for forecasting. Similar to start_pred, this can be a date string, a datetime object, or an integer representing a year. The function calculates the forecast horizon based on the difference between start_pred and end_pred.error (
{'raise', 'warn', 'ignore'}, default'raise') –Defines the error handling behavior when encountering issues such as invalid input types, missing date columns, or unparseable dates.
’raise’: Raises a ValueError when an error is encountered.
’warn’: Emits a warning and attempts to proceed with default behavior.
verbose (
int, default1) –Controls the level of verbosity for debug information.
0: No output.
1: Minimal output (e.g., starting message).
2: Intermediate output (e.g., detected dates, computed horizons).
3: Detailed output (e.g., types of predictions, inferred frequencies).
- Returns:
The computed forecast_horizon representing the number of steps ahead the model should predict. Returns None if an error occurs and error is set to ‘warn’.
- Return type:
- Raises:
ValueError – If invalid parameters are provided and error is set to ‘raise’.
Examples
>>> from geoprior.models.utils import compute_forecast_horizon >>> import pandas as pd >>> import numpy as np >>> from datetime import datetime, timedelta >>> >>> # Example 1: Using a DataFrame with a Date Column >>> df = pd.DataFrame({ ... 'date': pd.date_range(start='2023-01-01', periods=100, freq='D'), ... 'value': np.random.randn(100) ... }) >>> horizon = compute_forecast_horizon( ... data=df, ... dt_col='date', ... start_pred='2023-04-10', ... end_pred='2023-04-20', ... error='raise', ... verbose=3 ... ) >>> print(f"Forecast Horizon: {horizon}") Forecast Horizon: 11
>>> # Example 2: Using a List of Datetimes >>> dates = [datetime(2023, 1, 1) + timedelta(days=i) for i in range(100)] >>> horizon = compute_forecast_horizon( ... data=dates, ... start_pred='2023-04-10', ... end_pred='2023-04-20', ... error='warn', ... verbose=2 ... ) >>> print(f"Forecast Horizon: {horizon}") Forecast Horizon: 11
>>> # Example 3: Handling Integer Years >>> horizon = compute_forecast_horizon( ... start_pred=2024, ... end_pred=2030, ... error='raise', ... verbose=1 ... ) Forecast Horizon: 7
>>> # Example 4: Without Providing Data (Assuming Frequency Based on Prediction Dates) >>> horizon = compute_forecast_horizon( ... start_pred='2023-04-10', ... end_pred='2023-04-20', ... error='raise', ... verbose=1 ... ) Forecast Horizon: 11
Notes
When data is not provided, the function relies solely on the difference between start_pred and end_pred to compute the forecast horizon. In such cases, if the frequency cannot be inferred, the horizon is calculated based on the largest possible time unit (years, months, weeks, days).
If start_pred is after end_pred, the function returns 0 and issues a warning or raises an error based on the error parameter.
The function attempts to infer the frequency of the data using pandas utilities. If the frequency cannot be inferred, it defaults to calculating the horizon based on the time difference in the most significant unit.
See also
pandas.date_rangeGenerates a fixed frequency DatetimeIndex.
pandas.infer_freqInfers the frequency of a DatetimeIndex.
- geoprior.models.utils._utils.prepare_spatial_future_data(final_processed_data, feature_columns, dynamic_feature_indices, sequence_length=1, dt_col='date', static_feature_names=None, forecast_horizon=None, future_years=None, encoded_cat_columns=None, scaling_params=None, spatial_cols=None, squeeze_last=False, verbosity=0)[source]#
Prepare future static and dynamic inputs for making predictions.
This function prepares the necessary static and dynamic inputs required for forecasting future values in time series data. It processes the provided dataset by grouping it by
location_id, extracting the last sequence of data points based on the specifiedsequence_length, and generating future inputs for prediction over the definedforecast_horizon.The function handles both integer and datetime representations of the
dt_col, extracting the year from datetime columns when necessary. It also allows for flexibility in specifying static features and encoded categorical variables.(3)#\[\text{scaled\_time} = \frac{\text{future\_time} - \mu}{\sigma}\]- Parameters:
final_processed_data (
pandas.DataFrame) – The processed DataFrame containing all features and targets. Must include thelocation_idcolumn and the specifieddt_col.feature_columns (
List[str]) – List of feature column names to be used for dynamic input preparation.dynamic_feature_indices (
List[int]) – Indices of dynamic features infeature_columns. These features are considered time-dependent and are used to prepare dynamic inputs.sequence_length (
int, optional) – The number of past time steps to include in each input sequence. Default is1.dt_col (
str, optional) – The name of the time-related column infinal_processed_data. Defaults to'date'.static_feature_names (
List[str], optional) – List of static feature column names. If not provided, defaults to['longitude', 'latitude']plus anyencoded_cat_columns.forecast_horizon (
int, optional) – The number of future time steps to predict. If set toNone, the function defaults to predicting the next immediate time step.future_years (
List[int], optional) – List of future years to predict. Must match the length offorecast_horizonifforecast_horizonis provided.encoded_cat_columns (
List[str], optional) – List of encoded categorical column names to be treated as static features.scaling_params (
Dict[str,Dict[str,float]], optional) – Dictionary containing scaling parameters (mean and standard deviation) for features. Example:{'year': {'mean': 2000, 'std': 10}}. If not provided, the function computes the mean and std for thedt_col.squeeze_last (
bool, defaultTrue,) – Squeeze the last axis which correspond to the output dimensionyif equal to1.verbosity (
int, optional) – Verbosity level from0to7for debugging and understanding the process. Higher values produce more detailed logs.
- Returns:
A tuple containing:
future_static_inputsnumpy.ndarrayArray of future static inputs with shape
(num_samples, num_static_vars, 1).
future_dynamic_inputsnumpy.ndarrayArray of future dynamic inputs with shape
(num_samples, sequence_length, num_dynamic_vars, 1).
future_years_listList[int]List of future time values corresponding to each sample.
location_ids_listList[int]List of location IDs corresponding to each sample.
longitudesList[float]List of longitude values corresponding to each sample.
latitudesList[float]List of latitude values corresponding to each sample.
- Return type:
Tuple[np.ndarray,np.ndarray,List[int],List[int],List[float],List[float]]
Examples
>>> from geoprior.models.utils import prepare_spatial_future_data >>> import pandas as pd >>> data = pd.DataFrame({ ... 'location_id': [1, 1, 1, 2, 2, 2], ... 'year': [2018, 2019, 2020, 2018, 2019, 2020], ... 'longitude': [10.0, 10.0, 10.0, 20.0, 20.0, 20.0], ... 'latitude': [50.0, 50.0, 50.0, 60.0, 60.0, 60.0], ... 'temperature': [15, 16, 15.5, 20, 21, 20.5], ... 'rainfall': [100, 110, 105, 200, 210, 205], ... 'encoded_cat': [1, 1, 1, 2, 2, 2] ... }) >>> feature_cols = ['year', 'temperature', 'rainfall', 'encoded_cat'] >>> dynamic_indices = [0, 1, 2] >>> future_static, future_dynamic, future_years, loc_ids, longs,\ lats = prepare_spatial_future_data( ... final_processed_data=data, ... feature_columns=feature_cols, ... dynamic_feature_indices=dynamic_indices, ... sequence_length=2, ... forecast_horizon=1, ... future_years=[2021], ... encoded_cat_columns=['encoded_cat'], ... verbosity=5, ... dt_col='year' ... ) >>> print(future_static.shape) (2, 3, 1) >>> print(future_dynamic.shape) (2, 2, 3, 1)
Notes
The function handles both integer and datetime representations of the
dt_col. Ifdt_colis a datetime type, the year is extracted for scaling purposes.If
forecast_horizonis set toNone, the function defaults to generating data for the next immediate time step based on the last entry in the time column.Ensure that the length of
future_yearsmatchesforecast_horizonifforecast_horizonis provided.The
static_feature_namesparameter allows for flexibility in specifying which static features to include. If not provided, it defaults to['longitude', 'latitude']plus anyencoded_cat_columns.
See also
prepare_future_dataMain function for preparing future data inputs.
- geoprior.models.utils._utils.compute_anomaly_scores(y_true, y_pred=None, method='statistical', threshold=3.0, domain_func=None, contamination=0.05, epsilon=1e-06, estimator=None, random_state=None, residual_metric='mse', objective='ts', verbose=1)[source]#
Compute anomaly scores for given true targets using various methods.
This utility function,
anomaly_scores, provides a flexible approach to compute anomaly scores outside the XTFT model itself. Anomaly scores serve as indicators of how unusual certain observations are, guiding the model towards more robust and stable forecasts. By detecting and quantifying anomalies, practitioners can adjust forecasting strategies, improve predictive performance, and handle irregular patterns more effectively.- Parameters:
y_true (
np.ndarray) –The ground truth target values with shape
(B, H, O), where: - B: batch size - H: number of forecast horizons (time steps ahead) - O: output dimension (e.g., number of target variables).Typically, y_true corresponds to the same array passed as the forecast target to the model. All computations of anomalies are relative to these true values or, if provided, their predicted counterparts y_pred.
y_pred (
np.ndarray, optional) – The predicted values with shape(B, H, O). If provided and the method is set to ‘residual’, the anomaly scores are derived from the residuals between y_true and y_pred. In this scenario, anomalies reflect discrepancies indicating unusual conditions or model underperformance.method (
str, optional) –The method used to compute anomaly scores. Supported options:
"statistical"or"stats": Uses mean and standard deviation of y_true to measure deviation from the mean. Points far from the mean by a certain factor (controlled by threshold) yield higher anomaly scores.Formally, let \(\mu\) be the mean of y_true and \(\sigma\) its standard deviation. The anomaly score for a point \(y\) is:
(4)#\[(\frac{y - \mu}{\sigma + \varepsilon})^2\]where \(\varepsilon\) is a small constant for numerical stability.
"domain": Uses a domain-specific heuristic (provided by domain_func) to compute scores. If no domain_func is provided, a default heuristic marks negative values as anomalies."isolation_forest"or"if": Employs the IsolationForest algorithm to detect outliers. The model learns a structure to isolate anomalies more quickly than normal points. Higher contamination rates allow more points to be considered anomalous."residual": If y_pred is provided, anomalies are derived from residuals: the difference (y_true - y_pred). By default, mean squared error (mse) is used. Other metrics include mae and rmse, offering flexibility in quantifying deviations:(5)#\[\text{MSE: }(y_{true} - y_{pred})^2\]
Default is
"statistical".threshold (
float, optional) – Threshold factor for the statistical method. Defines how far beyond mean ± (threshold * std) is considered anomalous. Though not directly applied as a mask here, it can guide interpretation of scores. Default is3.0.domain_func (
callable, optional) –A user-defined function for domain method. It takes y_true as input and returns an array of anomaly scores with the same shape. If none is provided, the default heuristic:
(6)\[\begin{split}\text{anomaly}(y) = \begin{cases} |y| \times 10 & \text{if } y < 0 \\ 0 & \text{otherwise} \end{cases}\end{split}\]contamination (
float, optional) – Used in the isolation_forest method. Specifies the proportion of outliers in the dataset. Default is0.05.epsilon (
float, optional) – A small constant \(\varepsilon\) for numerical stability in calculations, especially during statistical normalization. Default is1e-6.estimator (
object, optional) – A pre-fitted IsolationForest estimator for the isolation_forest method. If not provided, a new estimator will be created and fitted to y_true.random_state (
int, optional) – Sets a random state for reproducibility in the isolation_forest method.residual_metric (
str, optional) –The metric used to compute anomalies from residuals if method is set to ‘residual’. Supported metrics:
"mse": mean squared error per point (residuals**2)"mae": mean absolute error per point |residuals|"rmse": root mean squared error sqrt((residuals**2))
Default is
"mse".objective (
str, optional) – Specifies the type of objective, for future extensibility. Default is"ts"indicating time series. Could be extended for other tasks in the future.verbose (
int, optional) – Controls verbosity. If verbose=1, some messages or warnings may be printed. Higher values might produce more detailed logs.
- Returns:
anomaly_scores – An array of anomaly scores with the same shape as y_true. Higher values indicate more unusual or anomalous points.
- Return type:
np.ndarray
Notes
Choosing an appropriate method depends on the data characteristics, domain requirements, and model complexity. Statistical methods are quick and interpretable but may oversimplify anomalies. Domain heuristics leverage expert knowledge, while isolation forest applies a more robust, data-driven approach. Residual-based anomalies help assess model performance and highlight periods where the model struggles.
Examples
>>> from geoprior.models.losses import compute_anomaly_scores >>> import numpy as np
>>> # Statistical method example >>> y_true = np.random.randn(32, 20, 1) # (B,H,O) >>> scores = compute_anomaly_scores(y_true, method='statistical', threshold=3) >>> scores.shape (32, 20, 1)
>>> # Domain-specific example >>> def my_heuristic(y): ... return np.where(y < -1, np.abs(y)*5, 0.0) >>> scores = compute_anomaly_scores(y_true, method='domain', domain_func=my_heuristic)
>>> # Isolation Forest example >>> scores = compute_anomaly_scores(y_true, method='isolation_forest', contamination=0.1)
>>> # Residual-based example >>> y_pred = y_true + np.random.normal(0, 1, y_true.shape) # Introduce noise >>> scores = compute_anomaly_scores(y_true, y_pred=y_pred, method='residual', residual_metric='mae')
See also
geoprior.models.losses.objective_lossFor integrating anomaly scores into a multi-objective loss.
- geoprior.models.utils._utils.generate_forecast(xtft_model, train_data, dt_col, dynamic_features, future_features=None, static_features=None, test_data=None, mode='quantile', spatial_cols=None, forecast_horizon=4, time_steps=3, q=None, tname=None, forecast_dt=None, savefile=None, verbose=3, **kw)[source]#
Generate forecast using the XTFT model.
This function uses a pre-trained Keras model to forecast future values based on provided historical data. The model receives three inputs: X_static, X_dynamic, and X_future re-built from train_data, and outputs predictions over a specified forecast horizon.
See more in User Guide.
- Parameters:
xtft_model (
object) – A validated Keras model instance. It is processed by thevalidate_keras_modelmethod.train_data (
pandas.DataFrame) – The training data containing historical records. Must include the dt_col and all required feature columns.dt_col (
str) – Name of the column representing time. It may be a datetime or numeric column (e.g."year").dynamic_features (
listofstr) – List of dynamic feature column names. They are formatted viacolumns_manager.future_features (
listofstr, optional) – List of future feature names. These columns are tiled over the forecast horizon.static_features (
listofstr, optional) – List of static feature names. If not provided, a dummy input is used.test_data (
pandas.DataFrame, optional) – DataFrame containing actual values used for evaluation. If provided, it is used to compute the R² and coverage score formode='quantile'.mode (
str, optional) – Forecast mode. Must be either"quantile"or"point". Inquantilemode, predictions for multiple quantiles (default: [0.1, 0.5, 0.9]) are computed.spatial_cols (
listofstr, optional) – List of spatial column names for grouping the data. When provided, forecasts are computed per location; otherwise, a global forecast is performed.forecast_horizon (
int, optional) – Number of future periods to forecast. Default is 4.time_steps (
int, optional) – Number of past time steps to use as input for the model. Default is 3.q (
listoffloat, optional) – List of quantiles for use inquantilemode. Default is [0.1, 0.5, 0.9]. Each quantile is validated by theassert_ratiofunction.tname (
str, optional) – Target variable name used for constructing forecast result columns. Defaults to"target".forecast_dt (
listorstr, optional) – List of forecast dates or"auto"to derive dates from dt_col. In auto mode, if dt_col is datetime, frequency is inferred usingpd.infer_freq.savefile (
str, optional) – Path to the CSV file where forecast results will be saved. If not provided, a default filename is generated.verbose (
int, optional) – Verbosity level (0-7). Controls the amount of execution output.
- Returns:
A DataFrame containing the forecast results. In
quantilemode, each forecast period includes columns for each quantile; inpointmode, a single prediction column is provided.- Return type:
Examples
Example refering to Train data only
>>> import os >>> import pandas as pd >>> import numpy as np >>> from geoprior.models.transformers import XTFT >>> from geoprior.models.losses import combined_quantile_loss >>> from geoprior.models.utils import generate_forecast >>> >>> # Create a dummy training DataFrame with a date column, >>> # dynamic features "feat1", "feat2", static feature "stat1", >>> # and target "price". >>> date_rng = pd.date_range(start="2020-01-01", periods=50, freq="D") >>> train_df = pd.DataFrame({ ... "date": date_rng, ... "feat1": np.random.rand(50), ... "feat2": np.random.rand(50), ... "stat1": np.random.rand(50), ... "price": np.random.rand(50) ... }) >>> >>> # Prepare a dummy XTFT model with example parameters. >>> # Note: The model expects the following input shapes: >>> # - X_static: (n_samples, static_input_dim) >>> # - X_dynamic: (n_samples, time_steps, dynamic_input_dim) >>> # - X_future: (n_samples, time_steps, future_input_dim) >>> my_model = XTFT( ... static_input_dim=1, # "stat1" ... dynamic_input_dim=2, # "feat1" and "feat2" ... future_input_dim=1, # features provided for dim1 ... forecast_horizon=5, # Forecasting 5 periods ahead ... quantiles=[0.1, 0.5, 0.9], ... embed_dim=16, ... max_window_size=3, ... memory_size=50, ... num_heads=2, ... dropout_rate=0.1, ... lstm_units=32, ... attention_units=32, ... hidden_units=16 ... ) >>> my_model.compile(optimizer="adam") >>> >>> # Create dummy input arrays for model fitting. >>> # For simplicity, assume time_steps = 3 and use random data. >>> X_static = train_df[["stat1"]].values # shape: (50, 1) >>> # Create a dummy dynamic input array of shape (50, 3, 2) >>> X_dynamic = np.random.rand(50, 3, 2) >>> # Create a dummy features >>> X_future = np.random.rand(50, 3, 1) >>> # Create dummy target output from "price" >>> y_array = train_df["price"].values.reshape(50, 1, 1) >>> >>> # Fit the model on the dummy data. >>> my_model.fit( ... x=[X_static, X_dynamic, X_future], ... y=y_array, ... epochs=1, ... batch_size=8 ... ) >>> >>> # Generate forecast using the generate_forecast function. >>> forecast = generate_forecast( ... xtft_model=my_model, ... train_data=train_df, ... dt_col="date", ... dynamic_features=["feat1", "feat2"], ... static_features=["stat1"], ... forecast_horizon=5, ... time_steps=3, ... tname="price", ... mode="quantile", ... verbose=3 ... ) >>> print(forecast.head())
Example refering to Test data included.
>>> # Create a dummy DataFrame with a date column, >>> # two dynamic features ("feat1", "feat2"), one static feature ("stat1"), >>> # and target "price". >>> date_rng = pd.date_range(start="2020-01-01", periods=60, freq="D") >>> data = { ... "date": date_rng, ... "feat1": np.random.rand(60), ... "feat2": np.random.rand(60), ... "stat1": np.random.rand(60), ... "price": np.random.rand(60) ... } >>> df = pd.DataFrame(data) >>> >>> # Split the DataFrame into training and test sets. >>> # Training data: dates before 2020-02-01 >>> # Test data: dates from 2020-02-01 onward. >>> train_df = df[df["date"] < "2020-02-01"].copy() >>> test_df = df[df["date"] >= "2020-02-01"].copy() >>> >>> # Create dummy input arrays for model fitting. >>> # Assume time_steps = 3. >>> X_static = train_df[["stat1"]].values # Shape: (n_train, 1) >>> X_dynamic = np.random.rand(len(train_df), 3, 2) >>> X_future = np.random.rand(len(train_df), 3, 1) >>> # Create dummy target output from "price". >>> y_array = train_df["price"].values.reshape(len(train_df), 1, 1) >>> >>> # Instantiate a dummy XTFT model. >>> my_model = XTFT( ... static_input_dim=1, # "stat1" ... dynamic_input_dim=2, # "feat1" and "feat2" ... future_input_dim=1, # For the provided future feature ... forecast_horizon=5, # Forecasting 5 periods ahead ... quantiles=[0.1, 0.5, 0.9], ... embed_dim=16, ... max_window_size=3, ... memory_size=50, ... num_heads=2, ... dropout_rate=0.1, ... lstm_units=32, ... attention_units=32, ... hidden_units=16 ... ) >>> loss_fn = combined_quantile_loss(my_model.quantiles) >>> my_model.compile(optimizer="adam", loss=loss_fn) >>> >>> # Fit the model on the training data. >>> my_model.fit( ... x=[X_static, X_dynamic, X_future], ... y=y_array, ... epochs=1, ... batch_size=8, ... callbacks = [early_stopping, model_checkpoint] ... ) >>> >>> # Generate forecast using the generate_forecast function. >>> # This example uses test_df for evaluation, which will compute >>> # metrics like R² Score and Coverage Score. >>> forecast = generate_forecast( ... xtft_model=my_model, ... train_data=train_df, ... dt_col="date", ... dynamic_features=["feat1", "feat2"], ... static_features=["stat1"], ... test_data=test_df.iloc[:5, :], # to fit the first horizon forecasting. ... forecast_horizon=5, ... time_steps=3, ... tname="price", ... mode="quantile", ... verbose=3 ... ) >>> print(forecast.head())
Example of Point forecasting
>>> # Create a dummy training DataFrame with a date column, >>> # two dynamic features ("feat1", "feat2"), one static feature ("stat1"), >>> # and target "price". >>> date_rng = pd.date_range(start="2020-01-01", periods=50, freq="D") >>> train_df = pd.DataFrame({ ... "date": date_rng, ... "feat1": np.random.rand(50), ... "feat2": np.random.rand(50), ... "stat1": np.random.rand(50), ... "price": np.random.rand(50) ... }) >>> >>> # Create dummy input arrays for model fitting. >>> # X_static is derived from the static feature "stat1". >>> X_static = train_df[["stat1"]].values # shape: (50, 1) >>> >>> # X_dynamic is a dummy dynamic array for "feat1" and "feat2". >>> # For time_steps = 3, its shape is (50, 3, 2). >>> X_dynamic = np.random.rand(50, 3, 2) >>> >>> # X_future is a dummy array for future features. >>> # Here, we assume a single future feature with shape (50, 3, 1). >>> X_future = np.random.rand(50, 3, 1) >>> >>> # Create dummy target output from "price". >>> y_array = train_df["price"].values.reshape(50, 1, 1) >>> >>> # Instantiate a dummy XTFT model. >>> my_model = XTFT( ... static_input_dim=1, # "stat1" ... dynamic_input_dim=2, # "feat1" and "feat2" ... future_input_dim=1, # Provided future feature ... forecast_horizon=5, # Forecast 5 periods ahead ... quantiles=None, # [0.1, 0.5, 0.9] Not used in point mode ... embed_dim=16, ... max_window_size=3, ... memory_size=50, ... num_heads=2, ... dropout_rate=0.1, ... lstm_units=32, ... attention_units=32, ... hidden_units=16 ... ) >>> my_model.compile(optimizer="adam") >>> >>> # Fit the model on the dummy data. >>> my_model.fit( ... x=[X_static, X_dynamic, X_future], ... y=y_array, ... epochs=1, ... batch_size=8 ... ) >>> >>> # Generate forecast using the generate_forecast function in point mode. >>> forecast = generate_forecast( ... xtft_model=my_model, ... train_data=train_df, ... dt_col="date", ... dynamic_features=["feat1", "feat2"], ... static_features=["stat1"], ... forecast_horizon=5, ... time_steps=3, ... tname="price", ... mode="point", ... verbose=3 ... ) >>> print(forecast.head())
Notes
The function groups data by spatial_cols if provided, and formats features via
columns_manager. It validates the time column usingcheck_datetimeand uses dummy inputs for missing static or future features. The forecast is produced by invokingxtft_model.predicton a list containing static, dynamic, and future inputs. The predictions are generated as follows:(7)#\[\hat{y}_{t+i} = f\Bigl(X_{\text{static}},\; X_{\text{dynamic}},\; X_{\text{future}}\Bigr)\]where \(i\) denotes the forecast period.
See also
geoprior.models.utils.reshape_xtft_dataFunction to reshape data for XTFT models.
geoprior.utils.validator.validate_keras_modelFunction to validate Keras model compatibility.
geoprior.core.handlers.columns_managerUtility to manage and format column names.
geoprior.core.checks.check_datetimeFunction to check and validate datetime columns.
geoprior.core.checks.check_spatial_columnsFunction to validate spatial columns in data.
geoprior.core.checks.assert_ratioFunction to validate and assert ratio values.
geoprior.metrics_special.coverage_scoreFunction to compute coverage score for quantile predictions.
- geoprior.models.utils._utils.generate_forecast_with(xtft_model, inputs, forecast_horizon, y=None, dt_col=None, mode='quantile', spatial_cols=None, q=None, tname=None, forecast_dt=None, apply_mask=False, mask_values=None, mask_fill_value=None, savefile=None, verbose=3, **kw)[source]#
Generate forecasts using a pre-trained XTFT model based on the forecast horizon.
There are two approaches to generating forecasts with an XTFT model:
A monolithic function (e.g.
generate_forecast) that handles both single-step and multi-step forecasts within a single implementation. This approach results in a single, large function that internally branches its logic based on the value of the forecast horizon.A modular design where the single-step and multi-step forecasting functionalities are separated into two distinct functions (e.g.
forecast_single_stepandforecast_multi_step), with a thin wrapper (e.g.generate_xtft_forecast) that dispatches to the appropriate function based on the forecast horizon.
The modular approach (2) is generally preferred because it separates concerns and improves code readability, maintainability, and unit testing. Each function becomes responsible for a well-defined task: one for single-step forecasts and one for multi-step forecasts. The wrapper function, which we propose to name
generate_xtft_forecast, simply selects the correct method based on the forecast horizon. Use this approach when your application may need to handle both short- and long- term forecasts, as it keeps the codebase modular and easier to debug.Below is an implementation of the wrapper function
generate_xtft_forecastthat callsforecast_single_stepwhenforecast_horizonequals 1 andforecast_multi_stepwhenforecast_horizonis greater than 1.- Parameters:
xtft_model (
object) – A validated Keras model instance. The model is expected to be verified viavalidate_keras_model.inputs (
listortupleofnumpy.ndarray) – A list containing three elements:X_static,X_dynamic, andX_future. Ifspatial_colsis provided, it is assumed that the first two columns ofX_staticcorrespond to the first and second spatial coordinates of the original training data.forecast_horizon (
int) – The number of future time steps to forecast. A value of 1 triggers a single-step forecast; values greater than 1 trigger a multi-step forecast.y (
numpy.ndarray, optional) – Actual target values for evaluation. If provided, evaluation metrics (e.g., R² Score, and in quantile mode, the coverage score) are computed.dt_col (
str, optional) – Name of the time column (e.g."year"). If provided, the output DataFrame includes a column with these values.mode (
str, optional) – Forecast mode, either"quantile"or"point". In quantile mode, predictions are generated for multiple quantiles (default:[0.1, 0.5, 0.9]); in point mode, a single prediction is generated.spatial_cols (
listofstr, optional) – List of spatial column names. If provided, it must contain at least two elements corresponding to the first and second columns of the original training data’sX_static.time_steps (
int, optional) – The number of historical time steps used as input.q (
listoffloat, optional) – List of quantile values for quantile forecasting. Default is[0.1, 0.5, 0.9]whenmodeis"quantile".tname (
str, optional) – Target variable name used to construct output column names (e.g.,"subsidence"). Default is"target".forecast_dt (
any, optional) – Forecast datetime information. If provided and its length matchesforecast_horizon, its values are added to the output DataFrame.apply_mask (
bool, optional) – If True, applies masking (viamask_by_reference) to adjust predictions in non-subsiding areas. Requires that bothmask_valuesandmask_fill_valueare provided.mask_values (
scalar, optional) – The reference value(s) used for masking. Must be provided ifapply_maskis True.mask_fill_value (
scalar, optional) – The value used to fill masked predictions. Must be provided ifapply_maskis True.savefile (
str, optional) – File path to save the forecast results as a CSV file. If not provided, a default filename is generated.verbose (
int, optional) – Verbosity level controlling printed output.**kw (
dict, optional) – Does nothing; here for future extension.
- Returns:
A DataFrame containing the forecast results. In quantile mode, the output includes columns for each quantile and forecast step (e.g.
<tname>_q10_step1,<tname>_q50_step2, etc.); in point mode, it contains a single prediction column per forecast step (e.g.<tname>_pred_step1). Ifyis provided, an additional column (<tname>_actual) is included.- Return type:
Examples
>>> from geoprior.models.transformers import XTFT >>> from geoprior.models.utils import generate_forecast_with >>> import numpy as np >>> >>> # Prepare a dummy XTFT model with example parameters. >>> my_model = XTFT( ... static_input_dim=10, ... dynamic_input_dim=5, ... future_input_dim=3, ... forecast_horizon=1, # This parameter will be updated in the ... # wrapper function based on forecast_horizon. ... quantiles=[0.1, 0.5, 0.9], ... embed_dim=32, ... max_window_size=3, ... memory_size=100, ... num_heads=4, ... dropout_rate=0.1, ... lstm_units=64, ... attention_units=64, ... hidden_units=32 ... ) >>> my_model.compile(optimizer='adam') >>> >>> # Create dummy input data. >>> X_static = np.random.rand(100, 10) >>> X_dynamic = np.random.rand(100, 3, 5) >>> X_future = np.random.rand(100, 3, 3) >>> y_array = np.random.rand(100, 1, 1) # For single-step target output. >>> inputs = [X_static, X_dynamic, X_future] >>> >>> # Fit the model with dummy data. >>> my_model.fit( ... x=[X_static, X_dynamic, X_future], ... y=y_array, ... epochs=1, ... batch_size=32 ... ) >>> >>> # Example for a single-step forecast: >>> forecast_df = generate_forecast_with( ... xtft_model=my_model, ... inputs=inputs, ... forecast_horizon=1, ... y=y_array, ... dt_col="year", ... mode="quantile", ... spatial_cols=["longitude", "latitude"], ... tname="subsidence", ... verbose=3 ... ) >>> print(forecast_df.head()) >>> >>> # Example for a multi-step forecast: >>> forecast_dates = ["2023", "2024", "2025", "2026"] >>> forecast_df = generate_forecast_with( ... xtft_model=my_model, ... inputs=inputs, ... forecast_horizon=4, ... y=y_array, ... dt_col="year", ... mode="point", ... spatial_cols=["longitude", "latitude"], ... tname="subsidence", ... forecast_dt=forecast_dates, ... verbose=3 ... ) >>> print(forecast_df.head())
See also
forecast_single_stepGenerates a single-step forecast.
forecast_multi_stepGenerates a multi-step forecast.
validate_keras_modelValidates a Keras model.
coverage_scoreComputes the coverage score.
- geoprior.models.utils._utils.forecast_multi_step(xtft_model, inputs, forecast_horizon, y=None, dt_col=None, mode='quantile', spatial_cols=None, q=None, tname=None, forecast_dt=None, apply_mask=False, mask_values=None, mask_fill_value=None, savefile=None, verbose=3, **kws)[source]#
Generate a multi-step forecast using the XTFT model.
This function generates forecasts for multiple future time steps using a pre-trained XTFT deep learning model. The model takes three inputs: X_static, X_dynamic, and X_future, and produces predictions according to the formulation:
(8)#\[\hat{y}_{t+i} = f\Bigl( X_{\text{static}},\; X_{\text{dynamic}},\; X_{\text{future}} \Bigr)\]for \(i = 1, \dots, forecast_horizon\), where \(f\) is the trained XTFT model.
- Parameters:
xtft_model (
object) – A validated Keras model instance. The model is expected to be verified viavalidate_keras_model.inputs (
listortupleofnumpy.ndarray) – A list containing three elements:X_static,X_dynamic, andX_future. Ifspatial_colsis provided, it is assumed that the first two columns ofX_staticcorrespond to the first and second spatial coordinates of the original training data.forecast_horizon (
int) – The number of future time steps to forecast. For example, ifforecast_horizonis 4, the model will generate predictions for 4 steps ahead.y (
numpy.ndarray, optional) – Actual target values. If provided, evaluation metrics such as R² Score and, in quantile mode, the coverage score are computed.dt_col (
str, optional) – Name of the time column (e.g."year"). If provided, a column with this name is added to the output DataFrame. The actual time values must be supplied externally.mode (
str, optional) – Forecast mode. Must be either"quantile"or"point". In quantile mode, predictions are generated for multiple quantiles (default:[0.1, 0.5, 0.9]); in point mode, a single prediction is generated.spatial_cols (
listofstr, optional) – A list of spatial column names. If provided, it must contain at least two elements corresponding to the first and second columns of the original training data’sX_static.time_steps (
int, optional) – The number of historical time steps used as input. Default is3.q (
listoffloat, optional) – List of quantile values for quantile forecasting. The default is[0.1, 0.5, 0.9]whenmodeis"quantile".tname (
str, optional) – Target variable name used to construct output column names. For instance, iftnameis"subsidence", then output columns may be named"subsidence_q10_step1","subsidence_q50_step2", etc. Default is"target".forecast_dt (
any, optional) – Forecast datetime information. If provided and its length matchesforecast_horizon, its values are added to the output DataFrame.apply_mask (
bool, optional) – If True, applies masking viamask_by_referenceto replace predictions in non-subsiding areas. Requires that bothmask_valuesandmask_fill_valueare provided.mask_values (
scalar, optional) – The reference value(s) used for masking. Must be provided ifapply_maskis True.mask_fill_value (
scalar, optional) – The value used to fill masked predictions. Must be provided ifapply_maskis True.savefile (
str, optional) – File path to save the forecast results as a CSV file. If not provided, a default filename is generated.verbose (
int, optional) – Verbosity level controlling printed output. Higher values produce more detailed messages.
- Returns:
A DataFrame containing the multi-step forecast results. In quantile mode, the DataFrame includes columns for each quantile and each forecast step (e.g.
<tname>_q10_step1,<tname>_q50_step2, etc.); in point mode, it contains a single prediction column per forecast step (e.g.<tname>_pred_step1). Ifyis provided, an additional column (<tname>_actual) is included.- Return type:
Examples
>>> from geoprior.models.transformers import XTFT >>> from geoprior.models.utils import forecast_multi_step >>> from geoprior.models.losses import combined_quantile_loss >>> import pandas as pd >>> import numpy as np >>> >>> # Create a dummy training DataFrame with a date column, >>> # spatial features ("longitude", "latitude"), two dynamic >>> # features ("feat1", "feat2"), a static feature ("stat1"), and >>> # the target variable "subsidence". >>> date_rng = pd.date_range(start="2020-01-01", periods=60, ... freq="D") >>> train_df = pd.DataFrame({ ... "date": date_rng, ... "longitude": np.random.uniform(-180, 180, 60), ... "latitude": np.random.uniform(-90, 90, 60), ... "feat1": np.random.rand(60), ... "feat2": np.random.rand(60), ... "stat1": np.random.rand(60), ... "subsidence": np.random.rand(60) ... }) >>> >>> # Prepare dummy input arrays for model training. >>> # X_static is constructed using "longitude" and "stat1". >>> X_static = train_df[["longitude", "stat1"]].values >>> # X_dynamic for "feat1" and "feat2" with time_steps = 3. >>> X_dynamic = np.random.rand(60, 3, 2) >>> # X_future is a dummy future feature array with shape (60, 3, 1). >>> X_future = np.random.rand(60, 3, 1) >>> # Target output from "subsidence" reshaped to >>> # (60, 1, 1). For multi-step forecast, forecast_horizon is 4. >>> forecast_horizon = 4 >>> y_array = train_df["subsidence"].values.reshape(60, 1, 1) >>> >>> # Instantiate a dummy XTFT model. >>> my_model = XTFT( ... static_input_dim=2, # "longitude" and "stat1" ... dynamic_input_dim=2, # "feat1" and "feat2" ... future_input_dim=1, # One future feature ... forecast_horizon=forecast_horizon, ... quantiles=[0.1, 0.5, 0.9], ... embed_dim=16, ... max_window_size=3, ... memory_size=50, ... num_heads=2, ... dropout_rate=0.1, ... lstm_units=32, ... attention_units=32, ... hidden_units=16 ... ) >>> my_model.compile( ... optimizer="adam", ... loss=combined_quantile_loss(my_model.quantiles) ... ) >>> >>> # Fit the model on the dummy data for demonstration. >>> my_model.fit( ... x=[X_static, X_dynamic, X_future], ... y=y_array, ... epochs=1, ... batch_size=8 ... ) >>> >>> # Generate forecast datetime values for the forecast horizon. >>> forecast_dates = pd.date_range(start="2020-02-01", ... periods=forecast_horizon, freq="D") >>> >>> # Package inputs as expected by forecast_multi_step. >>> inputs = [X_static, X_dynamic, X_future] >>> >>> # Generate a multi-step forecast in quantile mode. >>> forecast_df_quantile = forecast_multi_step( ... xtft_model=my_model, ... inputs=inputs, ... forecast_horizon=forecast_horizon, ... y=y_array, ... dt_col="date", ... mode="quantile", ... spatial_cols=["longitude", "latitude"], ... q=[0.1, 0.5, 0.9], ... tname="subsidence", ... forecast_dt=forecast_dates, ... apply_mask=False, ... verbose=3 ... ) >>> print("Quantile Forecast:") >>> print(forecast_df_quantile.head()) >>>
For point forecast
>>> # Instantiate a dummy XTFT model. >>> my_model = XTFT( ... static_input_dim=2, # "longitude" and "stat1" ... dynamic_input_dim=2, # "feat1" and "feat2" ... future_input_dim=1, # One future feature ... forecast_horizon=forecast_horizon, ... quantiles=None, # set quantiles to None ... embed_dim=16, ... max_window_size=3, ... memory_size=50, ... num_heads=2, ... dropout_rate=0.1, ... lstm_units=32, ... attention_units=32, ... hidden_units=16 ... ) >>> my_model.compile( ... optimizer="adam", loss="mse", ... ) >>> >>> # Fit the model on the dummy data for demonstration. >>> my_model.fit( ... x=[X_static, X_dynamic, X_future], ... y=y_array, ... epochs=1, ... batch_size=8 ... )
>>> # Generate a multi-step forecast in point mode. >>> forecast_df_point = forecast_multi_step( ... xtft_model=my_model, ... inputs=inputs, ... forecast_horizon=forecast_horizon, ... y=y_array, ... dt_col="date", ... mode="point", ... spatial_cols=["longitude", "latitude"], ... tname="subsidence", ... forecast_dt=forecast_dates, ... apply_mask=False, ... verbose=3 ... ) >>> print("Point Forecast:") >>> print(forecast_df_point.head())
Notes
In quantile mode, predictions are generated for each specified quantile for every forecast step, and the median (
0.5) is used for evaluation.In point mode, a single prediction is generated per forecast step.
The output prediction array is expected to have the shape \((n, forecast\_horizon, m)\), where \(n\) is the number of samples and \(m\) is the number of outputs per step (e.g., number of quantiles in quantile mode or 1 in point mode).
The provided
spatial_colsmust correspond to the first two columns of the original training data’sX_static.Evaluation metrics such as R² Score and Coverage Score (in quantile mode) are computed if actual target values (
y) are provided.The DataFrame is constructed by iterating over each sample and each forecast step.
See also
forecast_single_stepFunction for single-step forecasts.
coverage_scoreFunction to compute the coverage score.
validate_keras_modelFunction to validate a Keras model.
assert_ratioFunction to verify quantile ratios.
- geoprior.models.utils._utils.forecast_single_step(xtft_model, inputs, y=None, dt_col=None, mode='quantile', spatial_cols=None, q=None, tname=None, forecast_dt=None, apply_mask=False, mask_values=None, mask_fill_value=None, savefile=None, verbose=3, **kws)[source]#
Generate a single-step forecast using the XTFT model.
This function generates a forecast for a single future time step using a pre-trained XTFT deep learning model. The model takes three inputs: X_static, X_dynamic, and X_future, and produces a prediction according to the formulation:
(9)#\[\hat{y}_{t+1} = f\Bigl( X_{\text{static}},\; X_{\text{dynamic}},\; X_{\text{future}} \Bigr)\]where \(f\) is the trained XTFT model. The predictions can be either quantile-based or point-based, as determined by the mode parameter.
- Parameters:
xtft_model (
object) – A validated Keras model instance. The model is expected to be verified viavalidate_keras_model.inputs (
listortupleofnumpy.ndarray) – A list containing three elements:X_static,X_dynamic, andX_future. Ifspatial_colsis provided, it is assumed that the first column ofX_staticcorresponds to the first spatial coordinate and the second column to the second spatial coordinate of the original training data.y (
numpy.ndarray, optional) – Actual target values. If provided, evaluation metrics such as R² Score and (in quantile mode) the coverage score are computed.dt_col (
str, optional) – Name of the time column (e.g."year"). If provided, a column with this name is added to the output DataFrame. The actual time values must be supplied externally.mode (
str, optional) – Forecast mode. Must be either"quantile"or"point". In quantile mode, predictions are generated for multiple quantiles (default:0.1,0.5, and0.9).spatial_cols (
listofstr, optional) – List of spatial column names. If provided, it must contain at least two elements and correspond to the first and second columns of the original training data’sX_static.q (
listoffloat, optional) – List of quantiles for quantile forecasting. Default is[0.1, 0.5, 0.9]when mode is"quantile".tname (
str, optional) – Target variable name for predictions. This name is used to construct output column names (e.g."subsidence"). Default is"target".forecast_dt (
any, optional) – Forecast datetime information. Not used in this function but may be provided for compatibility.apply_mask (
bool, optional) – If True, applies a masking function (mask_by_reference) to replace predictions in non-subsiding areas. Requires that bothmask_valuesandmask_fill_valueare provided.mask_values (
scalar, optional) – Reference value(s) used for masking. Must be provided ifapply_maskis True.mask_fill_value (
scalar, optional) – Value used to fill masked predictions. Must be provided ifapply_maskis True.savefile (
str, optional) – Path to a CSV file where the forecast results will be saved. If not provided, a default filename is generated.verbose (
int, optional) – Verbosity level controlling printed output. Higher values result in more detailed output.
- Returns:
A DataFrame containing the forecast results. In quantile mode, the output includes columns for each quantile (e.g.
<tname>_q10,<tname>_q50,<tname>_q90). In point mode, a single prediction column (<tname>_pred) is provided. If y is provided, an additional column with the actual target values (<tname>_actual) is included.- Return type:
Examples
>>> from geoprior.models.transformers import XTFT >>> from geoprior.models.utils import forecast_single_step >>> import pandas as pd >>> import numpy as np >>> >>> # Create a dummy training DataFrame with a date column, >>> # two dynamic features ("feat1", "feat2"), a static feature ("stat1"), >>> # and dummy spatial features ("longitude", "latitude"), as well as the >>> # target variable "subsidence". >>> date_rng = pd.date_range(start="2020-01-01", periods=50, freq="D") >>> train_df = pd.DataFrame({ ... "date": date_rng, ... "longitude": np.random.uniform(-180, 180, 50), ... "latitude": np.random.uniform(-90, 90, 50), ... "feat1": np.random.rand(50), ... "feat2": np.random.rand(50), ... "stat1": np.random.rand(50), ... "subsidence": np.random.rand(50) ... }) >>> >>> # Prepare dummy inputs for the model. >>> # For the static input, combine the spatial feature "longitude" and the >>> # static feature "stat1". The forecast_single_step function expects that, >>> # if spatial_cols is provided, the first two columns of X_static correspond >>> # to the spatial coordinates. >>> X_static = train_df[["longitude", "stat1"]].values # shape: (50, 2) >>> >>> # Create a dummy dynamic input array for "feat1" and "feat2". >>> # Assume time_steps = 3, so the shape is (50, 3, 2). >>> X_dynamic = np.random.rand(50, 3, 2) >>> >>> # Create a dummy future input array. >>> # For this example, assume one future feature with shape (50, 3, 1). >>> X_future = np.random.rand(50, 3, 1) >>> >>> # Create dummy target output from "subsidence", reshaped to (50, 1, 1) >>> y_array = train_df["subsidence"].values.reshape(50, 1, 1) >>> >>> # Instantiate a dummy XTFT model. >>> # The model expects: >>> # - X_static with shape (n_samples, static_input_dim) >>> # - X_dynamic with shape (n_samples, time_steps, dynamic_input_dim) >>> # - X_future with shape (n_samples, time_steps, future_input_dim) >>> my_model = XTFT( ... static_input_dim=2, # "longitude" and "stat1" ... dynamic_input_dim=2, # "feat1" and "feat2" ... future_input_dim=1, # One future feature ... forecast_horizon=1, # Single-step forecast ... quantiles=[0.1, 0.5, 0.9], ... embed_dim=16, ... max_window_size=3, ... memory_size=50, ... num_heads=2, ... dropout_rate=0.1, ... lstm_units=32, ... attention_units=32, ... hidden_units=16 ... ) >>> my_model.compile(optimizer="adam") >>> >>> # Fit the model on the dummy data. >>> my_model.fit( ... x=[X_static, X_dynamic, X_future], ... y=y_array, ... epochs=1, ... batch_size=8 ... ) >>> >>> # Package the inputs as expected by forecast_single_step. >>> inputs = [X_static, X_dynamic, X_future] >>> >>> # Generate a single-step quantile forecast. >>> forecast_df = forecast_single_step( ... xtft_model=my_model, ... inputs=inputs, ... y=y_array, ... dt_col="date", # The time column name in the output ... mode="quantile", # Can be "quantile" or "point" ... spatial_cols=["longitude", "latitude"], ... q=[0.1, 0.5, 0.9], ... tname="subsidence", ... apply_mask=True, ... mask_values=0, ... mask_fill_value=0, ... verbose=3 ... ) >>> print(forecast_df.head())
Notes
In quantile mode, the function computes predictions for multiple quantiles and uses the median (
0.5) for evaluation.If
spatial_colsis provided, it must be the first and second columns of the original training data’sX_static.The function internally utilizes
validate_keras_modelfor model validation,assert_ratiofor quantile verification, andmask_by_referencefor masking operations.Evaluation metrics such as R² Score and Coverage Score are computed if actual target values (y) are provided.
The prediction output is expected to have the shape \((n, 1, m)\), where \(m\) is the number of outputs (e.g., the number of quantiles in quantile mode, or 1 in point mode).
See also
generate_forecast_multi_stepFunction for multi-step forecasts.
coverage_scoreFunction to compute the coverage score.
validate_keras_modelFunction to validate a Keras model.
assert_ratioFunction to validate quantile ratios.
- geoprior.models.utils._utils.step_to_long(df, tname=None, dt_col=None, spatial_cols=None, mode='quantile', quantiles=None, verbose=3, sort=True)[source]#
Convert a multi-step forecast DataFrame from wide to long format.
This function transforms a DataFrame containing multi-step forecast predictions into a long-format DataFrame. In quantile mode, forecast columns such as
subsidence_q10_step1,subsidence_q50_step1, etc. are consolidated into unified columns (e.g.subsidence_q10,subsidence_q50, etc.), while in point mode, a single prediction column (subsidence_pred) is generated. The transformation also carries over additional columns (e.g. spatial coordinates and time) from the original DataFrame.- Parameters:
df (
pandas.DataFrame) – The multi-step forecast DataFrame. Expected to contain forecast prediction columns (e.g. columns with_qor_pred_stepin their names) along with other identifiers.tname (
str, optional) – The base name of the target variable (e.g."subsidence"). IfNone, the function attempts to auto-detect the target name from the column names.dt_col (
str, optional) – The name of the time column to include in the final DataFrame. If not provided, time sorting is not performed.spatial_cols (
listofstr, optional) – A list of spatial coordinate columns (e.g.["longitude", "latitude"]) to be retained in the final output.mode (
{"quantile", "point"}, default"quantile") – The forecast mode. In"quantile"mode, multiple quantile forecast columns are merged into unified columns. In"point"mode, a single prediction column is produced.quantiles (
listoffloat, optional) – The quantile values for quantile mode (e.g.[0.1, 0.5, 0.9]). If not provided, defaults are used.sort (
bool, optional) – If True, sorts the final DataFrame by the column specified indt_col(if present). Default is True.verbose (
int, optional) – Verbosity level for logging output. Higher values (e.g. 5 to 7) provide more detailed debug information.
- Returns:
A long-format DataFrame containing the retained spatial columns, the time column when
dt_colis provided, and the merged forecast prediction columns. In quantile mode, the output contains unified columns such assubsidence_q10andsubsidence_q50. In point mode, it contains a singlesubsidence_predcolumn.- Return type:
Examples
>>> from geoprior.models.utils import step_to_long >>> # Given a DataFrame `forecast_df` with columns like: >>> # ['longitude', 'latitude', 'year', 'subsidence_actual', >>> # 'subsidence_q10_step1', 'subsidence_q50_step1', 'subsidence_q89_step1', >>> # 'subsidence_q10_step2', ...] >>> long_df = step_to_long( ... df=forecast_df, ... tname="subsidence", ... dt_col="year", ... spatial_cols=["longitude", "latitude"], ... mode="quantile", ... quantiles=[0.1, 0.5, 0.9], ... verbose=3, ... sort=True ... ) >>> print(long_df.head())
Notes
Internally, this function calls:
check_forecast_mode()to validate the user-specified quantiles.validate_consistency_q()andvalidate_quantiles()to ensure the supplied quantiles match those auto-detected from the DataFrame.Depending on
mode, either_step_to_long_q()for quantile mode or_step_to_long_pred()for point mode performs the conversion.
Mathematically, let \(X \in \mathbb{R}^{n \times m}\) represent the wide-format DataFrame, where each row corresponds to one sample and each forecast step is stored in separate columns. The function reshapes \(X\) into a long-format DataFrame \(Y \in \mathbb{R}^{(n \cdot s) \times p}\), where \(s\) is the forecast horizon and \(p\) is the number of output columns after merging forecast step values.
See also
_step_to_long_qConverts multi-step quantile forecasts to long format.
_step_to_long_predConverts multi-step point forecasts to long format.
detect_digitsExtracts numeric values from strings for quantile detection.
- geoprior.models.utils._utils.prepare_model_inputs(dynamic_input, static_input=None, future_input=None, model_type='strict', forecast_horizon=None, verbose=0, **kwargs)[source]#
Prepares a list of input tensors for a model’s call method.
This function standardizes the creation of the input list [static, dynamic, future] expected by many models in
geoprior. It handles cases where static or future inputs might beNone, creating appropriate dummy tensors with zero features if the model_type is ‘strict’.- Parameters:
dynamic_input (
np.ndarrayortf.Tensor) – The dynamic (past observed) features. This input is always required and must be a valid tensor or array. Expected shape: (batch_size, past_time_steps, num_dynamic_features).static_input (
np.ndarrayortf.Tensor, optional) – The static (time-invariant) features. Expected shape: (batch_size, num_static_features). IfNoneand model_type is ‘strict’, a dummy tensor with 0 static features will be created. Default isNone.future_input (
np.ndarrayortf.Tensor, optional) – The known future features. Expected shape: (batch_size, future_time_span, num_future_features). IfNoneand model_type is ‘strict’, a dummy tensor with 0 future features will be created. The time span for this dummy future tensor will be past_time_steps (from dynamic_input) plus forecast_horizon if provided, otherwise just past_time_steps. Default isNone.model_type (
{'strict', 'flexible'}, default'strict') –Determines how
Noneinputs for static and future features are handled:'strict': If static_input or future_input isNone, a dummy tensor with a feature dimension of 0 will be created and included in the output list. This is for models that expect a 3-element list of tensors, even if some paths are unused.'flexible': If static_input or future_input isNone,Noneitself will be placed in the corresponding position in the output list. This is for models that can internally handleNoneinputs for optional feature types.
forecast_horizon (
int, optional) – The forecast horizon. Used only if model_type=’strict’ and future_input isNone, to determine the time dimension of the dummy future tensor (as past_time_steps + forecast_horizon). If not provided in this scenario, the dummy future tensor’s time dimension will match dynamic_input’s past_time_steps. Default isNone.verbose (
int, default0) –Verbosity level. If > 0, prints information about dummy tensor creation.
0: Silent.1: Basic info on dummy creation.2: More details on shapes.
- Returns:
A list containing three elements in the order: [processed_static_input, processed_dynamic_input, processed_future_input]. Elements can be TensorFlow tensors or
None(if model_type=’flexible’ and original input wasNone). All returned tensors are cast to tf.float32.- Return type:
List[Optional[tf.Tensor]]- Raises:
ValueError – If dynamic_input is
None. If dynamic_input is not at least 2D (needs batch dimension). If static_input (when provided) is not 2D. If future_input (when provided) is not 3D.TypeError – If inputs cannot be converted to TensorFlow tensors.
Examples
>>> import tensorflow as tf >>> import numpy as np >>> from geoprior.models.utils import prepare_model_inputs >>> B, T, H = 2, 10, 3 >>> D_s, D_d, D_f = 2, 4, 1 >>> dyn_in = tf.random.normal((B, T, D_d)) >>> stat_in = tf.random.normal((B, D_s)) >>> fut_in = tf.random.normal((B, T + H, D_f))
>>> # Strict mode, all inputs provided >>> s, d, f = prepare_model_inputs(dyn_in, stat_in, fut_in, model_type='strict') >>> print(f"S: {s.shape}, D: {d.shape}, F: {f.shape}") S: (2, 2), D: (2, 10, 4), F: (2, 13, 1)
>>> # Strict mode, static is None >>> s, d, f = prepare_model_inputs(dyn_in, static_input=None, future_input=fut_in, ... model_type='strict', forecast_horizon=H) >>> print(f"S: {s.shape}, D: {d.shape}, F: {f.shape}") S: (2, 0), D: (2, 10, 4), F: (2, 13, 1)
>>> # Flexible mode, static and future are None >>> s, d, f = prepare_model_inputs(dyn_in, static_input=None, future_input=None, ... model_type='flexible') >>> print(f"S: {s is None}, D: {d.shape}, F: {f is None}") S: True, D: (2, 10, 4), F: True
- geoprior.models.utils._utils.extract_batches_from_dataset(dataset, num_batches_to_extract=1, agg=False, errors='warn')[source]#
Extracts a specified number of batches from a tf.data.Dataset. Optionally aggregates the extracted batches.
- Parameters:
dataset (
tf.data.Dataset) – The TensorFlow dataset to extract batches from.num_batches_to_extract (
Union[int,str], default1) – Number of batches: int, or ‘all’, ‘*’, ‘auto’.agg (
bool, defaultFalse) – If True, attempts to aggregate the extracted batches into a single tuple structure by concatenating corresponding tensors/arrays or aggregating dictionaries.errors (
str, default'warn') – Error handling: ‘raise’, ‘warn’, ‘ignore’.
- Returns:
If
aggisFalse, returns a list of batch tuples. IfaggisTrue, returns one aggregated tuple orNoneif no batches were extracted. When zero batches are requested or the dataset is empty, the function returns an empty list foragg=FalseandNoneforagg=True.- Return type:
Union[List[Tuple[Any,]],Optional[Tuple[Any,]]]- Raises:
TypeError – If dataset is not a tf.data.Dataset or num_batches_to_extract is invalid (and errors=’raise’).
ValueError – If num_batches_to_extract is negative, or fewer batches are available than requested (and errors=’raise’ and not taking all).
RuntimeError – For unexpected errors during dataset iteration (and errors=’raise’).
- geoprior.models.utils._utils.format_predictions(predictions=None, model=None, inputs=None, y_true_sequences=None, target_name='target', quantiles=None, forecast_horizon=None, output_dim=None, spatial_data_array=None, spatial_cols=None, spatial_cols_indices=None, evaluate_coverage=False, scaler=None, scaler_feature_names=None, target_idx_in_scaler=None, verbose=0, _logger=None, **kwargs)[source]#
Formats model predictions into a structured pandas DataFrame.
This utility function takes raw model predictions (either directly as an array/tensor or generated by a provided model and its inputs) and transforms them into a long-format pandas DataFrame. It can handle point forecasts, quantile forecasts, single or multi-output predictions, and optionally include actual target values, spatial identifiers, and perform coverage score evaluation for quantile forecasts.
The output DataFrame is structured with ‘sample_idx’ (identifying the original input sequence) and ‘forecast_step’ (from 1 to H, where H is the forecast horizon).
- Parameters:
predictions (
np.ndarrayortf.Tensor, optional) –The raw prediction tensor or array.
- For point forecasts, expected shapes:
(num_samples, forecast_horizon, output_dim)
(num_samples, forecast_horizon) if output_dim=1 (will be reshaped)
(num_samples, output_dim) if forecast_horizon=1 (will be reshaped)
- For quantile forecasts, expected shapes:
(num_samples, forecast_horizon, num_quantiles * output_dim)
(num_samples, forecast_horizon, num_quantiles, output_dim)
(num_samples, num_quantiles * output_dim) if forecast_horizon=1
If
None, model and inputs must be provided. Default isNone.model (
tf.keras.Model, optional) – A trained Keras model to generate predictions if predictions is not provided. Used in conjunction with inputs. Default isNone.inputs (
List[Optional[Union[np.ndarray,tf.Tensor]]], optional) – A list of input tensors (e.g., [static, dynamic, future]) required by the model to generate predictions. Required if predictions isNoneand model is provided. Default isNone.y_true_sequences (
np.ndarrayortf.Tensor, optional) – The true target values corresponding to the predictions, used for including actuals in the output DataFrame and for evaluation. Expected shape: (num_samples, forecast_horizon, output_dim). Default isNone.target_name (
str, optional) – Base name for the target variable. Used to prefix prediction and actual column names (e.g., “sales_pred”, “sales_q50”, “sales_actual”). Default is “target”.quantiles (
List[float], optional) – A list of quantiles that were predicted by the model (e.g., [0.1, 0.5, 0.9]). Required if the predictions are quantile forecasts. If provided, prediction columns will be named like {target_name}_q10, {target_name}_q50, etc. Default isNone(for point forecasts).forecast_horizon (
int, optional) – The number of time steps into the future that the model predicts. If not provided, it’s inferred from predictions.shape[1] or y_true_sequences.shape[1]. Default isNone.output_dim (
int, optional) – The number of target variables predicted at each time step (e.g., 1 for univariate, >1 for multivariate target). If not provided, it’s inferred from the shape of predictions or y_true_sequences. Default isNone.spatial_data_array (
np.ndarrayortf.Tensororpd.DataFrameorpd.Series, optional) –An array or DataFrame containing static spatial/identifier features for each of the num_samples sequences.
If NumPy/Tensor: Expected shape (num_samples, num_spatial_features). spatial_cols_indices must be provided.
If DataFrame/Series: Must have num_samples rows. spatial_cols must be provided.
These features will be repeated for each forecast step in the output DataFrame. Default is
None.spatial_cols (
List[str], optional) – List of column names to select from spatial_data_array if it’s a DataFrame/Series, or names to assign to columns if spatial_data_array is NumPy/Tensor and spatial_cols_indices are provided. Default isNone.spatial_cols_indices (
List[int], optional) – List of column indices to select from spatial_data_array if it’s a NumPy/Tensor. Length must match spatial_cols if provided. Default isNone.evaluate_coverage (
bool, defaultFalse) – IfTrue, quantiles are provided (at least two), and y_true_sequences is available, calculates the coverage score using the first and last quantiles as interval bounds. Requires geoprior.metrics.coverage_score.scaler (
Any, optional) – A fitted scikit-learn-like scaler object (must have an inverse_transform method) used to scale the target variable and potentially other features. If provided along with scaler_feature_names and target_idx_in_scaler, predictions and actuals for the target will be inverse-transformed. Default isNone.scaler_feature_names (
List[str], optional) – A list of all feature names (in order) that the scaler was originally fit on. Required if scaler is provided and targeted inverse transformation is needed. Default isNone.target_idx_in_scaler (
int, optional) – The index of the target_name within the scaler_feature_names list. Required if scaler is provided and targeted inverse transformation is needed. Default isNone.verbose (
int, default0) –Verbosity level for logging during processing.
0: Silent.1: Basic info.3: More detailed steps.5: Very detailed shape information.
**kwargs (Any) – Additional keyword arguments (currently not used but included for future extensibility).
**kwargs
- Returns:
A long-format DataFrame containing
sample_idxandforecast_step, optional spatial columns, prediction columns, and actual-value columns wheny_true_sequencesis provided. Point forecasts use names like{target_name}_predor{target_name}_{output_idx}_pred. Quantile forecasts use names like{target_name}_qXXor{target_name}_{output_idx}_qXX. Actual values use{target_name}_actualor{target_name}_{output_idx}_actual. Prediction and actual values are inverse-transformed when valid scaler information is provided.- Return type:
- Raises:
ValueError – If predictions is None and model or inputs is also None. If predictions shape is invalid (not 2D, 3D, or 4D). If quantiles are provided but prediction shape is incompatible for inferring output_dim. If spatial_data_array is provided without necessary name/index parameters.
TypeError – If predictions or other inputs cannot be converted to the expected tensor/array types.
See also
geoprior.models.utils.forecast_multi_stepHigher-level forecasting utility.
geoprior.metrics.coverage_scoreFor evaluating quantile forecast intervals.
Examples
>>> import tensorflow as tf >>> import numpy as np >>> from geoprior.models.utils import format_predictions_to_dataframe
>>> B, H, O = 4, 3, 1 # Batch, Horizon, OutputDim >>> Q = [0.1, 0.5, 0.9] >>> preds_point = tf.random.normal((B, H, O)) >>> preds_quant = tf.random.normal((B, H, len(Q))) # For O=1 >>> y_true = tf.random.normal((B, H, O))
>>> # Point forecast >>> df_point = format_predictions_to_dataframe( ... predictions=preds_point, y_true_sequences=y_true, ... target_name="value", forecast_horizon=H, output_dim=O ... ) >>> print(df_point.head(H)) # Show first sample's horizon sample_idx forecast_step value_pred value_actual 0 0 1 -0.576731 -0.647362 1 0 2 0.183931 1.198977 2 0 3 -0.766871 0.534040
>>> # Quantile forecast >>> df_quant = format_predictions_to_dataframe( ... predictions=preds_quant, y_true_sequences=y_true, ... target_name="value", quantiles=Q, ... forecast_horizon=H, output_dim=O ... ) >>> print(df_quant.head(H)) sample_idx forecast_step value_q10 value_q50 value_q90 value_actual 0 0 1 -0.209947 0.263107 -0.308929 -0.647362 1 0 2 0.303091 0.594701 -0.225007 1.198977 2 0 3 0.136699 -1.237739 0.002834 0.534040
>>> # With spatial data (NumPy array) >>> spatial_np = np.array([[101, 201], [102, 202], [103, 203], [104, 204]]) >>> df_spatial = format_predictions_to_dataframe( ... predictions=preds_point, ... spatial_data_array=spatial_np, ... spatial_cols=['store_id', 'region_id'], ... spatial_cols_indices=[0, 1] ... ) >>> print(df_spatial[['sample_idx', 'forecast_step', 'store_id']].head(H)) sample_idx forecast_step store_id 0 0 1 101.0 1 0 2 101.0 2 0 3 101.0
- geoprior.models.utils._utils.export_keras_losses(history, keys=None, savefile=None, verbose=0, formats=('json', 'csv'))[source]#
Export loss(es) (and any other metric) from a Keras History object.
- Parameters:
history (History) – The History object returned by model.fit().
keys (list[str] | None) – List of history.history keys to export, e.g. [“loss”, “val_loss”, “physics_loss”]. If None, defaults to all keys ending with “loss”.
savefile (str | None) –
Path (optionally with extension) where to write the output.
If the extension is
.json, only JSON is written. If the extension is.csv, only CSV is written. If no extension is given, all formats informatsare written usingsavefileas the base name.verbose (int) – If >0, prints status messages.
formats (tuple[str, ...]) – File formats to write when savefile has no extension. Valid entries are “json” and “csv”.
- Returns:
result – Dictionary containing
"epochs_run"and one list entry per exported history key.- Return type:
- geoprior.models.utils._utils.get_tensor_from(inputs, *tensor_names, default=None, check_type=True, auto_convert=True)[source]#
Safely retrieves the first available tensor from a dictionary using a list of possible keys.
This utility is crucial for handling model inputs within a TensorFlow graph (e.g., in train_step). It avoids the ambiguous boolean evaluation of Tensors (e.g., tensor_a or tensor_b), which causes runtime errors, by explicitly checking for is not None.
- Parameters:
inputs (
dict) – The dictionary to search, typically the model’s input dictionary (e.g., the inputs provided to call or train_step).*tensor_names (
str) – One or more string keys to check for in the inputs dictionary, in order of priority.default (
Any, optional) – A default value to return if no keys are found or if no found value is a valid tensor. Defaults to None.check_type (
bool, defaultTrue) – If True, only returns a value if it is (or can be converted to) a Tensor or Variable. If False, returns the first non-None value regardless of its type.auto_convert (
bool, defaultTrue) – If True and check_type is True, this function will attempt to convert a found non-Tensor value (like a NumPy array or a list) into a TensorFlow tensor using tf.convert_to_tensor.
- Returns:
The first found tf.Tensor or tf.Variable associated with one of the tensor_names. If auto_convert is True, this can also be a newly converted tensor. Returns default (typically None) if no valid tensor is found.
- Return type:
Optional[tf.Tensor]- Raises:
TypeError – If inputs is not a dictionary.
Examples
>>> import tensorflow as tf >>> inputs_dict = { ... 'some_other_key': [1, 2, 3], ... 'soil_thickness': tf.constant([20., 21.], dtype=tf.float32) ... } >>> >>> # Correctly finds 'soil_thickness' >>> get_tensor_from(inputs_dict, 'H_field', 'soil_thickness') <tf.Tensor: shape=(2,), dtype=float32, numpy=array([20., 21.], ...)> >>> >>> # Returns None safely if nothing is found >>> get_tensor_from(inputs_dict, 'missing_key', 'another_key') None >>> >>> # Demonstrating auto_convert >>> inputs_dict_np = {'H_field': np.array([10., 11.])} >>> get_tensor_from(inputs_dict_np, 'H_field', auto_convert=True) <tf.Tensor: shape=(2,), dtype=float32, numpy=array([10., 11.], ...)>