JaxQSOFit Tutorial: SDSS Example + Broad-Line Measurements#
This notebook demonstrates the main features of jaxqsofit using a real SDSS spectrum near:
coord = SkyCoord(184.0307, -2.2383, unit='deg')
Recommended fitting mode: fit_method='optax+nuts'. This runs the staged SVI/Optax MAP initializer, then performs full posterior sampling with NUTS.
It covers:
Fetching spectrum with
astroqueryRunning fits (
nuts,optax,optax+nuts)Overriding priors with
prior_configMeasuring broad-line FWHM and luminosity from fitted components
Current default: fit_method="optax" runs staged SVI/Optax MAP optimization (continuum warm start, then full model). fit_method="optax+nuts" uses that staged MAP point to initialize NUTS.
[ ]:
import numpy as np
import matplotlib.pyplot as plt
from astroquery.sdss import SDSS
from astropy.coordinates import SkyCoord
from astropy import units as u
from astropy.cosmology import FlatLambdaCDM
from jaxqsofit import QSOFit, build_default_prior_config
1. Download one SDSS spectrum#
[ ]:
coord = SkyCoord(184.0307, -2.2383, unit='deg')
xid = SDSS.query_region(coord, spectro=True, radius=5 * u.arcsec)
sp = SDSS.get_spectra(matches=xid[:1])[0]
tb = sp[1].data
lam = np.asarray(10 ** tb['loglam'], dtype=float)
flux = np.asarray(tb['flux'], dtype=float)
ivar = np.asarray(tb['ivar'], dtype=float)
err = np.full_like(flux, 1e-6)
m = np.isfinite(ivar) & (ivar > 0)
err[m] = 1.0 / np.sqrt(ivar[m])
z = float(sp[2].data['z'][0])
ra = float(coord.ra.deg)
dec = float(coord.dec.deg)
print(f'Nspec pixels: {lam.size}')
print(f'z = {z:.5f}')
plate = int(sp[0].header.get('plateid', 0))
mjd = int(sp[0].header.get('mjd', 0))
fiber = int(sp[0].header.get('fiberid', 0))
sdss_filename = f"{plate:04d}-{mjd}-{fiber:04d}"
2. Build a default prior config (auto-scaled from flux)#
You can use defaults directly, or modify selected entries.
[ ]:
flux_for_priors = flux * (1.0 + z)
prior_config = build_default_prior_config(flux_for_priors)
# Example override: slightly tighter PL slope and Fe normalization priors
prior_config['PL_slope'] = {'loc': -1.5, 'scale': 0.3, 'low': -3.5, 'high': 0.5}
prior_config['log_Fe_uv_norm'] = {'loc': np.log(max(1e-3 * np.median(np.abs(flux_for_priors)), 1e-10)), 'scale': 0.4}
prior_config['log_Fe_op_over_uv'] = {'loc': 0.0, 'scale': 0.4}
# Robust line scale multipliers (optional)
prior_config['line_dmu_scale_mult'] = 0.25
prior_config['line_sig_scale_mult'] = 0.25
prior_config['line_amp_scale_mult'] = 0.20
3. Run a fit#
Recommended:
fit_method='optax+nuts'(staged SVI/Optax initialization, then full NUTS posterior)
Other options:
fit_method='nuts'(full posterior, slower initialization)fit_method='optax'(staged MAP only, fastest, no posterior samples)
[ ]:
q = QSOFit(
lam=lam,
flux=flux,
err=err,
z=z,
ra=ra,
dec=dec,
filename=sdss_filename,
output_path='.',
)
q.fit(
deredden=True,
fit_method='optax+nuts',
fit_lines=True,
decompose_host=True,
fit_fe=True,
fit_bc=False,
fit_poly=True,
prior_config=prior_config,
dsps_ssp_fn='../tempdata.h5',
optax_steps=600,
optax_lr=1e-2,
nuts_warmup=50,
nuts_samples=50,
nuts_chains=1,
plot_fig=True,
save_fig=False,
save_result=False,
)
Estimate reduced Chi squared:
[ ]:
resid = np.asarray(q.flux) - np.asarray(q.model_total)
sigma = np.asarray(q.err)
# include fitted jitter (recommended)
if q.numpyro_samples is not None:
s = q.numpyro_samples
frac_j = float(np.median(np.asarray(s.get("frac_jitter", 0.0))))
add_j = float(np.median(np.asarray(s.get("add_jitter", 0.0))))
sigma = np.sqrt(sigma**2 + (frac_j*np.abs(np.asarray(q.model_total)))**2 + add_j**2)
m = np.isfinite(resid) & np.isfinite(sigma) & (sigma > 0)
z = resid[m] / sigma[m]
chi2 = float(np.sum(z**2))
chi2_per_pixel = float(np.mean(z**2)) # more stable than reduced chi2 here
wrms = float(np.sqrt(np.mean(z**2)))
print("chi2:", chi2)
print("chi2_per_pixel:", chi2_per_pixel)
print("wrms (normalized residual std):", wrms)
4. MCMC diagnostics (trace + corner)#
[ ]:
q.plot_mcmc_diagnostics(
#param_names='all',
do_trace=True,
do_corner=True,
max_vector_elems=2,
corner_bins=25,
corner_max_points=1500,
)
5. Measure broad-line FWHM and luminosity with posterior errors#
QSOFit provides:
line_profile_from_components(line_key)for posterior-median component profilesline_profile_from_draw(draw_index, line_key)for one posterior drawline_props(profile, wave=None)->(fwhm_kms, integrated_area)
The integrated area is in 10^-17 erg s^-1 cm^-2 units, so convert to luminosity with cosmology.
[ ]:
cosmo = FlatLambdaCDM(H0=70, Om0=0.3)
def flux_to_luminosity(area_1e17, z):
d_l_cm = cosmo.luminosity_distance(z).to(u.cm).value
return area_1e17 * 1e-17 * 4.0 * np.pi * d_l_cm**2
amp_draws = np.asarray(q.pred_out['line_amp_per_component'])
for line_key in ["CIV_br", "MgII_br", "Hb_br", "Ha_br"]:
fwhm_samp, logL_samp = [], []
for i in range(amp_draws.shape[0]):
prof = q.line_profile_from_draw(i, line_key)
fwhm, area = q.line_props(prof)
fwhm_samp.append(fwhm)
if np.isfinite(area) and area > 0:
logL_samp.append(np.log10(flux_to_luminosity(area, q.z)))
else:
logL_samp.append(np.nan)
fwhm_samp = np.asarray(fwhm_samp, dtype=float)
logL_samp = np.asarray(logL_samp, dtype=float)
f16, f50, f84 = np.nanpercentile(fwhm_samp, [16, 50, 84])
l16, l50, l84 = np.nanpercentile(logL_samp, [16, 50, 84])
print(
f"{line_key:8s} "
f"FWHM={f50:.1f} (+{f84-f50:.1f}/-{f50-f16:.1f}) km/s "
f"logL={l50:.3f} (+{l84-l50:.3f}/-{l50-l16:.3f})"
)
…but don’t forget to subtract the instrumental resolution in quadrature before reporting the FWHM values.
6. Inspect posterior samples (recommended with optax+nuts)#
If you used fit_method='nuts' or 'optax+nuts', posterior samples are stored in:
q.numpyro_samples
[ ]:
if hasattr(q, 'numpyro_samples') and q.numpyro_samples is not None:
keys = sorted(q.numpyro_samples.keys())
print('Num posterior params:', len(keys))
print('First 20 keys:', keys)
else:
print('No NumPyro samples available (likely fit_method=optax).')
7. Quick component diagnostics#
[ ]:
print('max data :', np.nanmax(q.flux))
print('max total model :', np.nanmax(q.model_total))
print('max PL :', np.nanmax(q.f_pl_model))
print('max host :', np.nanmax(q.host))
print('max FeII :', np.nanmax(q.f_fe_mgii_model + q.f_fe_balmer_model))
print('max Balmer cont :', np.nanmax(q.f_bc_model))
print('max lines :', np.nanmax(q.f_line_model))
[ ]: