Source code for geoprior.cli.init_config

# SPDX-License-Identifier: Apache-2.0
# GeoPrior-v3 — https://github.com/earthai-tech/geoprior-v3
# Copyright (c) 2026-present
# Author: LKouadio <https://lkouadio.com>
r"""CLI for initializing GeoPrior configuration files."""

from __future__ import annotations

import argparse
from pathlib import Path
from string import Template

from ..utils.nat_utils import (
    ensure_config_json,
    get_config_paths,
    get_natcom_dir,
)

_DEFAULT_TEMPLATE = """# SPDX-License-Identifier: Apache-2.0
# Auto-generated NATCOM config for GeoPrior-v3.
# Edit values here, then rerun:
#   geoprior-init
# or
#   python -m geoprior.cli init-config

# ---------------------------------------------------------
# 1) Dataset identity and file naming
# ---------------------------------------------------------
CITY_NAME = "$CITY_NAME"
MODEL_NAME = "$MODEL_NAME"
DATA_DIR = "$DATA_DIR"
DATASET_VARIANT = "$DATASET_VARIANT"

# Input file templates. These are resolved into BIG_FN and
# SMALL_FN by geoprior.utils.nat_utils.
BIG_FN_TEMPLATE = "{city}_${BIG_STEM}_{variant}.csv"
SMALL_FN_TEMPLATE = "{city}_${SMALL_STEM}_{variant}.csv"

BIG_FN = BIG_FN_TEMPLATE.format(
    city=CITY_NAME,
    variant=DATASET_VARIANT,
)
SMALL_FN = SMALL_FN_TEMPLATE.format(
    city=CITY_NAME,
    variant=DATASET_VARIANT,
)

ALL_CITIES_PARQUET = "$ALL_CITIES_PARQUET"

# ---------------------------------------------------------
# 2) Time layout
# ---------------------------------------------------------
TRAIN_END_YEAR = $TRAIN_END_YEAR
FORECAST_START_YEAR = $FORECAST_START_YEAR
FORECAST_HORIZON_YEARS = $FORECAST_HORIZON_YEARS
TIME_STEPS = $TIME_STEPS
MODE = "$MODE"

# ---------------------------------------------------------
# 3) Required columns
# ---------------------------------------------------------
TIME_COL = "$TIME_COL"
LON_COL = "$LON_COL"
LAT_COL = "$LAT_COL"
SUBSIDENCE_COL = "$SUBSIDENCE_COL"
GWL_COL = "$GWL_COL"

# Optional extra columns used by some workflows
H_FIELD_COL_NAME = "$H_FIELD_COL_NAME"
Z_SURF_COL = "$Z_SURF_COL"
HEAD_COL = "$HEAD_COL"
INCLUDE_Z_SURF_AS_STATIC = $INCLUDE_Z_SURF_AS_STATIC

# ---------------------------------------------------------
# 4) Physical interpretation of GWL
# ---------------------------------------------------------
GWL_KIND = "$GWL_KIND"
GWL_SIGN = "$GWL_SIGN"
USE_HEAD_PROXY = $USE_HEAD_PROXY

# ---------------------------------------------------------
# 5) Feature groups
# ---------------------------------------------------------
FUTURE_DRIVER_FEATURES = [
$FUTURE_DRIVER_FEATURES
]

OPTIONAL_NUMERIC_FEATURES = [
$OPTIONAL_NUMERIC_FEATURES
]

# ---------------------------------------------------------
# 6) Forecast / model defaults
# ---------------------------------------------------------
QUANTILES = [0.10, 0.50, 0.90]
PDE_MODE_CONFIG = "$PDE_MODE_CONFIG"
NORMALIZE_COORDS = $NORMALIZE_COORDS
KEEP_COORDS_RAW = $KEEP_COORDS_RAW
SCALE_GWL = $SCALE_GWL
SCALE_H_FIELD = $SCALE_H_FIELD
SCALE_Z_SURF = $SCALE_Z_SURF

# ---------------------------------------------------------
# 7) Training defaults
# ---------------------------------------------------------
BATCH_SIZE = $BATCH_SIZE
EPOCHS = $EPOCHS
LEARNING_RATE = $LEARNING_RATE
HIDDEN_UNITS = $HIDDEN_UNITS
LSTM_UNITS = [$LSTM_UNITS]
ATTENTION_UNITS = $ATTENTION_UNITS
NUMBER_HEADS = $NUMBER_HEADS
DROPOUT_RATE = $DROPOUT_RATE
USE_BATCH_NORM = $USE_BATCH_NORM
USE_VSN = $USE_VSN

# ---------------------------------------------------------
# 8) Scaling / bookkeeping
# ---------------------------------------------------------
SCALES = {
    "subsidence": "$SUBS_SCALE",
    "gwl": "$GWL_SCALE",
}
AUDIT_STAGES = ["stage1", "stage2", "stage3"]
VERBOSE = 1
"""

_PROMPTS = (
    ("CITY_NAME", "City name", "nansha"),
    (
        "MODEL_NAME",
        "Model name",
        "GeoPriorSubsNet",
    ),
    ("DATA_DIR", "Data dir", "."),
    (
        "DATASET_VARIANT",
        "Dataset variant",
        "with_zsurf",
    ),
    (
        "ALL_CITIES_PARQUET",
        "All-cities parquet",
        "",
    ),
    (
        "TRAIN_END_YEAR",
        "Train end year",
        "2022",
    ),
    (
        "FORECAST_START_YEAR",
        "Forecast start year",
        "2023",
    ),
    (
        "FORECAST_HORIZON_YEARS",
        "Forecast horizon years",
        "3",
    ),
    ("TIME_STEPS", "Input time steps", "5"),
    ("MODE", "Mode", "tft_like"),
    ("TIME_COL", "Time column", "year"),
    ("LON_COL", "Longitude column", "longitude"),
    ("LAT_COL", "Latitude column", "latitude"),
    (
        "SUBSIDENCE_COL",
        "Subsidence column",
        "subsidence_cum",
    ),
    (
        "GWL_COL",
        "GWL column",
        "GWL_depth_bgs_m",
    ),
    (
        "H_FIELD_COL_NAME",
        "H-field thickness column",
        "soil_thickness",
    ),
    ("Z_SURF_COL", "Surface elev. col", "z_surf_m"),
    ("HEAD_COL", "Head column", "head_m"),
    (
        "INCLUDE_Z_SURF_AS_STATIC",
        "Include z_surf as static [True/False]",
        "True",
    ),
    ("GWL_KIND", "GWL kind", "depth_bgs"),
    (
        "GWL_SIGN",
        "GWL sign",
        "down_positive",
    ),
    (
        "USE_HEAD_PROXY",
        "Use head proxy [True/False]",
        "False",
    ),
    (
        "PDE_MODE_CONFIG",
        "PDE mode",
        "on",
    ),
    (
        "NORMALIZE_COORDS",
        "Normalize coords [True/False]",
        "True",
    ),
    (
        "KEEP_COORDS_RAW",
        "Keep raw coords [True/False]",
        "True",
    ),
    (
        "SCALE_GWL",
        "Scale GWL [True/False]",
        "False",
    ),
    (
        "SCALE_H_FIELD",
        "Scale H-field [True/False]",
        "False",
    ),
    (
        "SCALE_Z_SURF",
        "Scale z_surf [True/False]",
        "False",
    ),
    ("BATCH_SIZE", "Batch size", "32"),
    ("EPOCHS", "Epochs", "100"),
    (
        "LEARNING_RATE",
        "Learning rate",
        "0.001",
    ),
    ("HIDDEN_UNITS", "Hidden units", "128"),
    (
        "LSTM_UNITS",
        "LSTM units (csv)",
        "128,64",
    ),
    (
        "ATTENTION_UNITS",
        "Attention units",
        "128",
    ),
    ("NUMBER_HEADS", "Attention heads", "4"),
    (
        "DROPOUT_RATE",
        "Dropout rate",
        "0.10",
    ),
    (
        "USE_BATCH_NORM",
        "Use batch norm [True/False]",
        "False",
    ),
    ("USE_VSN", "Use VSN [True/False]", "True"),
    (
        "SUBS_SCALE",
        "Subsidence scale",
        "auto",
    ),
    ("GWL_SCALE", "GWL scale", "auto"),
)

_DRIVER_DEFAULTS = (
    "rainfall",
    "u_star",
    "h_eff",
)

_OPTIONAL_NUMERIC_DEFAULTS = (
    "building_density",
    "seismic_count",
    "seismic_risk_score",
)


def _parse_args(
    argv: list[str] | None = None,
) -> argparse.Namespace:
    p = argparse.ArgumentParser(
        prog="init-config",
        description=(
            "Create nat.com/config.py "
            "for the GeoPrior pipeline."
        ),
    )
    p.add_argument(
        "--root",
        default="nat.com",
        help="Config folder relative to project root.",
    )
    p.add_argument(
        "--force",
        action="store_true",
        help="Overwrite an existing config.py.",
    )
    p.add_argument(
        "--yes",
        action="store_true",
        help="Use defaults for missing answers.",
    )
    p.add_argument(
        "--template",
        default=None,
        help=(
            "Optional template file. "
            "Falls back to built-in template."
        ),
    )
    p.add_argument("--city", default=None)
    p.add_argument("--model-name", default=None)
    p.add_argument("--data-dir", default=None)
    p.add_argument(
        "--dataset-variant",
        default=None,
    )
    p.add_argument(
        "--train-end-year",
        type=int,
        default=None,
    )
    p.add_argument(
        "--forecast-start-year",
        type=int,
        default=None,
    )
    p.add_argument(
        "--forecast-horizon-years",
        type=int,
        default=None,
    )
    p.add_argument(
        "--time-steps",
        type=int,
        default=None,
    )
    return p.parse_args(argv)


def _ask(
    label: str,
    default: str,
    *,
    yes: bool,
) -> str:
    if yes:
        return default

    raw = input(f"{label} [{default}]: ").strip()
    return raw or default


def _as_bool_literal(value: str) -> str:
    return "True" if str(value).lower() == "true" else "False"


def _quote_list(items: tuple[str, ...]) -> str:
    return ",\n".join(f'    "{item}"' for item in items)


def _normalize_cli(
    args: argparse.Namespace,
    values: dict[str, str],
) -> None:
    cli_map = {
        "city": "CITY_NAME",
        "model_name": "MODEL_NAME",
        "data_dir": "DATA_DIR",
        "dataset_variant": "DATASET_VARIANT",
        "train_end_year": "TRAIN_END_YEAR",
        "forecast_start_year": "FORECAST_START_YEAR",
        "forecast_horizon_years": ("FORECAST_HORIZON_YEARS"),
        "time_steps": "TIME_STEPS",
    }

    for src, dst in cli_map.items():
        value = getattr(args, src, None)
        if value is not None:
            values[dst] = str(value)


def _resolve_template(
    args: argparse.Namespace,
    project_root: Path,
) -> str:
    candidates: list[Path] = []

    if args.template:
        candidates.append(Path(args.template))

    candidates.extend(
        [
            project_root
            / "scripts"
            / "templates"
            / "natcom_config_template.py",
            project_root
            / "geoprior"
            / "resources"
            / "natcom_config_template.py",
            project_root / "nat.com" / "config.template.py",
        ]
    )

    for path in candidates:
        if path.is_file():
            text = path.read_text(encoding="utf-8")
            if "$CITY_NAME" in text:
                return text

    return _DEFAULT_TEMPLATE


def _build_values(
    args: argparse.Namespace,
) -> dict[str, str]:
    values: dict[str, str] = {}

    for key, label, default in _PROMPTS:
        values[key] = _ask(
            label,
            default,
            yes=args.yes,
        )

    _normalize_cli(args, values)

    values["CITY_NAME"] = values["CITY_NAME"].strip().lower()
    values["USE_HEAD_PROXY"] = _as_bool_literal(
        values["USE_HEAD_PROXY"]
    )
    values["INCLUDE_Z_SURF_AS_STATIC"] = _as_bool_literal(
        values["INCLUDE_Z_SURF_AS_STATIC"]
    )
    values["NORMALIZE_COORDS"] = _as_bool_literal(
        values["NORMALIZE_COORDS"]
    )
    values["KEEP_COORDS_RAW"] = _as_bool_literal(
        values["KEEP_COORDS_RAW"]
    )
    values["SCALE_GWL"] = _as_bool_literal(
        values["SCALE_GWL"]
    )
    values["SCALE_H_FIELD"] = _as_bool_literal(
        values["SCALE_H_FIELD"]
    )
    values["SCALE_Z_SURF"] = _as_bool_literal(
        values["SCALE_Z_SURF"]
    )
    values["USE_BATCH_NORM"] = _as_bool_literal(
        values["USE_BATCH_NORM"]
    )
    values["USE_VSN"] = _as_bool_literal(values["USE_VSN"])

    values["BIG_STEM"] = "big"
    values["SMALL_STEM"] = "small"

    if not values["ALL_CITIES_PARQUET"].strip():
        values["ALL_CITIES_PARQUET"] = ""

    values["FUTURE_DRIVER_FEATURES"] = _quote_list(
        _DRIVER_DEFAULTS
    )
    values["OPTIONAL_NUMERIC_FEATURES"] = _quote_list(
        _OPTIONAL_NUMERIC_DEFAULTS
    )

    return values


def _write_config(
    text: str,
    *,
    root: str,
    force: bool,
) -> Path:
    nat_dir = Path(get_natcom_dir(root=root))
    nat_dir.mkdir(parents=True, exist_ok=True)

    cfg_path_str, _ = get_config_paths(root=root)
    cfg_path = Path(cfg_path_str)

    if cfg_path.exists() and not force:
        raise FileExistsError(
            f"{cfg_path} already exists. "
            "Use --force to overwrite it."
        )

    cfg_path.write_text(text, encoding="utf-8")
    return cfg_path


[docs] def main(argv: list[str] | None = None) -> None: args = _parse_args(argv) nat_dir = Path(get_natcom_dir(root=args.root)) project_root = nat_dir.parent values = _build_values(args) template_text = _resolve_template( args, project_root, ) rendered = Template(template_text).substitute(values) try: cfg_path = _write_config( rendered, root=args.root, force=args.force, ) except FileExistsError as exc: print(f"[OK] {exc}") _, json_path = ensure_config_json(root=args.root) print( f"[OK] Reused existing config and refreshed: " f"{json_path}" ) return _, json_path = ensure_config_json(root=args.root) print(f"[OK] Created: {cfg_path}") print(f"[OK] Created: {json_path}") print("") print("Next commands:") print(" geoprior-run preprocess") print(" geoprior-run train") print(" geoprior-run tune") print(" geoprior-run infer --help") print(" geoprior-run transfer --help") print("") print(" # or") print(" python -m geoprior.cli preprocess")
if __name__ == "__main__": main()