I'm attempting to use the Apple Metal PJRT directly feeding StableHLO code but it doesn't seem to understand the current StableHLO specification (like the CPU and CUDA versions do). Even a trivial StableHLO like the one below fails:
func.func @main(%arg0: tensor<i1>) -> tensor<i1> {
%0 = "stablehlo.not"(%arg0) : (tensor<i1>) -> tensor<i1>
"stablehlo.return"(%0) : (tensor<i1>) -> ()
}
I'm using the current developer.apple.com/metal/jax/ PJRT plugin v0.1.1 (distributed in pypi.org/project/jax-metal/)
Please: