mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-04-28 03:43:01 +00:00
* sequence parallel optimization * validate sequence parallel in llama (code to be polished) * shardformer api writing * integrate sequence parallel in ShardFormer * fix pp bugs and sp bugs for LlaMa model * integrating ring-based sequence parallelism into ShardFormer * [sequence parallelism]: Add fused megatron function * integrating ring-based sequence parallelism into ShardFormer --------- Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn> * fix bugs when useing sp and flashattention together * fix operation function name * support flash attention for ulysses-style sp * clarify sp process group * fix compatibility bugs in moe plugin * fix fused linear bugs * fix linear layer test * support gpt model all-to-all sp * modify shard data dimension (meant to be dim=-1) * support megtron-style sp and distributed attn for llama model * [shardformer] add megatron sp to llama * support llama7B 128k with distributed attention * [shardformer] robustness enhancement * add block attn * sp mode 1: keep input as a complete sequence * fix sp compatability * finish sp mode 3 support for gpt * using all_to_all_single when batch size is 1 * support mode 2 sp in gpt2 (#5) * [shardformer] add megatron sp to llama * support llama7B 128k with distributed attention * [shardformer] robustness enhancement * add block attn * sp mode 1: keep input as a complete sequence * fix sp compatability * refactor ring implementation * support mode 2 sp in gpt2 * polish code * enable distributed attn mask when using sp mode 2 and 3 in llama * automatically enable flash attn when using sp mode 2 and 3 in llama * inplace attn mask * add zero2 support for sequence parallel * polish code * fix bugs * fix gemini checkpoint io * loose tensor checking atol and rtol * add comment * fix llama layernorm grad * fix zero grad * fix zero grad * fix conflict * update split and gather auto grad func * sequence parallel: inside text split (#6) * polish code (part 1) * polish code (part 2) * polish code (part 2.5) * polish code (part 3) * sequence parallel: inside text split * miscellaneous minor fixes * polish code * fix ulysses style ZeRO * sequence parallel: inside text split * miscellaneous minor fixes * disaggregate sp group and dp group for sp * fix llama and gpt sp * polish code * move ulysses grad sync to ddp (#9) * remove zero_stage and unbind the grad sync for alltoall sp * add 2d group creation test * move ulysses grad sync to ddp * add 2d group creation test * remove useless code * change shard config not to enable sp when enable_all_optimizations * add sp warnings for several model * remove useless code --------- Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn>
261 lines
8.2 KiB
Python
261 lines
8.2 KiB
Python
import pytest
|
|
import torch
|
|
|
|
import colossalai
|
|
from colossalai.logging import disable_existing_loggers
|
|
from colossalai.shardformer.layer.utils import Randomizer
|
|
from colossalai.tensor.d_tensor.api import clear_layout_converter
|
|
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
|
from tests.kit.model_zoo import model_zoo
|
|
from tests.test_shardformer.test_model._utils import (
|
|
build_model_from_hybrid_plugin,
|
|
check_all_grad_tensors,
|
|
check_loss,
|
|
check_output_hidden_state,
|
|
check_weight,
|
|
get_grad_tensors_for_check,
|
|
run_forward_backward_with_hybrid_plugin,
|
|
unwrap_model,
|
|
)
|
|
|
|
|
|
def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
|
|
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin(
|
|
model_fn, loss_fn, test_config
|
|
)
|
|
|
|
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
|
|
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
|
|
)
|
|
|
|
stage_manager = booster.plugin.stage_manager
|
|
tp_group = booster.plugin.tp_group
|
|
|
|
bert = unwrap_model(org_model, "BertModel", "bert")
|
|
sharded_bert = unwrap_model(sharded_model, "BertModel", "bert")
|
|
|
|
norm_layer_for_check = ["encoder.layer[0].attention.output.LayerNorm", "embeddings.LayerNorm"]
|
|
col_layer_for_check = ["encoder.layer[0].output.dense"]
|
|
row_layer_for_check = ["embeddings.word_embeddings", "encoder.layer[0].intermediate.dense"]
|
|
weight_layer_for_check = ["encoder.layer[0].output.dense", "encoder.layer[1].output.dense"]
|
|
|
|
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
|
|
grads_to_check = {}
|
|
if test_config["precision"] == "fp32":
|
|
atol, rtol = 1e-4, 1e-3
|
|
else:
|
|
atol, rtol = 5e-3, 5e-3
|
|
if (stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True)) and booster.plugin.zero_stage == 0:
|
|
col_layer_grads = get_grad_tensors_for_check(
|
|
bert, sharded_bert, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False
|
|
)
|
|
row_layer_grads = get_grad_tensors_for_check(
|
|
bert, sharded_bert, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False
|
|
)
|
|
|
|
norm_layer_grads = get_grad_tensors_for_check(
|
|
bert,
|
|
sharded_bert,
|
|
norm_layer_for_check,
|
|
tp_group,
|
|
atol=atol,
|
|
rtol=rtol,
|
|
dim=1,
|
|
verbose=False,
|
|
)
|
|
|
|
grads_to_check.update(col_layer_grads)
|
|
grads_to_check.update(row_layer_grads)
|
|
grads_to_check.update(norm_layer_grads)
|
|
|
|
# optimizer executes step
|
|
org_optimizer.step()
|
|
sharded_optimizer.step()
|
|
|
|
# check last hidden state & loss
|
|
if stage_manager is None or stage_manager.is_last_stage(ignore_chunk=True):
|
|
if test_config["precision"] == "fp32":
|
|
atol, rtol = 1e-5, 1e-3
|
|
else:
|
|
atol, rtol = 5e-3, 5e-3
|
|
if org_model.__class__.__name__ == "BertModel":
|
|
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
|
|
|
|
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
|
|
|
|
# check weights
|
|
if test_config["precision"] == "fp32":
|
|
atol, rtol = 5e-3, 1e-3
|
|
else:
|
|
atol, rtol = 5e-3, 5e-3
|
|
if stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True):
|
|
check_weight(bert, sharded_bert, weight_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1)
|
|
|
|
# check grads
|
|
check_all_grad_tensors(grads_to_check)
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
@parameterize(
|
|
"test_config",
|
|
[
|
|
{
|
|
"tp_size": 4,
|
|
"pp_size": 1,
|
|
"num_microbatches": 1,
|
|
"enable_sequence_parallelism": True,
|
|
"sequence_parallelism_mode": "ring",
|
|
"enable_flash_attention": False,
|
|
"use_lazy_init": True,
|
|
"precision": "fp32",
|
|
"initial_scale": 1,
|
|
},
|
|
{
|
|
"tp_size": 4,
|
|
"pp_size": 1,
|
|
"num_microbatches": 1,
|
|
"enable_sequence_parallelism": True,
|
|
"sequence_parallelism_mode": "split_gather",
|
|
"enable_flash_attention": False,
|
|
"use_lazy_init": True,
|
|
"precision": "fp16",
|
|
"initial_scale": 1,
|
|
},
|
|
{
|
|
"tp_size": 2,
|
|
"pp_size": 1,
|
|
"enable_all_optimization": True,
|
|
"use_lazy_init": True,
|
|
"precision": "fp32",
|
|
},
|
|
{
|
|
"tp_size": 1,
|
|
"pp_size": 2,
|
|
"num_microbatches": 4,
|
|
"use_lazy_init": True,
|
|
"precision": "fp32",
|
|
},
|
|
{
|
|
"tp_size": 2,
|
|
"pp_size": 2,
|
|
"num_microbatches": 2,
|
|
"enable_all_optimization": True,
|
|
"use_lazy_init": True,
|
|
"precision": "fp16",
|
|
"initial_scale": 1,
|
|
},
|
|
{
|
|
"tp_size": 4,
|
|
"pp_size": 1,
|
|
"enable_all_optimization": True,
|
|
"use_lazy_init": False,
|
|
"precision": "fp32",
|
|
},
|
|
{"tp_size": 2, "pp_size": 1, "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp32"},
|
|
{
|
|
"tp_size": 2,
|
|
"pp_size": 1,
|
|
"enable_all_optimization": True,
|
|
"use_lazy_init": True,
|
|
"zero_stage": 2,
|
|
"precision": "fp16",
|
|
"initial_scale": 1,
|
|
},
|
|
{
|
|
"tp_size": 1,
|
|
"pp_size": 2,
|
|
"num_microbatches": 2,
|
|
"enable_all_optimization": True,
|
|
"use_lazy_init": True,
|
|
"zero_stage": 1,
|
|
"precision": "fp16",
|
|
"initial_scale": 1,
|
|
},
|
|
],
|
|
)
|
|
def run_bert_test(test_config):
|
|
sub_model_zoo = model_zoo.get_sub_registry("transformers_bert")
|
|
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
|
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
|
|
|
clear_layout_converter()
|
|
Randomizer.reset_index()
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
@parameterize(
|
|
"test_config",
|
|
[
|
|
{
|
|
"tp_size": 2,
|
|
"pp_size": 2,
|
|
"num_microbatches": 4,
|
|
"enable_all_optimization": False,
|
|
"use_lazy_init": False,
|
|
"precision": "fp32",
|
|
},
|
|
{
|
|
"tp_size": 2,
|
|
"pp_size": 2,
|
|
"num_microbatches": 4,
|
|
"enable_all_optimization": False,
|
|
"use_lazy_init": False,
|
|
"precision": "fp16",
|
|
"zero_stage": 1,
|
|
"initial_scale": 1,
|
|
},
|
|
{
|
|
"tp_size": 2,
|
|
"pp_size": 2,
|
|
"pp_style": "interleaved",
|
|
"num_model_chunks": 2,
|
|
"num_microbatches": 4,
|
|
"enable_all_optimization": False,
|
|
"precision": "fp16",
|
|
"zero_stage": 1,
|
|
"initial_scale": 1,
|
|
},
|
|
],
|
|
)
|
|
def run_bert_3d_test(test_config):
|
|
sub_model_zoo = model_zoo.get_sub_registry("transformers_bert")
|
|
|
|
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
|
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
|
|
|
clear_layout_converter()
|
|
Randomizer.reset_index()
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
def check_bert(rank, world_size, port):
|
|
disable_existing_loggers()
|
|
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
|
run_bert_test()
|
|
|
|
|
|
def check_bert_3d(rank, world_size, port):
|
|
disable_existing_loggers()
|
|
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
|
run_bert_3d_test()
|
|
|
|
|
|
@pytest.mark.dist
|
|
@rerun_if_address_is_in_use()
|
|
@clear_cache_before_run()
|
|
def test_bert():
|
|
spawn(check_bert, 4)
|
|
|
|
|
|
@pytest.mark.largedist
|
|
@rerun_if_address_is_in_use()
|
|
@clear_cache_before_run()
|
|
def test_bert_3d():
|
|
spawn(check_bert_3d, 8)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_bert()
|
|
test_bert_3d()
|