mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-04-27 19:36:13 +00:00
* [moe] removed openmoe-coupled code and rectify mixstral code (#5471) * [Feauture] MoE refractor; Intergration with Mixtral (#5682) * cherry pick from refractor-moe branch * tests passed * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * support ep + zero --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * add mixtral auto policy & move pipeline forward code to modeling folder * [moe refactor] modify kernel test without Route Class * [moe refactor] add moe tensor test path environment variable to github workflow * fix typos * fix moe test bug due to the code rebase * [moe refactor] fix moe zero test, and little bug in low level zero * fix typo * add moe tensor path to github workflow * remove some useless code * fix typo & unify global variable XX_AXIS logic without using -1 * fix typo & prettifier the code * remove print code & support zero 2 test * remove useless code * reanme function * fix typo * fix typo * Further improve the test code * remove print code * [moe refactor] change test model from fake moe model to mixtral moe layer and remove useless test * [moe refactor] skip some unit test which will be refactored later * [moe refactor] fix unit import error * [moe refactor] fix circular import issues * [moe refactor] remove debug code * [moe refactor] update github workflow * [moe/zero] refactor low level optimizer (#5767) * [zero] refactor low level optimizer * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [Feature] MoE refactor with newest version of ZeRO (#5801) * [zero] remove redundant members in BucketStore (#5802) * [zero] align api with previous version * [Moe/Zero] Update MoeHybridParallelPlugin with refactored ZeRO and Fix Zero bug (#5819) * [moe refactor] update unit test with the refactored ZeRO and remove useless test * move moe checkpoint to checkpoint folder and exchange global axis to class member * update moe hybrid parallel plugin with newest version of zero & fix zero working/master params bug * fix zero unit test * Add an assertion to prevent users from using it incorrectly * [hotfix]Solve the compatibility issue of zero refactor (#5823) * [moe refactor] update unit test with the refactored ZeRO and remove useless test * move moe checkpoint to checkpoint folder and exchange global axis to class member * update moe hybrid parallel plugin with newest version of zero & fix zero working/master params bug * fix zero unit test * Add an assertion to prevent users from using it incorrectly * Modify function parameter names to resolve compatibility issues * [zero] fix missing hook removal (#5824) * [MoE] Resolve .github conflict (#5829) * [Fix/Example] Fix Llama Inference Loading Data Type (#5763) * [fix/example] fix llama inference loading dtype * revise loading dtype of benchmark llama3 * [release] update version (#5752) * [release] update version * [devops] update compatibility test * [devops] update compatibility test * [devops] update compatibility test * [devops] update compatibility test * [test] fix ddp plugin test * [test] fix gptj and rpc test * [devops] fix cuda ext compatibility * [inference] fix flash decoding test * [inference] fix flash decoding test * fix (#5765) * [test] Fix/fix testcase (#5770) * [fix] branch for fix testcase; * [fix] fix test_analyzer & test_auto_parallel; * [fix] remove local change about moe; * [fix] rm local change moe; * [Hotfix] Add missing init file in inference.executor (#5774) * [CI/tests] simplify some test case to reduce testing time (#5755) * [ci/tests] simplify some test case to reduce testing time * [ci/tests] continue to remove test case to reduce ci time cost * restore some test config * [ci/tests] continue to reduce ci time cost * [misc] update dockerfile (#5776) * [misc] update dockerfile * [misc] update dockerfile * [devops] fix docker ci (#5780) * [Inference]Add Streaming LLM (#5745) * Add Streaming LLM * add some parameters to llama_generation.py * verify streamingllm config * add test_streamingllm.py * modified according to the opinions of review * add Citation * change _block_tables tolist * [hotfix] fix llama flash attention forward (#5777) * [misc] Accelerate CI for zero and dist optim (#5758) * remove fp16 from lamb * remove d2h copy in checking states --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> * [Test/CI] remove test cases to reduce CI duration (#5753) * [test] smaller gpt2 test case * [test] reduce test cases: tests/test_zero/test_gemini/test_zeroddp_state_dict.py * [test] reduce test cases: tests/test_zero/test_gemini/test_grad_accum.py * [test] reduce test cases tests/test_zero/test_gemini/test_optim.py * Revert "[test] smaller gpt2 test case" Some tests might depend on the size of model (num of chunks) This reverts commitdf705a5210
. * [test] reduce test cases: tests/test_checkpoint_io/test_gemini_checkpoint_io.py * [CI] smaller test model for two mwo the two modifid cases * [CI] hardcode gpt model for tests/test_zero/test_gemini/test_search.py since we need a fixed answer there * [hotfix] fix testcase in test_fx/test_tracer (#5779) * [fix] branch for fix testcase; * [fix] fix test_analyzer & test_auto_parallel; * [fix] remove local change about moe; * [fix] rm local change moe; * [fix] fix test_deepfm_model & test_dlrf_model; * [fix] fix test_hf_albert & test_hf_gpt; * [gemini] optimize reduce scatter d2h copy (#5760) * [gemini] optimize reduce scatter d2h copy * [fix] fix missing reduce variable * [refactor] remove legacy async reduce scatter code * [gemini] missing sync * Revert "[refactor] remove legacy async reduce scatter code" This reverts commit58ad76d466
. * [gemini] further optimize with async all reduce * [fix] pass flag from manager to chunk * Allow building cuda extension without a device. (#5535) Added FORCE_CUDA environment variable support, to enable building extensions where a GPU device is not present but cuda libraries are. * [misc] fix dist logger (#5782) * [install]fix setup (#5786) * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [misc] update requirements (#5787) * [shardformer] fix import (#5788) * upgrade colossal-chat support tp_group>1, add sp for sft * upgrade ppo dpo rm script * run pre-commit * moupdate ci tests, st ci test cases passed, tp failed in generation for ppo, sp is buggy * fix training script * fix ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix transformers version * remove duplicated test * fix datasets version * remove models that require huggingface auth from ci * remove local data path * update ci * remove baichuan from template test due to transformer version conflict * merge * Refactor modeling by adding attention backend Signed-off-by: char-1ee <xingjianli59@gmail.com> * Fix tests and naming Signed-off-by: char-1ee <xingjianli59@gmail.com> * Pass inference model shard configs for module init Signed-off-by: char-1ee <xingjianli59@gmail.com> * Clean up Signed-off-by: char-1ee <xingjianli59@gmail.com> * replace the customized dataloader setup with the build-in one * replace the customized dataloader setup with the build-in one * Remove flash attention backend Signed-off-by: char-1ee <xingjianli59@gmail.com> * fix readme * Fix test import Signed-off-by: char-1ee <xingjianli59@gmail.com> * update sft trainning script * [Inference]refactor baichuan (#5791) * refactor baichuan * remove unused code and add TODO for lazyinit * [test] fix chatglm test kit (#5793) * [shardformer] fix modeling of bloom and falcon (#5796) * [test] fix qwen2 pytest distLarge (#5797) * [Inference] Fix flash-attn import and add model test (#5794) * Fix torch int32 dtype Signed-off-by: char-1ee <xingjianli59@gmail.com> * Fix flash-attn import Signed-off-by: char-1ee <xingjianli59@gmail.com> * Add generalized model test Signed-off-by: char-1ee <xingjianli59@gmail.com> * Remove exposed path to model Signed-off-by: char-1ee <xingjianli59@gmail.com> * Add default value for use_flash_attn Signed-off-by: char-1ee <xingjianli59@gmail.com> * Rename model test Signed-off-by: char-1ee <xingjianli59@gmail.com> --------- Signed-off-by: char-1ee <xingjianli59@gmail.com> * [Gemini] Use async stream to prefetch and h2d data moving (#5781) * use async stream to prefetch and h2d data moving * Remove redundant code * [gemini] quick fix on possible async operation (#5803) * [gemini] quick fix on possible async operation * [gemini] quick fix on possible async operation * [shardformer] upgrade transformers to 4.39.3 (#5815) * [shardformer]upgrade transformers for gpt2/gptj/whisper (#5807) * [shardformer] fix modeling of gpt2 and gptj * [shardformer] fix whisper modeling * [misc] update requirements --------- Co-authored-by: ver217 <lhx0217@gmail.com> * [shardformer]upgrade transformers for mistral (#5808) * upgrade transformers for mistral * fix * fix * [shardformer]upgrade transformers for llama (#5809) * update transformers fix * fix * fix * [inference] upgrade transformers (#5810) * update transformers fix * fix * fix * fix * fix * [gemini] update transformers for gemini (#5814) --------- Co-authored-by: ver217 <lhx0217@gmail.com> * Support 4d parallel + flash attention (#5789) * support tp + sp + pp * remove comments --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> --------- Signed-off-by: char-1ee <xingjianli59@gmail.com> Co-authored-by: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com> Co-authored-by: Hongxin Liu <lhx0217@gmail.com> Co-authored-by: flybird11111 <1829166702@qq.com> Co-authored-by: duanjunwen <935724073@qq.com> Co-authored-by: yuehuayingxueluo <867460659@qq.com> Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu> Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: botbw <wang1570@e.ntu.edu.sg> Co-authored-by: Charles Coulombe <ccoulombe@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: YeAnbang <anbangy2@outlook.com> Co-authored-by: char-1ee <xingjianli59@gmail.com> Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com> Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com> Co-authored-by: Guangyao Zhang <xjtu521@qq.com> * [zero] fix hook bug * [zero] add low level optimizer back (#5839) * [zero] fix param & refactor * [zero] add back original low level opt * [zero] remove moe related * [zero] pass zero tests * [zero] refactor * [chore] add del func back * [zero] comments and naming (#5840) * [zero] modify api (#5843) * [zero] modify api * [test] remove _grad_store access in tests * [test] fix (#5857) * [CI] skip openmoe CI check * [CI] fox pre-commit * [zero] remove redundant memebr init (#5862) * [misc] remove useless code, modify the pg mesh implementation * [misc] remove useless code, modify the pg mesh implementation * [misc] use tempfile * resolve conflict with main branch * [misc] use tempfile in test_moe_checkpoint.py * [misc] remove useless code, add assertion about sequence parallel, move logger into function * [misc] remove useless code --------- Signed-off-by: char-1ee <xingjianli59@gmail.com> Co-authored-by: Frank Lee <somerlee.9@gmail.com> Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu> Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: botbw <wang1570@e.ntu.edu.sg> Co-authored-by: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com> Co-authored-by: Hongxin Liu <lhx0217@gmail.com> Co-authored-by: flybird11111 <1829166702@qq.com> Co-authored-by: duanjunwen <935724073@qq.com> Co-authored-by: yuehuayingxueluo <867460659@qq.com> Co-authored-by: Charles Coulombe <ccoulombe@users.noreply.github.com> Co-authored-by: YeAnbang <anbangy2@outlook.com> Co-authored-by: char-1ee <xingjianli59@gmail.com> Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com> Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com> Co-authored-by: Guangyao Zhang <xjtu521@qq.com>
302 lines
10 KiB
Python
302 lines
10 KiB
Python
import pytest
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.nn as nn
|
|
from torch.testing import assert_close
|
|
|
|
import colossalai
|
|
from colossalai.cluster import DistCoordinator, ProcessGroupMesh
|
|
from colossalai.logging import disable_existing_loggers
|
|
from colossalai.nn.optimizer import DistributedLamb, Lamb
|
|
from colossalai.tensor.d_tensor import get_shard_dim_1d, is_distributed_tensor
|
|
from colossalai.tensor.d_tensor.api import clear_layout_converter
|
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
|
from colossalai.testing.random import seed_all
|
|
from colossalai.zero import LowLevelZeroOptimizer
|
|
from tests.kit.model_zoo import model_zoo
|
|
from tests.test_optimizer._utils import check_optim_states, run_bert_test
|
|
|
|
_ALLOWED_P_G_TYPES = [
|
|
(torch.float, torch.float), # pure fp32
|
|
(torch.float, torch.bfloat16), # bfloat16 amp
|
|
]
|
|
|
|
_IN_DIM = 32
|
|
_HID_DIM = 128
|
|
_N_STEP = 3
|
|
_SEED = 1024
|
|
coordinator = None
|
|
|
|
Net, data_gen, *_ = next(iter(model_zoo.get_sub_registry("simple_mlp").values()))
|
|
TPNet, *_ = next(iter(model_zoo.get_sub_registry("simple_tp_mlp").values()))
|
|
|
|
|
|
def assert_distributed_close(tp_model, torch_model, rtol, atol, tp_group):
|
|
rank = dist.get_rank(tp_group)
|
|
tp_size = dist.get_world_size(tp_group)
|
|
|
|
for (name, p), torch_p in zip(tp_model.named_parameters(), torch_model.parameters()):
|
|
# if overflow, the weight won't be updated. so there will be no nan in p
|
|
assert not torch.isnan(p).any()
|
|
try:
|
|
if is_distributed_tensor(p):
|
|
split_dim = get_shard_dim_1d(p)
|
|
torch_p = torch_p.chunk(tp_size, dim=split_dim)[rank]
|
|
|
|
assert_close(p.float(), torch_p, rtol=rtol, atol=atol)
|
|
except AssertionError as e:
|
|
print(f"grad mismatch in {name}")
|
|
raise e
|
|
|
|
|
|
def setup_param_groups(bert_model: nn.Module) -> list:
|
|
no_decay = ["bias", "LayerNorm.weight"]
|
|
optimizer_grouped_parameters = [
|
|
{
|
|
"params": [p for n, p in bert_model.named_parameters() if not any(nd in n for nd in no_decay)],
|
|
"weight_decay": 0.1,
|
|
},
|
|
{
|
|
"params": [p for n, p in bert_model.named_parameters() if any(nd in n for nd in no_decay)],
|
|
"weight_decay": 0.0,
|
|
},
|
|
]
|
|
return optimizer_grouped_parameters
|
|
|
|
|
|
def force_assign_grad(p, g_dtype, grad=None):
|
|
"""avoid inconsistent grad and param dtype error"""
|
|
orig_p = p.data
|
|
p.data = torch.randn_like(p, device=orig_p.device, dtype=g_dtype) if grad == None else grad
|
|
p.grad = p.data
|
|
p.data = orig_p
|
|
|
|
|
|
def set_dist_grad(
|
|
dist_module: nn.Module,
|
|
torch_model: nn.Module,
|
|
g_dtype: torch.dtype,
|
|
group: dist.ProcessGroup,
|
|
) -> None:
|
|
"""
|
|
Set grads chunks for Tensor Parallel or ZeRO DP.
|
|
We do not need a separate treatment for ZeRO,
|
|
as the LowLevelOptimizer takes care of reduce-scattering grads.
|
|
"""
|
|
rank = dist.get_rank(group)
|
|
world_size = dist.get_world_size(group)
|
|
|
|
for p, torch_p in zip(dist_module.parameters(), torch_model.parameters()):
|
|
if torch_p.grad is None:
|
|
# avoid inconsistent grad and param dtype error
|
|
force_assign_grad(torch_p, g_dtype)
|
|
else:
|
|
torch_p.grad += torch.randn_like(torch_p, device=torch_p.device, dtype=g_dtype)
|
|
|
|
if p.grad is None:
|
|
force_assign_grad(p, g_dtype)
|
|
|
|
if is_distributed_tensor(p):
|
|
split_dim = get_shard_dim_1d(p)
|
|
# Add grads only to the correctly split chunk
|
|
force_assign_grad(p, g_dtype, torch_p.grad.chunk(world_size, dim=split_dim)[rank])
|
|
# assert_close(p.grad, torch_p.grad.chunk(world_size, dim=split_dim)[rank])
|
|
else:
|
|
force_assign_grad(p, g_dtype, torch_p.grad)
|
|
|
|
|
|
@parameterize("p_g_dtype", _ALLOWED_P_G_TYPES)
|
|
@parameterize("bias_correction", [False, True])
|
|
@parameterize("tp_zero_size", [(1, 4), (4, 1), (2, 2)])
|
|
def run_dist_lamb_basic(
|
|
bias_correction: bool, p_g_dtype: tuple[torch.dtype, torch.dtype], tp_zero_size: tuple[int, int]
|
|
) -> None:
|
|
"""Test without forward"""
|
|
p_dtype, g_dtype = p_g_dtype
|
|
tp_size, zero_size = tp_zero_size
|
|
|
|
# Set distributed groups
|
|
rank = dist.get_rank()
|
|
clear_layout_converter() # Ensure correct sharding
|
|
proc_mesh = ProcessGroupMesh(tp_size, zero_size)
|
|
tp_group = proc_mesh.get_group_along_axis(0)
|
|
|
|
tp_rank = dist.get_rank(tp_group)
|
|
seed_all(_SEED) # Fix model init
|
|
torch_model = Net(in_dim=_IN_DIM, hid_dim=_HID_DIM, identity=True).to(rank)
|
|
tp_model = TPNet(torch_model.fc0, torch_model.fc1, torch_model.fc2, tp_group).to(rank)
|
|
# Ensure equal weight init
|
|
assert_close(
|
|
torch_model.fc1.weight[tp_rank * _HID_DIM // tp_size : (tp_rank + 1) * _HID_DIM // tp_size],
|
|
tp_model.fc1.weight,
|
|
)
|
|
assert_close(
|
|
torch_model.fc2.weight[:, tp_rank * _HID_DIM // tp_size : (tp_rank + 1) * _HID_DIM // tp_size],
|
|
tp_model.fc2.weight,
|
|
)
|
|
|
|
# Set up optimizers
|
|
lr = 1e-3
|
|
beta1, beta2 = 0.9, 0.999
|
|
eps = 1e-8
|
|
torch_optim = Lamb(
|
|
setup_param_groups(torch_model), lr=lr, betas=(beta1, beta2), eps=eps, bias_correction=bias_correction
|
|
)
|
|
optim = DistributedLamb(
|
|
setup_param_groups(tp_model),
|
|
lr=lr,
|
|
betas=(beta1, beta2),
|
|
eps=eps,
|
|
bias_correction=bias_correction,
|
|
)
|
|
optim.setup_distributed(tp_group)
|
|
|
|
rtol, atol = 8e-7, 8e-7
|
|
if p_dtype is torch.float16 or g_dtype is torch.float16:
|
|
rtol, atol = 1e-6, 1e-6
|
|
if p_dtype is torch.bfloat16 or g_dtype is torch.bfloat16:
|
|
rtol, atol = 2e-6, 2e-6
|
|
|
|
for i in range(_N_STEP):
|
|
seed_all(_SEED + i) # NOTE: having only one manual_seed above doesn't work?
|
|
set_dist_grad(tp_model, torch_model, g_dtype, tp_group)
|
|
|
|
torch_optim.step()
|
|
optim.step()
|
|
torch_optim.zero_grad()
|
|
optim.zero_grad()
|
|
try:
|
|
assert_distributed_close(tp_model, torch_model, rtol, atol, tp_group)
|
|
except Exception as e:
|
|
coordinator.print_on_master(
|
|
f"step {i + 1}: bias_correction: {bias_correction}, p_g_dtype: {p_g_dtype}, tp_zero_size: {tp_zero_size}"
|
|
)
|
|
raise e
|
|
|
|
|
|
@parameterize("p_g_dtype", _ALLOWED_P_G_TYPES)
|
|
@parameterize("bias_correction", [False, True])
|
|
@parameterize("tp_zero_size", [(2, 2), (4, 1), (1, 4)])
|
|
def run_dist_lamb_fwd_bwd(
|
|
bias_correction: bool, p_g_dtype: tuple[torch.dtype, torch.dtype], tp_zero_size: tuple[int, int]
|
|
) -> None:
|
|
p_dtype, g_dtype = p_g_dtype
|
|
tp_size, zero_size = tp_zero_size
|
|
|
|
# Set distributed groups
|
|
rank = dist.get_rank()
|
|
proc_mesh = ProcessGroupMesh(tp_size, zero_size)
|
|
tp_group = proc_mesh.get_group_along_axis(0)
|
|
dp_group = proc_mesh.get_group_along_axis(1)
|
|
tp_rank = dist.get_rank(tp_group)
|
|
|
|
seed_all(_SEED)
|
|
clear_layout_converter() # Ensure correct sharding
|
|
torch_model = Net(_IN_DIM, _HID_DIM).to(rank)
|
|
tp_model = TPNet(torch_model.fc0, torch_model.fc1, torch_model.fc2, tp_group).to(rank)
|
|
|
|
assert_close(
|
|
torch_model.fc1.weight[tp_rank * _HID_DIM // tp_size : (tp_rank + 1) * _HID_DIM // tp_size],
|
|
tp_model.fc1.weight,
|
|
)
|
|
assert_close(
|
|
torch_model.fc2.weight[:, tp_rank * _HID_DIM // tp_size : (tp_rank + 1) * _HID_DIM // tp_size],
|
|
tp_model.fc2.weight,
|
|
)
|
|
|
|
# Set up optimizers
|
|
lr = 1e-3
|
|
beta1, beta2 = 0.9, 0.999
|
|
eps = 1e-8
|
|
torch_optim = Lamb(
|
|
setup_param_groups(torch_model), lr=lr, betas=(beta1, beta2), eps=eps, bias_correction=bias_correction
|
|
)
|
|
optim = DistributedLamb(
|
|
setup_param_groups(tp_model),
|
|
lr=lr,
|
|
betas=(beta1, beta2),
|
|
eps=eps,
|
|
bias_correction=bias_correction,
|
|
)
|
|
|
|
# Setup distributed optimizer
|
|
if zero_size > 1:
|
|
optim = LowLevelZeroOptimizer(
|
|
optim,
|
|
overlap_communication=True,
|
|
initial_scale=128,
|
|
partition_grad=True,
|
|
dp_process_group=dp_group,
|
|
verbose=True,
|
|
)
|
|
shard_to_param = optim.master_to_working_param
|
|
optim.optim.setup_distributed(tp_group, dp_group, shard_to_param, is_zero=True)
|
|
else:
|
|
optim.setup_distributed(tp_group)
|
|
|
|
rtol, atol = 8e-7, 8e-7
|
|
if p_dtype is torch.float16 or g_dtype is torch.float16:
|
|
rtol, atol = 1e-6, 1e-6
|
|
if p_dtype is torch.bfloat16 or g_dtype is torch.bfloat16:
|
|
rtol, atol = 2e-6, 2e-6
|
|
|
|
seed_all(_SEED) # NOTE: having only one manual_seed above doesn't work?
|
|
x = data_gen()
|
|
x = x.cuda().to(dtype=p_dtype)
|
|
|
|
out_tp = tp_model(x)
|
|
out = torch_model(x)
|
|
try:
|
|
assert_close(out, out_tp, rtol=rtol, atol=atol)
|
|
except Exception as e:
|
|
coordinator.print_on_master(
|
|
f"bias_correction: {bias_correction}, p_g_dtype: {p_g_dtype}, tp_zero_size: {tp_zero_size}"
|
|
)
|
|
raise e
|
|
|
|
if zero_size > 1:
|
|
optim.backward(out_tp.sum())
|
|
out.sum().backward()
|
|
else:
|
|
out_tp.sum().backward()
|
|
out.sum().backward()
|
|
|
|
torch_optim.step()
|
|
optim.step()
|
|
torch_optim.zero_grad()
|
|
optim.zero_grad()
|
|
try:
|
|
assert_distributed_close(tp_model, torch_model, rtol, atol, tp_group)
|
|
check_optim_states(getattr(torch_optim, "optim", torch_optim), getattr(optim, "optim", optim))
|
|
except Exception as e:
|
|
coordinator.print_on_master(
|
|
f"bias_correction: {bias_correction}, p_g_dtype: {p_g_dtype}, tp_zero_size: {tp_zero_size}"
|
|
)
|
|
raise e
|
|
|
|
|
|
def check_dist_lamb(rank, world_size, port):
|
|
disable_existing_loggers()
|
|
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
|
global coordinator
|
|
coordinator = DistCoordinator()
|
|
|
|
run_dist_lamb_basic()
|
|
coordinator.print_on_master("Basic tests passed")
|
|
|
|
run_dist_lamb_fwd_bwd()
|
|
coordinator.print_on_master("Forward-backward tests passed")
|
|
|
|
run_bert_test(optim_class=Lamb, sharded_optim_class=Lamb)
|
|
print(f"rank {rank} tests passed :)")
|
|
|
|
|
|
@pytest.mark.dist
|
|
@rerun_if_address_is_in_use()
|
|
def test_dist_lamb():
|
|
spawn(check_dist_lamb, nprocs=4)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_dist_lamb()
|