diff --git a/quaternionic/properties.py b/quaternionic/properties.py index 70c5204..dc60153 100644 --- a/quaternionic/properties.py +++ b/quaternionic/properties.py @@ -260,4 +260,43 @@ def rotate(self, v, axis=-1): tensordot_axis, final_axis ) + def rotate_broadcast(self, v, axis=-1): + """Rotate vectors by quaternions in this array with NumPy broadcasting. + + The shape of this array and the vector array must be broadcastable + under standard NumPy rules (with the exception of the final axis of + this array, and the vector dimension axis of the vector array). + + For an Nx4 quaternion array and an Nx3 vector array, this function + will return a Nx3 rotated vector array. Note that this is in contrast + to the `rotate` method which performs the rotations as an outer + product and would return an NxNx3 rotated vector array in this case. + + Parameters + ---------- + v : float array + Three-vectors to be rotated. + axis : int, optional + Axis of the `v` array to use as the vector dimension. This axis of + `v` must have length 3. The default is the last axis. + + Returns + ------- + vprime : float array + The rotated vectors. The shape of this array is the broadcast + shape of self.shape and v.shape. The vector component will be + along the axis specified by the `axis` parameter. Note that this + means for an MxN quaternion array and a 3xN vector array, axis=0 + will give a 3xMxN output but axis=-2 will give a Mx3xN output. + + """ + v = np.asarray(v, dtype=self.dtype) + if v.ndim < 1 or 3 not in v.shape: + raise ValueError("Input `v` does not have at least one dimension of length 3") + if v.shape[axis] != 3: + raise ValueError("Input `v` axis {0} has length {1}, not 3.".format(axis, v.shape[axis])) + vq = self.from_vector_part(np.moveaxis(v, axis, -1)) + vprime = (self * vq * self.inverse).to_vector_part + return np.moveaxis(vprime, -1, axis) + return mixin diff --git a/tests/test_properties.py b/tests/test_properties.py index 8567ad4..53669d2 100644 --- a/tests/test_properties.py +++ b/tests/test_properties.py @@ -126,3 +126,74 @@ def test_rotate_vectors(Rs): [vprime.vector for vprime in quats * quaternionic.array(0, *vec) * ~quats], rtol=1e-15, atol=1e-15) assert quats.shape[:-1] + vecs.shape == vecsprime.shape, ("Out of shape!", quats.shape, vecs.shape, vecsprime.shape) + + +def test_rotate_broadcast(): + one, x, y, z = tuple(quaternionic.array(np.eye(4))) + zero = 0.0 * one + + with pytest.raises(ValueError): + one.rotate_broadcast(np.array(3.14)) + with pytest.raises(ValueError): + one.rotate_broadcast(np.random.normal(size=(17, 9, 4))) + with pytest.raises(ValueError): + one.rotate_broadcast(np.random.normal(size=(17, 9, 3)), axis=1) + with pytest.raises(ValueError, match="objects cannot be broadcast"): + quaternionic.array.random(10).rotate_broadcast([[1, 0, 0], [0, 1, 0]]) + + # Test (1)*(1) + vecs = np.random.normal(size=(3,)) + quats = z + vecsprime = quats.rotate_broadcast(vecs) + assert vecsprime.shape == vecs.shape + assert np.allclose(vecsprime, quats.rotate(vecs)) + + # Test (1)*(5) (both vector axis positions) + vecs = np.random.normal(size=(5, 3)) + quats = z + vecsprime = quats.rotate_broadcast(vecs) + assert vecsprime.shape == vecs.shape + assert np.allclose(vecsprime, quats.rotate(vecs)) + vecs = np.random.normal(size=(3, 5)) + quats = z + vecsprime = quats.rotate_broadcast(vecs, axis=0) + assert vecsprime.shape == vecs.shape + assert np.allclose(vecsprime, quats.rotate(vecs, axis=0)) + + # Test (5)*(1) + vecs = np.random.normal(size=(3,)) + quats = quaternionic.array.random(5) + vecsprime = quats.rotate_broadcast(vecs) + assert vecsprime.shape == (5, 3) + assert np.allclose(vecsprime, quats.rotate(vecs)) + + # Test (5)*(5) + vecs = np.random.normal(size=(5, 3)) + quats = quaternionic.array.random(5) + vecsprime = quats.rotate_broadcast(vecs) + assert vecsprime.shape == vecs.shape + for i in range(3): + assert np.allclose(vecsprime[:, i], np.diag(quats.rotate(vecs)[..., i])) + + # Test (8,5)*(5) (both vector axis positions). + vecs = np.random.normal(size=(5, 3)) + quats = quaternionic.array.random((8, 5)) + vecsprime = quats.rotate_broadcast(vecs) + assert vecsprime.shape == (8, 5, 3) + for i in range(8): + for j in range(3): + assert np.allclose(vecsprime[i, :, j], np.diag(quats[i].rotate(vecs)[..., j])) + + vecs = np.random.normal(size=(3, 5)) + vecsprime = quats.rotate_broadcast(vecs, axis=0) + assert vecsprime.shape == (3, 8, 5) + for i in range(8): + for j in range(3): + assert np.allclose(vecsprime[j, i, :], np.diag(quats[i].rotate(vecs, axis=0)[:, j])) + + # Same, but giving axis=-2 will yield a different output shape. + vecsprime = quats.rotate_broadcast(vecs, axis=-2) + assert vecsprime.shape == (8, 3, 5) + for i in range(8): + for j in range(3): + assert np.allclose(vecsprime[i, j, :], np.diag(quats[i].rotate(vecs, axis=0)[:, j]))