Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Reactant get_device with sharding throws error inside of MLDataDevices, impossible to use with TrainState API #1520

@bvdmitri

Description

@bvdmitri

MWE

julia> using Lux, Reactant

julia> devices = Reactant.devices()

julia> mesh = Sharding.Mesh(reshape(devices, length(devices)), (:batch, ))
Reactant.Sharding.Mesh{1, Vector{Int64}}([0, 1, 2, 3, 4, 5, 6, 7], [0, 1, 2, 3, 4, 5, 6, 7], (:batch,), (8,))

julia> setup_device = Lux.reactant_device(; sharding = Sharding.Replicated(mesh))
(::ReactantDevice{Missing, Missing, Reactant.Sharding.Replicated{Reactant.Sharding.Mesh{1, Vector{Int64}}}, Missing}) (generic function with 1 method)

julia> a1 = [ 1.0, 2.0, 3.0 ]
3-element Vector{Float64}:
 1.0
 2.0
 3.0

julia> a2 = [ 3.0, 2.0, 1.0 ]
3-element Vector{Float64}:
 3.0
 2.0
 1.0

julia> b1 = setup_device(a1)
3-element ConcreteIFRTArray{Float64,1} with "mhlo.sharding = {replicated}":
 1.0
 2.0
 3.0

julia> b2 = setup_device(a2)
3-element ConcreteIFRTArray{Float64,1} with "mhlo.sharding = {replicated}":
 3.0
 2.0
 1.0

julia> Lux.MLDataDevices.get_device(b1) == Lux.MLDataDevices.get_device(b2)
ERROR: type Nothing has no field device
Stacktrace:
 [1] getproperty
   @ ./Base.jl:49 [inlined]
 [2] ==(x::ReactantDevice{Reactant.XLA.IFRT.Client, Nothing, IdDict{…}, Missing}, y::ReactantDevice{Reactant.XLA.IFRT.Client, Nothing, IdDict{…}, Missing})
   @ MLDataDevices ~/.julia/packages/MLDataDevices/NeohJ/src/public.jl:90
 [3] top-level scope
   @ REPL[36]:1
Some type information was truncated. Use `show(err)` to see complete types.

The culprit is this code, which expects the device field set to missing, while in reality it is set to nothing

julia> Lux.MLDataDevices.get_device(b1) |> dump
ReactantDevice{Reactant.XLA.IFRT.Client, Nothing, IdDict{Union{ConcreteIFRTArray, ConcreteIFRTNumber, ConcretePJRTArray, ConcretePJRTNumber}, Reactant.Sharding.AbstractSharding}, Missing}(Reactant.XLA.IFRT.Client(Ptr{Nothing} @0x00000000222eea30), nothing, IdDict{Union{ConcreteIFRTArray, ConcreteIFRTNumber, ConcretePJRTArray, ConcretePJRTNumber}, Reactant.Sharding.AbstractSharding}(ConcreteIFRTArray{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NamedSharding{1, Reactant.Sharding.Mesh{1, Vector{Int64}}}, Vector{Tuple{UnitRange{Int64}}}}, Nothing}([1.0, 2.0, 3.0]) => Reactant.Sharding.NamedSharding{1, Reactant.Sharding.Mesh{1, Vector{Int64}}}(Reactant.Sharding.Mesh{1, Vector{Int64}}([0, 1, 2, 3, 4, 5, 6, 7], [0, 1, 2, 3, 4, 5, 6, 7], (:batch,), (8,)), Vector{Union{Nothing, Symbol}}[[nothing]], (true,), (-1,), Vector{Union{Nothing, Tuple{Int64, Int64}}}[[nothing]]))) (function of type ReactantDevice{Reactant.XLA.IFRT.Client, Nothing, IdDict{Union{ConcreteIFRTArray, ConcreteIFRTNumber, ConcretePJRTArray, ConcretePJRTNumber}, Reactant.Sharding.AbstractSharding}, Missing})
  client: Reactant.XLA.IFRT.Client
    client: Ptr{Nothing} @0x00000000222eea30
  device: Nothing nothing
  sharding: IdDict{Union{ConcreteIFRTArray, ConcreteIFRTNumber, ConcretePJRTArray, ConcretePJRTNumber}, Reactant.Sharding.AbstractSharding}
    ht: Memory{Any}
      length: Int64 32
      ptr: Ptr{Nothing} @0x000075b7b9c38380
    count: Int64 1
    ndel: Int64 0

This causes problems when using the TrainState api, as it throws immediately when trying to combine the devices, e.g.

julia> using Random, Optimisers

julia> model = Dense(32 => 64)
Dense(32 => 64)              # 2_112 parameters

julia> rng = Random.default_rng()
TaskLocalRNG()

julia> optimizer = OptimiserChain(ClipNorm(0.5f0), Adam(0.0005f0))
OptimiserChain(ClipNorm{Float32, Float64}(0.5, 2.0, true), Adam(eta=0.0005, beta=(0.9, 0.999), epsilon=1.0e-8))

julia> ps, st = Lux.setup(rng, model) |> setup_device
((weight = ConcreteIFRTArray{Float32, 2, Reactant.Sharding.ShardInfo{Reactant.Sharding.NamedSharding{2, Reactant.Sharding.Mesh{1, Vector{Int64}}}, Vector{Tuple{UnitRange{Int64}, UnitRange{Int64}}}}, Nothing}(Float32[-0.23557188 0.28043154  0.14598301 0.01099818; -0.14897338 0.099628106  0.2255169 -0.28783646;  ; -0.20580283 -0.013407671  -0.086372495 0.2402338; 0.28477693 -0.034549087  0.22345623 -0.17259249]), bias = ConcreteIFRTArray{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NamedSharding{1, Reactant.Sharding.Mesh{1, Vector{Int64}}}, Vector{Tuple{UnitRange{Int64}}}}, Nothing}(Float32[-0.019760212, -0.00045006513, -0.13889635, 0.12555468, -0.0103946775, 0.16893348, 0.1503944, -0.10110072, -0.07670705, 0.1038366    -0.1260984, -0.12039028, -0.057218667, -0.15630175, 0.14630395, 0.04193333, 0.023286724, 0.030515056, 0.061150294, 0.028806655])), NamedTuple())

julia> train_state = Training.TrainState(model, ps, st, optimizer)
ERROR: type Nothing has no field device
Stacktrace:
  [1] getproperty
    @ ./Base.jl:49 [inlined]
  [2] ==
    @ ~/.julia/packages/MLDataDevices/NeohJ/src/public.jl:90 [inlined]
  [3] combine_devices(dev1::ReactantDevice{Reactant.XLA.IFRT.Client, Nothing, IdDict{…}, Missing}, dev2::ReactantDevice{Reactant.XLA.IFRT.Client, Nothing, IdDict{…}, Missing})
    @ MLDataDevices.Internal ~/.julia/packages/MLDataDevices/NeohJ/src/internal.jl:150
  [4] macro expansion
    @ ~/.julia/packages/MLDataDevices/NeohJ/src/internal.jl:262 [inlined]
  [5] unrolled_mapreduce
    @ ~/.julia/packages/MLDataDevices/NeohJ/src/internal.jl:249 [inlined]
  [6] unrolled_mapreduce(f::typeof(get_device), op::typeof(MLDataDevices.Internal.combine_devices), itr::Tuple{ConcreteIFRTArray{…}, ConcreteIFRTArray{…}})
    @ MLDataDevices.Internal ~/.julia/packages/MLDataDevices/NeohJ/src/internal.jl:240
  [7] get_device
    @ ~/.julia/packages/MLDataDevices/NeohJ/src/internal.jl:214 [inlined]
  [8] get_device
    @ ~/.julia/packages/MLDataDevices/NeohJ/src/public.jl:481 [inlined]
  [9] Lux.Training.TrainState(model::Dense{…}, ps::@NamedTuple{}, st::@NamedTuple{}, optimizer::OptimiserChain{…})
    @ Lux.Training ~/.julia/packages/Lux/IjVGV/src/helpers/training.jl:67
 [10] top-level scope
    @ REPL[46]:1
Some type information was truncated. Use `show(err)` to see complete types.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions