From 9498fc775bd6a2c385265423ca818987692484e8 Mon Sep 17 00:00:00 2001 From: RVC-Boss <129054828+RVC-Boss@users.noreply.github.com> Date: Thu, 1 Aug 2024 21:28:06 +0800 Subject: [PATCH] bsroformer support fp16 inference bsroformer support fp16 inference --- tools/uvr5/bs_roformer/bs_roformer.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/tools/uvr5/bs_roformer/bs_roformer.py b/tools/uvr5/bs_roformer/bs_roformer.py index a159211..88af3ca 100644 --- a/tools/uvr5/bs_roformer/bs_roformer.py +++ b/tools/uvr5/bs_roformer/bs_roformer.py @@ -479,27 +479,31 @@ class BSRoformer(Module): 'b s f t c -> b (f s) t c') # merge stereo / mono into the frequency, with frequency leading dimension, for band splitting x = rearrange(stft_repr, 'b f t c -> b t (f c)') - + # print("460:", x.dtype)#fp32 x = self.band_split(x) # axial / hierarchical attention + # print("487:",x.dtype)#fp16 for transformer_block in self.layers: if len(transformer_block) == 3: linear_transformer, time_transformer, freq_transformer = transformer_block x, ft_ps = pack([x], 'b * d') + # print("494:", x.dtype)#fp16 x = linear_transformer(x) + # print("496:", x.dtype)#fp16 x, = unpack(x, ft_ps, 'b * d') else: time_transformer, freq_transformer = transformer_block + # print("501:", x.dtype)#fp16 x = rearrange(x, 'b t f d -> b f t d') x, ps = pack([x], '* t d') x = time_transformer(x) - + # print("505:", x.dtype)#fp16 x, = unpack(x, ps, '* t d') x = rearrange(x, 'b f t d -> b t f d') x, ps = pack([x], '* f d') @@ -508,10 +512,11 @@ class BSRoformer(Module): x, = unpack(x, ps, '* f d') + # print("515:", x.dtype)######fp16 x = self.final_norm(x) num_stems = len(self.mask_estimators) - + # print("519:", x.dtype)#fp32 mask = torch.stack([fn(x) for fn in self.mask_estimators], dim=1) mask = rearrange(mask, 'b n t (f c) -> b n f t c', c=2)