diff --git a/dfdx-core/src/tensor_ops/unstack/mod.rs b/dfdx-core/src/tensor_ops/unstack/mod.rs index 0155e555..21204f99 100644 --- a/dfdx-core/src/tensor_ops/unstack/mod.rs +++ b/dfdx-core/src/tensor_ops/unstack/mod.rs @@ -46,6 +46,7 @@ impl, T, const N: usize> TryUnstack where S: SubDim>, + D: super::reshape_to::ReshapeKernel, T: Tape, { type Unstacked = ([Tensor; N], T); @@ -57,6 +58,7 @@ where impl, T> TryUnstack for Tensor where S: SubDim, + D: super::reshape_to::ReshapeKernel, T: Tape, { type Unstacked = (Vec>, T); @@ -136,6 +138,7 @@ fn try_unstack, T> where S: SubDim, T: Tape, + D: super::reshape_to::ReshapeKernel, OptionalItems: Array>, Dim = S::Head> + std::ops::IndexMut>>, Items: Array, Dim = S::Head>, @@ -144,10 +147,17 @@ where let (head, _tail) = stack.shape().sub_dim(); let (stack, stack_tape) = stack.split_tape(); + // TODO: remove this overhead, and panic on a non-contiguous condition + let stack = { + use super::reshape_to::ReshapeTo; + stack.try_contiguous()? + }; + let stack_ghost = stack.ghost(); // list of optional tensors (all are Some) - let mut unstacks = device.forward::<_, OptionalItems>(stack)?; + + let mut unstacks = UnstackKernel::forward::<_, OptionalItems>(&device, stack)?; // tensors from unstacks must get tapes inserted into them. // to do this, from_fn is re-utilized, but this time without optionals @@ -163,7 +173,7 @@ where grads.try_alloc_for(&stack_ghost)?; grads.try_alloc_for(&unstack_ghost)?; let (grad_stack, grad_unstack) = grads.mut_and_ref(&stack_ghost, &unstack_ghost); - device.backward(grad_stack, grad_unstack, i) + UnstackKernel::backward(&device, grad_stack, grad_unstack, i) }); unstack.put_tape(unstack_tape) },