from __future__ import annotations
import os
import glob
import extinction
import h5py
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from astropy import units as u
from astropy.coordinates import SkyCoord
from speclite import filters as speclite_filters
import jax
import jax.numpy as jnp
import optax
from numpyro.infer import MCMC, NUTS, Predictive, SVI, Trace_ELBO, init_to_value
from numpyro.infer.autoguide import AutoDelta
from numpyro.optim import optax_to_numpyro
from .custom_components import (
CustomComponentSpec,
CustomLineComponentSpec,
custom_component_site_names,
custom_line_component_site_names,
inject_default_custom_component_priors,
inject_default_custom_line_component_priors,
normalize_custom_components,
normalize_custom_line_components,
)
from .defaults import build_default_bal_components, build_default_prior_config
from .model import (
C_KMS,
_continuum_output_waves_from_prior_config,
_extract_line_table_from_prior_config,
_get_sfd_query,
_normalize_template_flux,
_np_to_jnp,
_spectrum_center_pivot,
build_fsps_template_grid,
build_tied_line_meta_from_linelist,
qso_fsps_joint_model,
reconstruct_posterior_components,
unred,
)
_SDSS_PSF_BANDS = ("u", "g", "r", "i", "z")
_SDSS_FILTER_CACHE = None
def _get_sdss_filters():
"""Load SDSS filter curves once and return a band->response mapping."""
global _SDSS_FILTER_CACHE
if _SDSS_FILTER_CACHE is None:
filters = speclite_filters.load_filters(*[f"sdss2010-{b}" for b in _SDSS_PSF_BANDS])
_SDSS_FILTER_CACHE = {band: filt for band, filt in zip(_SDSS_PSF_BANDS, filters)}
return _SDSS_FILTER_CACHE
def _filter_wave_to_angstrom_array(value):
"""Return a filter wavelength grid as a float ndarray in Angstrom."""
if hasattr(value, "to_value"):
return np.asarray(value.to_value(u.AA), dtype=np.float64)
return np.asarray(value, dtype=np.float64)
def _filter_wave_to_angstrom_scalar(value):
"""Return a scalar wavelength-like object as a float in Angstrom."""
if hasattr(value, "to_value"):
return float(value.to_value(u.AA))
return float(value)
def _ab_mag_to_fnu(mag):
"""Convert AB magnitude to flux density in cgs units."""
mag = np.asarray(mag, dtype=np.float64)
return 10.0 ** (-0.4 * (mag + 48.60))
def _fnu_to_ab_mag(fnu):
"""Convert flux density in cgs units to AB magnitude."""
fnu = np.asarray(fnu, dtype=np.float64)
return -2.5 * np.log10(np.clip(fnu, 1e-300, None)) - 48.60
def _mw_band_attenuation_factor(wave_obs, filt_trans, ebv, r_v=3.1):
"""Return the AB-weighted Galactic attenuation factor through a filter."""
wave_obs = np.asarray(wave_obs, dtype=np.float64)
filt_trans = np.clip(np.asarray(filt_trans, dtype=np.float64), 0.0, None)
if (not np.isfinite(ebv)) or ebv == 0.0:
return 1.0
a_lambda = extinction.fitzpatrick99(wave_obs, a_v=float(r_v) * float(ebv), r_v=float(r_v))
attenuation = 10.0 ** (-0.4 * np.asarray(a_lambda, dtype=np.float64))
inv_wave = 1.0 / np.clip(wave_obs, 1e-8, None)
denom = float(np.trapezoid(filt_trans * inv_wave, wave_obs))
if (not np.isfinite(denom)) or denom <= 0.0:
return 1.0
numer = float(np.trapezoid(filt_trans * attenuation * inv_wave, wave_obs))
if (not np.isfinite(numer)) or numer <= 0.0:
return 1.0
return numer / denom
[docs]
class QSOFit:
_POSTERIOR_BUNDLE_SUFFIX = ".h5"
def __init__(self, lam, flux, err=None, z=0.0, ra=-999, dec=-999, filename=None, output_path=None,
wdisp=None, psf_mags=None, psf_mag_errs=None, psf_bands=None):
"""Initialize a spectral fitting object with observed-frame inputs.
Parameters
----------
lam : array-like
Observed-frame wavelength array in Angstrom.
flux : array-like
Observed-frame flux density array.
err : array-like or float or None, optional
Per-pixel 1-sigma uncertainty. If ``None``, a default of ``1e-6``
is used.
z : float, optional
Source redshift.
ra, dec : float, optional
Sky coordinates in degrees for Galactic dereddening.
filename : str or None, optional
Basename used for saving result tables/figures. If ``None``,
a filename is auto-generated from ``ra`` and ``dec``.
output_path : str or None, optional
Output directory for saved artifacts.
wdisp : array-like or None, optional
Optional wavelength dispersion vector (stored only).
psf_mags, psf_mag_errs : array-like or None, optional
Optional PSF photometry magnitudes and 1-sigma errors.
psf_bands : sequence of str or None, optional
Band labels associated with ``psf_mags``. If omitted, defaults to
the first N SDSS bands.
Notes
-----
Providing per-pixel `err` is strongly recommended for robust inference.
If `err` is not provided, a small default uncertainty (`1e-6`) is used;
in that case, fitted intrinsic-scatter terms absorb much of the noise model.
Use keyword arguments to avoid ambiguity (for example `z=...`).
"""
self.lam_in = np.asarray(lam, dtype=np.float64)
self.flux_in = np.asarray(flux, dtype=np.float64)
if err is None:
self.err_in = np.full_like(self.flux_in, 1e-6, dtype=np.float64)
else:
err_arr = np.asarray(err, dtype=np.float64)
if err_arr.ndim == 0:
self.err_in = np.full_like(self.flux_in, float(err_arr), dtype=np.float64)
else:
self.err_in = err_arr
self.z = z
self.wdisp = wdisp
self.ra = ra
self.dec = dec
self.install_path = os.path.dirname(os.path.abspath(__file__))
self.output_path = output_path
self.filename = self._resolve_filename(filename=filename, ra=ra, dec=dec)
self.psf_mags = None if psf_mags is None else np.asarray(psf_mags, dtype=np.float64)
self.psf_mag_errs = None if psf_mag_errs is None else np.asarray(psf_mag_errs, dtype=np.float64)
self.psf_mags_raw = None if psf_mags is None else np.asarray(psf_mags, dtype=np.float64)
self.psf_mag_errs_raw = None if psf_mag_errs is None else np.asarray(psf_mag_errs, dtype=np.float64)
self.psf_mags_dered = None
self.psf_mag_errs_dered = None
self.psf_bands = None if psf_bands is None else list(psf_bands)
if self.psf_bands is None and self.psf_mags is not None:
self.psf_bands = ["u", "g", "r", "i", "z"][:len(self.psf_mags)]
self.psf_filter_curves = None
self.use_psf_phot = False
self.ebv_mw = np.nan
@staticmethod
def _resolve_filename(filename=None, ra=-999, dec=-999):
"""Resolve a filesystem-safe basename for outputs."""
if filename is not None and str(filename).strip() != "":
return str(filename).strip()
try:
ra_f = float(ra)
dec_f = float(dec)
except Exception:
return "result"
if np.isfinite(ra_f) and np.isfinite(dec_f) and (ra_f != -999) and (dec_f != -999):
return f"ra{ra_f:.5f}_dec{dec_f:.5f}"
return "result"
def _predictive_return_sites(self, custom_components=None, custom_line_components=None):
"""Return posterior predictive sites needed for summaries and plots."""
return_sites = [
'f_pl_model',
'f_fe_mgii_model',
'f_fe_balmer_model',
'f_bc_model',
'f_poly_model',
'agn_model',
'gal_model',
'line_model_broad',
'line_model_narrow',
'line_model',
'continuum_model',
'model',
'fsps_weights',
'line_amp_per_component',
'line_mu_per_component',
'line_sig_per_component',
'delta_m_psf',
'eta_psf',
'scale_psf',
'agn_model_psf',
'gal_model_psf',
'line_model_broad_psf',
'line_model_narrow_psf',
'line_model_psf',
'psf_model',
]
for wave_lum in _continuum_output_waves_from_prior_config(
getattr(self, "_fit_prior_config", None)
):
wave_label = self._format_wave_label(wave_lum)
return_sites.append(f"log_lambda_Llambda_{wave_label}_agn")
return_sites += custom_component_site_names(custom_components)
return_sites += custom_line_component_site_names(custom_line_components)
return return_sites
def _prepare_psf_photometry(
self,
wave_obs,
psf_mags=None,
psf_mag_errs=None,
psf_bands=None,
use_psf_phot=False,
min_filter_coverage=0.97,
):
"""Validate PSF photometry and interpolate filters onto the observed wavelength grid."""
if psf_mags is not None:
self.psf_mags = np.asarray(psf_mags, dtype=np.float64)
self.psf_mags_raw = np.asarray(psf_mags, dtype=np.float64)
if psf_mag_errs is not None:
self.psf_mag_errs = np.asarray(psf_mag_errs, dtype=np.float64)
self.psf_mag_errs_raw = np.asarray(psf_mag_errs, dtype=np.float64)
if psf_bands is not None:
self.psf_bands = list(psf_bands)
if self.psf_bands is None and self.psf_mags is not None:
self.psf_bands = list(_SDSS_PSF_BANDS[:len(self.psf_mags)])
if (not use_psf_phot) or self.psf_mags is None or self.psf_mag_errs is None:
self.use_psf_phot = False
self.psf_filter_curves = None
self.psf_mags_dered = None
self.psf_mag_errs_dered = None
return None, None, None, None, False
mags = np.asarray(self.psf_mags, dtype=np.float64)
errs = np.asarray(self.psf_mag_errs, dtype=np.float64)
bands = list(self.psf_bands) if self.psf_bands is not None else list(_SDSS_PSF_BANDS[:len(mags)])
if len(mags) != len(errs) or len(mags) != len(bands):
raise ValueError("psf_mags, psf_mag_errs, and psf_bands must have the same length.")
valid = np.isfinite(mags) & np.isfinite(errs) & (errs > 0)
wave_obs = np.asarray(wave_obs, dtype=np.float64)
filters = _get_sdss_filters()
keep_mags = []
keep_errs = []
keep_bands = []
keep_trans = []
keep_coverage = []
for band, mag, err, is_valid in zip(bands, mags, errs, valid):
if not is_valid:
continue
if band not in filters:
raise ValueError(f"Unsupported PSF photometry band '{band}'. Supported bands: {_SDSS_PSF_BANDS}.")
filt = filters[band]
filt_wave = _filter_wave_to_angstrom_array(filt.wavelength)
filt_trans = np.asarray(filt.response, dtype=np.float64)
trans_on_wave = np.interp(wave_obs, filt_wave, filt_trans, left=0.0, right=0.0)
full_norm = float(np.trapezoid(np.clip(filt_trans, 0.0, None), filt_wave))
covered_norm = float(np.trapezoid(np.clip(trans_on_wave, 0.0, None), wave_obs))
coverage = (covered_norm / full_norm) if full_norm > 0 else 0.0
if coverage < float(min_filter_coverage):
continue
keep_mags.append(float(mag))
keep_errs.append(float(err))
keep_bands.append(str(band))
keep_trans.append(np.asarray(trans_on_wave, dtype=np.float64))
keep_coverage.append(float(coverage))
if len(keep_bands) == 0:
self.use_psf_phot = False
self.psf_filter_curves = None
self.psf_mags = None
self.psf_mag_errs = None
self.psf_mags_raw = None
self.psf_mag_errs_raw = None
self.psf_mags_dered = None
self.psf_mag_errs_dered = None
self.psf_bands = None
return None, None, None, None, False
raw_mags = np.asarray(keep_mags, dtype=np.float64)
raw_errs = np.asarray(keep_errs, dtype=np.float64)
dered_mags = raw_mags.copy()
apply_dered = bool(getattr(self, "_fit_deredden", False)) and np.isfinite(getattr(self, "ebv_mw", np.nan))
if apply_dered and float(self.ebv_mw) != 0.0:
band_atten = np.asarray(
[_mw_band_attenuation_factor(wave_obs, trans, self.ebv_mw) for trans in keep_trans],
dtype=np.float64,
)
fnu_obs = _ab_mag_to_fnu(raw_mags)
fnu_dered = fnu_obs / np.clip(band_atten, 1e-30, None)
dered_mags = _fnu_to_ab_mag(fnu_dered)
self.psf_mags_raw = raw_mags
self.psf_mag_errs_raw = raw_errs
self.psf_mags_dered = dered_mags
self.psf_mag_errs_dered = raw_errs.copy()
self.psf_mags = dered_mags
self.psf_mag_errs = raw_errs
self.psf_bands = keep_bands
self.psf_filter_curves = {
"bands": tuple(keep_bands),
"trans": np.asarray(keep_trans, dtype=np.float64),
"coverage": np.asarray(keep_coverage, dtype=np.float64),
}
self.use_psf_phot = True
return (
self.psf_mags,
self.psf_mag_errs,
self.psf_bands,
{"trans": self.psf_filter_curves["trans"]},
True,
)
def _posterior_bundle_path(self, save_name=None, save_path=None):
"""Return the compressed on-disk path for a saved posterior bundle."""
out_name = self._normalize_posterior_bundle_name(
f"{self.filename}_samples" if save_name is None else save_name
)
out_dir = self.output_path if save_path is None else save_path
if out_dir is None:
out_dir = '.'
os.makedirs(out_dir, exist_ok=True)
return os.path.join(out_dir, out_name)
def _intrinsic_powerlaw_draws(self, wave_out=None, apply_psf_scale=False):
"""Return posterior draws for the intrinsic AGN power law on ``wave_out``."""
samples = getattr(self, 'numpyro_samples', None)
if samples is None or 'PL_slope' not in samples:
return None
wave_eval = np.asarray(self.wave if wave_out is None else wave_out, dtype=float)
if wave_eval.ndim != 1 or wave_eval.size == 0 or not np.all(np.isfinite(wave_eval)):
return None
pl_norm = np.asarray(samples['PL_norm'], dtype=float).reshape(-1)
pl_slope = np.asarray(samples['PL_slope'], dtype=float).reshape(-1)
if pl_norm.size == 0 or pl_slope.size == 0:
return None
n = min(pl_norm.size, pl_slope.size)
if n == 0:
return None
pl_norm = pl_norm[:n]
pl_slope = pl_slope[:n]
prior_config = getattr(self, '_fit_prior_config', None) or {}
pivot = prior_config.get('PL_pivot', None)
if pivot is None:
pivot = 0.5 * (wave_eval[0] + wave_eval[-1])
pivot = max(float(pivot), 1e-8)
x = np.clip(wave_eval / pivot, 1e-8, None)
draws = pl_norm[:, None] * (x[None, :] ** pl_slope[:, None])
if apply_psf_scale:
psf_scale = float(getattr(self, 'scale_psf', np.nan))
if np.isfinite(psf_scale):
draws = psf_scale * draws
return draws
@classmethod
def _normalize_posterior_bundle_name(cls, name):
"""Normalize posterior bundle names to the enforced ``.h5`` suffix."""
name = str(name)
if name.endswith(cls._POSTERIOR_BUNDLE_SUFFIX):
return name
return name + cls._POSTERIOR_BUNDLE_SUFFIX
@staticmethod
def _bundle_excluded_keys():
"""Return object attributes intentionally omitted from saved bundles."""
return {
"numpyro_mcmc",
"svi",
"svi_state",
"fig",
"trace_fig",
"corner_fig",
"fsps_grid",
"fe_uv",
"fe_op",
"pred_out",
}
@staticmethod
def _is_matplotlib_state(value):
"""Return True when value is a matplotlib figure/axes object."""
classes = []
fig_cls = getattr(getattr(matplotlib, "figure", None), "Figure", None)
axes_cls = getattr(getattr(matplotlib, "axes", None), "Axes", None)
if isinstance(fig_cls, type):
classes.append(fig_cls)
if isinstance(axes_cls, type):
classes.append(axes_cls)
if len(classes) == 0:
return False
return isinstance(value, tuple(classes))
@classmethod
def _exclude_from_posterior_bundle(cls, key, value):
"""Return True when an attribute should be skipped during bundle save."""
if key in cls._bundle_excluded_keys():
return True
if key.startswith("_pred_"):
return True
if cls._is_matplotlib_state(value):
return True
return False
@staticmethod
def _serialize_for_hdf5(value):
"""Recursively convert model state into HDF5-serializable objects."""
if isinstance(value, CustomComponentSpec):
return value.to_state()
if isinstance(value, CustomLineComponentSpec):
return value.to_state()
if isinstance(value, dict):
return {str(k): QSOFit._serialize_for_hdf5(v) for k, v in value.items()}
if isinstance(value, (list, tuple)):
return type(value)(QSOFit._serialize_for_hdf5(v) for v in value)
if isinstance(value, np.ndarray) and value.dtype == object:
return {
"__ndarray_object__": True,
"shape": tuple(int(x) for x in value.shape),
"items": [QSOFit._serialize_for_hdf5(v) for v in value.ravel(order="C").tolist()],
}
if isinstance(value, (np.ndarray, np.generic)):
return np.asarray(value)
if hasattr(value, "shape") and hasattr(value, "dtype"):
return np.asarray(value)
return value
@staticmethod
def _deserialize_from_hdf5(value):
"""Rebuild custom serialized objects after reading HDF5 state."""
if isinstance(value, dict):
if value.get("__custom_component__", False):
return CustomComponentSpec.from_state(value)
if value.get("__custom_line_component__", False):
return CustomLineComponentSpec.from_state(value)
if value.get("__ndarray_object__", False):
items = [QSOFit._deserialize_from_hdf5(v) for v in value["items"]]
arr = np.asarray(items, dtype=object)
return arr.reshape(tuple(value["shape"]))
return {k: QSOFit._deserialize_from_hdf5(v) for k, v in value.items()}
if isinstance(value, list):
return [QSOFit._deserialize_from_hdf5(v) for v in value]
if isinstance(value, tuple):
return tuple(QSOFit._deserialize_from_hdf5(v) for v in value)
return value
@staticmethod
def _hdf5_scalar_string_dtype():
return h5py.string_dtype(encoding="utf-8")
@classmethod
def _write_hdf5_node(cls, parent, name, value):
value = cls._serialize_for_hdf5(value)
if value is None:
grp = parent.create_group(name)
grp.attrs["node_type"] = "none"
return
if isinstance(value, dict):
grp = parent.create_group(name)
grp.attrs["node_type"] = "dict"
for idx, (k, v) in enumerate(value.items()):
item_grp = grp.create_group(f"item_{idx:08d}")
cls._write_hdf5_node(item_grp, "key", str(k))
cls._write_hdf5_node(item_grp, "value", v)
return
if isinstance(value, list):
grp = parent.create_group(name)
grp.attrs["node_type"] = "list"
for idx, item in enumerate(value):
cls._write_hdf5_node(grp, f"item_{idx:08d}", item)
return
if isinstance(value, tuple):
grp = parent.create_group(name)
grp.attrs["node_type"] = "tuple"
for idx, item in enumerate(value):
cls._write_hdf5_node(grp, f"item_{idx:08d}", item)
return
if isinstance(value, np.ndarray):
ds_kwargs = {}
if value.ndim > 0:
ds_kwargs["compression"] = "gzip"
ds_kwargs["shuffle"] = True
ds = parent.create_dataset(name, data=value, **ds_kwargs)
ds.attrs["node_type"] = "ndarray"
return
if isinstance(value, bool):
ds = parent.create_dataset(name, data=np.bool_(value))
ds.attrs["node_type"] = "scalar_bool"
return
if isinstance(value, int):
ds = parent.create_dataset(name, data=np.int64(value))
ds.attrs["node_type"] = "scalar_int"
return
if isinstance(value, float):
ds = parent.create_dataset(name, data=np.float64(value))
ds.attrs["node_type"] = "scalar_float"
return
if isinstance(value, str):
ds = parent.create_dataset(name, data=np.array(value, dtype=cls._hdf5_scalar_string_dtype()))
ds.attrs["node_type"] = "scalar_str"
return
raise TypeError(f"Unsupported value type in posterior bundle: {type(value)!r}")
@classmethod
def _read_hdf5_node(cls, parent, name):
node = parent[name]
if isinstance(node, h5py.Dataset):
node_type = node.attrs.get("node_type", "ndarray")
if isinstance(node_type, bytes):
node_type = node_type.decode("utf-8")
if node_type == "scalar_str":
return node.asstr()[()]
value = node[()]
if node_type == "scalar_bool":
return bool(value)
if node_type == "scalar_int":
return int(value)
if node_type == "scalar_float":
return float(value)
return np.asarray(value)
node_type = node.attrs.get("node_type", "")
if isinstance(node_type, bytes):
node_type = node_type.decode("utf-8")
if node_type == "none":
return None
if node_type == "dict":
out = {}
for item_name in sorted(node.keys()):
item_grp = node[item_name]
key = cls._read_hdf5_node(item_grp, "key")
out[str(key)] = cls._read_hdf5_node(item_grp, "value")
return out
if node_type == "list":
return [cls._read_hdf5_node(node, item_name) for item_name in sorted(node.keys())]
if node_type == "tuple":
return tuple(cls._read_hdf5_node(node, item_name) for item_name in sorted(node.keys()))
raise TypeError(f"Unsupported HDF5 node type in posterior bundle: {node_type!r}")
@staticmethod
def _sample_bundle_meta_keys():
"""Return metadata keys persisted in sample-only bundles."""
return {
"lam_in",
"flux_in",
"err_in",
"z",
"ra",
"dec",
"filename",
"output_path",
"wdisp",
"wave",
"flux",
"err",
"wave_prereduced",
"flux_prereduced",
"fe_uv_wave",
"fe_uv_flux",
"fe_op_wave",
"fe_op_flux",
"psf_mags",
"psf_mag_errs",
"psf_mags_raw",
"psf_mag_errs_raw",
"psf_mags_dered",
"psf_mag_errs_dered",
"psf_bands",
"psf_filter_curves",
"use_psf_phot",
"verbose",
"save_fig",
"_fit_deredden",
"_fit_decompose_host",
"_fit_fit_lines",
"_fit_fit_pl",
"_fit_fit_fe",
"_fit_fit_bc",
"_fit_fit_bal",
"_fit_fit_poly",
"_fit_fit_reddening",
"_fit_fit_poly_order",
"_fit_mask_lya_forest",
"_fit_method",
"_fit_fsps_age_grid",
"_fit_fsps_logzsol_grid",
"_fit_prior_config",
"_fit_dsps_ssp_fn",
"_fit_use_psf_phot",
"_fit_custom_components",
"_fit_custom_line_components",
}
def _collect_sample_bundle_meta(self):
"""Collect minimal metadata for sample-only bundle persistence."""
if not hasattr(self, "numpyro_samples") or self.numpyro_samples is None:
raise RuntimeError("No posterior samples available. Run fit() before saving a posterior bundle.")
keys = self._sample_bundle_meta_keys()
meta = {}
for key in keys:
if key not in self.__dict__:
continue
value = self.__dict__[key]
if self._exclude_from_posterior_bundle(key, value):
continue
meta[key] = self._serialize_for_hdf5(value)
return meta
@staticmethod
def _empty_tied_line_meta():
"""Return an empty tied-line metadata payload."""
return {
'n_lines': 0,
'n_vgroups': 0,
'n_wgroups': 0,
'n_fgroups': 0,
'ln_lambda0': _np_to_jnp(np.array([], dtype=float)),
'vgroup': np.array([], dtype=int),
'wgroup': np.array([], dtype=int),
'fgroup': np.array([], dtype=int),
'flux_ratio': np.array([], dtype=float),
'dmu_init_group': np.array([], dtype=float),
'dmu_min_group': np.array([], dtype=float),
'dmu_max_group': np.array([], dtype=float),
'sig_init_group': np.array([], dtype=float),
'sig_min_group': np.array([], dtype=float),
'sig_max_group': np.array([], dtype=float),
'amp_init_group': np.array([], dtype=float),
'amp_min_group': np.array([], dtype=float),
'amp_max_group': np.array([], dtype=float),
'names': [],
'compnames': [],
'line_lambda': np.array([], dtype=float),
}
@staticmethod
def _require_posterior_bundle_fsps_metadata(state):
"""Return required FSPS bundle metadata or raise on incomplete bundles."""
required_keys = (
"_fit_fsps_age_grid",
"_fit_fsps_logzsol_grid",
"_fit_dsps_ssp_fn",
)
missing = [key for key in required_keys if key not in state or state[key] is None]
if missing:
joined = ", ".join(missing)
raise ValueError(
"Posterior bundle is missing required FSPS metadata for hydration: "
f"{joined}."
)
age_grid_gyr = tuple(np.asarray(state["_fit_fsps_age_grid"], dtype=float).tolist())
logzsol_grid = tuple(np.asarray(state["_fit_fsps_logzsol_grid"], dtype=float).tolist())
dsps_ssp_fn = state["_fit_dsps_ssp_fn"]
if len(age_grid_gyr) == 0 or len(logzsol_grid) == 0:
raise ValueError("Posterior bundle FSPS metadata must define non-empty age and metallicity grids.")
if not isinstance(dsps_ssp_fn, str) or len(dsps_ssp_fn) == 0:
raise ValueError("Posterior bundle FSPS metadata must include a non-empty dsps_ssp_fn.")
return age_grid_gyr, logzsol_grid, dsps_ssp_fn
@staticmethod
def _validate_fsps_weights_shape(pred_out, expected_templates, context):
"""Ensure hydrated or reconstructed FSPS weights match the expected basis width."""
if pred_out is None or "fsps_weights" not in pred_out:
raise ValueError(f"{context} requires pred_out['fsps_weights'] to be present.")
fsps_weights = np.asarray(pred_out["fsps_weights"], dtype=float)
if fsps_weights.ndim != 2:
raise ValueError(
f"{context} requires pred_out['fsps_weights'] to be a 2D array; "
f"got shape {fsps_weights.shape}."
)
if fsps_weights.shape[1] != int(expected_templates):
raise ValueError(
f"{context} requires pred_out['fsps_weights'] width {expected_templates}, "
f"got {fsps_weights.shape[1]}."
)
return fsps_weights
def _ensure_hydrated_from_samples(self):
"""Rebuild posterior-derived component products from saved samples."""
if bool(getattr(self, "_posterior_hydrated", False)):
return
has_cached = (
hasattr(self, "model_total")
and hasattr(self, "f_conti_model")
and hasattr(self, "f_line_model")
and hasattr(self, "host")
and hasattr(self, "pred_bands")
)
if has_cached:
self._posterior_hydrated = True
return
if not hasattr(self, "numpyro_samples") or self.numpyro_samples is None:
raise RuntimeError("No posterior samples available for hydration.")
if not hasattr(self, "wave") or not hasattr(self, "flux") or not hasattr(self, "err"):
raise RuntimeError("Missing fitted spectrum context (wave/flux/err) for hydration.")
wave = np.asarray(self.wave, dtype=float)
flux = np.asarray(self.flux, dtype=float)
err = np.asarray(self.err, dtype=float)
if wave.ndim != 1 or wave.size < 2:
raise RuntimeError("Invalid fitted wavelength grid for hydration.")
prior_config = getattr(self, "_fit_prior_config", None)
if prior_config is None:
prior_config = build_default_prior_config(flux)
custom_components = normalize_custom_components(getattr(self, "_fit_custom_components", ()))
custom_line_components = normalize_custom_line_components(getattr(self, "_fit_custom_line_components", ()))
prior_config = inject_default_custom_component_priors(prior_config, flux, custom_components)
prior_config = inject_default_custom_line_component_priors(prior_config, flux, custom_line_components)
conti_priors = prior_config.get("conti_priors", {})
use_lines = bool(getattr(self, "_fit_fit_lines", True))
line_table = _extract_line_table_from_prior_config(prior_config)
if line_table is not None:
tied_line_meta = build_tied_line_meta_from_linelist(line_table, wave)
else:
tied_line_meta = self._empty_tied_line_meta()
if use_lines and line_table is None and len(custom_line_components) == 0:
raise RuntimeError("Hydration requires line priors/table when fit_lines=True.")
age_grid_gyr, logzsol_grid, dsps_ssp_fn = self._require_posterior_bundle_fsps_metadata(self.__dict__)
decompose_host = bool(getattr(self, "_fit_decompose_host", True))
fsps_grid = self._build_fsps_grid_for_fit(
wave=wave,
age_grid_gyr=age_grid_gyr,
logzsol_grid=logzsol_grid,
dsps_ssp_fn=dsps_ssp_fn,
decompose_host=decompose_host,
)
self.tied_line_meta = tied_line_meta
pred = Predictive(
qso_fsps_joint_model,
posterior_samples={k: jnp.asarray(v) for k, v in self.numpyro_samples.items()},
return_sites=self._predictive_return_sites(
custom_components=custom_components,
custom_line_components=custom_line_components,
),
)
rng_key = jax.random.PRNGKey(0)
pred_out = pred(
rng_key,
wave=wave,
flux=None,
err=err,
conti_priors=conti_priors,
tied_line_meta=tied_line_meta,
fsps_grid=fsps_grid,
fe_uv_wave=self.fe_uv_wave,
fe_uv_flux=self.fe_uv_flux,
fe_op_wave=self.fe_op_wave,
fe_op_flux=self.fe_op_flux,
use_lines=use_lines,
prior_config=prior_config,
decompose_host=decompose_host,
fit_pl=bool(getattr(self, "_fit_fit_pl", True)),
fit_fe=bool(getattr(self, "_fit_fit_fe", True)),
fit_bc=bool(getattr(self, "_fit_fit_bc", True)),
fit_poly=bool(getattr(self, "_fit_fit_poly", False)),
fit_reddening=bool(getattr(self, "_fit_fit_reddening", False)),
fit_poly_order=int(getattr(self, "_fit_fit_poly_order", 2)),
z_qso=float(getattr(self, "z", 0.0)),
psf_mags=getattr(self, "psf_mags", None),
psf_mag_errs=getattr(self, "psf_mag_errs", None),
psf_filter_curves=getattr(self, "psf_filter_curves", None),
use_psf_phot=bool(getattr(self, "_fit_use_psf_phot", getattr(self, "use_psf_phot", False))),
custom_components=custom_components,
custom_line_components=custom_line_components,
)
self._validate_fsps_weights_shape(
pred_out,
expected_templates=fsps_grid.templates.shape[1],
context="Hydrated posterior state",
)
self._consume_posterior_outputs(
samples=self.numpyro_samples,
pred_out=pred_out,
fsps_grid=fsps_grid,
tied_line_meta=tied_line_meta,
use_lines=use_lines,
decompose_host=decompose_host,
)
[docs]
def save_posterior_bundle(self, save_name=None, save_path=None):
"""Persist posterior samples plus minimal metadata for compact reloads."""
if not hasattr(self, "numpyro_samples") or self.numpyro_samples is None:
raise RuntimeError("No posterior samples available. Run fit() before saving a posterior bundle.")
meta = self._collect_sample_bundle_meta()
out_file = self._posterior_bundle_path(save_name=save_name, save_path=save_path)
with h5py.File(out_file, "w") as h5f:
h5f.attrs["posterior_bundle_format"] = "jaxqsofit_samples_meta_v1"
samples_grp = h5f.create_group("samples")
for name, draws in self.numpyro_samples.items():
arr = np.asarray(draws)
ds_kwargs = {}
if arr.ndim > 0:
ds_kwargs["compression"] = "gzip"
ds_kwargs["shuffle"] = True
samples_grp.create_dataset(str(name), data=arr, **ds_kwargs)
meta_grp = h5f.create_group("meta")
for key, value in meta.items():
self._write_hdf5_node(meta_grp, str(key), value)
print(f"Saved posterior bundle: {out_file}")
return out_file
@staticmethod
def _build_fsps_grid_for_fit(wave, age_grid_gyr, logzsol_grid, dsps_ssp_fn, decompose_host):
"""Build the host-template grid only when host decomposition is enabled."""
if decompose_host:
return build_fsps_template_grid(
wave_out=wave,
age_grid_gyr=age_grid_gyr,
logzsol_grid=logzsol_grid,
dsps_ssp_fn=dsps_ssp_fn,
)
class _DummyFSPSGrid:
pass
wave = np.asarray(wave, dtype=float)
age_grid_gyr = np.asarray(age_grid_gyr, dtype=float)
logzsol_grid = np.asarray(logzsol_grid, dtype=float)
grid = _DummyFSPSGrid()
grid.wave = wave
n_templates = int(len(age_grid_gyr) * len(logzsol_grid))
grid.templates = np.zeros((len(wave), n_templates), dtype=float)
grid.template_meta = []
for logz in logzsol_grid:
for age in age_grid_gyr:
grid.template_meta.append({
'tage_gyr': float(age),
'logzsol': float(logz),
'norm': 1.0,
'dsps_lgmet': np.nan,
'dsps_lg_age_gyr': np.nan,
})
grid.age_grid_gyr = age_grid_gyr
grid.logzsol_grid = logzsol_grid
return grid
[docs]
@classmethod
def load_from_samples(
cls,
filename=None,
output_path=None,
save_name=None,
plot_fig=True,
plot_diagnostics=True,
kwargs_plot=None,
diagnostics_kwargs=None,
):
"""Load a compressed HDF5 posterior bundle and return a QSOFit object."""
if save_name is not None:
bundle_name = cls._normalize_posterior_bundle_name(save_name)
bundle_dir = '.' if output_path is None else output_path
bundle_path = os.path.join(bundle_dir, bundle_name)
resolved_name = cls._resolve_filename(filename=filename)
elif filename is not None:
resolved_name = cls._resolve_filename(filename=filename)
bundle_name = cls._normalize_posterior_bundle_name(f"{resolved_name}_samples")
bundle_dir = '.' if output_path is None else output_path
bundle_path = os.path.join(bundle_dir, bundle_name)
else:
bundle_dir = '.' if output_path is None else output_path
matches = sorted(glob.glob(os.path.join(bundle_dir, f"*_samples{cls._POSTERIOR_BUNDLE_SUFFIX}")))
if len(matches) == 0:
raise FileNotFoundError(
f"No compressed posterior bundle (*.h5) found under: {bundle_dir}. "
"Pass filename=..., output_path=..., or save_name=... explicitly."
)
if len(matches) > 1:
raise FileNotFoundError(
f"Multiple compressed posterior bundles (*.h5) found under: {bundle_dir}. "
"Pass filename=... or save_name=... explicitly."
)
bundle_path = matches[0]
bundle_name = os.path.basename(bundle_path)
suffix = f"_samples{cls._POSTERIOR_BUNDLE_SUFFIX}"
resolved_name = bundle_name[: -len(suffix)] if bundle_name.endswith(suffix) else bundle_name
if not os.path.exists(bundle_path):
raise FileNotFoundError(f"Posterior bundle not found: {bundle_path}")
with h5py.File(bundle_path, "r") as h5f:
if "samples" in h5f and "meta" in h5f:
samples = {k: np.asarray(h5f["samples"][k][()]) for k in h5f["samples"].keys()}
meta = {k: cls._read_hdf5_node(h5f["meta"], k) for k in h5f["meta"].keys()}
meta = cls._deserialize_from_hdf5(meta)
cls._require_posterior_bundle_fsps_metadata(meta)
state = dict(meta)
state["numpyro_samples"] = samples
state["_posterior_hydrated"] = False
elif "state" in h5f:
# Backward-compatible read for older .h5 bundles.
state = cls._read_hdf5_node(h5f, "state")
state = cls._deserialize_from_hdf5(state)
else:
raise ValueError(f"Unsupported posterior bundle schema: {bundle_path}")
obj = cls(
lam=state["lam_in"],
flux=state["flux_in"],
err=state.get("err_in"),
z=state.get("z", 0.0),
ra=state.get("ra", -999),
dec=state.get("dec", -999),
filename=state.get("filename", resolved_name),
output_path=output_path if output_path is not None else state.get("output_path"),
wdisp=state.get("wdisp"),
psf_mags=state.get("psf_mags_raw", state.get("psf_mags")),
psf_mag_errs=state.get("psf_mag_errs_raw", state.get("psf_mag_errs")),
psf_bands=state.get("psf_bands"),
)
obj.__dict__.update(state)
obj._resumed_from_samples = True
obj.install_path = os.path.dirname(os.path.abspath(__file__))
if not hasattr(obj, "verbose"):
obj.verbose = False
if not hasattr(obj, "save_fig"):
obj.save_fig = False
if not hasattr(obj, "SN_ratio_conti"):
obj.SN_ratio_conti = np.nan
obj._ensure_hydrated_from_samples()
if plot_fig:
plot_kwargs = {} if kwargs_plot is None else dict(kwargs_plot)
if "show_plot" not in plot_kwargs:
plot_kwargs["show_plot"] = False
obj.plot_fig(**plot_kwargs)
if plot_diagnostics:
diag_kwargs = {} if diagnostics_kwargs is None else dict(diagnostics_kwargs)
obj.plot_mcmc_diagnostics(**diag_kwargs)
return obj
[docs]
def fit(self, name=None, deredden=True,
wave_range=None, wave_mask=None, save_fits_name=None,
fit_lines=True, save_result=True, plot_fig=True, save_fig=True,
show_plot=False,
decompose_host=True,
fit_pl=True,
fit_fe=True,
fit_bc=False,
fit_bal=False,
fit_poly=True,
fit_reddening=True,
fit_poly_order=2,
mask_lya_forest=True,
fit_method='optax+nuts',
verbose=True,
fsps_age_grid=(0.01, 0.03, 0.1, 0.3, 1.0, 3.0, 10.0),
fsps_logzsol_grid=(-1.0, -0.5, 0.0, 0.2),
prior_config=None,
dsps_ssp_fn='tempdata.h5',
nuts_warmup=50,
nuts_samples=50,
nuts_chains=1,
nuts_target_accept=0.9,
optax_steps=600,
optax_lr=1e-2,
psf_mags=None,
psf_mag_errs=None,
psf_bands=None,
use_psf_phot=False,
custom_components=None,
custom_line_components=None,
kwargs_plot=None):
"""Run end-to-end preprocessing, fitting, and optional plotting/saving.
Parameters
----------
name : str or None, optional
Optional override basename used for figure/output naming.
deredden : bool, optional
If True, apply Galactic dereddening with dustmaps SFD.
wave_range : tuple[float, float] or None, optional
Rest-frame wavelength range to keep.
wave_mask : array-like or None, optional
Rest-frame wavelength intervals to mask.
save_fits_name : str or None, optional
Basename for saved result table.
fit_lines : bool, optional
Enable/disable emission-line model.
save_result : bool, optional
If True, write summary results to CSV.
plot_fig : bool, optional
If True, render decomposition figure.
save_fig : bool, optional
If True, save rendered figures.
show_plot : bool, optional
If True, call ``plt.show()`` for the decomposition figure.
Default is ``False`` to avoid interactive pop-ups in notebook/pipeline runs.
decompose_host : bool, optional
Enable/disable host SPS decomposition.
fit_pl : bool, optional
Enable/disable AGN power-law component.
fit_fe : bool, optional
Enable/disable FeII components.
fit_bc : bool, optional
Enable/disable Balmer continuum.
fit_bal : bool, optional
If True, append built-in BAL trough custom components for
N V, Si IV, C IV, C III], FeLoBAL troughs near 2000/2200 A,
and Mg II.
fit_poly : bool, optional
Enable/disable multiplicative polynomial tilt.
fit_poly_order : int, optional
Polynomial order for the multiplicative continuum tilt. Uses
coefficients ``poly_c1`` through ``poly_cN``.
fit_reddening : bool, optional
If True, apply a built-in smooth SMC-like reddening curve to the
AGN power-law continuum.
mask_lya_forest : bool, optional
If True, mask pixels with rest-frame wavelength below Ly-alpha
(1215.67 Angstrom) before fitting.
fit_method : {'nuts', 'optax', 'optax+nuts'}, optional
Fitting backend. ``'optax'`` runs staged MAP optimization with
Optax/Adam and stores a point estimate, but does not produce
posterior samples. ``'nuts'`` runs NumPyro NUTS directly and stores
posterior samples. ``'optax+nuts'`` first runs the staged Optax
warm start and then initializes NUTS from that solution; this is
the default and is usually the most robust posterior-fitting mode.
verbose : bool, optional
Verbose optimizer output where applicable.
fsps_age_grid : sequence of float, optional
SSP age grid in Gyr.
fsps_logzsol_grid : sequence of float, optional
SSP metallicity grid in log(Z/Zsun).
prior_config : dict or None, optional
Prior/config dictionary. If None, defaults are auto-built.
dsps_ssp_fn : str, optional
Path to DSPS SSP HDF5 template file.
nuts_warmup, nuts_samples : int, optional
NUTS warmup and posterior sample counts.
nuts_chains : int, optional
Number of MCMC chains.
nuts_target_accept : float, optional
Target accept probability for NUTS.
optax_steps : int, optional
Number of SVI/Optax warm-start steps.
optax_lr : float, optional
Learning rate for Optax Adam.
psf_mags, psf_mag_errs : array-like or None, optional
Optional PSF-aperture photometry magnitudes and 1-sigma errors.
psf_bands : sequence of str or None, optional
Band labels for ``psf_mags``. Defaults to SDSS ``ugriz`` order.
use_psf_phot : bool, optional
If True, add a PSF-photometry likelihood term and infer PSF/fiber
scaling plus host leakage.
custom_components : sequence[CustomComponentSpec] or None, optional
Optional additive continuum components. Use
``make_custom_component`` for general user-defined components or
``make_template_component`` for a template convenience wrapper.
custom_line_components : sequence[CustomLineComponentSpec] or None, optional
Optional additive emission-line components. Each component provides
its own evaluator and is tagged as ``broad`` or ``narrow``.
kwargs_plot : dict or None, optional
Extra keyword arguments passed to :meth:`plot_fig`.
"""
if kwargs_plot is None:
kwargs_plot = {}
if 'show_plot' not in kwargs_plot:
kwargs_plot['show_plot'] = show_plot
# Persist fit configuration so posterior reconstructions can be built on
# alternate wavelength grids after fitting.
self._fit_deredden = bool(deredden)
self._fit_decompose_host = bool(decompose_host)
self._fit_fit_lines = bool(fit_lines)
self._fit_fit_pl = bool(fit_pl)
self._fit_fit_fe = bool(fit_fe)
self._fit_fit_bc = bool(fit_bc)
self._fit_fit_bal = bool(fit_bal)
self._fit_fit_poly = bool(fit_poly)
self._fit_fit_reddening = bool(fit_reddening)
self._fit_fit_poly_order = int(fit_poly_order)
self._fit_mask_lya_forest = bool(mask_lya_forest)
self._fit_method = str(fit_method)
self._fit_fsps_age_grid = tuple(fsps_age_grid)
self._fit_fsps_logzsol_grid = tuple(fsps_logzsol_grid)
self._fit_prior_config = prior_config
self._fit_dsps_ssp_fn = str(dsps_ssp_fn)
self._fit_use_psf_phot = bool(use_psf_phot)
requested_custom_components = normalize_custom_components(custom_components)
self._fit_custom_components = requested_custom_components
self._fit_custom_line_components = normalize_custom_line_components(custom_line_components)
self.wave_range = wave_range
self.wave_mask = wave_mask
self.linefit = fit_lines
self.save_fig = save_fig
self.verbose = verbose
if name is not None and str(name).strip() != "":
self.filename = str(name).strip()
prior_config_input = prior_config
prior_config = {} if prior_config is None else prior_config
data_dir = os.path.join(self.install_path, 'data')
self.fe_uv = np.genfromtxt(os.path.join(data_dir, 'fe_uv.txt'))
self.fe_op = np.genfromtxt(os.path.join(data_dir, 'fe_optical.txt'))
self.fe_uv_wave = 10 ** self.fe_uv[:, 0]
# Normalize non-negative template amplitudes to O(1) so Fe norms are in data-flux units.
self.fe_uv_flux = _normalize_template_flux(np.maximum(self.fe_uv[:, 1], 0.0), target_amp=1.0)
fe_op_wave = 10 ** self.fe_op[:, 0]
fe_op_flux = _normalize_template_flux(np.maximum(self.fe_op[:, 1], 0.0), target_amp=1.0)
m = (fe_op_wave > 3686.) & (fe_op_wave < 7484.)
self.fe_op_wave = fe_op_wave[m]
self.fe_op_flux = fe_op_flux[m]
if save_fits_name is None:
save_fits_name = self.filename
ind_gooderror = np.where((self.err_in > 0) & np.isfinite(self.err_in) & (self.flux_in != 0) & np.isfinite(self.flux_in), True, False)
self.err = self.err_in[ind_gooderror]
self.flux = self.flux_in[ind_gooderror]
self.lam = self.lam_in[ind_gooderror]
if wave_range is not None:
self._wave_trim(self.lam, self.flux, self.err, self.z)
if wave_mask is not None:
self._wave_msk(self.lam, self.flux, self.err, self.z)
if mask_lya_forest:
self._mask_lya_forest(self.lam, self.flux, self.err, self.z)
if deredden:
self._validate_deredden_coordinates(self.ra, self.dec)
self._de_redden(self.lam, self.flux, self.err, self.ra, self.dec)
self._rest_frame(self.lam, self.flux, self.err, self.z)
self._calculate_sn(self.wave, self.flux)
self._orignial_spec(self.wave, self.flux, self.err)
bal_components = build_default_bal_components(self.flux) if bool(fit_bal) else ()
self._fit_custom_components = normalize_custom_components(
tuple(requested_custom_components) + tuple(bal_components)
)
if prior_config_input is None:
prior_config = build_default_prior_config(self.flux)
prior_config = inject_default_custom_component_priors(
prior_config=prior_config,
flux=self.flux,
custom_components=self._fit_custom_components,
)
prior_config = inject_default_custom_line_component_priors(
prior_config=prior_config,
flux=self.flux,
custom_line_components=self._fit_custom_line_components,
)
out_params = prior_config.get('out_params', {})
self.L_conti_wave = np.asarray(out_params.get('cont_loc', []), dtype=float)
self._fit_prior_config = prior_config
pl_pivot = prior_config.get("PL_pivot", None)
if pl_pivot is None:
pl_pivot = _spectrum_center_pivot(self.wave)
prior_config["PL_pivot"] = float(np.asarray(pl_pivot, dtype=float))
self._fit_prior_config = prior_config
psf_mags_use, psf_mag_errs_use, _psf_bands_use, psf_filter_curves_use, use_psf_phot_use = self._prepare_psf_photometry(
wave_obs=self.lam,
psf_mags=psf_mags,
psf_mag_errs=psf_mag_errs,
psf_bands=psf_bands,
use_psf_phot=use_psf_phot,
)
if fit_method == 'nuts':
self.run_fsps_numpyro_fit(
num_warmup=nuts_warmup,
num_samples=nuts_samples,
num_chains=nuts_chains,
target_accept_prob=nuts_target_accept,
age_grid_gyr=fsps_age_grid,
logzsol_grid=fsps_logzsol_grid,
prior_config=prior_config,
dsps_ssp_fn=dsps_ssp_fn,
use_lines=fit_lines,
decompose_host=decompose_host,
fit_pl=fit_pl,
fit_fe=fit_fe,
fit_bc=fit_bc,
fit_poly=fit_poly,
fit_reddening=fit_reddening,
fit_poly_order=fit_poly_order,
psf_mags=psf_mags_use,
psf_mag_errs=psf_mag_errs_use,
psf_filter_curves=psf_filter_curves_use,
use_psf_phot=use_psf_phot_use,
custom_components=self._fit_custom_components,
custom_line_components=self._fit_custom_line_components,
)
elif fit_method == 'optax':
self.run_fsps_optax_fit(
num_steps=optax_steps,
learning_rate=optax_lr,
age_grid_gyr=fsps_age_grid,
logzsol_grid=fsps_logzsol_grid,
prior_config=prior_config,
dsps_ssp_fn=dsps_ssp_fn,
use_lines=fit_lines,
decompose_host=decompose_host,
fit_pl=fit_pl,
fit_fe=fit_fe,
fit_bc=fit_bc,
fit_poly=fit_poly,
fit_reddening=fit_reddening,
fit_poly_order=fit_poly_order,
psf_mags=psf_mags_use,
psf_mag_errs=psf_mag_errs_use,
psf_filter_curves=psf_filter_curves_use,
use_psf_phot=use_psf_phot_use,
custom_components=self._fit_custom_components,
custom_line_components=self._fit_custom_line_components,
)
elif fit_method == 'optax+nuts':
self.run_fsps_optax_nuts_fit(
optax_steps=optax_steps,
optax_learning_rate=optax_lr,
num_warmup=nuts_warmup,
num_samples=nuts_samples,
num_chains=nuts_chains,
target_accept_prob=nuts_target_accept,
age_grid_gyr=fsps_age_grid,
logzsol_grid=fsps_logzsol_grid,
prior_config=prior_config,
dsps_ssp_fn=dsps_ssp_fn,
use_lines=fit_lines,
decompose_host=decompose_host,
fit_pl=fit_pl,
fit_fe=fit_fe,
fit_bc=fit_bc,
fit_poly=fit_poly,
fit_reddening=fit_reddening,
fit_poly_order=fit_poly_order,
psf_mags=psf_mags_use,
psf_mag_errs=psf_mag_errs_use,
psf_filter_curves=psf_filter_curves_use,
use_psf_phot=use_psf_phot_use,
custom_components=self._fit_custom_components,
custom_line_components=self._fit_custom_line_components,
)
else:
raise ValueError(f"Unknown fit_method='{fit_method}'. Use 'nuts', 'optax', or 'optax+nuts'.")
if save_result:
self.save_result(self.conti_result, self.conti_result_type, self.conti_result_name,
self.line_result, self.line_result_type, self.line_result_name,
save_fits_name)
self.save_posterior_bundle()
if plot_fig:
self.plot_fig(**kwargs_plot)
[docs]
def run_fsps_numpyro_fit(self, num_warmup=500, num_samples=1000, num_chains=1,
target_accept_prob=0.9,
age_grid_gyr=(0.1, 0.3, 1.0, 3.0, 10.0),
logzsol_grid=(-1.0, -0.5, 0.0, 0.2),
prior_config=None,
dsps_ssp_fn='tempdata.h5',
use_lines=True,
decompose_host=True,
fit_pl=True,
fit_fe=True,
fit_bc=True,
fit_poly=False,
fit_reddening=False,
fit_poly_order=2,
psf_mags=None,
psf_mag_errs=None,
psf_filter_curves=None,
use_psf_phot=False,
custom_components=None,
custom_line_components=None,
init_values=None):
"""Fit the full model using NUTS MCMC and store posterior summaries.
Parameters
----------
num_warmup, num_samples : int, optional
MCMC warmup and posterior sample counts.
num_chains : int, optional
Number of MCMC chains.
target_accept_prob : float, optional
Target acceptance probability for NUTS.
age_grid_gyr : sequence of float, optional
SSP age grid in Gyr.
logzsol_grid : sequence of float, optional
SSP metallicity grid in log(Z/Zsun).
prior_config : dict or None, optional
Prior/config dictionary for model blocks.
dsps_ssp_fn : str, optional
DSPS SSP template HDF5 path.
use_lines, decompose_host, fit_pl, fit_fe, fit_bc, fit_poly, fit_reddening : bool, optional
Component toggles for model blocks.
fit_poly_order : int, optional
Polynomial order for the multiplicative continuum tilt.
init_values : dict or None, optional
Optional initial values for ``init_to_value``.
"""
wave = np.asarray(self.wave, dtype=float)
flux = np.asarray(self.flux, dtype=float)
err = np.asarray(self.err, dtype=float)
custom_components = normalize_custom_components(custom_components)
custom_line_components = normalize_custom_line_components(custom_line_components)
if prior_config is None:
prior_config = build_default_prior_config(flux)
prior_config = inject_default_custom_component_priors(prior_config, flux, custom_components)
prior_config = inject_default_custom_line_component_priors(prior_config, flux, custom_line_components)
conti_priors = prior_config.get('conti_priors', {})
line_table = _extract_line_table_from_prior_config(prior_config)
if use_lines and line_table is None and len(custom_line_components) == 0:
raise ValueError(
"fit_lines=True requires either line priors/table in prior_config "
"or at least one custom_line_component."
)
if line_table is not None:
tied_line_meta = build_tied_line_meta_from_linelist(line_table, wave)
else:
tied_line_meta = {
'n_lines': 0,
'n_vgroups': 0,
'n_wgroups': 0,
'n_fgroups': 0,
'ln_lambda0': _np_to_jnp(np.array([], dtype=float)),
'vgroup': np.array([], dtype=int),
'wgroup': np.array([], dtype=int),
'fgroup': np.array([], dtype=int),
'flux_ratio': np.array([], dtype=float),
'dmu_init_group': np.array([], dtype=float),
'dmu_min_group': np.array([], dtype=float),
'dmu_max_group': np.array([], dtype=float),
'sig_init_group': np.array([], dtype=float),
'sig_min_group': np.array([], dtype=float),
'sig_max_group': np.array([], dtype=float),
'amp_init_group': np.array([], dtype=float),
'amp_min_group': np.array([], dtype=float),
'amp_max_group': np.array([], dtype=float),
'names': [],
'compnames': [],
'line_lambda': np.array([], dtype=float),
}
fsps_grid = self._build_fsps_grid_for_fit(
wave=wave,
age_grid_gyr=age_grid_gyr,
logzsol_grid=logzsol_grid,
dsps_ssp_fn=dsps_ssp_fn,
decompose_host=decompose_host,
)
self.tied_line_meta = tied_line_meta
init_vals = {'gal_v_kms': 0.0, 'gal_sigma_kms': 150.0} if init_values is None else init_values
init_strategy = init_to_value(values=init_vals)
kernel = NUTS(qso_fsps_joint_model, init_strategy=init_strategy, target_accept_prob=target_accept_prob, dense_mass=True, max_tree_depth=8)
mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples, num_chains=num_chains, progress_bar=True, jit_model_args=False)
rng_key = jax.random.PRNGKey(0)
mcmc.run(
rng_key,
wave=wave,
flux=flux,
err=err,
conti_priors=conti_priors,
tied_line_meta=tied_line_meta,
fsps_grid=fsps_grid,
fe_uv_wave=self.fe_uv_wave,
fe_uv_flux=self.fe_uv_flux,
fe_op_wave=self.fe_op_wave,
fe_op_flux=self.fe_op_flux,
use_lines=use_lines,
prior_config=prior_config,
decompose_host=decompose_host,
fit_pl=fit_pl,
fit_fe=fit_fe,
fit_bc=fit_bc,
fit_poly=fit_poly,
fit_reddening=fit_reddening,
fit_poly_order=fit_poly_order,
z_qso=self.z,
psf_mags=psf_mags,
psf_mag_errs=psf_mag_errs,
psf_filter_curves=psf_filter_curves,
use_psf_phot=use_psf_phot,
return_line_components=False,
emit_deterministics=False,
custom_components=custom_components,
custom_line_components=custom_line_components,
)
samples = mcmc.get_samples()
pred = Predictive(
qso_fsps_joint_model,
posterior_samples=samples,
return_sites=self._predictive_return_sites(custom_components=custom_components, custom_line_components=custom_line_components),
)
pred_out = pred(
rng_key,
wave=wave,
flux=None,
err=err,
conti_priors=conti_priors,
tied_line_meta=tied_line_meta,
fsps_grid=fsps_grid,
fe_uv_wave=self.fe_uv_wave,
fe_uv_flux=self.fe_uv_flux,
fe_op_wave=self.fe_op_wave,
fe_op_flux=self.fe_op_flux,
use_lines=use_lines,
prior_config=prior_config,
decompose_host=decompose_host,
fit_pl=fit_pl,
fit_fe=fit_fe,
fit_bc=fit_bc,
fit_poly=fit_poly,
fit_reddening=fit_reddening,
fit_poly_order=fit_poly_order,
z_qso=self.z,
psf_mags=psf_mags,
psf_mag_errs=psf_mag_errs,
psf_filter_curves=psf_filter_curves,
use_psf_phot=use_psf_phot,
custom_components=custom_components,
custom_line_components=custom_line_components,
)
self.numpyro_mcmc = mcmc
self._consume_posterior_outputs(
samples=samples,
pred_out=pred_out,
fsps_grid=fsps_grid,
tied_line_meta=tied_line_meta,
use_lines=use_lines,
decompose_host=decompose_host,
)
[docs]
def run_fsps_optax_fit(self, num_steps=2000, learning_rate=1e-2,
age_grid_gyr=(0.1, 0.3, 1.0, 3.0, 10.0),
logzsol_grid=(-1.0, -0.5, 0.0, 0.2),
prior_config=None,
dsps_ssp_fn='tempdata.h5',
use_lines=True,
decompose_host=True,
fit_pl=True,
fit_fe=True,
fit_bc=True,
fit_poly=False,
fit_reddening=False,
fit_poly_order=2,
psf_mags=None,
psf_mag_errs=None,
psf_filter_curves=None,
use_psf_phot=False,
custom_components=None,
custom_line_components=None):
"""Fit a MAP approximation using staged SVI with an Optax optimizer.
Parameters
----------
num_steps : int, optional
Total SVI steps across all stages.
learning_rate : float, optional
Adam learning rate.
age_grid_gyr : sequence of float, optional
SSP age grid in Gyr.
logzsol_grid : sequence of float, optional
SSP metallicity grid in log(Z/Zsun).
prior_config : dict or None, optional
Prior/config dictionary for model blocks.
dsps_ssp_fn : str, optional
DSPS SSP template HDF5 path.
use_lines, decompose_host, fit_pl, fit_fe, fit_bc, fit_poly, fit_reddening : bool, optional
Component toggles for model blocks.
fit_poly_order : int, optional
Polynomial order for the multiplicative continuum tilt.
"""
wave = np.asarray(self.wave, dtype=float)
flux = np.asarray(self.flux, dtype=float)
err = np.asarray(self.err, dtype=float)
custom_components = normalize_custom_components(custom_components)
custom_line_components = normalize_custom_line_components(custom_line_components)
if prior_config is None:
prior_config = build_default_prior_config(flux)
prior_config = inject_default_custom_component_priors(prior_config, flux, custom_components)
prior_config = inject_default_custom_line_component_priors(prior_config, flux, custom_line_components)
conti_priors = prior_config.get('conti_priors', {})
line_table = _extract_line_table_from_prior_config(prior_config)
if use_lines and line_table is None and len(custom_line_components) == 0:
raise ValueError(
"fit_lines=True requires either line priors/table in prior_config "
"or at least one custom_line_component."
)
if line_table is not None:
tied_line_meta = build_tied_line_meta_from_linelist(line_table, wave)
else:
tied_line_meta = {
'n_lines': 0,
'n_vgroups': 0,
'n_wgroups': 0,
'n_fgroups': 0,
'ln_lambda0': _np_to_jnp(np.array([], dtype=float)),
'vgroup': np.array([], dtype=int),
'wgroup': np.array([], dtype=int),
'fgroup': np.array([], dtype=int),
'flux_ratio': np.array([], dtype=float),
'dmu_init_group': np.array([], dtype=float),
'dmu_min_group': np.array([], dtype=float),
'dmu_max_group': np.array([], dtype=float),
'sig_init_group': np.array([], dtype=float),
'sig_min_group': np.array([], dtype=float),
'sig_max_group': np.array([], dtype=float),
'amp_init_group': np.array([], dtype=float),
'amp_min_group': np.array([], dtype=float),
'amp_max_group': np.array([], dtype=float),
'names': [],
'compnames': [],
'line_lambda': np.array([], dtype=float),
}
fsps_grid = self._build_fsps_grid_for_fit(
wave=wave,
age_grid_gyr=age_grid_gyr,
logzsol_grid=logzsol_grid,
dsps_ssp_fn=dsps_ssp_fn,
decompose_host=decompose_host,
)
self.tied_line_meta = tied_line_meta
def _run_svi(guide, steps, use_lines_i, fit_pl_i, fit_fe_i, fit_bc_i, fit_poly_i, fit_reddening_i, fit_poly_order_i, decompose_host_i):
"""Run an SVI stage and return optimizer state/results."""
optimizer = optax_to_numpyro(optax.adam(learning_rate))
svi = SVI(qso_fsps_joint_model, guide, optimizer, loss=Trace_ELBO())
key = jax.random.PRNGKey(0)
result = svi.run(
key,
int(steps),
wave=wave,
flux=flux,
err=err,
conti_priors=conti_priors,
tied_line_meta=tied_line_meta,
fsps_grid=fsps_grid,
fe_uv_wave=self.fe_uv_wave,
fe_uv_flux=self.fe_uv_flux,
fe_op_wave=self.fe_op_wave,
fe_op_flux=self.fe_op_flux,
use_lines=use_lines_i,
prior_config=prior_config,
decompose_host=decompose_host_i,
fit_pl=fit_pl_i,
fit_fe=fit_fe_i,
fit_bc=fit_bc_i,
fit_poly=fit_poly_i,
fit_reddening=fit_reddening_i,
fit_poly_order=fit_poly_order_i,
z_qso=self.z,
psf_mags=psf_mags,
psf_mag_errs=psf_mag_errs,
psf_filter_curves=psf_filter_curves,
use_psf_phot=use_psf_phot,
return_line_components=False,
emit_deterministics=False,
custom_components=custom_components,
custom_line_components=custom_line_components,
progress_bar=self.verbose,
)
return svi, result
# Stage 1: warm start on simpler landscape (continuum/host only).
n1 = max(100, int(num_steps // 3))
guide1 = AutoDelta(
qso_fsps_joint_model,
init_loc_fn=init_to_value(values={'gal_v_kms': 0.0, 'gal_sigma_kms': 150.0}),
)
svi1, res1 = _run_svi(
guide1,
n1,
use_lines_i=False,
fit_pl_i=fit_pl,
fit_fe_i=False,
fit_bc_i=False,
fit_poly_i=False,
fit_reddening_i=False,
fit_poly_order_i=2,
decompose_host_i=decompose_host,
)
map1 = guide1.median(res1.params)
# Stage 2: full model initialized from stage-1 MAP for overlapping parameters.
n2 = max(100, int(num_steps - n1))
guide2 = AutoDelta(
qso_fsps_joint_model,
init_loc_fn=init_to_value(values=map1),
)
svi, res2 = _run_svi(
guide2,
n2,
use_lines_i=use_lines,
fit_pl_i=fit_pl,
fit_fe_i=fit_fe,
fit_bc_i=fit_bc,
fit_poly_i=fit_poly,
fit_reddening_i=fit_reddening,
fit_poly_order_i=fit_poly_order,
decompose_host_i=decompose_host,
)
svi_state = res2.state
svi_params = res2.params
losses = np.concatenate([np.asarray(res1.losses), np.asarray(res2.losses)])
map_point = guide2.median(svi_params)
samples = {k: np.asarray(v)[None, ...] for k, v in map_point.items()}
rng_key = jax.random.PRNGKey(0)
pred = Predictive(
qso_fsps_joint_model,
posterior_samples={k: jnp.asarray(v) for k, v in samples.items()},
return_sites=self._predictive_return_sites(custom_components=custom_components, custom_line_components=custom_line_components),
)
pred_out = pred(
rng_key,
wave=wave,
flux=None,
err=err,
conti_priors=conti_priors,
tied_line_meta=tied_line_meta,
fsps_grid=fsps_grid,
fe_uv_wave=self.fe_uv_wave,
fe_uv_flux=self.fe_uv_flux,
fe_op_wave=self.fe_op_wave,
fe_op_flux=self.fe_op_flux,
use_lines=use_lines,
prior_config=prior_config,
decompose_host=decompose_host,
fit_pl=fit_pl,
fit_fe=fit_fe,
fit_bc=fit_bc,
fit_poly=fit_poly,
fit_reddening=fit_reddening,
fit_poly_order=fit_poly_order,
z_qso=self.z,
psf_mags=psf_mags,
psf_mag_errs=psf_mag_errs,
psf_filter_curves=psf_filter_curves,
use_psf_phot=use_psf_phot,
custom_components=custom_components,
custom_line_components=custom_line_components,
)
self.numpyro_mcmc = None
self.svi = svi
self.svi_state = svi_state
self.svi_params = svi_params
self.optax_losses = losses
self.optax_map_point = map_point
self._consume_posterior_outputs(
samples=samples,
pred_out=pred_out,
fsps_grid=fsps_grid,
tied_line_meta=tied_line_meta,
use_lines=use_lines,
decompose_host=decompose_host,
)
[docs]
def run_fsps_optax_nuts_fit(self, optax_steps=2000, optax_learning_rate=1e-2,
num_warmup=500, num_samples=1000, num_chains=1,
target_accept_prob=0.9,
age_grid_gyr=(0.1, 0.3, 1.0, 3.0, 10.0),
logzsol_grid=(-1.0, -0.5, 0.0, 0.2),
prior_config=None,
dsps_ssp_fn='tempdata.h5',
use_lines=True,
decompose_host=True,
fit_pl=True,
fit_fe=True,
fit_bc=True,
fit_poly=False,
fit_reddening=False,
fit_poly_order=2,
psf_mags=None,
psf_mag_errs=None,
psf_filter_curves=None,
use_psf_phot=False,
custom_components=None,
custom_line_components=None):
"""Warm-start with Optax MAP, then run NUTS as final inference.
Parameters
----------
optax_steps : int, optional
Number of SVI/Optax warm-start steps.
optax_learning_rate : float, optional
Learning rate for SVI warm-start.
num_warmup, num_samples : int, optional
NUTS warmup and posterior sample counts.
num_chains : int, optional
Number of MCMC chains.
target_accept_prob : float, optional
Target acceptance probability for NUTS.
age_grid_gyr : sequence of float, optional
SSP age grid in Gyr.
logzsol_grid : sequence of float, optional
SSP metallicity grid in log(Z/Zsun).
prior_config : dict or None, optional
Prior/config dictionary for model blocks.
dsps_ssp_fn : str, optional
DSPS SSP template HDF5 path.
use_lines, decompose_host, fit_pl, fit_fe, fit_bc, fit_poly, fit_reddening : bool, optional
Component toggles for model blocks.
fit_poly_order : int, optional
Polynomial order for the multiplicative continuum tilt.
"""
self.run_fsps_optax_fit(
num_steps=optax_steps,
learning_rate=optax_learning_rate,
age_grid_gyr=age_grid_gyr,
logzsol_grid=logzsol_grid,
prior_config=prior_config,
dsps_ssp_fn=dsps_ssp_fn,
use_lines=use_lines,
decompose_host=decompose_host,
fit_pl=fit_pl,
fit_fe=fit_fe,
fit_bc=fit_bc,
fit_poly=fit_poly,
fit_reddening=fit_reddening,
fit_poly_order=fit_poly_order,
psf_mags=psf_mags,
psf_mag_errs=psf_mag_errs,
psf_filter_curves=psf_filter_curves,
use_psf_phot=use_psf_phot,
custom_components=custom_components,
custom_line_components=custom_line_components,
)
init_values = getattr(self, 'optax_map_point', None)
self.run_fsps_numpyro_fit(
num_warmup=num_warmup,
num_samples=num_samples,
num_chains=num_chains,
target_accept_prob=target_accept_prob,
age_grid_gyr=age_grid_gyr,
logzsol_grid=logzsol_grid,
prior_config=prior_config,
dsps_ssp_fn=dsps_ssp_fn,
use_lines=use_lines,
decompose_host=decompose_host,
fit_pl=fit_pl,
fit_fe=fit_fe,
fit_bc=fit_bc,
fit_poly=fit_poly,
fit_reddening=fit_reddening,
fit_poly_order=fit_poly_order,
psf_mags=psf_mags,
psf_mag_errs=psf_mag_errs,
psf_filter_curves=psf_filter_curves,
use_psf_phot=use_psf_phot,
custom_components=custom_components,
custom_line_components=custom_line_components,
init_values=init_values,
)
def _consume_posterior_outputs(self, samples, pred_out, fsps_grid, tied_line_meta, use_lines, decompose_host):
"""Populate model components, uncertainty bands, and summary tables.
Parameters
----------
samples : dict
Posterior samples keyed by parameter name.
pred_out : dict
Posterior predictive outputs from ``Predictive``.
fsps_grid : FSPSTemplateGrid
Host SSP template grid metadata.
tied_line_meta : dict
Emission-line grouping metadata.
use_lines : bool
Whether line model was enabled.
decompose_host : bool
Whether host model was enabled.
"""
flux = np.asarray(self.flux, dtype=float)
self.numpyro_samples = samples
self.fsps_grid = fsps_grid
self.pred_out = pred_out
self._pred_host_draws = np.asarray(pred_out['gal_model'])
self._pred_bc_draws = np.asarray(pred_out['f_bc_model'])
self._pred_cont_draws = np.asarray(pred_out['continuum_model'])
self._pred_total_draws = np.asarray(pred_out['model'])
self._pred_line_draws = np.asarray(pred_out['line_model'])
self._pred_psf_draws = np.asarray(pred_out['psf_model']) if 'psf_model' in pred_out else None
self.custom_components = {}
self._pred_custom_draws = {}
self.custom_line_components = {}
self._pred_custom_line_draws = {}
self.bi = np.nan
self.bi_err = np.nan
self.f_pl_model = np.median(np.asarray(pred_out['f_pl_model']), axis=0)
intrinsic_pl_draws = self._intrinsic_powerlaw_draws()
if intrinsic_pl_draws is not None and intrinsic_pl_draws.shape[1] == len(self.wave):
self.f_pl_model_intrinsic = np.median(intrinsic_pl_draws, axis=0)
else:
self.f_pl_model_intrinsic = np.full_like(self.f_pl_model, np.nan)
self.f_fe_mgii_model = np.median(np.asarray(pred_out['f_fe_mgii_model']), axis=0)
self.f_fe_balmer_model = np.median(np.asarray(pred_out['f_fe_balmer_model']), axis=0)
self.f_bc_model = np.median(np.asarray(pred_out['f_bc_model']), axis=0)
self.f_poly_model = np.median(np.asarray(pred_out['f_poly_model']), axis=0)
self.qso = np.median(np.asarray(pred_out['agn_model']), axis=0)
self.host = np.median(np.asarray(pred_out['gal_model']), axis=0)
self.line_broad = np.median(np.asarray(pred_out['line_model_broad']), axis=0) if 'line_model_broad' in pred_out else np.full_like(self.qso, np.nan)
self.line_narrow = np.median(np.asarray(pred_out['line_model_narrow']), axis=0) if 'line_model_narrow' in pred_out else np.full_like(self.qso, np.nan)
self.f_line_model = np.median(np.asarray(pred_out['line_model']), axis=0)
self.f_conti_model = np.median(np.asarray(pred_out['continuum_model']), axis=0)
self.model_total = np.median(np.asarray(pred_out['model']), axis=0)
self.qso_psf = np.median(np.asarray(pred_out['agn_model_psf']), axis=0) if 'agn_model_psf' in pred_out else np.full_like(self.model_total, np.nan)
self.host_psf = np.median(np.asarray(pred_out['gal_model_psf']), axis=0) if 'gal_model_psf' in pred_out else np.full_like(self.model_total, np.nan)
self.line_broad_psf = np.median(np.asarray(pred_out['line_model_broad_psf']), axis=0) if 'line_model_broad_psf' in pred_out else np.full_like(self.model_total, np.nan)
self.line_narrow_psf = np.median(np.asarray(pred_out['line_model_narrow_psf']), axis=0) if 'line_model_narrow_psf' in pred_out else np.full_like(self.model_total, np.nan)
self.line_psf = np.median(np.asarray(pred_out['line_model_psf']), axis=0) if 'line_model_psf' in pred_out else np.full_like(self.model_total, np.nan)
self.psf_model = np.median(np.asarray(pred_out['psf_model']), axis=0) if 'psf_model' in pred_out else np.full_like(self.model_total, np.nan)
self.fsps_weights_median = np.median(np.asarray(pred_out['fsps_weights']), axis=0)
for comp in normalize_custom_components(getattr(self, '_fit_custom_components', ())):
if comp.deterministic_site_name in pred_out:
draws = np.asarray(pred_out[comp.deterministic_site_name])
self._pred_custom_draws[comp.output_name] = draws
self.custom_components[comp.output_name] = np.median(draws, axis=0)
for comp in normalize_custom_line_components(getattr(self, '_fit_custom_line_components', ())):
if comp.deterministic_site_name in pred_out:
draws = np.asarray(pred_out[comp.deterministic_site_name])
self._pred_custom_line_draws[comp.output_name] = draws
self.custom_line_components[comp.output_name] = np.median(draws, axis=0)
self.line_flux = flux - self.f_conti_model
self.decomposed = True
if 'delta_m_psf_raw' in samples:
delta_m_draws = np.asarray(samples['delta_m_psf_raw'], dtype=float)
elif 'delta_m_psf' in pred_out:
delta_m_draws = np.asarray(pred_out['delta_m_psf'], dtype=float)
else:
delta_m_draws = np.array([np.nan], dtype=float)
if 'eta_psf_raw' in samples:
eta_psf_draws = np.asarray(samples['eta_psf_raw'], dtype=float)
elif 'eta_psf' in pred_out:
eta_psf_draws = np.asarray(pred_out['eta_psf'], dtype=float)
else:
eta_psf_draws = np.array([np.nan], dtype=float)
self.delta_m_psf = float(np.nanmedian(delta_m_draws)) if delta_m_draws.size > 0 else np.nan
self.delta_m_psf_err = float(np.nanstd(delta_m_draws)) if delta_m_draws.size > 0 else np.nan
self.eta_psf = float(np.nanmedian(eta_psf_draws)) if eta_psf_draws.size > 0 else np.nan
self.eta_psf_err = float(np.nanstd(eta_psf_draws)) if eta_psf_draws.size > 0 else np.nan
self.scale_psf = 10.0 ** (-0.4 * self.delta_m_psf) if np.isfinite(self.delta_m_psf) else np.nan
if 'host_redshift_prior_weight' in pred_out:
host_prior_weight_draws = np.asarray(pred_out['host_redshift_prior_weight'], dtype=float)
else:
host_prior_weight_draws = np.array([np.nan], dtype=float)
if 'host_redshift_prior_loc_eff' in pred_out:
host_prior_loc_draws = np.asarray(pred_out['host_redshift_prior_loc_eff'], dtype=float)
else:
host_prior_loc_draws = np.array([np.nan], dtype=float)
if 'host_redshift_prior_scale_eff' in pred_out:
host_prior_scale_draws = np.asarray(pred_out['host_redshift_prior_scale_eff'], dtype=float)
else:
host_prior_scale_draws = np.array([np.nan], dtype=float)
if 'host_redshift_prior_df_eff' in pred_out:
host_prior_df_draws = np.asarray(pred_out['host_redshift_prior_df_eff'], dtype=float)
else:
host_prior_df_draws = np.array([np.nan], dtype=float)
self.host_redshift_prior_weight = float(np.nanmedian(host_prior_weight_draws)) if host_prior_weight_draws.size > 0 else np.nan
self.host_redshift_prior_weight_err = float(np.nanstd(host_prior_weight_draws)) if host_prior_weight_draws.size > 0 else np.nan
self.host_redshift_prior_loc_eff = float(np.nanmedian(host_prior_loc_draws)) if host_prior_loc_draws.size > 0 else np.nan
self.host_redshift_prior_loc_eff_err = float(np.nanstd(host_prior_loc_draws)) if host_prior_loc_draws.size > 0 else np.nan
self.host_redshift_prior_scale_eff = float(np.nanmedian(host_prior_scale_draws)) if host_prior_scale_draws.size > 0 else np.nan
self.host_redshift_prior_scale_eff_err = float(np.nanstd(host_prior_scale_draws)) if host_prior_scale_draws.size > 0 else np.nan
self.host_redshift_prior_df_eff = float(np.nanmedian(host_prior_df_draws)) if host_prior_df_draws.size > 0 else np.nan
self.host_redshift_prior_df_eff_err = float(np.nanstd(host_prior_df_draws)) if host_prior_df_draws.size > 0 else np.nan
def _band(x):
"""Compute 16th/84th percentile uncertainty band across samples."""
a = np.asarray(x)
return np.percentile(a, 16, axis=0), np.percentile(a, 84, axis=0)
cont_plus_lines = np.asarray(pred_out['continuum_model']) + np.asarray(pred_out['line_model'])
fe_total = np.asarray(pred_out['f_fe_mgii_model']) + np.asarray(pred_out['f_fe_balmer_model'])
self.pred_bands = {
'total_model': _band(pred_out['model']),
'host': _band(pred_out['gal_model']),
'PL': _band(pred_out['f_pl_model']),
'FeII': _band(fe_total),
'Balmer_cont': _band(pred_out['f_bc_model']),
'continuum': _band(pred_out['continuum_model']),
'lines': _band(pred_out['line_model']),
'conti_plus_lines': _band(cont_plus_lines),
}
if intrinsic_pl_draws is not None and intrinsic_pl_draws.shape[1] == len(self.wave):
self.pred_bands['PL_intrinsic'] = _band(intrinsic_pl_draws)
for name, draws in self._pred_custom_draws.items():
self.pred_bands[name] = _band(draws)
for name, draws in self._pred_custom_line_draws.items():
self.pred_bands[name] = _band(draws)
if bool(getattr(self, '_fit_fit_bal', False)):
bi, bi_err = self.balnicity_index()
self.bi = float(bi)
self.bi_err = float(bi_err) if np.isfinite(bi_err) else np.nan
if self.verbose:
print("max data :", np.nanmax(self.flux))
print("max total model :", np.nanmax(self.model_total))
print("max PL :", np.nanmax(self.f_pl_model))
print("max host :", np.nanmax(self.host))
print("max FeII UV :", np.nanmax(self.f_fe_mgii_model))
print("max FeII opt :", np.nanmax(self.f_fe_balmer_model))
print("max Balmer cont :", np.nanmax(self.f_bc_model))
print("max lines :", np.nanmax(self.f_line_model))
for name, model in self.custom_components.items():
print(f"max {name:<11}:", np.nanmax(model))
for name, model in self.custom_line_components.items():
print(f"max {name:<11}:", np.nanmax(model))
if decompose_host and 'gal_v_kms' in samples and 'gal_sigma_kms' in samples:
gal_v = float(np.median(np.asarray(samples['gal_v_kms'])))
gal_v_err = float(np.std(np.asarray(samples['gal_v_kms'])))
gal_sig = float(np.median(np.asarray(samples['gal_sigma_kms'])))
gal_sig_err = float(np.std(np.asarray(samples['gal_sigma_kms'])))
else:
gal_v, gal_v_err, gal_sig, gal_sig_err = 0.0, 0.0, 0.0, 0.0
ages = np.array([m['tage_gyr'] for m in fsps_grid.template_meta], dtype=float)
mets = np.array([m['logzsol'] for m in fsps_grid.template_meta], dtype=float)
wsum = np.sum(self.fsps_weights_median)
age_weighted = float(np.sum(self.fsps_weights_median * ages) / wsum) if wsum > 0 else -1.0
metal_weighted = float(np.sum(self.fsps_weights_median * mets) / wsum) if wsum > 0 else -99.0
cont_waves = np.asarray(
_continuum_output_waves_from_prior_config(
getattr(self, "_fit_prior_config", None)
),
dtype=float,
)
self.L_conti_wave = cont_waves
pivot_wave = float(np.asarray(_spectrum_center_pivot(self.wave), dtype=float))
frac_host_vals = []
frac_host_psf_vals = []
frac_bc_vals = []
log_lambda_llambda_vals = []
log_lambda_llambda_errs = []
frac_host_names = []
frac_host_psf_names = []
frac_bc_names = []
log_lambda_llambda_names = []
log_lambda_llambda_err_names = []
for w0 in cont_waves:
wave_label = self._format_wave_label(w0)
frac_host = self._host_fraction_at_wave(w0)
frac_host_psf = self._host_fraction_psf_at_wave(w0)
frac_bc = self._bc_fraction_at_wave(w0)
lum_key = f'log_lambda_Llambda_{wave_label}_agn'
lum_draws = (
np.asarray(pred_out[lum_key], dtype=float)
if lum_key in pred_out
else np.array([np.nan], dtype=float)
)
log_lambda_llambda = float(np.nanmedian(lum_draws)) if lum_draws.size > 0 else np.nan
log_lambda_llambda_err = float(np.nanstd(lum_draws)) if lum_draws.size > 0 else np.nan
setattr(self, f'frac_host_{wave_label}', frac_host)
setattr(self, f'frac_host_psf_{wave_label}', frac_host_psf)
setattr(self, f'frac_bc_{wave_label}', frac_bc)
setattr(self, lum_key, log_lambda_llambda)
setattr(self, f'{lum_key}_err', log_lambda_llambda_err)
frac_host_vals.append(frac_host)
frac_host_psf_vals.append(frac_host_psf)
frac_bc_vals.append(frac_bc)
log_lambda_llambda_vals.append(log_lambda_llambda)
log_lambda_llambda_errs.append(log_lambda_llambda_err)
frac_host_names.append(f'frac_host_{wave_label}')
frac_host_psf_names.append(f'frac_host_psf_{wave_label}')
frac_bc_names.append(f'frac_bc_{wave_label}')
log_lambda_llambda_names.append(lum_key)
log_lambda_llambda_err_names.append(f'{lum_key}_err')
# Preserve the legacy fixed-wavelength attributes for downstream compatibility.
self.pivot_wave = pivot_wave
self.frac_host_pivot = self._host_fraction_at_wave(pivot_wave)
self.frac_host_psf_pivot = self._host_fraction_psf_at_wave(pivot_wave)
self.frac_bc_pivot = self._bc_fraction_at_wave(pivot_wave)
self.frac_host_4200 = self._host_fraction_at_wave(4200.0)
self.frac_host_5100 = self._host_fraction_at_wave(5100.0)
self.frac_host_2500 = self._host_fraction_at_wave(2500.0)
self.frac_host_psf_4200 = self._host_fraction_psf_at_wave(4200.0)
self.frac_host_psf_5100 = self._host_fraction_psf_at_wave(5100.0)
self.frac_host_psf_2500 = self._host_fraction_psf_at_wave(2500.0)
self.frac_bc_2500 = self._bc_fraction_at_wave(2500.0)
n_samp = int(np.asarray(next(iter(samples.values()))).shape[0]) if len(samples) > 0 else 1
if 'PL_norm' in samples:
pl_norm_samp = np.asarray(samples['PL_norm'])
else:
pl_norm_samp = np.full((n_samp,), np.nan)
if 'PL_slope' in samples:
pl_slope_med = float(np.nanmedian(np.asarray(samples['PL_slope'])))
pl_slope_err = float(np.nanstd(np.asarray(samples['PL_slope'])))
else:
pl_slope_med = np.nan
pl_slope_err = np.nan
if 'reddening_ebv' in samples:
reddening_ebv_med = float(np.nanmedian(np.asarray(samples['reddening_ebv'])))
reddening_ebv_err = float(np.nanstd(np.asarray(samples['reddening_ebv'])))
else:
reddening_ebv_med = np.nan
reddening_ebv_err = np.nan
conti_entries = [
('ra', self.ra, 'float'),
('dec', self.dec, 'float'),
('filename', str(self.filename), 'str'),
('redshift', self.z, 'float'),
('SN_ratio_conti', self.SN_ratio_conti, 'float'),
('PL_norm', float(np.nanmedian(pl_norm_samp)), 'float'),
('PL_norm_err', float(np.nanstd(pl_norm_samp)), 'float'),
('PL_slope', pl_slope_med, 'float'),
('PL_slope_err', pl_slope_err, 'float'),
('reddening_ebv', reddening_ebv_med, 'float'),
('reddening_ebv_err', reddening_ebv_err, 'float'),
('pivot_wave', self.pivot_wave, 'float'),
('frac_host_pivot', self.frac_host_pivot, 'float'),
('frac_host_psf_pivot', self.frac_host_psf_pivot, 'float'),
('frac_bc_pivot', self.frac_bc_pivot, 'float'),
('sigma', gal_sig, 'float'),
('sigma_err', gal_sig_err, 'float'),
('v_off', gal_v, 'float'),
('v_off_err', gal_v_err, 'float'),
]
conti_entries += [(name, value, 'float') for name, value in zip(frac_host_names, frac_host_vals)]
conti_entries += [(name, value, 'float') for name, value in zip(frac_host_psf_names, frac_host_psf_vals)]
conti_entries += [(name, value, 'float') for name, value in zip(frac_bc_names, frac_bc_vals)]
conti_entries += [
(name, value, 'float')
for name, value in zip(log_lambda_llambda_names, log_lambda_llambda_vals)
]
conti_entries += [
(name, value, 'float')
for name, value in zip(log_lambda_llambda_err_names, log_lambda_llambda_errs)
]
conti_entries += [
('fsps_age_weighted_gyr', age_weighted, 'float'),
('fsps_logzsol_weighted', metal_weighted, 'float'),
('host_redshift_prior_weight', self.host_redshift_prior_weight, 'float'),
('host_redshift_prior_weight_err', self.host_redshift_prior_weight_err, 'float'),
('host_redshift_prior_loc_eff', self.host_redshift_prior_loc_eff, 'float'),
('host_redshift_prior_loc_eff_err', self.host_redshift_prior_loc_eff_err, 'float'),
('host_redshift_prior_scale_eff', self.host_redshift_prior_scale_eff, 'float'),
('host_redshift_prior_scale_eff_err', self.host_redshift_prior_scale_eff_err, 'float'),
('host_redshift_prior_df_eff', self.host_redshift_prior_df_eff, 'float'),
('host_redshift_prior_df_eff_err', self.host_redshift_prior_df_eff_err, 'float'),
('delta_m_psf', self.delta_m_psf, 'float'),
('delta_m_psf_err', self.delta_m_psf_err, 'float'),
('eta_psf', self.eta_psf, 'float'),
('eta_psf_err', self.eta_psf_err, 'float'),
]
self.conti_result, self.conti_result_type, self.conti_result_name = self._build_result_arrays(conti_entries)
if use_lines and tied_line_meta['n_lines'] > 0:
amp_comp = np.asarray(pred_out['line_amp_per_component'])
mu_comp = np.asarray(pred_out['line_mu_per_component'])
sig_comp = np.asarray(pred_out['line_sig_per_component'])
amp_med = np.median(amp_comp, axis=0)
amp_err = np.std(amp_comp, axis=0)
mu_med = np.median(mu_comp, axis=0)
mu_err = np.std(mu_comp, axis=0)
sig_med = np.median(sig_comp, axis=0)
sig_err = np.std(sig_comp, axis=0)
vals, names, types = [], [], []
for i, nm in enumerate(tied_line_meta['names']):
vals.extend([amp_med[i], amp_err[i], mu_med[i], mu_err[i], sig_med[i], sig_err[i]])
names.extend([f'{nm}_scale', f'{nm}_scale_err', f'{nm}_centerwave', f'{nm}_centerwave_err', f'{nm}_sigma', f'{nm}_sigma_err'])
types.extend(['float'] * 6)
self.line_result = np.array(vals, dtype=object)
self.line_result_type = np.array(types, dtype=object)
self.line_result_name = np.array(names, dtype=object)
self.gauss_result = self.line_result
self.gauss_result_name = self.line_result_name
self.line_component_amp_median = amp_med
self.line_component_mu_median = mu_med
self.line_component_sig_median = sig_med
else:
self.line_result = np.array([])
self.line_result_type = np.array([])
self.line_result_name = np.array([])
self.gauss_result = np.array([])
self.gauss_result_name = np.array([])
self.line_component_amp_median = np.array([])
self.line_component_mu_median = np.array([])
self.line_component_sig_median = np.array([])
self._posterior_hydrated = True
def _wave_trim(self, lam, flux, err, z):
"""Apply rest-frame wavelength range trimming.
Parameters
----------
lam, flux, err : ndarray
Observed-frame wavelength, flux, and uncertainty arrays.
z : float
Redshift used for rest-frame conversion.
"""
ind_trim = np.where((lam / (1 + z) > self.wave_range[0]) & (lam / (1 + z) < self.wave_range[1]), True, False)
self.lam, self.flux, self.err = lam[ind_trim], flux[ind_trim], err[ind_trim]
if len(self.lam) < 100:
raise RuntimeError('No enough pixels in the input wave_range!')
return self.lam, self.flux, self.err
def _wave_msk(self, lam, flux, err, z):
"""Mask user-provided rest-frame wavelength intervals.
Parameters
----------
lam, flux, err : ndarray
Observed-frame wavelength, flux, and uncertainty arrays.
z : float
Redshift used for rest-frame conversion.
"""
for msk in range(len(self.wave_mask)):
ind_not_mask = ~np.where((lam / (1 + z) > self.wave_mask[msk, 0]) & (lam / (1 + z) < self.wave_mask[msk, 1]), True, False)
self.lam, self.flux, self.err = lam[ind_not_mask], flux[ind_not_mask], err[ind_not_mask]
lam, flux, err = self.lam, self.flux, self.err
return self.lam, self.flux, self.err
def _mask_lya_forest(self, lam, flux, err, z, lya_rest=1215.67):
"""Mask observed pixels blueward of rest-frame Ly-alpha.
Parameters
----------
lam, flux, err : ndarray
Observed-frame wavelength, flux, and uncertainty arrays.
z : float
Redshift used for rest-frame conversion.
lya_rest : float, optional
Rest-frame Ly-alpha cutoff in Angstrom.
"""
keep = (lam / (1 + z)) >= float(lya_rest)
self.lam, self.flux, self.err = lam[keep], flux[keep], err[keep]
if len(self.lam) < 10:
raise RuntimeError('Not enough pixels after Ly-alpha forest masking.')
return self.lam, self.flux, self.err
def _de_redden(self, lam, flux, err, ra, dec):
"""Correct observed flux/error for Galactic extinction using dustmaps.
Parameters
----------
lam, flux, err : ndarray
Observed-frame wavelength, flux, and uncertainty arrays.
ra, dec : float
Sky coordinates in degrees.
"""
sfd_query = _get_sfd_query()
coord = SkyCoord(float(ra) * u.deg, float(dec) * u.deg, frame='icrs')
ebv = float(np.asarray(sfd_query(coord)))
self.ebv_mw = ebv
zero_flux = np.where(flux == 0, True, False)
flux[zero_flux] = 1e-10
flux_unred = unred(lam, flux, ebv)
err_unred = err * flux_unred / flux
flux_unred[zero_flux] = 0
self.flux = flux_unred
self.err = err_unred
return self.flux
@staticmethod
def _validate_deredden_coordinates(ra, dec):
"""Validate sky coordinates before Galactic dereddening."""
ra_f = float(ra)
dec_f = float(dec)
invalid_placeholder = (ra_f == -999.0 and dec_f == -999.0)
invalid_range = (not np.isfinite(ra_f)) or (not np.isfinite(dec_f)) or (dec_f < -90.0) or (dec_f > 90.0)
if invalid_placeholder or invalid_range:
raise ValueError(
"Galactic dereddening requires valid sky coordinates: "
f"received ra={ra_f}, dec={dec_f}. "
"Pass real source coordinates to `QSOFit(...)` or call `fit(deredden=False)` "
"for synthetic data or spectra without sky positions."
)
return ra_f, dec_f
def _rest_frame(self, lam, flux, err, z):
"""Convert observed-frame spectra to rest-frame convention.
Parameters
----------
lam, flux, err : ndarray
Observed-frame wavelength, flux, and uncertainty arrays.
z : float
Source redshift.
"""
self.wave = lam / (1 + z)
self.flux = flux * (1 + z)
self.err = err * (1 + z)
return self.wave, self.flux, self.err
def _orignial_spec(self, wave, flux, err):
"""Cache the pre-modeling spectrum for plotting/debugging.
Parameters
----------
wave, flux, err : ndarray
Rest-frame wavelength, flux, and uncertainty arrays.
"""
self.wave_prereduced = wave
self.flux_prereduced = flux
self.err_prereduced = err
def _calculate_sn(self, wave, flux, alter=True):
"""Estimate continuum S/N from standard windows or robust fallback.
Parameters
----------
wave, flux : ndarray
Rest-frame wavelength and flux arrays.
alter : bool, optional
If True and standard windows are unavailable, use robust fallback.
"""
ind5100 = np.where((wave > 5080) & (wave < 5130), True, False)
ind3000 = np.where((wave > 3000) & (wave < 3050), True, False)
ind1350 = np.where((wave > 1325) & (wave < 1375), True, False)
if np.all(np.array([np.sum(ind5100), np.sum(ind3000), np.sum(ind1350)]) < 10):
if alter is False:
self.SN_ratio_conti = -1.
return self.SN_ratio_conti
input_data = np.array(flux)
input_data = np.array(input_data[np.where(input_data != 0.0)])
n = len(input_data)
if n > 4:
signal = np.median(input_data)
noise = 0.6052697 * np.median(np.abs(2.0 * input_data[2:n - 2] - input_data[0:n - 4] - input_data[4:n]))
self.SN_ratio_conti = float(signal / noise)
else:
self.SN_ratio_conti = -1.
else:
tmp_SN = np.array([flux[ind5100].mean() / flux[ind5100].std(), flux[ind3000].mean() / flux[ind3000].std(), flux[ind1350].mean() / flux[ind1350].std()])
tmp_SN = tmp_SN[np.array([np.sum(ind5100), np.sum(ind3000), np.sum(ind1350)]) > 10]
self.SN_ratio_conti = np.nanmean(tmp_SN) if not np.all(np.isnan(tmp_SN)) else -1.
return self.SN_ratio_conti
def _host_fraction_at_wave(self, w0):
"""Return host/continuum flux fraction at wavelength ``w0``.
Parameters
----------
w0 : float
Rest-frame wavelength in Angstrom.
"""
return self._component_fraction_at_wave(self.host, w0)
def _host_fraction_psf_at_wave(self, w0):
"""Return PSF-space host fraction at wavelength ``w0``."""
qso_psf = np.asarray(getattr(self, 'qso_psf', []), dtype=float)
host_psf = np.asarray(getattr(self, 'host_psf', []), dtype=float)
if qso_psf.size != len(getattr(self, 'wave', [])) or host_psf.size != len(getattr(self, 'wave', [])):
return -1.0
return self._component_fraction_at_wave(host_psf, w0, reference=qso_psf + host_psf)
def _bc_fraction_at_wave(self, w0):
"""Return Balmer-continuum/continuum flux fraction at wavelength ``w0``.
Parameters
----------
w0 : float
Rest-frame wavelength in Angstrom.
"""
return self._component_fraction_at_wave(self.f_bc_model, w0)
def _component_fraction_at_wave(self, component, w0, reference=None):
"""Return component fraction relative to fitted continuum at ``w0``.
Parameters
----------
component : ndarray
Component flux array evaluated on ``self.wave``.
w0 : float
Rest-frame wavelength in Angstrom.
reference : ndarray or None, optional
Reference flux array. If ``None``, uses ``self.f_conti_model``.
"""
if len(self.wave) == 0:
return -1.
comp = np.interp(w0, self.wave, component, left=np.nan, right=np.nan)
ref_arr = self.f_conti_model if reference is None else np.asarray(reference, dtype=float)
if len(ref_arr) != len(self.wave):
return -1.
total = np.interp(w0, self.wave, ref_arr, left=np.nan, right=np.nan)
if not np.isfinite(comp) or not np.isfinite(total) or total == 0:
return -1.
return float(comp / total)
[docs]
def reconstruct_posterior_spectrum(
self,
wave_out=None,
wave_min=2500.0,
wave_max=None,
n_draws=None,
return_components=True,
):
"""Rebuild posterior component draws on a requested rest-frame grid.
Parameters
----------
wave_out : array-like or None, optional
Explicit rest-frame wavelength grid. If ``None``, build a grid from
``min(wave_min, self.wave.min())`` to ``wave_max or self.wave.max()``
using the median native wavelength spacing.
wave_min, wave_max : float or None, optional
Bounds for the auto-generated grid when ``wave_out`` is ``None``.
n_draws : int or None, optional
If provided, use at most the first ``n_draws`` posterior samples.
return_components : bool, optional
If True, include per-component draws and medians in the return value.
This includes any fitted custom components.
"""
if not hasattr(self, 'numpyro_samples') or self.numpyro_samples is None:
raise RuntimeError("No posterior samples available. Run fit() first.")
has_age_grid = hasattr(self, '_fit_fsps_age_grid')
has_logz_grid = hasattr(self, '_fit_fsps_logzsol_grid')
has_fsps_grid = hasattr(self, 'fsps_grid')
if not (has_age_grid and has_logz_grid) and not has_fsps_grid:
raise RuntimeError("No template-grid metadata available for reconstruction.")
if not hasattr(self, 'wave') or len(self.wave) < 2:
raise RuntimeError("No fitted rest-frame wavelength grid available.")
wave_native = np.asarray(self.wave, dtype=float)
dw = float(np.nanmedian(np.diff(wave_native)))
if not np.isfinite(dw) or dw <= 0:
raise RuntimeError("Unable to infer wavelength spacing for reconstruction.")
if wave_out is None:
wmin = min(float(wave_min), float(np.nanmin(wave_native)))
wmax = float(np.nanmax(wave_native) if wave_max is None else wave_max)
if wmin >= wmax:
raise ValueError("Requested reconstruction grid has non-positive span.")
wave_out = np.arange(wmin, wmax + 0.5 * dw, dw, dtype=float)
else:
wave_out = np.asarray(wave_out, dtype=float)
if wave_out.ndim != 1 or wave_out.size < 2 or not np.all(np.isfinite(wave_out)):
raise ValueError("wave_out must be a finite 1D wavelength grid.")
prior_config = getattr(self, '_fit_prior_config', None)
if prior_config is None:
prior_config = build_default_prior_config(np.asarray(self.flux, dtype=float))
age_grid_gyr, logzsol_grid, dsps_ssp_fn = self._require_posterior_bundle_fsps_metadata(self.__dict__)
expected_templates = int(len(age_grid_gyr) * len(logzsol_grid))
self._validate_fsps_weights_shape(
getattr(self, 'pred_out', None),
expected_templates=expected_templates,
context="Posterior reconstruction",
)
return reconstruct_posterior_components(
wave_out=wave_out,
samples=self.numpyro_samples,
pred_out=getattr(self, 'pred_out', None),
age_grid_gyr=age_grid_gyr,
logzsol_grid=logzsol_grid,
dsps_ssp_fn=dsps_ssp_fn,
prior_config=prior_config,
fit_poly=bool(getattr(self, '_fit_fit_poly', False)),
fit_reddening=bool(getattr(self, '_fit_fit_reddening', False)),
fit_poly_order=int(getattr(self, '_fit_fit_poly_order', 2)),
fe_uv_wave=self.fe_uv_wave,
fe_uv_flux=self.fe_uv_flux,
fe_op_wave=self.fe_op_wave,
fe_op_flux=self.fe_op_flux,
custom_components=getattr(self, '_fit_custom_components', ()),
n_draws=n_draws,
return_components=return_components,
decompose_host=bool(getattr(self, '_fit_decompose_host', True)),
)
[docs]
def component_fraction_at_wave(self, component='host', wave0=2500.0, reference='continuum', reconstruct=False, n_draws=None):
"""Return component/reference flux fraction at a requested wavelength.
Parameters
----------
component, reference : str, optional
Component names. Supported reconstructed names are ``host``, ``PL``,
``Fe_uv``, ``Fe_op``, ``Balmer_cont``, and ``continuum``.
Any fitted custom component names are also accepted.
wave0 : float, optional
Rest-frame wavelength in Angstrom.
reconstruct : bool, optional
If True, rebuild posterior components on a grid that reaches ``wave0``.
Returns ``(median, err)`` from the posterior draws.
n_draws : int or None, optional
Maximum number of posterior draws to use in the reconstruction.
"""
if not reconstruct:
component_map = {
'host': getattr(self, 'host', None),
'Balmer_cont': getattr(self, 'f_bc_model', None),
'continuum': getattr(self, 'f_conti_model', None),
}
component_map.update(getattr(self, 'custom_components', {}))
comp_arr = component_map.get(component)
ref_arr = component_map.get(reference, getattr(self, 'f_conti_model', None))
if comp_arr is None or ref_arr is None or len(self.wave) == 0:
return -1.0, np.nan
comp = np.interp(wave0, self.wave, comp_arr, left=np.nan, right=np.nan)
ref = np.interp(wave0, self.wave, ref_arr, left=np.nan, right=np.nan)
if not np.isfinite(comp) or not np.isfinite(ref) or ref == 0:
return -1.0, np.nan
return float(comp / ref), np.nan
recon = self.reconstruct_posterior_spectrum(wave_min=min(float(wave0), float(np.nanmin(self.wave))), n_draws=n_draws)
wave = recon['wave']
idx = int(np.argmin(np.abs(wave - float(wave0))))
if component not in recon['draws'] or reference not in recon['draws']:
raise ValueError(f"Unknown reconstructed component/reference: {component}, {reference}")
num = np.asarray(recon['draws'][component], dtype=float)[:, idx]
den = np.asarray(recon['draws'][reference], dtype=float)[:, idx]
frac = np.divide(num, den, out=np.full_like(num, np.nan), where=np.isfinite(den) & (den != 0))
good = np.isfinite(frac)
if not np.any(good):
return np.nan, np.nan
p16, p50, p84 = np.percentile(frac[good], [16.0, 50.0, 84.0])
return float(p50), float(0.5 * (p84 - p16))
@staticmethod
def _balnicity_index_from_arrays(
wave: np.ndarray,
bal_sum: np.ndarray,
reference: np.ndarray,
line_center: float,
vmin: float,
vmax: float,
min_width: float,
depth_threshold: float,
) -> tuple[float, list[tuple[float, float]]]:
"""Compute a simple BI-like integral from a BAL model and reference model."""
wave = np.asarray(wave, dtype=float)
bal_sum = np.asarray(bal_sum, dtype=float)
reference = np.asarray(reference, dtype=float)
if wave.ndim != 1 or bal_sum.shape != wave.shape or reference.shape != wave.shape or wave.size < 2:
return 0.0, []
finite = np.isfinite(wave) & np.isfinite(bal_sum) & np.isfinite(reference) & (reference > 0)
if not np.any(finite):
return 0.0, []
vel = C_KMS * (float(line_center) / wave - 1.0)
sel = finite & (vel >= float(vmin)) & (vel <= float(vmax))
if np.count_nonzero(sel) < 2:
return 0.0, []
vel_sel = vel[sel]
bal_sel = bal_sum[sel]
ref_sel = reference[sel]
order = np.argsort(vel_sel)
vel_sel = vel_sel[order]
bal_sel = bal_sel[order]
ref_sel = ref_sel[order]
flux_norm = 1.0 + bal_sel / ref_sel
integrand = 1.0 - flux_norm / 0.9
active = np.isfinite(integrand) & (integrand > 0.0) & ((-bal_sel / ref_sel) >= float(depth_threshold))
if not np.any(active):
return 0.0, []
bi_total = 0.0
troughs: list[tuple[float, float]] = []
idx = np.flatnonzero(active)
start = idx[0]
prev = idx[0]
for cur in idx[1:]:
if cur != prev + 1:
v0 = float(vel_sel[start])
v1 = float(vel_sel[prev])
if (v1 - v0) >= float(min_width):
bi_total += float(np.trapezoid(integrand[start:prev + 1], vel_sel[start:prev + 1]))
troughs.append((v0, v1))
start = cur
prev = cur
v0 = float(vel_sel[start])
v1 = float(vel_sel[prev])
if (v1 - v0) >= float(min_width):
bi_total += float(np.trapezoid(integrand[start:prev + 1], vel_sel[start:prev + 1]))
troughs.append((v0, v1))
return float(max(bi_total, 0.0)), troughs
[docs]
def balnicity_index(
self,
component_names=None,
line_center: float = 1549.06,
vmin: float = 3000.0,
vmax: float = 25000.0,
min_width: float = 2000.0,
depth_threshold: float = 0.1,
include_line_emission: bool = True,
return_details: bool = False,
):
"""Return a simple BALnicity-style index from the summed fitted BAL model.
The BAL model is defined as the sum of selected negative custom
components, typically names beginning with ``bal_``. The reference model
is the BAL-free AGN continuum, optionally plus the fitted emission-line
model. The returned BI uses the standard-style integrand
``1 - f_norm / 0.9`` over contiguous troughs at least ``min_width`` wide
and deeper than ``depth_threshold``.
"""
self._ensure_hydrated_from_samples()
if not hasattr(self, 'wave') or len(self.wave) == 0:
raise RuntimeError("No fitted spectrum available. Run fit() first.")
custom_models = getattr(self, 'custom_components', {})
if component_names is None:
selected_names = [name for name in custom_models if str(name).startswith('bal_')]
elif isinstance(component_names, str):
selected_names = [component_names]
else:
selected_names = [str(name) for name in component_names]
selected_names = [name for name in selected_names if name in custom_models]
if len(selected_names) == 0:
result = {
'bi': 0.0,
'bi_err': np.nan,
'component_names': [],
'troughs_kms': [],
'line_center': float(line_center),
'vmin': float(vmin),
'vmax': float(vmax),
'min_width': float(min_width),
'depth_threshold': float(depth_threshold),
}
return result if return_details else (0.0, np.nan)
bal_sum = np.sum([np.asarray(custom_models[name], dtype=float) for name in selected_names], axis=0)
qso_model = np.asarray(getattr(self, 'qso', np.zeros_like(self.wave)), dtype=float)
line_model = np.asarray(getattr(self, 'f_line_model', np.zeros_like(self.wave)), dtype=float) if include_line_emission else np.zeros_like(self.wave, dtype=float)
reference = qso_model - bal_sum + line_model
bi_med, troughs = self._balnicity_index_from_arrays(
wave=np.asarray(self.wave, dtype=float),
bal_sum=bal_sum,
reference=reference,
line_center=float(line_center),
vmin=float(vmin),
vmax=float(vmax),
min_width=float(min_width),
depth_threshold=float(depth_threshold),
)
bi_err = np.nan
if hasattr(self, 'pred_out') and self.pred_out is not None and hasattr(self, '_pred_custom_draws'):
draw_list = [np.asarray(self._pred_custom_draws[name], dtype=float) for name in selected_names if name in self._pred_custom_draws]
qso_draws = np.asarray(self.pred_out.get('agn_model', []), dtype=float)
line_draws = np.asarray(self.pred_out.get('line_model', []), dtype=float) if include_line_emission else np.zeros_like(qso_draws, dtype=float)
if len(draw_list) == len(selected_names) and qso_draws.ndim == 2 and qso_draws.shape[1] == len(self.wave):
bal_draws = np.sum(draw_list, axis=0)
ref_draws = qso_draws - bal_draws + line_draws
bi_draws = []
for i in range(qso_draws.shape[0]):
bi_i, _ = self._balnicity_index_from_arrays(
wave=np.asarray(self.wave, dtype=float),
bal_sum=bal_draws[i],
reference=ref_draws[i],
line_center=float(line_center),
vmin=float(vmin),
vmax=float(vmax),
min_width=float(min_width),
depth_threshold=float(depth_threshold),
)
bi_draws.append(bi_i)
bi_draws = np.asarray(bi_draws, dtype=float)
good = np.isfinite(bi_draws)
if np.any(good):
p16, p50, p84 = np.percentile(bi_draws[good], [16.0, 50.0, 84.0])
bi_med = float(p50)
bi_err = float(0.5 * (p84 - p16))
result = {
'bi': float(bi_med),
'bi_err': float(bi_err) if np.isfinite(bi_err) else np.nan,
'component_names': selected_names,
'troughs_kms': troughs,
'line_center': float(line_center),
'vmin': float(vmin),
'vmax': float(vmax),
'min_width': float(min_width),
'depth_threshold': float(depth_threshold),
}
return result if return_details else (result['bi'], result['bi_err'])
def _line_profile_from_params(
self,
line_key: str,
amp: np.ndarray,
mu: np.ndarray,
sig: np.ndarray,
) -> np.ndarray:
"""Build a line profile from explicit Gaussian parameter arrays.
Parameters
----------
line_key : str
Line-name prefix (for example ``'Hb_br'``).
amp, mu, sig : ndarray
Gaussian amplitudes, centers (ln lambda), and widths.
"""
if not hasattr(self, 'wave') or len(self.wave) == 0:
return np.array([], dtype=float)
if not hasattr(self, 'tied_line_meta'):
return np.zeros_like(self.wave, dtype=float)
names = np.asarray(self.tied_line_meta.get('names', []))
amp = np.asarray(amp, dtype=float)
mu = np.asarray(mu, dtype=float)
sig = np.asarray(sig, dtype=float)
if names.size == 0 or amp.size == 0 or mu.size == 0 or sig.size == 0:
return np.zeros_like(self.wave, dtype=float)
keep = np.array([str(n).startswith(f'{line_key}_') for n in names], dtype=bool)
if not np.any(keep):
return np.zeros_like(self.wave, dtype=float)
lnw = np.log(np.asarray(self.wave, dtype=float))
prof = np.zeros_like(lnw)
for a, m, s in zip(amp[keep], mu[keep], sig[keep]):
if np.isfinite(a) and np.isfinite(m) and np.isfinite(s) and s > 0:
prof += a * np.exp(-0.5 * ((lnw - m) / s) ** 2)
return prof
[docs]
def line_profile_from_components(self, line_key: str) -> np.ndarray:
"""Build a line-only profile from posterior-median Gaussian components.
Parameters
----------
line_key : str
Line-name prefix (for example ``'Hb_br'``).
"""
if not hasattr(self, 'line_component_amp_median'):
return np.zeros_like(self.wave, dtype=float)
return self._line_profile_from_params(
line_key=line_key,
amp=np.asarray(getattr(self, 'line_component_amp_median', []), dtype=float),
mu=np.asarray(getattr(self, 'line_component_mu_median', []), dtype=float),
sig=np.asarray(getattr(self, 'line_component_sig_median', []), dtype=float),
)
[docs]
def line_profile_from_draw(self, draw_index: int, line_key: str) -> np.ndarray:
"""Build a line-only profile for one posterior draw index.
Parameters
----------
draw_index : int
Posterior draw index.
line_key : str
Line-name prefix (for example ``'Hb_br'``).
"""
self._ensure_hydrated_from_samples()
if not hasattr(self, 'pred_out') or self.pred_out is None:
return np.zeros_like(self.wave, dtype=float)
if 'line_amp_per_component' not in self.pred_out:
return np.zeros_like(self.wave, dtype=float)
amp_draws = np.asarray(self.pred_out['line_amp_per_component'])
mu_draws = np.asarray(self.pred_out['line_mu_per_component'])
sig_draws = np.asarray(self.pred_out['line_sig_per_component'])
if amp_draws.ndim != 2 or mu_draws.ndim != 2 or sig_draws.ndim != 2:
return np.zeros_like(self.wave, dtype=float)
idx = int(draw_index)
if idx < 0 or idx >= amp_draws.shape[0]:
raise IndexError(f'draw_index {idx} is out of bounds for {amp_draws.shape[0]} posterior draws')
return self._line_profile_from_params(
line_key=line_key,
amp=amp_draws[idx],
mu=mu_draws[idx],
sig=sig_draws[idx],
)
[docs]
def line_props(self, profile: np.ndarray, wave: np.ndarray | None = None) -> tuple[float, float]:
"""Return ``(fwhm_kms, integrated_area)`` from a line profile.
Parameters
----------
profile : ndarray
Line profile values.
wave : ndarray or None, optional
Wavelength array. If ``None``, ``self.wave`` is used.
"""
p = np.asarray(profile, dtype=float)
w = np.asarray(self.wave if wave is None else wave, dtype=float)
if p.size == 0 or w.size == 0 or p.size != w.size:
return np.nan, np.nan
if not np.any(np.isfinite(p)) or np.nanmax(p) <= 0:
return np.nan, np.nan
ipeak = int(np.nanargmax(p))
peak_lam = w[ipeak]
half = 0.5 * p[ipeak]
idx = np.where(p >= half)[0]
area = float(np.trapezoid(np.clip(p, 0.0, None), w))
if idx.size < 2 or not np.isfinite(peak_lam) or peak_lam <= 0:
return np.nan, area
fwhm_a = w[idx[-1]] - w[idx[0]]
fwhm_kms = C_KMS * fwhm_a / peak_lam
return float(fwhm_kms), area
[docs]
def line_props_from_profile(self, wave: np.ndarray, profile: np.ndarray) -> tuple[float, float]:
"""Compatibility wrapper for :meth:`line_props`.
Parameters
----------
wave : ndarray
Wavelength array.
profile : ndarray
Line profile values.
"""
return self.line_props(profile=profile, wave=wave)
[docs]
def save_result(self, conti_result, conti_result_type, conti_result_name, line_result, line_result_type, line_result_name, save_fits_name):
"""Write continuum+line summary table to a pandas CSV file.
Parameters
----------
conti_result, line_result : ndarray
Continuum and line result values.
conti_result_type, line_result_type : ndarray
Legacy dtype tags (stored but not enforced).
conti_result_name, line_result_name : ndarray
Column names for continuum and line outputs.
save_fits_name : str
Output basename for CSV.
"""
self.all_result = np.concatenate([conti_result, line_result])
self.all_result_type = np.concatenate([conti_result_type, line_result_type])
self.all_result_name = np.concatenate([conti_result_name, line_result_name])
df = pd.DataFrame([self.all_result], columns=self.all_result_name)
out_dir = self.output_path if self.output_path is not None else '.'
os.makedirs(out_dir, exist_ok=True)
out_file = os.path.join(out_dir, save_fits_name + '.csv')
df.to_csv(out_file, index=False)
print(f"Saved results table: {out_file}")
return
def _posterior_series(self, param_names=None, max_vector_elems=2):
"""Flatten posterior samples into labeled 1D series for diagnostics.
Parameters
----------
param_names : list[str] | str | None, optional
Parameter selector. Use ``'all'`` for all posterior keys.
max_vector_elems : int or None, optional
Maximum number of vector elements to expand per key.
"""
if not hasattr(self, 'numpyro_samples') or self.numpyro_samples is None:
return []
samples = self.numpyro_samples
if param_names == 'all':
param_names = sorted(samples.keys())
elif param_names is None:
param_names = [
'cont_norm', 'log_frac_host', 'PL_slope', 'Fe_uv_norm', 'Fe_op_norm',
'Balmer_norm', 'Balmer_Te', 'Balmer_Tau',
'gal_v_kms', 'gal_sigma_kms',
'frac_jitter', 'add_jitter',
]
out = []
for name in param_names:
if name not in samples:
continue
arr = np.asarray(samples[name])
if arr.ndim == 1:
out.append((name, arr))
elif arr.ndim >= 2:
arr2 = arr.reshape(arr.shape[0], -1)
nflat = arr2.shape[1]
if max_vector_elems is None or int(max_vector_elems) < 0:
ncomp = nflat
else:
ncomp = min(nflat, int(max_vector_elems))
for i in range(ncomp):
out.append((f'{name}[{i}]', arr2[:, i]))
return out
@staticmethod
def _format_wave_label(w0):
"""Format a continuum wavelength for attribute/column naming."""
try:
w = float(w0)
except Exception:
return str(w0)
if np.isfinite(w) and abs(w - round(w)) < 1e-6:
return str(int(round(w)))
return str(w).replace('.', 'p')
@staticmethod
def _build_result_arrays(entries):
"""Convert ``(name, value, type)`` entries into legacy result arrays."""
return (
np.array([value for _, value, _ in entries], dtype=object),
np.array([dtype for _, _, dtype in entries], dtype=object),
np.array([name for name, _, _ in entries], dtype=object),
)
@staticmethod
def _filter_half_width_angstrom(filt):
"""Return an approximate half-width for a photometric filter."""
filt_wave = _filter_wave_to_angstrom_array(filt.wavelength)
filt_trans = np.asarray(filt.response, dtype=float)
support = filt_wave[filt_trans > 0.01 * np.nanmax(filt_trans)]
if support.size >= 2:
return 0.5 * float(support.max() - support.min())
return 0.0
def _plot_filter_metadata(self, bands):
"""Return plotting metadata arrays for the requested photometric bands."""
filters = _get_sdss_filters()
filt_list = [filters.get(str(band)) for band in bands]
valid = np.asarray([filt is not None for filt in filt_list], dtype=bool)
filt_list = [filt for filt in filt_list if filt is not None]
if len(filt_list) == 0:
return valid, np.array([], dtype=float), np.array([], dtype=float)
eff_wave_obs = np.asarray(
[_filter_wave_to_angstrom_scalar(filt.effective_wavelength) for filt in filt_list],
dtype=float,
)
half_width_obs = np.asarray(
[self._filter_half_width_angstrom(filt) for filt in filt_list],
dtype=float,
)
return valid, eff_wave_obs, half_width_obs
@staticmethod
def _style_axis(ax, spine_lw=1.5):
"""Apply consistent axis styling.
Parameters
----------
ax : matplotlib.axes.Axes
Axis to style.
spine_lw : float, optional
Line width for all spines.
"""
ax.grid(False)
ax.tick_params(
which='both',
direction='in',
top=True,
right=True,
length=6,
width=1.2,
)
for spine in ax.spines.values():
spine.set_linewidth(spine_lw)
def _synthetic_photometry_for_plot(self, model_attr='model_total'):
"""Return rest-frame synthetic photometry points for plotting, if available."""
if not bool(getattr(self, 'use_psf_phot', False)):
return None
bands = list(getattr(self, 'psf_bands', []) or [])
mag_errs = np.asarray(getattr(self, 'psf_mag_errs', []), dtype=float)
if len(bands) == 0 or mag_errs.size != len(bands):
return None
if not hasattr(self, 'wave') or len(getattr(self, 'wave', [])) == 0:
return None
model_rf = np.asarray(getattr(self, model_attr, []), dtype=float)
if model_rf.size != len(self.wave) or not np.any(np.isfinite(model_rf)):
return None
wave_rf = np.asarray(self.wave, dtype=float)
z = float(getattr(self, 'z', 0.0))
wave_obs = wave_rf * (1.0 + z)
flam_obs = model_rf / max(1.0 + z, 1e-8)
filters = _get_sdss_filters()
c_ang_s = 2.99792458e18
x_rf, xerr_rf, y_rf, yerr_rf = [], [], [], []
for band, mag_err in zip(bands, mag_errs):
filt = filters.get(str(band))
if filt is None or not np.isfinite(mag_err) or mag_err <= 0:
continue
filt_wave = _filter_wave_to_angstrom_array(filt.wavelength)
filt_trans = np.asarray(filt.response, dtype=float)
trans = np.interp(wave_obs, filt_wave, filt_trans, left=0.0, right=0.0)
trans = np.clip(trans, 0.0, None)
flam_obs_cgs = 1e-17 * flam_obs
num = float(np.trapezoid(flam_obs_cgs * trans * wave_obs, wave_obs))
den = float(np.trapezoid(trans * c_ang_s / np.clip(wave_obs, 1e-8, None), wave_obs))
if not np.isfinite(num) or not np.isfinite(den) or num <= 0 or den <= 0:
continue
mag_syn = -2.5 * np.log10(num / den) - 48.60
if not np.isfinite(mag_syn):
continue
eff_wave_obs = _filter_wave_to_angstrom_scalar(filt.effective_wavelength)
support = filt_wave[filt_trans > 0.01 * np.nanmax(filt_trans)]
if support.size >= 2:
half_width_obs = 0.5 * float(support.max() - support.min())
else:
half_width_obs = 0.0
fnu_obs = 10.0 ** (-0.4 * (mag_syn + 48.60))
flam_band_obs_cgs = fnu_obs * c_ang_s / max(eff_wave_obs ** 2, 1e-30)
flam_band_obs = flam_band_obs_cgs / 1e-17
flam_band_rf = flam_band_obs * (1.0 + z)
flam_band_rf_err = flam_band_rf * (0.4 * np.log(10.0) * float(mag_err))
x_rf.append(eff_wave_obs / (1.0 + z))
xerr_rf.append(half_width_obs / (1.0 + z))
y_rf.append(flam_band_rf)
yerr_rf.append(flam_band_rf_err)
if len(x_rf) == 0:
return None
return (
np.asarray(x_rf, dtype=float),
np.asarray(xerr_rf, dtype=float),
np.asarray(y_rf, dtype=float),
np.asarray(yerr_rf, dtype=float),
)
def _observed_photometry_for_plot(self):
"""Return rest-frame observed PSF photometry points for plotting, if available."""
if not bool(getattr(self, 'use_psf_phot', False)):
return None
bands = list(getattr(self, 'psf_bands', []) or [])
mags = np.asarray(getattr(self, 'psf_mags', []), dtype=float)
mag_errs = np.asarray(getattr(self, 'psf_mag_errs', []), dtype=float)
if len(bands) == 0 or mags.size != len(bands) or mag_errs.size != len(bands):
return None
z = float(getattr(self, 'z', 0.0))
c_ang_s = 2.99792458e18
filter_valid, eff_wave_obs, half_width_obs = self._plot_filter_metadata(bands)
phot_valid = np.isfinite(mags) & np.isfinite(mag_errs) & (mag_errs > 0)
valid = filter_valid & phot_valid
if not np.any(valid):
return None
mags = mags[valid]
mag_errs = mag_errs[valid]
eff_wave_obs = eff_wave_obs[phot_valid[filter_valid]]
half_width_obs = half_width_obs[phot_valid[filter_valid]]
fnu_obs = 10.0 ** (-0.4 * (mags + 48.60))
flam_band_obs_cgs = fnu_obs * c_ang_s / np.clip(eff_wave_obs ** 2, 1e-30, None)
flam_band_obs = flam_band_obs_cgs / 1e-17
flam_band_rf = flam_band_obs * (1.0 + z)
flam_band_rf_err = flam_band_rf * (0.4 * np.log(10.0) * mag_errs)
return (
eff_wave_obs / (1.0 + z),
half_width_obs / (1.0 + z),
flam_band_rf,
flam_band_rf_err,
)
[docs]
def plot_trace(
self,
param_names=None,
max_vector_elems=2,
save_fig_path=None,
save_fig_name=None,
show_plot=False,
):
"""Plot posterior trace series for selected parameters.
Parameters
----------
param_names : list[str] | str | None, optional
Parameter selector. Use ``'all'`` to include all posterior keys.
max_vector_elems : int or None, optional
Maximum number of vector elements to expand per key.
save_fig_path : str or None, optional
Output directory when saving figures. If ``None``, uses ``self.output_path``
(or ``'.'`` when unset).
save_fig_name : str or None, optional
Output filename override.
show_plot : bool, optional
If True, display the figure interactively with ``plt.show()``.
Defaults to False so diagnostics are safe to call in headless terminals.
"""
series = self._posterior_series(param_names=param_names, max_vector_elems=max_vector_elems)
if len(series) == 0:
return None
n = len(series)
fig, axes = plt.subplots(n, 1, figsize=(10, max(2.2 * n, 4)), sharex=True)
if n == 1:
axes = [axes]
for ax, (label, vals) in zip(axes, series):
ax.plot(np.arange(len(vals)), vals, color='tab:blue', lw=0.8)
ax.set_ylabel(label, fontsize=9)
self._style_axis(ax)
axes[-1].set_xlabel('Sample', fontsize=10)
fig.tight_layout()
if show_plot:
plt.show()
if self.save_fig:
out_name = f'{self.filename}_trace.pdf' if save_fig_name is None else save_fig_name
save_dir = self.output_path if save_fig_path is None else save_fig_path
if save_dir is None:
save_dir = '.'
os.makedirs(save_dir, exist_ok=True)
out_file = os.path.join(save_dir, out_name)
fig.savefig(out_file)
print(f"Saved trace plot: {out_file}")
plt.close(fig)
self.trace_fig = fig
return fig
[docs]
def plot_corner(
self,
param_names=None,
max_vector_elems=2,
bins=30,
max_points=5000,
save_fig_path=None,
save_fig_name=None,
show_plot=False,
):
"""Plot posterior projections with ``corner.corner``.
Parameters
----------
param_names : list[str] | str | None, optional
Parameter selector. Use ``'all'`` to include all posterior keys.
max_vector_elems : int or None, optional
Maximum number of vector elements to expand per key.
bins : int, optional
Histogram bin count.
max_points : int, optional
Maximum posterior draws to plot (subsampled if needed).
save_fig_path : str or None, optional
Output directory when saving figures. If ``None``, uses ``self.output_path``
(or ``'.'`` when unset).
save_fig_name : str or None, optional
Output filename override.
show_plot : bool, optional
If True, display the figure interactively with ``plt.show()``.
Defaults to False so diagnostics are safe to call in headless terminals.
"""
series = self._posterior_series(param_names=param_names, max_vector_elems=max_vector_elems)
if len(series) == 0:
return None
try:
import corner
except ImportError as exc:
raise ImportError("plot_corner requires the 'corner' package to be installed.") from exc
labels = [s[0] for s in series]
data = np.column_stack([s[1] for s in series])
if data.shape[0] > int(max_points):
idx = np.linspace(0, data.shape[0] - 1, int(max_points), dtype=int)
data = data[idx]
fig = corner.corner(
data,
labels=labels,
bins=bins,
show_titles=True,
color="black",
quantiles=[0.16, 0.5, 0.84],
plot_datapoints=False,
plot_contours=True,
hist2d_kwargs={"bins": max(8, int(bins // 2)), "levels": [0.393, 0.865, 0.989]},
fill_contours=False,
no_fill_contours=True,
smooth=0.8,
smooth1d=0.8,
max_n_ticks=3,
quiet=True,
labelpad=0.3,
label_kwargs={"fontsize": 9},
title_kwargs={"fontsize": 9},
use_math_text=False,
title_fmt=".3g",
)
for ax in fig.axes:
ax.tick_params(axis='both', which='major', labelsize=8)
self._style_axis(ax)
fig.subplots_adjust(left=0.08, bottom=0.08, right=0.98, top=0.98, wspace=0.06, hspace=0.06)
if show_plot:
plt.show()
if self.save_fig:
out_name = f'{self.filename}_corner.pdf' if save_fig_name is None else save_fig_name
save_dir = self.output_path if save_fig_path is None else save_fig_path
if save_dir is None:
save_dir = '.'
os.makedirs(save_dir, exist_ok=True)
out_file = os.path.join(save_dir, out_name)
fig.savefig(out_file)
print(f"Saved corner plot: {out_file}")
plt.close(fig)
self.corner_fig = fig
return fig
[docs]
def plot_mcmc_diagnostics(self, do_trace=True, do_corner=True,
param_names=None,
max_vector_elems=2,
corner_bins=30, corner_max_points=2000,
save_fig_path=None,
show_plot=False):
"""Plot trace and/or corner diagnostics in a single convenience call.
Parameters
----------
do_trace : bool, optional
If True, render trace plot.
do_corner : bool, optional
If True, render corner plot.
param_names : list[str] | str | None
Parameter selector shared by both trace and corner plots.
Use ``'all'`` to include all posterior parameters.
max_vector_elems : int or None, optional
Maximum number of vector elements to expand per key.
corner_bins : int, optional
Histogram bin count for corner plot.
corner_max_points : int, optional
Maximum posterior draws to use in corner plot.
save_fig_path : str or None, optional
Output directory when saving figures. If ``None``, uses ``self.output_path``
(or ``'.'`` when unset).
show_plot : bool, optional
If True, display each enabled diagnostics figure interactively with
``plt.show()``. Defaults to False so diagnostics are safe to call in
headless terminals.
"""
if do_trace:
self.plot_trace(
param_names=param_names,
max_vector_elems=max_vector_elems,
save_fig_path=save_fig_path,
show_plot=show_plot,
)
if do_corner:
self.plot_corner(
param_names=param_names,
max_vector_elems=max_vector_elems,
bins=corner_bins,
max_points=corner_max_points,
save_fig_path=save_fig_path,
show_plot=show_plot,
)
[docs]
def plot_fig(self, save_fig_path=None, broad_fwhm=1200, plot_legend=True, ylims=None, plot_residual=True, show_title=True,
plot_1sigma=True, sigma_alpha=0.12, show_plot=True, plot_psf_space=False, plot_intrinsic_powerlaw=False):
"""Plot data, model components, line decomposition, and residuals.
Parameters
----------
save_fig_path : str or None, optional
Output directory when saving figures. If ``None``, uses ``self.output_path``
(or ``'.'`` when unset).
broad_fwhm : float, optional
Reserved broad-line threshold (kept for compatibility).
plot_legend : bool, optional
If True, draw a legend.
ylims : tuple[float, float] or None, optional
Optional y-axis limits for the spectrum panel.
plot_residual : bool, optional
If True, draw residual panel below the spectrum.
show_title : bool, optional
Reserved title toggle kept for compatibility.
plot_1sigma : bool, optional
If True, draw 16-84% posterior bands for available components.
sigma_alpha : float, optional
Alpha transparency of 1-sigma bands.
show_plot : bool, optional
If True, call ``plt.show()``. If False, figure is created/saved without display.
plot_psf_space : bool, optional
If True, plot the PSF-space model/components. In this mode the
residual panel is suppressed because the observed spectrum remains
on the fiber scale.
plot_intrinsic_powerlaw : bool, optional
If True, overlay the intrinsic AGN power law before multiplicative
polynomial tilt and additive edge corrections.
"""
if bool(getattr(self, "_resumed_from_samples", False)):
self._ensure_hydrated_from_samples()
matplotlib.rc('xtick', labelsize=20)
matplotlib.rc('ytick', labelsize=20)
psf_total_model = np.asarray(getattr(self, 'psf_model', []), dtype=float)
use_psf_space = bool(plot_psf_space) and psf_total_model.size == len(getattr(self, 'wave', [])) and np.any(np.isfinite(psf_total_model))
residual_enabled = bool(plot_residual) and not use_psf_space
if residual_enabled:
fig, (ax, ax_resid) = plt.subplots(
2,
1,
figsize=(15, 8),
sharex=True,
gridspec_kw={'height_ratios': [4.0, 1.2], 'hspace': 0.05},
)
else:
fig, ax = plt.subplots(1, 1, figsize=(15, 6))
ax_resid = None
flux_ref = float(np.nanpercentile(np.abs(self.flux[np.isfinite(self.flux)]), 95)) if np.any(np.isfinite(self.flux)) else 1.0
comp_floor = max(1e-8, 0.005 * flux_ref)
psf_scale = float(getattr(self, 'scale_psf', np.nan))
psf_scale = psf_scale if np.isfinite(psf_scale) else 1.0
total_model_plot = self.psf_model if use_psf_space else self.model_total
host_plot = self.host_psf if use_psf_space else self.host
pl_plot = psf_scale * self.f_pl_model if use_psf_space else self.f_pl_model
pl_intrinsic = np.asarray(getattr(self, 'f_pl_model_intrinsic', []), dtype=float)
pl_intrinsic_plot = psf_scale * pl_intrinsic if use_psf_space and pl_intrinsic.size == len(self.wave) else pl_intrinsic
fe_total_model = psf_scale * (self.f_fe_mgii_model + self.f_fe_balmer_model) if use_psf_space else (self.f_fe_mgii_model + self.f_fe_balmer_model)
bc_plot = psf_scale * self.f_bc_model if use_psf_space else self.f_bc_model
line_plot = self.line_psf if use_psf_space else self.f_line_model
total_model_label = 'total model (PSF)' if use_psf_space else 'total model'
host_label = 'host galaxy (PSF)' if use_psf_space else 'host galaxy'
powerlaw_label = 'power law (PSF)' if use_psf_space else 'power law'
intrinsic_powerlaw_label = 'intrinsic power law (PSF)' if use_psf_space else 'intrinsic power law'
fe_label = 'Fe II (PSF)' if use_psf_space else 'Fe II'
bc_label = 'Balmer continuum (PSF)' if use_psf_space else 'Balmer continuum'
line_label = 'total lines (PSF)' if use_psf_space else 'total lines'
custom_component_colors = [
'darkorange',
'crimson',
'slateblue',
'seagreen',
'saddlebrown',
'deeppink',
]
custom_components = list(getattr(self, 'custom_components', {}).items())
def _show_component(arr):
arr = np.asarray(arr, dtype=float)
arr = arr[np.isfinite(arr)]
if arr.size == 0:
return False
return float(np.nanmax(np.abs(arr))) >= comp_floor
def _finite_component_values(*arrays):
vals = []
for arr in arrays:
if arr is None:
continue
arr = np.asarray(arr, dtype=float)
arr = arr[np.isfinite(arr)]
if arr.size > 0:
vals.append(arr)
return vals
if plot_1sigma and hasattr(self, 'pred_bands') and not use_psf_space:
band_colors = {
'total_model': 'b',
'host': 'purple',
'PL': 'orange',
'FeII': 'teal',
'Balmer_cont': 'y',
'lines': 'lightskyblue',
}
for key, color in band_colors.items():
if key not in self.pred_bands:
continue
lo, hi = self.pred_bands[key]
if len(lo) == len(self.wave) and _show_component(0.5 * (np.asarray(lo) + np.asarray(hi))):
ax.fill_between(
self.wave,
lo,
hi,
color=color,
alpha=sigma_alpha,
linewidth=0,
zorder=0,
rasterized=True,
)
for idx, (name, model) in enumerate(custom_components):
if name not in self.pred_bands:
continue
lo, hi = self.pred_bands[name]
if len(lo) != len(self.wave) or not _show_component(model):
continue
ax.fill_between(
self.wave,
lo,
hi,
color=custom_component_colors[idx % len(custom_component_colors)],
alpha=sigma_alpha,
linewidth=0,
zorder=0,
rasterized=True,
)
if bool(plot_intrinsic_powerlaw) and hasattr(self, 'pred_bands') and not use_psf_space:
if 'PL_intrinsic' in self.pred_bands:
lo, hi = self.pred_bands['PL_intrinsic']
if len(lo) == len(self.wave) and _show_component(0.5 * (np.asarray(lo) + np.asarray(hi))):
ax.fill_between(
self.wave,
lo,
hi,
color='darkorange',
alpha=sigma_alpha,
linewidth=0,
zorder=0,
rasterized=True,
)
ax.plot(
self.wave_prereduced,
self.flux_prereduced,
color='k' if not use_psf_space else 'gray',
lw=1,
label='data' if not use_psf_space else 'fiber data',
zorder=2,
alpha=1.0 if not use_psf_space else 0.6,
rasterized=True,
)
ax.plot(self.wave, total_model_plot, color='b', lw=1.8, label=total_model_label, zorder=6, rasterized=True)
if _show_component(host_plot):
ax.plot(self.wave, host_plot, color='purple', lw=1.8, label=host_label, zorder=4, rasterized=True)
else:
ax.plot(self.wave, host_plot, color='purple', lw=1.8, zorder=4, rasterized=True)
if _show_component(pl_plot):
ax.plot(self.wave, pl_plot, color='orange', lw=1.5, label=powerlaw_label, zorder=5, rasterized=True)
else:
ax.plot(self.wave, pl_plot, color='orange', lw=1.5, zorder=5, rasterized=True)
if bool(plot_intrinsic_powerlaw) and pl_intrinsic_plot.size == len(self.wave):
if _show_component(pl_intrinsic_plot):
ax.plot(
self.wave,
pl_intrinsic_plot,
color='darkorange',
lw=1.4,
ls='--',
label=intrinsic_powerlaw_label,
zorder=5,
rasterized=True,
)
else:
ax.plot(self.wave, pl_intrinsic_plot, color='darkorange', lw=1.4, ls='--', zorder=5, rasterized=True)
if _show_component(fe_total_model):
ax.plot(self.wave, fe_total_model, color='teal', lw=1.2, label=fe_label, zorder=5, rasterized=True)
else:
ax.plot(self.wave, fe_total_model, color='teal', lw=1.2, zorder=5, rasterized=True)
if _show_component(bc_plot):
ax.plot(self.wave, bc_plot, color='y', lw=1.2, label=bc_label, zorder=5, rasterized=True)
else:
ax.plot(self.wave, bc_plot, color='y', lw=1.2, zorder=5, rasterized=True)
bal_legend_drawn = False
for idx, (name, model) in enumerate(custom_components):
color = custom_component_colors[idx % len(custom_component_colors)]
is_bal_component = str(name).startswith('bal_')
if is_bal_component:
label = 'BAL'
else:
label = name.replace('_', ' ')
if _show_component(model):
draw_label = label
if is_bal_component and bal_legend_drawn:
draw_label = None
ax.plot(self.wave, model, color=color, lw=1.4, label=draw_label, zorder=5, rasterized=True)
if is_bal_component and draw_label is not None:
bal_legend_drawn = True
else:
ax.plot(self.wave, model, color=color, lw=1.4, zorder=5, rasterized=True)
if len(line_plot) == len(self.wave):
if _show_component(line_plot):
ax.plot(
self.wave,
line_plot,
color='lightskyblue',
lw=1.5,
label=line_label,
zorder=5,
rasterized=True,
)
else:
ax.plot(
self.wave,
line_plot,
color='lightskyblue',
lw=1.5,
label=line_label,
zorder=5,
rasterized=True,
)
obs_phot_points = self._observed_photometry_for_plot()
if obs_phot_points is not None:
phot_x, phot_xerr, phot_y, phot_yerr = obs_phot_points
ax.errorbar(
phot_x,
phot_y,
yerr=phot_yerr,
xerr=phot_xerr,
fmt='s',
ms=7,
color='black',
mfc='white',
mec='black',
mew=0.8,
elinewidth=1.0,
capsize=3,
alpha=0.95,
zorder=8,
label='PSF photometry',
)
syn_phot_points = self._synthetic_photometry_for_plot(
model_attr='psf_model' if use_psf_space else 'model_total'
)
if syn_phot_points is not None:
phot_x, phot_xerr, phot_y, phot_yerr = syn_phot_points
ax.errorbar(
phot_x,
phot_y,
yerr=phot_yerr,
xerr=phot_xerr,
fmt='o',
ms=7,
color='crimson',
mec='white',
mew=0.8,
elinewidth=1.0,
capsize=3,
alpha=0.95,
zorder=8,
label='synthetic photometry',
)
# Plot individual Gaussian line components: broad (*_br) in red, narrow in green.
if (hasattr(self, 'line_component_amp_median')
and hasattr(self, 'line_component_mu_median')
and hasattr(self, 'line_component_sig_median')
and hasattr(self, 'tied_line_meta')
and len(self.line_component_amp_median) > 0):
lnwave = np.log(self.wave)
comp_labels = self.tied_line_meta.get('names', [''] * len(self.line_component_amp_median))
drew_broad_label = False
drew_narrow_label = False
show_line_leg = _show_component(line_plot)
for i in range(len(self.line_component_amp_median)):
amp = float(self.line_component_amp_median[i])
mu = float(self.line_component_mu_median[i])
sig = float(self.line_component_sig_median[i])
if not np.isfinite(amp) or not np.isfinite(mu) or not np.isfinite(sig) or sig <= 0:
continue
prof = amp * np.exp(-0.5 * ((lnwave - mu) / sig) ** 2)
# Keep component plotting consistent with polynomial correction if enabled.
if hasattr(self, 'f_poly_model') and len(self.f_poly_model) == len(prof):
prof = prof * self.f_poly_model
cname = str(comp_labels[i]).lower()
is_broad = cname.endswith('_br') or ('_br' in cname)
if use_psf_space:
line_scale = psf_scale if is_broad else psf_scale * float(getattr(self, 'eta_psf', 1.0))
prof = line_scale * prof
if is_broad:
lbl = 'broad components' if (show_line_leg and not drew_broad_label) else None
ax.plot(
self.wave,
prof,
color='red',
lw=0.7,
alpha=0.35,
zorder=3,
label=lbl,
rasterized=True,
)
drew_broad_label = True
else:
lbl = 'narrow components' if (show_line_leg and not drew_narrow_label) else None
ax.plot(
self.wave,
prof,
color='green',
lw=0.7,
alpha=0.25,
zorder=3,
label=lbl,
rasterized=True,
)
drew_narrow_label = True
ax.set_xlim(self.wave.min(), self.wave.max())
if ylims is None:
yvals = _finite_component_values(
self.flux,
total_model_plot,
host_plot,
pl_plot,
pl_intrinsic_plot,
fe_total_model,
bc_plot,
line_plot,
*[model for _, model in custom_components],
)
if yvals:
yplot = np.concatenate(yvals)
ymin, ymax = np.nanpercentile(yplot, [1, 99])
if np.isfinite(ymin) and np.isfinite(ymax) and ymax > ymin:
pad = 0.15 * (ymax - ymin)
ax.set_ylim(float(ymin - pad), float(ymax + pad))
else:
ax.set_ylim(ylims[0], ylims[1])
# Mark common broad-line AGN transitions on the spectrum panel.
broad_line_markers = [
("Ly$\\alpha$", 1215.67),
("CIV", 1549.06),
("CIII]", 1908.73),
("MgII", 2798.75),
("H$\\beta$", 4862.68),
("H$\\alpha$", 6564.61),
]
xlo, xhi = ax.get_xlim()
y_top = ax.get_ylim()[1]
text_x_offset = 0.01 * (xhi - xlo)
for label, lam0 in broad_line_markers:
if xlo <= lam0 <= xhi:
ax.axvline(lam0, color="gray", ls="--", lw=0.8, alpha=0.35, zorder=1)
ax.text(
lam0 - text_x_offset,
y_top * 0.985,
label,
rotation=90,
va="top",
ha="center",
fontsize=12,
color="dimgray",
alpha=0.9,
zorder=7,
)
if residual_enabled and len(total_model_plot) == len(self.wave) and ax_resid is not None:
resid = self.flux - total_model_plot
ax_resid.plot(
self.wave,
resid,
color='gray',
ls='dotted',
lw=1.0,
zorder=2,
rasterized=True,
)
ax_resid.axhline(0.0, color='k', ls='--', lw=0.8, zorder=1)
r = resid[np.isfinite(resid)]
if r.size > 0:
rlim = np.nanpercentile(np.abs(r), 99)
if np.isfinite(rlim) and rlim > 0:
ax_resid.set_ylim(-1.15 * rlim, 1.15 * rlim)
ax_resid.set_ylabel('resid', fontsize=20)
self._style_axis(ax_resid)
if residual_enabled and ax_resid is not None:
ax_resid.set_xlabel('Rest Wavelength (Å)', fontsize=20)
else:
ax.set_xlabel('Rest Wavelength (Å)', fontsize=20)
ax.set_ylabel(r'$f_{\lambda}$ (10$^{-17}$ erg s$^{-1}$ cm$^{-2}$ Å$^{-1}$)', fontsize=20)
self._style_axis(ax)
if plot_legend:
ax.legend(loc="upper right", frameon=True, framealpha=0.9, fontsize=12, ncol=2)
if show_plot:
plt.show()
if self.save_fig:
save_dir = self.output_path if save_fig_path is None else save_fig_path
if save_dir is None:
save_dir = '.'
os.makedirs(save_dir, exist_ok=True)
out_file = os.path.join(save_dir, self.filename + '.pdf')
fig.savefig(out_file, dpi=150)
print(f"Saved spectrum plot: {out_file}")
plt.close(fig)
self.fig = fig
return