mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 20:40:34 +00:00
[shardformer] hotfix attn mask (#5945)
This commit is contained in:
@@ -91,7 +91,7 @@ class MistralForwards:
|
||||
|
||||
if shard_config.enable_flash_attention:
|
||||
# in this case, attention_mask is a dict rather than a tensor
|
||||
mask_shape = (batch_size, 1, seq_length, seq_length)
|
||||
mask_shape = (batch_size, 1, seq_length, seq_length + past_key_values_length)
|
||||
attention_mask = ColoAttention.prepare_attn_kwargs(
|
||||
mask_shape,
|
||||
hidden_states.dtype,
|
||||
|
Reference in New Issue
Block a user