# SPDX-License-Identifier: Apache-2.0
# GeoPrior-v3 https://github.com/earthai-tech/geoprior-v3
# Copyright (c) 2026-present
# Author: LKouadio <https://lkouadio.com>
r"""GeoPrior scaling config helpers (Keras-serializable)."""
from __future__ import annotations
import json
import os
from collections.abc import Callable, Mapping, Sequence
from dataclasses import dataclass, field
from typing import Any
import numpy as np
from ...logging import get_logger
from .. import KERAS_DEPS, dependency_message
from .utils import (
canonicalize_scaling_kwargs,
enforce_scaling_alias_consistency,
load_scaling_kwargs,
validate_scaling_kwargs,
)
K = KERAS_DEPS
register_keras_serializable = K.register_keras_serializable
DEP_MSG = dependency_message("models.subsidence.scaling")
logger = get_logger(__name__)
def _jsonify(x):
r"""
Convert nested objects into JSON-serializable Python types.
This helper walks common container types and converts values
into plain Python objects suitable for storage in a Keras
configuration dictionary.
It is intended for defensive serialization, where values may
include NumPy scalars, tuples, sets, or mapping-like objects.
Parameters
----------
x : object
Input object to convert. This may be a mapping, list,
tuple, set, NumPy scalar, or any other Python object.
Returns
-------
out : object
A JSON-serializable representation of ``x`` when possible.
Containers are converted recursively. Objects that do not
require conversion are returned unchanged.
Notes
-----
- Mapping keys are cast to ``str`` to avoid non-JSON keys.
- Sets are converted to sorted lists to ensure stability.
- NumPy scalar types are converted using ``.item()``.
Examples
--------
>>> _jsonify({"a": 1})
{'a': 1}
>>> import numpy as np
>>> _jsonify({"v": np.float32(2.0)})
{'v': 2.0}
See Also
--------
GeoPriorScalingConfig.get_config :
Uses this function to serialize configuration safely.
"""
# Dict-like: ensure keys are strings.
if isinstance(x, Mapping):
return {str(k): _jsonify(v) for k, v in x.items()}
# List/tuple: keep ordering.
if isinstance(x, list | tuple):
return [_jsonify(v) for v in x]
# Set: stable ordering for deterministic configs.
if isinstance(x, set):
return sorted(_jsonify(v) for v in x)
# NumPy scalar: convert to Python scalar.
if hasattr(x, "item") and isinstance(
x,
np.generic,
):
return x.item()
# Fall back: return as-is.
return x
[docs]
@register_keras_serializable(
"geoprior.nn.pinn.geoprior",
name="GeoPriorScalingConfig",
)
@dataclass
class GeoPriorScalingConfig:
r"""
Scaling configuration utilities for GeoPrior PINN.
This module defines :class:`~GeoPriorScalingConfig`, a small
Keras-serializable container used to store and reconstruct
the physics scaling and slicing controls used by
GeoPriorSubsNet.
The scaling configuration is critical because it governs how
coordinates, time units, groundwater variables, and physics
residuals are interpreted and non-dimensionalized. If this
configuration is not faithfully serialized via Keras
``get_config()``, a reloaded model may be reconstructed with
a different effective physics behavior.
The main entry point is :meth:`GeoPriorScalingConfig.from_any`,
which accepts a ``dict``-like mapping, a file path ``str``,
or an existing :class:`~GeoPriorScalingConfig` instance. The
resolved configuration is produced by :meth:`resolve`, which
runs the same canonicalization and validation pipeline used
during training.
Notes
-----
- The resolved scaling dictionary should be JSON-safe and
stable under Keras serialization.
- Use :func:`_jsonify` to defensively convert nested values
(NumPy scalars, tuples, sets) into plain Python types.
- The config container combines Keras serialization patterns with the
standard-library dataclass model
:cite:p:`KerasSerializationDocs,PythonDataclassesDocs`.
See Also
--------
load_scaling_kwargs :
Load scaling configuration from mapping or file.
canonicalize_scaling_kwargs :
Normalize keys and fill defaults consistently.
enforce_scaling_alias_consistency :
Ensure alias keys agree and do not conflict.
validate_scaling_kwargs :
Validate schema and value ranges.
"""
# Raw payload (may be incomplete or aliased).
payload: dict = field(default_factory=dict)
# Optional provenance (e.g., file path).
source: str | None = None
# Schema version tag (for future migrations).
schema_version: str = "1"
[docs]
@classmethod
def from_any(cls, obj, *, copy=True):
r"""
Serializable container for GeoPrior scaling configuration.
This dataclass stores a "payload" dictionary that holds all
scaling and physics-control parameters required to reproduce
the model behavior after saving and reloading with Keras.
The container supports flexible construction from:
- ``None`` (empty config),
- a mapping (dict-like),
- a file path ``str`` (loaded via ``load_scaling_kwargs``),
- an existing :class:`~GeoPriorScalingConfig` instance.
The canonical and validated configuration is produced by
:meth:`resolve`, which applies the GeoPrior scaling pipeline:
loading, canonicalization, alias consistency checks, and
validation.
Parameters
----------
payload : dict, optional
Raw scaling configuration payload. This may be incomplete
or contain aliases prior to canonicalization.
source : str or None, optional
Optional provenance string, typically a file path used to
load the payload. This is stored for traceability only.
schema_version : str, optional
Version label for the payload schema. This can be used
to implement migrations when the scaling format evolves.
Attributes
----------
payload : dict
The raw payload stored in this object.
source : str or None
The provenance hint, if provided.
schema_version : str
Schema version label.
Notes
-----
- The resolved scaling dictionary returned by :meth:`resolve`
is the one you should pass to the model internals.
- ``get_config`` returns JSON-safe objects only. This avoids
subtle reconstruction drift caused by non-serializable
values.
- This factory aligns with the Keras object-serialization
pattern described in :cite:t:`KerasSerializationDocs`.
Examples
--------
Construct from a mapping:
>>> cfg = GeoPriorScalingConfig.from_any(
... {"coords_normalized": True}
... )
>>> sk = cfg.resolve()
>>> isinstance(sk, dict)
True
Construct from a file path:
>>> cfg = GeoPriorScalingConfig.from_any(
... "path/to/scaling_kwargs.json"
... )
>>> sk = cfg.resolve()
Use in a model constructor (pattern):
>>> cfg = GeoPriorScalingConfig.from_any(scaling_kwargs)
>>> scaling_kwargs_resolved = cfg.resolve()
See Also
--------
GeoPriorScalingConfig.from_any :
Build config from dict, path, or config instance.
GeoPriorScalingConfig.resolve :
Produce canonical and validated scaling dictionary.
load_scaling_kwargs, canonicalize_scaling_kwargs :
Scaling pipeline functions.
"""
r"""
Create a scaling config from common input types.
This factory method normalizes user input into a
:class:`~GeoPriorScalingConfig` instance.
Accepted inputs
---------------
- ``None``: create an empty config.
- :class:`~GeoPriorScalingConfig`: returned as-is.
- ``str``: treated as a file path and loaded via
:func:`load_scaling_kwargs`.
- ``Mapping``: converted to a dict payload by default.
Parameters
----------
obj : object
Scaling configuration input to normalize.
copy : bool, optional
If ``True``, copy mapping payloads into a new ``dict``.
This helps avoid accidental mutation of user state.
Returns
-------
cfg : GeoPriorScalingConfig
A normalized config container.
Raises
------
TypeError
If ``obj`` is not ``None``, ``str``, ``Mapping``, or a
:class:`~GeoPriorScalingConfig` instance.
Notes
-----
- When ``obj`` is a file path, the path is stored in the
``source`` attribute for traceability.
- Canonicalization and validation happen in :meth:`resolve`,
not in this constructor.
Examples
--------
>>> GeoPriorScalingConfig.from_any(None)
GeoPriorScalingConfig(payload={}, source=None, ...)
>>> GeoPriorScalingConfig.from_any({"a": 1}).payload["a"]
1
"""
# ``None`` -> empty payload.
if obj is None:
logger.debug(
"GeoPriorScalingConfig.from_any: obj=None",
)
return cls(payload={})
# Already a config object.
if isinstance(obj, cls):
logger.debug(
"GeoPriorScalingConfig.from_any: "
"received GeoPriorScalingConfig",
)
return obj
# Path-like: load via existing loader.
if isinstance(obj, str):
logger.info(
"GeoPriorScalingConfig.from_any: "
"loading scaling kwargs from path=%r",
obj,
)
payload = load_scaling_kwargs(
obj,
copy=copy,
)
logger.debug(
"GeoPriorScalingConfig.from_any: "
"loaded keys=%d source=%r",
len(payload),
obj,
)
return cls(
payload=payload,
source=obj,
)
# Mapping-like: accept dict-like payload.
if isinstance(obj, Mapping):
logger.debug(
"GeoPriorScalingConfig.from_any: "
"received Mapping keys=%d copy=%s",
len(obj),
bool(copy),
)
payload = dict(obj) if copy else obj
return cls(payload=payload)
# Unsupported type.
msg = (
f"Unsupported scaling_kwargs type: {type(obj)!r}"
)
logger.error(
"GeoPriorScalingConfig.from_any: %s",
msg,
)
raise TypeError(msg)
[docs]
def resolve(self):
r"""
Resolve the payload into a canonical, validated scaling dict.
This method runs the GeoPrior scaling pipeline and returns a
dictionary suitable for direct use inside model computations.
The pipeline is:
1) Load payload (mapping or file-style behavior),
2) Canonicalize keys and fill defaults,
3) Enforce alias consistency,
4) Validate values and required fields.
Returns
-------
scaling_kwargs : dict
Canonical and validated scaling configuration.
Raises
------
ValueError
If validation fails due to missing keys or invalid values.
KeyError
If canonicalization expects keys that are absent.
TypeError
If the payload contains unsupported types.
Notes
-----
- The returned dict is intended to be stable under Keras
serialization and safe to store in model state.
- This method always loads with ``copy=True`` to avoid
mutating the stored payload.
Examples
--------
>>> cfg = GeoPriorScalingConfig.from_any(
... {"coords_normalized": True}
... )
>>> sk = cfg.resolve()
>>> sk["coords_normalized"]
True
See Also
--------
canonicalize_scaling_kwargs :
Normalizes scaling keys and defaults.
validate_scaling_kwargs :
Enforces schema and constraints.
enforce_scaling_alias_consistency :
Prevents conflicting aliases.
"""
logger.debug(
"GeoPriorScalingConfig.resolve: start "
"(source=%r, schema_version=%r)",
self.source,
self.schema_version,
)
# Load payload defensively (copy).
sk = load_scaling_kwargs(
self.payload,
copy=True,
)
logger.debug(
"GeoPriorScalingConfig.resolve: loaded "
"payload keys=%d",
len(sk),
)
# Normalize keys and fill defaults.
sk = canonicalize_scaling_kwargs(sk)
logger.debug(
"GeoPriorScalingConfig.resolve: "
"canonicalized keys=%d",
len(sk),
)
# Enforce alias agreement (no conflicts).
enforce_scaling_alias_consistency(sk)
logger.debug(
"GeoPriorScalingConfig.resolve: "
"alias consistency OK",
)
# Validate schema and value ranges.
validate_scaling_kwargs(sk)
logger.info(
"GeoPriorScalingConfig.resolve: OK "
"(keys=%d, source=%r)",
len(sk),
self.source,
)
return sk
[docs]
def get_config(self):
r"""
Return a JSON-safe Keras configuration dictionary.
Keras uses this method to serialize the object. The returned
dictionary must contain only JSON-serializable values.
This implementation uses :func:`_jsonify` to defensively
convert nested structures such as NumPy scalars, tuples, and
sets into plain Python types.
Returns
-------
config : dict
JSON-safe configuration dictionary with the following
keys:
- ``payload``: JSON-safe payload mapping,
- ``source``: provenance hint (may be ``None``),
- ``schema_version``: schema version label.
Notes
-----
- ``source`` is stored for traceability and does not affect
:meth:`resolve`.
- When saved as part of a model config, this makes scaling
reconstruction deterministic.
See Also
--------
GeoPriorScalingConfig.from_config :
Recreate a config instance from this dictionary.
"""
cfg = {
"payload": _jsonify(self.payload),
"source": self.source,
"schema_version": self.schema_version,
}
logger.debug(
"GeoPriorScalingConfig.get_config: "
"payload keys=%d source=%r",
len(cfg.get("payload", {})),
self.source,
)
return cfg
[docs]
@classmethod
def from_config(cls, config):
r"""
Recreate an instance from a Keras configuration dictionary.
This class method is used by Keras deserialization to rebuild
the object from the dictionary returned by :meth:`get_config`.
Parameters
----------
config : dict
Configuration dictionary produced by :meth:`get_config`.
Returns
-------
cfg : GeoPriorScalingConfig
Reconstructed config instance.
Notes
-----
- This method does not call :meth:`resolve`. Resolution is
deferred to the consumer so that reconstruction remains
explicit and testable.
See Also
--------
GeoPriorScalingConfig.get_config :
Produces the configuration dictionary.
"""
logger.debug(
"GeoPriorScalingConfig.from_config: keys=%s",
sorted(list(config.keys())),
)
return cls(**config)
def _deep_update(base: dict, over: dict) -> dict:
"""
Deep-merge nested dicts.
Override wins. Update is in-place; returns `base`.
"""
if not over:
return base
for k, v in over.items():
is_d = isinstance(v, dict)
is_b = isinstance(base.get(k), dict)
if is_d and is_b:
_deep_update(base[k], v)
else:
base[k] = v
return base
def _resolve_json_path(
path: str,
base_dir: str | None,
) -> str:
"""
Resolve a JSON path.
Expands env vars and `~`. Relative paths are
resolved against `base_dir` (or CWD).
"""
p = os.path.expandvars(os.path.expanduser(str(path)))
if os.path.isabs(p):
return p
b = base_dir or os.getcwd()
return os.path.abspath(os.path.join(b, p))
[docs]
def override_scaling_kwargs(
sk: Mapping[str, Any],
cfg: Mapping[str, Any] | None,
*,
finalize: Callable[[dict], dict] | None = None,
dyn_names: Sequence[str] | None = None,
gwl_dyn_index: int | None = None,
base_dir: str | None = None,
path_key: str = "SCALING_KWARGS_JSON_PATH",
strict: bool = True,
add_path: bool = True,
log_fn: Callable[[str], Any] | None = None,
) -> dict:
r"""
Override ``scaling_kwargs`` from a JSON file or dict.
This helper applies an optional, precedence-based override to
an existing ``scaling_kwargs`` mapping. The override source is
read from ``cfg[path_key]``. If the key is missing or empty, the
input ``sk`` is returned (optionally finalized).
The override can be provided as:
- a file path to a JSON object (mapping), or
- a Python dict-like mapping embedded in ``cfg``.
Overrides are applied via a deep-merge strategy:
- for nested dict values, keys are merged recursively,
- for non-dict values, the override replaces the base value.
Optionally, the merged result is passed through ``finalize`` to
recompute derived or canonical fields (for example, coordinate
ranges, unit flags, or other normalization metadata).
Parameters
----------
sk : Mapping[str, Any]
Base scaling configuration (``scaling_kwargs``). This is
typically computed by Stage-2 or loaded from Stage-1 output.
The input is copied to a plain ``dict`` before modification.
cfg : Mapping[str, Any] or None
Configuration mapping that may contain the override source
under ``path_key``. If ``None``, no override is applied.
finalize : callable or None, optional
Function applied to the scaling dict to enforce canonical
structure or to compute derived fields. If provided, it is
applied before and after the override merge:
- pre-merge: normalize the base dict,
- post-merge: ensure the merged dict is consistent.
The callable must accept a dict and return a dict.
dyn_names : Sequence[str] or None, optional
Expected dynamic feature names for safety validation. If
provided and the override contains ``dynamic_feature_names``,
the two sequences are compared. A mismatch raises an error
when ``strict=True``.
gwl_dyn_index : int or None, optional
Expected dynamic index for the groundwater-level feature.
If provided and the override contains ``gwl_dyn_index``, the
values are compared. A mismatch raises an error when
``strict=True``.
base_dir : str or None, optional
Base directory used to resolve relative JSON paths. If
``None``, the current working directory is used.
path_key : str, default="SCALING_KWARGS_JSON_PATH"
Name of the key in ``cfg`` that specifies the override. The
value may be a dict-like mapping or a path to a JSON file.
strict : bool, default=True
Controls behavior on safety-check mismatches. When ``True``,
mismatches raise a ``ValueError``. When ``False``, mismatches
can be logged via ``log_fn`` and the override still proceeds.
add_path : bool, default=True
If ``True``, store the resolved override source in the output
dict under ``scaling_kwargs_override_path``. When the override
is provided as a mapping (not a file), the value is set to
``"<dict>"``.
log_fn : callable or None, optional
Optional logger function. If provided, it is called with
informative messages such as successful override application
and (when ``strict=False``) mismatch warnings. Common choices
are ``print`` or ``logger.info``.
Returns
-------
out : dict
Final scaling dict after optional override and optional
finalization. The returned dict is independent from the input
mapping object ``sk`` (a copy is always created).
Raises
------
FileNotFoundError
If ``cfg[path_key]`` is a path and the file does not exist.
ValueError
If a path is provided but the file does not contain valid
JSON, or if a safety check fails while ``strict=True``.
TypeError
If the loaded override is not a JSON object (dict-like).
Notes
-----
Path resolution
When ``cfg[path_key]`` is a string path, it is resolved as:
1. Expand environment variables and ``~``.
2. If relative, join with ``base_dir`` (or CWD).
Safety checks
The checks are intentionally conservative. They prevent using
an override file produced for a different dataset or feature
layout. Recommended checks are:
- ``dynamic_feature_names`` equality when known.
- ``gwl_dyn_index`` equality when known.
You can extend validation by checking additional keys such as
``coord_epsg_used``, ``coords_normalized``, or unit flags.
Finalization
In GeoPrior pipelines, ``finalize`` is typically a helper that
enforces defaults and recomputes derived entries. Applying it
both before and after the override helps reduce edge cases
where the override only supplies partial information.
Figure assembly follows the plotting conventions described in
:cite:t:`Hunter2007Matplotlib`.
Examples
--------
Stage-2: override computed scaling with a file
In Stage-2, call this right after the auto-computed scaling
is available, so the override takes precedence:
>>> sk = subsmodel_params["scaling_kwargs"]
>>> sk = override_scaling_kwargs(
... sk,
... cfg,
... finalize=finalize_scaling_kwargs,
... dyn_names=DYN_NAMES,
... gwl_dyn_index=GWL_DYN_INDEX,
... base_dir=os.path.dirname(__file__),
... strict=True,
... log_fn=print,
... )
>>> subsmodel_params["scaling_kwargs"] = sk
Stage-3: override Stage-1 scaling prior to enforcing bounds
In Stage-3, apply the override before injecting Stage-3 bounds:
>>> sk_model = dict(cfg.get("scaling_kwargs", {}) or {})
>>> sk_model = override_scaling_kwargs(
... sk_model,
... cfg,
... dyn_names=sk_model.get("dynamic_feature_names"),
... gwl_dyn_index=sk_model.get("gwl_dyn_index"),
... base_dir=os.path.dirname(__file__),
... )
>>> sk_model["bounds"] = {
... **(sk_model.get("bounds", {}) or {}),
... **bounds_for_scaling,
... }
Inline dict override (no JSON file)
If the override is embedded in config, it is used directly:
>>> cfg = {
... "SCALING_KWARGS_JSON_PATH": {
... "coords_normalized": True,
... "coord_ranges": {"t": 7.0, "x": 1000.0, "y": 900.0},
... }
... }
>>> out = override_scaling_kwargs({}, cfg)
See Also
--------
finalize_scaling_kwargs :
Canonicalize and complete ``scaling_kwargs`` entries.
compute_scaling_kwargs :
Build a base scaling dict from data and pipeline settings.
"""
base = dict(sk) if sk is not None else {}
if finalize is not None:
base = finalize(base)
cfg = cfg or {}
raw = cfg.get(path_key, None)
if raw in (None, "", False):
return base
if isinstance(raw, Mapping):
over = dict(raw)
over_path = "<dict>"
else:
over_path = _resolve_json_path(str(raw), base_dir)
if not os.path.isfile(over_path):
raise FileNotFoundError(
f"{path_key} not found: {over_path}"
)
try:
with open(over_path, encoding="utf-8") as f:
over = json.load(f)
except Exception as e:
raise ValueError(
f"Invalid JSON: {over_path}"
) from e
if not isinstance(over, dict):
raise TypeError("Override must be a JSON object.")
if dyn_names and "dynamic_feature_names" in over:
names = list(over["dynamic_feature_names"])
if names != list(dyn_names):
msg = "Override mismatch: dynamic_feature_names."
if strict:
raise ValueError(msg)
if log_fn:
log_fn(msg)
if gwl_dyn_index is not None and "gwl_dyn_index" in over:
ov = int(over["gwl_dyn_index"])
if ov != int(gwl_dyn_index):
msg = "Override mismatch: gwl_dyn_index."
if strict:
raise ValueError(msg)
if log_fn:
log_fn(msg)
out = _deep_update(base, over)
if finalize is not None:
out = finalize(out)
if add_path:
out["scaling_kwargs_override_path"] = over_path
if log_fn:
log_fn(f"[INFO] scaling_kwargs override: {over_path}")
return out