1
0
mirror of https://github.com/hpcaitech/ColossalAI.git synced 2025-05-07 07:58:27 +00:00
ColossalAI/colossalai/shardformer/layer/_operation.py
duanjunwen e0c68ab6d3
[Zerobubble] merge main. ()
* [feat] add zerobubble pp (just a frame now); add POC test for dx_dw; add test for zerobubble;

* [feat] add dw test;

* [fix] fix weight not close;

* [update] update text;

* [feat] add test run_fwd_bwd automatic scheduling;

* [feat] split communication and calculation; fix pop empty send_bwd_buffer error;

* [feat] add test for p & p grad;

* [feat] add comments for ZBV func;

* [fix] rm useless assign and comments;

* [fix] fix ci test; add pytest;

* [feat] add run_fwd_bwd_with_microbatch  (replace input) & test; add p&p.grad assert close test & all pass;

* [feat] add apply v_schedule graph; p & p.grad assert err exist;

* [fix] update

* [feat] fix ci; add assert;

* [feat] fix poc format

* [feat] fix func name & ci; add comments;

* [fix] fix poc test; add comments in poc;

* [feat] add optim backward_b_by_grad

* [feat] fix optimizer bwd b & w; support return accum loss & output

* [feat] add fwd_bwd_step, run_fwd_only;

* [fix] fix optim bwd; add license for v_schedule; remove redundant attributes; fix schedule loop "while"--> "for"; add communication dict;

* [fix] fix communication_map;

* [feat] update test; rm comments;

* [fix] rm zbv in hybridplugin

* [fix] fix optim bwd;

* [fix] fix optim bwd;

* [fix] rm output.data after send fwd;

* [fix] fix bwd step if condition; remove useless comments and format info;

* [fix] fix detach output & release output;

* [fix] rm requir_grad for output;

* [fix] fix requir grad position and detach position and input&output local buffer append position;

* [feat] add memory assertation;

* [fix] fix mem check;

* [fix] mem assertation'

* [fix] fix mem assertation

* [fix] fix mem; use a new model shape; only assert mem less and equal than theo;

* [fix] fix model zoo import;

* [fix] fix redundant detach & clone; add buffer assertation in the end;

* [fix] add output_obj_grad assert None at bwd b step; replace input_obj.require_grad_ with treemap;

* [fix] update optim state dict assert (include param group & state); fix mem assert after add optim;

* [fix] add testcase with microbatch 4;

* [zerobubble]Support ZeroBubble Pipeline ()

* [feat] add zerobubble pp (just a frame now); add POC test for dx_dw; add test for zerobubble;

* [feat] add dw test;

* [fix] fix weight not close;

* [update] update text;

* [feat] add test run_fwd_bwd automatic scheduling;

* [feat] split communication and calculation; fix pop empty send_bwd_buffer error;

* [feat] add test for p & p grad;

* [feat] add comments for ZBV func;

* [fix] rm useless assign and comments;

* [fix] fix ci test; add pytest;

* [feat] add run_fwd_bwd_with_microbatch  (replace input) & test; add p&p.grad assert close test & all pass;

* [feat] add apply v_schedule graph; p & p.grad assert err exist;

* [fix] update

* [feat] fix ci; add assert;

* [feat] fix poc format

* [feat] fix func name & ci; add comments;

* [fix] fix poc test; add comments in poc;

* [feat] add optim backward_b_by_grad

* [feat] fix optimizer bwd b & w; support return accum loss & output

* [feat] add fwd_bwd_step, run_fwd_only;

* [fix] fix optim bwd; add license for v_schedule; remove redundant attributes; fix schedule loop "while"--> "for"; add communication dict;

* [fix] fix communication_map;

* [feat] update test; rm comments;

* [fix] rm zbv in hybridplugin

* [fix] fix optim bwd;

* [fix] fix optim bwd;

* [fix] rm output.data after send fwd;

* [fix] fix bwd step if condition; remove useless comments and format info;

* [fix] fix detach output & release output;

* [fix] rm requir_grad for output;

* [fix] fix requir grad position and detach position and input&output local buffer append position;

* [feat] add memory assertation;

* [fix] fix mem check;

* [fix] mem assertation'

* [fix] fix mem assertation

* [fix] fix mem; use a new model shape; only assert mem less and equal than theo;

* [fix] fix model zoo import;

* [fix] fix redundant detach & clone; add buffer assertation in the end;

* [fix] add output_obj_grad assert None at bwd b step; replace input_obj.require_grad_ with treemap;

* [fix] update optim state dict assert (include param group & state); fix mem assert after add optim;

* [fix] add testcase with microbatch 4;

* [feat] moehybrid support zerobubble;

* [fix] fix zerobubble pp for shardformer type input;

* [feat] add more test;

* [fix] fix require_grad & deallocate call;

* [fix] updatw bwd b&w input; dict --> list[torch.Tensor]

* [fix] fix bwd w input;

* [fix] fix mem assert;

* [fix] fix input_tensors buffer  append input_obj(dict) --> Tuple (microbatch, input_obj) , and all bwd b related cal logic;

* [fix] use tree_flatten replace dict traverse;

* [fix] rm comments;

* [fix] fix fwd branch, fwd pass both micro_batch & internal_inputs'

* [fix] fix pipeline util func deallocate --> release_tensor_data; fix bwd_b loss bwd branch;

* [fix] fix detach clone release order;

* [fix] fix ci --> oom in 4096 hidden dim;

* [fix] fix dumb clone;

* [fix] fix detach_output_obj clone;

* [fix] fix stage_indices;

* [fix] fix traverse; traverse dict --> traverse tensor List;

* [fix] fix zerobubble; support shardformer model type;

* [fix] rm comments;

* [fix] fix test_pipeline_utils ci;

* [fix] remove duplicate arg; rm comments;

* [fix] remove chunk 0 stage 0 bwd b; u don't have to cal micrbatch's dx;

* [fix] rm print & comments;

* [plugin] hybrid support zero bubble pipeline ()

* hybrid support zbv

* fix

fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* Update zero_bubble_pp.py

* fix

* fix-ci

* fix

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

* fix

* [zerobubble]Support ZeroBubble Pipeline ()

* [feat] add zerobubble pp (just a frame now); add POC test for dx_dw; add test for zerobubble;

* [feat] add dw test;

* [fix] fix weight not close;

* [update] update text;

* [feat] add test run_fwd_bwd automatic scheduling;

* [feat] split communication and calculation; fix pop empty send_bwd_buffer error;

* [feat] add test for p & p grad;

* [feat] add comments for ZBV func;

* [fix] rm useless assign and comments;

* [fix] fix ci test; add pytest;

* [feat] add run_fwd_bwd_with_microbatch  (replace input) & test; add p&p.grad assert close test & all pass;

* [feat] add apply v_schedule graph; p & p.grad assert err exist;

* [fix] update

* [feat] fix ci; add assert;

* [feat] fix poc format

* [feat] fix func name & ci; add comments;

* [fix] fix poc test; add comments in poc;

* [feat] add optim backward_b_by_grad

* [feat] fix optimizer bwd b & w; support return accum loss & output

* [feat] add fwd_bwd_step, run_fwd_only;

* [fix] fix optim bwd; add license for v_schedule; remove redundant attributes; fix schedule loop "while"--> "for"; add communication dict;

* [fix] fix communication_map;

* [feat] update test; rm comments;

* [fix] rm zbv in hybridplugin

* [fix] fix optim bwd;

* [fix] fix optim bwd;

* [fix] rm output.data after send fwd;

* [fix] fix bwd step if condition; remove useless comments and format info;

* [fix] fix detach output & release output;

* [fix] rm requir_grad for output;

* [fix] fix requir grad position and detach position and input&output local buffer append position;

* [feat] add memory assertation;

* [fix] fix mem check;

* [fix] mem assertation'

* [fix] fix mem assertation

* [fix] fix mem; use a new model shape; only assert mem less and equal than theo;

* [fix] fix model zoo import;

* [fix] fix redundant detach & clone; add buffer assertation in the end;

* [fix] add output_obj_grad assert None at bwd b step; replace input_obj.require_grad_ with treemap;

* [fix] update optim state dict assert (include param group & state); fix mem assert after add optim;

* [fix] add testcase with microbatch 4;

* hybrid support zbv

* fix

fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update zero_bubble_pp.py

* fix

* fix-ci

* fix

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

* fix

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: duanjunwen <935724073@qq.com>

* [feat] zerobubble support moehybridplugin;

* [feat] update optimizer bwd; ä¸

* [fix] fix build ci;

* [zerobubble] rebase main ()

* fp8 operators for compressed communication

cast_to_fp8, cast_from_fp8, all_reduce_fp8

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix typo

* fix scaling algorithm in FP8 casting

* support fp8 communication in pipeline parallelism

* add fp8_communication flag in the script

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* shardformer fp8

* fix rebase

* remove all to all

* fix shardformer fp8 communication training degradation

* [fp8] support all-gather flat tensor ()

* [fp8] add fp8 comm for low level zero

* [test] add zero fp8 test case

* [Feature] llama shardformer fp8 support ()

* add llama shardformer fp8

* Llama Shardformer Parity

* fix typo

* fix all reduce

* fix pytest failure

* fix reduce op and move function to fp8.py

* fix typo

* [FP8] rebase main ()

* add SimPO

* fix dataloader

* remove debug code

* add orpo

* fix style

* fix colossalai, transformers version

* fix colossalai, transformers version

* fix colossalai, transformers version

* fix torch colossalai version

* update transformers version

* [shardformer] DeepseekMoE support ()

* [Feature] deepseek moe expert parallel implement

* [misc] fix typo, remove redundant file ()

* [misc] fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Feature] deepseek support & unit test

* [misc] remove debug code & useless print

* [misc] fix typos ()

* [Feature] remove modeling file, use auto config. ()

* [misc] fix typos

* [Feature] deepseek support via auto model, remove modeling file

* [misc] delete useless file

* [misc] fix typos

* [Deepseek] remove redundant code ()

* [misc] fix typos

* [Feature] deepseek support via auto model, remove modeling file

* [misc] delete useless file

* [misc] fix typos

* [misc] remove redundant code

* [Feature/deepseek] resolve comment. ()

* [misc] fix typos

* [Feature] deepseek support via auto model, remove modeling file

* [misc] delete useless file

* [misc] fix typos

* [misc] remove redundant code

* [misc] mv module replacement into if branch

* [misc] add some warning message and modify some code in unit test

* [misc] fix typos

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Hoxfix] Fix CUDA_DEVICE_MAX_CONNECTIONS for comm overlap

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [Feat] Diffusion Model(PixArtAlpha/StableDiffusion3) Support ()

* Diffusion Model Inference support

* Stable Diffusion 3 Support

* pixartalpha support

* [HotFix] CI,import,requirements-test for  ()

* [Hot Fix] CI,import,requirements-test

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Feature] Enable PP + SP for llama ()

* fix cross-PP-stage position id length diff bug

* fix typo

* fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* use a one cross entropy func for all shardformer models

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [ShardFormer] Add Ulysses Sequence Parallelism support for Command-R, Qwen2 and ChatGLM ()

* add benchmark for sft, dpo, simpo, orpo. Add benchmarking result. Support lora with gradient checkpoint

* fix style

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix eval

* hotfix citation

* [zero] support all-gather overlap ()

* [zero] support all-gather overlap

* [zero] add overlap all-gather flag

* [misc] fix typo

* [zero] update api

* fix orpo cross entropy loss

* [Auto Parallel]: Speed up intra-op plan generation by 44% ()

* Remove unnecessary calls to deepcopy

* Build DimSpec's difference dict only once

This change considerably speeds up construction speed of DimSpec objects. The difference_dict is the same for each DimSpec object, so a single copy of it is enough.

* Fix documentation of DimSpec's difference method

* [ShardFormer] fix qwen2 sp ()

* [compatibility] support torch 2.2 ()

* Support Pytorch 2.2.2

* keep build_on_pr file and update .compatibility

* fix object_to_tensor usage when torch>=2.3.0 ()

* [misc] support torch2.3 ()

* [misc] support torch2.3

* [devops] update compatibility ci

* [devops] update compatibility ci

* [devops] add debug

* [devops] add debug

* [devops] add debug

* [devops] add debug

* [devops] remove debug

* [devops] remove debug

* [release] update version ()

* [plugin] support all-gather overlap for hybrid parallel ()

* [plugin] fixed all-gather overlap support for hybrid parallel

* add kto

* fix style, add kto data sample

* [Examples] Add lazy init to OPT and GPT examples ()

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [ColossalChat] Hotfix for ColossalChat ()

* add ignore and tiny llama

* fix path issue

* run style

* fix issue

* update bash

* add ignore and tiny llama

* fix path issue

* run style

* fix issue

* update bash

* fix ddp issue

* add Qwen 1.5 32B

* refactor tokenization

* [FIX BUG] UnboundLocalError: cannot access local variable 'default_conversation' where it is not associated with a value ()

* cannot access local variable 'default_conversation' where it is not associated with a value

set default value for 'default_conversation'

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* fix test data

* refactor evaluation

* remove real data path

* remove real data path

* Add n_fused as an input from native_module ()

* [FIX BUG] convert env param to int in ()

* [Hotfix] Fix ZeRO typo 

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [Feature] Add a switch to control whether the model checkpoint needs to be saved after each epoch ends ()

* Add a switch to control whether the model checkpoint needs to be saved after each epoch ends

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* fix style

* fix style

* fix style

* [shardformer] hotfix attn mask ()

* [shardformer] hotfix attn mask ()

* [Feat] Distrifusion Acceleration Support for Diffusion Inference ()

* Distrifusion Support source

* comp comm overlap optimization

* sd3 benchmark

* pixart distrifusion bug fix

* sd3 bug fix and benchmark

* generation bug fix

* naming fix

* add docstring, fix counter and shape error

* add reference

* readme and requirement

* [zero] hotfix update master params ()

* [release] update version ()

* [Chat] Fix lora ()

* fix merging

* remove filepath

* fix style

* Update README.md ()

* [hotfix] Remove unused plan section ()

* remove readme

* fix readme

* update

* [test] add mixtral for sequence classification

* [test] add mixtral transformer test

* [moe] fix plugin

* [test] mixtra pp shard test

* [chore] handle non member group

* [zero] solve hang

* [test] pass mixtral shardformer test

* [moe] implement transit between non moe tp and ep

* [zero] solve hang

* [misc] solve booster hang by rename the variable

* solve hang when parallel mode = pp + dp

* [moe] implement submesh initialization

* [moe] add mixtral dp grad scaling when not all experts are activated

* [chore] manually revert unintended commit

* [chore] trivial fix

* [chore] arg pass & remove drop token

* [test] add mixtral modelling test

* [moe] implement tp

* [moe] test deepseek

* [moe] clean legacy code

* [Feature] MoE Ulysses Support ()

* moe sp support

* moe sp bug solve

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [chore] minor fix

* [moe] init moe plugin comm setting with sp

* moe sp + ep bug fix

* [moe] finalize test (no pp)

* [moe] full test for deepseek and mixtral (pp + sp to fix)

* [chore] minor fix after rebase

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [chore] solve moe ckpt test failure and some other arg pass failure

* [moe] remove ops

* [test] fix test: test_zero1_2

* [bug] fix: somehow logger hangs the program

* [moe] deepseek moe sp support

* [test] add check

* [deepseek] replace attn (a workaround for bug in transformers)

* [misc] skip redunant test

* [misc] remove debug/print code

* [moe] refactor mesh assignment

* Revert "[moe] implement submesh initialization"

This reverts commit 2f9bce6686.

* [chore] change moe_pg_mesh to private

* [misc] remove incompatible test config

* [misc] fix ci failure: change default value to false in moe plugin

* [misc] remove useless condition

* [chore] docstring

* [moe] remove force_overlap_comm flag and add warning instead

* [doc] add MoeHybridParallelPlugin docstring

* [moe] solve dp axis issue

* [chore] remove redundant test case, print string & reduce test tokens

* [feat] Dist Loader for Eval ()

* support auto distributed data loader

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* support auto distributed data loader

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix tp error

* remove unused parameters

* remove unused

* update inference

* update docs

* update inference

---------

Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [lora] lora support hybrid parallel plugin ()

* lora support hybrid plugin

* fix

* fix

* fix

* fix

* fp8 operators for compressed communication

cast_to_fp8, cast_from_fp8, all_reduce_fp8

* fix scaling algorithm in FP8 casting

* support fp8 communication in pipeline parallelism

* add fp8_communication flag in the script

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* shardformer fp8

* fix rebase

* remove all to all

* fix shardformer fp8 communication training degradation

* [fp8] support all-gather flat tensor ()

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* Update low_level_optim.py

---------

Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: Haze188 <haze188@qq.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com>
Co-authored-by: Guangyao Zhang <xjtu521@qq.com>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: Stephan Kö <stephankoe@users.noreply.github.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: zhurunhua <1281592874@qq.com>
Co-authored-by: Insu Jang <insujang@umich.edu>
Co-authored-by: Gao, Ruiyuan <905370712@qq.com>
Co-authored-by: hxwang <wang1570@e.ntu.edu.sg>
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: Wang Binluo <32676639+wangbluo@users.noreply.github.com>
Co-authored-by: HangXu <hangxu0304@gmail.com>

* [fp8]support all2all fp8 ()

* support all2all fp8

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [fp8] add fp8 linear ()

* [fp8] add fp8 linear

* [test] fix fp8 linear test condition

* [test] fix fp8 linear test condition

* [test] fix fp8 linear test condition

* [fp8] support fp8 amp for hybrid parallel plugin ()

* [fp8] support fp8 amp for hybrid parallel plugin

* [test] add fp8 hook test

* [fp8] fix fp8 linear compatibility

* fix ()

* [Feature]: support FP8 communication in DDP, FSDP, Gemini ()

* support fp8_communication in the Torch DDP grad comm, FSDP grad comm, and FSDP params comm

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* implement communication hook for FSDP params all-gather

* added unit test for fp8 operators

* support fp8 communication in GeminiPlugin

* update training scripts to support fsdp and fp8 communication

* fixed some minor bugs observed in unit test

* add all_gather_into_tensor_flat_fp8

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add skip the test if torch < 2.2.0

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add skip the test if torch < 2.2.0

* add skip the test if torch < 2.2.0

* add fp8_comm flag

* rebase latest fp8 operators

* rebase latest fp8 operators

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [test ci]Feature/fp8 comm ()

* fix

* fix

* fix

* [fp8] support gemini plugin ()

* [fp8] refactor hook

* [fp8] support gemini plugin

* [example] add fp8 option for llama benchmark

* [fp8] use torch compile (torch >= 2.3.0) ()

* [fp8] use torch compile (torch >= 2.4.0)

* [fp8] set use_fast_accum in linear

* [chore] formal version check

* [chore] fix sig

* [fp8]Moe support fp8 communication ()

* fix

* support moe fp8

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

fix

fi

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [fp8] support hybrid parallel plugin ()

* support fp8 comm for qwen2 model

* support fp8 comm for qwen2 model

* support fp8 comm for qwen2 model

* fp8

* fix

* bert and bloom

* chatglm and command

* gpt2,gptj,bert, falcon,blip2

* mistral,opy,sam,t5,vit,whisper

* fix

* fix

* fix

* [fp8] refactor fp8 linear with compile ()

* [fp8] refactor fp8 linear with compile

* [fp8] fix linear test

* [fp8] fix linear test

* [fp8] support asynchronous FP8 communication ()

* fix

* fix

* fix

* support async all2all

* support async op for all gather

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [fp8] update torch.compile for linear_fp8 to >= 2.4.0 ()

* [fp8] linear perf enhancement

* [fp8]update reduce-scatter test ()

* fix

* fix

* fix

* fix

* [fp8] add use_fp8 option for MoeHybridParallelPlugin ()

* [fp8] zero support fp8 linear. ()

* fix

* fix

* fix

* zero fp8

* zero fp8

* Update requirements.txt

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix the merge

* fix the merge

* fix the merge

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix the merge

* fix

* fix

* fix the merge

* fix

* fix

* fix

* fix

* fix

* fix the merge

* fix

* fix

* fix

* fix

* [fp8] Merge feature/fp8_comm to main branch of Colossalai ()

* add SimPO

* fix dataloader

* remove debug code

* add orpo

* fix style

* fix colossalai, transformers version

* fix colossalai, transformers version

* fix colossalai, transformers version

* fix torch colossalai version

* update transformers version

* [shardformer] DeepseekMoE support ()

* [Feature] deepseek moe expert parallel implement

* [misc] fix typo, remove redundant file ()

* [misc] fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Feature] deepseek support & unit test

* [misc] remove debug code & useless print

* [misc] fix typos ()

* [Feature] remove modeling file, use auto config. ()

* [misc] fix typos

* [Feature] deepseek support via auto model, remove modeling file

* [misc] delete useless file

* [misc] fix typos

* [Deepseek] remove redundant code ()

* [misc] fix typos

* [Feature] deepseek support via auto model, remove modeling file

* [misc] delete useless file

* [misc] fix typos

* [misc] remove redundant code

* [Feature/deepseek] resolve comment. ()

* [misc] fix typos

* [Feature] deepseek support via auto model, remove modeling file

* [misc] delete useless file

* [misc] fix typos

* [misc] remove redundant code

* [misc] mv module replacement into if branch

* [misc] add some warning message and modify some code in unit test

* [misc] fix typos

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Hoxfix] Fix CUDA_DEVICE_MAX_CONNECTIONS for comm overlap

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [Feat] Diffusion Model(PixArtAlpha/StableDiffusion3) Support ()

* Diffusion Model Inference support

* Stable Diffusion 3 Support

* pixartalpha support

* [HotFix] CI,import,requirements-test for  ()

* [Hot Fix] CI,import,requirements-test

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Feature] Enable PP + SP for llama ()

* fix cross-PP-stage position id length diff bug

* fix typo

* fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* use a one cross entropy func for all shardformer models

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [ShardFormer] Add Ulysses Sequence Parallelism support for Command-R, Qwen2 and ChatGLM ()

* add benchmark for sft, dpo, simpo, orpo. Add benchmarking result. Support lora with gradient checkpoint

* fix style

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix eval

* hotfix citation

* [zero] support all-gather overlap ()

* [zero] support all-gather overlap

* [zero] add overlap all-gather flag

* [misc] fix typo

* [zero] update api

* fix orpo cross entropy loss

* [Auto Parallel]: Speed up intra-op plan generation by 44% ()

* Remove unnecessary calls to deepcopy

* Build DimSpec's difference dict only once

This change considerably speeds up construction speed of DimSpec objects. The difference_dict is the same for each DimSpec object, so a single copy of it is enough.

* Fix documentation of DimSpec's difference method

* [ShardFormer] fix qwen2 sp ()

* [compatibility] support torch 2.2 ()

* Support Pytorch 2.2.2

* keep build_on_pr file and update .compatibility

* fix object_to_tensor usage when torch>=2.3.0 ()

* [misc] support torch2.3 ()

* [misc] support torch2.3

* [devops] update compatibility ci

* [devops] update compatibility ci

* [devops] add debug

* [devops] add debug

* [devops] add debug

* [devops] add debug

* [devops] remove debug

* [devops] remove debug

* [release] update version ()

* [plugin] support all-gather overlap for hybrid parallel ()

* [plugin] fixed all-gather overlap support for hybrid parallel

* add kto

* fix style, add kto data sample

* [Examples] Add lazy init to OPT and GPT examples ()

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [ColossalChat] Hotfix for ColossalChat ()

* add ignore and tiny llama

* fix path issue

* run style

* fix issue

* update bash

* add ignore and tiny llama

* fix path issue

* run style

* fix issue

* update bash

* fix ddp issue

* add Qwen 1.5 32B

* refactor tokenization

* [FIX BUG] UnboundLocalError: cannot access local variable 'default_conversation' where it is not associated with a value ()

* cannot access local variable 'default_conversation' where it is not associated with a value

set default value for 'default_conversation'

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* fix test data

* refactor evaluation

* remove real data path

* remove real data path

* Add n_fused as an input from native_module ()

* [FIX BUG] convert env param to int in ()

* [Hotfix] Fix ZeRO typo 

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [Feature] Add a switch to control whether the model checkpoint needs to be saved after each epoch ends ()

* Add a switch to control whether the model checkpoint needs to be saved after each epoch ends

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* fix style

* fix style

* fix style

* [shardformer] hotfix attn mask ()

* [shardformer] hotfix attn mask ()

* [Feat] Distrifusion Acceleration Support for Diffusion Inference ()

* Distrifusion Support source

* comp comm overlap optimization

* sd3 benchmark

* pixart distrifusion bug fix

* sd3 bug fix and benchmark

* generation bug fix

* naming fix

* add docstring, fix counter and shape error

* add reference

* readme and requirement

* [zero] hotfix update master params ()

* [release] update version ()

* [Chat] Fix lora ()

* fix merging

* remove filepath

* fix style

* Update README.md ()

* [hotfix] Remove unused plan section ()

* remove readme

* fix readme

* update

* [test] add mixtral for sequence classification

* [test] add mixtral transformer test

* [moe] fix plugin

* [test] mixtra pp shard test

* [chore] handle non member group

* [zero] solve hang

* [test] pass mixtral shardformer test

* [moe] implement transit between non moe tp and ep

* [zero] solve hang

* [misc] solve booster hang by rename the variable

* solve hang when parallel mode = pp + dp

* [moe] implement submesh initialization

* [moe] add mixtral dp grad scaling when not all experts are activated

* [chore] manually revert unintended commit

* [chore] trivial fix

* [chore] arg pass & remove drop token

* [test] add mixtral modelling test

* [moe] implement tp

* [moe] test deepseek

* [moe] clean legacy code

* [Feature] MoE Ulysses Support ()

* moe sp support

* moe sp bug solve

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [chore] minor fix

* [moe] init moe plugin comm setting with sp

* moe sp + ep bug fix

* [moe] finalize test (no pp)

* [moe] full test for deepseek and mixtral (pp + sp to fix)

* [chore] minor fix after rebase

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [chore] solve moe ckpt test failure and some other arg pass failure

* [moe] remove ops

* [test] fix test: test_zero1_2

* [bug] fix: somehow logger hangs the program

* [moe] deepseek moe sp support

* [test] add check

* [deepseek] replace attn (a workaround for bug in transformers)

* [misc] skip redunant test

* [misc] remove debug/print code

* [moe] refactor mesh assignment

* Revert "[moe] implement submesh initialization"

This reverts commit 2f9bce6686.

* [chore] change moe_pg_mesh to private

* [misc] remove incompatible test config

* [misc] fix ci failure: change default value to false in moe plugin

* [misc] remove useless condition

* [chore] docstring

* [moe] remove force_overlap_comm flag and add warning instead

* [doc] add MoeHybridParallelPlugin docstring

* [moe] solve dp axis issue

* [chore] remove redundant test case, print string & reduce test tokens

* [feat] Dist Loader for Eval ()

* support auto distributed data loader

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* support auto distributed data loader

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix tp error

* remove unused parameters

* remove unused

* update inference

* update docs

* update inference

---------

Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [lora] lora support hybrid parallel plugin ()

* lora support hybrid plugin

* fix

* fix

* fix

* fix

* Support overall loss, update KTO logging

* [Docs] clarify launch port

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [Hotfix] README link ()

* update ignore

* update readme

* run style

* update readme

* [Hotfix] Avoid fused RMSnorm import error without apex ()

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [Chat] fix readme ()

* fix readme

* fix readme, tokenization fully tested

* fix readme, tokenization fully tested

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: root <root@notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9-0.notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9.colossal-ai.svc.cluster.local>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* fix sync condition ()

* [plugin] add cast inputs option for zero ()

* [pre-commit.ci] pre-commit autoupdate ()

updates:
- [github.com/psf/black-pre-commit-mirror: 24.4.2 → 24.8.0](https://github.com/psf/black-pre-commit-mirror/compare/24.4.2...24.8.0)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [misc] Bypass the huggingface bug to solve the mask mismatch problem ()

* [Feature] Zigzag Ring attention ()

* halfway

* fix cross-PP-stage position id length diff bug

* fix typo

* fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* unified cross entropy func for all shardformer models

* remove redundant lines

* add basic ring attn; debug cross entropy

* fwd bwd logic complete

* fwd bwd logic complete; add experimental triton rescale

* precision tests passed

* precision tests passed

* fix typos and remove misc files

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add sp_mode to benchmark; fix varlen interface

* update softmax_lse shape by new interface

* change tester name

* remove buffer clone; support packed seq layout

* add varlen tests

* fix typo

* all tests passed

* add dkv_group; fix mask

* remove debug statements

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [misc] update compatibility ()

* [misc] update compatibility

* [misc] update requirements

* [devops] disable requirements cache

* [test] fix torch ddp test

* [test] fix rerun on address in use

* [test] fix lazy init

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix the merge

* fix the merge

* overlap kv comm with output rescale ()

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* fix the merge

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix the merge

* fix

* fix

* fix the merge

* fix

* [misc] Use dist logger in plugins ()

* use dist logger in plugins

* remove trash

* print on rank 0

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* fix

* fix

* fix

* fix

* fix the merge

* fix

* fix

* fix

* fix

---------

Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: Haze188 <haze188@qq.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com>
Co-authored-by: Guangyao Zhang <xjtu521@qq.com>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: Stephan Kö <stephankoe@users.noreply.github.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: zhurunhua <1281592874@qq.com>
Co-authored-by: Insu Jang <insujang@umich.edu>
Co-authored-by: Gao, Ruiyuan <905370712@qq.com>
Co-authored-by: hxwang <wang1570@e.ntu.edu.sg>
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: root <root@notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9-0.notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9.colossal-ai.svc.cluster.local>

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update train_dpo.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update low_level_zero_plugin.py

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [CI] Remove triton version for compatibility bug; update req torch >=2.2 ()

* remove triton version

* remove torch 2.2

* remove torch 2.1

* debug

* remove 2.1 build tests

* require torch >=2.2

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [plugin] hotfix zero plugin ()

* [plugin] hotfix zero plugin

* [plugin] hotfix zero plugin

* [Colossal-LLaMA] Refactor latest APIs ()

* refactor latest code

* update api

* add dummy dataset

* update Readme

* add setup

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update files

* add PP support

* update arguments

* update argument

* reorg folder

* update version

* remove IB infor

* update utils

* update readme

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update save for zero

* update save

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add apex

* update

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* add fused norm ()

* [FP8] unsqueeze scale to make it compatible with torch.compile ()

* [colossalai/checkpoint_io/...] fix bug in load_state_dict_into_model; format error msg ()

* fix bug in load_state_dict_into_model; format error msg

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update utils.py

to support checking missing_keys

* Update general_checkpoint_io.py

fix bug in missing_keys error message

* retrigger tests

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Hotfix] Remove deprecated install ()

* remove deprecated install

* remove unused folder

* [fp8] optimize all-gather ()

* [fp8] optimize all-gather

* [fp8] fix all gather fp8 ring

* [fp8] enable compile

* [fp8] fix all gather fp8 ring

* [fp8] fix linear hook ()

* [fp8] disable all_to_all_fp8 in intranode  ()

* enhance all_to_all_fp8 with internode comm control

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* disable some fp8 ops due to performance issue

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [release] update version ()

* [release] update version

* [devops] update comp test

* [devops] update comp test debug

* [devops] debug comp test

* [devops] debug comp test

* [devops] debug comp test

* [devops] debug comp test

* [devops] debug comp test

* [Feature] Split cross-entropy computation in SP ()

* halfway

* fix cross-PP-stage position id length diff bug

* fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* unified cross entropy func for all shardformer models

* remove redundant lines

* add basic ring attn; debug cross entropy

* fwd bwd logic complete

* fwd bwd logic complete; add experimental triton rescale

* precision tests passed

* precision tests passed

* fix typos and remove misc files

* update softmax_lse shape by new interface

* change tester name

* remove buffer clone; support packed seq layout

* add varlen tests

* fix typo

* all tests passed

* add dkv_group; fix mask

* remove debug statements

* adapt chatglm, command-R, qwen

* debug

* halfway

* fix cross-PP-stage position id length diff bug

* fix typo

* fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* unified cross entropy func for all shardformer models

* remove redundant lines

* add basic ring attn; debug cross entropy

* fwd bwd logic complete

* fwd bwd logic complete; add experimental triton rescale

* precision tests passed

* precision tests passed

* fix typos and remove misc files

* add sp_mode to benchmark; fix varlen interface

* update softmax_lse shape by new interface

* add varlen tests

* fix typo

* all tests passed

* add dkv_group; fix mask

* remove debug statements

* add comments

* q1 index only once

* remove events to simplify stream sync

* simplify forward/backward logic

* 2d ring forward passed

* 2d ring backward passed

* fixes

* fix ring attn loss

* 2D ring backward + llama passed

* merge

* update logger

* fix typo

* rebase

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix typo

* remove typos

* fixes

* support GPT

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [hotfix] moe hybrid parallelism benchmark & follow-up fix ()

* [example] pass use_fp8_comm flag to all plugins

* [example] add mixtral benchmark

* [moe] refine assertion and check

* [moe] fix mixtral & add more tests

* [moe] consider checking dp * sp group and moe_dp_group

* [mixtral] remove gate tp & add more tests

* [deepseek] fix tp & sp for deepseek

* [mixtral] minor fix

* [deepseek] add deepseek benchmark

* [fp8] hotfix backward hook ()

* [fp8] hotfix backward hook

* [fp8] hotfix pipeline loss accumulation

* [doc] update sp doc ()

* update sp doc

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* fix the sp

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix the attn

* fix

* fix

* fix

* fix

* [zerobubble]Support ZeroBubble Pipeline ()

* [feat] add zerobubble pp (just a frame now); add POC test for dx_dw; add test for zerobubble;

* [feat] add dw test;

* [fix] fix weight not close;

* [update] update text;

* [feat] add test run_fwd_bwd automatic scheduling;

* [feat] split communication and calculation; fix pop empty send_bwd_buffer error;

* [feat] add test for p & p grad;

* [feat] add comments for ZBV func;

* [fix] rm useless assign and comments;

* [fix] fix ci test; add pytest;

* [feat] add run_fwd_bwd_with_microbatch  (replace input) & test; add p&p.grad assert close test & all pass;

* [feat] add apply v_schedule graph; p & p.grad assert err exist;

* [fix] update

* [feat] fix ci; add assert;

* [feat] fix poc format

* [feat] fix func name & ci; add comments;

* [fix] fix poc test; add comments in poc;

* [feat] add optim backward_b_by_grad

* [feat] fix optimizer bwd b & w; support return accum loss & output

* [feat] add fwd_bwd_step, run_fwd_only;

* [fix] fix optim bwd; add license for v_schedule; remove redundant attributes; fix schedule loop "while"--> "for"; add communication dict;

* [fix] fix communication_map;

* [feat] update test; rm comments;

* [fix] rm zbv in hybridplugin

* [fix] fix optim bwd;

* [fix] fix optim bwd;

* [fix] rm output.data after send fwd;

* [fix] fix bwd step if condition; remove useless comments and format info;

* [fix] fix detach output & release output;

* [fix] rm requir_grad for output;

* [fix] fix requir grad position and detach position and input&output local buffer append position;

* [feat] add memory assertation;

* [fix] fix mem check;

* [fix] mem assertation'

* [fix] fix mem assertation

* [fix] fix mem; use a new model shape; only assert mem less and equal than theo;

* [fix] fix model zoo import;

* [fix] fix redundant detach & clone; add buffer assertation in the end;

* [fix] add output_obj_grad assert None at bwd b step; replace input_obj.require_grad_ with treemap;

* [fix] update optim state dict assert (include param group & state); fix mem assert after add optim;

* [fix] add testcase with microbatch 4;

* [fp8] fix missing fp8_comm flag in mixtral ()

* fix

* fix

* fix

* [fp8] Disable all_gather intranode. Disable Redundant all_gather fp8 ()

* all_gather only internode, fix pytest

* fix cuda arch <89 compile pytest error

* fix pytest failure

* disable all_gather_into_tensor_flat_fp8

* fix fp8 format

* fix pytest

* fix conversations

* fix chunk tuple to list

* [doc] FP8 training and communication document ()

* Add FP8 training and communication document

* add fp8 docstring for plugins

* fix typo

* fix typo

* fix

* fix

* [moe] add parallel strategy for shared_expert && fix test for deepseek ()

* [ColossalEval] support for vllm ()

* support vllm

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* modify vllm and update readme

* run pre-commit

* remove dupilicated lines and refine code

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update param name

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* refine code

* update readme

* refine code

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [release] update version ()

* [feat] add zerobubble pp (just a frame now); add POC test for dx_dw; add test for zerobubble;

* [update] update text;

* [feat] add test run_fwd_bwd automatic scheduling;

* [feat] fix poc format

* [fix] fix poc test; add comments in poc;

* [feat] add optim backward_b_by_grad

* [feat] fix optimizer bwd b & w; support return accum loss & output

* [fix] fix optim bwd; add license for v_schedule; remove redundant attributes; fix schedule loop "while"--> "for"; add communication dict;

* [feat] update test; rm comments;

* [fix] rm zbv in hybridplugin

* [fix] fix optim bwd;

* [fix] fix optim bwd;

* [fix] rm output.data after send fwd;

* [fix] fix bwd step if condition; remove useless comments and format info;

* [fix] fix mem check;

* [fix] fix mem assertation

* [fix] fix mem; use a new model shape; only assert mem less and equal than theo;

* [fix] fix model zoo import;

* [feat] moehybrid support zerobubble;

* [fix] fix zerobubble pp for shardformer type input;

* [fix] fix require_grad & deallocate call;

* [fix] fix mem assert;

* [fix] fix fwd branch, fwd pass both micro_batch & internal_inputs'

* [fix] fix pipeline util func deallocate --> release_tensor_data; fix bwd_b loss bwd branch;

* [fix] fix zerobubble; support shardformer model type;

* [fix] fix test_pipeline_utils ci;

* [plugin] hybrid support zero bubble pipeline ()

* hybrid support zbv

* fix

fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* Update zero_bubble_pp.py

* fix

* fix-ci

* fix

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

* fix

* [zerobubble]Support ZeroBubble Pipeline ()

* [feat] add zerobubble pp (just a frame now); add POC test for dx_dw; add test for zerobubble;

* [feat] add dw test;

* [fix] fix weight not close;

* [update] update text;

* [feat] add test run_fwd_bwd automatic scheduling;

* [feat] split communication and calculation; fix pop empty send_bwd_buffer error;

* [feat] add test for p & p grad;

* [feat] add comments for ZBV func;

* [fix] rm useless assign and comments;

* [fix] fix ci test; add pytest;

* [feat] add run_fwd_bwd_with_microbatch  (replace input) & test; add p&p.grad assert close test & all pass;

* [feat] add apply v_schedule graph; p & p.grad assert err exist;

* [fix] update

* [feat] fix ci; add assert;

* [feat] fix poc format

* [feat] fix func name & ci; add comments;

* [fix] fix poc test; add comments in poc;

* [feat] add optim backward_b_by_grad

* [feat] fix optimizer bwd b & w; support return accum loss & output

* [feat] add fwd_bwd_step, run_fwd_only;

* [fix] fix optim bwd; add license for v_schedule; remove redundant attributes; fix schedule loop "while"--> "for"; add communication dict;

* [fix] fix communication_map;

* [feat] update test; rm comments;

* [fix] rm zbv in hybridplugin

* [fix] fix optim bwd;

* [fix] fix optim bwd;

* [fix] rm output.data after send fwd;

* [fix] fix bwd step if condition; remove useless comments and format info;

* [fix] fix detach output & release output;

* [fix] rm requir_grad for output;

* [fix] fix requir grad position and detach position and input&output local buffer append position;

* [feat] add memory assertation;

* [fix] fix mem check;

* [fix] mem assertation'

* [fix] fix mem assertation

* [fix] fix mem; use a new model shape; only assert mem less and equal than theo;

* [fix] fix model zoo import;

* [fix] fix redundant detach & clone; add buffer assertation in the end;

* [fix] add output_obj_grad assert None at bwd b step; replace input_obj.require_grad_ with treemap;

* [fix] update optim state dict assert (include param group & state); fix mem assert after add optim;

* [fix] add testcase with microbatch 4;

* hybrid support zbv

* fix

fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update zero_bubble_pp.py

* fix

* fix-ci

* fix

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

* fix

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: duanjunwen <935724073@qq.com>

* [feat] add zerobubble pp (just a frame now); add POC test for dx_dw; add test for zerobubble;

* [update] update text;

* [feat] add test run_fwd_bwd automatic scheduling;

* [feat] fix poc format

* [fix] fix poc test; add comments in poc;

* [feat] add optim backward_b_by_grad

* [feat] fix optimizer bwd b & w; support return accum loss & output

* [fix] fix optim bwd; add license for v_schedule; remove redundant attributes; fix schedule loop "while"--> "for"; add communication dict;

* [feat] update test; rm comments;

* [fix] fix optim bwd;

* [fix] fix optim bwd;

* [fix] rm output.data after send fwd;

* [fix] fix bwd step if condition; remove useless comments and format info;

* [fix] fix mem check;

* [fix] fix mem assertation

* [fix] fix mem; use a new model shape; only assert mem less and equal than theo;

* [fix] fix model zoo import;

* [fix] fix mem assert;

* [fix] fix fwd branch, fwd pass both micro_batch & internal_inputs'

* [plugin] hybrid support zero bubble pipeline ()

* hybrid support zbv

* fix

fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* Update zero_bubble_pp.py

* fix

* fix-ci

* fix

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

* fix

* [zerobubble]Support ZeroBubble Pipeline ()

* [feat] add zerobubble pp (just a frame now); add POC test for dx_dw; add test for zerobubble;

* [feat] add dw test;

* [fix] fix weight not close;

* [update] update text;

* [feat] add test run_fwd_bwd automatic scheduling;

* [feat] split communication and calculation; fix pop empty send_bwd_buffer error;

* [feat] add test for p & p grad;

* [feat] add comments for ZBV func;

* [fix] rm useless assign and comments;

* [fix] fix ci test; add pytest;

* [feat] add run_fwd_bwd_with_microbatch  (replace input) & test; add p&p.grad assert close test & all pass;

* [feat] add apply v_schedule graph; p & p.grad assert err exist;

* [fix] update

* [feat] fix ci; add assert;

* [feat] fix poc format

* [feat] fix func name & ci; add comments;

* [fix] fix poc test; add comments in poc;

* [feat] add optim backward_b_by_grad

* [feat] fix optimizer bwd b & w; support return accum loss & output

* [feat] add fwd_bwd_step, run_fwd_only;

* [fix] fix optim bwd; add license for v_schedule; remove redundant attributes; fix schedule loop "while"--> "for"; add communication dict;

* [fix] fix communication_map;

* [feat] update test; rm comments;

* [fix] rm zbv in hybridplugin

* [fix] fix optim bwd;

* [fix] fix optim bwd;

* [fix] rm output.data after send fwd;

* [fix] fix bwd step if condition; remove useless comments and format info;

* [fix] fix detach output & release output;

* [fix] rm requir_grad for output;

* [fix] fix requir grad position and detach position and input&output local buffer append position;

* [feat] add memory assertation;

* [fix] fix mem check;

* [fix] mem assertation'

* [fix] fix mem assertation

* [fix] fix mem; use a new model shape; only assert mem less and equal than theo;

* [fix] fix model zoo import;

* [fix] fix redundant detach & clone; add buffer assertation in the end;

* [fix] add output_obj_grad assert None at bwd b step; replace input_obj.require_grad_ with treemap;

* [fix] update optim state dict assert (include param group & state); fix mem assert after add optim;

* [fix] add testcase with microbatch 4;

* hybrid support zbv

* fix

fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update zero_bubble_pp.py

* fix

* fix-ci

* fix

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

* fix

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: duanjunwen <935724073@qq.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: HangXu <hangxu0304@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: GuangyaoZhang <xjtu521@qq.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: Haze188 <haze188@qq.com>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Stephan Kö <stephankoe@users.noreply.github.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: zhurunhua <1281592874@qq.com>
Co-authored-by: Insu Jang <insujang@umich.edu>
Co-authored-by: Gao, Ruiyuan <905370712@qq.com>
Co-authored-by: hxwang <wang1570@e.ntu.edu.sg>
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: Wang Binluo <32676639+wangbluo@users.noreply.github.com>
Co-authored-by: wangbluo <2538539015@qq.com>
Co-authored-by: root <root@notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9-0.notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9.colossal-ai.svc.cluster.local>
Co-authored-by: duanjunwen <935724073@qq.com>
Co-authored-by: Camille Zhong <44392324+Camille7777@users.noreply.github.com>

* [fix] fix mixtral policy;

* [fix] fix mixtral policy;

* [feat] support zbv in mixtral benchmark;

* [fix] MixtralForCausalLMPolicy get_held_layer support zbv;

* [feat] update MixtralPipelineForwards --> mixtral_model_forward; support zbv;

* [feat] support MixtralPipelineForwards--> mixtral_for_causal_lm_forward for zbv

* [zero bubble] support zero ()

* fp8 operators for compressed communication

cast_to_fp8, cast_from_fp8, all_reduce_fp8

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix typo

* fix scaling algorithm in FP8 casting

* support fp8 communication in pipeline parallelism

* add fp8_communication flag in the script

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* shardformer fp8

* fix rebase

* remove all to all

* fix shardformer fp8 communication training degradation

* [fp8] support all-gather flat tensor ()

* [fp8] add fp8 comm for low level zero

* [test] add zero fp8 test case

* [Feature] llama shardformer fp8 support ()

* add llama shardformer fp8

* Llama Shardformer Parity

* fix typo

* fix all reduce

* fix pytest failure

* fix reduce op and move function to fp8.py

* fix typo

* [FP8] rebase main ()

* add SimPO

* fix dataloader

* remove debug code

* add orpo

* fix style

* fix colossalai, transformers version

* fix colossalai, transformers version

* fix colossalai, transformers version

* fix torch colossalai version

* update transformers version

* [shardformer] DeepseekMoE support ()

* [Feature] deepseek moe expert parallel implement

* [misc] fix typo, remove redundant file ()

* [misc] fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Feature] deepseek support & unit test

* [misc] remove debug code & useless print

* [misc] fix typos ()

* [Feature] remove modeling file, use auto config. ()

* [misc] fix typos

* [Feature] deepseek support via auto model, remove modeling file

* [misc] delete useless file

* [misc] fix typos

* [Deepseek] remove redundant code ()

* [misc] fix typos

* [Feature] deepseek support via auto model, remove modeling file

* [misc] delete useless file

* [misc] fix typos

* [misc] remove redundant code

* [Feature/deepseek] resolve comment. ()

* [misc] fix typos

* [Feature] deepseek support via auto model, remove modeling file

* [misc] delete useless file

* [misc] fix typos

* [misc] remove redundant code

* [misc] mv module replacement into if branch

* [misc] add some warning message and modify some code in unit test

* [misc] fix typos

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Hoxfix] Fix CUDA_DEVICE_MAX_CONNECTIONS for comm overlap

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [Feat] Diffusion Model(PixArtAlpha/StableDiffusion3) Support ()

* Diffusion Model Inference support

* Stable Diffusion 3 Support

* pixartalpha support

* [HotFix] CI,import,requirements-test for  ()

* [Hot Fix] CI,import,requirements-test

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Feature] Enable PP + SP for llama ()

* fix cross-PP-stage position id length diff bug

* fix typo

* fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* use a one cross entropy func for all shardformer models

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [ShardFormer] Add Ulysses Sequence Parallelism support for Command-R, Qwen2 and ChatGLM ()

* add benchmark for sft, dpo, simpo, orpo. Add benchmarking result. Support lora with gradient checkpoint

* fix style

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix eval

* hotfix citation

* [zero] support all-gather overlap ()

* [zero] support all-gather overlap

* [zero] add overlap all-gather flag

* [misc] fix typo

* [zero] update api

* fix orpo cross entropy loss

* [Auto Parallel]: Speed up intra-op plan generation by 44% ()

* Remove unnecessary calls to deepcopy

* Build DimSpec's difference dict only once

This change considerably speeds up construction speed of DimSpec objects. The difference_dict is the same for each DimSpec object, so a single copy of it is enough.

* Fix documentation of DimSpec's difference method

* [ShardFormer] fix qwen2 sp ()

* [compatibility] support torch 2.2 ()

* Support Pytorch 2.2.2

* keep build_on_pr file and update .compatibility

* fix object_to_tensor usage when torch>=2.3.0 ()

* [misc] support torch2.3 ()

* [misc] support torch2.3

* [devops] update compatibility ci

* [devops] update compatibility ci

* [devops] add debug

* [devops] add debug

* [devops] add debug

* [devops] add debug

* [devops] remove debug

* [devops] remove debug

* [release] update version ()

* [plugin] support all-gather overlap for hybrid parallel ()

* [plugin] fixed all-gather overlap support for hybrid parallel

* add kto

* fix style, add kto data sample

* [Examples] Add lazy init to OPT and GPT examples ()

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [ColossalChat] Hotfix for ColossalChat ()

* add ignore and tiny llama

* fix path issue

* run style

* fix issue

* update bash

* add ignore and tiny llama

* fix path issue

* run style

* fix issue

* update bash

* fix ddp issue

* add Qwen 1.5 32B

* refactor tokenization

* [FIX BUG] UnboundLocalError: cannot access local variable 'default_conversation' where it is not associated with a value ()

* cannot access local variable 'default_conversation' where it is not associated with a value

set default value for 'default_conversation'

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* fix test data

* refactor evaluation

* remove real data path

* remove real data path

* Add n_fused as an input from native_module ()

* [FIX BUG] convert env param to int in ()

* [Hotfix] Fix ZeRO typo 

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [Feature] Add a switch to control whether the model checkpoint needs to be saved after each epoch ends ()

* Add a switch to control whether the model checkpoint needs to be saved after each epoch ends

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* fix style

* fix style

* fix style

* [shardformer] hotfix attn mask ()

* [shardformer] hotfix attn mask ()

* [Feat] Distrifusion Acceleration Support for Diffusion Inference ()

* Distrifusion Support source

* comp comm overlap optimization

* sd3 benchmark

* pixart distrifusion bug fix

* sd3 bug fix and benchmark

* generation bug fix

* naming fix

* add docstring, fix counter and shape error

* add reference

* readme and requirement

* [zero] hotfix update master params ()

* [release] update version ()

* [Chat] Fix lora ()

* fix merging

* remove filepath

* fix style

* Update README.md ()

* [hotfix] Remove unused plan section ()

* remove readme

* fix readme

* update

* [test] add mixtral for sequence classification

* [test] add mixtral transformer test

* [moe] fix plugin

* [test] mixtra pp shard test

* [chore] handle non member group

* [zero] solve hang

* [test] pass mixtral shardformer test

* [moe] implement transit between non moe tp and ep

* [zero] solve hang

* [misc] solve booster hang by rename the variable

* solve hang when parallel mode = pp + dp

* [moe] implement submesh initialization

* [moe] add mixtral dp grad scaling when not all experts are activated

* [chore] manually revert unintended commit

* [chore] trivial fix

* [chore] arg pass & remove drop token

* [test] add mixtral modelling test

* [moe] implement tp

* [moe] test deepseek

* [moe] clean legacy code

* [Feature] MoE Ulysses Support ()

* moe sp support

* moe sp bug solve

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [chore] minor fix

* [moe] init moe plugin comm setting with sp

* moe sp + ep bug fix

* [moe] finalize test (no pp)

* [moe] full test for deepseek and mixtral (pp + sp to fix)

* [chore] minor fix after rebase

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [chore] solve moe ckpt test failure and some other arg pass failure

* [moe] remove ops

* [test] fix test: test_zero1_2

* [bug] fix: somehow logger hangs the program

* [moe] deepseek moe sp support

* [test] add check

* [deepseek] replace attn (a workaround for bug in transformers)

* [misc] skip redunant test

* [misc] remove debug/print code

* [moe] refactor mesh assignment

* Revert "[moe] implement submesh initialization"

This reverts commit 2f9bce6686.

* [chore] change moe_pg_mesh to private

* [misc] remove incompatible test config

* [misc] fix ci failure: change default value to false in moe plugin

* [misc] remove useless condition

* [chore] docstring

* [moe] remove force_overlap_comm flag and add warning instead

* [doc] add MoeHybridParallelPlugin docstring

* [moe] solve dp axis issue

* [chore] remove redundant test case, print string & reduce test tokens

* [feat] Dist Loader for Eval ()

* support auto distributed data loader

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* support auto distributed data loader

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix tp error

* remove unused parameters

* remove unused

* update inference

* update docs

* update inference

---------

Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [lora] lora support hybrid parallel plugin ()

* lora support hybrid plugin

* fix

* fix

* fix

* fix

* fp8 operators for compressed communication

cast_to_fp8, cast_from_fp8, all_reduce_fp8

* fix scaling algorithm in FP8 casting

* support fp8 communication in pipeline parallelism

* add fp8_communication flag in the script

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* shardformer fp8

* fix rebase

* remove all to all

* fix shardformer fp8 communication training degradation

* [fp8] support all-gather flat tensor ()

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* Update low_level_optim.py

---------

Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: Haze188 <haze188@qq.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com>
Co-authored-by: Guangyao Zhang <xjtu521@qq.com>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: Stephan Kö <stephankoe@users.noreply.github.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: zhurunhua <1281592874@qq.com>
Co-authored-by: Insu Jang <insujang@umich.edu>
Co-authored-by: Gao, Ruiyuan <905370712@qq.com>
Co-authored-by: hxwang <wang1570@e.ntu.edu.sg>
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: Wang Binluo <32676639+wangbluo@users.noreply.github.com>
Co-authored-by: HangXu <hangxu0304@gmail.com>

* [fp8]support all2all fp8 ()

* support all2all fp8

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [fp8] add fp8 linear ()

* [fp8] add fp8 linear

* [test] fix fp8 linear test condition

* [test] fix fp8 linear test condition

* [test] fix fp8 linear test condition

* [fp8] support fp8 amp for hybrid parallel plugin ()

* [fp8] support fp8 amp for hybrid parallel plugin

* [test] add fp8 hook test

* [fp8] fix fp8 linear compatibility

* fix ()

* [Feature]: support FP8 communication in DDP, FSDP, Gemini ()

* support fp8_communication in the Torch DDP grad comm, FSDP grad comm, and FSDP params comm

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* implement communication hook for FSDP params all-gather

* added unit test for fp8 operators

* support fp8 communication in GeminiPlugin

* update training scripts to support fsdp and fp8 communication

* fixed some minor bugs observed in unit test

* add all_gather_into_tensor_flat_fp8

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add skip the test if torch < 2.2.0

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add skip the test if torch < 2.2.0

* add skip the test if torch < 2.2.0

* add fp8_comm flag

* rebase latest fp8 operators

* rebase latest fp8 operators

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [test ci]Feature/fp8 comm ()

* fix

* fix

* fix

* [fp8] support gemini plugin ()

* [fp8] refactor hook

* [fp8] support gemini plugin

* [example] add fp8 option for llama benchmark

* [fp8] use torch compile (torch >= 2.3.0) ()

* [fp8] use torch compile (torch >= 2.4.0)

* [fp8] set use_fast_accum in linear

* [chore] formal version check

* [chore] fix sig

* [fp8]Moe support fp8 communication ()

* fix

* support moe fp8

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

fix

fi

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [fp8] support hybrid parallel plugin ()

* support fp8 comm for qwen2 model

* support fp8 comm for qwen2 model

* support fp8 comm for qwen2 model

* fp8

* fix

* bert and bloom

* chatglm and command

* gpt2,gptj,bert, falcon,blip2

* mistral,opy,sam,t5,vit,whisper

* fix

* fix

* fix

* [fp8] refactor fp8 linear with compile ()

* [fp8] refactor fp8 linear with compile

* [fp8] fix linear test

* [fp8] fix linear test

* [fp8] support asynchronous FP8 communication ()

* fix

* fix

* fix

* support async all2all

* support async op for all gather

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [fp8] update torch.compile for linear_fp8 to >= 2.4.0 ()

* [fp8] linear perf enhancement

* [fp8]update reduce-scatter test ()

* fix

* fix

* fix

* fix

* [fp8] add use_fp8 option for MoeHybridParallelPlugin ()

* [fp8] zero support fp8 linear. ()

* fix

* fix

* fix

* zero fp8

* zero fp8

* Update requirements.txt

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix the merge

* fix the merge

* fix the merge

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix the merge

* fix

* fix

* fix the merge

* fix

* fix

* fix

* fix

* fix

* fix the merge

* fix

* fix

* fix

* fix

* [fp8] Merge feature/fp8_comm to main branch of Colossalai ()

* add SimPO

* fix dataloader

* remove debug code

* add orpo

* fix style

* fix colossalai, transformers version

* fix colossalai, transformers version

* fix colossalai, transformers version

* fix torch colossalai version

* update transformers version

* [shardformer] DeepseekMoE support ()

* [Feature] deepseek moe expert parallel implement

* [misc] fix typo, remove redundant file ()

* [misc] fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Feature] deepseek support & unit test

* [misc] remove debug code & useless print

* [misc] fix typos ()

* [Feature] remove modeling file, use auto config. ()

* [misc] fix typos

* [Feature] deepseek support via auto model, remove modeling file

* [misc] delete useless file

* [misc] fix typos

* [Deepseek] remove redundant code ()

* [misc] fix typos

* [Feature] deepseek support via auto model, remove modeling file

* [misc] delete useless file

* [misc] fix typos

* [misc] remove redundant code

* [Feature/deepseek] resolve comment. ()

* [misc] fix typos

* [Feature] deepseek support via auto model, remove modeling file

* [misc] delete useless file

* [misc] fix typos

* [misc] remove redundant code

* [misc] mv module replacement into if branch

* [misc] add some warning message and modify some code in unit test

* [misc] fix typos

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Hoxfix] Fix CUDA_DEVICE_MAX_CONNECTIONS for comm overlap

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [Feat] Diffusion Model(PixArtAlpha/StableDiffusion3) Support ()

* Diffusion Model Inference support

* Stable Diffusion 3 Support

* pixartalpha support

* [HotFix] CI,import,requirements-test for  ()

* [Hot Fix] CI,import,requirements-test

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Feature] Enable PP + SP for llama ()

* fix cross-PP-stage position id length diff bug

* fix typo

* fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* use a one cross entropy func for all shardformer models

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [ShardFormer] Add Ulysses Sequence Parallelism support for Command-R, Qwen2 and ChatGLM ()

* add benchmark for sft, dpo, simpo, orpo. Add benchmarking result. Support lora with gradient checkpoint

* fix style

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix eval

* hotfix citation

* [zero] support all-gather overlap ()

* [zero] support all-gather overlap

* [zero] add overlap all-gather flag

* [misc] fix typo

* [zero] update api

* fix orpo cross entropy loss

* [Auto Parallel]: Speed up intra-op plan generation by 44% ()

* Remove unnecessary calls to deepcopy

* Build DimSpec's difference dict only once

This change considerably speeds up construction speed of DimSpec objects. The difference_dict is the same for each DimSpec object, so a single copy of it is enough.

* Fix documentation of DimSpec's difference method

* [ShardFormer] fix qwen2 sp ()

* [compatibility] support torch 2.2 ()

* Support Pytorch 2.2.2

* keep build_on_pr file and update .compatibility

* fix object_to_tensor usage when torch>=2.3.0 ()

* [misc] support torch2.3 ()

* [misc] support torch2.3

* [devops] update compatibility ci

* [devops] update compatibility ci

* [devops] add debug

* [devops] add debug

* [devops] add debug

* [devops] add debug

* [devops] remove debug

* [devops] remove debug

* [release] update version ()

* [plugin] support all-gather overlap for hybrid parallel ()

* [plugin] fixed all-gather overlap support for hybrid parallel

* add kto

* fix style, add kto data sample

* [Examples] Add lazy init to OPT and GPT examples ()

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [ColossalChat] Hotfix for ColossalChat ()

* add ignore and tiny llama

* fix path issue

* run style

* fix issue

* update bash

* add ignore and tiny llama

* fix path issue

* run style

* fix issue

* update bash

* fix ddp issue

* add Qwen 1.5 32B

* refactor tokenization

* [FIX BUG] UnboundLocalError: cannot access local variable 'default_conversation' where it is not associated with a value ()

* cannot access local variable 'default_conversation' where it is not associated with a value

set default value for 'default_conversation'

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* fix test data

* refactor evaluation

* remove real data path

* remove real data path

* Add n_fused as an input from native_module ()

* [FIX BUG] convert env param to int in ()

* [Hotfix] Fix ZeRO typo 

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [Feature] Add a switch to control whether the model checkpoint needs to be saved after each epoch ends ()

* Add a switch to control whether the model checkpoint needs to be saved after each epoch ends

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* fix style

* fix style

* fix style

* [shardformer] hotfix attn mask ()

* [shardformer] hotfix attn mask ()

* [Feat] Distrifusion Acceleration Support for Diffusion Inference ()

* Distrifusion Support source

* comp comm overlap optimization

* sd3 benchmark

* pixart distrifusion bug fix

* sd3 bug fix and benchmark

* generation bug fix

* naming fix

* add docstring, fix counter and shape error

* add reference

* readme and requirement

* [zero] hotfix update master params ()

* [release] update version ()

* [Chat] Fix lora ()

* fix merging

* remove filepath

* fix style

* Update README.md ()

* [hotfix] Remove unused plan section ()

* remove readme

* fix readme

* update

* [test] add mixtral for sequence classification

* [test] add mixtral transformer test

* [moe] fix plugin

* [test] mixtra pp shard test

* [chore] handle non member group

* [zero] solve hang

* [test] pass mixtral shardformer test

* [moe] implement transit between non moe tp and ep

* [zero] solve hang

* [misc] solve booster hang by rename the variable

* solve hang when parallel mode = pp + dp

* [moe] implement submesh initialization

* [moe] add mixtral dp grad scaling when not all experts are activated

* [chore] manually revert unintended commit

* [chore] trivial fix

* [chore] arg pass & remove drop token

* [test] add mixtral modelling test

* [moe] implement tp

* [moe] test deepseek

* [moe] clean legacy code

* [Feature] MoE Ulysses Support ()

* moe sp support

* moe sp bug solve

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [chore] minor fix

* [moe] init moe plugin comm setting with sp

* moe sp + ep bug fix

* [moe] finalize test (no pp)

* [moe] full test for deepseek and mixtral (pp + sp to fix)

* [chore] minor fix after rebase

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [chore] solve moe ckpt test failure and some other arg pass failure

* [moe] remove ops

* [test] fix test: test_zero1_2

* [bug] fix: somehow logger hangs the program

* [moe] deepseek moe sp support

* [test] add check

* [deepseek] replace attn (a workaround for bug in transformers)

* [misc] skip redunant test

* [misc] remove debug/print code

* [moe] refactor mesh assignment

* Revert "[moe] implement submesh initialization"

This reverts commit 2f9bce6686.

* [chore] change moe_pg_mesh to private

* [misc] remove incompatible test config

* [misc] fix ci failure: change default value to false in moe plugin

* [misc] remove useless condition

* [chore] docstring

* [moe] remove force_overlap_comm flag and add warning instead

* [doc] add MoeHybridParallelPlugin docstring

* [moe] solve dp axis issue

* [chore] remove redundant test case, print string & reduce test tokens

* [feat] Dist Loader for Eval ()

* support auto distributed data loader

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* support auto distributed data loader

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix tp error

* remove unused parameters

* remove unused

* update inference

* update docs

* update inference

---------

Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [lora] lora support hybrid parallel plugin ()

* lora support hybrid plugin

* fix

* fix

* fix

* fix

* Support overall loss, update KTO logging

* [Docs] clarify launch port

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [Hotfix] README link ()

* update ignore

* update readme

* run style

* update readme

* [Hotfix] Avoid fused RMSnorm import error without apex ()

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [Chat] fix readme ()

* fix readme

* fix readme, tokenization fully tested

* fix readme, tokenization fully tested

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: root <root@notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9-0.notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9.colossal-ai.svc.cluster.local>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* fix sync condition ()

* [plugin] add cast inputs option for zero ()

* [pre-commit.ci] pre-commit autoupdate ()

updates:
- [github.com/psf/black-pre-commit-mirror: 24.4.2 → 24.8.0](https://github.com/psf/black-pre-commit-mirror/compare/24.4.2...24.8.0)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [misc] Bypass the huggingface bug to solve the mask mismatch problem ()

* [Feature] Zigzag Ring attention ()

* halfway

* fix cross-PP-stage position id length diff bug

* fix typo

* fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* unified cross entropy func for all shardformer models

* remove redundant lines

* add basic ring attn; debug cross entropy

* fwd bwd logic complete

* fwd bwd logic complete; add experimental triton rescale

* precision tests passed

* precision tests passed

* fix typos and remove misc files

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add sp_mode to benchmark; fix varlen interface

* update softmax_lse shape by new interface

* change tester name

* remove buffer clone; support packed seq layout

* add varlen tests

* fix typo

* all tests passed

* add dkv_group; fix mask

* remove debug statements

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [misc] update compatibility ()

* [misc] update compatibility

* [misc] update requirements

* [devops] disable requirements cache

* [test] fix torch ddp test

* [test] fix rerun on address in use

* [test] fix lazy init

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix the merge

* fix the merge

* overlap kv comm with output rescale ()

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* fix the merge

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix the merge

* fix

* fix

* fix the merge

* fix

* [misc] Use dist logger in plugins ()

* use dist logger in plugins

* remove trash

* print on rank 0

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* fix

* fix

* fix

* fix

* fix the merge

* fix

* fix

* fix

* fix

---------

Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: Haze188 <haze188@qq.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com>
Co-authored-by: Guangyao Zhang <xjtu521@qq.com>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: Stephan Kö <stephankoe@users.noreply.github.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: zhurunhua <1281592874@qq.com>
Co-authored-by: Insu Jang <insujang@umich.edu>
Co-authored-by: Gao, Ruiyuan <905370712@qq.com>
Co-authored-by: hxwang <wang1570@e.ntu.edu.sg>
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: root <root@notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9-0.notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9.colossal-ai.svc.cluster.local>

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update train_dpo.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update low_level_zero_plugin.py

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [CI] Remove triton version for compatibility bug; update req torch >=2.2 ()

* remove triton version

* remove torch 2.2

* remove torch 2.1

* debug

* remove 2.1 build tests

* require torch >=2.2

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [plugin] hotfix zero plugin ()

* [plugin] hotfix zero plugin

* [plugin] hotfix zero plugin

* [Colossal-LLaMA] Refactor latest APIs ()

* refactor latest code

* update api

* add dummy dataset

* update Readme

* add setup

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update files

* add PP support

* update arguments

* update argument

* reorg folder

* update version

* remove IB infor

* update utils

* update readme

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update save for zero

* update save

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add apex

* update

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* add fused norm ()

* [FP8] unsqueeze scale to make it compatible with torch.compile ()

* [colossalai/checkpoint_io/...] fix bug in load_state_dict_into_model; format error msg ()

* fix bug in load_state_dict_into_model; format error msg

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update utils.py

to support checking missing_keys

* Update general_checkpoint_io.py

fix bug in missing_keys error message

* retrigger tests

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Hotfix] Remove deprecated install ()

* remove deprecated install

* remove unused folder

* [fp8] optimize all-gather ()

* [fp8] optimize all-gather

* [fp8] fix all gather fp8 ring

* [fp8] enable compile

* [fp8] fix all gather fp8 ring

* [fp8] fix linear hook ()

* [fp8] disable all_to_all_fp8 in intranode  ()

* enhance all_to_all_fp8 with internode comm control

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* disable some fp8 ops due to performance issue

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [release] update version ()

* [release] update version

* [devops] update comp test

* [devops] update comp test debug

* [devops] debug comp test

* [devops] debug comp test

* [devops] debug comp test

* [devops] debug comp test

* [devops] debug comp test

* [Feature] Split cross-entropy computation in SP ()

* halfway

* fix cross-PP-stage position id length diff bug

* fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* unified cross entropy func for all shardformer models

* remove redundant lines

* add basic ring attn; debug cross entropy

* fwd bwd logic complete

* fwd bwd logic complete; add experimental triton rescale

* precision tests passed

* precision tests passed

* fix typos and remove misc files

* update softmax_lse shape by new interface

* change tester name

* remove buffer clone; support packed seq layout

* add varlen tests

* fix typo

* all tests passed

* add dkv_group; fix mask

* remove debug statements

* adapt chatglm, command-R, qwen

* debug

* halfway

* fix cross-PP-stage position id length diff bug

* fix typo

* fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* unified cross entropy func for all shardformer models

* remove redundant lines

* add basic ring attn; debug cross entropy

* fwd bwd logic complete

* fwd bwd logic complete; add experimental triton rescale

* precision tests passed

* precision tests passed

* fix typos and remove misc files

* add sp_mode to benchmark; fix varlen interface

* update softmax_lse shape by new interface

* add varlen tests

* fix typo

* all tests passed

* add dkv_group; fix mask

* remove debug statements

* add comments

* q1 index only once

* remove events to simplify stream sync

* simplify forward/backward logic

* 2d ring forward passed

* 2d ring backward passed

* fixes

* fix ring attn loss

* 2D ring backward + llama passed

* merge

* update logger

* fix typo

* rebase

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix typo

* remove typos

* fixes

* support GPT

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [hotfix] moe hybrid parallelism benchmark & follow-up fix ()

* [example] pass use_fp8_comm flag to all plugins

* [example] add mixtral benchmark

* [moe] refine assertion and check

* [moe] fix mixtral & add more tests

* [moe] consider checking dp * sp group and moe_dp_group

* [mixtral] remove gate tp & add more tests

* [deepseek] fix tp & sp for deepseek

* [mixtral] minor fix

* [deepseek] add deepseek benchmark

* [fp8] hotfix backward hook ()

* [fp8] hotfix backward hook

* [fp8] hotfix pipeline loss accumulation

* [doc] update sp doc ()

* update sp doc

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* fix the sp

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix the attn

* fix

* fix

* fix

* fix

* [zerobubble]Support ZeroBubble Pipeline ()

* [feat] add zerobubble pp (just a frame now); add POC test for dx_dw; add test for zerobubble;

* [feat] add dw test;

* [fix] fix weight not close;

* [update] update text;

* [feat] add test run_fwd_bwd automatic scheduling;

* [feat] split communication and calculation; fix pop empty send_bwd_buffer error;

* [feat] add test for p & p grad;

* [feat] add comments for ZBV func;

* [fix] rm useless assign and comments;

* [fix] fix ci test; add pytest;

* [feat] add run_fwd_bwd_with_microbatch  (replace input) & test; add p&p.grad assert close test & all pass;

* [feat] add apply v_schedule graph; p & p.grad assert err exist;

* [fix] update

* [feat] fix ci; add assert;

* [feat] fix poc format

* [feat] fix func name & ci; add comments;

* [fix] fix poc test; add comments in poc;

* [feat] add optim backward_b_by_grad

* [feat] fix optimizer bwd b & w; support return accum loss & output

* [feat] add fwd_bwd_step, run_fwd_only;

* [fix] fix optim bwd; add license for v_schedule; remove redundant attributes; fix schedule loop "while"--> "for"; add communication dict;

* [fix] fix communication_map;

* [feat] update test; rm comments;

* [fix] rm zbv in hybridplugin

* [fix] fix optim bwd;

* [fix] fix optim bwd;

* [fix] rm output.data after send fwd;

* [fix] fix bwd step if condition; remove useless comments and format info;

* [fix] fix detach output & release output;

* [fix] rm requir_grad for output;

* [fix] fix requir grad position and detach position and input&output local buffer append position;

* [feat] add memory assertation;

* [fix] fix mem check;

* [fix] mem assertation'

* [fix] fix mem assertation

* [fix] fix mem; use a new model shape; only assert mem less and equal than theo;

* [fix] fix model zoo import;

* [fix] fix redundant detach & clone; add buffer assertation in the end;

* [fix] add output_obj_grad assert None at bwd b step; replace input_obj.require_grad_ with treemap;

* [fix] update optim state dict assert (include param group & state); fix mem assert after add optim;

* [fix] add testcase with microbatch 4;

* [fp8] fix missing fp8_comm flag in mixtral ()

* fix

* fix

* fix

* [fp8] Disable all_gather intranode. Disable Redundant all_gather fp8 ()

* all_gather only internode, fix pytest

* fix cuda arch <89 compile pytest error

* fix pytest failure

* disable all_gather_into_tensor_flat_fp8

* fix fp8 format

* fix pytest

* fix conversations

* fix chunk tuple to list

* [doc] FP8 training and communication document ()

* Add FP8 training and communication document

* add fp8 docstring for plugins

* fix typo

* fix typo

* fix

* fix

* [moe] add parallel strategy for shared_expert && fix test for deepseek ()

* [ColossalEval] support for vllm ()

* support vllm

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* modify vllm and update readme

* run pre-commit

* remove dupilicated lines and refine code

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update param name

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* refine code

* update readme

* refine code

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [release] update version ()

* [feat] add zerobubble pp (just a frame now); add POC test for dx_dw; add test for zerobubble;

* [update] update text;

* [feat] add test run_fwd_bwd automatic scheduling;

* [feat] fix poc format

* [fix] fix poc test; add comments in poc;

* [feat] add optim backward_b_by_grad

* [feat] fix optimizer bwd b & w; support return accum loss & output

* [fix] fix optim bwd; add license for v_schedule; remove redundant attributes; fix schedule loop "while"--> "for"; add communication dict;

* [feat] update test; rm comments;

* [fix] rm zbv in hybridplugin

* [fix] fix optim bwd;

* [fix] fix optim bwd;

* [fix] rm output.data after send fwd;

* [fix] fix bwd step if condition; remove useless comments and format info;

* [fix] fix mem check;

* [fix] fix mem assertation

* [fix] fix mem; use a new model shape; only assert mem less and equal than theo;

* [fix] fix model zoo import;

* [feat] moehybrid support zerobubble;

* [fix] fix zerobubble pp for shardformer type input;

* [fix] fix require_grad & deallocate call;

* [fix] fix mem assert;

* [fix] fix fwd branch, fwd pass both micro_batch & internal_inputs'

* [fix] fix pipeline util func deallocate --> release_tensor_data; fix bwd_b loss bwd branch;

* [fix] fix zerobubble; support shardformer model type;

* [fix] fix test_pipeline_utils ci;

* [plugin] hybrid support zero bubble pipeline ()

* hybrid support zbv

* fix

fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* Update zero_bubble_pp.py

* fix

* fix-ci

* fix

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

* fix

* [zerobubble]Support ZeroBubble Pipeline ()

* [feat] add zerobubble pp (just a frame now); add POC test for dx_dw; add test for zerobubble;

* [feat] add dw test;

* [fix] fix weight not close;

* [update] update text;

* [feat] add test run_fwd_bwd automatic scheduling;

* [feat] split communication and calculation; fix pop empty send_bwd_buffer error;

* [feat] add test for p & p grad;

* [feat] add comments for ZBV func;

* [fix] rm useless assign and comments;

* [fix] fix ci test; add pytest;

* [feat] add run_fwd_bwd_with_microbatch  (replace input) & test; add p&p.grad assert close test & all pass;

* [feat] add apply v_schedule graph; p & p.grad assert err exist;

* [fix] update

* [feat] fix ci; add assert;

* [feat] fix poc format

* [feat] fix func name & ci; add comments;

* [fix] fix poc test; add comments in poc;

* [feat] add optim backward_b_by_grad

* [feat] fix optimizer bwd b & w; support return accum loss & output

* [feat] add fwd_bwd_step, run_fwd_only;

* [fix] fix optim bwd; add license for v_schedule; remove redundant attributes; fix schedule loop "while"--> "for"; add communication dict;

* [fix] fix communication_map;

* [feat] update test; rm comments;

* [fix] rm zbv in hybridplugin

* [fix] fix optim bwd;

* [fix] fix optim bwd;

* [fix] rm output.data after send fwd;

* [fix] fix bwd step if condition; remove useless comments and format info;

* [fix] fix detach output & release output;

* [fix] rm requir_grad for output;

* [fix] fix requir grad position and detach position and input&output local buffer append position;

* [feat] add memory assertation;

* [fix] fix mem check;

* [fix] mem assertation'

* [fix] fix mem assertation

* [fix] fix mem; use a new model shape; only assert mem less and equal than theo;

* [fix] fix model zoo import;

* [fix] fix redundant detach & clone; add buffer assertation in the end;

* [fix] add output_obj_grad assert None at bwd b step; replace input_obj.require_grad_ with treemap;

* [fix] update optim state dict assert (include param group & state); fix mem assert after add optim;

* [fix] add testcase with microbatch 4;

* hybrid support zbv

* fix

fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update zero_bubble_pp.py

* fix

* fix-ci

* fix

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

* fix

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: duanjunwen <935724073@qq.com>

* [feat] add zerobubble pp (just a frame now); add POC test for dx_dw; add test for zerobubble;

* [update] update text;

* [feat] add test run_fwd_bwd automatic scheduling;

* [feat] fix poc format

* [fix] fix poc test; add comments in poc;

* [feat] add optim backward_b_by_grad

* [feat] fix optimizer bwd b & w; support return accum loss & output

* [fix] fix optim bwd; add license for v_schedule; remove redundant attributes; fix schedule loop "while"--> "for"; add communication dict;

* [feat] update test; rm comments;

* [fix] fix optim bwd;

* [fix] fix optim bwd;

* [fix] rm output.data after send fwd;

* [fix] fix bwd step if condition; remove useless comments and format info;

* [fix] fix mem check;

* [fix] fix mem assertation

* [fix] fix mem; use a new model shape; only assert mem less and equal than theo;

* [fix] fix model zoo import;

* [fix] fix mem assert;

* [fix] fix fwd branch, fwd pass both micro_batch & internal_inputs'

* [plugin] hybrid support zero bubble pipeline ()

* hybrid support zbv

* fix

fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* Update zero_bubble_pp.py

* fix

* fix-ci

* fix

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

* fix

* [zerobubble]Support ZeroBubble Pipeline ()

* [feat] add zerobubble pp (just a frame now); add POC test for dx_dw; add test for zerobubble;

* [feat] add dw test;

* [fix] fix weight not close;

* [update] update text;

* [feat] add test run_fwd_bwd automatic scheduling;

* [feat] split communication and calculation; fix pop empty send_bwd_buffer error;

* [feat] add test for p & p grad;

* [feat] add comments for ZBV func;

* [fix] rm useless assign and comments;

* [fix] fix ci test; add pytest;

* [feat] add run_fwd_bwd_with_microbatch  (replace input) & test; add p&p.grad assert close test & all pass;

* [feat] add apply v_schedule graph; p & p.grad assert err exist;

* [fix] update

* [feat] fix ci; add assert;

* [feat] fix poc format

* [feat] fix func name & ci; add comments;

* [fix] fix poc test; add comments in poc;

* [feat] add optim backward_b_by_grad

* [feat] fix optimizer bwd b & w; support return accum loss & output

* [feat] add fwd_bwd_step, run_fwd_only;

* [fix] fix optim bwd; add license for v_schedule; remove redundant attributes; fix schedule loop "while"--> "for"; add communication dict;

* [fix] fix communication_map;

* [feat] update test; rm comments;

* [fix] rm zbv in hybridplugin

* [fix] fix optim bwd;

* [fix] fix optim bwd;

* [fix] rm output.data after send fwd;

* [fix] fix bwd step if condition; remove useless comments and format info;

* [fix] fix detach output & release output;

* [fix] rm requir_grad for output;

* [fix] fix requir grad position and detach position and input&output local buffer append position;

* [feat] add memory assertation;

* [fix] fix mem check;

* [fix] mem assertation'

* [fix] fix mem assertation

* [fix] fix mem; use a new model shape; only assert mem less and equal than theo;

* [fix] fix model zoo import;

* [fix] fix redundant detach & clone; add buffer assertation in the end;

* [fix] add output_obj_grad assert None at bwd b step; replace input_obj.require_grad_ with treemap;

* [fix] update optim state dict assert (include param group & state); fix mem assert after add optim;

* [fix] add testcase with microbatch 4;

* hybrid support zbv

* fix

fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update zero_bubble_pp.py

* fix

* fix-ci

* fix

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

* fix

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: duanjunwen <935724073@qq.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* zbv support zero

* fix

* fix

* fix

---------

Co-authored-by: HangXu <hangxu0304@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: GuangyaoZhang <xjtu521@qq.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: Haze188 <haze188@qq.com>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Stephan Kö <stephankoe@users.noreply.github.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: zhurunhua <1281592874@qq.com>
Co-authored-by: Insu Jang <insujang@umich.edu>
Co-authored-by: Gao, Ruiyuan <905370712@qq.com>
Co-authored-by: hxwang <wang1570@e.ntu.edu.sg>
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: Wang Binluo <32676639+wangbluo@users.noreply.github.com>
Co-authored-by: wangbluo <2538539015@qq.com>
Co-authored-by: root <root@notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9-0.notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9.colossal-ai.svc.cluster.local>
Co-authored-by: duanjunwen <935724073@qq.com>
Co-authored-by: Camille Zhong <44392324+Camille7777@users.noreply.github.com>

* [fix] fix llama, mixtral benchmark zbv loss none bug; update mixtral & llama policy and modeling;

* [feat] Linear1D_COL/ROW support zbv WeightGradStore;

* [feat] support use_zbv in llama, mixtral modeling; only replace Linear1D_Col/Row policy;

* [fix] fix test case; moe error in second iter

* [feat]EPMixtralSparseMoeBlock (op in MOE) support zbv;

* [fix] fix bwd b; now bwd w only for Layer replaced by Linear1D_Col/Row; other layer perform a fully bwd;

* [fix] debug zbv llama test;

* [fix] rm use_zbv flag in Shardconfig; rm debug info;

* [fix] add & fix  llama test

* [feat] support meta cache, meta_grad_send, meta_tensor_send; fix runtime too long in Recv Bwd; benchmark for llama + Hybrid(tp+pp);

* [fix\ fix fail case test_shard_llama

* [fix] fix test_shard_llama

* [fix] fix llama modeling policy;

* [fix] fix test_shard_llama ci;

* [fix] fix test zerobubble

* [fix] fix handle name; rm useless comments;

* [fix] fix send recv signature;

* [fix] fix comment in llama & benchmark

* [feat] support no tensor parallel Linear in shardformer; Add test for use weightGradStore and not use WeightGradStore

* [fix] fix linear (no tp) ops func name;

* [feat] support zbv in mixtral benchmark; ()

* [feat] support zbv in mixtral benchmark;

* [fix] MixtralForCausalLMPolicy get_held_layer support zbv;

* [feat] update MixtralPipelineForwards --> mixtral_model_forward; support zbv;

* [feat] support MixtralPipelineForwards--> mixtral_for_causal_lm_forward for zbv

* [fix] fix llama, mixtral benchmark zbv loss none bug; update mixtral & llama policy and modeling;

* [feat] Linear1D_COL/ROW support zbv WeightGradStore;

* [feat] support use_zbv in llama, mixtral modeling; only replace Linear1D_Col/Row policy;

* [fix] fix test case; moe error in second iter

* [feat]EPMixtralSparseMoeBlock (op in MOE) support zbv;

* [fix] fix bwd b; now bwd w only for Layer replaced by Linear1D_Col/Row; other layer perform a fully bwd;

* [fix] debug zbv llama test;

* [fix] rm use_zbv flag in Shardconfig; rm debug info;

* [fix] add & fix  llama test

* [feat] support meta cache, meta_grad_send, meta_tensor_send; fix runtime too long in Recv Bwd; benchmark for llama + Hybrid(tp+pp);

* [fix\ fix fail case test_shard_llama

* [fix] fix test_shard_llama

* [fix] fix llama modeling policy;

* [fix] fix test_shard_llama ci;

* [fix] fix test zerobubble

* [fix] fix handle name; rm useless comments;

* [fix] fix send recv signature;

* [fix] fix comment in llama & benchmark

* [feat] support no tensor parallel Linear in shardformer; Add test for use weightGradStore and not use WeightGradStore

* [fix] fix linear (no tp) ops func name;

* [fix] fix fp8 args in HybridParallel

* [fix] fix hybridparall use_fp8 config

* [fix] fix use_fp8 flag

* [fix] fix model zoo init

* [feat] support no_tp Linear for sharderformer.llama

* [fix] fix zbv llama pp4

* [fix] fix send_tensor_metadata & send_grad_metadata;

* [feat] fix testcase;

* [feat] support mixtral policy with zbv tp_Linear & non_tp_Linear

* [feat] update mixtral policy & bert policy for zerobubble

* [fix] fix p2p error in zbv

* [fix] fix attn

* [fix] fix mixtral modeling & policy; update wait handles; doing benchmarking for llama hybrid;

* [fix] fix zbv wait_handle

* [fix] rm debug info; update llama policy; update wait handle

* [fix] fix test_lora

* [fix] fix test_lora in llama policy

* [fix] fix wait handle in run_fwd_bwd

* [fix] remove debug info;

* [fix] rm unused comments

* [fix] fix fp8 overlap code

* [fix] fix yml file & v_schedule comments

* [fix] rm fwd only meta cache comments;

---------

Co-authored-by: flybird11111 <1829166702@qq.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: HangXu <hangxu0304@gmail.com>
Co-authored-by: GuangyaoZhang <xjtu521@qq.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: Haze188 <haze188@qq.com>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Stephan Kö <stephankoe@users.noreply.github.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: zhurunhua <1281592874@qq.com>
Co-authored-by: Insu Jang <insujang@umich.edu>
Co-authored-by: Gao, Ruiyuan <905370712@qq.com>
Co-authored-by: hxwang <wang1570@e.ntu.edu.sg>
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: Wang Binluo <32676639+wangbluo@users.noreply.github.com>
Co-authored-by: wangbluo <2538539015@qq.com>
Co-authored-by: root <root@notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9-0.notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9.colossal-ai.svc.cluster.local>
Co-authored-by: Camille Zhong <44392324+Camille7777@users.noreply.github.com>
2024-11-19 19:00:36 +08:00

1193 lines
45 KiB
Python

import functools
import torch
import torch.distributed as dist
import torch.nn.functional as F
from colossalai.pipeline.weight_grad_store import WeightGradStore
from .utils import is_share_sp_tp
try:
import fused_mix_prec_layer_norm_cuda
except:
fused_mix_prec_layer_norm_cuda = None
try:
import fused_weight_gradient_mlp_cuda
_grad_accum_fusion_available = True
except ImportError:
_grad_accum_fusion_available = False
from colossalai.quantization.fp8 import (
all_gather_fp8,
all_reduce_fp8,
all_to_all_fp8,
all_to_all_single_fp8,
reduce_scatter_fp8,
)
class FusedLayerNormAffineFunction1D(torch.autograd.Function):
r"""Layernorm
Args:
input: input matrix.
weight: weight matrix.
bias: bias matrix.
normalized_shape: input shape from an expected input of size.
:math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]`
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
eps: a value added to the denominator for numerical stability
"""
@staticmethod
def forward(ctx, input, weight, bias, normalized_shape, eps):
ctx.normalized_shape = normalized_shape
ctx.eps = eps
input_ = input.contiguous()
weight_ = weight.contiguous()
bias_ = bias.contiguous()
output, mean, invvar = fused_mix_prec_layer_norm_cuda.forward_affine(
input_, ctx.normalized_shape, weight_, bias_, ctx.eps
)
ctx.save_for_backward(input_, weight_, bias_, mean, invvar)
return output
@staticmethod
def backward(ctx, grad_output):
input_, weight_, bias_, mean, invvar = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None
grad_input, grad_weight, grad_bias = fused_mix_prec_layer_norm_cuda.backward_affine(
grad_output.contiguous(), mean, invvar, input_, ctx.normalized_shape, weight_, bias_, ctx.eps
)
return grad_input, grad_weight, grad_bias, None, None
class MatmulWithAsyncCommunication(torch.autograd.Function):
"""
Linear layer execution with asynchronous communication in backprop.
"""
@staticmethod
def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False):
ctx.save_for_backward(input_, weight, bias)
ctx.use_bias = bias is not None
ctx.process_group = process_group
ctx.async_grad_allreduce = async_grad_allreduce
ctx.fp8_communication = fp8_communication
output = torch.matmul(input_, weight)
if bias is not None:
output = output + bias
return output
@staticmethod
def backward(ctx, grad_output):
input, weight, bias = ctx.saved_tensors
use_bias = ctx.use_bias
fp8_communication = ctx.fp8_communication
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias.
weight = weight.view(weight.shape)
if bias is not None:
bias = bias.view(bias.shape)
total_input = input
grad_input = grad_output.matmul(weight.T)
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
if len(grad_output.shape) > 2:
grad_output = grad_output.view(-1, grad_output.shape[-1])
total_input = total_input.view(-1, total_input.shape[-1])
if fp8_communication or not ctx.async_grad_allreduce:
_reduce(grad_input, group=ctx.process_group, fp8_communication=fp8_communication, fp8_format="e5m2")
elif ctx.async_grad_allreduce:
# Asynchronous all-reduce
handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True)
# Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
# all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py
grad_weight = total_input.t().matmul(grad_output)
grad_bias = grad_output.sum(dim=0) if use_bias else None
if ctx.async_grad_allreduce and not fp8_communication:
handle.wait()
return grad_input, grad_weight, grad_bias, None, None, None, None
class LinearWithAsyncCommunication(torch.autograd.Function):
"""
Linear layer execution with asynchronous communication in backprop.
"""
@staticmethod
def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False, use_zbv=False):
ctx.save_for_backward(input_, weight, bias)
ctx.use_bias = bias is not None
ctx.process_group = process_group
ctx.async_grad_allreduce = async_grad_allreduce
ctx.fp8_communication = fp8_communication
ctx.use_zbv = use_zbv
if bias is not None:
output = F.linear(input_, weight, bias)
else:
output = F.linear(input_, weight)
return output
@staticmethod
def backward(ctx, grad_output):
input, weight, bias = ctx.saved_tensors
use_bias = ctx.use_bias
fp8_communication = ctx.fp8_communication
use_zbv = ctx.use_zbv
def execute_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_, wgrad_gemm_accum_func=None):
wgrad_gemm_accum_func(_input_, _grad_output_, _weight_main_grad_)
def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_func=None):
return wgrad_gemm_func(_grad_output_.t(), _input_)
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to bias.
if use_bias:
bias.view(bias.shape)
total_input = input.contiguous()
grad_input = grad_output.matmul(weight)
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
if len(grad_output.shape) > 2:
grad_output = grad_output.view(-1, grad_output.shape[-1])
total_input = total_input.view(-1, total_input.shape[-1])
if ctx.async_grad_allreduce:
# Asynchronous all-reduce
if fp8_communication:
all_reduce_fp8(grad_input, group=ctx.process_group)
else:
handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True)
# Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
# all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py
if _grad_accum_fusion_available and weight.grad is not None:
grad = weight.grad
if use_zbv:
# TODO: append input, grad_output_, weight, grad func to WeightGradStore
if grad.dtype == torch.float32:
WeightGradStore.put(
total_input,
grad_output,
weight,
functools.partial(
execute_w_pass_grad_accum,
wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32,
),
)
grad_weight = None
elif grad.dtype in (torch.float16, torch.bfloat16):
WeightGradStore.put(
total_input,
grad_output,
weight,
functools.partial(
execute_w_pass_grad_accum,
wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16,
),
)
grad_weight = None
else:
raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
else:
if grad.dtype == torch.float32:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad)
grad_weight = None
elif grad.dtype == torch.float16:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, grad)
grad_weight = None
else:
grad_weight = grad_output.t().matmul(total_input)
else:
if use_zbv:
WeightGradStore.put(
total_input,
grad_output,
weight,
functools.partial(
execute_w_pass,
wgrad_gemm_func=torch.matmul,
),
)
grad_weight = None
else:
grad_weight = grad_output.t().matmul(total_input)
grad_bias = grad_output.sum(dim=0) if use_bias else None
if ctx.async_grad_allreduce and not fp8_communication:
handle.wait()
return grad_input, grad_weight, grad_bias, None, None, None, None
class LinearWithGradAccum(torch.autograd.Function):
"""
Linear layer baseline (no tensor parallel version).
"""
@staticmethod
def forward(ctx, input_, weight, bias, async_grad_allreduce, use_zbv=False):
ctx.save_for_backward(input_, weight, bias)
ctx.use_bias = bias is not None
ctx.async_grad_allreduce = async_grad_allreduce
ctx.use_zbv = use_zbv
if bias is not None:
output = F.linear(input_, weight, bias)
else:
output = F.linear(input_, weight)
return output
@staticmethod
def backward(ctx, grad_output):
input, weight, bias = ctx.saved_tensors
use_bias = ctx.use_bias
use_zbv = ctx.use_zbv
def execute_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_, wgrad_gemm_accum_func=None):
wgrad_gemm_accum_func(_input_, _grad_output_, _weight_main_grad_)
def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_func=None):
return wgrad_gemm_func(_grad_output_.t(), _input_)
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to bias.
if use_bias:
bias.view(bias.shape)
total_input = input.contiguous()
grad_input = grad_output.matmul(weight)
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
if len(grad_output.shape) > 2:
grad_output = grad_output.view(-1, grad_output.shape[-1])
total_input = total_input.view(-1, total_input.shape[-1])
if _grad_accum_fusion_available and weight.grad is not None:
grad = weight.grad
if use_zbv:
# TODO: append input, grad_output_, weight, grad func to WeightGradStore
if grad.dtype == torch.float32:
WeightGradStore.put(
total_input,
grad_output,
weight,
functools.partial(
execute_w_pass_grad_accum,
wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32,
),
)
grad_weight = None
elif grad.dtype in (torch.float16, torch.bfloat16):
WeightGradStore.put(
total_input,
grad_output,
weight,
functools.partial(
execute_w_pass_grad_accum,
wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16,
),
)
grad_weight = None
else:
raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
else:
if grad.dtype == torch.float32:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad)
grad_weight = None
elif grad.dtype == torch.float16:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, grad)
grad_weight = None
else:
grad_weight = grad_output.t().matmul(total_input)
else:
if use_zbv:
WeightGradStore.put(
total_input,
grad_output,
weight,
functools.partial(
execute_w_pass,
wgrad_gemm_func=torch.matmul,
),
)
grad_weight = None
else:
grad_weight = grad_output.t().matmul(total_input)
grad_bias = grad_output.sum(dim=0) if use_bias else None
return grad_input, grad_weight, grad_bias, None, None, None, None
def _ring_as_gather(func, input_to_gather=None, input_local=None, process_group=None, gather_dim=1, keep_item=False):
# currently only support one single tensor as output
group_size = dist.get_world_size(process_group)
cur_rank = dist.get_rank(process_group)
# output_tensors = [torch.empty((input_shape[0], input_shape[1], weight_shape[0])) for _ in range(group_size)]
# initialization of ring communication
recv_rank = cur_rank + 1 if cur_rank + 1 < group_size else 0
send_rank = cur_rank - 1 if cur_rank > 0 else group_size - 1
rank_map = list(dist.get_process_group_ranks(process_group))
recv_rank = rank_map[recv_rank]
send_rank = rank_map[send_rank]
recv_tensors = {}
send_tensors = {}
for k, v in input_to_gather.items():
recv_tensors[k] = torch.empty_like(v)
send_tensors[k] = v.clone()
def communicate_step():
comm_ops = []
for k in recv_tensors:
comm_ops.append(dist.P2POp(dist.irecv, recv_tensors[k], recv_rank, group=process_group))
comm_ops.append(dist.P2POp(dist.isend, send_tensors[k], send_rank, group=process_group))
return dist.batch_isend_irecv(comm_ops)
def switch_step():
for k in recv_tensors:
send_tensors[k], recv_tensors[k] = recv_tensors[k], send_tensors[k]
input_tensors = []
output_tensors = []
handles = communicate_step()
# first round: special case, retrive from local tensor
input_tensors.append(input_to_gather)
output_tensors.append(func(**input_to_gather, **input_local))
for i in range(group_size - 2):
for handle in handles:
handle.wait()
switch_step()
handles = communicate_step()
# actual computation
input_tensors.append(send_tensors)
output_tensors.append(func(**send_tensors, **input_local))
# final round: special case, no need to send/recv again
for handle in handles:
handle.wait()
input_tensors.append(send_tensors)
output_tensors.append(func(**recv_tensors, **input_local))
gathered_input = {}
for k in input_to_gather:
input_shards = [d[k] for d in input_tensors[group_size - cur_rank :] + input_tensors[: group_size - cur_rank]]
gathered_input[k] = torch.cat(input_shards, dim=gather_dim)
gathered_output = torch.cat(
output_tensors[group_size - cur_rank :] + output_tensors[: group_size - cur_rank], dim=gather_dim
)
return gathered_output, gathered_input
class _GatherForwardReduceScatterBackward(torch.autograd.Function):
"""Gather input from sequence parallel in forward and reduce-scatter gradient in backward
Args:
input_ (`torch.Tensor`): The input tensor from sequence parallel region.
process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication.
overlap (`bool`): Whther to overlap the all_gather op and gradient calculate in backward.
"""
@staticmethod
def forward(ctx, input_, process_group, dim, fp8_communication=False):
ctx.process_group = process_group
ctx.dim = dim
ctx.fp8_communication = fp8_communication
return _gather(input_, dim, process_group, fp8_communication, fp8_format="e4m3")
@staticmethod
def backward(ctx, grad_output):
dim = ctx.dim
process_group = ctx.process_group
fp8_communication = ctx.fp8_communication
# do reduce-scatter
new_shape = list(grad_output.shape)
assert (
new_shape[dim] % dist.get_world_size(process_group) == 0
), f"The dimension to split ({new_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). "
new_shape[dim] = new_shape[dim] // dist.get_world_size(process_group)
grad_list = [
item.contiguous() for item in torch.chunk(grad_output, dist.get_world_size(process_group), dim=dim)
]
output = torch.empty(new_shape, dtype=grad_output.dtype, device=grad_output.device)
if fp8_communication:
reduce_scatter_fp8(output, grad_list, group=process_group, fp8_format="e5m2")
else:
dist.reduce_scatter(output, grad_list, group=process_group)
return output, None, None, None
class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
"""Gather input from sequence parallel in forward and reduce-scatter gradient in backward
Args:
input_ (`torch.Tensor`): The input tensor from sequence parallel region.
process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication.
overlap (`bool`): Whether to overlap the all_gather op and gradient calculate in backward.
"""
@staticmethod
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring=False):
ctx.save_for_backward(input_, weight, bias)
ctx.use_bias = bias is not None
ctx.process_group = process_group
ctx.async_grad_reduce_scatter = async_grad_reduce_scatter
ctx.dim = dim
if ring is True:
input_to_gather = {"input": input_}
input_local = {"weight": weight}
output, input_dict = _ring_as_gather(
F.linear,
input_to_gather=input_to_gather,
input_local=input_local,
process_group=process_group,
)
ctx.gathered_input = input_dict["input"]
if bias is not None:
output += bias
else:
input_parallel = _gather(input_, dim, process_group)
ctx.gathered_input = input_parallel
if bias is not None:
output = F.linear(input_parallel, weight, bias)
else:
output = F.linear(input_parallel, weight)
return output
@staticmethod
def backward(ctx, grad_output):
input_, weight, bias = ctx.saved_tensors
use_bias = ctx.use_bias
dim = ctx.dim
process_group = ctx.process_group
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm
if use_bias:
bias = bias.view(bias.shape)
input_parallel = ctx.gathered_input
total_input = input_parallel
grad_input = grad_output.matmul(weight)
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
if len(grad_output.shape) > 2:
grad_output = grad_output.view(-1, grad_output.shape[-1])
total_input = total_input.view(-1, total_input.shape[-1])
if ctx.async_grad_reduce_scatter:
# Asynchronous reduce-scatter
input_list = [
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
]
output = torch.empty(input_.shape, dtype=input_parallel.dtype, device=input_parallel.device).contiguous()
handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
# Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
# all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py
if _grad_accum_fusion_available and weight.grad is not None:
grad = weight.grad
if grad.dtype == torch.float32:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad)
grad_weight = None
elif grad.dtype == torch.float16:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, grad)
grad_weight = None
else:
grad_weight = grad_output.t().matmul(total_input)
else:
grad_weight = grad_output.t().matmul(total_input)
grad_bias = grad_output.sum(dim=0) if use_bias else None
if ctx.async_grad_reduce_scatter:
handle.wait()
return output, grad_weight, grad_bias, None, None, None, None
def _ring_as_reducescatter(
func, input_to_reducescatter=None, input_local=None, process_group=None, reducescatter_dim=1
):
# currently only support one single tensor as output
group_size = dist.get_world_size(process_group)
cur_rank = dist.get_rank(process_group)
# initialization of ring communication
recv_rank = cur_rank - 1 if cur_rank > 0 else group_size - 1
send_rank = cur_rank + 1 if cur_rank + 1 < group_size else 0
rank_map = list(dist.get_process_group_ranks(process_group))
recv_rank = rank_map[recv_rank]
send_rank = rank_map[send_rank]
input_tensors = []
for _ in range(group_size):
input_tensors.append({})
for k, v in input_to_reducescatter.items():
input_shape = v.shape
assert input_shape[reducescatter_dim] % group_size == 0
_input_tensors = list(torch.split(v, input_shape[reducescatter_dim] // group_size, dim=reducescatter_dim))
for i in range(group_size):
input_tensors[i][k] = _input_tensors[i]
input_tensors = input_tensors[cur_rank:] + input_tensors[:cur_rank]
input_tensors.reverse()
output_tensor = func(**input_tensors[0], **input_local)
recv_tensor = torch.empty_like(output_tensor)
send_tensor = output_tensor.clone()
def communicate_step():
recv_op = dist.P2POp(dist.irecv, recv_tensor, recv_rank, group=process_group)
send_op = dist.P2POp(dist.isend, send_tensor, send_rank, group=process_group)
return dist.batch_isend_irecv([recv_op, send_op])
handles = communicate_step()
# first round: special case, retrive from local tensor
for i in range(group_size - 2):
# actual computation
output_tensor = func(**input_tensors[i + 1], **input_local)
for handle in handles:
handle.wait()
output_tensor += recv_tensor
tmp_tensor = send_tensor
send_tensor = output_tensor
output_tensor = tmp_tensor
handles = communicate_step()
# final round: special case, no need to send/recv again
output_tensor = func(**input_tensors[-1], **input_local)
for handle in handles:
handle.wait()
output_tensor += recv_tensor
return output_tensor
class _LinearWithReduceScatterForwardGatherBackward(torch.autograd.Function):
"""Reduce-scatter input from sequence parallel in forward and gather gradient in backward with ring
Args:
input_ (`torch.Tensor`): The input tensor from sequence parallel region.
process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication.
overlap (`bool`): Whther to overlap the all_gather op and gradient calculate in backward.
"""
@staticmethod
def forward(ctx, input_, weight, bias, process_group, dim, ring):
ctx.save_for_backward(input_, weight, bias)
ctx.use_bias = bias is not None
ctx.process_group = process_group
ctx.dim = dim
if ring is True:
input_to_reducescatter = {"input": input_}
input_local = {"weight": weight}
if bias is not None:
input_to_reducescatter["bias"] = bias
output = _ring_as_reducescatter(
F.linear,
input_to_reducescatter=input_to_reducescatter,
input_local=input_local,
process_group=process_group,
)
else:
if bias is not None:
partial_output = F.linear(input_, weight, bias)
else:
partial_output = F.linear(input_, weight)
output_shape = list(partial_output.shape)
assert (
output_shape[dim] % dist.get_world_size(process_group) == 0
), f"The dimension to split ({output_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). "
output_shape[dim] = output_shape[dim] // dist.get_world_size(process_group)
output_list = [
item.contiguous() for item in torch.chunk(partial_output, dist.get_world_size(process_group), dim=dim)
]
output = torch.empty(output_shape, dtype=partial_output.dtype, device=partial_output.device).contiguous()
dist.reduce_scatter(output, output_list, group=process_group)
return output
@staticmethod
def backward(ctx, grad_output):
input_, weight, bias = ctx.saved_tensors
use_bias = ctx.use_bias
dim = ctx.dim
process_group = ctx.process_group
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm
if use_bias:
bias = bias.view(bias.shape)
grad_output = _gather(grad_output, dim, process_group)
# TODO Need to fully optimize
total_input = input_
grad_input = grad_output.matmul(weight)
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
if len(grad_output.shape) > 2:
grad_output = grad_output.view(-1, grad_output.shape[-1])
total_input = total_input.reshape(-1, total_input.shape[-1])
grad_weight = grad_output.t().matmul(total_input)
grad_bias = grad_output.sum(dim=0) if use_bias else None
return grad_input, grad_weight, grad_bias, None, None, None
class _ReduceScatterForwardGatherBackward(torch.autograd.Function):
"""Reduce-scatter input from sequence parallel in forward and gather gradient in backward
Args:
input_ (`torch.Tensor`): The input tensor from sequence parallel region.
process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication.
"""
@staticmethod
def forward(ctx, input_, process_group, dim, fp8_communication=False):
ctx.dim = dim
ctx.process_group = process_group
ctx.fp8_communication = fp8_communication
# do reduce-scatter
new_shape = list(input_.shape)
assert (
new_shape[dim] % dist.get_world_size(process_group) == 0
), f"The dimension to split ({new_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). "
new_shape[dim] = new_shape[dim] // dist.get_world_size(process_group)
input_list = [item.contiguous() for item in torch.chunk(input_, dist.get_world_size(process_group), dim=dim)]
output = torch.empty(new_shape, dtype=input_.dtype, device=input_.device)
if fp8_communication:
reduce_scatter_fp8(output, input_list, group=process_group, fp8_format="e4m3")
else:
dist.reduce_scatter(output, input_list, group=process_group)
return output
@staticmethod
def backward(ctx, grad_output):
dim = ctx.dim
process_group = ctx.process_group
fp8_communication = ctx.fp8_communication
return _gather(grad_output, dim, process_group, fp8_communication, fp8_format="e5m2"), None, None, None
class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
"""
This class is designed for matmul operation with gather forward and reduce-scatter backward.
Args:
input_ (`torch.Tensor`): input matrix.
dim (int): the dimension to perform split and gather
process_group (`torch.distributed.ProcessGroup`): the process group used for collective communication
"""
@staticmethod
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring, fp8_communication):
ctx.save_for_backward(input_, weight, bias)
ctx.use_bias = bias is not None
ctx.process_group = process_group
ctx.async_grad_reduce_scatter = async_grad_reduce_scatter
ctx.dim = dim
ctx.fp8_communication = fp8_communication
if ring is True:
input_to_gather = {"input": input_}
input_local = {"other": weight}
output, input_dict = _ring_as_gather(
torch.matmul,
input_to_gather=input_to_gather,
input_local=input_local,
process_group=process_group,
gather_dim=dim,
)
ctx.gathered_input = input_dict["input"]
else:
input_parallel = _gather(input_, dim, process_group, fp8_communication, fp8_format="e4m3")
ctx.gathered_input = input_parallel
output = torch.matmul(input_parallel, weight)
if bias is not None:
output = output + bias
return output
@staticmethod
def backward(ctx, grad_output):
input_, weight, bias = ctx.saved_tensors
use_bias = ctx.use_bias
dim = ctx.dim
process_group = ctx.process_group
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm
weight = weight.view(weight.shape)
if use_bias:
bias = bias.view(bias.shape)
input_parallel = ctx.gathered_input
total_input = input_parallel
grad_input = grad_output.matmul(weight.T)
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
if len(grad_output.shape) > 2:
grad_output = grad_output.view(-1, grad_output.shape[-1])
total_input = total_input.view(-1, total_input.shape[-1])
if ctx.async_grad_reduce_scatter:
# Asynchronous reduce-scatter
input_list = [
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
]
output = torch.empty(input_.shape, dtype=input_parallel.dtype, device=input_parallel.device).contiguous()
handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
# Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
# all-reduce scheduled first and have GPU resources allocated
grad_weight = total_input.t().matmul(grad_output)
grad_bias = grad_output.sum(dim=0) if use_bias else None
if ctx.async_grad_reduce_scatter:
handle.wait()
return output, grad_weight, grad_bias, None, None, None, None, None
class _SplitForwardGatherBackward(torch.autograd.Function):
"""
Split the input and keep only the corresponding chuck to the rank.
Args:
input_ (`torch.Tensor`): input matrix.
dim (int): the dimension to perform split and gather
process_group (`torch.distributed.ProcessGroup`): the process group used for collective communication
"""
@staticmethod
def forward(ctx, input_, dim, process_group, grad_scale=None, fp8_communication=False):
ctx.process_group = process_group
ctx.dim = dim
ctx.grad_scale = grad_scale
ctx.fp8_communication = fp8_communication
return _split(input_, dim, process_group)
@staticmethod
def backward(ctx, grad_output):
if ctx.grad_scale is not None:
grad_output = grad_output * ctx.grad_scale
return (
_gather(grad_output, ctx.dim, ctx.process_group, ctx.fp8_communication, fp8_format="e5m2"),
None,
None,
None,
None,
)
class _ReduceForward(torch.autograd.Function):
"""
All-reduce the input from the model parallel region.
Args:
input_: input matrix.
process_group: communication group.
"""
@staticmethod
def forward(ctx, input_, process_group, grad_scale=None, fp8_communication=False):
ctx.grad_scale = grad_scale
return _reduce(input_, process_group, fp8_communication, fp8_format="e4m3")
@staticmethod
def backward(ctx, grad_output):
if ctx.grad_scale is not None:
grad_output = grad_output * ctx.grad_scale
return grad_output, None, None, None
class _ReduceBackward(torch.autograd.Function):
"""
All-reduce the input from the model parallel region.
Args:
input_: input matrix.
parallel_mode: parallel mode.
"""
@staticmethod
def forward(ctx, input_, process_group, fp8_communication=False):
ctx.process_group = process_group
ctx.fp8_communication = fp8_communication
return input_
@staticmethod
def backward(ctx, grad_output):
fp8_communication = ctx.fp8_communication
return _reduce(grad_output, ctx.process_group, fp8_communication, fp8_format="e5m2"), None, None
class _GatherForwardSplitBackward(torch.autograd.Function):
"""Gather the input from model parallel region and concatenate.
Args:
input_: input matrix.
parallel_mode: parallel mode.
dim: dimension
"""
@staticmethod
def forward(ctx, input_, dim, process_group, grad_scale=None, fp8_communication=False):
ctx.process_group = process_group
ctx.dim = dim
ctx.grad_scale = grad_scale
return _gather(input_, dim, process_group, fp8_communication=fp8_communication, fp8_format="e4m3")
@staticmethod
def backward(ctx, grad_output):
if ctx.grad_scale is not None:
grad_output = grad_output * ctx.grad_scale
return _split(grad_output, ctx.dim, ctx.process_group), None, None, None, None
class _AllToAll(torch.autograd.Function):
"""All-to-all communication.
Args:
input_: input matrix
process_group: communication group
scatter_dim: scatter dimension
gather_dim: gather dimension
"""
@staticmethod
def forward(ctx, input_, process_group, scatter_dim, gather_dim, fp8_communication=False):
ctx.process_group = process_group
ctx.scatter_dim = scatter_dim
ctx.gather_dim = gather_dim
ctx.fp8_communication = fp8_communication
world_size = dist.get_world_size(process_group)
bsz = input_.shape[0]
# using all_to_all_single when batch size is 1
if bsz == 1:
return _all_to_all_single(
input_,
world_size,
process_group,
scatter_dim,
gather_dim,
fp8_communication=fp8_communication,
fp8_format="e4m3",
)
else:
return _all_to_all(
input_,
world_size,
process_group,
scatter_dim,
gather_dim,
fp8_communication=fp8_communication,
fp8_format="e4m3",
)
@staticmethod
def backward(ctx, grad_output):
process_group = ctx.process_group
scatter_dim = ctx.gather_dim
gather_dim = ctx.scatter_dim
fp8_communication = ctx.fp8_communication
world_size = dist.get_world_size(process_group)
bsz = grad_output.shape[0]
if bsz == 1:
return_grad = _all_to_all_single(
grad_output,
world_size,
process_group,
scatter_dim,
gather_dim,
fp8_communication=fp8_communication,
fp8_format="e5m2",
)
else:
return_grad = _all_to_all(
grad_output,
world_size,
process_group,
scatter_dim,
gather_dim,
fp8_communication=fp8_communication,
fp8_format="e5m2",
)
return (return_grad, None, None, None, None)
class HookParameter(torch.autograd.Function):
"""In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm"""
@staticmethod
def forward(ctx, input, weight, bias):
ctx.save_for_backward(weight, bias)
output = input
return output
@staticmethod
def backward(ctx, grad_output):
weight, bias = ctx.saved_tensors
if weight is not None:
weight = weight.view(weight.shape)
if bias is not None:
bias = bias.view(bias.shape)
return grad_output, None, None
def hook_parameter_in_backward(input, weight=None, bias=None):
return HookParameter.apply(input, weight, bias)
def _reduce(input_, process_group, fp8_communication=False, fp8_format="e5m2"):
# skip if only one rank involved
if dist.get_world_size(process_group) == 1:
return input_
else:
if fp8_communication:
all_reduce_fp8(input_, group=process_group, fp8_format=fp8_format)
else:
dist.all_reduce(input_, group=process_group)
return input_
def _split(input_, dim=-1, process_group=None):
# skip if only one rank involved
world_size = dist.get_world_size(process_group)
if world_size == 1:
return input_
# Split along last dimension.
dim_size = input_.size(dim)
assert dim_size % world_size == 0, (
f"The dimension to split ({dim_size}) is not a multiple of world size ({world_size}), "
f"cannot split tensor evenly"
)
tensor_list = torch.split(input_, dim_size // world_size, dim=dim)
rank = dist.get_rank(process_group)
output = tensor_list[rank].clone().contiguous()
return output
def _gather(input_, dim=-1, process_group=None, fp8_communication=False, fp8_format="e5m2"):
# skip if only one rank involved
world_size = dist.get_world_size(process_group)
if world_size == 1:
return input_
input_ = input_.contiguous()
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
if fp8_communication:
all_gather_fp8(tensor_list, input_, fp8_format=fp8_format, group=process_group)
else:
dist.all_gather(tensor_list, input_, group=process_group)
output = torch.cat(tensor_list, dim=dim).contiguous()
return output
def _reduce_scatter(input_, dim=1, process_group=None):
"""Do reduce-scatter operation.
Args:
input_ (`torch.Tensor`): The input tensor from sequence parallel region.
dim (int): The dimension to perform reduce-scatter.
process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication.
"""
world_size = dist.get_world_size(process_group)
if world_size == 1:
return input_
# reduce-scatter
new_shape = list(input_.shape)
assert (
new_shape[dim] % dist.get_world_size(process_group) == 0
), f"The dimension to split ({new_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). "
new_shape[dim] = new_shape[dim] // world_size
output = torch.empty(new_shape, dtype=input_.dtype, device=input_.device)
dist.reduce_scatter(output, input_, group=process_group)
return output
def _all_to_all(input_, world_size, group, scatter_dim, gather_dim, fp8_communication=False, fp8_format="e5m2"):
input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)]
output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)]
if fp8_communication:
all_to_all_fp8(output_list, input_list, group=group, fp8_format=fp8_format)
else:
dist.all_to_all(output_list, input_list, group=group)
return torch.cat(output_list, dim=gather_dim).contiguous()
def _all_to_all_single(
input_, seq_world_size, group, scatter_dim, gather_dim, fp8_communication=False, fp8_format="e5m2"
):
inp_shape = list(input_.shape)
inp_shape[scatter_dim] = inp_shape[scatter_dim] // seq_world_size
if scatter_dim < 2:
input_t = input_.reshape([seq_world_size, inp_shape[scatter_dim]] + inp_shape[scatter_dim + 1 :]).contiguous()
else:
input_t = (
input_.reshape([-1, seq_world_size, inp_shape[scatter_dim]] + inp_shape[scatter_dim + 1 :])
.transpose(0, 1)
.contiguous()
)
output = torch.empty_like(input_t)
if fp8_communication:
all_to_all_single_fp8(output, input_t, group=group, fp8_format=fp8_format)
else:
dist.all_to_all_single(output, input_t, group=group)
if scatter_dim < 2:
output = output.transpose(0, 1).contiguous()
return output.reshape(
inp_shape[:gather_dim]
+ [
inp_shape[gather_dim] * seq_world_size,
]
+ inp_shape[gather_dim + 1 :]
).contiguous()
def matmul_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False):
return MatmulWithAsyncCommunication.apply(
input_, weight, bias, process_group, async_grad_allreduce, fp8_communication
)
def linear_with_async_comm(
input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False, use_zbv=False
):
return LinearWithAsyncCommunication.apply(
input_, weight, bias, process_group, async_grad_allreduce, fp8_communication, use_zbv
)
def linear_with_grad_accum(input_, weight, bias, async_grad_allreduce, use_zbv=False):
return LinearWithGradAccum.apply(input_, weight, bias, async_grad_allreduce, use_zbv)
def linear_gather_forward_reducescatter_backward(
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring=False
):
return _LinearWithGatherForwardReduceScatterBackward.apply(
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring
)
def gather_forward_reducescatter_backward(input_, process_group, dim, fp8_communication=False):
return _GatherForwardReduceScatterBackward.apply(input_, process_group, dim, fp8_communication)
def reducescatter_forward_gather_backward(input_, process_group, dim, fp8_communication=False):
return _ReduceScatterForwardGatherBackward.apply(input_, process_group, dim, fp8_communication)
def linear_reducescatter_forward_gather_backward(input_, weight, bias=None, process_group=None, dim=1, ring=False):
return _LinearWithReduceScatterForwardGatherBackward.apply(input_, weight, bias, process_group, dim, ring)
def matmul_gather_forward_reducescatter_backward(
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring=False, fp8_communication=False
):
return _MatmulWithGatherForwardReduceScatterBackward.apply(
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring, fp8_communication
)
def gather_forward_split_backward(input_, dim, process_group, grad_scale=None, fp8_communication=False):
return _GatherForwardSplitBackward.apply(input_, dim, process_group, grad_scale, fp8_communication)
def split_forward_gather_backward(input_, dim, process_group, grad_scale=None, fp8_communication=False):
return _SplitForwardGatherBackward.apply(input_, dim, process_group, grad_scale, fp8_communication)
def reduce_forward(input_, process_group, grad_scale=None, fp8_communication=False):
return _ReduceForward.apply(input_, process_group, grad_scale, fp8_communication)
def reduce_backward(input_, process_group, fp8_communication=False):
return _ReduceBackward.apply(input_, process_group, fp8_communication)
def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1, fp8_communication=False):
return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim, fp8_communication)
def gather_sp_output(hidden_states, shard_config, sp_dim=1):
"""
Gather the output of the last layer for cross entropy computation
"""
sp_group = shard_config.sequence_parallel_process_group
sp_mode = shard_config.sequence_parallelism_mode
fp8_comm = shard_config.fp8_communication
if dist.get_world_size(sp_group) == 1:
return hidden_states
# Rescale grad (HybridParallelPlugin applies ZeRO grad averaging on the DP * SP group)
scale = None if is_share_sp_tp(sp_mode) else dist.get_world_size(sp_group)
hidden_states = gather_forward_split_backward(
hidden_states, sp_dim, sp_group, grad_scale=scale, fp8_communication=fp8_comm
)
return hidden_states