|
|
@ -57,7 +57,7 @@ def logits_to_probs(
|
|
|
|
logits = logits / max(temperature, 1e-5)
|
|
|
|
logits = logits / max(temperature, 1e-5)
|
|
|
|
|
|
|
|
|
|
|
|
if top_k is not None:
|
|
|
|
if top_k is not None:
|
|
|
|
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
|
|
|
v, _ = torch.topk(logits, top_k)
|
|
|
|
pivot = v.select(-1, -1).unsqueeze(-1)
|
|
|
|
pivot = v.select(-1, -1).unsqueeze(-1)
|
|
|
|
logits = torch.where(logits < pivot, inf_tensor_value, logits)
|
|
|
|
logits = torch.where(logits < pivot, inf_tensor_value, logits)
|
|
|
|
|
|
|
|
|
|
|
|