mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-17 07:26:29 +00:00
fix
This commit is contained in:
parent
686982764c
commit
e891501c55
@ -1056,8 +1056,8 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||||||
assert (
|
assert (
|
||||||
not pp_style == "zbv" or scheduler_nodes is not None
|
not pp_style == "zbv" or scheduler_nodes is not None
|
||||||
), f"scheduler_nodes must not be None when using zero bubble pipeline."
|
), f"scheduler_nodes must not be None when using zero bubble pipeline."
|
||||||
if sp_size is None or sp_size <= 1:
|
# if sp_size is None or sp_size <= 1:
|
||||||
enable_sequence_parallelism = False
|
# enable_sequence_parallelism = False
|
||||||
if enable_sequence_parallelism:
|
if enable_sequence_parallelism:
|
||||||
self.sequence_parallelism_mode = (
|
self.sequence_parallelism_mode = (
|
||||||
sequence_parallelism_mode if sequence_parallelism_mode is not None else "all_to_all"
|
sequence_parallelism_mode if sequence_parallelism_mode is not None else "all_to_all"
|
||||||
|
@ -607,7 +607,8 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
|
|||||||
attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication
|
attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
# attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||||
|
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
||||||
|
|
||||||
attn_output = self.o_proj(attn_output)
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
@ -162,9 +162,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||||||
[
|
[
|
||||||
# Double Ring Attention
|
# Double Ring Attention
|
||||||
{
|
{
|
||||||
"tp_size": 2,
|
"tp_size": 1,
|
||||||
"pp_size": 1,
|
"pp_size": 1,
|
||||||
"sp_size": 2,
|
"sp_size": 4,
|
||||||
"num_microbatches": 1,
|
"num_microbatches": 1,
|
||||||
"enable_sequence_parallelism": True,
|
"enable_sequence_parallelism": True,
|
||||||
"sequence_parallelism_mode": "ring_attn",
|
"sequence_parallelism_mode": "ring_attn",
|
||||||
@ -226,12 +226,12 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||||||
"initial_scale": 1,
|
"initial_scale": 1,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"tp_size": 1,
|
"tp_size": 2,
|
||||||
"pp_size": 1,
|
"pp_size": 1,
|
||||||
"sp_size": 2,
|
"sp_size": 1,
|
||||||
"num_microbatches": 1,
|
"num_microbatches": 1,
|
||||||
"enable_sequence_parallelism": True,
|
"enable_sequence_parallelism": True,
|
||||||
"sequence_parallelism_mode": "all_to_all",
|
"sequence_parallelism_mode": "ring",
|
||||||
"enable_flash_attention": True,
|
"enable_flash_attention": True,
|
||||||
"use_lazy_init": True,
|
"use_lazy_init": True,
|
||||||
"zero_stage": 2,
|
"zero_stage": 2,
|
||||||
|
Loading…
Reference in New Issue
Block a user