Local SDSS FITS + BAL Tutorial#

This notebook shows how to:

  • load a local SDSS spectrum from a FITS file you provide,

  • run jaxqsofit on it,

  • enable the built-in BAL trough models with fit_bal=True.

Important limitation:

  • the built-in BAL option uses simple additive negative-Gaussian troughs for common BAL lines,

  • so it is useful as a pragmatic trough model, but it is not a radiative-transfer BAL model.

Current default: fit_method="optax" runs staged SVI/Optax MAP optimization (continuum warm start, then full model). fit_method="optax+nuts" uses that staged MAP point to initialize NUTS.

[ ]:
import os
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt
from astropy.io import fits

from jaxqsofit import QSOFit, build_default_prior_config

[ ]:
# The built-in BAL models are enabled later via fit_bal=True.

1. Point to your local SDSS FITS file#

Replace local_fits_path with the spectrum you want to fit.

[ ]:
local_fits_path = Path("spec-7850-56956-0924.fits")
#local_fits_path = Path("spec-4216-55477-0778.fits")

assert local_fits_path.exists(), f"Missing FITS file: {local_fits_path}"

bal_line_rest = 1549.0  # C IV BAL example

2. Read the SDSS spectrum from disk#

This assumes a standard SDSS spectral FITS layout with flux/loglam/ivar in HDU 1 and redshift in HDU 2.

[ ]:
with fits.open(local_fits_path) as hdul:
    tb = hdul[1].data
    hdr0 = hdul[0].header
    hdr1 = hdul[1].header

    lam = np.asarray(10 ** tb["loglam"], dtype=float)
    flux = np.asarray(tb["flux"], dtype=float)
    ivar = np.asarray(tb["ivar"], dtype=float)

    err = np.full_like(flux, 1e-6)
    good = np.isfinite(ivar) & (ivar > 0)
    err[good] = 1.0 / np.sqrt(ivar[good])

    z = np.nan
    if len(hdul) > 2 and hdul[2].data is not None:
        names = set(hdul[2].data.names or [])
        row = hdul[2].data[0]
        if "Z" in names:
            z = float(row["Z"])
        elif "z" in names:
            z = float(row["z"])

    if not np.isfinite(z):
        z = float(hdr0.get("Z", hdr1.get("Z", np.nan)))

    ra = float(hdr0.get("PLUG_RA", hdr0.get("RA", np.nan)))
    dec = float(hdr0.get("PLUG_DEC", hdr0.get("DEC", np.nan)))

sdss_filename = f"test"

print(f"Loaded: {local_fits_path}")
print(f"Nspec pixels: {lam.size}")
print(f"z = {z:.5f}")
print(f"filename tag = {sdss_filename}")
[ ]:
wave_rest = lam / (1.0 + z)

plt.figure(figsize=(11, 4))
plt.plot(wave_rest, flux, color="k", lw=0.8)
plt.xlim(np.nanmin(wave_rest), np.nanmax(wave_rest))
plt.xlabel("Rest wavelength [A]")
plt.ylabel("Flux [1e-17 cgs]")
plt.title("Input local SDSS spectrum")
plt.grid(alpha=0.2)
plt.show()

3. Build priors#

The built-in fit_bal=True path adds common BAL trough models automatically.

[ ]:
flux_for_priors = flux * (1.0 + z)
prior_config = build_default_prior_config(flux_for_priors)

# Keep the continuum flexible but not too loose.
#prior_config["PL_slope"] = {"dist": "Normal", "loc": -1.5, "scale": 0.35}

4. Fit the spectrum#

Suggested notes:

  • keep fit_lines=True so the normal emission-line model remains active,

  • set fit_bal=True to include the built-in BAL trough models,

  • use a rest-frame window that covers your BAL region.

After fitting, BAL components are reported through q.custom_components; the total and continuum models are available as q.model_total and q.f_conti_model.

[ ]:
q = QSOFit(
    lam=lam,
    flux=flux,
    err=err,
    z=z,
    ra=ra,
    dec=dec,
    filename=sdss_filename,
    output_path=".",
)

q.fit(
    deredden=True,
    wave_range=(1250.0, 32000.0),
    fit_method="optax+nuts",
    fit_lines=True,
    decompose_host=True,
    fit_pl=True,
    fit_fe=True,
    fit_bc=False,
    fit_bal=True,
    fit_poly=True,
    fit_reddening=True,
    prior_config=prior_config,
    dsps_ssp_fn="../tempdata.h5",
    optax_steps=600,
    optax_lr=1e-2,
    nuts_warmup=10,
    nuts_samples=10,
    nuts_chains=1,
    plot_fig=True,
    save_fig=False,
    save_result=False,
    show_plot=False,
)

[ ]:
from scipy.ndimage import gaussian_filter1d

# Extract the iron template component from the fit results
fe_model = (
    np.asarray(q.f_fe_mgii_model) + np.asarray(q.f_fe_balmer_model)
    if hasattr(q, "f_fe_mgii_model") and hasattr(q, "f_fe_balmer_model")
    else q.model_total * 0
)

# Broaden the iron template by 3000 km/s

sigma_kms = 3000.0
# Convert velocity broadening to wavelength space
sigma_wave = sigma_kms / 3e5 * np.mean(q.wave)
# Convert to pixel space
dlam = np.mean(np.diff(q.wave))
sigma_pix = sigma_wave / dlam
fe_model_broadened = gaussian_filter1d(fe_model, sigma=sigma_pix)

# Define rest-frame wavelength range
wave_min, wave_max = 1500.0, 3000.0
mask = (q.wave >= wave_min) & (q.wave <= wave_max)

# Plot the iron template
plt.figure(figsize=(11, 5))
plt.plot(q.wave[mask], fe_model[mask]*0.5e2, color="tab:orange", lw=2.0, label="Fe UV template")
plt.xlabel("Rest wavelength [A]")
plt.ylabel("Flux [1e-17 cgs]")
plt.title("Iron (Fe) UV Template (1500-3000 AA)")

wave_rest = lam / (1.0 + z)

plt.plot(wave_rest, flux, color="k", lw=0.8)
plt.xlabel("Rest wavelength [A]")
plt.ylabel("Flux [1e-17 cgs]")
plt.xlim(1800.0, 3000.0)
plt.title("Input local SDSS spectrum")
plt.grid(alpha=0.2)
plt.show()
[ ]:
q = QSOFit(
    lam=lam,
    flux=flux,
    err=err,
    z=z,
    ra=ra,
    dec=dec,
    filename=sdss_filename,
    output_path=".",
)

q.fit(
    deredden=True,
    #wave_range=(1200.0, 3200.0),
    fit_method="optax+nuts",
    fit_lines=True,
    decompose_host=False,
    fit_pl=True,
    fit_fe=True,
    fit_bc=False,
    fit_bal=True,
    fit_poly=True,
    fit_reddening=True,
    prior_config=prior_config,
    dsps_ssp_fn="../tempdata.h5",
    optax_steps=600,
    optax_lr=1e-2,
    nuts_warmup=10,
    nuts_samples=10,
    nuts_chains=1,
    plot_fig=True,
    save_fig=False,
    save_result=False,
    show_plot=False,
)

5. Inspect the fitted BAL components#

The built-in BAL troughs appear in q.custom_components, for example q.custom_components['bal_civ'].

[ ]:
print("Available custom components:", list(getattr(q, "custom_components", {}).keys()))

bal_model = np.asarray(q.custom_components["bal_civ"])
cont_model = np.asarray(q.f_conti_model)
total_model = np.asarray(q.model_total)

plt.figure(figsize=(11, 5))
plt.plot(q.wave, q.flux, color="0.5", lw=0.8, label="data")
plt.plot(q.wave, total_model, color="k", lw=1.5, label="full model")
plt.plot(q.wave, cont_model, color="tab:blue", lw=1.2, label="continuum")
plt.plot(q.wave, bal_model, color="tab:red", lw=2.0, label="BAL custom component")
plt.axvline(bal_line_rest, color="tab:red", ls="--", alpha=0.5)
plt.xlim(1300.0, 1800.0)
plt.xlabel("Rest wavelength [A]")
plt.ylabel("Flux [1e-17 cgs]")
plt.title("BAL-region fit")
plt.legend(frameon=False)
plt.grid(alpha=0.2)
plt.show()

[ ]:
for key in ["custom_bal_civ_depth", "custom_bal_civ_center", "custom_bal_civ_sigma"]:
    if hasattr(q, key):
        print(key, getattr(q, key))

6. Notes#

  • The built-in BAL set includes N V, Si IV, C IV, C III], and Mg II.

  • If you only want a subset, use the lower-level custom component API instead of fit_bal=True.

  • If the trough is much broader, you may still want a custom BAL setup with different priors.

  • For strong BAL quasars, consider masking unrelated bad regions and narrowing the wavelength range to the science region of interest.