mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-04-28 19:55:29 +00:00
fix
This commit is contained in:
parent
6aface9316
commit
698c8b9804
@ -538,6 +538,7 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
|
|||||||
|
|
||||||
# sp: all-to-all comminucation when introducing sequence parallel
|
# sp: all-to-all comminucation when introducing sequence parallel
|
||||||
if sp_mode == "all_to_all":
|
if sp_mode == "all_to_all":
|
||||||
|
query_states = all_to_all_comm(query_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||||
key_states = all_to_all_comm(key_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
key_states = all_to_all_comm(key_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||||
value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||||
bsz, q_len, _ = query_states.size()
|
bsz, q_len, _ = query_states.size()
|
||||||
|
Loading…
Reference in New Issue
Block a user