Compare commits

...

61 Commits
v0.4.6 ... main

Author SHA1 Message Date
flybird11111
46ed5d856b
[ci] update ci (#6254)
* fix for async io

* test for upgrading transformers

* add ci machine

* fix

* fix

* fix

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update test_fp16_torch.py

* Update build_on_pr.yml

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fiux

* fix

* fix

* fix

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2025-04-18 16:40:53 +08:00
Yanjia0
7ecdf9a211
Update README.md (#6268)
Image Change from H100 to H200
2025-04-17 12:07:25 +08:00
duanjunwen
44d4053fec
[HotFix] update load lora model Readme; (#6240)
* [fix] update load lora model Readme;

* [fix] update lora infer readme

* [fix] remove useless comments
2025-03-07 14:14:26 +08:00
Hongxin Liu
6d676ee0e9
[release] update version (#6236) 2025-03-03 16:15:09 +08:00
Hongxin Liu
56fe130b15
[hotfix] fix lora load (#6231)
* [hotfix] fix lora load

* [hotfix] fix hp load

* accelerate deepseek loading
2025-03-01 19:04:14 +08:00
Hongxin Liu
f32861ccc5
[misc] update torch version (#6206)
* [misc] update torch version

* fix test

* fix test

* fix test

* fix test
2025-02-24 14:35:48 +08:00
YeAnbang
b9e60559b8
Merge pull request #6208 from hpcaitech/grpo_dev
[Chat] fix colossalchat bugs
2025-02-20 21:23:16 +08:00
pre-commit-ci[bot]
7595c453a5 [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
2025-02-20 10:25:19 +00:00
YeAnbang
53834b74b9 fix num_train_step update 2025-02-20 18:24:04 +08:00
YeAnbang
0171884664 fix inference rebatching bug 2025-02-20 17:28:49 +08:00
Hongxin Liu
9379cbd668
[release] update version (#6195)
* [release] update version

* fix test

* fix test
2025-02-20 11:36:18 +08:00
binmakeswell
24dee8f0b7
[doc] DeepSeek V3/R1 news (#6199)
* [doc] DeepSeek V3/R1 news

* [doc] DeepSeek V3/R1 news

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2025-02-19 15:07:29 +08:00
Hongxin Liu
f73ae55394
[application] add lora sft example data (#6198) 2025-02-18 20:18:18 +08:00
Tong Li
f8b9e88484
[application] Update README (#6196)
* remove unused ray

* remove unused readme

* update readme

* update readme

* update

* update

* add link

* update readme

* update readme

* fix link

* update code

* update cititaion

* update

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update readme

* update project

* add images

* update link

* update note

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2025-02-18 20:17:56 +08:00
Hongxin Liu
d54642a263
[application] add lora sft example (#6192)
* [application] add lora sft example

* update requirements

* update readme

* update comment

* update ci
2025-02-18 13:06:38 +08:00
YeAnbang
d20c8ffd97
Add GRPO and Support RLVR for PPO (#6186)
* add grpo, support rlvr

* add grpo, support rlvr

* tested deepseek r1 pipeline

* add ci

* verify grpo r1

* verify grpo r1

* update readme, remove unused code

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove path

* clean code

* fix circular import

* fix ci OOM

* fix ci OOM

* skip kto tp, fix qwen generation

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2025-02-18 09:43:36 +08:00
flybird11111
ce0ec40811
[checkpointio] fix for async io (#6189) 2025-02-14 17:34:13 +08:00
Hongxin Liu
5ff5323538
[hotfix] fix zero optim save (#6191) 2025-02-14 15:09:50 +08:00
Hongxin Liu
014837e725
[shardformer] support pipeline for deepseek v3 and optimize lora save (#6188)
* [shardformer] support pipeline for deepseek v3

* [checkpointio] fix lora save

* [devops] update ci env

* [booster] optimize lora

* fix test

* fix test
2025-02-14 14:48:54 +08:00
Wenxuan Tan
ec73f1b5e2
[CI] Cleanup Dist Optim tests with shared helper funcs (#6125)
* Refractor and cleanup using common helper funcs. Tests passed

* Update comments

* Fix relative import

* Fix param fetching bug
2025-02-12 13:42:34 +08:00
flybird11111
5c09d726a6
[checkpointio] fix checkpoint for 3d (#6187)
* fix checkpoint io for 3d

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update hybrid_parallel_checkpoint_io.py

* fix

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2025-02-12 11:54:55 +08:00
Hongxin Liu
2b415e5999
[shardformer] support ep for deepseek v3 (#6185)
* [feature] support ep for deepseek v3

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix test

* [shardformer] fix deepseek v3 init

* [lazy] fit lora for lazy init

* [example] support npu for deepseek v3

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2025-02-11 16:10:25 +08:00
flybird11111
17062c83b9
[hotfix] fix hybrid checkpointio for sp+dp (#6184)
* Update hybrid_parallel_plugin.py

* Update hybrid_parallel_plugin.py

* Update hybrid_parallel_plugin.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update build_on_pr.yml

* Update test_zerobubble_pp.py

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2025-02-06 17:21:04 +08:00
Wenxuan Tan
ca0aa2365d
[Issue template] Add checkbox asking for details to reproduce error (#6104)
* Add checkbox asking about reproducing error

* update

* Update

* Update checkbox
2025-01-24 14:36:25 +08:00
Lemon Qin
97e60cbbcb
[checkpointio] gather tensor before unpad it if the tensor is both padded and distributed (#6168) 2025-01-21 10:23:15 +08:00
Guangyao Zhang
5b094a836b
[Inference]Fix example in readme (#6178) 2025-01-08 11:51:50 +08:00
Hongxin Liu
ee81366cac
[checkpointio] support load-pin overlap (#6177)
* [checkpointio] support load-pin overlap

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [test] add conftest

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2025-01-07 16:16:04 +08:00
Hongxin Liu
479067e9bc
[release] update version (#6174)
* [release] update version

* [devops] fix test pypi ci

* [devops] fix test pypi ci
2025-01-03 11:52:23 +08:00
pre-commit-ci[bot]
7fdef9fd6b
[pre-commit.ci] pre-commit autoupdate (#6113)
updates:
- [github.com/pre-commit/mirrors-clang-format: v19.1.2 → v19.1.5](https://github.com/pre-commit/mirrors-clang-format/compare/v19.1.2...v19.1.5)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2025-01-02 10:23:20 +08:00
duanjunwen
a9bedc7a43
[Sharderformer] Support zbv in Sharderformer Policy (#6150)
* [feat] Sharderformer support zbv

* [feat] support chatglm2, command, deepseek for zbv

* [feat] support zbv in shardformer policy:
falcon,gptj,mistral,opt,qwen2,t5, vit, whisper

* [feat] support GPT2FusedLinearConv1D

* [feat] support GPT2FusedLinear (without tp)

* [fix] debug FusedConvLinear

* [shardfromer] support gpt2 policy for zbv, support GPT2FusedLinearConv
Col and Row.

* [Shardformer] support FusedLinear1D base for zbv

* [shardformer] support zbv in FusedLinear1D base, Col, Row

* [shardformer] support zbv in blip2 and sam policy

* [shardformer] fix bug incorrect number of gradients; add fusedLinear
base testcase;

* [fix] fix incorrect number of gradients ;

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [Shardformer] add en doc for zbv;

* [fix] fix typo in Model compatibility table

* [fix] fix API Reference typo

* [Shardformer] add zh-Han doc for zbv

* [fix] fix Linear name; update en & zh doc

* [fix] fix shardformer doc import err

* [fix] fix shardconfig import in doc

* [fix] fix shardformer doc

* [fix] fix shardconfig doc

* [fix] fix config

* [fix] remove shardconfig

* [fix] fix doc

* [feat] add zbv doc string

* [fix] rm doc

* [fix] fix doc

* [fix] empty zbv doc

* [fix] ifx torch version

* [fix] fix torch version

* [fix] fix torch versions

* [fix] fix torch versions

* [fix] fix pyramid versions

* [fix] fix pyramid, zope version

* [fix] try fix workflow

* [fix] try import ShardConfig in yml

* [fix] fix workflow

* [fix] fix workflow

* [fix] fix workflow

* [fix] fix workflow

* [fix] fix ci

* [fix] fix zbv doc

* [fix] fix param for qkv linear, gpt2fused linear; fix requirments;

* [fix] fix policy use fused_linear

* [fix] fix weight grad none, err caused by  weight ptr change

* [fix] fix comm in WeightGradStore

* [fix] fix WeightGradStore pop param

* [fix] remove useless param in doc; fix gpt2 qkv test;

* [shardformer] simplify execute_w_pass_grad_accum;

* [fix] rm useless comments

* [shardformer] simplify execute_w_pass_grad_accum & execute_w_pass

* [shardformer] Run meaningful doc test

* [shadformer] fix doc test cmd;

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2025-01-02 10:22:26 +08:00
Hongxin Liu
af06d162cf
[checkpointio] support non blocking pin load (#6172)
* [checkpointio] support non blocking pin load

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2024-12-25 17:03:25 +08:00
binmakeswell
836992438f
[news] release colossalai for sora (#6166)
* [news] release colossalai for sora

* [news] release colossalai for sora

* [news] release colossalai for sora

* [news] release colossalai for sora
2024-12-23 21:59:39 +08:00
Hongxin Liu
8b0ed61490
[hotfix] improve compatibility (#6165) 2024-12-23 18:57:08 +08:00
binmakeswell
5f82bfa636
[doc] add bonus event (#6164)
* [doc] add bonus event

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2024-12-23 17:41:59 +08:00
duanjunwen
fa9d0318e4
[Hotfix] hotfix normalization (#6163)
* [fix] hotfix normalization

* [hotfix] force doc ci test

* [hotfix] fallback doc
2024-12-23 16:29:48 +08:00
flybird11111
130229fdcb
[checkpointio]support asyncio for 3d (#6152)
* fix

* fix

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update utils.py

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2024-12-23 10:24:22 +08:00
flybird11111
aaafb38851
[Device]Support npu (#6159)
* support npu

* support pretrain

support pretrain

fix

* support lora

fix

fix

* support chatglm

fix

fxi

fix

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

fix

fix

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

fix

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

fix

fix

fix

* Update train.py

* Update train.py

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2024-12-17 15:42:39 +08:00
flybird11111
e994c64568
[checkpointio] fix async io (#6155) 2024-12-16 10:36:28 +08:00
Hongxin Liu
de3d371f65
[hotfix] fix zero comm buffer init (#6154) 2024-12-10 16:46:15 +08:00
duanjunwen
8d826a336e
[fix] fix bug caused by perf version (#6156) 2024-12-10 15:03:16 +08:00
Hongxin Liu
6280cb18b8
[checkpointio] support debug log (#6153)
* [checkpointio] support debug log

* [checkpointio] refactor async writer api

* fix test

* fix test
2024-12-02 11:29:19 +08:00
Hongxin Liu
ab856fd308
[checkpointio] fix zero optimizer async save memory (#6151)
* [checkpointio] fix zero optimizer async save memory

* [checkpointio] fit new tensornvme api

* [checkpointio] fit new tensornvme api
2024-11-25 14:46:31 +08:00
Wang Binluo
8ecff0cb7f
Merge pull request #6149 from ver217/hotfix/ckpt
[checkpointio] disable buffering
2024-11-21 16:05:19 +08:00
ver217
8fddbab04c [checkpointio] disable buffering 2024-11-21 14:33:26 +08:00
Sze-qq
152162a80e
[doc] update cloud link (#6148)
Co-authored-by: Siqi <hpc@192.168.1.4>
2024-11-20 22:00:10 +08:00
Hongxin Liu
cf519dac6a
[optim] hotfix adam load (#6146)
* [optim] hotfix adam load

* [checkpointio] fix optimizer async io

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [checkpointio] update test

* [checkpointio] update test

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2024-11-20 16:36:37 +08:00
Sze-qq
5caad13055
[doc] add hpc cloud intro (#6147)
* update readme

* update readme

---------

Co-authored-by: Siqi <hpc@hpcdeMacBook-Pro.local>
2024-11-20 15:47:30 +08:00
duanjunwen
e0c68ab6d3
[Zerobubble] merge main. (#6142)
* [feat] add zerobubble pp (just a frame now); add POC test for dx_dw; add test for zerobubble;

* [feat] add dw test;

* [fix] fix weight not close;

* [update] update text;

* [feat] add test run_fwd_bwd automatic scheduling;

* [feat] split communication and calculation; fix pop empty send_bwd_buffer error;

* [feat] add test for p & p grad;

* [feat] add comments for ZBV func;

* [fix] rm useless assign and comments;

* [fix] fix ci test; add pytest;

* [feat] add run_fwd_bwd_with_microbatch  (replace input) & test; add p&p.grad assert close test & all pass;

* [feat] add apply v_schedule graph; p & p.grad assert err exist;

* [fix] update

* [feat] fix ci; add assert;

* [feat] fix poc format

* [feat] fix func name & ci; add comments;

* [fix] fix poc test; add comments in poc;

* [feat] add optim backward_b_by_grad

* [feat] fix optimizer bwd b & w; support return accum loss & output

* [feat] add fwd_bwd_step, run_fwd_only;

* [fix] fix optim bwd; add license for v_schedule; remove redundant attributes; fix schedule loop "while"--> "for"; add communication dict;

* [fix] fix communication_map;

* [feat] update test; rm comments;

* [fix] rm zbv in hybridplugin

* [fix] fix optim bwd;

* [fix] fix optim bwd;

* [fix] rm output.data after send fwd;

* [fix] fix bwd step if condition; remove useless comments and format info;

* [fix] fix detach output & release output;

* [fix] rm requir_grad for output;

* [fix] fix requir grad position and detach position and input&output local buffer append position;

* [feat] add memory assertation;

* [fix] fix mem check;

* [fix] mem assertation'

* [fix] fix mem assertation

* [fix] fix mem; use a new model shape; only assert mem less and equal than theo;

* [fix] fix model zoo import;

* [fix] fix redundant detach & clone; add buffer assertation in the end;

* [fix] add output_obj_grad assert None at bwd b step; replace input_obj.require_grad_ with treemap;

* [fix] update optim state dict assert (include param group & state); fix mem assert after add optim;

* [fix] add testcase with microbatch 4;

* [zerobubble]Support ZeroBubble Pipeline (#6034)

* [feat] add zerobubble pp (just a frame now); add POC test for dx_dw; add test for zerobubble;

* [feat] add dw test;

* [fix] fix weight not close;

* [update] update text;

* [feat] add test run_fwd_bwd automatic scheduling;

* [feat] split communication and calculation; fix pop empty send_bwd_buffer error;

* [feat] add test for p & p grad;

* [feat] add comments for ZBV func;

* [fix] rm useless assign and comments;

* [fix] fix ci test; add pytest;

* [feat] add run_fwd_bwd_with_microbatch  (replace input) & test; add p&p.grad assert close test & all pass;

* [feat] add apply v_schedule graph; p & p.grad assert err exist;

* [fix] update

* [feat] fix ci; add assert;

* [feat] fix poc format

* [feat] fix func name & ci; add comments;

* [fix] fix poc test; add comments in poc;

* [feat] add optim backward_b_by_grad

* [feat] fix optimizer bwd b & w; support return accum loss & output

* [feat] add fwd_bwd_step, run_fwd_only;

* [fix] fix optim bwd; add license for v_schedule; remove redundant attributes; fix schedule loop "while"--> "for"; add communication dict;

* [fix] fix communication_map;

* [feat] update test; rm comments;

* [fix] rm zbv in hybridplugin

* [fix] fix optim bwd;

* [fix] fix optim bwd;

* [fix] rm output.data after send fwd;

* [fix] fix bwd step if condition; remove useless comments and format info;

* [fix] fix detach output & release output;

* [fix] rm requir_grad for output;

* [fix] fix requir grad position and detach position and input&output local buffer append position;

* [feat] add memory assertation;

* [fix] fix mem check;

* [fix] mem assertation'

* [fix] fix mem assertation

* [fix] fix mem; use a new model shape; only assert mem less and equal than theo;

* [fix] fix model zoo import;

* [fix] fix redundant detach & clone; add buffer assertation in the end;

* [fix] add output_obj_grad assert None at bwd b step; replace input_obj.require_grad_ with treemap;

* [fix] update optim state dict assert (include param group & state); fix mem assert after add optim;

* [fix] add testcase with microbatch 4;

* [feat] moehybrid support zerobubble;

* [fix] fix zerobubble pp for shardformer type input;

* [feat] add more test;

* [fix] fix require_grad & deallocate call;

* [fix] updatw bwd b&w input; dict --> list[torch.Tensor]

* [fix] fix bwd w input;

* [fix] fix mem assert;

* [fix] fix input_tensors buffer  append input_obj(dict) --> Tuple (microbatch, input_obj) , and all bwd b related cal logic;

* [fix] use tree_flatten replace dict traverse;

* [fix] rm comments;

* [fix] fix fwd branch, fwd pass both micro_batch & internal_inputs'

* [fix] fix pipeline util func deallocate --> release_tensor_data; fix bwd_b loss bwd branch;

* [fix] fix detach clone release order;

* [fix] fix ci --> oom in 4096 hidden dim;

* [fix] fix dumb clone;

* [fix] fix detach_output_obj clone;

* [fix] fix stage_indices;

* [fix] fix traverse; traverse dict --> traverse tensor List;

* [fix] fix zerobubble; support shardformer model type;

* [fix] rm comments;

* [fix] fix test_pipeline_utils ci;

* [fix] remove duplicate arg; rm comments;

* [fix] remove chunk 0 stage 0 bwd b; u don't have to cal micrbatch's dx;

* [fix] rm print & comments;

* [plugin] hybrid support zero bubble pipeline (#6060)

* hybrid support zbv

* fix

fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* Update zero_bubble_pp.py

* fix

* fix-ci

* fix

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

* fix

* [zerobubble]Support ZeroBubble Pipeline (#6034)

* [feat] add zerobubble pp (just a frame now); add POC test for dx_dw; add test for zerobubble;

* [feat] add dw test;

* [fix] fix weight not close;

* [update] update text;

* [feat] add test run_fwd_bwd automatic scheduling;

* [feat] split communication and calculation; fix pop empty send_bwd_buffer error;

* [feat] add test for p & p grad;

* [feat] add comments for ZBV func;

* [fix] rm useless assign and comments;

* [fix] fix ci test; add pytest;

* [feat] add run_fwd_bwd_with_microbatch  (replace input) & test; add p&p.grad assert close test & all pass;

* [feat] add apply v_schedule graph; p & p.grad assert err exist;

* [fix] update

* [feat] fix ci; add assert;

* [feat] fix poc format

* [feat] fix func name & ci; add comments;

* [fix] fix poc test; add comments in poc;

* [feat] add optim backward_b_by_grad

* [feat] fix optimizer bwd b & w; support return accum loss & output

* [feat] add fwd_bwd_step, run_fwd_only;

* [fix] fix optim bwd; add license for v_schedule; remove redundant attributes; fix schedule loop "while"--> "for"; add communication dict;

* [fix] fix communication_map;

* [feat] update test; rm comments;

* [fix] rm zbv in hybridplugin

* [fix] fix optim bwd;

* [fix] fix optim bwd;

* [fix] rm output.data after send fwd;

* [fix] fix bwd step if condition; remove useless comments and format info;

* [fix] fix detach output & release output;

* [fix] rm requir_grad for output;

* [fix] fix requir grad position and detach position and input&output local buffer append position;

* [feat] add memory assertation;

* [fix] fix mem check;

* [fix] mem assertation'

* [fix] fix mem assertation

* [fix] fix mem; use a new model shape; only assert mem less and equal than theo;

* [fix] fix model zoo import;

* [fix] fix redundant detach & clone; add buffer assertation in the end;

* [fix] add output_obj_grad assert None at bwd b step; replace input_obj.require_grad_ with treemap;

* [fix] update optim state dict assert (include param group & state); fix mem assert after add optim;

* [fix] add testcase with microbatch 4;

* hybrid support zbv

* fix

fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update zero_bubble_pp.py

* fix

* fix-ci

* fix

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

* fix

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: duanjunwen <935724073@qq.com>

* [feat] zerobubble support moehybridplugin;

* [feat] update optimizer bwd; ä¸

* [fix] fix build ci;

* [zerobubble] rebase main (#6075)

* fp8 operators for compressed communication

cast_to_fp8, cast_from_fp8, all_reduce_fp8

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix typo

* fix scaling algorithm in FP8 casting

* support fp8 communication in pipeline parallelism

* add fp8_communication flag in the script

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* shardformer fp8

* fix rebase

* remove all to all

* fix shardformer fp8 communication training degradation

* [fp8] support all-gather flat tensor (#5932)

* [fp8] add fp8 comm for low level zero

* [test] add zero fp8 test case

* [Feature] llama shardformer fp8 support (#5938)

* add llama shardformer fp8

* Llama Shardformer Parity

* fix typo

* fix all reduce

* fix pytest failure

* fix reduce op and move function to fp8.py

* fix typo

* [FP8] rebase main (#5963)

* add SimPO

* fix dataloader

* remove debug code

* add orpo

* fix style

* fix colossalai, transformers version

* fix colossalai, transformers version

* fix colossalai, transformers version

* fix torch colossalai version

* update transformers version

* [shardformer] DeepseekMoE support (#5871)

* [Feature] deepseek moe expert parallel implement

* [misc] fix typo, remove redundant file (#5867)

* [misc] fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Feature] deepseek support & unit test

* [misc] remove debug code & useless print

* [misc] fix typos (#5872)

* [Feature] remove modeling file, use auto config. (#5884)

* [misc] fix typos

* [Feature] deepseek support via auto model, remove modeling file

* [misc] delete useless file

* [misc] fix typos

* [Deepseek] remove redundant code (#5888)

* [misc] fix typos

* [Feature] deepseek support via auto model, remove modeling file

* [misc] delete useless file

* [misc] fix typos

* [misc] remove redundant code

* [Feature/deepseek] resolve comment. (#5889)

* [misc] fix typos

* [Feature] deepseek support via auto model, remove modeling file

* [misc] delete useless file

* [misc] fix typos

* [misc] remove redundant code

* [misc] mv module replacement into if branch

* [misc] add some warning message and modify some code in unit test

* [misc] fix typos

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Hoxfix] Fix CUDA_DEVICE_MAX_CONNECTIONS for comm overlap

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [Feat] Diffusion Model(PixArtAlpha/StableDiffusion3) Support (#5838)

* Diffusion Model Inference support

* Stable Diffusion 3 Support

* pixartalpha support

* [HotFix] CI,import,requirements-test for #5838 (#5892)

* [Hot Fix] CI,import,requirements-test

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Feature] Enable PP + SP for llama (#5868)

* fix cross-PP-stage position id length diff bug

* fix typo

* fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* use a one cross entropy func for all shardformer models

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [ShardFormer] Add Ulysses Sequence Parallelism support for Command-R, Qwen2 and ChatGLM (#5897)

* add benchmark for sft, dpo, simpo, orpo. Add benchmarking result. Support lora with gradient checkpoint

* fix style

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix eval

* hotfix citation

* [zero] support all-gather overlap (#5898)

* [zero] support all-gather overlap

* [zero] add overlap all-gather flag

* [misc] fix typo

* [zero] update api

* fix orpo cross entropy loss

* [Auto Parallel]: Speed up intra-op plan generation by 44% (#5446)

* Remove unnecessary calls to deepcopy

* Build DimSpec's difference dict only once

This change considerably speeds up construction speed of DimSpec objects. The difference_dict is the same for each DimSpec object, so a single copy of it is enough.

* Fix documentation of DimSpec's difference method

* [ShardFormer] fix qwen2 sp (#5903)

* [compatibility] support torch 2.2 (#5875)

* Support Pytorch 2.2.2

* keep build_on_pr file and update .compatibility

* fix object_to_tensor usage when torch>=2.3.0 (#5820)

* [misc] support torch2.3 (#5893)

* [misc] support torch2.3

* [devops] update compatibility ci

* [devops] update compatibility ci

* [devops] add debug

* [devops] add debug

* [devops] add debug

* [devops] add debug

* [devops] remove debug

* [devops] remove debug

* [release] update version (#5912)

* [plugin] support all-gather overlap for hybrid parallel (#5919)

* [plugin] fixed all-gather overlap support for hybrid parallel

* add kto

* fix style, add kto data sample

* [Examples] Add lazy init to OPT and GPT examples (#5924)

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [ColossalChat] Hotfix for ColossalChat (#5910)

* add ignore and tiny llama

* fix path issue

* run style

* fix issue

* update bash

* add ignore and tiny llama

* fix path issue

* run style

* fix issue

* update bash

* fix ddp issue

* add Qwen 1.5 32B

* refactor tokenization

* [FIX BUG] UnboundLocalError: cannot access local variable 'default_conversation' where it is not associated with a value (#5931)

* cannot access local variable 'default_conversation' where it is not associated with a value

set default value for 'default_conversation'

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* fix test data

* refactor evaluation

* remove real data path

* remove real data path

* Add n_fused as an input from native_module (#5894)

* [FIX BUG] convert env param to int in (#5934)

* [Hotfix] Fix ZeRO typo #5936

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [Feature] Add a switch to control whether the model checkpoint needs to be saved after each epoch ends (#5941)

* Add a switch to control whether the model checkpoint needs to be saved after each epoch ends

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* fix style

* fix style

* fix style

* [shardformer] hotfix attn mask (#5945)

* [shardformer] hotfix attn mask (#5947)

* [Feat] Distrifusion Acceleration Support for Diffusion Inference (#5895)

* Distrifusion Support source

* comp comm overlap optimization

* sd3 benchmark

* pixart distrifusion bug fix

* sd3 bug fix and benchmark

* generation bug fix

* naming fix

* add docstring, fix counter and shape error

* add reference

* readme and requirement

* [zero] hotfix update master params (#5951)

* [release] update version (#5952)

* [Chat] Fix lora (#5946)

* fix merging

* remove filepath

* fix style

* Update README.md (#5958)

* [hotfix] Remove unused plan section (#5957)

* remove readme

* fix readme

* update

* [test] add mixtral for sequence classification

* [test] add mixtral transformer test

* [moe] fix plugin

* [test] mixtra pp shard test

* [chore] handle non member group

* [zero] solve hang

* [test] pass mixtral shardformer test

* [moe] implement transit between non moe tp and ep

* [zero] solve hang

* [misc] solve booster hang by rename the variable

* solve hang when parallel mode = pp + dp

* [moe] implement submesh initialization

* [moe] add mixtral dp grad scaling when not all experts are activated

* [chore] manually revert unintended commit

* [chore] trivial fix

* [chore] arg pass & remove drop token

* [test] add mixtral modelling test

* [moe] implement tp

* [moe] test deepseek

* [moe] clean legacy code

* [Feature] MoE Ulysses Support (#5918)

* moe sp support

* moe sp bug solve

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [chore] minor fix

* [moe] init moe plugin comm setting with sp

* moe sp + ep bug fix

* [moe] finalize test (no pp)

* [moe] full test for deepseek and mixtral (pp + sp to fix)

* [chore] minor fix after rebase

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [chore] solve moe ckpt test failure and some other arg pass failure

* [moe] remove ops

* [test] fix test: test_zero1_2

* [bug] fix: somehow logger hangs the program

* [moe] deepseek moe sp support

* [test] add check

* [deepseek] replace attn (a workaround for bug in transformers)

* [misc] skip redunant test

* [misc] remove debug/print code

* [moe] refactor mesh assignment

* Revert "[moe] implement submesh initialization"

This reverts commit 2f9bce6686.

* [chore] change moe_pg_mesh to private

* [misc] remove incompatible test config

* [misc] fix ci failure: change default value to false in moe plugin

* [misc] remove useless condition

* [chore] docstring

* [moe] remove force_overlap_comm flag and add warning instead

* [doc] add MoeHybridParallelPlugin docstring

* [moe] solve dp axis issue

* [chore] remove redundant test case, print string & reduce test tokens

* [feat] Dist Loader for Eval (#5950)

* support auto distributed data loader

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* support auto distributed data loader

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix tp error

* remove unused parameters

* remove unused

* update inference

* update docs

* update inference

---------

Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [lora] lora support hybrid parallel plugin (#5956)

* lora support hybrid plugin

* fix

* fix

* fix

* fix

* fp8 operators for compressed communication

cast_to_fp8, cast_from_fp8, all_reduce_fp8

* fix scaling algorithm in FP8 casting

* support fp8 communication in pipeline parallelism

* add fp8_communication flag in the script

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* shardformer fp8

* fix rebase

* remove all to all

* fix shardformer fp8 communication training degradation

* [fp8] support all-gather flat tensor (#5932)

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* Update low_level_optim.py

---------

Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: Haze188 <haze188@qq.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com>
Co-authored-by: Guangyao Zhang <xjtu521@qq.com>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: Stephan Kö <stephankoe@users.noreply.github.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: zhurunhua <1281592874@qq.com>
Co-authored-by: Insu Jang <insujang@umich.edu>
Co-authored-by: Gao, Ruiyuan <905370712@qq.com>
Co-authored-by: hxwang <wang1570@e.ntu.edu.sg>
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: Wang Binluo <32676639+wangbluo@users.noreply.github.com>
Co-authored-by: HangXu <hangxu0304@gmail.com>

* [fp8]support all2all fp8 (#5953)

* support all2all fp8

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [fp8] add fp8 linear (#5967)

* [fp8] add fp8 linear

* [test] fix fp8 linear test condition

* [test] fix fp8 linear test condition

* [test] fix fp8 linear test condition

* [fp8] support fp8 amp for hybrid parallel plugin (#5975)

* [fp8] support fp8 amp for hybrid parallel plugin

* [test] add fp8 hook test

* [fp8] fix fp8 linear compatibility

* fix (#5976)

* [Feature]: support FP8 communication in DDP, FSDP, Gemini (#5928)

* support fp8_communication in the Torch DDP grad comm, FSDP grad comm, and FSDP params comm

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* implement communication hook for FSDP params all-gather

* added unit test for fp8 operators

* support fp8 communication in GeminiPlugin

* update training scripts to support fsdp and fp8 communication

* fixed some minor bugs observed in unit test

* add all_gather_into_tensor_flat_fp8

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add skip the test if torch < 2.2.0

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add skip the test if torch < 2.2.0

* add skip the test if torch < 2.2.0

* add fp8_comm flag

* rebase latest fp8 operators

* rebase latest fp8 operators

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [test ci]Feature/fp8 comm (#5981)

* fix

* fix

* fix

* [fp8] support gemini plugin (#5978)

* [fp8] refactor hook

* [fp8] support gemini plugin

* [example] add fp8 option for llama benchmark

* [fp8] use torch compile (torch >= 2.3.0) (#5979)

* [fp8] use torch compile (torch >= 2.4.0)

* [fp8] set use_fast_accum in linear

* [chore] formal version check

* [chore] fix sig

* [fp8]Moe support fp8 communication (#5977)

* fix

* support moe fp8

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

fix

fi

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [fp8] support hybrid parallel plugin (#5982)

* support fp8 comm for qwen2 model

* support fp8 comm for qwen2 model

* support fp8 comm for qwen2 model

* fp8

* fix

* bert and bloom

* chatglm and command

* gpt2,gptj,bert, falcon,blip2

* mistral,opy,sam,t5,vit,whisper

* fix

* fix

* fix

* [fp8] refactor fp8 linear with compile (#5993)

* [fp8] refactor fp8 linear with compile

* [fp8] fix linear test

* [fp8] fix linear test

* [fp8] support asynchronous FP8 communication (#5997)

* fix

* fix

* fix

* support async all2all

* support async op for all gather

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [fp8] update torch.compile for linear_fp8 to >= 2.4.0 (#6004)

* [fp8] linear perf enhancement

* [fp8]update reduce-scatter test (#6002)

* fix

* fix

* fix

* fix

* [fp8] add use_fp8 option for MoeHybridParallelPlugin (#6009)

* [fp8] zero support fp8 linear. (#6006)

* fix

* fix

* fix

* zero fp8

* zero fp8

* Update requirements.txt

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix the merge

* fix the merge

* fix the merge

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix the merge

* fix

* fix

* fix the merge

* fix

* fix

* fix

* fix

* fix

* fix the merge

* fix

* fix

* fix

* fix

* [fp8] Merge feature/fp8_comm to main branch of Colossalai (#6016)

* add SimPO

* fix dataloader

* remove debug code

* add orpo

* fix style

* fix colossalai, transformers version

* fix colossalai, transformers version

* fix colossalai, transformers version

* fix torch colossalai version

* update transformers version

* [shardformer] DeepseekMoE support (#5871)

* [Feature] deepseek moe expert parallel implement

* [misc] fix typo, remove redundant file (#5867)

* [misc] fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Feature] deepseek support & unit test

* [misc] remove debug code & useless print

* [misc] fix typos (#5872)

* [Feature] remove modeling file, use auto config. (#5884)

* [misc] fix typos

* [Feature] deepseek support via auto model, remove modeling file

* [misc] delete useless file

* [misc] fix typos

* [Deepseek] remove redundant code (#5888)

* [misc] fix typos

* [Feature] deepseek support via auto model, remove modeling file

* [misc] delete useless file

* [misc] fix typos

* [misc] remove redundant code

* [Feature/deepseek] resolve comment. (#5889)

* [misc] fix typos

* [Feature] deepseek support via auto model, remove modeling file

* [misc] delete useless file

* [misc] fix typos

* [misc] remove redundant code

* [misc] mv module replacement into if branch

* [misc] add some warning message and modify some code in unit test

* [misc] fix typos

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Hoxfix] Fix CUDA_DEVICE_MAX_CONNECTIONS for comm overlap

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [Feat] Diffusion Model(PixArtAlpha/StableDiffusion3) Support (#5838)

* Diffusion Model Inference support

* Stable Diffusion 3 Support

* pixartalpha support

* [HotFix] CI,import,requirements-test for #5838 (#5892)

* [Hot Fix] CI,import,requirements-test

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Feature] Enable PP + SP for llama (#5868)

* fix cross-PP-stage position id length diff bug

* fix typo

* fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* use a one cross entropy func for all shardformer models

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [ShardFormer] Add Ulysses Sequence Parallelism support for Command-R, Qwen2 and ChatGLM (#5897)

* add benchmark for sft, dpo, simpo, orpo. Add benchmarking result. Support lora with gradient checkpoint

* fix style

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix eval

* hotfix citation

* [zero] support all-gather overlap (#5898)

* [zero] support all-gather overlap

* [zero] add overlap all-gather flag

* [misc] fix typo

* [zero] update api

* fix orpo cross entropy loss

* [Auto Parallel]: Speed up intra-op plan generation by 44% (#5446)

* Remove unnecessary calls to deepcopy

* Build DimSpec's difference dict only once

This change considerably speeds up construction speed of DimSpec objects. The difference_dict is the same for each DimSpec object, so a single copy of it is enough.

* Fix documentation of DimSpec's difference method

* [ShardFormer] fix qwen2 sp (#5903)

* [compatibility] support torch 2.2 (#5875)

* Support Pytorch 2.2.2

* keep build_on_pr file and update .compatibility

* fix object_to_tensor usage when torch>=2.3.0 (#5820)

* [misc] support torch2.3 (#5893)

* [misc] support torch2.3

* [devops] update compatibility ci

* [devops] update compatibility ci

* [devops] add debug

* [devops] add debug

* [devops] add debug

* [devops] add debug

* [devops] remove debug

* [devops] remove debug

* [release] update version (#5912)

* [plugin] support all-gather overlap for hybrid parallel (#5919)

* [plugin] fixed all-gather overlap support for hybrid parallel

* add kto

* fix style, add kto data sample

* [Examples] Add lazy init to OPT and GPT examples (#5924)

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [ColossalChat] Hotfix for ColossalChat (#5910)

* add ignore and tiny llama

* fix path issue

* run style

* fix issue

* update bash

* add ignore and tiny llama

* fix path issue

* run style

* fix issue

* update bash

* fix ddp issue

* add Qwen 1.5 32B

* refactor tokenization

* [FIX BUG] UnboundLocalError: cannot access local variable 'default_conversation' where it is not associated with a value (#5931)

* cannot access local variable 'default_conversation' where it is not associated with a value

set default value for 'default_conversation'

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* fix test data

* refactor evaluation

* remove real data path

* remove real data path

* Add n_fused as an input from native_module (#5894)

* [FIX BUG] convert env param to int in (#5934)

* [Hotfix] Fix ZeRO typo #5936

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [Feature] Add a switch to control whether the model checkpoint needs to be saved after each epoch ends (#5941)

* Add a switch to control whether the model checkpoint needs to be saved after each epoch ends

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* fix style

* fix style

* fix style

* [shardformer] hotfix attn mask (#5945)

* [shardformer] hotfix attn mask (#5947)

* [Feat] Distrifusion Acceleration Support for Diffusion Inference (#5895)

* Distrifusion Support source

* comp comm overlap optimization

* sd3 benchmark

* pixart distrifusion bug fix

* sd3 bug fix and benchmark

* generation bug fix

* naming fix

* add docstring, fix counter and shape error

* add reference

* readme and requirement

* [zero] hotfix update master params (#5951)

* [release] update version (#5952)

* [Chat] Fix lora (#5946)

* fix merging

* remove filepath

* fix style

* Update README.md (#5958)

* [hotfix] Remove unused plan section (#5957)

* remove readme

* fix readme

* update

* [test] add mixtral for sequence classification

* [test] add mixtral transformer test

* [moe] fix plugin

* [test] mixtra pp shard test

* [chore] handle non member group

* [zero] solve hang

* [test] pass mixtral shardformer test

* [moe] implement transit between non moe tp and ep

* [zero] solve hang

* [misc] solve booster hang by rename the variable

* solve hang when parallel mode = pp + dp

* [moe] implement submesh initialization

* [moe] add mixtral dp grad scaling when not all experts are activated

* [chore] manually revert unintended commit

* [chore] trivial fix

* [chore] arg pass & remove drop token

* [test] add mixtral modelling test

* [moe] implement tp

* [moe] test deepseek

* [moe] clean legacy code

* [Feature] MoE Ulysses Support (#5918)

* moe sp support

* moe sp bug solve

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [chore] minor fix

* [moe] init moe plugin comm setting with sp

* moe sp + ep bug fix

* [moe] finalize test (no pp)

* [moe] full test for deepseek and mixtral (pp + sp to fix)

* [chore] minor fix after rebase

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [chore] solve moe ckpt test failure and some other arg pass failure

* [moe] remove ops

* [test] fix test: test_zero1_2

* [bug] fix: somehow logger hangs the program

* [moe] deepseek moe sp support

* [test] add check

* [deepseek] replace attn (a workaround for bug in transformers)

* [misc] skip redunant test

* [misc] remove debug/print code

* [moe] refactor mesh assignment

* Revert "[moe] implement submesh initialization"

This reverts commit 2f9bce6686.

* [chore] change moe_pg_mesh to private

* [misc] remove incompatible test config

* [misc] fix ci failure: change default value to false in moe plugin

* [misc] remove useless condition

* [chore] docstring

* [moe] remove force_overlap_comm flag and add warning instead

* [doc] add MoeHybridParallelPlugin docstring

* [moe] solve dp axis issue

* [chore] remove redundant test case, print string & reduce test tokens

* [feat] Dist Loader for Eval (#5950)

* support auto distributed data loader

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* support auto distributed data loader

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix tp error

* remove unused parameters

* remove unused

* update inference

* update docs

* update inference

---------

Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [lora] lora support hybrid parallel plugin (#5956)

* lora support hybrid plugin

* fix

* fix

* fix

* fix

* Support overall loss, update KTO logging

* [Docs] clarify launch port

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [Hotfix] README link (#5966)

* update ignore

* update readme

* run style

* update readme

* [Hotfix] Avoid fused RMSnorm import error without apex (#5985)

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [Chat] fix readme (#5989)

* fix readme

* fix readme, tokenization fully tested

* fix readme, tokenization fully tested

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: root <root@notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9-0.notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9.colossal-ai.svc.cluster.local>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* fix sync condition (#6000)

* [plugin] add cast inputs option for zero (#6003)

* [pre-commit.ci] pre-commit autoupdate (#5995)

updates:
- [github.com/psf/black-pre-commit-mirror: 24.4.2 → 24.8.0](https://github.com/psf/black-pre-commit-mirror/compare/24.4.2...24.8.0)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [misc] Bypass the huggingface bug to solve the mask mismatch problem (#5991)

* [Feature] Zigzag Ring attention (#5905)

* halfway

* fix cross-PP-stage position id length diff bug

* fix typo

* fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* unified cross entropy func for all shardformer models

* remove redundant lines

* add basic ring attn; debug cross entropy

* fwd bwd logic complete

* fwd bwd logic complete; add experimental triton rescale

* precision tests passed

* precision tests passed

* fix typos and remove misc files

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add sp_mode to benchmark; fix varlen interface

* update softmax_lse shape by new interface

* change tester name

* remove buffer clone; support packed seq layout

* add varlen tests

* fix typo

* all tests passed

* add dkv_group; fix mask

* remove debug statements

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [misc] update compatibility (#6008)

* [misc] update compatibility

* [misc] update requirements

* [devops] disable requirements cache

* [test] fix torch ddp test

* [test] fix rerun on address in use

* [test] fix lazy init

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix the merge

* fix the merge

* overlap kv comm with output rescale (#6017)

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* fix the merge

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix the merge

* fix

* fix

* fix the merge

* fix

* [misc] Use dist logger in plugins (#6011)

* use dist logger in plugins

* remove trash

* print on rank 0

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* fix

* fix

* fix

* fix

* fix the merge

* fix

* fix

* fix

* fix

---------

Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: Haze188 <haze188@qq.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com>
Co-authored-by: Guangyao Zhang <xjtu521@qq.com>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: Stephan Kö <stephankoe@users.noreply.github.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: zhurunhua <1281592874@qq.com>
Co-authored-by: Insu Jang <insujang@umich.edu>
Co-authored-by: Gao, Ruiyuan <905370712@qq.com>
Co-authored-by: hxwang <wang1570@e.ntu.edu.sg>
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: root <root@notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9-0.notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9.colossal-ai.svc.cluster.local>

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update train_dpo.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update low_level_zero_plugin.py

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [CI] Remove triton version for compatibility bug; update req torch >=2.2 (#6018)

* remove triton version

* remove torch 2.2

* remove torch 2.1

* debug

* remove 2.1 build tests

* require torch >=2.2

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [plugin] hotfix zero plugin (#6036)

* [plugin] hotfix zero plugin

* [plugin] hotfix zero plugin

* [Colossal-LLaMA] Refactor latest APIs (#6030)

* refactor latest code

* update api

* add dummy dataset

* update Readme

* add setup

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update files

* add PP support

* update arguments

* update argument

* reorg folder

* update version

* remove IB infor

* update utils

* update readme

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update save for zero

* update save

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add apex

* update

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* add fused norm (#6038)

* [FP8] unsqueeze scale to make it compatible with torch.compile (#6040)

* [colossalai/checkpoint_io/...] fix bug in load_state_dict_into_model; format error msg (#6020)

* fix bug in load_state_dict_into_model; format error msg

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update utils.py

to support checking missing_keys

* Update general_checkpoint_io.py

fix bug in missing_keys error message

* retrigger tests

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Hotfix] Remove deprecated install (#6042)

* remove deprecated install

* remove unused folder

* [fp8] optimize all-gather (#6043)

* [fp8] optimize all-gather

* [fp8] fix all gather fp8 ring

* [fp8] enable compile

* [fp8] fix all gather fp8 ring

* [fp8] fix linear hook (#6046)

* [fp8] disable all_to_all_fp8 in intranode  (#6045)

* enhance all_to_all_fp8 with internode comm control

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* disable some fp8 ops due to performance issue

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [release] update version (#6041)

* [release] update version

* [devops] update comp test

* [devops] update comp test debug

* [devops] debug comp test

* [devops] debug comp test

* [devops] debug comp test

* [devops] debug comp test

* [devops] debug comp test

* [Feature] Split cross-entropy computation in SP (#5959)

* halfway

* fix cross-PP-stage position id length diff bug

* fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* unified cross entropy func for all shardformer models

* remove redundant lines

* add basic ring attn; debug cross entropy

* fwd bwd logic complete

* fwd bwd logic complete; add experimental triton rescale

* precision tests passed

* precision tests passed

* fix typos and remove misc files

* update softmax_lse shape by new interface

* change tester name

* remove buffer clone; support packed seq layout

* add varlen tests

* fix typo

* all tests passed

* add dkv_group; fix mask

* remove debug statements

* adapt chatglm, command-R, qwen

* debug

* halfway

* fix cross-PP-stage position id length diff bug

* fix typo

* fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* unified cross entropy func for all shardformer models

* remove redundant lines

* add basic ring attn; debug cross entropy

* fwd bwd logic complete

* fwd bwd logic complete; add experimental triton rescale

* precision tests passed

* precision tests passed

* fix typos and remove misc files

* add sp_mode to benchmark; fix varlen interface

* update softmax_lse shape by new interface

* add varlen tests

* fix typo

* all tests passed

* add dkv_group; fix mask

* remove debug statements

* add comments

* q1 index only once

* remove events to simplify stream sync

* simplify forward/backward logic

* 2d ring forward passed

* 2d ring backward passed

* fixes

* fix ring attn loss

* 2D ring backward + llama passed

* merge

* update logger

* fix typo

* rebase

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix typo

* remove typos

* fixes

* support GPT

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [hotfix] moe hybrid parallelism benchmark & follow-up fix (#6048)

* [example] pass use_fp8_comm flag to all plugins

* [example] add mixtral benchmark

* [moe] refine assertion and check

* [moe] fix mixtral & add more tests

* [moe] consider checking dp * sp group and moe_dp_group

* [mixtral] remove gate tp & add more tests

* [deepseek] fix tp & sp for deepseek

* [mixtral] minor fix

* [deepseek] add deepseek benchmark

* [fp8] hotfix backward hook (#6053)

* [fp8] hotfix backward hook

* [fp8] hotfix pipeline loss accumulation

* [doc] update sp doc (#6055)

* update sp doc

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* fix the sp

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix the attn

* fix

* fix

* fix

* fix

* [zerobubble]Support ZeroBubble Pipeline (#6034)

* [feat] add zerobubble pp (just a frame now); add POC test for dx_dw; add test for zerobubble;

* [feat] add dw test;

* [fix] fix weight not close;

* [update] update text;

* [feat] add test run_fwd_bwd automatic scheduling;

* [feat] split communication and calculation; fix pop empty send_bwd_buffer error;

* [feat] add test for p & p grad;

* [feat] add comments for ZBV func;

* [fix] rm useless assign and comments;

* [fix] fix ci test; add pytest;

* [feat] add run_fwd_bwd_with_microbatch  (replace input) & test; add p&p.grad assert close test & all pass;

* [feat] add apply v_schedule graph; p & p.grad assert err exist;

* [fix] update

* [feat] fix ci; add assert;

* [feat] fix poc format

* [feat] fix func name & ci; add comments;

* [fix] fix poc test; add comments in poc;

* [feat] add optim backward_b_by_grad

* [feat] fix optimizer bwd b & w; support return accum loss & output

* [feat] add fwd_bwd_step, run_fwd_only;

* [fix] fix optim bwd; add license for v_schedule; remove redundant attributes; fix schedule loop "while"--> "for"; add communication dict;

* [fix] fix communication_map;

* [feat] update test; rm comments;

* [fix] rm zbv in hybridplugin

* [fix] fix optim bwd;

* [fix] fix optim bwd;

* [fix] rm output.data after send fwd;

* [fix] fix bwd step if condition; remove useless comments and format info;

* [fix] fix detach output & release output;

* [fix] rm requir_grad for output;

* [fix] fix requir grad position and detach position and input&output local buffer append position;

* [feat] add memory assertation;

* [fix] fix mem check;

* [fix] mem assertation'

* [fix] fix mem assertation

* [fix] fix mem; use a new model shape; only assert mem less and equal than theo;

* [fix] fix model zoo import;

* [fix] fix redundant detach & clone; add buffer assertation in the end;

* [fix] add output_obj_grad assert None at bwd b step; replace input_obj.require_grad_ with treemap;

* [fix] update optim state dict assert (include param group & state); fix mem assert after add optim;

* [fix] add testcase with microbatch 4;

* [fp8] fix missing fp8_comm flag in mixtral (#6057)

* fix

* fix

* fix

* [fp8] Disable all_gather intranode. Disable Redundant all_gather fp8 (#6059)

* all_gather only internode, fix pytest

* fix cuda arch <89 compile pytest error

* fix pytest failure

* disable all_gather_into_tensor_flat_fp8

* fix fp8 format

* fix pytest

* fix conversations

* fix chunk tuple to list

* [doc] FP8 training and communication document (#6050)

* Add FP8 training and communication document

* add fp8 docstring for plugins

* fix typo

* fix typo

* fix

* fix

* [moe] add parallel strategy for shared_expert && fix test for deepseek (#6063)

* [ColossalEval] support for vllm (#6056)

* support vllm

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* modify vllm and update readme

* run pre-commit

* remove dupilicated lines and refine code

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update param name

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* refine code

* update readme

* refine code

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [release] update version (#6062)

* [feat] add zerobubble pp (just a frame now); add POC test for dx_dw; add test for zerobubble;

* [update] update text;

* [feat] add test run_fwd_bwd automatic scheduling;

* [feat] fix poc format

* [fix] fix poc test; add comments in poc;

* [feat] add optim backward_b_by_grad

* [feat] fix optimizer bwd b & w; support return accum loss & output

* [fix] fix optim bwd; add license for v_schedule; remove redundant attributes; fix schedule loop "while"--> "for"; add communication dict;

* [feat] update test; rm comments;

* [fix] rm zbv in hybridplugin

* [fix] fix optim bwd;

* [fix] fix optim bwd;

* [fix] rm output.data after send fwd;

* [fix] fix bwd step if condition; remove useless comments and format info;

* [fix] fix mem check;

* [fix] fix mem assertation

* [fix] fix mem; use a new model shape; only assert mem less and equal than theo;

* [fix] fix model zoo import;

* [feat] moehybrid support zerobubble;

* [fix] fix zerobubble pp for shardformer type input;

* [fix] fix require_grad & deallocate call;

* [fix] fix mem assert;

* [fix] fix fwd branch, fwd pass both micro_batch & internal_inputs'

* [fix] fix pipeline util func deallocate --> release_tensor_data; fix bwd_b loss bwd branch;

* [fix] fix zerobubble; support shardformer model type;

* [fix] fix test_pipeline_utils ci;

* [plugin] hybrid support zero bubble pipeline (#6060)

* hybrid support zbv

* fix

fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* Update zero_bubble_pp.py

* fix

* fix-ci

* fix

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

* fix

* [zerobubble]Support ZeroBubble Pipeline (#6034)

* [feat] add zerobubble pp (just a frame now); add POC test for dx_dw; add test for zerobubble;

* [feat] add dw test;

* [fix] fix weight not close;

* [update] update text;

* [feat] add test run_fwd_bwd automatic scheduling;

* [feat] split communication and calculation; fix pop empty send_bwd_buffer error;

* [feat] add test for p & p grad;

* [feat] add comments for ZBV func;

* [fix] rm useless assign and comments;

* [fix] fix ci test; add pytest;

* [feat] add run_fwd_bwd_with_microbatch  (replace input) & test; add p&p.grad assert close test & all pass;

* [feat] add apply v_schedule graph; p & p.grad assert err exist;

* [fix] update

* [feat] fix ci; add assert;

* [feat] fix poc format

* [feat] fix func name & ci; add comments;

* [fix] fix poc test; add comments in poc;

* [feat] add optim backward_b_by_grad

* [feat] fix optimizer bwd b & w; support return accum loss & output

* [feat] add fwd_bwd_step, run_fwd_only;

* [fix] fix optim bwd; add license for v_schedule; remove redundant attributes; fix schedule loop "while"--> "for"; add communication dict;

* [fix] fix communication_map;

* [feat] update test; rm comments;

* [fix] rm zbv in hybridplugin

* [fix] fix optim bwd;

* [fix] fix optim bwd;

* [fix] rm output.data after send fwd;

* [fix] fix bwd step if condition; remove useless comments and format info;

* [fix] fix detach output & release output;

* [fix] rm requir_grad for output;

* [fix] fix requir grad position and detach position and input&output local buffer append position;

* [feat] add memory assertation;

* [fix] fix mem check;

* [fix] mem assertation'

* [fix] fix mem assertation

* [fix] fix mem; use a new model shape; only assert mem less and equal than theo;

* [fix] fix model zoo import;

* [fix] fix redundant detach & clone; add buffer assertation in the end;

* [fix] add output_obj_grad assert None at bwd b step; replace input_obj.require_grad_ with treemap;

* [fix] update optim state dict assert (include param group & state); fix mem assert after add optim;

* [fix] add testcase with microbatch 4;

* hybrid support zbv

* fix

fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update zero_bubble_pp.py

* fix

* fix-ci

* fix

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

* fix

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: duanjunwen <935724073@qq.com>

* [feat] add zerobubble pp (just a frame now); add POC test for dx_dw; add test for zerobubble;

* [update] update text;

* [feat] add test run_fwd_bwd automatic scheduling;

* [feat] fix poc format

* [fix] fix poc test; add comments in poc;

* [feat] add optim backward_b_by_grad

* [feat] fix optimizer bwd b & w; support return accum loss & output

* [fix] fix optim bwd; add license for v_schedule; remove redundant attributes; fix schedule loop "while"--> "for"; add communication dict;

* [feat] update test; rm comments;

* [fix] fix optim bwd;

* [fix] fix optim bwd;

* [fix] rm output.data after send fwd;

* [fix] fix bwd step if condition; remove useless comments and format info;

* [fix] fix mem check;

* [fix] fix mem assertation

* [fix] fix mem; use a new model shape; only assert mem less and equal than theo;

* [fix] fix model zoo import;

* [fix] fix mem assert;

* [fix] fix fwd branch, fwd pass both micro_batch & internal_inputs'

* [plugin] hybrid support zero bubble pipeline (#6060)

* hybrid support zbv

* fix

fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* Update zero_bubble_pp.py

* fix

* fix-ci

* fix

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

* fix

* [zerobubble]Support ZeroBubble Pipeline (#6034)

* [feat] add zerobubble pp (just a frame now); add POC test for dx_dw; add test for zerobubble;

* [feat] add dw test;

* [fix] fix weight not close;

* [update] update text;

* [feat] add test run_fwd_bwd automatic scheduling;

* [feat] split communication and calculation; fix pop empty send_bwd_buffer error;

* [feat] add test for p & p grad;

* [feat] add comments for ZBV func;

* [fix] rm useless assign and comments;

* [fix] fix ci test; add pytest;

* [feat] add run_fwd_bwd_with_microbatch  (replace input) & test; add p&p.grad assert close test & all pass;

* [feat] add apply v_schedule graph; p & p.grad assert err exist;

* [fix] update

* [feat] fix ci; add assert;

* [feat] fix poc format

* [feat] fix func name & ci; add comments;

* [fix] fix poc test; add comments in poc;

* [feat] add optim backward_b_by_grad

* [feat] fix optimizer bwd b & w; support return accum loss & output

* [feat] add fwd_bwd_step, run_fwd_only;

* [fix] fix optim bwd; add license for v_schedule; remove redundant attributes; fix schedule loop "while"--> "for"; add communication dict;

* [fix] fix communication_map;

* [feat] update test; rm comments;

* [fix] rm zbv in hybridplugin

* [fix] fix optim bwd;

* [fix] fix optim bwd;

* [fix] rm output.data after send fwd;

* [fix] fix bwd step if condition; remove useless comments and format info;

* [fix] fix detach output & release output;

* [fix] rm requir_grad for output;

* [fix] fix requir grad position and detach position and input&output local buffer append position;

* [feat] add memory assertation;

* [fix] fix mem check;

* [fix] mem assertation'

* [fix] fix mem assertation

* [fix] fix mem; use a new model shape; only assert mem less and equal than theo;

* [fix] fix model zoo import;

* [fix] fix redundant detach & clone; add buffer assertation in the end;

* [fix] add output_obj_grad assert None at bwd b step; replace input_obj.require_grad_ with treemap;

* [fix] update optim state dict assert (include param group & state); fix mem assert after add optim;

* [fix] add testcase with microbatch 4;

* hybrid support zbv

* fix

fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update zero_bubble_pp.py

* fix

* fix-ci

* fix

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

* fix

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: duanjunwen <935724073@qq.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: HangXu <hangxu0304@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: GuangyaoZhang <xjtu521@qq.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: Haze188 <haze188@qq.com>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Stephan Kö <stephankoe@users.noreply.github.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: zhurunhua <1281592874@qq.com>
Co-authored-by: Insu Jang <insujang@umich.edu>
Co-authored-by: Gao, Ruiyuan <905370712@qq.com>
Co-authored-by: hxwang <wang1570@e.ntu.edu.sg>
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: Wang Binluo <32676639+wangbluo@users.noreply.github.com>
Co-authored-by: wangbluo <2538539015@qq.com>
Co-authored-by: root <root@notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9-0.notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9.colossal-ai.svc.cluster.local>
Co-authored-by: duanjunwen <935724073@qq.com>
Co-authored-by: Camille Zhong <44392324+Camille7777@users.noreply.github.com>

* [fix] fix mixtral policy;

* [fix] fix mixtral policy;

* [feat] support zbv in mixtral benchmark;

* [fix] MixtralForCausalLMPolicy get_held_layer support zbv;

* [feat] update MixtralPipelineForwards --> mixtral_model_forward; support zbv;

* [feat] support MixtralPipelineForwards--> mixtral_for_causal_lm_forward for zbv

* [zero bubble] support zero (#6080)

* fp8 operators for compressed communication

cast_to_fp8, cast_from_fp8, all_reduce_fp8

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix typo

* fix scaling algorithm in FP8 casting

* support fp8 communication in pipeline parallelism

* add fp8_communication flag in the script

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* shardformer fp8

* fix rebase

* remove all to all

* fix shardformer fp8 communication training degradation

* [fp8] support all-gather flat tensor (#5932)

* [fp8] add fp8 comm for low level zero

* [test] add zero fp8 test case

* [Feature] llama shardformer fp8 support (#5938)

* add llama shardformer fp8

* Llama Shardformer Parity

* fix typo

* fix all reduce

* fix pytest failure

* fix reduce op and move function to fp8.py

* fix typo

* [FP8] rebase main (#5963)

* add SimPO

* fix dataloader

* remove debug code

* add orpo

* fix style

* fix colossalai, transformers version

* fix colossalai, transformers version

* fix colossalai, transformers version

* fix torch colossalai version

* update transformers version

* [shardformer] DeepseekMoE support (#5871)

* [Feature] deepseek moe expert parallel implement

* [misc] fix typo, remove redundant file (#5867)

* [misc] fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Feature] deepseek support & unit test

* [misc] remove debug code & useless print

* [misc] fix typos (#5872)

* [Feature] remove modeling file, use auto config. (#5884)

* [misc] fix typos

* [Feature] deepseek support via auto model, remove modeling file

* [misc] delete useless file

* [misc] fix typos

* [Deepseek] remove redundant code (#5888)

* [misc] fix typos

* [Feature] deepseek support via auto model, remove modeling file

* [misc] delete useless file

* [misc] fix typos

* [misc] remove redundant code

* [Feature/deepseek] resolve comment. (#5889)

* [misc] fix typos

* [Feature] deepseek support via auto model, remove modeling file

* [misc] delete useless file

* [misc] fix typos

* [misc] remove redundant code

* [misc] mv module replacement into if branch

* [misc] add some warning message and modify some code in unit test

* [misc] fix typos

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Hoxfix] Fix CUDA_DEVICE_MAX_CONNECTIONS for comm overlap

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [Feat] Diffusion Model(PixArtAlpha/StableDiffusion3) Support (#5838)

* Diffusion Model Inference support

* Stable Diffusion 3 Support

* pixartalpha support

* [HotFix] CI,import,requirements-test for #5838 (#5892)

* [Hot Fix] CI,import,requirements-test

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Feature] Enable PP + SP for llama (#5868)

* fix cross-PP-stage position id length diff bug

* fix typo

* fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* use a one cross entropy func for all shardformer models

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [ShardFormer] Add Ulysses Sequence Parallelism support for Command-R, Qwen2 and ChatGLM (#5897)

* add benchmark for sft, dpo, simpo, orpo. Add benchmarking result. Support lora with gradient checkpoint

* fix style

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix eval

* hotfix citation

* [zero] support all-gather overlap (#5898)

* [zero] support all-gather overlap

* [zero] add overlap all-gather flag

* [misc] fix typo

* [zero] update api

* fix orpo cross entropy loss

* [Auto Parallel]: Speed up intra-op plan generation by 44% (#5446)

* Remove unnecessary calls to deepcopy

* Build DimSpec's difference dict only once

This change considerably speeds up construction speed of DimSpec objects. The difference_dict is the same for each DimSpec object, so a single copy of it is enough.

* Fix documentation of DimSpec's difference method

* [ShardFormer] fix qwen2 sp (#5903)

* [compatibility] support torch 2.2 (#5875)

* Support Pytorch 2.2.2

* keep build_on_pr file and update .compatibility

* fix object_to_tensor usage when torch>=2.3.0 (#5820)

* [misc] support torch2.3 (#5893)

* [misc] support torch2.3

* [devops] update compatibility ci

* [devops] update compatibility ci

* [devops] add debug

* [devops] add debug

* [devops] add debug

* [devops] add debug

* [devops] remove debug

* [devops] remove debug

* [release] update version (#5912)

* [plugin] support all-gather overlap for hybrid parallel (#5919)

* [plugin] fixed all-gather overlap support for hybrid parallel

* add kto

* fix style, add kto data sample

* [Examples] Add lazy init to OPT and GPT examples (#5924)

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [ColossalChat] Hotfix for ColossalChat (#5910)

* add ignore and tiny llama

* fix path issue

* run style

* fix issue

* update bash

* add ignore and tiny llama

* fix path issue

* run style

* fix issue

* update bash

* fix ddp issue

* add Qwen 1.5 32B

* refactor tokenization

* [FIX BUG] UnboundLocalError: cannot access local variable 'default_conversation' where it is not associated with a value (#5931)

* cannot access local variable 'default_conversation' where it is not associated with a value

set default value for 'default_conversation'

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* fix test data

* refactor evaluation

* remove real data path

* remove real data path

* Add n_fused as an input from native_module (#5894)

* [FIX BUG] convert env param to int in (#5934)

* [Hotfix] Fix ZeRO typo #5936

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [Feature] Add a switch to control whether the model checkpoint needs to be saved after each epoch ends (#5941)

* Add a switch to control whether the model checkpoint needs to be saved after each epoch ends

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* fix style

* fix style

* fix style

* [shardformer] hotfix attn mask (#5945)

* [shardformer] hotfix attn mask (#5947)

* [Feat] Distrifusion Acceleration Support for Diffusion Inference (#5895)

* Distrifusion Support source

* comp comm overlap optimization

* sd3 benchmark

* pixart distrifusion bug fix

* sd3 bug fix and benchmark

* generation bug fix

* naming fix

* add docstring, fix counter and shape error

* add reference

* readme and requirement

* [zero] hotfix update master params (#5951)

* [release] update version (#5952)

* [Chat] Fix lora (#5946)

* fix merging

* remove filepath

* fix style

* Update README.md (#5958)

* [hotfix] Remove unused plan section (#5957)

* remove readme

* fix readme

* update

* [test] add mixtral for sequence classification

* [test] add mixtral transformer test

* [moe] fix plugin

* [test] mixtra pp shard test

* [chore] handle non member group

* [zero] solve hang

* [test] pass mixtral shardformer test

* [moe] implement transit between non moe tp and ep

* [zero] solve hang

* [misc] solve booster hang by rename the variable

* solve hang when parallel mode = pp + dp

* [moe] implement submesh initialization

* [moe] add mixtral dp grad scaling when not all experts are activated

* [chore] manually revert unintended commit

* [chore] trivial fix

* [chore] arg pass & remove drop token

* [test] add mixtral modelling test

* [moe] implement tp

* [moe] test deepseek

* [moe] clean legacy code

* [Feature] MoE Ulysses Support (#5918)

* moe sp support

* moe sp bug solve

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [chore] minor fix

* [moe] init moe plugin comm setting with sp

* moe sp + ep bug fix

* [moe] finalize test (no pp)

* [moe] full test for deepseek and mixtral (pp + sp to fix)

* [chore] minor fix after rebase

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [chore] solve moe ckpt test failure and some other arg pass failure

* [moe] remove ops

* [test] fix test: test_zero1_2

* [bug] fix: somehow logger hangs the program

* [moe] deepseek moe sp support

* [test] add check

* [deepseek] replace attn (a workaround for bug in transformers)

* [misc] skip redunant test

* [misc] remove debug/print code

* [moe] refactor mesh assignment

* Revert "[moe] implement submesh initialization"

This reverts commit 2f9bce6686.

* [chore] change moe_pg_mesh to private

* [misc] remove incompatible test config

* [misc] fix ci failure: change default value to false in moe plugin

* [misc] remove useless condition

* [chore] docstring

* [moe] remove force_overlap_comm flag and add warning instead

* [doc] add MoeHybridParallelPlugin docstring

* [moe] solve dp axis issue

* [chore] remove redundant test case, print string & reduce test tokens

* [feat] Dist Loader for Eval (#5950)

* support auto distributed data loader

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* support auto distributed data loader

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix tp error

* remove unused parameters

* remove unused

* update inference

* update docs

* update inference

---------

Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [lora] lora support hybrid parallel plugin (#5956)

* lora support hybrid plugin

* fix

* fix

* fix

* fix

* fp8 operators for compressed communication

cast_to_fp8, cast_from_fp8, all_reduce_fp8

* fix scaling algorithm in FP8 casting

* support fp8 communication in pipeline parallelism

* add fp8_communication flag in the script

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* shardformer fp8

* fix rebase

* remove all to all

* fix shardformer fp8 communication training degradation

* [fp8] support all-gather flat tensor (#5932)

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* Update low_level_optim.py

---------

Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: Haze188 <haze188@qq.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com>
Co-authored-by: Guangyao Zhang <xjtu521@qq.com>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: Stephan Kö <stephankoe@users.noreply.github.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: zhurunhua <1281592874@qq.com>
Co-authored-by: Insu Jang <insujang@umich.edu>
Co-authored-by: Gao, Ruiyuan <905370712@qq.com>
Co-authored-by: hxwang <wang1570@e.ntu.edu.sg>
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: Wang Binluo <32676639+wangbluo@users.noreply.github.com>
Co-authored-by: HangXu <hangxu0304@gmail.com>

* [fp8]support all2all fp8 (#5953)

* support all2all fp8

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [fp8] add fp8 linear (#5967)

* [fp8] add fp8 linear

* [test] fix fp8 linear test condition

* [test] fix fp8 linear test condition

* [test] fix fp8 linear test condition

* [fp8] support fp8 amp for hybrid parallel plugin (#5975)

* [fp8] support fp8 amp for hybrid parallel plugin

* [test] add fp8 hook test

* [fp8] fix fp8 linear compatibility

* fix (#5976)

* [Feature]: support FP8 communication in DDP, FSDP, Gemini (#5928)

* support fp8_communication in the Torch DDP grad comm, FSDP grad comm, and FSDP params comm

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* implement communication hook for FSDP params all-gather

* added unit test for fp8 operators

* support fp8 communication in GeminiPlugin

* update training scripts to support fsdp and fp8 communication

* fixed some minor bugs observed in unit test

* add all_gather_into_tensor_flat_fp8

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add skip the test if torch < 2.2.0

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add skip the test if torch < 2.2.0

* add skip the test if torch < 2.2.0

* add fp8_comm flag

* rebase latest fp8 operators

* rebase latest fp8 operators

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [test ci]Feature/fp8 comm (#5981)

* fix

* fix

* fix

* [fp8] support gemini plugin (#5978)

* [fp8] refactor hook

* [fp8] support gemini plugin

* [example] add fp8 option for llama benchmark

* [fp8] use torch compile (torch >= 2.3.0) (#5979)

* [fp8] use torch compile (torch >= 2.4.0)

* [fp8] set use_fast_accum in linear

* [chore] formal version check

* [chore] fix sig

* [fp8]Moe support fp8 communication (#5977)

* fix

* support moe fp8

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

fix

fi

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [fp8] support hybrid parallel plugin (#5982)

* support fp8 comm for qwen2 model

* support fp8 comm for qwen2 model

* support fp8 comm for qwen2 model

* fp8

* fix

* bert and bloom

* chatglm and command

* gpt2,gptj,bert, falcon,blip2

* mistral,opy,sam,t5,vit,whisper

* fix

* fix

* fix

* [fp8] refactor fp8 linear with compile (#5993)

* [fp8] refactor fp8 linear with compile

* [fp8] fix linear test

* [fp8] fix linear test

* [fp8] support asynchronous FP8 communication (#5997)

* fix

* fix

* fix

* support async all2all

* support async op for all gather

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [fp8] update torch.compile for linear_fp8 to >= 2.4.0 (#6004)

* [fp8] linear perf enhancement

* [fp8]update reduce-scatter test (#6002)

* fix

* fix

* fix

* fix

* [fp8] add use_fp8 option for MoeHybridParallelPlugin (#6009)

* [fp8] zero support fp8 linear. (#6006)

* fix

* fix

* fix

* zero fp8

* zero fp8

* Update requirements.txt

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix the merge

* fix the merge

* fix the merge

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix the merge

* fix

* fix

* fix the merge

* fix

* fix

* fix

* fix

* fix

* fix the merge

* fix

* fix

* fix

* fix

* [fp8] Merge feature/fp8_comm to main branch of Colossalai (#6016)

* add SimPO

* fix dataloader

* remove debug code

* add orpo

* fix style

* fix colossalai, transformers version

* fix colossalai, transformers version

* fix colossalai, transformers version

* fix torch colossalai version

* update transformers version

* [shardformer] DeepseekMoE support (#5871)

* [Feature] deepseek moe expert parallel implement

* [misc] fix typo, remove redundant file (#5867)

* [misc] fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Feature] deepseek support & unit test

* [misc] remove debug code & useless print

* [misc] fix typos (#5872)

* [Feature] remove modeling file, use auto config. (#5884)

* [misc] fix typos

* [Feature] deepseek support via auto model, remove modeling file

* [misc] delete useless file

* [misc] fix typos

* [Deepseek] remove redundant code (#5888)

* [misc] fix typos

* [Feature] deepseek support via auto model, remove modeling file

* [misc] delete useless file

* [misc] fix typos

* [misc] remove redundant code

* [Feature/deepseek] resolve comment. (#5889)

* [misc] fix typos

* [Feature] deepseek support via auto model, remove modeling file

* [misc] delete useless file

* [misc] fix typos

* [misc] remove redundant code

* [misc] mv module replacement into if branch

* [misc] add some warning message and modify some code in unit test

* [misc] fix typos

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Hoxfix] Fix CUDA_DEVICE_MAX_CONNECTIONS for comm overlap

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [Feat] Diffusion Model(PixArtAlpha/StableDiffusion3) Support (#5838)

* Diffusion Model Inference support

* Stable Diffusion 3 Support

* pixartalpha support

* [HotFix] CI,import,requirements-test for #5838 (#5892)

* [Hot Fix] CI,import,requirements-test

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Feature] Enable PP + SP for llama (#5868)

* fix cross-PP-stage position id length diff bug

* fix typo

* fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* use a one cross entropy func for all shardformer models

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [ShardFormer] Add Ulysses Sequence Parallelism support for Command-R, Qwen2 and ChatGLM (#5897)

* add benchmark for sft, dpo, simpo, orpo. Add benchmarking result. Support lora with gradient checkpoint

* fix style

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix eval

* hotfix citation

* [zero] support all-gather overlap (#5898)

* [zero] support all-gather overlap

* [zero] add overlap all-gather flag

* [misc] fix typo

* [zero] update api

* fix orpo cross entropy loss

* [Auto Parallel]: Speed up intra-op plan generation by 44% (#5446)

* Remove unnecessary calls to deepcopy

* Build DimSpec's difference dict only once

This change considerably speeds up construction speed of DimSpec objects. The difference_dict is the same for each DimSpec object, so a single copy of it is enough.

* Fix documentation of DimSpec's difference method

* [ShardFormer] fix qwen2 sp (#5903)

* [compatibility] support torch 2.2 (#5875)

* Support Pytorch 2.2.2

* keep build_on_pr file and update .compatibility

* fix object_to_tensor usage when torch>=2.3.0 (#5820)

* [misc] support torch2.3 (#5893)

* [misc] support torch2.3

* [devops] update compatibility ci

* [devops] update compatibility ci

* [devops] add debug

* [devops] add debug

* [devops] add debug

* [devops] add debug

* [devops] remove debug

* [devops] remove debug

* [release] update version (#5912)

* [plugin] support all-gather overlap for hybrid parallel (#5919)

* [plugin] fixed all-gather overlap support for hybrid parallel

* add kto

* fix style, add kto data sample

* [Examples] Add lazy init to OPT and GPT examples (#5924)

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [ColossalChat] Hotfix for ColossalChat (#5910)

* add ignore and tiny llama

* fix path issue

* run style

* fix issue

* update bash

* add ignore and tiny llama

* fix path issue

* run style

* fix issue

* update bash

* fix ddp issue

* add Qwen 1.5 32B

* refactor tokenization

* [FIX BUG] UnboundLocalError: cannot access local variable 'default_conversation' where it is not associated with a value (#5931)

* cannot access local variable 'default_conversation' where it is not associated with a value

set default value for 'default_conversation'

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* fix test data

* refactor evaluation

* remove real data path

* remove real data path

* Add n_fused as an input from native_module (#5894)

* [FIX BUG] convert env param to int in (#5934)

* [Hotfix] Fix ZeRO typo #5936

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [Feature] Add a switch to control whether the model checkpoint needs to be saved after each epoch ends (#5941)

* Add a switch to control whether the model checkpoint needs to be saved after each epoch ends

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* fix style

* fix style

* fix style

* [shardformer] hotfix attn mask (#5945)

* [shardformer] hotfix attn mask (#5947)

* [Feat] Distrifusion Acceleration Support for Diffusion Inference (#5895)

* Distrifusion Support source

* comp comm overlap optimization

* sd3 benchmark

* pixart distrifusion bug fix

* sd3 bug fix and benchmark

* generation bug fix

* naming fix

* add docstring, fix counter and shape error

* add reference

* readme and requirement

* [zero] hotfix update master params (#5951)

* [release] update version (#5952)

* [Chat] Fix lora (#5946)

* fix merging

* remove filepath

* fix style

* Update README.md (#5958)

* [hotfix] Remove unused plan section (#5957)

* remove readme

* fix readme

* update

* [test] add mixtral for sequence classification

* [test] add mixtral transformer test

* [moe] fix plugin

* [test] mixtra pp shard test

* [chore] handle non member group

* [zero] solve hang

* [test] pass mixtral shardformer test

* [moe] implement transit between non moe tp and ep

* [zero] solve hang

* [misc] solve booster hang by rename the variable

* solve hang when parallel mode = pp + dp

* [moe] implement submesh initialization

* [moe] add mixtral dp grad scaling when not all experts are activated

* [chore] manually revert unintended commit

* [chore] trivial fix

* [chore] arg pass & remove drop token

* [test] add mixtral modelling test

* [moe] implement tp

* [moe] test deepseek

* [moe] clean legacy code

* [Feature] MoE Ulysses Support (#5918)

* moe sp support

* moe sp bug solve

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [chore] minor fix

* [moe] init moe plugin comm setting with sp

* moe sp + ep bug fix

* [moe] finalize test (no pp)

* [moe] full test for deepseek and mixtral (pp + sp to fix)

* [chore] minor fix after rebase

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [chore] solve moe ckpt test failure and some other arg pass failure

* [moe] remove ops

* [test] fix test: test_zero1_2

* [bug] fix: somehow logger hangs the program

* [moe] deepseek moe sp support

* [test] add check

* [deepseek] replace attn (a workaround for bug in transformers)

* [misc] skip redunant test

* [misc] remove debug/print code

* [moe] refactor mesh assignment

* Revert "[moe] implement submesh initialization"

This reverts commit 2f9bce6686.

* [chore] change moe_pg_mesh to private

* [misc] remove incompatible test config

* [misc] fix ci failure: change default value to false in moe plugin

* [misc] remove useless condition

* [chore] docstring

* [moe] remove force_overlap_comm flag and add warning instead

* [doc] add MoeHybridParallelPlugin docstring

* [moe] solve dp axis issue

* [chore] remove redundant test case, print string & reduce test tokens

* [feat] Dist Loader for Eval (#5950)

* support auto distributed data loader

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* support auto distributed data loader

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix tp error

* remove unused parameters

* remove unused

* update inference

* update docs

* update inference

---------

Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [lora] lora support hybrid parallel plugin (#5956)

* lora support hybrid plugin

* fix

* fix

* fix

* fix

* Support overall loss, update KTO logging

* [Docs] clarify launch port

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [Hotfix] README link (#5966)

* update ignore

* update readme

* run style

* update readme

* [Hotfix] Avoid fused RMSnorm import error without apex (#5985)

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [Chat] fix readme (#5989)

* fix readme

* fix readme, tokenization fully tested

* fix readme, tokenization fully tested

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: root <root@notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9-0.notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9.colossal-ai.svc.cluster.local>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* fix sync condition (#6000)

* [plugin] add cast inputs option for zero (#6003)

* [pre-commit.ci] pre-commit autoupdate (#5995)

updates:
- [github.com/psf/black-pre-commit-mirror: 24.4.2 → 24.8.0](https://github.com/psf/black-pre-commit-mirror/compare/24.4.2...24.8.0)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [misc] Bypass the huggingface bug to solve the mask mismatch problem (#5991)

* [Feature] Zigzag Ring attention (#5905)

* halfway

* fix cross-PP-stage position id length diff bug

* fix typo

* fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* unified cross entropy func for all shardformer models

* remove redundant lines

* add basic ring attn; debug cross entropy

* fwd bwd logic complete

* fwd bwd logic complete; add experimental triton rescale

* precision tests passed

* precision tests passed

* fix typos and remove misc files

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add sp_mode to benchmark; fix varlen interface

* update softmax_lse shape by new interface

* change tester name

* remove buffer clone; support packed seq layout

* add varlen tests

* fix typo

* all tests passed

* add dkv_group; fix mask

* remove debug statements

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [misc] update compatibility (#6008)

* [misc] update compatibility

* [misc] update requirements

* [devops] disable requirements cache

* [test] fix torch ddp test

* [test] fix rerun on address in use

* [test] fix lazy init

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix the merge

* fix the merge

* overlap kv comm with output rescale (#6017)

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* fix the merge

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix the merge

* fix

* fix

* fix the merge

* fix

* [misc] Use dist logger in plugins (#6011)

* use dist logger in plugins

* remove trash

* print on rank 0

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* fix

* fix

* fix

* fix

* fix the merge

* fix

* fix

* fix

* fix

---------

Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: Haze188 <haze188@qq.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com>
Co-authored-by: Guangyao Zhang <xjtu521@qq.com>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: Stephan Kö <stephankoe@users.noreply.github.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: zhurunhua <1281592874@qq.com>
Co-authored-by: Insu Jang <insujang@umich.edu>
Co-authored-by: Gao, Ruiyuan <905370712@qq.com>
Co-authored-by: hxwang <wang1570@e.ntu.edu.sg>
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: root <root@notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9-0.notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9.colossal-ai.svc.cluster.local>

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update train_dpo.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update low_level_zero_plugin.py

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [CI] Remove triton version for compatibility bug; update req torch >=2.2 (#6018)

* remove triton version

* remove torch 2.2

* remove torch 2.1

* debug

* remove 2.1 build tests

* require torch >=2.2

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [plugin] hotfix zero plugin (#6036)

* [plugin] hotfix zero plugin

* [plugin] hotfix zero plugin

* [Colossal-LLaMA] Refactor latest APIs (#6030)

* refactor latest code

* update api

* add dummy dataset

* update Readme

* add setup

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update files

* add PP support

* update arguments

* update argument

* reorg folder

* update version

* remove IB infor

* update utils

* update readme

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update save for zero

* update save

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add apex

* update

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* add fused norm (#6038)

* [FP8] unsqueeze scale to make it compatible with torch.compile (#6040)

* [colossalai/checkpoint_io/...] fix bug in load_state_dict_into_model; format error msg (#6020)

* fix bug in load_state_dict_into_model; format error msg

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update utils.py

to support checking missing_keys

* Update general_checkpoint_io.py

fix bug in missing_keys error message

* retrigger tests

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Hotfix] Remove deprecated install (#6042)

* remove deprecated install

* remove unused folder

* [fp8] optimize all-gather (#6043)

* [fp8] optimize all-gather

* [fp8] fix all gather fp8 ring

* [fp8] enable compile

* [fp8] fix all gather fp8 ring

* [fp8] fix linear hook (#6046)

* [fp8] disable all_to_all_fp8 in intranode  (#6045)

* enhance all_to_all_fp8 with internode comm control

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* disable some fp8 ops due to performance issue

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [release] update version (#6041)

* [release] update version

* [devops] update comp test

* [devops] update comp test debug

* [devops] debug comp test

* [devops] debug comp test

* [devops] debug comp test

* [devops] debug comp test

* [devops] debug comp test

* [Feature] Split cross-entropy computation in SP (#5959)

* halfway

* fix cross-PP-stage position id length diff bug

* fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* unified cross entropy func for all shardformer models

* remove redundant lines

* add basic ring attn; debug cross entropy

* fwd bwd logic complete

* fwd bwd logic complete; add experimental triton rescale

* precision tests passed

* precision tests passed

* fix typos and remove misc files

* update softmax_lse shape by new interface

* change tester name

* remove buffer clone; support packed seq layout

* add varlen tests

* fix typo

* all tests passed

* add dkv_group; fix mask

* remove debug statements

* adapt chatglm, command-R, qwen

* debug

* halfway

* fix cross-PP-stage position id length diff bug

* fix typo

* fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* unified cross entropy func for all shardformer models

* remove redundant lines

* add basic ring attn; debug cross entropy

* fwd bwd logic complete

* fwd bwd logic complete; add experimental triton rescale

* precision tests passed

* precision tests passed

* fix typos and remove misc files

* add sp_mode to benchmark; fix varlen interface

* update softmax_lse shape by new interface

* add varlen tests

* fix typo

* all tests passed

* add dkv_group; fix mask

* remove debug statements

* add comments

* q1 index only once

* remove events to simplify stream sync

* simplify forward/backward logic

* 2d ring forward passed

* 2d ring backward passed

* fixes

* fix ring attn loss

* 2D ring backward + llama passed

* merge

* update logger

* fix typo

* rebase

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix typo

* remove typos

* fixes

* support GPT

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [hotfix] moe hybrid parallelism benchmark & follow-up fix (#6048)

* [example] pass use_fp8_comm flag to all plugins

* [example] add mixtral benchmark

* [moe] refine assertion and check

* [moe] fix mixtral & add more tests

* [moe] consider checking dp * sp group and moe_dp_group

* [mixtral] remove gate tp & add more tests

* [deepseek] fix tp & sp for deepseek

* [mixtral] minor fix

* [deepseek] add deepseek benchmark

* [fp8] hotfix backward hook (#6053)

* [fp8] hotfix backward hook

* [fp8] hotfix pipeline loss accumulation

* [doc] update sp doc (#6055)

* update sp doc

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* fix the sp

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix the attn

* fix

* fix

* fix

* fix

* [zerobubble]Support ZeroBubble Pipeline (#6034)

* [feat] add zerobubble pp (just a frame now); add POC test for dx_dw; add test for zerobubble;

* [feat] add dw test;

* [fix] fix weight not close;

* [update] update text;

* [feat] add test run_fwd_bwd automatic scheduling;

* [feat] split communication and calculation; fix pop empty send_bwd_buffer error;

* [feat] add test for p & p grad;

* [feat] add comments for ZBV func;

* [fix] rm useless assign and comments;

* [fix] fix ci test; add pytest;

* [feat] add run_fwd_bwd_with_microbatch  (replace input) & test; add p&p.grad assert close test & all pass;

* [feat] add apply v_schedule graph; p & p.grad assert err exist;

* [fix] update

* [feat] fix ci; add assert;

* [feat] fix poc format

* [feat] fix func name & ci; add comments;

* [fix] fix poc test; add comments in poc;

* [feat] add optim backward_b_by_grad

* [feat] fix optimizer bwd b & w; support return accum loss & output

* [feat] add fwd_bwd_step, run_fwd_only;

* [fix] fix optim bwd; add license for v_schedule; remove redundant attributes; fix schedule loop "while"--> "for"; add communication dict;

* [fix] fix communication_map;

* [feat] update test; rm comments;

* [fix] rm zbv in hybridplugin

* [fix] fix optim bwd;

* [fix] fix optim bwd;

* [fix] rm output.data after send fwd;

* [fix] fix bwd step if condition; remove useless comments and format info;

* [fix] fix detach output & release output;

* [fix] rm requir_grad for output;

* [fix] fix requir grad position and detach position and input&output local buffer append position;

* [feat] add memory assertation;

* [fix] fix mem check;

* [fix] mem assertation'

* [fix] fix mem assertation

* [fix] fix mem; use a new model shape; only assert mem less and equal than theo;

* [fix] fix model zoo import;

* [fix] fix redundant detach & clone; add buffer assertation in the end;

* [fix] add output_obj_grad assert None at bwd b step; replace input_obj.require_grad_ with treemap;

* [fix] update optim state dict assert (include param group & state); fix mem assert after add optim;

* [fix] add testcase with microbatch 4;

* [fp8] fix missing fp8_comm flag in mixtral (#6057)

* fix

* fix

* fix

* [fp8] Disable all_gather intranode. Disable Redundant all_gather fp8 (#6059)

* all_gather only internode, fix pytest

* fix cuda arch <89 compile pytest error

* fix pytest failure

* disable all_gather_into_tensor_flat_fp8

* fix fp8 format

* fix pytest

* fix conversations

* fix chunk tuple to list

* [doc] FP8 training and communication document (#6050)

* Add FP8 training and communication document

* add fp8 docstring for plugins

* fix typo

* fix typo

* fix

* fix

* [moe] add parallel strategy for shared_expert && fix test for deepseek (#6063)

* [ColossalEval] support for vllm (#6056)

* support vllm

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* modify vllm and update readme

* run pre-commit

* remove dupilicated lines and refine code

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update param name

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* refine code

* update readme

* refine code

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [release] update version (#6062)

* [feat] add zerobubble pp (just a frame now); add POC test for dx_dw; add test for zerobubble;

* [update] update text;

* [feat] add test run_fwd_bwd automatic scheduling;

* [feat] fix poc format

* [fix] fix poc test; add comments in poc;

* [feat] add optim backward_b_by_grad

* [feat] fix optimizer bwd b & w; support return accum loss & output

* [fix] fix optim bwd; add license for v_schedule; remove redundant attributes; fix schedule loop "while"--> "for"; add communication dict;

* [feat] update test; rm comments;

* [fix] rm zbv in hybridplugin

* [fix] fix optim bwd;

* [fix] fix optim bwd;

* [fix] rm output.data after send fwd;

* [fix] fix bwd step if condition; remove useless comments and format info;

* [fix] fix mem check;

* [fix] fix mem assertation

* [fix] fix mem; use a new model shape; only assert mem less and equal than theo;

* [fix] fix model zoo import;

* [feat] moehybrid support zerobubble;

* [fix] fix zerobubble pp for shardformer type input;

* [fix] fix require_grad & deallocate call;

* [fix] fix mem assert;

* [fix] fix fwd branch, fwd pass both micro_batch & internal_inputs'

* [fix] fix pipeline util func deallocate --> release_tensor_data; fix bwd_b loss bwd branch;

* [fix] fix zerobubble; support shardformer model type;

* [fix] fix test_pipeline_utils ci;

* [plugin] hybrid support zero bubble pipeline (#6060)

* hybrid support zbv

* fix

fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* Update zero_bubble_pp.py

* fix

* fix-ci

* fix

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

* fix

* [zerobubble]Support ZeroBubble Pipeline (#6034)

* [feat] add zerobubble pp (just a frame now); add POC test for dx_dw; add test for zerobubble;

* [feat] add dw test;

* [fix] fix weight not close;

* [update] update text;

* [feat] add test run_fwd_bwd automatic scheduling;

* [feat] split communication and calculation; fix pop empty send_bwd_buffer error;

* [feat] add test for p & p grad;

* [feat] add comments for ZBV func;

* [fix] rm useless assign and comments;

* [fix] fix ci test; add pytest;

* [feat] add run_fwd_bwd_with_microbatch  (replace input) & test; add p&p.grad assert close test & all pass;

* [feat] add apply v_schedule graph; p & p.grad assert err exist;

* [fix] update

* [feat] fix ci; add assert;

* [feat] fix poc format

* [feat] fix func name & ci; add comments;

* [fix] fix poc test; add comments in poc;

* [feat] add optim backward_b_by_grad

* [feat] fix optimizer bwd b & w; support return accum loss & output

* [feat] add fwd_bwd_step, run_fwd_only;

* [fix] fix optim bwd; add license for v_schedule; remove redundant attributes; fix schedule loop "while"--> "for"; add communication dict;

* [fix] fix communication_map;

* [feat] update test; rm comments;

* [fix] rm zbv in hybridplugin

* [fix] fix optim bwd;

* [fix] fix optim bwd;

* [fix] rm output.data after send fwd;

* [fix] fix bwd step if condition; remove useless comments and format info;

* [fix] fix detach output & release output;

* [fix] rm requir_grad for output;

* [fix] fix requir grad position and detach position and input&output local buffer append position;

* [feat] add memory assertation;

* [fix] fix mem check;

* [fix] mem assertation'

* [fix] fix mem assertation

* [fix] fix mem; use a new model shape; only assert mem less and equal than theo;

* [fix] fix model zoo import;

* [fix] fix redundant detach & clone; add buffer assertation in the end;

* [fix] add output_obj_grad assert None at bwd b step; replace input_obj.require_grad_ with treemap;

* [fix] update optim state dict assert (include param group & state); fix mem assert after add optim;

* [fix] add testcase with microbatch 4;

* hybrid support zbv

* fix

fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update zero_bubble_pp.py

* fix

* fix-ci

* fix

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

* fix

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: duanjunwen <935724073@qq.com>

* [feat] add zerobubble pp (just a frame now); add POC test for dx_dw; add test for zerobubble;

* [update] update text;

* [feat] add test run_fwd_bwd automatic scheduling;

* [feat] fix poc format

* [fix] fix poc test; add comments in poc;

* [feat] add optim backward_b_by_grad

* [feat] fix optimizer bwd b & w; support return accum loss & output

* [fix] fix optim bwd; add license for v_schedule; remove redundant attributes; fix schedule loop "while"--> "for"; add communication dict;

* [feat] update test; rm comments;

* [fix] fix optim bwd;

* [fix] fix optim bwd;

* [fix] rm output.data after send fwd;

* [fix] fix bwd step if condition; remove useless comments and format info;

* [fix] fix mem check;

* [fix] fix mem assertation

* [fix] fix mem; use a new model shape; only assert mem less and equal than theo;

* [fix] fix model zoo import;

* [fix] fix mem assert;

* [fix] fix fwd branch, fwd pass both micro_batch & internal_inputs'

* [plugin] hybrid support zero bubble pipeline (#6060)

* hybrid support zbv

* fix

fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* Update zero_bubble_pp.py

* fix

* fix-ci

* fix

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

* fix

* [zerobubble]Support ZeroBubble Pipeline (#6034)

* [feat] add zerobubble pp (just a frame now); add POC test for dx_dw; add test for zerobubble;

* [feat] add dw test;

* [fix] fix weight not close;

* [update] update text;

* [feat] add test run_fwd_bwd automatic scheduling;

* [feat] split communication and calculation; fix pop empty send_bwd_buffer error;

* [feat] add test for p & p grad;

* [feat] add comments for ZBV func;

* [fix] rm useless assign and comments;

* [fix] fix ci test; add pytest;

* [feat] add run_fwd_bwd_with_microbatch  (replace input) & test; add p&p.grad assert close test & all pass;

* [feat] add apply v_schedule graph; p & p.grad assert err exist;

* [fix] update

* [feat] fix ci; add assert;

* [feat] fix poc format

* [feat] fix func name & ci; add comments;

* [fix] fix poc test; add comments in poc;

* [feat] add optim backward_b_by_grad

* [feat] fix optimizer bwd b & w; support return accum loss & output

* [feat] add fwd_bwd_step, run_fwd_only;

* [fix] fix optim bwd; add license for v_schedule; remove redundant attributes; fix schedule loop "while"--> "for"; add communication dict;

* [fix] fix communication_map;

* [feat] update test; rm comments;

* [fix] rm zbv in hybridplugin

* [fix] fix optim bwd;

* [fix] fix optim bwd;

* [fix] rm output.data after send fwd;

* [fix] fix bwd step if condition; remove useless comments and format info;

* [fix] fix detach output & release output;

* [fix] rm requir_grad for output;

* [fix] fix requir grad position and detach position and input&output local buffer append position;

* [feat] add memory assertation;

* [fix] fix mem check;

* [fix] mem assertation'

* [fix] fix mem assertation

* [fix] fix mem; use a new model shape; only assert mem less and equal than theo;

* [fix] fix model zoo import;

* [fix] fix redundant detach & clone; add buffer assertation in the end;

* [fix] add output_obj_grad assert None at bwd b step; replace input_obj.require_grad_ with treemap;

* [fix] update optim state dict assert (include param group & state); fix mem assert after add optim;

* [fix] add testcase with microbatch 4;

* hybrid support zbv

* fix

fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update zero_bubble_pp.py

* fix

* fix-ci

* fix

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

* fix

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: duanjunwen <935724073@qq.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* zbv support zero

* fix

* fix

* fix

---------

Co-authored-by: HangXu <hangxu0304@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: GuangyaoZhang <xjtu521@qq.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: Haze188 <haze188@qq.com>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Stephan Kö <stephankoe@users.noreply.github.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: zhurunhua <1281592874@qq.com>
Co-authored-by: Insu Jang <insujang@umich.edu>
Co-authored-by: Gao, Ruiyuan <905370712@qq.com>
Co-authored-by: hxwang <wang1570@e.ntu.edu.sg>
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: Wang Binluo <32676639+wangbluo@users.noreply.github.com>
Co-authored-by: wangbluo <2538539015@qq.com>
Co-authored-by: root <root@notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9-0.notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9.colossal-ai.svc.cluster.local>
Co-authored-by: duanjunwen <935724073@qq.com>
Co-authored-by: Camille Zhong <44392324+Camille7777@users.noreply.github.com>

* [fix] fix llama, mixtral benchmark zbv loss none bug; update mixtral & llama policy and modeling;

* [feat] Linear1D_COL/ROW support zbv WeightGradStore;

* [feat] support use_zbv in llama, mixtral modeling; only replace Linear1D_Col/Row policy;

* [fix] fix test case; moe error in second iter

* [feat]EPMixtralSparseMoeBlock (op in MOE) support zbv;

* [fix] fix bwd b; now bwd w only for Layer replaced by Linear1D_Col/Row; other layer perform a fully bwd;

* [fix] debug zbv llama test;

* [fix] rm use_zbv flag in Shardconfig; rm debug info;

* [fix] add & fix  llama test

* [feat] support meta cache, meta_grad_send, meta_tensor_send; fix runtime too long in Recv Bwd; benchmark for llama + Hybrid(tp+pp);

* [fix\ fix fail case test_shard_llama

* [fix] fix test_shard_llama

* [fix] fix llama modeling policy;

* [fix] fix test_shard_llama ci;

* [fix] fix test zerobubble

* [fix] fix handle name; rm useless comments;

* [fix] fix send recv signature;

* [fix] fix comment in llama & benchmark

* [feat] support no tensor parallel Linear in shardformer; Add test for use weightGradStore and not use WeightGradStore

* [fix] fix linear (no tp) ops func name;

* [feat] support zbv in mixtral benchmark; (#6083)

* [feat] support zbv in mixtral benchmark;

* [fix] MixtralForCausalLMPolicy get_held_layer support zbv;

* [feat] update MixtralPipelineForwards --> mixtral_model_forward; support zbv;

* [feat] support MixtralPipelineForwards--> mixtral_for_causal_lm_forward for zbv

* [fix] fix llama, mixtral benchmark zbv loss none bug; update mixtral & llama policy and modeling;

* [feat] Linear1D_COL/ROW support zbv WeightGradStore;

* [feat] support use_zbv in llama, mixtral modeling; only replace Linear1D_Col/Row policy;

* [fix] fix test case; moe error in second iter

* [feat]EPMixtralSparseMoeBlock (op in MOE) support zbv;

* [fix] fix bwd b; now bwd w only for Layer replaced by Linear1D_Col/Row; other layer perform a fully bwd;

* [fix] debug zbv llama test;

* [fix] rm use_zbv flag in Shardconfig; rm debug info;

* [fix] add & fix  llama test

* [feat] support meta cache, meta_grad_send, meta_tensor_send; fix runtime too long in Recv Bwd; benchmark for llama + Hybrid(tp+pp);

* [fix\ fix fail case test_shard_llama

* [fix] fix test_shard_llama

* [fix] fix llama modeling policy;

* [fix] fix test_shard_llama ci;

* [fix] fix test zerobubble

* [fix] fix handle name; rm useless comments;

* [fix] fix send recv signature;

* [fix] fix comment in llama & benchmark

* [feat] support no tensor parallel Linear in shardformer; Add test for use weightGradStore and not use WeightGradStore

* [fix] fix linear (no tp) ops func name;

* [fix] fix fp8 args in HybridParallel

* [fix] fix hybridparall use_fp8 config

* [fix] fix use_fp8 flag

* [fix] fix model zoo init

* [feat] support no_tp Linear for sharderformer.llama

* [fix] fix zbv llama pp4

* [fix] fix send_tensor_metadata & send_grad_metadata;

* [feat] fix testcase;

* [feat] support mixtral policy with zbv tp_Linear & non_tp_Linear

* [feat] update mixtral policy & bert policy for zerobubble

* [fix] fix p2p error in zbv

* [fix] fix attn

* [fix] fix mixtral modeling & policy; update wait handles; doing benchmarking for llama hybrid;

* [fix] fix zbv wait_handle

* [fix] rm debug info; update llama policy; update wait handle

* [fix] fix test_lora

* [fix] fix test_lora in llama policy

* [fix] fix wait handle in run_fwd_bwd

* [fix] remove debug info;

* [fix] rm unused comments

* [fix] fix fp8 overlap code

* [fix] fix yml file & v_schedule comments

* [fix] rm fwd only meta cache comments;

---------

Co-authored-by: flybird11111 <1829166702@qq.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: HangXu <hangxu0304@gmail.com>
Co-authored-by: GuangyaoZhang <xjtu521@qq.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: Haze188 <haze188@qq.com>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Stephan Kö <stephankoe@users.noreply.github.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: zhurunhua <1281592874@qq.com>
Co-authored-by: Insu Jang <insujang@umich.edu>
Co-authored-by: Gao, Ruiyuan <905370712@qq.com>
Co-authored-by: hxwang <wang1570@e.ntu.edu.sg>
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: Wang Binluo <32676639+wangbluo@users.noreply.github.com>
Co-authored-by: wangbluo <2538539015@qq.com>
Co-authored-by: root <root@notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9-0.notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9.colossal-ai.svc.cluster.local>
Co-authored-by: Camille Zhong <44392324+Camille7777@users.noreply.github.com>
2024-11-19 19:00:36 +08:00
ver217
184a653704 [checkpointio] fix pinned state dict 2024-11-19 14:51:39 +08:00
ver217
5fa657f0a1 [checkpointio] fix size compute 2024-11-19 14:51:39 +08:00
flybird11111
eb69e640e5 [async io]supoort async io (#6137)
* support async optimizer save/load

* fix

* fix

* support pin mem

* Update low_level_zero_plugin.py

* fix

* fix

* fix

* fix

* fix
2024-11-19 14:51:39 +08:00
Hongxin Liu
b90835bd32 [checkpointio] fix performance issue (#6139) 2024-11-19 14:51:39 +08:00
Wang Binluo
8e08c27e19 [ckpt] Add async ckpt api (#6136)
* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix
2024-11-19 14:51:39 +08:00
Hongxin Liu
d4a436051d [checkpointio] support async model save (#6131)
* [checkpointio] support async model save

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2024-11-19 14:51:39 +08:00
Hongxin Liu
5a03d2696d
[cli] support run as module option (#6135) 2024-11-14 18:10:37 +08:00
Hanks
cc40fe0e6f
[fix] multi-node backward slowdown (#6134)
* remove redundant memcpy during backward

* get back record_stream
2024-11-14 17:45:49 +08:00
duanjunwen
c2fe3137e2
[hotfix] fix flash attn window_size err (#6132)
* [fix] fix flash attn

* [hotfix] fix flash-atten version

* [fix] fix flash_atten version

* [fix] fix flash-atten versions

* [fix] fix flash-attn not enough values to unpack error

* [fix] fix test_ring_attn

* [fix] fix test ring attn
2024-11-14 17:11:35 +08:00
Hongxin Liu
a2596519fd
[zero] support extra dp (#6123)
* [zero] support extra dp

* [zero] update checkpoint

* fix bugs

* fix bugs
2024-11-12 11:20:46 +08:00
Tong Li
30a9443132
[Coati] Refine prompt for better inference (#6117)
* refine prompt

* update prompt

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2024-11-08 11:00:37 +08:00
Tong Li
7a60161035
update readme (#6116) 2024-11-06 17:24:08 +08:00
Hongxin Liu
a15ab139ad
[plugin] support get_grad_norm (#6115) 2024-11-05 18:12:47 +08:00
199 changed files with 13927 additions and 2827 deletions

View File

@ -1,3 +1,3 @@
2.2.2-12.1.0
2.3.0-12.1.0
2.4.0-12.4.1
2.5.1-12.4.1

View File

@ -1,11 +1,11 @@
{
"build": [
{
"torch_command": "pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu121",
"torch_command": "pip install torch==2.3.0 torchvision==0.18.0 torchaudio==2.3.0 --index-url https://download.pytorch.org/whl/cu121",
"cuda_image": "hpcaitech/cuda-conda:12.1"
},
{
"torch_command": "pip install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cu124",
"torch_command": "pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu124",
"cuda_image": "hpcaitech/cuda-conda:12.4"
}
]

View File

@ -15,6 +15,26 @@ body:
options:
- label: I have searched the existing issues
required: true
- type: checkboxes
attributes:
label: The bug has not been fixed in the latest main branch
options:
- label: I have checked the latest main branch
required: true
- type: dropdown
id: share_script
attributes:
label: Do you feel comfortable sharing a concise (minimal) script that reproduces the error? :)
description: If not, please share your setting/training config, and/or point to the line in the repo that throws the error.
If the issue is not easily reproducible by us, it will reduce the likelihood of getting responses.
options:
- Yes, I will share a minimal reproducible script.
- No, I prefer not to share.
validations:
required: true
- type: textarea
attributes:
label: 🐛 Describe the bug

View File

@ -87,10 +87,10 @@ jobs:
name: Build and Test Colossal-AI
needs: detect
if: needs.detect.outputs.anyLibraryFileChanged == 'true'
runs-on: [self-hosted, gpu]
runs-on: ubuntu-latest
container:
image: hpcaitech/pytorch-cuda:2.2.2-12.1.0
options: --gpus all --rm -v /dev/shm -v /data/scratch:/data/scratch
image: image-cloud.luchentech.com/hpcaitech/pytorch-cuda:2.2.2-12.1.0
options: --gpus all --shm-size=2g --rm -v /dev/shm -v /data/scratch:/data/scratch
timeout-minutes: 90
defaults:
run:
@ -117,7 +117,7 @@ jobs:
cd TensorNVMe
conda install cmake
pip install -r requirements.txt
DISABLE_URING=1 pip install -v .
DISABLE_URING=1 pip install -v --no-cache-dir .
- name: Store TensorNVMe Cache
run: |
@ -166,6 +166,7 @@ jobs:
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
LLAMA_PATH: /data/scratch/llama-tiny
MOE_TENSOR_PATH: /data/scratch/moe_tensors
HF_ENDPOINT: https://hf-mirror.com
- name: Collate artifact
env:
@ -199,7 +200,7 @@ jobs:
fi
- name: Upload test coverage artifact
uses: actions/upload-artifact@v3
uses: actions/upload-artifact@v4
with:
name: report
path: report/

View File

@ -70,6 +70,7 @@ jobs:
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
LLAMA_PATH: /data/scratch/llama-tiny
MOE_TENSOR_PATH: /data/scratch/moe_tensors
HF_ENDPOINT: https://hf-mirror.com
- name: Notify Lark
id: message-preparation

View File

@ -79,3 +79,4 @@ jobs:
LD_LIBRARY_PATH: /github/home/.tensornvme/lib
LLAMA_PATH: /data/scratch/llama-tiny
MOE_TENSOR_PATH: /data/scratch/moe_tensors
HF_ENDPOINT: https://hf-mirror.com

View File

@ -73,3 +73,4 @@ jobs:
LD_LIBRARY_PATH: /github/home/.tensornvme/lib
LLAMA_PATH: /data/scratch/llama-tiny
MOE_TENSOR_PATH: /data/scratch/moe_tensors
HF_ENDPOINT: https://hf-mirror.com

View File

@ -67,6 +67,7 @@ jobs:
LD_LIBRARY_PATH: /github/home/.tensornvme/lib
LLAMA_PATH: /data/scratch/llama-tiny
MOE_TENSOR_PATH: /data/scratch/moe_tensors
HF_ENDPOINT: https://hf-mirror.com
- name: Notify Lark
id: message-preparation

View File

@ -58,6 +58,7 @@ jobs:
# there is no main branch, so it's safe to checkout the main branch from the merged branch
# docer will rebase the remote main branch to the merged branch, so we have to config user
- name: Make the merged branch main
run: |
cd ColossalAI
git checkout -b main

View File

@ -49,6 +49,7 @@ jobs:
# we need to install the requirements.txt first
# as test-pypi may not contain the distributions for libs listed in the txt file
pip install -r requirements/requirements.txt
pip install -U setuptools==68.2.2 wheel
pip install --index-url https://test.pypi.org/simple/ --extra-index-url https://pypi.python.org/pypi colossalai==$VERSION
env:
VERSION: ${{ steps.prep-version.outputs.version }}

View File

@ -31,13 +31,12 @@ jobs:
- name: Install Colossal-AI
run: |
BUILD_EXT=1 pip install --no-cache-dir -v -e .
pip install --no-cache-dir -v -e .
- name: Install ChatGPT
run: |
cd applications/ColossalChat
pip install --no-cache-dir -v .
export BUILD_EXT=1
pip install --no-cache-dir -r examples/requirements.txt
- name: Install Transformers
@ -61,5 +60,6 @@ jobs:
PRETRAINED_MODEL_PATH: ./models
SFT_DATASET: ./sft_data
PROMPT_DATASET: ./prompt_data
PROMPT_RLVR_DATASET: ./prompt_data
PREFERENCE_DATASET: ./preference_data
KTO_DATASET: ./kto_data

View File

@ -22,7 +22,7 @@ repos:
args: ['--line-length=120', '--target-version=py37', '--target-version=py38', '--target-version=py39','--target-version=py310']
- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v19.1.2
rev: v19.1.5
hooks:
- id: clang-format
name: clang formatter

View File

@ -9,7 +9,7 @@
<a href="https://www.colossalai.org/"> Documentation </a> |
<a href="https://github.com/hpcaitech/ColossalAI/tree/main/examples"> Examples </a> |
<a href="https://github.com/hpcaitech/ColossalAI/discussions"> Forum </a> |
<a href="https://cloud.luchentech.com/">GPU Cloud Playground </a> |
<a href="https://colossalai.org/zh-Hans/docs/get_started/bonus/">GPU Cloud Playground </a> |
<a href="https://hpc-ai.com/blog"> Blog </a></h3>
[![GitHub Repo stars](https://img.shields.io/github/stars/hpcaitech/ColossalAI?style=social)](https://github.com/hpcaitech/ColossalAI/stargazers)
@ -25,29 +25,27 @@
</div>
## GPU Cloud HPC-AI.COM Coming
## Get Started with Colossal-AI Without Setup
For a limited time, you can access an H100 Server for just $1! This is your chance to leverage premium GPU power at an unbeatable price.
Plus, when you refer a friend, youll receive 20% cashback or compute credits equal to 100% of their top-up!
Access high-end, on-demand compute for your research instantly—no setup needed.
Our platform offers on-demand premium compute, ensuring safe, permanent data storage even after stopping your instance.
Dont miss this incredible opportunity to accelerate your AI projects!
Sign up now and get $10 in credits!
Unlock premium GPUs and register now at [HPC-AI.COM](https://hpc-ai.com) to receive $10!
Special Bonuses:
Limited Academic Bonuses:
* Top up $1,000 and receive 300 credits
* Top up $500 and receive 100 credits
<div align="center">
<a href="https://youtu.be/ilMQpU71ddI?si=J4JSPzZ03ycYmlki">
<img src="https://github.com/hpcaitech/public_assets/blob/main/colossalai/img/HPCAICOM241010.jpg" width="700" />
<a href="https://hpc-ai.com/?utm_source=github&utm_medium=social&utm_campaign=promotion-colossalai">
<img src="https://github.com/hpcaitech/public_assets/blob/main/colossalai/img/2-2.gif" width="850" />
</a>
</div>
## Latest News
* [2025/02] [DeepSeek 671B Fine-Tuning Guide Revealed—Unlock the Upgraded DeepSeek Suite with One Click, AI Players Ecstatic!](https://company.hpc-ai.com/blog/shocking-release-deepseek-671b-fine-tuning-guide-revealed-unlock-the-upgraded-deepseek-suite-with-one-click-ai-players-ecstatic)
* [2024/12] [The development cost of video generation models has saved by 50%! Open-source solutions are now available with H200 GPU vouchers](https://company.hpc-ai.com/blog/the-development-cost-of-video-generation-models-has-saved-by-50-open-source-solutions-are-now-available-with-h200-gpu-vouchers) [[code]](https://github.com/hpcaitech/Open-Sora/blob/main/scripts/train.py) [[vouchers]](https://colossalai.org/zh-Hans/docs/get_started/bonus/)
* [2024/10] [How to build a low-cost Sora-like app? Solutions for you](https://company.hpc-ai.com/blog/how-to-build-a-low-cost-sora-like-app-solutions-for-you)
* [2024/09] [Singapore Startup HPC-AI Tech Secures 50 Million USD in Series A Funding to Build the Video Generation AI Model and GPU Platform](https://company.hpc-ai.com/blog/singapore-startup-hpc-ai-tech-secures-50-million-usd-in-series-a-funding-to-build-the-video-generation-ai-model-and-gpu-platform)
* [2024/09] [Reducing AI Large Model Training Costs by 30% Requires Just a Single Line of Code From FP8 Mixed Precision Training Upgrades](https://company.hpc-ai.com/blog/reducing-ai-large-model-training-costs-by-30-requires-just-a-single-line-of-code-from-fp8-mixed-precision-training-upgrades)

View File

@ -100,7 +100,7 @@ LLaMA3_Conv = Conversation(
messages=[],
offset=0,
sep_style=SeparatorStyle.ADD_BOS_EOS_TOKEN,
seps=["<|begin_of_text|>", "<|end_of_text|>"],
seps=["<|begin_of_text|>", "<|eot_id|>"],
)
default_conversation = LLaMA3_Conv

View File

@ -88,7 +88,7 @@ def supervised_tokenize_sft(
assert (
tokenizer.bos_token == conversation_template.seps[0] and tokenizer.eos_token == conversation_template.seps[1]
), "`bos_token` and `eos_token` should be the same with `conversation_template.seps`."
), f"`bos_token`{tokenizer.bos_token} and `eos_token`{tokenizer.eos_token} should be the same with `conversation_template.seps`{conversation_template.seps}."
if ignore_index is None:
ignore_index = IGNORE_INDEX

View File

@ -43,6 +43,7 @@ def save_checkpoint(
step: int,
batch_size: int,
coordinator: DistCoordinator,
use_lora: bool = False,
) -> None:
"""
Save model checkpoint, optimizer, LR scheduler and intermedidate running states.
@ -51,7 +52,10 @@ def save_checkpoint(
save_dir = os.path.join(save_dir, f"epoch-{epoch}_step-{step}")
os.makedirs(os.path.join(save_dir, "modeling"), exist_ok=True)
booster.save_model(model, os.path.join(save_dir, "modeling"), shard=True)
if use_lora:
booster.save_lora_as_pretrained(model, os.path.join(save_dir, "modeling"))
else:
booster.save_model(model, os.path.join(save_dir, "modeling"), shard=True)
booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True)
booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler"))

View File

@ -21,6 +21,7 @@ from colossal_llama.utils.ckpt_io import load_checkpoint, save_checkpoint
from colossal_llama.utils.froze import freeze_non_embeds_parameters
from colossal_llama.utils.neftune_patch import activate_neftune, deactivate_neftune
from colossal_llama.utils.utils import all_reduce_mean, format_numel_str, get_model_numel
from peft import LoraConfig
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
@ -65,7 +66,7 @@ def train(args) -> None:
initial_scale=2**16,
max_norm=args.grad_clip,
enable_gradient_accumulation=(args.accumulation_steps > 1),
enable_fused_normalization=torch.cuda.is_available(),
enable_fused_normalization=get_accelerator().is_available(),
enable_flash_attention=args.use_flash_attn,
)
elif args.plugin == "gemini_auto":
@ -75,7 +76,7 @@ def train(args) -> None:
initial_scale=2**16,
max_norm=args.grad_clip,
enable_gradient_accumulation=(args.accumulation_steps > 1),
enable_fused_normalization=torch.cuda.is_available(),
enable_fused_normalization=get_accelerator().is_available(),
enable_flash_attention=args.use_flash_attn,
)
elif args.plugin == "zero2":
@ -101,10 +102,9 @@ def train(args) -> None:
sequence_parallelism_mode=args.sp_mode,
zero_stage=args.zero_stage,
enable_flash_attention=args.use_flash_attn,
enable_fused_normalization=torch.cuda.is_available(),
enable_fused_normalization=get_accelerator().is_available(),
enable_sequence_parallelism=args.enable_sequence_parallelism,
cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False,
parallel_output=False,
max_norm=args.grad_clip,
precision=args.mixed_precision,
microbatch_size=args.microbatch_size,
@ -117,11 +117,17 @@ def train(args) -> None:
# ======================================================
# Initialize Tokenizer, Dataset, Collator and Dataloader
# ======================================================
tokenizer = AutoTokenizer.from_pretrained(args.pretrained)
tokenizer = AutoTokenizer.from_pretrained(args.pretrained, trust_remote_code=True)
if args.pad_token == "eos":
tokenizer.pad_token = tokenizer.eos_token
try:
tokenizer.pad_token = tokenizer.eos_token
except AttributeError:
coordinator.print_on_master(f"pad_token can't be set")
elif args.pad_token == "unk":
tokenizer.pad_token = tokenizer.unk_token
try:
tokenizer.pad_token = tokenizer.unk_token
except AttributeError:
coordinator.print_on_master(f"pad_token can't be set")
tokenizer.add_bos_token = False
tokenizer.add_eos_token = False
@ -164,33 +170,31 @@ def train(args) -> None:
# ======================================================
# Initialize Model, Objective, Optimizer and LR Scheduler
# ======================================================
# When training the ChatGLM model, LoRA and gradient checkpointing are incompatible.
init_ctx = (
LazyInitContext(default_device=get_current_device())
if isinstance(plugin, (GeminiPlugin, HybridParallelPlugin))
if isinstance(plugin, (GeminiPlugin, HybridParallelPlugin)) and args.lora_rank == 0
else nullcontext()
)
with init_ctx:
if args.use_flash_attn:
model = AutoModelForCausalLM.from_pretrained(
args.pretrained,
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
trust_remote_code=True,
)
else:
model = AutoModelForCausalLM.from_pretrained(
args.pretrained,
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
trust_remote_code=True,
)
model = AutoModelForCausalLM.from_pretrained(
args.pretrained,
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
trust_remote_code=True,
)
# Freeze part of parameters.
if args.freeze_non_embeds_params:
freeze_non_embeds_parameters(model=model)
if args.lora_rank > 0:
lora_config = LoraConfig(task_type="CAUSAL_LM", r=args.lora_rank, lora_alpha=32, lora_dropout=0.1)
model = booster.enable_lora(model, lora_config=lora_config)
# this is essential, otherwise the grad checkpoint will not work.
model.train()
if args.use_grad_checkpoint:
model.gradient_checkpointing_enable()
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
model_numel = get_model_numel(model)
@ -327,6 +331,7 @@ def train(args) -> None:
step=step + 1,
batch_size=args.batch_size,
coordinator=coordinator,
use_lora=(args.lora_rank > 0),
)
coordinator.print_on_master(
f"Saved checkpoint at epoch {epoch} step {step + 1} at folder {args.save_dir}"
@ -371,44 +376,45 @@ def train(args) -> None:
total_loss.fill_(0.0)
pbar.update()
# Save modeling.
save_model_condition = (
args.save_interval > 0 and (step + 1) % (args.save_interval * args.accumulation_steps) == 0
)
if not args.skip_save_each_epoch:
save_model_condition = save_model_condition or (step + 1) == len(dataloader)
if save_model_condition and not args.benchmark:
coordinator.print_on_master("\nStart saving model checkpoint with running states")
if args.use_neft:
coordinator.print_on_master("Deactivate NEFTune before saving model.")
deactivate_neftune(model, handle)
accelerator.empty_cache()
save_checkpoint(
save_dir=args.save_dir,
booster=booster,
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
epoch=epoch,
step=step + 1,
batch_size=args.batch_size,
coordinator=coordinator,
)
coordinator.print_on_master(
f"Saved checkpoint at epoch {epoch} step {step + 1} at folder {args.save_dir}"
# Save modeling.
save_model_condition = (
args.save_interval > 0 and (step + 1) % (args.save_interval * args.accumulation_steps) == 0
)
if args.use_neft:
coordinator.print_on_master("Activate NEFTune.")
model, handle = activate_neftune(model)
if not args.skip_save_each_epoch:
save_model_condition = save_model_condition or (step + 1) == len(dataloader)
# Delete cache.
# del batch, batch_labels, batch_output, loss
accelerator.empty_cache()
if save_model_condition and not args.benchmark:
coordinator.print_on_master("\nStart saving model checkpoint with running states")
if args.use_neft:
coordinator.print_on_master("Deactivate NEFTune before saving model.")
deactivate_neftune(model, handle)
accelerator.empty_cache()
save_checkpoint(
save_dir=args.save_dir,
booster=booster,
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
epoch=epoch,
step=step + 1,
batch_size=args.batch_size,
coordinator=coordinator,
use_lora=(args.lora_rank > 0),
)
coordinator.print_on_master(
f"Saved checkpoint at epoch {epoch} step {step + 1} at folder {args.save_dir}"
)
if args.use_neft:
coordinator.print_on_master("Activate NEFTune.")
model, handle = activate_neftune(model)
# Delete cache.
# del batch, batch_labels, batch_output, loss
accelerator.empty_cache()
# the continue epochs are not resumed, so we need to reset the sampler start index and start step
dataloader.sampler.set_start_index(start_index=0)
@ -522,6 +528,7 @@ if __name__ == "__main__":
parser.add_argument(
"--microbatch_size", type=int, default=1, help="Batch size for each process in PP, used for 3d plugin."
)
parser.add_argument("--lora_rank", type=int, default=0, help="lora rank when using lora to train.")
# Additional arguments for benchmark.
parser.add_argument("--num_samples", type=int, default=500, help="Number of samples for benchmarking.")

View File

@ -158,6 +158,7 @@ temp/
applications/ColossalChat/logs
applications/ColossalChat/models
applications/ColossalChat/sft_data
applications/ColossalChat/kto_data
applications/ColossalChat/prompt_data
applications/ColossalChat/preference_data
applications/ColossalChat/temp

View File

@ -7,31 +7,23 @@
## Table of Contents
- [Table of Contents](#table-of-contents)
- [What is ColossalChat and Coati ?](#what-is-colossalchat-and-coati-)
- [What is ColossalChat?](#what-is-colossalchat)
- [Online demo](#online-demo)
- [Install](#install)
- [Install the environment](#install-the-environment)
- [Install the Transformers](#install-the-transformers)
- [How to use?](#how-to-use)
- [Introduction](#introduction)
- [Supervised datasets collection](#step-1-data-collection)
- [RLHF Training Stage1 - Supervised instructs tuning](#rlhf-training-stage1---supervised-instructs-tuning)
- [RLHF Training Stage2 - Training reward model](#rlhf-training-stage2---training-reward-model)
- [RLHF Training Stage3 - Training model with reinforcement learning by human feedback](#rlhf-training-stage3---proximal-policy-optimization)
- [Alternative Option for RLHF: GRPO](#alternative-option-for-rlhf-group-relative-policy-optimization-grpo)
- [Alternative Option For RLHF: DPO](#alternative-option-for-rlhf-direct-preference-optimization)
- [Alternative Option For RLHF: SimPO](#alternative-option-for-rlhf-simple-preference-optimization-simpo)
- [Alternative Option For RLHF: ORPO](#alternative-option-for-rlhf-odds-ratio-preference-optimization-orpo)
- [Alternative Option For RLHF: KTO](#alternative-option-for-rlhf-kahneman-tversky-optimization-kto)
- [SFT for DeepSeek V3/R1](#sft-for-deepseek-v3)
- [Inference Quantization and Serving - After Training](#inference-quantization-and-serving---after-training)
- [Coati7B examples](#coati7b-examples)
- [Generation](#generation)
- [Open QA](#open-qa)
- [Limitation for LLaMA-finetuned models](#limitation)
- [Limitation of dataset](#limitation)
- [Alternative Option For RLHF: DPO](#alternative-option-for-rlhf-direct-preference-optimization)
- [Alternative Option For RLHF: SimPO](#alternative-option-for-rlhf-simple-preference-optimization-simpo)
- [Alternative Option For RLHF: ORPO](#alternative-option-for-rlhf-odds-ratio-preference-optimization-orpo)
- [Alternative Option For RLHF: KTO](#alternative-option-for-rlhf-kahneman-tversky-optimization-kto)
- [O1 Journey](#o1-journey)
- [Inference with Self-refined MCTS](#inference-with-self-refined-mcts)
- [FAQ](#faq)
- [How to save/load checkpoint](#faq)
- [How to train with limited resources](#faq)
- [Invitation to open-source contribution](#invitation-to-open-source-contribution)
- [Quick Preview](#quick-preview)
- [Authors](#authors)
@ -40,9 +32,9 @@
---
## What Is ColossalChat And Coati ?
## What is ColossalChat?
[ColossalChat](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat) is the project to implement LLM with RLHF, powered by the [Colossal-AI](https://github.com/hpcaitech/ColossalAI) project.
[ColossalChat](https://github.com/hpcaitech/ColossalAI/tree/main/applications/ColossalChat) is a project to implement LLM with RLHF, powered by the [Colossal-AI](https://github.com/hpcaitech/ColossalAI).
Coati stands for `ColossalAI Talking Intelligence`. It is the name for the module implemented in this project and is also the name of the large language model developed by the ColossalChat project.
@ -53,8 +45,6 @@ The Coati package provides a unified large language model framework that has imp
- Supervised instructions fine-tuning
- Training reward model
- Reinforcement learning with human feedback
- Quantization inference
- Fast model deploying
- Perfectly integrated with the Hugging Face ecosystem, a high degree of model customization
<div align="center">
@ -114,77 +104,16 @@ cd $COLOSSAL_AI_ROOT/applications/ColossalChat
pip install .
```
## How To Use?
## Introduction
### RLHF Training Stage1 - Supervised Instructs Tuning
Stage1 is supervised instructs fine-tuning (SFT). This step is a crucial part of the RLHF training process, as it involves training a machine learning model using human-provided instructions to learn the initial behavior for the task at hand. Here's a detailed guide on how to SFT your LLM with ColossalChat. More details can be found in [example guideline](./examples/README.md).
#### Step 1: Data Collection
The first step in Stage 1 is to collect a dataset of human demonstrations of the following format.
```json
[
{"messages":
[
{
"from": "user",
"content": "what are some pranks with a pen i can do?"
},
{
"from": "assistant",
"content": "Are you looking for practical joke ideas?"
},
]
},
]
```
#### Step 2: Preprocessing
Once you have collected your SFT dataset, you will need to preprocess it. This involves four steps: data cleaning, data deduplication, formatting and tokenization. In this section, we will focus on formatting and tokenization.
In this code, we provide a flexible way for users to set the conversation template for formatting chat data using Huggingface's newest feature--- chat template. Please follow the [example guideline](./examples/README.md) on how to format and tokenize data.
#### Step 3: Training
Choose a suitable model architecture for your task. Note that your model should be compatible with the tokenizer that you used to tokenize the SFT dataset. You can run [train_sft.sh](./examples/training_scripts/train_sft.sh) to start a supervised instructs fine-tuning. More details can be found in [example guideline](./examples/README.md).
Stage1 is supervised instructs fine-tuning (SFT). This step is a crucial part of the RLHF training process, as it involves training a machine learning model using human-provided instructions to learn the initial behavior for the task at hand. More details can be found in [example guideline](./examples/README.md).
### RLHF Training Stage2 - Training Reward Model
Stage2 trains a reward model, which obtains corresponding scores by manually ranking different outputs for the same prompt and supervises the training of the reward model.
#### Step 1: Data Collection
Below shows the preference dataset format used in training the reward model.
```json
[
{"context": [
{
"from": "human",
"content": "Introduce butterflies species in Oregon."
}
],
"chosen": [
{
"from": "assistant",
"content": "About 150 species of butterflies live in Oregon, with about 100 species are moths..."
},
],
"rejected": [
{
"from": "assistant",
"content": "Are you interested in just the common butterflies? There are a few common ones which will be easy to find..."
},
]
},
]
```
#### Step 2: Preprocessing
Similar to the second step in the previous stage, we format the reward data into the same structured format as used in step 2 of the SFT stage. You can run [prepare_preference_dataset.sh](./examples/data_preparation_scripts/prepare_preference_dataset.sh) to prepare the preference data for reward model training.
#### Step 3: Training
You can run [train_rm.sh](./examples/training_scripts/train_rm.sh) to start the reward model training. More details can be found in [example guideline](./examples/README.md).
### RLHF Training Stage3 - Proximal Policy Optimization
In stage3 we will use reinforcement learning algorithm--- Proximal Policy Optimization (PPO), which is the most complex part of the training process:
@ -193,86 +122,26 @@ In stage3 we will use reinforcement learning algorithm--- Proximal Policy Optimi
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/stage-3.jpeg" width=800/>
</p>
#### Step 1: Data Collection
PPO uses two kind of training data--- the prompt data and the sft data (optional). The first dataset is mandatory, data samples within the prompt dataset ends with a line from "human" and thus the "assistant" needs to generate a response to answer to the "human". Note that you can still use conversation that ends with a line from the "assistant", in that case, the last line will be dropped. Here is an example of the prompt dataset format.
```json
[
{"messages":
[
{
"from": "human",
"content": "what are some pranks with a pen i can do?"
}
]
},
]
```
#### Step 2: Data Preprocessing
To prepare the prompt dataset for PPO training, simply run [prepare_prompt_dataset.sh](./examples/data_preparation_scripts/prepare_prompt_dataset.sh)
#### Step 3: Training
You can run the [train_ppo.sh](./examples/training_scripts/train_ppo.sh) to start PPO training. Here are some unique arguments for PPO, please refer to the training configuration section for other training configuration. More detais can be found in [example guideline](./examples/README.md).
```bash
--pretrain $PRETRAINED_MODEL_PATH \
--rm_pretrain $PRETRAINED_MODEL_PATH \ # reward model architectual
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
--rm_checkpoint_path $REWARD_MODEL_PATH \ # reward model checkpoint path
--prompt_dataset ${prompt_dataset[@]} \ # List of string, the prompt dataset
--ptx_dataset ${ptx_dataset[@]} \ # List of string, the SFT data used in the SFT stage
--ptx_batch_size 1 \ # batch size for calculate ptx loss
--ptx_coef 0.0 \ # none-zero if ptx loss is enable
--num_episodes 2000 \ # number of episodes to train
--num_collect_steps 1 \
--num_update_steps 1 \
--experience_batch_size 8 \
--train_batch_size 4 \
--accumulation_steps 2
```
Each episode has two phases, the collect phase and the update phase. During the collect phase, we will collect experiences (answers generated by actor), store those in ExperienceBuffer. Then data in ExperienceBuffer is used during the update phase to update parameter of actor and critic.
- Without tensor parallelism,
```
experience buffer size
= num_process * num_collect_steps * experience_batch_size
= train_batch_size * accumulation_steps * num_process
```
- With tensor parallelism,
```
num_tp_group = num_process / tp
experience buffer size
= num_tp_group * num_collect_steps * experience_batch_size
= train_batch_size * accumulation_steps * num_tp_group
```
## Alternative Option For RLHF: Direct Preference Optimization (DPO)
### Alternative Option For RLHF: Direct Preference Optimization (DPO)
For those seeking an alternative to Reinforcement Learning from Human Feedback (RLHF), Direct Preference Optimization (DPO) presents a compelling option. DPO, as detailed in this [paper](https://arxiv.org/abs/2305.18290), DPO offers an low-cost way to perform RLHF and usually request less computation resources compares to PPO. Read this [README](./examples/README.md) for more information.
### DPO Training Stage1 - Supervised Instructs Tuning
Please refer the [sft section](#dpo-training-stage1---supervised-instructs-tuning) in the PPO part.
### DPO Training Stage2 - DPO Training
#### Step 1: Data Collection & Preparation
For DPO training, you only need the preference dataset. Please follow the instruction in the [preference dataset preparation section](#rlhf-training-stage2---training-reward-model) to prepare the preference data for DPO training.
#### Step 2: Training
You can run the [train_dpo.sh](./examples/training_scripts/train_dpo.sh) to start DPO training. More detais can be found in [example guideline](./examples/README.md).
## Alternative Option For RLHF: Simple Preference Optimization (SimPO)
### Alternative Option For RLHF: Simple Preference Optimization (SimPO)
Simple Preference Optimization (SimPO) from this [paper](https://arxiv.org/pdf/2405.14734) is similar to DPO but it abandons the use of the reference model, which makes the training more efficient. It also adds a reward shaping term called target reward margin to enhance training stability. It also use length normalization to better align with the inference process. Read this [README](./examples/README.md) for more information.
## Alternative Option For RLHF: Odds Ratio Preference Optimization (ORPO)
### Alternative Option For RLHF: Odds Ratio Preference Optimization (ORPO)
Odds Ratio Preference Optimization (ORPO) from this [paper](https://arxiv.org/pdf/2403.07691) is a reference model free alignment method that use a mixture of SFT loss and a reinforcement leanring loss calculated based on odds-ratio-based implicit reward to makes the training more efficient and stable. Read this [README](./examples/README.md) for more information.
## Alternative Option For RLHF: Kahneman-Tversky Optimization (KTO)
### Alternative Option For RLHF: Kahneman-Tversky Optimization (KTO)
We support the method introduced in the paper [KTO:Model Alignment as Prospect Theoretic Optimization](https://arxiv.org/pdf/2402.01306) (KTO). Which is a aligment method that directly maximize "human utility" of generation results. Read this [README](./examples/README.md) for more information.
## Inference Quantization and Serving - After Training
### Alternative Option For RLHF: Group Relative Policy Optimization (GRPO)
We support the main algorithm used to train DeepSeek R1 model, a variant of Proximal Policy Optimization (PPO), that enhances mathematical reasoning abilities while concurrently optimizing the memory usage of PPO. Read this [README](./examples/README.md) for more information.
### SFT for DeepSeek V3
We support fine-tuning DeepSeek V3/R1 model with LoRA. Read this [README](./examples/README.md) for more information.
### Inference Quantization and Serving - After Training
We provide an online inference server and a benchmark. We aim to run inference on single GPU, so quantization is essential when using large models.
@ -281,182 +150,7 @@ We support 8-bit quantization (RTN), 4-bit quantization (GPTQ), and FP16 inferen
Online inference server scripts can help you deploy your own services.
For more details, see [`inference/`](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat/inference).
## O1 Journey
### Inference with Self-refined MCTS
We provide the implementation of MCT Self-Refine (MCTSr) algorithm, an innovative integration of Large Language Models with Monte Carlo Tree Search.
To run inference with MCTS, simply use the following script.
```python
from coati.reasoner.guided_search.mcts import MCTS
from coati.reasoner.guided_search.prompt_store.qwen import Qwen32B_prompt_CFG
problem = "How Many R in 'Strawberry'"
search_tree = MCTS(problem=problem, max_simulations=8, cfg=Qwen32B_prompt_CFG)
answer = search_tree.simulate()
print(answer)
```
## Coati7B examples
### Generation
<details><summary><b>E-mail</b></summary>
![phd](https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/Phd.png)
</details>
<details><summary><b>coding</b></summary>
![sort](https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/quick_sort.png)
</details>
<details><summary><b>regex</b></summary>
![regex](https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/regex.png)
</details>
<details><summary><b>Tex</b></summary>
![tex](https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/tex.png)
</details>
<details><summary><b>writing</b></summary>
![writing](https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/writing.png)
</details>
<details><summary><b>Table</b></summary>
![Table](https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/table.png)
</details>
### Open QA
<details><summary><b>Game</b></summary>
![Game](https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/game.png)
</details>
<details><summary><b>Travel</b></summary>
![Travel](https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/travel.png)
</details>
<details><summary><b>Physical</b></summary>
![Physical](https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/physical.png)
</details>
<details><summary><b>Chemical</b></summary>
![Chemical](https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/chemical.png)
</details>
<details><summary><b>Economy</b></summary>
![Economy](https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/economy.png)
</details>
You can find more examples in this [repo](https://github.com/XueFuzhao/InstructionWild/blob/main/comparison.md).
### Limitation
<details><summary><b>Limitation for LLaMA-finetuned models</b></summary>
- Both Alpaca and ColossalChat are based on LLaMA. It is hard to compensate for the missing knowledge in the pre-training stage.
- Lack of counting ability: Cannot count the number of items in a list.
- Lack of Logics (reasoning and calculation)
- Tend to repeat the last sentence (fail to produce the end token).
- Poor multilingual results: LLaMA is mainly trained on English datasets (Generation performs better than QA).
</details>
<details><summary><b>Limitation of dataset</b></summary>
- Lack of summarization ability: No such instructions in finetune datasets.
- Lack of multi-turn chat: No such instructions in finetune datasets
- Lack of self-recognition: No such instructions in finetune datasets
- Lack of Safety:
- When the input contains fake facts, the model makes up false facts and explanations.
- Cannot abide by OpenAI's policy: When generating prompts from OpenAI API, it always abides by its policy. So no violation case is in the datasets.
</details>
## FAQ
<details><summary><b>How to save/load checkpoint</b></summary>
We have integrated the Transformers save and load pipeline, allowing users to freely call Hugging Face's language models and save them in the HF format.
- Option 1: Save the model weights, model config and generation config (Note: tokenizer will not be saved) which can be loaded using HF's from_pretrained method.
```python
# if use lora, you can choose to merge lora weights before saving
if args.lora_rank > 0 and args.merge_lora_weights:
from coati.models.lora import LORA_MANAGER
# NOTE: set model to eval to merge LoRA weights
LORA_MANAGER.merge_weights = True
model.eval()
# save model checkpoint after fitting on only rank0
booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True)
```
- Option 2: Save the model weights, model config, generation config, as well as the optimizer, learning rate scheduler, running states (Note: tokenizer will not be saved) which are needed for resuming training.
```python
from coati.utils import save_checkpoint
# save model checkpoint after fitting on only rank0
save_checkpoint(
save_dir=actor_save_dir,
booster=actor_booster,
model=model,
optimizer=optim,
lr_scheduler=lr_scheduler,
epoch=0,
step=step,
batch_size=train_batch_size,
coordinator=coordinator,
)
```
To load the saved checkpoint
```python
from coati.utils import load_checkpoint
start_epoch, start_step, sampler_start_idx = load_checkpoint(
load_dir=checkpoint_path,
booster=booster,
model=model,
optimizer=optim,
lr_scheduler=lr_scheduler,
)
```
</details>
<details><summary><b>How to train with limited resources</b></summary>
Here are some suggestions that can allow you to train a 7B model on a single or multiple consumer-grade GPUs.
`batch_size`, `lora_rank` and `grad_checkpoint` are the most important parameters to successfully train the model. To maintain a descent batch size for gradient calculation, consider increase the accumulation_step and reduce the batch_size on each rank.
If you only have a single 24G GPU. Generally, using lora and "zero2-cpu" will be sufficient.
`gemini` and `gemini-auto` can enable a single 24G GPU to train the whole model without using LoRA if you have sufficient CPU memory. But that strategy doesn't support gradient accumulation.
If you have multiple GPUs each has very limited VRAM, say 8GB. You can try the `3d` for the plugin option, which supports tensor parellelism, set `--tp` to the number of GPUs that you have.
</details>
### Real-time progress
You will find our progress in github [project broad](https://github.com/orgs/hpcaitech/projects/17/views/1).
## Invitation to open-source contribution
Referring to the successful attempts of [BLOOM](https://bigscience.huggingface.co/) and [Stable Diffusion](https://en.wikipedia.org/wiki/Stable_Diffusion), any and all developers and partners with computing powers, datasets, models are welcome to join and build the Colossal-AI community, making efforts towards the era of big AI models from the starting point of replicating ChatGPT!
You may contact us or participate in the following ways:
@ -500,25 +194,17 @@ Thanks so much to all of our amazing contributors!
- Increase the capacity of the fine-tuning model by up to 3.7 times on a single GPU
- Keep in a sufficiently high running speed
| Model Pair | Alpaca-7B ⚔ Coati-7B | Coati-7B ⚔ Alpaca-7B |
| :-----------: | :------------------: | :------------------: |
| Better Cases | 38 ⚔ **41** | **45** ⚔ 33 |
| Win Rate | 48% ⚔ **52%** | **58%** ⚔ 42% |
| Average Score | 7.06 ⚔ **7.13** | **7.31** ⚔ 6.82 |
- Our Coati-7B model performs better than Alpaca-7B when using GPT-4 to evaluate model performance. The Coati-7B model we evaluate is an old version we trained a few weeks ago and the new version is around the corner.
## Authors
Coati is developed by ColossalAI Team:
- [ver217](https://github.com/ver217) Leading the project while contributing to the main framework.
- [ver217](https://github.com/ver217) Leading the project while contributing to the main framework (System Lead).
- [Tong Li](https://github.com/TongLi3701) Leading the project while contributing to the main framework (Algorithm Lead).
- [Anbang Ye](https://github.com/YeAnbang) Contributing to the refactored PPO version with updated acceleration framework. Add support for DPO, SimPO, ORPO.
- [FrankLeeeee](https://github.com/FrankLeeeee) Providing ML infra support and also taking charge of both front-end and back-end development.
- [htzhou](https://github.com/ht-zhou) Contributing to the algorithm and development for RM and PPO training.
- [Fazzie](https://fazzie-key.cool/about/index.html) Contributing to the algorithm and development for SFT.
- [ofey404](https://github.com/ofey404) Contributing to both front-end and back-end development.
- [Wenhao Chen](https://github.com/CWHer) Contributing to subsequent code enhancements and performance improvements.
- [Anbang Ye](https://github.com/YeAnbang) Contributing to the refactored PPO version with updated acceleration framework. Add support for DPO, SimPO, ORPO.
The PhD student from [(HPC-AI) Lab](https://ai.comp.nus.edu.sg/) also contributed a lot to this project.
- [Zangwei Zheng](https://github.com/zhengzangw)
@ -527,7 +213,6 @@ The PhD student from [(HPC-AI) Lab](https://ai.comp.nus.edu.sg/) also contribute
We also appreciate the valuable suggestions provided by [Jian Hu](https://github.com/hijkzzz) regarding the convergence of the PPO algorithm.
## Citations
```bibtex
@article{Hu2021LoRALA,
title = {LoRA: Low-Rank Adaptation of Large Language Models},
@ -598,8 +283,22 @@ We also appreciate the valuable suggestions provided by [Jian Hu](https://github
primaryClass={cs.CL},
url={https://arxiv.org/abs/2403.07691},
}
@misc{shao2024deepseekmathpushinglimitsmathematical,
title={DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models},
author={Zhihong Shao and Peiyi Wang and Qihao Zhu and Runxin Xu and Junxiao Song and Xiao Bi and Haowei Zhang and Mingchuan Zhang and Y. K. Li and Y. Wu and Daya Guo},
year={2024},
eprint={2402.03300},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2402.03300},
}
@misc{logic-rl,
author = {Tian Xie and Qingnan Ren and Yuqian Hong and Zitian Gao and Haoming Luo},
title = {Logic-RL},
howpublished = {https://github.com/Unakar/Logic-RL},
note = {Accessed: 2025-02-03},
year = {2025}
}
```
## Licenses
Coati is licensed under the [Apache 2.0 License](LICENSE).

View File

@ -141,7 +141,7 @@ def setup_conversation_template(
pass
except ValueError as e:
raise ValueError(e)
if not dist.is_initialized() or dist.get_rank() == 0:
if save_path is not None and (not dist.is_initialized() or dist.get_rank() == 0):
os.makedirs(os.path.dirname(save_path), exist_ok=True)
with open(save_path, "w", encoding="utf8") as f:
logger.info(f"Successfully generated a conversation tempalte config, save to {save_path}.")

View File

@ -8,6 +8,7 @@ import os
from dataclasses import dataclass
from typing import Dict, Iterator, List, Optional, Sequence, Union
import jsonlines
import torch
import torch.nn.functional as F
from coati.dataset.utils import chuncate_sequence, pad_to_max_len
@ -155,13 +156,14 @@ class DataCollatorForPromptDataset(DataCollatorForSupervisedDataset):
`input_ids`: `torch.Tensor` of shape (bsz, max_len);
`attention_mask`: `torch.BoolTensor` of shape (bsz, max_len);
"""
gt_answer = [ins.get("gt_answer", None) for ins in instances]
instances = [{"input_ids": ins["input_ids"], "labels": ins["input_ids"]} for ins in instances]
ret = super().__call__(instances=instances)
input_ids = F.pad(
ret["input_ids"], (self.max_length - ret["input_ids"].size(1), 0), value=self.tokenizer.pad_token_id
)
attention_mask = F.pad(ret["attention_mask"], (self.max_length - ret["attention_mask"].size(1), 0), value=False)
return {"input_ids": input_ids, "attention_mask": attention_mask}
return {"input_ids": input_ids, "attention_mask": attention_mask, "gt_answer": gt_answer}
@dataclass
@ -344,3 +346,77 @@ class StatefulDistributedSampler(DistributedSampler):
def set_start_index(self, start_index: int) -> None:
self.start_index = start_index
def apply_chat_template_and_mask(
tokenizer: PreTrainedTokenizer,
chat: List[Dict[str, str]],
max_length: Optional[int] = None,
padding: bool = True,
truncation: bool = True,
ignore_idx: int = -100,
) -> Dict[str, torch.Tensor]:
tokens = []
assistant_mask = []
for i, msg in enumerate(chat):
msg_tokens = tokenizer.apply_chat_template([msg], tokenize=True)
# remove unexpected bos token
if i > 0 and msg_tokens[0] == tokenizer.bos_token_id:
msg_tokens = msg_tokens[1:]
tokens.extend(msg_tokens)
if msg["role"] == "assistant":
assistant_mask.extend([True] * len(msg_tokens))
else:
assistant_mask.extend([False] * len(msg_tokens))
attention_mask = [1] * len(tokens)
if max_length is not None:
if padding and len(tokens) < max_length:
to_pad = max_length - len(tokens)
if tokenizer.padding_side == "right":
tokens.extend([tokenizer.pad_token_id] * to_pad)
assistant_mask.extend([False] * to_pad)
attention_mask.extend([0] * to_pad)
else:
tokens = [tokenizer.pad_token_id] * to_pad + tokens
assistant_mask = [False] * to_pad + assistant_mask
attention_mask = [0] * to_pad + attention_mask
if truncation and len(tokens) > max_length:
tokens = tokens[:max_length]
assistant_mask = assistant_mask[:max_length]
attention_mask = attention_mask[:max_length]
input_ids = torch.tensor(tokens, dtype=torch.long)
attention_mask = torch.tensor(attention_mask, dtype=torch.long)
labels = input_ids.clone()
labels[~torch.tensor(assistant_mask, dtype=torch.bool)] = ignore_idx
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
}
class RawConversationDataset(Dataset):
"""
Raw conversation dataset.
Each instance is a dictionary with fields `system`, `roles`, `messages`, `offset`, `sep_style`, `seps`.
"""
def __init__(self, tokenizer: PreTrainedTokenizer, input_file: str, max_length: int) -> None:
self.tokenizer = tokenizer
self.raw_texts = []
with jsonlines.open(input_file) as f:
for line in f:
self.raw_texts.append(line)
self.tokenized_texts = [None] * len(self.raw_texts)
self.max_length = max_length
def __len__(self) -> int:
return len(self.raw_texts)
def __getitem__(self, index: int):
if self.tokenized_texts[index] is None:
message = self.raw_texts[index]
tokens = apply_chat_template_and_mask(self.tokenizer, message, self.max_length)
self.tokenized_texts[index] = dict(tokens)
return self.tokenized_texts[index]

View File

@ -147,7 +147,6 @@ def tokenize_prompt(
ignore_index: the ignore index when calculate loss during training
max_length: the maximum context length
"""
messages = data_point["messages"]
template = deepcopy(conversation_template)
template.messages = []
@ -167,7 +166,6 @@ def tokenize_prompt(
if len(template.messages) % 2 != 1:
# exclude the answer if provided. keep only the prompt
template.messages = template.messages[:-1]
# Prepare data
prompt = template.get_prompt(length=len(template.messages), add_generation_prompt=True)
tokenized = tokenizer([prompt], add_special_tokens=False)["input_ids"][0]
@ -185,12 +183,21 @@ def tokenize_prompt(
)
# `inputs_decode` can be used to check whether the tokenization method is true.
return dict(
input_ids=tokenized,
inputs_decode=prompt,
seq_length=len(tokenized),
seq_category=data_point["category"] if "category" in data_point else "None",
)
if "gt_answer" in data_point:
return dict(
input_ids=tokenized,
inputs_decode=prompt,
seq_length=len(tokenized),
seq_category=data_point["category"] if "category" in data_point else "None",
gt_answer=data_point["gt_answer"],
)
else:
return dict(
input_ids=tokenized,
inputs_decode=prompt,
seq_length=len(tokenized),
seq_category=data_point["category"] if "category" in data_point else "None",
)
def apply_rlhf_data_format(template: Conversation, tokenizer: Any):

View File

@ -27,6 +27,8 @@ class NaiveExperienceBuffer(ExperienceBuffer):
self.target_device = torch.device(f"cuda:{torch.cuda.current_device()}")
# TODO(ver217): add prefetch
self.items: List[BufferItem] = []
self.rng_sequence = []
self.ptr = 0
@torch.no_grad()
def append(self, experience: Experience) -> None:
@ -40,6 +42,9 @@ class NaiveExperienceBuffer(ExperienceBuffer):
if samples_to_remove > 0:
logger.warning(f"Experience buffer is full. Removing {samples_to_remove} samples.")
self.items = self.items[samples_to_remove:]
self.rng_sequence = [i for i in range(len(self.items))]
random.shuffle(self.rng_sequence)
self.ptr = 0
def clear(self) -> None:
self.items.clear()
@ -52,7 +57,10 @@ class NaiveExperienceBuffer(ExperienceBuffer):
Returns:
A batch of sampled experiences.
"""
items = random.sample(self.items, self.sample_batch_size)
items = []
for _ in range(self.sample_batch_size):
self.ptr = (self.ptr + 1) % len(self.items)
items.append(self.items[self.rng_sequence[self.ptr]])
experience = make_experience_batch(items)
if self.cpu_offload:
experience.to_device(self.target_device)

View File

@ -2,6 +2,8 @@
experience maker.
"""
from typing import Any
import torch
import torch.nn.functional as F
from coati.dataset.utils import find_first_occurrence_subsequence
@ -38,14 +40,27 @@ class NaiveExperienceMaker(ExperienceMaker):
kl_coef: float = 0.01,
gamma: float = 1.0,
lam: float = 0.95,
use_grpo: bool = False,
num_generation: int = 8,
inference_batch_size: int = None,
logits_forward_batch_size: int = 2,
) -> None:
super().__init__(actor, critic, reward_model, initial_model)
self.tokenizer = tokenizer
self.kl_coef = kl_coef
self.gamma = gamma
self.lam = lam
self.use_grpo = use_grpo
self.num_generation = num_generation
self.inference_batch_size = inference_batch_size
self.logits_forward_batch_size = logits_forward_batch_size
if not self.use_grpo:
assert self.critic is not None, "Critic model is required for PPO training."
else:
assert self.critic is None, "Critic model is not required for GRPO training."
assert self.num_generation > 1, "Number of generations should be greater than 1 for GRPO training."
@torch.no_grad()
@torch.inference_mode()
def calculate_advantage(self, value: torch.Tensor, reward: torch.Tensor, num_actions: int) -> torch.Tensor:
"""
Calculates the advantage values for each action based on the value and reward tensors.
@ -69,7 +84,9 @@ class NaiveExperienceMaker(ExperienceMaker):
return advantages
@torch.no_grad()
def make_experience(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **generate_kwargs) -> Experience:
def make_experience(
self, input_ids: torch.Tensor, attention_mask: torch.Tensor, gt_answer: Any = None, **generate_kwargs
) -> Experience:
"""
Generates an experience using the given input_ids and attention_mask.
@ -83,98 +100,204 @@ class NaiveExperienceMaker(ExperienceMaker):
"""
self.actor.eval()
self.critic.eval()
if self.critic:
self.critic.eval()
self.initial_model.eval()
self.reward_model.eval()
pad_token_id = self.tokenizer.pad_token_id
stop_token_ids = generate_kwargs.get("stop_token_ids", None)
if isinstance(stop_token_ids, int):
stop_token_ids = [[stop_token_ids]]
elif isinstance(stop_token_ids[0], int):
stop_token_ids = [stop_token_ids]
elif isinstance(stop_token_ids[0], list):
pass
else:
raise ValueError(
f"stop_token_ids should be a list of list of integers, a list of integers or an integers. got {stop_token_ids}"
)
generate_kwargs["stop_token_ids"] = stop_token_ids
torch.manual_seed(41) # for tp, gurantee the same input for reward model
sequences = generate(self.actor, input_ids, self.tokenizer, **generate_kwargs)
if self.use_grpo and self.num_generation > 1:
# Generate multiple responses for each prompt
input_ids = input_ids.repeat_interleave(self.num_generation, dim=0)
gt_answer_tmp = []
for t in gt_answer:
gt_answer_tmp.extend([t] * self.num_generation)
gt_answer = gt_answer_tmp
if self.inference_batch_size is None:
self.inference_batch_size = input_ids.size(0)
# Pad to max length
sequences = F.pad(sequences, (0, generate_kwargs["max_length"] - sequences.size(1)), value=pad_token_id)
sequence_length = sequences.size(1)
batch_sequences = []
batch_input_ids_rm = []
batch_attention_mask_rm = []
batch_attention_mask = []
batch_r = []
batch_action_log_probs = []
batch_base_action_log_probs = []
batch_action_mask = []
num_actions = 0
# Calculate auxiliary tensors
attention_mask = None
if pad_token_id is not None:
attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device)
for inference_mini_batch_id in range(0, input_ids.size(0), self.inference_batch_size):
s, e = inference_mini_batch_id, inference_mini_batch_id + self.inference_batch_size
if input_ids[s:e].size(0) == 0:
break
sequences = generate(self.actor, input_ids[s:e], self.tokenizer, **generate_kwargs)
# pad to max_len, you don't want to get an OOM error after a thousands of steps
sequences = F.pad(sequences, (0, generate_kwargs["max_length"] - sequences.size(1)), value=pad_token_id)
input_len = input_ids.size(1)
if stop_token_ids is None:
# End the sequence with eos token
eos_token_id = self.tokenizer.eos_token_id
if eos_token_id is None:
action_mask = torch.ones_like(sequences, dtype=torch.bool)
else:
# Left padding may be applied, only mask action
action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0
action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input
else:
# stop_token_ids are given, generation ends with stop_token_ids
action_mask = torch.ones_like(sequences, dtype=torch.bool)
for i in range(sequences.size(0)):
stop_index = find_first_occurrence_subsequence(
sequences[i][input_len:], torch.tensor(stop_token_ids).to(sequences.device)
)
if stop_index == -1:
# Sequence does not contain stop_token_ids, this should never happen BTW
logger.warning(
"Generated sequence does not contain stop_token_ids. Please check your chat template config"
)
# Pad to max length
sequence_length = sequences.size(1)
# Calculate auxiliary tensors
attention_mask = None
if pad_token_id is not None:
attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device)
input_len = input_ids.size(1)
if stop_token_ids is None:
# End the sequence with eos token
eos_token_id = self.tokenizer.eos_token_id
if eos_token_id is None:
action_mask = torch.ones_like(sequences, dtype=torch.bool)
else:
# Keep stop tokens
stop_index = input_len + stop_index
action_mask[i, stop_index + len(stop_token_ids) :] = False
generation_end_index = (action_mask == True).sum(dim=-1) - 1
action_mask[:, :input_len] = False
action_mask = action_mask[:, 1:]
action_mask = action_mask[:, -(sequences.size(1) - input_len) :]
num_actions = action_mask.size(1)
actor_output = self.actor(input_ids=sequences, attention_mask=attention_mask)["logits"]
action_log_probs = calc_action_log_probs(actor_output, sequences, num_actions)
base_model_output = self.initial_model(input_ids=sequences, attention_mask=attention_mask)["logits"]
base_action_log_probs = calc_action_log_probs(base_model_output, sequences, num_actions)
# Convert to right padding for the reward model and the critic model
input_ids_rm = torch.zeros_like(sequences, device=sequences.device)
attention_mask_rm = torch.zeros_like(sequences, device=sequences.device)
for i in range(sequences.size(0)):
sequence = sequences[i]
bos_index = (sequence != pad_token_id).nonzero().reshape([-1])[0]
eos_index = generation_end_index[i]
sequence_to_pad = sequence[bos_index:eos_index]
sequence_padded = F.pad(
sequence_to_pad, (0, sequence_length - sequence_to_pad.size(0)), value=self.tokenizer.pad_token_id
)
input_ids_rm[i] = sequence_padded
if sequence_length - sequence_to_pad.size(0) > 0:
attention_mask_rm[i, : sequence_to_pad.size(0) + 1] = 1
# Left padding may be applied, only mask action
action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0
action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input
else:
attention_mask_rm[i, :] = 1
attention_mask_rm = attention_mask_rm.to(dtype=torch.bool)
# stop_token_ids are given, generation ends with stop_token_ids
action_mask = torch.ones_like(sequences, dtype=torch.bool)
for i in range(sequences.size(0)):
stop_token_pos = [
find_first_occurrence_subsequence(
sequences[i][input_len:], torch.tensor(stop_token_id).to(sequences.device)
)
for stop_token_id in stop_token_ids
]
stop_index = min([i for i in stop_token_pos if i != -1], default=-1)
stop_token_id = stop_token_ids[stop_token_pos.index(stop_index)]
if stop_index == -1:
# Sequence does not contain stop_token_ids, this should never happen BTW
logger.warning(
"Generated sequence does not contain stop_token_ids. Please check your chat template config"
)
print(self.tokenizer.decode(sequences[i], skip_special_tokens=True))
else:
# Keep stop tokens
stop_index = input_len + stop_index
action_mask[i, stop_index + len(stop_token_id) :] = False
r = self.reward_model(
input_ids=input_ids_rm.to(dtype=torch.long, device=sequences.device),
attention_mask=attention_mask_rm.to(device=sequences.device),
)
generation_end_index = (action_mask == True).sum(dim=-1) - 1
action_mask[:, :input_len] = False
action_mask = action_mask[:, 1:]
action_mask = action_mask[:, -(sequences.size(1) - input_len) :]
num_actions = action_mask.size(1)
torch.cuda.empty_cache()
with torch.inference_mode():
actor_output = []
base_model_output = []
for i in range(0, sequences.size(0), self.logits_forward_batch_size):
actor_output.append(
self.actor(
input_ids=sequences[i : i + self.logits_forward_batch_size],
attention_mask=attention_mask[i : i + self.logits_forward_batch_size],
use_cache=False,
)["logits"]
)
base_model_output.append(
self.initial_model(
input_ids=sequences[i : i + self.logits_forward_batch_size],
attention_mask=attention_mask[i : i + self.logits_forward_batch_size],
use_cache=False,
)["logits"]
)
actor_output = torch.cat(actor_output, dim=0)
base_model_output = torch.cat(base_model_output, dim=0)
action_log_probs = calc_action_log_probs(actor_output, sequences, num_actions)
base_action_log_probs = calc_action_log_probs(base_model_output, sequences, num_actions)
value = self.critic(
input_ids=input_ids_rm.to(dtype=torch.long, device=sequences.device),
attention_mask=attention_mask_rm.to(device=sequences.device),
)
reward, kl = compute_reward(r, self.kl_coef, action_log_probs, base_action_log_probs, action_mask=action_mask)
value = value[:, -num_actions:] * action_mask
advantages = self.calculate_advantage(value, reward, num_actions)
# Convert to right padding for the reward model and the critic model
input_ids_rm = torch.zeros_like(sequences, device=sequences.device)
response_start = []
response_end = []
attention_mask_rm = torch.zeros_like(sequences, device=sequences.device)
for i in range(sequences.size(0)):
sequence = sequences[i]
bos_index = (sequence != pad_token_id).nonzero().reshape([-1])[0]
eos_index = generation_end_index[i] + 1 # include the stop token
sequence_to_pad = sequence[bos_index:eos_index]
response_start.append(input_len - bos_index)
response_end.append(eos_index - bos_index)
sequence_padded = F.pad(
sequence_to_pad, (0, sequence_length - sequence_to_pad.size(0)), value=self.tokenizer.pad_token_id
)
input_ids_rm[i] = sequence_padded
if sequence_length - sequence_to_pad.size(0) > 0:
attention_mask_rm[i, : sequence_to_pad.size(0) + 1] = 1
else:
attention_mask_rm[i, :] = 1
attention_mask_rm = attention_mask_rm.to(dtype=torch.bool)
advantages = advantages.detach()
value = value.detach()
r = self.reward_model(
input_ids=input_ids_rm.to(dtype=torch.long, device=sequences.device),
attention_mask=attention_mask_rm.to(device=sequences.device),
response_start=response_start,
response_end=response_end,
gt_answer=gt_answer[s:e],
)
batch_sequences.append(sequences)
batch_input_ids_rm.append(input_ids_rm)
batch_attention_mask_rm.append(attention_mask_rm)
batch_attention_mask.append(attention_mask)
batch_r.append(r)
batch_action_log_probs.append(action_log_probs.cpu())
batch_base_action_log_probs.append(base_action_log_probs.cpu())
batch_action_mask.append(action_mask)
sequences = torch.cat(batch_sequences, dim=0)
input_ids_rm = torch.cat(batch_input_ids_rm, dim=0)
attention_mask_rm = torch.cat(batch_attention_mask_rm, dim=0)
attention_mask = torch.cat(batch_attention_mask, dim=0)
r = torch.cat(batch_r, dim=0)
action_log_probs = torch.cat(batch_action_log_probs, dim=0).to(sequences.device)
base_action_log_probs = torch.cat(batch_base_action_log_probs, dim=0).to(sequences.device)
action_mask = torch.cat(batch_action_mask, dim=0).to(sequences.device)
if not self.use_grpo:
value = self.critic(
input_ids=input_ids_rm.to(dtype=torch.long, device=sequences.device),
attention_mask=attention_mask_rm.to(device=sequences.device),
)
value = value[:, -num_actions:] * action_mask
reward, kl = compute_reward(
r, self.kl_coef, action_log_probs, base_action_log_probs, action_mask=action_mask
)
advantages = self.calculate_advantage(value, reward, num_actions)
advantages = advantages.detach()
value = value.detach()
else:
# GRPO advantage calculation
kl = torch.sum(
-self.kl_coef * (action_log_probs - base_action_log_probs) * action_mask, dim=-1
) / torch.sum(
action_mask, dim=-1
) # address numerical instability issue
r = kl + r
mean_gr = r.view(-1, self.num_generation).mean(dim=1)
std_gr = r.view(-1, self.num_generation).std(dim=1)
mean_gr = mean_gr.repeat_interleave(self.num_generation, dim=0)
std_gr = std_gr.repeat_interleave(self.num_generation, dim=0)
advantages = (r - mean_gr) / (std_gr + 1e-4)
value = r.detach() # dummy value
r = r.detach()
return Experience(sequences, action_log_probs, value, r, kl, advantages, attention_mask, action_mask)
return Experience(
sequences.cpu(),
action_log_probs.cpu(),
value.cpu(),
r.cpu(),
kl.cpu(),
advantages.cpu(),
attention_mask.cpu(),
action_mask.cpu(),
)

View File

@ -4,12 +4,14 @@ from .generation import generate, generate_streaming, prepare_inputs_fn, update_
from .lora import LoraConfig, convert_to_lora_module, lora_manager
from .loss import DpoLoss, KTOLoss, LogExpLoss, LogSigLoss, PolicyLoss, ValueLoss
from .reward_model import RewardModel
from .rlvr_reward_model import RLVRRewardModel
from .utils import disable_dropout
__all__ = [
"BaseModel",
"Critic",
"RewardModel",
"RLVRRewardModel",
"PolicyLoss",
"ValueLoss",
"LogSigLoss",

View File

@ -1,3 +1,4 @@
import copy
from typing import Any, Callable, List, Optional
import torch
@ -88,13 +89,14 @@ def update_model_kwargs_fn(outputs: dict, new_mask, **model_kwargs) -> dict:
return model_kwargs
def prepare_inputs_fn(input_ids: torch.Tensor, pad_token_id: int, **model_kwargs) -> dict:
def prepare_inputs_fn(input_ids: torch.Tensor, **model_kwargs) -> dict:
model_kwargs["input_ids"] = input_ids
return model_kwargs
def _sample(
model: Any,
tokenizer: Any,
input_ids: torch.Tensor,
max_length: int,
early_stopping: bool = True,
@ -137,8 +139,8 @@ def _sample(
if max_new_tokens is None:
max_new_tokens = max_length - context_length
if context_length + max_new_tokens > max_length or max_new_tokens == 0:
print("Exeeded length limitation")
return input_ids
logits_processor = _prepare_logits_processor(top_k, top_p, temperature)
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
past = None
@ -183,18 +185,14 @@ def _sample(
if stop_token_ids is not None:
# If the last len(stop_token_ids) tokens of input_ids are equal to stop_token_ids, set sentence to finished.
tokens_to_check = input_ids[:, -len(stop_token_ids) :]
unfinished_sequences = unfinished_sequences.mul(
torch.any(tokens_to_check != torch.LongTensor(stop_token_ids).to(input_ids.device), dim=1).long()
)
for stop_token_id in stop_token_ids:
tokens_to_check = input_ids[:, -len(stop_token_id) :]
unfinished_sequences = unfinished_sequences.mul(
torch.any(tokens_to_check != torch.LongTensor(stop_token_id).to(input_ids.device), dim=1).long()
)
# Stop when each sentence is finished if early_stopping=True
if (early_stopping and _is_sequence_finished(unfinished_sequences)) or i == context_length + max_new_tokens - 1:
if i == context_length + max_new_tokens - 1:
# Force to end with stop token ids
input_ids[input_ids[:, -1] != pad_token_id, -len(stop_token_ids) :] = (
torch.LongTensor(stop_token_ids).to(input_ids.device).long()
)
return input_ids
@ -237,8 +235,10 @@ def generate(
raise NotImplementedError
elif is_sample_gen_mode:
# Run sample
generation_kwargs = copy.deepcopy(model_kwargs)
res = _sample(
model,
tokenizer,
input_ids,
max_length,
early_stopping=early_stopping,
@ -249,8 +249,9 @@ def generate(
temperature=temperature,
prepare_inputs_fn=prepare_inputs_fn,
update_model_kwargs_fn=update_model_kwargs_fn,
**model_kwargs,
**generation_kwargs,
)
del generation_kwargs
return res
elif is_beam_gen_mode:
raise NotImplementedError
@ -350,11 +351,17 @@ def _sample_streaming(
unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long())
if stop_token_ids is not None:
# If the last len(stop_token_ids) tokens of input_ids are equal to stop_token_ids, set sentence to finished.
tokens_to_check = input_ids[:, -len(stop_token_ids) :]
unfinished_sequences = unfinished_sequences.mul(
torch.any(tokens_to_check != torch.LongTensor(stop_token_ids).to(input_ids.device), dim=1).long()
)
if isinstance(stop_token_ids[0], int):
# If the last len(stop_token_ids) tokens of input_ids are equal to stop_token_ids, set sentence to finished.
unfinished_sequences = unfinished_sequences.mul(
torch.any(tokens_to_check != torch.LongTensor(stop_token_ids).to(input_ids.device), dim=1).long()
)
else:
for stop_token_id in stop_token_ids:
unfinished_sequences = unfinished_sequences.mul(
torch.any(tokens_to_check != torch.LongTensor(stop_token_id).to(input_ids.device), dim=1).long()
)
# Stop when each sentence is finished if early_stopping=True
if (

View File

@ -25,7 +25,9 @@ class RewardModel(BaseModel):
self.value_head = nn.Linear(self.last_hidden_state_size, 1)
self.value_head.weight.data.normal_(mean=0.0, std=1 / (self.last_hidden_state_size + 1))
def forward(self, input_ids: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
def forward(
self, input_ids: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None, **kwargs
) -> torch.Tensor:
outputs = self.model(input_ids, attention_mask=attention_mask)
last_hidden_states = outputs["last_hidden_state"]

View File

@ -0,0 +1,50 @@
"""
reward model
"""
from typing import Callable, List, Optional
import torch
class RLVRRewardModel:
"""
RLVRReward model class. Support varifiable reward.
Args:
reward_fn_list List: list of reward functions
**kwargs: all other kwargs as in reward functions
"""
def __init__(self, reward_fn_list: List[Callable], **kwargs) -> None:
self.reward_fn_list = reward_fn_list
self.kwargs = kwargs
def __call__(
self,
input_ids: torch.LongTensor,
attention_mask: Optional[torch.Tensor] = None,
response_start: List = None,
response_end: List = None,
gt_answer: List = None,
) -> torch.Tensor:
# apply varifiable reward
bs = input_ids.size(0)
rewards = torch.zeros(bs, device=input_ids.device)
for i in range(bs):
for reward_fn in self.reward_fn_list:
rewards[i] += reward_fn(
input_ids[i],
attention_mask[i],
response_start=response_start[i],
response_end=response_end[i],
gt_answer=gt_answer[i],
**self.kwargs,
)
return rewards
def to(self, device):
return self
def eval(self):
return self

View File

@ -142,3 +142,17 @@ def disable_dropout(model: torch.nn.Module):
for module in model.modules():
if isinstance(module, torch.nn.Dropout):
module.p = 0.0
def repad_to_left(tensor, tokenizer):
repadded_input_ids = []
max_non_padded_seq_len = 0
for i in range(tensor.size(0)):
non_pad_indices = (tensor[i] != tokenizer.pad_token_id).nonzero(as_tuple=True)[0]
start, end = non_pad_indices.min(), non_pad_indices.max()
repadded_input_ids.append(tensor[i][start : end + 1])
max_non_padded_seq_len = max(max_non_padded_seq_len, repadded_input_ids[-1].size(0))
repadded_input_ids = [
F.pad(t, (max_non_padded_seq_len - t.size(0), 0), value=tokenizer.pad_token_id) for t in repadded_input_ids
]
return torch.stack(repadded_input_ids)

View File

@ -1,26 +0,0 @@
import openai
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam
API_KEY = "Dummy API Key"
def get_client(base_url: str | None = None) -> openai.Client:
return openai.Client(api_key=API_KEY, base_url=base_url)
def chat_completion(
messages: list[ChatCompletionMessageParam],
model: str,
base_url: str | None = None,
temperature: float = 0.8,
**kwargs,
) -> ChatCompletion:
client = get_client(base_url)
response = client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
**kwargs,
)
return response

View File

@ -1,250 +0,0 @@
"""
Implementation of MCTS + Self-refine algorithm.
Reference:
1. "Accessing GPT-4 level Mathematical Olympiad Solutions via Monte
Carlo Tree Self-refine with LLaMa-3 8B: A Technical Report"
2. https://github.com/BrendanGraham14/mcts-llm/
3. https://github.com/trotsky1997/MathBlackBox/
4. https://github.com/openreasoner/openr/blob/main/reason/guided_search/tree.py
"""
from __future__ import annotations
import math
from collections import deque
import numpy as np
import tqdm
from coati.reasoner.guided_search.llm import chat_completion
from coati.reasoner.guided_search.prompt_store.base import PromptCFG
from pydantic import BaseModel
class MCTSNode(BaseModel):
"""
Node for MCTS.
"""
answer: str
parent: MCTSNode = None
children: list[MCTSNode] = []
num_visits: int = 0
Q: int = 0
rewards: list[int] = []
def expand_node(self, node) -> None:
self.children.append(node)
def add_reward(self, reward: int) -> None:
self.rewards.append(reward)
self.Q = (np.min(self.rewards) + np.mean(self.rewards)) / 2
class MCTS(BaseModel):
"""
Simulation of MCTS process.
"""
problem: str
max_simulations: int
cfg: PromptCFG
C: float = 1.4
max_children: int = 2
epsilon: float = 1e-5
root: MCTSNode = None
def initialization(self):
"""
Root Initiation.
"""
# Dummy answer as root.
base_answer = self.sample_base_answer()
self.root = MCTSNode(answer=base_answer)
self.self_evaluate(self.root)
def is_fully_expanded(self, node: MCTSNode):
return len(node.children) >= self.max_children or any(child.Q > node.Q for child in node.children)
def select_node(self) -> MCTSNode:
"""
Select next node to explore.
"""
candidates: list[MCTSNode] = []
to_explore = deque([self.root])
while to_explore:
current_node = to_explore.popleft()
if not self.is_fully_expanded(current_node):
candidates.append(current_node)
to_explore.extend(current_node.children)
if not candidates:
return self.root
return max(candidates, key=self.compute_uct)
def self_evaluate(self, node: MCTSNode):
"""
Sample reward of the answer.
"""
reward = self.sample_reward(node)
node.add_reward(reward)
def back_propagation(self, node: MCTSNode):
"""
Back propagate the value of the refined answer.
"""
parent = node.parent
while parent:
best_child_Q = max(child.Q for child in parent.children)
parent.Q = (parent.Q + best_child_Q) / 2
parent.num_visits += 1
parent = parent.parent
def compute_uct(self, node: MCTSNode):
"""
Compute UCT.
"""
if node.parent is None:
return -100
return node.Q + self.C * math.sqrt(math.log(node.parent.num_visits + 1) / (node.num_visits + self.epsilon))
def simulate(self):
self.initialization()
for _ in tqdm.tqdm(range(self.max_simulations)):
node = self.select_node()
child = self.self_refine(node)
node.expand_node(child)
self.self_evaluate(child)
self.back_propagation(child)
return self.get_best_answer()
def get_best_answer(self):
to_visit = deque([self.root])
best_node = self.root
while to_visit:
current_node = to_visit.popleft()
if current_node.Q > best_node.Q:
best_node = current_node
to_visit.extend(current_node.children)
return best_node.answer
def self_refine(self, node: MCTSNode):
"""
Refine node.
"""
critique_response = chat_completion(
messages=[
{
"role": "system",
"content": self.cfg.critic_system_prompt,
},
{
"role": "user",
"content": "\n\n".join(
[
f"<problem>\n{self.problem}\n</problem>",
f"<current_answer>\n{node.answer}\n</current_answer>",
]
),
},
],
model=self.cfg.model,
base_url=self.cfg.base_url,
max_tokens=self.cfg.max_tokens,
)
critique = critique_response.choices[0].message.content
assert critique is not None
refined_answer_response = chat_completion(
messages=[
{
"role": "system",
"content": self.cfg.refine_system_prompt,
},
{
"role": "user",
"content": "\n\n".join(
[
f"<problem>\n{self.problem}\n</problem>",
f"<current_answer>\n{node.answer}\n</current_answer>",
f"<critique>\n{critique}\n</critique>",
]
),
},
],
model=self.cfg.model,
base_url=self.cfg.base_url,
max_tokens=self.cfg.max_tokens,
)
refined_answer = refined_answer_response.choices[0].message.content
assert refined_answer is not None
return MCTSNode(answer=refined_answer, parent=node)
def sample_base_answer(self):
response = chat_completion(
messages=[
{
"role": "system",
"content": "The user will provide a problem. Solve the problem. The response should begin with [reasoning process]...[Verification]... and end with [Final Answer]. \nThe answer is [answer] \n#### [answer].",
},
{
"role": "user",
"content": f"<problem>\n {self.problem} \n</problem> \nLet's think step by step",
},
],
model=self.cfg.model,
base_url=self.cfg.base_url,
max_tokens=self.cfg.max_tokens,
)
assert response.choices[0].message.content is not None
return response.choices[0].message.content
def sample_reward(self, node: MCTSNode):
"""
Calculate reward.
"""
messages = [
{
"role": "system",
"content": self.cfg.evaluate_system_prompt,
},
{
"role": "user",
"content": "\n\n".join(
[
f"<problem>\n{self.problem}\n</problem>",
f"<answer>\n{node.answer}\n</answer>",
]
),
},
]
for attempt in range(3):
try:
response = chat_completion(
messages=messages,
model=self.cfg.model,
base_url=self.cfg.base_url,
max_tokens=self.cfg.max_tokens,
)
assert response.choices[0].message.content is not None
return int(response.choices[0].message.content)
except ValueError:
messages.extend(
[
{
"role": "assistant",
"content": response.choices[0].message.content,
},
{
"role": "user",
"content": "Failed to parse reward as an integer.",
},
]
)
if attempt == 2:
raise

View File

@ -1,10 +0,0 @@
from pydantic import BaseModel
class PromptCFG(BaseModel):
model: str
base_url: str
max_tokens: int = 4096
critic_system_prompt: str
refine_system_prompt: str
evaluate_system_prompt: str

View File

@ -1,20 +0,0 @@
"""
Prompts for Qwen Series.
"""
from coati.reasoner.guided_search.prompt_store.base import PromptCFG
Qwen32B_prompt_CFG = PromptCFG(
base_url="http://0.0.0.0:8008/v1",
model="Qwen2.5-32B-Instruct",
critic_system_prompt="Provide a detailed and constructive critique to improve the answer. "
"Highlight specific areas that need refinement or correction.",
refine_system_prompt="""# Instruction
Refine the answer based on the critique. The response should begin with [reasoning process]...[Verification]... and end with [Final Answer].
""",
evaluate_system_prompt=(
"Analyze this answer strictly and critic, provide a reward score between -100 and 100 for the answer quality, using very strict standards. "
"Do not give a full score above 95. Make sure the reward score is an integer. "
"Return *ONLY* the score."
),
)

View File

@ -1,5 +1,6 @@
from .base import OLTrainer, SLTrainer
from .dpo import DPOTrainer
from .grpo import GRPOTrainer
from .kto import KTOTrainer
from .orpo import ORPOTrainer
from .ppo import PPOTrainer
@ -15,4 +16,5 @@ __all__ = [
"DPOTrainer",
"ORPOTrainer",
"KTOTrainer",
"GRPOTrainer",
]

View File

@ -96,6 +96,7 @@ class OLTrainer(ABC):
self.sample_buffer = sample_buffer
self.dataloader_pin_memory = dataloader_pin_memory
self.callbacks = callbacks
self.num_train_step = 0
@contextmanager
def _fit_ctx(self) -> None:
@ -212,5 +213,6 @@ class OLTrainer(ABC):
self._update_phase(update_step)
# NOTE: this is for on-policy algorithms
self.data_buffer.clear()
if self.save_interval > 0 and (episode + 1) % (self.save_interval) == 0:
self._save_checkpoint(episode + 1)
if self.num_train_step > 0 and (self.num_train_step + 1) % (self.save_interval) == 0:
self._save_checkpoint(self.num_train_step + 1)

View File

@ -343,7 +343,7 @@ class DPOTrainer(SLTrainer):
self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item())
self.accumulative_meter.add("accuracy", reward_accuracies_mean.to(torch.float16).item())
if (i + 1) % self.accumulation_steps == 0:
if (self.num_train_step + 1) % self.accumulation_steps == 0:
self.optimizer.step()
self.optimizer.zero_grad()
self.actor_scheduler.step()
@ -358,29 +358,30 @@ class DPOTrainer(SLTrainer):
)
step_bar.update()
if self.writer and is_rank_0():
self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step)
self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]["lr"], self.num_train_step)
global_step = (self.num_train_step + 1) / self.accumulation_steps
self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), global_step)
self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]["lr"], global_step)
self.writer.add_scalar(
"train/chosen_rewards", self.accumulative_meter.get("chosen_rewards"), self.num_train_step
"train/chosen_rewards", self.accumulative_meter.get("chosen_rewards"), global_step
)
self.writer.add_scalar(
"train/rejected_rewards",
self.accumulative_meter.get("rejected_rewards"),
self.num_train_step,
global_step,
)
self.writer.add_scalar(
"train/margin",
self.accumulative_meter.get("chosen_rewards")
- self.accumulative_meter.get("rejected_rewards"),
self.num_train_step,
global_step,
)
self.writer.add_scalar(
"train/accuracy",
self.accumulative_meter.get("accuracy"),
self.num_train_step,
global_step,
)
self.num_train_step += 1
self.accumulative_meter.reset()
self.num_train_step += 1
if self.save_dir is not None and self.num_train_step > 0 and self.num_train_step % self.save_interval == 0:
# save checkpoint

View File

@ -0,0 +1,386 @@
"""
GRPO trainer
"""
import os
from typing import Dict, List, Optional, Union
import torch
import wandb
from coati.experience_buffer import NaiveExperienceBuffer
from coati.experience_maker import Experience, NaiveExperienceMaker
from coati.models import RewardModel, RLVRRewardModel
from coati.models.loss import GPTLMLoss, PolicyLoss
from coati.models.utils import calc_action_log_probs
from coati.trainer.callbacks import Callback
from coati.trainer.utils import all_reduce_mean
from coati.utils import AccumulativeMeanMeter, save_checkpoint
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader, DistributedSampler
from tqdm import tqdm
from transformers import PreTrainedModel, PreTrainedTokenizerBase
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin
from colossalai.cluster import DistCoordinator
from colossalai.utils import get_current_device
from .base import OLTrainer
from .utils import AnnealingScheduler, CycledDataLoader, is_rank_0, to_device
def _set_default_generate_kwargs(actor: PreTrainedModel) -> Dict:
"""
Set default keyword arguments for generation based on the actor model.
Args:
actor (PreTrainedModel): The actor model.
Returns:
Dict: A dictionary containing the default keyword arguments for generation.
"""
unwrapped_model = actor.unwrap()
new_kwargs = {}
# use huggingface models method directly
if hasattr(unwrapped_model, "prepare_inputs_for_generation"):
new_kwargs["prepare_inputs_fn"] = unwrapped_model.prepare_inputs_for_generation
if hasattr(unwrapped_model, "_update_model_kwargs_for_generation"):
new_kwargs["update_model_kwargs_fn"] = unwrapped_model._update_model_kwargs_for_generation
return new_kwargs
class GRPOTrainer(OLTrainer):
"""
Trainer for GRPO algorithm.
Args:
strategy (Booster): the strategy to use for training
actor (Actor): the actor model in ppo algorithm
reward_model (RewardModel): the reward model in rlhf algorithm to make reward of sentences
initial_model (Actor): the initial model in rlhf algorithm to generate reference logics to limit the update of actor
actor_optim (Optimizer): the optimizer to use for actor model
kl_coef (float, defaults to 0.1): the coefficient of kl divergence loss
train_batch_size (int, defaults to 8): the batch size to use for training
buffer_limit (int, defaults to 0): the max_size limitation of buffer
buffer_cpu_offload (bool, defaults to True): whether to offload buffer to cpu
eps_clip (float, defaults to 0.2): the clip coefficient of policy loss
vf_coef (float, defaults to 1.0): the coefficient of value loss
ptx_coef (float, defaults to 0.9): the coefficient of ptx loss
value_clip (float, defaults to 0.4): the clip coefficient of value loss
sample_buffer (bool, defaults to False): whether to sample from buffer
dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader
offload_inference_models (bool, defaults to True): whether to offload inference models to cpu during training process
callbacks (List[Callback], defaults to []): the callbacks to call during training process
generate_kwargs (dict, optional): the kwargs to use while model generating
"""
def __init__(
self,
actor_booster: Booster,
actor: PreTrainedModel,
reward_model: Union[RewardModel, RLVRRewardModel],
initial_model: PreTrainedModel,
actor_optim: Optimizer,
actor_lr_scheduler: _LRScheduler,
tokenizer: PreTrainedTokenizerBase,
kl_coef: float = 0.1,
ptx_coef: float = 0.9,
train_batch_size: int = 8,
buffer_limit: int = 0,
buffer_cpu_offload: bool = True,
eps_clip: float = 0.2,
vf_coef: float = 1.0,
value_clip: float = 0.2,
sample_buffer: bool = False,
dataloader_pin_memory: bool = True,
offload_inference_models: bool = True,
apply_loss_mask: bool = True,
accumulation_steps: int = 1,
save_interval: int = 0,
save_dir: str = None,
use_tp: bool = False,
num_generation: int = 8,
inference_batch_size: int = None,
logits_forward_batch_size: int = None,
temperature_annealing_config: Optional[Dict] = None,
coordinator: DistCoordinator = None,
callbacks: List[Callback] = [],
**generate_kwargs,
) -> None:
if isinstance(actor_booster, GeminiPlugin):
assert not offload_inference_models, "GeminiPlugin is not compatible with manual model.to('cpu')"
data_buffer = NaiveExperienceBuffer(train_batch_size, buffer_limit, buffer_cpu_offload)
super().__init__(actor_booster, None, data_buffer, sample_buffer, dataloader_pin_memory, callbacks=callbacks)
self.generate_kwargs = _set_default_generate_kwargs(actor)
self.generate_kwargs.update(generate_kwargs)
self.actor = actor
self.actor_booster = actor_booster
self.actor_scheduler = actor_lr_scheduler
self.tokenizer = tokenizer
self.experience_maker = NaiveExperienceMaker(
self.actor,
None,
reward_model,
initial_model,
self.tokenizer,
kl_coef,
use_grpo=True,
num_generation=num_generation,
inference_batch_size=inference_batch_size,
logits_forward_batch_size=logits_forward_batch_size,
)
if temperature_annealing_config:
# use annealing
self.temperature_annealing_scheduler = AnnealingScheduler(
temperature_annealing_config["start_temperature"],
temperature_annealing_config["end_temperature"],
temperature_annealing_config["annealing_warmup_steps"],
temperature_annealing_config["annealing_steps"],
)
else:
self.temperature_annealing_scheduler = None
self.train_batch_size = train_batch_size
self.actor_loss_fn = PolicyLoss(eps_clip)
self.vf_coef = vf_coef
self.ptx_loss_fn = GPTLMLoss()
self.ptx_coef = ptx_coef
self.actor_optim = actor_optim
self.save_interval = save_interval
self.apply_loss_mask = apply_loss_mask
self.coordinator = coordinator
self.actor_save_dir = os.path.join(save_dir, "actor")
self.num_train_step = 0
self.accumulation_steps = accumulation_steps
self.use_tp = use_tp
self.accumulative_meter = AccumulativeMeanMeter()
self.offload_inference_models = offload_inference_models
self.device = get_current_device()
def _before_fit(
self,
prompt_dataloader: DataLoader,
pretrain_dataloader: Optional[DataLoader] = None,
log_dir: Optional[str] = None,
use_wandb: bool = False,
):
"""
Args:
prompt_dataloader (DataLoader): the dataloader to use for prompt data
pretrain_dataloader (DataLoader): the dataloader to use for pretrain data
"""
self.prompt_dataloader = CycledDataLoader(prompt_dataloader)
self.pretrain_dataloader = CycledDataLoader(pretrain_dataloader) if pretrain_dataloader is not None else None
self.writer = None
if use_wandb and is_rank_0():
assert log_dir is not None, "log_dir must be provided when use_wandb is True"
import wandb
self.wandb_run = wandb.init(project="Coati-grpo", sync_tensorboard=True)
if log_dir is not None and is_rank_0():
import os
import time
from torch.utils.tensorboard import SummaryWriter
log_dir = os.path.join(log_dir, "grpo")
log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()))
self.writer = SummaryWriter(log_dir=log_dir)
def _setup_update_phrase_dataload(self):
"""
why not use distributed_dataloader?
if tp is used, input on each rank is the same and we use the same dataloader to feed same experience to all ranks
if tp is not used, input on each rank is different and we expect different experiences to be fed to each rank
"""
self.dataloader = DataLoader(
self.data_buffer,
batch_size=self.train_batch_size,
shuffle=True,
drop_last=True,
pin_memory=self.dataloader_pin_memory,
collate_fn=self.data_buffer.collate_fn,
)
def _make_experience(self, collect_step: int) -> Experience:
"""
Make experience
"""
prompts = self.prompt_dataloader.next()
if self.offload_inference_models:
# TODO(ver217): this may be controlled by strategy if they are prepared by strategy
self.experience_maker.initial_model.to(self.device)
self.experience_maker.reward_model.to(self.device)
if self.temperature_annealing_scheduler:
self.generate_kwargs["temperature"] = self.temperature_annealing_scheduler.get_temperature()
return self.experience_maker.make_experience(
input_ids=prompts["input_ids"].to(get_current_device()),
attention_mask=prompts["attention_mask"].to(get_current_device()),
gt_answer=prompts["gt_answer"],
**self.generate_kwargs,
)
def _training_step(self, experience: Experience):
"""
Args:
experience:
sequences: [batch_size, prompt_length + response_length] --- <PAD>...<PAD><PROMPT>...<PROMPT><RESPONSE>...<RESPONSE><PAD>...<PAD>
"""
self.actor.train()
num_actions = experience.action_log_probs.size(1)
# policy loss
actor_logits = self.actor(input_ids=experience.sequences, attention_mask=experience.attention_mask)[
"logits"
] # [batch size, prompt_length + response_length]
action_log_probs = calc_action_log_probs(actor_logits, experience.sequences, num_actions)
actor_loss, to_skip, max_ratio = self.actor_loss_fn(
action_log_probs,
experience.action_log_probs,
experience.advantages.unsqueeze(dim=-1).repeat_interleave(action_log_probs.size(-1), dim=-1),
action_mask=experience.action_mask if self.apply_loss_mask else None,
)
# sequence that is not end properly are not counted in token cost
token_cost = torch.sum(
(experience.sequences[:, -num_actions:] != self.tokenizer.pad_token_id).to(torch.float), axis=-1
).to(actor_logits.device)
end_properly = experience.sequences[:, -1] == self.tokenizer.pad_token_id
mean_token_cost = torch.sum(token_cost * end_properly) / torch.sum(end_properly)
actor_loss = (1 - self.ptx_coef) * actor_loss
if not to_skip:
self.actor_booster.backward(loss=actor_loss, optimizer=self.actor_optim)
# ptx loss
if self.ptx_coef != 0:
batch = self.pretrain_dataloader.next()
batch = to_device(batch, self.device)
outputs = self.actor(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"])
ptx_loss = outputs.loss
ptx_loss = self.ptx_coef * ptx_loss
self.actor_booster.backward(loss=ptx_loss, optimizer=self.actor_optim)
# sync
actor_loss_mean = all_reduce_mean(tensor=actor_loss)
max_ratio_mean = all_reduce_mean(tensor=max_ratio)
reward_mean = all_reduce_mean(tensor=experience.reward.mean())
advantages_mean = all_reduce_mean(tensor=experience.advantages.mean())
kl_mean = all_reduce_mean(tensor=experience.kl.mean())
mean_token_cost = all_reduce_mean(tensor=mean_token_cost)
if self.ptx_coef != 0:
ptx_loss_mean = all_reduce_mean(tensor=ptx_loss)
self.accumulative_meter.add("actor_loss", actor_loss_mean.to(torch.float16).mean().item())
self.accumulative_meter.add("max_ratio", max_ratio_mean.to(torch.float16).item())
self.accumulative_meter.add("reward", reward_mean.to(torch.float16).mean().item())
self.accumulative_meter.add("advantages", advantages_mean.to(torch.float16).item())
self.accumulative_meter.add("skip_ratio", 1.0 if to_skip else 0.0)
self.accumulative_meter.add("mean_token_cost", mean_token_cost.to(torch.float16).item())
self.accumulative_meter.add("kl", kl_mean.to(torch.float16).item())
if self.ptx_coef != 0:
self.accumulative_meter.add("ptx_loss", ptx_loss_mean.to(torch.float16).mean().item())
if self.num_train_step % self.accumulation_steps == self.accumulation_steps - 1:
self.actor_optim.step()
self.actor_optim.zero_grad()
self.actor_scheduler.step()
if self.temperature_annealing_scheduler:
self.temperature_annealing_scheduler.step_forward()
# preparing logging model output and corresponding rewards.
if self.num_train_step % 10 == 0:
response_text = self.experience_maker.tokenizer.batch_decode(
experience.sequences, skip_special_tokens=True
)
for i in range(len(response_text)):
response_text[i] = response_text[i] + f"\n\nReward: {experience.reward[i]}"
if self.writer and is_rank_0() and "wandb_run" in self.__dict__:
# log output to wandb
my_table = wandb.Table(
columns=[f"sample response {i}" for i in range(len(response_text))], data=[response_text]
)
try:
self.wandb_run.log({"sample_response": my_table})
except OSError as e:
self.coordinator.print_on_master(e)
elif self.writer and is_rank_0():
for line in response_text:
self.coordinator.print_on_master(line)
if self.writer and is_rank_0():
global_step = (self.num_train_step + 1) / self.accumulation_steps
self.writer.add_scalar("train/max_ratio", self.accumulative_meter.get("max_ratio"), global_step)
self.writer.add_scalar("train/skip_ratio", self.accumulative_meter.get("skip_ratio"), global_step)
self.writer.add_scalar("train/actor_loss", self.accumulative_meter.get("actor_loss"), global_step)
self.writer.add_scalar("train/lr_actor", self.actor_optim.param_groups[0]["lr"], global_step)
if self.ptx_coef != 0:
self.writer.add_scalar("train/ptx_loss", self.accumulative_meter.get("ptx_loss"), global_step)
self.writer.add_scalar("reward", self.accumulative_meter.get("reward"), global_step)
self.writer.add_scalar("token_cost", self.accumulative_meter.get("mean_token_cost"), global_step)
self.writer.add_scalar("approx_kl", self.accumulative_meter.get("kl"), global_step)
self.writer.add_scalar("advantages", self.accumulative_meter.get("advantages"), global_step)
self.accumulative_meter.reset()
self.num_train_step += 1
def _learn(self, update_step: int):
"""
Perform the learning step of the PPO algorithm.
Args:
update_step (int): The current update step.
Returns:
None
"""
if self.offload_inference_models:
self.experience_maker.initial_model.to("cpu")
self.experience_maker.reward_model.to("cpu")
# buffer may be empty at first, we should rebuild at each training
if self.sample_buffer:
experience = self.data_buffer.sample()
self._on_learn_batch_start()
experience.to_device(self.device)
self._training_step(experience)
self._on_learn_batch_end(experience)
else:
if isinstance(self.dataloader.sampler, DistributedSampler):
self.dataloader.sampler.set_epoch(update_step)
pbar = tqdm(self.dataloader, desc=f"Train epoch [{update_step + 1}]", disable=not is_rank_0())
for experience in pbar:
self._on_learn_batch_start()
experience.to_device(self.device)
self._training_step(experience)
self._on_learn_batch_end(experience)
def _save_checkpoint(self, num_train_step: int = 0):
"""
Save the actor checkpoints with running states.
Args:
num_train_step (int): The current num_train_step number.
Returns:
None
"""
self.coordinator.print_on_master("\nStart saving actor checkpoint with running states")
save_checkpoint(
save_dir=self.actor_save_dir,
booster=self.actor_booster,
model=self.actor,
optimizer=self.actor_optim,
lr_scheduler=self.actor_scheduler,
epoch=0,
step=num_train_step + 1,
batch_size=self.train_batch_size,
coordinator=self.coordinator,
)
self.coordinator.print_on_master(
f"Saved actor checkpoint at episode {(num_train_step + 1)} at folder {self.actor_save_dir}"
)

View File

@ -217,25 +217,25 @@ class KTOTrainer(SLTrainer):
self.accumulative_meter.add("rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item())
self.accumulative_meter.add("loss", loss_mean.to(torch.float16).detach().item())
if i % self.accumulation_steps == self.accumulation_steps - 1:
self.num_train_step += 1
if self.num_train_step % self.accumulation_steps == self.accumulation_steps - 1:
step_bar.update()
# logging
if self.writer and is_rank_0():
self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step)
self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]["lr"], self.num_train_step)
global_step = (self.num_train_step + 1) / self.accumulation_steps
self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), global_step)
self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]["lr"], global_step)
self.writer.add_scalar(
"train/chosen_rewards", self.accumulative_meter.get("chosen_rewards"), self.num_train_step
"train/chosen_rewards", self.accumulative_meter.get("chosen_rewards"), global_step
)
self.writer.add_scalar(
"train/rejected_rewards",
self.accumulative_meter.get("rejected_rewards"),
self.num_train_step,
global_step,
)
self.writer.add_scalar(
"train/margin",
self.accumulative_meter.get("chosen_rewards") - self.accumulative_meter.get("rejected_rewards"),
self.num_train_step,
global_step,
)
self.accumulative_meter.reset()
@ -256,6 +256,7 @@ class KTOTrainer(SLTrainer):
self.coordinator.print_on_master(
f"Saved checkpoint at epoch {epoch} step {self.save_interval} at folder {self.save_dir}"
)
self.num_train_step += 1
step_bar.close()

View File

@ -184,35 +184,35 @@ class ORPOTrainer(SLTrainer):
self.accumulative_meter.add("log_odds_ratio", log_odds_ratio.to(torch.float16).mean().item())
self.accumulative_meter.add("accuracy", reward_accuracies_mean.to(torch.float16).item())
if i % self.accumulation_steps == self.accumulation_steps - 1:
self.num_train_step += 1
if self.num_train_step % self.accumulation_steps == self.accumulation_steps - 1:
step_bar.update()
global_step = (self.num_train_step + 1) / self.accumulation_steps
# logging
if self.writer and is_rank_0():
self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step)
self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]["lr"], self.num_train_step)
self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), global_step)
self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]["lr"], global_step)
self.writer.add_scalar(
"train/chosen_rewards", self.accumulative_meter.get("chosen_rewards"), self.num_train_step
"train/chosen_rewards", self.accumulative_meter.get("chosen_rewards"), global_step
)
self.writer.add_scalar(
"train/rejected_rewards",
self.accumulative_meter.get("rejected_rewards"),
self.num_train_step,
global_step,
)
self.writer.add_scalar(
"train/margin",
self.accumulative_meter.get("chosen_rewards") - self.accumulative_meter.get("rejected_rewards"),
self.num_train_step,
global_step,
)
self.writer.add_scalar(
"train/accuracy",
self.accumulative_meter.get("accuracy"),
self.num_train_step,
global_step,
)
self.writer.add_scalar(
"train/log_odds_ratio",
self.accumulative_meter.get("log_odds_ratio"),
self.num_train_step,
global_step,
)
self.accumulative_meter.reset()
@ -233,6 +233,7 @@ class ORPOTrainer(SLTrainer):
self.coordinator.print_on_master(
f"Saved checkpoint at epoch {epoch} step {self.save_interval} at folder {self.save_dir}"
)
self.num_train_step += 1
step_bar.close()

View File

@ -3,13 +3,13 @@ PPO trainer
"""
import os
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Union
import torch
import wandb
from coati.experience_buffer import NaiveExperienceBuffer
from coati.experience_maker import Experience, NaiveExperienceMaker
from coati.models import Critic, RewardModel
from coati.models import Critic, RewardModel, RLVRRewardModel
from coati.models.loss import GPTLMLoss, PolicyLoss, ValueLoss
from coati.models.utils import calc_action_log_probs
from coati.trainer.callbacks import Callback
@ -84,7 +84,7 @@ class PPOTrainer(OLTrainer):
critic_booster: Booster,
actor: PreTrainedModel,
critic: Critic,
reward_model: RewardModel,
reward_model: Union[RewardModel, RLVRRewardModel],
initial_model: PreTrainedModel,
actor_optim: Optimizer,
critic_optim: Optimizer,
@ -210,6 +210,7 @@ class PPOTrainer(OLTrainer):
return self.experience_maker.make_experience(
input_ids=prompts["input_ids"].to(get_current_device()),
attention_mask=prompts["attention_mask"].to(get_current_device()),
gt_answer=prompts["gt_answer"],
**self.generate_kwargs,
)
@ -219,7 +220,6 @@ class PPOTrainer(OLTrainer):
experience:
sequences: [batch_size, prompt_length + response_length] --- <PAD>...<PAD><PROMPT>...<PROMPT><RESPONSE>...<RESPONSE><PAD>...<PAD>
"""
self.num_train_step += 1
self.actor.train()
self.critic.train()
num_actions = experience.action_log_probs.size(1)
@ -293,7 +293,7 @@ class PPOTrainer(OLTrainer):
self.critic_scheduler.step()
# preparing logging model output and corresponding rewards.
if self.num_train_step % 10 == 1:
if self.num_train_step % 10 == 0:
response_text = self.experience_maker.tokenizer.batch_decode(
experience.sequences, skip_special_tokens=True
)
@ -335,6 +335,7 @@ class PPOTrainer(OLTrainer):
self.writer.add_scalar("value", self.accumulative_meter.get("value"), self.num_train_step)
self.writer.add_scalar("advantages", self.accumulative_meter.get("advantages"), self.num_train_step)
self.accumulative_meter.reset()
self.num_train_step += 1
def _learn(self, update_step: int):
"""

View File

@ -150,29 +150,29 @@ class RewardModelTrainer(SLTrainer):
self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item())
self.accumulative_meter.add("accuracy", accuracy_mean.mean().to(torch.float16).item())
if (i + 1) % self.accumulation_steps == 0:
if (self.num_train_step + 1) % self.accumulation_steps == 0:
self.optimizer.step()
self.optimizer.zero_grad()
self.actor_scheduler.step()
step_bar.update()
self.num_train_step += 1
# Logging
if self.writer and is_rank_0():
self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step)
self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]["lr"], self.num_train_step)
global_step = (self.num_train_step + 1) / self.accumulation_steps
self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), global_step)
self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]["lr"], global_step)
self.writer.add_scalar(
"train/dist",
self.accumulative_meter.get("chosen_rewards") - self.accumulative_meter.get("rejected_rewards"),
self.num_train_step,
global_step,
)
self.writer.add_scalar(
"train/reward_chosen", self.accumulative_meter.get("chosen_rewards"), self.num_train_step
"train/reward_chosen", self.accumulative_meter.get("chosen_rewards"), global_step
)
self.writer.add_scalar(
"train/reward_reject", self.accumulative_meter.get("rejected_rewards"), self.num_train_step
"train/reward_reject", self.accumulative_meter.get("rejected_rewards"), global_step
)
self.writer.add_scalar("train/acc", self.accumulative_meter.get("accuracy"), self.num_train_step)
self.writer.add_scalar("train/acc", self.accumulative_meter.get("accuracy"), global_step)
self.accumulative_meter.reset()
@ -193,6 +193,7 @@ class RewardModelTrainer(SLTrainer):
self.coordinator.print_on_master(
f"Saved checkpoint at epoch {epoch} step {(i + 1)/self.accumulation_steps} at folder {self.save_dir}"
)
self.num_train_step += 1
step_bar.close()
def _eval(self, epoch):

View File

@ -143,18 +143,18 @@ class SFTTrainer(SLTrainer):
self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item())
# Gradient accumulation
if (i + 1) % self.accumulation_steps == 0:
if (self.num_train_step + 1) % self.accumulation_steps == 0:
self.optimizer.step()
self.optimizer.zero_grad()
self.scheduler.step()
global_step = (self.num_train_step + 1) / self.accumulation_steps
step_bar.set_postfix({"train/loss": self.accumulative_meter.get("loss")})
if self.writer:
self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step)
self.writer.add_scalar("train/lr", self.scheduler.get_last_lr()[0], self.num_train_step)
self.num_train_step += 1
self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), global_step)
self.writer.add_scalar("train/lr", self.scheduler.get_last_lr()[0], global_step)
self.accumulative_meter.reset()
step_bar.update()
self.num_train_step += 1
# Save checkpoint
if (

View File

@ -12,6 +12,27 @@ from torch.utils.data import DataLoader
from colossalai.booster import Plugin
class AnnealingScheduler:
def __init__(self, start, end, warmup_steps=100, annealing_step=2000):
self.start = start
self.end = end
self.warmup_steps = warmup_steps
self.step = 0
self.annealing_step = annealing_step
def get_temperature(self):
if self.step <= self.warmup_steps:
return self.start # Stop annealing after warm-up steps
elif self.step >= self.annealing_step:
return self.end
# Linear annealing
temp = self.start - (self.step / self.annealing_step) * (self.start - self.end)
return temp
def step_forward(self):
self.step += 1
class CycledDataLoader:
"""
A data loader that cycles through the data when it reaches the end.

View File

@ -0,0 +1,4 @@
from .competition import math_competition_reward_fn
from .gsm8k import gsm8k_reward_fn
__all__ = ["gsm8k_reward_fn", "math_competition_reward_fn"]

View File

@ -0,0 +1,26 @@
import torch
from .utils import extract_solution, validate_response_structure
def math_competition_reward_fn(input_ids, attention_mask, **kwargs):
# apply varifiable reward
# reward 10 points if the final answer is correct, reward 1 point if format is correct
gt_answer = kwargs["gt_answer"]
tokenizer = kwargs["tokenizer"]
s, e = kwargs["response_start"], kwargs["response_end"]
reward = torch.tensor(0.0).to(input_ids.device)
if gt_answer is None:
return reward
decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True)
final_answer, processed_str = extract_solution(decoded_final_answer)
format_valid = validate_response_structure(processed_str, kwargs["tags"])
if not format_valid:
return reward
else:
reward += 1.0
if gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower():
reward = reward + 9.0
return reward

View File

@ -0,0 +1,31 @@
import torch
from .utils import extract_solution, validate_response_structure
def gsm8k_reward_fn(input_ids, attention_mask, **kwargs):
# apply varifiable reward
# reward 10 points if the final answer is correct, reward 1 point if format is correct
gt_answer = kwargs["gt_answer"]
tokenizer = kwargs["tokenizer"]
s, e = kwargs["response_start"], kwargs["response_end"]
reward = torch.tensor(0.0).to(input_ids.device)
if gt_answer is None:
return reward
decoded_final_answer = tokenizer.decode(input_ids[s:e], skip_special_tokens=True)
final_answer, processed_str = extract_solution(decoded_final_answer)
is_valid = True
try:
int(final_answer.strip())
except Exception:
is_valid = False
format_valid = validate_response_structure(processed_str, kwargs["tags"])
if not is_valid or not format_valid:
return reward
else:
reward += 1.0
if gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower():
reward = reward + 9.0
return reward

View File

@ -0,0 +1,76 @@
# Copyright Unakar
# Modified from https://github.com/Unakar/Logic-RL/blob/086373176ac198c97277ff50f4b6e7e1bfe669d3/verl/utils/reward_score/kk.py#L99
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from typing import Dict, Optional, Tuple
def validate_response_structure(processed_str: str, tags: Dict = None) -> bool:
"""Performs comprehensive validation of response structure.
Args:
processed_str: Processed response string from the model
Returns:
Boolean indicating whether all formatting requirements are met
"""
validation_passed = True
# Check required tags
if tags is None:
tags = {
"think_start": {"text": "<think>", "num_occur": 1},
"think_end": {"text": "</think>", "num_occur": 1},
"answer_start": {"text": "<answer>", "num_occur": 1},
"answer_end": {"text": "</answer>", "num_occur": 1},
}
positions = {}
for tag_name, tag_info in tags.items():
tag_str = tag_info["text"]
expected_count = tag_info["num_occur"]
count = processed_str.count(tag_str)
positions[tag_name] = pos = processed_str.find(tag_str)
if count != expected_count:
validation_passed = False
# Verify tag order
if (
positions["think_start"] > positions["think_end"]
or positions["think_end"] > positions["answer_start"]
or positions["answer_start"] > positions["answer_end"]
):
validation_passed = False
if len(processed_str) - positions["answer_end"] != len(tags["answer_end"]["text"]):
validation_passed = False
return validation_passed
def extract_solution(solution_str: str) -> Tuple[Optional[str], str]:
"""Extracts the final answer from the model's response string.
Args:
solution_str: Raw response string from the language model
Returns:
Tuple containing (extracted_answer, processed_string)
"""
# Extract final answer using XML-style tags
answer_pattern = r"<answer>(.*?)</answer>"
matches = list(re.finditer(answer_pattern, solution_str, re.DOTALL))
if not matches:
return None, solution_str
final_answer = matches[-1].group(1).strip()
return final_answer, solution_str

View File

@ -0,0 +1,8 @@
{
"chat_template": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
"system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
"stop_ids": [
122753
],
"end_of_assistant": "<|im_end|>"
}

View File

@ -0,0 +1,26 @@
{
"chat_template": "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
"system_message": "You are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and<answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>. Now the user asks you to solve a math problem that involves reasoning. After thinking, when you finally reach a conclusion, clearly output the final answer without explanation within the <answer> </answer> tags, your final answer should be a integer without unit, currency mark, thousands separator or other text. i.e., <answer> 123 </answer>.\n",
"stop_ids": [
151643
],
"end_of_assistant": "<|endoftext|>",
"response_format_tags": {
"think_start": {
"text": "<think>",
"num_occur": 1
},
"think_end": {
"text": "</think>",
"num_occur": 1
},
"answer_start": {
"text": "<answer>",
"num_occur": 1
},
"answer_end": {
"text": "</answer>",
"num_occur": 1
}
}
}

View File

@ -2,8 +2,6 @@
## Table of Contents
- [Examples](#examples)
- [Table of Contents](#table-of-contents)
- [Install Requirements](#install-requirements)
@ -27,13 +25,14 @@
- [Reward](#reward)
- [KL Divergence](#approximate-kl-divergence)
- [Note on PPO Training](#note-on-ppo-training)
- [GRPO Training and DeepSeek R1 reproduction](#grpo-training-and-deepseek-r1-reproduction)
- [Alternative Option For RLHF: Direct Preference Optimization](#alternative-option-for-rlhf-direct-preference-optimization)
- [DPO Stage 1: Supervised Instruction Tuning](#dpo-training-stage1---supervised-instructs-tuning)
- [DPO Stage 2: DPO Training](#dpo-training-stage2---dpo-training)
- [Alternative Option For RLHF: Simple Preference Optimization](#alternative-option-for-rlhf-simple-preference-optimization)
- [Alternative Option For RLHF: Kahneman-Tversky Optimization (KTO)](#alternative-option-for-rlhf-kahneman-tversky-optimization-kto)
- [Alternative Option For RLHF: Odds Ratio Preference Optimization](#alternative-option-for-rlhf-odds-ratio-preference-optimization)
- [List of Supported Models](#list-of-supported-models)
- [SFT for DeepSeek V3](#sft-for-deepseek-v3)
- [Hardware Requirements](#hardware-requirements)
- [Inference example](#inference-example)
- [Attention](#attention)
@ -725,10 +724,69 @@ Answer: The causes of this problem are two-fold. Check your reward model, make s
#### Q4: Generation is garbage
Answer: Yes, this happens and is well documented by other implementations. After training for too many episodes, the actor gradually deviate from its original state, which may leads to decrease in language modeling capabilities. A way to fix this is to add supervised loss during PPO. Set ptx_coef to an non-zero value (between 0 and 1), which balances PPO loss and sft loss.
## GRPO Training and DeepSeek R1 reproduction
We support GRPO (Group Relative Policy Optimization), which is the reinforcement learning algorithm used in DeepSeek R1 paper. In this section, we will walk through GRPO training with an example trying to reproduce Deepseek R1's results in mathematical problem solving.
**Note: Currently, our PPO and GRPO pipelines are still under extensive development (integration with Ray and the inference engine). The speed is primarily limited by the rollout process, as we are using a naive generation approach without any acceleration. This experiment is focused solely on verifying the correctness of the GRPO algorithm. We will open-source the new version of code as soon as possible, so please stay tuned.**
### GRPO Model Selection
We finally select the base version of [Qwen2.5-3B](https://huggingface.co/Qwen/Qwen2.5-3B). We also did experiments on the instruct version [Qwen2.5-3B-Instruct](https://huggingface.co/Qwen/Qwen2.5-3B-Instruct) but the later one fails to explore more diversed output. We recommend to use base models (without SFT) and use a few SFT steps (see [SFT section](#rlhf-training-stage1---supervised-instructs-tuning)) to correct the base model's output format before GRPO.
### Reinforcement Learning with Verifiable Reward
Both the PPO and the GRPO support reinforcement learning with verifiable reward (RLVR). In this experiment on mathematical problem solving, we define the reward function as following, in the following definition, forward is correct if there are exactly one pair of <think></think>, <answer></answer> tags in the response and the order of the tags is correct.
- reward=0, if format is incorrect.
- reward=1, if format is correct but the answer doesn't match the ground truth answer exactly.
- reward=10, if format is correct and the answer match the ground truth answer exactly.
### Step 1: Data Collection & Preparation
For GPRO training, you only need the prompt dataset. Please follow the instruction in the [prompt dataset preparation](#rlhf-training-stage3---proximal-policy-optimization) to prepare the prompt data for GPRO training. In our reproduction experiment, we use the [qwedsacf/competition_math dataset](https://huggingface.co/datasets/qwedsacf/competition_math), which is available on Huggingface.
### Step 2: Training
You can run the [train_grpo.sh](./training_scripts/train_grpo.sh) to start GRPO training. The script share most of its arguments with the PPO script (please refer to the [PPO training section](#step-3-training) for more details). Here are some unique arguments for GRPO.
```bash
--num_generations 8 \ # number of roll outs to collect for each prompt
--inference_batch_size 8 \ # batch size used during roll out
--logits_forward_batch_size 1 \ # batch size used to calculate logits for GRPO training
--initial_temperature \ # initial temperature for annealing algorithm
--final_temperature \ # final temperature for annealing algorithm
```
As the GRPO requires to collect a group of response from each prompt (usually greater than 8), the effective batch size will satisfy the following constraints,
- Without tensor parallelism,
```
experience buffer size
= num_process * num_collect_steps * experience_batch_size * num_generations
= train_batch_size * accumulation_steps * num_process
```
- With tensor parallelism,
```
num_tp_group = num_process / tp
experience buffer size
= num_tp_group * num_collect_steps * experience_batch_size * num_generations
= train_batch_size * accumulation_steps * num_tp_group
```
During roll out, we perform rebatching to prevent out of memory both before roll out and before calculating logits. Please choose a proper setting for the "inference_batch_size" and the "logits_forward_batch_size" based on your device.
### GRPO Result
#### Reward and Response Length
<div style="display: flex; justify-content: space-between;">
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/grpo/reward.png" style="width: 48%;" />
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/grpo/token_cost.png" style="width: 48%;" />
</div>
#### Response Length Distribution (After Training) and Sample response
<div style="display: flex; justify-content: space-between;">
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/grpo/token_cost_eval.png" style="width: 48%;" />
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/grpo/sample.png" style="width: 48%;" />
</div>
## Alternative Option For RLHF: Direct Preference Optimization
For those seeking an alternative to Reinforcement Learning from Human Feedback (RLHF), Direct Preference Optimization (DPO) presents a compelling option. DPO, as detailed in the paper (available at [https://arxiv.org/abs/2305.18290](https://arxiv.org/abs/2305.18290)), DPO offers an low-cost way to perform RLHF and usually request less computation resources compares to PPO.
@ -814,8 +872,95 @@ For training, use the [train_kto.sh](./examples/training_scripts/train_orpo.sh)
<img width="1000" alt="image" src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/KTO.png">
</p>
## Hardware Requirements
### SFT for DeepSeek V3
We add a script to supervised-fintune the DeepSeek V3/R1 model with LoRA. The script is located in `examples/training_scripts/lora_fintune.py`. The script is similar to the SFT script for Coati7B, but with a few differences. This script is compatible with Peft.
#### Dataset preparation
This script receives JSONL format file as input dataset. Each line of dataset should be a list of chat dialogues. E.g.
```json
[{"role": "user", "content": "Hello, how are you?"}, {"role": "assistant", "content": "I'm doing great. How can I help you today?"}]
```
```json
[{"role": "user", "content": "火烧赤壁 曹操为何不拨打119求救"}, {"role": "assistant", "content": "因为在三国时期还没有电话和现代的消防系统所以曹操无法拨打119求救。"}]
```
The dialogues can by multiple turns and it can contain system prompt. For more details, see the [chat_templating](https://huggingface.co/docs/transformers/main/chat_templating).
#### Model weights preparation
We use bf16 weights for finetuning. If you downloaded fp8 DeepSeek V3/R1 weights, you can use the [script](https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/fp8_cast_bf16.py) to convert the weights to bf16 via GPU. For Ascend NPU, you can use this [script](https://gitee.com/ascend/ModelZoo-PyTorch/blob/master/MindIE/LLM/DeepSeek/DeepSeek-V2/NPU_inference/fp8_cast_bf16.py).
We have also added details on how to load and reason with lora models.
```python
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
)
from peft import (
PeftModel
)
import torch
# Set model path
model_name = "Qwen/Qwen2.5-3B"
lora_adapter = "Qwen2.5-3B_lora" # Your lora model Path
merged_model_path = "Qwen2.5-3B_merged"
######
# How to Load lora Model
######
# 1.Load base model
base_model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True
)
# 2.Load lora model
peft_model = PeftModel.from_pretrained(
base_model,
lora_adapter,
torch_dtype=torch.bfloat16
)
# 3.Merge lora model
merged_model = peft_model.merge_and_unload()
# 4.Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code=True,
pad_token="<|endoftext|>"
)
# 5.Save merged lora model
merged_model.save_pretrained(
merged_model_path,
safe_serialization=True
)
tokenizer.save_pretrained(merged_model_path)
# 6.Run Inference
test_input = tokenizer("Instruction: Finding prime numbers up to 100\nAnswer:", return_tensors="pt").to("cuda")
output = merged_model.generate(**test_input, max_new_tokens=100)
print(tokenizer.decode(output[0], skip_special_tokens=True))
```
#### Usage
After preparing the dataset and model weights, you can run the script with the following command:
```bash
colossalai run --hostfile path-to-host-file --nproc_per_node 8 lora_finetune.py --pretrained path-to-DeepSeek-R1-bf16 --dataset path-to-dataset.jsonl --plugin moe --lr 2e-5 --max_length 256 -g --ep 8 --pp 3 --batch_size 24 --lora_rank 8 --lora_alpha 16 --num_epochs 2 --warmup_steps 8 --tensorboard_dir logs --save_dir DeepSeek-R1-bf16-lora
```
For more details of each argument, you can run `python lora_finetune.py --help`.
The sample command does not use CPU offload to get better throughput. The minimum hardware requirement for sample command is 32 ascend 910B NPUs (with `ep=8,pp=4`) or 24 H100/H800 GPUs (with `ep=8,pp=3`). If you enable CPU offload by `--zero_cpu_offload`, the hardware requirement can be further reduced.
## Hardware Requirements
For SFT, we recommend using zero2 or zero2-cpu for 7B model and tp is your model is extra large. We tested the VRAM consumption on a dummy dataset with a sequence length of 2048. In all experiments, we use H800 GPUs with 80GB VRAM and enable gradient checkpointing and flash attention.
- 2 H800 GPU
- zero2-cpu, micro batch size=4, VRAM Usage=22457.98 MB
@ -872,35 +1017,9 @@ For KTO, we recommend using zero2-cpu or zero2 plugin, We tested the VRAM consum
- zero2_cpu, micro batch size=2, VRAM_USAGE=32443.22 MB
- zero2, micro batch size=4, VRAM_USAGE=59307.97 MB
## List of Supported Models
For SFT, we support the following models/series:
- Colossal-LLaMA-2
- ChatGLM2
- ChatGLM3 (only with zero2, zero2_cpu plugin)
- Baichuan2
- LLaMA2
- Qwen1.5-7B-Chat (with transformers==4.39.1)
- Yi-1.5
For PPO and DPO, we theoratically support the following models/series (without guarantee):
- Colossal-LLaMA-2 (tested)
- ChatGLM2
- Baichuan2
- LLaMA2 (tested)
- Qwen1.5-7B-Chat (with transformers==4.39.1)
- Yi-1.5
*-* The zero2, zero2_cpu plugin also support a wide range of chat models not listed above.
## Inference example
We support different inference options, including int8 and int4 quantization.
For details, see [`inference/`](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat/inference).
## Attention
The examples are demos for the whole training process. You need to change the hyper-parameters to reach great performance.

View File

@ -11,4 +11,4 @@ python prepare_dataset.py --type prompt \
--data_cache_dir $SAVE_DIR/cache \
--data_jsonl_output_dir $SAVE_DIR/jsonl \
--data_arrow_output_dir $SAVE_DIR/arrow \
--max_length 1024
--max_length 300

View File

@ -1,181 +0,0 @@
import argparse
import os
import socket
from functools import partial
import pandas as pd
import ray
from coati.quant import llama_load_quant, low_resource_init
from coati.ray.detached_trainer_ppo import DetachedPPOTrainer
from coati.ray.experience_maker_holder import ExperienceMakerHolder
from coati.ray.utils import (
get_actor_from_args,
get_critic_from_args,
get_reward_model_from_args,
get_strategy_from_args,
get_tokenizer_from_args,
)
from torch.utils.data import DataLoader
from transformers import AutoConfig
from transformers.modeling_utils import no_init_weights
def get_free_port():
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
return s.getsockname()[1]
def get_local_ip():
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
s.connect(("8.8.8.8", 80))
return s.getsockname()[0]
def main(args):
master_addr = str(get_local_ip())
# trainer_env_info
trainer_port = str(get_free_port())
env_info_trainers = [
{
"local_rank": "0",
"rank": str(rank),
"world_size": str(args.num_trainers),
"master_port": trainer_port,
"master_addr": master_addr,
}
for rank in range(args.num_trainers)
]
# maker_env_info
maker_port = str(get_free_port())
env_info_maker = {
"local_rank": "0",
"rank": "0",
"world_size": "1",
"master_port": maker_port,
"master_addr": master_addr,
}
# configure tokenizer
tokenizer = get_tokenizer_from_args(args.model)
def trainer_model_fn():
actor = get_actor_from_args(args.model, args.pretrain).half().cuda()
critic = get_critic_from_args(args.model, args.critic_pretrain).half().cuda()
return actor, critic
# configure Trainer
trainer_refs = [
DetachedPPOTrainer.options(name=f"trainer{i}", num_gpus=1, max_concurrency=2).remote(
experience_maker_holder_name_list=["maker1"],
strategy_fn=partial(get_strategy_from_args, args.trainer_strategy),
model_fn=trainer_model_fn,
env_info=env_info_trainer,
train_batch_size=args.train_batch_size,
buffer_limit=16,
eval_performance=True,
debug=args.debug,
update_lora_weights=not (args.lora_rank == 0),
)
for i, env_info_trainer in enumerate(env_info_trainers)
]
def model_fn():
actor = get_actor_from_args(args.model, args.pretrain).requires_grad_(False).half().cuda()
critic = get_critic_from_args(args.model, args.critic_pretrain).requires_grad_(False).half().cuda()
reward_model = get_reward_model_from_args(args.model, args.critic_pretrain).requires_grad_(False).half().cuda()
if args.initial_model_quant_ckpt is not None and args.model == "llama":
# quantize initial model
actor_cfg = AutoConfig.from_pretrained(args.pretrain)
with low_resource_init(), no_init_weights():
initial_model = get_actor_from_args(args.model, config=actor_cfg)
initial_model.model = (
llama_load_quant(
initial_model.model, args.initial_model_quant_ckpt, args.quant_bits, args.quant_group_size
)
.cuda()
.requires_grad_(False)
)
else:
initial_model = get_actor_from_args(args.model, args.pretrain).requires_grad_(False).half().cuda()
return actor, critic, reward_model, initial_model
# configure Experience Maker
experience_holder_ref = ExperienceMakerHolder.options(name="maker1", num_gpus=1, max_concurrency=2).remote(
detached_trainer_name_list=[f"trainer{i}" for i in range(args.num_trainers)],
strategy_fn=partial(get_strategy_from_args, args.maker_strategy),
model_fn=model_fn,
env_info=env_info_maker,
experience_batch_size=args.experience_batch_size,
kl_coef=0.1,
debug=args.debug,
update_lora_weights=not (args.lora_rank == 0),
# sync_models_from_trainers=True,
# generation kwargs:
max_length=512,
do_sample=True,
temperature=1.0,
top_k=50,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
eval_performance=True,
use_cache=True,
)
# uncomment this function if sync_models_from_trainers is True
# ray.get([
# trainer_ref.sync_models_to_remote_makers.remote()
# for trainer_ref in trainer_refs
# ])
wait_tasks = []
total_steps = args.experience_batch_size * args.experience_steps // (args.num_trainers * args.train_batch_size)
for trainer_ref in trainer_refs:
wait_tasks.append(trainer_ref.fit.remote(total_steps, args.update_steps, args.train_epochs))
dataset_size = args.experience_batch_size * 4
def build_dataloader():
def tokenize_fn(texts):
batch = tokenizer(texts, return_tensors="pt", max_length=96, padding="max_length", truncation=True)
return {k: v.cuda() for k, v in batch.items()}
dataset = pd.read_csv(args.prompt_path)["prompt"]
dataloader = DataLoader(dataset=dataset, batch_size=dataset_size, shuffle=True, collate_fn=tokenize_fn)
return dataloader
wait_tasks.append(experience_holder_ref.workingloop.remote(build_dataloader, num_steps=args.experience_steps))
ray.get(wait_tasks)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--prompt_path", type=str, default=None)
parser.add_argument("--num_trainers", type=int, default=1)
parser.add_argument(
"--trainer_strategy",
choices=["ddp", "colossalai_gemini", "colossalai_zero2", "colossalai_gemini_cpu", "colossalai_zero2_cpu"],
default="ddp",
)
parser.add_argument("--maker_strategy", choices=["naive"], default="naive")
parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
parser.add_argument("--critic_model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
parser.add_argument("--pretrain", type=str, default=None)
parser.add_argument("--critic_pretrain", type=str, default=None)
parser.add_argument("--experience_steps", type=int, default=4)
parser.add_argument("--experience_batch_size", type=int, default=8)
parser.add_argument("--train_epochs", type=int, default=1)
parser.add_argument("--update_steps", type=int, default=2)
parser.add_argument("--train_batch_size", type=int, default=8)
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
parser.add_argument("--initial_model_quant_ckpt", type=str, default=None)
parser.add_argument("--quant_bits", type=int, default=4)
parser.add_argument("--quant_group_size", type=int, default=128)
parser.add_argument("--debug", action="store_true")
args = parser.parse_args()
ray.init(namespace=os.environ["RAY_NAMESPACE"], runtime_env={"env_vars": dict(os.environ)})
main(args)

View File

@ -1,201 +0,0 @@
import argparse
import os
import socket
from functools import partial
import pandas as pd
import ray
from coati.quant import llama_load_quant, low_resource_init
from coati.ray.detached_trainer_ppo import DetachedPPOTrainer
from coati.ray.experience_maker_holder import ExperienceMakerHolder
from coati.ray.utils import (
get_actor_from_args,
get_critic_from_args,
get_receivers_per_sender,
get_reward_model_from_args,
get_strategy_from_args,
)
from torch.utils.data import DataLoader
from transformers import AutoConfig, AutoTokenizer
from transformers.modeling_utils import no_init_weights
def get_free_port():
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
return s.getsockname()[1]
def get_local_ip():
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
s.connect(("8.8.8.8", 80))
return s.getsockname()[0]
def main(args):
master_addr = str(get_local_ip())
# trainer_env_info
trainer_port = str(get_free_port())
env_info_trainers = [
{
"local_rank": "0",
"rank": str(rank),
"world_size": str(args.num_trainers),
"master_port": trainer_port,
"master_addr": master_addr,
}
for rank in range(args.num_trainers)
]
# maker_env_info
maker_port = str(get_free_port())
env_info_makers = [
{
"local_rank": "0",
"rank": str(rank),
"world_size": str(args.num_makers),
"master_port": maker_port,
"master_addr": master_addr,
}
for rank in range(args.num_makers)
]
# configure tokenizer
tokenizer = AutoTokenizer.from_pretrained(args.pretrain)
tokenizer.pad_token = tokenizer.eos_token
def model_fn():
actor = get_actor_from_args(args.model, args.pretrain).requires_grad_(False).half().cuda()
critic = get_critic_from_args(args.model, args.critic_pretrain).requires_grad_(False).half().cuda()
reward_model = get_reward_model_from_args(args.model, args.critic_pretrain).requires_grad_(False).half().cuda()
if args.initial_model_quant_ckpt is not None and args.model == "llama":
# quantize initial model
actor_cfg = AutoConfig.from_pretrained(args.pretrain)
with low_resource_init(), no_init_weights():
initial_model = get_actor_from_args(args.model, config=actor_cfg)
initial_model.model = (
llama_load_quant(
initial_model.model, args.initial_model_quant_ckpt, args.quant_bits, args.quant_group_size
)
.cuda()
.requires_grad_(False)
)
else:
initial_model = get_actor_from_args(args.model, args.pretrain).requires_grad_(False).half().cuda()
return actor, critic, reward_model, initial_model
# configure Experience Maker
experience_holder_refs = [
ExperienceMakerHolder.options(name=f"maker{i}", num_gpus=1, max_concurrency=2).remote(
detached_trainer_name_list=[
f"trainer{x}"
for x in get_receivers_per_sender(i, args.num_makers, args.num_trainers, allow_idle_sender=False)
],
strategy_fn=partial(get_strategy_from_args, args.maker_strategy),
model_fn=model_fn,
env_info=env_info_maker,
kl_coef=0.1,
debug=args.debug,
update_lora_weights=not (args.lora_rank == 0),
# sync_models_from_trainers=True,
# generation kwargs:
max_length=512,
do_sample=True,
temperature=1.0,
top_k=50,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
eval_performance=True,
use_cache=True,
)
for i, env_info_maker in enumerate(env_info_makers)
]
def trainer_model_fn():
actor = get_actor_from_args(args.model, args.pretrain, lora_rank=args.lora_rank).half().cuda()
critic = get_critic_from_args(args.model, args.critic_pretrain, lora_rank=args.lora_rank).half().cuda()
return actor, critic
# configure Trainer
trainer_refs = [
DetachedPPOTrainer.options(name=f"trainer{i}", num_gpus=1, max_concurrency=2).remote(
experience_maker_holder_name_list=[
f"maker{x}"
for x in get_receivers_per_sender(i, args.num_trainers, args.num_makers, allow_idle_sender=True)
],
strategy_fn=partial(get_strategy_from_args, args.trainer_strategy),
model_fn=trainer_model_fn,
env_info=env_info_trainer,
train_batch_size=args.train_batch_size,
buffer_limit=16,
eval_performance=True,
debug=args.debug,
update_lora_weights=not (args.lora_rank == 0),
)
for i, env_info_trainer in enumerate(env_info_trainers)
]
dataset_size = args.experience_batch_size * 4
def build_dataloader():
def tokenize_fn(texts):
batch = tokenizer(texts, return_tensors="pt", max_length=96, padding="max_length", truncation=True)
return {k: v.cuda() for k, v in batch.items()}
dataset = pd.read_csv(args.prompt_path)["prompt"]
dataloader = DataLoader(dataset=dataset, batch_size=dataset_size, shuffle=True, collate_fn=tokenize_fn)
return dataloader
# uncomment this function if sync_models_from_trainers is True
# ray.get([
# trainer_ref.sync_models_to_remote_makers.remote()
# for trainer_ref in trainer_refs
# ])
wait_tasks = []
for experience_holder_ref in experience_holder_refs:
wait_tasks.append(experience_holder_ref.workingloop.remote(build_dataloader, num_steps=args.experience_steps))
total_steps = (
args.experience_batch_size
* args.experience_steps
* args.num_makers
// (args.num_trainers * args.train_batch_size)
)
for trainer_ref in trainer_refs:
wait_tasks.append(trainer_ref.fit.remote(total_steps, args.update_steps, args.train_epochs))
ray.get(wait_tasks)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--prompt_path", type=str, default=None)
parser.add_argument("--num_makers", type=int, default=1)
parser.add_argument("--num_trainers", type=int, default=1)
parser.add_argument(
"--trainer_strategy",
choices=["ddp", "colossalai_gemini", "colossalai_zero2", "colossalai_gemini_cpu", "colossalai_zero2_cpu"],
default="ddp",
)
parser.add_argument("--maker_strategy", choices=["naive"], default="naive")
parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
parser.add_argument("--critic_model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
parser.add_argument("--pretrain", type=str, default=None)
parser.add_argument("--critic_pretrain", type=str, default=None)
parser.add_argument("--experience_steps", type=int, default=4)
parser.add_argument("--experience_batch_size", type=int, default=8)
parser.add_argument("--train_epochs", type=int, default=1)
parser.add_argument("--update_steps", type=int, default=2)
parser.add_argument("--train_batch_size", type=int, default=8)
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
parser.add_argument("--initial_model_quant_ckpt", type=str, default=None)
parser.add_argument("--quant_bits", type=int, default=4)
parser.add_argument("--quant_group_size", type=int, default=128)
parser.add_argument("--debug", action="store_true")
args = parser.parse_args()
ray.init(namespace=os.environ["RAY_NAMESPACE"], runtime_env={"env_vars": dict(os.environ)})
main(args)

View File

@ -1,12 +0,0 @@
#!/bin/bash
set -xe
BASE=$(realpath $(dirname $0))
export RAY_NAMESPACE=admin
export DATA=/data/scratch/chatgpt/prompts.csv
# install requirements
pip install -r ${BASE}/requirements.txt
python ${BASE}/mmmt_prompt.py --prompt_path $DATA --num_makers 2 --num_trainers 2 --trainer_strategy colossalai_gemini --model opt --critic_model opt --pretrain facebook/opt-350m --critic_pretrain facebook/opt-125m --experience_batch_size 4 --train_batch_size 2

View File

@ -1,4 +1,4 @@
pandas>=1.4.1
sentencepiece
colossalai==0.4.0
colossalai==0.4.7
prompt_toolkit

View File

@ -0,0 +1,455 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Supervised fine-tuning of MoE models like Deepseek V3/R1 on a downstream task.
"""
import argparse
import json
import os
import resource
from contextlib import nullcontext
from types import MethodType
import torch
import torch.distributed as dist
from coati.dataset.loader import RawConversationDataset
from peft import LoraConfig
from tqdm import tqdm
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
import colossalai
from colossalai.accelerator import get_accelerator
from colossalai.booster import Booster
from colossalai.booster.plugin import (
GeminiPlugin,
HybridParallelPlugin,
LowLevelZeroPlugin,
MoeHybridParallelPlugin,
Plugin,
TorchDDPPlugin,
)
from colossalai.cluster import DistCoordinator
from colossalai.lazy import LazyInitContext
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
def all_reduce_mean(loss: torch.Tensor, plugin: Plugin) -> torch.Tensor:
loss = loss.data
group = getattr(plugin, "dp_group", None)
dist.all_reduce(loss, group=group)
return loss / dist.get_world_size(group)
def train(args) -> None:
# ==============================
# Initialize Distributed Training
# ==============================
colossalai.launch_from_torch()
accelerator = get_accelerator()
coordinator = DistCoordinator()
# ==============================
# Initialize Booster
# ==============================
if args.plugin == "ddp":
plugin = TorchDDPPlugin(find_unused_parameters=True if args.use_grad_checkpoint is False else False)
elif args.plugin == "gemini":
plugin = GeminiPlugin(
precision=args.mixed_precision,
initial_scale=2**16,
max_norm=args.grad_clip,
enable_gradient_accumulation=(args.accumulation_steps > 1),
enable_fused_normalization=get_accelerator().is_available(),
enable_flash_attention=args.use_flash_attn,
)
elif args.plugin == "gemini_auto":
plugin = GeminiPlugin(
precision=args.mixed_precision,
placement_policy="auto",
initial_scale=2**16,
max_norm=args.grad_clip,
enable_gradient_accumulation=(args.accumulation_steps > 1),
enable_fused_normalization=get_accelerator().is_available(),
enable_flash_attention=args.use_flash_attn,
)
elif args.plugin == "zero2":
plugin = LowLevelZeroPlugin(
stage=2,
precision=args.mixed_precision,
initial_scale=2**16,
max_norm=args.grad_clip,
)
elif args.plugin == "zero2_cpu":
plugin = LowLevelZeroPlugin(
stage=2,
precision=args.mixed_precision,
initial_scale=2**16,
cpu_offload=True,
max_norm=args.grad_clip,
)
elif args.plugin == "3d":
plugin = HybridParallelPlugin(
tp_size=args.tp,
pp_size=args.pp,
sp_size=args.sp,
sequence_parallelism_mode=args.sp_mode,
zero_stage=args.zero_stage,
enable_flash_attention=args.use_flash_attn,
enable_fused_normalization=get_accelerator().is_available(),
enable_sequence_parallelism=args.enable_sequence_parallelism,
cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False,
max_norm=args.grad_clip,
precision=args.mixed_precision,
microbatch_size=args.microbatch_size,
)
elif args.plugin == "moe":
plugin = MoeHybridParallelPlugin(
ep_size=args.ep,
tp_size=args.tp,
pp_size=args.pp,
zero_stage=args.zero_stage,
sp_size=args.sp,
sequence_parallelism_mode=args.sp_mode,
enable_sequence_parallelism=args.sp > 1,
enable_fused_normalization=get_accelerator().is_available(),
enable_flash_attention=args.use_flash_attn,
max_norm=args.grad_clip,
precision=args.mixed_precision,
microbatch_size=args.microbatch_size,
)
else:
raise ValueError(f"Unknown plugin {args.plugin}")
booster = Booster(plugin=plugin)
def is_master():
if isinstance(plugin, HybridParallelPlugin) and plugin.pp_size > 1:
return coordinator.rank == coordinator.world_size - 1
return coordinator.is_master()
# ==============================
# Initialize Tensorboard and Save Config
# ==============================
if is_master():
if args.tensorboard_dir is not None:
from torch.utils.tensorboard import SummaryWriter
os.makedirs(args.tensorboard_dir, exist_ok=True)
writer = SummaryWriter(args.tensorboard_dir)
with open(args.config_file, "w") as f:
json.dump(args.__dict__, f, indent=4)
# ======================================================
# Initialize Tokenizer, Dataset, Collator and Dataloader
# ======================================================
tokenizer = AutoTokenizer.from_pretrained(args.pretrained, trust_remote_code=True)
coordinator.print_on_master(
f"Training Info:\nConfig file: {args.config_file} \nTensorboard logs: {args.tensorboard_dir} \nModel checkpoint: {args.save_dir}"
)
coordinator.print_on_master(f"Load dataset: {args.dataset}")
dataset = RawConversationDataset(
tokenizer,
args.dataset,
args.max_length,
)
dataloader = plugin.prepare_dataloader(
dataset=dataset,
batch_size=args.batch_size,
shuffle=True,
drop_last=True,
)
coordinator.print_on_master(
f"Max device memory after data loader: {accelerator.max_memory_allocated() / 1024 ** 2:.2f} MB"
)
# ======================================================
# Initialize Model, Objective, Optimizer and LR Scheduler
# ======================================================
# When training the ChatGLM model, LoRA and gradient checkpointing are incompatible.
init_ctx = (
LazyInitContext(default_device=get_current_device())
if isinstance(plugin, (GeminiPlugin, HybridParallelPlugin))
else nullcontext()
)
attn_impl = "eager" if get_accelerator().name == "npu" else "flash_attention_2"
config = AutoConfig.from_pretrained(args.pretrained, trust_remote_code=True)
with init_ctx:
# from_pretrained is not compatible with LoRA, we load pretrained weights later.
# model = AutoModelForCausalLM.from_pretrained(
# args.pretrained,
# torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
# trust_remote_code=True,
# attn_implementation=attn_impl,
# )
model = AutoModelForCausalLM.from_config(
config,
trust_remote_code=True,
attn_implementation=attn_impl,
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
)
if args.lora_rank > 0:
if model.__class__.__name__.startswith("DeepseekV3"):
lora_config = LoraConfig(
task_type="CAUSAL_LM",
r=args.lora_rank,
lora_alpha=args.lora_alpha,
target_modules=["gate_proj", "up_proj", "down_proj"],
)
else:
lora_config = LoraConfig(task_type="CAUSAL_LM", r=args.lora_rank, lora_alpha=args.lora_alpha)
model = booster.enable_lora(model, lora_config=lora_config)
# this is essential, otherwise the grad checkpoint will not work.
model.train()
if args.use_grad_checkpoint:
model.gradient_checkpointing_enable()
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
if model.config.__class__.__name__.startswith("DeepseekV3"):
model.config.use_cache = False
model.eval()
# enable grad for moe layers
for m in model.modules():
if m.__class__.__name__ == "DeepseekV3MoE":
m.moe_infer = MethodType(m.moe_infer.__wrapped__, m)
model_numel = sum(p.numel() for p in model.parameters())
coordinator.print_on_master(f"Model params: {model_numel / 1e9:.2f} B")
optimizer = HybridAdam(
model_params=model.parameters(),
lr=args.lr,
betas=(0.9, 0.95),
weight_decay=args.weight_decay,
adamw_mode=True,
)
if args.warmup_steps is None:
args.warmup_steps = int(args.num_epochs * 0.025 * (len(dataloader) // args.accumulation_steps))
coordinator.print_on_master(f"Warmup steps is set to {args.warmup_steps}")
lr_scheduler = CosineAnnealingWarmupLR(
optimizer=optimizer,
total_steps=args.num_epochs * (len(dataloader) // args.accumulation_steps),
warmup_steps=args.warmup_steps,
eta_min=0.1 * args.lr,
)
# Flash attention will be disabled because it does NOT support fp32.
default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
torch.set_default_dtype(default_dtype)
model, optimizer, _, dataloader, lr_scheduler = booster.boost(
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
dataloader=dataloader,
)
torch.set_default_dtype(torch.float)
booster.load_model(model, args.pretrained, low_cpu_mem_mode=False, num_threads=8)
coordinator.print_on_master(
f"Booster init max device memory: {accelerator.max_memory_allocated() / 1024 ** 2:.2f} MB"
)
coordinator.print_on_master(
f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
)
start_epoch = 0
start_step = 0
num_steps_per_epoch = len(dataloader) // args.accumulation_steps
for epoch in range(start_epoch, args.num_epochs):
dataloader.sampler.set_epoch(epoch=epoch)
if isinstance(plugin, HybridParallelPlugin) and plugin.pp_size > 1:
data_iter = iter(dataloader)
step_bar = tqdm(
range(len(dataloader)),
desc="Step",
disable=not is_master(),
)
for step in step_bar:
outputs = booster.execute_pipeline(
data_iter,
model,
criterion=lambda outputs, inputs: outputs[0],
optimizer=optimizer,
return_loss=True,
)
loss = outputs["loss"]
if booster.plugin.stage_manager.is_last_stage():
global_loss = all_reduce_mean(loss, plugin)
optimizer.step()
if booster.plugin.stage_manager.is_last_stage():
grad_norm = optimizer.get_grad_norm()
step_bar.set_postfix({"loss": global_loss.item(), "grad_norm": grad_norm})
if args.tensorboard_dir is not None and is_master():
global_step = (epoch * num_steps_per_epoch) + (step + 1) // args.accumulation_steps
writer.add_scalar(tag="Loss", scalar_value=global_loss.item(), global_step=global_step)
writer.add_scalar(
tag="Learning Rate",
scalar_value=lr_scheduler.get_last_lr()[0],
global_step=global_step,
)
writer.add_scalar(tag="Grad Norm", scalar_value=grad_norm, global_step=global_step)
lr_scheduler.step()
optimizer.zero_grad()
else:
pbar = tqdm(
dataloader,
desc=f"Epoch {epoch}",
disable=not is_master(),
initial=start_step // args.accumulation_steps,
)
total_loss = torch.tensor(0.0, device=get_current_device())
for step, batch in enumerate(pbar, start=start_step // args.accumulation_steps):
batch = {k: v.to(get_current_device()) for k, v in batch.items() if isinstance(v, torch.Tensor)}
batch_output = model(**batch)
loss = batch_output.loss / args.accumulation_steps
total_loss.add_(loss.data)
booster.backward(loss=loss, optimizer=optimizer)
if (step + 1) % args.accumulation_steps == 0:
all_reduce_mean(total_loss, plugin)
optimizer.step()
grad_norm = optimizer.get_grad_norm()
pbar.set_postfix({"loss": total_loss.item(), "grad_norm": grad_norm})
if args.tensorboard_dir is not None and is_master():
global_step = (epoch * num_steps_per_epoch) + (step + 1) // args.accumulation_steps
writer.add_scalar(tag="Loss", scalar_value=total_loss.item(), global_step=global_step)
writer.add_scalar(
tag="Learning Rate",
scalar_value=lr_scheduler.get_last_lr()[0],
global_step=global_step,
)
writer.add_scalar(tag="Grad Norm", scalar_value=grad_norm, global_step=global_step)
lr_scheduler.step()
optimizer.zero_grad()
total_loss.fill_(0.0)
# Delete cache.
# del batch, batch_labels, batch_output, loss
accelerator.empty_cache()
# Final save.
coordinator.print_on_master("Start saving final model checkpoint")
if args.lora_rank > 0:
booster.save_lora_as_pretrained(model, os.path.join(args.save_dir, "lora"))
else:
booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True)
coordinator.print_on_master(f"Saved final model checkpoint at epoch {epoch} at folder {args.save_dir}")
coordinator.print_on_master(f"Max device memory usage: {accelerator.max_memory_allocated()/1024**2:.2f} MB")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Basic training information.
parser.add_argument(
"-m",
"--pretrained",
type=str,
required=True,
help="Address of the pre-trained model",
)
parser.add_argument("-d", "--dataset", type=str, required=True, help="Raw Jonl dataset for training.")
parser.add_argument(
"-p",
"--plugin",
type=str,
default="zero2",
choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d", "ddp", "moe"],
help="Choose which plugin to use",
)
parser.add_argument("--save_dir", type=str, default="checkpoint_dir", help="Checkpoint directory")
parser.add_argument("--tensorboard_dir", type=str, default=None, help="Tensorboard directory")
parser.add_argument("--config_file", type=str, default="training_config.json", help="Config file")
# Training parameters
parser.add_argument("-n", "--num_epochs", type=int, default=1, help="Number of training epochs")
parser.add_argument("--accumulation_steps", type=int, default=1, help="Number of accumulation steps")
parser.add_argument("--batch_size", type=int, default=2, help="Global Batch size of each process")
parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate")
parser.add_argument("--max_length", type=int, default=8192, help="Model max length")
parser.add_argument(
"--mixed_precision",
type=str,
default="bf16",
choices=["fp16", "bf16"],
help="Mixed precision",
)
parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")
parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay")
parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
parser.add_argument(
"-g",
"--use_grad_checkpoint",
action="store_true",
default=False,
help="Use gradient checkpointing",
)
parser.add_argument(
"-f",
"--use_flash_attn",
action="store_true",
default=False,
help="Use flash-attention",
)
# Additional arguments for 3d plugin.
parser.add_argument("--tp", type=int, default=1, help="TP size, used for 3d plugin.")
parser.add_argument("--pp", type=int, default=1, help="PP size, used for 3d plugin.")
parser.add_argument("--sp", type=int, default=1, help="SP size, used for 3d plugin.")
parser.add_argument("--ep", type=int, default=1, help="EP size, used for moe plugin.")
parser.add_argument("--zero_stage", type=int, default=1, help="Zero stage, used for 3d plugin.", choices=[0, 1, 2])
parser.add_argument(
"--sp_mode",
type=str,
default="split_gather",
choices=["split_gather", "ring", "all_to_all"],
help="SP mode, used for 3d plugin.",
)
parser.add_argument(
"--enable_sequence_parallelism",
default=False,
action="store_true",
help="Whether to enable SP, used for 3d plugin.",
)
parser.add_argument(
"--zero_cpu_offload", default=False, action="store_true", help="Whether to use offloading, used for 3d plugin."
)
parser.add_argument(
"--microbatch_size", type=int, default=1, help="Batch size for each process in PP, used for 3d plugin."
)
parser.add_argument("--lora_rank", type=int, default=0, help="lora rank when using lora to train.")
parser.add_argument("--lora_alpha", type=int, default=8, help="lora alpha when using lora to train.")
args = parser.parse_args()
if args.plugin in ["3d", "moe"] and args.pp > 1 and args.accumulation_steps > 1:
raise ValueError("Accumulation steps should be 1 when using PP. Please adjust batch size directly.")
train(args)

View File

@ -0,0 +1,10 @@
[{"role": "user", "content": "只剩一个心脏了还能活吗?"}, {"role": "assistant", "content": "能,人本来就只有一个心脏。"}]
[{"role": "user", "content": "爸爸再婚,我是不是就有了个新娘?"}, {"role": "assistant", "content": "不是的,你有了一个继母。\"新娘\"是指新婚的女方,而你爸爸再婚,他的新婚妻子对你来说是继母。"}]
[{"role": "user", "content": "马上要上游泳课了,昨天洗的泳裤还没干,怎么办"}, {"role": "assistant", "content": "游泳时泳裤本来就会湿,不用晾干。"}]
[{"role": "user", "content": "为什么没人说ABCD型的成语🤔"}, {"role": "assistant", "content": "这是因为中文成语一般都是四字成语每个字都有其特定的含义四个字合在一起构成一个完整的意思。而ABCD型的成语最常见所以大家不会刻意强调。"}]
[{"role": "user", "content": "太阳还有五十亿年就没了,那到时候向日葵看哪呢?"}, {"role": "assistant", "content": "实际上,向日葵并不会意识到太阳的存在与否,它们只是通过光敏反应来追踪太阳。如果太阳消失,向日葵也不会存活下来,所以它们不会看向任何地方。"}]
[{"role": "user", "content": "我把我车钥匙落在车里,现在锁不了车门了。 有人能告诉我怎么办吗?"}, {"role": "assistant", "content": "如果你的车钥匙已经锁在车里,那么你实际上已经无法再从外部打开车门了,所以你的车已经被锁住了。你需要做的是找专业的开锁服务或者联系汽车救援服务来帮你打开车门,取回你的车钥匙。"}]
[{"role": "user", "content": "同学说steam好玩我现在八级烫伤了怎么办 救救我"}, {"role": "assistant", "content": "\"Steam\"是一个电子游戏分发平台不是steam蒸汽。建议立即去医院顺便看看脑子。"}]
[{"role": "user", "content": "人在岩浆里憋气最长时间是多少? 刚才好奇搜了一下水里憋气时间最长的是克罗地亚54岁在水中闭气24分33秒但是我搜在岩浆里憋气时间最长居然发现一个没有不都是液体吗?为什么没人在岩浆里憋气?"}, {"role": "assistant", "content": "因为岩浆的温度极高达到了800-1200摄氏度人如果接触到岩浆会立刻被烧伤甚至瞬间死亡所以没有人能在岩浆里憋气。这并不是一个合理的比较因为水和岩浆虽然都是液体但是它们的性质差别极大。"}]
[{"role": "user", "content": "出生证丢了怎么证明自己出生了"}, {"role": "assistant", "content": "这是个很有意思的问题,从现实意义来讲,你既然已经活着,就已经证明你出生了。如果需要法律意义上的证明,可以通过其他官方文件,例如护照、身份证、驾驶证等。此外,医院的出生记录、学校的入学记录、家庭照片等也可以作为间接证据。如果所有这些都没有,可以向出生地的公共卫生部门或者医院申请复制出生证明。"}]
[{"role": "user", "content": "为什么抄袭永远都是今人抄袭古人,而没有古人抄袭今人的案例?"}, {"role": "assistant", "content": "这是因为时间的流逝是单向的,古人无法预知今人的思想和作品,所以古人无法抄袭今人。而今人可以通过学习古人的作品,因此有可能出现抄袭古人的情况。"}]

View File

@ -0,0 +1,494 @@
import argparse
import json
import os
import resource
from contextlib import nullcontext
import torch
import torch.distributed as dist
from coati.dataset import (
DataCollatorForPromptDataset,
DataCollatorForSupervisedDataset,
StatefulDistributedSampler,
load_tokenized_dataset,
setup_conversation_template,
)
from coati.models import LoraConfig, RewardModel, RLVRRewardModel, convert_to_lora_module, disable_dropout, lora_manager
from coati.trainer import GRPOTrainer
from coati.utils import load_checkpoint
from coati.utils.reward_score import *
from transformers import AutoModelForCausalLM, AutoTokenizer
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
from colossalai.cluster import DistCoordinator
from colossalai.logging import get_dist_logger
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.nn.optimizer import HybridAdam
from colossalai.shardformer.policies.auto_policy import get_autopolicy
logger = get_dist_logger()
# default settings for response format tags, overwrite it in chat_template definition if needed
response_format_tags = {
"think_start": {"text": "<think>", "num_occur": 1},
"think_end": {"text": "</think>", "num_occur": 1},
"answer_start": {"text": "<answer>", "num_occur": 1},
"answer_end": {"text": "</answer>", "num_occur": 1},
}
def train(args):
global response_format_tags
lora_config = None
if args.lora_config is not None:
lora_config = LoraConfig.from_file(args.lora_config)
# check lora compatibility
if "gemini" in args.plugin and lora_config is not None and lora_config.r > 0:
raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
if args.plugin == "gemini_auto" and args.accumulation_steps > 1:
raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin")
# ==============================
# Initialize Distributed Training
# ==============================
colossalai.launch_from_torch()
coordinator = DistCoordinator()
# ======================================================
# Initialize Model, Objective, Optimizer and LR Scheduler
# ======================================================
# Temp Fix: Disable lazy init due to version conflict
# init_ctx = (
# LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext()
# )
init_ctx = nullcontext()
with init_ctx:
if args.use_flash_attn:
actor = AutoModelForCausalLM.from_pretrained(
args.pretrain,
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
use_flash_attention_2=True,
local_files_only=True,
trust_remote_code=True,
)
ref_model = AutoModelForCausalLM.from_pretrained(
args.pretrain,
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
use_flash_attention_2=True,
local_files_only=True,
trust_remote_code=True,
)
if args.rm_pretrain:
reward_model = RewardModel(
args.rm_pretrain,
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
use_flash_attention_2=True,
trust_remote_code=True,
)
coordinator.print_on_master(msg="Flash-attention enabled successfully")
else:
actor = AutoModelForCausalLM.from_pretrained(args.pretrain, local_files_only=True, trust_remote_code=True)
if args.rm_pretrain:
reward_model = RewardModel(args.rm_pretrain, trust_remote_code=True)
ref_model = AutoModelForCausalLM.from_pretrained(
args.pretrain, local_files_only=True, trust_remote_code=True
)
if args.lora_config is not None:
actor = convert_to_lora_module(actor, lora_config=lora_config)
for name, module in actor.named_modules():
if "norm" in name or "gate" in name:
module = module.to(torch.float32)
lora_manager.able_to_merge = False
# Disable dropout
disable_dropout(actor)
if args.grad_checkpoint:
actor.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
# configure tokenizer
tokenizer_dir = args.tokenizer_dir if args.tokenizer_dir is not None else args.pretrain
tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir, use_fast=False, trust_remote_code=True)
if os.path.exists(args.conversation_template_config):
with open(args.conversation_template_config, "r", encoding="utf8") as f:
conversation_template_config = json.load(f)
dist.barrier()
if "response_format_tags" in conversation_template_config:
logger.warning(f"Overwrite default response format tags with {args.conversation_template_config}")
response_format_tags = conversation_template_config.get("response_format_tags", response_format_tags)
conversation_template = setup_conversation_template(
tokenizer, chat_template_config=conversation_template_config, save_path=args.conversation_template_config
)
stop_ids = conversation_template.stop_ids if len(conversation_template.stop_ids) > 0 else None
else:
raise ValueError("Conversation template config is not provided or incorrect")
if hasattr(tokenizer, "pad_token") and hasattr(tokenizer, "eos_token") and tokenizer.eos_token is not None:
try:
# Some tokenizers doesn't allow to set pad_token mannually e.g., Qwen
tokenizer.pad_token = tokenizer.eos_token
except AttributeError as e:
logger.warning(f"Unable to set pad token to eos token, {str(e)}")
if not hasattr(tokenizer, "pad_token") or tokenizer.pad_token is None:
logger.warning(
"The tokenizer does not have a pad token which is required. May lead to unintended behavior in training, Please consider manually set them."
)
tokenizer.add_bos_token = False
tokenizer.add_eos_token = False
tokenizer.padding_side = "left" # left padding for generation (online learning)
# configure generation config
actor.generation_config.update(
pad_token_id=tokenizer.eos_token_id, bos_token_id=tokenizer.bos_token_id, eos_token_id=tokenizer.eos_token_id
)
# configure optimizer
coordinator.print_on_master(f"setting up optimizer for actor: lr={args.lr}, weight_decay={args.weight_decay}")
actor_optim = HybridAdam(
model_params=actor.parameters(),
lr=args.lr,
betas=(0.9, 0.95),
weight_decay=args.weight_decay,
adamw_mode=True,
)
if args.warmup_steps is None:
args.warmup_steps = int(0.025 * args.num_episodes)
coordinator.print_on_master(f"Warmup steps is set to {args.warmup_steps}")
actor_lr_scheduler = CosineAnnealingWarmupLR(
optimizer=actor_optim,
total_steps=args.num_episodes,
warmup_steps=args.warmup_steps,
eta_min=0.1 * args.lr,
)
# ==============================
# Initialize Booster
# ==============================
if args.plugin == "ddp":
"""
Default torch ddp plugin without any acceleration, for
debugging purpose acceleration, for debugging purpose
"""
plugin = TorchDDPPlugin(find_unused_parameters=True)
elif args.plugin == "gemini":
plugin = GeminiPlugin(
precision=args.mixed_precision,
placement_policy="static",
initial_scale=2**16,
max_norm=args.grad_clip,
enable_gradient_accumulation=True,
enable_flash_attention=args.use_flash_attn,
)
elif args.plugin == "gemini_auto":
plugin = GeminiPlugin(
precision=args.mixed_precision,
placement_policy="auto",
initial_scale=2**16,
max_norm=args.grad_clip,
enable_flash_attention=args.use_flash_attn,
)
elif args.plugin == "zero2":
plugin = LowLevelZeroPlugin(
stage=2,
precision=args.mixed_precision,
initial_scale=2**16,
max_norm=args.grad_clip,
)
elif args.plugin == "zero2_cpu":
plugin = LowLevelZeroPlugin(
stage=2,
precision=args.mixed_precision,
initial_scale=2**16,
cpu_offload=True,
max_norm=args.grad_clip,
)
elif args.plugin == "3d":
if args.use_flash_attn and (args.tp > 1 or args.pp > 1 or args.sp > 1 or args.enable_sequence_parallelism):
logger.warning("Flash attention cannot be used with 3D parallelism for PPO training. Disabling it.")
args.use_flash_attn = False
plugin = HybridParallelPlugin(
tp_size=args.tp,
pp_size=args.pp,
sp_size=args.sp,
sequence_parallelism_mode=args.sp_mode,
zero_stage=args.zero_stage,
enable_flash_attention=args.use_flash_attn,
enable_sequence_parallelism=args.enable_sequence_parallelism,
cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False,
parallel_output=False,
max_norm=args.grad_clip,
precision=args.mixed_precision,
)
if args.rm_pretrain:
custom_plugin = HybridParallelPlugin(
tp_size=args.tp,
pp_size=args.pp,
sp_size=args.sp,
sequence_parallelism_mode=args.sp_mode,
zero_stage=args.zero_stage,
enable_flash_attention=args.use_flash_attn,
enable_sequence_parallelism=args.enable_sequence_parallelism,
cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False,
parallel_output=False,
max_norm=args.grad_clip,
precision=args.mixed_precision,
custom_policy=get_autopolicy(reward_model.model),
)
else:
raise ValueError(f"Unknown plugin {args.plugin}")
if args.plugin != "3d" and args.rm_pretrain:
custom_plugin = plugin
# configure dataset
coordinator.print_on_master(f"Load dataset: {args.prompt_dataset}")
mode_map = {"train": "train", "valid": "validation", "test": "test"}
train_prompt_dataset = load_tokenized_dataset(dataset_paths=args.prompt_dataset, mode="train", mode_map=mode_map)
data_collator = DataCollatorForPromptDataset(tokenizer=tokenizer, max_length=args.max_length - args.max_seq_len)
train_prompt_dataloader = plugin.prepare_dataloader(
dataset=train_prompt_dataset,
batch_size=args.experience_batch_size,
shuffle=True,
drop_last=True,
collate_fn=data_collator,
distributed_sampler_cls=StatefulDistributedSampler,
)
if len(args.ptx_dataset) > 0:
train_ptx_dataset = load_tokenized_dataset(dataset_paths=args.ptx_dataset, mode="train", mode_map=mode_map)
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=args.max_length)
train_pretrain_dataloader = plugin.prepare_dataloader(
dataset=train_ptx_dataset,
batch_size=args.ptx_batch_size,
shuffle=True,
drop_last=True,
collate_fn=data_collator,
distributed_sampler_cls=StatefulDistributedSampler,
)
else:
train_pretrain_dataloader = None
actor_booster = Booster(plugin=plugin)
ref_booster = Booster(plugin=plugin)
if args.rm_pretrain:
rm_booster = Booster(plugin=custom_plugin)
default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
torch.set_default_dtype(default_dtype)
actor, actor_optim, _, train_prompt_dataloader, actor_lr_scheduler = actor_booster.boost(
model=actor,
optimizer=actor_optim,
lr_scheduler=actor_lr_scheduler,
dataloader=train_prompt_dataloader,
)
if args.rm_pretrain:
reward_model, _, _, _, _ = rm_booster.boost(model=reward_model, dataloader=train_prompt_dataloader)
else:
if args.reward_functions:
reward_fn_list = []
for reward_fn in args.reward_functions:
"""
To define custom reward function, you can define your functions under:
colossalai/applications/ColossalChat/coati/utils/reward_score/__init__.py
and use it here by mofiying the following line:
"""
if reward_fn == "gsm8k_reward_fn":
reward_fn_list.append(gsm8k_reward_fn)
elif reward_fn == "math_competition_reward_fn":
reward_fn_list.append(math_competition_reward_fn)
else:
raise ValueError(f"Unknown reward function {reward_fn}")
reward_model = RLVRRewardModel(
reward_fn_list=reward_fn_list, tokenizer=tokenizer, tags=response_format_tags
)
ref_model, _, _, _, _ = ref_booster.boost(model=ref_model, dataloader=train_prompt_dataloader)
torch.set_default_dtype(torch.float)
coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
coordinator.print_on_master(
f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
)
sampler_start_idx = 0
start_step = 0
if args.rm_checkpoint_path is not None:
if "modeling" in args.rm_checkpoint_path:
rm_booster.load_model(reward_model, args.rm_checkpoint_path)
else:
_, _, _ = load_checkpoint(
load_dir=args.rm_checkpoint_path,
booster=rm_booster,
model=reward_model,
optimizer=None,
lr_scheduler=None,
)
coordinator.print_on_master(f"Loaded reward model checkpoint {args.rm_checkpoint_path}")
if args.checkpoint_path is not None:
if "modeling" in args.checkpoint_path:
actor_booster.load_model(actor, args.checkpoint_path)
ref_booster.load_model(ref_model, args.checkpoint_path)
coordinator.print_on_master(f"Loaded actor and reference model {args.checkpoint_path}")
else:
_, start_step, sampler_start_idx = load_checkpoint(
load_dir=args.checkpoint_path,
booster=actor_booster,
model=actor,
optimizer=actor_optim,
lr_scheduler=actor_lr_scheduler,
)
_, _, _ = load_checkpoint(load_dir=args.checkpoint_path, booster=ref_booster, model=ref_model)
assert isinstance(train_prompt_dataloader.sampler, StatefulDistributedSampler)
train_prompt_dataloader.sampler.set_start_index(start_index=sampler_start_idx)
coordinator.print_on_master(
f"Loaded actor and reference model checkpoint {args.checkpoint_path} at spisode {start_step}"
)
coordinator.print_on_master(f"Loaded sample at index {sampler_start_idx}")
coordinator.print_on_master(
f"Checkpoint loaded max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
)
coordinator.print_on_master(
f"Checkpoint loaded CUDA memory: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB"
)
coordinator.print_on_master(
f"Checkpoint loaded max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
)
# configure trainer
trainer = GRPOTrainer(
actor_booster,
actor,
reward_model,
ref_model,
actor_optim,
actor_lr_scheduler,
tokenizer=tokenizer,
stop_token_ids=[stop_ids],
kl_coef=args.kl_coef,
ptx_coef=args.ptx_coef,
train_batch_size=args.train_batch_size,
buffer_limit=args.num_collect_steps * args.experience_batch_size * args.num_generations,
max_length=args.max_length,
use_cache=True,
do_sample=True,
apply_loss_mask=not args.disable_loss_mask,
accumulation_steps=args.accumulation_steps,
save_dir=args.save_path,
save_interval=args.save_interval,
top_k=50,
use_tp=args.tp > 1,
num_generations=args.num_generations,
inference_batch_size=args.inference_batch_size,
logits_forward_batch_size=args.logits_forward_batch_size,
offload_inference_models="gemini" not in args.plugin,
coordinator=coordinator,
max_tokens_thinking=args.max_tokens_thinking if args.max_tokens_thinking else args.max_length - 100,
temperature_annealing_config={
"start_temperature": args.initial_temperature,
"end_temperature": args.final_temperature,
"annealing_warmup_steps": min(100, int(args.num_episodes / 6)),
"annealing_steps": min(600, int(args.num_episodes / 2)),
},
# Hack: some old model's default update_model_kwargs_fn/prepare_inputs_fn may doesn't work due to version conflict with transformers, you can overwrite them
# update_model_kwargs_fn=update_model_kwargs_fn,
# prepare_inputs_fn = None
)
trainer.fit(
num_episodes=args.num_episodes,
num_collect_steps=args.num_collect_steps,
num_update_steps=args.num_update_steps,
prompt_dataloader=train_prompt_dataloader,
pretrain_dataloader=train_pretrain_dataloader,
log_dir=args.log_dir,
use_wandb=args.use_wandb,
)
if lora_config is not None and lora_config.r > 0:
# NOTE: set model to eval to merge LoRA weights
lora_manager.able_to_merge = True
actor.eval()
# save model checkpoint after fitting on only rank0
coordinator.print_on_master("Start saving final actor model checkpoint")
actor_booster.save_model(actor, os.path.join(trainer.actor_save_dir, "modeling"), shard=True)
coordinator.print_on_master(
f"Saved final actor model checkpoint at episodes {args.num_episodes} at folder {args.save_path}"
)
coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--prompt_dataset", nargs="+", default=[])
parser.add_argument("--ptx_dataset", nargs="+", default=[])
parser.add_argument(
"--plugin",
type=str,
default="gemini",
choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d"],
help="Choose which plugin to use",
)
parser.add_argument(
"--conversation_template_config",
type=str,
default=None,
help="Path \
to save conversation template config files.",
)
parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")
parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay")
parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
parser.add_argument("--tokenizer_dir", type=str, default=None)
parser.add_argument("--tp", type=int, default=1)
parser.add_argument("--pp", type=int, default=1)
parser.add_argument("--sp", type=int, default=1)
parser.add_argument("--enable_sequence_parallelism", default=False, action="store_true")
parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2])
parser.add_argument("--zero_cpu_offload", default=False, action="store_true")
parser.add_argument("--sp_mode", type=str, default="split_gather", choices=["split_gather", "ring", "all_to_all"])
parser.add_argument("--pretrain", type=str, default=None)
parser.add_argument("--rm_pretrain", type=str, default=None)
parser.add_argument("--checkpoint_path", type=str, default=None)
parser.add_argument("--rm_checkpoint_path", type=str, help="Reward model checkpoint path")
parser.add_argument("--reward_functions", type=str, nargs="+", default=None, help="Reward functions to use")
parser.add_argument("--save_path", type=str, default="actor_checkpoint_prompts")
parser.add_argument("--num_episodes", type=int, default=1)
parser.add_argument("--num_collect_steps", type=int, default=2)
parser.add_argument("--num_update_steps", type=int, default=5)
parser.add_argument("--num_generations", type=int, default=8)
parser.add_argument("--inference_batch_size", type=int, default=None)
parser.add_argument("--save_interval", type=int, default=1000)
parser.add_argument("--train_batch_size", type=int, default=16)
parser.add_argument("--logits_forward_batch_size", type=int, default=1)
parser.add_argument("--experience_batch_size", type=int, default=16)
parser.add_argument("--ptx_batch_size", type=int, default=4)
parser.add_argument("--lora_config", type=str, default=None, help="low-rank adaptation config file path")
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
parser.add_argument("--accumulation_steps", type=int, default=8)
parser.add_argument("--lr", type=float, default=1e-6)
parser.add_argument("--kl_coef", type=float, default=0.7)
parser.add_argument("--ptx_coef", type=float, default=0.0)
parser.add_argument("--disable_loss_mask", default=False, action="store_true")
parser.add_argument("--max_length", type=int, default=2048)
parser.add_argument("--max_tokens_thinking", type=int, default=2000)
parser.add_argument("--max_seq_len", type=int, default=256)
parser.add_argument("--initial_temperature", type=float, default=1.0)
parser.add_argument("--final_temperature", type=float, default=0.9)
parser.add_argument("--log_dir", default=None, type=str)
parser.add_argument("--use_wandb", default=False, action="store_true")
parser.add_argument("--grad_checkpoint", default=False, action="store_true")
parser.add_argument("--use_flash_attn", default=False, action="store_true")
args = parser.parse_args()
train(args)

View File

@ -0,0 +1,86 @@
#!/bin/bash
set_n_least_used_CUDA_VISIBLE_DEVICES() {
local n=${1:-"9999"}
echo "GPU Memory Usage:"
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv |
tail -n +2 |
nl -v 0 |
tee /dev/tty |
sort -g -k 2 |
awk '{print $1}' |
head -n $n)
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
echo "Now CUDA_VISIBLE_DEVICES is set to:"
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
}
set_n_least_used_CUDA_VISIBLE_DEVICES 8
PROJECT_NAME="PPO-RLVR"
PARENT_SAVE_DIR="" # Path to a folder to save checkpoints
PARENT_CONFIG_FILE="" # Path to a folder to save training config logs
PRETRAINED_MODEL_PATH="" # local pretrained model path (from RLHF step 1: SFT)
PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
CONVERSATION_TEMPLATE_CONFIG_PATH="" # path to the conversation config file
LOGDIR=""
declare -a prompt_dataset=(
YOUR/PROMPT/DATA/DIR/arrow/part-00000
YOUR/PROMPT/DATA/DIR/arrow/part-00001
YOUR/PROMPT/DATA/DIR/arrow/part-00002
YOUR/PROMPT/DATA/DIR/arrow/part-00003
YOUR/PROMPT/DATA/DIR/arrow/part-00004
YOUR/PROMPT/DATA/DIR/arrow/part-00005
YOUR/PROMPT/DATA/DIR/arrow/part-00006
YOUR/PROMPT/DATA/DIR/arrow/part-00007
YOUR/PROMPT/DATA/DIR/arrow/part-00008
YOUR/PROMPT/DATA/DIR/arrow/part-00009
)
declare -a ptx_dataset=(
YOUR/SFT/DATA/DIR/arrow/part-00000
YOUR/SFT/DATA/DIR/arrow/part-00001
YOUR/SFT/DATA/DIR/arrow/part-00002
YOUR/SFT/DATA/DIR/arrow/part-00003
YOUR/SFT/DATA/DIR/arrow/part-00004
YOUR/SFT/DATA/DIR/arrow/part-00005
YOUR/SFT/DATA/DIR/arrow/part-00006
YOUR/SFT/DATA/DIR/arrow/part-00007
YOUR/SFT/DATA/DIR/arrow/part-00008
YOUR/SFT/DATA/DIR/arrow/part-00009
)
TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}"
CONFIG_FILE="${PARENT_CONFIG_FILE}${FULL_PROJECT_NAME}.json"
colossalai run --nproc_per_node 8 --num_nodes 1 --hostfile ./hostfile train_grpo.py \
--pretrain $PRETRAINED_MODEL_PATH \
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
--prompt_dataset ${prompt_dataset[@]} \
--conversation_template_config $CONVERSATION_TEMPLATE_CONFIG_PATH \
--ptx_coef 0.0 \
--plugin "zero2_cpu" \
--reward_functions math_competition_reward_fn \
--save_interval 250 \
--save_path $SAVE_DIR \
--num_episodes 100 \
--num_collect_steps 8 \
--num_update_steps 1 \
--experience_batch_size 1 \
--train_batch_size 4 \
--inference_batch_size 8 \
--logits_forward_batch_size 2 \
--accumulation_steps 4 \
--lr 1e-6 \
--mixed_precision "bf16" \
--grad_clip 0.1\
--weight_decay 0.01 \
--kl_coef 0.01 \
--warmup_steps 40 \
--max_length 2000 \
--max_seq_len 1700 \
--log_dir $LOGDIR \
--use_flash_attn \
--grad_checkpoint

View File

@ -13,9 +13,18 @@ from coati.dataset import (
load_tokenized_dataset,
setup_conversation_template,
)
from coati.models import Critic, LoraConfig, RewardModel, convert_to_lora_module, disable_dropout, lora_manager
from coati.models import (
Critic,
LoraConfig,
RewardModel,
RLVRRewardModel,
convert_to_lora_module,
disable_dropout,
lora_manager,
)
from coati.trainer import PPOTrainer
from coati.utils import load_checkpoint
from coati.utils.reward_score import *
from transformers import AutoModelForCausalLM, AutoTokenizer
import colossalai
@ -29,8 +38,17 @@ from colossalai.shardformer.policies.auto_policy import get_autopolicy
logger = get_dist_logger()
# default settings for response format tags, overwrite it in chat_template definition if needed
response_format_tags = {
"think_start": {"text": "<think>", "num_occur": 1},
"think_end": {"text": "</think>", "num_occur": 1},
"answer_start": {"text": "<answer>", "num_occur": 1},
"answer_end": {"text": "</answer>", "num_occur": 1},
}
def train(args):
global response_format_tags
lora_config = None
if args.lora_config is not None:
lora_config = LoraConfig.from_file(args.lora_config)
@ -61,28 +79,36 @@ def train(args):
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
use_flash_attention_2=True,
local_files_only=True,
trust_remote_code=True,
)
ref_model = AutoModelForCausalLM.from_pretrained(
args.pretrain,
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
use_flash_attention_2=True,
local_files_only=True,
trust_remote_code=True,
)
reward_model = RewardModel(
args.rm_pretrain,
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
use_flash_attention_2=True,
)
if not args.no_neural_reward_model:
reward_model = RewardModel(
args.rm_pretrain,
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
use_flash_attention_2=True,
trust_remote_code=True,
)
critic = Critic(
args.rm_pretrain,
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
use_flash_attention_2=True,
trust_remote_code=True,
)
coordinator.print_on_master(msg="Flash-attention enabled successfully")
else:
actor = AutoModelForCausalLM.from_pretrained(args.pretrain, local_files_only=True)
ref_model = AutoModelForCausalLM.from_pretrained(args.pretrain, local_files_only=True)
reward_model = RewardModel(args.rm_pretrain)
actor = AutoModelForCausalLM.from_pretrained(args.pretrain, local_files_only=True, trust_remote_code=True)
ref_model = AutoModelForCausalLM.from_pretrained(
args.pretrain, local_files_only=True, trust_remote_code=True
)
if not args.no_neural_reward_model:
reward_model = RewardModel(args.rm_pretrain, trust_remote_code=True)
critic = Critic(args.rm_pretrain)
if args.lora_config is not None:
@ -112,6 +138,9 @@ def train(args):
with open(args.conversation_template_config, "r", encoding="utf8") as f:
conversation_template_config = json.load(f)
dist.barrier()
if "response_format_tags" in conversation_template_config:
logger.warning(f"Overwrite default response format tags with {args.conversation_template_config}")
response_format_tags = conversation_template_config.get("response_format_tags", response_format_tags)
conversation_template = setup_conversation_template(
tokenizer, chat_template_config=conversation_template_config, save_path=args.conversation_template_config
)
@ -245,7 +274,7 @@ def train(args):
parallel_output=False,
max_norm=args.grad_clip,
precision=args.mixed_precision,
custom_policy=get_autopolicy(reward_model.model),
custom_policy=get_autopolicy(critic.model),
)
else:
raise ValueError(f"Unknown plugin {args.plugin}")
@ -284,7 +313,8 @@ def train(args):
actor_booster = Booster(plugin=plugin)
ref_booster = Booster(plugin=plugin)
rm_booster = Booster(plugin=custom_plugin)
if not args.no_neural_reward_model:
rm_booster = Booster(plugin=custom_plugin)
critic_booster = Booster(plugin=custom_plugin)
default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
@ -302,7 +332,28 @@ def train(args):
lr_scheduler=critic_lr_scheduler,
dataloader=train_prompt_dataloader,
)
reward_model, _, _, _, _ = rm_booster.boost(model=reward_model, dataloader=train_prompt_dataloader)
if not args.no_neural_reward_model:
reward_model, _, _, _, _ = rm_booster.boost(model=reward_model, dataloader=train_prompt_dataloader)
else:
if args.reward_functions:
reward_fn_list = []
for reward_fn in args.reward_functions:
"""
To define custom reward function, you can define your functions under:
colossalai/applications/ColossalChat/coati/utils/reward_score/__init__.py
and use it here by mofiying the following line:
"""
if reward_fn == "gsm8k_reward_fn":
reward_fn_list.append(gsm8k_reward_fn)
elif reward_fn == "math_competition_reward_fn":
reward_fn_list.append(math_competition_reward_fn)
else:
raise ValueError(f"Unknown reward function {reward_fn}")
reward_fn_list.append(eval(reward_fn))
reward_model = RLVRRewardModel(
reward_fn_list=reward_fn_list, tokenizer=tokenizer, tags=response_format_tags
)
ref_model, _, _, _, _ = ref_booster.boost(model=ref_model, dataloader=train_prompt_dataloader)
torch.set_default_dtype(torch.float)
@ -481,9 +532,11 @@ if __name__ == "__main__":
parser.add_argument("--sp_mode", type=str, default="split_gather", choices=["split_gather", "ring", "all_to_all"])
parser.add_argument("--pretrain", type=str, default=None)
parser.add_argument("--rm_pretrain", type=str, default=None)
parser.add_argument("--no_neural_reward_model", default=False, action="store_true")
parser.add_argument("--checkpoint_path", type=str, default=None)
parser.add_argument("--critic_checkpoint_path", type=str, default=None)
parser.add_argument("--rm_checkpoint_path", type=str, help="Reward model checkpoint path")
parser.add_argument("--reward_functions", type=str, nargs="+", default=None, help="Reward functions to use")
parser.add_argument("--save_path", type=str, default="actor_checkpoint_prompts")
parser.add_argument("--num_episodes", type=int, default=1)
parser.add_argument("--num_collect_steps", type=int, default=2)

View File

@ -2,7 +2,7 @@ transformers==4.39.3
tqdm
datasets==2.14.7
loralib
colossalai>=0.4.0
colossalai>=0.4.7
torch>=2.1.0
langchain
tokenizers
@ -21,3 +21,4 @@ ninja==1.11.1
sentencepiece==0.1.99
flash-attn
tiktoken
jsonlines

View File

@ -20,6 +20,15 @@ prompt_seed = {
},
]
}
prompt_rlvr_seed = {
"messages": [
{
"from": "user",
"content": "What is the degree of the polynomial $(4 +5x^3 +100 +2\pi x^4 + \sqrt{10}x^4 +9)$?",
},
],
"gt_answer": "4",
}
preference_seed = {
"context": [
{"from": "user", "content": "What kind of noises did dinosaurs make?"},
@ -72,6 +81,8 @@ if __name__ == "__main__":
seed = sft_seed
elif args.data_type == "prompt":
seed = prompt_seed
elif args.data_type == "prompt_rlvr":
seed = prompt_rlvr_seed
elif args.data_type == "preference":
seed = preference_seed
elif args.data_type == "kto":

View File

@ -0,0 +1,16 @@
# run under /ColossalAI/applications/ColossalChat
export NCCL_SHM_DISABLE=1
export MAX_JOBS=1
export PRETRAINED_MODEL_PATH=./models
export SFT_DATASET=./sft_data
export PROMPT_DATASET=./prompt_data
export PROMPT_RLVR_DATASET=./prompt_data
export PREFERENCE_DATASET=./preference_data
export KTO_DATASET=./kto_data
mkdir models
mkdir sft_data
mkdir prompt_data
mkdir preference_data
mkdir kto_data
# ./tests/test_data_preparation.sh
# ./tests/test_train.sh

View File

@ -24,7 +24,12 @@ if [ -z "$SFT_DATASET" ]; then
fi
if [ -z "$PROMPT_DATASET" ]; then
echo "Please set \$PROMPT_DATASET to the path to prompts."
echo "Please set \$PROMPT_DATASET to the path to prompts dataset."
exit 1
fi
if [ -z "$PROMPT_RLVR_DATASET" ]; then
echo "Please set \$PROMPT_RLVR_DATASET to the path to prompts dataset with gt_answer labels."
exit 1
fi
@ -69,6 +74,8 @@ get_data_input_dirs() {
echo "$SFT_DATASET"
elif [[ $data_type == "prompt" ]]; then
echo "$PROMPT_DATASET"
elif [[ $data_type == "prompt_rlvr" ]]; then
echo "$PROMPT_RLVR_DATASET"
elif [[ $data_type == "preference" ]]; then
echo "$PREFERENCE_DATASET"
elif [[ $data_type == "kto" ]]; then
@ -123,6 +130,10 @@ python $TEST_DIR/generate_dummy_datasets_for_testing.py \
--data_dir $(get_data_input_dirs prompt) \
--data_type "prompt"
python $TEST_DIR/generate_dummy_datasets_for_testing.py \
--data_dir $(get_data_input_dirs prompt_rlvr) \
--data_type "prompt_rlvr"
python $TEST_DIR/generate_dummy_datasets_for_testing.py \
--data_dir $(get_data_input_dirs kto) \
--data_type "kto"
@ -266,6 +277,52 @@ for model in ${MODELS[@]}; do
done
echo "[Test]: testing prepare_prompt_dataset.py (with verifiable reward)..."
# FIXME: This is a hack to skip tests that are not working
SKIPPED_TESTS=(
)
# test prepare_prompt_dataset
for model in ${MODELS[@]}; do
data_type="prompt_rlvr"
if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$data_type " ]]; then
echo "[Test]: Skipped $model-$data_type"
continue
fi
cache_dir=$DATA_SAVE_PATH/tokenized_${model}_${data_type}/cache
jsonl_dir=$DATA_SAVE_PATH/tokenized_${model}_${data_type}/jsonl
arrow_dir=$DATA_SAVE_PATH/tokenized_${model}_${data_type}/arrow
data_input_dirs=$(get_data_input_dirs $data_type)
tokenizer_dir=$(get_tokenizer_dirs $model)
conversation_template=$(get_conversation_template_config $model)
for i in $(seq $NUM_RETRY); do
rm -rf $cache_dir
rm -rf $jsonl_dir
rm -rf $arrow_dir
echo "[Test]: $model-$data_type, attempt $i"
python $EXAMPLES_DIR/data_preparation_scripts/prepare_dataset.py \
--type prompt \
--data_input_dirs $data_input_dirs \
--conversation_template_config $conversation_template \
--tokenizer_dir $tokenizer_dir \
--data_cache_dir $cache_dir \
--data_jsonl_output_dir $jsonl_dir \
--data_arrow_output_dir $arrow_dir \
--max_length 400 \
--num_samples_per_datafile 100 \
--num_spliced_dataset_bins 1
passed=$?
if [ $passed -eq 0 ]; then
break
fi
done
if [ $passed -ne 0 ]; then
echo "[Test]: Failed $model-$data_type"
exit 1
fi
done
echo "[Test]: testing prepare_kto_dataset.py ..."
# FIXME: This is a hack to skip tests that are not working

View File

@ -81,8 +81,242 @@ random_choice() {
echo ${arr[$idx]}
}
echo "[Test]: testing grpo ..."
SKIPPED_TESTS=(
llama-3d # 3d plugin doesn't support lora
llama-gemini # gemini doesn't support lora
)
GRAD_CKPTS=('--grad_checkpoint')
REWARD_FLAG=('nn' 'vr')
for reward_type in ${REWARD_FLAG[@]}; do
for lora_rank in ${LORA_RANK[@]}; do
for model in ${MODELS[@]}; do
for plugin in ${PLUGINS[@]}; do
if [[ $plugin == "gemini_auto" ]]; then
echo "[Test]: Skipped $model-$plugin"
continue # gemini_auto plugin doesn't support generation
fi
if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin-$lora_rank " ]]; then
echo "[Test]: Skipped $model-$plugin-$lora_rank"
continue
elif [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin " ]]; then
echo "[Test]: Skipped $model-$plugin"
continue
fi
pretrain=$(get_pretrain $model)
rm_pretrain="--rm_pretrain $pretrain"
reward_fn=""
if [[ $reward_type == "vr" ]]; then
rm_pretrain=""
reward_fn="--reward_functions gsm8k_reward_fn"
fi
tokenizer_dir=$(get_tokenizer_dirs $model)
grad_ckpt=$(random_choice "${GRAD_CKPTS[@]}")
tp='1'
bs='2'
ebs='1'
conversation_template=$(get_conversation_template_config $model)
if [[ $plugin == "zero2" ]]; then
lora_config=$LORA_CONFIG_ENABLE
else
lora_config=""
fi
if [[ $plugin == "3d" ]]; then
tp='2'
bs='2'
ebs='1'
fi
grad_accu='2'
# gemini_auto and gemini doesn't support gradient accumulation
if [[ $plugin == "gemini_auto" ]]; then
grad_accu='1'
fi
# gemini_auto and gemini doesn't support generation
if [[ $plugin == "gemini_auto" ]]; then
# gemini-auto doesn't support generation
echo "[Test]: Skipped $model-$plugin"
continue
fi
for i in $(seq $NUM_RETRY); do
echo "[Test]: $model-$plugin-$lora_rank-$reward_type, attempt $i"
declare -a prompt_dataset=()
for split in $(seq -f "%05g" 0 0); do
if [[ $reward_type == "vr" ]]; then
prompt_dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_prompt_rlvr/arrow/part-$split")
else
prompt_dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_prompt/arrow/part-$split")
fi
done
declare -a ptx_dataset=()
for split in $(seq -f "%05g" 0 0); do
ptx_dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_sft/arrow/part-$split")
done
colossalai run --nproc_per_node 2 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_grpo.py \
--pretrain $pretrain \
$rm_pretrain \
--tokenizer_dir $tokenizer_dir \
--conversation_template_config $conversation_template \
--prompt_dataset ${prompt_dataset[@]} \
--ptx_dataset ${ptx_dataset[@]} \
--ptx_batch_size 1 \
--num_generations 2 \
--ptx_coef 0.2 \
--save_path $MODEL_SAVE_PATH \
$lora_config \
--plugin $plugin \
--num_episodes 5 \
--num_collect_steps 1 \
--num_update_steps 1 \
--experience_batch_size $ebs \
--train_batch_size $bs \
--accumulation_steps $grad_accu \
--lr 9e-6 \
--mixed_precision "bf16" \
--grad_clip 1.0 \
--tp $tp \
--lr 2e-5 \
$grad_ckpt \
--max_len 200 \ \
--max_seq_len 10 \
$reward_fn
# --use_flash_attn
passed=$?
if [ $passed -eq 0 ]; then
rm -rf ${MODEL_SAVE_PATH:?}/*
rm -rf ${MODELS_DIR:?}/*
break
fi
done
if [ $passed -ne 0 ]; then
echo "[Test]: Failed $model-$plugin-$lora_rank-$reward_type"
exit 1
fi
done
done
done
done
echo "[Test]: testing ppo ..."
SKIPPED_TESTS=(
llama-3d # 3d plugin doesn't support lora
llama-gemini # gemini doesn't support lora
)
GRAD_CKPTS=('--grad_checkpoint')
REWARD_FLAG=('vr' 'nn')
for reward_type in ${REWARD_FLAG[@]}; do
for lora_rank in ${LORA_RANK[@]}; do
for model in ${MODELS[@]}; do
for plugin in ${PLUGINS[@]}; do
if [[ $plugin == "gemini_auto" ]]; then
echo "[Test]: Skipped $model-$plugin"
continue # gemini_auto plugin doesn't support generation
fi
if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin-$lora_rank " ]]; then
echo "[Test]: Skipped $model-$plugin-$lora_rank"
continue
elif [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin " ]]; then
echo "[Test]: Skipped $model-$plugin"
continue
fi
pretrain=$(get_pretrain $model)
reward_fn=""
no_nn=""
if [[ $reward_type == "vr" ]]; then
reward_fn="--reward_functions gsm8k_reward_fn"
no_nn="--no_neural_reward_model"
fi
tokenizer_dir=$(get_tokenizer_dirs $model)
grad_ckpt=$(random_choice "${GRAD_CKPTS[@]}")
tp='1'
bs='2'
ebs='2'
conversation_template=$(get_conversation_template_config $model)
if [[ $plugin == "zero2" ]]; then
lora_config=$LORA_CONFIG_ENABLE
else
lora_config=""
fi
if [[ $plugin == "3d" ]]; then
tp='2'
bs='2'
ebs='2'
fi
grad_accu='2'
# gemini_auto and gemini doesn't support gradient accumulation
if [[ $plugin == "gemini_auto" ]]; then
grad_accu='1'
fi
# gemini_auto and gemini doesn't support generation
if [[ $plugin == "gemini_auto" ]]; then
# gemini-auto doesn't support generation
echo "[Test]: Skipped $model-$plugin"
continue
fi
for i in $(seq $NUM_RETRY); do
echo "[Test]: $model-$plugin-$lora_rank-$reward_type, attempt $i"
declare -a prompt_dataset=()
for split in $(seq -f "%05g" 0 0); do
if [[ $reward_type == "vr" ]]; then
prompt_dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_prompt_rlvr/arrow/part-$split")
else
prompt_dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_prompt/arrow/part-$split")
fi
done
declare -a ptx_dataset=()
for split in $(seq -f "%05g" 0 0); do
ptx_dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_sft/arrow/part-$split")
done
colossalai run --nproc_per_node 2 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_ppo.py \
--pretrain $pretrain \
--rm_pretrain $pretrain \
--tokenizer_dir $tokenizer_dir \
--conversation_template_config $conversation_template \
--prompt_dataset ${prompt_dataset[@]} \
--ptx_dataset ${ptx_dataset[@]} \
--ptx_batch_size 1 \
--ptx_coef 0.2 \
--save_path $MODEL_SAVE_PATH \
$lora_config \
--plugin $plugin \
--num_episodes 5 \
--num_collect_steps 1 \
--num_update_steps 1 \
--experience_batch_size $ebs \
--train_batch_size $bs \
--accumulation_steps $grad_accu \
--lr 9e-6 \
--mixed_precision "bf16" \
--grad_clip 1.0 \
--tp $tp \
--lr 2e-5 \
$grad_ckpt \
--max_len 400 \
--max_seq_len 10 \
$reward_fn \
$no_nn
# --use_flash_attn
passed=$?
if [ $passed -eq 0 ]; then
rm -rf ${MODEL_SAVE_PATH:?}/*
rm -rf ${MODELS_DIR:?}/*
break
fi
done
if [ $passed -ne 0 ]; then
echo "[Test]: Failed $model-$plugin-$lora_rank-$reward_type"
exit 1
fi
done
done
done
done
echo "[Test]: testing sft ..."
@ -316,111 +550,6 @@ for lora_rank in ${LORA_RANK[@]}; do
done
done
echo "[Test]: testing ppo ..."
SKIPPED_TESTS=(
llama-3d # 3d plugin doesn't support lora
llama-gemini # gemini doesn't support lora
)
GRAD_CKPTS=('--grad_checkpoint')
for lora_rank in ${LORA_RANK[@]}; do
for model in ${MODELS[@]}; do
for plugin in ${PLUGINS[@]}; do
if [[ $plugin == "gemini_auto" ]]; then
echo "[Test]: Skipped $model-$plugin"
continue # gemini_auto plugin doesn't support generation
fi
if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin-$lora_rank " ]]; then
echo "[Test]: Skipped $model-$plugin-$lora_rank"
continue
elif [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin " ]]; then
echo "[Test]: Skipped $model-$plugin"
continue
fi
pretrain=$(get_pretrain $model)
tokenizer_dir=$(get_tokenizer_dirs $model)
grad_ckpt=$(random_choice "${GRAD_CKPTS[@]}")
tp='1'
bs='4'
ebs='8'
conversation_template=$(get_conversation_template_config $model)
if [[ $plugin == "zero2" ]]; then
lora_config=$LORA_CONFIG_ENABLE
else
lora_config=""
fi
if [[ $plugin == "3d" ]]; then
tp='2'
bs='16'
ebs='32'
fi
grad_accu='2'
# gemini_auto and gemini doesn't support gradient accumulation
if [[ $plugin == "gemini_auto" ]]; then
grad_accu='1'
fi
# gemini_auto and gemini doesn't support generation
if [[ $plugin == "gemini_auto" ]]; then
# gemini-auto doesn't support generation
echo "[Test]: Skipped $model-$plugin"
continue
fi
for i in $(seq $NUM_RETRY); do
echo "[Test]: $model-$plugin-$lora_rank, attempt $i"
declare -a prompt_dataset=()
for split in $(seq -f "%05g" 0 0); do
prompt_dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_prompt/arrow/part-$split")
done
declare -a ptx_dataset=()
for split in $(seq -f "%05g" 0 0); do
ptx_dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_sft/arrow/part-$split")
done
colossalai run --nproc_per_node 2 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_ppo.py \
--pretrain $pretrain \
--rm_pretrain $pretrain \
--tokenizer_dir $tokenizer_dir \
--conversation_template_config $conversation_template \
--prompt_dataset ${prompt_dataset[@]} \
--ptx_dataset ${ptx_dataset[@]} \
--ptx_batch_size 1 \
--ptx_coef 0.2 \
--save_path $MODEL_SAVE_PATH \
$lora_config \
--plugin $plugin \
--num_episodes 5 \
--num_collect_steps 1 \
--num_update_steps 1 \
--experience_batch_size $ebs \
--train_batch_size $bs \
--accumulation_steps $grad_accu \
--lr 9e-6 \
--mixed_precision "bf16" \
--grad_clip 1.0 \
--tp $tp \
--lr 2e-5 \
$grad_ckpt \
--max_len 400 \
--max_seq_len 10 \
# --use_flash_attn
passed=$?
if [ $passed -eq 0 ]; then
rm -rf ${MODEL_SAVE_PATH:?}/*
rm -rf ${MODELS_DIR:?}/*
break
fi
done
if [ $passed -ne 0 ]; then
echo "[Test]: Failed $model-$plugin-$lora_rank"
exit 1
fi
done
done
done
echo "[Test]: testing DPO ..."
SKIPPED_TESTS=(
@ -446,7 +575,7 @@ for lora_rank in ${LORA_RANK[@]}; do
bs='2'
if [[ $plugin == "3d" ]]; then
tp='2'
bs='8'
bs='2'
fi
if [[ $plugin == "zero2" ]]; then
lora_config=$LORA_CONFIG_ENABLE
@ -503,10 +632,10 @@ for lora_rank in ${LORA_RANK[@]}; do
done
echo "[Test]: testing ORPO ..."
SKIPPED_TESTS=(
llama-3d-0
llama-3d-20 # 3d plugin doesn't support lora
llama-gemini_auto-20 # gemini_auto plugin doesn't support lora
llama-gemini-20 # gemini doesn't support lora
@ -529,7 +658,7 @@ for lora_rank in ${LORA_RANK[@]}; do
bs='2'
if [[ $plugin == "3d" ]]; then
tp='2'
bs='8'
bs='2'
fi
if [[ $plugin == "zero2" ]]; then
lora_config=$LORA_CONFIG_ENABLE
@ -585,11 +714,10 @@ for lora_rank in ${LORA_RANK[@]}; do
done
done
echo "[Test]: testing KTO ..."
SKIPPED_TESTS=(
llama-3d-0
llama-3d-20 # 3d plugin doesn't support lora
llama-gemini_auto-20 # gemini_auto plugin doesn't support lora
llama-gemini-20 # gemini doesn't support lora
@ -612,7 +740,7 @@ for lora_rank in ${LORA_RANK[@]}; do
bs='2'
if [[ $plugin == "3d" ]]; then
tp='2'
bs='8'
bs='2'
fi
if [[ $plugin == "zero2" ]]; then
lora_config=$LORA_CONFIG_ENABLE

View File

@ -43,7 +43,7 @@ class MixedPrecisionMixin(ABC):
dtype: torch.dtype
@abstractmethod
def pre_backward(self, loss: Tensor) -> Tensor:
def pre_backward(self, loss: Tensor, *args, **kwargs) -> Tensor:
"""Called before backward.
Args:

View File

@ -1,4 +1,4 @@
from typing import Dict, List, Tuple
from typing import Dict, List, Optional, Tuple
import torch
from torch import Tensor, inf
@ -84,14 +84,20 @@ class MixedPrecisionOptimizer(OptimizerWrapper):
self.master_to_working_map[master_p] = p
master_params.append(master_p)
group["params"] = master_params
self._current_grad_norm: Optional[float] = None
def backward(self, loss: Tensor, *args, **kwargs):
def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs):
loss = self.mixed_precision.pre_backward(loss)
loss.backward(*args, **kwargs)
loss.backward(inputs=inputs, retain_graph=retain_graph, **kwargs)
def backward_by_grad(self, tensor: Tensor, grad: Tensor):
def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False):
grad = self.mixed_precision.pre_backward_by_grad(tensor, grad)
tensor.backward(grad)
torch.autograd.backward(
tensors=tensor,
grad_tensors=grad,
inputs=inputs,
retain_graph=retain_graph,
)
def zero_grad(self, *args, **kwargs):
for p in self.working_to_master_map.keys():
@ -187,6 +193,7 @@ class MixedPrecisionOptimizer(OptimizerWrapper):
if p.grad is not None
]
total_norm = self._compute_grad_norm(param_gradient_pairs)
self._current_grad_norm = total_norm
self._unscale_and_clip_grads(total_norm)
self.optim.step(*args, **kwargs)
@ -212,3 +219,6 @@ class MixedPrecisionOptimizer(OptimizerWrapper):
def get_master_to_working_map(self) -> Dict[int, torch.Tensor]:
return {id(master_p): working_p for master_p, working_p in self.master_to_working_map.items()}
def get_grad_norm(self, norm_type=2, **kwargs):
return self._current_grad_norm

View File

@ -288,7 +288,14 @@ class Booster:
return self.plugin.enable_lora(model, pretrained_dir, lora_config, bnb_quantization_config)
def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True) -> None:
def load_model(
self,
model: Union[nn.Module, ModelWrapper],
checkpoint: str,
strict: bool = True,
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
) -> None:
"""Load model from checkpoint.
Args:
@ -298,8 +305,12 @@ class Booster:
strict (bool, optional): whether to strictly enforce that the keys
in :attr:`state_dict` match the keys returned by this module's
:meth:`~torch.nn.Module.state_dict` function. Defaults to True.
low_cpu_mem_mode (bool): whether to load the model in low cpu memory mode. If false, it will use RAM cache to accelerate loading. Default: True.
num_threads (int): number of threads to use when loading the model. Only useful when disabling low cpu mem mode. Default: 1.
"""
self.checkpoint_io.load_model(model, checkpoint, strict)
self.checkpoint_io.load_model(
model, checkpoint, strict, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
)
def save_model(
self,
@ -310,6 +321,7 @@ class Booster:
prefix: Optional[str] = None,
size_per_shard: int = 1024,
use_safetensors: bool = False,
use_async: bool = False,
) -> None:
"""Save model to checkpoint.
@ -324,6 +336,7 @@ class Booster:
names to compose the keys in state_dict. Defaults to None.
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
use_safetensors (bool, optional): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved.
use_async (bool, optional): whether to save the state_dict of model asynchronously. Default: False.
"""
self.checkpoint_io.save_model(
model,
@ -333,20 +346,28 @@ class Booster:
prefix=prefix,
size_per_shard=size_per_shard,
use_safetensors=use_safetensors,
use_async=use_async,
)
def load_optimizer(self, optimizer: Optimizer, checkpoint: str) -> None:
def load_optimizer(
self,
optimizer: Optimizer,
checkpoint: str,
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
) -> None:
"""Load optimizer from checkpoint.
Args:
optimizer (Optimizer): An optimizer boosted by Booster.
checkpoint (str): Path to the checkpoint. It must be a local path.
It should be a directory path if the checkpoint is sharded. Otherwise, it should be a file path.
prefix (str, optional): A prefix added to parameter and buffer
names to compose the keys in state_dict. Defaults to None.
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
low_cpu_mem_mode (bool): whether to load the model in low cpu memory mode. If false, it will use RAM cache to accelerate loading. Default: True.
num_threads (int): number of threads to use when loading the model. Only useful when disabling low cpu mem mode. Default: 1.
"""
self.checkpoint_io.load_optimizer(optimizer, checkpoint)
self.checkpoint_io.load_optimizer(
optimizer, checkpoint, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
)
def save_optimizer(
self,
@ -356,6 +377,7 @@ class Booster:
gather_dtensor: bool = True,
prefix: Optional[str] = None,
size_per_shard: int = 1024,
use_async: bool = False,
) -> None:
"""
Save optimizer to checkpoint.
@ -371,7 +393,9 @@ class Booster:
names to compose the keys in state_dict. Defaults to None.
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
"""
self.checkpoint_io.save_optimizer(optimizer, checkpoint, shard, gather_dtensor, prefix, size_per_shard)
self.checkpoint_io.save_optimizer(
optimizer, checkpoint, shard, gather_dtensor, prefix, size_per_shard, use_async=use_async
)
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str) -> None:
"""Save lr scheduler to checkpoint.

View File

@ -46,9 +46,9 @@ class TorchAMPOptimizer(OptimizerWrapper):
growth_interval=growth_interval,
)
def backward(self, loss: Tensor, *args, **kwargs) -> None:
def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs) -> None:
scaled_loss = self.scale_loss(loss)
scaled_loss.backward(*args, **kwargs)
scaled_loss.backward(inputs=inputs, retain_graph=retain_graph, **kwargs)
def step(self, *args, **kwargs) -> Optional[float]:
out = self.scaler.step(self.optim, *args, **kwargs)

View File

@ -1,4 +1,3 @@
import gc
import os
import random
from pathlib import Path
@ -17,9 +16,11 @@ from torch.utils.data.distributed import DistributedSampler
from colossalai.accelerator import get_accelerator
from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO
from colossalai.checkpoint_io.utils import (
async_save_state_dict_shards,
create_pinned_state_dict,
get_model_base_filenames,
get_optimizer_base_filenames,
load_shard_state_dict,
load_state_dict_shards,
save_config_file,
save_state_dict,
save_state_dict_shards,
@ -65,7 +66,14 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
self.coordinator = DistCoordinator()
self.logger = get_dist_logger()
def save_unsharded_model(self, model: GeminiDDP, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
def save_unsharded_model(
self,
model: GeminiDDP,
checkpoint: str,
gather_dtensor: bool,
use_safetensors: bool,
use_async: bool = False,
):
"""
Save sharded model to checkpoint but only on master process.
The model should be unwrapped in self.load_model via ModelWrapper.unwrap.
@ -74,17 +82,39 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
assert isinstance(model, GeminiDDP), "Please boost the model before saving!"
state_dict = model.state_dict(only_rank_0=True)
if self.coordinator.is_master():
save_state_dict(state_dict, checkpoint, use_safetensors)
if use_async:
from colossalai.utils.safetensors import save
def load_unsharded_model(self, model: GeminiDDP, checkpoint: str, strict: bool = True):
if hash(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[hash(model)] = create_pinned_state_dict(state_dict)
for k, v in state_dict.items():
self.pinned_state_dicts[hash(model)][k].copy_(v)
state_dict[k] = self.pinned_state_dicts[hash(model)][k]
writer = save(checkpoint, state_dict)
self.async_writers.append(writer)
else:
save_state_dict(state_dict, checkpoint, use_safetensors)
def load_unsharded_model(
self,
model: GeminiDDP,
checkpoint: str,
strict: bool = True,
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
):
"""
Load model from checkpoint with automatic unwrapping.
The model should be unwrapped in self.load_model via ModelWrapper.unwrap.
"""
assert isinstance(model, GeminiDDP), "Please boost the model before loading!"
super().load_unsharded_model(model, checkpoint, strict=strict)
super().load_unsharded_model(
model, checkpoint, strict=strict, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
)
def save_unsharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint: str, gather_dtensor: bool):
def save_unsharded_optimizer(
self, optimizer: GeminiOptimizer, checkpoint: str, gather_dtensor: bool, use_async: bool = False
):
"""
Save unsharded optimizer state dict to checkpoint.
After calling optimizer.state_dict(), the complete optimizer states will be collected on master rank.
@ -94,15 +124,31 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before saving!"
state_dict = optimizer.state_dict()
if self.coordinator.is_master():
save_state_dict(state_dict, checkpoint, use_safetensors=False)
if use_async:
from colossalai.utils.safetensors import _flatten_optim_state_dict, save
def load_unsharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint: str):
flatten_state_dict, metadata = _flatten_optim_state_dict(state_dict)
if id(optimizer) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(optimizer)] = create_pinned_state_dict(flatten_state_dict)
for k, v in flatten_state_dict.items():
self.pinned_state_dicts[id(optimizer)][k].copy_(v)
flatten_state_dict[k] = self.pinned_state_dicts[id(optimizer)][k]
writer = save(checkpoint, flatten_state_dict, metadata)
self.async_writers.append(writer)
else:
save_state_dict(state_dict, checkpoint, use_safetensors=False)
def load_unsharded_optimizer(
self, optimizer: GeminiOptimizer, checkpoint: str, low_cpu_mem_mode: bool = True, num_threads: int = 1
):
"""
Loading unsharded optimizer from checkpoint file.
For each process, only loading optimizer states of parameters it controls.
"""
assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before loading!"
super().load_unsharded_optimizer(optimizer, checkpoint)
super().load_unsharded_optimizer(
optimizer, checkpoint, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
)
def save_sharded_model(
self,
@ -112,6 +158,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
prefix: Optional[str] = None,
max_shard_size: int = 1024,
use_safetensors: bool = False,
use_async: bool = False,
):
"""
Save sharded model.
@ -124,20 +171,38 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
Path(checkpoint_path).mkdir(parents=True, exist_ok=True)
state_dict_shard = model.state_dict_shard(max_shard_size=max_shard_size, only_rank_0=True)
if use_async and self.coordinator.is_master():
if hash(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[hash(model)] = {}
pinned_state_dicts = self.pinned_state_dicts[hash(model)]
else:
pinned_state_dicts = None
state_dict_shard = model.state_dict_shard(
max_shard_size=max_shard_size, only_rank_0=True, pinned_state_dicts=pinned_state_dicts
)
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
index_file = CheckpointIndexFile(checkpoint_path)
# Save shards of optimizer states.
is_master = self.coordinator.is_master()
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint_path,
index_file=index_file,
base_filename=weights_name,
is_master=is_master,
use_safetensors=use_safetensors,
)
if use_async:
total_size, writers = async_save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint_path,
index_file=index_file,
base_filename=weights_name,
is_master=is_master,
)
self.async_writers.extend(writers)
else:
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint_path,
index_file=index_file,
base_filename=weights_name,
is_master=is_master,
use_safetensors=use_safetensors,
)
# only save the index file on the master rank
if self.coordinator.is_master():
@ -152,16 +217,36 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
)
def load_sharded_model(
self, model: GeminiDDP, checkpoint_index_file: Path, strict: bool = False, use_safetensors: bool = False
self,
model: GeminiDDP,
checkpoint_index_file: Path,
strict: bool = False,
use_safetensors: bool = False,
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
):
"""
Load shard model, load model from multiple files.
"""
assert isinstance(model, GeminiDDP), "Please boost the model before loading!"
return super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module=False)
return super().load_sharded_model(
model,
checkpoint_index_file,
strict,
use_safetensors,
load_sub_module=False,
low_cpu_mem_mode=low_cpu_mem_mode,
num_threads=num_threads,
)
def save_sharded_optimizer(
self, optimizer: GeminiOptimizer, checkpoint: Path, gather_dtensor: bool, prefix: str, size_per_shard: int
self,
optimizer: GeminiOptimizer,
checkpoint: Path,
gather_dtensor: bool,
prefix: str,
size_per_shard: int,
use_async: bool = False,
):
"""
Save sharded optimizer state dict to checkpoint folder.
@ -176,7 +261,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
Path(checkpoint).mkdir(parents=True, exist_ok=True)
# Preparing file paths and index file.
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix, use_safetensors=use_async)
index_file = CheckpointIndexFile(checkpoint)
index_file.append_meta_data("param_groups", param_group_file)
@ -187,17 +272,36 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
torch.save(param_groups, group_file_path)
# States are broken into shards within max_shard_size.
state_dict_shard = optimizer.state_shard(prefix=prefix, max_shard_size=size_per_shard, only_rank_0=True)
if use_async and self.coordinator.is_master():
if id(optimizer) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(optimizer)] = {}
pinned_state_dicts = self.pinned_state_dicts[id(optimizer)]
else:
pinned_state_dicts = None
state_dict_shard = optimizer.state_shard(
prefix=prefix, max_shard_size=size_per_shard, only_rank_0=True, pinned_state_dicts=pinned_state_dicts
)
# Save shards of optimizer states.
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=self.coordinator.is_master(),
use_safetensors=False,
)
if use_async:
total_size, writers = async_save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=self.coordinator.is_master(),
state_preprocess=True,
)
self.async_writers.extend(writers)
else:
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=self.coordinator.is_master(),
use_safetensors=False,
)
# Wrap up index file. Only save it on master rank.
if self.coordinator.is_master():
@ -210,7 +314,14 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
ranks=[0],
)
def load_sharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint_index_file: Path, prefix: str):
def load_sharded_optimizer(
self,
optimizer: GeminiOptimizer,
checkpoint_index_file: Path,
prefix: str,
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
):
"""
Loading sharded optimizer from checkpoint folder, with index file given.
For each process, only loading optimizer states of parameters it controls.
@ -238,11 +349,12 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
# Load optimizer states from shard files under checkpoint path.
# For each file, only load the states managed by current process.
for shard_file in checkpoint_files:
state_dict_shard = load_shard_state_dict(Path(shard_file), use_safetensors=False)
for state_dict_shard in load_state_dict_shards(
checkpoint_files, True, False, low_cpu_mem_mode=low_cpu_mem_mode
):
if not low_cpu_mem_mode:
state_dict_shard = create_pinned_state_dict(state_dict_shard, empty=False, num_threads=num_threads)
optimizer.load_param_states(state_dict_shard)
del state_dict_shard
gc.collect()
optimizer.optimizer_loading_epilogue()

View File

@ -10,6 +10,7 @@ from typing import Any, Callable, Dict, Iterator, List, Optional, OrderedDict, T
import numpy as np
import torch
import torch.distributed as dist
from peft import PeftModel
from torch import Tensor, inf
from torch.distributed import ProcessGroup, get_world_size
from torch.nn import Module, SyncBatchNorm
@ -25,10 +26,11 @@ from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOpt
from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO
from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.interface.model import PeftUnwrapMixin
from colossalai.interface.optimizer import DistributedOptim
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule, ZeroBubbleVPipeScheduler
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.quantization import BnbQuantizationConfig, quantize_model
from colossalai.quantization.fp8_hook import FP8Hook
@ -219,11 +221,13 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
with self._hook_context():
return super().forward(*args, **kwargs)
def unwrap(self):
module = super().unwrap()
if isinstance(module, DDP):
module = module.module
return module
def unwrap(self, unwrap_peft: bool = True):
model = self.module
if isinstance(model, DDP):
model = model.module
if unwrap_peft and isinstance(model, PeftModel):
model = PeftUnwrapMixin(model)
return model
def _force_wait_all_gather(self):
for p in self.module.parameters():
@ -293,9 +297,10 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
self.pp_pg = pp_process_group
self.tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1
self.pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1
self._current_grad_norm: Optional[float] = None
super().__init__(optim)
def backward(self, loss: Tensor, *args, **kwargs):
def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs):
r"""
Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients.
@ -314,7 +319,7 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
# Call the superclass backward method to compute gradients.
with self.model._hook_context():
super().backward(loss, *args, **kwargs)
super().backward(loss, inputs=inputs, retain_graph=retain_graph, **kwargs)
if self.model.require_grad_sync:
# If gradient synchronization is required, sync sequence parallelism gradients.
@ -323,7 +328,7 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
# If gradient synchronization is is not required, return.
return
def backward_by_grad(self, tensor: Tensor, grad: Tensor):
def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False):
"""
Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients.
@ -340,7 +345,7 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
"""
# Call the superclass backward method to compute gradients.
super().backward_by_grad(tensor, grad)
super().backward_by_grad(tensor, grad, inputs=inputs, retain_graph=retain_graph)
if self.model.require_grad_sync:
# If gradient synchronization is required, sync sequence parallelism gradients.
@ -364,6 +369,7 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
(p, p.grad) for group in self.optim.param_groups for p in group["params"] if p.grad is not None
]
total_norm = self._compute_grad_norm(param_gradient_pairs)
self._current_grad_norm = total_norm
# Clip the gradients to prevent exploding gradients.
self._clip_grad_norm(total_norm)
@ -477,6 +483,9 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
def get_master_to_working_map(self):
return None
def get_grad_norm(self, norm_type=2, **kwargs):
return self._current_grad_norm
class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
def __init__(
@ -520,7 +529,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
max_norm=max_norm,
)
def backward(self, loss: Tensor, *args, **kwargs):
def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs):
r"""
Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients.
@ -538,7 +547,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
"""
# Call the superclass backward method to compute gradients.
with self.model._hook_context():
super().backward(loss, *args, **kwargs)
super().backward(loss, inputs=inputs, retain_graph=retain_graph, **kwargs)
if self.model.require_grad_sync:
# If gradient synchronization is required, sync sequence parallelism gradients.
@ -547,7 +556,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
# If gradient synchronization is is not required, return.
return
def backward_by_grad(self, tensor: Tensor, grad: Tensor):
def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False):
"""
Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients.
@ -563,7 +572,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
None
"""
# Call the superclass backward method to compute gradients.
super().backward_by_grad(tensor, grad)
super().backward_by_grad(tensor, grad, inputs=inputs, retain_graph=retain_graph)
if self.model.require_grad_sync:
# If gradient synchronization is required, sync sequence parallelism gradients.
@ -780,7 +789,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
else:
return
def backward(self, loss, retain_graph=False):
def backward(self, loss, inputs=None, retain_graph=False):
"""
Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients.
@ -796,7 +805,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
None
"""
# Call the superclass backward method to compute gradients.
super().backward(loss, retain_graph)
super().backward(loss, inputs=inputs, retain_graph=retain_graph)
if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
# If gradient synchronization is required, sync sequence parallelism gradients.
@ -805,7 +814,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
# If gradient synchronization is is not required, return.
return
def backward_by_grad(self, tensor, grad):
def backward_by_grad(self, tensor, grad, inputs: Tensor = None, retain_graph: bool = False):
"""
Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients.
@ -821,7 +830,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
None
"""
# Call the superclass backward_by_grad method to compute gradients.
super().backward_by_grad(tensor, grad)
super().backward_by_grad(tensor, grad, inputs=inputs, retain_graph=retain_graph)
if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
# If gradient synchronization is required, sync sequence parallelism gradients.
@ -1025,6 +1034,7 @@ class HybridParallelPlugin(PipelinePluginBase):
custom_policy: Policy = None,
pp_style: str = "1f1b",
num_model_chunks: int = 1,
scheduler_nodes: List = None,
num_layers_per_stage: Optional[List[int]] = None,
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None,
enable_metadata_cache: bool = True,
@ -1043,6 +1053,9 @@ class HybridParallelPlugin(PipelinePluginBase):
dist.get_world_size() % (tp_size * pp_size) == 0
), f"World size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
assert (
not pp_style == "zbv" or scheduler_nodes is not None
), f"scheduler_nodes must not be None when using zero bubble pipeline."
if enable_sequence_parallelism:
self.sequence_parallelism_mode = (
sequence_parallelism_mode if sequence_parallelism_mode is not None else "all_to_all"
@ -1104,29 +1117,39 @@ class HybridParallelPlugin(PipelinePluginBase):
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size)
self.stage_manager = None
self.schedule = None
self.scheduler = None
self.custom_policy = custom_policy
assert zero_stage in (0, 1, 2)
if self.pp_size > 1:
assert pp_style in ["1f1b", "interleaved"], "Unsupported pipeline parallelism style"
assert pp_style == "interleaved" or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b"
assert pp_style in ["1f1b", "interleaved", "zbv"], "Unsupported pipeline parallelism style"
assert (
pp_style in ["interleaved", "zbv"] or num_model_chunks == 1
), "num_model_chunks must be 1 when using 1f1b"
assert (
pp_style in ["1f1b", "interleaved"] or num_model_chunks == 2
), "num_model_chunks must be 2 when using zero bubble pipeline"
assert (
num_microbatches is not None or microbatch_size is not None
), "num_microbatches or microbatch_size must be specified when using pipeline parallelism"
assert (
self.zero_stage <= 1
), "To avoid prohibitive gradient synchronization costs, zero stage must be 0 or 1 when using pipeline parallelism"
if pp_style == "zbv":
self.logger.warning(
"""the enable_gradient_checkpointing function must set the use_reentrant to False, such as model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant':False})"""
)
self.stage_manager = PipelineStageManager(
self.pg_mesh,
pipeline_axis=self.pp_axis,
enable_interleave=pp_style == "interleaved",
enable_interleave=(pp_style == "interleaved" or pp_style == "zbv"),
use_zbv=(pp_style == "zbv"),
num_model_chunks=num_model_chunks,
num_layers_per_stage=num_layers_per_stage,
)
if pp_style == "interleaved":
assert num_model_chunks > 1, "number of model chunks must be > 1 when using interleaved"
self.schedule = InterleavedSchedule(
self.scheduler = InterleavedSchedule(
stage_manager=self.stage_manager,
num_model_chunks=num_model_chunks,
num_microbatch=num_microbatches,
@ -1136,13 +1159,21 @@ class HybridParallelPlugin(PipelinePluginBase):
fp8_communication=fp8_communication,
)
elif pp_style == "1f1b":
self.schedule = OneForwardOneBackwardSchedule(
self.scheduler = OneForwardOneBackwardSchedule(
stage_manager=self.stage_manager,
num_microbatches=num_microbatches,
microbatch_size=microbatch_size,
enable_metadata_cache=enable_metadata_cache,
fp8_communication=fp8_communication,
)
elif pp_style == "zbv":
self.scheduler = ZeroBubbleVPipeScheduler(
stage_manager=self.stage_manager,
schedule=scheduler_nodes,
num_model_chunks=num_model_chunks,
num_microbatch=num_microbatches,
microbatch_size=microbatch_size,
)
else:
raise NotImplementedError()
if sequence_parallelism_mode == "ring_attn":
@ -1161,6 +1192,15 @@ class HybridParallelPlugin(PipelinePluginBase):
else:
self.sp_group = self.pg_mesh.get_group_along_axis(self.sp_axis)
# sync gradients across DP * SP ranks
# sync gradients across DP * SP ranks
# Apply Hybrid ZeRO across DP * SP ranks
if self.enable_sequence_parallelism and not is_share_sp_tp(self.sequence_parallelism_mode):
self.mixed_dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
self.dp_size = get_world_size(self.mixed_dp_group)
else:
self.mixed_dp_group = self.dp_group
self.shard_config = ShardConfig(
tensor_parallel_process_group=self.tp_group,
sequence_parallel_process_group=self.sp_group,
@ -1258,7 +1298,6 @@ class HybridParallelPlugin(PipelinePluginBase):
# Replace with distributed implementation if exists
optimizer = cast_to_distributed(optimizer)
if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and self.dp_size > 0:
self.logger.warning(
"Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.",
@ -1272,18 +1311,11 @@ class HybridParallelPlugin(PipelinePluginBase):
use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or (
self.dp_size == 1 and self.pp_size == 1
)
# sync gradients across DP * SP ranks
# Apply Hybrid ZeRO across DP * SP ranks
if self.enable_sequence_parallelism and not is_share_sp_tp(self.sequence_parallelism_mode):
dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
self.dp_size = get_world_size(dp_group)
else:
dp_group = self.dp_group
model = HybridParallelModule(
model,
precision=self.precision,
shard_config=self.shard_config,
dp_group=dp_group,
dp_group=self.mixed_dp_group,
tp_group=self.tp_group,
sp_group=self.sp_group,
use_ddp=use_ddp,
@ -1332,7 +1364,7 @@ class HybridParallelPlugin(PipelinePluginBase):
model,
use_pipeline=self.enable_pipeline_parallelism,
param_info=param_info,
dp_process_group=dp_group,
dp_process_group=self.mixed_dp_group,
tp_process_group=self.tp_group,
pp_process_group=self.pp_group,
verbose=True,
@ -1375,7 +1407,7 @@ class HybridParallelPlugin(PipelinePluginBase):
ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync()
with ctx, model._hook_context():
outputs = self.schedule.forward_backward_step(
outputs = self.scheduler.forward_backward_step(
model, data_iter, criterion, optimizer, return_loss, return_outputs
)
@ -1461,7 +1493,9 @@ class HybridParallelPlugin(PipelinePluginBase):
)
def get_checkpoint_io(self) -> CheckpointIO:
return HybridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
return HybridParallelCheckpointIO(
self.mixed_dp_group, self.pp_group, self.tp_group, self.sp_group, self.zero_stage
)
def no_sync(self, model: Module, optimizer: OptimizerWrapper) -> Iterator[None]:
assert (
@ -1479,7 +1513,7 @@ class HybridParallelPlugin(PipelinePluginBase):
from peft import PeftModel, get_peft_model
assert not isinstance(model, HybridParallelModule), "Lora should be enabled before boosting the model."
assert self.pp_size == 1 and self.tp_size == 1
assert self.tp_size == 1
self.lora_enabled = True
self.logger.warning("You have enabled LoRa training. Please check the hyperparameters such as lr", ranks=[0])

View File

@ -20,15 +20,18 @@ from torch.utils.data import DataLoader
from colossalai.accelerator import get_accelerator
from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO
from colossalai.checkpoint_io.utils import (
create_pinned_state_dict,
get_optimizer_base_filenames,
get_shard_filename,
load_param_groups_into_optimizer,
load_shard_state_dict,
load_state_dict,
load_state_dict_shards,
load_states_into_optimizer,
save_param_groups,
save_state_dict,
sharded_optimizer_loading_epilogue,
)
from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.interface.optimizer import DistributedOptim
from colossalai.logging import get_dist_logger
@ -112,7 +115,9 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False):
def save_unsharded_optimizer(
self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False, use_async: bool = False
):
"""Save optimizer to checkpoint but only on master process.
Args:
@ -124,9 +129,36 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
# the `state_dict` in LowLevelZeroOptimizer has communication
# if only the master rank collect state_dict and save,
# the communication on each rank would not match
state_dict = optimizer.state_dict()
if use_async and self.coordinator.is_master():
if id(optimizer) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(optimizer)] = {}
pinned_state_dicts = self.pinned_state_dicts[id(optimizer)]
else:
pinned_state_dicts = None
state_dict = optimizer.state_dict(pinned_state_dicts, only_on_master=True)
if self.coordinator.is_master():
save_state_dict(state_dict, checkpoint, use_safetensors=False)
if use_async:
from colossalai.utils.safetensors import save_nested
f_writer = save_nested(checkpoint, state_dict)
self.async_writers.append(f_writer)
else:
save_state_dict(state_dict, checkpoint, use_safetensors=False)
def load_unsharded_optimizer(
self, optimizer: OptimizerWrapper, checkpoint: str, low_cpu_mem_mode: bool = True, num_threads: int = 1
):
use_async = checkpoint.endswith(".safetensors")
if use_async:
from colossalai.utils.safetensors import load_flat
checkpoint = load_flat(checkpoint)
else:
checkpoint = load_state_dict(checkpoint)
if not low_cpu_mem_mode:
checkpoint = create_pinned_state_dict(checkpoint, empty=False, num_threads=num_threads)
optimizer.load_state_dict(checkpoint)
def save_sharded_optimizer(
self,
@ -135,6 +167,7 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
gather_dtensor: bool = False,
prefix: str = None,
size_per_shard: int = 1024,
use_async: bool = False,
):
"""
Save sharded Zero-optimizer checkpoint under the given checkpointing path.
@ -160,10 +193,18 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
# state_dict only provide only 'param_groups'
state_dict = optimizer.optim.state_dict()
# state shard would be handled by the low-level zero optimizer
sharded_state = optimizer.state_dict_shard(max_shard_size=size_per_shard)
if use_async and self.coordinator.is_master():
if id(optimizer) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(optimizer)] = {}
pinned_state_dicts = self.pinned_state_dicts[id(optimizer)]
else:
pinned_state_dicts = None
sharded_state = optimizer.state_dict_shard(
max_shard_size=size_per_shard, pinned_state_dicts=pinned_state_dicts, only_on_master=True
)
# Preparing file paths and index file.
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix, use_safetensors=use_async)
index_file = CheckpointIndexFile(checkpoint)
index_file.append_meta_data("param_groups", param_group_file)
@ -183,7 +224,14 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
checkpoint_file_path = os.path.join(checkpoint, shard_file)
if self.coordinator.is_master():
save_state_dict(shard, checkpoint_file_path, use_safetensors=False)
if use_async:
from colossalai.utils.safetensors import save_nested
f_writer = save_nested(checkpoint_file_path, shard)
self.async_writers.append(f_writer)
else:
save_state_dict(shard, checkpoint_file_path, use_safetensors=False)
# Wrap up index file.
index_file.append_meta_data("total_size", total_size)
@ -196,7 +244,14 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
ranks=[0],
)
def load_sharded_optimizer(self, optimizer: OptimizerWrapper, index_file_path: str, prefix: str):
def load_sharded_optimizer(
self,
optimizer: OptimizerWrapper,
index_file_path: str,
prefix: str,
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
):
"""Load sharded optimizer with the given path to index file.
Args:
@ -221,8 +276,7 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
for shard_file in checkpoint_files:
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False)
for state_dict in load_state_dict_shards(checkpoint_files, True, False, low_cpu_mem_mode):
# shard state dict
for param_idx, state in state_dict.items():
for k, v in state.items():
@ -235,14 +289,28 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
if padding_size > 0:
v = torch.nn.functional.pad(v, [0, padding_size])
v_list = v.split(v.numel() // self.coordinator.world_size)
state_dict[param_idx][k] = v_list[self.coordinator.rank].detach().clone()
state_dict[param_idx][k] = v_list[self.coordinator.rank].detach()
if low_cpu_mem_mode:
state_dict[param_idx][k] = state_dict[param_idx][k].clone()
if not low_cpu_mem_mode:
state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads)
load_states_into_optimizer(optimizer, state_dict, id_map)
sharded_optimizer_loading_epilogue(optimizer)
def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = True):
def load_unsharded_model(
self,
model: ModelWrapper,
checkpoint: str,
strict: bool = True,
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
):
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
model._force_wait_all_gather()
super().load_unsharded_model(model, checkpoint, strict)
super().load_unsharded_model(
model, checkpoint, strict, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
)
model.update_master_params()
def load_sharded_model(
@ -252,16 +320,28 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
strict: bool = False,
use_safetensors: bool = False,
load_sub_module: bool = True,
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
):
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
model._force_wait_all_gather()
super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module)
super().load_sharded_model(
model,
checkpoint_index_file,
strict,
use_safetensors,
load_sub_module,
low_cpu_mem_mode=low_cpu_mem_mode,
num_threads=num_threads,
)
model.update_master_params()
def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
def save_unsharded_model(
self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False
):
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
model._force_wait_all_gather()
return super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors)
return super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async=use_async)
def save_sharded_model(
self,
@ -271,30 +351,18 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
prefix: Optional[str] = None,
max_shard_size: int = 1024,
use_safetensors: bool = False,
use_async: bool = False,
):
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
model._force_wait_all_gather()
return super().save_sharded_model(
model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors
model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors, use_async=use_async
)
def save_lora_as_pretrained(self, model, checkpoint, use_safetensors):
if os.path.isfile(checkpoint):
self.logger.error(f"Provided path ({checkpoint}) should be a directory, not a file", ranks=[0])
return
from peft import PeftModel
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
def save_lora_as_pretrained(self, model, checkpoint, use_safetensors, state_dict: Optional[dict] = None):
assert isinstance(model, LowLevelZeroModel), "Please boost the model before saving!"
model._force_wait_all_gather()
peft_model = model.unwrap()
assert isinstance(
peft_model, PeftModel
), "The model doesn't have lora adapters, please enable lora before saving."
return peft_model.save_pretrained(
checkpoint,
safe_serialization=use_safetensors,
state_dict=tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict()),
)
super().save_lora_as_pretrained(model, checkpoint, use_safetensors, state_dict=state_dict)
class LowLevelZeroPlugin(DPPluginBase):
@ -333,6 +401,7 @@ class LowLevelZeroPlugin(DPPluginBase):
verbose (bool, optional): verbose mode. Debug info including grad overflow will be printed. Defaults to False.
use_fp8 (bool, optional): Whether to enable fp8 mixed precision training. Defaults to False.
fp8_communication (bool, optional): Whether to enable fp8 communication. Defaults to False.
extra_dp_size (int, optional): The number of extra data parallel groups. Defaults to 1.
"""
def __init__(
@ -358,11 +427,16 @@ class LowLevelZeroPlugin(DPPluginBase):
cast_inputs: bool = True,
fp8_communication: bool = False,
use_fp8: bool = False,
extra_dp_size: int = 1,
) -> None:
super().__init__()
assert stage in (1, 2), f"LowLevelZeroPlugin only supports stage 1/2 training"
assert precision in SUPPORTED_PRECISION, f"LowLevelZeroPlugin only supports {SUPPORTED_PRECISION} training"
assert norm_type == 2.0, f"LowLevelZeroPlugin only supports norm_type=2.0 now"
if extra_dp_size > 1:
assert dist.get_world_size() % extra_dp_size == 0, "extra_dp_size should be a factor of world_size"
inner_dp_size = dist.get_world_size() // extra_dp_size
self.pg_mesh = ProcessGroupMesh(extra_dp_size, inner_dp_size)
self.stage = stage
self.precision = precision
self.zero_optim_kwargs = dict(
@ -383,6 +457,9 @@ class LowLevelZeroPlugin(DPPluginBase):
overlap_allgather=overlap_allgather,
fp8_communication=fp8_communication,
)
if extra_dp_size > 1:
self.zero_optim_kwargs["extra_dp_group"] = self.pg_mesh.get_group_along_axis(0)
self.zero_optim_kwargs["dp_process_group"] = self.pg_mesh.get_group_along_axis(1)
self.lora_enabled = False
self.verbose = verbose
self.logger = get_dist_logger()

View File

@ -19,7 +19,6 @@ from colossalai.booster.plugin.hybrid_parallel_plugin import (
HybridParallelPlugin,
HybridParallelZeroOptimizer,
get_param_info,
reinitialize_optimizer,
)
from colossalai.checkpoint_io import MoECheckpointIO
from colossalai.cluster.process_group_mesh import ProcessGroupMesh
@ -29,6 +28,7 @@ from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import cast_to_distributed
from colossalai.pipeline.schedule.interleaved_pp import InterleavedSchedule
from colossalai.pipeline.schedule.one_f_one_b import OneForwardOneBackwardSchedule
from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.policies.base_policy import Policy
from colossalai.shardformer.shard.grad_ckpt_config import GradientCheckpointConfig
@ -212,6 +212,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
custom_policy: Policy = None,
pp_style: str = "1f1b",
num_model_chunks: int = 1,
scheduler_nodes: List = None,
num_layers_per_stage: Optional[List[int]] = None,
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None,
enable_metadata_cache: bool = True,
@ -285,12 +286,17 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.moe_dp_size, self.ep_size, self.tp_size, self.sp_size)
self.stage_manager = None
self.schedule = None
self.scheduler = None
self.custom_policy = custom_policy
assert zero_stage in (0, 1, 2)
if self.pp_size > 1:
assert pp_style in ["1f1b", "interleaved"], "Unsupported pipeline parallelism style"
assert pp_style == "interleaved" or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b"
assert pp_style in ["1f1b", "interleaved", "zbv"], "Unsupported pipeline parallelism style"
assert (
pp_style in ["interleaved", "zbv"] or num_model_chunks == 1
), "num_model_chunks must be 1 when using 1f1b"
assert (
pp_style in ["1f1b", "interleaved"] or num_model_chunks == 2
), "num_model_chunks must be 2 when using zero bubble pipeline"
assert (
num_microbatches is not None or microbatch_size is not None
), "num_microbatches or microbatch_size must be specified when using pipeline parallelism"
@ -300,14 +306,15 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
self.stage_manager = PipelineStageManager(
self.pg_mesh,
pipeline_axis=self.pp_axis,
enable_interleave=pp_style == "interleaved",
enable_interleave=(pp_style == "interleaved" or pp_style == "zbv"),
num_model_chunks=num_model_chunks,
num_layers_per_stage=num_layers_per_stage,
use_zbv=(pp_style == "zbv"),
)
if pp_style == "interleaved":
assert num_model_chunks > 1, "number of model chunks must be > 1 when using interleaved"
self.schedule = InterleavedSchedule(
self.scheduler = InterleavedSchedule(
stage_manager=self.stage_manager,
num_model_chunks=num_model_chunks,
num_microbatch=num_microbatches,
@ -316,12 +323,21 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
overlap_p2p=overlap_p2p,
)
elif pp_style == "1f1b":
self.schedule = OneForwardOneBackwardSchedule(
self.scheduler = OneForwardOneBackwardSchedule(
stage_manager=self.stage_manager,
num_microbatches=num_microbatches,
microbatch_size=microbatch_size,
enable_metadata_cache=enable_metadata_cache,
)
elif pp_style == "zbv":
assert num_model_chunks > 1, "number of model chunks must be > 1 when using ZerbubbleV"
self.scheduler = ZeroBubbleVPipeScheduler(
schedule=scheduler_nodes,
stage_manager=self.stage_manager,
num_model_chunks=num_model_chunks,
num_microbatch=num_microbatches,
overlap_p2p=overlap_p2p,
)
else:
raise NotImplementedError()
@ -334,6 +350,14 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
self.sp_group = self.pg_mesh.get_group_along_axis(self.tp_axis)
else:
self.sp_group = self.pg_mesh.get_group_along_axis(self.sp_axis)
# sync gradients across DP * SP ranks
if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all":
self.mixed_dp_group = self.pg_mesh.create_group_along_axis([self.moe_dp_axis, self.ep_axis, self.sp_axis])
self.dp_size = dist.get_world_size(self.mixed_dp_group)
else:
self.mixed_dp_group = self.dp_group
self.use_fp8 = use_fp8
self.shard_config = ShardConfig(
@ -387,7 +411,13 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
def get_checkpoint_io(self) -> MoECheckpointIO:
return MoECheckpointIO(
self.dp_group, self.pp_group, self.tp_group, self.ep_group, self.moe_dp_group, self.zero_stage
self.mixed_dp_group,
self.pp_group,
self.tp_group,
self.sp_group,
self.ep_group,
self.moe_dp_group,
self.zero_stage,
)
def configure(
@ -412,12 +442,6 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
and self.sequence_parallelism_mode == "all_to_all"
)
# sync gradients across DP * SP ranks
if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all":
dp_group = self.pg_mesh.create_group_along_axis([self.moe_dp_axis, self.ep_axis, self.sp_axis])
else:
dp_group = self.dp_group
if use_ddp:
self.logger.warning(
f"Will have to check all params are used in pytorch DDP since not all experts are always activated",
@ -425,7 +449,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
)
self.ddp_config["find_unused_parameters"] = True
if dist.get_process_group_ranks(dp_group) != dist.get_process_group_ranks(self.moe_dp_group):
if dist.get_process_group_ranks(self.mixed_dp_group) != dist.get_process_group_ranks(self.moe_dp_group):
raise ValueError(
f"if pytorch DDP is used, dp_group and moe_dp_group are expected to be the same since DDP can only reduce grad across a single group, but found dp_group {dist.get_process_group_ranks(dp_group)} and moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)}, you might want to modify your config to bypass DDP \nhint: check the above ddp condition to by pass this"
)
@ -434,7 +458,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
module=model,
precision=self.precision,
shard_config=self.shard_config,
dp_group=dp_group,
dp_group=self.mixed_dp_group,
tp_group=self.tp_group,
sp_group=self.sp_group,
use_ddp=use_ddp,
@ -443,18 +467,13 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
use_fp8=self.use_fp8,
)
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
if self.ep_size > 1:
# if ep is enabled, the num of (moe) paramaters changed since they are sharded among ep groups
# but the optimizer is not aware of ep, so we need to update the optimizer
reinitialize_optimizer(optimizer, model)
if self.zero_stage == 0:
is_zero = False
if self.precision in ["fp16", "bf16"]:
optimizer = HybridParallelAMPOptimizer(
optimizer,
model,
use_pipeline=self.enable_pipeline_parallelism,
use_pipeline=self.enable_pipeline_parallelism or self.ep_size > 1,
param_info=param_info,
precision=self.precision,
max_norm=self.max_norm,
@ -464,7 +483,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
optimizer = HybridParallelNaiveOptimizer(
optimizer,
model,
use_pipeline=self.enable_pipeline_parallelism,
use_pipeline=self.enable_pipeline_parallelism or self.ep_size > 1,
param_info=param_info,
max_norm=self.max_norm,
pp_process_group=self.pp_group,
@ -482,9 +501,9 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
optimizer = MoeHybridParallelZeroOptimizer(
optimizer,
model,
use_pipeline=self.enable_pipeline_parallelism,
use_pipeline=self.enable_pipeline_parallelism or self.ep_size > 1,
param_info=param_info,
dp_process_group=dp_group,
dp_process_group=self.mixed_dp_group,
tp_process_group=self.tp_group,
pp_process_group=self.pp_group,
moe_dp_group=self.moe_dp_group,

View File

@ -2,6 +2,7 @@ from typing import Callable, Dict, Iterator, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from peft import PeftModel
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
@ -11,6 +12,7 @@ from torch.utils.data import DataLoader
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.interface.model import PeftUnwrapMixin
from colossalai.logging import get_dist_logger
from colossalai.quantization import BnbQuantizationConfig, quantize_model
from colossalai.utils import get_current_device
@ -26,35 +28,54 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
self.coordinator = DistCoordinator()
self.logger = get_dist_logger()
def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = True):
def load_unsharded_model(
self,
model: ModelWrapper,
checkpoint: str,
strict: bool = True,
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
):
"""
Load model from checkpoint.
"""
assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
super().load_unsharded_model(model.unwrap(), checkpoint, strict=strict)
super().load_unsharded_model(
model.unwrap(), checkpoint, strict=strict, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
)
def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
def save_unsharded_model(
self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False
):
"""
Save model to checkpoint but only on master process.
"""
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
if self.coordinator.is_master():
super().save_unsharded_model(model.unwrap(), checkpoint, gather_dtensor, use_safetensors)
super().save_unsharded_model(
model.unwrap(), checkpoint, gather_dtensor, use_safetensors, use_async=use_async
)
def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str):
def load_unsharded_optimizer(
self, optimizer: OptimizerWrapper, checkpoint: str, low_cpu_mem_mode: bool = True, num_threads: int = 1
):
"""
Load optimizer from checkpoint.
"""
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
super().load_unsharded_optimizer(optimizer, checkpoint)
super().load_unsharded_optimizer(
optimizer, checkpoint, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
)
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
def save_unsharded_optimizer(
self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool, use_async: bool = False
):
"""
Save optimizer to checkpoint but only on master process.
"""
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
if self.coordinator.is_master():
super().save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor)
super().save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor, use_async=use_async)
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
"""
@ -71,6 +92,7 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
prefix: Optional[str] = None,
max_shard_size: int = 1024,
use_safetensors: bool = False,
use_async: bool = False,
):
"""
Save model to checkpoint but only on master process.
@ -78,7 +100,13 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
if self.coordinator.is_master():
super().save_sharded_model(
model.unwrap(), checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors
model.unwrap(),
checkpoint_path,
gather_dtensor,
prefix,
max_shard_size,
use_safetensors,
use_async=use_async,
)
def load_sharded_model(
@ -88,12 +116,22 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
strict: bool = False,
use_safetensors: bool = False,
load_sub_module: bool = True,
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
):
"""
Load model from sharded checkpoint.
"""
assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
super().load_sharded_model(model.unwrap(), checkpoint_index_file, strict, use_safetensors, load_sub_module)
super().load_sharded_model(
model.unwrap(),
checkpoint_index_file,
strict,
use_safetensors,
load_sub_module,
low_cpu_mem_mode=low_cpu_mem_mode,
num_threads=num_threads,
)
def save_sharded_optimizer(
self,
@ -102,28 +140,39 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
gather_dtensor: bool = True,
prefix: Optional[str] = None,
size_per_shard: int = 1024,
use_async: bool = False,
):
"""
Save optimizer to sharded checkpoint but only on master process.
"""
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
if self.coordinator.is_master():
super().save_sharded_optimizer(optimizer.unwrap(), checkpoint, gather_dtensor, prefix, size_per_shard)
super().save_sharded_optimizer(
optimizer.unwrap(), checkpoint, gather_dtensor, prefix, size_per_shard, use_async=use_async
)
def load_sharded_optimizer(
self,
optimizer: Optimizer,
index_file_path: str,
prefix: Optional[str] = None,
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
):
"""
Load optimizer from sharded checkpoint.
"""
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
super().load_sharded_optimizer(optimizer.unwrap(), index_file_path, prefix)
super().load_sharded_optimizer(
optimizer.unwrap(), index_file_path, prefix, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
)
def save_lora_as_pretrained(
self, model: Union[nn.Module, ModelWrapper], checkpoint: str, use_safetensors: bool = False
self,
model: Union[nn.Module, ModelWrapper],
checkpoint: str,
use_safetensors: bool = False,
state_dict: Optional[dict] = None,
) -> None:
"""
Save the lora adapters and adapter configuration file to checkpoint directory.
@ -131,15 +180,17 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
from peft import PeftModel
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
peft_model = model.unwrap(unwrap_peft=False)
assert isinstance(
peft_model, PeftModel
), "The model doesn't have lora adapters, please enable lora before saving."
if state_dict is None:
state_dict = tree_map(lambda x: x.data.cpu() if torch.is_tensor(x) else x, peft_model.state_dict())
if self.coordinator.is_master():
peft_model = model.unwrap()
assert isinstance(
peft_model, PeftModel
), "The model doesn't have lora adapters, please enable lora before saving."
return peft_model.save_pretrained(
checkpoint,
safe_serialization=use_safetensors,
state_dict=tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict()),
state_dict=state_dict,
)
@ -148,8 +199,11 @@ class TorchDDPModel(ModelWrapper):
super().__init__(module)
self.module = DDP(module, *args, **kwargs)
def unwrap(self):
return self.module.module
def unwrap(self, unwrap_peft: bool = True) -> nn.Module:
model = self.module.module
if unwrap_peft and isinstance(model, PeftModel):
model = PeftUnwrapMixin(model)
return model
class TorchDDPPlugin(DPPluginBase):

View File

@ -1,6 +1,6 @@
import os
from pathlib import Path
from typing import Callable, Dict, Iterable, Iterator, List, Optional, Tuple
from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple
import torch
import torch.nn as nn
@ -26,9 +26,11 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader
from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO, utils
from colossalai.checkpoint_io.utils import async_save_state_dict_shards, create_pinned_state_dict
from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.logging import get_dist_logger
from colossalai.utils.safetensors import load_flat
from .dp_plugin_base import DPPluginBase
@ -41,20 +43,54 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
self.coordinator = DistCoordinator()
self.logger = get_dist_logger()
def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool):
def load_unsharded_model(
self, model: ModelWrapper, checkpoint: str, strict: bool, low_cpu_mem_mode: bool = True, num_threads: int = 1
):
assert isinstance(model, TorchFSDPModel), "Please boost the model before loading!"
model = model.unwrap()
checkpoint = utils.load_state_dict(checkpoint)
model.load_state_dict(checkpoint)
def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: Path):
def load_unsharded_optimizer(
self, optimizer: OptimizerWrapper, checkpoint: Path, low_cpu_mem_mode: bool = True, num_threads: int = 1
):
assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before loading!"
checkpoint = utils.load_state_dict(checkpoint)
if checkpoint.endswith(".safetensors"):
checkpoint = load_flat(checkpoint, seperator=".")
else:
checkpoint = utils.load_state_dict(checkpoint)
fsdp_model = optimizer.unwrap_model()
full_optimizer_state = FSDP.full_optim_state_dict(fsdp_model, optim=optimizer, rank0_only=False)
start_index = 0
id2name = {}
def get_index_mapping(group: Dict[str, Any]) -> Dict[str, Any]:
nonlocal start_index
start_num = len(id2name)
id2name.update({i: p for i, p in enumerate(group["params"], start_index) if i not in id2name})
end_num = len(id2name)
start_index += end_num - start_num
for g in full_optimizer_state["param_groups"]:
get_index_mapping(g)
new_state = {}
for key, value in checkpoint["state"].items():
new_state[id2name[int(key)]] = value
checkpoint["state"] = new_state
for g in checkpoint["param_groups"]:
new_group = []
for param_id in g["params"]:
new_group.append(id2name[param_id])
g["params"] = new_group
sharded_osd = FSDP.scatter_full_optim_state_dict(checkpoint, fsdp_model)
optimizer.load_state_dict(sharded_osd)
def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
def save_unsharded_model(
self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False
):
"""
Save model to checkpoint but only on master process.
"""
@ -63,16 +99,67 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, cfg):
full_model_state = model.state_dict()
utils.save_state_dict(full_model_state, checkpoint_file_path=checkpoint, use_safetensors=use_safetensors)
if self.coordinator.is_master():
if use_async:
from colossalai.utils.safetensors import save
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
if hash(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[hash(model)] = create_pinned_state_dict(full_model_state)
for k, v in full_model_state.items():
self.pinned_state_dicts[hash(model)][k].copy_(v)
full_model_state[k] = self.pinned_state_dicts[hash(model)][k]
writer = save(checkpoint, full_model_state)
self.async_writers.append(writer)
else:
utils.save_state_dict(
full_model_state, checkpoint_file_path=checkpoint, use_safetensors=use_safetensors
)
def save_unsharded_optimizer(
self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool, use_async: bool = False
):
"""
Save optimizer to checkpoint but only on master process.
"""
assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before saving!"
fsdp_model = optimizer.unwrap_model()
full_optimizer_state = FSDP.full_optim_state_dict(fsdp_model, optim=optimizer, rank0_only=True)
utils.save_state_dict(full_optimizer_state, checkpoint_file_path=checkpoint, use_safetensors=False)
if self.coordinator.is_master():
# Save order indices instead of Tensors
name2id: Dict[str, int] = {}
start_index = 0
def pack_group(group: Dict[str, Any]) -> Dict[str, Any]:
nonlocal start_index
packed = {k: v for k, v in group.items() if k != "params"}
name2id.update({p: i for i, p in enumerate(group["params"], start_index) if p not in name2id})
packed["params"] = [name2id[p] for p in group["params"]]
start_index += len(packed["params"])
return packed
param_groups = [pack_group(g) for g in full_optimizer_state["param_groups"]]
full_optimizer_state["param_groups"] = param_groups
new_state = {}
for key, value in full_optimizer_state["state"].items():
new_state[name2id[key]] = value
full_optimizer_state["state"] = new_state
if use_async:
from colossalai.utils.safetensors import _flatten_optim_state_dict, save
flatten_state_dict, metadata = _flatten_optim_state_dict(full_optimizer_state, seperator=".")
if id(optimizer) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(optimizer)] = create_pinned_state_dict(flatten_state_dict)
for k, v in flatten_state_dict.items():
self.pinned_state_dicts[id(optimizer)][k].copy_(v)
flatten_state_dict[k] = self.pinned_state_dicts[id(optimizer)][k]
writer = save(checkpoint, state_dict=flatten_state_dict, metadata=metadata)
self.async_writers.append(writer)
else:
utils.save_state_dict(full_optimizer_state, checkpoint_file_path=checkpoint, use_safetensors=False)
def save_sharded_model(
self,
@ -82,6 +169,7 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
prefix: Optional[str] = None,
size_per_shard: int = 1024,
use_safetensors: bool = False,
use_async: bool = False,
):
"""
Save model to checkpoint but only on master process.
@ -97,20 +185,38 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
):
state_dict = model.unwrap().state_dict()
state_dict_shard = utils.shard_model_checkpoint(state_dict, max_shard_size=size_per_shard)
if use_async and self.coordinator.is_master():
if hash(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[hash(model)] = {}
pinned_state_dicts = self.pinned_state_dicts[hash(model)]
else:
pinned_state_dicts = None
state_dict_shard = utils.shard_model_checkpoint(
state_dict, max_shard_size=size_per_shard, pinned_state_dicts=pinned_state_dicts
)
weights_name, save_index_file = utils.get_model_base_filenames(prefix, use_safetensors)
index_file = CheckpointIndexFile(checkpoint_path)
# In general cases, is_master is set to True to get the right behavior.
total_size = utils.save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint_path,
index_file=index_file,
base_filename=weights_name,
is_master=self.coordinator.is_master(),
use_safetensors=use_safetensors,
)
if use_async:
total_size, writers = async_save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint_path,
index_file=index_file,
base_filename=weights_name,
is_master=self.coordinator.is_master(),
)
self.async_writers.extend(writers)
else:
total_size = utils.save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint_path,
index_file=index_file,
base_filename=weights_name,
is_master=self.coordinator.is_master(),
use_safetensors=use_safetensors,
)
# only save the index file on the master rank
if self.coordinator.is_master():
@ -130,6 +236,8 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
strict: bool = False,
use_safetensors: bool = False,
load_sub_module: bool = True,
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
):
"""
Load model to checkpoint but only on master process.
@ -147,14 +255,20 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
fsdp_state_dict = {}
for shard_file in checkpoint_files:
fsdp_state_dict.update(utils.load_shard_state_dict(Path(shard_file), use_safetensors))
for state_dict in utils.load_state_dict_shards(checkpoint_files, False, use_safetensors):
fsdp_state_dict.update(state_dict)
with FSDP.state_dict_type(model.unwrap(), StateDictType.FULL_STATE_DICT):
model.unwrap().load_state_dict(fsdp_state_dict, strict=False)
def save_sharded_optimizer(
self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool, prefix: str, size_per_shard: int
self,
optimizer: Optimizer,
checkpoint: str,
gather_dtensor: bool,
prefix: str,
size_per_shard: int,
use_async: bool = False,
):
"""
Save optimizer to checkpoint but only on master process.
@ -177,26 +291,66 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
)
if self.coordinator.is_master():
# Save order indices instead of Tensors
name2id: Dict[str, int] = {}
start_index = 0
def pack_group(group: Dict[str, Any]) -> Dict[str, Any]:
nonlocal start_index
packed = {k: v for k, v in group.items() if k != "params"}
name2id.update({p: i for i, p in enumerate(group["params"], start_index) if p not in name2id})
packed["params"] = [name2id[p] for p in group["params"]]
start_index += len(packed["params"])
return packed
param_groups = [pack_group(g) for g in fsdp_optim_state["param_groups"]]
fsdp_optim_state["param_groups"] = param_groups
new_state = {}
for key, value in fsdp_optim_state["state"].items():
new_state[name2id[key]] = value
fsdp_optim_state["state"] = new_state
# Preparing file paths and index file.
states_name, save_index_file, param_group_file = utils.get_optimizer_base_filenames(prefix)
states_name, save_index_file, param_group_file = utils.get_optimizer_base_filenames(
prefix, use_safetensors=use_async
)
index_file = CheckpointIndexFile(checkpoint)
index_file.append_meta_data("param_groups", param_group_file)
group_file_path = os.path.join(checkpoint, param_group_file)
utils.save_param_groups(fsdp_optim_state, group_file_path)
sharded_state = utils.shard_optimizer_checkpoint(fsdp_optim_state, max_shard_size=size_per_shard)
if use_async:
if id(optimizer) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(optimizer)] = {}
pinned_state_dicts = self.pinned_state_dicts[id(optimizer)]
else:
pinned_state_dicts = None
sharded_state = utils.shard_optimizer_checkpoint(
fsdp_optim_state, max_shard_size=size_per_shard, pinned_state_dicts=pinned_state_dicts
)
# Save shards of optimizer states.
# In general cases, is_master is set to True to get the right behavior.
total_size = utils.save_state_dict_shards(
sharded_state_dict=sharded_state,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=self.coordinator.is_master(),
use_safetensors=False,
)
if use_async:
total_size, writers = async_save_state_dict_shards(
sharded_state_dict=sharded_state,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=self.coordinator.is_master(),
state_preprocess=True,
)
self.async_writers.extend(writers)
else:
total_size = utils.save_state_dict_shards(
sharded_state_dict=sharded_state,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=self.coordinator.is_master(),
use_safetensors=False,
)
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
@ -206,7 +360,14 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
f"index located at {save_index_file}."
)
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, size_per_shard: int):
def load_sharded_optimizer(
self,
optimizer: Optimizer,
index_file_path: str,
size_per_shard: int,
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
):
"""
Load optimizer to checkpoint but only on master process.
"""
@ -227,12 +388,36 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
# Load param
fsdp_optim_state = {}
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
for shard_file in checkpoint_files:
state_dict_shard = utils.load_shard_state_dict(Path(shard_file), use_safetensors=False)
for state_dict_shard in utils.load_state_dict_shards(checkpoint_files, True, False):
fsdp_optim_state.update(state_dict_shard)
fsdp_optim_dict = dict(state=fsdp_optim_state, param_groups=saved_param_groups)
fsdp_model = optimizer.unwrap_model()
full_optimizer_state = FSDP.full_optim_state_dict(fsdp_model.unwrap(), optim=optimizer, rank0_only=False)
start_index = 0
id2name = {}
def get_index_mapping(group: Dict[str, Any]) -> Dict[str, Any]:
nonlocal start_index
start_num = len(id2name)
id2name.update({i: p for i, p in enumerate(group["params"], start_index) if i not in id2name})
end_num = len(id2name)
start_index += end_num - start_num
for g in full_optimizer_state["param_groups"]:
get_index_mapping(g)
new_state = {}
for key, value in fsdp_optim_dict["state"].items():
new_state[id2name[int(key)]] = value
fsdp_optim_dict["state"] = new_state
for g in fsdp_optim_dict["param_groups"]:
new_group = []
for param_id in g["params"]:
new_group.append(id2name[param_id])
g["params"] = new_group
with FSDP.state_dict_type(optimizer.unwrap_model().unwrap(), StateDictType.FULL_STATE_DICT):
fsdp_state = FSDP.optim_state_dict_to_load(
model=optimizer.unwrap_model().unwrap(), optim=optimizer, optim_state_dict=fsdp_optim_dict
@ -252,9 +437,6 @@ class TorchFSDPModel(ModelWrapper):
super().__init__(module)
self.module = FSDP(module, *args, **kwargs)
def unwrap(self):
return self.module
class FSDPOptimizerWrapper(OptimizerWrapper):
def __init__(self, optimizer: Optimizer, model: nn.Module):

View File

@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Optional, Union
from typing import Dict, Optional, Union
import torch
import torch.nn as nn
@ -8,6 +8,7 @@ from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from colossalai.interface import ModelWrapper
from colossalai.logging import get_dist_logger
from .utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, has_index_file
@ -61,8 +62,35 @@ class CheckpointIO(ABC):
# ======================================
# Public methods
# ======================================
def __init__(self):
super().__init__()
self.pinned_state_dicts: Dict[int, dict] = {}
self.async_writers = []
def _sync_io(self):
for writer in self.async_writers:
writer.synchronize()
self.async_writers.clear()
def _sync_d2h(self):
for writer in self.async_writers:
writer.sync_before_step()
def synchronize(self):
"""This method must be called before updating the model weights."""
self._sync_d2h()
def __del__(self):
self._sync_d2h()
self._sync_io()
def load_model(
self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True
self,
model: Union[nn.Module, ModelWrapper],
checkpoint: str,
strict: bool = True,
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
) -> Union[nn.Module, ModelWrapper]:
"""
Load model from checkpoint.
@ -77,6 +105,8 @@ class CheckpointIO(ABC):
Distributed tensors cannot be loaded directly unless gathered offline via our CLI.
strict (bool): whether to strictly enforce that the param name in
the checkpoint match the keys returned by this module's.
low_cpu_mem_mode (bool): whether to load the model in low cpu memory mode. If false, it will use RAM cache to accelerate loading. Default: True.
num_threads (int): number of threads to use when loading the model. Only useful when disabling low cpu mem mode. Default: 1.
"""
# since we only support loaded sharded and unsharded weight format
# containing no distributed tensors, dtensor -> full tensor conversion
@ -88,17 +118,25 @@ class CheckpointIO(ABC):
origin_model = model
if index_file_exists:
self.load_sharded_model(model, index_file_path, strict)
self.load_sharded_model(
model, index_file_path, strict, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
)
else:
path = Path(checkpoint, SAFE_WEIGHTS_NAME)
if path.is_file():
self.load_unsharded_model(model, str(path), strict)
self.load_unsharded_model(
model, str(path), strict, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
)
else:
path = Path(checkpoint, WEIGHTS_NAME)
if path.is_file():
self.load_unsharded_model(model, str(path), strict)
self.load_unsharded_model(
model, str(path), strict, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
)
else:
self.load_unsharded_model(model, checkpoint, strict)
self.load_unsharded_model(
model, checkpoint, strict, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
)
return origin_model
@ -111,6 +149,7 @@ class CheckpointIO(ABC):
prefix: str = None,
size_per_shard: int = 1024,
use_safetensors: bool = False,
use_async: bool = False,
):
"""
Save model to checkpoint.
@ -138,13 +177,30 @@ class CheckpointIO(ABC):
size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard = True.
use_safetensors (bool): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved
"""
self._sync_io()
if use_async and not use_safetensors:
logger = get_dist_logger()
logger.warning(
"Async save is only supported when use_safetensors is set to True. "
"Setting use_safetensors to True for async save."
)
use_safetensors = True
if shard:
self.save_sharded_model(model, checkpoint, gather_dtensor, prefix, size_per_shard, use_safetensors)
self.save_sharded_model(
model, checkpoint, gather_dtensor, prefix, size_per_shard, use_safetensors, use_async
)
else:
self.save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors)
self.save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async)
def load_optimizer(self, optimizer: Optimizer, checkpoint: str, prefix: str = None, size_per_shard: int = 1024):
def load_optimizer(
self,
optimizer: Optimizer,
checkpoint: str,
prefix: str = None,
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
):
"""
Load optimizer from checkpoint.
@ -153,7 +209,8 @@ class CheckpointIO(ABC):
checkpoint (str): checkpoint path. This value is made compatibility with the model checkpoints in the
prefix (str, optional): A prefix added to parameter and buffer
names to compose the keys in state_dict. Defaults to None.
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
low_cpu_mem_mode (bool): whether to load the model in low cpu memory mode. If false, it will use RAM cache to accelerate loading. Default: True.
num_threads (int): number of threads to use when loading the model. Only useful when disabling low cpu mem mode. Default: 1.
"""
index_file_exists, index_file_path = has_index_file(checkpoint)
@ -164,9 +221,13 @@ class CheckpointIO(ABC):
if index_file_exists:
# the existence of index file means it is a sharded checkpoint
self.load_sharded_optimizer(optimizer, index_file_path, prefix)
self.load_sharded_optimizer(
optimizer, index_file_path, prefix, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
)
else:
self.load_unsharded_optimizer(optimizer, checkpoint)
self.load_unsharded_optimizer(
optimizer, checkpoint, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
)
def save_optimizer(
self,
@ -176,6 +237,7 @@ class CheckpointIO(ABC):
gather_dtensor=True,
prefix: str = None,
size_per_shard: int = 1024,
use_async: bool = False,
):
"""
Save optimizer to checkpoint. Optimizer states saving is not compatible with safetensors.
@ -192,17 +254,20 @@ class CheckpointIO(ABC):
prefix (str): prefix for the optimizer checkpoint when shard = True. Default: None.
size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard is set to True.
"""
if shard:
self.save_sharded_optimizer(optimizer, checkpoint, gather_dtensor, prefix, size_per_shard)
self.save_sharded_optimizer(
optimizer, checkpoint, gather_dtensor, prefix, size_per_shard, use_async=use_async
)
else:
self.save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor)
self.save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor, use_async=use_async)
# ========================================================
# Abstract methods for model loading/saving implementation
# ========================================================
@abstractmethod
def load_sharded_model(self, model: nn.Module, index_file_path: str, strict: bool):
def load_sharded_model(
self, model: nn.Module, index_file_path: str, strict: bool, low_cpu_mem_mode: bool = True, num_threads: int = 1
):
"""
Load model from sharded checkpoint.
@ -211,10 +276,14 @@ class CheckpointIO(ABC):
index_file_path (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file.
strict (bool): whether to strictly enforce that the param name in
the checkpoint match the keys returned by this module's.
low_cpu_mem_mode (bool): whether to load the model in low cpu memory mode. If false, it will use RAM cache to accelerate loading. Default: True.
num_threads (int): number of threads to use when loading the model. Only useful when disabling low cpu mem mode. Default: 1.
"""
@abstractmethod
def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool):
def load_unsharded_model(
self, model: nn.Module, checkpoint: str, strict: bool, low_cpu_mem_mode: bool = True, num_threads: int = 1
):
"""
Load model from unsharded checkpoint.
@ -223,6 +292,8 @@ class CheckpointIO(ABC):
checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary.
strict (bool): whether to strictly enforce that the param name in
the checkpoint match the keys returned by this module's.
low_cpu_mem_mode (bool): whether to load the model in low cpu memory mode. If false, it will use RAM cache to accelerate loading. Default: True.
num_threads (int): number of threads to use when loading the model. Only useful when disabling low cpu mem mode. Default: 1.
"""
@abstractmethod
@ -234,6 +305,7 @@ class CheckpointIO(ABC):
prefix: Optional[str],
size_per_shard: int,
use_safetensors: bool,
use_async: bool = False,
):
"""
Save model to sharded checkpoint.
@ -248,7 +320,9 @@ class CheckpointIO(ABC):
"""
@abstractmethod
def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
def save_unsharded_model(
self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False
):
"""
Save model to unsharded checkpoint.
@ -264,7 +338,14 @@ class CheckpointIO(ABC):
# ========================================================
@abstractmethod
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str):
def load_sharded_optimizer(
self,
optimizer: Optimizer,
index_file_path: str,
prefix: str,
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
):
"""
Load optimizer from sharded checkpoint.
@ -272,21 +353,33 @@ class CheckpointIO(ABC):
optimizer (Optimizer): optimizer to be loaded.
index_file_path (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file.
prefix (str): prefix for the optimizer checkpoint.
low_cpu_mem_mode (bool): whether to load the model in low cpu memory mode. If false, it will use RAM cache to accelerate loading. Default: True.
num_threads (int): number of threads to use when loading the model. Only useful when disabling low cpu mem mode. Default: 1.
"""
@abstractmethod
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
def load_unsharded_optimizer(
self, optimizer: Optimizer, checkpoint: Path, low_cpu_mem_mode: bool = True, num_threads: int = 1
):
"""
Load optimizer from unsharded checkpoint.
Args:
optimizer (Optimizer): optimizer to be loaded.
checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary.
low_cpu_mem_mode (bool): whether to load the model in low cpu memory mode. If false, it will use RAM cache to accelerate loading. Default: True.
num_threads (int): number of threads to use when loading the model. Only useful when disabling low cpu mem mode. Default: 1.
"""
@abstractmethod
def save_sharded_optimizer(
self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, prefix: str, size_per_shard: int
self,
optimizer: Optimizer,
checkpoint: Path,
gather_dtensor: bool,
prefix: str,
size_per_shard: int,
use_async: bool = False,
):
"""
Save optimizer to sharded checkpoint.
@ -300,7 +393,9 @@ class CheckpointIO(ABC):
"""
@abstractmethod
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool):
def save_unsharded_optimizer(
self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, use_async: bool = False
):
"""
Save optimizer to unsharded checkpoint.
@ -342,7 +437,11 @@ class CheckpointIO(ABC):
@abstractmethod
def save_lora_as_pretrained(
self, model: Union[nn.Module, ModelWrapper], checkpoint: str, use_safetensors: bool = False
self,
model: Union[nn.Module, ModelWrapper],
checkpoint: str,
use_safetensors: bool = False,
state_dict: Optional[dict] = None,
) -> None:
"""
Save the lora adapters and adapter configuration file to a pretrained checkpoint directory.
@ -351,4 +450,5 @@ class CheckpointIO(ABC):
model (Union[nn.Module, ModelWrapper]): A model boosted by Booster.
checkpoint (str): Path to the checkpoint directory. It must be a local path.
use_safetensors (bool, optional): Whether to use safe tensors when saving. Defaults to False.
state_dict (Optional[dict], optional): The state dict to save. Defaults to None.
"""

View File

@ -1,4 +1,3 @@
import gc
import logging
import os
from functools import reduce
@ -8,16 +7,20 @@ from typing import Optional
import torch.nn as nn
from torch.optim import Optimizer
from colossalai.utils.safetensors import load_flat
from .checkpoint_io_base import CheckpointIO
from .index_file import CheckpointIndexFile
from .utils import (
async_move_save_state_dict_shards,
create_pinned_state_dict,
get_model_base_filenames,
get_optimizer_base_filenames,
is_safetensors_available,
load_param_groups_into_optimizer,
load_shard_state_dict,
load_state_dict,
load_state_dict_into_model,
load_state_dict_shards,
load_states_into_optimizer,
save_config_file,
save_param_groups,
@ -36,21 +39,43 @@ class GeneralCheckpointIO(CheckpointIO):
Checkpoint IO
"""
def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool):
def load_unsharded_model(
self,
model: nn.Module,
checkpoint: str,
strict: bool,
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
):
checkpoint = load_state_dict(checkpoint)
if not low_cpu_mem_mode:
checkpoint = create_pinned_state_dict(checkpoint, empty=False, num_threads=num_threads)
model.load_state_dict(checkpoint, strict=strict)
def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
def save_unsharded_model(
self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False
):
state_dict = model.state_dict()
# TODO(FrankLeeeee): add support for gather_dtensor
if gather_dtensor:
pass
if use_async:
from colossalai.utils.safetensors import move_and_save
# save the checkpoint
save_state_dict(state_dict, checkpoint, use_safetensors)
if hash(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[hash(model)] = create_pinned_state_dict(state_dict)
writer = move_and_save(checkpoint, state_dict, self.pinned_state_dicts[hash(model)])
self.async_writers.append(writer)
else:
# save the checkpoint
save_state_dict(state_dict, checkpoint, use_safetensors)
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str):
def load_sharded_optimizer(
self,
optimizer: Optimizer,
index_file_path: str,
prefix: str,
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
):
"""
Load sharded optimizer with the given path to index file.
"""
@ -69,8 +94,9 @@ class GeneralCheckpointIO(CheckpointIO):
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
for shard_file in checkpoint_files:
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False)
for state_dict in load_state_dict_shards(checkpoint_files, True, False, low_cpu_mem_mode):
if not low_cpu_mem_mode:
state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads)
load_states_into_optimizer(optimizer, state_dict, id_map)
sharded_optimizer_loading_epilogue(optimizer)
@ -82,6 +108,7 @@ class GeneralCheckpointIO(CheckpointIO):
gather_dtensor: bool,
prefix: str,
size_per_shard: int,
use_async: bool = False,
):
"""
Save sharded optimizer checkpoint under the given checkpointing path.
@ -102,7 +129,7 @@ class GeneralCheckpointIO(CheckpointIO):
sharded_state = shard_optimizer_checkpoint(state_dict, max_shard_size=size_per_shard)
# Preparing file paths and index file.
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix, use_safetensors=use_async)
index_file = CheckpointIndexFile(checkpoint)
# Store the information of param groups to param_group_file.
@ -112,14 +139,28 @@ class GeneralCheckpointIO(CheckpointIO):
# Save shards of optimizer states.
# In general cases, is_master is set to True to get the right behavior.
total_size = save_state_dict_shards(
sharded_state_dict=sharded_state,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=True,
use_safetensors=False,
)
if use_async:
pinned_state_dict = self.pinned_state_dicts.get(id(optimizer), None)
total_size, new_pinned_state_dict, writers = async_move_save_state_dict_shards(
sharded_state_dict=sharded_state,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=True,
pinned_state_dict=pinned_state_dict,
state_preprocess=True,
)
self.pinned_state_dicts[id(optimizer)] = new_pinned_state_dict
self.async_writers.extend(writers)
else:
total_size = save_state_dict_shards(
sharded_state_dict=sharded_state,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=True,
use_safetensors=False,
)
# Wrap up index file.
index_file.append_meta_data("total_size", total_size)
@ -130,8 +171,15 @@ class GeneralCheckpointIO(CheckpointIO):
f"index located at {save_index_file}."
)
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
checkpoint = load_state_dict(checkpoint)
def load_unsharded_optimizer(
self, optimizer: Optimizer, checkpoint: Path, low_cpu_mem_mode: bool = True, num_threads: int = 1
):
if checkpoint.endswith(".safetensors"):
checkpoint = load_flat(checkpoint)
else:
checkpoint = load_state_dict(checkpoint)
if not low_cpu_mem_mode:
checkpoint = create_pinned_state_dict(checkpoint, empty=False, num_threads=num_threads)
optimizer.load_state_dict(checkpoint)
def save_unsharded_optimizer(
@ -139,9 +187,25 @@ class GeneralCheckpointIO(CheckpointIO):
optimizer: Optimizer,
checkpoint: Path,
gather_dtensor: bool,
use_async: bool = False,
):
# TODO(FrankLeeeee): handle distributed tensors
save_state_dict(optimizer.state_dict(), checkpoint, use_safetensors=False)
state_dict = optimizer.state_dict()
if use_async:
from colossalai.utils.safetensors import _flatten_optim_state_dict, move_and_save
flatten_state_dict, metadata = _flatten_optim_state_dict(state_dict)
if id(optimizer) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(optimizer)] = create_pinned_state_dict(flatten_state_dict)
writer = move_and_save(
path=checkpoint,
state_dict=flatten_state_dict,
state_dict_pinned=self.pinned_state_dicts[id(optimizer)],
metadata=metadata,
)
self.async_writers.append(writer)
else:
save_state_dict(state_dict, checkpoint, use_safetensors=False)
def save_sharded_model(
self,
@ -151,6 +215,7 @@ class GeneralCheckpointIO(CheckpointIO):
prefix: Optional[str] = None,
max_shard_size: int = 1024,
use_safetensors: bool = False,
use_async: bool = False,
):
"""
implement this method as it can be supported by Huggingface model,
@ -168,16 +233,29 @@ class GeneralCheckpointIO(CheckpointIO):
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
index_file = CheckpointIndexFile(checkpoint_path)
# Save shards of optimizer states.
# In general cases, is_master is set to True to get the right behavior.
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint_path,
index_file=index_file,
base_filename=weights_name,
is_master=True,
use_safetensors=use_safetensors,
)
if use_async:
pinned_state_dict = self.pinned_state_dicts.get(hash(model), None)
total_size, new_pinned_state_dict, writers = async_move_save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint_path,
index_file=index_file,
base_filename=weights_name,
is_master=True,
pinned_state_dict=pinned_state_dict,
)
self.pinned_state_dicts[hash(model)] = new_pinned_state_dict
self.async_writers.extend(writers)
else:
# Save shards of optimizer states.
# In general cases, is_master is set to True to get the right behavior.
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint_path,
index_file=index_file,
base_filename=weights_name,
is_master=True,
use_safetensors=use_safetensors,
)
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
@ -195,6 +273,8 @@ class GeneralCheckpointIO(CheckpointIO):
strict: bool = False,
use_safetensors: bool = False,
load_sub_module: bool = True,
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
):
"""
load shard model, load model from multiple files
@ -211,11 +291,10 @@ class GeneralCheckpointIO(CheckpointIO):
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
missing_keys = []
for shard_file in checkpoint_files:
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors)
for state_dict in load_state_dict_shards(checkpoint_files, False, use_safetensors, low_cpu_mem_mode):
if not low_cpu_mem_mode:
state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads)
load_state_dict_into_model(model, state_dict, missing_keys, strict, load_sub_module)
del state_dict
gc.collect()
if strict:
remain_keys = reduce(lambda a, b: a & b, map(set, missing_keys))
@ -229,5 +308,7 @@ class GeneralCheckpointIO(CheckpointIO):
)
)
def save_lora_as_pretrained(self, model: nn.Module, checkpoint: str, use_safetensors: bool = False) -> None:
def save_lora_as_pretrained(
self, model: nn.Module, checkpoint: str, use_safetensors: bool = False, state_dict: Optional[dict] = None
) -> None:
raise NotImplementedError

View File

@ -1,6 +1,7 @@
import copy
import logging
import os
from collections import defaultdict
from functools import reduce
from pathlib import Path
from shutil import rmtree
@ -10,6 +11,7 @@ import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed import ProcessGroup
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils._pytree import tree_map
@ -22,19 +24,23 @@ from colossalai.tensor.padded_tensor import (
to_unpadded_tensor,
)
from colossalai.utils import get_current_device, get_non_persistent_buffers_set
from colossalai.utils.safetensors import _flatten_optim_state_dict, load_flat
from .general_checkpoint_io import GeneralCheckpointIO
from .index_file import CheckpointIndexFile
from .utils import (
StateDictSharder,
async_save_state_dict_shards,
create_pinned_state_dict,
gather_distributed_param,
gather_state_dict_fast,
get_lora_state_dict,
get_model_base_filenames,
get_optimizer_base_filenames,
is_safetensors_available,
load_shard_state_dict,
load_state_dict,
load_state_dict_into_model,
load_states_into_optimizer,
save_config_file,
save_param_groups,
save_state_dict,
@ -67,6 +73,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
dp_group: ProcessGroup,
pp_group: ProcessGroup,
tp_group: ProcessGroup,
sp_group: ProcessGroup,
zero_stage: int,
verbose: bool = True,
) -> None:
@ -74,9 +81,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
self.global_dp_group = dp_group
self.pp_group = pp_group
self.tp_group = tp_group
self.sp_group = sp_group
self.dp_rank = dist.get_rank(self.global_dp_group)
self.tp_rank = dist.get_rank(self.tp_group)
self.pp_rank = dist.get_rank(self.pp_group)
self.sp_rank = dist.get_rank(self.sp_group)
self.global_dp_size = dist.get_world_size(dp_group)
self.pp_size = dist.get_world_size(pp_group)
self.tp_size = dist.get_world_size(tp_group)
@ -86,7 +95,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
@staticmethod
def _model_sharder(
model: nn.Module, prefix: str = "", keep_vars: bool = False, size_per_shard: int = 1024
model: nn.Module,
prefix: str = "",
keep_vars: bool = False,
size_per_shard: int = 1024,
pinned_state_dicts: Optional[Dict[str, torch.Tensor]] = None,
) -> Iterator[Tuple[OrderedDict, int]]:
# An internel method that breaks state_dict of model into shards within limited size.
@ -97,9 +110,14 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
if param is None:
continue
# Gather tensor pieces when using tensor parallel.
if is_padded_tensor(param):
param = to_unpadded_tensor(param)
param_ = gather_distributed_param(param, keep_vars=False)
if is_padded_tensor(param_):
param_ = to_unpadded_tensor(param_)
if pinned_state_dicts is not None:
if (prefix + name) not in pinned_state_dicts:
pinned_state_dicts[prefix + name] = torch.empty_like(param_, pin_memory=True, device="cpu")
pinned_state_dicts[prefix + name].copy_(param_)
param_ = pinned_state_dicts[prefix + name]
block, block_size = state_dict_sharder.append_param(prefix + name, param_)
if block is not None:
yield block, block_size
@ -109,6 +127,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
for name, buf in model.named_buffers():
if buf is not None and name not in non_persist_buffers_set:
buffer = buf if keep_vars else buf.detach()
if pinned_state_dicts is not None:
if (prefix + name) not in pinned_state_dicts:
pinned_state_dicts[prefix + name] = torch.empty_like(param_, pin_memory=True, device="cpu")
pinned_state_dicts[prefix + name].copy_(buffer)
buffer = pinned_state_dicts[prefix + name]
block, block_size = state_dict_sharder.append_param(prefix + name, buffer)
if block is not None:
yield block, block_size
@ -120,6 +143,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
is not torch.nn.Module.get_extra_state
):
extra_state = model.get_extra_state()
if pinned_state_dicts is not None:
if extra_state_key not in pinned_state_dicts:
pinned_state_dicts[extra_state_key] = torch.empty_like(param_, pin_memory=True, device="cpu")
pinned_state_dicts[extra_state_key].copy_(extra_state)
extra_state = pinned_state_dicts[extra_state_key]
block, block_size = state_dict_sharder.append_param(extra_state_key, extra_state)
if block is not None:
yield block, block_size
@ -134,6 +162,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
dp_group: ProcessGroup,
tp_group: ProcessGroup,
size_per_shard: int = 1024,
pinned_state_dicts: Optional[Dict[int, Dict[str, torch.Tensor]]] = None,
):
# An internel method that breaks state_dict of optimizer into shards within limited size.
@ -151,6 +180,9 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
working_param = param
param_id = param_info["param2id"][id(working_param)]
if pinned_state_dicts is not None:
if param_id not in pinned_state_dicts:
pinned_state_dicts[param_id] = {}
original_shape = param_info["param2shape"][id(working_param)]
state_ = HybridParallelCheckpointIO.gather_from_sharded_optimizer_state(
state,
@ -160,6 +192,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
tp_group=tp_group,
use_zero=use_zero,
inplace=False,
pinned_state_dicts=pinned_state_dicts[param_id] if pinned_state_dicts is not None else None,
)
block, block_size = state_dict_sharder.append_optim_state(param_id, state_)
@ -177,6 +210,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
prefix: Optional[str] = None,
size_per_shard: int = 1024,
use_safetensors: bool = False,
use_async: bool = False,
) -> None:
"""
Save sharded model checkpoint under the given checkpointing path.
@ -194,6 +228,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
prefix (str, optional): Perfix of file to save. Defaults to None.
size_per_shard (int, optional): Size per shard in MB. Defaults to 1024.
use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
use_async (bool, optional): Whether to save the state_dicts of model asynchronously. Defaults to False.
"""
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
@ -212,21 +247,40 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
# Then collect the sharded parameters & buffers along tp_group.
# Only devices with tp_rank == 0 are responsible for model saving.
state_dict_shard = HybridParallelCheckpointIO._model_sharder(model, size_per_shard=size_per_shard)
control_saving = self.tp_rank == 0 and self.sp_rank == 0
if control_saving and use_async:
if hash(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[hash(model)] = {}
pinned_state_dicts = self.pinned_state_dicts[hash(model)]
else:
pinned_state_dicts = None
state_dict_shard = HybridParallelCheckpointIO._model_sharder(
model, size_per_shard=size_per_shard, pinned_state_dicts=pinned_state_dicts
)
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
index_file = CheckpointIndexFile(checkpoint)
control_saving = self.tp_rank == 0
if self.pp_size == 1:
# When pipeline is not used, save the model shards as in general checkpointIO
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=weights_name,
is_master=control_saving,
use_safetensors=use_safetensors,
)
if use_async:
total_size, writers = async_save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=weights_name,
is_master=control_saving,
state_preprocess=False,
)
self.async_writers.extend(writers)
else:
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=weights_name,
is_master=control_saving,
use_safetensors=use_safetensors,
)
if control_saving:
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
@ -251,16 +305,26 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
weights_name = weights_name.replace(".safetensors", f"-stage-{self.pp_rank+1:05d}-shard.safetensors")
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json")
save_index_file = os.path.join("tmp_index_files", save_index_file)
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=weights_name,
is_master=control_saving,
use_safetensors=use_safetensors,
use_pp_format=True,
)
if use_async:
total_size, writers = async_save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=weights_name,
is_master=control_saving,
state_preprocess=False,
)
self.async_writers.extend(writers)
else:
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=weights_name,
is_master=control_saving,
use_safetensors=use_safetensors,
use_pp_format=True,
)
if control_saving:
assert (
@ -294,7 +358,14 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
f"index located at {final_index_file_path}."
)
def load_sharded_model(self, model: ModelWrapper, checkpoint_index_file: Path, strict: bool = False):
def load_sharded_model(
self,
model: ModelWrapper,
checkpoint_index_file: Path,
strict: bool = False,
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
):
"""
Load sharded model with the given path to index file of checkpoint folder.
@ -342,6 +413,8 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
file_path = os.path.join(ckpt_root_path, filename)
state_dict = load_shard_state_dict(Path(file_path), use_safetensors)
if not low_cpu_mem_mode:
state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads)
load_state_dict_into_model(
model, state_dict, missing_keys=missing_keys, strict=strict, load_sub_module=True
@ -400,6 +473,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
gather_dtensor: bool = True,
prefix: Optional[str] = None,
size_per_shard: int = 1024,
use_async: bool = False,
):
"""
Save sharded optimizer checkpoint under the given checkpointing path.
@ -431,26 +505,46 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
# Then collect the sharded states along dp_group(if using zero)/tp_group.
# Only devices with (dp_rank == 0 and tp_rank == 0) are responsible for states saving.
control_saving = self.dp_rank == 0 and self.tp_rank == 0 and self.sp_rank == 0
if use_async and control_saving:
if id(optimizer) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(optimizer)] = {}
pinned_state_dicts = self.pinned_state_dicts[id(optimizer)]
else:
pinned_state_dicts = None
state_dict_shard = HybridParallelCheckpointIO._optimizer_sharder(
optimizer,
use_zero=self.use_zero,
dp_group=self.global_dp_group,
tp_group=self.tp_group,
size_per_shard=size_per_shard,
pinned_state_dicts=pinned_state_dicts,
)
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix, use_safetensors=use_async)
index_file = CheckpointIndexFile(checkpoint)
control_saving = self.dp_rank == 0 and self.tp_rank == 0
if self.pp_size == 1:
# When pipeline is not used, save the optimizer shards as in general checkpointIO
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=control_saving,
)
if use_async:
total_size, writers = async_save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=control_saving,
use_pp_format=True,
state_preprocess=True,
)
self.async_writers.extend(writers)
else:
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=control_saving,
)
if control_saving:
# Store param groups.
@ -481,18 +575,33 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True)
# Manage filenames of sharded weights and index file for each pipeline stage.
states_name = states_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-shard.bin")
if not use_async:
states_name = states_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-shard.bin")
else:
states_name = states_name.replace(".safetensors", f"-stage-{self.pp_rank+1:05d}-shard.safetensors")
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json")
save_index_file = os.path.join("tmp_index_files", save_index_file)
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=control_saving,
use_pp_format=True,
)
if use_async:
total_size, writers = async_save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=control_saving,
use_pp_format=True,
state_preprocess=True,
)
self.async_writers.extend(writers)
else:
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=control_saving,
use_pp_format=True,
)
if control_saving:
assert (
@ -535,7 +644,14 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
f"index located at {final_index_file_path}."
)
def load_sharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint_index_file: str, prefix: str = ""):
def load_sharded_optimizer(
self,
optimizer: OptimizerWrapper,
checkpoint_index_file: str,
prefix: str = "",
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
):
"""
Load sharded optimizer with the given path to index file of checkpoint folder.
@ -605,28 +721,46 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
continue
file_path = os.path.join(ckpt_root_path, filename)
state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False)
load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True)
if file_path.endswith(".safetensors"):
state_dict = load_flat(file_path)
else:
state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False)
if not low_cpu_mem_mode:
state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads)
self.load_states_into_optimizer(optimizer, state_dict, id_map)
loaded_file.add(filename)
# Then shard the loaded optimizer states if using tp/zero.
for param, state in optimizer.optim.state.items():
device = param.device
if master_to_working_map is not None:
working_param = master_to_working_map[id(param)]
else:
working_param = param
original_shape = optimizer.param_info["param2shape"][id(working_param)]
sharded_state = self.shard_from_complete_optimizer_state(
state, current_shape=working_param.shape, original_shape=original_shape, device=device, inplace=True
)
optimizer.optim.state[param] = sharded_state
sharded_optimizer_loading_epilogue(optimizer.optim)
if self.verbose and self.coordinator.is_master():
logging.info(f"The optimizer has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
def load_states_into_optimizer(self, optimizer: Optimizer, state_dict: dict, id_map: dict):
state_dict = {int(k): v for k, v in state_dict.items()}
new_states = defaultdict(dict)
master_to_working_map = optimizer.get_master_to_working_map()
for k, state in state_dict.items():
if k in id_map:
param = id_map[k]
device = param.device
dtype = param.dtype
if master_to_working_map is not None:
working_param = master_to_working_map[id(param)]
else:
working_param = param
original_shape = optimizer.param_info["param2shape"][id(working_param)]
new_states[param] = self.shard_from_complete_optimizer_state(
state,
current_shape=working_param.shape,
original_shape=original_shape,
device=device,
dtype=dtype,
inplace=True,
)
optimizer.optim.state.update(new_states)
def save_unsharded_model(
self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False
):
"""
Save model state dict to a single file with given checkpointing path.
@ -635,6 +769,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
checkpoint (str): Checkpointing path which should be a file path. Can be absolute or relative path.
gather_dtensor (bool, optional): Whether to gather dtensor, currently not used. Defaults to True.
use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
use_async (bool, optional): Whether to save the state_dicts of model asynchronously. Defaults to False.
"""
if self.coordinator.is_master():
logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!")
@ -651,7 +786,18 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
if self.pp_size == 1:
# When pipeline is not used, let master rank directly save the collected state_dict.
if self.tp_rank == 0:
save_state_dict(state_dict, checkpoint, use_safetensors)
if use_async:
from colossalai.utils.safetensors import save
if hash(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[hash(model)] = create_pinned_state_dict(state_dict)
for name, param in state_dict.items():
self.pinned_state_dicts[hash(model)][name].copy_(param)
state_dict[name] = self.pinned_state_dicts[hash(model)][name]
writer = save(path=checkpoint, state_dict=state_dict)
self.async_writers.append(writer)
else:
save_state_dict(state_dict, checkpoint, use_safetensors)
else:
# When pipeline is used, first collect state_dict from every pipeline stage, then save the complete state_dict.
state_dict_list = [None for _ in range(self.pp_size)]
@ -662,9 +808,27 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
complete_state_dict = dict()
for _state_dict in state_dict_list:
complete_state_dict.update(_state_dict)
save_state_dict(complete_state_dict, checkpoint, use_safetensors)
if use_async:
from colossalai.utils.safetensors import save
def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = False):
if hash(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[hash(model)] = create_pinned_state_dict(complete_state_dict)
for name, param in complete_state_dict.items():
self.pinned_state_dicts[hash(model)][name].copy_(param)
complete_state_dict[name] = self.pinned_state_dicts[hash(model)][name]
writer = save(path=checkpoint, state_dict=complete_state_dict)
self.async_writers.append(writer)
else:
save_state_dict(complete_state_dict, checkpoint, use_safetensors)
def load_unsharded_model(
self,
model: ModelWrapper,
checkpoint: str,
strict: bool = False,
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
):
"""
Load model from a single file with the given path of checkpoint.
@ -687,12 +851,16 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
# has been implemented by _load_from_state_dict method of ParallelModule in Shardformer,
# model.load_state_dict can be directly called.
state_dict = load_state_dict(checkpoint)
if not low_cpu_mem_mode:
state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads)
model.load_state_dict(state_dict, strict=strict)
# Update master params if mixed-precision training is enabled.
model_before_wrapping.update_master_params()
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
def save_unsharded_optimizer(
self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool, use_async: bool = False
):
"""
Save optimizer state dict to a file with given path.
@ -723,6 +891,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
# gather complete state from tp shards & dp shards
param_id = optimizer.param_info["param2id"][id(working_param)]
original_shape = optimizer.param_info["param2shape"][id(working_param)]
local_states[param_id] = HybridParallelCheckpointIO.gather_from_sharded_optimizer_state(
state,
working_param,
@ -742,7 +911,19 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
]
state_dict = {"param_groups": param_groups, "state": local_states}
if self.coordinator.is_master():
save_state_dict(state_dict, checkpoint, use_safetensors=False)
if use_async:
from colossalai.utils.safetensors import save
flatten_state_dict, metadata = _flatten_optim_state_dict(state_dict)
if id(optimizer) not in self.pinned_state_dicts:
self.pinned_state_dicts = create_pinned_state_dict(flatten_state_dict)
for k, v in flatten_state_dict.items():
self.pinned_state_dicts[k].copy_(v)
flatten_state_dict[k] = self.pinned_state_dicts[k]
writer = save(path=checkpoint, state_dict=flatten_state_dict, metadata=metadata)
self.async_writers.append(writer)
else:
save_state_dict(state_dict, checkpoint, use_safetensors=False)
else:
# When pipeline is used, first collect state_dict from every pipeline stage, then save the complete state_dict.
states_list = [None for _ in range(self.pp_size)]
@ -758,9 +939,23 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
state_dict = {"param_groups": param_groups, "state": dict()}
for _states in states_list:
state_dict["state"].update(_states)
save_state_dict(state_dict, checkpoint, use_safetensors=False)
if use_async:
from colossalai.utils.safetensors import save
def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str):
flatten_state_dict, metadata = _flatten_optim_state_dict(state_dict)
if id(optimizer) not in self.pinned_state_dicts:
self.pinned_state_dicts = create_pinned_state_dict(flatten_state_dict)
for k, v in flatten_state_dict.items():
self.pinned_state_dicts[k].copy_(v)
flatten_state_dict[k] = self.pinned_state_dicts[k]
writer = save(path=checkpoint, state_dict=flatten_state_dict, metadata=metadata)
self.async_writers.append(writer)
else:
save_state_dict(state_dict, checkpoint, use_safetensors=False)
def load_unsharded_optimizer(
self, optimizer: OptimizerWrapper, checkpoint: str, low_cpu_mem_mode: bool = True, num_threads: int = 1
):
"""
Load optimizer from a file with given path.
@ -784,7 +979,12 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
# Complete optimizer state_dict loaded from checkpoint, need to be processed later.
state_dict = load_state_dict(checkpoint)
if checkpoint.endswith(".safetensors"):
state_dict = load_flat(checkpoint)
else:
state_dict = load_state_dict(checkpoint)
if not low_cpu_mem_mode:
state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads)
# Load param_groups.
updated_groups = []
@ -802,22 +1002,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
for param in pg["params"]:
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
id_map[param_id] = param
load_states_into_optimizer(optimizer.optim, state_dict["state"], id_map, strict=True)
# Then shard the loaded optimizer states if using tp/zero.
for param, state in optimizer.optim.state.items():
if param is None:
continue
device = param.device
if master_to_working_map is not None:
working_param = master_to_working_map[id(param)]
else:
working_param = param
original_shape = optimizer.param_info["param2shape"][id(working_param)]
sharded_state = self.shard_from_complete_optimizer_state(
state, current_shape=working_param.shape, original_shape=original_shape, device=device, inplace=True
)
optimizer.optim.state[param] = sharded_state
self.load_states_into_optimizer(optimizer, state_dict["state"], id_map)
sharded_optimizer_loading_epilogue(optimizer.optim)
@ -838,6 +1023,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
use_zero: bool,
inplace: bool,
device: torch.device = torch.device("cpu"),
pinned_state_dicts: Optional[Dict[str, torch.Tensor]] = None,
) -> OrderedDict:
"""
With given parameter and its optimizer states, gather the complete optimizer state for saving.
@ -861,6 +1047,8 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
state_ = state if inplace else copy.deepcopy(state)
for k, v in state_.items():
if v is None:
continue
if isinstance(v, torch.Tensor) and k != "step":
# First gather Zero shards.
if use_zero:
@ -881,7 +1069,13 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
v = init_as_padded_tensor(v, v.shape[padding_dim], original_shape[padding_dim], padding_dim)
v = to_unpadded_tensor(v)
state_[k] = v.detach().clone().to(device)
if pinned_state_dicts is not None:
if k not in pinned_state_dicts:
pinned_state_dicts[k] = torch.empty_like(v, pin_memory=True, device="cpu")
pinned_state_dicts[k].copy_(v)
state_[k] = pinned_state_dicts[k]
else:
state_[k] = v.detach().clone().to(device)
return state_
@ -891,6 +1085,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
current_shape: torch.Size,
original_shape: torch.Size,
device: torch.device,
dtype: torch.dtype,
inplace: bool,
) -> OrderedDict:
"""
@ -940,11 +1135,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
slice_size = v.numel() // self.global_dp_size
v = v.split(slice_size, dim=0)[self.dp_rank]
state_[k] = v.detach().clone().to(device)
state_[k] = v.detach().clone().to(device=device, dtype=dtype)
return state_
def save_lora_as_pretrained(self, model, checkpoint, use_safetensors):
def save_lora_as_pretrained(self, model, checkpoint, use_safetensors, state_dict: Optional[dict] = None):
if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return
@ -952,12 +1147,21 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
model._force_wait_all_gather()
peft_model = model.unwrap()
peft_model = model.unwrap(unwrap_peft=False)
assert isinstance(
peft_model, PeftModel
), "The model doesn't have lora adapters, please enable lora before saving."
return peft_model.save_pretrained(
checkpoint,
safe_serialization=use_safetensors,
state_dict=tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict()),
)
if state_dict is None:
state_dict = tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict())
if self.pp_size > 1:
lora_state_dict = get_lora_state_dict(peft_model, state_dict)
gathered_lora_state_dict = gather_state_dict_fast(lora_state_dict, self.pp_group, device="cpu")
if self.pp_rank == 0:
state_dict.update(gathered_lora_state_dict)
state_dict = tree_map(lambda x: x.cpu() if torch.is_tensor(x) else x, state_dict)
if self.coordinator.is_master():
return peft_model.save_pretrained(
checkpoint,
safe_serialization=use_safetensors,
state_dict=state_dict,
)

View File

@ -10,6 +10,7 @@ import torch.distributed as dist
import torch.nn as nn
from torch.distributed import ProcessGroup
from torch.distributed.distributed_c10d import get_global_rank
from torch.utils._pytree import tree_map
from colossalai.checkpoint_io import CheckpointIndexFile
from colossalai.checkpoint_io.hybrid_parallel_checkpoint_io import HybridParallelCheckpointIO
@ -17,6 +18,8 @@ from colossalai.checkpoint_io.index_file import CheckpointIndexFile
from colossalai.checkpoint_io.utils import (
StateDictSharder,
gather_distributed_param,
gather_state_dict_fast,
get_lora_state_dict,
get_model_base_filenames,
get_optimizer_base_filenames,
load_shard_state_dict,
@ -44,12 +47,13 @@ class MoECheckpointIO(HybridParallelCheckpointIO):
global_dp_group: ProcessGroup,
pp_group: ProcessGroup,
tp_group: ProcessGroup,
sp_group: ProcessGroup,
ep_group: ProcessGroup,
moe_dp_group: ProcessGroup,
zero_stage: int,
verbose: bool = True,
) -> None:
super().__init__(global_dp_group, pp_group, tp_group, zero_stage, verbose)
super().__init__(global_dp_group, pp_group, tp_group, sp_group, zero_stage, verbose)
self.global_dp_group = global_dp_group
self.global_dp_rank = dist.get_rank(global_dp_group)
self.global_dp_size = dist.get_world_size(global_dp_group)
@ -117,6 +121,7 @@ class MoECheckpointIO(HybridParallelCheckpointIO):
prefix: Optional[str] = None,
size_per_shard: int = 1024,
use_safetensors: bool = False,
use_async: bool = False,
) -> None:
"""
Save sharded model checkpoint under the given checkpointing path.
@ -157,28 +162,31 @@ class MoECheckpointIO(HybridParallelCheckpointIO):
state_dict_shard = MoECheckpointIO._model_sharder(model, size_per_shard=size_per_shard)
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
index_file = CheckpointIndexFile(checkpoint)
control_saving = self.tp_rank == 0
control_saving = self.tp_rank == 0 and self.sp_rank == 0
if self.pp_size == 1 and self.ep_size == 1:
# When pipeline is not used, save the model shards as in general checkpointIO
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=weights_name,
is_master=control_saving,
use_safetensors=use_safetensors,
)
if control_saving:
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
save_config_file(model, checkpoint)
if self.verbose and self.coordinator.is_master():
logging.info(
f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
if use_async:
super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async=use_async)
else:
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=weights_name,
is_master=control_saving,
use_safetensors=use_safetensors,
)
if control_saving:
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
save_config_file(model, checkpoint)
if self.verbose and self.coordinator.is_master():
logging.info(
f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
dist.barrier()
else:
@ -365,6 +373,7 @@ class MoECheckpointIO(HybridParallelCheckpointIO):
gather_dtensor: bool = True,
prefix: Optional[str] = None,
size_per_shard: int = 1024,
use_async: bool = False,
):
"""
Save sharded optimizer checkpoint under the given checkpointing path.
@ -410,7 +419,7 @@ class MoECheckpointIO(HybridParallelCheckpointIO):
# e.g. dp_size = 4, moe_dp_size = 2, ep_size = 2 and use gather
# rank 0 saves moe & non-moe params; rank 1 only saves moe params
# rank 3 & 4 save nothing
control_saving = self.tp_rank == 0 and self.moe_dp_rank == 0
control_saving = self.tp_rank == 0 and self.moe_dp_rank == 0 and self.sp_rank == 0
if self.pp_size == 1 and self.ep_size == 1:
# When pipeline is not used, save the optimizer shards as in general checkpointIO
@ -504,7 +513,14 @@ class MoECheckpointIO(HybridParallelCheckpointIO):
f"index located at {final_index_file_path}."
)
def load_sharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint_index_file: str, prefix: str = ""):
def load_sharded_optimizer(
self,
optimizer: OptimizerWrapper,
checkpoint_index_file: str,
prefix: str = "",
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
):
"""
Load sharded optimizer with the given path to index file of checkpoint folder.
@ -685,15 +701,18 @@ class MoECheckpointIO(HybridParallelCheckpointIO):
all_param = None
# gather param from every ep rank
# dist.all_gather(all_param, param, group=ep_group)
dist.gather(param, all_param, group=ep_group)
dist.gather(param, all_param, dst=dist.get_global_rank(ep_group, 0), group=ep_group)
if ep_rank == 0:
all_param = torch.cat(all_param, dim=0)
state_dict[name] = all_param.cpu()
if self.pp_size > 1:
if self.dp_rank == 0:
out = [None for _ in range(self.pp_size)]
dist.gather_object(state_dict, out, group=self.pp_group)
if self.pp_rank == 0:
out = [None for _ in range(self.pp_size)]
else:
out = None
dist.gather_object(state_dict, out, dst=dist.get_global_rank(self.pp_group, 0), group=self.pp_group)
if self.pp_rank == 0:
new_state_dict = {}
for o in out:
@ -708,14 +727,30 @@ class MoECheckpointIO(HybridParallelCheckpointIO):
checkpoint: str,
gather_dtensor: bool,
use_safetensors: bool,
use_async: bool = False,
):
state_dict = self.pre_save_model(model)
if dist.get_rank() == 0:
torch.save(state_dict, checkpoint)
if use_async:
super().save_unsharded_model(
model=model,
checkpoint=checkpoint,
gather_dtensor=gather_dtensor,
use_safetensors=use_safetensors,
use_async=use_async,
)
else:
torch.save(state_dict, checkpoint)
dist.barrier()
# Copied from colossalai.moe
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
def save_unsharded_optimizer(
self,
optimizer: OptimizerWrapper,
checkpoint: str,
gather_dtensor: bool,
use_async: bool = False,
):
"""
Save optimizer state dict to a file with given path.
@ -773,7 +808,14 @@ class MoECheckpointIO(HybridParallelCheckpointIO):
dist.barrier()
# Copied from colossalai.moe
def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, strict: bool = False):
def load_unsharded_optimizer(
self,
optimizer: OptimizerWrapper,
checkpoint: str,
strict: bool = False,
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
):
"""
Load optimizer from a file with given path.
@ -853,3 +895,26 @@ class MoECheckpointIO(HybridParallelCheckpointIO):
optimizer.optim.state[param] = sharded_state
sharded_optimizer_loading_epilogue(optimizer.optim)
dist.barrier()
def save_lora_as_pretrained(self, model, checkpoint, use_safetensors, state_dict=None):
if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return
from peft import PeftModel
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
model._force_wait_all_gather()
peft_model = model.unwrap(unwrap_peft=False)
assert isinstance(
peft_model, PeftModel
), "The model doesn't have lora adapters, please enable lora before saving."
if state_dict is None:
state_dict = tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict())
if self.ep_size > 1:
lora_state_dict = get_lora_state_dict(peft_model, state_dict)
moe_params = set(n for n, p in peft_model.named_parameters() if is_moe_tensor(p))
expert_state_dict = {n: p for n, p in lora_state_dict.items() if n in moe_params}
gathered_expert_state_dict = gather_state_dict_fast(expert_state_dict, self.ep_group)
if self.ep_rank == 0:
state_dict.update(gathered_expert_state_dict)
return super().save_lora_as_pretrained(model, checkpoint, use_safetensors, state_dict)

View File

@ -1,31 +1,43 @@
# coding=utf-8
import concurrent.futures
import os
import re
import warnings
from collections import abc as container_abcs
from collections import defaultdict
from itertools import chain
from pathlib import Path
from typing import Iterator, List, Mapping, Optional, OrderedDict, Tuple
from typing import Dict, Generator, Iterator, List, Mapping, Optional, OrderedDict, Tuple, Union
import torch
import torch.distributed as dist
import torch.nn as nn
from packaging.version import Version
from peft import PeftModel, PeftType
from peft.utils.other import EMBEDDING_LAYER_NAMES, check_file_exists_on_hf_hub
from peft.utils.save_and_load import get_embedding_layer_name, has_valid_embedding_base_layer
from torch.optim import Optimizer
from torch.utils._pytree import tree_map
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
from colossalai.accelerator import get_accelerator
from colossalai.interface.model import PeftUnwrapMixin
from colossalai.tensor.d_tensor import (
is_customized_distributed_tensor,
is_distributed_tensor,
to_global,
to_global_for_customized_distributed_tensor,
)
from colossalai.utils import get_current_device
from colossalai.utils.safetensors import _flatten_optim_state_dict, load_flat
SAFE_WEIGHTS_NAME = "model.safetensors"
WEIGHTS_NAME = "pytorch_model.bin"
STATES_NAME = "pytorch_optim.bin"
SAFE_STATE_NAME = "optimizer.safetensors"
SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json"
WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
STATES_INDEX_NAME = "pytorch_optim.bin.index.json"
SAFE_STATES_INDEX_NAME = "optimizer.safetensors.index.json"
GROUP_FILE_NAME = "pytorch_optim_group.bin"
# ======================================
@ -263,7 +275,138 @@ def save_state_dict_shards(
return total_size
def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]:
def async_save_state_dict_shards(
sharded_state_dict: Iterator[Tuple[OrderedDict, int]],
checkpoint: str,
index_file: "CheckpointIndexFile",
base_filename: str,
is_master: bool,
use_pp_format: bool = False,
state_preprocess: bool = False,
) -> Tuple[int, list]:
"""
Save sharded state dict only on master rank, this method can be used by both model and optimizer states.
Args:
sharded_state_dict (Iterator[Tuple[OrderedDict, int]]): a generator of shards, each shard contains state dict and shard size.
checkpoint (str): The path of checkpoint directory as string.
index_file (CheckpointIndexFile): The index file object to be updated.
base_filename (str): Decides the prefix of filenames of shards.
is_master (bool): Whether current rank is main process.
use_safetensors (bool, optional): Whether to use safetensors to save checkpoint. Defaults to False.
use_pp_format: (bool, optional): Whether to save the files in pipeline format including stage information. Defaults to False.
Returns:
int: the total size of shards
"""
from colossalai.utils.safetensors import save
total_size = 0
shard_filenames = []
writers = []
for idx, shard_pair in enumerate(sharded_state_dict):
shard, current_size = shard_pair
# Just loop over the sharder and gather to other ranks if not master
if not is_master:
del shard
continue
shard_file = get_shard_filename(base_filename, idx)
total_size = total_size + current_size
for key in shard.keys():
index_file.append_weight_map(key, shard_file)
checkpoint_file_path = os.path.join(checkpoint, shard_file)
if state_preprocess:
state_dict, metadata = _flatten_optim_state_dict(state_dict=shard, seperator=".")
else:
state_dict = shard
metadata = None
# Only save on master rank.
writer = save(checkpoint_file_path, state_dict=state_dict, metadata=metadata)
writers.append(writer)
shard_filenames.append(shard_file)
del shard
# Clean folder, deleted unneeded files.
clean_folder(checkpoint, base_filename, shard_filenames, is_master=is_master, use_pp_format=use_pp_format)
return total_size, writers
def async_move_save_state_dict_shards(
sharded_state_dict: Iterator[Tuple[OrderedDict, int]],
checkpoint: str,
index_file: "CheckpointIndexFile",
base_filename: str,
is_master: bool,
pinned_state_dict: Optional[Dict[str, torch.Tensor]],
use_pp_format: bool = False,
state_preprocess: bool = False,
) -> Tuple[int, Dict[str, torch.Tensor], list]:
"""
Save sharded state dict only on master rank, this method can be used by both model and optimizer states.
Args:
sharded_state_dict (Iterator[Tuple[OrderedDict, int]]): a generator of shards, each shard contains state dict and shard size.
checkpoint (str): The path of checkpoint directory as string.
index_file (CheckpointIndexFile): The index file object to be updated.
base_filename (str): Decides the prefix of filenames of shards.
is_master (bool): Whether current rank is main process.
use_safetensors (bool, optional): Whether to use safetensors to save checkpoint. Defaults to False.
use_pp_format: (bool, optional): Whether to save the files in pipeline format including stage information. Defaults to False.
Returns:
int: the total size of shards
"""
from colossalai.utils.safetensors import move_and_save
total_size = 0
shard_filenames = []
if pinned_state_dict is None:
returned_state_dict = {}
else:
returned_state_dict = pinned_state_dict
writers = []
for idx, shard_pair in enumerate(sharded_state_dict):
shard, current_size = shard_pair
# Just loop over the sharder and gather to other ranks if not master
if not is_master:
del shard
continue
shard_file = get_shard_filename(base_filename, idx)
total_size = total_size + current_size
for key in shard.keys():
index_file.append_weight_map(key, shard_file)
checkpoint_file_path = os.path.join(checkpoint, shard_file)
if state_preprocess:
state_dict, metadata = _flatten_optim_state_dict(state_dict=shard)
else:
state_dict = shard
metadata = None
if pinned_state_dict is not None:
sub_pinned_state_dict = {k: pinned_state_dict[k] for k in state_dict.keys()}
else:
sub_pinned_state_dict = create_pinned_state_dict(state_dict)
returned_state_dict.update(sub_pinned_state_dict)
# Only save on master rank.
writer = move_and_save(checkpoint_file_path, state_dict, sub_pinned_state_dict, metadata)
writers.append(writer)
shard_filenames.append(shard_file)
del shard
# Clean folder, deleted unneeded files.
clean_folder(checkpoint, base_filename, shard_filenames, is_master=is_master, use_pp_format=use_pp_format)
return total_size, returned_state_dict, writers
def shard_model_checkpoint(
state_dict: torch.Tensor,
max_shard_size: int = 1024,
pinned_state_dicts: Optional[Dict[int, Dict[str, torch.Tensor]]] = None,
) -> Iterator[Tuple[OrderedDict, int]]:
"""
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
given size.
@ -272,6 +415,11 @@ def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024)
for key, weight in state_dict.items():
if not is_distributed_tensor(weight):
if pinned_state_dicts is not None:
if key not in pinned_state_dicts:
pinned_state_dicts[key] = torch.empty_like(weight, pin_memory=True, device="cpu")
pinned_state_dicts[key].copy_(weight)
weight = pinned_state_dicts[key]
block, block_size = state_dict_sharder.append_param(key, weight)
if block != None:
@ -281,7 +429,9 @@ def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024)
yield state_dict_sharder.current_block, state_dict_sharder.current_block_size
def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]:
def shard_optimizer_checkpoint(
state_dict: dict, max_shard_size: int = 1024, pinned_state_dicts: Optional[Dict[str, torch.Tensor]] = None
) -> Iterator[Tuple[OrderedDict, int]]:
"""
Splits an optimizer state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
given size.
@ -292,6 +442,15 @@ def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) ->
state_dict_sharder = StateDictSharder(max_shard_size)
for param_id, state in states.items():
if pinned_state_dicts is not None:
if param_id not in pinned_state_dicts:
pinned_state_dicts[param_id] = {}
for k, v in state.items():
if k not in pinned_state_dicts[param_id]:
pinned_state_dicts[param_id][k] = torch.empty_like(v, pin_memory=True, device="cpu")
pinned_state_dicts[param_id][k].copy_(v)
state[k] = pinned_state_dicts[param_id][k]
block, block_size = state_dict_sharder.append_optim_state(param_id, state)
if block != None:
yield block, block_size
@ -305,7 +464,11 @@ def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) ->
# ======================================
def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors: bool) -> None:
def save_state_dict(
state_dict: dict,
checkpoint_file_path: str,
use_safetensors: bool,
) -> None:
"""
Save state dict to checkpoint.
@ -371,7 +534,6 @@ def clean_folder(
reg = re.compile(r"(.*?)-stage-\d{5}-shard-\d{5}")
if (
filename.startswith(weights_no_suffix)
and os.path.isfile(full_filename)
and filename not in shard_filenames
and reg.fullmatch(filename_no_suffix) is not None
):
@ -393,6 +555,8 @@ def save_config_file(model: nn.Module, checkpoint_path: str, is_master: bool = T
from transformers.modeling_utils import unwrap_model as unwrap_huggingface_model
except ImportError:
return
if isinstance(model, PeftUnwrapMixin):
model = model.base_model
if not isinstance(model, PreTrainedModel):
return
@ -515,14 +679,7 @@ def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool = False):
raise Exception("load the model using `safetensors`, but no file endwith .safetensors")
if use_safetensors:
from safetensors.torch import load_file as safe_load_file
from safetensors.torch import safe_open
with safe_open(checkpoint_file, framework="pt") as f:
metadata = f.metadata()
if metadata["format"] != "pt":
raise NotImplementedError(
f"Conversion from a {metadata['format']} safetensors archive to PyTorch is not implemented yet."
)
return safe_load_file(checkpoint_file)
else:
return torch.load(checkpoint_file, map_location=torch.device("cpu"))
@ -538,6 +695,9 @@ def load_state_dict_into_model(
state_dict (dict): a dict containing parameters and
persistent buffers.
"""
if isinstance(model, PeftUnwrapMixin):
state_dict = model.patch_state_dict(state_dict)
model = model.base_model
if not isinstance(state_dict, Mapping):
raise TypeError("Expected state_dict to be dict-like, got {}.".format(type(state_dict)))
@ -647,7 +807,7 @@ def load_states_into_optimizer(optimizer: Optimizer, state_dict: dict, id_map: d
if key != "step":
if param.is_floating_point():
value = value.to(param.dtype)
value = value.to(param.device)
value = value.to(param.device, non_blocking=True)
return value
elif isinstance(value, dict):
return {k: cast(param, v, key=k) for k, v in value.items()}
@ -667,6 +827,7 @@ def load_states_into_optimizer(optimizer: Optimizer, state_dict: dict, id_map: d
elif not strict:
new_states[k] = v
get_accelerator().synchronize()
optimizer.state.update(new_states)
@ -776,14 +937,14 @@ def get_model_base_filenames(prefix: str = None, use_safetensors: bool = False):
return weights_name, save_index_file
def get_optimizer_base_filenames(prefix: str = None):
def get_optimizer_base_filenames(prefix: str = None, use_safetensors: bool = False):
"""
generate base optimizer state filenames
"""
states_name = STATES_NAME
states_name = SAFE_STATE_NAME if use_safetensors else STATES_NAME
states_name = add_prefix(states_name, prefix)
save_index_file = STATES_INDEX_NAME
save_index_file = SAFE_STATES_INDEX_NAME if use_safetensors else STATES_INDEX_NAME
save_index_file = add_prefix(save_index_file, prefix)
param_group_file = GROUP_FILE_NAME
@ -799,3 +960,196 @@ def get_shard_filename(weights_name: str, idx: int):
shard_file = weights_name.replace(".bin", f"-{idx+1:05d}.bin")
shard_file = shard_file.replace(".safetensors", f"-{idx+1:05d}.safetensors")
return shard_file
def _pin_tensor(tensor: torch.Tensor, empty: bool = True) -> torch.Tensor:
if empty:
return torch.empty_like(tensor, pin_memory=True, device="cpu")
return tensor.pin_memory()
def create_pinned_state_dict(
state_dict: Union[Dict[str, torch.Tensor], Dict[int, Dict[str, torch.Tensor]]],
empty: bool = True,
num_threads: int = 1,
) -> Dict[str, torch.Tensor]:
if num_threads == 1:
return tree_map(lambda x: _pin_tensor(x, empty=empty) if isinstance(x, torch.Tensor) else x, state_dict)
else:
with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor:
elems, spec = tree_flatten(state_dict)
future_to_idx = {}
for i, elem in enumerate(elems):
if isinstance(elem, torch.Tensor):
future_to_idx[executor.submit(_pin_tensor, elem, empty)] = i
for future in concurrent.futures.as_completed(future_to_idx):
idx = future_to_idx[future]
elems[idx] = future.result()
return tree_unflatten(elems, spec)
def load_optim_or_model_shard(path: str, is_optim: bool, use_safetensors: bool) -> dict:
if is_optim:
if path.endswith(".safetensors"):
state_dict = load_flat(path)
else:
state_dict = load_shard_state_dict(Path(path), use_safetensors=False)
else:
state_dict = load_shard_state_dict(Path(path), use_safetensors)
return state_dict
def load_state_dict_shards(
checkpoint_files: List[str],
is_optim: bool,
use_safetensors: bool,
low_cpu_mem_mode: bool = True,
prefetch: int = 3,
) -> Generator[dict, None, None]:
if low_cpu_mem_mode:
for shard_file in checkpoint_files:
state_dict = load_optim_or_model_shard(shard_file, is_optim, use_safetensors)
yield state_dict
else:
with concurrent.futures.ThreadPoolExecutor(max_workers=prefetch) as executor:
futures = []
for shard_file in checkpoint_files:
future = executor.submit(load_optim_or_model_shard, shard_file, is_optim, use_safetensors)
futures.append(future)
for future in concurrent.futures.as_completed(futures):
yield future.result()
# adapted from `peft/utils/save_and_load.py`
def get_lora_state_dict(
model: PeftModel, state_dict: dict, adapter_name="default", save_embedding_layers="auto"
) -> dict:
config = model.peft_config[adapter_name]
if config.peft_type != PeftType.LORA:
raise ValueError(f"Adapter {adapter_name} is not a LORA adapter.")
# to_return = lora_state_dict(model, bias=model.peft_config.bias)
# adapted from `https://github.com/microsoft/LoRA/blob/main/loralib/utils.py`
# to be used directly with the state dict which is necessary when using DeepSpeed or FSDP
bias = config.bias
if bias == "none":
to_return = {k: state_dict[k] for k in state_dict if "lora_" in k}
elif bias == "all":
to_return = {k: state_dict[k] for k in state_dict if "lora_" in k or "bias" in k}
elif bias == "lora_only":
to_return = {}
for k in state_dict:
if "lora_" in k:
to_return[k] = state_dict[k]
bias_name = k.split("lora_")[0] + "bias"
if bias_name in state_dict:
to_return[bias_name] = state_dict[bias_name]
else:
raise NotImplementedError
to_return = {k: v for k, v in to_return.items() if (("lora_" in k and adapter_name in k) or ("bias" in k))}
if config.use_dora:
# Here we take care of a refactor of DoRA which changed lora_magnitude_vector from a ParameterDict to a
# ModuleDict with a DoraLayer instance. The old parameter is now the "weight" attribute of that layer. Since
# we want the state_dict format not to change, we remove the "weight" part.
new_dora_suffix = f"lora_magnitude_vector.{adapter_name}.weight"
def renamed_dora_weights(k):
if k.endswith(new_dora_suffix):
k = k[:-7] # remove ".weight"
return k
to_return = {renamed_dora_weights(k): v for k, v in to_return.items()}
# DEAL WITH EMBEDDINGS
# check the common embedding layers in `target_modules` to reset `save_embedding_layers` if necessary
is_embedding_in_target_modules = False
if (
save_embedding_layers == "auto"
and hasattr(config, "target_modules")
and any(k in config.target_modules for k in EMBEDDING_LAYER_NAMES)
):
warnings.warn("Setting `save_embedding_layers` to `True` as embedding layers found in `target_modules`.")
save_embedding_layers = is_embedding_in_target_modules = True
elif save_embedding_layers == "auto":
vocab_size = getattr(getattr(model, "config", None), "vocab_size", None)
model_id = getattr(config, "base_model_name_or_path", None)
# For some models e.g. diffusers the text config file is stored in a subfolder
# we need to make sure we can download that config.
has_base_config = False
# ensure that this check is not performed in HF offline mode, see #1452
if model_id is not None:
local_config_exists = os.path.exists(os.path.join(model_id, "config.json"))
exists = local_config_exists or check_file_exists_on_hf_hub(model_id, "config.json")
if exists is None:
# check failed, could not determine if it exists or not
warnings.warn(
f"Could not find a config file in {model_id} - will assume that the vocabulary was not modified."
)
has_base_config = False
else:
has_base_config = exists
# check if the vocab size of the base model is different from the vocab size of the finetuned model
if (
vocab_size
and model_id
and has_base_config
and (vocab_size != model.config.__class__.from_pretrained(model_id).vocab_size)
):
warnings.warn(
"Setting `save_embedding_layers` to `True` as the embedding layer has been resized during finetuning."
)
save_embedding_layers = True
else:
save_embedding_layers = False
if save_embedding_layers and hasattr(model, "get_input_embeddings"):
for layer in [model.get_input_embeddings(), model.get_output_embeddings()]:
if not is_embedding_in_target_modules or has_valid_embedding_base_layer(layer):
# support from version >= 0.6.2
embedding_module_name = get_embedding_layer_name(model, layer, is_embedding_in_target_modules)
if embedding_module_name:
to_return.update({k: v for k, v in state_dict.items() if embedding_module_name in k})
elif save_embedding_layers:
warnings.warn("Could not identify embedding layer(s) because the model is not a 🤗 transformers model.")
return to_return
def gather_state_dict_fast(
state_dict: Dict[str, torch.Tensor],
group: dist.ProcessGroup,
device: Optional[Union[torch.device, str]] = None,
dst: int = 0,
) -> Optional[Dict[str, torch.Tensor]]:
if device is None:
device = get_current_device()
rank = dist.get_rank(group)
world_size = dist.get_world_size(group)
metadata = [(k, v.shape, v.dtype) for k, v in state_dict.items()]
all_meta_data = [None] * world_size
if rank == dst:
returned_state_dict = state_dict.copy()
dist.gather_object(metadata, all_meta_data, dst=dist.get_global_rank(group, rank), group=group)
for i, target_metadata in enumerate(all_meta_data):
if i == dst:
continue
ops = []
for k, shape, dtype in target_metadata:
buffer = torch.empty(shape, dtype=dtype, device=get_current_device())
returned_state_dict[k] = buffer
ops.append(dist.P2POp(dist.irecv, buffer, dist.get_global_rank(group, i), group))
reqs = dist.batch_isend_irecv(ops)
for req, (k, *_) in zip(reqs, target_metadata):
req.wait()
returned_state_dict[k] = returned_state_dict[k].to(device)
return returned_state_dict
else:
dist.gather_object(metadata, dst=dist.get_global_rank(group, dst), group=group)
ops = []
for k, *_ in metadata:
ops.append(dist.P2POp(dist.isend, state_dict[k], dist.get_global_rank(group, dst), group))
reqs = dist.batch_isend_irecv(ops)
for req in reqs:
req.wait()

View File

@ -64,7 +64,8 @@ from .run import launch_multi_processes
"This will be converted to --arg1=1 --arg2=2 during execution",
)
@click.option("--ssh-port", type=int, default=None, help="(optional) the port used for ssh connection")
@click.argument("user_script", type=str)
@click.option("-m", type=str, default=None, help="run library module as a script (terminates option list)")
@click.argument("user_script", type=str, required=False, default=None)
@click.argument("user_args", nargs=-1)
def run(
host: str,
@ -77,8 +78,9 @@ def run(
master_port: int,
extra_launch_args: str,
ssh_port: int,
m: str,
user_script: str,
user_args: str,
user_args: tuple,
) -> None:
"""
To launch multiple processes on a single node or multiple nodes via command line.
@ -102,9 +104,24 @@ def run(
# run with hostfile excluding the hosts selected
colossalai run --hostfile <file_path> --master_addr host1 --exclude host2 --nprocs_per_node 4 train.py
"""
if not user_script.endswith(".py"):
click.echo(f"Error: invalid Python file {user_script}. Did you use a wrong option? Try colossalai run --help")
exit()
if m is not None:
if m.endswith(".py"):
click.echo(f"Error: invalid Python module {m}. Did you use a wrong option? Try colossalai run --help")
exit()
if user_script is not None:
user_args = (user_script,) + user_args
user_script = m
m = True
else:
if user_script is None:
click.echo("Error: missing script argument. Did you use a wrong option? Try colossalai run --help")
exit()
if not user_script.endswith(".py"):
click.echo(
f"Error: invalid Python file {user_script}. Did you use a wrong option? Try colossalai run --help"
)
exit()
m = False
args_dict = locals()
args = Config(args_dict)

View File

@ -113,6 +113,7 @@ def get_launch_command(
user_args: List[str],
node_rank: int,
num_nodes: int,
run_as_module: bool,
extra_launch_args: str = None,
) -> str:
"""
@ -155,6 +156,8 @@ def get_launch_command(
torch_version = version.parse(torch.__version__)
assert torch_version.major >= 1
if torch_version.major < 2 and run_as_module:
raise ValueError("Torch version < 2.0 does not support running as module")
if torch_version.major == 1 and torch_version.minor < 9:
# torch distributed launch cmd with torch < 1.9
@ -198,7 +201,10 @@ def get_launch_command(
]
cmd += _arg_dict_to_list(default_torchrun_rdzv_args)
cmd += _arg_dict_to_list(extra_launch_args) + [user_script] + user_args
cmd += _arg_dict_to_list(extra_launch_args)
if run_as_module:
cmd.append("-m")
cmd += [user_script] + user_args
cmd = " ".join(cmd)
return cmd
@ -294,6 +300,7 @@ def launch_multi_processes(args: Config) -> None:
user_args=args.user_args,
node_rank=node_id,
num_nodes=len(active_device_pool),
run_as_module=args.m,
extra_launch_args=args.extra_launch_args,
)
runner.send(hostinfo=hostinfo, cmd=cmd)

View File

@ -64,7 +64,10 @@ class ProcessGroupMesh:
system resources.
"""
for group in self._ranks_to_group.values():
dist.destroy_process_group(group)
try:
dist.destroy_process_group(group)
except ValueError:
pass
# Manually clear all process groups to save memory
gc.collect()

View File

@ -62,7 +62,7 @@ engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
# Step 4: try inference
prompts = ['Who is the best player in the history of NBA?']
response = engine.generate(prompts)
response = engine.generate(prompts=prompts)
pprint(response)
```

View File

@ -1,4 +1,102 @@
import re
from typing import Dict, Set
import torch
import torch.nn as nn
from peft import PeftModel, PeftType
def extract_lora_layers(model: PeftModel, names: Set[str], adapter_name: str = "default"):
config = model.peft_config[adapter_name]
if config.peft_type != PeftType.LORA:
raise ValueError(f"Adapter {adapter_name} is not a LORA adapter.")
# to_return = lora_state_dict(model, bias=model.peft_config.bias)
# adapted from `https://github.com/microsoft/LoRA/blob/main/loralib/utils.py`
# to be used directly with the state dict which is necessary when using DeepSpeed or FSDP
bias = config.bias
if bias == "none":
to_return = {k for k in names if "lora_" in k}
elif bias == "all":
to_return = {k for k in names if "lora_" in k or "bias" in k}
elif bias == "lora_only":
to_return = set()
for k in names:
if "lora_" in k:
to_return.add(k)
bias_name = k.split("lora_")[0] + "bias"
if bias_name in names:
to_return.add(bias_name)
else:
raise NotImplementedError
to_return = {k for k in to_return if (("lora_" in k and adapter_name in k) or ("bias" in k))}
if config.use_dora:
# Here we take care of a refactor of DoRA which changed lora_magnitude_vector from a ParameterDict to a
# ModuleDict with a DoraLayer instance. The old parameter is now the "weight" attribute of that layer. Since
# we want the state_dict format not to change, we remove the "weight" part.
new_dora_suffix = f"lora_magnitude_vector.{adapter_name}.weight"
def renamed_dora_weights(k):
if k.endswith(new_dora_suffix):
k = k[:-7] # remove ".weight"
return k
to_return = {renamed_dora_weights(k) for k in to_return}
to_return = {re.sub(f"lora_\S\.{adapter_name}\.(weight|bias)", "base_layer", k) for k in to_return}
return to_return
class PeftUnwrapMixin:
def __init__(self, peft_model: PeftModel):
self.base_model = peft_model.get_base_model()
# peft does not affect buffers
self.lora_layers = extract_lora_layers(peft_model, set(n for n, p in self.base_model.named_parameters()))
potential_lora_weights = set()
for n in self.lora_layers:
potential_lora_weights.add(f"{n}.weight")
potential_lora_weights.add(f"{n}.bias")
self.lora_param_to_origin_param = {n: n.replace("base_layer.", "") for n in potential_lora_weights}
self.origin_param_to_lora_param = {v: k for k, v in self.lora_param_to_origin_param.items()}
def named_parameters(self):
for n, p in self.base_model.named_parameters():
if n in self.lora_param_to_origin_param:
n = self.lora_param_to_origin_param[n]
yield n, p
def named_buffers(self):
return self.base_model.named_buffers()
@property
def _modules(self):
return self.base_model._modules
@property
def _non_persistent_buffers_set(self):
return self.base_model._non_persistent_buffers_set
def patch_state_dict(self, state_dict: Dict[str, torch.Tensor]):
new_state_dict = {}
for k, v in state_dict.items():
if k in self.origin_param_to_lora_param:
k = self.origin_param_to_lora_param[k]
new_state_dict[k] = v
return new_state_dict
def state_dict(self):
state_dict = {}
for k, v in self.base_model.state_dict().items():
if k in self.lora_param_to_origin_param:
k = self.lora_param_to_origin_param[k]
state_dict[k] = v
return state_dict
def load_state_dict(self, state_dict, strict: bool = True, assign: bool = False):
state_dict = self.patch_state_dict(state_dict)
self.base_model.load_state_dict(state_dict, strict=strict, assign=assign)
def __hash__(self):
return hash(self.base_model)
class ModelWrapper(nn.Module):
@ -13,13 +111,17 @@ class ModelWrapper(nn.Module):
super().__init__()
self.module = module
def unwrap(self):
def unwrap(self, unwrap_peft: bool = True):
"""
Unwrap the model to return the original model for checkpoint saving/loading.
"""
if isinstance(self.module, ModelWrapper):
return self.module.unwrap()
return self.module
model = self.module.unwrap()
else:
model = self.module
if unwrap_peft and isinstance(model, PeftModel):
model = PeftUnwrapMixin(model)
return model
def forward(self, *args, **kwargs):
return self.module(*args, **kwargs)

View File

@ -49,14 +49,31 @@ class OptimizerWrapper:
"""
self.optim.zero_grad(*args, **kwargs)
def backward(self, loss: Tensor, *args, **kwargs):
def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs):
"""
Performs a backward pass on the loss.
"""
loss.backward(*args, **kwargs)
loss.backward(inputs=inputs, retain_graph=retain_graph, **kwargs)
def backward_by_grad(self, tensor: Tensor, grad: Tensor):
torch.autograd.backward(tensor, grad)
def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False):
"""
Performs a backward pass for dx or dw,
for dx, we only calculate dx = w*dy here
for dw, we only calculate dw = x*dy here
Args:
tensor (Tensor): y or loss of current chunk;
grad_tensors (Tensor): dy of current chunk;
input_obj (Tensor): for dx, input_obj is x of current chunk;
for dw, input_obj is w of current chunk;
retain_graph (bool): default to be True, we retain graph in backward_b
"""
torch.autograd.backward(
tensors=tensor,
grad_tensors=grad,
inputs=inputs,
retain_graph=retain_graph,
)
def state_dict(self):
"""
@ -135,6 +152,18 @@ class OptimizerWrapper:
"""
return self.optim
def get_grad_norm(self, norm_type: Union[float, int] = 2.0, **kwargs) -> Optional[float]:
"""
Returns the gradient norm of an iterable of parameters. This method should be called after optimizer.step().
Args:
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for infinity norm.
Returns:
Optional[float]: Total norm of the gradients (viewed as a single vector). If there are no valid gradients, returns None.
"""
raise NotImplementedError("The method get_grad_norm is not implemented yet.")
class DistributedOptim(Optimizer):
def setup_distributed(

View File

@ -104,7 +104,7 @@ def _data_tolist(tensor: torch.Tensor) -> list:
return tensor.data.tolist()
def _convert_cls(tensor: "LazyTensor", target: torch.Tensor) -> torch.Tensor:
def _convert_cls(tensor: "LazyTensor", target: torch.Tensor, requires_grad=None) -> torch.Tensor:
"""Convert a lazy tensor's class to target's class, with target's data.
The reason why we change the class of a lazy tensor in-place is that this can easily handle shared modules/parameters, which is common in huggingface models.
@ -117,13 +117,14 @@ def _convert_cls(tensor: "LazyTensor", target: torch.Tensor) -> torch.Tensor:
Returns:
torch.Tensor: the converted tensor
"""
requires_grad = target.requires_grad if requires_grad is None else requires_grad
cls_to_become = Parameter if isinstance(tensor, Parameter) else torch.Tensor
tensor.__class__ = cls_to_become
if cls_to_become is Parameter:
# to fit UninitializedParameter
delattr(tensor, "_is_param")
tensor.data = target
tensor.requires_grad = target.requires_grad
tensor.requires_grad = requires_grad
# subclass of torch.Tensor does not have tolist() method
# overwrite this method after materialization or distribution
tensor.tolist = MethodType(_data_tolist, tensor)
@ -212,9 +213,10 @@ class LazyTensor(torch.Tensor):
Returns:
torch.Tensor: The materialized tensor (self).
"""
requires_grad = self.requires_grad
target = self._materialize_data()
self.clean()
return _convert_cls(self, target)
return _convert_cls(self, target, requires_grad=requires_grad)
def clean(self) -> None:
"""Clean all stored operations, meta data and materialized data, which prevents memory leaking. This should be called after all tensors are materialized."""
@ -509,9 +511,9 @@ class LazyInitContext:
# factory_like functions (eg. torch.empty_like())
def wrapper(*args, **kwargs):
orig_t = args[0]
return self.tensor_cls(
orig_target, *orig_t.shape, *args[1:], device=orig_t.device, dtype=orig_t.dtype, **kwargs
)
device = kwargs.pop("device", orig_t.device)
dtype = kwargs.pop("dtype", orig_t.dtype)
return self.tensor_cls(orig_target, *orig_t.shape, *args[1:], device=device, dtype=dtype, **kwargs)
return wrapper, target

View File

@ -171,7 +171,7 @@ def _communicate(
for req in reqs:
req.wait()
# To protect against race condition when using batch_isend_irecv().
torch.cuda.synchronize()
get_accelerator().synchronize()
if recv_prev and recv_prev_split:
if isinstance(tensor_recv_prev, torch.Tensor):

View File

@ -81,6 +81,14 @@ class CPUAdam(NVMeOptimizer):
# if you find yourself stuck here, make sure that you install colossalai with BUILD_EXT=1 specification
self.cpu_adam_op = cpu_adam.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode)
def load_state_dict(self, state_dict):
super().load_state_dict(state_dict)
for group in self.param_groups:
for p in group["params"]:
state = self.state[p]
if "step" in state and isinstance(state["step"], torch.Tensor):
state["step"] = int(state["step"].item())
def torch_adam_update(
self,
data,

View File

@ -1,11 +1,12 @@
from .p2p import PipelineP2PCommunication
from .schedule import InterleavedSchedule, OneForwardOneBackwardSchedule, PipelineSchedule
from .schedule import InterleavedSchedule, OneForwardOneBackwardSchedule, PipelineSchedule, ZeroBubbleVPipeScheduler
from .stage_manager import PipelineStageManager
__all__ = [
"PipelineSchedule",
"OneForwardOneBackwardSchedule",
"InterleavedSchedule",
"ZeroBubbleVPipeScheduler",
"PipelineP2PCommunication",
"PipelineStageManager",
]

View File

@ -14,6 +14,8 @@ from torch.distributed import ProcessGroup
from torch.distributed import distributed_c10d as c10d
from torch.utils._pytree import tree_flatten, tree_unflatten
from colossalai.accelerator import get_accelerator
from .stage_manager import PipelineStageManager
@ -31,7 +33,7 @@ def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -
buf = tensor.numpy().tobytes()[:tensor_size]
if b"cuda" in buf:
buf_array = bytearray(buf)
device_index = torch.cuda.current_device()
device_index = get_accelerator().current_device()
# There might be more than one output tensors during forward
for cuda_str in re.finditer(b"cuda", buf_array):
pos = cuda_str.start()
@ -86,7 +88,7 @@ def _broadcast_object_list(
else:
current_device = torch.device("cpu")
if is_nccl_backend:
current_device = torch.device("cuda", torch.cuda.current_device())
current_device = torch.device("cuda", get_accelerator().current_device())
my_rank = dist.get_rank()
# Serialize object_list elements to tensors on src rank.
@ -139,14 +141,14 @@ def _broadcast_object_list(
# unconsistence in device
if (
isinstance(unpickle_object, torch.Tensor)
and unpickle_object.device.index != torch.cuda.current_device()
and unpickle_object.device.index != get_accelerator().current_device()
):
unpickle_object = unpickle_object.cuda()
unpickle_object = unpickle_object.to(get_accelerator().current_device())
object_list[i] = unpickle_object
def _check_for_nccl_backend(group):
def _check_for_nccl_hccl_backend(group):
pg = group or c10d._get_default_group()
# Gate PG wrapper check on Gloo availability.
if c10d._GLOO_AVAILABLE:
@ -154,14 +156,16 @@ def _check_for_nccl_backend(group):
while isinstance(pg, c10d._ProcessGroupWrapper):
pg = pg.wrapped_pg
return c10d.is_nccl_available() and pg.name() == c10d.Backend.NCCL
return (c10d.is_nccl_available() or torch.distributed.is_hccl_available()) and (
pg.name() == c10d.Backend.NCCL or pg.name() == c10d.Backend.HCCL
)
def _check_device(group):
is_nccl_backend = _check_for_nccl_backend(group)
is_nccl_backend = _check_for_nccl_hccl_backend(group)
current_device = torch.device("cpu")
if is_nccl_backend:
current_device = torch.device("cuda", torch.cuda.current_device())
current_device = torch.device(get_accelerator().current_device())
return current_device, is_nccl_backend
@ -348,8 +352,11 @@ def _send_recv_serialization_object(
unpickle_object = _cuda_safe_tensor_to_object(recv_object_tensor, recv_object_size_tensor.item())
if isinstance(unpickle_object, torch.Tensor) and unpickle_object.device.index != torch.cuda.current_device():
unpickle_object = unpickle_object.cuda()
if (
isinstance(unpickle_object, torch.Tensor)
and unpickle_object.device.index != get_accelerator().current_device()
):
unpickle_object = unpickle_object.to(get_accelerator().current_device())
return unpickle_object
@ -432,7 +439,6 @@ def _communicate(
overlap_p2p=overlap_p2p,
send_first=send_first if send_first != None else True,
)
if metadata_recv is not None:
assert isinstance(metadata_recv, P2PMetadata)
tree_spec = metadata_recv.tree_spec
@ -475,9 +481,11 @@ def _p2p_comm(
recv_prev_shape = None
if tensor_send_next is not None:
send_next_shape = torch.tensor(tensor_send_next.size(), device=torch.cuda.current_device(), dtype=torch.int64)
send_next_shape = torch.tensor(
tensor_send_next.size(), device=get_accelerator().current_device(), dtype=torch.int64
)
if recv_prev:
recv_prev_shape = torch.empty((3), device=torch.cuda.current_device(), dtype=torch.int64)
recv_prev_shape = torch.empty((3), device=get_accelerator().current_device(), dtype=torch.int64)
ops = []
if send_next_shape is not None:
@ -502,7 +510,7 @@ def _p2p_comm(
# send and recv data
tensor_recv_prev = None
if recv_prev:
tensor_recv_prev = torch.empty(recv_prev_shape, device=torch.cuda.current_device(), dtype=comm_dtype)
tensor_recv_prev = torch.empty(recv_prev_shape, device=get_accelerator().current_device(), dtype=comm_dtype)
ops = []
if tensor_send_next is not None:

View File

@ -1,9 +1,11 @@
from .base import PipelineSchedule
from .interleaved_pp import InterleavedSchedule
from .one_f_one_b import OneForwardOneBackwardSchedule
from .zero_bubble_pp import ZeroBubbleVPipeScheduler
__all__ = [
"PipelineSchedule",
"OneForwardOneBackwardSchedule",
"InterleavedSchedule",
"ZeroBubbleVPipeScheduler",
]

View File

@ -137,6 +137,16 @@ def retain_grad(x: Any) -> None:
x.retain_grad()
def require_grad(x: Any) -> None:
"""Call require_grad on a tensor.
Args:
x (Any): Object to be called.
"""
if isinstance(x, torch.Tensor) and not x.requires_grad:
x.requires_grad_()
def detach(x: Any) -> Any:
"""Call detach() on a tensor.
@ -151,6 +161,34 @@ def detach(x: Any) -> Any:
return x
def clone(x: Any) -> Any:
"""Call clone() on a tensor.
Args:
x (Any): Object to be called.
Returns:
Any: The cloned object.
"""
if isinstance(x, torch.Tensor):
return x.clone()
return x
def release_tensor_data(x: Any) -> Any:
"""Call untyped_storage().resize_(0) on a tensor. Use to release tensor.data and keep grad_fn.
Args:
x (Any): Object to be called.
Returns:
Any: The deallocate .data object.
"""
if isinstance(x, torch.Tensor):
return x.data.untyped_storage().resize_(0)
return x
def merge_batch(data: List[Any], batch_size_dim=0) -> Any:
"""Merge micro batches into a batch.

View File

@ -2,7 +2,6 @@ from functools import partial
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
import torch
import torch.cuda
import torch.distributed
from torch.nn import Module, ModuleList
from torch.utils._pytree import tree_map
@ -18,7 +17,7 @@ from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_
from .base import PipelineSchedule
def _wait_p2p(wait_handles: List[torch.cuda.Event]) -> None:
def _wait_p2p(wait_handles) -> None:
if wait_handles is not None:
for req in wait_handles:
req.wait()

View File

@ -2,7 +2,6 @@ from functools import partial
from typing import Any, Callable, Dict, Iterable, List, Optional, Union
import torch
import torch.cuda
from torch.nn import Module
from torch.utils._pytree import tree_map

View File

@ -0,0 +1,449 @@
# Refer from Zero Bubble Pipeline Parallelism.
# Github: https://github.com/sail-sg/zero-bubble-pipeline-parallelism
# Paper: https://arxiv.org/abs/2401.10241
# The following applies to all files unless otherwise noted:
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
from collections import deque
from dataclasses import dataclass
@dataclass(eq=True, frozen=True)
class ScheduledNode:
type: str
chunk: int
stage: int
minibatch: int
start_time: int = 0
completion_time: int = 0
rollback: bool = False
class PipelineGraph(object):
"""PipelineGraph"""
def __init__(
self,
n_stage,
n_micro,
f_cost,
b_cost,
w_cost,
c_cost,
f_mem,
b_mem,
w_mem,
max_mem=None,
):
self.n_node = 6 * n_stage * n_micro
self.n_stage = n_stage
self.n_micro = n_micro
self.f_cost = f_cost
self.b_cost = b_cost
self.w_cost = w_cost
self.c_cost = c_cost
self.f_mem = f_mem
self.b_mem = b_mem
self.w_mem = w_mem
self.fbw_cost = [f_cost, b_cost, w_cost]
self.fbw_mem = [f_mem, b_mem, w_mem]
self.max_mem = max_mem or f_mem * self.n_stage * 2
def get_id(self, cat, chunk, stage, micro):
return (
cat * 2 * self.n_stage * self.n_micro + chunk * self.n_stage * self.n_micro + stage * self.n_micro + micro
)
def try_v_schedule(self, fill_f=True, fill_b=True, approved_bubble=None):
count = []
for i in range(self.n_stage):
count.append([0] * 6)
end_time = [-1] * self.n_node
cur_time = [0] * self.n_stage
mem = [0] * self.n_stage
stage_bubble = [0] * self.n_stage
pending_w = [deque() for _ in range(self.n_stage)]
schedule = [[] for _ in range(self.n_stage)]
stage_str = [" " * i for i in range(self.n_stage)]
if approved_bubble is None:
approved_bubble = [-1] * self.n_stage
max_approved_bubble = max(approved_bubble)
def get_max_stage_bubble(stage=-1):
max_stage_bubble = 0
for bb in stage_bubble:
max_stage_bubble = max(max_stage_bubble, bb)
if stage >= 0:
max_stage_bubble = max(max_stage_bubble, max_approved_bubble - approved_bubble[stage])
return max_stage_bubble
def put_w(stage):
assert len(pending_w[stage]) > 0
_, chunk_, _ = pending_w[stage].popleft()
put(2, chunk_, stage)
def put(cat, chunk, stage, assert_cnt=True):
_tmp = _no_bubble = cur_time[stage] + self.fbw_cost[cat]
_cnt = count[stage][cat * 2 + chunk]
# assert _cnt < self.n_micro
if _cnt >= self.n_micro:
if not assert_cnt:
stage_str[stage] += " "
cur_time[stage] = _tmp # TODO
return
assert False
assert mem[stage] + self.fbw_mem[cat] <= self.max_mem
stage_str[stage] += "FfBbWw"[cat * 2 + chunk] + str(_cnt + 1) + " " * (3 - len(str(_cnt + 1)))
if cat > 0 or chunk > 0:
last_id = cat * 2 + chunk - 1
if cat < 2:
assert end_time[self.get_id(last_id // 2, last_id % 2, stage, _cnt)] >= 0
else:
assert end_time[self.get_id(1, chunk, stage, _cnt)] >= 0
if chunk == 1 and cat < 2:
if stage < self.n_stage - 1:
_fa_id = self.get_id(cat, chunk, stage + 1, _cnt)
assert end_time[_fa_id] >= 0
_tmp = max(_tmp, end_time[_fa_id] + self.c_cost + self.fbw_cost[cat])
if chunk == 0 and cat < 2:
if stage > 0:
_fa_id = self.get_id(cat, chunk, stage - 1, _cnt)
assert end_time[_fa_id] >= 0, f"{cat}, {chunk}, {stage}, {_cnt}"
_tmp = max(_tmp, end_time[_fa_id] + self.c_cost + self.fbw_cost[cat])
_id = self.get_id(cat, chunk, stage, _cnt)
if count[stage][0] > 0:
stage_bubble[stage] += _tmp - _no_bubble
end_time[_id] = _tmp
cur_time[stage] = _tmp
mem[stage] += self.fbw_mem[cat]
# noinspection PyTypeChecker
schedule[stage].append((cat, chunk, _cnt))
if cat == 1:
pending_w[stage].append((2, chunk, _cnt))
count[stage][cat * 2 + chunk] += 1
for i in range(self.n_stage):
put(0, 0, i)
for i in range(self.n_stage - 1, -1, -1):
if i == self.n_stage - 1:
put(0, 1, i)
continue
tmp = end_time[self.get_id(0, 1, i + 1, 0)] + self.c_cost
while (
mem[i] + self.fbw_mem[0] * (2 + i * 2) <= self.max_mem
and cur_time[i] + self.fbw_cost[0] <= tmp
and count[i][0] < self.n_micro
):
for j in range(i + 1):
put(0, 0, j)
put(0, 1, i)
iter_chunk_ = 0
end_tmp = 0
for i in range(self.n_stage):
if i == 0:
end_tmp = cur_time[0] + self.fbw_cost[1]
continue
tmp = end_tmp + self.c_cost
while (
count[i][0] + count[i][1] < count[i - 1][0] + count[i - 1][1]
or count[i][1] <= count[i - 1][1] < self.n_micro
):
for j in range(self.n_stage - 1, i - 1, -1):
if count[j][iter_chunk_] < self.n_micro:
put(0, iter_chunk_, j)
iter_chunk_ = 1 - iter_chunk_
for _ in range(2 * self.n_micro):
# check mem before putting b
for i in range(self.n_stage):
while mem[i] + self.fbw_mem[1] > self.max_mem:
assert len(pending_w[i]) > 0
put_w(i)
b0_ranks, b1_ranks = [], []
for i in range(self.n_stage):
if count[i][3] >= count[i][2]:
b0_ranks.append(i)
elif i == self.n_stage - 1:
b1_ranks.append(i)
else:
fa_id = self.get_id(1, 1, i + 1, count[i][3])
if end_time[fa_id] >= 0 or count[i][2] >= self.n_micro:
b1_ranks.append(i)
else:
b0_ranks.append(i)
b_ranks = []
# put b1
for i in reversed(b1_ranks):
b_ranks.append((i, 1))
# put b0
for i in b0_ranks:
b_ranks.append((i, 0))
for i, _chunk_ in b_ranks:
fa_id = -1
if _chunk_ == 1 and i < self.n_stage - 1:
fa_id = self.get_id(1, 1, i + 1, count[i][3])
if _chunk_ == 0 and i > 0:
fa_id = self.get_id(1, 0, i - 1, count[i][2])
while (
len(pending_w[i]) > 0
and fa_id >= 0
and end_time[fa_id] + self.c_cost >= cur_time[i] + self.fbw_cost[2]
):
# fill the bubble
put_w(i)
if (
len(pending_w[i]) > 0
and end_time[fa_id] + self.c_cost - cur_time[i] > get_max_stage_bubble(i) - stage_bubble[i]
):
if _chunk_ == 1:
put_w(i)
elif fill_b:
put_w(i)
put(1, _chunk_, i)
# put f
for i in range(self.n_stage):
if count[i][1] >= self.n_micro:
continue
put_item = None
if count[i][1] >= count[i][0]:
put_item = 0
elif i == self.n_stage - 1:
put_item = 1
else:
if end_time[self.get_id(0, 1, i + 1, count[i][1])] >= 0:
put_item = 1
elif count[i][0] < self.n_micro:
if i == 0:
put_item = 0
elif end_time[self.get_id(0, 0, i - 1, count[i][0])] >= 0:
put_item = 0
if put_item is None:
continue
# check mem before putting f
while mem[i] + self.fbw_mem[0] > self.max_mem:
assert len(pending_w[i]) > 0
put_w(i)
fa_id = -1
if put_item == 0 and i > 0:
fa_id = self.get_id(0, 0, i - 1, count[i][0])
if put_item == 1 and i < self.n_stage - 1:
fa_id = self.get_id(0, 1, i + 1, count[i][1])
while (
len(pending_w[i]) > 0
and fa_id >= 0
and end_time[fa_id] + self.c_cost >= cur_time[i] + self.fbw_cost[2]
):
# fill the bubble
put_w(i)
if (
len(pending_w[i]) > 0
and end_time[fa_id] + self.c_cost - cur_time[i] > get_max_stage_bubble(i) - stage_bubble[i]
):
if fill_f:
put_w(i)
put(0, put_item, i)
for i in range(self.n_stage):
while len(pending_w[i]) > 0:
put_w(i)
max_bubble = get_max_stage_bubble()
expected_time = sum(self.fbw_cost) * self.n_micro * 2
max_bubble / expected_time
if max_approved_bubble < 0 or max_bubble < max_approved_bubble:
_schedule, _end_time, _max_bubble = self.try_v_schedule(
fill_f=fill_f,
fill_b=fill_b,
approved_bubble=stage_bubble,
)
if _max_bubble < max_bubble:
return _schedule, _end_time, _max_bubble
return schedule, end_time, max_bubble
def print_details(self, end_time, print_scaling=1):
for stage in range(self.n_stage):
stage_str = ["."] * int(max(end_time) / print_scaling)
for _cat in range(3):
for _chunk in range(2):
for _micro in range(self.n_micro):
_id = self.get_id(_cat, _chunk, stage, _micro)
if end_time[_id] < 0:
continue
end = int(end_time[_id] / print_scaling)
start = int((end_time[_id] - self.fbw_cost[_cat]) / print_scaling)
for j in range(start, end):
if j == start or j == end - 1:
stage_str[j] = "FfBbWw"[_cat * 2 + _chunk]
elif j == start + 1:
if _micro >= 10:
stage_str[j] = str(_micro // 10)
else:
stage_str[j] = str(_micro)
elif j == start + 2 and _micro >= 10:
stage_str[j] = str(_micro % 10)
else:
stage_str[j] = "-"
_str = ""
for _c in stage_str:
_str += _c
print(_str)
def get_v_schedule(self, only_run_time=False):
schedule, end_time, max_bubble = None, None, None
expected_time = sum(self.fbw_cost) * self.n_micro * 2
for fill_b in [True, False]:
for fill_f in [True, False]:
_schedule, _end_time, _max_bubble = self.try_v_schedule(fill_b=fill_b, fill_f=fill_f)
if max_bubble is None or _max_bubble < max_bubble:
max_bubble = _max_bubble
schedule = _schedule
end_time = _end_time
if only_run_time:
return max_bubble + expected_time
max_bubble / (expected_time + max_bubble)
local_order = [[] for _ in range(self.n_stage)]
comm_id = {}
comm_id_counter = 0
post_validation_time = 0
for i in range(self.n_stage - 1, -1, -1):
pv_id = min(2 * (self.n_stage - 1 - i), self.n_micro - 1)
post_validation_time = max(
post_validation_time, end_time[self.get_id(0, 0, i, pv_id)] - self.fbw_cost[0] - self.c_cost
)
# post_validation_time = 0
for it in ["RECV_", "SEND_", ""]:
if i == 0 and it == "SEND_":
continue
if i == self.n_stage - 1 and it == "RECV_":
continue
# stage_ = i - 1 if it == "RECV_" else i
stage_ = i
local_order[stage_].append(
ScheduledNode(
type=it + "POST_VALIDATION",
chunk=0,
stage=stage_,
minibatch=0,
start_time=post_validation_time,
completion_time=post_validation_time,
)
)
comm_id[local_order[stage_][-1]] = comm_id_counter
comm_id_counter += 1
for i in range(self.n_stage):
for _cat_, _chunk_, _micro_ in schedule[i]:
complete_time = end_time[self.get_id(_cat_, _chunk_, i, _micro_)]
local_order[i].append(
ScheduledNode(
type="FBW"[_cat_],
chunk=_chunk_ if _cat_ == 0 else 1 - _chunk_,
stage=i,
minibatch=_micro_,
start_time=complete_time - self.fbw_cost[_cat_],
completion_time=complete_time,
)
)
if _cat_ == 2: # no communication for W
continue
cat_str = "FORWARD" if _cat_ == 0 else "BACKWARD"
def communicate(send_recv, stage_):
# noinspection PyTypeChecker
local_order[stage_].append(
ScheduledNode(
type=send_recv + cat_str,
chunk=_chunk_ if _cat_ == 0 else 1 - _chunk_,
stage=stage_,
minibatch=_micro_,
start_time=complete_time,
completion_time=complete_time,
)
)
comm_id[local_order[stage_][-1]] = comm_id_counter
if _chunk_ == 1 and i > 0:
communicate("SEND_", i)
communicate("RECV_", i - 1)
if _chunk_ == 0 and i < self.n_stage - 1:
communicate("SEND_", i)
communicate("RECV_", i + 1)
comm_id_counter += 1
for rank in range(self.n_stage):
# For nodes with the same timestamp on the same stage, communication will be prioritized.
def even_breaker(x: ScheduledNode):
# Compute nodes are always delayed.
if x.type in ["F", "B", "W"]:
return comm_id_counter
# For comm nodes, order by their unique comm id
return comm_id[x]
local_order[rank] = list(sorted(local_order[rank], key=lambda x: (x.start_time, even_breaker(x))))
# If a recv with intersects with previous computation, reorder them so that recv
# is executed before computation and hence can be overlapped.
for i in range(len(local_order[rank])):
if (
i > 0
and local_order[rank][i - 1].type in {"F", "B", "W"}
and local_order[rank][i].type.startswith("RECV")
and "POST_VALIDATION" not in local_order[rank][i].type
and local_order[rank][i].start_time <= local_order[rank][i - 1].completion_time
):
local_order[rank][i], local_order[rank][i - 1] = local_order[rank][i - 1], local_order[rank][i]
local_order_with_rollback = [[] for _ in range(self.n_stage)]
for rank in range(self.n_stage):
rollback_comm = set()
if rank > 0:
for node in local_order[rank - 1]:
if node.type == "POST_VALIDATION":
break
if node.type == "SEND_FORWARD":
assert node.chunk == 0
rollback_comm.add(node.minibatch)
for node in local_order[rank]:
if node.type == "RECV_FORWARD" and node.chunk == 0 and node.minibatch in rollback_comm:
rollback = True
rollback_comm.remove(node.minibatch)
else:
rollback = False
local_order_with_rollback[rank].append(
ScheduledNode(
type=node.type,
chunk=node.chunk,
stage=node.stage,
minibatch=node.minibatch,
start_time=node.start_time,
completion_time=node.completion_time,
rollback=rollback,
)
)
assert len(rollback_comm) == 0
return local_order_with_rollback

View File

@ -0,0 +1,971 @@
from functools import partial
from typing import Any, Callable, Dict, Iterable, List, Optional, Union
import torch
import torch.cuda
import torch.distributed
from torch.nn import Module, ModuleList
from torch.utils._pytree import tree_flatten, tree_map
from colossalai.accelerator import get_accelerator
from colossalai.interface import OptimizerWrapper
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata
from colossalai.pipeline.schedule.v_schedule import ScheduledNode
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.pipeline.weight_grad_store import WeightGradStore
from ._utils import (
clone,
detach,
get_batch_size,
get_micro_batch,
merge_batch,
model_forward,
release_tensor_data,
require_grad,
retain_grad,
to_device,
)
from .base import PipelineSchedule
AUTO_SCHEDULE_COMMUNICATION_TYPES = {"RECV_FORWARD", "RECV_BACKWARD", "SEND_FORWARD", "SEND_BACKWARD"}
def _wait_p2p(wait_handles: List[torch.cuda.Event]) -> None:
if wait_handles is not None:
for req in wait_handles:
req.wait()
class ZeroBubbleVPipeScheduler(PipelineSchedule):
r"""
ZeroBubbleVPipeScheduler
Args:
stage_manager (PipelineStageManager): If using pipeline parallelism, it's necessary to specify a pipeline stage manager for inter-process communication in pipeline parallelism. Defaults to None, which means not using pipeline parallelism.
schedule (List[ScheduledNode]): Schedule for ZeroBubbleVPipe.
num_model_chunks (int) : The number of model chunk in a device.
num_microbatch (Optional[int]): The number of microbatch.
microbatch_size (Optional[int]): The size per microbatch.
enable_metadata_cache (bool): whether to enable metadata cache to acclerate communication.
overlap_p2p (bool): whether to use overlap_p2p.
"""
def __init__(
self,
stage_manager: PipelineStageManager,
schedule: List[ScheduledNode],
num_model_chunks: int,
num_microbatch: Optional[int] = None,
microbatch_size: Optional[int] = None,
enable_metadata_cache: bool = True,
overlap_p2p: bool = True,
):
super().__init__(stage_manager)
# batch info
self.num_microbatch = num_microbatch
self.microbatch_size = microbatch_size
self.num_model_chunks = num_model_chunks
self.batch: Any
self.batch_size: int
self.last_batch_size: Optional[int] = None
self.microbatch_offset: List[int]
self.schedules = schedule
# TODO: optim post valid
self.do_post_validation = False
# P2PMeta cache
self.enable_metadata_cache = enable_metadata_cache
# check send_tensor_metadata, send_grad_metadata
# pp4 as sample, we should follow this meta strategy
# send_tensor_meta(fwd) send_grad_meta(bwd)
# chunk0 | chunk1 chunk0 | chunk 1
# stage 0 T | F F | T
# stage 1 T | T T | T
# stage 2 T | T T | T
# stage 3 F | T F | T
if stage_manager.is_first_stage(ignore_chunk=True):
self.send_tensor_metadata = [True, False]
self.send_grad_metadata = [False, True]
elif stage_manager.is_last_stage(ignore_chunk=True):
self.send_tensor_metadata = [False, True]
self.send_grad_metadata = [True, False]
else:
self.send_tensor_metadata = [True, True]
self.send_grad_metadata = [True, True]
# meta cache buffer
self.tensor_metadata_recv = [None, None] # [chunk 0 meta, chunk 1 meta]
self.grad_metadata_recv = [None, None]
# P2P communication
self.comm = PipelineP2PCommunication(stage_manager, overlap_p2p=overlap_p2p)
# init communication map
self.communication_map = {
"SEND_FORWARD": self.send_forward,
"RECV_FORWARD": self.recv_forward,
"SEND_BACKWARD": self.send_backward,
"RECV_BACKWARD": self.recv_backward,
}
# init buffer
self._free_buffers()
def _free_buffers(self):
# free local buffer
# two dim array, first dim is the model chunk, second dim is the microbatch queue
# x & y buffer for schedule b
self.input_tensors = [[], []]
self.output_tensors = [[], []]
# y & dy buffer for schedule w
self.output_tensors_dw = [[], []]
self.output_tensors_grad_dw = [[], []]
# buffer for communication
self.send_forward_buffer = [[], []] # [chunk0:[torch.Tensor], chunk1:[torch.Tensor]]
self.recv_forward_buffer = [
[],
[],
] # [chunk0:[(torch.Tensor, wait_handle)], chunk1:[(torch.Tensor, wait_handle)]]
self.send_backward_buffer = [[], []] # [chunk0:[torch.Tensor], chunk1:[torch.Tensor]]
self.recv_backward_buffer = [
[],
[],
] # [chunk0:[(torch.Tensor, wait_handle)], chunk1:[(torch.Tensor, wait_handle)]]
# y buffer for local send fwd
self.local_send_forward_buffer = []
# dy buffer for local send bwd
self.local_send_backward_buffer = []
# wait pp buffer
self.wait_handles = []
def assert_buffer_empty(self):
# assert buffer is empty at end
assert len(self.input_tensors[0]) == 0
assert len(self.input_tensors[1]) == 0
assert len(self.output_tensors[0]) == 0
assert len(self.output_tensors[1]) == 0
assert len(self.output_tensors_dw[0]) == 0
assert len(self.output_tensors_dw[1]) == 0
assert len(self.output_tensors_grad_dw[0]) == 0
assert len(self.output_tensors_grad_dw[1]) == 0
assert len(self.send_forward_buffer[0]) == 0
assert len(self.send_forward_buffer[1]) == 0
assert len(self.recv_forward_buffer[0]) == 0
assert len(self.recv_forward_buffer[1]) == 0
assert len(self.send_backward_buffer[0]) == 0
assert len(self.send_backward_buffer[1]) == 0
assert len(self.recv_backward_buffer[0]) == 0
assert len(self.recv_backward_buffer[1]) == 0
assert len(self.local_send_forward_buffer) == 0
assert len(self.local_send_backward_buffer) == 0
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
"""Load a batch from data iterator.
Args:
data_iter (Iterable): Data iterator.
device (Optional[torch.device], optional): Target device. Defaults to None.
"""
batch = next(data_iter)
if device is not None:
batch = tree_map(partial(to_device, device=device), batch)
self.microbatch_offset = [0 for _ in range(self.num_model_chunks)]
self.batch = batch
self.batch_size = get_batch_size(batch)
if self.microbatch_size is None:
assert self.batch_size % self.num_microbatch == 0, "Batch size should divided by the number of microbatch"
self.microbatch_size = self.batch_size // self.num_microbatch
if self.num_microbatch is None:
assert self.batch_size % self.microbatch_size == 0, "Batch size should divided by the microbatch size"
self.num_microbatch = self.batch_size // self.microbatch_size
if not self.forward_only:
assert self.last_batch_size is None or self.last_batch_size == self.batch_size
assert self.batch_size == self.microbatch_size * self.num_microbatch
assert (
self.num_microbatch % self.stage_manager.num_stages == 0
), "Number of microbatch should be an integer multiple of number of pipeline parallel devices"
if self.forward_only:
self.num_microbatch = (self.batch_size - 1) // self.microbatch_size + 1
self.last_batch_size = self.batch_size
def load_micro_batch(self, model_chunk_id: int) -> Any:
"""Load a micro batch from the current batch.
Args:
microbatch_id (int): the current model chunk idx.
Returns:
Any: Micro batch.
"""
assert self.microbatch_offset[model_chunk_id] <= self.batch_size, "Microbatches exhausted"
micro_batch = get_micro_batch(self.batch, self.microbatch_offset[model_chunk_id], self.microbatch_size)
self.microbatch_offset[model_chunk_id] += self.microbatch_size
return tree_map(partial(to_device, device=get_accelerator().get_current_device()), micro_batch)
def get_model_chunk_id(self, microbatch_id: int, is_forward: bool) -> int:
"""Helper method to get the model chunk ID given the iteration number.
Args:
microbatch_id (int): the current microbatch idx
forward (bool): if is the forward process
Returns:
int: The model chunk idx of the input microbatch_id
"""
assert (
microbatch_id < self.num_microbatch * self.num_model_chunks
), f"microbatch_id {microbatch_id} is out of range ({self.num_microbatch * self.num_model_chunks})"
microbatch_id_in_group = microbatch_id % (self.stage_manager.num_stages * self.num_model_chunks)
model_chunk_id = microbatch_id_in_group // self.stage_manager.num_stages
if not is_forward:
# Reverse order
model_chunk_id = self.num_model_chunks - model_chunk_id - 1
return model_chunk_id
def recv_forward(self, model_chunk_id: int, prev_rank: int = None) -> List:
"""Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
For ZBV.
Args:
model_chunk_id (int): The current model chunk idx.
prev_rank (int, optional): The rank of the source of the tensor.
Returns:
Any: The input tensor or input tensor list.
Any: The wait handles for the communication.
"""
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
if model_chunk_id == 0:
################
# chunk = 0 & is_first_stage
# do nothing; cause u are chunk 0 in first rank, u have no prev rank;
#################
if self.stage_manager.is_first_stage(ignore_chunk=True):
return []
################
# chunk = 0 & not is_first_stage
# Recv y from PREV_rank as input
#################
else:
prev_rank = self.stage_manager.get_prev_rank()
input_tensor, wait_handles = self.comm.recv_forward(
prev_rank=prev_rank, metadata_recv=self.tensor_metadata_recv[model_chunk_id]
)
if self.enable_metadata_cache and self.tensor_metadata_recv[model_chunk_id] is None:
self.tensor_metadata_recv[model_chunk_id] = create_send_metadata(input_tensor)
self.recv_forward_buffer[model_chunk_id].append((input_tensor, wait_handles))
return wait_handles
else:
################
# chunk = 1 & is_last_stage
# do nothing; cause u get y from local_send_forward_buffer in schedule f
################
if self.stage_manager.is_last_stage(ignore_chunk=True):
# return None, []
return []
################
# chunk = 1 & not is_last_stage
# recv y from NEXT_rank as input
################
else:
next_rank = self.stage_manager.get_next_rank()
input_tensor, wait_handles = self.comm.recv_forward(
next_rank, metadata_recv=self.tensor_metadata_recv[model_chunk_id]
)
if self.enable_metadata_cache and self.tensor_metadata_recv[model_chunk_id] is None:
self.tensor_metadata_recv[model_chunk_id] = create_send_metadata(input_tensor)
self.recv_forward_buffer[model_chunk_id].append((input_tensor, wait_handles))
return wait_handles
def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> List:
"""Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
For ZBV.
Args:
model_chunk_id (int): The current model chunk idx.
next_rank (int, optional): The rank of the source of the tensor.
Returns:
Any: The input gradient tensor or gradient tensor list.
Any: The wait handles for the communication.
"""
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
if model_chunk_id == 0:
# bwd chunk0 is right V;
################
# chunk = 0 & is_last_stage
# do nothing; Already get dy from local_send_backward_buffer in schedule b
################
if self.stage_manager.is_last_stage(ignore_chunk=True):
return []
################
# chunk = 0 & not is_last_stage
# Recv bwd from next stage;
################
else:
next_rank = self.stage_manager.get_next_rank()
output_tensor_grad, wait_handles = self.comm.recv_backward(
next_rank, metadata_recv=self.grad_metadata_recv[model_chunk_id]
)
if self.enable_metadata_cache and self.grad_metadata_recv[model_chunk_id] is None:
self.grad_metadata_recv[model_chunk_id] = create_send_metadata(output_tensor_grad)
self.recv_backward_buffer[model_chunk_id].append((output_tensor_grad, wait_handles))
return wait_handles
else:
# bwd chunk1 is left V;
################
# chunk = 1 & is_first_stage
# do nothing; get loss from local
################
if self.stage_manager.is_first_stage(ignore_chunk=True):
return []
################
# chunk = 1 & not first stage
# recv_backward recv bwd from prev stage;
################
else:
prev_rank = self.stage_manager.get_prev_rank()
output_tensor_grad, wait_handles = self.comm.recv_backward(
next_rank=prev_rank, metadata_recv=self.grad_metadata_recv[model_chunk_id]
)
if self.enable_metadata_cache and self.grad_metadata_recv[model_chunk_id] is None:
self.grad_metadata_recv[model_chunk_id] = create_send_metadata(output_tensor_grad)
self.recv_backward_buffer[model_chunk_id].append((output_tensor_grad, wait_handles))
return wait_handles
def send_forward(self, model_chunk_id: int, next_rank: int = None) -> List:
"""Sends the input tensor to the next stage in pipeline.
For ZBV.
Args:
model_chunk_id (int): The current model chunk idx.
next_rank (int, optional): The rank of the recipient of the tensor.
Returns:
Any: The wait handles for the communication.
"""
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
if model_chunk_id == 0:
################
# chunk = 0 && is_last_stage
# do nothing; hold y on local_send_forward_buffer
################
if self.stage_manager.is_last_stage(ignore_chunk=True):
self.send_tensor_metadata[model_chunk_id] = not self.enable_metadata_cache
return []
################
# chunk = 0 && not is_last_stage
# self.comm.send_forward send y to NEXT stage
################
else:
next_rank = self.stage_manager.get_next_rank()
output_tensor = self.send_forward_buffer[model_chunk_id].pop(0)
send_handles = self.comm.send_forward(
output_object=output_tensor,
next_rank=next_rank,
send_metadata=self.send_tensor_metadata[model_chunk_id],
)
self.send_tensor_metadata[model_chunk_id] = not self.enable_metadata_cache
return send_handles
else:
################
# chunk = 1 && is_first_stage
# do nothing; Already send LOSS to local_send_backward_buffer in schedule f send part
################
if self.stage_manager.is_first_stage(ignore_chunk=True):
self.send_tensor_metadata[model_chunk_id] = not self.enable_metadata_cache
return []
################
# chunk = 1 && not is_first_stage
# self.comm.send_forward send y to PREV stage
################
else:
prev_rank = self.stage_manager.get_prev_rank()
output_tensor = self.send_forward_buffer[model_chunk_id].pop(0)
send_handles = self.comm.send_forward(
output_tensor, prev_rank, send_metadata=self.send_tensor_metadata[model_chunk_id]
)
self.send_tensor_metadata[model_chunk_id] = not self.enable_metadata_cache
return send_handles
def send_backward(self, model_chunk_id: int, prev_rank: int = None) -> List:
"""Sends the gradient tensor to the previous stage in pipeline.
For ZBV.
Args:
model_chunk_id (int): The current model chunk idx.
prev_rank (int, optional): The rank of the recipient of the tensor
Returns:
Any: The wait handles for the communication.
"""
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
if model_chunk_id == 0:
# bwd chunk0 is right V;
################
# chunk = 0 && is_first_stage
# do nothing; cause u are the first chunk in first stage; bwd end
################
if self.stage_manager.is_first_stage(ignore_chunk=True):
self.send_grad_metadata[model_chunk_id] = not self.enable_metadata_cache
return []
################
# chunk = 0 && not is_first_stage
# Send dx to PREV stage;
################
else:
prev_rank = self.stage_manager.get_prev_rank()
input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0)
send_handles = self.comm.send_backward(
input_tensor_grad, prev_rank, send_metadata=self.send_grad_metadata[model_chunk_id]
)
self.send_grad_metadata[model_chunk_id] = not self.enable_metadata_cache
return send_handles
# bwd chunk1 is left V;
else:
################
# chunk = 1 && is_last_stage
# do nothing; Already send input_tensor_grad to local_send_bwd_buffer in schedule b;
################
if self.stage_manager.is_last_stage(ignore_chunk=True):
self.send_grad_metadata[model_chunk_id] = not self.enable_metadata_cache
return []
################
# chunk = 1 && not is_last_stage
# Send dx to NEXT stage;
################
else:
next_rank = self.stage_manager.get_next_rank()
input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0)
send_handles = self.comm.send_backward(
input_tensor_grad, next_rank, send_metadata=self.send_grad_metadata[model_chunk_id]
)
self.send_grad_metadata[model_chunk_id] = not self.enable_metadata_cache
return send_handles
def forward_step(
self,
model_chunk: Union[ModuleList, Module],
model_chunk_id: int,
micro_batch: Optional[dict],
input_obj: Optional[dict],
criterion: Callable,
accum_loss: Optional[torch.Tensor] = None,
outputs: Optional[List[Any]] = None,
) -> Union[torch.Tensor, dict]:
"""Forward one step of the pipeline
Args:
model_chunk (ModuleList or Module): Model Chunk to be run;
model_chunk_id (int): The current model chunk idx;
input_obj (Optional[dict]): x;
criterion (Callable): loss function;
accum_loss (Optional[torch.Tensor], optional): Accumulated loss. Defaults to None.
outputs (Optional[List[Any]], optional): List to store the output of the last stage (final output). Defaults to None.
Returns:
Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor).
"""
# Load input ids, attention mask and labels
# for the first stage, input_obj is None; So,we use micro_batch as input_obj
# for other stages, input_obj is the output of the previous/next stage containing hidden_states etc.
# Only attention_mask from micro_batch is used
with self.stage_manager.switch_model_chunk_id(model_chunk_id):
# fwd calculate
internal_inputs = {} if input_obj is None else input_obj
internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id]
output_obj = model_forward(model_chunk, micro_batch, internal_inputs)
# last layer in model
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
loss = criterion(output_obj, micro_batch) / self.num_microbatch
if accum_loss is not None:
accum_loss.add_(loss.detach())
if outputs is not None:
outputs.append(tree_map(detach, output_obj))
return loss
else:
return output_obj
def backward_b_step(
self,
model_chunk: Union[ModuleList, Module],
model_chunk_id: int,
optimizer: OptimizerWrapper,
# micro_batch: Optional[dict],
input_obj: Optional[dict],
output_obj: Union[dict, torch.Tensor],
output_obj_grad: Optional[dict],
) -> Optional[dict]:
"""Backward dx step of the pipeline; we calculate "dx = w*dy" here;
Args:
model_chunk (ModuleList or Module): Model Chunk to be run;
model_chunk_id (int): The current model chunk idx;
optimizer (OptimizerWrapper): Optimizer to update the model
input_obj (Optional[Tuple(dict)]): x. (microbatch, input_obj)
output_obj (Union[dict, torch.Tensor]): y.
output_obj_grad (dict): dy.
Returns:
Optional[dict]: dx.
"""
# calculate bwd b step ; only dx = w*dy;
# Retain the grad on the input_obj. No need retain_grad microbatch
if input_obj is not None:
tree_map(retain_grad, input_obj)
# x, y, dy list for backward_by_grad; Type: list[tensor];
input_obj_ = []
output_obj_ = []
output_obj_grad_ = []
# For chunk 0 stage 0, use micro_batch as input_obj_; and we don't have to cal microbatch dx.
# For loss backward; output_obj is loss; output_obj_grad should be None
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
assert output_obj_grad is None
input_obj_, _ = tree_flatten(input_obj)
output_obj_.append(output_obj) # LOSS
output_obj_grad_.append(output_obj_grad) # None
# For other chunk stage, use input_obj as input_obj_;
else:
input_obj_, _ = tree_flatten(input_obj)
output_obj_, _ = tree_flatten(output_obj) # y
output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy
# filter item which is not torch.Tensor
input_obj_ = [v for v in input_obj_ if isinstance(v, torch.Tensor) or v is None]
output_obj_ = [v for v in output_obj_ if isinstance(v, torch.Tensor) or v is None]
output_obj_grad_ = [v for v in output_obj_grad_ if isinstance(v, torch.Tensor) or v is None]
try:
ctx = optimizer.no_sync()
except AttributeError:
ctx = model_chunk.no_sync()
with ctx:
optimizer.backward_by_grad(
tensor=output_obj_,
grad=output_obj_grad_,
# inputs=input_obj_,
retain_graph=False,
)
# Format output_obj_grad
input_obj_grad = dict()
if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True):
pass
else:
for k, v in input_obj.items():
if isinstance(v, torch.Tensor) and v.grad is not None:
input_obj_grad[k] = v.grad
return input_obj_grad
def backward_w_step(
self,
model_chunk: Union[ModuleList, Module],
model_chunk_id: int,
optimizer: OptimizerWrapper,
output_obj: Union[dict, torch.Tensor],
output_obj_grad: Optional[dict],
):
"""Backward dw step of the pipeline; we calculate "dw = x*dy" here;
Args:
model_chunk (ModuleList or Module): Model Chunk to be run;
model_chunk_id (int): The current model chunk idx;
optimizer (OptimizerWrapper): Optimizer to update the model
output_obj (Union[dict, torch.Tensor]): y.
output_obj_grad (dict): dy.
Returns:
Nothing need to return; we only calculate dw then update w;
"""
# calculate bwd w step ; only dw = x*dy;
# y, dy list for w backward_by_grad; Type: list[tensor];
output_obj_ = []
output_obj_grad_ = []
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
# loss backward; output_obj is loss;
output_obj_.append(output_obj) # LOSS
output_obj_grad_.append(None) # None
else:
output_obj_, _ = tree_flatten(output_obj) # y
output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy
# filter item which is not torch.Tensor
output_obj_ = [v for v in output_obj_ if isinstance(v, torch.Tensor) or v is None]
output_obj_grad_ = [v for v in output_obj_grad_ if isinstance(v, torch.Tensor) or v is None]
optimizer.backward_by_grad(
tensor=output_obj_,
grad=output_obj_grad_,
inputs=list(model_chunk.parameters()),
retain_graph=False,
)
def schedule_f(
self,
scheduled_node,
model_chunk: torch.nn.ModuleList,
model_chunk_id: int,
criterion: Callable,
accum_loss: Optional[torch.Tensor] = None,
outputs: Optional[List[Any]] = None,
):
"""A complete forward schedule; Include recv fwd --> cal fwd --> send fwd;
Args:
scheduled_node:
model_chunk (ModuleList or Module): Model Chunk to be run;
model_chunk_id (int): The current model chunk idx;
criterion (Callable): loss function;
accum_loss (Optional[torch.Tensor], optional): Accumulated loss. Defaults to None.
outputs (Optional[List[Any]], optional): List to store the output of the last stage (final output). Defaults to None.
Returns:
Nothing.
"""
micro_batch = self.load_micro_batch(model_chunk_id=model_chunk_id)
# Step1: recv fwd
if model_chunk_id == 0:
# is first stage; get input from microbatch
if self.stage_manager.is_first_stage(ignore_chunk=True):
input_obj = None # (tensor, wait_handle)
else:
input_obj = self.recv_forward_buffer[model_chunk_id].pop(0)
for h in input_obj[1]:
h.wait()
input_obj = input_obj[0]
else:
# is last stage; recv from local
if self.stage_manager.is_last_stage(ignore_chunk=True):
input_obj = self.local_send_forward_buffer.pop(0)
# not last stage; recv from next
else:
input_obj = self.recv_forward_buffer[model_chunk_id].pop(0)
for h in input_obj[1]:
h.wait()
input_obj = input_obj[0]
# Here, let input_obj.requires_grad_()
# if input_obj is not None:
if not isinstance(input_obj, torch.Tensor):
tree_map(require_grad, input_obj)
# Also requires_grad_ for micro_batch in stage 0 chunk 0 fwd,
# tree_map(torch.Tensor.requires_grad_, micro_batch)
# Step2: fwd step
output_obj = self.forward_step(
model_chunk=model_chunk,
model_chunk_id=model_chunk_id,
micro_batch=micro_batch,
input_obj=input_obj,
criterion=criterion,
accum_loss=accum_loss,
outputs=outputs,
)
# Step3:
# 3-1:detach output; detach output for send fwd;
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
# We should not detach bwd LOSS
pass
else:
# detach output
detached_output_obj = tree_map(detach, output_obj)
# 3-2 clone detached_output_obj
detached_output_obj = tree_map(clone, detached_output_obj)
# 3-3 release cloned output.data; release_tensor_data output for bwd b & w; (do not detach output)
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
# We should not release_tensor_data bwd LOSS
pass
else:
# release_tensor_data output
tree_map(release_tensor_data, output_obj)
# add input and output object for backward b
self.input_tensors[model_chunk_id].append(input_obj)
# for bwd b&w, we only need the graph(grad_fn) of output_obj
# Do not release_tensor_data loss, release_tensor_data other output_obj;
if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True):
self.output_tensors[model_chunk_id].append(output_obj)
else:
self.output_tensors[model_chunk_id].append(output_obj)
# add output to send_fwd_buffer
if model_chunk_id == 0: # chunk 0
# is last stage; send to local_send_forward_buffer
if self.stage_manager.is_last_stage(ignore_chunk=True):
self.local_send_forward_buffer.append(detached_output_obj)
else:
self.send_forward_buffer[model_chunk_id].append(detached_output_obj)
else: # chunk 1
# is first stage; end of fwd; do nothing
if self.stage_manager.is_first_stage(ignore_chunk=True):
pass
else:
self.send_forward_buffer[model_chunk_id].append(detached_output_obj)
def schedule_b(
self,
scheduled_node,
model_chunk: Union[ModuleList, Module],
model_chunk_id: int,
optimizer: OptimizerWrapper,
):
"""A complete backward b schedule; Include recv bwd --> cal bwd step --> send bwd;
Args:
scheduled_node:
model_chunk (ModuleList or Module): Model Chunk to be run;
model_chunk_id (int): The current model chunk idx;
Returns:
Nothing.
"""
# Step1: recv bwd
if model_chunk_id == 0:
# chunk0 is last stage; recv output_grad from local_send_backward_buffer
if self.stage_manager.is_last_stage(ignore_chunk=True):
output_tensor_grad = self.local_send_backward_buffer.pop(0)
# chunk0 not last stage; recv output_grad from recv_backward_buffer
else:
output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0)
for h in output_tensor_grad[1]:
h.wait()
output_tensor_grad = output_tensor_grad[0]
else:
# chunk1, is first stage; recv LOSS from local send bwd buffer
if self.stage_manager.is_first_stage(ignore_chunk=True):
output_tensor_grad = None
# chunk1, not first stage; recv output_grad from recv_backward_buffer
else:
output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0)
for h in output_tensor_grad[1]:
h.wait()
output_tensor_grad = output_tensor_grad[0]
# get input and output object from buffer;
input_obj = self.input_tensors[model_chunk_id].pop(0)
output_obj = self.output_tensors[model_chunk_id].pop(0)
input_object_grad = self.backward_b_step(
model_chunk=model_chunk,
model_chunk_id=model_chunk_id,
optimizer=optimizer,
input_obj=input_obj,
output_obj=output_obj,
output_obj_grad=output_tensor_grad,
)
# Step3: send bwd
if model_chunk_id == 0:
# do nothing; end of bwd;
if self.stage_manager.is_first_stage(ignore_chunk=True):
pass
# save input_object_grad to send_backward_buffer
else:
self.send_backward_buffer[model_chunk_id].append(input_object_grad)
else:
# send to local_send_backward_buffer
if self.stage_manager.is_last_stage(ignore_chunk=True):
self.local_send_backward_buffer.append(input_object_grad)
# send to next
else:
self.send_backward_buffer[model_chunk_id].append(input_object_grad)
WeightGradStore.flush(chunk=model_chunk_id)
def schedule_w(
self,
scheduled_node,
model_chunk: Union[ModuleList, Module],
model_chunk_id: int,
optimizer: OptimizerWrapper,
):
"""A complete backward w schedule; Include get y & dy from buffer --> cal bwd w step(cal dw & update w);
Args:
scheduled_node:
model_chunk (ModuleList or Module): Model Chunk to be run;
model_chunk_id (int): The current model chunk idx;
Returns:
Nothing.
"""
WeightGradStore.pop(chunk=model_chunk_id)
def run_forward_only(
self,
model_chunk: Union[ModuleList, Module],
data_iter: Iterable,
criterion: Callable[..., Any],
return_loss: bool = False,
return_outputs: bool = False,
) -> Dict:
assert self.forward_only
# prepare batch
self.load_batch(data_iter)
# prepare accum loss & output
accum_loss = None
# reset accum loss at fwd end;
if return_loss and self.stage_manager.is_first_stage(ignore_chunk=True):
accum_loss = torch.scalar_tensor(0, device=get_accelerator().get_current_device())
outputs = [] if return_outputs and self.stage_manager.is_first_stage(ignore_chunk=True) else None
# while we still have schedules_node in self.schedules
for it in range(len(self.schedules)):
scheduled_node = self.schedules[it]
if scheduled_node.type in {"RECV_FORWARD", "SEND_FORWARD"}:
# communication
communication_func = self.communication_map[scheduled_node.type]
communication_func(scheduled_node.chunk)
if scheduled_node.type == "F":
self.schedule_f(
scheduled_node=scheduled_node,
model_chunk=model_chunk,
model_chunk_id=scheduled_node.chunk,
criterion=criterion,
accum_loss=accum_loss,
outputs=outputs,
)
# return loss & output
if outputs is not None:
outputs = merge_batch(outputs)
return {"loss": accum_loss, "outputs": outputs}
def run_forward_backward(
self,
model_chunk: Union[ModuleList, Module],
data_iter: Iterable,
criterion: Callable[..., Any],
optimizer: Optional[OptimizerWrapper] = None,
return_loss: bool = False,
return_outputs: bool = False,
) -> Dict:
"""
Runs Zerobubble schedule, with communication between pipeline stages.
"""
# prepare batch
self.load_batch(data_iter)
# prepare accum loss & output
accum_loss = None
# reset accum loss at fwd end;
if return_loss and self.stage_manager.is_first_stage(ignore_chunk=True):
accum_loss = torch.scalar_tensor(0, device=get_accelerator().get_current_device())
outputs = [] if return_outputs and self.stage_manager.is_first_stage(ignore_chunk=True) else None
# while we still have schedules_node in self.schedules
schedule = self.schedules[self.stage_manager.stage] # get schedule by stage (rank)
for it in range(len(schedule)):
scheduled_node = schedule[it]
if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES:
# communication
communication_func = self.communication_map[scheduled_node.type]
wait_handle = communication_func(scheduled_node.chunk)
# We wait recv handle in fwd step and bwd step. Here only need to wait for send handle
if scheduled_node.type in {"SEND_FORWARD", "SEND_BACKWARD"}:
self.wait_handles.append(wait_handle)
elif scheduled_node.type == "F":
self.schedule_f(
scheduled_node=scheduled_node,
model_chunk=model_chunk,
model_chunk_id=scheduled_node.chunk,
criterion=criterion,
accum_loss=accum_loss,
outputs=outputs,
)
elif scheduled_node.type == "B":
self.schedule_b(
scheduled_node=scheduled_node,
model_chunk=model_chunk,
model_chunk_id=scheduled_node.chunk,
optimizer=optimizer,
)
elif scheduled_node.type == "W":
self.schedule_w(
scheduled_node=scheduled_node,
model_chunk=model_chunk,
model_chunk_id=scheduled_node.chunk,
optimizer=optimizer,
)
# wait here to ensure all communication is done
for h in self.wait_handles:
for hh in h:
hh.wait()
# return loss & output
if outputs is not None:
outputs = merge_batch(outputs)
return {"loss": accum_loss, "outputs": outputs}
def forward_backward_step(
self,
model_chunk: Union[ModuleList, Module],
data_iter: Iterable,
criterion: Callable[..., Any],
optimizer: Optional[OptimizerWrapper] = None,
return_loss: bool = False,
return_outputs: bool = False,
) -> dict:
"""
Args:
model_chunk (ModuleList or Module): Model Chunk to be trained. Original interleaved uses a module list whereas shardformer uses entire model + layer specification
data_iter (Iterable): Data iterator.
criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor.
optimizer (OptimizerWrapper, optional): Optimizer to be used. Can be None when only forward is executed. Defaults to None.
return_loss (bool, optional): Whether to return loss. Defaults to False. Whether to return loss.
return_outputs (bool, optional): Whether to return model outputs. Defaults to False. Whether to return model outputs.
Returns:
dict: A dict with keys: 'loss' and 'outputs'.
"""
self.forward_only = not torch.is_grad_enabled()
if optimizer is None:
assert self.forward_only, "Optimizer should be passed when doing backward."
if self.forward_only:
result = self.run_forward_only(model_chunk, data_iter, criterion, return_loss, return_outputs)
else:
result = self.run_forward_backward(
model_chunk, data_iter, criterion, optimizer, return_loss, return_outputs
)
self.assert_buffer_empty()
return result

View File

@ -26,6 +26,7 @@ class PipelineStageManager:
pg_mesh: ProcessGroupMesh,
pipeline_axis: int,
enable_interleave: bool = False,
use_zbv: bool = False,
num_model_chunks: int = 1,
num_layers_per_stage: Optional[List[int]] = None,
) -> None:
@ -49,6 +50,7 @@ class PipelineStageManager:
next_coord = coord[: self.pipeline_axis] + (coord[self.pipeline_axis] + 1,) + coord[self.pipeline_axis + 1 :]
self.next_rank = self.pg_mesh.ravel(next_coord, self.pg_mesh.shape, mode="wrap")
self.is_interleave = enable_interleave
self.use_zbv = use_zbv
# for interleaved pipeline parallel, each device is responsible for multiple chunk of layers
self.num_model_chunks: int = num_model_chunks
# for shardformer, hold stage indices of model
@ -85,6 +87,16 @@ class PipelineStageManager:
num_layers_per_stage_accumulated = np.insert(np.cumsum(layers_per_stage), 0, 0)
stage_indices = []
if self.use_zbv:
stage_indices.append([num_layers_per_stage_accumulated[stage], num_layers_per_stage_accumulated[stage + 1]])
stage_indices.append(
[
num_layers_per_stage_accumulated[2 * num_stages - stage - 1],
num_layers_per_stage_accumulated[2 * num_stages - stage],
]
)
return stage_indices
for model_chunk in range(num_model_chunks):
start_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages]
end_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages + 1]
@ -124,7 +136,11 @@ class PipelineStageManager:
if not self.is_interleave or ignore_chunk:
return self.stage == self.num_stages - 1
else:
return self.stage == self.num_stages - 1 and self.model_chunk_id == self.num_model_chunks - 1
# use zero bubble pipeline
if self.use_zbv:
return self.stage == 0 and self.model_chunk_id == self.num_model_chunks - 1
else:
return self.stage == self.num_stages - 1 and self.model_chunk_id == self.num_model_chunks - 1
@property
def num_stages(self) -> int:
@ -207,7 +223,6 @@ class PipelineStageManager:
# calculate the num_layers per stage
layers_per_stage = [quotient] * num_stages * num_model_chunks
# deal with the rest layers
if remainder > 0:
start_position = (num_stages * num_model_chunks) // 2 - remainder // 2

View File

@ -0,0 +1,42 @@
import queue
class WeightGradStore:
cache = []
weight_grad_queue = [queue.Queue(), queue.Queue()]
@classmethod
def put(cls, total_input, grad_output, weight, func):
cls.cache.append((total_input, grad_output, weight, func))
@classmethod
def flush(cls, chunk=0):
cls.weight_grad_queue[chunk].put(cls.cache)
cls.cache = []
@classmethod
def pop(cls, chunk=0):
if cls.weight_grad_queue[chunk].qsize() > 0:
stored_grads = cls.weight_grad_queue[chunk].get()
for total_input, grad_output, weight, func in stored_grads:
if isinstance(weight, tuple):
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias.
# View will lead to weight ptr change
# weight_cal & weight_origin in tuple, weight_cal use to cal dw, weight_origin use to update
_, weight_origin = weight
if weight_origin.grad is not None:
func(total_input, grad_output, weight_origin.grad)
# for first bwd; weight.grad is None, assign grad_weight to weight.grad
else:
grad_weight = func(total_input, grad_output)
weight_origin.grad = grad_weight
else:
if weight.grad is not None:
func(total_input, grad_output, weight.grad)
# for first bwd; weight.grad is None, assign grad_weight to weight.grad
else:
grad_weight = func(total_input, grad_output)
weight.grad = grad_weight
else:
raise Exception("Pop empty queue.")

Some files were not shown because too many files have changed in this diff Show More