From 3635dc0e1c51b2c9cf81e856c3f7a1e888063236 Mon Sep 17 00:00:00 2001 From: David Ma Date: Fri, 7 Mar 2025 23:15:51 +0000 Subject: [PATCH] #0: Add mising MeshBuffer APIs --- tt_metal/api/tt-metalium/buffer.hpp | 8 + tt_metal/api/tt-metalium/mesh_buffer.hpp | 44 +++- tt_metal/distributed/mesh_buffer.cpp | 249 +++++++++++++++++++++-- tt_metal/impl/buffers/buffer.cpp | 90 ++++---- 4 files changed, 330 insertions(+), 61 deletions(-) diff --git a/tt_metal/api/tt-metalium/buffer.hpp b/tt_metal/api/tt-metalium/buffer.hpp index 895360e4383..7aae407fd0e 100644 --- a/tt_metal/api/tt-metalium/buffer.hpp +++ b/tt_metal/api/tt-metalium/buffer.hpp @@ -311,6 +311,14 @@ class Buffer final { } // namespace v0 +std::tuple>, std::vector>> core_to_host_pages( + const uint32_t total_pages, + const uint32_t pages_per_shard, + const uint32_t num_shards, + const TensorMemoryLayout layout, + const std::array& page_shape, + const std::array& shard_shape, + const std::array& tensor2d_size); BufferPageMapping generate_buffer_page_mapping(const Buffer &buffer); inline namespace v0 { diff --git a/tt_metal/api/tt-metalium/mesh_buffer.hpp b/tt_metal/api/tt-metalium/mesh_buffer.hpp index 2e68e1cf637..9c5f216c70c 100644 --- a/tt_metal/api/tt-metalium/mesh_buffer.hpp +++ b/tt_metal/api/tt-metalium/mesh_buffer.hpp @@ -88,6 +88,7 @@ class MeshBuffer { // Throws an exception if the corresponding MeshDevice is already deallocated MeshDevice* device() const; + Allocator* allocator() const { return allocator_; } DeviceAddr size() const; DeviceAddr device_local_size() const { return device_local_size_; } DeviceAddr address() const { return address_; }; @@ -102,8 +103,39 @@ class MeshBuffer { uint32_t datum_size_bytes() const; Shape2D physical_shard_shape() const; std::pair replicated_dims() const; - uint32_t page_size() const { return device_local_config_.page_size; } + DeviceAddr page_size() const { return device_local_config_.page_size; } + void set_page_size(DeviceAddr page_size); uint32_t num_pages() const { return page_size() == 0 ? 0 : device_local_size_ / page_size(); } + uint32_t num_dev_pages() const; + + BufferType buffer_type() const { return device_local_config_.buffer_type; } + CoreType core_type() const; + + bool is_l1() const; + bool is_dram() const; + bool is_trace() const; + + bool is_valid_region(const BufferRegion& region) const; + bool is_valid_partial_region(const BufferRegion& region) const; + + TensorMemoryLayout buffer_layout() const { return device_local_config_.buffer_layout; } + + bool bottom_up() const { return device_local_config_.bottom_up.value(); } + + DeviceAddr page_address(uint32_t bank_id, uint32_t page_index) const; + DeviceAddr bank_local_page_address(uint32_t bank_id, uint32_t page_index) const; + uint32_t alignment() const; + DeviceAddr aligned_page_size() const; + DeviceAddr aligned_size() const; + DeviceAddr aligned_size_per_bank() const; + + DeviceAddr sharded_page_address(uint32_t bank_id, uint32_t page_index) const; + ShardSpecBuffer shard_spec() const; + void set_shard_spec(const ShardSpecBuffer& shard_spec); + std::optional num_cores() const; + const std::shared_ptr& get_buffer_page_mapping(); + std::optional sub_device_id() const; + size_t unique_id() const { return unique_id_; } private: // Creates an owning `MeshBuffer`, backed by an allocation made through `backing_buffer`. @@ -112,6 +144,7 @@ class MeshBuffer { const DeviceLocalBufferConfig& device_local_config, DeviceAddr device_local_size, MeshDevice* mesh_device, + size_t unique_id, std::shared_ptr backing_buffer) : buffers_(MeshShape(mesh_device->shape()), nullptr), config_(config), @@ -119,6 +152,7 @@ class MeshBuffer { mesh_device_(mesh_device->shared_from_this()), address_(backing_buffer->address()), device_local_size_(device_local_size), + unique_id_(unique_id), state_(OwnedBufferState{std::move(backing_buffer)}) {} // Creates a non-owning `MeshBuffer` as "view" over an existing `address`. @@ -127,16 +161,19 @@ class MeshBuffer { const DeviceLocalBufferConfig& device_local_config, DeviceAddr address, DeviceAddr device_local_size, - MeshDevice* mesh_device) : + MeshDevice* mesh_device, + size_t unique_id) : buffers_(MeshShape(mesh_device->shape()), /*fill_value=*/nullptr), config_(config), device_local_config_(device_local_config), mesh_device_(mesh_device->shared_from_this()), address_(address), device_local_size_(device_local_size), + unique_id_(unique_id), state_(ExternallyOwnedState{}) {} void initialize_device_buffers(); + bool is_sharded() const; MeshBufferConfig config_; DeviceLocalBufferConfig device_local_config_; std::weak_ptr mesh_device_; @@ -156,6 +193,9 @@ class MeshBuffer { struct DeallocatedState {}; using MeshBufferState = std::variant; MeshBufferState state_; + size_t unique_id_; + Allocator* allocator_; + std::shared_ptr buffer_page_mapping_; }; } // namespace tt::tt_metal::distributed diff --git a/tt_metal/distributed/mesh_buffer.cpp b/tt_metal/distributed/mesh_buffer.cpp index 9eb540c5efd..3dc9dd8f8fc 100644 --- a/tt_metal/distributed/mesh_buffer.cpp +++ b/tt_metal/distributed/mesh_buffer.cpp @@ -54,6 +54,108 @@ void validate_mesh_buffer_config(const MeshBufferConfig& config, const MeshDevic mesh_device.num_devices()); } +size_t generate_unique_mesh_id() { + static std::atomic next_id{0}; + return next_id++; +} + +// Helper function to verify all Buffers in the MeshBuffer have the same value, and return it +template +decltype(auto) validate_and_get_reference_value( + const MeshContainer>& buffers, + F&& func, + const std::source_location& loc = std::source_location::current()) { + // Get reference to first device's value + decltype(auto) reference_value = std::forward(func)(buffers.begin()->value()); + + // Validate all other buffers match + for (auto& [coord, buffer] : buffers) { + const auto& current_value = std::forward(func)(buffer); + if (current_value != reference_value) { + TT_THROW( + "{} [{}:{}] failed: Buffer {} returned value that differs from reference. " + "Expected: {}, Actual: {}", + loc.function_name(), + loc.file_name(), + loc.line(), + buffer->unique_id(), + reference_value, + current_value); + } + } + return reference_value; +} + +BufferPageMapping generate_buffer_page_mapping(const MeshBuffer& buffer) { + BufferPageMapping buffer_page_mapping; + + if (buffer.size() == 0) { + return buffer_page_mapping; + } + auto shard_spec = buffer.shard_spec(); + + bool row_major = shard_spec.orientation() == ShardOrientation::ROW_MAJOR; + uint32_t num_cores = buffer.num_cores().value(); + + buffer_page_mapping.all_cores_ = corerange_to_cores(shard_spec.grid(), num_cores, row_major); + TT_FATAL( + num_cores == buffer_page_mapping.all_cores_.size(), + "Buffer has {} cores, but page mapping expects {} cores", + num_cores, + buffer_page_mapping.all_cores_.size()); + uint32_t core_id = 0; + for (const auto& core : buffer_page_mapping.all_cores_) { + buffer_page_mapping.core_to_core_id_.insert({core, core_id}); + core_id++; + } + + uint32_t num_dev_pages = buffer.num_dev_pages(); + auto [core_host_page_indices, shard_shape] = core_to_host_pages( + num_dev_pages, + shard_spec.num_pages(), + num_cores, + buffer.buffer_layout(), + shard_spec.page_shape, + shard_spec.shape(), + shard_spec.tensor2d_shape_in_pages); + + buffer_page_mapping.core_host_page_indices_ = std::vector>(num_cores); + + buffer_page_mapping.dev_page_to_host_page_mapping_ = + std::vector>(num_dev_pages, std::nullopt); + buffer_page_mapping.dev_page_to_core_mapping_ = std::vector(num_dev_pages); + + buffer_page_mapping.host_page_to_local_shard_page_mapping_ = std::vector(buffer.num_pages()); + buffer_page_mapping.host_page_to_dev_page_mapping_ = std::vector(buffer.num_pages()); + buffer_page_mapping.core_shard_shape_ = std::move(shard_shape); + uint32_t dev_page_index = 0; + + auto shape_in_pages = shard_spec.shape_in_pages(); + for (uint32_t core_index = 0; core_index < core_host_page_indices.size(); core_index++) { + uint32_t valid_shard_page = 0; + buffer_page_mapping.core_host_page_indices_[core_index].reserve(shard_spec.num_pages()); + uint32_t shard_page_id = 0; + for (uint32_t shard_page_x = 0; shard_page_x < shape_in_pages[0]; shard_page_x++) { + for (uint32_t shard_page_y = 0; shard_page_y < shape_in_pages[1]; shard_page_y++) { + buffer_page_mapping.dev_page_to_core_mapping_[dev_page_index] = core_index; + if (shard_page_x < buffer_page_mapping.core_shard_shape_[core_index][0] and + shard_page_y < buffer_page_mapping.core_shard_shape_[core_index][1]) { + uint32_t host_page = core_host_page_indices[core_index][valid_shard_page]; + buffer_page_mapping.dev_page_to_host_page_mapping_[dev_page_index] = host_page; + buffer_page_mapping.core_host_page_indices_[core_index].push_back(host_page); + buffer_page_mapping.host_page_to_local_shard_page_mapping_[host_page] = shard_page_id; + buffer_page_mapping.host_page_to_dev_page_mapping_[host_page] = dev_page_index; + valid_shard_page++; + } + dev_page_index++; + shard_page_id++; + } + } + } + + return buffer_page_mapping; +} + } // namespace uint32_t ShardedBufferConfig::compute_datum_size_bytes() const { @@ -100,10 +202,20 @@ std::shared_ptr MeshBuffer::create( device_local_config.bottom_up); mesh_buffer = std::shared_ptr(new MeshBuffer( - mesh_buffer_config, device_local_config, device_local_size, mesh_device, std::move(backing_buffer))); + mesh_buffer_config, + device_local_config, + device_local_size, + mesh_device, + generate_unique_mesh_id(), + std::move(backing_buffer))); } else { - mesh_buffer = std::shared_ptr( - new MeshBuffer(mesh_buffer_config, device_local_config, address.value(), device_local_size, mesh_device)); + mesh_buffer = std::shared_ptr(new MeshBuffer( + mesh_buffer_config, + device_local_config, + address.value(), + device_local_size, + mesh_device, + generate_unique_mesh_id())); } mesh_buffer->initialize_device_buffers(); @@ -128,6 +240,13 @@ void MeshBuffer::initialize_device_buffers() { for (auto& [coord, device_buffer] : buffers_) { device_buffer = init_device_buffer_at_address(coord); } + + auto mesh_device = mesh_device_.lock(); + if (sub_device_id().has_value()) { + allocator_ = mesh_device->allocator(sub_device_id().value()).get(); + } else { + allocator_ = mesh_device->allocator().get(); + } } bool MeshBuffer::is_allocated() const { return not std::holds_alternative(state_); } @@ -173,32 +292,134 @@ MeshBufferLayout MeshBuffer::global_layout() const { } const ShardedBufferConfig& MeshBuffer::global_shard_spec() const { - TT_FATAL( - global_layout() == MeshBufferLayout::SHARDED, "Can only query the global shard spec for a sharded MeshBuffer"); + TT_FATAL(is_sharded(), "Can only query the global shard spec for a sharded MeshBuffer"); return std::get(config_); } uint32_t MeshBuffer::datum_size_bytes() const { // Limitation for now. - TT_FATAL( - this->global_layout() == MeshBufferLayout::SHARDED, - "Can only query datum size for buffers sharded across the Mesh"); + TT_FATAL(is_sharded(), "Can only query datum size for buffers sharded across the Mesh"); return this->global_shard_spec().compute_datum_size_bytes(); } Shape2D MeshBuffer::physical_shard_shape() const { - TT_FATAL( - this->global_layout() == MeshBufferLayout::SHARDED, - "Can only query physical shard shape for buffers sharded across the Mesh"); + TT_FATAL(is_sharded(), "Can only query physical shard shape for buffers sharded across the Mesh"); auto sharded_config = std::get(config_); return sharded_config.physical_shard_shape(); } std::pair MeshBuffer::replicated_dims() const { - TT_FATAL( - this->global_layout() == MeshBufferLayout::SHARDED, - "Can only query replicated dims for buffers sharded across the Mesh"); + TT_FATAL(is_sharded(), "Can only query replicated dims for buffers sharded across the Mesh"); return this->global_shard_spec().replicated_dims(); } +void MeshBuffer::set_page_size(DeviceAddr page_size) { + TT_FATAL(page_size == 0 ? size() == 0 : size() % page_size == 0, "buffer size must be divisible by new page size"); + device_local_config_.page_size = page_size; + for (auto& [coord, device_buffer] : buffers_) { + device_buffer->set_page_size(page_size); + } +} + +uint32_t MeshBuffer::num_dev_pages() const { + if (!is_sharded()) { + return this->num_pages(); + } + + return this->shard_spec().num_pages() * this->num_cores().value(); +} + +CoreType MeshBuffer::core_type() const { + switch (buffer_type()) { + case BufferType::DRAM: return CoreType::DRAM; + case BufferType::L1: + case BufferType::L1_SMALL: return CoreType::WORKER; + default: TT_THROW("Unknown CoreType {} for buffer", buffer_type()); + } +} + +bool MeshBuffer::is_l1() const { return buffer_type() == BufferType::L1 or buffer_type() == BufferType::L1_SMALL; } +bool MeshBuffer::is_dram() const { return buffer_type() == BufferType::DRAM || buffer_type() == BufferType::TRACE; } +bool MeshBuffer::is_trace() const { return buffer_type() == BufferType::TRACE; } + +bool MeshBuffer::is_valid_region(const BufferRegion& region) const { + return region.offset + region.size <= this->size(); +} + +bool MeshBuffer::is_valid_partial_region(const BufferRegion& region) const { + return this->is_valid_region(region) && (region.offset > 0 || region.size != this->size()); +} + +DeviceAddr MeshBuffer::page_address(uint32_t bank_id, uint32_t page_index) const { + return validate_and_get_reference_value(this->buffers_, [bank_id, page_index](const auto& buffer) { + return buffer->page_address(bank_id, page_index); + }); +} + +DeviceAddr MeshBuffer::bank_local_page_address(uint32_t bank_id, uint32_t page_index) const { + return validate_and_get_reference_value(this->buffers_, [bank_id, page_index](const auto& buffer) { + return buffer->bank_local_page_address(bank_id, page_index); + }); +} + +uint32_t MeshBuffer::alignment() const { + return validate_and_get_reference_value(this->buffers_, [](const auto& buffer) { return buffer->alignment(); }); +} + +DeviceAddr MeshBuffer::aligned_page_size() const { + return validate_and_get_reference_value( + this->buffers_, [](const auto& buffer) { return buffer->aligned_page_size(); }); +} + +DeviceAddr MeshBuffer::aligned_size() const { + return validate_and_get_reference_value(this->buffers_, [](const auto& buffer) { return buffer->aligned_size(); }); +} + +DeviceAddr MeshBuffer::aligned_size_per_bank() const { + return validate_and_get_reference_value( + this->buffers_, [](const auto& buffer) { return buffer->aligned_size_per_bank(); }); +} + +DeviceAddr MeshBuffer::sharded_page_address(uint32_t bank_id, uint32_t page_index) const { + TT_FATAL(is_sharded(), "Can only query shard spec for buffers sharded across the Mesh"); + return validate_and_get_reference_value(this->buffers_, [bank_id, page_index](const auto& buffer) { + return buffer->sharded_page_address(bank_id, page_index); + }); +} + +ShardSpecBuffer MeshBuffer::shard_spec() const { + TT_FATAL(is_sharded(), "Can only query shard spec for buffers sharded across the Mesh"); + return device_local_config_.shard_parameters.value(); +} + +void MeshBuffer::set_shard_spec(const ShardSpecBuffer& shard_spec) { + TT_FATAL(is_sharded(), "Can only set shard spec for buffers sharded across the Mesh"); + device_local_config_.shard_parameters = shard_spec; + for (auto& [coord, device_buffer] : buffers_) { + device_buffer->set_shard_spec(shard_spec); + } +} + +std::optional MeshBuffer::num_cores() const { + if (!is_sharded()) { + return std::nullopt; + } + + return this->shard_spec().tensor_shard_spec.grid.num_cores(); +} + +bool MeshBuffer::is_sharded() const { return this->global_layout() == MeshBufferLayout::SHARDED; } + +const std::shared_ptr& MeshBuffer::get_buffer_page_mapping() { + TT_FATAL(is_sharded(), "Can only get page mapping for buffers sharded across the Mesh"); + if (!this->buffer_page_mapping_) { + this->buffer_page_mapping_ = std::make_shared(generate_buffer_page_mapping(*this)); + } + return this->buffer_page_mapping_; +} + +std::optional MeshBuffer::sub_device_id() const { + return validate_and_get_reference_value(this->buffers_, [](const auto& buffer) { return buffer->sub_device_id(); }); +} + } // namespace tt::tt_metal::distributed diff --git a/tt_metal/impl/buffers/buffer.cpp b/tt_metal/impl/buffers/buffer.cpp index 411e205042c..7324f312725 100644 --- a/tt_metal/impl/buffers/buffer.cpp +++ b/tt_metal/impl/buffers/buffer.cpp @@ -76,6 +76,51 @@ void validate_buffer_size_and_page_size( } } +void validate_sub_device_id( + std::optional sub_device_id, + IDevice* device, + BufferType buffer_type, + const std::optional& shard_parameters) { + // No need to validate if we're using the global allocator or not sharding + if (!sub_device_id.has_value()) { + return; + } + TT_FATAL(shard_parameters.has_value(), "Specifying sub-device for buffer requires buffer to be sharded"); + TT_FATAL(is_l1_impl(buffer_type), "Specifying sub-device for buffer requires buffer to be L1"); + const auto& sub_device_cores = device->worker_cores(HalProgrammableCoreType::TENSIX, sub_device_id.value()); + const auto& shard_cores = shard_parameters->grid(); + TT_FATAL( + sub_device_cores.contains(shard_cores), + "Shard cores specified {} do not match sub-device cores {}", + shard_cores, + sub_device_cores); +} + +void validate_sub_device_manager_id(std::optional sub_device_manager_id, IDevice* device) { + if (sub_device_manager_id.has_value()) { + TT_FATAL( + sub_device_manager_id.value() == device->get_active_sub_device_manager_id(), + "Sub-device manager id mismatch. Buffer sub-device manager id: {}, Device active sub-device manager id: {}", + sub_device_manager_id.value(), + device->get_active_sub_device_manager_id()); + } +} + +} // namespace + +std::atomic Buffer::next_unique_id = 0; + +std::ostream& operator<<(std::ostream& os, const ShardSpec& spec) { + tt::stl::reflection::operator<<(os, spec); + return os; +} + +bool is_sharded(const TensorMemoryLayout& layout) { + return ( + layout == TensorMemoryLayout::HEIGHT_SHARDED || layout == TensorMemoryLayout::WIDTH_SHARDED || + layout == TensorMemoryLayout::BLOCK_SHARDED); +} + std::tuple>, std::vector>> core_to_host_pages( const uint32_t total_pages, const uint32_t pages_per_shard, @@ -142,51 +187,6 @@ std::tuple>, std::vector sub_device_id, - IDevice* device, - BufferType buffer_type, - const std::optional& shard_parameters) { - // No need to validate if we're using the global allocator or not sharding - if (!sub_device_id.has_value()) { - return; - } - TT_FATAL(shard_parameters.has_value(), "Specifying sub-device for buffer requires buffer to be sharded"); - TT_FATAL(is_l1_impl(buffer_type), "Specifying sub-device for buffer requires buffer to be L1"); - const auto& sub_device_cores = device->worker_cores(HalProgrammableCoreType::TENSIX, sub_device_id.value()); - const auto& shard_cores = shard_parameters->grid(); - TT_FATAL( - sub_device_cores.contains(shard_cores), - "Shard cores specified {} do not match sub-device cores {}", - shard_cores, - sub_device_cores); -} - -void validate_sub_device_manager_id(std::optional sub_device_manager_id, IDevice* device) { - if (sub_device_manager_id.has_value()) { - TT_FATAL( - sub_device_manager_id.value() == device->get_active_sub_device_manager_id(), - "Sub-device manager id mismatch. Buffer sub-device manager id: {}, Device active sub-device manager id: {}", - sub_device_manager_id.value(), - device->get_active_sub_device_manager_id()); - } -} - -} // namespace - -std::atomic Buffer::next_unique_id = 0; - -std::ostream& operator<<(std::ostream& os, const ShardSpec& spec) { - tt::stl::reflection::operator<<(os, spec); - return os; -} - -bool is_sharded(const TensorMemoryLayout& layout) { - return ( - layout == TensorMemoryLayout::HEIGHT_SHARDED || layout == TensorMemoryLayout::WIDTH_SHARDED || - layout == TensorMemoryLayout::BLOCK_SHARDED); -} - BufferPageMapping generate_buffer_page_mapping(const Buffer& buffer) { BufferPageMapping buffer_page_mapping;