@@ -591,13 +591,13 @@ auto MakeTensorView(Context const *ctx, Order order, common::Span<T, ext> data,
591591
592592template <typename T, typename ... S>
593593auto MakeTensorView (Context const *ctx, HostDeviceVector<T> *data, S &&...shape) {
594- auto span = ctx->IsCUDA () ? data->DeviceSpan () : data->HostSpan ();
594+ auto span = ctx->IsCPU () ? data->HostSpan () : data->DeviceSpan ();
595595 return MakeTensorView (ctx->Device (), span, std::forward<S>(shape)...);
596596}
597597
598598template <typename T, typename ... S>
599599auto MakeTensorView (Context const *ctx, HostDeviceVector<T> const *data, S &&...shape) {
600- auto span = ctx->IsCUDA () ? data->ConstDeviceSpan () : data->ConstHostSpan ();
600+ auto span = ctx->IsCPU () ? data->ConstHostSpan () : data->ConstDeviceSpan ();
601601 return MakeTensorView (ctx->Device (), span, std::forward<S>(shape)...);
602602}
603603
@@ -647,13 +647,13 @@ auto MakeVec(T *ptr, size_t s, DeviceOrd device = DeviceOrd::CPU()) {
647647
648648template <typename T>
649649auto MakeVec (HostDeviceVector<T> *data) {
650- return MakeVec (data->Device ().IsCUDA () ? data->DevicePointer () : data->HostPointer (),
650+ return MakeVec (data->Device ().IsCPU () ? data->HostPointer () : data->DevicePointer (),
651651 data->Size (), data->Device ());
652652}
653653
654654template <typename T>
655655auto MakeVec (HostDeviceVector<T> const *data) {
656- return MakeVec (data->Device ().IsCUDA () ? data->ConstDevicePointer () : data->ConstHostPointer (),
656+ return MakeVec (data->Device ().IsCPU () ? data->ConstHostPointer () : data->ConstDevicePointer (),
657657 data->Size (), data->Device ());
658658}
659659
@@ -759,7 +759,7 @@ class Tensor {
759759 for (auto i = D; i < kDim ; ++i) {
760760 shape_[i] = 1 ;
761761 }
762- if (device.IsCUDA ()) {
762+ if (! device.IsCPU ()) {
763763 data_.SetDevice (device);
764764 data_.ConstDevicePointer (); // Pull to device;
765765 }
@@ -788,11 +788,11 @@ class Tensor {
788788 shape_[i] = 1 ;
789789 }
790790 auto size = detail::CalcSize (shape_);
791- if (device.IsCUDA ()) {
791+ if (! device.IsCPU ()) {
792792 data_.SetDevice (device);
793793 }
794794 data_.Resize (size);
795- if (device.IsCUDA ()) {
795+ if (! device.IsCPU ()) {
796796 data_.DevicePointer (); // Pull to device
797797 }
798798 }
0 commit comments