-
Notifications
You must be signed in to change notification settings - Fork 36
/
Copy pathtfdeploy.py
2238 lines (1714 loc) · 54.7 KB
/
tfdeploy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# -*- coding: utf-8 -*-
"""
Deploy tensorflow graphs for fast evaluation and export to tensorflow-less environments running
numpy.
"""
__author__ = "Marcel Rieger"
__copyright__ = "Copyright 2016-2025, Marcel Rieger"
__credits__ = ["Marcel Rieger"]
__contact__ = "https://github.com/riga/tfdeploy"
__license__ = "BSD-3-Clause"
__status__ = "Development"
__version__ = "0.4.2"
__all__ = ["Model", "Tensor", "Operation", "Ensemble",
"UnknownOperationException", "OperationMismatchException",
"InvalidImplementationException", "UnknownImplementationException",
"EnsembleMismatchException", "ScipyOperationException",
"reset", "optimize", "print_tensor", "print_op", "print_tf_tensor", "print_tf_op",
"IMPL_NUMPY", "IMPL_SCIPY", "IMPLS",
"METHOD_MEAN", "METHOD_MAX", "METHOD_MIN", "METHOD_CUSTOM", "METHODS",
"HAS_SCIPY"]
# imports for core code
import os
import re
from uuid import uuid4
from functools import reduce
try:
# python 2
import cPickle as pickle
except ImportError:
# python 3
import pickle
# third-party imports
import numpy as np
# metaclass decorator from six package, credits to Benjamin Peterson
def add_metaclass(metaclass):
def wrapper(cls):
orig_vars = cls.__dict__.copy()
slots = orig_vars.get("__slots__")
if slots is not None:
if isinstance(slots, str):
slots = [slots]
for slots_var in slots:
orig_vars.pop(slots_var)
orig_vars.pop("__dict__", None)
orig_vars.pop("__weakref__", None)
return metaclass(cls.__name__, cls.__bases__, orig_vars)
return wrapper
class Model(object):
"""
A trained model that contains one or more converted tensorflow graphs. When *path* is set, a
previously saved model is loaded from that path. Usage:
.. code-block:: python
import tensorflow as tf
import tfdeploy as td
# build your graph, use names for input and output tensors
sess = tf.Session()
x = tf.placeholder("float", shape=[None, 784], name="input")
W = tf.Variable(tf.truncated_normal([784, 100], stddev=0.05))
b = tf.Variable(tf.zeros([100]))
y = tf.nn.softmax(tf.matmul(x, W) + b, name="output")
sess.run(tf.initialize_all_variables())
# ... training ...
# create a model and save it to disk
model = td.Model()
model.add(y, sess)
model.save("/path/to/model.pkl")
And then in an other file:
.. code-block:: python
import tfdeploy as td
import numpy as np
model = td.Model("/path/to/model.pkl")
inp, outp = model.get("input", "output")
batch = np.random.rand(10000, 784)
result = outp.eval({inp: batch})
.. py:attribute:: roots
Contained root tensors in a dict mapped to a key.
"""
value_index_cre = re.compile("\:\d+$")
default_value_index = 0
def __init__(self, path=None):
super(Model, self).__init__()
self.roots = {}
# load when desired
if path is not None:
self.load(path)
def get(self, *names, **kwargs):
""" get(*names, key=None)
Returns one or more :py:class:`Tensor` instances given by *names* using a deep lookup within
the model. If *key* is not *None*, only the root tensor with that *key* is traversed. *None*
is returned when no tensor was found. In case a tensor is passed, it's name is used for the
lookup.
"""
tensors = tuple(self._get(name, **kwargs) for name in names)
return tensors[0] if len(names) == 1 else tensors
def _get(self, name, key=None):
if isinstance(name, Tensor):
name = name.name
# append the default value_index if there's none
if not self.value_index_cre.search(name):
name += ":%d" % self.default_value_index
# return the first occurance of a tensor with that name
if key is not None:
return self.roots[key].get(name)
else:
return reduce(lambda t1, t2: t1 or t2.get(name), self.roots.values(), None)
def __getitem__(self, name):
return self.get(name)
def __contains__(self, name):
return self.get(name) is not None
def add(self, tensor, tf_sess=None, key=None, **kwargs):
"""
Adds a new root *tensor* for a *key* which, if *None*, defaults to a consecutive number.
When *tensor* is not an instance of :py:class:`Tensor` but an instance of
``tensorflow.Tensor``, it is converted first. In that case, *tf_sess* should be a valid
tensorflow session and *kwargs* are forwarded to the :py:class:`Tensor` constructor.
"""
if not isinstance(tensor, Tensor):
tensor = Tensor(tensor, tf_sess, **kwargs)
if key is None:
if len(self.roots) == 0:
key = 0
else:
key = max(self.roots.keys()) + 1
self.roots[key] = tensor
def load(self, path):
"""
Loads all tensors from a file defined by *path* and adds them to the root set.
"""
path = os.path.expandvars(os.path.expanduser(path))
with open(path, "rb") as f:
roots = pickle.load(f)
for key, tensor in roots.items():
self.add(tensor, key=key)
def save(self, path):
"""
Saves all tensors of the root set to a file defined by *path*.
"""
path = os.path.expandvars(os.path.expanduser(path))
with open(path, "wb") as f:
pickle.dump(self.roots, f)
class TensorRegister(type):
"""
Meta class of :py:class:`Tensor` that performs instance caching indexed by tensorflow tensor
instances.
"""
instances = {}
def __call__(cls, tf_tensor, *args, **kwargs):
# simple caching
if tf_tensor not in cls.instances:
inst = super(TensorRegister, cls).__call__(tf_tensor, *args, **kwargs)
cls.instances[tf_tensor] = inst
return cls.instances[tf_tensor]
@add_metaclass(TensorRegister)
class Tensor(object):
"""
Building block of a model. In *graph* terms, tensors represent connections between nodes (ops)
of a graph. It contains information on the op it results from. The conversion uses the
(tensorflow) instances *tf_tensor* and *tf_sess*, *tf_feed_dict* can be set to evaluate the
tensor's current value.
.. py:attribute:: name
The name of the tensor.
.. py:attribute:: value_index
The integer value index of this tensor, i.e., the position in the op's output list.
.. py:attribute:: op
The op instance that defines the value of this tensor. When created from a
``tensorflow.Placeholder`` or a ``tensorflow.Variable/V2``, op will be *None*.
.. py:attribute:: value
The value of this tensor. When created from a ``tensorflow.Variable/V2``, this will be the
value of that variable, or *None* otherwise until it is evaluated the first time.
"""
def __init__(self, tf_tensor, tf_sess, tf_feed_dict=None):
super(Tensor, self).__init__()
if not tf_sess:
raise ValueError("bad tensorflow session: %s" % tf_sess)
self.name = tf_tensor.name
self.value_index = tf_tensor.value_index
self.op = None
self.value = None
self.last_uuid = None
# guess the value
# explicitly evaluate variables and constants, use feed_dict for placeholders
if tf_tensor.op.type in ("Variable", "VariableV2", "Const"):
self.value = tf_tensor.eval(session=tf_sess, feed_dict=tf_feed_dict)
elif tf_tensor.op.type == "Placeholder":
if tf_feed_dict is not None and tf_tensor in tf_feed_dict:
self.value = tf_feed_dict[tf_tensor]
# create the op
# no op for variables, placeholders and constants
if tf_tensor.op.type not in ("Variable", "VariableV2", "Const", "Placeholder"):
self.op = Operation.new(tf_tensor.op, tf_sess, tf_feed_dict=tf_feed_dict)
def get(self, *names):
"""
Returns one or more tensors given by *names* using a deep lookup within the inputs of the
op. Note that *this* tensor is returned when the name matches. *None* is returned when no
tensor was found.
"""
tensors = tuple(self._get(name) for name in names)
return tensors[0] if len(names) == 1 else tensors
def _get(self, name):
if self.name == name:
return self
elif self.op is None:
return None
else:
return self.op.get(name)
def eval(self, feed_dict=None, _uuid=None):
""" eval(feed_dict=None)
Returns the value of this tensor based on the evaluation of all dependent ops and tensors.
You can overwrite values of dependent tensors using *feed_dict*, a mapping of tensors to
numpy arrays, which is passed down the evaluation chain.
"""
# set a cache uuid for this eval call
if _uuid is None:
_uuid = uuid4()
# already cached? this is important for tensors that are used multiple time within the graph
if _uuid == self.last_uuid:
return self.value
else:
self.last_uuid = _uuid
if feed_dict is None:
feed_dict = {}
# when _this_ tensor is in the feed_dict, return the fed value
# otherwise, eval the op
if self in feed_dict:
self.value = feed_dict[self]
elif self.op is not None:
self.value = self.op.eval(feed_dict=feed_dict, _uuid=_uuid)[self.value_index]
return self.value
def __call__(self, *args, **kwargs):
return self.eval(*args, **kwargs)
class OperationRegister(type):
"""
Meta class of :py:class:`Operation` that performs instance caching indexed by tensorflow op
instances. Additionaly, all derived classes are registered in a mapping using their type's for
faster op class lookup.
"""
classes = {}
instances = {}
def __new__(metacls, classname, bases, classdict):
# when not set explicitly in that class, set type to the class name
classdict.setdefault("types", (classname,))
cls = super(OperationRegister, metacls).__new__(metacls, classname, bases, classdict)
# register the class for each of its types
for type in cls.types:
metacls.classes[type] = cls
return cls
def __call__(cls, tf_op, *args, **kwargs):
# simple caching
if tf_op not in cls.instances:
inst = super(OperationRegister, cls).__call__(tf_op, *args, **kwargs)
cls.instances[tf_op] = inst
return cls.instances[tf_op]
# implementation types
IMPLS = IMPL_NUMPY, IMPL_SCIPY = range(2)
IMPL_NAMES = ["numpy", "scipy"]
@add_metaclass(OperationRegister)
class Operation(object):
"""
Building block of a model. In *graph* terms, operations (ops) represent nodes that are connected
via tensors. It contains information on its input tensors. The conversion uses the
(tensorflow) instance *tf_op*, all *args* and *kwargs* are forwarded to the :py:class:`Tensor`
constructor for this op's input tensors. Op instances can have multiple implementations, i.e.,
different methods that lead to equivalent results but might use additional third-party software
such as *scipy*. To select a specific implementation, invoke :py:func:`use_impl`:
.. code-block:: python
# tell SomeOp to use the scipy implementation of its op logic
SomeOp.use_impl(IMPL_SCIPY)
See :py:func:`add_impl` for more info about adding new implementations.
.. py:attribute:: types
classmember
A tuple containing the types of tensorflow ops that this op can represent.
.. py:attribute:: unpack
classmember
If *True* (default), the values of evaluated input tensors are forwarded to *func* as single
arguments, or, otherwise, as a list.
.. py:attribute:: attrs
classmember
Names of the configuration attributes of the original tensorflow op in a tuple.
.. py:attribute:: name
The name of the op.
.. py:attribute:: inputs
Tuple of tensors that are input to this op. Their order is important as they are forwarded to
*func* for evaluation.
.. py:attribute:: kwargs
Keyword arguments containing configuration values that will be passed to *func*.
"""
impl = None
impls = []
types = ()
unpack = True
attrs = ()
output_dtypes = False
def __init__(self, tf_op, *args, **kwargs):
super(Operation, self).__init__()
# compare types as a cross check
if tf_op.type not in self.types:
raise OperationMismatchException("operation types do not match: %s, %s" \
% (self.types, tf_op.type))
self.name = tf_op.name
self.inputs = tuple(Tensor(tf_tensor, *args, **kwargs) for tf_tensor in tf_op.inputs)
self.value = None
self.last_uuid = None
# store attributes as kwargs for calls to eval
self.kwargs = []
for attr in self.attrs:
try:
value = tf_op.get_attr(attr)
except ValueError:
value = None
self.kwargs.append(value)
# store output dtypes for calls to eval when x is True
self.output_dtypes = [dtype_map[dtype] for dtype in tf_op._output_types]
@classmethod
def new(cls, tf_op, *args, **kwargs):
"""
Factory function that takes a tensorflow op *tf_op* and returns an instance of the
appropriate op class. *args* and *kwargs* are forwarded to the op constructor. Raises an
exception of type :py:exc:`UnknownOperationException` in case the requested op type is not
known.
"""
if tf_op.type not in cls.classes:
raise UnknownOperationException("unknown operation: %s" % tf_op.type)
return cls.classes[tf_op.type](tf_op, *args, **kwargs)
def set_attr(self, attr, value):
"""
Overwrites the value of an attribute *attr* with a new *value*.
"""
if attr not in self.attrs:
raise AttributeError("no attribute '%s' in op '%s'" % (attr, self.name))
self.kwargs[self.attrs.index(attr)] = value
def get(self, *names):
"""
Returns one or more tensors given by *names* using a deep lookup within this op. *None* is
returned when no tensor was found.
"""
tensors = tuple(self._get(name) for name in names)
return tensors[0] if len(names) == 1 else tensors
def _get(self, name):
return reduce(lambda t1,t2: t1 or t2.get(name), self.inputs, None)
def eval(self, feed_dict=None, _uuid=None):
""" eval(feed_dict=None)
Returns the value of all output tensors in a tuple. See :py:meth:`Tensor.eval` for more
info.
"""
# set a cache uuid for this eval call
if _uuid is None:
_uuid = uuid4()
# already cached?
if _uuid == self.last_uuid:
return self.value
else:
self.last_uuid = _uuid
args = [t.eval(feed_dict=feed_dict, _uuid=_uuid) for t in self.inputs]
if self.unpack:
args.extend(self.kwargs)
else:
args = [args] + self.kwargs
if self.__class__.output_dtypes:
args.append(self.output_dtypes)
self.value = self.func(*args)
return self.value
@classmethod
def func(cls, *args):
"""
The actual op logic. By default, the method call is forwareded to the
implementation-specific version which is determined using *impl*. Overwrite this method in
inheriting classes to disable this feature. Must return a tuple.
"""
if cls.impl == IMPL_NUMPY:
return cls.func_numpy(*args)
elif cls.impl == IMPL_SCIPY:
return cls.func_scipy(*args)
else:
raise InvalidImplementationException(cls.impl)
@staticmethod
def func_numpy(*args):
"""
Numpy implementation of the op logic. Returns a tuple.
"""
raise NotImplementedError
@staticmethod
def func_scipy(*args):
"""
Scipy implementation of the op logic. Returns a tuple.
"""
raise NotImplementedError
@classmethod
def factory(cls, func=None, impl=IMPL_NUMPY, **kwargs):
""" factory(func=None, impl=IMPL_NUMPY, **kwargs)
Returns a new op class whose static function will be set to *func*. The name of *func* will
also be the op class name. *impl* is the default implementation type of the op. *kwargs* are
used to update the class dict of the newly created op class.
"""
if impl not in IMPLS:
raise InvalidImplementationException(impl)
def wrapper(func):
classdict = {"impls": [], "func_" + IMPL_NAMES[impl]: staticmethod(func)}
classdict.update(kwargs)
cls = Operation.__class__(func.__name__, (Operation,), classdict)
cls.__doc__ = func.__doc__
cls.impls.append(impl)
cls.use_impl(impl)
return cls
return wrapper if func is None else wrapper(func)
@classmethod
def use_impl(cls, impl):
"""
Switches the implementation type to *impl*. Returns the previous type.
"""
if impl not in cls.impls:
raise UnknownImplementationException(impl)
prev = cls.impl
cls.impl = impl
return prev
@classmethod
def add_impl(cls, impl):
"""
Decorator to add an additional implementation to this op. Example:
.. code-block:: python
# initial implementation using factory, defaults to numpy
@Operation.factory
def MyOp(a, b):
# use numpy only
return ...
# also add a scipy implementation
@MyOp.add_impl(IMPL_SCIPY)
def MyOp(a, b):
# also use scipy
return ...
"""
if impl not in IMPLS:
raise InvalidImplementationException(impl)
def wrapper(func):
setattr(cls, "func_" + IMPL_NAMES[impl], staticmethod(func))
if impl not in cls.impls:
cls.impls.append(impl)
return cls
return wrapper
# ensemble method types
METHODS = METHOD_MEAN, METHOD_MAX, METHOD_MIN, METHOD_CUSTOM = range(4)
METHOD_NAMES = ["mean", "max", "min", "custom"]
class Ensemble(object):
"""
An ensemble is a wrapper around multiple models to compute ensemble values. It can initialized
with a list of model paths and an ensembling method that decides how to compute the merged
value.
.. code-block:: python
# create the ensemble
ensemble = Ensemble(["model1.pkl", "model2.pkl", ...], METHOD_MEAN)
# get input and output tensors (which actually are TensorEnsemble instances)
input, output = ensemble.get("input", "output")
# evaluate the ensemble just like a normal model
batch = ...
value = output.eval({input: batch})
If you want to use another method than ``METHOD_MEAN``, ``METHOD_MAX`` or ``METHOD_MAX``, use
``METHOD_CUSTOM`` and overwrite the ``func_custom`` method of the :py:class:`TensorEnsemble`
instance.
.. py:attribute:: models
A list that contains all read models.
.. py:attribute:: method
The ensembling method.
"""
def __init__(self, paths=None, method=METHOD_MEAN):
""" __init__(paths=None, method=METHOD_MEAN)
"""
super(Ensemble, self).__init__()
# check method
if method not in METHODS:
raise UnknownEnsembleMethodException(method)
self.method = method
# loaded models
self.models = []
# load when desired
if paths is not None:
self.load(paths)
def get(self, *names, **kwargs):
""" get(*names, key=None)
Returns one or more :py:class:`TensorEnsemble` instances given by *names* using a deep
lookup within all read models. Each returned tensor ensemble will have ``len(models)``
tensors. If a model does not contain a specific tensor defined by a specific *name*, the
associated ensemble tensor will contain a *None* for that model in its tensors. If *key* is
not *None*, only the root tensors with that *key* are traversed.
"""
# create empty tensor ensembles with our method
tensor_ensembles = [TensorEnsemble([], self.method) for name in names]
# loop over models, collect and add tensors
for model in self.models:
tensors = model.get(*names, **kwargs)
if not isinstance(tensors, tuple):
tensors = (tensors,)
for i, t in enumerate(tensors if isinstance(tensors, tuple) else (tensors,)):
tensor_ensembles[i].tensors.append(t)
return tensor_ensembles[0] if len(names) == 1 else tuple(tensor_ensembles)
def load(self, paths):
"""
Loads models from a list of *paths*.
"""
for path in paths:
self.models.append(Model(path))
class TensorEnsemble(object):
"""
A tensor ensemble basically contains a list of tensors that correspond to models of an
:py:class:`Ensemble` instance.
.. py:attribute: tensors
The list of contained tensors. Tensor *i* corresponds to model *i*.
.. py:attribute: method
The ensembling method.
"""
def __init__(self, tensors, method=METHOD_MEAN):
super(TensorEnsemble, self).__init__()
# check method
if method not in METHODS:
raise UnknownEnsembleMethodException(method)
self.method = method
self.tensors = list(tensors)
def eval(self, feed_dict=None):
"""
Evaluates all contained tensors using a *feed_dict* and returns the ensemble value. The keys
of *feed_dict* must be tensor ensembles. Its values can be batches, i.e., numpy arrays, or
lists or tuples of batches. In the latter case, these lists or tuples must have the same
length as the list of stored tensors as they will be mapped.
"""
# first, check that the length of all feed_dict keys match our own length
for tensor_ensemble in feed_dict:
if len(tensor_ensemble.tensors) != len(self.tensors):
raise EnsembleMismatchException("incompatible lengths of tensors: %d, %d" \
% (len(self.tensors), len(tensor_ensemble.tensors)))
# create a joined uuid
_uuid = uuid4()
# prepare feed_dicts
feed_dicts = [{} for _ in range(len(self.tensors))]
for tensor_ensemble, value in feed_dict.items():
for i, tensor in enumerate(tensor_ensemble.tensors):
if tensor is not None:
feed_dicts[i][tensor] = value[i] if isinstance(value, (list, tuple)) else value
# eval all tensors
values = [t.eval(feed_dict=d, _uuid=_uuid) for t, d in zip(self.tensors, feed_dicts)]
# return the computed ensemble value
return self.func(values)
def __call__(self, *args, **kwargs):
return self.eval(*args, **kwargs)
def func(self, values):
"""
The actual ensembling logic that combines multiple *values*. The method call is forwareded
tothe ensemble method-specific variant which is determined using *method*.
"""
if self.method == METHOD_MEAN:
return self.func_mean(values)
elif self.method == METHOD_MAX:
return self.func_max(values)
elif self.method == METHOD_MIN:
return self.func_min(values)
elif self.method == METHOD_CUSTOM:
return self.func_custom(values)
else:
raise UnknownEnsembleMethodException(self.method)
@staticmethod
def func_mean(values):
return np.mean(np.stack(values), axis=0)
@staticmethod
def func_max(values):
return np.amax(np.stack(values), axis=0)
@staticmethod
def func_min(values):
return np.amin(np.stack(values), axis=0)
@staticmethod
def func_custom(values):
raise NotImplementedError
class UnknownOperationException(Exception):
"""
An exception which is raised when trying to convert an unknown tensorflow.
"""
class OperationMismatchException(Exception):
"""
An exception which is raised during instantiation of an op whose type does not match the
underlying tensorflow op.
"""
class InvalidImplementationException(Exception):
"""
An exception which is raised when an implementation of an unknown type is registered for an
:py:class:`Operation` class.
"""
class UnknownImplementationException(Exception):
"""
An exception which is raised when an :py:class:`Operation` instance is requested to use an
implementation type that was not yet added.
"""
class UnknownEnsembleMethodException(Exception):
"""
An exception which is raised when an :py:class:`Ensemble` instance is initialised with an
unknown ensemle method.
"""
class EnsembleMismatchException(Exception):
"""
An exception which is raised when a :py:class:`TensorEnsemble` instance is evaluated with a
*feed_dict* whose keys, i.e. also :py:class:`TensorEnsemble` instances, do not match the tensor
to evaluate. An example would be that a tensor ensemble with *n* tensors is evaluated with a
tensor ensemble it its *feed_dict* that contains *m* tensors.
"""
class ScipyOperationException(Exception):
"""
An exception which is raised when trying to evaluate an op that uses scipy internally and scipy
is not available.
"""
def __init__(self, attr):
msg = "trying to access 'scipy.%s', but scipy is not installed on your system, " \
"install scipy to use this operation or use an other implementation" % attr
super(ScipyOperationException, self).__init__(msg)
# parses the tf version and returns a tuple, e.g. "0.12.0-rc1" => (0, 12, 0, "rc1")
def _parse_tf_version(v):
parts = v.split(".", 2)
if "-" in parts[2]:
parts.extend(parts.pop().split("-", 1))
return tuple([int(p) for p in parts[:3]] + parts[3:])
# default (last) tf version
_tf_version_string = "0.12.0-rc1"
_tf_version = _parse_tf_version(_tf_version_string)
def setup(tf, order=None):
"""
Sets up global variables (currently only the tensorflow version) to adapt to peculiarities of
different tensorflow versions. This function should only be called before :py:class:`Model`
creation, not for evaluation. Therefore, the tensorflow module *tf* must be passed:
.. code-block:: python
import tensorflow as tf
import tfdeploy as td
td.setup(tf)
# ...
Also, when *order* is not *None*, it is forwarded to :py:func:`optimize` for convenience.
"""
global _tf_version_string, _tf_version
_tf_version_string = tf.__version__
_tf_version = _parse_tf_version(_tf_version_string)
if order is not None:
optimize(order)
def reset():
"""
Resets the instance caches of :py:class:`TensorRegister` and :py:class:`OperationRegister`.
"""
TensorRegister.instances.clear()
OperationRegister.instances.clear()
def optimize(order):
""" optimize(impl)
Tries to set the implementation type of all registered :py:class:`Operation` classes to *impl*.
This has no effect when an op does not implement that type.
The behavior is equivalent to:
.. code-block:: python
for op in Operation.__subclasses__():
if impl in op.impls:
op.use_impl(impl)
*impl* can also be a list or tuple of valid implementation types representing a preferred order.
"""
if not isinstance(order, (list, tuple)):
order = [order]
for op in Operation.__subclasses__():
for impl in order:
if impl in op.impls:
op.use_impl(impl)
break
def print_tensor(td_tensor, indent="| ", max_depth=-1, depth=0):
""" print_tensor(td_tensor, indent=" ", max_depth=-1)
Prints the dependency graph of a :py:class:`Tensor` *td_tensor*, where each new level is
indented by *indent*. When *max_depth* is positive, the graph is truncated at that depth, where
each tensor and each op count as a level.
"""
offset = depth * indent
line = "td tensor: %s" % td_tensor.name
if td_tensor.value is not None:
line += " (%s)" % (",".join(str(i) for i in td_tensor.value.shape),)
print(offset + line)
if td_tensor.op and (max_depth < 0 or max_depth > depth):
print_op(td_tensor.op, indent=indent, max_depth=max_depth, depth=depth+1)
def print_op(td_op, indent="| ", max_depth=-1, depth=0):
""" print_op(td_op, indent=" ", max_depth=-1)
Prints the dependency graph of a :py:class:`Operation` *td_op*, where each new level is indented
by *indent*. When *max_depth* is positive, the graph is truncated at that depth, where each
tensor and each op count as a level.
"""
offset = depth * indent
line = "td op: %s (%s)" % (td_op.name, ",".join(td_op.types))
print(offset + line)
if max_depth < 0 or max_depth > depth:
for td_tensor in td_op.inputs:
print_tensor(td_tensor, indent=indent, max_depth=max_depth, depth=depth+1)
def print_tf_tensor(tf_tensor, indent="| ", max_depth=-1, depth=0):
""" print_tf_tensor(tf_tensor, indent=" ", max_depth=-1)
Prints the dependency graph of a tensorflow tensor *tf_tensor*, where each new level is indented
by *indent*. When *max_depth* is positive, the graph is truncated at that depth, where each
tensor and each op count as a level.
"""
offset = depth * indent
shape = tuple(int(i) for i in tf_tensor.get_shape())
line = "tf tensor: %s (%s)" % (tf_tensor.name, ",".join(str(i) for i in shape))
print(offset + line)
if tf_tensor.op and (max_depth < 0 or max_depth > depth):
print_tf_op(tf_tensor.op, indent=indent, max_depth=max_depth, depth=depth+1)
def print_tf_op(tf_op, indent="| ", max_depth=-1, depth=0):
""" print_tf_op(tf_tensor, indent=" ", max_depth=-1)
Prints the dependency graph of a tensorflow operation *tf_op*, where each new level is indented
by *indent*. When *max_depth* is positive, the graph is truncated at that depth, where each
tensor and each op count as a level.
"""
offset = depth * indent
line = "tf op: %s (%s)" % (tf_op.name, tf_op.type)
print(offset + line)
if max_depth < 0 or max_depth > depth:
for tf_tensor in tf_op.inputs:
print_tf_tensor(tf_tensor, indent=indent, max_depth=max_depth, depth=depth+1)
# imports exclusively for ops
from operator import mul
from itertools import product
from collections import defaultdict
# optional import of scipy
try:
if os.environ.get("TD_REFUSE_SCIPY", "").lower() in ("1", "true", "yes"):
raise ImportError
import scipy as sp
import scipy.special
HAS_SCIPY = True
except ImportError:
class ScipyDummy(object):
def __getattr__(self, attr):
raise ScipyOperationException(attr)
sp = ScipyDummy()
HAS_SCIPY = False
# mapping of tf dtypes to np dtypes
dtype_map = {
1: np.float32,
2: np.float64,
3: np.int32,
4: np.uint8,
5: np.int16,
6: np.int8,
7: np.object,
8: np.complex64,
9: np.int64,
10: np.bool,
14: np.uint16,
17: np.uint16,
18: np.complex128,
19: np.float16,
101: np.float32,
102: np.float64,
103: np.int32,
104: np.uint8,
105: np.int16,
106: np.int8,
107: np.object,
108: np.complex64,
109: np.int64,
110: np.bool,
114: np.uint16,
117: np.uint16,
118: np.complex128,
119: np.float16
}
lgamma_vec = np.vectorize(np.math.lgamma)
erf_vec = np.vectorize(np.math.erf)
erfc_vec = np.vectorize(np.math.erfc)
def _transpose(a, dim=2):
if dim <= 0:
axes = None
else:
axes = list(range(a.ndim))
axes.append(axes.pop(-1 * dim))
return np.transpose(a, axes=axes)
def _adjoint(a, dim=2):
return np.conj(_transpose(a, dim=dim))
#
# sequences
#