mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-31 08:34:14 +00:00
* add SimPO
* fix dataloader
* remove debug code
* add orpo
* fix style
* fix colossalai, transformers version
* fix colossalai, transformers version
* fix colossalai, transformers version
* fix torch colossalai version
* update transformers version
* [shardformer] DeepseekMoE support (#5871)
* [Feature] deepseek moe expert parallel implement
* [misc] fix typo, remove redundant file (#5867)
* [misc] fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Feature] deepseek support & unit test
* [misc] remove debug code & useless print
* [misc] fix typos (#5872)
* [Feature] remove modeling file, use auto config. (#5884)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [Deepseek] remove redundant code (#5888)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [misc] remove redundant code
* [Feature/deepseek] resolve comment. (#5889)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [misc] remove redundant code
* [misc] mv module replacement into if branch
* [misc] add some warning message and modify some code in unit test
* [misc] fix typos
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Hoxfix] Fix CUDA_DEVICE_MAX_CONNECTIONS for comm overlap
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [Feat] Diffusion Model(PixArtAlpha/StableDiffusion3) Support (#5838)
* Diffusion Model Inference support
* Stable Diffusion 3 Support
* pixartalpha support
* [HotFix] CI,import,requirements-test for #5838 (#5892)
* [Hot Fix] CI,import,requirements-test
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Feature] Enable PP + SP for llama (#5868)
* fix cross-PP-stage position id length diff bug
* fix typo
* fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* use a one cross entropy func for all shardformer models
---------
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [ShardFormer] Add Ulysses Sequence Parallelism support for Command-R, Qwen2 and ChatGLM (#5897)
* add benchmark for sft, dpo, simpo, orpo. Add benchmarking result. Support lora with gradient checkpoint
* fix style
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix eval
* hotfix citation
* [zero] support all-gather overlap (#5898)
* [zero] support all-gather overlap
* [zero] add overlap all-gather flag
* [misc] fix typo
* [zero] update api
* fix orpo cross entropy loss
* [Auto Parallel]: Speed up intra-op plan generation by 44% (#5446)
* Remove unnecessary calls to deepcopy
* Build DimSpec's difference dict only once
This change considerably speeds up construction speed of DimSpec objects. The difference_dict is the same for each DimSpec object, so a single copy of it is enough.
* Fix documentation of DimSpec's difference method
* [ShardFormer] fix qwen2 sp (#5903)
* [compatibility] support torch 2.2 (#5875)
* Support Pytorch 2.2.2
* keep build_on_pr file and update .compatibility
* fix object_to_tensor usage when torch>=2.3.0 (#5820)
* [misc] support torch2.3 (#5893)
* [misc] support torch2.3
* [devops] update compatibility ci
* [devops] update compatibility ci
* [devops] add debug
* [devops] add debug
* [devops] add debug
* [devops] add debug
* [devops] remove debug
* [devops] remove debug
* [release] update version (#5912)
* [plugin] support all-gather overlap for hybrid parallel (#5919)
* [plugin] fixed all-gather overlap support for hybrid parallel
* add kto
* fix style, add kto data sample
* [Examples] Add lazy init to OPT and GPT examples (#5924)
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [ColossalChat] Hotfix for ColossalChat (#5910)
* add ignore and tiny llama
* fix path issue
* run style
* fix issue
* update bash
* add ignore and tiny llama
* fix path issue
* run style
* fix issue
* update bash
* fix ddp issue
* add Qwen 1.5 32B
* refactor tokenization
* [FIX BUG] UnboundLocalError: cannot access local variable 'default_conversation' where it is not associated with a value (#5931)
* cannot access local variable 'default_conversation' where it is not associated with a value
set default value for 'default_conversation'
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* fix test data
* refactor evaluation
* remove real data path
* remove real data path
* Add n_fused as an input from native_module (#5894)
* [FIX BUG] convert env param to int in (#5934)
* [Hotfix] Fix ZeRO typo #5936
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [Feature] Add a switch to control whether the model checkpoint needs to be saved after each epoch ends (#5941)
* Add a switch to control whether the model checkpoint needs to be saved after each epoch ends
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* fix style
* fix style
* fix style
* [shardformer] hotfix attn mask (#5945)
* [shardformer] hotfix attn mask (#5947)
* [Feat] Distrifusion Acceleration Support for Diffusion Inference (#5895)
* Distrifusion Support source
* comp comm overlap optimization
* sd3 benchmark
* pixart distrifusion bug fix
* sd3 bug fix and benchmark
* generation bug fix
* naming fix
* add docstring, fix counter and shape error
* add reference
* readme and requirement
* [zero] hotfix update master params (#5951)
* [release] update version (#5952)
* [Chat] Fix lora (#5946)
* fix merging
* remove filepath
* fix style
* Update README.md (#5958)
* [hotfix] Remove unused plan section (#5957)
* remove readme
* fix readme
* update
* [test] add mixtral for sequence classification
* [test] add mixtral transformer test
* [moe] fix plugin
* [test] mixtra pp shard test
* [chore] handle non member group
* [zero] solve hang
* [test] pass mixtral shardformer test
* [moe] implement transit between non moe tp and ep
* [zero] solve hang
* [misc] solve booster hang by rename the variable
* solve hang when parallel mode = pp + dp
* [moe] implement submesh initialization
* [moe] add mixtral dp grad scaling when not all experts are activated
* [chore] manually revert unintended commit
* [chore] trivial fix
* [chore] arg pass & remove drop token
* [test] add mixtral modelling test
* [moe] implement tp
* [moe] test deepseek
* [moe] clean legacy code
* [Feature] MoE Ulysses Support (#5918)
* moe sp support
* moe sp bug solve
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [chore] minor fix
* [moe] init moe plugin comm setting with sp
* moe sp + ep bug fix
* [moe] finalize test (no pp)
* [moe] full test for deepseek and mixtral (pp + sp to fix)
* [chore] minor fix after rebase
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [chore] solve moe ckpt test failure and some other arg pass failure
* [moe] remove ops
* [test] fix test: test_zero1_2
* [bug] fix: somehow logger hangs the program
* [moe] deepseek moe sp support
* [test] add check
* [deepseek] replace attn (a workaround for bug in transformers)
* [misc] skip redunant test
* [misc] remove debug/print code
* [moe] refactor mesh assignment
* Revert "[moe] implement submesh initialization"
This reverts commit 2f9bce6686
.
* [chore] change moe_pg_mesh to private
* [misc] remove incompatible test config
* [misc] fix ci failure: change default value to false in moe plugin
* [misc] remove useless condition
* [chore] docstring
* [moe] remove force_overlap_comm flag and add warning instead
* [doc] add MoeHybridParallelPlugin docstring
* [moe] solve dp axis issue
* [chore] remove redundant test case, print string & reduce test tokens
* [feat] Dist Loader for Eval (#5950)
* support auto distributed data loader
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* support auto distributed data loader
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix tp error
* remove unused parameters
* remove unused
* update inference
* update docs
* update inference
---------
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [lora] lora support hybrid parallel plugin (#5956)
* lora support hybrid plugin
* fix
* fix
* fix
* fix
* fp8 operators for compressed communication
cast_to_fp8, cast_from_fp8, all_reduce_fp8
* fix scaling algorithm in FP8 casting
* support fp8 communication in pipeline parallelism
* add fp8_communication flag in the script
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* shardformer fp8
* fix rebase
* remove all to all
* fix shardformer fp8 communication training degradation
* [fp8] support all-gather flat tensor (#5932)
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
* Update low_level_optim.py
---------
Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: Haze188 <haze188@qq.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com>
Co-authored-by: Guangyao Zhang <xjtu521@qq.com>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: Stephan Kö <stephankoe@users.noreply.github.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: zhurunhua <1281592874@qq.com>
Co-authored-by: Insu Jang <insujang@umich.edu>
Co-authored-by: Gao, Ruiyuan <905370712@qq.com>
Co-authored-by: hxwang <wang1570@e.ntu.edu.sg>
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: Wang Binluo <32676639+wangbluo@users.noreply.github.com>
Co-authored-by: HangXu <hangxu0304@gmail.com>
191 lines
6.3 KiB
Python
191 lines
6.3 KiB
Python
import os
|
|
import shutil
|
|
from copy import deepcopy
|
|
from typing import Tuple
|
|
|
|
import pytest
|
|
import torch
|
|
import torch.distributed
|
|
import torch.distributed as dist
|
|
from transformers.models.mixtral.configuration_mixtral import MixtralConfig
|
|
from transformers.models.mixtral.modeling_mixtral import MixtralModel
|
|
|
|
import colossalai
|
|
from colossalai.booster.booster import Booster
|
|
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
|
from colossalai.testing.random import seed_all
|
|
from tests.test_moe.moe_utils import assert_loose_close, check_model_equal
|
|
|
|
NUM_BATCH = 8
|
|
NUM_TOK_PER_BATCH, NUM_EXPERTS = 4, 4
|
|
NUM_LAYERS = 4
|
|
HIDDEN_SIZE_PER_HEAD = 4
|
|
NUM_HEADS = 4
|
|
TOP_K = 1
|
|
|
|
CHECKED_CONFIG = [ # FOR WORLD=4
|
|
(0, 1, 4, 1, 1),
|
|
(0, 1, 1, 4, 1),
|
|
(0, 1, 1, 1, 4),
|
|
(1, 4, 1, 1, 1),
|
|
(1, 1, 4, 1, 1),
|
|
(1, 1, 1, 4, 1),
|
|
(1, 1, 1, 1, 4),
|
|
(1, 2, 1, 1, 1),
|
|
]
|
|
|
|
|
|
@parameterize(
|
|
"config",
|
|
[
|
|
(1, 2, 2, 1, 1),
|
|
(1, 2, 1, 2, 1),
|
|
(1, 2, 1, 1, 2),
|
|
],
|
|
)
|
|
def run_zero_with_original_model(config: Tuple[int, ...]):
|
|
stage, ep_size, pp_size, tp_size, sp_size = config
|
|
world_size = dist.get_world_size()
|
|
rank = dist.get_rank()
|
|
dtype, precision = torch.float16, "fp16"
|
|
torch.cuda.set_device(dist.get_rank())
|
|
|
|
plugin = MoeHybridParallelPlugin(
|
|
pp_size=pp_size,
|
|
num_microbatches=pp_size,
|
|
tp_size=tp_size,
|
|
sp_size=sp_size,
|
|
ep_size=ep_size,
|
|
zero_stage=stage,
|
|
enable_sequence_parallelism=sp_size > 1,
|
|
sequence_parallelism_mode="all_to_all" if sp_size > 1 else None,
|
|
overlap_communication=False,
|
|
initial_scale=1,
|
|
precision=precision,
|
|
find_unused_parameters=True,
|
|
)
|
|
dp_size = plugin.dp_size
|
|
|
|
booster = Booster(plugin=plugin)
|
|
|
|
assert pp_size <= NUM_LAYERS, "pp_size should be less than or equal to NUM_LAYERS"
|
|
config = MixtralConfig(
|
|
hidden_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS,
|
|
intermediate_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2,
|
|
num_hidden_layers=NUM_LAYERS,
|
|
num_attention_heads=NUM_HEADS,
|
|
num_key_value_heads=NUM_HEADS,
|
|
num_local_experts=NUM_EXPERTS,
|
|
num_experts_per_tok=TOP_K,
|
|
attn_implementation="flash_attention_2",
|
|
)
|
|
|
|
# init model with the same seed
|
|
seed_all(10086)
|
|
|
|
torch_model = MixtralModel(config).to(dtype).cuda()
|
|
torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)
|
|
|
|
parallel_model = deepcopy(torch_model)
|
|
parallel_optimizer = torch.optim.SGD(parallel_model.parameters(), lr=1)
|
|
parallel_model, parallel_optimizer, _, _, _ = booster.boost(parallel_model, parallel_optimizer)
|
|
|
|
# create different input along dp axis
|
|
seed_all(1453 + rank)
|
|
|
|
torch_model.train()
|
|
parallel_model.train()
|
|
for _ in range(2):
|
|
# gen random input
|
|
input_embeddings = torch.rand(
|
|
NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True
|
|
).cuda()
|
|
dist.all_reduce(
|
|
input_embeddings, group=plugin.pp_group
|
|
) # pp inputs except the first stage doesn't matter, but need to be replicate for torch model check
|
|
|
|
dist.all_reduce(input_embeddings, group=plugin.tp_group) # tp group duplicate input
|
|
dist.all_reduce(input_embeddings, group=plugin.sp_group) # sp group duplicate input
|
|
|
|
# run the model with hybrid parallel
|
|
if booster.plugin.stage_manager is not None:
|
|
# for test with pp
|
|
data_iter = iter([{"inputs_embeds": input_embeddings}])
|
|
sharded_output = booster.execute_pipeline(
|
|
data_iter,
|
|
parallel_model,
|
|
lambda x, y: x.last_hidden_state.mean(),
|
|
parallel_optimizer,
|
|
return_loss=True,
|
|
return_outputs=True,
|
|
)
|
|
if booster.plugin.stage_manager.is_last_stage():
|
|
parallel_output = sharded_output["loss"]
|
|
else:
|
|
parallel_output = torch.tensor(12345.0, device="cuda")
|
|
|
|
# broadcast along pp axis
|
|
dist.broadcast(
|
|
parallel_output, src=dist.get_process_group_ranks(plugin.pp_group)[-1], group=plugin.pp_group
|
|
)
|
|
else:
|
|
# for test without pp
|
|
parallel_output = parallel_model(inputs_embeds=input_embeddings.to(dtype)).last_hidden_state.mean()
|
|
parallel_optimizer.backward(parallel_output)
|
|
parallel_optimizer.step()
|
|
parallel_optimizer.zero_grad()
|
|
dist.all_reduce(parallel_output, group=plugin.dp_group)
|
|
|
|
# ===================================================================================
|
|
# run normal model with all dp(different) inputs
|
|
all_inputs = [torch.empty_like(input_embeddings) for _ in range(dp_size)]
|
|
dist.all_gather(all_inputs, input_embeddings, group=plugin.dp_group)
|
|
torch_output_sum = 0
|
|
for input_data_ in all_inputs:
|
|
torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean()
|
|
torch_output.backward()
|
|
torch_output_sum += torch_output.detach()
|
|
# avg dp grads follows zero optimizer
|
|
for p in torch_model.parameters():
|
|
if p.grad is not None:
|
|
p.grad /= dp_size
|
|
torch_optimizer.step()
|
|
torch_optimizer.zero_grad()
|
|
|
|
assert_loose_close(parallel_output, torch_output_sum, dtype=dtype)
|
|
|
|
# use checkpoint to load sharded zero model
|
|
model_dir = "./test_mixtral"
|
|
if rank == world_size - 1:
|
|
os.makedirs(model_dir, exist_ok=True)
|
|
|
|
dist.barrier()
|
|
booster.save_model(parallel_model, model_dir, shard=True)
|
|
dist.barrier()
|
|
|
|
saved_model = MixtralModel.from_pretrained(model_dir).cuda().to(dtype)
|
|
check_model_equal(torch_model, saved_model)
|
|
dist.barrier()
|
|
|
|
if rank == world_size - 1:
|
|
shutil.rmtree(model_dir)
|
|
|
|
print(f"rank {dist.get_rank()} test passed")
|
|
|
|
|
|
def run_dist(rank, world_size, port):
|
|
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
|
run_zero_with_original_model()
|
|
|
|
|
|
@pytest.mark.dist
|
|
@pytest.mark.parametrize("world_size", [4])
|
|
@rerun_if_address_is_in_use()
|
|
def test_mixtral(world_size):
|
|
spawn(run_dist, world_size)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_mixtral(world_size=4)
|