diff --git a/colossalai/shardformer/modeling/chatglm2.py b/colossalai/shardformer/modeling/chatglm2.py index 8934068d6..c8a311df7 100644 --- a/colossalai/shardformer/modeling/chatglm2.py +++ b/colossalai/shardformer/modeling/chatglm2.py @@ -51,7 +51,8 @@ def get_flash_core_attention_forward(): attn_mask_type = AttnMaskType.causal else: flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous() - attn_mask_type = AttnMaskType.paddedcausal + if not torch.all(flash_attention_mask): + attn_mask_type = AttnMaskType.paddedcausal attention = ColoAttention( embed_dim=self.hidden_size_per_partition, diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index 21f063930..8f4563537 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -771,11 +771,12 @@ def get_gpt2_flash_attention_forward(): attn_mask_type = AttnMaskType.causal flash_attention_mask = None if attention_mask != None: - if attn_mask_type == AttnMaskType.causal: - attn_mask_type == AttnMaskType.paddedcausal - else: - attn_mask_type = AttnMaskType.padding flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous() + if not torch.all(flash_attention_mask): + if attn_mask_type == AttnMaskType.causal: + attn_mask_type == AttnMaskType.paddedcausal + else: + attn_mask_type = AttnMaskType.padding scale = value.size(-1) ** -0.5 if self.scale_attn_by_inverse_layer_idx: diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 4bfef4529..8006bb3c0 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -465,7 +465,8 @@ def get_llama_flash_attention_forward(): f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" ) flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous() - attn_mask_type = AttnMaskType.paddedcausal + if not torch.all(flash_attention_mask): + attn_mask_type = AttnMaskType.paddedcausal attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads) attn_output = attention( diff --git a/colossalai/shardformer/modeling/opt.py b/colossalai/shardformer/modeling/opt.py index e0978d38e..71f2ca335 100644 --- a/colossalai/shardformer/modeling/opt.py +++ b/colossalai/shardformer/modeling/opt.py @@ -581,7 +581,8 @@ def get_opt_flash_attention_forward(): f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" ) flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous() - attn_mask_type = AttnMaskType.paddedcausal + if not torch.all(flash_attention_mask): + attn_mask_type = AttnMaskType.paddedcausal attention = ColoAttention( embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.dropout, scale=self.scaling diff --git a/colossalai/shardformer/modeling/whisper.py b/colossalai/shardformer/modeling/whisper.py index ef59dbcee..9827d4801 100644 --- a/colossalai/shardformer/modeling/whisper.py +++ b/colossalai/shardformer/modeling/whisper.py @@ -106,7 +106,10 @@ def get_whisper_flash_attention_forward(): f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" ) flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool).contiguous()) - attn_type = AttnMaskType.paddedcausal + if not torch.all(flash_attention_mask): + attn_type = AttnMaskType.paddedcausal + else: + attn_type = AttnMaskType.causal attention = ColoAttention( embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.dropout, scale=self.scaling diff --git a/examples/language/llama2/pretrain.py b/examples/language/llama2/pretrain.py index 6cc73b626..bb10f7a00 100644 --- a/examples/language/llama2/pretrain.py +++ b/examples/language/llama2/pretrain.py @@ -76,6 +76,7 @@ def tokenize_batch_for_pretrain(batch, tokenizer: Optional[LlamaTokenizer] = Non def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor: dist.all_reduce(tensor, op=dist.ReduceOp.SUM) + tensor = tensor.data tensor.div_(dist.get_world_size()) return tensor