Enable batch support for windowed_mean|variance#1600
Enable batch support for windowed_mean|variance#1600nicolaspi wants to merge 14 commits intotensorflow:mainfrom
windowed_mean|variance#1600Conversation
|
@axch I made changes in the code you authored, could you kindly have a look at this PR? |
|
@nicolaspi thanks for the contribution! I am no longer an active maintainer of TFP, so I'm not really in a position to review your PR in detail (@jburnim please suggest someone?). On a quick look, though, I see a couple potential code style issues:
|
|
Thanks for your feedback!
We need specifically the
There is two motivations for this case. First, for backward compatibility, it is equivalent to the legacy non batched usage. Second, it is the only case I can think of where the broadcast is unambiguous when
In any case, I modified the unit tests to test against non static shapes.
I made usage of |
|
I'll take a look at this. |
| import numpy as np | ||
| import tensorflow.compat.v2 as tf | ||
|
|
||
| if NUMPY_MODE: |
There was a problem hiding this comment.
We'll need to do something different about take_along_axis.
- (preferred) Somehow rewrite the logic using
tf.gather/tf.gather_nd - Expose tf.experimental.numpy.take_along_axis in https://github.com/tensorflow/probability/tree/main/tensorflow_probability/python/internal/backend/numpy
As is, this is problematic since we really dislike using JAX_/NUMPY_MODE in library code.
There was a problem hiding this comment.
Thanks for the review!
- I don't feel comfortable rewriting
take_along_axisas it would duplicate already existing logics, I feel like it would produce unnecessary maintenance burden. - What about mapping
tensorflow.experimental.numpytonumpyandjax.numpybackends?
| must be between 0 and N+1, and the shape of the output will be | ||
| `Bx + [M] + E`. Batch shape in the indices is not currently supported. | ||
| Suppose `x` has shape `Bx + [N] + E`, `low_indices` and `high_indices` | ||
| have shape `Bi + [M] + F`, such that `rank(Bx) = rank(Bi) = axis`. |
There was a problem hiding this comment.
What is F? Why isn't it a scalar?
There was a problem hiding this comment.
Please check my comment below.
|
|
||
| The shape `Bi + [1] + F` must be broadcastable with the shape of `x`. | ||
|
|
||
| If `rank(Bi + [M] + F) < rank(x)`, then the indices are expanded |
There was a problem hiding this comment.
I don't think this paragraph adds anything, it's just an implementation detail.
There was a problem hiding this comment.
We specify the implicit rules we uses for broadcasting. I updated the formulation.
| Then each element of `low_indices` and `high_indices` must be | ||
| between 0 and N+1, and the shape of the output will be `Bx + [M] + E`. | ||
|
|
||
| The shape `Bi + [1] + F` must be broadcastable with the shape of `x`. |
There was a problem hiding this comment.
This contradicts the next paragraph, no?
In general, consider the non-batched version of this:
x shape: [N] + E
idx shape: [M]
output shape: [M] + E
The batching would introduce a batch dimension on the left of those shapes:
x shape: Bx + [N] + E
idx shape: Bi + [M]
output shape: broadcast(Bx, Bi) + [M] + E
Thus, the only broadcasting requirements are that Bx and Bi broadcast. I don't know where F came from.
There was a problem hiding this comment.
This contradicts the next paragraph, no?
Yes, I reformulated.
The batching would introduce a batch dimension on the left of those shapes:
Thus, the only broadcasting requirements are that Bx and Bi broadcast. I don't know where F came from.
Maybe the term 'batch' is not proper. This contribution adds the possibility to have the more general case where
idx shape is Bi + [M] + F. F could be seen as 'inner batch dimensions', but here 'batch' carries a different semantic than the standard machine learning one where it is represented by outer dims.
| @test_util.test_all_tf_execution_regimes | ||
| class WindowedStatsTest(test_util.TestCase): | ||
|
|
||
| def _maybe_expand_dims_to_make_broadcastable(self, x, shape, axis): |
There was a problem hiding this comment.
These two functions are as complex as the thing we're testing. Is there any way we can write this via np.vectorize?
There was a problem hiding this comment.
I refactored using np.vectorize, but I am not sure it is easier to read.
Add test cases
Some `tensorflow` to `prefer_static` replacement
Parametrize tests
4002d8b to
c90e961
Compare
|
Hi @SiegeLordEx, I have assessed your comments, can you have a look? Thanks |
This PR makes functions
windowed_meanandwindowed_varianceto accept indices with batch dimensions.Example:
Now gives:
Was previously failing with: