-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactoring api_impl.cc and api_impl.h code (#116)
The `tt-xla` repository currently contains particularly large files `api_impl.cc` and `api_impl.h` that handle multiple functionalities through different classes. This issue is for refactoring `api_impl.cc` and `api_impl.h`. Splitting this file into smaller, logically organized files, with each file containing a single class or closely related classes. Other than that, no significant changed to the logic or features of the specific functions.
- Loading branch information
1 parent
1c4b2df
commit d500c0e
Showing
27 changed files
with
1,943 additions
and
1,568 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
// SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
// This file incorporates work covered by the following copyright and permission | ||
// notice: | ||
// SPDX-FileCopyrightText: Copyright 2023 The IREE Authors | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// https://llvm.org/LICENSE.txt | ||
|
||
#include "common/pjrt_implementation/error_instance.h" | ||
|
||
#include <memory> | ||
#include <utility> | ||
|
||
#ifndef TT_XLA_INC_COMMON_PJRT_IMPLEMENTATION_API_BINDINGS_H_ | ||
#define TT_XLA_INC_COMMON_PJRT_IMPLEMENTATION_API_BINDINGS_H_ | ||
|
||
namespace tt::pjrt { | ||
|
||
// Top-level API bindings. | ||
void BindMonomorphicApi(PJRT_Api *api); | ||
|
||
void BindUndefineds(PJRT_Api *api); | ||
|
||
// Initializes and returns PJRT plugin attributes. | ||
PJRT_Error *InitializePluginAttributes(PJRT_Plugin_Attributes_Args *args); | ||
|
||
template <typename PlatformTy, typename ClientInstanceTy> | ||
void BindApi(PJRT_Api *api) { | ||
BindMonomorphicApi(api); | ||
|
||
// Bind polymorphic entry-points. | ||
api->PJRT_Client_Create = +[](PJRT_Client_Create_Args *args) -> PJRT_Error * { | ||
DLOG_F(LOG_DEBUG, "PJRT_Client_Create"); | ||
auto platform = std::make_unique<PlatformTy>(); | ||
|
||
// Populate config_vars() from the client create_options. | ||
for (size_t i = 0; i < args->num_options; ++i) { | ||
DLOG_F(WARNING, "Unused config var: %s", args->create_options[i].name); | ||
} | ||
|
||
auto status = platform->Initialize(); | ||
if (!tt_pjrt_status_is_ok(status)) { | ||
return ErrorInstance::MakeError(status); | ||
} | ||
|
||
auto client = std::make_unique<ClientInstanceTy>(std::move(platform)); | ||
auto *error = client->Initialize(); | ||
if (error) | ||
return error; | ||
|
||
// Successful return. | ||
args->client = reinterpret_cast<PJRT_Client *>(client.release()); | ||
return nullptr; | ||
}; | ||
} | ||
|
||
} // namespace tt::pjrt | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
// SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
// This file incorporates work covered by the following copyright and permission | ||
// notice: | ||
// SPDX-FileCopyrightText: Copyright 2023 The IREE Authors | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// https://llvm.org/LICENSE.txt | ||
|
||
#include "tt/runtime/runtime.h" | ||
#include "xla/pjrt/c/pjrt_c_api.h" | ||
|
||
#include "common/pjrt_implementation/event_instance.h" | ||
#include "common/status.h" | ||
|
||
#ifndef TT_XLA_INC_COMMON_PJRT_IMPLEMENTATION_BUFFER_INSTANCE_H_ | ||
#define TT_XLA_INC_COMMON_PJRT_IMPLEMENTATION_BUFFER_INSTANCE_H_ | ||
|
||
namespace tt::pjrt { | ||
|
||
class DeviceInstance; | ||
|
||
class BufferInstance { | ||
public: | ||
BufferInstance(DeviceInstance &device, tt::runtime::Tensor tensor, | ||
std::vector<std::uint32_t> shape, | ||
std::vector<std::uint32_t> stride); | ||
BufferInstance(DeviceInstance &device); | ||
~BufferInstance(); | ||
operator PJRT_Buffer *() { return reinterpret_cast<PJRT_Buffer *>(this); } | ||
static BufferInstance *Unwrap(PJRT_Buffer *buffer) { | ||
return reinterpret_cast<BufferInstance *>(buffer); | ||
} | ||
static void BindApi(PJRT_Api *api); | ||
|
||
// iree_hal_buffer_view_t* buffer_view() { return buffer_view_.get(); } | ||
DeviceInstance &device() { return device_; } | ||
tt_pjrt_status AsyncDeallocate(); | ||
tt_pjrt_status Delete(); | ||
bool is_deleted() { return is_deleted_; } | ||
bool is_on_cpu() { | ||
// TODO: Plumb through an indication if running on CPU and then implement | ||
// the hook to get an unsafe pointer (avoids a copy). | ||
return false; | ||
} | ||
tt::runtime::Tensor tensor() { return tensor_.value(); } | ||
|
||
PJRT_Error *GetMemoryLayout(PJRT_Buffer_GetMemoryLayout_Args *args); | ||
// Gets the required host size in bytes to copy to host. | ||
tt_pjrt_status GetHostSizeInBytes(size_t *host_size); | ||
tt_pjrt_status CopyToHost(void *dst, size_t dst_size, | ||
EventInstance **done_event); | ||
|
||
const int64_t *dims() { return dims_.data(); } | ||
size_t num_dims() { return dims_.size(); } | ||
void setType(PJRT_Buffer_Type Type) { DataType = Type; } | ||
std::optional<PJRT_Buffer_Type> getType() { return DataType; } | ||
|
||
// Get the data type for a tensor through runtime if DataType is not set. | ||
PJRT_Buffer_Type getRuntimeType(); | ||
|
||
int unique_id() { return unique_id_; } | ||
|
||
private: | ||
static int id_counter_; | ||
int unique_id_; | ||
void ComputeLayout(); | ||
|
||
DeviceInstance &device_; | ||
// When the buffer resource gets freed, this is set to true. | ||
bool is_deleted_ = false; | ||
|
||
// API elements that must have the same lifetime as BufferInstance. | ||
std::vector<int64_t> dims_; | ||
std::vector<std::uint32_t> stride_; | ||
std::optional<tt::runtime::Tensor> tensor_; | ||
|
||
std::vector<int64_t> minor_to_major_; | ||
std::vector<int64_t> tile_dims_; | ||
std::vector<size_t> tile_dim_sizes_; | ||
|
||
// Underlying datatype of tensor. | ||
std::optional<PJRT_Buffer_Type> DataType; | ||
}; | ||
|
||
} // namespace tt::pjrt | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
// SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
// This file incorporates work covered by the following copyright and permission | ||
// notice: | ||
// SPDX-FileCopyrightText: Copyright 2023 The IREE Authors | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// https://llvm.org/LICENSE.txt | ||
|
||
#include "xla/pjrt/c/pjrt_c_api.h" | ||
|
||
#include <memory> | ||
#include <vector> | ||
|
||
#include "common/module_builder.h" | ||
#include "common/pjrt_implementation/device_instance.h" | ||
#include "common/pjrt_implementation/loaded_executable_instance.h" | ||
#include "common/platform.h" | ||
#include "common/status.h" | ||
|
||
#ifndef TT_XLA_INC_COMMON_PJRT_IMPLEMENTATION_CLIENT_INSTANCE_H_ | ||
#define TT_XLA_INC_COMMON_PJRT_IMPLEMENTATION_CLIENT_INSTANCE_H_ | ||
|
||
namespace tt::pjrt { | ||
|
||
//===----------------------------------------------------------------------===// | ||
// ClientInstance | ||
// The root of the runtime hierarchy, these map to an IREE driver and are | ||
// created against an API. | ||
//===----------------------------------------------------------------------===// | ||
class ClientInstance { | ||
|
||
public: | ||
ClientInstance(std::unique_ptr<Platform> platform); | ||
virtual ~ClientInstance(); | ||
|
||
// Binds monomorphic entry-points for the client. | ||
static void BindApi(PJRT_Api *api); | ||
|
||
static ClientInstance *Unwrap(PJRT_Client *client) { | ||
return reinterpret_cast<ClientInstance *>(client); | ||
} | ||
|
||
// Before the client is usable, it must be initialized. | ||
PJRT_Error *Initialize(); | ||
|
||
Platform &platform() { return *platform_; } | ||
const std::vector<DeviceInstance *> &devices() { return devices_; } | ||
const std::vector<DeviceInstance *> &addressable_devices() { | ||
return addressable_devices_; | ||
} | ||
const std::string &cached_platform_name() { return cached_platform_name_; } | ||
const std::string &cached_platform_version() { | ||
return cached_platform_version_; | ||
} | ||
|
||
// Checks if the output on the i-th index is a scalar. | ||
bool isOutputScalar(const size_t index) const { | ||
return module_builder_->isOutputScalar(index); | ||
} | ||
|
||
// Compiles. | ||
// See TODOs in PJRT_Client_Compile. | ||
PJRT_Error * | ||
Compile(const PJRT_Program *program, /*xla::CompileOptions options, */ | ||
LoadedExecutableInstance **executable); | ||
|
||
// Advances the timeline, returning (current, next) time point values. | ||
std::tuple<uint64_t, uint64_t> AdvanceTimeline(); | ||
|
||
protected: | ||
std::string cached_platform_name_; | ||
std::string cached_platform_version_; | ||
|
||
private: | ||
tt_pjrt_status InitializeCompiler(); | ||
tt_pjrt_status PopulateDevices(); | ||
|
||
std::unique_ptr<Platform> platform_; | ||
|
||
std::vector<DeviceInstance *> devices_; | ||
std::vector<DeviceInstance *> addressable_devices_; | ||
|
||
std::unique_ptr<ModuleBuilder> module_builder_; | ||
|
||
// Synchronization. | ||
// We keep one global execution timeline across all devices. The management | ||
// of this is currently somewhat primitive: we increment it by one for each | ||
// invocation. Batch invocations (i.e. across multiple devices), only | ||
// increment by one. In the future, additional parallelism could be plumbed | ||
// up to the framework to allow different kinds of timeline management. | ||
// Waiting on the current value of |execution_timeline_| will drain all | ||
// scheduled work to date. | ||
uint64_t execution_timeline_ = 0ull; | ||
}; | ||
|
||
} // namespace tt::pjrt | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
// SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
// This file incorporates work covered by the following copyright and permission | ||
// notice: | ||
// SPDX-FileCopyrightText: Copyright 2023 The IREE Authors | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// https://llvm.org/LICENSE.txt | ||
|
||
#include <sstream> | ||
|
||
#include "xla/pjrt/c/pjrt_c_api.h" | ||
|
||
#ifndef TT_XLA_INC_COMMON_PJRT_IMPLEMENTATION_DEVICE_DESCRIPTION_H_ | ||
#define TT_XLA_INC_COMMON_PJRT_IMPLEMENTATION_DEVICE_DESCRIPTION_H_ | ||
|
||
namespace tt::pjrt { | ||
|
||
class DeviceDescription { | ||
|
||
public: | ||
DeviceDescription(int32_t client_id) : client_id_(client_id) {}; | ||
~DeviceDescription(); | ||
operator PJRT_DeviceDescription *() { | ||
return reinterpret_cast<PJRT_DeviceDescription *>(this); | ||
} | ||
static void BindApi(PJRT_Api *api); | ||
|
||
static DeviceDescription *Unwrap(PJRT_DeviceDescription *device) { | ||
return reinterpret_cast<DeviceDescription *>(device); | ||
} | ||
|
||
std::string_view kind_string() { return kind_string_; } | ||
std::string_view debug_string() { return to_string(); } | ||
std::string_view to_string() { | ||
std::stringstream ss; | ||
ss << kind_string_ << "(id=" << device_id() << ", arch=" << arch_string_ | ||
<< ")"; | ||
user_string_ = ss.str(); | ||
return user_string_; | ||
} | ||
|
||
// TODO | ||
int64_t device_id() { return 0; } | ||
|
||
int client_id() { return client_id_; } | ||
|
||
int process_index() { return 0; } | ||
|
||
private: | ||
int client_id_; | ||
|
||
// TODO We should understand better how these are used. | ||
// See https://github.com/tenstorrent/tt-xla/issues/125 | ||
std::string kind_string_ = "TTDevice"; | ||
std::string arch_string_ = "Wormhole"; | ||
std::string user_string_ = ""; | ||
}; | ||
|
||
} // namespace tt::pjrt | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
// SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
// This file incorporates work covered by the following copyright and permission | ||
// notice: | ||
// SPDX-FileCopyrightText: Copyright 2023 The IREE Authors | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// https://llvm.org/LICENSE.txt | ||
|
||
#include "xla/pjrt/c/pjrt_c_api.h" | ||
|
||
#include "common/pjrt_implementation/device_description.h" | ||
#include "common/pjrt_implementation/event_instance.h" | ||
#include "common/status.h" | ||
|
||
#ifndef TT_XLA_INC_COMMON_PJRT_IMPLEMENTATION_DEVICE_INSTANCE_H_ | ||
#define TT_XLA_INC_COMMON_PJRT_IMPLEMENTATION_DEVICE_INSTANCE_H_ | ||
|
||
namespace tt::pjrt { | ||
|
||
class ClientInstance; | ||
class BufferInstance; | ||
|
||
class DeviceInstance { | ||
|
||
public: | ||
DeviceInstance(int client_id, ClientInstance &client) | ||
: client_(client), description_(client_id) {} | ||
~DeviceInstance(); | ||
operator PJRT_Device *() { return reinterpret_cast<PJRT_Device *>(this); } | ||
static void BindApi(PJRT_Api *api); | ||
|
||
static DeviceInstance *Unwrap(PJRT_Device *device) { | ||
return reinterpret_cast<DeviceInstance *>(device); | ||
} | ||
|
||
static DeviceInstance *Unwrap(PJRT_DeviceDescription *device_description) { | ||
return reinterpret_cast<DeviceInstance *>(device_description); | ||
} | ||
ClientInstance &client() { return client_; } | ||
bool is_addressable() { return true; } | ||
int local_hardware_id() { return -1; } | ||
|
||
tt_pjrt_status | ||
HostBufferToDeviceZeroDim(PJRT_Buffer_Type type, const int64_t *dims, | ||
size_t num_dims, | ||
EventInstance **out_done_with_host_buffer_event, | ||
BufferInstance **out_buffer); | ||
|
||
tt_pjrt_status | ||
HostBufferToDeviceSplat(const void *data, PJRT_Buffer_Type type, | ||
const int64_t *dims, size_t num_dims, | ||
EventInstance **out_done_with_host_buffer_event, | ||
BufferInstance **out_buffer); | ||
|
||
tt_pjrt_status | ||
HostBufferToDevice(const void *data, PJRT_Buffer_Type type, | ||
const int64_t *dims, size_t num_dims, | ||
const int64_t *byte_strides, size_t num_byte_strides, | ||
PJRT_HostBufferSemantics host_buffer_semantics, | ||
EventInstance **out_done_with_host_buffer_event, | ||
BufferInstance **out_buffer); | ||
|
||
DeviceDescription *device_description() { return &description_; } | ||
|
||
private: | ||
tt_pjrt_status OpenDevice(); | ||
|
||
ClientInstance &client_; | ||
uint64_t last_transfer_timepoint_ = 0; | ||
DeviceDescription description_; | ||
}; | ||
|
||
} // namespace tt::pjrt | ||
|
||
#endif |
Oops, something went wrong.