修复gpt的loss计算问题 (#2537)

* 修复gpt的loss计算问题

* fallback tts config
main
ChasonJiang 2 weeks ago committed by GitHub
parent b9211657d8
commit b5a67e6247
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -356,7 +356,7 @@ class Text2SemanticDecoder(nn.Module):
x = self.ar_text_embedding(x) x = self.ar_text_embedding(x)
x = x + self.bert_proj(bert_feature.transpose(1, 2)) x = x + self.bert_proj(bert_feature.transpose(1, 2))
x = self.ar_text_position(x) x = self.ar_text_position(x)
x_mask = make_pad_mask(x_lens) x_mask = make_pad_mask_left(x_lens)
y_mask = make_pad_mask(y_lens) y_mask = make_pad_mask(y_lens)
y_mask_int = y_mask.type(torch.int64) y_mask_int = y_mask.type(torch.int64)
@ -420,7 +420,7 @@ class Text2SemanticDecoder(nn.Module):
mask=xy_attn_mask, mask=xy_attn_mask,
) )
x_len = x_lens.max() x_len = x_lens.max()
logits = self.ar_predict_layer(xy_dec[:, x_len:]) logits = self.ar_predict_layer(xy_dec[:, x_len-1:])
###### DPO ############# ###### DPO #############
reject_xy_pos, reject_xy_attn_mask, reject_targets = self.make_input_data( reject_xy_pos, reject_xy_attn_mask, reject_targets = self.make_input_data(
@ -432,7 +432,7 @@ class Text2SemanticDecoder(nn.Module):
mask=reject_xy_attn_mask, mask=reject_xy_attn_mask,
) )
x_len = x_lens.max() x_len = x_lens.max()
reject_logits = self.ar_predict_layer(reject_xy_dec[:, x_len:]) reject_logits = self.ar_predict_layer(reject_xy_dec[:, x_len-1:])
# loss # loss
# from feiteng: 每次 duration 越多, 梯度更新也应该更多, 所以用 sum # from feiteng: 每次 duration 越多, 梯度更新也应该更多, 所以用 sum
@ -455,7 +455,7 @@ class Text2SemanticDecoder(nn.Module):
x = self.ar_text_embedding(x) x = self.ar_text_embedding(x)
x = x + self.bert_proj(bert_feature.transpose(1, 2)) x = x + self.bert_proj(bert_feature.transpose(1, 2))
x = self.ar_text_position(x) x = self.ar_text_position(x)
x_mask = make_pad_mask(x_lens) x_mask = make_pad_mask_left(x_lens)
y_mask = make_pad_mask(y_lens) y_mask = make_pad_mask(y_lens)
y_mask_int = y_mask.type(torch.int64) y_mask_int = y_mask.type(torch.int64)
@ -502,7 +502,7 @@ class Text2SemanticDecoder(nn.Module):
(xy_pos, None), (xy_pos, None),
mask=xy_attn_mask, mask=xy_attn_mask,
) )
logits = self.ar_predict_layer(xy_dec[:, x_len:]).permute(0, 2, 1) logits = self.ar_predict_layer(xy_dec[:, x_len-1:]).permute(0, 2, 1)
# loss # loss
# from feiteng: 每次 duration 越多, 梯度更新也应该更多, 所以用 sum # from feiteng: 每次 duration 越多, 梯度更新也应该更多, 所以用 sum
loss = F.cross_entropy(logits, targets, reduction="sum") loss = F.cross_entropy(logits, targets, reduction="sum")
@ -578,7 +578,7 @@ class Text2SemanticDecoder(nn.Module):
def pad_y_eos(self, y, y_mask_int, eos_id): def pad_y_eos(self, y, y_mask_int, eos_id):
targets = F.pad(y, (0, 1), value=0) + eos_id * F.pad(y_mask_int, (0, 1), value=1) targets = F.pad(y, (0, 1), value=0) + eos_id * F.pad(y_mask_int, (0, 1), value=1)
# 错位 # 错位
return targets[:, :-1], targets[:, 1:] return targets[:, :-1], targets
def infer_panel_batch_infer( def infer_panel_batch_infer(
self, self,

Loading…
Cancel
Save