forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathgen_autograd_functions.py
698 lines (637 loc) · 22.6 KB
/
gen_autograd_functions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
# Generates C++ autograd functions for the derivatives of ATen operations
#
# This writes two files:
# Functions.h/cpp: subclasses of autograd::Node
# python_functions.h/cpp: Python bindings for the above classes
#
from .gen_inplace_or_view_type import VIEW_FUNCTIONS
from typing import List, Sequence, Tuple
from torchgen.api.autograd import (
Derivative,
DifferentiabilityInfo,
SavedAttribute,
uses_retain_variables,
uses_single_grad,
)
from torchgen.api.types import (
Binding,
BaseCType,
OptionalCType,
tensorT,
longT,
doubleT,
scalarT,
stringT,
boolT,
intArrayRefT,
tensorListT,
MutRefCType,
ListCType,
ArrayRefCType,
optionalIntArrayRefT,
)
from torchgen.code_template import CodeTemplate
from torchgen.utils import FileManager
from torchgen.model import Argument
FUNCTION_DECLARATION = CodeTemplate(
"""\
struct TORCH_API ${op} : public ${superclass} {
using ${superclass}::${superclass};
variable_list apply(variable_list&& grads) override;
std::string name() const override { return "${op}"; }
void release_variables() override {
${thread_lock}
${release_variables}
}
${will_release_variables}
${saved_variables}
${saved_list_sizes}
};
"""
)
WILL_RELEASE_VARIABLES = CodeTemplate(
"""\
bool retain_variables = true;
void will_release_variables() override {
retain_variables = false;
}
"""
)
FUNCTION_DEFINITION = CodeTemplate(
"""\
variable_list ${op}::apply(variable_list&& grads) {
${thread_lock}
${asserts}
IndexRangeGenerator gen;
${compute_index_ranges}
variable_list grad_inputs(gen.size());
${body}
return grad_inputs;
}
"""
)
GRAD_INPUT_MASK = CodeTemplate(
"""\
auto grad_input_mask = std::array<bool, ${n}>{
${masks}
};\
"""
)
DERIVATIVE_SINGLE = CodeTemplate(
"""\
if (should_compute_output({ ${name}_ix })) {
auto grad_result = ${derivative};
copy_range(grad_inputs, ${name}_ix, grad_result);
}
"""
)
DERIVATIVE_MULTI_COPY_RANGE = CodeTemplate(
"""\
if (should_compute_output({ ${name}_ix })) {
copy_range(grad_inputs, ${name}_ix, std::get<${i}>(grad_result));
}
"""
)
DERIVATIVE_MULTI = CodeTemplate(
"""\
if (should_compute_output({ ${idx_ranges} })) {
${grad_input_mask}
auto grad_result = ${derivative};
${copy_ranges}
}
"""
)
# Generates python bindings
#
# This generates the definitions for:
# (1) The PyTypeObject for each backward grad_fn subclassing Node
# (2) The entry for PyTypeObject's tp_getset slot (an array of PyGetSetDef structs)
# We generate one PyGetSetDef struct for each of grad_fn's saved inputs and outputs
# Each PyGetSetDef has a function ptr to a getter, also defined here (3).
# (3) Getters for each of grad_fn's saved inputs and outputs.
#
PY_FUNCTION_DEFINITION = CodeTemplate(
"""\
static PyTypeObject ${op}Class;
addClass<${op}>(${op}Class, "${op}", ${op}_properties);
"""
)
PY_FUNCTION_PROPS_AND_GETTERS = CodeTemplate(
"""\
${all_getter_definitions}
static struct PyGetSetDef ${op}_properties[] = {
THP_FUNCTION_DEFAULT_PROPERTIES,
${all_getsetdef_structs}
{nullptr} /* sentinel */
};
"""
)
PY_GETSETDEF_STRUCT = CodeTemplate(
"""\
{(char*)"_saved_${name}", (getter)THP${op}_${name}_getter, nullptr, nullptr, nullptr}"""
)
PY_RAW_GETSETDEF_STRUCT = CodeTemplate(
"""\
{(char*)"_raw_saved_${name}", (getter)THP${op}_${name}_raw_getter, nullptr, nullptr, nullptr}"""
)
# Getter templates
GETTER_DEFINITION = CodeTemplate(
"""\
PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
HANDLE_TH_ERRORS
auto prop = static_cast<${op}*>(self->cdata.get())->${name};
${body}
END_HANDLE_TH_ERRORS
}
"""
)
GETTER_DEFINITION_SAVEDVAR = CodeTemplate(
"""\
PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
HANDLE_TH_ERRORS
const auto& prop = static_cast<${op}*>(self->cdata.get())->${name}_;
${body}
END_HANDLE_TH_ERRORS
}
"""
)
GETTER_DEFINITION_RAW_SAVEDVAR = CodeTemplate(
"""\
PyObject* THP${op}_${name}_raw_getter(THPCppFunction *self, void *_unused) {
HANDLE_TH_ERRORS
const auto& prop = static_cast<${op}*>(self->cdata.get())->${name}_;
${body}
END_HANDLE_TH_ERRORS
}
"""
)
GETTER_DEFINITION_VEC_SAVEDVAR = CodeTemplate(
"""\
PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
HANDLE_TH_ERRORS
const auto *node = static_cast<${op}*>(self->cdata.get());
const auto& prop = node->${name}_;
if (node->${name}_released_) {
PyErr_SetString(PyExc_RuntimeError, ERR_BACKWARD_TWICE);
return nullptr;
}
${body}
END_HANDLE_TH_ERRORS
}
"""
)
GETTER_DEFINITION_RAW_VEC_SAVEDVAR = CodeTemplate(
"""\
PyObject* THP${op}_${name}_raw_getter(THPCppFunction *self, void *_unused) {
HANDLE_TH_ERRORS
const auto *node = static_cast<${op}*>(self->cdata.get());
const auto& prop = node->${name}_;
if (node->${name}_released_) {
PyErr_SetString(PyExc_RuntimeError, ERR_BACKWARD_TWICE);
return nullptr;
}
${body}
END_HANDLE_TH_ERRORS
}
"""
)
GETTER_DEFINITION_OPT = CodeTemplate(
"""\
PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
HANDLE_TH_ERRORS
auto opt_prop = static_cast<${op}*>(self->cdata.get())->${name};
if (!opt_prop.has_value()) {
Py_RETURN_NONE;
}
auto prop = opt_prop.value();
${body}
END_HANDLE_TH_ERRORS
}
"""
)
GETTER_DEFINITION_OPT_ARRAYREF = CodeTemplate(
"""\
PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
HANDLE_TH_ERRORS
auto opt_prop = static_cast<${op}*>(self->cdata.get())->${name};
if (!opt_prop.list.has_value()) {
Py_RETURN_NONE;
}
auto prop = opt_prop.list.value();
${body}
END_HANDLE_TH_ERRORS
}
"""
)
# Getter body
GETTER_BODY_SAVEDVAR = """\
return THPVariable_Wrap(prop.unpack(self->cdata));
"""
GETTER_BODY_RAW_SAVEDVAR = """\
pybind11::object obj = pybind11::cast(prop, pybind11::return_value_policy::reference);
return obj.release().ptr();
"""
GETTER_BODY_VEC_SAVEDVAR = """\
PyObject* tup = PyTuple_New((Py_ssize_t) prop.size());
for (auto i: c10::irange(prop.size())) {
PyTuple_SetItem(tup, (Py_ssize_t) i, THPVariable_Wrap(prop[i].unpack(self->cdata)));
}
return tup;
"""
GETTER_BODY_RAW_VEC_SAVEDVAR = """\
PyObject* tup = PyTuple_New((Py_ssize_t) prop.size());
for (auto i : c10::irange(prop.size())) {
pybind11::object obj = pybind11::cast(prop[i], pybind11::return_value_policy::reference);
PyTuple_SetItem(tup, (Py_ssize_t) i, obj.release().ptr());
}
return tup;
"""
GETTER_BODY_ARRAYREF_LONG = """\
PyObject* tup = PyTuple_New((Py_ssize_t) prop.size());
for (auto i : c10::irange(prop.size())) {
PyTuple_SetItem(tup, (Py_ssize_t) i, PyLong_FromUnsignedLong((uint64_t) prop[i]));
}
return tup;
"""
GETTER_BODY_ARRAYREF_DOUBLE = """\
PyObject* tup = PyTuple_New((Py_ssize_t) prop.size());
for (auto i : c10::irange(prop.size())) {
PyTuple_SetItem(tup, (Py_ssize_t) i, PyFloat_FromDouble((double) prop[i]));
}
return tup;
"""
GETTER_BODY_INT64_T = """\
return PyLong_FromUnsignedLong((int64_t) prop);
"""
GETTER_BODY_DOUBLE = """\
return PyFloat_FromDouble((double) prop);
"""
GETTER_BODY_BOOL = """\
if (prop) {
Py_RETURN_TRUE;
} else {
Py_RETURN_FALSE;
}
"""
GETTER_BODY_STRING = """\
return PyUnicode_FromStringAndSize(prop.data(), prop.size());
"""
GETTER_BODY_SCALAR = """\
if (prop.isComplex()) {
auto cprop = prop.to<c10::complex<double>>();
return PyComplex_FromDoubles(cprop.real(), cprop.imag());
} else if (prop.isFloatingPoint()) {
return PyFloat_FromDouble(prop.to<double>());
} else if (prop.isIntegral(/*includeBool=*/false)) {
return PyLong_FromLong(prop.to<int64_t>());
} else if (prop.isBoolean()) {
if (prop.to<bool>()) {
Py_RETURN_TRUE;
} else {
Py_RETURN_FALSE;
}
} else {
PyErr_SetString(PyExc_RuntimeError, "Unknown scalar type");
return nullptr;
}
"""
MISC_GETTER_DEFS = {
OptionalCType(BaseCType(longT)): (GETTER_DEFINITION_OPT, GETTER_BODY_INT64_T),
BaseCType(doubleT): (GETTER_DEFINITION, GETTER_BODY_DOUBLE),
OptionalCType(BaseCType(doubleT)): (GETTER_DEFINITION_OPT, GETTER_BODY_DOUBLE),
BaseCType(boolT): (GETTER_DEFINITION, GETTER_BODY_BOOL),
BaseCType(scalarT): (GETTER_DEFINITION, GETTER_BODY_SCALAR),
OptionalCType(BaseCType(scalarT)): (GETTER_DEFINITION_OPT, GETTER_BODY_SCALAR),
}
# These functions have backwards which cannot be traced, and so must have
# their backward functions traced opaquely.
# VIEW_FUNCTIONS are not traceable because they use as_strided, which
# has an untraceable backwards, see
# https://github.com/pytorch/pytorch/issues/4250
# TODO: This is probably not exhaustive, but it's a start
UNTRACEABLE_FUNCTIONS = VIEW_FUNCTIONS
def gen_autograd_functions_lib(
out: str,
differentiability_infos: Sequence[DifferentiabilityInfo],
template_path: str,
) -> None:
"""Functions.h and Functions.cpp body
These contain the auto-generated subclasses of torch::autograd::Node
for each every differentiable torch function.
"""
# only create an autograd function if we are actually going to calculate a derivative
infos = list(
filter(lambda info: info.args_with_derivatives, differentiability_infos)
)
declarations = list(map(lambda f: process_function(f, FUNCTION_DECLARATION), infos))
definitions = list(map(lambda f: process_function(f, FUNCTION_DEFINITION), infos))
file_basename = "Functions"
fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
for suffix in [".h", ".cpp"]:
fname = file_basename + suffix
fm.write_with_template(
fname,
fname,
lambda: {
"generated_comment": "@" + f"generated from {fm.template_dir}/" + fname,
"autograd_function_declarations": declarations,
"autograd_function_definitions": definitions,
},
)
def gen_autograd_functions_python(
out: str,
differentiability_infos: Sequence[DifferentiabilityInfo],
template_path: str,
) -> None:
fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
num_shards = 5
fm.write(
"python_functions.h",
lambda: {
"generated_comment": f"@generated from {fm.template_dir}/python_functions.h",
"shard_forward_declare": [
f"void initialize_autogenerated_functions_{i}();"
for i in range(num_shards)
],
"shard_call": [
f"initialize_autogenerated_functions_{i}();" for i in range(num_shards)
],
},
)
infos = list(
filter(lambda info: info.args_with_derivatives, differentiability_infos)
)
fm.write_sharded(
"python_functions.cpp",
infos,
key_fn=lambda info: info.name,
base_env={
"generated_comment": f"@generated from {fm.template_dir}/python_functions.cpp",
},
env_callable=lambda info: {
"py_function_initializers": [
process_function(info, PY_FUNCTION_DEFINITION)
],
"py_function_props_and_getters": [
process_function(info, PY_FUNCTION_PROPS_AND_GETTERS)
],
},
num_shards=num_shards,
sharded_keys={"py_function_initializers", "py_function_props_and_getters"},
)
def process_function(info: DifferentiabilityInfo, template: CodeTemplate) -> str:
saved_variables: List[str] = []
release_variables: List[str] = []
saved_list_sizes: List[str] = []
unpack: List[str] = []
asserts: List[str] = []
compute_index_ranges: List[str] = []
getter_definitions: List[str] = []
py_getsetdef_structs: List[str] = []
for arg in info.args_with_derivatives:
if (
arg.type == "at::TensorList"
or arg.type == "const c10::List<c10::optional<at::Tensor>> &"
):
size = f"{arg.name}_size_"
saved_list_sizes.append(f"size_t {arg.name}_size_;")
else:
size = "1"
compute_index_ranges.append(f"auto {arg.name}_ix = gen.range({size});")
def save_var(var: SavedAttribute, is_output: bool) -> None:
name = var.nctype.name
type = var.nctype.type
should_append_getsetdef = True
should_append_raw_getsetdef = False
if (
type == BaseCType(tensorT)
or type == OptionalCType(BaseCType(tensorT))
or type == MutRefCType(OptionalCType(BaseCType(tensorT)))
or (type == BaseCType(scalarT) and is_output)
):
saved_variables.append(f"SavedVariable {name}_;")
release_variables.append(f"{name}_.reset_data();")
ptr = "shared_from_this()" if is_output else ""
unpack.append(f"auto {name} = {name}_.unpack({ptr});")
getter_definitions.append(
GETTER_DEFINITION_SAVEDVAR.substitute(
op=info.op, name=name, body=GETTER_BODY_SAVEDVAR
)
)
getter_definitions.append(
GETTER_DEFINITION_RAW_SAVEDVAR.substitute(
op=info.op, name=name, body=GETTER_BODY_RAW_SAVEDVAR
)
)
should_append_raw_getsetdef = True
elif type == BaseCType(tensorListT):
saved_variables.append(f"std::vector<SavedVariable> {name}_;")
saved_variables.append(f"bool {name}_released_ = false;")
# Just clear() is sufficient, we don't need to loop and clear each variable.
# Because the SavedVariable owns a tensor and a grad_fn, removing the SavedVariable makes them go away as well.
release_variables.append(f"{name}_.clear();")
release_variables.append(f"{name}_released_ = true;")
unpack.append(f"auto {name} = unpack_list({name}_);")
asserts.append(f"TORCH_CHECK(!{name}_released_, ERR_BACKWARD_TWICE);")
getter_definitions.append(
GETTER_DEFINITION_VEC_SAVEDVAR.substitute(
op=info.op, name=name, body=GETTER_BODY_VEC_SAVEDVAR
)
)
getter_definitions.append(
GETTER_DEFINITION_RAW_VEC_SAVEDVAR.substitute(
op=info.op, name=name, body=GETTER_BODY_RAW_VEC_SAVEDVAR
)
)
should_append_raw_getsetdef = True
elif type == ListCType(OptionalCType(BaseCType(tensorT))):
saved_variables.append(f"std::vector<SavedVariable> {name}_;")
saved_variables.append(f"bool {name}_released_ = false;")
# Just clear() is sufficient, we don't need to loop and clear each variable.
# Because the SavedVariable owns a tensor and a grad_fn, removing the SavedVariable makes them go away as well.
release_variables.append(f"{name}_.clear();")
release_variables.append(f"{name}_released_ = true;")
unpack.append(f"auto {name} = unpack_opt_list({name}_);")
asserts.append(f"TORCH_CHECK(!{name}_released_, ERR_BACKWARD_TWICE);")
getter_definitions.append(
GETTER_DEFINITION_VEC_SAVEDVAR.substitute(
op=info.op, name=name, body=GETTER_BODY_VEC_SAVEDVAR
)
)
getter_definitions.append(
GETTER_DEFINITION_RAW_VEC_SAVEDVAR.substitute(
op=info.op, name=name, body=GETTER_BODY_RAW_VEC_SAVEDVAR
)
)
should_append_raw_getsetdef = True
elif type == BaseCType(intArrayRefT):
saved_variables.append(f"std::vector<int64_t> {name};")
getter_definitions.append(
GETTER_DEFINITION.substitute(
op=info.op, name=name, body=GETTER_BODY_ARRAYREF_LONG
)
)
elif type == BaseCType(optionalIntArrayRefT):
saved_variables.append(f"c10::OptionalArray<int64_t> {name};")
getter_definitions.append(
GETTER_DEFINITION_OPT_ARRAYREF.substitute(
op=info.op, name=name, body=GETTER_BODY_ARRAYREF_LONG
)
)
elif type == OptionalCType(BaseCType(intArrayRefT)):
saved_variables.append(f"c10::OptionalArray<int64_t> {name};")
getter_definitions.append(
GETTER_DEFINITION_OPT_ARRAYREF.substitute(
op=info.op, name=name, body=GETTER_BODY_ARRAYREF_LONG
)
)
elif type == OptionalCType(ArrayRefCType(BaseCType(doubleT))):
saved_variables.append(f"c10::OptionalArray<double> {name};")
getter_definitions.append(
GETTER_DEFINITION_OPT_ARRAYREF.substitute(
op=info.op, name=name, body=GETTER_BODY_ARRAYREF_DOUBLE
)
)
elif type == BaseCType(longT):
saved_variables.append(f"{type.cpp_type()} {name} = 0;")
getter_definitions.append(
GETTER_DEFINITION.substitute(
op=info.op, name=name, body=GETTER_BODY_INT64_T
)
)
elif type == BaseCType(stringT):
saved_variables.append(f"std::string {name};")
getter_definitions.append(
GETTER_DEFINITION.substitute(
op=info.op, name=name, body=GETTER_BODY_STRING
)
)
elif type == OptionalCType(BaseCType(stringT)):
saved_variables.append(f"c10::optional<std::string> {name};")
getter_definitions.append(
GETTER_DEFINITION_OPT.substitute(
op=info.op, name=name, body=GETTER_BODY_STRING
)
)
else:
saved_variables.append(f"{type.cpp_type()} {name};")
if type in MISC_GETTER_DEFS:
getter_def, body = MISC_GETTER_DEFS[type]
getter_definitions.append(
getter_def.substitute(op=info.op, name=name, body=body)
)
else:
# Types we don't expose python bindings to yet:
# TypeAndSize, at::ScalarType, TensorOptions, TensorGeometry,
# std::vector<std::vector<int64_t>>, std::vector<at::ScalarType>
should_append_getsetdef = False
if should_append_getsetdef:
py_getsetdef_structs.append(
PY_GETSETDEF_STRUCT.substitute(op=info.op, name=name)
)
if should_append_raw_getsetdef:
py_getsetdef_structs.append(
PY_RAW_GETSETDEF_STRUCT.substitute(op=info.op, name=name)
)
for var in info.all_saved_inputs:
save_var(var, is_output=False)
for var in info.all_saved_outputs:
save_var(var, is_output=True)
# lock the mutex when we release variables and in Node::apply to protect thread safety
# see Note [Thread Safety on Autograd Node]
if len(release_variables) > 0:
thread_lock = "std::lock_guard<std::mutex> lock(mutex_);"
else:
thread_lock = ""
if uses_retain_variables(info):
will_release_variables = WILL_RELEASE_VARIABLES.substitute()
else:
will_release_variables = ""
body: List[str] = []
if uses_single_grad(info):
body.append("const auto& grad = grads[0];")
else:
# Generate aliases for gradients named for returned values.
body.extend(
f"const auto& {name} = grads[{info.available_named_gradients.index(name)}];"
for name in info.used_named_gradients
)
def emit_derivative(
derivative: Derivative,
args_with_derivatives: Sequence[Binding],
) -> Tuple[bool, str]:
formula = derivative.formula
var_names = derivative.var_names
if len(var_names) == 1:
checks_any_grad_defined = False
if "not_implemented" not in formula:
matching_args = [
arg for arg in args_with_derivatives if arg.name == var_names[0]
]
if len(matching_args) == 1:
# We can add undefined grad support if the input variable is a Tensor
arg = matching_args[0]
if isinstance(arg.argument, Argument) and str(
arg.argument.type
) in ("Tensor", "Tensor?"):
formula = "any_grad_defined ? (" + formula + ") : Tensor()"
checks_any_grad_defined = True
return (
checks_any_grad_defined,
DERIVATIVE_SINGLE.substitute(name=var_names[0], derivative=formula),
)
else:
if "grad_input_mask" in formula:
masks = [f"should_compute_output({{ {n}_ix }})," for n in var_names]
grad_input_mask = GRAD_INPUT_MASK.substitute(
masks=masks, n=len(var_names)
)
else:
grad_input_mask = ""
idx_ranges = ", ".join(f"{n}_ix" for n in var_names)
copy_ranges: List[str] = []
for i, n in enumerate(var_names):
copy_ranges.append(DERIVATIVE_MULTI_COPY_RANGE.substitute(name=n, i=i))
return False, DERIVATIVE_MULTI.substitute(
idx_ranges=idx_ranges,
copy_ranges=copy_ranges,
derivative=formula,
grad_input_mask=grad_input_mask,
)
body.extend(unpack)
need_any_grad_defined_var = False
for derivative in info.derivatives:
checks_any_grad_defined, derivative_text = emit_derivative(
derivative, info.args_with_derivatives
)
body.append(derivative_text)
need_any_grad_defined_var |= checks_any_grad_defined
# Since single-output derivative formulas need to check if grads are
# defined, only perform the check once, before all the formulas
if need_any_grad_defined_var:
body.insert(
-len(info.derivatives),
"bool any_grad_defined = any_variable_defined(grads);",
)
if info.name in UNTRACEABLE_FUNCTIONS:
superclass = "Node"
else:
superclass = "TraceableFunction"
all_getsetdef_structs = (
",\n".join(py_getsetdef_structs) + "," if len(py_getsetdef_structs) != 0 else ""
)
all_getter_definitions = "\n".join(getter_definitions)
return template.substitute(
op=info.op,
compute_index_ranges=compute_index_ranges,
saved_variables=saved_variables,
release_variables=release_variables,
saved_list_sizes=saved_list_sizes,
asserts=asserts,
thread_lock=thread_lock,
will_release_variables=will_release_variables,
body=body,
superclass=superclass,
all_getter_definitions=all_getter_definitions,
all_getsetdef_structs=all_getsetdef_structs,
)