Source code for geoprior.models.subsidence.batch_io

# SPDX-License-Identifier: Apache-2.0
# GeoPrior-v3 — https://github.com/earthai-tech/geoprior-v3
# Copyright (c) 2026-present
# Author: LKouadio <https://lkouadio.com>

"""
Batch.io
"""

from __future__ import annotations

from collections.abc import Mapping, Sequence
from typing import Any

import numpy as np

from ...utils.generic_utils import rename_dict_keys
from .. import KERAS_DEPS

Tensor = KERAS_DEPS.Tensor

tf_float32 = KERAS_DEPS.float32
tf_int32 = KERAS_DEPS.int32

tf_cast = KERAS_DEPS.cast
tf_constant = KERAS_DEPS.constant
tf_debugging = KERAS_DEPS.debugging
tf_equal = KERAS_DEPS.equal
tf_maximum = KERAS_DEPS.maximum
tf_minimum = KERAS_DEPS.minimum
tf_greater_equal = KERAS_DEPS.greater_equal
tf_rank = KERAS_DEPS.rank
tf_cond = KERAS_DEPS.cond
tf_shape = KERAS_DEPS.shape
tf_zeros_like = KERAS_DEPS.zeros_like
tf_ones = KERAS_DEPS.ones
tf_greater = KERAS_DEPS.greater
tf_cond = KERAS_DEPS.cond
tf_concat = KERAS_DEPS.concat
tf_convert_to_tensor = KERAS_DEPS.convert_to_tensor
tf_ones_like = KERAS_DEPS.ones_like
tf_less_equal = KERAS_DEPS.less_equal
tf_abs = KERAS_DEPS.abs
tf_print = KERAS_DEPS.print
tf_reduce_mean = KERAS_DEPS.reduce_mean
tf_expand_dims = KERAS_DEPS.expand_dims
tf_tile = KERAS_DEPS.tile


def _get_coords(
    inputs: dict[str, Any] | Sequence[Any] | Tensor,
    *,
    check_shape: bool = False,
    is_time_dependent: bool = True,
) -> Tensor:
    r"""
    Extract the **coords** tensor from any input layout.

    Parameters
    ----------
    inputs : dict, sequence or tensor
        * **Dict** – must contain a ``"coords"`` key.
        * **Sequence** – coords are expected at index 0.
        * **Tensor** – interpreted as coords directly.
    check_shape : bool, default ``False``
        If ``True``, validate that the last dimension equals 3 and that
        the rank matches *is_time_dependent*.
    is_time_dependent : bool, default ``True``
        * ``True``  → expect a 3-D shape ``(B, T, 3)``
        * ``False`` → allow a 2-D shape ``(B, 3)``; a 3-D shape is also
          accepted (useful when the model ignores time).

    Returns
    -------
    tf.Tensor
        The extracted coords tensor.

    Raises
    ------
    KeyError
        If a dict is missing the ``"coords"`` key.
    TypeError
        If *inputs* is not dict, sequence or tensor.
    ValueError
        If ``check_shape`` is ``True`` and the tensor shape is invalid.

    Notes
    -----
    The function is lightweight and executes outside any
    ``tf.function`` context; use it freely inside the model’s
    ``train_step`` or data-pipe helpers.
    """
    # ── 1. locate the tensor
    if isinstance(inputs, Tensor):
        coords = inputs
    elif isinstance(inputs, tuple | list):
        coords = inputs[0]
    elif isinstance(inputs, dict):
        try:
            coords = inputs["coords"]
        except KeyError as err:
            raise KeyError(
                "Input dict lacks a 'coords' entry."
            ) from err
    else:
        raise TypeError(
            "inputs must be a dict, sequence, or Tensor; "
            f"got {type(inputs).__name__}"
        )

    # ── 2. validate shape if requested
    if check_shape:
        shape = coords.shape  # static if known, else dynamic
        # last dim must be 3
        if shape.rank is not None:
            if shape[-1] != 3:
                raise ValueError(
                    "coords[..., -1] must equal 3 (t, x, y); "
                    f"found {shape[-1]}"
                )
            # rank check when static
            if is_time_dependent and shape.rank != 3:
                raise ValueError(
                    "Expected coords rank 3 (B, T, 3) for time-dependent "
                    "data; received rank "
                    f"{shape.rank}"
                )
            if not is_time_dependent and shape.rank not in (
                2,
                3,
            ):
                raise ValueError(
                    "Expected coords rank 2 or 3 for static data; "
                    f"received rank {shape.rank}"
                )
        else:  # dynamic shapes (inside tf.function)
            tf_debugging.assert_equal(
                tf_shape(coords)[-1],
                3,
                message="coords[..., -1] must equal 3 (t, x, y)",
            )
            if is_time_dependent:
                tf_debugging.assert_rank(coords, 3)
            else:
                tf_debugging.assert_rank_in(coords, (2, 3))

    return coords


def _canonicalize_targets(targets):
    """
    Return targets as a dict matching compiled output names.

    Accepts:
    - dict targets: {"subsidence": ..., "gwl": ...} or already
      {"subs_pred": ..., "gwl_pred": ...}
    - tuple/list targets: (subs, gwl) -> {"subs_pred": subs, "gwl_pred": gwl}

    This avoids Keras 3 structure-mismatch errors when y_pred is a dict.
    """

    if isinstance(targets, Mapping):
        tgt = dict(targets)
        # already canonical?
        if ("subs_pred" in tgt) and ("gwl_pred" in tgt):
            return tgt

        return rename_dict_keys(
            tgt,
            param_to_rename={
                "subsidence": "subs_pred",
                "gwl": "gwl_pred",
            },
        )

    if isinstance(targets, tuple | list):
        if len(targets) == 2:
            return {
                "subs_pred": targets[0],
                "gwl_pred": targets[1],
            }
        if len(targets) == 1:
            return {"subs_pred": targets[0]}

    return targets


# ---------------------------------------------------------------------
# Quantile helpers (shape-safe)
# ---------------------------------------------------------------------
def _q_index(quantiles, q=0.5):
    """Return the index of q in quantiles (python int), else None."""
    if quantiles is None:
        return None
    try:
        qs = [float(x) for x in list(quantiles)]
    except Exception:
        return None
    # exact match first
    for i, qi in enumerate(qs):
        if abs(qi - float(q)) < 1e-12:
            return i
    # fallback: closest
    qv = float(q)
    i0 = int(np.argmin([abs(qi - qv) for qi in qs]))
    return i0


[docs] def select_q(pred, quantiles=None, q=0.5, fallback="mean"): """ Select q-quantile from pred. pred: - (B,H,Q,O) -> returns (B,H,O) - (B,H,Q,1) -> returns (B,H,1) - otherwise returned as-is. """ r = getattr(pred, "shape", None) rank = pred.shape.rank if r is not None else None if rank != 4: return pred idx = _q_index(quantiles, q=q) if idx is not None: return pred[:, :, idx, :] # fallback if quantiles list is missing/odd if str(fallback).lower() == "mean": return tf_reduce_mean(pred, axis=2) # "middle" fallback qn = tf_shape(pred)[2] mid = qn // 2 return pred[:, :, mid, :]
[docs] def tile_true_to_quantiles(y_true, y_pred): """ Make y_true compatible with y_pred when y_pred has quantile axis. y_true: (B,H,O) and y_pred: (B,H,Q,O) -> return y_true_q: (B,H,Q,O) Else return y_true unchanged. """ rt = y_true.shape.rank rp = y_pred.shape.rank if (rp == 4) and (rt == 3): y4 = tf_expand_dims(y_true, axis=2) # (B,H,1,O) qn = tf_shape(y_pred)[2] return tf_tile(y4, [1, 1, qn, 1]) return y_true
def _align_true_for_loss(y_true, y_pred): """ Align y_true to y_pred for quantile-aware losses/metrics. Never tiles y_true across Q. Only reduces y_true if it accidentally arrives with a quantile axis. """ y_true = tf_convert_to_tensor(y_true) y_pred = tf_convert_to_tensor(y_pred) # (B,H) -> (B,H,1) if y_true.shape.rank == 2: y_true = tf_expand_dims(y_true, axis=-1) rt = y_true.shape.rank rp = y_pred.shape.rank # (B,H,Q,O) -> (B,H,O) if rt == 4: return y_true[:, :, 0, :] # (B,H,Q) -> (B,H,1) (rare but happens if upstream tiled targets) if rt == 3 and rp in (3, 4): if rp == 4: q_dim = y_pred.shape[2] if (q_dim is None) or (y_true.shape[-1] == q_dim): return tf_expand_dims( y_true[:, :, 0], axis=-1 ) else: return tf_expand_dims(y_true[:, :, 0], axis=-1) return y_true