Skip to content

Commit

Permalink
Refactoring api_impl.cc and api_impl.h code (#116)
Browse files Browse the repository at this point in the history
The `tt-xla` repository currently contains particularly large files
`api_impl.cc` and `api_impl.h` that handle multiple functionalities
through different classes. This issue is for refactoring `api_impl.cc`
and `api_impl.h`. Splitting this file into smaller, logically organized
files, with each file containing a single class or closely related
classes. Other than that, no significant changed to the logic or
features of the specific functions.
  • Loading branch information
ajakovljevicTT authored Dec 30, 2024
1 parent 1c4b2df commit d500c0e
Show file tree
Hide file tree
Showing 27 changed files with 1,943 additions and 1,568 deletions.
61 changes: 61 additions & 0 deletions inc/common/pjrt_implementation/api_bindings.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0
//
// This file incorporates work covered by the following copyright and permission
// notice:
// SPDX-FileCopyrightText: Copyright 2023 The IREE Authors
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// https://llvm.org/LICENSE.txt

#include "common/pjrt_implementation/error_instance.h"

#include <memory>
#include <utility>

#ifndef TT_XLA_INC_COMMON_PJRT_IMPLEMENTATION_API_BINDINGS_H_
#define TT_XLA_INC_COMMON_PJRT_IMPLEMENTATION_API_BINDINGS_H_

namespace tt::pjrt {

// Top-level API bindings.
void BindMonomorphicApi(PJRT_Api *api);

void BindUndefineds(PJRT_Api *api);

// Initializes and returns PJRT plugin attributes.
PJRT_Error *InitializePluginAttributes(PJRT_Plugin_Attributes_Args *args);

template <typename PlatformTy, typename ClientInstanceTy>
void BindApi(PJRT_Api *api) {
BindMonomorphicApi(api);

// Bind polymorphic entry-points.
api->PJRT_Client_Create = +[](PJRT_Client_Create_Args *args) -> PJRT_Error * {
DLOG_F(LOG_DEBUG, "PJRT_Client_Create");
auto platform = std::make_unique<PlatformTy>();

// Populate config_vars() from the client create_options.
for (size_t i = 0; i < args->num_options; ++i) {
DLOG_F(WARNING, "Unused config var: %s", args->create_options[i].name);
}

auto status = platform->Initialize();
if (!tt_pjrt_status_is_ok(status)) {
return ErrorInstance::MakeError(status);
}

auto client = std::make_unique<ClientInstanceTy>(std::move(platform));
auto *error = client->Initialize();
if (error)
return error;

// Successful return.
args->client = reinterpret_cast<PJRT_Client *>(client.release());
return nullptr;
};
}

} // namespace tt::pjrt

#endif
89 changes: 89 additions & 0 deletions inc/common/pjrt_implementation/buffer_instance.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0
//
// This file incorporates work covered by the following copyright and permission
// notice:
// SPDX-FileCopyrightText: Copyright 2023 The IREE Authors
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// https://llvm.org/LICENSE.txt

#include "tt/runtime/runtime.h"
#include "xla/pjrt/c/pjrt_c_api.h"

#include "common/pjrt_implementation/event_instance.h"
#include "common/status.h"

#ifndef TT_XLA_INC_COMMON_PJRT_IMPLEMENTATION_BUFFER_INSTANCE_H_
#define TT_XLA_INC_COMMON_PJRT_IMPLEMENTATION_BUFFER_INSTANCE_H_

namespace tt::pjrt {

class DeviceInstance;

class BufferInstance {
public:
BufferInstance(DeviceInstance &device, tt::runtime::Tensor tensor,
std::vector<std::uint32_t> shape,
std::vector<std::uint32_t> stride);
BufferInstance(DeviceInstance &device);
~BufferInstance();
operator PJRT_Buffer *() { return reinterpret_cast<PJRT_Buffer *>(this); }
static BufferInstance *Unwrap(PJRT_Buffer *buffer) {
return reinterpret_cast<BufferInstance *>(buffer);
}
static void BindApi(PJRT_Api *api);

// iree_hal_buffer_view_t* buffer_view() { return buffer_view_.get(); }
DeviceInstance &device() { return device_; }
tt_pjrt_status AsyncDeallocate();
tt_pjrt_status Delete();
bool is_deleted() { return is_deleted_; }
bool is_on_cpu() {
// TODO: Plumb through an indication if running on CPU and then implement
// the hook to get an unsafe pointer (avoids a copy).
return false;
}
tt::runtime::Tensor tensor() { return tensor_.value(); }

PJRT_Error *GetMemoryLayout(PJRT_Buffer_GetMemoryLayout_Args *args);
// Gets the required host size in bytes to copy to host.
tt_pjrt_status GetHostSizeInBytes(size_t *host_size);
tt_pjrt_status CopyToHost(void *dst, size_t dst_size,
EventInstance **done_event);

const int64_t *dims() { return dims_.data(); }
size_t num_dims() { return dims_.size(); }
void setType(PJRT_Buffer_Type Type) { DataType = Type; }
std::optional<PJRT_Buffer_Type> getType() { return DataType; }

// Get the data type for a tensor through runtime if DataType is not set.
PJRT_Buffer_Type getRuntimeType();

int unique_id() { return unique_id_; }

private:
static int id_counter_;
int unique_id_;
void ComputeLayout();

DeviceInstance &device_;
// When the buffer resource gets freed, this is set to true.
bool is_deleted_ = false;

// API elements that must have the same lifetime as BufferInstance.
std::vector<int64_t> dims_;
std::vector<std::uint32_t> stride_;
std::optional<tt::runtime::Tensor> tensor_;

std::vector<int64_t> minor_to_major_;
std::vector<int64_t> tile_dims_;
std::vector<size_t> tile_dim_sizes_;

// Underlying datatype of tensor.
std::optional<PJRT_Buffer_Type> DataType;
};

} // namespace tt::pjrt

#endif
100 changes: 100 additions & 0 deletions inc/common/pjrt_implementation/client_instance.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0
//
// This file incorporates work covered by the following copyright and permission
// notice:
// SPDX-FileCopyrightText: Copyright 2023 The IREE Authors
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// https://llvm.org/LICENSE.txt

#include "xla/pjrt/c/pjrt_c_api.h"

#include <memory>
#include <vector>

#include "common/module_builder.h"
#include "common/pjrt_implementation/device_instance.h"
#include "common/pjrt_implementation/loaded_executable_instance.h"
#include "common/platform.h"
#include "common/status.h"

#ifndef TT_XLA_INC_COMMON_PJRT_IMPLEMENTATION_CLIENT_INSTANCE_H_
#define TT_XLA_INC_COMMON_PJRT_IMPLEMENTATION_CLIENT_INSTANCE_H_

namespace tt::pjrt {

//===----------------------------------------------------------------------===//
// ClientInstance
// The root of the runtime hierarchy, these map to an IREE driver and are
// created against an API.
//===----------------------------------------------------------------------===//
class ClientInstance {

public:
ClientInstance(std::unique_ptr<Platform> platform);
virtual ~ClientInstance();

// Binds monomorphic entry-points for the client.
static void BindApi(PJRT_Api *api);

static ClientInstance *Unwrap(PJRT_Client *client) {
return reinterpret_cast<ClientInstance *>(client);
}

// Before the client is usable, it must be initialized.
PJRT_Error *Initialize();

Platform &platform() { return *platform_; }
const std::vector<DeviceInstance *> &devices() { return devices_; }
const std::vector<DeviceInstance *> &addressable_devices() {
return addressable_devices_;
}
const std::string &cached_platform_name() { return cached_platform_name_; }
const std::string &cached_platform_version() {
return cached_platform_version_;
}

// Checks if the output on the i-th index is a scalar.
bool isOutputScalar(const size_t index) const {
return module_builder_->isOutputScalar(index);
}

// Compiles.
// See TODOs in PJRT_Client_Compile.
PJRT_Error *
Compile(const PJRT_Program *program, /*xla::CompileOptions options, */
LoadedExecutableInstance **executable);

// Advances the timeline, returning (current, next) time point values.
std::tuple<uint64_t, uint64_t> AdvanceTimeline();

protected:
std::string cached_platform_name_;
std::string cached_platform_version_;

private:
tt_pjrt_status InitializeCompiler();
tt_pjrt_status PopulateDevices();

std::unique_ptr<Platform> platform_;

std::vector<DeviceInstance *> devices_;
std::vector<DeviceInstance *> addressable_devices_;

std::unique_ptr<ModuleBuilder> module_builder_;

// Synchronization.
// We keep one global execution timeline across all devices. The management
// of this is currently somewhat primitive: we increment it by one for each
// invocation. Batch invocations (i.e. across multiple devices), only
// increment by one. In the future, additional parallelism could be plumbed
// up to the framework to allow different kinds of timeline management.
// Waiting on the current value of |execution_timeline_| will drain all
// scheduled work to date.
uint64_t execution_timeline_ = 0ull;
};

} // namespace tt::pjrt

#endif
63 changes: 63 additions & 0 deletions inc/common/pjrt_implementation/device_description.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0
//
// This file incorporates work covered by the following copyright and permission
// notice:
// SPDX-FileCopyrightText: Copyright 2023 The IREE Authors
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// https://llvm.org/LICENSE.txt

#include <sstream>

#include "xla/pjrt/c/pjrt_c_api.h"

#ifndef TT_XLA_INC_COMMON_PJRT_IMPLEMENTATION_DEVICE_DESCRIPTION_H_
#define TT_XLA_INC_COMMON_PJRT_IMPLEMENTATION_DEVICE_DESCRIPTION_H_

namespace tt::pjrt {

class DeviceDescription {

public:
DeviceDescription(int32_t client_id) : client_id_(client_id) {};
~DeviceDescription();
operator PJRT_DeviceDescription *() {
return reinterpret_cast<PJRT_DeviceDescription *>(this);
}
static void BindApi(PJRT_Api *api);

static DeviceDescription *Unwrap(PJRT_DeviceDescription *device) {
return reinterpret_cast<DeviceDescription *>(device);
}

std::string_view kind_string() { return kind_string_; }
std::string_view debug_string() { return to_string(); }
std::string_view to_string() {
std::stringstream ss;
ss << kind_string_ << "(id=" << device_id() << ", arch=" << arch_string_
<< ")";
user_string_ = ss.str();
return user_string_;
}

// TODO
int64_t device_id() { return 0; }

int client_id() { return client_id_; }

int process_index() { return 0; }

private:
int client_id_;

// TODO We should understand better how these are used.
// See https://github.com/tenstorrent/tt-xla/issues/125
std::string kind_string_ = "TTDevice";
std::string arch_string_ = "Wormhole";
std::string user_string_ = "";
};

} // namespace tt::pjrt

#endif
77 changes: 77 additions & 0 deletions inc/common/pjrt_implementation/device_instance.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0
//
// This file incorporates work covered by the following copyright and permission
// notice:
// SPDX-FileCopyrightText: Copyright 2023 The IREE Authors
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// https://llvm.org/LICENSE.txt

#include "xla/pjrt/c/pjrt_c_api.h"

#include "common/pjrt_implementation/device_description.h"
#include "common/pjrt_implementation/event_instance.h"
#include "common/status.h"

#ifndef TT_XLA_INC_COMMON_PJRT_IMPLEMENTATION_DEVICE_INSTANCE_H_
#define TT_XLA_INC_COMMON_PJRT_IMPLEMENTATION_DEVICE_INSTANCE_H_

namespace tt::pjrt {

class ClientInstance;
class BufferInstance;

class DeviceInstance {

public:
DeviceInstance(int client_id, ClientInstance &client)
: client_(client), description_(client_id) {}
~DeviceInstance();
operator PJRT_Device *() { return reinterpret_cast<PJRT_Device *>(this); }
static void BindApi(PJRT_Api *api);

static DeviceInstance *Unwrap(PJRT_Device *device) {
return reinterpret_cast<DeviceInstance *>(device);
}

static DeviceInstance *Unwrap(PJRT_DeviceDescription *device_description) {
return reinterpret_cast<DeviceInstance *>(device_description);
}
ClientInstance &client() { return client_; }
bool is_addressable() { return true; }
int local_hardware_id() { return -1; }

tt_pjrt_status
HostBufferToDeviceZeroDim(PJRT_Buffer_Type type, const int64_t *dims,
size_t num_dims,
EventInstance **out_done_with_host_buffer_event,
BufferInstance **out_buffer);

tt_pjrt_status
HostBufferToDeviceSplat(const void *data, PJRT_Buffer_Type type,
const int64_t *dims, size_t num_dims,
EventInstance **out_done_with_host_buffer_event,
BufferInstance **out_buffer);

tt_pjrt_status
HostBufferToDevice(const void *data, PJRT_Buffer_Type type,
const int64_t *dims, size_t num_dims,
const int64_t *byte_strides, size_t num_byte_strides,
PJRT_HostBufferSemantics host_buffer_semantics,
EventInstance **out_done_with_host_buffer_event,
BufferInstance **out_buffer);

DeviceDescription *device_description() { return &description_; }

private:
tt_pjrt_status OpenDevice();

ClientInstance &client_;
uint64_t last_transfer_timepoint_ = 0;
DeviceDescription description_;
};

} // namespace tt::pjrt

#endif
Loading

0 comments on commit d500c0e

Please sign in to comment.