mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-19 20:23:41 +00:00
fix
This commit is contained in:
parent
b582319273
commit
827ef3ee9a
@ -195,10 +195,6 @@ class ColoAttention:
|
|||||||
b,
|
b,
|
||||||
s_kv,
|
s_kv,
|
||||||
), f"Padding mask shape {kv_padding_mask.shape} should align with shape 4d ({b}, {s_kv})"
|
), f"Padding mask shape {kv_padding_mask.shape} should align with shape 4d ({b}, {s_kv})"
|
||||||
if memory_size < MEMORY_BOUND and not is_causal:
|
|
||||||
attention_mask = kv_padding_mask[:, None, :].expand(b, s_q, s_kv).to(dtype=dtype, device=device)
|
|
||||||
else:
|
|
||||||
attention_mask = torch.empty((0,), dtype=dtype, device=device)
|
|
||||||
outputs.update(
|
outputs.update(
|
||||||
{
|
{
|
||||||
"cu_seqlens_q": cu_seqlens_q,
|
"cu_seqlens_q": cu_seqlens_q,
|
||||||
@ -210,15 +206,18 @@ class ColoAttention:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
if is_causal:
|
if is_causal:
|
||||||
attention_mask = kv_padding_mask[:, None, :].expand(b, s_q, s_kv).to(dtype=dtype, device=device)
|
|
||||||
outputs["attention_mask_type"] = AttnMaskType.PADDED_CAUSAL
|
outputs["attention_mask_type"] = AttnMaskType.PADDED_CAUSAL
|
||||||
if memory_size < MEMORY_BOUND:
|
if memory_size < MEMORY_BOUND:
|
||||||
if s_q != 1:
|
if s_q != 1:
|
||||||
|
attention_mask = kv_padding_mask[:, None, :].expand(b, s_q, s_kv).to(dtype=dtype, device=device)
|
||||||
attention_mask = attention_mask * attention_mask.new_ones(s_q, s_kv).tril(diagonal=0)
|
attention_mask = attention_mask * attention_mask.new_ones(s_q, s_kv).tril(diagonal=0)
|
||||||
else:
|
else:
|
||||||
attention_mask = torch.empty((0,), dtype=dtype, device=device)
|
attention_mask = torch.empty((0,), dtype=dtype, device=device)
|
||||||
else:
|
else:
|
||||||
outputs["attention_mask_type"] = AttnMaskType.PADDED
|
outputs["attention_mask_type"] = AttnMaskType.PADDED
|
||||||
|
if memory_size < MEMORY_BOUND:
|
||||||
|
attention_mask = kv_padding_mask[:, None, :].expand(b, s_q, s_kv).to(dtype=dtype, device=device)
|
||||||
|
|
||||||
if invert:
|
if invert:
|
||||||
attention_mask = invert_mask(attention_mask).unsqueeze(1)
|
attention_mask = invert_mask(attention_mask).unsqueeze(1)
|
||||||
outputs["attention_mask"] = attention_mask
|
outputs["attention_mask"] = attention_mask
|
||||||
|
Loading…
Reference in New Issue
Block a user