Source code for symmys.layers.QuaternionRotation

import math

import tensorflow as tf
import tensorflow.keras as keras
import tensorflow.keras.backend as K

def projected_quaternion_initializer(shape, dtype=None):
    """Initialize a set of quaternions (to be summed over the last dimension) randomly.

    This method produces a uniform distribution of rotations.
    """
    assert shape[1] == 4
    (num_rotations, _, quaternion_dim) = shape

    # http://planning.cs.uiuc.edu/node198.html
    (u1, u2, u3) = tf.random.uniform((3, num_rotations), maxval=1.)
    r = tf.sqrt(1 - u1)*tf.sin(2*math.pi*u2)
    x = tf.sqrt(1 - u1)*tf.cos(2*math.pi*u2)
    y = tf.sqrt(u1)*tf.sin(2*math.pi*u3)
    z = tf.sqrt(u1)*tf.cos(2*math.pi*u3)

    (a, b, c, d) = tf.random.uniform((4, num_rotations, quaternion_dim), 0., 1.)

    elements = []
    for (u, v) in zip((a, b, c, d), (r, x, y, z)):
        norm = tf.expand_dims(v, axis=-1)/tf.reduce_sum(u, axis=-1, keepdims=True)
        elements.append(u*norm)

    return tf.concat([tf.expand_dims(v, -2) for v in elements], axis=1)

@tf.function
def rotate(quat, vec):
    real = K.expand_dims(quat[..., 0], -1)
    imag = quat[..., 1:]
    result = (real**2 - K.sum(imag**2, axis=-1, keepdims=True))*vec
    result = result + 2*real*tf.linalg.cross(imag, vec)
    result = result + 2*K.sum(imag*vec, axis=-1, keepdims=True)*imag
    return result

[docs]class QuaternionRotation(keras.layers.Layer): """Perform rotations of a set of input points, parameterized by unit quaternions. This layer takes a point cloud as input and produces rotated images of all the points in the point cloud. The rotations that are applied are parameterized by unit quaternions, which are treated as layer weights to be optimized. Quaternions are optimized in a higher dimension and then projected down through a `sum` operation to improve the speed of the optimization process. :param num_rotations: Number of rotation quaternions to use :param quaternion_dim: Pre-projection dimension of quaternion parameters :param include_reverse: If True, also output points rotated by the conjugate quaternion for each learned quaternion """ def __init__(self, num_rotations, quaternion_dim=6, include_reverse=True, *args, **kwargs): self.num_rotations = num_rotations self.quaternion_dim = quaternion_dim self.include_reverse = include_reverse super().__init__(*args, **kwargs) def build(self, input_shape): self.quaternion_weight = self.add_weight( shape=(self.num_rotations, 4, self.quaternion_dim), initializer=projected_quaternion_initializer, name='pre_projected_quaternions' ) @property def quaternions(self): quaternions = K.sum(self.quaternion_weight, axis=-1) quaternions = tf.linalg.normalize(quaternions, axis=-1)[0] return quaternions @property def training_quaternions(self): quaternions = self.quaternions if self.include_reverse: conj = quaternions*tf.constant([(-1., 1, 1, 1)]) quaternions = tf.concat([quaternions, conj], axis=0) return quaternions def call(self, inputs): # (whatever, 3) -> (whatever, num_rotations, 3) replicated = K.expand_dims(inputs, -2) shape = K.int_shape(replicated) replicas = [1]*len(shape) replicas[-2] = self.num_rotations*(2 if self.include_reverse else 1) replicated = K.tile(replicated, replicas) # (num_rotations, 4, d) -> (num_rotations, 4) quaternions = self.training_quaternions # (num_rotations, 4) -> (whatever, num_rotations, 4) for _ in range(2, len(shape)): quaternions = K.expand_dims(quaternions, axis=-3) symbolic_shape = K.shape(replicated) replicas = [symbolic_shape[i] for i in range(len(shape) - 2)] + [1, 1] quaternions = K.tile(quaternions, replicas) return rotate(quaternions, replicated) def get_config(self): config = super().get_config() config.update(dict( num_rotations=self.num_rotations, quaternion_dim=self.quaternion_dim, include_reverse=self.include_reverse, )) return config
[docs]class QuaternionRotoinversion(QuaternionRotation): """Learn rotoinversions, rather than rotations. Otherwise identical to :py:class:`QuaternionRotation`.""" def call(self, inputs): return super().call(-inputs)
keras.utils.get_custom_objects().update(dict( QuaternionRotation=QuaternionRotation, QuaternionRotoinversion=QuaternionRotoinversion, ))