diff --git a/tools/uvr5/bs_roformer/bs_roformer.py b/tools/uvr5/bs_roformer/bs_roformer.py index a159211..88af3ca 100644 --- a/tools/uvr5/bs_roformer/bs_roformer.py +++ b/tools/uvr5/bs_roformer/bs_roformer.py @@ -479,27 +479,31 @@ class BSRoformer(Module): 'b s f t c -> b (f s) t c') # merge stereo / mono into the frequency, with frequency leading dimension, for band splitting x = rearrange(stft_repr, 'b f t c -> b t (f c)') - + # print("460:", x.dtype)#fp32 x = self.band_split(x) # axial / hierarchical attention + # print("487:",x.dtype)#fp16 for transformer_block in self.layers: if len(transformer_block) == 3: linear_transformer, time_transformer, freq_transformer = transformer_block x, ft_ps = pack([x], 'b * d') + # print("494:", x.dtype)#fp16 x = linear_transformer(x) + # print("496:", x.dtype)#fp16 x, = unpack(x, ft_ps, 'b * d') else: time_transformer, freq_transformer = transformer_block + # print("501:", x.dtype)#fp16 x = rearrange(x, 'b t f d -> b f t d') x, ps = pack([x], '* t d') x = time_transformer(x) - + # print("505:", x.dtype)#fp16 x, = unpack(x, ps, '* t d') x = rearrange(x, 'b f t d -> b t f d') x, ps = pack([x], '* f d') @@ -508,10 +512,11 @@ class BSRoformer(Module): x, = unpack(x, ps, '* f d') + # print("515:", x.dtype)######fp16 x = self.final_norm(x) num_stems = len(self.mask_estimators) - + # print("519:", x.dtype)#fp32 mask = torch.stack([fn(x) for fn in self.mask_estimators], dim=1) mask = rearrange(mask, 'b n t (f c) -> b n f t c', c=2)