|
|
|
@ -402,7 +402,7 @@ class Text2SemanticDecoder(nn.Module):
|
|
|
|
|
if(idx==0):###第一次跑不能EOS否则没有了
|
|
|
|
|
logits = logits[:, :-1] ###刨除1024终止符号的概率
|
|
|
|
|
samples = sample(
|
|
|
|
|
logits[0], y, top_k=top_k, top_p=top_p, repetition_penalty=1.05, temperature=temperature
|
|
|
|
|
logits[0], y, top_k=top_k, top_p=top_p, repetition_penalty=1.35, temperature=temperature
|
|
|
|
|
)[0].unsqueeze(0)
|
|
|
|
|
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
|
|
|
|
|
print("use early stop num:", early_stop_num)
|
|
|
|
|