|
|
|
@ -5,7 +5,7 @@ from typing import List, Optional
|
|
|
|
|
import torch
|
|
|
|
|
from tqdm import tqdm
|
|
|
|
|
|
|
|
|
|
from AR.models.utils import make_pad_mask
|
|
|
|
|
from AR.models.utils import make_pad_mask, make_pad_mask_left
|
|
|
|
|
from AR.models.utils import (
|
|
|
|
|
topk_sampling,
|
|
|
|
|
sample,
|
|
|
|
@ -162,7 +162,7 @@ class T2SBlock:
|
|
|
|
|
)
|
|
|
|
|
return x, k_cache, v_cache
|
|
|
|
|
|
|
|
|
|
def decode_next_token(self, x:torch.Tensor, k_cache:torch.Tensor, v_cache:torch.Tensor, attn_mask:Optional[torch.Tensor]=None, torch_sdpa:bool=True):
|
|
|
|
|
def decode_next_token(self, x:torch.Tensor, k_cache:torch.Tensor, v_cache:torch.Tensor, attn_mask:torch.Tensor=None, torch_sdpa:bool=True):
|
|
|
|
|
q, k, v = F.linear(x, self.qkv_w, self.qkv_b).chunk(3, dim=-1)
|
|
|
|
|
|
|
|
|
|
k_cache = torch.cat([k_cache, k], dim=1)
|
|
|
|
@ -178,7 +178,7 @@ class T2SBlock:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if torch_sdpa:
|
|
|
|
|
attn = F.scaled_dot_product_attention(q, k, v)
|
|
|
|
|
attn = F.scaled_dot_product_attention(q, k, v, (~attn_mask) if attn_mask is not None else None)
|
|
|
|
|
else:
|
|
|
|
|
attn = scaled_dot_product_attention(q, k, v, attn_mask)
|
|
|
|
|
|
|
|
|
@ -223,7 +223,7 @@ class T2STransformer:
|
|
|
|
|
self, x:torch.Tensor,
|
|
|
|
|
k_cache: List[torch.Tensor],
|
|
|
|
|
v_cache: List[torch.Tensor],
|
|
|
|
|
attn_mask : Optional[torch.Tensor]=None,
|
|
|
|
|
attn_mask : torch.Tensor=None,
|
|
|
|
|
torch_sdpa:bool=True
|
|
|
|
|
):
|
|
|
|
|
for i in range(self.num_blocks):
|
|
|
|
@ -573,71 +573,61 @@ class Text2SemanticDecoder(nn.Module):
|
|
|
|
|
x_item = self.ar_text_embedding(x_item.unsqueeze(0))
|
|
|
|
|
x_item = x_item + self.bert_proj(bert_item.transpose(0, 1).unsqueeze(0))
|
|
|
|
|
x_item = self.ar_text_position(x_item).squeeze(0)
|
|
|
|
|
x_item = F.pad(x_item,(0,0,0,max_len-x_item.shape[0]),value=0) if x_item.shape[0]<max_len else x_item
|
|
|
|
|
# x_item = F.pad(x_item,(0,0,0,max_len-x_item.shape[0]),value=0) if x_item.shape[0]<max_len else x_item ### padding right
|
|
|
|
|
x_item = F.pad(x_item,(0,0,max_len-x_item.shape[0],0),value=0) if x_item.shape[0]<max_len else x_item ### padding left
|
|
|
|
|
x_list.append(x_item)
|
|
|
|
|
x = torch.stack(x_list, dim=0)
|
|
|
|
|
x:torch.Tensor = torch.stack(x_list, dim=0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# AR Decoder
|
|
|
|
|
y = prompts
|
|
|
|
|
|
|
|
|
|
x_len = x.shape[1]
|
|
|
|
|
x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
|
|
|
|
|
stop = False
|
|
|
|
|
|
|
|
|
|
k_cache = None
|
|
|
|
|
v_cache = None
|
|
|
|
|
################### first step ##########################
|
|
|
|
|
if y is not None:
|
|
|
|
|
y_emb = self.ar_audio_embedding(y)
|
|
|
|
|
y_len = y_emb.shape[1]
|
|
|
|
|
prefix_len = y.shape[1]
|
|
|
|
|
y_lens = torch.LongTensor([y_emb.shape[1]]*y_emb.shape[0]).to(x.device)
|
|
|
|
|
y_pos = self.ar_audio_position(y_emb)
|
|
|
|
|
xy_pos = torch.concat([x, y_pos], dim=1)
|
|
|
|
|
ref_free = False
|
|
|
|
|
else:
|
|
|
|
|
y_emb = None
|
|
|
|
|
y_len = 0
|
|
|
|
|
prefix_len = 0
|
|
|
|
|
y_lens = torch.LongTensor([y_len]*x.shape[0]).to(x.device)
|
|
|
|
|
y_pos = None
|
|
|
|
|
xy_pos = x
|
|
|
|
|
y = torch.zeros(x.shape[0], 0, dtype=torch.int, device=x.device)
|
|
|
|
|
ref_free = True
|
|
|
|
|
assert y is not None, "Error: Prompt free is not supported batch_infer!"
|
|
|
|
|
ref_free = False
|
|
|
|
|
|
|
|
|
|
y_emb = self.ar_audio_embedding(y)
|
|
|
|
|
y_len = y_emb.shape[1]
|
|
|
|
|
prefix_len = y.shape[1]
|
|
|
|
|
y_lens = torch.LongTensor([y_emb.shape[1]]*y_emb.shape[0]).to(x.device)
|
|
|
|
|
y_pos = self.ar_audio_position(y_emb)
|
|
|
|
|
xy_pos = torch.concat([x, y_pos], dim=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
##### create mask #####
|
|
|
|
|
bsz = x.shape[0]
|
|
|
|
|
src_len = x_len + y_len
|
|
|
|
|
y_paddind_mask = make_pad_mask(y_lens, y_len)
|
|
|
|
|
x_paddind_mask = make_pad_mask(x_lens, max_len)
|
|
|
|
|
y_paddind_mask = make_pad_mask_left(y_lens, y_len)
|
|
|
|
|
x_paddind_mask = make_pad_mask_left(x_lens, max_len)
|
|
|
|
|
|
|
|
|
|
# (bsz, x_len + y_len)
|
|
|
|
|
xy_padding_mask = torch.concat([x_paddind_mask, y_paddind_mask], dim=1)
|
|
|
|
|
padding_mask = torch.concat([x_paddind_mask, y_paddind_mask], dim=1)
|
|
|
|
|
|
|
|
|
|
x_mask = F.pad(
|
|
|
|
|
torch.zeros(x_len, x_len, dtype=torch.bool, device=x.device),
|
|
|
|
|
(0, y_len),
|
|
|
|
|
value=True,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
x_mask = F.pad(
|
|
|
|
|
x_attn_mask,
|
|
|
|
|
(0, y_len), ###xx的纯0扩展到xx纯0+xy纯1,(x,x+y)
|
|
|
|
|
value=True,
|
|
|
|
|
)
|
|
|
|
|
y_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y)
|
|
|
|
|
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
|
|
|
|
|
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool, device=x.device), diagonal=1),
|
|
|
|
|
(x_len, 0),
|
|
|
|
|
value=False,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
xy_mask = torch.concat([x_mask, y_mask], dim=0).view(1 , src_len, src_len).repeat(bsz, 1, 1).to(x.device)
|
|
|
|
|
_xy_padding_mask = xy_padding_mask.view(bsz, 1, src_len).repeat(1, src_len, 1)
|
|
|
|
|
causal_mask = torch.concat([x_mask, y_mask], dim=0).view(1 , src_len, src_len).repeat(bsz, 1, 1).to(x.device)
|
|
|
|
|
padding_mask = padding_mask.unsqueeze(1) * padding_mask.unsqueeze(2) ### [b, x+y, x+y]
|
|
|
|
|
|
|
|
|
|
for i in range(bsz):
|
|
|
|
|
l = x_lens[i]
|
|
|
|
|
_xy_padding_mask[i,l:max_len,:]=True
|
|
|
|
|
|
|
|
|
|
xy_attn_mask = xy_mask.logical_or(_xy_padding_mask)
|
|
|
|
|
xy_attn_mask = xy_attn_mask.unsqueeze(1).expand(-1, self.num_head, -1, -1)
|
|
|
|
|
xy_attn_mask = xy_attn_mask.bool()
|
|
|
|
|
xy_padding_mask = xy_padding_mask.view(bsz, src_len, 1)
|
|
|
|
|
attn_mask:torch.Tensor = causal_mask.logical_or(padding_mask)
|
|
|
|
|
attn_mask = attn_mask.unsqueeze(1).expand(-1, self.num_head, -1, -1).bool()
|
|
|
|
|
# padding_mask = padding_mask.view(bsz, src_len, 1)
|
|
|
|
|
|
|
|
|
|
###### decode #####
|
|
|
|
|
y_list = [None]*y.shape[0]
|
|
|
|
@ -645,18 +635,18 @@ class Text2SemanticDecoder(nn.Module):
|
|
|
|
|
idx_list = [None]*y.shape[0]
|
|
|
|
|
for idx in tqdm(range(1500)):
|
|
|
|
|
if idx == 0:
|
|
|
|
|
xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask, xy_padding_mask, False)
|
|
|
|
|
xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, attn_mask, None)
|
|
|
|
|
else:
|
|
|
|
|
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache, xy_attn_mask, False)
|
|
|
|
|
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache, attn_mask)
|
|
|
|
|
logits = self.ar_predict_layer(
|
|
|
|
|
xy_dec[:, -1]
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if idx == 0:
|
|
|
|
|
xy_attn_mask = F.pad(xy_attn_mask[:,:,-1].unsqueeze(-2),(0,1),value=False)
|
|
|
|
|
attn_mask = F.pad(attn_mask[:,:,-1].unsqueeze(-2),(0,1),value=False)
|
|
|
|
|
logits = logits[:, :-1]
|
|
|
|
|
else:
|
|
|
|
|
xy_attn_mask = F.pad(xy_attn_mask,(0,1),value=False)
|
|
|
|
|
attn_mask = F.pad(attn_mask,(0,1),value=False)
|
|
|
|
|
|
|
|
|
|
samples = sample(
|
|
|
|
|
logits, y, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature
|
|
|
|
@ -686,7 +676,7 @@ class Text2SemanticDecoder(nn.Module):
|
|
|
|
|
if reserved_idx_of_batch_for_y is not None:
|
|
|
|
|
# index = torch.LongTensor(batch_idx_map).to(y.device)
|
|
|
|
|
y = torch.index_select(y, dim=0, index=reserved_idx_of_batch_for_y)
|
|
|
|
|
xy_attn_mask = torch.index_select(xy_attn_mask, dim=0, index=reserved_idx_of_batch_for_y)
|
|
|
|
|
attn_mask = torch.index_select(attn_mask, dim=0, index=reserved_idx_of_batch_for_y)
|
|
|
|
|
if k_cache is not None :
|
|
|
|
|
for i in range(len(k_cache)):
|
|
|
|
|
k_cache[i] = torch.index_select(k_cache[i], dim=0, index=reserved_idx_of_batch_for_y)
|
|
|
|
|