When planting into a remat, if the planted object has different aval than the target, the result is an object whose trace-time and runtime avals don't match.
For example, this prints ((3,), dtype('int32')) followed by (2,) float64:
def f(x):
x = harvest.sow(x, tag='tag', name='x')
print((x.shape, x.dtype))
jax.debug.callback(lambda x: print(x.shape, x.dtype), x)
return x
harvest.plant(jax.remat(f), tag='tag')({'x': np.zeros(2)}, np.zeros(3, int))
When planting into a remat, if the planted object has different aval than the target, the result is an object whose trace-time and runtime avals don't match.
For example, this prints
((3,), dtype('int32'))followed by(2,) float64: