13
13
import pickle
14
14
import shutil
15
15
import sys
16
- from collections import OrderedDict
17
16
from tempfile import mkdtemp
18
17
19
18
import numpy as np
@@ -764,11 +763,9 @@ def test_output_padding(self):
764
763
b = shared (np .random .default_rng (utt .fetch_seed ()).random ((5 , 4 )))
765
764
766
765
def inner_func (a ):
767
- return a + 1 , OrderedDict ([( b , 2 * b )])
766
+ return a + 1 , { b : 2 * b }
768
767
769
- out , updates = scan (
770
- inner_func , outputs_info = [OrderedDict ([("initial" , init_a )])], n_steps = 1
771
- )
768
+ out , updates = scan (inner_func , outputs_info = [{"initial" : init_a }], n_steps = 1 )
772
769
out = out [- 1 ]
773
770
assert out .type .ndim == a .type .ndim
774
771
assert updates [b ].type .ndim == b .type .ndim
@@ -934,7 +931,7 @@ def test_only_shared_no_input_no_output(self):
934
931
state = shared (v_state , "vstate" )
935
932
936
933
def f_2 ():
937
- return OrderedDict ([( state , 2 * state )])
934
+ return { state : 2 * state }
938
935
939
936
n_steps = iscalar ("nstep" )
940
937
output , updates = scan (
@@ -968,7 +965,7 @@ def test_shared_updates(self):
968
965
X = shared (np .array (1 ))
969
966
970
967
out , updates = scan (
971
- lambda : OrderedDict ([( X , (X + 1 ))]) ,
968
+ lambda : { X : (X + 1 )} ,
972
969
outputs_info = [],
973
970
non_sequences = [],
974
971
sequences = [],
@@ -984,7 +981,7 @@ def test_shared_memory_aliasing_updates(self):
984
981
y = shared (np .array (1 ))
985
982
986
983
out , updates = scan (
987
- lambda : OrderedDict ([( x , x + 1 ), ( y , x )]) ,
984
+ lambda : { x : x + 1 , y : x } ,
988
985
outputs_info = [],
989
986
non_sequences = [],
990
987
sequences = [],
@@ -1914,7 +1911,7 @@ def test_grad_numeric_shared(self):
1914
1911
shared_var = shared (np .float32 (1.0 ))
1915
1912
1916
1913
def inner_fn ():
1917
- return [], OrderedDict ([( shared_var , shared_var + np .float32 (1.0 ))])
1914
+ return [], { shared_var : shared_var + np .float32 (1.0 )}
1918
1915
1919
1916
_ , updates = scan (
1920
1917
inner_fn , n_steps = 10 , truncate_gradient = - 1 , go_backwards = False
@@ -2746,7 +2743,7 @@ def one_step(x_t, h_tm1, W):
2746
2743
2747
2744
v1 = shared (np .ones (5 , dtype = config .floatX ))
2748
2745
v2 = shared (np .ones ((5 , 5 ), dtype = config .floatX ))
2749
- shapef = function ([W ], expr , givens = OrderedDict ([( initial , v1 ), ( inpt , v2 )]) )
2746
+ shapef = function ([W ], expr , givens = { initial : v1 , inpt : v2 } )
2750
2747
# First execution to cache n_steps
2751
2748
shapef (np .ones ((5 , 5 ), dtype = config .floatX ))
2752
2749
@@ -2755,7 +2752,7 @@ def one_step(x_t, h_tm1, W):
2755
2752
f = function (
2756
2753
[W , inpt ],
2757
2754
d_cost_wrt_W ,
2758
- givens = OrderedDict ([( initial , shared (np .zeros (5 )))]) ,
2755
+ givens = { initial : shared (np .zeros (5 ))} ,
2759
2756
)
2760
2757
2761
2758
rval = np .asarray ([[5187989 ] * 5 ] * 5 , dtype = config .floatX )
@@ -2956,7 +2953,7 @@ def onestep(x, x_tm4):
2956
2953
2957
2954
seq = matrix ()
2958
2955
initial_value = shared (np .zeros ((4 , 1 ), dtype = config .floatX ))
2959
- outputs_info = [OrderedDict ([( "initial" , initial_value ), ( "taps" , [- 4 ])]) , None ]
2956
+ outputs_info = [{ "initial" : initial_value , "taps" : [- 4 ]} , None ]
2960
2957
results , updates = scan (fn = onestep , sequences = seq , outputs_info = outputs_info )
2961
2958
2962
2959
f = function ([seq ], results [1 ])
@@ -2979,10 +2976,10 @@ def onestep(x, x_tm4):
2979
2976
2980
2977
seq = matrix ()
2981
2978
initial_value = shared (np .zeros ((4 , 1 ), dtype = config .floatX ))
2982
- outputs_info = [OrderedDict ([( "initial" , initial_value ), ( "taps" , [- 4 ])]) , None ]
2979
+ outputs_info = [{ "initial" : initial_value , "taps" : [- 4 ]} , None ]
2983
2980
results , _ = scan (fn = onestep , sequences = seq , outputs_info = outputs_info )
2984
2981
sharedvar = shared (np .zeros ((1 , 1 ), dtype = config .floatX ))
2985
- updates = OrderedDict ([( sharedvar , results [0 ][- 1 :])])
2982
+ updates = { sharedvar : results [0 ][- 1 :]}
2986
2983
2987
2984
f = function ([seq ], results [1 ], updates = updates )
2988
2985
0 commit comments