This commit is contained in:
wangbluo
2024-10-15 11:56:49 +08:00
parent 3dc08c8a5a
commit 6be9862aaf
6 changed files with 23 additions and 6 deletions

View File

@@ -869,6 +869,7 @@ def get_gpt2_flash_attention_forward(shard_config: Optional[ShardConfig] = None)
scale=scale,
inner_ring_size=shard_config.inner_ring_size,
pg_mesh=shard_config.pg_mesh,
sp_axis=shard_config.sp_axis,
)
else:
attn_output = ColoAttention.attention(query, key, value, **attention_mask, dropout_p=dropout_p, scale=scale)