Fit the Selsing Quasar Composite with JAXQSOFit

Fit the Selsing Quasar Composite with JAXQSOFit#

This notebook downloads the Selsing et al. composite spectrum and fits it with jaxqsofit.QSOFit.

Current default: fit_method="optax" runs staged SVI/Optax MAP optimization (continuum warm start, then full model). fit_method="optax+nuts" uses that staged MAP point to initialize NUTS.

The composite is already rest-frame, so the prior config below pins pl_pivot=3000.0 and disables the optional expanded high-ionization/narrow-line sets for a compact demo fit.

[ ]:
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from urllib.request import urlretrieve

from jaxqsofit import QSOFit, build_default_prior_config

[ ]:
url = "https://raw.githubusercontent.com/jselsing/QuasarComposite/master/Selsing2015.dat"
data_dir = Path("./data")
data_dir.mkdir(parents=True, exist_ok=True)
fn = data_dir / "Selsing2015.dat"
if not fn.exists():
    urlretrieve(url, fn)

arr = np.loadtxt(fn)
lam = arr[:, 0].astype(float)
flux = arr[:, 1].astype(float)

# Composite table usually has a third column uncertainty; if not, use a small floor.
if arr.shape[1] >= 3:
    err = arr[:, 2].astype(float)
else:
    err = np.full_like(flux, 1e-3 * np.nanmedian(np.abs(flux)))

# Keep only finite positive wavelengths and finite flux/error values
m = np.isfinite(lam) & np.isfinite(flux) & np.isfinite(err) & (lam > 0) & (err > 0)
lam, flux, err = lam[m], flux[m], err[m]

print(f"Loaded {lam.size} pixels from {fn}")

[ ]:
# The composite is already in rest-frame wavelengths.
z = 0.0
flux_for_priors = flux * (1.0 + z)

q = QSOFit(
    lam=lam,
    flux=flux,
    err=err,
    z=z,
)

prior_config = build_default_prior_config(
    flux=flux_for_priors,
    include_elg_narrow_lines=False,
    include_high_ionization_lines=False,
    pl_pivot=3000.0,
)

[ ]:
q.fit(
    deredden=False,
    fit_method="optax+nuts",
    fit_lines=True,
    decompose_host=False,
    fit_pl=True,
    fit_fe=True,
    fit_bc=False,
    fit_poly=True,
    mask_lya_forest=False,
    prior_config=prior_config,
    dsps_ssp_fn="../tempdata.h5",
    nuts_warmup=20,
    nuts_samples=30,
    nuts_chains=1,
    nuts_target_accept=0.9,
    optax_steps=1200,
    optax_lr=1e-2,
    save_result=True,
    plot_fig=True,
    save_fig=True,
    verbose=True
)

[ ]:
# Plot power-law (PL) subtracted spectrum (flattened around zero)
wave = np.asarray(q.wave)
fobs = np.asarray(q.flux)

# Try common attribute names for the fitted PL component
pl = None
for name in ["model_pl", "f_pl_model", "pl_model", "f_pl", "model_powerlaw"]:
    if hasattr(q, name):
        cand = np.asarray(getattr(q, name))
        if cand.shape == fobs.shape:
            pl = cand
            break

if pl is None:
    raise AttributeError(
        "Could not find a PL model on `q`. "
        "Check available attributes and replace with the correct PL component."
    )

# Subtract PL and remove small residual offset so continuum is visually flat
f_pl_sub = fobs - pl
cont_mask = (wave > 1700) & (wave < 4000) & np.isfinite(f_pl_sub)
offset = np.median(f_pl_sub[cont_mask]) if np.any(cont_mask) else np.median(f_pl_sub)
f_flat = f_pl_sub - offset

plt.figure(figsize=(10, 4))
plt.plot(wave, f_flat, color="k", lw=1.0, label="PL-subtracted spectrum")
plt.axhline(0.0, color="tab:blue", ls="--", lw=1.0, alpha=0.9, label="flat baseline")
plt.xlim(wave.min(), wave.max())
plt.xlabel(r"Rest Wavelength ($\AA$)")
plt.ylabel(r"$f_\lambda - f_{\lambda,\mathrm{PL}}$")
plt.legend(frameon=True)
plt.tight_layout()
[ ]:
mask_lyman_forest = q.wave > 1215.67

resid = np.asarray(q.flux) - np.asarray(q.model_total)
sigma = np.asarray(q.err)

# include fitted jitter (recommended)
if q.numpyro_samples is not None:
    s = q.numpyro_samples
    frac_j = float(np.median(np.asarray(s.get("frac_jitter", 0.0))))
    add_j  = float(np.median(np.asarray(s.get("add_jitter", 0.0))))
    sigma = np.sqrt(sigma**2 + (frac_j*np.abs(np.asarray(q.model_total)))**2 + add_j**2)

m = np.isfinite(resid) & np.isfinite(sigma) & (sigma > 0) & mask_lyman_forest
z = resid[m] / sigma[m]

chi2 = float(np.sum(z**2))
chi2_per_pixel = float(np.mean(z**2))      # more stable than reduced chi2 here
wrms = float(np.sqrt(np.mean(z**2)))

print("chi2:", chi2)
print("chi2_per_pixel:", chi2_per_pixel)
print("wrms (normalized residual std):", wrms)

[ ]:
print("SN_ratio_conti:", q.SN_ratio_conti)
print("frac_host_4200:", q.frac_host_4200)
print("frac_host_5100:", q.frac_host_5100)
print("frac_host_2500:", q.frac_host_2500)

[ ]:
q.plot_mcmc_diagnostics(param_names='log_Fe_op_over_uv')
[ ]:
plt.figure(figsize=(10, 4))
plt.plot(lam, flux, color="k", lw=1.0, alpha=0.8, label="Selsing composite")
if hasattr(q, "model_total"):
    plt.plot(q.wave, q.model_total, color="tab:blue", lw=1.2, label="best-fit model")
plt.xlabel(r"Rest Wavelength ($\AA$)")
plt.ylabel(r"$f_\lambda$")
plt.legend(frameon=True)
plt.tight_layout()