From 2eef566612936f04caaf866705ff16ff74c7d5a0 Mon Sep 17 00:00:00 2001 From: Collin Tod Date: Wed, 5 Mar 2025 17:51:15 +0000 Subject: [PATCH 1/8] Implement runtime tensor desc API This change implements an API on runtime tensors to reflect tensor metadata and contents reflecting the underlying TTNN tensor. This functionality is pybound as well, allowing for easy casting to torch tensors. Please see `runtime/test/python/ttnn/test_runtime_api.py` for an example. NOTE: This is not the most efficient implementation possible, as there is effectively a double copy to get the data in a pybindable row major format. There is probably some trickery that can be done later to avoid this, but I would like to get this functionality out ASAP and avoid premature optimization Closes #1957 --- runtime/include/tt/runtime/detail/ttmetal.h | 6 ++ runtime/include/tt/runtime/detail/ttnn.h | 7 ++ runtime/include/tt/runtime/types.h | 7 ++ runtime/lib/runtime.cpp | 90 ++++++++++++++++++++ runtime/lib/ttmetal/runtime.cpp | 29 +++++++ runtime/lib/ttnn/runtime.cpp | 77 +++++++++++++++++ runtime/test/python/ttnn/test_runtime_api.py | 28 ++++++ runtime/tools/python/ttrt/common/util.py | 6 +- runtime/tools/python/ttrt/runtime/module.cpp | 12 ++- 9 files changed, 260 insertions(+), 2 deletions(-) diff --git a/runtime/include/tt/runtime/detail/ttmetal.h b/runtime/include/tt/runtime/detail/ttmetal.h index d700cd4fc7..d7ff26fc40 100644 --- a/runtime/include/tt/runtime/detail/ttmetal.h +++ b/runtime/include/tt/runtime/detail/ttmetal.h @@ -30,6 +30,12 @@ inline Tensor createTensor(std::shared_ptr data, TensorDesc const &desc) { } tt::target::DataType getTensorDataType(Tensor tensor); +std::vector getDataBuffer(::tt::runtime::Tensor tensor); +std::vector getShape(::tt::runtime::Tensor tensor); +std::vector getStride(::tt::runtime::Tensor tensor); +std::uint32_t getElementSize(::tt::runtime::Tensor tensor); +std::uint32_t getVolume(::tt::runtime::Tensor tensor); +target::DataType getDtype(::tt::runtime::Tensor tensor); size_t getNumAvailableDevices(); diff --git a/runtime/include/tt/runtime/detail/ttnn.h b/runtime/include/tt/runtime/detail/ttnn.h index ac01626459..cfea0d1344 100644 --- a/runtime/include/tt/runtime/detail/ttnn.h +++ b/runtime/include/tt/runtime/detail/ttnn.h @@ -94,6 +94,13 @@ inline Tensor createTensor(Device device, Layout layout, tt::target::DataType getTensorDataType(Tensor tensor); +std::vector getDataBuffer(::tt::runtime::Tensor tensor); +std::vector getShape(::tt::runtime::Tensor tensor); +std::vector getStride(::tt::runtime::Tensor tensor); +std::uint32_t getElementSize(::tt::runtime::Tensor tensor); +std::uint32_t getVolume(::tt::runtime::Tensor tensor); +target::DataType getDtype(::tt::runtime::Tensor tensor); + size_t getNumAvailableDevices(); Device diff --git a/runtime/include/tt/runtime/types.h b/runtime/include/tt/runtime/types.h index e0f4e7b219..436dfc07b9 100644 --- a/runtime/include/tt/runtime/types.h +++ b/runtime/include/tt/runtime/types.h @@ -151,6 +151,13 @@ struct Tensor : public detail::RuntimeCheckedObjectImpl { std::shared_ptr eventHandle, DeviceRuntime runtime) : detail::RuntimeCheckedObjectImpl(handle, runtime), data(data), event(eventHandle, runtime) {} + + std::vector getDataBuffer(); + std::uint32_t getElementSize(); + std::uint32_t getVolume(); + std::vector getShape(); + std::vector getStride(); + target::DataType getDtype(); }; struct Layout : public detail::RuntimeCheckedObjectImpl { diff --git a/runtime/lib/runtime.cpp b/runtime/lib/runtime.cpp index 9023f99087..38560eba62 100644 --- a/runtime/lib/runtime.cpp +++ b/runtime/lib/runtime.cpp @@ -531,4 +531,94 @@ Event submit(Device deviceHandle, Binary executableHandle, #endif LOG_FATAL("runtime is not enabled"); } +std::vector Tensor::getDataBuffer() { +#if defined(TT_RUNTIME_ENABLE_TTNN) + if (getCurrentRuntime() == DeviceRuntime::TTNN) { + return ::tt::runtime::ttnn::getDataBuffer(*this); + } +#endif + +#if defined(TT_RUNTIME_ENABLE_TTMETAL) + if (getCurrentRuntime() == DeviceRuntime::TTMetal) { + return ::tt::runtime::ttmetal::getDataBuffer(*this); + } +#endif + LOG_FATAL("runtime is not enabled"); +} + +std::vector Tensor::getShape() { +#if defined(TT_RUNTIME_ENABLE_TTNN) + if (getCurrentRuntime() == DeviceRuntime::TTNN) { + return ::tt::runtime::ttnn::getShape(*this); + } +#endif + +#if defined(TT_RUNTIME_ENABLE_TTMETAL) + if (getCurrentRuntime() == DeviceRuntime::TTMetal) { + return ::tt::runtime::ttmetal::getShape(*this); + } +#endif + LOG_FATAL("runtime is not enabled"); +} + +std::vector Tensor::getStride() { +#if defined(TT_RUNTIME_ENABLE_TTNN) + if (getCurrentRuntime() == DeviceRuntime::TTNN) { + return ::tt::runtime::ttnn::getStride(*this); + } +#endif + +#if defined(TT_RUNTIME_ENABLE_TTMETAL) + if (getCurrentRuntime() == DeviceRuntime::TTMetal) { + return ::tt::runtime::ttmetal::getStride(*this); + } +#endif + LOG_FATAL("runtime is not enabled"); +} + +target::DataType Tensor::getDtype() { +#if defined(TT_RUNTIME_ENABLE_TTNN) + if (getCurrentRuntime() == DeviceRuntime::TTNN) { + return ::tt::runtime::ttnn::getDtype(*this); + } +#endif + +#if defined(TT_RUNTIME_ENABLE_TTMETAL) + if (getCurrentRuntime() == DeviceRuntime::TTMetal) { + return ::tt::runtime::ttmetal::getDtype(*this); + } +#endif + LOG_FATAL("runtime is not enabled"); +} + +std::uint32_t Tensor::getElementSize() { +#if defined(TT_RUNTIME_ENABLE_TTNN) + if (getCurrentRuntime() == DeviceRuntime::TTNN) { + return ::tt::runtime::ttnn::getElementSize(*this); + } +#endif + +#if defined(TT_RUNTIME_ENABLE_TTMETAL) + if (getCurrentRuntime() == DeviceRuntime::TTMetal) { + return ::tt::runtime::ttmetal::getElementSize(*this); + } +#endif + LOG_FATAL("runtime is not enabled"); +} + +std::uint32_t Tensor::getVolume() { +#if defined(TT_RUNTIME_ENABLE_TTNN) + if (getCurrentRuntime() == DeviceRuntime::TTNN) { + return ::tt::runtime::ttnn::getVolume(*this); + } +#endif + +#if defined(TT_RUNTIME_ENABLE_TTMETAL) + if (getCurrentRuntime() == DeviceRuntime::TTMetal) { + return ::tt::runtime::ttmetal::getVolume(*this); + } +#endif + LOG_FATAL("runtime is not enabled"); +} + } // namespace tt::runtime diff --git a/runtime/lib/ttmetal/runtime.cpp b/runtime/lib/ttmetal/runtime.cpp index 255ef9da89..fe6050c4e6 100644 --- a/runtime/lib/ttmetal/runtime.cpp +++ b/runtime/lib/ttmetal/runtime.cpp @@ -363,4 +363,33 @@ std::vector getTensorData(Tensor tensor) { return {}; } +std::vector getDataBuffer(::tt::runtime::Tensor tensor) { + LOG_WARNING("getDataBuffer not implemented for metal runtime"); + return {}; +} + +std::vector getShape(::tt::runtime::Tensor tensor) { + LOG_WARNING("getShape not implemented for metal runtime"); + return {}; +} + +std::vector getStride(::tt::runtime::Tensor tensor) { + LOG_WARNING("getStride not implemented for metal runtime"); + return {}; +} + +std::uint32_t getElementSize(::tt::runtime::Tensor tensor) { + LOG_WARNING("getElementSize not implemented for metal runtime"); + return 0; +} + +std::uint32_t getVolume(::tt::runtime::Tensor tensor) { + LOG_WARNING("getVolume not implemented for metal runtime"); + return 0; +} +target::DataType getDtype(::tt::runtime::Tensor tensor) { + LOG_WARNING("getDtype not implemented for metal runtime"); + return {}; +} + } // namespace tt::runtime::ttmetal diff --git a/runtime/lib/ttnn/runtime.cpp b/runtime/lib/ttnn/runtime.cpp index f77dff0125..b92aaaf409 100644 --- a/runtime/lib/ttnn/runtime.cpp +++ b/runtime/lib/ttnn/runtime.cpp @@ -567,6 +567,83 @@ std::vector getTensorData(Tensor tensor) { static_cast(dataPtr) + nnTensor->volume()); } +std::vector getDataBuffer(::tt::runtime::Tensor tensor) { + auto ttnnTensor = static_cast<::ttnn::Tensor *>(tensor.handle.get()); + void *dataPtr = nullptr; + size_t numBytes = getElementSize(tensor) * getVolume(tensor); + std::vector dataVec(numBytes); + + // Need to `memcpy` in each case because the vector will go out of scope if we + // wait until after the switch case + switch (getDtype(tensor)) { + case target::DataType::BFP_BFloat4: + case target::DataType::BFP_BFloat8: + case target::DataType::Float32: + dataPtr = ttnnTensor->to_vector().data(); + assert(dataPtr != nullptr); + std::memcpy(dataVec.data(), dataPtr, numBytes); + return dataVec; + case target::DataType::BFloat16: + dataPtr = ttnnTensor->to_vector().data(); + assert(dataPtr != nullptr); + std::memcpy(dataVec.data(), dataPtr, numBytes); + return dataVec; + case target::DataType::Int32: + dataPtr = ttnnTensor->to_vector().data(); + assert(dataPtr != nullptr); + std::memcpy(dataVec.data(), dataPtr, numBytes); + return dataVec; + case target::DataType::UInt32: + dataPtr = ttnnTensor->to_vector().data(); + assert(dataPtr != nullptr); + std::memcpy(dataVec.data(), dataPtr, numBytes); + return dataVec; + case target::DataType::UInt16: + dataPtr = ttnnTensor->to_vector().data(); + assert(dataPtr != nullptr); + std::memcpy(dataVec.data(), dataPtr, numBytes); + return dataVec; + case target::DataType::UInt8: + dataPtr = ttnnTensor->to_vector().data(); + assert(dataPtr != nullptr); + std::memcpy(dataVec.data(), dataPtr, numBytes); + return dataVec; + default: + LOG_ERROR("Unsupported datatype for underlying TTNN tensor, returning " + "empty data vector"); + return {}; + } +} + +std::vector getShape(::tt::runtime::Tensor tensor) { + auto ttnnTensor = static_cast<::ttnn::Tensor *>(tensor.handle.get()); + std::vector shape(ttnnTensor->logical_shape().cbegin(), + ttnnTensor->logical_shape().cend()); + return shape; +} + +std::vector getStride(::tt::runtime::Tensor tensor) { + auto ttnnTensor = static_cast<::ttnn::Tensor *>(tensor.handle.get()); + std::vector stride(ttnnTensor->strides().cbegin(), + ttnnTensor->strides().cend()); + return stride; +} + +std::uint32_t getElementSize(::tt::runtime::Tensor tensor) { + auto ttnnTensor = static_cast<::ttnn::Tensor *>(tensor.handle.get()); + return ttnnTensor->element_size(); +} + +std::uint32_t getVolume(::tt::runtime::Tensor tensor) { + auto ttnnTensor = static_cast<::ttnn::Tensor *>(tensor.handle.get()); + return ttnnTensor->volume(); +} + +target::DataType getDtype(::tt::runtime::Tensor tensor) { + auto ttnnTensor = static_cast<::ttnn::Tensor *>(tensor.handle.get()); + return utils::fromTTNNDataType(ttnnTensor->dtype()); +} + std::vector submit(Device deviceHandle, Binary executableHandle, std::uint32_t programIndex, std::vector const &inputHandles) { diff --git a/runtime/test/python/ttnn/test_runtime_api.py b/runtime/test/python/ttnn/test_runtime_api.py index fc4cc9be79..98b86b0bf8 100644 --- a/runtime/test/python/ttnn/test_runtime_api.py +++ b/runtime/test/python/ttnn/test_runtime_api.py @@ -10,6 +10,34 @@ from utils import TT_MLIR_HOME, Helper, DeviceContext, assert_pcc +@pytest.mark.parametrize("shape", [(64, 128)]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16]) +def test_tensor_buffer_api(shape, dtype): + torch_tensor = torch.randn(shape, dtype=dtype) + runtime_dtype = Binary.Program.to_data_type(dtype) + rt_tensor = ttrt.runtime.create_tensor( + torch_tensor.data_ptr(), + list(torch_tensor.shape), + list(torch_tensor.stride()), + torch_tensor.element_size(), + runtime_dtype, + ) + rt_shape = rt_tensor.get_shape() + rt_elem_size = rt_tensor.get_element_size() + rt_vol = rt_tensor.get_volume() + rt_dtype = ttrt_datatype_to_torch_dtype(rt_tensor.get_dtype()) + rt_bytes = rt_tensor.get_data_buffer() + + # Various tests that the no underlying stuff has changed over the pybind boundary + assert list(rt_shape) == list(shape) + assert rt_elem_size == torch_tensor.element_size() + assert rt_vol == torch_tensor.numel() + assert rt_dtype == torch_tensor.dtype + assert len(rt_bytes) == rt_vol * rt_elem_size + reconstructed_tensor = torch.frombuffer(rt_bytes, dtype=rt_dtype).reshape(rt_shape) + assert torch.equal(torch_tensor, reconstructed_tensor) + + @pytest.mark.parametrize("shape", [(64, 128)]) @pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16]) def test_to_layout(helper: Helper, shape, dtype, request): diff --git a/runtime/tools/python/ttrt/common/util.py b/runtime/tools/python/ttrt/common/util.py index 1da498f465..dfe838a44f 100644 --- a/runtime/tools/python/ttrt/common/util.py +++ b/runtime/tools/python/ttrt/common/util.py @@ -53,8 +53,12 @@ def ttrt_datatype_to_torch_dtype(dtype) -> torch.dtype: return torch.uint16 elif dtype == DataType.UInt8: return torch.uint8 + elif dtype == DataType.BFloat16: + return torch.bfloat16 else: - raise ValueError("Only F32 and unsigned integers are supported in the runtime") + raise ValueError( + "Only F32, BF16, and unsigned integers are supported in the runtime" + ) def get_ttrt_metal_home_path(): diff --git a/runtime/tools/python/ttrt/runtime/module.cpp b/runtime/tools/python/ttrt/runtime/module.cpp index 72a15dabb1..09817037c1 100644 --- a/runtime/tools/python/ttrt/runtime/module.cpp +++ b/runtime/tools/python/ttrt/runtime/module.cpp @@ -39,7 +39,17 @@ PYBIND11_MODULE(_C, m) { .def("get_memory_view", &tt::runtime::detail::getMemoryView, py::arg("device_id") = 0); py::class_(m, "Event"); - py::class_(m, "Tensor"); + py::class_(m, "Tensor") + .def("get_shape", &tt::runtime::Tensor::getShape) + .def("get_stride", &tt::runtime::Tensor::getStride) + .def("get_volume", &tt::runtime::Tensor::getVolume) + .def("get_dtype", &tt::runtime::Tensor::getDtype) + .def("get_element_size", &tt::runtime::Tensor::getElementSize) + .def("get_data_buffer", [](tt::runtime::Tensor self) { + std::vector vec = self.getDataBuffer(); + return py::bytes(reinterpret_cast(vec.data()), + vec.size()); + }); py::class_(m, "Layout"); py::class_(m, "OpContext"); py::class_(m, "CallbackContext"); From f66be99b5cf7156e3dc15e4e0599a86cfeb68b76 Mon Sep 17 00:00:00 2001 From: Collin Tod Date: Wed, 5 Mar 2025 18:38:29 +0000 Subject: [PATCH 2/8] Fix nits --- runtime/include/tt/runtime/detail/ttmetal.h | 1 - runtime/include/tt/runtime/detail/ttnn.h | 1 - runtime/lib/runtime.cpp | 4 +- runtime/lib/ttmetal/runtime.cpp | 4 -- runtime/lib/ttnn/runtime.cpp | 58 ++++++++++----------- 5 files changed, 31 insertions(+), 37 deletions(-) diff --git a/runtime/include/tt/runtime/detail/ttmetal.h b/runtime/include/tt/runtime/detail/ttmetal.h index d7ff26fc40..74e81c66f6 100644 --- a/runtime/include/tt/runtime/detail/ttmetal.h +++ b/runtime/include/tt/runtime/detail/ttmetal.h @@ -35,7 +35,6 @@ std::vector getShape(::tt::runtime::Tensor tensor); std::vector getStride(::tt::runtime::Tensor tensor); std::uint32_t getElementSize(::tt::runtime::Tensor tensor); std::uint32_t getVolume(::tt::runtime::Tensor tensor); -target::DataType getDtype(::tt::runtime::Tensor tensor); size_t getNumAvailableDevices(); diff --git a/runtime/include/tt/runtime/detail/ttnn.h b/runtime/include/tt/runtime/detail/ttnn.h index cfea0d1344..a2f0ff6d17 100644 --- a/runtime/include/tt/runtime/detail/ttnn.h +++ b/runtime/include/tt/runtime/detail/ttnn.h @@ -99,7 +99,6 @@ std::vector getShape(::tt::runtime::Tensor tensor); std::vector getStride(::tt::runtime::Tensor tensor); std::uint32_t getElementSize(::tt::runtime::Tensor tensor); std::uint32_t getVolume(::tt::runtime::Tensor tensor); -target::DataType getDtype(::tt::runtime::Tensor tensor); size_t getNumAvailableDevices(); diff --git a/runtime/lib/runtime.cpp b/runtime/lib/runtime.cpp index 38560eba62..bcf8b7786c 100644 --- a/runtime/lib/runtime.cpp +++ b/runtime/lib/runtime.cpp @@ -579,13 +579,13 @@ std::vector Tensor::getStride() { target::DataType Tensor::getDtype() { #if defined(TT_RUNTIME_ENABLE_TTNN) if (getCurrentRuntime() == DeviceRuntime::TTNN) { - return ::tt::runtime::ttnn::getDtype(*this); + return ::tt::runtime::ttnn::getTensorDataType(*this); } #endif #if defined(TT_RUNTIME_ENABLE_TTMETAL) if (getCurrentRuntime() == DeviceRuntime::TTMetal) { - return ::tt::runtime::ttmetal::getDtype(*this); + return ::tt::runtime::ttmetal::getTensorDataType(*this); } #endif LOG_FATAL("runtime is not enabled"); diff --git a/runtime/lib/ttmetal/runtime.cpp b/runtime/lib/ttmetal/runtime.cpp index fe6050c4e6..cd14e16a9b 100644 --- a/runtime/lib/ttmetal/runtime.cpp +++ b/runtime/lib/ttmetal/runtime.cpp @@ -387,9 +387,5 @@ std::uint32_t getVolume(::tt::runtime::Tensor tensor) { LOG_WARNING("getVolume not implemented for metal runtime"); return 0; } -target::DataType getDtype(::tt::runtime::Tensor tensor) { - LOG_WARNING("getDtype not implemented for metal runtime"); - return {}; -} } // namespace tt::runtime::ttmetal diff --git a/runtime/lib/ttnn/runtime.cpp b/runtime/lib/ttnn/runtime.cpp index b92aaaf409..ac23691295 100644 --- a/runtime/lib/ttnn/runtime.cpp +++ b/runtime/lib/ttnn/runtime.cpp @@ -568,44 +568,45 @@ std::vector getTensorData(Tensor tensor) { } std::vector getDataBuffer(::tt::runtime::Tensor tensor) { - auto ttnnTensor = static_cast<::ttnn::Tensor *>(tensor.handle.get()); + const ::ttnn::Tensor &ttnnTensor = + tensor.as<::ttnn::Tensor>(DeviceRuntime::TTNN); void *dataPtr = nullptr; size_t numBytes = getElementSize(tensor) * getVolume(tensor); std::vector dataVec(numBytes); // Need to `memcpy` in each case because the vector will go out of scope if we // wait until after the switch case - switch (getDtype(tensor)) { + switch (getTensorDataType(tensor)) { case target::DataType::BFP_BFloat4: case target::DataType::BFP_BFloat8: case target::DataType::Float32: - dataPtr = ttnnTensor->to_vector().data(); - assert(dataPtr != nullptr); + dataPtr = ttnnTensor.to_vector().data(); + LOG_ASSERT(dataPtr != nullptr); std::memcpy(dataVec.data(), dataPtr, numBytes); return dataVec; case target::DataType::BFloat16: - dataPtr = ttnnTensor->to_vector().data(); - assert(dataPtr != nullptr); + dataPtr = ttnnTensor.to_vector().data(); + LOG_ASSERT(dataPtr != nullptr); std::memcpy(dataVec.data(), dataPtr, numBytes); return dataVec; case target::DataType::Int32: - dataPtr = ttnnTensor->to_vector().data(); - assert(dataPtr != nullptr); + dataPtr = ttnnTensor.to_vector().data(); + LOG_ASSERT(dataPtr != nullptr); std::memcpy(dataVec.data(), dataPtr, numBytes); return dataVec; case target::DataType::UInt32: - dataPtr = ttnnTensor->to_vector().data(); - assert(dataPtr != nullptr); + dataPtr = ttnnTensor.to_vector().data(); + LOG_ASSERT(dataPtr != nullptr); std::memcpy(dataVec.data(), dataPtr, numBytes); return dataVec; case target::DataType::UInt16: - dataPtr = ttnnTensor->to_vector().data(); - assert(dataPtr != nullptr); + dataPtr = ttnnTensor.to_vector().data(); + LOG_ASSERT(dataPtr != nullptr); std::memcpy(dataVec.data(), dataPtr, numBytes); return dataVec; case target::DataType::UInt8: - dataPtr = ttnnTensor->to_vector().data(); - assert(dataPtr != nullptr); + dataPtr = ttnnTensor.to_vector().data(); + LOG_ASSERT(dataPtr != nullptr); std::memcpy(dataVec.data(), dataPtr, numBytes); return dataVec; default: @@ -616,32 +617,31 @@ std::vector getDataBuffer(::tt::runtime::Tensor tensor) { } std::vector getShape(::tt::runtime::Tensor tensor) { - auto ttnnTensor = static_cast<::ttnn::Tensor *>(tensor.handle.get()); - std::vector shape(ttnnTensor->logical_shape().cbegin(), - ttnnTensor->logical_shape().cend()); + const ::ttnn::Tensor &ttnnTensor = + tensor.as<::ttnn::Tensor>(DeviceRuntime::TTNN); + std::vector shape(ttnnTensor.logical_shape().cbegin(), + ttnnTensor.logical_shape().cend()); return shape; } std::vector getStride(::tt::runtime::Tensor tensor) { - auto ttnnTensor = static_cast<::ttnn::Tensor *>(tensor.handle.get()); - std::vector stride(ttnnTensor->strides().cbegin(), - ttnnTensor->strides().cend()); + const ::ttnn::Tensor &ttnnTensor = + tensor.as<::ttnn::Tensor>(DeviceRuntime::TTNN); + std::vector stride(ttnnTensor.strides().cbegin(), + ttnnTensor.strides().cend()); return stride; } std::uint32_t getElementSize(::tt::runtime::Tensor tensor) { - auto ttnnTensor = static_cast<::ttnn::Tensor *>(tensor.handle.get()); - return ttnnTensor->element_size(); + const ::ttnn::Tensor &ttnnTensor = + tensor.as<::ttnn::Tensor>(DeviceRuntime::TTNN); + return ttnnTensor.element_size(); } std::uint32_t getVolume(::tt::runtime::Tensor tensor) { - auto ttnnTensor = static_cast<::ttnn::Tensor *>(tensor.handle.get()); - return ttnnTensor->volume(); -} - -target::DataType getDtype(::tt::runtime::Tensor tensor) { - auto ttnnTensor = static_cast<::ttnn::Tensor *>(tensor.handle.get()); - return utils::fromTTNNDataType(ttnnTensor->dtype()); + const ::ttnn::Tensor &ttnnTensor = + tensor.as<::ttnn::Tensor>(DeviceRuntime::TTNN); + return ttnnTensor.volume(); } std::vector submit(Device deviceHandle, Binary executableHandle, From 197a1a3f78aaac6a1ede9495ece91c8f5d5902c7 Mon Sep 17 00:00:00 2001 From: Collin Tod Date: Wed, 5 Mar 2025 20:56:27 +0000 Subject: [PATCH 3/8] use tensordesc --- runtime/include/tt/runtime/detail/ttmetal.h | 1 + runtime/include/tt/runtime/detail/ttnn.h | 1 + runtime/include/tt/runtime/types.h | 1 + runtime/lib/runtime.cpp | 15 +++++++++++++ runtime/lib/ttmetal/runtime.cpp | 5 +++++ runtime/lib/ttnn/runtime.cpp | 21 +++++++++++++++---- runtime/test/python/ttnn/test_runtime_api.py | 9 ++++++++ runtime/tools/python/ttrt/runtime/__init__.py | 1 + runtime/tools/python/ttrt/runtime/module.cpp | 19 ++++++++++++----- 9 files changed, 64 insertions(+), 9 deletions(-) diff --git a/runtime/include/tt/runtime/detail/ttmetal.h b/runtime/include/tt/runtime/detail/ttmetal.h index 74e81c66f6..f7fc5b563a 100644 --- a/runtime/include/tt/runtime/detail/ttmetal.h +++ b/runtime/include/tt/runtime/detail/ttmetal.h @@ -35,6 +35,7 @@ std::vector getShape(::tt::runtime::Tensor tensor); std::vector getStride(::tt::runtime::Tensor tensor); std::uint32_t getElementSize(::tt::runtime::Tensor tensor); std::uint32_t getVolume(::tt::runtime::Tensor tensor); +TensorDesc getTensorDesc(::tt::runtime::Tensor tensor); size_t getNumAvailableDevices(); diff --git a/runtime/include/tt/runtime/detail/ttnn.h b/runtime/include/tt/runtime/detail/ttnn.h index a2f0ff6d17..64ce245255 100644 --- a/runtime/include/tt/runtime/detail/ttnn.h +++ b/runtime/include/tt/runtime/detail/ttnn.h @@ -99,6 +99,7 @@ std::vector getShape(::tt::runtime::Tensor tensor); std::vector getStride(::tt::runtime::Tensor tensor); std::uint32_t getElementSize(::tt::runtime::Tensor tensor); std::uint32_t getVolume(::tt::runtime::Tensor tensor); +TensorDesc getTensorDesc(::tt::runtime::Tensor tensor); size_t getNumAvailableDevices(); diff --git a/runtime/include/tt/runtime/types.h b/runtime/include/tt/runtime/types.h index 436dfc07b9..7ab6a7af81 100644 --- a/runtime/include/tt/runtime/types.h +++ b/runtime/include/tt/runtime/types.h @@ -158,6 +158,7 @@ struct Tensor : public detail::RuntimeCheckedObjectImpl { std::vector getShape(); std::vector getStride(); target::DataType getDtype(); + TensorDesc getTensorDesc(); }; struct Layout : public detail::RuntimeCheckedObjectImpl { diff --git a/runtime/lib/runtime.cpp b/runtime/lib/runtime.cpp index bcf8b7786c..a1861516d9 100644 --- a/runtime/lib/runtime.cpp +++ b/runtime/lib/runtime.cpp @@ -621,4 +621,19 @@ std::uint32_t Tensor::getVolume() { LOG_FATAL("runtime is not enabled"); } +TensorDesc Tensor::getTensorDesc() { +#if defined(TT_RUNTIME_ENABLE_TTNN) + if (getCurrentRuntime() == DeviceRuntime::TTNN) { + return ::tt::runtime::ttnn::getTensorDesc(*this); + } +#endif + +#if defined(TT_RUNTIME_ENABLE_TTMETAL) + if (getCurrentRuntime() == DeviceRuntime::TTMetal) { + return ::tt::runtime::ttmetal::getTensorDesc(*this); + } +#endif + LOG_FATAL("runtime is not enabled"); +} + } // namespace tt::runtime diff --git a/runtime/lib/ttmetal/runtime.cpp b/runtime/lib/ttmetal/runtime.cpp index cd14e16a9b..15f7c92658 100644 --- a/runtime/lib/ttmetal/runtime.cpp +++ b/runtime/lib/ttmetal/runtime.cpp @@ -388,4 +388,9 @@ std::uint32_t getVolume(::tt::runtime::Tensor tensor) { return 0; } +TensorDesc getTensorDesc(::tt::runtime::Tensor tensor) { + LOG_WARNING("getTensorDesc not implemented for metal runtime"); + return {}; +} + } // namespace tt::runtime::ttmetal diff --git a/runtime/lib/ttnn/runtime.cpp b/runtime/lib/ttnn/runtime.cpp index ac23691295..2698eb409e 100644 --- a/runtime/lib/ttnn/runtime.cpp +++ b/runtime/lib/ttnn/runtime.cpp @@ -619,16 +619,20 @@ std::vector getDataBuffer(::tt::runtime::Tensor tensor) { std::vector getShape(::tt::runtime::Tensor tensor) { const ::ttnn::Tensor &ttnnTensor = tensor.as<::ttnn::Tensor>(DeviceRuntime::TTNN); - std::vector shape(ttnnTensor.logical_shape().cbegin(), - ttnnTensor.logical_shape().cend()); + std::vector shape; + for (size_t i = 0; i < ttnnTensor.logical_shape().size(); ++i) { + shape.push_back(ttnnTensor.logical_shape()[i]); + } return shape; } std::vector getStride(::tt::runtime::Tensor tensor) { const ::ttnn::Tensor &ttnnTensor = tensor.as<::ttnn::Tensor>(DeviceRuntime::TTNN); - std::vector stride(ttnnTensor.strides().cbegin(), - ttnnTensor.strides().cend()); + std::vector stride; + for (size_t i = 0; i < ttnnTensor.strides().size(); ++i) { + stride.push_back(ttnnTensor.strides()[i]); + } return stride; } @@ -644,6 +648,15 @@ std::uint32_t getVolume(::tt::runtime::Tensor tensor) { return ttnnTensor.volume(); } +TensorDesc getTensorDesc(::tt::runtime::Tensor tensor) { + TensorDesc desc; + desc.dataType = getTensorDataType(tensor); + desc.itemsize = getElementSize(tensor); + desc.stride = getStride(tensor); + desc.shape = getShape(tensor); + return desc; +} + std::vector submit(Device deviceHandle, Binary executableHandle, std::uint32_t programIndex, std::vector const &inputHandles) { diff --git a/runtime/test/python/ttnn/test_runtime_api.py b/runtime/test/python/ttnn/test_runtime_api.py index 98b86b0bf8..6372b7d28a 100644 --- a/runtime/test/python/ttnn/test_runtime_api.py +++ b/runtime/test/python/ttnn/test_runtime_api.py @@ -23,13 +23,22 @@ def test_tensor_buffer_api(shape, dtype): runtime_dtype, ) rt_shape = rt_tensor.get_shape() + rt_stride = rt_tensor.get_stride() rt_elem_size = rt_tensor.get_element_size() rt_vol = rt_tensor.get_volume() rt_dtype = ttrt_datatype_to_torch_dtype(rt_tensor.get_dtype()) rt_bytes = rt_tensor.get_data_buffer() + rt_desc = rt_tensor.get_tensor_desc() + + # Tests to make sure the binding of `TensorDesc` works. Might belong in its own test? + assert rt_desc.shape == rt_shape + assert rt_desc.stride == rt_stride + assert ttrt_datatype_to_torch_dtype(rt_desc.dtype) == rt_dtype + assert rt_desc.item_size == rt_elem_size # Various tests that the no underlying stuff has changed over the pybind boundary assert list(rt_shape) == list(shape) + assert list(rt_stride) == list(torch_tensor.stride()) assert rt_elem_size == torch_tensor.element_size() assert rt_vol == torch_tensor.numel() assert rt_dtype == torch_tensor.dtype diff --git a/runtime/tools/python/ttrt/runtime/__init__.py b/runtime/tools/python/ttrt/runtime/__init__.py index a12f474708..68f2fca842 100644 --- a/runtime/tools/python/ttrt/runtime/__init__.py +++ b/runtime/tools/python/ttrt/runtime/__init__.py @@ -7,6 +7,7 @@ Device, Event, Tensor, + TensorDesc, MemoryBufferType, DataType, DeviceRuntime, diff --git a/runtime/tools/python/ttrt/runtime/module.cpp b/runtime/tools/python/ttrt/runtime/module.cpp index 09817037c1..724f29e333 100644 --- a/runtime/tools/python/ttrt/runtime/module.cpp +++ b/runtime/tools/python/ttrt/runtime/module.cpp @@ -39,17 +39,26 @@ PYBIND11_MODULE(_C, m) { .def("get_memory_view", &tt::runtime::detail::getMemoryView, py::arg("device_id") = 0); py::class_(m, "Event"); + py::class_(m, "TensorDesc") + .def_readonly("shape", &tt::runtime::TensorDesc::shape) + .def_readonly("stride", &tt::runtime::TensorDesc::stride) + .def_readonly("item_size", &tt::runtime::TensorDesc::itemsize) + .def_readonly("dtype", &tt::runtime::TensorDesc::dataType); py::class_(m, "Tensor") .def("get_shape", &tt::runtime::Tensor::getShape) .def("get_stride", &tt::runtime::Tensor::getStride) .def("get_volume", &tt::runtime::Tensor::getVolume) .def("get_dtype", &tt::runtime::Tensor::getDtype) .def("get_element_size", &tt::runtime::Tensor::getElementSize) - .def("get_data_buffer", [](tt::runtime::Tensor self) { - std::vector vec = self.getDataBuffer(); - return py::bytes(reinterpret_cast(vec.data()), - vec.size()); - }); + .def("get_tensor_desc", &tt::runtime::Tensor::getTensorDesc) + .def( + "get_data_buffer", + [](tt::runtime::Tensor self) { + std::vector vec = self.getDataBuffer(); + return py::bytes(reinterpret_cast(vec.data()), + vec.size()); + }, + py::return_value_policy::take_ownership); py::class_(m, "Layout"); py::class_(m, "OpContext"); py::class_(m, "CallbackContext"); From 5b8177de7be7d1499eeb9395cd9e3e7eee6fc44c Mon Sep 17 00:00:00 2001 From: Collin Tod Date: Thu, 6 Mar 2025 17:42:14 +0000 Subject: [PATCH 4/8] Fix block-type sizes and dangling pointer --- runtime/lib/ttnn/runtime.cpp | 54 ++++++++++++++++++++++++++---------- 1 file changed, 40 insertions(+), 14 deletions(-) diff --git a/runtime/lib/ttnn/runtime.cpp b/runtime/lib/ttnn/runtime.cpp index 2698eb409e..e7c3526ff6 100644 --- a/runtime/lib/ttnn/runtime.cpp +++ b/runtime/lib/ttnn/runtime.cpp @@ -577,38 +577,64 @@ std::vector getDataBuffer(::tt::runtime::Tensor tensor) { // Need to `memcpy` in each case because the vector will go out of scope if we // wait until after the switch case switch (getTensorDataType(tensor)) { - case target::DataType::BFP_BFloat4: - case target::DataType::BFP_BFloat8: - case target::DataType::Float32: - dataPtr = ttnnTensor.to_vector().data(); + case target::DataType::BFP_BFloat4: { + dataVec.resize(sizeof(float) * getVolume(tensor)); + auto vec = ttnnTensor.to_vector(); + dataPtr = vec.data(); LOG_ASSERT(dataPtr != nullptr); std::memcpy(dataVec.data(), dataPtr, numBytes); return dataVec; - case target::DataType::BFloat16: - dataPtr = ttnnTensor.to_vector().data(); + } + case target::DataType::BFP_BFloat8: { + dataVec.resize(sizeof(float) * getVolume(tensor)); + auto vec = ttnnTensor.to_vector(); + dataPtr = vec.data(); + LOG_ASSERT(dataPtr != nullptr); + std::memcpy(dataVec.data(), dataPtr, numBytes); + return dataVec; + } + case target::DataType::Float32: { + auto vec = ttnnTensor.to_vector(); + dataPtr = vec.data(); + LOG_ASSERT(dataPtr != nullptr); + std::memcpy(dataVec.data(), dataPtr, numBytes); + return dataVec; + } + case target::DataType::BFloat16: { + auto vec = ttnnTensor.to_vector(); + dataPtr = vec.data(); LOG_ASSERT(dataPtr != nullptr); std::memcpy(dataVec.data(), dataPtr, numBytes); return dataVec; - case target::DataType::Int32: - dataPtr = ttnnTensor.to_vector().data(); + } + case target::DataType::Int32: { + auto vec = ttnnTensor.to_vector(); + dataPtr = vec.data(); LOG_ASSERT(dataPtr != nullptr); std::memcpy(dataVec.data(), dataPtr, numBytes); return dataVec; - case target::DataType::UInt32: - dataPtr = ttnnTensor.to_vector().data(); + } + case target::DataType::UInt32: { + auto vec = ttnnTensor.to_vector(); + dataPtr = vec.data(); LOG_ASSERT(dataPtr != nullptr); std::memcpy(dataVec.data(), dataPtr, numBytes); return dataVec; - case target::DataType::UInt16: - dataPtr = ttnnTensor.to_vector().data(); + } + case target::DataType::UInt16: { + auto vec = ttnnTensor.to_vector(); + dataPtr = vec.data(); LOG_ASSERT(dataPtr != nullptr); std::memcpy(dataVec.data(), dataPtr, numBytes); return dataVec; - case target::DataType::UInt8: - dataPtr = ttnnTensor.to_vector().data(); + } + case target::DataType::UInt8: { + auto vec = ttnnTensor.to_vector(); + dataPtr = vec.data(); LOG_ASSERT(dataPtr != nullptr); std::memcpy(dataVec.data(), dataPtr, numBytes); return dataVec; + } default: LOG_ERROR("Unsupported datatype for underlying TTNN tensor, returning " "empty data vector"); From 490e719818a6b99da3b4719ab041b1bb8da481a8 Mon Sep 17 00:00:00 2001 From: Collin Tod Date: Thu, 6 Mar 2025 19:09:41 +0000 Subject: [PATCH 5/8] Remove now obsolete `getTensorData` Now all golden checks utilize the runtime tensor buffer API --- runtime/include/tt/runtime/detail/ttmetal.h | 2 -- runtime/include/tt/runtime/detail/ttnn.h | 2 -- runtime/include/tt/runtime/runtime.h | 2 -- runtime/lib/runtime.cpp | 16 ---------------- runtime/lib/ttmetal/runtime.cpp | 6 ------ runtime/lib/ttnn/runtime.cpp | 12 ------------ runtime/tools/python/ttrt/common/callback.py | 6 ++++-- runtime/tools/python/ttrt/runtime/module.cpp | 6 ++++-- 8 files changed, 8 insertions(+), 44 deletions(-) diff --git a/runtime/include/tt/runtime/detail/ttmetal.h b/runtime/include/tt/runtime/detail/ttmetal.h index f7fc5b563a..3ee36b861a 100644 --- a/runtime/include/tt/runtime/detail/ttmetal.h +++ b/runtime/include/tt/runtime/detail/ttmetal.h @@ -71,8 +71,6 @@ std::string getOpLocInfo(OpContext opContextHandle); Tensor getOpOutputTensor(OpContext opContextHandle, CallbackContext programContextHandle); -std::vector getTensorData(Tensor tensor); - using InputBuffer = std::tuple, std::shared_ptr<::tt::tt_metal::Event>>; diff --git a/runtime/include/tt/runtime/detail/ttnn.h b/runtime/include/tt/runtime/detail/ttnn.h index 64ce245255..9b0b808da1 100644 --- a/runtime/include/tt/runtime/detail/ttnn.h +++ b/runtime/include/tt/runtime/detail/ttnn.h @@ -144,8 +144,6 @@ std::string getOpLocInfo(OpContext opContextHandle); Tensor getOpOutputTensor(OpContext opContextHandle, CallbackContext programContextHandle); -std::vector getTensorData(Tensor tensor); - std::vector submit(Device deviceHandle, Binary executableHandle, std::uint32_t programIndex, std::vector const &inputs); diff --git a/runtime/include/tt/runtime/runtime.h b/runtime/include/tt/runtime/runtime.h index 5d4b3bafba..e76a3bf63e 100644 --- a/runtime/include/tt/runtime/runtime.h +++ b/runtime/include/tt/runtime/runtime.h @@ -126,8 +126,6 @@ std::string getOpLocInfo(OpContext opContextHandle); Tensor getOpOutputTensor(OpContext opContextHandle, CallbackContext programContextHandle); -std::vector getTensorData(Tensor tensor); - std::vector submit(Device deviceHandle, Binary executableHandle, std::uint32_t programIndex, std::vector const &inputs); diff --git a/runtime/lib/runtime.cpp b/runtime/lib/runtime.cpp index a1861516d9..53d0d961bc 100644 --- a/runtime/lib/runtime.cpp +++ b/runtime/lib/runtime.cpp @@ -477,22 +477,6 @@ Tensor getOpOutputTensor(OpContext opContextHandle, LOG_FATAL("runtime is not enabled"); } -std::vector getTensorData(Tensor tensor) { -#if defined(TT_RUNTIME_ENABLE_TTNN) - if (getCurrentRuntime() == DeviceRuntime::TTNN) { - return ::tt::runtime::ttnn::getTensorData(tensor); - } -#endif - -#if defined(TT_RUNTIME_ENABLE_TTMETAL) - if (getCurrentRuntime() == DeviceRuntime::TTMetal) { - return ::tt::runtime::ttmetal::getTensorData(tensor); - } -#endif - - LOG_FATAL("runtime is not enabled"); -} - std::vector submit(Device deviceHandle, Binary executableHandle, std::uint32_t programIndex, std::vector const &inputHandles) { diff --git a/runtime/lib/ttmetal/runtime.cpp b/runtime/lib/ttmetal/runtime.cpp index 15f7c92658..a660219259 100644 --- a/runtime/lib/ttmetal/runtime.cpp +++ b/runtime/lib/ttmetal/runtime.cpp @@ -357,12 +357,6 @@ Tensor getOpOutputTensor(OpContext opContextHandle, return createNullTensor(); } -std::vector getTensorData(Tensor tensor) { - // Not implemented - LOG_WARNING("obtaining tensor data for metal runtime not implemented"); - return {}; -} - std::vector getDataBuffer(::tt::runtime::Tensor tensor) { LOG_WARNING("getDataBuffer not implemented for metal runtime"); return {}; diff --git a/runtime/lib/ttnn/runtime.cpp b/runtime/lib/ttnn/runtime.cpp index e7c3526ff6..242640f8b9 100644 --- a/runtime/lib/ttnn/runtime.cpp +++ b/runtime/lib/ttnn/runtime.cpp @@ -555,18 +555,6 @@ Tensor getOpOutputTensor(OpContext opContextHandle, DeviceRuntime::TTNN); } -std::vector getTensorData(Tensor tensor) { - const ::ttnn::Tensor *nnTensor = - static_cast<::ttnn::Tensor *>(tensor.handle.get()); - if (nnTensor == nullptr) { - return {}; - } - - void *dataPtr = ::tt::tt_metal::get_raw_host_data_ptr(*nnTensor); - return std::vector(static_cast(dataPtr), - static_cast(dataPtr) + nnTensor->volume()); -} - std::vector getDataBuffer(::tt::runtime::Tensor tensor) { const ::ttnn::Tensor &ttnnTensor = tensor.as<::ttnn::Tensor>(DeviceRuntime::TTNN); diff --git a/runtime/tools/python/ttrt/common/callback.py b/runtime/tools/python/ttrt/common/callback.py index 93a9af267b..782e9094d2 100644 --- a/runtime/tools/python/ttrt/common/callback.py +++ b/runtime/tools/python/ttrt/common/callback.py @@ -212,15 +212,17 @@ def golden(callback_runtime_config, binary, program_context, op_context): op_output_tensor = ttrt.runtime.get_op_output_tensor(op_context, program_context) - if len(op_output_tensor) == 0: + if op_output_tensor is None: logging.debug("Output tensor is empty - skipping golden comparison") return + rt_buffer = op_output_tensor.get_data_buffer() dtype = ttrt_datatype_to_torch_dtype(op_golden_tensor.dtype) + assert ttrt_datatype_to_torch_dtype(op_output_tensor.get_dtype()) == dtype golden_tensor_torch = torch.frombuffer(op_golden_tensor, dtype=dtype).flatten() - output_tensor_torch = torch.tensor(op_output_tensor, dtype=dtype).flatten() + output_tensor_torch = torch.frombuffer(rt_buffer, dtype=dtype).flatten() if callback_runtime_config.save_golden_tensors: torch.save( diff --git a/runtime/tools/python/ttrt/runtime/module.cpp b/runtime/tools/python/ttrt/runtime/module.cpp index 724f29e333..2a54d4b411 100644 --- a/runtime/tools/python/ttrt/runtime/module.cpp +++ b/runtime/tools/python/ttrt/runtime/module.cpp @@ -196,9 +196,11 @@ PYBIND11_MODULE(_C, m) { tt::runtime::CallbackContext &programContextHandle) { tt::runtime::Tensor tensor = tt::runtime::getOpOutputTensor( opContextHandle, programContextHandle); - return tt::runtime::getTensorData(tensor); + return tensor.handle.get() == nullptr + ? std::nullopt + : std::optional(tensor); }, - "Get the input tensor of the op"); + "Get the output tensor of the op"); m.def("get_op_debug_str", &tt::runtime::getOpDebugString, "Get the debug string of the op"); m.def("get_op_loc_info", &tt::runtime::getOpLocInfo, From a90028d39d15063eea533c3ba4a29023fa2cbf53 Mon Sep 17 00:00:00 2001 From: Collin Tod Date: Thu, 6 Mar 2025 19:32:02 +0000 Subject: [PATCH 6/8] use correct sizes for memcpy --- runtime/lib/ttnn/runtime.cpp | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/runtime/lib/ttnn/runtime.cpp b/runtime/lib/ttnn/runtime.cpp index 242640f8b9..dce7dc768d 100644 --- a/runtime/lib/ttnn/runtime.cpp +++ b/runtime/lib/ttnn/runtime.cpp @@ -559,8 +559,7 @@ std::vector getDataBuffer(::tt::runtime::Tensor tensor) { const ::ttnn::Tensor &ttnnTensor = tensor.as<::ttnn::Tensor>(DeviceRuntime::TTNN); void *dataPtr = nullptr; - size_t numBytes = getElementSize(tensor) * getVolume(tensor); - std::vector dataVec(numBytes); + std::vector dataVec(getElementSize(tensor) * getVolume(tensor)); // Need to `memcpy` in each case because the vector will go out of scope if we // wait until after the switch case @@ -570,7 +569,7 @@ std::vector getDataBuffer(::tt::runtime::Tensor tensor) { auto vec = ttnnTensor.to_vector(); dataPtr = vec.data(); LOG_ASSERT(dataPtr != nullptr); - std::memcpy(dataVec.data(), dataPtr, numBytes); + std::memcpy(dataVec.data(), dataPtr, dataVec.size()); return dataVec; } case target::DataType::BFP_BFloat8: { @@ -578,49 +577,49 @@ std::vector getDataBuffer(::tt::runtime::Tensor tensor) { auto vec = ttnnTensor.to_vector(); dataPtr = vec.data(); LOG_ASSERT(dataPtr != nullptr); - std::memcpy(dataVec.data(), dataPtr, numBytes); + std::memcpy(dataVec.data(), dataPtr, dataVec.size()); return dataVec; } case target::DataType::Float32: { auto vec = ttnnTensor.to_vector(); dataPtr = vec.data(); LOG_ASSERT(dataPtr != nullptr); - std::memcpy(dataVec.data(), dataPtr, numBytes); + std::memcpy(dataVec.data(), dataPtr, dataVec.size()); return dataVec; } case target::DataType::BFloat16: { auto vec = ttnnTensor.to_vector(); dataPtr = vec.data(); LOG_ASSERT(dataPtr != nullptr); - std::memcpy(dataVec.data(), dataPtr, numBytes); + std::memcpy(dataVec.data(), dataPtr, dataVec.size()); return dataVec; } case target::DataType::Int32: { auto vec = ttnnTensor.to_vector(); dataPtr = vec.data(); LOG_ASSERT(dataPtr != nullptr); - std::memcpy(dataVec.data(), dataPtr, numBytes); + std::memcpy(dataVec.data(), dataPtr, dataVec.size()); return dataVec; } case target::DataType::UInt32: { auto vec = ttnnTensor.to_vector(); dataPtr = vec.data(); LOG_ASSERT(dataPtr != nullptr); - std::memcpy(dataVec.data(), dataPtr, numBytes); + std::memcpy(dataVec.data(), dataPtr, dataVec.size()); return dataVec; } case target::DataType::UInt16: { auto vec = ttnnTensor.to_vector(); dataPtr = vec.data(); LOG_ASSERT(dataPtr != nullptr); - std::memcpy(dataVec.data(), dataPtr, numBytes); + std::memcpy(dataVec.data(), dataPtr, dataVec.size()); return dataVec; } case target::DataType::UInt8: { auto vec = ttnnTensor.to_vector(); dataPtr = vec.data(); LOG_ASSERT(dataPtr != nullptr); - std::memcpy(dataVec.data(), dataPtr, numBytes); + std::memcpy(dataVec.data(), dataPtr, dataVec.size()); return dataVec; } default: From ef8fd1d81e700cfad62c1a58e630dc4845df9f22 Mon Sep 17 00:00:00 2001 From: Collin Tod Date: Fri, 7 Mar 2025 21:52:35 +0000 Subject: [PATCH 7/8] Make API consistent by removing methods --- runtime/include/tt/runtime/runtime.h | 6 +++ runtime/include/tt/runtime/types.h | 8 ---- runtime/lib/runtime.cpp | 42 ++++++++++---------- runtime/tools/python/ttrt/runtime/module.cpp | 28 +++++++++---- 4 files changed, 48 insertions(+), 36 deletions(-) diff --git a/runtime/include/tt/runtime/runtime.h b/runtime/include/tt/runtime/runtime.h index e76a3bf63e..042d7516d3 100644 --- a/runtime/include/tt/runtime/runtime.h +++ b/runtime/include/tt/runtime/runtime.h @@ -89,6 +89,12 @@ inline Tensor createTensor(Device device, Layout layout, } tt::target::DataType getTensorDataType(Tensor tensor); +std::vector getDataBuffer(Tensor tensor); +std::uint32_t getElementSize(Tensor tensor); +std::uint32_t getVolume(Tensor tensor); +std::vector getShape(Tensor tensor); +std::vector getStride(Tensor tensor); +TensorDesc getTensorDesc(Tensor tensor); size_t getNumAvailableDevices(); diff --git a/runtime/include/tt/runtime/types.h b/runtime/include/tt/runtime/types.h index 7ab6a7af81..e0f4e7b219 100644 --- a/runtime/include/tt/runtime/types.h +++ b/runtime/include/tt/runtime/types.h @@ -151,14 +151,6 @@ struct Tensor : public detail::RuntimeCheckedObjectImpl { std::shared_ptr eventHandle, DeviceRuntime runtime) : detail::RuntimeCheckedObjectImpl(handle, runtime), data(data), event(eventHandle, runtime) {} - - std::vector getDataBuffer(); - std::uint32_t getElementSize(); - std::uint32_t getVolume(); - std::vector getShape(); - std::vector getStride(); - target::DataType getDtype(); - TensorDesc getTensorDesc(); }; struct Layout : public detail::RuntimeCheckedObjectImpl { diff --git a/runtime/lib/runtime.cpp b/runtime/lib/runtime.cpp index 53d0d961bc..c20e43fdfa 100644 --- a/runtime/lib/runtime.cpp +++ b/runtime/lib/runtime.cpp @@ -515,106 +515,106 @@ Event submit(Device deviceHandle, Binary executableHandle, #endif LOG_FATAL("runtime is not enabled"); } -std::vector Tensor::getDataBuffer() { +std::vector getDataBuffer(::tt::runtime::Tensor t) { #if defined(TT_RUNTIME_ENABLE_TTNN) if (getCurrentRuntime() == DeviceRuntime::TTNN) { - return ::tt::runtime::ttnn::getDataBuffer(*this); + return ::tt::runtime::ttnn::getDataBuffer(t); } #endif #if defined(TT_RUNTIME_ENABLE_TTMETAL) if (getCurrentRuntime() == DeviceRuntime::TTMetal) { - return ::tt::runtime::ttmetal::getDataBuffer(*this); + return ::tt::runtime::ttmetal::getDataBuffer(t); } #endif LOG_FATAL("runtime is not enabled"); } -std::vector Tensor::getShape() { +std::vector getShape(::tt::runtime::Tensor t) { #if defined(TT_RUNTIME_ENABLE_TTNN) if (getCurrentRuntime() == DeviceRuntime::TTNN) { - return ::tt::runtime::ttnn::getShape(*this); + return ::tt::runtime::ttnn::getShape(t); } #endif #if defined(TT_RUNTIME_ENABLE_TTMETAL) if (getCurrentRuntime() == DeviceRuntime::TTMetal) { - return ::tt::runtime::ttmetal::getShape(*this); + return ::tt::runtime::ttmetal::getShape(t); } #endif LOG_FATAL("runtime is not enabled"); } -std::vector Tensor::getStride() { +std::vector getStride(::tt::runtime::Tensor t) { #if defined(TT_RUNTIME_ENABLE_TTNN) if (getCurrentRuntime() == DeviceRuntime::TTNN) { - return ::tt::runtime::ttnn::getStride(*this); + return ::tt::runtime::ttnn::getStride(t); } #endif #if defined(TT_RUNTIME_ENABLE_TTMETAL) if (getCurrentRuntime() == DeviceRuntime::TTMetal) { - return ::tt::runtime::ttmetal::getStride(*this); + return ::tt::runtime::ttmetal::getStride(t); } #endif LOG_FATAL("runtime is not enabled"); } -target::DataType Tensor::getDtype() { +target::DataType getDtype(::tt::runtime::Tensor t) { #if defined(TT_RUNTIME_ENABLE_TTNN) if (getCurrentRuntime() == DeviceRuntime::TTNN) { - return ::tt::runtime::ttnn::getTensorDataType(*this); + return ::tt::runtime::ttnn::getTensorDataType(t); } #endif #if defined(TT_RUNTIME_ENABLE_TTMETAL) if (getCurrentRuntime() == DeviceRuntime::TTMetal) { - return ::tt::runtime::ttmetal::getTensorDataType(*this); + return ::tt::runtime::ttmetal::getTensorDataType(t); } #endif LOG_FATAL("runtime is not enabled"); } -std::uint32_t Tensor::getElementSize() { +std::uint32_t getElementSize(::tt::runtime::Tensor t) { #if defined(TT_RUNTIME_ENABLE_TTNN) if (getCurrentRuntime() == DeviceRuntime::TTNN) { - return ::tt::runtime::ttnn::getElementSize(*this); + return ::tt::runtime::ttnn::getElementSize(t); } #endif #if defined(TT_RUNTIME_ENABLE_TTMETAL) if (getCurrentRuntime() == DeviceRuntime::TTMetal) { - return ::tt::runtime::ttmetal::getElementSize(*this); + return ::tt::runtime::ttmetal::getElementSize(t); } #endif LOG_FATAL("runtime is not enabled"); } -std::uint32_t Tensor::getVolume() { +std::uint32_t getVolume(::tt::runtime::Tensor t) { #if defined(TT_RUNTIME_ENABLE_TTNN) if (getCurrentRuntime() == DeviceRuntime::TTNN) { - return ::tt::runtime::ttnn::getVolume(*this); + return ::tt::runtime::ttnn::getVolume(t); } #endif #if defined(TT_RUNTIME_ENABLE_TTMETAL) if (getCurrentRuntime() == DeviceRuntime::TTMetal) { - return ::tt::runtime::ttmetal::getVolume(*this); + return ::tt::runtime::ttmetal::getVolume(t); } #endif LOG_FATAL("runtime is not enabled"); } -TensorDesc Tensor::getTensorDesc() { +TensorDesc getTensorDesc(::tt::runtime::Tensor t) { #if defined(TT_RUNTIME_ENABLE_TTNN) if (getCurrentRuntime() == DeviceRuntime::TTNN) { - return ::tt::runtime::ttnn::getTensorDesc(*this); + return ::tt::runtime::ttnn::getTensorDesc(t); } #endif #if defined(TT_RUNTIME_ENABLE_TTMETAL) if (getCurrentRuntime() == DeviceRuntime::TTMetal) { - return ::tt::runtime::ttmetal::getTensorDesc(*this); + return ::tt::runtime::ttmetal::getTensorDesc(t); } #endif LOG_FATAL("runtime is not enabled"); diff --git a/runtime/tools/python/ttrt/runtime/module.cpp b/runtime/tools/python/ttrt/runtime/module.cpp index 2a54d4b411..7b11c36ea7 100644 --- a/runtime/tools/python/ttrt/runtime/module.cpp +++ b/runtime/tools/python/ttrt/runtime/module.cpp @@ -45,16 +45,30 @@ PYBIND11_MODULE(_C, m) { .def_readonly("item_size", &tt::runtime::TensorDesc::itemsize) .def_readonly("dtype", &tt::runtime::TensorDesc::dataType); py::class_(m, "Tensor") - .def("get_shape", &tt::runtime::Tensor::getShape) - .def("get_stride", &tt::runtime::Tensor::getStride) - .def("get_volume", &tt::runtime::Tensor::getVolume) - .def("get_dtype", &tt::runtime::Tensor::getDtype) - .def("get_element_size", &tt::runtime::Tensor::getElementSize) - .def("get_tensor_desc", &tt::runtime::Tensor::getTensorDesc) + .def("get_shape", + [](tt::runtime::Tensor self) { return tt::runtime::getShape(self); }) + .def( + "get_stride", + [](tt::runtime::Tensor self) { return tt::runtime::getStride(self); }) + .def( + "get_volume", + [](tt::runtime::Tensor self) { return tt::runtime::getVolume(self); }) + .def("get_dtype", + [](tt::runtime::Tensor self) { + return tt::runtime::getTensorDataType(self); + }) + .def("get_element_size", + [](tt::runtime::Tensor self) { + return tt::runtime::getElementSize(self); + }) + .def("get_tensor_desc", + [](tt::runtime::Tensor self) { + return tt::runtime::getTensorDesc(self); + }) .def( "get_data_buffer", [](tt::runtime::Tensor self) { - std::vector vec = self.getDataBuffer(); + std::vector vec = tt::runtime::getDataBuffer(self); return py::bytes(reinterpret_cast(vec.data()), vec.size()); }, From 804531039e5d1e3bfbf9725077f6636c02bff829 Mon Sep 17 00:00:00 2001 From: Collin Tod Date: Fri, 7 Mar 2025 22:04:08 +0000 Subject: [PATCH 8/8] Change API names in C++ for consistency --- runtime/include/tt/runtime/detail/ttmetal.h | 10 +++---- runtime/include/tt/runtime/detail/ttnn.h | 10 +++---- runtime/include/tt/runtime/runtime.h | 10 +++---- runtime/lib/runtime.cpp | 30 ++++++++++---------- runtime/lib/ttmetal/runtime.cpp | 10 +++---- runtime/lib/ttnn/runtime.cpp | 23 ++++++++------- runtime/tools/python/ttrt/runtime/module.cpp | 22 ++++++++------ 7 files changed, 60 insertions(+), 55 deletions(-) diff --git a/runtime/include/tt/runtime/detail/ttmetal.h b/runtime/include/tt/runtime/detail/ttmetal.h index 3ee36b861a..05f968a900 100644 --- a/runtime/include/tt/runtime/detail/ttmetal.h +++ b/runtime/include/tt/runtime/detail/ttmetal.h @@ -30,11 +30,11 @@ inline Tensor createTensor(std::shared_ptr data, TensorDesc const &desc) { } tt::target::DataType getTensorDataType(Tensor tensor); -std::vector getDataBuffer(::tt::runtime::Tensor tensor); -std::vector getShape(::tt::runtime::Tensor tensor); -std::vector getStride(::tt::runtime::Tensor tensor); -std::uint32_t getElementSize(::tt::runtime::Tensor tensor); -std::uint32_t getVolume(::tt::runtime::Tensor tensor); +std::vector getTensorDataBuffer(::tt::runtime::Tensor tensor); +std::vector getTensorShape(::tt::runtime::Tensor tensor); +std::vector getTensorStride(::tt::runtime::Tensor tensor); +std::uint32_t getTensorElementSize(::tt::runtime::Tensor tensor); +std::uint32_t getTensorVolume(::tt::runtime::Tensor tensor); TensorDesc getTensorDesc(::tt::runtime::Tensor tensor); size_t getNumAvailableDevices(); diff --git a/runtime/include/tt/runtime/detail/ttnn.h b/runtime/include/tt/runtime/detail/ttnn.h index 9b0b808da1..82de11cca9 100644 --- a/runtime/include/tt/runtime/detail/ttnn.h +++ b/runtime/include/tt/runtime/detail/ttnn.h @@ -94,11 +94,11 @@ inline Tensor createTensor(Device device, Layout layout, tt::target::DataType getTensorDataType(Tensor tensor); -std::vector getDataBuffer(::tt::runtime::Tensor tensor); -std::vector getShape(::tt::runtime::Tensor tensor); -std::vector getStride(::tt::runtime::Tensor tensor); -std::uint32_t getElementSize(::tt::runtime::Tensor tensor); -std::uint32_t getVolume(::tt::runtime::Tensor tensor); +std::vector getTensorDataBuffer(::tt::runtime::Tensor tensor); +std::vector getTensorShape(::tt::runtime::Tensor tensor); +std::vector getTensorStride(::tt::runtime::Tensor tensor); +std::uint32_t getTensorElementSize(::tt::runtime::Tensor tensor); +std::uint32_t getTensorVolume(::tt::runtime::Tensor tensor); TensorDesc getTensorDesc(::tt::runtime::Tensor tensor); size_t getNumAvailableDevices(); diff --git a/runtime/include/tt/runtime/runtime.h b/runtime/include/tt/runtime/runtime.h index 042d7516d3..58d8c4f6d3 100644 --- a/runtime/include/tt/runtime/runtime.h +++ b/runtime/include/tt/runtime/runtime.h @@ -89,11 +89,11 @@ inline Tensor createTensor(Device device, Layout layout, } tt::target::DataType getTensorDataType(Tensor tensor); -std::vector getDataBuffer(Tensor tensor); -std::uint32_t getElementSize(Tensor tensor); -std::uint32_t getVolume(Tensor tensor); -std::vector getShape(Tensor tensor); -std::vector getStride(Tensor tensor); +std::vector getTensorDataBuffer(Tensor tensor); +std::uint32_t getTensorElementSize(Tensor tensor); +std::uint32_t getTensorVolume(Tensor tensor); +std::vector getTensorShape(Tensor tensor); +std::vector getTensorStride(Tensor tensor); TensorDesc getTensorDesc(Tensor tensor); size_t getNumAvailableDevices(); diff --git a/runtime/lib/runtime.cpp b/runtime/lib/runtime.cpp index c20e43fdfa..1d8682e0a8 100644 --- a/runtime/lib/runtime.cpp +++ b/runtime/lib/runtime.cpp @@ -515,46 +515,46 @@ Event submit(Device deviceHandle, Binary executableHandle, #endif LOG_FATAL("runtime is not enabled"); } -std::vector getDataBuffer(::tt::runtime::Tensor t) { +std::vector getTensorDataBuffer(::tt::runtime::Tensor t) { #if defined(TT_RUNTIME_ENABLE_TTNN) if (getCurrentRuntime() == DeviceRuntime::TTNN) { - return ::tt::runtime::ttnn::getDataBuffer(t); + return ::tt::runtime::ttnn::getTensorDataBuffer(t); } #endif #if defined(TT_RUNTIME_ENABLE_TTMETAL) if (getCurrentRuntime() == DeviceRuntime::TTMetal) { - return ::tt::runtime::ttmetal::getDataBuffer(t); + return ::tt::runtime::ttmetal::getTensorDataBuffer(t); } #endif LOG_FATAL("runtime is not enabled"); } -std::vector getShape(::tt::runtime::Tensor t) { +std::vector getTensorShape(::tt::runtime::Tensor t) { #if defined(TT_RUNTIME_ENABLE_TTNN) if (getCurrentRuntime() == DeviceRuntime::TTNN) { - return ::tt::runtime::ttnn::getShape(t); + return ::tt::runtime::ttnn::getTensorShape(t); } #endif #if defined(TT_RUNTIME_ENABLE_TTMETAL) if (getCurrentRuntime() == DeviceRuntime::TTMetal) { - return ::tt::runtime::ttmetal::getShape(t); + return ::tt::runtime::ttmetal::getTensorShape(t); } #endif LOG_FATAL("runtime is not enabled"); } -std::vector getStride(::tt::runtime::Tensor t) { +std::vector getTensorStride(::tt::runtime::Tensor t) { #if defined(TT_RUNTIME_ENABLE_TTNN) if (getCurrentRuntime() == DeviceRuntime::TTNN) { - return ::tt::runtime::ttnn::getStride(t); + return ::tt::runtime::ttnn::getTensorStride(t); } #endif #if defined(TT_RUNTIME_ENABLE_TTMETAL) if (getCurrentRuntime() == DeviceRuntime::TTMetal) { - return ::tt::runtime::ttmetal::getStride(t); + return ::tt::runtime::ttmetal::getTensorStride(t); } #endif LOG_FATAL("runtime is not enabled"); @@ -575,31 +575,31 @@ target::DataType getDtype(::tt::runtime::Tensor t) { LOG_FATAL("runtime is not enabled"); } -std::uint32_t getElementSize(::tt::runtime::Tensor t) { +std::uint32_t getTensorElementSize(::tt::runtime::Tensor t) { #if defined(TT_RUNTIME_ENABLE_TTNN) if (getCurrentRuntime() == DeviceRuntime::TTNN) { - return ::tt::runtime::ttnn::getElementSize(t); + return ::tt::runtime::ttnn::getTensorElementSize(t); } #endif #if defined(TT_RUNTIME_ENABLE_TTMETAL) if (getCurrentRuntime() == DeviceRuntime::TTMetal) { - return ::tt::runtime::ttmetal::getElementSize(t); + return ::tt::runtime::ttmetal::getTensorElementSize(t); } #endif LOG_FATAL("runtime is not enabled"); } -std::uint32_t getVolume(::tt::runtime::Tensor t) { +std::uint32_t getTensorVolume(::tt::runtime::Tensor t) { #if defined(TT_RUNTIME_ENABLE_TTNN) if (getCurrentRuntime() == DeviceRuntime::TTNN) { - return ::tt::runtime::ttnn::getVolume(t); + return ::tt::runtime::ttnn::getTensorVolume(t); } #endif #if defined(TT_RUNTIME_ENABLE_TTMETAL) if (getCurrentRuntime() == DeviceRuntime::TTMetal) { - return ::tt::runtime::ttmetal::getVolume(t); + return ::tt::runtime::ttmetal::getTensorVolume(t); } #endif LOG_FATAL("runtime is not enabled"); diff --git a/runtime/lib/ttmetal/runtime.cpp b/runtime/lib/ttmetal/runtime.cpp index a660219259..bc838efda7 100644 --- a/runtime/lib/ttmetal/runtime.cpp +++ b/runtime/lib/ttmetal/runtime.cpp @@ -357,27 +357,27 @@ Tensor getOpOutputTensor(OpContext opContextHandle, return createNullTensor(); } -std::vector getDataBuffer(::tt::runtime::Tensor tensor) { +std::vector getTensorDataBuffer(::tt::runtime::Tensor tensor) { LOG_WARNING("getDataBuffer not implemented for metal runtime"); return {}; } -std::vector getShape(::tt::runtime::Tensor tensor) { +std::vector getTensorShape(::tt::runtime::Tensor tensor) { LOG_WARNING("getShape not implemented for metal runtime"); return {}; } -std::vector getStride(::tt::runtime::Tensor tensor) { +std::vector getTensorStride(::tt::runtime::Tensor tensor) { LOG_WARNING("getStride not implemented for metal runtime"); return {}; } -std::uint32_t getElementSize(::tt::runtime::Tensor tensor) { +std::uint32_t getTensorElementSize(::tt::runtime::Tensor tensor) { LOG_WARNING("getElementSize not implemented for metal runtime"); return 0; } -std::uint32_t getVolume(::tt::runtime::Tensor tensor) { +std::uint32_t getTensorVolume(::tt::runtime::Tensor tensor) { LOG_WARNING("getVolume not implemented for metal runtime"); return 0; } diff --git a/runtime/lib/ttnn/runtime.cpp b/runtime/lib/ttnn/runtime.cpp index dce7dc768d..7e97cdc035 100644 --- a/runtime/lib/ttnn/runtime.cpp +++ b/runtime/lib/ttnn/runtime.cpp @@ -555,17 +555,18 @@ Tensor getOpOutputTensor(OpContext opContextHandle, DeviceRuntime::TTNN); } -std::vector getDataBuffer(::tt::runtime::Tensor tensor) { +std::vector getTensorDataBuffer(::tt::runtime::Tensor tensor) { const ::ttnn::Tensor &ttnnTensor = tensor.as<::ttnn::Tensor>(DeviceRuntime::TTNN); void *dataPtr = nullptr; - std::vector dataVec(getElementSize(tensor) * getVolume(tensor)); + std::vector dataVec(getTensorElementSize(tensor) * + getTensorVolume(tensor)); // Need to `memcpy` in each case because the vector will go out of scope if we // wait until after the switch case switch (getTensorDataType(tensor)) { case target::DataType::BFP_BFloat4: { - dataVec.resize(sizeof(float) * getVolume(tensor)); + dataVec.resize(sizeof(float) * getTensorVolume(tensor)); auto vec = ttnnTensor.to_vector(); dataPtr = vec.data(); LOG_ASSERT(dataPtr != nullptr); @@ -573,7 +574,7 @@ std::vector getDataBuffer(::tt::runtime::Tensor tensor) { return dataVec; } case target::DataType::BFP_BFloat8: { - dataVec.resize(sizeof(float) * getVolume(tensor)); + dataVec.resize(sizeof(float) * getTensorVolume(tensor)); auto vec = ttnnTensor.to_vector(); dataPtr = vec.data(); LOG_ASSERT(dataPtr != nullptr); @@ -629,7 +630,7 @@ std::vector getDataBuffer(::tt::runtime::Tensor tensor) { } } -std::vector getShape(::tt::runtime::Tensor tensor) { +std::vector getTensorShape(::tt::runtime::Tensor tensor) { const ::ttnn::Tensor &ttnnTensor = tensor.as<::ttnn::Tensor>(DeviceRuntime::TTNN); std::vector shape; @@ -639,7 +640,7 @@ std::vector getShape(::tt::runtime::Tensor tensor) { return shape; } -std::vector getStride(::tt::runtime::Tensor tensor) { +std::vector getTensorStride(::tt::runtime::Tensor tensor) { const ::ttnn::Tensor &ttnnTensor = tensor.as<::ttnn::Tensor>(DeviceRuntime::TTNN); std::vector stride; @@ -649,13 +650,13 @@ std::vector getStride(::tt::runtime::Tensor tensor) { return stride; } -std::uint32_t getElementSize(::tt::runtime::Tensor tensor) { +std::uint32_t getTensorElementSize(::tt::runtime::Tensor tensor) { const ::ttnn::Tensor &ttnnTensor = tensor.as<::ttnn::Tensor>(DeviceRuntime::TTNN); return ttnnTensor.element_size(); } -std::uint32_t getVolume(::tt::runtime::Tensor tensor) { +std::uint32_t getTensorVolume(::tt::runtime::Tensor tensor) { const ::ttnn::Tensor &ttnnTensor = tensor.as<::ttnn::Tensor>(DeviceRuntime::TTNN); return ttnnTensor.volume(); @@ -664,9 +665,9 @@ std::uint32_t getVolume(::tt::runtime::Tensor tensor) { TensorDesc getTensorDesc(::tt::runtime::Tensor tensor) { TensorDesc desc; desc.dataType = getTensorDataType(tensor); - desc.itemsize = getElementSize(tensor); - desc.stride = getStride(tensor); - desc.shape = getShape(tensor); + desc.itemsize = getTensorElementSize(tensor); + desc.stride = getTensorStride(tensor); + desc.shape = getTensorShape(tensor); return desc; } diff --git a/runtime/tools/python/ttrt/runtime/module.cpp b/runtime/tools/python/ttrt/runtime/module.cpp index 7b11c36ea7..7763da910b 100644 --- a/runtime/tools/python/ttrt/runtime/module.cpp +++ b/runtime/tools/python/ttrt/runtime/module.cpp @@ -46,20 +46,24 @@ PYBIND11_MODULE(_C, m) { .def_readonly("dtype", &tt::runtime::TensorDesc::dataType); py::class_(m, "Tensor") .def("get_shape", - [](tt::runtime::Tensor self) { return tt::runtime::getShape(self); }) - .def( - "get_stride", - [](tt::runtime::Tensor self) { return tt::runtime::getStride(self); }) - .def( - "get_volume", - [](tt::runtime::Tensor self) { return tt::runtime::getVolume(self); }) + [](tt::runtime::Tensor self) { + return tt::runtime::getTensorShape(self); + }) + .def("get_stride", + [](tt::runtime::Tensor self) { + return tt::runtime::getTensorStride(self); + }) + .def("get_volume", + [](tt::runtime::Tensor self) { + return tt::runtime::getTensorVolume(self); + }) .def("get_dtype", [](tt::runtime::Tensor self) { return tt::runtime::getTensorDataType(self); }) .def("get_element_size", [](tt::runtime::Tensor self) { - return tt::runtime::getElementSize(self); + return tt::runtime::getTensorElementSize(self); }) .def("get_tensor_desc", [](tt::runtime::Tensor self) { @@ -68,7 +72,7 @@ PYBIND11_MODULE(_C, m) { .def( "get_data_buffer", [](tt::runtime::Tensor self) { - std::vector vec = tt::runtime::getDataBuffer(self); + std::vector vec = tt::runtime::getTensorDataBuffer(self); return py::bytes(reinterpret_cast(vec.data()), vec.size()); },