@@ -132,6 +132,28 @@ def model_with_dims():
132
132
return pmodel , compute_graph , plates
133
133
134
134
135
+ def model_unnamed_observed_node ():
136
+ """
137
+ Model at the source of the following issue: https://github.com/pymc-devs/pymc/issues/5892
138
+ """
139
+ data = [- 1 , 0 , 0.5 , 1 ]
140
+
141
+ with pm .Model () as model :
142
+ mu = pm .Normal (name = "mu" , mu = 0.0 , sigma = 5.0 )
143
+ y = pm .Normal (name = "y" , mu = mu , sigma = 3.0 , observed = data )
144
+
145
+ compute_graph = {
146
+ "mu" : set (),
147
+ "y" : {"mu" },
148
+ }
149
+ plates = {
150
+ "" : {"mu" },
151
+ "4" : {"y" },
152
+ }
153
+
154
+ return model , compute_graph , plates
155
+
156
+
135
157
class BaseModelGraphTest (SeededTest ):
136
158
model_func = None
137
159
@@ -202,21 +224,16 @@ def model_with_different_descendants():
202
224
return pmodel2
203
225
204
226
205
- class TestParents :
206
- @pytest .mark .parametrize (
207
- "var_name, parent_names" ,
208
- [
209
- ("L" , {"pred" }),
210
- ("pred" , {"intermediate" }),
211
- ("intermediate" , {"a" , "b" }),
212
- ("c" , {"a" , "b" }),
213
- ("a" , set ()),
214
- ("b" , set ()),
215
- ],
216
- )
217
- def test_get_parent_names (self , var_name , parent_names ):
218
- mg = ModelGraph (model_with_different_descendants ())
219
- mg .get_parent_names (mg .model [var_name ]) == parent_names
227
+ class TestImputationModel (BaseModelGraphTest ):
228
+ model_func = model_with_imputations
229
+
230
+
231
+ class TestModelWithDims (BaseModelGraphTest ):
232
+ model_func = model_with_dims
233
+
234
+
235
+ class TestUnnamedObservedNodes (BaseModelGraphTest ):
236
+ model_func = model_unnamed_observed_node
220
237
221
238
222
239
class TestVariableSelection :
@@ -260,11 +277,3 @@ def test_subgraph(self, var_names, vars_to_plot, compute_graph):
260
277
mg = ModelGraph (model_with_different_descendants ())
261
278
assert set (mg .vars_to_plot (var_names = var_names )) == set (vars_to_plot )
262
279
assert mg .make_compute_graph (var_names = var_names ) == compute_graph
263
-
264
-
265
- class TestImputationModel (BaseModelGraphTest ):
266
- model_func = model_with_imputations
267
-
268
-
269
- class TestModelWithDims (BaseModelGraphTest ):
270
- model_func = model_with_dims
0 commit comments