Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions examples/toy_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,15 @@ def make_step(
params, opt_state = make_step(params, opt_state, hists, observation)

diffable, static = evm.parameter.partition(params)
fast_covariance_matrix = eqx.filter_jit(evm.sample.compute_covariance_matrix)
covariance_matrix = fast_covariance_matrix(
loss=loss,
params=diffable,
args=(static, hists, observation),
)
fast_covariance_matrix = eqx.filter_jit(evm.loss.compute_covariance)


# partial it to only depend on `params`
def loss_fn(params):
return loss(params, static, hists, observation)


covariance_matrix = fast_covariance_matrix(loss_fn, diffable)


# generate new expectation based on the postfit toy parameters
Expand Down
52 changes: 34 additions & 18 deletions src/evermore/binned/effect.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ def __dir__():


class OffsetAndScale(eqx.Module):
offset: Float[Array, ...] = eqx.field(converter=float_array, default=0.0)
scale: Float[Array, ...] = eqx.field(converter=float_array, default=1.0)
offset: Float[Array, "..."] = eqx.field(converter=float_array, default=0.0) # noqa: UP037
scale: Float[Array, "..."] = eqx.field(converter=float_array, default=1.0) # noqa: UP037

def broadcast(self) -> OffsetAndScale:
shape = jnp.broadcast_shapes(self.offset.shape, self.scale.shape)
Expand All @@ -41,29 +41,35 @@ def broadcast(self) -> OffsetAndScale:
class Effect(eqx.Module, SupportsTreescope):
@abc.abstractmethod
def __call__(
self, parameter: PyTree[Parameter], hist: Float[Array, ...]
self,
parameter: PyTree[Parameter],
hist: Float[Array, "..."], # noqa: UP037
) -> OffsetAndScale: ...


class Identity(Effect):
@jax.named_scope("evm.effect.Identity")
def __call__(
self, parameter: PyTree[Parameter], hist: Float[Array, ...]
self,
parameter: PyTree[Parameter],
hist: Float[Array, "..."], # noqa: UP037
) -> OffsetAndScale:
return OffsetAndScale(offset=0.0, scale=1.0) # type: ignore[arg-type]


class Lambda(Effect):
fun: Callable[
[PyTree[Parameter], Float[Array, ...]], OffsetAndScale | Float[Array, ...]
[PyTree[Parameter], Float[Array, "..."]], OffsetAndScale | Float[Array, "..."] # noqa: UP037
]
normalize_by: Literal["offset", "scale"] | None = eqx.field(
static=True, default=None
)

@jax.named_scope("evm.effect.Lambda")
def __call__(
self, parameter: PyTree[Parameter], hist: Float[Array, ...]
self,
parameter: PyTree[Parameter],
hist: Float[Array, "..."], # noqa: UP037
) -> OffsetAndScale:
res = self.fun(parameter, hist)
if isinstance(res, OffsetAndScale) and self.normalize_by is None:
Expand All @@ -77,12 +83,14 @@ def __call__(


class Linear(Effect):
offset: Float[Array, ...] = eqx.field(converter=float_array)
slope: Float[Array, ...] = eqx.field(converter=float_array)
offset: Float[Array, "..."] = eqx.field(converter=float_array) # noqa: UP037
slope: Float[Array, "..."] = eqx.field(converter=float_array) # noqa: UP037

@jax.named_scope("evm.effect.Linear")
def __call__(
self, parameter: PyTree[Parameter], hist: Float[Array, ...]
self,
parameter: PyTree[Parameter],
hist: Float[Array, "..."], # noqa: UP037
) -> OffsetAndScale:
assert isinstance(parameter, Parameter)
sf = parameter.value * self.slope + self.offset
Expand All @@ -93,12 +101,16 @@ def __call__(


class VerticalTemplateMorphing(Effect):
up_template: Float[Array, ...] = eqx.field(converter=float_array) # + 1 sigma
down_template: Float[Array, ...] = eqx.field(converter=float_array) # - 1 sigma
# + 1 sigma
up_template: Float[Array, "..."] = eqx.field(converter=float_array) # noqa: UP037
# - 1 sigma
down_template: Float[Array, "..."] = eqx.field(converter=float_array) # noqa: UP037

def vshift(
self, x: Float[Array, ...], hist: Float[Array, ...]
) -> Float[Array, ...]:
self,
x: Float[Array, "..."], # noqa: UP037
hist: Float[Array, "..."], # noqa: UP037
) -> Float[Array, "..."]: # noqa: UP037
dx_sum = self.up_template + self.down_template - 2 * hist
dx_diff = self.up_template - self.down_template

Expand All @@ -118,18 +130,20 @@ def vshift(

@jax.named_scope("evm.effect.VerticalTemplateMorphing")
def __call__(
self, parameter: PyTree[Parameter], hist: Float[Array, ...]
self,
parameter: PyTree[Parameter],
hist: Float[Array, "..."], # noqa: UP037
) -> OffsetAndScale:
assert isinstance(parameter, Parameter)
offset = self.vshift(parameter.value, hist=hist)
return OffsetAndScale(offset=offset, scale=jnp.ones_like(hist)) # type: ignore[arg-type]


class AsymmetricExponential(Effect):
up: Float[Array, ...] = eqx.field(converter=float_array)
down: Float[Array, ...] = eqx.field(converter=float_array)
up: Float[Array, "..."] = eqx.field(converter=float_array) # noqa: UP037
down: Float[Array, "..."] = eqx.field(converter=float_array) # noqa: UP037

def interpolate(self, x: Float[Array, ...]) -> Float[Array, ...]:
def interpolate(self, x: Float[Array, "..."]) -> Float[Array, "..."]: # noqa: UP037
# https://github.com/cms-analysis/HiggsAnalysis-CombinedLimit/blob/be488af288361ef101859a398ae618131373cad7/src/ProcessNormalization.cc#L112-L129
lo, hi = self.down, self.up
hi = jnp.log(hi)
Expand All @@ -146,7 +160,9 @@ def interpolate(self, x: Float[Array, ...]) -> Float[Array, ...]:

@jax.named_scope("evm.effect.AsymmetricExponential")
def __call__(
self, parameter: PyTree[Parameter], hist: Float[Array, ...]
self,
parameter: PyTree[Parameter],
hist: Float[Array, "..."], # noqa: UP037
) -> OffsetAndScale:
assert isinstance(parameter, Parameter)
interp = self.interpolate(parameter.value)
Expand Down
28 changes: 16 additions & 12 deletions src/evermore/binned/modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,27 +36,31 @@ def __dir__():

@runtime_checkable
class ModifierLike(Protocol):
def offset_and_scale(self, hist: Float[Array, ...]) -> OffsetAndScale: ...
def __call__(self, hist: Float[Array, ...]) -> Float[Array, ...]: ...
def offset_and_scale(self, hist: Float[Array, "..."]) -> OffsetAndScale: ... # noqa: UP037
def __call__(self, hist: Float[Array, "..."]) -> Float[Array, "..."]: ... # noqa: UP037
def __matmul__(self, other: ModifierLike) -> Compose: ...


class AbstractModifier(eqx.Module):
@abc.abstractmethod
def offset_and_scale(
self: ModifierLike, hist: Float[Array, ...]
self: ModifierLike,
hist: Float[Array, "..."], # noqa: UP037
) -> OffsetAndScale: ...

@abc.abstractmethod
def __call__(self: ModifierLike, hist: Float[Array, ...]) -> Float[Array, ...]: ...
def __call__(
self: ModifierLike,
hist: Float[Array, "..."], # noqa: UP037
) -> Float[Array, "..."]: ... # noqa: UP037

@abc.abstractmethod
def __matmul__(self: ModifierLike, other: ModifierLike) -> Compose: ...


class ApplyFn(AbstractModifier):
@jax.named_scope("evm.modifier.ApplyFn")
def __call__(self: ModifierLike, hist: Float[Array, ...]) -> Float[Array, ...]:
def __call__(self: ModifierLike, hist: Float[Array, "..."]) -> Float[Array, "..."]: # noqa: UP037
os = self.offset_and_scale(hist=hist)
return os.scale * (hist + os.offset)

Expand Down Expand Up @@ -168,7 +172,7 @@ def __init__(
self.parameter = parameter
self.effect = effect

def offset_and_scale(self, hist: Float[Array, ...]) -> OffsetAndScale:
def offset_and_scale(self, hist: Float[Array, "..."]) -> OffsetAndScale: # noqa: UP037
return self.effect(parameter=self.parameter, hist=hist)


Expand Down Expand Up @@ -211,7 +215,7 @@ class Where(ModifierBase):
modifier_true: ModifierLike
modifier_false: ModifierLike

def offset_and_scale(self, hist: Float[Array, ...]) -> OffsetAndScale:
def offset_and_scale(self, hist: Float[Array, "..."]) -> OffsetAndScale: # noqa: UP037
true_os = self.modifier_true.offset_and_scale(hist)
false_os = self.modifier_false.offset_and_scale(hist)

Expand Down Expand Up @@ -254,7 +258,7 @@ class BooleanMask(ModifierBase):
mask: Bool[Array, ...]
modifier: ModifierLike

def offset_and_scale(self, hist: Float[Array, ...]) -> OffsetAndScale:
def offset_and_scale(self, hist: Float[Array, "..."]) -> OffsetAndScale: # noqa: UP037
os = self.modifier.offset_and_scale(hist)

def _mask(
Expand Down Expand Up @@ -301,7 +305,7 @@ class Transform(ModifierBase):
transform_fn: Callable = eqx.field(static=True)
modifier: ModifierLike

def offset_and_scale(self, hist: Float[Array, ...]) -> OffsetAndScale:
def offset_and_scale(self, hist: Float[Array, "..."]) -> OffsetAndScale: # noqa: UP037
os = self.modifier.offset_and_scale(hist)
return jax.tree.map(self.transform_fn, os)

Expand All @@ -310,7 +314,7 @@ class TransformOffset(ModifierBase):
transform_fn: Callable = eqx.field(static=True)
modifier: ModifierLike

def offset_and_scale(self, hist: Float[Array, ...]) -> OffsetAndScale:
def offset_and_scale(self, hist: Float[Array, "..."]) -> OffsetAndScale: # noqa: UP037
os = self.modifier.offset_and_scale(hist)
return OffsetAndScale(offset=self.transform_fn(os.offset), scale=os.scale)

Expand All @@ -319,7 +323,7 @@ class TransformScale(ModifierBase):
transform_fn: Callable = eqx.field(static=True)
modifier: ModifierLike

def offset_and_scale(self, hist: Float[Array, ...]) -> OffsetAndScale:
def offset_and_scale(self, hist: Float[Array, "..."]) -> OffsetAndScale: # noqa: UP037
os = self.modifier.offset_and_scale(hist)
return OffsetAndScale(offset=os.offset, scale=self.transform_fn(os.scale))

Expand Down Expand Up @@ -392,7 +396,7 @@ def unroll_modifiers(self) -> list[ModifierLike]:
def __len__(self) -> int:
return len(self.unroll_modifiers())

def offset_and_scale(self, hist: Float[Array, ...]) -> OffsetAndScale:
def offset_and_scale(self, hist: Float[Array, "..."]) -> OffsetAndScale: # noqa: UP037
from collections import defaultdict

# initial scale and offset
Expand Down
10 changes: 5 additions & 5 deletions src/evermore/binned/staterror.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,15 @@ class StatErrors(ModifierBase):
"""

eps: Float[Scalar, ""]
n_entries: Float[Array, ...]
n_entries: Float[Array, "..."] # noqa: UP037
non_empty_mask: Bool[Array, " nbins"]
relative_error: Float[Array, ...]
relative_error: Float[Array, "..."] # noqa: UP037
parameter: NormalParameter

def __init__(
self,
hist: Float[Array, ...],
variance: Float[Array, ...],
hist: Float[Array, "..."], # noqa: UP037
variance: Float[Array, "..."], # noqa: UP037
):
# make sure they are of dtype float
hist, variance = jax.tree.map(float_array, (hist, variance))
Expand All @@ -87,7 +87,7 @@ def __init__(
)
self.parameter = NormalParameter(value=jnp.zeros_like(self.n_entries))

def offset_and_scale(self, hist: Float[Array, ...]) -> OffsetAndScale:
def offset_and_scale(self, hist: Float[Array, "..."]) -> OffsetAndScale: # noqa: UP037
modifier = Where(
self.non_empty_mask,
self.parameter.scale(slope=self.relative_error, offset=1.0),
Expand Down
86 changes: 50 additions & 36 deletions src/evermore/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,15 @@
import jax
import jax.flatten_util
import jax.numpy as jnp
from jaxtyping import Array, Float, Scalar

from evermore.parameters.parameter import Parameter, _params_map, _ParamsTree
from jaxtyping import Array, Float

from evermore.parameters.parameter import (
Parameter,
_params_map,
_ParamsTree,
_replace_parameter_value,
is_parameter,
)
from evermore.pdf import PDF, ImplementsFromUnitNormalConversion, Normal

__all__ = [
Expand Down Expand Up @@ -76,54 +82,62 @@ def _constraint(param: Parameter) -> Float[Array, "..."]:
def compute_covariance(
loss_fn: tp.Callable,
params: _ParamsTree,
args: tuple[tp.Any, ...] = (),
kwargs: dict[str, tp.Any] | None = None,
) -> Array:
"""
Computes the covariance (correlation) matrix between parameters in a PyTree, evaluated with its
parameter values at a given loss function. The covariance is computed using the inverted Hessian
under the Laplace assumption of normality, followed by a normalization step.
) -> Float[Array, "nparams nparams"]:
r"""
Computes the covariance matrix of the parameters under the Laplace approximation,
by inverting the Hessian of the loss function at the current parameter values.

See ``examples/toy_generation.py`` for an example usage.

Args:
loss_fn (Callable): The loss function. Should accept (params) as arguments.
All other arguments have to be "partial'd" into the loss function.
params (_ParamsTree): A PyTree of parameters.

Returns:
Float[Array, "nparams nparams"]: The covariance matrix of the parameters.

Example:

.. code-block:: python

import jax.numpy as jnp
import evermore as evm

params = {"a": jnp.array(2.0), "b": jnp.array(3.0), "c": jnp.array(4.0)}
import jax
import jax.numpy as jnp


def loss_fn(params):
# some loss function depending on params["a"], params["b"] and params["c"]
return ...

x = params["a"].value
y = params["b"].value
return jnp.sum((x - 1.0) ** 2 + (y - 2.0) ** 2)

# compute the covariance matrix
cov = evm.loss.compute_covariance(loss_fn, params)

Args:
loss_fn (Callable): A callable whose gradients are evaluated for the computation.
params (PyTree): A PyTree containing parameters to compute the covariance for.
args (tuple): Additional positional arguments to pass to the loss function.
kwargs (dict): Additional keyword arguments to pass to the loss function.
params = {
"a": evm.Parameter(value=jnp.array([1.0]), prior=None, lower=0.0, upper=2.0),
"b": evm.Parameter(value=jnp.array([2.0]), prior=None, lower=1.0, upper=3.0),
}

Returns:
Array: A square matrix representing the correlation between parameters.
cov = evm.loss.compute_covariance(loss_fn, params)
cov.shape
# (2, 2)
"""
# default kwargs
if kwargs is None:
kwargs = {}
# first, compute the hessian at the current point
values = _params_map(lambda p: p.value, params)
flat_values, unravel_fn = jax.flatten_util.ravel_pytree(values)

# create a flattened version of the parameters and the loss
flat_params, unravel_fn = jax.flatten_util.ravel_pytree(params)
def _flat_loss(flat_values: Float[Array, "..."]) -> Float[Array, ""]:
param_values = unravel_fn(flat_values)

def flat_loss_fn(flat_params: Float[Array, "..."]) -> Float[Scalar, ""]:
return loss_fn(unravel_fn(flat_params), *args, **kwargs)
_params = jax.tree.map(
_replace_parameter_value, params, param_values, is_leaf=is_parameter
)
return loss_fn(_params)

# compute the hessian at the current parameters
h = jax.hessian(flat_loss_fn)(flat_params)
# calculate hessian
hessian = jax.hessian(_flat_loss)(flat_values)

# get the unnormalized covariance matrix
cov = jnp.linalg.inv(h)
# invert to get the correlation matrix under the Laplace assumption of normality
cov = jnp.linalg.inv(hessian)

# normalize via D^-1 @ cov @ D^-1 with D being the diagnonal standard deviation matrix
d = jnp.sqrt(jnp.diag(cov))
Expand Down
Loading