-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathutils.py
47 lines (32 loc) · 1.06 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
# utils.py
# author: Playinf
# email: [email protected]
import tensorflow as tf
from tensorflow.python.util import nest
def function(inputs, outputs, updates=None):
if not isinstance(inputs, (list, tuple)):
inputs = [inputs]
if not isinstance(outputs, (list, tuple)):
outputs = [outputs]
fetches = nest.flatten(outputs)
if updates:
fetches.append(updates)
def func(*values, **option):
feed_dict = {}
flat_inputs = nest.flatten(inputs)
flat_values = nest.flatten(values)
for inp, val in zip(flat_inputs, flat_values):
feed_dict[inp] = val
if "session" not in option:
session = None
else:
session = option["session"]
sess = session or tf.get_default_session()
results = sess.run(fetches, feed_dict=feed_dict)
if updates:
results = results[:-1]
results = nest.pack_sequence_as(outputs, results)
if len(results) == 1:
return results[0]
return results
return func