Skip to content

Faster convolve numba#2175

Open
ricardoV94 wants to merge 4 commits into
pymc-devs:mainfrom
ricardoV94:faster_convolve_numba
Open

Faster convolve numba#2175
ricardoV94 wants to merge 4 commits into
pymc-devs:mainfrom
ricardoV94:faster_convolve_numba

Conversation

@ricardoV94
Copy link
Copy Markdown
Member

@ricardoV94 ricardoV94 commented May 27, 2026

LLVM really wins when it knows the static shape of the kernel, it can vectorize the inner loop.

Benchmark Main (μs) Branch (μs) Speedup Specialization?
numba batch=False mode=valid 2.50 1.47 1.70x Yes — static shapes (183,), (6,)
numba batch=True mode=valid 9.23 2.15 4.29x Yes — static shapes (7,183), (7,6)
numba batch=False mode=full 2.49 2.47 1.01x No effect — full path unchanged; specialization only rewrites valid_convolve1d
numba batch=True mode=full 9.26 8.76 1.06x No effect — full path unchanged
numba grad full 82.28 81.00 ~1.0x No — inputs have shape=(8, None), use_static=False
numba grad valid 80.72 80.27 ~1.0x No — inputs have shape=(8, None), use_static=False

On the defaul/demo MMM model in pymc-marketing this translates to a ~1.2x speedup in the logp+dlogp function.

This PR also develops a mechanism to provide the out_argument for blockwise/non-scalar RV functions, which avoids the useless copy of the inner function buffer to the blockwise batched buffer. We should follow up and start using this in as many places as we can.

For now there's a hacky .handles_out argument that specifies the behavior, we can think of a better API, but I wouldn't hang too much on it.

The speed benefits are smaller (you can see 1-2 us in batched cases where the argument plays a role). It's more dramatic for blockwise of cheap inner graphs, and obviously reduces intermediate memory consumption.

@register_canonicalize
@register_specialize
@node_rewriter([SpecifyShape])
def local_specify_shape_alloc(fgraph, node):
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was messing some intermediate graphs I explored

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be shape_unsafe or something? I know it's not literally shape unsafe, but it's weird that a ViewOp ends up mutating the inputs.

Copy link
Copy Markdown
Member Author

@ricardoV94 ricardoV94 May 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm so this could mask an specifyshape(alloc(x, 3), 5), in which we only get the 3 at runtime. If it was static we would know at graph definition time.

In practice it's more like we have a alloc(x, shape(y)), from an elemwise broadcast and y doesn't have static shape, but through rewrites we found at some point that alloc must have length 5, so it's simpler not to rely on the shape of y

say alloc(x, shape(y)) + zeros(5) -> becomes specify_shape(alloc(x, shape(y)), 5) -> alloc(x, 5)

@ricardoV94 ricardoV94 force-pushed the faster_convolve_numba branch from cb54a7e to e447d97 Compare May 27, 2026 12:47
@jessegrabowski
Copy link
Copy Markdown
Member

I remember we talked about batched 1d vs 2d convolution, does that story play in here anywhere?

@ricardoV94
Copy link
Copy Markdown
Member Author

ricardoV94 commented May 27, 2026

I remember we talked about batched 1d vs 2d convolution, does that story play in here anywhere?

Maybe for large kernels -> GPU but jax will likely do that already. For CPU and small kernels I don't think so. Like fft only starts winning at the 1000s. But Convolution is a rabbit hole, this is by no means the solution, just empirically better than what we had before.

Also we don't yet have a native Convolve2D so it's not even an option

@jessegrabowski
Copy link
Copy Markdown
Member

what do you mean by native? Numba dispatch?

@ricardoV94
Copy link
Copy Markdown
Member Author

what do you mean by native? Numba dispatch?

Yes

@ricardoV94
Copy link
Copy Markdown
Member Author

regardless of conv2d tricks, this would still apply tho the unbatched 1d case ofc

@jessegrabowski
Copy link
Copy Markdown
Member

Yeah i get that im asking a tangent question

Copy link
Copy Markdown
Member

@jessegrabowski jessegrabowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good broadly, some questions (in particular about overloading SpecifyShape to be actually enforced on the graph, not just as shape information)

Comment on lines +18 to +19
a_static_len = node.inputs[0].type.shape[-1]
b_static_len = node.inputs[1].type.shape[-1]
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this is rewrite safe because by the time we get to dispatch we're always done rewriting?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What would be unsafe about it? The only thing is numba sees a constant in the inner loop. If we can't trust atatic shape after compile we would need to change many other places

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nothing would be unsafe about it. i'm just thinking out loud.

return valid_convolve1d(x, y)
return valid_convolve1d(x, y, out=out)

convolve_1d.handles_out = True
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to go back and add this tag everywhere that allows inplace?

Copy link
Copy Markdown
Member Author

@ricardoV94 ricardoV94 May 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's nothing else that supports it, it's a new argument/property

@register_canonicalize
@register_specialize
@node_rewriter([SpecifyShape])
def local_specify_shape_alloc(fgraph, node):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be shape_unsafe or something? I know it's not literally shape unsafe, but it's weird that a ViewOp ends up mutating the inputs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants