[FP8] rebase main (#5963)

* add SimPO

* fix dataloader

* remove debug code

* add orpo

* fix style

* fix colossalai, transformers version

* fix colossalai, transformers version

* fix colossalai, transformers version

* fix torch colossalai version

* update transformers version

* [shardformer] DeepseekMoE support (#5871)

* [Feature] deepseek moe expert parallel implement

* [misc] fix typo, remove redundant file (#5867)

* [misc] fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Feature] deepseek support & unit test

* [misc] remove debug code & useless print

* [misc] fix typos (#5872)

* [Feature] remove modeling file, use auto config. (#5884)

* [misc] fix typos

* [Feature] deepseek support via auto model, remove modeling file

* [misc] delete useless file

* [misc] fix typos

* [Deepseek] remove redundant code (#5888)

* [misc] fix typos

* [Feature] deepseek support via auto model, remove modeling file

* [misc] delete useless file

* [misc] fix typos

* [misc] remove redundant code

* [Feature/deepseek] resolve comment. (#5889)

* [misc] fix typos

* [Feature] deepseek support via auto model, remove modeling file

* [misc] delete useless file

* [misc] fix typos

* [misc] remove redundant code

* [misc] mv module replacement into if branch

* [misc] add some warning message and modify some code in unit test

* [misc] fix typos

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Hoxfix] Fix CUDA_DEVICE_MAX_CONNECTIONS for comm overlap

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [Feat] Diffusion Model(PixArtAlpha/StableDiffusion3) Support (#5838)

* Diffusion Model Inference support

* Stable Diffusion 3 Support

* pixartalpha support

* [HotFix] CI,import,requirements-test for #5838 (#5892)

* [Hot Fix] CI,import,requirements-test

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Feature] Enable PP + SP for llama (#5868)

* fix cross-PP-stage position id length diff bug

* fix typo

* fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* use a one cross entropy func for all shardformer models

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [ShardFormer] Add Ulysses Sequence Parallelism support for Command-R, Qwen2 and ChatGLM (#5897)

* add benchmark for sft, dpo, simpo, orpo. Add benchmarking result. Support lora with gradient checkpoint

* fix style

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix eval

* hotfix citation

* [zero] support all-gather overlap (#5898)

* [zero] support all-gather overlap

* [zero] add overlap all-gather flag

* [misc] fix typo

* [zero] update api

* fix orpo cross entropy loss

* [Auto Parallel]: Speed up intra-op plan generation by 44% (#5446)

* Remove unnecessary calls to deepcopy

* Build DimSpec's difference dict only once

This change considerably speeds up construction speed of DimSpec objects. The difference_dict is the same for each DimSpec object, so a single copy of it is enough.

* Fix documentation of DimSpec's difference method

* [ShardFormer] fix qwen2 sp (#5903)

* [compatibility] support torch 2.2 (#5875)

* Support Pytorch 2.2.2

* keep build_on_pr file and update .compatibility

* fix object_to_tensor usage when torch>=2.3.0 (#5820)

* [misc] support torch2.3 (#5893)

* [misc] support torch2.3

* [devops] update compatibility ci

* [devops] update compatibility ci

* [devops] add debug

* [devops] add debug

* [devops] add debug

* [devops] add debug

* [devops] remove debug

* [devops] remove debug

* [release] update version (#5912)

* [plugin] support all-gather overlap for hybrid parallel (#5919)

* [plugin] fixed all-gather overlap support for hybrid parallel

* add kto

* fix style, add kto data sample

* [Examples] Add lazy init to OPT and GPT examples (#5924)

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [ColossalChat] Hotfix for ColossalChat (#5910)

* add ignore and tiny llama

* fix path issue

* run style

* fix issue

* update bash

* add ignore and tiny llama

* fix path issue

* run style

* fix issue

* update bash

* fix ddp issue

* add Qwen 1.5 32B

* refactor tokenization

* [FIX BUG] UnboundLocalError: cannot access local variable 'default_conversation' where it is not associated with a value (#5931)

* cannot access local variable 'default_conversation' where it is not associated with a value

set default value for 'default_conversation'

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* fix test data

* refactor evaluation

* remove real data path

* remove real data path

* Add n_fused as an input from native_module (#5894)

* [FIX BUG] convert env param to int in (#5934)

* [Hotfix] Fix ZeRO typo #5936

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [Feature] Add a switch to control whether the model checkpoint needs to be saved after each epoch ends (#5941)

* Add a switch to control whether the model checkpoint needs to be saved after each epoch ends

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* fix style

* fix style

* fix style

* [shardformer] hotfix attn mask (#5945)

* [shardformer] hotfix attn mask (#5947)

* [Feat] Distrifusion Acceleration Support for Diffusion Inference (#5895)

* Distrifusion Support source

* comp comm overlap optimization

* sd3 benchmark

* pixart distrifusion bug fix

* sd3 bug fix and benchmark

* generation bug fix

* naming fix

* add docstring, fix counter and shape error

* add reference

* readme and requirement

* [zero] hotfix update master params (#5951)

* [release] update version (#5952)

* [Chat] Fix lora (#5946)

* fix merging

* remove filepath

* fix style

* Update README.md (#5958)

* [hotfix] Remove unused plan section (#5957)

* remove readme

* fix readme

* update

* [test] add mixtral for sequence classification

* [test] add mixtral transformer test

* [moe] fix plugin

* [test] mixtra pp shard test

* [chore] handle non member group

* [zero] solve hang

* [test] pass mixtral shardformer test

* [moe] implement transit between non moe tp and ep

* [zero] solve hang

* [misc] solve booster hang by rename the variable

* solve hang when parallel mode = pp + dp

* [moe] implement submesh initialization

* [moe] add mixtral dp grad scaling when not all experts are activated

* [chore] manually revert unintended commit

* [chore] trivial fix

* [chore] arg pass & remove drop token

* [test] add mixtral modelling test

* [moe] implement tp

* [moe] test deepseek

* [moe] clean legacy code

* [Feature] MoE Ulysses Support (#5918)

* moe sp support

* moe sp bug solve

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [chore] minor fix

* [moe] init moe plugin comm setting with sp

* moe sp + ep bug fix

* [moe] finalize test (no pp)

* [moe] full test for deepseek and mixtral (pp + sp to fix)

* [chore] minor fix after rebase

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [chore] solve moe ckpt test failure and some other arg pass failure

* [moe] remove ops

* [test] fix test: test_zero1_2

* [bug] fix: somehow logger hangs the program

* [moe] deepseek moe sp support

* [test] add check

* [deepseek] replace attn (a workaround for bug in transformers)

* [misc] skip redunant test

* [misc] remove debug/print code

* [moe] refactor mesh assignment

* Revert "[moe] implement submesh initialization"

This reverts commit 2f9bce6686.

* [chore] change moe_pg_mesh to private

* [misc] remove incompatible test config

* [misc] fix ci failure: change default value to false in moe plugin

* [misc] remove useless condition

* [chore] docstring

* [moe] remove force_overlap_comm flag and add warning instead

* [doc] add MoeHybridParallelPlugin docstring

* [moe] solve dp axis issue

* [chore] remove redundant test case, print string & reduce test tokens

* [feat] Dist Loader for Eval (#5950)

* support auto distributed data loader

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* support auto distributed data loader

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix tp error

* remove unused parameters

* remove unused

* update inference

* update docs

* update inference

---------

Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [lora] lora support hybrid parallel plugin (#5956)

* lora support hybrid plugin

* fix

* fix

* fix

* fix

* fp8 operators for compressed communication

cast_to_fp8, cast_from_fp8, all_reduce_fp8

* fix scaling algorithm in FP8 casting

* support fp8 communication in pipeline parallelism

* add fp8_communication flag in the script

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* shardformer fp8

* fix rebase

* remove all to all

* fix shardformer fp8 communication training degradation

* [fp8] support all-gather flat tensor (#5932)

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* Update low_level_optim.py

---------

Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: Haze188 <haze188@qq.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com>
Co-authored-by: Guangyao Zhang <xjtu521@qq.com>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: Stephan Kö <stephankoe@users.noreply.github.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: zhurunhua <1281592874@qq.com>
Co-authored-by: Insu Jang <insujang@umich.edu>
Co-authored-by: Gao, Ruiyuan <905370712@qq.com>
Co-authored-by: hxwang <wang1570@e.ntu.edu.sg>
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: Wang Binluo <32676639+wangbluo@users.noreply.github.com>
Co-authored-by: HangXu <hangxu0304@gmail.com>
This commit is contained in:
flybird11111
2024-08-06 16:29:37 +08:00
committed by GitHub
parent 53cb9606bd
commit 0c10afd372
208 changed files with 10962 additions and 2892 deletions

View File

@@ -3,28 +3,17 @@ from .bert import *
from .blip2 import *
from .bloom import *
from .chatglm2 import *
from .command import *
from .deepseek import *
from .falcon import *
from .gpt import *
from .gptj import *
from .llama import *
from .mistral import *
from .mixtral import *
from .opt import *
from .qwen2 import *
from .sam import *
from .t5 import *
from .vit import *
from .whisper import *
try:
from .mistral import *
except ImportError:
print("This version of transformers doesn't support mistral.")
try:
from .qwen2 import *
except ImportError:
print("This version of transformers doesn't support qwen2.")
try:
from .command import *
except ImportError:
print("This version of transformers doesn't support Command-R.")

View File

@@ -0,0 +1,83 @@
# modified from tests/kit/model_zoo/transformers/mistral.py
import torch
import transformers
from transformers import AutoConfig
from ..registry import ModelAttribute, model_zoo
# ===============================
# Register single-sentence Mixtral
# ===============================
def data_gen():
# Generated from following code snippet
#
# from transformers import AutoModelForCausalLM, AutoTokenizer
# tokenizer = AutoTokenizer.from_pretrained("mixtralai/Mixtral-7B-v0.1")
# input = 'My favourite condiment is vinegar' (last two words repeated to satisfy length requirement)
# tokenized_input = tokenizer([input], return_tensors="pt")
# input_ids = tokenized_input['input_ids']
# attention_mask = tokenized_input['attention_mask']
input_ids = torch.tensor([[1, 22, 55, 77, 532, 349, 43, 22]], dtype=torch.int64)
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
return dict(input_ids=input_ids, attention_mask=attention_mask)
def data_gen_for_lm():
# LM data gen
# the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels`
data = data_gen()
data["labels"] = data["input_ids"].clone()
return data
def data_gen_for_sequence_classification():
# sequence classification data gen
data = data_gen()
data["labels"] = torch.tensor([1], dtype=torch.int64)
return data
# define output transform function
output_transform_fn = lambda x: x
# define loss function
loss_fn_for_mixtral_model = lambda x: x[0].mean()
loss_fn = lambda x: x.loss
loss_fn_for_seq_classification = lambda output: output.logits.mean()
def init_deepseek():
config = AutoConfig.from_pretrained(
"deepseek-ai/deepseek-moe-16b-base",
hidden_size=32,
intermediate_size=32,
moe_intermediate_size=32,
num_hidden_layers=2,
num_attention_heads=8,
num_key_value_heads=8,
# vocab_size=2200,
first_k_dense_replace=1,
attn_implementation="flash_attention_2",
torch_dtype="float16",
n_routed_experts=8,
trust_remote_code=True,
)
if hasattr(config, "pad_token_id"):
config.pad_token_id = config.eos_token_id
model = transformers.AutoModel.from_config(config, trust_remote_code=True)
return model
model_zoo.register(
name="transformers_deepseek",
model_fn=init_deepseek,
data_gen_fn=data_gen,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_mixtral_model,
model_attribute=ModelAttribute(has_control_flow=True),
)

View File

@@ -0,0 +1,85 @@
# modified from tests/kit/model_zoo/transformers/mistral.py
import torch
import transformers
from transformers import MixtralConfig
from ..registry import ModelAttribute, model_zoo
# ===============================
# Register single-sentence Mixtral
# ===============================
def data_gen():
# Generated from following code snippet
#
# from transformers import AutoModelForCausalLM, AutoTokenizer
# tokenizer = AutoTokenizer.from_pretrained("mixtralai/Mixtral-7B-v0.1")
# input = 'My favourite condiment is vinegar' (last two words repeated to satisfy length requirement)
# tokenized_input = tokenizer([input], return_tensors="pt")
# input_ids = tokenized_input['input_ids']
# attention_mask = tokenized_input['attention_mask']
input_ids = torch.tensor([[1, 22, 55, 77, 532, 349, 43, 22]], dtype=torch.int64)
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
return dict(input_ids=input_ids, attention_mask=attention_mask)
def data_gen_for_lm():
# LM data gen
# the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels`
data = data_gen()
data["labels"] = data["input_ids"].clone()
return data
def data_gen_for_sequence_classification():
# sequence classification data gen
data = data_gen()
data["labels"] = torch.tensor([1], dtype=torch.int64)
return data
# define output transform function
output_transform_fn = lambda x: x
# define loss function
loss_fn_for_mixtral_model = lambda x: x[0].mean()
loss_fn = lambda x: x.loss
loss_fn_for_seq_classification = lambda output: output.logits.mean()
config = MixtralConfig(
hidden_size=32,
intermediate_size=32,
num_attention_heads=8,
num_hidden_layers=2,
vocab_size=1000,
output_router_logits=True,
)
if hasattr(config, "pad_token_id"):
config.pad_token_id = config.eos_token_id
model_zoo.register(
name="transformers_mixtral",
model_fn=lambda: transformers.MixtralModel(config),
data_gen_fn=data_gen,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_mixtral_model,
model_attribute=ModelAttribute(has_control_flow=True),
)
# model_zoo.register(
# name="transformers_mixtral_for_casual_lm",
# model_fn=lambda: transformers.MixtralForCausalLM(config),
# data_gen_fn=data_gen_for_lm,
# output_transform_fn=output_transform_fn,
# loss_fn=loss_fn,
# model_attribute=ModelAttribute(has_control_flow=True),
# )
# model_zoo.register(
# name="transformers_mixtral_for_sequence_classification",
# model_fn=lambda: transformers.MixtralForSequenceClassification(config),
# data_gen_fn=data_gen_for_sequence_classification,
# output_transform_fn=output_transform_fn,
# loss_fn=loss_fn_for_seq_classification,
# model_attribute=ModelAttribute(has_control_flow=True),
# )

View File

@@ -0,0 +1,136 @@
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed import ProcessGroup
from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
from colossalai.legacy.engine.gradient_handler._base_gradient_handler import BaseGradientHandler
from colossalai.legacy.engine.gradient_handler.utils import bucket_allreduce
from colossalai.legacy.moe.manager import MOE_MANAGER
from colossalai.legacy.moe.utils import get_moe_epsize_param_dict
from colossalai.legacy.registry import GRADIENT_HANDLER
from colossalai.tensor.moe_tensor.api import get_ep_group, get_ep_size, set_moe_tensor_ep_group
def delete_moe_info(model):
for _, param in model.named_parameters():
if hasattr(param, "ep_group"):
delattr(param, "ep_group")
class MoeModel(nn.Module):
def __init__(self, ep_group: ProcessGroup = None):
super().__init__()
self.test_embed = nn.Linear(4, 16, bias=False)
self.w1 = torch.nn.Parameter(torch.randn(16, 8))
if ep_group:
set_moe_tensor_ep_group(self.w1, ep_group)
def forward(self, x):
x = self.test_embed(x)
x = torch.matmul(x, self.w1)
return x
@GRADIENT_HANDLER.register_module
class MoeGradientHandler(BaseGradientHandler):
"""A helper class to handle all-reduce operations in a data parallel group and
moe model parallel. A all-reduce collective communication will be operated in
:func:`handle_gradient` among a data parallel group.
For better performance, it bucketizes the gradients of all parameters that are
the same type to improve the efficiency of communication.
Args:
model (Module): Model where the gradients accumulate.
optimizer (Optimizer): Optimizer for updating the parameters.
"""
def __init__(self, model, optimizer=None):
super().__init__(model, optimizer)
def handle_gradient(self):
"""A method running an all-reduce operation in a data parallel group.
Then running an all-reduce operation for all parameters in experts
across moe model parallel group
"""
if dist.get_world_size() > 1:
epsize_param_dict = get_moe_epsize_param_dict(self._model)
# epsize is 1, indicating the params are replicated among processes in data parallelism
# use the ParallelMode.DATA to get data parallel group
# reduce gradients for all parameters in data parallelism
if 1 in epsize_param_dict:
bucket_allreduce(param_list=epsize_param_dict[1])
for ep_size in epsize_param_dict:
if ep_size != 1 and ep_size != MOE_MANAGER.world_size:
bucket_allreduce(
param_list=epsize_param_dict[ep_size], group=MOE_MANAGER.parallel_info_dict[ep_size].dp_group
)
def assert_not_equal_in_group(tensor, process_group=None):
# all gather tensors from different ranks
world_size = dist.get_world_size(process_group)
tensor_list = [torch.empty_like(tensor) for _ in range(world_size)]
dist.all_gather(tensor_list, tensor, group=process_group)
# check if they are equal one by one
for i in range(world_size - 1):
a = tensor_list[i]
b = tensor_list[i + 1]
assert not torch.allclose(a, b), (
f"expected tensors on rank {i} and {i + 1} not to be equal " f"but they are, {a} vs {b}"
)
def run_fwd_bwd(model, data, label, criterion, optimizer, enable_autocast=False):
model.train()
with torch.cuda.amp.autocast(enabled=enable_autocast):
if criterion:
y = model(data)
loss = criterion(y, label)
else:
loss = model(data, label)
loss = loss.float()
if isinstance(model, LowLevelZeroModel):
optimizer.backward(loss)
else:
loss.backward()
return y
def sync_local_from_ep(local_model, ep_model, assert_grad_flag: bool = False) -> None:
"""Sync the parameters of tp model from ep model
Args:
local_model (MoeModule)
ep_model (MoeModule)
"""
for (local_name, local_param), (ep_name, ep_param) in zip(
local_model.named_parameters(), ep_model.named_parameters()
):
if "experts" not in local_name:
if assert_grad_flag:
assert torch.allclose(local_param, ep_param), f"local_param: {local_param}, ep_param: {ep_param}"
assert torch.allclose(local_param.grad, ep_param.grad)
else:
local_param.data.copy_(ep_param.data)
continue
# gather param from ep model
param_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))]
dist.all_gather(param_list, ep_param, group=get_ep_group(ep_param))
all_param = torch.cat(param_list, dim=0)
if assert_grad_flag:
grad_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))]
dist.all_gather(grad_list, ep_param.grad, group=get_ep_group(ep_param))
all_grad = torch.cat(grad_list, dim=0)
if assert_grad_flag:
assert torch.allclose(local_param, all_param)
assert torch.allclose(local_param.grad, all_grad)
else:
local_param.data.copy_(all_param.data)

View File

@@ -5,7 +5,7 @@ import torch.nn as nn
import colossalai
from colossalai.accelerator import get_accelerator
from colossalai.moe.manager import MOE_MANAGER
from colossalai.legacy.moe.manager import MOE_MANAGER
# from colossalai.shardformer.layer.moe.layers import SparseMLP
from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn

View File

@@ -4,8 +4,8 @@ import torch.nn as nn
import colossalai
from colossalai.accelerator import get_accelerator
from colossalai.moe.manager import MOE_MANAGER
from colossalai.moe.utils import sync_moe_model_param
from colossalai.legacy.moe.manager import MOE_MANAGER
from colossalai.legacy.moe.utils import sync_moe_model_param
# from colossalai.shardformer.layer.moe import MLPExperts
from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn

View File

@@ -6,7 +6,7 @@ import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import LowLevelZeroPlugin
from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
from colossalai.moe.manager import MOE_MANAGER
from colossalai.legacy.moe.manager import MOE_MANAGER
from colossalai.tensor.moe_tensor.api import is_moe_tensor
from colossalai.testing import rerun_if_address_is_in_use, spawn
from tests.test_moe.moe_utils import MoeModel

View File

@@ -6,7 +6,7 @@ import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import LowLevelZeroPlugin
from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
from colossalai.moe.manager import MOE_MANAGER
from colossalai.legacy.moe.manager import MOE_MANAGER
# from colossalai.shardformer.layer.moe import apply_load_balance
from colossalai.tensor.moe_tensor.api import is_moe_tensor

View File

@@ -9,7 +9,8 @@ from torch.optim import AdamW
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.booster.plugin import HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelModule
from colossalai.testing import check_state_dict_equal, clear_cache_before_run, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
from tests.test_checkpoint_io.utils import shared_tempdir
@@ -20,7 +21,7 @@ def check_fwd_bwd(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type
model = model_fn()
lora_config = LoraConfig(task_type=task_type, r=8, lora_alpha=32, lora_dropout=0.1)
test_plugins = [TorchDDPPlugin(), LowLevelZeroPlugin()]
test_plugins = [TorchDDPPlugin(), LowLevelZeroPlugin(), HybridParallelPlugin(tp_size=1, pp_size=1)]
test_configs = [
{
"lora_config": lora_config,
@@ -59,6 +60,8 @@ def check_fwd_bwd(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type
# test fwd bwd correctness
test_model = model_load
if isinstance(model_load, HybridParallelModule):
model_load = model_load.module.module
model_copy = copy.deepcopy(model_load)
data = data_gen_fn()

View File

@@ -1,142 +1,8 @@
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed import ProcessGroup
from torch.testing import assert_close
from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
from colossalai.legacy.engine.gradient_handler._base_gradient_handler import BaseGradientHandler
from colossalai.legacy.engine.gradient_handler.utils import bucket_allreduce
from colossalai.legacy.registry import GRADIENT_HANDLER
from colossalai.moe.manager import MOE_MANAGER
from colossalai.moe.utils import get_moe_epsize_param_dict
# from colossalai.shardformer.layer.moe import SparseMLP
from colossalai.tensor.moe_tensor.api import get_ep_group, get_ep_size, set_moe_tensor_ep_group
def delete_moe_info(model):
for _, param in model.named_parameters():
if hasattr(param, "ep_group"):
delattr(param, "ep_group")
class MoeModel(nn.Module):
def __init__(self, ep_group: ProcessGroup = None):
super().__init__()
self.test_embed = nn.Linear(4, 16, bias=False)
self.w1 = torch.nn.Parameter(torch.randn(16, 8))
if ep_group:
set_moe_tensor_ep_group(self.w1, ep_group)
def forward(self, x):
x = self.test_embed(x)
x = torch.matmul(x, self.w1)
return x
@GRADIENT_HANDLER.register_module
class MoeGradientHandler(BaseGradientHandler):
"""A helper class to handle all-reduce operations in a data parallel group and
moe model parallel. A all-reduce collective communication will be operated in
:func:`handle_gradient` among a data parallel group.
For better performance, it bucketizes the gradients of all parameters that are
the same type to improve the efficiency of communication.
Args:
model (Module): Model where the gradients accumulate.
optimizer (Optimizer): Optimizer for updating the parameters.
"""
def __init__(self, model, optimizer=None):
super().__init__(model, optimizer)
def handle_gradient(self):
"""A method running an all-reduce operation in a data parallel group.
Then running an all-reduce operation for all parameters in experts
across moe model parallel group
"""
if dist.get_world_size() > 1:
epsize_param_dict = get_moe_epsize_param_dict(self._model)
# epsize is 1, indicating the params are replicated among processes in data parallelism
# use the ParallelMode.DATA to get data parallel group
# reduce gradients for all parameters in data parallelism
if 1 in epsize_param_dict:
bucket_allreduce(param_list=epsize_param_dict[1])
for ep_size in epsize_param_dict:
if ep_size != 1 and ep_size != MOE_MANAGER.world_size:
bucket_allreduce(
param_list=epsize_param_dict[ep_size], group=MOE_MANAGER.parallel_info_dict[ep_size].dp_group
)
def assert_not_equal_in_group(tensor, process_group=None):
# all gather tensors from different ranks
world_size = dist.get_world_size(process_group)
tensor_list = [torch.empty_like(tensor) for _ in range(world_size)]
dist.all_gather(tensor_list, tensor, group=process_group)
# check if they are equal one by one
for i in range(world_size - 1):
a = tensor_list[i]
b = tensor_list[i + 1]
assert not torch.allclose(a, b), (
f"expected tensors on rank {i} and {i + 1} not to be equal " f"but they are, {a} vs {b}"
)
def run_fwd_bwd(model, data, label, criterion, optimizer, enable_autocast=False):
model.train()
with torch.cuda.amp.autocast(enabled=enable_autocast):
if criterion:
y = model(data)
loss = criterion(y, label)
else:
loss = model(data, label)
loss = loss.float()
if isinstance(model, LowLevelZeroModel):
optimizer.backward(loss)
else:
loss.backward()
return y
def sync_local_from_ep(local_model, ep_model, assert_grad_flag: bool = False) -> None:
"""Sync the parameters of tp model from ep model
Args:
local_model (MoeModule)
ep_model (MoeModule)
"""
for (local_name, local_param), (ep_name, ep_param) in zip(
local_model.named_parameters(), ep_model.named_parameters()
):
if "experts" not in local_name:
if assert_grad_flag:
assert torch.allclose(local_param, ep_param), f"local_param: {local_param}, ep_param: {ep_param}"
assert torch.allclose(local_param.grad, ep_param.grad)
else:
local_param.data.copy_(ep_param.data)
continue
# gather param from ep model
param_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))]
dist.all_gather(param_list, ep_param, group=get_ep_group(ep_param))
all_param = torch.cat(param_list, dim=0)
if assert_grad_flag:
grad_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))]
dist.all_gather(grad_list, ep_param.grad, group=get_ep_group(ep_param))
all_grad = torch.cat(grad_list, dim=0)
if assert_grad_flag:
assert torch.allclose(local_param, all_param)
assert torch.allclose(local_param.grad, all_grad)
else:
local_param.data.copy_(all_param.data)
def assert_loose_close(a, b, dtype: torch.dtype = torch.float32, name=""):
assert loose_close(a, b, dtype), f"{name} not close {a.mean()} {b.mean()}"
def loose_close(a, b, dtype: torch.dtype = torch.float32):
@@ -148,8 +14,18 @@ def loose_close(a, b, dtype: torch.dtype = torch.float32):
elif dtype is torch.bfloat16:
rtol = 4e-3
atol = 4e-3
else:
assert dtype is torch.float32
rtol = 1e-05
atol = 1e-08
a = a.detach().to(dtype)
b = b.detach().to(dtype).to(a.device)
assert_close(a, b, rtol=rtol, atol=atol)
return torch.allclose(a, b, rtol=rtol, atol=atol)
def check_model_equal(model1, model2):
assert set(model1.state_dict().keys()) == set(model2.state_dict().keys())
for i, ((name, p1), p2) in enumerate(zip(model1.named_parameters(), model2.parameters())):
assert_loose_close(p1, p2, p1.dtype)

View File

@@ -0,0 +1,78 @@
from copy import deepcopy
import pytest
import torch
import torch.distributed as dist
from torch.testing import assert_close
from transformers import AutoConfig, AutoModel
import colossalai
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.shardformer.modeling.deepseek import EPDeepseekMoE
from colossalai.testing.utils import spawn
tokens, n_experts = 7, 4
hidden_size = 8
top_k = 2
def check_deepseek_moe_layer():
torch.cuda.set_device(dist.get_rank())
plugin = MoeHybridParallelPlugin(
precision="bf16",
tp_size=1,
pp_size=1,
zero_stage=1,
ep_size=dist.get_world_size(),
)
config = AutoConfig.from_pretrained(
"deepseek-ai/deepseek-moe-16b-base",
num_hidden_layers=1,
n_routed_experts=n_experts,
num_experts_per_tok=top_k,
hidden_size=hidden_size,
intermediate_size=hidden_size * 2,
first_k_dense_replace=0,
num_attention_heads=2,
trust_remote_code=True,
)
torch.manual_seed(0)
# get the moe layer in auto model
orig_model = AutoModel.from_config(config, trust_remote_code=True).layers[0].mlp.cuda()
x = torch.rand(1, tokens, hidden_size, requires_grad=True).cuda()
orig_output = orig_model(x)
model = deepcopy(orig_model)
model = EPDeepseekMoE.from_native_module(
model,
ep_group=plugin.ep_group,
moe_dp_group=plugin.moe_dp_group,
tp_group=plugin.tp_group,
)
ep_output = model(x)
assert_close(orig_output, ep_output)
orig_loss = orig_output.mean()
orig_loss.backward()
ep_loss = ep_output.mean()
ep_loss.backward()
assert_close(orig_loss, ep_loss)
name_to_p = {n: p for n, p in orig_model.named_parameters()}
for n, ep_p in model.named_parameters():
p = name_to_p[n]
if ep_p.grad is not None:
assert_close(p.grad, ep_p.grad)
def run_dist(rank: int, world_size: int, port: int):
colossalai.launch(rank, world_size, "localhost", port)
check_deepseek_moe_layer()
@pytest.mark.skip("tested in corresponding sharderformer")
@pytest.mark.parametrize("world_size", [2])
def test_deepseek_moe_layer(world_size: int):
spawn(run_dist, world_size)
if __name__ == "__main__":
test_deepseek_moe_layer(2)

View File

@@ -4,8 +4,6 @@ import pytest
import torch
from colossalai.accelerator import get_accelerator
# from colossalai.moe import SparseMLP
from colossalai.moe._operation import MoeCombine, MoeDispatch, moe_cumsum
NUM_EXPERTS = 4

View File

@@ -23,6 +23,7 @@ def check_mixtral_moe_layer():
precision="bf16",
tp_size=1,
pp_size=1,
zero_stage=1,
ep_size=dist.get_world_size(),
)
config = MixtralConfig(
@@ -36,7 +37,12 @@ def check_mixtral_moe_layer():
x = torch.rand(1, tokens, hidden_size, requires_grad=True).cuda()
orig_output, orig_logits = orig_model(x)
model = deepcopy(orig_model)
model = EPMixtralSparseMoeBlock.from_native_module(model, ep_group=plugin.ep_group)
model = EPMixtralSparseMoeBlock.from_native_module(
model,
ep_group=plugin.ep_group,
tp_group=plugin.tp_group,
moe_dp_group=plugin.moe_dp_group,
)
ep_output, ep_logits = model(x)
assert_close(orig_logits, ep_logits)
assert_close(orig_output, ep_output)
@@ -57,7 +63,8 @@ def run_dist(rank: int, world_size: int, port: int):
check_mixtral_moe_layer()
@pytest.mark.parametrize("world_size", [2, 4])
@pytest.mark.skip("tested in corresponding sharderformer")
@pytest.mark.parametrize("world_size", [2])
def test_mixtral_moe_layer(world_size: int):
spawn(run_dist, world_size)

View File

@@ -6,30 +6,23 @@ from copy import deepcopy
import pytest
import torch
import torch.distributed as dist
from torch.optim import Adam
from torch.optim import SGD, Adam
from transformers.models.mixtral.configuration_mixtral import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.checkpoint_io import MoECheckpointIO
from colossalai.tensor.moe_tensor.api import is_moe_tensor
from colossalai.testing import parameterize, spawn
from colossalai.testing.random import seed_all
from colossalai.testing.utils import spawn
from tests.test_moe.moe_utils import check_model_equal
tokens, n_experts = 7, 4
hidden_size = 8
top_k = 2
def check_model_equal(model1, model2):
assert set(model1.state_dict().keys()) == set(model2.state_dict().keys())
for i, ((name, p1), p2) in enumerate(zip(model1.named_parameters(), model2.parameters())):
if not torch.equal(p1.half(), p2.half()):
print(f"Model parameter {name} is not equal. is_moe_tensor: {is_moe_tensor(p1)}")
raise AssertionError(f"Model parameter {name} is not equal")
def get_optimizer_snapshot(optim):
state = {id(k): deepcopy(v) for k, v in optim.state.items()}
param_groups = []
@@ -77,36 +70,44 @@ def check_optimizer_snapshot_equal(snapshot1, snapshot2, param2name, moe_dp_grou
raise AssertionError(f"A total of {count} optim states are not equal")
def check_mixtral_moe_layer():
@parameterize(
"test_config",
[
[
MixtralConfig(
hidden_size=hidden_size,
intermediate_size=hidden_size * 2,
num_local_experts=n_experts,
num_experts_per_tok=top_k,
num_attention_heads=2,
num_key_value_heads=2,
num_hidden_layers=2,
),
MixtralForCausalLM,
],
],
)
def check_moe_checkpoint(test_config):
dtype, precision = torch.float16, "fp16"
config, model_cls = test_config
torch.cuda.set_device(dist.get_rank())
context = tempfile.TemporaryDirectory() if dist.get_rank() == 0 else nullcontext()
with context as f:
torch.cuda.set_device(dist.get_rank())
if dist.get_rank() == 0:
broadcast_objects = [f] # any picklable object
else:
broadcast_objects = [None]
dist.broadcast_object_list(broadcast_objects, src=0)
config = MixtralConfig(
hidden_size=hidden_size,
intermediate_size=hidden_size * 2,
num_local_experts=n_experts,
num_experts_per_tok=top_k,
num_attention_heads=2,
num_key_value_heads=2,
)
torch.manual_seed(0)
input_ids = torch.randint(0, 100, (2, tokens)).cuda()
orig_model = MixtralForCausalLM(config).cuda()
orig_model = model_cls(config).cuda().to(dtype)
seed_all(10086)
model = deepcopy(orig_model)
optimizer = Adam(model.parameters(), lr=1e-3)
optimizer = SGD(model.parameters(), lr=1e-3)
plugin = MoeHybridParallelPlugin(
pp_size=2,
ep_size=2,
tp_size=1,
checkpoint_io=MoECheckpointIO,
microbatch_size=1,
zero_stage=1,
pp_size=2, ep_size=2, tp_size=1, microbatch_size=1, zero_stage=1, precision=precision
)
booster = Booster(plugin=plugin)
model, optimizer, *_ = booster.boost(model=model, optimizer=optimizer)
@@ -120,7 +121,6 @@ def check_mixtral_moe_layer():
lambda outputs, inputs: outputs.loss,
optimizer,
)
tmpdirname = broadcast_objects[0]
model_dir = os.path.join(tmpdirname, "mixtral_model")
hf_model_dir = os.path.join(tmpdirname, "mixtral_hf_model")
@@ -129,13 +129,12 @@ def check_mixtral_moe_layer():
booster.save_model(model, model_dir, shard=True)
dist.barrier()
if dist.get_rank() == 0:
saved_model = MixtralForCausalLM.from_pretrained(model_dir).cuda()
saved_model = model_cls.from_pretrained(model_dir).cuda().to(dtype)
check_model_equal(orig_model, saved_model)
# check_model_equal(model, saved_model)
saved_model.save_pretrained(hf_model_dir)
dist.barrier()
# check load model
new_model = MixtralForCausalLM(config).cuda()
new_model = model_cls(config).cuda().to(dtype)
new_optimizer = Adam(new_model.parameters(), lr=1e-3)
new_model, new_optimizer, *_ = booster.boost(model=new_model, optimizer=new_optimizer)
booster.load_model(new_model, hf_model_dir)
@@ -163,7 +162,7 @@ def check_mixtral_moe_layer():
def run_dist(rank: int, world_size: int, port: int):
colossalai.launch(rank, world_size, "localhost", port)
check_mixtral_moe_layer()
check_moe_checkpoint()
# Test EP + ZeRO + PP

View File

@@ -1,238 +1,132 @@
import os
import warnings
from typing import Dict
from copy import deepcopy
import pytest
import torch
import torch.distributed as dist
from transformers.models.mixtral.configuration_mixtral import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralModel
import colossalai
from colossalai.accelerator import get_accelerator
from colossalai.moe.manager import MOE_MANAGER
from colossalai.moe.utils import sync_moe_model_param
from colossalai.booster.booster import Booster
from colossalai.booster.plugin import HybridParallelPlugin
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing.random import seed_all
from tests.test_moe.moe_utils import assert_loose_close
# from colossalai.shardformer.layer import SparseMLP
from colossalai.tensor.moe_tensor.api import get_ep_group, get_ep_rank, get_ep_size, is_moe_tensor
from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn
from tests.test_moe.moe_utils import MoeGradientHandler
NUM_BATCH = 4
NUM_TOK_PER_BATCH, NUM_EXPERTS = 7, 4
HIDDEN_SIZE_PER_HEAD = 4
NUM_HEADS = 4
TOP_K = 2
def sync_tp_from_local(tp_model, local_model, assert_grad_flag: bool = False) -> None:
"""Sync the parameters of tp model from local model
@parameterize("stage", [1])
@parameterize("ep_size", [2])
def run_zero_with_original_model(stage: int, ep_size: int):
tp_size = dist.get_world_size() // ep_size
dtype = torch.bfloat16
Args:
tp_model (MoeModule)
local_model (MoeModule)
"""
for (tp_name, tp_param), (local_name, local_param) in zip(
tp_model.named_parameters(), local_model.named_parameters()
):
assert tp_name == local_name
if not is_moe_tensor(tp_param):
if assert_grad_flag:
assert torch.allclose(tp_param, local_param)
assert torch.allclose(tp_param.grad, local_param.grad)
else:
tp_param.data.copy_(local_param.data)
continue
rank = torch.distributed.get_rank()
torch.cuda.set_device(dist.get_rank())
tp_rank = get_ep_rank(tp_param)
tp_dim = [i for i, (d1, d2) in enumerate(zip(tp_param.shape, local_param.shape)) if d1 != d2][0]
tp_slice = [slice(None)] * tp_dim + [
slice(tp_param.shape[tp_dim] * tp_rank, tp_param.shape[tp_dim] * (tp_rank + 1))
]
seed_all(10086)
if assert_grad_flag:
assert torch.allclose(tp_param, local_param[tuple(tp_slice)])
assert torch.allclose(tp_param.grad, local_param.grad[tuple(tp_slice)])
else:
tp_param.data.copy_(local_param[tuple(tp_slice)].data)
def sync_tp_from_ep(tp_model, ep_model, assert_grad_flag: bool = False) -> None:
"""Sync the parameters of tp model from ep model
Args:
tp_model (MoeModule)
ep_model (MoeModule)
"""
for (tp_name, tp_param), (ep_name, ep_param) in zip(tp_model.named_parameters(), ep_model.named_parameters()):
assert tp_name == ep_name
if not is_moe_tensor(tp_param):
if assert_grad_flag:
assert torch.allclose(tp_param, ep_param)
assert torch.allclose(tp_param.grad, ep_param.grad)
else:
tp_param.data.copy_(ep_param.data)
continue
# gather param from ep model
param_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))]
dist.all_gather(param_list, ep_param, group=get_ep_group(ep_param))
all_param = torch.cat(param_list, dim=0)
if assert_grad_flag:
grad_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))]
dist.all_gather(grad_list, ep_param.grad, group=get_ep_group(ep_param))
all_grad = torch.cat(grad_list, dim=0)
# get tp param
tp_dim = [i for i, (d1, d2) in enumerate(zip(tp_param.shape[1:], all_param.shape[1:])) if d1 != d2][0] + 1
tp_rank = get_ep_rank(tp_param)
tp_slice = [slice(None)] * tp_dim + [
slice(tp_param.shape[tp_dim] * tp_rank, tp_param.shape[tp_dim] * (tp_rank + 1))
]
new_tp_param = all_param[tuple(tp_slice)]
if assert_grad_flag:
new_grad = all_grad[tuple(tp_slice)]
if assert_grad_flag:
assert torch.allclose(tp_param, new_tp_param)
assert torch.allclose(tp_param.grad, new_grad)
else:
tp_param.data.copy_(new_tp_param.data)
def sync_local_from_ep(local_model, ep_model, assert_grad_flag: bool = False) -> None:
"""Sync the parameters of tp model from ep model
Args:
local_model (MoeModule)
ep_model (MoeModule)
"""
for (local_name, local_param), (ep_name, ep_param) in zip(
local_model.named_parameters(), ep_model.named_parameters()
):
assert local_name == ep_name
if "experts" not in local_name:
if assert_grad_flag:
assert torch.allclose(local_param, ep_param)
assert torch.allclose(local_param.grad, ep_param.grad)
else:
local_param.data.copy_(ep_param.data)
continue
# gather param from ep model
param_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))]
dist.all_gather(param_list, ep_param, group=get_ep_group(ep_param))
all_param = torch.cat(param_list, dim=0)
if assert_grad_flag:
grad_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))]
dist.all_gather(grad_list, ep_param.grad, group=get_ep_group(ep_param))
all_grad = torch.cat(grad_list, dim=0)
if assert_grad_flag:
assert torch.allclose(local_param, all_param)
assert torch.allclose(local_param.grad, all_grad)
else:
local_param.data.copy_(all_param.data)
def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size: int, dim: int, config: Dict):
assert batch_size % world_size == 0
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
MOE_MANAGER.__init__()
MOE_MANAGER.setup(parallel=None)
local_model = SparseMLP(num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2)
MOE_MANAGER.__init__()
MOE_MANAGER.setup(parallel="EP")
enable_hierarchical_comm = config.get("enable_hierarchical_comm", False)
if enable_hierarchical_comm:
os.environ["LOCAL_WORLD_SIZE"] = str(world_size)
ep_model = SparseMLP(
num_experts=num_experts,
hidden_size=dim,
intermediate_size=dim * 2,
enable_hierarchical_comm=enable_hierarchical_comm,
config = MixtralConfig(
hidden_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS,
intermediate_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2,
num_hidden_layers=2,
num_attention_heads=NUM_HEADS,
num_key_value_heads=NUM_HEADS,
num_local_experts=NUM_EXPERTS,
num_experts_per_tok=TOP_K,
)
MOE_MANAGER.__init__()
MOE_MANAGER.setup(parallel="TP")
tp_model = SparseMLP(num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2)
ep_model = ep_model.to(get_accelerator().get_current_device())
tp_model = tp_model.to(get_accelerator().get_current_device())
local_model = local_model.to(get_accelerator().get_current_device())
torch_model = MixtralModel(config).to(dtype).cuda()
# sync ep param
sync_moe_model_param(ep_model)
dist_dict = MOE_MANAGER.parallel_info_dict
assert_equal_in_group(ep_model.experts.wi.data, dist_dict[world_size].dp_group)
assert_equal_in_group(ep_model.experts.wo.data, dist_dict[world_size].dp_group)
ep_grad_handler = MoeGradientHandler(ep_model)
# sync local param
sync_local_from_ep(local_model, ep_model)
# sync tp param
sync_tp_from_ep(tp_model, ep_model)
tp_grad_handler = MoeGradientHandler(tp_model)
rank = dist.get_rank()
input_data = torch.randn(batch_size, dim, device=get_accelerator().get_current_device())
micro_batch_size = batch_size // world_size
index = rank * micro_batch_size
# NOTE: ep & tp takes in sharded data for each process
shard_data = input_data.detach()[index : index + micro_batch_size]
out_local = local_model(input_data)
MOE_MANAGER.reset_loss()
out_tp = tp_model(shard_data)
MOE_MANAGER.reset_loss()
out_ep = ep_model(shard_data)
MOE_MANAGER.reset_loss()
assert torch.allclose(
out_tp, out_ep, atol=1e-6
), f"Rank {rank} failed, max diff: {torch.max(torch.abs(out_tp - out_ep))}"
try:
out_local_slice = out_local[index : index + micro_batch_size]
assert torch.allclose(
out_ep, out_local_slice, atol=1e-6
), f"Rank {rank} failed, max diff: {torch.max(torch.abs(out_ep - out_local_slice))}"
except AssertionError:
"""
e.g., in local model, tokens = 4, capacity = 2, experts = 2, topk = 1
router yields [01] --> [0], [23] --> [1], this is valid as capacity is 2
However, in ep mode, there are 2 separate routers dealing with sharded data.
Assume router 0 handles token [01] and router 1 handles token [23].
Note that for each router the capacity is only 1 !!!
Thus, router 0 may yields [0] --> [0] or [1] --> [0], but not both.
The same thing happens on router 1. And finally some tokens are dropped due to the sharded nature.
"""
warnings.warn(
"EP & TP may result in different behavior from local model. " "Please check the comments for details."
zero_model = deepcopy(torch_model).to(dtype)
zero_optimizer = torch.optim.SGD(zero_model.parameters(), lr=1)
moe_booster = Booster(
plugin=MoeHybridParallelPlugin(
tp_size=tp_size,
moe_tp_size=tp_size,
pp_size=1,
ep_size=ep_size,
zero_stage=stage,
overlap_communication=False,
initial_scale=1,
)
)
zero_model, zero_optimizer, _, _, _ = moe_booster.boost(zero_model, zero_optimizer)
out_local.mean().backward()
out_tp.mean().backward()
tp_grad_handler.handle_gradient()
out_ep.mean().backward()
ep_grad_handler.handle_gradient()
assert_equal_in_group(ep_model.experts.wi.grad, dist_dict[world_size].dp_group)
assert_equal_in_group(ep_model.experts.wo.grad, dist_dict[world_size].dp_group)
sync_tp_from_ep(tp_model, ep_model, assert_grad_flag=True)
try:
sync_local_from_ep(local_model, ep_model, assert_grad_flag=True)
except AssertionError:
warnings.warn(
"EP & TP may result in different behavior from local model. " "Please check the comments for details."
hybird_booster = Booster(
plugin=HybridParallelPlugin(
tp_size=tp_size,
pp_size=1,
zero_stage=stage,
overlap_communication=False,
initial_scale=1,
)
)
hybrid_model, hybrid_optimizer, _, _, _ = hybird_booster.boost(
torch_model, torch.optim.SGD(torch_model.parameters(), lr=1)
)
# create different input
seed_all(1453 + rank)
hybrid_model.train()
zero_model.train()
for _ in range(2):
# zero-dp forward
input_data = torch.rand(
NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True
).cuda()
zero_output = zero_model(inputs_embeds=input_data.to(dtype)).last_hidden_state.mean()
# zero-dp backward
zero_optimizer.backward(zero_output)
# torch-ddp forward
hybrid_output = hybrid_model(inputs_embeds=input_data.to(dtype)).last_hidden_state.mean()
assert_loose_close(zero_output, hybrid_output, dtype=dtype)
# torch-ddp backward
hybrid_optimizer.backward(hybrid_output)
# check grad
name_to_p = {n: p for n, p in hybrid_model.named_parameters()}
for n, p in zero_model.named_parameters():
zero_grad = zero_optimizer.get_param_grad(p)
if name_to_p[n].grad is None:
name_to_p[n].grad = torch.zeros_like(name_to_p[n])
continue
if zero_grad.shape != name_to_p[n].grad.shape: # TODO check sharded and sliced moe
continue
assert_loose_close(zero_grad, name_to_p[n].grad, dtype=dtype, name=n)
# zero-dp step
zero_optimizer.step()
# original model step
hybrid_optimizer.step()
# check updated param
for n, p in zero_model.named_parameters():
if p.data.shape != name_to_p[n].data.shape: # TODO check sharded and sliced moe
continue
assert_loose_close(p.data, name_to_p[n].data, dtype=dtype, name=n)
print(f"{dist.get_rank()} test passed")
@pytest.mark.skip(reason="moe need to be refactored")
def run_dist(rank, world_size, port):
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_zero_with_original_model()
@pytest.mark.skip("tested in corresponding sharderformer")
@pytest.mark.dist
@pytest.mark.parametrize("num_experts", [4, 64])
@pytest.mark.parametrize("batch_size", [16])
@pytest.mark.parametrize("dim", [64])
@pytest.mark.parametrize(
"config",
[
{"enable_hierarchical_comm": False},
{"enable_hierarchical_comm": True},
],
)
@pytest.mark.parametrize("world_size", [4])
@rerun_if_address_is_in_use()
def test_moe_ep_tp(num_experts: int, batch_size: int, dim: int, config: Dict):
spawn(run_test, 2, num_experts=num_experts, batch_size=batch_size, dim=dim, config=config)
def test_moe_ep_tp(world_size):
spawn(run_dist, world_size)
if __name__ == "__main__":
test_moe_ep_tp(num_experts=8, batch_size=32, dim=32)
test_moe_ep_tp(world_size=4)

View File

@@ -0,0 +1,119 @@
from copy import deepcopy
import pytest
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from transformers.models.mixtral.configuration_mixtral import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralModel
import colossalai
from colossalai.booster.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing.random import seed_all
from tests.test_moe.moe_utils import assert_loose_close
NUM_BATCH = 4
NUM_TOK_PER_BATCH, NUM_EXPERTS = 7, 4
HIDDEN_SIZE_PER_HEAD = 4
NUM_HEADS = 2
TOP_K = 1
@parameterize("stage", [1])
@parameterize("ep_size", [2, 4])
def run_zero_with_original_model(stage: int, ep_size: int):
dtype = torch.bfloat16
rank = torch.distributed.get_rank()
torch.cuda.set_device(dist.get_rank())
plugin = MoeHybridParallelPlugin(
pp_size=1, tp_size=1, ep_size=ep_size, zero_stage=stage, overlap_communication=False, initial_scale=1
)
booster = Booster(plugin=plugin)
seed_all(10086)
config = MixtralConfig(
hidden_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS,
intermediate_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2,
num_hidden_layers=2,
num_attention_heads=NUM_HEADS,
num_key_value_heads=NUM_HEADS,
num_local_experts=NUM_EXPERTS,
num_experts_per_tok=TOP_K,
)
torch_model = MixtralModel(config).to(dtype).cuda()
zero_model = deepcopy(torch_model).to(dtype)
zero_optimizer = torch.optim.SGD(zero_model.parameters(), lr=1)
zero_model, zero_optimizer, _, _, _ = booster.boost(zero_model, zero_optimizer)
ddp_model = DDP(
torch_model.cuda(),
process_group=plugin.dp_group,
find_unused_parameters=True, # important for torch ddp, not all experts are routed
).cuda()
ddp_optimizer = torch.optim.SGD(ddp_model.parameters(), lr=1)
# create different input
seed_all(1453 + rank)
ddp_model.train()
zero_model.train()
for _ in range(2):
# zero-dp forward
input_data = torch.rand(
NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True
).cuda()
zero_output = zero_model(inputs_embeds=input_data.to(dtype)).last_hidden_state.mean()
# zero-dp backward
zero_optimizer.backward(zero_output)
# torch-ddp forward
ddp_output = ddp_model(inputs_embeds=input_data.to(dtype)).last_hidden_state.mean()
assert_loose_close(zero_output, ddp_output, dtype=dtype)
# torch-ddp backward
ddp_output.backward()
# check grad
name_to_p = {n: p for n, p in ddp_model.named_parameters()}
for n, p in zero_model.named_parameters():
zero_grad = zero_optimizer.get_param_grad(p)
if name_to_p[n].grad is None:
name_to_p[n].grad = torch.zeros_like(name_to_p[n].data)
continue
assert_loose_close(zero_grad, name_to_p[n].grad, dtype=dtype, name=n)
# zero-dp step
zero_optimizer.step()
# original model step
ddp_optimizer.step()
# check updated param
for n, p in zero_model.named_parameters():
assert_loose_close(p.data, name_to_p[n].data, dtype=dtype, name=n)
print(f"{dist.get_rank()} test passed")
def run_dist(rank, world_size, port):
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_zero_with_original_model()
@pytest.mark.skip("tested in corresponding sharderformer")
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [4])
@rerun_if_address_is_in_use()
def test_moe_ep_zero(world_size):
spawn(run_dist, world_size)
if __name__ == "__main__":
test_moe_ep_zero(world_size=4)

View File

@@ -1,132 +0,0 @@
from copy import deepcopy
import pytest
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from transformers.models.mixtral.configuration_mixtral import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
import colossalai
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.shardformer.modeling.mixtral import EPMixtralSparseMoeBlock
from colossalai.tensor.moe_tensor.api import is_moe_tensor
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing.random import seed_all
from colossalai.zero import LowLevelZeroOptimizer
from tests.test_moe.moe_utils import loose_close
tokens, n_experts = 7, 4
hidden_size = 8
top_k = 2
def split_grad(grad, world_size):
with torch.no_grad():
grad = grad.clone().detach().flatten()
padding_size = (world_size - grad.numel() % world_size) % world_size
if padding_size > 0:
grad = torch.nn.functional.pad(grad, [0, padding_size])
splited_grad = grad.split(grad.numel() // world_size)
return splited_grad
@parameterize("dtype", [torch.float16, torch.bfloat16])
@parameterize("master_weights", [True, False])
@parameterize("stage", [1, 2])
def run_zero_with_original_model(world_size, master_weights: bool, dtype: torch.dtype, stage: int):
rank = torch.distributed.get_rank()
torch.cuda.set_device(dist.get_rank())
plugin = MoeHybridParallelPlugin(
tp_size=1,
pp_size=1,
ep_size=dist.get_world_size() // 2,
)
seed_all(10086)
config = MixtralConfig(
hidden_size=hidden_size,
intermediate_size=hidden_size * 2,
num_local_experts=n_experts,
num_experts_per_tok=top_k,
)
orig_model = MixtralSparseMoeBlock(config).to(dtype).cuda()
ori_model = DDP(orig_model.cuda(), static_graph=True).cuda()
zero_model = deepcopy(orig_model).to(dtype)
zero_model = EPMixtralSparseMoeBlock.from_native_module(zero_model, ep_group=plugin.ep_group)
zero_optimizer = torch.optim.SGD(zero_model.parameters(), lr=1)
pg_param_list = {plugin.global_dp_group: [], plugin.moe_dp_group: []}
for p in zero_model.parameters():
if is_moe_tensor(p):
pg_param_list[plugin.moe_dp_group].append(p)
else:
pg_param_list[plugin.global_dp_group].append(p)
zero_optimizer = LowLevelZeroOptimizer(
zero_optimizer,
pg_to_param_list=pg_param_list,
master_weights=master_weights,
initial_scale=1,
overlap_communication=False,
partition_grad=True,
)
ori_optimizer = torch.optim.SGD(ori_model.parameters(), lr=1)
# create
seed_all(1453 + rank)
for _ in range(2):
# zero-dp forward
input_data = torch.rand(1, tokens, hidden_size).cuda()
zero_output, zero_logits = zero_model(input_data.to(dtype))
# torch-ddp forward
ori_output, ori_logits = ori_model(input_data.to(dtype))
loose_close(zero_output, ori_output, dtype=dtype)
# zero-dp backward
zero_optimizer.backward(zero_output.mean().float())
# torch-ddp backward
ori_output.mean().backward()
# check grad
name_to_p = {n: p for n, p in ori_model.module.named_parameters()}
for n, p in zero_model.named_parameters():
zero_grad = zero_optimizer.get_param_grad(p)
if name_to_p[n].grad is None:
assert zero_grad is None
continue
loose_close(zero_grad, name_to_p[n].grad, dtype=dtype)
# zero-dp step
zero_optimizer.step()
# original model step
ori_optimizer.step()
# check updated param
for n, p in zero_model.named_parameters():
loose_close(p.data, name_to_p[n].data, dtype=dtype)
def run_dist(rank, world_size, port):
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_zero_with_original_model(world_size=world_size)
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [2, 4])
@rerun_if_address_is_in_use()
def test_moe_zero_model(world_size):
spawn(run_dist, world_size)
if __name__ == "__main__":
test_moe_zero_model(world_size=4)

View File

@@ -1,6 +1,6 @@
import copy
from contextlib import nullcontext
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional, Type
import torch
import torch.distributed as dist
@@ -117,7 +117,12 @@ def check_state_dict(org_model: Module, sharded_model: Module, name: str = ""):
def build_model_from_hybrid_plugin(
model_fn: Callable, loss_fn: Callable, test_config: Dict[str, Any], optim_class=Adam, sharded_optim_class=Adam
model_fn: Callable,
loss_fn: Callable,
test_config: Dict[str, Any],
optim_class=Adam,
sharded_optim_class=Adam,
pluggin_cls: Type[HybridParallelPlugin] = HybridParallelPlugin,
):
use_lazy_init = False
if "use_lazy_init" in test_config:
@@ -149,9 +154,10 @@ def build_model_from_hybrid_plugin(
else:
org_optimizer = optim_class(org_model.parameters(), lr=1e-3)
sharded_optimizer = sharded_optim_class(sharded_model.parameters(), lr=1e-3)
criterion = loss_fn
plugin = HybridParallelPlugin(**test_config)
plugin = pluggin_cls(**test_config)
booster = Booster(plugin=plugin)
sharded_model, sharded_optimizer, criterion, _, _ = booster.boost(sharded_model, sharded_optimizer, criterion)

View File

@@ -136,6 +136,44 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
@parameterize(
"test_config",
[
{ # Ulysess + Flash attention
"tp_size": 1,
"pp_size": 2,
"sp_size": 2,
"num_microbatches": 2,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "all_to_all",
"enable_flash_attention": True,
"use_lazy_init": True,
"zero_stage": 1,
"precision": "fp16",
"initial_scale": 1,
},
{
"tp_size": 2,
"pp_size": 2,
"sp_size": 2,
"num_microbatches": 2,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "split_gather",
"enable_flash_attention": True,
"use_lazy_init": True,
"zero_stage": 1,
"precision": "fp16",
"initial_scale": 1,
},
{
"tp_size": 1,
"pp_size": 1,
"sp_size": 2,
"num_microbatches": 1,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "all_to_all",
"use_lazy_init": True,
"zero_stage": 1,
"precision": "fp16",
"initial_scale": 1,
},
{
"tp_size": 4,
"pp_size": 1,

View File

@@ -58,6 +58,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
# Check the grad when using ZeRO-1 and ZeRO-2
if (
booster.plugin.zero_stage in [1, 2]
and booster.plugin.shard_config.pipeline_stage_manager is None
and booster.plugin.shard_config.enable_sequence_parallelism
and booster.plugin.shard_config.sequence_parallelism_mode == "all_to_all"
):
@@ -154,6 +155,45 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
@parameterize(
"test_config",
[
{ # Ulysess + Flash attention
"tp_size": 1,
"pp_size": 2,
"sp_size": 2,
"num_microbatches": 2,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "all_to_all",
"enable_flash_attention": True,
"use_lazy_init": True,
"zero_stage": 1,
"precision": "fp16",
"initial_scale": 1,
},
{
"tp_size": 2,
"pp_size": 2,
"sp_size": 2,
"num_microbatches": 2,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "split_gather",
"enable_flash_attention": True,
"use_lazy_init": True,
"zero_stage": 1,
"precision": "fp16",
"initial_scale": 1,
},
{
"tp_size": 2,
"pp_size": 2,
"sp_size": 2,
"num_microbatches": 2,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "ring",
"enable_flash_attention": True,
"use_lazy_init": True,
"zero_stage": 1,
"precision": "fp16",
"initial_scale": 1,
},
{
"tp_size": 2,
"pp_size": 1,

View File

@@ -0,0 +1,196 @@
import os
import shutil
from copy import deepcopy
from typing import Tuple
import pytest
import torch
import torch.distributed
import torch.distributed as dist
from transformers import AutoConfig, AutoModel
import colossalai
from colossalai.booster.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing.random import seed_all
from tests.test_moe.moe_utils import assert_loose_close, check_model_equal
NUM_BATCH = 8
NUM_TOK_PER_BATCH, NUM_EXPERTS = 4, 2
NUM_LAYERS = 4
HIDDEN_SIZE_PER_HEAD = 4
NUM_HEADS = 4
TOP_K = 2
CHECKED_CONFIG = [ # FOR_WORLD=4
(1, 4, 1, 1, 1),
(1, 1, 4, 1, 1),
(1, 1, 1, 4, 1),
(1, 1, 1, 1, 4),
(0, 1, 4, 1, 1),
(0, 1, 1, 4, 1),
(0, 1, 1, 1, 4),
(1, 2, 1, 1, 1),
]
@parameterize(
"config",
[
(1, 2, 2, 1, 1),
(1, 2, 1, 2, 1),
(1, 2, 1, 1, 2),
],
)
def run_zero_with_original_model(config: Tuple[int, ...]):
stage, ep_size, pp_size, tp_size, sp_size = config
world_size = dist.get_world_size()
rank = dist.get_rank()
dtype, precision = torch.float16, "fp16"
torch.cuda.set_device(dist.get_rank())
plugin = MoeHybridParallelPlugin(
pp_size=pp_size,
num_microbatches=pp_size,
tp_size=tp_size,
sp_size=sp_size,
ep_size=ep_size,
zero_stage=stage,
enable_sequence_parallelism=sp_size > 1,
sequence_parallelism_mode="all_to_all" if sp_size > 1 else None,
enable_flash_attention=sp_size > 1,
overlap_communication=False,
initial_scale=1,
precision=precision,
find_unused_parameters=True,
)
dp_size = plugin.dp_size
booster = Booster(plugin=plugin)
assert pp_size <= NUM_LAYERS, "pp_size should be less than or equal to NUM_LAYERS"
config = AutoConfig.from_pretrained(
"deepseek-ai/deepseek-moe-16b-base",
hidden_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS,
intermediate_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2,
moe_intermediate_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2,
num_hidden_layers=4,
num_attention_heads=NUM_HEADS,
num_key_value_heads=NUM_HEADS,
first_k_dense_replace=1,
attn_implementation="flash_attention_2",
torch_dtype="float16",
n_routed_experts=NUM_EXPERTS,
num_experts_per_tok=TOP_K,
trust_remote_code=True,
)
# init model with the same seed
seed_all(10086)
torch_model = AutoModel.from_config(config, trust_remote_code=True).cuda().to(dtype)
torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)
parallel_model = deepcopy(torch_model)
parallel_optimizer = torch.optim.SGD(parallel_model.parameters(), lr=1)
parallel_model, parallel_optimizer, _, _, _ = booster.boost(parallel_model, parallel_optimizer)
# create different input along dp axis
seed_all(1453 + rank)
torch_model.train()
parallel_model.train()
for _ in range(2):
# gen random input
input_embeddings = torch.rand(
NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True
).cuda()
dist.all_reduce(
input_embeddings, group=plugin.pp_group
) # pp inputs except the first stage doesn't matter, but need to be replicate for torch model check
dist.all_reduce(input_embeddings, group=plugin.tp_group) # tp group duplicate input
dist.all_reduce(input_embeddings, group=plugin.sp_group) # sp group duplicate input
# run the model with hybrid parallel
if booster.plugin.stage_manager is not None:
# for test with pp
data_iter = iter([{"inputs_embeds": input_embeddings}])
sharded_output = booster.execute_pipeline(
data_iter,
parallel_model,
lambda x, y: x[0].mean(),
parallel_optimizer,
return_loss=True,
return_outputs=True,
)
if booster.plugin.stage_manager.is_last_stage():
parallel_output = sharded_output["loss"]
else:
parallel_output = torch.tensor(12345.0, device="cuda")
# broadcast along pp axis
dist.broadcast(
parallel_output, src=dist.get_process_group_ranks(plugin.pp_group)[-1], group=plugin.pp_group
)
else:
# for test without pp
parallel_output = parallel_model(inputs_embeds=input_embeddings.to(dtype)).last_hidden_state.mean()
parallel_optimizer.backward(parallel_output)
parallel_optimizer.step()
parallel_optimizer.zero_grad()
dist.all_reduce(parallel_output, group=plugin.dp_group)
# ===================================================================================
# run normal model with all dp(different) inputs
all_inputs = [torch.empty_like(input_embeddings) for _ in range(dp_size)]
dist.all_gather(all_inputs, input_embeddings, group=plugin.dp_group)
torch_output_sum = 0
for input_data_ in all_inputs:
torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean()
torch_output.backward()
torch_output_sum += torch_output.detach()
# avg dp grads follows zero optimizer
for p in torch_model.parameters():
if p.grad is not None:
p.grad /= dp_size
torch_optimizer.step()
torch_optimizer.zero_grad()
assert_loose_close(parallel_output, torch_output_sum, dtype=dtype)
# use checkpoint to load sharded zero model
model_dir = "./test_deepseek"
if rank == world_size - 1:
os.makedirs(model_dir, exist_ok=True)
dist.barrier()
booster.save_model(parallel_model, model_dir, shard=True)
dist.barrier()
saved_model = AutoModel.from_pretrained(model_dir, trust_remote_code=True).cuda()
check_model_equal(torch_model, saved_model)
dist.barrier()
if rank == world_size - 1:
shutil.rmtree(model_dir)
print(f"rank {dist.get_rank()} test passed")
def run_dist(rank, world_size, port):
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_zero_with_original_model()
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [4])
@rerun_if_address_is_in_use()
def test_deepseek(world_size):
spawn(run_dist, world_size)
if __name__ == "__main__":
test_deepseek(world_size=4)

View File

@@ -59,10 +59,12 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
if (
booster.plugin.zero_stage in [1, 2]
and booster.plugin.shard_config.enable_sequence_parallelism
and booster.plugin.shard_config.pipeline_stage_manager is None
and booster.plugin.shard_config.sequence_parallelism_mode == "all_to_all"
):
master2working = sharded_optimizer.get_master_to_working_map()
for p1, p2 in zip(llama_model.parameters(), sharded_optimizer._master_param_groups_of_current_rank[0]):
working_p = sharded_optimizer.master_to_working_param[id(p2)]
working_p = master2working[id(p2)]
grads = sharded_optimizer.get_partitioned_gradients_by_param_id(0, id(working_p))
grad_index = (
0
@@ -146,6 +148,19 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
@parameterize(
"test_config",
[
{ # Ulysess + Flash attention
"tp_size": 1,
"pp_size": 2,
"sp_size": 2,
"num_microbatches": 2,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "all_to_all",
"enable_flash_attention": True,
"use_lazy_init": True,
"zero_stage": 0,
"precision": "fp16",
"initial_scale": 1,
},
{ # Test ring + Flash attention
"tp_size": 2,
"pp_size": 1,
@@ -159,19 +174,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"precision": "fp16",
"initial_scale": 1,
},
{ # Ulysess + Flash attention
"tp_size": 1,
"pp_size": 2,
"sp_size": 2,
"num_microbatches": 2,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "all_to_all",
"enable_flash_attention": True,
"use_lazy_init": True,
"zero_stage": 1,
"precision": "fp16",
"initial_scale": 1,
},
{
"tp_size": 1,
"pp_size": 1,
@@ -245,7 +247,6 @@ def run_llama_test(test_config):
except Exception as e:
print(f"Failed config: {test_config}")
raise e
clear_layout_converter()
Randomizer.reset_index()
torch.cuda.empty_cache()

View File

@@ -0,0 +1,190 @@
import os
import shutil
from copy import deepcopy
from typing import Tuple
import pytest
import torch
import torch.distributed
import torch.distributed as dist
from transformers.models.mixtral.configuration_mixtral import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralModel
import colossalai
from colossalai.booster.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing.random import seed_all
from tests.test_moe.moe_utils import assert_loose_close, check_model_equal
NUM_BATCH = 8
NUM_TOK_PER_BATCH, NUM_EXPERTS = 4, 4
NUM_LAYERS = 4
HIDDEN_SIZE_PER_HEAD = 4
NUM_HEADS = 4
TOP_K = 1
CHECKED_CONFIG = [ # FOR WORLD=4
(0, 1, 4, 1, 1),
(0, 1, 1, 4, 1),
(0, 1, 1, 1, 4),
(1, 4, 1, 1, 1),
(1, 1, 4, 1, 1),
(1, 1, 1, 4, 1),
(1, 1, 1, 1, 4),
(1, 2, 1, 1, 1),
]
@parameterize(
"config",
[
(1, 2, 2, 1, 1),
(1, 2, 1, 2, 1),
(1, 2, 1, 1, 2),
],
)
def run_zero_with_original_model(config: Tuple[int, ...]):
stage, ep_size, pp_size, tp_size, sp_size = config
world_size = dist.get_world_size()
rank = dist.get_rank()
dtype, precision = torch.float16, "fp16"
torch.cuda.set_device(dist.get_rank())
plugin = MoeHybridParallelPlugin(
pp_size=pp_size,
num_microbatches=pp_size,
tp_size=tp_size,
sp_size=sp_size,
ep_size=ep_size,
zero_stage=stage,
enable_sequence_parallelism=sp_size > 1,
sequence_parallelism_mode="all_to_all" if sp_size > 1 else None,
overlap_communication=False,
initial_scale=1,
precision=precision,
find_unused_parameters=True,
)
dp_size = plugin.dp_size
booster = Booster(plugin=plugin)
assert pp_size <= NUM_LAYERS, "pp_size should be less than or equal to NUM_LAYERS"
config = MixtralConfig(
hidden_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS,
intermediate_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2,
num_hidden_layers=NUM_LAYERS,
num_attention_heads=NUM_HEADS,
num_key_value_heads=NUM_HEADS,
num_local_experts=NUM_EXPERTS,
num_experts_per_tok=TOP_K,
attn_implementation="flash_attention_2",
)
# init model with the same seed
seed_all(10086)
torch_model = MixtralModel(config).to(dtype).cuda()
torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)
parallel_model = deepcopy(torch_model)
parallel_optimizer = torch.optim.SGD(parallel_model.parameters(), lr=1)
parallel_model, parallel_optimizer, _, _, _ = booster.boost(parallel_model, parallel_optimizer)
# create different input along dp axis
seed_all(1453 + rank)
torch_model.train()
parallel_model.train()
for _ in range(2):
# gen random input
input_embeddings = torch.rand(
NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True
).cuda()
dist.all_reduce(
input_embeddings, group=plugin.pp_group
) # pp inputs except the first stage doesn't matter, but need to be replicate for torch model check
dist.all_reduce(input_embeddings, group=plugin.tp_group) # tp group duplicate input
dist.all_reduce(input_embeddings, group=plugin.sp_group) # sp group duplicate input
# run the model with hybrid parallel
if booster.plugin.stage_manager is not None:
# for test with pp
data_iter = iter([{"inputs_embeds": input_embeddings}])
sharded_output = booster.execute_pipeline(
data_iter,
parallel_model,
lambda x, y: x.last_hidden_state.mean(),
parallel_optimizer,
return_loss=True,
return_outputs=True,
)
if booster.plugin.stage_manager.is_last_stage():
parallel_output = sharded_output["loss"]
else:
parallel_output = torch.tensor(12345.0, device="cuda")
# broadcast along pp axis
dist.broadcast(
parallel_output, src=dist.get_process_group_ranks(plugin.pp_group)[-1], group=plugin.pp_group
)
else:
# for test without pp
parallel_output = parallel_model(inputs_embeds=input_embeddings.to(dtype)).last_hidden_state.mean()
parallel_optimizer.backward(parallel_output)
parallel_optimizer.step()
parallel_optimizer.zero_grad()
dist.all_reduce(parallel_output, group=plugin.dp_group)
# ===================================================================================
# run normal model with all dp(different) inputs
all_inputs = [torch.empty_like(input_embeddings) for _ in range(dp_size)]
dist.all_gather(all_inputs, input_embeddings, group=plugin.dp_group)
torch_output_sum = 0
for input_data_ in all_inputs:
torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean()
torch_output.backward()
torch_output_sum += torch_output.detach()
# avg dp grads follows zero optimizer
for p in torch_model.parameters():
if p.grad is not None:
p.grad /= dp_size
torch_optimizer.step()
torch_optimizer.zero_grad()
assert_loose_close(parallel_output, torch_output_sum, dtype=dtype)
# use checkpoint to load sharded zero model
model_dir = "./test_mixtral"
if rank == world_size - 1:
os.makedirs(model_dir, exist_ok=True)
dist.barrier()
booster.save_model(parallel_model, model_dir, shard=True)
dist.barrier()
saved_model = MixtralModel.from_pretrained(model_dir).cuda().to(dtype)
check_model_equal(torch_model, saved_model)
dist.barrier()
if rank == world_size - 1:
shutil.rmtree(model_dir)
print(f"rank {dist.get_rank()} test passed")
def run_dist(rank, world_size, port):
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_zero_with_original_model()
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [4])
@rerun_if_address_is_in_use()
def test_mixtral(world_size):
spawn(run_dist, world_size)
if __name__ == "__main__":
test_mixtral(world_size=4)

View File

@@ -135,6 +135,68 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"precision": "fp16",
"initial_scale": 1,
},
{ # Ulysess + Flash attention
"tp_size": 1,
"pp_size": 2,
"sp_size": 2,
"num_microbatches": 2,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "all_to_all",
"enable_flash_attention": True,
"use_lazy_init": True,
"zero_stage": 1,
"precision": "fp16",
"initial_scale": 1,
},
{
"tp_size": 2,
"pp_size": 2,
"sp_size": 2,
"num_microbatches": 2,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "split_gather",
"enable_flash_attention": True,
"use_lazy_init": True,
"zero_stage": 1,
"precision": "fp16",
"initial_scale": 1,
},
{
"tp_size": 2,
"pp_size": 2,
"sp_size": 2,
"num_microbatches": 2,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "ring",
"enable_flash_attention": True,
"use_lazy_init": True,
"zero_stage": 1,
"precision": "fp16",
"initial_scale": 1,
},
{
"tp_size": 1,
"pp_size": 1,
"sp_size": 2,
"num_microbatches": 1,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "all_to_all",
"use_lazy_init": True,
"zero_stage": 1,
"precision": "fp16",
"initial_scale": 1,
},
{
"tp_size": 4,
"pp_size": 1,
"num_microbatches": 1,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "split_gather",
"enable_flash_attention": False,
"use_lazy_init": True,
"precision": "fp16",
"initial_scale": 1,
},
{
"tp_size": 1,
"pp_size": 2,
@@ -151,8 +213,11 @@ def run_qwen2_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry("transformers_qwen2")
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
try:
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
except Exception as e:
print(f"Failed config: {test_config}")
raise e
clear_layout_converter()
Randomizer.reset_index()
torch.cuda.empty_cache()
@@ -197,7 +262,11 @@ def run_qwen2_3d_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry("transformers_qwen2")
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
try:
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
except Exception as e:
print(f"Failed config: {test_config}")
raise e
clear_layout_converter()
Randomizer.reset_index()

View File

@@ -64,8 +64,12 @@ def exam_zero_1_2_grad_acc():
zero1_optimizer.step()
zero2_optimizer.step()
zero1_optimizer._force_wait_all_gather()
zero2_optimizer._force_wait_all_gather()
# check updated param
for z1p, z2p in zip(zero1_model.parameters(), zero2_model.parameters()):
assert not hasattr(z1p, "_all_gather_handle")
assert torch.equal(z1p.data, z2p.data)

View File

@@ -190,6 +190,8 @@ def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype, master_weights: bool):
# torch ddp step
torch_optimizer.step()
zero_optimizer._force_wait_all_gather()
# check updated param
for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
loose_close(p, z1p, dtype=dtype)