Skip to content

Commit 6d875d8

Browse files
committed
Numba: fuse AdvancedSubtensor+Elemwise
1 parent 80c81e4 commit 6d875d8

6 files changed

Lines changed: 805 additions & 152 deletions

File tree

pytensor/link/numba/dispatch/elemwise.py

Lines changed: 104 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@
2020
)
2121
from pytensor.link.numba.dispatch.string_codegen import create_tuple_string
2222
from pytensor.link.numba.dispatch.vectorize_codegen import (
23+
_jit_options,
2324
_vectorized,
25+
_vectorized_with_gather,
2426
encode_literals,
2527
store_core_outputs,
2628
)
@@ -35,12 +37,11 @@
3537
Mul,
3638
Sub,
3739
TrueDiv,
38-
get_scalar_type,
3940
maximum,
4041
)
4142
from pytensor.scalar.basic import add as add_as
4243
from pytensor.tensor.blas import BatchedDot
43-
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
44+
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise, ElemwiseWithGather
4445
from pytensor.tensor.math import Argmax, Dot, MulWithoutZeros, Sum
4546
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
4647

@@ -312,8 +313,7 @@ def axis_apply_fn(x):
312313

313314
@register_funcify_and_cache_key(Elemwise)
314315
def numba_funcify_Elemwise(op, node, **kwargs):
315-
scalar_inputs = [get_scalar_type(dtype=input.dtype)() for input in node.inputs]
316-
scalar_node = op.scalar_op.make_node(*scalar_inputs)
316+
scalar_node = op.make_scalar_node(*node.inputs)
317317
scalar_op_fn, scalar_cache_key = numba_funcify_and_cache_key(
318318
op.scalar_op,
319319
node=scalar_node,
@@ -390,6 +390,106 @@ def impl(*inputs):
390390
return elemwise, elemwise_key
391391

392392

393+
@register_funcify_and_cache_key(ElemwiseWithGather)
394+
def numba_funcify_ElemwiseWithGather(op, node, **kwargs):
395+
"""Generate fused gather+elemwise Numba code.
396+
397+
Analyzes the inner fgraph to find the Elemwise node and which inputs
398+
are gathered via AdvancedSubtensor1, then generates a single vectorized
399+
loop with indirect indexing for gathered inputs.
400+
401+
Outer inputs are expected in order:
402+
[elemwise_input_0_or_source, ..., elemwise_input_N_or_source, gather_index]
403+
The gather index is always the last input. ``indexed_inputs`` tells the
404+
intrinsic which elemwise input positions use indirect indexing via it.
405+
"""
406+
from pytensor.tensor.subtensor import AdvancedSubtensor1
407+
408+
[elemwise_node] = [n for n in op.fgraph.apply_nodes if isinstance(n.op, Elemwise)]
409+
410+
# Group gathered inputs by their index array.
411+
# indexed_inputs encodes as ((inp_a, inp_b), (inp_c,), ...) — one tuple
412+
# per distinct index array, listing the elemwise input positions it gathers.
413+
# Outer inputs are [elemwise_inputs (sources substituted)..., idx_0, idx_1, ...]
414+
index_groups = {} # id(idx_var) -> list of elemwise input positions
415+
for i, inp in enumerate(elemwise_node.inputs):
416+
if inp.owner and isinstance(inp.owner.op, AdvancedSubtensor1):
417+
idx_var = inp.owner.inputs[1]
418+
index_groups.setdefault(id(idx_var), []).append(i)
419+
420+
indexed_inputs = tuple(tuple(positions) for positions in index_groups.values())
421+
422+
indexed_set = {p for positions in indexed_inputs for p in positions}
423+
424+
scalar_node = elemwise_node.op.make_scalar_node(*elemwise_node.inputs)
425+
scalar_op_fn, scalar_cache_key = numba_funcify_and_cache_key(
426+
elemwise_node.op.scalar_op, node=scalar_node, **kwargs
427+
)
428+
429+
nin_elemwise = len(elemwise_node.inputs)
430+
nout = len(elemwise_node.outputs)
431+
core_op_fn = store_core_outputs(scalar_op_fn, nin=nin_elemwise, nout=nout)
432+
433+
# Gathered inputs use SOURCE's broadcastable; direct inputs use their own
434+
input_bc_patterns = tuple(
435+
inp.owner.inputs[0].type.broadcastable
436+
if i in indexed_set
437+
else inp.type.broadcastable
438+
for i, inp in enumerate(elemwise_node.inputs)
439+
)
440+
output_bc_patterns = tuple(out.type.broadcastable for out in node.outputs)
441+
output_dtypes = tuple(out.type.dtype for out in node.outputs)
442+
inplace_pattern = tuple(elemwise_node.op.inplace_pattern.items())
443+
core_output_shapes = tuple(() for _ in range(nout))
444+
445+
input_bc_patterns_enc = encode_literals(input_bc_patterns)
446+
output_bc_patterns_enc = encode_literals(output_bc_patterns)
447+
output_dtypes_enc = encode_literals(output_dtypes)
448+
inplace_pattern_enc = encode_literals(inplace_pattern)
449+
indexed_inputs_enc = encode_literals(indexed_inputs)
450+
451+
def elemwise_with_gather(*outer_inputs):
452+
raise NotImplementedError(
453+
"ElemwiseWithGather cannot be evaluated in Python (non-JIT) mode."
454+
)
455+
456+
@overload(elemwise_with_gather, jit_options=_jit_options)
457+
def ov_elemwise_with_gather(*outer_inputs):
458+
def impl(*outer_inputs):
459+
return _vectorized_with_gather(
460+
core_op_fn,
461+
input_bc_patterns_enc,
462+
output_bc_patterns_enc,
463+
output_dtypes_enc,
464+
inplace_pattern_enc,
465+
True, # allow_core_scalar
466+
(), # constant_inputs
467+
outer_inputs,
468+
core_output_shapes,
469+
None, # size
470+
indexed_inputs_enc,
471+
)
472+
473+
return impl
474+
475+
if scalar_cache_key is None:
476+
key = None
477+
else:
478+
key = str(
479+
(
480+
type(op),
481+
"ElemwiseWithGather",
482+
inplace_pattern,
483+
input_bc_patterns,
484+
indexed_inputs,
485+
scalar_cache_key,
486+
)
487+
)
488+
key = sha256(key.encode()).hexdigest()
489+
490+
return elemwise_with_gather, key
491+
492+
393493
@register_funcify_and_cache_key(CAReduce)
394494
def numba_funcify_CAReduce(op, node, **kwargs):
395495
axes = op.axis

0 commit comments

Comments
 (0)