@@ -101,11 +101,12 @@ def test_truncation_specialized_op(shape_info):
101
101
102
102
@pytest .mark .parametrize ("lower, upper" , [(- 1 , np .inf ), (- 1 , 1.5 ), (- np .inf , 1.5 )])
103
103
@pytest .mark .parametrize ("op_type" , ["icdf" , "rejection" ])
104
- def test_truncation_continuous_random (op_type , lower , upper ):
104
+ @pytest .mark .parametrize ("scalar" , [True , False ])
105
+ def test_truncation_continuous_random (op_type , lower , upper , scalar ):
105
106
loc = 0.15
106
107
scale = 10
107
108
normal_op = icdf_normal if op_type == "icdf" else rejection_normal
108
- x = normal_op (loc , scale , name = "x" , size = 100 )
109
+ x = normal_op (loc , scale , name = "x" , size = () if scalar else ( 100 ,) )
109
110
110
111
xt = Truncated .dist (x , lower = lower , upper = upper )
111
112
assert isinstance (xt .owner .op , TruncatedRV )
@@ -134,7 +135,7 @@ def test_truncation_continuous_random(op_type, lower, upper):
134
135
assert np .unique (xt_draws ).size == xt_draws .size
135
136
else :
136
137
with pytest .raises (TruncationError , match = "^Truncation did not converge" ):
137
- draw (xt )
138
+ draw (xt , draws = 100 if scalar else 1 )
138
139
139
140
140
141
@pytest .mark .parametrize ("lower, upper" , [(- 1 , np .inf ), (- 1 , 1.5 ), (- np .inf , 1.5 )])
0 commit comments