diff --git a/Project.toml b/Project.toml index 538b118e9..31ae909cd 100644 --- a/Project.toml +++ b/Project.toml @@ -15,6 +15,7 @@ MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" [compat] diff --git a/docs/src/ref/extending.md b/docs/src/ref/extending.md index 7f9dfd480..8e28fc27c 100644 --- a/docs/src/ref/extending.md +++ b/docs/src/ref/extending.md @@ -228,6 +228,17 @@ If your generative function has trainable parameters, then implement: - [`accumulate_param_gradients!`](@ref) +#### Supporting trace serialization +To support trace serialization, a trace type of type `T` for a generative function of type `G` must convertable into a `SerializableTrace` object, and must be recoverable from a `SerializableTrace` object and the generative function. +```@docs +SerializableTrace +to_serializable_trace +from_serializable_trace +``` +A user must implement `to_serializable_Trace(::T)`, and `from_serializable_Trace(::ST, ::G)` for some concrete type `ST <: SerializableTrace`. This may be a custom type, or the user may use the built-in type +```@docs +GenericSerializableTrace +``` ## Custom modeling languages diff --git a/docs/src/ref/gfi.md b/docs/src/ref/gfi.md index dde929722..d3da5ff26 100644 --- a/docs/src/ref/gfi.md +++ b/docs/src/ref/gfi.md @@ -350,6 +350,17 @@ The set of elements (either arguments, random choices, or trainable parameters) If the return value of the function is conditionally dependent on any element in the gradient source set given the arguments and values of all other random choices, for all possible traces of the function, then the generative function requires a *return value gradient* to compute gradients with respect to elements of the gradient source set. This static property of the generative function is reported by [`accepts_output_grad`](@ref). +## Serialization +To serialize a trace `tr` for a generative function `gf` +(stave the trace to disk), a user may call +```julia +serialize_trace(filename_or_io::Union{IO, AbstractString}, tr) +``` +To recover the trace, a user may call +```julia +deserialized_tr = deserialize_trace(filename_or_io, gf) +``` + ## Generative function interface The complete set of methods in the generative function interface (GFI) is: diff --git a/src/Gen.jl b/src/Gen.jl index a3a40e38b..dc7b43363 100644 --- a/src/Gen.jl +++ b/src/Gen.jl @@ -49,6 +49,9 @@ include("trie.jl") # generative function interface include("gen_fn_interface.jl") +# serialization/deserialization for traces +include("serialization.jl") + # built-in data types for arg-diff and ret-diff values include("diff.jl") diff --git a/src/dynamic/dynamic.jl b/src/dynamic/dynamic.jl index d83055444..2411213bf 100644 --- a/src/dynamic/dynamic.jl +++ b/src/dynamic/dynamic.jl @@ -175,6 +175,7 @@ function gen_fn_changed_error(addr) error("Generative function changed at address: $addr") end +include("serialization.jl") include("simulate.jl") include("generate.jl") include("propose.jl") diff --git a/src/dynamic/serialization.jl b/src/dynamic/serialization.jl new file mode 100644 index 000000000..3d57b9432 --- /dev/null +++ b/src/dynamic/serialization.jl @@ -0,0 +1,48 @@ +function _record_to_serializable(r::ChoiceOrCallRecord{T}) where {T <: Trace} + @assert !r.is_choice + return ChoiceOrCallRecord(to_serializable_trace(r.subtrace_or_retval), r.score, r.noise, r.is_choice) +end +function _record_to_serializable(r::ChoiceOrCallRecord) + @assert r.is_choice + return r +end +function _record_from_serializable(r::ChoiceOrCallRecord{T}, gf::GenerativeFunction) where {T <: SerializableTrace} + @assert !r.is_choice + return ChoiceOrCallRecord(from_serializable_trace(r.subtrace_or_retval, gf), r.score, r.noise, r.is_choice) +end +function _record_from_serializable(r::ChoiceOrCallRecord, dist::Distribution) + @assert r.is_choice + return r +end +function _trie_to_serializable(trie::Trie) + triemap(trie, identity, _record_to_serializable) +end +function to_serializable_trace(tr::DynamicDSLTrace) + return GenericSerializableTrace( + _trie_to_serializable(tr.trie), + (tr.isempty, tr.score, tr.noise, tr.args, tr.retval) + ) +end + +# since a Dynamic Gen Function doesn't store +# what sub-generative-function is at which address, +# we have to run the generative function to get access to this! +mutable struct GFDeserializeState + trace::DynamicDSLTrace + serialized::GenericSerializableTrace +end +function from_serializable_trace(st::GenericSerializableTrace, gen_fn::DynamicDSLFunction{T}) where T + trace = DynamicDSLTrace{T}(gen_fn, Trie{Any, ChoiceOrCallRecord}(), st.properties...) + state = GFDeserializeState(trace, st) + exec(gen_fn, state, trace.args) + return trace +end +function traceat(state::GFDeserializeState, dist_or_gen_fn, args, key) + record = _record_from_serializable(state.serialized.subtraces[key], dist_or_gen_fn) + state.trace.trie[key] = record + return record.is_choice ? record.subtrace_or_retval : get_retval(record.subtrace_or_retval) +end +function splice(state::GFDeserializeState, gf::DynamicDSLFunction, args::Tuple) + return exec(gf, state, args) +end +read_param(::GFDeserializeState, ::Symbol) = nothing \ No newline at end of file diff --git a/src/dynamic/trace.jl b/src/dynamic/trace.jl index 8c02eceb5..a72e475d9 100644 --- a/src/dynamic/trace.jl +++ b/src/dynamic/trace.jl @@ -43,6 +43,9 @@ mutable struct DynamicDSLTrace{T} <: Trace # retval is not known yet new(gen_fn, trie, true, 0, 0, args) end + function DynamicDSLTrace{T}(gen_fn::T, trie, isempty, score, noise, args, retval) where {T} + new(gen_fn, trie, isempty, score, noise, args, retval) + end end set_retval!(trace::DynamicDSLTrace, retval) = (trace.retval = retval) diff --git a/src/modeling_library/call_at/call_at.jl b/src/modeling_library/call_at/call_at.jl index 0c5b997bc..177aa645e 100644 --- a/src/modeling_library/call_at/call_at.jl +++ b/src/modeling_library/call_at/call_at.jl @@ -157,4 +157,11 @@ function accumulate_param_gradients!(trace::CallAtTrace, retval_grad) (kernel_input_grads..., nothing) end +function to_serializable_trace(tr::CallAtTrace) + return GenericSerializableTrace(to_serializable_trace(tr.subtrace), tr.key) +end +function from_serializable_trace(st::GenericSerializableTrace, gf::CallAtCombinator) + return get_trace_type(gf)(gf, from_serializable_trace(st.subtraces, gf.kernel), st.properties) +end + export call_at diff --git a/src/modeling_library/choice_at/choice_at.jl b/src/modeling_library/choice_at/choice_at.jl index 09cb922fa..b89afa73c 100644 --- a/src/modeling_library/choice_at/choice_at.jl +++ b/src/modeling_library/choice_at/choice_at.jl @@ -25,6 +25,7 @@ function get_address_schema(::Type{T}) where {T<:ChoiceAtChoiceMap} end get_value(choices::ChoiceAtChoiceMap, addr::Pair) = _get_value(choices, addr) has_value(choices::ChoiceAtChoiceMap, addr::Pair) = _has_value(choices, addr) +has_value(choices::ChoiceAtChoiceMap, addr) = addr == choices.key function get_value(choices::ChoiceAtChoiceMap{T,K}, addr::K) where {T,K} choices.key == addr ? choices.value : throw(KeyError(choices, addr)) end @@ -172,4 +173,11 @@ function accumulate_param_gradients!(trace::ChoiceAtTrace, retval_grad) (kernel_arg_grads[2:end]..., nothing) end +function to_serializable_trace(tr::ChoiceAtTrace) + return GenericSerializableTrace(nothing, (tr.value, tr.key, tr.kernel_args, tr.score)) +end +function from_serializable_trace(st::GenericSerializableTrace, gf::ChoiceAtCombinator) + return get_trace_type(gf)(gf, st.properties...) +end + export choice_at diff --git a/src/modeling_library/custom_determ.jl b/src/modeling_library/custom_determ.jl index 24d6d90f2..fdb13bc3a 100644 --- a/src/modeling_library/custom_determ.jl +++ b/src/modeling_library/custom_determ.jl @@ -204,4 +204,11 @@ has_argument_grads(gen_fn::CustomUpdateGF) = tuple(fill(nothing, num_args(gen_fn apply_with_state(gen_fn::CustomUpdateGF, args) = error("not implemented") +function to_serializable_trace(tr::CustomDetermGFTrace) + return GenericSerializableTrace(nothing, (tr.retval, tr.state, tr.args)) +end +function from_serializable_trace(st::GenericSerializableTrace, gf::CustomDetermGF) + return get_trace_type(gf)(st.properties..., gf) +end + export CustomUpdateGF, num_args diff --git a/src/modeling_library/map/map.jl b/src/modeling_library/map/map.jl index 1bb695ff7..5aaa63ab6 100644 --- a/src/modeling_library/map/map.jl +++ b/src/modeling_library/map/map.jl @@ -43,6 +43,8 @@ function get_prev_and_new_lengths(args::Tuple, prev_trace) (new_length, prev_length) end +_gen_fn_at_addr(gf::Map, _) = gf.kernel + include("assess.jl") include("propose.jl") include("simulate.jl") diff --git a/src/modeling_library/recurse/recurse.jl b/src/modeling_library/recurse/recurse.jl index 7b4d1d23d..73d1326fa 100644 --- a/src/modeling_library/recurse/recurse.jl +++ b/src/modeling_library/recurse/recurse.jl @@ -193,6 +193,27 @@ function get_aggregation_constraints(constraints::ChoiceMap, cur::Int) get_submap(constraints, (cur, Val(:aggregation))) end +function to_serializable_trace(tr::RecurseTrace) + return GenericSerializableTrace( + ( + Dict(k => to_serializable_trace(subtr) for (k, subtr) in tr.production_traces), + Dict(k => to_serializable_trace(subtr) for (k, subtr) in tr.aggregation_traces) + ), + (tr.max_branch, tr.score, tr.root_idx, tr.num_has_choices) + ) +end +function from_serializable_trace(st::GenericSerializableTrace, gf::Recurse{S, T}) where {S, T} + production_traces = PersistentHashMap{Int, S}() + for (k, subst) in st.subtraces[1] + production_traces = assoc(production_traces, k, from_serializable_trace(subst, gf.production_kern)) + end + aggregation_traces = PersistentHashMap{Int, T}() + for (k, subst) in st.subtraces[2] + aggregation_traces = assoc(aggregation_traces, k, from_serializable_trace(subst, gf.aggregation_kern)) + end + return get_trace_type(gf)(gf, production_traces, aggregation_traces, st.properties...) +end + ############ # simulate # ############ diff --git a/src/modeling_library/switch/switch.jl b/src/modeling_library/switch/switch.jl index 9ef5752e1..05323f589 100644 --- a/src/modeling_library/switch/switch.jl +++ b/src/modeling_library/switch/switch.jl @@ -30,6 +30,19 @@ function (gen_fn::Switch{C})(index::C, args...) where C end include("trace.jl") + +function to_serializable_trace(tr::SwitchTrace) + GenericSerializableTrace(to_serializable_trace(tr.branch), (tr.index, tr.retval, tr.args, tr.score, tr.noise)) +end +function from_serializable_trace(c::GenericSerializableTrace, gf::Switch) + (index, retval, args, score, noise) = c.properties + GenericSerializableTrace( + gf, index, + from_serializable_trace(c.subtraces, gf.branches[index]), + retval, args, score, noise + ) +end + include("assess.jl") include("propose.jl") include("simulate.jl") diff --git a/src/modeling_library/unfold/unfold.jl b/src/modeling_library/unfold/unfold.jl index 44238e3b7..ae077a5eb 100644 --- a/src/modeling_library/unfold/unfold.jl +++ b/src/modeling_library/unfold/unfold.jl @@ -60,6 +60,8 @@ function check_length(len::Int) end end +_gen_fn_at_addr(gf::Unfold, _) = gf.kernel + include("simulate.jl") include("generate.jl") include("propose.jl") diff --git a/src/modeling_library/vector.jl b/src/modeling_library/vector.jl index f35360b50..2b34f2bbd 100644 --- a/src/modeling_library/vector.jl +++ b/src/modeling_library/vector.jl @@ -184,3 +184,20 @@ function vector_remove_deleted_applications(subtraces, retval, prev_length, new_ end (subtraces, retval) end + +################# +# Serialization # +################# +function to_serializable_trace(trace::VectorTrace) + GenericSerializableTrace( + [to_serializable_trace(st) for st in trace.subtraces], + (trace.retval, trace.args, trace.len, trace.num_nonempty, trace.score, trace.noise) + ) +end +function from_serializable_trace(st::GenericSerializableTrace, gf::GenerativeFunction{<:Any, VectorTrace{GenFnType, T, U}}) where {GenFnType, T, U} + subtraces = PersistentVector{U}( + [from_serializable_trace(serialized_subtrace, _gen_fn_at_addr(gf, i)) + for (i, serialized_subtrace) in enumerate(st.subtraces)] + ) + get_trace_type(gf)(gf, subtraces, st.properties...) +end \ No newline at end of file diff --git a/src/serialization.jl b/src/serialization.jl new file mode 100644 index 000000000..a71ccdd3e --- /dev/null +++ b/src/serialization.jl @@ -0,0 +1,62 @@ +using Serialization: serialize, deserialize + +""" + SerializableTrace + +A representation of a `Trace` which can be serialized. Obtainable via `to_serializable_trace`. +This does not need to contain the `GenerativeFunction` which produced the trace; +to deserialize (using `from_serializable_trace`), the `GenerativeFunction` must be provided. +""" +abstract type SerializableTrace end + +""" + to_serializable_trace(trace::Trace) + +Get a SerializableTrace representing the `trace` in a serializable manner. +""" +function to_serializable_trace(trace::Trace) + error("Not implemented") +end + +""" + from_serializable_trace(st::SerializableTrace, fn::GenerativeFunction) + +Get the trace of the given generative function encoded by the serializable trace object. +""" +function from_serializable_trace(::SerializableTrace, ::GenerativeFunction) + error("Not implemented.") +end + +""" + serialize_trace(stream::IO, trace::Trace) + serialize_trace(filename::AbstractString, trace::Trace) + +Serialize the given trace to the given stream or file, by converting to a `SerializableTrace`. +""" +function serialize_trace(filename_or_io::Union{IO, AbstractString}, trace::Trace) + return serialize(filename_or_io, to_serializable_trace(trace)) +end + +""" + deserialize_trace(stream::IO, gen_fn::GenerativeFunction) + deserialize_trace(filename::AbstractString, gen_fn::GenerativeFunction) + +Deserialize the trace for the given generative function stored in the given stream or file +(as saved via `serialize_trace`). +""" +function deserialize_trace(filename_or_io::Union{IO, AbstractString}, gf::GenerativeFunction) + return from_serializable_trace(deserialize(filename_or_io), gf) +end + +""" + GenericSerializableTrace <: SerializableTrace + +A SerializableTrace which contains some subtraces which have been recursively converted +to `SerializableTrace`s, and some properties which are directly serializable. +""" +struct GenericSerializableTrace{S, P} <: SerializableTrace + subtraces::S + properties::P +end + +export to_serializable_trace, from_serializable_trace, serialize_trace, deserialize_trace \ No newline at end of file diff --git a/src/static_ir/static_ir.jl b/src/static_ir/static_ir.jl index d5d7e237f..1eda20e8b 100644 --- a/src/static_ir/static_ir.jl +++ b/src/static_ir/static_ir.jl @@ -37,10 +37,10 @@ function generate_generative_function(ir::StaticIR, name::Symbol; track_diffs=fa end function generate_generative_function(ir::StaticIR, name::Symbol, options::StaticIRGenerativeFunctionOptions) + gen_fn_type_name = gensym("StaticGenFunction_$name") - (trace_defns, trace_struct_name) = generate_trace_type_and_methods(ir, name, options) + (trace_defns, trace_struct_name, tracefields) = generate_trace_type_and_methods(ir, name, options) - gen_fn_type_name = gensym("StaticGenFunction_$name") return_type = ir.return_node.typ trace_type = trace_struct_name has_argument_grads = tuple(map((node) -> node.compute_grad, ir.arg_nodes)...) @@ -61,7 +61,10 @@ function generate_generative_function(ir::StaticIR, name::Symbol, options::Stati $(GlobalRef(Gen, :get_gen_fn_type))(::Type{$trace_struct_name}) = $gen_fn_type_name $(GlobalRef(Gen, :get_options))(::Type{$gen_fn_type_name}) = $(QuoteNode(options)) end - Expr(:block, trace_defns, gen_fn_defn, Expr(:call, gen_fn_type_name, :(Dict{Symbol,Any}()), :(Dict{Symbol,Any}()))) + + serialization_code = generate_serialization_methods(ir, trace_struct_name, gen_fn_type_name, tracefields) + + Expr(:block, trace_defns, gen_fn_defn, serialization_code, Expr(:call, gen_fn_type_name, :(Dict{Symbol,Any}()), :(Dict{Symbol,Any}()))) end include("print_ir.jl") diff --git a/src/static_ir/trace.jl b/src/static_ir/trace.jl index de2c84b30..4d1f80d92 100644 --- a/src/static_ir/trace.jl +++ b/src/static_ir/trace.jl @@ -86,7 +86,9 @@ const return_value_fieldname = gensym("retval") struct TraceField fieldname::Symbol typ::Union{Symbol,Expr,QuoteNode} + holds_subtrace::Bool end +TraceField(f, t) = TraceField(f, t, false) function get_trace_fields(ir::StaticIR, options::StaticIRGenerativeFunctionOptions) fields = TraceField[] @@ -103,7 +105,7 @@ function get_trace_fields(ir::StaticIR, options::StaticIRGenerativeFunctionOptio for node in ir.call_nodes subtrace_fieldname = get_subtrace_fieldname(node) subtrace_type = QuoteNode(get_trace_type(node.generative_function)) - push!(fields, TraceField(subtrace_fieldname, subtrace_type)) + push!(fields, TraceField(subtrace_fieldname, subtrace_type, true)) end if options.cache_julia_nodes for node in ir.julia_nodes @@ -124,8 +126,54 @@ function generate_trace_struct(ir::StaticIR, trace_struct_name::Symbol, options: mutable = false fields = get_trace_fields(ir, options) field_exprs = map((f) -> Expr(:(::), f.fieldname, f.typ), fields) - Expr(:struct, mutable, Expr(:(<:), trace_struct_name, QuoteNode(StaticIRTrace)), + return ( + fields, + Expr(:struct, mutable, Expr(:(<:), trace_struct_name, QuoteNode(StaticIRTrace)), Expr(:block, field_exprs..., Expr(:(::), static_ir_gen_fn_ref, QuoteNode(Any)))) + ) +end + +function generate_serialization_methods(ir::StaticIR, trace_struct_name::Symbol, gen_fn_typename::Symbol, fields) + to_subtraces_exprs = [ + :($(GlobalRef(Gen, :to_serializable_trace))(tr.$(field.fieldname))) + for field in fields if field.holds_subtrace + ] + to_properties_exprs = [:(tr.$(field.fieldname)) for field in fields if !field.holds_subtrace] + + # fields will have a bunch of properties, then the subtraces, then more properties + num_initial_props = 0 + for field in fields + if !field.holds_subtrace + num_initial_props += 1 + else + break; + end + end + + gen_fns = [node.generative_function for node in ir.call_nodes] + + quote + function $(GlobalRef(Gen, :to_serializable_trace))(tr::$trace_struct_name) + return $(GlobalRef(Gen, :GenericSerializableTrace))( + $(Expr(:tuple, to_subtraces_exprs...)), + $(Expr(:tuple, to_properties_exprs...)) + ) + end + function $(GlobalRef(Gen, :from_serializable_trace))( + st::$(GlobalRef(Gen, :GenericSerializableTrace)), + gf::$gen_fn_typename + ) + return $trace_struct_name( + st.properties[1:$num_initial_props]..., + ( + $(GlobalRef(Gen, :from_serializable_trace))(args...) + for args in zip(st.subtraces, $gen_fns) + )..., + st.properties[$(num_initial_props + 1):end]..., + gf + ) + end + end end function generate_isempty(trace_struct_name::Symbol) @@ -284,7 +332,7 @@ end function generate_trace_type_and_methods(ir::StaticIR, name::Symbol, options::StaticIRGenerativeFunctionOptions) trace_struct_name = gensym("StaticIRTrace_$name") - trace_struct_expr = generate_trace_struct(ir, trace_struct_name, options) + (fields, trace_struct_expr) = generate_trace_struct(ir, trace_struct_name, options) isempty_expr = generate_isempty(trace_struct_name) get_score_expr = generate_get_score(trace_struct_name) get_args_expr = generate_get_args(ir, trace_struct_name) @@ -302,8 +350,10 @@ function generate_trace_type_and_methods(ir::StaticIR, name::Symbol, options::St get_args_expr, get_retval_expr, get_choices_expr, get_schema_expr, get_values_shallow_expr, get_submaps_shallow_expr, static_get_value_exprs..., - static_has_value_exprs..., static_get_submap_exprs..., getindex_exprs...) - (exprs, trace_struct_name) + static_has_value_exprs..., static_get_submap_exprs..., + getindex_exprs... + ) + (exprs, trace_struct_name, fields) end export StaticIRTrace diff --git a/src/trie.jl b/src/trie.jl index 0d1c2a8a2..e623d167b 100644 --- a/src/trie.jl +++ b/src/trie.jl @@ -2,7 +2,7 @@ # Trie # ################## -struct Trie{K,V} <: ChoiceMap +struct Trie{K,V} leaf_nodes::Dict{K,V} internal_nodes::Dict{K,Trie{K,V}} end @@ -32,6 +32,7 @@ Base.isempty(trie::Trie) = isempty(trie.leaf_nodes) && isempty(trie.internal_nod get_leaf_nodes(trie::Trie) = trie.leaf_nodes get_internal_nodes(trie::Trie) = trie.internal_nodes +Base.:(==)(t1::Trie, t2::Trie) = get_leaf_nodes(t1) == get_leaf_nodes(t2) && get_internal_nodes(t1) == get_internal_nodes(t2) function Base.values(trie::Trie) iterators = convert(Vector{Any}, collect(map(values, values(trie.internal_nodes)))) push!(iterators, values(trie.leaf_nodes)) @@ -179,6 +180,28 @@ Base.haskey(trie::Trie, key) = has_leaf_node(trie, key) Base.getindex(trie::Trie, key) = get_leaf_node(trie, key) +""" + triemap(trie::Trie, key_converter, leaf_converter) + +Get a new trie by applying the function `key_converter` to every key in the trie +and applying the function `leaf_converter` to every leaf node in the trie. +""" +function triemap(trie::Trie{K, V}, key_converter, leaf_converter) where {K, V} + new_keytype = Core.Compiler.return_type(key_converter, Tuple{K}) + KT = Union{Base.return_types(key_converter, (K,))...} + LT = Union{Base.return_types(leaf_converter, (V,))...} + converted_leafs = Dict{KT, LT}( + key_converter(k) => leaf_converter(v) for (k, v) in trie.leaf_nodes + ) + converted_internals = Dict{KT, Trie{KT, LT}}( + key_converter(k) => convert_to_serializable_trie(subtrie, key_converter, leaf_converter, new_keytype, new_leaftype) + for (k, subtrie) in trie.internal_nodes + ) + + return Trie{KT, LT}(converted_leafs, converted_internals) +end + + export Trie export set_internal_node! export delete_internal_node! diff --git a/test/dsl/dynamic_dsl.jl b/test/dsl/dynamic_dsl.jl index 25e3a50f4..a4672b31a 100644 --- a/test/dsl/dynamic_dsl.jl +++ b/test/dsl/dynamic_dsl.jl @@ -534,4 +534,26 @@ end end +@testset "serialization" begin + @gen function bar() + @trace(normal(0, 1), :a) + end + + @gen function baz() + @trace(normal(0, 1), :b) + end + + @gen function foo() + if @trace(bernoulli(0.4), :branch) + @trace(normal(0, 1), :x) + @trace(bar(), :u) + else + @trace(normal(0, 1), :y) + @trace(baz(), :v) + end + end + tr = simulate(foo, ()) + @test serialize_loop_successful(tr) +end + end diff --git a/test/dsl/static_dsl.jl b/test/dsl/static_dsl.jl index e062d06f9..077f4bdbe 100644 --- a/test/dsl/static_dsl.jl +++ b/test/dsl/static_dsl.jl @@ -603,4 +603,14 @@ ch = get_choices(tr) @test length(get_submaps_shallow(ch)) == 1 end +@testset "serialization" begin + tr = simulate(model, ([1., 2., 3., 4.],)) + @test Gen.to_serializable_trace(tr) isa Gen.GenericSerializableTrace + io = IOBuffer() + serialize_trace(io, tr) + seek(io, 0) + deserialized_tr = deserialize_trace(io, model) + @test get_choices(deserialized_tr) == get_choices(tr) +end + end # @testset "static DSL" diff --git a/test/modeling_library/call_at.jl b/test/modeling_library/call_at.jl index b27f0130d..a79f12e3b 100644 --- a/test/modeling_library/call_at.jl +++ b/test/modeling_library/call_at.jl @@ -55,6 +55,10 @@ (trace, y) end + @testset "serialization" begin + @test serialize_loop_successful(get_trace()[1]) + end + @testset "project" begin (trace, y) = get_trace() @test isapprox(project(trace, EmptySelection()), 0.) diff --git a/test/modeling_library/choice_at.jl b/test/modeling_library/choice_at.jl index 080b1b461..d241d5195 100644 --- a/test/modeling_library/choice_at.jl +++ b/test/modeling_library/choice_at.jl @@ -49,6 +49,10 @@ trace end + @testset "serialization" begin + @test serialize_loop_successful(get_trace()) + end + @testset "project" begin trace = get_trace() @test isapprox(project(trace, EmptySelection()), 0.) diff --git a/test/modeling_library/custom_determ.jl b/test/modeling_library/custom_determ.jl index a0d473c14..8c36ab0f7 100644 --- a/test/modeling_library/custom_determ.jl +++ b/test/modeling_library/custom_determ.jl @@ -84,6 +84,9 @@ @test w == 0. @test get_retval(trace) == 1 + 2 + 3 + # serialization + @test serialize_loop_successful(trace) + # update (UnknownChange) trace = simulate(MyDeterministicGF(), ([1, 2, 3],)) new_trace, w, retdiff = update(trace, ([1, 2, 4],), (UnknownChange(),), EmptyChoiceMap()) diff --git a/test/modeling_library/map.jl b/test/modeling_library/map.jl index 3c1f820fe..32b590070 100644 --- a/test/modeling_library/map.jl +++ b/test/modeling_library/map.jl @@ -38,6 +38,11 @@ @test isapprox(weight, logpdf(normal, z1, 4., 1.)) end + @testset "serialization" begin + (trace, _) = generate(bar, (xs[1:2], ys[1:2])) + @test serialize_loop_successful(trace) + end + @testset "propose" begin (choices, weight) = propose(bar, (xs[1:2], ys[1:2])) z1 = choices[1 => :z] diff --git a/test/modeling_library/recurse.jl b/test/modeling_library/recurse.jl index 7fe7a9592..fdfbb63d2 100644 --- a/test/modeling_library/recurse.jl +++ b/test/modeling_library/recurse.jl @@ -177,6 +177,8 @@ end @test choices[(4, Val(:production)) => :rule] == 4 @test choices[(4, Val(:aggregation)) => :prefix] == false + @test serialize_loop_successful(trace) + # update non-structure choice new_constraints = choicemap() new_constraints[(3, Val(:aggregation)) => :prefix] = false diff --git a/test/modeling_library/switch.jl b/test/modeling_library/switch.jl index dde53f7f2..a46d8bf34 100644 --- a/test/modeling_library/switch.jl +++ b/test/modeling_library/switch.jl @@ -13,6 +13,8 @@ @test swtr[:z] == tr[:z] @test project(swtr, AllSelection()) == project(swtr.branch, AllSelection()) @test project(swtr, EmptySelection()) == swtr.noise + + @test serialize_loop_successful(tr) end # ------------ Bare combinator ------------ # diff --git a/test/modeling_library/unfold.jl b/test/modeling_library/unfold.jl index 3d8c1b54a..cdab51eb2 100644 --- a/test/modeling_library/unfold.jl +++ b/test/modeling_library/unfold.jl @@ -15,6 +15,11 @@ @test length(foo(5, 0., 1.0, 1.0)) == 5 end + @testset "serialization" begin + tr = simulate(foo, (3, 0.1, 0.2, 0.3)) + @test serialize_loop_successful(tr) + end + @testset "simulate" begin x_init = 0.1 alpha = 0.2 diff --git a/test/runtests.jl b/test/runtests.jl index 6fddf20b4..0ff95ab0f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -76,6 +76,24 @@ end const dx = 1e-6 +""" +Attempts to serialize then deserialize the given trace, and returns +whether the pre-serialization and post-serialization traces are equal. +""" +function serialize_loop_successful(tr) + io = IOBuffer() + serialize_trace(io, tr) + seek(io, 0) + des_tr = deserialize_trace(io, get_gen_fn(tr)) + + if get_choices(des_tr) != get_choices(tr) + display(tr) + display(des_tr) + end + + return get_choices(des_tr) == get_choices(tr) +end + include("autodiff.jl") include("diff.jl") include("selection.jl") @@ -86,4 +104,4 @@ include("optional_args.jl") include("static_ir/static_ir.jl") include("tilde_sugar.jl") include("inference/inference.jl") -include("modeling_library/modeling_library.jl") +include("modeling_library/modeling_library.jl") \ No newline at end of file