diff --git a/docs/src/ref/choice_maps.md b/docs/src/ref/choice_maps.md index c065b1b32..db30bed77 100644 --- a/docs/src/ref/choice_maps.md +++ b/docs/src/ref/choice_maps.md @@ -8,13 +8,58 @@ ChoiceMap Choice maps are constructed by users to express observations and/or constraints on the traces of generative functions. Choice maps are also returned by certain Gen inference methods, and are used internally by various Gen inference methods. +A choicemap a tree, whose leaf nodes store a single value, and whose internal nodes provide addresses +for sub-choicemaps. Leaf nodes have type: +```@docs +Value +``` + +### Example Usage Overview + +Choicemaps store values nested in a tree where each node posesses an address for each subtree. +A leaf-node choicemap simply contains a value, and has its value looked up via: +```julia +value = choicemap[] +``` +If a choicemap has a value choicemap at address `:a`, the value it stores is looked up via: +```julia +value = choicemap[:a] +``` +A choicemap may also have a non-value choicemap stored at an address. For instance, +if a choicemap has another choicemap stored at address `:a`, and this internal choicemap +has a Value stored at address `:b` and another at `:c`, we could perform the following lookups: +```julia +value1 = choicemap[:a => :b] +value2 = choicemap[:a => :c] +``` +Nesting can be arbitrarily deep, and the keys can be arbitrary values; for instance +choicemaps can be constructed with values at the following nested addresses: +```julia +value = choicemap[:a => :b => :c => 4 => 1.63 => :e] +value = choicemap[:a => :b => :a => 2 => "alphabet" => :e] +``` +To get a sub-choicemap, use `get_submap`: +```julia +value1 = choicemap[:a => :b] +submap = get_submap(choicemap, :a) +value1 == submap[:b] # is true + +value_submap = get_submap(choicemap, :a => :b) +value_submap[] == value1 # is true +``` +One can think of `Value`s at storing being a choicemap which has a value at "nesting level zero", +while other choicemaps have values at "nesting level" one or higher. + +### Interface + Choice maps provide the following methods: ```@docs +get_submap +get_submaps_shallow has_value get_value -get_submap get_values_shallow -get_submaps_shallow +get_nonvalue_submaps_shallow to_array from_array get_selected @@ -23,7 +68,7 @@ Note that none of these methods mutate the choice map. Choice maps also implement: -- `Base.isempty`, which tests of there are no random choices in the choice map +- `Base.isempty`, which returns `false` if the choicemap contains no value or submaps, and `true` otherwise. - `Base.merge`, which takes two choice maps, and returns a new choice map containing all random choices in either choice map. It is an error if the choice maps both have values at the same address, or if one choice map has a value at an address that is the prefix of the address of a value in the other choice map. @@ -50,3 +95,35 @@ choicemap set_value! set_submap! ``` + +### Implementing custom choicemap types + +To implement a custom choicemap, one must implement +`get_submap` and `get_submaps_shallow`. +To avoid method ambiguity with the default +`get_submap(::ChoiceMap, ::Pair)`, one must implement both +```julia +get_submap(::CustomChoiceMap, addr) +``` +and +```julia +get_submap(::CustomChoiceMap, addr::Pair) +``` +To use the default implementation of `get_submap(_, ::Pair)`, +one may define +```julia +get_submap(c::CustomChoiceMap, addr::Pair) = _get_choicemap(c, addr) +``` + +Once `get_submap` and `get_submaps_shallow` are defined, default +implementations are provided for: +- `has_value` +- `get_value` +- `get_values_shallow` +- `get_nonvalue_submaps_shallow` +- `to_array` +- `get_selected` + +If one wishes to support `from_array`, they must implement +`_from_array`, as described in the documentation for +[`from_array`](@ref). \ No newline at end of file diff --git a/docs/src/ref/distributions.md b/docs/src/ref/distributions.md index c81801d43..a928fabf9 100644 --- a/docs/src/ref/distributions.md +++ b/docs/src/ref/distributions.md @@ -1,5 +1,11 @@ # Probability Distributions +In Gen, a probability distribution is a generative function which makes a single random choice +and returns the value of this choice. The choicemap for a probability distribution +is always a [`Value`](@ref). In addition to supporting the regular `GFI` methods, +every distribution supports the methods [`random`](@ref) and [`logpdf`](@ref), described +in the [Distribution API](@ref custom_distributions). + Gen provides a library of built-in probability distributions, and two ways of writing custom distributions, both of which are explained below: diff --git a/docs/src/ref/extending.md b/docs/src/ref/extending.md index 7f9dfd480..b1d759b3a 100644 --- a/docs/src/ref/extending.md +++ b/docs/src/ref/extending.md @@ -110,7 +110,7 @@ Gen's Distribution interface directly, as defined below. Probability distributions are singleton types whose supertype is `Distribution{T}`, where `T` indicates the data type of the random sample. ```julia -abstract type Distribution{T} end +abstract type Distribution{T} <: GenerativeFunction{T, DistributionTrace} end ``` A new Distribution type must implement the following methods: @@ -146,6 +146,9 @@ has_output_grad logpdf_grad ``` +Any custom distribution will automatically be a `GenerativeFunction` since `Distribution <: GenerativeFunction`; +implementations of all GFI methods are automatically provided in terms of `random` and `logpdf`. + ## Custom generative functions We recommend the following steps for implementing a new type of generative function, and also looking at the implementation for the [`DynamicDSLFunction`](@ref) type as an example. diff --git a/examples/pmmh/model.jl b/examples/pmmh/model.jl index ce5897064..5421dd691 100644 --- a/examples/pmmh/model.jl +++ b/examples/pmmh/model.jl @@ -150,7 +150,7 @@ get_call_record(trace::CollapsedHMMTrace) = trace.vector.call has_choices(trace::CollapsedHMMTrace) = length(trace.vector.call.retval) > 0 get_choices(trace::CollapsedHMMTrace) = CollapsedHMMChoiceMap(get_choices(trace.vector)) -struct CollapsedHMMChoiceMap <: ChoiceMap +struct CollapsedHMMChoiceMap <: AddressTree{Value} y_assignment::VectorDistTraceChoiceMap end diff --git a/src/Gen.jl b/src/Gen.jl index 9f3da9e3a..22623b040 100644 --- a/src/Gen.jl +++ b/src/Gen.jl @@ -33,22 +33,23 @@ export load_generated_functions, @load_generated_functions # built-in extensions to the reverse mode AD include("backprop.jl") -# addresses and address selections -include("address.jl") - -# abstract and built-in concrete choice map data types -include("choice_map.jl") +# address and address trees +# (including choicemaps and selections) +include("address_tree/address_tree.jl") # a homogeneous trie data type (not for use as choice map) include("trie.jl") +# built-in data types for arg-diff and ret-diff values +include("diff.jl") + # generative function interface include("gen_fn_interface.jl") -# built-in data types for arg-diff and ret-diff values -include("diff.jl") +# distribution abstract type +include("distribution.jl") -# built-in probability disributions +# built-in probability disributions; distribution dsl; combinators include("modeling_library/modeling_library.jl") # optimization of trainable parameters diff --git a/src/address.jl b/src/address.jl deleted file mode 100644 index 2d6499a6a..000000000 --- a/src/address.jl +++ /dev/null @@ -1,370 +0,0 @@ -################### -# address schemas # -################### - -abstract type AddressSchema end - -struct StaticAddressSchema <: AddressSchema - keys::Set{Symbol} -end - -Base.keys(schema::StaticAddressSchema) = schema.keys - -struct VectorAddressSchema <: AddressSchema end -struct SingleDynamicKeyAddressSchema <: AddressSchema end -struct DynamicAddressSchema <: AddressSchema end -struct EmptyAddressSchema <: AddressSchema end -struct AllAddressSchema <: AddressSchema end - -export AddressSchema -export StaticAddressSchema # hierarchical -export VectorAddressSchema # hierarchical -export SingleDynamicKeyAddressSchema # hierarchical -export DynamicAddressSchema # hierarchical -export EmptyAddressSchema -export AllAddressSchema - -###################### -# abstract selection # -###################### - -""" - abstract type Selection end - -Abstract type for selections of addresses. - -All selections implement the following methods: - - Base.in(addr, selection) - -Is the address selected? - - Base.getindex(selection, addr) - -Get the subselection at the given address. - - Base.isempty(selection) - -Is the selection guaranteed to be empty? - - get_address_schema(T) - -Return a shallow, compile-time address schema, where `T` is the concrete type of the selection. -""" -abstract type Selection end - -Base.in(addr, ::Selection) = false -Base.getindex(::Selection, addr) = EmptySelection() - -export Selection - -########################## -# hierarchical selection # -########################## - -""" - abstract type HierarchicalSelection <: Selection end - -Abstract type for selections that have a notion of sub-selections. - - get_subselections(selection::HierarchicalSelection) - -Return an iterator over pairs of addresses and subselections at associated addresses. -""" -abstract type HierarchicalSelection <: Selection end - -export HierarchicalSelection -export get_subselections - -################### -# empty selection # -################### - -""" - struct EmptySelection <: Selection end - -A singleton type for a selection that is always empty. -""" -struct EmptySelection <: Selection end -get_address_schema(::Type{EmptySelection}) = EmptyAddressSchema() -Base.isempty(::EmptySelection) = true - -export EmptySelection - -################# -# all selection # -################# - -""" - struct AllSelection <: Selection end - -A singleton type for a selection that contains all choices at or under an address. -""" -struct AllSelection <: Selection end -get_address_schema(::Type{AllSelection}) = AllAddressSchema() -Base.isempty(::AllSelection) = false # it is not guaranteed to be empty -Base.in(addr, ::AllSelection) = true -Base.getindex(::AllSelection, addr) = AllSelection() - -export AllSelection - -######################## -# complement selection # -######################## - -struct ComplementSelection <: Selection - complement::Selection -end -get_address_schema(::Type{ComplementSelection}) = DynamicAddressSchema() -Base.isempty(::ComplementSelection) = false # it is not guaranteed to be empty -Base.in(addr, selection::ComplementSelection) = !(addr in selection.complement) -function Base.getindex(selection::ComplementSelection, addr) - ComplementSelection(selection.complement[addr]) -end - -""" - comp_selection = complement(selection::Selection) - -Return a selection that is the complement of the given selection. - -An address is in the selection if it is not in the complement selection. -""" -function complement(selection::Selection) - ComplementSelection(selection) -end - -export ComplementSelection, complement - -#################### -# static selection # -#################### - -# R is a tuple of symbols.. -# T is a tuple of symbols -# U the tuple type of subselections - -""" - struct StaticSelection{T,U} <: HierarchicalSelection .. end - -A hierarchical selection whose keys are among its type parameters. -""" -struct StaticSelection{T,U} <: HierarchicalSelection - subselections::NamedTuple{T,U} -end - -function Base.isempty(selection::StaticSelection{T,U}) where {T,U} - length(R) == 0 && all(isempty(node) for node in selection.subselections) -end - -function get_address_schema(::Type{StaticSelection{T,U}}) where {T,U} - keys = Set{Symbol}() - for (key, _) in zip(T, U.parameters) - push!(keys, key) - end - StaticAddressSchema(keys) -end - -get_subselections(selection::StaticSelection) = pairs(selection.subselections) - -function static_getindex(selection::StaticSelection, ::Val{A}) where {A} - selection.subselections[A] -end - -# TODO do we no longer need static_in? - -function Base.getindex(selection::StaticSelection, addr::Symbol) - if haskey(selection.subselections, addr) - selection.subselections[addr] - else - EmptySelection() - end -end - -function Base.getindex(selection::StaticSelection, addr::Pair) - (first, rest) = addr - subselection = selection.subselections[first] - subselection[rest] -end - -function Base.in(addr::Symbol, selection::StaticSelection{T,U}) where {T,U} - addr in T && selection.subselections[addr] == AllSelection() -end - -function Base.in(addr::Pair, selection::StaticSelection) - (first, rest) = addr - if haskey(selection.subselections, first) - subselection = selection.subselections[first] - in(subselection, rest) - else - false - end -end - -function StaticSelection(other::HierarchicalSelection) - keys_and_subselections = collect(get_subselections(other)) - if length(keys_and_subselections) > 0 - (keys, subselections) = collect(zip(keys_and_subselections...)) - else - (keys, subselections) = ((), ()) - end - types = map(typeof, subselections) - StaticSelection{keys,Tuple{types...}}(NamedTuple{keys}(subselections)) -end - -export StaticSelection - - -##################### -# dynamic selection # -##################### - -""" - struct DynamicSelection <: HierarchicalSelection .. end - -A hierarchical, mutable, selection with arbitrary addresses. - -Can be mutated with the following methods: - - - Base.push!(selection::DynamicSelection, addr) - -Add the address and all of its sub-addresses to the selection. - -Example: -```julia -selection = select() -@assert !(:x in selection) -push!(selection, :x) -@assert :x in selection -``` - - set_subselection!(selection::DynamicSelection, addr, other::Selection) - -Change the selection status of the given address and its sub-addresses that defined by `other`. - -Example: -```julia -selection = select(:x) -@assert :x in selection -subselection = select(:y) -set_subselection!(selection, :x, subselection) -@assert (:x => :y) in selection -@assert !(:x in selection) -``` - -Note that `set_subselection!` does not copy data in `other`, so `other` may be mutated by a later calls to `set_subselection!` for addresses under `addr`. -""" -struct DynamicSelection <: HierarchicalSelection - # note: only store subselections for which isempty = false - subselections::Dict{Any,Selection} -end - -function Base.isempty(selection::DynamicSelection) - isempty(selection.subselections) -end - -DynamicSelection() = DynamicSelection(Dict{Any,Selection}()) - -get_address_schema(::Type{DynamicSelection}) = DynamicAddressSchema() - -function Base.in(addr, selection::DynamicSelection) - if haskey(selection.subselections, addr) - selection.subselections[addr] == AllSelection() - else - false - end -end - -function Base.in(addr::Pair, selection::DynamicSelection) - (first, rest) = addr - if haskey(selection.subselections, first) - subselection = selection.subselections[first] - @assert !isempty(subselection) - rest in subselection - else - false - end -end - -function Base.getindex(selection::DynamicSelection, addr) - if haskey(selection.subselections, addr) - selection.subselections[addr] - else - EmptySelection() - end -end - -function Base.getindex(selection::DynamicSelection, addr::Pair) - (first, rest) = addr - if haskey(selection.subselections, first) - subselection = selection.subselections[first] - @assert !isempty(subselection) - getindex(subselection, rest) - else - EmptySelection() - end -end - -function Base.push!(selection::DynamicSelection, addr) - selection.subselections[addr] = AllSelection() -end - -function Base.push!(selection::DynamicSelection, addr::Pair) - (first, rest) = addr - if haskey(selection.subselections, first) - subselection = selection.subselections[first] - else - subselection = DynamicSelection() - selection.subselections[first] = subselection - end - push!(subselection, rest) -end - -function set_subselection!(selection::DynamicSelection, addr, other::Selection) - selection.subselections[addr] = other -end - -function set_subselection!(selection::DynamicSelection, addr::Pair, other::Selection) - (first, rest) = addr - if haskey(selection.subselections, first) - subselection = selection.subselections[first] - else - subselection = DynamicSelection() - selection.subselections[first] = subselection - end - set_subselection!(subselection, rest, other) -end - -get_subselections(selection::DynamicSelection) = selection.subselections - -""" - selection = select(addrs...) - -Return a selection containing a given set of addresses. - -Examples: -```julia -selection = select(:x, "foo", :y => 1 => :z) -selection = select() -selection = select(:x => 1, :x => 2) -``` -""" -function select(addrs...) - selection = DynamicSelection() - for addr in addrs - push!(selection, addr) - end - selection -end - -""" - selection = selectall() - -Construct a selection that includes all random choices. -""" -function selectall() - AllSelection() -end - -export DynamicSelection -export select, selectall, set_subselection! diff --git a/src/address_tree/.ipynb_checkpoints/Untitled-checkpoint.ipynb b/src/address_tree/.ipynb_checkpoints/Untitled-checkpoint.ipynb new file mode 100644 index 000000000..a11a62ac5 --- /dev/null +++ b/src/address_tree/.ipynb_checkpoints/Untitled-checkpoint.ipynb @@ -0,0 +1,11 @@ +{ + "cells": [], + "metadata": { + "language_info": { + "name": "python", + "pygments_lexer": "ipython3" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/src/address_tree/address_schema.jl b/src/address_tree/address_schema.jl new file mode 100644 index 000000000..409fbe554 --- /dev/null +++ b/src/address_tree/address_schema.jl @@ -0,0 +1,32 @@ +abstract type AddressSchema end +abstract type StaticSchema <: AddressSchema end + +struct StaticAddressSchema <: StaticSchema + keys::Set{Symbol} +end +Base.keys(schema::StaticAddressSchema) = schema.keys + +struct StaticInverseAddressSchema <: StaticSchema + inv::StaticAddressSchema +end +struct InvertedKeys + keys +end +Base.in(key, ik::InvertedKeys) = !(Base.in(key, ik.keys)) +Base.keys(schema::StaticInverseAddressSchema) = InvertedKeys(keys(schema.inv)) + +struct EmptyAddressSchema <: StaticSchema end +struct AllAddressSchema <: StaticSchema end + +struct VectorAddressSchema <: AddressSchema end +struct SingleDynamicKeyAddressSchema <: AddressSchema end +struct DynamicAddressSchema <: AddressSchema end + +export AddressSchema +export StaticAddressSchema # hierarchical +export VectorAddressSchema # hierarchical +export SingleDynamicKeyAddressSchema # hierarchical +export DynamicAddressSchema # hierarchical +export EmptyAddressSchema +export AllAddressSchema +export StaticInverseAddressSchema \ No newline at end of file diff --git a/src/address_tree/address_tree.jl b/src/address_tree/address_tree.jl new file mode 100644 index 000000000..8ee98e09a --- /dev/null +++ b/src/address_tree/address_tree.jl @@ -0,0 +1,245 @@ +include("address_schema.jl") + +""" + AddressTree{LeafType} + +Abstract type for trees where each node's subtrees are labelled with +an address. All leaf nodes are of `LeafType` (or are an `EmptyAddressTree`). +""" +abstract type AddressTree{LeafType} end + +""" + AddressTreeLeaf + +Abstract type for address tree leaf nodes. + +## Note: +When declaring a subtype `T` of `AddressTreeLeaf`, +declare `T <: AddressTreeLeaf{T}` to ensure +`T <: AddressTree{T}`. +""" +abstract type AddressTreeLeaf{Type} <: AddressTree{Type} end + +""" + EmptyAddressTree + +An empty address tree with no subtrees. +""" +struct EmptyAddressTree <: AddressTreeLeaf{EmptyAddressTree} end + +""" + Value{T} + +An address tree leaf node storing a value of type `T`. +""" +struct Value{T} <: AddressTreeLeaf{Value} + val::T +end +@inline get_value(v::Value) = v.val + +# Note that I don't set `Value{T} <: AddressTreeLeaf{Value{T}}`. +# I have run into issues when I do this, +# because then `Value{T} <: AddressTree{Value}` is not true. +# There may be some way to make this work, but I haven't been +# able to figure it out, and I don't currently have any need +# for specifying `AddressTree{Value{T}}` types, so I'm not going +# to worry about it for now. + +""" + SelectionLeaf + +Abstract type for a `Selection` which cannot be naturally decomposed +into "subtrees" as an address tree. (Often this is because infinitely +many addresses should have `get_subtree` return a nonempty tree.) +""" +abstract type SelectionLeaf <: AddressTreeLeaf{SelectionLeaf} end + +""" + AllSelection + +An address tree leaf node representing that all sub-addresses +from this point are selected. +""" +struct AllSelection <: SelectionLeaf end + +""" + CustomUpdateSpec + +Supertype for custom update specifications. +""" +abstract type CustomUpdateSpec <: AddressTreeLeaf{CustomUpdateSpec} end +const UpdateSpec = AddressTree{<:Union{Value, SelectionLeaf, EmptyAddressTree, CustomUpdateSpec}} + +""" + get_subtree(tree::AddressTree{T}, addr)::Union{AddressTree{T}, EmptyAddressTree} + +Get the subtree at address `addr` or return `EmptyAddressTree` +if there is no subtree at this address. + +Invariant: `get_subtree(::AddressTree{LeafType}, addr)` either returns +an object of `LeafType` or an `EmptyAddressTree`. +""" +function get_subtree(::AddressTree{LeafType}, addr)::AddressTree{LeafType} where {LeafType} end + +function _get_subtree(t::AddressTree, addr::Pair) + get_subtree(get_subtree(t, addr.first), addr.second) +end + +""" + get_subtrees_shallow(tree::AddressTree{T}) + +Return an iterator over tuples `(address, subtree::AddressTree{T})` for each +top-level address associated with `tree`. + +The length of this iterator must nonzero if this is not a leaf node. +""" +function get_subtrees_shallow end + +get_leaf_type(T::Type{AddressTree{U}}) where {U} = U + +""" +schema = get_address_schema(::Type{T}) where {T <: AddressTree} + +Return the (top-level) address schema for the given address tree type. +""" +function get_address_schema end +@inline get_address_schema(::Type{EmptyAddressTree}) = EmptyAddressSchema() +@inline get_address_schema(::Type{AllSelection}) = AllAddressSchema() + +@inline get_address_schema(::Type{Value}) = error("I don't think this currently gets called, and it's not part of the user-facing interface. If we need this, set the appropriate value then.") + +Base.isempty(::Value) = false +Base.isempty(::AllSelection) = false +Base.isempty(::EmptyAddressTree) = true +Base.isempty(::AddressTreeLeaf) = error("Not implemented") +Base.isempty(t::AddressTree) = all(((_, subtree),) -> isempty(subtree), get_subtrees_shallow(t)) + +@inline get_subtree(::AddressTreeLeaf, _) = EmptyAddressTree() +@inline get_subtrees_shallow(::AddressTreeLeaf) = () + +@inline get_subtree(::AllSelection, _) = AllSelection() + +function Base.:(==)(a::AddressTree, b::AddressTree) + for (addr, subtree) in get_subtrees_shallow(a) + if get_subtree(b, addr) != subtree + return false + end + end + for (addr, subtree) in get_subtrees_shallow(b) + if get_subtree(a, addr) != subtree + return false + end + end + return true +end +@inline Base.:(==)(a::Value, b::Value) = a.val == b.val +Base.:(==)(a::AddressTreeLeaf, b::AddressTreeLeaf) = false +Base.:(==)(::T, ::T) where {T <: AddressTreeLeaf} = true + +Base.isapprox(a::Value, b::Value) = isapprox(a.val, b.val) +Base.isapprox(::EmptyAddressTree, ::EmptyAddressTree) = true +Base.isapprox(::AllSelection, ::AllSelection) = true +function Base.isapprox(::AddressTreeLeaf{T}, ::AddressTreeLeaf{U}) where {T, U} + if T != U + false + else + error("Not implemented") + end +end +function Base.isapprox(a::AddressTree, b::AddressTree) + for (addr, subtree) in get_subtrees_shallow(a) + if !isapprox(get_subtree(b, addr), subtree) + return false + end + end + for (addr, subtree) in get_subtrees_shallow(b) + if !isapprox(get_subtree(a, addr), subtree) + return false + end + end + return true +end + +""" + Base.merge(a::AddressTree, b::AddressTree) + +Merge two address trees. +""" +function Base.merge(a::AddressTree{T}, b::AddressTree{U}) where {T, U} + tree = DynamicAddressTree{Union{T, U}}() + for (key, subtree) in get_subtrees_shallow(a) + set_subtree!(tree, key, merge(subtree, get_subtree(b, key))) + end + for (key, subtree) in get_subtrees_shallow(b) + if isempty(get_subtree(a, key)) + set_subtree!(tree, key, subtree) + end + end + tree +end +Base.merge(t::AddressTree, ::EmptyAddressTree) = t +Base.merge(::EmptyAddressTree, t::AddressTree) = t +Base.merge(t::AddressTreeLeaf, ::EmptyAddressTree) = t +Base.merge(::EmptyAddressTree, t::AddressTreeLeaf) = t + +Base.merge(::AddressTreeLeaf, ::AddressTree) = error("Not implemented") +Base.merge(::AddressTree, ::AddressTreeLeaf) = error("Not implemented") + +""" +Variadic merge of address trees. +""" +function Base.merge(first::AddressTree, rest::AddressTree...) + reduce(Base.merge, rest; init=first) +end + +function _show_pretty(io::IO, tree::AddressTree, pre, vert_bars::Tuple) + VERT = '\u2502' + PLUS = '\u251C' + HORZ = '\u2500' + LAST = '\u2514' + indent_vert = vcat(Char[' ' for _ in 1:pre], Char[VERT, '\n']) + indent_vert_last = vcat(Char[' ' for _ in 1:pre], Char[VERT, '\n']) + indent = vcat(Char[' ' for _ in 1:pre], Char[PLUS, HORZ, HORZ, ' ']) + indent_last = vcat(Char[' ' for _ in 1:pre], Char[LAST, HORZ, HORZ, ' ']) + for i in vert_bars + indent_vert[i] = VERT + indent[i] = VERT + indent_last[i] = VERT + end + indent_vert_str = join(indent_vert) + indent_vert_last_str = join(indent_vert_last) + indent_str = join(indent) + indent_last_str = join(indent_last) + key_and_subtrees = collect(get_subtrees_shallow(tree)) + n = length(key_and_subtrees) + cur = 1 + for (key, subtree) in key_and_subtrees + print(io, indent_vert_str) + if subtree isa AddressTreeLeaf + print(io, (cur == n ? indent_last_str : indent_str) * "$(repr(key)) : $subtree\n") + else + print(io, (cur == n ? indent_last_str : indent_str) * "$(repr(key))\n") + _show_pretty(io, subtree, pre + 4, cur == n ? (vert_bars...,) : (vert_bars..., pre+1)) + end + cur += 1 + end +end + +function Base.show(io::IO, ::MIME"text/plain", tree::AddressTree) + _show_pretty(io, tree, 0, ()) +end +Base.show(io::IO, ::MIME"text/plain", t::AddressTreeLeaf) = print(io, t) + +function nonempty_subtree_itr(itr) + ((addr, subtree) for (addr, subtree) in itr if !isempty(subtree)) +end + +include("dynamic_address_tree.jl") +include("static_address_tree.jl") + +include("choicemap.jl") +include("selection.jl") + +export get_subtree, get_subtrees_shallow +export EmptyAddressTree, Value, AllSelection, SelectionLeaf, CustomUpdateSpec, UpdateSpec +export get_address_schema \ No newline at end of file diff --git a/src/address_tree/array_interface.jl b/src/address_tree/array_interface.jl new file mode 100644 index 000000000..0b6eae2fd --- /dev/null +++ b/src/address_tree/array_interface.jl @@ -0,0 +1,110 @@ +### interface for to_array and fill_array ### + +# NOTE: currently this only works for choicemaps, +# but if we found we needed some sort of "to_array" for other types of +# address trees, I don't think it would be too hard to generalize + +""" + arr::Vector{T} = to_array(choices::ChoiceMap, ::Type{T}) where {T} + +Populate an array with values of choices in the given assignment. + +It is an error if each of the values cannot be coerced into a value of the +given type. + +Implementation + +The default implmentation of `fill_array` will populate the array by sorting +the addresses of the choicemap using the `sort` function, then iterating over +each submap in this order and filling the array for that submap. + +To override the default implementation of `to_array`, +a concrete subtype `T <: AddressTree{Value}` should implement the following method: + + n::Int = _fill_array!(choices::T, arr::Vector{V}, start_idx::Int) where {V} + +Populate `arr` with values from the given assignment, starting at `start_idx`, +and return the number of elements in `arr` that were populated. + +(This is for performance; it is more efficient to fill in values in a preallocated array +by implementing `_fill_array!` than to construct discontiguous arrays for each submap and then merge them.) +""" +function to_array(choices::ChoiceMap, ::Type{T}) where {T} + arr = Vector{T}(undef, 32) + n = _fill_array!(choices, arr, 1) + @assert n <= length(arr) + resize!(arr, n) + arr +end + +function _fill_array!(c::Value{<:T}, arr::Vector{T}, start_idx::Int) where {T} + if length(arr) < start_idx + resize!(arr, 2 * start_idx) + end + arr[start_idx] = get_value(c) + 1 +end +function _fill_array!(c::Value{<:Vector{<:T}}, arr::Vector{T}, start_idx::Int) where {T} + value = get_value(c) + if length(arr) < start_idx + length(value) + resize!(arr, 2 * (start_idx + length(value))) + end + arr[start_idx:start_idx+length(value)-1] = value + length(value) +end + +# default _fill_array! implementation +function _fill_array!(choices::ChoiceMap, arr::Vector{T}, start_idx::Int) where {T} + key_to_submap = collect(get_submaps_shallow(choices)) + sort!(key_to_submap, by = ((key, submap),) -> key) + idx = start_idx + for (key, submap) in key_to_submap + n_written = _fill_array!(submap, arr, idx) + idx += n_written + end + idx - start_idx +end + +""" + choices::ChoiceMap = from_array(proto_choices::ChoiceMap, arr::Vector) + +Return an assignment with the same address structure as a prototype +assignment, but with values read off from the given array. + +It is an error if the number of choices in the prototype assignment +is not equal to the length the array. + +The order in which addresses are populated with values from the array +should match the order in which the array is populated with values +in a call to `to_array(proto_choices, T)`. By default, +this means sorting the top-level addresses for `proto_choices` +and then filling in the submaps depth-first in this order. + +# Implementation + +To support `from_array`, a concrete subtype `T AddressTree{Value}` must implement +the following method: + + (n::Int, choices::T) = _from_array(proto_choices::T, arr::Vector{V}, start_idx::Int) where {V} + +Return an assignment with the same address structure as a prototype assignment, +but with values read off from `arr`, starting at position `start_idx`. Return the +number of elements read from `arr`. +""" +function from_array(proto_choices::ChoiceMap, arr::Vector) + (n, choices) = _from_array(proto_choices, arr, 1) + if n != length(arr) + error("Dimension mismatch: $n, $(length(arr))") + end + choices +end + +function _from_array(::Value, arr::Vector, start_idx::Int) + (1, Value(arr[start_idx])) +end +function _from_array(c::Value{<:Vector{<:T}}, arr::Vector{T}, start_idx::Int) where {T} + n_read = length(get_value(c)) + (n_read, Value(arr[start_idx:start_idx+n_read-1])) +end + +export to_array, from_array \ No newline at end of file diff --git a/src/address_tree/choicemap.jl b/src/address_tree/choicemap.jl new file mode 100644 index 000000000..c32d41c87 --- /dev/null +++ b/src/address_tree/choicemap.jl @@ -0,0 +1,162 @@ +""" + ChoiceMapGetValueError + +The error returned when a user attempts to call `get_value` +on an choicemap for an address which does not contain a value in that choicemap. +""" +struct ChoiceMapGetValueError <: Exception end +showerror(io::IO, ex::ChoiceMapGetValueError) = (print(io, "ChoiceMapGetValueError: no value was found for the `get_value` call.")) + +""" + ChoiceMap + +Abstract type for maps from hierarchical addresses to values. +""" +const ChoiceMap = AddressTree{<:Union{Value, EmptyAddressTree}} + +""" + get_submaps_shallow(choices::ChoiceMap) + +Returns an iterable collection of tuples `(address, submap)` +for each top-level address associated with `choices`. +(This includes `Value`s.) +""" +@inline get_submaps_shallow(c::ChoiceMap) = get_subtrees_shallow(c) + +""" + get_submap(choices::ChoiceMap, addr) + +Return the submap at the given address, or `EmptyChoiceMap` +if there is no submap at the given address. +""" +@inline get_submap(c::ChoiceMap, addr) = get_subtree(c, addr) + +@inline static_get_submap(c::ChoiceMap, a) = static_get_subtree(c, a) + +""" + has_value(tree::AddressTree) + +Returns true if `tree` is a `Value`. + + has_value(tree::AddressTree, addr) + +Returns true if `tree` has a value stored at address `addr`. +""" +function has_value end +@inline has_value(t::AddressTree, addr) = has_value(get_subtree(t, addr)) +has_value(::Value) = true +has_value(::AddressTree) = false + +""" + get_value(choices::ChoiceMap) + +Returns the value stored on `choices` is `choices` is a `Value`; +throws a `ChoiceMapGetValueError` if `choices` is not a `Value`. + + get_value(choices::ChoiceMap, addr) +Returns the value stored in the submap with address `addr` or throws +a `ChoiceMapGetValueError` if no value exists at this address. + +A syntactic sugar is `Base.getindex`: + + value = choices[addr] +""" +function get_value end +@inline get_value(::ChoiceMap) = throw(ChoiceMapGetValueError()) +@inline get_value(c::ChoiceMap, addr) = get_value(get_submap(c, addr)) +@inline Base.getindex(choices::ChoiceMap, addr...) = get_value(choices, addr...) + +""" + get_values_shallow(choices::ChoiceMap) + +Returns an iterable collection of tuples `(address, value)` +for each value stored at a top-level address in `choices`. +(Works by applying a filter to `get_submaps_shallow`, +so this internally requires iterating over every submap.) +""" +function get_values_shallow(choices::ChoiceMap) + ( + (addr, get_value(submap)) + for (addr, submap) in get_submaps_shallow(choices) + if has_value(submap) + ) +end + +""" + get_nonvalue_submaps_shallow(choices::ChoiceMap) + +Returns an iterable collection of tuples `(address, submap)` +for every top-level submap stored in `choices` which is +not a `Value`. +(Works by applying a filter to `get_submaps_shallow`, +so this internally requires iterating over every submap.) +""" +function get_nonvalue_submaps_shallow(choices::ChoiceMap) + (addr_to_submap for addr_to_submap in get_submaps_shallow(choices) if !has_value(addr_to_submap[2])) +end + +# support `DynamicChoiceMap` and `StaticChoiceMap` types, and the "legacy" DynamicChoiceMap interface +const DynamicChoiceMap = DynamicAddressTree{Value} +set_submap!(cm::DynamicChoiceMap, addr, submap::ChoiceMap) = set_subtree!(cm, addr, submap) +function set_value!(cm::DynamicChoiceMap, addr, val) + is_empty_or_subtree = let sub = get_subtree(cm, addr); sub isa Value || isempty(sub); end + @assert is_empty_or_subtree "Cannot assign a value to address `$addr` since there is a nonempty, nonvalue subtree at `$addr`" + set_subtree!(cm, addr, Value(val)) +end +Base.setindex!(cm::DynamicChoiceMap, val, addr) = set_value!(cm, addr, val) + +const StaticChoiceMap = StaticAddressTree{Value} +const EmptyChoiceMap = EmptyAddressTree + +""" + choices = choicemap() + +Construct an empty mutable choice map. +""" +function choicemap() + DynamicChoiceMap() +end + +""" + choices = choicemap(tuples...) + +Construct a mutable choice map initialized with given (address, value) tuples. +(Where `value` is the value to be stored, not a `Value` object.) +""" +function choicemap(tuples...) + cm = DynamicChoiceMap() + for (addr, val) in tuples + cm[addr] = val + end + cm +end + +""" + UnderlyingChoices(tree::AddressTree) + +A choicemap exposing all the choices in a given `tree`, removing any leaf +nodes which are not values. +""" +struct UnderlyingChoices <: AddressTree{Value} + tree::AddressTree +end +UnderlyingChoices(t::ChoiceMap) = t +UnderlyingChoices(v::Value) = v +UnderlyingChoices(::AddressTreeLeaf) = EmptyAddressTree() +UnderlyingChoices(::EmptyAddressTree) = EmptyAddressTree() +get_subtree(t::UnderlyingChoices, a) = UnderlyingChoices(get_subtree(t.tree, a)) +get_subtrees_shallow(t::UnderlyingChoices) = ((addr, UnderlyingChoices(subtree)) for (addr, subtree) in get_subtrees_shallow(t.tree) if UnderlyingChoices(subtree) !== EmptyAddressTree()) + +# TODO: we should be able to extract more information +get_address_schema(::Type{UnderlyingChoices}) = DynamicAddressSchema() + +export ChoiceMap, choicemap +export ChoiceMapGetValueError +export get_value, has_value, get_submap +export get_values_shallow, get_submaps_shallow, get_nonvalue_submaps_shallow +export EmptyChoiceMap, StaticChoiceMap, DynamicChoiceMap, UnderlyingChoices +export set_value!, set_submap! +export static_get_submap + +include("array_interface.jl") +include("nested_view.jl") \ No newline at end of file diff --git a/src/address_tree/dynamic_address_tree.jl b/src/address_tree/dynamic_address_tree.jl new file mode 100644 index 000000000..23901cdcc --- /dev/null +++ b/src/address_tree/dynamic_address_tree.jl @@ -0,0 +1,119 @@ +""" + struct DynamicAddressTree <: AddressTree .. end + +A mutable AddressTree. + + tree = DynamicAddressTree() + +Construct an empty address tree. + +""" +struct DynamicAddressTree{LeafType} <: AddressTree{LeafType} + subtrees::Dict{Any, AddressTree{<:LeafType}} +end +DynamicAddressTree{LeafType}() where {LeafType} = DynamicAddressTree{LeafType}(Dict{Any, AddressTree{LeafType}}()) + +""" + tree = address_tree() + +Construct an empty, mutable address tree. +""" +address_tree() = DynamicAddressTree{Any}() + +get_address_schema(::Type{<:DynamicAddressTree}) = DynamicAddressTree + +@inline get_subtrees_shallow(t::DynamicAddressTree) = t.subtrees +@inline get_subtree(t::DynamicAddressTree, addr) = get(t.subtrees, addr, EmptyAddressTree()) +@inline get_subtree(t::DynamicAddressTree, addr::Pair) = _get_subtree(t, addr) +@inline Base.isempty(t::DynamicAddressTree) = isempty(t.subtrees) + +function set_subtree!(t::DynamicAddressTree, addr, new_node::AddressTree) + delete!(t.subtrees, addr) + if !isempty(new_node) + t.subtrees[addr] = new_node + end +end +function set_subtree!(t::DynamicAddressTree{T}, addr::Pair, new_node::AddressTree) where {T} + (first, rest) = addr + if !haskey(t.subtrees, first) + t.subtrees[first] = DynamicAddressTree{T}() + end + set_subtree!(t.subtrees[first], rest, new_node) +end + +""" + tree = shallow_dynamic_copy(other::AddressTree) + +Make a shallow `DynamicAddressTree` copy of the given address tree. +""" +function shallow_dynamic_copy(other::AddressTree{LeafType}) where {LeafType} + tree = DynamicAddressTree{LeafType}() + for (addr, subtree) in get_subtrees_shallow(other) + set_subtree!(tree, addr, subtree) + end + tree +end + +""" + tree = deep_dynamic_copy(other::AddressTree) + +Make a deep copy of the given address tree, where every non-leaf-node +is a `DynamicAddressTree`. +""" +function deep_dynamic_copy(other::AddressTree{LeafType}) where {LeafType} + tree = DynamicAddressTree{LeafType}() + for (addr, subtree) in get_subtrees_shallow(other) + if subtree isa AddressTreeLeaf + set_subtree!(tree, addr, subtree) + else + set_subtree!(tree, addr, deep_dynamic_copy(subtree)) + end + end + tree +end + +""" + tree = DynamicAddressTree(other::AddressTree) + +Shallowly convert an address tree to dynamic. +""" +DynamicAddressTree(t::AddressTree) = shallow_dynamic_copy(t) +DynamicAddressTree(t::DynamicAddressTree) = t +DynamicAddressTree{LeafType}(t::AddressTree) where {LeafType} = DynamicAddressTree(t) + +function _from_array(proto_choices::DynamicAddressTree{LT}, arr::Vector{T}, start_idx::Int) where {T, LT} + choices = DynamicAddressTree{LT}() + keys_sorted = sort(collect(keys(proto_choices.subtrees))) + idx = start_idx + for key in keys_sorted + (n_read, submap) = _from_array(proto_choices.subtrees[key], arr, idx) + idx += n_read + choices.subtrees[key] = submap + end + (idx - start_idx, choices) +end + +""" + merge!(into::DynamicAddressTree{T}, from::DynamicAddressTree{U}) where U <: T + +Merge all the subtrees from `from` into `into`. + +Merges in place as deeply as possible, +until encountering a non-dynamic subtree of `into`, at which point +this resorts to non-mutating `merge` for that subtree. +""" +function Base.merge!(into::DynamicAddressTree{T}, from::DynamicAddressTree{U}) where {T, U <: T} + for (addr, from_subtree) in get_subtrees_shallow(from) + into_subtree = get_subtree(into, addr) + if isempty(into_subtree) + set_subtree!(into, addr, from_subtree) + elseif into_subtree isa DynamicAddressTree + merge!(into_subtree, from_subtree) + else + set_subtree!(into, addr, merge(into_subtree, from_subtree)) + end + end +end +Base.merge!(into::DynamicAddressTree, ::EmptyAddressTree) = into + +export DynamicAddressTree, set_subtree! \ No newline at end of file diff --git a/src/address_tree/nested_view.jl b/src/address_tree/nested_view.jl new file mode 100644 index 000000000..40593c162 --- /dev/null +++ b/src/address_tree/nested_view.jl @@ -0,0 +1,82 @@ +############################################ +# Nested-dict–like accessor for choicemaps # +############################################ + +# TODO: augment this to work in a form for any address tree? + +""" +Wrapper for a `ChoiceMap` that provides nested-dict–like syntax, rather than +the default syntax which looks like a flat dict of full keypaths. + +```jldoctest +julia> using Gen +julia> c = choicemap((:a, 1), + (:b => :c, 2)); +julia> cv = nested_view(c); +julia> c[:a] == cv[:a] +true +julia> c[:b => :c] == cv[:b][:c] +true +julia> length(cv) +2 +julia> length(cv[:b]) +1 +julia> sort(collect(keys(cv))) +[:a, :b] +julia> sort(collect(keys(cv[:b]))) +[:c] +``` +""" +struct ChoiceMapNestedView + choice_map::ChoiceMap +end + +ChoiceMapNestedView(cm::Value) = get_value(cm) +ChoiceMapNestedView(::EmptyAddressTree) = error("Can't convert an emptychoicemap to nested view.") + +function Base.getindex(choices::ChoiceMapNestedView, addr) + ChoiceMapNestedView(get_submap(choices.choice_map, addr)) +end + +function Base.iterate(c::ChoiceMapNestedView) + itr = ((k, ChoiceMapNestedView(s)) for (k, s) in get_submaps_shallow(c.choice_map)) + r = Base.iterate(itr) + if r === nothing + return nothing + end + (next_kv, next_inner_state) = r + (next_kv, (itr, next_inner_state)) +end + +function Base.iterate(c::ChoiceMapNestedView, state) + (itr, st) = state + r = Base.iterate(itr, st) + if r === nothing + return nothing + end + (next_kv, next_inner_state) = r + (next_kv, (itr, next_inner_state)) +end + +# TODO: Allow different implementations of this method depending on the +# concrete type of the `ChoiceMap`, so that an already-existing data structure +# with faster key lookup (analogous to `Base.KeySet`) can be exposed if it +# exists. +Base.keys(cv::ChoiceMapNestedView) = (k for (k, v) in cv) + +Base.:(==)(a::ChoiceMapNestedView, b::ChoiceMapNestedView) = a.choice_map == b.choice_map + +function Base.length(cv::ChoiceMapNestedView) + length(collect(get_submaps_shallow(cv.choice_map))) +end +function Base.show(io::IO, ::MIME"text/plain", c::ChoiceMapNestedView) + Base.show(io, MIME"text/plain"(), c.choice_map) +end + +nested_view(c::ChoiceMap) = ChoiceMapNestedView(c) + +# TODO(https://github.com/probcomp/Gen/issues/167): Also allow calling +# `nested_view(::Trace)`, to get a nested-dict–like view of the choicemap and +# aux data together. + +export nested_view \ No newline at end of file diff --git a/src/address_tree/selection.jl b/src/address_tree/selection.jl new file mode 100644 index 000000000..0da22fe37 --- /dev/null +++ b/src/address_tree/selection.jl @@ -0,0 +1,171 @@ +const Selection = AddressTree{<:Union{SelectionLeaf, EmptyAddressTree}} + +const StaticSelection = StaticAddressTree{SelectionLeaf} +const EmptySelection = EmptyAddressTree + +""" + in(addr, selection::Selection) + +Whether the address is selected in the given selection. +""" +@inline function Base.in(addr, selection::Selection) + get_subtree(selection, addr) === AllSelection() +end + +# indexing returns subtrees for selections +Base.getindex(selection::AddressTree{SelectionLeaf}, addr) = get_subtree(selection, addr) + +# TODO: deprecate indexing syntax and only use this +get_subselection(s::Selection, addr) = get_subtree(s, addr) + +get_subselections(s::Selection) = get_subtrees_shallow(s) + +Base.merge(::AllSelection, ::Selection) = AllSelection() +Base.merge(::Selection, ::AllSelection) = AllSelection() +Base.merge(::AllSelection, ::AllSelection) = AllSelection() +Base.merge(::AllSelection, ::EmptySelection) = AllSelection() +Base.merge(::EmptySelection, ::AllSelection) = AllSelection() + +""" + filtered = SelectionFilteredAddressTree(tree, selection) + +An address tree containing only the nodes in `tree` whose addresses are selected +in `selection.` +""" +struct SelectionFilteredAddressTree{T} <: AddressTree{T} + tree::AddressTree{T} + sel::Selection +end +SelectionFilteredAddressTree(t::AddressTree, ::AllSelection) = t +SelectionFilteredAddressTree(t::AddressTreeLeaf, ::AllSelection) = t +SelectionFilteredAddressTree(::AddressTree, ::EmptyAddressTree) = EmptyAddressTree() +SelectionFilteredAddressTree(::AddressTreeLeaf, ::EmptyAddressTree) = EmptyAddressTree() +SelectionFilteredAddressTree(::AddressTreeLeaf, ::Selection) = EmptyAddressTree() # if we hit a leaf node before a selected value, the node is not selected + +function get_subtree(t::SelectionFilteredAddressTree, addr) + subselection = get_subtree(t.sel, addr) + if subselection === EmptyAddressTree() + EmptyAddressTree() + else + SelectionFilteredAddressTree(get_subtree(t.tree, addr), subselection) + end +end + +function get_subtrees_shallow(t::SelectionFilteredAddressTree) + nonempty_subtree_itr( + (addr, SelectionFilteredAddressTree(subtree, get_subtree(t.sel, addr))) + for (addr, subtree) in get_subtrees_shallow(t.tree) + ) +end + +""" + selected = get_selected(tree::AddressTree, selection::Selection) + +Filter the address tree `tree` to only include leaf nodes at selected +addresses. +""" +get_selected(tree::AddressTree, selection::Selection) = SelectionFilteredAddressTree(tree, selection) + +""" + struct DynamicSelection <: HierarchicalSelection .. end +A hierarchical, mutable, selection with arbitrary addresses. +Can be mutated with the following methods: + Base.push!(selection::DynamicSelection, addr) +Add the address and all of its sub-addresses to the selection. +Example: +```julia +selection = select() +@assert !(:x in selection) +push!(selection, :x) +@assert :x in selection +``` + set_subselection!(selection::DynamicSelection, addr, other::Selection) +Change the selection status of the given address and its sub-addresses that defined by `other`. +Example: +```julia +selection = select(:x) +@assert :x in selection +subselection = select(:y) +set_subselection!(selection, :x, subselection) +@assert (:x => :y) in selection +@assert !(:x in selection) +``` +Note that `set_subselection!` does not copy data in `other`, so `other` may be mutated by a later calls to `set_subselection!` for addresses under `addr`. +""" +const DynamicSelection = DynamicAddressTree{SelectionLeaf} +set_subselection!(s::DynamicSelection, addr, sub::Selection) = set_subtree!(s, addr, sub) + +function Base.push!(s::DynamicSelection, addr) + set_subtree!(s, addr, AllSelection()) +end +function Base.push!(s::DynamicSelection, addr::Pair) + first, rest = addr + subtree = get_subtree(s, first) + if subtree isa DynamicSelection + push!(subtree, rest) + else + new_subtree = select(rest) + merge!(new_subtree, subtree) + set_subtree!(s, first, new_subtree) + end +end + +function select(addrs...) + selection = DynamicSelection() + for addr in addrs + set_subtree!(selection, addr, AllSelection()) + end + selection +end + +""" + AddressSelection(::AddressTree) + +A selection containing all of the addresses in the given address tree with a nonempty leaf node. +""" +struct AddressSelection{T} <: AddressTree{SelectionLeaf} + a::T + AddressSelection(a::T) where {T <: AddressTree} = new{T}(a) +end +AddressSelection(::AddressTreeLeaf) = AllSelection() +AddressSelection(::EmptyAddressTree) = EmptyAddressTree() +get_subtree(a::AddressSelection, addr) = AddressSelection(get_subtree(a.a, addr)) +function get_subtrees_shallow(a::AddressSelection) + nonempty_subtree_itr((addr, AddressSelection(subtree)) for (addr, subtree) in get_subtrees_shallow(a.a)) +end +get_address_schema(::Type{AddressSelection{T}}) where {T} = get_address_schema(T) + +""" + addrs(::AddressTree) + +Returns a selection containing all of the addresses in the tree with a nonempty leaf node. +""" +addrs(a::AddressTree) = AddressSelection(a) + +""" + invert(sel::Selection) + InvertedSelection(sel::Selection) + +"Inverts" `sel` by transforming every `AllSelection` subtree +to an `EmptySelection` and every `EmptySelection` to an `AllSelection`. +""" +invert(sel::Selection) = InvertedSelection(sel) +struct InvertedSelection{SelectionType} <: SelectionLeaf + sel::SelectionType + InvertedSelection(t::T) where {T <: Selection} = new{T}(t) +end +@inline InvertedSelection(::AllSelection) = EmptySelection() +@inline InvertedSelection(::EmptySelection) = AllSelection() +get_subtree(s::InvertedSelection, address) = InvertedSelection(get_subtree(s.sel, address)) +# get_subtrees_shallow uses default implementation for ::AddressTreeLeaf to return () +Base.isempty(::InvertedSelection) = false + +function get_address_schema(::Type{InvertedSelection{T}}) where {T <: StaticAddressTree} + StaticInverseAddressSchema(get_address_schema(T)) +end +get_address_schema(::Type{<:InvertedSelection}) = DynamicAddressSchema() +@inline static_get_subtree(s::InvertedSelection, v::Val) = InvertedSelection(static_get_subtree(s.sel, v)) +StaticAddressTree(t::InvertedSelection) = InvertedSelection(StaticAddressTree(t.sel)) + +export select, get_selected, addrs, get_subselection, get_subselections, invert +export Selection, DynamicSelection, EmptySelection, StaticSelection, InvertedSelection \ No newline at end of file diff --git a/src/address_tree/static_address_tree.jl b/src/address_tree/static_address_tree.jl new file mode 100644 index 000000000..59bd71ad5 --- /dev/null +++ b/src/address_tree/static_address_tree.jl @@ -0,0 +1,195 @@ +struct StaticAddressTree{LeafType, Addrs, SubtreeTypes} <: AddressTree{LeafType} + subtrees::NamedTuple{Addrs, SubtreeTypes} + function StaticAddressTree{LeafType}(nt::NamedTuple{Addrs, Subtrees}) where { + LeafType, Addrs, Subtrees <: Tuple{Vararg{<:Union{AddressTree{<:LeafType}, EmptyAddressTree}}} + } + new{LeafType, Addrs, Subtrees}(nt) + end +end + +# NOTE: this constructor makes it possible to construct `StaticAddressTree`s +# which have `EmptyAddressTree` as a subtree. If we want to avoid this, +# we can have the inner constructor only accept +# Subtrees <: Tuple{Vararg{AddressTree{<:LeafType}}} +# and have a separate outer constructor accepting +# Subtrees <: Tuple{Vararg{<:Union{AddressTree{<:LeafType}, EmptyAddressTree}}} +# that removes the empty subtrees. (Possibly this should be an @generated method +# which does this at compile-time.) + +# NOTE: It is probably better to avoid using this constructor when possible since I suspect it is less performant +# than if we specify `LeafType`. +function StaticAddressTree(subtrees::NamedTuple{Addrs, SubtreeTypes}) where {Addrs, SubtreeTypes <: Tuple{Vararg{AddressTree}}} + uniontype = Union{SubtreeTypes.parameters...} + if @generated + quote StaticAddressTree{$uniontype}(subtrees) end + else + StaticAddressTree{uniontype}(subtrees) + end +end +""" + StaticAddressTree{LeafType}(; a=val, b=tree, ...) + StaticAddressTree(; a=val, b=tree, ...) + +Construct a static address tree with the given address-subtree +or address-value pairs. (The addresses must be top-level symbols; +if the RHS is an AddressTree, this will be the subtree; if not, the +subtree will be a `Value` with the given value.) +""" +StaticAddressTree(;addrs_to_vals_and_trees...) = StaticAddressTree(addrs_subtrees_namedtuple(addrs_to_vals_and_trees)) +StaticAddressTree{LeafType}(; addrs_to_vals_and_trees...) where {LeafType} = StaticAddressTree{LeafType}(addrs_subtrees_namedtuple(addrs_to_vals_and_trees)) + +function addrs_subtrees_namedtuple(addrs_to_vals_and_trees) + addrs = Tuple(addr for (addr, val_or_map) in addrs_to_vals_and_trees) + trees = Tuple(val_or_map isa AddressTree ? val_or_map : Value(val_or_map) for (addr, val_or_map) in addrs_to_vals_and_trees) + NamedTuple{addrs}(trees) +end + +@inline get_subtrees_shallow(t::StaticAddressTree) = pairs(t.subtrees) +@inline get_submap(t::StaticAddressTree, addr::Pair) = _get_subtree(t, addr) + +function get_subtree(t::StaticAddressTree{LeafType, Addrs}, addr::Symbol) where {LeafType, Addrs} + if addr in Addrs + t.subtrees[addr] + else + EmptyAddressTree() + end +end + +@generated function static_get_subtree(t::StaticAddressTree{LeafType, Addrs}, ::Val{A}) where {A, Addrs, LeafType} + if A in Addrs + quote t.subtrees[A] end + else + quote EmptyAddressTree() end + end +end +@inline static_get_subtree(::EmptyAddressTree, ::Val) = EmptyAddressTree() +@inline static_get_subtree(::Value, ::Val) = EmptyAddressTree() +@inline static_get_subtree(::AllSelection, ::Val) = AllSelection() + +@inline static_get_value(choices::StaticAddressTree, v::Val) = get_value(static_get_subtree(choices, v)) +@inline static_get_value(::EmptyAddressTree, ::Val) = throw(ChoiceMapGetValueError()) + +# convert a nonvalue choicemap all of whose top-level-addresses +# are symbols into a staticchoicemap at the top level +StaticAddressTree(t::StaticAddressTree) = t +function StaticAddressTree(other::AddressTree{LeafType}) where {LeafType} + keys_and_nodes = collect(get_subtrees_shallow(other)) + if length(keys_and_nodes) > 0 + addrs = Tuple(key for (key, _) in keys_and_nodes) + submaps = Tuple(submap for (_, submap) in keys_and_nodes) + else + addrs = () + submaps = () + end + StaticAddressTree{LeafType}(NamedTuple{addrs}(submaps)) +end +StaticAddressTree(::AllSelection) = AllSelection() +StaticAddressTree(::EmptyAddressTree) = EmptyAddressTree() +StaticAddressTree(v::Value) = v +StaticAddressTree(::AddressTreeLeaf) = error("Not implemented") +StaticAddressTree{LeafType}(::NamedTuple{(),Tuple{}}) where {LeafType} = EmptyAddressTree() +StaticAddressTree(::NamedTuple{(),Tuple{}}) = EmptyAddressTree() +StaticAddressTree{LeafType}(other::AddressTree{<:LeafType}) where {LeafType} = StaticAddressTree(other) + +# TODO: deep conversion to static choicemap + +""" + tree = pair(tree1::AddressTree, tree2::AddressTree, key1::Symbol, key2::Symbol) + +Return an address tree that contains `tree1` as a subtree under `key1` +and `tree2` as a subtree under `key2`. +""" +function pair(tree1::AddressTree, tree2::AddressTree, key1::Symbol, key2::Symbol) + StaticAddressTree(NamedTuple{(key1, key2)}((tree1, tree2))) +end + +""" + (tree1, tree2) = unpair(tree::AddressTree, key1::Symbol, key2::Symbol) + +Return the two subtrees at `key1` and `key2`, one or both of which may be empty. + +It is an error if there are any subtrees at keys other than `key1` and `key2`. +""" +function unpair(tree::AddressTree, key1::Symbol, key2::Symbol) + if length(collect(get_subtrees_shallow(tree))) != 2 + error("Not a pair") + end + (get_subtree(tree, key1), get_subtree(tree, key2)) +end + +@generated function Base.merge(tree1::StaticAddressTree{T1, Addrs1, SubmapTypes1}, + tree2::StaticAddressTree{T2, Addrs2, SubmapTypes2}) where {T1, T2, Addrs1, Addrs2, SubmapTypes1, SubmapTypes2} + + addr_to_type1 = Dict{Symbol, Type{<:AddressTree}}() + addr_to_type2 = Dict{Symbol, Type{<:AddressTree}}() + for (i, addr) in enumerate(Addrs1) + addr_to_type1[addr] = SubmapTypes1.parameters[i] + end + for (i, addr) in enumerate(Addrs2) + addr_to_type2[addr] = SubmapTypes2.parameters[i] + end + + merged_addrs = Tuple(union(Set(Addrs1), Set(Addrs2))) + submap_exprs = [] + + for addr in merged_addrs + type1 = get(addr_to_type1, addr, EmptyAddressTree) + type2 = get(addr_to_type2, addr, EmptyAddressTree) + + if type1 <: EmptyAddressTree + push!(submap_exprs, + quote tree2.subtrees.$addr end + ) + elseif type2 <: EmptyAddressTree + push!(submap_exprs, + quote tree1.subtrees.$addr end + ) + else + push!(submap_exprs, + quote merge(tree1.subtrees.$addr, tree2.subtrees.$addr) end + ) + end + end + + leaftype = Union{T1, T2} + + quote + StaticAddressTree{$leaftype}(NamedTuple{$merged_addrs}(($(submap_exprs...),))) + end +end + +@generated function _from_array(proto_choices::StaticAddressTree{LT, Addrs, SubmapTypes}, + arr::Vector{T}, start_idx::Int) where {LT, T, Addrs, SubmapTypes} + + perm = sortperm(collect(Addrs)) + sorted_addrs = Addrs[perm] + submap_var_names = Vector{Symbol}(undef, length(sorted_addrs)) + + exprs = [quote idx = start_idx end] + + for (idx, addr) in zip(perm, sorted_addrs) + submap_var_name = gensym(addr) + submap_var_names[idx] = submap_var_name + push!(exprs, + quote + (n_read, $submap_var_name) = _from_array(proto_choices.subtrees.$addr, arr, idx) + idx += n_read + end + ) + end + + quote + $(exprs...) + submaps = NamedTuple{Addrs}(( $(submap_var_names...), )) + choices = StaticAddressTree{LT}(submaps) + (idx - start_idx, choices) + end +end + +function get_address_schema(::Type{StaticAddressTree{LT, Addrs, SubtreeTypes}}) where {LT, Addrs, SubtreeTypes} + StaticAddressSchema(Set(Addrs)) +end + +export StaticAddressTree +export pair, unpair +export static_get_subtree, static_get_value \ No newline at end of file diff --git a/src/backprop.jl b/src/backprop.jl index 8590e989d..ac95220d3 100644 --- a/src/backprop.jl +++ b/src/backprop.jl @@ -17,31 +17,4 @@ increment_deriv!(arg, deriv) = ReverseDiff.increment_deriv!(arg, deriv) seed!(tracked) = ReverseDiff.seed!(tracked) unseed!(tracked) = ReverseDiff.unseed!(tracked) -using ReverseDiff: InstructionTape, TrackedReal, SpecialInstruction, TrackedArray - -######## -# fill # -######## - -function Base.fill(x::TrackedReal{V}, dims::Integer...) where {V} - tp = ReverseDiff.tape(x) - out = ReverseDiff.track(fill(ReverseDiff.value(x), dims...), V, tp) - ReverseDiff.record!(tp, SpecialInstruction, fill, (x, dims), out) - return out -end - -@noinline function ReverseDiff.special_reverse_exec!( - instruction::SpecialInstruction{typeof(fill)}) - x, dims = instruction.input - output = instruction.output - ReverseDiff.istracked(x) && ReverseDiff.increment_deriv!(x, sum(ReverseDiff.deriv(output))) - ReverseDiff.unseed!(output) - return nothing -end - -@noinline function ReverseDiff.special_forward_exec!( - instruction::SpecialInstruction{typeof(fill)}) - x, dims = instruction.input - ReverseDiff.value!(instruction.output, fill(value(x), dims...)) - return nothing -end +using ReverseDiff: InstructionTape, TrackedReal, SpecialInstruction, TrackedArray \ No newline at end of file diff --git a/src/choice_map.jl b/src/choice_map.jl deleted file mode 100644 index b7891b40a..000000000 --- a/src/choice_map.jl +++ /dev/null @@ -1,1009 +0,0 @@ -######################### -# choice map interface # -######################### - -""" - schema = get_address_schema(::Type{T}) where {T <: ChoiceMap} - -Return the (top-level) address schema for the given choice map. -""" -function get_address_schema end - -""" - submap = get_submap(choices::ChoiceMap, addr) - -Return the sub-assignment containing all choices whose address is prefixed by addr. - -It is an error if the assignment contains a value at the given address. If -there are no choices whose address is prefixed by addr then return an -`EmptyChoiceMap`. -""" -function get_submap end - -""" - value = get_value(choices::ChoiceMap, addr) - -Return the value at the given address in the assignment, or throw a KeyError if -no value exists. A syntactic sugar is `Base.getindex`: - - value = choices[addr] -""" -function get_value end - -""" - key_submap_iterable = get_submaps_shallow(choices::ChoiceMap) - -Return an iterable collection of tuples `(key, submap::ChoiceMap)` for each top-level key -that has a non-empty sub-assignment. -""" -function get_submaps_shallow end - -""" - has_value(choices::ChoiceMap, addr) - -Return true if there is a value at the given address. -""" -function has_value end - -""" - key_submap_iterable = get_values_shallow(choices::ChoiceMap) - -Return an iterable collection of tuples `(key, value)` for each -top-level key associated with a value. -""" -function get_values_shallow end - -""" - abstract type ChoiceMap end - -Abstract type for maps from hierarchical addresses to values. -""" -abstract type ChoiceMap end - -""" - Base.isempty(choices::ChoiceMap) - -Return true if there are no values in the assignment. -""" -function Base.isempty(::ChoiceMap) - true -end - -@inline get_submap(choices::ChoiceMap, addr) = EmptyChoiceMap() -@inline has_value(choices::ChoiceMap, addr) = false -@inline get_value(choices::ChoiceMap, addr) = throw(KeyError(addr)) -@inline Base.getindex(choices::ChoiceMap, addr) = get_value(choices, addr) - -@inline function _has_value(choices::T, addr::Pair) where {T <: ChoiceMap} - (first, rest) = addr - submap = get_submap(choices, first) - has_value(submap, rest) -end - -@inline function _get_value(choices::T, addr::Pair) where {T <: ChoiceMap} - (first, rest) = addr - submap = get_submap(choices, first) - get_value(submap, rest) -end - -@inline function _get_submap(choices::T, addr::Pair) where {T <: ChoiceMap} - (first, rest) = addr - submap = get_submap(choices, first) - get_submap(submap, rest) -end - -function _show_pretty(io::IO, choices::ChoiceMap, pre, vert_bars::Tuple) - VERT = '\u2502' - PLUS = '\u251C' - HORZ = '\u2500' - LAST = '\u2514' - indent_vert = vcat(Char[' ' for _ in 1:pre], Char[VERT, '\n']) - indent_vert_last = vcat(Char[' ' for _ in 1:pre], Char[VERT, '\n']) - indent = vcat(Char[' ' for _ in 1:pre], Char[PLUS, HORZ, HORZ, ' ']) - indent_last = vcat(Char[' ' for _ in 1:pre], Char[LAST, HORZ, HORZ, ' ']) - for i in vert_bars - indent_vert[i] = VERT - indent[i] = VERT - indent_last[i] = VERT - end - indent_vert_str = join(indent_vert) - indent_vert_last_str = join(indent_vert_last) - indent_str = join(indent) - indent_last_str = join(indent_last) - key_and_values = collect(get_values_shallow(choices)) - key_and_submaps = collect(get_submaps_shallow(choices)) - n = length(key_and_values) + length(key_and_submaps) - cur = 1 - for (key, value) in key_and_values - # For strings, `print` is what we want; `Base.show` includes quote marks. - # https://docs.julialang.org/en/v1/base/io-network/#Base.print - print(io, indent_vert_str) - print(io, (cur == n ? indent_last_str : indent_str) * "$(repr(key)) : $value\n") - cur += 1 - end - for (key, submap) in key_and_submaps - print(io, indent_vert_str) - print(io, (cur == n ? indent_last_str : indent_str) * "$(repr(key))\n") - _show_pretty(io, submap, pre + 4, cur == n ? (vert_bars...,) : (vert_bars..., pre+1)) - cur += 1 - end -end - -function Base.show(io::IO, ::MIME"text/plain", choices::ChoiceMap) - _show_pretty(io, choices, 0, ()) -end - -# assignments that have static address schemas should also support faster -# accessors, which make the address explicit in the type (Val(:foo) instaed of -# :foo) -function static_get_value end -function static_get_submap end - -function _fill_array! end -function _from_array end - -""" - arr::Vector{T} = to_array(choices::ChoiceMap, ::Type{T}) where {T} - -Populate an array with values of choices in the given assignment. - -It is an error if each of the values cannot be coerced into a value of the -given type. - -# Implementation - -To support `to_array`, a concrete subtype `T <: ChoiceMap` should implement -the following method: - - n::Int = _fill_array!(choices::T, arr::Vector{V}, start_idx::Int) where {V} - -Populate `arr` with values from the given assignment, starting at `start_idx`, -and return the number of elements in `arr` that were populated. -""" -function to_array(choices::ChoiceMap, ::Type{T}) where {T} - arr = Vector{T}(undef, 32) - n = _fill_array!(choices, arr, 1) - @assert n <= length(arr) - resize!(arr, n) - arr -end - -function _fill_array!(value::T, arr::Vector{T}, start_idx::Int) where {T} - if length(arr) < start_idx - resize!(arr, 2 * start_idx) - end - arr[start_idx] = value - 1 -end - -function _fill_array!(value::Vector{T}, arr::Vector{T}, start_idx::Int) where {T} - if length(arr) < start_idx + length(value) - resize!(arr, 2 * (start_idx + length(value))) - end - arr[start_idx:start_idx+length(value)-1] = value - length(value) -end - - -""" - choices::ChoiceMap = from_array(proto_choices::ChoiceMap, arr::Vector) - -Return an assignment with the same address structure as a prototype -assignment, but with values read off from the given array. - -The order in which addresses are populated is determined by the prototype -assignment. It is an error if the number of choices in the prototype assignment -is not equal to the length the array. - -# Implementation - -To support `from_array`, a concrete subtype `T <: ChoiceMap` should implement -the following method: - - - (n::Int, choices::T) = _from_array(proto_choices::T, arr::Vector{V}, start_idx::Int) where {V} - -Return an assignment with the same address structure as a prototype assignment, -but with values read off from `arr`, starting at position `start_idx`, and the -number of elements read from `arr`. -""" -function from_array(proto_choices::ChoiceMap, arr::Vector) - (n, choices) = _from_array(proto_choices, arr, 1) - if n != length(arr) - error("Dimension mismatch: $n, $(length(arr))") - end - choices -end - -function _from_array(::T, arr::Vector{T}, start_idx::Int) where {T} - (1, arr[start_idx]) -end - -function _from_array(value::Vector{T}, arr::Vector{T}, start_idx::Int) where {T} - n_read = length(value) - (n_read, arr[start_idx:start_idx+n_read-1]) -end - - -""" - choices = Base.merge(choices1::ChoiceMap, choices2::ChoiceMap) - -Merge two choice maps. - -It is an error if the choice maps both have values at the same address, or if -one choice map has a value at an address that is the prefix of the address of a -value in the other choice map. -""" -function Base.merge(choices1::ChoiceMap, choices2::ChoiceMap) - choices = DynamicChoiceMap() - for (key, value) in get_values_shallow(choices1) - choices.leaf_nodes[key] = value - end - for (key, node1) in get_submaps_shallow(choices1) - node2 = get_submap(choices2, key) - node = merge(node1, node2) - choices.internal_nodes[key] = node - end - for (key, value) in get_values_shallow(choices2) - if haskey(choices.leaf_nodes, key) - error("choices1 has leaf node at $key and choices2 has leaf node at $key") - end - if haskey(choices.internal_nodes, key) - error("choices1 has internal node at $key and choices2 has leaf node at $key") - end - choices.leaf_nodes[key] = value - end - for (key, node) in get_submaps_shallow(choices2) - if haskey(choices.leaf_nodes, key) - error("choices1 has leaf node at $key and choices2 has internal node at $key") - end - if !haskey(choices.internal_nodes, key) - # otherwise it should already be included - choices.internal_nodes[key] = node - end - end - return choices -end - -""" -Variadic merge of choice maps. -""" -function Base.merge(choices1::ChoiceMap, choices_rest::ChoiceMap...) - reduce(Base.merge, choices_rest; init=choices1) -end - -function Base.:(==)(a::ChoiceMap, b::ChoiceMap) - for (addr, value) in get_values_shallow(a) - if !has_value(b, addr) || (get_value(b, addr) != value) - return false - end - end - for (addr, value) in get_values_shallow(b) - if !has_value(a, addr) || (get_value(a, addr) != value) - return false - end - end - for (addr, submap) in get_submaps_shallow(a) - if submap != get_submap(b, addr) - return false - end - end - for (addr, submap) in get_submaps_shallow(b) - if submap != get_submap(a, addr) - return false - end - end - return true -end - -function Base.isapprox(a::ChoiceMap, b::ChoiceMap) - for (addr, value) in get_values_shallow(a) - if !has_value(b, addr) || !isapprox(get_value(b, addr), value) - return false - end - end - for (addr, value) in get_values_shallow(b) - if !has_value(a, addr) || !isapprox(get_value(a, addr), value) - return false - end - end - for (addr, submap) in get_submaps_shallow(a) - if !isapprox(submap, get_submap(b, addr)) - return false - end - end - for (addr, submap) in get_submaps_shallow(b) - if !isapprox(submap, get_submap(a, addr)) - return false - end - end - return true -end - - -export ChoiceMap -export get_address_schema -export get_submap -export get_value -export has_value -export get_submaps_shallow -export get_values_shallow -export static_get_value -export static_get_submap -export to_array, from_array - - -###################### -# static assignment # -###################### - -struct StaticChoiceMap{R,S,T,U} <: ChoiceMap - leaf_nodes::NamedTuple{R,S} - internal_nodes::NamedTuple{T,U} - isempty::Bool -end - -function StaticChoiceMap{R,S,T,U}(leaf_nodes::NamedTuple{R,S}, internal_nodes::NamedTuple{T,U}) where {R,S,T,U} - is_empty = length(leaf_nodes) == 0 && all(isempty(n) for n in internal_nodes) - StaticChoiceMap(leaf_nodes, internal_nodes, is_empty) -end - -function StaticChoiceMap(leaf_nodes::NamedTuple{R,S}, internal_nodes::NamedTuple{T,U}) where {R,S,T,U} - is_empty = length(leaf_nodes) == 0 && all(isempty(n) for n in internal_nodes) - StaticChoiceMap(leaf_nodes, internal_nodes, is_empty) -end - - -# invariant: all internal_nodes are nonempty - -function get_address_schema(::Type{StaticChoiceMap{R,S,T,U}}) where {R,S,T,U} - keys = Set{Symbol}() - for (key, _) in zip(R, S.parameters) - push!(keys, key) - end - for (key, _) in zip(T, U.parameters) - push!(keys, key) - end - StaticAddressSchema(keys) -end - -function Base.isempty(choices::StaticChoiceMap) - choices.isempty -end - -get_values_shallow(choices::StaticChoiceMap) = pairs(choices.leaf_nodes) -get_submaps_shallow(choices::StaticChoiceMap) = pairs(choices.internal_nodes) -has_value(choices::StaticChoiceMap, addr::Pair) = _has_value(choices, addr) -get_value(choices::StaticChoiceMap, addr::Pair) = _get_value(choices, addr) -get_submap(choices::StaticChoiceMap, addr::Pair) = _get_submap(choices, addr) - -# NOTE: there is no static_has_value because this is known from the static -# address schema - -## has_value ## - -function has_value(choices::StaticChoiceMap, key::Symbol) - haskey(choices.leaf_nodes, key) -end - -## get_submap ## - -function get_submap(choices::StaticChoiceMap, key::Symbol) - if haskey(choices.internal_nodes, key) - choices.internal_nodes[key] - elseif haskey(choices.leaf_nodes, key) - throw(KeyError(key)) - else - EmptyChoiceMap() - end -end - -function static_get_submap(choices::StaticChoiceMap, ::Val{A}) where {A} - choices.internal_nodes[A] -end - -## get_value ## - -function get_value(choices::StaticChoiceMap, key::Symbol) - choices.leaf_nodes[key] -end - -function static_get_value(choices::StaticChoiceMap, ::Val{A}) where {A} - choices.leaf_nodes[A] -end - -# convert from any other schema that has only Val{:foo} addresses -function StaticChoiceMap(other::ChoiceMap) - leaf_keys_and_nodes = collect(get_values_shallow(other)) - internal_keys_and_nodes = collect(get_submaps_shallow(other)) - if length(leaf_keys_and_nodes) > 0 - (leaf_keys, leaf_nodes) = collect(zip(leaf_keys_and_nodes...)) - else - (leaf_keys, leaf_nodes) = ((), ()) - end - if length(internal_keys_and_nodes) > 0 - (internal_keys, internal_nodes) = collect(zip(internal_keys_and_nodes...)) - else - (internal_keys, internal_nodes) = ((), ()) - end - StaticChoiceMap( - NamedTuple{leaf_keys}(leaf_nodes), - NamedTuple{internal_keys}(internal_nodes), - isempty(other)) -end - -""" - choices = pair(choices1::ChoiceMap, choices2::ChoiceMap, key1::Symbol, key2::Symbol) - -Return an assignment that contains `choices1` as a sub-assignment under `key1` -and `choices2` as a sub-assignment under `key2`. -""" -function pair(choices1::ChoiceMap, choices2::ChoiceMap, key1::Symbol, key2::Symbol) - StaticChoiceMap(NamedTuple(), NamedTuple{(key1,key2)}((choices1, choices2)), - isempty(choices1) && isempty(choices2)) -end - -""" - (choices1, choices2) = unpair(choices::ChoiceMap, key1::Symbol, key2::Symbol) - -Return the two sub-assignments at `key1` and `key2`, one or both of which may be empty. - -It is an error if there are any top-level values, or any non-empty top-level -sub-assignments at keys other than `key1` and `key2`. -""" -function unpair(choices::ChoiceMap, key1::Symbol, key2::Symbol) - if !isempty(get_values_shallow(choices)) || length(collect(get_submaps_shallow(choices))) > 2 - error("Not a pair") - end - a = get_submap(choices, key1) - b = get_submap(choices, key2) - (a, b) -end - -# TODO make a generated function? -function _fill_array!(choices::StaticChoiceMap, arr::Vector{T}, start_idx::Int) where {T} - idx = start_idx - for value in choices.leaf_nodes - n_written = _fill_array!(value, arr, idx) - idx += n_written - end - for node in choices.internal_nodes - n_written = _fill_array!(node, arr, idx) - idx += n_written - end - idx - start_idx -end - -@generated function _from_array( - proto_choices::StaticChoiceMap{R,S,T,U}, arr::Vector{V}, start_idx::Int) where {R,S,T,U,V} - leaf_node_keys = proto_choices.parameters[1] - leaf_node_types = proto_choices.parameters[2].parameters - internal_node_keys = proto_choices.parameters[3] - internal_node_types = proto_choices.parameters[4].parameters - - exprs = [quote idx = start_idx end] - leaf_node_names = [] - internal_node_names = [] - - # leaf nodes - for key in leaf_node_keys - value = gensym() - push!(leaf_node_names, value) - push!(exprs, quote - (n_read, $value) = _from_array(proto_choices.leaf_nodes.$key, arr, idx) - idx += n_read - end) - end - - # internal nodes - for key in internal_node_keys - node = gensym() - push!(internal_node_names, node) - push!(exprs, quote - (n_read, $node) = _from_array(proto_choices.internal_nodes.$key, arr, idx) - idx += n_read - end) - end - - quote - $(exprs...) - leaf_nodes_field = NamedTuple{R,S}(($(leaf_node_names...),)) - internal_nodes_field = NamedTuple{T,U}(($(internal_node_names...),)) - choices = StaticChoiceMap{R,S,T,U}(leaf_nodes_field, internal_nodes_field) - (idx - start_idx, choices) - end -end - -@generated function Base.merge(choices1::StaticChoiceMap{R,S,T,U}, - choices2::StaticChoiceMap{W,X,Y,Z}) where {R,S,T,U,W,X,Y,Z} - - # unpack first assignment type parameters - leaf_node_keys1 = choices1.parameters[1] - leaf_node_types1 = choices1.parameters[2].parameters - internal_node_keys1 = choices1.parameters[3] - internal_node_types1 = choices1.parameters[4].parameters - keys1 = (leaf_node_keys1..., internal_node_keys1...,) - - # unpack second assignment type parameters - leaf_node_keys2 = choices2.parameters[1] - leaf_node_types2 = choices2.parameters[2].parameters - internal_node_keys2 = choices2.parameters[3] - internal_node_types2 = choices2.parameters[4].parameters - keys2 = (leaf_node_keys2..., internal_node_keys2...,) - - # leaf vs leaf collision is an error - colliding_leaf_leaf_keys = intersect(leaf_node_keys1, leaf_node_keys2) - if !isempty(colliding_leaf_leaf_keys) - error("choices1 and choices2 both have leaf nodes at key(s): $colliding_leaf_leaf_keys") - end - - # leaf vs internal collision is an error - colliding_leaf_internal_keys = intersect(leaf_node_keys1, internal_node_keys2) - if !isempty(colliding_leaf_internal_keys) - error("choices1 has leaf node and choices2 has internal node at key(s): $colliding_leaf_internal_keys") - end - - # internal vs leaf collision is an error - colliding_internal_leaf_keys = intersect(internal_node_keys1, leaf_node_keys2) - if !isempty(colliding_internal_leaf_keys) - error("choices1 has internal node and choices2 has leaf node at key(s): $colliding_internal_leaf_keys") - end - - # internal vs internal collision is not an error, recursively call merge - colliding_internal_internal_keys = (intersect(internal_node_keys1, internal_node_keys2)...,) - internal_node_keys1_exclusive = (setdiff(internal_node_keys1, internal_node_keys2)...,) - internal_node_keys2_exclusive = (setdiff(internal_node_keys2, internal_node_keys1)...,) - - # leaf nodes named tuple - leaf_node_keys = (leaf_node_keys1..., leaf_node_keys2...,) - leaf_node_types = map(QuoteNode, (leaf_node_types1..., leaf_node_types2...,)) - leaf_node_values = Expr(:tuple, - [Expr(:(.), :(choices1.leaf_nodes), QuoteNode(key)) - for key in leaf_node_keys1]..., - [Expr(:(.), :(choices2.leaf_nodes), QuoteNode(key)) - for key in leaf_node_keys2]...) - leaf_nodes = Expr(:call, - Expr(:curly, :NamedTuple, - QuoteNode(leaf_node_keys), - Expr(:curly, :Tuple, leaf_node_types...)), - leaf_node_values) - - # internal nodes named tuple - internal_node_keys = (internal_node_keys1_exclusive..., - internal_node_keys2_exclusive..., - colliding_internal_internal_keys...) - internal_node_values = Expr(:tuple, - [Expr(:(.), :(choices1.internal_nodes), QuoteNode(key)) - for key in internal_node_keys1_exclusive]..., - [Expr(:(.), :(choices2.internal_nodes), QuoteNode(key)) - for key in internal_node_keys2_exclusive]..., - [Expr(:call, :merge, - Expr(:(.), :(choices1.internal_nodes), QuoteNode(key)), - Expr(:(.), :(choices2.internal_nodes), QuoteNode(key))) - for key in colliding_internal_internal_keys]...) - internal_nodes = Expr(:call, - Expr(:curly, :NamedTuple, QuoteNode(internal_node_keys)), - internal_node_values) - - # construct assignment from named tuples - Expr(:call, :StaticChoiceMap, leaf_nodes, internal_nodes) -end - -export StaticChoiceMap -export pair, unpair - -####################### -# dynamic assignment # -####################### - -struct DynamicChoiceMap <: ChoiceMap - leaf_nodes::Dict{Any,Any} - internal_nodes::Dict{Any,Any} - function DynamicChoiceMap(leaf_nodes::Dict{Any,Any}, internal_nodes::Dict{Any,Any}) - new(leaf_nodes, internal_nodes) - end -end - -# invariant: all internal nodes are nonempty - -""" - struct DynamicChoiceMap <: ChoiceMap .. end - -A mutable map from arbitrary hierarchical addresses to values. - - choices = DynamicChoiceMap() - -Construct an empty map. - - choices = DynamicChoiceMap(tuples...) - -Construct a map containing each of the given (addr, value) tuples. -""" -function DynamicChoiceMap() - DynamicChoiceMap(Dict(), Dict()) -end - -function DynamicChoiceMap(tuples...) - choices = DynamicChoiceMap() - for (addr, value) in tuples - choices[addr] = value - end - choices -end - -""" - choices = DynamicChoiceMap(other::ChoiceMap) - -Copy a choice map, returning a mutable choice map. -""" -function DynamicChoiceMap(other::ChoiceMap) - choices = DynamicChoiceMap() - for (addr, val) in get_values_shallow(other) - choices[addr] = val - end - for (addr, submap) in get_submaps_shallow(other) - set_submap!(choices, addr, DynamicChoiceMap(submap)) - end - choices -end - -""" - choices = choicemap() - -Construct an empty mutable choice map. -""" -function choicemap() - DynamicChoiceMap() -end - -""" - choices = choicemap(tuples...) - -Construct a mutable choice map initialized with given address, value tuples. -""" -function choicemap(tuples...) - DynamicChoiceMap(tuples...) -end - -get_address_schema(::Type{DynamicChoiceMap}) = DynamicAddressSchema() - -get_values_shallow(choices::DynamicChoiceMap) = choices.leaf_nodes - -get_submaps_shallow(choices::DynamicChoiceMap) = choices.internal_nodes - -has_value(choices::DynamicChoiceMap, addr::Pair) = _has_value(choices, addr) - -get_value(choices::DynamicChoiceMap, addr::Pair) = _get_value(choices, addr) - -get_submap(choices::DynamicChoiceMap, addr::Pair) = _get_submap(choices, addr) - -function get_submap(choices::DynamicChoiceMap, addr) - if haskey(choices.internal_nodes, addr) - choices.internal_nodes[addr] - elseif haskey(choices.leaf_nodes, addr) - throw(KeyError(addr)) - else - EmptyChoiceMap() - end -end - -has_value(choices::DynamicChoiceMap, addr) = haskey(choices.leaf_nodes, addr) - -get_value(choices::DynamicChoiceMap, addr) = choices.leaf_nodes[addr] - -function Base.isempty(choices::DynamicChoiceMap) - isempty(choices.leaf_nodes) && isempty(choices.internal_nodes) -end - -# mutation (not part of the assignment interface) - -""" - set_value!(choices::DynamicChoiceMap, addr, value) - -Set the given value for the given address. - -Will cause any previous value or sub-assignment at this address to be deleted. -It is an error if there is already a value present at some prefix of the given address. - -The following syntactic sugar is provided: - - choices[addr] = value -""" -function set_value!(choices::DynamicChoiceMap, addr, value) - delete!(choices.internal_nodes, addr) - choices.leaf_nodes[addr] = value -end - -function set_value!(choices::DynamicChoiceMap, addr::Pair, value) - (first, rest) = addr - if haskey(choices.leaf_nodes, first) - # we are not writing to the address directly, so we error instead of - # delete the existing node. - error("Tried to create assignment at $first but there was already a value there.") - end - if haskey(choices.internal_nodes, first) - node = choices.internal_nodes[first] - else - node = DynamicChoiceMap() - choices.internal_nodes[first] = node - end - node = choices.internal_nodes[first] - set_value!(node, rest, value) -end - -""" - set_submap!(choices::DynamicChoiceMap, addr, submap::ChoiceMap) - -Replace the sub-assignment rooted at the given address with the given sub-assignment. -Set the given value for the given address. - -Will cause any previous value or sub-assignment at the given address to be deleted. -It is an error if there is already a value present at some prefix of address. -""" -function set_submap!(choices::DynamicChoiceMap, addr, new_node) - delete!(choices.leaf_nodes, addr) - delete!(choices.internal_nodes, addr) - if !isempty(new_node) - choices.internal_nodes[addr] = new_node - end -end - -function set_submap!(choices::DynamicChoiceMap, addr::Pair, new_node) - (first, rest) = addr - if haskey(choices.leaf_nodes, first) - # we are not writing to the address directly, so we error instead of - # delete the existing node. - error("Tried to create assignment at $first but there was already a value there.") - end - if haskey(choices.internal_nodes, first) - node = choices.internal_nodes[first] - else - node = DynamicChoiceMap() - choices.internal_nodes[first] = node - end - set_submap!(node, rest, new_node) -end - -Base.setindex!(choices::DynamicChoiceMap, value, addr) = set_value!(choices, addr, value) - -function _fill_array!(choices::DynamicChoiceMap, arr::Vector{T}, start_idx::Int) where {T} - leaf_keys_sorted = sort(collect(keys(choices.leaf_nodes))) - internal_node_keys_sorted = sort(collect(keys(choices.internal_nodes))) - idx = start_idx - for key in leaf_keys_sorted - value = choices.leaf_nodes[key] - n_written = _fill_array!(value, arr, idx) - idx += n_written - end - for key in internal_node_keys_sorted - n_written = _fill_array!(get_submap(choices, key), arr, idx) - idx += n_written - end - idx - start_idx -end - -function _from_array(proto_choices::DynamicChoiceMap, arr::Vector{T}, start_idx::Int) where {T} - @assert length(arr) >= start_idx - choices = DynamicChoiceMap() - leaf_keys_sorted = sort(collect(keys(proto_choices.leaf_nodes))) - internal_node_keys_sorted = sort(collect(keys(proto_choices.internal_nodes))) - idx = start_idx - for key in leaf_keys_sorted - (n_read, value) = _from_array(proto_choices.leaf_nodes[key], arr, idx) - idx += n_read - choices.leaf_nodes[key] = value - end - for key in internal_node_keys_sorted - (n_read, node) = _from_array(get_submap(proto_choices, key), arr, idx) - idx += n_read - choices.internal_nodes[key] = node - end - (idx - start_idx, choices) -end - -export DynamicChoiceMap -export choicemap -export set_value! -export set_submap! - - -####################################### -## vector combinator for assignments # -####################################### - -# TODO implement LeafVectorChoiceMap, which stores a vector of leaf nodes - -struct InternalVectorChoiceMap{T} <: ChoiceMap - internal_nodes::Vector{T} - is_empty::Bool -end - -function vectorize_internal(nodes::Vector{T}) where {T} - is_empty = all(map(isempty, nodes)) - InternalVectorChoiceMap(nodes, is_empty) -end - -# note some internal nodes may be empty - -get_address_schema(::Type{InternalVectorChoiceMap}) = VectorAddressSchema() - -Base.isempty(choices::InternalVectorChoiceMap) = choices.is_empty -has_value(choices::InternalVectorChoiceMap, addr::Pair) = _has_value(choices, addr) -get_value(choices::InternalVectorChoiceMap, addr::Pair) = _get_value(choices, addr) -get_submap(choices::InternalVectorChoiceMap, addr::Pair) = _get_submap(choices, addr) - -function get_submap(choices::InternalVectorChoiceMap, addr::Int) - if addr > 0 && addr <= length(choices.internal_nodes) - choices.internal_nodes[addr] - else - EmptyChoiceMap() - end -end - -function get_submaps_shallow(choices::InternalVectorChoiceMap) - ((i, choices.internal_nodes[i]) - for i=1:length(choices.internal_nodes) - if !isempty(choices.internal_nodes[i])) -end - -get_values_shallow(::InternalVectorChoiceMap) = () - -function _fill_array!(choices::InternalVectorChoiceMap, arr::Vector{T}, start_idx::Int) where {T} - idx = start_idx - for key=1:length(choices.internal_nodes) - n = _fill_array!(choices.internal_nodes[key], arr, idx) - idx += n - end - idx - start_idx -end - -function _from_array(proto_choices::InternalVectorChoiceMap{U}, arr::Vector{T}, start_idx::Int) where {T,U} - @assert length(arr) >= start_idx - nodes = Vector{U}(undef, length(proto_choices.internal_nodes)) - idx = start_idx - for key=1:length(proto_choices.internal_nodes) - (n_read, nodes[key]) = _from_array(proto_choices.internal_nodes[key], arr, idx) - idx += n_read - end - choices = InternalVectorChoiceMap(nodes, proto_choices.is_empty) - (idx - start_idx, choices) -end - -export InternalVectorChoiceMap -export vectorize_internal - - -#################### -# empty assignment # -#################### - -struct EmptyChoiceMap <: ChoiceMap end - -Base.isempty(::EmptyChoiceMap) = true -get_address_schema(::Type{EmptyChoiceMap}) = EmptyAddressSchema() -get_submaps_shallow(::EmptyChoiceMap) = () -get_values_shallow(::EmptyChoiceMap) = () - -_fill_array!(::EmptyChoiceMap, arr::Vector, start_idx::Int) = 0 -_from_array(::EmptyChoiceMap, arr::Vector, start_idx::Int) = (0, EmptyChoiceMap()) - -export EmptyChoiceMap - -############################################ -# Nested-dict–like accessor for choicemaps # -############################################ - -""" -Wrapper for a `ChoiceMap` that provides nested-dict–like syntax, rather than -the default syntax which looks like a flat dict of full keypaths. - -```jldoctest -julia> using Gen -julia> c = choicemap((:a, 1), - (:b => :c, 2)); -julia> cv = nested_view(c); -julia> c[:a] == cv[:a] -true -julia> c[:b => :c] == cv[:b][:c] -true -julia> length(cv) -2 -julia> length(cv[:b]) -1 -julia> sort(collect(keys(cv))) -[:a, :b] -julia> sort(collect(keys(cv[:b]))) -[:c] -``` -""" -struct ChoiceMapNestedView - choice_map::ChoiceMap -end - -function Base.getindex(choices::ChoiceMapNestedView, addr) - if has_value(choices.choice_map, addr) - return get_value(choices.choice_map, addr) - end - submap = get_submap(choices.choice_map, addr) - if isempty(submap) - throw(KeyError(addr)) - end - ChoiceMapNestedView(submap) -end - -function Base.iterate(c::ChoiceMapNestedView) - inner_iterator = Base.Iterators.flatten(( - get_values_shallow(c.choice_map), - ((k, ChoiceMapNestedView(v)) - for (k, v) in get_submaps_shallow(c.choice_map)))) - r = Base.iterate(inner_iterator) - if r == nothing - return nothing - end - (next_kv, next_inner_state) = r - (next_kv, (inner_iterator, next_inner_state)) -end - -function Base.iterate(c::ChoiceMapNestedView, state) - (inner_iterator, inner_state) = state - r = Base.iterate(inner_iterator, inner_state) - if r == nothing - return nothing - end - (next_kv, next_inner_state) = r - (next_kv, (inner_iterator, next_inner_state)) -end - -# TODO: Allow different implementations of this method depending on the -# concrete type of the `ChoiceMap`, so that an already-existing data structure -# with faster key lookup (analogous to `Base.KeySet`) can be exposed if it -# exists. -Base.keys(cv::Gen.ChoiceMapNestedView) = (k for (k, v) in cv) - -function Base.:(==)(a::ChoiceMapNestedView, b::ChoiceMapNestedView) - a.choice_map == b.choice_map -end - -# Length of a `ChoiceMapNestedView` is number of leaf values + number of -# submaps. Motivation: This matches what `length` would return for the -# equivalent nested dict. -function Base.length(cv::ChoiceMapNestedView) - +(get_values_shallow(cv.choice_map) |> collect |> length, - get_submaps_shallow(cv.choice_map) |> collect |> length) -end - -function Base.show(io::IO, ::MIME"text/plain", c::ChoiceMapNestedView) - Base.show(io, MIME"text/plain"(), c.choice_map) -end - -nested_view(c::ChoiceMap) = ChoiceMapNestedView(c) - -# TODO(https://github.com/probcomp/Gen/issues/167): Also allow calling -# `nested_view(::Trace)`, to get a nested-dict–like view of the choicemap and -# aux data together. - -export nested_view - -""" - selected_choices = get_selected(choices::ChoiceMap, selection::Selection) - -Filter the choice map to include only choices in the given selection. - -Returns a new choice map. -""" -function get_selected( - choices::ChoiceMap, selection::Selection) - output = choicemap() - for (key, value) in get_values_shallow(choices) - if (key in selection) - output[key] = value - end - end - for (key, submap) in get_submaps_shallow(choices) - subselection = selection[key] - set_submap!(output, key, get_selected(submap, subselection)) - end - output -end - -export get_selected diff --git a/src/diff.jl b/src/diff.jl index e43e61c36..ead1cc3a6 100644 --- a/src/diff.jl +++ b/src/diff.jl @@ -82,6 +82,12 @@ export UnknownChange, NoChange export SetDiff, DictDiff, VectorDiff export IntDiff +function all_nochange(itr) + for i in itr + i !== NoChange() && return false + end + return true +end ############################### ## differencing of Julia code # @@ -97,18 +103,37 @@ struct Diffed{V,DV <: Diff} diff::DV end +function Base.show(io::IO, dv::Diffed) + print(io, "Diffed(") + print(io, strip_diff(dv)) + print(io, ", ") + print(io, get_diff(dv)) + print(io, ")") +end +Base.:(==)(x::Diffed, y::Diffed) = strip_diff(x) == strip_diff(y) && get_diff(x) == get_diff(y) + # obtain the diff part of a Diffed value -get_diff(diffed::Diffed) = diffed.diff +get_diff(diffed::Diffed) = getfield(diffed, :diff) # use `getfield` to avoid infinite recursion with `getproperty` # a value that is not wrapped in Diffed is a constant get_diff(value) = NoChange() -strip_diff(diffed::Diffed) = diffed.value +strip_diff(diffed::Diffed) = getfield(diffed, :value) strip_diff(value) = value export Diffed, diff, strip_diff +# getting properties of `NoChange` or `UnknownChange` diffed objects +Base.getproperty(x::Diffed{<:Any, NoChange}, sym::Symbol) = Diffed(getproperty(strip_diff(x), sym), NoChange()) +Base.getproperty(x::Diffed{<:Any, UnknownChange}, sym::Symbol) = Diffed(getproperty(strip_diff(x), sym), UnknownChange()) + +# simple map +# TODO: handle VectorDiff +Base.map(f::Function, itr::Diffed{<:Any, NoChange}) = Diffed(map(f, strip_diff(itr)), NoChange()) +Base.map(f::Function, itr::Diffed{<:Any, UnknownChange}) = Diffed(map(f, strip_diff(itr)), UnknownChange()) +Base.map(f::Diffed{Function, NoChange}, itr::Diffed) = map(strip_diff(f), itr) + # sets function Base.in(element::Diffed{V,DV}, set::Diffed{T,DT}) where {V, T <: AbstractSet{V}, DV, DT} @@ -221,6 +246,14 @@ function Base.length(vec::Diffed{T,VectorDiff}) where {T <: Union{AbstractVector end end +function Base.collect(T::Type, v::Diffed{U, NoChange}) where {U <: Union{AbstractVector, Tuple}} + Diffed(collect(T, strip_diff(v)), NoChange()) +end +function Base.collect(T::Type, v::Diffed{U, UnknownChange}) where {U <: Union{AbstractVector, Tuple}} + Diffed(collect(T, strip_diff(v)), UnknownChange()) +end +# TODO: vector diff + # TODO: we know that indeices before teh deleted/inserted idnex have not been changed function Base.getindex(vec::Union{AbstractVector,Tuple}, idx::Diffed{U,DU}) where {U <: Integer, DU} diff --git a/src/distribution.jl b/src/distribution.jl new file mode 100644 index 000000000..865375ed2 --- /dev/null +++ b/src/distribution.jl @@ -0,0 +1,135 @@ +############################### +# Core Distribution Interface # +############################### + +struct DistributionTrace{T, Dist} <: Trace + dist::Dist + val::T + args + score::Float64 +end + +abstract type Distribution{T} <: GenerativeFunction{T, DistributionTrace{T}} end +DistributionTrace{T, Dist}(dist::Dist, val::T, args::Tuple) where {T, Dist} = DistributionTrace{T, Dist}(dist, val, args, logpdf(dist, val, args...)) +DistributionTrace(dist::Dist, val::T, args::Tuple) where{Dist, T} = DistributionTrace{T, Dist}(dist, val, args) + +# we need to know the specific distribution in the trace type so the compiler can specialize GFI calls fully +@inline get_trace_type(::Dist) where {T, Dist <: Distribution{T}} = DistributionTrace{T, Dist} + +function Base.convert(::Type{<:DistributionTrace{U, <:Any}}, tr::DistributionTrace{<:Any, Dist}) where {U, Dist} + DistributionTrace{U, Dist}(tr.dist, convert(U, tr.val), tr.args, tr.score) +end + +""" + val::T = random(dist::Distribution{T}, args...) + +Sample a random choice from the given distribution with the given arguments. +""" +function random end + +""" + lpdf = logpdf(dist::Distribution{T}, value::T, args...) + +Evaluate the log probability (density) of the value. +""" +function logpdf end + +""" + has::Bool = has_output_grad(dist::Distribution) + +Return true of the gradient if the distribution computes the gradient of the logpdf with respect to the value of the random choice. +""" +function has_output_grad end + +""" + grads::Tuple = logpdf_grad(dist::Distribution{T}, value::T, args...) + +Compute the gradient of the logpdf with respect to the value, and each of the arguments. + +If `has_output_grad` returns false, then the first element of the returned tuple is `nothing`. +Otherwise, the first element of the tuple is the gradient with respect to the value. +If the return value of `has_argument_grads` has a false value for at position `i`, then the `i+1`th element of the returned tuple has value `nothing`. +Otherwise, this element contains the gradient with respect to the `i`th argument. +""" +function logpdf_grad end + +function is_discrete end + +# NOTE: has_argument_grad is documented and exported in gen_fn_interface.jl + +get_return_type(::Distribution{T}) where {T} = T + + +############################## +# Distribution GFI Interface # +############################## + +@inline Base.getindex(trace::DistributionTrace) = trace.val +@inline Gen.get_args(trace::DistributionTrace) = trace.args +@inline Gen.get_choices(trace::DistributionTrace) = Value(trace.val) # should be able to get type of val +@inline Gen.get_retval(trace::DistributionTrace) = trace.val +@inline Gen.get_gen_fn(trace::DistributionTrace) = trace.dist +@inline Gen.get_score(trace::DistributionTrace) = trace.score +@inline Gen.project(trace::DistributionTrace, ::EmptySelection) = 0. +@inline Gen.project(trace::DistributionTrace, ::AllSelection) = get_score(trace) + +@inline function Gen.simulate(dist::Distribution, args::Tuple) + val = random(dist, args...) + DistributionTrace(dist, val, args) +end +@inline Gen.generate(dist::Distribution, args::Tuple, ::EmptyChoiceMap) = (simulate(dist, args), 0.) +@inline function Gen.generate(dist::Distribution, args::Tuple, constraints::Value) + tr = DistributionTrace(dist, get_value(constraints), args) + weight = get_score(tr) + (tr, weight) +end +@inline function Gen.update(tr::DistributionTrace, args::Tuple, argdiffs::Tuple, spec::Value, ::AllSelection) + new_tr = DistributionTrace(tr.dist, get_value(spec), args) + weight = get_score(new_tr) - get_score(tr) + (new_tr, weight, UnknownChange(), get_choices(tr)) +end +@inline function Gen.update(tr::DistributionTrace, args::Tuple, argdiffs::Tuple, spec::Value, ::EmptyAddressTree) + new_tr = DistributionTrace(tr.dist, get_value(spec), args) + (new_tr, get_score(new_tr), UnknownChange(), get_choices(tr)) +end +@inline function Gen.update(tr::DistributionTrace, args::Tuple, argdiffs::Tuple, ::EmptyAddressTree, ::Selection) + new_tr = DistributionTrace(tr.dist, tr.val, args) + weight = get_score(new_tr) - get_score(tr) + (new_tr, weight, NoChange(), EmptyChoiceMap()) +end +@inline function Gen.update(tr::DistributionTrace, args::Tuple, argdiffs::Tuple{Vararg{NoChange}}, ::EmptyAddressTree, ::Selection) + (tr, 0., NoChange(), EmptyChoiceMap()) +end +@inline function Gen.update(tr::DistributionTrace, args::Tuple, argdiffs::Tuple, ::AllSelection, ::EmptyAddressTree) + new_val = random(tr.dist, args...) + new_tr = DistributionTrace(tr.dist, new_val, args) + (new_tr, 0., UnknownChange(), get_choices(tr)) +end +@inline function Gen.update(tr::DistributionTrace, args::Tuple, argdiffs::Tuple, ::AllSelection, ::AllSelection) + new_val = random(tr.dist, args...) + new_tr = DistributionTrace(tr.dist, new_val, args) + (new_tr, -get_score(tr), UnknownChange(), get_choices(tr)) +end +@inline function Gen.propose(dist::Distribution, args::Tuple) + val = random(dist, args...) + score = logpdf(dist, val, args...) + (Value(val), score, val) +end +@inline function Gen.assess(dist::Distribution, args::Tuple, choices::Value) + weight = logpdf(dist, get_value(choices), args...) + (weight, choices.val) +end +@inline function Gen.assess(dist::Distribution, args::Tuple, ::EmptyChoiceMap) + error("Call to `assess` did not provide a value constraint for a call to $dist.") +end + +########### +# Exports # +########### + +export Distribution +export random +export logpdf +export logpdf_grad +export has_output_grad +export is_discrete diff --git a/src/dsl/dsl.jl b/src/dsl/dsl.jl index 2f60b7fb3..6a016af62 100644 --- a/src/dsl/dsl.jl +++ b/src/dsl/dsl.jl @@ -7,6 +7,7 @@ const DSL_ARG_GRAD_ANNOTATION = :grad const DSL_RET_GRAD_ANNOTATION = :grad const DSL_TRACK_DIFFS_ANNOTATION = :diffs const DSL_NO_JULIA_CACHE_ANNOTATION = :nojuliacache +const DSL_MACROS = Set([Symbol("@trace"), Symbol("@param")]) struct Argument name::Symbol @@ -65,6 +66,8 @@ include("dynamic.jl") include("static.jl") function desugar_tildes(expr) + trace_ref = GlobalRef(@__MODULE__, Symbol("@trace")) + line_num = LineNumberNode(1, :none) MacroTools.postwalk(expr) do e # Replace tilde statements with :gentrace expressions if MacroTools.@capture(e, {*} ~ rhs_call) diff --git a/src/dsl/static.jl b/src/dsl/static.jl index 483a2eae1..bf2c86c5f 100644 --- a/src/dsl/static.jl +++ b/src/dsl/static.jl @@ -47,10 +47,6 @@ end split_addr!(keys, addr_expr::QuoteNode) = push!(keys, addr_expr) split_addr!(keys, addr_expr::Symbol) = push!(keys, addr_expr) -"Construct choice-at or call-at combinator depending on type." -choice_or_call_at(gen_fn::GenerativeFunction, addr_typ) = call_at(gen_fn, addr_typ) -choice_or_call_at(dist::Distribution, addr_typ) = choice_at(dist, addr_typ) - "Generate informative node name for a Julia expression." gen_node_name(arg::Any) = gensym(string(arg)) gen_node_name(arg::Expr) = gensym(arg.head) @@ -74,12 +70,12 @@ function parse_trace_expr!(stmts, bindings, fn, args, addr) end addr = keys[1].value # Get top level address if length(keys) > 1 - # For each nesting level, wrap gen_fn_or_dist within choice_at / call_at + # For each nesting level, wrap gen_fn_or_dist within call_at for key in keys[2:end] push!(stmts, :($(esc(gen_fn_or_dist)) = - choice_or_call_at($(esc(gen_fn_or_dist)), Any))) + call_at($(esc(gen_fn_or_dist)), Any))) end - # Append the nested addresses as arguments to choice_at / call_at + # Append the nested addresses as arguments to call_at args = [args; reverse(keys[2:end])] end # Handle arguments to the traced call diff --git a/src/dynamic/assess.jl b/src/dynamic/assess.jl index c583d5079..0bf37a077 100644 --- a/src/dynamic/assess.jl +++ b/src/dynamic/assess.jl @@ -9,22 +9,6 @@ function GFAssessState(choices, params::Dict{Symbol,Any}) GFAssessState(choices, 0., AddressVisitor(), params) end -function traceat(state::GFAssessState, dist::Distribution{T}, - args, key) where {T} - local retval::T - - # check that key was not already visited, and mark it as visited - visit!(state.visitor, key) - - # get return value - retval = get_value(state.choices, key) - - # update weight - state.weight += logpdf(dist, retval, args...) - - retval -end - function traceat(state::GFAssessState, gen_fn::GenerativeFunction{T,U}, args, key) where {T,U} local retval::T diff --git a/src/dynamic/backprop.jl b/src/dynamic/backprop.jl index 6a7278a02..3ad34643e 100644 --- a/src/dynamic/backprop.jl +++ b/src/dynamic/backprop.jl @@ -74,7 +74,7 @@ function traceat(state::GFBackpropParamsState, dist::Distribution{T}, args_maybe_tracked, key) where {T} local retval::T visit!(state.visitor, key) - retval = get_choice(state.trace, key).retval + retval = get_retval(get_call(state.trace, key).subtrace) args = map(value, args_maybe_tracked) score_tracked = track(logpdf(dist, retval, args...), state.tape) record!(state.tape, ReverseDiff.SpecialInstruction, dist, @@ -275,7 +275,7 @@ function traceat(state::GFBackpropTraceState, dist::Distribution{T}, args_maybe_tracked, key) where {T} local retval::T visit!(state.visitor, key) - retval = get_choice(state.trace, key).retval + retval = get_retval(get_call(state.trace, key).subtrace) args = map(value, args_maybe_tracked) score_tracked = track(logpdf(dist, retval, args...), state.tape) if key in state.selection @@ -317,7 +317,7 @@ function traceat(state::GFBackpropTraceState, gen_fn::GenerativeFunction{T,U}, retval_maybe_tracked = retval @assert !istracked(retval_maybe_tracked) end - selection = state.selection[key] + selection = get_subselection(state.selection, key) record = BackpropTraceRecord(gen_fn, subtrace, selection, state.value_choices, state.gradient_choices, key) record!(state.tape, ReverseDiff.SpecialInstruction, record, (args_maybe_tracked...,), retval_maybe_tracked) diff --git a/src/dynamic/dynamic.jl b/src/dynamic/dynamic.jl index 58d0cf6b0..57baa6841 100644 --- a/src/dynamic/dynamic.jl +++ b/src/dynamic/dynamic.jl @@ -112,48 +112,36 @@ function visit!(visitor::AddressVisitor, addr) push!(visitor.visited, addr) end -function all_visited(visited::Selection, choices::ChoiceMap) - allvisited = true - for (key, _) in get_values_shallow(choices) - allvisited = allvisited && (key in visited) - end - for (key, submap) in get_submaps_shallow(choices) - if !(key in visited) - subvisited = visited[key] - allvisited = allvisited && all_visited(subvisited, submap) +all_constraints_visited(::Selection, ::Value) = false +all_constraints_visited(::AllSelection, ::Value) = true +all_constraints_visited(::Selection, ::Selection) = true # we're allowed to not visit selections +all_constraints_visited(::Selection, ::EmptyAddressTree) = true +all_constraints_visited(::AllSelection, ::EmptyAddressTree) = true +function all_constraints_visited(visited::Selection, spec::UpdateSpec) + for (key, subtree) in get_subtrees_shallow(spec) + if !all_constraints_visited(get_subselection(visited, key), subtree) + return false end end - allvisited + return true end +get_unvisited(::Selection, v::Value) = v +get_unvisited(::AllSelection, v::Value) = EmptyChoiceMap() function get_unvisited(visited::Selection, choices::ChoiceMap) unvisited = choicemap() - for (key, _) in get_values_shallow(choices) - if !(key in visited) - set_value!(unvisited, key, get_value(choices, key)) - end - end for (key, submap) in get_submaps_shallow(choices) - if !(key in visited) - subvisited = visited[key] - sub_unvisited = get_unvisited(subvisited, submap) - set_submap!(unvisited, key, sub_unvisited) - end + sub_unvisited = get_unvisited(get_subselection(visited, key), submap) + set_submap!(unvisited, key, sub_unvisited) end unvisited end get_visited(visitor) = visitor.visited -function check_no_submap(constraints::ChoiceMap, addr) +function check_is_empty(constraints::ChoiceMap, addr) if !isempty(get_submap(constraints, addr)) - error("Expected a value at address $addr but found a sub-assignment") - end -end - -function check_no_value(constraints::ChoiceMap, addr) - if has_value(constraints, addr) - error("Expected a sub-assignment at address $addr but found a value") + error("Expected a value or EmptyChoiceMap at address $addr but found a sub-assignment") end end @@ -167,7 +155,6 @@ include("propose.jl") include("assess.jl") include("project.jl") include("update.jl") -include("regenerate.jl") include("backprop.jl") export DynamicDSLFunction diff --git a/src/dynamic/generate.jl b/src/dynamic/generate.jl index df6a5f465..4a5796aae 100644 --- a/src/dynamic/generate.jl +++ b/src/dynamic/generate.jl @@ -11,38 +11,6 @@ function GFGenerateState(gen_fn, args, constraints, params) GFGenerateState(trace, constraints, 0., AddressVisitor(), params) end -function traceat(state::GFGenerateState, dist::Distribution{T}, - args, key) where {T} - local retval::T - - # check that key was not already visited, and mark it as visited - visit!(state.visitor, key) - - # check for constraints at this key - constrained = has_value(state.constraints, key) - !constrained && check_no_submap(state.constraints, key) - - # get return value - if constrained - retval = get_value(state.constraints, key) - else - retval = random(dist, args...) - end - - # compute logpdf - score = logpdf(dist, retval, args...) - - # add to the trace - add_choice!(state.trace, key, retval, score) - - # increment weight - if constrained - state.weight += score - end - - retval -end - function traceat(state::GFGenerateState, gen_fn::GenerativeFunction{T,U}, args, key) where {T,U} local subtrace::U diff --git a/src/dynamic/project.jl b/src/dynamic/project.jl index 358398e55..334525814 100644 --- a/src/dynamic/project.jl +++ b/src/dynamic/project.jl @@ -1,18 +1,12 @@ -function project_recurse(trie::Trie{Any,ChoiceOrCallRecord}, +function project_recurse(trie::Trie{Any, CallRecord}, selection::Selection) weight = 0. - for (key, choice_or_call) in get_leaf_nodes(trie) - if choice_or_call.is_choice - if key in selection - weight += choice_or_call.score - end - else - subselection = selection[key] - weight += project(choice_or_call.subtrace_or_retval, subselection) - end + for (key, call) in get_leaf_nodes(trie) + subselection = get_subselection(selection, key) + weight += project(call.subtrace, subselection) end for (key, subtrie) in get_internal_nodes(trie) - subselection = selection[key] + subselection = get_subselection(selection, key) weight += project_recurse(subtrie, subselection) end weight diff --git a/src/dynamic/propose.jl b/src/dynamic/propose.jl index e4281f49e..32fc95da2 100644 --- a/src/dynamic/propose.jl +++ b/src/dynamic/propose.jl @@ -9,25 +9,6 @@ function GFProposeState(params::Dict{Symbol,Any}) GFProposeState(choicemap(), 0., AddressVisitor(), params) end -function traceat(state::GFProposeState, dist::Distribution{T}, - args, key) where {T} - local retval::T - - # check that key was not already visited, and mark it as visited - visit!(state.visitor, key) - - # sample return value - retval = random(dist, args...) - - # update assignment - set_value!(state.choices, key, retval) - - # update weight - state.weight += logpdf(dist, retval, args...) - - retval -end - function traceat(state::GFProposeState, gen_fn::GenerativeFunction{T,U}, args, key) where {T,U} local retval::T diff --git a/src/dynamic/regenerate.jl b/src/dynamic/regenerate.jl deleted file mode 100644 index 81ba8b3c4..000000000 --- a/src/dynamic/regenerate.jl +++ /dev/null @@ -1,142 +0,0 @@ -mutable struct GFRegenerateState - prev_trace::DynamicDSLTrace - trace::DynamicDSLTrace - selection::Selection - weight::Float64 - visitor::AddressVisitor - params::Dict{Symbol,Any} -end - -function GFRegenerateState(gen_fn, args, prev_trace, - selection, params) - visitor = AddressVisitor() - GFRegenerateState(prev_trace, DynamicDSLTrace(gen_fn, args), selection, - 0., visitor, params) -end - -function traceat(state::GFRegenerateState, dist::Distribution{T}, - args, key) where {T} - local prev_retval::T - local retval::T - - # check that key was not already visited, and mark it as visited - visit!(state.visitor, key) - - # check for previous choice at this key - has_previous = has_choice(state.prev_trace, key) - if has_previous - prev_choice = get_choice(state.prev_trace, key) - prev_retval = prev_choice.retval - prev_score = prev_choice.score - end - - # check whether the key was selected - in_selection = key in state.selection - - # get return value - if has_previous && in_selection - retval = random(dist, args...) - elseif has_previous - retval = prev_retval - else - retval = random(dist, args...) - end - - # compute logpdf - score = logpdf(dist, retval, args...) - - # update weight - if has_previous && !in_selection - state.weight += score - prev_score - end - - # add to the trace - add_choice!(state.trace, key, retval, score) - - retval -end - -function traceat(state::GFRegenerateState, gen_fn::GenerativeFunction{T,U}, - args, key) where {T,U} - local prev_retval::T - local trace::U - local retval::T - - # check that key was not already visited, and mark it as visited - visit!(state.visitor, key) - - # check whether the key was selected - subselection = state.selection[key] - - # get subtrace - has_previous = has_call(state.prev_trace, key) - if has_previous - prev_call = get_call(state.prev_trace, key) - prev_subtrace = prev_call.subtrace - get_gen_fn(prev_subtrace) === gen_fn || gen_fn_changed_error(key) - (subtrace, weight, _) = regenerate( - prev_subtrace, args, map((_) -> UnknownChange(), args), subselection) - else - (subtrace, weight) = generate(gen_fn, args, EmptyChoiceMap()) - end - - # update weight - state.weight += weight - - # add to the trace - add_call!(state.trace, key, subtrace) - - # get return value - retval = get_retval(subtrace) - - retval -end - -function splice(state::GFRegenerateState, gen_fn::DynamicDSLFunction, - args::Tuple) - prev_params = state.params - state.params = gen_fn.params - retval = exec(gen_fn, state, args) - state.params = prev_params - retval -end - -function regenerate_delete_recurse(prev_trie::Trie{Any,ChoiceOrCallRecord}, - visited::EmptySelection) - noise = 0. - for (key, choice_or_call) in get_leaf_nodes(prev_trie) - if !choice_or_call.is_choice - noise += choice_or_call.noise - end - end - for (key, subtrie) in get_internal_nodes(prev_trie) - noise += regenerate_delete_recurse(subtrie, EmptySelection()) - end - noise -end - -function regenerate_delete_recurse(prev_trie::Trie{Any,ChoiceOrCallRecord}, - visited::DynamicSelection) - noise = 0. - for (key, choice_or_call) in get_leaf_nodes(prev_trie) - if !(key in visited) && !choice_or_call.is_choice - noise += choice_or_call.noise - end - end - for (key, subtrie) in get_internal_nodes(prev_trie) - subvisited = visited[key] - noise += regenerate_delete_recurse(subtrie, subvisited) - end - noise -end - -function regenerate(trace::DynamicDSLTrace, args::Tuple, argdiffs::Tuple, - selection::Selection) - gen_fn = trace.gen_fn - state = GFRegenerateState(gen_fn, args, trace, selection, gen_fn.params) - retval = exec(gen_fn, state, args) - set_retval!(state.trace, retval) - visited = state.visitor.visited - state.weight -= regenerate_delete_recurse(trace.trie, visited) - (state.trace, state.weight, UnknownChange()) -end diff --git a/src/dynamic/simulate.jl b/src/dynamic/simulate.jl index 0addd8bfb..57f709dca 100644 --- a/src/dynamic/simulate.jl +++ b/src/dynamic/simulate.jl @@ -9,24 +9,6 @@ function GFSimulateState(gen_fn::GenerativeFunction, args::Tuple, params) GFSimulateState(trace, AddressVisitor(), params) end -function traceat(state::GFSimulateState, dist::Distribution{T}, - args, key) where {T} - local retval::T - - # check that key was not already visited, and mark it as visited - visit!(state.visitor, key) - - retval = random(dist, args...) - - # compute logpdf - score = logpdf(dist, retval, args...) - - # add to the trace - add_choice!(state.trace, key, retval, score) - - retval -end - function traceat(state::GFSimulateState, gen_fn::GenerativeFunction{T,U}, args, key) where {T,U} local subtrace::U diff --git a/src/dynamic/trace.jl b/src/dynamic/trace.jl index 8c02eceb5..69d94cce1 100644 --- a/src/dynamic/trace.jl +++ b/src/dynamic/trace.jl @@ -1,96 +1,40 @@ -struct ChoiceRecord{T} - retval::T - score::Float64 -end - struct CallRecord{T} subtrace::T score::Float64 noise::Float64 end - -struct ChoiceOrCallRecord{T} - subtrace_or_retval::T - score::Float64 - noise::Float64 # if choice then NaN - is_choice::Bool -end - -function ChoiceRecord(record::ChoiceOrCallRecord) - if !record.is_choice - error("Found call but expected choice") - end - ChoiceRecord(record.subtrace_or_retval, record.score) -end - -function CallRecord(record::ChoiceOrCallRecord) - if record.is_choice - error("Found choice but expected call") - end - CallRecord(record.subtrace_or_retval, record.score, record.noise) -end +project(c::CallRecord, ::EmptySelection) = c.noise +project(c::CallRecord, ::AllSelection) = c.score +project(c::CallRecord, s::Selection) = project(c.subtrace, s) mutable struct DynamicDSLTrace{T} <: Trace gen_fn::T - trie::Trie{Any,ChoiceOrCallRecord} - isempty::Bool + trie::Trie{Any,CallRecord} score::Float64 noise::Float64 args::Tuple retval::Any function DynamicDSLTrace{T}(gen_fn::T, args) where {T} - trie = Trie{Any,ChoiceOrCallRecord}() + trie = Trie{Any,CallRecord}() # retval is not known yet - new(gen_fn, trie, true, 0, 0, args) + new(gen_fn, trie, 0, 0, args) end end set_retval!(trace::DynamicDSLTrace, retval) = (trace.retval = retval) -function has_choice(trace::DynamicDSLTrace, addr) - haskey(trace.trie, addr) && trace.trie[addr].is_choice -end - -function has_call(trace::DynamicDSLTrace, addr) - haskey(trace.trie, addr) && !trace.trie[addr].is_choice -end - -function get_choice(trace::DynamicDSLTrace, addr) - choice = trace.trie[addr] - if !choice.is_choice - throw(KeyError(addr)) - end - ChoiceRecord(choice) -end - -function get_call(trace::DynamicDSLTrace, addr) - call = trace.trie[addr] - if call.is_choice - throw(KeyError(addr)) - end - CallRecord(call) -end - -function add_choice!(trace::DynamicDSLTrace, addr, retval, score) - if haskey(trace.trie, addr) - error("Value or subtrace already present at address $addr. - The same address cannot be reused for multiple random choices.") - end - trace.trie[addr] = ChoiceOrCallRecord(retval, score, NaN, true) - trace.score += score - trace.isempty = false -end +has_call(trace::DynamicDSLTrace, addr) = haskey(trace.trie, addr) +get_call(trace::DynamicDSLTrace, addr) = trace.trie[addr] function add_call!(trace::DynamicDSLTrace, addr, subtrace) if haskey(trace.trie, addr) - error("Value or subtrace already present at address $addr. + error("Subtrace already present at address $addr. The same address cannot be reused for multiple random choices.") end score = get_score(subtrace) noise = project(subtrace, EmptySelection()) submap = get_choices(subtrace) - trace.isempty = trace.isempty && isempty(submap) - trace.trie[addr] = ChoiceOrCallRecord(subtrace, score, noise, false) + trace.trie[addr] = CallRecord(subtrace, score, noise) trace.score += score trace.noise += noise end @@ -106,69 +50,28 @@ get_gen_fn(trace::DynamicDSLTrace) = trace.gen_fn ## get_choices ## -function get_choices(trace::DynamicDSLTrace) - if !trace.isempty - DynamicDSLChoiceMap(trace.trie) # see below - else - EmptyChoiceMap() - end -end +get_choices(trace::DynamicDSLTrace) = DynamicDSLChoiceMap(trace.trie) -struct DynamicDSLChoiceMap <: ChoiceMap - trie::Trie{Any,ChoiceOrCallRecord} +struct DynamicDSLChoiceMap <: AddressTree{Value} + trie::Trie{Any,CallRecord} end get_address_schema(::Type{DynamicDSLChoiceMap}) = DynamicAddressSchema() -Base.isempty(::DynamicDSLChoiceMap) = false # TODO not necessarily true -has_value(choices::DynamicDSLChoiceMap, addr::Pair) = _has_value(choices, addr) -get_value(choices::DynamicDSLChoiceMap, addr::Pair) = _get_value(choices, addr) -get_submap(choices::DynamicDSLChoiceMap, addr::Pair) = _get_submap(choices, addr) - -function get_submap(choices::DynamicDSLChoiceMap, addr) - trie = choices.trie - if has_leaf_node(trie, addr) - # leaf node, must be a call - call = trie[addr] - if call.is_choice - throw(KeyError(addr)) - end - get_choices(call.subtrace_or_retval) - elseif has_internal_node(trie, addr) - # internal node - subtrie = get_internal_node(trie, addr) - DynamicDSLChoiceMap(subtrie) # see below +get_subtree(choices::DynamicDSLChoiceMap, addr::Pair) = _get_subtree(choices, addr) +function get_subtree(choices::DynamicDSLChoiceMap, addr) + if haskey(choices.trie.leaf_nodes, addr) + get_choices(choices.trie[addr].subtrace) + elseif haskey(choices.trie.internal_nodes, addr) + DynamicDSLChoiceMap(choices.trie.internal_nodes[addr]) else EmptyChoiceMap() end end -function has_value(choices::DynamicDSLChoiceMap, addr) - trie = choices.trie - has_leaf_node(trie, addr) && trie[addr].is_choice -end - -function get_value(choices::DynamicDSLChoiceMap, addr) - trie = choices.trie - choice = trie[addr] - if !choice.is_choice - throw(KeyError(addr)) - end - choice.subtrace_or_retval -end - -function get_values_shallow(choices::DynamicDSLChoiceMap) - ((key, choice.subtrace_or_retval) - for (key, choice) in get_leaf_nodes(choices.trie) - if choice.is_choice) -end - -function get_submaps_shallow(choices::DynamicDSLChoiceMap) - calls_iter = ((key, get_choices(call.subtrace_or_retval)) - for (key, call) in get_leaf_nodes(choices.trie) - if !call.is_choice) - internal_nodes_iter = ((key, DynamicDSLChoiceMap(trie)) - for (key, trie) in get_internal_nodes(choices.trie)) - Iterators.flatten((calls_iter, internal_nodes_iter)) +function get_subtrees_shallow(choices::DynamicDSLChoiceMap) + leafs = ((key, get_choices(record.subtrace)) for (key, record) in get_leaf_nodes(choices.trie)) + internals = ((key, DynamicDSLChoiceMap(trie)) for (key, trie) in get_internal_nodes(choices.trie)) + Iterators.flatten((leafs, internals)) end ## Base.getindex ## @@ -176,13 +79,7 @@ end function _getindex(trace::DynamicDSLTrace, trie::Trie, addr::Pair) (first, rest) = addr if haskey(trie.leaf_nodes, first) - choice_or_call = trie.leaf_nodes[first] - if choice_or_call.is_choice - error("Unknown address $addr; random choice at $first") - else - subtrace = choice_or_call.subtrace_or_retval - return subtrace[rest] - end + return trie.leaf_nodes[first].subtrace[rest] elseif haskey(trie.internal_nodes, first) return _getindex(trace, trie.internal_nodes[first], rest) else @@ -192,14 +89,7 @@ end function _getindex(trace::DynamicDSLTrace, trie::Trie, addr) if haskey(trie.leaf_nodes, addr) - choice_or_call = trie.leaf_nodes[addr] - if choice_or_call.is_choice - # the value of the random choice - return choice_or_call.subtrace_or_retval - else - # the return value of the generative function call - return get_retval(choice_or_call.subtrace_or_retval) - end + return get_retval(trie.leaf_nodes[addr].subtrace) else error("No random choice or generative function call at address $addr") end diff --git a/src/dynamic/update.jl b/src/dynamic/update.jl index 24e023f24..e4642beb3 100644 --- a/src/dynamic/update.jl +++ b/src/dynamic/update.jl @@ -1,72 +1,22 @@ mutable struct GFUpdateState prev_trace::DynamicDSLTrace trace::DynamicDSLTrace - constraints::Any + spec::UpdateSpec + externally_constrained_addrs::Selection weight::Float64 visitor::AddressVisitor params::Dict{Symbol,Any} discard::DynamicChoiceMap end -function GFUpdateState(gen_fn, args, prev_trace, constraints, params) +function GFUpdateState(gen_fn, args, prev_trace, constraints, externally_constrained_addrs, params) visitor = AddressVisitor() discard = choicemap() trace = DynamicDSLTrace(gen_fn, args) - GFUpdateState(prev_trace, trace, constraints, + GFUpdateState(prev_trace, trace, constraints, externally_constrained_addrs, 0., visitor, params, discard) end -function traceat(state::GFUpdateState, dist::Distribution{T}, - args::Tuple, key) where {T} - - local prev_retval::T - local retval::T - - # check that key was not already visited, and mark it as visited - visit!(state.visitor, key) - - # check for previous choice at this key - has_previous = has_choice(state.prev_trace, key) - if has_previous - prev_choice = get_choice(state.prev_trace, key) - prev_retval = prev_choice.retval - prev_score = prev_choice.score - end - - # check for constraints at this key - constrained = has_value(state.constraints, key) - !constrained && check_no_submap(state.constraints, key) - - # record the previous value as discarded if it is replaced - if constrained && has_previous - set_value!(state.discard, key, prev_retval) - end - - # get return value - if constrained - retval = get_value(state.constraints, key) - elseif has_previous - retval = prev_retval - else - retval = random(dist, args...) - end - - # compute logpdf - score = logpdf(dist, retval, args...) - - # update the weight - if has_previous - state.weight += score - prev_score - elseif constrained - state.weight += score - end - - # add to the trace - add_choice!(state.trace, key, retval, score) - - retval -end - function traceat(state::GFUpdateState, gen_fn::GenerativeFunction{T,U}, args::Tuple, key) where {T,U} @@ -77,9 +27,9 @@ function traceat(state::GFUpdateState, gen_fn::GenerativeFunction{T,U}, # check key was not already visited, and mark it as visited visit!(state.visitor, key) - # check for constraints at this key - check_no_value(state.constraints, key) - constraints = get_submap(state.constraints, key) + # updatespec at this key + spec = get_subtree(state.spec, key) + sub_externally_constrained_addrs = get_subtree(state.externally_constrained_addrs, key) # get subtrace has_previous = has_call(state.prev_trace, key) @@ -88,9 +38,9 @@ function traceat(state::GFUpdateState, gen_fn::GenerativeFunction{T,U}, prev_subtrace = prev_call.subtrace get_gen_fn(prev_subtrace) === gen_fn || gen_fn_changed_error(key) (subtrace, weight, _, discard) = update(prev_subtrace, - args, map((_) -> UnknownChange(), args), constraints) + args, map((_) -> UnknownChange(), args), spec, sub_externally_constrained_addrs) else - (subtrace, weight) = generate(gen_fn, args, constraints) + (subtrace, weight) = generate(gen_fn, args, spec) end # update the weight @@ -119,60 +69,44 @@ function splice(state::GFUpdateState, gen_fn::DynamicDSLFunction, retval end -function update_delete_recurse(prev_trie::Trie{Any,ChoiceOrCallRecord}, - visited::EmptySelection) - score = 0. - for (key, choice_or_call) in get_leaf_nodes(prev_trie) - score += choice_or_call.score - end - for (key, subtrie) in get_internal_nodes(prev_trie) - score += update_delete_recurse(subtrie, EmptySelection()) - end - score -end - -function update_delete_recurse(prev_trie::Trie{Any,ChoiceOrCallRecord}, - visited::DynamicSelection) - score = 0. - for (key, choice_or_call) in get_leaf_nodes(prev_trie) +function update_delete_recurse(prev_trie::Trie{Any,CallRecord}, + visited::Selection, externally_constrained_addrs::Selection) + weight = 0. + for (key, call) in get_leaf_nodes(prev_trie) if !(key in visited) - score += choice_or_call.score + # weight += Q[deleted_subtrace | reverse_constraints] / P[deleted_subtrace] . + # where reverse_constraints = get_selected(choices(deleted_subtrace), externally_constrained_addrs) + # (ie. whatever choices in the discard are constrained externally) + sub_externally_constrained_addrs = get_subselection(externally_constrained_addrs, key) + reverse_constraint = get_selected(get_choices(call.subtrace), sub_externally_constrained_addrs) + weight += project(call, addrs(reverse_constraint)) end end for (key, subtrie) in get_internal_nodes(prev_trie) - subvisited = visited[key] - score += update_delete_recurse(subtrie, subvisited) + subvisited = get_subtree(visited, key) + sub_externally_constrained_addrs = get_subtree(externally_constrained_addrs, key) + weight += update_delete_recurse(subtrie, subvisited, sub_externally_constrained_addrs) end - score + weight end function add_unvisited_to_discard!(discard::DynamicChoiceMap, visited::DynamicSelection, prev_choices::ChoiceMap) - for (key, value) in get_values_shallow(prev_choices) - if !(key in visited) - @assert !has_value(discard, key) - @assert isempty(get_submap(discard, key)) - set_value!(discard, key, value) - end - end for (key, submap) in get_submaps_shallow(prev_choices) - @assert !has_value(discard, key) - if key in visited - # the recursive call to update already handled the discard - # for this entire submap - continue - else - subvisited = visited[key] + # if key IS in visited, + # the recursive call to update already handled the discard + # for this entire submap; else we need to handle it + if !(key in visited) + subvisited = get_subselection(visited, key) if isempty(subvisited) # none of this submap was visited, so we discard the whole thing @assert isempty(get_submap(discard, key)) set_submap!(discard, key, submap) else subdiscard = get_submap(discard, key) - add_unvisited_to_discard!( - isempty(subdiscard) ? choicemap() : subdiscard, - subvisited, submap) + subdiscard = isempty(subdiscard) ? choicemap() : subdiscard + add_unvisited_to_discard!(subdiscard, subvisited, submap) set_submap!(discard, key, subdiscard) end end @@ -180,16 +114,16 @@ function add_unvisited_to_discard!(discard::DynamicChoiceMap, end function update(trace::DynamicDSLTrace, arg_values::Tuple, arg_diffs::Tuple, - constraints::ChoiceMap) + spec::UpdateSpec, externally_constrained_addrs::Selection) gen_fn = trace.gen_fn - state = GFUpdateState(gen_fn, arg_values, trace, constraints, gen_fn.params) + state = GFUpdateState(gen_fn, arg_values, trace, spec, externally_constrained_addrs, gen_fn.params) retval = exec(gen_fn, state, arg_values) set_retval!(state.trace, retval) visited = get_visited(state.visitor) - state.weight -= update_delete_recurse(trace.trie, visited) + state.weight -= update_delete_recurse(trace.trie, visited, externally_constrained_addrs) add_unvisited_to_discard!(state.discard, visited, get_choices(trace)) - if !all_visited(visited, constraints) - error("Did not visit all constraints") + if !all_constraints_visited(visited, spec) + error("Did not visit all addresses in the update specification") end (state.trace, state.weight, UnknownChange(), state.discard) end diff --git a/src/gen_fn_interface.jl b/src/gen_fn_interface.jl index 40b494a80..c44cc0440 100644 --- a/src/gen_fn_interface.jl +++ b/src/gen_fn_interface.jl @@ -186,9 +186,12 @@ function generate(::GenerativeFunction, ::Tuple, ::ChoiceMap) error("Not implemented") end -function generate(gen_fn::GenerativeFunction, args::Tuple) - generate(gen_fn, args, EmptyChoiceMap()) -end +generate(gen_fn::GenerativeFunction, args::Tuple) = generate(gen_fn, args, EmptyChoiceMap()) + +# if we try to generate with an address tree that has some choices, and some selections, +# we should simply generate with the constraint choices, ignoring the selection part of the address tree +generate(gen_fn::GenerativeFunction, args::Tuple, t::AddressTree{Union{Value, SelectionLeaf}}) = generate(gen_fn, args, UnderlyingChoices(t)) +generate(gen_fn::GenerativeFunction, args::Tuple, ::AddressTree{SelectionLeaf}) = generate(gen_fn, args, EmptyAddressTree()) """ weight = project(trace::U, selection::Selection) @@ -203,9 +206,10 @@ let \$u\$ denote the restriction of \$t\$ to \$A\$. Return the weight \\log \\frac{p(r, t; x)}{q(t; u, x) q(r; x, t)} ``` """ -function project(trace, selection::Selection) +function project(trace::Trace, selection::Selection) error("Not implemented") end +project(trace::Trace, ::AllSelection) = get_score(trace) """ (choices, weight, retval) = propose(gen_fn::GenerativeFunction, args::Tuple) @@ -238,12 +242,15 @@ return the weight (`weight`): ``` It is an error if \$p(t; x) = 0\$. """ -function assess(gen_fn::GenerativeFunction, args::Tuple, choices::ChoiceMap) - (trace, weight) = generate(gen_fn, args, choices) - (weight, get_retval(trace)) -end +function assess end """ + (new_trace, weight, retdiff, reverse_update_spec) = update(trace, args::Tuple, argdiffs::Tuple, + spec::UpdateSpec, externally_constrained_addresses::Selection) + +TODO: Document me. + + (new_trace, weight, retdiff, discard) = update(trace, args::Tuple, argdiffs::Tuple, constraints::ChoiceMap) @@ -272,9 +279,12 @@ that if the original `trace` was generated using non-default argument values, then for each optional argument that is omitted, the old value will be over-written by the default argument value in the updated trace. """ -function update(trace, args::Tuple, argdiffs::Tuple, ::ChoiceMap) +function update(trace, args::Tuple, argdiffs::Tuple, ::UpdateSpec, ::Selection) error("Not implemented") end +function update(trace, args::Tuple, argdiffs::Tuple, constraints::ChoiceMap) + update(trace, args, argdiffs, constraints, AllSelection()) +end """ (new_trace, weight, retdiff) = regenerate(trace, args::Tuple, argdiffs::Tuple, @@ -304,7 +314,7 @@ then for each optional argument that is omitted, the old value will be over-written by the default argument value in the regenerated trace. """ function regenerate(trace, args::Tuple, argdiffs::Tuple, selection::Selection) - error("Not implemented") + update(trace, args, argdiffs, selection, EmptySelection())[1:3] end """ diff --git a/src/inference/kernel_dsl.jl b/src/inference/kernel_dsl.jl index a231f03a7..e4a0ba635 100644 --- a/src/inference/kernel_dsl.jl +++ b/src/inference/kernel_dsl.jl @@ -1,12 +1,13 @@ import MacroTools function check_observations(choices::ChoiceMap, observations::ChoiceMap) - for (key, value) in get_values_shallow(observations) - !has_value(choices, key) && error("Check failed: observed choice at $key not found") - choices[key] != value && error("Check failed: value of observed choice at $key changed") - end for (key, submap) in get_submaps_shallow(observations) - check_observations(get_submap(choices, key), submap) + if has_value(submap) + !has_value(choices, key) && error("Check failed: observed choice at $key not found") + choices[key] != get_value(submap) && error("Check failed: value of observed choice at $key should be $(get_value(submap)) but is $(choices[key])") + else + check_observations(get_submap(choices, key), submap) + end end end diff --git a/src/modeling_library/call_at/call_at.jl b/src/modeling_library/call_at/call_at.jl index 234116976..e95610303 100644 --- a/src/modeling_library/call_at/call_at.jl +++ b/src/modeling_library/call_at/call_at.jl @@ -1,4 +1,4 @@ -struct CallAtChoiceMap{K,T} <: ChoiceMap +struct CallAtChoiceMap{K,T} <: AddressTree{Value} key::K submap::T end @@ -9,15 +9,12 @@ function get_address_schema(::Type{T}) where {T<:CallAtChoiceMap} SingleDynamicKeyAddressSchema() end -function get_submap(choices::CallAtChoiceMap{K,T}, addr::K) where {K,T} +function get_subtree(choices::CallAtChoiceMap{K,T}, addr::K) where {K,T} choices.key == addr ? choices.submap : EmptyChoiceMap() end -get_submap(choices::CallAtChoiceMap, addr::Pair) = _get_submap(choices, addr) -get_value(choices::CallAtChoiceMap, addr::Pair) = _get_value(choices, addr) -has_value(choices::CallAtChoiceMap, addr::Pair) = _has_value(choices, addr) -get_submaps_shallow(choices::CallAtChoiceMap) = ((choices.key, choices.submap),) -get_values_shallow(::CallAtChoiceMap) = () +get_subtree(choices::CallAtChoiceMap, addr::Pair) = _get_subtree(choices, addr) +get_subtrees_shallow(choices::CallAtChoiceMap) = ((choices.key, choices.submap),) # TODO optimize CallAtTrace using type parameters @@ -69,7 +66,7 @@ unpack_call_at_args(args) = (args[end], args[1:end-1]) function assess(gen_fn::CallAtCombinator, args::Tuple, choices::ChoiceMap) (key, kernel_args) = unpack_call_at_args(args) - if length(get_submaps_shallow(choices)) > 1 || length(get_values_shallow(choices)) > 0 + if length(get_submaps_shallow(choices)) > 1 error("Not all constraints were consumed") end submap = get_submap(choices, key) @@ -99,62 +96,82 @@ function generate(gen_fn::CallAtCombinator{T,U,K}, args::Tuple, end function project(trace::CallAtTrace, selection::Selection) - subselection = selection[trace.key] + subselection = get_subselection(selection, trace.key) project(trace.subtrace, subselection) end function update(trace::CallAtTrace, args::Tuple, argdiffs::Tuple, - choices::ChoiceMap) + spec::UpdateSpec, externally_constrained_addrs::Selection) (key, kernel_args) = unpack_call_at_args(args) key_changed = (key != trace.key) - submap = get_submap(choices, key) - if key_changed - (subtrace, weight) = generate(trace.gen_fn.kernel, kernel_args, submap) - weight -= get_score(trace.subtrace) + subspec = get_subtree(spec, key) + if key_changed # TODO: remove the capacity to change key! + (subtrace, weight) = generate(trace.gen_fn.kernel, kernel_args, subspec) + + sub_ext_const_addrs = get_subselection(externally_constrained_addrs, key) + if sub_ext_const_addrs === AllSelection() + weight -= get_score(trace.subtrace) + else + reverse_external_constraint_addrs = addrs(get_selected(get_choices(trace.subtrace), sub_ext_const_addrs)) + weight -= project(trace.subtrace, reverse_external_constraint_addrs) + end discard = get_choices(trace) retdiff = UnknownChange() else + sub_ext_const_addrs = get_subtree(externally_constrained_addrs, key) (subtrace, weight, retdiff, subdiscard) = update( - trace.subtrace, kernel_args, argdiffs[1:end-1], submap) + trace.subtrace, kernel_args, argdiffs[1:end-1], subspec, sub_ext_const_addrs) discard = CallAtChoiceMap(key, subdiscard) end new_trace = CallAtTrace(trace.gen_fn, subtrace, key) (new_trace, weight, retdiff, discard) end -function regenerate(trace::CallAtTrace, args::Tuple, argdiffs::Tuple, - selection::Selection) - (key, kernel_args) = unpack_call_at_args(args) - key_changed = (key != trace.key) - subselection = selection[key] - if key_changed - if !isempty(subselection) - error("Cannot select addresses under new key $key in regenerate") +function choice_gradients(trace::CallAtTrace, selection::Selection, retval_grad) + if trace.subtrace isa DistributionTrace + if retval_grad !== nothing && !has_output_grad(get_gen_fn(trace.subtrace)) + error("return value gradient not accepted but one was provided") end - (subtrace, weight) = generate(trace.gen_fn.kernel, kernel_args, EmptyChoiceMap()) - weight -= project(trace.subtrace, EmptySelection()) - retdiff = UnknownChange() + kernel_arg_grads = logpdf_grad(get_gen_fn(trace.subtrace), get_retval(trace.subtrace), get_args(trace.subtrace)...) + if trace.key in selection + value_choices = CallAtChoiceMap(trace.key, get_choices(trace.subtrace)) + choice_grad = kernel_arg_grads[1] + if choice_grad === nothing + error("gradient not available for selected choice") + end + if retval_grad !== nothing + choice_grad += retval_grad + end + gradient_choices = CallAtChoiceMap(trace.key, Value(choice_grad)) + else + value_choices = EmptyChoiceMap() + gradient_choices = EmptyChoiceMap() + end + input_grads = (kernel_arg_grads[2:end]..., nothing) + return (input_grads, value_choices, gradient_choices) else - (subtrace, weight, retdiff) = regenerate( - trace.subtrace, kernel_args, argdiffs[1:end-1], subselection) + subselection = get_subselection(selection, trace.key) + (kernel_input_grads, value_submap, gradient_submap) = choice_gradients( + trace.subtrace, subselection, retval_grad) + input_grads = (kernel_input_grads..., nothing) + value_choices = CallAtChoiceMap(trace.key, value_submap) + gradient_choices = CallAtChoiceMap(trace.key, gradient_submap) + return (input_grads, value_choices, gradient_choices) end - new_trace = CallAtTrace(trace.gen_fn, subtrace, key) - (new_trace, weight, retdiff) -end - -function choice_gradients(trace::CallAtTrace, selection::Selection, retval_grad) - subselection = selection[trace.key] - (kernel_input_grads, value_submap, gradient_submap) = choice_gradients( - trace.subtrace, subselection, retval_grad) - input_grads = (kernel_input_grads..., nothing) - value_choices = CallAtChoiceMap(trace.key, value_submap) - gradient_choices = CallAtChoiceMap(trace.key, gradient_submap) - (input_grads, value_choices, gradient_choices) end function accumulate_param_gradients!(trace::CallAtTrace, retval_grad) - kernel_input_grads = accumulate_param_gradients!(trace.subtrace, retval_grad) - (kernel_input_grads..., nothing) + if trace.subtrace isa DistributionTrace + if retval_grad !== nothing && !has_output_grad(trace.gen_fn.dist) + error("return value gradient not accepted but one was provided") + end + kernel_arg_grads = logpdf_grad(get_gen_fn(trace.subtrace), get_retval(trace.subtrace), get_args(trace.subtrace)...) + return (kernel_arg_grads[2:end]..., nothing) + else + kernel_input_grads = accumulate_param_gradients!(trace.subtrace, retval_grad) + return (kernel_input_grads..., nothing) + end + end export call_at diff --git a/src/modeling_library/choice_at/choice_at.jl b/src/modeling_library/choice_at/choice_at.jl deleted file mode 100644 index 69bb4851a..000000000 --- a/src/modeling_library/choice_at/choice_at.jl +++ /dev/null @@ -1,175 +0,0 @@ -# TODO optimize ChoiceAtTrace using type parameters - -struct ChoiceAtTrace <: Trace - gen_fn::GenerativeFunction # the ChoiceAtCombinator (not the kernel) - value::Any - key::Any - kernel_args::Tuple - score::Float64 -end - -get_args(trace::ChoiceAtTrace) = (trace.kernel_args..., trace.key) -get_retval(trace::ChoiceAtTrace) = trace.value -get_score(trace::ChoiceAtTrace) = trace.score -get_gen_fn(trace::ChoiceAtTrace) = trace.gen_fn - -struct ChoiceAtChoiceMap{T,K} <: ChoiceMap - key::K - value::T -end - -get_choices(trace::ChoiceAtTrace) = ChoiceAtChoiceMap(trace.key, trace.value) -Base.isempty(::ChoiceAtChoiceMap) = false -function get_address_schema(::Type{T}) where {T<:ChoiceAtChoiceMap} - SingleDynamicKeyAddressSchema() -end -get_value(choices::ChoiceAtChoiceMap, addr::Pair) = _get_value(choices, addr) -has_value(choices::ChoiceAtChoiceMap, addr::Pair) = _has_value(choices, addr) -function get_value(choices::ChoiceAtChoiceMap{T,K}, addr::K) where {T,K} - choices.key == addr ? choices.value : throw(KeyError(choices, addr)) -end -get_submaps_shallow(choices::ChoiceAtChoiceMap) = () -get_values_shallow(choices::ChoiceAtChoiceMap) = ((choices.key, choices.value),) - -struct ChoiceAtCombinator{T,K} <: GenerativeFunction{T, ChoiceAtTrace} - dist::Distribution{T} -end - -accepts_output_grad(gen_fn::ChoiceAtCombinator) = has_output_grad(gen_fn.dist) - -# TODO -# accepts_output_grad is true if the return value is dependent on the 'gradient source elements' -# if the random choice itself is not a 'gradient source element' then it is independent (false) -# if the random choice is a 'gradient source element', then the return value is dependent (true) -# we will consider the random choice as a gradient source element if the -# distribution has has_output_grad = true) - -function choice_at(dist::Distribution{T}, ::Type{K}) where {T,K} - ChoiceAtCombinator{T,K}(dist) -end - -unpack_choice_at_args(args) = (args[end], args[1:end-1]) - -function assess(gen_fn::ChoiceAtCombinator{T,K}, args::Tuple, choices::ChoiceMap) where {T,K} - local key::K - local value::T - (key, kernel_args) = unpack_choice_at_args(args) - value = get_value(choices, key) - weight = logpdf(gen_fn.dist, value, kernel_args...) - (weight, value) -end - -function propose(gen_fn::ChoiceAtCombinator{T,K}, args::Tuple) where {T,K} - local key::K - local value::T - (key, kernel_args) = unpack_choice_at_args(args) - value = random(gen_fn.dist, kernel_args...) - score = logpdf(gen_fn.dist, value, kernel_args...) - choices = ChoiceAtChoiceMap(key, value) - (choices, score, value) -end - -function simulate(gen_fn::ChoiceAtCombinator, args::Tuple) - (key, kernel_args) = unpack_choice_at_args(args) - value = random(gen_fn.dist, kernel_args...) - score = logpdf(gen_fn.dist, value, kernel_args...) - ChoiceAtTrace(gen_fn, value, key, kernel_args, score) -end - -function generate(gen_fn::ChoiceAtCombinator{T,K}, args::Tuple, choices::ChoiceMap) where {T,K} - local key::K - local value::T - (key, kernel_args) = unpack_choice_at_args(args) - constrained = has_value(choices, key) - value = constrained ? get_value(choices, key) : random(gen_fn.dist, kernel_args...) - score = logpdf(gen_fn.dist, value, kernel_args...) - trace = ChoiceAtTrace(gen_fn, value, key, kernel_args, score) - weight = constrained ? score : 0. - (trace, weight) -end - -function project(trace::ChoiceAtTrace, selection::Selection) - (trace.key in selection) ? trace.score : 0. -end - -function update(trace::ChoiceAtTrace, args::Tuple, argdiffs::Tuple, - choices::ChoiceMap) - (key, kernel_args) = unpack_choice_at_args(args) - key_changed = (key != trace.key) - constrained = has_value(choices, key) - if key_changed && constrained - new_value = get_value(choices, key) - discard = ChoiceAtChoiceMap(trace.key, trace.value) - elseif !key_changed && constrained - new_value = get_value(choices, key) - discard = ChoiceAtChoiceMap(key, trace.value) - elseif !key_changed && !constrained - new_value = trace.value - discard = EmptyChoiceMap() - else - error("New address $key not constrained in update") - end - new_score = logpdf(trace.gen_fn.dist, new_value, kernel_args...) - new_trace = ChoiceAtTrace(trace.gen_fn, new_value, key, kernel_args, new_score) - weight = new_score - trace.score - (new_trace, weight, UnknownChange(), discard) -end - -function regenerate(trace::ChoiceAtTrace, args::Tuple, argdiffs::Tuple, - selection::Selection) - (key, kernel_args) = unpack_choice_at_args(args) - key_changed = (key != trace.key) - selected = key in selection - if !key_changed && selected - new_value = random(trace.gen_fn.dist, kernel_args...) - elseif !key_changed && !selected - new_value = trace.value - elseif key_changed && !selected - new_value = random(trace.gen_fn.dist, kernel_args...) - else - error("Cannot select new address $key in regenerate") - end - new_score = logpdf(trace.gen_fn.dist, new_value, kernel_args...) - if !key_changed && selected - weight = 0. - elseif !key_changed && !selected - weight = new_score - trace.score - elseif key_changed && !selected - weight = 0. - end - new_trace = ChoiceAtTrace(trace.gen_fn, new_value, key, kernel_args, new_score) - (new_trace, weight, UnknownChange()) -end - -function choice_gradients(trace::ChoiceAtTrace, selection::Selection, retval_grad) - if retval_grad != nothing && !has_output_grad(trace.gen_fn.dist) - error("return value gradient not accepted but one was provided") - end - kernel_arg_grads = logpdf_grad(trace.gen_fn.dist, trace.value, trace.kernel_args...) - if trace.key in selection - value_choices = ChoiceAtChoiceMap(trace.key, trace.value) - choice_grad = kernel_arg_grads[1] - if choice_grad == nothing - error("gradient not available for selected choice") - end - if retval_grad != nothing - choice_grad += retval_grad - end - gradient_choices = ChoiceAtChoiceMap(trace.key, choice_grad) - else - value_choices = EmptyChoiceMap() - gradient_choices = EmptyChoiceMap() - end - input_grads = (kernel_arg_grads[2:end]..., nothing) - (input_grads, value_choices, gradient_choices) -end - -function accumulate_param_gradients!(trace::ChoiceAtTrace, retval_grad) - if retval_grad != nothing && !has_output_grad(trace.gen_fn.dist) - error("return value gradient not accepted but one was provided") - end - kernel_arg_grads = logpdf_grad(trace.gen_fn.dist, trace.value, trace.kernel_args...) - (kernel_arg_grads[2:end]..., nothing) -end - -export choice_at diff --git a/src/modeling_library/custom_determ.jl b/src/modeling_library/custom_determ.jl index 48045ddb1..4c03ce199 100644 --- a/src/modeling_library/custom_determ.jl +++ b/src/modeling_library/custom_determ.jl @@ -107,7 +107,7 @@ function generate(gen_fn::CustomDetermGF{T,S}, args::Tuple, choices::ChoiceMap) trace, 0. end -function update(trace::CustomDetermGFTrace{T,S}, args::Tuple, argdiffs::Tuple, choices::ChoiceMap) where {T,S} +function update(trace::CustomDetermGFTrace{T,S}, args::Tuple, argdiffs::Tuple, choices::ChoiceMap, ::Selection) where {T,S} if !isempty(choices) error("Deterministic generative function makes no random choices") end @@ -116,10 +116,6 @@ function update(trace::CustomDetermGFTrace{T,S}, args::Tuple, argdiffs::Tuple, c (new_trace, 0., retdiff, choicemap()) end -function regenerate(trace::CustomDetermGFTrace, args::Tuple, argdiffs::Tuple, selection::Selection) - update(trace, args, argdiffs, EmptyChoiceMap()) -end - function choice_gradients(trace::CustomDetermGFTrace, selection::Selection, retgrad) arg_grads = gradient_with_state(trace.gen_fn, trace.state, trace.args, retgrad) (arg_grads, EmptyChoiceMap(), EmptyChoiceMap()) diff --git a/src/modeling_library/distributions/uniform_discrete.jl b/src/modeling_library/distributions/uniform_discrete.jl index b125b3b2a..3a2ace125 100644 --- a/src/modeling_library/distributions/uniform_discrete.jl +++ b/src/modeling_library/distributions/uniform_discrete.jl @@ -11,6 +11,7 @@ function logpdf(::UniformDiscrete, x::Int, low::Integer, high::Integer) d = Distributions.DiscreteUniform(low, high) Distributions.logpdf(d, x) end +logpdf(u::UniformDiscrete, x::Real, low::Integer, high::Integer) = logpdf(u, convert(Int, x), low, high) function logpdf_grad(::UniformDiscrete, x::Int, lower::Integer, high::Integer) (nothing, nothing, nothing) diff --git a/src/modeling_library/map/backprop.jl b/src/modeling_library/map/backprop.jl index e14dbd053..4af8a6fd2 100644 --- a/src/modeling_library/map/backprop.jl +++ b/src/modeling_library/map/backprop.jl @@ -20,7 +20,7 @@ function choice_gradients(trace::VectorTrace{MapType,T,U}, selection::Selection, for key=1:len subtrace = trace.subtraces[key] - sub_selection = selection[key] + sub_selection = get_subselection(selection, key) kernel_retval_grad = (retval_grad == nothing) ? nothing : retval_grad[key] (kernel_arg_grad::Tuple, kernel_value_choices, kernel_gradient_choices) = choice_gradients( subtrace, sub_selection, kernel_retval_grad) diff --git a/src/modeling_library/map/generic_update.jl b/src/modeling_library/map/generic_update.jl deleted file mode 100644 index c8865d5ac..000000000 --- a/src/modeling_library/map/generic_update.jl +++ /dev/null @@ -1,68 +0,0 @@ - -function get_kernel_argdiffs(argdiffs::Tuple) - n_args = length(argdiffs) - kernel_argdiffs = Dict{Int,Vector}() - for (i, diff) in enumerate(argdiffs) - if isa(diff, VectorDiff) - for (key, element_diff) in diff.updated - if !haskey(kernel_argdiffs, key) - kernel_argdiff = Vector{Any}(undef, n_args) - fill!(kernel_argdiff, NoChange()) - kernel_argdiffs[key] = kernel_argdiff - end - kernel_argdiffs[key][i] = diff.updated[key] - end - end - end - kernel_argdiffs -end - -function process_all_retained!(gen_fn::Map{T,U}, args::Tuple, argdiffs::Tuple, - choices_or_selection, prev_length::Int, new_length::Int, - retained_and_targeted::Set{Int}, state) where {T,U} - kernel_no_change_argdiffs = map((_) -> NoChange(), args) - kernel_unknown_change_argdiffs = map((_) -> UnknownChange(), args) - - if all(diff == NoChange() for diff in argdiffs) - - # only visit retained applications that were targeted - for key in retained_and_targeted - @assert key <= min(new_length, prev_length) - process_retained!(gen_fn, args, choices_or_selection, key, kernel_no_change_argdiffs, state) - end - - elseif any(diff == UnknownChange() for diff in argdiffs) - - # visit every retained application - for key in 1:min(prev_length, new_length) - @assert key <= min(new_length, prev_length) - process_retained!(gen_fn, args, choices_or_selection, key, kernel_unknown_change_argdiffs, state) - end - - else - - key_to_kernel_argdiffs = get_kernel_argdiffs(argdiffs) - - # visit every retained applications that either has an argdiff or was targeted - for key in union(keys(key_to_kernel_argdiffs), retained_and_targeted) - @assert key <= min(new_length, prev_length) - if haskey(key_to_kernel_argdiffs, key) - kernel_argdiffs = tuple(key_to_kernel_argdiffs[key]...) - else - kernel_argdiffs = kernel_no_change_argdiffs - end - process_retained!(gen_fn, args, choices_or_selection, key, kernel_argdiffs, state) - end - - end -end - -""" -Process all new applications. -""" -function process_all_new!(gen_fn::Map{T,U}, args::Tuple, choices_or_selection, - prev_len::Int, new_len::Int, state) where {T,U} - for key=prev_len+1:new_len - process_new!(gen_fn, args, choices_or_selection, key, state) - end -end diff --git a/src/modeling_library/map/map.jl b/src/modeling_library/map/map.jl index 1bb695ff7..cea43cf9b 100644 --- a/src/modeling_library/map/map.jl +++ b/src/modeling_library/map/map.jl @@ -47,7 +47,5 @@ include("assess.jl") include("propose.jl") include("simulate.jl") include("generate.jl") -include("generic_update.jl") include("update.jl") -include("regenerate.jl") include("backprop.jl") diff --git a/src/modeling_library/map/regenerate.jl b/src/modeling_library/map/regenerate.jl deleted file mode 100644 index 8634bbc30..000000000 --- a/src/modeling_library/map/regenerate.jl +++ /dev/null @@ -1,103 +0,0 @@ -mutable struct MapRegenerateState{T,U} - weight::Float64 - score::Float64 - noise::Float64 - subtraces::PersistentVector{U} - retval::PersistentVector{T} - num_nonempty::Int - updated_retdiffs::Dict{Int,Diff} -end - -function process_retained!(gen_fn::Map{T,U}, args::Tuple, - selection::Selection, key::Int, kernel_argdiffs::Tuple, - state::MapRegenerateState{T,U}) where {T,U} - local subtrace::U - local prev_subtrace::U - local retval::T - - subselection = selection[key] - kernel_args = get_args_for_key(args, key) - - # get new subtrace with recursive call to regenerate() - prev_subtrace = state.subtraces[key] - (subtrace, weight, retdiff) = regenerate( - prev_subtrace, kernel_args, kernel_argdiffs, subselection) - - # retrieve retdiff - if retdiff != NoChange() - state.updated_retdiffs[key] = retdiff - end - - # update state - state.weight += weight - state.score += (get_score(subtrace) - get_score(prev_subtrace)) - state.noise += (project(subtrace, EmptySelection()) - project(subtrace, EmptySelection())) - state.subtraces = assoc(state.subtraces, key, subtrace) - retval = get_retval(subtrace) - state.retval = assoc(state.retval, key, retval) - subtrace_empty = isempty(get_choices(subtrace)) - prev_subtrace_empty = isempty(get_choices(prev_subtrace)) - if !subtrace_empty && prev_subtrace_empty - state.num_nonempty += 1 - elseif subtrace_empty && !prev_subtrace_empty - state.num_nonempty -= 1 - end -end - -function process_new!(gen_fn::Map{T,U}, args::Tuple, selection::Selection, key::Int, - state::MapRegenerateState{T,U}) where {T,U} - local subtrace::U - local retval::T - if !isempty(selection[key]) - error("Tried to select new address in regenerate at key $key") - end - kernel_args = get_args_for_key(args, key) - - # get subtrace and weight - (subtrace, weight) = generate(gen_fn.kernel, kernel_args, EmptyChoiceMap()) - - # update state - state.weight += weight - state.score += get_score(subtrace) - retval = get_retval(subtrace) - @assert key >= length(state.subtraces) - state.subtraces = push(state.subtraces, subtrace) - state.retval = push(state.retval, retval) - @assert length(state.subtraces) == key - if !isempty(get_choices(subtrace)) - state.num_nonempty += 1 - end -end - - -function regenerate(trace::VectorTrace{MapType,T,U}, args::Tuple, argdiffs::Tuple, - selection::Selection) where {T,U} - gen_fn = trace.gen_fn - (new_length, prev_length) = get_prev_and_new_lengths(args, trace) - retained_and_selected = get_retained_and_selected(selection, prev_length, new_length) - - # handle removed applications - (num_nonempty, score_decrement, noise_decrement) = vector_regenerate_delete( - new_length, prev_length, trace) - (subtraces, retval) = vector_remove_deleted_applications( - trace.subtraces, trace.retval, prev_length, new_length) - score = trace.score - score_decrement - noise = trace.noise - noise_decrement - - # handle retained and new applications - state = MapRegenerateState{T,U}(-noise_decrement, score, noise, - subtraces, retval, - num_nonempty, Dict{Int,Diff}()) - process_all_retained!(gen_fn, args, argdiffs, selection, - prev_length, new_length, retained_and_selected, state) - process_all_new!(gen_fn, args, selection, prev_length, new_length, state) - - # retdiff - retdiff = vector_compute_retdiff(state.updated_retdiffs, new_length, prev_length) - - # new trace - new_trace = VectorTrace{MapType,T,U}(gen_fn, state.subtraces, state.retval, args, - state.score, state.noise, new_length, state.num_nonempty) - - return (new_trace, state.weight, retdiff) -end diff --git a/src/modeling_library/map/update.jl b/src/modeling_library/map/update.jl index fd8ffc777..afb251cf5 100644 --- a/src/modeling_library/map/update.jl +++ b/src/modeling_library/map/update.jl @@ -9,20 +9,89 @@ mutable struct MapUpdateState{T,U} updated_retdiffs::Dict{Int,Diff} end +function get_kernel_argdiffs(argdiffs::Tuple) + n_args = length(argdiffs) + kernel_argdiffs = Dict{Int,Vector}() + for (i, diff) in enumerate(argdiffs) + if isa(diff, VectorDiff) + for (key, element_diff) in diff.updated + if !haskey(kernel_argdiffs, key) + kernel_argdiff = Vector{Any}(undef, n_args) + fill!(kernel_argdiff, NoChange()) + kernel_argdiffs[key] = kernel_argdiff + end + kernel_argdiffs[key][i] = diff.updated[key] + end + end + end + kernel_argdiffs +end + +function process_all_retained!(gen_fn::Map{T,U}, args::Tuple, argdiffs::Tuple, + spec::UpdateSpec, prev_length::Int, new_length::Int, + retained_and_targeted::Set{Int}, externally_constrained_addrs, state) where {T,U} + kernel_no_change_argdiffs = map((_) -> NoChange(), args) + kernel_unknown_change_argdiffs = map((_) -> UnknownChange(), args) + + if all(diff == NoChange() for diff in argdiffs) + + # only visit retained applications that were targeted + for key in retained_and_targeted + @assert key <= min(new_length, prev_length) + process_retained!(gen_fn, args, spec, key, kernel_no_change_argdiffs, externally_constrained_addrs, state) + end + + elseif any(diff == UnknownChange() for diff in argdiffs) + + # visit every retained application + for key in 1:min(prev_length, new_length) + @assert key <= min(new_length, prev_length) + process_retained!(gen_fn, args, spec, key, kernel_unknown_change_argdiffs, externally_constrained_addrs, state) + end + + else + + key_to_kernel_argdiffs = get_kernel_argdiffs(argdiffs) + + # visit every retained applications that either has an argdiff or was targeted + for key in union(keys(key_to_kernel_argdiffs), retained_and_targeted) + @assert key <= min(new_length, prev_length) + if haskey(key_to_kernel_argdiffs, key) + kernel_argdiffs = tuple(key_to_kernel_argdiffs[key]...) + else + kernel_argdiffs = kernel_no_change_argdiffs + end + process_retained!(gen_fn, args, spec, key, kernel_argdiffs, externally_constrained_addrs, state) + end + + end +end + +""" +Process all new applications. +""" +function process_all_new!(gen_fn::Map{T,U}, args::Tuple, spec, + prev_len::Int, new_len::Int, state) where {T,U} + for key=prev_len+1:new_len + process_new!(gen_fn, args, spec, key, state) + end +end + function process_retained!(gen_fn::Map{T,U}, args::Tuple, - choices::ChoiceMap, key::Int, kernel_argdiffs::Tuple, - state::MapUpdateState{T,U}) where {T,U} + spec::UpdateSpec, key::Int, kernel_argdiffs::Tuple, + ext_const_addrs::Selection, state::MapUpdateState{T,U}) where {T,U} local subtrace::U local prev_subtrace::U local retval::T - submap = get_submap(choices, key) + subspec = get_subtree(spec, key) + sub_ext_const_addrs = get_subtree(ext_const_addrs, key) kernel_args = get_args_for_key(args, key) # get new subtrace with recursive call to update() prev_subtrace = state.subtraces[key] (subtrace, weight, retdiff, discard) = update( - prev_subtrace, kernel_args, kernel_argdiffs, submap) + prev_subtrace, kernel_args, kernel_argdiffs, subspec, sub_ext_const_addrs) # retrieve retdiff if retdiff != NoChange() @@ -51,7 +120,7 @@ function process_new!(gen_fn::Map{T,U}, args::Tuple, choices, key::Int, local subtrace::U local retval::T - submap = get_submap(choices, key) + submap = get_subtree(choices, key) kernel_args = get_args_for_key(args, key) # get subtrace and weight @@ -72,26 +141,27 @@ end function update(trace::VectorTrace{MapType,T,U}, args::Tuple, argdiffs::Tuple, - choices::ChoiceMap) where {T,U} + spec::UpdateSpec, externally_constrained_addrs::Selection) where {T,U} gen_fn = trace.gen_fn (new_length, prev_length) = get_prev_and_new_lengths(args, trace) - retained_and_constrained = get_retained_and_constrained(choices, prev_length, new_length) + retained_and_constrained = get_retained_and_specd(spec, prev_length, new_length) + # TODO: for performance, don't use a Set for `retained_and_constrained` # handle removed applications - (discard, num_nonempty, score_decrement, noise_decrement) = vector_update_delete( - new_length, prev_length, trace) + (discard, num_nonempty, score_decrement, noise_decrement, weight_decrement) = vector_update_delete( + new_length, prev_length, trace, externally_constrained_addrs) (subtraces, retval) = vector_remove_deleted_applications( trace.subtraces, trace.retval, prev_length, new_length) score = trace.score - score_decrement noise = trace.noise - noise_decrement # handle retained and new applications - state = MapUpdateState{T,U}(-score_decrement, score, noise, + state = MapUpdateState{T,U}(-weight_decrement, score, noise, subtraces, retval, discard, num_nonempty, Dict{Int,Diff}()) - process_all_retained!(gen_fn, args, argdiffs, choices, prev_length, new_length, - retained_and_constrained, state) - process_all_new!(gen_fn, args, choices, prev_length, new_length, state) + process_all_retained!(gen_fn, args, argdiffs, spec, prev_length, new_length, + retained_and_constrained, externally_constrained_addrs, state) + process_all_new!(gen_fn, args, spec, prev_length, new_length, state) # retdiff retdiff = vector_compute_retdiff(state.updated_retdiffs, new_length, prev_length) diff --git a/src/modeling_library/modeling_library.jl b/src/modeling_library/modeling_library.jl index d0797426c..09fe899ca 100644 --- a/src/modeling_library/modeling_library.jl +++ b/src/modeling_library/modeling_library.jl @@ -5,54 +5,6 @@ import Distributions using SpecialFunctions: loggamma, logbeta, digamma -abstract type Distribution{T} end - -""" - val::T = random(dist::Distribution{T}, args...) - -Sample a random choice from the given distribution with the given arguments. -""" -function random end - -""" - lpdf = logpdf(dist::Distribution{T}, value::T, args...) - -Evaluate the log probability (density) of the value. -""" -function logpdf end - -""" - has::Bool = has_output_grad(dist::Distribution) - -Return true of the gradient if the distribution computes the gradient of the logpdf with respect to the value of the random choice. -""" -function has_output_grad end - -""" - grads::Tuple = logpdf_grad(dist::Distribution{T}, value::T, args...) - -Compute the gradient of the logpdf with respect to the value, and each of the arguments. - -If `has_output_grad` returns false, then the first element of the returned tuple is `nothing`. -Otherwise, the first element of the tuple is the gradient with respect to the value. -If the return value of `has_argument_grads` has a false value for at position `i`, then the `i+1`th element of the returned tuple has value `nothing`. -Otherwise, this element contains the gradient with respect to the `i`th argument. -""" -function logpdf_grad end - -function is_discrete end - -# NOTE: has_argument_grad is documented and exported in gen_fn_interface.jl - -get_return_type(::Distribution{T}) where {T} = T - -export Distribution -export random -export logpdf -export logpdf_grad -export has_output_grad -export is_discrete - # built-in distributions include("distributions/distributions.jl") @@ -67,11 +19,11 @@ include("dist_dsl/dist_dsl.jl") include("vector.jl") # built-in generative function combinators -include("choice_at/choice_at.jl") include("call_at/call_at.jl") include("map/map.jl") include("unfold/unfold.jl") include("recurse/recurse.jl") +include("set_map/set_map.jl") ############################################################# # abstractions for constructing custom generative functions # diff --git a/src/modeling_library/recurse/recurse.jl b/src/modeling_library/recurse/recurse.jl index 715800737..c93903b46 100644 --- a/src/modeling_library/recurse/recurse.jl +++ b/src/modeling_library/recurse/recurse.jl @@ -51,7 +51,7 @@ project(trace::RecurseTrace, ::EmptySelection) = 0. # recurse assignment wrapper # ############################## -struct RecurseTraceChoiceMap <: ChoiceMap +struct RecurseTraceChoiceMap <: AddressTree{Value} trace::RecurseTrace end @@ -64,7 +64,7 @@ end get_address_schema(::Type{RecurseTraceChoiceMap}) = DynamicAddressSchema() -function get_submap(choices::RecurseTraceChoiceMap, +function get_subtree(choices::RecurseTraceChoiceMap, addr::Tuple{Int,Val{:production}}) idx = addr[1] if !haskey(choices.trace.aggregation_traces, idx) @@ -74,7 +74,7 @@ function get_submap(choices::RecurseTraceChoiceMap, end end -function get_submap(choices::RecurseTraceChoiceMap, +function get_subtree(choices::RecurseTraceChoiceMap, addr::Tuple{Int,Val{:aggregation}}) idx = addr[1] if !haskey(choices.trace.aggregation_traces, idx) @@ -84,21 +84,11 @@ function get_submap(choices::RecurseTraceChoiceMap, end end -function get_submap(choices::RecurseTraceChoiceMap, addr::Pair) - _get_submap(choices, addr) -end - -function has_value(choices::RecurseTraceChoiceMap, addr::Pair) - _has_value(choices, addr) -end - -function get_value(choices::RecurseTraceChoiceMap, addr::Pair) - _get_value(choices, addr) -end +get_subtree(choices::RecurseTraceChoiceMap, addr::Pair) = _get_subtree(choices, addr) get_values_shallow(choices::RecurseTraceChoiceMap) = () -function get_submaps_shallow(choices::RecurseTraceChoiceMap) +function get_subtrees_shallow(choices::RecurseTraceChoiceMap) production_iter = (((idx, Val(:production)), get_choices(subtrace)) for (idx, subtrace) in choices.trace.production_traces) aggregation_iter = (((idx, Val(:aggregation)), get_choices(subtrace)) @@ -333,6 +323,9 @@ function recurse_unpack_constraints(constraints::ChoiceMap) production_constraints = Dict{Int, Any}() aggregation_constraints = Dict{Int, Any}() for (addr, node) in get_submaps_shallow(constraints) + if has_value(node) + error("Unknown address: $(addr)") + end idx::Int = addr[1] if addr[2] == Val(:production) production_constraints[idx] = node @@ -342,9 +335,6 @@ function recurse_unpack_constraints(constraints::ChoiceMap) error("Unknown address: $addr") end end - if length(get_values_shallow(constraints)) > 0 - error("Unknown address: $(first(get_values_shallow(constraints))[1])") - end return (production_constraints, aggregation_constraints) end diff --git a/src/modeling_library/set_map/multiset.jl b/src/modeling_library/set_map/multiset.jl new file mode 100644 index 000000000..0b66f292d --- /dev/null +++ b/src/modeling_library/set_map/multiset.jl @@ -0,0 +1,75 @@ +import FunctionalCollections + +export MultiSet, remove_one, setmap + +struct MultiSet{T} + counts::PersistentHashMap + len::Int +end +function MultiSet{T}() where {T} + MultiSet{T}(PersistentHashMap{T, Int}(), 0) +end +MultiSet() = MultiSet{Any}() +Base.convert(::Type{MultiSet{T}}, ms::MultiSet) where {T} = MultiSet{T}(ms.counts, ms.len) + +function MultiSet(vals::Vector{T}) where T + ms = MultiSet{T}() + for val in vals + ms = push(ms, val) + end + ms +end + +Base.length(ms::MultiSet) = ms.len +function Base.:(==)(ms1::MultiSet, ms2::MultiSet) + if length(ms1) !== length(ms2); return false; end; + for (key, cnt) in ms1.counts + if !haskey(ms2.counts, key); return false; end; + if ms2.counts[key] !== cnt; return false; end; + end + for (key, _) in ms2.counts + if !haskey(ms1.counts, key); return false; end; + end + return true +end +function FunctionalCollections.push(ms::MultiSet{T}, el::T) where T + if haskey(ms.counts, el) + return MultiSet{T}(assoc(ms.counts, el, ms.counts[el] + 1), ms.len + 1) + else + return MultiSet{T}(assoc(ms.counts, el, 1), ms.len + 1) + end +end +function remove_one(ms::MultiSet{T}, el::T) where T + cnt = ms.counts[el] + if cnt == 1 + return MultiSet{T}(dissoc(ms.counts, el), ms.len - 1) + else + return MultiSet{T}(assoc(ms.counts, el, cnt - 1), ms.len - 1) + end +end +function FunctionalCollections.disj(ms::MultiSet{T}, el::T) where T + cnt = ms.counts[el] + return MultiSet{T}(dissoc(ms.counts, el), ms.len - cnt) +end +Base.in(el::T, ms::MultiSet{T}) where T = haskey(ms.counts, el) +function Base.iterate(ms::MultiSet{T}) where T + i = iterate(ms.counts) + if i === nothing; return nothing; end; + ((key, cnt), st) = i + return Base.iterate(ms, (key, cnt, st)) +end +function Base.iterate(ms::MultiSet{T}, (key, cnt, st)) where T + if cnt == 0 + i = iterate(ms.counts, st) + if i === nothing; return nothing; end; + ((key, cnt), st) = i + return Base.iterate(ms, (key, cnt, st)) + else + return (key, (key, cnt-1, st)) + end +end + +function setmap(f, set) + vals = [f(el) for el in set] + MultiSet(vals) +end \ No newline at end of file diff --git a/src/modeling_library/set_map/set_map.jl b/src/modeling_library/set_map/set_map.jl new file mode 100644 index 000000000..58da71dff --- /dev/null +++ b/src/modeling_library/set_map/set_map.jl @@ -0,0 +1,106 @@ +include("multiset.jl") + +export SetMap + +struct SetTrace{ArgType, RetType, TraceType} <: Trace + gen_fn::GenerativeFunction + subtraces::PersistentHashMap{ArgType, TraceType} + args::Tuple + score::Float64 + noise::Float64 +end +function get_choices(trace::SetTrace) + # TODO: specialized choicemap type + c = choicemap() + for (arg, tr) in trace.subtraces + set_subtree!(c, arg, get_choices(tr)) + end + c +end +get_retval(trace::SetTrace) = setmap(((_, tr),) -> get_retval(tr), trace.subtraces) +get_args(trace::SetTrace) = trace.args +get_score(trace::SetTrace) = trace.score +get_gen_fn(trace::SetTrace) = trace.gen_fn +project(trace::SetTrace, ::EmptyAddressTree) = trace.noise +Base.getindex(tr::SetTrace, address) = tr.subtraces[address][] +Base.getindex(tr::SetTrace, address::Pair) = tr.subtraces[address.first][address] + +struct SetMap{RetType, TraceType} <: GenerativeFunction{MultiSet{RetType}, SetTrace{<:Any, RetType, TraceType}} + kernel::GenerativeFunction{RetType, T} where {T >: TraceType} + function SetMap{RetType, TraceType}(kernel::GenerativeFunction{RetType, T} where {T >: TraceType}) where {RetType, TraceType} + new{RetType, TraceType}(kernel) + end +end +function SetMap(kernel::GenerativeFunction{RetType, TraceType}) where {RetType, TraceType} + SetMap{RetType, get_trace_type(kernel)}(kernel) +end + +has_argument_grads(gf::SetMap) = has_argument_grads(gf.kernel) +accepts_output_grad(gf::SetMap) = accepts_output_grad(gf.kernel) + +function simulate(sm::SetMap{RetType, TraceType}, (set,)::Tuple{<:AbstractSet{ArgType}}) where {RetType, TraceType, ArgType} + subtraces = PersistentHashMap{ArgType, TraceType}() + score = 0. + noise = 0. + for item in set + subtr = simulate(sm.kernel, (item,)) + subtraces = assoc(subtraces, item, subtr) + score += get_score(subtr) + noise += project(subtr, EmptyAddressTree()) + end + return SetTrace{ArgType, RetType, TraceType}(sm, subtraces, (set,), score, noise) +end + +function generate(sm::SetMap{RetType, TraceType}, (set,)::Tuple{<:AbstractSet{ArgType}}, constraints::ChoiceMap) where {ArgType, RetType, TraceType} + subtraces = PersistentHashMap{ArgType, TraceType}() + score = 0. + weight = 0. + noise = 0. + for item in set + constraint = get_subtree(constraints, item) + subtr, wt = generate(sm.kernel, (item,), constraint) + weight += wt + noise += project(subtr, EmptyAddressTree()) + subtraces = assoc(subtraces, item, subtr) + score += get_score(subtr) + end + return (SetTrace{ArgType, RetType, TraceType}(sm, subtraces, (set,), score), weight, noise) +end + +# TODO: handle argdiffs +function update(tr::SetTrace{ArgType, RetType, TraceType}, (set,)::Tuple, ::Tuple{<:Diff}, spec::UpdateSpec, ext_const_addrs::Selection) where {ArgType, RetType, TraceType} + new_subtraces = PersistentHashMap{ArgType, TraceType}() + discard = choicemap() + weight = 0. + score = 0. + noise = 0. + for item in set + if item in keys(tr.subtraces) + (new_tr, wt, retdiff, this_discard) = update( + tr.subtraces[item], (item,), + (UnknownChange(),), + get_subtree(spec, item), + get_subtree(ext_const_addrs, item) + ) + new_subtraces = assoc(new_subtraces, item, new_tr) + score += get_score(new_tr) + noise += project(new_tr, EmptyAddressTree()) + weight += wt + set_subtree!(discard, item, this_discard) + else + tr, weight = generate(tr.gen_fn.kernel, (item,), get_subspec(spec, item)) + score += get_score(tr) + noise += project(tr, EmptyAddressTree()) + new_subtraces = assoc(new_subtraces, item, tr) + end + end + for (item, tr) in tr.subtraces + if !(item in set) + ext_const = get_subtree(ext_const_addrs, item) + weight -= project(tr, addrs(get_selected(get_choices(tr), ext_const))) + set_subtree!(discard, item, get_choices(tr)) + end + end + tr = SetTrace{ArgType, RetType, TraceType}(tr.gen_fn, new_subtraces, (set,), score, noise) + return (tr, weight, UnknownChange(), discard) +end \ No newline at end of file diff --git a/src/modeling_library/unfold/generic_update.jl b/src/modeling_library/unfold/generic_update.jl deleted file mode 100644 index 1d693277e..000000000 --- a/src/modeling_library/unfold/generic_update.jl +++ /dev/null @@ -1,60 +0,0 @@ -function process_all_retained!(gen_fn::Unfold{T,U}, params::Tuple, argdiffs::Tuple, - choices_or_selection, prev_length::Int, new_length::Int, - retained_and_targeted::Set{Int}, state) where {T,U} - - len_diff = argdiffs[1] - init_state_diff = argdiffs[2] - param_diffs = argdiffs[3:end] # a tuple of diffs - - if any(diff != NoChange() for diff in param_diffs) - - # visit every retained kernel application - state_diff = init_state_diff - for key=1:min(prev_length,new_length) - state_diff = process_retained!(gen_fn, params, choices_or_selection, - key, (NoChange(), state_diff, param_diffs...), state) - end - - else - # every parameter diff is NoChange() - - # visit only certain retained kernel applications - to_visit::Vector{Int} = sort(collect(retained_and_targeted)) - key = 0 - state_diff = init_state_diff - if state_diff != NoChange() - key = 1 - visit = true - while visit && key <= min(prev_length, new_length) - state_diff = process_retained!(gen_fn, params, choices_or_selection, - key, (NoChange(), state_diff, param_diffs...), state) - key += 1 - visit = (state_diff != NoChange()) - end - end - for i=1:length(to_visit) - if key > to_visit[i] - # we have already visited it - continue - end - key = to_visit[i] - visit = true - while visit && key <= min(prev_length, new_length) - state_diff = process_retained!(gen_fn, params, choices_or_selection, - key, (NoChange(), state_diff, param_diffs...), state) - key += 1 - visit = (state_diff != NoChange()) - end - end - end -end - -""" -Process all new applications. -""" -function process_all_new!(gen_fn::Unfold{T,U}, params::Tuple, choices_or_selection, - prev_len::Int, new_len::Int, state) where {T,U} - for key=prev_len+1:new_len - process_new!(gen_fn, params, choices_or_selection, key, state) - end -end diff --git a/src/modeling_library/unfold/regenerate.jl b/src/modeling_library/unfold/regenerate.jl deleted file mode 100644 index 7e8480ca0..000000000 --- a/src/modeling_library/unfold/regenerate.jl +++ /dev/null @@ -1,113 +0,0 @@ -mutable struct UnfoldRegenerateState{T,U} - init_state::T - weight::Float64 - score::Float64 - noise::Float64 - subtraces::PersistentVector{U} - retval::PersistentVector{T} - num_nonempty::Int - updated_retdiffs::Dict{Int,Diff} -end - -function process_retained!(gen_fn::Unfold{T,U}, params::Tuple, - selection::Selection, key::Int, kernel_argdiffs::Tuple, - state::UnfoldRegenerateState{T,U}) where {T,U} - local subtrace::U - local prev_subtrace::U - local prev_state::T - local new_state::T - - subselection = selection[key] - prev_state = (key == 1) ? state.init_state : state.retval[key-1] - kernel_args = (key, prev_state, params...) - - # get new subtrace with recursive call to regenerate() - prev_subtrace = state.subtraces[key] - (subtrace, weight, retdiff) = regenerate( - prev_subtrace, kernel_args, kernel_argdiffs, subselection) - - # retrieve retdiff - if retdiff != NoChange() - state.updated_retdiffs[key] = retdiff - end - - # update state - state.weight += weight - state.score += (get_score(subtrace) - get_score(prev_subtrace)) - state.noise += (project(subtrace, EmptySelection()) - project(subtrace, EmptySelection())) - state.subtraces = assoc(state.subtraces, key, subtrace) - new_state = get_retval(subtrace) - state.retval = assoc(state.retval, key, new_state) - subtrace_empty = isempty(get_choices(subtrace)) - prev_subtrace_empty = isempty(get_choices(prev_subtrace)) - if !subtrace_empty && prev_subtrace_empty - state.num_nonempty += 1 - elseif subtrace_empty && !prev_subtrace_empty - state.num_nonempty -= 1 - end - - retdiff -end - -function process_new!(gen_fn::Unfold{T,U}, params::Tuple, selection::Selection, key::Int, - state::UnfoldRegenerateState{T,U}) where {T,U} - local subtrace::U - local prev_state::T - local new_state::T - - if !isempty(selection[key]) - error("Cannot select new addresses in regenerate") - end - prev_state = (key == 1) ? state.init_state : state.retval[key-1] - kernel_args = (key, prev_state, params...) - - # get subtrace and weight - (subtrace, weight) = generate(gen_fn.kernel, kernel_args, EmptyChoiceMap()) - - # update state - state.weight += weight - state.score += get_score(subtrace) - new_state = get_retval(subtrace) - @assert key > length(state.subtraces) - state.subtraces = push(state.subtraces, subtrace) - state.retval = push(state.retval, new_state) - @assert length(state.subtraces) == key - if !isempty(get_choices(subtrace)) - state.num_nonempty += 1 - end -end - -function regenerate(trace::VectorTrace{UnfoldType,T,U}, - args::Tuple, argdiffs::Tuple, - selection::Selection) where {T,U} - gen_fn = trace.gen_fn - (new_length, init_state, params) = unpack_args(args) - check_length(new_length) - prev_args = get_args(trace) - prev_length = prev_args[1] - retained_and_selected = get_retained_and_selected(selection, prev_length, new_length) - - # handle removed applications - (num_nonempty, score_decrement, noise_decrement) = vector_regenerate_delete( - new_length, prev_length, trace) - (subtraces, retval) = vector_remove_deleted_applications( - trace.subtraces, trace.retval, prev_length, new_length) - score = trace.score - score_decrement - noise = trace.noise - noise_decrement - - # handle retained and new applications - state = UnfoldRegenerateState{T,U}(init_state, -noise_decrement, score, noise, - subtraces, retval, num_nonempty, Dict{Int,Diff}()) - process_all_retained!(gen_fn, params, argdiffs, selection, prev_length, new_length, - retained_and_selected, state) - process_all_new!(gen_fn, params, selection, prev_length, new_length, state) - - # retdiff - retdiff = vector_compute_retdiff(state.updated_retdiffs, new_length, prev_length) - - # new trace - new_trace = VectorTrace{UnfoldType,T,U}(gen_fn, state.subtraces, state.retval, args, - state.score, state.noise, new_length, state.num_nonempty) - - (new_trace, state.weight, retdiff) -end diff --git a/src/modeling_library/unfold/unfold.jl b/src/modeling_library/unfold/unfold.jl index 44238e3b7..ea22070d8 100644 --- a/src/modeling_library/unfold/unfold.jl +++ b/src/modeling_library/unfold/unfold.jl @@ -64,7 +64,5 @@ include("simulate.jl") include("generate.jl") include("propose.jl") include("assess.jl") -include("generic_update.jl") include("update.jl") -include("regenerate.jl") include("backprop.jl") diff --git a/src/modeling_library/unfold/update.jl b/src/modeling_library/unfold/update.jl index d49cea302..1fe9a0d02 100644 --- a/src/modeling_library/unfold/update.jl +++ b/src/modeling_library/unfold/update.jl @@ -10,22 +10,84 @@ mutable struct UnfoldUpdateState{T,U} updated_retdiffs::Dict{Int,Diff} end +function process_all_retained!(gen_fn::Unfold{T,U}, params::Tuple, argdiffs::Tuple, + spec, prev_length::Int, new_length::Int, + retained_and_targeted::Set{Int}, externally_constrained_addrs, state) where {T,U} + + len_diff = argdiffs[1] + init_state_diff = argdiffs[2] + param_diffs = argdiffs[3:end] # a tuple of diffs + + if any(diff != NoChange() for diff in param_diffs) + + # visit every retained kernel application + state_diff = init_state_diff + for key = 1:min(prev_length, new_length) + state_diff = process_retained!(gen_fn, params, spec, + key, (NoChange(), state_diff, param_diffs...), externally_constrained_addrs, state) + end + + else + # every parameter diff is NoChange() + + # visit only certain retained kernel applications + to_visit::Vector{Int} = sort(collect(retained_and_targeted)) + key = 0 + state_diff = init_state_diff + if state_diff != NoChange() + key = 1 + visit = true + while visit && key <= min(prev_length, new_length) + state_diff = process_retained!(gen_fn, params, spec, + key, (NoChange(), state_diff, param_diffs...), externally_constrained_addrs, state) + key += 1 + visit = (state_diff != NoChange()) + end + end + for i = 1:length(to_visit) + if key > to_visit[i] + # we have already visited it + continue + end + key = to_visit[i] + visit = true + while visit && key <= min(prev_length, new_length) + state_diff = process_retained!(gen_fn, params, spec, + key, (NoChange(), state_diff, param_diffs...), externally_constrained_addrs, state) + key += 1 + visit = (state_diff != NoChange()) + end + end + end +end + +""" +Process all new applications. +""" +function process_all_new!(gen_fn::Unfold{T,U}, params::Tuple, choices, + prev_len::Int, new_len::Int, state) where {T,U} + for key = prev_len + 1:new_len + process_new!(gen_fn, params, choices, key, state) + end +end + function process_retained!(gen_fn::Unfold{T,U}, params::Tuple, - choices::ChoiceMap, key::Int, kernel_argdiffs::Tuple, - state::UnfoldUpdateState{T,U}) where {T,U} + spec::UpdateSpec, key::Int, kernel_argdiffs::Tuple, + externally_constrained_addrs::Selection, state::UnfoldUpdateState{T,U}) where {T,U} local subtrace::U local prev_subtrace::U local prev_state::T local new_state::T - submap = get_submap(choices, key) - prev_state = (key == 1) ? state.init_state : state.retval[key-1] + subspec = get_subtree(spec, key) + sub_ext_const_addrs = get_subtree(externally_constrained_addrs, key) + prev_state = (key == 1) ? state.init_state : state.retval[key - 1] kernel_args = (key, prev_state, params...) # get new subtrace with recursive call to update() prev_subtrace = state.subtraces[key] (subtrace, weight, retdiff, discard) = update( - prev_subtrace, kernel_args, kernel_argdiffs, submap) + prev_subtrace, kernel_args, kernel_argdiffs, subspec, sub_ext_const_addrs) # retrieve retdiff if retdiff != NoChange() @@ -51,14 +113,14 @@ function process_retained!(gen_fn::Unfold{T,U}, params::Tuple, retdiff end -function process_new!(gen_fn::Unfold{T,U}, params::Tuple, choices, key::Int, +function process_new!(gen_fn::Unfold{T,U}, params::Tuple, spec, key::Int, state::UnfoldUpdateState{T,U}) where {T,U} local subtrace::U local prev_state::T local new_state::T - submap = get_submap(choices, key) - prev_state = (key == 1) ? state.init_state : state.retval[key-1] + submap = get_subtree(spec, key) + prev_state = (key == 1) ? state.init_state : state.retval[key - 1] kernel_args = (key, prev_state, params...) # get subtrace and weight @@ -79,28 +141,28 @@ end function update(trace::VectorTrace{UnfoldType,T,U}, args::Tuple, argdiffs::Tuple, - choices::ChoiceMap) where {T,U} + spec::UpdateSpec, externally_constrained_addrs::Selection) where {T,U} gen_fn = trace.gen_fn (new_length, init_state, params) = unpack_args(args) check_length(new_length) prev_args = get_args(trace) prev_length = prev_args[1] - retained_and_constrained = get_retained_and_constrained(choices, prev_length, new_length) + retained_and_specd = get_retained_and_specd(spec, prev_length, new_length) # handle removed applications - (discard, num_nonempty, score_decrement, noise_decrement) = vector_update_delete( - new_length, prev_length, trace) + (discard, num_nonempty, score_decrement, noise_decrement, weight_decrement) = vector_update_delete( + new_length, prev_length, trace, externally_constrained_addrs) (subtraces, retval) = vector_remove_deleted_applications( trace.subtraces, trace.retval, prev_length, new_length) score = trace.score - score_decrement noise = trace.noise - noise_decrement # handle retained and new applications - state = UnfoldUpdateState{T,U}(init_state, -score_decrement, score, noise, + state = UnfoldUpdateState{T,U}(init_state, -weight_decrement, score, noise, subtraces, retval, discard, num_nonempty, Dict{Int,Diff}()) - process_all_retained!(gen_fn, params, argdiffs, choices, prev_length, new_length, - retained_and_constrained, state) - process_all_new!(gen_fn, params, choices, prev_length, new_length, state) + process_all_retained!(gen_fn, params, argdiffs, spec, prev_length, new_length, + retained_and_specd, externally_constrained_addrs, state) + process_all_new!(gen_fn, params, spec, prev_length, new_length, state) # retdiff retdiff = vector_compute_retdiff(state.updated_retdiffs, new_length, prev_length) diff --git a/src/modeling_library/vector.jl b/src/modeling_library/vector.jl index 9b0eb763a..307283df8 100644 --- a/src/modeling_library/vector.jl +++ b/src/modeling_library/vector.jl @@ -65,49 +65,47 @@ end function project(trace::VectorTrace, selection::Selection) weight = 0. for key=1:trace.len - subselection = selection[key] + subselection = get_subselection(selection, key) weight += project(trace.subtraces[key], subselection) end weight end project(trace::VectorTrace, ::EmptySelection) = trace.noise -struct VectorTraceChoiceMap{GenFnType, T, U} <: ChoiceMap +struct VectorTraceChoiceMap{GenFnType, T, U} <: AddressTree{Value} trace::VectorTrace{GenFnType, T, U} end @inline Base.isempty(assignment::VectorTraceChoiceMap) = assignment.trace.num_nonempty == 0 @inline get_address_schema(::Type{VectorTraceChoiceMap}) = VectorAddressSchema() -@inline function get_submap(choices::VectorTraceChoiceMap, addr::Int) +@inline get_subtree(choices::VectorTraceChoiceMap, addr::Pair) = _get_subtree(choices, addr) +@inline function get_subtree(choices::VectorTraceChoiceMap, addr::Int) if addr <= choices.trace.len get_choices(choices.trace.subtraces[addr]) else EmptyChoiceMap() end end +# keys which are not ints have no sub-choicemap +@inline get_subtree(choices::VectorTraceChoiceMap, addr) = EmptyChoiceMap() -@inline function get_submaps_shallow(choices::VectorTraceChoiceMap) +@inline function get_subtrees_shallow(choices::VectorTraceChoiceMap) ((i, get_choices(choices.trace.subtraces[i])) for i=1:choices.trace.len) end -@inline get_submap(choices::VectorTraceChoiceMap, addr::Pair) = _get_submap(choices, addr) -@inline get_value(choices::VectorTraceChoiceMap, addr::Pair) = _get_value(choices, addr) -@inline has_value(choices::VectorTraceChoiceMap, addr::Pair) = _has_value(choices, addr) -@inline get_values_shallow(::VectorTraceChoiceMap) = () - - ############################################ # code shared by vector-shaped combinators # ############################################ -function get_retained_and_constrained(constraints::ChoiceMap, prev_length::Int, new_length::Int) +function get_retained_and_specd(spec::UpdateSpec, prev_length::Int, new_length::Int) keys = Set{Int}() - for (key::Int, _) in get_submaps_shallow(constraints) + for (key::Int, subspec) in get_subtrees_shallow(spec) + isempty(subspec) && continue; if key > 0 && key <= new_length push!(keys, key) else - error("Constrained address does not exist: $key") + error("Update spec included address which does not exist: $key") end end keys @@ -117,7 +115,7 @@ function get_retained_and_selected(selection::EmptySelection, prev_length::Int, Set{Int}() end -function get_retained_and_selected(selection::HierarchicalSelection, prev_length::Int, new_length::Int) +function get_retained_and_selected(selection::Selection, prev_length::Int, new_length::Int) keys = Set{Int}() for (key::Int, _) in get_subselections(selection) if key > 0 && key <= new_length @@ -142,39 +140,36 @@ function vector_compute_retdiff(updated_retdiffs::Dict{Int,Diff}, new_length::In end function vector_update_delete(new_length::Int, prev_length::Int, - prev_trace::VectorTrace) + prev_trace::VectorTrace, externally_constrained_addrs::Selection) num_nonempty = prev_trace.num_nonempty discard = choicemap() score_decrement = 0. noise_decrement = 0. + deletion_weight = 0. for key=new_length+1:prev_length subtrace = prev_trace.subtraces[key] - score_decrement += get_score(subtrace) - noise_decrement += project(subtrace, EmptySelection()) - if !isempty(get_choices(subtrace)) - num_nonempty -= 1 + score_change = get_score(subtrace) + noise_change = project(subtrace, EmptySelection()) + + score_decrement += score_change + noise_decrement += noise_change + + ext_const = get_subselection(externally_constrained_addrs, key) + if isempty(ext_const) + deletion_weight += noise_change + elseif ext_const === AllSelection() + deletion_weight += score_change + else + deletion_weight -= project(subtrace, addrs(get_selected(get_choices(subtrace), ext_const))) end - @assert num_nonempty >= 0 - set_submap!(discard, key, get_choices(subtrace)) - end - return (discard, num_nonempty, score_decrement, noise_decrement) -end -function vector_regenerate_delete(new_length::Int, prev_length::Int, - prev_trace::VectorTrace) - num_nonempty = prev_trace.num_nonempty - score_decrement = 0. - noise_decrement = 0. - for key=new_length+1:prev_length - subtrace = prev_trace.subtraces[key] - score_decrement += get_score(subtrace) - noise_decrement += project(subtrace, EmptySelection()) if !isempty(get_choices(subtrace)) num_nonempty -= 1 end @assert num_nonempty >= 0 + set_submap!(discard, key, get_choices(subtrace)) end - return (num_nonempty, score_decrement, noise_decrement) + return (discard, num_nonempty, score_decrement, noise_decrement, deletion_weight) end function vector_remove_deleted_applications(subtraces, retval, prev_length, new_length) diff --git a/src/static_ir/assess.jl b/src/static_ir/assess.jl new file mode 100644 index 000000000..e69de29bb diff --git a/src/static_ir/backprop.jl b/src/static_ir/backprop.jl index 99594cb72..131415665 100644 --- a/src/static_ir/backprop.jl +++ b/src/static_ir/backprop.jl @@ -36,15 +36,15 @@ function fwd_pass!(selected_choices, selected_calls, fwd_marked, node::JuliaNode end end -function fwd_pass!(selected_choices, selected_calls, fwd_marked, node::RandomChoiceNode) - if node in selected_choices - push!(fwd_marked, node) - end -end - function fwd_pass!(selected_choices, selected_calls, fwd_marked, node::GenerativeFunctionCallNode) - if node in selected_calls || any(input_node in fwd_marked for input_node in node.inputs) - push!(fwd_marked, node) + if node.generative_function isa Distribution + if node in selected_choices + push!(fwd_marked, node) + end + else + if node in selected_calls || any(input_node in fwd_marked for input_node in node.inputs) + push!(fwd_marked, node) + end end end @@ -60,20 +60,15 @@ function back_pass!(back_marked, node::JuliaNode) end end -function back_pass!(back_marked, node::RandomChoiceNode) - # the logpdf of every random choice is a SINK - for input_node in node.inputs - push!(back_marked, input_node) - end - # the value of every random choice is in back_marked, since it affects its logpdf - push!(back_marked, node) -end - function back_pass!(back_marked, node::GenerativeFunctionCallNode) # the logpdf of every generative function call is a SINK for input_node in node.inputs push!(back_marked, input_node) end + if node.generative_function isa Distribution + # the value of every random choice is in back_marked, since it affects its logpdf + push!(back_marked, node) + end end function fwd_codegen!(stmts, fwd_marked, back_marked, node::TrainableParameterNode) @@ -134,35 +129,35 @@ function fwd_codegen!(stmts, fwd_marked, back_marked, node::JuliaNode) end end -function fwd_codegen!(stmts, fwd_marked, back_marked, node::RandomChoiceNode) - # for reference by other nodes during back_codegen! - # could performance optimize this away - push!(stmts, :($(node.name) = trace.$(get_value_fieldname(node)))) - - # every random choice is in back_marked, since it affects it logpdf, but - # also possibly due to other downstream usage of the value - @assert node in back_marked +function fwd_codegen!(stmts, fwd_marked, back_marked, node::GenerativeFunctionCallNode) + if node.generative_function isa Distribution + # for reference by other nodes during back_codegen! + # could performance optimize this away + push!(stmts, :($(node.name) = get_retval(trace.$(get_subtrace_fieldname(node))))) - if node in fwd_marked - # the only way we are fwd_marked is if this choice was selected + # every random choice is in back_marked, since it affects it logpdf, but + # also possibly due to other downstream usage of the value + @assert node in back_marked - # initialize gradient with respect to the value of the random choice to zero - # it will be a runtime error, thrown here, if there is no zero() method - push!(stmts, :($(gradient_var(node)) = zero($(node.name)))) - end -end + if node in fwd_marked + # the only way we are fwd_marked is if this choice was selected -function fwd_codegen!(stmts, fwd_marked, back_marked, node::GenerativeFunctionCallNode) - # for reference by other nodes during back_codegen! - # could performance optimize this away - subtrace_fieldname = get_subtrace_fieldname(node) - push!(stmts, :($(node.name) = get_retval(trace.$subtrace_fieldname))) + # initialize gradient with respect to the value of the random choice to zero + # it will be a runtime error, thrown here, if there is no zero() method + push!(stmts, :($(gradient_var(node)) = zero($(node.name)))) + end + else + # for reference by other nodes during back_codegen! + # could performance optimize this away + subtrace_fieldname = get_subtrace_fieldname(node) + push!(stmts, :($(node.name) = get_retval(trace.$subtrace_fieldname))) - # NOTE: we will still potentially run choice_gradients recursively on the generative function, - # we just might not use its return value gradient. - if node in fwd_marked && node in back_marked - # we are fwd_marked if an input was fwd_marked, or if we were selected internally - push!(stmts, :($(gradient_var(node)) = zero($(node.name)))) + # NOTE: we will still potentially run choice_gradients recursively on the generative function, + # we just might not use its return value gradient. + if node in fwd_marked && node in back_marked + # we are fwd_marked if an input was fwd_marked, or if we were selected internally + push!(stmts, :($(gradient_var(node)) = zero($(node.name)))) + end end end @@ -217,19 +212,19 @@ function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, node: end function back_codegen_random_choice_to_inputs!(stmts, ir, fwd_marked, back_marked, - node::RandomChoiceNode, logpdf_grad::Symbol) + node::GenerativeFunctionCallNode, logpdf_grad::Symbol) # only evaluate the gradient of the logpdf if we need to if any(input_node in fwd_marked for input_node in node.inputs) || node in fwd_marked args = map((input_node) -> input_node.name, node.inputs) - push!(stmts, :($logpdf_grad = logpdf_grad($(node.dist), $(node.name), $(args...)))) + push!(stmts, :($logpdf_grad = logpdf_grad($(node.generative_function), $(node.name), $(args...)))) end # increment gradients of input nodes that are in fwd_marked for (i, input_node) in enumerate(node.inputs) if input_node in fwd_marked @assert input_node in back_marked # this ensured its gradient will have been initialized - if !has_argument_grads(node.dist)[i] - error("Distribution $(node.dist) does not have logpdf gradient for argument $i") + if !has_argument_grads(node.generative_function)[i] + error("Distribution $(node.generative_function) does not have logpdf gradient for argument $i") end push!(stmts, :($(gradient_var(input_node)) += $logpdf_grad[$(QuoteNode(i+1))])) end @@ -242,125 +237,124 @@ function back_codegen_random_choice_to_inputs!(stmts, ir, fwd_marked, back_marke end end -function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, - node::RandomChoiceNode, ::BackpropTraceMode) - logpdf_grad = gensym("logpdf_grad") - - # backpropagate to the inputs - back_codegen_random_choice_to_inputs!(stmts, ir, fwd_marked, back_marked, node, logpdf_grad) - - # backpropagate to the value (if it was selected) - if node in fwd_marked - if !has_output_grad(node.dist) - error("Distribution $dist does not logpdf gradient for its output value") - end - push!(stmts, :($(gradient_var(node)) += $logpdf_grad[1])) - end -end - -function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, - node::RandomChoiceNode, ::BackpropParamsMode) - logpdf_grad = gensym("logpdf_grad") - back_codegen_random_choice_to_inputs!(stmts, ir, fwd_marked, back_marked, node, logpdf_grad) -end - function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, node::GenerativeFunctionCallNode, mode::BackpropTraceMode) + if node.generative_function isa Distribution + logpdf_grad = gensym("logpdf_grad") + + # backpropagate to the inputs + back_codegen_random_choice_to_inputs!(stmts, ir, fwd_marked, back_marked, node, logpdf_grad) + + # backpropagate to the value (if it was selected) + if node in fwd_marked + if !has_output_grad(node.generative_function) + error("Distribution $(node.generative_function) does not logpdf gradient for its output value") + end + push!(stmts, :($(gradient_var(node)) += $logpdf_grad[1])) + end + else + # handle case when it is the return node + if node === ir.return_node && node in fwd_marked + @assert node in back_marked + push!(stmts, :($(gradient_var(node)) += retval_grad)) + end - # handle case when it is the return node - if node === ir.return_node && node in fwd_marked - @assert node in back_marked - push!(stmts, :($(gradient_var(node)) += retval_grad)) - end - - if node in fwd_marked - input_grads = gensym("call_input_grads") - value_trie = value_trie_var(node) - gradient_trie = gradient_trie_var(node) - subtrace_fieldname = get_subtrace_fieldname(node) - call_selection = gensym("call_selection") - if node in selected_calls - push!(stmts, :($call_selection = $(GlobalRef(Gen, :static_getindex))(selection, $(QuoteNode(Val(node.addr)))))) - else - push!(stmts, :($call_selection = EmptySelection())) + if node in fwd_marked + input_grads = gensym("call_input_grads") + value_trie = value_trie_var(node) + gradient_trie = gradient_trie_var(node) + subtrace_fieldname = get_subtrace_fieldname(node) + call_selection = gensym("call_selection") + if node in selected_calls + push!(stmts, :($call_selection = $(GlobalRef(Gen, :static_get_subtree))(selection, $(QuoteNode(Val(node.addr)))))) + else + push!(stmts, :($call_selection = EmptySelection())) + end + retval_grad = node in back_marked ? gradient_var(node) : :(nothing) + push!(stmts, :(($input_grads, $value_trie, $gradient_trie) = choice_gradients( + trace.$subtrace_fieldname, $call_selection, $retval_grad))) end - retval_grad = node in back_marked ? gradient_var(node) : :(nothing) - push!(stmts, :(($input_grads, $value_trie, $gradient_trie) = choice_gradients( - trace.$subtrace_fieldname, $call_selection, $retval_grad))) - end - # increment gradients of input nodes that are in fwd_marked - for (i, input_node) in enumerate(node.inputs) - if input_node in fwd_marked - @assert input_node in back_marked # this ensured its gradient will have been initialized - push!(stmts, :($(gradient_var(input_node)) += $input_grads[$(QuoteNode(i))])) + # increment gradients of input nodes that are in fwd_marked + for (i, input_node) in enumerate(node.inputs) + if input_node in fwd_marked + @assert input_node in back_marked # this ensured its gradient will have been initialized + push!(stmts, :($(gradient_var(input_node)) += $input_grads[$(QuoteNode(i))])) + end end - end - # NOTE: the value_trie and gradient_trie are dealt with later + # NOTE: the value_trie and gradient_trie are dealt with later + end end function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, node::GenerativeFunctionCallNode, mode::BackpropParamsMode) - # handle case when it is the return node - if node === ir.return_node && node in fwd_marked - @assert node in back_marked - push!(stmts, :($(gradient_var(node)) += retval_grad)) - end + if node.generative_function isa Distribution + logpdf_grad = gensym("logpdf_grad") + back_codegen_random_choice_to_inputs!(stmts, ir, fwd_marked, back_marked, node, logpdf_grad) + else + # handle case when it is the return node + if node === ir.return_node && node in fwd_marked + @assert node in back_marked + push!(stmts, :($(gradient_var(node)) += retval_grad)) + end - if node in fwd_marked - input_grads = gensym("call_input_grads") - subtrace_fieldname = get_subtrace_fieldname(node) - retval_grad = node in back_marked ? gradient_var(node) : :(nothing) - push!(stmts, :($input_grads = accumulate_param_gradients!(trace.$subtrace_fieldname, $retval_grad))) - end + if node in fwd_marked + input_grads = gensym("call_input_grads") + subtrace_fieldname = get_subtrace_fieldname(node) + retval_grad = node in back_marked ? gradient_var(node) : :(nothing) + push!(stmts, :($input_grads = accumulate_param_gradients!(trace.$subtrace_fieldname, $retval_grad))) + end - # increment gradients of input nodes that are in fwd_marked - for (i, input_node) in enumerate(node.inputs) - if input_node in fwd_marked - @assert input_node in back_marked # this ensured its gradient will have been initialized - push!(stmts, :($(gradient_var(input_node)) += $input_grads[$(QuoteNode(i))])) + # increment gradients of input nodes that are in fwd_marked + for (i, input_node) in enumerate(node.inputs) + if input_node in fwd_marked + @assert input_node in back_marked # this ensured its gradient will have been initialized + push!(stmts, :($(gradient_var(input_node)) += $input_grads[$(QuoteNode(i))])) + end end end end -function generate_value_gradient_trie(selected_choices::Set{RandomChoiceNode}, +function generate_value_gradient_trie(selected_choices::Set{GenerativeFunctionCallNode}, selected_calls::Set{GenerativeFunctionCallNode}, value_trie::Symbol, gradient_trie::Symbol) selected_choices_vec = collect(selected_choices) quoted_leaf_keys = map((node) -> QuoteNode(node.addr), selected_choices_vec) - leaf_values = map((node) -> :(trace.$(get_value_fieldname(node))), selected_choices_vec) - leaf_gradients = map((node) -> gradient_var(node), selected_choices_vec) + leaf_value_choicemaps = map((node) -> :(Value(get_retval(trace.$(get_subtrace_fieldname(node))))), selected_choices_vec) + leaf_gradient_choicemaps = map((node) -> :(Value($(gradient_var(node)))), selected_choices_vec) selected_calls_vec = collect(selected_calls) quoted_internal_keys = map((node) -> QuoteNode(node.addr), selected_calls_vec) - internal_values = map((node) -> :(get_choices(trace.$(get_subtrace_fieldname(node)))), + internal_value_choicemaps = map((node) -> :(get_choices(trace.$(get_subtrace_fieldname(node)))), selected_calls_vec) - internal_gradients = map((node) -> gradient_trie_var(node), selected_calls_vec) + internal_gradient_choicemaps = map((node) -> gradient_trie_var(node), selected_calls_vec) + + quoted_all_keys = Iterators.flatten((quoted_leaf_keys, quoted_internal_keys)) + all_value_choicemaps = Iterators.flatten((leaf_value_choicemaps, internal_value_choicemaps)) + all_gradient_choicemaps = Iterators.flatten((leaf_gradient_choicemaps, internal_gradient_choicemaps)) + quote - $value_trie = StaticChoiceMap( - NamedTuple{($(quoted_leaf_keys...),)}(($(leaf_values...),)), - NamedTuple{($(quoted_internal_keys...),)}(($(internal_values...),))) - $gradient_trie = StaticChoiceMap( - NamedTuple{($(quoted_leaf_keys...),)}(($(leaf_gradients...),)), - NamedTuple{($(quoted_internal_keys...),)}(($(internal_gradients...),))) + $value_trie = StaticChoiceMap(NamedTuple{($(quoted_all_keys...),)}(($(all_value_choicemaps...),))) + $gradient_trie = StaticChoiceMap(NamedTuple{($(quoted_all_keys...),)}(($(all_gradient_choicemaps...),))) end end function get_selected_choices(::EmptyAddressSchema, ::StaticIR) - Set{RandomChoiceNode}() + Set{GenerativeFunctionCallNode}() end function get_selected_choices(::AllAddressSchema, ir::StaticIR) - Set{RandomChoiceNodes}(ir.choice_nodes) + Set{GenerativeFunctionCallNode}([node for node in ir.call_nodes if node.generative_function isa Distribution]...) end function get_selected_choices(schema::StaticAddressSchema, ir::StaticIR) + # TODO: handle inverted selections! selected_choice_addrs = Set(keys(schema)) - selected_choices = Set{RandomChoiceNode}() - for node in ir.choice_nodes - if node.addr in selected_choice_addrs + selected_choices = Set{GenerativeFunctionCallNode}() + for node in ir.call_nodes + if node.generative_function isa Distribution && node.addr in selected_choice_addrs push!(selected_choices, node) end end @@ -372,14 +366,14 @@ function get_selected_calls(::EmptyAddressSchema, ::StaticIR) end function get_selected_calls(::AllAddressSchema, ir::StaticIR) - Set{GenerativeFunctionCallNode}(ir.call_nodes) + Set{GenerativeFunctionCallNode}([node for node in ir.call_nodes if !(node.generative_function isa Distribution)]...) end function get_selected_calls(schema::StaticAddressSchema, ir::StaticIR) selected_call_addrs = Set(keys(schema)) selected_calls = Set{GenerativeFunctionCallNode}() for node in ir.call_nodes - if node.addr in selected_call_addrs + if !(node.generative_function isa Distribution) && node.addr in selected_call_addrs push!(selected_calls, node) end end @@ -451,7 +445,7 @@ function codegen_accumulate_param_gradients!(trace_type::Type{T}, ir = get_ir(gen_fn_type) # unlike choice_gradients we don't take gradients w.r.t. the value of random choices - selected_choices = Set{RandomChoiceNode}() + selected_choices = Set{GenerativeFunctionCallNode}() # we need to guarantee that we visit every generative function call, # because we need to backpropagate to its trainable parameters diff --git a/src/static_ir/dag.jl b/src/static_ir/dag.jl index c82658892..de6acf6e7 100644 --- a/src/static_ir/dag.jl +++ b/src/static_ir/dag.jl @@ -18,14 +18,6 @@ struct JuliaNode <: StaticIRNode typ::Union{Symbol,Expr,QuoteNode} end -struct RandomChoiceNode <: StaticIRNode - dist::Distribution - inputs::Vector{StaticIRNode} - addr::Symbol - name::Symbol - typ::Union{Symbol,Expr,QuoteNode} -end - struct GenerativeFunctionCallNode <: StaticIRNode generative_function::GenerativeFunction inputs::Vector{StaticIRNode} @@ -38,7 +30,6 @@ struct StaticIR nodes::Vector{StaticIRNode} trainable_param_nodes::Vector{TrainableParameterNode} arg_nodes::Vector{ArgumentNode} - choice_nodes::Vector{RandomChoiceNode} call_nodes::Vector{GenerativeFunctionCallNode} julia_nodes::Vector{JuliaNode} return_node::StaticIRNode @@ -50,12 +41,10 @@ mutable struct StaticIRBuilder node_set::Set{StaticIRNode} trainable_param_nodes::Vector{TrainableParameterNode} arg_nodes::Vector{ArgumentNode} - choice_nodes::Vector{RandomChoiceNode} call_nodes::Vector{GenerativeFunctionCallNode} julia_nodes::Vector{JuliaNode} return_node::Union{Nothing,StaticIRNode} vars::Set{Symbol} - addrs_to_choice_nodes::Dict{Symbol,RandomChoiceNode} addrs_to_call_nodes::Dict{Symbol,GenerativeFunctionCallNode} accepts_output_grad::Bool end @@ -65,17 +54,15 @@ function StaticIRBuilder() node_set = Set{StaticIRNode}() trainable_param_nodes = Vector{TrainableParameterNode}() arg_nodes = Vector{ArgumentNode}() - choice_nodes = Vector{RandomChoiceNode}() call_nodes = Vector{GenerativeFunctionCallNode}() julia_nodes = Vector{JuliaNode}() return_node = nothing vars = Set{Symbol}() - addrs_to_choice_nodes = Dict{Symbol,RandomChoiceNode}() addrs_to_call_nodes = Dict{Symbol,GenerativeFunctionCallNode}() accepts_output_grad = false - StaticIRBuilder(nodes, node_set, trainable_param_nodes, arg_nodes, choice_nodes, call_nodes, + StaticIRBuilder(nodes, node_set, trainable_param_nodes, arg_nodes, call_nodes, julia_nodes, - return_node, vars, addrs_to_choice_nodes, addrs_to_call_nodes, + return_node, vars, addrs_to_call_nodes, accepts_output_grad) end @@ -87,7 +74,6 @@ function build_ir(builder::StaticIRBuilder) builder.nodes, builder.trainable_param_nodes, builder.arg_nodes, - builder.choice_nodes, builder.call_nodes, builder.julia_nodes, builder.return_node, @@ -109,7 +95,7 @@ function check_inputs_exist(builder::StaticIRBuilder, input_nodes) end function check_addr_unique(builder::StaticIRBuilder, addr::Symbol) - if haskey(builder.addrs_to_choice_nodes, addr) || haskey(builder.addrs_to_call_nodes, addr) + if haskey(builder.addrs_to_call_nodes, addr) error("Address $addr was not unique") end end @@ -164,20 +150,6 @@ function add_constant_node!(builder::StaticIRBuilder, val, node end -function add_addr_node!(builder::StaticIRBuilder, dist::Distribution; - inputs::Vector=[], addr::Symbol=gensym(), - name::Symbol=gensym()) - check_unique_var(builder, name) - check_addr_unique(builder, addr) - check_inputs_exist(builder, inputs) - typ = QuoteNode(get_return_type(dist)) - node = RandomChoiceNode(dist, inputs, addr, name, typ) - _add_node!(builder, node) - builder.addrs_to_choice_nodes[addr] = node - push!(builder.choice_nodes, node) - node -end - function add_addr_node!(builder::StaticIRBuilder, gen_fn::GenerativeFunction; inputs::Vector=[], addr::Symbol=gensym(), name::Symbol=gensym()) diff --git a/src/static_ir/generate.jl b/src/static_ir/generate.jl index 3557f19b1..a7e9d42ad 100644 --- a/src/static_ir/generate.jl +++ b/src/static_ir/generate.jl @@ -21,27 +21,6 @@ function process!(state::StaticIRGenerateState, node::JuliaNode, options) end end -function process!(state::StaticIRGenerateState, node::RandomChoiceNode, options) - schema = state.schema - args = map((input_node) -> input_node.name, node.inputs) - incr = gensym("logpdf") - addr = QuoteNode(node.addr) - dist = QuoteNode(node.dist) - @assert isa(schema, StaticAddressSchema) || isa(schema, EmptyAddressSchema) - if isa(schema, StaticAddressSchema) && (node.addr in keys(schema)) - push!(state.stmts, :($(node.name) = $(GlobalRef(Gen, :static_get_value))(constraints, Val($addr)))) - push!(state.stmts, :($incr = $(GlobalRef(Gen, :logpdf))($dist, $(node.name), $(args...)))) - push!(state.stmts, :($weight += $incr)) - else - push!(state.stmts, :($(node.name) = $(GlobalRef(Gen, :random))($dist, $(args...)))) - push!(state.stmts, :($incr = $(GlobalRef(Gen, :logpdf))($dist, $(node.name), $(args...)))) - end - push!(state.stmts, :($(get_value_fieldname(node)) = $(node.name))) - push!(state.stmts, :($(get_score_fieldname(node)) = $incr)) - push!(state.stmts, :($num_nonempty_fieldname += 1)) - push!(state.stmts, :($total_score_fieldname += $incr)) -end - function process!(state::StaticIRGenerateState, node::GenerativeFunctionCallNode, options) schema = state.schema args = map((input_node) -> input_node.name, node.inputs) diff --git a/src/static_ir/project.jl b/src/static_ir/project.jl index f95db7833..0612a5640 100644 --- a/src/static_ir/project.jl +++ b/src/static_ir/project.jl @@ -1,29 +1,19 @@ struct StaticIRProjectState - schema::Union{StaticAddressSchema,EmptyAddressSchema,AllAddressSchema} + schema::StaticSchema stmts::Vector{Any} end function process!(state::StaticIRProjectState, node) end -function process!(state::StaticIRProjectState, node::RandomChoiceNode) - schema = state.schema - @assert isa(schema, StaticAddressSchema) || isa(schema, EmptyAddressSchema) || isa(schema, AllAddressSchema) - if isa(schema, AllAddressSchema) || (isa(schema, StaticAddressSchema) && (node.addr in keys(schema))) - push!(state.stmts, :($weight += trace.$(get_score_fieldname(node)))) - end -end - function process!(state::StaticIRProjectState, node::GenerativeFunctionCallNode) schema = state.schema addr = QuoteNode(node.addr) - @assert isa(schema, StaticAddressSchema) || isa(schema, EmptyAddressSchema) || isa(schema, AllAddressSchema) + @assert isa(schema, StaticSchema) subtrace = get_subtrace_fieldname(node) subselection = gensym("subselection") - if isa(schema, AllAddressSchema) || (isa(schema, StaticAddressSchema) && (node.addr in keys(schema))) - push!(state.stmts, :($subselection = $(GlobalRef(Gen, :static_getindex))(selection, Val($addr)))) + if isa(schema, AllAddressSchema) || (!isa(schema, EmptyAddressSchema) && (node.addr in keys(schema))) + push!(state.stmts, :($subselection = $(GlobalRef(Gen, :static_get_subtree))(selection, Val($addr)))) push!(state.stmts, :($weight += $(GlobalRef(Gen, :project))(trace.$subtrace, $subselection))) - else - push!(state.stmts, :($weight += $(GlobalRef(Gen, :project))(trace.$subtrace, $(GlobalRef(Gen, :EmptySelection))()))) end end @@ -32,7 +22,7 @@ function codegen_project(trace_type::Type, selection_type::Type) schema = get_address_schema(selection_type) # convert the selection to a static selection if it is not already one - if !(isa(schema, StaticAddressSchema) || isa(schema, EmptyAddressSchema) || isa(schema, AllAddressSchema)) + if !(isa(schema, StaticSchema)) return quote $(GlobalRef(Gen, :project))(trace, $(QuoteNode(StaticSelection))(selection)) end end diff --git a/src/static_ir/render_ir.jl b/src/static_ir/render_ir.jl index 22e7b3625..880fec505 100644 --- a/src/static_ir/render_ir.jl +++ b/src/static_ir/render_ir.jl @@ -1,7 +1,12 @@ label(node::ArgumentNode) = String(node.name) label(node::JuliaNode) = String(node.name) -label(node::RandomChoiceNode) = "$(node.dist) $(node.addr) $(node.name)" -label(node::GenerativeFunctionCallNode) = "$(node.addr) $(node.name)" +function label(node::GenerativeFunctionCallNode) + if node.generative_function isa Distribution + "$(node.generative_function) $(node.addr) $(node.name)" + else + "$(node.addr) $(node.name)" + end +end function draw_graph(ir::StaticIR, graphviz, fname) dot = graphviz.Digraph() @@ -14,7 +19,7 @@ function draw_graph(ir::StaticIR, graphviz, fname) shape = "diamond" color = "white" parents = [] - elseif isa(node, RandomChoiceNode) + elseif isa(node, GenerativeFunctionCallNode) && node.generative_function isa Distribution shape = "ellipse" color = "white" parents = node.inputs diff --git a/src/static_ir/simulate.jl b/src/static_ir/simulate.jl index 48c215a3c..669f40753 100644 --- a/src/static_ir/simulate.jl +++ b/src/static_ir/simulate.jl @@ -20,19 +20,6 @@ function process!(state::StaticIRSimulateState, node::JuliaNode, options) end end -function process!(state::StaticIRSimulateState, node::RandomChoiceNode, options) - args = map((input_node) -> input_node.name, node.inputs) - incr = gensym("logpdf") - addr = QuoteNode(node.addr) - dist = QuoteNode(node.dist) - push!(state.stmts, :($(node.name) = $(GlobalRef(Gen, :random))($dist, $(args...)))) - push!(state.stmts, :($incr = $(GlobalRef(Gen, :logpdf))($dist, $(node.name), $(args...)))) - push!(state.stmts, :($(get_value_fieldname(node)) = $(node.name))) - push!(state.stmts, :($(get_score_fieldname(node)) = $incr)) - push!(state.stmts, :($num_nonempty_fieldname += 1)) - push!(state.stmts, :($total_score_fieldname += $incr)) -end - function process!(state::StaticIRSimulateState, node::GenerativeFunctionCallNode, options) args = map((input_node) -> input_node.name, node.inputs) args_tuple = Expr(:tuple, args...) diff --git a/src/static_ir/trace.jl b/src/static_ir/trace.jl index de2c84b30..aafedf00d 100644 --- a/src/static_ir/trace.jl +++ b/src/static_ir/trace.jl @@ -2,34 +2,16 @@ # assignment wrapper # ###################### -struct StaticIRTraceAssmt{T} <: ChoiceMap +struct StaticIRTraceAssmt{T} <: AddressTree{Value} trace::T end function get_schema end @inline get_address_schema(::Type{StaticIRTraceAssmt{T}}) where {T} = get_schema(T) - @inline Base.isempty(choices::StaticIRTraceAssmt) = isempty(choices.trace) - -@inline static_has_value(choices::StaticIRTraceAssmt, key) = false - -@inline function get_value(choices::StaticIRTraceAssmt, key::Symbol) - static_get_value(choices, Val(key)) -end - -@inline function has_value(choices::StaticIRTraceAssmt, key::Symbol) - static_has_value(choices, Val(key)) -end - -@inline function get_submap(choices::StaticIRTraceAssmt, key::Symbol) - static_get_submap(choices, Val(key)) -end -static_get_submap(::StaticIRTraceAssmt, ::Val) = EmptyChoiceMap() - -@inline get_value(choices::StaticIRTraceAssmt, addr::Pair) = _get_value(choices, addr) -@inline has_value(choices::StaticIRTraceAssmt, addr::Pair) = _has_value(choices, addr) -@inline get_submap(choices::StaticIRTraceAssmt, addr::Pair) = _get_submap(choices, addr) +@inline get_subtree(choices::StaticIRTraceAssmt, key::Symbol) = static_get_subtree(choices, Val(key)) +@inline get_subtree(choices::StaticIRTraceAssmt, addr::Pair) = _get_subtree(choices, addr) ######################### # trace type generation # @@ -37,21 +19,21 @@ static_get_submap(::StaticIRTraceAssmt, ::Val) = EmptyChoiceMap() abstract type StaticIRTrace <: Trace end -@inline function static_get_subtrace(trace::StaticIRTrace, addr) - error("Not implemented") -end +@inline static_get_subtrace(trace::StaticIRTrace, addr) = error("Not implemented") +@inline static_get_subtree(::StaticIRTraceAssmt, ::Val) = EmptyChoiceMap() +@inline static_get_value(trace::StaticIRTrace, v::Val) = get_value(static_get_subtree(trace, v)) @inline static_haskey(trace::StaticIRTrace, ::Val) = false - Base.haskey(trace::StaticIRTrace, key) = Gen.static_haskey(trace, Val(key)) +@inline Base.haskey(trace::StaticIRTrace, key) = Gen.static_haskey(trace, Val(key)) -@inline function Base.getindex(trace::StaticIRTrace, addr) - Gen.static_getindex(trace, Val(addr)) -end +@inline Base.getindex(trace::StaticIRTrace, addr) = Gen.static_get_subtree(trace, Val(addr)) @inline function Base.getindex(trace::StaticIRTrace, addr::Pair) first, rest = addr return Gen.static_get_subtrace(trace, Val(first))[rest] end +@inline get_choices(trace::T) where {T <: StaticIRTrace} = StaticIRTraceAssmt{T}(trace) + const arg_prefix = gensym("arg") const choice_value_prefix = gensym("choice_value") const choice_score_prefix = gensym("choice_score") @@ -62,18 +44,10 @@ function get_value_fieldname(node::ArgumentNode) Symbol("$(arg_prefix)_$(node.name)") end -function get_value_fieldname(node::RandomChoiceNode) - Symbol("$(choice_value_prefix)_$(node.addr)") -end - function get_value_fieldname(node::JuliaNode) Symbol("$(julia_prefix)_$(node.name)") end -function get_score_fieldname(node::RandomChoiceNode) - Symbol("$(choice_score_prefix)_$(node.addr)") -end - function get_subtrace_fieldname(node::GenerativeFunctionCallNode) Symbol("$(subtrace_prefix)_$(node.addr)") end @@ -94,12 +68,6 @@ function get_trace_fields(ir::StaticIR, options::StaticIRGenerativeFunctionOptio fieldname = get_value_fieldname(node) push!(fields, TraceField(fieldname, node.typ)) end - for node in ir.choice_nodes - value_fieldname = get_value_fieldname(node) - push!(fields, TraceField(value_fieldname, node.typ)) - score_fieldname = get_score_fieldname(node) - push!(fields, TraceField(score_fieldname, QuoteNode(Float64))) - end for node in ir.call_nodes subtrace_fieldname = get_subtrace_fieldname(node) subtrace_type = QuoteNode(get_trace_type(node.generative_function)) @@ -154,28 +122,7 @@ function generate_get_retval(ir::StaticIR, trace_struct_name::Symbol) Expr(:block, :(trace.$return_value_fieldname))) end -function generate_get_choices(trace_struct_name::Symbol) - Expr(:function, - Expr(:call, GlobalRef(Gen, :get_choices), :(trace::$trace_struct_name)), - Expr(:if, :(!isempty(trace)), - :($(QuoteNode(StaticIRTraceAssmt))(trace)), - :($(QuoteNode(EmptyChoiceMap))()))) -end - -function generate_get_values_shallow(ir::StaticIR, trace_struct_name::Symbol) - elements = [] - for node in ir.choice_nodes - addr = node.addr - value = :(choices.trace.$(get_value_fieldname(node))) - push!(elements, :(($(QuoteNode(addr)), $value))) - end - Expr(:function, - Expr(:call, GlobalRef(Gen, :get_values_shallow), - :(choices::$(QuoteNode(StaticIRTraceAssmt)){$trace_struct_name})), - Expr(:block, Expr(:tuple, elements...))) -end - -function generate_get_submaps_shallow(ir::StaticIR, trace_struct_name::Symbol) +function generate_get_subtrees_shallow(ir::StaticIR, trace_struct_name::Symbol) elements = [] for node in ir.call_nodes addr = node.addr @@ -183,7 +130,7 @@ function generate_get_submaps_shallow(ir::StaticIR, trace_struct_name::Symbol) push!(elements, :(($(QuoteNode(addr)), $(GlobalRef(Gen, :get_choices))($subtrace)))) end Expr(:function, - Expr(:call, GlobalRef(Gen, :get_submaps_shallow), + Expr(:call, GlobalRef(Gen, :get_subtrees_shallow), :(choices::$(QuoteNode(StaticIRTraceAssmt)){$trace_struct_name})), Expr(:block, Expr(:tuple, elements...))) end @@ -204,77 +151,32 @@ function generate_getindex(ir::StaticIR, trace_struct_name::Symbol) for node in ir.call_nodes push!(call_getindex_exprs, quote - function $(GlobalRef(Gen, :static_getindex))(trace::$trace_struct_name, ::Val{$(QuoteNode(node.addr))}) + function $(GlobalRef(Gen, :static_get_subtree))(trace::$trace_struct_name, ::Val{$(QuoteNode(node.addr))}) return $(GlobalRef(Gen, :get_retval))(trace.$(get_subtrace_fieldname(node))) end end ) end - - choice_getindex_exprs = Expr[] - for node in ir.choice_nodes - push!(choice_getindex_exprs, - quote - function $(GlobalRef(Gen, :static_getindex))(trace::$trace_struct_name, ::Val{$(QuoteNode(node.addr))}) - return trace.$(get_value_fieldname(node)) - end - end - ) - end - - return [get_subtrace_exprs; call_getindex_exprs; choice_getindex_exprs] + + return [get_subtrace_exprs; call_getindex_exprs] end -function generate_static_get_value(ir::StaticIR, trace_struct_name::Symbol) - methods = Expr[] - for node in ir.choice_nodes - push!(methods, Expr(:function, - Expr(:call, GlobalRef(Gen, :static_get_value), - :(choices::$(QuoteNode(StaticIRTraceAssmt)){$trace_struct_name}), - :(::Val{$(QuoteNode(node.addr))})), - Expr(:block, :(choices.trace.$(get_value_fieldname(node)))))) - end - methods -end - -function generate_static_has_value(ir::StaticIR, trace_struct_name::Symbol) - methods = Expr[] - for node in ir.choice_nodes - push!(methods, Expr(:function, - Expr(:call, GlobalRef(Gen, :static_has_value), - :(choices::$(QuoteNode(StaticIRTraceAssmt)){$trace_struct_name}), - :(::Val{$(QuoteNode(node.addr))})), - Expr(:block, :(true)))) - end - methods -end - -function generate_static_get_submap(ir::StaticIR, trace_struct_name::Symbol) +function generate_static_get_subtree(ir::StaticIR, trace_struct_name::Symbol) methods = Expr[] for node in ir.call_nodes push!(methods, Expr(:function, - Expr(:call, GlobalRef(Gen, :static_get_submap), + Expr(:call, GlobalRef(Gen, :static_get_subtree), :(choices::$(QuoteNode(StaticIRTraceAssmt)){$trace_struct_name}), :(::Val{$(QuoteNode(node.addr))})), Expr(:block, :($(GlobalRef(Gen, :get_choices))(choices.trace.$(get_subtrace_fieldname(node))))))) end - # throw a KeyError if get_submap is run on an address containing a value - for node in ir.choice_nodes - push!(methods, Expr(:function, - Expr(:call, GlobalRef(Gen, :static_get_submap), - :(choices::$(QuoteNode(StaticIRTraceAssmt)){$trace_struct_name}), - :(::Val{$(QuoteNode(node.addr))})), - Expr(:block, :(throw(KeyError($(QuoteNode(node.addr)))))))) - end methods end function generate_get_schema(ir::StaticIR, trace_struct_name::Symbol) - choice_addrs = [QuoteNode(node.addr) for node in ir.choice_nodes] - call_addrs = [QuoteNode(node.addr) for node in ir.call_nodes] - addrs = vcat(choice_addrs, call_addrs) + addrs = [QuoteNode(node.addr) for node in ir.call_nodes] Expr(:function, Expr(:call, GlobalRef(Gen, :get_schema), :(::Type{$trace_struct_name})), Expr(:block, @@ -289,20 +191,14 @@ function generate_trace_type_and_methods(ir::StaticIR, name::Symbol, options::St get_score_expr = generate_get_score(trace_struct_name) get_args_expr = generate_get_args(ir, trace_struct_name) get_retval_expr = generate_get_retval(ir, trace_struct_name) - get_choices_expr = generate_get_choices(trace_struct_name) get_schema_expr = generate_get_schema(ir, trace_struct_name) - get_values_shallow_expr = generate_get_values_shallow(ir, trace_struct_name) - get_submaps_shallow_expr = generate_get_submaps_shallow(ir, trace_struct_name) - static_get_value_exprs = generate_static_get_value(ir, trace_struct_name) - static_has_value_exprs = generate_static_has_value(ir, trace_struct_name) - static_get_submap_exprs = generate_static_get_submap(ir, trace_struct_name) + get_submaps_shallow_expr = generate_get_subtrees_shallow(ir, trace_struct_name) + static_get_subtree_exprs = generate_static_get_subtree(ir, trace_struct_name) getindex_exprs = generate_getindex(ir, trace_struct_name) exprs = Expr(:block, trace_struct_expr, isempty_expr, get_score_expr, get_args_expr, get_retval_expr, - get_choices_expr, get_schema_expr, get_values_shallow_expr, - get_submaps_shallow_expr, static_get_value_exprs..., - static_has_value_exprs..., static_get_submap_exprs..., getindex_exprs...) + get_schema_expr, get_submaps_shallow_expr, static_get_subtree_exprs..., getindex_exprs...) (exprs, trace_struct_name) end diff --git a/src/static_ir/update.jl b/src/static_ir/update.jl index 3a491c7a8..530370adb 100644 --- a/src/static_ir/update.jl +++ b/src/static_ir/update.jl @@ -1,7 +1,3 @@ -abstract type AbstractUpdateMode end -struct UpdateMode <: AbstractUpdateMode end -struct RegenerateMode <: AbstractUpdateMode end - const retdiff = gensym("retdiff") const discard = gensym("discard") @@ -9,7 +5,6 @@ const calldiff_prefix = gensym("calldiff") calldiff_var(node::GenerativeFunctionCallNode) = Symbol("$(calldiff_prefix)_$(node.addr)") const choice_discard_prefix = gensym("choice_discard") -choice_discard_var(node::RandomChoiceNode) = Symbol("$(choice_discard_prefix)_$(node.addr)") const call_discard_prefix = gensym("call_discard") call_discard_var(node::GenerativeFunctionCallNode) = Symbol("$(call_discard_prefix)_$(node.addr)") @@ -19,21 +14,18 @@ call_discard_var(node::GenerativeFunctionCallNode) = Symbol("$(call_discard_pref ######################## struct ForwardPassState - input_changed::Set{Union{RandomChoiceNode,GenerativeFunctionCallNode}} + input_changed::Set{GenerativeFunctionCallNode} value_changed::Set{StaticIRNode} - constrained_or_selected_choices::Set{RandomChoiceNode} constrained_or_selected_calls::Set{GenerativeFunctionCallNode} discard_calls::Set{GenerativeFunctionCallNode} end function ForwardPassState() - input_changed = Set{Union{RandomChoiceNode,GenerativeFunctionCallNode}}() + input_changed = Set{GenerativeFunctionCallNode}() value_changed = Set{StaticIRNode}() - constrained_or_selected_choices = Set{RandomChoiceNode}() constrained_or_selected_calls = Set{GenerativeFunctionCallNode}() discard_calls = Set{GenerativeFunctionCallNode}() - ForwardPassState(input_changed, value_changed, constrained_or_selected_choices, - constrained_or_selected_calls, discard_calls) + ForwardPassState(input_changed, value_changed, constrained_or_selected_calls, discard_calls) end function forward_pass_argdiff!(state::ForwardPassState, @@ -46,40 +38,56 @@ function forward_pass_argdiff!(state::ForwardPassState, end end -function process_forward!(::AddressSchema, ::ForwardPassState, ::TrainableParameterNode) end +function process_forward!(::Type{<:UpdateSpec}, ::Type{<:Selection}, ::ForwardPassState, ::TrainableParameterNode) end -function process_forward!(::AddressSchema, ::ForwardPassState, node::ArgumentNode) end +function process_forward!(::Type{<:UpdateSpec}, ::Type{<:Selection}, ::ForwardPassState, node::ArgumentNode) end -function process_forward!(::AddressSchema, state::ForwardPassState, node::JuliaNode) +function process_forward!(::Type{<:UpdateSpec}, ::Type{<:Selection}, state::ForwardPassState, node::JuliaNode) if any(input_node in state.value_changed for input_node in node.inputs) push!(state.value_changed, node) end end -function process_forward!(schema::AddressSchema, state::ForwardPassState, - node::RandomChoiceNode) - @assert isa(schema, StaticAddressSchema) || isa(schema, EmptyAddressSchema) || isa(schema, AllAddressSchema) - if isa(schema, AllAddressSchema) || (isa(schema, StaticAddressSchema) && (node.addr in keys(schema))) - push!(state.constrained_or_selected_choices, node) - push!(state.value_changed, node) - end - if any(input_node in state.value_changed for input_node in node.inputs) - push!(state.input_changed, node) - end +function cannot_statically_guarantee_nochange_retdiff(spec_type, externally_constrained_addr_type, node, state) + trace_type = get_trace_type(node.generative_function) + argdiff_types = map(input_node -> input_node in state.value_changed ? UnknownChange : NoChange, node.inputs) + argdiff_type = Tuple{argdiff_types...} + # TODO: can we know the arg type statically? + + # TODO: is there a way to do this using `get_subtree` with constant propagation? + # if so, we might be able to avoid always casting to static address trees + subspec_type = Core.Compiler.return_type(static_get_subtree, Tuple{spec_type, Val{node.addr}}) + subext_const_addr_type = Core.Compiler.return_type(static_get_subtree, Tuple{externally_constrained_addr_type, Val{node.addr}}) + + update_rettype = Core.Compiler.return_type( + Gen.update, + Tuple{trace_type, Tuple, argdiff_type, subspec_type, subext_const_addr_type} + ) + has_static_retdiff = update_rettype <: Tuple && update_rettype != Union{} && length(update_rettype.parameters) >= 3 + guaranteed_returns_nochange = has_static_retdiff && update_rettype.parameters[3] == NoChange + + return !guaranteed_returns_nochange end -function process_forward!(schema::AddressSchema, state::ForwardPassState, +function process_forward!(spec_type::Type{<:UpdateSpec}, externally_constrained_addrs_type::Type{<:Selection}, + state::ForwardPassState, node::GenerativeFunctionCallNode) - @assert isa(schema, StaticAddressSchema) || isa(schema, EmptyAddressSchema) || isa(schema, AllAddressSchema) - if isa(schema, AllAddressSchema) || (isa(schema, StaticAddressSchema) && (node.addr in keys(schema))) + schema = get_address_schema(spec_type) + will_run_update = false + @assert isa(schema, StaticSchema) + if isa(schema, AllAddressSchema) || (!isa(schema, EmptyAddressSchema) && node.addr in keys(schema)) push!(state.constrained_or_selected_calls, node) - push!(state.value_changed, node) - push!(state.discard_calls, node) + will_run_update = true end if any(input_node in state.value_changed for input_node in node.inputs) push!(state.input_changed, node) - push!(state.value_changed, node) # TODO can check whether the node is satically absorbing + will_run_update = true + end + if will_run_update push!(state.discard_calls, node) + if cannot_statically_guarantee_nochange_retdiff(spec_type, externally_constrained_addrs_type, node, state) + push!(state.value_changed, node) + end end end @@ -113,15 +121,6 @@ function process_backward!(fwd::ForwardPassState, back::BackwardPassState, end end -function process_backward!(fwd::ForwardPassState, back::BackwardPassState, - node::RandomChoiceNode, options) - if node in fwd.input_changed || node in fwd.constrained_or_selected_choices - for input_node in node.inputs - push!(back.marked, input_node) - end - end -end - function process_backward!(fwd::ForwardPassState, back::BackwardPassState, node::GenerativeFunctionCallNode, options) if node in fwd.input_changed || node in fwd.constrained_or_selected_calls @@ -142,14 +141,14 @@ function arg_values_and_diffs_from_tracked_diffs(input_nodes) end function process_codegen!(stmts, ::ForwardPassState, back::BackwardPassState, - node::TrainableParameterNode, ::AbstractUpdateMode, options) + node::TrainableParameterNode, options) if node in back.marked push!(stmts, :($(node.name) = $(QuoteNode(get_param))($(QuoteNode(get_gen_fn))(trace), $(QuoteNode(node.name))))) end end function process_codegen!(stmts, ::ForwardPassState, ::BackwardPassState, - node::ArgumentNode, ::AbstractUpdateMode, options) + node::ArgumentNode, options) if options.track_diffs push!(stmts, :($(get_value_fieldname(node)) = $(GlobalRef(Gen, :strip_diff))($(node.name)))) else @@ -158,26 +157,28 @@ function process_codegen!(stmts, ::ForwardPassState, ::BackwardPassState, end function process_codegen!(stmts, fwd::ForwardPassState, back::BackwardPassState, - node::JuliaNode, ::AbstractUpdateMode, options) - run_it = ((options.cache_julia_nodes && node in fwd.value_changed) || - (!options.cache_julia_nodes && node in back.marked)) + node::JuliaNode, options) if options.track_diffs - - # track diffs - if run_it - arg_values, arg_diffs = arg_values_and_diffs_from_tracked_diffs(node.inputs) - args = map((v, d) -> Expr(:call, (GlobalRef(Gen, :Diffed)), v, d), arg_values, arg_diffs) + arg_values, arg_diffs = arg_values_and_diffs_from_tracked_diffs(node.inputs) + args = map((v, d) -> Expr(:call, (GlobalRef(Gen, :Diffed)), v, d), arg_values, arg_diffs) + if !options.cache_julia_nodes && (node in back.marked) push!(stmts, :($(node.name) = $(QuoteNode(node.fn))($(args...)))) - elseif options.cache_julia_nodes + elseif node in fwd.value_changed + push!(stmts, (quote + if !($(Expr(:call, GlobalRef(Gen, :all_nochange), Expr(:tuple, arg_diffs...)))) + $(node.name) = $(QuoteNode(node.fn))($(args...)) + else + $(node.name) = $(GlobalRef(Gen, :Diffed))(trace.$(get_value_fieldname(node)), $(GlobalRef(Gen, :NoChange))()) + end + end).args[2]) + else push!(stmts, :($(node.name) = $(GlobalRef(Gen, :Diffed))(trace.$(get_value_fieldname(node)), $(GlobalRef(Gen, :NoChange))()))) end if options.cache_julia_nodes push!(stmts, :($(get_value_fieldname(node)) = $(GlobalRef(Gen, :strip_diff))($(node.name)))) end else - - # no track diffs - if run_it + if (!options.cache_julia_nodes && node in back.marked) || (node in fwd.value_changed) arg_values = map((n) -> n.name, node.inputs) push!(stmts, :($(node.name) = $(QuoteNode(node.fn))($(arg_values...)))) elseif options.cache_julia_nodes @@ -190,120 +191,7 @@ function process_codegen!(stmts, fwd::ForwardPassState, back::BackwardPassState, end function process_codegen!(stmts, fwd::ForwardPassState, back::BackwardPassState, - node::RandomChoiceNode, ::UpdateMode, - options) - if options.track_diffs - - # track diffs - arg_values, _ = arg_values_and_diffs_from_tracked_diffs(node.inputs) - new_logpdf = gensym("new_logpdf") - addr = QuoteNode(node.addr) - dist = QuoteNode(node.dist) - if node in fwd.constrained_or_selected_choices || node in fwd.input_changed - if node in fwd.constrained_or_selected_choices - push!(stmts, :($(node.name) = $(GlobalRef(Gen, :Diffed))($(GlobalRef(Gen, :static_get_value))(constraints, Val($addr)), $(GlobalRef(Gen, :UnknownChange))()))) - push!(stmts, :($(choice_discard_var(node)) = trace.$(get_value_fieldname(node)))) - else - push!(stmts, :($(node.name) = $(GlobalRef(Gen, :Diffed))(trace.$(get_value_fieldname(node)), NoChange()))) - end - push!(stmts, :($new_logpdf = $(GlobalRef(Gen, :logpdf))($dist, $(GlobalRef(Gen, :strip_diff))($(node.name)), $(arg_values...)))) - push!(stmts, :($weight += $new_logpdf - trace.$(get_score_fieldname(node)))) - push!(stmts, :($total_score_fieldname += $new_logpdf - trace.$(get_score_fieldname(node)))) - push!(stmts, :($(get_score_fieldname(node)) = $new_logpdf)) - else - push!(stmts, :($(node.name) = $(GlobalRef(Gen, :Diffed))(trace.$(get_value_fieldname(node)), $(GlobalRef(Gen, :NoChange))()))) - push!(stmts, :($(get_score_fieldname(node)) = trace.$(get_score_fieldname(node)))) - end - push!(stmts, :($(get_value_fieldname(node)) = $(node.name))) - - else - - # no track diffs - arg_values = map((n) -> n.name, node.inputs) - new_logpdf = gensym("new_logpdf") - addr = QuoteNode(node.addr) - dist = QuoteNode(node.dist) - if node in fwd.constrained_or_selected_choices || node in fwd.input_changed - if node in fwd.constrained_or_selected_choices - push!(stmts, :($(node.name) = $(GlobalRef(Gen, :static_get_value))(constraints, Val($addr)))) - push!(stmts, :($(choice_discard_var(node)) = trace.$(get_value_fieldname(node)))) - else - push!(stmts, :($(node.name) = trace.$(get_value_fieldname(node)))) - end - push!(stmts, :($new_logpdf = $(GlobalRef(Gen, :logpdf))($dist, $(node.name), $(arg_values...)))) - push!(stmts, :($weight += $new_logpdf - trace.$(get_score_fieldname(node)))) - push!(stmts, :($total_score_fieldname += $new_logpdf - trace.$(get_score_fieldname(node)))) - push!(stmts, :($(get_score_fieldname(node)) = $new_logpdf)) - else - push!(stmts, :($(node.name) = trace.$(get_value_fieldname(node)))) - push!(stmts, :($(get_score_fieldname(node)) = trace.$(get_score_fieldname(node)))) - end - push!(stmts, :($(get_value_fieldname(node)) = $(node.name))) - end -end - -function process_codegen!(stmts, fwd::ForwardPassState, back::BackwardPassState, - node::RandomChoiceNode, ::RegenerateMode, - options) - if options.track_diffs - - # track diffs - arg_values, _ = arg_values_and_diffs_from_tracked_diffs(node.inputs) - new_logpdf = gensym("new_logpdf") - addr = QuoteNode(node.addr) - dist = QuoteNode(node.dist) - if node in fwd.constrained_or_selected_choices || node in fwd.input_changed - output_value = Expr(:call, (GlobalRef(Gen, :strip_diff)), node.name) - if node in fwd.constrained_or_selected_choices - # the choice was selected, it does not contribute to the weight - push!(stmts, :($(node.name) = $(GlobalRef(Gen, :Diffed))($(GlobalRef(Gen, :random))($dist, $(arg_values...)), UnknownChange()))) - push!(stmts, :($new_logpdf = $(GlobalRef(Gen, :logpdf))($dist, $output_value, $(arg_values...)))) - else - # the choice was not selected, and the input to the choice changed - # it does contribute to the weight - push!(stmts, :($(node.name) = $(GlobalRef(Gen, :Diffed))(trace.$(get_value_fieldname(node)), NoChange()))) - push!(stmts, :($new_logpdf = $(GlobalRef(Gen, :logpdf))($dist, $output_value, $(arg_values...)))) - push!(stmts, :($weight += $new_logpdf - trace.$(get_score_fieldname(node)))) - end - push!(stmts, :($total_score_fieldname += $new_logpdf - trace.$(get_score_fieldname(node)))) - push!(stmts, :($(get_score_fieldname(node)) = $new_logpdf)) - else - push!(stmts, :($(node.name) = $(GlobalRef(Gen, :Diffed))(trace.$(get_value_fieldname(node)), NoChange()))) - push!(stmts, :($(get_score_fieldname(node)) = trace.$(get_score_fieldname(node)))) - end - push!(stmts, :($(get_value_fieldname(node)) = $(node.name))) - else - - # no track diffs - arg_values = map((n) -> n.name, node.inputs) - new_logpdf = gensym("new_logpdf") - addr = QuoteNode(node.addr) - dist = QuoteNode(node.dist) - if node in fwd.constrained_or_selected_choices || node in fwd.input_changed - if node in fwd.constrained_or_selected_choices - # the choice was selected, it does not contribute to the weight - push!(stmts, :($(node.name) = $(GlobalRef(Gen, :random))($dist, $(arg_values...)))) - push!(stmts, :($new_logpdf = $(GlobalRef(Gen, :logpdf))($dist, $(node.name), $(arg_values...)))) - else - # the choice was not selected, and the input to the choice changed - # it does contribute to the weight - push!(stmts, :($(node.name) = trace.$(get_value_fieldname(node)))) - push!(stmts, :($new_logpdf = $(GlobalRef(Gen, :logpdf))($dist, $(node.name), $(arg_values...)))) - push!(stmts, :($weight += $new_logpdf - trace.$(get_score_fieldname(node)))) - end - push!(stmts, :($total_score_fieldname += $new_logpdf - trace.$(get_score_fieldname(node)))) - push!(stmts, :($(get_score_fieldname(node)) = $new_logpdf)) - else - push!(stmts, :($(node.name) = trace.$(get_value_fieldname(node)))) - push!(stmts, :($(get_score_fieldname(node)) = trace.$(get_score_fieldname(node)))) - end - push!(stmts, :($(get_value_fieldname(node)) = $(node.name))) - end -end - -function process_codegen!(stmts, fwd::ForwardPassState, back::BackwardPassState, - node::GenerativeFunctionCallNode, ::UpdateMode, - options) + node::GenerativeFunctionCallNode, options) if options.track_diffs arg_values, arg_diffs = arg_values_and_diffs_from_tracked_diffs(node.inputs) else @@ -315,15 +203,17 @@ function process_codegen!(stmts, fwd::ForwardPassState, back::BackwardPassState, subtrace = get_subtrace_fieldname(node) prev_subtrace = :(trace.$subtrace) call_weight = gensym("call_weight") - call_constraints = gensym("call_constraints") + call_spec = gensym("call_spec") + ext_const_addrs = gensym("ext_const_addrs") if node in fwd.constrained_or_selected_calls || node in fwd.input_changed if node in fwd.constrained_or_selected_calls - push!(stmts, :($call_constraints = $(GlobalRef(Gen, :static_get_submap))(constraints, Val($addr)))) + push!(stmts, :($call_spec = $(GlobalRef(Gen, :static_get_subtree))(spec, Val($addr)))) else - push!(stmts, :($call_constraints = $(GlobalRef(Gen, :EmptyChoiceMap))())) + push!(stmts, :($call_spec = $(GlobalRef(Gen, :EmptyAddressTree))())) end + push!(stmts, :($ext_const_addrs = $(GlobalRef(Gen, :get_subtree))(externally_constrained_addrs, $addr))) push!(stmts, :(($subtrace, $call_weight, $(calldiff_var(node)), $(call_discard_var(node))) = - $(GlobalRef(Gen, :update))($prev_subtrace, $(Expr(:tuple, arg_values...)), $(Expr(:tuple, arg_diffs...)), $call_constraints))) + $(GlobalRef(Gen, :update))($prev_subtrace, $(Expr(:tuple, arg_values...)), $(Expr(:tuple, arg_diffs...)), $call_spec, $ext_const_addrs))) push!(stmts, :($weight += $call_weight)) push!(stmts, :($total_score_fieldname += $(GlobalRef(Gen, :get_score))($subtrace) - $(GlobalRef(Gen, :get_score))($prev_subtrace))) push!(stmts, :($total_noise_fieldname += $(GlobalRef(Gen, :project))($subtrace, $(GlobalRef(Gen, :EmptySelection))()) - $(GlobalRef(Gen, :project))($prev_subtrace, $(GlobalRef(Gen, :EmptySelection))()))) @@ -336,51 +226,6 @@ function process_codegen!(stmts, fwd::ForwardPassState, back::BackwardPassState, else push!(stmts, :($(node.name) = $(GlobalRef(Gen, :get_retval))($subtrace))) end - else - push!(stmts, :($subtrace = $prev_subtrace)) - if options.track_diffs - push!(stmts, :($(node.name) = $(GlobalRef(Gen, :Diffed))($(GlobalRef(Gen, :get_retval))($subtrace), $(QuoteNode(NoChange()))))) - else - push!(stmts, :($(node.name) = $(GlobalRef(Gen, :get_retval))($subtrace))) - end - end -end - -function process_codegen!(stmts, fwd::ForwardPassState, back::BackwardPassState, - node::GenerativeFunctionCallNode, ::RegenerateMode, - options) - if options.track_diffs - arg_values, arg_diffs = arg_values_and_diffs_from_tracked_diffs(node.inputs) - else - arg_values = map((n) -> n.name, node.inputs) - arg_diffs = map((n) -> QuoteNode(n in fwd.value_changed ? UnknownChange() : NoChange()), node.inputs) - end - addr = QuoteNode(node.addr) - gen_fn = QuoteNode(node.generative_function) - subtrace = get_subtrace_fieldname(node) - prev_subtrace = :(trace.$subtrace) - call_weight = gensym("call_weight") - call_subselection = gensym("call_subselection") - if node in fwd.constrained_or_selected_calls || node in fwd.input_changed - if node in fwd.constrained_or_selected_calls - push!(stmts, :($call_subselection = $(GlobalRef(Gen, :static_getindex))(selection, Val($addr)))) - else - push!(stmts, :($call_subselection = $(GlobalRef(Gen, :EmptySelection))())) - end - push!(stmts, :(($subtrace, $call_weight, $(calldiff_var(node))) = - $(GlobalRef(Gen, :regenerate))($prev_subtrace, $(Expr(:tuple, arg_values...)), $(Expr(:tuple, arg_diffs...)), $call_subselection))) - push!(stmts, :($weight += $call_weight)) - push!(stmts, :($total_score_fieldname += $(GlobalRef(Gen, :get_score))($subtrace) - $(GlobalRef(Gen, :get_score))($prev_subtrace))) - push!(stmts, :($total_noise_fieldname += $(GlobalRef(Gen, :project))($subtrace, $(GlobalRef(Gen, :EmptySelection))()) - $(GlobalRef(Gen, :project))($prev_subtrace, $(GlobalRef(Gen, :EmptySelection))()))) - push!(stmts, :(if !$(GlobalRef(Gen, :isempty))($(GlobalRef(Gen, :get_choices))($subtrace)) && !$(GlobalRef(Gen, :isempty))($(GlobalRef(Gen, :get_choices))($prev_subtrace)) - $num_nonempty_fieldname += 1 end)) - push!(stmts, :(if $(GlobalRef(Gen, :isempty))($(GlobalRef(Gen, :get_choices))($subtrace)) && !$(GlobalRef(Gen, :isempty))($(GlobalRef(Gen, :get_choices))($prev_subtrace)) - $num_nonempty_fieldname -= 1 end)) - if options.track_diffs - push!(stmts, :($(node.name) = $(GlobalRef(Gen, :Diffed))($(GlobalRef(Gen, :get_retval))($subtrace), $(calldiff_var(node))))) - else - push!(stmts, :($(node.name) = $(GlobalRef(Gen, :get_retval))($subtrace))) - end else push!(stmts, :($subtrace = $prev_subtrace)) if options.track_diffs @@ -431,32 +276,20 @@ function generate_new_trace!(stmts::Vector{Expr}, trace_type::Type, options) end end -function generate_discard!(stmts::Vector{Expr}, - constrained_choices::Set{RandomChoiceNode}, - discard_calls::Set{GenerativeFunctionCallNode}) - discard_leaf_nodes = Dict{Symbol,Symbol}() - for node in constrained_choices - discard_leaf_nodes[node.addr] = choice_discard_var(node) - end - discard_internal_nodes = Dict{Symbol,Symbol}() +function generate_discard!(stmts::Vector{Expr}, discard_calls::Set{GenerativeFunctionCallNode}) + discard_nodes = Dict{Symbol,Symbol}() for node in discard_calls - discard_internal_nodes[node.addr] = call_discard_var(node) - end - if length(discard_leaf_nodes) > 0 - (leaf_keys, leaf_nodes) = collect(zip(discard_leaf_nodes...)) - else - (leaf_keys, leaf_nodes) = ((), ()) + discard_nodes[node.addr] = call_discard_var(node) end - if length(discard_internal_nodes) > 0 - (internal_keys, internal_nodes) = collect(zip(discard_internal_nodes...)) + + if length(discard_nodes) > 0 + (keys, nodes) = collect(zip(discard_nodes...)) else - (internal_keys, internal_nodes) = ((), ()) + (keys, nodes) = ((), ()) end - leaf_keys = map((key::Symbol) -> QuoteNode(key), leaf_keys) - internal_keys = map((key::Symbol) -> QuoteNode(key), internal_keys) - expr = :($(QuoteNode(StaticChoiceMap))( - $(QuoteNode(NamedTuple)){($(leaf_keys...),)}(($(leaf_nodes...),)), - $(QuoteNode(NamedTuple)){($(internal_keys...),)}(($(internal_nodes...),)))) + keys = map((key::Symbol) -> QuoteNode(key), keys) + expr = quote $(QuoteNode(StaticChoiceMap))( + $(QuoteNode(NamedTuple)){($(keys...),)}(($(nodes...),))) end push!(stmts, :($discard = $expr)) end @@ -465,13 +298,17 @@ end ####################### function codegen_update(trace_type::Type{T}, args_type::Type, argdiffs_type::Type, - constraints_type::Type) where {T<:StaticIRTrace} + spec_type::Type, externally_constrained_addrs_type::Type) where {T<:StaticIRTrace} gen_fn_type = get_gen_fn_type(trace_type) - schema = get_address_schema(constraints_type) + spec_schema = get_address_schema(spec_type) + ext_const_addrs_schema = get_address_schema(externally_constrained_addrs_type) - # convert the constraints to a static assignment if it is not already one - if !(isa(schema, StaticAddressSchema) || isa(schema, EmptyAddressSchema)) - return quote $(GlobalRef(Gen, :update))(trace, args, argdiffs, $(QuoteNode(StaticChoiceMap))(constraints)) end + spec_is_static = isa(spec_schema, StaticSchema) + ext_const_addrs_is_static = isa(ext_const_addrs_schema, StaticSchema) + + # convert the spec and ext_const_addrs to static if they are not already + if !(spec_is_static && ext_const_addrs_is_static) + return quote $(GlobalRef(Gen, :update))(trace, args, argdiffs, $(QuoteNode(StaticAddressTree))(spec), $(QuoteNode(StaticAddressTree))(externally_constrained_addrs)) end end ir = get_ir(gen_fn_type) @@ -481,7 +318,7 @@ function codegen_update(trace_type::Type{T}, args_type::Type, argdiffs_type::Typ fwd_state = ForwardPassState() forward_pass_argdiff!(fwd_state, ir.arg_nodes, argdiffs_type) for node in ir.nodes - process_forward!(schema, fwd_state, node) + process_forward!(spec_type, externally_constrained_addrs_type, fwd_state, node) end # backward marking pass @@ -500,11 +337,11 @@ function codegen_update(trace_type::Type{T}, args_type::Type, argdiffs_type::Typ initialize_score_weight_num_nonempty!(stmts) unpack_arguments!(stmts, ir.arg_nodes, options) for node in ir.nodes - process_codegen!(stmts, fwd_state, bwd_state, node, UpdateMode(), options) + process_codegen!(stmts, fwd_state, bwd_state, node, options) end generate_return_value!(stmts, fwd_state, ir.return_node, options) generate_new_trace!(stmts, trace_type, options) - generate_discard!(stmts, fwd_state.constrained_or_selected_choices, fwd_state.discard_calls) + generate_discard!(stmts, fwd_state.discard_calls) # return trace and weight and discard and retdiff push!(stmts, :(return ($trace, $weight, $retdiff, $discard))) @@ -512,61 +349,11 @@ function codegen_update(trace_type::Type{T}, args_type::Type, argdiffs_type::Typ Expr(:block, stmts...) end -function codegen_regenerate(trace_type::Type{T}, args_type::Type, argdiffs_type::Type, - selection_type::Type) where {T<:StaticIRTrace} - gen_fn_type = get_gen_fn_type(trace_type) - schema = get_address_schema(selection_type) - - # convert a hierarchical selection to a static selection if it is not alreay one - if !(isa(schema, StaticAddressSchema) || isa(schema, EmptyAddressSchema) || isa(schema, AllAddressSchema)) - return quote $(GlobalRef(Gen, :regenerate))(trace, args, argdiffs, $(QuoteNode(StaticSelection))(selection)) end - end - - ir = get_ir(gen_fn_type) - options = get_options(gen_fn_type) - - # forward marking pass - fwd_state = ForwardPassState() - forward_pass_argdiff!(fwd_state, ir.arg_nodes, argdiffs_type) - for node in ir.nodes - process_forward!(schema, fwd_state, node) - end - - # backward marking pass - bwd_state = BackwardPassState() - push!(bwd_state.marked, ir.return_node) - for node in reverse(ir.nodes) - process_backward!(fwd_state, bwd_state, node, options) - end - - # forward code generation pass - stmts = Expr[] - initialize_score_weight_num_nonempty!(stmts) - unpack_arguments!(stmts, ir.arg_nodes, options) - for node in ir.nodes - process_codegen!(stmts, fwd_state, bwd_state, node, RegenerateMode(), options) - end - generate_return_value!(stmts, fwd_state ,ir.return_node, options) - generate_new_trace!(stmts, trace_type, options) - - # return trace and weight and retdiff - push!(stmts, :(return ($trace, $weight, $retdiff))) - - Expr(:block, stmts...) -end - let T = gensym() push!(generated_functions, quote @generated function $(GlobalRef(Gen, :update))(trace::$T, args::Tuple, argdiffs::Tuple, - constraints::$(QuoteNode(ChoiceMap))) where {$T<:$(QuoteNode(StaticIRTrace))} - $(QuoteNode(codegen_update))(trace, args, argdiffs, constraints) - end - end) - - push!(generated_functions, quote - @generated function $(GlobalRef(Gen, :regenerate))(trace::$T, args::Tuple, argdiffs::Tuple, - selection::$(QuoteNode(Selection))) where {$T<:$(QuoteNode(StaticIRTrace))} - $(QuoteNode(codegen_regenerate))(trace, args, argdiffs, selection) + spec::$(QuoteNode(UpdateSpec)), externally_constrained_addrs::$(QuoteNode(Selection))) where {$T<:$(QuoteNode(StaticIRTrace))} + $(QuoteNode(codegen_update))(trace, args, argdiffs, spec, externally_constrained_addrs) end end) end diff --git a/src/trie.jl b/src/trie.jl index 0d1c2a8a2..345a3ec1c 100644 --- a/src/trie.jl +++ b/src/trie.jl @@ -2,7 +2,7 @@ # Trie # ################## -struct Trie{K,V} <: ChoiceMap +struct Trie{K,V} <: AddressTree{Value} leaf_nodes::Dict{K,V} internal_nodes::Dict{K,Trie{K,V}} end diff --git a/test/assignment.jl b/test/assignment.jl index 1bba754af..7983ea282 100644 --- a/test/assignment.jl +++ b/test/assignment.jl @@ -1,6 +1,63 @@ +@testset "Value" begin + vcm1 = Value(2) + vcm2 = Value(2.) + vcm3 = Value([1,2]) + @test vcm1 isa Value{Int} + @test vcm2 isa Value{Float64} + @test vcm3 isa Value{Vector{Int}} + @test vcm1[] == 2 + @test vcm1[] == get_value(vcm1) + + @test !isempty(vcm1) + @test has_value(vcm1) + @test get_value(vcm1) == 2 + @test vcm1 == vcm2 + @test isempty(get_submaps_shallow(vcm1)) + @test isempty(get_values_shallow(vcm1)) + @test isempty(get_nonvalue_submaps_shallow(vcm1)) + @test to_array(vcm1, Int) == [2] + @test from_array(vcm1, [4]) == Value(4) + @test from_array(vcm3, [4, 5]) == Value([4, 5]) + @test_throws Exception merge(vcm1, vcm2) + @test_throws Exception merge(vcm1, choicemap(:a, 5)) + @test merge(vcm1, EmptyChoiceMap()) == vcm1 + @test merge(EmptyChoiceMap(), vcm1) == vcm1 + @test get_submap(vcm1, :addr) == EmptyChoiceMap() + @test_throws ChoiceMapGetValueError get_value(vcm1, :addr) + @test !has_value(vcm1, :addr) + @test isapprox(vcm2, Value(prevfloat(2.))) + @test isapprox(vcm1, Value(prevfloat(2.))) + @test nested_view(vcm1) == 2 +end + +@testset "static choicemap constructors" begin + @test StaticChoiceMap((a=Value(5), b=Value(6))) == StaticChoiceMap(a=5, b=6) + submap = StaticChoiceMap(a=1., b=[2., 2.5]) + @test submap == StaticChoiceMap((a=Value(1.), b=Value([2., 2.5]))) + outer = StaticChoiceMap(c=3, d=submap, e=submap) + @test outer == StaticChoiceMap((c=Value(3), d=submap, e=submap)) + + # what if we have an emptychoicemap? + StaticChoiceMap(a=Value(5), b=choicemap((1, "hello")), c=EmptyChoiceMap()) + + # convert dynamic --> static + d1 = choicemap((:a, 1), (:b => :c, 2.), (:d, "yes")) + s1 = StaticChoiceMap(d1) + @test s1 == d1 + # should not be a deep conversion! + @test get_subtree(s1, :b) === get_subtree(d1, :b) + + d2 = choicemap((:a, 3), (:b, -10.2)) + d3 = choicemap((1 => :a, 80), (2 => :a, 100), (3 => :a, 2)) + set_submap!(d2, :c, d3) + s2 = StaticChoiceMap(d2) + @test s2 == d2 + @test get_subtree(s2, :c) == d3 +end + @testset "static assignment to/from array" begin - submap = StaticChoiceMap((a=1., b=[2., 2.5]),NamedTuple()) - outer = StaticChoiceMap((c=3.,), (d=submap, e=submap)) + submap = StaticChoiceMap(a=1., b=[2., 2.5]) + outer = StaticChoiceMap(c=3., d=submap, e=submap) arr = to_array(outer, Float64) @test to_array(outer, Float64) == Float64[3.0, 1.0, 2.0, 2.5, 1.0, 2.0, 2.5] @@ -11,14 +68,16 @@ @test choices[:d => :b] == [3.0, 4.0] @test choices[:e => :a] == 5.0 @test choices[:e => :b] == [6.0, 7.0] - @test length(collect(get_submaps_shallow(choices))) == 2 + @test length(collect(get_submaps_shallow(choices))) == 3 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 2 @test length(collect(get_values_shallow(choices))) == 1 submap1 = get_submap(choices, :d) @test length(collect(get_values_shallow(submap1))) == 2 - @test length(collect(get_submaps_shallow(submap1))) == 0 + @test length(collect(get_submaps_shallow(submap1))) == 2 + @test length(collect(get_nonvalue_submaps_shallow(submap1))) == 0 submap2 = get_submap(choices, :e) @test length(collect(get_values_shallow(submap2))) == 2 - @test length(collect(get_submaps_shallow(submap2))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(submap2))) == 0 end @testset "dynamic assignment to/from array" begin @@ -39,14 +98,18 @@ end @test choices[:d => :b] == [3.0, 4.0] @test choices[:e => :a] == 5.0 @test choices[:e => :b] == [6.0, 7.0] - @test length(collect(get_submaps_shallow(choices))) == 2 + @test get_submap(choices, :c) == Value(1.0) + @test get_submap(choices, :d => :b) == Value([3.0, 4.0]) + @test length(collect(get_submaps_shallow(choices))) == 3 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 2 @test length(collect(get_values_shallow(choices))) == 1 submap1 = get_submap(choices, :d) @test length(collect(get_values_shallow(submap1))) == 2 - @test length(collect(get_submaps_shallow(submap1))) == 0 + @test length(collect(get_submaps_shallow(submap1))) == 2 + @test length(collect(get_nonvalue_submaps_shallow(submap1))) == 0 submap2 = get_submap(choices, :e) @test length(collect(get_values_shallow(submap2))) == 2 - @test length(collect(get_submaps_shallow(submap2))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(submap2))) == 0 end @testset "dynamic assignment copy constructor" begin @@ -64,25 +127,6 @@ end @test choices[:u => :w] == 4 end -@testset "internal vector assignment to/from array" begin - inner = choicemap() - set_value!(inner, :a, 1.) - set_value!(inner, :b, 2.) - outer = vectorize_internal([inner, inner, inner]) - - arr = to_array(outer, Float64) - @test to_array(outer, Float64) == Float64[1, 2, 1, 2, 1, 2] - - choices = from_array(outer, Float64[1, 2, 3, 4, 5, 6]) - @test choices[1 => :a] == 1.0 - @test choices[1 => :b] == 2.0 - @test choices[2 => :a] == 3.0 - @test choices[2 => :b] == 4.0 - @test choices[3 => :a] == 5.0 - @test choices[3 => :b] == 6.0 - @test length(collect(get_submaps_shallow(choices))) == 3 -end - @testset "dynamic assignment merge" begin submap = choicemap() set_value!(submap, :x, 1) @@ -107,7 +151,7 @@ end @test choices[:f => :x] == 1 @test choices[:shared => :x] == 1 @test choices[:shared => :y] == 4. - @test length(collect(get_submaps_shallow(choices))) == 4 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 4 @test length(collect(get_values_shallow(choices))) == 3 end @@ -125,8 +169,8 @@ end set_value!(submap, :x, 1) submap2 = choicemap() set_value!(submap2, :y, 4.) - choices1 = StaticChoiceMap((a=1., b=2.), (c=submap, shared=submap)) - choices2 = StaticChoiceMap((d=3.,), (e=submap, f=submap, shared=submap2)) + choices1 = StaticChoiceMap(a=1., b=2., c=submap, shared=submap) + choices2 = StaticChoiceMap(d=3., e=submap, f=submap, shared=submap2) choices = merge(choices1, choices2) @test choices[:a] == 1. @test choices[:b] == 2. @@ -136,124 +180,91 @@ end @test choices[:f => :x] == 1 @test choices[:shared => :x] == 1 @test choices[:shared => :y] == 4. - @test length(collect(get_submaps_shallow(choices))) == 4 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 4 @test length(collect(get_values_shallow(choices))) == 3 end @testset "static assignment variadic merge" begin - choices1 = StaticChoiceMap((a=1,), NamedTuple()) - choices2 = StaticChoiceMap((b=2,), NamedTuple()) - choices3 = StaticChoiceMap((c=3,), NamedTuple()) - choices_all = StaticChoiceMap((a=1, b=2, c=3), NamedTuple()) + choices1 = StaticChoiceMap(a=1) + choices2 = StaticChoiceMap(b=2) + choices3 = StaticChoiceMap(c=3) + choices_all = StaticChoiceMap(a=1, b=2, c=3) @test merge(choices1) == choices1 @test merge(choices1, choices2, choices3) == choices_all end +# TODO: in changing a lot of these to reflect the new behavior of choicemap, +# they are mostly not error checks, but instead checks for returning `EmptyChoiceMap`; +# should we relabel this testset? @testset "static assignment errors" begin + # get_choices on an address that returns a Value + choices = StaticChoiceMap(x=1) + @test get_submap(choices, :x) == Value(1) - # get_choices on an address that contains a value throws a KeyError - choices = StaticChoiceMap((x=1,), NamedTuple()) - threw = false - try get_submap(choices, :x) catch KeyError threw = true end - @test threw - - # static_get_submap on an address that contains a value throws a KeyError - choices = StaticChoiceMap((x=1,), NamedTuple()) - threw = false - try static_get_submap(choices, Val(:x)) catch KeyError threw = true end - @test threw - - # get_choices on an address whose prefix contains a value throws a KeyError - choices = StaticChoiceMap((x=1,), NamedTuple()) - threw = false - try get_submap(choices, :x => :y) catch KeyError threw = true end - @test threw - - # static_get_choices on an address whose prefix contains a value throws a KeyError - choices = StaticChoiceMap((x=1,), NamedTuple()) - threw = false - try static_get_submap(choices, Val(:x)) catch KeyError threw = true end - @test threw + # static_get_submap on an address that contains a value returns a Value + choices = StaticChoiceMap(x=1) + @test static_get_submap(choices, Val(:x)) == Value(1) + + # get_submap on an address whose prefix contains a value returns EmptyChoiceMap + choices = StaticChoiceMap(x=1) + @test get_submap(choices, :x => :y) == EmptyChoiceMap() # get_choices on an address that contains nothing gives empty assignment - choices = StaticChoiceMap(NamedTuple(), NamedTuple()) + choices = StaticChoiceMap() @test isempty(get_submap(choices, :x)) @test isempty(get_submap(choices, :x => :y)) - # static_get_choices on an address that contains nothing throws a KeyError - choices = StaticChoiceMap(NamedTuple(), NamedTuple()) - threw = false - try static_get_submap(choices, Val(:x)) catch KeyError threw = true end - @test threw + # static_get_choices on an address that contains nothing returns an EmptyChoiceMap + choices = StaticChoiceMap() + @test static_get_submap(choices, Val(:x)) == EmptyChoiceMap() - # get_value on an address that contains a submap throws a KeyError + # get_value on an address that contains a submap throws a ChoiceMapGetValueError submap = choicemap() submap[:y] = 1 - choices = StaticChoiceMap(NamedTuple(), (x=submap,)) - threw = false - try get_value(choices, :x) catch KeyError threw = true end - @test threw + choices = StaticChoiceMap(x=submap) + @test_throws ChoiceMapGetValueError get_value(choices, :x) - # static_get_value on an address that contains a submap throws a KeyError + # static_get_value on an address that contains a submap throws a ChoiceMapGetValueError submap = choicemap() submap[:y] = 1 - choices = StaticChoiceMap(NamedTuple(), (x=submap,)) - threw = false - try static_get_value(choices, Val(:x)) catch KeyError threw = true end - @test threw - - # get_value on an address that contains nothing throws a KeyError - choices = StaticChoiceMap(NamedTuple(), NamedTuple()) - threw = false - try get_value(choices, :x) catch KeyError threw = true end - @test threw - threw = false - try get_value(choices, :x => :y) catch KeyError threw = true end - @test threw - - # static_get_value on an address that contains nothing throws a KeyError - choices = StaticChoiceMap(NamedTuple(), NamedTuple()) - threw = false - try static_get_value(choices, Val(:x)) catch KeyError threw = true end - @test threw + choices = StaticChoiceMap(x=submap) + @test_throws ChoiceMapGetValueError static_get_value(choices, Val(:x)) + + # get_value on an address that contains nothing throws a ChoiceMapGetValueError + choices = StaticChoiceMap() + @test_throws ChoiceMapGetValueError get_value(choices, :x) + @test_throws ChoiceMapGetValueError get_value(choices, :x => :y) + + # static_get_value on an address that contains nothing throws a ChoiceMapGetValueError + choices = StaticChoiceMap() + @test_throws ChoiceMapGetValueError static_get_value(choices, Val(:x)) end @testset "dynamic assignment errors" begin - - # get_choices on an address that contains a value throws a KeyError + # get_choices on an address that contains a value returns a Value choices = choicemap() choices[:x] = 1 - threw = false - try get_submap(choices, :x) catch KeyError threw = true end - @test threw + @test get_submap(choices, :x) == Value(1) - # get_choices on an address whose prefix contains a value throws a KeyError + # get_choices on an address whose prefix contains a value returns EmptyChoiceMap choices = choicemap() choices[:x] = 1 - threw = false - try get_submap(choices, :x => :y) catch KeyError threw = true end - @test threw + @test get_submap(choices, :x => :y) == EmptyChoiceMap() # get_choices on an address that contains nothing gives empty assignment choices = choicemap() @test isempty(get_submap(choices, :x)) @test isempty(get_submap(choices, :x => :y)) - # get_value on an address that contains a submap throws a KeyError + # get_value on an address that contains a submap throws a ChoiceMapGetValueError choices = choicemap() choices[:x => :y] = 1 - threw = false - try get_value(choices, :x) catch KeyError threw = true end - @test threw + @test_throws ChoiceMapGetValueError get_value(choices, :x) - # get_value on an address that contains nothing throws a KeyError + # get_value on an address that contains nothing throws a ChoiceMapGetValueError choices = choicemap() - threw = false - try get_value(choices, :x) catch KeyError threw = true end - @test threw - threw = false - try get_value(choices, :x => :y) catch KeyError threw = true end - @test threw + @test_throws ChoiceMapGetValueError get_value(choices, :x) + @test_throws ChoiceMapGetValueError get_value(choices, :x => :y) end @testset "dynamic assignment overwrite" begin @@ -261,7 +272,7 @@ end # overwrite value with a value choices = choicemap() choices[:x] = 1 - choices[:x] = 2 + set_subtree!(choices, :x, Value(2)) @test choices[:x] == 2 # overwrite value with a submap @@ -275,10 +286,8 @@ end # overwrite subassignment with a value choices = choicemap() choices[:x => :y] = 1 - choices[:x] = 2 - threw = false - try get_submap(choices, :x) catch KeyError threw = true end - @test threw + set_subtree!(choices, :x, Value(2)) + @test get_submap(choices, :x) == Value(2) @test choices[:x] == 2 # overwrite subassignment with a subassignment @@ -293,17 +302,13 @@ end # illegal set value under existing value choices = choicemap() choices[:x] = 1 - threw = false - try set_value!(choices, :x => :y, 2) catch KeyError threw = true end - @test threw + @test_throws Exception set_value!(choices, :x => :y, 2) # illegal set submap under existing value choices = choicemap() choices[:x] = 1 submap = choicemap(); choices[:z] = 2 - threw = false - try set_submap!(choices, :x => :y, submap) catch KeyError threw = true end - @test threw + @test_throws Exception set_submap!(choices, :x => :y, submap) end @testset "dynamic assignment constructor" begin @@ -350,7 +355,6 @@ end end @testset "filtering choicemaps with selections" begin - c = choicemap((:a, 1), (:b, 2)) filtered = get_selected(c, select(:a)) @@ -367,3 +371,14 @@ end @test filtered[:x => :y] == 1 @test !has_value(filtered, :x => :z) end + +@testset "underlying choicemap" begin + t = DynamicAddressTree{Union{Value, SelectionLeaf}}() + set_subtree!(t, :a, AllSelection()) + set_subtree!(t, :b => :c, invert(select(:d, :e => :f))) + set_subtree!(t, :c => :a, AllSelection()) + set_subtree!(t, :c => :b, Value(5)) + set_subtree!(t, :c => :c, Value(6)) + set_subtree!(t, :d, Value(7)) + @test UnderlyingChoices(t) == choicemap((:c => :b, 5), (:c => :c, 6), (:d, 7)) +end \ No newline at end of file diff --git a/test/benchmarks/dynamic_choicemap_benchmark.jl b/test/benchmarks/dynamic_choicemap_benchmark.jl new file mode 100644 index 000000000..3724e44de --- /dev/null +++ b/test/benchmarks/dynamic_choicemap_benchmark.jl @@ -0,0 +1,27 @@ +using Gen + +function many_shallow(cm::ChoiceMap) + for _=1:10^5 + cm[:a] + end +end +function many_nested(cm::ChoiceMap) + for _=1:10^5 + cm[:b => :c] + end +end + +# many_shallow(cm) = perform_many_lookups(cm, :a) +# many_nested(cm) = perform_many_lookups(cm, :b => :c) + +cm = choicemap((:a, 1), (:b => :c, 2)) + +println("dynamic choicemap nonnested lookup:") +for _=1:4 + @time many_shallow(cm) +end + +println("dynamic choicemap nested lookup:") +for _=1:4 + @time many_nested(cm) +end \ No newline at end of file diff --git a/test/benchmarks/dynamic_mh.jl b/test/benchmarks/dynamic_mh.jl new file mode 100644 index 000000000..a31223f94 --- /dev/null +++ b/test/benchmarks/dynamic_mh.jl @@ -0,0 +1,77 @@ +module DynamicMHBenchmark +using Gen +import Random + +include("../../examples/regression/dynamic_model.jl") +include("../../examples/regression/dataset.jl") + +@gen function slope_proposal(trace) + slope = trace[:slope] + @trace(normal(slope, 0.5), :slope) +end + +@gen function intercept_proposal(trace) + intercept = trace[:intercept] + @trace(normal(intercept, 0.5), :intercept) +end + +@gen function inlier_std_proposal(trace) + log_inlier_std = trace[:log_inlier_std] + @trace(normal(log_inlier_std, 0.5), :log_inlier_std) +end + +@gen function outlier_std_proposal(trace) + log_outlier_std = trace[:log_outlier_std] + @trace(normal(log_outlier_std, 0.5), :log_outlier_std) +end + +@gen function is_outlier_proposal(trace, i::Int) + prev = trace[:data => i => :z] + @trace(bernoulli(prev ? 0.0 : 1.0), :data => i => :z) +end + +function do_inference(xs, ys, num_iters) + observations = choicemap() + for (i, y) in enumerate(ys) + observations[:data => i => :y] = y + end + + # initial trace + (trace, _) = generate(model, (xs,), observations) + + scores = Vector{Float64}(undef, num_iters) + for i=1:num_iters + + # steps on the parameters + for j=1:5 + (trace, _) = metropolis_hastings(trace, slope_proposal, ()) + (trace, _) = metropolis_hastings(trace, intercept_proposal, ()) + (trace, _) = metropolis_hastings(trace, inlier_std_proposal, ()) + (trace, _) = metropolis_hastings(trace, outlier_std_proposal, ()) + end + + # step on the outliers + for j=1:length(xs) + (trace, _) = metropolis_hastings(trace, is_outlier_proposal, (j,)) + end + + score = get_score(trace) + scores[i] = score + + # print + slope = trace[:slope] + intercept = trace[:intercept] + inlier_std = exp(trace[:log_inlier_std]) + outlier_std = exp(trace[:log_outlier_std]) + end + return scores +end + +println("Simple dynamic DSL MH on regression model:") +(xs, ys) = make_data_set(200) +do_inference(xs, ys, 10) +@time do_inference(xs, ys, 20) +@time do_inference(xs, ys, 20) +println() + +end diff --git a/test/benchmarks/run_benchmarks.jl b/test/benchmarks/run_benchmarks.jl new file mode 100644 index 000000000..66e73978d --- /dev/null +++ b/test/benchmarks/run_benchmarks.jl @@ -0,0 +1,5 @@ +include("static_mh.jl") +include("dynamic_mh.jl") +# Note that there are some other benchmarks in the folder +# for things like choicemap lookup speeds +# I am not running here since they are somewhat special-case \ No newline at end of file diff --git a/test/benchmarks/static_choicemap_benchmark.jl b/test/benchmarks/static_choicemap_benchmark.jl new file mode 100644 index 000000000..1e62b9a8e --- /dev/null +++ b/test/benchmarks/static_choicemap_benchmark.jl @@ -0,0 +1,50 @@ +using Gen + +function many_shallow(cm::ChoiceMap) + for _=1:10^5 + cm[:a] + end +end +function many_nested(cm::ChoiceMap) + for _=1:10^5 + cm[:b => :c] + end +end + +# many_shallow(cm) = perform_many_lookups(cm, :a) +# many_nested(cm) = perform_many_lookups(cm, :b => :c) + +scm = StaticChoiceMap(a=1, b=StaticChoiceMap(c=2)) + +println("static choicemap nonnested lookup:") +for _=1:4 + @time many_shallow(scm) +end + +println("static choicemap nested lookup:") +for _=1:4 + @time many_nested(scm) +end + +@gen (static) function inner() + c ~ normal(0, 1) +end +@gen (static) function outer() + a ~ normal(0, 1) + b ~ inner() +end + +load_generated_functions() + +tr, _ = generate(outer, ()) +choices = get_choices(tr) + +println("static gen function choicemap nonnested lookup:") +for _=1:4 + @time many_shallow(choices) +end + +println("static gen function choicemap nested lookup:") +for _=1:4 + @time many_nested(choices) +end diff --git a/test/benchmarks/static_inference_benchmark.jl b/test/benchmarks/static_inference_benchmark.jl new file mode 100644 index 000000000..b70d08be2 --- /dev/null +++ b/test/benchmarks/static_inference_benchmark.jl @@ -0,0 +1,23 @@ +using Gen + +@gen (static, diffs) function foo() + a ~ normal(0, 1) + b ~ normal(a, 1) + c ~ normal(b, 1) +end + +@load_generated_functions + +observations = StaticChoiceMap(choicemap((:b,2), (:c,1.5))) +tr, _ = generate(foo, (), observations) + +function run_inference(trace) + tr = trace + for _=1:10^3 + tr, acc = mh(tr, select(:a)) + end +end + +for _=1:4 + @time run_inference(tr) +end \ No newline at end of file diff --git a/test/benchmarks/static_mh.jl b/test/benchmarks/static_mh.jl new file mode 100644 index 000000000..9fa6e8e76 --- /dev/null +++ b/test/benchmarks/static_mh.jl @@ -0,0 +1,91 @@ +module StaticMHBenchmark +using Gen +import Random +using Profile +using ProfileView + +include("../../examples/regression/static_model.jl") +include("../../examples/regression/dataset.jl") + +@gen (static) function slope_proposal(trace) + slope = trace[:slope] + @trace(normal(slope, 0.5), :slope) +end + +@gen (static) function intercept_proposal(trace) + intercept = trace[:intercept] + @trace(normal(intercept, 0.5), :intercept) +end + +@gen (static) function inlier_std_proposal(trace) + log_inlier_std = trace[:log_inlier_std] + @trace(normal(log_inlier_std, 0.5), :log_inlier_std) +end + +@gen (static) function outlier_std_proposal(trace) + log_outlier_std = trace[:log_outlier_std] + @trace(normal(log_outlier_std, 0.5), :log_outlier_std) +end + +@gen (static) function flip_z(z::Bool) + @trace(bernoulli(z ? 0.0 : 1.0), :z) +end + +@gen (static) function is_outlier_proposal(trace, i::Int) + prev_z = trace[:data => i => :z] + @trace(bernoulli(prev_z ? 0.0 : 1.0), :data => i => :z) +end + +@gen (static) function is_outlier_proposal(trace, i::Int) + prev_z = trace[:data => i => :z] + @trace(bernoulli(prev_z ? 0.0 : 1.0), :data => i => :z) +end + +Gen.load_generated_functions() + +function do_inference(xs, ys, num_iters) + observations = choicemap() + for (i, y) in enumerate(ys) + observations[:data => i => :y] = y + end + + # initial trace + (trace, _) = generate(model, (xs,), observations) + + scores = Vector{Float64}(undef, num_iters) + for i=1:num_iters + + # steps on the parameters + for j=1:5 + (trace, _) = metropolis_hastings(trace, slope_proposal, ()) + (trace, _) = metropolis_hastings(trace, intercept_proposal, ()) + (trace, _) = metropolis_hastings(trace, inlier_std_proposal, ()) + (trace, _) = metropolis_hastings(trace, outlier_std_proposal, ()) + end + + # step on the outliers + for j=1:length(xs) + (trace, _) = metropolis_hastings(trace, is_outlier_proposal, (j,)) + end + + score = get_score(trace) + scores[i] = score + + # print + slope = trace[:slope] + intercept = trace[:intercept] + inlier_std = exp(trace[:log_inlier_std]) + outlier_std = exp(trace[:log_outlier_std]) + end + return scores +end + +(xs, ys) = make_data_set(200) +do_inference(xs, ys, 10) +println("Simple static DSL (including CallAt nodes) MH on regression model:") +# Profile.clear() +@time do_inference(xs, ys, 50) +@time do_inference(xs, ys, 50) +# ProfileView.view() +println() +end diff --git a/test/diff.jl b/test/diff.jl index 039bacebd..0e3e2e219 100644 --- a/test/diff.jl +++ b/test/diff.jl @@ -96,3 +96,19 @@ end @testset "diff sets" begin end + +@testset "diff properties" begin + struct Foo51 + x::Int + end + @test Diffed(Foo51(1), NoChange()).x == Diffed(1, NoChange()) + @test Diffed(Foo51(1), UnknownChange()).x == Diffed(1, UnknownChange()) +end + +@testset "map diff" begin + list = [1, 2, 3, 4, 5] + @test map(x -> 2*x, Diffed(list, NoChange())) == Diffed([2, 4, 6, 8, 10], NoChange()) + @test map(x -> 2*x, Diffed(list, UnknownChange())) == Diffed([2, 4, 6, 8, 10], UnknownChange()) + + # TODO: test propagating VectorDiffs +end \ No newline at end of file diff --git a/test/dynamic_dsl.jl b/test/dynamic_dsl.jl index e8d2526df..536f81a70 100644 --- a/test/dynamic_dsl.jl +++ b/test/dynamic_dsl.jl @@ -1,4 +1,4 @@ -using Gen: AddressVisitor, all_visited, visit!, get_visited +using Gen: AddressVisitor, all_constraints_visited, visit!, get_visited struct DummyReturnType end @@ -49,20 +49,20 @@ end visitor = AddressVisitor() visit!(visitor, :x) visit!(visitor, :y => :z) - @test all_visited(get_visited(visitor), choices) + @test all_constraints_visited(get_visited(visitor), choices) visitor = AddressVisitor() visit!(visitor, :x) - @test !all_visited(get_visited(visitor), choices) + @test !all_constraints_visited(get_visited(visitor), choices) visitor = AddressVisitor() visit!(visitor, :y => :z) - @test !all_visited(get_visited(visitor), choices) + @test !all_constraints_visited(get_visited(visitor), choices) visitor = AddressVisitor() visit!(visitor, :x) visit!(visitor, :y) - @test all_visited(get_visited(visitor), choices) + @test all_constraints_visited(get_visited(visitor), choices) end @testset "simulate" begin @@ -119,7 +119,7 @@ end @test get_value(discard, :x) == x @test get_value(discard, :u => :a) == a @test length(collect(get_values_shallow(discard))) == 2 - @test length(collect(get_submaps_shallow(discard))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(discard))) == 1 # test new trace new_assignment = get_choices(new_trace) @@ -127,7 +127,7 @@ end @test get_value(new_assignment, :y) == y @test get_value(new_assignment, :v => :b) == b @test length(collect(get_values_shallow(new_assignment))) == 2 - @test length(collect(get_submaps_shallow(new_assignment))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(new_assignment))) == 1 # test score and weight prev_score = ( @@ -242,7 +242,7 @@ end @test !isempty(get_submap(assignment, :v)) end @test length(collect(get_values_shallow(assignment))) == 2 - @test length(collect(get_submaps_shallow(assignment))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(assignment))) == 1 # test weight if assignment[:branch] == prev_assignment[:branch] @@ -270,6 +270,34 @@ end @test retdiff === UnknownChange() end + @gen function bar(mu) + @trace(normal(mu, 1), :a) + end + + @gen function baz(mu) + @trace(normal(mu, 1), :b) + end + + @gen function foo(mu) + if @trace(bernoulli(0.4), :branch) + @trace(normal(mu, 1), :x) + @trace(bar(mu), :u) + else + @trace(normal(mu, 1), :y) + @trace(baz(mu), :v) + end + end + + # test that no errors occur if we select addresses in a subtrace we have to generate + tr, _ = generate(foo, (mu,), choicemap((:branch, true))) + old_score = get_score(tr) + weight = nothing + while tr[:branch] + tr, weight, _ = regenerate(tr, (mu,), (NoChange(),), select(:branch, :x, :y, :v => :b)) + end + # P[new_tr]/P[old_tr] * Q[old_tr|new_tr]/Q[new_tr|old_tr] should be 1 + # since if we switch branches, we totally regenerate everything, so the Q values should equal the P values + @test weight == 0. end @testset "choice_gradients and accumulate_param_gradients!" begin @@ -332,11 +360,11 @@ end @test get_value(choices, :out) == out @test get_value(choices, :bar => :z) == z @test !has_value(choices, :b) # was not selected - @test length(collect(get_submaps_shallow(choices))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 1 @test length(collect(get_values_shallow(choices))) == 2 # check gradient trie - @test length(collect(get_submaps_shallow(gradients))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(gradients))) == 1 @test length(collect(get_values_shallow(gradients))) == 2 @test !has_value(gradients, :b) # was not selected @test isapprox(get_value(gradients, :bar => :z), @@ -426,19 +454,23 @@ end @test trace[:x => 2] == 2 @test trace[:x => 3 => :z] == 3 + new_trace, _, _ = regenerate(trace, (), (), select(:x => 2)) + @test new_trace[:x => 1] == trace[:x => 1] + @test new_trace[:y] == trace[:y] + choices = get_choices(trace) @test choices[:x => 1] == 1 @test choices[:x => 2] == 2 @test choices[:x => 3 => :z] == 3 @test length(collect(get_values_shallow(choices))) == 1 # :y - @test length(collect(get_submaps_shallow(choices))) == 1 # :x + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 1 # :x submap = get_submap(choices, :x) @test submap[1] == 1 @test submap[2] == 2 @test submap[3 => :z] == 3 @test length(collect(get_values_shallow(submap))) == 2 # 1, 2 - @test length(collect(get_submaps_shallow(submap))) == 1 # 3 + @test length(collect(get_nonvalue_submaps_shallow(submap))) == 1 # 3 bar_submap = get_submap(submap, 3) @test bar_submap[:z] == 3 @@ -545,4 +577,22 @@ end @test trace[:x] == trace[:y] end +@testset "generate with selections in constraints" begin + @gen function foo() + x ~ normal(0, 1) + return x + end + constraints = choicemap((:x, 1.)) + constraints_and_selection = DynamicAddressTree{Union{Value, SelectionLeaf}}() + set_subtree!(constraints_and_selection, :x, Value(1.)) + set_subtree!(constraints_and_selection, :y, AllSelection()) + tr1, w1 = generate(foo, (), constraints) + tr2, w2 = generate(foo, (), constraints_and_selection) + @test get_choices(tr1) == get_choices(tr2) + @test w1 == w2 + _, weight1 = generate(foo, (), EmptyAddressTree()) + _, weight2 = generate(foo, (), AllSelection()) + @test weight1 == weight2 +end + end diff --git a/test/modeling_library/call_at.jl b/test/modeling_library/call_at.jl index b27f0130d..1985c610a 100644 --- a/test/modeling_library/call_at.jl +++ b/test/modeling_library/call_at.jl @@ -1,4 +1,4 @@ -@testset "call_at combinator" begin +@testset "call_at combinator on non-distribution" begin @gen (grad) function foo((grad)(x::Float64)) return x + @trace(normal(x, 1), :y) @@ -20,7 +20,7 @@ y = choices[3 => :y] @test isapprox(weight, logpdf(normal, y, 0.4, 1)) @test length(collect(get_values_shallow(choices))) == 0 - @test length(collect(get_submaps_shallow(choices))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 1 end @testset "generate" begin @@ -32,7 +32,7 @@ y = choices[3 => :y] @test get_retval(trace) == 0.4 + y @test length(collect(get_values_shallow(choices))) == 0 - @test length(collect(get_submaps_shallow(choices))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 1 # with constraints y = 1.234 @@ -44,7 +44,7 @@ @test get_retval(trace) == 0.4 + y @test isapprox(weight, logpdf(normal, y, 0.4, 1.)) @test length(collect(get_values_shallow(choices))) == 0 - @test length(collect(get_submaps_shallow(choices))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 1 end function get_trace() @@ -71,7 +71,7 @@ choices = get_choices(new_trace) @test choices[3 => :y] == y @test length(collect(get_values_shallow(choices))) == 0 - @test length(collect(get_submaps_shallow(choices))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 1 @test isapprox(weight, logpdf(normal, y, 0.2, 1) - logpdf(normal, y, 0.4, 1)) @test get_retval(new_trace) == 0.2 + y @test isempty(discard) @@ -86,12 +86,12 @@ choices = get_choices(new_trace) @test choices[3 => :y] == y_new @test length(collect(get_values_shallow(choices))) == 0 - @test length(collect(get_submaps_shallow(choices))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 1 @test isapprox(weight, logpdf(normal, y_new, 0.2, 1) - logpdf(normal, y, 0.4, 1)) @test get_retval(new_trace) == 0.2 + y_new @test discard[3 => :y] == y @test length(collect(get_values_shallow(discard))) == 0 - @test length(collect(get_submaps_shallow(discard))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(discard))) == 1 @test isapprox(get_score(new_trace), logpdf(normal, y_new, 0.2, 1)) # change kernel_args, different key, with constraint @@ -103,12 +103,12 @@ choices = get_choices(new_trace) @test choices[4 => :y] == y_new @test length(collect(get_values_shallow(choices))) == 0 - @test length(collect(get_submaps_shallow(choices))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 1 @test isapprox(weight, logpdf(normal, y_new, 0.2, 1) - logpdf(normal, y, 0.4, 1)) @test get_retval(new_trace) == 0.2 + y_new @test discard[3 => :y] == y @test length(collect(get_values_shallow(discard))) == 0 - @test length(collect(get_submaps_shallow(discard))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(discard))) == 1 @test isapprox(get_score(new_trace), logpdf(normal, y_new, 0.2, 1)) end @@ -121,7 +121,7 @@ choices = get_choices(new_trace) @test choices[3 => :y] == y @test length(collect(get_values_shallow(choices))) == 0 - @test length(collect(get_submaps_shallow(choices))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 1 @test isapprox(weight, logpdf(normal, y, 0.2, 1) - logpdf(normal, y, 0.4, 1)) @test get_retval(new_trace) == 0.2 + y @test isapprox(get_score(new_trace), logpdf(normal, y, 0.2, 1)) @@ -133,7 +133,7 @@ choices = get_choices(new_trace) y_new = choices[3 => :y] @test length(collect(get_values_shallow(choices))) == 0 - @test length(collect(get_submaps_shallow(choices))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 1 @test weight == 0. @test get_retval(new_trace) == 0.2 + y_new @test isapprox(get_score(new_trace), logpdf(normal, y_new, 0.2, 1)) @@ -144,7 +144,7 @@ choices = get_choices(new_trace) y_new = choices[4 => :y] @test length(collect(get_values_shallow(choices))) == 0 - @test length(collect(get_submaps_shallow(choices))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 1 @test weight == 0. @test get_retval(new_trace) == 0.2 + y_new @test isapprox(get_score(new_trace), logpdf(normal, y_new, 0.2, 1)) @@ -171,9 +171,9 @@ @test choices[3 => :y] == y @test isapprox(gradients[3 => :y], logpdf_grad(normal, y, 0.4, 1.0)[1] + retval_grad) @test length(collect(get_values_shallow(gradients))) == 0 - @test length(collect(get_submaps_shallow(gradients))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(gradients))) == 1 @test length(collect(get_values_shallow(choices))) == 0 - @test length(collect(get_submaps_shallow(choices))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 1 @test length(input_grads) == 2 @test isapprox(input_grads[1], logpdf_grad(normal, y, 0.4, 1.0)[2] + retval_grad) @test input_grads[2] == nothing # the key has no gradient diff --git a/test/modeling_library/choice_at.jl b/test/modeling_library/choice_at.jl index 080b1b461..69eb52498 100644 --- a/test/modeling_library/choice_at.jl +++ b/test/modeling_library/choice_at.jl @@ -1,6 +1,6 @@ -@testset "choice_at combinator" begin +@testset "call_at combinator on distribution" begin - at = choice_at(bernoulli, Int) + at = call_at(bernoulli, Int) @testset "assess" begin choices = choicemap() @@ -15,7 +15,7 @@ @test isapprox(weight, value ? log(0.4) : log(0.6)) @test choices[3] == value @test length(collect(get_values_shallow(choices))) == 1 - @test length(collect(get_submaps_shallow(choices))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 0 end @testset "generate" begin @@ -27,7 +27,7 @@ choices = get_choices(trace) @test choices[3] == value @test length(collect(get_values_shallow(choices))) == 1 - @test length(collect(get_submaps_shallow(choices))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 0 # with constraints constraints = choicemap() @@ -39,7 +39,7 @@ choices = get_choices(trace) @test choices[3] == value @test length(collect(get_values_shallow(choices))) == 1 - @test length(collect(get_submaps_shallow(choices))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 0 end function get_trace() @@ -65,7 +65,7 @@ choices = get_choices(new_trace) @test choices[3] == true @test length(collect(get_values_shallow(choices))) == 1 - @test length(collect(get_submaps_shallow(choices))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 0 @test isapprox(weight, log(0.2) - log(0.4)) @test get_retval(new_trace) == true @test isempty(discard) @@ -78,12 +78,12 @@ choices = get_choices(new_trace) @test choices[3] == false @test length(collect(get_values_shallow(choices))) == 1 - @test length(collect(get_submaps_shallow(choices))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 0 @test isapprox(weight, log(1 - 0.2) - log(0.4)) @test get_retval(new_trace) == false @test discard[3] == true @test length(collect(get_values_shallow(discard))) == 1 - @test length(collect(get_submaps_shallow(discard))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(discard))) == 0 # change kernel_args, different key, with constraint constraints = choicemap() @@ -93,12 +93,12 @@ choices = get_choices(new_trace) @test choices[4] == false @test length(collect(get_values_shallow(choices))) == 1 - @test length(collect(get_submaps_shallow(choices))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 0 @test isapprox(weight, log(1 - 0.2) - log(0.4)) @test get_retval(new_trace) == false @test discard[3] == true @test length(collect(get_values_shallow(discard))) == 1 - @test length(collect(get_submaps_shallow(discard))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(discard))) == 0 end @testset "regenerate" begin @@ -110,7 +110,7 @@ choices = get_choices(new_trace) @test choices[3] == true @test length(collect(get_values_shallow(choices))) == 1 - @test length(collect(get_submaps_shallow(choices))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 0 @test isapprox(weight, log(0.2) - log(0.4)) @test get_retval(new_trace) == true @test isapprox(get_score(new_trace), log(0.2)) @@ -122,7 +122,7 @@ choices = get_choices(new_trace) value = choices[3] @test length(collect(get_values_shallow(choices))) == 1 - @test length(collect(get_submaps_shallow(choices))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 0 @test weight == 0. @test get_retval(new_trace) == value @test isapprox(get_score(new_trace), log(value ? 0.2 : 1 - 0.2)) @@ -133,7 +133,7 @@ choices = get_choices(new_trace) value = choices[4] @test length(collect(get_values_shallow(choices))) == 1 - @test length(collect(get_submaps_shallow(choices))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 0 @test weight == 0. @test get_retval(new_trace) == value @test isapprox(get_score(new_trace), log(value ? 0.2 : 1 - 0.2)) @@ -143,7 +143,7 @@ y = 1.2 constraints = choicemap() set_value!(constraints, 3, y) - (trace, _) = generate(choice_at(normal, Int), (0.0, 1.0, 3), constraints) + (trace, _) = generate(call_at(normal, Int), (0.0, 1.0, 3), constraints) # not selected (input_grads, choices, gradients) = choice_gradients( @@ -163,9 +163,9 @@ @test choices[3] == y @test isapprox(gradients[3], logpdf_grad(normal, y, 0.0, 1.0)[1] + retval_grad) @test length(collect(get_values_shallow(gradients))) == 1 - @test length(collect(get_submaps_shallow(gradients))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(gradients))) == 0 @test length(collect(get_values_shallow(choices))) == 1 - @test length(collect(get_submaps_shallow(choices))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 0 @test length(input_grads) == 3 @test isapprox(input_grads[1], logpdf_grad(normal, y, 0.0, 1.0)[2]) @test isapprox(input_grads[2], logpdf_grad(normal, y, 0.0, 1.0)[3]) diff --git a/test/modeling_library/dist_dsl.jl b/test/modeling_library/dist_dsl.jl index ee7ecd091..6b0269836 100644 --- a/test/modeling_library/dist_dsl.jl +++ b/test/modeling_library/dist_dsl.jl @@ -56,3 +56,15 @@ end @test logpdf(mylabel_cat, MyLabel(:a), [MyLabel(:a)], [1.0]) == 0 @test_throws MethodError logpdf(mylabel_cat, :a, [MyLabel(:a)], [1.0]) end + +@testset "dist dsl as generative function" begin + @dist labeled_cat(labels, probs) = labels[categorical(probs)] + + (tr, weight) = generate(labeled_cat, ([1, 2, 3], [.2, .2, .6]), Value(3)) + @test get_score(tr) == weight + @test get_retval(tr) == 3 + + (new_tr, weight, _, _) = update(tr, ([1, 2, 3], [.2, .2, .6]), (NoChange(), NoChange()), Value(2)) + @test isapprox(weight, log(.2) - log(.6)) + @test get_retval(new_tr) == 2 +end \ No newline at end of file diff --git a/test/modeling_library/map.jl b/test/modeling_library/map.jl index bfb13eb4e..1bb6cb02e 100644 --- a/test/modeling_library/map.jl +++ b/test/modeling_library/map.jl @@ -402,4 +402,19 @@ @test isapprox(get_param_grad(foo, :std), expected_std_grad) end + @testset "map over distribution" begin + flip_coins = Map(bernoulli) + coinflips_tr, weight = generate(flip_coins, (fill(0.4, 100),)) + @test weight == 0. + @test coinflips_tr[20] isa Bool + choices = get_choices(coinflips_tr) + @test get_submap(choices, 42) isa Value{Bool} + val42 = get_value(choices, 42) + new_tr, weight, retdiff, discard = update(coinflips_tr, (fill(0.4, 100),), (NoChange(),), choicemap((42, !val42))) + @test new_tr[42] == !val42 + expected_score_change = logpdf(bernoulli, !val42, 0.4) - logpdf(bernoulli, val42, 0.4) + @test isapprox(get_score(new_tr) - get_score(coinflips_tr), expected_score_change) + @test isapprox(weight, expected_score_change) + end + end diff --git a/test/modeling_library/modeling_library.jl b/test/modeling_library/modeling_library.jl index 616110f84..3eac6e0f8 100644 --- a/test/modeling_library/modeling_library.jl +++ b/test/modeling_library/modeling_library.jl @@ -1,8 +1,9 @@ include("custom_determ.jl") include("distributions.jl") +include("dist_dsl.jl") include("choice_at.jl") include("call_at.jl") include("map.jl") +include("set_map.jl") include("unfold.jl") -include("recurse.jl") -include("dist_dsl.jl") +include("recurse.jl") \ No newline at end of file diff --git a/test/modeling_library/recurse.jl b/test/modeling_library/recurse.jl index 46954e3be..b440a44fa 100644 --- a/test/modeling_library/recurse.jl +++ b/test/modeling_library/recurse.jl @@ -197,9 +197,9 @@ end @test choices[(4, Val(:production)) => :rule] == 4 @test choices[(4, Val(:aggregation)) => :prefix] == false @test discard[(3, Val(:aggregation)) => :prefix] == true - @test length(collect(get_submaps_shallow(discard))) == 1 + @test length(collect(get_nonvalue_submaps_shallow(discard))) == 1 @test length(collect(get_values_shallow(discard))) == 0 - @test length(collect(get_submaps_shallow(get_submap(discard,(3, Val(:aggregation)))))) == 0 + @test length(collect(get_nonvalue_submaps_shallow(get_submap(discard,(3, Val(:aggregation)))))) == 0 @test length(collect(get_values_shallow(get_submap(discard,(3, Val(:aggregation)))))) == 1 @test retdiff == UnknownChange() diff --git a/test/modeling_library/set_map.jl b/test/modeling_library/set_map.jl new file mode 100644 index 000000000..fdd09636c --- /dev/null +++ b/test/modeling_library/set_map.jl @@ -0,0 +1,67 @@ +using FunctionalCollections: push, disj +@testset "multiset" begin + ms = MultiSet() + @test ms isa MultiSet{Any} + + ms = MultiSet{Int}() + ms = push(ms, 1) + @test length(ms) == 1 + ms = push(ms, 1) + @test length(ms) == 2 + @test collect(ms) == [1, 1] + + ms = push(ms, 2) + @test length(ms) == 3 + @test 2 in ms + @test 1 in ms + ms = remove_one(ms, 1) + @test 1 in ms + @test 2 in ms + ms = push(ms, 2) + @test length(ms) == 3 + ms = disj(ms, 2) + @test length(ms) == 1 + + @test push(push(ms, 2), 2) == MultiSet([1, 2, 2]) + @test MultiSet([2, 1, 2, 5]) == MultiSet([2, 5, 2, 1]) + + total = 0 + for el in MultiSet([2, 1, 2, 5]) + total += el + end + @test total == 2+1+2+5 + + @test setmap(x -> x^2, Set([-2, -1, 0, 1])) == MultiSet([4, 1, 1, 0]) +end + +@testset "SetMap" begin + priors = [ + [0.1, 0.3, 0.6], + [0.2, 0.6, 0.2], + [0.6, 0.2, 0.2] + ] + tr, weight = generate(SetMap(categorical), (Set(priors),), choicemap((priors[1], 2))) + @test tr[priors[1]] == 2 + @test tr[priors[2]] in (1, 2, 3) + @test tr[priors[3]] in (1, 2, 3) + @test isapprox(weight, log(0.3)) + + tr = simulate(SetMap(categorical), (Set(priors),)) + exp_score = sum(logpdf(categorical, tr[priors[i]], priors[i]) for i=1:3) + @test isapprox(get_score(tr), exp_score) + + current1 = tr[priors[1]] + new = current1 == 1 ? 2 : 1 + + new_tr, weight, _, discard = update(tr, (Set(priors),), (NoChange(),), choicemap((priors[1], new))) + expected_weight = logpdf(categorical, new, priors[1]) - logpdf(categorical, tr[priors[1]], priors[1]) + @test isapprox(weight, expected_weight) + @test isapprox(get_score(new_tr) - get_score(tr), expected_weight) + @test discard == choicemap((priors[1], tr[priors[1]])) + + new_tr, weight, _, discard = update(tr, (Set(priors[1:2]),), (UnknownChange(),), EmptyAddressTree()) + expected_weight = -logpdf(categorical, tr[priors[3]], priors[3]) + @test isapprox(weight, expected_weight) + @test isapprox(get_score(new_tr), sum(logpdf(categorical, tr[priors[i]], priors[i]) for i=1:2)) + @test discard == choicemap((priors[3], tr[priors[3]])) +end \ No newline at end of file diff --git a/test/modeling_library/unfold.jl b/test/modeling_library/unfold.jl index ba748453b..0f3a56180 100644 --- a/test/modeling_library/unfold.jl +++ b/test/modeling_library/unfold.jl @@ -28,7 +28,7 @@ x3 = trace[3 => :x] choices = get_choices(trace) @test length(collect(get_values_shallow(choices))) == 0 - @test length(collect(get_submaps_shallow(choices))) == 3 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 3 expected_score = (logpdf(normal, x1, x_init * alpha + beta, std) + logpdf(normal, x2, x1 * alpha + beta, std) + logpdf(normal, x3, x2 * alpha + beta, std)) @@ -55,7 +55,7 @@ @test choices[1 => :x] == x1 @test choices[3 => :x] == x3 @test length(collect(get_values_shallow(choices))) == 0 - @test length(collect(get_submaps_shallow(choices))) == 3 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 3 x2 = choices[2 => :x] expected_weight = (logpdf(normal, x1, x_init * alpha + beta, std) + logpdf(normal, x3, x2 * alpha + beta, std)) @@ -77,7 +77,7 @@ beta = 0.3 (choices, weight, retval) = propose(foo, (3, x_init, alpha, beta)) @test length(collect(get_values_shallow(choices))) == 0 - @test length(collect(get_submaps_shallow(choices))) == 3 + @test length(collect(get_nonvalue_submaps_shallow(choices))) == 3 x1 = choices[1 => :x] x2 = choices[2 => :x] x3 = choices[3 => :x] diff --git a/test/runtests.jl b/test/runtests.jl index f075494c8..7933c959c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -79,4 +79,4 @@ include("static_ir/static_ir.jl") include("static_dsl.jl") include("tilde_sugar.jl") include("inference/inference.jl") -include("modeling_library/modeling_library.jl") +include("modeling_library/modeling_library.jl") \ No newline at end of file diff --git a/test/selection.jl b/test/selection.jl index 6436eaec7..55389a316 100644 --- a/test/selection.jl +++ b/test/selection.jl @@ -1,5 +1,4 @@ @testset begin "dynamic selection" - s = select(:x, :y => :z, :y => :w) # test Base.in @@ -14,28 +13,28 @@ # test Base.getindex @test s[:x] == AllSelection() sub = s[:y] - @test isa(sub, DynamicSelection) + @test isa(sub, DynamicAddressTree) @test :z in sub @test :w in sub - @test s[:u] == EmptySelection() + @test s[:u] == EmptyAddressTree() @test s[:y => :z] == AllSelection() # test set_subselection! - set_subselection!(s, :y, select(:z)) + set_subtree!(s, :y, select(:z)) @test (:y => :z) in s @test !((:y => :w) in s) selection = select(:x) @test :x in selection subselection = select(:y) - set_subselection!(selection, :x, subselection) + set_subtree!(selection, :x, subselection) @test (:x => :y) in selection @test !(:x in selection) end @testset begin "all selection" - s = selectall() + s = AllSelection() # test Base.in @test :x in s @@ -46,24 +45,57 @@ end @test s[:x => :y] == AllSelection() end -@testset begin "complement selection" - - @test !(:x in complement(selectall())) - @test :x in complement(select()) +@testset begin "addrs" + choices = choicemap((:a, 1), (:b => :c, 2)) + a = addrs(choices) + @test a isa Selection + @test :a in a + @test (:b => :c) in a + @test !(:d in a) + @test get_subselection(a, :b) == select(:c) + @test length(collect(get_subtrees_shallow(a))) == 2 +end - @test !(:x in complement(select(:x))) - @test :y in complement(select(:x)) +@testset begin "push, merge, merge!" + x = select(:x, :y => :z) + push!(x, :y => :w) + @test !(:y in x) + @test (:y => :z) in x + @test (:y => :w) in x + @test :x in x + + y = select(:a => 1 => :b, 5.1, :y => :k) + z = merge(x, y) + @test 5.1 in z + @test :x in z + @test (:y => :w) in z + @test (:a => :1 => :b) in z + @test (:y => :k) in z - @test :x in complement(select(:x => :y => :z)) - @test (:x => :y) in complement(select(:x => :y => :z)) - @test !((:x => :y => :z) in complement(select(:x => :y => :z))) + merge!(y, x) + @test (:y => :w) in y + @test 5.1 in y + @test (:y => :k) in y +end - @test !(:x in complement(complement(select(:x => :y => :z)))) - @test !((:x => :y) in complement(complement(select(:x => :y => :z)))) - @test (:x => :y => :z) in complement(complement(select(:x => :y => :z))) +@testset begin "inverted selection" + sel = select(:x, :y => :z, :a => :b => 1, :a => :b => 2, :a => :c => 1) + i = invert(sel) + @test !(:x in i) + @test !(:y in i) + @test !((:y => :z) in i) + @test (:y => :a) in i + @test !((:a => :b) in i) + @test (:a => :b => 3) in i + @test (:a => :d) in i + @test :z in i - s = complement(select(:x => :y => :z))[:x] - @test !((:y => :z) in s) - @test :w in s - @test :y in s -end + c = choicemap( + (:x, 1), + (:y, 5), + (:a => :b => 1, 10), + (:a => :b => 3, 2), + (:a => :c => 2, 12) + ) + @test get_selected(c, i) == choicemap((:a => :b => 3, 2), (:a => :c => 2, 12)) +end \ No newline at end of file diff --git a/test/static_dsl.jl b/test/static_dsl.jl index 3a8bc0a40..6b60131ae 100644 --- a/test/static_dsl.jl +++ b/test/static_dsl.jl @@ -40,13 +40,13 @@ end ret = @trace(bernoulli(0.5), :x => i) end -# @trace(choice_at(bernoulli)(0.5, i), :x) +# @trace(call_at(bernoulli)(0.5, i), :x) @gen (static) function at_choice_example_2(i::Int) ret = @trace(bernoulli(0.5), :x => i => :y) end -# @trace(call_at(choice_at(bernoulli))(0.5, i, :y), :x) +# @trace(call_at(call_at(bernoulli))(0.5, i, :y), :x) @gen function foo(mu) @trace(normal(mu, 1), :y) @@ -64,6 +64,12 @@ end ret = @trace(foo(mu), :x => i => :y) end +@gen (static) function foo6() + x ~ exponential(1) + y ~ normal(x, 1) +end +@load_generated_functions() + # Modules to test load_generated_functions module MyModuleA using Gen @@ -117,14 +123,13 @@ params = ir.arg_nodes[2] @test params.compute_grad # choice nodes and call nodes -@test length(ir.choice_nodes) == 2 -@test length(ir.call_nodes) == 0 +@test length(ir.call_nodes) == 2 # is_outlier -is_outlier = ir.choice_nodes[1] +is_outlier = ir.call_nodes[1] @test is_outlier.addr == :z @test is_outlier.typ == QuoteNode(Bool) -@test is_outlier.dist == bernoulli +@test is_outlier.generative_function == bernoulli @test length(is_outlier.inputs) == 1 # std @@ -138,10 +143,10 @@ in2 = std.inputs[2] @test (in1 === is_outlier && in2 === params) || (in2 === is_outlier && in1 === params) # y -y = ir.choice_nodes[2] +y = ir.call_nodes[2] @test y.addr == :y @test y.typ == QuoteNode(Float64) -@test y.dist == normal +@test y.generative_function == normal @test length(y.inputs) == 2 @test y.inputs[2] === std @@ -174,40 +179,39 @@ xs = ir.arg_nodes[1] @test xs.typ == :(Vector{Float64}) @test !xs.compute_grad -# choice nodes and call nodes -@test length(ir.choice_nodes) == 4 -@test length(ir.call_nodes) == 1 +# call nodes +@test length(ir.call_nodes) == 5 # inlier_std -inlier_std = ir.choice_nodes[1] +inlier_std = ir.call_nodes[1] @test inlier_std.addr == :inlier_std @test inlier_std.typ == QuoteNode(Float64) -@test inlier_std.dist == gamma +@test inlier_std.generative_function == gamma @test length(inlier_std.inputs) == 2 # outlier_std -outlier_std = ir.choice_nodes[2] +outlier_std = ir.call_nodes[2] @test outlier_std.addr == :outlier_std @test outlier_std.typ == QuoteNode(Float64) -@test outlier_std.dist == gamma +@test outlier_std.generative_function == gamma @test length(outlier_std.inputs) == 2 # slope -slope = ir.choice_nodes[3] +slope = ir.call_nodes[3] @test slope.addr == :slope @test slope.typ == QuoteNode(Float64) -@test slope.dist == normal +@test slope.generative_function == normal @test length(slope.inputs) == 2 # intercept -intercept = ir.choice_nodes[4] +intercept = ir.call_nodes[4] @test intercept.addr == :intercept @test intercept.typ == QuoteNode(Float64) -@test intercept.dist == normal +@test intercept.generative_function == normal @test length(intercept.inputs) == 2 # data -ys = ir.call_nodes[1] +ys = ir.call_nodes[5] @test ys.addr == :data @test ys.typ == QuoteNode(PersistentVector{Float64}) @test ys.generative_function == data_fn @@ -257,8 +261,8 @@ ret = get_node_by_addr(ir, :x) @test isa(ret.inputs[1], Gen.JuliaNode) # () -> 0.5 @test ret.inputs[2] === i at = ret.generative_function -@test isa(at, Gen.ChoiceAtCombinator) -@test at.dist == bernoulli +@test isa(at, Gen.CallAtCombinator) +@test at.kernel == bernoulli # at_choice_example_2 ir = Gen.get_ir(typeof(at_choice_example_2)) @@ -273,8 +277,8 @@ ret = get_node_by_addr(ir, :x) at = ret.generative_function @test isa(at, Gen.CallAtCombinator) at2 = at.kernel -@test isa(at2, Gen.ChoiceAtCombinator) -@test at2.dist == bernoulli +@test isa(at2, Gen.CallAtCombinator) +@test at2.kernel == bernoulli end @@ -376,7 +380,7 @@ ir2 = Gen.get_ir(typeof(f2)) return_node1 = ir1.return_node return_node2 = ir2.return_node @test isa(return_node2, typeof(return_node1)) -@test return_node2.dist == return_node1.dist +@test return_node2.generative_function == return_node1.generative_function inputs1 = return_node1.inputs inputs2 = return_node2.inputs @@ -584,11 +588,24 @@ tr = simulate(bar1, ()) ch = get_choices(tr) @test has_value(ch, :x) @test !has_value(ch, :y) -@test_throws KeyError get_submap(ch, :x) + @test has_value(get_submap(ch, :a), :b) @test get_submap(ch, :y) == EmptyChoiceMap() -@test length(get_values_shallow(ch)) == 1 -@test length(get_submaps_shallow(ch)) == 1 +@test length(collect(get_values_shallow(ch))) == 1 +@test length(collect(get_submaps_shallow(ch))) == 2 +end + +@testset "inverted selections!" begin + tr = simulate(foo6, ()) + # regenerate y but not x + new_tr, _, _ = regenerate(tr, (), (), invert(select(:x))) + @test new_tr[:x] == tr[:x] + @test new_tr[:y] != tr[:y] + @test isapprox(project(new_tr, invert(select(:y))), logpdf(exponential, new_tr[:x], 1)) + + # regenerate x and y, with y but not x constrained in the reverse direction + new_tr, weight, _, _ = update(tr, (), (), AllSelection(), invert(select(:y))) + @test isapprox(weight, -get_score(tr) + logpdf(normal, tr[:y], tr[:x], 1.)) end @testset "macros in static functions" begin @@ -609,4 +626,4 @@ end @test trace[:x] == trace[:y] end -end # @testset "static DSL" +end # @testset "static DSL" \ No newline at end of file diff --git a/test/static_ir/static_ir.jl b/test/static_ir/static_ir.jl index 91c6c3202..b847d29ee 100644 --- a/test/static_ir/static_ir.jl +++ b/test/static_ir/static_ir.jl @@ -362,12 +362,12 @@ end @test get_value(value_trie, :out) == out @test get_value(value_trie, :bar => :z) == z @test !has_value(value_trie, :b) # was not selected - @test length(get_submaps_shallow(value_trie)) == 1 - @test length(get_values_shallow(value_trie)) == 2 + @test length(collect(get_nonvalue_submaps_shallow(value_trie))) == 1 + @test length(collect(get_values_shallow(value_trie))) == 2 # check gradient trie - @test length(get_submaps_shallow(gradient_trie)) == 1 - @test length(get_values_shallow(gradient_trie)) == 2 + @test length(collect(get_nonvalue_submaps_shallow(gradient_trie))) == 1 + @test length(collect(get_values_shallow(gradient_trie))) == 2 @test !has_value(gradient_trie, :b) # was not selected @test isapprox(get_value(gradient_trie, :a), finite_diff(f, (mu_a, theta, a, b, z, out), 3, dx)) @test isapprox(get_value(gradient_trie, :out), finite_diff(f, (mu_a, theta, a, b, z, out), 6, dx)) @@ -554,6 +554,47 @@ Gen.load_generated_functions() @test counter == 1 end +struct SignFlipDiff <: Gen.Diff end +counter = 0 +Base.abs(v::Diffed{<:Any, SignFlipDiff}) = Diffed(abs(strip_diff(v)), NoChange()) +Base.abs(v::Diffed{<:Any, UnknownChange}) = Diffed(abs(strip_diff(v)), UnknownChange()) + +#= +@gen (static, diffs) function foo105(x) + x1 = abs(x) + x2 = begin x1 + 5; counter += 1; end + return x2 +end +=# + +counter = 0 +builder = StaticIRBuilder() +x = add_argument_node!(builder, name=:x) +x1 = add_julia_node!(builder, abs, inputs=[x], name=:x1) +x2 = add_julia_node!(builder, (a,) -> begin counter += 1; a += 5; end, inputs=[x1], name=:x2) +set_return_node!(builder, x2) +ir = build_ir(builder) +foo105 = eval(generate_generative_function(ir, :foo105, track_diffs=true, cache_julia_nodes=true)) +Gen.load_generated_functions() + +@testset "cached julia nodes with runtime NoChange diffs" begin + counter = 0 + tr = simulate(foo105, (-4,)) + @test counter == 1 + + counter = 0 + update(tr, (4,), (SignFlipDiff(),), EmptyChoiceMap()) + @test counter == 0 + + counter = 0 + update(tr, (-4,), (NoChange(),), EmptyChoiceMap()) + @test counter == 0 + + counter = 0 + update(tr, (5,), (UnknownChange(),), EmptyChoiceMap()) + @test counter == 1 +end + @testset "regression test for https://github.com/probcomp/Gen/issues/168" begin @gen (static) function model(var) mean = @trace(normal(0, 1), :mean)