mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-11 22:10:37 +00:00
fix CI bugs
This commit is contained in:
committed by
FrankLeeeee
parent
2a73e828eb
commit
fab294c7f4
@@ -217,6 +217,8 @@ class PagedAttention:
|
||||
|
||||
attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(head_size)
|
||||
|
||||
padding_mask = None
|
||||
|
||||
if attn_mask is not None:
|
||||
padding_mask = AttentionMaskConverter._expand_mask(attn_mask, q.dtype, seq_len)
|
||||
|
||||
@@ -279,11 +281,12 @@ class PagedAttention:
|
||||
if attn_weights.size() != (bsz, num_heads, 1, seq_len):
|
||||
raise ValueError(f"Got wrong attn_weights, should be in shape {(bsz,num_heads,1,seq_len)}.")
|
||||
|
||||
padding_mask = None
|
||||
if attn_mask is not None:
|
||||
padding_mask = AttentionMaskConverter._expand_mask(attn_mask, q.dtype, query_length)
|
||||
padding_mask = AttentionMaskConverter._expand_mask(attn_mask, q.dtype, q_length)
|
||||
|
||||
attn_mask = AttentionMaskConverter._make_causal_mask(
|
||||
(bsz, q_length), q.dtype, q.device, past_key_values_length=seq_len - query_length
|
||||
(bsz, q_length), q.dtype, q.device, past_key_values_length=seq_len - q_length
|
||||
)
|
||||
|
||||
if padding_mask is not None:
|
||||
|
Reference in New Issue
Block a user