From 827ef3ee9a176de774422f7361cc95efae57e3f1 Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Sat, 14 Sep 2024 10:40:35 +0000 Subject: [PATCH] fix --- colossalai/shardformer/layer/attn.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 7157fbed8..2f8e4d677 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -195,10 +195,6 @@ class ColoAttention: b, s_kv, ), f"Padding mask shape {kv_padding_mask.shape} should align with shape 4d ({b}, {s_kv})" - if memory_size < MEMORY_BOUND and not is_causal: - attention_mask = kv_padding_mask[:, None, :].expand(b, s_q, s_kv).to(dtype=dtype, device=device) - else: - attention_mask = torch.empty((0,), dtype=dtype, device=device) outputs.update( { "cu_seqlens_q": cu_seqlens_q, @@ -210,15 +206,18 @@ class ColoAttention: } ) if is_causal: - attention_mask = kv_padding_mask[:, None, :].expand(b, s_q, s_kv).to(dtype=dtype, device=device) outputs["attention_mask_type"] = AttnMaskType.PADDED_CAUSAL if memory_size < MEMORY_BOUND: if s_q != 1: + attention_mask = kv_padding_mask[:, None, :].expand(b, s_q, s_kv).to(dtype=dtype, device=device) attention_mask = attention_mask * attention_mask.new_ones(s_q, s_kv).tril(diagonal=0) else: attention_mask = torch.empty((0,), dtype=dtype, device=device) else: outputs["attention_mask_type"] = AttnMaskType.PADDED + if memory_size < MEMORY_BOUND: + attention_mask = kv_padding_mask[:, None, :].expand(b, s_q, s_kv).to(dtype=dtype, device=device) + if invert: attention_mask = invert_mask(attention_mask).unsqueeze(1) outputs["attention_mask"] = attention_mask