jaxtransform3d.rotations.quaternion_from_compact_axis_angle#

jaxtransform3d.rotations.quaternion_from_compact_axis_angle(axis_angle: Array | ndarray | bool_ | number | bool | int | float | complex) Array[source]#

Compute quaternion from axis-angle.

This operation is called exponential map.

Given a compact axis-angle representation (rotation vector) \(\hat{\boldsymbol{\omega}} \theta \in \mathbb{R}^3\), we compute the unit quaternion \(\boldsymbol{q} \in S^3\) as

\[\begin{split}\boldsymbol{q}(\hat{\boldsymbol{\omega}} \theta) = Exp(\hat{\boldsymbol{\omega}} \theta) = \left( \begin{array}{c} \cos{\frac{\theta}{2}}\\ \hat{\boldsymbol{\omega}} \sin{\frac{\theta}{2}} \end{array} \right).\end{split}\]

For small angles we use a Taylor series approximation.

Parameters:
axis_anglearray-like, shape (…, 3)

Axis of rotation and rotation angle in compact form (also known as rotation vector): angle * (x, y, z) or \(\hat{\boldsymbol{\omega}} \theta\).

Returns:
qarray, shape (…, 4)

Unit quaternion to represent rotation: (w, x, y, z)

See also

compact_axis_angle_from_quaternion

Logarithmic map.

matrix_from_compact_axis_angle

Exponential map for rotation matrices.

Examples

>>> import jax.numpy as jnp
>>> from jaxtransform3d.rotations import quaternion_from_compact_axis_angle
>>> quaternion_from_compact_axis_angle(jnp.zeros(3))
Array([1., 0., 0., 0.], dtype=...)
>>> import jax
>>> a = jax.random.normal(jax.random.PRNGKey(42), shape=(2, 3))
>>> a
Array([[-0.0283...,  0.4671...,  0.2957...],
       [ 0.1535..., -0.1240...,  0.2169...]], dtype=...)
>>> quaternion_from_compact_axis_angle(a)
Array([[ 0.9619..., -0.0139...,  0.2305...,  0.1459...],
       [ 0.9892...,  0.0764..., -0.0617...,  0.1080...]], ...)