1
+ #pragma once
2
+
3
+ #include < ATen/OpaqueTensorImpl.h>
4
+ #include " ttnn/tensor/tensor.hpp"
5
+ #include < iostream>
6
+ #include < string.h>
7
+
8
+ template <typename Arg, typename ... Args>
9
+ void doPrint (std::ostream& out, const std::string_view& filename, int lineno, const std::string_view& fn, Arg&& arg, Args&&... args)
10
+ {
11
+ out << std::format (" {}({})({}): " , filename, lineno, fn);
12
+ out << std::forward<Arg>(arg);
13
+ ((out << std::forward<Args>(args)), ...);
14
+ out << std::endl;
15
+ }
16
+ #define LOGGING (...) doPrint(std::cout, __FILE_NAME__, __LINE__, __FUNCTION__, __VA_ARGS__)
17
+
18
+ namespace at {
19
+
20
+ struct TtnnTensorImpl : public TensorImpl {
21
+ TtnnTensorImpl (
22
+ at::DispatchKeySet key_set,
23
+ const caffe2::TypeMeta data_type,
24
+ c10::Device device,
25
+ ttnn::Tensor& ttnn_tensor,
26
+ c10::intrusive_ptr<c10::StorageImpl> storage) : TensorImpl(key_set, data_type, device), ttnn_tensor_(ttnn_tensor), ttnn_tensor_string_(ttnn_tensor.write_to_string()) {
27
+ storage_ = std::move (storage);
28
+ auto view = ttnn_tensor_.get_logical_shape ().view ();
29
+ std::vector<int64_t > view_int64;
30
+ std::copy (view.begin (), view.end (), std::back_inserter (view_int64));
31
+ IntArrayRef int_array_ref (&(*view_int64.begin ()), &(*view_int64.end ()));
32
+ sizes_and_strides_.set_sizes (int_array_ref);
33
+ }
34
+
35
+ TtnnTensorImpl (
36
+ at::DispatchKeySet key_set,
37
+ const caffe2::TypeMeta data_type,
38
+ c10::Device device,
39
+ const ttnn::Tensor& ttnn_tensor,
40
+ const Storage& storage) : TensorImpl(key_set, data_type, device), ttnn_tensor_(ttnn_tensor), ttnn_tensor_string_(ttnn_tensor.write_to_string()) {
41
+ storage_ = std::move (storage);
42
+ auto view = ttnn_tensor_.get_logical_shape ().view ();
43
+ std::vector<int64_t > view_int64;
44
+ std::copy (view.begin (), view.end (), std::back_inserter (view_int64));
45
+ IntArrayRef int_array_ref (&(*view_int64.begin ()), &(*view_int64.end ()));
46
+ sizes_and_strides_.set_sizes (int_array_ref);
47
+ }
48
+
49
+ void set_sizes_and_strides (const IntArrayRef& int_array_ref) {
50
+ sizes_and_strides_.set_sizes (int_array_ref);
51
+ }
52
+
53
+ void set_sizes_and_strides_as (const at::Tensor& the_template) {
54
+ sizes_and_strides_.set_sizes (the_template.sizes ());
55
+ }
56
+
57
+ ttnn::Tensor get_ttnn_tensor () {
58
+ // LOGGING(ttnn_tensor_string_);
59
+ LOGGING (ttnn_tensor_.write_to_string ());
60
+ return ttnn_tensor_;
61
+ }
62
+
63
+ void set_ttnn_tensor (const ttnn::Tensor& tensor) {
64
+ ttnn_tensor_ = tensor;
65
+ }
66
+
67
+ /* *
68
+ * Return a TensorImpl that is a shallow-copy of this TensorImpl.
69
+ *
70
+ * For usage of `version_counter` and `allow_tensor_metadata_change`,
71
+ * see NOTE [ TensorImpl Shallow-Copying ].
72
+ */
73
+ c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach (
74
+ const c10::VariableVersion& version_counter,
75
+ bool allow_tensor_metadata_change) const override {
76
+ auto impl = c10::make_intrusive<TtnnTensorImpl>(
77
+ key_set (),
78
+ dtype (),
79
+ device (),
80
+ ttnn_tensor_,
81
+ storage_);
82
+ copy_tensor_metadata (
83
+ /* src_opaque_impl=*/ this ,
84
+ /* dest_opaque_impl=*/ impl.get (),
85
+ /* version_counter=*/ version_counter,
86
+ /* allow_tensor_metadata_change=*/ allow_tensor_metadata_change);
87
+ impl->refresh_numel ();
88
+ return impl;
89
+ }
90
+
91
+ /* *
92
+ * Return a TensorImpl that is a shallow-copy of this TensorImpl.
93
+ *
94
+ * For usage of `version_counter` and `allow_tensor_metadata_change`,
95
+ * see NOTE [ TensorImpl Shallow-Copying ].
96
+ */
97
+ c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach (
98
+ c10::VariableVersion&& version_counter,
99
+ bool allow_tensor_metadata_change) const override {
100
+ auto impl = c10::make_intrusive<TtnnTensorImpl>(
101
+ key_set (),
102
+ dtype (),
103
+ device (),
104
+ ttnn_tensor_,
105
+ storage_);
106
+ copy_tensor_metadata (
107
+ /* src_opaque_impl=*/ this ,
108
+ /* dest_opaque_impl=*/ impl.get (),
109
+ /* version_counter=*/ std::move (version_counter),
110
+ /* allow_tensor_metadata_change=*/ allow_tensor_metadata_change);
111
+ impl->refresh_numel ();
112
+ return impl;
113
+ }
114
+
115
+ /* *
116
+ * Shallow-copies data from another TensorImpl into this TensorImpl.
117
+ *
118
+ * For why this function doesn't check this TensorImpl's
119
+ * `allow_tensor_metadata_change_`, see NOTE [ TensorImpl Shallow-Copying ].
120
+ */
121
+ void shallow_copy_from (const c10::intrusive_ptr<TensorImpl>& impl) override {
122
+ AT_ASSERT (has_compatible_shallow_copy_type (impl->key_set ()));
123
+ auto ttnn_impl =
124
+ static_cast <const TtnnTensorImpl*>(impl.get ());
125
+ copy_tensor_metadata (
126
+ /* src_impl=*/ ttnn_impl,
127
+ /* dest_impl=*/ this ,
128
+ /* version_counter=*/ version_counter (),
129
+ /* allow_tensor_metadata_change=*/ allow_tensor_metadata_change ());
130
+ refresh_numel ();
131
+ }
132
+
133
+ // protected:
134
+ // static void copy_tensor_metadata(
135
+ // const TtnnTensorImpl* src_impl,
136
+ // TtnnTensorImpl* dest_impl,
137
+ // const c10::VariableVersion& version_counter,
138
+ // bool allow_tensor_metadata_change) {
139
+ // TensorImpl::copy_tensor_metadata(
140
+ // src_impl,
141
+ // dest_impl,
142
+ // version_counter,
143
+ // allow_tensor_metadata_change);
144
+
145
+ // // TtnnTensorImpl-specific fields.
146
+ // dest_impl->ttnn_tensor_ = src_impl->ttnn_tensor_;
147
+ // dest_impl->ttnn_tensor_string_ = src_impl->ttnn_tensor_string_;
148
+ // }
149
+
150
+ // static void copy_tensor_metadata(
151
+ // const TtnnTensorImpl* src_impl,
152
+ // TtnnTensorImpl* dest_impl,
153
+ // c10::VariableVersion&& version_counter,
154
+ // bool allow_tensor_metadata_change) {
155
+ // TensorImpl::copy_tensor_metadata(
156
+ // src_impl,
157
+ // dest_impl,
158
+ // std::move(version_counter),
159
+ // allow_tensor_metadata_change);
160
+
161
+ // // TtnnTensorImpl-specific fields.
162
+ // dest_impl->ttnn_tensor_ = src_impl->ttnn_tensor_;
163
+ // dest_impl->ttnn_tensor_string_ = src_impl->ttnn_tensor_string_;
164
+ // }
165
+
166
+ private:
167
+ ttnn::Tensor ttnn_tensor_;
168
+ std::string ttnn_tensor_string_;
169
+ };
170
+
171
+ } // namespace at
0 commit comments