Source code for jaxqsofit.custom_components

from __future__ import annotations

import copy
import importlib
import re
from dataclasses import dataclass, field
from typing import Any, Callable, Iterable, Mapping, Sequence

import jax.numpy as jnp
import numpy as np


def _normalize_template_flux(flux: np.ndarray, target_amp: float = 1.0) -> np.ndarray:
    """Rescale a template so its robust peak amplitude is O(target_amp)."""
    f = np.asarray(flux, dtype=float)
    robust = np.nanpercentile(np.abs(f), 99)
    if not np.isfinite(robust) or robust <= 0:
        robust = 1.0
    return f * (target_amp / robust)


def _sanitize_component_name(name: str) -> str:
    """Return a stable ASCII-safe identifier for a custom component."""
    text = re.sub(r"[^0-9a-zA-Z_]+", "_", str(name).strip())
    text = re.sub(r"_+", "_", text).strip("_").lower()
    if not text:
        raise ValueError("Custom component name must contain at least one alphanumeric character.")
    if text[0].isdigit():
        text = f"c_{text}"
    return text


def _callable_to_ref(func: Callable[..., Any]) -> str:
    """Serialize an importable top-level function to module:qualname."""
    module = getattr(func, "__module__", None)
    qualname = getattr(func, "__qualname__", None)
    if not module or not qualname or "<locals>" in qualname or "<lambda>" in qualname:
        raise ValueError(
            "Custom component evaluators must be importable top-level functions "
            "to support save/load of posterior bundles."
        )
    return f"{module}:{qualname}"


def _callable_from_ref(ref: str) -> Callable[..., Any]:
    """Deserialize a module:qualname function reference."""
    module_name, qualname = str(ref).split(":", 1)
    obj: Any = importlib.import_module(module_name)
    for part in qualname.split("."):
        obj = getattr(obj, part)
    if not callable(obj):
        raise TypeError(f"Resolved custom component reference is not callable: {ref}")
    return obj


[docs] @dataclass(frozen=True) class CustomComponentSpec: """Generic additive continuum component. The component is fully defined by: - ``parameter_priors``: local parameter names -> prior config dictionaries - ``evaluate``: callable ``evaluate(wave, params, metadata)`` The evaluator is responsible for any shifts, broadenings, template interpolation, or other transformations. """ name: str parameter_priors: Mapping[str, Mapping[str, Any]] evaluate: Callable[[Any, Mapping[str, Any], Mapping[str, Any]], Any] metadata: Mapping[str, Any] = field(default_factory=dict) def __post_init__(self): safe_name = _sanitize_component_name(self.name) object.__setattr__(self, "name", safe_name) if not callable(self.evaluate): raise TypeError("Custom component evaluate must be callable.") priors = {str(k): dict(v) for k, v in dict(self.parameter_priors).items()} object.__setattr__(self, "parameter_priors", priors) object.__setattr__(self, "metadata", dict(self.metadata)) @property def prefix(self) -> str: """Return the parameter-site prefix used in samples/priors.""" return f"custom_{self.name}" @property def output_name(self) -> str: """Return the public output component key.""" return self.name @property def deterministic_site_name(self) -> str: """Return the Predictive deterministic site name.""" return f"{self.prefix}_model"
[docs] def site_name(self, param_name: str) -> str: """Return the full NumPyro sample-site name for one local parameter.""" return f"{self.prefix}_{param_name}"
[docs] def to_state(self) -> dict[str, Any]: """Return a pickle-friendly representation.""" return { "__custom_component__": True, "name": self.name, "parameter_priors": copy.deepcopy(dict(self.parameter_priors)), "evaluate_ref": _callable_to_ref(self.evaluate), "metadata": copy.deepcopy(dict(self.metadata)), }
[docs] @classmethod def from_state(cls, state: Mapping[str, Any]) -> "CustomComponentSpec": """Rebuild a spec from :meth:`to_state`.""" return cls( name=str(state["name"]), parameter_priors=dict(state["parameter_priors"]), evaluate=_callable_from_ref(str(state["evaluate_ref"])), metadata=dict(state.get("metadata", {})), )
[docs] @dataclass(frozen=True) class CustomLineComponentSpec: """Generic additive emission-line component.""" name: str parameter_priors: Mapping[str, Mapping[str, Any]] evaluate: Callable[[Any, Mapping[str, Any], Mapping[str, Any]], Any] line_kind: str = "broad" metadata: Mapping[str, Any] = field(default_factory=dict) def __post_init__(self): safe_name = _sanitize_component_name(self.name) object.__setattr__(self, "name", safe_name) if not callable(self.evaluate): raise TypeError("Custom line component evaluate must be callable.") priors = {str(k): dict(v) for k, v in dict(self.parameter_priors).items()} object.__setattr__(self, "parameter_priors", priors) kind = str(self.line_kind).strip().lower() if kind not in {"broad", "narrow"}: raise ValueError("Custom line component line_kind must be 'broad' or 'narrow'.") object.__setattr__(self, "line_kind", kind) object.__setattr__(self, "metadata", dict(self.metadata)) @property def prefix(self) -> str: return f"custom_line_{self.name}" @property def output_name(self) -> str: return self.name @property def deterministic_site_name(self) -> str: return f"{self.prefix}_model"
[docs] def site_name(self, param_name: str) -> str: return f"{self.prefix}_{param_name}"
[docs] def to_state(self) -> dict[str, Any]: return { "__custom_line_component__": True, "name": self.name, "parameter_priors": copy.deepcopy(dict(self.parameter_priors)), "evaluate_ref": _callable_to_ref(self.evaluate), "line_kind": self.line_kind, "metadata": copy.deepcopy(dict(self.metadata)), }
[docs] @classmethod def from_state(cls, state: Mapping[str, Any]) -> "CustomLineComponentSpec": return cls( name=str(state["name"]), parameter_priors=dict(state["parameter_priors"]), evaluate=_callable_from_ref(str(state["evaluate_ref"])), line_kind=str(state.get("line_kind", "broad")), metadata=dict(state.get("metadata", {})), )
[docs] def make_custom_component( name: str, parameter_priors: Mapping[str, Mapping[str, Any]], evaluate: Callable[[Any, Mapping[str, Any], Mapping[str, Any]], Any], *, metadata: Mapping[str, Any] | None = None, ) -> CustomComponentSpec: """Build a generic additive custom component.""" return CustomComponentSpec( name=name, parameter_priors=parameter_priors, evaluate=evaluate, metadata={} if metadata is None else dict(metadata), )
[docs] def make_custom_line_component( name: str, parameter_priors: Mapping[str, Mapping[str, Any]], evaluate: Callable[[Any, Mapping[str, Any], Mapping[str, Any]], Any], *, line_kind: str = "broad", metadata: Mapping[str, Any] | None = None, ) -> CustomLineComponentSpec: """Build a generic additive custom line component.""" return CustomLineComponentSpec( name=name, parameter_priors=parameter_priors, evaluate=evaluate, line_kind=line_kind, metadata={} if metadata is None else dict(metadata), )
def template_component_evaluator(wave, params, metadata): """Convenience evaluator for broadened/shifted template components.""" from .model import _fe_template_component wave_template = jnp.asarray(metadata["wave"], dtype=jnp.float64) flux_template = jnp.asarray( metadata.get("flux_model", metadata["flux"]), dtype=jnp.float64, ) return _fe_template_component( wave, wave_template, flux_template, params["norm"], params.get("fwhm", jnp.asarray(metadata.get("default_fwhm_kms", 3000.0), dtype=jnp.float64)), params.get("shift", 0.0), base_fwhm_kms=float(metadata.get("base_fwhm_kms", 900.0)), )
[docs] def make_template_component( name: str, wave: Sequence[float], flux: Sequence[float], *, fit_fwhm: bool = False, fit_shift: bool = False, base_fwhm_kms: float = 900.0, default_fwhm_kms: float = 3000.0, normalize_template: bool = True, target_amp: float = 1.0, ) -> CustomComponentSpec: """Build a broadened/shifted additive template component.""" wave = np.asarray(wave, dtype=float) flux = np.asarray(flux, dtype=float) if wave.ndim != 1 or flux.ndim != 1 or wave.size != flux.size or wave.size < 2: raise ValueError("Template custom component wave/flux must be same-length 1D arrays.") if not np.all(np.isfinite(wave)) or not np.all(np.isfinite(flux)): raise ValueError("Template custom component wave/flux must be finite.") if np.any(np.diff(wave) <= 0): raise ValueError("Template custom component wavelength grid must be strictly increasing.") priors: dict[str, dict[str, Any]] = { "norm": {"dist": "LogNormal", "loc": np.log(1e-3), "scale": 0.5}, } if fit_fwhm: priors["fwhm"] = {"dist": "LogNormal", "loc": np.log(float(default_fwhm_kms)), "scale": 0.3} if fit_shift: priors["shift"] = {"dist": "Normal", "loc": 0.0, "scale": 1e-3} flux_model = flux if bool(normalize_template): flux_model = _normalize_template_flux(flux_model, target_amp=float(target_amp)) return make_custom_component( name=name, parameter_priors=priors, evaluate=template_component_evaluator, metadata={ "wave": wave, "flux": flux, "flux_model": flux_model, "normalize_template": bool(normalize_template), "target_amp": float(target_amp), "base_fwhm_kms": float(base_fwhm_kms), "default_fwhm_kms": float(default_fwhm_kms), }, )
def normalize_custom_components(custom_components: Iterable[CustomComponentSpec] | None) -> tuple[CustomComponentSpec, ...]: """Validate and normalize custom component definitions.""" if custom_components is None: return () normalized = [] seen = set() for comp in custom_components: if not isinstance(comp, CustomComponentSpec): raise TypeError("custom_components entries must be CustomComponentSpec objects.") if comp.name in seen: raise ValueError(f"Duplicate custom component name: {comp.name}") seen.add(comp.name) normalized.append(comp) return tuple(normalized) def normalize_custom_line_components( custom_line_components: Iterable[CustomLineComponentSpec] | None, ) -> tuple[CustomLineComponentSpec, ...]: """Validate and normalize custom line-component definitions.""" if custom_line_components is None: return () normalized = [] seen = set() for comp in custom_line_components: if not isinstance(comp, CustomLineComponentSpec): raise TypeError("custom_line_components entries must be CustomLineComponentSpec objects.") if comp.name in seen: raise ValueError(f"Duplicate custom line component name: {comp.name}") seen.add(comp.name) normalized.append(comp) return tuple(normalized) def inject_default_custom_component_priors( prior_config: dict[str, Any], flux: np.ndarray, custom_components: Iterable[CustomComponentSpec] | None, ) -> dict[str, Any]: """Return a prior config with default keys added for custom components.""" comps = normalize_custom_components(custom_components) if len(comps) == 0: return prior_config cfg = copy.deepcopy(prior_config) f = np.asarray(flux, dtype=float) finite = np.isfinite(f) fscale = float(np.nanmedian(np.abs(f[finite]))) if np.any(finite) else 1.0 if (not np.isfinite(fscale)) or fscale <= 0: fscale = 1.0 for comp in comps: for param_name, param_cfg in comp.parameter_priors.items(): out_cfg = dict(param_cfg) scale = float(out_cfg.get("scale", 1.0)) loc = float(out_cfg.get("loc", 0.0)) if str(out_cfg.get("dist", "Normal")).lower() == "lognormal" and loc <= np.log(1e-8): out_cfg["loc"] = np.log(max(1e-3 * fscale, 1e-10)) elif str(param_name) == "c0" and scale == 0.0: out_cfg["scale"] = 0.2 * fscale elif str(param_name).startswith("c") and scale == 0.0: out_cfg["scale"] = 0.05 * fscale cfg.setdefault(comp.site_name(param_name), out_cfg) return cfg def inject_default_custom_line_component_priors( prior_config: dict[str, Any], flux: np.ndarray, custom_line_components: Iterable[CustomLineComponentSpec] | None, ) -> dict[str, Any]: """Return a prior config with default keys added for custom line components.""" comps = normalize_custom_line_components(custom_line_components) if len(comps) == 0: return prior_config cfg = copy.deepcopy(prior_config) f = np.asarray(flux, dtype=float) finite = np.isfinite(f) fscale = float(np.nanmedian(np.abs(f[finite]))) if np.any(finite) else 1.0 if (not np.isfinite(fscale)) or fscale <= 0: fscale = 1.0 for comp in comps: for param_name, param_cfg in comp.parameter_priors.items(): out_cfg = dict(param_cfg) scale = float(out_cfg.get("scale", 1.0)) loc = float(out_cfg.get("loc", 0.0)) if str(out_cfg.get("dist", "Normal")).lower() == "lognormal" and loc <= np.log(1e-8): out_cfg["loc"] = np.log(max(1e-3 * fscale, 1e-10)) elif scale == 0.0: out_cfg["scale"] = 0.05 * fscale cfg.setdefault(comp.site_name(param_name), out_cfg) return cfg def custom_component_site_names(custom_components: Iterable[CustomComponentSpec] | None) -> list[str]: """Return deterministic-site names used by Predictive for custom components.""" return [comp.deterministic_site_name for comp in normalize_custom_components(custom_components)] def custom_line_component_site_names( custom_line_components: Iterable[CustomLineComponentSpec] | None, ) -> list[str]: """Return deterministic-site names used by Predictive for custom line components.""" return [comp.deterministic_site_name for comp in normalize_custom_line_components(custom_line_components)]