mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-01 20:05:27 +00:00
fix
This commit is contained in:
parent
f393867cff
commit
dc032172c3
@ -173,7 +173,7 @@ class ColoAttention:
|
|||||||
# no padding
|
# no padding
|
||||||
assert is_causal
|
assert is_causal
|
||||||
outputs["attention_mask_type"] = AttnMaskType.CAUSAL
|
outputs["attention_mask_type"] = AttnMaskType.CAUSAL
|
||||||
if memory_size < MEMORY_BOUND and not is_causal:
|
if memory_size < MEMORY_BOUND:
|
||||||
attention_mask = torch.ones(s_q, s_kv, dtype=dtype, device=device)
|
attention_mask = torch.ones(s_q, s_kv, dtype=dtype, device=device)
|
||||||
if s_q != 1:
|
if s_q != 1:
|
||||||
attention_mask.tril_(diagonal=0)
|
attention_mask.tril_(diagonal=0)
|
||||||
|
Loading…
Reference in New Issue
Block a user