Skip to content

Commit 4bf2ac5

Browse files
committed
Simplify Libtask extension
1 parent 6dff5f8 commit 4bf2ac5

File tree

3 files changed

+39
-21
lines changed

3 files changed

+39
-21
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ AdvancedPSLibtaskExt = "Libtask"
2121
[compat]
2222
AbstractMCMC = "2, 3, 4, 5"
2323
Distributions = "0.23, 0.24, 0.25"
24-
Libtask = "0.9"
24+
Libtask = "0.9.2"
2525
Random = "<0.0.1, 1"
2626
Random123 = "1.3"
2727
Requires = "1.0"

ext/AdvancedPSLibtaskExt.jl

Lines changed: 37 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ else
1717
end
1818

1919
# In Libtask.TapedTask.taped_globals, this extension sometimes needs to store an RNG,
20-
# and sometimes both an RNG and other information. In Turing.jl this other information
20+
# and sometimes both an RNG and other information. In Turing.jl the other information
2121
# is a VarInfo. This struct puts those in a single struct. Note the abstract type of
2222
# the second field. This is okay, because `get_taped_globals` needs a type assertion anyway.
2323
struct TapedGlobals{RngType}
@@ -49,6 +49,29 @@ end
4949

5050
const LibtaskTrace{R} = AdvancedPS.Trace{<:AdvancedPS.LibtaskModel,R}
5151

52+
"""Get the RNG from a `LibtaskTrace`."""
53+
function get_rng(trace::LibtaskTrace)
54+
return trace.model.ctask.taped_globals.rng
55+
end
56+
57+
"""Set the RNG for a `LibtaskTrace`."""
58+
function set_rng!(trace::LibtaskTrace, rng::Random.AbstractRNG)
59+
taped_globals = trace.model.ctask.taped_globals
60+
Libtask.set_taped_globals!(trace.model.ctask, TapedGlobals(rng, taped_globals.other))
61+
trace.rng = rng
62+
return trace
63+
end
64+
65+
"""Set the other "taped global" variable of a `LibtaskTrace`, other than the RNG."""
66+
function set_other_global!(trace::LibtaskTrace, other)
67+
rng = get_rng(trace)
68+
Libtask.set_taped_globals!(trace.model.ctask, TapedGlobals(rng, other))
69+
return trace
70+
end
71+
72+
"""Get the other "taped global" variable of a `LibtaskTrace`, other than the RNG."""
73+
get_other_global(trace::LibtaskTrace) = trace.model.ctask.taped_globals.other
74+
5275
function AdvancedPS.Trace(
5376
model::AdvancedPS.AbstractGenericModel, rng::Random.AbstractRNG, args...
5477
)
@@ -58,30 +81,25 @@ end
5881
# step to the next observe statement and
5982
# return the log probability of the transition (or nothing if done)
6083
function AdvancedPS.advance!(trace::LibtaskTrace, isref::Bool=false)
61-
taped_globals = trace.model.ctask.taped_globals
62-
rng = taped_globals.rng
84+
rng = get_rng(trace)
6385
isref ? AdvancedPS.load_state!(rng) : AdvancedPS.save_state!(rng)
6486
AdvancedPS.inc_counter!(rng)
65-
66-
Libtask.set_taped_globals!(trace.model.ctask, TapedGlobals(rng, taped_globals.other))
67-
trace.rng = rng
68-
87+
set_rng!(trace, rng)
6988
# Move to next step
7089
return Libtask.consume(trace.model.ctask)
7190
end
7291

73-
# create a backward reference in task_local_storage
74-
function AdvancedPS.addreference!(task::Libtask.TapedTask, trace::LibtaskTrace)
75-
rng = task.taped_globals.rng
76-
Libtask.set_taped_globals!(task, TapedGlobals(rng, trace))
77-
return task
92+
"""
93+
Set a backreference so that the TapedTask in `trace` stores the `trace` itself in the
94+
taped globals.
95+
"""
96+
function AdvancedPS.addreference!(trace::LibtaskTrace)
97+
set_other_global!(trace, trace)
98+
return trace
7899
end
79100

80101
function AdvancedPS.update_rng!(trace::LibtaskTrace)
81-
taped_globals = trace.model.ctask.taped_globals
82-
new_rng = deepcopy(taped_globals.rng)
83-
trace.rng = new_rng
84-
Libtask.set_taped_globals!(trace.model.ctask, TapedGlobals(new_rng, taped_globals.other))
102+
set_rng!(trace, deepcopy(get_rng(trace)))
85103
return trace
86104
end
87105

@@ -91,19 +109,19 @@ function AdvancedPS.fork(trace::LibtaskTrace, isref::Bool=false)
91109
AdvancedPS.update_rng!(newtrace)
92110
isref && AdvancedPS.delete_retained!(newtrace.model.f)
93111
isref && delete_seeds!(newtrace)
112+
AdvancedPS.addreference!(newtrace)
94113
return newtrace
95114
end
96115

97116
# PG requires keeping all randomness for the reference particle
98117
# Create new task and copy randomness
99118
function AdvancedPS.forkr(trace::LibtaskTrace)
100-
taped_globals = trace.model.ctask.taped_globals
101-
rng = taped_globals.rng
119+
rng = get_rng(trace)
102120
newf = AdvancedPS.reset_model(trace.model.f)
103121
Random123.set_counter!(rng, 1)
104122
trace.rng = rng
105123

106-
ctask = Libtask.TapedTask(TapedGlobals(rng, taped_globals.other), newf)
124+
ctask = Libtask.TapedTask(TapedGlobals(rng, get_other_global(trace)), newf)
107125
new_tapedmodel = AdvancedPS.LibtaskModel(newf, ctask)
108126

109127
# add backward reference

test/container.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@
159159
end
160160

161161
trace = AdvancedPS.Trace(TaskIdModel(), AdvancedPS.TracedRNG())
162-
AdvancedPS.addreference!(trace.model.ctask, trace)
162+
AdvancedPS.addreference!(trace)
163163

164164
@test AdvancedPS.advance!(trace, false) === objectid(trace)
165165
end

0 commit comments

Comments
 (0)