diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index eadcbdf..1b0fc94 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -20,19 +20,19 @@ repos: - id: trailing-whitespace - repo: https://github.com/astral-sh/ruff-pre-commit - rev: "v0.12.1" + rev: "v0.12.7" hooks: - id: ruff args: ["--fix", "--show-fixes"] - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: v0.12.1 + rev: v0.12.7 hooks: - id: ruff-format - repo: https://github.com/pre-commit/mirrors-mypy - rev: "v1.16.1" + rev: "v1.17.1" hooks: - id: mypy files: src|tests @@ -59,12 +59,12 @@ repos: exclude: ^(LICENSE$) - repo: https://github.com/henryiii/validate-pyproject-schema-store - rev: 2025.06.23 + rev: 2025.07.28 hooks: - id: validate-pyproject - repo: https://github.com/python-jsonschema/check-jsonschema - rev: 0.33.1 + rev: 0.33.2 hooks: - id: check-readthedocs - id: check-github-workflows diff --git a/.readthedocs.yaml b/.readthedocs.yaml index aeb6d1f..69207fe 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -1,15 +1,20 @@ # https://docs.readthedocs.com/platform/stable/build-customization.html#install-dependencies-with-uv + version: 2 +sphinx: + configuration: docs/conf.py + build: - os: ubuntu-24.04 - tools: - python: "3.13" - jobs: - create_environment: - - asdf plugin add uv - - asdf install uv latest - - asdf global uv latest - build: - html: - - uv run sphinx-build -T -b html docs $READTHEDOCS_OUTPUT/html + os: ubuntu-24.04 + tools: + python: "3.13" + jobs: + pre_create_environment: + - asdf plugin add uv + - asdf install uv latest + - asdf global uv latest + create_environment: + - uv venv "${READTHEDOCS_VIRTUALENV_PATH}" + install: + - UV_PROJECT_ENVIRONMENT="${READTHEDOCS_VIRTUALENV_PATH}" uv sync --group docs diff --git a/README.md b/README.md index 3336309..0c40f76 100644 --- a/README.md +++ b/README.md @@ -34,32 +34,36 @@ See more in `examples/` _evermore_ in a nutshell: ```python3 -from typing import NamedTuple +from typing import NamedTuple, TypeAlias import equinox as eqx import jax import jax.numpy as jnp -from jaxtyping import Array, PyTree +import wadler_lindig as wl +from jaxtyping import Array, Float, Scalar import evermore as evm jax.config.update("jax_enable_x64", True) +Hist1D: TypeAlias = Float[Array, "..."] +Hists1D: TypeAlias = dict[str, Hist1D] + # define a simple model with two processes and two parameters -def model(params: PyTree, hists: dict[str, Array]) -> Array: +def model(params: evm.PT, hists: Hists1D) -> Array: mu_modifier = params.mu.scale() syst_modifier = params.syst.scale_log(up=1.1, down=0.9) return mu_modifier(hists["signal"]) + syst_modifier(hists["bkg"]) def loss( - dynamic: PyTree, - static: PyTree, - hists: dict[str, Array], - observation: Array, -) -> Array: - params = evm.parameter.combine(dynamic, static) + dynamic: evm.PT, + static: evm.PT, + hists: Hists1D, + observation: Hist1D, +) -> Float[Scalar, ""]: + params = evm.tree.combine(dynamic, static) expectation = model(params, hists) # Poisson NLL of the expectation and observation log_likelihood = ( @@ -72,27 +76,27 @@ def loss( # setup data -hists = {"signal": jnp.array([3]), "bkg": jnp.array([10])} -observation = jnp.array([15]) +hists: Hists1D = {"signal": jnp.array([3.0]), "bkg": jnp.array([10.0])} +observation: Hist1D = jnp.array([15.0]) # define parameters, can be any PyTree of evm.Parameters class Params(NamedTuple): - mu: evm.Parameter - syst: evm.NormalParameter + mu: evm.Parameter[Float[Scalar, ""]] + syst: evm.NormalParameter[Float[Scalar, ""]] params = Params(mu=evm.Parameter(1.0), syst=evm.NormalParameter(0.0)) # split tree of parameters in a differentiable part and a static part -dynamic, static = evm.parameter.partition(params) +dynamic, static = evm.tree.partition(params) # Calculate negative log-likelihood/loss loss_val = loss(dynamic, static, hists, observation) # gradients of negative log-likelihood w.r.t. dynamic parameters grads = eqx.filter_grad(loss)(dynamic, static, hists, observation) -print(f"{grads.mu.value=}, {grads.syst.value=}") -# -> grads.mu.value=Array(-0.46153846, dtype=float64), grads.syst.value=Array(-0.15436207, dtype=float64) +wl.pprint(evm.tree.pure(grads), short_arrays=False) +# -> Params(mu=Array(-0.46153846, dtype=float64), syst=Array(-0.15436207, dtype=float64)) ``` ## Contributing diff --git a/docs/api/filter.md b/docs/api/filter.md new file mode 100644 index 0000000..251468f --- /dev/null +++ b/docs/api/filter.md @@ -0,0 +1,7 @@ +# PyTree Filter + +```{eval-rst} +.. automodule:: evermore.parameters.filter + :show-inheritance: + :members: +``` diff --git a/docs/api/index.md b/docs/api/index.md index 45b6f23..3c6fb22 100644 --- a/docs/api/index.md +++ b/docs/api/index.md @@ -2,6 +2,8 @@ ```{toctree} parameter.md +tree.md +filter.md transform.md sample.md effect.md diff --git a/docs/api/tree.md b/docs/api/tree.md new file mode 100644 index 0000000..eb8992c --- /dev/null +++ b/docs/api/tree.md @@ -0,0 +1,7 @@ +# Tree Manipulation for Parameters + +```{eval-rst} +.. automodule:: evermore.parameters.tree + :show-inheritance: + :members: +``` diff --git a/docs/binned_likelihood.md b/docs/binned_likelihood.md index 301ee62..15921de 100644 --- a/docs/binned_likelihood.md +++ b/docs/binned_likelihood.md @@ -10,9 +10,9 @@ of histograms. It is defined as follows: where {math}`\lambda_i(\phi)` is the model prediction for bin {math}`i`, {math}`d_i` is the observed data in bin {math}`i`, and -{math}`\pi_j\left(\phi_j\right)` is the prior probability density function (PDF) +{math}`\pi_j\left(\phi_j\right)` is the prior probability density function (AbstractPDF) for parameter {math}`j`. The first product is a Poisson per bin, and the second -product is the constraint from each prior PDF. +product is the constraint from each prior AbstractPDF. Key to constructing this likelihood is the definition of the model {math}`\lambda(\phi)` as a function of parameters {math}`\phi`. evermore @@ -21,8 +21,8 @@ provides building blocks to define these in a modular way. These building blocks include: - **evm.Parameter**: A class that represents a parameter with a value, name, - bounds, and prior PDF used as constraint. -- **evm.Effect**: Effects describe how data, e.g., histogram bins, may be + bounds, and prior AbstractPDF used as constraint. +- **evm.AbstractEffect**: Effects describe how data, e.g., histogram bins, may be varied. - **evm.Modifier**: Modifiers combine **evm.Effects** and **evm.Parameters** to modify data. @@ -37,7 +37,7 @@ import evermore as evm # -- parameter definition -- # params: PyTree[evm.Parameter] = ... -# dynamic_params, static_params = evm.parameter.partition(params) +# dynamic_params, static_params = evm.tree.partition(params) # -- model definition -- diff --git a/docs/building_blocks.md b/docs/building_blocks.md index 8f8f0f7..93087be 100644 --- a/docs/building_blocks.md +++ b/docs/building_blocks.md @@ -63,8 +63,7 @@ PDFs :::{tip} - You can use _any_ JAX-compatible PDF that satisfies the `evm.custom_types.PDFLike` protocol that requires a `.log_prob` and a `.sample` method to be present. - Examples would be PDFs from `distrax` or from the JAX-substrate of TensorFlow Probability. + You can use _any_ JAX-compatible PDF that satisfies the `evm.pdf.AbstractPDF` interface. Parameter Boundaries @@ -77,7 +76,7 @@ More information can be found in . Freeze a Parameter : For the minimization of a likelihood it is necessary to differentiate with respect to the _differentiable_ part, i.e., the `.value` attributes, of a PyTree of `evm.Parameters`. - Splitting this tree into the differentiable and non-differentiable part is done with `evm.parameter.partition`. You can freeze a `evm.Parameter` by setting `frozen=True`, this will + Splitting this tree into the differentiable and non-differentiable part is done with `evm.tree.partition`. You can freeze a `evm.Parameter` by setting `frozen=True`, this will put the frozen parameter in the non-differentiable part. Correlate a Parameter @@ -93,7 +92,7 @@ Correlate a Parameter p3 = evm.Parameter(value=0.5) # correlate them - p1, p2, p3 = evm.parameter.correlate(*parameters) + p1, p2, p3 = evm.parameter.correlate(p1, p2, p3) # now p1, p2, p3 are correlated, i.e., they share the same value assert p1.value == p2.value == p3.value @@ -102,6 +101,7 @@ Correlate a Parameter A more general case of correlating any PyTree of parameters is implemented as follows: ```{code-block} python from typing import NamedTuple + import jax class Params(NamedTuple): @@ -129,6 +129,7 @@ Inspect a (PyTree of) `evm.Parameters` with [treescope](https://treescope.readth You can even add custom visualizers, such as: ```{code-block} python +import treescope import evermore as evm @@ -170,7 +171,7 @@ Asymmetric Exponential Scaling The mathematical description can be found [here](https://cms-analysis.github.io/HiggsAnalysis-CombinedLimit/latest/what_combine_does/model_and_likelihood/#normalization-effects). -Custom effects can be either implemented by inheriting from `evm.effect.Effect` or - more conveniently - be defined with `evm.effect.Lambda`. +Custom effects can be either implemented by inheriting from `evm.effect.AbstractEffect` or - more conveniently - be defined with `evm.effect.Lambda`. Exemplary, a custom effect that varies a 3-bin histogram by a constant absolute {math}`1\sigma` uncertainty of `[1.0, 1.5, 2.0]` and returns an additive (`normalize_by="offset"`) variation: ```{code-block} python @@ -228,7 +229,7 @@ modify = evm.Modifier(parameter=param, effect=evm.effect.Linear(offset=0, slope= # apply the modifier modify(jnp.array([10, 20, 30])) -# -> Array([11., 22., 33.], dtype=float32, weak_type=True), +# -> Array([11., 22., 33.], dtype=float32) ``` For the most common types of modifiers evermore provides shorthands that construct a modifier directly from parameters, two examples: @@ -245,7 +246,7 @@ Modifier that scales a histogram with its value (no constraint): # apply the modifier modify(jnp.array([10, 20, 30])) - # -> Array([11., 22., 33.], dtype=float32, weak_type=True), + # -> Array([11., 22., 33.], dtype=float32) ``` @@ -255,12 +256,12 @@ Modifier that scales a histogram based on vertical template morphing (Normal con import evermore as evm - param = evm.NormalParameter(value=1.2) + param = evm.NormalParameter(value=0.2) # create the modifier modify = param.morphing( - up_template=[12, 23, 35], - down_template=[9, 17, 26], + up_template=jnp.array([12., 23., 35.]), + down_template=jnp.array([9., 17., 26.]), ) # apply the modifier @@ -282,14 +283,14 @@ jax.config.update("jax_enable_x64", True) param = evm.NormalParameter(value=0.1) modifier1 = param.morphing( - up_template=[12, 23, 35], - down_template=[9, 17, 26], + up_template=jnp.array([12., 23., 35.]), + down_template=jnp.array([9., 17., 26.]), ) modifier2 = param.scale_log(up=1.1, down=0.9) # apply the composed modifier -(modifier1 @ modifier2)(jnp.array([10, 20, 30])) +(modifier1 @ modifier2)(jnp.array([10., 20., 30.])) # -> Array([10.259877, 20.500944, 30.760822], dtype=float32) with treescope.active_autovisualizer.set_scoped(treescope.ArrayAutovisualizer()): diff --git a/docs/evermore_for_CMS.md b/docs/evermore_for_CMS.md index baedcd7..50830b6 100644 --- a/docs/evermore_for_CMS.md +++ b/docs/evermore_for_CMS.md @@ -184,10 +184,6 @@ bin1 autoMCStats 10 [include-signal = 0] [hist-mode = 1] Please note that evermore is treating statistical uncertainties through Gaussian rather than Poisson modifiers, comparable to the behavior of [pyhf](https://pyhf.readthedocs.io/en/latest/likelihood.html#mc-statistical-uncertainty-staterror). -However, unlike pyhf, evermore allows to define a threshold on the number of true, -simulated events per bin below which the statistical uncertainty is modelled per -process. The threshold value is negative by default, meaning that the per-process -treatment is always applied. ```{code-block} python import jax diff --git a/docs/index.md b/docs/index.md index 0bea2c5..99d612a 100644 --- a/docs/index.md +++ b/docs/index.md @@ -19,35 +19,41 @@ python -m pip install . ## evermore Quickstart ```{code-block} python -from typing import NamedTuple +from typing import NamedTuple, TypeAlias import equinox as eqx import jax import jax.numpy as jnp -from jaxtyping import Array, PyTree +import wadler_lindig as wl +from jaxtyping import Array, Float, Scalar import evermore as evm jax.config.update("jax_enable_x64", True) +Hist1D: TypeAlias = Float[Array, "..."] +Hists1D: TypeAlias = dict[str, Hist1D] + # define a simple model with two processes and two parameters -def model(params: PyTree, hists: dict[str, Array]) -> Array: +def model(params: evm.PT, hists: Hists1D) -> Array: mu_modifier = params.mu.scale() syst_modifier = params.syst.scale_log(up=1.1, down=0.9) return mu_modifier(hists["signal"]) + syst_modifier(hists["bkg"]) def loss( - dynamic: PyTree, - static: PyTree, - hists: dict[str, Array], - observation: Array, -) -> Array: - params = evm.parameter.combine(dynamic, static) + dynamic: evm.PT, + static: evm.PT, + hists: Hists1D, + observation: Hist1D, +) -> Float[Scalar, ""]: + params = evm.tree.combine(dynamic, static) expectation = model(params, hists) # Poisson NLL of the expectation and observation - log_likelihood = evm.pdf.PoissonContinuous(lamb=expectation).log_prob(observation).sum() + log_likelihood = ( + evm.pdf.PoissonContinuous(lamb=expectation).log_prob(observation).sum() + ) # Add parameter constraints from logpdfs constraints = evm.loss.get_log_probs(params) log_likelihood += evm.util.sum_over_leaves(constraints) @@ -55,27 +61,28 @@ def loss( # setup data -hists = {"signal": jnp.array([3]), "bkg": jnp.array([10])} -observation = jnp.array([15]) +hists: Hists1D = {"signal": jnp.array([3.0]), "bkg": jnp.array([10.0])} +observation: Hist1D = jnp.array([15.0]) # define parameters, can be any PyTree of evm.Parameters class Params(NamedTuple): - mu: evm.Parameter - syst: evm.NormalParameter + mu: evm.Parameter[Float[Scalar, ""]] + syst: evm.NormalParameter[Float[Scalar, ""]] params = Params(mu=evm.Parameter(1.0), syst=evm.NormalParameter(0.0)) # split tree of parameters in a differentiable part and a static part -dynamic, static = evm.parameter.partition(params) +dynamic, static = evm.tree.partition(params) # Calculate negative log-likelihood/loss loss_val = loss(dynamic, static, hists, observation) # gradients of negative log-likelihood w.r.t. dynamic parameters grads = eqx.filter_grad(loss)(dynamic, static, hists, observation) -print(f"{grads.mu.value=}, {grads.syst.value=}") -# -> grads.mu.value=Array(-0.46153846, dtype=float64), grads.syst.value=Array(-0.15436207, dtype=float64) +wl.pprint(evm.tree.pure(grads), short_arrays=False) +# -> Params(mu=Array(-0.46153846, dtype=float64), syst=Array(-0.15436207, dtype=float64)) + ``` Checkout the other [Examples](https://github.com/pfackeldey/evermore/tree/main/examples). diff --git a/docs/tips_and_tricks.md b/docs/tips_and_tricks.md index 92c05af..d35bc6d 100644 --- a/docs/tips_and_tricks.md +++ b/docs/tips_and_tricks.md @@ -71,7 +71,7 @@ This can be useful for example to ensure that the parameter values are within a evermore provides two predefined transformations: [`evm.transform.MinuitTransform`](#evermore.parameters.transform.MinuitTransform) (for bounded parameters) and [`evm.transform.SoftPlusTransform`](#evermore.parameters.transform.SoftPlusTransform) (for positive parameters). -```{code-cell} ipython3 +```{code-cell} python import evermore as evm import wadler_lindig as wl @@ -96,7 +96,7 @@ Transformations always transform into the unconstrained real space (using [`evm. Typically, you would transform your parameters as a first step inside your loss (or model) function. Then, a minimizer can optimize the transformed parameters in the unconstrained space. Finally, you can transform them back to the constrained space for further processing. -Custom transformations can be defined by subclassing [`evm.transform.ParameterTransformation`](#evermore.parameters.transform.ParameterTransformation) and implementing the [`wrap`](#evermore.parameters.transform.ParameterTransformation.wrap) and [`unwrap`](#evermore.parameters.transform.ParameterTransformation.unwrap) methods. +Custom transformations can be defined by subclassing [`evm.transform.ParameterTransformation`](#evermore.parameters.transform.AbstractParameterTransformation) and implementing the [`wrap`](#evermore.parameters.transform.AbstractParameterTransformation.wrap) and [`unwrap`](#evermore.parameters.transform.AbstractParameterTransformation.unwrap) methods. ## Parameter Partitioning @@ -109,6 +109,7 @@ w.r.t. both parts. Gradient calculation is performed only w.r.t. the differentia ```{code-block} python from jaxtyping import Array, PyTree +import equinox as eqx import evermore as evm # define a PyTree of parameters @@ -117,14 +118,14 @@ params = { "b": evm.Parameter(), } -dynamic, static = evm.parameter.partition(params) +dynamic, static = evm.tree.partition(params) print(f"{dynamic=}") print(f"{static=}") # loss's first argument is only the dynamic part of the parameter Pytree! def loss(dynamic: PyTree[evm.Parameter], static: PyTree[evm.Parameter], hists: PyTree[Array]) -> Array: # combine the dynamic and static parts of the parameter PyTree - parameters = evm.parameter.combine(dynamic, static) + parameters = evm.tree.combine(dynamic, static) assert parameters == params # use the parameters to calculate the loss as usual ... @@ -133,7 +134,50 @@ grad_loss = eqx.filter_grad(loss)(dynamic, static, ...) ``` If you need to further exclude parameter from being optimized you can either set `frozen=True`. -For a more precise handling of the partitioning and combining step, have a look at `eqx.partition`, `eqx.combine`, and `evm.parameter.value_filter_spec`. +For a more precise handling of the partitioning and combining step, have a look at `eqx.partition`, `eqx.combine`, and `evm.tree.value_filter_spec`. + + +(tree-manipulations)= +## PyTree Manipulations + +`evermore` provides (similarly to `nnx`) a little filter DSL to allow more powerful manipulations of PyTrees of `evm.Parameters`. +The following example highlights the `evm.tree.only` function using various filters: + +```{code-cell} ipython3 +import evermore as evm +import wadler_lindig as wl + +tags = frozenset({"theory"}) + +# some pytree of parameters and something else +tree = { + "mu": evm.Parameter(name="mu"), + "xsecs": { + "dy": evm.Parameter(tags=tags), + "tt": evm.Parameter(frozen=True, tags=tags), + }, + "not_a_parameter": 0.0, +} + +# parameter-only pytree +params = evm.tree.only(tree, evm.filter.is_parameter) +print("evm.filter.is_parameter:") +wl.pprint(params, width=200) + +print("\nevm.filter.is_frozen:") +wl.pprint(evm.tree.only(params, evm.filter.is_frozen), width=200) + +print("\nevm.filter.is_not_frozen:") +wl.pprint(evm.tree.only(params, evm.filter.is_not_frozen), width=200) + +print("\nevm.filter.HasName('mu'):") +wl.pprint(evm.tree.only(params, evm.filter.HasName("mu")), width=200) + +print("\nevm.filter.HasTags({'theory'}):") +wl.pprint(evm.tree.only(params, evm.filter.HasTags(tags)), width=200) +``` + +`evm.tree.partition` also accepts a `filter` argument, and let's you partition any pytree as you want. ## JAX Transformations diff --git a/examples/bin_by_bin_uncs.py b/examples/bin_by_bin_uncs.py index f12cefa..48ec781 100644 --- a/examples/bin_by_bin_uncs.py +++ b/examples/bin_by_bin_uncs.py @@ -1,5 +1,7 @@ from __future__ import annotations +import typing as tp + import jax import jax.numpy as jnp from jaxtyping import Array, Float, PyTree @@ -8,10 +10,12 @@ jax.config.update("jax_enable_x64", True) +Hist1D = tp.TypeVar("Hist1D", bound=Float[Array, " nbins"]) + def model( - staterrors: evm.staterror.StatErrors, hists: PyTree[Float[Array, " nbins"]] -) -> PyTree[Float[Array, " nbins"]]: + staterrors: evm.staterror.StatErrors, hists: PyTree[Hist1D] +) -> PyTree[Hist1D]: expectations = {} # signal process diff --git a/examples/dnn_weights_constraint.py b/examples/dnn_weights_constraint.py index 455ca11..a54cc47 100644 --- a/examples/dnn_weights_constraint.py +++ b/examples/dnn_weights_constraint.py @@ -1,37 +1,42 @@ +import typing as tp + import equinox as eqx import jax import jax.numpy as jnp +from jaxtyping import Array, Float import evermore as evm +In: tp.TypeAlias = Float[Array, "in_size"] +W: tp.TypeAlias = Float[Array, "out_size in_size"] +B: tp.TypeAlias = Float[Array, "out_size"] +Out: tp.TypeAlias = B + class LinearConstrained(eqx.Module): - weights: evm.Parameter - biases: jax.Array + weights: evm.Parameter[W] + biases: B def __init__(self, in_size, out_size, key): wkey, bkey = jax.random.split(key) # weights - normal = evm.pdf.Normal( - mean=jnp.zeros((out_size, in_size)), - width=jnp.full((out_size, in_size), 0.5), - ) self.weights = evm.Parameter( value=jax.random.normal(wkey, (out_size, in_size)), - lower=-jnp.inf, # type: ignore[arg-type] - upper=jnp.inf, # type: ignore[arg-type] - prior=normal, + prior=evm.pdf.Normal( + mean=jnp.zeros((out_size, in_size)), + width=jnp.full((out_size, in_size), 0.5), + ), ) # biases self.biases = jax.random.normal(bkey, (out_size,)) - def __call__(self, x: jax.Array): + def __call__(self, x: In) -> Out: return self.weights.value @ x + self.biases @eqx.filter_jit -def loss_fn(model, x, y): +def loss_fn(model, x: In, y: Out) -> Float[Array, ""]: pred_y = jax.vmap(model)(x) mse = jax.numpy.mean((y - pred_y) ** 2) constraints = evm.loss.get_log_probs(model) diff --git a/examples/grad_nll.py b/examples/grad_nll.py index 4b0a5a6..e4fbef9 100644 --- a/examples/grad_nll.py +++ b/examples/grad_nll.py @@ -1,17 +1,16 @@ import equinox as eqx -import jax import wadler_lindig as wl from model import hists, loss, observation, params import evermore as evm if __name__ == "__main__": - dynamic, static = evm.parameter.partition(params) + dynamic, static = evm.tree.partition(params) loss_val = loss(dynamic, static, hists, observation) print(f"{loss_val=}") grads = eqx.filter_grad(loss)(dynamic, static, hists, observation) print("Gradients:") wl.pprint( - jax.tree.map(lambda p: p.value, grads, is_leaf=evm.parameter.is_parameter), + evm.tree.pure(grads), short_arrays=False, ) diff --git a/examples/model.py b/examples/model.py index 3c39266..6d7dded 100644 --- a/examples/model.py +++ b/examples/model.py @@ -1,5 +1,7 @@ from __future__ import annotations +import typing as tp + import equinox as eqx import jax import jax.numpy as jnp @@ -10,6 +12,9 @@ jax.config.update("jax_enable_x64", True) +Hist1D: tp.TypeAlias = Float[Array, " nbins"] # type: ignore[name-defined] + + # dataclass like container for parameters class Params(eqx.Module): mu: evm.Parameter @@ -20,8 +25,8 @@ class Params(eqx.Module): def model( params: PyTree[evm.Parameter], - hists: PyTree[Float[Array, " nbins"]], -) -> PyTree[Float[Array, " nbins"]]: + hists: PyTree[Hist1D], +) -> PyTree[Hist1D]: expectations = {} # signal process @@ -69,7 +74,7 @@ def model( } params = Params( - mu=evm.Parameter(), # type: ignore[arg-type] + mu=evm.Parameter(), norm1=evm.NormalParameter(), norm2=evm.NormalParameter(), shape1=evm.NormalParameter(), @@ -83,10 +88,10 @@ def model( def loss( dynamic: PyTree[evm.Parameter], static: PyTree[evm.Parameter], - hists: PyTree[Float[Array, " nbins"]], - observation: Float[Array, " nbins"], + hists: PyTree[Hist1D], + observation: Hist1D, ) -> Float[Array, ""]: - params = evm.parameter.combine(dynamic, static) + params = evm.tree.combine(dynamic, static) expectations = model(params, hists) constraints = evm.loss.get_log_probs(params) loss_val = ( diff --git a/examples/nll_fit_iminuit.py b/examples/nll_fit_iminuit.py index 3e24062..f7854c2 100644 --- a/examples/nll_fit_iminuit.py +++ b/examples/nll_fit_iminuit.py @@ -1,3 +1,5 @@ +import typing as tp + import equinox as eqx import iminuit import jax @@ -8,31 +10,31 @@ import evermore as evm +FlatV: tp.TypeAlias = Float[Array, " nparams"] # type: ignore[name-defined] + def fit(params, hists, observation): # partition into dynamic and static parts - dynamic, static = evm.parameter.partition(params) + dynamic, static = evm.tree.partition(params) # flatten parameter.value for iminuit - values = jax.tree.map( - lambda p: p.value, dynamic, is_leaf=evm.parameter.is_parameter - ) + values = evm.tree.pure(dynamic) flat_values, unravel_fn = jax.flatten_util.ravel_pytree(values) def update( params: PyTree[evm.Parameter], - values: Float[Array, " nparams"], + values: FlatV, ) -> PyTree[evm.Parameter]: return jax.tree.map( evm.parameter.replace_value, params, unravel_fn(values), - is_leaf=evm.parameter.is_parameter, + is_leaf=evm.filter.is_parameter, ) # wrap loss that works on flat array @eqx.filter_jit - def flat_loss(flat_values: Float[Array, " nparams"]) -> Float[Array, ""]: + def flat_loss(flat_values: FlatV) -> Float[Array, ""]: _dynamic = update(dynamic, flat_values) return loss(_dynamic, static, hists, observation) @@ -49,11 +51,11 @@ def flat_loss(flat_values: Float[Array, " nparams"]) -> Float[Array, ""]: bestfit_dynamic = update(dynamic, bestfit_values) # combine with static pytree - return evm.parameter.combine(bestfit_dynamic, static) + return evm.tree.combine(bestfit_dynamic, static) if __name__ == "__main__": bestfit_params = fit(params, hists, observation) print("Bestfit parameter:") - wl.pprint(bestfit_params, short_arrays=False) + wl.pprint(evm.tree.pure(bestfit_params), short_arrays=False) diff --git a/examples/nll_fit_mutable_arrays.py b/examples/nll_fit_mutable_arrays.py new file mode 100644 index 0000000..da343bd --- /dev/null +++ b/examples/nll_fit_mutable_arrays.py @@ -0,0 +1,51 @@ +import equinox as eqx +import jax +import wadler_lindig as wl +from model import hists, loss, observation, params + +import evermore as evm + +evmm = evm.mutable +evmt = evm.tree +evmf = evm.filter + + +def fit(params, hists, observation): + dynamic, static = evmt.partition(params) + + # make dynamic part mutable + dynamic_ref = evmm.to_refs(dynamic) + + @jax.jit + def minimize_step(dynamic_ref, static, hists, observation) -> None: + loss_grad = eqx.filter_value_and_grad(loss) + loss_val, grads = loss_grad( + evmm.freeze(dynamic_ref), + static, + hists, + observation, + ) + + # gradient descent step (in-place update `p.value`) + def gd(p: evm.Parameter, g: evm.Parameter, lr: float = 1e-2) -> None: + p.value[...] -= lr * g.value + + # apply the gradient descent step to each parameter in the dynamic part + jax.tree.map(gd, dynamic_ref, grads, is_leaf=evmf.is_parameter) + return loss_val + + # minimize with 5000 steps + for step in range(5000): + loss_val = minimize_step(dynamic_ref, static, hists, observation) + if step % 500 == 0: + print(f"{step=} - {loss_val=:.6f}") + + # return best fit values (immutable) + return evmt.pure(evmt.combine(evmm.to_arrays(dynamic_ref), static)) + + +if __name__ == "__main__": + bestfit_params = fit(params, hists, observation) + + print("Bestfit parameter:") + wl.pprint(bestfit_params, short_arrays=False) diff --git a/examples/nll_fit_optax.py b/examples/nll_fit_optax.py index dd9a98a..5c00a3c 100644 --- a/examples/nll_fit_optax.py +++ b/examples/nll_fit_optax.py @@ -1,8 +1,8 @@ import equinox as eqx import optax import wadler_lindig as wl -from jaxtyping import Array, Float, PyTree -from model import hists, loss, observation, params +from jaxtyping import PyTree +from model import Hist1D, hists, loss, observation, params import evermore as evm @@ -14,8 +14,8 @@ def make_step( dynamic: PyTree[evm.Parameter], static: PyTree[evm.Parameter], opt_state: PyTree, - hists: PyTree[Float[Array, " nbins"]], - observation: Float[Array, " nbins"], + hists: PyTree[Hist1D], + observation: Hist1D, ) -> tuple[PyTree[evm.Parameter], PyTree]: grads = eqx.filter_grad(loss)(dynamic, static, hists, observation) updates, opt_state = optim.update(grads, opt_state) @@ -25,7 +25,7 @@ def make_step( def fit(params, hists, observation): - dynamic, static = evm.parameter.partition(params) + dynamic, static = evm.tree.partition(params) # initialize optimizer state opt_state = optim.init(eqx.filter(dynamic, eqx.is_inexact_array)) @@ -38,11 +38,11 @@ def fit(params, hists, observation): dynamic, opt_state = make_step(dynamic, static, opt_state, hists, observation) # combine optimized dynamic part with the static pytree - return evm.parameter.combine(dynamic, static) + return evm.tree.combine(dynamic, static) if __name__ == "__main__": bestfit_params = fit(params, hists, observation) print("Bestfit parameter:") - wl.pprint(bestfit_params, short_arrays=False) + wl.pprint(evm.tree.pure(bestfit_params), short_arrays=False) diff --git a/examples/nll_fit_optimistix.py b/examples/nll_fit_optimistix.py index 3193c26..497f5b8 100644 --- a/examples/nll_fit_optimistix.py +++ b/examples/nll_fit_optimistix.py @@ -8,7 +8,7 @@ def fit(params, hists, observation): solver = optx.BFGS(rtol=1e-5, atol=1e-7) - dynamic, static = evm.parameter.partition(params) + dynamic, static = evm.tree.partition(params) def optx_loss(dynamic, args): return loss(dynamic, *args) @@ -23,11 +23,11 @@ def optx_loss(dynamic, args): max_steps=10_000, throw=True, ) - return evm.parameter.combine(fitresult.value, static) + return evm.tree.combine(fitresult.value, static) if __name__ == "__main__": bestfit_params = fit(params, hists, observation) print("Bestfit parameter:") - wl.pprint(bestfit_params, short_arrays=False) + wl.pprint(evm.tree.pure(bestfit_params), short_arrays=False) diff --git a/examples/nll_profiling.py b/examples/nll_profiling.py index 0e86222..d1e2a2b 100644 --- a/examples/nll_profiling.py +++ b/examples/nll_profiling.py @@ -1,46 +1,44 @@ import equinox as eqx import jax import jax.numpy as jnp -import optax -from jaxtyping import Array, Float, PyTree +import optimistix as optx +from jaxtyping import Array, Float import evermore as evm +solver = optx.BFGS(rtol=1e-5, atol=1e-7) + def fixed_mu_fit(mu: Float[Array, ""]) -> Float[Array, ""]: from model import hists, loss, observation, params - optim = optax.sgd(learning_rate=1e-2) - opt_state = optim.init(eqx.filter(params, eqx.is_inexact_array)) - # Fix `mu` and freeze the parameter - params = eqx.tree_at(lambda t: t.mu.value, params, mu) params = eqx.tree_at(lambda t: t.mu.frozen, params, True) - - # twice_nll = 2 * loss - def twice_nll(dynamic, static, hists, observation): - return 2.0 * loss(dynamic, static, hists, observation) - - @eqx.filter_jit - def make_step( - dynamic: PyTree[evm.Parameter], - static: PyTree[evm.Parameter], - hists: PyTree[Float[Array, " nbins"]], - observation: Float[Array, " nbins"], - opt_state: PyTree, - ) -> tuple[PyTree[evm.Parameter], PyTree]: - grads = eqx.filter_grad(twice_nll)(dynamic, static, hists, observation) - updates, opt_state = optim.update(grads, opt_state) - # apply parameter updates - dynamic = eqx.apply_updates(dynamic, updates) - return dynamic, opt_state - - dynamic, static = evm.parameter.partition(params) - - # minimize params with 1000 steps - for _ in range(1000): - dynamic, opt_state = make_step(dynamic, static, hists, observation, opt_state) - return twice_nll(dynamic, static, hists, observation) + dynamic, static = evm.tree.partition(params) + + # Update the `mu` value in the static part, either: + # 1) using `evm.parameter.to_value` and `eqx.tree_at` + static = eqx.tree_at(lambda t: t.mu.raw_value, static, evm.parameter.to_value(mu)) + # or 2) using mutable arrays + # static_ref = evm.mutable.to_refs(static) + # static_ref.mu.raw_value[...] = evm.parameter.to_value(mu) + # static = evm.mutable.to_arrays(static_ref) + + def twice_nll(dynamic, args): + return 2.0 * loss(dynamic, *args) + + fitresult = optx.minimise( + twice_nll, + solver, + dynamic, + has_aux=False, + args=(static, hists, observation), + options={}, + max_steps=10_000, + throw=True, + ) + + return twice_nll(fitresult.value, (static, hists, observation)) if __name__ == "__main__": diff --git a/examples/toy_generation.py b/examples/toy_generation.py index 87705ef..586fbe1 100644 --- a/examples/toy_generation.py +++ b/examples/toy_generation.py @@ -5,7 +5,7 @@ import jax.numpy as jnp import optimistix as optx from jaxtyping import Array, Float, PRNGKeyArray, PyTree -from model import hists, loss, model, observation, params +from model import Hist1D, hists, loss, model, observation, params import evermore as evm @@ -29,7 +29,7 @@ def optx_loss(dynamic, args): def fit(params, hists, observation): solver = optx.BFGS(rtol=1e-5, atol=1e-7) - dynamic, static = evm.parameter.partition(params) + dynamic, static = evm.tree.partition(params) fitresult = optx.minimise( optx_loss, @@ -41,7 +41,7 @@ def fit(params, hists, observation): max_steps=10_000, throw=True, ) - return evm.parameter.combine(fitresult.value, static) + return evm.tree.combine(fitresult.value, static) # generate new expectation based on the postfit toy parameters @@ -52,14 +52,14 @@ def postfit_toy_expectation( static: PyTree[evm.Parameter], covariance_matrix: Float[Array, "x x"], n_samples: int = 1, -) -> Float[Array, " nbins"]: +) -> Hist1D: toy_dynamic = evm.sample.sample_from_covariance_matrix( key=key, params=dynamic, covariance_matrix=covariance_matrix, n_samples=n_samples, ) - toy_params = evm.parameter.combine(toy_dynamic, static) + toy_params = evm.tree.combine(toy_dynamic, static) expectations = model(toy_params, hists) return evm.util.sum_over_leaves(expectations) @@ -77,7 +77,7 @@ def prefit_toy_expectation(params, key): # --- Postfit sampling --- bestfit_params = fit(params, hists, observation) - dynamic, static = evm.parameter.partition(bestfit_params) + dynamic, static = evm.tree.partition(bestfit_params) # partial it to only depend on `params` loss_fn = partial(optx_loss, args=(static, hists, observation)) @@ -107,10 +107,18 @@ def prefit_toy_expectation(params, key): print("Mean of 10.000 toys (prefit):", jnp.mean(expectations, axis=0)) print("Std of 10.000 toys (prefit):", jnp.std(expectations, axis=0)) - # # just sample observations with poisson - # poisson_obs = evm.pdf.PoissonDiscrete(observation) - # sampled_observation = poisson_obs.sample(key) + # just sample observations with poisson + poisson_obs = evm.pdf.PoissonDiscrete(observation) + sampled_observation = poisson_obs.sample(key, shape=(1,)) - # # vectorized sampling (generically with `vmap`) - # keys = jax.random.split(key, 10_000) - # sampled_observations = jax.vmap(poisson_obs.sample)(keys) + N = 10_000 + # vectorized sampling (standard way) + sampled_observations = poisson_obs.sample(key, shape=(N, 1)) + + # vectorized sampling (generically with `vmap`) + keys = jax.random.split(key, N) + + def sample_obs(k): + return poisson_obs.sample(k, shape=(1,)) + + sampled_observations = jax.vmap(sample_obs)(keys) diff --git a/pyproject.toml b/pyproject.toml index f1b1335..2ac5efe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ build-backend = "hatchling.build" [project] name = "evermore" authors = [ - { name = "Peter Fackeldey", email = "peter.fackeldey@rwth-aachen.de" }, + { name = "Peter Fackeldey", email = "fackeldey.peter@gmail.com" }, ] description = "Differentiable (binned) likelihoods in JAX." license = "BSD-3-Clause" diff --git a/src/evermore/__init__.py b/src/evermore/__init__.py index 943d173..ecd4211 100644 --- a/src/evermore/__init__.py +++ b/src/evermore/__init__.py @@ -13,25 +13,31 @@ __contact__ = "https://github.com/pfackeldey/evermore" __license__ = "BSD-3-Clause" __status__ = "Development" -__version__ = "0.3.2" +__version__ = "0.3.3" # expose public API __all__ = [ + "PT", + "AbstractParameter", "Modifier", "NormalParameter", # explicitly expose some classes "Parameter", + "V", "__version__", "effect", + "filter", "loss", "modifier", + "mutable", "parameter", "pdf", "sample", "staterror", "transform", + "tree", "util", "visualization", ] @@ -54,11 +60,17 @@ def __dir__(): ) from evermore.binned.modifier import Modifier # noqa: E402 from evermore.parameters import ( # noqa: E402 + filter, + mutable, parameter, sample, transform, + tree, ) from evermore.parameters.parameter import ( # noqa: E402 + AbstractParameter, NormalParameter, Parameter, + V, ) +from evermore.parameters.tree import PT # noqa: E402 diff --git a/src/evermore/binned/effect.py b/src/evermore/binned/effect.py index 443a1de..e74eaf8 100644 --- a/src/evermore/binned/effect.py +++ b/src/evermore/binned/effect.py @@ -2,20 +2,21 @@ import abc from collections.abc import Callable -from typing import Literal +from typing import Generic, Literal, TypeVar import equinox as eqx import jax import jax.numpy as jnp -from jaxtyping import Array, Float, PyTree +from jaxtyping import Array, Float -from evermore.parameters.parameter import Parameter -from evermore.util import float_array +from evermore.parameters.parameter import AbstractParameter +from evermore.parameters.tree import PT +from evermore.util import maybe_float_array from evermore.visualization import SupportsTreescope __all__ = [ + "AbstractEffect", "AsymmetricExponential", - "Effect", "Identity", "Linear", "VerticalTemplateMorphing", @@ -26,11 +27,14 @@ def __dir__(): return __all__ -class OffsetAndScale(eqx.Module): - offset: Float[Array, "..."] = eqx.field(converter=float_array, default=0.0) # noqa: UP037 - scale: Float[Array, "..."] = eqx.field(converter=float_array, default=1.0) # noqa: UP037 +H = TypeVar("H", bound=Float[Array, "..."]) - def broadcast(self) -> OffsetAndScale: + +class OffsetAndScale(eqx.Module, Generic[H]): + offset: H = eqx.field(converter=maybe_float_array, default=0.0) + scale: H = eqx.field(converter=maybe_float_array, default=1.0) + + def broadcast(self) -> OffsetAndScale[H]: shape = jnp.broadcast_shapes(self.offset.shape, self.scale.shape) return type(self)( offset=jnp.broadcast_to(self.offset, shape), @@ -38,79 +42,61 @@ def broadcast(self) -> OffsetAndScale: ) -class Effect(eqx.Module, SupportsTreescope): +class AbstractEffect(eqx.Module, Generic[H], SupportsTreescope): @abc.abstractmethod - def __call__( - self, - parameter: PyTree[Parameter], - hist: Float[Array, "..."], # noqa: UP037 - ) -> OffsetAndScale: ... + def __call__(self, parameter: PT, hist: H) -> OffsetAndScale[H]: ... -class Identity(Effect): +class Identity(AbstractEffect[H]): @jax.named_scope("evm.effect.Identity") - def __call__( - self, - parameter: PyTree[Parameter], - hist: Float[Array, "..."], # noqa: UP037 - ) -> OffsetAndScale: - return OffsetAndScale(offset=0.0, scale=1.0) # type: ignore[arg-type] - - -class Lambda(Effect): - fun: Callable[ - [PyTree[Parameter], Float[Array, "..."]], OffsetAndScale | Float[Array, "..."] # noqa: UP037 - ] + def __call__(self, parameter: PT, hist: H) -> OffsetAndScale[H]: + return OffsetAndScale( + offset=jnp.zeros_like(hist), scale=jnp.ones_like(hist) + ).broadcast() + + +class Lambda(AbstractEffect[H]): + fun: Callable[[PT, H], OffsetAndScale[H] | H] normalize_by: Literal["offset", "scale"] | None = eqx.field( static=True, default=None ) @jax.named_scope("evm.effect.Lambda") - def __call__( - self, - parameter: PyTree[Parameter], - hist: Float[Array, "..."], # noqa: UP037 - ) -> OffsetAndScale: + def __call__(self, parameter: PT, hist: H) -> OffsetAndScale[H]: + assert isinstance(parameter, AbstractParameter) res = self.fun(parameter, hist) if isinstance(res, OffsetAndScale) and self.normalize_by is None: return res if self.normalize_by == "offset": - return OffsetAndScale(offset=(res - hist), scale=1.0) # type: ignore[arg-type] + return OffsetAndScale( + offset=(res - hist), scale=jnp.ones_like(hist) + ).broadcast() if self.normalize_by == "scale": - return OffsetAndScale(offset=0.0, scale=(res / hist)) # type: ignore[arg-type] + return OffsetAndScale( + offset=jnp.zeros_like(hist), scale=(res / hist) + ).broadcast() msg = f"Unknown normalization type '{self.normalize_by}' for '{res}'" raise ValueError(msg) -class Linear(Effect): - offset: Float[Array, "..."] = eqx.field(converter=float_array) # noqa: UP037 - slope: Float[Array, "..."] = eqx.field(converter=float_array) # noqa: UP037 +class Linear(AbstractEffect[H]): + offset: H = eqx.field(converter=maybe_float_array) + slope: H = eqx.field(converter=maybe_float_array) @jax.named_scope("evm.effect.Linear") - def __call__( - self, - parameter: PyTree[Parameter], - hist: Float[Array, "..."], # noqa: UP037 - ) -> OffsetAndScale: - assert isinstance(parameter, Parameter) + def __call__(self, parameter: PT, hist: H) -> OffsetAndScale[H]: + assert isinstance(parameter, AbstractParameter) sf = parameter.value * self.slope + self.offset - return OffsetAndScale(offset=0.0, scale=sf) # type: ignore[arg-type] - - -DEFAULT_EFFECT: Linear = Linear(offset=0.0, slope=1.0) # type: ignore[arg-type] + return OffsetAndScale(offset=jnp.zeros_like(hist), scale=sf).broadcast() -class VerticalTemplateMorphing(Effect): +class VerticalTemplateMorphing(AbstractEffect[H]): # + 1 sigma - up_template: Float[Array, "..."] = eqx.field(converter=float_array) # noqa: UP037 + up_template: H = eqx.field(converter=maybe_float_array) # - 1 sigma - down_template: Float[Array, "..."] = eqx.field(converter=float_array) # noqa: UP037 + down_template: H = eqx.field(converter=maybe_float_array) - def vshift( - self, - x: Float[Array, "..."], # noqa: UP037 - hist: Float[Array, "..."], # noqa: UP037 - ) -> Float[Array, "..."]: # noqa: UP037 + def vshift(self, x: H, hist: H) -> H: dx_sum = self.up_template + self.down_template - 2 * hist dx_diff = self.up_template - self.down_template @@ -129,21 +115,17 @@ def vshift( ) @jax.named_scope("evm.effect.VerticalTemplateMorphing") - def __call__( - self, - parameter: PyTree[Parameter], - hist: Float[Array, "..."], # noqa: UP037 - ) -> OffsetAndScale: - assert isinstance(parameter, Parameter) + def __call__(self, parameter: PT, hist: H) -> OffsetAndScale[H]: + assert isinstance(parameter, AbstractParameter) offset = self.vshift(parameter.value, hist=hist) - return OffsetAndScale(offset=offset, scale=jnp.ones_like(hist)) # type: ignore[arg-type] + return OffsetAndScale(offset=offset, scale=jnp.ones_like(hist)).broadcast() -class AsymmetricExponential(Effect): - up: Float[Array, "..."] = eqx.field(converter=float_array) # noqa: UP037 - down: Float[Array, "..."] = eqx.field(converter=float_array) # noqa: UP037 +class AsymmetricExponential(AbstractEffect[H]): + up: H = eqx.field(converter=maybe_float_array) + down: H = eqx.field(converter=maybe_float_array) - def interpolate(self, x: Float[Array, "..."]) -> Float[Array, "..."]: # noqa: UP037 + def interpolate(self, x: H) -> H: # https://github.com/cms-analysis/HiggsAnalysis-CombinedLimit/blob/be488af288361ef101859a398ae618131373cad7/src/ProcessNormalization.cc#L112-L129 lo, hi = self.down, self.up hi = jnp.log(hi) @@ -159,11 +141,9 @@ def interpolate(self, x: Float[Array, "..."]) -> Float[Array, "..."]: # noqa: U ) @jax.named_scope("evm.effect.AsymmetricExponential") - def __call__( - self, - parameter: PyTree[Parameter], - hist: Float[Array, "..."], # noqa: UP037 - ) -> OffsetAndScale: - assert isinstance(parameter, Parameter) + def __call__(self, parameter: PT, hist: H) -> OffsetAndScale[H]: + assert isinstance(parameter, AbstractParameter) interp = self.interpolate(parameter.value) - return OffsetAndScale(offset=0.0, scale=jnp.exp(parameter.value * interp)) # type: ignore[arg-type] + return OffsetAndScale( + offset=jnp.zeros_like(hist), scale=jnp.exp(parameter.value * interp) + ).broadcast() diff --git a/src/evermore/binned/modifier.py b/src/evermore/binned/modifier.py index 7415d50..1eedad5 100644 --- a/src/evermore/binned/modifier.py +++ b/src/evermore/binned/modifier.py @@ -2,20 +2,20 @@ import abc from collections.abc import Callable, Iterable, Iterator -from typing import TYPE_CHECKING, Protocol, runtime_checkable +from typing import TYPE_CHECKING, Generic import equinox as eqx import jax import jax.numpy as jnp -from jaxtyping import Array, Bool, Float, PyTree +from jaxtyping import Array, Bool -from evermore.binned.effect import DEFAULT_EFFECT, OffsetAndScale -from evermore.parameters.parameter import Parameter +from evermore.binned.effect import H, OffsetAndScale +from evermore.parameters.tree import PT from evermore.util import tree_stack from evermore.visualization import SupportsTreescope if TYPE_CHECKING: - from evermore.binned.effect import Effect + from evermore.binned.effect import AbstractEffect __all__ = [ "BooleanMask", @@ -33,43 +33,30 @@ def __dir__(): return __all__ -@runtime_checkable -class ModifierLike(Protocol): - def offset_and_scale(self, hist: Float[Array, "..."]) -> OffsetAndScale: ... # noqa: UP037 - def __call__(self, hist: Float[Array, "..."]) -> Float[Array, "..."]: ... # noqa: UP037 - def __matmul__(self, other: ModifierLike) -> Compose: ... - - -class AbstractModifier(eqx.Module): +class AbstractModifier(eqx.Module, Generic[PT], SupportsTreescope): @abc.abstractmethod - def offset_and_scale( - self: ModifierLike, - hist: Float[Array, "..."], # noqa: UP037 - ) -> OffsetAndScale: ... + def offset_and_scale(self: ModifierBase, hist: H) -> OffsetAndScale[H]: ... @abc.abstractmethod - def __call__( - self: ModifierLike, - hist: Float[Array, "..."], # noqa: UP037 - ) -> Float[Array, "..."]: ... # noqa: UP037 + def __call__(self: ModifierBase, hist: H) -> H: ... @abc.abstractmethod - def __matmul__(self: ModifierLike, other: ModifierLike) -> Compose: ... + def __matmul__(self: ModifierBase, other: ModifierBase) -> Compose: ... -class ApplyFn(AbstractModifier): +class ApplyFn(AbstractModifier[PT]): @jax.named_scope("evm.modifier.ApplyFn") - def __call__(self: ModifierLike, hist: Float[Array, "..."]) -> Float[Array, "..."]: # noqa: UP037 + def __call__(self: ModifierBase, hist: H) -> H: os = self.offset_and_scale(hist=hist) return os.scale * (hist + os.offset) -class MatMulCompose(AbstractModifier): - def __matmul__(self: ModifierLike, other: ModifierLike) -> Compose: +class MatMulCompose(AbstractModifier[PT]): + def __matmul__(self: ModifierBase, other: ModifierBase) -> Compose: return Compose(self, other) -class ModifierBase(ApplyFn, MatMulCompose, SupportsTreescope): +class ModifierBase(ApplyFn[PT], MatMulCompose[PT]): """ This serves as a base class for all modifiers. It automatically implements the __call__ method to apply the scale factors to the hist array @@ -83,17 +70,19 @@ class ModifierBase(ApplyFn, MatMulCompose, SupportsTreescope): import equinox as eqx import jax.numpy as jnp - from jaxtyping import Array + from jaxtyping import Array, Float import evermore as evm - class Clip(evm.modifier.ModifierBase): - modifier: evm.custom_types.ModifierLike + class Clip(evm.modifier.ModifierBase[evm.PT]): + modifier: evm.modifier.ModifierBase[evm.PT] min_sf: float = eqx.field(static=True) max_sf: float = eqx.field(static=True) - def offset_and_scale(self, hist: Array) -> evm.custom_types.OffsetAndScale: + def offset_and_scale( + self, hist: Float[Array, "..."] + ) -> evm.effect.OffsetAndScale[H]: os = self.modifier.offset_and_scale(hist) return jax.tree.map(lambda x: jnp.clip(x, self.min_sf, self.max_sf), os) @@ -112,7 +101,7 @@ def offset_and_scale(self, hist: Array) -> evm.custom_types.OffsetAndScale: """ -class Modifier(ModifierBase): +class Modifier(ModifierBase[PT]): """ Create a new modifier for a given parameter and penalty. @@ -162,20 +151,18 @@ class Modifier(ModifierBase): modify = norm.morphing(up_template=up_template, down_template=down_template) """ - parameter: PyTree[Parameter] - effect: Effect + parameter: PT + effect: AbstractEffect - def __init__( - self, parameter: PyTree[Parameter], effect: Effect = DEFAULT_EFFECT - ) -> None: + def __init__(self, parameter: PT, effect: AbstractEffect[H]) -> None: self.parameter = parameter self.effect = effect - def offset_and_scale(self, hist: Float[Array, "..."]) -> OffsetAndScale: # noqa: UP037 + def offset_and_scale(self, hist: H) -> OffsetAndScale[H]: return self.effect(parameter=self.parameter, hist=hist) -class Where(ModifierBase): +class Where(ModifierBase[PT]): """ Combine two modifiers based on a condition. @@ -211,10 +198,10 @@ class Where(ModifierBase): """ condition: Bool[Array, ...] - modifier_true: ModifierLike - modifier_false: ModifierLike + modifier_true: ModifierBase[PT] + modifier_false: ModifierBase[PT] - def offset_and_scale(self, hist: Float[Array, "..."]) -> OffsetAndScale: # noqa: UP037 + def offset_and_scale(self, hist: H) -> OffsetAndScale[H]: true_os = self.modifier_true.offset_and_scale(hist) false_os = self.modifier_false.offset_and_scale(hist) @@ -227,7 +214,7 @@ def _where( return jax.tree.map(_where, true_os, false_os) -class BooleanMask(ModifierBase): +class BooleanMask(ModifierBase[PT]): """ Mask a modifier for specific bins. @@ -255,9 +242,9 @@ class BooleanMask(ModifierBase): """ mask: Bool[Array, ...] - modifier: ModifierLike + modifier: ModifierBase[PT] - def offset_and_scale(self, hist: Float[Array, "..."]) -> OffsetAndScale: # noqa: UP037 + def offset_and_scale(self, hist: H) -> OffsetAndScale[H]: os = self.modifier.offset_and_scale(hist) def _mask( @@ -266,13 +253,13 @@ def _mask( ) -> Bool[Array, ...]: return jnp.where(self.mask, true, false) - return OffsetAndScale( + return OffsetAndScale[H]( offset=_mask(os.offset, 0.0), scale=_mask(os.scale, 1.0), - ) + ).broadcast() -class Transform(ModifierBase): +class Transform(ModifierBase[PT]): """ Transform the scale factors of a modifier. @@ -302,32 +289,36 @@ class Transform(ModifierBase): """ transform_fn: Callable = eqx.field(static=True) - modifier: ModifierLike + modifier: ModifierBase[PT] - def offset_and_scale(self, hist: Float[Array, "..."]) -> OffsetAndScale: # noqa: UP037 + def offset_and_scale(self, hist: H) -> OffsetAndScale[H]: os = self.modifier.offset_and_scale(hist) return jax.tree.map(self.transform_fn, os) -class TransformOffset(ModifierBase): +class TransformOffset(ModifierBase[PT]): transform_fn: Callable = eqx.field(static=True) - modifier: ModifierLike + modifier: ModifierBase[PT] - def offset_and_scale(self, hist: Float[Array, "..."]) -> OffsetAndScale: # noqa: UP037 + def offset_and_scale(self, hist: H) -> OffsetAndScale[H]: os = self.modifier.offset_and_scale(hist) - return OffsetAndScale(offset=self.transform_fn(os.offset), scale=os.scale) + return OffsetAndScale( + offset=self.transform_fn(os.offset), scale=os.scale + ).broadcast() -class TransformScale(ModifierBase): +class TransformScale(ModifierBase[PT]): transform_fn: Callable = eqx.field(static=True) - modifier: ModifierLike + modifier: ModifierBase[PT] - def offset_and_scale(self, hist: Float[Array, "..."]) -> OffsetAndScale: # noqa: UP037 + def offset_and_scale(self, hist: H) -> OffsetAndScale[H]: os = self.modifier.offset_and_scale(hist) - return OffsetAndScale(offset=os.offset, scale=self.transform_fn(os.scale)) + return OffsetAndScale( + offset=os.offset, scale=self.transform_fn(os.scale) + ).broadcast() -class Compose(ModifierBase): +class Compose(ModifierBase[PT]): """ Composition of multiple modifiers, in order to correctly apply them *together*. It behaves like a single modifier, but it is composed of multiple modifiers; it can be arbitrarily nested. @@ -376,9 +367,9 @@ class Compose(ModifierBase): eqx.filter_jit(composition)(hist) """ - modifiers: list[ModifierLike] + modifiers: list[ModifierBase[PT]] - def __init__(self, *modifiers: ModifierLike) -> None: + def __init__(self, *modifiers: ModifierBase[PT]) -> None: if not modifiers: msg = "At least one modifier must be provided to Compose." raise ValueError(msg) @@ -391,7 +382,7 @@ def __init__(self, *modifiers: ModifierLike) -> None: def __len__(self) -> int: return len(self.modifiers) - def offset_and_scale(self, hist: Float[Array, "..."]) -> OffsetAndScale: # noqa: UP037 + def offset_and_scale(self, hist: H) -> OffsetAndScale[H]: from collections import defaultdict # initial scale and offset @@ -431,19 +422,19 @@ def calc_sf(_hist, _dynamic_stack, _static_stack): scale *= jnp.prod(os.scale, axis=0) offset += jnp.sum(os.offset, axis=0) - return OffsetAndScale(offset=offset, scale=scale) + return OffsetAndScale(offset=offset, scale=scale).broadcast() -def unroll(modifiers: Iterable) -> Iterator[ModifierLike]: +def unroll(modifiers: Iterable[ModifierBase[PT]]) -> Iterator[ModifierBase[PT]]: # Helper to recursively flatten nested Compose instances into a single list for mod in modifiers: if isinstance(mod, Compose): # recursively yield from the modifiers of the Compose instance yield from unroll(mod.modifiers) - elif isinstance(mod, ModifierLike): - # yield the modifier if it is a ModifierLike instance + elif isinstance(mod, ModifierBase): + # yield the modifier if it is a ModifierBase instance yield mod else: - # raise an error if the modifier is not a ModifierLike instance - msg = f"Modifier {mod} is not a ModifierLike instance." + # raise an error if the modifier is not a ModifierBase instance + msg = f"Modifier {mod} is not a ModifierBase instance." # type: ignore[unreachable] raise TypeError(msg) diff --git a/src/evermore/binned/staterror.py b/src/evermore/binned/staterror.py index 76e21fe..809572a 100644 --- a/src/evermore/binned/staterror.py +++ b/src/evermore/binned/staterror.py @@ -1,13 +1,16 @@ from __future__ import annotations +from typing import TypeVar + import jax import jax.numpy as jnp -from jaxtyping import Array, Bool, Float, Scalar +from jaxtyping import Array, Bool, Float, Scalar, Shaped from evermore.binned.effect import Identity, OffsetAndScale from evermore.binned.modifier import Modifier, ModifierBase, Where from evermore.parameters.parameter import NormalParameter -from evermore.util import float_array +from evermore.parameters.tree import PT +from evermore.util import maybe_float_array __all__ = [ "StatErrors", @@ -18,7 +21,10 @@ def __dir__(): return __all__ -class StatErrors(ModifierBase): +N = TypeVar("N", bound=Shaped[Array, "..."]) + + +class StatErrors(ModifierBase[PT]): """ Create staterror (barlow-beeston) parameters. @@ -58,18 +64,18 @@ class StatErrors(ModifierBase): """ eps: Float[Scalar, ""] - n_entries: Float[Array, "..."] # noqa: UP037 - non_empty_mask: Bool[Array, " nbins"] - relative_error: Float[Array, "..."] # noqa: UP037 - parameter: NormalParameter + n_entries: Float[N] + non_empty_mask: Bool[N] + relative_error: Float[N] + parameter: NormalParameter[Float[N]] def __init__( self, - hist: Float[Array, "..."], # noqa: UP037 - variance: Float[Array, "..."], # noqa: UP037 + hist: Float[N], + variance: Float[N], ): # make sure they are of dtype float - hist, variance = jax.tree.map(float_array, (hist, variance)) + hist, variance = jax.tree.map(maybe_float_array, (hist, variance)) self.eps = jnp.finfo(variance.dtype).eps @@ -85,10 +91,10 @@ def __init__( / jnp.sqrt(self.n_entries + jnp.where(self.non_empty_mask, 0.0, self.eps)), 1.0, ) - self.parameter = NormalParameter(value=jnp.zeros_like(self.n_entries)) + self.parameter = NormalParameter(jnp.zeros_like(self.n_entries)) - def offset_and_scale(self, hist: Float[Array, "..."]) -> OffsetAndScale: # noqa: UP037 - modifier = Where( + def offset_and_scale(self, hist: Float[N]) -> OffsetAndScale[Float[N]]: + modifier: Where[PT] = Where( self.non_empty_mask, self.parameter.scale(slope=self.relative_error, offset=1.0), Modifier(parameter=self.parameter, effect=Identity()), diff --git a/src/evermore/loss.py b/src/evermore/loss.py index 6ebb8e4..def3a55 100644 --- a/src/evermore/loss.py +++ b/src/evermore/loss.py @@ -5,14 +5,14 @@ import jax.numpy as jnp from jaxtyping import Array, Float +from evermore.parameters.filter import is_parameter from evermore.parameters.parameter import ( - Parameter, - _params_map, - _ParamsTree, - is_parameter, + AbstractParameter, + V, replace_value, ) -from evermore.pdf import PDF, ImplementsFromUnitNormalConversion, Normal +from evermore.parameters.tree import PT, combine, only, partition, pure +from evermore.pdf import AbstractPDF, ImplementsFromUnitNormalConversion, Normal __all__ = [ "compute_covariance", @@ -24,29 +24,29 @@ def __dir__(): return __all__ -def get_log_probs(params: _ParamsTree) -> _ParamsTree: +def get_log_probs(tree: PT) -> PT: """ Compute the log probabilities for all parameters. - This function iterates over all parameters in the provided PyTree params, + This function iterates over all parameters in the provided PyTree tree, applies their associated prior distributions (if any), and computes the log probability for each parameter. If a parameter does not have a prior distribution, a default log probability of 0.0 is returned. Args: - params (PyTree): A PyTree containing parameters to compute log probabilities for. + tree (PyTree): A PyTree containing parameters to compute log probabilities for. Returns: PyTree: A PyTree with the same structure as the input, where each parameter is replaced by its corresponding log probability. """ - def _constraint(param: Parameter) -> Float[Array, "..."]: - prior: PDF | None = param.prior + def _constraint(param: AbstractParameter[V]) -> V: + prior: AbstractPDF | None = param.prior # unconstrained case is easy: if prior is None: - return jnp.array([0.0]) + return jnp.zeros_like(param.value) # all constrained parameters are 'moving' on a 'unit_normal' distribution (mean=0, width=1), i.e.: # - param.value=0: no shift, no constrain @@ -66,9 +66,9 @@ def _constraint(param: Parameter) -> Float[Array, "..."]: # this is the fast-path x = prior.__evermore_from_unit_normal__(param.value) else: - # this is a general implementation to translate from a unit normal to any target PDF + # this is a general implementation to translate from a unit normal to any target AbstractPDF # the only requirement is that the target pdf implements `.inv_cdf`. - unit_normal = Normal( + unit_normal: Normal[V] = Normal( mean=jnp.zeros_like(param.value), width=jnp.ones_like(param.value) ) cdf = unit_normal.cdf(param.value) @@ -76,12 +76,12 @@ def _constraint(param: Parameter) -> Float[Array, "..."]: return prior.log_prob(x) # constraints from pdfs - return _params_map(_constraint, params) + return jax.tree.map(_constraint, only(tree, is_parameter), is_leaf=is_parameter) def compute_covariance( loss_fn: tp.Callable, - params: _ParamsTree, + tree: PT, ) -> Float[Array, "nparams nparams"]: r""" Computes the covariance matrix of the parameters under the Laplace approximation, @@ -90,9 +90,9 @@ def compute_covariance( See ``examples/toy_generation.py`` for an example usage. Args: - loss_fn (Callable): The loss function. Should accept (params) as arguments. + loss_fn (Callable): The loss function. Should accept (tree) as arguments. All other arguments have to be "partial'd" into the loss function. - params (_ParamsTree): A PyTree of parameters. + tree (PT): A PyTree of parameters. Returns: Float[Array, "nparams nparams"]: The covariance matrix of the parameters. @@ -122,16 +122,26 @@ def loss_fn(params): # (2, 2) """ # first, compute the hessian at the current point - values = _params_map(lambda p: p.value, params) + values = pure(tree) flat_values, unravel_fn = jax.flatten_util.ravel_pytree(values) def _flat_loss(flat_values: Float[Array, "..."]) -> Float[Array, ""]: param_values = unravel_fn(flat_values) - _params = jax.tree.map( + # update the parameters with the new values + # and call the loss function + # 1. partition the tree of parameters and other things + # 2. update the parameters with the new values + # 3. combine the updated parameters with the rest of the tree + # 4. call the loss function with the updated tree + params, rest = partition(tree, filter=is_parameter) + + updated_params = jax.tree.map( replace_value, params, param_values, is_leaf=is_parameter ) - return loss_fn(_params) + + updated_tree = combine(updated_params, rest) + return loss_fn(updated_tree) # calculate hessian hessian = jax.hessian(_flat_loss)(flat_values) diff --git a/src/evermore/parameters/filter.py b/src/evermore/parameters/filter.py new file mode 100644 index 0000000..9cecbd8 --- /dev/null +++ b/src/evermore/parameters/filter.py @@ -0,0 +1,268 @@ +from __future__ import annotations + +import abc +import dataclasses +from collections.abc import Hashable +from typing import Any + +from evermore.parameters.parameter import AbstractParameter, ValueAttr + +__all__ = [ + "Filter", + "HasName", + "HasTags", + "IsFrozen", + "IsParam", + "OfType", + "ParameterFilter", + "is_frozen", + "is_not_frozen", + "is_parameter", + "is_value", +] + + +@dataclasses.dataclass(frozen=True) +class Filter(abc.ABC): + """ + Base class for filters that can be used to filter PyTrees with the evm.tree.* module. + """ + + @abc.abstractmethod + def __call__(self, x: Any) -> bool: ... + + @abc.abstractmethod + def is_leaf(self, x: Any) -> bool: ... + + +@dataclasses.dataclass(frozen=True) +class Not(Filter): + """ + A filter class that inverts the result of another filter. + + Attributes: + filter (Filter): The filter whose result will be negated. + """ + + filter: Filter + + def __call__(self, x: Any) -> bool: + """ + Makes the instance callable and returns the negation of the filter result for the given input. + + Args: + x (Any): The input to be evaluated by the filter. + + Returns: + bool: True if the filter returns False for the input, otherwise False. + """ + return not self.filter(x) + + def is_leaf(self, x: Any) -> bool: + return self.filter.is_leaf(x) + + +@dataclasses.dataclass(frozen=True) +class OfType(Filter): + """ + A filter that checks if a value is of a specified type. + + Attributes: + type (type): The type to check against. + """ + + type: type + + def __call__(self, x: Any): + """ + Check if the input object is an instance of the specified type. + + Args: + x (Any): The object to check. + + Returns: + bool: True if x is an instance of self.type, False otherwise. + """ + return isinstance(x, self.type) + + def is_leaf(self, x: Any) -> bool: + return self(x) + + +@dataclasses.dataclass(frozen=True) +class ParameterFilter(Filter): + def is_leaf(self, x: Any) -> bool: + return is_parameter(x) + + +@dataclasses.dataclass(frozen=True) +class IsParam(ParameterFilter): + """ + A parameter filter that matches a specific parameter instance. + + Attributes: + param (AbstractParameter): The parameter instance to match. + """ + + param: AbstractParameter + + def __post_init__(self): + if not isinstance(self.param, AbstractParameter): + msg = f"Expected an AbstractParameter, got {type(self.param).__name__}" # type: ignore[unreachable] + raise TypeError(msg) + + def __call__(self, x: AbstractParameter) -> bool: + """ + Checks if the given parameter is the same instance as the stored parameter. + + Args: + x (AbstractParameter): The parameter to compare. + + Returns: + bool: True if `x` is the same instance as `self.param`, False otherwise. + """ + return x is self.param + + +@dataclasses.dataclass(frozen=True) +class HasName(ParameterFilter): + """ + A filter that matches parameters by their name. + + Attributes: + name (str): The name to match against the parameter's name. + """ + + name: str + + def __call__(self, x: AbstractParameter): + """ + Compares the name attribute of this object with that of the given AbstractParameter. + + Args: + x (AbstractParameter): The parameter to compare against. + + Returns: + bool: True if the names are equal, False otherwise. + """ + return self.name == x.name + + +@dataclasses.dataclass(frozen=True) +class HasTags(ParameterFilter): + """ + A filter that checks if a parameter has all specified tags. + + Attributes: + tags (frozenset[Hashable]): The set of tags to check for. + """ + + tags: frozenset[Hashable] + + def __call__(self, x: AbstractParameter) -> bool: + """ + Determines if the tags of this filter are a subset of the tags of the given AbstractParameter. + + Args: + x (AbstractParameter): The parameter to check against. + + Returns: + bool: True if all tags in this filter are present in x.tags, False otherwise. + """ + return self.tags <= x.tags + + +@dataclasses.dataclass(frozen=True) +class IsFrozen(ParameterFilter): + """ + A filter that checks if a parameter is frozen. + """ + + def __call__(self, x: AbstractParameter) -> bool: + """ + Checks if the given AbstractParameter instance is frozen. + + Args: + x (AbstractParameter): The parameter to check. + + Returns: + bool: True if the parameter is frozen, False otherwise. + """ + return x.frozen + + +is_parameter = OfType(type=AbstractParameter) +""" +A filter that checks if a value is an instance of AbstractParameter. + +Example: + + .. code-block:: python + + import evermore as evm + + params = { + "a": evm.Parameter(value=1.0), + "b": 42, + "c": evm.Parameter(value=2.0), + } + + filtered_params = evm.tree.only(params, filter=evm.filter.is_parameter) +""" + +is_value = OfType(type=ValueAttr) +""" +A filter that checks if a value is an instance of ValueAttr. + +Example: + + .. code-block:: python + + import evermore as evm + + params = { + "a": evm.Parameter(value=1.0), + "b": 42, + "c": evm.Parameter(value=2.0), + } + + filtered_params = evm.tree.only(params, filter=evm.filter.is_value) +""" + +is_frozen = IsFrozen() +""" +A filter that checks if a parameter is frozen. + +Example: + + .. code-block:: python + + import evermore as evm + + params = { + "a": evm.Parameter(value=1.0, frozen=True), + "b": 42, + "c": evm.Parameter(value=2.0), + } + + filtered_params = evm.tree.only(params, filter=evm.filter.is_frozen) +""" + +is_not_frozen = Not(is_frozen) +""" +A filter that checks if a parameter is not frozen. + +Example: + + .. code-block:: python + + import evermore as evm + + params = { + "a": evm.Parameter(value=1.0, frozen=True), + "b": 42, + "c": evm.Parameter(value=2.0), + } + + filtered_params = evm.tree.only(params, filter=evm.filter.is_not_frozen) +""" diff --git a/src/evermore/parameters/mutable.py b/src/evermore/parameters/mutable.py new file mode 100644 index 0000000..44b6afa --- /dev/null +++ b/src/evermore/parameters/mutable.py @@ -0,0 +1,44 @@ +from __future__ import annotations + +import typing as tp + +import equinox as eqx +import jax +from jax._src.state.types import AbstractRef # noqa: PLC2701 +from jax.experimental import MutableArray, mutable_array +from jaxtyping import PyTree + + +def _is_mutable_array(x) -> tp.TypeGuard[MutableArray]: + return isinstance(x, jax.Array | AbstractRef | MutableArray) and isinstance( + jax.typeof(x), AbstractRef | MutableArray + ) + + +def to_refs(vals: PyTree) -> PyTree: + """ + Recursively converts all mutable array-like elements within a parameter tree to their mutable counterparts. + + Args: + vals (PyTree): The parameter tree to process. + + Returns: + PyTree: A new parameter tree where all mutable array-like elements have been converted using `mutable_array`. + """ + return jax.tree.map(lambda x: mutable_array(x) if eqx.is_array(x) else x, vals) + + +def to_arrays(vals: PyTree) -> PyTree: + """ + Converts all mutable arrays within the given parameter tree to their immutable counterparts. + + This function traverses the input parameter tree and, for each element, checks if it is a mutable array. + If so, it creates an immutable copy of the array; otherwise, the element is left unchanged. + + Args: + vals (PyTree): The parameter tree potentially containing mutable arrays. + + Returns: + PyTree: A new parameter tree where all mutable arrays have been replaced with immutable copies. + """ + return jax.tree.map(lambda x: x[...] if _is_mutable_array(x) else x, vals) diff --git a/src/evermore/parameters/parameter.py b/src/evermore/parameters/parameter.py index e86cf46..bf0a0de 100644 --- a/src/evermore/parameters/parameter.py +++ b/src/evermore/parameters/parameter.py @@ -1,28 +1,30 @@ from __future__ import annotations -from functools import partial -from typing import TYPE_CHECKING, Any, TypeVar +from collections.abc import Hashable, Iterator +from typing import ( + TYPE_CHECKING, + Generic, + TypeVar, +) import equinox as eqx -import jax -from jaxtyping import Array, ArrayLike, Float, PyTree +from jaxtyping import Array, ArrayLike, Float -from evermore.pdf import PDF, Normal -from evermore.util import _missing, filter_tree_map, float_array +from evermore.util import _missing, maybe_float_array from evermore.visualization import SupportsTreescope if TYPE_CHECKING: - from evermore.binned.modifier import Modifier - from evermore.parameters.transform import ParameterTransformation + from evermore.binned.modifier import H, Modifier + from evermore.parameters.transform import AbstractParameterTransformation + from evermore.pdf import AbstractPDF, Normal + __all__ = [ + "AbstractParameter", "NormalParameter", "Parameter", "correlate", - "is_parameter", - "partition", "replace_value", - "value_filter_spec", ] @@ -30,46 +32,119 @@ def __dir__(): return __all__ -class Parameter(eqx.Module, SupportsTreescope): - """ - A general Parameter class for defining the parameters of a statistical model. +def _numeric_methods(name): + bname = f"__{name}__" - Attributes: - value (Array): The actual value of the parameter. - name (str | None): An optional name for the parameter. - lower (Array | None): The lower boundary of the parameter. - upper (Array | None): The upper boundary of the parameter. - prior (PDF | None): The prior distribution of the parameter. - frozen (bool): Indicates if the parameter is frozen during optimization. - transform (ParameterTransformation | None): An optional transformation applied to the parameter. + def _binary(self, other): + other_val = other.value if isinstance(other, AbstractParameter) else other + return getattr(self.value, bname)(other_val) - Usage: + rname = f"__r{name}__" - .. code-block:: python + def _reflected(self, other): + other_val = other.value if isinstance(other, AbstractParameter) else other + return getattr(self.value, rname)(other_val) - import evermore as evm + iname = f"__i{name}__" - simple_param = evm.Parameter(value=1.0) - bounded_param = evm.Parameter(value=1.0, lower=0.0, upper=2.0) - constrained_parameter = evm.Parameter( - value=1.0, prior=evm.pdf.Normal(mean=1.0, width=0.1) - ) - frozen_parameter = evm.Parameter(value=1.0, frozen=True) - """ + def _inplace(self, other): + other_val = other.value if isinstance(other, AbstractParameter) else other + return getattr(self.value, iname)(other_val) - value: Float[Array, "..."] = eqx.field(converter=float_array, default=0.0) # noqa: UP037 - name: str | None = eqx.field(default=None) - lower: Float[Array, "..."] | None = eqx.field(default=None) # noqa: UP037 - upper: Float[Array, "..."] | None = eqx.field(default=None) # noqa: UP037 - prior: PDF | None = eqx.field(default=None) - frozen: bool = eqx.field(default=False) - transform: ParameterTransformation | None = eqx.field(default=None) + _binary.__name__ = bname + _reflected.__name__ = rname + _inplace.__name__ = iname + return _binary, _reflected, _inplace + + +def _unary_method(name): + uname = f"__{name}__" + + def _unary(self): + return getattr(self.value, uname)() + + _unary.__name__ = uname + return _unary + + +V = TypeVar("V", bound=Float[Array, "..."]) + + +class ValueAttr(eqx.Module, Generic[V], SupportsTreescope): + value: V - def __check_init__(self): - # runtime check to be sure - if self.prior is not None and not isinstance(self.prior, PDF): - msg = f"Prior must be a PDF object for a constrained Parameter (or 'None' for an unconstrained one), got {self.prior=} ({type(self.prior)=})" # type: ignore[unreachable] - raise ValueError(msg) + +def to_value(x: ArrayLike) -> ValueAttr: + if isinstance(x, ValueAttr): + return x + return ValueAttr(value=maybe_float_array(x, passthrough=False)) + + +class AbstractParameter( + eqx.Module, + Generic[V], + SupportsTreescope, +): + raw_value: eqx.AbstractVar[V] + name: eqx.AbstractVar[str | None] + lower: eqx.AbstractVar[V | None] + upper: eqx.AbstractVar[V | None] + prior: eqx.AbstractVar[AbstractPDF | None] + frozen: eqx.AbstractVar[bool] + transform: eqx.AbstractVar[AbstractParameterTransformation | None] + tags: eqx.AbstractVar[frozenset[Hashable]] + + @property + def value(self) -> V: + """ + Returns the value of the parameter. + + This property is used to access the actual value of the parameter, which can be a JAX array or any other type. + It is defined as a property to allow for lazy evaluation and potential transformations. + """ + return self.raw_value.value + + def __jax_array__(self): + return self.value + + def __len__(self) -> int: + return len(self.value) + + def __iter__(self) -> Iterator: + return iter(self.value) + + def __contains__(self, item) -> bool: + return item in self.value + + __add__, __radd__, __iadd__ = _numeric_methods("add") + __sub__, __rsub__, __isub__ = _numeric_methods("sub") + __mul__, __rmul__, __imul__ = _numeric_methods("mul") + __matmul__, __rmatmul__, __imatmul__ = _numeric_methods("matmul") + __truediv__, __rtruediv__, __itruediv__ = _numeric_methods("truediv") + __floordiv__, __rfloordiv__, __ifloordiv__ = _numeric_methods("floordiv") + __mod__, __rmod__, __imod__ = _numeric_methods("mod") + __divmod__, __rdivmod__, __idivmod__ = _numeric_methods("divmod") + __pow__, __rpow__, __ipow__ = _numeric_methods("pow") + __lshift__, __rlshift__, __ilshift__ = _numeric_methods("lshift") + __rshift__, __rrshift__, __irshift__ = _numeric_methods("rshift") + __and__, __rand__, __iand__ = _numeric_methods("and") + __xor__, __rxor__, __ixor__ = _numeric_methods("xor") + __or__, __ror__, __ior__ = _numeric_methods("or") + + __neg__ = _unary_method("neg") + __pos__ = _unary_method("pos") + __abs__ = _unary_method("abs") + __invert__ = _unary_method("invert") + __complex__ = _unary_method("complex") + __int__ = _unary_method("int") + __float__ = _unary_method("float") + __index__ = _unary_method("index") + __trunc__ = _unary_method("trunc") + __floor__ = _unary_method("floor") + __ceil__ = _unary_method("ceil") + + def __round__(self, ndigits: int) -> V: + return self.value.__round__(ndigits) def scale(self, slope: ArrayLike = 1.0, offset: ArrayLike = 0.0) -> Modifier: """ @@ -87,11 +162,84 @@ def scale(self, slope: ArrayLike = 1.0, offset: ArrayLike = 0.0) -> Modifier: return Modifier( parameter=self, - effect=Linear(slope=slope, offset=offset), # type: ignore[arg-type] + effect=Linear(slope=slope, offset=offset), ) -class NormalParameter(Parameter): +class Parameter(AbstractParameter[V]): + """ + A general Parameter class for defining the parameters of a statistical model. + + Attributes: + value (V): The actual value of the parameter. + name (str | None): An optional name for the parameter. + lower (V | None): The lower boundary of the parameter. + upper (V | None): The upper boundary of the parameter. + prior (AbstractPDF | None): The prior distribution of the parameter. + frozen (bool): Indicates if the parameter is frozen during optimization. + transform (AbstractParameterTransformation | None): An optional transformation applied to the parameter. + tags (frozenset[Hashable]): A set of tags associated with the parameter for additional metadata. + + Usage: + + .. code-block:: python + + import evermore as evm + + simple_param = evm.Parameter(value=1.0) + bounded_param = evm.Parameter(value=1.0, lower=0.0, upper=2.0) + constrained_parameter = evm.Parameter( + value=1.0, prior=evm.pdf.Normal(mean=1.0, width=0.1) + ) + frozen_parameter = evm.Parameter(value=1.0, frozen=True) + """ + + raw_value: V + name: str | None + lower: V | None + upper: V | None + prior: AbstractPDF | None + frozen: bool + transform: AbstractParameterTransformation | None + tags: frozenset[Hashable] = eqx.field(static=True) + + def __init__( + self, + value: V | ArrayLike = 0.0, + name: str | None = None, + lower: V | ArrayLike | None = None, + upper: V | ArrayLike | None = None, + prior: AbstractPDF | None = None, + frozen: bool = False, + transform: AbstractParameterTransformation | None = None, + tags: frozenset[Hashable] = frozenset(), + ) -> None: + self.raw_value = to_value(value) + self.name = name + + # boundaries + self.lower = maybe_float_array(lower) + self.upper = maybe_float_array(upper) + + # prior + self.prior = prior + + # frozen: if True, the parameter is not updated during optimization + self.frozen = frozen + self.transform = transform + + self.tags = tags + + def __check_init__(self): + from evermore.pdf import AbstractPDF + + # runtime check to be sure + if self.prior is not None and not isinstance(self.prior, AbstractPDF): + msg = f"Prior must be a AbstractPDF object for a constrained AbstractParameter (or 'None' for an unconstrained one), got {self.prior=} ({type(self.prior)=})" # type: ignore[unreachable] + raise ValueError(msg) + + +class NormalParameter(AbstractParameter[V]): """ A specialized Parameter class with a Normal prior distribution. @@ -99,18 +247,61 @@ class NormalParameter(Parameter): It also provides additional methods for scaling and morphing the parameter. Attributes: - prior (PDF | None): The prior distribution of the parameter, defaulting to a Normal distribution with mean 0.0 and width 1.0. + value (V): The actual value of the parameter. + name (str | None): An optional name for the parameter. + lower (V | None): The lower boundary of the parameter. + upper (V | None): The upper boundary of the parameter. + prior (Normal): The prior distribution of the parameter, set to a Normal distribution by default. + frozen (bool): Indicates if the parameter is frozen during optimization. + transform (AbstractParameterTransformation | None): An optional transformation applied to the parameter. + tags (frozenset[Hashable]): A set of tags associated with the parameter for additional metadata. + """ - prior: PDF | None = eqx.field(default_factory=lambda: Normal(mean=0.0, width=1.0)) # type: ignore[arg-type] + raw_value: V + name: str | None + lower: V | None + upper: V | None + prior: Normal + frozen: bool + transform: AbstractParameterTransformation | None + tags: frozenset[Hashable] = eqx.field(static=True) + + def __init__( + self, + value: V | ArrayLike = 0.0, + name: str | None = None, + lower: V | ArrayLike | None = None, + upper: V | ArrayLike | None = None, + frozen: bool = False, + transform: AbstractParameterTransformation | None = None, + tags: frozenset[Hashable] = frozenset(), + ) -> None: + from evermore.pdf import Normal + + self.raw_value = to_value(value) + self.name = name + + # boundaries + self.lower = maybe_float_array(lower) + self.upper = maybe_float_array(upper) + + # prior + self.prior = Normal(mean=0.0, width=1.0) + + # frozen: if True, the parameter is not updated during optimization + self.frozen = frozen + self.transform = transform + + self.tags = tags - def scale_log(self, up: Float[Array, "..."], down: Float[Array, "..."]) -> Modifier: # noqa: UP037 + def scale_log(self, up: ArrayLike, down: ArrayLike) -> Modifier: """ Applies an asymmetric exponential scaling to the parameter. Args: - up (Float[Array, "..."]): The scaling factor for the upward direction. - down (Float[Array, "..."]): The scaling factor for the downward direction. + up (ArrayLike): The scaling factor for the upward direction. + down (ArrayLike): The scaling factor for the downward direction. Returns: Modifier: A Modifier instance with the asymmetric exponential effect applied. @@ -118,19 +309,19 @@ def scale_log(self, up: Float[Array, "..."], down: Float[Array, "..."]) -> Modif from evermore.binned.effect import AsymmetricExponential from evermore.binned.modifier import Modifier - return Modifier(parameter=self, effect=AsymmetricExponential(up=up, down=down)) # type: ignore[arg-type] + return Modifier(parameter=self, effect=AsymmetricExponential(up=up, down=down)) def morphing( self, - up_template: Float[Array, "..."], # noqa: UP037 - down_template: Float[Array, "..."], # noqa: UP037 + up_template: H, + down_template: H, ) -> Modifier: """ Applies vertical template morphing to the parameter. Args: - up_template (Float[Array, "..."]): The template for the upward shift. - down_template (Float[Array, "..."]): The template for the downward shift. + up_template (H): The template for the upward shift. + down_template (H): The template for the downward shift. Returns: Modifier: A Modifier instance with the vertical template morphing effect applied. @@ -146,183 +337,49 @@ def morphing( ) -def is_parameter(leaf: Any) -> bool: +def replace_value( + param: AbstractParameter, + value: V, +) -> AbstractParameter: """ - Checks if the given leaf is an instance of the Parameter class. + Replaces the `value` attribute of a given `AbstractParameter` instance with a new value. Args: - leaf (Any): The object to check. + param (AbstractParameter): The parameter object whose value is to be replaced. + value (V): The new value to assign to the parameter. Returns: - bool: True if the leaf is an instance of Parameter, False otherwise. - """ - return isinstance(leaf, Parameter) - - -_params_map = partial(filter_tree_map, filter=is_parameter) + AbstractParameter: A new `AbstractParameter` instance with the updated value. - -_ParamsTree = TypeVar("_ParamsTree", bound=PyTree[Parameter]) - - -def replace_value( - param: Parameter, - value: Float[Array, "..."], # noqa: UP037 -) -> Parameter: + Notes: + This function uses `eqx.tree_at` to perform a functional update, returning a new object + rather than modifying the original `param` in place. + """ return eqx.tree_at( - lambda p: p.value, + lambda p: p.raw_value, param, - value, + to_value(value), is_leaf=lambda leaf: leaf is _missing, ) -def value_filter_spec(tree: _ParamsTree) -> _ParamsTree: - """ - Splits a PyTree of `evm.Parameter` instances into two PyTrees: one containing the values of the parameters - and the other containing the rest of the PyTree. This is useful for defining which components are to be optimized - and which to keep fixed during optimization. - - Args: - tree (_ParamsTree): A PyTree of `evm.Parameter` instances to be split. - - Returns: - _ParamsTree: A PyTree with the same structure as the input, but with boolean values indicating - which parts of the tree are dynamic (True) and which are static (False). - - Usage: - - .. code-block:: python - - from jaxtyping import Array - import evermore as evm - - # define a PyTree of parameters - params = { - "a": evm.Parameter(value=1.0), - "b": evm.Parameter(value=2.0), - } - - # split the PyTree into dynamic and the static parts - filter_spec = evm.parameter.value_filter_spec(params) - dynamic, static = eqx.partition(params, filter_spec) - - # model's first argument is only the dynamic part of the parameter PyTree!! - def model(dynamic, static, hists) -> Array: - # combine the dynamic and static parts of the parameter PyTree - parameters = eqx.combine(dynamic, static) - assert parameters == params - # use the parameters to calculate the model as usual - ... - """ - # 1. set the filter_spec to False for all (non-static) leaves - filter_spec = jax.tree.map(lambda _: False, tree) - - # 2. set the filter_spec to True for each parameter value, - # and _only_ the .value, because we don't want do optimize against anything else! - def _replace_value(filter_leaf: Any, tree_leaf: Any) -> Any: - if isinstance(filter_leaf, Parameter): - filter_leaf = eqx.tree_at( - lambda fl: fl.value, - filter_leaf, - not tree_leaf.frozen, - is_leaf=lambda leaf: leaf is _missing, - ) - return filter_leaf - - return jax.tree.map(_replace_value, filter_spec, tree, is_leaf=is_parameter) - - -def partition(tree: _ParamsTree) -> tuple[_ParamsTree, _ParamsTree]: - """ - Partitions a PyTree of parameters into two separate PyTrees: one containing the dynamic (optimizable) parts - and the other containing the static parts. - - This function serves as a shorthand for manually creating a filter specification and then using `eqx.partition` - to split the parameters. - - Args: - tree (_ParamsTree): A PyTree of parameters to be partitioned. - - Returns: - tuple[_ParamsTree, _ParamsTree]: A tuple containing two PyTrees. The first PyTree contains the dynamic parts - of the parameters, and the second PyTree contains the static parts. - - Example: - - .. code-block:: python - - import evermore as evm - - params = {"a": evm.Parameter(1.0), "b": evm.Parameter(2.0, frozen=True)} - - # Verbose: - filter_spec = evm.parameter.value_filter_spec(params) - dynamic, static = eqx.partition(params, filter_spec, replace=evm.util._missing) - print(dynamic) - # >> {'a': Parameter(value=f32[1]), 'b': Parameter(value=--, frozen=True)} - - print(static) - # >> {'a': Parameter(value=--), 'b': Parameter(value=f32[1], frozen=True)} - - # Short hand: - dynamic, static = evm.parameter.partition(params) - """ - return eqx.partition(tree, filter_spec=value_filter_spec(tree), replace=_missing) - - -def combine(*trees: tuple[_ParamsTree]) -> _ParamsTree: - """ - Combines multiple PyTrees of parameters into a single PyTree. - - For each leaf position, returns the first non-_missing value found among the input trees. - If all values _missing at a given position, returns _missing for that position. - - Args: - *trees (_ParamsTree): One or more PyTrees to be combined. - - Returns: - _ParamsTree: A PyTree with the same structure as the inputs, where each leaf is the first non-_missing value found at that position. - - Example: - - .. code-block:: python - - import evermore as evm - - params = {"a": evm.Parameter(1.0), "b": evm.Parameter(2.0, frozen=True)} - - dynamic, static = evm.parameter.partition(params) - reconstructed_params = evm.parameter.combine(dynamic, static) # inverse of `partition` - print(reconstructed_params) - # >> {"a": evm.Parameter(1.0), "b": evm.Parameter(2.0)} - """ - - def _combine(*args): - for arg in args: - if arg is not _missing: - return arg - return _missing - - return jax.tree.map(_combine, *trees, is_leaf=lambda x: x is _missing) - - -def correlate(*parameters: Parameter) -> tuple[Parameter, ...]: +def correlate(*parameters: AbstractParameter) -> tuple[AbstractParameter, ...]: """ Correlate parameters by sharing the value of the *first* given parameter. It is preferred to just use the same parameter if possible, this function should be used if that is not doable. Args: - *parameters (Parameter): A variable number of Parameter instances to be correlated. + *parameters (AbstractParameter): A variable number of AbstractParameter instances to be correlated. Returns: - tuple[Parameter, ...]: A tuple of correlated Parameter instances. + tuple[AbstractParameter, ...]: A tuple of correlated AbstractParameter instances. Example: .. code-block:: python + import jax from jaxtyping import PyTree import evermore as evm @@ -355,7 +412,7 @@ class Params(NamedTuple): def model(params: Params): - flat_params, tree_def = jax.tree.flatten(params, evm.parameter.is_parameter) + flat_params, tree_def = jax.tree.flatten(params, evm.filter.is_parameter) # correlate the parameters correlated_flat_params = evm.parameter.correlate(*flat_params) @@ -371,13 +428,15 @@ def model(params: Params): first, *rest = parameters - def _correlate(parameter: Parameter) -> tuple[Parameter, Parameter]: + def _correlate( + parameter: AbstractParameter[V], + ) -> tuple[AbstractParameter[V], AbstractParameter[V]]: ps = (first, parameter) - def where(ps: tuple[Parameter, Parameter]) -> Float[Array, "..."]: # noqa: UP037 + def where(ps: tuple[AbstractParameter[V], AbstractParameter[V]]) -> V: return ps[1].value - def get(ps: tuple[Parameter, Parameter]) -> Float[Array, "..."]: # noqa: UP037 + def get(ps: tuple[AbstractParameter[V], AbstractParameter[V]]) -> V: return ps[0].value shared = eqx.nn.Shared(ps, where, get) diff --git a/src/evermore/parameters/sample.py b/src/evermore/parameters/sample.py index 0572075..6e63b9d 100644 --- a/src/evermore/parameters/sample.py +++ b/src/evermore/parameters/sample.py @@ -4,14 +4,14 @@ import jax.numpy as jnp from jaxtyping import Array, Float, PRNGKeyArray +from evermore.parameters.filter import is_parameter from evermore.parameters.parameter import ( - Parameter, - _params_map, - _ParamsTree, - is_parameter, + AbstractParameter, + V, replace_value, ) -from evermore.pdf import PDF, PoissonBase +from evermore.parameters.tree import PT, only, pure +from evermore.pdf import AbstractPDF, PoissonBase from evermore.util import _missing __all__ = [ @@ -26,11 +26,11 @@ def __dir__(): def sample_from_covariance_matrix( key: jax.random.PRNGKey, - params: _ParamsTree, + params: PT, *, covariance_matrix: Float[Array, "nparams nparams"], n_samples: int = 1, -) -> _ParamsTree: +) -> PT: """ Samples parameter sets from a multivariate normal distribution defined by the given covariance matrix, centered around the current parameter values. @@ -39,12 +39,12 @@ def sample_from_covariance_matrix( Args: key (jax.random.PRNGKey): A JAX random key used for generating random samples. - params (_ParamsTree): A PyTree of parameters whose values will be used as the mean of the distribution. + params (PT): A PyTree of parameters whose values will be used as the mean of the distribution. covariance_matrix (Float[Array, "nparams nparams"]): The covariance matrix for the multivariate normal distribution. n_samples (int, optional): The number of samples to draw. Defaults to 1. Returns: - _ParamsTree: A PyTree with the same structure as `params`, where each parameter value is replaced + PT: A PyTree with the same structure as `params`, where each parameter value is replaced by a sampled value. If `n_samples > 1`, the parameter values will have a batch dimension as the first axis. Example: @@ -69,9 +69,8 @@ def sample_from_covariance_matrix( # (3, 1) """ # get the value & make sure it has at least 1d so we insert a batch dim later - values = _params_map( - lambda p: _missing if p.value is _missing else jnp.atleast_1d(p.value), params - ) + params_ = only(params, is_parameter) + values = jax.tree.map(jnp.atleast_1d, pure(params_)) flat_values, unravel_fn = jax.flatten_util.ravel_pytree(values) # sample parameter sets from the correlation matrix (centered around `flat_values`) @@ -91,7 +90,7 @@ def sample_from_covariance_matrix( ) -def sample_from_priors(params: _ParamsTree, key: PRNGKeyArray) -> _ParamsTree: +def sample_from_priors(params: PT, key: PRNGKeyArray) -> PT: """ Samples from the individual prior distributions of the parameters in the given PyTree. Note that no correlations between parameters are taken into account during sampling. @@ -99,11 +98,11 @@ def sample_from_priors(params: _ParamsTree, key: PRNGKeyArray) -> _ParamsTree: See ``examples/toy_generation.py`` for an example usage. Args: - params (_ParamsTree): A PyTree of parameters from which to sample. + params (PT): A PyTree of parameters from which to sample. key (PRNGKeyArray): A JAX random key used for generating random samples. Returns: - _ParamsTree: A new PyTree with the parameters sampled from their respective prior distributions. + PT: A new PyTree with the parameters sampled from their respective prior distributions. Example: @@ -120,9 +119,9 @@ def sample_from_priors(params: _ParamsTree, key: PRNGKeyArray) -> _ParamsTree: key = jax.random.PRNGKey(0) sampled = evm.sample.sample_from_priors(params, key) sampled["a"].value.shape - # (1,) + # () sampled["b"].value.shape - # (1,) + # () """ flat_params, treedef = jax.tree.flatten(params, is_leaf=is_parameter) n_params = len(flat_params) @@ -131,12 +130,12 @@ def sample_from_priors(params: _ParamsTree, key: PRNGKeyArray) -> _ParamsTree: keys = jax.random.split(key, n_params) keys_tree = jax.tree.unflatten(treedef, keys) - def _sample_from_prior(param: Parameter, key) -> Array: - if isinstance(param.prior, PDF) and param.value is not _missing: + def _sample_from_prior(param: AbstractParameter[V], key) -> V: + if isinstance(param.prior, AbstractPDF) and param.value is not _missing: pdf = param.prior # Sample new value from the prior pdf - sampled_value = pdf.sample(key) + sampled_value = pdf.sample(key, shape=(1,)) # TODO: this is not correct I assume if isinstance(pdf, PoissonBase): diff --git a/src/evermore/parameters/transform.py b/src/evermore/parameters/transform.py index 601d556..8cb1478 100644 --- a/src/evermore/parameters/transform.py +++ b/src/evermore/parameters/transform.py @@ -5,14 +5,18 @@ import equinox as eqx import jax import jax.numpy as jnp -from jaxtyping import ArrayLike -from evermore.parameters.parameter import Parameter, _params_map, _ParamsTree -from evermore.util import _missing +from evermore.parameters.filter import is_parameter +from evermore.parameters.parameter import ( + AbstractParameter, + V, + replace_value, +) +from evermore.parameters.tree import PT, only __all__ = [ + "AbstractParameterTransformation", "MinuitTransform", - "ParameterTransformation", "SoftPlusTransform", "unwrap", "wrap", @@ -23,7 +27,7 @@ def __dir__(): return __all__ -def unwrap(params: _ParamsTree) -> _ParamsTree: +def unwrap(params: PT) -> PT: """ Unwraps the parameters in the given PyTree by applying their respective transformations. @@ -31,21 +35,21 @@ def unwrap(params: _ParamsTree) -> _ParamsTree: transformation, if it exists. If a parameter does not have a transformation, it remains unchanged. Args: - params (_ParamsTree): A PyTree of parameters to be unwrapped. + params (PT): A PyTree of parameters to be unwrapped. Returns: - _ParamsTree: A new PyTree with the parameters unwrapped. + PT: A new PyTree with the parameters unwrapped. """ - def _unwrap(param: Parameter) -> Parameter: + def _unwrap(param: AbstractParameter[V]) -> AbstractParameter[V]: if param.transform is None: return param return param.transform.unwrap(param) - return _params_map(_unwrap, params) + return jax.tree.map(_unwrap, only(params, is_parameter), is_leaf=is_parameter) -def wrap(params: _ParamsTree) -> _ParamsTree: +def wrap(params: PT) -> PT: """ Wraps the parameters in the given PyTree by applying their respective transformations. This is the inverse operation of `unwrap`. @@ -54,21 +58,21 @@ def wrap(params: _ParamsTree) -> _ParamsTree: transformation, if it exists. If a parameter does not have a transformation, it remains unchanged. Args: - params (_ParamsTree): A PyTree of parameters to be wrapped. + params (PT): A PyTree of parameters to be wrapped. Returns: - _ParamsTree: A new PyTree with the parameters wrapped. + PT: A new PyTree with the parameters wrapped. """ - def _wrap(param: Parameter) -> Parameter: + def _wrap(param: AbstractParameter[V]) -> AbstractParameter[V]: if param.transform is None: return param return param.transform.wrap(param) - return _params_map(_wrap, params) + return jax.tree.map(_wrap, only(params, is_parameter), is_leaf=is_parameter) -class ParameterTransformation(eqx.Module): +class AbstractParameterTransformation(eqx.Module): """ Abstract base class for parameter transformations. @@ -78,31 +82,31 @@ class ParameterTransformation(eqx.Module): """ @abc.abstractmethod - def unwrap(self, parameter: Parameter) -> Parameter: + def unwrap(self, parameter: AbstractParameter[V]) -> AbstractParameter[V]: """ Transform a parameter from its meaningful (e.g. bounded) space to the real unconstrained space. Args: - parameter (Parameter): The parameter to be transformed. + parameter (AbstractParameter): The parameter to be transformed. Returns: - Parameter: The transformed parameter. + AbstractParameter: The transformed parameter. """ @abc.abstractmethod - def wrap(self, parameter: Parameter) -> Parameter: + def wrap(self, parameter: AbstractParameter[V]) -> AbstractParameter[V]: """ Transform a parameter from the real unconstrained space back to its meaningful (e.g. bounded) space. (Inverse of `unwrap`) Args: - parameter (Parameter): The parameter to be transformed. + parameter (AbstractParameter): The parameter to be transformed. Returns: - Parameter: The parameter transformed back to its original space. + AbstractParameter: The parameter transformed back to its original space. """ -class MinuitTransform(ParameterTransformation): +class MinuitTransform(AbstractParameterTransformation): """ Transform parameters based on Minuit's conventions. This transformation is used to map parameters with finite lower and upper boundaries to an unconstrained space. Both lower and upper boundaries @@ -133,33 +137,33 @@ class MinuitTransform(ParameterTransformation): # wrap back (or "inverse transform") pytree_tt = wrap(pytree_t) - wl.pprint(pytree, width=150, short_arrays=False) + wl.pprint(pytree, width=250, short_arrays=False) # { - # 'a': Parameter(value=Array([2.], dtype=float32), lower=-0.1, upper=2.2, transform=MinuitTransform()), - # 'b': Parameter(value=Array([0.1], dtype=float32), lower=0.0, upper=1.1, transform=MinuitTransform()) + # 'a': Parameter(raw_value=ValueAttr(value=Array(2., dtype=float32)), name=None, lower=Array(-0.1, dtype=float32), upper=Array(2.2, dtype=float32), prior=None, frozen=False, transform=MinuitTransform(), tags=frozenset()), + # 'b': Parameter(raw_value=ValueAttr(value=Array(0.1, dtype=float32)), name=None, lower=Array(0., dtype=float32), upper=Array(1.1, dtype=float32), prior=None, frozen=False, transform=MinuitTransform(), tags=frozenset()) # } - wl.pprint(pytree_t, width=150, short_arrays=False) + wl.pprint(pytree_t, width=250, short_arrays=False) # { - # 'a': Parameter(value=Array([0.9721281], dtype=float32), lower=-0.1, upper=2.2, transform=MinuitTransform()), - # 'b': Parameter(value=Array([-0.95824164], dtype=float32), lower=0.0, upper=1.1, transform=MinuitTransform()) + # 'a': Parameter(raw_value=ValueAttr(value=Array(0.9721281, dtype=float32)), name=None, lower=Array(-0.1, dtype=float32), upper=Array(2.2, dtype=float32), prior=None, frozen=False, transform=MinuitTransform(), tags=frozenset()), + # 'b': Parameter(raw_value=ValueAttr(value=Array(-0.95824164, dtype=float32)), name=None, lower=Array(0., dtype=float32), upper=Array(1.1, dtype=float32), prior=None, frozen=False, transform=MinuitTransform(), tags=frozenset()) # } - wl.pprint(pytree_tt, width=150, short_arrays=False) + wl.pprint(pytree_tt, width=250, short_arrays=False) # { - # 'a': Parameter(value=Array([1.9999999], dtype=float32), lower=-0.1, upper=2.2, transform=MinuitTransform()), - # 'b': Parameter(value=Array([0.09999997], dtype=float32), lower=0.0, upper=1.1, transform=MinuitTransform()) + # 'a': Parameter(raw_value=ValueAttr(value=Array(1.9999999, dtype=float32)), name=None, lower=Array(-0.1, dtype=float32), upper=Array(2.2, dtype=float32), prior=None, frozen=False, transform=MinuitTransform(), tags=frozenset()), + # 'b': Parameter(raw_value=ValueAttr(value=Array(0.09999997, dtype=float32)), name=None, lower=Array(0., dtype=float32), upper=Array(1.1, dtype=float32), prior=None, frozen=False, transform=MinuitTransform(), tags=frozenset()) # } """ - def _check(self, parameter: Parameter) -> Parameter: + def _check(self, parameter: AbstractParameter[V]) -> AbstractParameter[V]: if (parameter.lower is None and parameter.upper is not None) or ( parameter.lower is not None and parameter.upper is None ): msg = f"{parameter} must have both lower and upper boundaries set, or none of them." raise ValueError(msg) - lower: ArrayLike = parameter.lower # type: ignore[assignment] - upper: ArrayLike = parameter.upper # type: ignore[assignment] + lower = parameter.lower + upper = parameter.upper # check for finite boundaries error_msg = f"Bounds of {parameter} must be finite, got {parameter.lower=}, {parameter.upper=}." parameter = eqx.error_if( @@ -173,7 +177,7 @@ def _check(self, parameter: Parameter) -> Parameter: error_msg, ) - def unwrap(self, parameter: Parameter) -> Parameter: + def unwrap(self, parameter: AbstractParameter[V]) -> AbstractParameter[V]: # short-cut if parameter.lower is None and parameter.upper is None: return parameter @@ -195,30 +199,26 @@ def unwrap(self, parameter: Parameter) -> Parameter: # this formula turns user-provided "external" parameter values into "internal" values value_t = jnp.arcsin( 2.0 - * (parameter.value - parameter.lower) # type: ignore[operator] - / (parameter.upper - parameter.lower) # type: ignore[operator] + * (parameter.value - parameter.lower) + / (parameter.upper - parameter.lower) - 1.0 ) - return eqx.tree_at( - lambda p: p.value, parameter, value_t, is_leaf=lambda leaf: leaf is _missing - ) + return replace_value(parameter, value_t) - def wrap(self, parameter: Parameter) -> Parameter: + def wrap(self, parameter: AbstractParameter[V]) -> AbstractParameter[V]: # short-cut if parameter.lower is None and parameter.upper is None: return parameter parameter = self._check(parameter) # this formula turns "internal" parameter values into "external" values - value_t = parameter.lower + (parameter.upper - parameter.lower) / 2 * ( # type: ignore[operator] + value_t = parameter.lower + (parameter.upper - parameter.lower) / 2 * ( jnp.sin(parameter.value) + 1 ) - return eqx.tree_at( - lambda p: p.value, parameter, value_t, is_leaf=lambda leaf: leaf is _missing - ) + return replace_value(parameter, value_t) -class SoftPlusTransform(ParameterTransformation): +class SoftPlusTransform(AbstractParameterTransformation): """ Applies the softplus transformation to parameters, projecting them from real space (R) to positive space (R+). This transformation is useful for enforcing the positivity of parameters and does not require lower or upper boundaries. @@ -245,26 +245,26 @@ class SoftPlusTransform(ParameterTransformation): # wrap back (or "inverse transform") pytree_tt = wrap(pytree_t) - wl.pprint(pytree, width=150, short_arrays=False) + wl.pprint(pytree, width=250, short_arrays=False) # { - # 'a': Parameter(value=Array([2.], dtype=float32), transform=SoftPlusTransform()), - # 'b': Parameter(value=Array([0.1], dtype=float32), transform=SoftPlusTransform()) + # 'a': Parameter(raw_value=ValueAttr(value=Array(2., dtype=float32)), name=None, lower=None, upper=None, prior=None, frozen=False, transform=SoftPlusTransform(), tags=frozenset()), + # 'b': Parameter(raw_value=ValueAttr(value=Array(0.1, dtype=float32)), name=None, lower=None, upper=None, prior=None, frozen=False, transform=SoftPlusTransform(), tags=frozenset()) # } - wl.pprint(pytree_t, width=150, short_arrays=False) + wl.pprint(pytree_t, width=250, short_arrays=False) # { - # 'a': Parameter(value=Array([1.8545866], dtype=float32), transform=SoftPlusTransform()), - # 'b': Parameter(value=Array([-2.2521687], dtype=float32), transform=SoftPlusTransform()) + # 'a': Parameter(raw_value=ValueAttr(value=Array(1.8545866, dtype=float32)), name=None, lower=None, upper=None, prior=None, frozen=False, transform=SoftPlusTransform(), tags=frozenset()), + # 'b': Parameter(raw_value=ValueAttr(value=Array(-2.2521687, dtype=float32)), name=None, lower=None, upper=None, prior=None, frozen=False, transform=SoftPlusTransform(), tags=frozenset()) # } - wl.pprint(pytree_tt, width=150, short_arrays=False) + wl.pprint(pytree_tt, width=250, short_arrays=False) # { - # 'a': Parameter(value=Array([2.], dtype=float32), transform=SoftPlusTransform()), - # 'b': Parameter(value=Array([0.09999998], dtype=float32), transform=SoftPlusTransform()) + # 'a': Parameter(raw_value=ValueAttr(value=Array(2., dtype=float32)), name=None, lower=None, upper=None, prior=None, frozen=False, transform=SoftPlusTransform(), tags=frozenset()), + # 'b': Parameter(raw_value=ValueAttr(value=Array(0.09999998, dtype=float32)), name=None, lower=None, upper=None, prior=None, frozen=False, transform=SoftPlusTransform(), tags=frozenset()) # } """ - def unwrap(self, parameter: Parameter) -> Parameter: + def unwrap(self, parameter: AbstractParameter[V]) -> AbstractParameter[V]: # from: https://github.com/danielward27/paramax/blob/main/paramax/utils.py """The inverse of the softplus function, checking for positive inputs.""" parameter = eqx.error_if( @@ -273,8 +273,8 @@ def unwrap(self, parameter: Parameter) -> Parameter: "Expected positive inputs to inv_softplus.", ) value_t = jnp.log(-jnp.expm1(-parameter.value)) + parameter.value - return eqx.tree_at(lambda p: p.value, parameter, value_t) + return replace_value(parameter, value_t) - def wrap(self, parameter: Parameter) -> Parameter: + def wrap(self, parameter: AbstractParameter[V]) -> AbstractParameter[V]: value_t = jax.nn.softplus(parameter.value) - return eqx.tree_at(lambda p: p.value, parameter, value_t) + return replace_value(parameter, value_t) diff --git a/src/evermore/parameters/tree.py b/src/evermore/parameters/tree.py new file mode 100644 index 0000000..4725edb --- /dev/null +++ b/src/evermore/parameters/tree.py @@ -0,0 +1,262 @@ +from __future__ import annotations + +from typing import ( + TYPE_CHECKING, + TypeVar, +) + +import equinox as eqx +import jax +from jaxtyping import PyTree + +from evermore.parameters.filter import ( + Filter, + is_not_frozen, + is_parameter, + is_value, +) +from evermore.parameters.parameter import AbstractParameter, V +from evermore.util import _missing + +if TYPE_CHECKING: + pass + + +__all__ = [ + "combine", + "only", + "partition", + "pure", + "value_filter_spec", +] + + +PT = TypeVar("PT", bound=PyTree[AbstractParameter[V]]) + + +def only(tree: PT, filter: Filter) -> PT: + """ + Filters a PyTree to include only leaves that are instances of the specified type. + + Args: + tree (PT): A PyTree containing various objects, some of which may be instances of the specified type. + filter (Filter): A callable that checks if an object is of the specified type. + + Returns: + PT: A new PyTree containing only the instances of the specified type from the original tree. + + Example: + + .. code-block:: python + + + import equinox as eqx + import wadler_lindig as wl + import evermore as evm + + params = { + "a": evm.Parameter(1.0), + "b": 42, + "c": evm.Parameter(2.0), + } + + filtered = evm.tree.only(params, evm.filter.is_parameter) + wl.pprint(filtered, width=150) + # { + # 'a': Parameter(raw_value=ValueAttr(value=f32[](jax)), name=None, lower=None, upper=None, prior=None, frozen=False, transform=None, tags=frozenset()), + # 'b': --, + # 'c': Parameter(raw_value=ValueAttr(value=f32[](jax)), name=None, lower=None, upper=None, prior=None, frozen=False, transform=None, tags=frozenset()) + # } + """ + return eqx.filter(tree, filter, replace=_missing, is_leaf=filter.is_leaf) + + +def pure(tree: PT) -> PT: + """ + Extracts the raw values from a parameter tree. + + Args: + params (PT): A tree structure containing parameter objects. + + Returns: + PT: A tree structure with the same shape as `params`, but with each parameter replaced by its underlying value. + + Example: + + .. code-block:: python + + import equinox as eqx + import wadler_lindig as wl + import evermore as evm + + params = { + "a": evm.Parameter(1.0), + "b": 42, + "c": evm.Parameter(2.0), + } + + pure_values = evm.tree.pure(params) + wl.pprint(pure_values, short_arrays=False, width=150) + # {'a': Array(1., dtype=float32), 'b': --, 'c': Array(2., dtype=float32)} + """ + parameters = only(tree, is_parameter) + return jax.tree.map(lambda p: p.value, parameters, is_leaf=is_parameter) + + +def value_filter_spec(tree: PT, filter: Filter) -> PT: + """ + Splits a PyTree of `AbstractParameter` instances into two PyTrees: one containing the values of the parameters + and the other containing the rest of the PyTree. This is useful for defining which components are to be optimized + and which to keep fixed during optimization. + + Args: + tree (PT): A PyTree of `AbstractParameter` instances to be split. + filter (Filter | None, optional): A filter that defines which parameters are static (frozen). + If provided, it will be used to determine which parameters are static (frozen) and which are dynamic. + + Returns: + PT: A PyTree with the same structure as the input, but with boolean values indicating + which parts of the tree are dynamic (True) and which are static (False). + + Usage: + + .. code-block:: python + + from jaxtyping import Array + import equinox as eqx + import evermore as evm + + # define a PyTree of parameters + params = { + "a": evm.Parameter(value=1.0), + "b": evm.Parameter(value=2.0), + } + + # split the PyTree into dynamic and the static parts + filter_spec = evm.tree.value_filter_spec(params, filter=evm.filter.is_not_frozen) + dynamic, static = eqx.partition(params, filter_spec) + + # model's first argument is only the dynamic part of the parameter PyTree!! + def model(dynamic, static, hists) -> Array: + # combine the dynamic and static parts of the parameter PyTree + parameters = evm.tree.combine(dynamic, static) + assert eqx.tree_equal(params, parameters) + # use the parameters to calculate the model as usual + ... + """ + if not isinstance(filter, Filter): + msg = f"Expected a Filter, got {filter} ({type(filter)=})" # type: ignore[unreachable] + raise ValueError(msg) + + # 1. split by the filter + left_tree, right_tree = eqx.partition( + tree, + filter_spec=filter, + is_leaf=filter.is_leaf, + ) + + # 2. set the .raw_value attr to True for each parameter from the `left_tree`, rest is False + value_tree = jax.tree.map(is_value, left_tree, is_leaf=is_value.is_leaf) + false_tree = jax.tree.map(lambda _: False, right_tree, is_leaf=is_value.is_leaf) + + # 3. combine the two trees to get the final filter spec + return eqx.combine(value_tree, false_tree, is_leaf=filter.is_leaf) + + +def partition(tree: PT, filter: Filter | None = None) -> tuple[PT, PT]: + """ + Partitions a PyTree of parameters into two separate PyTrees: one containing the dynamic (optimizable) parts + and the other containing the static parts. + + This function serves as a shorthand for manually creating a filter specification and then using `eqx.partition` + to split the parameters. + + Args: + tree (PT): A PyTree of parameters to be partitioned. + filter (Filter | None, optional): A filter that defines which parameters are static (frozen). + If provided, it will be used to determine which parameters are static (frozen) and which are dynamic. + + Returns: + tuple[PT, PT]: A tuple containing two PyTrees. The first PyTree contains the dynamic parts + of the parameters, and the second PyTree contains the static parts. + + Example: + + .. code-block:: python + + import equinox as eqx + import wadler_lindig as wl + import evermore as evm + + params = {"a": evm.Parameter(1.0), "b": evm.Parameter(2.0, frozen=True)} + + # Verbose: + filter_spec = evm.tree.value_filter_spec(params, filter=evm.filter.is_not_frozen) + dynamic, static = eqx.partition(params, filter_spec, replace=evm.util._missing) + wl.pprint(dynamic, width=150) + # { + # 'a': Parameter(raw_value=ValueAttr(value=f32[](jax)), name=None, lower=None, upper=None, prior=None, frozen=--, transform=None, tags=frozenset()), + # 'b': Parameter(raw_value=ValueAttr(value=--), name=None, lower=None, upper=None, prior=None, frozen=--, transform=None, tags=frozenset()) + # } + + wl.pprint(static, width=150) + # { + # 'a': Parameter(raw_value=ValueAttr(value=--), name=None, lower=None, upper=None, prior=None, frozen=False, transform=None, tags=frozenset()), + # 'b': Parameter(raw_value=ValueAttr(value=f32[](jax)), name=None, lower=None, upper=None, prior=None, frozen=True, transform=None, tags=frozenset()) + # } + + # Short hand: + dynamic, static = evm.tree.partition(params) + """ + if filter is None: + # If no filter is provided, we assume all parameters are dynamic, + # except those that are marked as frozen. + filter = is_not_frozen + return eqx.partition( + tree, + filter_spec=value_filter_spec(tree, filter=filter), + replace=_missing, + ) + + +def combine(*trees: tuple[PT]) -> PT: + """ + Combines multiple PyTrees of parameters into a single PyTree. + + For each leaf position, returns the first non-_missing value found among the input trees. + If all values _missing at a given position, returns _missing for that position. + + Args: + *trees (PT): One or more PyTrees to be combined. + + Returns: + PT: A PyTree with the same structure as the inputs, where each leaf is the first non-_missing value found at that position. + + Example: + + .. code-block:: python + + import equinox as eqx + import wadler_lindig as wl + import evermore as evm + + params = {"a": evm.Parameter(1.0), "b": evm.Parameter(2.0, frozen=True)} + + dynamic, static = evm.tree.partition(params) + reconstructed_params = evm.tree.combine(dynamic, static) # inverse of `partition` + wl.pprint(reconstructed_params, width=150) + # { + # 'a': Parameter(raw_value=ValueAttr(value=f32[](jax)), name=None, lower=None, upper=None, prior=None, frozen=False, transform=None, tags=frozenset()), + # 'b': Parameter(raw_value=ValueAttr(value=f32[](jax)), name=None, lower=None, upper=None, prior=None, frozen=True, transform=None, tags=frozenset()) + # } + + assert eqx.tree_equal(params, reconstructed_params) + """ + + def _combine(*args): + for arg in args: + if arg is not _missing: + return arg + return _missing + + return jax.tree.map(_combine, *trees, is_leaf=lambda x: x is _missing) diff --git a/src/evermore/pdf.py b/src/evermore/pdf.py index a69fab3..e6355dd 100644 --- a/src/evermore/pdf.py +++ b/src/evermore/pdf.py @@ -1,7 +1,7 @@ from __future__ import annotations import abc -import typing as tp +from typing import Generic, Protocol, runtime_checkable import equinox as eqx import jax @@ -10,11 +10,12 @@ from jax.scipy.special import digamma, gammaln, xlogy from jaxtyping import Array, Float, PRNGKeyArray -from evermore.util import float_array +from evermore.parameters.parameter import V +from evermore.util import maybe_float_array from evermore.visualization import SupportsTreescope __all__ = [ - "PDF", + "AbstractPDF", "Normal", "PoissonBase", "PoissonContinuous", @@ -26,68 +27,58 @@ def __dir__(): return __all__ -@tp.runtime_checkable -class ImplementsFromUnitNormalConversion(tp.Protocol): - def __evermore_from_unit_normal__( - self, - x: Float[Array, "..."], # noqa: UP037 - ) -> Float[Array, "..."]: ... # noqa: UP037 +@runtime_checkable +class ImplementsFromUnitNormalConversion(Protocol[V]): + def __evermore_from_unit_normal__(self, x: V) -> V: ... -class PDF(eqx.Module, SupportsTreescope): +class AbstractPDF(eqx.Module, Generic[V], SupportsTreescope): @abc.abstractmethod - def log_prob(self, x: Float[Array, "..."]) -> Float[Array, "..."]: ... # noqa: UP037 + def log_prob(self, x: V) -> V: ... @abc.abstractmethod - def cdf(self, x: Float[Array, "..."]) -> Float[Array, "..."]: ... # noqa: UP037 + def cdf(self, x: V) -> V: ... @abc.abstractmethod - def inv_cdf(self, x: Float[Array, "..."]) -> Float[Array, "..."]: ... # noqa: UP037 + def inv_cdf(self, x: V) -> V: ... @abc.abstractmethod - def sample( - self, key: PRNGKeyArray, shape: Shape | None = None - ) -> Float[Array, "..."]: ... # noqa: UP037 + def sample(self, key: PRNGKeyArray, shape: Shape) -> Float[Array, ...]: ... - def prob(self, x: Float[Array, "..."], **kwargs) -> Float[Array, "..."]: # noqa: UP037 + def prob(self, x: V, **kwargs) -> V: return jnp.exp(self.log_prob(x, **kwargs)) -class Normal(PDF): - mean: Float[Array, "..."] = eqx.field(converter=float_array) # noqa: UP037 - width: Float[Array, "..."] = eqx.field(converter=float_array) # noqa: UP037 +class Normal(AbstractPDF[V]): + mean: V = eqx.field(converter=maybe_float_array) + width: V = eqx.field(converter=maybe_float_array) - def log_prob(self, x: Float[Array, "..."]) -> Float[Array, "..."]: # noqa: UP037 + def log_prob(self, x: V) -> V: logpdf_max = jax.scipy.stats.norm.logpdf( self.mean, loc=self.mean, scale=self.width ) unnormalized = jax.scipy.stats.norm.logpdf(x, loc=self.mean, scale=self.width) return unnormalized - logpdf_max - def cdf(self, x: Float[Array, "..."]) -> Float[Array, "..."]: # noqa: UP037 + def cdf(self, x: V) -> V: return jax.scipy.stats.norm.cdf(x, loc=self.mean, scale=self.width) - def inv_cdf(self, x: Float[Array, "..."]) -> Float[Array, "..."]: # noqa: UP037 + def inv_cdf(self, x: V) -> V: return jax.scipy.stats.norm.ppf(x, loc=self.mean, scale=self.width) - def __evermore_from_unit_normal__(self, x: Array) -> Array: + def __evermore_from_unit_normal__(self, x: V) -> V: return self.mean + self.width * x - def sample( - self, key: PRNGKeyArray, shape: Shape | None = None - ) -> Float[Array, "..."]: # noqa: UP037 - # jax.random.normal does not accept None shape - if shape is None: - shape = () + def sample(self, key: PRNGKeyArray, shape: Shape) -> Float[Array, ...]: # sample parameter from pdf return self.__evermore_from_unit_normal__(jax.random.normal(key, shape=shape)) -class PoissonBase(PDF): - lamb: Float[Array, "..."] = eqx.field(converter=float_array) # noqa: UP037 +class PoissonBase(AbstractPDF[V]): + lamb: V = eqx.field(converter=maybe_float_array) -class PoissonDiscrete(PoissonBase): +class PoissonDiscrete(PoissonBase[V]): """ Poisson distribution with discrete support. Float inputs are floored to the nearest integer. See https://root.cern.ch/doc/master/RooPoisson_8cxx_source.html#l00057 for reference. @@ -95,9 +86,9 @@ class PoissonDiscrete(PoissonBase): def log_prob( self, - x: Float[Array, "..."], # noqa: UP037 + x: V, normalize: bool = True, - ) -> Float[Array, "..."]: # noqa: UP037 + ) -> V: x = jnp.floor(x) unnormalized = jax.scipy.stats.poisson.logpmf(x, self.lamb) @@ -107,10 +98,10 @@ def log_prob( logpdf_max = jax.scipy.stats.poisson.logpmf(x, x) return unnormalized - logpdf_max - def cdf(self, x: Float[Array, "..."]) -> Float[Array, "..."]: # noqa: UP037 + def cdf(self, x: V) -> V: return jax.scipy.stats.poisson.cdf(x, self.lamb) - def inv_cdf(self, x: Float[Array, "..."]) -> Float[Array, "..."]: # noqa: UP037 + def inv_cdf(self, x: V) -> V: # perform an iterative search # see: https://num.pyro.ai/en/stable/tutorials/truncated_distributions.html?highlight=poisson%20inverse#5.3-Example:-Left-truncated-Poisson def cond_fn(val): @@ -129,22 +120,17 @@ def body_fn(val): # since we check for cdf < value, n will always refer to the next value return jnp.clip(n - 1, min=0) - def sample( - self, key: PRNGKeyArray, shape: Shape | None = None - ) -> Float[Array, "..."]: # noqa: UP037 - # jax.random.poisson does not accept empty tuple shape - if shape == (): - shape = None + def sample(self, key: PRNGKeyArray, shape: Shape) -> Float[Array, ...]: return jax.random.poisson(key, self.lamb, shape=shape) -class PoissonContinuous(PoissonBase): +class PoissonContinuous(PoissonBase[V]): def log_prob( self, - x: Float[Array, "..."], # noqa: UP037 + x: V, normalize: bool = True, shift_mode: bool = False, - ) -> Float[Array, "..."]: # noqa: UP037 + ) -> V: # optionally adjust lambda to a higher value such that the new mode is the current lambda lamb = jnp.exp(digamma(self.lamb + 1)) if shift_mode else self.lamb @@ -160,16 +146,16 @@ def _log_prob(x, lamb): logpdf_max = _log_prob(*args) return unnormalized - logpdf_max - def cdf(self, x: Float[Array, "..."]) -> Float[Array, "..."]: # noqa: UP037 + def cdf(self, x: V) -> V: err = f"{self.__class__.__name__} does not support cdf" raise Exception(err) - def inv_cdf(self, x: Float[Array, "..."]) -> Float[Array, "..."]: # noqa: UP037 + def inv_cdf(self, x: V) -> V: err = f"{self.__class__.__name__} does not support inv_cdf" raise Exception(err) def sample( self, key: PRNGKeyArray, shape: Shape | None = None - ) -> Float[Array, "..."]: # noqa: UP037 + ) -> Float[Array, ...]: msg = f"{self.__class__.__name__} does not support sampling, use PoissonDiscrete instead" raise Exception(msg) diff --git a/src/evermore/util.py b/src/evermore/util.py index ea8b91f..15157e9 100644 --- a/src/evermore/util.py +++ b/src/evermore/util.py @@ -15,7 +15,7 @@ "dump_hlo_graph", "dump_jaxpr", "filter_tree_map", - "float_array", + "maybe_float_array", "sum_over_leaves", "tree_stack", ] @@ -25,8 +25,13 @@ def __dir__(): return __all__ -def float_array(x: Any) -> Float[Array, "..."]: # noqa: UP037 - return jnp.asarray(x, jnp.result_type(float)) +def maybe_float_array(x: Any, passthrough: bool = True) -> Float[Array, "..."]: # noqa: UP037 + if eqx.is_array_like(x): + return jnp.asarray(x, jnp.result_type(float)) + if passthrough: + return x + msg = f"Expected an array-like object, got {type(x).__name__} instead." + raise ValueError(msg) @jax.tree_util.register_static @@ -46,10 +51,10 @@ def filter_tree_map( tree: PyTree, filter: Callable, ) -> PyTree: - params = eqx.filter(tree, filter, is_leaf=filter) + filtered = eqx.filter(tree, filter, is_leaf=filter) return jax.tree.map( fun, - params, + filtered, is_leaf=filter, ) @@ -171,4 +176,4 @@ def f(x: jax.Array) -> jax.Array: filepath = pathlib.Path("graph.gv") filepath.write_text(dump_hlo_graph(f, x), encoding="ascii") """ - return jax.jit(fun).lower(*args, **kwargs).compiler_ir("hlo").as_hlo_dot_graph() # type: ignore[union-attr] + return jax.jit(fun).lower(*args, **kwargs).compiler_ir("hlo").as_hlo_dot_graph() diff --git a/src/evermore/visualization.py b/src/evermore/visualization.py index f878b75..62e07f8 100644 --- a/src/evermore/visualization.py +++ b/src/evermore/visualization.py @@ -1,121 +1,27 @@ from __future__ import annotations -import dataclasses from collections.abc import Callable from typing import Any -from treescope import ( - dataclass_util, - formatting_util, - renderers, - rendering_parts, -) +import treescope class SupportsTreescope: def __treescope_repr__( self, path: str, - subtree_renderer: Callable[[Any, str | None], rendering_parts.Rendering], - ) -> rendering_parts.Rendering: - return handle_evermore_classes(self, path, subtree_renderer) - - -def handle_evermore_classes( - node: Any, - path: str | None, - subtree_renderer: renderers.TreescopeSubtreeRenderer, -) -> rendering_parts.RenderableTreePart | rendering_parts.Rendering: - """Renders evermore classes. - Taken from: https://github.com/google-deepmind/penzai/blob/b1bd577dc34f0e7b8f7fef3bbeb2cd571c2f8fcd/penzai/core/_treescope_handlers/struct_handler.py - - Args: - node: The node to render. - path: The path to the node. (Optional) - subtree_renderer: A recursive renderer for subtrees. - - Returns: - A rendering of evermore classes. - """ - - # get prefix, e.g. "Parameter(" - prefix = render_evermore_constructor(node) - - # get fields of the dataclass, e.g. value=1.0 - fields = dataclasses.fields(node) - - # get children of the tree - children = rendering_parts.build_field_children( - node, - path, - subtree_renderer, - fields_or_attribute_names=fields, - attr_style_fn=evermore_attr_style_fn_for_fields(fields), - ) - - # get colors for the background of the tree node - def _treescope_color(node) -> str: - """Returns the color of the tree node.""" - - type_string = type(node).__module__ + "." + type(node).__qualname__ - return formatting_util.color_from_string(type_string) - - background_color, background_pattern = ( - formatting_util.parse_simple_color_and_pattern_spec( - _treescope_color(node), type(node).__name__ + subtree_renderer: Callable[ + [Any, str | None], treescope.rendering_parts.Rendering + ], + ) -> treescope.rendering_parts.Rendering: + object_type = type(self) + return treescope.repr_lib.render_object_constructor( + object_type=object_type, + attributes=dict(self.__dict__), + path=path, + subtree_renderer=subtree_renderer, + # Pass `roundtrippable=True` only if you can rebuild your object by + # calling `__init__` with these attributes! + roundtrippable=True, + color=treescope.formatting_util.color_from_string(object_type.__qualname__), ) - ) - - return rendering_parts.build_foldable_tree_node_from_children( - prefix=prefix, - children=children, - suffix=")", - background_color=background_color, - background_pattern=background_pattern, - ) - - -def evermore_attr_style_fn_for_fields( - fields, -) -> Callable[[str], rendering_parts.RenderableTreePart]: - """Builds a function to render attributes of an evermore class. - - The resulting function will render pytree node fields in a different style. - E.g. the field "value" of a Parameter class will be rendered in a different style. - - Taken from: https://github.com/google-deepmind/penzai/blob/b1bd577dc34f0e7b8f7fef3bbeb2cd571c2f8fcd/penzai/core/_treescope_handlers/struct_handler.py - - Args: - fields: The fields of the evermore class. - - Returns: - A function that takes a field name and returns a RenderableTreePart.""" - fields_by_name = {field.name: field for field in fields} - - def attr_style_fn(field_name): - field = fields_by_name[field_name] - if field.metadata.get("pytree_node", True): - return rendering_parts.custom_style( - rendering_parts.text(field_name), - css_style="font-style: italic; color: #00255f;", - ) - return rendering_parts.text(field_name) - - return attr_style_fn - - -def render_evermore_constructor(node: Any) -> rendering_parts.RenderableTreePart: - """Renders the constructor of an evermore class, with an open parenthesis. - Taken from: https://github.com/google-deepmind/penzai/blob/b1bd577dc34f0e7b8f7fef3bbeb2cd571c2f8fcd/penzai/core/_treescope_handlers/struct_handler.py - """ - if dataclass_util.init_takes_fields(type(node)): - return rendering_parts.siblings( - rendering_parts.maybe_qualified_type_name(type(node)), "(" - ) - - return rendering_parts.siblings( - rendering_parts.maybe_qualified_type_name(type(node)), - rendering_parts.roundtrip_condition( - roundtrip=rendering_parts.text(".from_attributes") - ), - ) diff --git a/tests/test_effect.py b/tests/test_effect.py index ed7eca6..19a2913 100644 --- a/tests/test_effect.py +++ b/tests/test_effect.py @@ -15,70 +15,124 @@ def test_Identity(): - effect = Identity() + effect: Identity = Identity() - assert effect( - parameter=Parameter(value=0.0), hist=jnp.array([1, 2, 3]) - ) == OffsetAndScale(offset=0.0, scale=1.0) - assert effect( - parameter=Parameter(value=1.0), hist=jnp.array([1, 2, 3]) - ) == OffsetAndScale(offset=0.0, scale=1.0) - assert effect( - parameter=(Parameter(), Parameter()), hist=jnp.array([1, 2, 3]) - ) == OffsetAndScale(offset=0.0, scale=1.0) + hist = jnp.array([1.0, 2.0, 3.0]) + + assert ( + effect(parameter=Parameter(value=0.0), hist=hist) + == OffsetAndScale( + offset=jnp.zeros_like(hist), scale=jnp.ones_like(hist) + ).broadcast() + ) + assert ( + effect(parameter=Parameter(value=1.0), hist=hist) + == OffsetAndScale( + offset=jnp.zeros_like(hist), scale=jnp.ones_like(hist) + ).broadcast() + ) + assert ( + effect(parameter=(Parameter(), Parameter()), hist=hist) + == OffsetAndScale( + offset=jnp.zeros_like(hist), scale=jnp.ones_like(hist) + ).broadcast() + ) def test_Linear(): - effect = Linear(slope=1.0, offset=0.0) - assert effect( - parameter=Parameter(value=0.0), hist=jnp.array([1, 2, 3]) - ) == OffsetAndScale(offset=0.0, scale=0.0) - assert effect( - parameter=Parameter(value=1.0), hist=jnp.array([1, 2, 3]) - ) == OffsetAndScale(offset=0.0, scale=1.0) + effect: Linear = Linear(slope=1.0, offset=0.0) + + hist = jnp.array([1.0, 2.0, 3.0]) + + assert ( + effect(parameter=Parameter(value=0.0), hist=hist) + == OffsetAndScale( + offset=jnp.zeros_like(hist), scale=jnp.zeros_like(hist) + ).broadcast() + ) + assert ( + effect(parameter=Parameter(value=1.0), hist=hist) + == OffsetAndScale( + offset=jnp.zeros_like(hist), scale=jnp.ones_like(hist) + ).broadcast() + ) effect = Linear(slope=0.0, offset=1.0) - assert effect( - parameter=Parameter(value=0.0), hist=jnp.array([1, 2, 3]) - ) == OffsetAndScale(offset=0.0, scale=1.0) - assert effect( - parameter=Parameter(value=1.0), hist=jnp.array([1, 2, 3]) - ) == OffsetAndScale(offset=0.0, scale=1.0) + assert ( + effect(parameter=Parameter(value=0.0), hist=hist) + == OffsetAndScale( + offset=jnp.zeros_like(hist), scale=jnp.ones_like(hist) + ).broadcast() + ) + assert ( + effect(parameter=Parameter(value=1.0), hist=hist) + == OffsetAndScale( + offset=jnp.zeros_like(hist), scale=jnp.ones_like(hist) + ).broadcast() + ) effect = Linear(slope=1.0, offset=1.0) - assert effect( - parameter=Parameter(value=0.0), hist=jnp.array([1, 2, 3]) - ) == OffsetAndScale(offset=0.0, scale=1.0) - assert effect( - parameter=Parameter(value=1.0), hist=jnp.array([1, 2, 3]) - ) == OffsetAndScale(offset=0.0, scale=2.0) + assert ( + effect(parameter=Parameter(value=0.0), hist=hist) + == OffsetAndScale( + offset=jnp.zeros_like(hist), scale=jnp.ones_like(hist) + ).broadcast() + ) + assert ( + effect(parameter=Parameter(value=1.0), hist=hist) + == OffsetAndScale( + offset=jnp.zeros_like(hist), scale=jnp.full_like(hist, 2.0) + ).broadcast() + ) def test_AsymmetricExponential(): - effect = AsymmetricExponential(up=1.2, down=0.9) + effect: AsymmetricExponential = AsymmetricExponential(up=1.2, down=0.9) + + hist = jnp.array([1.0]) - assert effect( - parameter=Parameter(value=0.0), hist=jnp.array([1]) - ) == OffsetAndScale(offset=0.0, scale=1.0) - assert effect( - parameter=Parameter(value=+1.0), hist=jnp.array([1]) - ) == OffsetAndScale(offset=0.0, scale=1.2) - assert effect( - parameter=Parameter(value=-1.0), hist=jnp.array([1]) - ) == OffsetAndScale(offset=0.0, scale=0.9) + assert ( + effect(parameter=Parameter(value=0.0), hist=hist) + == OffsetAndScale( + offset=jnp.zeros_like(hist), scale=jnp.ones_like(hist) + ).broadcast() + ) + assert ( + effect(parameter=Parameter(value=+1.0), hist=hist) + == OffsetAndScale( + offset=jnp.zeros_like(hist), scale=jnp.full_like(hist, 1.2) + ).broadcast() + ) + assert ( + effect(parameter=Parameter(value=-1.0), hist=hist) + == OffsetAndScale( + offset=jnp.zeros_like(hist), scale=jnp.full_like(hist, 0.9) + ).broadcast() + ) def test_VerticalTemplateMorphing(): - effect = VerticalTemplateMorphing( + effect: VerticalTemplateMorphing = VerticalTemplateMorphing( up_template=jnp.array([12]), down_template=jnp.array([7]) ) - assert effect( - parameter=Parameter(value=0.0), hist=jnp.array([10]) - ) == OffsetAndScale(offset=[0.0], scale=[1.0]) - assert effect( - parameter=Parameter(value=+1.0), hist=jnp.array([10]) - ) == OffsetAndScale(offset=[2.0], scale=[1.0]) - assert effect( - parameter=Parameter(value=-1.0), hist=jnp.array([10]) - ) == OffsetAndScale(offset=[-3.0], scale=[1.0]) + hist = jnp.array([10.0]) + + assert ( + effect(parameter=Parameter(value=0.0), hist=hist) + == OffsetAndScale( + offset=jnp.zeros_like(hist), scale=jnp.ones_like(hist) + ).broadcast() + ) + assert ( + effect(parameter=Parameter(value=+1.0), hist=hist) + == OffsetAndScale( + offset=jnp.full_like(hist, 2.0), scale=jnp.ones_like(hist) + ).broadcast() + ) + assert ( + effect(parameter=Parameter(value=-1.0), hist=hist) + == OffsetAndScale( + offset=jnp.full_like(hist, -3.0), scale=jnp.ones_like(hist) + ).broadcast() + ) diff --git a/tests/test_loss.py b/tests/test_loss.py index 14cd9a6..01d4d92 100644 --- a/tests/test_loss.py +++ b/tests/test_loss.py @@ -1,15 +1,21 @@ from __future__ import annotations +import typing as tp + import jax.numpy as jnp import numpy as np import pytest +from jaxtyping import Float, PyTree, Scalar import evermore as evm +ScalarParam: tp.TypeAlias = evm.Parameter[Float[Scalar, ""]] +ScalarParamTree: tp.TypeAlias = PyTree[ScalarParam] + def test_get_log_probs(): - params = { - "a": evm.NormalParameter(value=0.5), # type: ignore[arg-type] + params: ScalarParamTree = { + "a": evm.NormalParameter(value=0.5), "b": evm.NormalParameter(), "c": evm.Parameter(), } @@ -21,14 +27,18 @@ def test_get_log_probs(): def test_compute_covariance(): - def loss_fn(params): + def loss_fn(params: ScalarParamTree) -> Float[Scalar, ""]: return ( params["a"].value ** 2 + 2 * params["b"].value ** 2 + (params["a"].value + params["c"].value) ** 2 ) - params = {"a": evm.Parameter(2.0), "b": evm.Parameter(3.0), "c": evm.Parameter(4.0)} + params: ScalarParamTree = { + "a": evm.Parameter(2.0), + "b": evm.Parameter(3.0), + "c": evm.Parameter(4.0), + } cov = evm.loss.compute_covariance(loss_fn, params) diff --git a/tests/test_modifier.py b/tests/test_modifier.py index e5f76c6..ebabb17 100644 --- a/tests/test_modifier.py +++ b/tests/test_modifier.py @@ -1,13 +1,19 @@ from __future__ import annotations +import typing as tp + import jax.numpy as jnp import numpy as np +from jaxtyping import Float, PyTree, Scalar import evermore as evm +ScalarParam: tp.TypeAlias = evm.Parameter[Float[Scalar, ""]] +ScalarParamTree: tp.TypeAlias = PyTree[ScalarParam] + def test_Modifier(): - param = evm.Parameter(value=1.1) # type: ignore[arg-type] + param: ScalarParam = evm.Parameter(value=1.1) modifier = param.scale() hist = jnp.array([1, 2, 3]) @@ -17,48 +23,56 @@ def test_Modifier(): def test_Where(): - param1 = evm.Parameter(value=1.0) # type: ignore[arg-type] - param2 = evm.Parameter(value=1.1) # type: ignore[arg-type] + param1: ScalarParam = evm.Parameter(value=1.0) + param2: ScalarParam = evm.Parameter(value=1.1) modifier1 = param1.scale() modifier2 = param2.scale() hist = jnp.array([1, 2, 3]) - where_mod = evm.modifier.Where(hist > 1.5, modifier2, modifier1) + where_mod: evm.modifier.Where[ScalarParamTree] = evm.modifier.Where( + hist > 1.5, modifier2, modifier1 + ) np.testing.assert_allclose(where_mod(hist), jnp.array([1, 2.2, 3.3])) def test_BooleanMask(): - param = evm.Parameter(value=1.1) # type: ignore[arg-type] + param: ScalarParam = evm.Parameter(value=1.1) modifier = param.scale() hist = jnp.array([1, 2, 3]) - masked_mod = evm.modifier.BooleanMask(jnp.array([True, False, True]), modifier) + masked_mod: evm.modifier.BooleanMask[ScalarParamTree] = evm.modifier.BooleanMask( + jnp.array([True, False, True]), modifier + ) np.testing.assert_allclose(masked_mod(hist), jnp.array([1.1, 2, 3.3])) def test_Transform(): - param = evm.Parameter(value=1.1) # type: ignore[arg-type] + param: ScalarParam = evm.Parameter(value=1.1) modifier = param.scale() hist = jnp.array([1, 2, 3]) - sqrt_modifier = evm.modifier.Transform(jnp.sqrt, modifier) + sqrt_modifier: evm.modifier.Transform[ScalarParamTree] = evm.modifier.Transform( + jnp.sqrt, modifier + ) np.testing.assert_allclose( sqrt_modifier(hist), jnp.array([1.0488088, 2.0976176, 3.1464264]) ) def test_mix_modifiers(): - param = evm.Parameter(value=1.1) # type: ignore[arg-type] + param: ScalarParam = evm.Parameter(value=1.1) modifier = param.scale() hist = jnp.array([1, 2, 3]) - sqrt_modifier = evm.modifier.Transform(jnp.sqrt, modifier) - sqrt_masked_modifier = evm.modifier.BooleanMask( - jnp.array([True, False, True]), sqrt_modifier + sqrt_modifier: evm.modifier.Transform[ScalarParamTree] = evm.modifier.Transform( + jnp.sqrt, modifier + ) + sqrt_masked_modifier: evm.modifier.BooleanMask[ScalarParamTree] = ( + evm.modifier.BooleanMask(jnp.array([True, False, True]), sqrt_modifier) ) np.testing.assert_allclose( sqrt_masked_modifier(hist), jnp.array([1.0488088, 2, 3.1464264]) @@ -66,8 +80,8 @@ def test_mix_modifiers(): def test_Compose(): - param1 = evm.Parameter(value=1.0) # type: ignore[arg-type] - param2 = evm.Parameter(value=1.1) # type: ignore[arg-type] + param1: ScalarParam = evm.Parameter(value=1.0) + param2: ScalarParam = evm.Parameter(value=1.1) modifier1 = param1.scale() modifier2 = param2.scale() diff --git a/tests/test_parameter.py b/tests/test_parameter.py index be786f9..263836c 100644 --- a/tests/test_parameter.py +++ b/tests/test_parameter.py @@ -1,11 +1,17 @@ from __future__ import annotations +import typing as tp + +from jaxtyping import Float, Scalar + import evermore as evm from evermore.pdf import Normal +ScalarParam: tp.TypeAlias = evm.Parameter[Float[Scalar, ""]] + def test_Parameter(): - p = evm.Parameter(value=1.0, lower=0.0, upper=2.0) # type: ignore[arg-type] + p: ScalarParam = evm.Parameter(value=1.0, lower=0.0, upper=2.0) assert p.value == 1.0 assert p.lower == 0.0 assert p.upper == 2.0 @@ -13,7 +19,7 @@ def test_Parameter(): def test_NormalParameter(): - p = evm.NormalParameter(value=1.0, lower=0.0, upper=2.0) # type: ignore[arg-type] + p: ScalarParam = evm.NormalParameter(value=1.0, lower=0.0, upper=2.0) assert p.value == 1.0 assert p.lower == 0.0 assert p.upper == 2.0 diff --git a/tests/test_pdf.py b/tests/test_pdf.py index e025478..e15d3c8 100644 --- a/tests/test_pdf.py +++ b/tests/test_pdf.py @@ -2,23 +2,24 @@ import jax.numpy as jnp import pytest +from jaxtyping import Float, Scalar from evermore.pdf import Normal, PoissonContinuous, PoissonDiscrete def test_Normal(): - pdf = Normal(mean=jnp.array(0.0), width=jnp.array(1.0)) + pdf: Normal[Float[Scalar, ""]] = Normal(mean=jnp.array(0.0), width=jnp.array(1.0)) assert pdf.log_prob(jnp.array(0.0)) == pytest.approx(0.0) def test_PoissonDiscrete(): - pdf = PoissonDiscrete(lamb=jnp.array(10)) + pdf: PoissonDiscrete[Float[Scalar, ""]] = PoissonDiscrete(lamb=jnp.array(10)) assert pdf.log_prob(jnp.array(5.0)) == pytest.approx(-1.5342636) def test_PoissonContinuous(): - pdf = PoissonContinuous(lamb=jnp.array(10)) + pdf: PoissonContinuous[Float[Scalar, ""]] = PoissonContinuous(lamb=jnp.array(10)) assert pdf.log_prob(jnp.array(5.0)) == pytest.approx(-1.5342636)