|
|
|
@ -677,7 +677,7 @@ class Text2SemanticDecoder(nn.Module):
|
|
|
|
|
# batch_indexs = torch.tensor(batch_idx_map, device=y.device)[removed_idx_of_batch_for_y]
|
|
|
|
|
for i in removed_idx_of_batch_for_y:
|
|
|
|
|
batch_index = batch_idx_map[i]
|
|
|
|
|
idx_list[batch_index] = idx - 1
|
|
|
|
|
idx_list[batch_index] = idx
|
|
|
|
|
y_list[batch_index] = y[i, :-1]
|
|
|
|
|
|
|
|
|
|
batch_idx_map = [batch_idx_map[i] for i in reserved_idx_of_batch_for_y.tolist()]
|
|
|
|
@ -857,7 +857,7 @@ class Text2SemanticDecoder(nn.Module):
|
|
|
|
|
|
|
|
|
|
if ref_free:
|
|
|
|
|
return y[:, :-1], 0
|
|
|
|
|
return y[:, :-1], idx - 1
|
|
|
|
|
return y[:, :-1], idx
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def infer_panel(
|
|
|
|
|