# SPDX-License-Identifier: Apache-2.0
# GeoPrior-v3 — https://github.com/earthai-tech/geoprior-v3
# Copyright (c) 2026-present
# Author: LKouadio <https://lkouadio.com>
"""Drop-in `run_one()` for nat.com/sensitivity_lib.py.
This implementation is adapted from nat.com/sensitivity.py,
with the heavy steps moved to `build_context()`:
- Stage-1 manifest lookup
- NPZ loading
- tf.data pipeline building
`run_one()` assumes `ctx` contains:
- ``ctx.manifest``, ``ctx.manifest_path``
- ``ctx.cfg_base``
- ``ctx.base_output_dir``
- ``ctx.scaler_info``
- ``ctx.X_train / ctx.y_train / ctx.X_val / ctx.y_val``
- optional ``ctx.X_test / ctx.y_test``
- ``ctx.ds_train / ctx.ds_val /`` optional ``ctx.ds_test``
- ``ctx.dyn_names / ctx.fut_names / ctx.sta_names``
It writes all artifacts under `run_dir`.
"""
from __future__ import annotations
import copy
import datetime as dt
import gc
import json
import os
import platform
import re
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import joblib
import numpy as np
import tensorflow as tf
from geoprior.backends.devices import configure_tf_from_cfg
from geoprior.registry import _find_stage1_manifest
from geoprior.utils import (
default_results_dir,
ensure_directory_exists,
getenv_stripped,
load_nat_config,
load_nat_config_payload,
load_scaler_info,
make_tf_dataset,
resolve_hybrid_config,
)
def _sanitize_tag(s: str | None) -> str | None:
if not s:
return None
s = re.sub(r"[^0-9A-Za-z._-]+", "_", s.strip())
s = s.strip("._-")
s = s[:120]
return s or None
[docs]
@dataclass
class SensitivityContext:
manifest_path: str
manifest: dict[str, Any]
cfg_base: dict[str, Any]
# stage-1 paths / outputs
base_output_dir: str
# scalers / encoders
scaler_info: dict[str, Any]
# npz arrays
X_train: dict[str, np.ndarray]
y_train: dict[str, np.ndarray]
X_val: dict[str, np.ndarray]
y_val: dict[str, np.ndarray]
X_test: dict[str, np.ndarray] | None
y_test: dict[str, np.ndarray] | None
# datasets (cached)
ds_train: tf.data.Dataset
ds_val: tf.data.Dataset
ds_test: tf.data.Dataset | None
# reusable “static” knobs
dyn_names: tuple[str, ...]
fut_names: tuple[str, ...]
sta_names: tuple[str, ...]
mode: str
horizon: int
batch_size: int
[docs]
def build_context(
*,
city: str | None = None,
stage1_manifest: str | None = None,
verbose: int = 1,
) -> SensitivityContext:
"""
One-time heavy setup:
- locate Stage-1 manifest
- resolve cfg (hybrid)
- load NPZ arrays
- build tf.data datasets (train/val/test)
"""
results_dir = default_results_dir()
payload = load_nat_config_payload()
cfg_city = (
payload.get("city") or ""
).strip().lower() or None
city_env = getenv_stripped("CITY")
city_hint = city_env or city or cfg_city
manual = stage1_manifest or getenv_stripped(
"STAGE1_MANIFEST"
)
manifest_path = _find_stage1_manifest(
manual=manual,
base_dir=results_dir,
city_hint=city_hint,
model_hint=None,
prefer="timestamp",
required_keys=(
"model",
"stage",
"artifacts",
"config",
"paths",
),
filter_fn=None,
verbose=1 if verbose else 0,
)
with open(manifest_path, encoding="utf-8") as f:
M = json.load(f)
cfg_global = load_nat_config()
cfg_manifest = M.get("config", {}) or {}
cfg = resolve_hybrid_config(
manifest_cfg=cfg_manifest,
live_cfg=cfg_global,
verbose=bool(verbose),
)
_ = configure_tf_from_cfg(cfg)
features = cfg.get("features", {}) or {}
dyn = tuple(features.get("dynamic", []) or [])
fut = tuple(features.get("future", []) or [])
sta = tuple(features.get("static", []) or [])
mode = str(cfg.get("MODE", cfg.get("mode", "sequence")))
horizon = int(cfg.get("FORECAST_HORIZON_YEARS", 3))
batch_size = int(cfg.get("BATCH_SIZE", 32))
enc = M["artifacts"]["encoders"]
scaler_info = load_scaler_info(enc)
npz = M["artifacts"]["numpy"]
X_train = dict(np.load(npz["train_inputs_npz"]))
y_train = dict(np.load(npz["train_targets_npz"]))
X_val = dict(np.load(npz["val_inputs_npz"]))
y_val = dict(np.load(npz["val_targets_npz"]))
X_test = None
y_test = None
if npz.get("test_inputs_npz") and npz.get(
"test_targets_npz"
):
X_test = dict(np.load(npz["test_inputs_npz"]))
y_test = dict(np.load(npz["test_targets_npz"]))
ds_train = make_tf_dataset(
X_train,
y_train,
batch_size=batch_size,
shuffle=True,
mode=mode,
forecast_horizon=horizon,
check_npz_finite=True,
check_finite=True,
scan_finite_batches=None,
dynamic_feature_names=list(dyn),
future_feature_names=list(fut),
)
ds_val = make_tf_dataset(
X_val,
y_val,
batch_size=batch_size,
shuffle=False,
mode=mode,
forecast_horizon=horizon,
check_npz_finite=True,
check_finite=True,
scan_finite_batches=None,
dynamic_feature_names=list(dyn),
future_feature_names=list(fut),
)
ds_test = None
if X_test is not None and y_test is not None:
ds_test = make_tf_dataset(
X_test,
y_test,
batch_size=batch_size,
shuffle=False,
mode=mode,
forecast_horizon=horizon,
check_npz_finite=True,
check_finite=True,
scan_finite_batches=None,
dynamic_feature_names=list(dyn),
future_feature_names=list(fut),
)
base_out = M["paths"]["run_dir"]
return SensitivityContext(
manifest_path=str(manifest_path),
manifest=M,
cfg_base=cfg,
base_output_dir=str(base_out),
scaler_info=scaler_info,
X_train=X_train,
y_train=y_train,
X_val=X_val,
y_val=y_val,
X_test=X_test,
y_test=y_test,
ds_train=ds_train,
ds_val=ds_val,
ds_test=ds_test,
dyn_names=dyn,
fut_names=fut,
sta_names=sta,
mode=mode,
horizon=horizon,
batch_size=batch_size,
)
def _make_run_dir(
base_output_dir: str,
run_tag: str | None,
stable: bool,
) -> str:
tag = _sanitize_tag(run_tag)
if stable and tag:
name = f"sens__{tag}"
else:
stamp = dt.datetime.now().strftime("%Y%m%d-%H%M%S")
name = f"train_{stamp}"
if tag:
name = f"{name}__{tag}"
out = str(Path(base_output_dir) / name)
ensure_directory_exists(out)
return out
[docs]
def cleanup_between_runs() -> None:
"""Clear TF/Keras graphs between runs."""
tf.keras.backend.clear_session()
gc.collect()
[docs]
def run_one(
ctx,
*,
overrides: dict[str, Any],
run_tag: str | None,
stable_run_dir: bool = True,
eval_max_batches: int | None = None,
cal_max_batches: int | None = None,
) -> str:
"""Run one physics sensitivity trial.
Notes
-----
- Reuses ``ctx.ds_train/ctx.ds_val/ctx.ds_test``.
- All artifacts are written to the returned run_dir.
"""
# --- imports (local to keep module light) -----------------
from tensorflow.keras.callbacks import (
CSVLogger,
EarlyStopping,
ModelCheckpoint,
TerminateOnNaN,
)
from geoprior.api.util import get_table_size
from geoprior.backends.devices import (
configure_tf_from_cfg,
)
from geoprior.compat import (
load_inference_model,
load_model_from_tfv2,
normalize_predict_output,
save_manifest,
save_model,
)
from geoprior.deps import with_progress
from geoprior.models import (
MAEQ50,
MSEQ50,
Coverage80,
GeoPriorSubsNet,
LambdaOffsetScheduler,
PoroElasticSubsNet,
Sharpness80,
_logs_to_py,
_to_py,
apply_calibrator_to_subs,
autoplot_geoprior_history,
coverage80_fn,
debug_model_reload,
# debug_tensor_interval,
debug_val_interval,
extract_physical_parameters,
finalize_scaling_kwargs,
fit_interval_calibrator_on_val,
load_physics_payload,
make_weighted_pinball,
override_scaling_kwargs,
plot_history_in,
plot_physics_values_in,
sharpness80_fn,
)
from geoprior.params import (
FixedGammaW,
FixedHRef,
LearnableKappa,
LearnableMV,
)
from geoprior.plot import plot_eval_future
from geoprior.utils import (
audit_stage2_handshake,
best_epoch_and_metrics,
build_censor_mask,
calibrate_quantile_forecasts,
convert_eval_payload_units,
# default_results_dir,
deg_to_m_from_lat,
ensure_directory_exists,
ensure_input_shapes,
evaluate_forecast,
evaluate_point_forecast,
format_and_forecast,
inverse_scale_target,
load_scaler_info,
map_targets_for_training,
name_of,
postprocess_eval_json,
print_config_table,
resolve_si_affine,
save_ablation_record,
save_all_figures,
serialize_subs_params,
should_audit,
)
# ----------------------------------------------------------
cfg = copy.deepcopy(ctx.cfg_base)
cfg.update(overrides or {})
# Use caller-provided tag unless cfg has RUN_TAG.
run_tag2 = _sanitize_tag(run_tag or cfg.get("RUN_TAG"))
run_dir = _make_run_dir(
ctx.base_output_dir,
run_tag=run_tag2,
stable=stable_run_dir,
)
ensure_directory_exists(run_dir)
# ---------------------------
# Basic runtime knobs
# ---------------------------
FAST_SENS = bool(cfg.get("FAST_SENSITIVITY", False))
DEBUG = bool(cfg.get("DEBUG", False))
# Unit post-processing for eval JSON.
units_mode = str(
cfg.get("EVAL_JSON_UNITS_MODE", "si") or "si"
)
units_mode = units_mode.strip().lower()
units_scope = str(
cfg.get("EVAL_JSON_UNITS_SCOPE", "all") or "all"
)
units_scope = units_scope.strip().lower()
# ---------------------------
# Manifest + names
# ---------------------------
M = ctx.manifest
CITY_NAME = M.get("city", cfg.get("CITY_NAME", "nansha"))
MODEL_NAME = cfg.get("MODEL_NAME", cfg.get("model", None))
if not MODEL_NAME:
MODEL_NAME = "GeoPriorSubsNet"
# ---------------------------
# Cached tensors/datasets
# ---------------------------
X_train = ctx.X_train
y_train = ctx.y_train
X_val = ctx.X_val
y_val = ctx.y_val
X_test = ctx.X_test
y_test = ctx.y_test
train_dataset = ctx.ds_train
val_dataset = ctx.ds_val
DYN_NAMES = list(ctx.dyn_names)
FUT_NAMES = list(ctx.fut_names)
STA_NAMES = list(ctx.sta_names)
# ---------------------------
# Config essentials
# ---------------------------
TIME_STEPS = int(cfg["TIME_STEPS"])
FORECAST_H = int(cfg["FORECAST_HORIZON_YEARS"])
MODE = str(cfg["MODE"])
FORECAST_START_YEAR = int(cfg.get("FORECAST_START_YEAR"))
# Architecture
ATTENTION_LEVELS = cfg.get(
"ATTENTION_LEVELS",
["cross", "hierarchical", "memory"],
)
EMBED_DIM = int(cfg.get("EMBED_DIM", 32))
HIDDEN_UNITS = int(cfg.get("HIDDEN_UNITS", 64))
LSTM_UNITS = int(cfg.get("LSTM_UNITS", 64))
ATTENTION_UNITS = int(cfg.get("ATTENTION_UNITS", 64))
NUMBER_HEADS = int(cfg.get("NUMBER_HEADS", 2))
DROPOUT_RATE = float(cfg.get("DROPOUT_RATE", 0.10))
MEMORY_SIZE = int(cfg.get("MEMORY_SIZE", 50))
SCALES = cfg.get("SCALES", [1, 2])
USE_RESIDUALS = bool(cfg.get("USE_RESIDUALS", True))
USE_BATCH_NORM = bool(cfg.get("USE_BATCH_NORM", False))
USE_VSN = bool(cfg.get("USE_VSN", True))
VSN_UNITS = int(cfg.get("VSN_UNITS", 32))
# Training
EPOCHS = int(cfg.get("EPOCHS", 50))
BATCH_SIZE_CFG = int(cfg.get("BATCH_SIZE", 32))
if hasattr(ctx, "batch_size"):
if int(ctx.batch_size) != int(BATCH_SIZE_CFG):
print(
"[Warn] Gold mode uses cached tf.data datasets "
f"built with BATCH_SIZE={int(ctx.batch_size)}; "
f"ignoring override BATCH_SIZE={int(BATCH_SIZE_CFG)}. "
"Rebuild context to change batch size."
)
BATCH_SIZE = int(ctx.batch_size)
cfg["BATCH_SIZE"] = BATCH_SIZE
else:
BATCH_SIZE = int(BATCH_SIZE_CFG)
LEARNING_RATE = float(cfg.get("LEARNING_RATE", 1e-4))
VERBOSE = int(cfg.get("VERBOSE", 1))
LEARNING_RATE = float(cfg.get("LEARNING_RATE", 1e-4))
VERBOSE = int(cfg.get("VERBOSE", 1))
# Quantiles
QUANTILES = cfg.get("QUANTILES", [0.1, 0.5, 0.9])
def _coerce_qw(d: dict, default: dict) -> dict:
if not d:
return default
out = {}
for k, v in d.items():
try:
q = float(k)
except Exception:
q = k
out[q] = float(v)
return out
SUBS_WEIGHTS = _coerce_qw(
cfg.get("SUBS_WEIGHTS", None)
or {0.1: 3.0, 0.5: 1.0, 0.9: 3.0},
{0.1: 3.0, 0.5: 1.0, 0.9: 3.0},
)
GWL_WEIGHTS = _coerce_qw(
cfg.get("GWL_WEIGHTS", None)
or {0.1: 1.5, 0.5: 1.0, 0.9: 1.5},
{0.1: 1.5, 0.5: 1.0, 0.9: 1.5},
)
# Physics
TIME_UNITS = str(
cfg.get("TIME_UNITS", "year") or "year"
).lower()
SCALE_PDE_RESIDUALS = bool(
cfg.get("SCALE_PDE_RESIDUALS", True)
)
CONS_RESID_METHOD = str(
cfg.get(
"CONSOLIDATION_STEP_RESIDUAL_METHOD",
"exact",
)
)
PDE_MODE = str(cfg.get("PDE_MODE_CONFIG", "off") or "off")
PDE_MODE = PDE_MODE.strip().lower()
if PDE_MODE in ("off", "none"):
PDE_MODE = "none"
LAMBDA_CONS = float(cfg.get("LAMBDA_CONS", 0.10))
LAMBDA_GW = float(cfg.get("LAMBDA_GW", 0.01))
LAMBDA_PRIOR = float(cfg.get("LAMBDA_PRIOR", 0.10))
LAMBDA_SMOOTH = float(cfg.get("LAMBDA_SMOOTH", 0.01))
LAMBDA_MV = float(cfg.get("LAMBDA_MV", 0.01))
LAMBDA_BOUNDS = float(cfg.get("LAMBDA_BOUNDS", 0.0))
LOSS_WEIGHT_GWL = float(cfg.get("LOSS_WEIGHT_GWL", 0.5))
LAMBDA_Q = float(cfg.get("LAMBDA_Q", 0.0))
LOG_Q_DIAGNOSTICS = bool(
cfg.get("LOG_Q_DIAGNOSTICS", False)
)
MV_LR_MULT = float(cfg.get("MV_LR_MULT", 1.0))
KAPPA_LR_MULT = float(cfg.get("KAPPA_LR_MULT", 5.0))
OFFSET_MODE = str(cfg.get("OFFSET_MODE", "mul"))
LAMBDA_OFFSET = float(cfg.get("LAMBDA_OFFSET", 1.0))
USE_LAMBDA_OFFSET_SCHED = bool(
cfg.get("USE_LAMBDA_OFFSET_SCHEDULER", False)
)
LAMBDA_OFFSET_UNIT = cfg.get(
"LAMBDA_OFFSET_UNIT", "epoch"
)
LAMBDA_OFFSET_WHEN = cfg.get(
"LAMBDA_OFFSET_WHEN", "begin"
)
LAMBDA_OFFSET_WARMUP = int(
cfg.get("LAMBDA_OFFSET_WARMUP", 10)
)
LAMBDA_OFFSET_START = cfg.get("LAMBDA_OFFSET_START", None)
LAMBDA_OFFSET_END = cfg.get("LAMBDA_OFFSET_END", None)
LAMBDA_OFFSET_SCHEDULE = cfg.get(
"LAMBDA_OFFSET_SCHEDULE", None
)
# Identifiability
ident = cfg.get("IDENTIFIABILITY_REGIME", None)
if isinstance(ident, str):
s = ident.strip()
if not s or s.lower() in (
"none",
"off",
"false",
"0",
):
ident = None
else:
ident = s
# Model flavour tweaks
if MODEL_NAME == "HybridAttn-NoPhysics":
PDE_MODE = "none"
LAMBDA_CONS = 0.0
LAMBDA_GW = 0.0
LAMBDA_PRIOR = 0.0
LAMBDA_SMOOTH = 0.0
LAMBDA_BOUNDS = 0.0
LAMBDA_MV = 0.0
MV_LR_MULT = 0.0
KAPPA_LR_MULT = 0.0
if MODEL_NAME == "PoroElasticSubsNet":
PDE_MODE = "consolidation"
LAMBDA_GW = 0.0
# Censoring
cfg.get("features", {}) or {}
CENSOR = (
cfg.get("censoring", {})
or cfg.get("censor", {})
or {}
)
CENSOR_SPECS = CENSOR.get("specs", []) or []
CENSOR_THRESH = float(CENSOR.get("flag_threshold", 0.5))
# Determine censor flag index.
CENSOR_FLAG_IDX_DYN = None
CENSOR_FLAG_IDX_FUT = None
CENSOR_FLAG_NAME = None
for sp in CENSOR_SPECS:
cand = sp.get("flag_col")
if not cand:
base = sp.get("col")
if base:
cand = base + sp.get(
"flag_suffix", "_censored"
)
if not cand:
continue
if cand in FUT_NAMES and CENSOR_FLAG_IDX_FUT is None:
CENSOR_FLAG_IDX_FUT = FUT_NAMES.index(cand)
CENSOR_FLAG_NAME = cand
if cand in DYN_NAMES and CENSOR_FLAG_IDX_DYN is None:
CENSOR_FLAG_IDX_DYN = DYN_NAMES.index(cand)
CENSOR_FLAG_NAME = cand
if CENSOR_FLAG_IDX_FUT is not None:
CENSOR_MASK_SOURCE = "future"
CENSOR_FLAG_IDX = CENSOR_FLAG_IDX_FUT
elif CENSOR_FLAG_IDX_DYN is not None:
CENSOR_MASK_SOURCE = "dynamic"
CENSOR_FLAG_IDX = CENSOR_FLAG_IDX_DYN
else:
CENSOR_MASK_SOURCE = None
CENSOR_FLAG_IDX = None
# Column naming
cols_cfg = cfg.get("cols", {}) or {}
SUBSIDENCE_COL = cols_cfg.get("subsidence", "subsidence")
GWL_COL = cols_cfg.get("gwl", "GWL")
# Stage-1 dims
dims = (
M.get("artifacts", {})
.get("sequences", {})
.get("dims", {})
or {}
)
OUT_S_DIM = int(dims.get("output_subsidence_dim", 1))
OUT_G_DIM = int(dims.get("output_gwl_dim", 1))
# Scalers
encoders = M["artifacts"]["encoders"]
scaler_info_dict = ctx.scaler_info
if not scaler_info_dict:
scaler_info_dict = load_scaler_info(encoders)
# Attach loaded scaler objects if missing.
if isinstance(scaler_info_dict, dict):
for _, v in scaler_info_dict.items():
if not isinstance(v, dict):
continue
if "scaler" in v:
continue
p = v.get("scaler_path")
if p and os.path.exists(p):
try:
v["scaler"] = joblib.load(p)
except Exception:
pass
def _pick_scaler_key(info, preferred, fallbacks=()):
if not info:
return preferred
if preferred in info:
return preferred
for k in fallbacks:
if k in info:
return k
low = {k.lower(): k for k in info.keys()}
for token in ("subs", "subsidence"):
for lk, orig in low.items():
if token in lk:
return orig
return preferred
SUBS_SCALER_KEY = _pick_scaler_key(
scaler_info_dict,
preferred=SUBSIDENCE_COL,
fallbacks=("subsidence", "subs_pred"),
)
GWL_SCALER_KEY = _pick_scaler_key(
scaler_info_dict,
preferred=GWL_COL,
fallbacks=("gwl", "gwl_pred"),
)
# Coord scaler (optional)
coord_scaler = None
cs_path = encoders.get("coord_scaler")
if cs_path and os.path.exists(cs_path):
try:
coord_scaler = joblib.load(cs_path)
except Exception:
coord_scaler = None
# scaling kwargs (Stage-1 source of truth)
sk_stage1 = (cfg.get("scaling_kwargs") or {}).copy()
# Resolve driver channel indices.
if (
"gwl_dyn_index" in sk_stage1
and sk_stage1["gwl_dyn_index"] is not None
):
GWL_DYN_INDEX = int(sk_stage1["gwl_dyn_index"])
gwl_dyn_name = DYN_NAMES[GWL_DYN_INDEX]
else:
gwl_dyn_name = (
sk_stage1.get("gwl_dyn_name") or GWL_COL
)
if gwl_dyn_name not in DYN_NAMES:
for cand in (GWL_COL, "z_GWL", "gwl", "GWL"):
if cand in DYN_NAMES:
gwl_dyn_name = cand
break
GWL_DYN_INDEX = int(DYN_NAMES.index(gwl_dyn_name))
Z_SURF_STATIC_INDEX = sk_stage1.get("z_surf_static_index")
SUBS_DYN_INDEX = sk_stage1.get("subs_dyn_index")
sub_dyn_name = sk_stage1.get("subs_dyn_name")
# Coords normalization and degrees
coords_normalized = bool(
sk_stage1.get(
"coords_normalized",
sk_stage1.get("normalize_coords", False),
)
)
coord_ranges = sk_stage1.get("coord_ranges") or None
coords_in_degrees = bool(
sk_stage1.get("coords_in_degrees", False)
)
deg_to_m_lon = sk_stage1.get("deg_to_m_lon", None)
deg_to_m_lat = sk_stage1.get("deg_to_m_lat", None)
coord_order = sk_stage1.get(
"coord_order", ["t", "x", "y"]
)
if coords_in_degrees and (
deg_to_m_lon is None or deg_to_m_lat is None
):
lat_ref_deg = sk_stage1.get("lat_ref_deg", None)
if lat_ref_deg is not None and np.isfinite(
float(lat_ref_deg)
):
deg_to_m_lon, deg_to_m_lat = deg_to_m_from_lat(
float(lat_ref_deg)
)
sk_stage1["deg_to_m_lon"] = float(deg_to_m_lon)
sk_stage1["deg_to_m_lat"] = float(deg_to_m_lat)
# Thickness SI affine
H_scale_si = sk_stage1.get("H_scale_si", None)
H_bias_si = sk_stage1.get("H_bias_si", None)
# Resolve SI affine maps for subs/head
subs_scale_si = sk_stage1.get("subs_scale_si")
subs_bias_si = sk_stage1.get("subs_bias_si")
head_scale_si = sk_stage1.get("head_scale_si")
head_bias_si = sk_stage1.get("head_bias_si")
if subs_scale_si is None or subs_bias_si is None:
subs_scale_si, subs_bias_si = resolve_si_affine(
cfg,
scaler_info_dict,
target_name=SUBSIDENCE_COL,
prefix="SUBS",
unit_factor_key="SUBS_UNIT_TO_SI",
scale_key="SUBS_SCALE_SI",
bias_key="SUBS_BIAS_SI",
)
if head_scale_si is None or head_bias_si is None:
head_scale_si, head_bias_si = resolve_si_affine(
cfg,
scaler_info_dict,
target_name=GWL_COL,
prefix="HEAD",
unit_factor_key="HEAD_UNIT_TO_SI",
scale_key="HEAD_SCALE_SI",
bias_key="HEAD_BIAS_SI",
)
# Physics bounds
PHYS_BOUNDS_CFG = cfg.get("PHYSICS_BOUNDS", {}) or {}
BOUNDS_MODE = str(
cfg.get("PHYSICS_BOUNDS_MODE", "soft") or "soft"
)
BOUNDS_MODE = BOUNDS_MODE.strip().lower()
if BOUNDS_MODE in ("off", "none"):
BOUNDS_MODE = "off"
default_phys_bounds = {
"H_min": 5.0,
"H_max": 80.0,
"K_min": 1e-8,
"K_max": 1e-3,
"Ss_min": 1e-7,
"Ss_max": 1e-3,
"tau_min": 7.0 * 86400.0,
"tau_max": 300.0 * 31556952.0,
}
phys_bounds = dict(default_phys_bounds)
phys_bounds.update(PHYS_BOUNDS_CFG)
bounds_for_scaling = {
"H_min": float(phys_bounds["H_min"]),
"H_max": float(phys_bounds["H_max"]),
"K_min": float(phys_bounds["K_min"]),
"K_max": float(phys_bounds["K_max"]),
"Ss_min": float(phys_bounds["Ss_min"]),
"Ss_max": float(phys_bounds["Ss_max"]),
"tau_min": float(phys_bounds["tau_min"]),
"tau_max": float(phys_bounds["tau_max"]),
"logK_min": float(np.log(phys_bounds["K_min"])),
"logK_max": float(np.log(phys_bounds["K_max"])),
"logSs_min": float(np.log(phys_bounds["Ss_min"])),
"logSs_max": float(np.log(phys_bounds["Ss_max"])),
"logTau_min": float(np.log(phys_bounds["tau_min"])),
"logTau_max": float(np.log(phys_bounds["tau_max"])),
}
# GeoPrior parameters
GEOPRIOR_INIT_MV = float(
cfg.get("GEOPRIOR_INIT_MV", 1e-7)
)
GEOPRIOR_INIT_KAPPA = float(
cfg.get("GEOPRIOR_INIT_KAPPA", 1.0)
)
GEOPRIOR_GAMMA_W = float(
cfg.get("GEOPRIOR_GAMMA_W", 9810.0)
)
GEOPRIOR_H_REF = cfg.get("GEOPRIOR_H_REF", 0.0)
GEOPRIOR_KAPPA_MODE = cfg.get(
"GEOPRIOR_KAPPA_MODE", "bar"
)
GEOPRIOR_USE_EFF_H = bool(
cfg.get("GEOPRIOR_USE_EFFECTIVE_H", True)
)
GEOPRIOR_HD_FACTOR = float(
cfg.get("GEOPRIOR_HD_FACTOR", 0.6)
)
GEOPRIOR_H_REF_VALUE = 0.0
GEOPRIOR_H_REF_MODE = None
if isinstance(GEOPRIOR_H_REF, int | float):
GEOPRIOR_H_REF_VALUE = float(GEOPRIOR_H_REF)
else:
GEOPRIOR_H_REF_MODE = GEOPRIOR_H_REF
# Training strategy gates
TRAINING_STRATEGY = (
str(cfg.get("TRAINING_STRATEGY", "data_first"))
.strip()
.lower()
)
# Compute steps_per_epoch
X_train_norm = ensure_input_shapes(
X_train,
mode=MODE,
forecast_horizon=FORECAST_H,
)
n_train = int(X_train_norm["static_features"].shape[0])
steps_per_epoch = int(
np.ceil(n_train / float(BATCH_SIZE))
)
# Gate policies (reusing sensitivity.py logic)
q_policy = "always_on"
q_warmup_epochs = 0
q_ramp_epochs = 0
subs_resid_policy = "always_on"
subs_resid_warmup_epochs = 0
subs_resid_ramp_epochs = 0
if TRAINING_STRATEGY == "physics_first":
q_policy = (
str(
cfg.get(
"Q_POLICY_PHYSICS_FIRST", "warmup_off"
)
)
.strip()
.lower()
)
q_warmup_epochs = int(
cfg.get("Q_WARMUP_EPOCHS_PHYSICS_FIRST", 5)
)
q_ramp_epochs = int(
cfg.get("Q_RAMP_EPOCHS_PHYSICS_FIRST", 0)
)
subs_resid_policy = (
str(
cfg.get(
"SUBS_RESID_POLICY_PHYSICS_FIRST",
"warmup_off",
)
)
.strip()
.lower()
)
subs_resid_warmup_epochs = int(
cfg.get(
"SUBS_RESID_WARMUP_EPOCHS_PHYSICS_FIRST",
5,
)
)
subs_resid_ramp_epochs = int(
cfg.get(
"SUBS_RESID_RAMP_EPOCHS_PHYSICS_FIRST",
0,
)
)
LAMBDA_Q = float(
cfg.get("LAMBDA_Q_PHYSICS_FIRST", LAMBDA_Q)
)
LOSS_WEIGHT_GWL = float(
cfg.get(
"LOSS_WEIGHT_GWL_PHYSICS_FIRST",
LOSS_WEIGHT_GWL,
)
)
else:
LOSS_WEIGHT_GWL = float(
cfg.get(
"LOSS_WEIGHT_GWL_DATA_FIRST", LOSS_WEIGHT_GWL
)
)
LAMBDA_Q = float(
cfg.get("LAMBDA_Q_DATA_FIRST", LAMBDA_Q)
)
q_policy = (
str(cfg.get("Q_POLICY_DATA_FIRST", "always_on"))
.strip()
.lower()
)
q_warmup_epochs = int(
cfg.get("Q_WARMUP_EPOCHS_DATA_FIRST", 0)
)
q_ramp_epochs = int(
cfg.get("Q_RAMP_EPOCHS_DATA_FIRST", 0)
)
subs_resid_policy = (
str(
cfg.get(
"SUBS_RESID_POLICY_DATA_FIRST",
"always_on",
)
)
.strip()
.lower()
)
subs_resid_warmup_epochs = int(
cfg.get("SUBS_RESID_WARMUP_EPOCHS_DATA_FIRST", 0)
)
subs_resid_ramp_epochs = int(
cfg.get("SUBS_RESID_RAMP_EPOCHS_DATA_FIRST", 0)
)
if q_policy == "always_off":
LAMBDA_Q = 0.0
q_warmup_steps = max(0, q_warmup_epochs) * steps_per_epoch
q_ramp_steps = max(0, q_ramp_epochs) * steps_per_epoch
subs_resid_warmup_steps = (
max(0, subs_resid_warmup_epochs) * steps_per_epoch
)
subs_resid_ramp_steps = (
max(0, subs_resid_ramp_epochs) * steps_per_epoch
)
# MV prior schedule
MV_PRIOR_MODE = str(
sk_stage1.get(
"mv_prior_mode",
cfg.get("MV_PRIOR_MODE", "calibrate"),
)
)
MV_WEIGHT = float(
sk_stage1.get("mv_weight", cfg.get("MV_WEIGHT", 1e-3))
)
MV_SCHEDULE_UNIT = (
str(
sk_stage1.get(
"mv_schedule_unit",
cfg.get("MV_SCHEDULE_UNIT", "epoch"),
)
)
.strip()
.lower()
)
MV_DELAY_EPOCHS = int(
sk_stage1.get(
"mv_delay_epochs", cfg.get("MV_DELAY_EPOCHS", 1)
)
)
MV_WARMUP_EPOCHS = int(
sk_stage1.get(
"mv_warmup_epochs",
cfg.get("MV_WARMUP_EPOCHS", 2),
)
)
mv_delay_steps = sk_stage1.get("mv_delay_steps", None)
mv_warmup_steps = sk_stage1.get("mv_warmup_steps", None)
if mv_delay_steps is None:
mv_delay_steps = (
max(0, MV_DELAY_EPOCHS) * steps_per_epoch
)
if mv_warmup_steps is None:
mv_warmup_steps = (
max(0, MV_WARMUP_EPOCHS) * steps_per_epoch
)
# ---------------------------
# Logging header
# ---------------------------
device_info = configure_tf_from_cfg(cfg)
config_sections = [
(
"Run",
{
"CITY_NAME": CITY_NAME,
"MODEL_NAME": MODEL_NAME,
"MANIFEST_PATH": ctx.manifest_path,
"RUN_OUTPUT_PATH": run_dir,
},
),
(
"Architecture",
{
"TIME_STEPS": TIME_STEPS,
"FORECAST_HORIZON_YEARS": FORECAST_H,
"MODE": MODE,
"ATTENTION_LEVELS": ATTENTION_LEVELS,
"EMBED_DIM": EMBED_DIM,
"HIDDEN_UNITS": HIDDEN_UNITS,
"LSTM_UNITS": LSTM_UNITS,
"ATTENTION_UNITS": ATTENTION_UNITS,
"NUMBER_HEADS": NUMBER_HEADS,
"DROPOUT_RATE": DROPOUT_RATE,
"MEMORY_SIZE": MEMORY_SIZE,
"SCALES": SCALES,
},
),
(
"Physics",
{
"PDE_MODE_CONFIG": PDE_MODE,
"TIME_UNITS": TIME_UNITS,
"LAMBDA_CONS": LAMBDA_CONS,
"LAMBDA_GW": LAMBDA_GW,
"LAMBDA_PRIOR": LAMBDA_PRIOR,
"LAMBDA_SMOOTH": LAMBDA_SMOOTH,
"LAMBDA_BOUNDS": LAMBDA_BOUNDS,
"LAMBDA_MV": LAMBDA_MV,
"LAMBDA_Q": LAMBDA_Q,
"LOSS_WEIGHT_GWL": LOSS_WEIGHT_GWL,
"MV_LR_MULT": MV_LR_MULT,
"KAPPA_LR_MULT": KAPPA_LR_MULT,
"PHYSICS_BOUNDS": phys_bounds,
},
),
(
"Training",
{
"EPOCHS": EPOCHS,
"BATCH_SIZE": BATCH_SIZE,
"LEARNING_RATE": LEARNING_RATE,
"QUANTILES": QUANTILES,
},
),
]
print_config_table(
config_sections,
table_width=get_table_size(),
title=(f"{CITY_NAME.upper()} {MODEL_NAME} SENS RUN"),
)
print("\nTraining outputs ->", run_dir)
# ---------------------------
# Build & compile model
# ---------------------------
s_dim_model = int(
X_train_norm["static_features"].shape[-1]
)
d_dim_model = int(
X_train_norm["dynamic_features"].shape[-1]
)
f_dim_model = int(
X_train_norm["future_features"].shape[-1]
)
MODEL_CLASS_REGISTRY = {
"GeoPriorSubsNet": GeoPriorSubsNet,
"PoroElasticSubsNet": PoroElasticSubsNet,
"HybridAttn-NoPhysics": GeoPriorSubsNet,
}
model_cls = MODEL_CLASS_REGISTRY.get(
MODEL_NAME, GeoPriorSubsNet
)
sk_model = dict(sk_stage1)
sk_model.update(
{
"bounds": bounds_for_scaling,
"time_units": TIME_UNITS,
"coords_normalized": coords_normalized,
"coord_ranges": coord_ranges or {},
"coord_order": coord_order,
"coords_in_degrees": coords_in_degrees,
"deg_to_m_lon": (
float(deg_to_m_lon)
if deg_to_m_lon is not None
else None
),
"deg_to_m_lat": (
float(deg_to_m_lat)
if deg_to_m_lat is not None
else None
),
"H_scale_si": (
float(H_scale_si)
if H_scale_si is not None
else 1.0
),
"H_bias_si": (
float(H_bias_si)
if H_bias_si is not None
else 0.0
),
"dynamic_feature_names": list(DYN_NAMES),
"future_feature_names": list(FUT_NAMES),
"static_feature_names": list(STA_NAMES),
"gwl_dyn_name": gwl_dyn_name,
"gwl_dyn_index": int(GWL_DYN_INDEX),
"z_surf_static_index": (
int(Z_SURF_STATIC_INDEX)
if Z_SURF_STATIC_INDEX is not None
else None
),
"subs_dyn_index": (
int(SUBS_DYN_INDEX)
if SUBS_DYN_INDEX is not None
else None
),
"subs_dyn_name": (
sub_dyn_name
if sub_dyn_name is not None
else SUBSIDENCE_COL
),
"subs_scale_si": subs_scale_si,
"subs_bias_si": subs_bias_si,
"head_scale_si": head_scale_si,
"head_bias_si": head_bias_si,
"training_strategy": TRAINING_STRATEGY,
"q_policy": q_policy,
"q_warmup_epochs": int(q_warmup_epochs),
"q_ramp_epochs": int(q_ramp_epochs),
"q_warmup_steps": int(q_warmup_steps),
"q_ramp_steps": int(q_ramp_steps),
"log_q_diagnostics": bool(LOG_Q_DIAGNOSTICS),
"subs_resid_policy": subs_resid_policy,
"subs_resid_warmup_epochs": int(
subs_resid_warmup_epochs
),
"subs_resid_ramp_epochs": int(
subs_resid_ramp_epochs
),
"subs_resid_warmup_steps": int(
subs_resid_warmup_steps
),
"subs_resid_ramp_steps": int(
subs_resid_ramp_steps
),
"loss_weight_gwl": float(LOSS_WEIGHT_GWL),
"lambda_q": float(LAMBDA_Q),
"mv_prior_mode": MV_PRIOR_MODE,
"mv_weight": float(MV_WEIGHT),
"mv_schedule_unit": MV_SCHEDULE_UNIT,
"mv_delay_epochs": int(MV_DELAY_EPOCHS),
"mv_warmup_epochs": int(MV_WARMUP_EPOCHS),
"mv_delay_steps": int(mv_delay_steps),
"mv_warmup_steps": int(mv_warmup_steps),
"mv_steps_per_epoch": int(steps_per_epoch),
"identifiability_regime": ident,
}
)
sk_model = finalize_scaling_kwargs(sk_model)
sk_model = override_scaling_kwargs(
sk_model,
cfg,
finalize=finalize_scaling_kwargs,
dyn_names=DYN_NAMES,
gwl_dyn_index=GWL_DYN_INDEX,
base_dir=os.path.dirname(__file__),
strict=True,
log_fn=print,
)
# Keep it on disk for audit.
with open(
os.path.join(run_dir, "scaling_kwargs.json"),
"w",
encoding="utf-8",
) as f:
json.dump(sk_model, f, indent=2)
subsmodel_params = {
"embed_dim": EMBED_DIM,
"hidden_units": HIDDEN_UNITS,
"lstm_units": LSTM_UNITS,
"attention_units": ATTENTION_UNITS,
"num_heads": NUMBER_HEADS,
"dropout_rate": DROPOUT_RATE,
"max_window_size": TIME_STEPS,
"memory_size": MEMORY_SIZE,
"scales": SCALES,
"multi_scale_agg": "last",
"final_agg": "last",
"use_residuals": USE_RESIDUALS,
"use_batch_norm": USE_BATCH_NORM,
"use_vsn": USE_VSN,
"vsn_units": VSN_UNITS,
"mode": MODE,
"attention_levels": ATTENTION_LEVELS,
"scale_pde_residuals": SCALE_PDE_RESIDUALS,
"scaling_kwargs": sk_model,
"bounds_mode": BOUNDS_MODE,
"mv": LearnableMV(initial_value=GEOPRIOR_INIT_MV),
"kappa": LearnableKappa(
initial_value=GEOPRIOR_INIT_KAPPA
),
"gamma_w": FixedGammaW(value=GEOPRIOR_GAMMA_W),
"h_ref": FixedHRef(
value=GEOPRIOR_H_REF_VALUE,
mode=GEOPRIOR_H_REF_MODE,
),
"kappa_mode": GEOPRIOR_KAPPA_MODE,
"use_effective_h": GEOPRIOR_USE_EFF_H,
"hd_factor": GEOPRIOR_HD_FACTOR,
"offset_mode": OFFSET_MODE,
"residual_method": CONS_RESID_METHOD,
"time_units": TIME_UNITS,
}
if should_audit(cfg.get("AUDIT_STAGES"), stage="stage2"):
_ = audit_stage2_handshake(
X_train=X_train,
X_val=X_val,
y_train=y_train,
y_val=y_val,
time_steps=TIME_STEPS,
forecast_horizon=FORECAST_H,
mode=MODE,
dyn_names=list(DYN_NAMES),
fut_names=list(FUT_NAMES),
sta_names=list(STA_NAMES),
coord_scaler=coord_scaler,
sk_final=sk_model,
save_dir=run_dir,
table_width=get_table_size(),
title_prefix="STAGE-2 HANDSHAKE AUDIT",
city=CITY_NAME,
model_name=MODEL_NAME,
)
subs_model_inst = model_cls(
static_input_dim=s_dim_model,
dynamic_input_dim=d_dim_model,
future_input_dim=f_dim_model,
output_subsidence_dim=OUT_S_DIM,
output_gwl_dim=OUT_G_DIM,
forecast_horizon=FORECAST_H,
quantiles=QUANTILES,
pde_mode=PDE_MODE,
identifiability_regime=ident,
verbose=0,
**subsmodel_params,
)
# Build outputs once.
for xb, _ in train_dataset.take(1):
subs_model_inst(xb)
break
# Losses
loss_dict = {
"subs_pred": (
make_weighted_pinball(QUANTILES, SUBS_WEIGHTS)
if QUANTILES
else tf.keras.losses.MSE
),
"gwl_pred": (
make_weighted_pinball(QUANTILES, GWL_WEIGHTS)
if QUANTILES
else tf.keras.losses.MSE
),
}
# Compile-time metrics (only if not tracked internally)
TRACK_AUX_METRICS = bool(
cfg.get(
"TRACK_AUX_METRICS",
cfg.get("TRACK_ADD_ON_METRICS", True),
)
)
if TRACK_AUX_METRICS:
metrics_arg = None
else:
if QUANTILES:
metrics_arg = {
"subs_pred": [
MAEQ50(name="mae_q50"),
MSEQ50(name="mse_q50"),
Coverage80(name="coverage80"),
Sharpness80(name="sharpness80"),
],
"gwl_pred": [
MAEQ50(name="mae_q50"),
MSEQ50(name="mse_q50"),
],
}
else:
metrics_arg = {
"subs_pred": ["mae", "mse"],
"gwl_pred": ["mae", "mse"],
}
physics_loss_weights = {
"lambda_cons": LAMBDA_CONS,
"lambda_gw": LAMBDA_GW,
"lambda_prior": LAMBDA_PRIOR,
"lambda_smooth": LAMBDA_SMOOTH,
"lambda_bounds": LAMBDA_BOUNDS,
"lambda_mv": LAMBDA_MV,
"mv_lr_mult": MV_LR_MULT,
"lambda_offset": LAMBDA_OFFSET,
"kappa_lr_mult": KAPPA_LR_MULT,
"lambda_q": float(LAMBDA_Q),
}
loss_weights_dict = {
"subs_pred": 1.0,
"gwl_pred": float(LOSS_WEIGHT_GWL),
}
out_names = list(
getattr(subs_model_inst, "output_names", [])
) or ["subs_pred", "gwl_pred"]
import keras
is_keras2 = keras.__version__.startswith("2.")
if is_keras2:
loss_arg = [loss_dict[k] for k in out_names]
lossw_arg = [
loss_weights_dict.get(k, 1.0) for k in out_names
]
metrics_compile = (
None
if metrics_arg is None
else [metrics_arg.get(k, []) for k in out_names]
)
else:
loss_arg = loss_dict
lossw_arg = loss_weights_dict
metrics_compile = metrics_arg
subs_model_inst.compile(
optimizer=tf.keras.optimizers.Adam(
learning_rate=LEARNING_RATE,
clipnorm=1.0,
),
loss=loss_arg,
loss_weights=lossw_arg,
metrics=metrics_compile,
**physics_loss_weights,
)
# ---------------------------
# Training
# ---------------------------
bundle_prefix = f"{CITY_NAME}_{MODEL_NAME}_H{FORECAST_H}"
best_keras_path = os.path.join(
run_dir,
f"{bundle_prefix}_best.keras",
)
best_weights_path = os.path.join(
run_dir,
f"{bundle_prefix}_best.weights.h5",
)
best_tf_dir = os.path.join(
run_dir,
f"{bundle_prefix}_best_savedmodel",
)
model_init_manifest_path = os.path.join(
run_dir,
"model_init_manifest.json",
)
model_init_manifest = {
"model_class": model_cls.__name__,
"dims": {
"static_input_dim": int(s_dim_model),
"dynamic_input_dim": int(d_dim_model),
"future_input_dim": int(f_dim_model),
"output_subsidence_dim": int(OUT_S_DIM),
"output_gwl_dim": int(OUT_G_DIM),
"forecast_horizon": int(FORECAST_H),
},
"config": {
"quantiles": list(QUANTILES)
if QUANTILES
else None,
"pde_mode": PDE_MODE,
"mode": MODE,
"time_units": TIME_UNITS,
"embed_dim": EMBED_DIM,
"hidden_units": HIDDEN_UNITS,
"lstm_units": LSTM_UNITS,
"attention_units": ATTENTION_UNITS,
"num_heads": NUMBER_HEADS,
"dropout_rate": DROPOUT_RATE,
"memory_size": MEMORY_SIZE,
"scales": list(SCALES) if SCALES else None,
"use_residuals": bool(USE_RESIDUALS),
"use_batch_norm": bool(USE_BATCH_NORM),
"use_vsn": bool(USE_VSN),
"vsn_units": int(VSN_UNITS),
"geoprior": {
"init_mv": float(GEOPRIOR_INIT_MV),
"init_kappa": float(GEOPRIOR_INIT_KAPPA),
"gamma_w": float(GEOPRIOR_GAMMA_W),
"h_ref_value": float(GEOPRIOR_H_REF_VALUE),
"h_ref_mode": GEOPRIOR_H_REF_MODE,
"kappa_mode": GEOPRIOR_KAPPA_MODE,
"use_effective_h": bool(GEOPRIOR_USE_EFF_H),
"hd_factor": float(GEOPRIOR_HD_FACTOR),
"offset_mode": OFFSET_MODE,
},
"scaling_kwargs": sk_model,
"identifiability_regime": ident,
},
}
save_manifest(
model_init_manifest_path, model_init_manifest
)
callbacks = [
ModelCheckpoint(
filepath=best_keras_path,
monitor="val_loss",
save_best_only=True,
save_weights_only=False,
verbose=1,
),
ModelCheckpoint(
filepath=best_weights_path,
monitor="val_loss",
save_best_only=True,
save_weights_only=True,
verbose=1,
),
]
disable_es = bool(
cfg.get("DISABLE_EARLY_STOPPING", False)
)
if not disable_es:
callbacks.append(
EarlyStopping(
monitor="val_loss",
patience=15,
restore_best_weights=True,
verbose=1,
)
)
csvlog_path = os.path.join(
run_dir,
f"{CITY_NAME}_{MODEL_NAME}_train_log.csv",
)
callbacks.append(CSVLogger(csvlog_path, append=False))
callbacks.append(TerminateOnNaN())
if USE_LAMBDA_OFFSET_SCHED and (
not subs_model_inst._physics_off()
):
callbacks.append(
LambdaOffsetScheduler(
schedule=LAMBDA_OFFSET_SCHEDULE,
unit=LAMBDA_OFFSET_UNIT,
when=LAMBDA_OFFSET_WHEN,
warmup=LAMBDA_OFFSET_WARMUP,
start=LAMBDA_OFFSET_START,
end=LAMBDA_OFFSET_END,
clamp_positive=True,
verbose=1,
)
)
print("\nTraining...")
history = subs_model_inst.fit(
train_dataset,
validation_data=val_dataset,
epochs=EPOCHS,
callbacks=callbacks,
verbose=VERBOSE,
)
# Save training summary
best_epoch, metrics_at_best = best_epoch_and_metrics(
history.history
)
training_summary = {
"timestamp": dt.datetime.now().strftime(
"%Y%m%d-%H%M%S"
),
"city": CITY_NAME,
"model": MODEL_NAME,
"horizon": int(FORECAST_H),
"best_epoch": (
int(best_epoch)
if best_epoch is not None
else None
),
"metrics_at_best": metrics_at_best,
"final_epoch_metrics": {
k: float(v[-1])
for k, v in history.history.items()
if len(v)
},
"env": {
"python": sys.version.split()[0],
"tensorflow": tf.__version__,
"numpy": np.__version__,
"platform": platform.platform(),
"device": device_info,
},
"compile": {
"optimizer": "Adam",
"learning_rate": float(LEARNING_RATE),
"loss_weights": loss_weights_dict,
"metrics": (
{
k: [name_of(m) for m in v]
for k, v in metrics_arg.items()
}
if metrics_arg
else {}
),
"physics_loss_weights": physics_loss_weights,
"lambda_offset": LAMBDA_OFFSET,
},
"hp_init": {
"quantiles": QUANTILES,
"subs_weights": SUBS_WEIGHTS,
"gwl_weights": GWL_WEIGHTS,
"attention_levels": ATTENTION_LEVELS,
"pde_mode": PDE_MODE,
"time_steps": int(TIME_STEPS),
"mode": MODE,
"model_init_params": serialize_subs_params(
subsmodel_params,
cfg,
),
"offset_mode": OFFSET_MODE,
"scaling_kwargs": {
"bounds": bounds_for_scaling,
"time_units": TIME_UNITS,
"coords_normalized": coords_normalized,
"coord_ranges": coord_ranges or {},
},
"identifiability_regime": ident,
},
"paths": {
"run_dir": run_dir,
"csv_log": csvlog_path,
"best_keras": best_keras_path,
"best_weights": best_weights_path,
"model_init_manifest": model_init_manifest_path,
},
}
summary_json_path = os.path.join(
run_dir,
f"{CITY_NAME}_{MODEL_NAME}_training_summary.json",
)
with open(summary_json_path, "w", encoding="utf-8") as f:
json.dump(training_summary, f, indent=2)
# Save final model
final_model_path = os.path.join(
run_dir,
f"{CITY_NAME}_{MODEL_NAME}_H{FORECAST_H}_final.keras",
)
try:
subs_model_inst.save(final_model_path)
training_summary["paths"]["final_keras"] = (
final_model_path
)
except Exception as e:
print("[Warn] Final save failed:", e)
# Plots (optional)
if not FAST_SENS:
history_groups = {
"Total Loss": ["total_loss"],
"Data vs Physics": [
"data_loss",
"physics_loss_scaled",
"physics_loss",
],
"Offset Controls": [
"lambda_offset",
"physics_mult",
],
"Physics Components": [
"consolidation_loss",
"gw_flow_loss",
"prior_loss",
"smooth_loss",
"mv_prior_loss",
"bounds_loss",
],
"Subsidence MAE": ["subs_pred_mae"],
"GWL MAE": ["gwl_pred_mae"],
}
yscales = {
"Total Loss": "log",
"Data vs Physics": "log",
"Physics Components": "log",
"Offset Controls": "linear",
"Subsidence MAE": "linear",
"GWL MAE": "linear",
}
plot_history_in(
history.history,
metrics=history_groups,
title=f"{MODEL_NAME} Training History",
yscale_settings=yscales,
layout="subplots",
savefig=os.path.join(
run_dir,
f"{CITY_NAME}_{MODEL_NAME.lower()}_history.png",
),
)
autoplot_geoprior_history(
history,
outdir=run_dir,
prefix=bundle_prefix,
style="default",
log_fn=print,
)
# Physical parameters CSV
phys_model_tag = "geoprior"
if MODEL_NAME == "PoroElasticSubsNet":
phys_model_tag = "poroelastic"
elif MODEL_NAME.startswith("HybridAttn"):
phys_model_tag = "hybridattn"
extract_physical_parameters(
subs_model_inst,
to_csv=True,
filename=(
f"{CITY_NAME}_{MODEL_NAME.lower()}_"
"physical_parameters.csv"
),
save_dir=run_dir,
model_name=phys_model_tag,
)
# ---------------------------
# Inference model
# ---------------------------
USE_IN_MEMORY_MODEL = bool(
cfg.get("USE_IN_MEMORY_MODEL", True)
)
USE_TF_SAVEDMODEL = bool(
cfg.get("USE_TF_SAVEDMODEL", False)
)
build_inputs = None
for xb, _ in val_dataset.take(1):
build_inputs = xb
break
def builder(manifest: dict):
dims2 = (manifest or {}).get("dims", {}) or {}
cfgm = (manifest or {}).get("config", {}) or {}
gp = cfgm.get("geoprior", {}) or {}
_subsparams = dict(subsmodel_params)
_subsparams.update(
{
"mv": LearnableMV(
initial_value=float(
gp.get("init_mv", GEOPRIOR_INIT_MV)
)
),
"kappa": LearnableKappa(
initial_value=float(
gp.get(
"init_kappa",
GEOPRIOR_INIT_KAPPA,
)
)
),
"gamma_w": FixedGammaW(
value=float(
gp.get("gamma_w", GEOPRIOR_GAMMA_W)
)
),
"h_ref": FixedHRef(
value=float(
gp.get(
"h_ref_value",
GEOPRIOR_H_REF_VALUE,
)
),
mode=gp.get(
"h_ref_mode",
GEOPRIOR_H_REF_MODE,
),
),
}
)
return model_cls(
static_input_dim=int(
dims2.get("static_input_dim", s_dim_model)
),
dynamic_input_dim=int(
dims2.get("dynamic_input_dim", d_dim_model)
),
future_input_dim=int(
dims2.get("future_input_dim", f_dim_model)
),
output_subsidence_dim=int(
dims2.get("output_subsidence_dim", OUT_S_DIM)
),
output_gwl_dim=int(
dims2.get("output_gwl_dim", OUT_G_DIM)
),
forecast_horizon=int(
dims2.get("forecast_horizon", FORECAST_H)
),
quantiles=cfgm.get("quantiles", QUANTILES),
pde_mode=cfgm.get("pde_mode", PDE_MODE),
identifiability_regime=ident,
verbose=0,
**_subsparams,
)
save_model(
model=subs_model_inst,
keras_path=(
best_tf_dir
if USE_TF_SAVEDMODEL
else best_keras_path
),
weights_path=best_weights_path,
manifest_path=model_init_manifest_path,
manifest=model_init_manifest,
overwrite=True,
use_tf_format=USE_TF_SAVEDMODEL,
)
if USE_IN_MEMORY_MODEL:
model_inf = subs_model_inst
print("[Info] Using in-memory model.")
elif USE_TF_SAVEDMODEL:
model_inf = load_model_from_tfv2(
best_tf_dir,
endpoint="serve",
custom_objects={
"GeoPriorSubsNet": GeoPriorSubsNet,
"PoroElasticSubsNet": PoroElasticSubsNet,
"LearnableMV": LearnableMV,
"LearnableKappa": LearnableKappa,
"FixedGammaW": FixedGammaW,
"FixedHRef": FixedHRef,
"make_weighted_pinball": make_weighted_pinball,
},
)
else:
model_inf = load_inference_model(
keras_path=best_keras_path,
weights_path=best_weights_path,
manifest_path=model_init_manifest_path,
custom_objects={
"GeoPriorSubsNet": GeoPriorSubsNet,
"PoroElasticSubsNet": PoroElasticSubsNet,
"LearnableMV": LearnableMV,
"LearnableKappa": LearnableKappa,
"FixedGammaW": FixedGammaW,
"FixedHRef": FixedHRef,
"make_weighted_pinball": make_weighted_pinball,
},
compile=False,
builder=builder,
build_inputs=build_inputs,
prefer_full_model=True,
log_fn=print,
use_in_memory_model=False,
)
if DEBUG and (model_inf is not subs_model_inst):
_ = debug_model_reload(
subs_model_inst,
model_inf,
val_dataset,
pred_key="subs_pred",
also_check=["subs_pred", "gwl_pred"],
top_weights=30,
log_fn=print,
)
# ---------------------------
# Interval calibrator
# ---------------------------
cal80 = None
val_for_cal = val_dataset
if cal_max_batches is not None:
val_for_cal = val_dataset.take(int(cal_max_batches))
if QUANTILES:
print("\nFit interval calibrator (80%)...")
cal80 = fit_interval_calibrator_on_val(
model_inf,
val_for_cal,
target=0.80,
q_values=QUANTILES,
)
np.save(
os.path.join(run_dir, "interval_factors_80.npy"),
cal80.factors_,
)
# ---------------------------
# Forecasting split
# ---------------------------
dataset_name_for_forecast = "ValidationSet_Fallback"
if X_test is not None and y_test is not None:
X_fore, y_fore = X_test, y_test
dataset_name_for_forecast = "TestSet"
else:
X_fore, y_fore = X_val, y_val
X_fore = ensure_input_shapes(
X_fore,
mode=MODE,
forecast_horizon=FORECAST_H,
)
# y_fore_fmt = map_targets_for_training(y_fore)
def _slice_any(x: Any, n: int) -> Any:
"""Slice nested inputs/targets along the sample axis."""
if x is None:
return None
if isinstance(x, dict):
return {k: _slice_any(v, n) for k, v in x.items()}
if isinstance(x, list | tuple):
return type(x)(_slice_any(v, n) for v in x)
try:
return x[:n]
except Exception:
return x
# Keep forecasting tensors and ds_eval consistent when we
# limit evaluation/export to a subset of batches.
if eval_max_batches is not None:
n_take = int(eval_max_batches) * int(BATCH_SIZE)
X_fore = _slice_any(X_fore, n_take)
y_fore = _slice_any(y_fore, n_take)
dataset_name_for_forecast += (
f"_Take{int(eval_max_batches)}B"
)
y_fore_fmt = map_targets_for_training(y_fore)
print(f"\nPredicting on {dataset_name_for_forecast}...")
pred_out = model_inf.predict(X_fore, verbose=0)
pred_dict = normalize_predict_output(
model_inf,
x=X_fore,
pred_out=pred_out,
required=("subs_pred", "gwl_pred"),
batch_n=32,
log_fn=print if DEBUG else None,
)
s_pred = pred_dict["subs_pred"]
h_pred = pred_dict["gwl_pred"]
if QUANTILES:
if cal80 is not None:
s_pred_cal = apply_calibrator_to_subs(
cal80,
s_pred,
q_values=QUANTILES,
)
else:
s_pred_cal = s_pred
predictions_for_formatter = {
"subs_pred": s_pred_cal,
"gwl_pred": h_pred,
}
else:
predictions_for_formatter = {
"subs_pred": s_pred,
"gwl_pred": h_pred,
}
# Point metrics
ev_point = evaluate_point_forecast(
model_inf,
predictions_for_formatter,
y_true_subs=y_fore_fmt["subs_pred"],
y_true_gwl=y_fore_fmt.get("gwl_pred"),
n_q=(len(QUANTILES) if QUANTILES else 3),
quantiles=(QUANTILES if QUANTILES else None),
use_physical=True,
scaler_info=scaler_info_dict,
subs_target_name=SUBS_SCALER_KEY,
gwl_target_name=GWL_SCALER_KEY,
)
metrics_point = ev_point.get("subs_metrics", {})
per_h_mae_dict = ev_point.get("subs_mae_h", {})
per_h_r2_dict = ev_point.get("subs_r2_h", {})
# CSV outputs
csv_eval = os.path.join(
run_dir,
(
f"{CITY_NAME}_{MODEL_NAME}_forecast_"
f"{dataset_name_for_forecast}_H"
f"{FORECAST_H}_eval.csv"
),
)
csv_future = os.path.join(
run_dir,
(
f"{CITY_NAME}_{MODEL_NAME}_forecast_"
f"{dataset_name_for_forecast}_H"
f"{FORECAST_H}_future.csv"
),
)
csv_eval_cal = os.path.join(
run_dir,
(
f"{CITY_NAME}_{MODEL_NAME}_forecast_"
f"{dataset_name_for_forecast}_H"
f"{FORECAST_H}_eval_calibrated.csv"
),
)
csv_future_cal = os.path.join(
run_dir,
(
f"{CITY_NAME}_{MODEL_NAME}_forecast_"
f"{dataset_name_for_forecast}_H"
f"{FORECAST_H}_future_calibrated.csv"
),
)
cal_stats_path = os.path.join(
run_dir,
(
f"{CITY_NAME}_{MODEL_NAME}_calibration_stats_"
f"{dataset_name_for_forecast}_H"
f"{FORECAST_H}.json"
),
)
metrics_json = os.path.join(
run_dir,
(
f"{CITY_NAME}_{MODEL_NAME}_eval_diagnostics_"
f"{dataset_name_for_forecast}_H{FORECAST_H}.json"
),
)
future_grid = np.arange(
FORECAST_START_YEAR,
FORECAST_START_YEAR + FORECAST_H,
dtype=float,
)
y_true_for_format = {
"subsidence": y_fore_fmt["subs_pred"],
"gwl": y_fore_fmt.get("gwl_pred"),
}
df_eval, df_future = format_and_forecast(
y_pred=predictions_for_formatter,
y_true=y_true_for_format,
coords=X_fore.get("coords", None),
quantiles=QUANTILES if QUANTILES else None,
target_name=SUBSIDENCE_COL,
scaler_target_name=SUBS_SCALER_KEY,
output_target_name="subsidence",
target_key_pred="subs_pred",
component_index=0,
scaler_info=scaler_info_dict,
coord_scaler=coord_scaler,
coord_columns=("coord_t", "coord_x", "coord_y"),
train_end_time=cfg.get("TRAIN_END_YEAR"),
forecast_start_time=FORECAST_START_YEAR,
forecast_horizon=FORECAST_H,
future_time_grid=future_grid,
eval_forecast_step=None,
sample_index_offset=0,
city_name=CITY_NAME,
model_name=MODEL_NAME,
dataset_name=dataset_name_for_forecast,
csv_eval_path=csv_eval,
csv_future_path=csv_future,
time_as_datetime=False,
time_format=None,
verbose=1,
eval_metrics=True,
metrics_quantile_interval=(0.1, 0.9),
metrics_per_horizon=True,
metrics_extra=["pss"],
metrics_savefile=metrics_json,
metrics_save_format=".json",
metrics_time_as_str=True,
value_mode="cumulative",
input_value_mode="cumulative",
output_unit="mm",
output_unit_from="m",
output_unit_mode="overwrite",
output_unit_col="subsidence_unit",
)
cal_stats = {
"skipped": True,
"reason": "fast_sensitivity",
}
if not FAST_SENS:
df_eval_cal2, df_future_cal2, cal_stats = (
calibrate_quantile_forecasts(
df_eval=df_eval,
df_future=df_future,
target_name="subsidence",
interval=(0.1, 0.9),
target_coverage=0.8,
use="auto",
tol=0.02,
f_max=5.0,
enforce_monotonic="cummax",
save_eval=csv_eval_cal,
save_future=csv_future_cal,
save_stats=cal_stats_path,
verbose=2,
)
)
if df_eval_cal2 is not None:
df_eval = df_eval_cal2
if df_future_cal2 is not None:
df_future = df_future_cal2
# Optional diagnostics JSON from DataFrame
diag_suffix = (
"calibrated" if not FAST_SENS else "uncalibrated"
)
if df_eval is not None and not df_eval.empty:
diag_path = os.path.join(
run_dir,
(
f"{CITY_NAME}_{MODEL_NAME}_eval_diagnostics_"
f"{dataset_name_for_forecast}_H{FORECAST_H}_"
f"{diag_suffix}.json"
),
)
_ = evaluate_forecast(
df_eval,
target_name="subsidence",
quantile_interval=(0.1, 0.9),
per_horizon=True,
extra_metrics=["pss"],
savefile=diag_path,
verbose=1,
)
# ---------------------------
# Evaluate() + physics payload
# ---------------------------
# ds_eval_full = test_dataset or val_dataset
# ds_eval = ds_eval_full
# if eval_max_batches is not None:
# ds_eval = ds_eval_full.take(int(eval_max_batches))
ds_eval = make_tf_dataset(
X_fore,
y_fore,
batch_size=BATCH_SIZE,
shuffle=False,
mode=MODE,
forecast_horizon=FORECAST_H,
check_npz_finite=True,
check_finite=True,
scan_finite_batches=None,
dynamic_feature_names=DYN_NAMES,
future_feature_names=FUT_NAMES,
)
if DEBUG and (not FAST_SENS):
_ = debug_val_interval(
model_inf,
ds_eval,
n_q=len(QUANTILES),
max_batches=2,
verbose=1,
)
if not USE_IN_MEMORY_MODEL:
model_inf.compile(
optimizer=tf.keras.optimizers.SGD(
learning_rate=0.0
),
loss=loss_arg,
loss_weights=lossw_arg,
metrics=metrics_compile,
**physics_loss_weights,
)
eval_results = {}
phys = {}
try:
eval_raw = model_inf.evaluate(
ds_eval,
return_dict=True,
verbose=1,
)
eval_results = _logs_to_py(eval_raw)
for k in (
"epsilon_prior",
"epsilon_cons",
"epsilon_gw",
):
if k in eval_results:
phys[k] = float(_to_py(eval_raw[k]))
except Exception as e:
print("[Warn] evaluate() failed:", e)
phys_npz_path = os.path.join(
run_dir,
f"{CITY_NAME}_phys_payload_run_val.npz",
)
try:
_ = model_inf.export_physics_payload(
ds_eval,
# max_batches=eval_max_batches,
save_path=phys_npz_path,
format="npz",
overwrite=True,
metadata={
"city": CITY_NAME,
"split": dataset_name_for_forecast,
"time_units": TIME_UNITS,
},
)
except Exception as e:
print("[Warn] export_physics_payload failed:", e)
# ---------------------------
# Interval metrics (scaled + phys)
# ---------------------------
cov80_uncal = None
cov80_cal = None
sharp80_uncal = None
sharp80_cal = None
cov80_uncal_phys = None
cov80_cal_phys = None
sharp80_uncal_phys = None
sharp80_cal_phys = None
censor_metrics = None
y_true = tf.convert_to_tensor(
y_fore_fmt["subs_pred"],
dtype=tf.float32,
)
if QUANTILES:
s_q = tf.convert_to_tensor(s_pred, dtype=tf.float32)
s_q_cal = tf.convert_to_tensor(
predictions_for_formatter["subs_pred"],
dtype=tf.float32,
)
cov80_uncal = float(
coverage80_fn(y_true, s_q).numpy()
)
sharp80_uncal = float(
sharpness80_fn(y_true, s_q).numpy()
)
cov80_cal = float(
coverage80_fn(y_true, s_q_cal).numpy()
)
sharp80_cal = float(
sharpness80_fn(y_true, s_q_cal).numpy()
)
# Physical space
y_true_np = y_true.numpy()
s_q_np = s_q.numpy()
s_q_cal_np = s_q_cal.numpy()
y_true_phys_np = inverse_scale_target(
y_true_np,
scaler_info=scaler_info_dict,
target_name=SUBS_SCALER_KEY,
)
s_q_phys_np = inverse_scale_target(
s_q_np,
scaler_info=scaler_info_dict,
target_name=SUBS_SCALER_KEY,
)
s_q_cal_phys_np = inverse_scale_target(
s_q_cal_np,
scaler_info=scaler_info_dict,
target_name=SUBS_SCALER_KEY,
)
y_true_phys_tf = tf.convert_to_tensor(
y_true_phys_np,
dtype=tf.float32,
)
s_q_phys_tf = tf.convert_to_tensor(
s_q_phys_np,
dtype=tf.float32,
)
s_q_cal_phys_tf = tf.convert_to_tensor(
s_q_cal_phys_np,
dtype=tf.float32,
)
cov80_uncal_phys = float(
coverage80_fn(y_true_phys_tf, s_q_phys_tf).numpy()
)
sharp80_uncal_phys = float(
sharpness80_fn(
y_true_phys_tf, s_q_phys_tf
).numpy()
)
cov80_cal_phys = float(
coverage80_fn(
y_true_phys_tf,
s_q_cal_phys_tf,
).numpy()
)
sharp80_cal_phys = float(
sharpness80_fn(
y_true_phys_tf,
s_q_cal_phys_tf,
).numpy()
)
# Censor mask (no model calls)
mask = None
if CENSOR_FLAG_IDX is not None:
mask_list = []
for xb, yb in with_progress(
ds_eval,
desc="Censor mask (no preds)",
):
H = tf.shape(yb["subs_pred"])[1]
mask_b = build_censor_mask(
xb,
H,
CENSOR_FLAG_IDX,
CENSOR_THRESH,
source=(CENSOR_MASK_SOURCE or "dynamic"),
reduce_time="any",
align="broadcast",
)
mask_list.append(mask_b)
mask = (
tf.concat(mask_list, axis=0)
if mask_list
else None
)
if (mask is not None) and QUANTILES:
med_idx = int(
np.argmin(
np.abs(np.asarray(QUANTILES, float) - 0.5)
)
)
# Median prediction in model (scaled) space
s_med = tf.convert_to_tensor(
predictions_for_formatter["subs_pred"],
dtype=tf.float32,
)[..., med_idx, :]
# --- IMPORTANT: align to ds_eval subset length -----------
n_mask = tf.shape(mask)[0]
y_true_sub = y_true[:n_mask]
s_med_sub = s_med[:n_mask]
mask_sub = mask[:n_mask] # (N,H,1) bool
# Inverse-scale only the subset (faster + consistent)
y_true_phys_np = inverse_scale_target(
y_true_sub.numpy(),
scaler_info=scaler_info_dict,
target_name=SUBS_SCALER_KEY,
)
s_med_phys_np = inverse_scale_target(
s_med_sub.numpy(),
scaler_info=scaler_info_dict,
target_name=SUBS_SCALER_KEY,
)
y_true_phys = tf.convert_to_tensor(
y_true_phys_np,
dtype=tf.float32,
)
s_med_phys = tf.convert_to_tensor(
s_med_phys_np,
dtype=tf.float32,
)
mask_f = tf.cast(mask_sub, tf.float32)
num_cens = tf.reduce_sum(mask_f) + 1e-8
num_unc = tf.reduce_sum(1.0 - mask_f) + 1e-8
abs_err = tf.abs(y_true_phys - s_med_phys)
mae_cens = tf.reduce_sum(abs_err * mask_f) / num_cens
mae_unc = (
tf.reduce_sum(abs_err * (1.0 - mask_f)) / num_unc
)
censor_metrics = {
"flag_name": CENSOR_FLAG_NAME,
"threshold": float(CENSOR_THRESH),
"mae_censored": float(mae_cens.numpy()),
"mae_uncensored": float(mae_unc.numpy()),
"subset_batches": int(eval_max_batches)
if eval_max_batches is not None
else None,
}
# if (mask is not None) and QUANTILES:
# med_idx = int(
# np.argmin(
# np.abs(np.asarray(QUANTILES, float) - 0.5)
# )
# )
# s_med = tf.convert_to_tensor(
# predictions_for_formatter["subs_pred"],
# dtype=tf.float32,
# )[..., med_idx, :]
# y_true_phys_np = inverse_scale_target(
# y_true.numpy(),
# scaler_info=scaler_info_dict,
# target_name=SUBS_SCALER_KEY,
# )
# s_med_phys_np = inverse_scale_target(
# s_med.numpy(),
# scaler_info=scaler_info_dict,
# target_name=SUBS_SCALER_KEY,
# )
# y_true_phys = tf.convert_to_tensor(
# y_true_phys_np,
# dtype=tf.float32,
# )
# s_med_phys = tf.convert_to_tensor(
# s_med_phys_np,
# dtype=tf.float32,
# )
# mask_f = tf.cast(mask, tf.float32)
# num_cens = tf.reduce_sum(mask_f) + 1e-8
# num_unc = tf.reduce_sum(1.0 - mask_f) + 1e-8
# abs_err = tf.abs(y_true_phys - s_med_phys)
# mae_cens = tf.reduce_sum(abs_err * mask_f) / num_cens
mae_unc = (
tf.reduce_sum(abs_err * (1.0 - mask_f)) / num_unc
)
censor_metrics = {
"flag_name": CENSOR_FLAG_NAME,
"threshold": float(CENSOR_THRESH),
"mae_censored": float(mae_cens.numpy()),
"mae_uncensored": float(mae_unc.numpy()),
}
# ---------------------------
# Build + save eval payload
# ---------------------------
stamp = dt.datetime.now().strftime("%Y%m%d-%H%M%S")
payload = {
"timestamp": stamp,
"tf_version": tf.__version__,
"numpy_version": np.__version__,
"quantiles": QUANTILES,
"horizon": int(FORECAST_H),
"batch_size": int(BATCH_SIZE),
"metrics_evaluate": {
k: _to_py(v)
for k, v in (eval_results or {}).items()
},
"physics_diagnostics": phys,
}
if QUANTILES:
payload["interval_calibration"] = {
"target": 0.80,
"factors_per_horizon": (
getattr(cal80, "factors_", None).tolist()
if (
cal80 is not None
and hasattr(cal80, "factors_")
)
else None
),
"factors_per_horizon_from_cal_stats": cal_stats,
"coverage80_uncalibrated": cov80_uncal,
"coverage80_calibrated": cov80_cal,
"sharpness80_uncalibrated": sharp80_uncal,
"sharpness80_calibrated": sharp80_cal,
"coverage80_uncalibrated_phys": cov80_uncal_phys,
"coverage80_calibrated_phys": cov80_cal_phys,
"sharpness80_uncalibrated_phys": sharp80_uncal_phys,
"sharpness80_calibrated_phys": sharp80_cal_phys,
}
if censor_metrics is not None:
payload["censor_stratified"] = censor_metrics
if metrics_point:
payload["point_metrics"] = {
"mae": metrics_point.get("mae"),
"mse": metrics_point.get("mse"),
"r2": metrics_point.get("r2"),
}
if per_h_mae_dict:
payload.setdefault("per_horizon", {})
payload["per_horizon"]["mae"] = per_h_mae_dict
if per_h_r2_dict:
payload.setdefault("per_horizon", {})
payload["per_horizon"]["r2"] = per_h_r2_dict
try:
payload = convert_eval_payload_units(
payload,
cfg,
mode=units_mode,
scope=units_scope,
)
except Exception as e:
print("[Warn] unit conversion skipped:", e)
json_out = os.path.join(
run_dir,
f"geoprior_eval_phys_{stamp}.json",
)
with open(json_out, "w", encoding="utf-8") as f:
json.dump(payload, f, indent=2)
json_out_interp = os.path.join(
run_dir,
f"geoprior_eval_phys_{stamp}_interpretable.json",
)
try:
_ = postprocess_eval_json(
json_out,
scope="all",
out_path=json_out_interp,
overwrite=True,
add_rmse=True,
)
except Exception:
pass
# ---------------------------
# Ablation record
# ---------------------------
ABLCFG = {
"PDE_MODE_CONFIG": PDE_MODE,
"GEOPRIOR_USE_EFFECTIVE_H": GEOPRIOR_USE_EFF_H,
"GEOPRIOR_KAPPA_MODE": GEOPRIOR_KAPPA_MODE,
"GEOPRIOR_HD_FACTOR": GEOPRIOR_HD_FACTOR,
"LAMBDA_CONS": LAMBDA_CONS,
"LAMBDA_GW": LAMBDA_GW,
"LAMBDA_PRIOR": LAMBDA_PRIOR,
"LAMBDA_SMOOTH": LAMBDA_SMOOTH,
"LAMBDA_BOUNDS": LAMBDA_BOUNDS,
"LAMBDA_MV": LAMBDA_MV,
"LAMBDA_Q": LAMBDA_Q,
}
ival = payload.get("interval_calibration", {}) or {}
cal_stats2 = (
ival.get("factors_per_horizon_from_cal_stats", {})
or {}
)
ev_after = cal_stats2.get("eval_after", {}) or {}
def _pick_first(*vals):
for v in vals:
if v is not None:
return v
return None
abl_cov = _pick_first(
ev_after.get("coverage"),
ival.get("coverage80_calibrated_phys"),
ival.get("coverage80_calibrated"),
ival.get("coverage80_uncalibrated_phys"),
ival.get("coverage80_uncalibrated"),
)
abl_sharp = _pick_first(
ev_after.get("sharpness"),
ival.get("sharpness80_calibrated_phys"),
ival.get("sharpness80_calibrated"),
ival.get("sharpness80_uncalibrated_phys"),
ival.get("sharpness80_uncalibrated"),
)
eval_mae = (payload.get("point_metrics", {}) or {}).get(
"mae"
)
eval_mse = (payload.get("point_metrics", {}) or {}).get(
"mse"
)
eval_r2 = (payload.get("point_metrics", {}) or {}).get(
"r2"
)
eval_rmse = None
if eval_mse is not None:
try:
eval_rmse = float(np.sqrt(float(eval_mse)))
except Exception:
eval_rmse = None
save_ablation_record(
outdir=run_dir,
city=CITY_NAME,
model_name=MODEL_NAME,
cfg=ABLCFG,
eval_dict={
"r2": (
float(eval_r2)
if eval_r2 is not None
else None
),
"mse": (
float(eval_mse)
if eval_mse is not None
else None
),
"mae": (
float(eval_mae)
if eval_mae is not None
else None
),
"rmse": (
float(eval_rmse)
if eval_rmse is not None
else None
),
"coverage80": (
float(abl_cov)
if abl_cov is not None
else None
),
"sharpness80": (
float(abl_sharp)
if abl_sharp is not None
else None
),
"units": payload.get("units", {})
if isinstance(payload, dict)
else {},
},
phys_diag=(phys or {}),
per_h_mae=per_h_mae_dict,
per_h_r2=per_h_r2_dict,
)
# DONE marker
def _write_done_marker(
run_path: Path,
*,
city: str,
pde_mode: str,
lambda_cons: float,
lambda_prior: float,
tag: str | None = None,
) -> None:
run_path.mkdir(parents=True, exist_ok=True)
(run_path / "DONE").write_text("", encoding="utf-8")
d = {
"city": city,
"pde_mode": pde_mode,
"lambda_cons": float(lambda_cons),
"lambda_prior": float(lambda_prior),
"run_tag": tag,
}
tmp = run_path / "DONE.json.tmp"
outp = run_path / "DONE.json"
tmp.write_text(json.dumps(d), encoding="utf-8")
tmp.replace(outp)
_write_done_marker(
Path(run_dir),
city=CITY_NAME,
pde_mode=PDE_MODE,
lambda_cons=LAMBDA_CONS,
lambda_prior=LAMBDA_PRIOR,
tag=run_tag2,
)
# Physics plots (optional)
if not FAST_SENS:
try:
phys_payload, _ = load_physics_payload(
phys_npz_path
)
plot_physics_values_in(
phys_payload,
dataset=ds_eval,
keys=[
"cons_res_vals",
"log10_tau",
"log10_tau_prior",
"K",
"Ss",
"Hd",
],
mode="map",
transform=None,
savefig=os.path.join(
run_dir, "phys_maps.png"
),
)
except Exception:
pass
try:
plot_eval_future(
df_eval=df_eval,
df_future=df_future,
target_name=SUBSIDENCE_COL,
quantiles=QUANTILES,
spatial_cols=("coord_x", "coord_y"),
time_col="coord_t",
eval_years=[FORECAST_START_YEAR - 1],
future_years=future_grid,
eval_view_quantiles=[0.5],
future_view_quantiles=QUANTILES,
spatial_mode="hexbin",
hexbin_gridsize=40,
savefig_prefix=os.path.join(
run_dir,
f"{CITY_NAME}_subsidence_view",
),
save_fmts=[".png", ".pdf"],
show=False,
verbose=1,
)
except Exception:
pass
try:
save_all_figures(
output_dir=run_dir,
prefix=f"{CITY_NAME}_{MODEL_NAME}_plot_",
fmts=[".png", ".pdf"],
)
except Exception:
pass
print(
f"\n---- {CITY_NAME.upper()} {MODEL_NAME} RUN DONE ----\n"
f"Artifacts -> {run_dir}\n"
)
return run_dir