[shardformer] hotfix attn mask (#5945)

This commit is contained in:
Hongxin Liu
2024-07-29 13:58:27 +08:00
committed by GitHub
parent c8332b9cb5
commit 9664b1bc19
4 changed files with 9 additions and 5 deletions

View File

@@ -91,7 +91,7 @@ class MistralForwards:
if shard_config.enable_flash_attention:
# in this case, attention_mask is a dict rather than a tensor
mask_shape = (batch_size, 1, seq_length, seq_length)
mask_shape = (batch_size, 1, seq_length, seq_length + past_key_values_length)
attention_mask = ColoAttention.prepare_attn_kwargs(
mask_shape,
hidden_states.dtype,