diff --git a/docs/src/ref/choice_maps.md b/docs/src/ref/choice_maps.md index c065b1b3..4a23b7cf 100644 --- a/docs/src/ref/choice_maps.md +++ b/docs/src/ref/choice_maps.md @@ -8,13 +8,58 @@ ChoiceMap Choice maps are constructed by users to express observations and/or constraints on the traces of generative functions. Choice maps are also returned by certain Gen inference methods, and are used internally by various Gen inference methods. +A choicemap a tree, whose leaf nodes store a single value, and whose internal nodes provide addresses +for sub-choicemaps. Leaf nodes have type: +```@docs +ValueChoiceMap +``` + +### Example Usage Overview + +Choicemaps store values nested in a tree where each node posesses an address for each subtree. +A leaf-node choicemap simply contains a value, and has its value looked up via: +```julia +value = choicemap[] +``` +If a choicemap has a value choicemap at address `:a`, the value it stores is looked up via: +```julia +value = choicemap[:a] +``` +A choicemap may also have a non-value choicemap stored at an address. For instance, +if a choicemap has another choicemap stored at address `:a`, and this internal choicemap +has a valuechoicemap stored at address `:b` and another at `:c`, we could perform the following lookups: +```julia +value1 = choicemap[:a => :b] +value2 = choicemap[:a => :c] +``` +Nesting can be arbitrarily deep, and the keys can be arbitrary values; for instance +choicemaps can be constructed with values at the following nested addresses: +```julia +value = choicemap[:a => :b => :c => 4 => 1.63 => :e] +value = choicemap[:a => :b => :a => 2 => "alphabet" => :e] +``` +To get a sub-choicemap, use `get_submap`: +```julia +value1 = choicemap[:a => :b] +submap = get_submap(choicemap, :a) +value1 == submap[:b] # is true + +value_submap = get_submap(choicemap, :a => :b) +value_submap[] == value1 # is true +``` +One can think of `ValueChoiceMap`s at storing being a choicemap which has a value at "nesting level zero", +while other choicemaps have values at "nesting level" one or higher. + +### Interface + Choice maps provide the following methods: ```@docs +get_submap +get_submaps_shallow has_value get_value -get_submap get_values_shallow -get_submaps_shallow +get_nonvalue_submaps_shallow to_array from_array get_selected @@ -23,7 +68,7 @@ Note that none of these methods mutate the choice map. Choice maps also implement: -- `Base.isempty`, which tests of there are no random choices in the choice map +- `Base.isempty`, which returns `false` if the choicemap contains no value or submaps, and `true` otherwise. - `Base.merge`, which takes two choice maps, and returns a new choice map containing all random choices in either choice map. It is an error if the choice maps both have values at the same address, or if one choice map has a value at an address that is the prefix of the address of a value in the other choice map. @@ -50,3 +95,35 @@ choicemap set_value! set_submap! ``` + +### Implementing custom choicemap types + +To implement a custom choicemap, one must implement +`get_submap` and `get_submaps_shallow`. +To avoid method ambiguity with the default +`get_submap(::ChoiceMap, ::Pair)`, one must implement both +```julia +get_submap(::CustomChoiceMap, addr) +``` +and +```julia +get_submap(::CustomChoiceMap, addr::Pair) +``` +To use the default implementation of `get_submap(_, ::Pair)`, +one may define +```julia +get_submap(c::CustomChoiceMap, addr::Pair) = _get_choicemap(c, addr) +``` + +Once `get_submap` and `get_submaps_shallow` are defined, default +implementations are provided for: +- `has_value` +- `get_value` +- `get_values_shallow` +- `get_nonvalue_submaps_shallow` +- `to_array` +- `get_selected` + +If one wishes to support `from_array`, they must implement +`_from_array`, as described in the documentation for +[`from_array`](@ref). \ No newline at end of file diff --git a/src/Gen.jl b/src/Gen.jl index 9f3da9e3..fa239359 100644 --- a/src/Gen.jl +++ b/src/Gen.jl @@ -37,7 +37,7 @@ include("backprop.jl") include("address.jl") # abstract and built-in concrete choice map data types -include("choice_map.jl") +include("choice_map/choice_map.jl") # a homogeneous trie data type (not for use as choice map) include("trie.jl") diff --git a/src/choice_map.jl b/src/choice_map.jl deleted file mode 100644 index b7891b40..00000000 --- a/src/choice_map.jl +++ /dev/null @@ -1,1009 +0,0 @@ -######################### -# choice map interface # -######################### - -""" - schema = get_address_schema(::Type{T}) where {T <: ChoiceMap} - -Return the (top-level) address schema for the given choice map. -""" -function get_address_schema end - -""" - submap = get_submap(choices::ChoiceMap, addr) - -Return the sub-assignment containing all choices whose address is prefixed by addr. - -It is an error if the assignment contains a value at the given address. If -there are no choices whose address is prefixed by addr then return an -`EmptyChoiceMap`. -""" -function get_submap end - -""" - value = get_value(choices::ChoiceMap, addr) - -Return the value at the given address in the assignment, or throw a KeyError if -no value exists. A syntactic sugar is `Base.getindex`: - - value = choices[addr] -""" -function get_value end - -""" - key_submap_iterable = get_submaps_shallow(choices::ChoiceMap) - -Return an iterable collection of tuples `(key, submap::ChoiceMap)` for each top-level key -that has a non-empty sub-assignment. -""" -function get_submaps_shallow end - -""" - has_value(choices::ChoiceMap, addr) - -Return true if there is a value at the given address. -""" -function has_value end - -""" - key_submap_iterable = get_values_shallow(choices::ChoiceMap) - -Return an iterable collection of tuples `(key, value)` for each -top-level key associated with a value. -""" -function get_values_shallow end - -""" - abstract type ChoiceMap end - -Abstract type for maps from hierarchical addresses to values. -""" -abstract type ChoiceMap end - -""" - Base.isempty(choices::ChoiceMap) - -Return true if there are no values in the assignment. -""" -function Base.isempty(::ChoiceMap) - true -end - -@inline get_submap(choices::ChoiceMap, addr) = EmptyChoiceMap() -@inline has_value(choices::ChoiceMap, addr) = false -@inline get_value(choices::ChoiceMap, addr) = throw(KeyError(addr)) -@inline Base.getindex(choices::ChoiceMap, addr) = get_value(choices, addr) - -@inline function _has_value(choices::T, addr::Pair) where {T <: ChoiceMap} - (first, rest) = addr - submap = get_submap(choices, first) - has_value(submap, rest) -end - -@inline function _get_value(choices::T, addr::Pair) where {T <: ChoiceMap} - (first, rest) = addr - submap = get_submap(choices, first) - get_value(submap, rest) -end - -@inline function _get_submap(choices::T, addr::Pair) where {T <: ChoiceMap} - (first, rest) = addr - submap = get_submap(choices, first) - get_submap(submap, rest) -end - -function _show_pretty(io::IO, choices::ChoiceMap, pre, vert_bars::Tuple) - VERT = '\u2502' - PLUS = '\u251C' - HORZ = '\u2500' - LAST = '\u2514' - indent_vert = vcat(Char[' ' for _ in 1:pre], Char[VERT, '\n']) - indent_vert_last = vcat(Char[' ' for _ in 1:pre], Char[VERT, '\n']) - indent = vcat(Char[' ' for _ in 1:pre], Char[PLUS, HORZ, HORZ, ' ']) - indent_last = vcat(Char[' ' for _ in 1:pre], Char[LAST, HORZ, HORZ, ' ']) - for i in vert_bars - indent_vert[i] = VERT - indent[i] = VERT - indent_last[i] = VERT - end - indent_vert_str = join(indent_vert) - indent_vert_last_str = join(indent_vert_last) - indent_str = join(indent) - indent_last_str = join(indent_last) - key_and_values = collect(get_values_shallow(choices)) - key_and_submaps = collect(get_submaps_shallow(choices)) - n = length(key_and_values) + length(key_and_submaps) - cur = 1 - for (key, value) in key_and_values - # For strings, `print` is what we want; `Base.show` includes quote marks. - # https://docs.julialang.org/en/v1/base/io-network/#Base.print - print(io, indent_vert_str) - print(io, (cur == n ? indent_last_str : indent_str) * "$(repr(key)) : $value\n") - cur += 1 - end - for (key, submap) in key_and_submaps - print(io, indent_vert_str) - print(io, (cur == n ? indent_last_str : indent_str) * "$(repr(key))\n") - _show_pretty(io, submap, pre + 4, cur == n ? (vert_bars...,) : (vert_bars..., pre+1)) - cur += 1 - end -end - -function Base.show(io::IO, ::MIME"text/plain", choices::ChoiceMap) - _show_pretty(io, choices, 0, ()) -end - -# assignments that have static address schemas should also support faster -# accessors, which make the address explicit in the type (Val(:foo) instaed of -# :foo) -function static_get_value end -function static_get_submap end - -function _fill_array! end -function _from_array end - -""" - arr::Vector{T} = to_array(choices::ChoiceMap, ::Type{T}) where {T} - -Populate an array with values of choices in the given assignment. - -It is an error if each of the values cannot be coerced into a value of the -given type. - -# Implementation - -To support `to_array`, a concrete subtype `T <: ChoiceMap` should implement -the following method: - - n::Int = _fill_array!(choices::T, arr::Vector{V}, start_idx::Int) where {V} - -Populate `arr` with values from the given assignment, starting at `start_idx`, -and return the number of elements in `arr` that were populated. -""" -function to_array(choices::ChoiceMap, ::Type{T}) where {T} - arr = Vector{T}(undef, 32) - n = _fill_array!(choices, arr, 1) - @assert n <= length(arr) - resize!(arr, n) - arr -end - -function _fill_array!(value::T, arr::Vector{T}, start_idx::Int) where {T} - if length(arr) < start_idx - resize!(arr, 2 * start_idx) - end - arr[start_idx] = value - 1 -end - -function _fill_array!(value::Vector{T}, arr::Vector{T}, start_idx::Int) where {T} - if length(arr) < start_idx + length(value) - resize!(arr, 2 * (start_idx + length(value))) - end - arr[start_idx:start_idx+length(value)-1] = value - length(value) -end - - -""" - choices::ChoiceMap = from_array(proto_choices::ChoiceMap, arr::Vector) - -Return an assignment with the same address structure as a prototype -assignment, but with values read off from the given array. - -The order in which addresses are populated is determined by the prototype -assignment. It is an error if the number of choices in the prototype assignment -is not equal to the length the array. - -# Implementation - -To support `from_array`, a concrete subtype `T <: ChoiceMap` should implement -the following method: - - - (n::Int, choices::T) = _from_array(proto_choices::T, arr::Vector{V}, start_idx::Int) where {V} - -Return an assignment with the same address structure as a prototype assignment, -but with values read off from `arr`, starting at position `start_idx`, and the -number of elements read from `arr`. -""" -function from_array(proto_choices::ChoiceMap, arr::Vector) - (n, choices) = _from_array(proto_choices, arr, 1) - if n != length(arr) - error("Dimension mismatch: $n, $(length(arr))") - end - choices -end - -function _from_array(::T, arr::Vector{T}, start_idx::Int) where {T} - (1, arr[start_idx]) -end - -function _from_array(value::Vector{T}, arr::Vector{T}, start_idx::Int) where {T} - n_read = length(value) - (n_read, arr[start_idx:start_idx+n_read-1]) -end - - -""" - choices = Base.merge(choices1::ChoiceMap, choices2::ChoiceMap) - -Merge two choice maps. - -It is an error if the choice maps both have values at the same address, or if -one choice map has a value at an address that is the prefix of the address of a -value in the other choice map. -""" -function Base.merge(choices1::ChoiceMap, choices2::ChoiceMap) - choices = DynamicChoiceMap() - for (key, value) in get_values_shallow(choices1) - choices.leaf_nodes[key] = value - end - for (key, node1) in get_submaps_shallow(choices1) - node2 = get_submap(choices2, key) - node = merge(node1, node2) - choices.internal_nodes[key] = node - end - for (key, value) in get_values_shallow(choices2) - if haskey(choices.leaf_nodes, key) - error("choices1 has leaf node at $key and choices2 has leaf node at $key") - end - if haskey(choices.internal_nodes, key) - error("choices1 has internal node at $key and choices2 has leaf node at $key") - end - choices.leaf_nodes[key] = value - end - for (key, node) in get_submaps_shallow(choices2) - if haskey(choices.leaf_nodes, key) - error("choices1 has leaf node at $key and choices2 has internal node at $key") - end - if !haskey(choices.internal_nodes, key) - # otherwise it should already be included - choices.internal_nodes[key] = node - end - end - return choices -end - -""" -Variadic merge of choice maps. -""" -function Base.merge(choices1::ChoiceMap, choices_rest::ChoiceMap...) - reduce(Base.merge, choices_rest; init=choices1) -end - -function Base.:(==)(a::ChoiceMap, b::ChoiceMap) - for (addr, value) in get_values_shallow(a) - if !has_value(b, addr) || (get_value(b, addr) != value) - return false - end - end - for (addr, value) in get_values_shallow(b) - if !has_value(a, addr) || (get_value(a, addr) != value) - return false - end - end - for (addr, submap) in get_submaps_shallow(a) - if submap != get_submap(b, addr) - return false - end - end - for (addr, submap) in get_submaps_shallow(b) - if submap != get_submap(a, addr) - return false - end - end - return true -end - -function Base.isapprox(a::ChoiceMap, b::ChoiceMap) - for (addr, value) in get_values_shallow(a) - if !has_value(b, addr) || !isapprox(get_value(b, addr), value) - return false - end - end - for (addr, value) in get_values_shallow(b) - if !has_value(a, addr) || !isapprox(get_value(a, addr), value) - return false - end - end - for (addr, submap) in get_submaps_shallow(a) - if !isapprox(submap, get_submap(b, addr)) - return false - end - end - for (addr, submap) in get_submaps_shallow(b) - if !isapprox(submap, get_submap(a, addr)) - return false - end - end - return true -end - - -export ChoiceMap -export get_address_schema -export get_submap -export get_value -export has_value -export get_submaps_shallow -export get_values_shallow -export static_get_value -export static_get_submap -export to_array, from_array - - -###################### -# static assignment # -###################### - -struct StaticChoiceMap{R,S,T,U} <: ChoiceMap - leaf_nodes::NamedTuple{R,S} - internal_nodes::NamedTuple{T,U} - isempty::Bool -end - -function StaticChoiceMap{R,S,T,U}(leaf_nodes::NamedTuple{R,S}, internal_nodes::NamedTuple{T,U}) where {R,S,T,U} - is_empty = length(leaf_nodes) == 0 && all(isempty(n) for n in internal_nodes) - StaticChoiceMap(leaf_nodes, internal_nodes, is_empty) -end - -function StaticChoiceMap(leaf_nodes::NamedTuple{R,S}, internal_nodes::NamedTuple{T,U}) where {R,S,T,U} - is_empty = length(leaf_nodes) == 0 && all(isempty(n) for n in internal_nodes) - StaticChoiceMap(leaf_nodes, internal_nodes, is_empty) -end - - -# invariant: all internal_nodes are nonempty - -function get_address_schema(::Type{StaticChoiceMap{R,S,T,U}}) where {R,S,T,U} - keys = Set{Symbol}() - for (key, _) in zip(R, S.parameters) - push!(keys, key) - end - for (key, _) in zip(T, U.parameters) - push!(keys, key) - end - StaticAddressSchema(keys) -end - -function Base.isempty(choices::StaticChoiceMap) - choices.isempty -end - -get_values_shallow(choices::StaticChoiceMap) = pairs(choices.leaf_nodes) -get_submaps_shallow(choices::StaticChoiceMap) = pairs(choices.internal_nodes) -has_value(choices::StaticChoiceMap, addr::Pair) = _has_value(choices, addr) -get_value(choices::StaticChoiceMap, addr::Pair) = _get_value(choices, addr) -get_submap(choices::StaticChoiceMap, addr::Pair) = _get_submap(choices, addr) - -# NOTE: there is no static_has_value because this is known from the static -# address schema - -## has_value ## - -function has_value(choices::StaticChoiceMap, key::Symbol) - haskey(choices.leaf_nodes, key) -end - -## get_submap ## - -function get_submap(choices::StaticChoiceMap, key::Symbol) - if haskey(choices.internal_nodes, key) - choices.internal_nodes[key] - elseif haskey(choices.leaf_nodes, key) - throw(KeyError(key)) - else - EmptyChoiceMap() - end -end - -function static_get_submap(choices::StaticChoiceMap, ::Val{A}) where {A} - choices.internal_nodes[A] -end - -## get_value ## - -function get_value(choices::StaticChoiceMap, key::Symbol) - choices.leaf_nodes[key] -end - -function static_get_value(choices::StaticChoiceMap, ::Val{A}) where {A} - choices.leaf_nodes[A] -end - -# convert from any other schema that has only Val{:foo} addresses -function StaticChoiceMap(other::ChoiceMap) - leaf_keys_and_nodes = collect(get_values_shallow(other)) - internal_keys_and_nodes = collect(get_submaps_shallow(other)) - if length(leaf_keys_and_nodes) > 0 - (leaf_keys, leaf_nodes) = collect(zip(leaf_keys_and_nodes...)) - else - (leaf_keys, leaf_nodes) = ((), ()) - end - if length(internal_keys_and_nodes) > 0 - (internal_keys, internal_nodes) = collect(zip(internal_keys_and_nodes...)) - else - (internal_keys, internal_nodes) = ((), ()) - end - StaticChoiceMap( - NamedTuple{leaf_keys}(leaf_nodes), - NamedTuple{internal_keys}(internal_nodes), - isempty(other)) -end - -""" - choices = pair(choices1::ChoiceMap, choices2::ChoiceMap, key1::Symbol, key2::Symbol) - -Return an assignment that contains `choices1` as a sub-assignment under `key1` -and `choices2` as a sub-assignment under `key2`. -""" -function pair(choices1::ChoiceMap, choices2::ChoiceMap, key1::Symbol, key2::Symbol) - StaticChoiceMap(NamedTuple(), NamedTuple{(key1,key2)}((choices1, choices2)), - isempty(choices1) && isempty(choices2)) -end - -""" - (choices1, choices2) = unpair(choices::ChoiceMap, key1::Symbol, key2::Symbol) - -Return the two sub-assignments at `key1` and `key2`, one or both of which may be empty. - -It is an error if there are any top-level values, or any non-empty top-level -sub-assignments at keys other than `key1` and `key2`. -""" -function unpair(choices::ChoiceMap, key1::Symbol, key2::Symbol) - if !isempty(get_values_shallow(choices)) || length(collect(get_submaps_shallow(choices))) > 2 - error("Not a pair") - end - a = get_submap(choices, key1) - b = get_submap(choices, key2) - (a, b) -end - -# TODO make a generated function? -function _fill_array!(choices::StaticChoiceMap, arr::Vector{T}, start_idx::Int) where {T} - idx = start_idx - for value in choices.leaf_nodes - n_written = _fill_array!(value, arr, idx) - idx += n_written - end - for node in choices.internal_nodes - n_written = _fill_array!(node, arr, idx) - idx += n_written - end - idx - start_idx -end - -@generated function _from_array( - proto_choices::StaticChoiceMap{R,S,T,U}, arr::Vector{V}, start_idx::Int) where {R,S,T,U,V} - leaf_node_keys = proto_choices.parameters[1] - leaf_node_types = proto_choices.parameters[2].parameters - internal_node_keys = proto_choices.parameters[3] - internal_node_types = proto_choices.parameters[4].parameters - - exprs = [quote idx = start_idx end] - leaf_node_names = [] - internal_node_names = [] - - # leaf nodes - for key in leaf_node_keys - value = gensym() - push!(leaf_node_names, value) - push!(exprs, quote - (n_read, $value) = _from_array(proto_choices.leaf_nodes.$key, arr, idx) - idx += n_read - end) - end - - # internal nodes - for key in internal_node_keys - node = gensym() - push!(internal_node_names, node) - push!(exprs, quote - (n_read, $node) = _from_array(proto_choices.internal_nodes.$key, arr, idx) - idx += n_read - end) - end - - quote - $(exprs...) - leaf_nodes_field = NamedTuple{R,S}(($(leaf_node_names...),)) - internal_nodes_field = NamedTuple{T,U}(($(internal_node_names...),)) - choices = StaticChoiceMap{R,S,T,U}(leaf_nodes_field, internal_nodes_field) - (idx - start_idx, choices) - end -end - -@generated function Base.merge(choices1::StaticChoiceMap{R,S,T,U}, - choices2::StaticChoiceMap{W,X,Y,Z}) where {R,S,T,U,W,X,Y,Z} - - # unpack first assignment type parameters - leaf_node_keys1 = choices1.parameters[1] - leaf_node_types1 = choices1.parameters[2].parameters - internal_node_keys1 = choices1.parameters[3] - internal_node_types1 = choices1.parameters[4].parameters - keys1 = (leaf_node_keys1..., internal_node_keys1...,) - - # unpack second assignment type parameters - leaf_node_keys2 = choices2.parameters[1] - leaf_node_types2 = choices2.parameters[2].parameters - internal_node_keys2 = choices2.parameters[3] - internal_node_types2 = choices2.parameters[4].parameters - keys2 = (leaf_node_keys2..., internal_node_keys2...,) - - # leaf vs leaf collision is an error - colliding_leaf_leaf_keys = intersect(leaf_node_keys1, leaf_node_keys2) - if !isempty(colliding_leaf_leaf_keys) - error("choices1 and choices2 both have leaf nodes at key(s): $colliding_leaf_leaf_keys") - end - - # leaf vs internal collision is an error - colliding_leaf_internal_keys = intersect(leaf_node_keys1, internal_node_keys2) - if !isempty(colliding_leaf_internal_keys) - error("choices1 has leaf node and choices2 has internal node at key(s): $colliding_leaf_internal_keys") - end - - # internal vs leaf collision is an error - colliding_internal_leaf_keys = intersect(internal_node_keys1, leaf_node_keys2) - if !isempty(colliding_internal_leaf_keys) - error("choices1 has internal node and choices2 has leaf node at key(s): $colliding_internal_leaf_keys") - end - - # internal vs internal collision is not an error, recursively call merge - colliding_internal_internal_keys = (intersect(internal_node_keys1, internal_node_keys2)...,) - internal_node_keys1_exclusive = (setdiff(internal_node_keys1, internal_node_keys2)...,) - internal_node_keys2_exclusive = (setdiff(internal_node_keys2, internal_node_keys1)...,) - - # leaf nodes named tuple - leaf_node_keys = (leaf_node_keys1..., leaf_node_keys2...,) - leaf_node_types = map(QuoteNode, (leaf_node_types1..., leaf_node_types2...,)) - leaf_node_values = Expr(:tuple, - [Expr(:(.), :(choices1.leaf_nodes), QuoteNode(key)) - for key in leaf_node_keys1]..., - [Expr(:(.), :(choices2.leaf_nodes), QuoteNode(key)) - for key in leaf_node_keys2]...) - leaf_nodes = Expr(:call, - Expr(:curly, :NamedTuple, - QuoteNode(leaf_node_keys), - Expr(:curly, :Tuple, leaf_node_types...)), - leaf_node_values) - - # internal nodes named tuple - internal_node_keys = (internal_node_keys1_exclusive..., - internal_node_keys2_exclusive..., - colliding_internal_internal_keys...) - internal_node_values = Expr(:tuple, - [Expr(:(.), :(choices1.internal_nodes), QuoteNode(key)) - for key in internal_node_keys1_exclusive]..., - [Expr(:(.), :(choices2.internal_nodes), QuoteNode(key)) - for key in internal_node_keys2_exclusive]..., - [Expr(:call, :merge, - Expr(:(.), :(choices1.internal_nodes), QuoteNode(key)), - Expr(:(.), :(choices2.internal_nodes), QuoteNode(key))) - for key in colliding_internal_internal_keys]...) - internal_nodes = Expr(:call, - Expr(:curly, :NamedTuple, QuoteNode(internal_node_keys)), - internal_node_values) - - # construct assignment from named tuples - Expr(:call, :StaticChoiceMap, leaf_nodes, internal_nodes) -end - -export StaticChoiceMap -export pair, unpair - -####################### -# dynamic assignment # -####################### - -struct DynamicChoiceMap <: ChoiceMap - leaf_nodes::Dict{Any,Any} - internal_nodes::Dict{Any,Any} - function DynamicChoiceMap(leaf_nodes::Dict{Any,Any}, internal_nodes::Dict{Any,Any}) - new(leaf_nodes, internal_nodes) - end -end - -# invariant: all internal nodes are nonempty - -""" - struct DynamicChoiceMap <: ChoiceMap .. end - -A mutable map from arbitrary hierarchical addresses to values. - - choices = DynamicChoiceMap() - -Construct an empty map. - - choices = DynamicChoiceMap(tuples...) - -Construct a map containing each of the given (addr, value) tuples. -""" -function DynamicChoiceMap() - DynamicChoiceMap(Dict(), Dict()) -end - -function DynamicChoiceMap(tuples...) - choices = DynamicChoiceMap() - for (addr, value) in tuples - choices[addr] = value - end - choices -end - -""" - choices = DynamicChoiceMap(other::ChoiceMap) - -Copy a choice map, returning a mutable choice map. -""" -function DynamicChoiceMap(other::ChoiceMap) - choices = DynamicChoiceMap() - for (addr, val) in get_values_shallow(other) - choices[addr] = val - end - for (addr, submap) in get_submaps_shallow(other) - set_submap!(choices, addr, DynamicChoiceMap(submap)) - end - choices -end - -""" - choices = choicemap() - -Construct an empty mutable choice map. -""" -function choicemap() - DynamicChoiceMap() -end - -""" - choices = choicemap(tuples...) - -Construct a mutable choice map initialized with given address, value tuples. -""" -function choicemap(tuples...) - DynamicChoiceMap(tuples...) -end - -get_address_schema(::Type{DynamicChoiceMap}) = DynamicAddressSchema() - -get_values_shallow(choices::DynamicChoiceMap) = choices.leaf_nodes - -get_submaps_shallow(choices::DynamicChoiceMap) = choices.internal_nodes - -has_value(choices::DynamicChoiceMap, addr::Pair) = _has_value(choices, addr) - -get_value(choices::DynamicChoiceMap, addr::Pair) = _get_value(choices, addr) - -get_submap(choices::DynamicChoiceMap, addr::Pair) = _get_submap(choices, addr) - -function get_submap(choices::DynamicChoiceMap, addr) - if haskey(choices.internal_nodes, addr) - choices.internal_nodes[addr] - elseif haskey(choices.leaf_nodes, addr) - throw(KeyError(addr)) - else - EmptyChoiceMap() - end -end - -has_value(choices::DynamicChoiceMap, addr) = haskey(choices.leaf_nodes, addr) - -get_value(choices::DynamicChoiceMap, addr) = choices.leaf_nodes[addr] - -function Base.isempty(choices::DynamicChoiceMap) - isempty(choices.leaf_nodes) && isempty(choices.internal_nodes) -end - -# mutation (not part of the assignment interface) - -""" - set_value!(choices::DynamicChoiceMap, addr, value) - -Set the given value for the given address. - -Will cause any previous value or sub-assignment at this address to be deleted. -It is an error if there is already a value present at some prefix of the given address. - -The following syntactic sugar is provided: - - choices[addr] = value -""" -function set_value!(choices::DynamicChoiceMap, addr, value) - delete!(choices.internal_nodes, addr) - choices.leaf_nodes[addr] = value -end - -function set_value!(choices::DynamicChoiceMap, addr::Pair, value) - (first, rest) = addr - if haskey(choices.leaf_nodes, first) - # we are not writing to the address directly, so we error instead of - # delete the existing node. - error("Tried to create assignment at $first but there was already a value there.") - end - if haskey(choices.internal_nodes, first) - node = choices.internal_nodes[first] - else - node = DynamicChoiceMap() - choices.internal_nodes[first] = node - end - node = choices.internal_nodes[first] - set_value!(node, rest, value) -end - -""" - set_submap!(choices::DynamicChoiceMap, addr, submap::ChoiceMap) - -Replace the sub-assignment rooted at the given address with the given sub-assignment. -Set the given value for the given address. - -Will cause any previous value or sub-assignment at the given address to be deleted. -It is an error if there is already a value present at some prefix of address. -""" -function set_submap!(choices::DynamicChoiceMap, addr, new_node) - delete!(choices.leaf_nodes, addr) - delete!(choices.internal_nodes, addr) - if !isempty(new_node) - choices.internal_nodes[addr] = new_node - end -end - -function set_submap!(choices::DynamicChoiceMap, addr::Pair, new_node) - (first, rest) = addr - if haskey(choices.leaf_nodes, first) - # we are not writing to the address directly, so we error instead of - # delete the existing node. - error("Tried to create assignment at $first but there was already a value there.") - end - if haskey(choices.internal_nodes, first) - node = choices.internal_nodes[first] - else - node = DynamicChoiceMap() - choices.internal_nodes[first] = node - end - set_submap!(node, rest, new_node) -end - -Base.setindex!(choices::DynamicChoiceMap, value, addr) = set_value!(choices, addr, value) - -function _fill_array!(choices::DynamicChoiceMap, arr::Vector{T}, start_idx::Int) where {T} - leaf_keys_sorted = sort(collect(keys(choices.leaf_nodes))) - internal_node_keys_sorted = sort(collect(keys(choices.internal_nodes))) - idx = start_idx - for key in leaf_keys_sorted - value = choices.leaf_nodes[key] - n_written = _fill_array!(value, arr, idx) - idx += n_written - end - for key in internal_node_keys_sorted - n_written = _fill_array!(get_submap(choices, key), arr, idx) - idx += n_written - end - idx - start_idx -end - -function _from_array(proto_choices::DynamicChoiceMap, arr::Vector{T}, start_idx::Int) where {T} - @assert length(arr) >= start_idx - choices = DynamicChoiceMap() - leaf_keys_sorted = sort(collect(keys(proto_choices.leaf_nodes))) - internal_node_keys_sorted = sort(collect(keys(proto_choices.internal_nodes))) - idx = start_idx - for key in leaf_keys_sorted - (n_read, value) = _from_array(proto_choices.leaf_nodes[key], arr, idx) - idx += n_read - choices.leaf_nodes[key] = value - end - for key in internal_node_keys_sorted - (n_read, node) = _from_array(get_submap(proto_choices, key), arr, idx) - idx += n_read - choices.internal_nodes[key] = node - end - (idx - start_idx, choices) -end - -export DynamicChoiceMap -export choicemap -export set_value! -export set_submap! - - -####################################### -## vector combinator for assignments # -####################################### - -# TODO implement LeafVectorChoiceMap, which stores a vector of leaf nodes - -struct InternalVectorChoiceMap{T} <: ChoiceMap - internal_nodes::Vector{T} - is_empty::Bool -end - -function vectorize_internal(nodes::Vector{T}) where {T} - is_empty = all(map(isempty, nodes)) - InternalVectorChoiceMap(nodes, is_empty) -end - -# note some internal nodes may be empty - -get_address_schema(::Type{InternalVectorChoiceMap}) = VectorAddressSchema() - -Base.isempty(choices::InternalVectorChoiceMap) = choices.is_empty -has_value(choices::InternalVectorChoiceMap, addr::Pair) = _has_value(choices, addr) -get_value(choices::InternalVectorChoiceMap, addr::Pair) = _get_value(choices, addr) -get_submap(choices::InternalVectorChoiceMap, addr::Pair) = _get_submap(choices, addr) - -function get_submap(choices::InternalVectorChoiceMap, addr::Int) - if addr > 0 && addr <= length(choices.internal_nodes) - choices.internal_nodes[addr] - else - EmptyChoiceMap() - end -end - -function get_submaps_shallow(choices::InternalVectorChoiceMap) - ((i, choices.internal_nodes[i]) - for i=1:length(choices.internal_nodes) - if !isempty(choices.internal_nodes[i])) -end - -get_values_shallow(::InternalVectorChoiceMap) = () - -function _fill_array!(choices::InternalVectorChoiceMap, arr::Vector{T}, start_idx::Int) where {T} - idx = start_idx - for key=1:length(choices.internal_nodes) - n = _fill_array!(choices.internal_nodes[key], arr, idx) - idx += n - end - idx - start_idx -end - -function _from_array(proto_choices::InternalVectorChoiceMap{U}, arr::Vector{T}, start_idx::Int) where {T,U} - @assert length(arr) >= start_idx - nodes = Vector{U}(undef, length(proto_choices.internal_nodes)) - idx = start_idx - for key=1:length(proto_choices.internal_nodes) - (n_read, nodes[key]) = _from_array(proto_choices.internal_nodes[key], arr, idx) - idx += n_read - end - choices = InternalVectorChoiceMap(nodes, proto_choices.is_empty) - (idx - start_idx, choices) -end - -export InternalVectorChoiceMap -export vectorize_internal - - -#################### -# empty assignment # -#################### - -struct EmptyChoiceMap <: ChoiceMap end - -Base.isempty(::EmptyChoiceMap) = true -get_address_schema(::Type{EmptyChoiceMap}) = EmptyAddressSchema() -get_submaps_shallow(::EmptyChoiceMap) = () -get_values_shallow(::EmptyChoiceMap) = () - -_fill_array!(::EmptyChoiceMap, arr::Vector, start_idx::Int) = 0 -_from_array(::EmptyChoiceMap, arr::Vector, start_idx::Int) = (0, EmptyChoiceMap()) - -export EmptyChoiceMap - -############################################ -# Nested-dict–like accessor for choicemaps # -############################################ - -""" -Wrapper for a `ChoiceMap` that provides nested-dict–like syntax, rather than -the default syntax which looks like a flat dict of full keypaths. - -```jldoctest -julia> using Gen -julia> c = choicemap((:a, 1), - (:b => :c, 2)); -julia> cv = nested_view(c); -julia> c[:a] == cv[:a] -true -julia> c[:b => :c] == cv[:b][:c] -true -julia> length(cv) -2 -julia> length(cv[:b]) -1 -julia> sort(collect(keys(cv))) -[:a, :b] -julia> sort(collect(keys(cv[:b]))) -[:c] -``` -""" -struct ChoiceMapNestedView - choice_map::ChoiceMap -end - -function Base.getindex(choices::ChoiceMapNestedView, addr) - if has_value(choices.choice_map, addr) - return get_value(choices.choice_map, addr) - end - submap = get_submap(choices.choice_map, addr) - if isempty(submap) - throw(KeyError(addr)) - end - ChoiceMapNestedView(submap) -end - -function Base.iterate(c::ChoiceMapNestedView) - inner_iterator = Base.Iterators.flatten(( - get_values_shallow(c.choice_map), - ((k, ChoiceMapNestedView(v)) - for (k, v) in get_submaps_shallow(c.choice_map)))) - r = Base.iterate(inner_iterator) - if r == nothing - return nothing - end - (next_kv, next_inner_state) = r - (next_kv, (inner_iterator, next_inner_state)) -end - -function Base.iterate(c::ChoiceMapNestedView, state) - (inner_iterator, inner_state) = state - r = Base.iterate(inner_iterator, inner_state) - if r == nothing - return nothing - end - (next_kv, next_inner_state) = r - (next_kv, (inner_iterator, next_inner_state)) -end - -# TODO: Allow different implementations of this method depending on the -# concrete type of the `ChoiceMap`, so that an already-existing data structure -# with faster key lookup (analogous to `Base.KeySet`) can be exposed if it -# exists. -Base.keys(cv::Gen.ChoiceMapNestedView) = (k for (k, v) in cv) - -function Base.:(==)(a::ChoiceMapNestedView, b::ChoiceMapNestedView) - a.choice_map == b.choice_map -end - -# Length of a `ChoiceMapNestedView` is number of leaf values + number of -# submaps. Motivation: This matches what `length` would return for the -# equivalent nested dict. -function Base.length(cv::ChoiceMapNestedView) - +(get_values_shallow(cv.choice_map) |> collect |> length, - get_submaps_shallow(cv.choice_map) |> collect |> length) -end - -function Base.show(io::IO, ::MIME"text/plain", c::ChoiceMapNestedView) - Base.show(io, MIME"text/plain"(), c.choice_map) -end - -nested_view(c::ChoiceMap) = ChoiceMapNestedView(c) - -# TODO(https://github.com/probcomp/Gen/issues/167): Also allow calling -# `nested_view(::Trace)`, to get a nested-dict–like view of the choicemap and -# aux data together. - -export nested_view - -""" - selected_choices = get_selected(choices::ChoiceMap, selection::Selection) - -Filter the choice map to include only choices in the given selection. - -Returns a new choice map. -""" -function get_selected( - choices::ChoiceMap, selection::Selection) - output = choicemap() - for (key, value) in get_values_shallow(choices) - if (key in selection) - output[key] = value - end - end - for (key, submap) in get_submaps_shallow(choices) - subselection = selection[key] - set_submap!(output, key, get_selected(submap, subselection)) - end - output -end - -export get_selected diff --git a/src/choice_map/array_interface.jl b/src/choice_map/array_interface.jl new file mode 100644 index 00000000..cf9d0bd0 --- /dev/null +++ b/src/choice_map/array_interface.jl @@ -0,0 +1,106 @@ +### interface for to_array and fill_array ### + +""" + arr::Vector{T} = to_array(choices::ChoiceMap, ::Type{T}) where {T} + +Populate an array with values of choices in the given assignment. + +It is an error if each of the values cannot be coerced into a value of the +given type. + +Implementation + +The default implmentation of `fill_array` will populate the array by sorting +the addresses of the choicemap using the `sort` function, then iterating over +each submap in this order and filling the array for that submap. + +To override the default implementation of `to_array`, +a concrete subtype `T <: ChoiceMap` should implement the following method: + + n::Int = _fill_array!(choices::T, arr::Vector{V}, start_idx::Int) where {V} + +Populate `arr` with values from the given assignment, starting at `start_idx`, +and return the number of elements in `arr` that were populated. + +(This is for performance; it is more efficient to fill in values in a preallocated array +by implementing `_fill_array!` than to construct discontiguous arrays for each submap and then merge them.) +""" +function to_array(choices::ChoiceMap, ::Type{T}) where {T} + arr = Vector{T}(undef, 32) + n = _fill_array!(choices, arr, 1) + @assert n <= length(arr) + resize!(arr, n) + arr +end + +function _fill_array!(c::ValueChoiceMap{<:T}, arr::Vector{T}, start_idx::Int) where {T} + if length(arr) < start_idx + resize!(arr, 2 * start_idx) + end + arr[start_idx] = get_value(c) + 1 +end +function _fill_array!(c::ValueChoiceMap{<:Vector{<:T}}, arr::Vector{T}, start_idx::Int) where {T} + value = get_value(c) + if length(arr) < start_idx + length(value) + resize!(arr, 2 * (start_idx + length(value))) + end + arr[start_idx:start_idx+length(value)-1] = value + length(value) +end + +# default _fill_array! implementation +function _fill_array!(choices::ChoiceMap, arr::Vector{T}, start_idx::Int) where {T} + key_to_submap = collect(get_submaps_shallow(choices)) + sort!(key_to_submap, by = ((key, submap),) -> key) + idx = start_idx + for (key, submap) in key_to_submap + n_written = _fill_array!(submap, arr, idx) + idx += n_written + end + idx - start_idx +end + +""" + choices::ChoiceMap = from_array(proto_choices::ChoiceMap, arr::Vector) + +Return an assignment with the same address structure as a prototype +assignment, but with values read off from the given array. + +It is an error if the number of choices in the prototype assignment +is not equal to the length the array. + +The order in which addresses are populated with values from the array +should match the order in which the array is populated with values +in a call to `to_array(proto_choices, T)`. By default, +this means sorting the top-level addresses for `proto_choices` +and then filling in the submaps depth-first in this order. + +# Implementation + +To support `from_array`, a concrete subtype `T <: ChoiceMap` must implement +the following method: + + (n::Int, choices::T) = _from_array(proto_choices::T, arr::Vector{V}, start_idx::Int) where {V} + +Return an assignment with the same address structure as a prototype assignment, +but with values read off from `arr`, starting at position `start_idx`. Return the +number of elements read from `arr`. +""" +function from_array(proto_choices::ChoiceMap, arr::Vector) + (n, choices) = _from_array(proto_choices, arr, 1) + if n != length(arr) + error("Dimension mismatch: $n, $(length(arr))") + end + choices +end + +function _from_array(::ValueChoiceMap, arr::Vector, start_idx::Int) + (1, ValueChoiceMap(arr[start_idx])) +end +function _from_array(c::ValueChoiceMap{<:Vector{<:T}}, arr::Vector{T}, start_idx::Int) where {T} + n_read = length(get_value(c)) + (n_read, ValueChoiceMap(arr[start_idx:start_idx+n_read-1])) +end + +export to_array, from_array \ No newline at end of file diff --git a/src/choice_map/choice_map.jl b/src/choice_map/choice_map.jl new file mode 100644 index 00000000..213bc5f8 --- /dev/null +++ b/src/choice_map/choice_map.jl @@ -0,0 +1,283 @@ +######################### +# choice map interface # +######################### + +""" + ChoiceMapGetValueError + +The error returned when a user attempts to call `get_value` +on an choicemap for an address which does not contain a value in that choicemap. +""" +struct ChoiceMapGetValueError <: Exception end +showerror(io::IO, ex::ChoiceMapGetValueError) = (print(io, "ChoiceMapGetValueError: no value was found for the `get_value` call.")) + +""" + abstract type ChoiceMap end + +Abstract type for maps from hierarchical addresses to values. +""" +abstract type ChoiceMap end + +""" + get_submaps_shallow(choices::ChoiceMap) + +Returns an iterable collection of tuples `(address, submap)` +for each top-level address associated with `choices`. +(This includes `ValueChoiceMap`s.) +""" +function get_submaps_shallow end + +""" + get_submap(choices::ChoiceMap, addr) + +Return the submap at the given address, or `EmptyChoiceMap` +if there is no submap at the given address. +""" +function get_submap end + +# provide _get_submap so when users overwrite get_submap(choices::CustomChoiceMap, addr::Pair) +# they can just call _get_submap for convenience if they want +@inline function _get_submap(choices::ChoiceMap, addr::Pair) + (first, rest) = addr + submap = get_submap(choices, first) + get_submap(submap, rest) +end + +""" + has_value(choices::ChoiceMap) + +Returns true if `choices` is a `ValueChoiceMap`. + + has_value(choices::ChoiceMap, addr) + +Returns true if `choices` has a value stored at address `addr`. +""" +function has_value end +@inline has_value(::ChoiceMap) = false +@inline has_value(c::ChoiceMap, addr) = has_value(get_submap(c, addr)) + +""" + get_value(choices::ChoiceMap) + +Returns the value stored on `choices` is `choices` is a `ValueChoiceMap`; +throws a `ChoiceMapGetValueError` if `choices` is not a `ValueChoiceMap`. + + get_value(choices::ChoiceMap, addr) +Returns the value stored in the submap with address `addr` or throws +a `ChoiceMapGetValueError` if no value exists at this address. + +A syntactic sugar is `Base.getindex`: + + value = choices[addr] +""" +function get_value end +@inline get_value(::ChoiceMap) = throw(ChoiceMapGetValueError()) +@inline get_value(c::ChoiceMap, addr) = get_value(get_submap(c, addr)) +@inline Base.getindex(choices::ChoiceMap, addr...) = get_value(choices, addr...) + +""" +schema = get_address_schema(::Type{T}) where {T <: ChoiceMap} + +Return the (top-level) address schema for the given choice map. +""" +function get_address_schema end + +# get_values_shallow and get_nonvalue_submaps_shallow are just filters on get_submaps_shallow +""" + get_values_shallow(choices::ChoiceMap) + +Returns an iterable collection of tuples `(address, value)` +for each value stored at a top-level address in `choices`. +(Works by applying a filter to `get_submaps_shallow`, +so this internally requires iterating over every submap.) +""" +function get_values_shallow(choices::ChoiceMap) + ( + (addr, get_value(submap)) + for (addr, submap) in get_submaps_shallow(choices) + if has_value(submap) + ) +end + +""" + get_nonvalue_submaps_shallow(choices::ChoiceMap) + +Returns an iterable collection of tuples `(address, submap)` +for every top-level submap stored in `choices` which is +not a `ValueChoiceMap`. +(Works by applying a filter to `get_submaps_shallow`, +so this internally requires iterating over every submap.) +""" +function get_nonvalue_submaps_shallow(choices::ChoiceMap) + (addr_to_submap for addr_to_submap in get_submaps_shallow(choices) if !has_value(addr_to_submap[2])) +end + +# a choicemap is empty if it has no submaps and no value +Base.isempty(c::ChoiceMap) = all(((addr, submap),) -> isempty(submap), get_submaps_shallow(c)) && !has_value(c) + +""" + EmptyChoiceMap + +A choicemap with no submaps or values. +""" +struct EmptyChoiceMap <: ChoiceMap end + +@inline has_value(::EmptyChoiceMap, addr...) = false +@inline get_value(::EmptyChoiceMap) = throw(ChoiceMapGetValueError()) +@inline get_submap(::EmptyChoiceMap, addr) = EmptyChoiceMap() +@inline Base.isempty(::EmptyChoiceMap) = true +@inline get_submaps_shallow(::EmptyChoiceMap) = () +@inline get_address_schema(::Type{EmptyChoiceMap}) = EmptyAddressSchema() +@inline Base.:(==)(::EmptyChoiceMap, ::EmptyChoiceMap) = true +@inline Base.:(==)(::ChoiceMap, ::EmptyChoiceMap) = false +@inline Base.:(==)(::EmptyChoiceMap, ::ChoiceMap) = false + +""" + ValueChoiceMap + +A leaf-node choicemap. Stores a single value. +""" +struct ValueChoiceMap{T} <: ChoiceMap + val::T +end + +@inline has_value(choices::ValueChoiceMap) = true +@inline get_value(choices::ValueChoiceMap) = choices.val +@inline get_submap(choices::ValueChoiceMap, addr) = EmptyChoiceMap() +@inline get_submaps_shallow(choices::ValueChoiceMap) = () +@inline Base.:(==)(a::ValueChoiceMap, b::ValueChoiceMap) = a.val == b.val +@inline Base.isapprox(a::ValueChoiceMap, b::ValueChoiceMap) = isapprox(a.val, b.val) +@inline get_address_schema(::Type{<:ValueChoiceMap}) = EmptyAddressSchema() + +""" + choices = Base.merge(choices1::ChoiceMap, choices2::ChoiceMap) + +Merge two choice maps. + +It is an error if the choice maps both have values at the same address, or if +one choice map has a value at an address that is the prefix of the address of a +value in the other choice map. +""" +function Base.merge(choices1::ChoiceMap, choices2::ChoiceMap) + choices = DynamicChoiceMap() + for (key, submap) in get_submaps_shallow(choices1) + set_submap!(choices, key, merge(submap, get_submap(choices2, key))) + end + for (key, submap) in get_submaps_shallow(choices2) + if isempty(get_submap(choices1, key)) + set_submap!(choices, key, submap) + end + end + choices +end +Base.merge(c::ChoiceMap, ::EmptyChoiceMap) = c +Base.merge(::EmptyChoiceMap, c::ChoiceMap) = c +Base.merge(c::ValueChoiceMap, ::EmptyChoiceMap) = c +Base.merge(::EmptyChoiceMap, c::ValueChoiceMap) = c +Base.merge(::ValueChoiceMap, ::ChoiceMap) = error("ValueChoiceMaps cannot be merged") +Base.merge(::ChoiceMap, ::ValueChoiceMap) = error("ValueChoiceMaps cannot be merged") + +""" +Variadic merge of choice maps. +""" +function Base.merge(choices1::ChoiceMap, choices_rest::ChoiceMap...) + reduce(Base.merge, choices_rest; init=choices1) +end + +function Base.:(==)(a::ChoiceMap, b::ChoiceMap) + for (addr, submap) in get_submaps_shallow(a) + if get_submap(b, addr) != submap + return false + end + end + for (addr, submap) in get_submaps_shallow(b) + if get_submap(a, addr) != submap + return false + end + end + return true +end + +function Base.isapprox(a::ChoiceMap, b::ChoiceMap) + for (addr, submap) in get_submaps_shallow(a) + if !isapprox(get_submap(b, addr), submap) + return false + end + end + return true +end + +""" + selected_choices = get_selected(choices::ChoiceMap, selection::Selection) + +Filter the choice map to include only choices in the given selection. + +Returns a new choice map. +""" +function get_selected( + choices::ChoiceMap, selection::Selection) + # TODO: return a `FilteringChoiceMap` which does this filtering lazily! + output = choicemap() + for (addr, submap) in get_submaps_shallow(choices) + if has_value(submap) && addr in selection + output[addr] = get_value(submap) + else + subselection = selection[addr] + set_submap!(output, addr, get_selected(submap, subselection)) + end + end + output +end + +function _show_pretty(io::IO, choices::ChoiceMap, pre, vert_bars::Tuple) + VERT = '\u2502' + PLUS = '\u251C' + HORZ = '\u2500' + LAST = '\u2514' + indent_vert = vcat(Char[' ' for _ in 1:pre], Char[VERT, '\n']) + indent_vert_last = vcat(Char[' ' for _ in 1:pre], Char[VERT, '\n']) + indent = vcat(Char[' ' for _ in 1:pre], Char[PLUS, HORZ, HORZ, ' ']) + indent_last = vcat(Char[' ' for _ in 1:pre], Char[LAST, HORZ, HORZ, ' ']) + for i in vert_bars + indent_vert[i] = VERT + indent[i] = VERT + indent_last[i] = VERT + end + indent_vert_str = join(indent_vert) + indent_vert_last_str = join(indent_vert_last) + indent_str = join(indent) + indent_last_str = join(indent_last) + key_and_values = collect(get_values_shallow(choices)) + key_and_submaps = collect(get_nonvalue_submaps_shallow(choices)) + n = length(key_and_values) + length(key_and_submaps) + cur = 1 + for (key, value) in key_and_values + # For strings, `print` is what we want; `Base.show` includes quote marks. + # https://docs.julialang.org/en/v1/base/io-network/#Base.print + print(io, indent_vert_str) + print(io, (cur == n ? indent_last_str : indent_str) * "$(repr(key)) : $value\n") + cur += 1 + end + for (key, submap) in key_and_submaps + print(io, indent_vert_str) + print(io, (cur == n ? indent_last_str : indent_str) * "$(repr(key))\n") + _show_pretty(io, submap, pre + 4, cur == n ? (vert_bars...,) : (vert_bars..., pre+1)) + cur += 1 + end +end + +function Base.show(io::IO, ::MIME"text/plain", choices::ChoiceMap) + _show_pretty(io, choices, 0, ()) +end + +export ChoiceMap, ValueChoiceMap, EmptyChoiceMap +export _get_submap, get_submap, get_submaps_shallow +export get_value, has_value +export get_values_shallow, get_nonvalue_submaps_shallow +export get_address_schema, get_selected +export ChoiceMapGetValueError + +include("array_interface.jl") +include("dynamic_choice_map.jl") +include("static_choice_map.jl") +include("nested_view.jl") \ No newline at end of file diff --git a/src/choice_map/dynamic_choice_map.jl b/src/choice_map/dynamic_choice_map.jl new file mode 100644 index 00000000..0f27c89d --- /dev/null +++ b/src/choice_map/dynamic_choice_map.jl @@ -0,0 +1,149 @@ +####################### +# dynamic assignment # +####################### + +""" + struct DynamicChoiceMap <: ChoiceMap .. end + +A mutable map from arbitrary hierarchical addresses to values. + + choices = DynamicChoiceMap() + +Construct an empty map. + + choices = DynamicChoiceMap(tuples...) + +Construct a map containing each of the given (addr, value) tuples. +""" +struct DynamicChoiceMap <: ChoiceMap + submaps::Dict{Any, ChoiceMap} + function DynamicChoiceMap() + new(Dict()) + end +end + +function DynamicChoiceMap(tuples...) + choices = DynamicChoiceMap() + for (addr, value) in tuples + choices[addr] = value + end + choices +end + +""" + choices = DynamicChoiceMap(other::ChoiceMap) + +Copy a choice map, returning a mutable choice map. +""" +function DynamicChoiceMap(other::ChoiceMap) + choices = DynamicChoiceMap() + for (addr, submap) in get_submaps_shallow(other) + if submap isa ValueChoiceMap + set_submap!(choices, addr, submap) + else + set_submap!(choices, addr, DynamicChoiceMap(submap)) + end + end + choices +end + +DynamicChoiceMap(other::ValueChoiceMap) = error("Cannot convert a ValueChoiceMap to a DynamicChoiceMap") + +""" + choices = choicemap() + +Construct an empty mutable choice map. +""" +function choicemap() + DynamicChoiceMap() +end + +""" + choices = choicemap(tuples...) + +Construct a mutable choice map initialized with given address, value tuples. +""" +function choicemap(tuples...) + DynamicChoiceMap(tuples...) +end + +@inline get_submaps_shallow(choices::DynamicChoiceMap) = choices.submaps +@inline get_submap(choices::DynamicChoiceMap, addr) = get(choices.submaps, addr, EmptyChoiceMap()) +@inline get_submap(choices::DynamicChoiceMap, addr::Pair) = _get_submap(choices, addr) +@inline Base.isempty(choices::DynamicChoiceMap) = isempty(choices.submaps) + +# mutation (not part of the assignment interface) + +""" + set_value!(choices::DynamicChoiceMap, addr, value) + +Set the given value for the given address. + +Will cause any previous value or sub-assignment at this address to be deleted. +It is an error if there is already a value present at some prefix of the given address. + +The following syntactic sugar is provided: + + choices[addr] = value +""" +function set_value!(choices::DynamicChoiceMap, addr, value) + delete!(choices.submaps, addr) + choices.submaps[addr] = ValueChoiceMap(value) +end + +function set_value!(choices::DynamicChoiceMap, addr::Pair, value) + (first, rest) = addr + if !haskey(choices.submaps, first) + choices.submaps[first] = DynamicChoiceMap() + elseif has_value(choices.submaps[first]) + error("Tried to create assignment at $first but there was already a value there.") + end + set_value!(choices.submaps[first], rest, value) +end + +""" + set_submap!(choices::DynamicChoiceMap, addr, submap::ChoiceMap) + +Replace the sub-assignment rooted at the given address with the given sub-assignment. +Set the given value for the given address. + +Will cause any previous value or sub-assignment at the given address to be deleted. +It is an error if there is already a value present at some prefix of address. +""" +function set_submap!(choices::DynamicChoiceMap, addr, new_node::ChoiceMap) + delete!(choices.submaps, addr) + if !isempty(new_node) + choices.submaps[addr] = new_node + end +end + +function set_submap!(choices::DynamicChoiceMap, addr::Pair, new_node::ChoiceMap) + (first, rest) = addr + if !haskey(choices.submaps, first) + choices.submaps[first] = DynamicChoiceMap() + elseif has_value(choices.submaps[first]) + error("Tried to create assignment at $first but there was already a value there.") + end + set_submap!(choices.submaps[first], rest, new_node) +end + +Base.setindex!(choices::DynamicChoiceMap, value, addr) = set_value!(choices, addr, value) + +function _from_array(proto_choices::DynamicChoiceMap, arr::Vector{T}, start_idx::Int) where {T} + choices = DynamicChoiceMap() + keys_sorted = sort(collect(keys(proto_choices.submaps))) + idx = start_idx + for key in keys_sorted + (n_read, submap) = _from_array(proto_choices.submaps[key], arr, idx) + idx += n_read + choices.submaps[key] = submap + end + (idx - start_idx, choices) +end + +get_address_schema(::Type{DynamicChoiceMap}) = DynamicAddressSchema() + +export DynamicChoiceMap +export choicemap +export set_value! +export set_submap! \ No newline at end of file diff --git a/src/choice_map/nested_view.jl b/src/choice_map/nested_view.jl new file mode 100644 index 00000000..68add0a0 --- /dev/null +++ b/src/choice_map/nested_view.jl @@ -0,0 +1,80 @@ +############################################ +# Nested-dict–like accessor for choicemaps # +############################################ + +""" +Wrapper for a `ChoiceMap` that provides nested-dict–like syntax, rather than +the default syntax which looks like a flat dict of full keypaths. + +```jldoctest +julia> using Gen +julia> c = choicemap((:a, 1), + (:b => :c, 2)); +julia> cv = nested_view(c); +julia> c[:a] == cv[:a] +true +julia> c[:b => :c] == cv[:b][:c] +true +julia> length(cv) +2 +julia> length(cv[:b]) +1 +julia> sort(collect(keys(cv))) +[:a, :b] +julia> sort(collect(keys(cv[:b]))) +[:c] +``` +""" +struct ChoiceMapNestedView + choice_map::ChoiceMap +end + +ChoiceMapNestedView(cm::ValueChoiceMap) = get_value(cm) +ChoiceMapNestedView(::EmptyChoiceMap) = error("Can't convert an emptychoicemap to nested view.") + +function Base.getindex(choices::ChoiceMapNestedView, addr) + ChoiceMapNestedView(get_submap(choices.choice_map, addr)) +end + +function Base.iterate(c::ChoiceMapNestedView) + itr = ((k, ChoiceMapNestedView(s)) for (k, s) in get_submaps_shallow(c.choice_map)) + r = Base.iterate(itr) + if r === nothing + return nothing + end + (next_kv, next_inner_state) = r + (next_kv, (itr, next_inner_state)) +end + +function Base.iterate(c::ChoiceMapNestedView, state) + (itr, st) = state + r = Base.iterate(itr, st) + if r === nothing + return nothing + end + (next_kv, next_inner_state) = r + (next_kv, (itr, next_inner_state)) +end + +# TODO: Allow different implementations of this method depending on the +# concrete type of the `ChoiceMap`, so that an already-existing data structure +# with faster key lookup (analogous to `Base.KeySet`) can be exposed if it +# exists. +Base.keys(cv::ChoiceMapNestedView) = (k for (k, v) in cv) + +Base.:(==)(a::ChoiceMapNestedView, b::ChoiceMapNestedView) = a.choice_map == b.choice_map + +function Base.length(cv::ChoiceMapNestedView) + length(collect(get_submaps_shallow(cv.choice_map))) +end +function Base.show(io::IO, ::MIME"text/plain", c::ChoiceMapNestedView) + Base.show(io, MIME"text/plain"(), c.choice_map) +end + +nested_view(c::ChoiceMap) = ChoiceMapNestedView(c) + +# TODO(https://github.com/probcomp/Gen/issues/167): Also allow calling +# `nested_view(::Trace)`, to get a nested-dict–like view of the choicemap and +# aux data together. + +export nested_view \ No newline at end of file diff --git a/src/choice_map/static_choice_map.jl b/src/choice_map/static_choice_map.jl new file mode 100644 index 00000000..ff8c01a7 --- /dev/null +++ b/src/choice_map/static_choice_map.jl @@ -0,0 +1,152 @@ +###################### +# static assignment # +###################### + +struct StaticChoiceMap{Addrs, SubmapTypes} <: ChoiceMap + submaps::NamedTuple{Addrs, SubmapTypes} + function StaticChoiceMap(submaps::NamedTuple{Addrs, SubmapTypes}) where {Addrs, SubmapTypes <: NTuple{n, ChoiceMap} where n} + new{Addrs, SubmapTypes}(submaps) + end +end + +function StaticChoiceMap(;addrs_to_vals_and_maps...) + addrs = Tuple(addr for (addr, val_or_map) in addrs_to_vals_and_maps) + maps = Tuple(val_or_map isa ChoiceMap ? val_or_map : ValueChoiceMap(val_or_map) for (addr, val_or_map) in addrs_to_vals_and_maps) + StaticChoiceMap(NamedTuple{addrs}(maps)) +end + +@inline get_submaps_shallow(choices::StaticChoiceMap) = pairs(choices.submaps) +@inline get_submap(choices::StaticChoiceMap, addr::Pair) = _get_submap(choices, addr) +@inline get_submap(choices::StaticChoiceMap, addr::Symbol) = static_get_submap(choices, Val(addr)) + +@generated function static_get_submap(choices::StaticChoiceMap{Addrs, SubmapTypes}, ::Val{A}) where {A, Addrs, SubmapTypes} + if A in Addrs + quote choices.submaps[A] end + else + quote EmptyChoiceMap() end + end +end +@inline static_get_submap(::EmptyChoiceMap, ::Val) = EmptyChoiceMap() + +@inline static_get_value(choices::StaticChoiceMap, v::Val) = get_value(static_get_submap(choices, v)) +@inline static_get_value(::EmptyChoiceMap, ::Val) = throw(ChoiceMapGetValueError()) + +# convert a nonvalue choicemap all of whose top-level-addresses +# are symbols into a staticchoicemap at the top level +function StaticChoiceMap(other::ChoiceMap) + keys_and_nodes = get_submaps_shallow(other) + if length(keys_and_nodes) > 0 + addrs = Tuple(key for (key, _) in keys_and_nodes) + submaps = Tuple(submap for (_, submap) in keys_and_nodes) + else + addrs = () + submaps = () + end + StaticChoiceMap(NamedTuple{addrs}(submaps)) +end +StaticChoiceMap(other::ValueChoiceMap) = error("Cannot convert a ValueChoiceMap to a StaticChoiceMap") +StaticChoiceMap(::NamedTuple{(),Tuple{}}) = EmptyChoiceMap() + +# TODO: deep conversion to static choicemap + +""" + choices = pair(choices1::ChoiceMap, choices2::ChoiceMap, key1::Symbol, key2::Symbol) + +Return an assignment that contains `choices1` as a sub-assignment under `key1` +and `choices2` as a sub-assignment under `key2`. +""" +function pair(choices1::ChoiceMap, choices2::ChoiceMap, key1::Symbol, key2::Symbol) + StaticChoiceMap(NamedTuple{(key1, key2)}((choices1, choices2))) +end + +""" + (choices1, choices2) = unpair(choices::ChoiceMap, key1::Symbol, key2::Symbol) + +Return the two sub-assignments at `key1` and `key2`, one or both of which may be empty. + +It is an error if there are any submaps at keys other than `key1` and `key2`. +""" +function unpair(choices::ChoiceMap, key1::Symbol, key2::Symbol) + if length(collect(get_submaps_shallow(choices))) != 2 + error("Not a pair") + end + (get_submap(choices, key1), get_submap(choices, key2)) +end + +@generated function Base.merge(choices1::StaticChoiceMap{Addrs1, SubmapTypes1}, + choices2::StaticChoiceMap{Addrs2, SubmapTypes2}) where {Addrs1, Addrs2, SubmapTypes1, SubmapTypes2} + + addr_to_type1 = Dict{Symbol, Type{<:ChoiceMap}}() + addr_to_type2 = Dict{Symbol, Type{<:ChoiceMap}}() + for (i, addr) in enumerate(Addrs1) + addr_to_type1[addr] = SubmapTypes1.parameters[i] + end + for (i, addr) in enumerate(Addrs2) + addr_to_type2[addr] = SubmapTypes2.parameters[i] + end + + merged_addrs = Tuple(union(Set(Addrs1), Set(Addrs2))) + submap_exprs = [] + + for addr in merged_addrs + type1 = get(addr_to_type1, addr, EmptyChoiceMap) + type2 = get(addr_to_type2, addr, EmptyChoiceMap) + if ((type1 <: ValueChoiceMap && type2 != EmptyChoiceMap) + || (type2 <: ValueChoiceMap && type1 != EmptyChoiceMap)) + error( "One choicemap has a value at address $addr; the other is nonempty at $addr. Cannot merge.") + end + if type1 <: EmptyChoiceMap + push!(submap_exprs, + quote choices2.submaps.$addr end + ) + elseif type2 <: EmptyChoiceMap + push!(submap_exprs, + quote choices1.submaps.$addr end + ) + else + push!(submap_exprs, + quote merge(choices1.submaps.$addr, choices2.submaps.$addr) end + ) + end + end + + quote + StaticChoiceMap(NamedTuple{$merged_addrs}(($(submap_exprs...),))) + end +end + +@generated function _from_array(proto_choices::StaticChoiceMap{Addrs, SubmapTypes}, + arr::Vector{T}, start_idx::Int) where {T, Addrs, SubmapTypes} + + perm = sortperm(collect(Addrs)) + sorted_addrs = Addrs[perm] + submap_var_names = Vector{Symbol}(undef, length(sorted_addrs)) + + exprs = [quote idx = start_idx end] + + for (idx, addr) in zip(perm, sorted_addrs) + submap_var_name = gensym(addr) + submap_var_names[idx] = submap_var_name + push!(exprs, + quote + (n_read, $submap_var_name) = _from_array(proto_choices.submaps.$addr, arr, idx) + idx += n_read + end + ) + end + + quote + $(exprs...) + submaps = NamedTuple{Addrs}(( $(submap_var_names...), )) + choices = StaticChoiceMap(submaps) + (idx - start_idx, choices) + end +end + +function get_address_schema(::Type{StaticChoiceMap{Addrs, SubmapTypes}}) where {Addrs, SubmapTypes} + StaticAddressSchema(Set(Addrs)) +end + +export StaticChoiceMap +export pair, unpair +export static_get_submap, static_get_value \ No newline at end of file diff --git a/src/dynamic/dynamic.jl b/src/dynamic/dynamic.jl index c6f09374..73f22159 100644 --- a/src/dynamic/dynamic.jl +++ b/src/dynamic/dynamic.jl @@ -124,42 +124,33 @@ function visit!(visitor::AddressVisitor, addr) push!(visitor.visited, addr) end +all_visited(::Selection, ::ValueChoiceMap) = false +all_visited(::AllSelection, ::ValueChoiceMap) = true function all_visited(visited::Selection, choices::ChoiceMap) - allvisited = true - for (key, _) in get_values_shallow(choices) - allvisited = allvisited && (key in visited) - end for (key, submap) in get_submaps_shallow(choices) - if !(key in visited) - subvisited = visited[key] - allvisited = allvisited && all_visited(subvisited, submap) + if !all_visited(visited[key], submap) + return false end end - allvisited + return true end +get_unvisited(::Selection, v::ValueChoiceMap) = v +get_unvisited(::AllSelection, v::ValueChoiceMap) = EmptyChoiceMap() function get_unvisited(visited::Selection, choices::ChoiceMap) unvisited = choicemap() - for (key, _) in get_values_shallow(choices) - if !(key in visited) - set_value!(unvisited, key, get_value(choices, key)) - end - end for (key, submap) in get_submaps_shallow(choices) - if !(key in visited) - subvisited = visited[key] - sub_unvisited = get_unvisited(subvisited, submap) - set_submap!(unvisited, key, sub_unvisited) - end + sub_unvisited = get_unvisited(visited[key], submap) + set_submap!(unvisited, key, sub_unvisited) end unvisited end get_visited(visitor) = visitor.visited -function check_no_submap(constraints::ChoiceMap, addr) +function check_is_empty(constraints::ChoiceMap, addr) if !isempty(get_submap(constraints, addr)) - error("Expected a value at address $addr but found a sub-assignment") + error("Expected a value or EmptyChoiceMap at address $addr but found a sub-assignment") end end diff --git a/src/dynamic/generate.jl b/src/dynamic/generate.jl index df6a5f46..970dac42 100644 --- a/src/dynamic/generate.jl +++ b/src/dynamic/generate.jl @@ -20,7 +20,7 @@ function traceat(state::GFGenerateState, dist::Distribution{T}, # check for constraints at this key constrained = has_value(state.constraints, key) - !constrained && check_no_submap(state.constraints, key) + !constrained && check_is_empty(state.constraints, key) # get return value if constrained diff --git a/src/dynamic/trace.jl b/src/dynamic/trace.jl index 8c02eceb..882297e4 100644 --- a/src/dynamic/trace.jl +++ b/src/dynamic/trace.jl @@ -119,9 +119,6 @@ struct DynamicDSLChoiceMap <: ChoiceMap end get_address_schema(::Type{DynamicDSLChoiceMap}) = DynamicAddressSchema() -Base.isempty(::DynamicDSLChoiceMap) = false # TODO not necessarily true -has_value(choices::DynamicDSLChoiceMap, addr::Pair) = _has_value(choices, addr) -get_value(choices::DynamicDSLChoiceMap, addr::Pair) = _get_value(choices, addr) get_submap(choices::DynamicDSLChoiceMap, addr::Pair) = _get_submap(choices, addr) function get_submap(choices::DynamicDSLChoiceMap, addr) @@ -130,9 +127,10 @@ function get_submap(choices::DynamicDSLChoiceMap, addr) # leaf node, must be a call call = trie[addr] if call.is_choice - throw(KeyError(addr)) + ValueChoiceMap(call.subtrace_or_retval) + else + get_choices(call.subtrace_or_retval) end - get_choices(call.subtrace_or_retval) elseif has_internal_node(trie, addr) # internal node subtrie = get_internal_node(trie, addr) @@ -142,32 +140,12 @@ function get_submap(choices::DynamicDSLChoiceMap, addr) end end -function has_value(choices::DynamicDSLChoiceMap, addr) - trie = choices.trie - has_leaf_node(trie, addr) && trie[addr].is_choice -end - -function get_value(choices::DynamicDSLChoiceMap, addr) - trie = choices.trie - choice = trie[addr] - if !choice.is_choice - throw(KeyError(addr)) - end - choice.subtrace_or_retval -end - -function get_values_shallow(choices::DynamicDSLChoiceMap) - ((key, choice.subtrace_or_retval) - for (key, choice) in get_leaf_nodes(choices.trie) - if choice.is_choice) -end - function get_submaps_shallow(choices::DynamicDSLChoiceMap) - calls_iter = ((key, get_choices(call.subtrace_or_retval)) + calls_iter = ( + (key, call.is_choice ? ValueChoiceMap(call.subtrace_or_retval) : get_choices(call.subtrace_or_retval)) for (key, call) in get_leaf_nodes(choices.trie) - if !call.is_choice) - internal_nodes_iter = ((key, DynamicDSLChoiceMap(trie)) - for (key, trie) in get_internal_nodes(choices.trie)) + ) + internal_nodes_iter = ((key, DynamicDSLChoiceMap(trie)) for (key, trie) in get_internal_nodes(choices.trie)) Iterators.flatten((calls_iter, internal_nodes_iter)) end diff --git a/src/dynamic/update.jl b/src/dynamic/update.jl index 24e023f2..7acc1630 100644 --- a/src/dynamic/update.jl +++ b/src/dynamic/update.jl @@ -35,7 +35,7 @@ function traceat(state::GFUpdateState, dist::Distribution{T}, # check for constraints at this key constrained = has_value(state.constraints, key) - !constrained && check_no_submap(state.constraints, key) + !constrained && check_is_empty(state.constraints, key) # record the previous value as discarded if it is replaced if constrained && has_previous @@ -149,32 +149,22 @@ end function add_unvisited_to_discard!(discard::DynamicChoiceMap, visited::DynamicSelection, prev_choices::ChoiceMap) - for (key, value) in get_values_shallow(prev_choices) + for (key, submap) in get_submaps_shallow(prev_choices) + # if key IS in visited, + # the recursive call to update already handled the discard + # for this entire submap; else we need to handle it if !(key in visited) - @assert !has_value(discard, key) @assert isempty(get_submap(discard, key)) - set_value!(discard, key, value) - end - end - for (key, submap) in get_submaps_shallow(prev_choices) - @assert !has_value(discard, key) - if key in visited - # the recursive call to update already handled the discard - # for this entire submap - continue - else subvisited = visited[key] if isempty(subvisited) # none of this submap was visited, so we discard the whole thing - @assert isempty(get_submap(discard, key)) set_submap!(discard, key, submap) else subdiscard = get_submap(discard, key) - add_unvisited_to_discard!( - isempty(subdiscard) ? choicemap() : subdiscard, - subvisited, submap) + subdiscard = isempty(subdiscard) ? choicemap() : subdiscard + add_unvisited_to_discard!(subdiscard, subvisited, submap) set_submap!(discard, key, subdiscard) - end + end end end end diff --git a/src/inference/kernel_dsl.jl b/src/inference/kernel_dsl.jl index a231f03a..d662dbb7 100644 --- a/src/inference/kernel_dsl.jl +++ b/src/inference/kernel_dsl.jl @@ -1,12 +1,13 @@ import MacroTools function check_observations(choices::ChoiceMap, observations::ChoiceMap) - for (key, value) in get_values_shallow(observations) - !has_value(choices, key) && error("Check failed: observed choice at $key not found") - choices[key] != value && error("Check failed: value of observed choice at $key changed") - end for (key, submap) in get_submaps_shallow(observations) - check_observations(get_submap(choices, key), submap) + if has_value(submap) + !has_value(choices, key) && error("Check failed: observed choice at $key not found") + choices[key] != value && error("Check failed: value of observed choice at $key changed") + else + check_observations(get_submap(choices, key), submap) + end end end diff --git a/src/modeling_library/call_at/call_at.jl b/src/modeling_library/call_at/call_at.jl index 23411697..f17d061f 100644 --- a/src/modeling_library/call_at/call_at.jl +++ b/src/modeling_library/call_at/call_at.jl @@ -14,10 +14,7 @@ function get_submap(choices::CallAtChoiceMap{K,T}, addr::K) where {K,T} end get_submap(choices::CallAtChoiceMap, addr::Pair) = _get_submap(choices, addr) -get_value(choices::CallAtChoiceMap, addr::Pair) = _get_value(choices, addr) -has_value(choices::CallAtChoiceMap, addr::Pair) = _has_value(choices, addr) get_submaps_shallow(choices::CallAtChoiceMap) = ((choices.key, choices.submap),) -get_values_shallow(::CallAtChoiceMap) = () # TODO optimize CallAtTrace using type parameters @@ -69,7 +66,7 @@ unpack_call_at_args(args) = (args[end], args[1:end-1]) function assess(gen_fn::CallAtCombinator, args::Tuple, choices::ChoiceMap) (key, kernel_args) = unpack_call_at_args(args) - if length(get_submaps_shallow(choices)) > 1 || length(get_values_shallow(choices)) > 0 + if length(get_submaps_shallow(choices)) > 1 error("Not all constraints were consumed") end submap = get_submap(choices, key) diff --git a/src/modeling_library/choice_at/choice_at.jl b/src/modeling_library/choice_at/choice_at.jl index 69bb4851..f3875895 100644 --- a/src/modeling_library/choice_at/choice_at.jl +++ b/src/modeling_library/choice_at/choice_at.jl @@ -25,10 +25,12 @@ function get_address_schema(::Type{T}) where {T<:ChoiceAtChoiceMap} end get_value(choices::ChoiceAtChoiceMap, addr::Pair) = _get_value(choices, addr) has_value(choices::ChoiceAtChoiceMap, addr::Pair) = _has_value(choices, addr) +get_submap(choices::ChoiceAtChoiceMap, addr::Pair) = _get_submap(choices, addr) function get_value(choices::ChoiceAtChoiceMap{T,K}, addr::K) where {T,K} choices.key == addr ? choices.value : throw(KeyError(choices, addr)) end -get_submaps_shallow(choices::ChoiceAtChoiceMap) = () +get_submap(choices::ChoiceAtChoiceMap, addr) = addr == choices.key ? ValueChoiceMap(choices.value) : EmptyChoiceMap() +get_submaps_shallow(choices::ChoiceAtChoiceMap) = ((choices.key, ValueChoiceMap(choices.value)),) get_values_shallow(choices::ChoiceAtChoiceMap) = ((choices.key, choices.value),) struct ChoiceAtCombinator{T,K} <: GenerativeFunction{T, ChoiceAtTrace} diff --git a/src/modeling_library/recurse/recurse.jl b/src/modeling_library/recurse/recurse.jl index 71580073..1f101725 100644 --- a/src/modeling_library/recurse/recurse.jl +++ b/src/modeling_library/recurse/recurse.jl @@ -84,17 +84,7 @@ function get_submap(choices::RecurseTraceChoiceMap, end end -function get_submap(choices::RecurseTraceChoiceMap, addr::Pair) - _get_submap(choices, addr) -end - -function has_value(choices::RecurseTraceChoiceMap, addr::Pair) - _has_value(choices, addr) -end - -function get_value(choices::RecurseTraceChoiceMap, addr::Pair) - _get_value(choices, addr) -end +get_submap(choices::RecurseTraceChoiceMap, addr::Pair) = _get_submap(choices, addr) get_values_shallow(choices::RecurseTraceChoiceMap) = () @@ -333,6 +323,9 @@ function recurse_unpack_constraints(constraints::ChoiceMap) production_constraints = Dict{Int, Any}() aggregation_constraints = Dict{Int, Any}() for (addr, node) in get_submaps_shallow(constraints) + if has_value(node) + error("Unknown address: $(addr)") + end idx::Int = addr[1] if addr[2] == Val(:production) production_constraints[idx] = node @@ -342,9 +335,6 @@ function recurse_unpack_constraints(constraints::ChoiceMap) error("Unknown address: $addr") end end - if length(get_values_shallow(constraints)) > 0 - error("Unknown address: $(first(get_values_shallow(constraints))[1])") - end return (production_constraints, aggregation_constraints) end diff --git a/src/modeling_library/vector.jl b/src/modeling_library/vector.jl index 9b0eb763..3af416ef 100644 --- a/src/modeling_library/vector.jl +++ b/src/modeling_library/vector.jl @@ -92,10 +92,6 @@ end end @inline get_submap(choices::VectorTraceChoiceMap, addr::Pair) = _get_submap(choices, addr) -@inline get_value(choices::VectorTraceChoiceMap, addr::Pair) = _get_value(choices, addr) -@inline has_value(choices::VectorTraceChoiceMap, addr::Pair) = _has_value(choices, addr) -@inline get_values_shallow(::VectorTraceChoiceMap) = () - ############################################ # code shared by vector-shaped combinators # diff --git a/src/static_ir/backprop.jl b/src/static_ir/backprop.jl index 7a0fe384..b352d3ca 100644 --- a/src/static_ir/backprop.jl +++ b/src/static_ir/backprop.jl @@ -330,21 +330,22 @@ function generate_value_gradient_trie(selected_choices::Set{RandomChoiceNode}, value_trie::Symbol, gradient_trie::Symbol) selected_choices_vec = collect(selected_choices) quoted_leaf_keys = map((node) -> QuoteNode(node.addr), selected_choices_vec) - leaf_values = map((node) -> :(trace.$(get_value_fieldname(node))), selected_choices_vec) - leaf_gradients = map((node) -> gradient_var(node), selected_choices_vec) + leaf_value_choicemaps = map((node) -> :(ValueChoiceMap(trace.$(get_value_fieldname(node)))), selected_choices_vec) + leaf_gradient_choicemaps = map((node) -> :(ValueChoiceMap($(gradient_var(node)))), selected_choices_vec) selected_calls_vec = collect(selected_calls) quoted_internal_keys = map((node) -> QuoteNode(node.addr), selected_calls_vec) - internal_values = map((node) -> :(get_choices(trace.$(get_subtrace_fieldname(node)))), + internal_value_choicemaps = map((node) -> :(get_choices(trace.$(get_subtrace_fieldname(node)))), selected_calls_vec) - internal_gradients = map((node) -> gradient_trie_var(node), selected_calls_vec) + internal_gradient_choicemaps = map((node) -> gradient_trie_var(node), selected_calls_vec) + + quoted_all_keys = Iterators.flatten((quoted_leaf_keys, quoted_internal_keys)) + all_value_choicemaps = Iterators.flatten((leaf_value_choicemaps, internal_value_choicemaps)) + all_gradient_choicemaps = Iterators.flatten((leaf_gradient_choicemaps, internal_gradient_choicemaps)) + quote - $value_trie = StaticChoiceMap( - NamedTuple{($(quoted_leaf_keys...),)}(($(leaf_values...),)), - NamedTuple{($(quoted_internal_keys...),)}(($(internal_values...),))) - $gradient_trie = StaticChoiceMap( - NamedTuple{($(quoted_leaf_keys...),)}(($(leaf_gradients...),)), - NamedTuple{($(quoted_internal_keys...),)}(($(internal_gradients...),))) + $value_trie = StaticChoiceMap(NamedTuple{($(quoted_all_keys...),)}(($(all_value_choicemaps...),))) + $gradient_trie = StaticChoiceMap(NamedTuple{($(quoted_all_keys...),)}(($(all_gradient_choicemaps...),))) end end diff --git a/src/static_ir/trace.jl b/src/static_ir/trace.jl index 713c0863..3c7016a1 100644 --- a/src/static_ir/trace.jl +++ b/src/static_ir/trace.jl @@ -9,25 +9,8 @@ end function get_schema end @inline get_address_schema(::Type{StaticIRTraceAssmt{T}}) where {T} = get_schema(T) - @inline Base.isempty(choices::StaticIRTraceAssmt) = isempty(choices.trace) - -@inline static_has_value(choices::StaticIRTraceAssmt, key) = false - -@inline function get_value(choices::StaticIRTraceAssmt, key::Symbol) - static_get_value(choices, Val(key)) -end - -@inline function has_value(choices::StaticIRTraceAssmt, key::Symbol) - static_has_value(choices, Val(key)) -end - -@inline function get_submap(choices::StaticIRTraceAssmt, key::Symbol) - static_get_submap(choices, Val(key)) -end - -@inline get_value(choices::StaticIRTraceAssmt, addr::Pair) = _get_value(choices, addr) -@inline has_value(choices::StaticIRTraceAssmt, addr::Pair) = _has_value(choices, addr) +@inline get_submap(choices::StaticIRTraceAssmt, key::Symbol) = static_get_submap(choices, Val(key)) @inline get_submap(choices::StaticIRTraceAssmt, addr::Pair) = _get_submap(choices, addr) ######################### @@ -36,21 +19,21 @@ end abstract type StaticIRTrace <: Trace end -@inline function static_get_subtrace(trace::StaticIRTrace, addr) - error("Not implemented") -end +@inline static_get_subtrace(trace::StaticIRTrace, addr) = error("Not implemented") +@inline static_get_submap(::StaticIRTraceAssmt, ::Val) = EmptyChoiceMap() +@inline static_get_value(assmt::StaticIRTraceAssmt, v::Val) = get_value(static_get_submap(assmt, v)) @inline static_haskey(trace::StaticIRTrace, ::Val) = false - Base.haskey(trace::StaticIRTrace, key) = Gen.static_haskey(trace, Val(key)) +@inline Base.haskey(trace::StaticIRTrace, key) = Gen.static_haskey(trace, Val(key)) -@inline function Base.getindex(trace::StaticIRTrace, addr) - Gen.static_getindex(trace, Val(addr)) -end +@inline Base.getindex(trace::StaticIRTrace, addr) = Gen.static_getindex(trace, Val(addr)) @inline function Base.getindex(trace::StaticIRTrace, addr::Pair) first, rest = addr return Gen.static_get_subtrace(trace, Val(first))[rest] end +@inline get_choices(trace::T) where {T <: StaticIRTrace} = StaticIRTraceAssmt{T}(trace) + const arg_prefix = gensym("arg") const choice_value_prefix = gensym("choice_value") const choice_score_prefix = gensym("choice_score") @@ -153,29 +136,13 @@ function generate_get_retval(ir::StaticIR, trace_struct_name::Symbol) Expr(:block, :(trace.$return_value_fieldname))) end -function generate_get_choices(trace_struct_name::Symbol) - Expr(:function, - Expr(:call, Expr(:(.), Gen, QuoteNode(:get_choices)), :(trace::$trace_struct_name)), - Expr(:if, :(!isempty(trace)), - :($(QuoteNode(StaticIRTraceAssmt))(trace)), - :($(QuoteNode(EmptyChoiceMap))()))) -end - -function generate_get_values_shallow(ir::StaticIR, trace_struct_name::Symbol) +function generate_get_submaps_shallow(ir::StaticIR, trace_struct_name::Symbol) elements = [] for node in ir.choice_nodes addr = node.addr value = :(choices.trace.$(get_value_fieldname(node))) - push!(elements, :(($(QuoteNode(addr)), $value))) + push!(elements, :(($(QuoteNode(addr)), ValueChoiceMap($value)))) end - Expr(:function, - Expr(:call, Expr(:(.), Gen, QuoteNode(:get_values_shallow)), - :(choices::$(QuoteNode(StaticIRTraceAssmt)){$trace_struct_name})), - Expr(:block, Expr(:tuple, elements...))) -end - -function generate_get_submaps_shallow(ir::StaticIR, trace_struct_name::Symbol) - elements = [] for node in ir.call_nodes addr = node.addr subtrace = :(choices.trace.$(get_subtrace_fieldname(node))) @@ -224,30 +191,6 @@ function generate_getindex(ir::StaticIR, trace_struct_name::Symbol) return [get_subtrace_exprs; call_getindex_exprs; choice_getindex_exprs] end -function generate_static_get_value(ir::StaticIR, trace_struct_name::Symbol) - methods = Expr[] - for node in ir.choice_nodes - push!(methods, Expr(:function, - Expr(:call, Expr(:(.), Gen, QuoteNode(:static_get_value)), - :(choices::$(QuoteNode(StaticIRTraceAssmt)){$trace_struct_name}), - :(::Val{$(QuoteNode(node.addr))})), - Expr(:block, :(choices.trace.$(get_value_fieldname(node)))))) - end - methods -end - -function generate_static_has_value(ir::StaticIR, trace_struct_name::Symbol) - methods = Expr[] - for node in ir.choice_nodes - push!(methods, Expr(:function, - Expr(:call, Expr(:(.), Gen, QuoteNode(:static_has_value)), - :(choices::$(QuoteNode(StaticIRTraceAssmt)){$trace_struct_name}), - :(::Val{$(QuoteNode(node.addr))})), - Expr(:block, :(true)))) - end - methods -end - function generate_static_get_submap(ir::StaticIR, trace_struct_name::Symbol) methods = Expr[] for node in ir.call_nodes @@ -259,13 +202,13 @@ function generate_static_get_submap(ir::StaticIR, trace_struct_name::Symbol) :(get_choices(choices.trace.$(get_subtrace_fieldname(node))))))) end - # throw a KeyError if get_submap is run on an address containing a value + # return a ValueChoiceMap if get_submap is run on an address containing a value for node in ir.choice_nodes push!(methods, Expr(:function, Expr(:call, Expr(:(.), Gen, QuoteNode(:static_get_submap)), :(choices::$(QuoteNode(StaticIRTraceAssmt)){$trace_struct_name}), :(::Val{$(QuoteNode(node.addr))})), - Expr(:block, :(throw(KeyError($(QuoteNode(node.addr)))))))) + Expr(:block, :(ValueChoiceMap(choices.trace.$(get_value_fieldname(node))))))) end methods end @@ -288,20 +231,14 @@ function generate_trace_type_and_methods(ir::StaticIR, name::Symbol, options::St get_score_expr = generate_get_score(trace_struct_name) get_args_expr = generate_get_args(ir, trace_struct_name) get_retval_expr = generate_get_retval(ir, trace_struct_name) - get_choices_expr = generate_get_choices(trace_struct_name) get_schema_expr = generate_get_schema(ir, trace_struct_name) - get_values_shallow_expr = generate_get_values_shallow(ir, trace_struct_name) get_submaps_shallow_expr = generate_get_submaps_shallow(ir, trace_struct_name) - static_get_value_exprs = generate_static_get_value(ir, trace_struct_name) - static_has_value_exprs = generate_static_has_value(ir, trace_struct_name) static_get_submap_exprs = generate_static_get_submap(ir, trace_struct_name) getindex_exprs = generate_getindex(ir, trace_struct_name) exprs = Expr(:block, trace_struct_expr, isempty_expr, get_score_expr, get_args_expr, get_retval_expr, - get_choices_expr, get_schema_expr, get_values_shallow_expr, - get_submaps_shallow_expr, static_get_value_exprs..., - static_has_value_exprs..., static_get_submap_exprs..., getindex_exprs...) + get_schema_expr, get_submaps_shallow_expr, static_get_submap_exprs..., getindex_exprs...) (exprs, trace_struct_name) end diff --git a/src/static_ir/update.jl b/src/static_ir/update.jl index dc4fddf3..c806bba3 100644 --- a/src/static_ir/update.jl +++ b/src/static_ir/update.jl @@ -454,9 +454,10 @@ function generate_discard!(stmts::Vector{Expr}, end leaf_keys = map((key::Symbol) -> QuoteNode(key), leaf_keys) internal_keys = map((key::Symbol) -> QuoteNode(key), internal_keys) - expr = :($(QuoteNode(StaticChoiceMap))( - $(QuoteNode(NamedTuple)){($(leaf_keys...),)}(($(leaf_nodes...),)), - $(QuoteNode(NamedTuple)){($(internal_keys...),)}(($(internal_nodes...),)))) + all_keys = (leaf_keys..., internal_keys...) + all_nodes = ([:($(QuoteNode(ValueChoiceMap))($node)) for node in leaf_nodes]..., internal_nodes...) + expr = quote $(QuoteNode(StaticChoiceMap))( + $(QuoteNode(NamedTuple)){($(all_keys...),)}(($(all_nodes...),))) end push!(stmts, :($discard = $expr)) end diff --git a/test/assignment.jl b/test/assignment.jl index 1bba754a..69890297 100644 --- a/test/assignment.jl +++ b/test/assignment.jl @@ -1,6 +1,48 @@ +@testset "ValueChoiceMap" begin + vcm1 = ValueChoiceMap(2) + vcm2 = ValueChoiceMap(2.) + vcm3 = ValueChoiceMap([1,2]) + @test vcm1 isa ValueChoiceMap{Int} + @test vcm2 isa ValueChoiceMap{Float64} + @test vcm3 isa ValueChoiceMap{Vector{Int}} + @test vcm1[] == 2 + @test vcm1[] == get_value(vcm1) + + @test !isempty(vcm1) + @test has_value(vcm1) + @test get_value(vcm1) == 2 + @test vcm1 == vcm2 + @test isempty(get_submaps_shallow(vcm1)) + @test isempty(get_values_shallow(vcm1)) + @test isempty(get_nonvalue_submaps_shallow(vcm1)) + @test to_array(vcm1, Int) == [2] + @test from_array(vcm1, [4]) == ValueChoiceMap(4) + @test from_array(vcm3, [4, 5]) == ValueChoiceMap([4, 5]) + @test_throws Exception merge(vcm1, vcm2) + @test_throws Exception merge(vcm1, choicemap(:a, 5)) + @test merge(vcm1, EmptyChoiceMap()) == vcm1 + @test merge(EmptyChoiceMap(), vcm1) == vcm1 + @test get_submap(vcm1, :addr) == EmptyChoiceMap() + @test_throws ChoiceMapGetValueError get_value(vcm1, :addr) + @test !has_value(vcm1, :addr) + @test isapprox(vcm2, ValueChoiceMap(prevfloat(2.))) + @test isapprox(vcm1, ValueChoiceMap(prevfloat(2.))) + @test get_address_schema(typeof(vcm1)) == EmptyAddressSchema() + @test get_address_schema(ValueChoiceMap) == EmptyAddressSchema() + @test nested_view(vcm1) == 2 +end + +@testset "static choicemap constructor" begin + @test StaticChoiceMap((a=ValueChoiceMap(5), b=ValueChoiceMap(6))) == StaticChoiceMap(a=5, b=6) + submap = StaticChoiceMap(a=1., b=[2., 2.5]) + @test submap == StaticChoiceMap((a=ValueChoiceMap(1.), b=ValueChoiceMap([2., 2.5]))) + outer = StaticChoiceMap(c=3, d=submap, e=submap) + @test outer == StaticChoiceMap((c=ValueChoiceMap(3), d=submap, e=submap)) +end + @testset "static assignment to/from array" begin - submap = StaticChoiceMap((a=1., b=[2., 2.5]),NamedTuple()) - outer = StaticChoiceMap((c=3.,), (d=submap, e=submap)) + submap = StaticChoiceMap(a=1., b=[2., 2.5]) + outer = StaticChoiceMap(c=3., d=submap, e=submap) arr = to_array(outer, Float64) @test to_array(outer, Float64) == Float64[3.0, 1.0, 2.0, 2.5, 1.0, 2.0, 2.5] @@ -11,14 +53,16 @@ @test choices[:d => :b] == [3.0, 4.0] @test choices[:e => :a] == 5.0 @test choices[:e => :b] == [6.0, 7.0] - @test length(collect(get_submaps_shallow(choices))) == 2 + @test length(collect(get_submaps_shallow(choices))) == 3 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 2 @test length(collect(get_values_shallow(choices))) == 1 submap1 = get_submap(choices, :d) @test length(collect(get_values_shallow(submap1))) == 2 - @test length(collect(get_submaps_shallow(submap1))) == 0 + @test length(collect(get_submaps_shallow(submap1))) == 2 + @test length(collect(get_nonvalue_submaps_shallow(submap1))) == 0 submap2 = get_submap(choices, :e) @test length(collect(get_values_shallow(submap2))) == 2 - @test length(collect(get_submaps_shallow(submap2))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(submap2))) == 0 end @testset "dynamic assignment to/from array" begin @@ -39,14 +83,18 @@ end @test choices[:d => :b] == [3.0, 4.0] @test choices[:e => :a] == 5.0 @test choices[:e => :b] == [6.0, 7.0] - @test length(collect(get_submaps_shallow(choices))) == 2 + @test get_submap(choices, :c) == ValueChoiceMap(1.0) + @test get_submap(choices, :d => :b) == ValueChoiceMap([3.0, 4.0]) + @test length(collect(get_submaps_shallow(choices))) == 3 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 2 @test length(collect(get_values_shallow(choices))) == 1 submap1 = get_submap(choices, :d) @test length(collect(get_values_shallow(submap1))) == 2 - @test length(collect(get_submaps_shallow(submap1))) == 0 + @test length(collect(get_submaps_shallow(submap1))) == 2 + @test length(collect(get_nonvalue_submaps_shallow(submap1))) == 0 submap2 = get_submap(choices, :e) @test length(collect(get_values_shallow(submap2))) == 2 - @test length(collect(get_submaps_shallow(submap2))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(submap2))) == 0 end @testset "dynamic assignment copy constructor" begin @@ -64,25 +112,6 @@ end @test choices[:u => :w] == 4 end -@testset "internal vector assignment to/from array" begin - inner = choicemap() - set_value!(inner, :a, 1.) - set_value!(inner, :b, 2.) - outer = vectorize_internal([inner, inner, inner]) - - arr = to_array(outer, Float64) - @test to_array(outer, Float64) == Float64[1, 2, 1, 2, 1, 2] - - choices = from_array(outer, Float64[1, 2, 3, 4, 5, 6]) - @test choices[1 => :a] == 1.0 - @test choices[1 => :b] == 2.0 - @test choices[2 => :a] == 3.0 - @test choices[2 => :b] == 4.0 - @test choices[3 => :a] == 5.0 - @test choices[3 => :b] == 6.0 - @test length(collect(get_submaps_shallow(choices))) == 3 -end - @testset "dynamic assignment merge" begin submap = choicemap() set_value!(submap, :x, 1) @@ -107,7 +136,7 @@ end @test choices[:f => :x] == 1 @test choices[:shared => :x] == 1 @test choices[:shared => :y] == 4. - @test length(collect(get_submaps_shallow(choices))) == 4 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 4 @test length(collect(get_values_shallow(choices))) == 3 end @@ -125,8 +154,8 @@ end set_value!(submap, :x, 1) submap2 = choicemap() set_value!(submap2, :y, 4.) - choices1 = StaticChoiceMap((a=1., b=2.), (c=submap, shared=submap)) - choices2 = StaticChoiceMap((d=3.,), (e=submap, f=submap, shared=submap2)) + choices1 = StaticChoiceMap(a=1., b=2., c=submap, shared=submap) + choices2 = StaticChoiceMap(d=3., e=submap, f=submap, shared=submap2) choices = merge(choices1, choices2) @test choices[:a] == 1. @test choices[:b] == 2. @@ -136,124 +165,91 @@ end @test choices[:f => :x] == 1 @test choices[:shared => :x] == 1 @test choices[:shared => :y] == 4. - @test length(collect(get_submaps_shallow(choices))) == 4 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 4 @test length(collect(get_values_shallow(choices))) == 3 end @testset "static assignment variadic merge" begin - choices1 = StaticChoiceMap((a=1,), NamedTuple()) - choices2 = StaticChoiceMap((b=2,), NamedTuple()) - choices3 = StaticChoiceMap((c=3,), NamedTuple()) - choices_all = StaticChoiceMap((a=1, b=2, c=3), NamedTuple()) + choices1 = StaticChoiceMap(a=1) + choices2 = StaticChoiceMap(b=2) + choices3 = StaticChoiceMap(c=3) + choices_all = StaticChoiceMap(a=1, b=2, c=3) @test merge(choices1) == choices1 @test merge(choices1, choices2, choices3) == choices_all end +# TODO: in changing a lot of these to reflect the new behavior of choicemap, +# they are mostly not error checks, but instead checks for returning `EmptyChoiceMap`; +# should we relabel this testset? @testset "static assignment errors" begin + # get_choices on an address that returns a ValueChoiceMap + choices = StaticChoiceMap(x=1) + @test get_submap(choices, :x) == ValueChoiceMap(1) + + # static_get_submap on an address that contains a value returns a ValueChoiceMap + choices = StaticChoiceMap(x=1) + @test static_get_submap(choices, Val(:x)) == ValueChoiceMap(1) - # get_choices on an address that contains a value throws a KeyError - choices = StaticChoiceMap((x=1,), NamedTuple()) - threw = false - try get_submap(choices, :x) catch KeyError threw = true end - @test threw - - # static_get_submap on an address that contains a value throws a KeyError - choices = StaticChoiceMap((x=1,), NamedTuple()) - threw = false - try static_get_submap(choices, Val(:x)) catch KeyError threw = true end - @test threw - - # get_choices on an address whose prefix contains a value throws a KeyError - choices = StaticChoiceMap((x=1,), NamedTuple()) - threw = false - try get_submap(choices, :x => :y) catch KeyError threw = true end - @test threw - - # static_get_choices on an address whose prefix contains a value throws a KeyError - choices = StaticChoiceMap((x=1,), NamedTuple()) - threw = false - try static_get_submap(choices, Val(:x)) catch KeyError threw = true end - @test threw + # get_submap on an address whose prefix contains a value returns EmptyChoiceMap + choices = StaticChoiceMap(x=1) + @test get_submap(choices, :x => :y) == EmptyChoiceMap() # get_choices on an address that contains nothing gives empty assignment - choices = StaticChoiceMap(NamedTuple(), NamedTuple()) + choices = StaticChoiceMap() @test isempty(get_submap(choices, :x)) @test isempty(get_submap(choices, :x => :y)) - # static_get_choices on an address that contains nothing throws a KeyError - choices = StaticChoiceMap(NamedTuple(), NamedTuple()) - threw = false - try static_get_submap(choices, Val(:x)) catch KeyError threw = true end - @test threw + # static_get_choices on an address that contains nothing returns an EmptyChoiceMap + choices = StaticChoiceMap() + @test static_get_submap(choices, Val(:x)) == EmptyChoiceMap() - # get_value on an address that contains a submap throws a KeyError + # get_value on an address that contains a submap throws a ChoiceMapGetValueError submap = choicemap() submap[:y] = 1 - choices = StaticChoiceMap(NamedTuple(), (x=submap,)) - threw = false - try get_value(choices, :x) catch KeyError threw = true end - @test threw + choices = StaticChoiceMap(x=submap) + @test_throws ChoiceMapGetValueError get_value(choices, :x) - # static_get_value on an address that contains a submap throws a KeyError + # static_get_value on an address that contains a submap throws a ChoiceMapGetValueError submap = choicemap() submap[:y] = 1 - choices = StaticChoiceMap(NamedTuple(), (x=submap,)) - threw = false - try static_get_value(choices, Val(:x)) catch KeyError threw = true end - @test threw - - # get_value on an address that contains nothing throws a KeyError - choices = StaticChoiceMap(NamedTuple(), NamedTuple()) - threw = false - try get_value(choices, :x) catch KeyError threw = true end - @test threw - threw = false - try get_value(choices, :x => :y) catch KeyError threw = true end - @test threw - - # static_get_value on an address that contains nothing throws a KeyError - choices = StaticChoiceMap(NamedTuple(), NamedTuple()) - threw = false - try static_get_value(choices, Val(:x)) catch KeyError threw = true end - @test threw + choices = StaticChoiceMap(x=submap) + @test_throws ChoiceMapGetValueError static_get_value(choices, Val(:x)) + + # get_value on an address that contains nothing throws a ChoiceMapGetValueError + choices = StaticChoiceMap() + @test_throws ChoiceMapGetValueError get_value(choices, :x) + @test_throws ChoiceMapGetValueError get_value(choices, :x => :y) + + # static_get_value on an address that contains nothing throws a ChoiceMapGetValueError + choices = StaticChoiceMap() + @test_throws ChoiceMapGetValueError static_get_value(choices, Val(:x)) end @testset "dynamic assignment errors" begin - - # get_choices on an address that contains a value throws a KeyError + # get_choices on an address that contains a value returns a ValueChoiceMap choices = choicemap() choices[:x] = 1 - threw = false - try get_submap(choices, :x) catch KeyError threw = true end - @test threw + @test get_submap(choices, :x) == ValueChoiceMap(1) - # get_choices on an address whose prefix contains a value throws a KeyError + # get_choices on an address whose prefix contains a value returns EmptyChoiceMap choices = choicemap() choices[:x] = 1 - threw = false - try get_submap(choices, :x => :y) catch KeyError threw = true end - @test threw + @test get_submap(choices, :x => :y) == EmptyChoiceMap() # get_choices on an address that contains nothing gives empty assignment choices = choicemap() @test isempty(get_submap(choices, :x)) @test isempty(get_submap(choices, :x => :y)) - # get_value on an address that contains a submap throws a KeyError + # get_value on an address that contains a submap throws a ChoiceMapGetValueError choices = choicemap() choices[:x => :y] = 1 - threw = false - try get_value(choices, :x) catch KeyError threw = true end - @test threw + @test_throws ChoiceMapGetValueError get_value(choices, :x) - # get_value on an address that contains nothing throws a KeyError + # get_value on an address that contains nothing throws a ChoiceMapGetValueError choices = choicemap() - threw = false - try get_value(choices, :x) catch KeyError threw = true end - @test threw - threw = false - try get_value(choices, :x => :y) catch KeyError threw = true end - @test threw + @test_throws ChoiceMapGetValueError get_value(choices, :x) + @test_throws ChoiceMapGetValueError get_value(choices, :x => :y) end @testset "dynamic assignment overwrite" begin @@ -276,9 +272,7 @@ end choices = choicemap() choices[:x => :y] = 1 choices[:x] = 2 - threw = false - try get_submap(choices, :x) catch KeyError threw = true end - @test threw + @test get_submap(choices, :x) == ValueChoiceMap(2) @test choices[:x] == 2 # overwrite subassignment with a subassignment @@ -293,17 +287,13 @@ end # illegal set value under existing value choices = choicemap() choices[:x] = 1 - threw = false - try set_value!(choices, :x => :y, 2) catch KeyError threw = true end - @test threw + @test_throws Exception set_value!(choices, :x => :y, 2) # illegal set submap under existing value choices = choicemap() choices[:x] = 1 submap = choicemap(); choices[:z] = 2 - threw = false - try set_submap!(choices, :x => :y, submap) catch KeyError threw = true end - @test threw + @test_throws Exception set_submap!(choices, :x => :y, submap) end @testset "dynamic assignment constructor" begin diff --git a/test/dynamic_choicemap_benchmark.jl b/test/dynamic_choicemap_benchmark.jl new file mode 100644 index 00000000..3724e44d --- /dev/null +++ b/test/dynamic_choicemap_benchmark.jl @@ -0,0 +1,27 @@ +using Gen + +function many_shallow(cm::ChoiceMap) + for _=1:10^5 + cm[:a] + end +end +function many_nested(cm::ChoiceMap) + for _=1:10^5 + cm[:b => :c] + end +end + +# many_shallow(cm) = perform_many_lookups(cm, :a) +# many_nested(cm) = perform_many_lookups(cm, :b => :c) + +cm = choicemap((:a, 1), (:b => :c, 2)) + +println("dynamic choicemap nonnested lookup:") +for _=1:4 + @time many_shallow(cm) +end + +println("dynamic choicemap nested lookup:") +for _=1:4 + @time many_nested(cm) +end \ No newline at end of file diff --git a/test/dynamic_dsl.jl b/test/dynamic_dsl.jl index 35f81703..5561ae54 100644 --- a/test/dynamic_dsl.jl +++ b/test/dynamic_dsl.jl @@ -119,7 +119,7 @@ end @test get_value(discard, :x) == x @test get_value(discard, :u => :a) == a @test length(collect(get_values_shallow(discard))) == 2 - @test length(collect(get_submaps_shallow(discard))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(discard))) == 1 # test new trace new_assignment = get_choices(new_trace) @@ -127,7 +127,7 @@ end @test get_value(new_assignment, :y) == y @test get_value(new_assignment, :v => :b) == b @test length(collect(get_values_shallow(new_assignment))) == 2 - @test length(collect(get_submaps_shallow(new_assignment))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(new_assignment))) == 1 # test score and weight prev_score = ( @@ -242,7 +242,7 @@ end @test !isempty(get_submap(assignment, :v)) end @test length(collect(get_values_shallow(assignment))) == 2 - @test length(collect(get_submaps_shallow(assignment))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(assignment))) == 1 # test weight if assignment[:branch] == prev_assignment[:branch] @@ -332,11 +332,11 @@ end @test get_value(choices, :out) == out @test get_value(choices, :bar => :z) == z @test !has_value(choices, :b) # was not selected - @test length(collect(get_submaps_shallow(choices))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 1 @test length(collect(get_values_shallow(choices))) == 2 # check gradient trie - @test length(collect(get_submaps_shallow(gradients))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(gradients))) == 1 @test length(collect(get_values_shallow(gradients))) == 2 @test !has_value(gradients, :b) # was not selected @test isapprox(get_value(gradients, :bar => :z), @@ -431,14 +431,14 @@ end @test choices[:x => 2] == 2 @test choices[:x => 3 => :z] == 3 @test length(collect(get_values_shallow(choices))) == 1 # :y - @test length(collect(get_submaps_shallow(choices))) == 1 # :x + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 1 # :x submap = get_submap(choices, :x) @test submap[1] == 1 @test submap[2] == 2 @test submap[3 => :z] == 3 @test length(collect(get_values_shallow(submap))) == 2 # 1, 2 - @test length(collect(get_submaps_shallow(submap))) == 1 # 3 + @test length(collect(get_nonvalue_submaps_shallow(submap))) == 1 # 3 bar_submap = get_submap(submap, 3) @test bar_submap[:z] == 3 diff --git a/test/modeling_library/call_at.jl b/test/modeling_library/call_at.jl index b27f0130..607eb61f 100644 --- a/test/modeling_library/call_at.jl +++ b/test/modeling_library/call_at.jl @@ -20,7 +20,7 @@ y = choices[3 => :y] @test isapprox(weight, logpdf(normal, y, 0.4, 1)) @test length(collect(get_values_shallow(choices))) == 0 - @test length(collect(get_submaps_shallow(choices))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 1 end @testset "generate" begin @@ -32,7 +32,7 @@ y = choices[3 => :y] @test get_retval(trace) == 0.4 + y @test length(collect(get_values_shallow(choices))) == 0 - @test length(collect(get_submaps_shallow(choices))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 1 # with constraints y = 1.234 @@ -44,7 +44,7 @@ @test get_retval(trace) == 0.4 + y @test isapprox(weight, logpdf(normal, y, 0.4, 1.)) @test length(collect(get_values_shallow(choices))) == 0 - @test length(collect(get_submaps_shallow(choices))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 1 end function get_trace() @@ -71,7 +71,7 @@ choices = get_choices(new_trace) @test choices[3 => :y] == y @test length(collect(get_values_shallow(choices))) == 0 - @test length(collect(get_submaps_shallow(choices))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 1 @test isapprox(weight, logpdf(normal, y, 0.2, 1) - logpdf(normal, y, 0.4, 1)) @test get_retval(new_trace) == 0.2 + y @test isempty(discard) @@ -86,12 +86,12 @@ choices = get_choices(new_trace) @test choices[3 => :y] == y_new @test length(collect(get_values_shallow(choices))) == 0 - @test length(collect(get_submaps_shallow(choices))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 1 @test isapprox(weight, logpdf(normal, y_new, 0.2, 1) - logpdf(normal, y, 0.4, 1)) @test get_retval(new_trace) == 0.2 + y_new @test discard[3 => :y] == y @test length(collect(get_values_shallow(discard))) == 0 - @test length(collect(get_submaps_shallow(discard))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(discard))) == 1 @test isapprox(get_score(new_trace), logpdf(normal, y_new, 0.2, 1)) # change kernel_args, different key, with constraint @@ -103,12 +103,12 @@ choices = get_choices(new_trace) @test choices[4 => :y] == y_new @test length(collect(get_values_shallow(choices))) == 0 - @test length(collect(get_submaps_shallow(choices))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 1 @test isapprox(weight, logpdf(normal, y_new, 0.2, 1) - logpdf(normal, y, 0.4, 1)) @test get_retval(new_trace) == 0.2 + y_new @test discard[3 => :y] == y @test length(collect(get_values_shallow(discard))) == 0 - @test length(collect(get_submaps_shallow(discard))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(discard))) == 1 @test isapprox(get_score(new_trace), logpdf(normal, y_new, 0.2, 1)) end @@ -121,7 +121,7 @@ choices = get_choices(new_trace) @test choices[3 => :y] == y @test length(collect(get_values_shallow(choices))) == 0 - @test length(collect(get_submaps_shallow(choices))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 1 @test isapprox(weight, logpdf(normal, y, 0.2, 1) - logpdf(normal, y, 0.4, 1)) @test get_retval(new_trace) == 0.2 + y @test isapprox(get_score(new_trace), logpdf(normal, y, 0.2, 1)) @@ -133,7 +133,7 @@ choices = get_choices(new_trace) y_new = choices[3 => :y] @test length(collect(get_values_shallow(choices))) == 0 - @test length(collect(get_submaps_shallow(choices))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 1 @test weight == 0. @test get_retval(new_trace) == 0.2 + y_new @test isapprox(get_score(new_trace), logpdf(normal, y_new, 0.2, 1)) @@ -144,7 +144,7 @@ choices = get_choices(new_trace) y_new = choices[4 => :y] @test length(collect(get_values_shallow(choices))) == 0 - @test length(collect(get_submaps_shallow(choices))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 1 @test weight == 0. @test get_retval(new_trace) == 0.2 + y_new @test isapprox(get_score(new_trace), logpdf(normal, y_new, 0.2, 1)) @@ -171,9 +171,9 @@ @test choices[3 => :y] == y @test isapprox(gradients[3 => :y], logpdf_grad(normal, y, 0.4, 1.0)[1] + retval_grad) @test length(collect(get_values_shallow(gradients))) == 0 - @test length(collect(get_submaps_shallow(gradients))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(gradients))) == 1 @test length(collect(get_values_shallow(choices))) == 0 - @test length(collect(get_submaps_shallow(choices))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 1 @test length(input_grads) == 2 @test isapprox(input_grads[1], logpdf_grad(normal, y, 0.4, 1.0)[2] + retval_grad) @test input_grads[2] == nothing # the key has no gradient diff --git a/test/modeling_library/choice_at.jl b/test/modeling_library/choice_at.jl index 080b1b46..4f524138 100644 --- a/test/modeling_library/choice_at.jl +++ b/test/modeling_library/choice_at.jl @@ -15,7 +15,7 @@ @test isapprox(weight, value ? log(0.4) : log(0.6)) @test choices[3] == value @test length(collect(get_values_shallow(choices))) == 1 - @test length(collect(get_submaps_shallow(choices))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 0 end @testset "generate" begin @@ -27,7 +27,7 @@ choices = get_choices(trace) @test choices[3] == value @test length(collect(get_values_shallow(choices))) == 1 - @test length(collect(get_submaps_shallow(choices))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 0 # with constraints constraints = choicemap() @@ -39,7 +39,7 @@ choices = get_choices(trace) @test choices[3] == value @test length(collect(get_values_shallow(choices))) == 1 - @test length(collect(get_submaps_shallow(choices))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 0 end function get_trace() @@ -65,7 +65,7 @@ choices = get_choices(new_trace) @test choices[3] == true @test length(collect(get_values_shallow(choices))) == 1 - @test length(collect(get_submaps_shallow(choices))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 0 @test isapprox(weight, log(0.2) - log(0.4)) @test get_retval(new_trace) == true @test isempty(discard) @@ -78,12 +78,12 @@ choices = get_choices(new_trace) @test choices[3] == false @test length(collect(get_values_shallow(choices))) == 1 - @test length(collect(get_submaps_shallow(choices))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 0 @test isapprox(weight, log(1 - 0.2) - log(0.4)) @test get_retval(new_trace) == false @test discard[3] == true @test length(collect(get_values_shallow(discard))) == 1 - @test length(collect(get_submaps_shallow(discard))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(discard))) == 0 # change kernel_args, different key, with constraint constraints = choicemap() @@ -93,12 +93,12 @@ choices = get_choices(new_trace) @test choices[4] == false @test length(collect(get_values_shallow(choices))) == 1 - @test length(collect(get_submaps_shallow(choices))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 0 @test isapprox(weight, log(1 - 0.2) - log(0.4)) @test get_retval(new_trace) == false @test discard[3] == true @test length(collect(get_values_shallow(discard))) == 1 - @test length(collect(get_submaps_shallow(discard))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(discard))) == 0 end @testset "regenerate" begin @@ -110,7 +110,7 @@ choices = get_choices(new_trace) @test choices[3] == true @test length(collect(get_values_shallow(choices))) == 1 - @test length(collect(get_submaps_shallow(choices))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 0 @test isapprox(weight, log(0.2) - log(0.4)) @test get_retval(new_trace) == true @test isapprox(get_score(new_trace), log(0.2)) @@ -122,7 +122,7 @@ choices = get_choices(new_trace) value = choices[3] @test length(collect(get_values_shallow(choices))) == 1 - @test length(collect(get_submaps_shallow(choices))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 0 @test weight == 0. @test get_retval(new_trace) == value @test isapprox(get_score(new_trace), log(value ? 0.2 : 1 - 0.2)) @@ -133,7 +133,7 @@ choices = get_choices(new_trace) value = choices[4] @test length(collect(get_values_shallow(choices))) == 1 - @test length(collect(get_submaps_shallow(choices))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 0 @test weight == 0. @test get_retval(new_trace) == value @test isapprox(get_score(new_trace), log(value ? 0.2 : 1 - 0.2)) @@ -163,9 +163,9 @@ @test choices[3] == y @test isapprox(gradients[3], logpdf_grad(normal, y, 0.0, 1.0)[1] + retval_grad) @test length(collect(get_values_shallow(gradients))) == 1 - @test length(collect(get_submaps_shallow(gradients))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(gradients))) == 0 @test length(collect(get_values_shallow(choices))) == 1 - @test length(collect(get_submaps_shallow(choices))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 0 @test length(input_grads) == 3 @test isapprox(input_grads[1], logpdf_grad(normal, y, 0.0, 1.0)[2]) @test isapprox(input_grads[2], logpdf_grad(normal, y, 0.0, 1.0)[3]) diff --git a/test/modeling_library/recurse.jl b/test/modeling_library/recurse.jl index 46954e3b..b440a44f 100644 --- a/test/modeling_library/recurse.jl +++ b/test/modeling_library/recurse.jl @@ -197,9 +197,9 @@ end @test choices[(4, Val(:production)) => :rule] == 4 @test choices[(4, Val(:aggregation)) => :prefix] == false @test discard[(3, Val(:aggregation)) => :prefix] == true - @test length(collect(get_submaps_shallow(discard))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(discard))) == 1 @test length(collect(get_values_shallow(discard))) == 0 - @test length(collect(get_submaps_shallow(get_submap(discard,(3, Val(:aggregation)))))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(get_submap(discard,(3, Val(:aggregation)))))) == 0 @test length(collect(get_values_shallow(get_submap(discard,(3, Val(:aggregation)))))) == 1 @test retdiff == UnknownChange() diff --git a/test/modeling_library/unfold.jl b/test/modeling_library/unfold.jl index ba748453..0f3a5618 100644 --- a/test/modeling_library/unfold.jl +++ b/test/modeling_library/unfold.jl @@ -28,7 +28,7 @@ x3 = trace[3 => :x] choices = get_choices(trace) @test length(collect(get_values_shallow(choices))) == 0 - @test length(collect(get_submaps_shallow(choices))) == 3 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 3 expected_score = (logpdf(normal, x1, x_init * alpha + beta, std) + logpdf(normal, x2, x1 * alpha + beta, std) + logpdf(normal, x3, x2 * alpha + beta, std)) @@ -55,7 +55,7 @@ @test choices[1 => :x] == x1 @test choices[3 => :x] == x3 @test length(collect(get_values_shallow(choices))) == 0 - @test length(collect(get_submaps_shallow(choices))) == 3 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 3 x2 = choices[2 => :x] expected_weight = (logpdf(normal, x1, x_init * alpha + beta, std) + logpdf(normal, x3, x2 * alpha + beta, std)) @@ -77,7 +77,7 @@ beta = 0.3 (choices, weight, retval) = propose(foo, (3, x_init, alpha, beta)) @test length(collect(get_values_shallow(choices))) == 0 - @test length(collect(get_submaps_shallow(choices))) == 3 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 3 x1 = choices[1 => :x] x2 = choices[2 => :x] x3 = choices[3 => :x] diff --git a/test/runtests.jl b/test/runtests.jl index a67a5f78..74923603 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -74,4 +74,4 @@ include("static_ir/static_ir.jl") include("static_dsl.jl") include("tilde_sugar.jl") include("inference/inference.jl") -include("modeling_library/modeling_library.jl") +include("modeling_library/modeling_library.jl") \ No newline at end of file diff --git a/test/static_choicemap_benchmark.jl b/test/static_choicemap_benchmark.jl new file mode 100644 index 00000000..1e62b9a8 --- /dev/null +++ b/test/static_choicemap_benchmark.jl @@ -0,0 +1,50 @@ +using Gen + +function many_shallow(cm::ChoiceMap) + for _=1:10^5 + cm[:a] + end +end +function many_nested(cm::ChoiceMap) + for _=1:10^5 + cm[:b => :c] + end +end + +# many_shallow(cm) = perform_many_lookups(cm, :a) +# many_nested(cm) = perform_many_lookups(cm, :b => :c) + +scm = StaticChoiceMap(a=1, b=StaticChoiceMap(c=2)) + +println("static choicemap nonnested lookup:") +for _=1:4 + @time many_shallow(scm) +end + +println("static choicemap nested lookup:") +for _=1:4 + @time many_nested(scm) +end + +@gen (static) function inner() + c ~ normal(0, 1) +end +@gen (static) function outer() + a ~ normal(0, 1) + b ~ inner() +end + +load_generated_functions() + +tr, _ = generate(outer, ()) +choices = get_choices(tr) + +println("static gen function choicemap nonnested lookup:") +for _=1:4 + @time many_shallow(choices) +end + +println("static gen function choicemap nested lookup:") +for _=1:4 + @time many_nested(choices) +end diff --git a/test/static_inference_benchmark.jl b/test/static_inference_benchmark.jl new file mode 100644 index 00000000..b70d08be --- /dev/null +++ b/test/static_inference_benchmark.jl @@ -0,0 +1,23 @@ +using Gen + +@gen (static, diffs) function foo() + a ~ normal(0, 1) + b ~ normal(a, 1) + c ~ normal(b, 1) +end + +@load_generated_functions + +observations = StaticChoiceMap(choicemap((:b,2), (:c,1.5))) +tr, _ = generate(foo, (), observations) + +function run_inference(trace) + tr = trace + for _=1:10^3 + tr, acc = mh(tr, select(:a)) + end +end + +for _=1:4 + @time run_inference(tr) +end \ No newline at end of file diff --git a/test/static_ir/static_ir.jl b/test/static_ir/static_ir.jl index 91c6c320..1b594d39 100644 --- a/test/static_ir/static_ir.jl +++ b/test/static_ir/static_ir.jl @@ -362,12 +362,12 @@ end @test get_value(value_trie, :out) == out @test get_value(value_trie, :bar => :z) == z @test !has_value(value_trie, :b) # was not selected - @test length(get_submaps_shallow(value_trie)) == 1 - @test length(get_values_shallow(value_trie)) == 2 + @test length(collect(get_nonvalue_submaps_shallow(value_trie))) == 1 + @test length(collect(get_values_shallow(value_trie))) == 2 # check gradient trie - @test length(get_submaps_shallow(gradient_trie)) == 1 - @test length(get_values_shallow(gradient_trie)) == 2 + @test length(collect(get_nonvalue_submaps_shallow(gradient_trie))) == 1 + @test length(collect(get_values_shallow(gradient_trie))) == 2 @test !has_value(gradient_trie, :b) # was not selected @test isapprox(get_value(gradient_trie, :a), finite_diff(f, (mu_a, theta, a, b, z, out), 3, dx)) @test isapprox(get_value(gradient_trie, :out), finite_diff(f, (mu_a, theta, a, b, z, out), 6, dx))