JAXQSOFit Tutorial: SDSS Example + PSF Photometric Recalibration#

This notebook mirrors the main SDSS tutorial, but also:

  • queries SDSS PSF magnitudes for the same object

  • fits the spectrum without PSF photometry

  • fits the spectrum again with use_psf_phot=True

  • compares the fiber-only and PSF-constrained results

This notebook demonstrates a joint spectral-plus-photometric calibration scheme in JAXQSOFit that connects the SDSS fiber spectrum to broadband PSF photometry. The spectrum is fit on its native fiber scale, while a second likelihood compares synthetic PSF magnitudes from a PSF-space version of the model to the observed SDSS PSF magnitudes. In that PSF-space model, a gray offset term delta_m_psf captures overall spectrophotometric or epoch-to-epoch normalization differences, while a host-suppression term eta_psf reduces the contribution of spatially extended components in the PSF aperture. The key idea is that the AGN continuum, host galaxy, broad-line emission, and narrow-line emission do not contribute in the same way to fiber and PSF measurements. The AGN continuum and broad lines are treated as compact, unresolved components, while the stellar host and narrow lines are treated as more extended. Because the AGN and host also have different spectral shapes, changing the relative contribution of the extended components changes the broadband colors in a characteristic way. By fitting the fiber spectrum and the PSF photometry simultaneously, the model can use those color differences to separate AGN and host light more cleanly, infer PSF-consistent components, and estimate more realistic host fractions in the photometric aperture.

We model the observed SDSS fiber spectrum and the SDSS PSF photometry jointly.

For the rest-frame spectral model, we write

\[f_{\mathrm{fiber}}(\lambda) = f_{\mathrm{AGN}}(\lambda) + f_{\mathrm{host}}(\lambda) + f_{\mathrm{lines,br}}(\lambda) + f_{\mathrm{lines,na}}(\lambda),\]

where the AGN term includes the power law, Fe II, and Balmer continuum components, and the emission-line model is split into broad and narrow components.

The fiber-spectrum likelihood is

\[F_i^{\mathrm{obs}} \sim \mathcal{N}\!\left( f_{\mathrm{fiber}}(\lambda_i), \, \sigma_{\mathrm{tot},i}^2 \right),\]

with total per-pixel uncertainty

\[\sigma_{\mathrm{tot},i}^2 = \sigma_i^2 + \left(f_{\mathrm{frac}}\,|f_{\mathrm{fiber}}(\lambda_i)|\right)^2 + \sigma_{\mathrm{add}}^2,\]

where \(\sigma_i\) is the measured spectral error, and \(f_{\mathrm{frac}}\) and \(\sigma_{\mathrm{add}}\) are fractional and additive jitter terms.

To compare with PSF photometry, we construct a PSF-space model

\[f_{\mathrm{PSF}}(\lambda) = s_{\mathrm{PSF}} \left[ f_{\mathrm{AGN}}(\lambda) + \eta_{\mathrm{PSF}}\,f_{\mathrm{host}}(\lambda) + f_{\mathrm{lines,br}}(\lambda) + \eta_{\mathrm{PSF}}\,f_{\mathrm{lines,na}}(\lambda) \right],\]

where

\[s_{\mathrm{PSF}} = 10^{-0.4\,\Delta m_{\mathrm{PSF}}},\]

\(\Delta m_{\mathrm{PSF}}\) is a gray magnitude offset between the fiber spectral scale and the PSF photometric scale, and \(\eta_{\mathrm{PSF}} \in [0,1]\) controls how much of the extended host-like components contribute in the PSF aperture. In this version of the model, broad lines are treated like the unresolved AGN, while narrow lines are treated like the extended host.

For each photometric band \(b\), we compute a synthetic AB magnitude from the PSF-space model,

\[m_b^{\mathrm{syn}} = -2.5\log_{10}\!\left(f_{\nu,b}\right)-48.60,\]

with

\[f_{\nu,b} = \frac{ \int f_{\lambda}^{\mathrm{PSF}}(\lambda)\,T_b(\lambda)\,\lambda\,d\lambda }{ \int T_b(\lambda)\,c/\lambda\,d\lambda },\]

where \(T_b(\lambda)\) is the filter transmission curve.

The PSF-photometry likelihood is then

\[m_b^{\mathrm{obs}} \sim \mathcal{N}\!\left( m_b^{\mathrm{syn}}, \, \sigma_{m,b}^2 + \sigma_{\mathrm{phot,extra}}^2 \right),\]

where \(\sigma_{m,b}\) is the catalog magnitude uncertainty and \(\sigma_{\mathrm{phot,extra}}\) is an additional photometric scatter term.

The full joint likelihood is

\[\mathcal{L} = \prod_i p\!\left(F_i^{\mathrm{obs}} \mid f_{\mathrm{fiber}}(\lambda_i)\right) \; \prod_b p\!\left(m_b^{\mathrm{obs}} \mid m_b^{\mathrm{syn}}\right).\]

A useful derived quantity is the PSF-space host fraction at wavelength \(\lambda\),

\[f_{\mathrm{host,PSF}}(\lambda) = \frac{f_{\mathrm{host}}^{\mathrm{PSF}}(\lambda)} {f_{\mathrm{AGN}}^{\mathrm{PSF}}(\lambda)+f_{\mathrm{host}}^{\mathrm{PSF}}(\lambda)} = \frac{\eta_{\mathrm{PSF}}\,f_{\mathrm{host}}(\lambda)} {f_{\mathrm{AGN}}(\lambda)+\eta_{\mathrm{PSF}}\,f_{\mathrm{host}}(\lambda)},\]

since the overall gray scale factor \(s_{\mathrm{PSF}}\) cancels in the ratio.

Likewise, the PSF-space narrow-line fraction is reduced by the same \(\eta_{\mathrm{PSF}}\) factor, while the broad-line contribution follows the unresolved AGN scaling.

Current default: fit_method="optax" runs staged SVI/Optax MAP optimization (continuum warm start, then full model). fit_method="optax+nuts" uses that staged MAP point to initialize NUTS.

The psf_bands argument is passed directly to QSOFit.fit with the SDSS PSF magnitudes and errors so the synthetic-photometry likelihood uses the intended band order.

[ ]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

from astroquery.sdss import SDSS
from astropy.coordinates import SkyCoord
from astropy import units as u
from astropy.cosmology import FlatLambdaCDM
from speclite import filters as speclite_filters

from jaxqsofit import QSOFit, build_default_prior_config

1. Download one SDSS spectrum and query PSF magnitudes#

[ ]:
coord = SkyCoord(184.0307, -2.2383, unit='deg') # High host example

# coord = SkyCoord.from_name('SDSS J212805.25-005145.7') # Low host example

coord = SkyCoord.from_name('SDSS J231250.89+001719.0')

# 010514.10-010222.9

coord = SkyCoord.from_name('SDSS J010514.10-010222.9')

# z0.623_002350.17+010809.6




psf_bands_all = ['u', 'g', 'r', 'i', 'z']

photo_fields = [
    'ra', 'dec',
    *[f'psfMag_{b}' for b in psf_bands_all],
    *[f'psfMagErr_{b}' for b in psf_bands_all],
]
spec_fields = ['plate', 'mjd', 'fiberID', 'run2d']

xid = SDSS.query_region(
    coord,
    spectro=True,
    radius=5 * u.arcsec,
    photoobj_fields=photo_fields,
    specobj_fields=spec_fields,
)
if xid is None or len(xid) == 0:
    raise RuntimeError('No SDSS spectrum / photometry match found near the target coordinates.')

# Keep only the most recent spectroscopic visit by MJD
if 'mjd' not in xid.colnames:
    raise KeyError(f"'mjd' not found in SDSS query result columns: {xid.colnames}")

print(xid)

xid.sort('mjd')      # ascending
print()
if len(xid) < 2:
    raise RuntimeError('Need at least 2 spectroscopic matches to select the 2nd most recent one.')

print('Matches by MJD (oldest -> newest):')
print(xid[['plate', 'mjd', 'fiberID', 'run2d']])

# Remove the newest row so the existing downstream selection picks the 2nd most recent overall.
#xid = xid[:-1]

xid = xid[::-1][:1]  # most recent only

print(
    f"Using latest match: plate={int(xid['plate'][0])}, "
    f"mjd={int(xid['mjd'][0])}, fiberID={int(xid['fiberID'][0])}"
)

sp = SDSS.get_spectra(matches=xid[:1])[0]

tb = sp[1].data
lam = np.asarray(10 ** tb['loglam'], dtype=float)
flux = np.asarray(tb['flux'], dtype=float)
ivar = np.asarray(tb['ivar'], dtype=float)

err = np.full_like(flux, 1e-6)
m = np.isfinite(ivar) & (ivar > 0)
err[m] = 1.0 / np.sqrt(ivar[m])

z = float(sp[2].data['z'][0])
ra = float(coord.ra.deg)
dec = float(coord.dec.deg)

plate = int(sp[0].header.get('plateid', 0))
mjd = int(sp[0].header.get('mjd', 0))
fiber = int(sp[0].header.get('fiberid', 0))
sdss_filename = f"{plate:04d}-{mjd}-{fiber:04d}"

row = xid[0]


def get_row_value(row, name):
    for candidate in (name, name.lower(), name.upper()):
        if candidate in row.colnames:
            return row[candidate]
    raise KeyError(f'Column {name!r} not found in SDSS query result. Available columns: {row.colnames}')


psf_mags_all = np.array([float(get_row_value(row, f'psfMag_{b}')) for b in psf_bands_all], dtype=float)
psf_mag_errs_all = np.array([float(get_row_value(row, f'psfMagErr_{b}')) for b in psf_bands_all], dtype=float)

phot_table = pd.DataFrame({
    'band': psf_bands_all,
    'psf_mag': psf_mags_all,
    'psf_mag_err': psf_mag_errs_all,
})

print(f'Nspec pixels: {lam.size}')
print(f'z = {z:.5f}')
print(f'SDSS filename key: {sdss_filename}')
phot_table

[ ]:
xid
[ ]:
SDSS.get_spectra(matches=xid)

2. Build a default prior config#

[ ]:
def make_prior_config(flux, z):
    flux_for_priors = flux * (1.0 + z)
    prior_config = build_default_prior_config(flux_for_priors)

    # Example override: slightly tighter PL slope and Fe normalization priors
    prior_config['PL_slope'] = {'loc': -1.5, 'scale': 0.3, 'low': -3.5, 'high': 0.5}
    prior_config['log_Fe_uv_norm'] = {
        'loc': np.log(max(1e-3 * np.median(np.abs(flux_for_priors)), 1e-10)),
        'scale': 0.4,
    }
    prior_config['log_Fe_op_over_uv'] = {'loc': 0.0, 'scale': 0.4}

    # Robust line scale multipliers (optional)
    prior_config['line_dmu_scale_mult'] = 0.25
    prior_config['line_sig_scale_mult'] = 0.25
    prior_config['line_amp_scale_mult'] = 0.20
    prior_config["edge_rbf_frac_max"] = 0.20

    return prior_config


prior_config = make_prior_config(flux, z)

3. Run a baseline fit without PSF recalibration#

[ ]:
q_nocal = QSOFit(
    lam=lam,
    flux=flux,
    err=err,
    z=z,
    ra=ra,
    dec=dec,
    filename=sdss_filename + '_fiber_only',
    output_path='.',
)

q_nocal.fit(
    deredden=True,
    fit_method='optax+nuts',
    fit_lines=True,
    decompose_host=True,
    fit_fe=True,
    fit_bc=False,
    fit_poly=True,
    fsps_age_grid=(0.01, 0.03, 0.1, 0.3, 1.0, 3.0, 10.0),
    fsps_logzsol_grid=(-1.0, -0.5, 0.0, 0.2),
    prior_config=make_prior_config(flux, z),
    dsps_ssp_fn='../tempdata.h5',
    optax_steps=600,
    optax_lr=1e-2,
    nuts_warmup=50,
    nuts_samples=50,
    nuts_chains=1,
    plot_fig=True,
    save_fig=False,
    save_result=False,
    use_psf_phot=False,
)

4. Run a fit with SDSS PSF photometry#

[ ]:
q_psf = QSOFit(
    lam=lam,
    flux=flux,
    err=err,
    z=z,
    ra=ra,
    dec=dec,
    filename=sdss_filename + '_psf_recal',
    output_path='.',
)

q_psf.fit(
    deredden=True,
    fit_method='optax+nuts',
    fit_lines=True,
    decompose_host=True,
    fit_fe=True,
    fit_bc=False,
    fit_poly=True,
    fsps_age_grid=(0.01, 0.03, 0.1, 0.3, 1.0, 3.0, 10.0),
    fsps_logzsol_grid=(-1.0, -0.5, 0.0, 0.2),
    prior_config=make_prior_config(flux, z),
    dsps_ssp_fn='../tempdata.h5',
    optax_steps=600,
    optax_lr=1e-2,
    nuts_warmup=50,
    nuts_samples=50,
    nuts_chains=1,
    plot_fig=True,
    save_fig=False,
    save_result=False,
    psf_mags=psf_mags_all,
    psf_mag_errs=psf_mag_errs_all,
    psf_bands=psf_bands_all,
    use_psf_phot=True,
    kwargs_plot={'plot_residual': True, 'plot_psf_space': True},
)

print('Bands kept by the PSF photometry likelihood:', getattr(q_psf, 'psf_bands', None))

Fiber spectrum (below), compared to PSF spectrum (above)

[ ]:
q_psf.plot_fig(plot_residual=False, plot_psf_space=False)

5. Compare fit-quality and host-decomposition summaries#

[ ]:
def fit_stats(q):
    resid = np.asarray(q.flux) - np.asarray(q.model_total)
    sigma = np.asarray(q.err)

    if getattr(q, 'numpyro_samples', None) is not None:
        s = q.numpyro_samples
        frac_j = float(np.median(np.asarray(s.get('frac_jitter', 0.0))))
        add_j = float(np.median(np.asarray(s.get('add_jitter', 0.0))))
        sigma = np.sqrt(sigma**2 + (frac_j * np.abs(np.asarray(q.model_total)))**2 + add_j**2)

    m = np.isfinite(resid) & np.isfinite(sigma) & (sigma > 0)
    zres = resid[m] / sigma[m]

    return {
        'chi2': float(np.sum(zres**2)),
        'chi2_per_pixel': float(np.mean(zres**2)),
        'wrms': float(np.sqrt(np.mean(zres**2))),
        'frac_host_2500': float(getattr(q, 'frac_host_2500', np.nan)),
        'frac_host_4200': float(getattr(q, 'frac_host_4200', np.nan)),
        'delta_m_psf': float(getattr(q, 'delta_m_psf', np.nan)),
        'eta_psf': float(getattr(q, 'eta_psf', np.nan)),
        'scale_psf': float(getattr(q, 'scale_psf', np.nan)),
    }


summary = pd.DataFrame({
    'fiber_only': fit_stats(q_nocal),
    'with_psf_phot': fit_stats(q_psf),
}).T

summary

6. Compare the fitted fiber-space models directly#

[ ]:
fig, ax = plt.subplots(figsize=(14, 6))
ax.plot(q_nocal.wave, q_nocal.flux, color='k', lw=1.0, alpha=0.5, label='data')
ax.plot(q_nocal.wave, q_nocal.model_total, color='tab:blue', lw=1.5, label='fiber-only model')
ax.plot(q_psf.wave, q_psf.model_total, color='tab:orange', lw=1.5, label='PSF-constrained fiber model')
ax.set_xlabel(r'Rest Wavelength ($\AA$)')
ax.set_ylabel(r'$f_\lambda$')
ax.legend(frameon=False)
plt.tight_layout()
plt.show()

8. Inspect PSF-specific posterior quantities#

[ ]:
for name in ['delta_m_psf', 'delta_m_psf_err', 'eta_psf', 'eta_psf_err', 'scale_psf']:
    print(f'{name:18s}:', getattr(q_psf, name, np.nan))

[ ]: