Faster convolve numba#2175
Conversation
613b985 to
cb54a7e
Compare
| @register_canonicalize | ||
| @register_specialize | ||
| @node_rewriter([SpecifyShape]) | ||
| def local_specify_shape_alloc(fgraph, node): |
There was a problem hiding this comment.
this was messing some intermediate graphs I explored
There was a problem hiding this comment.
should this be shape_unsafe or something? I know it's not literally shape unsafe, but it's weird that a ViewOp ends up mutating the inputs.
There was a problem hiding this comment.
Hmm so this could mask an specifyshape(alloc(x, 3), 5), in which we only get the 3 at runtime. If it was static we would know at graph definition time.
In practice it's more like we have a alloc(x, shape(y)), from an elemwise broadcast and y doesn't have static shape, but through rewrites we found at some point that alloc must have length 5, so it's simpler not to rely on the shape of y
say alloc(x, shape(y)) + zeros(5) -> becomes specify_shape(alloc(x, shape(y)), 5) -> alloc(x, 5)
cb54a7e to
e447d97
Compare
|
I remember we talked about batched 1d vs 2d convolution, does that story play in here anywhere? |
Maybe for large kernels -> GPU but jax will likely do that already. For CPU and small kernels I don't think so. Like fft only starts winning at the 1000s. But Convolution is a rabbit hole, this is by no means the solution, just empirically better than what we had before. Also we don't yet have a native Convolve2D so it's not even an option |
|
what do you mean by native? Numba dispatch? |
Yes |
|
regardless of conv2d tricks, this would still apply tho the unbatched 1d case ofc |
|
Yeah i get that im asking a tangent question |
jessegrabowski
left a comment
There was a problem hiding this comment.
looks good broadly, some questions (in particular about overloading SpecifyShape to be actually enforced on the graph, not just as shape information)
| a_static_len = node.inputs[0].type.shape[-1] | ||
| b_static_len = node.inputs[1].type.shape[-1] |
There was a problem hiding this comment.
I guess this is rewrite safe because by the time we get to dispatch we're always done rewriting?
There was a problem hiding this comment.
What would be unsafe about it? The only thing is numba sees a constant in the inner loop. If we can't trust atatic shape after compile we would need to change many other places
There was a problem hiding this comment.
nothing would be unsafe about it. i'm just thinking out loud.
| return valid_convolve1d(x, y) | ||
| return valid_convolve1d(x, y, out=out) | ||
|
|
||
| convolve_1d.handles_out = True |
There was a problem hiding this comment.
Do we need to go back and add this tag everywhere that allows inplace?
There was a problem hiding this comment.
There's nothing else that supports it, it's a new argument/property
| @register_canonicalize | ||
| @register_specialize | ||
| @node_rewriter([SpecifyShape]) | ||
| def local_specify_shape_alloc(fgraph, node): |
There was a problem hiding this comment.
should this be shape_unsafe or something? I know it's not literally shape unsafe, but it's weird that a ViewOp ends up mutating the inputs.
LLVM really wins when it knows the static shape of the kernel, it can vectorize the inner loop.
fullpath unchanged; specialization only rewritesvalid_convolve1dfullpath unchangedshape=(8, None),use_static=Falseshape=(8, None),use_static=FalseOn the defaul/demo MMM model in pymc-marketing this translates to a ~1.2x speedup in the logp+dlogp function.
This PR also develops a mechanism to provide the out_argument for blockwise/non-scalar RV functions, which avoids the useless copy of the inner function buffer to the blockwise batched buffer. We should follow up and start using this in as many places as we can.
For now there's a hacky
.handles_outargument that specifies the behavior, we can think of a better API, but I wouldn't hang too much on it.The speed benefits are smaller (you can see 1-2 us in batched cases where the argument plays a role). It's more dramatic for blockwise of cheap inner graphs, and obviously reduces intermediate memory consumption.