# SPDX-License-Identifier: Apache-2.0
# GeoPrior-v3 — https://github.com/earthai-tech/geoprior-v3
# Copyright (c) 2026-present
# Author: LKouadio <https://lkouadio.com>
r"""CLI for building merged full-input NPZ artifacts."""
from __future__ import annotations
import argparse
import ast
import json
import shutil
from pathlib import Path
from typing import Any
import numpy as np
from geoprior.utils.nat_utils import (
ensure_config_json,
get_config_paths,
)
def _parse_override_value(text: str) -> Any:
s = str(text).strip()
if not s:
return s
low = s.lower()
if low in {"true", "false"}:
return low == "true"
if low in {"none", "null"}:
return None
try:
return ast.literal_eval(s)
except Exception:
return s
def _parse_set_items(
items: list[str] | None,
) -> dict[str, Any]:
out: dict[str, Any] = {}
for item in items or []:
if "=" not in item:
raise SystemExit(
f"Each --set must be KEY=VALUE. Got: {item!r}"
)
key, value = item.split("=", 1)
key = key.strip()
if not key:
raise SystemExit(
f"Invalid empty key in: {item!r}"
)
out[key] = _parse_override_value(value)
return out
def _install_user_config(
config_path: str,
*,
config_root: str = "nat.com",
) -> str:
src = Path(config_path).expanduser().resolve()
if not src.exists():
raise FileNotFoundError(
f"Config file not found: {src}"
)
config_py, config_json = get_config_paths(
root=config_root
)
dst = Path(config_py).expanduser().resolve()
dst.parent.mkdir(
parents=True,
exist_ok=True,
)
if src != dst:
shutil.copy2(src, dst)
json_path = Path(config_json)
if json_path.exists():
json_path.unlink()
return str(dst)
def _persist_runtime_overrides(
overrides: dict[str, Any] | None = None,
*,
config_root: str = "nat.com",
) -> dict[str, Any]:
cfg0, config_json = ensure_config_json(root=config_root)
cfg = dict(cfg0)
if not overrides:
return cfg
cfg.update(overrides)
payload: dict[str, Any] = {}
cfg_json = Path(config_json)
if cfg_json.exists():
try:
payload = json.loads(
cfg_json.read_text(encoding="utf-8")
)
except Exception:
payload = {}
payload["city"] = cfg.get("CITY_NAME")
payload["model"] = cfg.get("MODEL_NAME")
payload["config"] = cfg
payload.setdefault("__meta__", {})
cfg_json.write_text(
json.dumps(payload, indent=2),
encoding="utf-8",
)
return cfg
def _load_npz_dict(path: str | Path) -> dict[str, np.ndarray]:
with np.load(path, allow_pickle=False) as z:
return {k: z[k] for k in z.files}
def _save_npz(
path: str | Path,
data: dict[str, np.ndarray],
) -> None:
np.savez_compressed(path, **data)
def _choose_output_path(
manifest_path: Path,
output: str | None,
*,
output_name: str,
) -> Path:
if output:
out = Path(output).expanduser().resolve()
out.parent.mkdir(
parents=True,
exist_ok=True,
)
return out
artifacts_dir = manifest_path.parent / "artifacts"
artifacts_dir.mkdir(
parents=True,
exist_ok=True,
)
return artifacts_dir / output_name
def _resolve_manifest_path(
*,
manifest: str | None,
stage1_dir: str | None,
results_dir: str | None,
city: str | None,
model: str | None,
) -> Path:
candidates: list[Path] = []
if manifest:
p = Path(manifest).expanduser().resolve()
if not p.exists():
raise FileNotFoundError(
f"Manifest not found: {p}"
)
return p
if stage1_dir:
p = (
Path(stage1_dir).expanduser().resolve()
/ "manifest.json"
)
if p.exists():
return p
candidates.append(p)
if results_dir and city and model:
p = (
Path(results_dir).expanduser().resolve()
/ f"{str(city).strip().lower()}_{model}_stage1"
/ "manifest.json"
)
if p.exists():
return p
candidates.append(p)
if results_dir:
root = Path(results_dir).expanduser().resolve()
found = sorted(
root.glob("*_stage1/manifest.json"),
key=lambda x: x.stat().st_mtime,
reverse=True,
)
if len(found) == 1:
return found[0]
if found:
preview = "\n - ".join(str(p) for p in found[:8])
raise SystemExit(
"Multiple Stage-1 manifests found. "
"Pass --manifest or --stage1-dir.\n"
f" - {preview}"
)
tried = "\n - ".join(str(p) for p in candidates)
msg = (
"Could not resolve a Stage-1 manifest. "
"Pass --manifest or --stage1-dir explicitly."
)
if tried:
msg += f"\nTried:\n - {tried}"
raise SystemExit(msg)
def _extract_split_paths(
manifest_payload: dict[str, Any],
splits: list[str],
) -> dict[str, Path]:
npz_art = manifest_payload.get("artifacts", {}).get(
"numpy", {}
)
out: dict[str, Path] = {}
for split in splits:
key = f"{split}_inputs_npz"
path = npz_art.get(key)
if not path:
raise KeyError(
f"Missing manifest artifact key: {key}"
)
p = Path(path).expanduser().resolve()
if not p.exists():
raise FileNotFoundError(
f"Missing NPZ for split={split!r}: {p}"
)
out[split] = p
return out
def _merge_inputs(
split_arrays: dict[str, dict[str, np.ndarray]],
*,
strict_keys: bool = True,
) -> dict[str, np.ndarray]:
if not split_arrays:
raise ValueError("No split arrays were provided.")
split_names = list(split_arrays)
first = split_arrays[split_names[0]]
first_keys = list(first)
if strict_keys:
ref = set(first_keys)
for split in split_names[1:]:
keys = set(split_arrays[split])
if keys != ref:
missing = sorted(ref - keys)
extra = sorted(keys - ref)
raise ValueError(
"Input-key mismatch across splits: "
f"split={split!r}, missing={missing}, "
f"extra={extra}"
)
keys = sorted(ref)
else:
keys = []
seen: set[str] = set()
for split in split_names:
for key in split_arrays[split]:
if key not in seen:
seen.add(key)
keys.append(key)
merged: dict[str, np.ndarray] = {}
for key in keys:
arrays: list[np.ndarray] = []
for split in split_names:
arr = split_arrays[split].get(key)
if arr is None:
if strict_keys:
raise KeyError(
f"Missing key {key!r} in split {split!r}."
)
continue
arrays.append(arr)
if not arrays:
continue
if len(arrays) == 1:
merged[key] = arrays[0]
else:
merged[key] = np.concatenate(
arrays,
axis=0,
)
return merged
def _build_parser() -> argparse.ArgumentParser:
p = argparse.ArgumentParser(
prog="full-inputs-npz",
description=(
"Build a merged input NPZ from Stage-1 split input "
"artifacts (for example train+val+test)."
),
)
p.add_argument(
"--manifest",
type=str,
default=None,
help="Path to a Stage-1 manifest.json.",
)
p.add_argument(
"--stage1-dir",
type=str,
default=None,
help=(
"Stage-1 output directory containing manifest.json."
),
)
p.add_argument(
"--results-dir",
type=str,
default=None,
help=(
"Results root used to auto-resolve "
"<city>_<model>_stage1/manifest.json."
),
)
p.add_argument(
"--city",
type=str,
default=None,
help="City or dataset label used for manifest auto-resolution.",
)
p.add_argument(
"--model",
type=str,
default=None,
help="Model name used for manifest auto-resolution.",
)
p.add_argument(
"--splits",
nargs="+",
default=["train", "val", "test"],
help=(
"Input splits to concatenate in order. "
"Example: --splits train val test"
),
)
p.add_argument(
"--output",
type=str,
default=None,
help=(
"Output NPZ path. If omitted, the file is written under "
"<stage1_dir>/artifacts/."
),
)
p.add_argument(
"--output-name",
type=str,
default="full_inputs.npz",
help=(
"Default file name used when --output is omitted."
),
)
p.add_argument(
"--allow-missing-keys",
action="store_true",
help=(
"Use the union of keys across splits instead of requiring "
"all splits to share the same keys."
),
)
p.add_argument(
"--config",
type=str,
default=None,
help=(
"Optional config.py to install into nat.com/config.py "
"before resolving defaults."
),
)
p.add_argument(
"--config-root",
type=str,
default="nat.com",
help="Configuration root directory.",
)
p.add_argument(
"--set",
dest="set_items",
action="append",
default=None,
metavar="KEY=VALUE",
help=(
"One-off config override persisted into config.json. "
"May be repeated."
),
)
return p
[docs]
def build_full_inputs_main(
argv: list[str] | None = None,
) -> None:
args = _build_parser().parse_args(argv)
if args.config:
_install_user_config(
args.config,
config_root=args.config_root,
)
overrides = _parse_set_items(args.set_items)
cfg = _persist_runtime_overrides(
overrides,
config_root=args.config_root,
)
city = args.city or cfg.get("CITY_NAME")
model = args.model or cfg.get("MODEL_NAME")
results_dir = (
args.results_dir
or cfg.get("RESULTS_DIR")
or "results"
)
out = build_full_inputs_npz(
manifest=args.manifest,
stage1_dir=args.stage1_dir,
output=args.output,
output_name=args.output_name,
splits=list(args.splits),
strict_keys=not args.allow_missing_keys,
results_dir=results_dir,
city=city,
model=model,
)
print(f"Saved: {out}")
data = _load_npz_dict(out)
for key, value in sorted(data.items()):
print(f"{key}: {tuple(value.shape)}")
[docs]
def main(argv: list[str] | None = None) -> None:
build_full_inputs_main(argv)
if __name__ == "__main__":
main()