Source code for jaxqsofit.plotting

"""Plotting helpers for :class:`jaxqsofit.JAXQSOFit`.

The public functions in this module take a fitted ``JAXQSOFit`` instance as
first argument. ``JAXQSOFit`` keeps method wrappers around them so notebook code
can continue to call ``fitter.plot_spectrum()`` and related methods.
"""

from __future__ import annotations

import os

import matplotlib
import matplotlib.pyplot as plt
import numpy as np

from .core import (
    _filter_wave_to_angstrom_array,
    _filter_wave_to_angstrom_scalar,
    _get_sdss_filters,
)

__all__ = [
    "posterior_series",
    "filter_half_width_angstrom",
    "plot_filter_metadata",
    "style_axis",
    "synthetic_photometry_for_plot",
    "observed_photometry_for_plot",
    "plot_trace",
    "plot_corner",
    "plot_mcmc_diagnostics",
    "plot_spectrum",
    "plot_fig",
]


[docs] def posterior_series(fitter, param_names=None, max_vector_elems=2): """Flatten posterior samples into labeled 1D series for diagnostics. Parameters ---------- param_names : list[str] | str | None, optional Parameter selector. Use ``'all'`` for all posterior keys. max_vector_elems : int or None, optional Maximum number of vector elements to expand per key. """ if not hasattr(fitter, 'numpyro_samples') or fitter.numpyro_samples is None: return [] samples = fitter.numpyro_samples if param_names == 'all': param_names = sorted(samples.keys()) elif param_names is None: param_names = [ 'cont_norm', 'log_frac_host', 'PL_norm', 'PL_slope', 'Fe_uv_norm', 'log_Fe_op_over_uv', 'Fe_uv_FWHM', 'Fe_op_FWHM', 'Balmer_norm', 'Balmer_Tau', 'Balmer_vel', 'gal_v_kms', 'gal_sigma_kms', 'frac_jitter', 'add_jitter', ] out = [] for name in param_names: if name not in samples: continue arr = np.asarray(samples[name]) if arr.ndim == 1: out.append((name, arr)) elif arr.ndim >= 2: arr2 = arr.reshape(arr.shape[0], -1) nflat = arr2.shape[1] if max_vector_elems is None or int(max_vector_elems) < 0: ncomp = nflat else: ncomp = min(nflat, int(max_vector_elems)) for i in range(ncomp): out.append((f'{name}[{i}]', arr2[:, i])) return out
def filter_half_width_angstrom(filt): """Return an approximate half-width for a photometric filter.""" filt_wave = _filter_wave_to_angstrom_array(filt.wave) filt_trans = np.asarray(filt.transmission, dtype=float) support = filt_wave[filt_trans > 0.01 * np.nanmax(filt_trans)] if support.size >= 2: return 0.5 * float(support.max() - support.min()) return 0.0 def plot_filter_metadata(fitter, bands): """Return plotting metadata arrays for the requested photometric bands.""" filters = _get_sdss_filters() filt_list = [filters.get(str(band)) for band in bands] valid = np.asarray([filt is not None for filt in filt_list], dtype=bool) filt_list = [filt for filt in filt_list if filt is not None] if len(filt_list) == 0: return valid, np.array([], dtype=float), np.array([], dtype=float) eff_wave_obs = np.asarray( [_filter_wave_to_angstrom_scalar(filt.effective_wavelength) for filt in filt_list], dtype=float, ) half_width_obs = np.asarray( [fitter._filter_half_width_angstrom(filt) for filt in filt_list], dtype=float, ) return valid, eff_wave_obs, half_width_obs def style_axis(ax, spine_lw=1.5): """Apply consistent axis styling. Parameters ---------- ax : matplotlib.axes.Axes Axis to style. spine_lw : float, optional Line width for all spines. """ ax.grid(False) ax.tick_params( which='both', direction='in', top=True, right=True, length=6, width=1.2, ) for spine in ax.spines.values(): spine.set_linewidth(spine_lw)
[docs] def synthetic_photometry_for_plot(fitter, model_attr='model_total'): """Return rest-frame synthetic photometry points for plotting, if available.""" if not bool(getattr(fitter, 'use_psf_phot', False)): return None bands = list(getattr(fitter, 'psf_bands', []) or []) mag_errs = np.asarray(getattr(fitter, 'psf_mag_errs', []), dtype=float) if len(bands) == 0 or mag_errs.size != len(bands): return None if not hasattr(fitter, 'wave') or len(getattr(fitter, 'wave', [])) == 0: return None model_rf = np.asarray(getattr(fitter, model_attr, []), dtype=float) if model_rf.size != len(fitter.wave) or not np.any(np.isfinite(model_rf)): return None wave_rf = np.asarray(fitter.wave, dtype=float) z = float(getattr(fitter, 'z', 0.0)) wave_obs = wave_rf * (1.0 + z) flam_obs = model_rf / max(1.0 + z, 1e-8) filters = _get_sdss_filters() c_ang_s = 2.99792458e18 x_rf, xerr_rf, y_rf, yerr_rf = [], [], [], [] for band, mag_err in zip(bands, mag_errs): filt = filters.get(str(band)) if filt is None or not np.isfinite(mag_err) or mag_err <= 0: continue filt_wave = _filter_wave_to_angstrom_array(filt.wave) filt_trans = np.asarray(filt.transmission, dtype=float) trans = np.interp(wave_obs, filt_wave, filt_trans, left=0.0, right=0.0) trans = np.clip(trans, 0.0, None) flam_obs_cgs = 1e-17 * flam_obs num = float(np.trapezoid(flam_obs_cgs * trans * wave_obs, wave_obs)) den = float(np.trapezoid(trans * c_ang_s / np.clip(wave_obs, 1e-8, None), wave_obs)) if not np.isfinite(num) or not np.isfinite(den) or num <= 0 or den <= 0: continue mag_syn = -2.5 * np.log10(num / den) - 48.60 if not np.isfinite(mag_syn): continue eff_wave_obs = _filter_wave_to_angstrom_scalar(filt.effective_wavelength) support = filt_wave[filt_trans > 0.01 * np.nanmax(filt_trans)] if support.size >= 2: half_width_obs = 0.5 * float(support.max() - support.min()) else: half_width_obs = 0.0 fnu_obs = 10.0 ** (-0.4 * (mag_syn + 48.60)) flam_band_obs_cgs = fnu_obs * c_ang_s / max(eff_wave_obs ** 2, 1e-30) flam_band_obs = flam_band_obs_cgs / 1e-17 flam_band_rf = flam_band_obs * (1.0 + z) flam_band_rf_err = flam_band_rf * (0.4 * np.log(10.0) * float(mag_err)) x_rf.append(eff_wave_obs / (1.0 + z)) xerr_rf.append(half_width_obs / (1.0 + z)) y_rf.append(flam_band_rf) yerr_rf.append(flam_band_rf_err) if len(x_rf) == 0: return None return ( np.asarray(x_rf, dtype=float), np.asarray(xerr_rf, dtype=float), np.asarray(y_rf, dtype=float), np.asarray(yerr_rf, dtype=float), )
[docs] def observed_photometry_for_plot(fitter): """Return rest-frame observed PSF photometry points for plotting, if available.""" if not bool(getattr(fitter, 'use_psf_phot', False)): return None bands = list(getattr(fitter, 'psf_bands', []) or []) mags = np.asarray(getattr(fitter, 'psf_mags', []), dtype=float) mag_errs = np.asarray(getattr(fitter, 'psf_mag_errs', []), dtype=float) if len(bands) == 0 or mags.size != len(bands) or mag_errs.size != len(bands): return None z = float(getattr(fitter, 'z', 0.0)) c_ang_s = 2.99792458e18 filter_valid, eff_wave_obs, half_width_obs = fitter._plot_filter_metadata(bands) phot_valid = np.isfinite(mags) & np.isfinite(mag_errs) & (mag_errs > 0) valid = filter_valid & phot_valid if not np.any(valid): return None mags = mags[valid] mag_errs = mag_errs[valid] eff_wave_obs = eff_wave_obs[phot_valid[filter_valid]] half_width_obs = half_width_obs[phot_valid[filter_valid]] fnu_obs = 10.0 ** (-0.4 * (mags + 48.60)) flam_band_obs_cgs = fnu_obs * c_ang_s / np.clip(eff_wave_obs ** 2, 1e-30, None) flam_band_obs = flam_band_obs_cgs / 1e-17 flam_band_rf = flam_band_obs * (1.0 + z) flam_band_rf_err = flam_band_rf * (0.4 * np.log(10.0) * mag_errs) return ( eff_wave_obs / (1.0 + z), half_width_obs / (1.0 + z), flam_band_rf, flam_band_rf_err, )
[docs] def plot_trace( fitter, param_names=None, max_vector_elems=2, save_fig_path=None, save_fig_name=None, show_plot=False, ): """Plot posterior trace series for selected parameters. Parameters ---------- param_names : list[str] | str | None, optional Parameter selector. Use ``'all'`` to include all posterior keys. max_vector_elems : int or None, optional Maximum number of vector elements to expand per key. save_fig_path : str or None, optional Output directory when saving figures. If ``None``, uses ``fitter.output_path`` (or ``'.'`` when unset). save_fig_name : str or None, optional Output filename override. show_plot : bool, optional If True, display the figure interactively with ``plt.show()``. Defaults to False so diagnostics are safe to call in headless terminals. """ series = fitter._posterior_series(param_names=param_names, max_vector_elems=max_vector_elems) if len(series) == 0: return None n = len(series) fig, axes = plt.subplots(n, 1, figsize=(10, max(2.2 * n, 4)), sharex=True) if n == 1: axes = [axes] for ax, (label, vals) in zip(axes, series): ax.plot(np.arange(len(vals)), vals, color='tab:blue', lw=0.8) ax.set_ylabel(label, fontsize=9) fitter._style_axis(ax) axes[-1].set_xlabel('Sample', fontsize=10) fig.tight_layout() if show_plot: plt.show() if fitter.save_fig: out_name = f'{fitter.filename}_trace.pdf' if save_fig_name is None else save_fig_name save_dir = fitter.output_path if save_fig_path is None else save_fig_path if save_dir is None: save_dir = '.' os.makedirs(save_dir, exist_ok=True) out_file = os.path.join(save_dir, out_name) fig.savefig(out_file) print(f"Saved trace plot: {out_file}") plt.close(fig) fitter.trace_fig = fig return fig
[docs] def plot_corner( fitter, param_names=None, max_vector_elems=2, bins=30, max_points=5000, save_fig_path=None, save_fig_name=None, show_plot=False, ): """Plot posterior projections with ``corner.corner``. Parameters ---------- param_names : list[str] | str | None, optional Parameter selector. Use ``'all'`` to include all posterior keys. max_vector_elems : int or None, optional Maximum number of vector elements to expand per key. bins : int, optional Histogram bin count. max_points : int, optional Maximum posterior draws to plot (subsampled if needed). save_fig_path : str or None, optional Output directory when saving figures. If ``None``, uses ``fitter.output_path`` (or ``'.'`` when unset). save_fig_name : str or None, optional Output filename override. show_plot : bool, optional If True, display the figure interactively with ``plt.show()``. Defaults to False so diagnostics are safe to call in headless terminals. """ series = fitter._posterior_series(param_names=param_names, max_vector_elems=max_vector_elems) if len(series) == 0: return None try: import corner except ImportError as exc: raise ImportError("plot_corner requires the 'corner' package to be installed.") from exc labels = [s[0] for s in series] data = np.column_stack([s[1] for s in series]) if data.shape[0] > int(max_points): idx = np.linspace(0, data.shape[0] - 1, int(max_points), dtype=int) data = data[idx] fig = corner.corner( data, labels=labels, bins=bins, show_titles=True, color="black", quantiles=[0.16, 0.5, 0.84], plot_datapoints=False, plot_contours=True, hist2d_kwargs={"bins": max(8, int(bins // 2)), "levels": [0.393, 0.865, 0.989]}, fill_contours=False, no_fill_contours=True, smooth=0.8, smooth1d=0.8, max_n_ticks=3, quiet=True, labelpad=0.3, label_kwargs={"fontsize": 9}, title_kwargs={"fontsize": 9}, use_math_text=False, title_fmt=".3g", ) for ax in fig.axes: ax.tick_params(axis='both', which='major', labelsize=8) fitter._style_axis(ax) fig.subplots_adjust(left=0.08, bottom=0.08, right=0.98, top=0.98, wspace=0.06, hspace=0.06) if show_plot: plt.show() if fitter.save_fig: out_name = f'{fitter.filename}_corner.pdf' if save_fig_name is None else save_fig_name save_dir = fitter.output_path if save_fig_path is None else save_fig_path if save_dir is None: save_dir = '.' os.makedirs(save_dir, exist_ok=True) out_file = os.path.join(save_dir, out_name) fig.savefig(out_file) print(f"Saved corner plot: {out_file}") plt.close(fig) fitter.corner_fig = fig return fig
[docs] def plot_mcmc_diagnostics(fitter, do_trace=True, do_corner=True, param_names=None, max_vector_elems=2, corner_bins=30, corner_max_points=2000, save_fig_path=None, show_plot=False): """Plot trace and/or corner diagnostics in a single convenience call. Parameters ---------- do_trace : bool, optional If True, render trace plot. do_corner : bool, optional If True, render corner plot. param_names : list[str] | str | None Parameter selector shared by both trace and corner plots. Use ``'all'`` to include all posterior parameters. max_vector_elems : int or None, optional Maximum number of vector elements to expand per key. corner_bins : int, optional Histogram bin count for corner plot. corner_max_points : int, optional Maximum posterior draws to use in corner plot. save_fig_path : str or None, optional Output directory when saving figures. If ``None``, uses ``fitter.output_path`` (or ``'.'`` when unset). show_plot : bool, optional If True, display each enabled diagnostics figure interactively with ``plt.show()``. Defaults to False so diagnostics are safe to call in headless terminals. """ if do_trace: fitter.plot_trace( param_names=param_names, max_vector_elems=max_vector_elems, save_fig_path=save_fig_path, show_plot=show_plot, ) if do_corner: fitter.plot_corner( param_names=param_names, max_vector_elems=max_vector_elems, bins=corner_bins, max_points=corner_max_points, save_fig_path=save_fig_path, show_plot=show_plot, )
[docs] def plot_spectrum(fitter, **kwargs): """Plot the fitted spectrum, model components, and residuals. This is the preferred public plotting method. It delegates to :meth:`plot_fig`, which remains available for compatibility with older notebooks. """ return fitter.plot_fig(**kwargs)
[docs] def plot_fig(fitter, save_fig_path=None, broad_fwhm=1200, plot_legend=True, ylims=None, plot_residual=True, show_title=True, plot_1sigma=True, sigma_alpha=0.12, show_plot=True, plot_psf_space=False, plot_intrinsic_powerlaw=False): """Plot data, model components, line decomposition, and residuals. Parameters ---------- save_fig_path : str or None, optional Output directory when saving figures. If ``None``, uses ``fitter.output_path`` (or ``'.'`` when unset). broad_fwhm : float, optional Reserved broad-line threshold (kept for compatibility). plot_legend : bool, optional If True, draw a legend. ylims : tuple[float, float] or None, optional Optional y-axis limits for the spectrum panel. plot_residual : bool, optional If True, draw residual panel below the spectrum. show_title : bool, optional Reserved title toggle kept for compatibility. plot_1sigma : bool, optional If True, draw 16-84% posterior bands for available components. sigma_alpha : float, optional Alpha transparency of 1-sigma bands. show_plot : bool, optional If True, call ``plt.show()``. If False, figure is created/saved without display. plot_psf_space : bool, optional If True, plot the PSF-space model/components. In this mode the residual panel is suppressed because the observed spectrum remains on the fiber scale. plot_intrinsic_powerlaw : bool, optional If True, overlay the intrinsic AGN power law before multiplicative polynomial tilt and additive edge corrections. """ if bool(getattr(fitter, "_resumed_from_samples", False)): fitter._ensure_hydrated_from_samples() matplotlib.rc('xtick', labelsize=20) matplotlib.rc('ytick', labelsize=20) psf_total_model = np.asarray(getattr(fitter, 'psf_model', []), dtype=float) use_psf_space = bool(plot_psf_space) and psf_total_model.size == len(getattr(fitter, 'wave', [])) and np.any(np.isfinite(psf_total_model)) residual_enabled = bool(plot_residual) and not use_psf_space if residual_enabled: fig, (ax, ax_resid) = plt.subplots( 2, 1, figsize=(15, 8), sharex=True, gridspec_kw={'height_ratios': [4.0, 1.2], 'hspace': 0.05}, ) else: fig, ax = plt.subplots(1, 1, figsize=(15, 6)) ax_resid = None flux_ref = float(np.nanpercentile(np.abs(fitter.flux[np.isfinite(fitter.flux)]), 95)) if np.any(np.isfinite(fitter.flux)) else 1.0 comp_floor = max(1e-8, 0.005 * flux_ref) psf_scale = float(getattr(fitter, 'scale_psf', np.nan)) psf_scale = psf_scale if np.isfinite(psf_scale) else 1.0 total_model_plot = fitter.psf_model if use_psf_space else fitter.model_total host_plot = fitter.host_psf if use_psf_space else fitter.host pl_plot = psf_scale * fitter.f_pl_model if use_psf_space else fitter.f_pl_model pl_intrinsic = np.asarray(getattr(fitter, 'f_pl_model_intrinsic', []), dtype=float) pl_intrinsic_plot = psf_scale * pl_intrinsic if use_psf_space and pl_intrinsic.size == len(fitter.wave) else pl_intrinsic fe_total_model = psf_scale * (fitter.f_fe_mgii_model + fitter.f_fe_balmer_model) if use_psf_space else (fitter.f_fe_mgii_model + fitter.f_fe_balmer_model) bc_plot = psf_scale * fitter.f_bc_model if use_psf_space else fitter.f_bc_model line_plot = fitter.line_psf if use_psf_space else fitter.f_line_model total_model_label = 'total model (PSF)' if use_psf_space else 'total model' host_label = 'host galaxy (PSF)' if use_psf_space else 'host galaxy' powerlaw_label = 'power law (PSF)' if use_psf_space else 'power law' intrinsic_powerlaw_label = 'intrinsic power law (PSF)' if use_psf_space else 'intrinsic power law' fe_label = 'Fe II (PSF)' if use_psf_space else 'Fe II' bc_label = 'Balmer continuum (PSF)' if use_psf_space else 'Balmer continuum' line_label = 'total lines (PSF)' if use_psf_space else 'total lines' custom_component_colors = [ 'darkorange', 'crimson', 'slateblue', 'seagreen', 'saddlebrown', 'deeppink', ] custom_components = list(getattr(fitter, 'custom_components', {}).items()) def _show_component(arr): """Return True when a component has finite amplitude worth plotting.""" arr = np.asarray(arr, dtype=float) arr = arr[np.isfinite(arr)] if arr.size == 0: return False return float(np.nanmax(np.abs(arr))) >= comp_floor def _finite_component_values(*arrays): """Collect finite values from one or more component arrays.""" vals = [] for arr in arrays: if arr is None: continue arr = np.asarray(arr, dtype=float) arr = arr[np.isfinite(arr)] if arr.size > 0: vals.append(arr) return vals if plot_1sigma and hasattr(fitter, 'pred_bands') and not use_psf_space: band_colors = { 'total_model': 'b', 'host': 'purple', 'PL': 'orange', 'FeII': 'teal', 'Balmer_cont': 'y', 'lines': 'lightskyblue', } for key, color in band_colors.items(): if key not in fitter.pred_bands: continue lo, hi = fitter.pred_bands[key] if len(lo) == len(fitter.wave) and _show_component(0.5 * (np.asarray(lo) + np.asarray(hi))): ax.fill_between( fitter.wave, lo, hi, color=color, alpha=sigma_alpha, linewidth=0, zorder=0, rasterized=True, ) for idx, (name, model) in enumerate(custom_components): if name not in fitter.pred_bands: continue lo, hi = fitter.pred_bands[name] if len(lo) != len(fitter.wave) or not _show_component(model): continue ax.fill_between( fitter.wave, lo, hi, color=custom_component_colors[idx % len(custom_component_colors)], alpha=sigma_alpha, linewidth=0, zorder=0, rasterized=True, ) if bool(plot_intrinsic_powerlaw) and hasattr(fitter, 'pred_bands') and not use_psf_space: if 'PL_intrinsic' in fitter.pred_bands: lo, hi = fitter.pred_bands['PL_intrinsic'] if len(lo) == len(fitter.wave) and _show_component(0.5 * (np.asarray(lo) + np.asarray(hi))): ax.fill_between( fitter.wave, lo, hi, color='darkorange', alpha=sigma_alpha, linewidth=0, zorder=0, rasterized=True, ) ax.plot( fitter.wave_prereduced, fitter.flux_prereduced, color='k' if not use_psf_space else 'gray', lw=1, label='data' if not use_psf_space else 'fiber data', zorder=2, alpha=1.0 if not use_psf_space else 0.6, rasterized=True, ) ax.plot(fitter.wave, total_model_plot, color='b', lw=1.8, label=total_model_label, zorder=6, rasterized=True) if _show_component(host_plot): ax.plot(fitter.wave, host_plot, color='purple', lw=1.8, label=host_label, zorder=4, rasterized=True) else: ax.plot(fitter.wave, host_plot, color='purple', lw=1.8, zorder=4, rasterized=True) if _show_component(pl_plot): ax.plot(fitter.wave, pl_plot, color='orange', lw=1.5, label=powerlaw_label, zorder=5, rasterized=True) else: ax.plot(fitter.wave, pl_plot, color='orange', lw=1.5, zorder=5, rasterized=True) if bool(plot_intrinsic_powerlaw) and pl_intrinsic_plot.size == len(fitter.wave): if _show_component(pl_intrinsic_plot): ax.plot( fitter.wave, pl_intrinsic_plot, color='darkorange', lw=1.4, ls='--', label=intrinsic_powerlaw_label, zorder=5, rasterized=True, ) else: ax.plot(fitter.wave, pl_intrinsic_plot, color='darkorange', lw=1.4, ls='--', zorder=5, rasterized=True) if _show_component(fe_total_model): ax.plot(fitter.wave, fe_total_model, color='teal', lw=1.2, label=fe_label, zorder=5, rasterized=True) else: ax.plot(fitter.wave, fe_total_model, color='teal', lw=1.2, zorder=5, rasterized=True) if _show_component(bc_plot): ax.plot(fitter.wave, bc_plot, color='y', lw=1.2, label=bc_label, zorder=5, rasterized=True) else: ax.plot(fitter.wave, bc_plot, color='y', lw=1.2, zorder=5, rasterized=True) bal_legend_drawn = False for idx, (name, model) in enumerate(custom_components): color = custom_component_colors[idx % len(custom_component_colors)] is_bal_component = str(name).startswith('bal_') if is_bal_component: label = 'BAL' else: label = name.replace('_', ' ') if _show_component(model): draw_label = label if is_bal_component and bal_legend_drawn: draw_label = None ax.plot(fitter.wave, model, color=color, lw=1.4, label=draw_label, zorder=5, rasterized=True) if is_bal_component and draw_label is not None: bal_legend_drawn = True else: ax.plot(fitter.wave, model, color=color, lw=1.4, zorder=5, rasterized=True) if len(line_plot) == len(fitter.wave): if _show_component(line_plot): ax.plot( fitter.wave, line_plot, color='lightskyblue', lw=1.5, label=line_label, zorder=5, rasterized=True, ) else: ax.plot( fitter.wave, line_plot, color='lightskyblue', lw=1.5, label=line_label, zorder=5, rasterized=True, ) obs_phot_points = fitter._observed_photometry_for_plot() if obs_phot_points is not None: phot_x, phot_xerr, phot_y, phot_yerr = obs_phot_points ax.errorbar( phot_x, phot_y, yerr=phot_yerr, xerr=phot_xerr, fmt='s', ms=7, color='black', mfc='white', mec='black', mew=0.8, elinewidth=1.0, capsize=3, alpha=0.95, zorder=8, label='PSF photometry', ) syn_phot_points = fitter._synthetic_photometry_for_plot( model_attr='psf_model' if use_psf_space else 'model_total' ) if syn_phot_points is not None: phot_x, phot_xerr, phot_y, phot_yerr = syn_phot_points ax.errorbar( phot_x, phot_y, yerr=phot_yerr, xerr=phot_xerr, fmt='o', ms=7, color='crimson', mec='white', mew=0.8, elinewidth=1.0, capsize=3, alpha=0.95, zorder=8, label='synthetic photometry', ) # Plot individual Gaussian line components: broad (*_br) in red, narrow in green. if (hasattr(fitter, 'line_component_amp_median') and hasattr(fitter, 'line_component_mu_median') and hasattr(fitter, 'line_component_sig_median') and hasattr(fitter, 'tied_line_meta') and len(fitter.line_component_amp_median) > 0): lnwave = np.log(fitter.wave) comp_labels = fitter.tied_line_meta.get('names', [''] * len(fitter.line_component_amp_median)) drew_broad_label = False drew_narrow_label = False show_line_leg = _show_component(line_plot) for i in range(len(fitter.line_component_amp_median)): amp = float(fitter.line_component_amp_median[i]) mu = float(fitter.line_component_mu_median[i]) sig = float(fitter.line_component_sig_median[i]) if not np.isfinite(amp) or not np.isfinite(mu) or not np.isfinite(sig) or sig <= 0: continue prof = amp * np.exp(-0.5 * ((lnwave - mu) / sig) ** 2) # Keep component plotting consistent with polynomial correction if enabled. if hasattr(fitter, 'f_poly_model') and len(fitter.f_poly_model) == len(prof): prof = prof * fitter.f_poly_model cname = str(comp_labels[i]).lower() is_broad = cname.endswith('_br') or ('_br' in cname) if use_psf_space: line_scale = psf_scale if is_broad else psf_scale * float(getattr(fitter, 'eta_psf', 1.0)) prof = line_scale * prof if is_broad: lbl = 'broad components' if (show_line_leg and not drew_broad_label) else None ax.plot( fitter.wave, prof, color='red', lw=0.7, alpha=0.35, zorder=3, label=lbl, rasterized=True, ) drew_broad_label = True else: lbl = 'narrow components' if (show_line_leg and not drew_narrow_label) else None ax.plot( fitter.wave, prof, color='green', lw=0.7, alpha=0.25, zorder=3, label=lbl, rasterized=True, ) drew_narrow_label = True ax.set_xlim(fitter.wave.min(), fitter.wave.max()) if ylims is None: yvals = _finite_component_values( fitter.flux, total_model_plot, host_plot, pl_plot, pl_intrinsic_plot, fe_total_model, bc_plot, line_plot, *[model for _, model in custom_components], ) if yvals: yplot = np.concatenate(yvals) ymin, ymax = np.nanpercentile(yplot, [1, 99]) if np.isfinite(ymin) and np.isfinite(ymax) and ymax > ymin: pad = 0.15 * (ymax - ymin) ax.set_ylim(float(ymin - pad), float(ymax + pad)) else: ax.set_ylim(ylims[0], ylims[1]) # Mark common broad-line AGN transitions on the spectrum panel. broad_line_markers = [ ("Ly$\\alpha$", 1215.67), ("CIV", 1549.06), ("CIII]", 1908.73), ("MgII", 2798.75), ("H$\\beta$", 4862.68), ("H$\\alpha$", 6564.61), ] xlo, xhi = ax.get_xlim() y_top = ax.get_ylim()[1] text_x_offset = 0.01 * (xhi - xlo) for label, lam0 in broad_line_markers: if xlo <= lam0 <= xhi: ax.axvline(lam0, color="gray", ls="--", lw=0.8, alpha=0.35, zorder=1) ax.text( lam0 - text_x_offset, y_top * 0.985, label, rotation=90, va="top", ha="center", fontsize=12, color="dimgray", alpha=0.9, zorder=7, ) if residual_enabled and len(total_model_plot) == len(fitter.wave) and ax_resid is not None: resid = fitter.flux - total_model_plot ax_resid.plot( fitter.wave, resid, color='gray', ls='dotted', lw=1.0, zorder=2, rasterized=True, ) ax_resid.axhline(0.0, color='k', ls='--', lw=0.8, zorder=1) r = resid[np.isfinite(resid)] if r.size > 0: rlim = np.nanpercentile(np.abs(r), 99) if np.isfinite(rlim) and rlim > 0: ax_resid.set_ylim(-1.15 * rlim, 1.15 * rlim) ax_resid.set_ylabel('resid', fontsize=20) fitter._style_axis(ax_resid) if residual_enabled and ax_resid is not None: ax_resid.set_xlabel('Rest Wavelength (Å)', fontsize=20) else: ax.set_xlabel('Rest Wavelength (Å)', fontsize=20) ax.set_ylabel(r'$f_{\lambda}$ (10$^{-17}$ erg s$^{-1}$ cm$^{-2}$ Å$^{-1}$)', fontsize=20) fitter._style_axis(ax) if plot_legend: ax.legend(loc="upper right", frameon=True, framealpha=0.9, fontsize=12, ncol=2) if show_plot: plt.show() if fitter.save_fig: save_dir = fitter.output_path if save_fig_path is None else save_fig_path if save_dir is None: save_dir = '.' os.makedirs(save_dir, exist_ok=True) out_file = os.path.join(save_dir, fitter.filename + '.pdf') fig.savefig(out_file, dpi=150) print(f"Saved spectrum plot: {out_file}") plt.close(fig) fitter.fig = fig return