Source code for jaxqsofit.results
from __future__ import annotations
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Mapping
import numpy as np
def median_mapping(values: Mapping[str, Any] | None) -> dict[str, Any]:
"""Return posterior medians for every value in a sample-like mapping."""
out: dict[str, Any] = {}
for key, value in (values or {}).items():
arr = np.asarray(value)
if arr.ndim == 0:
out[key] = arr.item()
elif arr.size == 0:
out[key] = arr
else:
out[key] = np.nanmedian(arr, axis=0)
return out
@dataclass
class _PosteriorState:
"""Internal mutable posterior state produced by a jaxqsofit run."""
method: str | None = None
samples: Mapping[str, Any] | None = None
predictive: Mapping[str, Any] | None = None
bands: Mapping[str, Any] | None = None
path: Path | None = None
figure: Any = None
trace_figure: Any = None
corner_figure: Any = None
hydrated: bool = False
resumed_from_samples: bool = False
[docs]
@dataclass
class PredictionResult:
"""Dict-like posterior prediction or reconstruction result."""
data: Mapping[str, Any]
fitter: Any
_median: dict[str, Any] | None = field(default=None, init=False, repr=False)
@property
def median(self) -> dict[str, Any]:
"""Median predictive values over the leading posterior axis."""
if self._median is None:
self._median = median_mapping(self.data)
return self._median
def __getitem__(self, key: str) -> Any:
return self.data[key]
[docs]
def keys(self):
return self.data.keys()
[docs]
def items(self):
return self.data.items()
[docs]
def get(self, key: str, default: Any = None) -> Any:
return self.data.get(key, default)
[docs]
@dataclass
class FitResult:
"""High-level result object returned by ``jaxqsofit`` fits."""
fitter: Any
samples: Mapping[str, Any] | None
median: Mapping[str, Any]
method: str
summary: Mapping[str, Any] | None = None
path: Path | None = None
figure: Any = None
_state: _PosteriorState | None = field(default=None, repr=False, compare=False)
[docs]
def predict(self, **kwargs) -> PredictionResult:
"""Reconstruct posterior spectral components for this fit."""
if self._state is not None:
kwargs.setdefault("_state", self._state)
return PredictionResult(self.fitter.reconstruct_posterior_spectrum(**kwargs), fitter=self.fitter)
[docs]
def save(self, path: str | Path | None = None, **kwargs) -> Path:
"""Save the result with the fitter's native persistence format."""
if self._state is not None:
kwargs.setdefault("_state", self._state)
self.path = Path(self.fitter.save(path, **kwargs))
if self._state is not None:
self._state.path = self.path
return self.path
[docs]
def plot_corner(self, **kwargs):
"""Plot posterior samples with the fitter's corner-plot helper."""
return self.fitter.plot_corner(**kwargs)
[docs]
def plot_trace(self, **kwargs):
"""Plot posterior samples with the fitter's trace-plot helper."""
return self.fitter.plot_trace(**kwargs)