Plot Polar Decomposition#

Robust polar decomposition orthonormalizes basis vectors (i.e., rotation matrices). It is more expensive than standard Gram-Schmidt orthonormalization, but it spreads the error more evenly over all basis vectors. The top row of these plots shows the unnormalized bases that were obtained by randomly rotating one of the columns of the identity matrix. The middle row shows Gram-Schmidt orthonormalization and the bottom row shows orthonormalization through robust polar decomposition. For comparison, we show the unnormalized basis with dashed lines in the last two rows.

Unnormalized Bases, Gram-Schmidt, Polar Decomposition
JIT-compiled Gram-Schmidt orthogonalization: 0.33549 s
JIT-compiled robust polar decomposition: 0.52874 s
Gram-Schmidt orthogonalization: 0.13418 s
Robost polar decomposition: 1.31573 s
/home/afabisch/Data/anaconda3/envs/jaxtransform3d/lib/python3.12/site-packages/pytransform3d/rotations/_utils.py:486: UserWarning: Expected rotation matrix, but it failed the test for inversion by transposition. np.dot(R, R.T) gives array([[ 1.05720187, -0.21547284, -0.08660988],
       [-0.21547284,  0.81166131,  0.32624941],
       [-0.08660988,  0.32624941,  1.13113681]])
  warnings.warn(error_msg)
/home/afabisch/Data/anaconda3/envs/jaxtransform3d/lib/python3.12/site-packages/pytransform3d/rotations/_utils.py:486: UserWarning: Expected rotation matrix, but it failed the test for inversion by transposition. np.dot(R, R.T) gives array([[ 1.1810075 , -0.14921235, -0.35493586],
       [-0.14921235,  1.12300223,  0.29258907],
       [-0.35493586,  0.29258907,  0.69599032]])
  warnings.warn(error_msg)
/home/afabisch/Data/anaconda3/envs/jaxtransform3d/lib/python3.12/site-packages/pytransform3d/rotations/_utils.py:486: UserWarning: Expected rotation matrix, but it failed the test for inversion by transposition. np.dot(R, R.T) gives array([[ 1.02770255, -0.00322832,  0.16408746],
       [-0.00322832,  1.00037621, -0.01912193],
       [ 0.16408746, -0.01912193,  0.97192121]])
  warnings.warn(error_msg)
/home/afabisch/Data/anaconda3/envs/jaxtransform3d/lib/python3.12/site-packages/pytransform3d/rotations/_utils.py:486: UserWarning: Expected rotation matrix, but it failed the test for inversion by transposition. np.dot(R, R.T) gives array([[ 1.07023019, -0.04169216, -0.25211044],
       [-0.04169216,  1.02475056,  0.14966539],
       [-0.25211044,  0.14966539,  0.90501925]])
  warnings.warn(error_msg)
/home/afabisch/Data/anaconda3/envs/jaxtransform3d/lib/python3.12/site-packages/pytransform3d/rotations/_utils.py:486: UserWarning: Expected rotation matrix, but it failed the test for inversion by transposition. np.dot(R, R.T) gives array([[ 0.8676452 , -0.3266027 , -0.09037511],
       [-0.3266027 ,  1.12294118,  0.03401938],
       [-0.09037511,  0.03401938,  1.00941359]])
  warnings.warn(error_msg)
/home/afabisch/Data/anaconda3/envs/jaxtransform3d/lib/python3.12/site-packages/pytransform3d/rotations/_utils.py:486: UserWarning: Expected rotation matrix, but it failed the test for inversion by transposition. np.dot(R, R.T) gives array([[ 1.00451664, -0.06697986,  0.00315241],
       [-0.06697986,  0.99328314, -0.04674896],
       [ 0.00315241, -0.04674896,  1.00220024]])
  warnings.warn(error_msg)
/home/afabisch/Data/anaconda3/envs/jaxtransform3d/lib/python3.12/site-packages/pytransform3d/rotations/_utils.py:486: UserWarning: Expected rotation matrix, but it failed the test for inversion by transposition. np.dot(R, R.T) gives array([[ 1.06777937, -0.25127454, -0.00681401],
       [-0.25127454,  0.93153558,  0.02526118],
       [-0.00681401,  0.02526118,  1.00068503]])
  warnings.warn(error_msg)
/home/afabisch/Data/anaconda3/envs/jaxtransform3d/lib/python3.12/site-packages/pytransform3d/rotations/_utils.py:486: UserWarning: Expected rotation matrix, but it failed the test for inversion by transposition. np.dot(R, R.T) gives array([[ 1.01181409, -0.08639439,  0.0648886 ],
       [-0.08639439,  0.63178698, -0.4745189 ],
       [ 0.0648886 , -0.4745189 ,  1.3563989 ]])
  warnings.warn(error_msg)

import time
from functools import partial

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import pytransform3d.rotations as pr

import jaxtransform3d.rotations as jr

gram_schmidt = jax.jit(jax.vmap(jr.norm_matrix, in_axes=0, out_axes=0))
robust_polar_decomposition = jax.jit(
    jax.vmap(partial(jr.robust_polar_decomposition, n_iter=5), in_axes=0, out_axes=0)
)

start = time.time()
gram_schmidt(jnp.eye(3)[jnp.newaxis]).block_until_ready()
gs_jit_time = time.time() - start
start = time.time()
robust_polar_decomposition(jnp.eye(3)[jnp.newaxis]).block_until_ready()
rpd_jit_time = time.time() - start

n_cases = 8
fig, axes = plt.subplots(3, n_cases, subplot_kw={"projection": "3d"}, figsize=(14, 8))
ax_s = 1.0
plot_center = jnp.array([-0.2, -0.2, -0.2])
for ax in axes.flat:
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_zticks([])
    ax.set_xlim(-ax_s, ax_s)
    ax.set_ylim(-ax_s, ax_s)
    ax.set_zlim(-ax_s, ax_s)

titles = ["Unnormalized Bases", "Gram-Schmidt", "Polar Decomposition"]
for ax, title in zip(axes[:, 0], titles, strict=False):
    ax.set_title(title)

rng = np.random.default_rng(46)
R_unnormalized = jnp.array([jnp.eye(3) for _ in range(n_cases)])
for i in range(n_cases):
    random_axis = rng.integers(0, 3)
    R_unnormalized = R_unnormalized.at[i, :, random_axis].set(
        jnp.dot(
            pr.random_matrix(rng, cov=0.1 * jnp.eye(3)),
            R_unnormalized[i, :, random_axis],
        )
    )

start = time.time()
R_gs = gram_schmidt(R_unnormalized)
gs_time = time.time() - start

start = time.time()
R_rpd = robust_polar_decomposition(R_unnormalized)
rpd_time = time.time() - start

print(f"JIT-compiled Gram-Schmidt orthogonalization: {gs_jit_time:.5f} s")
print(f"JIT-compiled robust polar decomposition: {rpd_jit_time:.5f} s")
print(f"Gram-Schmidt orthogonalization: {gs_time:.5f} s")
print(f"Robost polar decomposition: {rpd_time:.5f} s")

for i in range(n_cases):
    pr.plot_basis(axes[0, i], R_unnormalized[i], p=plot_center, strict_check=False)

    pr.plot_basis(
        axes[1, i], R_unnormalized[i], p=plot_center, strict_check=False, ls="--"
    )
    pr.plot_basis(axes[1, i], R_gs[i], p=plot_center)

    pr.plot_basis(
        axes[2, i], R_unnormalized[i], p=plot_center, strict_check=False, ls="--"
    )
    pr.plot_basis(axes[2, i], R_rpd[i], p=plot_center)

plt.tight_layout()
plt.show()

Total running time of the script: (0 minutes 3.327 seconds)

Gallery generated by Sphinx-Gallery