geoprior.utils.split#
geoprior.utils.split
Group-holdout split for sequence data.
- Exports:
train_windows_T{T}_H{H}.npz val_windows_T{T}_H{H}.npz test_windows_T{T}_H{H}.npz future_inputs_T{T}_H{H}.npz splits_groups.json
- Leakage fix (Zhongshan 2 windows/pixel):
split by group_id first, then window inside split.
Functions
|
|
|
Build train/val/test windows using group holdout. |
|
|
|
|
|
|
|
Classes
|
- class geoprior.utils.split.SplitCfg(seed: 'int' = 42, ratios: 'tuple[float, float, float]' = (0.7, 0.15, 0.15), decimals: 'int' = 8)[source]#
Bases:
object
- geoprior.utils.split.split_group_keys(keys, *, cfg=SplitCfg(seed=42, ratios=(0.7, 0.15, 0.15), decimals=8))[source]#
- geoprior.utils.split.write_splits_json(path, *, group_cols, time_steps, horizon, train_end, cfg, splits)[source]#
- geoprior.utils.split.build_group_holdout_npzs(*, df_train, artifacts_dir, group_cols, time_col_used, x_col_used, y_col_used, subs_col, gwl_target_col, gwl_dyn_col, h_field_col, static_cols, dynamic_cols, future_cols, time_steps, horizon, mode, model_name, train_end, keys_ok, cfg=SplitCfg(seed=42, ratios=(0.7, 0.15, 0.15), decimals=8), normalize_coords=True)[source]#
Build train/val/test windows using group holdout.
Returns dict containing paths and coord_scaler.
- Parameters:
df_train (DataFrame)
artifacts_dir (str)
time_col_used (str)
x_col_used (str)
y_col_used (str)
subs_col (str)
gwl_target_col (str)
gwl_dyn_col (str)
h_field_col (str)
time_steps (int)
horizon (int)
mode (str)
model_name (str)
train_end (float | None)
keys_ok (ndarray)
cfg (SplitCfg)
normalize_coords (bool)
- Return type:
- geoprior.utils.split.build_future_inputs_npz(*, df_scaled, artifacts_dir, time_col, time_col_num, lon_col, lat_col, subs_col, gwl_col, h_field_col, static_features, dynamic_features, future_features, group_cols, train_end_time, forecast_start_time, horizon, time_steps, mode, model_name, normalize_coords, coord_scaler=None)[source]#
- Parameters:
df_scaled (DataFrame)
artifacts_dir (str)
time_col (str)
time_col_num (str | None)
lon_col (str)
lat_col (str)
subs_col (str)
gwl_col (str)
h_field_col (str)
train_end_time (Any)
forecast_start_time (Any)
horizon (int)
time_steps (int)
mode (str)
model_name (str)
normalize_coords (bool)
coord_scaler (Any)
- Return type: