mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 01:55:12 +00:00
[fp8] Merge feature/fp8_comm to main branch of Colossalai (#6016)
* add SimPO
* fix dataloader
* remove debug code
* add orpo
* fix style
* fix colossalai, transformers version
* fix colossalai, transformers version
* fix colossalai, transformers version
* fix torch colossalai version
* update transformers version
* [shardformer] DeepseekMoE support (#5871)
* [Feature] deepseek moe expert parallel implement
* [misc] fix typo, remove redundant file (#5867)
* [misc] fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Feature] deepseek support & unit test
* [misc] remove debug code & useless print
* [misc] fix typos (#5872)
* [Feature] remove modeling file, use auto config. (#5884)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [Deepseek] remove redundant code (#5888)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [misc] remove redundant code
* [Feature/deepseek] resolve comment. (#5889)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [misc] remove redundant code
* [misc] mv module replacement into if branch
* [misc] add some warning message and modify some code in unit test
* [misc] fix typos
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Hoxfix] Fix CUDA_DEVICE_MAX_CONNECTIONS for comm overlap
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [Feat] Diffusion Model(PixArtAlpha/StableDiffusion3) Support (#5838)
* Diffusion Model Inference support
* Stable Diffusion 3 Support
* pixartalpha support
* [HotFix] CI,import,requirements-test for #5838 (#5892)
* [Hot Fix] CI,import,requirements-test
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Feature] Enable PP + SP for llama (#5868)
* fix cross-PP-stage position id length diff bug
* fix typo
* fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* use a one cross entropy func for all shardformer models
---------
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [ShardFormer] Add Ulysses Sequence Parallelism support for Command-R, Qwen2 and ChatGLM (#5897)
* add benchmark for sft, dpo, simpo, orpo. Add benchmarking result. Support lora with gradient checkpoint
* fix style
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix eval
* hotfix citation
* [zero] support all-gather overlap (#5898)
* [zero] support all-gather overlap
* [zero] add overlap all-gather flag
* [misc] fix typo
* [zero] update api
* fix orpo cross entropy loss
* [Auto Parallel]: Speed up intra-op plan generation by 44% (#5446)
* Remove unnecessary calls to deepcopy
* Build DimSpec's difference dict only once
This change considerably speeds up construction speed of DimSpec objects. The difference_dict is the same for each DimSpec object, so a single copy of it is enough.
* Fix documentation of DimSpec's difference method
* [ShardFormer] fix qwen2 sp (#5903)
* [compatibility] support torch 2.2 (#5875)
* Support Pytorch 2.2.2
* keep build_on_pr file and update .compatibility
* fix object_to_tensor usage when torch>=2.3.0 (#5820)
* [misc] support torch2.3 (#5893)
* [misc] support torch2.3
* [devops] update compatibility ci
* [devops] update compatibility ci
* [devops] add debug
* [devops] add debug
* [devops] add debug
* [devops] add debug
* [devops] remove debug
* [devops] remove debug
* [release] update version (#5912)
* [plugin] support all-gather overlap for hybrid parallel (#5919)
* [plugin] fixed all-gather overlap support for hybrid parallel
* add kto
* fix style, add kto data sample
* [Examples] Add lazy init to OPT and GPT examples (#5924)
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [ColossalChat] Hotfix for ColossalChat (#5910)
* add ignore and tiny llama
* fix path issue
* run style
* fix issue
* update bash
* add ignore and tiny llama
* fix path issue
* run style
* fix issue
* update bash
* fix ddp issue
* add Qwen 1.5 32B
* refactor tokenization
* [FIX BUG] UnboundLocalError: cannot access local variable 'default_conversation' where it is not associated with a value (#5931)
* cannot access local variable 'default_conversation' where it is not associated with a value
set default value for 'default_conversation'
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* fix test data
* refactor evaluation
* remove real data path
* remove real data path
* Add n_fused as an input from native_module (#5894)
* [FIX BUG] convert env param to int in (#5934)
* [Hotfix] Fix ZeRO typo #5936
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [Feature] Add a switch to control whether the model checkpoint needs to be saved after each epoch ends (#5941)
* Add a switch to control whether the model checkpoint needs to be saved after each epoch ends
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* fix style
* fix style
* fix style
* [shardformer] hotfix attn mask (#5945)
* [shardformer] hotfix attn mask (#5947)
* [Feat] Distrifusion Acceleration Support for Diffusion Inference (#5895)
* Distrifusion Support source
* comp comm overlap optimization
* sd3 benchmark
* pixart distrifusion bug fix
* sd3 bug fix and benchmark
* generation bug fix
* naming fix
* add docstring, fix counter and shape error
* add reference
* readme and requirement
* [zero] hotfix update master params (#5951)
* [release] update version (#5952)
* [Chat] Fix lora (#5946)
* fix merging
* remove filepath
* fix style
* Update README.md (#5958)
* [hotfix] Remove unused plan section (#5957)
* remove readme
* fix readme
* update
* [test] add mixtral for sequence classification
* [test] add mixtral transformer test
* [moe] fix plugin
* [test] mixtra pp shard test
* [chore] handle non member group
* [zero] solve hang
* [test] pass mixtral shardformer test
* [moe] implement transit between non moe tp and ep
* [zero] solve hang
* [misc] solve booster hang by rename the variable
* solve hang when parallel mode = pp + dp
* [moe] implement submesh initialization
* [moe] add mixtral dp grad scaling when not all experts are activated
* [chore] manually revert unintended commit
* [chore] trivial fix
* [chore] arg pass & remove drop token
* [test] add mixtral modelling test
* [moe] implement tp
* [moe] test deepseek
* [moe] clean legacy code
* [Feature] MoE Ulysses Support (#5918)
* moe sp support
* moe sp bug solve
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [chore] minor fix
* [moe] init moe plugin comm setting with sp
* moe sp + ep bug fix
* [moe] finalize test (no pp)
* [moe] full test for deepseek and mixtral (pp + sp to fix)
* [chore] minor fix after rebase
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [chore] solve moe ckpt test failure and some other arg pass failure
* [moe] remove ops
* [test] fix test: test_zero1_2
* [bug] fix: somehow logger hangs the program
* [moe] deepseek moe sp support
* [test] add check
* [deepseek] replace attn (a workaround for bug in transformers)
* [misc] skip redunant test
* [misc] remove debug/print code
* [moe] refactor mesh assignment
* Revert "[moe] implement submesh initialization"
This reverts commit 2f9bce6686
.
* [chore] change moe_pg_mesh to private
* [misc] remove incompatible test config
* [misc] fix ci failure: change default value to false in moe plugin
* [misc] remove useless condition
* [chore] docstring
* [moe] remove force_overlap_comm flag and add warning instead
* [doc] add MoeHybridParallelPlugin docstring
* [moe] solve dp axis issue
* [chore] remove redundant test case, print string & reduce test tokens
* [feat] Dist Loader for Eval (#5950)
* support auto distributed data loader
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* support auto distributed data loader
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix tp error
* remove unused parameters
* remove unused
* update inference
* update docs
* update inference
---------
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [lora] lora support hybrid parallel plugin (#5956)
* lora support hybrid plugin
* fix
* fix
* fix
* fix
* Support overall loss, update KTO logging
* [Docs] clarify launch port
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [Hotfix] README link (#5966)
* update ignore
* update readme
* run style
* update readme
* [Hotfix] Avoid fused RMSnorm import error without apex (#5985)
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [Chat] fix readme (#5989)
* fix readme
* fix readme, tokenization fully tested
* fix readme, tokenization fully tested
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: root <root@notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9-0.notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9.colossal-ai.svc.cluster.local>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* fix sync condition (#6000)
* [plugin] add cast inputs option for zero (#6003)
* [pre-commit.ci] pre-commit autoupdate (#5995)
updates:
- [github.com/psf/black-pre-commit-mirror: 24.4.2 → 24.8.0](https://github.com/psf/black-pre-commit-mirror/compare/24.4.2...24.8.0)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [misc] Bypass the huggingface bug to solve the mask mismatch problem (#5991)
* [Feature] Zigzag Ring attention (#5905)
* halfway
* fix cross-PP-stage position id length diff bug
* fix typo
* fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* unified cross entropy func for all shardformer models
* remove redundant lines
* add basic ring attn; debug cross entropy
* fwd bwd logic complete
* fwd bwd logic complete; add experimental triton rescale
* precision tests passed
* precision tests passed
* fix typos and remove misc files
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* add sp_mode to benchmark; fix varlen interface
* update softmax_lse shape by new interface
* change tester name
* remove buffer clone; support packed seq layout
* add varlen tests
* fix typo
* all tests passed
* add dkv_group; fix mask
* remove debug statements
---------
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [misc] update compatibility (#6008)
* [misc] update compatibility
* [misc] update requirements
* [devops] disable requirements cache
* [test] fix torch ddp test
* [test] fix rerun on address in use
* [test] fix lazy init
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix the merge
* fix the merge
* overlap kv comm with output rescale (#6017)
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* fix the merge
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix the merge
* fix
* fix
* fix the merge
* fix
* [misc] Use dist logger in plugins (#6011)
* use dist logger in plugins
* remove trash
* print on rank 0
---------
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* fix
* fix
* fix
* fix
* fix the merge
* fix
* fix
* fix
* fix
---------
Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: Haze188 <haze188@qq.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com>
Co-authored-by: Guangyao Zhang <xjtu521@qq.com>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: Stephan Kö <stephankoe@users.noreply.github.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: zhurunhua <1281592874@qq.com>
Co-authored-by: Insu Jang <insujang@umich.edu>
Co-authored-by: Gao, Ruiyuan <905370712@qq.com>
Co-authored-by: hxwang <wang1570@e.ntu.edu.sg>
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: root <root@notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9-0.notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9.colossal-ai.svc.cluster.local>
This commit is contained in:
@@ -22,9 +22,9 @@ COMMON_MODELS = [
|
||||
"transformers_bloom_for_causal_lm",
|
||||
"transformers_falcon_for_causal_lm",
|
||||
"transformers_chatglm_for_conditional_generation",
|
||||
"transformers_llama_for_casual_lm",
|
||||
"transformers_llama_for_causal_lm",
|
||||
"transformers_vit_for_masked_image_modeling",
|
||||
"transformers_mistral_for_casual_lm",
|
||||
"transformers_mistral_for_causal_lm",
|
||||
]
|
||||
|
||||
IS_FAST_TEST = os.environ.get("FAST_TEST", "0") == "1"
|
||||
|
@@ -32,8 +32,8 @@ if HAS_COMMAND:
|
||||
|
||||
return dict(input_ids=input_ids, attention_mask=attention_mask)
|
||||
|
||||
# label is needed for casual lm
|
||||
def data_gen_for_casual_lm():
|
||||
# label is needed for causal lm
|
||||
def data_gen_for_causal_lm():
|
||||
data = data_gen()
|
||||
labels = data["input_ids"].clone()
|
||||
data["labels"] = labels
|
||||
@@ -44,7 +44,7 @@ if HAS_COMMAND:
|
||||
|
||||
# function to get the loss
|
||||
loss_fn = lambda output: output["last_hidden_state"].mean()
|
||||
loss_fn_for_casual_lm = lambda output: output["loss"]
|
||||
loss_fn_for_causal_lm = lambda output: output["loss"]
|
||||
loss_fn_for_seq_classification = lambda output: output["logits"].mean()
|
||||
|
||||
config = CohereConfig(
|
||||
@@ -70,10 +70,10 @@ if HAS_COMMAND:
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
model_zoo.register(
|
||||
name="transformers_command_for_casual_lm",
|
||||
name="transformers_command_for_causal_lm",
|
||||
model_fn=lambda: transformers.CohereForCausalLM(config),
|
||||
data_gen_fn=data_gen_for_casual_lm,
|
||||
data_gen_fn=data_gen_for_causal_lm,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_casual_lm,
|
||||
loss_fn=loss_fn_for_causal_lm,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
|
@@ -33,20 +33,21 @@ if HAS_LLAMA:
|
||||
[1, 15043, 29892, 590, 11203, 338, 274, 1082, 1, 15043, 29892, 590, 11203, 338, 274, 1082],
|
||||
]
|
||||
).long()
|
||||
|
||||
attention_mask = torch.Tensor(
|
||||
[
|
||||
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
|
||||
]
|
||||
).long()
|
||||
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
return dict(input_ids=input_ids, attention_mask=attention_mask)
|
||||
|
||||
# label is needed for casual lm
|
||||
def data_gen_for_casual_lm():
|
||||
# label is needed for causal lm
|
||||
def data_gen_for_causal_lm():
|
||||
data = data_gen()
|
||||
|
||||
# Test padded sequence
|
||||
padding = torch.zeros(2, data["input_ids"].shape[1] // 2, dtype=torch.long)
|
||||
data["input_ids"] = torch.cat([data["input_ids"], padding], dim=1)
|
||||
data["attention_mask"] = torch.cat([data["attention_mask"], padding], dim=1)
|
||||
|
||||
ignore_idx = -100
|
||||
labels = data["input_ids"].clone()
|
||||
labels[~data["attention_mask"].bool()] = ignore_idx
|
||||
data["labels"] = labels
|
||||
return data
|
||||
|
||||
@@ -55,7 +56,7 @@ if HAS_LLAMA:
|
||||
|
||||
# function to get the loss
|
||||
loss_fn = lambda output: output["last_hidden_state"].mean()
|
||||
loss_fn_for_casual_lm = lambda output: output["loss"]
|
||||
loss_fn_for_causal_lm = lambda output: output["loss"]
|
||||
loss_fn_for_seq_classification = lambda output: output["logits"].mean()
|
||||
|
||||
config = LlamaConfig(
|
||||
@@ -70,9 +71,17 @@ if HAS_LLAMA:
|
||||
config.pad_token_id = config.eos_token_id
|
||||
|
||||
# register the following models
|
||||
# transformers.LlamaModel,
|
||||
# transformers.LlamaForCausalLM,
|
||||
# transformers.LlamaModel,
|
||||
# transformers.LlamaForSequenceClassification,
|
||||
model_zoo.register(
|
||||
name="transformers_llama_for_causal_lm",
|
||||
model_fn=lambda: transformers.LlamaForCausalLM(config),
|
||||
data_gen_fn=data_gen_for_causal_lm,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_causal_lm,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
model_zoo.register(
|
||||
name="transformers_llama",
|
||||
model_fn=lambda: transformers.LlamaModel(config),
|
||||
@@ -81,14 +90,6 @@ if HAS_LLAMA:
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
model_zoo.register(
|
||||
name="transformers_llama_for_casual_lm",
|
||||
model_fn=lambda: transformers.LlamaForCausalLM(config),
|
||||
data_gen_fn=data_gen_for_casual_lm,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_casual_lm,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
model_zoo.register(
|
||||
name="transformers_llama_for_sequence_classification",
|
||||
model_fn=lambda: transformers.LlamaForSequenceClassification(config),
|
||||
|
@@ -64,7 +64,7 @@ model_zoo.register(
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
model_zoo.register(
|
||||
name="transformers_mistral_for_casual_lm",
|
||||
name="transformers_mistral_for_causal_lm",
|
||||
model_fn=lambda: transformers.MistralForCausalLM(config),
|
||||
data_gen_fn=data_gen_for_lm,
|
||||
output_transform_fn=output_transform_fn,
|
||||
|
@@ -53,6 +53,8 @@ config = MixtralConfig(
|
||||
num_attention_heads=8,
|
||||
num_hidden_layers=2,
|
||||
vocab_size=1000,
|
||||
attn_implementation="flash_attention_2",
|
||||
torch_dtype="float16",
|
||||
output_router_logits=True,
|
||||
)
|
||||
|
||||
|
@@ -33,8 +33,8 @@ if HAS_QWEN2:
|
||||
attention_mask = torch.Tensor([[1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1]]).long()
|
||||
return dict(input_ids=input_ids, attention_mask=attention_mask)
|
||||
|
||||
# label is needed for casual lm
|
||||
def data_gen_for_casual_lm():
|
||||
# label is needed for causal lm
|
||||
def data_gen_for_causal_lm():
|
||||
data = data_gen()
|
||||
labels = data["input_ids"].clone()
|
||||
data["labels"] = labels
|
||||
@@ -45,7 +45,7 @@ if HAS_QWEN2:
|
||||
|
||||
# function to get the loss
|
||||
loss_fn = lambda output: output["last_hidden_state"].mean()
|
||||
loss_fn_for_casual_lm = lambda output: output["loss"]
|
||||
loss_fn_for_causal_lm = lambda output: output["loss"]
|
||||
loss_fn_for_seq_classification = lambda output: output["logits"].mean()
|
||||
|
||||
config = Qwen2Config(
|
||||
@@ -72,11 +72,11 @@ if HAS_QWEN2:
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
model_zoo.register(
|
||||
name="transformers_qwen2_for_casual_lm",
|
||||
name="transformers_qwen2_for_causal_lm",
|
||||
model_fn=lambda: transformers.Qwen2ForCausalLM(config),
|
||||
data_gen_fn=data_gen_for_casual_lm,
|
||||
data_gen_fn=data_gen_for_causal_lm,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_casual_lm,
|
||||
loss_fn=loss_fn_for_causal_lm,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
model_zoo.register(
|
||||
|
@@ -97,7 +97,7 @@ def check_3d_plugin(init_method: str = "none", early_stop: bool = True):
|
||||
|
||||
# TODO(ver217): add more models
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.get_sub_registry(
|
||||
"transformers_llama_for_casual_lm"
|
||||
"transformers_llama_for_causal_lm"
|
||||
).items():
|
||||
err = run_fn(init_method, model_fn, data_gen_fn, output_transform_fn)
|
||||
|
||||
|
@@ -105,7 +105,7 @@ def check_low_level_zero_lora(stage, model_name, early_stop: bool = True):
|
||||
sub_model_zoo = model_zoo.get_sub_registry(model_name)
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
task_type = None
|
||||
if name == "transformers_llama_for_casual_lm":
|
||||
if name == "transformers_llama_for_causal_lm":
|
||||
task_type = "CAUSAL_LM"
|
||||
if name == "transformers_llama_for_sequence_classification":
|
||||
task_type = "SEQ_CLS"
|
||||
|
@@ -47,7 +47,7 @@ def check_torch_ddp_plugin():
|
||||
registry = model_zoo
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in registry.items():
|
||||
if name == "dlrm_interactionarch" or name.startswith("simple_"):
|
||||
if name in ("dlrm_interactionarch", "transformers_mixtral") or name.startswith("simple_"):
|
||||
continue
|
||||
run_fn(model_fn, data_gen_fn, output_transform_fn)
|
||||
torch.cuda.empty_cache()
|
||||
|
@@ -74,7 +74,7 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b
|
||||
@clear_cache_before_run()
|
||||
@parameterize("placement_config", OPTIM_PLACEMENT_CONFIGS)
|
||||
@parameterize("shard", [True, False])
|
||||
@parameterize("model_name", ["transformers_llama_for_casual_lm"])
|
||||
@parameterize("model_name", ["transformers_llama_for_causal_lm"])
|
||||
@parameterize("size_per_shard", [32])
|
||||
@parameterize("tp_size", [1, 2])
|
||||
@parameterize("zero_size", [2])
|
||||
|
@@ -20,7 +20,7 @@ from tests.kit.model_zoo import model_zoo
|
||||
|
||||
@clear_cache_before_run()
|
||||
@parameterize("shard", [False, True])
|
||||
@parameterize("model_name", ["transformers_llama_for_casual_lm"])
|
||||
@parameterize("model_name", ["transformers_llama_for_causal_lm"])
|
||||
def exam_torch_load_from_gemini(shard: bool, model_name: str):
|
||||
(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
|
||||
criterion = lambda x: x.mean()
|
||||
|
@@ -39,7 +39,7 @@ else:
|
||||
|
||||
|
||||
@parameterize("shard", [True, False])
|
||||
@parameterize("model_name", ["transformers_llama_for_casual_lm"])
|
||||
@parameterize("model_name", ["transformers_llama_for_causal_lm"])
|
||||
@parameterize("size_per_shard", [32])
|
||||
@parameterize("test_config", TEST_CONFIGS)
|
||||
@clear_cache_before_run()
|
||||
|
@@ -149,7 +149,7 @@ def check_low_level_zero_lora_checkpointIO(
|
||||
if name != "transformers_llama":
|
||||
continue
|
||||
task_type = None
|
||||
if name == "transformers_llama_for_casual_lm":
|
||||
if name == "transformers_llama_for_causal_lm":
|
||||
task_type = "CAUSAL_LM"
|
||||
if name == "transformers_llama_for_sequence_classification":
|
||||
task_type = "SEQ_CLS"
|
||||
|
@@ -18,7 +18,7 @@ from tests.kit.model_zoo import model_zoo
|
||||
|
||||
|
||||
@clear_cache_before_run()
|
||||
@parameterize("model_name", ["transformers_llama_for_casual_lm"])
|
||||
@parameterize("model_name", ["transformers_llama_for_causal_lm"])
|
||||
@parameterize("plugin_type", ["ddp", "zero", "gemini"])
|
||||
def exam_from_pretrained(plugin_type: str, model_name: str, shard=True, size_per_shard=32):
|
||||
(model_fn, data_gen_fn, output_transform_fn, loss_fn, _) = next(
|
||||
|
@@ -18,9 +18,17 @@ def test_models_lazy_init(subset, default_device):
|
||||
sub_model_zoo = model_zoo.get_sub_registry(subset, allow_empty=True)
|
||||
for name, entry in sub_model_zoo.items():
|
||||
# TODO(ver217): lazy init does not support weight norm, skip these models
|
||||
if name in ("torchaudio_wav2vec2_base", "torchaudio_hubert_base") or name.startswith(
|
||||
("transformers_vit", "transformers_blip2", "transformers_whisper")
|
||||
):
|
||||
if name in (
|
||||
"torchaudio_wav2vec2_base",
|
||||
"torchaudio_hubert_base",
|
||||
"timm_beit",
|
||||
"timm_vision_transformer",
|
||||
"timm_deit",
|
||||
"timm_beitv2",
|
||||
"timm_deit3",
|
||||
"timm_convit",
|
||||
"timm_tnt_b_patch16_224",
|
||||
) or name.startswith(("transformers_vit", "transformers_blip2", "transformers_whisper")):
|
||||
continue
|
||||
check_lazy_init(entry, verbose=True, default_device=default_device)
|
||||
|
||||
|
@@ -91,7 +91,7 @@ def run_lora_test():
|
||||
sub_model_zoo = model_zoo.get_sub_registry("transformers_llama")
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
task_type = None
|
||||
if name == "transformers_llama_for_casual_lm":
|
||||
if name == "transformers_llama_for_causal_lm":
|
||||
task_type = "CAUSAL_LM"
|
||||
if name == "transformers_llama_for_sequence_classification":
|
||||
task_type = "SEQ_CLS"
|
||||
|
@@ -6,6 +6,7 @@ import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.testing import assert_close
|
||||
|
||||
import colossalai
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
@@ -107,13 +108,13 @@ def run_pp(
|
||||
|
||||
# check loss
|
||||
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||
assert torch.allclose(torch_loss, pp_ret["loss"])
|
||||
assert_close(torch_loss, pp_ret["loss"])
|
||||
|
||||
# check gradients
|
||||
for i in range(num_model_chunk):
|
||||
idx = world_size * i + rank
|
||||
assert torch.allclose(torch_model.layers[idx].weight.grad, sharded_model[i].weight.grad)
|
||||
assert torch.allclose(torch_model.layers[idx].bias.grad, sharded_model[i].bias.grad)
|
||||
assert_close(torch_model.layers[idx].weight.grad, sharded_model[i].weight.grad)
|
||||
assert_close(torch_model.layers[idx].bias.grad, sharded_model[i].bias.grad)
|
||||
|
||||
# step
|
||||
torch_optimizer.step()
|
||||
@@ -123,8 +124,8 @@ def run_pp(
|
||||
# check updated param
|
||||
for i in range(num_model_chunk):
|
||||
idx = world_size * i + rank
|
||||
assert torch.allclose(torch_model.layers[idx].weight, sharded_model[i].weight)
|
||||
assert torch.allclose(torch_model.layers[idx].bias, sharded_model[i].bias)
|
||||
assert_close(torch_model.layers[idx].weight, sharded_model[i].weight)
|
||||
assert_close(torch_model.layers[idx].bias, sharded_model[i].bias)
|
||||
|
||||
# forward only
|
||||
with torch.no_grad():
|
||||
@@ -135,14 +136,14 @@ def run_pp(
|
||||
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True
|
||||
)
|
||||
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||
assert torch.allclose(torch_loss, pp_ret["loss"])
|
||||
assert_close(torch_loss, pp_ret["loss"])
|
||||
|
||||
for layer in sharded_model:
|
||||
if layer.weight.grad is None:
|
||||
assert layer.weight.grad is None and layer.bias.grad is None
|
||||
else:
|
||||
assert torch.allclose(layer.weight.grad, torch.zeros_like(layer.weight.grad))
|
||||
assert torch.allclose(layer.bias.grad, torch.zeros_like(layer.bias.grad))
|
||||
assert_close(layer.weight.grad, torch.zeros_like(layer.weight.grad))
|
||||
assert_close(layer.bias.grad, torch.zeros_like(layer.bias.grad))
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
|
@@ -6,6 +6,7 @@ import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.testing import assert_close
|
||||
|
||||
import colossalai
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
@@ -103,13 +104,13 @@ def examine_pp(num_microbatch: int, batch_size: int):
|
||||
|
||||
# check loss
|
||||
if stage_manager.is_last_stage():
|
||||
assert torch.allclose(torch_loss, pp_ret["loss"])
|
||||
assert_close(torch_loss, pp_ret["loss"])
|
||||
|
||||
# check gradients
|
||||
for i in range(len(sharded_model)):
|
||||
idx = rank * num_local_layer + i
|
||||
assert torch.allclose(torch_model.layers[idx].weight.grad, sharded_model[i].weight.grad)
|
||||
assert torch.allclose(torch_model.layers[idx].bias.grad, sharded_model[i].bias.grad)
|
||||
assert_close(torch_model.layers[idx].weight.grad, sharded_model[i].weight.grad)
|
||||
assert_close(torch_model.layers[idx].bias.grad, sharded_model[i].bias.grad)
|
||||
|
||||
# step
|
||||
torch_optimizer.step()
|
||||
@@ -119,8 +120,8 @@ def examine_pp(num_microbatch: int, batch_size: int):
|
||||
# check updated param
|
||||
for i in range(len(sharded_model)):
|
||||
idx = rank * num_local_layer + i
|
||||
assert torch.allclose(torch_model.layers[idx].weight, sharded_model[i].weight)
|
||||
assert torch.allclose(torch_model.layers[idx].bias, sharded_model[i].bias)
|
||||
assert_close(torch_model.layers[idx].weight, sharded_model[i].weight)
|
||||
assert_close(torch_model.layers[idx].bias, sharded_model[i].bias)
|
||||
|
||||
# forward only
|
||||
with torch.no_grad():
|
||||
@@ -131,14 +132,14 @@ def examine_pp(num_microbatch: int, batch_size: int):
|
||||
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True
|
||||
)
|
||||
if stage_manager.is_last_stage():
|
||||
assert torch.allclose(torch_loss, pp_ret["loss"])
|
||||
assert_close(torch_loss, pp_ret["loss"])
|
||||
|
||||
for layer in sharded_model:
|
||||
if layer.weight.grad is None:
|
||||
assert layer.weight.grad is None and layer.bias.grad is None
|
||||
else:
|
||||
assert torch.allclose(layer.weight.grad, torch.zeros_like(layer.weight.grad))
|
||||
assert torch.allclose(layer.bias.grad, torch.zeros_like(layer.bias.grad))
|
||||
assert_close(layer.weight.grad, torch.zeros_like(layer.weight.grad))
|
||||
assert_close(layer.bias.grad, torch.zeros_like(layer.bias.grad))
|
||||
|
||||
|
||||
def run_dist(
|
||||
|
@@ -88,6 +88,7 @@ def check_attn_func(dtype: torch.dtype, attn_func, attn_kwargs: dict, padding_ma
|
||||
padding_mask = padding_mask[:, None, :, None].logical_not()
|
||||
ref_output = ref_output.masked_fill(padding_mask, 0)
|
||||
output = output.masked_fill(padding_mask, 0)
|
||||
|
||||
assert_close(output, ref_output, **tols)
|
||||
output.mean().backward()
|
||||
ref_output.mean().backward()
|
||||
@@ -128,6 +129,8 @@ def test_flash_attn_func(dtype: torch.dtype):
|
||||
attn_kwargs, padding_mask = gen_kwargs_func(dtype)
|
||||
for attn_func, name, need_postprocess in attn_funcs:
|
||||
print(f"{dtype}, {name}, {mask_type}")
|
||||
if mask_type == "padded":
|
||||
pass
|
||||
if need_postprocess:
|
||||
check_attn_func(dtype, attn_func, post_process_kwargs_for_raw_attn(attn_kwargs), padding_mask)
|
||||
else:
|
||||
|
186
tests/test_shardformer/test_layer/test_ring_attn.py
Normal file
186
tests/test_shardformer/test_layer/test_ring_attn.py
Normal file
@@ -0,0 +1,186 @@
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
from flash_attn import flash_attn_qkvpacked_func, flash_attn_varlen_qkvpacked_func
|
||||
from torch.testing import assert_close
|
||||
|
||||
import colossalai
|
||||
from colossalai.shardformer.layer import AttnMaskType
|
||||
from colossalai.shardformer.layer.attn import AttnMaskType, RingAttention
|
||||
from colossalai.shardformer.layer.utils import split_batch_zigzag, split_varlen_zigzag
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
@parameterize("seq_len", [4096])
|
||||
@parameterize("bs", [2])
|
||||
@parameterize("nheads", [5])
|
||||
@parameterize("d", [128])
|
||||
@parameterize("dtype", [torch.bfloat16, torch.float16])
|
||||
def check_ring_attn(seq_len, bs, nheads, d, dtype):
|
||||
torch.cuda.manual_seed(2)
|
||||
device = get_current_device()
|
||||
sp_group = dist.group.WORLD
|
||||
sp_size = dist.get_world_size()
|
||||
# Some outliers may seem large, but our errors are still lower than
|
||||
# than Megatron-LM context parallel's
|
||||
# (https://github.com/NVIDIA/TransformerEngine/blob/33a3d02f81c56e6f7b542c09bfa86657078d57fb/tests/pytorch/fused_attn/run_fused_attn_with_cp.py#L215)
|
||||
# and the original zigzag implementation's (https://github.com/zhuzilin/ring-flash-attention/tree/main)
|
||||
atol = rtol = 7e-3
|
||||
|
||||
# Setup inputs
|
||||
qkv = torch.randn(bs, seq_len, 3, nheads, d, device=device, dtype=dtype, requires_grad=True)
|
||||
local_qkv = split_batch_zigzag(qkv, sp_group)
|
||||
q, k, v = local_qkv.unbind(dim=-3)
|
||||
q, k, v = [x.squeeze(2).detach().clone().transpose(1, 2) for x in (q, k, v)] # (B, nHeads, Sq, D)
|
||||
q.requires_grad = k.requires_grad = v.requires_grad = True
|
||||
|
||||
# Ring attention vs single GPU
|
||||
ring_out, ring_lse = RingAttention.attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
sp_group,
|
||||
AttnMaskType.CAUSAL,
|
||||
return_softmax=True,
|
||||
inner_ring_size=max(2, sp_size // 2),
|
||||
# inner_ring_size=4
|
||||
)
|
||||
ring_out = ring_out.transpose(1, 2)
|
||||
out, lse, _ = flash_attn_qkvpacked_func(
|
||||
qkv, dropout_p=0.0, causal=True, window_size=(-1, -1), alibi_slopes=None, return_attn_probs=True
|
||||
)
|
||||
|
||||
# Checkout out and softmax denominator
|
||||
local_out = split_batch_zigzag(out, sp_group)
|
||||
local_lse = split_batch_zigzag(lse, sp_group, seq_dim=-1)
|
||||
local_lse = local_lse.transpose(1, 2).contiguous().view(-1, ring_lse.shape[-1]) # (B, nHeads, Sq) -> (T, nHeads)
|
||||
assert_close(ring_lse, local_lse, atol=atol, rtol=rtol)
|
||||
assert_close(ring_out, local_out, atol=atol, rtol=rtol)
|
||||
|
||||
# Check grads
|
||||
ring_out.sum().backward()
|
||||
out.sum().backward()
|
||||
ring_dq, ring_dk, ring_dv = [x.transpose(1, 2) for x in (q.grad, k.grad, v.grad)]
|
||||
dqkv = qkv.grad
|
||||
local_dqkv = split_batch_zigzag(dqkv, sp_group)
|
||||
|
||||
assert_close(ring_dq, local_dqkv[:, :, 0], atol=atol, rtol=rtol)
|
||||
assert_close(ring_dk, local_dqkv[:, :, 1], atol=atol, rtol=rtol)
|
||||
assert_close(ring_dv, local_dqkv[:, :, 2], atol=atol, rtol=rtol)
|
||||
if dist.get_rank() == 0:
|
||||
print(
|
||||
f"sp_size {dist.get_world_size()}, inner ring size {dist.get_world_size(RingAttention.INNER_RING_GROUP)} passed."
|
||||
)
|
||||
|
||||
|
||||
@parameterize("seqlen", [4096])
|
||||
@parameterize("bs", [2])
|
||||
@parameterize("nheads", [5])
|
||||
@parameterize("d", [128])
|
||||
@parameterize("dtype", [torch.bfloat16, torch.float16])
|
||||
def check_packed_seq(seqlen, bs, nheads, d, dtype):
|
||||
device = get_current_device()
|
||||
sp_group = dist.group.WORLD
|
||||
sp_size = dist.get_world_size()
|
||||
atol = rtol = 7e-3
|
||||
torch.cuda.manual_seed(2)
|
||||
# Prepare varlen attention mask
|
||||
padding_mask = torch.ones((bs, seqlen), dtype=torch.int, device=device)
|
||||
padding_mask[: bs // 2, (seqlen // 4) * 3 :] = 0
|
||||
padding_mask[:, seqlen // 2 :] = 0
|
||||
|
||||
input_embeds = torch.randn(bs, seqlen, nheads, d, device=device, dtype=dtype, requires_grad=True)
|
||||
|
||||
# Forward
|
||||
# out = ColoAttention.attention(q, k, v, **mask_info)
|
||||
flat_input = input_embeds.view(-1, nheads, d)[padding_mask.flatten().nonzero().squeeze()]
|
||||
qkv = torch.stack([flat_input] * 3, dim=1)
|
||||
qkv.retain_grad()
|
||||
|
||||
input_embeds, mask_info, _ = RingAttention.prepare_varlen_batch(padding_mask, sp_group, input_embeds)
|
||||
out, lse, _ = flash_attn_varlen_qkvpacked_func(
|
||||
qkv,
|
||||
mask_info["cu_seqlens"] * sp_size,
|
||||
mask_info["max_seqlen"] * sp_size,
|
||||
return_attn_probs=True,
|
||||
causal=True,
|
||||
# deterministic=True
|
||||
)
|
||||
# Test the splitting function
|
||||
local_input = split_varlen_zigzag(
|
||||
flat_input, mask_info["cu_seqlens"] * sp_size, sp_group, mask_info["max_seqlen"] * sp_size
|
||||
)
|
||||
assert (local_input == input_embeds.view(-1, nheads, d)[mask_info["valid_indices"]]).all()
|
||||
del local_input, flat_input
|
||||
|
||||
q_ring, k_ring, v_ring = [input_embeds.clone().transpose(1, 2) for _ in range(3)]
|
||||
q_ring.retain_grad()
|
||||
k_ring.retain_grad()
|
||||
v_ring.retain_grad()
|
||||
|
||||
ring_out, ring_lse = RingAttention.attention(
|
||||
q_ring,
|
||||
k_ring,
|
||||
v_ring,
|
||||
sp_group,
|
||||
**mask_info,
|
||||
pad_output=False,
|
||||
return_softmax=True,
|
||||
# deterministic=True
|
||||
)
|
||||
ring_out = ring_out.transpose(1, 2).reshape(-1, nheads, d)
|
||||
# Check output
|
||||
lse = lse.transpose(0, 1)
|
||||
out, lse = split_varlen_zigzag(
|
||||
[out, lse], mask_info["cu_seqlens"] * sp_size, sp_group, mask_info["max_seqlen"] * sp_size
|
||||
)
|
||||
assert_close(lse, ring_lse, atol=atol, rtol=rtol)
|
||||
assert_close(out, ring_out, atol=atol, rtol=rtol)
|
||||
|
||||
# Check grads
|
||||
labels = torch.ones(out.shape[0], dtype=dtype, device=device)
|
||||
F.mse_loss(out.sum((-2, -1)), labels).backward()
|
||||
F.mse_loss(ring_out.sum((-2, -1)), labels[: ring_out.shape[0]]).backward()
|
||||
dq, dk, dv = [
|
||||
split_varlen_zigzag(
|
||||
qkv.grad[:, i], mask_info["cu_seqlens"] * sp_size, sp_group, mask_info["max_seqlen"] * sp_size
|
||||
)
|
||||
for i in range(3)
|
||||
]
|
||||
dq_ring, dk_ring, dv_ring = [
|
||||
x.transpose(1, 2).reshape(-1, nheads, d)[mask_info["valid_indices"]]
|
||||
for x in (q_ring.grad, k_ring.grad, v_ring.grad)
|
||||
]
|
||||
|
||||
assert_close(dq, dq_ring, atol=atol, rtol=rtol)
|
||||
assert_close(dk, dk_ring, atol=atol, rtol=rtol)
|
||||
assert_close(dv, dv_ring, atol=atol, rtol=rtol)
|
||||
|
||||
|
||||
def launch_single_ring(rank, world_size, port):
|
||||
colossalai.launch(rank, world_size, "localhost", port)
|
||||
check_packed_seq()
|
||||
check_ring_attn()
|
||||
|
||||
|
||||
def launch_double_ring(rank, world_size, port):
|
||||
colossalai.launch(rank, world_size, "localhost", port)
|
||||
check_ring_attn()
|
||||
|
||||
|
||||
@rerun_if_address_is_in_use()
|
||||
@parameterize("world_size", [2])
|
||||
def test_ring_attn(world_size):
|
||||
spawn(launch_single_ring, nprocs=world_size)
|
||||
|
||||
|
||||
@rerun_if_address_is_in_use()
|
||||
@parameterize("world_size", [4])
|
||||
def test_double_ring(world_size):
|
||||
spawn(launch_double_ring, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_ring_attn()
|
||||
test_double_ring()
|
@@ -10,6 +10,7 @@ from torch.distributed import ProcessGroup
|
||||
from torch.nn import Module
|
||||
from torch.optim import Adam, Optimizer
|
||||
from torch.testing import assert_close
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.booster import Booster
|
||||
@@ -259,7 +260,6 @@ def run_forward_backward_with_hybrid_plugin(
|
||||
org_output = org_model(**unshard_test_data)
|
||||
org_loss = criterion(org_output)
|
||||
org_loss.backward()
|
||||
|
||||
return org_loss, org_output, sharded_loss, sharded_output
|
||||
|
||||
|
||||
@@ -302,11 +302,12 @@ def run_forward_backward_with_low_level_zero_plugin(
|
||||
|
||||
|
||||
def check_output_hidden_state(
|
||||
org_output: Tensor,
|
||||
sharded_output: Tensor,
|
||||
org_output: BaseModelOutputWithPast,
|
||||
sharded_output: BaseModelOutputWithPast,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
atol: float = 1e-5,
|
||||
rtol: float = 1e-3,
|
||||
shard_config: Optional[ShardConfig] = None,
|
||||
):
|
||||
org_hidden_state = org_output.last_hidden_state
|
||||
|
||||
@@ -315,6 +316,14 @@ def check_output_hidden_state(
|
||||
else:
|
||||
sharded_hidden_state = sharded_output.last_hidden_state
|
||||
|
||||
# Check if the output sequence is gathered before cross entropy
|
||||
if shard_config is not None:
|
||||
seq_dim = 1
|
||||
sp_group = shard_config.sequence_parallel_process_group
|
||||
sp_size = shard_config.sequence_parallel_size
|
||||
if org_hidden_state.shape[seq_dim] == sharded_hidden_state.shape[seq_dim] * sp_size:
|
||||
org_hidden_state = org_hidden_state.chunk(sp_size, dim=seq_dim)[dist.get_rank(sp_group)]
|
||||
|
||||
assert_close(org_hidden_state.float(), sharded_hidden_state.float(), atol=atol, rtol=rtol)
|
||||
|
||||
|
||||
@@ -374,8 +383,11 @@ def get_grad_tensors_for_check(
|
||||
shard_grad = torch.cat(shard_grad_list, dim=dim)
|
||||
|
||||
# embedding may be resized when using tensor parallel
|
||||
if shard_grad.shape[0] > org_grad.shape[0]:
|
||||
shard_grad = shard_grad[: org_grad.shape[0], :]
|
||||
try:
|
||||
if shard_grad.shape[0] > org_grad.shape[0]:
|
||||
shard_grad = shard_grad[: org_grad.shape[0], :]
|
||||
except:
|
||||
pass
|
||||
if verbose and dist.get_rank() == 0:
|
||||
print(f"'{suffix}' grad: {org_grad}, {shard_grad}")
|
||||
|
||||
@@ -404,9 +416,6 @@ def check_grad(
|
||||
org_grad = getattr_(org_model, suffix).weight.grad
|
||||
shard_grad = getattr_(sharded_model, suffix).weight.grad
|
||||
shard_weight = getattr_(sharded_model, suffix).weight
|
||||
# if verbose and dist.get_rank() == 0:
|
||||
# print("shard_weight", shard_weight)
|
||||
# print("org_grad", org_grad)
|
||||
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
|
||||
shard_grad_list = [torch.zeros_like(shard_grad).to("cuda") for _ in range(dist.get_world_size(tp_group))]
|
||||
dist.all_gather(shard_grad_list, shard_grad, tp_group)
|
||||
@@ -440,7 +449,7 @@ def check_all_grad_tensors(check_tensors):
|
||||
"org_grad": tensor to be compared from the original model
|
||||
"shard_grad": tensor to be compared from the sharded model
|
||||
"""
|
||||
for suffix, check_info in check_tensors.items():
|
||||
for idx, (suffix, check_info) in enumerate(check_tensors.items()):
|
||||
org_grad = check_info["org_grad"]
|
||||
shard_grad = check_info["shard_grad"]
|
||||
rtol = check_info["rtol"]
|
||||
|
@@ -271,7 +271,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||
],
|
||||
)
|
||||
def run_command_test(test_config):
|
||||
sub_model_zoo = model_zoo.get_sub_registry("transformers_command", "transformers_command_for_casual_lm")
|
||||
sub_model_zoo = model_zoo.get_sub_registry("transformers_command", "transformers_command_for_causal_lm")
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||
@@ -321,7 +321,7 @@ def run_command_test(test_config):
|
||||
],
|
||||
)
|
||||
def run_command_3d_test(test_config):
|
||||
sub_model_zoo = model_zoo.get_sub_registry("transformers_command", "transformers_command_for_casual_lm")
|
||||
sub_model_zoo = model_zoo.get_sub_registry("transformers_command", "transformers_command_for_causal_lm")
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||
|
@@ -63,7 +63,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||
and booster.plugin.shard_config.sequence_parallelism_mode == "all_to_all"
|
||||
):
|
||||
master2working = sharded_optimizer.get_master_to_working_map()
|
||||
for p1, p2 in zip(llama_model.parameters(), sharded_optimizer._master_param_groups_of_current_rank[0]):
|
||||
for (name, p1), p2 in zip(
|
||||
llama_model.named_parameters(), sharded_optimizer._master_param_groups_of_current_rank[0]
|
||||
):
|
||||
working_p = master2working[id(p2)]
|
||||
grads = sharded_optimizer.get_partitioned_gradients_by_param_id(0, id(working_p))
|
||||
grad_index = (
|
||||
@@ -73,7 +75,10 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||
)
|
||||
grad = grads[grad_index]
|
||||
sharded_grad = p1.grad.view(-1).chunk(dist.get_world_size())[dist.get_rank()]
|
||||
assert_close(sharded_grad, grad[: sharded_grad.shape[0]], atol=5e-3, rtol=5e-3, check_dtype=False)
|
||||
try:
|
||||
assert_close(sharded_grad, grad[: sharded_grad.shape[0]], atol=5e-3, rtol=5e-3, check_dtype=False)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to check grad for {name}") from e
|
||||
|
||||
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
|
||||
grads_to_check = {}
|
||||
@@ -114,89 +119,130 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||
atol, rtol = 5e-3, 5e-3
|
||||
|
||||
if org_model.__class__.__name__ == "LlamaModel":
|
||||
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
|
||||
|
||||
check_output_hidden_state(
|
||||
org_output,
|
||||
sharded_output,
|
||||
stage_manager,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
shard_config=booster.plugin.shard_config,
|
||||
)
|
||||
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
|
||||
|
||||
# check weights
|
||||
if stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True):
|
||||
if test_config["precision"] == "fp32":
|
||||
atol, rtol = 1e-4, 1e-3
|
||||
else:
|
||||
atol, rtol = 5e-3, 5e-3
|
||||
try:
|
||||
check_weight(
|
||||
llama_model,
|
||||
shard_llama_model,
|
||||
col_layer_for_check,
|
||||
tp_group,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
dim=1,
|
||||
verbose=False,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Failed config: {test_config}")
|
||||
raise e
|
||||
check_weight(
|
||||
llama_model,
|
||||
shard_llama_model,
|
||||
col_layer_for_check,
|
||||
tp_group,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
dim=1,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
# check grads
|
||||
check_all_grad_tensors(grads_to_check)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@parameterize(
|
||||
"test_config",
|
||||
[
|
||||
{ # Ulysess + Flash attention
|
||||
# Double Ring Attention
|
||||
{
|
||||
"tp_size": 1,
|
||||
"pp_size": 1,
|
||||
"sp_size": 4,
|
||||
"num_microbatches": 1,
|
||||
"enable_sequence_parallelism": True,
|
||||
"sequence_parallelism_mode": "ring_attn",
|
||||
"use_lazy_init": True,
|
||||
"zero_stage": 0,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
"inner_ring_size": 2,
|
||||
},
|
||||
# Ring Attention + PP
|
||||
{
|
||||
"tp_size": 1,
|
||||
"pp_size": 2,
|
||||
"sp_size": 2,
|
||||
"num_microbatches": 2,
|
||||
"enable_sequence_parallelism": True,
|
||||
"sequence_parallelism_mode": "ring_attn",
|
||||
"use_lazy_init": True,
|
||||
"zero_stage": 1,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
# Ring Attention + TP
|
||||
{
|
||||
"tp_size": 2,
|
||||
"pp_size": 1,
|
||||
"sp_size": 2,
|
||||
"num_microbatches": 1,
|
||||
"enable_sequence_parallelism": True,
|
||||
"sequence_parallelism_mode": "ring_attn",
|
||||
"use_lazy_init": True,
|
||||
"zero_stage": 2,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{ # Ulysess + TP
|
||||
"tp_size": 2,
|
||||
"pp_size": 1,
|
||||
"sp_size": 2,
|
||||
"num_microbatches": 1,
|
||||
"enable_sequence_parallelism": True,
|
||||
"sequence_parallelism_mode": "all_to_all",
|
||||
"enable_all_optimization": True,
|
||||
"use_lazy_init": True,
|
||||
"zero_stage": 0,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{ # Ulysess + PP
|
||||
"tp_size": 1,
|
||||
"pp_size": 2,
|
||||
"sp_size": 2,
|
||||
"num_microbatches": 2,
|
||||
"enable_sequence_parallelism": True,
|
||||
"sequence_parallelism_mode": "all_to_all",
|
||||
"enable_flash_attention": True,
|
||||
"enable_all_optimization": True,
|
||||
"use_lazy_init": True,
|
||||
"zero_stage": 0,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{ # Test ring + Flash attention
|
||||
"tp_size": 2,
|
||||
"pp_size": 1,
|
||||
"sp_size": 2,
|
||||
"num_microbatches": 1,
|
||||
"enable_sequence_parallelism": True,
|
||||
"sequence_parallelism_mode": "ring",
|
||||
"enable_flash_attention": True,
|
||||
"use_lazy_init": True,
|
||||
"zero_stage": 2,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
"tp_size": 1,
|
||||
"pp_size": 1,
|
||||
"sp_size": 2,
|
||||
"num_microbatches": 1,
|
||||
"enable_sequence_parallelism": True,
|
||||
"sequence_parallelism_mode": "all_to_all",
|
||||
"use_lazy_init": True,
|
||||
"zero_stage": 1,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
"tp_size": 4,
|
||||
"pp_size": 1,
|
||||
"num_microbatches": 1,
|
||||
"enable_sequence_parallelism": True,
|
||||
"sequence_parallelism_mode": "split_gather",
|
||||
"enable_flash_attention": False,
|
||||
"enable_flash_attention": True,
|
||||
"use_lazy_init": True,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
"tp_size": 2,
|
||||
"pp_size": 1,
|
||||
"sp_size": 1,
|
||||
"num_microbatches": 1,
|
||||
"enable_sequence_parallelism": True,
|
||||
"sequence_parallelism_mode": "ring",
|
||||
"enable_flash_attention": True,
|
||||
"use_lazy_init": True,
|
||||
"zero_stage": 2,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
"tp_size": 2,
|
||||
"pp_size": 2,
|
||||
@@ -240,12 +286,13 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||
)
|
||||
def run_llama_test(test_config):
|
||||
sub_model_zoo = model_zoo.get_sub_registry("transformers_llama")
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
if test_config.get("sequence_parallelism_mode", None) == "ring_attn" and "causal" not in name:
|
||||
continue
|
||||
try:
|
||||
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||
except Exception as e:
|
||||
print(f"Failed config: {test_config}")
|
||||
print(f"Failed config: {test_config}, model name: {name}")
|
||||
raise e
|
||||
clear_layout_converter()
|
||||
Randomizer.reset_index()
|
||||
|
Reference in New Issue
Block a user