Source code for jaxqsofit.results

from __future__ import annotations

from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Mapping

import numpy as np


def median_mapping(values: Mapping[str, Any] | None) -> dict[str, Any]:
    """Return posterior medians for every value in a sample-like mapping."""
    out: dict[str, Any] = {}
    for key, value in (values or {}).items():
        arr = np.asarray(value)
        if arr.ndim == 0:
            out[key] = arr.item()
        elif arr.size == 0:
            out[key] = arr
        else:
            out[key] = np.nanmedian(arr, axis=0)
    return out


@dataclass
class _PosteriorState:
    """Internal mutable posterior state produced by a jaxqsofit run."""

    method: str | None = None
    samples: Mapping[str, Any] | None = None
    predictive: Mapping[str, Any] | None = None
    bands: Mapping[str, Any] | None = None
    path: Path | None = None
    figure: Any = None
    trace_figure: Any = None
    corner_figure: Any = None
    hydrated: bool = False
    resumed_from_samples: bool = False


[docs] @dataclass class PredictionResult: """Dict-like posterior prediction or reconstruction result.""" data: Mapping[str, Any] fitter: Any _median: dict[str, Any] | None = field(default=None, init=False, repr=False) @property def median(self) -> dict[str, Any]: """Median predictive values over the leading posterior axis.""" if self._median is None: self._median = median_mapping(self.data) return self._median def __getitem__(self, key: str) -> Any: return self.data[key]
[docs] def keys(self): return self.data.keys()
[docs] def items(self): return self.data.items()
[docs] def get(self, key: str, default: Any = None) -> Any: return self.data.get(key, default)
[docs] @dataclass class FitResult: """High-level result object returned by ``jaxqsofit`` fits.""" fitter: Any samples: Mapping[str, Any] | None median: Mapping[str, Any] method: str summary: Mapping[str, Any] | None = None path: Path | None = None figure: Any = None _state: _PosteriorState | None = field(default=None, repr=False, compare=False)
[docs] def predict(self, **kwargs) -> PredictionResult: """Reconstruct posterior spectral components for this fit.""" if self._state is not None: kwargs.setdefault("_state", self._state) return PredictionResult(self.fitter.reconstruct_posterior_spectrum(**kwargs), fitter=self.fitter)
[docs] def save(self, path: str | Path | None = None, **kwargs) -> Path: """Save the result with the fitter's native persistence format.""" if self._state is not None: kwargs.setdefault("_state", self._state) self.path = Path(self.fitter.save(path, **kwargs)) if self._state is not None: self._state.path = self.path return self.path
[docs] def plot_corner(self, **kwargs): """Plot posterior samples with the fitter's corner-plot helper.""" return self.fitter.plot_corner(**kwargs)
[docs] def plot_trace(self, **kwargs): """Plot posterior samples with the fitter's trace-plot helper.""" return self.fitter.plot_trace(**kwargs)