Source code for jaxtransform3d.rotations._jacobian

import jax
import jax.numpy as jnp

from ..utils import (
    cross_product_matrix,
    differentiable_norm,
    min_diff_norm,
    norm_vector,
)


[docs] def left_jacobian_SO3(axis_angle: jnp.ndarray) -> jnp.ndarray: r"""Left Jacobian of SO(3) at theta (angle of rotation). .. math:: \frac{\partial Exp(\hat{\boldsymbol{\omega}}\theta)} {\partial\hat{\boldsymbol{\omega}}\theta} = \boldsymbol{J}(\hat{\boldsymbol{\omega}}\theta) = \frac{\sin{\theta}}{\theta} \boldsymbol{I} + \left(\frac{1 - \cos{\theta}}{\theta}\right) \left[\hat{\boldsymbol{\omega}}\right] + \left(1 - \frac{\sin{\theta}}{\theta} \right) \hat{\boldsymbol{\omega}} \hat{\boldsymbol{\omega}}^T Parameters ---------- axis_angle : array, shape (..., 3) Compact axis-angle representation. Returns ------- J : array, shape (..., 3, 3) Left Jacobian of SO(3). See also -------- left_jacobian_SO3_series : Left Jacobian of SO(3) at theta from Taylor series. left_jacobian_SO3_inv : Inverse left Jacobian of SO(3) at theta (angle of rotation). """ theta = differentiable_norm(axis_angle, axis=-1) theta_safe = jnp.where(theta > min_diff_norm(theta), theta, 1.0) omega_unit = norm_vector(axis_angle, norm=theta) omega_matrix = cross_product_matrix(omega_unit) eye = jnp.broadcast_to(jnp.eye(3), omega_matrix.shape) factor1 = (1.0 - jnp.cos(theta_safe)) / theta_safe factor2 = 1.0 - jnp.sin(theta_safe) / theta_safe J = ( eye + factor1[..., jnp.newaxis, jnp.newaxis] * omega_matrix + factor2[..., jnp.newaxis, jnp.newaxis] * omega_matrix @ omega_matrix ) J_taylor = left_jacobian_SO3_series(axis_angle) return jnp.where(theta[..., jnp.newaxis, jnp.newaxis] < 1e-3, J_taylor, J)
[docs] def left_jacobian_SO3_series(axis_angle: jnp.ndarray) -> jnp.ndarray: """Left Jacobian of SO(3) at theta from Taylor series with 10 terms. Parameters ---------- axis_angle : array-like, shape (..., 3) Compact axis-angle representation. Returns ------- J : array, shape (..., 3, 3) Left Jacobian of SO(3). See Also -------- left_jacobian_SO3 : Left Jacobian of SO(3) at theta (angle of rotation). """ eye = jnp.broadcast_to(jnp.eye(3), axis_angle.shape + (3,)) px = cross_product_matrix(axis_angle) pxn = eye J = eye for n in range(10): pxn = pxn @ px / (n + 2) J = J + pxn return J
[docs] def left_jacobian_SO3_inv(axis_angle: jnp.ndarray) -> jnp.ndarray: r"""Inverse left Jacobian of SO(3) at theta (angle of rotation). .. math:: \boldsymbol{J}^{-1}(\theta) = \frac{\theta}{2 \tan{\frac{\theta}{2}}} \boldsymbol{I} - \frac{\theta}{2} \left[\hat{\boldsymbol{\omega}}\right] + \left(1 - \frac{\theta}{2 \tan{\frac{\theta}{2}}}\right) \hat{\boldsymbol{\omega}} \hat{\boldsymbol{\omega}}^T Parameters ---------- axis_angle : array-like, shape (..., 3) Compact axis-angle representation. Returns ------- J_inv : array, shape (..., 3, 3) Inverse left Jacobian of SO(3). See Also -------- left_jacobian_SO3 : Left Jacobian of SO(3) at theta (angle of rotation). left_jacobian_SO3_inv_series : Inverse left Jacobian of SO(3) at theta from Taylor series. """ theta = differentiable_norm(axis_angle, axis=-1) theta_safe = jnp.where(theta > min_diff_norm(theta), theta, 1.0) omega_unit = norm_vector(axis_angle, norm=theta) omega_matrix = cross_product_matrix(omega_unit) eye = jnp.broadcast_to(jnp.eye(3), omega_matrix.shape) factor1 = 0.5 * theta factor2 = 1.0 - 0.5 * theta / jnp.tan(theta_safe / 2.0) J_inv = ( eye - factor1[..., jnp.newaxis, jnp.newaxis] * omega_matrix + factor2[..., jnp.newaxis, jnp.newaxis] * omega_matrix @ omega_matrix ) J_inv_taylor = left_jacobian_SO3_inv_series(axis_angle) return jnp.where(theta[..., jnp.newaxis, jnp.newaxis] < 1e-3, J_inv_taylor, J_inv)
[docs] def left_jacobian_SO3_inv_series(axis_angle: jnp.ndarray) -> jnp.ndarray: """Inverse left Jacobian of SO(3) at theta from Taylor series with 10 terms. Parameters ---------- axis_angle : array, shape (..., 3) Compact axis-angle representation. Returns ------- J_inv : array, shape (..., 3, 3) Inverse left Jacobian of SO(3). See Also -------- left_jacobian_SO3_inv : Inverse left Jacobian of SO(3) at theta (angle of rotation). """ eye = jnp.broadcast_to(jnp.eye(3), axis_angle.shape + (3,)) px = cross_product_matrix(axis_angle) J_inv = eye pxn = eye px = cross_product_matrix(axis_angle) b = jax.scipy.special.bernoulli(11) for n in range(10): pxn = pxn @ px / (n + 1) J_inv = J_inv + b[n + 1] * pxn return J_inv