Custom Components Tutorial#
jaxqsofit custom components are fully generic additive terms. This notebook shows how to build executable continuum and line extensions using user-defined functions.
Current default: fit_method="optax" runs staged SVI/Optax MAP optimization (continuum warm start, then full model). fit_method="optax+nuts" uses that staged MAP point to initialize NUTS.
1. Imports#
The generic APIs are make_custom_component(...) for continuum-like additive terms and make_custom_line_component(...) for emission-line terms.
[ ]:
import numpy as np
import jax.numpy as jnp
import matplotlib.pyplot as plt
from jaxqsofit import (
QSOFit,
build_default_prior_config,
make_custom_component,
make_custom_line_component,
make_template_component,
)
from jaxqsofit.model import _fe_template_component
2. Build A Small Synthetic Spectrum#
This keeps the examples executable without relying on external downloads.
[ ]:
rng = np.random.default_rng(0)
z = 0.1
wave_rest = np.linspace(2000.0, 5500.0, 800)
lam_obs = wave_rest * (1.0 + z)
powerlaw = 1.5 * (wave_rest / 3000.0) ** (-1.2)
hbeta = 0.6 * np.exp(-0.5 * ((wave_rest - 4862.68) / 18.0) ** 2)
noise_sigma = 0.05
err = np.full_like(wave_rest, noise_sigma)
flux = powerlaw + hbeta + rng.normal(0.0, noise_sigma, size=wave_rest.size)
plt.figure(figsize=(9, 3))
plt.plot(lam_obs, flux, lw=1, label='observed synthetic spectrum')
plt.xlabel('Observed Wavelength [A]')
plt.ylabel('Flux Density')
plt.legend()
plt.tight_layout()
3. Generic Continuum Component#
A custom continuum component is just a function evaluate(wave, params, metadata) plus a prior dictionary.
[ ]:
def gaussian_bump_component(wave, params, metadata):
mu = metadata['mu']
sigma = metadata['sigma']
x = (wave - mu) / sigma
return params['amp'] * jnp.exp(-0.5 * x**2)
custom_components = [
make_custom_component(
name='gaussian_bump',
parameter_priors={
'amp': {'dist': 'LogNormal', 'loc': np.log(0.02), 'scale': 0.5},
},
evaluate=gaussian_bump_component,
metadata={'mu': 2800.0, 'sigma': 120.0},
),
]
flux_for_priors = flux * (1.0 + z)
prior_config = build_default_prior_config(flux_for_priors)
custom_components[0]
4. Additive Polynomial As A Normal Custom Component#
There is no special polynomial helper. If you want one, write it as a regular custom component.
[ ]:
def additive_poly_component(wave, params, metadata):
pivot = metadata['pivot']
x = (wave - pivot) / pivot
return params['c0'] + params['c1'] * x + params['c2'] * x**2
poly_component = make_custom_component(
name='blue_excess',
parameter_priors={
'c0': {'dist': 'Normal', 'loc': 0.0, 'scale': 0.2 * np.nanmedian(np.abs(flux_for_priors))},
'c1': {'dist': 'Normal', 'loc': 0.0, 'scale': 0.05 * np.nanmedian(np.abs(flux_for_priors))},
'c2': {'dist': 'Normal', 'loc': 0.0, 'scale': 0.05 * np.nanmedian(np.abs(flux_for_priors))},
},
evaluate=additive_poly_component,
metadata={'pivot': 3000.0},
)
poly_component
5. Template Convenience Wrapper#
make_template_component(...) is still available for the common broadened-template case, but it is only convenience sugar around the generic API.
[ ]:
template_wave = np.linspace(2200.0, 3200.0, 300)
template_flux = np.exp(-0.5 * ((template_wave - 2700.0) / 140.0) ** 2)
template_component = make_template_component(
'alt_feii',
wave=template_wave,
flux=template_flux,
fit_fwhm=True,
fit_shift=True,
)
template_component
6. Fully User-Controlled Template Physics#
If you want to own all the broadening or shifting behavior yourself, put that directly in the evaluator.
[ ]:
def my_feii_component(wave, params, metadata):
return _fe_template_component(
wave=wave,
wave_template=metadata['wave_template'],
flux_template=metadata['flux_template'],
norm=params['norm'],
fwhm_kms=params['fwhm_kms'],
shift_frac=params['shift_frac'],
base_fwhm_kms=metadata['base_fwhm_kms'],
)
manual_template_component = make_custom_component(
name='manual_alt_feii',
parameter_priors={
'norm': {'dist': 'LogNormal', 'loc': np.log(0.02), 'scale': 0.6},
'fwhm_kms': {'dist': 'LogNormal', 'loc': np.log(2500.0), 'scale': 0.3},
'shift_frac': {'dist': 'Normal', 'loc': 0.0, 'scale': 1e-3},
},
evaluate=my_feii_component,
metadata={
'wave_template': template_wave,
'flux_template': template_flux,
'base_fwhm_kms': 900.0,
},
)
manual_template_component
7. Custom Line Components#
Custom emission-line components are attached to the line model and tagged as either broad or narrow.
[ ]:
def exponential_line_profile(wave, params, metadata):
center = metadata['center']
tau = params['tau']
amp = params['amp']
x = jnp.abs(wave - center) / tau
return amp * jnp.exp(-x)
custom_line_components = [
make_custom_line_component(
name='exp_hbeta_wing',
parameter_priors={
'amp': {'dist': 'LogNormal', 'loc': np.log(0.02), 'scale': 0.5},
'tau': {'dist': 'LogNormal', 'loc': np.log(15.0), 'scale': 0.3},
},
evaluate=exponential_line_profile,
line_kind='broad',
metadata={'center': 4862.68},
),
]
custom_line_components[0]
8. Simulate A Fake Exponential Wing And Fit It#
Here we inject a broad exponential wing near H\(\\beta\) into the synthetic spectrum, then fit that spectrum with a matching custom_line_component.
Custom line components still work when fit_lines=False; that setting disables the built-in tied-line table while leaving the explicitly supplied custom line component active.
[ ]:
wing_amp_true = 0.35
wing_tau_true = 35.0
wing_flux = np.asarray(
exponential_line_profile(
wave_rest,
params={'amp': wing_amp_true, 'tau': wing_tau_true},
metadata={'center': 4862.68},
)
)
flux_with_wing = flux + wing_flux
plt.figure(figsize=(9, 3))
plt.plot(wave_rest, flux_with_wing, lw=1, label='synthetic spectrum with exponential wing')
plt.plot(wave_rest, wing_flux, lw=2, label='injected wing')
plt.xlim(4700, 5050)
plt.xlabel('Rest Wavelength [A]')
plt.ylabel('Flux Density')
plt.legend()
plt.tight_layout()
[ ]:
flux_with_wing_for_priors = flux_with_wing * (1.0 + z)
wing_prior_config = build_default_prior_config(flux_with_wing_for_priors)
wing_prior_config['custom_line_exp_hbeta_wing_amp'] = {
'dist': 'LogNormal',
'loc': np.log(0.2),
'scale': 0.7,
}
wing_prior_config['custom_line_exp_hbeta_wing_tau'] = {
'dist': 'LogNormal',
'loc': np.log(25.0),
'scale': 0.5,
}
q_wing = QSOFit(lam=lam_obs, flux=flux_with_wing, err=err, z=z)
q_wing.fit(
deredden=False,
prior_config=wing_prior_config,
custom_line_components=custom_line_components,
fit_fe=False,
fit_bc=False,
fit_poly=False,
fit_lines=False,
# Use MAP-only here so the synthetic custom-line example runs quickly.
fit_method='optax',
decompose_host=False,
optax_steps=2000,
save_result=False,
plot_fig=True,
save_fig=False,
verbose=False,
)
9. Prior Naming Convention#
Each local parameter gets expanded into a global prior/sample key:
continuum component
gaussian_bump, parameteramp->custom_gaussian_bump_ampline component
exp_hbeta_wing, parametertau->custom_line_exp_hbeta_wing_tau
10. Save/Load Note#
If you want QSOFit.load_from_samples(...) to restore custom components cleanly, define evaluators as importable top-level Python functions.
Avoid lambdas, nested functions, and local closures for production use.