Source code for jaxqsofit.config

from __future__ import annotations

from collections.abc import MutableMapping
from dataclasses import asdict, dataclass, field, is_dataclass
from typing import Any, Mapping, Sequence

import numpy as np


[docs] @dataclass class Observation: """Observation-level metadata for one quasar spectrum.""" object_id: str = "result" redshift: float = 0.0 ra: float | None = None dec: float | None = None apply_mw_deredden: bool = True
[docs] @dataclass class SpectroscopyData: """Observed spectral measurements on an observed-frame wavelength grid.""" wave_obs: Sequence[float] fluxes: Sequence[float] errors: Sequence[float] | float | None = None wavelength_dispersion: Sequence[float] | None = None mask: Sequence[bool] | None = None
[docs] def validate(self) -> None: """Validate spectroscopy payload array lengths.""" n = len(self.wave_obs) if len(self.fluxes) != n: raise ValueError("Spectroscopy fluxes must have the same length as wave_obs.") if self.errors is not None and not np.isscalar(self.errors) and len(self.errors) != n: raise ValueError("Spectroscopy errors must be scalar, None, or match wave_obs length.") if self.wavelength_dispersion is not None and len(self.wavelength_dispersion) != n: raise ValueError("wavelength_dispersion must match wave_obs length.") if self.mask is not None and len(self.mask) != n: raise ValueError("spectroscopy mask must match wave_obs length.")
[docs] @dataclass class PSFPhotometryData: """Optional PSF-aperture photometry used for spectral recalibration. JAXQSOFit is a spectral fitter, so these data are only used as an extra calibration constraint on the fitted spectrum. Use bands whose transmission curves overlap the observed spectral wavelength coverage. For full joint spectrum + broadband SED modeling, use ``jaxsedfit`` instead. """ magnitudes: Sequence[float] magnitude_errors: Sequence[float] filter_names: Sequence[str] = ("u", "g", "r", "i", "z")
[docs] def validate(self) -> None: """Validate PSF photometry vector lengths.""" n = len(self.magnitudes) if len(self.magnitude_errors) != n or len(self.filter_names) != n: raise ValueError("PSF magnitudes, errors, and filter_names must have the same length.")
[docs] @dataclass class PreprocessingConfig: """Spectrum preprocessing options applied before fitting.""" wave_range: tuple[float, float] | None = None wave_mask: Sequence[Sequence[float]] | None = None mask_lya_forest: bool = True
[docs] @dataclass class ContinuumConfig: """Continuum and spectral component switches.""" fit_power_law: bool = True fit_feii: bool = True fit_balmer_continuum: bool = False fit_bal_absorption: bool = False fit_polynomial_tilt: bool = True fit_reddening: bool = True polynomial_order: int = 2
[docs] @dataclass class HostConfig: """Host-galaxy spectral decomposition configuration.""" enabled: bool = True sfh_model: str = "delayed" dsps_ssp_fn: str = "tempdata.h5" age_grid_gyr: Sequence[float] = (0.01, 0.03, 0.1, 0.3, 1.0, 3.0, 10.0) logzsol_grid: Sequence[float] = (-1.0, -0.5, 0.0, 0.2)
[docs] @dataclass class LineConfig: """Emission-line model configuration.""" enabled: bool = True custom_components: Sequence[Any] | None = None custom_line_components: Sequence[Any] | None = None
[docs] @dataclass class InferenceConfig: """Inference defaults for Optax and NUTS.""" method: str = "optax+nuts" map_steps: int = 600 learning_rate: float = 1.0e-2 num_warmup: int = 50 num_samples: int = 50 num_chains: int = 1 target_accept_prob: float = 0.9 plot_init: bool = False
[docs] @dataclass class OutputConfig: """Plotting and persistence defaults.""" output_path: str | None = None save_name: str | None = None save_result: bool = True plot_fig: bool = True save_fig: bool = True show_plot: bool = False
[docs] @dataclass class ContinuumPriorConfig: """Semantic continuum-prior options.""" power_law_pivot: float | None = None polynomial_pivot: float | None = None output_wavelengths: Sequence[float] | None = None overrides: dict[str, Any] = field(default_factory=dict)
[docs] def to_mapping(self) -> dict[str, Any]: """Convert semantic continuum prior settings into model-site keys.""" out = dict(self.overrides) if self.power_law_pivot is not None: out["PL_pivot"] = float(self.power_law_pivot) if self.polynomial_pivot is not None: out["poly_pivot"] = float(self.polynomial_pivot) if self.output_wavelengths is not None: out["out_params"] = {"cont_lum_waves": list(self.output_wavelengths)} return out
[docs] @dataclass class HostPriorConfig: """Semantic host-galaxy prior options.""" redshift_weight_enabled: bool | None = None overrides: dict[str, Any] = field(default_factory=dict)
[docs] def to_mapping(self) -> dict[str, Any]: """Convert semantic host prior settings into model-site keys.""" out = dict(self.overrides) if self.redshift_weight_enabled is not None: host_z = dict(out.get("host_redshift_prior", {})) host_z["enabled"] = bool(self.redshift_weight_enabled) out["host_redshift_prior"] = host_z return out
[docs] @dataclass class LinePriorConfig: """Semantic emission-line prior options.""" table: Sequence[Mapping[str, Any]] | None = None dmu_scale_mult: float | None = None sig_scale_mult: float | None = None amp_scale_mult: float | None = None overrides: dict[str, Any] = field(default_factory=dict)
[docs] def to_mapping(self) -> dict[str, Any]: """Convert semantic emission-line prior settings into model-site keys.""" out = dict(self.overrides) if self.table is not None: out["line"] = {"table": list(self.table)} if self.dmu_scale_mult is not None: out["line_dmu_scale_mult"] = float(self.dmu_scale_mult) if self.sig_scale_mult is not None: out["line_sig_scale_mult"] = float(self.sig_scale_mult) if self.amp_scale_mult is not None: out["line_amp_scale_mult"] = float(self.amp_scale_mult) return out
[docs] @dataclass class FeIIPriorConfig: """Semantic Fe II prior options.""" uv_fwhm: Mapping[str, Any] | None = None optical_fwhm: Mapping[str, Any] | None = None overrides: dict[str, Any] = field(default_factory=dict)
[docs] def to_mapping(self) -> dict[str, Any]: """Convert semantic Fe II prior settings into model-site keys.""" out = dict(self.overrides) if self.uv_fwhm is not None: out["log_Fe_uv_FWHM"] = dict(self.uv_fwhm) if self.optical_fwhm is not None: out["log_Fe_op_FWHM"] = dict(self.optical_fwhm) return out
[docs] @dataclass class PSFPriorConfig: """Semantic PSF recalibration prior options.""" overrides: dict[str, Any] = field(default_factory=dict)
[docs] def to_mapping(self) -> dict[str, Any]: """Return low-level PSF recalibration prior overrides.""" return dict(self.overrides)
[docs] @dataclass class PriorConfig(MutableMapping[str, Any]): """Object-oriented prior configuration with dict-compatible overrides.""" continuum: ContinuumPriorConfig = field(default_factory=ContinuumPriorConfig) host: HostPriorConfig = field(default_factory=HostPriorConfig) lines: LinePriorConfig = field(default_factory=LinePriorConfig) feii: FeIIPriorConfig = field(default_factory=FeIIPriorConfig) psf: PSFPriorConfig = field(default_factory=PSFPriorConfig) overrides: dict[str, Any] = field(default_factory=dict)
[docs] def to_mapping(self) -> dict[str, Any]: """Return the flat prior mapping consumed by low-level model code.""" out: dict[str, Any] = dict(self.overrides) out.update(self.continuum.to_mapping()) out.update(self.host.to_mapping()) out.update(self.lines.to_mapping()) out.update(self.feii.to_mapping()) out.update(self.psf.to_mapping()) return out
def __getitem__(self, key: str) -> Any: """Return one prior entry from the flat mapping view.""" return self.to_mapping()[key] def __setitem__(self, key: str, value: Any) -> None: """Store a low-level prior override by model-site key.""" self.overrides[str(key)] = value def __delitem__(self, key: str) -> None: """Remove a low-level override entry.""" if key in self.overrides: del self.overrides[key] return raise KeyError(key) def __iter__(self): """Iterate over keys in the flat mapping view.""" return iter(self.to_mapping()) def __len__(self) -> int: """Return the number of entries in the flat mapping view.""" return len(self.to_mapping()) def __contains__(self, key: object) -> bool: """Return True when a key exists in the flat mapping view.""" return key in self.to_mapping()
[docs] def get(self, key: str, default: Any = None) -> Any: """Return a prior entry or default from the flat mapping view.""" return self.to_mapping().get(key, default)
[docs] def setdefault(self, key: str, default: Any = None) -> Any: """Set and return a low-level override when the key is absent.""" if key not in self: self[key] = default return self[key]
[docs] @dataclass class FitConfig: """Top-level configuration bundle for one JAXQSOFit spectral fit.""" observation: Observation spectroscopy: SpectroscopyData psf_photometry: PSFPhotometryData | None = None preprocessing: PreprocessingConfig = field(default_factory=PreprocessingConfig) continuum: ContinuumConfig = field(default_factory=ContinuumConfig) host: HostConfig = field(default_factory=HostConfig) lines: LineConfig = field(default_factory=LineConfig) inference: InferenceConfig = field(default_factory=InferenceConfig) output: OutputConfig = field(default_factory=OutputConfig) prior_config: PriorConfig | None = None def __post_init__(self) -> None: """Coerce mapping-style prior configs into :class:`PriorConfig`.""" if self.prior_config is not None: self.prior_config = _coerce_prior_config(self.prior_config)
[docs] def validate(self) -> None: """Validate required nested data payloads.""" self.spectroscopy.validate() if self.psf_photometry is not None: self.psf_photometry.validate()
[docs] def to_dict(self) -> dict[str, Any]: """Convert the dataclass tree into a plain dictionary.""" return asdict(self)
def _coerce_dataclass(cls, value: Any): """Convert an existing instance or mapping into the requested dataclass.""" if isinstance(value, cls): return value if isinstance(value, Mapping): kwargs = {} for field_name in cls.__dataclass_fields__: if field_name in value: kwargs[field_name] = value[field_name] return cls(**kwargs) raise TypeError(f"Cannot coerce {type(value)!r} to {cls.__name__}") def _coerce_prior_config(value: Any) -> PriorConfig: """Coerce flat and nested prior mappings into :class:`PriorConfig`.""" if isinstance(value, PriorConfig): return value if value is None: return PriorConfig() if not isinstance(value, Mapping): return _coerce_dataclass(PriorConfig, value) data = dict(value) nested_keys = {"continuum", "host", "lines", "feii", "psf", "overrides"} if any(key in data for key in nested_keys): cfg = PriorConfig( continuum=_coerce_dataclass(ContinuumPriorConfig, data.get("continuum", {})), host=_coerce_dataclass(HostPriorConfig, data.get("host", {})), lines=_coerce_dataclass(LinePriorConfig, data.get("lines", {})), feii=_coerce_dataclass(FeIIPriorConfig, data.get("feii", {})), psf=_coerce_dataclass(PSFPriorConfig, data.get("psf", {})), overrides=dict(data.get("overrides", {})), ) for key, item in data.items(): if key not in nested_keys: cfg.overrides[key] = item return cfg return PriorConfig(overrides=data)
[docs] def fit_config_from_mapping(data: Mapping[str, Any]) -> FitConfig: """Build a validated FitConfig from a nested mapping.""" psf_raw = data.get("psf_photometry") psf_obj = None if psf_raw is None else _coerce_dataclass(PSFPhotometryData, psf_raw) cfg = FitConfig( observation=_coerce_dataclass(Observation, data.get("observation", {})), spectroscopy=_coerce_dataclass(SpectroscopyData, data["spectroscopy"]), psf_photometry=psf_obj, preprocessing=_coerce_dataclass(PreprocessingConfig, data.get("preprocessing", {})), continuum=_coerce_dataclass(ContinuumConfig, data.get("continuum", {})), host=_coerce_dataclass(HostConfig, data.get("host", {})), lines=_coerce_dataclass(LineConfig, data.get("lines", {})), inference=_coerce_dataclass(InferenceConfig, data.get("inference", {})), output=_coerce_dataclass(OutputConfig, data.get("output", {})), prior_config=None if data.get("prior_config") is None else _coerce_prior_config(data.get("prior_config", {})), ) cfg.validate() return cfg
def serialize_config(value: Any) -> Any: """Convert config-like objects into JSON-serializable Python values.""" if is_dataclass(value): return {k: serialize_config(v) for k, v in asdict(value).items()} if isinstance(value, dict): return {k: serialize_config(v) for k, v in value.items()} if isinstance(value, (list, tuple)): return [serialize_config(v) for v in value] if isinstance(value, np.ndarray): return value.tolist() return value