Source code for geoprior.cli.run_sensitivity

# SPDX-License-Identifier: Apache-2.0
# GeoPrior-v3 — https://github.com/earthai-tech/geoprior-v3
# Copyright (c) 2026-present
# Author: LKouadio <https://lkouadio.com>

"""
run_sensitivity.py

Driver to run a (lambda_cons, lambda_prior) sensitivity grid
for GeoPriorSubsNet using the existing Stage-2 training script.

This script calls sensitivity.py multiple times with environment
overrides. Each run should write its own ablation record
(entry in ablation_records/ablation_record.jsonl), which your
make_supp_figS6_ablations.py later aggregates.

Core overrides (expected by sensitivity.py)
--------------------------------------------
- EPOCHS_OVERRIDE
- PDE_MODE_OVERRIDE
- LAMBDA_CONS_OVERRIDE
- LAMBDA_PRIOR_OVERRIDE

Optional "deconfounding" overrides (safe to export even if
sensivity.py ignores some of them; you can wire them later)
-----------------------------------------------------------
- TRAINING_STRATEGY_OVERRIDE
- Q_POLICY_OVERRIDE
- SUBS_RESID_POLICY_OVERRIDE
- ALLOW_SUBS_RESIDUAL_OVERRIDE
- LAMBDA_Q_OVERRIDE
- PHYSICS_WARMUP_STEPS_OVERRIDE
- PHYSICS_RAMP_STEPS_OVERRIDE
- LAMBDA_GW_OVERRIDE
- LAMBDA_SMOOTH_OVERRIDE
- LAMBDA_BOUNDS_OVERRIDE
- LAMBDA_MV_OVERRIDE

Driver to run a (lambda_cons, lambda_prior) sensitivity grid
for GeoPriorSubsNet using the Stage-2 sensitivity script.

Resume mechanism
----------------
On restart, the script scans existing ablation_record.jsonl files
under the results directory and skips runs that already finished.

A run is considered "done" if an ablation record exists containing
``pde_mode``, ``lambda_cons``, and ``lambda_prior`` (and matching
CITY when available).

Usage
------
set CITY=zhongshan
python nat.com/run_lambda_sensitivity.py --epochs 20


to force rerun everything:

python nat.com/run_sensitivity.py --epochs 20 --no-resume

results live elsewhere: 

python nat.com/run_sensitivity.py --epochs 20 \
  --scan-root F:/repositories/geoprior-learn/results/zhongshan
  
python nat.com/run_sensitivity.py --epochs 20 --inprocess --fast

python nat.com/run_sensitivity.py --epochs 20 --gold --eval-max-batches 50 --fast

python nat.com/run_sensitivity.py --epochs 10 --fast --n-jobs -1   

python nat.com/run_sensitivity.py --gold --epochs 10 --fast --threads 20

"""

from __future__ import annotations

import argparse
import ast
import itertools
import json
import os
import runpy
import shutil
import subprocess
import sys
from collections.abc import Iterable
from concurrent.futures import (
    ThreadPoolExecutor,
    as_completed,
)
from dataclasses import dataclass
from pathlib import Path
from typing import (
    Any,
)

try:
    from .sensitivity_lib import (
        build_context,
        cleanup_between_runs,
    )
    from .sensitivity_lib import (
        run_one as run_one_gold,
    )
except ImportError:  # pragma: no cover
    from sensitivity_lib import (
        build_context,
        cleanup_between_runs,
    )
    from sensitivity_lib import (
        run_one as run_one_gold,
    )

from geoprior.utils import (
    apply_gpu_env,
    apply_tf_threading,
    apply_thread_env,
    default_results_dir,
    pick_gpu_id,
    resolve_device,
    resolve_gpu_ids,
    resolve_n_jobs,
    threads_per_job,
)
from geoprior.utils.nat_utils import (
    ensure_config_json,
    get_config_paths,
)

TRAIN_SCRIPT_DEFAULT = Path(__file__).with_name(
    "_sensitivity.py"
)
# 0 0.05 0.2 1.0
DEFAULT_LCONS: list[float] = [
    0.0,
    0.01,
    0.05,
    0.1,
    0.2,
    0.5,
    1.0,
]

DEFAULT_LPRIOR: list[float] = [
    0.0,
    0.01,
    0.05,
    0.1,
    0.2,
    0.5,
    1.0,
]

DEFAULT_PDE_MODES: list[str] = ["both"]  # ["none", "both" ]


def _fmt_float(x: float) -> str:
    # Stable-ish string key for floats in configs.
    # Uses "g" to match your tag style.
    try:
        return f"{float(x):g}"
    except Exception:
        return str(x)


def _norm_mode(x: str) -> str:
    return str(x).strip().lower()


def _canon_pde_mode(x: str) -> str:
    m = str(x).strip().lower()
    if m in {"both", "on", "true"}:
        return "on"
    if m in {"none", "off", "false"}:
        return "none"
    return m


[docs] @dataclass(frozen=True) class RunSpec: pde_mode: str lambda_cons: float lambda_prior: float
[docs] def key(self) -> str: pde = _canon_pde_mode(self.pde_mode) lc = _fmt_float(self.lambda_cons) lp = _fmt_float(self.lambda_prior) return f"pde={pde}|lcons={lc}|lprior={lp}"
[docs] def tag(self) -> str: # Human-readable pde = str(self.pde_mode) lc = _fmt_float(self.lambda_cons) lp = _fmt_float(self.lambda_prior) return f"pde={pde}, lcons={lc}, lprior={lp}"
[docs] def run_tag(self) -> str: # Filesystem-friendly (short) pde = _norm_mode(self.pde_mode) lc = _fmt_float(self.lambda_cons).replace(".", "p") lp = _fmt_float(self.lambda_prior).replace(".", "p") return f"pde_{pde}__lc_{lc}__lp_{lp}"
[docs] def parse_args( argv: list[str] | None = None, ) -> argparse.Namespace: p = argparse.ArgumentParser( description=( "Run a lambda_cons / lambda_prior sensitivity " "grid using the stage2 sensitivity script." ) ) p.add_argument( "--train-script", type=str, default=str(TRAIN_SCRIPT_DEFAULT), help="Path to training script (sensitivity.py).", ) p.add_argument( "--epochs", type=int, default=20, help="Epochs per run (short sensitivity runs).", ) p.add_argument( "--pde-modes", type=str, nargs="+", default=DEFAULT_PDE_MODES, help="PDE modes to sweep (e.g. none both).", ) p.add_argument( "--lcons", type=float, nargs="+", default=DEFAULT_LCONS, help="Grid for lambda_cons.", ) p.add_argument( "--lprior", type=float, nargs="+", default=DEFAULT_LPRIOR, help="Grid for lambda_prior.", ) p.add_argument( "--no-early-stopping", action="store_true", help=("Disable EarlyStopping in sensitivity.py."), ) p.add_argument( "--fast", action="store_true", help=( "Skip plotting + calibration " "extras in sensitivity.py." ), ) p.add_argument( "--eval-max-batches", type=int, default=None, help="Limit eval/export batches per run (speeds grids).", ) p.add_argument( "--batch-size", type=int, default=None, help="Override BATCH_SIZE for stage-2 runs.", ) p.add_argument( "--inprocess", action="store_true", help="Run sensitivity.py in-process (no subprocess).", ) p.add_argument( "--gold", action="store_true", help=( "Gold mode: run sensitivity in-process via sensitivity_lib " "(reuses NPZ + tf.data pipelines). Fastest for grids." ), ) p.add_argument( "--n-jobs", type=int, default=1, help="Parallel grid runs; -1=all CPUs.", ) p.add_argument( "--threads", type=int, default=0, help="Threads per run (0=auto).", ) p.add_argument( "--device", type=str, default="auto", choices=["auto", "cpu", "gpu"], help="Device policy for runs.", ) p.add_argument( "--gpu-ids", type=str, nargs="*", default=None, help=("Explicit GPU ids, e.g. --gpu-ids 0 1"), ) p.add_argument( "--gpu-allow-growth", action="store_true", help="Enable TF GPU allow-growth.", ) # ----------------------------- # Optional deconfounding knobs # ----------------------------- p.add_argument( "--strategy", type=str, default="data_first", choices=["data_first", "physics_first"], help="Training strategy override.", ) p.add_argument( "--disable-q", action="store_true", help="Export overrides to force Q always off.", ) p.add_argument( "--disable-subs-resid", action="store_true", help="Export overrides to disable subs residual.", ) p.add_argument( "--no-physics-ramp", action="store_true", help="Set physics warmup/ramp steps to 0.", ) p.add_argument( "--physics-warmup-steps", type=int, default=None, help="Override physics warmup steps.", ) p.add_argument( "--physics-ramp-steps", type=int, default=None, help="Override physics ramp steps.", ) p.add_argument( "--lambda-gw", type=float, default=None, help="Optional override for lambda_gw.", ) p.add_argument( "--lambda-smooth", type=float, default=None, help="Optional override for lambda_smooth.", ) p.add_argument( "--lambda-bounds", type=float, default=None, help="Optional override for lambda_bounds.", ) p.add_argument( "--lambda-mv", type=float, default=None, help="Optional override for lambda_mv.", ) p.add_argument( "--lambda-q", type=float, default=None, help="Optional override for lambda_q.", ) # ----------------------------- # Resume controls # ----------------------------- p.add_argument( "--no-resume", action="store_true", help="Do NOT skip completed runs.", ) p.add_argument( "--scan-root", type=str, default=None, help=( "Root directory to scan for prior " "ablation_record.jsonl files. " "Default: results_dir/CITY." ), ) p.add_argument( "--state-file", type=str, default=None, help=( "Optional JSON state file to write progress " "(default: <scan_root>/lambda_sensitivity_state.json)." ), ) # ----------------------------- # Runner controls # ----------------------------- p.add_argument( "--start", type=int, default=0, help="Start index in the remaining grid.", ) p.add_argument( "--limit", type=int, default=None, help="Max number of runs (after --start).", ) p.add_argument( "--shuffle", action="store_true", help="Shuffle run order (deterministic with --seed).", ) p.add_argument( "--seed", type=int, default=42, help="Seed for shuffling.", ) p.add_argument( "--continue-on-error", action="store_true", help="Continue grid even if a run fails.", ) p.add_argument( "--dry-run", action="store_true", help="Print planned commands without executing.", ) return p.parse_args(argv)
[docs] def build_grid( pde_modes: Iterable[str], lcons: Iterable[float], lprior: Iterable[float], ) -> list[RunSpec]: out: list[RunSpec] = [] for pde_mode in pde_modes: for lc, lp in itertools.product(lcons, lprior): out.append( RunSpec( pde_mode=str(pde_mode), lambda_cons=float(lc), lambda_prior=float(lp), ) ) return out
[docs] def maybe_shuffle( runs: list[RunSpec], *, shuffle: bool, seed: int, ) -> list[RunSpec]: if not shuffle: return runs n = len(runs) if n <= 1: return runs idx = list(range(n)) x = int(seed) & 0xFFFFFFFF for i in range(n - 1, 0, -1): x ^= (x << 13) & 0xFFFFFFFF x ^= (x >> 17) & 0xFFFFFFFF x ^= (x << 5) & 0xFFFFFFFF j = x % (i + 1) idx[i], idx[j] = idx[j], idx[i] return [runs[k] for k in idx]
[docs] def apply_runner_slicing( runs: list[RunSpec], *, start: int, limit: int | None, ) -> list[RunSpec]: if start < 0: start = 0 out = runs[start:] if limit is None: return out if limit <= 0: return [] return out[:limit]
def _default_scan_root(city: str) -> Path: # Prefer geoprior's default_results_dir if available. # Fall back to ./results. try: root = Path(default_results_dir()) except Exception: root = Path.cwd() / "results" if city and city != "<unknown>": return root / city return root def _iter_ablation_jsonl_files(root: Path) -> Iterable[Path]: if not root.exists(): return [] # Typical layout: .../ablation_records/ablation_record.jsonl return root.rglob("ablation_record.jsonl") def _load_completed_keys( scan_root: Path, *, city: str, ) -> set[str]: done: set[str] = set() for fp in _iter_ablation_jsonl_files(scan_root): try: with fp.open("r", encoding="utf-8") as f: for line in f: line = line.strip() if not line: continue try: rec = json.loads(line) except Exception: continue # Filter by city when present rec_city = rec.get("city", None) if rec_city is not None: if ( str(rec_city).lower() != str(city).lower() ): continue pde = _canon_pde_mode(rec.get("pde_mode")) lc = rec.get("lambda_cons", None) lp = rec.get("lambda_prior", None) if ( pde is None or lc is None or lp is None ): continue k = RunSpec( pde_mode=str(pde), lambda_cons=float(lc), lambda_prior=float(lp), ).key() done.add(k) except Exception: continue return done def _iter_done_json(scan_root: Path) -> Iterable[Path]: if not scan_root.exists(): return [] return scan_root.rglob("DONE.json") def _load_completed_keys_from_done( scan_root: Path, *, city: str, ) -> set[str]: done: set[str] = set() for fp in _iter_done_json(scan_root): try: rec = json.loads(fp.read_text(encoding="utf-8")) except Exception: continue rec_city = rec.get("city", None) if rec_city is not None: if str(rec_city).lower() != str(city).lower(): continue pde = rec.get("pde_mode", None) lc = rec.get("lambda_cons", None) lp = rec.get("lambda_prior", None) if pde is None or lc is None or lp is None: continue k = RunSpec( pde_mode=_canon_pde_mode(pde), lambda_cons=float(lc), lambda_prior=float(lp), ).key() done.add(k) return done def _save_state( state_path: Path, *, city: str, scan_root: Path, completed_n: int, last_key: str | None, ) -> None: payload = { "city": city, "scan_root": str(scan_root), "completed_n": int(completed_n), "last_completed_key": last_key, } try: state_path.parent.mkdir(parents=True, exist_ok=True) state_path.write_text( json.dumps(payload, indent=2), encoding="utf-8", ) except: # State is optional: never fail the run. return def _worker_banner( *, mode: str, job_i: int, n_jobs: int, pool: int, run_tag: str, device: str, gpu_id: str | None, ) -> None: d = str(device).lower().strip() gid = "-" if gpu_id is None else str(gpu_id) prefix = f"[{mode}]" if d == "gpu": msg = ( f"{prefix} job {job_i + 1}/{n_jobs} | " f"pool={pool} | RUN_TAG={run_tag} | " f"GPU={gid}" ) else: msg = ( f"{prefix} job {job_i + 1}/{n_jobs} | " f"pool={pool} | RUN_TAG={run_tag} | " f"CPU" ) print(msg, flush=True)
[docs] def make_env( base_env: dict[str, str], *, epochs: int, spec: RunSpec, strategy: str, disable_q: bool, disable_subs_resid: bool, no_physics_ramp: bool, physics_warmup_steps: int | None, physics_ramp_steps: int | None, lambda_gw: float | None, lambda_smooth: float | None, lambda_bounds: float | None, lambda_mv: float | None, lambda_q: float | None, no_early_stopping: bool, fast: bool, eval_max_batches: int | None, batch_size: int | None, ) -> dict[str, str]: env = dict(base_env) # Core sweep env["PDE_MODE_OVERRIDE"] = str(spec.pde_mode) env["EPOCHS_OVERRIDE"] = str(int(epochs)) env["LAMBDA_CONS_OVERRIDE"] = str(spec.lambda_cons) env["LAMBDA_PRIOR_OVERRIDE"] = str(spec.lambda_prior) env["SENS_WORKER_BANNER"] = "1" # Traceability env["RUN_TAG"] = spec.run_tag() env["DISABLE_EARLY_STOPPING"] = ( "1" if no_early_stopping else "0" ) env["FAST_SENSITIVITY"] = "1" if fast else "0" if eval_max_batches is not None: n = int(eval_max_batches) env["SENS_EVAL_MAX_BATCHES"] = str(n) if batch_size is not None: env["BATCH_SIZE_OVERRIDE"] = str(int(batch_size)) # Optional controls env["TRAINING_STRATEGY_OVERRIDE"] = str(strategy) if disable_q: env["Q_POLICY_OVERRIDE"] = "always_off" env["LAMBDA_Q_OVERRIDE"] = "0.0" elif lambda_q is not None: env["LAMBDA_Q_OVERRIDE"] = str(lambda_q) if disable_subs_resid: env["SUBS_RESID_POLICY_OVERRIDE"] = "always_off" env["ALLOW_SUBS_RESIDUAL_OVERRIDE"] = "0" if no_physics_ramp: env["PHYSICS_WARMUP_STEPS_OVERRIDE"] = "0" env["PHYSICS_RAMP_STEPS_OVERRIDE"] = "0" else: if physics_warmup_steps is not None: env["PHYSICS_WARMUP_STEPS_OVERRIDE"] = str( int(physics_warmup_steps) ) if physics_ramp_steps is not None: env["PHYSICS_RAMP_STEPS_OVERRIDE"] = str( int(physics_ramp_steps) ) if lambda_gw is not None: env["LAMBDA_GW_OVERRIDE"] = str(lambda_gw) if lambda_smooth is not None: env["LAMBDA_SMOOTH_OVERRIDE"] = str(lambda_smooth) if lambda_bounds is not None: env["LAMBDA_BOUNDS_OVERRIDE"] = str(lambda_bounds) if lambda_mv is not None: env["LAMBDA_MV_OVERRIDE"] = str(lambda_mv) env["VERBOSE_OVERRIDE"] = "1" env["AUDIT_STAGES_OVERRIDE"] = "off" env["DEBUG_OVERRIDE"] = "0" env["LOG_Q_DIAGNOSTICS_OVERRIDE"] = "0" # env["Q_POLICY_OVERRIDE"] = "always_off" # env["LAMBDA_Q_OVERRIDE"] = "0.0" # env["SUBS_RESID_POLICY_OVERRIDE"] = "always_off" # env["ALLOW_SUBS_RESIDUAL_OVERRIDE"] = "0" # env["PHYSICS_WARMUP_STEPS_OVERRIDE"] = "0" # env["PHYSICS_RAMP_STEPS_OVERRIDE"] = "0" # env["LAMBDA_MV_OVERRIDE"] = "0.0" # env["MV_WEIGHT_OVERRIDE"] = "0.0" env["MV_WEIGHT_OVERRIDE"] = "0.0" if lambda_mv is None: env["LAMBDA_MV_OVERRIDE"] = "0.0" return env
[docs] def run_one_script( train_script: Path, *, env: dict[str, str], dry_run: bool, inprocess: bool, ) -> None: cmd = [sys.executable, str(train_script)] if dry_run: print("[DryRun] " + " ".join(cmd)) return if not inprocess: subprocess.run(cmd, env=env, check=True) return old_env = os.environ.copy() try: os.environ.update(env) runpy.run_path(str(train_script), run_name="__main__") finally: os.environ.clear() os.environ.update(old_env) try: import tensorflow as tf tf.keras.backend.clear_session() except Exception: pass import gc gc.collect()
def _parse_override_value(text: str) -> Any: s = str(text).strip() if not s: return s low = s.lower() if low in {"true", "false"}: return low == "true" if low in {"none", "null"}: return None try: return ast.literal_eval(s) except Exception: return s def _refresh_config_fields( cfg: dict[str, Any], ) -> dict[str, Any]: out = dict(cfg) city = str(out.get("CITY_NAME", "")).strip().lower() variant = str(out.get("DATASET_VARIANT", "")).strip() big_t = out.get("BIG_FN_TEMPLATE") small_t = out.get("SMALL_FN_TEMPLATE") if city and variant and isinstance(big_t, str): out["BIG_FN"] = big_t.format( city=city, variant=variant, ) if city and variant and isinstance(small_t, str): out["SMALL_FN"] = small_t.format( city=city, variant=variant, ) return out def _apply_config( cfg: dict[str, Any], ) -> dict[str, Any]: return _refresh_config_fields(dict(cfg)) def _install_user_config( config_path: str, *, config_root: str = "nat.com", ) -> str: src = Path(config_path).expanduser().resolve() if not src.exists(): raise FileNotFoundError( f"Config file not found: {src}" ) config_py, config_json = get_config_paths( root=config_root ) dst = Path(config_py).expanduser().resolve() dst.parent.mkdir( parents=True, exist_ok=True, ) if src != dst: shutil.copy2(src, dst) json_path = Path(config_json) if json_path.exists(): json_path.unlink() return str(dst) def _persist_runtime_overrides( overrides: dict[str, Any] | None = None, *, config_root: str = "nat.com", ) -> dict[str, Any]: cfg0, config_json = ensure_config_json(root=config_root) cfg = _apply_config(cfg0) if overrides: cfg.update(overrides) cfg = _apply_config(cfg) payload: dict[str, Any] = {} cfg_json = Path(config_json) if cfg_json.exists(): try: payload = json.loads( cfg_json.read_text(encoding="utf-8") ) except Exception: payload = {} payload["city"] = cfg.get("CITY_NAME") payload["model"] = cfg.get("MODEL_NAME") payload["config"] = cfg payload.setdefault("__meta__", {}) cfg_json.write_text( json.dumps(payload, indent=2), encoding="utf-8", ) return cfg def _parse_set_items( items: list[str] | None, ) -> dict[str, Any]: out: dict[str, Any] = {} for item in items or []: if "=" not in item: raise SystemExit( f"Each --set must be KEY=VALUE. Got: {item!r}" ) key, value = item.split("=", 1) key = key.strip() if not key: raise SystemExit( f"Invalid empty key in: {item!r}" ) out[key] = _parse_override_value(value) return out def _build_wrapper_parser() -> argparse.ArgumentParser: p = argparse.ArgumentParser( prog="sensitivity", add_help=False, description=( "Run the lambda sensitivity grid through geoprior-run. " "Unknown arguments are forwarded to the native sensitivity driver." ), ) p.add_argument( "--config", type=str, default=None, help=( "Optional config.py to install into " "nat.com/config.py before running." ), ) p.add_argument( "--config-root", type=str, default="nat.com", help="Config root directory.", ) p.add_argument( "--city", type=str, default=None, help=( "Seed CITY_NAME and CITY for Stage-1 lookup when omitted " "from the active config." ), ) p.add_argument( "--model", type=str, default=None, help=( "Seed MODEL_NAME for sensitivity runs when omitted " "from the active config." ), ) p.add_argument( "--stage1-manifest", type=str, default=None, help=( "Explicit Stage-1 manifest to reuse for sensitivity training." ), ) p.add_argument( "--results-dir", type=str, default=None, help=( "Optional results root used to seed --scan-root when the " "driver was not given one explicitly." ), ) p.add_argument( "--set", dest="sets", action="append", default=[], metavar="KEY=VALUE", help=( "Extra config override. Repeat as needed, for example " "--set EPOCHS=10." ), ) p.add_argument( "-h", "--help", action="store_true", help="Show the combined wrapper/driver help.", ) return p def _cli_overrides( args: argparse.Namespace, ) -> dict[str, Any]: out = _parse_set_items(args.sets) if args.city: out["CITY_NAME"] = str(args.city).strip().lower() if args.model: out["MODEL_NAME"] = str(args.model).strip() if args.results_dir: out["RESULTS_DIR"] = str(args.results_dir).strip() return out def _has_opt( argv: list[str], *names: str, ) -> bool: for token in argv: for name in names: if token == name: return True if token.startswith(name + "="): return True return False
[docs] def sensitivity_main( argv: list[str] | None = None, ) -> None: wrapper = _build_wrapper_parser() args, rest = wrapper.parse_known_args(argv) if args.help: wrapper.print_help() print("\nDriver options:\n") try: parse_args(["--help"]) except SystemExit: pass return if args.config: _install_user_config( args.config, config_root=args.config_root, ) cfg = _persist_runtime_overrides( _cli_overrides(args), config_root=args.config_root, ) effective_city = ( str( cfg.get("CITY_NAME") or args.city or os.environ.get("CITY", "") ) .strip() .lower() ) effective_model = str( cfg.get("MODEL_NAME") or args.model or os.environ.get("MODEL_NAME_OVERRIDE", "") ).strip() effective_results_dir = str( cfg.get("RESULTS_DIR") or args.results_dir or "" ).strip() forwarded = list(rest) if effective_results_dir and not _has_opt( forwarded, "--scan-root" ): scan_root = Path(effective_results_dir) if effective_city: scan_root = scan_root / effective_city forwarded += [ "--scan-root", str(scan_root), ] env_updates: dict[str, str] = {} if effective_city: env_updates["CITY"] = effective_city if effective_model: env_updates["MODEL_NAME_OVERRIDE"] = effective_model if args.stage1_manifest: env_updates["STAGE1_MANIFEST"] = str( Path(args.stage1_manifest).expanduser().resolve() ) if effective_results_dir: env_updates["RESULTS_DIR"] = effective_results_dir env_updates["GEOPRIOR_RESULTS_DIR"] = ( effective_results_dir ) old_env = os.environ.copy() try: os.environ.update(env_updates) main(forwarded) finally: os.environ.clear() os.environ.update(old_env)
[docs] def main( argv: list[str] | None = None, ) -> None: args = parse_args(argv) train_script = Path(args.train_script) if not train_script.exists(): raise SystemExit( f"Cannot find training script at: {train_script}" ) base_env = os.environ.copy() city = base_env.get("CITY", "<unknown>") dev = resolve_device(args.device, env=base_env) gpus = [] if dev == "gpu": gpus = resolve_gpu_ids(args.gpu_ids, env=base_env) if dev == "gpu" and not gpus: print("[Warn] device=gpu but no GPUs found.") print(" Falling back to CPU.") dev = "cpu" # Build full grid grid0 = build_grid( args.pde_modes, args.lcons, args.lprior ) grid1 = maybe_shuffle( grid0, shuffle=args.shuffle, seed=args.seed ) resume = not bool(args.no_resume) # Resolve scan root if args.scan_root is not None: scan_root = Path(args.scan_root) else: scan_root = _default_scan_root(city) # State file (optional) if args.state_file is not None: state_path = Path(args.state_file) else: state_path = ( scan_root / "lambda_sensitivity_state.json" ) completed: set[str] = set() if resume: completed = _load_completed_keys_from_done( scan_root, city=city, ) if not completed: # fallback: slower but robust completed = _load_completed_keys( scan_root, city=city, ) # Filter completed BEFORE slicing if resume and completed: grid2: list[RunSpec] = [] skipped = 0 for spec in grid1: if spec.key() in completed: skipped += 1 continue grid2.append(spec) else: grid2 = list(grid1) skipped = 0 grid = apply_runner_slicing( grid2, start=args.start, limit=args.limit, ) print("[Sensitivity] Setup") print(f" CITY : {city}") print(f" train_script : {train_script}") print(f" epochs/run : {args.epochs}") print(f" pde_modes : {list(args.pde_modes)}") print(f" lcons grid : {list(args.lcons)}") print(f" lprior grid : {list(args.lprior)}") print(f" strategy : {args.strategy}") print(f" resume : {resume}") print(f" scan_root : {scan_root}") print(f" done_found : {len(completed)}") print(f" skipped_done : {skipped}") print(f" start : {args.start}") print(f" limit : {args.limit}") print(f" shuffle : {bool(args.shuffle)}") print(f" seed : {args.seed}") print(f" runs : {len(grid)} / {len(grid0)}") print(f" dry_run : {bool(args.dry_run)}") print(f" continue_err : {bool(args.continue_on_error)}") if not grid: print("[Sensitivity] No runs selected. Done.") _save_state( state_path, city=city, scan_root=scan_root, completed_n=len(completed), last_key=None, ) return # --------------------------------------------------------- # GOLD MODE: cached context + in-process per-point runs # --------------------------------------------------------- if bool(args.gold): resolve_n_jobs(-1) t = threads_per_job( n_jobs=1, threads=int(args.threads or 0), reserve=1, ) apply_tf_threading( intra=t, inter=max(1, min(4, t // 2)) ) if dev == "gpu": try: import tensorflow as tf for g in tf.config.list_physical_devices( "GPU" ): tf.config.experimental.set_memory_growth( g, True ) except: pass # Build cached context ONCE ctx = build_context(city=city, verbose=1) # If user didn’t pass scan_root, scan where gold runs live: # (inside Stage-1 run_dir) if args.scan_root is None: scan_root = Path(ctx.base_output_dir) else: scan_root = Path(args.scan_root) # Resume detection (reuse your existing DONE.json logic) resume = not bool(args.no_resume) completed: set[str] = set() if resume: completed = _load_completed_keys_from_done( scan_root, city=city, ) if not completed: completed = _load_completed_keys( scan_root, city=city, ) # Filter completed BEFORE slicing (optional but recommended) if resume and completed: grid_gold: list[RunSpec] = [] for spec in grid: if spec.key() in completed: continue grid_gold.append(spec) grid = grid_gold failures: list[tuple[int, str]] = [] last_done: str | None = None for i, spec in enumerate(grid): print("\n" + "=" * 62) print( f"[Sensitivity GOLD] Run {i + 1}/{len(grid)}" ) print(f" {spec.tag()}") print("=" * 62) # Build cfg overrides (direct cfg keys; no env needed) overrides: dict[str, Any] = { "EPOCHS": int(args.epochs), "PDE_MODE_CONFIG": str(spec.pde_mode) .strip() .lower(), "LAMBDA_CONS": float(spec.lambda_cons), "LAMBDA_PRIOR": float(spec.lambda_prior), "TRAINING_STRATEGY": str(args.strategy) .strip() .lower(), "FAST_SENSITIVITY": bool(args.fast), "DISABLE_EARLY_STOPPING": bool( args.no_early_stopping ), # big speed win: don’t reload inference model from disk "USE_IN_MEMORY_MODEL": True, # keep your "grid hygiene" defaults: "AUDIT_STAGES": "off", "DEBUG": False, "LOG_Q_DIAGNOSTICS": False, # keep your previous driver behavior "MV_WEIGHT": 0.0, } # Optional knobs from CLI (mirror make_env behavior) if args.disable_q: overrides["Q_POLICY_DATA_FIRST"] = ( "always_off" ) overrides["Q_POLICY_PHYSICS_FIRST"] = ( "always_off" ) overrides["LAMBDA_Q"] = 0.0 overrides["LAMBDA_Q_DATA_FIRST"] = 0.0 overrides["LAMBDA_Q_PHYSICS_FIRST"] = 0.0 elif args.lambda_q is not None: overrides["LAMBDA_Q"] = float(args.lambda_q) overrides["LAMBDA_Q_DATA_FIRST"] = float( args.lambda_q ) overrides["LAMBDA_Q_PHYSICS_FIRST"] = float( args.lambda_q ) if args.disable_subs_resid: overrides["SUBS_RESID_POLICY_DATA_FIRST"] = ( "always_off" ) overrides[ "SUBS_RESID_POLICY_PHYSICS_FIRST" ] = "always_off" overrides["ALLOW_SUBS_RESIDUAL"] = False if args.no_physics_ramp: overrides["PHYSICS_WARMUP_STEPS"] = 0 overrides["PHYSICS_RAMP_STEPS"] = 0 else: if args.physics_warmup_steps is not None: overrides["PHYSICS_WARMUP_STEPS"] = int( args.physics_warmup_steps ) if args.physics_ramp_steps is not None: overrides["PHYSICS_RAMP_STEPS"] = int( args.physics_ramp_steps ) if args.lambda_gw is not None: overrides["LAMBDA_GW"] = float(args.lambda_gw) if args.lambda_smooth is not None: overrides["LAMBDA_SMOOTH"] = float( args.lambda_smooth ) if args.lambda_bounds is not None: overrides["LAMBDA_BOUNDS"] = float( args.lambda_bounds ) if args.lambda_mv is not None: overrides["LAMBDA_MV"] = float(args.lambda_mv) else: # keep the same behavior as your current driver overrides["LAMBDA_MV"] = 0.0 if bool(args.dry_run): print( "[DryRun GOLD] would run:", spec.run_tag() ) continue try: run_dir = run_one_gold( ctx, overrides=overrides, run_tag=spec.run_tag(), stable_run_dir=True, eval_max_batches=args.eval_max_batches, cal_max_batches=args.eval_max_batches, ) # Mark done in-memory + state file k = spec.key() completed.add(k) last_done = k _save_state( state_path, city=city, scan_root=scan_root, completed_n=len(completed), last_key=last_done, ) cleanup_between_runs() print("[GOLD] run_dir ->", run_dir) except Exception as e: msg = f"failed: {spec.tag()} ({type(e).__name__}: {e})" failures.append((i, msg)) print("[Sensitivity GOLD] ERROR:", msg) if not args.continue_on_error: raise print("\n[Sensitivity GOLD] Finished.") if failures: print("[Sensitivity GOLD] Failures:") for _, msg in failures: print(" - " + msg) raise SystemExit(1) return # IMPORTANT: don’t fall through to old runner nj = resolve_n_jobs(args.n_jobs) if dev == "gpu": # Single GPU => force n_jobs=1 if len(gpus) <= 1 and nj > 1: print("[Warn] Single GPU detected.") print(" Forcing --n-jobs 1.") nj = 1 # Multi GPU => cap workers to num GPUs (safe) if len(gpus) >= 2: if nj > len(gpus): print("[Warn] Capping jobs to GPUs.") nj = len(gpus) if nj > 1 and (args.gold or args.inprocess): print( "[Warn] --n-jobs ignored with --gold/--inprocess." ) nj = 1 if nj > 1: def _worker(i: int, spec: RunSpec) -> str: env0 = make_env( base_env, epochs=args.epochs, spec=spec, strategy=args.strategy, disable_q=bool(args.disable_q), disable_subs_resid=bool( args.disable_subs_resid ), no_physics_ramp=bool(args.no_physics_ramp), physics_warmup_steps=( args.physics_warmup_steps ), physics_ramp_steps=(args.physics_ramp_steps), lambda_gw=args.lambda_gw, lambda_smooth=args.lambda_smooth, lambda_bounds=args.lambda_bounds, lambda_mv=args.lambda_mv, lambda_q=args.lambda_q, no_early_stopping=bool( args.no_early_stopping ), fast=bool(args.fast), eval_max_batches=(args.eval_max_batches), ) env1 = apply_thread_env( env0, n_jobs=nj, threads=int(args.threads or 0), ) if dev == "gpu": gid = pick_gpu_id(i, gpus) env1 = apply_gpu_env( env1, gpu_id=gid, allow_growth=bool(args.gpu_allow_growth), ) else: gid = None _worker_banner( mode="Sensitivity", job_i=i, n_jobs=len(grid), pool=nj, run_tag=spec.run_tag(), device=dev, gpu_id=gid, ) run_one_script( train_script, env=env1, dry_run=bool(args.dry_run), inprocess=False, ) return spec.key() failures = [] with ThreadPoolExecutor(max_workers=nj) as ex: futs = { ex.submit(_worker, i, s): (i, s) for i, s in enumerate(grid) } for fut in as_completed(futs): i, spec = futs[fut] try: k = fut.result() completed.add(k) _save_state( state_path, city=city, scan_root=scan_root, completed_n=len(completed), last_key=k, ) except Exception as e: msg = f"failed: worker={i} {spec.tag()} ({e})" failures.append(msg) print("[Sensitivity] ERROR:", msg) if not args.continue_on_error: raise if failures: raise SystemExit(1) return failures: list[tuple[int, str]] = [] last_done: str | None = None for i, spec in enumerate(grid): tag = spec.tag() print("\n" + "=" * 62) print(f"[Sensitivity] Run {i + 1}/{len(grid)}") print(f" {tag}") print("=" * 62) env = make_env( base_env, epochs=args.epochs, spec=spec, strategy=args.strategy, disable_q=bool(args.disable_q), disable_subs_resid=bool(args.disable_subs_resid), no_physics_ramp=bool(args.no_physics_ramp), physics_warmup_steps=args.physics_warmup_steps, physics_ramp_steps=args.physics_ramp_steps, lambda_gw=args.lambda_gw, lambda_smooth=args.lambda_smooth, lambda_bounds=args.lambda_bounds, lambda_mv=args.lambda_mv, lambda_q=args.lambda_q, no_early_stopping=bool(args.no_early_stopping), fast=bool(args.fast), eval_max_batches=args.eval_max_batches, ) gid = None if dev == "gpu": # sequential case: pick first visible GPU for clarity gid = pick_gpu_id(0, gpus) _worker_banner( mode="Sensitivity", job_i=i, n_jobs=len(grid), pool=1, run_tag=spec.run_tag(), device=dev, gpu_id=gid, ) try: run_one_script( train_script, env=env, dry_run=bool(args.dry_run), inprocess=bool(args.inprocess), ) # Mark done in-memory (useful if rerun same process) k = spec.key() completed.add(k) last_done = k _save_state( state_path, city=city, scan_root=scan_root, completed_n=len(completed), last_key=last_done, ) except subprocess.CalledProcessError as e: msg = f"failed: {tag} (code={e.returncode})" failures.append((i, msg)) print("[Sensitivity] ERROR: " + msg) if not args.continue_on_error: _save_state( state_path, city=city, scan_root=scan_root, completed_n=len(completed), last_key=last_done, ) raise print("\n[Sensitivity] Finished.") if failures: print("[Sensitivity] Failures:") for _, msg in failures: print(" - " + msg) raise SystemExit(1) print( "You can now run make_supp_figS6_ablations.py " "over the same --root to build the tidy table " "+ Supplement S6 figure." )
if __name__ == "__main__": sensitivity_main()