mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-23 22:19:47 +00:00
* add SimPO
* fix dataloader
* remove debug code
* add orpo
* fix style
* fix colossalai, transformers version
* fix colossalai, transformers version
* fix colossalai, transformers version
* fix torch colossalai version
* update transformers version
* [shardformer] DeepseekMoE support (#5871)
* [Feature] deepseek moe expert parallel implement
* [misc] fix typo, remove redundant file (#5867)
* [misc] fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Feature] deepseek support & unit test
* [misc] remove debug code & useless print
* [misc] fix typos (#5872)
* [Feature] remove modeling file, use auto config. (#5884)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [Deepseek] remove redundant code (#5888)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [misc] remove redundant code
* [Feature/deepseek] resolve comment. (#5889)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [misc] remove redundant code
* [misc] mv module replacement into if branch
* [misc] add some warning message and modify some code in unit test
* [misc] fix typos
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Hoxfix] Fix CUDA_DEVICE_MAX_CONNECTIONS for comm overlap
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [Feat] Diffusion Model(PixArtAlpha/StableDiffusion3) Support (#5838)
* Diffusion Model Inference support
* Stable Diffusion 3 Support
* pixartalpha support
* [HotFix] CI,import,requirements-test for #5838 (#5892)
* [Hot Fix] CI,import,requirements-test
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Feature] Enable PP + SP for llama (#5868)
* fix cross-PP-stage position id length diff bug
* fix typo
* fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* use a one cross entropy func for all shardformer models
---------
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [ShardFormer] Add Ulysses Sequence Parallelism support for Command-R, Qwen2 and ChatGLM (#5897)
* add benchmark for sft, dpo, simpo, orpo. Add benchmarking result. Support lora with gradient checkpoint
* fix style
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix eval
* hotfix citation
* [zero] support all-gather overlap (#5898)
* [zero] support all-gather overlap
* [zero] add overlap all-gather flag
* [misc] fix typo
* [zero] update api
* fix orpo cross entropy loss
* [Auto Parallel]: Speed up intra-op plan generation by 44% (#5446)
* Remove unnecessary calls to deepcopy
* Build DimSpec's difference dict only once
This change considerably speeds up construction speed of DimSpec objects. The difference_dict is the same for each DimSpec object, so a single copy of it is enough.
* Fix documentation of DimSpec's difference method
* [ShardFormer] fix qwen2 sp (#5903)
* [compatibility] support torch 2.2 (#5875)
* Support Pytorch 2.2.2
* keep build_on_pr file and update .compatibility
* fix object_to_tensor usage when torch>=2.3.0 (#5820)
* [misc] support torch2.3 (#5893)
* [misc] support torch2.3
* [devops] update compatibility ci
* [devops] update compatibility ci
* [devops] add debug
* [devops] add debug
* [devops] add debug
* [devops] add debug
* [devops] remove debug
* [devops] remove debug
* [release] update version (#5912)
* [plugin] support all-gather overlap for hybrid parallel (#5919)
* [plugin] fixed all-gather overlap support for hybrid parallel
* add kto
* fix style, add kto data sample
* [Examples] Add lazy init to OPT and GPT examples (#5924)
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [ColossalChat] Hotfix for ColossalChat (#5910)
* add ignore and tiny llama
* fix path issue
* run style
* fix issue
* update bash
* add ignore and tiny llama
* fix path issue
* run style
* fix issue
* update bash
* fix ddp issue
* add Qwen 1.5 32B
* refactor tokenization
* [FIX BUG] UnboundLocalError: cannot access local variable 'default_conversation' where it is not associated with a value (#5931)
* cannot access local variable 'default_conversation' where it is not associated with a value
set default value for 'default_conversation'
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* fix test data
* refactor evaluation
* remove real data path
* remove real data path
* Add n_fused as an input from native_module (#5894)
* [FIX BUG] convert env param to int in (#5934)
* [Hotfix] Fix ZeRO typo #5936
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [Feature] Add a switch to control whether the model checkpoint needs to be saved after each epoch ends (#5941)
* Add a switch to control whether the model checkpoint needs to be saved after each epoch ends
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* fix style
* fix style
* fix style
* [shardformer] hotfix attn mask (#5945)
* [shardformer] hotfix attn mask (#5947)
* [Feat] Distrifusion Acceleration Support for Diffusion Inference (#5895)
* Distrifusion Support source
* comp comm overlap optimization
* sd3 benchmark
* pixart distrifusion bug fix
* sd3 bug fix and benchmark
* generation bug fix
* naming fix
* add docstring, fix counter and shape error
* add reference
* readme and requirement
* [zero] hotfix update master params (#5951)
* [release] update version (#5952)
* [Chat] Fix lora (#5946)
* fix merging
* remove filepath
* fix style
* Update README.md (#5958)
* [hotfix] Remove unused plan section (#5957)
* remove readme
* fix readme
* update
* [test] add mixtral for sequence classification
* [test] add mixtral transformer test
* [moe] fix plugin
* [test] mixtra pp shard test
* [chore] handle non member group
* [zero] solve hang
* [test] pass mixtral shardformer test
* [moe] implement transit between non moe tp and ep
* [zero] solve hang
* [misc] solve booster hang by rename the variable
* solve hang when parallel mode = pp + dp
* [moe] implement submesh initialization
* [moe] add mixtral dp grad scaling when not all experts are activated
* [chore] manually revert unintended commit
* [chore] trivial fix
* [chore] arg pass & remove drop token
* [test] add mixtral modelling test
* [moe] implement tp
* [moe] test deepseek
* [moe] clean legacy code
* [Feature] MoE Ulysses Support (#5918)
* moe sp support
* moe sp bug solve
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [chore] minor fix
* [moe] init moe plugin comm setting with sp
* moe sp + ep bug fix
* [moe] finalize test (no pp)
* [moe] full test for deepseek and mixtral (pp + sp to fix)
* [chore] minor fix after rebase
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [chore] solve moe ckpt test failure and some other arg pass failure
* [moe] remove ops
* [test] fix test: test_zero1_2
* [bug] fix: somehow logger hangs the program
* [moe] deepseek moe sp support
* [test] add check
* [deepseek] replace attn (a workaround for bug in transformers)
* [misc] skip redunant test
* [misc] remove debug/print code
* [moe] refactor mesh assignment
* Revert "[moe] implement submesh initialization"
This reverts commit 2f9bce6686
.
* [chore] change moe_pg_mesh to private
* [misc] remove incompatible test config
* [misc] fix ci failure: change default value to false in moe plugin
* [misc] remove useless condition
* [chore] docstring
* [moe] remove force_overlap_comm flag and add warning instead
* [doc] add MoeHybridParallelPlugin docstring
* [moe] solve dp axis issue
* [chore] remove redundant test case, print string & reduce test tokens
* [feat] Dist Loader for Eval (#5950)
* support auto distributed data loader
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* support auto distributed data loader
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix tp error
* remove unused parameters
* remove unused
* update inference
* update docs
* update inference
---------
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [lora] lora support hybrid parallel plugin (#5956)
* lora support hybrid plugin
* fix
* fix
* fix
* fix
* fp8 operators for compressed communication
cast_to_fp8, cast_from_fp8, all_reduce_fp8
* fix scaling algorithm in FP8 casting
* support fp8 communication in pipeline parallelism
* add fp8_communication flag in the script
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* shardformer fp8
* fix rebase
* remove all to all
* fix shardformer fp8 communication training degradation
* [fp8] support all-gather flat tensor (#5932)
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
* Update low_level_optim.py
---------
Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: Haze188 <haze188@qq.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com>
Co-authored-by: Guangyao Zhang <xjtu521@qq.com>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: Stephan Kö <stephankoe@users.noreply.github.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: zhurunhua <1281592874@qq.com>
Co-authored-by: Insu Jang <insujang@umich.edu>
Co-authored-by: Gao, Ruiyuan <905370712@qq.com>
Co-authored-by: hxwang <wang1570@e.ntu.edu.sg>
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: Wang Binluo <32676639+wangbluo@users.noreply.github.com>
Co-authored-by: HangXu <hangxu0304@gmail.com>
215 lines
6.2 KiB
Python
215 lines
6.2 KiB
Python
import copy
|
|
|
|
import pytest
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
from torch.testing import assert_close
|
|
|
|
import colossalai
|
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
|
from colossalai.testing.random import seed_all
|
|
from colossalai.zero import LowLevelZeroOptimizer
|
|
|
|
|
|
class MlpModel(nn.Module):
|
|
def __init__(self):
|
|
super(MlpModel, self).__init__()
|
|
self.linear1 = nn.Linear(123, 253)
|
|
self.linear_drop = nn.Linear(253, 253)
|
|
self.linear2 = nn.Linear(253, 512)
|
|
|
|
def forward(self, x):
|
|
x = self.linear1(x)
|
|
x = self.linear2(x)
|
|
return x
|
|
|
|
|
|
def loose_close(a, b, dtype: torch.dtype = torch.float32):
|
|
rtol = None
|
|
atol = None
|
|
if dtype is torch.float16:
|
|
rtol = 5e-2
|
|
atol = 5e-4
|
|
elif dtype is torch.bfloat16:
|
|
rtol = 4e-3
|
|
atol = 4e-3
|
|
|
|
a = a.detach().to(dtype)
|
|
b = b.detach().to(dtype)
|
|
|
|
assert_close(a, b, rtol=rtol, atol=atol)
|
|
|
|
|
|
def split_ddp_grad(grad, world_size):
|
|
with torch.no_grad():
|
|
grad = grad.clone().detach().flatten()
|
|
padding_size = (world_size - grad.numel() % world_size) % world_size
|
|
if padding_size > 0:
|
|
grad = torch.nn.functional.pad(grad, [0, padding_size])
|
|
splited_grad = grad.split(grad.numel() // world_size)
|
|
return splited_grad
|
|
|
|
|
|
@parameterize("fp8_communication", [True, False])
|
|
def exam_zero_1_2(fp8_communication: bool):
|
|
"""
|
|
In this test, we want to test whether zero stage 1 and 2
|
|
deliver the same numerical results despite different communication
|
|
pattern
|
|
|
|
we use these prefixes to differentiate the zero stage
|
|
oss: partition optimizer states
|
|
pg: partition gradients and optimizer states
|
|
|
|
"""
|
|
local_rank = torch.distributed.get_rank()
|
|
seed_all(2001)
|
|
|
|
# create model
|
|
zero1_model = MlpModel().cuda()
|
|
zero2_model = copy.deepcopy(zero1_model)
|
|
|
|
# create optimizer
|
|
zero1_optimizer = torch.optim.Adam(zero1_model.parameters(), lr=1)
|
|
zero2_optimizer = torch.optim.Adam(zero2_model.parameters(), lr=1)
|
|
zero1_optimizer = LowLevelZeroOptimizer(
|
|
zero1_optimizer,
|
|
overlap_communication=True,
|
|
initial_scale=128,
|
|
verbose=True,
|
|
fp8_communication=fp8_communication,
|
|
)
|
|
zero2_optimizer = LowLevelZeroOptimizer(
|
|
zero2_optimizer,
|
|
overlap_communication=True,
|
|
partition_grad=True,
|
|
initial_scale=128,
|
|
fp8_communication=fp8_communication,
|
|
)
|
|
# create data
|
|
seed_all(2001 + local_rank)
|
|
input_data = torch.randn(32, 123).cuda()
|
|
|
|
zero1_output = zero1_model(input_data)
|
|
zero2_output = zero2_model(input_data)
|
|
assert torch.equal(zero1_output, zero2_output)
|
|
|
|
# zero-dp backward
|
|
zero1_optimizer.backward(zero1_output.mean().float())
|
|
zero2_optimizer.backward(zero2_output.mean().float())
|
|
|
|
# check grad
|
|
for p1, p2 in zip(zero1_model.parameters(), zero2_model.parameters()):
|
|
g1 = zero1_optimizer.get_param_grad(p1)
|
|
g2 = zero2_optimizer.get_param_grad(p2)
|
|
if g1 is None or g2 is None:
|
|
assert g1 is None and g2 is None
|
|
continue
|
|
if fp8_communication:
|
|
loose_close(g1, g2, dtype=torch.float16)
|
|
else:
|
|
assert torch.allclose(g1, g2)
|
|
|
|
# step
|
|
zero1_optimizer.step()
|
|
zero2_optimizer.step()
|
|
|
|
# check updated param
|
|
for z1p, z2p in zip(zero1_model.parameters(), zero2_model.parameters()):
|
|
if not fp8_communication:
|
|
assert torch.allclose(z1p, z2p)
|
|
|
|
|
|
@parameterize("dtype", [torch.float16, torch.bfloat16])
|
|
@parameterize("master_weights", [True, False])
|
|
def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype, master_weights: bool):
|
|
"""
|
|
In this test, two pairs of model and optimizers are created.
|
|
1. zero: use sharded optimizer and fp16 parameters
|
|
2. torch: use torch DDP and fp32 parameters
|
|
|
|
We feed these two sets of models with the same input and check if the
|
|
differences in model output and updated parameters are within tolerance.
|
|
"""
|
|
local_rank = torch.distributed.get_rank()
|
|
seed_all(1453)
|
|
|
|
# create models
|
|
torch_model = MlpModel().cuda().to(dtype)
|
|
zero_model = copy.deepcopy(torch_model).to(dtype)
|
|
|
|
torch_model = DDP(torch_model.cuda(), static_graph=True).cuda()
|
|
|
|
# create optimizer
|
|
zero_optimizer = torch.optim.SGD(zero_model.parameters(), lr=1)
|
|
|
|
# we only test stage 1 here
|
|
# in `check_sharded_param_consistency.py`, we will test whether
|
|
# level 1 and 2 will produce exactly the same results
|
|
zero_optimizer = LowLevelZeroOptimizer(
|
|
zero_optimizer,
|
|
overlap_communication=True,
|
|
initial_scale=1,
|
|
reduce_bucket_size=1024 * 1024,
|
|
master_weights=master_weights,
|
|
)
|
|
|
|
torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)
|
|
|
|
seed_all(1453 + local_rank)
|
|
|
|
for _ in range(2):
|
|
# create
|
|
input_data = torch.rand(32, 123).cuda().to(dtype)
|
|
|
|
# zero-dp forward
|
|
zero_output = zero_model(input_data)
|
|
|
|
# torch-ddp forward
|
|
torch_output = torch_model(input_data)
|
|
loose_close(zero_output, torch_output, dtype=dtype)
|
|
|
|
# zero-dp backward
|
|
zero_optimizer.backward(zero_output.mean())
|
|
|
|
# torch-ddp backward
|
|
torch_output.mean().backward()
|
|
|
|
# check grad
|
|
for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
|
|
zero_grad = zero_optimizer.get_param_grad(z1p)
|
|
if p.grad is None:
|
|
assert zero_grad is None
|
|
continue
|
|
loose_close(p.grad, zero_grad, dtype=dtype)
|
|
|
|
# zero-dp step
|
|
zero_optimizer.step()
|
|
|
|
# torch ddp step
|
|
torch_optimizer.step()
|
|
|
|
zero_optimizer._force_wait_all_gather()
|
|
|
|
# check updated param
|
|
for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
|
|
loose_close(p, z1p, dtype=dtype)
|
|
|
|
|
|
def run_dist(rank, world_size, port):
|
|
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
|
|
|
|
exam_zero_1_torch_ddp(world_size=world_size)
|
|
exam_zero_1_2()
|
|
|
|
|
|
@pytest.mark.dist
|
|
@rerun_if_address_is_in_use()
|
|
def test_zero_1_2():
|
|
spawn(run_dist, 2)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_zero_1_2()
|