Source code for geoprior.scripts.extend_forecast

# SPDX-License-Identifier: Apache-2.0
# GeoPrior-v3 - https://github.com/earthai-tech/geoprior-v3
# Copyright (c) 2026-present
# Author: LKouadio <https://lkouadio.com>
r"""Script helpers for extending forecast outputs in time."""

from __future__ import annotations

import argparse
from pathlib import Path
from typing import Any

from . import config as cfg
from . import extend_utils as ex
from . import utils


def _city_ns() -> str:
    return cfg.CITY_CANON.get("ns", "Nansha")


def _city_zh() -> str:
    return cfg.CITY_CANON.get("zh", "Zhongshan")


def _slug_city(name: str) -> str:
    return str(name).strip().lower().replace(" ", "_")


def _resolve_cities(args) -> list[str]:
    picked: list[str] = []

    if getattr(args, "use_ns", False):
        picked.append(_city_ns())
    if getattr(args, "use_zh", False):
        picked.append(_city_zh())

    if picked:
        return picked

    raw = str(getattr(args, "cities", "") or "")
    parts = [p.strip().lower() for p in raw.split(",")]

    out: list[str] = []
    for p in parts:
        if not p:
            continue
        out.append(cfg.CITY_CANON.get(p, p.title()))
    return out


def _pick_paths(
    art: utils.Artifacts,
    split: str,
) -> tuple[Path | None, Path | None, str]:
    if split == "val":
        return (
            art.forecast_val_csv,
            art.forecast_future_csv,
            "val",
        )
    if split == "test":
        return (
            art.forecast_test_csv,
            art.forecast_test_future_csv,
            "test",
        )

    if (
        art.forecast_test_csv is not None
        and art.forecast_test_future_csv is not None
    ):
        return (
            art.forecast_test_csv,
            art.forecast_test_future_csv,
            "test",
        )

    return (
        art.forecast_val_csv,
        art.forecast_future_csv,
        "val",
    )


def _resolve_one_city(
    *,
    city: str,
    src: str | None,
    eval_csv: str | None,
    future_csv: str | None,
    split: str,
) -> dict[str, Any]:
    out: dict[str, Any] = {"city": city}

    if eval_csv and future_csv:
        out["eval_csv"] = utils.as_path(eval_csv)
        out["future_csv"] = utils.as_path(future_csv)
        out["split"] = split
        out["src_note"] = "manual"
        return out

    if not src:
        raise ValueError(
            f"{city}: provide --*-src or both "
            "--*-eval and --*-future."
        )

    art = utils.detect_artifacts(src)
    ev, fu, lab = _pick_paths(art, split)

    if ev is None:
        raise FileNotFoundError(
            f"{city}: no eval CSV under {src}"
        )
    if fu is None:
        raise FileNotFoundError(
            f"{city}: no future CSV under {src}"
        )

    out["eval_csv"] = ev
    out["future_csv"] = fu
    out["split"] = lab
    out["src_note"] = f"{ev.name}+{fu.name} ({lab})"
    return out


def _out_path(
    out_arg: str,
    *,
    city: str,
    multi: bool,
) -> Path:
    p = Path(out_arg).expanduser()

    if p.suffix.lower() != ".csv":
        stem = str(p)
        if multi:
            stem = f"{stem}_{_slug_city(city)}"
        p = Path(stem + ".csv")
    else:
        if multi:
            stem = p.with_suffix("")
            p = Path(f"{stem}_{_slug_city(city)}.csv")

    return utils.resolve_out_out(str(p))


[docs] def build_parser( *, prog: str | None = None ) -> argparse.ArgumentParser: ap = argparse.ArgumentParser( prog=prog or "extend-forecast", description=( "Extend future forecast CSV by " "1-2+ years (extrapolation)." ), ) utils.add_city_flags(ap, default_both=True) ap.add_argument("--ns-src", type=str, default=None) ap.add_argument("--zh-src", type=str, default=None) ap.add_argument("--ns-eval", type=str, default=None) ap.add_argument("--zh-eval", type=str, default=None) ap.add_argument("--ns-future", type=str, default=None) ap.add_argument("--zh-future", type=str, default=None) ap.add_argument( "--split", type=str, choices=["auto", "val", "test"], default="auto", help="When using --*-src, pick val/test.", ) ap.add_argument( "--subsidence-kind", type=str, default="cumulative", choices=["cumulative", "rate", "increment"], help="Input meaning (default cumulative).", ) ap.add_argument( "--out-kind", type=str, default="same", choices=["same", "cumulative", "rate"], help="Output kind (default same).", ) ap.add_argument( "--method", type=str, default="linear_fit", choices=["linear_fit", "linear_last"], help="Trend method for extrapolation.", ) ap.add_argument( "--window", type=int, default=3, help="Fit window (years) for trend.", ) ap.add_argument( "--years", nargs="*", type=int, default=None, help="Explicit years to add (e.g. 2026 2027).", ) ap.add_argument( "--add-years", type=int, default=1, help="If --years omitted, add N years.", ) ap.add_argument( "--unc-growth", type=str, default="sqrt", choices=["hold", "sqrt", "linear"], help="Uncertainty widening across added years.", ) ap.add_argument( "--unc-scale", type=float, default=1.0, help="Multiplier for uncertainty widening.", ) ap.add_argument( "--out", "-o", type=str, default="future_extended", help="Output stem/path (scripts/out if rel).", ) return ap
[docs] def extend_forecast_main( argv: list[str] | None = None, *, prog: str | None = None, ) -> None: args = build_parser(prog=prog).parse_args(argv) utils.ensure_script_dirs() cities = _resolve_cities(args) if not cities: cities = [_city_ns(), _city_zh()] want_ns = _city_ns() in cities want_zh = _city_zh() in cities if not want_ns and not want_zh: want_ns = True want_zh = True jobs: list[dict[str, Any]] = [] if want_ns: jobs.append( _resolve_one_city( city=_city_ns(), src=args.ns_src, eval_csv=args.ns_eval, future_csv=args.ns_future, split=args.split, ) ) if want_zh: jobs.append( _resolve_one_city( city=_city_zh(), src=args.zh_src, eval_csv=args.zh_eval, future_csv=args.zh_future, split=args.split, ) ) multi = len(jobs) > 1 cc = ex.ExtendCfg( kind=str(args.subsidence_kind), out_kind=str(args.out_kind), method=str(args.method), window=int(args.window), unc_growth=str(args.unc_growth), unc_scale=float(args.unc_scale), ) for j in jobs: city = str(j["city"]) ev = ex.load_eval_csv(j["eval_csv"], cc=cc) fu = ex.load_future_csv(j["future_csv"], cc=cc) out_df = ex.extend_future_df( fu, eval_df=ev, add_years=int(args.add_years), years=list(args.years) if args.years else None, cc=cc, ) out_p = _out_path( str(args.out), city=city, multi=multi ) out_p.parent.mkdir(parents=True, exist_ok=True) out_df.to_csv(out_p, index=False) print( f"[OK] {city}: wrote {out_p} " f"({j.get('src_note', '')})" )
[docs] def main(argv: list[str] | None = None) -> None: extend_forecast_main(argv)
if __name__ == "__main__": main()