You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

77 lines
2.2 KiB
Python

import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.cuda.amp import custom_bwd, custom_fwd
try:
import _freqencoder as _backend
except ImportError:
from .backend import _backend
class _freq_encoder(Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32) # force float32 for better precision
def forward(ctx, inputs, degree, output_dim):
# inputs: [B, input_dim], float
# RETURN: [B, F], float
if not inputs.is_cuda: inputs = inputs.cuda()
inputs = inputs.contiguous()
B, input_dim = inputs.shape # batch size, coord dim
outputs = torch.empty(B, output_dim, dtype=inputs.dtype, device=inputs.device)
_backend.freq_encode_forward(inputs, B, input_dim, degree, output_dim, outputs)
ctx.save_for_backward(inputs, outputs)
ctx.dims = [B, input_dim, degree, output_dim]
return outputs
@staticmethod
#@once_differentiable
@custom_bwd
def backward(ctx, grad):
# grad: [B, C * C]
grad = grad.contiguous()
inputs, outputs = ctx.saved_tensors
B, input_dim, degree, output_dim = ctx.dims
grad_inputs = torch.zeros_like(inputs)
_backend.freq_encode_backward(grad, outputs, B, input_dim, degree, output_dim, grad_inputs)
return grad_inputs, None, None
freq_encode = _freq_encoder.apply
class FreqEncoder(nn.Module):
def __init__(self, input_dim=3, degree=4):
super().__init__()
self.input_dim = input_dim
self.degree = degree
self.output_dim = input_dim + input_dim * 2 * degree
def __repr__(self):
return f"FreqEncoder: input_dim={self.input_dim} degree={self.degree} output_dim={self.output_dim}"
def forward(self, inputs, **kwargs):
# inputs: [..., input_dim]
# return: [..., ]
prefix_shape = list(inputs.shape[:-1])
inputs = inputs.reshape(-1, self.input_dim)
outputs = freq_encode(inputs, self.degree, self.output_dim)
outputs = outputs.reshape(prefix_shape + [self.output_dim])
return outputs