We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent c3dc4b5 commit 4a918a9Copy full SHA for 4a918a9
breze/arch/util.py
@@ -7,6 +7,7 @@
7
import theano
8
import theano.tensor as T
9
import theano.sandbox.cuda
10
+import theano.sandbox.cuda.var
11
import theano.misc.gnumpy_utils as gput
12
13
from breze.utils import dictlist
@@ -399,7 +400,10 @@ def _init_exprs(self):
399
400
pass
401
402
def _lookup(self, container, ident):
- if isinstance(ident, theano.tensor.basic.TensorVariable):
403
+ tensor_types = (theano.tensor.basic.TensorVariable,
404
+ theano.sandbox.cuda.var.CudaNdarrayVariable)
405
+
406
+ if isinstance(ident, tensor_types):
407
res = ident
408
elif isinstance(ident, tuple):
409
res = dictlist.get(container, ident)
0 commit comments