|
20 | 20 | ) |
21 | 21 | from pytensor.link.numba.dispatch.string_codegen import create_tuple_string |
22 | 22 | from pytensor.link.numba.dispatch.vectorize_codegen import ( |
| 23 | + _jit_options, |
23 | 24 | _vectorized, |
| 25 | + _vectorized_with_gather, |
24 | 26 | encode_literals, |
25 | 27 | store_core_outputs, |
26 | 28 | ) |
|
35 | 37 | Mul, |
36 | 38 | Sub, |
37 | 39 | TrueDiv, |
38 | | - get_scalar_type, |
39 | 40 | maximum, |
40 | 41 | ) |
41 | 42 | from pytensor.scalar.basic import add as add_as |
42 | 43 | from pytensor.tensor.blas import BatchedDot |
43 | | -from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise |
| 44 | +from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise, ElemwiseWithGather |
44 | 45 | from pytensor.tensor.math import Argmax, Dot, MulWithoutZeros, Sum |
45 | 46 | from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad |
46 | 47 |
|
@@ -312,8 +313,7 @@ def axis_apply_fn(x): |
312 | 313 |
|
313 | 314 | @register_funcify_and_cache_key(Elemwise) |
314 | 315 | def numba_funcify_Elemwise(op, node, **kwargs): |
315 | | - scalar_inputs = [get_scalar_type(dtype=input.dtype)() for input in node.inputs] |
316 | | - scalar_node = op.scalar_op.make_node(*scalar_inputs) |
| 316 | + scalar_node = op.make_scalar_node(*node.inputs) |
317 | 317 | scalar_op_fn, scalar_cache_key = numba_funcify_and_cache_key( |
318 | 318 | op.scalar_op, |
319 | 319 | node=scalar_node, |
@@ -390,6 +390,106 @@ def impl(*inputs): |
390 | 390 | return elemwise, elemwise_key |
391 | 391 |
|
392 | 392 |
|
| 393 | +@register_funcify_and_cache_key(ElemwiseWithGather) |
| 394 | +def numba_funcify_ElemwiseWithGather(op, node, **kwargs): |
| 395 | + """Generate fused gather+elemwise Numba code. |
| 396 | +
|
| 397 | + Analyzes the inner fgraph to find the Elemwise node and which inputs |
| 398 | + are gathered via AdvancedSubtensor1, then generates a single vectorized |
| 399 | + loop with indirect indexing for gathered inputs. |
| 400 | +
|
| 401 | + Outer inputs are expected in order: |
| 402 | + [elemwise_input_0_or_source, ..., elemwise_input_N_or_source, gather_index] |
| 403 | + The gather index is always the last input. ``indexed_inputs`` tells the |
| 404 | + intrinsic which elemwise input positions use indirect indexing via it. |
| 405 | + """ |
| 406 | + from pytensor.tensor.subtensor import AdvancedSubtensor1 |
| 407 | + |
| 408 | + [elemwise_node] = [n for n in op.fgraph.apply_nodes if isinstance(n.op, Elemwise)] |
| 409 | + |
| 410 | + # Group gathered inputs by their index array. |
| 411 | + # indexed_inputs encodes as ((inp_a, inp_b), (inp_c,), ...) — one tuple |
| 412 | + # per distinct index array, listing the elemwise input positions it gathers. |
| 413 | + # Outer inputs are [elemwise_inputs (sources substituted)..., idx_0, idx_1, ...] |
| 414 | + index_groups = {} # id(idx_var) -> list of elemwise input positions |
| 415 | + for i, inp in enumerate(elemwise_node.inputs): |
| 416 | + if inp.owner and isinstance(inp.owner.op, AdvancedSubtensor1): |
| 417 | + idx_var = inp.owner.inputs[1] |
| 418 | + index_groups.setdefault(id(idx_var), []).append(i) |
| 419 | + |
| 420 | + indexed_inputs = tuple(tuple(positions) for positions in index_groups.values()) |
| 421 | + |
| 422 | + indexed_set = {p for positions in indexed_inputs for p in positions} |
| 423 | + |
| 424 | + scalar_node = elemwise_node.op.make_scalar_node(*elemwise_node.inputs) |
| 425 | + scalar_op_fn, scalar_cache_key = numba_funcify_and_cache_key( |
| 426 | + elemwise_node.op.scalar_op, node=scalar_node, **kwargs |
| 427 | + ) |
| 428 | + |
| 429 | + nin_elemwise = len(elemwise_node.inputs) |
| 430 | + nout = len(elemwise_node.outputs) |
| 431 | + core_op_fn = store_core_outputs(scalar_op_fn, nin=nin_elemwise, nout=nout) |
| 432 | + |
| 433 | + # Gathered inputs use SOURCE's broadcastable; direct inputs use their own |
| 434 | + input_bc_patterns = tuple( |
| 435 | + inp.owner.inputs[0].type.broadcastable |
| 436 | + if i in indexed_set |
| 437 | + else inp.type.broadcastable |
| 438 | + for i, inp in enumerate(elemwise_node.inputs) |
| 439 | + ) |
| 440 | + output_bc_patterns = tuple(out.type.broadcastable for out in node.outputs) |
| 441 | + output_dtypes = tuple(out.type.dtype for out in node.outputs) |
| 442 | + inplace_pattern = tuple(elemwise_node.op.inplace_pattern.items()) |
| 443 | + core_output_shapes = tuple(() for _ in range(nout)) |
| 444 | + |
| 445 | + input_bc_patterns_enc = encode_literals(input_bc_patterns) |
| 446 | + output_bc_patterns_enc = encode_literals(output_bc_patterns) |
| 447 | + output_dtypes_enc = encode_literals(output_dtypes) |
| 448 | + inplace_pattern_enc = encode_literals(inplace_pattern) |
| 449 | + indexed_inputs_enc = encode_literals(indexed_inputs) |
| 450 | + |
| 451 | + def elemwise_with_gather(*outer_inputs): |
| 452 | + raise NotImplementedError( |
| 453 | + "ElemwiseWithGather cannot be evaluated in Python (non-JIT) mode." |
| 454 | + ) |
| 455 | + |
| 456 | + @overload(elemwise_with_gather, jit_options=_jit_options) |
| 457 | + def ov_elemwise_with_gather(*outer_inputs): |
| 458 | + def impl(*outer_inputs): |
| 459 | + return _vectorized_with_gather( |
| 460 | + core_op_fn, |
| 461 | + input_bc_patterns_enc, |
| 462 | + output_bc_patterns_enc, |
| 463 | + output_dtypes_enc, |
| 464 | + inplace_pattern_enc, |
| 465 | + True, # allow_core_scalar |
| 466 | + (), # constant_inputs |
| 467 | + outer_inputs, |
| 468 | + core_output_shapes, |
| 469 | + None, # size |
| 470 | + indexed_inputs_enc, |
| 471 | + ) |
| 472 | + |
| 473 | + return impl |
| 474 | + |
| 475 | + if scalar_cache_key is None: |
| 476 | + key = None |
| 477 | + else: |
| 478 | + key = str( |
| 479 | + ( |
| 480 | + type(op), |
| 481 | + "ElemwiseWithGather", |
| 482 | + inplace_pattern, |
| 483 | + input_bc_patterns, |
| 484 | + indexed_inputs, |
| 485 | + scalar_cache_key, |
| 486 | + ) |
| 487 | + ) |
| 488 | + key = sha256(key.encode()).hexdigest() |
| 489 | + |
| 490 | + return elemwise_with_gather, key |
| 491 | + |
| 492 | + |
393 | 493 | @register_funcify_and_cache_key(CAReduce) |
394 | 494 | def numba_funcify_CAReduce(op, node, **kwargs): |
395 | 495 | axes = op.axis |
|
0 commit comments