Thanks to visit codestin.com
Credit goes to github.com

Skip to content

make this faster) #220

@github-actions

Description

@github-actions

retries[] # 607, 556

comp_graph_size(Dice.tobits(g)) # 16825

BenchmarkTools.Trial: 24 samples with 1 evaluation.

Range (min … max): 104.068 ms … 592.294 ms ┊ GC (min … max): 0.00% … 0.00%

Time (median): 168.251 ms ┊ GC (median): 0.00%

Time (mean ± σ): 186.669 ms ± 97.022 ms ┊ GC (mean ± σ): 0.00% ± 0.00%

▃▃█ ▃▃▃ ▃▃

▇███▁▁███▇██▁▇▇▁▁▇▁▁▁▁▁▇▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▇ ▁

104 ms Histogram: frequency by time 592 ms <

Memory estimate: 6.51 MiB, allocs estimate: 235681.

Range (min … max): 21.161 ms … 130.908 ms ┊ GC (min … max): 0.00% … 30.08%

Time (median): 39.090 ms ┊ GC (median): 0.00%

Time (mean ± σ): 44.209 ms ± 18.477 ms ┊ GC (mean ± σ): 0.79% ± 2.83%

▁▄ ▂▅█▁▅▇ ▂ ▁

▃▅▁██▆███████▆▅█▃█▆▅█▆▃▃▅▅▃▅▁▁▁▃▁▁▁▃▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▃▁▁▁▃▃ ▃

21.2 ms Histogram: frequency by time 117 ms <

Memory estimate: 2.11 MiB, allocs estimate: 62283.

https://github.com/Tractables/Dice.jl/blob/daa0fc51f50e9ffbb8cf0be440618d8af92247cd/qc/benchmarks/stlc_faster_200samples.jl#L37

# We found sampling from the BDD, like sampling from the computation graph, also took ~160s for 200 well-typed samples.
# OTHER TIMINGS IN THIS FILE ARE WRONG
# mainly we care about `sample_one_as_dist_compile` in this file

using Revise
using Dice
using BenchmarkTools
using ProfileView

function comp_graph_size(roots)
    cmp_graph_ct = Ref(0)
    Dice.foreach_down(roots) do _
        cmp_graph_ct[] += 1
    end
    cmp_graph_ct[] # 2040
end


include("benchmarks.jl")

generation_params = LangSiblingDerivedGenerator{STLC}(
    root_ty=Expr.t,
    ty_sizes=[Expr.t=>5, Typ.t=>2],
    stack_size=2,
    intwidth=3,
)

SEED = 0
out_dir="/tmp"
log_path="/dev/null"
rs = RunState(Valuation(), Dict{String,ADNode}(), open(log_path, "w"), out_dir, MersenneTwister(SEED), nothing,generation_params)

generation::Generation = generate(rs, generation_params)

g::Dist = generation.value

# Sample some tree until it's valid (TODO: make this faster)
a = ADComputer(rs.var_vals)

NUM_SAMPLES = 200

function sample_one_as_dist_compile(c::BDDCompiler, a::ADComputer, d::Dist, roots)
    # State for one sampling
    bdd_node_to_tf = Dict{CuddNode,Bool}()
    level_to_tf = Dict{Integer, Bool}()
    bdd_node_to_tf[Dice.constant(c.mgr, true)] = true
    bdd_node_to_tf[Dice.constant(c.mgr, false)] = false

    function sample_level(c, level::Integer)
        get!(level_to_tf, level) do
            prob = compute(a, c.level_to_flip[level].prob)
            rand() < prob
        end
    end

    function sample_one(c, bdd_node_to_tf, x::AnyBool)
        sample_one(c, bdd_node_to_tf, compile(c, x))
    end

    function sample_one(c, bdd_node_to_tf, x::CuddNode)
        get!(bdd_node_to_tf, x) do
            if sample_level(c, Dice.level(x))
                sample_one(c, bdd_node_to_tf, Dice.high(x))
            else
                sample_one(c, bdd_node_to_tf, Dice.low(x))
            end
        end
    end

    bits = Dict()
    for root in roots
        bits[root] = sample_one(c, bdd_node_to_tf, root)
    end
    Dice.frombits_as_dist(d, bits)
end


function wellTyped(e)
    @assert isdeterministic(e)
    @match typecheck(e) [
        Some(_) -> true,
        None() -> false,
    ]
end

retries = Ref(0)
#== @benchmark ==# @time begin
    samples = []
    d = g
    roots = Dice.tobits(d)
    c = BDDCompiler(roots)
    a = Dice.ADComputer(rs.var_vals)
    while length(samples) < NUM_SAMPLES
        retries[] += 1
        s = sample_one_as_dist_compile(c, a, d, roots)
        if wellTyped(s)
            push!(samples, s)
        end
    end
end
# 174s, 155s, 281
retries[] # 607, 556

@time eqs = [
    prob_equals(g, sample)
    for sample in samples
]
# 27 sec

comp_graph_size(eqs) # 2040
comp_graph_size(Dice.tobits(g)) # 16825

# @benchmark prob_equals(g, samples[1])
# BenchmarkTools.Trial: 24 samples with 1 evaluation.
#  Range (min … max):  104.068 ms … 592.294 ms  ┊ GC (min … max): 0.00% … 0.00%
#  Time  (median):     168.251 ms               ┊ GC (median):    0.00%
#  Time  (mean ± σ):   186.669 ms ±  97.022 ms  ┊ GC (mean ± σ):  0.00% ± 0.00%
#    ▃▃█  ▃▃▃ ▃▃                                                   
#   ▇███▁▁███▇██▁▇▇▁▁▇▁▁▁▁▁▇▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▇ ▁
#   104 ms           Histogram: frequency by time          592 ms <
#  Memory estimate: 6.51 MiB, allocs estimate: 235681.

@time c = BDDCompiler(eqs)
# 0.28553 s, 0.23, 0.1

@time w = WMC(c)
# instant

@time l = Dice.LogPrExpander(w)
# instant

@time loss, actual_loss = sum(
    begin
        lpr_eq = Dice.expand_logprs(l, LogPr(prob_equals(g, sample)))
        [lpr_eq * compute(a, lpr_eq), lpr_eq]
    end
    for sample in samples
)
# 6s, 31s, 9s

length(l.cache) # 3473

@benchmark vals, derivs = differentiate(
    rs.var_vals,
    Derivs([loss => 1.])
)
# BenchmarkTools.Trial: 113 samples with 1 evaluation.
#  Range (min … max):  21.161 ms … 130.908 ms  ┊ GC (min … max): 0.00% … 30.08%
#  Time  (median):     39.090 ms               ┊ GC (median):    0.00%
#  Time  (mean ± σ):   44.209 ms ±  18.477 ms  ┊ GC (mean ± σ):  0.79% ±  2.83%
#      ▁▄ ▂▅█▁▅▇     ▂  ▁                                         
#   ▃▅▁██▆███████▆▅█▃█▆▅█▆▃▃▅▅▃▅▁▁▁▃▁▁▁▃▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▃▁▁▁▃▃ ▃
#   21.2 ms         Histogram: frequency by time          117 ms <
#  Memory estimate: 2.11 MiB, allocs estimate: 62283.

ct = Ref(0)
Dice.foreach_down(loss) do _ ct[] += 1 end
ct[] # 3872

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions