Here's an example:
def f(x):
harvest.sow(0, name='x', tag='tag')
return x
print(harvest.reap(f, tag='tag')(1))
print(harvest.reap(lambda x: jax.remat(f), tag='tag')(1))
print(harvest.reap(lambda x: jax.jit(f), tag='tag')(1))
print(harvest.reap(lambda x: jax.lax.cond(x == 1, f, f, x), tag='tag')(1))
print(
harvest.reap(
lambda x: jax.lax.fori_loop(0, 1, lambda i, xx: f(xx), x), tag='tag'
)(1)
)
print(
harvest.reap(
lambda x: jax.lax.fori_loop(0, x, lambda i, xx: f(xx), x), tag='tag'
)(1)
)
The first of these reaps contains 'x': 0, but the others don't.
Here's an example:
The first of these reaps contains
'x': 0, but the others don't.