-
Notifications
You must be signed in to change notification settings - Fork 80
Open
Description
MWE
julia> using Lux, Reactant
julia> devices = Reactant.devices()
julia> mesh = Sharding.Mesh(reshape(devices, length(devices)), (:batch, ))
Reactant.Sharding.Mesh{1, Vector{Int64}}([0, 1, 2, 3, 4, 5, 6, 7], [0, 1, 2, 3, 4, 5, 6, 7], (:batch,), (8,))
julia> setup_device = Lux.reactant_device(; sharding = Sharding.Replicated(mesh))
(::ReactantDevice{Missing, Missing, Reactant.Sharding.Replicated{Reactant.Sharding.Mesh{1, Vector{Int64}}}, Missing}) (generic function with 1 method)
julia> a1 = [ 1.0, 2.0, 3.0 ]
3-element Vector{Float64}:
1.0
2.0
3.0
julia> a2 = [ 3.0, 2.0, 1.0 ]
3-element Vector{Float64}:
3.0
2.0
1.0
julia> b1 = setup_device(a1)
3-element ConcreteIFRTArray{Float64,1} with "mhlo.sharding = {replicated}":
1.0
2.0
3.0
julia> b2 = setup_device(a2)
3-element ConcreteIFRTArray{Float64,1} with "mhlo.sharding = {replicated}":
3.0
2.0
1.0
julia> Lux.MLDataDevices.get_device(b1) == Lux.MLDataDevices.get_device(b2)
ERROR: type Nothing has no field device
Stacktrace:
[1] getproperty
@ ./Base.jl:49 [inlined]
[2] ==(x::ReactantDevice{Reactant.XLA.IFRT.Client, Nothing, IdDict{…}, Missing}, y::ReactantDevice{Reactant.XLA.IFRT.Client, Nothing, IdDict{…}, Missing})
@ MLDataDevices ~/.julia/packages/MLDataDevices/NeohJ/src/public.jl:90
[3] top-level scope
@ REPL[36]:1
Some type information was truncated. Use `show(err)` to see complete types.The culprit is this code, which expects the device field set to missing, while in reality it is set to nothing
julia> Lux.MLDataDevices.get_device(b1) |> dump
ReactantDevice{Reactant.XLA.IFRT.Client, Nothing, IdDict{Union{ConcreteIFRTArray, ConcreteIFRTNumber, ConcretePJRTArray, ConcretePJRTNumber}, Reactant.Sharding.AbstractSharding}, Missing}(Reactant.XLA.IFRT.Client(Ptr{Nothing} @0x00000000222eea30), nothing, IdDict{Union{ConcreteIFRTArray, ConcreteIFRTNumber, ConcretePJRTArray, ConcretePJRTNumber}, Reactant.Sharding.AbstractSharding}(ConcreteIFRTArray{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NamedSharding{1, Reactant.Sharding.Mesh{1, Vector{Int64}}}, Vector{Tuple{UnitRange{Int64}}}}, Nothing}([1.0, 2.0, 3.0]) => Reactant.Sharding.NamedSharding{1, Reactant.Sharding.Mesh{1, Vector{Int64}}}(Reactant.Sharding.Mesh{1, Vector{Int64}}([0, 1, 2, 3, 4, 5, 6, 7], [0, 1, 2, 3, 4, 5, 6, 7], (:batch,), (8,)), Vector{Union{Nothing, Symbol}}[[nothing]], (true,), (-1,), Vector{Union{Nothing, Tuple{Int64, Int64}}}[[nothing]]))) (function of type ReactantDevice{Reactant.XLA.IFRT.Client, Nothing, IdDict{Union{ConcreteIFRTArray, ConcreteIFRTNumber, ConcretePJRTArray, ConcretePJRTNumber}, Reactant.Sharding.AbstractSharding}, Missing})
client: Reactant.XLA.IFRT.Client
client: Ptr{Nothing} @0x00000000222eea30
device: Nothing nothing
sharding: IdDict{Union{ConcreteIFRTArray, ConcreteIFRTNumber, ConcretePJRTArray, ConcretePJRTNumber}, Reactant.Sharding.AbstractSharding}
ht: Memory{Any}
length: Int64 32
ptr: Ptr{Nothing} @0x000075b7b9c38380
count: Int64 1
ndel: Int64 0This causes problems when using the TrainState api, as it throws immediately when trying to combine the devices, e.g.
julia> using Random, Optimisers
julia> model = Dense(32 => 64)
Dense(32 => 64) # 2_112 parameters
julia> rng = Random.default_rng()
TaskLocalRNG()
julia> optimizer = OptimiserChain(ClipNorm(0.5f0), Adam(0.0005f0))
OptimiserChain(ClipNorm{Float32, Float64}(0.5, 2.0, true), Adam(eta=0.0005, beta=(0.9, 0.999), epsilon=1.0e-8))
julia> ps, st = Lux.setup(rng, model) |> setup_device
((weight = ConcreteIFRTArray{Float32, 2, Reactant.Sharding.ShardInfo{Reactant.Sharding.NamedSharding{2, Reactant.Sharding.Mesh{1, Vector{Int64}}}, Vector{Tuple{UnitRange{Int64}, UnitRange{Int64}}}}, Nothing}(Float32[-0.23557188 0.28043154 … 0.14598301 0.01099818; -0.14897338 0.099628106 … 0.2255169 -0.28783646; … ; -0.20580283 -0.013407671 … -0.086372495 0.2402338; 0.28477693 -0.034549087 … 0.22345623 -0.17259249]), bias = ConcreteIFRTArray{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NamedSharding{1, Reactant.Sharding.Mesh{1, Vector{Int64}}}, Vector{Tuple{UnitRange{Int64}}}}, Nothing}(Float32[-0.019760212, -0.00045006513, -0.13889635, 0.12555468, -0.0103946775, 0.16893348, 0.1503944, -0.10110072, -0.07670705, 0.1038366 … -0.1260984, -0.12039028, -0.057218667, -0.15630175, 0.14630395, 0.04193333, 0.023286724, 0.030515056, 0.061150294, 0.028806655])), NamedTuple())
julia> train_state = Training.TrainState(model, ps, st, optimizer)
ERROR: type Nothing has no field device
Stacktrace:
[1] getproperty
@ ./Base.jl:49 [inlined]
[2] ==
@ ~/.julia/packages/MLDataDevices/NeohJ/src/public.jl:90 [inlined]
[3] combine_devices(dev1::ReactantDevice{Reactant.XLA.IFRT.Client, Nothing, IdDict{…}, Missing}, dev2::ReactantDevice{Reactant.XLA.IFRT.Client, Nothing, IdDict{…}, Missing})
@ MLDataDevices.Internal ~/.julia/packages/MLDataDevices/NeohJ/src/internal.jl:150
[4] macro expansion
@ ~/.julia/packages/MLDataDevices/NeohJ/src/internal.jl:262 [inlined]
[5] unrolled_mapreduce
@ ~/.julia/packages/MLDataDevices/NeohJ/src/internal.jl:249 [inlined]
[6] unrolled_mapreduce(f::typeof(get_device), op::typeof(MLDataDevices.Internal.combine_devices), itr::Tuple{ConcreteIFRTArray{…}, ConcreteIFRTArray{…}})
@ MLDataDevices.Internal ~/.julia/packages/MLDataDevices/NeohJ/src/internal.jl:240
[7] get_device
@ ~/.julia/packages/MLDataDevices/NeohJ/src/internal.jl:214 [inlined]
[8] get_device
@ ~/.julia/packages/MLDataDevices/NeohJ/src/public.jl:481 [inlined]
[9] Lux.Training.TrainState(model::Dense{…}, ps::@NamedTuple{…}, st::@NamedTuple{}, optimizer::OptimiserChain{…})
@ Lux.Training ~/.julia/packages/Lux/IjVGV/src/helpers/training.jl:67
[10] top-level scope
@ REPL[46]:1
Some type information was truncated. Use `show(err)` to see complete types.Metadata
Metadata
Assignees
Labels
No labels