diff --git a/runtime/include/tt/runtime/detail/ttmetal.h b/runtime/include/tt/runtime/detail/ttmetal.h index d700cd4fc7..05f968a900 100644 --- a/runtime/include/tt/runtime/detail/ttmetal.h +++ b/runtime/include/tt/runtime/detail/ttmetal.h @@ -30,6 +30,12 @@ inline Tensor createTensor(std::shared_ptr data, TensorDesc const &desc) { } tt::target::DataType getTensorDataType(Tensor tensor); +std::vector getTensorDataBuffer(::tt::runtime::Tensor tensor); +std::vector getTensorShape(::tt::runtime::Tensor tensor); +std::vector getTensorStride(::tt::runtime::Tensor tensor); +std::uint32_t getTensorElementSize(::tt::runtime::Tensor tensor); +std::uint32_t getTensorVolume(::tt::runtime::Tensor tensor); +TensorDesc getTensorDesc(::tt::runtime::Tensor tensor); size_t getNumAvailableDevices(); @@ -65,8 +71,6 @@ std::string getOpLocInfo(OpContext opContextHandle); Tensor getOpOutputTensor(OpContext opContextHandle, CallbackContext programContextHandle); -std::vector getTensorData(Tensor tensor); - using InputBuffer = std::tuple, std::shared_ptr<::tt::tt_metal::Event>>; diff --git a/runtime/include/tt/runtime/detail/ttnn.h b/runtime/include/tt/runtime/detail/ttnn.h index ac01626459..82de11cca9 100644 --- a/runtime/include/tt/runtime/detail/ttnn.h +++ b/runtime/include/tt/runtime/detail/ttnn.h @@ -94,6 +94,13 @@ inline Tensor createTensor(Device device, Layout layout, tt::target::DataType getTensorDataType(Tensor tensor); +std::vector getTensorDataBuffer(::tt::runtime::Tensor tensor); +std::vector getTensorShape(::tt::runtime::Tensor tensor); +std::vector getTensorStride(::tt::runtime::Tensor tensor); +std::uint32_t getTensorElementSize(::tt::runtime::Tensor tensor); +std::uint32_t getTensorVolume(::tt::runtime::Tensor tensor); +TensorDesc getTensorDesc(::tt::runtime::Tensor tensor); + size_t getNumAvailableDevices(); Device @@ -137,8 +144,6 @@ std::string getOpLocInfo(OpContext opContextHandle); Tensor getOpOutputTensor(OpContext opContextHandle, CallbackContext programContextHandle); -std::vector getTensorData(Tensor tensor); - std::vector submit(Device deviceHandle, Binary executableHandle, std::uint32_t programIndex, std::vector const &inputs); diff --git a/runtime/include/tt/runtime/runtime.h b/runtime/include/tt/runtime/runtime.h index 5d4b3bafba..58d8c4f6d3 100644 --- a/runtime/include/tt/runtime/runtime.h +++ b/runtime/include/tt/runtime/runtime.h @@ -89,6 +89,12 @@ inline Tensor createTensor(Device device, Layout layout, } tt::target::DataType getTensorDataType(Tensor tensor); +std::vector getTensorDataBuffer(Tensor tensor); +std::uint32_t getTensorElementSize(Tensor tensor); +std::uint32_t getTensorVolume(Tensor tensor); +std::vector getTensorShape(Tensor tensor); +std::vector getTensorStride(Tensor tensor); +TensorDesc getTensorDesc(Tensor tensor); size_t getNumAvailableDevices(); @@ -126,8 +132,6 @@ std::string getOpLocInfo(OpContext opContextHandle); Tensor getOpOutputTensor(OpContext opContextHandle, CallbackContext programContextHandle); -std::vector getTensorData(Tensor tensor); - std::vector submit(Device deviceHandle, Binary executableHandle, std::uint32_t programIndex, std::vector const &inputs); diff --git a/runtime/lib/runtime.cpp b/runtime/lib/runtime.cpp index 9023f99087..1d8682e0a8 100644 --- a/runtime/lib/runtime.cpp +++ b/runtime/lib/runtime.cpp @@ -477,22 +477,6 @@ Tensor getOpOutputTensor(OpContext opContextHandle, LOG_FATAL("runtime is not enabled"); } -std::vector getTensorData(Tensor tensor) { -#if defined(TT_RUNTIME_ENABLE_TTNN) - if (getCurrentRuntime() == DeviceRuntime::TTNN) { - return ::tt::runtime::ttnn::getTensorData(tensor); - } -#endif - -#if defined(TT_RUNTIME_ENABLE_TTMETAL) - if (getCurrentRuntime() == DeviceRuntime::TTMetal) { - return ::tt::runtime::ttmetal::getTensorData(tensor); - } -#endif - - LOG_FATAL("runtime is not enabled"); -} - std::vector submit(Device deviceHandle, Binary executableHandle, std::uint32_t programIndex, std::vector const &inputHandles) { @@ -531,4 +515,109 @@ Event submit(Device deviceHandle, Binary executableHandle, #endif LOG_FATAL("runtime is not enabled"); } +std::vector getTensorDataBuffer(::tt::runtime::Tensor t) { +#if defined(TT_RUNTIME_ENABLE_TTNN) + if (getCurrentRuntime() == DeviceRuntime::TTNN) { + return ::tt::runtime::ttnn::getTensorDataBuffer(t); + } +#endif + +#if defined(TT_RUNTIME_ENABLE_TTMETAL) + if (getCurrentRuntime() == DeviceRuntime::TTMetal) { + return ::tt::runtime::ttmetal::getTensorDataBuffer(t); + } +#endif + LOG_FATAL("runtime is not enabled"); +} + +std::vector getTensorShape(::tt::runtime::Tensor t) { +#if defined(TT_RUNTIME_ENABLE_TTNN) + if (getCurrentRuntime() == DeviceRuntime::TTNN) { + return ::tt::runtime::ttnn::getTensorShape(t); + } +#endif + +#if defined(TT_RUNTIME_ENABLE_TTMETAL) + if (getCurrentRuntime() == DeviceRuntime::TTMetal) { + return ::tt::runtime::ttmetal::getTensorShape(t); + } +#endif + LOG_FATAL("runtime is not enabled"); +} + +std::vector getTensorStride(::tt::runtime::Tensor t) { +#if defined(TT_RUNTIME_ENABLE_TTNN) + if (getCurrentRuntime() == DeviceRuntime::TTNN) { + return ::tt::runtime::ttnn::getTensorStride(t); + } +#endif + +#if defined(TT_RUNTIME_ENABLE_TTMETAL) + if (getCurrentRuntime() == DeviceRuntime::TTMetal) { + return ::tt::runtime::ttmetal::getTensorStride(t); + } +#endif + LOG_FATAL("runtime is not enabled"); +} + +target::DataType getDtype(::tt::runtime::Tensor t) { +#if defined(TT_RUNTIME_ENABLE_TTNN) + if (getCurrentRuntime() == DeviceRuntime::TTNN) { + return ::tt::runtime::ttnn::getTensorDataType(t); + } +#endif + +#if defined(TT_RUNTIME_ENABLE_TTMETAL) + if (getCurrentRuntime() == DeviceRuntime::TTMetal) { + return ::tt::runtime::ttmetal::getTensorDataType(t); + } +#endif + LOG_FATAL("runtime is not enabled"); +} + +std::uint32_t getTensorElementSize(::tt::runtime::Tensor t) { +#if defined(TT_RUNTIME_ENABLE_TTNN) + if (getCurrentRuntime() == DeviceRuntime::TTNN) { + return ::tt::runtime::ttnn::getTensorElementSize(t); + } +#endif + +#if defined(TT_RUNTIME_ENABLE_TTMETAL) + if (getCurrentRuntime() == DeviceRuntime::TTMetal) { + return ::tt::runtime::ttmetal::getTensorElementSize(t); + } +#endif + LOG_FATAL("runtime is not enabled"); +} + +std::uint32_t getTensorVolume(::tt::runtime::Tensor t) { +#if defined(TT_RUNTIME_ENABLE_TTNN) + if (getCurrentRuntime() == DeviceRuntime::TTNN) { + return ::tt::runtime::ttnn::getTensorVolume(t); + } +#endif + +#if defined(TT_RUNTIME_ENABLE_TTMETAL) + if (getCurrentRuntime() == DeviceRuntime::TTMetal) { + return ::tt::runtime::ttmetal::getTensorVolume(t); + } +#endif + LOG_FATAL("runtime is not enabled"); +} + +TensorDesc getTensorDesc(::tt::runtime::Tensor t) { +#if defined(TT_RUNTIME_ENABLE_TTNN) + if (getCurrentRuntime() == DeviceRuntime::TTNN) { + return ::tt::runtime::ttnn::getTensorDesc(t); + } +#endif + +#if defined(TT_RUNTIME_ENABLE_TTMETAL) + if (getCurrentRuntime() == DeviceRuntime::TTMetal) { + return ::tt::runtime::ttmetal::getTensorDesc(t); + } +#endif + LOG_FATAL("runtime is not enabled"); +} + } // namespace tt::runtime diff --git a/runtime/lib/ttmetal/runtime.cpp b/runtime/lib/ttmetal/runtime.cpp index 255ef9da89..bc838efda7 100644 --- a/runtime/lib/ttmetal/runtime.cpp +++ b/runtime/lib/ttmetal/runtime.cpp @@ -357,9 +357,33 @@ Tensor getOpOutputTensor(OpContext opContextHandle, return createNullTensor(); } -std::vector getTensorData(Tensor tensor) { - // Not implemented - LOG_WARNING("obtaining tensor data for metal runtime not implemented"); +std::vector getTensorDataBuffer(::tt::runtime::Tensor tensor) { + LOG_WARNING("getDataBuffer not implemented for metal runtime"); + return {}; +} + +std::vector getTensorShape(::tt::runtime::Tensor tensor) { + LOG_WARNING("getShape not implemented for metal runtime"); + return {}; +} + +std::vector getTensorStride(::tt::runtime::Tensor tensor) { + LOG_WARNING("getStride not implemented for metal runtime"); + return {}; +} + +std::uint32_t getTensorElementSize(::tt::runtime::Tensor tensor) { + LOG_WARNING("getElementSize not implemented for metal runtime"); + return 0; +} + +std::uint32_t getTensorVolume(::tt::runtime::Tensor tensor) { + LOG_WARNING("getVolume not implemented for metal runtime"); + return 0; +} + +TensorDesc getTensorDesc(::tt::runtime::Tensor tensor) { + LOG_WARNING("getTensorDesc not implemented for metal runtime"); return {}; } diff --git a/runtime/lib/ttnn/runtime.cpp b/runtime/lib/ttnn/runtime.cpp index f77dff0125..7e97cdc035 100644 --- a/runtime/lib/ttnn/runtime.cpp +++ b/runtime/lib/ttnn/runtime.cpp @@ -555,16 +555,120 @@ Tensor getOpOutputTensor(OpContext opContextHandle, DeviceRuntime::TTNN); } -std::vector getTensorData(Tensor tensor) { - const ::ttnn::Tensor *nnTensor = - static_cast<::ttnn::Tensor *>(tensor.handle.get()); - if (nnTensor == nullptr) { +std::vector getTensorDataBuffer(::tt::runtime::Tensor tensor) { + const ::ttnn::Tensor &ttnnTensor = + tensor.as<::ttnn::Tensor>(DeviceRuntime::TTNN); + void *dataPtr = nullptr; + std::vector dataVec(getTensorElementSize(tensor) * + getTensorVolume(tensor)); + + // Need to `memcpy` in each case because the vector will go out of scope if we + // wait until after the switch case + switch (getTensorDataType(tensor)) { + case target::DataType::BFP_BFloat4: { + dataVec.resize(sizeof(float) * getTensorVolume(tensor)); + auto vec = ttnnTensor.to_vector(); + dataPtr = vec.data(); + LOG_ASSERT(dataPtr != nullptr); + std::memcpy(dataVec.data(), dataPtr, dataVec.size()); + return dataVec; + } + case target::DataType::BFP_BFloat8: { + dataVec.resize(sizeof(float) * getTensorVolume(tensor)); + auto vec = ttnnTensor.to_vector(); + dataPtr = vec.data(); + LOG_ASSERT(dataPtr != nullptr); + std::memcpy(dataVec.data(), dataPtr, dataVec.size()); + return dataVec; + } + case target::DataType::Float32: { + auto vec = ttnnTensor.to_vector(); + dataPtr = vec.data(); + LOG_ASSERT(dataPtr != nullptr); + std::memcpy(dataVec.data(), dataPtr, dataVec.size()); + return dataVec; + } + case target::DataType::BFloat16: { + auto vec = ttnnTensor.to_vector(); + dataPtr = vec.data(); + LOG_ASSERT(dataPtr != nullptr); + std::memcpy(dataVec.data(), dataPtr, dataVec.size()); + return dataVec; + } + case target::DataType::Int32: { + auto vec = ttnnTensor.to_vector(); + dataPtr = vec.data(); + LOG_ASSERT(dataPtr != nullptr); + std::memcpy(dataVec.data(), dataPtr, dataVec.size()); + return dataVec; + } + case target::DataType::UInt32: { + auto vec = ttnnTensor.to_vector(); + dataPtr = vec.data(); + LOG_ASSERT(dataPtr != nullptr); + std::memcpy(dataVec.data(), dataPtr, dataVec.size()); + return dataVec; + } + case target::DataType::UInt16: { + auto vec = ttnnTensor.to_vector(); + dataPtr = vec.data(); + LOG_ASSERT(dataPtr != nullptr); + std::memcpy(dataVec.data(), dataPtr, dataVec.size()); + return dataVec; + } + case target::DataType::UInt8: { + auto vec = ttnnTensor.to_vector(); + dataPtr = vec.data(); + LOG_ASSERT(dataPtr != nullptr); + std::memcpy(dataVec.data(), dataPtr, dataVec.size()); + return dataVec; + } + default: + LOG_ERROR("Unsupported datatype for underlying TTNN tensor, returning " + "empty data vector"); return {}; } +} + +std::vector getTensorShape(::tt::runtime::Tensor tensor) { + const ::ttnn::Tensor &ttnnTensor = + tensor.as<::ttnn::Tensor>(DeviceRuntime::TTNN); + std::vector shape; + for (size_t i = 0; i < ttnnTensor.logical_shape().size(); ++i) { + shape.push_back(ttnnTensor.logical_shape()[i]); + } + return shape; +} + +std::vector getTensorStride(::tt::runtime::Tensor tensor) { + const ::ttnn::Tensor &ttnnTensor = + tensor.as<::ttnn::Tensor>(DeviceRuntime::TTNN); + std::vector stride; + for (size_t i = 0; i < ttnnTensor.strides().size(); ++i) { + stride.push_back(ttnnTensor.strides()[i]); + } + return stride; +} + +std::uint32_t getTensorElementSize(::tt::runtime::Tensor tensor) { + const ::ttnn::Tensor &ttnnTensor = + tensor.as<::ttnn::Tensor>(DeviceRuntime::TTNN); + return ttnnTensor.element_size(); +} + +std::uint32_t getTensorVolume(::tt::runtime::Tensor tensor) { + const ::ttnn::Tensor &ttnnTensor = + tensor.as<::ttnn::Tensor>(DeviceRuntime::TTNN); + return ttnnTensor.volume(); +} - void *dataPtr = ::tt::tt_metal::get_raw_host_data_ptr(*nnTensor); - return std::vector(static_cast(dataPtr), - static_cast(dataPtr) + nnTensor->volume()); +TensorDesc getTensorDesc(::tt::runtime::Tensor tensor) { + TensorDesc desc; + desc.dataType = getTensorDataType(tensor); + desc.itemsize = getTensorElementSize(tensor); + desc.stride = getTensorStride(tensor); + desc.shape = getTensorShape(tensor); + return desc; } std::vector submit(Device deviceHandle, Binary executableHandle, diff --git a/runtime/test/python/ttnn/test_runtime_api.py b/runtime/test/python/ttnn/test_runtime_api.py index fc4cc9be79..6372b7d28a 100644 --- a/runtime/test/python/ttnn/test_runtime_api.py +++ b/runtime/test/python/ttnn/test_runtime_api.py @@ -10,6 +10,43 @@ from utils import TT_MLIR_HOME, Helper, DeviceContext, assert_pcc +@pytest.mark.parametrize("shape", [(64, 128)]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16]) +def test_tensor_buffer_api(shape, dtype): + torch_tensor = torch.randn(shape, dtype=dtype) + runtime_dtype = Binary.Program.to_data_type(dtype) + rt_tensor = ttrt.runtime.create_tensor( + torch_tensor.data_ptr(), + list(torch_tensor.shape), + list(torch_tensor.stride()), + torch_tensor.element_size(), + runtime_dtype, + ) + rt_shape = rt_tensor.get_shape() + rt_stride = rt_tensor.get_stride() + rt_elem_size = rt_tensor.get_element_size() + rt_vol = rt_tensor.get_volume() + rt_dtype = ttrt_datatype_to_torch_dtype(rt_tensor.get_dtype()) + rt_bytes = rt_tensor.get_data_buffer() + rt_desc = rt_tensor.get_tensor_desc() + + # Tests to make sure the binding of `TensorDesc` works. Might belong in its own test? + assert rt_desc.shape == rt_shape + assert rt_desc.stride == rt_stride + assert ttrt_datatype_to_torch_dtype(rt_desc.dtype) == rt_dtype + assert rt_desc.item_size == rt_elem_size + + # Various tests that the no underlying stuff has changed over the pybind boundary + assert list(rt_shape) == list(shape) + assert list(rt_stride) == list(torch_tensor.stride()) + assert rt_elem_size == torch_tensor.element_size() + assert rt_vol == torch_tensor.numel() + assert rt_dtype == torch_tensor.dtype + assert len(rt_bytes) == rt_vol * rt_elem_size + reconstructed_tensor = torch.frombuffer(rt_bytes, dtype=rt_dtype).reshape(rt_shape) + assert torch.equal(torch_tensor, reconstructed_tensor) + + @pytest.mark.parametrize("shape", [(64, 128)]) @pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16]) def test_to_layout(helper: Helper, shape, dtype, request): diff --git a/runtime/tools/python/ttrt/common/callback.py b/runtime/tools/python/ttrt/common/callback.py index 93a9af267b..782e9094d2 100644 --- a/runtime/tools/python/ttrt/common/callback.py +++ b/runtime/tools/python/ttrt/common/callback.py @@ -212,15 +212,17 @@ def golden(callback_runtime_config, binary, program_context, op_context): op_output_tensor = ttrt.runtime.get_op_output_tensor(op_context, program_context) - if len(op_output_tensor) == 0: + if op_output_tensor is None: logging.debug("Output tensor is empty - skipping golden comparison") return + rt_buffer = op_output_tensor.get_data_buffer() dtype = ttrt_datatype_to_torch_dtype(op_golden_tensor.dtype) + assert ttrt_datatype_to_torch_dtype(op_output_tensor.get_dtype()) == dtype golden_tensor_torch = torch.frombuffer(op_golden_tensor, dtype=dtype).flatten() - output_tensor_torch = torch.tensor(op_output_tensor, dtype=dtype).flatten() + output_tensor_torch = torch.frombuffer(rt_buffer, dtype=dtype).flatten() if callback_runtime_config.save_golden_tensors: torch.save( diff --git a/runtime/tools/python/ttrt/common/util.py b/runtime/tools/python/ttrt/common/util.py index 1da498f465..dfe838a44f 100644 --- a/runtime/tools/python/ttrt/common/util.py +++ b/runtime/tools/python/ttrt/common/util.py @@ -53,8 +53,12 @@ def ttrt_datatype_to_torch_dtype(dtype) -> torch.dtype: return torch.uint16 elif dtype == DataType.UInt8: return torch.uint8 + elif dtype == DataType.BFloat16: + return torch.bfloat16 else: - raise ValueError("Only F32 and unsigned integers are supported in the runtime") + raise ValueError( + "Only F32, BF16, and unsigned integers are supported in the runtime" + ) def get_ttrt_metal_home_path(): diff --git a/runtime/tools/python/ttrt/runtime/__init__.py b/runtime/tools/python/ttrt/runtime/__init__.py index a12f474708..68f2fca842 100644 --- a/runtime/tools/python/ttrt/runtime/__init__.py +++ b/runtime/tools/python/ttrt/runtime/__init__.py @@ -7,6 +7,7 @@ Device, Event, Tensor, + TensorDesc, MemoryBufferType, DataType, DeviceRuntime, diff --git a/runtime/tools/python/ttrt/runtime/module.cpp b/runtime/tools/python/ttrt/runtime/module.cpp index 72a15dabb1..7763da910b 100644 --- a/runtime/tools/python/ttrt/runtime/module.cpp +++ b/runtime/tools/python/ttrt/runtime/module.cpp @@ -39,7 +39,44 @@ PYBIND11_MODULE(_C, m) { .def("get_memory_view", &tt::runtime::detail::getMemoryView, py::arg("device_id") = 0); py::class_(m, "Event"); - py::class_(m, "Tensor"); + py::class_(m, "TensorDesc") + .def_readonly("shape", &tt::runtime::TensorDesc::shape) + .def_readonly("stride", &tt::runtime::TensorDesc::stride) + .def_readonly("item_size", &tt::runtime::TensorDesc::itemsize) + .def_readonly("dtype", &tt::runtime::TensorDesc::dataType); + py::class_(m, "Tensor") + .def("get_shape", + [](tt::runtime::Tensor self) { + return tt::runtime::getTensorShape(self); + }) + .def("get_stride", + [](tt::runtime::Tensor self) { + return tt::runtime::getTensorStride(self); + }) + .def("get_volume", + [](tt::runtime::Tensor self) { + return tt::runtime::getTensorVolume(self); + }) + .def("get_dtype", + [](tt::runtime::Tensor self) { + return tt::runtime::getTensorDataType(self); + }) + .def("get_element_size", + [](tt::runtime::Tensor self) { + return tt::runtime::getTensorElementSize(self); + }) + .def("get_tensor_desc", + [](tt::runtime::Tensor self) { + return tt::runtime::getTensorDesc(self); + }) + .def( + "get_data_buffer", + [](tt::runtime::Tensor self) { + std::vector vec = tt::runtime::getTensorDataBuffer(self); + return py::bytes(reinterpret_cast(vec.data()), + vec.size()); + }, + py::return_value_policy::take_ownership); py::class_(m, "Layout"); py::class_(m, "OpContext"); py::class_(m, "CallbackContext"); @@ -177,9 +214,11 @@ PYBIND11_MODULE(_C, m) { tt::runtime::CallbackContext &programContextHandle) { tt::runtime::Tensor tensor = tt::runtime::getOpOutputTensor( opContextHandle, programContextHandle); - return tt::runtime::getTensorData(tensor); + return tensor.handle.get() == nullptr + ? std::nullopt + : std::optional(tensor); }, - "Get the input tensor of the op"); + "Get the output tensor of the op"); m.def("get_op_debug_str", &tt::runtime::getOpDebugString, "Get the debug string of the op"); m.def("get_op_loc_info", &tt::runtime::getOpLocInfo,