From 2eef566612936f04caaf866705ff16ff74c7d5a0 Mon Sep 17 00:00:00 2001 From: Collin Tod Date: Wed, 5 Mar 2025 17:51:15 +0000 Subject: [PATCH] Implement runtime tensor desc API This change implements an API on runtime tensors to reflect tensor metadata and contents reflecting the underlying TTNN tensor. This functionality is pybound as well, allowing for easy casting to torch tensors. Please see `runtime/test/python/ttnn/test_runtime_api.py` for an example. NOTE: This is not the most efficient implementation possible, as there is effectively a double copy to get the data in a pybindable row major format. There is probably some trickery that can be done later to avoid this, but I would like to get this functionality out ASAP and avoid premature optimization Closes #1957 --- runtime/include/tt/runtime/detail/ttmetal.h | 6 ++ runtime/include/tt/runtime/detail/ttnn.h | 7 ++ runtime/include/tt/runtime/types.h | 7 ++ runtime/lib/runtime.cpp | 90 ++++++++++++++++++++ runtime/lib/ttmetal/runtime.cpp | 29 +++++++ runtime/lib/ttnn/runtime.cpp | 77 +++++++++++++++++ runtime/test/python/ttnn/test_runtime_api.py | 28 ++++++ runtime/tools/python/ttrt/common/util.py | 6 +- runtime/tools/python/ttrt/runtime/module.cpp | 12 ++- 9 files changed, 260 insertions(+), 2 deletions(-) diff --git a/runtime/include/tt/runtime/detail/ttmetal.h b/runtime/include/tt/runtime/detail/ttmetal.h index d700cd4fc7..d7ff26fc40 100644 --- a/runtime/include/tt/runtime/detail/ttmetal.h +++ b/runtime/include/tt/runtime/detail/ttmetal.h @@ -30,6 +30,12 @@ inline Tensor createTensor(std::shared_ptr data, TensorDesc const &desc) { } tt::target::DataType getTensorDataType(Tensor tensor); +std::vector getDataBuffer(::tt::runtime::Tensor tensor); +std::vector getShape(::tt::runtime::Tensor tensor); +std::vector getStride(::tt::runtime::Tensor tensor); +std::uint32_t getElementSize(::tt::runtime::Tensor tensor); +std::uint32_t getVolume(::tt::runtime::Tensor tensor); +target::DataType getDtype(::tt::runtime::Tensor tensor); size_t getNumAvailableDevices(); diff --git a/runtime/include/tt/runtime/detail/ttnn.h b/runtime/include/tt/runtime/detail/ttnn.h index ac01626459..cfea0d1344 100644 --- a/runtime/include/tt/runtime/detail/ttnn.h +++ b/runtime/include/tt/runtime/detail/ttnn.h @@ -94,6 +94,13 @@ inline Tensor createTensor(Device device, Layout layout, tt::target::DataType getTensorDataType(Tensor tensor); +std::vector getDataBuffer(::tt::runtime::Tensor tensor); +std::vector getShape(::tt::runtime::Tensor tensor); +std::vector getStride(::tt::runtime::Tensor tensor); +std::uint32_t getElementSize(::tt::runtime::Tensor tensor); +std::uint32_t getVolume(::tt::runtime::Tensor tensor); +target::DataType getDtype(::tt::runtime::Tensor tensor); + size_t getNumAvailableDevices(); Device diff --git a/runtime/include/tt/runtime/types.h b/runtime/include/tt/runtime/types.h index e0f4e7b219..436dfc07b9 100644 --- a/runtime/include/tt/runtime/types.h +++ b/runtime/include/tt/runtime/types.h @@ -151,6 +151,13 @@ struct Tensor : public detail::RuntimeCheckedObjectImpl { std::shared_ptr eventHandle, DeviceRuntime runtime) : detail::RuntimeCheckedObjectImpl(handle, runtime), data(data), event(eventHandle, runtime) {} + + std::vector getDataBuffer(); + std::uint32_t getElementSize(); + std::uint32_t getVolume(); + std::vector getShape(); + std::vector getStride(); + target::DataType getDtype(); }; struct Layout : public detail::RuntimeCheckedObjectImpl { diff --git a/runtime/lib/runtime.cpp b/runtime/lib/runtime.cpp index 9023f99087..38560eba62 100644 --- a/runtime/lib/runtime.cpp +++ b/runtime/lib/runtime.cpp @@ -531,4 +531,94 @@ Event submit(Device deviceHandle, Binary executableHandle, #endif LOG_FATAL("runtime is not enabled"); } +std::vector Tensor::getDataBuffer() { +#if defined(TT_RUNTIME_ENABLE_TTNN) + if (getCurrentRuntime() == DeviceRuntime::TTNN) { + return ::tt::runtime::ttnn::getDataBuffer(*this); + } +#endif + +#if defined(TT_RUNTIME_ENABLE_TTMETAL) + if (getCurrentRuntime() == DeviceRuntime::TTMetal) { + return ::tt::runtime::ttmetal::getDataBuffer(*this); + } +#endif + LOG_FATAL("runtime is not enabled"); +} + +std::vector Tensor::getShape() { +#if defined(TT_RUNTIME_ENABLE_TTNN) + if (getCurrentRuntime() == DeviceRuntime::TTNN) { + return ::tt::runtime::ttnn::getShape(*this); + } +#endif + +#if defined(TT_RUNTIME_ENABLE_TTMETAL) + if (getCurrentRuntime() == DeviceRuntime::TTMetal) { + return ::tt::runtime::ttmetal::getShape(*this); + } +#endif + LOG_FATAL("runtime is not enabled"); +} + +std::vector Tensor::getStride() { +#if defined(TT_RUNTIME_ENABLE_TTNN) + if (getCurrentRuntime() == DeviceRuntime::TTNN) { + return ::tt::runtime::ttnn::getStride(*this); + } +#endif + +#if defined(TT_RUNTIME_ENABLE_TTMETAL) + if (getCurrentRuntime() == DeviceRuntime::TTMetal) { + return ::tt::runtime::ttmetal::getStride(*this); + } +#endif + LOG_FATAL("runtime is not enabled"); +} + +target::DataType Tensor::getDtype() { +#if defined(TT_RUNTIME_ENABLE_TTNN) + if (getCurrentRuntime() == DeviceRuntime::TTNN) { + return ::tt::runtime::ttnn::getDtype(*this); + } +#endif + +#if defined(TT_RUNTIME_ENABLE_TTMETAL) + if (getCurrentRuntime() == DeviceRuntime::TTMetal) { + return ::tt::runtime::ttmetal::getDtype(*this); + } +#endif + LOG_FATAL("runtime is not enabled"); +} + +std::uint32_t Tensor::getElementSize() { +#if defined(TT_RUNTIME_ENABLE_TTNN) + if (getCurrentRuntime() == DeviceRuntime::TTNN) { + return ::tt::runtime::ttnn::getElementSize(*this); + } +#endif + +#if defined(TT_RUNTIME_ENABLE_TTMETAL) + if (getCurrentRuntime() == DeviceRuntime::TTMetal) { + return ::tt::runtime::ttmetal::getElementSize(*this); + } +#endif + LOG_FATAL("runtime is not enabled"); +} + +std::uint32_t Tensor::getVolume() { +#if defined(TT_RUNTIME_ENABLE_TTNN) + if (getCurrentRuntime() == DeviceRuntime::TTNN) { + return ::tt::runtime::ttnn::getVolume(*this); + } +#endif + +#if defined(TT_RUNTIME_ENABLE_TTMETAL) + if (getCurrentRuntime() == DeviceRuntime::TTMetal) { + return ::tt::runtime::ttmetal::getVolume(*this); + } +#endif + LOG_FATAL("runtime is not enabled"); +} + } // namespace tt::runtime diff --git a/runtime/lib/ttmetal/runtime.cpp b/runtime/lib/ttmetal/runtime.cpp index 255ef9da89..fe6050c4e6 100644 --- a/runtime/lib/ttmetal/runtime.cpp +++ b/runtime/lib/ttmetal/runtime.cpp @@ -363,4 +363,33 @@ std::vector getTensorData(Tensor tensor) { return {}; } +std::vector getDataBuffer(::tt::runtime::Tensor tensor) { + LOG_WARNING("getDataBuffer not implemented for metal runtime"); + return {}; +} + +std::vector getShape(::tt::runtime::Tensor tensor) { + LOG_WARNING("getShape not implemented for metal runtime"); + return {}; +} + +std::vector getStride(::tt::runtime::Tensor tensor) { + LOG_WARNING("getStride not implemented for metal runtime"); + return {}; +} + +std::uint32_t getElementSize(::tt::runtime::Tensor tensor) { + LOG_WARNING("getElementSize not implemented for metal runtime"); + return 0; +} + +std::uint32_t getVolume(::tt::runtime::Tensor tensor) { + LOG_WARNING("getVolume not implemented for metal runtime"); + return 0; +} +target::DataType getDtype(::tt::runtime::Tensor tensor) { + LOG_WARNING("getDtype not implemented for metal runtime"); + return {}; +} + } // namespace tt::runtime::ttmetal diff --git a/runtime/lib/ttnn/runtime.cpp b/runtime/lib/ttnn/runtime.cpp index f77dff0125..b92aaaf409 100644 --- a/runtime/lib/ttnn/runtime.cpp +++ b/runtime/lib/ttnn/runtime.cpp @@ -567,6 +567,83 @@ std::vector getTensorData(Tensor tensor) { static_cast(dataPtr) + nnTensor->volume()); } +std::vector getDataBuffer(::tt::runtime::Tensor tensor) { + auto ttnnTensor = static_cast<::ttnn::Tensor *>(tensor.handle.get()); + void *dataPtr = nullptr; + size_t numBytes = getElementSize(tensor) * getVolume(tensor); + std::vector dataVec(numBytes); + + // Need to `memcpy` in each case because the vector will go out of scope if we + // wait until after the switch case + switch (getDtype(tensor)) { + case target::DataType::BFP_BFloat4: + case target::DataType::BFP_BFloat8: + case target::DataType::Float32: + dataPtr = ttnnTensor->to_vector().data(); + assert(dataPtr != nullptr); + std::memcpy(dataVec.data(), dataPtr, numBytes); + return dataVec; + case target::DataType::BFloat16: + dataPtr = ttnnTensor->to_vector().data(); + assert(dataPtr != nullptr); + std::memcpy(dataVec.data(), dataPtr, numBytes); + return dataVec; + case target::DataType::Int32: + dataPtr = ttnnTensor->to_vector().data(); + assert(dataPtr != nullptr); + std::memcpy(dataVec.data(), dataPtr, numBytes); + return dataVec; + case target::DataType::UInt32: + dataPtr = ttnnTensor->to_vector().data(); + assert(dataPtr != nullptr); + std::memcpy(dataVec.data(), dataPtr, numBytes); + return dataVec; + case target::DataType::UInt16: + dataPtr = ttnnTensor->to_vector().data(); + assert(dataPtr != nullptr); + std::memcpy(dataVec.data(), dataPtr, numBytes); + return dataVec; + case target::DataType::UInt8: + dataPtr = ttnnTensor->to_vector().data(); + assert(dataPtr != nullptr); + std::memcpy(dataVec.data(), dataPtr, numBytes); + return dataVec; + default: + LOG_ERROR("Unsupported datatype for underlying TTNN tensor, returning " + "empty data vector"); + return {}; + } +} + +std::vector getShape(::tt::runtime::Tensor tensor) { + auto ttnnTensor = static_cast<::ttnn::Tensor *>(tensor.handle.get()); + std::vector shape(ttnnTensor->logical_shape().cbegin(), + ttnnTensor->logical_shape().cend()); + return shape; +} + +std::vector getStride(::tt::runtime::Tensor tensor) { + auto ttnnTensor = static_cast<::ttnn::Tensor *>(tensor.handle.get()); + std::vector stride(ttnnTensor->strides().cbegin(), + ttnnTensor->strides().cend()); + return stride; +} + +std::uint32_t getElementSize(::tt::runtime::Tensor tensor) { + auto ttnnTensor = static_cast<::ttnn::Tensor *>(tensor.handle.get()); + return ttnnTensor->element_size(); +} + +std::uint32_t getVolume(::tt::runtime::Tensor tensor) { + auto ttnnTensor = static_cast<::ttnn::Tensor *>(tensor.handle.get()); + return ttnnTensor->volume(); +} + +target::DataType getDtype(::tt::runtime::Tensor tensor) { + auto ttnnTensor = static_cast<::ttnn::Tensor *>(tensor.handle.get()); + return utils::fromTTNNDataType(ttnnTensor->dtype()); +} + std::vector submit(Device deviceHandle, Binary executableHandle, std::uint32_t programIndex, std::vector const &inputHandles) { diff --git a/runtime/test/python/ttnn/test_runtime_api.py b/runtime/test/python/ttnn/test_runtime_api.py index fc4cc9be79..98b86b0bf8 100644 --- a/runtime/test/python/ttnn/test_runtime_api.py +++ b/runtime/test/python/ttnn/test_runtime_api.py @@ -10,6 +10,34 @@ from utils import TT_MLIR_HOME, Helper, DeviceContext, assert_pcc +@pytest.mark.parametrize("shape", [(64, 128)]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16]) +def test_tensor_buffer_api(shape, dtype): + torch_tensor = torch.randn(shape, dtype=dtype) + runtime_dtype = Binary.Program.to_data_type(dtype) + rt_tensor = ttrt.runtime.create_tensor( + torch_tensor.data_ptr(), + list(torch_tensor.shape), + list(torch_tensor.stride()), + torch_tensor.element_size(), + runtime_dtype, + ) + rt_shape = rt_tensor.get_shape() + rt_elem_size = rt_tensor.get_element_size() + rt_vol = rt_tensor.get_volume() + rt_dtype = ttrt_datatype_to_torch_dtype(rt_tensor.get_dtype()) + rt_bytes = rt_tensor.get_data_buffer() + + # Various tests that the no underlying stuff has changed over the pybind boundary + assert list(rt_shape) == list(shape) + assert rt_elem_size == torch_tensor.element_size() + assert rt_vol == torch_tensor.numel() + assert rt_dtype == torch_tensor.dtype + assert len(rt_bytes) == rt_vol * rt_elem_size + reconstructed_tensor = torch.frombuffer(rt_bytes, dtype=rt_dtype).reshape(rt_shape) + assert torch.equal(torch_tensor, reconstructed_tensor) + + @pytest.mark.parametrize("shape", [(64, 128)]) @pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16]) def test_to_layout(helper: Helper, shape, dtype, request): diff --git a/runtime/tools/python/ttrt/common/util.py b/runtime/tools/python/ttrt/common/util.py index 1da498f465..dfe838a44f 100644 --- a/runtime/tools/python/ttrt/common/util.py +++ b/runtime/tools/python/ttrt/common/util.py @@ -53,8 +53,12 @@ def ttrt_datatype_to_torch_dtype(dtype) -> torch.dtype: return torch.uint16 elif dtype == DataType.UInt8: return torch.uint8 + elif dtype == DataType.BFloat16: + return torch.bfloat16 else: - raise ValueError("Only F32 and unsigned integers are supported in the runtime") + raise ValueError( + "Only F32, BF16, and unsigned integers are supported in the runtime" + ) def get_ttrt_metal_home_path(): diff --git a/runtime/tools/python/ttrt/runtime/module.cpp b/runtime/tools/python/ttrt/runtime/module.cpp index 72a15dabb1..09817037c1 100644 --- a/runtime/tools/python/ttrt/runtime/module.cpp +++ b/runtime/tools/python/ttrt/runtime/module.cpp @@ -39,7 +39,17 @@ PYBIND11_MODULE(_C, m) { .def("get_memory_view", &tt::runtime::detail::getMemoryView, py::arg("device_id") = 0); py::class_(m, "Event"); - py::class_(m, "Tensor"); + py::class_(m, "Tensor") + .def("get_shape", &tt::runtime::Tensor::getShape) + .def("get_stride", &tt::runtime::Tensor::getStride) + .def("get_volume", &tt::runtime::Tensor::getVolume) + .def("get_dtype", &tt::runtime::Tensor::getDtype) + .def("get_element_size", &tt::runtime::Tensor::getElementSize) + .def("get_data_buffer", [](tt::runtime::Tensor self) { + std::vector vec = self.getDataBuffer(); + return py::bytes(reinterpret_cast(vec.data()), + vec.size()); + }); py::class_(m, "Layout"); py::class_(m, "OpContext"); py::class_(m, "CallbackContext");