ColossalAI/colossalai/zero/gemini/gemini_optimizer.py
linsj20 fcf776ff1b
[Feature] LoRA rebased to main branch (#5622)
* [Inference]ADD Bench Chatglm2 script (#4963)

* add bench chatglm

* fix bug and make utils

---------

Co-authored-by: CjhHa1 <cjh18671720497outlook.com>

* [Pipeline inference] Combine kvcache with pipeline inference (#4938)

* merge kvcache with pipeline inference and refactor the code structure

* support ppsize > 2

* refactor pipeline code

* do pre-commit

* modify benchmark

* fix bench mark

* polish code

* add docstring and update readme

* refactor the code

* fix some logic bug of ppinfer

* polish readme

* fix typo

* skip infer test

* updated c++17 compiler flags (#4983)

* [Inference] Dynamic Batching Inference, online and offline (#4953)

* [inference] Dynamic Batching for Single and Multiple GPUs (#4831)

* finish batch manager

* 1

* first

* fix

* fix dynamic batching

* llama infer

* finish test

* support different lengths generating

* del prints

* del prints

* fix

* fix bug

---------

Co-authored-by: CjhHa1 <cjh18671720497outlook.com>

* [inference] Async dynamic batching  (#4894)

* finish input and output logic

* add generate

* test forward

* 1

* [inference]Re push async dynamic batching (#4901)

* adapt to ray server

* finish async

* finish test

* del test

---------

Co-authored-by: yuehuayingxueluo <867460659@qq.com>

* Revert "[inference]Re push async dynamic batching (#4901)" (#4905)

This reverts commit fbf3c09e67.

* Revert "[inference] Async dynamic batching  (#4894)"

This reverts commit fced140250.

* Revert "[inference] Async dynamic batching  (#4894)" (#4909)

This reverts commit fced140250.

* Add Ray Distributed Environment Init Scripts

* support DynamicBatchManager base function

* revert _set_tokenizer version

* add driver async generate

* add async test

* fix bugs in test_ray_dist.py

* add get_tokenizer.py

* fix code style

* fix bugs about No module named 'pydantic' in ci test

* fix bugs in ci test

* fix bugs in ci test

* fix bugs in ci test

* [infer]Add Ray Distributed Environment Init Scripts (#4911)

* Revert "[inference] Async dynamic batching  (#4894)"

This reverts commit fced140250.

* Add Ray Distributed Environment Init Scripts

* support DynamicBatchManager base function

* revert _set_tokenizer version

* add driver async generate

* add async test

* fix bugs in test_ray_dist.py

* add get_tokenizer.py

* fix code style

* fix bugs about No module named 'pydantic' in ci test

* fix bugs in ci test

* fix bugs in ci test

* fix bugs in ci test

* support dynamic batch for bloom model and is_running function

* [Inference]Test for new Async engine (#4935)

* infer engine

* infer engine

* test engine

* test engine

* new manager

* change step

* add

* test

* fix

* fix

* finish test

* finish test

* finish test

* finish test

* add license

---------

Co-authored-by: yuehuayingxueluo <867460659@qq.com>

* add assertion for config (#4947)

* [Inference] Finish dynamic batching offline test (#4948)

* test

* fix test

* fix quant

* add default

* fix

* fix some bugs

* fix some bugs

* fix

* fix bug

* fix bugs

* reset param

---------

Co-authored-by: yuehuayingxueluo <867460659@qq.com>
Co-authored-by: Cuiqing Li <lixx3527@gmail.com>
Co-authored-by: CjhHa1 <cjh18671720497outlook.com>

* [Kernels]Updated Triton kernels into 2.1.0 and adding flash-decoding for llama token attention  (#4965)

* adding flash-decoding

* clean

* adding kernel

* adding flash-decoding

* add integration

* add

* adding kernel

* adding kernel

* adding triton 2.1.0 features for inference

* update bloom triton kernel

* remove useless vllm kernels

* clean codes

* fix

* adding files

* fix readme

* update llama flash-decoding

---------

Co-authored-by: cuiqing.li <lixx336@gmail.com>

* fix ColossalEval (#4992)

Co-authored-by: Xu Yuanchen <yuanchen.xu00@gmail.com>

* [doc]Update doc for colossal-inference (#4989)

* update doc

* Update README.md

---------

Co-authored-by: cuiqing.li <lixx336@gmail.com>

* [hotfix] Fix the bug where process groups were not being properly released. (#4940)

* Fix the bug where process groups were not being properly released.

* test

* Revert "test"

This reverts commit 479900c139.

* [hotfix] fix the bug of repeatedly storing param group (#4951)

* [doc] add supported feature diagram for hybrid parallel plugin (#4996)

* [Pipeline Inference] Merge pp with tp (#4993)

* refactor pipeline into new CaiInferEngine

* updata llama modeling forward

* merge tp with pp

* update docstring

* optimize test workflow and example

* fix typo

* add assert and todo

* [release] update version (#4995)

* [release] update version

* [hotfix] fix ci

* [moe] merge moe into main (#4978)

* update moe module
* support openmoe

* [hotfix] fix grad accumulation plus clipping for gemini (#5002)

* [hotfix] Add layer norm gradients all-reduce for sequence parallel (#4926)

* [hotfix] Add layer norm gradients all-reduce for sequence parallel. (#4915)

* Add layer norm gradients all-reduce for sequence parallel.

* skip pipeline inference test

* [hotfix] fixing polices of sequence parallel (#4922)

* Add layer norm gradients all-reduce for sequence parallel.

* fix parameter passing when calling get_autopolicy

---------

Co-authored-by: littsk <1214689160@qq.com>

* Hotfix/add grad all reduce for sequence parallel (#4927)

* Add layer norm gradients all-reduce for sequence parallel.


* fix parameter passing when calling get_autopolicy

* fix bug using wrong variables

---------

Co-authored-by: littsk <1214689160@qq.com>

* fix policy initialization

* fix bloom and chatglm policices

* polish code of handling layernorm

* fix moe module

* polish code of class initializing

---------

Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com>

* [format] applied code formatting on changed files in pull request 4926 (#5007)

Co-authored-by: github-actions <github-actions@github.com>

* [Inference] Fix bug in ChatGLM2 Tensor Parallelism (#5014)

* fix bug

* fix

* fix multiquery

* fix multiquery

---------

Co-authored-by: CjhHa1 <cjh18671720497outlook.com>

* [misc] add code owners (#5024)

* [moe] support optimizer checkpoint (#5015)

* Refactor MoE Manager setup method

* unshard optim ckpt

* optim io

* update transformer version

* update requirements

* update ckpt

* update ckpt

* update ckpt

* fix engine

* fix engine

* Support mtbench (#5025)

Co-authored-by: Xu Yuanchen <yuanchen.xu00@gmail.com>

* [moe]: fix ep/tp tests, add hierarchical all2all (#4982)

* fix: add warning for EP different behavior

* fix: use shard_data in ep & tp model

* to: add used_capacity

* fix: fix router test

* feat: add create_ep_node_group

* feat: add create_ep_hierarchical_group fn

* feat: add HierarchicalAllToAll

* test: add hierarchical all2all test

* fix: fix test errors

* fix: simplify create_ep_hierarchical_group

* fix: add hierarchical_alltoall arg

* fix: fix environ typo

* revert: revert process mesh order

* to: add todo mark

* fix: skip hierarchical_comm if torch < 1.13.1

* [shardformer] Fix serialization error with Tensor Parallel state saving (#5018)

* Fix serialization error with Tensor Parallel state saving

* Refactor state_dict CPU transfer using tree_map

* [gemini] gemini support tensor parallelism. (#4942)

* [colossalai]fix typo

* [inference] Add smmoothquant for llama (#4904)

* [inference] add int8 rotary embedding kernel for smoothquant (#4843)

* [inference] add smoothquant llama attention (#4850)

* add smoothquant llama attention

* remove uselss code

* remove useless code

* fix import error

* rename file name

* [inference] add silu linear fusion for smoothquant llama mlp  (#4853)

* add silu linear

* update skip condition

* catch smoothquant cuda lib exception

* prcocess exception for tests

* [inference] add llama mlp for smoothquant (#4854)

* add llama mlp for smoothquant

* fix down out scale

* remove duplicate lines

* add llama mlp check

* delete useless code

* [inference] add smoothquant llama (#4861)

* add smoothquant llama

* fix attention accuracy

* fix accuracy

* add kv cache and save pretrained

* refactor example

* delete smooth

* refactor code

* [inference] add smooth function and delete useless code for smoothquant (#4895)

* add smooth function and delete useless code

* update datasets

* remove duplicate import

* delete useless file

* refactor codes (#4902)

* rafactor code

* add license

* add torch-int and smoothquant license

* Update flash_attention_patch.py

To be compatible with the new change in the Transformers library, where a new argument 'padding_mask' was added to forward function of attention layer.
https://github.com/huggingface/transformers/pull/25598

* [kernel] support pure fp16 for cpu adam and update gemini optim tests (#4921)

* [kernel] support pure fp16 for cpu adam (#4896)

* [kernel] fix cpu adam kernel for pure fp16 and update tests (#4919)

* [kernel] fix cpu adam

* [test] update gemini optim test

* [format] applied code formatting on changed files in pull request 4908 (#4918)

Co-authored-by: github-actions <github-actions@github.com>

* [gemini] support gradient accumulation (#4869)

* add test

* fix no_sync bug in low level zero plugin

* fix test

* add argument for grad accum

* add grad accum in backward hook for gemini

* finish implementation, rewrite tests

* fix test

* skip stuck model in low level zero test

* update doc

* optimize communication & fix gradient checkpoint

* modify doc

* cleaning codes

* update cpu adam fp16 case

* [hotfix] fix torch 2.0 compatibility (#4936)

* [hotfix] fix launch

* [test] fix test gemini optim

* [shardformer] fix vit

* [test] add no master test for low level zero plugin (#4934)

* [format] applied code formatting on changed files in pull request 4820 (#4886)

Co-authored-by: github-actions <github-actions@github.com>

* [nfc] fix some typo with colossalai/ docs/ etc. (#4920)

* [Refactor] Integrated some lightllm kernels into token-attention  (#4946)

* add some req for inference

* clean codes

* add codes

* add some lightllm deps

* clean codes

* hello

* delete rms files

* add some comments

* add comments

* add doc

* add lightllm deps

* add lightllm cahtglm2 kernels

* add lightllm cahtglm2 kernels

* replace rotary embedding with lightllm kernel

* add some commnets

* add some comments

* add some comments

* add

* replace fwd kernel att1

* fix a arg

* add

* add

* fix token attention

* add some comments

* clean codes

* modify comments

* fix readme

* fix bug

* fix bug

---------

Co-authored-by: cuiqing.li <lixx336@gmail.com>
Co-authored-by: CjhHa1 <cjh18671720497@outlook.com>

* [test] merge old components to test to model zoo (#4945)

* [test] add custom models in model zoo

* [test] update legacy test

* [test] update model zoo

* [test] update gemini test

* [test] remove components to test

* [inference] add reference and fix some bugs (#4937)

* add reference and fix some bugs

* update gptq init

---------

Co-authored-by: Xu Kai <xukai16@foxamil.com>

* [Inference]ADD Bench Chatglm2 script (#4963)

* add bench chatglm

* fix bug and make utils

---------

Co-authored-by: CjhHa1 <cjh18671720497outlook.com>

* [Pipeline inference] Combine kvcache with pipeline inference (#4938)

* merge kvcache with pipeline inference and refactor the code structure

* support ppsize > 2

* refactor pipeline code

* do pre-commit

* modify benchmark

* fix bench mark

* polish code

* add docstring and update readme

* refactor the code

* fix some logic bug of ppinfer

* polish readme

* fix typo

* skip infer test

* updated c++17 compiler flags (#4983)

* [Inference] Dynamic Batching Inference, online and offline (#4953)

* [inference] Dynamic Batching for Single and Multiple GPUs (#4831)

* finish batch manager

* 1

* first

* fix

* fix dynamic batching

* llama infer

* finish test

* support different lengths generating

* del prints

* del prints

* fix

* fix bug

---------

Co-authored-by: CjhHa1 <cjh18671720497outlook.com>

* [inference] Async dynamic batching  (#4894)

* finish input and output logic

* add generate

* test forward

* 1

* [inference]Re push async dynamic batching (#4901)

* adapt to ray server

* finish async

* finish test

* del test

---------

Co-authored-by: yuehuayingxueluo <867460659@qq.com>

* Revert "[inference]Re push async dynamic batching (#4901)" (#4905)

This reverts commit fbf3c09e67.

* Revert "[inference] Async dynamic batching  (#4894)"

This reverts commit fced140250.

* Revert "[inference] Async dynamic batching  (#4894)" (#4909)

This reverts commit fced140250.

* Add Ray Distributed Environment Init Scripts

* support DynamicBatchManager base function

* revert _set_tokenizer version

* add driver async generate

* add async test

* fix bugs in test_ray_dist.py

* add get_tokenizer.py

* fix code style

* fix bugs about No module named 'pydantic' in ci test

* fix bugs in ci test

* fix bugs in ci test

* fix bugs in ci test

* [infer]Add Ray Distributed Environment Init Scripts (#4911)

* Revert "[inference] Async dynamic batching  (#4894)"

This reverts commit fced140250.

* Add Ray Distributed Environment Init Scripts

* support DynamicBatchManager base function

* revert _set_tokenizer version

* add driver async generate

* add async test

* fix bugs in test_ray_dist.py

* add get_tokenizer.py

* fix code style

* fix bugs about No module named 'pydantic' in ci test

* fix bugs in ci test

* fix bugs in ci test

* fix bugs in ci test

* support dynamic batch for bloom model and is_running function

* [Inference]Test for new Async engine (#4935)

* infer engine

* infer engine

* test engine

* test engine

* new manager

* change step

* add

* test

* fix

* fix

* finish test

* finish test

* finish test

* finish test

* add license

---------

Co-authored-by: yuehuayingxueluo <867460659@qq.com>

* add assertion for config (#4947)

* [Inference] Finish dynamic batching offline test (#4948)

* test

* fix test

* fix quant

* add default

* fix

* fix some bugs

* fix some bugs

* fix

* fix bug

* fix bugs

* reset param

---------

Co-authored-by: yuehuayingxueluo <867460659@qq.com>
Co-authored-by: Cuiqing Li <lixx3527@gmail.com>
Co-authored-by: CjhHa1 <cjh18671720497outlook.com>

* [Kernels]Updated Triton kernels into 2.1.0 and adding flash-decoding for llama token attention  (#4965)

* adding flash-decoding

* clean

* adding kernel

* adding flash-decoding

* add integration

* add

* adding kernel

* adding kernel

* adding triton 2.1.0 features for inference

* update bloom triton kernel

* remove useless vllm kernels

* clean codes

* fix

* adding files

* fix readme

* update llama flash-decoding

---------

Co-authored-by: cuiqing.li <lixx336@gmail.com>

* fix ColossalEval (#4992)

Co-authored-by: Xu Yuanchen <yuanchen.xu00@gmail.com>

* [doc]Update doc for colossal-inference (#4989)

* update doc

* Update README.md

---------

Co-authored-by: cuiqing.li <lixx336@gmail.com>

* [hotfix] Fix the bug where process groups were not being properly released. (#4940)

* Fix the bug where process groups were not being properly released.

* test

* Revert "test"

This reverts commit 479900c139.

* [hotfix] fix the bug of repeatedly storing param group (#4951)

* [doc] add supported feature diagram for hybrid parallel plugin (#4996)

* [Pipeline Inference] Merge pp with tp (#4993)

* refactor pipeline into new CaiInferEngine

* updata llama modeling forward

* merge tp with pp

* update docstring

* optimize test workflow and example

* fix typo

* add assert and todo

* [release] update version (#4995)

* [release] update version

* [hotfix] fix ci

* [gemini] gemini support tp

[gemini] gemini support tp

[gemini] gemini support tp

[gemini] gemini support tp

[gemini] gemini support tp

* fix

fix

fix

* update checkpointIO

update checkpointIO

update checkpointIO

update checkpointIO

update checkpointIO

update checkpointIO

update checkpointIO

update checkpointIO

update checkpointIO

* support fused layernorm

support fused layernorm

support fused layernorm

* update fusedlayernorm

update fusedlayernorm

update fusedlayernorm

* add sequence parallel to gemini

add sequence parallel to gemini

* fix

* fix comments

fix comments

fix comments

* fix

* fix t5

* clear cache

* fix

* activate ci

* activate ci

* fix

* fix

* fix

* fix

* revert

* modify tp gather method

modify tp gather method

modify tp gather method

modify tp gather method

* fix test

---------

Co-authored-by: Xu Kai <xukai16@foxmail.com>
Co-authored-by: Zian(Andy) Zheng <62330719+Orion-Zheng@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions <github-actions@github.com>
Co-authored-by: Baizhou Zhang <eddiezhang@pku.edu.cn>
Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com>
Co-authored-by: digger yu <digger-yu@outlook.com>
Co-authored-by: Cuiqing Li <lixx3527@gmail.com>
Co-authored-by: cuiqing.li <lixx336@gmail.com>
Co-authored-by: CjhHa1 <cjh18671720497@outlook.com>
Co-authored-by: Xu Kai <xukai16@foxamil.com>
Co-authored-by: Jianghai <72591262+CjhHa1@users.noreply.github.com>
Co-authored-by: Bin Jia <45593998+FoolPlayer@users.noreply.github.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: yuehuayingxueluo <867460659@qq.com>
Co-authored-by: Yuanchen <70520919+chengeharrison@users.noreply.github.com>
Co-authored-by: Xu Yuanchen <yuanchen.xu00@gmail.com>
Co-authored-by: littsk <1214689160@qq.com>
Co-authored-by: ppt0011 <143150326+ppt0011@users.noreply.github.com>

* [hotfix] Suport extra_kwargs in ShardConfig (#5031)

* [refactor]: replace inference args with extra_kwargs in ShardConfig

* modify shardconfig

* polish code

* fix policy bug in llama

* fix bug in auto policy

* remove setattr in ShardConfig

* fix wrong EOS token in ColossalChat

* [Kernels]Update triton kernels into 2.1.0 (#5046)

* update flash-context-attention

* adding kernels

* fix

* reset

* add build script

* add building process

* add llama2 exmaple

* add colossal-llama2 test

* clean

* fall back test setting

* fix test file

* clean

* clean

* clean

---------

Co-authored-by: cuiqing.li <lixx336@gmail.com>

* [pipeline,shardformer] Fix p2p efficiency in pipeline, allow skipping loading weight not in weight_map when `strict=False`, fix llama flash attention forward, add flop estimation by megatron in llama benchmark (#5017)

* Use p2p

* Cannot bidirectonal send p2p

* Refactor tensor creation and serialization in P2P
communication

* Fix llama forward args in flash attention

* Add flop estimate from megatron

* Support loading weight not in weight_map when strict=False in hybrid_parallel

* Use send_forward_recv_backward, etc in 1f1b

* Use dataclass for metdata
Remove torch.cuda.synchronize() as suggested

* Add comment about the torch.cuda.synchronize for potential error

* Typo

* Update hybrid_parallel_checkpoint_io.py

* Update p2p.py

* Update one_f_one_b.py

* Update p2p.py

---------

Co-authored-by: flybird11111 <1829166702@qq.com>

* [gemini] gemini support extra-dp (#5043)

* support ddp

* fix

* fix

* fix

fix

* support ddp

* fix

* fix

* fix

fix

* simplify tests

* fix

* fix

* fix

fix

fix

* fix

* [shardformer] fix llama error when transformers upgraded. (#5055)

* fix-llama

* Update llama.py

* [hotfix]: modify create_ep_hierarchical_group and add test (#5032)

* feat: modify create_ep_hierarchical_group args

* test: add ep tests

* fix: remove get_process_group_ranks

* fix: fix src_rank

* [exampe] fix llama example' loss error when using gemini plugin (#5060)

fix llama example

* [inference] Refactor inference architecture (#5057)

* [inference] support only TP (#4998)

* support only tp

* enable tp

* add support for bloom (#5008)

* [refactor] refactor gptq and smoothquant llama (#5012)

* refactor gptq and smoothquant llama

* fix import error

* fix linear import torch-int

* fix smoothquant llama import error

* fix import accelerate error

* fix bug

* fix import smooth cuda

* fix smoothcuda

* [Inference Refactor] Merge chatglm2 with pp and tp (#5023)

merge chatglm with pp and tp

* [Refactor] remove useless inference code (#5022)

* remove useless code

* fix quant model

* fix test import bug

* mv original inference legacy

* fix chatglm2

* [Refactor] refactor policy search and quant type controlling in inference (#5035)

* [Refactor] refactor policy search and quant type controling in inference

* [inference] update readme (#5051)

* update readme

* update readme

* fix architecture

* fix table

* fix table

* [inference] udpate example (#5053)

* udpate example

* fix run.sh

* fix rebase bug

* fix some errors

* update readme

* add some features

* update interface

* update readme

* update benchmark

* add requirements-infer

---------

Co-authored-by: Bin Jia <45593998+FoolPlayer@users.noreply.github.com>
Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com>

* [Kernels]added flash-decoidng of triton (#5063)

* added flash-decoidng of triton based on lightllm kernel

* add req

* clean

* clean

* delete build.sh

---------

Co-authored-by: cuiqing.li <lixx336@gmail.com>

* [misc] remove outdated submodule (#5070)

* [npu] add npu support for gemini and zero (#5067)

* [npu] setup device utils (#5047)

* [npu] add npu device support

* [npu] support low level zero

* [test] update npu zero plugin test

* [hotfix] fix import

* [test] recover tests

* [npu] gemini support npu (#5052)

* [npu] refactor device utils

* [gemini] support npu

* [example] llama2+gemini support npu

* [kernel] add arm cpu adam kernel (#5065)

* [kernel] add arm cpu adam

* [optim] update adam optimizer

* [kernel] arm cpu adam remove bf16 support

* [hotfix/hybridengine] fix bug when tp*pp size = 1 (#5069)

* [inference] update examples and engine (#5073)

* update examples and engine

* fix choices

* update example

* [format] applied code formatting on changed files in pull request 5067 (#5072)

Co-authored-by: github-actions <github-actions@github.com>

* [hotfix/hybridengine] Fix init model with random parameters in benchmark (#5074)

* fix init model with random parameters

* fix example

* [inference] refactor examples and fix schedule (#5077)

* [setup] refactor infer setup

* [hotfix] fix infenrece behavior on 1 1 gpu

* [exmaple] refactor inference examples

* fix thrust-transform-reduce error (#5078)

* [nfc] fix typo in docs/ (#4972)

* [nfc] fix typo and author name (#5089)

* [gemini]fix gemini optimzer, saving Shardformer in Gemini got list assignment index out of range (#5085)

* [Hotfix] Fix model policy matching strategy in ShardFormer (#5064)

* hotfix/Fix get model policy strategy in ShardFormer

* fix bug in auto policy

* [shardformer]fix flash attention, when mask is casual, just don't unpad it (#5084)

* fix flash attn

* fix

fix

* [npu] add npu support for hybrid plugin and llama (#5090)

* llama 3d

* update

* fix autocast

* [Feature] Add document retrieval QA (#5020)

* add langchain

* add langchain

* Add files via upload

* add langchain

* fix style

* fix style: remove extra space

* add pytest; modified retriever

* add pytest; modified retriever

* add tests to build_on_pr.yml

* fix build_on_pr.yml

* fix build on pr; fix environ vars

* seperate unit tests for colossalqa from build from pr

* fix container setting; fix environ vars

* commented dev code

* add incremental update

* remove stale code

* fix style

* change to sha3 224

* fix retriever; fix style; add unit test for document loader

* fix ci workflow config

* fix ci workflow config

* add set cuda visible device script in ci

* fix doc string

* fix style; update readme; refactored

* add force log info

* change build on pr, ignore colossalqa

* fix docstring, captitalize all initial letters

* fix indexing; fix text-splitter

* remove debug code, update reference

* reset previous commit

* update LICENSE update README add key-value mode, fix bugs

* add files back

* revert force push

* remove junk file

* add test files

* fix retriever bug, add intent classification

* change conversation chain design

* rewrite prompt and conversation chain

* add ui v1

* ui v1

* fix atavar

* add header

* Refactor the RAG Code and support Pangu

* Refactor the ColossalQA chain to Object-Oriented Programming and the UI demo.

* resolved conversation. tested scripts under examples. web demo still buggy

* fix ci tests

* Some modifications to add ChatGPT api

* modify llm.py and remove unnecessary files

* Delete applications/ColossalQA/examples/ui/test_frontend_input.json

* Remove OpenAI api key

* add colossalqa

* move files

* move files

* move files

* move files

* fix style

* Add Readme and fix some bugs.

* Add something to readme and modify some code

* modify a directory name for clarity

* remove redundant directory

* Correct a type in  llm.py

* fix AI prefix

* fix test_memory.py

* fix conversation

* fix some erros and typos

* Fix a missing import in RAG_ChatBot.py

* add colossalcloud LLM wrapper, correct issues in code review

---------

Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: Orion-Zheng <zheng_zian@u.nus.edu>
Co-authored-by: Zian(Andy) Zheng <62330719+Orion-Zheng@users.noreply.github.com>
Co-authored-by: Orion-Zheng <zhengzian@u.nus.edu>

* remove duplicate import (#5100)

* fix typo change lazy_iniy to lazy_init (#5099)

* [nfc] fix typo change directoty to directory (#5111)

* [FEATURE] Add Safety Eval Datasets to ColossalEval (#5095)

* add safetybench and cvalues(responsibility) eval dataset

* Modify code according to review suggestions

---------

Co-authored-by: Orion-Zheng <zhengzian@u.nus.edu>

* [hotfix] fixed memory usage of shardformer module replacement (#5122)

* [shardformer]: support gpt-j, falcon, Mistral and add interleaved pipeline for bert (#5088)

* [shardformer] implement policy for all GPT-J models and test

* [shardformer] support interleaved pipeline parallel for bert finetune

* [shardformer] shardformer support falcon (#4883)

* [shardformer]: fix interleaved pipeline for bert model (#5048)

* [hotfix]: disable seq parallel for gptj and falcon, and polish code (#5093)

* Add Mistral support for Shardformer (#5103)

* [shardformer] add tests to mistral (#5105)

---------

Co-authored-by: Pengtai Xu <henryxu880@gmail.com>
Co-authored-by: ppt0011 <143150326+ppt0011@users.noreply.github.com>
Co-authored-by: flybird11111 <1829166702@qq.com>
Co-authored-by: eric8607242 <e0928021388@gmail.com>

* [doc] add moe news (#5128)

* [doc] add moe news

* [doc] add moe news

* [doc] add moe news

* [doc] updated paper citation (#5131)

* fix typo change JOSNL TO JSONL etc. (#5116)

* [format] applied code formatting on changed files in pull request 5088 (#5127)

Co-authored-by: github-actions <github-actions@github.com>

* [format] applied code formatting on changed files in pull request 5124 (#5125)

Co-authored-by: github-actions <github-actions@github.com>

* [format] applied code formatting on changed files in pull request 5115 (#5118)

Co-authored-by: github-actions <github-actions@github.com>

* [accelerator] init the accelerator module (#5129)

* [accelerator] init the accelerator module

* polish code

* polish code

* polish code

* polish code

* [npu] support triangle attention for llama (#5130)

* update fused attn

* update spda

* tri attn

* update triangle

* import

* fix

* fix

* [plugin]fix 3d checkpoint load when booster boost without optimizer. (#5135)

* fix 3d checkpoint load when booster boost without optimizer

fix 3d checkpoint load when booster boost without optimizer

* test ci

* revert ci

* fix

fix

* [ColossalQA] refactor server and webui & add new feature (#5138)

* refactor server and webui & add new feature

* add requirements

* modify readme and ui

* [doc] fix colossalqa document (#5146)

* fix doc

* modify doc

* fix (#5158)

fix

* [Colossal-Llama-2] Add finetuning Colossal-Llama-2 example (#4878)

* Add finetuning Colossal-Llama-2 example

* Add finetuning Colossal-Llama-2 example 2

* Add finetuning Colossal-Llama-2 example and support NEFTuning

* Add inference example and refine neftune

* Modify readme file

* update the imports

---------

Co-authored-by: Xu Yuanchen <yuanchen.xu00@gmail.com>
Co-authored-by: Camille Zhong <44392324+Camille7777@users.noreply.github.com>

* [gemini]  hotfix NaN loss while using Gemini + tensor_parallel (#5150)

* fix

aaa

fix

fix

fix

* fix

* fix

* test ci

* fix ci

fix

* [colossalqa] fix pangu api (#5170)

* fix pangu api

* add comment

* [ColossalEval] Support GSM, Data Leakage Evaluation and Tensor Parallel (#5169)

* Support GSM, Data Leakage Evaluation and Tensor Parallel

* remove redundant code and update inference.py in examples/gpt_evaluation

---------

Co-authored-by: Xu Yuanchen <yuanchen.xu00@gmail.com>

* [shardformer] llama support DistCrossEntropy (#5176)

* fix

aaa

fix

fix

fix

* fix

* fix

* test ci

* fix ci

fix

* llama support dist-cross

fix

fix

fix

fix

fix

fix

fix

fix

* fix

* fix

* fix

fix

* test ci

* test ci

* fix

* [Colossal-Llama-2] Add finetuning Colossal-Llama-2 example (#4878)

* Add finetuning Colossal-Llama-2 example

* Add finetuning Colossal-Llama-2 example 2

* Add finetuning Colossal-Llama-2 example and support NEFTuning

* Add inference example and refine neftune

* Modify readme file

* update the imports

---------

Co-authored-by: Xu Yuanchen <yuanchen.xu00@gmail.com>
Co-authored-by: Camille Zhong <44392324+Camille7777@users.noreply.github.com>

* llama support dist-cross

fix

fix

fix

fix

fix

fix

fix

fix

* fix

* fix

* fix

fix

* test ci

* test ci

* fix

* fix ci

* fix ci

---------

Co-authored-by: Yuanchen <70520919+chengeharrison@users.noreply.github.com>
Co-authored-by: Xu Yuanchen <yuanchen.xu00@gmail.com>
Co-authored-by: Camille Zhong <44392324+Camille7777@users.noreply.github.com>

* Fix ColossalEval (#5186)

Co-authored-by: Xu Yuanchen <yuanchen.xu00@gmail.com>

* [doc] update pytorch version in documents. (#5177)

* fix

aaa

fix

fix

fix

* fix

* fix

* test ci

* fix ci

fix

* update pytorch version in documents

* polish readme in application/chat (#5194)

* [pipeline]: fix p2p comm, add metadata cache and support llama interleaved pp (#5134)

* test: add more p2p tests

* fix: remove send_forward_recv_forward as p2p op list need to use the same group

* fix: make send and receive atomic

* feat: update P2PComm fn

* feat: add metadata cache in 1f1b

* feat: add metadata cache in interleaved pp

* feat: modify is_xx_stage fn

* revert: add _broadcast_object_list

* feat: add interleaved pp in llama policy

* feat: set NCCL_BUFFSIZE in HybridParallelPlugin

* Improve logic for selecting metrics (#5196)

Co-authored-by: Xu <yuanchen.xu00@gmail.com>

* [doc] Update required third-party library list for testing and torch comptibility checking (#5207)

* doc/update requirements-test.txt

* update torch-cuda compatibility check

* support linear accumulation fusion (#5199)

support linear accumulation fusion

support linear accumulation fusion

fix

* [pipeline]: support arbitrary batch size in forward_only mode (#5201)

* fix: remove drop last in val & test dataloader

* feat: add run_forward_only, support arbitrary bs

* chore: modify ci script

* [pipeline]: add p2p fallback order and fix interleaved pp deadlock (#5214)

* fix: add fallback order option and update 1f1b

* fix: fix deadlock comm in interleaved pp

* test: modify p2p test

* [devops] update torch versoin in ci (#5217)

* fix-test (#5210)

fix-test

fix-test

* fix flash attn (#5209)

* [nfc] fix typo colossalai/shardformer/ (#5133)

* [Colossal-LLaMA-2] Release Colossal-LLaMA-2-13b-base model (#5224)

* update readme

* update readme

* update link

* update

* update readme

* update

* update

* update

* update title

* update example

* update example

* fix content

* add conclusion

* add license

* update

* update

* update version

* fix minor

* [doc] Update README.md of Colossal-LLAMA2 (#5233)

* Update README.md

* Update README.md

* [doc] Make leaderboard format more uniform and good-looking (#5231)

* Make leaderboard format more unifeid and good-looking

* Update README.md

* Update README.md

* [doc] add Colossal-LLaMA-2-13B (#5234)

* [doc] add Colossal-LLaMA-2-13B

* [doc] add Colossal-LLaMA-2-13B

* [doc] add Colossal-LLaMA-2-13B

* [format] applied code formatting on changed files in pull request 5234 (#5235)

Co-authored-by: github-actions <github-actions@github.com>

* [doc] SwiftInfer release (#5236)

* [doc] SwiftInfer release

* [doc] SwiftInfer release

* [doc] SwiftInfer release

* [doc] SwiftInfer release

* [doc] SwiftInfer release

* [npu] use extension for op builder (#5172)

* update extension

* update cpu adam

* update is

* add doc for cpu adam

* update kernel

* update commit

* update flash

* update memory efficient

* update flash attn

* update flash attention loader

* update api

* fix

* update doc

* update example time limit

* reverse change

* fix doc

* remove useless kernel

* fix

* not use warning

* update

* update

* [pipeline] A more general _communicate in p2p (#5062)

* A more general _communicate

* feat: finish tree_flatten version p2p

* fix: update p2p api calls

---------

Co-authored-by: Wenhao Chen <cwher@outlook.com>

* [npu] change device to accelerator api (#5239)

* update accelerator

* fix timer

* fix amp

* update

* fix

* update bug

* add error raise

* fix autocast

* fix set device

* remove doc accelerator

* update doc

* update doc

* update doc

* use nullcontext

* update cpu

* update null context

* change time limit for example

* udpate

* update

* update

* update

* [npu] polish accelerator code

---------

Co-authored-by: Xuanlei Zhao <xuanlei.zhao@gmail.com>
Co-authored-by: zxl <43881818+oahzxl@users.noreply.github.com>

* [hotfix] removed unused flag (#5242)

* [doc] fix typo in Colossal-LLaMA-2/README.md (#5247)

* [workflow] fixed build CI (#5240)

* [workflow] fixed build CI

* polish

* polish

* polish

* polish

* polish

* [ci] fixed booster test (#5251)

* [ci] fixed booster test

* [ci] fixed booster test

* [ci] fixed booster test

* [ci] fixed ddp test (#5254)

* [ci] fixed ddp test

* polish

* fix typo in  applications/ColossalEval/README.md (#5250)

* [ci] fix shardformer tests. (#5255)

* fix ci

fix

* revert: revert p2p

* feat: add enable_metadata_cache option

* revert: enable t5 tests

---------

Co-authored-by: Wenhao Chen <cwher@outlook.com>

* [doc] fix doc typo (#5256)

* [doc] fix annotation display

* [doc] fix llama2 doc

* [hotfix]: add pp sanity check and fix mbs arg (#5268)

* fix: fix misleading mbs arg

* feat: add pp sanity check

* fix: fix 1f1b sanity check

* [workflow] fixed incomplete bash command (#5272)

* [workflow] fixed oom tests (#5275)

* [workflow] fixed oom tests

* polish

* polish

* polish

* [ci] fix test_hybrid_parallel_plugin_checkpoint_io.py (#5276)

* fix ci

fix

* fix test

* revert: revert p2p

* feat: add enable_metadata_cache option

* revert: enable t5 tests

* fix

---------

Co-authored-by: Wenhao Chen <cwher@outlook.com>

* [shardformer] hybridparallelplugin support gradients accumulation. (#5246)

* support gradients acc

fix

fix

fix

fix

fix

fix

fix

fix

fix

fix

fix

fix

fix

* fix

fix

* fix

fix

fix

* [hotfix] Fix ShardFormer test execution path when using sequence parallelism (#5230)

* fix auto loading gpt2 tokenizer (#5279)

* [doc] add llama2-13B disyplay (#5285)

* Update README.md

* fix 13b typo

---------

Co-authored-by: binmakeswell <binmakeswell@gmail.com>

* fix llama pretrain (#5287)

* [hotfix] fix 3d plugin test (#5292)

* fix bug for mefture (#5299)

* [NFC] polish applications/Colossal-LLaMA-2/colossal_llama2/tokenizer/init_tokenizer.py code style (#5228)

* fix some typo (#5307)

* [feat] refactored extension module (#5298)

* [feat] refactored extension module

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* [workflow] updated CI image (#5318)

* [accelerator] fixed npu api

* [tests] fix t5 test. (#5322)

* [ci] fix shardformer tests. (#5255)

* fix ci

fix

* revert: revert p2p

* feat: add enable_metadata_cache option

* revert: enable t5 tests

---------

Co-authored-by: Wenhao Chen <cwher@outlook.com>

* fix t5 test

---------

Co-authored-by: Wenhao Chen <cwher@outlook.com>

* [doc] added docs for extensions (#5324)

* [doc] added docs for extensions

* polish

* polish

* fix typo under extensions/ (#5330)

* fix typo change dosen't to doesn't (#5308)

* [extension] fixed exception catch (#5342)

* [Chat] fix sft loss nan (#5345)

* fix script

* fix script

* fix chat nan

* fix chat nan

* [checkpointio] fix gemini and hybrid parallel optim checkpoint (#5347)

* [checkpointio] fix hybrid parallel optim checkpoint

* [extension] fix cuda extension

* [checkpointio] fix gemini optimizer checkpoint

* polish code

* [fix] remove unnecessary dp_size assert  (#5351)

* fix: remove unnecessary assert

* test: add more 3d plugin tests

* fix: add warning

* [gemini] fix param op hook when output is tuple (#5355)

* [gemini] fix param op hook when output is tuple

* [gemini] fix param op hook

* [llama] fix dataloader for hybrid parallel (#5358)

* [plugin] refactor prepare dataloader

* [plugin] update train script

* [llama] update training script (#5360)

* [llama] update training script

* [doc] polish docstr

* [llama] add flash attn patch for npu (#5362)

* [llama] fix neftune & pbar with start_step (#5364)

* [eval] update llama npu eval (#5366)

* [llama] polish training script and fix optim ckpt (#5368)

* [lr-scheduler] fix load state dict and add test (#5369)

* [llama] fix memory issue (#5371)

* [llama] fix memory issue

* [llama] add comment

* [moe] init mixtral impl

* [moe] update capacity computing (#5253)

* [moe] top2 allow uneven input

* [moe] update capacity computing

* [moe] remove debug info

* [moe] update capacity computing

* [moe] update capacity computing

* [moe] support mixtral (#5309)

* [moe] add mixtral block for single expert

* [moe] mixtral block fwd support uneven ep

* [moe] mixtral block bwd support uneven ep

* [moe] add mixtral moe layer

* [moe] simplify replace

* [meo] support save sharded mixtral

* [meo] support load sharded mixtral

* [meo] support save sharded optim

* [meo] integrate moe manager into plug

* [meo] fix optimizer load

* [meo] fix mixtral layer

* [moe] fix mixtral checkpoint io (#5314)

* [moe] fix mixtral forward default value (#5329)

* [moe] fix mixtral optim checkpoint (#5344)

* [moe] fix tests

* [release] update version (#5380)

* [llama] fix training and inference scripts (#5384)

* [llama] refactor inference example to fit sft

* [llama] fix training script to fit gemini

* [llama] fix inference script

* [doc] Fix typo (#5361)

* [doc] updated installation command (#5389)

* [hotfix] fix variable type for top_p (#5313)

Co-authored-by: binmakeswell <binmakeswell@gmail.com>

* [hotfix] Fix wrong import in meta_registry (#5392)

* [extension] hotfix jit extension setup (#5402)

* [example] reuse flash attn patch (#5400)

* [fsdp] impl save/load shard model/optimizer (#5357)

* [setup] fixed nightly release (#5388)

* [shardformer]gather llama logits (#5398)

* gather llama logits

* fix

* update requirements (#5407)

* [workflow] added pypi channel (#5412)

* [doc] fix blog link

* [doc] fix blog link

* fix sft single turn inference example (#5416)

* [example]add gpt2 benchmark example script. (#5295)

* benchmark gpt2

* fix

fix

fix

fix

* [doc] fix typo in Colossal-LLaMA-2/README.md (#5247)

* [workflow] fixed build CI (#5240)

* [workflow] fixed build CI

* polish

* polish

* polish

* polish

* polish

* [ci] fixed booster test (#5251)

* [ci] fixed booster test

* [ci] fixed booster test

* [ci] fixed booster test

* [ci] fixed ddp test (#5254)

* [ci] fixed ddp test

* polish

* fix typo in  applications/ColossalEval/README.md (#5250)

* [ci] fix shardformer tests. (#5255)

* fix ci

fix

* revert: revert p2p

* feat: add enable_metadata_cache option

* revert: enable t5 tests

---------

Co-authored-by: Wenhao Chen <cwher@outlook.com>

* [doc] fix doc typo (#5256)

* [doc] fix annotation display

* [doc] fix llama2 doc

* [hotfix]: add pp sanity check and fix mbs arg (#5268)

* fix: fix misleading mbs arg

* feat: add pp sanity check

* fix: fix 1f1b sanity check

* [workflow] fixed incomplete bash command (#5272)

* [workflow] fixed oom tests (#5275)

* [workflow] fixed oom tests

* polish

* polish

* polish

* [ci] fix test_hybrid_parallel_plugin_checkpoint_io.py (#5276)

* fix ci

fix

* fix test

* revert: revert p2p

* feat: add enable_metadata_cache option

* revert: enable t5 tests

* fix

---------

Co-authored-by: Wenhao Chen <cwher@outlook.com>

* [shardformer] hybridparallelplugin support gradients accumulation. (#5246)

* support gradients acc

fix

fix

fix

fix

fix

fix

fix

fix

fix

fix

fix

fix

fix

* fix

fix

* fix

fix

fix

* [hotfix] Fix ShardFormer test execution path when using sequence parallelism (#5230)

* fix auto loading gpt2 tokenizer (#5279)

* [doc] add llama2-13B disyplay (#5285)

* Update README.md

* fix 13b typo

---------

Co-authored-by: binmakeswell <binmakeswell@gmail.com>

* fix llama pretrain (#5287)

* fix

* fix

* fix

fix

* fix

fix

fix

* fix

fix

* benchmark gpt2

* fix

fix

fix

fix

* [workflow] fixed build CI (#5240)

* [workflow] fixed build CI

* polish

* polish

* polish

* polish

* polish

* [ci] fixed booster test (#5251)

* [ci] fixed booster test

* [ci] fixed booster test

* [ci] fixed booster test

* fix

fix

* fix

fix

fix

* fix

* fix

fix

fix

fix

fix

* fix

* Update shardformer.py

---------

Co-authored-by: digger yu <digger-yu@outlook.com>
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: Wenhao Chen <cwher@outlook.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com>
Co-authored-by: Michelle <97082656+MichelleMa8@users.noreply.github.com>
Co-authored-by: Desperado-Jia <502205863@qq.com>

* [doc] sora release (#5425)

* [doc] sora release

* [doc] sora release

* [doc] sora release

* [doc] sora release

* [devops] fix extention building (#5427)

* [hotfix] fix sd vit import error (#5420)

* fix import error

* Update dpt_depth.py

---------

Co-authored-by: binmakeswell <binmakeswell@gmail.com>

* [hotfix] fix typo of openmoe model source (#5403)

* [doc] update some translations with README-zh-Hans.md (#5382)

* [hotfix] fix typo change _descrption to _description (#5331)

* [hotfix] fix typo change enabel to enable under colossalai/shardformer/ (#5317)

* [eval-hotfix] set few_shot_data to None when few shot is disabled (#5422)

* [hotfix] fix typo change MoECheckpintIO to MoECheckpointIO (#5335)

Co-authored-by: binmakeswell <binmakeswell@gmail.com>

* [doc] Fix typo s/infered/inferred/ (#5288)

Signed-off-by: hugo-syn <hugo.vincent@synacktiv.com>

* [hotfix] fix stable diffusion inference bug. (#5289)

* Update train_ddp.yaml

delete  "strategy" to fix DDP config loading bug in "main.py"

* Update train_ddp.yaml

fix inference with scripts/txt2img.py config file load bug.

* Update README.md

add pretrain model test code.

* [colossal-llama2] add stream chat examlple for chat version model (#5428)

* add stream chat for chat version

* remove os.system clear

* modify function name

* [release] update version (#5411)

* fix tensor data update for gemini loss caluculation (#5442)

* [hotfix] fix typo s/keywrods/keywords etc. (#5429)

* [devops] fix compatibility (#5444)

* [devops] fix compatibility

* [hotfix] update compatibility test on pr

* [devops] fix compatibility

* [devops] record duration during comp test

* [test] decrease test duration

* fix falcon

* [shardformer] fix gathering output when using tensor parallelism (#5431)

* fix

* padding vocab_size when using pipeline parallellism

padding vocab_size when using pipeline parallellism

fix

fix

* fix

* fix

fix

fix

* fix gather output

* fix

* fix

* fix

fix resize embedding

fix resize embedding

* fix resize embedding

fix

* revert

* revert

* revert

* [doc] release Open-Sora 1.0 with model weights (#5468)

* [doc] release Open-Sora 1.0 with model weights

* [doc] release Open-Sora 1.0 with model weights

* [doc] release Open-Sora 1.0 with model weights

* [doc] update open-sora demo (#5479)

* [doc] update open-sora demo

* [doc] update open-sora demo

* [doc] update open-sora demo

* [example] add grok-1 inference (#5485)

* [misc] add submodule

* remove submodule

* [example] support grok-1 tp inference

* [example] add grok-1 inference script

* [example] refactor code

* [example] add grok-1 readme

* [exmaple] add test ci

* [exmaple] update readme

* [release] grok-1 314b inference (#5490)

* [release] grok-1 inference

* [release] grok-1 inference

* [release] grok-1 inference

* [example] update Grok-1 inference (#5495)

* revise grok-1 example

* remove unused arg in scripts

* prevent re-installing torch

* update readme

* revert modifying colossalai requirements

* add perf

* trivial

* add tokenizer url

* [hotfix] set return_outputs=False in examples and polish code (#5404)

* fix: simplify merge_batch

* fix: use return_outputs=False to eliminate extra memory consumption

* feat: add return_outputs warning

* style: remove `return_outputs=False` as it is the default value

* [release] grok-1 inference benchmark (#5500)

* [release] grok-1 inference benchmark

* [release] grok-1 inference benchmark

* [release] grok-1 inference benchmark

* [release] grok-1 inference benchmark

* [release] grok-1 inference benchmark

* [shardformer]Fix lm parallel. (#5480)

* fix

* padding vocab_size when using pipeline parallellism

padding vocab_size when using pipeline parallellism

fix

fix

* fix

* fix

fix

fix

* fix gather output

* fix

* fix

* fix

fix resize embedding

fix resize embedding

* fix resize embedding

fix

* revert

* revert

* revert

* fix lm forward distribution

* fix

* test ci

* fix

* [fix] fix grok-1 example typo (#5506)

* [devops] fix example test ci (#5504)

* Fix ColoTensorSpec for py11 (#5440)

* fixed layout converter caching and updated tester

* Empty-Commit

* [shardformer] update colo attention to support custom mask (#5510)

* [feature] refactor colo attention (#5462)

* [extension] update api

* [feature] add colo attention

* [feature] update sdpa

* [feature] update npu attention

* [feature] update flash-attn

* [test] add flash attn test

* [test] update flash attn test

* [shardformer] update modeling to fit colo attention (#5465)

* [misc] refactor folder structure

* [shardformer] update llama flash-attn

* [shardformer] fix llama policy

* [devops] update tensornvme install

* [test] update llama test

* [shardformer] update colo attn kernel dispatch

* [shardformer] update blip2

* [shardformer] update chatglm

* [shardformer] update gpt2

* [shardformer] update gptj

* [shardformer] update opt

* [shardformer] update vit

* [shardformer] update colo attention mask prep

* [shardformer] update whisper

* [test] fix shardformer tests (#5514)

* [test] fix shardformer tests

* [test] fix shardformer tests

* [format] applied code formatting on changed files in pull request 5510 (#5517)

Co-authored-by: github-actions <github-actions@github.com>

* [shardformer] fix pipeline forward error if custom layer distribution is used (#5189)

* Use self.[distribute_layers|get_stage_index] to exploit custom layer distribution

* Change static methods for t5 layer distribution to member functions

* Change static methods for whisper layer distribution to member functions

* Replace whisper policy usage with self one

* Fix test case to use non-static layer distribution methods

* fix: fix typo

---------

Co-authored-by: Wenhao Chen <cwher@outlook.com>

* [Fix] Grok-1 use tokenizer from the same pretrained path (#5532)

* [fix] use tokenizer from the same pretrained path

* trust remote code

* [ColossalChat] Update RLHF V2 (#5286)

* Add dpo. Fix sft, ppo, lora. Refactor all

* fix and tested ppo

* 2 nd round refactor

* add ci tests

* fix ci

* fix ci

* fix readme, style

* fix readme style

* fix style, fix benchmark

* reproduce benchmark result, remove useless files

* rename to ColossalChat

* use new image

* fix ci workflow

* fix ci

* use local model/tokenizer for ci tests

* fix ci

* fix ci

* fix ci

* fix ci timeout

* fix rm progress bar. fix ci timeout

* fix ci

* fix ci typo

* remove 3d plugin from ci temporary

* test environment

* cannot save optimizer

* support chat template

* fix readme

* fix path

* test ci locally

* restore build_or_pr

* fix ci data path

* fix benchmark

* fix ci, move ci tests to 3080, disable fast tokenizer

* move ci to 85

* support flash attention 2

* add all-in-one data preparation script. Fix colossal-llama2-chat chat template

* add hardware requirements

* move ci test data

* fix save_model, add unwrap

* fix missing bos

* fix missing bos; support grad accumulation with gemini

* fix ci

* fix ci

* fix ci

* fix llama2 chat template config

* debug sft

* debug sft

* fix colossalai version requirement

* fix ci

* add sanity check to prevent NaN loss

* fix requirements

* add dummy data generation script

* add dummy data generation script

* add dummy data generation script

* add dummy data generation script

* update readme

* update readme

* update readme and ignore

* fix logger bug

* support parallel_output

* modify data preparation logic

* fix tokenization

* update lr

* fix inference

* run pre-commit

---------

Co-authored-by: Tong Li <tong.li352711588@gmail.com>

* [shardformer, pipeline] add `gradient_checkpointing_ratio` and heterogenous shard policy for llama (#5508)

* feat: add `GradientCheckpointConfig` and `PipelineGradientCheckpointConfig`

* feat: apply `GradientCheckpointConfig` to policy and llama_forward

* feat: move `distribute_layer` and `get_stage_index` to PipelineStageManager

* fix: add optional args for `distribute_layer` and `get_stage_index`

* fix: fix changed API calls

* test: update llama tests

* style: polish `GradientCheckpointConfig`

* fix: fix pipeline utils tests

* fix incorrect sharding without zero (#5545)

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [shardformer] Sequence Parallelism Optimization (#5533)

* sequence parallel optimization

* validate sequence parallel in llama (code to be polished)

* shardformer api writing

* integrate sequence parallel in ShardFormer

* fix pp bugs and sp bugs for LlaMa model

* integrating ring-based sequence parallelism into ShardFormer

* [sequence parallelism]: Add fused megatron function

* integrating ring-based sequence parallelism into ShardFormer

---------

Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn>

* fix bugs when useing sp and flashattention together

* fix operation function name

* support flash attention for ulysses-style sp

* clarify sp process group

* fix compatibility bugs in moe plugin

* fix fused linear bugs

* fix linear layer test

* support gpt model all-to-all sp

* modify shard data dimension (meant to be dim=-1)

* support megtron-style sp and distributed attn for llama model

* [shardformer] add megatron sp to llama

* support llama7B 128k with distributed attention

* [shardformer] robustness enhancement

* add block attn

* sp mode 1: keep input as a complete sequence

* fix sp compatability

* finish sp mode 3 support for gpt

* using all_to_all_single when batch size is 1

* support mode 2 sp in gpt2 (#5)

* [shardformer] add megatron sp to llama

* support llama7B 128k with distributed attention

* [shardformer] robustness enhancement

* add block attn

* sp mode 1: keep input as a complete sequence

* fix sp compatability

* refactor ring implementation

* support mode 2 sp in gpt2

* polish code

* enable distributed attn mask when using sp mode 2 and 3 in llama

* automatically enable flash attn when using sp mode 2 and 3 in llama

* inplace attn mask

* add zero2 support for sequence parallel

* polish code

* fix bugs

* fix gemini checkpoint io

* loose tensor checking atol and rtol

* add comment

* fix llama layernorm grad

* fix zero grad

* fix zero grad

* fix conflict

* update split and gather auto grad func

* sequence parallel: inside text split (#6)

* polish code (part 1)

* polish code (part 2)

* polish code (part 2.5)

* polish code (part 3)

* sequence parallel: inside text split

* miscellaneous minor fixes

* polish code

* fix ulysses style ZeRO

* sequence parallel: inside text split

* miscellaneous minor fixes

* disaggregate sp group and dp group for  sp

* fix llama and gpt sp

* polish code

* move ulysses grad sync to ddp (#9)

* remove zero_stage and unbind the grad sync for alltoall sp

* add 2d group creation test

* move ulysses grad sync to ddp

* add 2d group creation test

* remove useless code

* change shard config not to enable sp when enable_all_optimizations

* add sp warnings for several model

* remove useless code

---------

Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn>

* [hotfix] quick fixes to make legacy tutorials runnable (#5559)

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [fix] fix typo s/muiti-node /multi-node etc. (#5448)

* [hotfix] fix typo s/get_defualt_parser /get_default_parser (#5548)

* [devops] remove post commit ci (#5566)

* [devops] remove post commit ci

* [misc] run pre-commit on all files

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [doc] fix ColossalMoE readme (#5599)

* fix readme

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [zero] support multiple (partial) backward passes (#5596)

* [zero] support multiple (partial) backward passes

* [misc] update requirements

* [shardformer] refactor embedding resize (#5603)

* [branch rebase] rebase main to Feature/resize_embedding (#5554)

* fix

* [release] update version (#5411)

* [hotfix] fix typo s/keywrods/keywords etc. (#5429)

* [devops] fix compatibility (#5444)

* [devops] fix compatibility

* [hotfix] update compatibility test on pr

* [devops] fix compatibility

* [devops] record duration during comp test

* [test] decrease test duration

* fix falcon

* [shardformer] fix gathering output when using tensor parallelism (#5431)

* fix

* padding vocab_size when using pipeline parallellism

padding vocab_size when using pipeline parallellism

fix

fix

* fix

* fix

fix

fix

* fix gather output

* fix

* fix

* fix

fix resize embedding

fix resize embedding

* fix resize embedding

fix

* revert

* revert

* revert

* [doc] release Open-Sora 1.0 with model weights (#5468)

* [doc] release Open-Sora 1.0 with model weights

* [doc] release Open-Sora 1.0 with model weights

* [doc] release Open-Sora 1.0 with model weights

* [doc] update open-sora demo (#5479)

* [doc] update open-sora demo

* [doc] update open-sora demo

* [doc] update open-sora demo

* [example] add grok-1 inference (#5485)

* [misc] add submodule

* remove submodule

* [example] support grok-1 tp inference

* [example] add grok-1 inference script

* [example] refactor code

* [example] add grok-1 readme

* [exmaple] add test ci

* [exmaple] update readme

---------

Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: digger yu <digger-yu@outlook.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>

* [CI] run pre-commit (#5577)

* fix

* [release] update version (#5411)

* [hotfix] fix typo s/keywrods/keywords etc. (#5429)

* [devops] fix compatibility (#5444)

* [devops] fix compatibility

* [hotfix] update compatibility test on pr

* [devops] fix compatibility

* [devops] record duration during comp test

* [test] decrease test duration

* fix falcon

* [shardformer] fix gathering output when using tensor parallelism (#5431)

* fix

* padding vocab_size when using pipeline parallellism

padding vocab_size when using pipeline parallellism

fix

fix

* fix

* fix

fix

fix

* fix gather output

* fix

* fix

* fix

fix resize embedding

fix resize embedding

* fix resize embedding

fix

* revert

* revert

* revert

* [doc] release Open-Sora 1.0 with model weights (#5468)

* [doc] release Open-Sora 1.0 with model weights

* [doc] release Open-Sora 1.0 with model weights

* [doc] release Open-Sora 1.0 with model weights

* [doc] update open-sora demo (#5479)

* [doc] update open-sora demo

* [doc] update open-sora demo

* [doc] update open-sora demo

* [example] add grok-1 inference (#5485)

* [misc] add submodule

* remove submodule

* [example] support grok-1 tp inference

* [example] add grok-1 inference script

* [example] refactor code

* [example] add grok-1 readme

* [exmaple] add test ci

* [exmaple] update readme

* run pre-commit

---------

Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: digger yu <digger-yu@outlook.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>

* [rebase] rebase main to resize-embedding (#5581)

* [release] grok-1 314b inference (#5490)

* [release] grok-1 inference

* [release] grok-1 inference

* [release] grok-1 inference

* [example] update Grok-1 inference (#5495)

* revise grok-1 example

* remove unused arg in scripts

* prevent re-installing torch

* update readme

* revert modifying colossalai requirements

* add perf

* trivial

* add tokenizer url

* [hotfix] set return_outputs=False in examples and polish code (#5404)

* fix: simplify merge_batch

* fix: use return_outputs=False to eliminate extra memory consumption

* feat: add return_outputs warning

* style: remove `return_outputs=False` as it is the default value

* [release] grok-1 inference benchmark (#5500)

* [release] grok-1 inference benchmark

* [release] grok-1 inference benchmark

* [release] grok-1 inference benchmark

* [release] grok-1 inference benchmark

* [release] grok-1 inference benchmark

* [shardformer]Fix lm parallel. (#5480)

* fix

* padding vocab_size when using pipeline parallellism

padding vocab_size when using pipeline parallellism

fix

fix

* fix

* fix

fix

fix

* fix gather output

* fix

* fix

* fix

fix resize embedding

fix resize embedding

* fix resize embedding

fix

* revert

* revert

* revert

* fix lm forward distribution

* fix

* test ci

* fix

* [fix] fix grok-1 example typo (#5506)

* [devops] fix example test ci (#5504)

* Fix ColoTensorSpec for py11 (#5440)

* fixed layout converter caching and updated tester

* Empty-Commit

* [shardformer] update colo attention to support custom mask (#5510)

* [feature] refactor colo attention (#5462)

* [extension] update api

* [feature] add colo attention

* [feature] update sdpa

* [feature] update npu attention

* [feature] update flash-attn

* [test] add flash attn test

* [test] update flash attn test

* [shardformer] update modeling to fit colo attention (#5465)

* [misc] refactor folder structure

* [shardformer] update llama flash-attn

* [shardformer] fix llama policy

* [devops] update tensornvme install

* [test] update llama test

* [shardformer] update colo attn kernel dispatch

* [shardformer] update blip2

* [shardformer] update chatglm

* [shardformer] update gpt2

* [shardformer] update gptj

* [shardformer] update opt

* [shardformer] update vit

* [shardformer] update colo attention mask prep

* [shardformer] update whisper

* [test] fix shardformer tests (#5514)

* [test] fix shardformer tests

* [test] fix shardformer tests

* [format] applied code formatting on changed files in pull request 5510 (#5517)

Co-authored-by: github-actions <github-actions@github.com>

* [shardformer] fix pipeline forward error if custom layer distribution is used (#5189)

* Use self.[distribute_layers|get_stage_index] to exploit custom layer distribution

* Change static methods for t5 layer distribution to member functions

* Change static methods for whisper layer distribution to member functions

* Replace whisper policy usage with self one

* Fix test case to use non-static layer distribution methods

* fix: fix typo

---------

Co-authored-by: Wenhao Chen <cwher@outlook.com>

* [Fix] Grok-1 use tokenizer from the same pretrained path (#5532)

* [fix] use tokenizer from the same pretrained path

* trust remote code

* [ColossalChat] Update RLHF V2 (#5286)

* Add dpo. Fix sft, ppo, lora. Refactor all

* fix and tested ppo

* 2 nd round refactor

* add ci tests

* fix ci

* fix ci

* fix readme, style

* fix readme style

* fix style, fix benchmark

* reproduce benchmark result, remove useless files

* rename to ColossalChat

* use new image

* fix ci workflow

* fix ci

* use local model/tokenizer for ci tests

* fix ci

* fix ci

* fix ci

* fix ci timeout

* fix rm progress bar. fix ci timeout

* fix ci

* fix ci typo

* remove 3d plugin from ci temporary

* test environment

* cannot save optimizer

* support chat template

* fix readme

* fix path

* test ci locally

* restore build_or_pr

* fix ci data path

* fix benchmark

* fix ci, move ci tests to 3080, disable fast tokenizer

* move ci to 85

* support flash attention 2

* add all-in-one data preparation script. Fix colossal-llama2-chat chat template

* add hardware requirements

* move ci test data

* fix save_model, add unwrap

* fix missing bos

* fix missing bos; support grad accumulation with gemini

* fix ci

* fix ci

* fix ci

* fix llama2 chat template config

* debug sft

* debug sft

* fix colossalai version requirement

* fix ci

* add sanity check to prevent NaN loss

* fix requirements

* add dummy data generation script

* add dummy data generation script

* add dummy data generation script

* add dummy data generation script

* update readme

* update readme

* update readme and ignore

* fix logger bug

* support parallel_output

* modify data preparation logic

* fix tokenization

* update lr

* fix inference

* run pre-commit

---------

Co-authored-by: Tong Li <tong.li352711588@gmail.com>

* [shardformer, pipeline] add `gradient_checkpointing_ratio` and heterogenous shard policy for llama (#5508)

* feat: add `GradientCheckpointConfig` and `PipelineGradientCheckpointConfig`

* feat: apply `GradientCheckpointConfig` to policy and llama_forward

* feat: move `distribute_layer` and `get_stage_index` to PipelineStageManager

* fix: add optional args for `distribute_layer` and `get_stage_index`

* fix: fix changed API calls

* test: update llama tests

* style: polish `GradientCheckpointConfig`

* fix: fix pipeline utils tests

* fix incorrect sharding without zero (#5545)

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [shardformer] Sequence Parallelism Optimization (#5533)

* sequence parallel optimization

* validate sequence parallel in llama (code to be polished)

* shardformer api writing

* integrate sequence parallel in ShardFormer

* fix pp bugs and sp bugs for LlaMa model

* integrating ring-based sequence parallelism into ShardFormer

* [sequence parallelism]: Add fused megatron function

* integrating ring-based sequence parallelism into ShardFormer

---------

Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn>

* fix bugs when useing sp and flashattention together

* fix operation function name

* support flash attention for ulysses-style sp

* clarify sp process group

* fix compatibility bugs in moe plugin

* fix fused linear bugs

* fix linear layer test

* support gpt model all-to-all sp

* modify shard data dimension (meant to be dim=-1)

* support megtron-style sp and distributed attn for llama model

* [shardformer] add megatron sp to llama

* support llama7B 128k with distributed attention

* [shardformer] robustness enhancement

* add block attn

* sp mode 1: keep input as a complete sequence

* fix sp compatability

* finish sp mode 3 support for gpt

* using all_to_all_single when batch size is 1

* support mode 2 sp in gpt2 (#5)

* [shardformer] add megatron sp to llama

* support llama7B 128k with distributed attention

* [shardformer] robustness enhancement

* add block attn

* sp mode 1: keep input as a complete sequence

* fix sp compatability

* refactor ring implementation

* support mode 2 sp in gpt2

* polish code

* enable distributed attn mask when using sp mode 2 and 3 in llama

* automatically enable flash attn when using sp mode 2 and 3 in llama

* inplace attn mask

* add zero2 support for sequence parallel

* polish code

* fix bugs

* fix gemini checkpoint io

* loose tensor checking atol and rtol

* add comment

* fix llama layernorm grad

* fix zero grad

* fix zero grad

* fix conflict

* update split and gather auto grad func

* sequence parallel: inside text split (#6)

* polish code (part 1)

* polish code (part 2)

* polish code (part 2.5)

* polish code (part 3)

* sequence parallel: inside text split

* miscellaneous minor fixes

* polish code

* fix ulysses style ZeRO

* sequence parallel: inside text split

* miscellaneous minor fixes

* disaggregate sp group and dp group for  sp

* fix llama and gpt sp

* polish code

* move ulysses grad sync to ddp (#9)

* remove zero_stage and unbind the grad sync for alltoall sp

* add 2d group creation test

* move ulysses grad sync to ddp

* add 2d group creation test

* remove useless code

* change shard config not to enable sp when enable_all_optimizations

* add sp warnings for several model

* remove useless code

---------

Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn>

* [hotfix] quick fixes to make legacy tutorials runnable (#5559)

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [fix] fix typo s/muiti-node /multi-node etc. (#5448)

* [hotfix] fix typo s/get_defualt_parser /get_default_parser (#5548)

* [devops] remove post commit ci (#5566)

* [devops] remove post commit ci

* [misc] run pre-commit on all files

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

---------

Co-authored-by: binmakeswell <binmakeswell@gmail.com>
Co-authored-by: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com>
Co-authored-by: Wenhao Chen <cwher@outlook.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: Rocky Duan <dementrock@users.noreply.github.com>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions <github-actions@github.com>
Co-authored-by: Insu Jang <insujang@umich.edu>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com>
Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn>
Co-authored-by: digger yu <digger-yu@outlook.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [shardformer]enable padding vocabulary size. (#5489)

* padding vocab_size when using pipeline parallellism

padding vocab_size when using pipeline parallellism

fix

fix

* fix

* fix

fix

fix

* fix gather output

* fix

* fix

* fix

fix resize embedding

fix resize embedding

* fix resize embedding

fix

* revert

* revert

* revert

* padding vocab

* padding vocabe

* fix

* fix

* fxi

* test ci

* fix

fix

fix

fix

* fix

fix

* fix

* fix

* Update hybrid_parallel_plugin.py

fix

fix

fix

* fix

fix

* fix

fix

* fix

* resolve super init

resolve super init

resolve super init

resolve super init

* resolve comments

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* vocab checkpointio

* padding vocab_size when using pipeline parallellism

padding vocab_size when using pipeline parallellism

fix

fix

* fix

fix

fix

* fix

* fix

fix resize embedding

fix resize embedding

* fix resize embedding

fix

* revert

* revert

* padding vocab

* fix

* fix

fix

* fix

fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* cherry-pick

* revert moe modify

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

fix

fix

fix

fix

fix

fix

fix

* resolve comments

resolve comments

resolve comments

resolve comments

resolve comments

* ptensor

ptensor

resolve comments

fix

fix

fix

fix

fix

resolve comments

resolve comments

resolve comments

resolve comments

resolve comments

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix rebase

* fix rebase

---------

Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: digger yu <digger-yu@outlook.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
Co-authored-by: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com>
Co-authored-by: Wenhao Chen <cwher@outlook.com>
Co-authored-by: Rocky Duan <dementrock@users.noreply.github.com>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions <github-actions@github.com>
Co-authored-by: Insu Jang <insujang@umich.edu>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com>
Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [hotfix] Fix examples no pad token & auto parallel codegen bug; (#5606)

* fix no pad token bug

* fixed some auto parallel codegen bug, but might not run on torch 2.1

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [shardformer] fix pipeline grad ckpt (#5620)

* [shardformer] fix pipeline grad ckpt

* [lora] add lora APIs for booster, support lora for TorchDDP (#4981)

* add apis and peft requirement

* add liscense and implement apis

* add checkpointio apis

* add torchddp fwd_bwd test

* add support_lora methods

* add checkpointio test and debug

* delete unneeded codes

* remove peft from LICENSE

* add concrete methods for enable_lora

* simplify enable_lora api

* fix requirements

* [LowLevelZero] low level zero support lora (#5153)

* low level zero support lora

low level zero support lora

* add checkpoint test

* add checkpoint test

* fix

* fix

* fix

* fix

fix

fix

fix

* fix

* fix

fix

fix

fix

fix

fix

fix

* fix

* fix

fix

fix

fix

fix

fix

fix

* fix

* test ci

* git # This is a combination of 3 commits.

Update low_level_zero_plugin.py

Update low_level_zero_plugin.py

fix

fix

fix

* fix naming

fix naming

fix naming

fix

* [feature] qlora support

* qlora follow commit

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* migrate qutization folder to colossalai/

* minor fixes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* gptj sp fix

* remove redundancies from pre-commit

* minor fixes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: hugo-syn <hugo.vincent@synacktiv.com>
Co-authored-by: Jianghai <72591262+CjhHa1@users.noreply.github.com>
Co-authored-by: Bin Jia <45593998+FoolPlayer@users.noreply.github.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: yuehuayingxueluo <867460659@qq.com>
Co-authored-by: Cuiqing Li <lixx3527@gmail.com>
Co-authored-by: cuiqing.li <lixx336@gmail.com>
Co-authored-by: Yuanchen <70520919+chengeharrison@users.noreply.github.com>
Co-authored-by: Xu Yuanchen <yuanchen.xu00@gmail.com>
Co-authored-by: littsk <1214689160@qq.com>
Co-authored-by: Baizhou Zhang <eddiezhang@pku.edu.cn>
Co-authored-by: ppt0011 <143150326+ppt0011@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: Xuanlei Zhao <43881818+oahzxl@users.noreply.github.com>
Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions <github-actions@github.com>
Co-authored-by: Wenhao Chen <cwher@outlook.com>
Co-authored-by: Jun Gao <imgaojun@gmail.com>
Co-authored-by: flybird11111 <1829166702@qq.com>
Co-authored-by: Xu Kai <xukai16@foxmail.com>
Co-authored-by: Zian(Andy) Zheng <62330719+Orion-Zheng@users.noreply.github.com>
Co-authored-by: digger yu <digger-yu@outlook.com>
Co-authored-by: CjhHa1 <cjh18671720497@outlook.com>
Co-authored-by: Xu Kai <xukai16@foxamil.com>
Co-authored-by: Orion-Zheng <zheng_zian@u.nus.edu>
Co-authored-by: Elsa Granger <zeyugao@outlook.com>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: Orion-Zheng <zhengzian@u.nus.edu>
Co-authored-by: Pengtai Xu <henryxu880@gmail.com>
Co-authored-by: eric8607242 <e0928021388@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: Michelle <97082656+MichelleMa8@users.noreply.github.com>
Co-authored-by: Camille Zhong <44392324+Camille7777@users.noreply.github.com>
Co-authored-by: BlueRum <70618399+ht-zhou@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: JIMMY ZHAO <knightyzhao@gmail.com>
Co-authored-by: Xuanlei Zhao <xuanlei.zhao@gmail.com>
Co-authored-by: Desperado-Jia <502205863@qq.com>
Co-authored-by: 李文军 <40464906+liwenjuna@users.noreply.github.com>
Co-authored-by: yixiaoer <miyaku@yixiaoer.sg>
Co-authored-by: CZYCW <czyczf@163.com>
Co-authored-by: Stephan Kölker <stephankoe@users.noreply.github.com>
Co-authored-by: QinLuo <eric.x.sun@gmail.com>
Co-authored-by: MickeyCHAN <76671016+danyow-cheung@users.noreply.github.com>
Co-authored-by: Luo Yihang <luo_yihang@outlook.com>
Co-authored-by: Dongruixuan Li <dongruixuan@hotmail.com>
Co-authored-by: hugo-syn <61210734+hugo-syn@users.noreply.github.com>
Co-authored-by: Youngon <Youngon_wyl@163.com>
Co-authored-by: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com>
Co-authored-by: Rocky Duan <dementrock@users.noreply.github.com>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: Insu Jang <insujang@umich.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2024-04-23 17:57:44 +08:00

852 lines
37 KiB
Python

# this code is inspired by the DeepSpeed library and implemented with our own design from scratch
import copy
import math
import warnings
from typing import Any, Dict, Iterator, OrderedDict, Set, Tuple, Union
import torch
import torch.distributed as dist
from packaging.version import Version
from torch.distributed import ProcessGroup
from torch.nn import Parameter
from torch.optim import Optimizer
from colossalai.accelerator import get_accelerator
from colossalai.amp.naive_amp.mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin
from colossalai.checkpoint_io.utils import StateDictSharder, gather_distributed_param
from colossalai.interface import OptimizerWrapper
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import CPUAdam, FusedAdam, HybridAdam
from colossalai.tensor.d_tensor import (
distribute_tensor,
distribute_tensor_with_customization,
get_device_mesh,
get_global_shape,
get_sharding_spec,
init_as_dtensor,
init_tensor_as_customization_distributed,
is_customized_distributed_tensor,
is_distributed_tensor,
)
from colossalai.tensor.padded_tensor import (
init_as_padded_tensor,
is_padded_tensor,
to_padded_tensor,
to_unpadded_tensor,
)
from colossalai.utils import disposable, is_ddp_ignored
from .chunk import Chunk, ChunkManager
from .gemini_ddp import GeminiDDP
__all__ = ["GeminiOptimizer", "GeminiAdamOptimizer"]
_AVAIL_OPTIM_LIST = {FusedAdam, CPUAdam, HybridAdam}
class GeminiFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
def __init__(
self,
module: GeminiDDP,
initial_scale: float = 2**16,
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
max_scale: float = 2**32,
) -> None:
super().__init__(
initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, hysteresis, max_scale
)
self.module = module
def check_local_overflow(self) -> bool:
return self.module.overflow_counter > 0
def pre_zero_grad(self) -> None:
self.module.overflow_counter = 0
class GeminiOptimizer(OptimizerWrapper):
"""A wrapper for optimizer. ``GeminiDDP`` and ``GeminiOptimizer`` implement Zero Redundancy Optimizer (ZeRO state-3).
Note:
You must use ``GeminiDDP`` with ``GeminiOptimizer``.
Note:
Make sure you set ``placement_policy`` of ``GeminiManager`` to `"auto"`,
if you set ``gpu_margin_mem_ratio > 0``.
Args:
optim (Optimizer): An Optimizer instance.
module (GeminiDDP): A ``GeminiDDP`` instance.
gpu_margin_mem_ratio (float, optional): The ratio of GPU remaining memory (after the first forward-backward)
which will be used when using hybrid CPU optimizer.
This argument is meaningless when `placement_policy` of `GeminiManager` is not "auto".
Defaults to 0.0.
initial_scale (float, optional): Initial scale used by DynamicGradScaler. Defaults to 2**32.
min_scale (float, optional): Min scale used by DynamicGradScaler. Defaults to 1.
growth_factor (float, optional): Growth_factor used by DynamicGradScaler. Defaults to 2.
backoff_factor (float, optional): Backoff_factor used by DynamicGradScaler. Defaults to 0.5.
growth_interval (float, optional): Growth_interval used by DynamicGradScaler. Defaults to 1000.
hysteresis (float, optional): Hysteresis used by DynamicGradScaler. Defaults to 2.
max_scale (int, optional): Max_scale used by DynamicGradScaler. Defaults to 2**32.
max_norm (float, optional): The norm value used to clip gradient. Defaults to 0.0.
norm_type (float, optional): The type of norm used for gradient clipping. Currently, only L2-norm (norm_type=2.0)
is supported in GeminiOptimizer. Defaults to 2.0.
verbose (bool, optional): Whether to print verbose information, including grad overflow info. Defaults to False.
"""
def __init__(
self,
optim: Optimizer,
module: GeminiDDP,
gpu_margin_mem_ratio: float = 0.0,
initial_scale: float = 2**32,
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
max_scale: float = 2**32,
max_norm: float = 0.0,
norm_type: float = 2.0,
tp_group: ProcessGroup = None,
params_info=None,
verbose: bool = False,
**defaults: Any,
):
super().__init__(optim)
assert isinstance(module, GeminiDDP)
assert type(optim) in _AVAIL_OPTIM_LIST, (
"You should use an optimizer in the available list:\n" f"{_AVAIL_OPTIM_LIST}"
)
self.module = module
self.gemini_manager = module.gemini_manager
self.chunk_manager: ChunkManager = self.gemini_manager.chunk_manager
self.param_to_range: Dict[Parameter, Tuple[int, int]] = dict()
self.param_to_chunk16: Dict[Parameter, Chunk] = dict()
self.chunk16_set: Set[Chunk] = set()
self.clipping_flag = max_norm > 0.0
self.max_norm = max_norm
self.tp_group = tp_group
self.params_info = params_info
self.tp_size = dist.get_world_size(tp_group) if tp_group is not None else 1
self.tp_rank = dist.get_rank(tp_group) if tp_group is not None else 0
self.verbose = verbose
self.param_groups_backup = list()
# Mapping from integer id to real/fake param tensor, used for checkpointing.
self.id_to_real_params: Dict[int, Parameter] = dict()
self.id_to_fake_params: Dict[int, Parameter] = dict()
if self.clipping_flag:
assert norm_type == 2.0, "GeminiOptimizer only supports L2 norm now"
ddp_param_list = []
for name, param in module.named_parameters():
if is_ddp_ignored(param):
if param.requires_grad:
warnings.warn(
f"Parameter `{name}` is ignored by DDP but requires gradient! "
"You should handle its optimizer update by yourself!"
)
else:
ddp_param_list.append(param)
for p in ddp_param_list:
chunk_16 = self.chunk_manager.get_chunk(p)
if chunk_16 not in self.chunk16_set:
chunk_16.l2_norm_flag = self.clipping_flag
self.chunk16_set.add(chunk_16)
self.__init__optimizer()
if module.mixed_precision is torch.float16:
self.mix_precision_mixin = GeminiFP16MixedPrecisionMixin(
module,
initial_scale=initial_scale,
min_scale=min_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
max_scale=max_scale,
)
elif module.mixed_precision is torch.bfloat16:
self.mix_precision_mixin = BF16MixedPrecisionMixin()
else:
raise RuntimeError(f"Unsupported mixed precision type: {module.mixed_precision}")
self._logger = get_dist_logger()
self.gpu_margin_mem_ratio: float = float(gpu_margin_mem_ratio)
assert 0.0 <= self.gpu_margin_mem_ratio <= 1.0, f"gpu_margin_mem_ratio must >=0.0 and <=1.0"
# Only move fp32 shards from CPU to GPU when user allows and inner optimizer is valid
# Inner optimizer must support optimizing hybrid (CPU and CUDA) tensors,
# and it must set `num_fp32_shards_per_param` correctly
self._should_move_fp32_params_h2d: bool = (
self.gemini_manager.is_cuda_margin_mem_avail
and self.gpu_margin_mem_ratio > 0.0
and getattr(optim, "num_fp32_shards_per_param", 0) >= 2
)
if self.gpu_margin_mem_ratio > 0.0 and not self.gemini_manager.is_cuda_margin_mem_avail:
self._logger.warning(f'gpu_margin_mem_ratio is meaningless when placement_policy is not "auto"', ranks=[0])
self._register_states = disposable(self._register_states_)
def _set_grad_ptr(self):
for group in self.param_groups:
for fake_param in group["params"]:
chunk16 = self.param_to_chunk16[fake_param]
begin, end = self.param_to_range[fake_param]
grad_chunk16 = chunk16 if self.module.reuse_fp16_chunk else chunk16.grad_chunk
fake_param.data = grad_chunk16.payload[begin:end]
fake_param.grad = fake_param.data
to_update_chunk = chunk16.paired_chunk if self.module.master_weights else chunk16
fake_param.data = to_update_chunk.payload[begin:end]
def _update_fp16_params(self):
none_tensor = torch.empty([0])
for group in self.param_groups:
for fake_param in group["params"]:
assert fake_param.grad is None
fake_param.data = none_tensor.to(fake_param.device)
for chunk16 in self.chunk16_set:
chunk16.optim_update()
def _clear_global_norm(self) -> None:
for c16 in self.chunk16_set:
grad_chunk = c16 if self.module.reuse_fp16_chunk else c16.grad_chunk
grad_chunk.l2_norm = None
def _calc_global_norm(self) -> float:
norm_sqr: float = 0.0
group_to_norm = dict()
for c16 in self.chunk16_set:
grad_chunk = c16 if self.module.reuse_fp16_chunk else c16.grad_chunk
assert grad_chunk.l2_norm is not None
if grad_chunk.is_gathered:
norm_sqr += grad_chunk.l2_norm
else:
# this chunk is sharded, use communication to collect total norm
if grad_chunk.torch_pg not in group_to_norm:
group_to_norm[grad_chunk.torch_pg] = 0.0
group_to_norm[grad_chunk.torch_pg] += grad_chunk.l2_norm
grad_chunk.l2_norm = None # clear l2 norm
comm_buffer = torch.zeros(1, dtype=torch.float, device=get_accelerator().get_current_device())
for group, part_norm in group_to_norm.items():
comm_buffer.fill_(part_norm)
dist.all_reduce(comm_buffer, group=group)
norm_sqr += comm_buffer.item()
global_norm = math.sqrt(norm_sqr)
return global_norm
def _get_combined_scale(self):
div_scale = self.mix_precision_mixin.get_grad_div_scale()
if self.clipping_flag:
total_norm = self._calc_global_norm()
clip = ((total_norm / div_scale) + 1e-6) / self.max_norm
if clip > 1:
div_scale = clip * div_scale
return -1 if div_scale == 1.0 else div_scale
def zero_grad(self, *args, **kwargs):
self.mix_precision_mixin.pre_zero_grad()
return self.optim.zero_grad(set_to_none=True)
def step(self, *args, **kwargs):
if self.module.master_weights:
self._maybe_move_fp32_params()
self._set_grad_ptr()
if self.mix_precision_mixin.should_skip_step():
if self.verbose:
self._logger.info(f"Found overflow. Skip step")
self._clear_global_norm() # clear recorded norm
self.zero_grad() # reset all gradients
if self.module.reuse_fp16_chunk:
self._update_fp16_params()
return
# get combined scale. combined scale = loss scale * clipping norm
# so that gradient = gradient / combined scale
combined_scale = self._get_combined_scale()
ret = self.optim.step(div_scale=combined_scale, *args, **kwargs)
self._register_states()
self.zero_grad()
if self.module.master_weights:
self._update_fp16_params()
self.module.accumulating_grads = False
return ret
def clip_grad_norm(self, model: torch.nn.Module, max_norm: float, norm_type: float = 2.0):
raise NotImplementedError
def backward(self, loss: torch.Tensor):
loss = self.mix_precision_mixin.pre_backward(loss)
self.module.backward(loss)
def backward_by_grad(self, tensor: torch.Tensor, grad: torch.Tensor):
# This function is called except the last stage of pipeline parallel
# It receives the scaled grad from the previous rank
# No need to scale the grad again
# Need to unscale when optimizing
grad = self.mix_precision_mixin.pre_backward_by_grad(grad)
self.module.backward_by_grad(tensor, grad)
def _maybe_move_fp32_params(self):
if self._should_move_fp32_params_h2d:
self._should_move_fp32_params_h2d = False
available_cuda_margin_mem = self.gemini_manager.cuda_margin_mem * self.gpu_margin_mem_ratio
fp32_params_available_cuda_margin_mem = available_cuda_margin_mem / self.optim.num_fp32_shards_per_param
fp32_params_used_cuda_margin_mem = 0
for group in self.param_groups:
for fake_param in group["params"]:
chunk16 = self.param_to_chunk16[fake_param]
chunk32 = chunk16.paired_chunk
if chunk32.device_type == "cuda" or chunk32.device_type == "npu":
continue
if fp32_params_used_cuda_margin_mem + chunk32.payload_mem < fp32_params_available_cuda_margin_mem:
self.chunk_manager.move_chunk(chunk32, get_accelerator().get_current_device())
# stores grad now
self.chunk_manager.move_chunk(chunk16, get_accelerator().get_current_device())
self.module.set_chunk_grad_device(chunk16, get_accelerator().get_current_device())
fp32_params_used_cuda_margin_mem += chunk32.payload_mem
for group in self.param_groups:
for fake_param in group["params"]:
chunk16 = self.param_to_chunk16[fake_param]
chunk32 = chunk16.paired_chunk
if chunk32.device_type == "cuda" or chunk32.device_type == "npu":
state = self.optim.state[fake_param]
for k, v in state.items():
if isinstance(v, torch.Tensor):
state[k] = v.to(get_accelerator().get_current_device())
def _register_states_(self):
for group in self.optim.param_groups:
for p in group["params"]:
state = self.optim.state[p]
for val in state.values():
if isinstance(val, torch.Tensor):
self.chunk_manager.add_extern_static_tensor(val)
def __init__optimizer(self):
def get_range_pair(local_chunk: Chunk, local_param: Parameter):
param_info = local_chunk.tensors_info[local_param]
if local_chunk.keep_gathered:
return param_info.offset, param_info.end
begin = max(0, param_info.offset - local_chunk.shard_begin)
end = min(local_chunk.shard_size, param_info.end - local_chunk.shard_begin)
return begin, end
param_id = -1
for group in self.optim.param_groups:
fake_params_list = list()
group_backup = {k: v for k, v in group.items() if k != "params"}
group_ids = []
for param in group["params"]:
# Record the mapping of id to current param.
param_id += 1
self.id_to_real_params[param_id] = param
group_ids.append(param_id)
# If current param is controlled by current process, add it to fake_param.
if is_ddp_ignored(param):
continue
chunk16 = self.chunk_manager.get_chunk(param)
range_pair = get_range_pair(chunk16, param)
if range_pair[0] >= range_pair[1]:
continue
grad_device = self.module.grads_device[param]
fake_param = torch.nn.Parameter(torch.empty([0], device=grad_device))
self.param_to_chunk16[fake_param] = chunk16
self.param_to_range[fake_param] = range_pair
self.id_to_fake_params[param_id] = fake_param
fake_params_list.append(fake_param)
# Update self.optim.param_groups as well as backup group.
group["params"] = fake_params_list
group_backup["params"] = group_ids
self.param_groups_backup.append(group_backup)
def get_offsets(self, param_id: int) -> tuple:
"""
Args:
param_id(int): The id of parameter.
Returns:
chunk_offset(int): Offset of parameter inside the chunk.
shard_offset(int): Offset of its optimizer state shard
relative to the whole optimizer state.
shard_size(int): Length of parameter shard owned by current process.
"""
if param_id not in self.id_to_fake_params:
return -1, -1, -1
fake_param = self.id_to_fake_params[param_id]
chunk = self.param_to_chunk16[fake_param]
param = self.id_to_real_params[param_id]
param_info = chunk.tensors_info[param]
begin_in_chunk, end_in_chunk = self.param_to_range[fake_param]
chunk_offset = begin_in_chunk
if chunk.keep_gathered:
shard_offset = 0
else:
shard_offset = begin_in_chunk + chunk.shard_begin - param_info.offset
shard_size = end_in_chunk - begin_in_chunk
assert chunk_offset >= 0 and shard_offset >= 0
return chunk_offset, shard_offset, shard_size
def collect_states(self, param_id: int, only_rank_0: bool = True) -> dict:
"""
Args:
param_id (int): id of the parameter whose state is to be gathered at master rank.
only_rank_0(bool): if True, states will be collected only on master rank, otherwise collected on every rank.
Returns:
collected_states(dict): the gathered optimizer state of parameter with given id
if this method is called by master rank, otherwise an empty dict.
This method can work only when called by all processes simultaneously.
"""
# Get param & chunk & process group.
param = self.id_to_real_params[param_id]
fake_param = self.id_to_fake_params.get(param_id, None)
chunk = self.chunk_manager.get_chunk(param)
zero_group = chunk.torch_pg
rank = dist.get_rank(zero_group)
master_rank = 0
collected_states = {}
# Fetch names of states through all_gather.
local_state_names = None
if fake_param is not None:
local_state_names = list(self.optim.state[fake_param].keys())
gathered_state_names = [None for _ in range(dist.get_world_size(zero_group))]
dist.barrier()
dist.all_gather_object(gathered_state_names, local_state_names, zero_group)
state_names = None
for names in gathered_state_names:
if names is not None:
# Assume different devices share the same set of state names if they have.
state_names = copy.deepcopy(names)
break
# Directly return if this parameter doesn't have optimizer states.
# e.g. parameter freezed/layer dropped
if state_names is None:
return collected_states
# Boolean variable is_collector indicates that whether the current rank
# needs to gather the whole optimizer states.
# Only master rank is collector when only_rank_0 is True.
# Every rank is collector when only_rank_0 is False.
is_collector = (rank == master_rank) or (not only_rank_0)
# get tensor parallelism information
is_dtensor = is_distributed_tensor(param)
is_customized_distributed = is_customized_distributed_tensor(param)
shard_spec = get_sharding_spec(param) if is_dtensor else None
device_mesh = get_device_mesh(param) if is_dtensor else None
global_shape = self.params_info["id2shape"][param_id]
# If the chunk is kept gathered,
# the parameters are treated the same as that of those in strict DDP during training.
# So states can be directly fetched from current device.
if chunk.keep_gathered:
assert param_id in self.id_to_fake_params
if is_collector:
states = self.optim.state[fake_param]
for state_name in state_names:
if state_name == "step":
# To keep aligned with pytorch, state 'step' is stored as a pytorch tensor with type float32.
collected_states[state_name] = torch.tensor(
states["step"], dtype=torch.float32, requires_grad=False
).cpu()
else:
state_tensor = states[state_name].detach().clone().to(torch.float32).cpu()
if is_dtensor:
global_shape = get_global_shape(param)
state_tensor = torch.reshape(state_tensor, param.shape).to(param.device)
state_tensor = init_as_dtensor(
state_tensor,
device_mesh=device_mesh,
sharding_spec=shard_spec,
global_shape=global_shape,
)
elif is_customized_distributed:
state_tensor = torch.reshape(state_tensor, param.shape).to(param.device)
init_tensor_as_customization_distributed(
state_tensor, shard_fn=param.shard_fn, gather_fn=param.gather_fn
)
state_tensor = gather_distributed_param(state_tensor, keep_vars=False).cpu()
state_tensor = state_tensor.reshape(global_shape)
if is_padded_tensor(param):
state_tensor = init_as_padded_tensor(
state_tensor, param._current_length, param._origin_length, param._padding_dim
)
state_tensor = to_unpadded_tensor(state_tensor)
collected_states[state_name] = state_tensor
return collected_states
# Check whether the param with given id is managed by current process.
own_param = param_id in self.id_to_fake_params
# Collector gets prepared for state collecting.
if is_collector:
for state_name in state_names:
if state_name == "step":
# To keep aligned with pytorch, state 'step' is stored as a pytorch tensor with type float32.
collected_states[state_name] = torch.tensor(0.0, dtype=torch.float32, requires_grad=False).cpu()
else:
collected_states[state_name] = torch.zeros(
param.numel(), dtype=torch.float32, requires_grad=False
).cpu()
# Materials for gathering, including compacted state tensors, and the offset of shard inside each state.
compacted_states = self.pack_optimizer_states_to_tensor(param_id, state_names) if own_param else None
_, shard_offset, shard_size = self.get_offsets(param_id)
# Collectors gather state shards through all_gathering.
gathered_state_shards = [None for _ in range(dist.get_world_size(zero_group))]
dist.barrier()
dist.all_gather_object(gathered_state_shards, [compacted_states, shard_offset, shard_size], group=zero_group)
if is_collector:
for state_shard in gathered_state_shards:
compacted_states = state_shard[0]
shard_offset = state_shard[1]
shard_size = state_shard[2]
if compacted_states is None:
continue
self.load_from_compacted_states(
compacted_states, collected_states, state_names, shard_offset, shard_size
)
# Reshape tensors
if is_collector:
for state_name, state_tensor in collected_states.items():
if state_tensor.numel() == param.numel():
collected_states[state_name] = torch.reshape(state_tensor, param.shape)
if is_dtensor:
global_shape = get_global_shape(param)
state_tensor = state_tensor.to(param.device)
state_tensor = init_as_dtensor(
state_tensor, sharding_spec=shard_spec, device_mesh=device_mesh, global_shape=global_shape
)
elif is_customized_distributed:
state_tensor = state_tensor.to(param.device)
init_tensor_as_customization_distributed(
state_tensor, shard_fn=param.shard_fn, gather_fn=param.gather_fn
)
state_tensor = gather_distributed_param(state_tensor, keep_vars=False).cpu()
if is_padded_tensor(param):
state_tensor = init_as_padded_tensor(
state_tensor, param._current_length, param._origin_length, param._padding_dim
)
state_tensor = to_unpadded_tensor(state_tensor)
return collected_states
def pack_optimizer_states_to_tensor(
self,
param_id: int,
state_names: list,
device: torch.device = get_accelerator().get_current_device(),
dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""
With param id given, pack its optimizer states into a compact tensor and return.
"""
if param_id not in self.id_to_fake_params:
return None
fake_param = self.id_to_fake_params[param_id]
param_range = self.param_to_range[fake_param]
states = self.optim.state[fake_param]
shard_size = param_range[1] - param_range[0]
compacted_size = 0
for name in state_names:
if name == "step":
compacted_size += 1
else:
compacted_size += shard_size
compacted_states = torch.zeros(compacted_size, dtype=dtype, device=device, requires_grad=False)
next_state_offset = 0
for state_name, state_tensor in states.items():
# State 'step' needs special operation.
if state_name == "step":
if isinstance(state_tensor, torch.Tensor):
compacted_states[next_state_offset] = state_tensor[0].item()
else:
assert isinstance(state_tensor, int)
compacted_states[next_state_offset] = state_tensor
next_state_offset += 1
else:
assert state_tensor.numel() == shard_size
compacted_states[next_state_offset : next_state_offset + shard_size].copy_(state_tensor)
next_state_offset += shard_size
return compacted_states
def load_from_compacted_states(
self,
compacted_states: torch.Tensor,
collected_states: dict,
state_names: list,
shard_start: int,
shard_size: int,
):
"""
Given a tensor carrying compacted optimizer states,
update these states to collected_states.
"""
shard_end = shard_start + shard_size
next_state_offset = 0
for state_name in state_names:
if state_name == "step":
collected_states["step"].data = torch.tensor(
compacted_states[next_state_offset].item(), dtype=torch.float32, requires_grad=False
).cpu()
next_state_offset += 1
else:
target_segment = collected_states[state_name][shard_start:shard_end]
target_segment.copy_(compacted_states[next_state_offset : next_state_offset + shard_size])
next_state_offset += shard_size
def get_param_groups_for_saving(self) -> list:
"""
Return the param_groups in Pytorch format when saving to checkpoint.
"""
param_groups = [
{**group, "params": group_info["params"]}
for group, group_info in zip(self.optim.param_groups, self.param_groups_backup)
]
# To be compatible with pytorch checkpointing,
# store extra hyperparameters used by pytorch Adam optimizer.
torch_special_hyperparameters = {
"amsgrad": False,
"maximize": False,
"foreach": None,
"capturable": False,
"differentiable": False,
"fused": False,
}
for group in param_groups:
for k, v in torch_special_hyperparameters.items():
if k not in group:
group[k] = v
return param_groups
def state_dict(self, only_rank_0: bool = True) -> dict:
"""
Args:
only_rank_0 (bool): a boolean value indicating whether the state_dict is collected
only on rank 0, default to True.
Returns:
The complete state of the optimizer as a :class:`dict`.
It contains two entries:
* state - a dict holding current optimization state. Its content
differs between optimizer classes.
* param_groups - a list containing all parameter groups where each
parameter group is a dict.
Warning: This method will gather and return the whole optimizer state_dict,
so it should be called only when memory resources are abundant.
"""
state_dict = {}
state_dict["param_groups"] = self.get_param_groups_for_saving()
# Collect optimizer states.
state_dict["state"] = dict()
for param_id in self.id_to_real_params.keys():
dist.barrier()
state_dict["state"][param_id] = self.collect_states(param_id=param_id, only_rank_0=only_rank_0)
return state_dict
def load_param_groups(self, saved_param_groups: list):
"""
Load saved_param_groups into
self.param_groups and self.param_groups_backup
"""
self.param_groups_backup = copy.deepcopy(saved_param_groups)
# discard the older param_groups
self.optim.param_groups = []
for group in saved_param_groups:
fake_params_list = list()
updated_group = {k: v for k, v in group.items() if k != "params"}
for param_id in group["params"]:
if param_id not in self.id_to_fake_params:
continue
fake_param = self.id_to_fake_params[param_id]
fake_params_list.append(fake_param)
updated_group["params"] = fake_params_list
self.optim.param_groups.append(updated_group)
def load_single_param_states(self, param_id: int, saved_states: dict):
"""
Load saved optimizer states into parameter with given id.
"""
def cast(param, state_range, value, global_shape, origin_shape, key=None):
"""
Make a copy of the needed segment of value and cast it to device of param.
"""
assert isinstance(value, torch.Tensor)
ret_val = value
if key == "step":
assert value.numel() == 1
ret_val = int(value.item())
else:
state_start, state_end = state_range
ret_val = torch.zeros(
state_end - state_start, dtype=torch.float32, device=param.device, requires_grad=False
)
if is_dtensor:
global_shape = get_global_shape(real_param)
if is_padded_tensor(real_param):
value = torch.reshape(value, origin_shape)
padding_dim = real_param._padding_dim
value = to_padded_tensor(value, global_shape[padding_dim], padding_dim)
if is_dtensor:
value = distribute_tensor(value, sharding_spec=shard_spec, device_mesh=device_mesh)
elif is_customized_distributed:
value = torch.reshape(value, global_shape)
value = distribute_tensor_with_customization(value, real_param.shard_fn, real_param.gather_fn)
ret_val.copy_(value.flatten()[state_start:state_end])
return ret_val
assert param_id in self.id_to_fake_params
fake_param = self.id_to_fake_params[param_id]
_, state_offset, param_size = self.get_offsets(param_id)
state_range = (state_offset, state_offset + param_size)
# Copy states assigned to param (and cast tensors to appropriate types).
updated_states = dict()
# get tensor parallelism information
real_param = self.id_to_real_params[param_id]
is_dtensor = is_distributed_tensor(real_param)
is_customized_distributed = is_customized_distributed_tensor(real_param)
shard_spec = get_sharding_spec(real_param) if is_dtensor else None
device_mesh = get_device_mesh(real_param) if is_dtensor else None
global_shape = self.params_info["id2shape"][param_id]
origin_shape = global_shape
for k, v in saved_states.items():
updated_states[k] = cast(fake_param, state_range, v, global_shape, origin_shape, k)
del v # clean loaded states
self.optim.state[fake_param].update(updated_states)
def load_param_states(self, param_states: dict):
"""Loads param states from a state_dict. The param_states can be complete or sharded.
During loading, filter out the part of states not considered by current process.
Args:
param_states (dict): A mapping from param_id to its states.
"""
for param_id, states in param_states.items():
if param_id in self.id_to_fake_params:
self.load_single_param_states(param_id, states)
def optimizer_loading_epilogue(self):
# Epilogue when loading state_dict to pytorch optimizer.
if Version(torch.__version__) >= Version("2.0.0"):
self.optim._patch_step_function() # To support multiprocessing pickle/unpickle
else:
self.optim._hook_for_profile() # To support multiprocessing pickle/unpickle.
self.optim.defaults.setdefault("differentiable", False)
def load_state_dict(self, state_dict: dict):
"""Loads optimizer state from complete optimizer state_dict.
During loading, filter out the part of states not considered by current process.
Args:
state_dict (dict): optimizer state. Should be an object returned
from a call to :meth:`state_dict`.
"""
assert "param_groups" in state_dict
assert "state" in state_dict
self.load_param_groups(state_dict["param_groups"])
self.load_param_states(state_dict["state"])
self.optimizer_loading_epilogue()
def state_shard(
self, prefix: str = "", max_shard_size: int = 1024, only_rank_0: bool = True
) -> Iterator[Tuple[OrderedDict, int]]:
"""Returns dictionaries containing shards of optimizer states one by one.
The max size of each dictionary shard is specified by ``max_shard_size``.
Args:
prefix (str, optional): the prefix for states. Default to ''.
max_shard_size (int, optional): max size of state dict shard (in MB). Defaults to 1024.
only_rank_0 (bool, optional): a boolean value indicating whether the state_dict is collected
only on rank 0, default to True.
Yields:
Iterator[OrderedDict]: A generator of state dict shard of optimizer states.
"""
sharder = StateDictSharder(max_shard_size)
for param_id in self.id_to_real_params.keys():
dist.barrier()
state = self.collect_states(param_id=param_id, only_rank_0=only_rank_0)
block, block_size = sharder.append_optim_state(param_id, state)
if block is not None:
yield block, block_size
yield sharder.current_block, sharder.current_block_size
def clip_grad_by_value(self, clip_value: float, *args, **kwargs) -> None:
raise NotImplementedError("Gemini does not support clip_grad_by_value")
def clip_grad_by_norm(
self,
max_norm: Union[float, int],
norm_type: Union[float, int] = 2,
error_if_nonfinite: bool = False,
*args,
**kwargs,
) -> torch.Tensor:
warnings.warn(f"Gemini controls grad clipping by itself, so you should not use clip_grad_by_norm")
class GeminiAdamOptimizer(GeminiOptimizer):
def __init__(self, model: torch.nn.Module, **defaults: Any) -> None:
optimizer = HybridAdam(model.parameters(), **defaults)
super().__init__(optimizer, model, **defaults)