17
17
end
18
18
19
19
# In Libtask.TapedTask.taped_globals, this extension sometimes needs to store an RNG,
20
- # and sometimes both an RNG and other information. In Turing.jl this other information
20
+ # and sometimes both an RNG and other information. In Turing.jl the other information
21
21
# is a VarInfo. This struct puts those in a single struct. Note the abstract type of
22
22
# the second field. This is okay, because `get_taped_globals` needs a type assertion anyway.
23
23
struct TapedGlobals{RngType}
49
49
50
50
const LibtaskTrace{R} = AdvancedPS. Trace{<: AdvancedPS.LibtaskModel ,R}
51
51
52
+ """ Get the RNG from a `LibtaskTrace`."""
53
+ function get_rng (trace:: LibtaskTrace )
54
+ return trace. model. ctask. taped_globals. rng
55
+ end
56
+
57
+ """ Set the RNG for a `LibtaskTrace`."""
58
+ function set_rng! (trace:: LibtaskTrace , rng:: Random.AbstractRNG )
59
+ taped_globals = trace. model. ctask. taped_globals
60
+ Libtask. set_taped_globals! (trace. model. ctask, TapedGlobals (rng, taped_globals. other))
61
+ trace. rng = rng
62
+ return trace
63
+ end
64
+
65
+ """ Set the other "taped global" variable of a `LibtaskTrace`, other than the RNG."""
66
+ function set_other_global! (trace:: LibtaskTrace , other)
67
+ rng = get_rng (trace)
68
+ Libtask. set_taped_globals! (trace. model. ctask, TapedGlobals (rng, other))
69
+ return trace
70
+ end
71
+
72
+ """ Get the other "taped global" variable of a `LibtaskTrace`, other than the RNG."""
73
+ get_other_global (trace:: LibtaskTrace ) = trace. model. ctask. taped_globals. other
74
+
52
75
function AdvancedPS. Trace (
53
76
model:: AdvancedPS.AbstractGenericModel , rng:: Random.AbstractRNG , args...
54
77
)
58
81
# step to the next observe statement and
59
82
# return the log probability of the transition (or nothing if done)
60
83
function AdvancedPS. advance! (trace:: LibtaskTrace , isref:: Bool = false )
61
- taped_globals = trace. model. ctask. taped_globals
62
- rng = taped_globals. rng
84
+ rng = get_rng (trace)
63
85
isref ? AdvancedPS. load_state! (rng) : AdvancedPS. save_state! (rng)
64
86
AdvancedPS. inc_counter! (rng)
65
-
66
- Libtask. set_taped_globals! (trace. model. ctask, TapedGlobals (rng, taped_globals. other))
67
- trace. rng = rng
68
-
87
+ set_rng! (trace, rng)
69
88
# Move to next step
70
89
return Libtask. consume (trace. model. ctask)
71
90
end
72
91
73
- # create a backward reference in task_local_storage
74
- function AdvancedPS. addreference! (task:: Libtask.TapedTask , trace:: LibtaskTrace )
75
- rng = task. taped_globals. rng
76
- Libtask. set_taped_globals! (task, TapedGlobals (rng, trace))
77
- return task
92
+ """
93
+ Set a backreference so that the TapedTask in `trace` stores the `trace` itself in the
94
+ taped globals.
95
+ """
96
+ function AdvancedPS. addreference! (trace:: LibtaskTrace )
97
+ set_other_global! (trace, trace)
98
+ return trace
78
99
end
79
100
80
101
function AdvancedPS. update_rng! (trace:: LibtaskTrace )
81
- taped_globals = trace. model. ctask. taped_globals
82
- new_rng = deepcopy (taped_globals. rng)
83
- trace. rng = new_rng
84
- Libtask. set_taped_globals! (trace. model. ctask, TapedGlobals (new_rng, taped_globals. other))
102
+ set_rng! (trace, deepcopy (get_rng (trace)))
85
103
return trace
86
104
end
87
105
@@ -91,19 +109,19 @@ function AdvancedPS.fork(trace::LibtaskTrace, isref::Bool=false)
91
109
AdvancedPS. update_rng! (newtrace)
92
110
isref && AdvancedPS. delete_retained! (newtrace. model. f)
93
111
isref && delete_seeds! (newtrace)
112
+ AdvancedPS. addreference! (newtrace)
94
113
return newtrace
95
114
end
96
115
97
116
# PG requires keeping all randomness for the reference particle
98
117
# Create new task and copy randomness
99
118
function AdvancedPS. forkr (trace:: LibtaskTrace )
100
- taped_globals = trace. model. ctask. taped_globals
101
- rng = taped_globals. rng
119
+ rng = get_rng (trace)
102
120
newf = AdvancedPS. reset_model (trace. model. f)
103
121
Random123. set_counter! (rng, 1 )
104
122
trace. rng = rng
105
123
106
- ctask = Libtask. TapedTask (TapedGlobals (rng, taped_globals . other ), newf)
124
+ ctask = Libtask. TapedTask (TapedGlobals (rng, get_other_global (trace) ), newf)
107
125
new_tapedmodel = AdvancedPS. LibtaskModel (newf, ctask)
108
126
109
127
# add backward reference
0 commit comments