Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ __pycache__/

# Ignore environment files
.env
.venv
venv/
env/
pip-lockfile.txt
Expand All @@ -26,4 +27,5 @@ config/*.ini
.vscode/
.idea/
*.sublime-project
*.sublime-workspace
*.sublime-workspace
.mypy_cache
11 changes: 5 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,15 @@ to do:

## Get Started
### Requirments
tested in Pytorch 1.12, CUDA-11.1, gcc-6.3.0.

To do: check if it compiles with:
- [ ] Pytorch 2.0
- [ ] CUDA 11.7
tested in Pytorch 2.5.1+cu124, CUDA-12.4, gcc-9.2.0

### Install
```
python setup.py install
pip install -r requirements.txt
pip install . --no-build-isolation
```
And make sure that there's no lock file left over in torch_extensions/Cache/py312_cu124/_quaternion_cuda before running a program.
Signed, several hours of my life
### Test
```
python examples.py
Expand Down
16 changes: 11 additions & 5 deletions dqtorch/backend.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,37 @@
import os
import torch
from torch.utils.cpp_extension import load

major, minor = torch.cuda.get_device_capability()
os.environ["TORCH_CUDA_ARCH_LIST"] = f"{major}.{minor}"

_src_path = os.path.dirname(os.path.abspath(__file__))

nvcc_flags = [
'-O3', '-std=c++14',
'-O3', '-std=c++17',
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
]

if os.name == "posix":
c_flags = ['-O3', '-std=c++14']
c_flags = ['-O3', '-std=c++17']
elif os.name == "nt":
c_flags = ['/O2', '/std:c++17']

# find cl.exe
def find_cl_path():
import glob
for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
paths = sorted(glob.glob(
r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
if paths:
return paths[0]

# If cl.exe is not on path, try to find it.
if os.system("where cl.exe >nul 2>nul") != 0:
cl_path = find_cl_path()
if cl_path is None:
raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
raise RuntimeError(
"Could not locate a supported Microsoft Visual C++ installation")
os.environ["PATH"] += ";" + cl_path

_backend = load(name='_quaternion_cuda',
Expand All @@ -37,4 +43,4 @@ def find_cl_path():
]],
)

__all__ = ['_backend']
__all__ = ['_backend']
95 changes: 51 additions & 44 deletions dqtorch/dqtorch.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,14 @@
import torch
from typing import Tuple, Optional, Union
from typing import Tuple

from .quaternion_cuda import quaternion_mul as _quaternion_mul_cuda
from .quaternion_cuda import quaternion_conjugate as _quaternion_conjugate_cuda

from enum import Enum, unique



Quaternion = torch.Tensor
DualQuaternions = Tuple[Quaternion, Quaternion]
QuaternionTranslation = Tuple[Quaternion, torch.Tensor]



'''
quaternion library from pytorch3d
https://pytorch3d.readthedocs.io/en/latest/_modules/pytorch3d/transforms/rotation_conversions.html
Expand All @@ -33,7 +28,7 @@ def quaternion_conjugate(q: Quaternion) -> Quaternion:
# return _quaternion_conjugate_cuda(q.contiguous().view(-1,4)).view(out_shape)
if q.is_cuda:
out_shape = q.shape
return _quaternion_conjugate_cuda(q.contiguous().view(-1,4)).view(out_shape)
return _quaternion_conjugate_cuda(q.contiguous().view(-1, 4)).view(out_shape) # type:ignore # nopep8
else:
return _quaternion_conjugate_pytorch(q)

Expand All @@ -53,6 +48,8 @@ def standardize_quaternion(quaternions: Quaternion) -> Quaternion:
return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions)

# @torch.jit.script


def _quaternion_mul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
"""
Multiply two quaternions.
Expand All @@ -73,25 +70,28 @@ def _quaternion_mul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
oz = aw * bz + ax * by - ay * bx + az * bw
return torch.stack((ow, ox, oy, oz), -1)

def _quaternion_4D_mul_3D(a:torch.Tensor, b_xyz:torch.Tensor) -> torch.Tensor:

def _quaternion_4D_mul_3D(a: torch.Tensor, b_xyz: torch.Tensor) -> torch.Tensor:
aw, ax, ay, az = torch.unbind(a, -1)
bx, by, bz = torch.unbind(b_xyz, -1)
ow = - ax * bx - ay * by - az * bz
ox = aw * bx + ay * bz - az * by
oy = aw * by - ax * bz + az * bx
oz = aw * bz + ax * by - ay * bx
return torch.stack((ow, ox, oy, oz), -1)
return torch.stack((ow, ox, oy, oz), -1)


def _quaternion_3D_mul_4D(a_xyz:torch.Tensor, b:torch.Tensor) -> torch.Tensor:
def _quaternion_3D_mul_4D(a_xyz: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
ax, ay, az = torch.unbind(a_xyz, -1)
bw, bx, by, bz = torch.unbind(b, -1)
ow = - ax * bx - ay * by - az * bz
ox = ax * bw + ay * bz - az * by
oy = - ax * bz + ay * bw + az * bx
oz = ax * by - ay * bx + az * bw
ow = - ax * bx - ay * by - az * bz
ox = ax * bw + ay * bz - az * by
oy = - ax * bz + ay * bw + az * bx
oz = ax * by - ay * bx + az * bw
return torch.stack((ow, ox, oy, oz), -1)

def _quaternion_mul_pytorch(a: torch.Tensor, b:torch.Tensor):

def _quaternion_mul_pytorch(a: torch.Tensor, b: torch.Tensor):
'''
native pytorch implementation, only used as a baseline.
'''
Expand All @@ -103,14 +103,12 @@ def _quaternion_mul_pytorch(a: torch.Tensor, b:torch.Tensor):
return _quaternion_4D_mul_3D(a, b)
else:
raise ValueError(f"Invalid input shapes.")




def quaternion_mul(a: Quaternion, b: Quaternion) -> Quaternion:
if a.is_cuda:
ouput_shape = list(a.shape[:-1]) + [4]
return _quaternion_mul_cuda(a.view(-1, a.shape[-1]), b.view(-1, b.shape[-1])).view(ouput_shape)
return _quaternion_mul_cuda(a.view(-1, a.shape[-1]), b.view(-1, b.shape[-1])).view(ouput_shape) # type:ignore # nopep8
else:
return _quaternion_mul_pytorch(a, b)

Expand Down Expand Up @@ -159,7 +157,7 @@ def axis_angle_to_quaternion(axis_angle: torch.Tensor) -> torch.Tensor:
def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor:
"""
Convert rotations given as quaternions to rotation matrices.


Args:
quaternions: quaternions with real part first,
Expand All @@ -183,7 +181,7 @@ def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor:
o2 = two_s * (ij - kr)
o3 = two_s * (ik + jr)
o4 = two_s * (ij + kr)

o5 = 1 - two_s * (ii + kk)
o6 = two_s * (jk - ir)
o7 = two_s * (ik - jr)
Expand All @@ -195,6 +193,7 @@ def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor:

return o.view(quaternions.shape[:-1] + (3, 3))


def quaternion_apply(quaternion: Quaternion, point: torch.Tensor) -> torch.Tensor:
"""
Apply the rotation given by a quaternion to a 3D point.
Expand All @@ -213,38 +212,42 @@ def quaternion_apply(quaternion: Quaternion, point: torch.Tensor) -> torch.Tenso
)
return out[..., 1:].contiguous()

def quaternion_translation_apply(q:Quaternion, t:torch.Tensor, point:torch.Tensor) -> torch.Tensor:

def quaternion_translation_apply(q: Quaternion, t: torch.Tensor, point: torch.Tensor) -> torch.Tensor:
p = quaternion_apply(q, point)
return p + t

def quaternion_translation_compose(qt1:QuaternionTranslation, qt2:QuaternionTranslation) -> QuaternionTranslation:


def quaternion_translation_compose(qt1: QuaternionTranslation, qt2: QuaternionTranslation) -> QuaternionTranslation:
qr = quaternion_mul(qt1[0], qt2[0])
t = quaternion_apply(qt1[0], qt2[1]) + qt1[1]
return (qr, t)

def quaternion_translation_inverse(q:Quaternion, t:torch.Tensor) -> Tuple[Quaternion, torch.Tensor]:

def quaternion_translation_inverse(q: Quaternion, t: torch.Tensor) -> Tuple[Quaternion, torch.Tensor]:
q_inv = quaternion_conjugate(q)
t_inv = quaternion_apply(q_inv, -t)
return q_inv, t_inv


def quaternion_translation_to_dual_quaternion(
q:torch.Tensor, t:torch.Tensor) -> DualQuaternions:
q: torch.Tensor, t: torch.Tensor) -> DualQuaternions:
'''
https://cs.gmu.edu/~jmlien/teaching/cs451/uploads/Main/dual-quaternion.pdf
'''
q_d = 0.5* quaternion_mul(t, q)
q_d = 0.5 * quaternion_mul(t, q)
return (q, q_d)


def dual_quaternion_to_quaternion_translation(dq:DualQuaternions) -> DualQuaternions:
def dual_quaternion_to_quaternion_translation(dq: DualQuaternions) -> DualQuaternions:
q_r = dq[0]
q_d = dq[1]
t = 2*quaternion_mul(q_d, quaternion_conjugate(q_r))[..., 1:]

return q_r, t


def dual_quaternion_mul(dq1:DualQuaternions, dq2:DualQuaternions) -> DualQuaternions:
def dual_quaternion_mul(dq1: DualQuaternions, dq2: DualQuaternions) -> DualQuaternions:
q_r1 = dq1[0]
q_d1 = dq1[1]
q_r2 = dq2[0]
Expand All @@ -253,35 +256,37 @@ def dual_quaternion_mul(dq1:DualQuaternions, dq2:DualQuaternions) -> DualQuatern
r_d = quaternion_mul(q_r1, q_d2) + quaternion_mul(q_d1, q_r2)
return (r_r, r_d)

def dual_quaternion_apply(dq:DualQuaternions, point:torch.Tensor) -> torch.Tensor:

def dual_quaternion_apply(dq: DualQuaternions, point: torch.Tensor) -> torch.Tensor:
'''
assuming the input dual quaternion is normalized.
'''
q, t = dual_quaternion_to_quaternion_translation(dq)
return quaternion_translation_apply(q, t, point)


def dual_quaternion_q_conjugate(dq:DualQuaternions) -> DualQuaternions:
def dual_quaternion_q_conjugate(dq: DualQuaternions) -> DualQuaternions:
r = quaternion_conjugate(dq[0])
d = quaternion_conjugate(dq[1])
return (r, d)


def dual_quaternion_d_conjugate(dq:DualQuaternions) -> DualQuaternions:
def dual_quaternion_d_conjugate(dq: DualQuaternions) -> DualQuaternions:
return (dq[0], -dq[1])


def dual_quaternion_3rd_conjugate(dq:DualQuaternions) -> DualQuaternions:
def dual_quaternion_3rd_conjugate(dq: DualQuaternions) -> DualQuaternions:
return dual_quaternion_d_conjugate(
dual_quaternion_q_conjugate(dq) )
dual_quaternion_q_conjugate(dq))


# def dual_quaternion_inverse(dq:DualQuaternions) -> DualQuaternions:
# return dual_quaternion_q_conjugate(dq)

dual_quaternion_inverse = dual_quaternion_q_conjugate

def dual_quaternion_rectify(dq:DualQuaternions) -> DualQuaternions:

def dual_quaternion_rectify(dq: DualQuaternions) -> DualQuaternions:
'''
input: (unit quaternion, 4D vector w') -> dual quaternion, which satisfies (r, 0.5* t r)
solve: min | q - w' | s.t. w^T r = 0
Expand All @@ -291,6 +296,7 @@ def dual_quaternion_rectify(dq:DualQuaternions) -> DualQuaternions:

return (q_r, q_d)


def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
"""
Returns torch.sqrt(torch.max(0, x))
Expand Down Expand Up @@ -336,10 +342,14 @@ def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
# we produce the desired quaternion multiplied by each of r, i, j, k
quat_by_rijk = torch.stack(
[
torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
torch.stack([q_abs[..., 0] ** 2, m21 - m12,
m02 - m20, m10 - m01], dim=-1),
torch.stack([m21 - m12, q_abs[..., 1] ** 2,
m10 + m01, m02 + m20], dim=-1),
torch.stack([m02 - m20, m10 + m01, q_abs[..., 2]
** 2, m12 + m21], dim=-1),
torch.stack([m10 - m01, m20 + m02, m21 + m12,
q_abs[..., 3] ** 2], dim=-1),
],
dim=-2,
)
Expand All @@ -353,9 +363,6 @@ def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
# forall i; we pick the best-conditioned one (with the largest denominator)

return quat_candidates[
torch.nn.functional.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, : # pyre-ignore[16]
# pyre-ignore[16]
torch.nn.functional.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :
].reshape(batch_dim + (4,))




Loading