Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Conversation

@riga
Copy link
Collaborator

@riga riga commented May 1, 2025

This PR changes the default value of the prior class member to use a default_factory rather than a singleton Normal(mean=0.0, width=1.0).

The current choice causes the singleton to be created once before users have the chance to configure the default precision via jax.config.update("jax_enable_x64", True).

For me, this resulted in mixed precision operations between f32 and f64 arrays that lead to minor numerical differences (e.g. log_probs not being exactly 0 in test cases with default prior settings).

@pfackeldey pfackeldey merged commit 4b62f05 into pfackeldey:main May 1, 2025
6 checks passed
@riga riga deleted the fix/prior_init branch May 2, 2025 08:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants