-
Notifications
You must be signed in to change notification settings - Fork 129
/
Copy pathvectorize_codegen.py
581 lines (517 loc) · 20.8 KB
/
vectorize_codegen.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
from __future__ import annotations
import base64
import pickle
from collections.abc import Callable, Sequence
from hashlib import sha256
from textwrap import indent
from typing import Any, cast
import numba
import numpy as np
from llvmlite import ir
from numba import TypingError, types
from numba.core import cgutils
from numba.core.base import BaseContext
from numba.core.types.misc import NoneType
from numba.np import arrayobj
from pytensor.graph.op import HasInnerGraph
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.super_utils import compile_function_src2
from pytensor.scalar import ScalarOp
def encode_literals(literals: Sequence) -> str:
return base64.encodebytes(pickle.dumps(literals)).decode()
def store_core_outputs(
core_op_fn: Callable, core_op: ScalarOp, nin: int, nout: int
) -> Callable:
"""Create a Numba function that wraps a core function and stores its vectorized outputs.
@njit
def store_core_outputs(i0, i1, ..., in, o0, o1, ..., on):
to0, to1, ..., ton = core_op_fn(i0, i1, ..., in)
o0[...] = to0
o1[...] = to1
...
on[...] = ton
"""
inputs = [f"i{i}" for i in range(nin)]
outputs = [f"o{i}" for i in range(nout)]
inner_outputs = [f"t{output}" for output in outputs]
inp_signature = ", ".join(inputs)
out_signature = ", ".join(outputs)
inner_out_signature = ", ".join(inner_outputs)
store_outputs = "\n".join(
f"{output}[...] = {inner_output}"
for output, inner_output in zip(outputs, inner_outputs, strict=True)
)
func_src = f"""
def store_core_outputs({inp_signature}, {out_signature}):
{inner_out_signature} = core_op_fn({inp_signature})
{indent(store_outputs, " " * 4)}
"""
global_env = {"core_op_fn": core_op_fn}
# func = compile_function_src(
# func_src, "store_core_outputs", {**globals(), **global_env},
# )
if isinstance(core_op, HasInnerGraph):
key = sha256(core_op.c_code_template.encode()).hexdigest()
else:
key = str(core_op)
func = compile_function_src2(key, func_src, "store_core_outputs", global_env)
return cast(Callable, numba_basic.numba_njit(func))
_jit_options = {
"fastmath": {
"arcp", # Allow Reciprocal
"contract", # Allow floating-point contraction
"afn", # Approximate functions
"reassoc",
"nsz", # TODO Do we want this one?
},
"no_cpython_wrapper": True,
"no_cfunc_wrapper": True,
}
@numba.extending.intrinsic(jit_options=_jit_options, prefer_literal=True)
def _vectorized(
typingctx,
scalar_func,
input_bc_patterns,
output_bc_patterns,
output_dtypes,
inplace_pattern,
constant_inputs_types,
input_types,
output_core_shape_types,
size_type,
):
arg_types = [
scalar_func,
input_bc_patterns,
output_bc_patterns,
output_dtypes,
inplace_pattern,
constant_inputs_types,
input_types,
output_core_shape_types,
size_type,
]
if not isinstance(input_bc_patterns, types.Literal):
raise TypingError("input_bc_patterns must be literal.")
input_bc_patterns = input_bc_patterns.literal_value
input_bc_patterns = pickle.loads(base64.decodebytes(input_bc_patterns.encode()))
if not isinstance(output_bc_patterns, types.Literal):
raise TypeError("output_bc_patterns must be literal.")
output_bc_patterns = output_bc_patterns.literal_value
output_bc_patterns = pickle.loads(base64.decodebytes(output_bc_patterns.encode()))
if not isinstance(output_dtypes, types.Literal):
raise TypeError("output_dtypes must be literal.")
output_dtypes = output_dtypes.literal_value
output_dtypes = pickle.loads(base64.decodebytes(output_dtypes.encode()))
if not isinstance(inplace_pattern, types.Literal):
raise TypeError("inplace_pattern must be literal.")
inplace_pattern = inplace_pattern.literal_value
inplace_pattern = pickle.loads(base64.decodebytes(inplace_pattern.encode()))
batch_ndim = len(input_bc_patterns[0])
nin = len(constant_inputs_types) + len(input_types)
nout = len(output_bc_patterns)
if nin == 0:
raise TypingError("Empty argument list to vectorized op.")
if nout == 0:
raise TypingError("Empty list of outputs for vectorized op.")
if not all(isinstance(input, types.Array) for input in input_types):
raise TypingError("Vectorized inputs must be arrays.")
if not all(
len(pattern) == batch_ndim for pattern in input_bc_patterns + output_bc_patterns
):
raise TypingError(
"Vectorized broadcastable patterns must have the same length."
)
core_input_types = []
for input_type, bc_pattern in zip(input_types, input_bc_patterns, strict=True):
core_ndim = input_type.ndim - len(bc_pattern)
# TODO: Reconsider this
if core_ndim == 0:
core_input_type = input_type.dtype
else:
core_input_type = types.Array(
dtype=input_type.dtype, ndim=core_ndim, layout=input_type.layout
)
core_input_types.append(core_input_type)
core_out_types = [
types.Array(numba.from_dtype(np.dtype(dtype)), len(output_core_shape), "C")
for dtype, output_core_shape in zip(
output_dtypes, output_core_shape_types, strict=True
)
]
out_types = [
types.Array(
numba.from_dtype(np.dtype(dtype)), batch_ndim + len(output_core_shape), "C"
)
for dtype, output_core_shape in zip(
output_dtypes, output_core_shape_types, strict=True
)
]
for output_idx, input_idx in inplace_pattern:
output_type = input_types[input_idx]
core_out_types[output_idx] = types.Array(
dtype=output_type.dtype,
ndim=output_type.ndim - batch_ndim,
layout=input_type.layout,
)
out_types[output_idx] = output_type
core_signature = typingctx.resolve_function_type(
scalar_func,
[
*constant_inputs_types,
*core_input_types,
*core_out_types,
],
{},
)
ret_type = types.Tuple(out_types)
if len(output_dtypes) == 1:
ret_type = ret_type.types[0]
sig = ret_type(*arg_types)
# So we can access the constant values in codegen...
input_bc_patterns_val = input_bc_patterns
output_bc_patterns_val = output_bc_patterns
output_dtypes_val = output_dtypes
inplace_pattern_val = inplace_pattern
input_types = input_types
size_is_none = isinstance(size_type, NoneType)
def codegen(
ctx,
builder,
sig,
args,
):
[_, _, _, _, _, constant_inputs, inputs, output_core_shapes, size] = args
constant_inputs = cgutils.unpack_tuple(builder, constant_inputs)
inputs = cgutils.unpack_tuple(builder, inputs)
output_core_shapes = [
cgutils.unpack_tuple(builder, shape)
for shape in cgutils.unpack_tuple(builder, output_core_shapes)
]
size = None if size_is_none else cgutils.unpack_tuple(builder, size)
inputs = [
arrayobj.make_array(ty)(ctx, builder, val)
for ty, val in zip(input_types, inputs, strict=True)
]
in_shapes = [cgutils.unpack_tuple(builder, obj.shape) for obj in inputs]
iter_shape = compute_itershape(
ctx,
builder,
in_shapes,
input_bc_patterns_val,
size,
)
outputs, output_types = make_outputs(
ctx,
builder,
iter_shape,
output_bc_patterns_val,
output_dtypes_val,
inplace_pattern_val,
inputs,
input_types,
output_core_shapes,
)
make_loop_call(
typingctx,
ctx,
builder,
scalar_func,
core_signature,
iter_shape,
constant_inputs,
inputs,
outputs,
input_bc_patterns_val,
output_bc_patterns_val,
input_types,
output_types,
)
if len(outputs) == 1:
if inplace_pattern:
assert inplace_pattern[0][0] == 0
ctx.nrt.incref(builder, sig.return_type, outputs[0]._getvalue())
return outputs[0]._getvalue()
for inplace_idx in dict(inplace_pattern):
ctx.nrt.incref(
builder,
sig.return_type.types[inplace_idx],
outputs[inplace_idx]._get_value(),
)
return ctx.make_tuple(
builder, sig.return_type, [out._getvalue() for out in outputs]
)
return sig, codegen
def compute_itershape(
ctx: BaseContext,
builder: ir.IRBuilder,
in_shapes: list[list[ir.Instruction]],
broadcast_pattern: tuple[tuple[bool, ...], ...],
size: list[ir.Instruction] | None,
):
one = ir.IntType(64)(1)
batch_ndim = len(broadcast_pattern[0])
shape = [None] * batch_ndim
if size is not None:
shape = size
for i in range(batch_ndim):
for j, (bc, in_shape) in enumerate(
zip(broadcast_pattern, in_shapes, strict=True)
):
length = in_shape[i]
if bc[i]:
with builder.if_then(
builder.icmp_unsigned("!=", length, one), likely=False
):
msg = f"Vectorized input {j} is expected to have shape 1 in axis {i}"
ctx.call_conv.return_user_exc(builder, ValueError, (msg,))
else:
with builder.if_then(
builder.icmp_unsigned("!=", length, shape[i]), likely=False
):
with builder.if_else(
builder.icmp_unsigned("==", length, one)
) as (
then,
otherwise,
):
with then:
msg = (
f"Incompatible vectorized shapes for input {j} and axis {i}. "
f"Input {j} has shape 1, but is not statically "
"known to have shape 1, and thus not broadcastable."
)
ctx.call_conv.return_user_exc(
builder, ValueError, (msg,)
)
with otherwise:
msg = f"Vectorized input {j} has an incompatible shape in axis {i}."
ctx.call_conv.return_user_exc(
builder, ValueError, (msg,)
)
else:
# Size is implied by the broadcast pattern
for i in range(batch_ndim):
for j, (bc, in_shape) in enumerate(
zip(broadcast_pattern, in_shapes, strict=True)
):
length = in_shape[i]
if bc[i]:
with builder.if_then(
builder.icmp_unsigned("!=", length, one), likely=False
):
msg = f"Vectorized input {j} is expected to have shape 1 in axis {i}"
ctx.call_conv.return_user_exc(builder, ValueError, (msg,))
elif shape[i] is not None:
with builder.if_then(
builder.icmp_unsigned("!=", length, shape[i]), likely=False
):
with builder.if_else(
builder.icmp_unsigned("==", length, one)
) as (
then,
otherwise,
):
with then:
msg = (
f"Incompatible vectorized shapes for input {j} and axis {i}. "
f"Input {j} has shape 1, but is not statically "
"known to have shape 1, and thus not broadcastable."
)
ctx.call_conv.return_user_exc(
builder, ValueError, (msg,)
)
with otherwise:
msg = f"Vectorized input {j} has an incompatible shape in axis {i}."
ctx.call_conv.return_user_exc(
builder, ValueError, (msg,)
)
else:
shape[i] = length
for i in range(batch_ndim):
if shape[i] is None:
shape[i] = one
return shape
def make_outputs(
ctx: numba.core.base.BaseContext,
builder: ir.IRBuilder,
iter_shape: tuple[ir.Instruction, ...],
out_bc: tuple[tuple[bool, ...], ...],
dtypes: tuple[Any, ...],
inplace: tuple[tuple[int, int], ...],
inputs: tuple[Any, ...],
input_types: tuple[Any, ...],
output_core_shapes: tuple,
) -> tuple[list[ir.Value], list[types.Array]]:
output_arrays = []
output_arry_types = []
one = ir.IntType(64)(1)
inplace_dict = dict(inplace)
for i, (core_shape, bc, dtype) in enumerate(
zip(output_core_shapes, out_bc, dtypes, strict=True)
):
if i in inplace_dict:
output_arrays.append(inputs[inplace_dict[i]])
output_arry_types.append(input_types[inplace_dict[i]])
# We need to incref once we return the inplace objects
continue
dtype = numba.from_dtype(np.dtype(dtype))
output_ndim = len(iter_shape) + len(core_shape)
arrtype = types.Array(dtype, output_ndim, "C")
output_arry_types.append(arrtype)
# This is actually an internal numba function, I guess we could
# call `numba.nd.unsafe.ndarray` instead?
batch_shape = [
length if not bc_dim else one
for length, bc_dim in zip(iter_shape, bc, strict=True)
]
shape = batch_shape + core_shape
array = arrayobj._empty_nd_impl(ctx, builder, arrtype, shape)
output_arrays.append(array)
# If there is no inplace operation, we know that all output arrays
# don't alias. Informing llvm can make it easier to vectorize.
if not inplace:
# The first argument is the output pointer
arg = builder.function.args[0]
arg.add_attribute("noalias")
return output_arrays, output_arry_types
def make_loop_call(
typingctx,
context: numba.core.base.BaseContext,
builder: ir.IRBuilder,
scalar_func: Any,
scalar_signature: types.FunctionType,
iter_shape: tuple[ir.Instruction, ...],
constant_inputs: tuple[ir.Instruction, ...],
inputs: tuple[ir.Instruction, ...],
outputs: tuple[ir.Instruction, ...],
input_bc: tuple[tuple[bool, ...], ...],
output_bc: tuple[tuple[bool, ...], ...],
input_types: tuple[Any, ...],
output_types: tuple[Any, ...],
):
safe = (False, False)
n_outputs = len(outputs)
# TODO I think this is better than the noalias attribute
# for the input, but self_ref isn't supported in a released
# llvmlite version yet
# mod = builder.module
# domain = mod.add_metadata([], self_ref=True)
# input_scope = mod.add_metadata([domain], self_ref=True)
# output_scope = mod.add_metadata([domain], self_ref=True)
# input_scope_set = mod.add_metadata([input_scope, output_scope])
# output_scope_set = mod.add_metadata([input_scope, output_scope])
zero = ir.Constant(ir.IntType(64), 0)
# Setup loops and initialize accumulators for outputs
# This part corresponds to opening the loops
loop_stack = []
loops = []
output_accumulator: list[tuple[Any | None, int | None]] = [(None, None)] * n_outputs
for dim, length in enumerate(iter_shape):
# Find outputs that only have accumulations left
for output in range(n_outputs):
if output_accumulator[output][0] is not None:
continue
if all(output_bc[output][dim:]):
value = outputs[output][0].type.pointee(0)
accu = cgutils.alloca_once_value(builder, value)
output_accumulator[output] = (accu, dim)
loop = cgutils.for_range(builder, length)
loop_stack.append(loop)
loops.append(loop.__enter__())
# Code in the inner most loop...
idxs = [loopval.index for loopval in loops]
# Load values from input arrays
input_vals = []
for input, input_type, bc in zip(inputs, input_types, input_bc, strict=True):
core_ndim = input_type.ndim - len(bc)
idxs_bc = [zero if bc else idx for idx, bc in zip(idxs, bc, strict=True)] + [
zero
] * core_ndim
ptr = cgutils.get_item_pointer2(
context,
builder,
input.data,
cgutils.unpack_tuple(builder, input.shape),
cgutils.unpack_tuple(builder, input.strides),
input_type.layout,
idxs_bc,
*safe,
)
if core_ndim == 0:
# Retrive scalar item at index
val = builder.load(ptr)
# val.set_metadata("alias.scope", input_scope_set)
# val.set_metadata("noalias", output_scope_set)
else:
# Retrieve array item at index
# This is a streamlined version of Numba's `GUArrayArg.load`
# TODO check layout arg!
core_arry_type = types.Array(
dtype=input_type.dtype, ndim=core_ndim, layout=input_type.layout
)
core_array = context.make_array(core_arry_type)(context, builder)
core_shape = cgutils.unpack_tuple(builder, input.shape)[-core_ndim:]
core_strides = cgutils.unpack_tuple(builder, input.strides)[-core_ndim:]
itemsize = context.get_abi_sizeof(context.get_data_type(input_type.dtype))
context.populate_array(
core_array,
# TODO whey do we need to bitcast?
data=builder.bitcast(ptr, core_array.data.type),
shape=cgutils.pack_array(builder, core_shape),
strides=cgutils.pack_array(builder, core_strides),
itemsize=context.get_constant(types.intp, itemsize),
# TODO what is meminfo about?
meminfo=None,
)
val = core_array._getvalue()
input_vals.append(val)
# Create output slices to pass to inner func
output_slices = []
for output, output_type, bc in zip(outputs, output_types, output_bc, strict=True):
core_ndim = output_type.ndim - len(bc)
size_type = output.shape.type.element # type: ignore
output_shape = cgutils.unpack_tuple(builder, output.shape) # type: ignore
output_strides = cgutils.unpack_tuple(builder, output.strides) # type: ignore
idxs_bc = [zero if bc else idx for idx, bc in zip(idxs, bc, strict=True)] + [
zero
] * core_ndim
ptr = cgutils.get_item_pointer2(
context,
builder,
output.data, # type:ignore
output_shape,
output_strides,
output_type.layout,
idxs_bc,
*safe,
)
# Retrieve array item at index
# This is a streamlined version of Numba's `GUArrayArg.load`
core_arry_type = types.Array(
dtype=output_type.dtype, ndim=core_ndim, layout=output_type.layout
)
core_array = context.make_array(core_arry_type)(context, builder)
core_shape = output_shape[-core_ndim:] if core_ndim > 0 else []
core_strides = output_strides[-core_ndim:] if core_ndim > 0 else []
itemsize = context.get_abi_sizeof(context.get_data_type(output_type.dtype))
context.populate_array(
core_array,
# TODO whey do we need to bitcast?
data=builder.bitcast(ptr, core_array.data.type),
shape=cgutils.pack_array(builder, core_shape, ty=size_type),
strides=cgutils.pack_array(builder, core_strides, ty=size_type),
itemsize=context.get_constant(types.intp, itemsize),
# TODO what is meminfo about?
meminfo=None,
)
val = core_array._getvalue()
output_slices.append(val)
inner_codegen = context.get_function(scalar_func, scalar_signature)
if isinstance(scalar_signature.args[0], types.StarArgTuple | types.StarArgUniTuple):
input_vals = [context.make_tuple(builder, scalar_signature.args[0], input_vals)]
inner_codegen(builder, [*constant_inputs, *input_vals, *output_slices])
# Close the loops
for depth, loop in enumerate(loop_stack[::-1]):
loop.__exit__(None, None, None)
return