bsroformer support fp16 inference

bsroformer support fp16 inference
main
RVC-Boss 1 year ago committed by GitHub
parent e62e965323
commit 9498fc775b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -479,27 +479,31 @@ class BSRoformer(Module):
'b s f t c -> b (f s) t c') # merge stereo / mono into the frequency, with frequency leading dimension, for band splitting 'b s f t c -> b (f s) t c') # merge stereo / mono into the frequency, with frequency leading dimension, for band splitting
x = rearrange(stft_repr, 'b f t c -> b t (f c)') x = rearrange(stft_repr, 'b f t c -> b t (f c)')
# print("460:", x.dtype)#fp32
x = self.band_split(x) x = self.band_split(x)
# axial / hierarchical attention # axial / hierarchical attention
# print("487:",x.dtype)#fp16
for transformer_block in self.layers: for transformer_block in self.layers:
if len(transformer_block) == 3: if len(transformer_block) == 3:
linear_transformer, time_transformer, freq_transformer = transformer_block linear_transformer, time_transformer, freq_transformer = transformer_block
x, ft_ps = pack([x], 'b * d') x, ft_ps = pack([x], 'b * d')
# print("494:", x.dtype)#fp16
x = linear_transformer(x) x = linear_transformer(x)
# print("496:", x.dtype)#fp16
x, = unpack(x, ft_ps, 'b * d') x, = unpack(x, ft_ps, 'b * d')
else: else:
time_transformer, freq_transformer = transformer_block time_transformer, freq_transformer = transformer_block
# print("501:", x.dtype)#fp16
x = rearrange(x, 'b t f d -> b f t d') x = rearrange(x, 'b t f d -> b f t d')
x, ps = pack([x], '* t d') x, ps = pack([x], '* t d')
x = time_transformer(x) x = time_transformer(x)
# print("505:", x.dtype)#fp16
x, = unpack(x, ps, '* t d') x, = unpack(x, ps, '* t d')
x = rearrange(x, 'b f t d -> b t f d') x = rearrange(x, 'b f t d -> b t f d')
x, ps = pack([x], '* f d') x, ps = pack([x], '* f d')
@ -508,10 +512,11 @@ class BSRoformer(Module):
x, = unpack(x, ps, '* f d') x, = unpack(x, ps, '* f d')
# print("515:", x.dtype)######fp16
x = self.final_norm(x) x = self.final_norm(x)
num_stems = len(self.mask_estimators) num_stems = len(self.mask_estimators)
# print("519:", x.dtype)#fp32
mask = torch.stack([fn(x) for fn in self.mask_estimators], dim=1) mask = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
mask = rearrange(mask, 'b n t (f c) -> b n f t c', c=2) mask = rearrange(mask, 'b n t (f c) -> b n f t c', c=2)

Loading…
Cancel
Save