"""
Provides common helper functions and for validation,
comparison, and other generic operations
"""
from __future__ import annotations
import inspect
import logging
import os
import re
import shutil
import textwrap
import warnings
from collections.abc import Iterable, Mapping, Sequence
from datetime import datetime
from itertools import chain
from numbers import Real
from pathlib import Path
from typing import Any, Literal
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
PathLike = str | os.PathLike
_SENTINEL = object()
__all__ = [
"ExistenceChecker",
"ensure_directory_exists",
"verify_identical_items",
"vlog",
"detect_dt_format",
"get_actual_column_name",
"transform_contributions",
"exclude_duplicate_kwargs",
"reorder_columns",
"find_id_column",
"check_group_column_validity",
"save_all_figures",
"rename_dict_keys",
"normalize_time_column",
"select_mode",
"normalize_model_inputs",
"print_config_table",
]
[docs]
class ExistenceChecker:
"""
A utility class for checking and ensuring the existence of files
and directories on the filesystem.
This class provides static methods to verify whether a given path
exists and to create directories or files if necessary. It raises
informative exceptions when paths are invalid or cannot be created.
Methods
-------
ensure_directory(path)
Ensure a directory exists at the specified path.
ensure_file(path, create_parent_dirs=False)
Ensure a file exists at the specified path, optionally creating
parent directories.
Examples
--------
>>> from geoprior.utils.generic_utils import ExistenceChecker
>>> # Ensure a directory exists
>>> dir_path = ExistenceChecker.ensure_directory("data/output")
>>> isinstance(dir_path, Path)
True
>>> # Ensure a file exists, creating parent directories
>>> file_path = ExistenceChecker.ensure_file(
... "data/output/results.txt", create_parent_dirs=True
... )
>>> file_path.exists()
True
Notes
-----
- Uses `pathlib.Path.mkdir(..., parents=True, exist_ok=True)` under
the hood to create directories.
- Creating a file will produce an empty file if it does not exist.
- Raises TypeError if the given path is not a str or pathlib.Path,
and appropriate OSError/FileExistsError for filesystem errors.
See Also
--------
pathlib.Path.mkdir : Method to create directories.
pathlib.Path.touch : Method to create an empty file.
os.makedirs : Legacy function for creating directories recursively.
os.path.exists : Check if a path exists.
"""
[docs]
@staticmethod
def ensure_directory(path):
"""
Ensure that a directory exists at the given path, creating it if
needed.
Parameters
----------
path : str or pathlib.Path
The filesystem path for which to ensure directory existence.
Can be either a string or a `pathlib.Path` object.
Returns
-------
pathlib.Path
A `Path` object pointing to the existing (or newly created)
directory.
Raises
------
TypeError
If `path` is not a string or `pathlib.Path`.
FileExistsError
If a file (not a directory) already exists at `path`.
OSError
If the directory cannot be created for any other reason
(e.g., insufficient permissions).
"""
# Validate type
if not isinstance(path, str | Path):
raise TypeError(
"`path` must be a str or pathlib.Path, got "
f"{type(path).__name__!r}"
)
dir_path = Path(path)
# If it exists and is a directory, return immediately
if dir_path.exists():
if dir_path.is_dir():
return dir_path
else:
# A non-directory file exists at this path
raise FileExistsError(
f"A non-directory file exists at: "
f"{dir_path}"
)
# Attempt to create the directory (including parents)
try:
dir_path.mkdir(parents=True, exist_ok=True)
except OSError as exc:
raise OSError(
f"Unable to create directory {dir_path}: {exc}"
) from exc
return dir_path
[docs]
@staticmethod
def ensure_file(path, create_parent_dirs=False):
"""
Ensure that a file exists at the given path, creating it if needed.
If `create_parent_dirs` is True, any missing parent directories will
be created automatically.
Parameters
----------
path : str or pathlib.Path
The filesystem path for the file that must exist.
create_parent_dirs : bool, optional
If True, create any missing parent directories. Default is False.
Returns
-------
pathlib.Path
A `Path` object pointing to the existing (or newly created)
file.
Raises
------
TypeError
If `path` is not a string or `pathlib.Path`.
FileExistsError
If a directory (not a file) already exists at `path`.
OSError
If the file or parent directories cannot be created due to
filesystem errors.
"""
# Validate type
if not isinstance(path, str | Path):
raise TypeError(
"`path` must be a str or pathlib.Path, got "
f"{type(path).__name__!r}"
)
file_path = Path(path)
# If it exists and is a file, return immediately
if file_path.exists():
if file_path.is_file():
return file_path
else:
# A directory exists at this path
raise FileExistsError(
f"A directory exists at: {file_path}"
)
# Create parent directories if requested
parent_dir = file_path.parent
if create_parent_dirs and not parent_dir.exists():
try:
parent_dir.mkdir(parents=True, exist_ok=True)
except OSError as exc:
raise OSError(
f"Unable to create parent directories {parent_dir}: "
f"{exc}"
) from exc
# Attempt to create the file
try:
file_path.touch(exist_ok=True)
except OSError as exc:
raise OSError(
f"Unable to create file {file_path}: {exc}"
) from exc
return file_path
[docs]
def check_group_column_validity(
df: pd.DataFrame,
group_col: str,
ops: str = "check_only",
max_unique: int = 10,
auto_bin: bool = False,
bins: int = 4,
error: str = "warn",
bin_labels: list[str] | None = None,
verbose: bool = True,
) -> pd.DataFrame | bool:
"""
Validate a grouping column for categorical-style use and optionally bin it.
Parameters
----------
df : pandas.DataFrame
Input DataFrame holding the grouping column.
group_col : str
Name of the candidate grouping column in ``df``.
ops : {'check_only', 'binning', 'validate'}, optional
Operation mode. Use ``"check_only"`` to return a boolean,
``"binning"`` to bin the column when needed and return a modified
DataFrame, or ``"validate"`` to check validity while honoring
``error``.
max_unique : int, optional
Maximum number of unique numeric values allowed before the column is
treated as too continuous for categorical use.
auto_bin : bool, optional
Whether to auto-bin a numeric column when ``ops='binning'``.
bins : int, optional
Number of bins to create when binning is applied.
error : {'warn', 'raise', 'ignore'}, optional
Policy used when validation fails.
bin_labels : list of str or None, optional
Custom labels for generated bins.
verbose : bool, optional
Whether to emit informational messages.
Returns
-------
bool or pandas.DataFrame
Returns a boolean for ``ops='check_only'``. Otherwise returns a
DataFrame, possibly with a transformed ``group_col``.
Notes
-----
When quantile binning is used, interval boundaries are derived from the
numeric distribution of ``group_col``.
"""
# Check if group_col is in the DataFrame.
if group_col not in df.columns:
raise ValueError(
f"[ERROR] Column '{group_col}' not found "
f"in DataFrame."
)
# Extract data from the column and check if numeric.
col_data = df[group_col]
is_numeric = pd.api.types.is_numeric_dtype(col_data)
# If numeric, ensure we haven't exceeded max_unique.
is_valid_group = (
not is_numeric or col_data.nunique() <= max_unique
)
# Validate ops argument.
if ops not in {"check_only", "binning", "validate"}:
raise ValueError(
f"Unknown 'ops' value: {ops}. Choose from "
f"'check_only', 'binning', 'validate'."
)
# If user only wants a check, return boolean.
if ops == "check_only":
return is_valid_group
# If 'validate', decide whether to warn/raise
# or fallback to binning if auto_bin is True.
elif ops == "validate":
if is_valid_group:
return df
msg = (
f"Column '{group_col}' is numeric with "
f"{col_data.nunique()} unique values. "
"Not suitable for grouping in plots."
)
if auto_bin:
ops = "binning" # fallback
else:
if error == "raise":
raise ValueError(msg)
elif error == "warn":
warnings.warn(f"{msg}", stacklevel=2)
return df
# If 'binning', we attempt to transform the column.
if ops == "binning":
# If already valid, no change needed.
if is_valid_group:
return df
# If auto_bin is True, we do quantile binning.
if auto_bin:
if bin_labels is None:
bin_labels = [
f"Q{i + 1}" for i in range(bins)
]
df[group_col] = pd.qcut(
col_data, q=bins, labels=bin_labels
)
if verbose:
print(
f"[INFO] Auto-binned '{group_col}' "
f"into {bins} quantile-based categories."
)
return df
else:
# If auto_bin is False and data is invalid,
# we handle per 'error'.
if error == "raise":
raise ValueError(
"Auto-binning disabled, and group "
"column is not suitable."
)
elif error == "warn":
warnings.warn(
"Group column is not categorical "
"and was not binned.",
stacklevel=2,
)
return df
[docs]
def find_id_column(
df: pd.DataFrame,
strategy: Literal[
"naive", "exact", "dtype", "regex", "prefix_suffix"
] = "naive",
regex_pattern: str | None = None,
uniqueness_threshold: float = 0.95,
errors: Literal["raise", "warn", "ignore"] = "raise",
empty_as_none: bool = True,
as_list: bool = False,
case_sensitive: bool = False,
as_frame: bool = False,
) -> str | list[str] | pd.DataFrame | None:
"""
Identify potential ID column(s) in a pandas DataFrame
using multiple heuristic strategies.
The function examines column names and/or data properties
to detect columns likely to serve as unique identifiers.
This is particularly useful for large datasets where the
ID field is not explicitly labeled, and for quick scanning
of possible key columns.
Parameters
----------
df : pandas.DataFrame
The input DataFrame in which to search for
potential ID columns.
strategy : {'naive', 'exact', 'dtype', 'regex','prefix_suffix'},\
default 'naive'
Defines the logic for detecting ID columns:
- `exact`: Checks for a column name that exactly
matches `id` (case sensitivity controlled by
``case_sensitive``).
- `naive`: Searches for columns where `id` is
part of the name (e.g., `location_id`) subject
to case sensitivity.
- `prefix_suffix`: Considers columns prefixed
or suffixed with `id` or `_id`.
- `dtype`: Examines columns having data types
commonly used for IDs (integer, string, or
object) and checks if they show high uniqueness
via :math:`\\text{uniqueness\\_ratio}
\\geq \\text{uniqueness\\_threshold}`.
- `<regex>`: Uses a custom regular expression
``<regex_pattern>`` to find matches in column
names.
regex_pattern : str, optional
Required if `strategy` is `'regex'`. The
pattern is compiled via `re.compile`, with case
sensitivity determined by `<case_sensitive>`.
uniqueness_threshold : float, default 0.95
For `<dtype>` strategy, columns are flagged as ID
candidates if the ratio:
.. math::
r = \\frac{
\\text{unique\\_values}
}{
\\text{non\\_NA\\_rows}
}
satisfies :math:`r \\geq \\text{uniqueness\\_threshold}`,
or if the number of unique values equals the number
of non-null rows.
errors : {'raise', 'warn', 'ignore'}, default 'raise'
How to handle no-match cases:
- `raise`: Raises a `ValueError`.
- `warn`: Issues a `UserWarning` and returns
based on `<as_frame>` or `<empty_as_none>`.
- `ignore`: Returns an empty result based on
the same parameters without warning.
empty_as_none : bool, default True
*Applies only if `as_frame` is False.* Defines
whether to return `None` (if True) or an empty list
(if False) when no ID column is found and
`<errors>` is `'warn'` or `'ignore'`.
as_list : bool, default False
If True, return all matched columns. If False,
return only the first match. Affects both name
returns and DataFrame returns.
case_sensitive : bool, default False
If False, comparisons (including regex) are
performed in a case-insensitive manner.
as_frame : bool, default False
If True, return the matched columns as a
pandas DataFrame. If `as_list` is True,
it may include multiple columns. If no column
is found, returns an empty DataFrame (if
`<errors>` is `'warn'` or `'ignore'`).
Returns
-------
str or List[str] or pandas.DataFrame or None
Depends on `as_frame`, `as_list`, and the
number of matching columns:
- `<as_frame>`=False, `as_list`=False:
returns the first match as a string,
or `None`/`[]`.
- `as_frame`=False, `as_list`=True:
returns all matching column names as
a list of strings.
- `as_frame`=True, `as_list`=False:
returns a DataFrame with the first matched
column. If no match is found, an empty
DataFrame may be returned.
- `as_frame`=True, `as_list`=True:
returns a DataFrame with all matched columns
included.
Notes
-----
- For `<dtype>` strategy, integer, string, and
object columns are inspected. The function
calculates a uniqueness ratio and compares it
against `<uniqueness_threshold>`.
- Negative or zero thresholds are invalid, as are
values above 1.
- If the DataFrame has no columns or is empty, the
behavior is determined by `<errors>`.
- The relational-model motivation for schema-oriented
column handling goes back to
:cite:t:`Codd1970Relational`.
Examples
--------
>>> from geoprior.utils.generic_utils import find_id_column
>>> import pandas as pd
>>> data = pd.DataFrame({
... 'ID_code': [101, 102, 103],
... 'Name': ['Alice', 'Bob', 'Charlie'],
... 'value': [10, 20, 30]
... })
>>> # Example using the 'naive' strategy
>>> col = find_id_column(data, strategy='naive')
>>> print(col) # Might return 'ID_code'
>>> # Example with as_list=True
>>> cols = find_id_column(data, strategy='naive',
... as_list=True)
>>> print(cols) # ['ID_code']
See Also
--------
re.compile : The regex compilation method used
when `strategy`='regex'.
pandas.api.types.is_integer_dtype : Checks integer type.
pandas.api.types.is_string_dtype : Checksstring type.
pandas.api.types.is_object_dtype : Checksobject type.
"""
# Validate that df is a DataFrame
if not isinstance(df, pd.DataFrame):
raise ValueError(
"Input 'df' must be a pandas DataFrame."
)
valid_strategies = [
"naive",
"exact",
"dtype",
"regex",
"prefix_suffix",
]
if strategy not in valid_strategies:
raise ValueError(
f"Invalid strategy '{strategy}'. Must be one "
f"of {valid_strategies}."
)
valid_errors = ["raise", "warn", "ignore"]
if errors not in valid_errors:
raise ValueError(
f"Invalid errors value '{errors}'. Must be one "
f"of {valid_errors}."
)
# If strategy='regex', ensure regex_pattern is valid
if strategy == "regex":
if not regex_pattern or not isinstance(
regex_pattern, str
):
raise ValueError(
"Parameter 'regex_pattern' must be a "
"non-empty string when strategy is "
"'regex'."
)
# If strategy='dtype', ensure threshold in [0,1]
if strategy == "dtype":
if not (0.0 <= uniqueness_threshold <= 1.0):
raise ValueError(
"Parameter 'uniqueness_threshold' must be "
"between 0.0 and 1.0."
)
# If DataFrame is empty, handle accordingly
if df.empty or len(df.columns) == 0:
msg = (
"DataFrame is empty or has no columns. "
"Cannot find ID column."
)
if errors == "raise":
raise ValueError(msg)
elif errors == "warn":
warnings.warn(msg, UserWarning, stacklevel=2)
if as_frame:
return pd.DataFrame()
else:
return [] if not empty_as_none else None
# Prepare columns for matching
original_columns = df.columns.tolist()
if not case_sensitive:
col_map = {}
seen_lower = set()
for c in original_columns:
lc = c.lower()
if lc not in seen_lower:
col_map[lc] = c
seen_lower.add(lc)
match_columns_keys = list(col_map.keys())
else:
col_map = {c: c for c in original_columns}
match_columns_keys = original_columns
found_matches_keys = []
# -------------- Strategy: 'exact' --------------
if strategy == "exact":
target = "id" if not case_sensitive else "id"
if target in match_columns_keys:
found_matches_keys.append(target)
# -------------- Strategy: 'naive' --------------
elif strategy == "naive":
target = "id" if not case_sensitive else "id"
found_matches_keys = [
mc_key
for mc_key in match_columns_keys
if target in mc_key
]
# -------------- Strategy: 'prefix_suffix' -------
elif strategy == "prefix_suffix":
targets = ["id", "_id"]
if not case_sensitive:
targets = [t.lower() for t in targets]
for mc_key in match_columns_keys:
match = False
for t in targets:
if mc_key.startswith(t) or mc_key.endswith(t):
match = True
break
if match:
found_matches_keys.append(mc_key)
# -------------- Strategy: 'regex' --------------
elif strategy == "regex":
try:
flags = 0 if case_sensitive else re.IGNORECASE
pattern = re.compile(regex_pattern, flags)
original_regex_matches = [
c
for c in original_columns
if pattern.search(c)
]
if not case_sensitive:
found_matches_keys = [
k
for k, v in col_map.items()
if v in original_regex_matches
]
else:
found_matches_keys = [
k
for k in original_regex_matches
if k in match_columns_keys
]
except re.error as e:
raise ValueError(
f"Invalid regex pattern provided: {e}"
) from e
# -------------- Strategy: 'dtype' --------------
elif strategy == "dtype":
for col_name in original_columns:
col_series = df[col_name]
# Check if type is integer, string, or object
is_potential_type = (
pd.api.types.is_integer_dtype(col_series)
or pd.api.types.is_string_dtype(col_series)
or pd.api.types.is_object_dtype(col_series)
)
if is_potential_type:
non_na = col_series.dropna()
if len(non_na) > 0:
if pd.api.types.is_object_dtype(non_na):
num_unique = non_na.astype(
str
).nunique()
else:
num_unique = non_na.nunique()
uniqueness_ratio = num_unique / len(
non_na
)
is_perfectly_unique = num_unique == len(
non_na
)
if (
uniqueness_ratio
>= uniqueness_threshold
or is_perfectly_unique
):
key_to_add = (
col_name.lower()
if not case_sensitive
else col_name
)
if key_to_add in col_map:
if key_to_add not in (
found_matches_keys
):
found_matches_keys.append(
key_to_add
)
# Remove duplicates while preserving order
ordered_unique_matches_keys = sorted(
list(set(found_matches_keys)),
key=found_matches_keys.index,
)
original_matches = [
col_map[mk] for mk in ordered_unique_matches_keys
]
# -------------- Handle Results --------------
if original_matches:
if as_frame:
# Decide columns to select
cols_to_select = (
original_matches
if as_list
else [original_matches[0]]
)
valid_cols = [
c for c in cols_to_select if c in df.columns
]
if not valid_cols:
msg = (
"Internal error: Matched columns "
"not found in DataFrame."
)
if errors == "raise":
raise ValueError(msg)
elif errors == "warn":
warnings.warn(
f"{msg} Returning empty DataFrame.",
UserWarning,
stacklevel=2,
)
return pd.DataFrame()
return df[valid_cols]
elif as_list:
return original_matches
else:
return original_matches[0]
else:
msg = (
f"No ID column found in DataFrame using "
f"strategy='{strategy}'"
)
if regex_pattern and strategy == "regex":
msg += f" with pattern='{regex_pattern}'"
if strategy == "dtype":
msg += f" and threshold={uniqueness_threshold}"
if not case_sensitive:
msg += " (case-insensitive)."
else:
msg += " (case-sensitive)."
if errors == "raise":
raise ValueError(msg)
elif errors == "warn":
warnings.warn(msg, UserWarning, stacklevel=2)
# Return empty or None if no match
if as_frame:
return pd.DataFrame()
else:
return [] if not empty_as_none else None
[docs]
def verify_identical_items(
list1,
list2,
mode: str = "unique",
ops: str = "check_only",
error: str = "raise",
objname: str = None,
) -> bool | list:
"""
Check if two lists contain identical elements according
to the specified mode.
In "unique" mode, the function compares the unique elements
in each list.
In "ascending" mode, it compares elements pairwise in order.
Parameters
----------
list1 : list
The first list of items.
list2` : list
The second list of items.
mode : {'unique', 'ascending'}, default="unique"
The mode of comparison:
- "unique": Compare unique elements (order-insensitive).
- "ascending": Compare each element pairwise in order.
ops : {'check_only', 'validate'}, default="check_only"
If "check_only", returns True/False indicating a match.
If "validate", returns the validated list.
error : {'raise', 'warn', 'ignore'}, default="raise"
Specifies how to handle mismatches.
objname : str, optional
A name to include in error messages.
Returns
-------
bool or list
Depending on `ops`, returns True/False or the validated list.
Examples
--------
>>> from geoprior.utils.generic_utils import verify_identical_items
>>> list1 = [0.1, 0.5, 0.9]
>>> list2 = [0.1, 0.5, 0.9]
>>> verify_identical_items(list1, list2, mode="unique", ops="validate")
[0.1, 0.5, 0.9]
>>> verify_identical_items(list1, list2, mode="ascending", ops="check_only")
True
Notes
-----
In "ascending" mode, both lists must have the same length, and the
function compares each corresponding pair of elements.
In "unique" mode, the function uses the set of unique values for
comparison. If the lists contain mixed types, the function attempts
to compare their string representations.
"""
# Validate mode.
if mode not in ("unique", "ascending"):
raise ValueError(
"mode must be either 'unique' or 'ascending'"
)
if ops not in ("check_only", "validate"):
raise ValueError(
"ops must be either 'check_only' or 'validate'"
)
if error not in ("raise", "warn", "ignore"):
raise ValueError(
"error must be one of 'raise', 'warn', or 'ignore'"
)
# Ascending mode: compare each element in order.
if mode == "ascending":
if len(list1) != len(list2):
msg = (
f"Length mismatch in {objname or 'object lists'}: "
f"{len(list1)} vs {len(list2)}."
)
if error == "raise":
raise ValueError(msg)
elif error == "warn":
import warnings
warnings.warn(msg, UserWarning, stacklevel=2)
return False
differences = []
for idx, (a, b) in enumerate(
zip(list1, list2, strict=False)
):
if a != b:
differences.append((idx, a, b))
if differences:
msg = f"Differences in {objname or 'object lists'}: {differences}."
if error == "raise":
raise ValueError(msg)
elif error == "warn":
import warnings
warnings.warn(msg, UserWarning, stacklevel=2)
return False
return True if ops == "check_only" else list1
# Unique mode: compare the unique elements of each list.
else:
try:
unique1 = sorted(set(list1))
unique2 = sorted(set(list2))
except Exception:
unique1 = sorted({str(x) for x in list1})
unique2 = sorted({str(x) for x in list2})
if unique1 != unique2:
msg = (
f"Inconsistent unique elements in {objname or 'object lists'}: "
f"{unique1} vs {unique2}."
)
if error == "raise":
raise ValueError(msg)
elif error == "warn":
import warnings
warnings.warn(msg, UserWarning, stacklevel=2)
return False
return True if ops == "check_only" else unique1
[docs]
def vlog(
message,
verbose=None,
level: int = 3,
depth: int | str = "auto",
mode: str | None = None,
vp: bool = True,
logger=None,
**kws,
):
r"""
Log or naive messages with optional indentation and
bracketed tags.
This function, `vlog`, allows conditional logging
or printing of messages based on a global or passed
in <parameter inline> `verbose` level. By default,
it behaves differently depending on whether
``mode`` is ``'log'`` or ``'naive'``. When
:math:`mode = 'log'`, the message is printed only if
:math:`\text{verbose} \geq \text{level}`. Otherwise,
for :math:`mode` in [``None``, ``'naive'``], the
verbosity threshold leads to various bracketed
prefixes (e.g. [INFO], [DEBUG], [TRACE]) unless the
message already contains such a prefix.
.. math::
\text{indentation} = 2 \times \text{depth}
where :math:`\text{depth}` is either manually
specified or auto-derived based on `<parameter inline>`
`level` (1 = ERROR, 2 = WARNING, 3 = INFO, 4/5 =
DEBUG, 6/7 = TRACE).
Parameters
----------
message : str
The text to be printed or logged.
verbose : int, optional
Overall verbosity threshold. If ``None``, it
looks for a global variable named ``verbose``.
Default is ``None``.
level : int, default=3
Severity or importance level of the message.
Commonly:
* 1 = ERROR
* 2 = WARNING
* 3 = INFO
* 4,5 = DEBUG
* 6,7 = TRACE
depth : int or str, default="auto"
Indentation level used for the printed message.
If ``"auto"``, the depth is computed from
`<parameter inline> level`.
mode : str, optional
Determines logging mode. If set to ``'log'``,
prints messages only if
:math:`\text{verbose} \geq \text{level}`.
Otherwise (if ``None`` or ``'naive'``), it
follows a custom logic driven by `<parameter
inline> verbose`.
vp : bool, default=True
If ``True``, the function automatically prepends
bracketed tags (e.g. [INFO]) unless the message
already contains one of [INFO], [DEBUG], [ERROR],
[WARNING], or [TRACE].
logger : logging.Logger or Callable[[str], None], optional
Custom sink that receives the *already-formatted* message string.
* If you pass a standard :pyclass:`logging.Logger` instance,
the message is routed through ``logger.info``.
* If you supply any ``callable`` that accepts a single ``str``
(e.g. a GUI text-append function), that callable is invoked
directly.
* Defaults to :pyfunc:`print`, which writes to *stdout*.
kws: Logging instance, optional
For future extensions.
Returns
-------
None
This function does not return anything. It either
prints the message to stdout or omits it,
depending on `<parameter inline> verbose`, `<parameter
inline> level`, and ``mode``.
Notes
-----
This function is helpful for selectively displaying
or logging messages in applications that adapt to
the user's required verbosity. By default, each
level has a specific bracketed tag and an auto
indentation depth.
Examples
--------
>>> from geoprior.utils.generic_utils import vlog
>>> # Example with mode='log'
>>> # This prints only if global or passed-in
>>> # verbose >= 4.
>>> vlog("Check debugging details.", verbose=3,
... level=4, mode='log')
>>> # Example with mode='naive'
>>> # If verbose=2, it displays as [INFO] prefixed.
>>> vlog("Loading data...", verbose=2, mode='naive')
See Also
--------
globals : Used to retrieve the fallback `verbose`
value if not explicitly passed.
"""
logger = logger or print
if isinstance(logger, logging.Logger):
_emit = (
logger.info
) # unify interface → Callable[[str], None]
else:
_emit = logger # assume it's already a callable
verbosity_labels = {
1: "[ERROR]",
2: "[WARNING]",
3: "[INFO]",
4: "[DEBUG]",
5: "[DEBUG]",
6: "[TRACE]",
7: "[TRACE]",
}
# When depth="auto", assign indentation based on `level`.
if depth == "auto":
if level == 1:
depth = 0
elif level == 2:
depth = 2
elif level == 3:
depth = 0
elif level in (4, 5):
depth = 2
elif level in (6, 7):
depth = 4
else:
depth = 0
# Fallback to a global 'verbose' if none given.
actual_verbose = (
verbose
if verbose is not None
else globals().get("verbose", 0)
)
# If mode == 'log', keep original approach:
if mode == "log":
if actual_verbose >= level:
# Indent and prefix with the label from `level`.
indent = " " * (depth * 2)
_emit(
rf"{indent}{verbosity_labels[level]} {message}"
)
# Nothing else for mode='log' if verbosity is too low.
return
# If mode == 'naive' or None, override with the new rules:
if mode in [None, "naive"]:
# If actual_verbose < 1, do not print anything.
if actual_verbose < 1:
return
# Otherwise, figure out the prefix based on the threshold:
# We check if the message already includes [INFO], [DEBUG], [WARNING],
# [ERROR], or [TRACE]. If so, skip adding any prefix.
prefix_tags = (
"[INFO]",
"[DEBUG]",
"[ERROR]",
"[WARNING]",
"[TRACE]",
)
already_tagged = any(
tag in message for tag in prefix_tags
)
# indent as per the computed `depth`
indent = " " * (depth * 2)
# If >=3 => prefix with [INFO] if vp is True and not already tagged
if actual_verbose <= 3:
if vp and not already_tagged:
_emit(rf"{indent}[INFO] {message}")
else:
_emit(rf"{indent}{message}")
return
# If 3 < verbose < 5 => prefix with [DEBUG] if vp is True and not already tagged
if 3 < actual_verbose < 5:
if vp and not already_tagged:
_emit(rf"{indent}[DEBUG] {message}")
else:
_emit(rf"{indent}{message}")
return
# If verbose >= 5 => prefix with [TRACE] if vp is True and not already tagged
if actual_verbose >= 5:
if vp and not already_tagged:
_emit(rf"{indent}[TRACE] {message}")
else:
_emit(rf"{indent}{message}")
return
[docs]
def get_actual_column_name(
df: pd.DataFrame,
tname: str = None,
actual_name: str = None,
error: str = "raise",
default_to=None,
) -> str:
"""
Determines the actual target column name in the given DataFrame.
Parameters
----------
df : pandas.DataFrame
The DataFrame containing the target column.
tname : str, optional
The base target name (e.g., "subsidence"). If not found in the DataFrame,
it will attempt to find a matching column using "<tname>_actual" format.
actual_name : str, optional
If provided, this name will be returned as the actual target column name.
error : {'raise', 'warn', 'ignore'}, default='raise'
Specifies how to handle the case when no valid column is found:
- 'raise': Raises a `ValueError`.
- 'warn': Issues a warning and returns `None`.
- 'ignore': Silently returns `None`.
Returns
-------
str or None
The determined actual column name, or None if no match is found
and `error='warn'` or `error='ignore'`.
Raises
------
ValueError
If no valid target column is found and `error='raise'`.
Examples
--------
>>> from geoprior.utils.generic_utils import get_actual_column_name
>>> df = pd.DataFrame({'subsidence_actual': [1, 2, 3]})
>>> get_actual_column_name(df, tname="subsidence")
'subsidence_actual'
>>> df = pd.DataFrame({'subsidence': [1, 2, 3]})
>>> get_actual_column_name(df, tname="subsidence")
'subsidence'
>>> df = pd.DataFrame({'actual': [1, 2, 3]})
>>> get_actual_column_name(df)
'actual'
>>> df = pd.DataFrame({'measurement': [1, 2, 3]})
>>> get_actual_column_name(df, tname="subsidence", error="warn")
Warning: Could not determine the actual target column in the DataFrame.
None
"""
# Validate input DataFrame
if not isinstance(df, pd.DataFrame):
raise TypeError("`df` must be a pandas DataFrame.")
# If `actual_name` is provided, check if it exists in df
if actual_name and actual_name in df.columns:
return actual_name
# If `tname` exists in df, return it
if tname and tname in df.columns:
return tname
# If `<tname>_actual` exists, return it
if tname and f"{tname}_actual" in df.columns:
return f"{tname}_actual"
# If "actual" column exists, return it
if "actual" in df.columns:
return "actual"
# Handle the case when no valid column is found
msg = "Could not determine the actual target column in the DataFrame."
if error == "raise":
raise ValueError(msg)
elif error == "warn":
warnings.warn(msg, UserWarning, stacklevel=2)
if default_to == "tname":
return tname
return None # If `error='ignore'`, return None silently
[docs]
def exclude_duplicate_kwargs(
func: callable,
existing_kwargs: dict[str, Any] | list[str],
user_kwargs: dict[str, Any],
) -> dict[str, Any]:
"""
Prevents the user from overriding existing parameters
in a target function. The method `exclude_duplicate_kwargs`
checks both developer-specified and function-level
parameter names to exclude them from `user_kwargs`.
.. math::
\text{final\_kwargs} =
\{\,(k, v) \in \text{user\_kwargs} \,\mid\,
k \notin \text{protected\_params}\,\}
Parameters
----------
func : callable
The target function whose valid parameters are
checked. It uses Python's introspection to gather
the acceptable parameter names.
existing_kwargs : dict or list
Developer-defined parameters to protect. Can be:
* A dictionary of parameter-value pairs (e.g.,
``{'ax': ax_obj, 'data': df}``) whose keys
are excluded from user overrides.
* A list of parameter names (e.g., ``['ax',
'data']``) to protect from user overrides.
user_kwargs : dict
The user-supplied keyword arguments that are
candidates for merging with `existing_kwargs`.
This dictionary is filtered to remove collisions
with protected parameters.
Returns
-------
dict
A filtered dictionary of user-defined arguments
that do not overlap with protected parameters.
Examples
--------
>>> from geoprior.utils.generic_utils import exclude_duplicate_kwargs
>>> import seaborn as sns
>>> # Developer has some base kwargs
... base_kwargs = {
... 'x': 'species',
... 'y': 'sepal_length',
... 'palette': 'viridis'
... }
>>> # User tries to override 'x' with new param
... user_args = {
... 'x': 'petal_width',
... 'color': 'red'
... }
>>> # Filter out duplicates
... safe_args = exclude_duplicate_kwargs(
... sns.scatterplot,
... base_kwargs,
... user_args
... )
>>> safe_args
{'color': 'red'}
Notes
-----
By default, if `existing_kwargs` is a dictionary,
its keys are treated as protected parameter names.
If it's a list, those items are protected. The
function signature of `func` is also used to
verify that only recognized parameters are
protected. Keyword-filtering patterns like this
are covered in :cite:t:`BeazleyJones2013PythonCookbook`.
See Also
--------
inspect.signature : Used to introspect function
parameters.
filter_valid_kwargs : Another inline function that
discards user params not valid for a given
function.
"""
# Validate existing_kwargs
if not isinstance(existing_kwargs, dict | list):
raise TypeError(
"existing_kwargs must be a dict or list"
)
# Build set of protected parameters
protected_params = (
existing_kwargs.keys()
if isinstance(existing_kwargs, dict)
else existing_kwargs
)
# Inspect func signature to get valid names
sig = inspect.signature(func)
valid_params = {
name: param
for name, param in sig.parameters.items()
if param.kind
not in (param.VAR_POSITIONAL, param.VAR_KEYWORD)
}
# Restrict protection to valid parameters only
exclude = [
param
for param in protected_params
if param in valid_params
]
# Filter out excluded params from user_kwargs
safe_kwargs = {
k: v
for k, v in user_kwargs.items()
if k not in exclude
}
return safe_kwargs
[docs]
def reorder_columns(
df: pd.DataFrame,
columns: str | list[str],
pos: str | int | float = "end",
):
"""
Reorder columns in a DataFrame by moving specified
columns to a chosen position.
This function locates `<columns>` in the original
DataFrame `<df>` and rearranges them based on the
parameter ``pos``. If ``pos`` is `"end"`, columns
are appended to the end. If `"begin"` or `"start"`,
they are placed at the front. If `"center"`, they
are inserted at the midpoint:
Parameters
----------
df : pandas.DataFrame
The input DataFrame to be modified.
columns : str or iterable of str
A single column name or multiple column
names to reposition. If a single string is
given, it is converted to a list with one
element.
pos : str, int, or float, default ``"end"``
Determines the target placement:
- ``"end"``: Append after all other
columns.
- ``"begin"`` or ``"start"``: Prepend at
the start.
- ``"center"``: Insert at the midpoint of
remaining columns.
- integer or float: Insert at zero-based
index among the remaining columns. If
out of bounds, the original DataFrame is
returned unchanged.
Returns
-------
pandas.DataFrame
A new DataFrame with `<columns>` moved as
specified by ``pos``.
Methods
-------
`reorder_columns_in`
This method rearranges columns without
altering values or data order beyond column
placement.
Notes
-----
- The function checks if `<columns>` exist in
`<df>`, ignoring columns not present.
- A warning is issued if the position is beyond
the range of valid indices.
- Negative indices for integer ``pos`` are
converted to positive by adding the total
number of remaining columns.
.. math::
i_{\\text{center}} = \\left\\lfloor
\\frac{|R|}{2}
\\right\\rfloor,
where :math:`|R|` is the number of remaining
columns after removing the target columns.
For integer or float ``pos``, the target columns
are inserted at index :math:`\\lfloor pos
\\rfloor` among the remaining columns. Column-order
management follows common DataFrame practices
discussed in :cite:t:`McKinney2017PythonDataAnalysis`.
Examples
--------
>>> from geoprior.utils.generic_utils import reorder_columns
>>> import pandas as pd
>>> data = pd.DataFrame({
... 'id': [1, 2, 3],
... 'latitude': [10.1, 10.2, 10.3],
... 'landslide': [0, 1, 0],
... 'longitude': [20.1, 20.2, 20.3]
... })
>>> # Move 'landslide' to the end (default)
>>> reorder_columns(data, 'landslide', pos="end")
id latitude longitude landslide
0 1 10.1 20.1 0
1 2 10.2 20.2 1
2 3 10.3 20.3 0
See Also
--------
pandas.DataFrame.reindex : Pandas method for
reindexing or
reordering columns
more generally.
"""
# Validate that the input 'df' is a pandas DataFrame
if not isinstance(df, pd.DataFrame):
raise TypeError(
f"Expect dataframe. Got {type(df).__name__!r}"
)
# Normalize 'columns' to a list of strings
if isinstance(columns, str):
columns = [columns]
elif not isinstance(columns, list | tuple | set):
raise TypeError(
"columns must be a string or an "
"iterable of strings."
)
# Filter out columns that are actually in df
valid_targets = [
col for col in columns if col in df.columns
]
if not valid_targets:
raise ValueError(
"None of the specified columns were "
"found in the DataFrame."
)
# Create a list of remaining columns while
# preserving their original order
remaining_columns = [
col for col in df.columns if col not in valid_targets
]
# Maintain the original target columns' order
target_order = [
col for col in df.columns if col in valid_targets
]
# Compute new order based on 'pos'
if isinstance(pos, str):
pos_lower = pos.lower()
if pos_lower == "end":
new_order = remaining_columns + target_order
elif pos_lower in ("begin", "start"):
new_order = target_order + remaining_columns
elif pos_lower == "center":
center_index = len(remaining_columns) // 2
new_order = (
remaining_columns[:center_index]
+ target_order
+ remaining_columns[center_index:]
)
else:
try:
pos = int(pos)
except Exception as e:
raise ValueError(
f"Unrecognized string value for pos:"
f" {pos}"
) from e
# If 'pos' is numeric, insert the columns at
# that position
if isinstance(pos, int | float):
pos_int = int(pos)
# Check if pos_int is within bounds
if pos_int > len(remaining_columns) or pos_int < -len(
remaining_columns
):
warnings.warn(
"Position is out of bounds. Skipping "
"moving columns.",
stacklevel=2,
)
return df
# Convert negative index to positive
if pos_int < 0:
pos_int = len(remaining_columns) + pos_int
new_order = (
remaining_columns[:pos_int]
+ target_order
+ remaining_columns[pos_int:]
)
else:
# If it's not a recognized type, raise an error
if not isinstance(pos, str):
raise TypeError(
"pos must be a string, integer, or float."
)
# Return the DataFrame with the new column order
return df[new_order]
def map_scales_choice(
scales_choice_str: str,
) -> list[int] | None:
"""
Maps a string choice for scales to an actual list of scale values
or None if no scales are provided.
This function interprets specific string inputs and converts them
to corresponding lists of integers. If no matching string is found,
it returns None as a default fallback.
Parameters
----------
scales_choice_str : str
A string representing the choice of scale. It can be one of
'default_scales', 'alt_scales', or 'no_scales'.
Returns
-------
Optional[List[int]]
A list of integers corresponding to the chosen scales, or None
if no scales are provided.
Examples
--------
>>> from geoprior.utils.generic_utils import map_scales_choice
>>> map_scales_choice('default_scales')
[1, 3, 7]
>>> map_scales_choice('alt_scales')
[1, 5, 10]
>>> map_scales_choice('no_scales')
None
>>> map_scales_choice('unknown_scales')
None
"""
if not isinstance(scales_choice_str, str):
raise ValueError(
"scales_choice_str must be a string."
)
if scales_choice_str == "default_scales":
return [1, 3, 7]
elif scales_choice_str == "alt_scales":
return [1, 5, 10]
elif scales_choice_str == "no_scales":
return None
return None # Default fallback for unrecognized inputs
def cast_hp_to_bool(
params: dict[str, Any],
param_name: str,
default_value: bool = False,
) -> None:
"""
Casts a hyperparameter value in the `params` dictionary to a boolean.
The function ensures that hyperparameters that are supposed to be
booleans are correctly cast to Python booleans. This is particularly
useful in cases where Keras Tuner or other libraries might return
0 or 1 as the value for boolean choices.
Parameters
----------
params : Dict[str, Any]
A dictionary of hyperparameters to be validated and updated in place.
param_name : str
The key of the hyperparameter to be checked and casted to boolean.
default_value : bool, optional
The default value to assign if the parameter is not found or
has an invalid value. The default is False.
Returns
-------
None
This function modifies the `params` dictionary in-place.
Examples
--------
>>> from geoprior.utils.generic_utils import cast_hp_to_bool
>>> params = {'use_batch_norm': 1}
>>> cast_hp_to_bool(params, 'use_batch_norm', default_value=False)
>>> print(params['use_batch_norm'])
True
>>> params = {'use_residuals': 0}
>>> cast_hp_to_bool(params, 'use_residuals', default_value=True)
>>> print(params['use_residuals'])
False
>>> params = {'use_dropout': 'yes'}
>>> cast_hp_to_bool(params, 'use_dropout', default_value=False)
>>> print(params['use_dropout'])
False
"""
if not isinstance(params, dict):
raise ValueError("params must be a dictionary.")
if param_name in params:
value = params[param_name]
if isinstance(
value, int | float
): # Handles 0, 1, 0.0, 1.0
params[param_name] = bool(value)
elif not isinstance(value, bool):
# Warn if the value is not a valid boolean type (int, float, bool)
warnings.warn(
f"Hyperparameter '{param_name}' received unexpected value "
f"'{value}' (type: {type(value)}). Expected bool or 0/1. "
f"Defaulting to {default_value}. Please check param_space definition.",
stacklevel=2,
)
params[param_name] = default_value
def cast_multiple_bool_params(
params: dict[str, Any],
bool_params_to_cast: list[tuple[str, bool]],
) -> None:
"""
Casts a list of boolean hyperparameters to ensure they are Python booleans.
This function iterates over a list of parameter names and default values,
ensuring that each parameter is properly cast to a boolean type. If a
parameter does not exist or has an invalid value, it is set to the
provided default.
Parameters
----------
params : Dict[str, Any]
A dictionary of hyperparameters to be validated and updated in place.
bool_params_to_cast : List[Tuple[str, bool]]
A list of tuples where each tuple consists of a hyperparameter name
and its default boolean value. The function will cast the corresponding
parameter to a boolean.
Returns
-------
None
This function modifies the `params` dictionary in-place.
Examples
--------
>>> from geoprior.utils.generic_utils import cast_multiple_bool_params
>>> params = {'use_batch_norm': 1, 'use_residuals': 0}
>>> cast_multiple_bool_params(params, [('use_batch_norm', False), ('use_residuals', True)])
>>> print(params)
{'use_batch_norm': True, 'use_residuals': False}
"""
if not isinstance(params, dict):
raise ValueError("params must be a dictionary.")
if not isinstance(bool_params_to_cast, list):
raise ValueError(
"bool_params_to_cast must be a list of tuples."
)
for param_name, default_value in bool_params_to_cast:
cast_hp_to_bool(params, param_name, default_value)
def print_box(
msg: str | list[str],
width: int = 80,
align: str = "center",
border_char: str = "+",
horizontal_char: str = "-",
vertical_char: str = "|",
padding: int = 1,
) -> None:
"""
Print a boxed message with customizable styling.
Parameters
----------
msg : str or list of str
The message to display. If a list is provided, each element
is treated as a separate line.
width : int, default=80
Total width of the box including borders.
align : {'center','left','right'}, default='center'
Text alignment within the box.
border_char : str, default='+'
Character used for the four corners of the box.
horizontal_char : str, default='-'
Character used for the top/bottom border lines.
vertical_char : str, default='|'
Character used for the left/right border lines.
padding : int, default=1
Number of spaces between text and vertical borders.
Returns
-------
None
Prints the styled box directly to stdout.
Examples
--------
>>> from geoprior.utils.generic_utils import print_box
>>> print_box("Hello, world!", width=40)
+--------------------------------------+
| Hello, world! |
+--------------------------------------+
>>> print_box(
... ["Line one", "This is a longer line that will wrap"],
... width=50, align='left', border_char='*'
... )
**************************************************
* Line one *
* This is a longer line that will *
* wrap *
**************************************************
"""
# Ensure msg is a list of lines
lines = msg if isinstance(msg, list) else msg.split("\n")
# Calculate inner width for text (excluding borders & padding)
inner_width = width - 2 - 2 * padding
if inner_width < 10:
raise ValueError(
"Width too small for the given padding."
)
# Build top and bottom border
border_line = (
border_char
+ horizontal_char * (width - 2)
+ border_char
)
# Function to align a single line
def align_line(text: str) -> str:
if len(text) > inner_width:
text = text[:inner_width]
if align == "left":
return text.ljust(inner_width)
elif align == "right":
return text.rjust(inner_width)
else: # center
return text.center(inner_width)
# Print the box
print(border_line)
for raw_line in lines:
# Wrap long lines
wrapped = textwrap.wrap(raw_line, inner_width) or [""]
for wline in wrapped:
print(
vertical_char
+ " " * padding
+ align_line(wline)
+ " " * padding
+ vertical_char
)
print(border_line)
def handle_emptiness(
obj: Any,
ops: str = "validate",
empty_as_none: bool = True,
) -> Any | bool:
"""
Smart helper to check or normalize empty/None values.
Parameters
----------
obj : Any
The object to inspect. Can be None, numpy array,
pandas Series/DataFrame, list, tuple, etc.
ops : {'validate','check_only'}, default='validate'
- 'check_only': return True if obj is None or empty.
- 'validate': return normalized obj or placeholder.
empty_as_none : bool, default=True
Only used when ops='validate':
- If True, empty or None -> return None
- If False, empty or None -> return [] (empty list)
Returns
-------
obj or bool
- If `ops=='check_only'`, returns a bool indicating if
`obj` is None or empty.
- If `ops=='validate'`,
- Non-empty object -> returned unchanged
- Empty/None -> None or [] based on `empty_as_none`
Examples
--------
>>> import numpy as np
>>> from geoprior.utils.generic_utils import \
handle_emptiness
>>> handle_emptiness(None, ops='check_only')
True
>>> handle_emptiness([], ops='check_only')
True
>>> handle_emptiness(np.array([]), ops='check_only')
True
>>> handle_emptiness(pd.DataFrame(), ops='check_only')
True
>>> handle_emptiness([1,2,3], ops='check_only')
False
>>> handle_emptiness([], ops='validate', empty_as_none=True)
None
>>> handle_emptiness([], ops='validate', empty_as_none=False)
[]
"""
def _is_empty(o: Any) -> bool:
"""Determine if o is None or contains no elements."""
if o is None:
return True
if isinstance(o, np.ndarray):
return o.size == 0
if isinstance(o, pd.Series | pd.DataFrame):
return o.empty
try:
return len(o) == 0
except Exception:
return False
if ops == "check_only":
return _is_empty(obj)
if ops == "validate":
if _is_empty(obj):
return None if empty_as_none else []
return obj
raise ValueError(
"`ops` must be 'validate' or 'check_only'"
)
def _report_condition(
policy: Literal["raise", "warn", "ignore"],
default_message: str,
custom_message: str | None,
exception_type: type = ValueError,
warning_type: type = UserWarning,
) -> None:
"""Helper to raise error, issue warning, or ignore based on policy."""
message_to_use = (
custom_message
if custom_message is not None
else default_message
)
if policy == "raise":
raise exception_type(message_to_use)
elif policy == "warn":
warnings.warn(
message_to_use, warning_type, stacklevel=2
)
# If 'ignore', do nothing.
def are_all_values_in_bounds(
values: Any,
bounds: tuple[Real, Real] | list[Real] = (0, 1),
closed: Literal[
"both", "left", "right", "neither"
] = "neither",
nan_policy: Literal[
"raise", "propagate", "omit"
] = "propagate",
empty_policy: Literal[
"allow_true", "treat_as_false"
] = "allow_true",
error: Literal["raise", "warn", "ignore"] = "raise",
message: str | None = None,
) -> bool:
"""
Check if all evaluable input values are within specified numeric bounds.
Supports various input types including scalars, lists, tuples,
NumPy arrays, and pandas Series/Index. It attempts to convert
inputs to a numeric format (float) for comparison. Non-numeric
entries that cannot be converted will typically result in False
or trigger the `error` policy if conversion itself fails.
Parameters
----------
values : Any
Input data to check.
bounds : tuple[Real, Real] or list[Real, Real], default=(0, 1)
A sequence of two numeric values (lower_bound, upper_bound).
`bounds[0]` must be less than or equal to `bounds[1]`.
closed : {'both', 'left', 'right', 'neither'}, default='neither'
Defines whether the interval is closed or open:
- ``'both'``: `lower_bound <= value <= upper_bound`
- ``'left'``: `lower_bound <= value < upper_bound`
- ``'right'``: `lower_bound < value <= upper_bound`
- ``'neither'``: `lower_bound < value < upper_bound` (strict)
nan_policy : {'raise', 'propagate', 'omit'}, default='propagate'
How to handle NaN (Not a Number) values:
- ``'raise'``: Behavior depends on `error` policy. If `error`
is 'raise', a ValueError is raised. If 'warn', a warning
is issued and False is returned. If 'ignore', False is
returned (as NaNs are not in bounds).
- ``'propagate'``: If any NaN is found, return False.
- ``'omit'``: NaNs are removed. The behavior for an array
that becomes empty after omission is governed by
`empty_array_policy`.
empty_array_policy : {'allow_true', 'treat_as_false'}, default='allow_true'
Policy for handling an empty array of values to check. An array
can be empty if the original input was empty (e.g., `[]`) or
if all values were NaNs and `nan_policy='omit'` was used.
- ``'allow_true'``: An empty array is considered to have all
its (zero) elements within bounds (vacuously true).
- ``'treat_as_false'``: An empty array results in False.
error : {'raise', 'warn', 'ignore'}, default='raise'
Policy for handling invalid parameters (e.g., `bounds`,
`closed`), non-convertible inputs, NaNs when
`nan_policy='raise'`, or when values are out of bounds:
- ``'raise'``: Raise a ValueError.
- ``'warn'``: Issue a UserWarning. The function will then
return False if the condition (e.g., out of bounds, NaN
present with nan_policy='raise') makes the check fail.
- ``'ignore'``: Suppress the error/warning. The function will
return False if the condition makes the check fail.
message : str, optional
Custom message to use when `error` policy is 'raise' or 'warn'.
If None, a default message is used for the specific condition.
Returns
-------
bool
True if all evaluable values are within bounds (considering
`nan_policy` and `empty_array_policy`), False otherwise.
Raises
------
ValueError
If `error='raise'` and an invalid parameter is provided,
input is non-convertible, NaNs are present with
`nan_policy='raise'`, or values are out of bounds.
UserWarning
If `error='warn'` for the same conditions as ValueError.
Examples
--------
>>> from geoprior.utils.generic_utils import are_all_values_in_bounds
>>> are_all_values_in_bounds(0.5)
True
>>> are_all_values_in_bounds([]) # Empty list, default empty_policy='allow_true'
True
>>> are_all_values_in_bounds([], empty_policy='treat_as_false')
False
>>> are_all_values_in_bounds(pd.Series([np.nan]), nan_policy='omit',
empty_array_policy='treat_as_false')
False
>>> try:
... are_all_values_in_bounds([0, 2], bounds=(0,1),
... error='raise', closed='neither')
... except ValueError as e:
... print(e)
One or more values are out of the specified bounds.
"""
arr: np.ndarray
# --- 1. Input Conversion to NumPy float array ---
if np.isscalar(values):
if not isinstance(values, int | float | np.number):
try:
val_float = float(values)
arr = np.array([val_float])
except (ValueError, TypeError):
_report_condition(
error,
f"Scalar input '{values}' of type {type(values).__name__} "
"is not convertible to a numeric value.",
message,
)
return False
else:
arr = np.array([float(values)])
elif isinstance(values, pd.Series | pd.Index):
temp_arr = values.to_numpy()
try:
if temp_arr.dtype == object or not np.issubdtype(
temp_arr.dtype, np.number
):
arr = pd.to_numeric(
temp_arr, errors="raise"
).astype(float)
else:
arr = temp_arr.astype(float)
except (ValueError, TypeError):
_report_condition(
error,
"Pandas input contains non-convertible non-numeric values.",
message,
)
return False
else:
try:
arr = np.asarray(values, dtype=float)
except (ValueError, TypeError):
_report_condition(
error,
"Input sequence contains non-convertible non-numeric values.",
message,
)
return False
# --- 2. Validate `bounds` and `closed` Parameters ---
# These checks happen early. If they fail and error is 'raise'/'warn',
# the function returns False or raises, so subsequent logic is safe.
if not (
isinstance(bounds, list | tuple) and len(bounds) == 2
):
_report_condition(
error,
"`bounds` must be a list or tuple of two numbers.",
message,
)
return False
try:
lower_b = float(bounds[0])
upper_b = float(bounds[1])
except (IndexError, ValueError, TypeError):
_report_condition(
error,
"Elements of `bounds` must be numeric and `bounds` must "
"have length 2.",
message,
)
return False
if lower_b > upper_b:
_report_condition(
error,
f"Lower bound {lower_b} cannot be greater than upper "
f"bound {upper_b}.",
message,
)
return False
valid_closed_options = {
"both",
"left",
"right",
"neither",
}
if closed not in valid_closed_options:
_report_condition(
error,
f"Invalid 'closed' parameter: {closed}. Must be one of "
f"{valid_closed_options}.",
message,
)
return False
# --- 3. Handle NaNs based on `nan_policy` ---
# This section modifies `arr` if nan_policy is 'omit'.
# For other policies, it might return early.
has_nans = np.isnan(arr).any()
if has_nans:
if nan_policy == "raise":
_report_condition(
error, "Input contains NaNs.", message
)
return False # Returns False if error policy is 'warn' or 'ignore'
elif nan_policy == "propagate":
return False
elif nan_policy == "omit":
arr = arr[~np.isnan(arr)]
# `arr` might be empty now. This is handled in step 4.
# --- 4. Handle Empty Array (original or after NaN omission) ---
if arr.size == 0:
if empty_policy == "allow_true":
return True
elif empty_policy == "treat_as_false":
return False
else: # Should be caught by Literal type hint
_report_condition(
error,
f"Internal error: Invalid 'empty_policy' "
f"value '{empty_policy}'.",
message,
)
return False # Fallback for invalid policy if not raised
# --- 5. Perform Bounds Check on (Potentially Modified) `arr` ---
# At this point, `arr` contains only non-NaN numeric values,
# and is guaranteed to be non-empty.
in_bounds_mask: np.ndarray
if closed == "both":
in_bounds_mask = (arr >= lower_b) & (arr <= upper_b)
elif closed == "left":
in_bounds_mask = (arr >= lower_b) & (arr < upper_b)
elif closed == "right":
in_bounds_mask = (arr > lower_b) & (arr <= upper_b)
elif closed == "neither":
in_bounds_mask = (arr > lower_b) & (arr < upper_b)
else:
# This case should have been caught by parameter validation.
# If error='ignore' for invalid 'closed', this is a fallback.
_report_condition(
error,
f"Internal error: Invalid 'closed' value '{closed}' "
"reached bounds check.",
message,
)
return False
all_values_are_in_bounds = bool(np.all(in_bounds_mask))
if not all_values_are_in_bounds:
_report_condition(
error,
"One or more values are out of the specified bounds.",
message,
)
return False # Returns False if error policy is 'warn' or 'ignore'
else:
return True
[docs]
def rename_dict_keys(
data: dict,
param_to_rename: dict | None = None,
order: str = "forward",
):
"""
Renames keys in the `data` dictionary based on
the provided `param_to_rename` dictionary.
This function will check if the key exists in the `data` dictionary.
If the key is present, it will be renamed according to the mapping
provided in the `param_to_rename` dictionary. If the key is not found
in `data` and a mapping exists in `param_to_rename`, the function will
apply the rename. If no rename is required, the function will return the
original dictionary.
Parameters
----------
data : dict
The dictionary whose keys may be renamed. The function will iterate
over the keys of this dictionary and rename them according to the
mapping provided in `param_to_rename`.
param_to_rename : dict, optional
A dictionary mapping old keys to new keys. Each key in this dictionary
represents an old key that may be found in `data`, and the corresponding
value is the new key. If `None`, no renaming is performed. If a key in
`data` matches an old key in `param_to_rename`, that key will be renamed.
order: str, {'forward', 'reverse'}:
Order for renaming keys in a flat dict::
forward (default):
param_to_rename = {old_key: new_key}
reverse:
param_to_rename = {
canonical_key: alias or (alias1, alias2, ...)
}
The first alias found in `data` is moved under the
canonical key. If the canonical key already exists,
nothing is changed for that mapping.
Returns
-------
dict
The updated dictionary with keys renamed as per the `param_to_rename`
mapping. If no keys need renaming, the original dictionary is returned.
Raises
------
ValueError
If `param_to_rename` is not a dictionary, a `ValueError` will be raised.
Examples
--------
>>> from geoprior.utils.generic_utils import
Example 1: Renaming a key in the dictionary:
>>> data = {"subsidence": 100}
>>> param_to_rename = {"subsidence": "subs_pred"}
>>> rename_dict_keys(data, param_to_rename)
{'subs_pred': 100}
Example 2: When the key is already valid (no change needed):
>>> data = {"subs_pred": 100}
>>> param_to_rename = {"subsidence": "subs_pred"}
>>> rename_dict_keys(data, param_to_rename)
{'subs_pred': 100}
Example 3: When `param_to_rename` is `None`, no renaming is performed:
>>> data = {"subsidence": 100}
>>> rename_dict_keys(data)
{'subsidence': 100}
Notes
-----
- If `param_to_rename` is `None`, no renaming occurs, and the `data`
dictionary is returned as is.
- This function raises an error if `param_to_rename` is not a dictionary.
Ensure that the parameter is a valid dictionary of old-to-new key mappings.
"""
if param_to_rename is None:
return (
data # If no param_to_rename, return data as is
)
# Ensure param_to_rename is a dictionary
if not isinstance(param_to_rename, dict):
raise ValueError(
"param_to_rename must be a dictionary."
)
if not isinstance(data, dict):
raise ValueError(
f"data must be a dictionary. Got {type(data).__name__!r}"
)
if order not in ("forward", "reverse"):
raise ValueError(
"order must be 'forward' or 'reverse'."
)
# Create a copy of data to avoid modifying the original
updated_data = data.copy()
if order == "forward":
# Rename keys based on param_to_rename mapping
for old_key, new_key in param_to_rename.items():
if old_key == new_key:
continue
if old_key in updated_data:
# do not clobber an existing canonical value
if (
new_key in updated_data
and new_key != old_key
):
# keep existing new_key; drop old_key
updated_data.pop(old_key)
else:
updated_data[new_key] = updated_data.pop(
old_key
)
return updated_data
# reverse mode: canonical -> aliases
for canonical, aliases in param_to_rename.items():
# normalize aliases to tuple
if isinstance(aliases, list | tuple):
alias_iter = tuple(aliases)
elif isinstance(aliases, str):
alias_iter = (aliases,)
else:
raise ValueError(
"reverse mode requires alias str or sequence."
)
# if canonical already present, prefer it
if canonical in updated_data:
continue
# move first alias found → canonical
for a in alias_iter:
if a in updated_data:
updated_data[canonical] = updated_data.pop(a)
break
# Rename keys based on param_to_rename mapping
# for old_key, new_key in param_to_rename.items():
# if old_key in updated_data:
# updated_data[new_key] = updated_data.pop(old_key)
return updated_data
[docs]
def ensure_directory_exists(path):
"""
Ensure that a directory exists at the given path, creating it if needed.
This function checks whether the provided `path` exists and is a
directory. If the path does not exist, it attempts to create the
directory (including any necessary parent directories). If a file
with the same name already exists, or if creation fails, an exception
is raised.
Parameters
----------
path : str or pathlib.Path
The filesystem path for which to ensure directory existence. Can be
either a string or a `pathlib.Path` object.
Returns
-------
pathlib.Path
A `Path` object pointing to the existing (or newly created)
directory.
Raises
------
TypeError
If `path` is not a string or `pathlib.Path`.
FileExistsError
If a file (not a directory) already exists at `path`.
OSError
If the directory cannot be created for any other reason (e.g.,
insufficient permissions).
Examples
--------
>>> from pathlib import Path
>>> from geoprior.utils.generic_utils import ensure_directory_exists
>>> output_dir = ensure_directory_exists("data/output")
>>> isinstance(output_dir, Path)
True
>>> # The directory "data/output" now exists on disk.
Notes
-----
- Uses `pathlib.Path.mkdir(..., parents=True, exist_ok=True)` under the
hood for cross-platform compatibility.
- If `path` already exists as a directory, this function returns
immediately without modifying it.
See Also
--------
pathlib.Path.mkdir : Method to create a directory.
os.makedirs : Legacy function for creating directories recursively.
"""
# Validate type
if not isinstance(path, str | Path):
raise TypeError(
"`path` must be a str or pathlib.Path, got "
f"{type(path).__name__!r}"
)
# Convert to Path
dir_path = Path(path)
# If it exists and is a directory, return immediately
if dir_path.exists():
if dir_path.is_dir():
return dir_path
else:
# A non-directory file exists at this path
raise FileExistsError(
f"A non-directory file exists at: {dir_path}"
)
# Attempt to create the directory (including parents)
try:
dir_path.mkdir(parents=True, exist_ok=True)
except OSError as exc:
raise OSError(
f"Unable to create directory {dir_path}: {exc}"
) from exc
return dir_path
[docs]
def normalize_time_column(
df: pd.DataFrame,
time_col: str,
datetime_col: str = "datetime_temp",
year_col: str = "year_int",
drop_orig: bool = False,
) -> pd.DataFrame:
"""
Normalize a time column into a datetime column and an integer year.
The input column may contain integer years, strings, or existing pandas
Datetime values. The function creates ``datetime_col`` with parsed
timestamps and ``year_col`` with the extracted integer year. When
``drop_orig=True``, the original ``time_col`` is removed and
``datetime_col`` is renamed back to ``time_col``.
Parameters
----------
df : pandas.DataFrame
Input DataFrame containing a time column named ``time_col``.
time_col : str
Name of the column to normalize.
datetime_col : str, default='datetime_temp'
Name of the parsed datetime column.
year_col : str, default='year_int'
Name of the extracted integer year column.
drop_orig : bool, default=False
If ``True``, drop the original ``time_col`` after parsing and rename
``datetime_col`` back to ``time_col``.
Returns
-------
pandas.DataFrame
A copy of ``df`` with the parsed datetime column and integer year
column.
Raises
------
ValueError
If ``time_col`` is missing or parsing fails for any entry.
TypeError
If ``df`` is not a pandas DataFrame.
"""
if not isinstance(df, pd.DataFrame):
raise TypeError(
f"`df` must be a pandas DataFrame, got {type(df).__name__}."
)
if time_col not in df.columns:
raise ValueError(
f"Time column '{time_col}' not found in DataFrame."
)
df_norm = df.copy()
# Create empty column for parsed datetime values
df_norm[datetime_col] = pd.NaT
# 1) Identify rows where value is already a Timestamp or datetime64
mask_ts = df_norm[time_col].apply(
lambda x: pd.api.types.is_datetime64_any_dtype(
type(x)
)
)
if mask_ts.all():
# If entire column is datetime dtype, just copy it
df_norm[datetime_col] = df_norm[time_col]
else:
# 2) Identify pure-integer years
mask_int = df_norm[time_col].apply(
lambda x: isinstance(x, int | np.integer)
)
# Parse integer‐year entries via format="%Y"
if mask_int.any():
try:
df_norm.loc[mask_int, datetime_col] = (
pd.to_datetime(
df_norm.loc[mask_int, time_col],
format="%Y",
)
)
except Exception as exc:
raise ValueError(
f"Could not parse integer years in '{time_col}': {exc}"
)
# 3) Parse all remaining (non-integer) entries generically
mask_remain = ~mask_int
if mask_remain.any():
try:
df_norm.loc[mask_remain, datetime_col] = (
pd.to_datetime(
df_norm.loc[mask_remain, time_col],
errors="raise",
)
)
except Exception as exc:
raise ValueError(
f"Could not parse '{time_col}' into datetime for "
f"non-integer entries: {exc}"
)
# Finally, ensure datetime_col is dtype datetime64[ns]
if not pd.api.types.is_datetime64_any_dtype(
df_norm[datetime_col]
):
try:
df_norm[datetime_col] = pd.to_datetime(
df_norm[datetime_col], errors="raise"
)
except Exception as exc:
raise ValueError(
f"Failed to coerce '{datetime_col}' to datetime: {exc}"
)
# 4) Extract integer year into year_col
try:
df_norm[year_col] = df_norm[
datetime_col
].dt.year.astype(int)
except Exception as exc:
raise ValueError(
f"Failed to extract year from '{datetime_col}': {exc}"
)
# --- Drop original time_col and adjust columns if requested ---
if drop_orig:
# If the original time_col and year_col share the same name,
# we need to preserve the integer year before dropping.
if year_col == time_col:
# Temporarily stash original values
temp_col = f"{time_col}_xtemp"
df_norm[temp_col] = df_norm[time_col]
df_norm.drop(columns=[time_col], inplace=True)
# Rename the stashed column back to time_col
df_norm.rename(
columns={temp_col: time_col}, inplace=True
)
else:
# Simply drop the original time_col
df_norm.drop(columns=[time_col], inplace=True)
return df_norm
[docs]
def select_mode(
mode: str | None = None,
default: str = "pihal_like",
canonical: dict[str, Any] | list[Any] | None = None,
) -> str:
"""
Resolve a user-supplied mode alias to a canonical value.
Parameters
----------
mode : str or None, optional
Case-insensitive mode alias. Accepted values include ``'pihal'``,
``'pihal_like'``, ``'tft'``, ``'tft_like'``, or ``None`` to fall
back to ``default``.
default : {'pihal', 'tft'}, optional
Canonical value returned when ``mode`` is ``None``.
canonical : dict or list or None, optional
Custom alias mapping. A dictionary maps input strings to canonical
values. A list is treated as an identity mapping for its items.
Returns
-------
str
Canonical string corresponding to the resolved mode.
Raises
------
ValueError
If ``mode`` does not match any accepted alias.
"""
default_canonical = {
"pihal": "pihal_like",
"pihal_like": "pihal_like",
"tft": "tft_like",
"tft_like": "tft_like",
"tft-like": "tft_like",
"pihal-like": "pihal_like",
}
# Use provided canonical map or fall back to the default
if canonical is None:
canonical = default_canonical
elif isinstance(canonical, list):
# Convert list to dictionary where each element is both key and value
canonical = {
str(mode).lower(): str(mode).lower()
for mode in canonical
}
if mode is None:
return canonical.get(default, default)
try:
return canonical[str(mode).lower().strip()]
except KeyError: # unknown keyword
valid = ", ".join(sorted(canonical.keys()))
raise ValueError(
f"Invalid mode '{mode}'. Choose one of: {valid} or None."
) from None
def _coerce_dt_kw(
*,
dt_col: Any = _SENTINEL,
time_col: Any = _SENTINEL,
_time_default: str | None = None,
**kwargs,
) -> dict[str, Any]:
r"""
Harmonise the interchangeable ``dt_col`` / ``time_col`` keywords.
Exactly **one** of *dt_col* or *time_col* must be supplied *after
accounting for defaults*. The helper returns a new mapping where the
canonical key ``"dt_col"`` holds the resolved column name and all other
keyword arguments are forwarded unchanged.
Parameters
----------
dt_col : str or None, optional
User-supplied datetime column. If omitted, *time_col* is used
instead.
time_col : str or None, optional
Alias for *dt_col*. Many public APIs historically expose this
keyword—this helper allows both to coexist.
_time_default : str or None, default ``None``
Pass the *function-level* default for *time_col* (for example
``'coord_t'``). If *time_col* equals this default and *dt_col* is
supplied, the default is silently ignored **without** triggering the
“both supplied” error.
**kwargs
Arbitrary additional keyword arguments that will be preserved in the
returned dictionary.
Returns
-------
dict
Shallow copy of *kwargs* plus the resolved key::
{"dt_col": <chosen_name>, ...}
Raises
------
ValueError
* If neither *dt_col* nor *time_col* is supplied.
* If **both** are supplied **and** *time_col* differs from
*_time_default* (i.e. the user explicitly provided two conflicting
names).
Notes
-----
Use this helper at the very top of public API functions, e.g.::
def forecast_view(..., time_col='coord_t', dt_col=None, ...):
kw = _coerce_dt_kw(dt_col=dt_col,
time_col=time_col,
_time_default='coord_t')
time_col = kw.pop("dt_col")
Doing so keeps internal code agnostic to which alias the caller picked.
Examples
--------
>>> from geoprior.utils.generic_utils import _coerce_dt_kw
>>> _coerce_dt_kw(dt_col='date')
{'dt_col': 'date'}
>>> _coerce_dt_kw(time_col='timestamp')
{'dt_col': 'timestamp'}
>>> _coerce_dt_kw(dt_col='d', time_col='coord_t', _time_default='coord_t')
{'dt_col': 'd'} # default 'coord_t' safely ignored
>>> _coerce_dt_kw(dt_col='d', time_col='t', _time_default='coord_t')
Traceback (most recent call last):
...
ValueError: Supply **exactly one** of 'dt_col' or 'time_col'.
"""
# ──────────────────────── detect “omitted” vs “provided None” ──
dt_provided = (dt_col is not _SENTINEL) and (
dt_col is not None
)
time_provided = time_col is not _SENTINEL and (
time_col is not None
)
# ──────────────────────── three legal scenarios ───────────────
if not dt_provided and not time_provided:
raise ValueError(
"You must supply either 'dt_col' or its alias 'time_col'."
)
if dt_provided and time_provided:
# allow the benign case where time_col is *just* the default
if str(time_col) != str(_time_default):
raise ValueError(
"Supply **exactly one** of 'dt_col' or 'time_col', not both "
f"({'dt_col=' + str(dt_col)!r}, {'time_col=' + str(time_col)!r})."
)
# user gave dt_col; time_col was default -> ignore default
chosen = dt_col
else:
chosen = dt_col if dt_provided else time_col
# ──────────────────────── assemble output mapping ──────────────
out = dict(kwargs)
out["dt_col"] = chosen
return out
def save_figure(
figure: plt.Figure,
savefile: str | None = None,
save_fmts: str | list[str] = ".png",
overwrite: bool = True,
verbose: int = 1,
_logger=None,
**kwargs,
):
"""
Save the given matplotlib figure to disk in one or more formats.
This function robustly handles file naming, directory creation,
and multiple format exporting.
Parameters
----------
figure : plt.Figure
The matplotlib figure object to save.
savefile : str, optional
The target file path or name. If an extension is provided
(e.g., '.png', '.pdf'), it is used as the base name. If no
extension is provided or it's not a recognized image
format, the whole string is used as the base name. If None,
a default timestamped name is generated.
save_fmts : str or list of str, optional
The format(s) to save the figure in. Default is '.png'.
overwrite : bool, default=True
If False, raises a FileExistsError if the output file
already exists.
verbose : int, default=1
Verbosity level for logging.
_logger : logging.Logger, optional
A logger instance to use for messages. Defaults to `print`.
**kwargs :
Additional keyword arguments passed to `plt.savefig()`
(e.g., `dpi`, `bbox_inches`).
Returns
-------
str or list[str] or None
The path to the saved file if one format was requested, a list
of paths if multiple formats were requested, or None if saving
failed.
"""
# Define a set of common, valid image extensions for checking.
VALID_EXTENSIONS = {
".png",
".jpg",
".jpeg",
".pdf",
".svg",
".eps",
".tiff",
}
# --- Step 1: Determine the base output path and name ---
if savefile is None:
# Generate a default timestamped filename if none is provided.
timestamp = datetime.datetime.now().strftime(
"%Y%m%d_%H%M%S"
)
base_name = f"figure_{timestamp}"
output_dir = Path.cwd()
else:
p = Path(savefile)
output_dir = p.parent
# Check if the provided savefile has a valid image extension.
if p.suffix.lower() in VALID_EXTENSIONS:
# If yes, use the part before the extension as the base name.
base_name = p.stem # str(p.with_suffix(''))#
else:
# If not, treat the entire string as the base name.
base_name = p.name # str(savefile) # p.name
# --- Step 2: Prepare the list of formats to save ---
if isinstance(save_fmts, str):
save_fmts = [
save_fmts
] # Ensure it's a list for iteration.
# Clean up formats to ensure they start with a dot.
formats_to_save = [
f".{fmt.lstrip('.')}" for fmt in save_fmts
]
# --- Step 3: Save the figure in each requested format ---
saved_paths = []
try:
# Ensure the output directory exists before saving.
ExistenceChecker.ensure_directory(output_dir)
for fmt in set(
formats_to_save
): # Use set to avoid duplicates
# Construct the final, full path for the current format.
final_path = (
output_dir / f"{base_name}{fmt}"
) # .with_suffix(fmt)
# Handle overwrite logic.
if not overwrite and final_path.exists():
raise FileExistsError(
f"File '{final_path}' already exists. "
"Set 'overwrite=True' to overwrite."
)
# Save the figure with any extra keyword arguments.
figure.savefig(final_path, **kwargs)
if verbose > 0:
vlog(
f"Saved figure to {final_path}",
verbose=verbose,
level=2,
logger=_logger,
)
saved_paths.append(str(final_path))
# Return a single path if only one file was saved, else the list.
return (
saved_paths[0]
if len(saved_paths) == 1
else saved_paths
)
except Exception as e:
if verbose > 0:
vlog(
f"Error saving figure: {e}",
verbose=verbose,
level=0,
logger=_logger,
)
return None
def ensure_cols_exist(
df: pd.DataFrame,
*cols: str | Sequence[str],
strict: bool = False,
error: str = "warn",
return_as_flatten: bool = True,
) -> list[str] | list[list[str]]:
"""
Validate that *every* requested column is present in a DataFrame.
Parameters
----------
df : pandas.DataFrame
The DataFrame you intend to slice.
*cols : str | Sequence[str]
A loose collection of column names (strings **or**
iterables of strings). Examples::
ensure_cols_exist(df, "foo", "bar")
ensure_cols_exist(df, ["foo", "bar"], "baz")
strict : bool, default ``False``
* ``True`` – all requested columns **must** exist or the
function triggers `error='raise'` automatically.
* ``False`` – missing columns can be tolerated depending on
*error*.
error : {'warn', 'raise', 'ignore'}, default ``'warn'``
Behaviour when *some* columns are missing:
* ``'warn'`` – emit a :class:`UserWarning`
* ``'raise'`` – raise :class:`ValueError`
* ``'ignore'`` – stay silent
If *strict* is ``True`` the function behaves as
``error='raise'`` regardless of this argument.
return_as_flatten : bool, default ``True``
* ``True`` → return **one** flat list with every *valid* column.
* ``False`` → keep the original grouping:
``[["foo", "bar"], ["baz"]]``
Returns
-------
list
Either a flat list or a list-of-lists containing only the
columns that really exist in *df* (order preserved).
Raises
------
ValueError
If *strict* is ``True`` **or** ``error='raise'`` **and**
at least one column is missing.
Examples
--------
>>> ensure_cols_exist(df, "lon", ["lat", "depth"])
['lon', 'lat', 'depth']
>>> ensure_cols_exist(df, "x", strict=True)
ValueError: missing columns: ['x']
"""
# ---- 1 • normalise cols
def _to_list(obj: str | Iterable[str]) -> list[str]:
return [obj] if isinstance(obj, str) else list(obj)
grouped: list[list[str]] = [_to_list(c) for c in cols]
flat: list[str] = list(chain.from_iterable(grouped))
if not flat: # nothing to check
return [] if return_as_flatten else grouped
# membership test
present = [c for c in flat if c in df.columns]
missing = [c for c in flat if c not in df.columns]
if missing:
msg = f"Missing columns: {missing} (DataFrame has {len(df.columns)})"
if strict or error == "raise":
raise ValueError(msg)
if error == "warn":
warnings.warn(msg, stacklevel=2)
# return
if return_as_flatten:
# keep original order but drop the missing ones
return list(set([c for c in flat if c in present]))
# keep original grouping, dropping any missing names
filtered_grouped: list[list[str]] = [
[c for c in grp if c in present] for grp in grouped
]
return filtered_grouped
def apply_affix(
value: Any,
label: str | None = None,
*,
mode: str = "suffix",
affix_prefix: str = "",
separator: str = "",
include_date: bool = False,
date_format: str = "%Y%m%d",
version: str | None = None,
version_prefix: str = "v",
part_separator: str = "",
) -> str:
"""
Apply either a prefix or suffix (with optional version, date, and custom label)
to the string form of `value`.
Parameters
----------
value : Any
The base value, converted to str().
label : str, optional
Custom label to include.
mode : {'suffix', 'prefix'}, default 'suffix'
Whether to append ('suffix') or prepend ('prefix') the affix block.
affix_prefix : str, default ""
Characters to prepend to the combined affix block (e.g. ".", "_").
separator : str, default ""
Characters to insert between base and affix block.
include_date : bool, default False
If True, include current date in the affix block.
date_format : str, default "%Y%m%d"
strftime format for the date.
version : str, optional
Version string to include (e.g. "1.2.3").
version_prefix : str, default "v"
Characters to prepend to the version.
part_separator : str, default ""
Separator for joining version, date, and label parts.
Returns
-------
str
The resulting string with prefix or suffix as specified.
Raises
------
ValueError
If `mode` is not 'suffix' or 'prefix'.
Examples
--------
>>> from geoprior.utils.generic_utils import apply_affix
>>> apply_affix("file", "txt")
'filetxt'
>>> apply_affix("file", "txt", affix_prefix=".", separator="")
'file.txt'
>>> apply_affix("data", None, include_date=True, separator="_")
'data_20250627'
>>> apply_affix("app", None, version="1.0.0", part_separator=".")
'appv1.0.0'
>>> apply_affix(
... "log",
... "backup",
... mode="prefix",
... affix_prefix="_",
... separator="-",
... include_date=True,
... date_format="%Y-%m-%d",
... version="2.5",
... version_prefix="v",
... part_separator="."
... )
'_v2.5.2025-06-27-backup-log'
"""
base = str(value)
parts = []
if version is not None:
parts.append(f"{version_prefix}{version}")
if include_date:
parts.append(datetime.now().strftime(date_format))
if label:
parts.append(label)
if not parts:
combined = ""
else:
combined = part_separator.join(parts)
if affix_prefix:
combined = f"{affix_prefix}{combined}"
if not combined:
return base
if mode == "suffix":
return (
f"{base}{separator}{combined}"
if separator
else f"{base}{combined}"
)
elif mode == "prefix":
return (
f"{combined}{separator}{base}"
if separator
else f"{combined}{base}"
)
else:
raise ValueError("mode must be 'suffix' or 'prefix'")
def insert_affix_in(
filepath: str | Path,
affix: str,
*,
separator: str = "",
) -> str:
"""
Insert an affix between the base name and extension of a filename.
Parameters
----------
filepath : str or Path
Original file path or name (e.g. "diagnostic.json" or "/tmp/report.txt").
affix : str
Text to insert before the extension (e.g. "test").
separator : str, default ""
Text to place between the base name and the affix (e.g. "_").
Returns
-------
str
New filename with the affix inserted (path preserved).
Examples
--------
>>> insert_affix_in_filename("diagnostic.json", "test")
'diagnostictest.json'
>>> insert_affix_in_filename("diagnostic.json", "v1", separator="_")
'diagnostic_v1.json'
>>> insert_affix_in_filename("/tmp/log.tar.gz", "backup", separator=".")
'/tmp/log.tar.backup.gz'
"""
p = Path(filepath)
base = p.stem # name without last suffix
ext = p.suffix # last suffix, including the dot
if affix is None or str(affix) == "":
separator = ""
affix = ""
new_name = f"{base}{separator}{affix}{ext}"
return str(p.with_name(new_name))
def split_train_test_by_time(
df: pd.DataFrame,
time_col: str,
cutoff: int | float | str | pd.Timestamp,
*,
parse_time_col: bool = True,
) -> tuple[pd.DataFrame, pd.DataFrame]:
"""
Split a DataFrame into train/test based on a time cutoff,
robust to different time formats.
Parameters
----------
df : pd.DataFrame
Input data containing `time_col`.
time_col : str
Name of the column to use for splitting.
cutoff : int or str or pd.Timestamp
If int: treated as a year (e.g. 2022).
If str: parsed as date (e.g. "2022-06-30") or, if 4-digit, as year.
If pd.Timestamp: used directly.
parse_time_col : bool, default True
If True and `df[time_col]` is non-datetime, attempt to parse it.
Returns
-------
train_df, test_df : Tuple[pd.DataFrame, pd.DataFrame]
Rows where time <= cutoff go to train, >= cutoff to test.
Raises
------
ValueError
If types are incompatible or parsing fails.
Example
--------
>>> import numpy as np ; import pandas as pd
>>> from geoprior.utils.generic_utils import split_train_test_by_time
>>> # 1. datetime column vs integer year cutoff
>>> df = pd.DataFrame({
'year': ['2017-01-01','2018-01-01','2019-01-01'],
'value': [1,2,3]
})
>>> train, test = split_train_test_by_time(df, 'year', 2018)
>>> # 2. datetime column vs date-string cutoff
>>> train, test = split_train_test_by_time(df, 'year', '2018-06-30')
>>> # 3. integer column vs integer cutoff
>>> df2 = pd.DataFrame({'year': [2015,2016,2017,2018], 'v':[4,5,6,7]})
>>> train, test = split_train_test_by_time(df2, 'year', 2017)
>>> # 4. disable parsing (if already datetime)
>>> df['year'] = pd.to_datetime(df['year'])
>>> train, test = split_train_test_by_time(df, 'year', pd.Timestamp('2019-01-01'))
"""
if not isinstance(df, pd.DataFrame):
raise TypeError(
f"Expect dataframe for 'df'. Got {type(df).__name__!r}"
)
if time_col not in df.columns:
raise ValueError(
f"Time column is missing in df: {time_col}"
)
# 2. Rewrap into DataFrame copy
working = df.copy()
series = working[time_col]
# 1. Detect a “numeric-year” column:
is_int = pd.api.types.is_integer_dtype(series)
is_float = pd.api.types.is_float_dtype(series) and np.all(
series.dropna() % 1 == 0
)
numeric_year_col = is_int or is_float
# 2. Branch on column type
if numeric_year_col:
# Convert floats→int if needed
years = series.astype(int)
# Cutoff must also be numeric
if not isinstance(cutoff, int | float):
raise ValueError(
"Cannot compare numeric-year column"
" to non-numeric cutoff"
)
cut_val = int(cutoff)
train_mask = years <= cut_val
test_mask = years >= cut_val
else:
# Expect a datetime-like column
if (
parse_time_col
and not pd.api.types.is_datetime64_any_dtype(
series
)
):
try:
series = pd.to_datetime(
series, errors="coerce"
)
except Exception as e:
raise ValueError(
f"Failed to parse `{time_col}` to datetime: {e}"
)
if not pd.api.types.is_datetime64_any_dtype(series):
raise ValueError(
f"Column `{time_col}` is not datetime or numeric-year."
)
# Now handle cutoff as year or full date
if isinstance(cutoff, int | float):
# Year-based split
years = series.dt.year
cut_val = int(cutoff)
train_mask = years <= cut_val
test_mask = years >= cut_val
else:
# Parse cutoff to Timestamp if needed
if not isinstance(cutoff, pd.Timestamp):
try:
cutoff = pd.to_datetime(cutoff)
except Exception as e:
raise ValueError(
f"Failed to parse cutoff {cutoff!r}: {e}"
)
train_mask = series <= cutoff
test_mask = series >= cutoff
# 3. Return splits
train_df = working.loc[train_mask].copy()
test_df = working.loc[test_mask].copy()
return train_df, test_df
[docs]
def getenv_stripped(
name: str,
default: str | None = None,
allow_empty: bool = False,
) -> str | None:
"""
Read an environment variable and strip whitespace robustly.
Parameters
----------
name : str
Environment variable name to read.
default : str or None, optional
Value returned when the environment variable is not set.
allow_empty : bool, default=False
If False, empty strings are treated as missing and `default`
is returned instead. If True, an empty string is returned
unchanged.
Returns
-------
out : str or None
The stripped string value, the empty string (if allowed), or
`default` when unset / empty.
"""
val = os.getenv(name)
if val is None:
return default
val = val.strip()
if not val and not allow_empty:
return default
return val
[docs]
def default_results_dir(
start: PathLike | None = None,
env_var: str = "RESULTS_DIR",
folder_name: str = "results",
create: bool = False,
) -> str:
"""
Resolve the canonical 'results' directory with robust fallbacks.
Resolution order
----------------
1. If the environment variable `env_var` is set, use it.
2. Else search upward from `start` (or `Path.cwd()`) for a folder
named `folder_name`.
3. Else return `Path.cwd() / folder_name`.
Parameters
----------
start : path-like, optional
Starting path for the upward search. If None, uses the current
working directory.
env_var : str, default="RESULTS_DIR"
Environment variable that, when set, overrides any discovery.
folder_name : str, default="results"
Directory name to look for when walking upward.
create : bool, default=False
If True, create the directory when it does not exist.
Returns
-------
path_str : str
Absolute path to the resolved results directory.
Notes
-----
This helper is importable in Stage-1/2/3 scripts to keep path
resolution consistent across training/tuning/inference. It avoids
relying on `__file__` (which may be inside the package install
path) and instead uses `Path.cwd()` by default, which is what users
expect when launching scripts from various locations.
Examples
--------
>>> default_results_dir()
'.../your/project/results'
>>> default_results_dir(start='/work/exp/run42', create=True)
'.../your/project/results'
"""
# 1) Environment variable takes precedence
env = os.getenv(env_var)
if env:
p = Path(env).expanduser().resolve()
if create:
p.mkdir(parents=True, exist_ok=True)
return str(p)
# 2) Search upward for `folder_name`
base = (
Path(start).resolve()
if start
else Path.cwd().resolve()
)
for parent in (base, *base.parents):
cand = parent / folder_name
if cand.exists():
return str(cand.resolve())
# 3) Fallback to cwd/folder_name
fallback = base / folder_name
if create:
fallback.mkdir(parents=True, exist_ok=True)
return str(fallback.resolve())
[docs]
def print_config_table(
sections: dict[str, Any]
| Sequence[tuple[str, dict[str, Any]]],
title: str | None = None,
table_width: int | None = None,
sort_keys: bool = True,
key_col_fraction: float = 0.35,
max_value_length: int = 200,
log_fn=None,
) -> str:
"""
Pretty-print configuration or hyperparameters as a key/value table.
This helper is intended for CLI scripts (Stage-1, training, tuning)
so that the user can quickly inspect which parameters are actually
in effect.
Parameters
----------
sections : dict or sequence of (str, dict)
If a single dict is passed, all key/value pairs are printed in
one block.
If a sequence is passed, it must contain ``(name, params)``
tuples, where ``name`` is a section label (e.g. ``"Physics"``)
and ``params`` is a dict mapping parameter names to values.
title : str, optional
Optional title displayed above the table (centered).
table_width : int, optional
Total width of the printed table. If ``None``, the function
tries to use :func:`geoprior.api.util.get_table_size`. If that
fails, it falls back to the terminal width (via
:mod:`shutil.get_terminal_size`) or 80 characters.
sort_keys : bool, default=True
Whether to sort parameter names alphabetically within each
section.
key_col_fraction : float, default=0.35
Fraction of the table width allocated to the parameter-name
column. The remainder is used for the value column.
max_value_length : int, default=200
Maximum number of characters kept from the stringified value.
Longer values are truncated with an ellipsis (``"..."``) before
being wrapped onto multiple lines.
log_fn : callable, optional
Function used to emit lines (defaults to :func:`print`). This
allows capturing the table in logs if needed.
Returns
-------
str
The full rendered table as a single string. It is always
printed via ``print_fn`` as a side effect.
Notes
-----
* Nested containers (lists, tuples, dicts) are rendered in a compact
one-line form and then wrapped to fill the value column.
* This function is intentionally lightweight and does not depend on
external tabulation libraries, so it can be safely used in
lightweight Stage-1 / Stage-2 scripts.
"""
# 1) Resolve table width
if table_width is None:
# Fallback to terminal width or 80 columns
try:
table_width = shutil.get_terminal_size(
(80, 20)
).columns
except Exception: # pragma: no cover
table_width = 80
# Keep width within reasonable bounds
table_width = int(max(40, min(table_width, 140)))
# 2) Normalize `sections` into a sequence of (name, dict)
if isinstance(sections, dict):
sections_list: Sequence[
tuple[str, dict[str, Any]]
] = [("", sections)]
else:
norm_sections: list[tuple[str, dict[str, Any]]] = []
for idx, sec in enumerate(sections):
if isinstance(sec, dict):
norm_sections.append(
(f"Section {idx + 1}", sec)
)
elif (
isinstance(sec, tuple | list)
and len(sec) == 2
and isinstance(sec[1], dict)
):
name, params = sec
norm_sections.append((str(name), params))
else:
raise TypeError(
"sections must be a dict or a sequence of "
"(name, dict) tuples."
)
sections_list = norm_sections
# 3) Determine column widths
all_keys: list[str] = []
for _, params in sections_list:
all_keys.extend(str(k) for k in params.keys())
if all_keys:
key_width = max(len(k) for k in all_keys) + 2
key_width = int(
min(key_width, table_width * key_col_fraction)
)
else:
key_width = int(table_width * key_col_fraction)
key_width = max(8, key_width)
value_width = max(
10, table_width - key_width - 3
) # "key : value"
# 4) Helpers for value formatting and wrapping
def _format_value(v: Any) -> str:
"""Compact, human-readable representation for table."""
if isinstance(v, float):
s = f"{v:.6g}"
elif isinstance(v, list | tuple | set):
inner = ", ".join(repr(x) for x in v)
open_br, close_br = (
("[", "]")
if isinstance(v, list)
else ("(", ")")
if isinstance(v, tuple)
else ("{", "}")
)
s = f"{open_br}{inner}{close_br}"
elif isinstance(v, dict):
inner = ", ".join(
f"{k}={val!r}" for k, val in v.items()
)
s = "{" + inner + "}"
else:
s = repr(v)
if len(s) > max_value_length:
s = s[: max_value_length - 3] + "..."
return s
lines: list[str] = []
border = "-" * table_width
# 5) Build lines
lines.append(border)
if title:
t = str(title)[:table_width]
lines.append(t.center(table_width))
lines.append(border)
for sec_name, params in sections_list:
if not params:
continue
if sec_name:
lines.append(f"[{sec_name}]")
items = params.items()
if sort_keys:
items = sorted(items, key=lambda kv: str(kv[0]))
for key, value in items:
key_str = str(key)
val_str = _format_value(value)
wrapped = textwrap.wrap(
val_str, width=value_width
) or [""]
# First line with key
first = wrapped[0]
lines.append(f"{key_str:<{key_width}} : {first}")
# Continuation lines for long values
for cont in wrapped[1:]:
lines.append(" " * (key_width + 3) + cont)
lines.append(border)
table_text = "\n".join(lines)
if log_fn is None:
log_fn = print # pragma: no cover
log_fn(table_text)
return table_text
def as_tuple(
obj: Any,
names: Sequence[str] | None = None,
*,
ctx: str = "value",
strict: bool = True,
allow_unwrap_singleton: bool = True,
max_unwrap_depth: int = 4,
missing_value: Any = None,
) -> tuple[Any, ...]:
"""
Convert model I/O structures (dict/list/tuple/tensor-like) into a tuple.
This is designed for Keras multi-output training where `y_true` and
`y_pred` may be dictionaries keyed by output names, but Keras metrics
are more stable when fed a tuple/list ordered like `model.output_names`.
Parameters
----------
obj : Any
The object to convert. Common cases:
- Mapping (e.g., {"subs_pred": t1, "gwl_pred": t2})
- list/tuple (e.g., [t1, t2])
- tensor-like scalar/single output (e.g., t1)
- nested single-key mapping wrapper
(e.g., {"data_final": {"subs_pred": t1, "gwl_pred": t2}})
names : sequence of str, optional
Expected output names (ordering source). Typically `self.output_names`.
If provided and `obj` is a mapping, values are extracted in this order.
ctx : str, default="value"
Context label used in error messages ("targets" / "y_pred", etc.).
strict : bool, default=True
If True, enforce:
- all `names` keys exist when `obj` is a mapping
- sequence length matches `len(names)` when `obj` is list/tuple
- multi-output requires structured input (not a single tensor-like)
If False, best-effort conversion is attempted and missing keys are
filled with `missing_value` (only for mapping + names).
allow_unwrap_singleton : bool, default=True
If True, unwrap nested single-key mapping layers when the key is not
in `names` (e.g., {"data_final": ...}) up to `max_unwrap_depth`.
max_unwrap_depth : int, default=4
Maximum unwrap depth when `allow_unwrap_singleton=True`.
missing_value : Any, default=None
Value used when `strict=False` and a key in `names` is missing.
Returns
-------
tuple
Tuple of outputs ordered to match `names` (if provided), otherwise
preserves list/tuple order or mapping insertion order.
Raises
------
KeyError
If `obj` is a mapping, `names` is provided, and a required key is
missing (strict=True).
ValueError
If `obj` is a list/tuple and its length doesn't match `len(names)`
(strict=True).
TypeError
If `obj` is tensor-like (non-iterable, non-mapping) but multiple
outputs are expected (len(names) > 1) and strict=True.
"""
# Normalize names early (avoid weird generators / None).
if names is not None:
names = tuple(str(n) for n in names)
# ------------------------------------------------------------
# 0) Optional unwrap: {"data_final": {...}} -> {...}
# ------------------------------------------------------------
if allow_unwrap_singleton:
depth = 0
cur = obj
while (
depth < int(max_unwrap_depth)
and isinstance(cur, Mapping)
and len(cur) == 1
):
(k,) = cur.keys()
# Unwrap if this singleton key is clearly a wrapper (not an output).
if names is not None and str(k) in names:
break
cur = next(iter(cur.values()))
depth += 1
obj = cur
# ------------------------------------------------------------
# 1) Mapping case: extract by names (preferred for multi-output)
# ------------------------------------------------------------
if isinstance(obj, Mapping):
if names is None:
# Fall back to insertion order (Py3.7+ preserves insertion order).
return tuple(obj[k] for k in obj.keys())
missing = [n for n in names if n not in obj]
if missing and strict:
available = ", ".join(map(str, obj.keys()))
need = ", ".join(map(str, missing))
raise KeyError(
f"{ctx} is a mapping but missing required key(s): {need}. "
f"Available keys: {available}."
)
# If not strict, fill missing keys with missing_value.
return tuple(obj.get(n, missing_value) for n in names)
# ------------------------------------------------------------
# 2) Sequence case: validate length against names if provided
# ------------------------------------------------------------
if isinstance(obj, tuple | list):
tup = tuple(obj)
if (
names is not None
and strict
and len(tup) != len(names)
):
raise ValueError(
f"{ctx} has length {len(tup)} but expected {len(names)} "
f"to match output_names={list(names)}."
)
return tup
# ------------------------------------------------------------
# 3) Single (tensor-like / scalar) case
# ------------------------------------------------------------
if names is not None and len(names) > 1:
if strict:
raise TypeError(
f"{ctx} is not a mapping/sequence but multiple outputs are "
f"expected (len(output_names)={len(names)}). "
f"Provide a dict keyed by {list(names)} or a tuple/list of "
f"that length."
)
# Best-effort: replicate into a 1-tuple (caller may handle).
return (obj,)
return (obj,)