Skip to content

Commit

Permalink
Implement runtime tensor desc API
Browse files Browse the repository at this point in the history
This change implements an API on runtime tensors to reflect tensor
metadata and contents reflecting the underlying TTNN tensor. This
functionality is pybound as well, allowing for easy casting to torch
tensors. Please see `runtime/test/python/ttnn/test_runtime_api.py` for
an example.

NOTE: This is not the most efficient implementation possible, as there
is effectively a double copy to get the data in a pybindable row major
format. There is probably some trickery that can be done later to avoid
this, but I would like to get this functionality out ASAP and avoid
premature optimization

Closes #1957
  • Loading branch information
ctodTT committed Mar 5, 2025
1 parent ea451a3 commit 2eef566
Show file tree
Hide file tree
Showing 9 changed files with 260 additions and 2 deletions.
6 changes: 6 additions & 0 deletions runtime/include/tt/runtime/detail/ttmetal.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ inline Tensor createTensor(std::shared_ptr<void> data, TensorDesc const &desc) {
}

tt::target::DataType getTensorDataType(Tensor tensor);
std::vector<std::byte> getDataBuffer(::tt::runtime::Tensor tensor);
std::vector<std::uint32_t> getShape(::tt::runtime::Tensor tensor);
std::vector<std::uint32_t> getStride(::tt::runtime::Tensor tensor);
std::uint32_t getElementSize(::tt::runtime::Tensor tensor);
std::uint32_t getVolume(::tt::runtime::Tensor tensor);
target::DataType getDtype(::tt::runtime::Tensor tensor);

size_t getNumAvailableDevices();

Expand Down
7 changes: 7 additions & 0 deletions runtime/include/tt/runtime/detail/ttnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,13 @@ inline Tensor createTensor(Device device, Layout layout,

tt::target::DataType getTensorDataType(Tensor tensor);

std::vector<std::byte> getDataBuffer(::tt::runtime::Tensor tensor);
std::vector<std::uint32_t> getShape(::tt::runtime::Tensor tensor);
std::vector<std::uint32_t> getStride(::tt::runtime::Tensor tensor);
std::uint32_t getElementSize(::tt::runtime::Tensor tensor);
std::uint32_t getVolume(::tt::runtime::Tensor tensor);
target::DataType getDtype(::tt::runtime::Tensor tensor);

size_t getNumAvailableDevices();

Device
Expand Down
7 changes: 7 additions & 0 deletions runtime/include/tt/runtime/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,13 @@ struct Tensor : public detail::RuntimeCheckedObjectImpl {
std::shared_ptr<void> eventHandle, DeviceRuntime runtime)
: detail::RuntimeCheckedObjectImpl(handle, runtime), data(data),
event(eventHandle, runtime) {}

std::vector<std::byte> getDataBuffer();
std::uint32_t getElementSize();
std::uint32_t getVolume();
std::vector<std::uint32_t> getShape();
std::vector<std::uint32_t> getStride();
target::DataType getDtype();
};

struct Layout : public detail::RuntimeCheckedObjectImpl {
Expand Down
90 changes: 90 additions & 0 deletions runtime/lib/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -531,4 +531,94 @@ Event submit(Device deviceHandle, Binary executableHandle,
#endif
LOG_FATAL("runtime is not enabled");
}
std::vector<std::byte> Tensor::getDataBuffer() {
#if defined(TT_RUNTIME_ENABLE_TTNN)
if (getCurrentRuntime() == DeviceRuntime::TTNN) {
return ::tt::runtime::ttnn::getDataBuffer(*this);
}
#endif

#if defined(TT_RUNTIME_ENABLE_TTMETAL)
if (getCurrentRuntime() == DeviceRuntime::TTMetal) {
return ::tt::runtime::ttmetal::getDataBuffer(*this);
}
#endif
LOG_FATAL("runtime is not enabled");
}

std::vector<std::uint32_t> Tensor::getShape() {
#if defined(TT_RUNTIME_ENABLE_TTNN)
if (getCurrentRuntime() == DeviceRuntime::TTNN) {
return ::tt::runtime::ttnn::getShape(*this);
}
#endif

#if defined(TT_RUNTIME_ENABLE_TTMETAL)
if (getCurrentRuntime() == DeviceRuntime::TTMetal) {
return ::tt::runtime::ttmetal::getShape(*this);
}
#endif
LOG_FATAL("runtime is not enabled");
}

std::vector<std::uint32_t> Tensor::getStride() {
#if defined(TT_RUNTIME_ENABLE_TTNN)
if (getCurrentRuntime() == DeviceRuntime::TTNN) {
return ::tt::runtime::ttnn::getStride(*this);
}
#endif

#if defined(TT_RUNTIME_ENABLE_TTMETAL)
if (getCurrentRuntime() == DeviceRuntime::TTMetal) {
return ::tt::runtime::ttmetal::getStride(*this);
}
#endif
LOG_FATAL("runtime is not enabled");
}

target::DataType Tensor::getDtype() {
#if defined(TT_RUNTIME_ENABLE_TTNN)
if (getCurrentRuntime() == DeviceRuntime::TTNN) {
return ::tt::runtime::ttnn::getDtype(*this);
}
#endif

#if defined(TT_RUNTIME_ENABLE_TTMETAL)
if (getCurrentRuntime() == DeviceRuntime::TTMetal) {
return ::tt::runtime::ttmetal::getDtype(*this);
}
#endif
LOG_FATAL("runtime is not enabled");
}

std::uint32_t Tensor::getElementSize() {
#if defined(TT_RUNTIME_ENABLE_TTNN)
if (getCurrentRuntime() == DeviceRuntime::TTNN) {
return ::tt::runtime::ttnn::getElementSize(*this);
}
#endif

#if defined(TT_RUNTIME_ENABLE_TTMETAL)
if (getCurrentRuntime() == DeviceRuntime::TTMetal) {
return ::tt::runtime::ttmetal::getElementSize(*this);
}
#endif
LOG_FATAL("runtime is not enabled");
}

std::uint32_t Tensor::getVolume() {
#if defined(TT_RUNTIME_ENABLE_TTNN)
if (getCurrentRuntime() == DeviceRuntime::TTNN) {
return ::tt::runtime::ttnn::getVolume(*this);
}
#endif

#if defined(TT_RUNTIME_ENABLE_TTMETAL)
if (getCurrentRuntime() == DeviceRuntime::TTMetal) {
return ::tt::runtime::ttmetal::getVolume(*this);
}
#endif
LOG_FATAL("runtime is not enabled");
}

} // namespace tt::runtime
29 changes: 29 additions & 0 deletions runtime/lib/ttmetal/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -363,4 +363,33 @@ std::vector<float> getTensorData(Tensor tensor) {
return {};
}

std::vector<std::byte> getDataBuffer(::tt::runtime::Tensor tensor) {
LOG_WARNING("getDataBuffer not implemented for metal runtime");
return {};
}

std::vector<std::uint32_t> getShape(::tt::runtime::Tensor tensor) {
LOG_WARNING("getShape not implemented for metal runtime");
return {};
}

std::vector<std::uint32_t> getStride(::tt::runtime::Tensor tensor) {
LOG_WARNING("getStride not implemented for metal runtime");
return {};
}

std::uint32_t getElementSize(::tt::runtime::Tensor tensor) {
LOG_WARNING("getElementSize not implemented for metal runtime");
return 0;
}

std::uint32_t getVolume(::tt::runtime::Tensor tensor) {
LOG_WARNING("getVolume not implemented for metal runtime");
return 0;
}
target::DataType getDtype(::tt::runtime::Tensor tensor) {
LOG_WARNING("getDtype not implemented for metal runtime");
return {};
}

} // namespace tt::runtime::ttmetal
77 changes: 77 additions & 0 deletions runtime/lib/ttnn/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,83 @@ std::vector<float> getTensorData(Tensor tensor) {
static_cast<float *>(dataPtr) + nnTensor->volume());
}

std::vector<std::byte> getDataBuffer(::tt::runtime::Tensor tensor) {
auto ttnnTensor = static_cast<::ttnn::Tensor *>(tensor.handle.get());
void *dataPtr = nullptr;
size_t numBytes = getElementSize(tensor) * getVolume(tensor);
std::vector<std::byte> dataVec(numBytes);

// Need to `memcpy` in each case because the vector will go out of scope if we
// wait until after the switch case
switch (getDtype(tensor)) {
case target::DataType::BFP_BFloat4:
case target::DataType::BFP_BFloat8:
case target::DataType::Float32:
dataPtr = ttnnTensor->to_vector<float>().data();
assert(dataPtr != nullptr);
std::memcpy(dataVec.data(), dataPtr, numBytes);
return dataVec;
case target::DataType::BFloat16:
dataPtr = ttnnTensor->to_vector<bfloat16>().data();
assert(dataPtr != nullptr);
std::memcpy(dataVec.data(), dataPtr, numBytes);
return dataVec;
case target::DataType::Int32:
dataPtr = ttnnTensor->to_vector<std::int32_t>().data();
assert(dataPtr != nullptr);
std::memcpy(dataVec.data(), dataPtr, numBytes);
return dataVec;
case target::DataType::UInt32:
dataPtr = ttnnTensor->to_vector<std::uint32_t>().data();
assert(dataPtr != nullptr);
std::memcpy(dataVec.data(), dataPtr, numBytes);
return dataVec;
case target::DataType::UInt16:
dataPtr = ttnnTensor->to_vector<std::uint16_t>().data();
assert(dataPtr != nullptr);
std::memcpy(dataVec.data(), dataPtr, numBytes);
return dataVec;
case target::DataType::UInt8:
dataPtr = ttnnTensor->to_vector<std::uint8_t>().data();
assert(dataPtr != nullptr);
std::memcpy(dataVec.data(), dataPtr, numBytes);
return dataVec;
default:
LOG_ERROR("Unsupported datatype for underlying TTNN tensor, returning "
"empty data vector");
return {};
}
}

std::vector<std::uint32_t> getShape(::tt::runtime::Tensor tensor) {
auto ttnnTensor = static_cast<::ttnn::Tensor *>(tensor.handle.get());
std::vector<std::uint32_t> shape(ttnnTensor->logical_shape().cbegin(),
ttnnTensor->logical_shape().cend());
return shape;
}

std::vector<std::uint32_t> getStride(::tt::runtime::Tensor tensor) {
auto ttnnTensor = static_cast<::ttnn::Tensor *>(tensor.handle.get());
std::vector<std::uint32_t> stride(ttnnTensor->strides().cbegin(),
ttnnTensor->strides().cend());
return stride;
}

std::uint32_t getElementSize(::tt::runtime::Tensor tensor) {
auto ttnnTensor = static_cast<::ttnn::Tensor *>(tensor.handle.get());
return ttnnTensor->element_size();
}

std::uint32_t getVolume(::tt::runtime::Tensor tensor) {
auto ttnnTensor = static_cast<::ttnn::Tensor *>(tensor.handle.get());
return ttnnTensor->volume();
}

target::DataType getDtype(::tt::runtime::Tensor tensor) {
auto ttnnTensor = static_cast<::ttnn::Tensor *>(tensor.handle.get());
return utils::fromTTNNDataType(ttnnTensor->dtype());
}

std::vector<Tensor> submit(Device deviceHandle, Binary executableHandle,
std::uint32_t programIndex,
std::vector<Tensor> const &inputHandles) {
Expand Down
28 changes: 28 additions & 0 deletions runtime/test/python/ttnn/test_runtime_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,34 @@
from utils import TT_MLIR_HOME, Helper, DeviceContext, assert_pcc


@pytest.mark.parametrize("shape", [(64, 128)])
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16])
def test_tensor_buffer_api(shape, dtype):
torch_tensor = torch.randn(shape, dtype=dtype)
runtime_dtype = Binary.Program.to_data_type(dtype)
rt_tensor = ttrt.runtime.create_tensor(
torch_tensor.data_ptr(),
list(torch_tensor.shape),
list(torch_tensor.stride()),
torch_tensor.element_size(),
runtime_dtype,
)
rt_shape = rt_tensor.get_shape()
rt_elem_size = rt_tensor.get_element_size()
rt_vol = rt_tensor.get_volume()
rt_dtype = ttrt_datatype_to_torch_dtype(rt_tensor.get_dtype())
rt_bytes = rt_tensor.get_data_buffer()

# Various tests that the no underlying stuff has changed over the pybind boundary
assert list(rt_shape) == list(shape)
assert rt_elem_size == torch_tensor.element_size()
assert rt_vol == torch_tensor.numel()
assert rt_dtype == torch_tensor.dtype
assert len(rt_bytes) == rt_vol * rt_elem_size
reconstructed_tensor = torch.frombuffer(rt_bytes, dtype=rt_dtype).reshape(rt_shape)
assert torch.equal(torch_tensor, reconstructed_tensor)


@pytest.mark.parametrize("shape", [(64, 128)])
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16])
def test_to_layout(helper: Helper, shape, dtype, request):
Expand Down
6 changes: 5 additions & 1 deletion runtime/tools/python/ttrt/common/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,12 @@ def ttrt_datatype_to_torch_dtype(dtype) -> torch.dtype:
return torch.uint16
elif dtype == DataType.UInt8:
return torch.uint8
elif dtype == DataType.BFloat16:
return torch.bfloat16
else:
raise ValueError("Only F32 and unsigned integers are supported in the runtime")
raise ValueError(
"Only F32, BF16, and unsigned integers are supported in the runtime"
)


def get_ttrt_metal_home_path():
Expand Down
12 changes: 11 additions & 1 deletion runtime/tools/python/ttrt/runtime/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,17 @@ PYBIND11_MODULE(_C, m) {
.def("get_memory_view", &tt::runtime::detail::getMemoryView,
py::arg("device_id") = 0);
py::class_<tt::runtime::Event>(m, "Event");
py::class_<tt::runtime::Tensor>(m, "Tensor");
py::class_<tt::runtime::Tensor>(m, "Tensor")
.def("get_shape", &tt::runtime::Tensor::getShape)
.def("get_stride", &tt::runtime::Tensor::getStride)
.def("get_volume", &tt::runtime::Tensor::getVolume)
.def("get_dtype", &tt::runtime::Tensor::getDtype)
.def("get_element_size", &tt::runtime::Tensor::getElementSize)
.def("get_data_buffer", [](tt::runtime::Tensor self) {
std::vector<std::byte> vec = self.getDataBuffer();
return py::bytes(reinterpret_cast<const char *>(vec.data()),
vec.size());
});
py::class_<tt::runtime::Layout>(m, "Layout");
py::class_<tt::runtime::OpContext>(m, "OpContext");
py::class_<tt::runtime::CallbackContext>(m, "CallbackContext");
Expand Down

0 comments on commit 2eef566

Please sign in to comment.