-
Notifications
You must be signed in to change notification settings - Fork 79
Closed
EnzymeAD/Enzyme-JAX
#1511Labels
Description
Compilation fails when computing gradients using Enzyme on vector inputs with Reactant.jl. The error occurs specifically when the input tensor has shape (2,) (vector) but works fine with higher-dimensional inputs like (2,1,16) and (2,1).
Minimal Example
using Lux, Enzyme, Reactant, Random
Reactant.set_default_backend("cpu")
const cdev = cpu_device()
const xdev = reactant_device(; force=true)
model = Chain(
Dense(2, 16, tanh),
Dense(16, 16, tanh),
Dense(16, 1)
)
ps_model, st_model = Lux.setup(Random.default_rng(), model)
ps_model_ra = ps_model |> xdev
st_model_ra = st_model |> xdev
model_stateful = StatefulLuxLayer(model, ps_model_ra, st_model_ra)
m = @jit model_stateful(xdev(randn(Float32, 2, 1, 16)))
m = @jit model_stateful(xdev(randn(Float32, 2, 1)))
m = @jit model_stateful(xdev(randn(Float32, 2)))
function ∂m∂x(smodel, x)
Enzyme.gradient(Enzyme.Reverse, sum ∘ smodel, x)[1]
end
∂m∂x_val = @jit ∂m∂x(model_stateful, xdev(randn(Float32, 2, 1, 16)))
∂m∂x_val = @jit ∂m∂x(model_stateful, xdev(randn(Float32, 2, 1)))
#failed
∂m∂x_val = @jit ∂m∂x(model_stateful, xdev(randn(Float32, 2)))The last line gives an error:
loc(callsite("tanh_fast/tanh"("~/.julia/packages/Reactant/IgTfV/ext/ReactantNNlibExt/Implementations.jl":6:0) at "traced_call/call"("~/.julia/packages/Reactant/IgTfV/src/ControlFlow.jl":8:0))): error: 'stablehlo.multiply' op requires compatible types for all operands and results
┌ Error: Compilation failed, MLIR module written to /tmp/reactant_3fOn9J/module_000_a8bc_post_all_pm.mlir
└ @ Reactant.MLIR.IR ~/.julia/packages/Reactant/IgTfV/src/mlir/IR/Pass.jl:119
ERROR: "failed to run pass manager on module"
Stacktrace:
[1] run!(pm::Reactant.MLIR.IR.PassManager, mod::Reactant.MLIR.IR.Module, key::String)
@ Reactant.MLIR.IR ~/.julia/packages/Reactant/IgTfV/src/mlir/IR/Pass.jl:163
[2] run_pass_pipeline!(mod::Reactant.MLIR.IR.Module, pass_pipeline::String, key::String; enable_verifier::Bool)
@ Reactant.Compiler ~/.julia/packages/Reactant/IgTfV/src/Compiler.jl:1315
[3] run_pass_pipeline!
@ ~/.julia/packages/Reactant/IgTfV/src/Compiler.jl:1310 [inlined]
[4] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{…}, compile_options::CompileOptions, callcache::Dict{…}, sdycache::Dict{…}; fn_kwargs::@NamedTuple{}, backend::String, runtime::Val{…}, legalize_stablehlo_to_mhlo::Bool, kwargs::@Kwargs{})
@ Reactant.Compiler ~/.julia/packages/Reactant/IgTfV/src/Compiler.jl:1757
[5] compile_mlir! (repeats 2 times)
@ ~/.julia/packages/Reactant/IgTfV/src/Compiler.jl:1572 [inlined]
[6] compile_xla(f::Function, args::Tuple{…}; before_xla_optimizations::Bool, client::Nothing, serializable::Bool, kwargs::@Kwargs{…})
@ Reactant.Compiler ~/.julia/packages/Reactant/IgTfV/src/Compiler.jl:3494
[7] compile_xla
@ ~/.julia/packages/Reactant/IgTfV/src/Compiler.jl:3467 [inlined]
[8] compile(f::Function, args::Tuple{…}; kwargs::@Kwargs{…})
@ Reactant.Compiler ~/.julia/packages/Reactant/IgTfV/src/Compiler.jl:3569
[9] top-level scope
@ ~/.julia/packages/Reactant/IgTfV/src/Compiler.jl:2644
Some type information was truncated. Use `show(err)` to see complete types.Environment
(jl_CeLRAA) pkg> st
Status `/tmp/jl_CeLRAA/Project.toml`
[7da242da] Enzyme v0.13.87
[b2108857] Lux v1.24.0
[3c362404] Reactant v0.2.171
[9a3f8284] Random v1.11.0