Fit the Selsing Quasar Composite with JAXQSOFit#
This notebook downloads the Selsing et al. composite spectrum and fits it with jaxqsofit.QSOFit.
Current default: fit_method="optax" runs staged SVI/Optax MAP optimization (continuum warm start, then full model). fit_method="optax+nuts" uses that staged MAP point to initialize NUTS.
The composite is already rest-frame, so the prior config below pins pl_pivot=3000.0 and disables the optional expanded high-ionization/narrow-line sets for a compact demo fit.
[ ]:
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from urllib.request import urlretrieve
from jaxqsofit import QSOFit, build_default_prior_config
[ ]:
url = "https://raw.githubusercontent.com/jselsing/QuasarComposite/master/Selsing2015.dat"
data_dir = Path("./data")
data_dir.mkdir(parents=True, exist_ok=True)
fn = data_dir / "Selsing2015.dat"
if not fn.exists():
urlretrieve(url, fn)
arr = np.loadtxt(fn)
lam = arr[:, 0].astype(float)
flux = arr[:, 1].astype(float)
# Composite table usually has a third column uncertainty; if not, use a small floor.
if arr.shape[1] >= 3:
err = arr[:, 2].astype(float)
else:
err = np.full_like(flux, 1e-3 * np.nanmedian(np.abs(flux)))
# Keep only finite positive wavelengths and finite flux/error values
m = np.isfinite(lam) & np.isfinite(flux) & np.isfinite(err) & (lam > 0) & (err > 0)
lam, flux, err = lam[m], flux[m], err[m]
print(f"Loaded {lam.size} pixels from {fn}")
[ ]:
# The composite is already in rest-frame wavelengths.
z = 0.0
flux_for_priors = flux * (1.0 + z)
q = QSOFit(
lam=lam,
flux=flux,
err=err,
z=z,
)
prior_config = build_default_prior_config(
flux=flux_for_priors,
include_elg_narrow_lines=False,
include_high_ionization_lines=False,
pl_pivot=3000.0,
)
[ ]:
q.fit(
deredden=False,
fit_method="optax+nuts",
fit_lines=True,
decompose_host=False,
fit_pl=True,
fit_fe=True,
fit_bc=False,
fit_poly=True,
mask_lya_forest=False,
prior_config=prior_config,
dsps_ssp_fn="../tempdata.h5",
nuts_warmup=20,
nuts_samples=30,
nuts_chains=1,
nuts_target_accept=0.9,
optax_steps=1200,
optax_lr=1e-2,
save_result=True,
plot_fig=True,
save_fig=True,
verbose=True
)
[ ]:
# Plot power-law (PL) subtracted spectrum (flattened around zero)
wave = np.asarray(q.wave)
fobs = np.asarray(q.flux)
# Try common attribute names for the fitted PL component
pl = None
for name in ["model_pl", "f_pl_model", "pl_model", "f_pl", "model_powerlaw"]:
if hasattr(q, name):
cand = np.asarray(getattr(q, name))
if cand.shape == fobs.shape:
pl = cand
break
if pl is None:
raise AttributeError(
"Could not find a PL model on `q`. "
"Check available attributes and replace with the correct PL component."
)
# Subtract PL and remove small residual offset so continuum is visually flat
f_pl_sub = fobs - pl
cont_mask = (wave > 1700) & (wave < 4000) & np.isfinite(f_pl_sub)
offset = np.median(f_pl_sub[cont_mask]) if np.any(cont_mask) else np.median(f_pl_sub)
f_flat = f_pl_sub - offset
plt.figure(figsize=(10, 4))
plt.plot(wave, f_flat, color="k", lw=1.0, label="PL-subtracted spectrum")
plt.axhline(0.0, color="tab:blue", ls="--", lw=1.0, alpha=0.9, label="flat baseline")
plt.xlim(wave.min(), wave.max())
plt.xlabel(r"Rest Wavelength ($\AA$)")
plt.ylabel(r"$f_\lambda - f_{\lambda,\mathrm{PL}}$")
plt.legend(frameon=True)
plt.tight_layout()
[ ]:
mask_lyman_forest = q.wave > 1215.67
resid = np.asarray(q.flux) - np.asarray(q.model_total)
sigma = np.asarray(q.err)
# include fitted jitter (recommended)
if q.numpyro_samples is not None:
s = q.numpyro_samples
frac_j = float(np.median(np.asarray(s.get("frac_jitter", 0.0))))
add_j = float(np.median(np.asarray(s.get("add_jitter", 0.0))))
sigma = np.sqrt(sigma**2 + (frac_j*np.abs(np.asarray(q.model_total)))**2 + add_j**2)
m = np.isfinite(resid) & np.isfinite(sigma) & (sigma > 0) & mask_lyman_forest
z = resid[m] / sigma[m]
chi2 = float(np.sum(z**2))
chi2_per_pixel = float(np.mean(z**2)) # more stable than reduced chi2 here
wrms = float(np.sqrt(np.mean(z**2)))
print("chi2:", chi2)
print("chi2_per_pixel:", chi2_per_pixel)
print("wrms (normalized residual std):", wrms)
[ ]:
print("SN_ratio_conti:", q.SN_ratio_conti)
print("frac_host_4200:", q.frac_host_4200)
print("frac_host_5100:", q.frac_host_5100)
print("frac_host_2500:", q.frac_host_2500)
[ ]:
q.plot_mcmc_diagnostics(param_names='log_Fe_op_over_uv')
[ ]:
plt.figure(figsize=(10, 4))
plt.plot(lam, flux, color="k", lw=1.0, alpha=0.8, label="Selsing composite")
if hasattr(q, "model_total"):
plt.plot(q.wave, q.model_total, color="tab:blue", lw=1.2, label="best-fit model")
plt.xlabel(r"Rest Wavelength ($\AA$)")
plt.ylabel(r"$f_\lambda$")
plt.legend(frameon=True)
plt.tight_layout()