diff --git a/runtime/include/tt/runtime/detail/ttmetal.h b/runtime/include/tt/runtime/detail/ttmetal.h index d700cd4fc7..d7ff26fc40 100644 --- a/runtime/include/tt/runtime/detail/ttmetal.h +++ b/runtime/include/tt/runtime/detail/ttmetal.h @@ -30,6 +30,12 @@ inline Tensor createTensor(std::shared_ptr data, TensorDesc const &desc) { } tt::target::DataType getTensorDataType(Tensor tensor); +std::vector getDataBuffer(::tt::runtime::Tensor tensor); +std::vector getShape(::tt::runtime::Tensor tensor); +std::vector getStride(::tt::runtime::Tensor tensor); +std::uint32_t getElementSize(::tt::runtime::Tensor tensor); +std::uint32_t getVolume(::tt::runtime::Tensor tensor); +target::DataType getDtype(::tt::runtime::Tensor tensor); size_t getNumAvailableDevices(); diff --git a/runtime/include/tt/runtime/detail/ttnn.h b/runtime/include/tt/runtime/detail/ttnn.h index ac01626459..cfea0d1344 100644 --- a/runtime/include/tt/runtime/detail/ttnn.h +++ b/runtime/include/tt/runtime/detail/ttnn.h @@ -94,6 +94,13 @@ inline Tensor createTensor(Device device, Layout layout, tt::target::DataType getTensorDataType(Tensor tensor); +std::vector getDataBuffer(::tt::runtime::Tensor tensor); +std::vector getShape(::tt::runtime::Tensor tensor); +std::vector getStride(::tt::runtime::Tensor tensor); +std::uint32_t getElementSize(::tt::runtime::Tensor tensor); +std::uint32_t getVolume(::tt::runtime::Tensor tensor); +target::DataType getDtype(::tt::runtime::Tensor tensor); + size_t getNumAvailableDevices(); Device diff --git a/runtime/include/tt/runtime/types.h b/runtime/include/tt/runtime/types.h index e0f4e7b219..436dfc07b9 100644 --- a/runtime/include/tt/runtime/types.h +++ b/runtime/include/tt/runtime/types.h @@ -151,6 +151,13 @@ struct Tensor : public detail::RuntimeCheckedObjectImpl { std::shared_ptr eventHandle, DeviceRuntime runtime) : detail::RuntimeCheckedObjectImpl(handle, runtime), data(data), event(eventHandle, runtime) {} + + std::vector getDataBuffer(); + std::uint32_t getElementSize(); + std::uint32_t getVolume(); + std::vector getShape(); + std::vector getStride(); + target::DataType getDtype(); }; struct Layout : public detail::RuntimeCheckedObjectImpl { diff --git a/runtime/lib/runtime.cpp b/runtime/lib/runtime.cpp index 9023f99087..38560eba62 100644 --- a/runtime/lib/runtime.cpp +++ b/runtime/lib/runtime.cpp @@ -531,4 +531,94 @@ Event submit(Device deviceHandle, Binary executableHandle, #endif LOG_FATAL("runtime is not enabled"); } +std::vector Tensor::getDataBuffer() { +#if defined(TT_RUNTIME_ENABLE_TTNN) + if (getCurrentRuntime() == DeviceRuntime::TTNN) { + return ::tt::runtime::ttnn::getDataBuffer(*this); + } +#endif + +#if defined(TT_RUNTIME_ENABLE_TTMETAL) + if (getCurrentRuntime() == DeviceRuntime::TTMetal) { + return ::tt::runtime::ttmetal::getDataBuffer(*this); + } +#endif + LOG_FATAL("runtime is not enabled"); +} + +std::vector Tensor::getShape() { +#if defined(TT_RUNTIME_ENABLE_TTNN) + if (getCurrentRuntime() == DeviceRuntime::TTNN) { + return ::tt::runtime::ttnn::getShape(*this); + } +#endif + +#if defined(TT_RUNTIME_ENABLE_TTMETAL) + if (getCurrentRuntime() == DeviceRuntime::TTMetal) { + return ::tt::runtime::ttmetal::getShape(*this); + } +#endif + LOG_FATAL("runtime is not enabled"); +} + +std::vector Tensor::getStride() { +#if defined(TT_RUNTIME_ENABLE_TTNN) + if (getCurrentRuntime() == DeviceRuntime::TTNN) { + return ::tt::runtime::ttnn::getStride(*this); + } +#endif + +#if defined(TT_RUNTIME_ENABLE_TTMETAL) + if (getCurrentRuntime() == DeviceRuntime::TTMetal) { + return ::tt::runtime::ttmetal::getStride(*this); + } +#endif + LOG_FATAL("runtime is not enabled"); +} + +target::DataType Tensor::getDtype() { +#if defined(TT_RUNTIME_ENABLE_TTNN) + if (getCurrentRuntime() == DeviceRuntime::TTNN) { + return ::tt::runtime::ttnn::getDtype(*this); + } +#endif + +#if defined(TT_RUNTIME_ENABLE_TTMETAL) + if (getCurrentRuntime() == DeviceRuntime::TTMetal) { + return ::tt::runtime::ttmetal::getDtype(*this); + } +#endif + LOG_FATAL("runtime is not enabled"); +} + +std::uint32_t Tensor::getElementSize() { +#if defined(TT_RUNTIME_ENABLE_TTNN) + if (getCurrentRuntime() == DeviceRuntime::TTNN) { + return ::tt::runtime::ttnn::getElementSize(*this); + } +#endif + +#if defined(TT_RUNTIME_ENABLE_TTMETAL) + if (getCurrentRuntime() == DeviceRuntime::TTMetal) { + return ::tt::runtime::ttmetal::getElementSize(*this); + } +#endif + LOG_FATAL("runtime is not enabled"); +} + +std::uint32_t Tensor::getVolume() { +#if defined(TT_RUNTIME_ENABLE_TTNN) + if (getCurrentRuntime() == DeviceRuntime::TTNN) { + return ::tt::runtime::ttnn::getVolume(*this); + } +#endif + +#if defined(TT_RUNTIME_ENABLE_TTMETAL) + if (getCurrentRuntime() == DeviceRuntime::TTMetal) { + return ::tt::runtime::ttmetal::getVolume(*this); + } +#endif + LOG_FATAL("runtime is not enabled"); +} + } // namespace tt::runtime diff --git a/runtime/lib/ttmetal/runtime.cpp b/runtime/lib/ttmetal/runtime.cpp index 255ef9da89..fe6050c4e6 100644 --- a/runtime/lib/ttmetal/runtime.cpp +++ b/runtime/lib/ttmetal/runtime.cpp @@ -363,4 +363,33 @@ std::vector getTensorData(Tensor tensor) { return {}; } +std::vector getDataBuffer(::tt::runtime::Tensor tensor) { + LOG_WARNING("getDataBuffer not implemented for metal runtime"); + return {}; +} + +std::vector getShape(::tt::runtime::Tensor tensor) { + LOG_WARNING("getShape not implemented for metal runtime"); + return {}; +} + +std::vector getStride(::tt::runtime::Tensor tensor) { + LOG_WARNING("getStride not implemented for metal runtime"); + return {}; +} + +std::uint32_t getElementSize(::tt::runtime::Tensor tensor) { + LOG_WARNING("getElementSize not implemented for metal runtime"); + return 0; +} + +std::uint32_t getVolume(::tt::runtime::Tensor tensor) { + LOG_WARNING("getVolume not implemented for metal runtime"); + return 0; +} +target::DataType getDtype(::tt::runtime::Tensor tensor) { + LOG_WARNING("getDtype not implemented for metal runtime"); + return {}; +} + } // namespace tt::runtime::ttmetal diff --git a/runtime/lib/ttnn/runtime.cpp b/runtime/lib/ttnn/runtime.cpp index f77dff0125..b92aaaf409 100644 --- a/runtime/lib/ttnn/runtime.cpp +++ b/runtime/lib/ttnn/runtime.cpp @@ -567,6 +567,83 @@ std::vector getTensorData(Tensor tensor) { static_cast(dataPtr) + nnTensor->volume()); } +std::vector getDataBuffer(::tt::runtime::Tensor tensor) { + auto ttnnTensor = static_cast<::ttnn::Tensor *>(tensor.handle.get()); + void *dataPtr = nullptr; + size_t numBytes = getElementSize(tensor) * getVolume(tensor); + std::vector dataVec(numBytes); + + // Need to `memcpy` in each case because the vector will go out of scope if we + // wait until after the switch case + switch (getDtype(tensor)) { + case target::DataType::BFP_BFloat4: + case target::DataType::BFP_BFloat8: + case target::DataType::Float32: + dataPtr = ttnnTensor->to_vector().data(); + assert(dataPtr != nullptr); + std::memcpy(dataVec.data(), dataPtr, numBytes); + return dataVec; + case target::DataType::BFloat16: + dataPtr = ttnnTensor->to_vector().data(); + assert(dataPtr != nullptr); + std::memcpy(dataVec.data(), dataPtr, numBytes); + return dataVec; + case target::DataType::Int32: + dataPtr = ttnnTensor->to_vector().data(); + assert(dataPtr != nullptr); + std::memcpy(dataVec.data(), dataPtr, numBytes); + return dataVec; + case target::DataType::UInt32: + dataPtr = ttnnTensor->to_vector().data(); + assert(dataPtr != nullptr); + std::memcpy(dataVec.data(), dataPtr, numBytes); + return dataVec; + case target::DataType::UInt16: + dataPtr = ttnnTensor->to_vector().data(); + assert(dataPtr != nullptr); + std::memcpy(dataVec.data(), dataPtr, numBytes); + return dataVec; + case target::DataType::UInt8: + dataPtr = ttnnTensor->to_vector().data(); + assert(dataPtr != nullptr); + std::memcpy(dataVec.data(), dataPtr, numBytes); + return dataVec; + default: + LOG_ERROR("Unsupported datatype for underlying TTNN tensor, returning " + "empty data vector"); + return {}; + } +} + +std::vector getShape(::tt::runtime::Tensor tensor) { + auto ttnnTensor = static_cast<::ttnn::Tensor *>(tensor.handle.get()); + std::vector shape(ttnnTensor->logical_shape().cbegin(), + ttnnTensor->logical_shape().cend()); + return shape; +} + +std::vector getStride(::tt::runtime::Tensor tensor) { + auto ttnnTensor = static_cast<::ttnn::Tensor *>(tensor.handle.get()); + std::vector stride(ttnnTensor->strides().cbegin(), + ttnnTensor->strides().cend()); + return stride; +} + +std::uint32_t getElementSize(::tt::runtime::Tensor tensor) { + auto ttnnTensor = static_cast<::ttnn::Tensor *>(tensor.handle.get()); + return ttnnTensor->element_size(); +} + +std::uint32_t getVolume(::tt::runtime::Tensor tensor) { + auto ttnnTensor = static_cast<::ttnn::Tensor *>(tensor.handle.get()); + return ttnnTensor->volume(); +} + +target::DataType getDtype(::tt::runtime::Tensor tensor) { + auto ttnnTensor = static_cast<::ttnn::Tensor *>(tensor.handle.get()); + return utils::fromTTNNDataType(ttnnTensor->dtype()); +} + std::vector submit(Device deviceHandle, Binary executableHandle, std::uint32_t programIndex, std::vector const &inputHandles) { diff --git a/runtime/test/python/ttnn/test_runtime_api.py b/runtime/test/python/ttnn/test_runtime_api.py index fc4cc9be79..98b86b0bf8 100644 --- a/runtime/test/python/ttnn/test_runtime_api.py +++ b/runtime/test/python/ttnn/test_runtime_api.py @@ -10,6 +10,34 @@ from utils import TT_MLIR_HOME, Helper, DeviceContext, assert_pcc +@pytest.mark.parametrize("shape", [(64, 128)]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16]) +def test_tensor_buffer_api(shape, dtype): + torch_tensor = torch.randn(shape, dtype=dtype) + runtime_dtype = Binary.Program.to_data_type(dtype) + rt_tensor = ttrt.runtime.create_tensor( + torch_tensor.data_ptr(), + list(torch_tensor.shape), + list(torch_tensor.stride()), + torch_tensor.element_size(), + runtime_dtype, + ) + rt_shape = rt_tensor.get_shape() + rt_elem_size = rt_tensor.get_element_size() + rt_vol = rt_tensor.get_volume() + rt_dtype = ttrt_datatype_to_torch_dtype(rt_tensor.get_dtype()) + rt_bytes = rt_tensor.get_data_buffer() + + # Various tests that the no underlying stuff has changed over the pybind boundary + assert list(rt_shape) == list(shape) + assert rt_elem_size == torch_tensor.element_size() + assert rt_vol == torch_tensor.numel() + assert rt_dtype == torch_tensor.dtype + assert len(rt_bytes) == rt_vol * rt_elem_size + reconstructed_tensor = torch.frombuffer(rt_bytes, dtype=rt_dtype).reshape(rt_shape) + assert torch.equal(torch_tensor, reconstructed_tensor) + + @pytest.mark.parametrize("shape", [(64, 128)]) @pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16]) def test_to_layout(helper: Helper, shape, dtype, request): diff --git a/runtime/tools/python/ttrt/common/util.py b/runtime/tools/python/ttrt/common/util.py index 1da498f465..dfe838a44f 100644 --- a/runtime/tools/python/ttrt/common/util.py +++ b/runtime/tools/python/ttrt/common/util.py @@ -53,8 +53,12 @@ def ttrt_datatype_to_torch_dtype(dtype) -> torch.dtype: return torch.uint16 elif dtype == DataType.UInt8: return torch.uint8 + elif dtype == DataType.BFloat16: + return torch.bfloat16 else: - raise ValueError("Only F32 and unsigned integers are supported in the runtime") + raise ValueError( + "Only F32, BF16, and unsigned integers are supported in the runtime" + ) def get_ttrt_metal_home_path(): diff --git a/runtime/tools/python/ttrt/runtime/module.cpp b/runtime/tools/python/ttrt/runtime/module.cpp index 72a15dabb1..09817037c1 100644 --- a/runtime/tools/python/ttrt/runtime/module.cpp +++ b/runtime/tools/python/ttrt/runtime/module.cpp @@ -39,7 +39,17 @@ PYBIND11_MODULE(_C, m) { .def("get_memory_view", &tt::runtime::detail::getMemoryView, py::arg("device_id") = 0); py::class_(m, "Event"); - py::class_(m, "Tensor"); + py::class_(m, "Tensor") + .def("get_shape", &tt::runtime::Tensor::getShape) + .def("get_stride", &tt::runtime::Tensor::getStride) + .def("get_volume", &tt::runtime::Tensor::getVolume) + .def("get_dtype", &tt::runtime::Tensor::getDtype) + .def("get_element_size", &tt::runtime::Tensor::getElementSize) + .def("get_data_buffer", [](tt::runtime::Tensor self) { + std::vector vec = self.getDataBuffer(); + return py::bytes(reinterpret_cast(vec.data()), + vec.size()); + }); py::class_(m, "Layout"); py::class_(m, "OpContext"); py::class_(m, "CallbackContext");