Files
ColossalAI/applications/ColossalEval/colossal_eval/models/huggingface.py
flybird11111 295dd2d9fe [zerobubble] rebase main (#6075)
* fp8 operators for compressed communication

cast_to_fp8, cast_from_fp8, all_reduce_fp8

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix typo

* fix scaling algorithm in FP8 casting

* support fp8 communication in pipeline parallelism

* add fp8_communication flag in the script

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* shardformer fp8

* fix rebase

* remove all to all

* fix shardformer fp8 communication training degradation

* [fp8] support all-gather flat tensor (#5932)

* [fp8] add fp8 comm for low level zero

* [test] add zero fp8 test case

* [Feature] llama shardformer fp8 support (#5938)

* add llama shardformer fp8

* Llama Shardformer Parity

* fix typo

* fix all reduce

* fix pytest failure

* fix reduce op and move function to fp8.py

* fix typo

* [FP8] rebase main (#5963)

* add SimPO

* fix dataloader

* remove debug code

* add orpo

* fix style

* fix colossalai, transformers version

* fix colossalai, transformers version

* fix colossalai, transformers version

* fix torch colossalai version

* update transformers version

* [shardformer] DeepseekMoE support (#5871)

* [Feature] deepseek moe expert parallel implement

* [misc] fix typo, remove redundant file (#5867)

* [misc] fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Feature] deepseek support & unit test

* [misc] remove debug code & useless print

* [misc] fix typos (#5872)

* [Feature] remove modeling file, use auto config. (#5884)

* [misc] fix typos

* [Feature] deepseek support via auto model, remove modeling file

* [misc] delete useless file

* [misc] fix typos

* [Deepseek] remove redundant code (#5888)

* [misc] fix typos

* [Feature] deepseek support via auto model, remove modeling file

* [misc] delete useless file

* [misc] fix typos

* [misc] remove redundant code

* [Feature/deepseek] resolve comment. (#5889)

* [misc] fix typos

* [Feature] deepseek support via auto model, remove modeling file

* [misc] delete useless file

* [misc] fix typos

* [misc] remove redundant code

* [misc] mv module replacement into if branch

* [misc] add some warning message and modify some code in unit test

* [misc] fix typos

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Hoxfix] Fix CUDA_DEVICE_MAX_CONNECTIONS for comm overlap

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [Feat] Diffusion Model(PixArtAlpha/StableDiffusion3) Support (#5838)

* Diffusion Model Inference support

* Stable Diffusion 3 Support

* pixartalpha support

* [HotFix] CI,import,requirements-test for #5838 (#5892)

* [Hot Fix] CI,import,requirements-test

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Feature] Enable PP + SP for llama (#5868)

* fix cross-PP-stage position id length diff bug

* fix typo

* fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* use a one cross entropy func for all shardformer models

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [ShardFormer] Add Ulysses Sequence Parallelism support for Command-R, Qwen2 and ChatGLM (#5897)

* add benchmark for sft, dpo, simpo, orpo. Add benchmarking result. Support lora with gradient checkpoint

* fix style

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix eval

* hotfix citation

* [zero] support all-gather overlap (#5898)

* [zero] support all-gather overlap

* [zero] add overlap all-gather flag

* [misc] fix typo

* [zero] update api

* fix orpo cross entropy loss

* [Auto Parallel]: Speed up intra-op plan generation by 44% (#5446)

* Remove unnecessary calls to deepcopy

* Build DimSpec's difference dict only once

This change considerably speeds up construction speed of DimSpec objects. The difference_dict is the same for each DimSpec object, so a single copy of it is enough.

* Fix documentation of DimSpec's difference method

* [ShardFormer] fix qwen2 sp (#5903)

* [compatibility] support torch 2.2 (#5875)

* Support Pytorch 2.2.2

* keep build_on_pr file and update .compatibility

* fix object_to_tensor usage when torch>=2.3.0 (#5820)

* [misc] support torch2.3 (#5893)

* [misc] support torch2.3

* [devops] update compatibility ci

* [devops] update compatibility ci

* [devops] add debug

* [devops] add debug

* [devops] add debug

* [devops] add debug

* [devops] remove debug

* [devops] remove debug

* [release] update version (#5912)

* [plugin] support all-gather overlap for hybrid parallel (#5919)

* [plugin] fixed all-gather overlap support for hybrid parallel

* add kto

* fix style, add kto data sample

* [Examples] Add lazy init to OPT and GPT examples (#5924)

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [ColossalChat] Hotfix for ColossalChat (#5910)

* add ignore and tiny llama

* fix path issue

* run style

* fix issue

* update bash

* add ignore and tiny llama

* fix path issue

* run style

* fix issue

* update bash

* fix ddp issue

* add Qwen 1.5 32B

* refactor tokenization

* [FIX BUG] UnboundLocalError: cannot access local variable 'default_conversation' where it is not associated with a value (#5931)

* cannot access local variable 'default_conversation' where it is not associated with a value

set default value for 'default_conversation'

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* fix test data

* refactor evaluation

* remove real data path

* remove real data path

* Add n_fused as an input from native_module (#5894)

* [FIX BUG] convert env param to int in (#5934)

* [Hotfix] Fix ZeRO typo #5936

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [Feature] Add a switch to control whether the model checkpoint needs to be saved after each epoch ends (#5941)

* Add a switch to control whether the model checkpoint needs to be saved after each epoch ends

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* fix style

* fix style

* fix style

* [shardformer] hotfix attn mask (#5945)

* [shardformer] hotfix attn mask (#5947)

* [Feat] Distrifusion Acceleration Support for Diffusion Inference (#5895)

* Distrifusion Support source

* comp comm overlap optimization

* sd3 benchmark

* pixart distrifusion bug fix

* sd3 bug fix and benchmark

* generation bug fix

* naming fix

* add docstring, fix counter and shape error

* add reference

* readme and requirement

* [zero] hotfix update master params (#5951)

* [release] update version (#5952)

* [Chat] Fix lora (#5946)

* fix merging

* remove filepath

* fix style

* Update README.md (#5958)

* [hotfix] Remove unused plan section (#5957)

* remove readme

* fix readme

* update

* [test] add mixtral for sequence classification

* [test] add mixtral transformer test

* [moe] fix plugin

* [test] mixtra pp shard test

* [chore] handle non member group

* [zero] solve hang

* [test] pass mixtral shardformer test

* [moe] implement transit between non moe tp and ep

* [zero] solve hang

* [misc] solve booster hang by rename the variable

* solve hang when parallel mode = pp + dp

* [moe] implement submesh initialization

* [moe] add mixtral dp grad scaling when not all experts are activated

* [chore] manually revert unintended commit

* [chore] trivial fix

* [chore] arg pass & remove drop token

* [test] add mixtral modelling test

* [moe] implement tp

* [moe] test deepseek

* [moe] clean legacy code

* [Feature] MoE Ulysses Support (#5918)

* moe sp support

* moe sp bug solve

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [chore] minor fix

* [moe] init moe plugin comm setting with sp

* moe sp + ep bug fix

* [moe] finalize test (no pp)

* [moe] full test for deepseek and mixtral (pp + sp to fix)

* [chore] minor fix after rebase

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [chore] solve moe ckpt test failure and some other arg pass failure

* [moe] remove ops

* [test] fix test: test_zero1_2

* [bug] fix: somehow logger hangs the program

* [moe] deepseek moe sp support

* [test] add check

* [deepseek] replace attn (a workaround for bug in transformers)

* [misc] skip redunant test

* [misc] remove debug/print code

* [moe] refactor mesh assignment

* Revert "[moe] implement submesh initialization"

This reverts commit 2f9bce6686.

* [chore] change moe_pg_mesh to private

* [misc] remove incompatible test config

* [misc] fix ci failure: change default value to false in moe plugin

* [misc] remove useless condition

* [chore] docstring

* [moe] remove force_overlap_comm flag and add warning instead

* [doc] add MoeHybridParallelPlugin docstring

* [moe] solve dp axis issue

* [chore] remove redundant test case, print string & reduce test tokens

* [feat] Dist Loader for Eval (#5950)

* support auto distributed data loader

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* support auto distributed data loader

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix tp error

* remove unused parameters

* remove unused

* update inference

* update docs

* update inference

---------

Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [lora] lora support hybrid parallel plugin (#5956)

* lora support hybrid plugin

* fix

* fix

* fix

* fix

* fp8 operators for compressed communication

cast_to_fp8, cast_from_fp8, all_reduce_fp8

* fix scaling algorithm in FP8 casting

* support fp8 communication in pipeline parallelism

* add fp8_communication flag in the script

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* shardformer fp8

* fix rebase

* remove all to all

* fix shardformer fp8 communication training degradation

* [fp8] support all-gather flat tensor (#5932)

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* Update low_level_optim.py

---------

Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: Haze188 <haze188@qq.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com>
Co-authored-by: Guangyao Zhang <xjtu521@qq.com>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: Stephan Kö <stephankoe@users.noreply.github.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: zhurunhua <1281592874@qq.com>
Co-authored-by: Insu Jang <insujang@umich.edu>
Co-authored-by: Gao, Ruiyuan <905370712@qq.com>
Co-authored-by: hxwang <wang1570@e.ntu.edu.sg>
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: Wang Binluo <32676639+wangbluo@users.noreply.github.com>
Co-authored-by: HangXu <hangxu0304@gmail.com>

* [fp8]support all2all fp8 (#5953)

* support all2all fp8

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [fp8] add fp8 linear (#5967)

* [fp8] add fp8 linear

* [test] fix fp8 linear test condition

* [test] fix fp8 linear test condition

* [test] fix fp8 linear test condition

* [fp8] support fp8 amp for hybrid parallel plugin (#5975)

* [fp8] support fp8 amp for hybrid parallel plugin

* [test] add fp8 hook test

* [fp8] fix fp8 linear compatibility

* fix (#5976)

* [Feature]: support FP8 communication in DDP, FSDP, Gemini (#5928)

* support fp8_communication in the Torch DDP grad comm, FSDP grad comm, and FSDP params comm

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* implement communication hook for FSDP params all-gather

* added unit test for fp8 operators

* support fp8 communication in GeminiPlugin

* update training scripts to support fsdp and fp8 communication

* fixed some minor bugs observed in unit test

* add all_gather_into_tensor_flat_fp8

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add skip the test if torch < 2.2.0

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add skip the test if torch < 2.2.0

* add skip the test if torch < 2.2.0

* add fp8_comm flag

* rebase latest fp8 operators

* rebase latest fp8 operators

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [test ci]Feature/fp8 comm (#5981)

* fix

* fix

* fix

* [fp8] support gemini plugin (#5978)

* [fp8] refactor hook

* [fp8] support gemini plugin

* [example] add fp8 option for llama benchmark

* [fp8] use torch compile (torch >= 2.3.0) (#5979)

* [fp8] use torch compile (torch >= 2.4.0)

* [fp8] set use_fast_accum in linear

* [chore] formal version check

* [chore] fix sig

* [fp8]Moe support fp8 communication (#5977)

* fix

* support moe fp8

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

fix

fi

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [fp8] support hybrid parallel plugin (#5982)

* support fp8 comm for qwen2 model

* support fp8 comm for qwen2 model

* support fp8 comm for qwen2 model

* fp8

* fix

* bert and bloom

* chatglm and command

* gpt2,gptj,bert, falcon,blip2

* mistral,opy,sam,t5,vit,whisper

* fix

* fix

* fix

* [fp8] refactor fp8 linear with compile (#5993)

* [fp8] refactor fp8 linear with compile

* [fp8] fix linear test

* [fp8] fix linear test

* [fp8] support asynchronous FP8 communication (#5997)

* fix

* fix

* fix

* support async all2all

* support async op for all gather

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [fp8] update torch.compile for linear_fp8 to >= 2.4.0 (#6004)

* [fp8] linear perf enhancement

* [fp8]update reduce-scatter test (#6002)

* fix

* fix

* fix

* fix

* [fp8] add use_fp8 option for MoeHybridParallelPlugin (#6009)

* [fp8] zero support fp8 linear. (#6006)

* fix

* fix

* fix

* zero fp8

* zero fp8

* Update requirements.txt

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix the merge

* fix the merge

* fix the merge

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix the merge

* fix

* fix

* fix the merge

* fix

* fix

* fix

* fix

* fix

* fix the merge

* fix

* fix

* fix

* fix

* [fp8] Merge feature/fp8_comm to main branch of Colossalai (#6016)

* add SimPO

* fix dataloader

* remove debug code

* add orpo

* fix style

* fix colossalai, transformers version

* fix colossalai, transformers version

* fix colossalai, transformers version

* fix torch colossalai version

* update transformers version

* [shardformer] DeepseekMoE support (#5871)

* [Feature] deepseek moe expert parallel implement

* [misc] fix typo, remove redundant file (#5867)

* [misc] fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Feature] deepseek support & unit test

* [misc] remove debug code & useless print

* [misc] fix typos (#5872)

* [Feature] remove modeling file, use auto config. (#5884)

* [misc] fix typos

* [Feature] deepseek support via auto model, remove modeling file

* [misc] delete useless file

* [misc] fix typos

* [Deepseek] remove redundant code (#5888)

* [misc] fix typos

* [Feature] deepseek support via auto model, remove modeling file

* [misc] delete useless file

* [misc] fix typos

* [misc] remove redundant code

* [Feature/deepseek] resolve comment. (#5889)

* [misc] fix typos

* [Feature] deepseek support via auto model, remove modeling file

* [misc] delete useless file

* [misc] fix typos

* [misc] remove redundant code

* [misc] mv module replacement into if branch

* [misc] add some warning message and modify some code in unit test

* [misc] fix typos

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Hoxfix] Fix CUDA_DEVICE_MAX_CONNECTIONS for comm overlap

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [Feat] Diffusion Model(PixArtAlpha/StableDiffusion3) Support (#5838)

* Diffusion Model Inference support

* Stable Diffusion 3 Support

* pixartalpha support

* [HotFix] CI,import,requirements-test for #5838 (#5892)

* [Hot Fix] CI,import,requirements-test

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Feature] Enable PP + SP for llama (#5868)

* fix cross-PP-stage position id length diff bug

* fix typo

* fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* use a one cross entropy func for all shardformer models

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [ShardFormer] Add Ulysses Sequence Parallelism support for Command-R, Qwen2 and ChatGLM (#5897)

* add benchmark for sft, dpo, simpo, orpo. Add benchmarking result. Support lora with gradient checkpoint

* fix style

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix eval

* hotfix citation

* [zero] support all-gather overlap (#5898)

* [zero] support all-gather overlap

* [zero] add overlap all-gather flag

* [misc] fix typo

* [zero] update api

* fix orpo cross entropy loss

* [Auto Parallel]: Speed up intra-op plan generation by 44% (#5446)

* Remove unnecessary calls to deepcopy

* Build DimSpec's difference dict only once

This change considerably speeds up construction speed of DimSpec objects. The difference_dict is the same for each DimSpec object, so a single copy of it is enough.

* Fix documentation of DimSpec's difference method

* [ShardFormer] fix qwen2 sp (#5903)

* [compatibility] support torch 2.2 (#5875)

* Support Pytorch 2.2.2

* keep build_on_pr file and update .compatibility

* fix object_to_tensor usage when torch>=2.3.0 (#5820)

* [misc] support torch2.3 (#5893)

* [misc] support torch2.3

* [devops] update compatibility ci

* [devops] update compatibility ci

* [devops] add debug

* [devops] add debug

* [devops] add debug

* [devops] add debug

* [devops] remove debug

* [devops] remove debug

* [release] update version (#5912)

* [plugin] support all-gather overlap for hybrid parallel (#5919)

* [plugin] fixed all-gather overlap support for hybrid parallel

* add kto

* fix style, add kto data sample

* [Examples] Add lazy init to OPT and GPT examples (#5924)

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [ColossalChat] Hotfix for ColossalChat (#5910)

* add ignore and tiny llama

* fix path issue

* run style

* fix issue

* update bash

* add ignore and tiny llama

* fix path issue

* run style

* fix issue

* update bash

* fix ddp issue

* add Qwen 1.5 32B

* refactor tokenization

* [FIX BUG] UnboundLocalError: cannot access local variable 'default_conversation' where it is not associated with a value (#5931)

* cannot access local variable 'default_conversation' where it is not associated with a value

set default value for 'default_conversation'

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* fix test data

* refactor evaluation

* remove real data path

* remove real data path

* Add n_fused as an input from native_module (#5894)

* [FIX BUG] convert env param to int in (#5934)

* [Hotfix] Fix ZeRO typo #5936

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [Feature] Add a switch to control whether the model checkpoint needs to be saved after each epoch ends (#5941)

* Add a switch to control whether the model checkpoint needs to be saved after each epoch ends

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* fix style

* fix style

* fix style

* [shardformer] hotfix attn mask (#5945)

* [shardformer] hotfix attn mask (#5947)

* [Feat] Distrifusion Acceleration Support for Diffusion Inference (#5895)

* Distrifusion Support source

* comp comm overlap optimization

* sd3 benchmark

* pixart distrifusion bug fix

* sd3 bug fix and benchmark

* generation bug fix

* naming fix

* add docstring, fix counter and shape error

* add reference

* readme and requirement

* [zero] hotfix update master params (#5951)

* [release] update version (#5952)

* [Chat] Fix lora (#5946)

* fix merging

* remove filepath

* fix style

* Update README.md (#5958)

* [hotfix] Remove unused plan section (#5957)

* remove readme

* fix readme

* update

* [test] add mixtral for sequence classification

* [test] add mixtral transformer test

* [moe] fix plugin

* [test] mixtra pp shard test

* [chore] handle non member group

* [zero] solve hang

* [test] pass mixtral shardformer test

* [moe] implement transit between non moe tp and ep

* [zero] solve hang

* [misc] solve booster hang by rename the variable

* solve hang when parallel mode = pp + dp

* [moe] implement submesh initialization

* [moe] add mixtral dp grad scaling when not all experts are activated

* [chore] manually revert unintended commit

* [chore] trivial fix

* [chore] arg pass & remove drop token

* [test] add mixtral modelling test

* [moe] implement tp

* [moe] test deepseek

* [moe] clean legacy code

* [Feature] MoE Ulysses Support (#5918)

* moe sp support

* moe sp bug solve

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [chore] minor fix

* [moe] init moe plugin comm setting with sp

* moe sp + ep bug fix

* [moe] finalize test (no pp)

* [moe] full test for deepseek and mixtral (pp + sp to fix)

* [chore] minor fix after rebase

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [chore] solve moe ckpt test failure and some other arg pass failure

* [moe] remove ops

* [test] fix test: test_zero1_2

* [bug] fix: somehow logger hangs the program

* [moe] deepseek moe sp support

* [test] add check

* [deepseek] replace attn (a workaround for bug in transformers)

* [misc] skip redunant test

* [misc] remove debug/print code

* [moe] refactor mesh assignment

* Revert "[moe] implement submesh initialization"

This reverts commit 2f9bce6686.

* [chore] change moe_pg_mesh to private

* [misc] remove incompatible test config

* [misc] fix ci failure: change default value to false in moe plugin

* [misc] remove useless condition

* [chore] docstring

* [moe] remove force_overlap_comm flag and add warning instead

* [doc] add MoeHybridParallelPlugin docstring

* [moe] solve dp axis issue

* [chore] remove redundant test case, print string & reduce test tokens

* [feat] Dist Loader for Eval (#5950)

* support auto distributed data loader

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* support auto distributed data loader

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix tp error

* remove unused parameters

* remove unused

* update inference

* update docs

* update inference

---------

Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [lora] lora support hybrid parallel plugin (#5956)

* lora support hybrid plugin

* fix

* fix

* fix

* fix

* Support overall loss, update KTO logging

* [Docs] clarify launch port

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [Hotfix] README link (#5966)

* update ignore

* update readme

* run style

* update readme

* [Hotfix] Avoid fused RMSnorm import error without apex (#5985)

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [Chat] fix readme (#5989)

* fix readme

* fix readme, tokenization fully tested

* fix readme, tokenization fully tested

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: root <root@notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9-0.notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9.colossal-ai.svc.cluster.local>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* fix sync condition (#6000)

* [plugin] add cast inputs option for zero (#6003)

* [pre-commit.ci] pre-commit autoupdate (#5995)

updates:
- [github.com/psf/black-pre-commit-mirror: 24.4.2 → 24.8.0](https://github.com/psf/black-pre-commit-mirror/compare/24.4.2...24.8.0)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [misc] Bypass the huggingface bug to solve the mask mismatch problem (#5991)

* [Feature] Zigzag Ring attention (#5905)

* halfway

* fix cross-PP-stage position id length diff bug

* fix typo

* fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* unified cross entropy func for all shardformer models

* remove redundant lines

* add basic ring attn; debug cross entropy

* fwd bwd logic complete

* fwd bwd logic complete; add experimental triton rescale

* precision tests passed

* precision tests passed

* fix typos and remove misc files

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add sp_mode to benchmark; fix varlen interface

* update softmax_lse shape by new interface

* change tester name

* remove buffer clone; support packed seq layout

* add varlen tests

* fix typo

* all tests passed

* add dkv_group; fix mask

* remove debug statements

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [misc] update compatibility (#6008)

* [misc] update compatibility

* [misc] update requirements

* [devops] disable requirements cache

* [test] fix torch ddp test

* [test] fix rerun on address in use

* [test] fix lazy init

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix the merge

* fix the merge

* overlap kv comm with output rescale (#6017)

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* fix the merge

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix the merge

* fix

* fix

* fix the merge

* fix

* [misc] Use dist logger in plugins (#6011)

* use dist logger in plugins

* remove trash

* print on rank 0

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* fix

* fix

* fix

* fix

* fix the merge

* fix

* fix

* fix

* fix

---------

Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: Haze188 <haze188@qq.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com>
Co-authored-by: Guangyao Zhang <xjtu521@qq.com>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: Stephan Kö <stephankoe@users.noreply.github.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: zhurunhua <1281592874@qq.com>
Co-authored-by: Insu Jang <insujang@umich.edu>
Co-authored-by: Gao, Ruiyuan <905370712@qq.com>
Co-authored-by: hxwang <wang1570@e.ntu.edu.sg>
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: root <root@notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9-0.notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9.colossal-ai.svc.cluster.local>

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update train_dpo.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update low_level_zero_plugin.py

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [CI] Remove triton version for compatibility bug; update req torch >=2.2 (#6018)

* remove triton version

* remove torch 2.2

* remove torch 2.1

* debug

* remove 2.1 build tests

* require torch >=2.2

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [plugin] hotfix zero plugin (#6036)

* [plugin] hotfix zero plugin

* [plugin] hotfix zero plugin

* [Colossal-LLaMA] Refactor latest APIs (#6030)

* refactor latest code

* update api

* add dummy dataset

* update Readme

* add setup

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update files

* add PP support

* update arguments

* update argument

* reorg folder

* update version

* remove IB infor

* update utils

* update readme

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update save for zero

* update save

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add apex

* update

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* add fused norm (#6038)

* [FP8] unsqueeze scale to make it compatible with torch.compile (#6040)

* [colossalai/checkpoint_io/...] fix bug in load_state_dict_into_model; format error msg (#6020)

* fix bug in load_state_dict_into_model; format error msg

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update utils.py

to support checking missing_keys

* Update general_checkpoint_io.py

fix bug in missing_keys error message

* retrigger tests

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Hotfix] Remove deprecated install (#6042)

* remove deprecated install

* remove unused folder

* [fp8] optimize all-gather (#6043)

* [fp8] optimize all-gather

* [fp8] fix all gather fp8 ring

* [fp8] enable compile

* [fp8] fix all gather fp8 ring

* [fp8] fix linear hook (#6046)

* [fp8] disable all_to_all_fp8 in intranode  (#6045)

* enhance all_to_all_fp8 with internode comm control

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* disable some fp8 ops due to performance issue

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [release] update version (#6041)

* [release] update version

* [devops] update comp test

* [devops] update comp test debug

* [devops] debug comp test

* [devops] debug comp test

* [devops] debug comp test

* [devops] debug comp test

* [devops] debug comp test

* [Feature] Split cross-entropy computation in SP (#5959)

* halfway

* fix cross-PP-stage position id length diff bug

* fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* unified cross entropy func for all shardformer models

* remove redundant lines

* add basic ring attn; debug cross entropy

* fwd bwd logic complete

* fwd bwd logic complete; add experimental triton rescale

* precision tests passed

* precision tests passed

* fix typos and remove misc files

* update softmax_lse shape by new interface

* change tester name

* remove buffer clone; support packed seq layout

* add varlen tests

* fix typo

* all tests passed

* add dkv_group; fix mask

* remove debug statements

* adapt chatglm, command-R, qwen

* debug

* halfway

* fix cross-PP-stage position id length diff bug

* fix typo

* fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* unified cross entropy func for all shardformer models

* remove redundant lines

* add basic ring attn; debug cross entropy

* fwd bwd logic complete

* fwd bwd logic complete; add experimental triton rescale

* precision tests passed

* precision tests passed

* fix typos and remove misc files

* add sp_mode to benchmark; fix varlen interface

* update softmax_lse shape by new interface

* add varlen tests

* fix typo

* all tests passed

* add dkv_group; fix mask

* remove debug statements

* add comments

* q1 index only once

* remove events to simplify stream sync

* simplify forward/backward logic

* 2d ring forward passed

* 2d ring backward passed

* fixes

* fix ring attn loss

* 2D ring backward + llama passed

* merge

* update logger

* fix typo

* rebase

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix typo

* remove typos

* fixes

* support GPT

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [hotfix] moe hybrid parallelism benchmark & follow-up fix (#6048)

* [example] pass use_fp8_comm flag to all plugins

* [example] add mixtral benchmark

* [moe] refine assertion and check

* [moe] fix mixtral & add more tests

* [moe] consider checking dp * sp group and moe_dp_group

* [mixtral] remove gate tp & add more tests

* [deepseek] fix tp & sp for deepseek

* [mixtral] minor fix

* [deepseek] add deepseek benchmark

* [fp8] hotfix backward hook (#6053)

* [fp8] hotfix backward hook

* [fp8] hotfix pipeline loss accumulation

* [doc] update sp doc (#6055)

* update sp doc

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* fix the sp

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix the attn

* fix

* fix

* fix

* fix

* [zerobubble]Support ZeroBubble Pipeline (#6034)

* [feat] add zerobubble pp (just a frame now); add POC test for dx_dw; add test for zerobubble;

* [feat] add dw test;

* [fix] fix weight not close;

* [update] update text;

* [feat] add test run_fwd_bwd automatic scheduling;

* [feat] split communication and calculation; fix pop empty send_bwd_buffer error;

* [feat] add test for p & p grad;

* [feat] add comments for ZBV func;

* [fix] rm useless assign and comments;

* [fix] fix ci test; add pytest;

* [feat] add run_fwd_bwd_with_microbatch  (replace input) & test; add p&p.grad assert close test & all pass;

* [feat] add apply v_schedule graph; p & p.grad assert err exist;

* [fix] update

* [feat] fix ci; add assert;

* [feat] fix poc format

* [feat] fix func name & ci; add comments;

* [fix] fix poc test; add comments in poc;

* [feat] add optim backward_b_by_grad

* [feat] fix optimizer bwd b & w; support return accum loss & output

* [feat] add fwd_bwd_step, run_fwd_only;

* [fix] fix optim bwd; add license for v_schedule; remove redundant attributes; fix schedule loop "while"--> "for"; add communication dict;

* [fix] fix communication_map;

* [feat] update test; rm comments;

* [fix] rm zbv in hybridplugin

* [fix] fix optim bwd;

* [fix] fix optim bwd;

* [fix] rm output.data after send fwd;

* [fix] fix bwd step if condition; remove useless comments and format info;

* [fix] fix detach output & release output;

* [fix] rm requir_grad for output;

* [fix] fix requir grad position and detach position and input&output local buffer append position;

* [feat] add memory assertation;

* [fix] fix mem check;

* [fix] mem assertation'

* [fix] fix mem assertation

* [fix] fix mem; use a new model shape; only assert mem less and equal than theo;

* [fix] fix model zoo import;

* [fix] fix redundant detach & clone; add buffer assertation in the end;

* [fix] add output_obj_grad assert None at bwd b step; replace input_obj.require_grad_ with treemap;

* [fix] update optim state dict assert (include param group & state); fix mem assert after add optim;

* [fix] add testcase with microbatch 4;

* [fp8] fix missing fp8_comm flag in mixtral (#6057)

* fix

* fix

* fix

* [fp8] Disable all_gather intranode. Disable Redundant all_gather fp8 (#6059)

* all_gather only internode, fix pytest

* fix cuda arch <89 compile pytest error

* fix pytest failure

* disable all_gather_into_tensor_flat_fp8

* fix fp8 format

* fix pytest

* fix conversations

* fix chunk tuple to list

* [doc] FP8 training and communication document (#6050)

* Add FP8 training and communication document

* add fp8 docstring for plugins

* fix typo

* fix typo

* fix

* fix

* [moe] add parallel strategy for shared_expert && fix test for deepseek (#6063)

* [ColossalEval] support for vllm (#6056)

* support vllm

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* modify vllm and update readme

* run pre-commit

* remove dupilicated lines and refine code

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update param name

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* refine code

* update readme

* refine code

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [release] update version (#6062)

* [feat] add zerobubble pp (just a frame now); add POC test for dx_dw; add test for zerobubble;

* [update] update text;

* [feat] add test run_fwd_bwd automatic scheduling;

* [feat] fix poc format

* [fix] fix poc test; add comments in poc;

* [feat] add optim backward_b_by_grad

* [feat] fix optimizer bwd b & w; support return accum loss & output

* [fix] fix optim bwd; add license for v_schedule; remove redundant attributes; fix schedule loop "while"--> "for"; add communication dict;

* [feat] update test; rm comments;

* [fix] rm zbv in hybridplugin

* [fix] fix optim bwd;

* [fix] fix optim bwd;

* [fix] rm output.data after send fwd;

* [fix] fix bwd step if condition; remove useless comments and format info;

* [fix] fix mem check;

* [fix] fix mem assertation

* [fix] fix mem; use a new model shape; only assert mem less and equal than theo;

* [fix] fix model zoo import;

* [feat] moehybrid support zerobubble;

* [fix] fix zerobubble pp for shardformer type input;

* [fix] fix require_grad & deallocate call;

* [fix] fix mem assert;

* [fix] fix fwd branch, fwd pass both micro_batch & internal_inputs'

* [fix] fix pipeline util func deallocate --> release_tensor_data; fix bwd_b loss bwd branch;

* [fix] fix zerobubble; support shardformer model type;

* [fix] fix test_pipeline_utils ci;

* [plugin] hybrid support zero bubble pipeline (#6060)

* hybrid support zbv

* fix

fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* Update zero_bubble_pp.py

* fix

* fix-ci

* fix

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

* fix

* [zerobubble]Support ZeroBubble Pipeline (#6034)

* [feat] add zerobubble pp (just a frame now); add POC test for dx_dw; add test for zerobubble;

* [feat] add dw test;

* [fix] fix weight not close;

* [update] update text;

* [feat] add test run_fwd_bwd automatic scheduling;

* [feat] split communication and calculation; fix pop empty send_bwd_buffer error;

* [feat] add test for p & p grad;

* [feat] add comments for ZBV func;

* [fix] rm useless assign and comments;

* [fix] fix ci test; add pytest;

* [feat] add run_fwd_bwd_with_microbatch  (replace input) & test; add p&p.grad assert close test & all pass;

* [feat] add apply v_schedule graph; p & p.grad assert err exist;

* [fix] update

* [feat] fix ci; add assert;

* [feat] fix poc format

* [feat] fix func name & ci; add comments;

* [fix] fix poc test; add comments in poc;

* [feat] add optim backward_b_by_grad

* [feat] fix optimizer bwd b & w; support return accum loss & output

* [feat] add fwd_bwd_step, run_fwd_only;

* [fix] fix optim bwd; add license for v_schedule; remove redundant attributes; fix schedule loop "while"--> "for"; add communication dict;

* [fix] fix communication_map;

* [feat] update test; rm comments;

* [fix] rm zbv in hybridplugin

* [fix] fix optim bwd;

* [fix] fix optim bwd;

* [fix] rm output.data after send fwd;

* [fix] fix bwd step if condition; remove useless comments and format info;

* [fix] fix detach output & release output;

* [fix] rm requir_grad for output;

* [fix] fix requir grad position and detach position and input&output local buffer append position;

* [feat] add memory assertation;

* [fix] fix mem check;

* [fix] mem assertation'

* [fix] fix mem assertation

* [fix] fix mem; use a new model shape; only assert mem less and equal than theo;

* [fix] fix model zoo import;

* [fix] fix redundant detach & clone; add buffer assertation in the end;

* [fix] add output_obj_grad assert None at bwd b step; replace input_obj.require_grad_ with treemap;

* [fix] update optim state dict assert (include param group & state); fix mem assert after add optim;

* [fix] add testcase with microbatch 4;

* hybrid support zbv

* fix

fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update zero_bubble_pp.py

* fix

* fix-ci

* fix

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

* fix

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: duanjunwen <935724073@qq.com>

* [feat] add zerobubble pp (just a frame now); add POC test for dx_dw; add test for zerobubble;

* [update] update text;

* [feat] add test run_fwd_bwd automatic scheduling;

* [feat] fix poc format

* [fix] fix poc test; add comments in poc;

* [feat] add optim backward_b_by_grad

* [feat] fix optimizer bwd b & w; support return accum loss & output

* [fix] fix optim bwd; add license for v_schedule; remove redundant attributes; fix schedule loop "while"--> "for"; add communication dict;

* [feat] update test; rm comments;

* [fix] fix optim bwd;

* [fix] fix optim bwd;

* [fix] rm output.data after send fwd;

* [fix] fix bwd step if condition; remove useless comments and format info;

* [fix] fix mem check;

* [fix] fix mem assertation

* [fix] fix mem; use a new model shape; only assert mem less and equal than theo;

* [fix] fix model zoo import;

* [fix] fix mem assert;

* [fix] fix fwd branch, fwd pass both micro_batch & internal_inputs'

* [plugin] hybrid support zero bubble pipeline (#6060)

* hybrid support zbv

* fix

fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* Update zero_bubble_pp.py

* fix

* fix-ci

* fix

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

* fix

* [zerobubble]Support ZeroBubble Pipeline (#6034)

* [feat] add zerobubble pp (just a frame now); add POC test for dx_dw; add test for zerobubble;

* [feat] add dw test;

* [fix] fix weight not close;

* [update] update text;

* [feat] add test run_fwd_bwd automatic scheduling;

* [feat] split communication and calculation; fix pop empty send_bwd_buffer error;

* [feat] add test for p & p grad;

* [feat] add comments for ZBV func;

* [fix] rm useless assign and comments;

* [fix] fix ci test; add pytest;

* [feat] add run_fwd_bwd_with_microbatch  (replace input) & test; add p&p.grad assert close test & all pass;

* [feat] add apply v_schedule graph; p & p.grad assert err exist;

* [fix] update

* [feat] fix ci; add assert;

* [feat] fix poc format

* [feat] fix func name & ci; add comments;

* [fix] fix poc test; add comments in poc;

* [feat] add optim backward_b_by_grad

* [feat] fix optimizer bwd b & w; support return accum loss & output

* [feat] add fwd_bwd_step, run_fwd_only;

* [fix] fix optim bwd; add license for v_schedule; remove redundant attributes; fix schedule loop "while"--> "for"; add communication dict;

* [fix] fix communication_map;

* [feat] update test; rm comments;

* [fix] rm zbv in hybridplugin

* [fix] fix optim bwd;

* [fix] fix optim bwd;

* [fix] rm output.data after send fwd;

* [fix] fix bwd step if condition; remove useless comments and format info;

* [fix] fix detach output & release output;

* [fix] rm requir_grad for output;

* [fix] fix requir grad position and detach position and input&output local buffer append position;

* [feat] add memory assertation;

* [fix] fix mem check;

* [fix] mem assertation'

* [fix] fix mem assertation

* [fix] fix mem; use a new model shape; only assert mem less and equal than theo;

* [fix] fix model zoo import;

* [fix] fix redundant detach & clone; add buffer assertation in the end;

* [fix] add output_obj_grad assert None at bwd b step; replace input_obj.require_grad_ with treemap;

* [fix] update optim state dict assert (include param group & state); fix mem assert after add optim;

* [fix] add testcase with microbatch 4;

* hybrid support zbv

* fix

fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update zero_bubble_pp.py

* fix

* fix-ci

* fix

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

* fix

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: duanjunwen <935724073@qq.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: HangXu <hangxu0304@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: GuangyaoZhang <xjtu521@qq.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: Haze188 <haze188@qq.com>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Stephan Kö <stephankoe@users.noreply.github.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: zhurunhua <1281592874@qq.com>
Co-authored-by: Insu Jang <insujang@umich.edu>
Co-authored-by: Gao, Ruiyuan <905370712@qq.com>
Co-authored-by: hxwang <wang1570@e.ntu.edu.sg>
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: Wang Binluo <32676639+wangbluo@users.noreply.github.com>
Co-authored-by: wangbluo <2538539015@qq.com>
Co-authored-by: root <root@notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9-0.notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9.colossal-ai.svc.cluster.local>
Co-authored-by: duanjunwen <935724073@qq.com>
Co-authored-by: Camille Zhong <44392324+Camille7777@users.noreply.github.com>
2024-10-08 15:58:00 +08:00

622 lines
26 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import copy
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
import torch
from colossal_eval.utils import Conversation, get_batch_prompt, is_rank_0
from peft import PeftModel
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokenizer
from colossalai.logging import DistributedLogger
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.utils import get_current_device
from .base import BaseModel
IGNORE_INDEX = -100
class HuggingFaceModel(BaseModel):
"""
Model wrapper around HuggingFace AutoModel models.
Args:
path: The path to a HuggingFace model.
model_max_length: The maximum sequence length of the model.
tokenizer_path: The path to the tokenizer.
tokenizer_kwargs: Keyword arguments for the tokenizer.
peft_path: The name or path to the HuggingFace's PEFT model.
model_kwargs: Keyword arguments for the model.
prompt_template: The model's prompt template.
batch_size: Batch size for inference.
logger: Logger for the model.
shard_config: Shard config for tensor parallel.
"""
def __init__(
self,
path: str,
model_max_length: int = 2048,
tokenizer_path: Optional[str] = None,
tokenizer_kwargs: dict = dict(),
peft_path: Optional[str] = None,
model_kwargs: Dict = None,
prompt_template: Conversation = None,
batch_size: int = 1,
logger: DistributedLogger = None,
shard_config: ShardConfig = None,
):
super().__init__(
path=path,
model_max_length=model_max_length,
prompt_template=prompt_template,
batch_size=batch_size,
logger=logger,
)
self._load_tokenizer(path=path, tokenizer_path=tokenizer_path, tokenizer_kwargs=tokenizer_kwargs)
self._load_model(path=path, model_kwargs=model_kwargs, peft_path=peft_path, shard_config=shard_config)
def _get_choices_indices(self, language: str):
"""
Get indices for each choice
Some tokenizer will insert BOS if you don't specify add_special_tokens=False such as Llama-2.
The indices for choices may be different given the context. For example, for Llama-2 tokenizer, for Chinese context like "答案:{choice}", indices for choices A, B, C and D are 29909, 29933, 29907 and 29928, for English context like "Answer: {choice}", indices for choices A, B, C and D are 319, 350, 315 and 360.
print(self.tokenizer("答案A")) to see
print(self.tokenizer("Answer: A")) to see
"""
# A trick for get "all" tokens ids related to given choices.
self.indices_for_choices = [[] for _ in range(2)]
for choice in self.choices:
self.indices_for_choices[0].append(
self.tokenizer(f"Answer: {choice}", add_special_tokens=False).input_ids[-1]
)
self.indices_for_choices[1].append(
self.tokenizer(f"答案:{choice}", add_special_tokens=False).input_ids[-1]
)
def _load_tokenizer(self, path: str, tokenizer_path: Optional[str], tokenizer_kwargs: dict):
"""
Load tokenizer.
Args:
path: The path to the model. Usually it also serves as the path to the tokenizer.
tokenizer_path: The path to the tokenzier.
tokenizer_kwargs: Keyword arguments for the tokenizer.
"""
if self.batch_size > 1:
tokenizer_kwargs.update({"padding_side": "left"})
tokenizer_kwargs.update({"truncation_side": "left"})
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path if tokenizer_path else path, **tokenizer_kwargs)
if self.tokenizer.pad_token_id is None:
self.logger.warning("pad_token_id is not set for the tokenizer. " "Using eos_token_id as pad_token_id.")
if self.tokenizer.eos_token:
self.tokenizer.pad_token = self.tokenizer.eos_token
elif hasattr(self.tokenizer, "eod_id"):
# Qwen has an eod token "<|endoftext|>".
self.tokenizer.pad_token_id = self.tokenizer.eod_id
else:
self.logger.error("Neither eos_token nor eod_id is available for setting pad_token_id.")
raise ValueError(
"The tokenizer does not have a pad_token_id, eos_token, or eod_id. "
"Please set pad_token_id manually."
)
def _load_model(
self, path: str, model_kwargs: dict, peft_path: Optional[str] = None, shard_config: ShardConfig = None
):
"""
Load model.
Args:
path: The path to the model.
model_kwargs: Keyword arguments for the model.
peft_path: The path to the peft model.
shard_config: Shard config for tensor parallel.
"""
if "torch_dtype" in model_kwargs:
model_kwargs["torch_dtype"] = eval(model_kwargs["torch_dtype"])
else:
model_kwargs.setdefault("torch_dtype", torch.float16)
if "config" in model_kwargs:
model_kwargs["config"] = AutoConfig.from_pretrained(model_kwargs["config"])
if shard_config is not None:
self.model = AutoModel.from_pretrained(path, **model_kwargs)
shard_former = ShardFormer(shard_config)
self.model, _ = shard_former.optimize(self.model)
self.model.to(get_current_device())
if peft_path is not None:
raise NotImplementedError("ShardFormer for PEFT models is not implemented.")
else:
self.model = AutoModel.from_pretrained(path, **model_kwargs).to(get_current_device())
if peft_path is not None:
self.model = PeftModel.from_pretrained(self.model, peft_path, is_trainable=False)
self.model.eval()
def _calculate_loss(self, input_ids_list: List[torch.LongTensor], labels: List[torch.LongTensor]) -> Tuple[List]:
"""
Calculate loss only on target tokens.
Hugging Face generate() function can't return per sample loss.
It will only return the mean of the loss in a batch.
In torch.nn.CrossEntropyLoss(), reduction should be specified as "none" to get per sample loss.
Args:
input_ids_list: A batch of input token ids.
labels: A batch of labels.
Returns:
A list of loss.
"""
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids_list, batch_first=True, padding_value=self.tokenizer.pad_token_id
).to(get_current_device())
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX).to(
get_current_device()
)
attention_mask = input_ids.ne(self.tokenizer.pad_token_id).to(get_current_device())
outputs = self.model(input_ids, attention_mask=attention_mask)[0]
shift_logits = outputs[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = torch.nn.CrossEntropyLoss(reduction="none", ignore_index=IGNORE_INDEX)
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)).view(shift_labels.size())
lens = (labels[..., 1:] != IGNORE_INDEX).sum(-1).cpu().numpy()
loss_sum = loss.sum(-1).to(torch.float32).cpu().detach().numpy()
return loss_sum.tolist(), lens.tolist()
def _get_truncated_prompts(self, inputs: List[str], max_new_tokens: int) -> List[str]:
"""
Truncate the input sequence to fit model_max_length (we suggest truncate in the middle, since the left and right side may contain crucial instructions)
https://github.com/THUDM/LongBench/blob/main/pred.py#L16
Args:
inputs: A batch of input prompts.
max_new_tokens: Max new tokens for model to generate.
Returns:
Truncated prompts.
"""
truncated_inputs = copy.deepcopy(inputs)
for i, input in enumerate(inputs):
tokenized_prompt = self.tokenizer(input, truncation=False, return_tensors="pt").input_ids[0]
if len(tokenized_prompt) > self.model_max_length - max_new_tokens:
half = (self.model_max_length - max_new_tokens) // 2
prompt = self.tokenizer.decode(
tokenized_prompt[:half], skip_special_tokens=True
) + self.tokenizer.decode(tokenized_prompt[-half:], skip_special_tokens=True)
truncated_inputs[i] = prompt
return truncated_inputs
def _get_input_ids_and_labels_pretrain(self, batch_prompt: List[str]) -> Tuple[List[torch.LongTensor]]:
"""
Get input_ids and labels for pretrain data.
We only need batch_prompt because for pretain dataset, we don't need to predict new tokens.
Args:
batch_prompt: A batch of prompt.
Returns:
Input_ids and labels for the given batch.
"""
input_ids_list = []
labels_list = []
bytes_list = []
for input in batch_prompt:
# Pretrain data tends to be very long, sometimes much larger than the model_max_length, we only tokenize 1/ratio of the data first to accelerate the tokenization process.
# Once the length of the result is greater or equal to model_max_length, we stop iterating on ratios and use the result as input_ids and labels.
# After all, the rest of the original string doesn't need to be tokenized at the first place.
ratio = [16, 8, 4, 2, 1]
tokenized = None
for r in ratio:
tokenized = self.tokenizer(
[input[0 : len(input) // r]], truncation=True, max_length=self.model_max_length, return_tensors="pt"
)
if tokenized.input_ids.size(1) >= self.model_max_length:
break
input_ids = copy.deepcopy(tokenized["input_ids"])[0]
target_ids = copy.deepcopy(input_ids)
string = self.tokenizer.decode(tokenized.input_ids[0], skip_special_tokens=True)
bytes_list.append(len(string.encode("utf-8")))
input_ids_list.append(input_ids)
labels_list.append(target_ids)
return input_ids_list, labels_list, bytes_list
def _get_input_ids_and_labels(
self, batch_prompt: List[str], batch_target: List[List[str]], calculate_overall_loss: bool
) -> Tuple[List[torch.LongTensor]]:
"""
Get input_ids and labels for the given data.
Args:
batch_prompt: A batch of prompt.
batch_target: A batch of target.
Returns:
Input_ids and labels for the given batch.
"""
if calculate_overall_loss:
batch = []
# Concatenate prompt and target answers.
# You should decide the concatenation character in the corresponding dataset script in dataset folder. For example, in line 119 dataset/gsm.py, the concatenation character is space.
for p, b in zip(batch_prompt, batch_target):
batch.append(p + b[0])
return self._get_input_ids_and_labels_pretrain(batch)
input_ids_list = []
labels_list = []
for input, targets in zip(batch_prompt, batch_target):
for target in targets:
# TODO: Improve the labeling process. Should annotate the border by adding special tokens.
target_tokenized = self.tokenizer(
[target], truncation=True, max_length=self.model_max_length, return_tensors="pt"
)
# Get prompt with length model_max_length - len(target_tokenized).
# Reserve some space for target answer tokens using max_new_tokens.
# This will generate the correct start_idx and end_idx.
max_new_tokens = target_tokenized["input_ids"][0].size(0)
prompt_with_correct_length = self._get_truncated_prompts([input], max_new_tokens)[0]
input_tokenized = self.tokenizer(
[prompt_with_correct_length],
truncation=True,
max_length=self.model_max_length - max_new_tokens,
return_tensors="pt",
)
target_tokenized = self.tokenizer(
[prompt_with_correct_length + target],
truncation=True,
max_length=self.model_max_length,
return_tensors="pt",
)
start_idx = input_tokenized["input_ids"][0].size(0)
end_idx = target_tokenized["input_ids"][0].size(0)
# Sometimes if the target is only an option such as A, B, C and D, the length of input_tokenized is equal to the length of target_tokenized, so we need -1.
# This is caused by the different behavior of tokenizers.
# For example, the tokenizer for Baichuan and Llama will cause such problem in a plain prompt setting.
# The length of the tokenized sequences for prompt "Answer: " and "Answer: A" is the same.
# Baichuan: [29394, 31143, 31106] [29394, 31143, 703]
# Llama: [673, 29901, 29871] [673, 29901, 319]
# The length for sequence "prompt" and "prompt + A" is equal.
# For ChatGLM, the length of the tokenized sequences is different.
# ChatGLM: [16583, 12] [16583, 12, 167]
if start_idx == end_idx:
start_idx -= 1
input_ids = copy.deepcopy(target_tokenized["input_ids"])[0]
target_ids = copy.deepcopy(input_ids)
mask = torch.zeros_like(target_ids, dtype=torch.bool)
mask[start_idx:end_idx] = True
target_ids[~mask] = IGNORE_INDEX
input_ids_list.append(input_ids)
labels_list.append(target_ids)
return input_ids_list, labels_list, None
def inference(self, data_loader: DataLoader, inference_kwargs: Dict[str, Any], debug: bool = False) -> List[Dict]:
"""
Infer the given data.
This function will call self.generate() to get model outputs and also self.model() to get logits.
Args:
data: The data for inference.
inference_kwargs: Arguments for inference.
debug: Whether to display generated prompt for debugging.
Returns:
Inference results.
"""
calculate_loss = inference_kwargs["calculate_loss"]
classes = inference_kwargs["all_classes"]
language = inference_kwargs["language"]
calculate_overall_loss = inference_kwargs["calculate_overall_loss"]
max_new_tokens = inference_kwargs["max_new_tokens"]
few_shot_data = inference_kwargs.get("few_shot_data", None)
# Some classification questions' options are texts not a single letter such as A, B, C and D.
# If the text length is greater than 1, we won't calculate loss over choices.
if classes is not None and any(len(c) > 1 for c in classes):
classes = None
self.choices = classes
self.indices_for_choices = None
if self.choices:
# Get indices for each choice
self._get_choices_indices(language)
self.str_label_map = {choice: idx for idx, choice in enumerate(self.choices)}
bar = tqdm(
range(len(data_loader)),
desc=f"{inference_kwargs['dataset']}-{inference_kwargs['category']} Inference steps",
disable=not is_rank_0(),
)
loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
answers = []
for i, batch in enumerate(data_loader):
batch_prompt, batch_target = get_batch_prompt(
self.prompt_template, batch, few_shot_data, self.tokenizer, self.model_max_length
)
if is_rank_0() and debug and i == 0:
self.logger.info(
f"Inference arguments for dataset {batch[0]['dataset']} category {batch[0]['category']} is:\n{inference_kwargs}"
)
self.logger.info("-" * 120)
self.logger.info("An example prompt and prompt with target is:")
self.logger.info("-" * 120)
self.logger.info(batch_prompt[0])
self.logger.info("-" * 120)
self.logger.info(batch_prompt[0] + batch_target[0][0])
if not calculate_overall_loss:
batch_decodes, scores = self.generate(batch_prompt, max_new_tokens)
if calculate_loss:
batch_losses, batch_target_token_nums, batch_bytes_nums = self.get_loss(
batch_prompt, batch_target, calculate_overall_loss
)
probs = []
if self.indices_for_choices:
scores = scores.to(torch.float32)
# If we have indices_for_choices(must be single-choice question), there will be only one target answer for one data sample.
# Otherwise this will violate the single-choice setting.
if calculate_loss:
labels = [self.str_label_map[batch[j]["target"]] for j in range(len(batch))]
loss_over_choices = loss_fct(scores, torch.tensor(labels, dtype=torch.long)).numpy().tolist()
probs = scores.numpy().tolist()
probs = [
{choice: probs[i][self.str_label_map[choice]] for choice in self.choices} for i in range(len(probs))
]
for j in range(len(batch)):
if not calculate_overall_loss:
if isinstance(batch[j]["output"], list):
batch[j]["output"].append(batch_decodes[j].strip())
else:
batch[j]["output"] = batch_decodes[j].strip()
if isinstance(scores, torch.Tensor):
batch[j]["logits_over_choices"] = probs[j]
if calculate_loss:
batch[j]["loss_over_choices"] = loss_over_choices[j]
if calculate_loss:
batch[j]["loss"] = (np.array(batch_losses[j]) / np.array(batch_target_token_nums[j])).tolist()
# loss_sum is specially used for pertrain dataset for calculating per-byte-perplexity.
# However, loss (which is per sample loss) suffices for most cases.
batch[j]["loss_sum"] = batch_losses[j]
batch[j]["token_num"] = batch_target_token_nums[j]
if batch_bytes_nums:
batch[j]["byte_num"] = batch_bytes_nums[j]
answers.extend(batch)
bar.update()
return answers
@torch.no_grad()
def generate(self, inputs: List[str], max_new_tokens: int, **kwargs) -> List[str]:
"""Generate results given a list of inputs and get logits of the first new token over choices.
Args:
inputs: A list of strings.
max_new_tokens: Max new tokens for generation.
kwargs: Key arguments for generation
Returns:
A list of generated strings and logits over choices.
Note:
Currently the function only returns the logits of the first new token.
It is used for single choice question.
For multiple choices question, please avoid using the loss over choices.
You should set argument choices as None in self.inference().
"""
truncated_inputs = self._get_truncated_prompts(inputs, max_new_tokens)
encoded_inputs = self.tokenizer(
truncated_inputs,
padding=True,
truncation=True,
return_tensors="pt",
return_token_type_ids=False,
max_length=self.model_max_length - max_new_tokens,
).to(get_current_device())
# Set output_scores=True to get prediction scores.
outputs = self.model.generate(
**encoded_inputs,
max_new_tokens=max_new_tokens,
return_dict_in_generate=True,
output_scores=True,
do_sample=False,
use_cache=True,
**kwargs,
)
# We only need to decode predicted tokens.
sequences = outputs.sequences[:, encoded_inputs["input_ids"].shape[1] :]
scores = []
if self.indices_for_choices:
# If the question is a single-choice question, we will return the scores of specific indices for first predicted token.
# The indices are the tokenization results of the options for the single-choice question.
# For example, if the options of the question are A, B, C and D, we only returns scores at indices of A, B, C and D.
for option_indices in self.indices_for_choices:
scores.append(outputs.scores[0][:, option_indices].detach().cpu())
scores = torch.max(torch.stack(scores), dim=0)[0]
decoded_sequences = self.tokenizer.batch_decode(sequences, skip_special_tokens=True)
return decoded_sequences, scores
@torch.no_grad()
def get_loss(
self, batch_prompt: List[str], batch_target: List[List[str]], calculate_overall_loss: bool
) -> List[List[float]]:
"""
Calculate loss only on target tokens.
Args:
batch: A batch of prompt without target answer.
batch_target: A batch of target answer. Sometimes one question can have multiple target answers.
Returns:
Loss.
"""
# We set max_new_tokens in self._get_truncated_prompts to 0 because we only need logits to calculate loss.
# We don't need to generate new tokens.
# Target answer's length is usually << model_max_length, but we still call it in case.
# We don't call self._get_truncated_prompts for batch_prompt because we need target answer's length first to reserve some space for target answer's tokens.
if not calculate_overall_loss:
batch_target = [self._get_truncated_prompts(prompt_target, 0) for prompt_target in batch_target]
# Get the number of target answers for different questions
batch_target_nums = [len(prompt_target) for prompt_target in batch_target]
input_ids_list, labels_list, bytes_list = self._get_input_ids_and_labels(
batch_prompt, batch_target, calculate_overall_loss
)
# Because of multiple target answers, the final batch size may be greater than self.batch_size.
# We will generate new batches.
losses = []
target_token_nums = []
batched_input_ids = [
input_ids_list[i : i + self.batch_size] for i in range(0, len(input_ids_list), self.batch_size)
]
batched_labels = [labels_list[i : i + self.batch_size] for i in range(0, len(labels_list), self.batch_size)]
for batch_input_ids, batch_labels in zip(batched_input_ids, batched_labels):
losses_per_batch, target_token_num_per_batch = self._calculate_loss(batch_input_ids, batch_labels)
losses.extend(losses_per_batch)
target_token_nums.extend(target_token_num_per_batch)
start_indice = 0
losses_per_sample = []
target_token_nums_per_sample = []
bytes_nums_per_sample = []
for length in batch_target_nums:
losses_per_sample.append(losses[start_indice : start_indice + length])
target_token_nums_per_sample.append(target_token_nums[start_indice : start_indice + length])
if bytes_list:
bytes_nums_per_sample.append(bytes_list[start_indice : start_indice + length])
start_indice += length
if bytes_list:
return losses_per_sample, target_token_nums_per_sample, bytes_nums_per_sample
return losses_per_sample, target_token_nums_per_sample, None
class HuggingFaceCausalLM(HuggingFaceModel):
"""
Model wrapper around HuggingFace AutoModelForCausalLM models.
Args:
path: The path to a HuggingFace model.
model_max_length: The maximum sequence length of the model.
tokenizer_path: The path to the tokenizer.
tokenizer_kwargs: Keyword arguments for the tokenizer.
peft_path: The name or path to the HuggingFace's PEFT model.
model_kwargs: Keyword arguments for the model.
prompt_template: The model's prompt template.
batch_size: Batch size for inference.
logger: Logger for the model.
shard_config: Shard config for tensor parallel.
"""
def _load_model(
self, path: str, model_kwargs: dict, peft_path: Optional[str] = None, shard_config: ShardConfig = None
):
"""
Load model.
Args:
path: The path to the model.
model_kwargs: Keyword arguments for the model.
peft_path: The path to the peft model.
shard_config: Shard config for tensor parallel.
"""
if "torch_dtype" in model_kwargs:
model_kwargs["torch_dtype"] = eval(model_kwargs["torch_dtype"])
else:
model_kwargs.setdefault("torch_dtype", torch.float16)
if "config" in model_kwargs:
model_kwargs["config"] = AutoConfig.from_pretrained(model_kwargs["config"])
if shard_config is not None:
self.model = AutoModelForCausalLM.from_pretrained(path, **model_kwargs)
shard_former = ShardFormer(shard_config)
self.model, _ = shard_former.optimize(self.model)
self.model.to(get_current_device())
if peft_path is not None:
raise NotImplementedError("ShardFormer for PEFT models is not implemented.")
else:
self.model = AutoModelForCausalLM.from_pretrained(path, **model_kwargs).to(get_current_device())
if peft_path is not None:
self.model = PeftModel.from_pretrained(self.model, peft_path, is_trainable=False)
self.model.eval()