Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Bad avals planted into remat #69

@jkramar

Description

@jkramar

When planting into a remat, if the planted object has different aval than the target, the result is an object whose trace-time and runtime avals don't match.

For example, this prints ((3,), dtype('int32')) followed by (2,) float64:

def f(x):
  x = harvest.sow(x, tag='tag', name='x')
  print((x.shape, x.dtype))
  jax.debug.callback(lambda x: print(x.shape, x.dtype), x)
  return x

harvest.plant(jax.remat(f), tag='tag')({'x': np.zeros(2)}, np.zeros(3, int))

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions