[fp8] Merge feature/fp8_comm to main branch of Colossalai (#6016)

* add SimPO

* fix dataloader

* remove debug code

* add orpo

* fix style

* fix colossalai, transformers version

* fix colossalai, transformers version

* fix colossalai, transformers version

* fix torch colossalai version

* update transformers version

* [shardformer] DeepseekMoE support (#5871)

* [Feature] deepseek moe expert parallel implement

* [misc] fix typo, remove redundant file (#5867)

* [misc] fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Feature] deepseek support & unit test

* [misc] remove debug code & useless print

* [misc] fix typos (#5872)

* [Feature] remove modeling file, use auto config. (#5884)

* [misc] fix typos

* [Feature] deepseek support via auto model, remove modeling file

* [misc] delete useless file

* [misc] fix typos

* [Deepseek] remove redundant code (#5888)

* [misc] fix typos

* [Feature] deepseek support via auto model, remove modeling file

* [misc] delete useless file

* [misc] fix typos

* [misc] remove redundant code

* [Feature/deepseek] resolve comment. (#5889)

* [misc] fix typos

* [Feature] deepseek support via auto model, remove modeling file

* [misc] delete useless file

* [misc] fix typos

* [misc] remove redundant code

* [misc] mv module replacement into if branch

* [misc] add some warning message and modify some code in unit test

* [misc] fix typos

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Hoxfix] Fix CUDA_DEVICE_MAX_CONNECTIONS for comm overlap

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [Feat] Diffusion Model(PixArtAlpha/StableDiffusion3) Support (#5838)

* Diffusion Model Inference support

* Stable Diffusion 3 Support

* pixartalpha support

* [HotFix] CI,import,requirements-test for #5838 (#5892)

* [Hot Fix] CI,import,requirements-test

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Feature] Enable PP + SP for llama (#5868)

* fix cross-PP-stage position id length diff bug

* fix typo

* fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* use a one cross entropy func for all shardformer models

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [ShardFormer] Add Ulysses Sequence Parallelism support for Command-R, Qwen2 and ChatGLM (#5897)

* add benchmark for sft, dpo, simpo, orpo. Add benchmarking result. Support lora with gradient checkpoint

* fix style

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix eval

* hotfix citation

* [zero] support all-gather overlap (#5898)

* [zero] support all-gather overlap

* [zero] add overlap all-gather flag

* [misc] fix typo

* [zero] update api

* fix orpo cross entropy loss

* [Auto Parallel]: Speed up intra-op plan generation by 44% (#5446)

* Remove unnecessary calls to deepcopy

* Build DimSpec's difference dict only once

This change considerably speeds up construction speed of DimSpec objects. The difference_dict is the same for each DimSpec object, so a single copy of it is enough.

* Fix documentation of DimSpec's difference method

* [ShardFormer] fix qwen2 sp (#5903)

* [compatibility] support torch 2.2 (#5875)

* Support Pytorch 2.2.2

* keep build_on_pr file and update .compatibility

* fix object_to_tensor usage when torch>=2.3.0 (#5820)

* [misc] support torch2.3 (#5893)

* [misc] support torch2.3

* [devops] update compatibility ci

* [devops] update compatibility ci

* [devops] add debug

* [devops] add debug

* [devops] add debug

* [devops] add debug

* [devops] remove debug

* [devops] remove debug

* [release] update version (#5912)

* [plugin] support all-gather overlap for hybrid parallel (#5919)

* [plugin] fixed all-gather overlap support for hybrid parallel

* add kto

* fix style, add kto data sample

* [Examples] Add lazy init to OPT and GPT examples (#5924)

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [ColossalChat] Hotfix for ColossalChat (#5910)

* add ignore and tiny llama

* fix path issue

* run style

* fix issue

* update bash

* add ignore and tiny llama

* fix path issue

* run style

* fix issue

* update bash

* fix ddp issue

* add Qwen 1.5 32B

* refactor tokenization

* [FIX BUG] UnboundLocalError: cannot access local variable 'default_conversation' where it is not associated with a value (#5931)

* cannot access local variable 'default_conversation' where it is not associated with a value

set default value for 'default_conversation'

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* fix test data

* refactor evaluation

* remove real data path

* remove real data path

* Add n_fused as an input from native_module (#5894)

* [FIX BUG] convert env param to int in (#5934)

* [Hotfix] Fix ZeRO typo #5936

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [Feature] Add a switch to control whether the model checkpoint needs to be saved after each epoch ends (#5941)

* Add a switch to control whether the model checkpoint needs to be saved after each epoch ends

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* fix style

* fix style

* fix style

* [shardformer] hotfix attn mask (#5945)

* [shardformer] hotfix attn mask (#5947)

* [Feat] Distrifusion Acceleration Support for Diffusion Inference (#5895)

* Distrifusion Support source

* comp comm overlap optimization

* sd3 benchmark

* pixart distrifusion bug fix

* sd3 bug fix and benchmark

* generation bug fix

* naming fix

* add docstring, fix counter and shape error

* add reference

* readme and requirement

* [zero] hotfix update master params (#5951)

* [release] update version (#5952)

* [Chat] Fix lora (#5946)

* fix merging

* remove filepath

* fix style

* Update README.md (#5958)

* [hotfix] Remove unused plan section (#5957)

* remove readme

* fix readme

* update

* [test] add mixtral for sequence classification

* [test] add mixtral transformer test

* [moe] fix plugin

* [test] mixtra pp shard test

* [chore] handle non member group

* [zero] solve hang

* [test] pass mixtral shardformer test

* [moe] implement transit between non moe tp and ep

* [zero] solve hang

* [misc] solve booster hang by rename the variable

* solve hang when parallel mode = pp + dp

* [moe] implement submesh initialization

* [moe] add mixtral dp grad scaling when not all experts are activated

* [chore] manually revert unintended commit

* [chore] trivial fix

* [chore] arg pass & remove drop token

* [test] add mixtral modelling test

* [moe] implement tp

* [moe] test deepseek

* [moe] clean legacy code

* [Feature] MoE Ulysses Support (#5918)

* moe sp support

* moe sp bug solve

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [chore] minor fix

* [moe] init moe plugin comm setting with sp

* moe sp + ep bug fix

* [moe] finalize test (no pp)

* [moe] full test for deepseek and mixtral (pp + sp to fix)

* [chore] minor fix after rebase

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [chore] solve moe ckpt test failure and some other arg pass failure

* [moe] remove ops

* [test] fix test: test_zero1_2

* [bug] fix: somehow logger hangs the program

* [moe] deepseek moe sp support

* [test] add check

* [deepseek] replace attn (a workaround for bug in transformers)

* [misc] skip redunant test

* [misc] remove debug/print code

* [moe] refactor mesh assignment

* Revert "[moe] implement submesh initialization"

This reverts commit 2f9bce6686.

* [chore] change moe_pg_mesh to private

* [misc] remove incompatible test config

* [misc] fix ci failure: change default value to false in moe plugin

* [misc] remove useless condition

* [chore] docstring

* [moe] remove force_overlap_comm flag and add warning instead

* [doc] add MoeHybridParallelPlugin docstring

* [moe] solve dp axis issue

* [chore] remove redundant test case, print string & reduce test tokens

* [feat] Dist Loader for Eval (#5950)

* support auto distributed data loader

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* support auto distributed data loader

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix tp error

* remove unused parameters

* remove unused

* update inference

* update docs

* update inference

---------

Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [lora] lora support hybrid parallel plugin (#5956)

* lora support hybrid plugin

* fix

* fix

* fix

* fix

* Support overall loss, update KTO logging

* [Docs] clarify launch port

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [Hotfix] README link (#5966)

* update ignore

* update readme

* run style

* update readme

* [Hotfix] Avoid fused RMSnorm import error without apex (#5985)

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [Chat] fix readme (#5989)

* fix readme

* fix readme, tokenization fully tested

* fix readme, tokenization fully tested

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: root <root@notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9-0.notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9.colossal-ai.svc.cluster.local>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* fix sync condition (#6000)

* [plugin] add cast inputs option for zero (#6003)

* [pre-commit.ci] pre-commit autoupdate (#5995)

updates:
- [github.com/psf/black-pre-commit-mirror: 24.4.2 → 24.8.0](https://github.com/psf/black-pre-commit-mirror/compare/24.4.2...24.8.0)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [misc] Bypass the huggingface bug to solve the mask mismatch problem (#5991)

* [Feature] Zigzag Ring attention (#5905)

* halfway

* fix cross-PP-stage position id length diff bug

* fix typo

* fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* unified cross entropy func for all shardformer models

* remove redundant lines

* add basic ring attn; debug cross entropy

* fwd bwd logic complete

* fwd bwd logic complete; add experimental triton rescale

* precision tests passed

* precision tests passed

* fix typos and remove misc files

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add sp_mode to benchmark; fix varlen interface

* update softmax_lse shape by new interface

* change tester name

* remove buffer clone; support packed seq layout

* add varlen tests

* fix typo

* all tests passed

* add dkv_group; fix mask

* remove debug statements

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [misc] update compatibility (#6008)

* [misc] update compatibility

* [misc] update requirements

* [devops] disable requirements cache

* [test] fix torch ddp test

* [test] fix rerun on address in use

* [test] fix lazy init

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix the merge

* fix the merge

* overlap kv comm with output rescale (#6017)

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* fix the merge

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix the merge

* fix

* fix

* fix the merge

* fix

* [misc] Use dist logger in plugins (#6011)

* use dist logger in plugins

* remove trash

* print on rank 0

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* fix

* fix

* fix

* fix

* fix the merge

* fix

* fix

* fix

* fix

---------

Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: Haze188 <haze188@qq.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com>
Co-authored-by: Guangyao Zhang <xjtu521@qq.com>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: Stephan Kö <stephankoe@users.noreply.github.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: zhurunhua <1281592874@qq.com>
Co-authored-by: Insu Jang <insujang@umich.edu>
Co-authored-by: Gao, Ruiyuan <905370712@qq.com>
Co-authored-by: hxwang <wang1570@e.ntu.edu.sg>
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: root <root@notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9-0.notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9.colossal-ai.svc.cluster.local>
This commit is contained in:
Wang Binluo
2024-08-22 09:21:34 +08:00
committed by GitHub
parent 0a51319113
commit eea37da6fa
92 changed files with 2239 additions and 480 deletions

View File

@@ -22,9 +22,9 @@ COMMON_MODELS = [
"transformers_bloom_for_causal_lm",
"transformers_falcon_for_causal_lm",
"transformers_chatglm_for_conditional_generation",
"transformers_llama_for_casual_lm",
"transformers_llama_for_causal_lm",
"transformers_vit_for_masked_image_modeling",
"transformers_mistral_for_casual_lm",
"transformers_mistral_for_causal_lm",
]
IS_FAST_TEST = os.environ.get("FAST_TEST", "0") == "1"

View File

@@ -32,8 +32,8 @@ if HAS_COMMAND:
return dict(input_ids=input_ids, attention_mask=attention_mask)
# label is needed for casual lm
def data_gen_for_casual_lm():
# label is needed for causal lm
def data_gen_for_causal_lm():
data = data_gen()
labels = data["input_ids"].clone()
data["labels"] = labels
@@ -44,7 +44,7 @@ if HAS_COMMAND:
# function to get the loss
loss_fn = lambda output: output["last_hidden_state"].mean()
loss_fn_for_casual_lm = lambda output: output["loss"]
loss_fn_for_causal_lm = lambda output: output["loss"]
loss_fn_for_seq_classification = lambda output: output["logits"].mean()
config = CohereConfig(
@@ -70,10 +70,10 @@ if HAS_COMMAND:
model_attribute=ModelAttribute(has_control_flow=True),
)
model_zoo.register(
name="transformers_command_for_casual_lm",
name="transformers_command_for_causal_lm",
model_fn=lambda: transformers.CohereForCausalLM(config),
data_gen_fn=data_gen_for_casual_lm,
data_gen_fn=data_gen_for_causal_lm,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_casual_lm,
loss_fn=loss_fn_for_causal_lm,
model_attribute=ModelAttribute(has_control_flow=True),
)

View File

@@ -33,20 +33,21 @@ if HAS_LLAMA:
[1, 15043, 29892, 590, 11203, 338, 274, 1082, 1, 15043, 29892, 590, 11203, 338, 274, 1082],
]
).long()
attention_mask = torch.Tensor(
[
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
]
).long()
attention_mask = torch.ones_like(input_ids)
return dict(input_ids=input_ids, attention_mask=attention_mask)
# label is needed for casual lm
def data_gen_for_casual_lm():
# label is needed for causal lm
def data_gen_for_causal_lm():
data = data_gen()
# Test padded sequence
padding = torch.zeros(2, data["input_ids"].shape[1] // 2, dtype=torch.long)
data["input_ids"] = torch.cat([data["input_ids"], padding], dim=1)
data["attention_mask"] = torch.cat([data["attention_mask"], padding], dim=1)
ignore_idx = -100
labels = data["input_ids"].clone()
labels[~data["attention_mask"].bool()] = ignore_idx
data["labels"] = labels
return data
@@ -55,7 +56,7 @@ if HAS_LLAMA:
# function to get the loss
loss_fn = lambda output: output["last_hidden_state"].mean()
loss_fn_for_casual_lm = lambda output: output["loss"]
loss_fn_for_causal_lm = lambda output: output["loss"]
loss_fn_for_seq_classification = lambda output: output["logits"].mean()
config = LlamaConfig(
@@ -70,9 +71,17 @@ if HAS_LLAMA:
config.pad_token_id = config.eos_token_id
# register the following models
# transformers.LlamaModel,
# transformers.LlamaForCausalLM,
# transformers.LlamaModel,
# transformers.LlamaForSequenceClassification,
model_zoo.register(
name="transformers_llama_for_causal_lm",
model_fn=lambda: transformers.LlamaForCausalLM(config),
data_gen_fn=data_gen_for_causal_lm,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_causal_lm,
model_attribute=ModelAttribute(has_control_flow=True),
)
model_zoo.register(
name="transformers_llama",
model_fn=lambda: transformers.LlamaModel(config),
@@ -81,14 +90,6 @@ if HAS_LLAMA:
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True),
)
model_zoo.register(
name="transformers_llama_for_casual_lm",
model_fn=lambda: transformers.LlamaForCausalLM(config),
data_gen_fn=data_gen_for_casual_lm,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_casual_lm,
model_attribute=ModelAttribute(has_control_flow=True),
)
model_zoo.register(
name="transformers_llama_for_sequence_classification",
model_fn=lambda: transformers.LlamaForSequenceClassification(config),

View File

@@ -64,7 +64,7 @@ model_zoo.register(
model_attribute=ModelAttribute(has_control_flow=True),
)
model_zoo.register(
name="transformers_mistral_for_casual_lm",
name="transformers_mistral_for_causal_lm",
model_fn=lambda: transformers.MistralForCausalLM(config),
data_gen_fn=data_gen_for_lm,
output_transform_fn=output_transform_fn,

View File

@@ -53,6 +53,8 @@ config = MixtralConfig(
num_attention_heads=8,
num_hidden_layers=2,
vocab_size=1000,
attn_implementation="flash_attention_2",
torch_dtype="float16",
output_router_logits=True,
)

View File

@@ -33,8 +33,8 @@ if HAS_QWEN2:
attention_mask = torch.Tensor([[1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1]]).long()
return dict(input_ids=input_ids, attention_mask=attention_mask)
# label is needed for casual lm
def data_gen_for_casual_lm():
# label is needed for causal lm
def data_gen_for_causal_lm():
data = data_gen()
labels = data["input_ids"].clone()
data["labels"] = labels
@@ -45,7 +45,7 @@ if HAS_QWEN2:
# function to get the loss
loss_fn = lambda output: output["last_hidden_state"].mean()
loss_fn_for_casual_lm = lambda output: output["loss"]
loss_fn_for_causal_lm = lambda output: output["loss"]
loss_fn_for_seq_classification = lambda output: output["logits"].mean()
config = Qwen2Config(
@@ -72,11 +72,11 @@ if HAS_QWEN2:
model_attribute=ModelAttribute(has_control_flow=True),
)
model_zoo.register(
name="transformers_qwen2_for_casual_lm",
name="transformers_qwen2_for_causal_lm",
model_fn=lambda: transformers.Qwen2ForCausalLM(config),
data_gen_fn=data_gen_for_casual_lm,
data_gen_fn=data_gen_for_causal_lm,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_casual_lm,
loss_fn=loss_fn_for_causal_lm,
model_attribute=ModelAttribute(has_control_flow=True),
)
model_zoo.register(

View File

@@ -97,7 +97,7 @@ def check_3d_plugin(init_method: str = "none", early_stop: bool = True):
# TODO(ver217): add more models
for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.get_sub_registry(
"transformers_llama_for_casual_lm"
"transformers_llama_for_causal_lm"
).items():
err = run_fn(init_method, model_fn, data_gen_fn, output_transform_fn)

View File

@@ -105,7 +105,7 @@ def check_low_level_zero_lora(stage, model_name, early_stop: bool = True):
sub_model_zoo = model_zoo.get_sub_registry(model_name)
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
task_type = None
if name == "transformers_llama_for_casual_lm":
if name == "transformers_llama_for_causal_lm":
task_type = "CAUSAL_LM"
if name == "transformers_llama_for_sequence_classification":
task_type = "SEQ_CLS"

View File

@@ -47,7 +47,7 @@ def check_torch_ddp_plugin():
registry = model_zoo
for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in registry.items():
if name == "dlrm_interactionarch" or name.startswith("simple_"):
if name in ("dlrm_interactionarch", "transformers_mixtral") or name.startswith("simple_"):
continue
run_fn(model_fn, data_gen_fn, output_transform_fn)
torch.cuda.empty_cache()

View File

@@ -74,7 +74,7 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b
@clear_cache_before_run()
@parameterize("placement_config", OPTIM_PLACEMENT_CONFIGS)
@parameterize("shard", [True, False])
@parameterize("model_name", ["transformers_llama_for_casual_lm"])
@parameterize("model_name", ["transformers_llama_for_causal_lm"])
@parameterize("size_per_shard", [32])
@parameterize("tp_size", [1, 2])
@parameterize("zero_size", [2])

View File

@@ -20,7 +20,7 @@ from tests.kit.model_zoo import model_zoo
@clear_cache_before_run()
@parameterize("shard", [False, True])
@parameterize("model_name", ["transformers_llama_for_casual_lm"])
@parameterize("model_name", ["transformers_llama_for_causal_lm"])
def exam_torch_load_from_gemini(shard: bool, model_name: str):
(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
criterion = lambda x: x.mean()

View File

@@ -39,7 +39,7 @@ else:
@parameterize("shard", [True, False])
@parameterize("model_name", ["transformers_llama_for_casual_lm"])
@parameterize("model_name", ["transformers_llama_for_causal_lm"])
@parameterize("size_per_shard", [32])
@parameterize("test_config", TEST_CONFIGS)
@clear_cache_before_run()

View File

@@ -149,7 +149,7 @@ def check_low_level_zero_lora_checkpointIO(
if name != "transformers_llama":
continue
task_type = None
if name == "transformers_llama_for_casual_lm":
if name == "transformers_llama_for_causal_lm":
task_type = "CAUSAL_LM"
if name == "transformers_llama_for_sequence_classification":
task_type = "SEQ_CLS"

View File

@@ -18,7 +18,7 @@ from tests.kit.model_zoo import model_zoo
@clear_cache_before_run()
@parameterize("model_name", ["transformers_llama_for_casual_lm"])
@parameterize("model_name", ["transformers_llama_for_causal_lm"])
@parameterize("plugin_type", ["ddp", "zero", "gemini"])
def exam_from_pretrained(plugin_type: str, model_name: str, shard=True, size_per_shard=32):
(model_fn, data_gen_fn, output_transform_fn, loss_fn, _) = next(

View File

@@ -18,9 +18,17 @@ def test_models_lazy_init(subset, default_device):
sub_model_zoo = model_zoo.get_sub_registry(subset, allow_empty=True)
for name, entry in sub_model_zoo.items():
# TODO(ver217): lazy init does not support weight norm, skip these models
if name in ("torchaudio_wav2vec2_base", "torchaudio_hubert_base") or name.startswith(
("transformers_vit", "transformers_blip2", "transformers_whisper")
):
if name in (
"torchaudio_wav2vec2_base",
"torchaudio_hubert_base",
"timm_beit",
"timm_vision_transformer",
"timm_deit",
"timm_beitv2",
"timm_deit3",
"timm_convit",
"timm_tnt_b_patch16_224",
) or name.startswith(("transformers_vit", "transformers_blip2", "transformers_whisper")):
continue
check_lazy_init(entry, verbose=True, default_device=default_device)

View File

@@ -91,7 +91,7 @@ def run_lora_test():
sub_model_zoo = model_zoo.get_sub_registry("transformers_llama")
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
task_type = None
if name == "transformers_llama_for_casual_lm":
if name == "transformers_llama_for_causal_lm":
task_type = "CAUSAL_LM"
if name == "transformers_llama_for_sequence_classification":
task_type = "SEQ_CLS"

View File

@@ -6,6 +6,7 @@ import pytest
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.testing import assert_close
import colossalai
from colossalai.cluster import ProcessGroupMesh
@@ -107,13 +108,13 @@ def run_pp(
# check loss
if stage_manager.is_last_stage(ignore_chunk=True):
assert torch.allclose(torch_loss, pp_ret["loss"])
assert_close(torch_loss, pp_ret["loss"])
# check gradients
for i in range(num_model_chunk):
idx = world_size * i + rank
assert torch.allclose(torch_model.layers[idx].weight.grad, sharded_model[i].weight.grad)
assert torch.allclose(torch_model.layers[idx].bias.grad, sharded_model[i].bias.grad)
assert_close(torch_model.layers[idx].weight.grad, sharded_model[i].weight.grad)
assert_close(torch_model.layers[idx].bias.grad, sharded_model[i].bias.grad)
# step
torch_optimizer.step()
@@ -123,8 +124,8 @@ def run_pp(
# check updated param
for i in range(num_model_chunk):
idx = world_size * i + rank
assert torch.allclose(torch_model.layers[idx].weight, sharded_model[i].weight)
assert torch.allclose(torch_model.layers[idx].bias, sharded_model[i].bias)
assert_close(torch_model.layers[idx].weight, sharded_model[i].weight)
assert_close(torch_model.layers[idx].bias, sharded_model[i].bias)
# forward only
with torch.no_grad():
@@ -135,14 +136,14 @@ def run_pp(
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True
)
if stage_manager.is_last_stage(ignore_chunk=True):
assert torch.allclose(torch_loss, pp_ret["loss"])
assert_close(torch_loss, pp_ret["loss"])
for layer in sharded_model:
if layer.weight.grad is None:
assert layer.weight.grad is None and layer.bias.grad is None
else:
assert torch.allclose(layer.weight.grad, torch.zeros_like(layer.weight.grad))
assert torch.allclose(layer.bias.grad, torch.zeros_like(layer.bias.grad))
assert_close(layer.weight.grad, torch.zeros_like(layer.weight.grad))
assert_close(layer.bias.grad, torch.zeros_like(layer.bias.grad))
@pytest.mark.dist

View File

@@ -6,6 +6,7 @@ import pytest
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.testing import assert_close
import colossalai
from colossalai.cluster import ProcessGroupMesh
@@ -103,13 +104,13 @@ def examine_pp(num_microbatch: int, batch_size: int):
# check loss
if stage_manager.is_last_stage():
assert torch.allclose(torch_loss, pp_ret["loss"])
assert_close(torch_loss, pp_ret["loss"])
# check gradients
for i in range(len(sharded_model)):
idx = rank * num_local_layer + i
assert torch.allclose(torch_model.layers[idx].weight.grad, sharded_model[i].weight.grad)
assert torch.allclose(torch_model.layers[idx].bias.grad, sharded_model[i].bias.grad)
assert_close(torch_model.layers[idx].weight.grad, sharded_model[i].weight.grad)
assert_close(torch_model.layers[idx].bias.grad, sharded_model[i].bias.grad)
# step
torch_optimizer.step()
@@ -119,8 +120,8 @@ def examine_pp(num_microbatch: int, batch_size: int):
# check updated param
for i in range(len(sharded_model)):
idx = rank * num_local_layer + i
assert torch.allclose(torch_model.layers[idx].weight, sharded_model[i].weight)
assert torch.allclose(torch_model.layers[idx].bias, sharded_model[i].bias)
assert_close(torch_model.layers[idx].weight, sharded_model[i].weight)
assert_close(torch_model.layers[idx].bias, sharded_model[i].bias)
# forward only
with torch.no_grad():
@@ -131,14 +132,14 @@ def examine_pp(num_microbatch: int, batch_size: int):
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True
)
if stage_manager.is_last_stage():
assert torch.allclose(torch_loss, pp_ret["loss"])
assert_close(torch_loss, pp_ret["loss"])
for layer in sharded_model:
if layer.weight.grad is None:
assert layer.weight.grad is None and layer.bias.grad is None
else:
assert torch.allclose(layer.weight.grad, torch.zeros_like(layer.weight.grad))
assert torch.allclose(layer.bias.grad, torch.zeros_like(layer.bias.grad))
assert_close(layer.weight.grad, torch.zeros_like(layer.weight.grad))
assert_close(layer.bias.grad, torch.zeros_like(layer.bias.grad))
def run_dist(

View File

@@ -88,6 +88,7 @@ def check_attn_func(dtype: torch.dtype, attn_func, attn_kwargs: dict, padding_ma
padding_mask = padding_mask[:, None, :, None].logical_not()
ref_output = ref_output.masked_fill(padding_mask, 0)
output = output.masked_fill(padding_mask, 0)
assert_close(output, ref_output, **tols)
output.mean().backward()
ref_output.mean().backward()
@@ -128,6 +129,8 @@ def test_flash_attn_func(dtype: torch.dtype):
attn_kwargs, padding_mask = gen_kwargs_func(dtype)
for attn_func, name, need_postprocess in attn_funcs:
print(f"{dtype}, {name}, {mask_type}")
if mask_type == "padded":
pass
if need_postprocess:
check_attn_func(dtype, attn_func, post_process_kwargs_for_raw_attn(attn_kwargs), padding_mask)
else:

View File

@@ -0,0 +1,186 @@
import torch
import torch.distributed as dist
import torch.nn.functional as F
from flash_attn import flash_attn_qkvpacked_func, flash_attn_varlen_qkvpacked_func
from torch.testing import assert_close
import colossalai
from colossalai.shardformer.layer import AttnMaskType
from colossalai.shardformer.layer.attn import AttnMaskType, RingAttention
from colossalai.shardformer.layer.utils import split_batch_zigzag, split_varlen_zigzag
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device
@parameterize("seq_len", [4096])
@parameterize("bs", [2])
@parameterize("nheads", [5])
@parameterize("d", [128])
@parameterize("dtype", [torch.bfloat16, torch.float16])
def check_ring_attn(seq_len, bs, nheads, d, dtype):
torch.cuda.manual_seed(2)
device = get_current_device()
sp_group = dist.group.WORLD
sp_size = dist.get_world_size()
# Some outliers may seem large, but our errors are still lower than
# than Megatron-LM context parallel's
# (https://github.com/NVIDIA/TransformerEngine/blob/33a3d02f81c56e6f7b542c09bfa86657078d57fb/tests/pytorch/fused_attn/run_fused_attn_with_cp.py#L215)
# and the original zigzag implementation's (https://github.com/zhuzilin/ring-flash-attention/tree/main)
atol = rtol = 7e-3
# Setup inputs
qkv = torch.randn(bs, seq_len, 3, nheads, d, device=device, dtype=dtype, requires_grad=True)
local_qkv = split_batch_zigzag(qkv, sp_group)
q, k, v = local_qkv.unbind(dim=-3)
q, k, v = [x.squeeze(2).detach().clone().transpose(1, 2) for x in (q, k, v)] # (B, nHeads, Sq, D)
q.requires_grad = k.requires_grad = v.requires_grad = True
# Ring attention vs single GPU
ring_out, ring_lse = RingAttention.attention(
q,
k,
v,
sp_group,
AttnMaskType.CAUSAL,
return_softmax=True,
inner_ring_size=max(2, sp_size // 2),
# inner_ring_size=4
)
ring_out = ring_out.transpose(1, 2)
out, lse, _ = flash_attn_qkvpacked_func(
qkv, dropout_p=0.0, causal=True, window_size=(-1, -1), alibi_slopes=None, return_attn_probs=True
)
# Checkout out and softmax denominator
local_out = split_batch_zigzag(out, sp_group)
local_lse = split_batch_zigzag(lse, sp_group, seq_dim=-1)
local_lse = local_lse.transpose(1, 2).contiguous().view(-1, ring_lse.shape[-1]) # (B, nHeads, Sq) -> (T, nHeads)
assert_close(ring_lse, local_lse, atol=atol, rtol=rtol)
assert_close(ring_out, local_out, atol=atol, rtol=rtol)
# Check grads
ring_out.sum().backward()
out.sum().backward()
ring_dq, ring_dk, ring_dv = [x.transpose(1, 2) for x in (q.grad, k.grad, v.grad)]
dqkv = qkv.grad
local_dqkv = split_batch_zigzag(dqkv, sp_group)
assert_close(ring_dq, local_dqkv[:, :, 0], atol=atol, rtol=rtol)
assert_close(ring_dk, local_dqkv[:, :, 1], atol=atol, rtol=rtol)
assert_close(ring_dv, local_dqkv[:, :, 2], atol=atol, rtol=rtol)
if dist.get_rank() == 0:
print(
f"sp_size {dist.get_world_size()}, inner ring size {dist.get_world_size(RingAttention.INNER_RING_GROUP)} passed."
)
@parameterize("seqlen", [4096])
@parameterize("bs", [2])
@parameterize("nheads", [5])
@parameterize("d", [128])
@parameterize("dtype", [torch.bfloat16, torch.float16])
def check_packed_seq(seqlen, bs, nheads, d, dtype):
device = get_current_device()
sp_group = dist.group.WORLD
sp_size = dist.get_world_size()
atol = rtol = 7e-3
torch.cuda.manual_seed(2)
# Prepare varlen attention mask
padding_mask = torch.ones((bs, seqlen), dtype=torch.int, device=device)
padding_mask[: bs // 2, (seqlen // 4) * 3 :] = 0
padding_mask[:, seqlen // 2 :] = 0
input_embeds = torch.randn(bs, seqlen, nheads, d, device=device, dtype=dtype, requires_grad=True)
# Forward
# out = ColoAttention.attention(q, k, v, **mask_info)
flat_input = input_embeds.view(-1, nheads, d)[padding_mask.flatten().nonzero().squeeze()]
qkv = torch.stack([flat_input] * 3, dim=1)
qkv.retain_grad()
input_embeds, mask_info, _ = RingAttention.prepare_varlen_batch(padding_mask, sp_group, input_embeds)
out, lse, _ = flash_attn_varlen_qkvpacked_func(
qkv,
mask_info["cu_seqlens"] * sp_size,
mask_info["max_seqlen"] * sp_size,
return_attn_probs=True,
causal=True,
# deterministic=True
)
# Test the splitting function
local_input = split_varlen_zigzag(
flat_input, mask_info["cu_seqlens"] * sp_size, sp_group, mask_info["max_seqlen"] * sp_size
)
assert (local_input == input_embeds.view(-1, nheads, d)[mask_info["valid_indices"]]).all()
del local_input, flat_input
q_ring, k_ring, v_ring = [input_embeds.clone().transpose(1, 2) for _ in range(3)]
q_ring.retain_grad()
k_ring.retain_grad()
v_ring.retain_grad()
ring_out, ring_lse = RingAttention.attention(
q_ring,
k_ring,
v_ring,
sp_group,
**mask_info,
pad_output=False,
return_softmax=True,
# deterministic=True
)
ring_out = ring_out.transpose(1, 2).reshape(-1, nheads, d)
# Check output
lse = lse.transpose(0, 1)
out, lse = split_varlen_zigzag(
[out, lse], mask_info["cu_seqlens"] * sp_size, sp_group, mask_info["max_seqlen"] * sp_size
)
assert_close(lse, ring_lse, atol=atol, rtol=rtol)
assert_close(out, ring_out, atol=atol, rtol=rtol)
# Check grads
labels = torch.ones(out.shape[0], dtype=dtype, device=device)
F.mse_loss(out.sum((-2, -1)), labels).backward()
F.mse_loss(ring_out.sum((-2, -1)), labels[: ring_out.shape[0]]).backward()
dq, dk, dv = [
split_varlen_zigzag(
qkv.grad[:, i], mask_info["cu_seqlens"] * sp_size, sp_group, mask_info["max_seqlen"] * sp_size
)
for i in range(3)
]
dq_ring, dk_ring, dv_ring = [
x.transpose(1, 2).reshape(-1, nheads, d)[mask_info["valid_indices"]]
for x in (q_ring.grad, k_ring.grad, v_ring.grad)
]
assert_close(dq, dq_ring, atol=atol, rtol=rtol)
assert_close(dk, dk_ring, atol=atol, rtol=rtol)
assert_close(dv, dv_ring, atol=atol, rtol=rtol)
def launch_single_ring(rank, world_size, port):
colossalai.launch(rank, world_size, "localhost", port)
check_packed_seq()
check_ring_attn()
def launch_double_ring(rank, world_size, port):
colossalai.launch(rank, world_size, "localhost", port)
check_ring_attn()
@rerun_if_address_is_in_use()
@parameterize("world_size", [2])
def test_ring_attn(world_size):
spawn(launch_single_ring, nprocs=world_size)
@rerun_if_address_is_in_use()
@parameterize("world_size", [4])
def test_double_ring(world_size):
spawn(launch_double_ring, nprocs=world_size)
if __name__ == "__main__":
test_ring_attn()
test_double_ring()

View File

@@ -10,6 +10,7 @@ from torch.distributed import ProcessGroup
from torch.nn import Module
from torch.optim import Adam, Optimizer
from torch.testing import assert_close
from transformers.modeling_outputs import BaseModelOutputWithPast
from colossalai.accelerator import get_accelerator
from colossalai.booster import Booster
@@ -259,7 +260,6 @@ def run_forward_backward_with_hybrid_plugin(
org_output = org_model(**unshard_test_data)
org_loss = criterion(org_output)
org_loss.backward()
return org_loss, org_output, sharded_loss, sharded_output
@@ -302,11 +302,12 @@ def run_forward_backward_with_low_level_zero_plugin(
def check_output_hidden_state(
org_output: Tensor,
sharded_output: Tensor,
org_output: BaseModelOutputWithPast,
sharded_output: BaseModelOutputWithPast,
stage_manager: Optional[PipelineStageManager] = None,
atol: float = 1e-5,
rtol: float = 1e-3,
shard_config: Optional[ShardConfig] = None,
):
org_hidden_state = org_output.last_hidden_state
@@ -315,6 +316,14 @@ def check_output_hidden_state(
else:
sharded_hidden_state = sharded_output.last_hidden_state
# Check if the output sequence is gathered before cross entropy
if shard_config is not None:
seq_dim = 1
sp_group = shard_config.sequence_parallel_process_group
sp_size = shard_config.sequence_parallel_size
if org_hidden_state.shape[seq_dim] == sharded_hidden_state.shape[seq_dim] * sp_size:
org_hidden_state = org_hidden_state.chunk(sp_size, dim=seq_dim)[dist.get_rank(sp_group)]
assert_close(org_hidden_state.float(), sharded_hidden_state.float(), atol=atol, rtol=rtol)
@@ -374,8 +383,11 @@ def get_grad_tensors_for_check(
shard_grad = torch.cat(shard_grad_list, dim=dim)
# embedding may be resized when using tensor parallel
if shard_grad.shape[0] > org_grad.shape[0]:
shard_grad = shard_grad[: org_grad.shape[0], :]
try:
if shard_grad.shape[0] > org_grad.shape[0]:
shard_grad = shard_grad[: org_grad.shape[0], :]
except:
pass
if verbose and dist.get_rank() == 0:
print(f"'{suffix}' grad: {org_grad}, {shard_grad}")
@@ -404,9 +416,6 @@ def check_grad(
org_grad = getattr_(org_model, suffix).weight.grad
shard_grad = getattr_(sharded_model, suffix).weight.grad
shard_weight = getattr_(sharded_model, suffix).weight
# if verbose and dist.get_rank() == 0:
# print("shard_weight", shard_weight)
# print("org_grad", org_grad)
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
shard_grad_list = [torch.zeros_like(shard_grad).to("cuda") for _ in range(dist.get_world_size(tp_group))]
dist.all_gather(shard_grad_list, shard_grad, tp_group)
@@ -440,7 +449,7 @@ def check_all_grad_tensors(check_tensors):
"org_grad": tensor to be compared from the original model
"shard_grad": tensor to be compared from the sharded model
"""
for suffix, check_info in check_tensors.items():
for idx, (suffix, check_info) in enumerate(check_tensors.items()):
org_grad = check_info["org_grad"]
shard_grad = check_info["shard_grad"]
rtol = check_info["rtol"]

View File

@@ -271,7 +271,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
],
)
def run_command_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry("transformers_command", "transformers_command_for_casual_lm")
sub_model_zoo = model_zoo.get_sub_registry("transformers_command", "transformers_command_for_causal_lm")
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
@@ -321,7 +321,7 @@ def run_command_test(test_config):
],
)
def run_command_3d_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry("transformers_command", "transformers_command_for_casual_lm")
sub_model_zoo = model_zoo.get_sub_registry("transformers_command", "transformers_command_for_causal_lm")
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)

View File

@@ -63,7 +63,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
and booster.plugin.shard_config.sequence_parallelism_mode == "all_to_all"
):
master2working = sharded_optimizer.get_master_to_working_map()
for p1, p2 in zip(llama_model.parameters(), sharded_optimizer._master_param_groups_of_current_rank[0]):
for (name, p1), p2 in zip(
llama_model.named_parameters(), sharded_optimizer._master_param_groups_of_current_rank[0]
):
working_p = master2working[id(p2)]
grads = sharded_optimizer.get_partitioned_gradients_by_param_id(0, id(working_p))
grad_index = (
@@ -73,7 +75,10 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
)
grad = grads[grad_index]
sharded_grad = p1.grad.view(-1).chunk(dist.get_world_size())[dist.get_rank()]
assert_close(sharded_grad, grad[: sharded_grad.shape[0]], atol=5e-3, rtol=5e-3, check_dtype=False)
try:
assert_close(sharded_grad, grad[: sharded_grad.shape[0]], atol=5e-3, rtol=5e-3, check_dtype=False)
except Exception as e:
raise RuntimeError(f"Failed to check grad for {name}") from e
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
grads_to_check = {}
@@ -114,89 +119,130 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
atol, rtol = 5e-3, 5e-3
if org_model.__class__.__name__ == "LlamaModel":
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
check_output_hidden_state(
org_output,
sharded_output,
stage_manager,
atol=atol,
rtol=rtol,
shard_config=booster.plugin.shard_config,
)
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
# check weights
if stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True):
if test_config["precision"] == "fp32":
atol, rtol = 1e-4, 1e-3
else:
atol, rtol = 5e-3, 5e-3
try:
check_weight(
llama_model,
shard_llama_model,
col_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=1,
verbose=False,
)
except Exception as e:
print(f"Failed config: {test_config}")
raise e
check_weight(
llama_model,
shard_llama_model,
col_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=1,
verbose=False,
)
# check grads
check_all_grad_tensors(grads_to_check)
torch.cuda.empty_cache()
@parameterize(
"test_config",
[
{ # Ulysess + Flash attention
# Double Ring Attention
{
"tp_size": 1,
"pp_size": 1,
"sp_size": 4,
"num_microbatches": 1,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "ring_attn",
"use_lazy_init": True,
"zero_stage": 0,
"precision": "fp16",
"initial_scale": 1,
"inner_ring_size": 2,
},
# Ring Attention + PP
{
"tp_size": 1,
"pp_size": 2,
"sp_size": 2,
"num_microbatches": 2,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "ring_attn",
"use_lazy_init": True,
"zero_stage": 1,
"precision": "fp16",
"initial_scale": 1,
},
# Ring Attention + TP
{
"tp_size": 2,
"pp_size": 1,
"sp_size": 2,
"num_microbatches": 1,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "ring_attn",
"use_lazy_init": True,
"zero_stage": 2,
"precision": "fp16",
"initial_scale": 1,
},
{ # Ulysess + TP
"tp_size": 2,
"pp_size": 1,
"sp_size": 2,
"num_microbatches": 1,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "all_to_all",
"enable_all_optimization": True,
"use_lazy_init": True,
"zero_stage": 0,
"precision": "fp16",
"initial_scale": 1,
},
{ # Ulysess + PP
"tp_size": 1,
"pp_size": 2,
"sp_size": 2,
"num_microbatches": 2,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "all_to_all",
"enable_flash_attention": True,
"enable_all_optimization": True,
"use_lazy_init": True,
"zero_stage": 0,
"precision": "fp16",
"initial_scale": 1,
},
{ # Test ring + Flash attention
"tp_size": 2,
"pp_size": 1,
"sp_size": 2,
"num_microbatches": 1,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "ring",
"enable_flash_attention": True,
"use_lazy_init": True,
"zero_stage": 2,
"precision": "fp16",
"initial_scale": 1,
},
{
"tp_size": 1,
"pp_size": 1,
"sp_size": 2,
"num_microbatches": 1,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "all_to_all",
"use_lazy_init": True,
"zero_stage": 1,
"precision": "fp16",
"initial_scale": 1,
},
{
"tp_size": 4,
"pp_size": 1,
"num_microbatches": 1,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "split_gather",
"enable_flash_attention": False,
"enable_flash_attention": True,
"use_lazy_init": True,
"precision": "fp16",
"initial_scale": 1,
},
{
"tp_size": 2,
"pp_size": 1,
"sp_size": 1,
"num_microbatches": 1,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "ring",
"enable_flash_attention": True,
"use_lazy_init": True,
"zero_stage": 2,
"precision": "fp16",
"initial_scale": 1,
},
{
"tp_size": 2,
"pp_size": 2,
@@ -240,12 +286,13 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
)
def run_llama_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry("transformers_llama")
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
if test_config.get("sequence_parallelism_mode", None) == "ring_attn" and "causal" not in name:
continue
try:
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
except Exception as e:
print(f"Failed config: {test_config}")
print(f"Failed config: {test_config}, model name: {name}")
raise e
clear_layout_converter()
Randomizer.reset_index()