# License: Apache-2.0
# Copyright (c) 2026-present
# Author: LKouadio <etanoyau@gmail.com>
"""
geoprior.utils.split
Group-holdout split for sequence data.
Exports:
train_windows_T{T}_H{H}.npz
val_windows_T{T}_H{H}.npz
test_windows_T{T}_H{H}.npz
future_inputs_T{T}_H{H}.npz
splits_groups.json
Leakage fix (Zhongshan 2 windows/pixel):
split by group_id first, then window inside split.
"""
from __future__ import annotations
import json
import os
import shutil
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any
import numpy as np
import pandas as pd
from .panel_cache import (
make_group_keys,
min_required_len,
)
from .sequence_utils import (
build_future_sequences_npz,
)
__all__ = [
"SplitCfg",
"split_group_keys",
"subset_by_keys",
"write_splits_json",
"pack_xy_npz",
"build_group_holdout_npzs",
"build_future_inputs_npz",
]
[docs]
@dataclass(frozen=True)
class SplitCfg:
seed: int = 42
ratios: tuple[float, float, float] = (0.7, 0.15, 0.15)
decimals: int = 8
def _check_ratios(r: tuple[float, float, float]) -> None:
if len(r) != 3:
raise ValueError("ratios must be (train,val,test).")
s = float(r[0]) + float(r[1]) + float(r[2])
if not np.isclose(s, 1.0, atol=1e-8):
raise ValueError("ratios must sum to 1.0.")
_SPLITCFG = SplitCfg()
[docs]
def split_group_keys(
keys: np.ndarray,
*,
cfg: SplitCfg = _SPLITCFG,
) -> dict[str, np.ndarray]:
_check_ratios(cfg.ratios)
keys = np.asarray(keys, dtype=object)
if keys.size < 3:
raise ValueError("Need >= 3 groups for splits.")
rng = np.random.default_rng(int(cfg.seed))
keys = keys[rng.permutation(keys.size)]
n = int(keys.size)
n_tr = int(np.floor(cfg.ratios[0] * n))
n_va = int(np.floor(cfg.ratios[1] * n))
n_tr = max(1, min(n_tr, n - 2))
n_va = max(1, min(n_va, n - n_tr - 1))
tr = keys[:n_tr]
va = keys[n_tr : n_tr + n_va]
te = keys[n_tr + n_va :]
return {"train": tr, "val": va, "test": te}
[docs]
def subset_by_keys(
df: pd.DataFrame,
*,
group_cols: Sequence[str],
keys: np.ndarray,
decimals: int = 8,
) -> pd.DataFrame:
if keys is None or len(keys) == 0:
return df.iloc[0:0].copy()
k = make_group_keys(
df,
group_cols=group_cols,
decimals=decimals,
)
keep = k.isin(set(map(str, keys)))
return df.loc[keep].copy()
[docs]
def write_splits_json(
path: str,
*,
group_cols: Sequence[str],
time_steps: int,
horizon: int,
train_end: float | None,
cfg: SplitCfg,
splits: dict[str, np.ndarray],
) -> str:
os.makedirs(os.path.dirname(path), exist_ok=True)
payload: dict[str, Any] = {
"group_cols": list(group_cols),
"decimals": int(cfg.decimals),
"seed": int(cfg.seed),
"ratios": list(map(float, cfg.ratios)),
"time_steps": int(time_steps),
"horizon": int(horizon),
"min_len": int(min_required_len(time_steps, horizon)),
"train_end": train_end,
"counts": {k: int(len(v)) for k, v in splits.items()},
"splits": {
k: list(map(str, v)) for k, v in splits.items()
},
}
with open(path, "w", encoding="utf-8") as f:
json.dump(payload, f, indent=2)
return path
def _pick_target_key(
d: dict[str, Any],
*names: str,
) -> str:
for n in names:
if n in d:
return n
raise KeyError(f"Missing target key in {list(d)}")
[docs]
def pack_xy_npz(
x: dict[str, Any],
y: dict[str, Any] | None,
) -> dict[str, np.ndarray]:
out: dict[str, np.ndarray] = {
"coords": np.asarray(x["coords"], np.float32),
"dynamic_features": np.asarray(
x["dynamic_features"],
np.float32,
),
"static_features": np.asarray(
x.get("static_features", np.zeros((0, 0))),
np.float32,
),
"future_features": np.asarray(
x.get("future_features", np.zeros((0, 0, 0))),
np.float32,
),
"H_field": np.asarray(x["H_field"], np.float32),
}
if y is None:
return out
s_key = _pick_target_key(y, "subs_pred", "subsidence")
g_key = _pick_target_key(y, "gwl_pred", "gwl")
out["subs_pred"] = np.asarray(y[s_key], np.float32)
out["gwl_pred"] = np.asarray(y[g_key], np.float32)
return out
def _save_npz(
path: str, arrays: dict[str, np.ndarray]
) -> str:
os.makedirs(os.path.dirname(path), exist_ok=True)
np.savez_compressed(path, **arrays)
return path
[docs]
def build_group_holdout_npzs(
*,
df_train: pd.DataFrame,
artifacts_dir: str,
group_cols: Sequence[str],
time_col_used: str,
x_col_used: str,
y_col_used: str,
subs_col: str,
gwl_target_col: str,
gwl_dyn_col: str,
h_field_col: str,
static_cols: Sequence[str],
dynamic_cols: Sequence[str],
future_cols: Sequence[str],
time_steps: int,
horizon: int,
mode: str,
model_name: str,
train_end: float | None,
keys_ok: np.ndarray,
cfg: SplitCfg = _SPLITCFG,
normalize_coords: bool = True,
) -> dict[str, Any]:
"""
Build train/val/test windows using group holdout.
Returns dict containing paths and coord_scaler.
"""
from geoprior.nn.pinn.utils import (
prepare_pinn_data_sequences,
)
splits = split_group_keys(keys_ok, cfg=cfg)
spath = os.path.join(artifacts_dir, "splits_groups.json")
write_splits_json(
spath,
group_cols=group_cols,
time_steps=time_steps,
horizon=horizon,
train_end=train_end,
cfg=cfg,
splits=splits,
)
coord_scaler = None
out: dict[str, Any] = {"splits_groups_json": spath}
def _one(
which: str,
) -> tuple[dict[str, Any], dict[str, Any], Any]:
dfi = subset_by_keys(
df_train,
group_cols=group_cols,
keys=splits[which],
decimals=cfg.decimals,
)
if dfi.empty:
raise RuntimeError(f"{which} split is empty.")
fit_cs = bool(normalize_coords) and (which == "train")
ins, tar, cs = prepare_pinn_data_sequences(
df=dfi,
time_col=time_col_used,
lon_col=x_col_used,
lat_col=y_col_used,
subsidence_col=subs_col,
gwl_col=gwl_target_col,
gwl_dyn_col=gwl_dyn_col,
h_field_col=h_field_col,
dynamic_cols=list(dynamic_cols),
static_cols=list(static_cols),
future_cols=list(future_cols),
group_id_cols=list(group_cols),
time_steps=int(time_steps),
forecast_horizon=int(horizon),
output_subsidence_dim=1,
output_gwl_dim=1,
normalize_coords=bool(normalize_coords),
coord_scaler=coord_scaler,
fit_coord_scaler=bool(fit_cs),
return_coord_scaler=True,
mode=str(mode),
model=str(model_name),
verbose=1,
)
return ins, tar, cs
for which in ("train", "val", "test"):
ins, tar, cs = _one(which)
if which == "train":
coord_scaler = cs
arrays = pack_xy_npz(ins, tar)
fn = f"{which}_windows_T{time_steps}_H{horizon}.npz"
p = os.path.join(artifacts_dir, fn)
out[f"{which}_windows_npz"] = _save_npz(p, arrays)
out["coord_scaler"] = coord_scaler
return out
# from geoprior.utils.panel_cache import (
# ensure_feasible_keys_cache,
# )
# df_min = pd.read_csv(
# used_path,
# usecols=[LON_COL, LAT_COL, TIME_COL],
# )
# paths = ensure_feasible_keys_cache(
# df_min,
# out_dir=ARTIFACTS_DIR,
# city=CITY_NAME,
# group_cols=[LON_COL, LAT_COL],
# time_col=TIME_COL,
# train_end=TRAIN_END_YEAR,
# time_steps=TIME_STEPS,
# horizon=FORECAST_HORIZON_YEARS,
# )
# keys_ok = np.load(paths.keys_npy, allow_pickle=True)
# from geoprior.utils.split import subset_by_keys
# df_clean = subset_by_keys(
# df_clean,
# group_cols=[LON_COL, LAT_COL],
# keys=keys_ok,
# decimals=8,
# )
# from geoprior.utils.split import (
# SplitCfg,
# build_group_holdout_npzs,
# build_future_inputs_npz,
# )
# cfg_split = SplitCfg(seed=42, ratios=(0.7, 0.15, 0.15))
# res = build_group_holdout_npzs(
# df_train=df_train,
# artifacts_dir=ARTIFACTS_DIR,
# group_cols=[LON_COL, LAT_COL],
# time_col_used=TIME_COL_USED,
# x_col_used=X_COL_USED,
# y_col_used=Y_COL_USED,
# subs_col=SUBS_MODEL_COL,
# gwl_target_col=GWL_TARGET_COL,
# gwl_dyn_col=GWL_DYN_COL,
# h_field_col=H_FIELD_COL,
# static_cols=static_features,
# dynamic_cols=dynamic_features,
# future_cols=future_features,
# time_steps=TIME_STEPS,
# horizon=FORECAST_HORIZON_YEARS,
# mode=MODE,
# model_name=MODEL_NAME,
# train_end=TRAIN_END_YEAR,
# keys_ok=keys_ok,
# cfg=cfg_split,
# normalize_coords=normalize_coords,
# )
# future_npz = build_future_inputs_npz(
# df_scaled=df_scaled,
# artifacts_dir=ARTIFACTS_DIR,
# time_col=TIME_COL,
# time_col_num=TIME_COL_NUM,
# lon_col=COORD_X_COL,
# lat_col=COORD_Y_COL,
# subs_col=SUBS_MODEL_COL,
# gwl_col=GWL_TARGET_COL,
# h_field_col=H_FIELD_COL,
# static_features=static_features,
# dynamic_features=dynamic_features,
# future_features=future_features,
# group_cols=[LON_COL, LAT_COL],
# train_end_time=TRAIN_END_YEAR,
# forecast_start_time=FORECAST_START_YEAR,
# horizon=FORECAST_HORIZON_YEARS,
# time_steps=TIME_STEPS,
# mode=MODE,
# model_name=MODEL_NAME,
# normalize_coords=normalize_coords,
# coord_scaler=res["coord_scaler"],
# )