ColossalAI/colossalai/zero/gemini/chunk/chunk.py
linsj20 fcf776ff1b
[Feature] LoRA rebased to main branch (#5622)
* [Inference]ADD Bench Chatglm2 script (#4963)

* add bench chatglm

* fix bug and make utils

---------

Co-authored-by: CjhHa1 <cjh18671720497outlook.com>

* [Pipeline inference] Combine kvcache with pipeline inference (#4938)

* merge kvcache with pipeline inference and refactor the code structure

* support ppsize > 2

* refactor pipeline code

* do pre-commit

* modify benchmark

* fix bench mark

* polish code

* add docstring and update readme

* refactor the code

* fix some logic bug of ppinfer

* polish readme

* fix typo

* skip infer test

* updated c++17 compiler flags (#4983)

* [Inference] Dynamic Batching Inference, online and offline (#4953)

* [inference] Dynamic Batching for Single and Multiple GPUs (#4831)

* finish batch manager

* 1

* first

* fix

* fix dynamic batching

* llama infer

* finish test

* support different lengths generating

* del prints

* del prints

* fix

* fix bug

---------

Co-authored-by: CjhHa1 <cjh18671720497outlook.com>

* [inference] Async dynamic batching  (#4894)

* finish input and output logic

* add generate

* test forward

* 1

* [inference]Re push async dynamic batching (#4901)

* adapt to ray server

* finish async

* finish test

* del test

---------

Co-authored-by: yuehuayingxueluo <867460659@qq.com>

* Revert "[inference]Re push async dynamic batching (#4901)" (#4905)

This reverts commit fbf3c09e67.

* Revert "[inference] Async dynamic batching  (#4894)"

This reverts commit fced140250.

* Revert "[inference] Async dynamic batching  (#4894)" (#4909)

This reverts commit fced140250.

* Add Ray Distributed Environment Init Scripts

* support DynamicBatchManager base function

* revert _set_tokenizer version

* add driver async generate

* add async test

* fix bugs in test_ray_dist.py

* add get_tokenizer.py

* fix code style

* fix bugs about No module named 'pydantic' in ci test

* fix bugs in ci test

* fix bugs in ci test

* fix bugs in ci test

* [infer]Add Ray Distributed Environment Init Scripts (#4911)

* Revert "[inference] Async dynamic batching  (#4894)"

This reverts commit fced140250.

* Add Ray Distributed Environment Init Scripts

* support DynamicBatchManager base function

* revert _set_tokenizer version

* add driver async generate

* add async test

* fix bugs in test_ray_dist.py

* add get_tokenizer.py

* fix code style

* fix bugs about No module named 'pydantic' in ci test

* fix bugs in ci test

* fix bugs in ci test

* fix bugs in ci test

* support dynamic batch for bloom model and is_running function

* [Inference]Test for new Async engine (#4935)

* infer engine

* infer engine

* test engine

* test engine

* new manager

* change step

* add

* test

* fix

* fix

* finish test

* finish test

* finish test

* finish test

* add license

---------

Co-authored-by: yuehuayingxueluo <867460659@qq.com>

* add assertion for config (#4947)

* [Inference] Finish dynamic batching offline test (#4948)

* test

* fix test

* fix quant

* add default

* fix

* fix some bugs

* fix some bugs

* fix

* fix bug

* fix bugs

* reset param

---------

Co-authored-by: yuehuayingxueluo <867460659@qq.com>
Co-authored-by: Cuiqing Li <lixx3527@gmail.com>
Co-authored-by: CjhHa1 <cjh18671720497outlook.com>

* [Kernels]Updated Triton kernels into 2.1.0 and adding flash-decoding for llama token attention  (#4965)

* adding flash-decoding

* clean

* adding kernel

* adding flash-decoding

* add integration

* add

* adding kernel

* adding kernel

* adding triton 2.1.0 features for inference

* update bloom triton kernel

* remove useless vllm kernels

* clean codes

* fix

* adding files

* fix readme

* update llama flash-decoding

---------

Co-authored-by: cuiqing.li <lixx336@gmail.com>

* fix ColossalEval (#4992)

Co-authored-by: Xu Yuanchen <yuanchen.xu00@gmail.com>

* [doc]Update doc for colossal-inference (#4989)

* update doc

* Update README.md

---------

Co-authored-by: cuiqing.li <lixx336@gmail.com>

* [hotfix] Fix the bug where process groups were not being properly released. (#4940)

* Fix the bug where process groups were not being properly released.

* test

* Revert "test"

This reverts commit 479900c139.

* [hotfix] fix the bug of repeatedly storing param group (#4951)

* [doc] add supported feature diagram for hybrid parallel plugin (#4996)

* [Pipeline Inference] Merge pp with tp (#4993)

* refactor pipeline into new CaiInferEngine

* updata llama modeling forward

* merge tp with pp

* update docstring

* optimize test workflow and example

* fix typo

* add assert and todo

* [release] update version (#4995)

* [release] update version

* [hotfix] fix ci

* [moe] merge moe into main (#4978)

* update moe module
* support openmoe

* [hotfix] fix grad accumulation plus clipping for gemini (#5002)

* [hotfix] Add layer norm gradients all-reduce for sequence parallel (#4926)

* [hotfix] Add layer norm gradients all-reduce for sequence parallel. (#4915)

* Add layer norm gradients all-reduce for sequence parallel.

* skip pipeline inference test

* [hotfix] fixing polices of sequence parallel (#4922)

* Add layer norm gradients all-reduce for sequence parallel.

* fix parameter passing when calling get_autopolicy

---------

Co-authored-by: littsk <1214689160@qq.com>

* Hotfix/add grad all reduce for sequence parallel (#4927)

* Add layer norm gradients all-reduce for sequence parallel.


* fix parameter passing when calling get_autopolicy

* fix bug using wrong variables

---------

Co-authored-by: littsk <1214689160@qq.com>

* fix policy initialization

* fix bloom and chatglm policices

* polish code of handling layernorm

* fix moe module

* polish code of class initializing

---------

Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com>

* [format] applied code formatting on changed files in pull request 4926 (#5007)

Co-authored-by: github-actions <github-actions@github.com>

* [Inference] Fix bug in ChatGLM2 Tensor Parallelism (#5014)

* fix bug

* fix

* fix multiquery

* fix multiquery

---------

Co-authored-by: CjhHa1 <cjh18671720497outlook.com>

* [misc] add code owners (#5024)

* [moe] support optimizer checkpoint (#5015)

* Refactor MoE Manager setup method

* unshard optim ckpt

* optim io

* update transformer version

* update requirements

* update ckpt

* update ckpt

* update ckpt

* fix engine

* fix engine

* Support mtbench (#5025)

Co-authored-by: Xu Yuanchen <yuanchen.xu00@gmail.com>

* [moe]: fix ep/tp tests, add hierarchical all2all (#4982)

* fix: add warning for EP different behavior

* fix: use shard_data in ep & tp model

* to: add used_capacity

* fix: fix router test

* feat: add create_ep_node_group

* feat: add create_ep_hierarchical_group fn

* feat: add HierarchicalAllToAll

* test: add hierarchical all2all test

* fix: fix test errors

* fix: simplify create_ep_hierarchical_group

* fix: add hierarchical_alltoall arg

* fix: fix environ typo

* revert: revert process mesh order

* to: add todo mark

* fix: skip hierarchical_comm if torch < 1.13.1

* [shardformer] Fix serialization error with Tensor Parallel state saving (#5018)

* Fix serialization error with Tensor Parallel state saving

* Refactor state_dict CPU transfer using tree_map

* [gemini] gemini support tensor parallelism. (#4942)

* [colossalai]fix typo

* [inference] Add smmoothquant for llama (#4904)

* [inference] add int8 rotary embedding kernel for smoothquant (#4843)

* [inference] add smoothquant llama attention (#4850)

* add smoothquant llama attention

* remove uselss code

* remove useless code

* fix import error

* rename file name

* [inference] add silu linear fusion for smoothquant llama mlp  (#4853)

* add silu linear

* update skip condition

* catch smoothquant cuda lib exception

* prcocess exception for tests

* [inference] add llama mlp for smoothquant (#4854)

* add llama mlp for smoothquant

* fix down out scale

* remove duplicate lines

* add llama mlp check

* delete useless code

* [inference] add smoothquant llama (#4861)

* add smoothquant llama

* fix attention accuracy

* fix accuracy

* add kv cache and save pretrained

* refactor example

* delete smooth

* refactor code

* [inference] add smooth function and delete useless code for smoothquant (#4895)

* add smooth function and delete useless code

* update datasets

* remove duplicate import

* delete useless file

* refactor codes (#4902)

* rafactor code

* add license

* add torch-int and smoothquant license

* Update flash_attention_patch.py

To be compatible with the new change in the Transformers library, where a new argument 'padding_mask' was added to forward function of attention layer.
https://github.com/huggingface/transformers/pull/25598

* [kernel] support pure fp16 for cpu adam and update gemini optim tests (#4921)

* [kernel] support pure fp16 for cpu adam (#4896)

* [kernel] fix cpu adam kernel for pure fp16 and update tests (#4919)

* [kernel] fix cpu adam

* [test] update gemini optim test

* [format] applied code formatting on changed files in pull request 4908 (#4918)

Co-authored-by: github-actions <github-actions@github.com>

* [gemini] support gradient accumulation (#4869)

* add test

* fix no_sync bug in low level zero plugin

* fix test

* add argument for grad accum

* add grad accum in backward hook for gemini

* finish implementation, rewrite tests

* fix test

* skip stuck model in low level zero test

* update doc

* optimize communication & fix gradient checkpoint

* modify doc

* cleaning codes

* update cpu adam fp16 case

* [hotfix] fix torch 2.0 compatibility (#4936)

* [hotfix] fix launch

* [test] fix test gemini optim

* [shardformer] fix vit

* [test] add no master test for low level zero plugin (#4934)

* [format] applied code formatting on changed files in pull request 4820 (#4886)

Co-authored-by: github-actions <github-actions@github.com>

* [nfc] fix some typo with colossalai/ docs/ etc. (#4920)

* [Refactor] Integrated some lightllm kernels into token-attention  (#4946)

* add some req for inference

* clean codes

* add codes

* add some lightllm deps

* clean codes

* hello

* delete rms files

* add some comments

* add comments

* add doc

* add lightllm deps

* add lightllm cahtglm2 kernels

* add lightllm cahtglm2 kernels

* replace rotary embedding with lightllm kernel

* add some commnets

* add some comments

* add some comments

* add

* replace fwd kernel att1

* fix a arg

* add

* add

* fix token attention

* add some comments

* clean codes

* modify comments

* fix readme

* fix bug

* fix bug

---------

Co-authored-by: cuiqing.li <lixx336@gmail.com>
Co-authored-by: CjhHa1 <cjh18671720497@outlook.com>

* [test] merge old components to test to model zoo (#4945)

* [test] add custom models in model zoo

* [test] update legacy test

* [test] update model zoo

* [test] update gemini test

* [test] remove components to test

* [inference] add reference and fix some bugs (#4937)

* add reference and fix some bugs

* update gptq init

---------

Co-authored-by: Xu Kai <xukai16@foxamil.com>

* [Inference]ADD Bench Chatglm2 script (#4963)

* add bench chatglm

* fix bug and make utils

---------

Co-authored-by: CjhHa1 <cjh18671720497outlook.com>

* [Pipeline inference] Combine kvcache with pipeline inference (#4938)

* merge kvcache with pipeline inference and refactor the code structure

* support ppsize > 2

* refactor pipeline code

* do pre-commit

* modify benchmark

* fix bench mark

* polish code

* add docstring and update readme

* refactor the code

* fix some logic bug of ppinfer

* polish readme

* fix typo

* skip infer test

* updated c++17 compiler flags (#4983)

* [Inference] Dynamic Batching Inference, online and offline (#4953)

* [inference] Dynamic Batching for Single and Multiple GPUs (#4831)

* finish batch manager

* 1

* first

* fix

* fix dynamic batching

* llama infer

* finish test

* support different lengths generating

* del prints

* del prints

* fix

* fix bug

---------

Co-authored-by: CjhHa1 <cjh18671720497outlook.com>

* [inference] Async dynamic batching  (#4894)

* finish input and output logic

* add generate

* test forward

* 1

* [inference]Re push async dynamic batching (#4901)

* adapt to ray server

* finish async

* finish test

* del test

---------

Co-authored-by: yuehuayingxueluo <867460659@qq.com>

* Revert "[inference]Re push async dynamic batching (#4901)" (#4905)

This reverts commit fbf3c09e67.

* Revert "[inference] Async dynamic batching  (#4894)"

This reverts commit fced140250.

* Revert "[inference] Async dynamic batching  (#4894)" (#4909)

This reverts commit fced140250.

* Add Ray Distributed Environment Init Scripts

* support DynamicBatchManager base function

* revert _set_tokenizer version

* add driver async generate

* add async test

* fix bugs in test_ray_dist.py

* add get_tokenizer.py

* fix code style

* fix bugs about No module named 'pydantic' in ci test

* fix bugs in ci test

* fix bugs in ci test

* fix bugs in ci test

* [infer]Add Ray Distributed Environment Init Scripts (#4911)

* Revert "[inference] Async dynamic batching  (#4894)"

This reverts commit fced140250.

* Add Ray Distributed Environment Init Scripts

* support DynamicBatchManager base function

* revert _set_tokenizer version

* add driver async generate

* add async test

* fix bugs in test_ray_dist.py

* add get_tokenizer.py

* fix code style

* fix bugs about No module named 'pydantic' in ci test

* fix bugs in ci test

* fix bugs in ci test

* fix bugs in ci test

* support dynamic batch for bloom model and is_running function

* [Inference]Test for new Async engine (#4935)

* infer engine

* infer engine

* test engine

* test engine

* new manager

* change step

* add

* test

* fix

* fix

* finish test

* finish test

* finish test

* finish test

* add license

---------

Co-authored-by: yuehuayingxueluo <867460659@qq.com>

* add assertion for config (#4947)

* [Inference] Finish dynamic batching offline test (#4948)

* test

* fix test

* fix quant

* add default

* fix

* fix some bugs

* fix some bugs

* fix

* fix bug

* fix bugs

* reset param

---------

Co-authored-by: yuehuayingxueluo <867460659@qq.com>
Co-authored-by: Cuiqing Li <lixx3527@gmail.com>
Co-authored-by: CjhHa1 <cjh18671720497outlook.com>

* [Kernels]Updated Triton kernels into 2.1.0 and adding flash-decoding for llama token attention  (#4965)

* adding flash-decoding

* clean

* adding kernel

* adding flash-decoding

* add integration

* add

* adding kernel

* adding kernel

* adding triton 2.1.0 features for inference

* update bloom triton kernel

* remove useless vllm kernels

* clean codes

* fix

* adding files

* fix readme

* update llama flash-decoding

---------

Co-authored-by: cuiqing.li <lixx336@gmail.com>

* fix ColossalEval (#4992)

Co-authored-by: Xu Yuanchen <yuanchen.xu00@gmail.com>

* [doc]Update doc for colossal-inference (#4989)

* update doc

* Update README.md

---------

Co-authored-by: cuiqing.li <lixx336@gmail.com>

* [hotfix] Fix the bug where process groups were not being properly released. (#4940)

* Fix the bug where process groups were not being properly released.

* test

* Revert "test"

This reverts commit 479900c139.

* [hotfix] fix the bug of repeatedly storing param group (#4951)

* [doc] add supported feature diagram for hybrid parallel plugin (#4996)

* [Pipeline Inference] Merge pp with tp (#4993)

* refactor pipeline into new CaiInferEngine

* updata llama modeling forward

* merge tp with pp

* update docstring

* optimize test workflow and example

* fix typo

* add assert and todo

* [release] update version (#4995)

* [release] update version

* [hotfix] fix ci

* [gemini] gemini support tp

[gemini] gemini support tp

[gemini] gemini support tp

[gemini] gemini support tp

[gemini] gemini support tp

* fix

fix

fix

* update checkpointIO

update checkpointIO

update checkpointIO

update checkpointIO

update checkpointIO

update checkpointIO

update checkpointIO

update checkpointIO

update checkpointIO

* support fused layernorm

support fused layernorm

support fused layernorm

* update fusedlayernorm

update fusedlayernorm

update fusedlayernorm

* add sequence parallel to gemini

add sequence parallel to gemini

* fix

* fix comments

fix comments

fix comments

* fix

* fix t5

* clear cache

* fix

* activate ci

* activate ci

* fix

* fix

* fix

* fix

* revert

* modify tp gather method

modify tp gather method

modify tp gather method

modify tp gather method

* fix test

---------

Co-authored-by: Xu Kai <xukai16@foxmail.com>
Co-authored-by: Zian(Andy) Zheng <62330719+Orion-Zheng@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions <github-actions@github.com>
Co-authored-by: Baizhou Zhang <eddiezhang@pku.edu.cn>
Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com>
Co-authored-by: digger yu <digger-yu@outlook.com>
Co-authored-by: Cuiqing Li <lixx3527@gmail.com>
Co-authored-by: cuiqing.li <lixx336@gmail.com>
Co-authored-by: CjhHa1 <cjh18671720497@outlook.com>
Co-authored-by: Xu Kai <xukai16@foxamil.com>
Co-authored-by: Jianghai <72591262+CjhHa1@users.noreply.github.com>
Co-authored-by: Bin Jia <45593998+FoolPlayer@users.noreply.github.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: yuehuayingxueluo <867460659@qq.com>
Co-authored-by: Yuanchen <70520919+chengeharrison@users.noreply.github.com>
Co-authored-by: Xu Yuanchen <yuanchen.xu00@gmail.com>
Co-authored-by: littsk <1214689160@qq.com>
Co-authored-by: ppt0011 <143150326+ppt0011@users.noreply.github.com>

* [hotfix] Suport extra_kwargs in ShardConfig (#5031)

* [refactor]: replace inference args with extra_kwargs in ShardConfig

* modify shardconfig

* polish code

* fix policy bug in llama

* fix bug in auto policy

* remove setattr in ShardConfig

* fix wrong EOS token in ColossalChat

* [Kernels]Update triton kernels into 2.1.0 (#5046)

* update flash-context-attention

* adding kernels

* fix

* reset

* add build script

* add building process

* add llama2 exmaple

* add colossal-llama2 test

* clean

* fall back test setting

* fix test file

* clean

* clean

* clean

---------

Co-authored-by: cuiqing.li <lixx336@gmail.com>

* [pipeline,shardformer] Fix p2p efficiency in pipeline, allow skipping loading weight not in weight_map when `strict=False`, fix llama flash attention forward, add flop estimation by megatron in llama benchmark (#5017)

* Use p2p

* Cannot bidirectonal send p2p

* Refactor tensor creation and serialization in P2P
communication

* Fix llama forward args in flash attention

* Add flop estimate from megatron

* Support loading weight not in weight_map when strict=False in hybrid_parallel

* Use send_forward_recv_backward, etc in 1f1b

* Use dataclass for metdata
Remove torch.cuda.synchronize() as suggested

* Add comment about the torch.cuda.synchronize for potential error

* Typo

* Update hybrid_parallel_checkpoint_io.py

* Update p2p.py

* Update one_f_one_b.py

* Update p2p.py

---------

Co-authored-by: flybird11111 <1829166702@qq.com>

* [gemini] gemini support extra-dp (#5043)

* support ddp

* fix

* fix

* fix

fix

* support ddp

* fix

* fix

* fix

fix

* simplify tests

* fix

* fix

* fix

fix

fix

* fix

* [shardformer] fix llama error when transformers upgraded. (#5055)

* fix-llama

* Update llama.py

* [hotfix]: modify create_ep_hierarchical_group and add test (#5032)

* feat: modify create_ep_hierarchical_group args

* test: add ep tests

* fix: remove get_process_group_ranks

* fix: fix src_rank

* [exampe] fix llama example' loss error when using gemini plugin (#5060)

fix llama example

* [inference] Refactor inference architecture (#5057)

* [inference] support only TP (#4998)

* support only tp

* enable tp

* add support for bloom (#5008)

* [refactor] refactor gptq and smoothquant llama (#5012)

* refactor gptq and smoothquant llama

* fix import error

* fix linear import torch-int

* fix smoothquant llama import error

* fix import accelerate error

* fix bug

* fix import smooth cuda

* fix smoothcuda

* [Inference Refactor] Merge chatglm2 with pp and tp (#5023)

merge chatglm with pp and tp

* [Refactor] remove useless inference code (#5022)

* remove useless code

* fix quant model

* fix test import bug

* mv original inference legacy

* fix chatglm2

* [Refactor] refactor policy search and quant type controlling in inference (#5035)

* [Refactor] refactor policy search and quant type controling in inference

* [inference] update readme (#5051)

* update readme

* update readme

* fix architecture

* fix table

* fix table

* [inference] udpate example (#5053)

* udpate example

* fix run.sh

* fix rebase bug

* fix some errors

* update readme

* add some features

* update interface

* update readme

* update benchmark

* add requirements-infer

---------

Co-authored-by: Bin Jia <45593998+FoolPlayer@users.noreply.github.com>
Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com>

* [Kernels]added flash-decoidng of triton (#5063)

* added flash-decoidng of triton based on lightllm kernel

* add req

* clean

* clean

* delete build.sh

---------

Co-authored-by: cuiqing.li <lixx336@gmail.com>

* [misc] remove outdated submodule (#5070)

* [npu] add npu support for gemini and zero (#5067)

* [npu] setup device utils (#5047)

* [npu] add npu device support

* [npu] support low level zero

* [test] update npu zero plugin test

* [hotfix] fix import

* [test] recover tests

* [npu] gemini support npu (#5052)

* [npu] refactor device utils

* [gemini] support npu

* [example] llama2+gemini support npu

* [kernel] add arm cpu adam kernel (#5065)

* [kernel] add arm cpu adam

* [optim] update adam optimizer

* [kernel] arm cpu adam remove bf16 support

* [hotfix/hybridengine] fix bug when tp*pp size = 1 (#5069)

* [inference] update examples and engine (#5073)

* update examples and engine

* fix choices

* update example

* [format] applied code formatting on changed files in pull request 5067 (#5072)

Co-authored-by: github-actions <github-actions@github.com>

* [hotfix/hybridengine] Fix init model with random parameters in benchmark (#5074)

* fix init model with random parameters

* fix example

* [inference] refactor examples and fix schedule (#5077)

* [setup] refactor infer setup

* [hotfix] fix infenrece behavior on 1 1 gpu

* [exmaple] refactor inference examples

* fix thrust-transform-reduce error (#5078)

* [nfc] fix typo in docs/ (#4972)

* [nfc] fix typo and author name (#5089)

* [gemini]fix gemini optimzer, saving Shardformer in Gemini got list assignment index out of range (#5085)

* [Hotfix] Fix model policy matching strategy in ShardFormer (#5064)

* hotfix/Fix get model policy strategy in ShardFormer

* fix bug in auto policy

* [shardformer]fix flash attention, when mask is casual, just don't unpad it (#5084)

* fix flash attn

* fix

fix

* [npu] add npu support for hybrid plugin and llama (#5090)

* llama 3d

* update

* fix autocast

* [Feature] Add document retrieval QA (#5020)

* add langchain

* add langchain

* Add files via upload

* add langchain

* fix style

* fix style: remove extra space

* add pytest; modified retriever

* add pytest; modified retriever

* add tests to build_on_pr.yml

* fix build_on_pr.yml

* fix build on pr; fix environ vars

* seperate unit tests for colossalqa from build from pr

* fix container setting; fix environ vars

* commented dev code

* add incremental update

* remove stale code

* fix style

* change to sha3 224

* fix retriever; fix style; add unit test for document loader

* fix ci workflow config

* fix ci workflow config

* add set cuda visible device script in ci

* fix doc string

* fix style; update readme; refactored

* add force log info

* change build on pr, ignore colossalqa

* fix docstring, captitalize all initial letters

* fix indexing; fix text-splitter

* remove debug code, update reference

* reset previous commit

* update LICENSE update README add key-value mode, fix bugs

* add files back

* revert force push

* remove junk file

* add test files

* fix retriever bug, add intent classification

* change conversation chain design

* rewrite prompt and conversation chain

* add ui v1

* ui v1

* fix atavar

* add header

* Refactor the RAG Code and support Pangu

* Refactor the ColossalQA chain to Object-Oriented Programming and the UI demo.

* resolved conversation. tested scripts under examples. web demo still buggy

* fix ci tests

* Some modifications to add ChatGPT api

* modify llm.py and remove unnecessary files

* Delete applications/ColossalQA/examples/ui/test_frontend_input.json

* Remove OpenAI api key

* add colossalqa

* move files

* move files

* move files

* move files

* fix style

* Add Readme and fix some bugs.

* Add something to readme and modify some code

* modify a directory name for clarity

* remove redundant directory

* Correct a type in  llm.py

* fix AI prefix

* fix test_memory.py

* fix conversation

* fix some erros and typos

* Fix a missing import in RAG_ChatBot.py

* add colossalcloud LLM wrapper, correct issues in code review

---------

Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: Orion-Zheng <zheng_zian@u.nus.edu>
Co-authored-by: Zian(Andy) Zheng <62330719+Orion-Zheng@users.noreply.github.com>
Co-authored-by: Orion-Zheng <zhengzian@u.nus.edu>

* remove duplicate import (#5100)

* fix typo change lazy_iniy to lazy_init (#5099)

* [nfc] fix typo change directoty to directory (#5111)

* [FEATURE] Add Safety Eval Datasets to ColossalEval (#5095)

* add safetybench and cvalues(responsibility) eval dataset

* Modify code according to review suggestions

---------

Co-authored-by: Orion-Zheng <zhengzian@u.nus.edu>

* [hotfix] fixed memory usage of shardformer module replacement (#5122)

* [shardformer]: support gpt-j, falcon, Mistral and add interleaved pipeline for bert (#5088)

* [shardformer] implement policy for all GPT-J models and test

* [shardformer] support interleaved pipeline parallel for bert finetune

* [shardformer] shardformer support falcon (#4883)

* [shardformer]: fix interleaved pipeline for bert model (#5048)

* [hotfix]: disable seq parallel for gptj and falcon, and polish code (#5093)

* Add Mistral support for Shardformer (#5103)

* [shardformer] add tests to mistral (#5105)

---------

Co-authored-by: Pengtai Xu <henryxu880@gmail.com>
Co-authored-by: ppt0011 <143150326+ppt0011@users.noreply.github.com>
Co-authored-by: flybird11111 <1829166702@qq.com>
Co-authored-by: eric8607242 <e0928021388@gmail.com>

* [doc] add moe news (#5128)

* [doc] add moe news

* [doc] add moe news

* [doc] add moe news

* [doc] updated paper citation (#5131)

* fix typo change JOSNL TO JSONL etc. (#5116)

* [format] applied code formatting on changed files in pull request 5088 (#5127)

Co-authored-by: github-actions <github-actions@github.com>

* [format] applied code formatting on changed files in pull request 5124 (#5125)

Co-authored-by: github-actions <github-actions@github.com>

* [format] applied code formatting on changed files in pull request 5115 (#5118)

Co-authored-by: github-actions <github-actions@github.com>

* [accelerator] init the accelerator module (#5129)

* [accelerator] init the accelerator module

* polish code

* polish code

* polish code

* polish code

* [npu] support triangle attention for llama (#5130)

* update fused attn

* update spda

* tri attn

* update triangle

* import

* fix

* fix

* [plugin]fix 3d checkpoint load when booster boost without optimizer. (#5135)

* fix 3d checkpoint load when booster boost without optimizer

fix 3d checkpoint load when booster boost without optimizer

* test ci

* revert ci

* fix

fix

* [ColossalQA] refactor server and webui & add new feature (#5138)

* refactor server and webui & add new feature

* add requirements

* modify readme and ui

* [doc] fix colossalqa document (#5146)

* fix doc

* modify doc

* fix (#5158)

fix

* [Colossal-Llama-2] Add finetuning Colossal-Llama-2 example (#4878)

* Add finetuning Colossal-Llama-2 example

* Add finetuning Colossal-Llama-2 example 2

* Add finetuning Colossal-Llama-2 example and support NEFTuning

* Add inference example and refine neftune

* Modify readme file

* update the imports

---------

Co-authored-by: Xu Yuanchen <yuanchen.xu00@gmail.com>
Co-authored-by: Camille Zhong <44392324+Camille7777@users.noreply.github.com>

* [gemini]  hotfix NaN loss while using Gemini + tensor_parallel (#5150)

* fix

aaa

fix

fix

fix

* fix

* fix

* test ci

* fix ci

fix

* [colossalqa] fix pangu api (#5170)

* fix pangu api

* add comment

* [ColossalEval] Support GSM, Data Leakage Evaluation and Tensor Parallel (#5169)

* Support GSM, Data Leakage Evaluation and Tensor Parallel

* remove redundant code and update inference.py in examples/gpt_evaluation

---------

Co-authored-by: Xu Yuanchen <yuanchen.xu00@gmail.com>

* [shardformer] llama support DistCrossEntropy (#5176)

* fix

aaa

fix

fix

fix

* fix

* fix

* test ci

* fix ci

fix

* llama support dist-cross

fix

fix

fix

fix

fix

fix

fix

fix

* fix

* fix

* fix

fix

* test ci

* test ci

* fix

* [Colossal-Llama-2] Add finetuning Colossal-Llama-2 example (#4878)

* Add finetuning Colossal-Llama-2 example

* Add finetuning Colossal-Llama-2 example 2

* Add finetuning Colossal-Llama-2 example and support NEFTuning

* Add inference example and refine neftune

* Modify readme file

* update the imports

---------

Co-authored-by: Xu Yuanchen <yuanchen.xu00@gmail.com>
Co-authored-by: Camille Zhong <44392324+Camille7777@users.noreply.github.com>

* llama support dist-cross

fix

fix

fix

fix

fix

fix

fix

fix

* fix

* fix

* fix

fix

* test ci

* test ci

* fix

* fix ci

* fix ci

---------

Co-authored-by: Yuanchen <70520919+chengeharrison@users.noreply.github.com>
Co-authored-by: Xu Yuanchen <yuanchen.xu00@gmail.com>
Co-authored-by: Camille Zhong <44392324+Camille7777@users.noreply.github.com>

* Fix ColossalEval (#5186)

Co-authored-by: Xu Yuanchen <yuanchen.xu00@gmail.com>

* [doc] update pytorch version in documents. (#5177)

* fix

aaa

fix

fix

fix

* fix

* fix

* test ci

* fix ci

fix

* update pytorch version in documents

* polish readme in application/chat (#5194)

* [pipeline]: fix p2p comm, add metadata cache and support llama interleaved pp (#5134)

* test: add more p2p tests

* fix: remove send_forward_recv_forward as p2p op list need to use the same group

* fix: make send and receive atomic

* feat: update P2PComm fn

* feat: add metadata cache in 1f1b

* feat: add metadata cache in interleaved pp

* feat: modify is_xx_stage fn

* revert: add _broadcast_object_list

* feat: add interleaved pp in llama policy

* feat: set NCCL_BUFFSIZE in HybridParallelPlugin

* Improve logic for selecting metrics (#5196)

Co-authored-by: Xu <yuanchen.xu00@gmail.com>

* [doc] Update required third-party library list for testing and torch comptibility checking (#5207)

* doc/update requirements-test.txt

* update torch-cuda compatibility check

* support linear accumulation fusion (#5199)

support linear accumulation fusion

support linear accumulation fusion

fix

* [pipeline]: support arbitrary batch size in forward_only mode (#5201)

* fix: remove drop last in val & test dataloader

* feat: add run_forward_only, support arbitrary bs

* chore: modify ci script

* [pipeline]: add p2p fallback order and fix interleaved pp deadlock (#5214)

* fix: add fallback order option and update 1f1b

* fix: fix deadlock comm in interleaved pp

* test: modify p2p test

* [devops] update torch versoin in ci (#5217)

* fix-test (#5210)

fix-test

fix-test

* fix flash attn (#5209)

* [nfc] fix typo colossalai/shardformer/ (#5133)

* [Colossal-LLaMA-2] Release Colossal-LLaMA-2-13b-base model (#5224)

* update readme

* update readme

* update link

* update

* update readme

* update

* update

* update

* update title

* update example

* update example

* fix content

* add conclusion

* add license

* update

* update

* update version

* fix minor

* [doc] Update README.md of Colossal-LLAMA2 (#5233)

* Update README.md

* Update README.md

* [doc] Make leaderboard format more uniform and good-looking (#5231)

* Make leaderboard format more unifeid and good-looking

* Update README.md

* Update README.md

* [doc] add Colossal-LLaMA-2-13B (#5234)

* [doc] add Colossal-LLaMA-2-13B

* [doc] add Colossal-LLaMA-2-13B

* [doc] add Colossal-LLaMA-2-13B

* [format] applied code formatting on changed files in pull request 5234 (#5235)

Co-authored-by: github-actions <github-actions@github.com>

* [doc] SwiftInfer release (#5236)

* [doc] SwiftInfer release

* [doc] SwiftInfer release

* [doc] SwiftInfer release

* [doc] SwiftInfer release

* [doc] SwiftInfer release

* [npu] use extension for op builder (#5172)

* update extension

* update cpu adam

* update is

* add doc for cpu adam

* update kernel

* update commit

* update flash

* update memory efficient

* update flash attn

* update flash attention loader

* update api

* fix

* update doc

* update example time limit

* reverse change

* fix doc

* remove useless kernel

* fix

* not use warning

* update

* update

* [pipeline] A more general _communicate in p2p (#5062)

* A more general _communicate

* feat: finish tree_flatten version p2p

* fix: update p2p api calls

---------

Co-authored-by: Wenhao Chen <cwher@outlook.com>

* [npu] change device to accelerator api (#5239)

* update accelerator

* fix timer

* fix amp

* update

* fix

* update bug

* add error raise

* fix autocast

* fix set device

* remove doc accelerator

* update doc

* update doc

* update doc

* use nullcontext

* update cpu

* update null context

* change time limit for example

* udpate

* update

* update

* update

* [npu] polish accelerator code

---------

Co-authored-by: Xuanlei Zhao <xuanlei.zhao@gmail.com>
Co-authored-by: zxl <43881818+oahzxl@users.noreply.github.com>

* [hotfix] removed unused flag (#5242)

* [doc] fix typo in Colossal-LLaMA-2/README.md (#5247)

* [workflow] fixed build CI (#5240)

* [workflow] fixed build CI

* polish

* polish

* polish

* polish

* polish

* [ci] fixed booster test (#5251)

* [ci] fixed booster test

* [ci] fixed booster test

* [ci] fixed booster test

* [ci] fixed ddp test (#5254)

* [ci] fixed ddp test

* polish

* fix typo in  applications/ColossalEval/README.md (#5250)

* [ci] fix shardformer tests. (#5255)

* fix ci

fix

* revert: revert p2p

* feat: add enable_metadata_cache option

* revert: enable t5 tests

---------

Co-authored-by: Wenhao Chen <cwher@outlook.com>

* [doc] fix doc typo (#5256)

* [doc] fix annotation display

* [doc] fix llama2 doc

* [hotfix]: add pp sanity check and fix mbs arg (#5268)

* fix: fix misleading mbs arg

* feat: add pp sanity check

* fix: fix 1f1b sanity check

* [workflow] fixed incomplete bash command (#5272)

* [workflow] fixed oom tests (#5275)

* [workflow] fixed oom tests

* polish

* polish

* polish

* [ci] fix test_hybrid_parallel_plugin_checkpoint_io.py (#5276)

* fix ci

fix

* fix test

* revert: revert p2p

* feat: add enable_metadata_cache option

* revert: enable t5 tests

* fix

---------

Co-authored-by: Wenhao Chen <cwher@outlook.com>

* [shardformer] hybridparallelplugin support gradients accumulation. (#5246)

* support gradients acc

fix

fix

fix

fix

fix

fix

fix

fix

fix

fix

fix

fix

fix

* fix

fix

* fix

fix

fix

* [hotfix] Fix ShardFormer test execution path when using sequence parallelism (#5230)

* fix auto loading gpt2 tokenizer (#5279)

* [doc] add llama2-13B disyplay (#5285)

* Update README.md

* fix 13b typo

---------

Co-authored-by: binmakeswell <binmakeswell@gmail.com>

* fix llama pretrain (#5287)

* [hotfix] fix 3d plugin test (#5292)

* fix bug for mefture (#5299)

* [NFC] polish applications/Colossal-LLaMA-2/colossal_llama2/tokenizer/init_tokenizer.py code style (#5228)

* fix some typo (#5307)

* [feat] refactored extension module (#5298)

* [feat] refactored extension module

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* [workflow] updated CI image (#5318)

* [accelerator] fixed npu api

* [tests] fix t5 test. (#5322)

* [ci] fix shardformer tests. (#5255)

* fix ci

fix

* revert: revert p2p

* feat: add enable_metadata_cache option

* revert: enable t5 tests

---------

Co-authored-by: Wenhao Chen <cwher@outlook.com>

* fix t5 test

---------

Co-authored-by: Wenhao Chen <cwher@outlook.com>

* [doc] added docs for extensions (#5324)

* [doc] added docs for extensions

* polish

* polish

* fix typo under extensions/ (#5330)

* fix typo change dosen't to doesn't (#5308)

* [extension] fixed exception catch (#5342)

* [Chat] fix sft loss nan (#5345)

* fix script

* fix script

* fix chat nan

* fix chat nan

* [checkpointio] fix gemini and hybrid parallel optim checkpoint (#5347)

* [checkpointio] fix hybrid parallel optim checkpoint

* [extension] fix cuda extension

* [checkpointio] fix gemini optimizer checkpoint

* polish code

* [fix] remove unnecessary dp_size assert  (#5351)

* fix: remove unnecessary assert

* test: add more 3d plugin tests

* fix: add warning

* [gemini] fix param op hook when output is tuple (#5355)

* [gemini] fix param op hook when output is tuple

* [gemini] fix param op hook

* [llama] fix dataloader for hybrid parallel (#5358)

* [plugin] refactor prepare dataloader

* [plugin] update train script

* [llama] update training script (#5360)

* [llama] update training script

* [doc] polish docstr

* [llama] add flash attn patch for npu (#5362)

* [llama] fix neftune & pbar with start_step (#5364)

* [eval] update llama npu eval (#5366)

* [llama] polish training script and fix optim ckpt (#5368)

* [lr-scheduler] fix load state dict and add test (#5369)

* [llama] fix memory issue (#5371)

* [llama] fix memory issue

* [llama] add comment

* [moe] init mixtral impl

* [moe] update capacity computing (#5253)

* [moe] top2 allow uneven input

* [moe] update capacity computing

* [moe] remove debug info

* [moe] update capacity computing

* [moe] update capacity computing

* [moe] support mixtral (#5309)

* [moe] add mixtral block for single expert

* [moe] mixtral block fwd support uneven ep

* [moe] mixtral block bwd support uneven ep

* [moe] add mixtral moe layer

* [moe] simplify replace

* [meo] support save sharded mixtral

* [meo] support load sharded mixtral

* [meo] support save sharded optim

* [meo] integrate moe manager into plug

* [meo] fix optimizer load

* [meo] fix mixtral layer

* [moe] fix mixtral checkpoint io (#5314)

* [moe] fix mixtral forward default value (#5329)

* [moe] fix mixtral optim checkpoint (#5344)

* [moe] fix tests

* [release] update version (#5380)

* [llama] fix training and inference scripts (#5384)

* [llama] refactor inference example to fit sft

* [llama] fix training script to fit gemini

* [llama] fix inference script

* [doc] Fix typo (#5361)

* [doc] updated installation command (#5389)

* [hotfix] fix variable type for top_p (#5313)

Co-authored-by: binmakeswell <binmakeswell@gmail.com>

* [hotfix] Fix wrong import in meta_registry (#5392)

* [extension] hotfix jit extension setup (#5402)

* [example] reuse flash attn patch (#5400)

* [fsdp] impl save/load shard model/optimizer (#5357)

* [setup] fixed nightly release (#5388)

* [shardformer]gather llama logits (#5398)

* gather llama logits

* fix

* update requirements (#5407)

* [workflow] added pypi channel (#5412)

* [doc] fix blog link

* [doc] fix blog link

* fix sft single turn inference example (#5416)

* [example]add gpt2 benchmark example script. (#5295)

* benchmark gpt2

* fix

fix

fix

fix

* [doc] fix typo in Colossal-LLaMA-2/README.md (#5247)

* [workflow] fixed build CI (#5240)

* [workflow] fixed build CI

* polish

* polish

* polish

* polish

* polish

* [ci] fixed booster test (#5251)

* [ci] fixed booster test

* [ci] fixed booster test

* [ci] fixed booster test

* [ci] fixed ddp test (#5254)

* [ci] fixed ddp test

* polish

* fix typo in  applications/ColossalEval/README.md (#5250)

* [ci] fix shardformer tests. (#5255)

* fix ci

fix

* revert: revert p2p

* feat: add enable_metadata_cache option

* revert: enable t5 tests

---------

Co-authored-by: Wenhao Chen <cwher@outlook.com>

* [doc] fix doc typo (#5256)

* [doc] fix annotation display

* [doc] fix llama2 doc

* [hotfix]: add pp sanity check and fix mbs arg (#5268)

* fix: fix misleading mbs arg

* feat: add pp sanity check

* fix: fix 1f1b sanity check

* [workflow] fixed incomplete bash command (#5272)

* [workflow] fixed oom tests (#5275)

* [workflow] fixed oom tests

* polish

* polish

* polish

* [ci] fix test_hybrid_parallel_plugin_checkpoint_io.py (#5276)

* fix ci

fix

* fix test

* revert: revert p2p

* feat: add enable_metadata_cache option

* revert: enable t5 tests

* fix

---------

Co-authored-by: Wenhao Chen <cwher@outlook.com>

* [shardformer] hybridparallelplugin support gradients accumulation. (#5246)

* support gradients acc

fix

fix

fix

fix

fix

fix

fix

fix

fix

fix

fix

fix

fix

* fix

fix

* fix

fix

fix

* [hotfix] Fix ShardFormer test execution path when using sequence parallelism (#5230)

* fix auto loading gpt2 tokenizer (#5279)

* [doc] add llama2-13B disyplay (#5285)

* Update README.md

* fix 13b typo

---------

Co-authored-by: binmakeswell <binmakeswell@gmail.com>

* fix llama pretrain (#5287)

* fix

* fix

* fix

fix

* fix

fix

fix

* fix

fix

* benchmark gpt2

* fix

fix

fix

fix

* [workflow] fixed build CI (#5240)

* [workflow] fixed build CI

* polish

* polish

* polish

* polish

* polish

* [ci] fixed booster test (#5251)

* [ci] fixed booster test

* [ci] fixed booster test

* [ci] fixed booster test

* fix

fix

* fix

fix

fix

* fix

* fix

fix

fix

fix

fix

* fix

* Update shardformer.py

---------

Co-authored-by: digger yu <digger-yu@outlook.com>
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: Wenhao Chen <cwher@outlook.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com>
Co-authored-by: Michelle <97082656+MichelleMa8@users.noreply.github.com>
Co-authored-by: Desperado-Jia <502205863@qq.com>

* [doc] sora release (#5425)

* [doc] sora release

* [doc] sora release

* [doc] sora release

* [doc] sora release

* [devops] fix extention building (#5427)

* [hotfix] fix sd vit import error (#5420)

* fix import error

* Update dpt_depth.py

---------

Co-authored-by: binmakeswell <binmakeswell@gmail.com>

* [hotfix] fix typo of openmoe model source (#5403)

* [doc] update some translations with README-zh-Hans.md (#5382)

* [hotfix] fix typo change _descrption to _description (#5331)

* [hotfix] fix typo change enabel to enable under colossalai/shardformer/ (#5317)

* [eval-hotfix] set few_shot_data to None when few shot is disabled (#5422)

* [hotfix] fix typo change MoECheckpintIO to MoECheckpointIO (#5335)

Co-authored-by: binmakeswell <binmakeswell@gmail.com>

* [doc] Fix typo s/infered/inferred/ (#5288)

Signed-off-by: hugo-syn <hugo.vincent@synacktiv.com>

* [hotfix] fix stable diffusion inference bug. (#5289)

* Update train_ddp.yaml

delete  "strategy" to fix DDP config loading bug in "main.py"

* Update train_ddp.yaml

fix inference with scripts/txt2img.py config file load bug.

* Update README.md

add pretrain model test code.

* [colossal-llama2] add stream chat examlple for chat version model (#5428)

* add stream chat for chat version

* remove os.system clear

* modify function name

* [release] update version (#5411)

* fix tensor data update for gemini loss caluculation (#5442)

* [hotfix] fix typo s/keywrods/keywords etc. (#5429)

* [devops] fix compatibility (#5444)

* [devops] fix compatibility

* [hotfix] update compatibility test on pr

* [devops] fix compatibility

* [devops] record duration during comp test

* [test] decrease test duration

* fix falcon

* [shardformer] fix gathering output when using tensor parallelism (#5431)

* fix

* padding vocab_size when using pipeline parallellism

padding vocab_size when using pipeline parallellism

fix

fix

* fix

* fix

fix

fix

* fix gather output

* fix

* fix

* fix

fix resize embedding

fix resize embedding

* fix resize embedding

fix

* revert

* revert

* revert

* [doc] release Open-Sora 1.0 with model weights (#5468)

* [doc] release Open-Sora 1.0 with model weights

* [doc] release Open-Sora 1.0 with model weights

* [doc] release Open-Sora 1.0 with model weights

* [doc] update open-sora demo (#5479)

* [doc] update open-sora demo

* [doc] update open-sora demo

* [doc] update open-sora demo

* [example] add grok-1 inference (#5485)

* [misc] add submodule

* remove submodule

* [example] support grok-1 tp inference

* [example] add grok-1 inference script

* [example] refactor code

* [example] add grok-1 readme

* [exmaple] add test ci

* [exmaple] update readme

* [release] grok-1 314b inference (#5490)

* [release] grok-1 inference

* [release] grok-1 inference

* [release] grok-1 inference

* [example] update Grok-1 inference (#5495)

* revise grok-1 example

* remove unused arg in scripts

* prevent re-installing torch

* update readme

* revert modifying colossalai requirements

* add perf

* trivial

* add tokenizer url

* [hotfix] set return_outputs=False in examples and polish code (#5404)

* fix: simplify merge_batch

* fix: use return_outputs=False to eliminate extra memory consumption

* feat: add return_outputs warning

* style: remove `return_outputs=False` as it is the default value

* [release] grok-1 inference benchmark (#5500)

* [release] grok-1 inference benchmark

* [release] grok-1 inference benchmark

* [release] grok-1 inference benchmark

* [release] grok-1 inference benchmark

* [release] grok-1 inference benchmark

* [shardformer]Fix lm parallel. (#5480)

* fix

* padding vocab_size when using pipeline parallellism

padding vocab_size when using pipeline parallellism

fix

fix

* fix

* fix

fix

fix

* fix gather output

* fix

* fix

* fix

fix resize embedding

fix resize embedding

* fix resize embedding

fix

* revert

* revert

* revert

* fix lm forward distribution

* fix

* test ci

* fix

* [fix] fix grok-1 example typo (#5506)

* [devops] fix example test ci (#5504)

* Fix ColoTensorSpec for py11 (#5440)

* fixed layout converter caching and updated tester

* Empty-Commit

* [shardformer] update colo attention to support custom mask (#5510)

* [feature] refactor colo attention (#5462)

* [extension] update api

* [feature] add colo attention

* [feature] update sdpa

* [feature] update npu attention

* [feature] update flash-attn

* [test] add flash attn test

* [test] update flash attn test

* [shardformer] update modeling to fit colo attention (#5465)

* [misc] refactor folder structure

* [shardformer] update llama flash-attn

* [shardformer] fix llama policy

* [devops] update tensornvme install

* [test] update llama test

* [shardformer] update colo attn kernel dispatch

* [shardformer] update blip2

* [shardformer] update chatglm

* [shardformer] update gpt2

* [shardformer] update gptj

* [shardformer] update opt

* [shardformer] update vit

* [shardformer] update colo attention mask prep

* [shardformer] update whisper

* [test] fix shardformer tests (#5514)

* [test] fix shardformer tests

* [test] fix shardformer tests

* [format] applied code formatting on changed files in pull request 5510 (#5517)

Co-authored-by: github-actions <github-actions@github.com>

* [shardformer] fix pipeline forward error if custom layer distribution is used (#5189)

* Use self.[distribute_layers|get_stage_index] to exploit custom layer distribution

* Change static methods for t5 layer distribution to member functions

* Change static methods for whisper layer distribution to member functions

* Replace whisper policy usage with self one

* Fix test case to use non-static layer distribution methods

* fix: fix typo

---------

Co-authored-by: Wenhao Chen <cwher@outlook.com>

* [Fix] Grok-1 use tokenizer from the same pretrained path (#5532)

* [fix] use tokenizer from the same pretrained path

* trust remote code

* [ColossalChat] Update RLHF V2 (#5286)

* Add dpo. Fix sft, ppo, lora. Refactor all

* fix and tested ppo

* 2 nd round refactor

* add ci tests

* fix ci

* fix ci

* fix readme, style

* fix readme style

* fix style, fix benchmark

* reproduce benchmark result, remove useless files

* rename to ColossalChat

* use new image

* fix ci workflow

* fix ci

* use local model/tokenizer for ci tests

* fix ci

* fix ci

* fix ci

* fix ci timeout

* fix rm progress bar. fix ci timeout

* fix ci

* fix ci typo

* remove 3d plugin from ci temporary

* test environment

* cannot save optimizer

* support chat template

* fix readme

* fix path

* test ci locally

* restore build_or_pr

* fix ci data path

* fix benchmark

* fix ci, move ci tests to 3080, disable fast tokenizer

* move ci to 85

* support flash attention 2

* add all-in-one data preparation script. Fix colossal-llama2-chat chat template

* add hardware requirements

* move ci test data

* fix save_model, add unwrap

* fix missing bos

* fix missing bos; support grad accumulation with gemini

* fix ci

* fix ci

* fix ci

* fix llama2 chat template config

* debug sft

* debug sft

* fix colossalai version requirement

* fix ci

* add sanity check to prevent NaN loss

* fix requirements

* add dummy data generation script

* add dummy data generation script

* add dummy data generation script

* add dummy data generation script

* update readme

* update readme

* update readme and ignore

* fix logger bug

* support parallel_output

* modify data preparation logic

* fix tokenization

* update lr

* fix inference

* run pre-commit

---------

Co-authored-by: Tong Li <tong.li352711588@gmail.com>

* [shardformer, pipeline] add `gradient_checkpointing_ratio` and heterogenous shard policy for llama (#5508)

* feat: add `GradientCheckpointConfig` and `PipelineGradientCheckpointConfig`

* feat: apply `GradientCheckpointConfig` to policy and llama_forward

* feat: move `distribute_layer` and `get_stage_index` to PipelineStageManager

* fix: add optional args for `distribute_layer` and `get_stage_index`

* fix: fix changed API calls

* test: update llama tests

* style: polish `GradientCheckpointConfig`

* fix: fix pipeline utils tests

* fix incorrect sharding without zero (#5545)

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [shardformer] Sequence Parallelism Optimization (#5533)

* sequence parallel optimization

* validate sequence parallel in llama (code to be polished)

* shardformer api writing

* integrate sequence parallel in ShardFormer

* fix pp bugs and sp bugs for LlaMa model

* integrating ring-based sequence parallelism into ShardFormer

* [sequence parallelism]: Add fused megatron function

* integrating ring-based sequence parallelism into ShardFormer

---------

Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn>

* fix bugs when useing sp and flashattention together

* fix operation function name

* support flash attention for ulysses-style sp

* clarify sp process group

* fix compatibility bugs in moe plugin

* fix fused linear bugs

* fix linear layer test

* support gpt model all-to-all sp

* modify shard data dimension (meant to be dim=-1)

* support megtron-style sp and distributed attn for llama model

* [shardformer] add megatron sp to llama

* support llama7B 128k with distributed attention

* [shardformer] robustness enhancement

* add block attn

* sp mode 1: keep input as a complete sequence

* fix sp compatability

* finish sp mode 3 support for gpt

* using all_to_all_single when batch size is 1

* support mode 2 sp in gpt2 (#5)

* [shardformer] add megatron sp to llama

* support llama7B 128k with distributed attention

* [shardformer] robustness enhancement

* add block attn

* sp mode 1: keep input as a complete sequence

* fix sp compatability

* refactor ring implementation

* support mode 2 sp in gpt2

* polish code

* enable distributed attn mask when using sp mode 2 and 3 in llama

* automatically enable flash attn when using sp mode 2 and 3 in llama

* inplace attn mask

* add zero2 support for sequence parallel

* polish code

* fix bugs

* fix gemini checkpoint io

* loose tensor checking atol and rtol

* add comment

* fix llama layernorm grad

* fix zero grad

* fix zero grad

* fix conflict

* update split and gather auto grad func

* sequence parallel: inside text split (#6)

* polish code (part 1)

* polish code (part 2)

* polish code (part 2.5)

* polish code (part 3)

* sequence parallel: inside text split

* miscellaneous minor fixes

* polish code

* fix ulysses style ZeRO

* sequence parallel: inside text split

* miscellaneous minor fixes

* disaggregate sp group and dp group for  sp

* fix llama and gpt sp

* polish code

* move ulysses grad sync to ddp (#9)

* remove zero_stage and unbind the grad sync for alltoall sp

* add 2d group creation test

* move ulysses grad sync to ddp

* add 2d group creation test

* remove useless code

* change shard config not to enable sp when enable_all_optimizations

* add sp warnings for several model

* remove useless code

---------

Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn>

* [hotfix] quick fixes to make legacy tutorials runnable (#5559)

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [fix] fix typo s/muiti-node /multi-node etc. (#5448)

* [hotfix] fix typo s/get_defualt_parser /get_default_parser (#5548)

* [devops] remove post commit ci (#5566)

* [devops] remove post commit ci

* [misc] run pre-commit on all files

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [doc] fix ColossalMoE readme (#5599)

* fix readme

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [zero] support multiple (partial) backward passes (#5596)

* [zero] support multiple (partial) backward passes

* [misc] update requirements

* [shardformer] refactor embedding resize (#5603)

* [branch rebase] rebase main to Feature/resize_embedding (#5554)

* fix

* [release] update version (#5411)

* [hotfix] fix typo s/keywrods/keywords etc. (#5429)

* [devops] fix compatibility (#5444)

* [devops] fix compatibility

* [hotfix] update compatibility test on pr

* [devops] fix compatibility

* [devops] record duration during comp test

* [test] decrease test duration

* fix falcon

* [shardformer] fix gathering output when using tensor parallelism (#5431)

* fix

* padding vocab_size when using pipeline parallellism

padding vocab_size when using pipeline parallellism

fix

fix

* fix

* fix

fix

fix

* fix gather output

* fix

* fix

* fix

fix resize embedding

fix resize embedding

* fix resize embedding

fix

* revert

* revert

* revert

* [doc] release Open-Sora 1.0 with model weights (#5468)

* [doc] release Open-Sora 1.0 with model weights

* [doc] release Open-Sora 1.0 with model weights

* [doc] release Open-Sora 1.0 with model weights

* [doc] update open-sora demo (#5479)

* [doc] update open-sora demo

* [doc] update open-sora demo

* [doc] update open-sora demo

* [example] add grok-1 inference (#5485)

* [misc] add submodule

* remove submodule

* [example] support grok-1 tp inference

* [example] add grok-1 inference script

* [example] refactor code

* [example] add grok-1 readme

* [exmaple] add test ci

* [exmaple] update readme

---------

Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: digger yu <digger-yu@outlook.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>

* [CI] run pre-commit (#5577)

* fix

* [release] update version (#5411)

* [hotfix] fix typo s/keywrods/keywords etc. (#5429)

* [devops] fix compatibility (#5444)

* [devops] fix compatibility

* [hotfix] update compatibility test on pr

* [devops] fix compatibility

* [devops] record duration during comp test

* [test] decrease test duration

* fix falcon

* [shardformer] fix gathering output when using tensor parallelism (#5431)

* fix

* padding vocab_size when using pipeline parallellism

padding vocab_size when using pipeline parallellism

fix

fix

* fix

* fix

fix

fix

* fix gather output

* fix

* fix

* fix

fix resize embedding

fix resize embedding

* fix resize embedding

fix

* revert

* revert

* revert

* [doc] release Open-Sora 1.0 with model weights (#5468)

* [doc] release Open-Sora 1.0 with model weights

* [doc] release Open-Sora 1.0 with model weights

* [doc] release Open-Sora 1.0 with model weights

* [doc] update open-sora demo (#5479)

* [doc] update open-sora demo

* [doc] update open-sora demo

* [doc] update open-sora demo

* [example] add grok-1 inference (#5485)

* [misc] add submodule

* remove submodule

* [example] support grok-1 tp inference

* [example] add grok-1 inference script

* [example] refactor code

* [example] add grok-1 readme

* [exmaple] add test ci

* [exmaple] update readme

* run pre-commit

---------

Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: digger yu <digger-yu@outlook.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>

* [rebase] rebase main to resize-embedding (#5581)

* [release] grok-1 314b inference (#5490)

* [release] grok-1 inference

* [release] grok-1 inference

* [release] grok-1 inference

* [example] update Grok-1 inference (#5495)

* revise grok-1 example

* remove unused arg in scripts

* prevent re-installing torch

* update readme

* revert modifying colossalai requirements

* add perf

* trivial

* add tokenizer url

* [hotfix] set return_outputs=False in examples and polish code (#5404)

* fix: simplify merge_batch

* fix: use return_outputs=False to eliminate extra memory consumption

* feat: add return_outputs warning

* style: remove `return_outputs=False` as it is the default value

* [release] grok-1 inference benchmark (#5500)

* [release] grok-1 inference benchmark

* [release] grok-1 inference benchmark

* [release] grok-1 inference benchmark

* [release] grok-1 inference benchmark

* [release] grok-1 inference benchmark

* [shardformer]Fix lm parallel. (#5480)

* fix

* padding vocab_size when using pipeline parallellism

padding vocab_size when using pipeline parallellism

fix

fix

* fix

* fix

fix

fix

* fix gather output

* fix

* fix

* fix

fix resize embedding

fix resize embedding

* fix resize embedding

fix

* revert

* revert

* revert

* fix lm forward distribution

* fix

* test ci

* fix

* [fix] fix grok-1 example typo (#5506)

* [devops] fix example test ci (#5504)

* Fix ColoTensorSpec for py11 (#5440)

* fixed layout converter caching and updated tester

* Empty-Commit

* [shardformer] update colo attention to support custom mask (#5510)

* [feature] refactor colo attention (#5462)

* [extension] update api

* [feature] add colo attention

* [feature] update sdpa

* [feature] update npu attention

* [feature] update flash-attn

* [test] add flash attn test

* [test] update flash attn test

* [shardformer] update modeling to fit colo attention (#5465)

* [misc] refactor folder structure

* [shardformer] update llama flash-attn

* [shardformer] fix llama policy

* [devops] update tensornvme install

* [test] update llama test

* [shardformer] update colo attn kernel dispatch

* [shardformer] update blip2

* [shardformer] update chatglm

* [shardformer] update gpt2

* [shardformer] update gptj

* [shardformer] update opt

* [shardformer] update vit

* [shardformer] update colo attention mask prep

* [shardformer] update whisper

* [test] fix shardformer tests (#5514)

* [test] fix shardformer tests

* [test] fix shardformer tests

* [format] applied code formatting on changed files in pull request 5510 (#5517)

Co-authored-by: github-actions <github-actions@github.com>

* [shardformer] fix pipeline forward error if custom layer distribution is used (#5189)

* Use self.[distribute_layers|get_stage_index] to exploit custom layer distribution

* Change static methods for t5 layer distribution to member functions

* Change static methods for whisper layer distribution to member functions

* Replace whisper policy usage with self one

* Fix test case to use non-static layer distribution methods

* fix: fix typo

---------

Co-authored-by: Wenhao Chen <cwher@outlook.com>

* [Fix] Grok-1 use tokenizer from the same pretrained path (#5532)

* [fix] use tokenizer from the same pretrained path

* trust remote code

* [ColossalChat] Update RLHF V2 (#5286)

* Add dpo. Fix sft, ppo, lora. Refactor all

* fix and tested ppo

* 2 nd round refactor

* add ci tests

* fix ci

* fix ci

* fix readme, style

* fix readme style

* fix style, fix benchmark

* reproduce benchmark result, remove useless files

* rename to ColossalChat

* use new image

* fix ci workflow

* fix ci

* use local model/tokenizer for ci tests

* fix ci

* fix ci

* fix ci

* fix ci timeout

* fix rm progress bar. fix ci timeout

* fix ci

* fix ci typo

* remove 3d plugin from ci temporary

* test environment

* cannot save optimizer

* support chat template

* fix readme

* fix path

* test ci locally

* restore build_or_pr

* fix ci data path

* fix benchmark

* fix ci, move ci tests to 3080, disable fast tokenizer

* move ci to 85

* support flash attention 2

* add all-in-one data preparation script. Fix colossal-llama2-chat chat template

* add hardware requirements

* move ci test data

* fix save_model, add unwrap

* fix missing bos

* fix missing bos; support grad accumulation with gemini

* fix ci

* fix ci

* fix ci

* fix llama2 chat template config

* debug sft

* debug sft

* fix colossalai version requirement

* fix ci

* add sanity check to prevent NaN loss

* fix requirements

* add dummy data generation script

* add dummy data generation script

* add dummy data generation script

* add dummy data generation script

* update readme

* update readme

* update readme and ignore

* fix logger bug

* support parallel_output

* modify data preparation logic

* fix tokenization

* update lr

* fix inference

* run pre-commit

---------

Co-authored-by: Tong Li <tong.li352711588@gmail.com>

* [shardformer, pipeline] add `gradient_checkpointing_ratio` and heterogenous shard policy for llama (#5508)

* feat: add `GradientCheckpointConfig` and `PipelineGradientCheckpointConfig`

* feat: apply `GradientCheckpointConfig` to policy and llama_forward

* feat: move `distribute_layer` and `get_stage_index` to PipelineStageManager

* fix: add optional args for `distribute_layer` and `get_stage_index`

* fix: fix changed API calls

* test: update llama tests

* style: polish `GradientCheckpointConfig`

* fix: fix pipeline utils tests

* fix incorrect sharding without zero (#5545)

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [shardformer] Sequence Parallelism Optimization (#5533)

* sequence parallel optimization

* validate sequence parallel in llama (code to be polished)

* shardformer api writing

* integrate sequence parallel in ShardFormer

* fix pp bugs and sp bugs for LlaMa model

* integrating ring-based sequence parallelism into ShardFormer

* [sequence parallelism]: Add fused megatron function

* integrating ring-based sequence parallelism into ShardFormer

---------

Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn>

* fix bugs when useing sp and flashattention together

* fix operation function name

* support flash attention for ulysses-style sp

* clarify sp process group

* fix compatibility bugs in moe plugin

* fix fused linear bugs

* fix linear layer test

* support gpt model all-to-all sp

* modify shard data dimension (meant to be dim=-1)

* support megtron-style sp and distributed attn for llama model

* [shardformer] add megatron sp to llama

* support llama7B 128k with distributed attention

* [shardformer] robustness enhancement

* add block attn

* sp mode 1: keep input as a complete sequence

* fix sp compatability

* finish sp mode 3 support for gpt

* using all_to_all_single when batch size is 1

* support mode 2 sp in gpt2 (#5)

* [shardformer] add megatron sp to llama

* support llama7B 128k with distributed attention

* [shardformer] robustness enhancement

* add block attn

* sp mode 1: keep input as a complete sequence

* fix sp compatability

* refactor ring implementation

* support mode 2 sp in gpt2

* polish code

* enable distributed attn mask when using sp mode 2 and 3 in llama

* automatically enable flash attn when using sp mode 2 and 3 in llama

* inplace attn mask

* add zero2 support for sequence parallel

* polish code

* fix bugs

* fix gemini checkpoint io

* loose tensor checking atol and rtol

* add comment

* fix llama layernorm grad

* fix zero grad

* fix zero grad

* fix conflict

* update split and gather auto grad func

* sequence parallel: inside text split (#6)

* polish code (part 1)

* polish code (part 2)

* polish code (part 2.5)

* polish code (part 3)

* sequence parallel: inside text split

* miscellaneous minor fixes

* polish code

* fix ulysses style ZeRO

* sequence parallel: inside text split

* miscellaneous minor fixes

* disaggregate sp group and dp group for  sp

* fix llama and gpt sp

* polish code

* move ulysses grad sync to ddp (#9)

* remove zero_stage and unbind the grad sync for alltoall sp

* add 2d group creation test

* move ulysses grad sync to ddp

* add 2d group creation test

* remove useless code

* change shard config not to enable sp when enable_all_optimizations

* add sp warnings for several model

* remove useless code

---------

Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn>

* [hotfix] quick fixes to make legacy tutorials runnable (#5559)

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [fix] fix typo s/muiti-node /multi-node etc. (#5448)

* [hotfix] fix typo s/get_defualt_parser /get_default_parser (#5548)

* [devops] remove post commit ci (#5566)

* [devops] remove post commit ci

* [misc] run pre-commit on all files

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

---------

Co-authored-by: binmakeswell <binmakeswell@gmail.com>
Co-authored-by: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com>
Co-authored-by: Wenhao Chen <cwher@outlook.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: Rocky Duan <dementrock@users.noreply.github.com>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions <github-actions@github.com>
Co-authored-by: Insu Jang <insujang@umich.edu>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com>
Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn>
Co-authored-by: digger yu <digger-yu@outlook.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [shardformer]enable padding vocabulary size. (#5489)

* padding vocab_size when using pipeline parallellism

padding vocab_size when using pipeline parallellism

fix

fix

* fix

* fix

fix

fix

* fix gather output

* fix

* fix

* fix

fix resize embedding

fix resize embedding

* fix resize embedding

fix

* revert

* revert

* revert

* padding vocab

* padding vocabe

* fix

* fix

* fxi

* test ci

* fix

fix

fix

fix

* fix

fix

* fix

* fix

* Update hybrid_parallel_plugin.py

fix

fix

fix

* fix

fix

* fix

fix

* fix

* resolve super init

resolve super init

resolve super init

resolve super init

* resolve comments

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* vocab checkpointio

* padding vocab_size when using pipeline parallellism

padding vocab_size when using pipeline parallellism

fix

fix

* fix

fix

fix

* fix

* fix

fix resize embedding

fix resize embedding

* fix resize embedding

fix

* revert

* revert

* padding vocab

* fix

* fix

fix

* fix

fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* cherry-pick

* revert moe modify

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

fix

fix

fix

fix

fix

fix

fix

* resolve comments

resolve comments

resolve comments

resolve comments

resolve comments

* ptensor

ptensor

resolve comments

fix

fix

fix

fix

fix

resolve comments

resolve comments

resolve comments

resolve comments

resolve comments

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix rebase

* fix rebase

---------

Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: digger yu <digger-yu@outlook.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
Co-authored-by: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com>
Co-authored-by: Wenhao Chen <cwher@outlook.com>
Co-authored-by: Rocky Duan <dementrock@users.noreply.github.com>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions <github-actions@github.com>
Co-authored-by: Insu Jang <insujang@umich.edu>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com>
Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [hotfix] Fix examples no pad token & auto parallel codegen bug; (#5606)

* fix no pad token bug

* fixed some auto parallel codegen bug, but might not run on torch 2.1

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [shardformer] fix pipeline grad ckpt (#5620)

* [shardformer] fix pipeline grad ckpt

* [lora] add lora APIs for booster, support lora for TorchDDP (#4981)

* add apis and peft requirement

* add liscense and implement apis

* add checkpointio apis

* add torchddp fwd_bwd test

* add support_lora methods

* add checkpointio test and debug

* delete unneeded codes

* remove peft from LICENSE

* add concrete methods for enable_lora

* simplify enable_lora api

* fix requirements

* [LowLevelZero] low level zero support lora (#5153)

* low level zero support lora

low level zero support lora

* add checkpoint test

* add checkpoint test

* fix

* fix

* fix

* fix

fix

fix

fix

* fix

* fix

fix

fix

fix

fix

fix

fix

* fix

* fix

fix

fix

fix

fix

fix

fix

* fix

* test ci

* git # This is a combination of 3 commits.

Update low_level_zero_plugin.py

Update low_level_zero_plugin.py

fix

fix

fix

* fix naming

fix naming

fix naming

fix

* [feature] qlora support

* qlora follow commit

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* migrate qutization folder to colossalai/

* minor fixes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* gptj sp fix

* remove redundancies from pre-commit

* minor fixes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: hugo-syn <hugo.vincent@synacktiv.com>
Co-authored-by: Jianghai <72591262+CjhHa1@users.noreply.github.com>
Co-authored-by: Bin Jia <45593998+FoolPlayer@users.noreply.github.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: yuehuayingxueluo <867460659@qq.com>
Co-authored-by: Cuiqing Li <lixx3527@gmail.com>
Co-authored-by: cuiqing.li <lixx336@gmail.com>
Co-authored-by: Yuanchen <70520919+chengeharrison@users.noreply.github.com>
Co-authored-by: Xu Yuanchen <yuanchen.xu00@gmail.com>
Co-authored-by: littsk <1214689160@qq.com>
Co-authored-by: Baizhou Zhang <eddiezhang@pku.edu.cn>
Co-authored-by: ppt0011 <143150326+ppt0011@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: Xuanlei Zhao <43881818+oahzxl@users.noreply.github.com>
Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions <github-actions@github.com>
Co-authored-by: Wenhao Chen <cwher@outlook.com>
Co-authored-by: Jun Gao <imgaojun@gmail.com>
Co-authored-by: flybird11111 <1829166702@qq.com>
Co-authored-by: Xu Kai <xukai16@foxmail.com>
Co-authored-by: Zian(Andy) Zheng <62330719+Orion-Zheng@users.noreply.github.com>
Co-authored-by: digger yu <digger-yu@outlook.com>
Co-authored-by: CjhHa1 <cjh18671720497@outlook.com>
Co-authored-by: Xu Kai <xukai16@foxamil.com>
Co-authored-by: Orion-Zheng <zheng_zian@u.nus.edu>
Co-authored-by: Elsa Granger <zeyugao@outlook.com>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: Orion-Zheng <zhengzian@u.nus.edu>
Co-authored-by: Pengtai Xu <henryxu880@gmail.com>
Co-authored-by: eric8607242 <e0928021388@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: Michelle <97082656+MichelleMa8@users.noreply.github.com>
Co-authored-by: Camille Zhong <44392324+Camille7777@users.noreply.github.com>
Co-authored-by: BlueRum <70618399+ht-zhou@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: JIMMY ZHAO <knightyzhao@gmail.com>
Co-authored-by: Xuanlei Zhao <xuanlei.zhao@gmail.com>
Co-authored-by: Desperado-Jia <502205863@qq.com>
Co-authored-by: 李文军 <40464906+liwenjuna@users.noreply.github.com>
Co-authored-by: yixiaoer <miyaku@yixiaoer.sg>
Co-authored-by: CZYCW <czyczf@163.com>
Co-authored-by: Stephan Kölker <stephankoe@users.noreply.github.com>
Co-authored-by: QinLuo <eric.x.sun@gmail.com>
Co-authored-by: MickeyCHAN <76671016+danyow-cheung@users.noreply.github.com>
Co-authored-by: Luo Yihang <luo_yihang@outlook.com>
Co-authored-by: Dongruixuan Li <dongruixuan@hotmail.com>
Co-authored-by: hugo-syn <61210734+hugo-syn@users.noreply.github.com>
Co-authored-by: Youngon <Youngon_wyl@163.com>
Co-authored-by: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com>
Co-authored-by: Rocky Duan <dementrock@users.noreply.github.com>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: Insu Jang <insujang@umich.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2024-04-23 17:57:44 +08:00

653 lines
24 KiB
Python

from dataclasses import dataclass
from enum import Enum
from typing import Dict, List, Optional
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from colossalai.accelerator import get_accelerator
class TensorState(Enum):
FREE = 0
COMPUTE = 1
HOLD = 2
HOLD_AFTER_BWD = 3
READY_FOR_REDUCE = 4
STATE_TRANS = (
(TensorState.FREE, TensorState.HOLD),
(TensorState.FREE, TensorState.COMPUTE),
(TensorState.HOLD, TensorState.FREE),
(TensorState.HOLD, TensorState.COMPUTE),
(TensorState.COMPUTE, TensorState.HOLD),
(TensorState.COMPUTE, TensorState.HOLD_AFTER_BWD),
(TensorState.HOLD_AFTER_BWD, TensorState.COMPUTE),
(TensorState.HOLD_AFTER_BWD, TensorState.READY_FOR_REDUCE),
(TensorState.READY_FOR_REDUCE, TensorState.HOLD),
)
@dataclass
class TensorInfo:
state: TensorState
offset: int
end: int
class ChunkFullError(Exception):
pass
def is_storage_empty(tensor: torch.Tensor) -> bool:
return tensor.storage().size() == 0
def free_storage(tensor: torch.Tensor) -> None:
if not is_storage_empty(tensor):
tensor.storage().resize_(0)
def alloc_storage(tensor: torch.Tensor) -> None:
if is_storage_empty(tensor):
tensor.storage().resize_(tensor.numel())
class Chunk:
_total_number = 0
def __init__(
self,
chunk_size: int,
zero_group: ProcessGroup,
dtype: torch.dtype,
init_device: Optional[torch.device] = None,
cpu_shard_init: bool = False,
keep_gathered: bool = False,
pin_memory: bool = False,
extra_dp_group: ProcessGroup = None,
) -> None:
"""
Chunk: A container owning a piece of contiguous memory space for tensors
Here we use all-gather operation to gather the whole chunk.
Currently, Chunk is exclusively used for DDP and ZeRO DDP and it doesn't support unused parameters.
It is designed to make the full use of communication and PCIE bandwidth.
Args:
chunk_size (int): the number of elements in the chunk
zero_group (ProcessGroup): the process group of this chunk
dtype (torch.dtype): the data type of the chunk
init_device (torch.device): optional, During the chunk construction process, where the tensor is stored.
The default value is None, which is the current GPU
cpu_shard_init (bool): a flag indicates the local chunk shard is resident on CPU.
keep_gathered (bool): optional, if True, this chunk is always gathered in CUDA memory
pin_memory (bool): optional, if True, this chunk always has a shard copied in pinned CPU memory
"""
self.count_id = Chunk._total_number
Chunk._total_number += 1
self.chunk_size = chunk_size
self.utilized_size = 0
self.torch_pg = zero_group
self.pg_size = dist.get_world_size(self.torch_pg)
self.pg_rank = dist.get_rank(self.torch_pg)
self.extra_dp_group = extra_dp_group
self.extra_dp_size = dist.get_world_size(self.extra_dp_group) if self.extra_dp_group is not None else 1
# the chunk size should be divisible by the dp degree
if not keep_gathered:
assert chunk_size % self.pg_size == 0
self.shard_size = chunk_size // self.pg_size
self.shard_begin = self.shard_size * self.pg_rank
self.shard_end = self.shard_begin + self.shard_size
self.valid_end = self.shard_size
self.dtype = dtype
device = init_device or get_accelerator().get_current_device()
# chunk_temp is a global chunk, which only exists during building the chunks.
self.chunk_temp = torch.zeros(chunk_size, dtype=dtype, device=device) # keep all zero
self.cuda_global_chunk = None # we force cuda_global_chunk located in CUDA
# cuda local chunk, which is sharded on GPUs
self.cuda_shard = None
# cpu local chunk, which is sharded on CPUs
self.cpu_shard = None
# is the chunks gathers, which means chunks are duplicated on each process,
# and we should use the cuda_global_chunk.
self.is_gathered = True
# configure the init device of the shard
# no-offload default: fp16, fp32 -> CUDA
# offload default: fp16, fp32 -> CPU
self.shard_device = torch.device("cpu") if cpu_shard_init else get_accelerator().get_current_device()
self.chunk_mem = self.chunk_size * self.chunk_temp.element_size()
self.shard_mem = self.chunk_mem // self.pg_size
# each tensor is associated with a TensorInfo to track its meta info
# (state, offset, end)
self.tensors_info: Dict[torch.Tensor, TensorInfo] = {}
# the total number of tensors in the chunk
self.num_tensors = 0
# Record the number of tensors in different states
self.tensor_state_cnter: Dict[TensorState, int] = dict()
for state in TensorState:
self.tensor_state_cnter[state] = 0
# If a chunk is kept gathered,
# they are treated the same as that of the parameters in DDP during training.
self.keep_gathered = keep_gathered
if self.keep_gathered:
pin_memory = False # since this chunk is gathered, it doesn't need to pin
# if pin_memory is True, we allocate a piece of CPU pin-memory
# for it all the time
self.pin_memory = pin_memory
# we introduce the paired chunk here
# it refers to another chunk having the same parameters
# but with different dtype(such as fp16_chunk.paired_chunk -> fp32_chunk
self.paired_chunk = None
# if this chunk is synchronized with the optimizer, the flag is True
self.optim_sync_flag = True
# if the cpu_shard has been visited during the training step, the flag is True
self.cpu_vis_flag = False
# whether to record l2 norm for the gradient clipping calculation
self.l2_norm_flag = False
self.l2_norm = None
self.grad_chunk = None
@property
def memory_usage(self) -> Dict[str, int]:
cuda_memory = 0
cpu_memory = 0
if self.chunk_temp is not None:
# this chunk is not closed
if self.chunk_temp.device.type == "cuda" or self.chunk_temp.device.type == "npu":
cuda_memory += self.chunk_mem
else:
cpu_memory += self.chunk_mem
else:
if self.is_gathered:
cuda_memory += self.chunk_mem
if self.cuda_shard is not None:
cuda_memory += self.shard_mem
if self.cpu_shard is not None:
cpu_memory += self.shard_mem
return dict(cuda=cuda_memory, cpu=cpu_memory)
@property
def device_type(self) -> str:
if self.chunk_temp is not None:
return self.chunk_temp.device.type
elif self.is_gathered or self.cuda_shard is not None:
return get_accelerator().name
else:
return "cpu"
@property
def payload(self) -> torch.Tensor:
# sanity check
assert self.chunk_temp is None
if self.is_gathered:
return self.cuda_global_chunk
elif self.cuda_shard is not None:
return self.cuda_shard
else:
return self.cpu_shard
@property
def payload_mem(self) -> int:
# sanity check
assert self.chunk_temp is None
if self.is_gathered:
return self.chunk_mem
else:
return self.shard_mem
@property
def can_move(self) -> bool:
return not self.is_gathered
@property
def can_release(self) -> bool:
if self.keep_gathered:
return False
else:
return (
self.tensor_state_cnter[TensorState.HOLD] + self.tensor_state_cnter[TensorState.HOLD_AFTER_BWD]
== self.num_tensors
)
@property
def can_reduce(self):
return self.tensor_state_cnter[TensorState.READY_FOR_REDUCE] == self.num_tensors
@property
def has_inf_or_nan(self) -> bool:
"""Check if the chunk has inf or nan values on CUDA."""
if self.is_gathered:
valid_tensor = self.cuda_global_chunk[: self.utilized_size]
else:
assert self.cuda_shard is not None # only check on CUDA
valid_tensor = self.cuda_shard[: self.valid_end]
return torch.isinf(valid_tensor).any().item() | torch.isnan(valid_tensor).any().item()
def set_l2_norm(self) -> None:
"""Record l2 norm of this chunks on CUDA."""
assert self.l2_norm is None, "you are calculating the l2 norm twice"
if self.is_gathered:
valid_tensor = self.cuda_global_chunk[: self.utilized_size]
else:
assert self.cuda_shard is not None # calculate on CUDA
valid_tensor = self.cuda_shard[: self.valid_end]
chunk_l2_norm = valid_tensor.data.float().norm(2)
self.l2_norm = chunk_l2_norm.item() ** 2
def append_tensor(self, tensor: torch.Tensor):
"""Add a tensor to the chunk.
Args:
tensor (torch.Tensor): a tensor to be added to the chunk
"""
# sanity check
assert self.chunk_temp is not None
assert tensor.dtype == self.dtype
new_utilized_size = self.utilized_size + tensor.numel()
# raise exception when the chunk size is exceeded
if new_utilized_size > self.chunk_size:
raise ChunkFullError
self.chunk_temp[self.utilized_size : new_utilized_size].copy_(tensor.data.flatten())
assert type(self.chunk_temp) == torch.Tensor, "copy_tensor_to_chunk_slice must use a torch tensor"
tensor.data = self.chunk_temp[self.utilized_size : new_utilized_size].view(tensor.shape)
# record all the information about the tensor
self.num_tensors += 1
tensor_state = TensorState.HOLD
self.tensors_info[tensor] = TensorInfo(tensor_state, self.utilized_size, new_utilized_size)
self.tensor_state_cnter[tensor_state] += 1
self.utilized_size = new_utilized_size
def close_chunk(self):
"""Close the chunk. Any tensor can't be appended to a closed chunk later."""
# sanity check
assert self.chunk_temp is not None
# calculate the valid end for each shard
if self.utilized_size <= self.shard_begin:
self.valid_end = 0
elif self.utilized_size < self.shard_end:
self.valid_end = self.utilized_size - self.shard_begin
if self.chunk_temp.device.type == "cpu":
self.cuda_global_chunk = self.chunk_temp.to(get_accelerator().get_current_device())
self.__update_tensors_ptr()
else:
self.cuda_global_chunk = self.chunk_temp
self.chunk_temp = None
self.__scatter()
# gathered chunk never have shard attribute
if self.keep_gathered:
return
if self.pin_memory or self.shard_device.type == "cpu":
self.cpu_shard = torch.empty(self.shard_size, dtype=self.dtype, pin_memory=self.pin_memory)
self.cpu_shard.copy_(self.cuda_shard)
self.cpu_vis_flag = True # cpu_shard has been visited
if self.shard_device.type == "cpu":
self.cuda_shard = None
def shard_move(self, device: torch.device, force_copy: bool = False):
"""Move the shard tensor in the chunk.
Args:
device: the device to which the shard will move
force_copy: if True, copy function is called mandatorily
"""
# sanity check
assert not self.is_gathered
# when the current chunk is not synchronized with the optimizer
# just use another way for the movement
if not self.optim_sync_flag:
assert device.type == "cuda" or device.type == "npu", "each chunk should first be moved to CUDA"
self.__paired_shard_move()
self.optim_sync_flag = True
return
if device.type == "cuda" or device.type == "npu":
assert device == get_accelerator().get_current_device(), "can't move chunk to another device"
if self.cuda_shard:
return
self.cuda_shard = self.cpu_shard.to(get_accelerator().get_current_device())
if not self.pin_memory:
self.cpu_shard = None
elif device.type == "cpu":
if self.cuda_shard is None:
return
if self.pin_memory:
if force_copy or not self.cpu_vis_flag:
self.cpu_shard.copy_(self.cuda_shard)
# if cpu_shard has been visited
# copy operation is not need
else:
self.cpu_shard = self.cuda_shard.cpu()
self.cpu_vis_flag = True
self.cuda_shard = None
else:
raise NotImplementedError
def access_chunk(self):
"""Make the chunk usable for the parameters inside it. It's an operation done in CUDA."""
# sanity check
assert self.chunk_temp is None
if not self.is_gathered:
self.__gather()
self.__update_tensors_ptr()
def release_chunk(self):
"""Release the usable chunk. It's an operation done in CUDA."""
# sanity check
assert self.chunk_temp is None
if self.is_gathered:
self.__scatter()
def reduce(self):
"""Reduce scatter all the gradients. It's an operation done in CUDA."""
# sanity check
assert self.is_gathered
if self.pg_size == 1:
# tricky code here
# just move cuda_global_chunk to cuda_shard
# the communication is not necessary
self.__scatter()
if self.extra_dp_group is not None:
dist.all_reduce(self.cuda_shard, group=self.extra_dp_group)
elif self.keep_gathered:
# we use all-reduce here
dist.all_reduce(self.cuda_global_chunk, group=self.torch_pg)
if self.extra_dp_group is not None:
dist.all_reduce(self.cuda_global_chunk, group=self.extra_dp_group)
else:
self.cuda_shard = torch.empty(
self.shard_size, dtype=self.dtype, device=get_accelerator().get_current_device()
)
input_list = list(torch.chunk(self.cuda_global_chunk, chunks=self.pg_size, dim=0))
dist.reduce_scatter(self.cuda_shard, input_list, group=self.torch_pg)
if self.extra_dp_group is not None:
dist.all_reduce(self.cuda_shard, group=self.extra_dp_group)
free_storage(self.cuda_global_chunk)
self.is_gathered = False
self.__update_tensors_state(TensorState.HOLD)
def tensor_trans_state(self, tensor: torch.Tensor, tensor_state: TensorState) -> None:
"""
Make a transition of the tensor into the next state.
Args:
tensor (torch.Tensor): a torch Tensor object.
tensor_state (TensorState): the target state for transition.
"""
# As the gradient hook can be triggered either before or after post-backward
# tensor's state can be compute -> hold_after_bwd -> ready_for_reduce
# or compute -> ready_for_reduce -> hold_after_bwd
# the second one is invalid, we just ignore ready_for_reduce -> hold_after_bwd
# this function only apply valid state transformation
# invalid calls will be ignored and nothing changes
if (self.tensors_info[tensor].state, tensor_state) not in STATE_TRANS:
return
self.__update_one_tensor_info(self.tensors_info[tensor], tensor_state)
def copy_tensor_to_chunk_slice(
self, tensor: torch.Tensor, data_slice: torch.Tensor, update_ptr: bool = True
) -> None:
"""
Copy data slice to the memory space indexed by the input tensor in the chunk.
Args:
tensor (torch.Tensor): the tensor used to retrieve meta information
data_slice (torch.Tensor): the tensor to be copied to the chunk
"""
# sanity check
assert self.is_gathered
tensor_info = self.tensors_info[tensor]
self.cuda_global_chunk[tensor_info.offset : tensor_info.end].copy_(data_slice.data.flatten())
if update_ptr:
tensor.data = self.cuda_global_chunk[tensor_info.offset : tensor_info.end].view(tensor.shape)
def add_tensor_to_chunk_slice(self, tensor: torch.Tensor, data_slice: torch.Tensor) -> None:
"""
Add data slice to the memory space indexed by the input tensor in the chunk.
Only used when accumulating gradient chunks.
Args:
tensor (torch.Tensor): the tensor used to retrieve meta information
data_slice (torch.Tensor): the tensor to be added to the chunk
"""
# sanity check
assert self.is_gathered
tensor_info = self.tensors_info[tensor]
self.cuda_global_chunk[tensor_info.offset : tensor_info.end].add_(data_slice.data.flatten())
def get_valid_length(self) -> int:
"""Get the valid length of the chunk's payload."""
if self.keep_gathered:
return self.utilized_size
else:
return self.valid_end
def init_pair(self, friend_chunk: "Chunk") -> None:
"""Initialize the paired chunk."""
if self.paired_chunk is None and friend_chunk.paired_chunk is None:
self.paired_chunk = friend_chunk
friend_chunk.paired_chunk = self
else:
assert self.paired_chunk is friend_chunk
assert friend_chunk.paired_chunk is self
def optim_update(self) -> None:
"""Update the fp16 chunks via their fp32 chunks. It's used by the optimizer."""
# sanity check
assert self.paired_chunk is not None
friend_chunk = self.paired_chunk
if self.is_gathered is True:
assert friend_chunk.is_gathered is True
self.cuda_global_chunk.copy_(friend_chunk.cuda_global_chunk)
self.optim_sync_flag = True
elif friend_chunk.device_type in ("cuda", "npu") and self.device_type in ("cuda", "npu"):
self.cuda_shard.copy_(friend_chunk.cuda_shard)
self.optim_sync_flag = True
self.cpu_vis_flag = False
else:
# optim_sync_flag is set to False
# see shard_move function for more details
assert friend_chunk.device_type == "cpu"
assert self.device_type == "cpu"
self.optim_sync_flag = False
self.cpu_vis_flag = False
def get_tensors(self) -> List[torch.Tensor]:
return list(self.tensors_info.keys())
def __gather(self):
if not self.is_gathered:
# sanity check
assert self.cuda_shard is not None
alloc_storage(self.cuda_global_chunk)
gather_list = list(torch.chunk(input=self.cuda_global_chunk, chunks=self.pg_size, dim=0))
dist.all_gather(gather_list, self.cuda_shard, self.torch_pg)
self.cuda_shard = None
self.is_gathered = True
def __scatter(self):
if self.keep_gathered:
return
if self.is_gathered:
# sanity check
assert self.cuda_shard is None
self.cuda_shard = torch.empty(self.shard_size, dtype=self.dtype, device=self.cuda_global_chunk.device)
self.cuda_shard.copy_(self.cuda_global_chunk[self.shard_begin : self.shard_end])
free_storage(self.cuda_global_chunk)
self.is_gathered = False
def __paired_shard_move(self):
assert self.paired_chunk is not None, "chunks should be paired before training"
optim_chunk = self.paired_chunk
assert self.chunk_size == optim_chunk.chunk_size
# only be called when optimizer state is in CPU memory
# the grad and param should be in the same device
assert self.cuda_shard is None
temp = optim_chunk.cpu_shard.to(get_accelerator().get_current_device())
# avoid to transform FP32 in CPU
self.cuda_shard = temp.to(self.dtype)
if not self.pin_memory:
self.cpu_shard = None
def __update_tensors_ptr(self) -> None:
# sanity check
assert self.is_gathered
assert type(self.cuda_global_chunk) == torch.Tensor
for tensor, tensor_info in self.tensors_info.items():
tensor.data = self.cuda_global_chunk[tensor_info.offset : tensor_info.end].view(tensor.shape)
def __update_one_tensor_info(self, tensor_info: TensorInfo, next_state: TensorState):
self.tensor_state_cnter[tensor_info.state] -= 1
tensor_info.state = next_state
self.tensor_state_cnter[tensor_info.state] += 1
def __update_tensors_state(self, next_state: TensorState, prev_state: Optional[TensorState] = None):
for tensor_info in self.tensors_info.values():
if prev_state is None or tensor_info.state == prev_state:
self.__update_one_tensor_info(tensor_info, next_state)
def __hash__(self) -> int:
return hash(id(self))
def __eq__(self, __o: object) -> bool:
return self is __o
def __repr__(self, detailed: bool = True):
output = [
"Chunk Information:\n",
"\tchunk size: {}, chunk dtype: {}, process group size: {}\n".format(
self.chunk_size, self.dtype, self.pg_size
),
"\t# of tensors: {}, utilized size: {}, utilized percentage: {:.2f}\n".format(
self.num_tensors, self.utilized_size, self.utilized_size / self.chunk_size
),
]
def print_tensor(tensor, prefix=""):
output.append(
"{}shape: {}, dtype: {}, device: {}\n".format(prefix, tensor.shape, tensor.dtype, tensor.device)
)
if self.chunk_temp is not None:
output.append("\tchunk temp:\n")
print_tensor(tensor=self.chunk_temp, prefix="\t\t")
if self.cuda_global_chunk is not None and self.cuda_global_chunk.storage().size() > 0:
output.append("\tchunk total:\n")
print_tensor(tensor=self.cuda_global_chunk, prefix="\t\t")
if self.cuda_shard is not None:
output.append("\tcuda shard:\n")
print_tensor(tensor=self.cuda_shard, prefix="\t\t")
if self.cpu_shard is not None:
output.append("\tcpu shard:\n")
print_tensor(tensor=self.cpu_shard, prefix="\t\t")
memory_info = self.memory_usage
output.append("\tmemory usage: cuda {}, cpu {}\n".format(memory_info["cuda"], memory_info["cpu"]))
if detailed:
output.append("\ttensor state monitor:\n")
for st in TensorState:
output.append("\t\t# of {}: {}\n".format(st, self.tensor_state_cnter[st]))
return "".join(output)
def init_grad_chunk(self) -> "Chunk":
"""Init grad chunk. This should be called in grad handler.
Returns:
Chunk: Grad chunk
"""
if self.grad_chunk is None:
# grad chunk is not initialized
grad_chunk = Chunk(
chunk_size=self.chunk_size,
zero_group=self.torch_pg,
dtype=self.dtype,
keep_gathered=self.keep_gathered,
pin_memory=self.pin_memory,
extra_dp_group=self.extra_dp_group,
)
grad_chunk.num_tensors = self.num_tensors
grad_chunk.utilized_size = self.utilized_size
grad_chunk.tensor_state_cnter[TensorState.HOLD] = self.num_tensors
for tensor, state in self.tensors_info.items():
grad_chunk.tensors_info[tensor] = TensorInfo(TensorState.HOLD, state.offset, state.end)
grad_chunk.valid_end = self.valid_end
if grad_chunk.chunk_temp.device.type == "cpu":
grad_chunk.cuda_global_chunk = grad_chunk.chunk_temp.to(get_accelerator().get_current_device())
else:
grad_chunk.cuda_global_chunk = grad_chunk.chunk_temp
grad_chunk.chunk_temp = None
if grad_chunk.pin_memory:
grad_chunk.cpu_shard = torch.empty(
grad_chunk.shard_size, dtype=grad_chunk.dtype, pin_memory=grad_chunk.pin_memory
)
self.grad_chunk = grad_chunk
else:
# grad chunk is initialized, just reallocate cuda global chunk
self.grad_chunk.cuda_shard = None
self.grad_chunk.is_gathered = True
self.grad_chunk.l2_norm = None
alloc_storage(self.grad_chunk.cuda_global_chunk)
return self.grad_chunk