Custom Components Tutorial#

jaxqsofit custom components are fully generic additive terms. This notebook shows how to build executable continuum and line extensions using user-defined functions.

Current default: fit_method="optax" runs staged SVI/Optax MAP optimization (continuum warm start, then full model). fit_method="optax+nuts" uses that staged MAP point to initialize NUTS.

1. Imports#

The generic APIs are make_custom_component(...) for continuum-like additive terms and make_custom_line_component(...) for emission-line terms.

[ ]:
import numpy as np
import jax.numpy as jnp
import matplotlib.pyplot as plt

from jaxqsofit import (
    QSOFit,
    build_default_prior_config,
    make_custom_component,
    make_custom_line_component,
    make_template_component,
)
from jaxqsofit.model import _fe_template_component

2. Build A Small Synthetic Spectrum#

This keeps the examples executable without relying on external downloads.

[ ]:
rng = np.random.default_rng(0)

z = 0.1
wave_rest = np.linspace(2000.0, 5500.0, 800)
lam_obs = wave_rest * (1.0 + z)

powerlaw = 1.5 * (wave_rest / 3000.0) ** (-1.2)
hbeta = 0.6 * np.exp(-0.5 * ((wave_rest - 4862.68) / 18.0) ** 2)
noise_sigma = 0.05
err = np.full_like(wave_rest, noise_sigma)
flux = powerlaw + hbeta + rng.normal(0.0, noise_sigma, size=wave_rest.size)

plt.figure(figsize=(9, 3))
plt.plot(lam_obs, flux, lw=1, label='observed synthetic spectrum')
plt.xlabel('Observed Wavelength [A]')
plt.ylabel('Flux Density')
plt.legend()
plt.tight_layout()

3. Generic Continuum Component#

A custom continuum component is just a function evaluate(wave, params, metadata) plus a prior dictionary.

[ ]:
def gaussian_bump_component(wave, params, metadata):
    mu = metadata['mu']
    sigma = metadata['sigma']
    x = (wave - mu) / sigma
    return params['amp'] * jnp.exp(-0.5 * x**2)

custom_components = [
    make_custom_component(
        name='gaussian_bump',
        parameter_priors={
            'amp': {'dist': 'LogNormal', 'loc': np.log(0.02), 'scale': 0.5},
        },
        evaluate=gaussian_bump_component,
        metadata={'mu': 2800.0, 'sigma': 120.0},
    ),
]

flux_for_priors = flux * (1.0 + z)
prior_config = build_default_prior_config(flux_for_priors)
custom_components[0]

4. Additive Polynomial As A Normal Custom Component#

There is no special polynomial helper. If you want one, write it as a regular custom component.

[ ]:
def additive_poly_component(wave, params, metadata):
    pivot = metadata['pivot']
    x = (wave - pivot) / pivot
    return params['c0'] + params['c1'] * x + params['c2'] * x**2

poly_component = make_custom_component(
    name='blue_excess',
    parameter_priors={
        'c0': {'dist': 'Normal', 'loc': 0.0, 'scale': 0.2 * np.nanmedian(np.abs(flux_for_priors))},
        'c1': {'dist': 'Normal', 'loc': 0.0, 'scale': 0.05 * np.nanmedian(np.abs(flux_for_priors))},
        'c2': {'dist': 'Normal', 'loc': 0.0, 'scale': 0.05 * np.nanmedian(np.abs(flux_for_priors))},
    },
    evaluate=additive_poly_component,
    metadata={'pivot': 3000.0},
)

poly_component

5. Template Convenience Wrapper#

make_template_component(...) is still available for the common broadened-template case, but it is only convenience sugar around the generic API.

[ ]:
template_wave = np.linspace(2200.0, 3200.0, 300)
template_flux = np.exp(-0.5 * ((template_wave - 2700.0) / 140.0) ** 2)

template_component = make_template_component(
    'alt_feii',
    wave=template_wave,
    flux=template_flux,
    fit_fwhm=True,
    fit_shift=True,
)

template_component

6. Fully User-Controlled Template Physics#

If you want to own all the broadening or shifting behavior yourself, put that directly in the evaluator.

[ ]:
def my_feii_component(wave, params, metadata):
    return _fe_template_component(
        wave=wave,
        wave_template=metadata['wave_template'],
        flux_template=metadata['flux_template'],
        norm=params['norm'],
        fwhm_kms=params['fwhm_kms'],
        shift_frac=params['shift_frac'],
        base_fwhm_kms=metadata['base_fwhm_kms'],
    )

manual_template_component = make_custom_component(
    name='manual_alt_feii',
    parameter_priors={
        'norm': {'dist': 'LogNormal', 'loc': np.log(0.02), 'scale': 0.6},
        'fwhm_kms': {'dist': 'LogNormal', 'loc': np.log(2500.0), 'scale': 0.3},
        'shift_frac': {'dist': 'Normal', 'loc': 0.0, 'scale': 1e-3},
    },
    evaluate=my_feii_component,
    metadata={
        'wave_template': template_wave,
        'flux_template': template_flux,
        'base_fwhm_kms': 900.0,
    },
)

manual_template_component

7. Custom Line Components#

Custom emission-line components are attached to the line model and tagged as either broad or narrow.

[ ]:
def exponential_line_profile(wave, params, metadata):
    center = metadata['center']
    tau = params['tau']
    amp = params['amp']
    x = jnp.abs(wave - center) / tau
    return amp * jnp.exp(-x)

custom_line_components = [
    make_custom_line_component(
        name='exp_hbeta_wing',
        parameter_priors={
            'amp': {'dist': 'LogNormal', 'loc': np.log(0.02), 'scale': 0.5},
            'tau': {'dist': 'LogNormal', 'loc': np.log(15.0), 'scale': 0.3},
        },
        evaluate=exponential_line_profile,
        line_kind='broad',
        metadata={'center': 4862.68},
    ),
]

custom_line_components[0]

8. Simulate A Fake Exponential Wing And Fit It#

Here we inject a broad exponential wing near H\(\\beta\) into the synthetic spectrum, then fit that spectrum with a matching custom_line_component.

Custom line components still work when fit_lines=False; that setting disables the built-in tied-line table while leaving the explicitly supplied custom line component active.

[ ]:
wing_amp_true = 0.35
wing_tau_true = 35.0
wing_flux = np.asarray(
    exponential_line_profile(
        wave_rest,
        params={'amp': wing_amp_true, 'tau': wing_tau_true},
        metadata={'center': 4862.68},
    )
)

flux_with_wing = flux + wing_flux

plt.figure(figsize=(9, 3))
plt.plot(wave_rest, flux_with_wing, lw=1, label='synthetic spectrum with exponential wing')
plt.plot(wave_rest, wing_flux, lw=2, label='injected wing')
plt.xlim(4700, 5050)
plt.xlabel('Rest Wavelength [A]')
plt.ylabel('Flux Density')
plt.legend()
plt.tight_layout()

[ ]:
flux_with_wing_for_priors = flux_with_wing * (1.0 + z)
wing_prior_config = build_default_prior_config(flux_with_wing_for_priors)
wing_prior_config['custom_line_exp_hbeta_wing_amp'] = {
    'dist': 'LogNormal',
    'loc': np.log(0.2),
    'scale': 0.7,
}
wing_prior_config['custom_line_exp_hbeta_wing_tau'] = {
    'dist': 'LogNormal',
    'loc': np.log(25.0),
    'scale': 0.5,
}

q_wing = QSOFit(lam=lam_obs, flux=flux_with_wing, err=err, z=z)
q_wing.fit(
    deredden=False,
    prior_config=wing_prior_config,
    custom_line_components=custom_line_components,
    fit_fe=False,
    fit_bc=False,
    fit_poly=False,
    fit_lines=False,
    # Use MAP-only here so the synthetic custom-line example runs quickly.
    fit_method='optax',
    decompose_host=False,
    optax_steps=2000,
    save_result=False,
    plot_fig=True,
    save_fig=False,
    verbose=False,
)

9. Prior Naming Convention#

Each local parameter gets expanded into a global prior/sample key:

  • continuum component gaussian_bump, parameter amp -> custom_gaussian_bump_amp

  • line component exp_hbeta_wing, parameter tau -> custom_line_exp_hbeta_wing_tau

10. Save/Load Note#

If you want QSOFit.load_from_samples(...) to restore custom components cleanly, define evaluators as importable top-level Python functions.

Avoid lambdas, nested functions, and local closures for production use.