|
|
|
@ -479,27 +479,31 @@ class BSRoformer(Module):
|
|
|
|
|
'b s f t c -> b (f s) t c') # merge stereo / mono into the frequency, with frequency leading dimension, for band splitting
|
|
|
|
|
|
|
|
|
|
x = rearrange(stft_repr, 'b f t c -> b t (f c)')
|
|
|
|
|
|
|
|
|
|
# print("460:", x.dtype)#fp32
|
|
|
|
|
x = self.band_split(x)
|
|
|
|
|
|
|
|
|
|
# axial / hierarchical attention
|
|
|
|
|
|
|
|
|
|
# print("487:",x.dtype)#fp16
|
|
|
|
|
for transformer_block in self.layers:
|
|
|
|
|
|
|
|
|
|
if len(transformer_block) == 3:
|
|
|
|
|
linear_transformer, time_transformer, freq_transformer = transformer_block
|
|
|
|
|
|
|
|
|
|
x, ft_ps = pack([x], 'b * d')
|
|
|
|
|
# print("494:", x.dtype)#fp16
|
|
|
|
|
x = linear_transformer(x)
|
|
|
|
|
# print("496:", x.dtype)#fp16
|
|
|
|
|
x, = unpack(x, ft_ps, 'b * d')
|
|
|
|
|
else:
|
|
|
|
|
time_transformer, freq_transformer = transformer_block
|
|
|
|
|
|
|
|
|
|
# print("501:", x.dtype)#fp16
|
|
|
|
|
x = rearrange(x, 'b t f d -> b f t d')
|
|
|
|
|
x, ps = pack([x], '* t d')
|
|
|
|
|
|
|
|
|
|
x = time_transformer(x)
|
|
|
|
|
|
|
|
|
|
# print("505:", x.dtype)#fp16
|
|
|
|
|
x, = unpack(x, ps, '* t d')
|
|
|
|
|
x = rearrange(x, 'b f t d -> b t f d')
|
|
|
|
|
x, ps = pack([x], '* f d')
|
|
|
|
@ -508,10 +512,11 @@ class BSRoformer(Module):
|
|
|
|
|
|
|
|
|
|
x, = unpack(x, ps, '* f d')
|
|
|
|
|
|
|
|
|
|
# print("515:", x.dtype)######fp16
|
|
|
|
|
x = self.final_norm(x)
|
|
|
|
|
|
|
|
|
|
num_stems = len(self.mask_estimators)
|
|
|
|
|
|
|
|
|
|
# print("519:", x.dtype)#fp32
|
|
|
|
|
mask = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
|
|
|
|
|
mask = rearrange(mask, 'b n t (f c) -> b n f t c', c=2)
|
|
|
|
|
|
|
|
|
|