fix CI bugs

This commit is contained in:
yuehuayingxueluo
2024-01-09 15:18:28 +08:00
committed by FrankLeeeee
parent 2a73e828eb
commit fab294c7f4
5 changed files with 21 additions and 9 deletions

View File

@@ -217,6 +217,8 @@ class PagedAttention:
attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(head_size)
padding_mask = None
if attn_mask is not None:
padding_mask = AttentionMaskConverter._expand_mask(attn_mask, q.dtype, seq_len)
@@ -279,11 +281,12 @@ class PagedAttention:
if attn_weights.size() != (bsz, num_heads, 1, seq_len):
raise ValueError(f"Got wrong attn_weights, should be in shape {(bsz,num_heads,1,seq_len)}.")
padding_mask = None
if attn_mask is not None:
padding_mask = AttentionMaskConverter._expand_mask(attn_mask, q.dtype, query_length)
padding_mask = AttentionMaskConverter._expand_mask(attn_mask, q.dtype, q_length)
attn_mask = AttentionMaskConverter._make_causal_mask(
(bsz, q_length), q.dtype, q.device, past_key_values_length=seq_len - query_length
(bsz, q_length), q.dtype, q.device, past_key_values_length=seq_len - q_length
)
if padding_mask is not None: