# SPDX-License-Identifier: Apache-2.0
# GeoPrior-v3 — https://github.com/earthai-tech/geoprior-v3
# Copyright (c) 2026-present
# Author: LKouadio <https://lkouadio.com>
"""Shared CLI configuration helpers.
This module centralises the parser options and runtime helpers that
repeat across GeoPrior CLI commands. The goal is to keep command
modules small, consistent, and easy to maintain.
Scope
-----
This module is intentionally limited to:
- repeated parser arguments such as ``--config`` and ``--set``
- light argument aliases and normalisation
- config installation and runtime override persistence
- small path utilities used by many commands
It does **not** own command-specific business logic. Each command keeps
its own artifact resolution and domain-specific validation.
Examples
--------
Build a parser with shared arguments::
import argparse
from geoprior.cli._config import (
add_city_arg,
add_config_args,
add_outdir_arg,
add_results_dir_arg,
)
p = argparse.ArgumentParser()
add_config_args(p)
add_city_arg(p)
add_results_dir_arg(p)
add_outdir_arg(p)
Apply config installation and runtime overrides::
cfg = bootstrap_runtime_config(
args,
field_map={
"city": "CITY_NAME",
"model": "MODEL_NAME",
"results_dir": "RESULTS_DIR",
},
)
The returned ``cfg`` is the effective config dictionary after optional
config installation and any ``--set KEY=VALUE`` overrides.
"""
from __future__ import annotations
import argparse
import ast
import json
import shutil
from collections.abc import Callable
from pathlib import Path
from typing import Any
from ..utils.nat_utils import (
ensure_config_json,
get_config_paths,
load_nat_config,
)
ConfigDict = dict[str, Any]
RefreshFn = Callable[[ConfigDict], ConfigDict]
[docs]
def parse_override_value(raw: str) -> Any:
"""Parse a scalar or container value from ``--set``.
Parameters
----------
raw : str
Raw string value from the CLI.
Returns
-------
Any
Parsed Python object when possible, otherwise the stripped
string.
Notes
-----
The parsing order is conservative:
1. case-insensitive booleans and ``none``
2. integer / float literals
3. ``ast.literal_eval`` for lists, tuples, dicts, and quoted text
4. fallback to the stripped input string
"""
text = str(raw).strip()
low = text.lower()
if low == "none":
return None
if low == "true":
return True
if low == "false":
return False
try:
if text.startswith("0") and text not in {"0", "0.0"}:
raise ValueError
return int(text)
except Exception:
pass
try:
return float(text)
except Exception:
pass
try:
return ast.literal_eval(text)
except Exception:
return text
[docs]
def parse_set_items(
items: list[str] | tuple[str, ...] | None,
) -> ConfigDict:
"""Parse repeated ``--set KEY=VALUE`` items."""
out: ConfigDict = {}
for item in items or []:
if "=" not in item:
raise SystemExit(
f"Each --set must be KEY=VALUE. Got: {item!r}"
)
key, value = item.split("=", 1)
key = key.strip()
if not key:
raise SystemExit(
f"Invalid empty key in: {item!r}"
)
out[key] = parse_override_value(value)
return out
[docs]
def install_user_config(
config_path: str,
*,
config_root: str = "nat.com",
) -> str:
"""Install a user ``config.py`` into the active config root."""
src = Path(config_path).expanduser().resolve()
if not src.exists():
raise FileNotFoundError(
f"Config file not found: {src}"
)
config_py, config_json = get_config_paths(
root=config_root
)
dst = Path(config_py).expanduser().resolve()
dst.parent.mkdir(parents=True, exist_ok=True)
if src != dst:
shutil.copy2(src, dst)
json_path = Path(config_json)
if json_path.exists():
json_path.unlink()
return str(dst)
[docs]
def persist_runtime_overrides(
overrides: ConfigDict | None = None,
*,
config_root: str = "nat.com",
refresh_fn: RefreshFn | None = None,
) -> ConfigDict:
"""Persist effective config to ``config.json``.
Parameters
----------
overrides : dict or None
Optional config overrides to merge into the active config.
config_root : str, default="nat.com"
Config root directory.
refresh_fn : callable or None
Optional callback used to refresh derived fields after the
overrides are applied.
"""
cfg0, config_json = ensure_config_json(root=config_root)
cfg = dict(cfg0)
if refresh_fn is not None:
cfg = refresh_fn(cfg)
if overrides:
cfg.update(overrides)
if refresh_fn is not None:
cfg = refresh_fn(cfg)
payload: dict[str, Any] = {}
cfg_json = Path(config_json)
if cfg_json.exists():
try:
payload = json.loads(
cfg_json.read_text(encoding="utf-8")
)
except Exception:
payload = {}
payload["city"] = cfg.get("CITY_NAME")
payload["model"] = cfg.get("MODEL_NAME")
payload["config"] = cfg
payload.setdefault("__meta__", {})
cfg_json.write_text(
json.dumps(payload, indent=2),
encoding="utf-8",
)
return cfg
[docs]
def args_to_config_overrides(
args: argparse.Namespace,
*,
field_map: dict[str, str] | None = None,
) -> ConfigDict:
"""Map parsed argument fields to config keys.
Parameters
----------
args : argparse.Namespace
Parsed CLI namespace.
field_map : dict[str, str] or None
Mapping from argument field name to config key.
Returns
-------
dict
Override dictionary combining ``--set`` items and selected
explicit CLI fields.
"""
out = parse_set_items(getattr(args, "sets", None))
for field, key in (field_map or {}).items():
value = getattr(args, field, None)
if value is None:
continue
if isinstance(value, str):
value = value.strip()
if not value:
continue
out[key] = value
return out
[docs]
def bootstrap_runtime_config(
args: argparse.Namespace,
*,
field_map: dict[str, str] | None = None,
refresh_fn: RefreshFn | None = None,
) -> ConfigDict:
"""Install config, apply overrides, and return effective cfg."""
config_root = getattr(args, "config_root", "nat.com")
config_path = getattr(args, "config", None)
if config_path:
installed = install_user_config(
config_path,
config_root=config_root,
)
print(f"[Config] Using: {installed}")
overrides = args_to_config_overrides(
args,
field_map=field_map,
)
persist_runtime_overrides(
overrides,
config_root=config_root,
refresh_fn=refresh_fn,
)
cfg = load_nat_config(root=config_root)
if refresh_fn is not None:
cfg = refresh_fn(dict(cfg))
return dict(cfg)
[docs]
def ensure_outdir(
outdir: str | Path,
) -> Path:
"""Create and return an output directory path."""
path = Path(outdir).expanduser().resolve()
path.mkdir(parents=True, exist_ok=True)
return path
[docs]
def find_latest_dir(
root: str | Path,
*,
pattern: str = "*",
must_contain: str | None = None,
) -> Path | None:
"""Return the newest matching directory under ``root``."""
base = Path(root).expanduser().resolve()
if not base.exists():
return None
cands = []
for path in base.glob(pattern):
if not path.is_dir():
continue
if must_contain is not None:
marker = path / must_contain
if not marker.exists():
continue
cands.append(path)
if not cands:
return None
return max(cands, key=lambda p: p.stat().st_mtime)
[docs]
def add_config_args(
parser: argparse.ArgumentParser,
*,
include_root: bool = True,
include_set: bool = True,
) -> argparse.ArgumentParser:
"""Add shared config installation and override arguments."""
parser.add_argument(
"--config",
type=str,
default=None,
help=(
"Optional config.py to install into the active "
"config root before running."
),
)
if include_root:
parser.add_argument(
"--config-root",
type=str,
default="nat.com",
help="Config root directory.",
)
if include_set:
parser.add_argument(
"--set",
dest="sets",
action="append",
default=[],
metavar="KEY=VALUE",
help=(
"Extra config override. Repeat as needed, "
"for example --set TIME_STEPS=6."
),
)
return parser
[docs]
def add_city_arg(
parser: argparse.ArgumentParser,
*,
dest: str = "city",
default: str | None = None,
required: bool = False,
action: str | None = None,
help: str | None = None,
) -> argparse.ArgumentParser:
"""Add one or repeated ``--city`` arguments."""
kwargs: dict[str, object] = {
"dest": dest,
"type": str,
"default": default,
"required": required,
"help": help
or "Override CITY_NAME for this command.",
}
if action is not None:
kwargs["action"] = action
parser.add_argument("--city", **kwargs)
return parser
[docs]
def add_model_arg(
parser: argparse.ArgumentParser,
*,
dest: str = "model",
default: str | None = None,
required: bool = False,
help: str | None = None,
) -> argparse.ArgumentParser:
"""Add ``--model`` argument."""
parser.add_argument(
"--model",
dest=dest,
type=str,
default=default,
required=required,
help=help or "Override MODEL_NAME for this command.",
)
return parser
[docs]
def add_results_dir_arg(
parser: argparse.ArgumentParser,
*,
dest: str = "results_dir",
default: str | None = None,
) -> argparse.ArgumentParser:
"""Add results directory argument with a root alias."""
parser.add_argument(
"--results-dir",
"--results-root",
dest=dest,
type=str,
default=default,
help=(
"Results directory or results root. Both option names "
"map to the same destination."
),
)
return parser
[docs]
def add_manifest_arg(
parser: argparse.ArgumentParser,
*,
dest: str = "manifest",
option: str = "--manifest",
help_text: str | None = None,
) -> argparse.ArgumentParser:
"""Add a manifest path argument."""
parser.add_argument(
option,
dest=dest,
type=str,
default=None,
help=help_text or "Explicit manifest path.",
)
return parser
[docs]
def add_stage1_dir_arg(
parser: argparse.ArgumentParser,
) -> argparse.ArgumentParser:
"""Add ``--stage1-dir`` argument."""
parser.add_argument(
"--stage1-dir",
type=str,
default=None,
help="Stage-1 output directory.",
)
return parser
[docs]
def add_outdir_arg(
parser: argparse.ArgumentParser,
*,
dest: str = "outdir",
default: str | None = None,
required: bool = False,
help: str | None = None,
) -> argparse.ArgumentParser:
"""Add output directory argument."""
parser.add_argument(
"--outdir",
dest=dest,
type=str,
default=default,
required=required,
help=help or "Output directory.",
)
return parser
[docs]
def add_output_stem_arg(
parser: argparse.ArgumentParser,
*,
default: str | None = None,
) -> argparse.ArgumentParser:
"""Add an output stem argument for multi-file commands."""
parser.add_argument(
"--output-stem",
type=str,
default=default,
help=(
"Base file stem used when the command writes multiple "
"related outputs."
),
)
return parser
[docs]
def add_stage2_manifest_arg(
parser: argparse.ArgumentParser,
) -> argparse.ArgumentParser:
"""Add ``--stage2-manifest`` argument."""
parser.add_argument(
"--stage2-manifest",
type=str,
default=None,
help="Stage-2 manifest path.",
)
return parser
[docs]
def add_split_arg(
parser: argparse.ArgumentParser,
*,
default: str | None = None,
choices: tuple[str, ...] | None = None,
) -> argparse.ArgumentParser:
"""Add a reusable dataset split argument."""
parser.add_argument(
"--split",
type=str,
default=default,
choices=choices,
help="Dataset split name.",
)
return parser
[docs]
def add_validation_csv_arg(
parser: argparse.ArgumentParser,
*,
required: bool = False,
) -> argparse.ArgumentParser:
"""Add external validation CSV argument."""
parser.add_argument(
"--validation-csv",
type=str,
default=None,
required=required,
help="External validation CSV path.",
)
return parser