From e0497d50a32e7fe0bb31cf7f6227b04dbd90c475 Mon Sep 17 00:00:00 2001 From: Marcel Rieger Date: Thu, 14 Aug 2025 16:31:12 +0200 Subject: [PATCH 1/5] feat: Accelerate discrete inv_cdf search (#51) * Add discrete_inv_cdf_search, use in PoissonDiscrete. * Typo. * Typos. * Prefer | over logical_or. * Add rounding choice check. * avoid shape manipulation & polish --------- Co-authored-by: Peter Fackeldey Co-authored-by: pfackeldey --- src/evermore/pdf.py | 181 ++++++++++++++++++++++++++++++++++++++------ tests/test_pdf.py | 34 ++++++++- 2 files changed, 190 insertions(+), 25 deletions(-) diff --git a/src/evermore/pdf.py b/src/evermore/pdf.py index e6355dd..67abd6c 100644 --- a/src/evermore/pdf.py +++ b/src/evermore/pdf.py @@ -1,7 +1,8 @@ from __future__ import annotations import abc -from typing import Generic, Protocol, runtime_checkable +from collections.abc import Callable +from typing import Generic, Literal, Protocol, runtime_checkable import equinox as eqx import jax @@ -80,8 +81,8 @@ class PoissonBase(AbstractPDF[V]): class PoissonDiscrete(PoissonBase[V]): """ - Poisson distribution with discrete support. Float inputs are floored to the nearest integer. - See https://root.cern.ch/doc/master/RooPoisson_8cxx_source.html#l00057 for reference. + Poisson distribution with discrete support. Float inputs are floored to the nearest integer, + similar to the behavior implemented in other libraries like SciPy or RooFit. """ def log_prob( @@ -89,36 +90,38 @@ def log_prob( x: V, normalize: bool = True, ) -> V: - x = jnp.floor(x) + k = jnp.floor(x) - unnormalized = jax.scipy.stats.poisson.logpmf(x, self.lamb) + # plain evaluation of the pmf + unnormalized = jax.scipy.stats.poisson.logpmf(k, self.lamb) if not normalize: return unnormalized - logpdf_max = jax.scipy.stats.poisson.logpmf(x, x) + # when normalizing, divide (subtract in log space) by maximum over k range + logpdf_max = jax.scipy.stats.poisson.logpmf(k, k) return unnormalized - logpdf_max def cdf(self, x: V) -> V: + # no need to round x to k, already done by cdf library function return jax.scipy.stats.poisson.cdf(x, self.lamb) - def inv_cdf(self, x: V) -> V: - # perform an iterative search - # see: https://num.pyro.ai/en/stable/tutorials/truncated_distributions.html?highlight=poisson%20inverse#5.3-Example:-Left-truncated-Poisson - def cond_fn(val): - _, cdf = val - return jnp.any(cdf < x) - - def body_fn(val): - n, cdf = val - n_new = jnp.where(cdf < x, n + 1, n) - return n_new, jax.scipy.stats.poisson.cdf(n_new, self.lamb) - - start_n = jnp.zeros_like(x, dtype=jnp.result_type(int)) - start_cdf = jnp.zeros_like(x, dtype=jnp.result_type(float)) - n, _ = jax.lax.while_loop(cond_fn, body_fn, (start_n, start_cdf)) - - # since we check for cdf < value, n will always refer to the next value - return jnp.clip(n - 1, min=0) + def inv_cdf(self, x: V, rounding: DiscreteRounding = "floor") -> V: + # define starting point for search from normal approximation + def start_fn(x: V) -> V: + return jnp.floor( + self.lamb + jax.scipy.stats.norm.ppf(x) * jnp.sqrt(self.lamb) + ) + + # define the cdf function + def cdf_fn(k: V) -> V: + return jax.scipy.stats.poisson.cdf(k, self.lamb) + + return discrete_inv_cdf_search( + x, + cdf_fn=cdf_fn, + start_fn=start_fn, + rounding=rounding, + ) def sample(self, key: PRNGKeyArray, shape: Shape) -> Float[Array, ...]: return jax.random.poisson(key, self.lamb, shape=shape) @@ -138,10 +141,13 @@ def _log_prob(x, lamb): x = jnp.array(x, jnp.result_type(float)) return xlogy(x, lamb) - lamb - gammaln(x + 1) + # plain evaluation of the pdf unnormalized = _log_prob(x, lamb) if not normalize: return unnormalized + # when normalizing, divide (subtract in log space) by maximum over a range + # that depends on whether the mode is shifted args = (self.lamb, lamb) if shift_mode else (x, x) logpdf_max = _log_prob(*args) return unnormalized - logpdf_max @@ -159,3 +165,130 @@ def sample( ) -> Float[Array, ...]: msg = f"{self.__class__.__name__} does not support sampling, use PoissonDiscrete instead" raise Exception(msg) + + +# alias for rounding literals +DiscreteRounding = Literal["floor", "ceil", "closest"] +known_roundings = frozenset(DiscreteRounding.__args__) # type: ignore[attr-defined] + + +def discrete_inv_cdf_search( + x: V, + cdf_fn: Callable[[V], V], + start_fn: Callable[[V], V], + rounding: DiscreteRounding, +) -> V: + """ + Computes the inverse CDF (percent point function) at integral values *x* for a discrete CDF + distribution *cdf* using an iterative search strategy. The search starts at values provided by + *start_fn* and progresses in integer steps towards the target values. + + .. code-block:: python + + # this example mimics the PoissonDiscrete.inv_cdf implementation + + import jax + import jax.numpy as jnp + import evermore as evm + + # parameter of the poisson distribution + lamb = 5.0 + + # the normal approximation is a good starting point + def start_fn(x): + return jnp.floor(lamb + jax.scipy.stats.norm.ppf(x) * jnp.sqrt(lamb)) + + + # define the cdf function + def cdf_fn(k): + return jax.scipy.stats.poisson.cdf(k, lamb) + + + k = discrete_inv_cdf_search(jnp.array(0.9), cdf_fn, start_fn, "floor") + # -> 7.0 + + Args: + x (V): Integral values to compute the inverse CDF for. + cdf_fn (Callable): A callable representing the discrete CDF function. It is called with a + single argument and supposed to return the CDF value for that argument. + start_fn (Callable): A callable that provides a starting point for the search. It is called + with a reshaped representation of *x*. + rounding (DiscreteRounding): One of "floor", "ceil" or "closest". + + Returns: + V: The computed inverse CDF values in the same shape as *x*. + """ + # store masks for injecting exact values for known edge cases later on + # inject 0 for x == 0 + zero_mask = x == 0.0 + # inject inf for x == 1 + inf_mask = x == 1.0 + # inject nan for ~(0 < x < 1) or non-finite values + nan_mask = (x < 0.0) | (x > 1.0) | ~jnp.isfinite(x) + + # setup stopping condition and iteration body for the iterative search + # note: functions are defined for scalar values and then vmap'd, with results being reshaped + def cond_fn(val): + *_, stop = val + return ~jnp.any(stop) + + def body_fn(val): + k, target_itg, prev_itg, stop = val + # compute the current integral + itg = cdf_fn(k) + # special case: itg is the exact solution + stop |= itg == target_itg + # if no previous integral is available or if we have not yet "cornered" the target value + # with the current and previous integrals, make a step in the right direction + make_step = ( + (prev_itg < 0) + | ((prev_itg < itg) & (itg < target_itg)) + | ((target_itg < itg) & (itg < prev_itg)) + ) + step = jnp.where(~stop & make_step, jnp.sign(target_itg - itg), 0) + k += step + # if target_itg is between the computed integrals we can now find the correct k + # note: k might be subject to a shift by +1 or -1, depending on the stride and rounding + k_found = ~stop & ~make_step + + # we're using python >=3.11 :) + match rounding: + case "floor": + k_shift = jnp.where(k_found & (itg > target_itg), -1, 0) + case "ceil": + k_shift = jnp.where(k_found & (prev_itg > target_itg), 1, 0) + case "closest": + k_shift = jnp.where( + k_found & (abs(itg - target_itg) > abs(prev_itg - target_itg)), + jnp.sign(prev_itg - itg), + 0, + ) + case _: + msg = f"unknown rounding '{rounding}' mode, expected one of {', '.join(known_roundings)}" # type: ignore[unreachable] + raise ValueError(msg) + + k += k_shift + # update the stop flag and end + stop |= k_found + return (k, target_itg, itg, stop) + + def search(start_k, target_itg, stop): + prev_itg = -jnp.ones_like(target_itg) + val = (start_k, target_itg, prev_itg, stop) + return jax.lax.while_loop(cond_fn, body_fn, val)[0] + + # jnp.vectorize is auto-vmapping over all axes of its arguments, + # see: https://docs.jax.dev/en/latest/_autosummary/jax.numpy.vectorize.html#jax.numpy.vectorize + vsearch = jnp.vectorize(search) + + # define starting point and stop flag (eagerly skipping edge cases), then search + start_k = start_fn(x) + stop = zero_mask | inf_mask | nan_mask + k = vsearch(start_k, x, stop) + + # inject known values for edge cases + k = jnp.where(zero_mask, 0.0, k) + k = jnp.where(inf_mask, jnp.inf, k) + k = jnp.where(nan_mask, jnp.nan, k) + + return k # noqa: RET504 diff --git a/tests/test_pdf.py b/tests/test_pdf.py index e15d3c8..96ed1a9 100644 --- a/tests/test_pdf.py +++ b/tests/test_pdf.py @@ -1,10 +1,17 @@ from __future__ import annotations +import jax import jax.numpy as jnp +import numpy as np import pytest from jaxtyping import Float, Scalar -from evermore.pdf import Normal, PoissonContinuous, PoissonDiscrete +from evermore.pdf import ( + Normal, + PoissonContinuous, + PoissonDiscrete, + discrete_inv_cdf_search, +) def test_Normal(): @@ -23,3 +30,28 @@ def test_PoissonContinuous(): pdf: PoissonContinuous[Float[Scalar, ""]] = PoissonContinuous(lamb=jnp.array(10)) assert pdf.log_prob(jnp.array(5.0)) == pytest.approx(-1.5342636) + + +def test_discrete_inv_cdf_search(): + lamb = 5.0 + + def start_fn(x): + return jnp.floor(lamb + jax.scipy.stats.norm.ppf(x) * jnp.sqrt(lamb)) + + def cdf_fn(k): + return jax.scipy.stats.poisson.cdf(k, lamb) + + # test correct algorithmic behavior + assert discrete_inv_cdf_search(jnp.array([0.9]), cdf_fn, start_fn, "floor") == 7 + assert discrete_inv_cdf_search(jnp.array([0.9]), cdf_fn, start_fn, "ceil") == 8 + assert discrete_inv_cdf_search(jnp.array([0.9]), cdf_fn, start_fn, "closest") == 8 + + # test individual solutions in vmapped mode plus shape preservation + k = discrete_inv_cdf_search(jnp.array([0.9, 0.95, 0.99]), cdf_fn, start_fn, "floor") + np.testing.assert_allclose(k, jnp.array([7.0, 8.0, 10.0])) + k = discrete_inv_cdf_search(jnp.array([0.9, 0.95, 0.99]), cdf_fn, start_fn, "ceil") + np.testing.assert_allclose(k, jnp.array([8.0, 9.0, 11.0])) + k = discrete_inv_cdf_search( + jnp.array([0.9, 0.95, 0.99]), cdf_fn, start_fn, "closest" + ) + np.testing.assert_allclose(k, jnp.array([8.0, 8.0, 10.0])) From 8578be7df9870aa933f35c37b9da4a2b9e4fe01f Mon Sep 17 00:00:00 2001 From: Peter Fackeldey Date: Tue, 19 Aug 2025 11:49:38 -0400 Subject: [PATCH 2/5] add fisher info matrix & cramer rao uncertainties (#79) --- examples/toy_generation.py | 2 +- src/evermore/loss.py | 187 ++++++++++++++++++++++++++++++++----- tests/test_loss.py | 51 +++++++--- 3 files changed, 201 insertions(+), 39 deletions(-) diff --git a/examples/toy_generation.py b/examples/toy_generation.py index 586fbe1..380e283 100644 --- a/examples/toy_generation.py +++ b/examples/toy_generation.py @@ -82,7 +82,7 @@ def prefit_toy_expectation(params, key): # partial it to only depend on `params` loss_fn = partial(optx_loss, args=(static, hists, observation)) - fast_covariance_matrix = eqx.filter_jit(evm.loss.compute_covariance) + fast_covariance_matrix = eqx.filter_jit(evm.loss.covariance_matrix) covariance_matrix = fast_covariance_matrix(loss_fn, dynamic) # create 1 toy diff --git a/src/evermore/loss.py b/src/evermore/loss.py index def3a55..366f49e 100644 --- a/src/evermore/loss.py +++ b/src/evermore/loss.py @@ -15,8 +15,11 @@ from evermore.pdf import AbstractPDF, ImplementsFromUnitNormalConversion, Normal __all__ = [ - "compute_covariance", + "covariance_matrix", + "cramer_rao_uncertainty", + "fisher_information_matrix", "get_log_probs", + "hessian_matrix", ] @@ -79,15 +82,27 @@ def _constraint(param: AbstractParameter[V]) -> V: return jax.tree.map(_constraint, only(tree, is_parameter), is_leaf=is_parameter) -def compute_covariance( +def _ravel_pure_tree(tree: PT) -> tuple[Float[Array, " nparams"], tp.Callable]: + """Flattens a PyTree of parameters into a 1D array of parameter values. + Args: + tree (PT): A PyTree of parameters. + + Returns: + tuple[Float[Array, "nparams"], Callable]: A tuple containing the flattened + array of parameter values and a function to unflatten the array back into + the original PyTree structure. + """ + values = pure(tree) + flat_values, unravel_fn = jax.flatten_util.ravel_pytree(values) + return flat_values, unravel_fn + + +def hessian_matrix( loss_fn: tp.Callable, tree: PT, ) -> Float[Array, "nparams nparams"]: - r""" - Computes the covariance matrix of the parameters under the Laplace approximation, - by inverting the Hessian of the loss function at the current parameter values. - - See ``examples/toy_generation.py`` for an example usage. + """ + Computes the Hessian matrix of the loss function at the current parameter values. Args: loss_fn (Callable): The loss function. Should accept (tree) as arguments. @@ -95,14 +110,12 @@ def compute_covariance( tree (PT): A PyTree of parameters. Returns: - Float[Array, "nparams nparams"]: The covariance matrix of the parameters. + Float[Array, "nparams nparams"]: The Hessian matrix of the loss function. Example: .. code-block:: python - import evermore as evm - import jax import jax.numpy as jnp @@ -113,17 +126,15 @@ def loss_fn(params): params = { - "a": evm.Parameter(value=jnp.array([1.0]), prior=None, lower=0.0, upper=2.0), - "b": evm.Parameter(value=jnp.array([2.0]), prior=None, lower=1.0, upper=3.0), + "a": evm.Parameter(value=jnp.array([1.0])), + "b": evm.Parameter(value=jnp.array([2.0])), } - cov = evm.loss.compute_covariance(loss_fn, params) - cov.shape + hessian = evm.loss.hessian(loss_fn, params) + hessian.shape # (2, 2) """ - # first, compute the hessian at the current point - values = pure(tree) - flat_values, unravel_fn = jax.flatten_util.ravel_pytree(values) + flat_values, unravel_fn = _ravel_pure_tree(tree) def _flat_loss(flat_values: Float[Array, "..."]) -> Float[Array, ""]: param_values = unravel_fn(flat_values) @@ -144,14 +155,144 @@ def _flat_loss(flat_values: Float[Array, "..."]) -> Float[Array, ""]: return loss_fn(updated_tree) # calculate hessian - hessian = jax.hessian(_flat_loss)(flat_values) + return jax.hessian(_flat_loss)(flat_values) + + +def fisher_information_matrix( + loss_fn: tp.Callable, + tree: PT, +) -> Float[Array, "nparams nparams"]: + """ + Computes the Fisher information matrix of the parameters under the Laplace approximation, + by computing the Hessian of the loss function at the current parameter values. + + Args: + loss_fn (Callable): The loss function. Should accept (tree) as arguments. + All other arguments have to be "partial'd" into the loss function. + tree (PT): A PyTree of parameters. + + Returns: + Float[Array, "nparams nparams"]: The Fisher information matrix of the parameters. + + Example: + + .. code-block:: python + + import jax.numpy as jnp + + + def loss_fn(params): + x = params["a"].value + y = params["b"].value + return jnp.sum((x - 1.0) ** 2 + (y - 2.0) ** 2) + + + params = { + "a": evm.Parameter(value=jnp.array([1.0])), + "b": evm.Parameter(value=jnp.array([2.0])), + } + + fisher = evm.loss.fisher_information_matrix(loss_fn, params) + fisher.shape + # (2, 2) + """ + # calculate hessian + hessian = hessian_matrix(loss_fn, tree) + # invert to get the fisher information matrix under the Laplace assumption of normality + return jnp.linalg.inv(hessian) - # invert to get the correlation matrix under the Laplace assumption of normality - cov = jnp.linalg.inv(hessian) - # normalize via D^-1 @ cov @ D^-1 with D being the diagnonal standard deviation matrix - d = jnp.sqrt(jnp.diag(cov)) - cov = cov / jnp.outer(d, d) +def covariance_matrix( + loss_fn: tp.Callable, + tree: PT, +) -> Float[Array, "nparams nparams"]: + """ + Computes the covariance matrix of the parameters under the Laplace approximation, + by inverting the Hessian of the loss function at the current parameter values. + + See ``examples/toy_generation.py`` for an example usage. + + Args: + loss_fn (Callable): The loss function. Should accept (tree) as arguments. + All other arguments have to be "partial'd" into the loss function. + tree (PT): A PyTree of parameters. + + Returns: + Float[Array, "nparams nparams"]: The covariance matrix of the parameters. + + Example: + + .. code-block:: python + + import evermore as evm + import jax + import jax.numpy as jnp + + + def loss_fn(params): + x = params["a"].value + y = params["b"].value + return jnp.sum((x - 1.0) ** 2 + (y - 2.0) ** 2) + + + params = { + "a": evm.Parameter(value=jnp.array([1.0])), + "b": evm.Parameter(value=jnp.array([2.0])), + } + + cov = evm.loss.covariance(loss_fn, params) + cov.shape + # (2, 2) + """ + # calculate fisher information matrix + fisher = fisher_information_matrix(loss_fn, tree) + + # normalize via D^-1 @ fisher @ D^-1 with D being the diagnonal standard deviation matrix + d = jnp.sqrt(jnp.diag(fisher)) + cov = fisher / jnp.outer(d, d) # to avoid numerical issues, fix the diagonal to 1 return jnp.fill_diagonal(cov, 1.0, inplace=False) + + +def cramer_rao_uncertainty( + loss_fn: tp.Callable, + tree: PT, +) -> PT: + """ + Computes the Cramer-Rao uncertainty of the parameters under the Laplace approximation, + by computing the square root of the diagonal of the Fisher information matrix. + + Args: + loss_fn (Callable): The loss function. Should accept (tree) as arguments. + All other arguments have to be "partial'd" into the loss function. + tree (PT): A PyTree of parameters. + + Returns: + PT: The Cramer-Rao uncertainty of the parameters. + + Example: + + .. code-block:: python + + import jax.numpy as jnp + + + def loss_fn(params): + x = params["a"].value + y = params["b"].value + return jnp.sum((x - 1.0) ** 2 + (y - 2.0) ** 2) + + + params = { + "a": evm.Parameter(value=jnp.array([1.0])), + "b": evm.Parameter(value=jnp.array([2.0])), + } + + uncertainties = evm.loss.cramer_rao_uncertainty(loss_fn, params) + """ + _, unravel_fn = _ravel_pure_tree(tree) + + # calculate fisher information matrix + fisher_info = fisher_information_matrix(loss_fn, tree) + return unravel_fn(jnp.sqrt(jnp.diag(fisher_info))) diff --git a/tests/test_loss.py b/tests/test_loss.py index 01d4d92..b2fe5f1 100644 --- a/tests/test_loss.py +++ b/tests/test_loss.py @@ -13,7 +13,23 @@ ScalarParamTree: tp.TypeAlias = PyTree[ScalarParam] +def loss_fn(params: ScalarParamTree) -> Float[Scalar, ""]: + return ( + params["a"].value ** 2 + + 2 * params["b"].value ** 2 + + (params["a"].value + params["c"].value) ** 2 + ) + + +params: ScalarParamTree = { + "a": evm.Parameter(2.0), + "b": evm.Parameter(3.0), + "c": evm.Parameter(4.0), +} + + def test_get_log_probs(): + # use some with constraints params: ScalarParamTree = { "a": evm.NormalParameter(value=0.5), "b": evm.NormalParameter(), @@ -26,24 +42,29 @@ def test_get_log_probs(): assert log_probs["c"] == pytest.approx(0.0) -def test_compute_covariance(): - def loss_fn(params: ScalarParamTree) -> Float[Scalar, ""]: - return ( - params["a"].value ** 2 - + 2 * params["b"].value ** 2 - + (params["a"].value + params["c"].value) ** 2 - ) - - params: ScalarParamTree = { - "a": evm.Parameter(2.0), - "b": evm.Parameter(3.0), - "c": evm.Parameter(4.0), - } - - cov = evm.loss.compute_covariance(loss_fn, params) +def test_covariance_matrix(): + cov = evm.loss.covariance_matrix(loss_fn, params) assert cov.shape == (3, 3) np.testing.assert_allclose( cov, jnp.array([[1.0, 0.0, -0.7071067], [0.0, 1.0, 0.0], [-0.7071067, 0.0, 1.0]]), ) + + +def test_hessian_matrix(): + h = evm.loss.hessian_matrix(loss_fn, params) + + assert h.shape == (3, 3) + np.testing.assert_allclose( + h, + jnp.array([[4.0, 0.0, 2.0], [0.0, 4.0, 0.0], [2.0, 0.0, 2.0]]), + ) + + +def test_cramer_rao_uncertainty(): + uncertainty = evm.loss.cramer_rao_uncertainty(loss_fn, params) + + assert uncertainty["a"] == pytest.approx(0.70710677) + assert uncertainty["b"] == pytest.approx(0.5) + assert uncertainty["c"] == pytest.approx(1.0) From c51bdf53012b9d836ba46193b575e566b645f2bf Mon Sep 17 00:00:00 2001 From: Peter Fackeldey Date: Tue, 19 Aug 2025 11:56:43 -0400 Subject: [PATCH 3/5] fix roundtrippable treescope rendering (#80) --- src/evermore/visualization.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/evermore/visualization.py b/src/evermore/visualization.py index 62e07f8..edcc005 100644 --- a/src/evermore/visualization.py +++ b/src/evermore/visualization.py @@ -22,6 +22,7 @@ def __treescope_repr__( subtree_renderer=subtree_renderer, # Pass `roundtrippable=True` only if you can rebuild your object by # calling `__init__` with these attributes! - roundtrippable=True, + # This is `False` because of `evm.Parameter.raw_value` + roundtrippable=False, color=treescope.formatting_util.color_from_string(object_type.__qualname__), ) From 8bc9ba2bc3ee357d4219fbce6d20f4fec1c13a6b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 19 Aug 2025 11:58:27 -0400 Subject: [PATCH 4/5] chore: update pre-commit hooks (#78) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/astral-sh/ruff-pre-commit: v0.12.8 → v0.12.9](https://github.com/astral-sh/ruff-pre-commit/compare/v0.12.8...v0.12.9) - [github.com/astral-sh/ruff-pre-commit: v0.12.8 → v0.12.9](https://github.com/astral-sh/ruff-pre-commit/compare/v0.12.8...v0.12.9) - [github.com/henryiii/validate-pyproject-schema-store: 2025.08.07 → 2025.08.15](https://github.com/henryiii/validate-pyproject-schema-store/compare/2025.08.07...2025.08.15) - [github.com/python-jsonschema/check-jsonschema: 0.33.2 → 0.33.3](https://github.com/python-jsonschema/check-jsonschema/compare/0.33.2...0.33.3) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Peter Fackeldey --- .pre-commit-config.yaml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 111ba38..c399ed7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -20,14 +20,14 @@ repos: - id: trailing-whitespace - repo: https://github.com/astral-sh/ruff-pre-commit - rev: "v0.12.8" + rev: "v0.12.9" hooks: - id: ruff args: ["--fix", "--show-fixes"] - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: v0.12.8 + rev: v0.12.9 hooks: - id: ruff-format @@ -59,12 +59,12 @@ repos: exclude: ^(LICENSE$) - repo: https://github.com/henryiii/validate-pyproject-schema-store - rev: 2025.08.07 + rev: 2025.08.15 hooks: - id: validate-pyproject - repo: https://github.com/python-jsonschema/check-jsonschema - rev: 0.33.2 + rev: 0.33.3 hooks: - id: check-readthedocs - id: check-github-workflows From 60244dd5b5e3cf01d72d4998c63184e9eda5acf6 Mon Sep 17 00:00:00 2001 From: pfackeldey Date: Thu, 21 Aug 2025 10:35:42 -0400 Subject: [PATCH 5/5] increase version --- src/evermore/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/evermore/__init__.py b/src/evermore/__init__.py index 97c01b1..a3ed985 100644 --- a/src/evermore/__init__.py +++ b/src/evermore/__init__.py @@ -13,7 +13,7 @@ __contact__ = "https://github.com/pfackeldey/evermore" __license__ = "BSD-3-Clause" __status__ = "Development" -__version__ = "0.3.4" +__version__ = "0.3.5" # expose public API