Source code for jaxtransform3d.transformations._transform

import chex
import jax
import jax.numpy as jnp
from jax.typing import ArrayLike

from ..rotations import (
    apply_matrix,
    compact_axis_angle_from_matrix,
    left_jacobian_SO3_inv,
    matrix_inverse,
)


[docs] def transform_inverse(T: ArrayLike) -> jax.Array: r"""Invert transformation matrix. .. math:: \boldsymbol{T}^{-1} = \left( \begin{array}{cc} \boldsymbol{R} & \boldsymbol{t}\\ \boldsymbol{0} & 1 \end{array} \right)^{-1} = \left( \begin{array}{cc} \boldsymbol{R}^T & -\boldsymbol{R}^T \boldsymbol{t}\\ \boldsymbol{0} & 1 \end{array} \right) Parameters ---------- T : array-like, shape (..., 4, 4) Transformation matrix. Returns ------- T_inv : array, shape (..., 4, 4) Inverted transformation matrix. """ T = jnp.asarray(T) R = T[..., :3, :3] t = T[..., :3, 3] R_inv = matrix_inverse(R) t_inv = -apply_matrix(R_inv, t) return create_transform(R_inv, t_inv)
[docs] def apply_transform(T: ArrayLike, v: ArrayLike) -> jax.Array: r"""Apply transformation matrix to vector. .. math:: \boldsymbol{w} = \boldsymbol{R} \boldsymbol{v} + \boldsymbol{t} Parameters ---------- T : array-like, shape (..., 4, 4) or (4, 4) Transformation matrix. v : array-like, shape (..., 3) or (3,) 3d vector. Returns ------- w : array, shape (..., 3) or (3,) 3d vector. """ T = jnp.asarray(T) v = jnp.asarray(v) if not jnp.issubdtype(T.dtype, jnp.floating): T = T.astype(jnp.float64) if not jnp.issubdtype(v.dtype, jnp.floating): v = v.astype(jnp.float64) chex.assert_axis_dimension(T, axis=-2, expected=4) chex.assert_axis_dimension(T, axis=-1, expected=4) chex.assert_axis_dimension(v, axis=-1, expected=3) return apply_matrix(T[..., :3, :3], v) + T[..., :3, 3]
[docs] def compose_transforms(T1: ArrayLike, T2: ArrayLike) -> jax.Array: """Compose transformation matrices. Parameters ---------- T1 : array-like, shape (..., 4, 4) or (4, 4) Transformation matrix. T2 : array-like, shape (..., 4, 4) or (4, 4) Transformation matrix. Returns ------- T1_T2 : array, shape (..., 4, 4) or (4, 4) Composed transformation matrix. """ T1 = jnp.asarray(T1) T2 = jnp.asarray(T2) bigger_shape = T1.shape if T1.size > T2.size else T2.shape return (T1.reshape(-1, 4, 4) @ T2.reshape(-1, 4, 4)).reshape(bigger_shape)
[docs] def create_transform(R: ArrayLike, t: ArrayLike) -> jax.Array: r"""Make transformation from rotation matrix and translation. .. math:: \boldsymbol{T} = \left( \begin{array}{cc} \boldsymbol{R} & \boldsymbol{t}\\ \boldsymbol{0} & 1 \end{array} \right) \in SE(3) Parameters ---------- R : array-like, shape (..., 3, 3) Rotation matrix. t : array-like, shape (..., 3) Translation. Returns ------- T : array, shape (..., 4, 4) Transformation matrix. """ R = jnp.asarray(R) t = jnp.asarray(t) chex.assert_equal_shape_prefix((R, t), prefix_len=R.ndim - 1) chex.assert_axis_dimension(R, axis=-1, expected=3) T = jnp.zeros(R.shape[:-2] + (4, 4), dtype=R.dtype) T = T.at[..., :3, :3].set(R) T = T.at[..., :3, 3].set(t) T = T.at[..., 3, 3].set(1) return T
[docs] def exponential_coordinates_from_transform(T: ArrayLike) -> jax.Array: r"""Compute exponential coordinates from transformation matrix. This is the logarithm map. .. math:: Log: \boldsymbol{T} \in SE(3) \rightarrow \mathcal{S} \theta \in \mathbb{R}^6 .. math:: Log(\boldsymbol{T}) = Log\left( \begin{array}{cc} \boldsymbol{R} & \boldsymbol{p}\\ \boldsymbol{0} & 1 \end{array} \right) = \left( \begin{array}{c} Log(\boldsymbol{R})\\ \boldsymbol{J}^{-1}(\theta) \boldsymbol{p} \end{array} \right) = \left( \begin{array}{c} \hat{\boldsymbol{\omega}}\\ \boldsymbol{v} \end{array} \right) \theta = \mathcal{S}\theta, where :math:`\boldsymbol{J}^{-1}(\theta)` is the inverse left Jacobian of :math:`SO(3)`. Parameters ---------- T : array-like, shape (..., 4, 4) Transformation matrix. Returns ------- exp_coords : array, shape (..., 6) Exponential coordinates of transformation: S * theta = (omega_1, omega_2, omega_3, v_1, v_2, v_3) * theta, where S is the screw axis, the first 3 components are related to rotation and the last 3 components are related to translation. Theta is the rotation angle and h * theta the translation. See also -------- transform_from_exponential_coordinates : Exponential map. exponential_coordinates_from_dual_quaternion Logarithmic map for dual quaternions. """ T = jnp.asarray(T) chex.assert_axis_dimension(T, axis=-2, expected=4) chex.assert_axis_dimension(T, axis=-1, expected=4) R = T[..., :3, :3] t = T[..., :3, 3] axis_angle = compact_axis_angle_from_matrix(R) v_theta = (left_jacobian_SO3_inv(axis_angle) @ t[..., jnp.newaxis])[..., 0] return jnp.concatenate((axis_angle, v_theta), axis=-1)