Source code for geoprior.scripts.config

# SPDX-License-Identifier: Apache-2.0
# GeoPrior-v3 — https://github.com/earthai-tech/
# geoprior-v3
# Copyright (c) 2026-present
# Author: LKouadio <https://lkouadio.com>
r"""Shared configuration helpers for GeoPrior scripts."""

from __future__ import annotations

import os
from pathlib import Path

# ---------------------------------------------------------
# Artifact export environment
# ---------------------------------------------------------
FIG_DIR_ENV = "GEOPRIOR_FIG_DIR"
OUT_DIR_ENV = "GEOPRIOR_OUT_DIR"
ROOT_ENV = "GEOPRIOR_ARTIFACT_ROOT"

# ---------------------------------------------------------
# Filesystem helpers
# ---------------------------------------------------------


def _as_path(value: str | os.PathLike[str]) -> Path:
    """Return an expanded absolute path."""
    return Path(value).expanduser().resolve()


def _find_project_root(start: Path | None = None) -> Path:
    """Return the nearest project root."""

    here = (start or Path.cwd()).resolve()

    for root in (here, *here.parents):
        if (root / "pyproject.toml").exists():
            return root
        if (root / ".git").exists():
            return root

    return here


[docs] def scripts_root(*, create: bool = True) -> Path: """Return the user-facing artifact root.""" root_env = os.environ.get(ROOT_ENV) if root_env: root = _as_path(root_env) else: root = _find_project_root() / "scripts" if create: root.mkdir(parents=True, exist_ok=True) return root
[docs] def fig_dir(*, create: bool = True) -> Path: """Return the figure export directory.""" value = os.environ.get(FIG_DIR_ENV) path = ( _as_path(value) if value else scripts_root(create=create) / "figs" ) if create: path.mkdir(parents=True, exist_ok=True) return path
[docs] def out_dir(*, create: bool = True) -> Path: """Return the tabular export directory.""" value = os.environ.get(OUT_DIR_ENV) path = ( _as_path(value) if value else scripts_root(create=create) / "out" ) if create: path.mkdir(parents=True, exist_ok=True) return path
def _has_explicit_parent(path: Path) -> bool: """Return True when path includes a parent part.""" return path.parent != Path(".")
[docs] def resolve_user_artifact_path( value: str | os.PathLike[str], *, kind: str, create_parent: bool = False, ) -> Path: """ Resolve a user-provided export path. Rules ----- - absolute path: use as given - relative path with parent: respect that folder - bare filename/stem: place under scripts/figs or scripts/out """ path = Path(value).expanduser() kind_norm = kind.strip().lower() if path.is_absolute(): out = path elif _has_explicit_parent(path): out = path.resolve() elif kind_norm in {"fig", "figs", "figure", "figures"}: out = fig_dir(create=False) / path elif kind_norm == "out": out = out_dir(create=False) / path else: raise ValueError( "kind must be 'fig', 'figs', or 'out'" ) if create_parent: out.parent.mkdir(parents=True, exist_ok=True) return out
[docs] def resolve_export_path( kind: str, *parts: str | os.PathLike[str], create_parent: bool = True, ) -> Path: """Resolve a path under fig/ or out/.""" path = resolve_user_artifact_path( Path(*map(str, parts)), kind=kind, create_parent=create_parent, ) return path
[docs] def export_env_summary() -> dict[str, str]: """Return resolved export environment values.""" return { ROOT_ENV: str(scripts_root(create=False)), FIG_DIR_ENV: str(fig_dir(create=False)), OUT_DIR_ENV: str(out_dir(create=False)), }
# --------------------------------------------------------- # Filesystem layout (v3.2) # --------------------------------------------------------- MODULE_DIR = Path(__file__).resolve().parent SCRIPTS_DIR = MODULE_DIR ARTIFACT_ROOT = scripts_root() FIG_DIR = fig_dir() FIGS_DIR = FIG_DIR OUT_DIR = out_dir() # --------------------------------------------------------- # Common discovery patterns # (src may be a directory or a file) # --------------------------------------------------------- # NOTE: # - We search recursively under src when src is a # directory. # - For GeoPrior eval JSON we prefer the # interpretable one when available. # - It may carry human units such as "mm". PATTERNS = { # GeoPrior evaluation JSON "phys_json": ( "geoprior_eval_phys_*_interpretable.json", "geoprior_eval_phys_*.json", ), # Eval diagnostics fallback JSON "eval_diag_json": ( "*GeoPriorSubsNet_eval_diagnostics*_calibrated.json", "*GeoPriorSubsNet_eval_diagnostics*.json", "eval_diagnostics*.json", "diagnostics*.json", "*eval*diagnostic*.json", ), # --- TestSet: EVAL # (has subsidence_actual) --- "forecast_test_csv": ( "*forecast*TestSet*eval_calibrated*.csv", "*forecast_TestSet*_eval_calibrated.csv", "*_forecast_TestSet_*_eval_calibrated.csv", "*TestSet*eval_calibrated*.csv", ), # --- TestSet: FUTURE # (no subsidence_actual) --- "forecast_test_future_csv": ( "*forecast*TestSet*future*.csv", "*forecast_TestSet*_future.csv", "*_forecast_TestSet_*_future.csv", "*TestSet*future*.csv", ), # Typical calibrated validation forecasts "forecast_val_csv": ( "*forecast*Validation*calibrated*.csv", "*forecast*Validation*Fallback*calibrated*.csv", "*Validation*calibrated*.csv", "*_calibrated.csv", ), # Typical future forecasts "forecast_future_csv": ( "*future*.csv", "*H*_future*.csv", "*forecast*future*.csv", ), # Physics payloads "physics_payload": ( "*physics_payload*.npz", "*phys_payload*.npz", "*payload*.npz", ), # Optional coords file # (used by some physics maps) "coords_npz": ( "*coords*.npz", "*xy*.npz", "*lonlat*.npz", "*val_inputs*.npz", "*test_inputs*.npz", "*train_inputs*.npz", "*oos_time_inputs*.npz", ), # Ablation records (Supplement S6/S7) "ablation_record_jsonl": ( "ablation_records/ablation_record.updated*.jsonl", "ablation_records/ablation_record*.jsonl", "ablation_record*.jsonl", "ablation_record.updated*.jsonl", ), "boundary_shp": ( "*boundary*.shp", "*coast*.shp", "*admin*.shp", "*outline*.shp", "*border*.shp", ), } # --------------------------------------------------------- # Plot metric metadata # (titles / labels / format) # Used by multi-panel figures to avoid duplicated # strings. # --------------------------------------------------------- PLOT_METRIC_META = { "r2": { "title": r"$R^2$ (↑)", "ylabel": r"$R^2$", "fmt": "{:.2f}", }, "mae": { "title": "MAE (↓, {unit})", "ylabel": "MAE ({unit})", "fmt": "{:.2f}", "unit": "mm", }, "rmse": { "title": "RMSE (↓, {unit})", "ylabel": "RMSE ({unit})", "fmt": "{:.2f}", "unit": "mm", }, "mse": { "title": "MSE (↓, {unit})", "ylabel": "MSE ({unit})", "fmt": "{:.2f}", "unit": "mm²", }, "coverage80": { "title": "Coverage (80% PI; target 0.80)", "ylabel": "Coverage", "fmt": "{:.3f}", }, "sharpness80": { "title": "Sharpness (80% PI; ↓, {unit})", "ylabel": "Sharpness ({unit})", "fmt": "{:.3f}", "unit": "mm", }, } # --------------------------------------------------------- # Matplotlib defaults (paper-friendly) # --------------------------------------------------------- PAPER_DPI = 600 PAPER_FONT = 8 # --------------------------------------------------------- # City styling # (keep consistent across all scripts) # (Used repeatedly across figures) # --------------------------------------------------------- CITY_CANON = { "nansha": "Nansha", "zhongshan": "Zhongshan", "ns": "Nansha", "zh": "Zhongshan", } CITY_COLORS = { "Nansha": "#1F78B4", "Zhongshan": "#E31A1C", } # --------------------------------------------------------- # Units: payload + drivers + targets # --------------------------------------------------------- # Keep *units* separate from *labels* so plots can # do: # label = LABELS[key] # unit = UNITS.get(key) # ax.set_xlabel( # f"{label} ({unit})" if unit else label # ) UNITS = { # Drivers / coordinates "GWL": "m", "rainfall_mm": "mm/yr", "soil_thickness": "m", "normalized_urban_load_proxy": "1", "longitude": "deg", "latitude": "deg", "year": "year", # Targets "subsidence": "mm/yr", "subsidence_cum": "mm", # Forecast CSV canonical columns "forecast_step": "step", "sample_idx": "id", } LABELS = { "GWL": "Groundwater level", "rainfall_mm": "Rainfall", "soil_thickness": "Soil thickness", "normalized_urban_load_proxy": "Urban load", "longitude": "Longitude", "latitude": "Latitude", "year": "Year", "subsidence": "Subsidence", "subsidence_cum": "Cumulative subsidence", "lithology_class": "Lithology class", "city": "City", "forecast_step": "Forecast step", "sample_idx": "Sample id", } UNITS.update( { "GWL_depth_bgs_m": "m", "GWL_depth_bgs": "m", "GWL_depth_bgs_raw": "m", "soil_thickness_eff": "m", "soil_thickness_imputed": "m", "urban_load_global": "1", "head_m": "m", "z_surf_m": "m", "z_surf": "m", } ) LABELS.update( { "GWL_depth_bgs_m": "Depth to water table", "GWL_depth_bgs": "Depth to water table", "GWL_depth_bgs_raw": "Depth to water table", "soil_thickness_eff": "Effective soil thickness", "soil_thickness_imputed": "Imputed soil thickness", "urban_load_global": "Urban load", "head_m": "Hydraulic head", "z_surf_m": "Surface elevation", "z_surf": "Surface elevation", } ) COLUMN_ALIASES = { # Drivers "GWL_depth_bgs_m": ( "GWL_depth_bgs_m", "GWL_depth_bgs", "GWL", "GWL_depth_bgs_raw", ), "rainfall_mm": ( "rainfall_mm", "rainfall", "rainfall_m", ), "soil_thickness_eff": ( "soil_thickness_eff", "soil_thickness_imputed", "soil_thickness", ), "urban_load_global": ( "urban_load_global", "normalized_urban_load_proxy", ), # Responses "subsidence_cum": ("subsidence_cum",), "subsidence": ("subsidence",), "head_m": ("head_m",), } DRIVER_RESPONSE_DEFAULT_DRIVERS = ( "GWL_depth_bgs_m", "rainfall_mm", "soil_thickness_eff", "urban_load_global", ) DRIVER_RESPONSE_DEFAULT_RESPONSE = "subsidence_cum" # --------------------------------------------------------- # Units: GeoPrior physics fields / diagnostics # --------------------------------------------------------- # Used for physics maps and sanity plots. PHYS_UNITS = { "K": "m/s", "Ss": "1/m", "Hd": "m", "H": "m", "tau": "s", "tau_prior": "s", "log10_K": "log10(m/s)", "log10_Ss": "log10(1/m)", "log10_tau": "log10(s)", # Diagnostics "epsilon_prior": "1", "epsilon_cons": "1", "epsilon_gw": "1", "epsilon_cons_raw": "varies", "epsilon_gw_raw": "varies", } PHYS_LABELS = { "K": "Hydraulic conductivity", "Ss": "Specific storage", "Hd": "Drainage thickness", "H": "Effective thickness", "tau": "Consolidation time scale", "tau_prior": "Prior time scale", "log10_K": "log10 K", "log10_Ss": "log10 Ss", "log10_tau": "log10 τ", "epsilon_prior": r"$\epsilon_{\mathrm{prior}}$", "epsilon_cons": r"$\epsilon_{\mathrm{cons}}$", "epsilon_gw": r"$\epsilon_{\mathrm{gw}}$", "epsilon_cons_raw": (r"$\epsilon_{\mathrm{cons,raw}}$"), "epsilon_gw_raw": r"$\epsilon_{\mathrm{gw,raw}}$", } # --------------------------------------------------------- # Time + scaling # (used to harmonize JSON variants) # --------------------------------------------------------- SECONDS_PER_YEAR = 31556952.0 # --------------------------------------------------------- # Closure strings for captions / legends # (math "single source") # --------------------------------------------------------- CLOSURES = { "tau_prior": ( r"$\tau_{\mathrm{prior}}" r"\approx H_d^2\,S_s" r"/(\pi^2\,\kappa_b\,K)$" ), "s_eq": r"$s_{\mathrm{eq}}\approx S_s\,\Delta h\,H$", "R_cons": ( r"$R_{\mathrm{cons}}=" r"\partial_t \bar{s}" r"-(s_{\mathrm{eq}}(\bar{h})-\bar{s})/\tau$" ), "R_gw": ( r"$R_{\mathrm{gw}}=" r"S_s\,\partial_t \bar{h}" r"-\nabla\cdot(K\nabla\bar{h})$" ), } # --------------------------------------------------------- # Column canonicalization # --------------------------------------------------------- _BASE_ALIASES = { "sample_idx": ("sample_idx", "sample_id", "sample"), "forecast_step": ("forecast_step", "forecast_s", "h"), "coord_t": ("coord_t", "year", "t"), "coord_x": ("coord_x", "lon", "longitude", "x"), "coord_y": ("coord_y", "lat", "latitude", "y"), "subsidence_actual": ("subsidence_actual", "actual"), "subsidence_q50": ("subsidence_q50", "q50", "p50"), "subsidence_unit": ("subsidence_unit", "unit"), } _CALIB_REQUIRED = [ "sample_idx", "forecast_step", "coord_t", "coord_x", "coord_y", "subsidence_actual", "subsidence_q50", ] _FUT_REQUIRED = [ "sample_idx", "forecast_step", "coord_t", "coord_x", "coord_y", "subsidence_q50", ] _CAL_ORDER = ("none", "source", "target") _CAL_LABEL = { "none": "None", "source": "Source", "target": "Target", } _CAL_MARKER = { "none": "o", "source": "^", "target": "s", } _STRAT_DEFAULT = ("baseline", "xfer", "warm") _STRAT_LABEL = { "baseline": "Baseline (target)", "xfer": "Transfer", "warm": "Warm-start", } _STRAT_HATCH = { "baseline": "//", "xfer": "", "warm": "xx", } _STRAT_EDGE = { "baseline": "#111111", "xfer": "#111111", "warm": "#111111", } _STRAT_LINESTYLE = { "baseline": ":", "xfer": "--", "warm": "-", } _STRAT_MARKER = { "baseline": "o", "xfer": "^", "warm": "s", } _BASELINE_MAP = { "A_to_B": "B_to_B", "B_to_A": "A_to_A", } _METRIC_DEF = { "mae": ("overall_mae", "MAE (mm)", "min"), "mse": ("overall_mse", "MSE (mm^2)", "min"), "rmse": ("overall_rmse", "RMSE (mm)", "min"), "r2": ("overall_r2", r"$R^2$", "max"), }