Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Error "failed to run pass manager on module" only on Vector input #1521

@jacobleft

Description

@jacobleft

Compilation fails when computing gradients using Enzyme on vector inputs with Reactant.jl. The error occurs specifically when the input tensor has shape (2,) (vector) but works fine with higher-dimensional inputs like (2,1,16) and (2,1).

Minimal Example

using Lux, Enzyme, Reactant, Random

Reactant.set_default_backend("cpu")
const cdev = cpu_device()
const xdev = reactant_device(; force=true)

model = Chain(
    Dense(2, 16, tanh),
    Dense(16, 16, tanh),
    Dense(16, 1)
)

ps_model, st_model = Lux.setup(Random.default_rng(), model)
ps_model_ra = ps_model |> xdev
st_model_ra = st_model |> xdev
model_stateful = StatefulLuxLayer(model, ps_model_ra, st_model_ra)

m = @jit model_stateful(xdev(randn(Float32, 2, 1, 16)))
m = @jit model_stateful(xdev(randn(Float32, 2, 1)))
m = @jit model_stateful(xdev(randn(Float32, 2)))

function ∂m∂x(smodel, x)
    Enzyme.gradient(Enzyme.Reverse, sum  smodel, x)[1]
end


∂m∂x_val = @jit ∂m∂x(model_stateful, xdev(randn(Float32, 2, 1, 16)))
∂m∂x_val = @jit ∂m∂x(model_stateful, xdev(randn(Float32, 2, 1)))
#failed
∂m∂x_val = @jit ∂m∂x(model_stateful, xdev(randn(Float32, 2)))

The last line gives an error:

loc(callsite("tanh_fast/tanh"("~/.julia/packages/Reactant/IgTfV/ext/ReactantNNlibExt/Implementations.jl":6:0) at "traced_call/call"("~/.julia/packages/Reactant/IgTfV/src/ControlFlow.jl":8:0))): error: 'stablehlo.multiply' op requires compatible types for all operands and results
┌ Error: Compilation failed, MLIR module written to /tmp/reactant_3fOn9J/module_000_a8bc_post_all_pm.mlir
└ @ Reactant.MLIR.IR ~/.julia/packages/Reactant/IgTfV/src/mlir/IR/Pass.jl:119
ERROR: "failed to run pass manager on module"
Stacktrace:
 [1] run!(pm::Reactant.MLIR.IR.PassManager, mod::Reactant.MLIR.IR.Module, key::String)
   @ Reactant.MLIR.IR ~/.julia/packages/Reactant/IgTfV/src/mlir/IR/Pass.jl:163
 [2] run_pass_pipeline!(mod::Reactant.MLIR.IR.Module, pass_pipeline::String, key::String; enable_verifier::Bool)
   @ Reactant.Compiler ~/.julia/packages/Reactant/IgTfV/src/Compiler.jl:1315
 [3] run_pass_pipeline!
   @ ~/.julia/packages/Reactant/IgTfV/src/Compiler.jl:1310 [inlined]
 [4] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{…}, compile_options::CompileOptions, callcache::Dict{…}, sdycache::Dict{…}; fn_kwargs::@NamedTuple{}, backend::String, runtime::Val{…}, legalize_stablehlo_to_mhlo::Bool, kwargs::@Kwargs{})
   @ Reactant.Compiler ~/.julia/packages/Reactant/IgTfV/src/Compiler.jl:1757
 [5] compile_mlir! (repeats 2 times)
   @ ~/.julia/packages/Reactant/IgTfV/src/Compiler.jl:1572 [inlined]
 [6] compile_xla(f::Function, args::Tuple{…}; before_xla_optimizations::Bool, client::Nothing, serializable::Bool, kwargs::@Kwargs{})
   @ Reactant.Compiler ~/.julia/packages/Reactant/IgTfV/src/Compiler.jl:3494
 [7] compile_xla
   @ ~/.julia/packages/Reactant/IgTfV/src/Compiler.jl:3467 [inlined]
 [8] compile(f::Function, args::Tuple{…}; kwargs::@Kwargs{})
   @ Reactant.Compiler ~/.julia/packages/Reactant/IgTfV/src/Compiler.jl:3569
 [9] top-level scope
   @ ~/.julia/packages/Reactant/IgTfV/src/Compiler.jl:2644
Some type information was truncated. Use `show(err)` to see complete types.

Environment

(jl_CeLRAA) pkg> st
Status `/tmp/jl_CeLRAA/Project.toml`
  [7da242da] Enzyme v0.13.87
  [b2108857] Lux v1.24.0
  [3c362404] Reactant v0.2.171
  [9a3f8284] Random v1.11.0

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingreactant

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions