Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement runtime tensor desc API #2370

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions runtime/include/tt/runtime/detail/ttmetal.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ inline Tensor createTensor(std::shared_ptr<void> data, TensorDesc const &desc) {
}

tt::target::DataType getTensorDataType(Tensor tensor);
std::vector<std::byte> getDataBuffer(::tt::runtime::Tensor tensor);
std::vector<std::uint32_t> getShape(::tt::runtime::Tensor tensor);
std::vector<std::uint32_t> getStride(::tt::runtime::Tensor tensor);
std::uint32_t getElementSize(::tt::runtime::Tensor tensor);
std::uint32_t getVolume(::tt::runtime::Tensor tensor);
target::DataType getDtype(::tt::runtime::Tensor tensor);

size_t getNumAvailableDevices();

Expand Down
7 changes: 7 additions & 0 deletions runtime/include/tt/runtime/detail/ttnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,13 @@ inline Tensor createTensor(Device device, Layout layout,

tt::target::DataType getTensorDataType(Tensor tensor);

std::vector<std::byte> getDataBuffer(::tt::runtime::Tensor tensor);
std::vector<std::uint32_t> getShape(::tt::runtime::Tensor tensor);
std::vector<std::uint32_t> getStride(::tt::runtime::Tensor tensor);
std::uint32_t getElementSize(::tt::runtime::Tensor tensor);
std::uint32_t getVolume(::tt::runtime::Tensor tensor);
target::DataType getDtype(::tt::runtime::Tensor tensor);
Copy link
Contributor

@jnie-TT jnie-TT Mar 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have a TensorDesc type under runtime/include/tt/runtime/types.h:

struct TensorDesc {
  std::vector<std::uint32_t> shape;
  std::vector<std::uint32_t> stride;
  std::uint32_t itemsize;
  ::tt::target::DataType dataType;
};

Was wondering if we could merge these APIs into one and return the TensorDesc directly. It shouldn't be hard to bind this structure, and it'll have all the information in one place.

Also on a side note I think getDtype and getTensorDataType do the same thing here so is probably redundant.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, yeah I can bind and have it return that object if you want!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that would be awesome!

Copy link
Contributor Author

@ctodTT ctodTT Mar 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is now getTensorDesc() & get_tensor_desc() for this exact purpose!


size_t getNumAvailableDevices();

Device
Expand Down
7 changes: 7 additions & 0 deletions runtime/include/tt/runtime/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,13 @@ struct Tensor : public detail::RuntimeCheckedObjectImpl {
std::shared_ptr<void> eventHandle, DeviceRuntime runtime)
: detail::RuntimeCheckedObjectImpl(handle, runtime), data(data),
event(eventHandle, runtime) {}

std::vector<std::byte> getDataBuffer();
std::uint32_t getElementSize();
std::uint32_t getVolume();
std::vector<std::uint32_t> getShape();
std::vector<std::uint32_t> getStride();
target::DataType getDtype();
};

struct Layout : public detail::RuntimeCheckedObjectImpl {
Expand Down
90 changes: 90 additions & 0 deletions runtime/lib/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -531,4 +531,94 @@ Event submit(Device deviceHandle, Binary executableHandle,
#endif
LOG_FATAL("runtime is not enabled");
}
std::vector<std::byte> Tensor::getDataBuffer() {
#if defined(TT_RUNTIME_ENABLE_TTNN)
if (getCurrentRuntime() == DeviceRuntime::TTNN) {
return ::tt::runtime::ttnn::getDataBuffer(*this);
}
#endif

#if defined(TT_RUNTIME_ENABLE_TTMETAL)
if (getCurrentRuntime() == DeviceRuntime::TTMetal) {
return ::tt::runtime::ttmetal::getDataBuffer(*this);
}
#endif
LOG_FATAL("runtime is not enabled");
}

std::vector<std::uint32_t> Tensor::getShape() {
#if defined(TT_RUNTIME_ENABLE_TTNN)
if (getCurrentRuntime() == DeviceRuntime::TTNN) {
return ::tt::runtime::ttnn::getShape(*this);
}
#endif

#if defined(TT_RUNTIME_ENABLE_TTMETAL)
if (getCurrentRuntime() == DeviceRuntime::TTMetal) {
return ::tt::runtime::ttmetal::getShape(*this);
}
#endif
LOG_FATAL("runtime is not enabled");
}

std::vector<std::uint32_t> Tensor::getStride() {
#if defined(TT_RUNTIME_ENABLE_TTNN)
if (getCurrentRuntime() == DeviceRuntime::TTNN) {
return ::tt::runtime::ttnn::getStride(*this);
}
#endif

#if defined(TT_RUNTIME_ENABLE_TTMETAL)
if (getCurrentRuntime() == DeviceRuntime::TTMetal) {
return ::tt::runtime::ttmetal::getStride(*this);
}
#endif
LOG_FATAL("runtime is not enabled");
}

target::DataType Tensor::getDtype() {
#if defined(TT_RUNTIME_ENABLE_TTNN)
if (getCurrentRuntime() == DeviceRuntime::TTNN) {
return ::tt::runtime::ttnn::getDtype(*this);
}
#endif

#if defined(TT_RUNTIME_ENABLE_TTMETAL)
if (getCurrentRuntime() == DeviceRuntime::TTMetal) {
return ::tt::runtime::ttmetal::getDtype(*this);
}
#endif
LOG_FATAL("runtime is not enabled");
}

std::uint32_t Tensor::getElementSize() {
#if defined(TT_RUNTIME_ENABLE_TTNN)
if (getCurrentRuntime() == DeviceRuntime::TTNN) {
return ::tt::runtime::ttnn::getElementSize(*this);
}
#endif

#if defined(TT_RUNTIME_ENABLE_TTMETAL)
if (getCurrentRuntime() == DeviceRuntime::TTMetal) {
return ::tt::runtime::ttmetal::getElementSize(*this);
}
#endif
LOG_FATAL("runtime is not enabled");
}

std::uint32_t Tensor::getVolume() {
#if defined(TT_RUNTIME_ENABLE_TTNN)
if (getCurrentRuntime() == DeviceRuntime::TTNN) {
return ::tt::runtime::ttnn::getVolume(*this);
}
#endif

#if defined(TT_RUNTIME_ENABLE_TTMETAL)
if (getCurrentRuntime() == DeviceRuntime::TTMetal) {
return ::tt::runtime::ttmetal::getVolume(*this);
}
#endif
LOG_FATAL("runtime is not enabled");
}

} // namespace tt::runtime
29 changes: 29 additions & 0 deletions runtime/lib/ttmetal/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -363,4 +363,33 @@ std::vector<float> getTensorData(Tensor tensor) {
return {};
}

std::vector<std::byte> getDataBuffer(::tt::runtime::Tensor tensor) {
LOG_WARNING("getDataBuffer not implemented for metal runtime");
return {};
}

std::vector<std::uint32_t> getShape(::tt::runtime::Tensor tensor) {
LOG_WARNING("getShape not implemented for metal runtime");
return {};
}

std::vector<std::uint32_t> getStride(::tt::runtime::Tensor tensor) {
LOG_WARNING("getStride not implemented for metal runtime");
return {};
}

std::uint32_t getElementSize(::tt::runtime::Tensor tensor) {
LOG_WARNING("getElementSize not implemented for metal runtime");
return 0;
}

std::uint32_t getVolume(::tt::runtime::Tensor tensor) {
LOG_WARNING("getVolume not implemented for metal runtime");
return 0;
}
target::DataType getDtype(::tt::runtime::Tensor tensor) {
LOG_WARNING("getDtype not implemented for metal runtime");
return {};
}

} // namespace tt::runtime::ttmetal
77 changes: 77 additions & 0 deletions runtime/lib/ttnn/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,83 @@ std::vector<float> getTensorData(Tensor tensor) {
static_cast<float *>(dataPtr) + nnTensor->volume());
}

std::vector<std::byte> getDataBuffer(::tt::runtime::Tensor tensor) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is extremely clean. I love it.

auto ttnnTensor = static_cast<::ttnn::Tensor *>(tensor.handle.get());
void *dataPtr = nullptr;
size_t numBytes = getElementSize(tensor) * getVolume(tensor);
std::vector<std::byte> dataVec(numBytes);

// Need to `memcpy` in each case because the vector will go out of scope if we
// wait until after the switch case
switch (getDtype(tensor)) {
case target::DataType::BFP_BFloat4:
case target::DataType::BFP_BFloat8:
case target::DataType::Float32:
dataPtr = ttnnTensor->to_vector<float>().data();
assert(dataPtr != nullptr);
std::memcpy(dataVec.data(), dataPtr, numBytes);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think numBytes will contain the wrong value here and won't be compatible for BFloat4 and BFloat8

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh shoot, you're right! good catch. let me fix that

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed!

return dataVec;
case target::DataType::BFloat16:
dataPtr = ttnnTensor->to_vector<bfloat16>().data();
assert(dataPtr != nullptr);
std::memcpy(dataVec.data(), dataPtr, numBytes);
return dataVec;
case target::DataType::Int32:
dataPtr = ttnnTensor->to_vector<std::int32_t>().data();
assert(dataPtr != nullptr);
std::memcpy(dataVec.data(), dataPtr, numBytes);
return dataVec;
case target::DataType::UInt32:
dataPtr = ttnnTensor->to_vector<std::uint32_t>().data();
assert(dataPtr != nullptr);
std::memcpy(dataVec.data(), dataPtr, numBytes);
return dataVec;
case target::DataType::UInt16:
dataPtr = ttnnTensor->to_vector<std::uint16_t>().data();
assert(dataPtr != nullptr);
std::memcpy(dataVec.data(), dataPtr, numBytes);
return dataVec;
case target::DataType::UInt8:
dataPtr = ttnnTensor->to_vector<std::uint8_t>().data();
assert(dataPtr != nullptr);
std::memcpy(dataVec.data(), dataPtr, numBytes);
return dataVec;
default:
LOG_ERROR("Unsupported datatype for underlying TTNN tensor, returning "
"empty data vector");
return {};
}
}

std::vector<std::uint32_t> getShape(::tt::runtime::Tensor tensor) {
auto ttnnTensor = static_cast<::ttnn::Tensor *>(tensor.handle.get());
std::vector<std::uint32_t> shape(ttnnTensor->logical_shape().cbegin(),
ttnnTensor->logical_shape().cend());
return shape;
}

std::vector<std::uint32_t> getStride(::tt::runtime::Tensor tensor) {
auto ttnnTensor = static_cast<::ttnn::Tensor *>(tensor.handle.get());
std::vector<std::uint32_t> stride(ttnnTensor->strides().cbegin(),
ttnnTensor->strides().cend());
return stride;
}

std::uint32_t getElementSize(::tt::runtime::Tensor tensor) {
auto ttnnTensor = static_cast<::ttnn::Tensor *>(tensor.handle.get());
return ttnnTensor->element_size();
}

std::uint32_t getVolume(::tt::runtime::Tensor tensor) {
auto ttnnTensor = static_cast<::ttnn::Tensor *>(tensor.handle.get());
Copy link
Contributor

@jnie-TT jnie-TT Mar 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should be using tensor.as here since it'll also check that the device runtime matches

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

return ttnnTensor->volume();
}

target::DataType getDtype(::tt::runtime::Tensor tensor) {
auto ttnnTensor = static_cast<::ttnn::Tensor *>(tensor.handle.get());
return utils::fromTTNNDataType(ttnnTensor->dtype());
}

std::vector<Tensor> submit(Device deviceHandle, Binary executableHandle,
std::uint32_t programIndex,
std::vector<Tensor> const &inputHandles) {
Expand Down
28 changes: 28 additions & 0 deletions runtime/test/python/ttnn/test_runtime_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,34 @@
from utils import TT_MLIR_HOME, Helper, DeviceContext, assert_pcc


@pytest.mark.parametrize("shape", [(64, 128)])
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16])
def test_tensor_buffer_api(shape, dtype):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great way to test!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for adding tests!

torch_tensor = torch.randn(shape, dtype=dtype)
runtime_dtype = Binary.Program.to_data_type(dtype)
rt_tensor = ttrt.runtime.create_tensor(
torch_tensor.data_ptr(),
list(torch_tensor.shape),
list(torch_tensor.stride()),
torch_tensor.element_size(),
runtime_dtype,
)
rt_shape = rt_tensor.get_shape()
rt_elem_size = rt_tensor.get_element_size()
rt_vol = rt_tensor.get_volume()
rt_dtype = ttrt_datatype_to_torch_dtype(rt_tensor.get_dtype())
rt_bytes = rt_tensor.get_data_buffer()

# Various tests that the no underlying stuff has changed over the pybind boundary
assert list(rt_shape) == list(shape)
assert rt_elem_size == torch_tensor.element_size()
assert rt_vol == torch_tensor.numel()
assert rt_dtype == torch_tensor.dtype
assert len(rt_bytes) == rt_vol * rt_elem_size
reconstructed_tensor = torch.frombuffer(rt_bytes, dtype=rt_dtype).reshape(rt_shape)
assert torch.equal(torch_tensor, reconstructed_tensor)


@pytest.mark.parametrize("shape", [(64, 128)])
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16])
def test_to_layout(helper: Helper, shape, dtype, request):
Expand Down
6 changes: 5 additions & 1 deletion runtime/tools/python/ttrt/common/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,12 @@ def ttrt_datatype_to_torch_dtype(dtype) -> torch.dtype:
return torch.uint16
elif dtype == DataType.UInt8:
return torch.uint8
elif dtype == DataType.BFloat16:
return torch.bfloat16
else:
raise ValueError("Only F32 and unsigned integers are supported in the runtime")
raise ValueError(
"Only F32, BF16, and unsigned integers are supported in the runtime"
)


def get_ttrt_metal_home_path():
Expand Down
12 changes: 11 additions & 1 deletion runtime/tools/python/ttrt/runtime/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,17 @@ PYBIND11_MODULE(_C, m) {
.def("get_memory_view", &tt::runtime::detail::getMemoryView,
py::arg("device_id") = 0);
py::class_<tt::runtime::Event>(m, "Event");
py::class_<tt::runtime::Tensor>(m, "Tensor");
py::class_<tt::runtime::Tensor>(m, "Tensor")
.def("get_shape", &tt::runtime::Tensor::getShape)
.def("get_stride", &tt::runtime::Tensor::getStride)
.def("get_volume", &tt::runtime::Tensor::getVolume)
.def("get_dtype", &tt::runtime::Tensor::getDtype)
.def("get_element_size", &tt::runtime::Tensor::getElementSize)
.def("get_data_buffer", [](tt::runtime::Tensor self) {
std::vector<std::byte> vec = self.getDataBuffer();
return py::bytes(reinterpret_cast<const char *>(vec.data()),
vec.size());
});
py::class_<tt::runtime::Layout>(m, "Layout");
py::class_<tt::runtime::OpContext>(m, "OpContext");
py::class_<tt::runtime::CallbackContext>(m, "CallbackContext");
Expand Down
Loading