forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathcodegen.cpp
686 lines (629 loc) · 23.7 KB
/
codegen.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
#include <torch/csrc/jit/codegen/fuser/codegen.h>
#include <ATen/ATen.h>
#include <ATen/code_template.h>
#include <c10/util/Exception.h>
#include <torch/csrc/jit/codegen/fuser/compiler.h>
#include <torch/csrc/jit/codegen/fuser/interface.h>
#include <torch/csrc/jit/codegen/fuser/tensor_info.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/codegen/fuser/cpu/resource_strings.h>
#include <torch/csrc/jit/codegen/fuser/cuda/resource_strings.h>
#include <cmath>
#include <cstdint>
#include <iostream>
#include <sstream>
#include <tuple>
#include <vector>
namespace torch::jit::fuser {
// Template for computing the offset into the tensor to access a value
static auto dim_calc = at::jit::CodeTemplate(R"(
//printf("tensor ${tensor} sizes[${d}] = %d, strides[${d}] = %d\n", ${tensor}.sizes[${d}],${tensor}.strides[${d}]);
size_t ${tensor}_dimIndex${d} = ${tensor}_linearIndex ${mod_sizes};
${tensor}_offset += ${tensor}_dimIndex${d} ${times_stride};
)");
static std::string valueName(const Value* n) {
return "n" + std::to_string(n->unique());
}
static std::string scalarValue(const int64_t v) {
return std::to_string(v);
}
static std::string scalarValue(const bool v) {
return std::to_string(v);
}
// Note: The NAN, NEG_INFINITY and POS_INFINITY strings map to device-specific
// implementations of these special values. These macros are found in the
// resource strings for each device.
static std::string scalarValue(const double v) {
std::ostringstream out;
if (std::isnan(v)) {
out << "NAN";
} else if (std::isinf(v)) {
if (v < 0) {
out << "NEG_INFINITY";
} else {
out << "POS_INFINITY";
}
} else {
out << std::setprecision(16) << v;
}
return out.str();
}
// Note: Half is special-cased to avoid returning at::Half
static const char* scalarTypeName(const at::ScalarType type) {
if (type == at::ScalarType::Half) {
return "half";
}
if (type == at::ScalarType::BFloat16) {
return cuda::bfloat16_type_string;
}
switch (type) {
#define DEFINE_CASE(ctype, name) \
case at::ScalarType::name: \
return #ctype;
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CASE)
#undef DEFINE_CASE
default:
throw std::runtime_error("unknown scalar type");
}
}
static const char* calcScalarTypeName(const at::ScalarType type) {
if (type == at::ScalarType::Half) {
return "float";
}
if (type == at::ScalarType::BFloat16) {
return "float";
}
return scalarTypeName(type);
}
static std::string variableType(const c10::Type& t) {
if (t.kind() == TypeKind::IntType) {
return "int64_t";
} else if (t.kind() == TypeKind::FloatType) {
return "double";
} else if (t.kind() == TypeKind::BoolType) {
return "bool";
} else if (auto scalar_type = t.expectRef<TensorType>().scalarType()) {
return calcScalarTypeName(*scalar_type);
}
// something went wrong with the type analysis during shape propagation
throw std::runtime_error(
"unknown scalar type during JIT fusion code generation");
}
static std::string typeCastedValueName(
const c10::Type& t,
const at::ScalarType outtype,
const std::string& vn) {
if (t.kind() == TypeKind::IntType || t.kind() == TypeKind::BoolType) {
if (!isIntegralType(outtype, /*includeBool=*/false)) {
return std::string("((") + calcScalarTypeName(outtype) + ") " + vn + ")";
}
return vn;
} else if (t.kind() == TypeKind::FloatType) {
// We don't guard this on anything because in our type system for scalars,
// there is not a distinction between `float` and `double`, however there
// *is* a distinction in tensor scalar types. We conservatively insert a
// cast here, which may end up being a no-op if the tensor's scalar type
// is `double`.
return std::string("((") + calcScalarTypeName(outtype) + ") " + vn + ")";
} else if (t.kind() == TypeKind::NoneType) {
// Support None value for optional arguments like memory format
return vn;
} else if (auto scalar_type = t.expectRef<TensorType>().scalarType()) {
if (*scalar_type != outtype) {
return std::string("((") + calcScalarTypeName(outtype) + ") " + vn + ")";
}
return vn;
}
// something went wrong with the type analysis during shape propagation
throw std::runtime_error(
"unknown scalar type during JIT fusion code generation");
}
// Writes RHS of special handling "simple mappable" ops
static std::string encodeSpecialRHS(const Node* n, at::jit::TemplateEnv& env) {
// special case for clamp fusion on missing min/max inputs
// Note: It may seem unusual to have the bounds as the first case below,
// this is so that if min or max is NaN, they are "ignored"
// and when the input is NaN, the output is, too
if (n->kind() == aten::clamp) {
const auto min = n->input(1);
const auto max = n->input(2);
env.s("0", valueName(n->input(0)));
if (!min->node()->mustBeNone() && !max->node()->mustBeNone()) {
env.s("1", valueName(min));
env.s("2", valueName(max));
return format("(${0} < ${1} ? ${1} : (${0} > ${2}? ${2} : ${0}))", env);
} else if (min->node()->mustBeNone()) {
env.s("1", valueName(max));
return format("(${0} > ${1} ? ${1} : ${0})", env);
} else if (max->node()->mustBeNone()) {
env.s("1", valueName(min));
return format("(${0} < ${1} ? ${1} : ${0})", env);
} else {
throw std::runtime_error(
"At least one of 'min' or 'max' must not be None");
}
} else {
throw std::runtime_error("Cannot encode RHS of the node, op not supported");
}
}
// This struct specifies a template for dispatching specific aten:: operators.
// The current variants of RHS code selection we support are for double and
// float output values. For example, an aten::log operator which is assigned
// to a float value would emit logf(), whereas an aten::log operator which is
// assigned to a double would emit log().
struct RHSTemplate {
// Common case: float and double dispatch are identical
RHSTemplate(const char* for_float)
: for_float(for_float), for_double(for_float) {}
RHSTemplate(const char* for_float, const char* for_double)
: for_float(for_float), for_double(for_double) {}
const char* for_float;
const char* for_double;
};
// Writes "simple mappable" ops
static std::string encodeRHS(const Node* n) {
static std::unordered_map<NodeKind, RHSTemplate> simple_map_ops = {
// unary
{aten::_cast_Float, "static_cast<float>(${0})"},
{aten::abs, "fabs(${0})"},
{aten::sigmoid, {"1.f / (1.f + expf(-${0}))", "1. / (1. + exp(-${0}))"}},
{aten::relu, "${0} < 0 ? 0.f : ${0} "},
{aten::threshold,
"${0} <= ${1} ? static_cast<decltype(${0})>(${2}) : ${0} "},
{aten::log, {"logf(${0})", "log(${0})"}},
{aten::log10, {"log10f(${0})", "log10(${0})"}},
{aten::log1p, {"log1pf(${0})", "log1p(${0})"}},
{aten::log2, {"log2f(${0})", "log2(${0})"}},
{aten::lgamma, {"lgammaf(${0})", "lgamma(${0})"}},
{aten::exp, {"expf(${0})", "exp(${0})"}},
{aten::expm1, {"expm1f(${0})", "expm1(${0})"}},
{aten::erf, {"erff(${0})", "erf(${0})"}},
{aten::erfc, {"erfcf(${0})", "erfc(${0})"}},
{aten::cos, {"cosf(${0})", "cos(${0})"}},
{aten::acos, {"acosf(${0})", "acos(${0})"}},
{aten::cosh, {"coshf(${0})", "cosh(${0})"}},
{aten::sin, {"sinf(${0})", "sin(${0})"}},
{aten::asin, {"asinf(${0})", "asin(${0})"}},
{aten::sinh, {"sinhf(${0})", "sinh(${0})"}},
{aten::tan, {"tanf(${0})", "tan(${0})"}},
{aten::atan, {"atanf(${0})", "atan(${0})"}},
{aten::tanh, {"tanhf(${0})", "tanh(${0})"}},
{aten::sqrt, {"sqrtf(${0})", "sqrt(${0})"}},
{aten::rsqrt, {"rsqrtf(${0})", "rsqrt(${0})"}},
{aten::ceil, {"ceilf(${0})", "ceil(${0})"}},
{aten::floor, {"floorf(${0})", "floor(${0})"}},
{aten::round, {"roundf(${0})", "round(${0})"}},
{aten::trunc, {"truncf(${0})", "trunc(${0})"}},
{aten::frac, {"${0} - truncf(${0})", "${0} - trunc(${0})"}},
{aten::reciprocal, {"1.f/(${0})", "1./(${0})"}},
{aten::neg, "-${0}"},
// simple binary
{aten::atan2, "atan2(${0}, ${1})"},
{aten::min,
"isnan(${0}) ? ${0} : (isnan(${1}) ? ${1} : (${0} < ${1} ? ${0} : ${1}))"},
{aten::max,
"isnan(${0}) ? ${0} : (isnan(${1}) ? ${1} : (${0} < ${1} ? ${1} : ${0}))"},
// binary with other
// TODO: some of these ops will not get generated because
// we only work on float inputs/outputs, but they are here to record
// that they are valid mappable ops once we handle more type
{aten::__and__, "${0} && ${1}"},
{aten::__lshift__, "${0} << ${1}"},
{aten::__or__, "${0} || ${1}"},
{aten::__rshift__, "${0} >> ${1}"},
{aten::__xor__, "${0} ^ ${1}"},
{aten::addcmul, "${0} + ${3} * ${1} * ${2}"},
{aten::div, "${0} / ${1}"},
{aten::eq, "${0_nocast} == ${1_nocast}"},
{aten::fmod, "fmodf(${0}, ${1})"},
{aten::ge, "(${0_nocast} >= ${1_nocast})"},
{aten::gt, "${0_nocast} > ${1_nocast}"},
{aten::le, "(${0_nocast} <= ${1_nocast})"},
{aten::lt, "${0_nocast} < ${1_nocast}"},
{aten::lerp, "${0} + ${2} * (${1} - ${0})"},
{aten::type_as, "(${0})"},
{aten::mul, "${0} * ${1}"},
{aten::ne, "${0_nocast} != ${1_nocast}"},
{aten::remainder, "fmod((${1} + fmod(${0}, ${1})), ${1})"},
{aten::pow, {"powf(${0}, ${1})", "pow(${0}, ${1})"}},
// alpha
{aten::add, "${0} + ${2}*${1}"},
{aten::sub, "(${0} - ${2}*${1})"},
{aten::rand_like, "uniform(rnd())"},
// where
{aten::where, "(${0} ? ${1} : ${2})"},
};
at::jit::TemplateEnv env;
if (simple_map_ops.find(n->kind()) == simple_map_ops.end()) {
return encodeSpecialRHS(n, env);
} else {
size_t i = 0;
auto outtype = n->output()->type()->expectRef<TensorType>().scalarType();
TORCH_INTERNAL_ASSERT(outtype);
for (auto in : n->inputs()) {
// PyTorch converts (scalar) argument types to result before applying the
// operator e.g. 1.4-torch.tensor(3) = -2
env.s(
std::to_string(i),
typeCastedValueName(*in->type(), *outtype, valueName(in)));
// Uncasted operands only used for comparison operators
env.s(std::to_string(i) + "_nocast", valueName(in));
i++;
}
const auto& templ = simple_map_ops.at(n->kind());
const char* str = nullptr;
if (*outtype == at::kFloat) {
str = templ.for_float;
} else {
str = templ.for_double;
}
AT_ASSERT(str);
return format(str, env);
}
}
static void emitIndexingFor(
std::ostream& out,
const std::string& tensor,
const int ndim,
const bool last_is_cont) {
at::jit::TemplateEnv env;
env.s("tensor", tensor);
out << format("IndexType ${tensor}_offset = 0;\n", env);
out << format("IndexType ${tensor}_linearIndex = linearIndex;\n", env);
for (int d = ndim - 1; d >= 0; --d) {
env.d("d", d);
env.s("mod_sizes", d > 0 ? format("% ${tensor}.sizes[${d}]", env) : "");
env.s(
"times_stride",
(d < ndim - 1 || !last_is_cont)
? format("* ${tensor}.strides[${d}]", env)
: "");
out << dim_calc.format(env);
if (d > 0) {
out << format("${tensor}_linearIndex /= ${tensor}.sizes[${d}];\n", env);
}
}
}
static void emitCheckFor(
std::ostream& out,
const std::string& tensor,
const int ndim,
const TensorDesc& desc) {
at::jit::TemplateEnv env;
env.s("tensor", tensor);
env.s("scalar_type", scalarTypeName(desc.scalar_type));
// allocate buffer to load 4
out << format("${scalar_type} ${tensor}_buf[4];\n", env);
// check if last dim is contiguous
if (!desc.lastIsContiguous()) {
out << "flag_vec4 = false;\n";
return;
}
// disable on dtype > 4 bytes for performance
if (at::elementSize(desc.scalar_type) > 4) {
out << "flag_vec4 = false;\n";
return;
}
// last dim size multiple of 4, other dim stride multiple of 4
for (int d = ndim - 1; d >= 0; --d) {
env.d("d", d);
if (d == ndim - 1) {
// last dim stride already checked above at compile time
out << format(
"if(${tensor}.sizes[${d}] % 4 != 0) flag_vec4 = false;\n", env);
} else {
out << format(
"if(${tensor}.strides[${d}] % 4 != 0) flag_vec4 = false;\n", env);
}
}
// pointer aligned
out << format(
"if(((uint64_t) ${tensor}.data) % (4 * sizeof(${scalar_type})) != 0) flag_vec4 = false;\n",
env);
}
// TODO: handle cases where we need to generate > 2^32 element tensors
std::string generateKernel(
const std::string& name,
const Graph& graph,
const std::vector<std::pair<const Value*, const std::optional<TensorDesc>>>&
inputs,
const std::vector<std::pair<const Value*, const TensorDesc>>& outputs,
const bool use_cuda) {
at::jit::TemplateEnv env;
env.s("kernelName", name);
env.s(
"IndexType",
"unsigned int"); // Note: not uint32_t to avoid including cstdint
std::stringstream tensorChecks;
std::stringstream body;
std::stringstream body_vec4;
std::stringstream load;
std::stringstream store;
std::stringstream tensorOffsets;
std::vector<std::string> formals;
std::vector<std::string> argument_loads;
// Lambda for writing arguments
auto emitFormal = [&](const Value* n, const TensorDesc& desc) {
env.d(
"formal_index",
formals.size() +
1); // + 1 because the first argument is the linearIndex
std::string tensor =
"t" +
std::to_string(
formals.size()); // can't be unique() because Param may be an output
const auto nDim = desc.nDim();
emitCheckFor(tensorChecks, tensor, nDim, desc);
emitIndexingFor(tensorOffsets, tensor, nDim, desc.lastIsContiguous());
env.s("tensor", tensor);
env.d("nDim", nDim);
env.s("scalar_type", scalarTypeName(desc.scalar_type));
formals.push_back(
format("const TensorInfo<${scalar_type},${nDim}> ${tensor}", env));
argument_loads.push_back(format(
"*static_cast<TensorInfo<${scalar_type},${nDim}>*>(args[${formal_index}])",
env));
};
auto emitScalarFormal = [&](const Value* n) {
env.d(
"formal_index",
formals.size() +
1); // + 1 because the first argument is the linearIndex
std::string scalar =
"s" +
std::to_string(
formals.size()); // can't be unique() because Param may be an output
env.d(
"formal_index",
formals.size() +
1); // + 1 because the first argument is the linearIndex
env.s("scalar", scalar);
env.s("scalar_type", variableType(*n->type()));
formals.push_back(format("${scalar_type} ${scalar}", env));
argument_loads.push_back(
format("*static_cast<${scalar_type}*>(args[${formal_index}])", env));
};
// Writes input parameters
for (const auto& input : inputs) {
if (input.second.has_value()) {
emitFormal(input.first, *input.second);
} else {
emitScalarFormal(input.first);
}
}
// Writes output parameters
for (const auto& output : outputs) {
emitFormal(output.first, output.second);
}
// Acquires input values
bool has_half_tensor = false;
bool has_bfloat_tensor = false;
size_t formal_count = 0;
for (const auto& input : inputs) {
auto p = input.first;
env.s("node", valueName(p));
env.d("formal", formal_count++);
// Acquires and converts (if needed) inputs
// Note: conversion from half is only supported for CUDA kernels.
// The conversion immediately converts fp16 inputs to float.
// Access for other types is common to CUDA and CPU kernels.
if (input.second.has_value()) {
const auto is_half = input.second.has_value() &&
((*input.second).scalar_type == at::ScalarType::Half);
const auto is_bfloat = input.second.has_value() &&
((*input.second).scalar_type == at::ScalarType::BFloat16);
const auto is_bool = input.second.has_value() &&
((*input.second).scalar_type == at::ScalarType::Bool);
if (is_half) {
AT_ASSERT(use_cuda);
env.s(
"access",
format("__half2float(t${formal}.data[t${formal}_offset])", env));
env.s("access_vec4", format("__half2float(t${formal}_buf[i])", env));
has_half_tensor = true;
} else if (is_bfloat) {
AT_ASSERT(use_cuda);
env.s(
"access",
format(
"__bfloat162float(t${formal}.data[t${formal}_offset])", env));
env.s(
"access_vec4", format("__bfloat162float(t${formal}_buf[i])", env));
has_bfloat_tensor = true;
} else if (use_cuda) {
// No __ldg overload for bool
if (is_bool) {
env.s("access", format("t${formal}.data[t${formal}_offset]", env));
} else {
env.s(
"access",
format("__ldg(&t${formal}.data[t${formal}_offset])", env));
}
env.s("access_vec4", format("t${formal}_buf[i]", env));
} else {
env.s("access", format("t${formal}.data[t${formal}_offset]", env));
env.s("access_vec4", format("t${formal}_buf[i]", env));
}
env.s("lhs_type", calcScalarTypeName(input.second->scalar_type));
// load input in vectorized code path
auto ele_size = at::elementSize((*input.second).scalar_type);
if (ele_size == 1) {
env.s(
"load4",
format(
"*(reinterpret_cast<float*>(t${formal}_buf)) = *(reinterpret_cast<float*>(t${formal}.data + t${formal}_offset))",
env));
} else if (ele_size == 2) {
env.s(
"load4",
format(
"*(reinterpret_cast<float2*>(t${formal}_buf)) = *(reinterpret_cast<float2*>(t${formal}.data + t${formal}_offset))",
env));
} else if (ele_size == 4) {
env.s(
"load4",
format(
"*(reinterpret_cast<float4*>(t${formal}_buf)) = *(reinterpret_cast<float4*>(t${formal}.data + t${formal}_offset))",
env));
} else {
env.s(
"load4",
format(
"for(int i = 0; i<4; i++) t${formal}_buf[i] = t${formal}.data[t${formal}_offset + i]",
env));
}
load << format("${load4};\n", env);
} else {
env.s("access", format("s${formal}", env));
env.s("access_vec4", format("s${formal}", env));
env.s("lhs_type", variableType(*input.first->type()));
}
body << format("${lhs_type} ${node} = ${access};\n", env);
body_vec4 << format("${lhs_type} ${node} = ${access_vec4};\n", env);
}
bool has_random = false;
// Generates code for intermediate nodes
// Note: Concat and Chunk are implicitly generated
// Note: Random number generation is only supported for CUDA kernels.
// Note: Constant None node is ignored and we will handle it in the
// places where the constant None node is used
// Note: No need to iterate over reference as n is a pointer
for (const auto n : graph.nodes()) {
static_assert(std::is_pointer_v<decltype(n)>, "n must be a pointer");
// Note: FusedConcat nodes work by narrowing the output Tensors before the
// kernel runs
if (n->kind() == prim::FusedConcat)
continue;
if (n->kind() == prim::ConstantChunk)
continue;
if (n->mustBeNone())
continue;
if (n->kind() == aten::rand_like) {
AT_ASSERT(use_cuda);
has_random = true;
}
// Always emit double for prim::Constant. This will be narrowed later based
// on either:
// - Tensor-Scalar operator type rules
// - Math function rules
if (n->kind() == prim::Constant) {
const auto val = toIValue(n->output()).value();
std::string rhs;
if (val.isDouble()) {
rhs = scalarValue(val.toDouble());
} else if (val.isBool()) {
rhs = scalarValue(val.toBool());
} else {
AT_ASSERT(val.isInt());
rhs = scalarValue(val.toInt());
}
env.s("node", valueName(n->output()));
env.s("rhs", rhs);
env.s("lhs_type", variableType(*n->output()->type()));
} else {
env.s("node", valueName(n->output()));
env.s("rhs", encodeRHS(n));
env.s("lhs_type", variableType(*n->output()->type()));
}
body << format("${lhs_type} ${node} = ${rhs};\n", env);
body_vec4 << format("${lhs_type} ${node} = ${rhs};\n", env);
}
// Generates writes to output tensors
for (const auto& output : outputs) {
env.d("formal", formal_count++);
env.s("access", format("t${formal}.data[t${formal}_offset]", env));
env.s("access_vec4", format("t${formal}_buf[i]", env));
env.s("node", valueName(output.first));
// Acquires and converts (if needed) outputs
// Note: conversion to half is only supported for CUDA kernels.
const auto is_half = (output.second.scalar_type == at::ScalarType::Half);
const auto is_bfloat =
(output.second.scalar_type == at::ScalarType::BFloat16);
if (is_half) {
AT_ASSERT(use_cuda);
body << format("${access} = __float2half(${node});\n", env);
body_vec4 << format("${access_vec4} = __float2half(${node});\n", env);
has_half_tensor = true;
} else if (is_bfloat) {
AT_ASSERT(use_cuda);
body << format("${access} = __float2bfloat16(${node});\n", env);
body_vec4 << format("${access_vec4} = __float2bfloat16(${node});\n", env);
has_bfloat_tensor = true;
} else {
body << format("${access} = ${node};\n", env);
body_vec4 << format("${access_vec4} = ${node};\n", env);
}
// store output in vectorized code path
auto ele_size = at::elementSize(output.second.scalar_type);
if (ele_size == 1) {
env.s(
"store4",
format(
"*(reinterpret_cast<float*>(t${formal}.data + t${formal}_offset)) = *(reinterpret_cast<float*>(t${formal}_buf))",
env));
} else if (ele_size == 2) {
env.s(
"store4",
format(
"*(reinterpret_cast<float2*>(t${formal}.data + t${formal}_offset)) = *(reinterpret_cast<float2*>(t${formal}_buf))",
env));
} else if (ele_size == 4) {
env.s(
"store4",
format(
"*(reinterpret_cast<float4*>(t${formal}.data + t${formal}_offset)) = *(reinterpret_cast<float4*>(t${formal}_buf))",
env));
} else {
env.s(
"store4",
format(
"for(int i = 0; i<4; i++) t${formal}.data[t${formal}_offset + i] = t${formal}_buf[i]",
env));
}
store << format("${store4};\n", env);
}
// Includes headers
// Note: CUDA kernels support halfs and random generation, CPU kernels do not
if (has_half_tensor) {
env.s("HalfHeader", cuda::half_support_literal);
} else {
env.s("HalfHeader", "");
}
if (has_bfloat_tensor) {
env.s("BFloat16Header", cuda::bfloat16_support_literal);
} else {
env.s("BFloat16Header", "");
}
if (has_random) {
env.s("RandHeader", cuda::rand_support_literal);
env.s("RandParam", cuda::rand_param);
env.s("RandInit", cuda::rand_init);
} else {
env.s("RandHeader", "");
env.s("RandParam", "");
env.s("RandInit", "");
}
// clang-format on
// Instantiates the CUDA or CPU-specific templates
env.s("tensorOffsets", tensorOffsets.str());
env.s("tensorChecks", tensorChecks.str());
env.s("kernelBody", body.str());
env.s("kernelBody_vec4", body_vec4.str());
env.s("kernelLoad", load.str());
env.s("kernelStore", store.str());
env.v("formals", formals);
env.v("argument_loads", argument_loads);
std::string code_string;
if (use_cuda) {
env.s("type_declarations", cuda::type_declarations_template.format(env));
code_string = cuda::cuda_compilation_unit_template.format(env);
} else {
env.s("type_declarations", cpu::type_declarations_template.format(env));
code_string = cpu::cpu_compilation_unit_template.format(env);
}
if (debugFuser()) {
std::cerr << "fusion code:" << code_string << '\n';
}
return code_string;
}
} // namespace torch::jit::fuser