You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

70 lines
1.9 KiB
Python

"""This module contains functions for geometry transform and camera projection"""
import torch
import torch.nn as nn
import numpy as np
def euler2rot(euler_angle):
batch_size = euler_angle.shape[0]
theta = euler_angle[:, 0].reshape(-1, 1, 1)
phi = euler_angle[:, 1].reshape(-1, 1, 1)
psi = euler_angle[:, 2].reshape(-1, 1, 1)
one = torch.ones((batch_size, 1, 1), dtype=torch.float32, device=euler_angle.device)
zero = torch.zeros(
(batch_size, 1, 1), dtype=torch.float32, device=euler_angle.device
)
rot_x = torch.cat(
(
torch.cat((one, zero, zero), 1),
torch.cat((zero, theta.cos(), theta.sin()), 1),
torch.cat((zero, -theta.sin(), theta.cos()), 1),
),
2,
)
rot_y = torch.cat(
(
torch.cat((phi.cos(), zero, -phi.sin()), 1),
torch.cat((zero, one, zero), 1),
torch.cat((phi.sin(), zero, phi.cos()), 1),
),
2,
)
rot_z = torch.cat(
(
torch.cat((psi.cos(), -psi.sin(), zero), 1),
torch.cat((psi.sin(), psi.cos(), zero), 1),
torch.cat((zero, zero, one), 1),
),
2,
)
return torch.bmm(rot_x, torch.bmm(rot_y, rot_z))
def rot_trans_geo(geometry, rot, trans):
rott_geo = torch.bmm(rot, geometry.permute(0, 2, 1)) + trans.view(-1, 3, 1)
return rott_geo.permute(0, 2, 1)
def euler_trans_geo(geometry, euler, trans):
rot = euler2rot(euler)
return rot_trans_geo(geometry, rot, trans)
def proj_geo(rott_geo, camera_para):
fx = camera_para[:, 0]
fy = camera_para[:, 0]
cx = camera_para[:, 1]
cy = camera_para[:, 2]
X = rott_geo[:, :, 0]
Y = rott_geo[:, :, 1]
Z = rott_geo[:, :, 2]
fxX = fx[:, None] * X
fyY = fy[:, None] * Y
proj_x = -fxX / Z + cx[:, None]
proj_y = fyY / Z + cy[:, None]
return torch.cat((proj_x[:, :, None], proj_y[:, :, None], Z[:, :, None]), 2)