mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-04-26 17:53:08 +00:00
Support 4d parallel + flash attention (#5789)
* support tp + sp + pp * remove comments --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu>
This commit is contained in:
@@ -120,9 +120,20 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||
atol, rtol = 1e-4, 1e-3
|
||||
else:
|
||||
atol, rtol = 5e-3, 5e-3
|
||||
check_weight(
|
||||
llama_model, shard_llama_model, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False
|
||||
)
|
||||
try:
|
||||
check_weight(
|
||||
llama_model,
|
||||
shard_llama_model,
|
||||
col_layer_for_check,
|
||||
tp_group,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
dim=1,
|
||||
verbose=False,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Failed config: {test_config}")
|
||||
raise e
|
||||
|
||||
# check grads
|
||||
check_all_grad_tensors(grads_to_check)
|
||||
@@ -133,9 +144,10 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||
@parameterize(
|
||||
"test_config",
|
||||
[
|
||||
{
|
||||
{ # Test ring + Flash attention
|
||||
"tp_size": 2,
|
||||
"pp_size": 1,
|
||||
"sp_size": 2,
|
||||
"num_microbatches": 1,
|
||||
"enable_sequence_parallelism": True,
|
||||
"sequence_parallelism_mode": "ring",
|
||||
@@ -145,14 +157,16 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
"tp_size": 4,
|
||||
"pp_size": 1,
|
||||
"num_microbatches": 1,
|
||||
{ # Ulysess + Flash attention
|
||||
"tp_size": 1,
|
||||
"pp_size": 2,
|
||||
"sp_size": 2,
|
||||
"num_microbatches": 2,
|
||||
"enable_sequence_parallelism": True,
|
||||
"sequence_parallelism_mode": "split_gather",
|
||||
"enable_flash_attention": False,
|
||||
"sequence_parallelism_mode": "all_to_all",
|
||||
"enable_flash_attention": True,
|
||||
"use_lazy_init": True,
|
||||
"zero_stage": 1,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
@@ -164,7 +178,18 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||
"enable_sequence_parallelism": True,
|
||||
"sequence_parallelism_mode": "all_to_all",
|
||||
"use_lazy_init": True,
|
||||
"zero_stage": 2,
|
||||
"zero_stage": 1,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
"tp_size": 4,
|
||||
"pp_size": 1,
|
||||
"num_microbatches": 1,
|
||||
"enable_sequence_parallelism": True,
|
||||
"sequence_parallelism_mode": "split_gather",
|
||||
"enable_flash_attention": False,
|
||||
"use_lazy_init": True,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
@@ -213,7 +238,11 @@ def run_llama_test(test_config):
|
||||
sub_model_zoo = model_zoo.get_sub_registry("transformers_llama")
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||
try:
|
||||
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||
except Exception as e:
|
||||
print(f"Failed config: {test_config}")
|
||||
raise e
|
||||
|
||||
clear_layout_converter()
|
||||
Randomizer.reset_index()
|
||||
@@ -263,7 +292,11 @@ def run_llama_3d_test(test_config):
|
||||
sub_model_zoo = model_zoo.get_sub_registry("transformers_llama")
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||
try:
|
||||
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||
except Exception as e:
|
||||
print(f"Failed config: {test_config}")
|
||||
raise e
|
||||
|
||||
clear_layout_converter()
|
||||
Randomizer.reset_index()
|
||||
|
||||
Reference in New Issue
Block a user