|
|
@ -12,33 +12,33 @@ import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def multi_head_attention_forward_patched(
|
|
|
|
def multi_head_attention_forward_patched(
|
|
|
|
query: Tensor,
|
|
|
|
query,
|
|
|
|
key: Tensor,
|
|
|
|
key,
|
|
|
|
value: Tensor,
|
|
|
|
value,
|
|
|
|
embed_dim_to_check: int,
|
|
|
|
embed_dim_to_check,
|
|
|
|
num_heads: int,
|
|
|
|
num_heads,
|
|
|
|
in_proj_weight: Optional[Tensor],
|
|
|
|
in_proj_weight,
|
|
|
|
in_proj_bias: Optional[Tensor],
|
|
|
|
in_proj_bias,
|
|
|
|
bias_k: Optional[Tensor],
|
|
|
|
bias_k,
|
|
|
|
bias_v: Optional[Tensor],
|
|
|
|
bias_v,
|
|
|
|
add_zero_attn: bool,
|
|
|
|
add_zero_attn,
|
|
|
|
dropout_p: float,
|
|
|
|
dropout_p: float,
|
|
|
|
out_proj_weight: Tensor,
|
|
|
|
out_proj_weight,
|
|
|
|
out_proj_bias: Optional[Tensor],
|
|
|
|
out_proj_bias,
|
|
|
|
training: bool = True,
|
|
|
|
training = True,
|
|
|
|
key_padding_mask: Optional[Tensor] = None,
|
|
|
|
key_padding_mask = None,
|
|
|
|
need_weights: bool = True,
|
|
|
|
need_weights = True,
|
|
|
|
attn_mask: Optional[Tensor] = None,
|
|
|
|
attn_mask = None,
|
|
|
|
use_separate_proj_weight: bool = False,
|
|
|
|
use_separate_proj_weight = False,
|
|
|
|
q_proj_weight: Optional[Tensor] = None,
|
|
|
|
q_proj_weight = None,
|
|
|
|
k_proj_weight: Optional[Tensor] = None,
|
|
|
|
k_proj_weight = None,
|
|
|
|
v_proj_weight: Optional[Tensor] = None,
|
|
|
|
v_proj_weight = None,
|
|
|
|
static_k: Optional[Tensor] = None,
|
|
|
|
static_k = None,
|
|
|
|
static_v: Optional[Tensor] = None,
|
|
|
|
static_v = None,
|
|
|
|
average_attn_weights: bool = True,
|
|
|
|
average_attn_weights = True,
|
|
|
|
is_causal: bool = False,
|
|
|
|
is_causal = False,
|
|
|
|
cache=None,
|
|
|
|
cache=None,
|
|
|
|
) -> Tuple[Tensor, Optional[Tensor]]:
|
|
|
|
):
|
|
|
|
r"""
|
|
|
|
r"""
|
|
|
|
Args:
|
|
|
|
Args:
|
|
|
|
query, key, value: map a query and a set of key-value pairs to an output.
|
|
|
|
query, key, value: map a query and a set of key-value pairs to an output.
|
|
|
|