Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement runtime tensor desc API #2370

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open

Implement runtime tensor desc API #2370

wants to merge 8 commits into from

Conversation

ctodTT
Copy link
Contributor

@ctodTT ctodTT commented Mar 5, 2025

This change implements an API on runtime tensors to reflect tensor metadata and contents reflecting the underlying TTNN tensor. This functionality is pybound as well, allowing for easy casting to torch tensors. Please see runtime/test/python/ttnn/test_runtime_api.py for an example.

NOTE: This is not the most efficient implementation possible, as there is effectively a double copy to get the data in a pybindable row major format. There is probably some trickery that can be done later to avoid this, but I would like to get this functionality out ASAP and avoid premature optimization

Big thanks to @tapspatel for helping me figure out the namespace dispatch intricacies here!

Ticket

Closes #1957

Checklist

  • New/Existing tests provide coverage for changes

This change implements an API on runtime tensors to reflect tensor
metadata and contents reflecting the underlying TTNN tensor. This
functionality is pybound as well, allowing for easy casting to torch
tensors. Please see `runtime/test/python/ttnn/test_runtime_api.py` for
an example.

NOTE: This is not the most efficient implementation possible, as there
is effectively a double copy to get the data in a pybindable row major
format. There is probably some trickery that can be done later to avoid
this, but I would like to get this functionality out ASAP and avoid
premature optimization

Closes #1957
Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Clang-Tidy found issue(s) with the introduced code (1/1)

Copy link
Collaborator

@tapspatel tapspatel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fantastic changes! Thanks for the continued push on this. Few comments inline.

Will this break the golden callback functionality? Or do we already handle the types dynamically.

@@ -567,6 +567,83 @@ std::vector<float> getTensorData(Tensor tensor) {
static_cast<float *>(dataPtr) + nnTensor->volume());
}

std::vector<std::byte> getDataBuffer(::tt::runtime::Tensor tensor) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is extremely clean. I love it.

@@ -10,6 +10,34 @@
from utils import TT_MLIR_HOME, Helper, DeviceContext, assert_pcc


@pytest.mark.parametrize("shape", [(64, 128)])
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16])
def test_tensor_buffer_api(shape, dtype):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great way to test!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for adding tests!

std::vector<std::uint32_t> getStride(::tt::runtime::Tensor tensor);
std::uint32_t getElementSize(::tt::runtime::Tensor tensor);
std::uint32_t getVolume(::tt::runtime::Tensor tensor);
target::DataType getDtype(::tt::runtime::Tensor tensor);
Copy link
Contributor

@jnie-TT jnie-TT Mar 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have a TensorDesc type under runtime/include/tt/runtime/types.h:

struct TensorDesc {
  std::vector<std::uint32_t> shape;
  std::vector<std::uint32_t> stride;
  std::uint32_t itemsize;
  ::tt::target::DataType dataType;
};

Was wondering if we could merge these APIs into one and return the TensorDesc directly. It shouldn't be hard to bind this structure, and it'll have all the information in one place.

Also on a side note I think getDtype and getTensorDataType do the same thing here so is probably redundant.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, yeah I can bind and have it return that object if you want!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that would be awesome!

Copy link
Contributor Author

@ctodTT ctodTT Mar 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is now getTensorDesc() & get_tensor_desc() for this exact purpose!

}

std::uint32_t getVolume(::tt::runtime::Tensor tensor) {
auto ttnnTensor = static_cast<::ttnn::Tensor *>(tensor.handle.get());
Copy link
Contributor

@jnie-TT jnie-TT Mar 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should be using tensor.as here since it'll also check that the device runtime matches

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

case target::DataType::Float32:
dataPtr = ttnnTensor->to_vector<float>().data();
assert(dataPtr != nullptr);
std::memcpy(dataVec.data(), dataPtr, numBytes);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think numBytes will contain the wrong value here and won't be compatible for BFloat4 and BFloat8

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh shoot, you're right! good catch. let me fix that

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed!

Copy link

@jameszianxuTT jameszianxuTT left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks clean and the added tests provide a good reference for my implementation of golden checking with this API in tt-torch.

Having an API to return a TensorDesc would be great, so I second Jackson's request, but otherwise no complaints. Thanks!

@ctodTT ctodTT force-pushed the ctod/issue-1957 branch from 5ecf5cb to 490e719 Compare March 6, 2025 19:10
std::vector<std::uint32_t> getShape();
std::vector<std::uint32_t> getStride();
target::DataType getDtype();
TensorDesc getTensorDesc();
Copy link
Contributor

@jnie-TT jnie-TT Mar 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not add these as member functions for consistency reasons. I think for now we want to put all user-facing functions in runtime.h. This is open for discussion though, if we prefer member functions over standalone functions then this should be a larger scale change that covers all existing functions.

Also since we have getTensorDesc we can probably remove all other APIs since getTensorDesc will return all the information anyway.

Copy link
Contributor

@jnie-TT jnie-TT left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the change Collin!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Provide API to get runtime Tensor buffer, dtype, and shape.
5 participants