Skip to content

Commit d500c0e

Browse files
Refactoring api_impl.cc and api_impl.h code (#116)
The `tt-xla` repository currently contains particularly large files `api_impl.cc` and `api_impl.h` that handle multiple functionalities through different classes. This issue is for refactoring `api_impl.cc` and `api_impl.h`. Splitting this file into smaller, logically organized files, with each file containing a single class or closely related classes. Other than that, no significant changed to the logic or features of the specific functions.
1 parent 1c4b2df commit d500c0e

27 files changed

+1943
-1568
lines changed
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
// SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC
2+
//
3+
// SPDX-License-Identifier: Apache-2.0
4+
//
5+
// This file incorporates work covered by the following copyright and permission
6+
// notice:
7+
// SPDX-FileCopyrightText: Copyright 2023 The IREE Authors
8+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
9+
// https://llvm.org/LICENSE.txt
10+
11+
#include "common/pjrt_implementation/error_instance.h"
12+
13+
#include <memory>
14+
#include <utility>
15+
16+
#ifndef TT_XLA_INC_COMMON_PJRT_IMPLEMENTATION_API_BINDINGS_H_
17+
#define TT_XLA_INC_COMMON_PJRT_IMPLEMENTATION_API_BINDINGS_H_
18+
19+
namespace tt::pjrt {
20+
21+
// Top-level API bindings.
22+
void BindMonomorphicApi(PJRT_Api *api);
23+
24+
void BindUndefineds(PJRT_Api *api);
25+
26+
// Initializes and returns PJRT plugin attributes.
27+
PJRT_Error *InitializePluginAttributes(PJRT_Plugin_Attributes_Args *args);
28+
29+
template <typename PlatformTy, typename ClientInstanceTy>
30+
void BindApi(PJRT_Api *api) {
31+
BindMonomorphicApi(api);
32+
33+
// Bind polymorphic entry-points.
34+
api->PJRT_Client_Create = +[](PJRT_Client_Create_Args *args) -> PJRT_Error * {
35+
DLOG_F(LOG_DEBUG, "PJRT_Client_Create");
36+
auto platform = std::make_unique<PlatformTy>();
37+
38+
// Populate config_vars() from the client create_options.
39+
for (size_t i = 0; i < args->num_options; ++i) {
40+
DLOG_F(WARNING, "Unused config var: %s", args->create_options[i].name);
41+
}
42+
43+
auto status = platform->Initialize();
44+
if (!tt_pjrt_status_is_ok(status)) {
45+
return ErrorInstance::MakeError(status);
46+
}
47+
48+
auto client = std::make_unique<ClientInstanceTy>(std::move(platform));
49+
auto *error = client->Initialize();
50+
if (error)
51+
return error;
52+
53+
// Successful return.
54+
args->client = reinterpret_cast<PJRT_Client *>(client.release());
55+
return nullptr;
56+
};
57+
}
58+
59+
} // namespace tt::pjrt
60+
61+
#endif
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
// SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC
2+
//
3+
// SPDX-License-Identifier: Apache-2.0
4+
//
5+
// This file incorporates work covered by the following copyright and permission
6+
// notice:
7+
// SPDX-FileCopyrightText: Copyright 2023 The IREE Authors
8+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
9+
// https://llvm.org/LICENSE.txt
10+
11+
#include "tt/runtime/runtime.h"
12+
#include "xla/pjrt/c/pjrt_c_api.h"
13+
14+
#include "common/pjrt_implementation/event_instance.h"
15+
#include "common/status.h"
16+
17+
#ifndef TT_XLA_INC_COMMON_PJRT_IMPLEMENTATION_BUFFER_INSTANCE_H_
18+
#define TT_XLA_INC_COMMON_PJRT_IMPLEMENTATION_BUFFER_INSTANCE_H_
19+
20+
namespace tt::pjrt {
21+
22+
class DeviceInstance;
23+
24+
class BufferInstance {
25+
public:
26+
BufferInstance(DeviceInstance &device, tt::runtime::Tensor tensor,
27+
std::vector<std::uint32_t> shape,
28+
std::vector<std::uint32_t> stride);
29+
BufferInstance(DeviceInstance &device);
30+
~BufferInstance();
31+
operator PJRT_Buffer *() { return reinterpret_cast<PJRT_Buffer *>(this); }
32+
static BufferInstance *Unwrap(PJRT_Buffer *buffer) {
33+
return reinterpret_cast<BufferInstance *>(buffer);
34+
}
35+
static void BindApi(PJRT_Api *api);
36+
37+
// iree_hal_buffer_view_t* buffer_view() { return buffer_view_.get(); }
38+
DeviceInstance &device() { return device_; }
39+
tt_pjrt_status AsyncDeallocate();
40+
tt_pjrt_status Delete();
41+
bool is_deleted() { return is_deleted_; }
42+
bool is_on_cpu() {
43+
// TODO: Plumb through an indication if running on CPU and then implement
44+
// the hook to get an unsafe pointer (avoids a copy).
45+
return false;
46+
}
47+
tt::runtime::Tensor tensor() { return tensor_.value(); }
48+
49+
PJRT_Error *GetMemoryLayout(PJRT_Buffer_GetMemoryLayout_Args *args);
50+
// Gets the required host size in bytes to copy to host.
51+
tt_pjrt_status GetHostSizeInBytes(size_t *host_size);
52+
tt_pjrt_status CopyToHost(void *dst, size_t dst_size,
53+
EventInstance **done_event);
54+
55+
const int64_t *dims() { return dims_.data(); }
56+
size_t num_dims() { return dims_.size(); }
57+
void setType(PJRT_Buffer_Type Type) { DataType = Type; }
58+
std::optional<PJRT_Buffer_Type> getType() { return DataType; }
59+
60+
// Get the data type for a tensor through runtime if DataType is not set.
61+
PJRT_Buffer_Type getRuntimeType();
62+
63+
int unique_id() { return unique_id_; }
64+
65+
private:
66+
static int id_counter_;
67+
int unique_id_;
68+
void ComputeLayout();
69+
70+
DeviceInstance &device_;
71+
// When the buffer resource gets freed, this is set to true.
72+
bool is_deleted_ = false;
73+
74+
// API elements that must have the same lifetime as BufferInstance.
75+
std::vector<int64_t> dims_;
76+
std::vector<std::uint32_t> stride_;
77+
std::optional<tt::runtime::Tensor> tensor_;
78+
79+
std::vector<int64_t> minor_to_major_;
80+
std::vector<int64_t> tile_dims_;
81+
std::vector<size_t> tile_dim_sizes_;
82+
83+
// Underlying datatype of tensor.
84+
std::optional<PJRT_Buffer_Type> DataType;
85+
};
86+
87+
} // namespace tt::pjrt
88+
89+
#endif
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
// SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC
2+
//
3+
// SPDX-License-Identifier: Apache-2.0
4+
//
5+
// This file incorporates work covered by the following copyright and permission
6+
// notice:
7+
// SPDX-FileCopyrightText: Copyright 2023 The IREE Authors
8+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
9+
// https://llvm.org/LICENSE.txt
10+
11+
#include "xla/pjrt/c/pjrt_c_api.h"
12+
13+
#include <memory>
14+
#include <vector>
15+
16+
#include "common/module_builder.h"
17+
#include "common/pjrt_implementation/device_instance.h"
18+
#include "common/pjrt_implementation/loaded_executable_instance.h"
19+
#include "common/platform.h"
20+
#include "common/status.h"
21+
22+
#ifndef TT_XLA_INC_COMMON_PJRT_IMPLEMENTATION_CLIENT_INSTANCE_H_
23+
#define TT_XLA_INC_COMMON_PJRT_IMPLEMENTATION_CLIENT_INSTANCE_H_
24+
25+
namespace tt::pjrt {
26+
27+
//===----------------------------------------------------------------------===//
28+
// ClientInstance
29+
// The root of the runtime hierarchy, these map to an IREE driver and are
30+
// created against an API.
31+
//===----------------------------------------------------------------------===//
32+
class ClientInstance {
33+
34+
public:
35+
ClientInstance(std::unique_ptr<Platform> platform);
36+
virtual ~ClientInstance();
37+
38+
// Binds monomorphic entry-points for the client.
39+
static void BindApi(PJRT_Api *api);
40+
41+
static ClientInstance *Unwrap(PJRT_Client *client) {
42+
return reinterpret_cast<ClientInstance *>(client);
43+
}
44+
45+
// Before the client is usable, it must be initialized.
46+
PJRT_Error *Initialize();
47+
48+
Platform &platform() { return *platform_; }
49+
const std::vector<DeviceInstance *> &devices() { return devices_; }
50+
const std::vector<DeviceInstance *> &addressable_devices() {
51+
return addressable_devices_;
52+
}
53+
const std::string &cached_platform_name() { return cached_platform_name_; }
54+
const std::string &cached_platform_version() {
55+
return cached_platform_version_;
56+
}
57+
58+
// Checks if the output on the i-th index is a scalar.
59+
bool isOutputScalar(const size_t index) const {
60+
return module_builder_->isOutputScalar(index);
61+
}
62+
63+
// Compiles.
64+
// See TODOs in PJRT_Client_Compile.
65+
PJRT_Error *
66+
Compile(const PJRT_Program *program, /*xla::CompileOptions options, */
67+
LoadedExecutableInstance **executable);
68+
69+
// Advances the timeline, returning (current, next) time point values.
70+
std::tuple<uint64_t, uint64_t> AdvanceTimeline();
71+
72+
protected:
73+
std::string cached_platform_name_;
74+
std::string cached_platform_version_;
75+
76+
private:
77+
tt_pjrt_status InitializeCompiler();
78+
tt_pjrt_status PopulateDevices();
79+
80+
std::unique_ptr<Platform> platform_;
81+
82+
std::vector<DeviceInstance *> devices_;
83+
std::vector<DeviceInstance *> addressable_devices_;
84+
85+
std::unique_ptr<ModuleBuilder> module_builder_;
86+
87+
// Synchronization.
88+
// We keep one global execution timeline across all devices. The management
89+
// of this is currently somewhat primitive: we increment it by one for each
90+
// invocation. Batch invocations (i.e. across multiple devices), only
91+
// increment by one. In the future, additional parallelism could be plumbed
92+
// up to the framework to allow different kinds of timeline management.
93+
// Waiting on the current value of |execution_timeline_| will drain all
94+
// scheduled work to date.
95+
uint64_t execution_timeline_ = 0ull;
96+
};
97+
98+
} // namespace tt::pjrt
99+
100+
#endif
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
// SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC
2+
//
3+
// SPDX-License-Identifier: Apache-2.0
4+
//
5+
// This file incorporates work covered by the following copyright and permission
6+
// notice:
7+
// SPDX-FileCopyrightText: Copyright 2023 The IREE Authors
8+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
9+
// https://llvm.org/LICENSE.txt
10+
11+
#include <sstream>
12+
13+
#include "xla/pjrt/c/pjrt_c_api.h"
14+
15+
#ifndef TT_XLA_INC_COMMON_PJRT_IMPLEMENTATION_DEVICE_DESCRIPTION_H_
16+
#define TT_XLA_INC_COMMON_PJRT_IMPLEMENTATION_DEVICE_DESCRIPTION_H_
17+
18+
namespace tt::pjrt {
19+
20+
class DeviceDescription {
21+
22+
public:
23+
DeviceDescription(int32_t client_id) : client_id_(client_id) {};
24+
~DeviceDescription();
25+
operator PJRT_DeviceDescription *() {
26+
return reinterpret_cast<PJRT_DeviceDescription *>(this);
27+
}
28+
static void BindApi(PJRT_Api *api);
29+
30+
static DeviceDescription *Unwrap(PJRT_DeviceDescription *device) {
31+
return reinterpret_cast<DeviceDescription *>(device);
32+
}
33+
34+
std::string_view kind_string() { return kind_string_; }
35+
std::string_view debug_string() { return to_string(); }
36+
std::string_view to_string() {
37+
std::stringstream ss;
38+
ss << kind_string_ << "(id=" << device_id() << ", arch=" << arch_string_
39+
<< ")";
40+
user_string_ = ss.str();
41+
return user_string_;
42+
}
43+
44+
// TODO
45+
int64_t device_id() { return 0; }
46+
47+
int client_id() { return client_id_; }
48+
49+
int process_index() { return 0; }
50+
51+
private:
52+
int client_id_;
53+
54+
// TODO We should understand better how these are used.
55+
// See https://github.com/tenstorrent/tt-xla/issues/125
56+
std::string kind_string_ = "TTDevice";
57+
std::string arch_string_ = "Wormhole";
58+
std::string user_string_ = "";
59+
};
60+
61+
} // namespace tt::pjrt
62+
63+
#endif
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
// SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC
2+
//
3+
// SPDX-License-Identifier: Apache-2.0
4+
//
5+
// This file incorporates work covered by the following copyright and permission
6+
// notice:
7+
// SPDX-FileCopyrightText: Copyright 2023 The IREE Authors
8+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
9+
// https://llvm.org/LICENSE.txt
10+
11+
#include "xla/pjrt/c/pjrt_c_api.h"
12+
13+
#include "common/pjrt_implementation/device_description.h"
14+
#include "common/pjrt_implementation/event_instance.h"
15+
#include "common/status.h"
16+
17+
#ifndef TT_XLA_INC_COMMON_PJRT_IMPLEMENTATION_DEVICE_INSTANCE_H_
18+
#define TT_XLA_INC_COMMON_PJRT_IMPLEMENTATION_DEVICE_INSTANCE_H_
19+
20+
namespace tt::pjrt {
21+
22+
class ClientInstance;
23+
class BufferInstance;
24+
25+
class DeviceInstance {
26+
27+
public:
28+
DeviceInstance(int client_id, ClientInstance &client)
29+
: client_(client), description_(client_id) {}
30+
~DeviceInstance();
31+
operator PJRT_Device *() { return reinterpret_cast<PJRT_Device *>(this); }
32+
static void BindApi(PJRT_Api *api);
33+
34+
static DeviceInstance *Unwrap(PJRT_Device *device) {
35+
return reinterpret_cast<DeviceInstance *>(device);
36+
}
37+
38+
static DeviceInstance *Unwrap(PJRT_DeviceDescription *device_description) {
39+
return reinterpret_cast<DeviceInstance *>(device_description);
40+
}
41+
ClientInstance &client() { return client_; }
42+
bool is_addressable() { return true; }
43+
int local_hardware_id() { return -1; }
44+
45+
tt_pjrt_status
46+
HostBufferToDeviceZeroDim(PJRT_Buffer_Type type, const int64_t *dims,
47+
size_t num_dims,
48+
EventInstance **out_done_with_host_buffer_event,
49+
BufferInstance **out_buffer);
50+
51+
tt_pjrt_status
52+
HostBufferToDeviceSplat(const void *data, PJRT_Buffer_Type type,
53+
const int64_t *dims, size_t num_dims,
54+
EventInstance **out_done_with_host_buffer_event,
55+
BufferInstance **out_buffer);
56+
57+
tt_pjrt_status
58+
HostBufferToDevice(const void *data, PJRT_Buffer_Type type,
59+
const int64_t *dims, size_t num_dims,
60+
const int64_t *byte_strides, size_t num_byte_strides,
61+
PJRT_HostBufferSemantics host_buffer_semantics,
62+
EventInstance **out_done_with_host_buffer_event,
63+
BufferInstance **out_buffer);
64+
65+
DeviceDescription *device_description() { return &description_; }
66+
67+
private:
68+
tt_pjrt_status OpenDevice();
69+
70+
ClientInstance &client_;
71+
uint64_t last_transfer_timepoint_ = 0;
72+
DeviceDescription description_;
73+
};
74+
75+
} // namespace tt::pjrt
76+
77+
#endif

0 commit comments

Comments
 (0)