Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement runtime tensor desc API #2370

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions runtime/include/tt/runtime/detail/ttmetal.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ inline Tensor createTensor(std::shared_ptr<void> data, TensorDesc const &desc) {
}

tt::target::DataType getTensorDataType(Tensor tensor);
std::vector<std::byte> getTensorDataBuffer(::tt::runtime::Tensor tensor);
std::vector<std::uint32_t> getTensorShape(::tt::runtime::Tensor tensor);
std::vector<std::uint32_t> getTensorStride(::tt::runtime::Tensor tensor);
std::uint32_t getTensorElementSize(::tt::runtime::Tensor tensor);
std::uint32_t getTensorVolume(::tt::runtime::Tensor tensor);
TensorDesc getTensorDesc(::tt::runtime::Tensor tensor);

size_t getNumAvailableDevices();

Expand Down Expand Up @@ -65,8 +71,6 @@ std::string getOpLocInfo(OpContext opContextHandle);
Tensor getOpOutputTensor(OpContext opContextHandle,
CallbackContext programContextHandle);

std::vector<float> getTensorData(Tensor tensor);

using InputBuffer =
std::tuple<std::uint32_t, std::shared_ptr<::tt::tt_metal::Buffer>,
std::shared_ptr<::tt::tt_metal::Event>>;
Expand Down
9 changes: 7 additions & 2 deletions runtime/include/tt/runtime/detail/ttnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,13 @@ inline Tensor createTensor(Device device, Layout layout,

tt::target::DataType getTensorDataType(Tensor tensor);

std::vector<std::byte> getTensorDataBuffer(::tt::runtime::Tensor tensor);
std::vector<std::uint32_t> getTensorShape(::tt::runtime::Tensor tensor);
std::vector<std::uint32_t> getTensorStride(::tt::runtime::Tensor tensor);
std::uint32_t getTensorElementSize(::tt::runtime::Tensor tensor);
std::uint32_t getTensorVolume(::tt::runtime::Tensor tensor);
TensorDesc getTensorDesc(::tt::runtime::Tensor tensor);

size_t getNumAvailableDevices();

Device
Expand Down Expand Up @@ -137,8 +144,6 @@ std::string getOpLocInfo(OpContext opContextHandle);
Tensor getOpOutputTensor(OpContext opContextHandle,
CallbackContext programContextHandle);

std::vector<float> getTensorData(Tensor tensor);

std::vector<Tensor> submit(Device deviceHandle, Binary executableHandle,
std::uint32_t programIndex,
std::vector<Tensor> const &inputs);
Expand Down
8 changes: 6 additions & 2 deletions runtime/include/tt/runtime/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,12 @@ inline Tensor createTensor(Device device, Layout layout,
}

tt::target::DataType getTensorDataType(Tensor tensor);
std::vector<std::byte> getTensorDataBuffer(Tensor tensor);
std::uint32_t getTensorElementSize(Tensor tensor);
std::uint32_t getTensorVolume(Tensor tensor);
std::vector<std::uint32_t> getTensorShape(Tensor tensor);
std::vector<std::uint32_t> getTensorStride(Tensor tensor);
TensorDesc getTensorDesc(Tensor tensor);

size_t getNumAvailableDevices();

Expand Down Expand Up @@ -126,8 +132,6 @@ std::string getOpLocInfo(OpContext opContextHandle);
Tensor getOpOutputTensor(OpContext opContextHandle,
CallbackContext programContextHandle);

std::vector<float> getTensorData(Tensor tensor);

std::vector<Tensor> submit(Device deviceHandle, Binary executableHandle,
std::uint32_t programIndex,
std::vector<Tensor> const &inputs);
Expand Down
121 changes: 105 additions & 16 deletions runtime/lib/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -477,22 +477,6 @@ Tensor getOpOutputTensor(OpContext opContextHandle,
LOG_FATAL("runtime is not enabled");
}

std::vector<float> getTensorData(Tensor tensor) {
#if defined(TT_RUNTIME_ENABLE_TTNN)
if (getCurrentRuntime() == DeviceRuntime::TTNN) {
return ::tt::runtime::ttnn::getTensorData(tensor);
}
#endif

#if defined(TT_RUNTIME_ENABLE_TTMETAL)
if (getCurrentRuntime() == DeviceRuntime::TTMetal) {
return ::tt::runtime::ttmetal::getTensorData(tensor);
}
#endif

LOG_FATAL("runtime is not enabled");
}

std::vector<Tensor> submit(Device deviceHandle, Binary executableHandle,
std::uint32_t programIndex,
std::vector<Tensor> const &inputHandles) {
Expand Down Expand Up @@ -531,4 +515,109 @@ Event submit(Device deviceHandle, Binary executableHandle,
#endif
LOG_FATAL("runtime is not enabled");
}
std::vector<std::byte> getTensorDataBuffer(::tt::runtime::Tensor t) {
#if defined(TT_RUNTIME_ENABLE_TTNN)
if (getCurrentRuntime() == DeviceRuntime::TTNN) {
return ::tt::runtime::ttnn::getTensorDataBuffer(t);
}
#endif

#if defined(TT_RUNTIME_ENABLE_TTMETAL)
if (getCurrentRuntime() == DeviceRuntime::TTMetal) {
return ::tt::runtime::ttmetal::getTensorDataBuffer(t);
}
#endif
LOG_FATAL("runtime is not enabled");
}

std::vector<std::uint32_t> getTensorShape(::tt::runtime::Tensor t) {
#if defined(TT_RUNTIME_ENABLE_TTNN)
if (getCurrentRuntime() == DeviceRuntime::TTNN) {
return ::tt::runtime::ttnn::getTensorShape(t);
}
#endif

#if defined(TT_RUNTIME_ENABLE_TTMETAL)
if (getCurrentRuntime() == DeviceRuntime::TTMetal) {
return ::tt::runtime::ttmetal::getTensorShape(t);
}
#endif
LOG_FATAL("runtime is not enabled");
}

std::vector<std::uint32_t> getTensorStride(::tt::runtime::Tensor t) {
#if defined(TT_RUNTIME_ENABLE_TTNN)
if (getCurrentRuntime() == DeviceRuntime::TTNN) {
return ::tt::runtime::ttnn::getTensorStride(t);
}
#endif

#if defined(TT_RUNTIME_ENABLE_TTMETAL)
if (getCurrentRuntime() == DeviceRuntime::TTMetal) {
return ::tt::runtime::ttmetal::getTensorStride(t);
}
#endif
LOG_FATAL("runtime is not enabled");
}

target::DataType getDtype(::tt::runtime::Tensor t) {
#if defined(TT_RUNTIME_ENABLE_TTNN)
if (getCurrentRuntime() == DeviceRuntime::TTNN) {
return ::tt::runtime::ttnn::getTensorDataType(t);
}
#endif

#if defined(TT_RUNTIME_ENABLE_TTMETAL)
if (getCurrentRuntime() == DeviceRuntime::TTMetal) {
return ::tt::runtime::ttmetal::getTensorDataType(t);
}
#endif
LOG_FATAL("runtime is not enabled");
}

std::uint32_t getTensorElementSize(::tt::runtime::Tensor t) {
#if defined(TT_RUNTIME_ENABLE_TTNN)
if (getCurrentRuntime() == DeviceRuntime::TTNN) {
return ::tt::runtime::ttnn::getTensorElementSize(t);
}
#endif

#if defined(TT_RUNTIME_ENABLE_TTMETAL)
if (getCurrentRuntime() == DeviceRuntime::TTMetal) {
return ::tt::runtime::ttmetal::getTensorElementSize(t);
}
#endif
LOG_FATAL("runtime is not enabled");
}

std::uint32_t getTensorVolume(::tt::runtime::Tensor t) {
#if defined(TT_RUNTIME_ENABLE_TTNN)
if (getCurrentRuntime() == DeviceRuntime::TTNN) {
return ::tt::runtime::ttnn::getTensorVolume(t);
}
#endif

#if defined(TT_RUNTIME_ENABLE_TTMETAL)
if (getCurrentRuntime() == DeviceRuntime::TTMetal) {
return ::tt::runtime::ttmetal::getTensorVolume(t);
}
#endif
LOG_FATAL("runtime is not enabled");
}

TensorDesc getTensorDesc(::tt::runtime::Tensor t) {
#if defined(TT_RUNTIME_ENABLE_TTNN)
if (getCurrentRuntime() == DeviceRuntime::TTNN) {
return ::tt::runtime::ttnn::getTensorDesc(t);
}
#endif

#if defined(TT_RUNTIME_ENABLE_TTMETAL)
if (getCurrentRuntime() == DeviceRuntime::TTMetal) {
return ::tt::runtime::ttmetal::getTensorDesc(t);
}
#endif
LOG_FATAL("runtime is not enabled");
}

} // namespace tt::runtime
30 changes: 27 additions & 3 deletions runtime/lib/ttmetal/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -357,9 +357,33 @@ Tensor getOpOutputTensor(OpContext opContextHandle,
return createNullTensor();
}

std::vector<float> getTensorData(Tensor tensor) {
// Not implemented
LOG_WARNING("obtaining tensor data for metal runtime not implemented");
std::vector<std::byte> getTensorDataBuffer(::tt::runtime::Tensor tensor) {
LOG_WARNING("getDataBuffer not implemented for metal runtime");
return {};
}

std::vector<std::uint32_t> getTensorShape(::tt::runtime::Tensor tensor) {
LOG_WARNING("getShape not implemented for metal runtime");
return {};
}

std::vector<std::uint32_t> getTensorStride(::tt::runtime::Tensor tensor) {
LOG_WARNING("getStride not implemented for metal runtime");
return {};
}

std::uint32_t getTensorElementSize(::tt::runtime::Tensor tensor) {
LOG_WARNING("getElementSize not implemented for metal runtime");
return 0;
}

std::uint32_t getTensorVolume(::tt::runtime::Tensor tensor) {
LOG_WARNING("getVolume not implemented for metal runtime");
return 0;
}

TensorDesc getTensorDesc(::tt::runtime::Tensor tensor) {
LOG_WARNING("getTensorDesc not implemented for metal runtime");
return {};
}

Expand Down
118 changes: 111 additions & 7 deletions runtime/lib/ttnn/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -555,16 +555,120 @@ Tensor getOpOutputTensor(OpContext opContextHandle,
DeviceRuntime::TTNN);
}

std::vector<float> getTensorData(Tensor tensor) {
const ::ttnn::Tensor *nnTensor =
static_cast<::ttnn::Tensor *>(tensor.handle.get());
if (nnTensor == nullptr) {
std::vector<std::byte> getTensorDataBuffer(::tt::runtime::Tensor tensor) {
const ::ttnn::Tensor &ttnnTensor =
tensor.as<::ttnn::Tensor>(DeviceRuntime::TTNN);
void *dataPtr = nullptr;
std::vector<std::byte> dataVec(getTensorElementSize(tensor) *
getTensorVolume(tensor));

// Need to `memcpy` in each case because the vector will go out of scope if we
// wait until after the switch case
switch (getTensorDataType(tensor)) {
case target::DataType::BFP_BFloat4: {
dataVec.resize(sizeof(float) * getTensorVolume(tensor));
auto vec = ttnnTensor.to_vector<float>();
dataPtr = vec.data();
LOG_ASSERT(dataPtr != nullptr);
std::memcpy(dataVec.data(), dataPtr, dataVec.size());
return dataVec;
}
case target::DataType::BFP_BFloat8: {
dataVec.resize(sizeof(float) * getTensorVolume(tensor));
auto vec = ttnnTensor.to_vector<float>();
dataPtr = vec.data();
LOG_ASSERT(dataPtr != nullptr);
std::memcpy(dataVec.data(), dataPtr, dataVec.size());
return dataVec;
}
case target::DataType::Float32: {
auto vec = ttnnTensor.to_vector<float>();
dataPtr = vec.data();
LOG_ASSERT(dataPtr != nullptr);
std::memcpy(dataVec.data(), dataPtr, dataVec.size());
return dataVec;
}
case target::DataType::BFloat16: {
auto vec = ttnnTensor.to_vector<bfloat16>();
dataPtr = vec.data();
LOG_ASSERT(dataPtr != nullptr);
std::memcpy(dataVec.data(), dataPtr, dataVec.size());
return dataVec;
}
case target::DataType::Int32: {
auto vec = ttnnTensor.to_vector<std::int32_t>();
dataPtr = vec.data();
LOG_ASSERT(dataPtr != nullptr);
std::memcpy(dataVec.data(), dataPtr, dataVec.size());
return dataVec;
}
case target::DataType::UInt32: {
auto vec = ttnnTensor.to_vector<std::uint32_t>();
dataPtr = vec.data();
LOG_ASSERT(dataPtr != nullptr);
std::memcpy(dataVec.data(), dataPtr, dataVec.size());
return dataVec;
}
case target::DataType::UInt16: {
auto vec = ttnnTensor.to_vector<std::uint16_t>();
dataPtr = vec.data();
LOG_ASSERT(dataPtr != nullptr);
std::memcpy(dataVec.data(), dataPtr, dataVec.size());
return dataVec;
}
case target::DataType::UInt8: {
auto vec = ttnnTensor.to_vector<std::uint8_t>();
dataPtr = vec.data();
LOG_ASSERT(dataPtr != nullptr);
std::memcpy(dataVec.data(), dataPtr, dataVec.size());
return dataVec;
}
default:
LOG_ERROR("Unsupported datatype for underlying TTNN tensor, returning "
"empty data vector");
return {};
}
}

std::vector<std::uint32_t> getTensorShape(::tt::runtime::Tensor tensor) {
const ::ttnn::Tensor &ttnnTensor =
tensor.as<::ttnn::Tensor>(DeviceRuntime::TTNN);
std::vector<std::uint32_t> shape;
for (size_t i = 0; i < ttnnTensor.logical_shape().size(); ++i) {
shape.push_back(ttnnTensor.logical_shape()[i]);
}
return shape;
}

std::vector<std::uint32_t> getTensorStride(::tt::runtime::Tensor tensor) {
const ::ttnn::Tensor &ttnnTensor =
tensor.as<::ttnn::Tensor>(DeviceRuntime::TTNN);
std::vector<std::uint32_t> stride;
for (size_t i = 0; i < ttnnTensor.strides().size(); ++i) {
stride.push_back(ttnnTensor.strides()[i]);
}
return stride;
}

std::uint32_t getTensorElementSize(::tt::runtime::Tensor tensor) {
const ::ttnn::Tensor &ttnnTensor =
tensor.as<::ttnn::Tensor>(DeviceRuntime::TTNN);
return ttnnTensor.element_size();
}

std::uint32_t getTensorVolume(::tt::runtime::Tensor tensor) {
const ::ttnn::Tensor &ttnnTensor =
tensor.as<::ttnn::Tensor>(DeviceRuntime::TTNN);
return ttnnTensor.volume();
}

void *dataPtr = ::tt::tt_metal::get_raw_host_data_ptr(*nnTensor);
return std::vector<float>(static_cast<float *>(dataPtr),
static_cast<float *>(dataPtr) + nnTensor->volume());
TensorDesc getTensorDesc(::tt::runtime::Tensor tensor) {
TensorDesc desc;
desc.dataType = getTensorDataType(tensor);
desc.itemsize = getTensorElementSize(tensor);
desc.stride = getTensorStride(tensor);
desc.shape = getTensorShape(tensor);
return desc;
}

std::vector<Tensor> submit(Device deviceHandle, Binary executableHandle,
Expand Down
Loading
Loading