Source code for geoprior.scripts.rebuild_confusion_tables

# SPDX-License-Identifier: Apache-2.0
# GeoPrior-v3 - https://github.com/earthai-tech/geoprior-v3
# Copyright (c) 2026-present
# Author: LKouadio <https://lkouadio.com>
r"""Script helpers for rebuilding confusion tables."""

from __future__ import annotations

import argparse
import math
from dataclasses import dataclass
from pathlib import Path

import pandas as pd


[docs] @dataclass(frozen=True) class RowSpec: name: str acc: float f1: float prec: float rec: float
def _safe_float(x: object) -> float | None: if x is None: return None if isinstance(x, float) and math.isnan(x): return None try: return float(x) except Exception: return None def _round2(x: float) -> float: return float(f"{x:.2f}") def _acc(tp: int, tn: int, fp: int, fn: int) -> float: den = tp + tn + fp + fn return 100.0 * (tp + tn) / den if den else 0.0 def _prec(tp: int, fp: int) -> float | None: den = tp + fp return 100.0 * tp / den if den else None def _rec(tp: int, fn: int) -> float | None: den = tp + fn return 100.0 * tp / den if den else None def _f1(tp: int, fp: int, fn: int) -> float | None: den = 2 * tp + fp + fn return 100.0 * (2 * tp) / den if den else None def _spec(tn: int, fp: int) -> float | None: den = tn + fp return 100.0 * tn / den if den else None def _npv(tn: int, fn: int) -> float | None: den = tn + fn return 100.0 * tn / den if den else None def _mcc(tp: int, tn: int, fp: int, fn: int) -> float: a = (tp + fp) * (tp + fn) * (tn + fp) * (tn + fn) if a <= 0: return 0.0 num = tp * tn - fp * fn return num / math.sqrt(a) def _close(a: float | None, b: float, tol: float) -> bool: if a is None: return False return abs(a - b) <= tol def _reconstruct_counts( *, p: int, n: int, acc: float, prec: float, rec: float, f1: float, tol: float, ) -> tuple[int, int, int, int] | None: """ Reconstruct integer TP/FP/FN/TN given (P,N) and reported metrics (Acc, Prec, Rec, F1). Strategy: - Enumerate TP candidates consistent with Recall. - For each TP, enumerate FP candidates. - Derive FN, TN from P,N. - Validate all metrics within tolerance. """ # Candidate TP around Recall * P tp0 = int(round((rec / 100.0) * p)) tp_cands = set([tp0 - 2, tp0 - 1, tp0, tp0 + 1, tp0 + 2]) for tp in sorted(tp_cands): if tp < 0 or tp > p: continue fn = p - tp # From precision: prec = TP/(TP+FP) # => FP = TP*(100/prec - 1) fp0 = int(round(tp * (100.0 / prec - 1.0))) fp_cands = set( [ fp0 - 3, fp0 - 2, fp0 - 1, fp0, fp0 + 1, fp0 + 2, fp0 + 3, ] ) for fp in sorted(fp_cands): if fp < 0 or fp > n: continue tn = n - fp a = _round2(_acc(tp, tn, fp, fn)) pr = _prec(tp, fp) rc = _rec(tp, fn) f1v = _f1(tp, fp, fn) if not _close(a, acc, tol): continue if not _close(_round2(pr or -1.0), prec, tol): continue if not _close(_round2(rc or -1.0), rec, tol): continue if not _close(_round2(f1v or -1.0), f1, tol): continue return tp, fp, fn, tn return None def _infer_pn( rows: list[RowSpec], *, n_test: int, tol: float, ) -> tuple[int, int]: """ Find (P,N) such that every row can be reconstructed. """ best: tuple[int, int, int] = (-1, -1, -1) for p in range(0, n_test + 1): n = n_test - p ok = 0 for r in rows: out = _reconstruct_counts( p=p, n=n, acc=r.acc, prec=r.prec, rec=r.rec, f1=r.f1, tol=tol, ) if out is not None: ok += 1 if ok > best[2]: best = (p, n, ok) if ok == len(rows): return p, n p, n, ok = best raise RuntimeError( "Could not find a single (P,N) that " "matches all rows. Best match: " f"P={p}, N={n}, rows_ok={ok}/{len(rows)}. " "Try increasing --tol." ) def _load_metric_rows(csv_path: Path) -> list[RowSpec]: df = pd.read_csv(csv_path) need = ["Model", "Acc", "F1", "Prec", "Rec"] for c in need: if c not in df.columns: raise ValueError(f"Missing column: {c}") rows: list[RowSpec] = [] for _, r in df.iterrows(): rows.append( RowSpec( name=str(r["Model"]), acc=float(r["Acc"]), f1=float(r["F1"]), prec=float(r["Prec"]), rec=float(r["Rec"]), ) ) return rows def _counts_to_df( rows: list[RowSpec], *, p: int, n: int, tol: float, ) -> pd.DataFrame: out_rows: list[dict[str, object]] = [] for r in rows: counts = _reconstruct_counts( p=p, n=n, acc=r.acc, prec=r.prec, rec=r.rec, f1=r.f1, tol=tol, ) if counts is None: raise RuntimeError( f"Failed reconstruction for {r.name}. " "Increase --tol or check inputs." ) tp, fp, fn, tn = counts out_rows.append( { "Model": r.name, "TP": tp, "FP": fp, "FN": fn, "TN": tn, "Acc%": _round2(_acc(tp, tn, fp, fn)), "Prec%": _round2( _prec(tp, fp) or float("nan") ), "Rec%": _round2(_rec(tp, fn) or float("nan")), "F1%": _round2( _f1(tp, fp, fn) or float("nan") ), "Specificity%": _round2( _spec(tn, fp) or float("nan") ), "NPV%": _round2(_npv(tn, fn) or float("nan")), "MCC": float(f"{_mcc(tp, tn, fp, fn):.4f}"), "PredPos": tp + fp, "PredNeg": tn + fn, } ) df = pd.DataFrame(out_rows) return df
[docs] def main(*, prog: str | None = None) -> None: ap = argparse.ArgumentParser() ap.add_argument("--metrics-csv", required=True) ap.add_argument("--data-csv", default=None) ap.add_argument("--target", default="cath") ap.add_argument("--test-ratio", type=float, default=0.30) ap.add_argument("--n-test", type=int, default=None) ap.add_argument("--tol", type=float, default=0.02) ap.add_argument("--outdir", default="scripts/out") ap.add_argument("--tag", default="table") args = ap.parse_args() rows = _load_metric_rows(Path(args.metrics_csv)) # Determine n_test if args.n_test is not None: n_test = int(args.n_test) elif args.data_csv is not None: df_data = pd.read_csv(args.data_csv) n_total = int(df_data.shape[0]) n_test = int(round(n_total * float(args.test_ratio))) else: raise ValueError("Provide --n-test or --data-csv.") p, n = _infer_pn(rows, n_test=n_test, tol=float(args.tol)) outdir = Path(args.outdir) outdir.mkdir(parents=True, exist_ok=True) df = _counts_to_df(rows, p=p, n=n, tol=float(args.tol)) # Full export full_path = outdir / f"{args.tag}_confusion_full.csv" df.to_csv(full_path, index=False) # Paper-ready export (rename Rec% as Recall%) paper = df[ [ "Model", "Acc%", "Prec%", "Rec%", "Specificity%", "NPV%", "F1%", "MCC", ] ].copy() paper = paper.rename(columns={"Rec%": "Recall%"}) paper_path = outdir / f"{args.tag}_metrics_paper.csv" paper.to_csv(paper_path, index=False) # Print summary print(f"n_test={n_test}, P={p}, N={n}") print(f"Wrote: {full_path}") print(f"Wrote: {paper_path}")
if __name__ == "__main__": main()