# SPDX-License-Identifier: Apache-2.0
# GeoPrior-v3 — https://github.com/earthai-tech/geoprior-v3
# Copyright (c) 2026-present
# Author: LKouadio <https://lkouadio.com>
"""
run_sensitivity.py
Driver to run a (lambda_cons, lambda_prior) sensitivity grid
for GeoPriorSubsNet using the existing Stage-2 training script.
This script calls sensitivity.py multiple times with environment
overrides. Each run should write its own ablation record
(entry in ablation_records/ablation_record.jsonl), which your
make_supp_figS6_ablations.py later aggregates.
Core overrides (expected by sensitivity.py)
--------------------------------------------
- EPOCHS_OVERRIDE
- PDE_MODE_OVERRIDE
- LAMBDA_CONS_OVERRIDE
- LAMBDA_PRIOR_OVERRIDE
Optional "deconfounding" overrides (safe to export even if
sensivity.py ignores some of them; you can wire them later)
-----------------------------------------------------------
- TRAINING_STRATEGY_OVERRIDE
- Q_POLICY_OVERRIDE
- SUBS_RESID_POLICY_OVERRIDE
- ALLOW_SUBS_RESIDUAL_OVERRIDE
- LAMBDA_Q_OVERRIDE
- PHYSICS_WARMUP_STEPS_OVERRIDE
- PHYSICS_RAMP_STEPS_OVERRIDE
- LAMBDA_GW_OVERRIDE
- LAMBDA_SMOOTH_OVERRIDE
- LAMBDA_BOUNDS_OVERRIDE
- LAMBDA_MV_OVERRIDE
Driver to run a (lambda_cons, lambda_prior) sensitivity grid
for GeoPriorSubsNet using the Stage-2 sensitivity script.
Resume mechanism
----------------
On restart, the script scans existing ablation_record.jsonl files
under the results directory and skips runs that already finished.
A run is considered "done" if an ablation record exists containing
``pde_mode``, ``lambda_cons``, and ``lambda_prior`` (and matching
CITY when available).
Usage
------
set CITY=zhongshan
python nat.com/run_lambda_sensitivity.py --epochs 20
to force rerun everything:
python nat.com/run_sensitivity.py --epochs 20 --no-resume
results live elsewhere:
python nat.com/run_sensitivity.py --epochs 20 \
--scan-root F:/repositories/geoprior-learn/results/zhongshan
python nat.com/run_sensitivity.py --epochs 20 --inprocess --fast
python nat.com/run_sensitivity.py --epochs 20 --gold --eval-max-batches 50 --fast
python nat.com/run_sensitivity.py --epochs 10 --fast --n-jobs -1
python nat.com/run_sensitivity.py --gold --epochs 10 --fast --threads 20
"""
from __future__ import annotations
import argparse
import ast
import itertools
import json
import os
import runpy
import shutil
import subprocess
import sys
from collections.abc import Iterable
from concurrent.futures import (
ThreadPoolExecutor,
as_completed,
)
from dataclasses import dataclass
from pathlib import Path
from typing import (
Any,
)
try:
from .sensitivity_lib import (
build_context,
cleanup_between_runs,
)
from .sensitivity_lib import (
run_one as run_one_gold,
)
except ImportError: # pragma: no cover
from sensitivity_lib import (
build_context,
cleanup_between_runs,
)
from sensitivity_lib import (
run_one as run_one_gold,
)
from geoprior.utils import (
apply_gpu_env,
apply_tf_threading,
apply_thread_env,
default_results_dir,
pick_gpu_id,
resolve_device,
resolve_gpu_ids,
resolve_n_jobs,
threads_per_job,
)
from geoprior.utils.nat_utils import (
ensure_config_json,
get_config_paths,
)
TRAIN_SCRIPT_DEFAULT = Path(__file__).with_name(
"_sensitivity.py"
)
# 0 0.05 0.2 1.0
DEFAULT_LCONS: list[float] = [
0.0,
0.01,
0.05,
0.1,
0.2,
0.5,
1.0,
]
DEFAULT_LPRIOR: list[float] = [
0.0,
0.01,
0.05,
0.1,
0.2,
0.5,
1.0,
]
DEFAULT_PDE_MODES: list[str] = ["both"] # ["none", "both" ]
def _fmt_float(x: float) -> str:
# Stable-ish string key for floats in configs.
# Uses "g" to match your tag style.
try:
return f"{float(x):g}"
except Exception:
return str(x)
def _norm_mode(x: str) -> str:
return str(x).strip().lower()
def _canon_pde_mode(x: str) -> str:
m = str(x).strip().lower()
if m in {"both", "on", "true"}:
return "on"
if m in {"none", "off", "false"}:
return "none"
return m
[docs]
@dataclass(frozen=True)
class RunSpec:
pde_mode: str
lambda_cons: float
lambda_prior: float
[docs]
def key(self) -> str:
pde = _canon_pde_mode(self.pde_mode)
lc = _fmt_float(self.lambda_cons)
lp = _fmt_float(self.lambda_prior)
return f"pde={pde}|lcons={lc}|lprior={lp}"
[docs]
def tag(self) -> str:
# Human-readable
pde = str(self.pde_mode)
lc = _fmt_float(self.lambda_cons)
lp = _fmt_float(self.lambda_prior)
return f"pde={pde}, lcons={lc}, lprior={lp}"
[docs]
def run_tag(self) -> str:
# Filesystem-friendly (short)
pde = _norm_mode(self.pde_mode)
lc = _fmt_float(self.lambda_cons).replace(".", "p")
lp = _fmt_float(self.lambda_prior).replace(".", "p")
return f"pde_{pde}__lc_{lc}__lp_{lp}"
[docs]
def parse_args(
argv: list[str] | None = None,
) -> argparse.Namespace:
p = argparse.ArgumentParser(
description=(
"Run a lambda_cons / lambda_prior sensitivity "
"grid using the stage2 sensitivity script."
)
)
p.add_argument(
"--train-script",
type=str,
default=str(TRAIN_SCRIPT_DEFAULT),
help="Path to training script (sensitivity.py).",
)
p.add_argument(
"--epochs",
type=int,
default=20,
help="Epochs per run (short sensitivity runs).",
)
p.add_argument(
"--pde-modes",
type=str,
nargs="+",
default=DEFAULT_PDE_MODES,
help="PDE modes to sweep (e.g. none both).",
)
p.add_argument(
"--lcons",
type=float,
nargs="+",
default=DEFAULT_LCONS,
help="Grid for lambda_cons.",
)
p.add_argument(
"--lprior",
type=float,
nargs="+",
default=DEFAULT_LPRIOR,
help="Grid for lambda_prior.",
)
p.add_argument(
"--no-early-stopping",
action="store_true",
help=("Disable EarlyStopping in sensitivity.py."),
)
p.add_argument(
"--fast",
action="store_true",
help=(
"Skip plotting + calibration "
"extras in sensitivity.py."
),
)
p.add_argument(
"--eval-max-batches",
type=int,
default=None,
help="Limit eval/export batches per run (speeds grids).",
)
p.add_argument(
"--batch-size",
type=int,
default=None,
help="Override BATCH_SIZE for stage-2 runs.",
)
p.add_argument(
"--inprocess",
action="store_true",
help="Run sensitivity.py in-process (no subprocess).",
)
p.add_argument(
"--gold",
action="store_true",
help=(
"Gold mode: run sensitivity in-process via sensitivity_lib "
"(reuses NPZ + tf.data pipelines). Fastest for grids."
),
)
p.add_argument(
"--n-jobs",
type=int,
default=1,
help="Parallel grid runs; -1=all CPUs.",
)
p.add_argument(
"--threads",
type=int,
default=0,
help="Threads per run (0=auto).",
)
p.add_argument(
"--device",
type=str,
default="auto",
choices=["auto", "cpu", "gpu"],
help="Device policy for runs.",
)
p.add_argument(
"--gpu-ids",
type=str,
nargs="*",
default=None,
help=("Explicit GPU ids, e.g. --gpu-ids 0 1"),
)
p.add_argument(
"--gpu-allow-growth",
action="store_true",
help="Enable TF GPU allow-growth.",
)
# -----------------------------
# Optional deconfounding knobs
# -----------------------------
p.add_argument(
"--strategy",
type=str,
default="data_first",
choices=["data_first", "physics_first"],
help="Training strategy override.",
)
p.add_argument(
"--disable-q",
action="store_true",
help="Export overrides to force Q always off.",
)
p.add_argument(
"--disable-subs-resid",
action="store_true",
help="Export overrides to disable subs residual.",
)
p.add_argument(
"--no-physics-ramp",
action="store_true",
help="Set physics warmup/ramp steps to 0.",
)
p.add_argument(
"--physics-warmup-steps",
type=int,
default=None,
help="Override physics warmup steps.",
)
p.add_argument(
"--physics-ramp-steps",
type=int,
default=None,
help="Override physics ramp steps.",
)
p.add_argument(
"--lambda-gw",
type=float,
default=None,
help="Optional override for lambda_gw.",
)
p.add_argument(
"--lambda-smooth",
type=float,
default=None,
help="Optional override for lambda_smooth.",
)
p.add_argument(
"--lambda-bounds",
type=float,
default=None,
help="Optional override for lambda_bounds.",
)
p.add_argument(
"--lambda-mv",
type=float,
default=None,
help="Optional override for lambda_mv.",
)
p.add_argument(
"--lambda-q",
type=float,
default=None,
help="Optional override for lambda_q.",
)
# -----------------------------
# Resume controls
# -----------------------------
p.add_argument(
"--no-resume",
action="store_true",
help="Do NOT skip completed runs.",
)
p.add_argument(
"--scan-root",
type=str,
default=None,
help=(
"Root directory to scan for prior "
"ablation_record.jsonl files. "
"Default: results_dir/CITY."
),
)
p.add_argument(
"--state-file",
type=str,
default=None,
help=(
"Optional JSON state file to write progress "
"(default: <scan_root>/lambda_sensitivity_state.json)."
),
)
# -----------------------------
# Runner controls
# -----------------------------
p.add_argument(
"--start",
type=int,
default=0,
help="Start index in the remaining grid.",
)
p.add_argument(
"--limit",
type=int,
default=None,
help="Max number of runs (after --start).",
)
p.add_argument(
"--shuffle",
action="store_true",
help="Shuffle run order (deterministic with --seed).",
)
p.add_argument(
"--seed",
type=int,
default=42,
help="Seed for shuffling.",
)
p.add_argument(
"--continue-on-error",
action="store_true",
help="Continue grid even if a run fails.",
)
p.add_argument(
"--dry-run",
action="store_true",
help="Print planned commands without executing.",
)
return p.parse_args(argv)
[docs]
def build_grid(
pde_modes: Iterable[str],
lcons: Iterable[float],
lprior: Iterable[float],
) -> list[RunSpec]:
out: list[RunSpec] = []
for pde_mode in pde_modes:
for lc, lp in itertools.product(lcons, lprior):
out.append(
RunSpec(
pde_mode=str(pde_mode),
lambda_cons=float(lc),
lambda_prior=float(lp),
)
)
return out
[docs]
def maybe_shuffle(
runs: list[RunSpec],
*,
shuffle: bool,
seed: int,
) -> list[RunSpec]:
if not shuffle:
return runs
n = len(runs)
if n <= 1:
return runs
idx = list(range(n))
x = int(seed) & 0xFFFFFFFF
for i in range(n - 1, 0, -1):
x ^= (x << 13) & 0xFFFFFFFF
x ^= (x >> 17) & 0xFFFFFFFF
x ^= (x << 5) & 0xFFFFFFFF
j = x % (i + 1)
idx[i], idx[j] = idx[j], idx[i]
return [runs[k] for k in idx]
[docs]
def apply_runner_slicing(
runs: list[RunSpec],
*,
start: int,
limit: int | None,
) -> list[RunSpec]:
if start < 0:
start = 0
out = runs[start:]
if limit is None:
return out
if limit <= 0:
return []
return out[:limit]
def _default_scan_root(city: str) -> Path:
# Prefer geoprior's default_results_dir if available.
# Fall back to ./results.
try:
root = Path(default_results_dir())
except Exception:
root = Path.cwd() / "results"
if city and city != "<unknown>":
return root / city
return root
def _iter_ablation_jsonl_files(root: Path) -> Iterable[Path]:
if not root.exists():
return []
# Typical layout: .../ablation_records/ablation_record.jsonl
return root.rglob("ablation_record.jsonl")
def _load_completed_keys(
scan_root: Path,
*,
city: str,
) -> set[str]:
done: set[str] = set()
for fp in _iter_ablation_jsonl_files(scan_root):
try:
with fp.open("r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
try:
rec = json.loads(line)
except Exception:
continue
# Filter by city when present
rec_city = rec.get("city", None)
if rec_city is not None:
if (
str(rec_city).lower()
!= str(city).lower()
):
continue
pde = _canon_pde_mode(rec.get("pde_mode"))
lc = rec.get("lambda_cons", None)
lp = rec.get("lambda_prior", None)
if (
pde is None
or lc is None
or lp is None
):
continue
k = RunSpec(
pde_mode=str(pde),
lambda_cons=float(lc),
lambda_prior=float(lp),
).key()
done.add(k)
except Exception:
continue
return done
def _iter_done_json(scan_root: Path) -> Iterable[Path]:
if not scan_root.exists():
return []
return scan_root.rglob("DONE.json")
def _load_completed_keys_from_done(
scan_root: Path,
*,
city: str,
) -> set[str]:
done: set[str] = set()
for fp in _iter_done_json(scan_root):
try:
rec = json.loads(fp.read_text(encoding="utf-8"))
except Exception:
continue
rec_city = rec.get("city", None)
if rec_city is not None:
if str(rec_city).lower() != str(city).lower():
continue
pde = rec.get("pde_mode", None)
lc = rec.get("lambda_cons", None)
lp = rec.get("lambda_prior", None)
if pde is None or lc is None or lp is None:
continue
k = RunSpec(
pde_mode=_canon_pde_mode(pde),
lambda_cons=float(lc),
lambda_prior=float(lp),
).key()
done.add(k)
return done
def _save_state(
state_path: Path,
*,
city: str,
scan_root: Path,
completed_n: int,
last_key: str | None,
) -> None:
payload = {
"city": city,
"scan_root": str(scan_root),
"completed_n": int(completed_n),
"last_completed_key": last_key,
}
try:
state_path.parent.mkdir(parents=True, exist_ok=True)
state_path.write_text(
json.dumps(payload, indent=2),
encoding="utf-8",
)
except:
# State is optional: never fail the run.
return
def _worker_banner(
*,
mode: str,
job_i: int,
n_jobs: int,
pool: int,
run_tag: str,
device: str,
gpu_id: str | None,
) -> None:
d = str(device).lower().strip()
gid = "-" if gpu_id is None else str(gpu_id)
prefix = f"[{mode}]"
if d == "gpu":
msg = (
f"{prefix} job {job_i + 1}/{n_jobs} | "
f"pool={pool} | RUN_TAG={run_tag} | "
f"GPU={gid}"
)
else:
msg = (
f"{prefix} job {job_i + 1}/{n_jobs} | "
f"pool={pool} | RUN_TAG={run_tag} | "
f"CPU"
)
print(msg, flush=True)
[docs]
def make_env(
base_env: dict[str, str],
*,
epochs: int,
spec: RunSpec,
strategy: str,
disable_q: bool,
disable_subs_resid: bool,
no_physics_ramp: bool,
physics_warmup_steps: int | None,
physics_ramp_steps: int | None,
lambda_gw: float | None,
lambda_smooth: float | None,
lambda_bounds: float | None,
lambda_mv: float | None,
lambda_q: float | None,
no_early_stopping: bool,
fast: bool,
eval_max_batches: int | None,
batch_size: int | None,
) -> dict[str, str]:
env = dict(base_env)
# Core sweep
env["PDE_MODE_OVERRIDE"] = str(spec.pde_mode)
env["EPOCHS_OVERRIDE"] = str(int(epochs))
env["LAMBDA_CONS_OVERRIDE"] = str(spec.lambda_cons)
env["LAMBDA_PRIOR_OVERRIDE"] = str(spec.lambda_prior)
env["SENS_WORKER_BANNER"] = "1"
# Traceability
env["RUN_TAG"] = spec.run_tag()
env["DISABLE_EARLY_STOPPING"] = (
"1" if no_early_stopping else "0"
)
env["FAST_SENSITIVITY"] = "1" if fast else "0"
if eval_max_batches is not None:
n = int(eval_max_batches)
env["SENS_EVAL_MAX_BATCHES"] = str(n)
if batch_size is not None:
env["BATCH_SIZE_OVERRIDE"] = str(int(batch_size))
# Optional controls
env["TRAINING_STRATEGY_OVERRIDE"] = str(strategy)
if disable_q:
env["Q_POLICY_OVERRIDE"] = "always_off"
env["LAMBDA_Q_OVERRIDE"] = "0.0"
elif lambda_q is not None:
env["LAMBDA_Q_OVERRIDE"] = str(lambda_q)
if disable_subs_resid:
env["SUBS_RESID_POLICY_OVERRIDE"] = "always_off"
env["ALLOW_SUBS_RESIDUAL_OVERRIDE"] = "0"
if no_physics_ramp:
env["PHYSICS_WARMUP_STEPS_OVERRIDE"] = "0"
env["PHYSICS_RAMP_STEPS_OVERRIDE"] = "0"
else:
if physics_warmup_steps is not None:
env["PHYSICS_WARMUP_STEPS_OVERRIDE"] = str(
int(physics_warmup_steps)
)
if physics_ramp_steps is not None:
env["PHYSICS_RAMP_STEPS_OVERRIDE"] = str(
int(physics_ramp_steps)
)
if lambda_gw is not None:
env["LAMBDA_GW_OVERRIDE"] = str(lambda_gw)
if lambda_smooth is not None:
env["LAMBDA_SMOOTH_OVERRIDE"] = str(lambda_smooth)
if lambda_bounds is not None:
env["LAMBDA_BOUNDS_OVERRIDE"] = str(lambda_bounds)
if lambda_mv is not None:
env["LAMBDA_MV_OVERRIDE"] = str(lambda_mv)
env["VERBOSE_OVERRIDE"] = "1"
env["AUDIT_STAGES_OVERRIDE"] = "off"
env["DEBUG_OVERRIDE"] = "0"
env["LOG_Q_DIAGNOSTICS_OVERRIDE"] = "0"
# env["Q_POLICY_OVERRIDE"] = "always_off"
# env["LAMBDA_Q_OVERRIDE"] = "0.0"
# env["SUBS_RESID_POLICY_OVERRIDE"] = "always_off"
# env["ALLOW_SUBS_RESIDUAL_OVERRIDE"] = "0"
# env["PHYSICS_WARMUP_STEPS_OVERRIDE"] = "0"
# env["PHYSICS_RAMP_STEPS_OVERRIDE"] = "0"
# env["LAMBDA_MV_OVERRIDE"] = "0.0"
# env["MV_WEIGHT_OVERRIDE"] = "0.0"
env["MV_WEIGHT_OVERRIDE"] = "0.0"
if lambda_mv is None:
env["LAMBDA_MV_OVERRIDE"] = "0.0"
return env
[docs]
def run_one_script(
train_script: Path,
*,
env: dict[str, str],
dry_run: bool,
inprocess: bool,
) -> None:
cmd = [sys.executable, str(train_script)]
if dry_run:
print("[DryRun] " + " ".join(cmd))
return
if not inprocess:
subprocess.run(cmd, env=env, check=True)
return
old_env = os.environ.copy()
try:
os.environ.update(env)
runpy.run_path(str(train_script), run_name="__main__")
finally:
os.environ.clear()
os.environ.update(old_env)
try:
import tensorflow as tf
tf.keras.backend.clear_session()
except Exception:
pass
import gc
gc.collect()
def _parse_override_value(text: str) -> Any:
s = str(text).strip()
if not s:
return s
low = s.lower()
if low in {"true", "false"}:
return low == "true"
if low in {"none", "null"}:
return None
try:
return ast.literal_eval(s)
except Exception:
return s
def _refresh_config_fields(
cfg: dict[str, Any],
) -> dict[str, Any]:
out = dict(cfg)
city = str(out.get("CITY_NAME", "")).strip().lower()
variant = str(out.get("DATASET_VARIANT", "")).strip()
big_t = out.get("BIG_FN_TEMPLATE")
small_t = out.get("SMALL_FN_TEMPLATE")
if city and variant and isinstance(big_t, str):
out["BIG_FN"] = big_t.format(
city=city,
variant=variant,
)
if city and variant and isinstance(small_t, str):
out["SMALL_FN"] = small_t.format(
city=city,
variant=variant,
)
return out
def _apply_config(
cfg: dict[str, Any],
) -> dict[str, Any]:
return _refresh_config_fields(dict(cfg))
def _install_user_config(
config_path: str,
*,
config_root: str = "nat.com",
) -> str:
src = Path(config_path).expanduser().resolve()
if not src.exists():
raise FileNotFoundError(
f"Config file not found: {src}"
)
config_py, config_json = get_config_paths(
root=config_root
)
dst = Path(config_py).expanduser().resolve()
dst.parent.mkdir(
parents=True,
exist_ok=True,
)
if src != dst:
shutil.copy2(src, dst)
json_path = Path(config_json)
if json_path.exists():
json_path.unlink()
return str(dst)
def _persist_runtime_overrides(
overrides: dict[str, Any] | None = None,
*,
config_root: str = "nat.com",
) -> dict[str, Any]:
cfg0, config_json = ensure_config_json(root=config_root)
cfg = _apply_config(cfg0)
if overrides:
cfg.update(overrides)
cfg = _apply_config(cfg)
payload: dict[str, Any] = {}
cfg_json = Path(config_json)
if cfg_json.exists():
try:
payload = json.loads(
cfg_json.read_text(encoding="utf-8")
)
except Exception:
payload = {}
payload["city"] = cfg.get("CITY_NAME")
payload["model"] = cfg.get("MODEL_NAME")
payload["config"] = cfg
payload.setdefault("__meta__", {})
cfg_json.write_text(
json.dumps(payload, indent=2),
encoding="utf-8",
)
return cfg
def _parse_set_items(
items: list[str] | None,
) -> dict[str, Any]:
out: dict[str, Any] = {}
for item in items or []:
if "=" not in item:
raise SystemExit(
f"Each --set must be KEY=VALUE. Got: {item!r}"
)
key, value = item.split("=", 1)
key = key.strip()
if not key:
raise SystemExit(
f"Invalid empty key in: {item!r}"
)
out[key] = _parse_override_value(value)
return out
def _build_wrapper_parser() -> argparse.ArgumentParser:
p = argparse.ArgumentParser(
prog="sensitivity",
add_help=False,
description=(
"Run the lambda sensitivity grid through geoprior-run. "
"Unknown arguments are forwarded to the native sensitivity driver."
),
)
p.add_argument(
"--config",
type=str,
default=None,
help=(
"Optional config.py to install into "
"nat.com/config.py before running."
),
)
p.add_argument(
"--config-root",
type=str,
default="nat.com",
help="Config root directory.",
)
p.add_argument(
"--city",
type=str,
default=None,
help=(
"Seed CITY_NAME and CITY for Stage-1 lookup when omitted "
"from the active config."
),
)
p.add_argument(
"--model",
type=str,
default=None,
help=(
"Seed MODEL_NAME for sensitivity runs when omitted "
"from the active config."
),
)
p.add_argument(
"--stage1-manifest",
type=str,
default=None,
help=(
"Explicit Stage-1 manifest to reuse for sensitivity training."
),
)
p.add_argument(
"--results-dir",
type=str,
default=None,
help=(
"Optional results root used to seed --scan-root when the "
"driver was not given one explicitly."
),
)
p.add_argument(
"--set",
dest="sets",
action="append",
default=[],
metavar="KEY=VALUE",
help=(
"Extra config override. Repeat as needed, for example "
"--set EPOCHS=10."
),
)
p.add_argument(
"-h",
"--help",
action="store_true",
help="Show the combined wrapper/driver help.",
)
return p
def _cli_overrides(
args: argparse.Namespace,
) -> dict[str, Any]:
out = _parse_set_items(args.sets)
if args.city:
out["CITY_NAME"] = str(args.city).strip().lower()
if args.model:
out["MODEL_NAME"] = str(args.model).strip()
if args.results_dir:
out["RESULTS_DIR"] = str(args.results_dir).strip()
return out
def _has_opt(
argv: list[str],
*names: str,
) -> bool:
for token in argv:
for name in names:
if token == name:
return True
if token.startswith(name + "="):
return True
return False
[docs]
def sensitivity_main(
argv: list[str] | None = None,
) -> None:
wrapper = _build_wrapper_parser()
args, rest = wrapper.parse_known_args(argv)
if args.help:
wrapper.print_help()
print("\nDriver options:\n")
try:
parse_args(["--help"])
except SystemExit:
pass
return
if args.config:
_install_user_config(
args.config,
config_root=args.config_root,
)
cfg = _persist_runtime_overrides(
_cli_overrides(args),
config_root=args.config_root,
)
effective_city = (
str(
cfg.get("CITY_NAME")
or args.city
or os.environ.get("CITY", "")
)
.strip()
.lower()
)
effective_model = str(
cfg.get("MODEL_NAME")
or args.model
or os.environ.get("MODEL_NAME_OVERRIDE", "")
).strip()
effective_results_dir = str(
cfg.get("RESULTS_DIR") or args.results_dir or ""
).strip()
forwarded = list(rest)
if effective_results_dir and not _has_opt(
forwarded, "--scan-root"
):
scan_root = Path(effective_results_dir)
if effective_city:
scan_root = scan_root / effective_city
forwarded += [
"--scan-root",
str(scan_root),
]
env_updates: dict[str, str] = {}
if effective_city:
env_updates["CITY"] = effective_city
if effective_model:
env_updates["MODEL_NAME_OVERRIDE"] = effective_model
if args.stage1_manifest:
env_updates["STAGE1_MANIFEST"] = str(
Path(args.stage1_manifest).expanduser().resolve()
)
if effective_results_dir:
env_updates["RESULTS_DIR"] = effective_results_dir
env_updates["GEOPRIOR_RESULTS_DIR"] = (
effective_results_dir
)
old_env = os.environ.copy()
try:
os.environ.update(env_updates)
main(forwarded)
finally:
os.environ.clear()
os.environ.update(old_env)
[docs]
def main(
argv: list[str] | None = None,
) -> None:
args = parse_args(argv)
train_script = Path(args.train_script)
if not train_script.exists():
raise SystemExit(
f"Cannot find training script at: {train_script}"
)
base_env = os.environ.copy()
city = base_env.get("CITY", "<unknown>")
dev = resolve_device(args.device, env=base_env)
gpus = []
if dev == "gpu":
gpus = resolve_gpu_ids(args.gpu_ids, env=base_env)
if dev == "gpu" and not gpus:
print("[Warn] device=gpu but no GPUs found.")
print(" Falling back to CPU.")
dev = "cpu"
# Build full grid
grid0 = build_grid(
args.pde_modes, args.lcons, args.lprior
)
grid1 = maybe_shuffle(
grid0, shuffle=args.shuffle, seed=args.seed
)
resume = not bool(args.no_resume)
# Resolve scan root
if args.scan_root is not None:
scan_root = Path(args.scan_root)
else:
scan_root = _default_scan_root(city)
# State file (optional)
if args.state_file is not None:
state_path = Path(args.state_file)
else:
state_path = (
scan_root / "lambda_sensitivity_state.json"
)
completed: set[str] = set()
if resume:
completed = _load_completed_keys_from_done(
scan_root,
city=city,
)
if not completed:
# fallback: slower but robust
completed = _load_completed_keys(
scan_root,
city=city,
)
# Filter completed BEFORE slicing
if resume and completed:
grid2: list[RunSpec] = []
skipped = 0
for spec in grid1:
if spec.key() in completed:
skipped += 1
continue
grid2.append(spec)
else:
grid2 = list(grid1)
skipped = 0
grid = apply_runner_slicing(
grid2,
start=args.start,
limit=args.limit,
)
print("[Sensitivity] Setup")
print(f" CITY : {city}")
print(f" train_script : {train_script}")
print(f" epochs/run : {args.epochs}")
print(f" pde_modes : {list(args.pde_modes)}")
print(f" lcons grid : {list(args.lcons)}")
print(f" lprior grid : {list(args.lprior)}")
print(f" strategy : {args.strategy}")
print(f" resume : {resume}")
print(f" scan_root : {scan_root}")
print(f" done_found : {len(completed)}")
print(f" skipped_done : {skipped}")
print(f" start : {args.start}")
print(f" limit : {args.limit}")
print(f" shuffle : {bool(args.shuffle)}")
print(f" seed : {args.seed}")
print(f" runs : {len(grid)} / {len(grid0)}")
print(f" dry_run : {bool(args.dry_run)}")
print(f" continue_err : {bool(args.continue_on_error)}")
if not grid:
print("[Sensitivity] No runs selected. Done.")
_save_state(
state_path,
city=city,
scan_root=scan_root,
completed_n=len(completed),
last_key=None,
)
return
# ---------------------------------------------------------
# GOLD MODE: cached context + in-process per-point runs
# ---------------------------------------------------------
if bool(args.gold):
resolve_n_jobs(-1)
t = threads_per_job(
n_jobs=1,
threads=int(args.threads or 0),
reserve=1,
)
apply_tf_threading(
intra=t, inter=max(1, min(4, t // 2))
)
if dev == "gpu":
try:
import tensorflow as tf
for g in tf.config.list_physical_devices(
"GPU"
):
tf.config.experimental.set_memory_growth(
g, True
)
except:
pass
# Build cached context ONCE
ctx = build_context(city=city, verbose=1)
# If user didn’t pass scan_root, scan where gold runs live:
# (inside Stage-1 run_dir)
if args.scan_root is None:
scan_root = Path(ctx.base_output_dir)
else:
scan_root = Path(args.scan_root)
# Resume detection (reuse your existing DONE.json logic)
resume = not bool(args.no_resume)
completed: set[str] = set()
if resume:
completed = _load_completed_keys_from_done(
scan_root,
city=city,
)
if not completed:
completed = _load_completed_keys(
scan_root,
city=city,
)
# Filter completed BEFORE slicing (optional but recommended)
if resume and completed:
grid_gold: list[RunSpec] = []
for spec in grid:
if spec.key() in completed:
continue
grid_gold.append(spec)
grid = grid_gold
failures: list[tuple[int, str]] = []
last_done: str | None = None
for i, spec in enumerate(grid):
print("\n" + "=" * 62)
print(
f"[Sensitivity GOLD] Run {i + 1}/{len(grid)}"
)
print(f" {spec.tag()}")
print("=" * 62)
# Build cfg overrides (direct cfg keys; no env needed)
overrides: dict[str, Any] = {
"EPOCHS": int(args.epochs),
"PDE_MODE_CONFIG": str(spec.pde_mode)
.strip()
.lower(),
"LAMBDA_CONS": float(spec.lambda_cons),
"LAMBDA_PRIOR": float(spec.lambda_prior),
"TRAINING_STRATEGY": str(args.strategy)
.strip()
.lower(),
"FAST_SENSITIVITY": bool(args.fast),
"DISABLE_EARLY_STOPPING": bool(
args.no_early_stopping
),
# big speed win: don’t reload inference model from disk
"USE_IN_MEMORY_MODEL": True,
# keep your "grid hygiene" defaults:
"AUDIT_STAGES": "off",
"DEBUG": False,
"LOG_Q_DIAGNOSTICS": False,
# keep your previous driver behavior
"MV_WEIGHT": 0.0,
}
# Optional knobs from CLI (mirror make_env behavior)
if args.disable_q:
overrides["Q_POLICY_DATA_FIRST"] = (
"always_off"
)
overrides["Q_POLICY_PHYSICS_FIRST"] = (
"always_off"
)
overrides["LAMBDA_Q"] = 0.0
overrides["LAMBDA_Q_DATA_FIRST"] = 0.0
overrides["LAMBDA_Q_PHYSICS_FIRST"] = 0.0
elif args.lambda_q is not None:
overrides["LAMBDA_Q"] = float(args.lambda_q)
overrides["LAMBDA_Q_DATA_FIRST"] = float(
args.lambda_q
)
overrides["LAMBDA_Q_PHYSICS_FIRST"] = float(
args.lambda_q
)
if args.disable_subs_resid:
overrides["SUBS_RESID_POLICY_DATA_FIRST"] = (
"always_off"
)
overrides[
"SUBS_RESID_POLICY_PHYSICS_FIRST"
] = "always_off"
overrides["ALLOW_SUBS_RESIDUAL"] = False
if args.no_physics_ramp:
overrides["PHYSICS_WARMUP_STEPS"] = 0
overrides["PHYSICS_RAMP_STEPS"] = 0
else:
if args.physics_warmup_steps is not None:
overrides["PHYSICS_WARMUP_STEPS"] = int(
args.physics_warmup_steps
)
if args.physics_ramp_steps is not None:
overrides["PHYSICS_RAMP_STEPS"] = int(
args.physics_ramp_steps
)
if args.lambda_gw is not None:
overrides["LAMBDA_GW"] = float(args.lambda_gw)
if args.lambda_smooth is not None:
overrides["LAMBDA_SMOOTH"] = float(
args.lambda_smooth
)
if args.lambda_bounds is not None:
overrides["LAMBDA_BOUNDS"] = float(
args.lambda_bounds
)
if args.lambda_mv is not None:
overrides["LAMBDA_MV"] = float(args.lambda_mv)
else:
# keep the same behavior as your current driver
overrides["LAMBDA_MV"] = 0.0
if bool(args.dry_run):
print(
"[DryRun GOLD] would run:", spec.run_tag()
)
continue
try:
run_dir = run_one_gold(
ctx,
overrides=overrides,
run_tag=spec.run_tag(),
stable_run_dir=True,
eval_max_batches=args.eval_max_batches,
cal_max_batches=args.eval_max_batches,
)
# Mark done in-memory + state file
k = spec.key()
completed.add(k)
last_done = k
_save_state(
state_path,
city=city,
scan_root=scan_root,
completed_n=len(completed),
last_key=last_done,
)
cleanup_between_runs()
print("[GOLD] run_dir ->", run_dir)
except Exception as e:
msg = f"failed: {spec.tag()} ({type(e).__name__}: {e})"
failures.append((i, msg))
print("[Sensitivity GOLD] ERROR:", msg)
if not args.continue_on_error:
raise
print("\n[Sensitivity GOLD] Finished.")
if failures:
print("[Sensitivity GOLD] Failures:")
for _, msg in failures:
print(" - " + msg)
raise SystemExit(1)
return # IMPORTANT: don’t fall through to old runner
nj = resolve_n_jobs(args.n_jobs)
if dev == "gpu":
# Single GPU => force n_jobs=1
if len(gpus) <= 1 and nj > 1:
print("[Warn] Single GPU detected.")
print(" Forcing --n-jobs 1.")
nj = 1
# Multi GPU => cap workers to num GPUs (safe)
if len(gpus) >= 2:
if nj > len(gpus):
print("[Warn] Capping jobs to GPUs.")
nj = len(gpus)
if nj > 1 and (args.gold or args.inprocess):
print(
"[Warn] --n-jobs ignored with --gold/--inprocess."
)
nj = 1
if nj > 1:
def _worker(i: int, spec: RunSpec) -> str:
env0 = make_env(
base_env,
epochs=args.epochs,
spec=spec,
strategy=args.strategy,
disable_q=bool(args.disable_q),
disable_subs_resid=bool(
args.disable_subs_resid
),
no_physics_ramp=bool(args.no_physics_ramp),
physics_warmup_steps=(
args.physics_warmup_steps
),
physics_ramp_steps=(args.physics_ramp_steps),
lambda_gw=args.lambda_gw,
lambda_smooth=args.lambda_smooth,
lambda_bounds=args.lambda_bounds,
lambda_mv=args.lambda_mv,
lambda_q=args.lambda_q,
no_early_stopping=bool(
args.no_early_stopping
),
fast=bool(args.fast),
eval_max_batches=(args.eval_max_batches),
)
env1 = apply_thread_env(
env0,
n_jobs=nj,
threads=int(args.threads or 0),
)
if dev == "gpu":
gid = pick_gpu_id(i, gpus)
env1 = apply_gpu_env(
env1,
gpu_id=gid,
allow_growth=bool(args.gpu_allow_growth),
)
else:
gid = None
_worker_banner(
mode="Sensitivity",
job_i=i,
n_jobs=len(grid),
pool=nj,
run_tag=spec.run_tag(),
device=dev,
gpu_id=gid,
)
run_one_script(
train_script,
env=env1,
dry_run=bool(args.dry_run),
inprocess=False,
)
return spec.key()
failures = []
with ThreadPoolExecutor(max_workers=nj) as ex:
futs = {
ex.submit(_worker, i, s): (i, s)
for i, s in enumerate(grid)
}
for fut in as_completed(futs):
i, spec = futs[fut]
try:
k = fut.result()
completed.add(k)
_save_state(
state_path,
city=city,
scan_root=scan_root,
completed_n=len(completed),
last_key=k,
)
except Exception as e:
msg = f"failed: worker={i} {spec.tag()} ({e})"
failures.append(msg)
print("[Sensitivity] ERROR:", msg)
if not args.continue_on_error:
raise
if failures:
raise SystemExit(1)
return
failures: list[tuple[int, str]] = []
last_done: str | None = None
for i, spec in enumerate(grid):
tag = spec.tag()
print("\n" + "=" * 62)
print(f"[Sensitivity] Run {i + 1}/{len(grid)}")
print(f" {tag}")
print("=" * 62)
env = make_env(
base_env,
epochs=args.epochs,
spec=spec,
strategy=args.strategy,
disable_q=bool(args.disable_q),
disable_subs_resid=bool(args.disable_subs_resid),
no_physics_ramp=bool(args.no_physics_ramp),
physics_warmup_steps=args.physics_warmup_steps,
physics_ramp_steps=args.physics_ramp_steps,
lambda_gw=args.lambda_gw,
lambda_smooth=args.lambda_smooth,
lambda_bounds=args.lambda_bounds,
lambda_mv=args.lambda_mv,
lambda_q=args.lambda_q,
no_early_stopping=bool(args.no_early_stopping),
fast=bool(args.fast),
eval_max_batches=args.eval_max_batches,
)
gid = None
if dev == "gpu":
# sequential case: pick first visible GPU for clarity
gid = pick_gpu_id(0, gpus)
_worker_banner(
mode="Sensitivity",
job_i=i,
n_jobs=len(grid),
pool=1,
run_tag=spec.run_tag(),
device=dev,
gpu_id=gid,
)
try:
run_one_script(
train_script,
env=env,
dry_run=bool(args.dry_run),
inprocess=bool(args.inprocess),
)
# Mark done in-memory (useful if rerun same process)
k = spec.key()
completed.add(k)
last_done = k
_save_state(
state_path,
city=city,
scan_root=scan_root,
completed_n=len(completed),
last_key=last_done,
)
except subprocess.CalledProcessError as e:
msg = f"failed: {tag} (code={e.returncode})"
failures.append((i, msg))
print("[Sensitivity] ERROR: " + msg)
if not args.continue_on_error:
_save_state(
state_path,
city=city,
scan_root=scan_root,
completed_n=len(completed),
last_key=last_done,
)
raise
print("\n[Sensitivity] Finished.")
if failures:
print("[Sensitivity] Failures:")
for _, msg in failures:
print(" - " + msg)
raise SystemExit(1)
print(
"You can now run make_supp_figS6_ablations.py "
"over the same --root to build the tidy table "
"+ Supplement S6 figure."
)
if __name__ == "__main__":
sensitivity_main()