Source code for geoprior.utils.holdout_utils

# SPDX-License-Identifier: Apache-2.0
# GeoPrior-v3 - https://github.com/earthai-tech/geoprior-v3
# Copyright (c) 2026-present
# Author: LKouadio <https://lkouadio.com>
r"""Utility helpers for holdout and split workflows."""

from __future__ import annotations

from dataclasses import dataclass
from typing import Literal

import numpy as np
import pandas as pd

HoldoutStrategy = Literal["random", "spatial_block"]


[docs] @dataclass(frozen=True) class GroupMasks: """Group-level validity masks for early filtering.""" required_train_years: list[int] required_forecast_years: list[int] valid_for_train: pd.DataFrame valid_for_forecast: pd.DataFrame @property def keep_for_processing(self) -> pd.DataFrame: """Union(valid_for_train, valid_for_forecast).""" if self.valid_for_train.empty: return self.valid_for_forecast.copy() if self.valid_for_forecast.empty: return self.valid_for_train.copy() a = self.valid_for_train.copy() b = self.valid_for_forecast.copy() out = pd.concat([a, b], axis=0, ignore_index=True) out = out.drop_duplicates() return out
def _year_span(end_year: int, n_years: int) -> list[int]: end_year = int(end_year) n_years = int(n_years) start = end_year - n_years + 1 return list(range(start, end_year + 1)) def _groups_with_all_years( df: pd.DataFrame, *, group_cols: list[str], time_col: str, required_years: list[int], ) -> pd.DataFrame: """ Return unique group rows (group_cols) that contain ALL required_years. Efficient approach: - drop duplicate (group,time) - filter to required years - count unique years per group """ if not required_years: return df[group_cols].drop_duplicates().copy() d0 = df[group_cols + [time_col]].drop_duplicates() d0 = d0[d0[time_col].isin(required_years)] if d0.empty: return d0[group_cols].drop_duplicates().copy() need = len(required_years) cnt = d0.groupby(group_cols, sort=False)[ time_col ].nunique() ok = cnt[cnt == need].index if isinstance(ok, pd.MultiIndex): out = ok.to_frame(index=False) out.columns = group_cols return out.reset_index(drop=True) # single group col return pd.DataFrame({group_cols[0]: ok}).reset_index( drop=True )
[docs] def compute_group_masks( df: pd.DataFrame, *, group_cols: list[str], time_col: str, train_end_year: int, time_steps: int, horizon: int, ) -> GroupMasks: """ Build: - valid_for_train: groups containing all years for last (T+H) - valid_for_forecast: groups containing all years for last T This assumes annual steps and integer years in time_col. """ req_train = _year_span( train_end_year, time_steps + horizon ) req_fcst = _year_span(train_end_year, time_steps) v_train = _groups_with_all_years( df, group_cols=group_cols, time_col=time_col, required_years=req_train, ) v_fcst = _groups_with_all_years( df, group_cols=group_cols, time_col=time_col, required_years=req_fcst, ) return GroupMasks( required_train_years=req_train, required_forecast_years=req_fcst, valid_for_train=v_train, valid_for_forecast=v_fcst, )
def _hash_groups( df: pd.DataFrame, group_cols: list[str] ) -> pd.Series: """ Stable-ish 64-bit hash per row for group membership filtering. Collisions are theoretically possible but extremely unlikely. """ return pd.util.hash_pandas_object( df[group_cols], index=False, ).astype("uint64")
[docs] def filter_df_by_groups( df: pd.DataFrame, *, group_cols: list[str], groups: pd.DataFrame, ) -> pd.DataFrame: """ Keep only rows in df whose (group_cols) exist in groups DataFrame. """ if groups is None or groups.empty: return df.iloc[0:0].copy() h_df = _hash_groups(df, group_cols) h_g = _hash_groups(groups.drop_duplicates(), group_cols) keep = h_df.isin(h_g.to_numpy()) return df.loc[keep].copy()
[docs] @dataclass(frozen=True) class HoldoutSplit: """Pixel holdout split (disjoint groups).""" train_groups: pd.DataFrame val_groups: pd.DataFrame test_groups: pd.DataFrame
[docs] def check_disjoint(self) -> None: g = ["__h__"] a = self.train_groups.copy() b = self.val_groups.copy() c = self.test_groups.copy() a[g[0]] = _hash_groups(a, list(a.columns)) b[g[0]] = _hash_groups(b, list(b.columns)) c[g[0]] = _hash_groups(c, list(c.columns)) sa = set(a[g[0]].to_numpy()) sb = set(b[g[0]].to_numpy()) sc = set(c[g[0]].to_numpy()) if sa & sb or sa & sc or sb & sc: raise RuntimeError( "Holdout groups are not disjoint." )
[docs] def split_groups_holdout( groups: pd.DataFrame, *, seed: int = 42, val_frac: float = 0.2, test_frac: float = 0.1, strategy: HoldoutStrategy = "random", x_col: str | None = None, y_col: str | None = None, block_size: float | None = None, ) -> HoldoutSplit: """ Split unique groups into train/val/test (pixel-level holdout). strategy: - "random": shuffle groups - "spatial_block": shuffle spatial blocks (needs x_col,y_col,block) """ if groups is None or groups.empty: z = groups.iloc[0:0].copy() return HoldoutSplit(z, z, z) if ( val_frac < 0 or test_frac < 0 or (val_frac + test_frac) >= 1.0 ): raise ValueError("Bad val/test fractions.") g = groups.drop_duplicates().reset_index(drop=True) rng = np.random.default_rng(int(seed)) if strategy == "spatial_block": if x_col is None or y_col is None: raise ValueError( "spatial_block needs x_col and y_col." ) if block_size is None or float(block_size) <= 0: raise ValueError( "spatial_block needs block_size > 0." ) bx = np.floor( g[x_col].to_numpy(float) / float(block_size) ) by = np.floor( g[y_col].to_numpy(float) / float(block_size) ) blocks = pd.DataFrame( {"bx": bx.astype(int), "by": by.astype(int)} ) b_id = pd.util.hash_pandas_object(blocks, index=False) uniq = pd.DataFrame({"b": b_id}).drop_duplicates() perm = rng.permutation(len(uniq)) n = len(uniq) n_test = max(1, int(round(test_frac * n))) n_val = max(1, int(round(val_frac * n))) n_train = max(1, n - n_val - n_test) b_train = set( uniq.iloc[perm[:n_train]]["b"].to_numpy() ) b_val = set( uniq.iloc[perm[n_train : n_train + n_val]][ "b" ].to_numpy() ) b_test = set( uniq.iloc[perm[n_train + n_val :]]["b"].to_numpy() ) in_train = b_id.isin(b_train).to_numpy() in_val = b_id.isin(b_val).to_numpy() in_test = b_id.isin(b_test).to_numpy() tr = g.loc[in_train].copy() va = g.loc[in_val].copy() te = g.loc[in_test].copy() else: perm = rng.permutation(len(g)) n = len(g) n_test = max(1, int(round(test_frac * n))) n_val = max(1, int(round(val_frac * n))) n_train = max(1, n - n_val - n_test) tr = g.iloc[perm[:n_train]].copy() va = g.iloc[perm[n_train : n_train + n_val]].copy() te = g.iloc[perm[n_train + n_val :]].copy() # hard safety: ensure disjoint (by hash) out = HoldoutSplit(tr, va, te) out.check_disjoint() return out