Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Conversation

@inailuig
Copy link
Collaborator

@inailuig inailuig commented Jul 24, 2025

For the flattening/reshaping/batching of samples we need to know how many batch dims the samples have.

Since I can't figure out how to automatically increase that number for the output of a jax.lax.scan, instead here I keep around the number of non-batch dims and compute it from that.

That information is also present in hilbert, so in principle it would be possible to move all the tree maps etc from the wrapper class to netket, and support pytrees directly.

Example:

import netket as nk
import jax
import jax.numpy as jnp

g1 = nk.graph.Chain(4)
g2 = nk.graph.Chain(1, pbc=False)
hi1 = nk.hilbert.Fock(n_max=8, N=g1.n_nodes)
hi2 = nk.hilbert.Spin(s=1/2, N=g2.n_nodes)
hi = hi1*hi2

k = jax.random.key(123)
x = hi.random_state(k, (23,))
print(jax.tree.map(lambda x: x.shape, x))

ma1 = nk.models.RBM
ma2 = nk.models.RBM

from netket.utils.samples_pytree import SampleWrapperExample
from flax import linen as nn
class ProdModel(nn.Module):
    models: tuple[nn.Module]
    @nn.compact
    def __call__(self, x):
        assert isinstance (x, SampleWrapperExample)
        assert len(x.sub_states) == len(self.models)
        return sum(mi()(xi) for mi, xi in zip(self.models, x.sub_states))

ma = ProdModel((ma1, ma2))

sa = nk.sampler.MetropolisLocal(hi)
vs = nk.vqs.MCState(sa, ma)
vs.sample()

S = vs.quantum_geometric_tensor()
S@vs.parameters
SampleWrapperExample(sub_states=((23, 4), (23, 1)), _structure=(ShapeDtypeStruct(shape=(4,), dtype=int8), ShapeDtypeStruct(shape=(1,), dtype=int8)))

@gcarleo

@PhilipVinc
Copy link
Member

(I am in favour of this. Maybe a short discussion about the class should be had.)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants