You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

87 lines
2.6 KiB
Python

import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.cuda.amp import custom_bwd, custom_fwd
try:
import _shencoder as _backend
except ImportError:
from .backend import _backend
class _sh_encoder(Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32) # force float32 for better precision
def forward(ctx, inputs, degree, calc_grad_inputs=False):
# inputs: [B, input_dim], float in [-1, 1]
# RETURN: [B, F], float
inputs = inputs.contiguous()
B, input_dim = inputs.shape # batch size, coord dim
output_dim = degree ** 2
outputs = torch.empty(B, output_dim, dtype=inputs.dtype, device=inputs.device)
if calc_grad_inputs:
dy_dx = torch.empty(B, input_dim * output_dim, dtype=inputs.dtype, device=inputs.device)
else:
dy_dx = None
_backend.sh_encode_forward(inputs, outputs, B, input_dim, degree, dy_dx)
ctx.save_for_backward(inputs, dy_dx)
ctx.dims = [B, input_dim, degree]
return outputs
@staticmethod
#@once_differentiable
@custom_bwd
def backward(ctx, grad):
# grad: [B, C * C]
inputs, dy_dx = ctx.saved_tensors
if dy_dx is not None:
grad = grad.contiguous()
B, input_dim, degree = ctx.dims
grad_inputs = torch.zeros_like(inputs)
_backend.sh_encode_backward(grad, inputs, B, input_dim, degree, dy_dx, grad_inputs)
return grad_inputs, None, None
else:
return None, None, None
sh_encode = _sh_encoder.apply
class SHEncoder(nn.Module):
def __init__(self, input_dim=3, degree=4):
super().__init__()
self.input_dim = input_dim # coord dims, must be 3
self.degree = degree # 0 ~ 4
self.output_dim = degree ** 2
assert self.input_dim == 3, "SH encoder only support input dim == 3"
assert self.degree > 0 and self.degree <= 8, "SH encoder only supports degree in [1, 8]"
def __repr__(self):
return f"SHEncoder: input_dim={self.input_dim} degree={self.degree}"
def forward(self, inputs, size=1):
# inputs: [..., input_dim], normalized real world positions in [-size, size]
# return: [..., degree^2]
inputs = inputs / size # [-1, 1]
prefix_shape = list(inputs.shape[:-1])
inputs = inputs.reshape(-1, self.input_dim)
outputs = sh_encode(inputs, self.degree, inputs.requires_grad)
outputs = outputs.reshape(prefix_shape + [self.output_dim])
return outputs