-
Notifications
You must be signed in to change notification settings - Fork 16
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement runtime tensor desc API #2370
base: main
Are you sure you want to change the base?
Conversation
This change implements an API on runtime tensors to reflect tensor metadata and contents reflecting the underlying TTNN tensor. This functionality is pybound as well, allowing for easy casting to torch tensors. Please see `runtime/test/python/ttnn/test_runtime_api.py` for an example. NOTE: This is not the most efficient implementation possible, as there is effectively a double copy to get the data in a pybindable row major format. There is probably some trickery that can be done later to avoid this, but I would like to get this functionality out ASAP and avoid premature optimization Closes #1957
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Clang-Tidy
found issue(s) with the introduced code (1/1)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fantastic changes! Thanks for the continued push on this. Few comments inline.
Will this break the golden callback functionality? Or do we already handle the types dynamically.
runtime/lib/ttnn/runtime.cpp
Outdated
@@ -567,6 +567,83 @@ std::vector<float> getTensorData(Tensor tensor) { | |||
static_cast<float *>(dataPtr) + nnTensor->volume()); | |||
} | |||
|
|||
std::vector<std::byte> getDataBuffer(::tt::runtime::Tensor tensor) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is extremely clean. I love it.
@@ -10,6 +10,34 @@ | |||
from utils import TT_MLIR_HOME, Helper, DeviceContext, assert_pcc | |||
|
|||
|
|||
@pytest.mark.parametrize("shape", [(64, 128)]) | |||
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16]) | |||
def test_tensor_buffer_api(shape, dtype): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
great way to test!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for adding tests!
std::vector<std::uint32_t> getStride(::tt::runtime::Tensor tensor); | ||
std::uint32_t getElementSize(::tt::runtime::Tensor tensor); | ||
std::uint32_t getVolume(::tt::runtime::Tensor tensor); | ||
target::DataType getDtype(::tt::runtime::Tensor tensor); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have a TensorDesc type under runtime/include/tt/runtime/types.h
:
struct TensorDesc {
std::vector<std::uint32_t> shape;
std::vector<std::uint32_t> stride;
std::uint32_t itemsize;
::tt::target::DataType dataType;
};
Was wondering if we could merge these APIs into one and return the TensorDesc directly. It shouldn't be hard to bind this structure, and it'll have all the information in one place.
Also on a side note I think getDtype
and getTensorDataType
do the same thing here so is probably redundant.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, yeah I can bind and have it return that object if you want!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah that would be awesome!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is now getTensorDesc()
& get_tensor_desc()
for this exact purpose!
runtime/lib/ttnn/runtime.cpp
Outdated
} | ||
|
||
std::uint32_t getVolume(::tt::runtime::Tensor tensor) { | ||
auto ttnnTensor = static_cast<::ttnn::Tensor *>(tensor.handle.get()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should be using tensor.as
here since it'll also check that the device runtime matches
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
runtime/lib/ttnn/runtime.cpp
Outdated
case target::DataType::Float32: | ||
dataPtr = ttnnTensor->to_vector<float>().data(); | ||
assert(dataPtr != nullptr); | ||
std::memcpy(dataVec.data(), dataPtr, numBytes); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think numBytes
will contain the wrong value here and won't be compatible for BFloat4 and BFloat8
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh shoot, you're right! good catch. let me fix that
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks clean and the added tests provide a good reference for my implementation of golden checking with this API in tt-torch.
Having an API to return a TensorDesc would be great, so I second Jackson's request, but otherwise no complaints. Thanks!
Now all golden checks utilize the runtime tensor buffer API
runtime/include/tt/runtime/types.h
Outdated
std::vector<std::uint32_t> getShape(); | ||
std::vector<std::uint32_t> getStride(); | ||
target::DataType getDtype(); | ||
TensorDesc getTensorDesc(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's not add these as member functions for consistency reasons. I think for now we want to put all user-facing functions in runtime.h
. This is open for discussion though, if we prefer member functions over standalone functions then this should be a larger scale change that covers all existing functions.
Also since we have getTensorDesc
we can probably remove all other APIs since getTensorDesc
will return all the information anyway.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the change Collin!
This change implements an API on runtime tensors to reflect tensor metadata and contents reflecting the underlying TTNN tensor. This functionality is pybound as well, allowing for easy casting to torch tensors. Please see
runtime/test/python/ttnn/test_runtime_api.py
for an example.NOTE: This is not the most efficient implementation possible, as there is effectively a double copy to get the data in a pybindable row major format. There is probably some trickery that can be done later to avoid this, but I would like to get this functionality out ASAP and avoid premature optimization
Big thanks to @tapspatel for helping me figure out the namespace dispatch intricacies here!
Ticket
Closes #1957
Checklist