From dc032172c34538abcdad101997b9637b70ef0552 Mon Sep 17 00:00:00 2001
From: wangbluo <2538539015@qq.com>
Date: Fri, 13 Sep 2024 06:00:58 +0000
Subject: [PATCH] fix

---
 colossalai/shardformer/layer/attn.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py
index 8890da242..a2ea761bf 100644
--- a/colossalai/shardformer/layer/attn.py
+++ b/colossalai/shardformer/layer/attn.py
@@ -173,7 +173,7 @@ class ColoAttention:
             # no padding
             assert is_causal
             outputs["attention_mask_type"] = AttnMaskType.CAUSAL
-            if memory_size < MEMORY_BOUND and not is_causal:
+            if memory_size < MEMORY_BOUND:
                 attention_mask = torch.ones(s_q, s_kv, dtype=dtype, device=device)
                 if s_q != 1:
                     attention_mask.tril_(diagonal=0)