mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-04-29 20:25:29 +00:00
Compare commits
134 Commits
Author | SHA1 | Date | |
---|---|---|---|
|
46ed5d856b | ||
|
7ecdf9a211 | ||
|
44d4053fec | ||
|
6d676ee0e9 | ||
|
56fe130b15 | ||
|
f32861ccc5 | ||
|
b9e60559b8 | ||
|
7595c453a5 | ||
|
53834b74b9 | ||
|
0171884664 | ||
|
9379cbd668 | ||
|
24dee8f0b7 | ||
|
f73ae55394 | ||
|
f8b9e88484 | ||
|
d54642a263 | ||
|
d20c8ffd97 | ||
|
ce0ec40811 | ||
|
5ff5323538 | ||
|
014837e725 | ||
|
ec73f1b5e2 | ||
|
5c09d726a6 | ||
|
2b415e5999 | ||
|
17062c83b9 | ||
|
ca0aa2365d | ||
|
97e60cbbcb | ||
|
5b094a836b | ||
|
ee81366cac | ||
|
479067e9bc | ||
|
7fdef9fd6b | ||
|
a9bedc7a43 | ||
|
af06d162cf | ||
|
836992438f | ||
|
8b0ed61490 | ||
|
5f82bfa636 | ||
|
fa9d0318e4 | ||
|
130229fdcb | ||
|
aaafb38851 | ||
|
e994c64568 | ||
|
de3d371f65 | ||
|
8d826a336e | ||
|
6280cb18b8 | ||
|
ab856fd308 | ||
|
8ecff0cb7f | ||
|
8fddbab04c | ||
|
152162a80e | ||
|
cf519dac6a | ||
|
5caad13055 | ||
|
e0c68ab6d3 | ||
|
184a653704 | ||
|
5fa657f0a1 | ||
|
eb69e640e5 | ||
|
b90835bd32 | ||
|
8e08c27e19 | ||
|
d4a436051d | ||
|
5a03d2696d | ||
|
cc40fe0e6f | ||
|
c2fe3137e2 | ||
|
a2596519fd | ||
|
30a9443132 | ||
|
7a60161035 | ||
|
a15ab139ad | ||
|
13ffa08cfa | ||
|
2f583c1549 | ||
|
c2e8f61592 | ||
|
89a9a600bc | ||
|
4294ae83bb | ||
|
80a8ca916a | ||
|
dee63cc5ef | ||
|
6d6cafabe2 | ||
|
b10339df7c | ||
|
19baab5fd5 | ||
|
58d8b8a2dd | ||
|
5ddad486ca | ||
|
3b1d7d1ae8 | ||
|
2bcd0b6844 | ||
|
cd61353bae | ||
|
62c13e7969 | ||
|
dcd41d0973 | ||
|
83cf2f84fb | ||
|
bc7eeade33 | ||
|
fd92789af2 | ||
|
6be9862aaf | ||
|
3dc08c8a5a | ||
|
8ff7d0c780 | ||
|
fe9208feac | ||
|
3201377e94 | ||
|
23199e34cc | ||
|
d891e50617 | ||
|
e1e86f9f1f | ||
|
4c8e85ee0d | ||
|
703bb5c18d | ||
|
4e0e99bb6a | ||
|
1507a7528f | ||
|
0002ae5956 | ||
|
dc2cdaf3e8 | ||
|
efe3042bb2 | ||
|
6b2c506fc5 | ||
|
5ecc27e150 | ||
|
f98384aef6 | ||
|
646b3c5a90 | ||
|
b635dd0669 | ||
|
3532f77b90 | ||
|
3fab92166e | ||
|
f4daf04270 | ||
|
6705dad41b | ||
|
91ed32c256 | ||
|
6fb1322db1 | ||
|
65c8297710 | ||
|
cfd9eda628 | ||
|
cbaa104216 | ||
|
dabc2e7430 | ||
|
f9546ba0be | ||
|
4fa6b9509c | ||
|
63314ce4e4 | ||
|
10e4f7da72 | ||
|
37e35230ff | ||
|
827ef3ee9a | ||
|
bdb125f83f | ||
|
f20b066c59 | ||
|
b582319273 | ||
|
0ad3129cb9 | ||
|
0b14a5512e | ||
|
696fced0d7 | ||
|
dc032172c3 | ||
|
f393867cff | ||
|
6eb8832366 | ||
|
683179cefd | ||
|
0a01e2a453 | ||
|
216d54e374 | ||
|
fdd84b9087 | ||
|
a35a078f08 | ||
|
13946c4448 | ||
|
c54c4fcd15 | ||
|
8fd25d6e09 |
@ -1,3 +1,3 @@
|
|||||||
2.2.2-12.1.0
|
|
||||||
2.3.0-12.1.0
|
2.3.0-12.1.0
|
||||||
2.4.0-12.4.1
|
2.4.0-12.4.1
|
||||||
|
2.5.1-12.4.1
|
||||||
|
@ -1,11 +1,11 @@
|
|||||||
{
|
{
|
||||||
"build": [
|
"build": [
|
||||||
{
|
{
|
||||||
"torch_command": "pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu121",
|
"torch_command": "pip install torch==2.3.0 torchvision==0.18.0 torchaudio==2.3.0 --index-url https://download.pytorch.org/whl/cu121",
|
||||||
"cuda_image": "hpcaitech/cuda-conda:12.1"
|
"cuda_image": "hpcaitech/cuda-conda:12.1"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"torch_command": "pip install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cu124",
|
"torch_command": "pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu124",
|
||||||
"cuda_image": "hpcaitech/cuda-conda:12.4"
|
"cuda_image": "hpcaitech/cuda-conda:12.4"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
20
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
20
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
@ -15,6 +15,26 @@ body:
|
|||||||
options:
|
options:
|
||||||
- label: I have searched the existing issues
|
- label: I have searched the existing issues
|
||||||
required: true
|
required: true
|
||||||
|
|
||||||
|
- type: checkboxes
|
||||||
|
attributes:
|
||||||
|
label: The bug has not been fixed in the latest main branch
|
||||||
|
options:
|
||||||
|
- label: I have checked the latest main branch
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: dropdown
|
||||||
|
id: share_script
|
||||||
|
attributes:
|
||||||
|
label: Do you feel comfortable sharing a concise (minimal) script that reproduces the error? :)
|
||||||
|
description: If not, please share your setting/training config, and/or point to the line in the repo that throws the error.
|
||||||
|
If the issue is not easily reproducible by us, it will reduce the likelihood of getting responses.
|
||||||
|
options:
|
||||||
|
- Yes, I will share a minimal reproducible script.
|
||||||
|
- No, I prefer not to share.
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
- type: textarea
|
- type: textarea
|
||||||
attributes:
|
attributes:
|
||||||
label: 🐛 Describe the bug
|
label: 🐛 Describe the bug
|
||||||
|
11
.github/workflows/build_on_pr.yml
vendored
11
.github/workflows/build_on_pr.yml
vendored
@ -87,10 +87,10 @@ jobs:
|
|||||||
name: Build and Test Colossal-AI
|
name: Build and Test Colossal-AI
|
||||||
needs: detect
|
needs: detect
|
||||||
if: needs.detect.outputs.anyLibraryFileChanged == 'true'
|
if: needs.detect.outputs.anyLibraryFileChanged == 'true'
|
||||||
runs-on: [self-hosted, gpu]
|
runs-on: ubuntu-latest
|
||||||
container:
|
container:
|
||||||
image: hpcaitech/pytorch-cuda:2.2.2-12.1.0
|
image: image-cloud.luchentech.com/hpcaitech/pytorch-cuda:2.2.2-12.1.0
|
||||||
options: --gpus all --rm -v /dev/shm -v /data/scratch:/data/scratch
|
options: --gpus all --shm-size=2g --rm -v /dev/shm -v /data/scratch:/data/scratch
|
||||||
timeout-minutes: 90
|
timeout-minutes: 90
|
||||||
defaults:
|
defaults:
|
||||||
run:
|
run:
|
||||||
@ -117,7 +117,7 @@ jobs:
|
|||||||
cd TensorNVMe
|
cd TensorNVMe
|
||||||
conda install cmake
|
conda install cmake
|
||||||
pip install -r requirements.txt
|
pip install -r requirements.txt
|
||||||
DISABLE_URING=1 pip install -v .
|
DISABLE_URING=1 pip install -v --no-cache-dir .
|
||||||
|
|
||||||
- name: Store TensorNVMe Cache
|
- name: Store TensorNVMe Cache
|
||||||
run: |
|
run: |
|
||||||
@ -166,6 +166,7 @@ jobs:
|
|||||||
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
|
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
|
||||||
LLAMA_PATH: /data/scratch/llama-tiny
|
LLAMA_PATH: /data/scratch/llama-tiny
|
||||||
MOE_TENSOR_PATH: /data/scratch/moe_tensors
|
MOE_TENSOR_PATH: /data/scratch/moe_tensors
|
||||||
|
HF_ENDPOINT: https://hf-mirror.com
|
||||||
|
|
||||||
- name: Collate artifact
|
- name: Collate artifact
|
||||||
env:
|
env:
|
||||||
@ -199,7 +200,7 @@ jobs:
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
- name: Upload test coverage artifact
|
- name: Upload test coverage artifact
|
||||||
uses: actions/upload-artifact@v3
|
uses: actions/upload-artifact@v4
|
||||||
with:
|
with:
|
||||||
name: report
|
name: report
|
||||||
path: report/
|
path: report/
|
||||||
|
1
.github/workflows/build_on_schedule.yml
vendored
1
.github/workflows/build_on_schedule.yml
vendored
@ -70,6 +70,7 @@ jobs:
|
|||||||
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
|
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
|
||||||
LLAMA_PATH: /data/scratch/llama-tiny
|
LLAMA_PATH: /data/scratch/llama-tiny
|
||||||
MOE_TENSOR_PATH: /data/scratch/moe_tensors
|
MOE_TENSOR_PATH: /data/scratch/moe_tensors
|
||||||
|
HF_ENDPOINT: https://hf-mirror.com
|
||||||
|
|
||||||
- name: Notify Lark
|
- name: Notify Lark
|
||||||
id: message-preparation
|
id: message-preparation
|
||||||
|
@ -79,3 +79,4 @@ jobs:
|
|||||||
LD_LIBRARY_PATH: /github/home/.tensornvme/lib
|
LD_LIBRARY_PATH: /github/home/.tensornvme/lib
|
||||||
LLAMA_PATH: /data/scratch/llama-tiny
|
LLAMA_PATH: /data/scratch/llama-tiny
|
||||||
MOE_TENSOR_PATH: /data/scratch/moe_tensors
|
MOE_TENSOR_PATH: /data/scratch/moe_tensors
|
||||||
|
HF_ENDPOINT: https://hf-mirror.com
|
||||||
|
@ -73,3 +73,4 @@ jobs:
|
|||||||
LD_LIBRARY_PATH: /github/home/.tensornvme/lib
|
LD_LIBRARY_PATH: /github/home/.tensornvme/lib
|
||||||
LLAMA_PATH: /data/scratch/llama-tiny
|
LLAMA_PATH: /data/scratch/llama-tiny
|
||||||
MOE_TENSOR_PATH: /data/scratch/moe_tensors
|
MOE_TENSOR_PATH: /data/scratch/moe_tensors
|
||||||
|
HF_ENDPOINT: https://hf-mirror.com
|
||||||
|
@ -67,6 +67,7 @@ jobs:
|
|||||||
LD_LIBRARY_PATH: /github/home/.tensornvme/lib
|
LD_LIBRARY_PATH: /github/home/.tensornvme/lib
|
||||||
LLAMA_PATH: /data/scratch/llama-tiny
|
LLAMA_PATH: /data/scratch/llama-tiny
|
||||||
MOE_TENSOR_PATH: /data/scratch/moe_tensors
|
MOE_TENSOR_PATH: /data/scratch/moe_tensors
|
||||||
|
HF_ENDPOINT: https://hf-mirror.com
|
||||||
|
|
||||||
- name: Notify Lark
|
- name: Notify Lark
|
||||||
id: message-preparation
|
id: message-preparation
|
||||||
|
1
.github/workflows/doc_check_on_pr.yml
vendored
1
.github/workflows/doc_check_on_pr.yml
vendored
@ -58,6 +58,7 @@ jobs:
|
|||||||
# there is no main branch, so it's safe to checkout the main branch from the merged branch
|
# there is no main branch, so it's safe to checkout the main branch from the merged branch
|
||||||
# docer will rebase the remote main branch to the merged branch, so we have to config user
|
# docer will rebase the remote main branch to the merged branch, so we have to config user
|
||||||
- name: Make the merged branch main
|
- name: Make the merged branch main
|
||||||
|
|
||||||
run: |
|
run: |
|
||||||
cd ColossalAI
|
cd ColossalAI
|
||||||
git checkout -b main
|
git checkout -b main
|
||||||
|
@ -49,6 +49,7 @@ jobs:
|
|||||||
# we need to install the requirements.txt first
|
# we need to install the requirements.txt first
|
||||||
# as test-pypi may not contain the distributions for libs listed in the txt file
|
# as test-pypi may not contain the distributions for libs listed in the txt file
|
||||||
pip install -r requirements/requirements.txt
|
pip install -r requirements/requirements.txt
|
||||||
|
pip install -U setuptools==68.2.2 wheel
|
||||||
pip install --index-url https://test.pypi.org/simple/ --extra-index-url https://pypi.python.org/pypi colossalai==$VERSION
|
pip install --index-url https://test.pypi.org/simple/ --extra-index-url https://pypi.python.org/pypi colossalai==$VERSION
|
||||||
env:
|
env:
|
||||||
VERSION: ${{ steps.prep-version.outputs.version }}
|
VERSION: ${{ steps.prep-version.outputs.version }}
|
||||||
|
4
.github/workflows/run_chatgpt_examples.yml
vendored
4
.github/workflows/run_chatgpt_examples.yml
vendored
@ -31,13 +31,12 @@ jobs:
|
|||||||
|
|
||||||
- name: Install Colossal-AI
|
- name: Install Colossal-AI
|
||||||
run: |
|
run: |
|
||||||
BUILD_EXT=1 pip install --no-cache-dir -v -e .
|
pip install --no-cache-dir -v -e .
|
||||||
|
|
||||||
- name: Install ChatGPT
|
- name: Install ChatGPT
|
||||||
run: |
|
run: |
|
||||||
cd applications/ColossalChat
|
cd applications/ColossalChat
|
||||||
pip install --no-cache-dir -v .
|
pip install --no-cache-dir -v .
|
||||||
export BUILD_EXT=1
|
|
||||||
pip install --no-cache-dir -r examples/requirements.txt
|
pip install --no-cache-dir -r examples/requirements.txt
|
||||||
|
|
||||||
- name: Install Transformers
|
- name: Install Transformers
|
||||||
@ -61,5 +60,6 @@ jobs:
|
|||||||
PRETRAINED_MODEL_PATH: ./models
|
PRETRAINED_MODEL_PATH: ./models
|
||||||
SFT_DATASET: ./sft_data
|
SFT_DATASET: ./sft_data
|
||||||
PROMPT_DATASET: ./prompt_data
|
PROMPT_DATASET: ./prompt_data
|
||||||
|
PROMPT_RLVR_DATASET: ./prompt_data
|
||||||
PREFERENCE_DATASET: ./preference_data
|
PREFERENCE_DATASET: ./preference_data
|
||||||
KTO_DATASET: ./kto_data
|
KTO_DATASET: ./kto_data
|
||||||
|
@ -15,21 +15,21 @@ repos:
|
|||||||
args: ["--profile", "black"] # avoid conflict with black
|
args: ["--profile", "black"] # avoid conflict with black
|
||||||
|
|
||||||
- repo: https://github.com/psf/black-pre-commit-mirror
|
- repo: https://github.com/psf/black-pre-commit-mirror
|
||||||
rev: 24.8.0
|
rev: 24.10.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: black
|
- id: black
|
||||||
name: black formatter
|
name: black formatter
|
||||||
args: ['--line-length=120', '--target-version=py37', '--target-version=py38', '--target-version=py39','--target-version=py310']
|
args: ['--line-length=120', '--target-version=py37', '--target-version=py38', '--target-version=py39','--target-version=py310']
|
||||||
|
|
||||||
- repo: https://github.com/pre-commit/mirrors-clang-format
|
- repo: https://github.com/pre-commit/mirrors-clang-format
|
||||||
rev: v18.1.8
|
rev: v19.1.5
|
||||||
hooks:
|
hooks:
|
||||||
- id: clang-format
|
- id: clang-format
|
||||||
name: clang formatter
|
name: clang formatter
|
||||||
types_or: [c++, c]
|
types_or: [c++, c]
|
||||||
|
|
||||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
rev: v4.6.0
|
rev: v5.0.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: check-yaml
|
- id: check-yaml
|
||||||
- id: check-merge-conflict
|
- id: check-merge-conflict
|
||||||
|
30
README.md
30
README.md
@ -9,7 +9,7 @@
|
|||||||
<a href="https://www.colossalai.org/"> Documentation </a> |
|
<a href="https://www.colossalai.org/"> Documentation </a> |
|
||||||
<a href="https://github.com/hpcaitech/ColossalAI/tree/main/examples"> Examples </a> |
|
<a href="https://github.com/hpcaitech/ColossalAI/tree/main/examples"> Examples </a> |
|
||||||
<a href="https://github.com/hpcaitech/ColossalAI/discussions"> Forum </a> |
|
<a href="https://github.com/hpcaitech/ColossalAI/discussions"> Forum </a> |
|
||||||
<a href="https://cloud.luchentech.com/">GPU Cloud Playground </a> |
|
<a href="https://colossalai.org/zh-Hans/docs/get_started/bonus/">GPU Cloud Playground </a> |
|
||||||
<a href="https://hpc-ai.com/blog"> Blog </a></h3>
|
<a href="https://hpc-ai.com/blog"> Blog </a></h3>
|
||||||
|
|
||||||
[](https://github.com/hpcaitech/ColossalAI/stargazers)
|
[](https://github.com/hpcaitech/ColossalAI/stargazers)
|
||||||
@ -25,16 +25,34 @@
|
|||||||
|
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
## Get Started with Colossal-AI Without Setup
|
||||||
|
|
||||||
|
Access high-end, on-demand compute for your research instantly—no setup needed.
|
||||||
|
|
||||||
|
Sign up now and get $10 in credits!
|
||||||
|
|
||||||
|
Limited Academic Bonuses:
|
||||||
|
|
||||||
|
* Top up $1,000 and receive 300 credits
|
||||||
|
* Top up $500 and receive 100 credits
|
||||||
|
|
||||||
|
<div align="center">
|
||||||
|
<a href="https://hpc-ai.com/?utm_source=github&utm_medium=social&utm_campaign=promotion-colossalai">
|
||||||
|
<img src="https://github.com/hpcaitech/public_assets/blob/main/colossalai/img/2-2.gif" width="850" />
|
||||||
|
</a>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
|
||||||
## Latest News
|
## Latest News
|
||||||
|
* [2025/02] [DeepSeek 671B Fine-Tuning Guide Revealed—Unlock the Upgraded DeepSeek Suite with One Click, AI Players Ecstatic!](https://company.hpc-ai.com/blog/shocking-release-deepseek-671b-fine-tuning-guide-revealed-unlock-the-upgraded-deepseek-suite-with-one-click-ai-players-ecstatic)
|
||||||
|
* [2024/12] [The development cost of video generation models has saved by 50%! Open-source solutions are now available with H200 GPU vouchers](https://company.hpc-ai.com/blog/the-development-cost-of-video-generation-models-has-saved-by-50-open-source-solutions-are-now-available-with-h200-gpu-vouchers) [[code]](https://github.com/hpcaitech/Open-Sora/blob/main/scripts/train.py) [[vouchers]](https://colossalai.org/zh-Hans/docs/get_started/bonus/)
|
||||||
|
* [2024/10] [How to build a low-cost Sora-like app? Solutions for you](https://company.hpc-ai.com/blog/how-to-build-a-low-cost-sora-like-app-solutions-for-you)
|
||||||
|
* [2024/09] [Singapore Startup HPC-AI Tech Secures 50 Million USD in Series A Funding to Build the Video Generation AI Model and GPU Platform](https://company.hpc-ai.com/blog/singapore-startup-hpc-ai-tech-secures-50-million-usd-in-series-a-funding-to-build-the-video-generation-ai-model-and-gpu-platform)
|
||||||
|
* [2024/09] [Reducing AI Large Model Training Costs by 30% Requires Just a Single Line of Code From FP8 Mixed Precision Training Upgrades](https://company.hpc-ai.com/blog/reducing-ai-large-model-training-costs-by-30-requires-just-a-single-line-of-code-from-fp8-mixed-precision-training-upgrades)
|
||||||
* [2024/06] [Open-Sora Continues Open Source: Generate Any 16-Second 720p HD Video with One Click, Model Weights Ready to Use](https://hpc-ai.com/blog/open-sora-from-hpc-ai-tech-team-continues-open-source-generate-any-16-second-720p-hd-video-with-one-click-model-weights-ready-to-use)
|
* [2024/06] [Open-Sora Continues Open Source: Generate Any 16-Second 720p HD Video with One Click, Model Weights Ready to Use](https://hpc-ai.com/blog/open-sora-from-hpc-ai-tech-team-continues-open-source-generate-any-16-second-720p-hd-video-with-one-click-model-weights-ready-to-use)
|
||||||
* [2024/05] [Large AI Models Inference Speed Doubled, Colossal-Inference Open Source Release](https://hpc-ai.com/blog/colossal-inference)
|
* [2024/05] [Large AI Models Inference Speed Doubled, Colossal-Inference Open Source Release](https://hpc-ai.com/blog/colossal-inference)
|
||||||
* [2024/04] [Open-Sora Unveils Major Upgrade: Embracing Open Source with Single-Shot 16-Second Video Generation and 720p Resolution](https://hpc-ai.com/blog/open-soras-comprehensive-upgrade-unveiled-embracing-16-second-video-generation-and-720p-resolution-in-open-source)
|
* [2024/04] [Open-Sora Unveils Major Upgrade: Embracing Open Source with Single-Shot 16-Second Video Generation and 720p Resolution](https://hpc-ai.com/blog/open-soras-comprehensive-upgrade-unveiled-embracing-16-second-video-generation-and-720p-resolution-in-open-source)
|
||||||
* [2024/04] [Most cost-effective solutions for inference, fine-tuning and pretraining, tailored to LLaMA3 series](https://hpc-ai.com/blog/most-cost-effective-solutions-for-inference-fine-tuning-and-pretraining-tailored-to-llama3-series)
|
* [2024/04] [Most cost-effective solutions for inference, fine-tuning and pretraining, tailored to LLaMA3 series](https://hpc-ai.com/blog/most-cost-effective-solutions-for-inference-fine-tuning-and-pretraining-tailored-to-llama3-series)
|
||||||
* [2024/03] [314 Billion Parameter Grok-1 Inference Accelerated by 3.8x, Efficient and Easy-to-Use PyTorch+HuggingFace version is Here](https://hpc-ai.com/blog/314-billion-parameter-grok-1-inference-accelerated-by-3.8x-efficient-and-easy-to-use-pytorchhuggingface-version-is-here)
|
|
||||||
* [2024/03] [Open-Sora: Revealing Complete Model Parameters, Training Details, and Everything for Sora-like Video Generation Models](https://hpc-ai.com/blog/open-sora-v1.0)
|
|
||||||
* [2024/03] [Open-Sora:Sora Replication Solution with 46% Cost Reduction, Sequence Expansion to Nearly a Million](https://hpc-ai.com/blog/open-sora)
|
|
||||||
* [2024/01] [Inference Performance Improved by 46%, Open Source Solution Breaks the Length Limit of LLM for Multi-Round Conversations](https://hpc-ai.com/blog/Colossal-AI-SwiftInfer)
|
|
||||||
* [2023/07] [HPC-AI Tech Raises 22 Million USD in Series A Funding](https://www.hpc-ai.tech/blog/hpc-ai-tech-raises-22-million-usd-in-series-a-funding-to-fuel-team-expansion-and-business-growth)
|
|
||||||
|
|
||||||
## Table of Contents
|
## Table of Contents
|
||||||
<ul>
|
<ul>
|
||||||
|
@ -100,7 +100,7 @@ LLaMA3_Conv = Conversation(
|
|||||||
messages=[],
|
messages=[],
|
||||||
offset=0,
|
offset=0,
|
||||||
sep_style=SeparatorStyle.ADD_BOS_EOS_TOKEN,
|
sep_style=SeparatorStyle.ADD_BOS_EOS_TOKEN,
|
||||||
seps=["<|begin_of_text|>", "<|end_of_text|>"],
|
seps=["<|begin_of_text|>", "<|eot_id|>"],
|
||||||
)
|
)
|
||||||
|
|
||||||
default_conversation = LLaMA3_Conv
|
default_conversation = LLaMA3_Conv
|
||||||
|
@ -88,7 +88,7 @@ def supervised_tokenize_sft(
|
|||||||
|
|
||||||
assert (
|
assert (
|
||||||
tokenizer.bos_token == conversation_template.seps[0] and tokenizer.eos_token == conversation_template.seps[1]
|
tokenizer.bos_token == conversation_template.seps[0] and tokenizer.eos_token == conversation_template.seps[1]
|
||||||
), "`bos_token` and `eos_token` should be the same with `conversation_template.seps`."
|
), f"`bos_token`{tokenizer.bos_token} and `eos_token`{tokenizer.eos_token} should be the same with `conversation_template.seps`{conversation_template.seps}."
|
||||||
|
|
||||||
if ignore_index is None:
|
if ignore_index is None:
|
||||||
ignore_index = IGNORE_INDEX
|
ignore_index = IGNORE_INDEX
|
||||||
|
@ -43,6 +43,7 @@ def save_checkpoint(
|
|||||||
step: int,
|
step: int,
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
coordinator: DistCoordinator,
|
coordinator: DistCoordinator,
|
||||||
|
use_lora: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Save model checkpoint, optimizer, LR scheduler and intermedidate running states.
|
Save model checkpoint, optimizer, LR scheduler and intermedidate running states.
|
||||||
@ -51,7 +52,10 @@ def save_checkpoint(
|
|||||||
save_dir = os.path.join(save_dir, f"epoch-{epoch}_step-{step}")
|
save_dir = os.path.join(save_dir, f"epoch-{epoch}_step-{step}")
|
||||||
os.makedirs(os.path.join(save_dir, "modeling"), exist_ok=True)
|
os.makedirs(os.path.join(save_dir, "modeling"), exist_ok=True)
|
||||||
|
|
||||||
booster.save_model(model, os.path.join(save_dir, "modeling"), shard=True)
|
if use_lora:
|
||||||
|
booster.save_lora_as_pretrained(model, os.path.join(save_dir, "modeling"))
|
||||||
|
else:
|
||||||
|
booster.save_model(model, os.path.join(save_dir, "modeling"), shard=True)
|
||||||
|
|
||||||
booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True)
|
booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True)
|
||||||
booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler"))
|
booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler"))
|
||||||
|
@ -21,6 +21,7 @@ from colossal_llama.utils.ckpt_io import load_checkpoint, save_checkpoint
|
|||||||
from colossal_llama.utils.froze import freeze_non_embeds_parameters
|
from colossal_llama.utils.froze import freeze_non_embeds_parameters
|
||||||
from colossal_llama.utils.neftune_patch import activate_neftune, deactivate_neftune
|
from colossal_llama.utils.neftune_patch import activate_neftune, deactivate_neftune
|
||||||
from colossal_llama.utils.utils import all_reduce_mean, format_numel_str, get_model_numel
|
from colossal_llama.utils.utils import all_reduce_mean, format_numel_str, get_model_numel
|
||||||
|
from peft import LoraConfig
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
@ -65,7 +66,7 @@ def train(args) -> None:
|
|||||||
initial_scale=2**16,
|
initial_scale=2**16,
|
||||||
max_norm=args.grad_clip,
|
max_norm=args.grad_clip,
|
||||||
enable_gradient_accumulation=(args.accumulation_steps > 1),
|
enable_gradient_accumulation=(args.accumulation_steps > 1),
|
||||||
enable_fused_normalization=torch.cuda.is_available(),
|
enable_fused_normalization=get_accelerator().is_available(),
|
||||||
enable_flash_attention=args.use_flash_attn,
|
enable_flash_attention=args.use_flash_attn,
|
||||||
)
|
)
|
||||||
elif args.plugin == "gemini_auto":
|
elif args.plugin == "gemini_auto":
|
||||||
@ -75,7 +76,7 @@ def train(args) -> None:
|
|||||||
initial_scale=2**16,
|
initial_scale=2**16,
|
||||||
max_norm=args.grad_clip,
|
max_norm=args.grad_clip,
|
||||||
enable_gradient_accumulation=(args.accumulation_steps > 1),
|
enable_gradient_accumulation=(args.accumulation_steps > 1),
|
||||||
enable_fused_normalization=torch.cuda.is_available(),
|
enable_fused_normalization=get_accelerator().is_available(),
|
||||||
enable_flash_attention=args.use_flash_attn,
|
enable_flash_attention=args.use_flash_attn,
|
||||||
)
|
)
|
||||||
elif args.plugin == "zero2":
|
elif args.plugin == "zero2":
|
||||||
@ -101,10 +102,9 @@ def train(args) -> None:
|
|||||||
sequence_parallelism_mode=args.sp_mode,
|
sequence_parallelism_mode=args.sp_mode,
|
||||||
zero_stage=args.zero_stage,
|
zero_stage=args.zero_stage,
|
||||||
enable_flash_attention=args.use_flash_attn,
|
enable_flash_attention=args.use_flash_attn,
|
||||||
enable_fused_normalization=torch.cuda.is_available(),
|
enable_fused_normalization=get_accelerator().is_available(),
|
||||||
enable_sequence_parallelism=args.enable_sequence_parallelism,
|
enable_sequence_parallelism=args.enable_sequence_parallelism,
|
||||||
cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False,
|
cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False,
|
||||||
parallel_output=False,
|
|
||||||
max_norm=args.grad_clip,
|
max_norm=args.grad_clip,
|
||||||
precision=args.mixed_precision,
|
precision=args.mixed_precision,
|
||||||
microbatch_size=args.microbatch_size,
|
microbatch_size=args.microbatch_size,
|
||||||
@ -117,11 +117,17 @@ def train(args) -> None:
|
|||||||
# ======================================================
|
# ======================================================
|
||||||
# Initialize Tokenizer, Dataset, Collator and Dataloader
|
# Initialize Tokenizer, Dataset, Collator and Dataloader
|
||||||
# ======================================================
|
# ======================================================
|
||||||
tokenizer = AutoTokenizer.from_pretrained(args.pretrained)
|
tokenizer = AutoTokenizer.from_pretrained(args.pretrained, trust_remote_code=True)
|
||||||
if args.pad_token == "eos":
|
if args.pad_token == "eos":
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
try:
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
except AttributeError:
|
||||||
|
coordinator.print_on_master(f"pad_token can't be set")
|
||||||
elif args.pad_token == "unk":
|
elif args.pad_token == "unk":
|
||||||
tokenizer.pad_token = tokenizer.unk_token
|
try:
|
||||||
|
tokenizer.pad_token = tokenizer.unk_token
|
||||||
|
except AttributeError:
|
||||||
|
coordinator.print_on_master(f"pad_token can't be set")
|
||||||
tokenizer.add_bos_token = False
|
tokenizer.add_bos_token = False
|
||||||
tokenizer.add_eos_token = False
|
tokenizer.add_eos_token = False
|
||||||
|
|
||||||
@ -164,33 +170,31 @@ def train(args) -> None:
|
|||||||
# ======================================================
|
# ======================================================
|
||||||
# Initialize Model, Objective, Optimizer and LR Scheduler
|
# Initialize Model, Objective, Optimizer and LR Scheduler
|
||||||
# ======================================================
|
# ======================================================
|
||||||
|
# When training the ChatGLM model, LoRA and gradient checkpointing are incompatible.
|
||||||
init_ctx = (
|
init_ctx = (
|
||||||
LazyInitContext(default_device=get_current_device())
|
LazyInitContext(default_device=get_current_device())
|
||||||
if isinstance(plugin, (GeminiPlugin, HybridParallelPlugin))
|
if isinstance(plugin, (GeminiPlugin, HybridParallelPlugin)) and args.lora_rank == 0
|
||||||
else nullcontext()
|
else nullcontext()
|
||||||
)
|
)
|
||||||
with init_ctx:
|
with init_ctx:
|
||||||
if args.use_flash_attn:
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
args.pretrained,
|
||||||
args.pretrained,
|
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
|
||||||
attn_implementation="flash_attention_2",
|
trust_remote_code=True,
|
||||||
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
|
)
|
||||||
trust_remote_code=True,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
|
||||||
args.pretrained,
|
|
||||||
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
|
|
||||||
trust_remote_code=True,
|
|
||||||
)
|
|
||||||
# Freeze part of parameters.
|
# Freeze part of parameters.
|
||||||
if args.freeze_non_embeds_params:
|
if args.freeze_non_embeds_params:
|
||||||
freeze_non_embeds_parameters(model=model)
|
freeze_non_embeds_parameters(model=model)
|
||||||
|
|
||||||
|
if args.lora_rank > 0:
|
||||||
|
lora_config = LoraConfig(task_type="CAUSAL_LM", r=args.lora_rank, lora_alpha=32, lora_dropout=0.1)
|
||||||
|
model = booster.enable_lora(model, lora_config=lora_config)
|
||||||
|
|
||||||
# this is essential, otherwise the grad checkpoint will not work.
|
# this is essential, otherwise the grad checkpoint will not work.
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
if args.use_grad_checkpoint:
|
if args.use_grad_checkpoint:
|
||||||
model.gradient_checkpointing_enable()
|
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
|
||||||
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
|
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
|
||||||
|
|
||||||
model_numel = get_model_numel(model)
|
model_numel = get_model_numel(model)
|
||||||
@ -327,6 +331,7 @@ def train(args) -> None:
|
|||||||
step=step + 1,
|
step=step + 1,
|
||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size,
|
||||||
coordinator=coordinator,
|
coordinator=coordinator,
|
||||||
|
use_lora=(args.lora_rank > 0),
|
||||||
)
|
)
|
||||||
coordinator.print_on_master(
|
coordinator.print_on_master(
|
||||||
f"Saved checkpoint at epoch {epoch} step {step + 1} at folder {args.save_dir}"
|
f"Saved checkpoint at epoch {epoch} step {step + 1} at folder {args.save_dir}"
|
||||||
@ -371,44 +376,45 @@ def train(args) -> None:
|
|||||||
total_loss.fill_(0.0)
|
total_loss.fill_(0.0)
|
||||||
pbar.update()
|
pbar.update()
|
||||||
|
|
||||||
# Save modeling.
|
# Save modeling.
|
||||||
save_model_condition = (
|
save_model_condition = (
|
||||||
args.save_interval > 0 and (step + 1) % (args.save_interval * args.accumulation_steps) == 0
|
args.save_interval > 0 and (step + 1) % (args.save_interval * args.accumulation_steps) == 0
|
||||||
)
|
|
||||||
|
|
||||||
if not args.skip_save_each_epoch:
|
|
||||||
save_model_condition = save_model_condition or (step + 1) == len(dataloader)
|
|
||||||
|
|
||||||
if save_model_condition and not args.benchmark:
|
|
||||||
coordinator.print_on_master("\nStart saving model checkpoint with running states")
|
|
||||||
|
|
||||||
if args.use_neft:
|
|
||||||
coordinator.print_on_master("Deactivate NEFTune before saving model.")
|
|
||||||
deactivate_neftune(model, handle)
|
|
||||||
|
|
||||||
accelerator.empty_cache()
|
|
||||||
save_checkpoint(
|
|
||||||
save_dir=args.save_dir,
|
|
||||||
booster=booster,
|
|
||||||
model=model,
|
|
||||||
optimizer=optimizer,
|
|
||||||
lr_scheduler=lr_scheduler,
|
|
||||||
epoch=epoch,
|
|
||||||
step=step + 1,
|
|
||||||
batch_size=args.batch_size,
|
|
||||||
coordinator=coordinator,
|
|
||||||
)
|
|
||||||
coordinator.print_on_master(
|
|
||||||
f"Saved checkpoint at epoch {epoch} step {step + 1} at folder {args.save_dir}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.use_neft:
|
if not args.skip_save_each_epoch:
|
||||||
coordinator.print_on_master("Activate NEFTune.")
|
save_model_condition = save_model_condition or (step + 1) == len(dataloader)
|
||||||
model, handle = activate_neftune(model)
|
|
||||||
|
|
||||||
# Delete cache.
|
if save_model_condition and not args.benchmark:
|
||||||
# del batch, batch_labels, batch_output, loss
|
coordinator.print_on_master("\nStart saving model checkpoint with running states")
|
||||||
accelerator.empty_cache()
|
|
||||||
|
if args.use_neft:
|
||||||
|
coordinator.print_on_master("Deactivate NEFTune before saving model.")
|
||||||
|
deactivate_neftune(model, handle)
|
||||||
|
|
||||||
|
accelerator.empty_cache()
|
||||||
|
save_checkpoint(
|
||||||
|
save_dir=args.save_dir,
|
||||||
|
booster=booster,
|
||||||
|
model=model,
|
||||||
|
optimizer=optimizer,
|
||||||
|
lr_scheduler=lr_scheduler,
|
||||||
|
epoch=epoch,
|
||||||
|
step=step + 1,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
coordinator=coordinator,
|
||||||
|
use_lora=(args.lora_rank > 0),
|
||||||
|
)
|
||||||
|
coordinator.print_on_master(
|
||||||
|
f"Saved checkpoint at epoch {epoch} step {step + 1} at folder {args.save_dir}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.use_neft:
|
||||||
|
coordinator.print_on_master("Activate NEFTune.")
|
||||||
|
model, handle = activate_neftune(model)
|
||||||
|
|
||||||
|
# Delete cache.
|
||||||
|
# del batch, batch_labels, batch_output, loss
|
||||||
|
accelerator.empty_cache()
|
||||||
|
|
||||||
# the continue epochs are not resumed, so we need to reset the sampler start index and start step
|
# the continue epochs are not resumed, so we need to reset the sampler start index and start step
|
||||||
dataloader.sampler.set_start_index(start_index=0)
|
dataloader.sampler.set_start_index(start_index=0)
|
||||||
@ -522,6 +528,7 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--microbatch_size", type=int, default=1, help="Batch size for each process in PP, used for 3d plugin."
|
"--microbatch_size", type=int, default=1, help="Batch size for each process in PP, used for 3d plugin."
|
||||||
)
|
)
|
||||||
|
parser.add_argument("--lora_rank", type=int, default=0, help="lora rank when using lora to train.")
|
||||||
|
|
||||||
# Additional arguments for benchmark.
|
# Additional arguments for benchmark.
|
||||||
parser.add_argument("--num_samples", type=int, default=500, help="Number of samples for benchmarking.")
|
parser.add_argument("--num_samples", type=int, default=500, help="Number of samples for benchmarking.")
|
||||||
|
1
applications/ColossalChat/.gitignore
vendored
1
applications/ColossalChat/.gitignore
vendored
@ -158,6 +158,7 @@ temp/
|
|||||||
applications/ColossalChat/logs
|
applications/ColossalChat/logs
|
||||||
applications/ColossalChat/models
|
applications/ColossalChat/models
|
||||||
applications/ColossalChat/sft_data
|
applications/ColossalChat/sft_data
|
||||||
|
applications/ColossalChat/kto_data
|
||||||
applications/ColossalChat/prompt_data
|
applications/ColossalChat/prompt_data
|
||||||
applications/ColossalChat/preference_data
|
applications/ColossalChat/preference_data
|
||||||
applications/ColossalChat/temp
|
applications/ColossalChat/temp
|
||||||
|
@ -7,31 +7,23 @@
|
|||||||
## Table of Contents
|
## Table of Contents
|
||||||
|
|
||||||
- [Table of Contents](#table-of-contents)
|
- [Table of Contents](#table-of-contents)
|
||||||
- [What is ColossalChat and Coati ?](#what-is-colossalchat-and-coati-)
|
- [What is ColossalChat?](#what-is-colossalchat)
|
||||||
- [Online demo](#online-demo)
|
- [Online demo](#online-demo)
|
||||||
- [Install](#install)
|
- [Install](#install)
|
||||||
- [Install the environment](#install-the-environment)
|
- [Install the environment](#install-the-environment)
|
||||||
- [Install the Transformers](#install-the-transformers)
|
- [Install the Transformers](#install-the-transformers)
|
||||||
- [How to use?](#how-to-use)
|
- [Introduction](#introduction)
|
||||||
- [Supervised datasets collection](#step-1-data-collection)
|
- [Supervised datasets collection](#step-1-data-collection)
|
||||||
- [RLHF Training Stage1 - Supervised instructs tuning](#rlhf-training-stage1---supervised-instructs-tuning)
|
- [RLHF Training Stage1 - Supervised instructs tuning](#rlhf-training-stage1---supervised-instructs-tuning)
|
||||||
- [RLHF Training Stage2 - Training reward model](#rlhf-training-stage2---training-reward-model)
|
- [RLHF Training Stage2 - Training reward model](#rlhf-training-stage2---training-reward-model)
|
||||||
- [RLHF Training Stage3 - Training model with reinforcement learning by human feedback](#rlhf-training-stage3---proximal-policy-optimization)
|
- [RLHF Training Stage3 - Training model with reinforcement learning by human feedback](#rlhf-training-stage3---proximal-policy-optimization)
|
||||||
|
- [Alternative Option for RLHF: GRPO](#alternative-option-for-rlhf-group-relative-policy-optimization-grpo)
|
||||||
|
- [Alternative Option For RLHF: DPO](#alternative-option-for-rlhf-direct-preference-optimization)
|
||||||
|
- [Alternative Option For RLHF: SimPO](#alternative-option-for-rlhf-simple-preference-optimization-simpo)
|
||||||
|
- [Alternative Option For RLHF: ORPO](#alternative-option-for-rlhf-odds-ratio-preference-optimization-orpo)
|
||||||
|
- [Alternative Option For RLHF: KTO](#alternative-option-for-rlhf-kahneman-tversky-optimization-kto)
|
||||||
|
- [SFT for DeepSeek V3/R1](#sft-for-deepseek-v3)
|
||||||
- [Inference Quantization and Serving - After Training](#inference-quantization-and-serving---after-training)
|
- [Inference Quantization and Serving - After Training](#inference-quantization-and-serving---after-training)
|
||||||
- [Coati7B examples](#coati7b-examples)
|
|
||||||
- [Generation](#generation)
|
|
||||||
- [Open QA](#open-qa)
|
|
||||||
- [Limitation for LLaMA-finetuned models](#limitation)
|
|
||||||
- [Limitation of dataset](#limitation)
|
|
||||||
- [Alternative Option For RLHF: DPO](#alternative-option-for-rlhf-direct-preference-optimization)
|
|
||||||
- [Alternative Option For RLHF: SimPO](#alternative-option-for-rlhf-simple-preference-optimization-simpo)
|
|
||||||
- [Alternative Option For RLHF: ORPO](#alternative-option-for-rlhf-odds-ratio-preference-optimization-orpo)
|
|
||||||
- [Alternative Option For RLHF: KTO](#alternative-option-for-rlhf-kahneman-tversky-optimization-kto)
|
|
||||||
- [FAQ](#faq)
|
|
||||||
- [How to save/load checkpoint](#faq)
|
|
||||||
- [How to train with limited resources](#faq)
|
|
||||||
- [The Plan](#the-plan)
|
|
||||||
- [Real-time progress](#real-time-progress)
|
|
||||||
- [Invitation to open-source contribution](#invitation-to-open-source-contribution)
|
- [Invitation to open-source contribution](#invitation-to-open-source-contribution)
|
||||||
- [Quick Preview](#quick-preview)
|
- [Quick Preview](#quick-preview)
|
||||||
- [Authors](#authors)
|
- [Authors](#authors)
|
||||||
@ -40,9 +32,9 @@
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## What Is ColossalChat And Coati ?
|
## What is ColossalChat?
|
||||||
|
|
||||||
[ColossalChat](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat) is the project to implement LLM with RLHF, powered by the [Colossal-AI](https://github.com/hpcaitech/ColossalAI) project.
|
[ColossalChat](https://github.com/hpcaitech/ColossalAI/tree/main/applications/ColossalChat) is a project to implement LLM with RLHF, powered by the [Colossal-AI](https://github.com/hpcaitech/ColossalAI).
|
||||||
|
|
||||||
Coati stands for `ColossalAI Talking Intelligence`. It is the name for the module implemented in this project and is also the name of the large language model developed by the ColossalChat project.
|
Coati stands for `ColossalAI Talking Intelligence`. It is the name for the module implemented in this project and is also the name of the large language model developed by the ColossalChat project.
|
||||||
|
|
||||||
@ -53,8 +45,6 @@ The Coati package provides a unified large language model framework that has imp
|
|||||||
- Supervised instructions fine-tuning
|
- Supervised instructions fine-tuning
|
||||||
- Training reward model
|
- Training reward model
|
||||||
- Reinforcement learning with human feedback
|
- Reinforcement learning with human feedback
|
||||||
- Quantization inference
|
|
||||||
- Fast model deploying
|
|
||||||
- Perfectly integrated with the Hugging Face ecosystem, a high degree of model customization
|
- Perfectly integrated with the Hugging Face ecosystem, a high degree of model customization
|
||||||
|
|
||||||
<div align="center">
|
<div align="center">
|
||||||
@ -114,77 +104,16 @@ cd $COLOSSAL_AI_ROOT/applications/ColossalChat
|
|||||||
pip install .
|
pip install .
|
||||||
```
|
```
|
||||||
|
|
||||||
## How To Use?
|
## Introduction
|
||||||
|
|
||||||
### RLHF Training Stage1 - Supervised Instructs Tuning
|
### RLHF Training Stage1 - Supervised Instructs Tuning
|
||||||
|
|
||||||
Stage1 is supervised instructs fine-tuning (SFT). This step is a crucial part of the RLHF training process, as it involves training a machine learning model using human-provided instructions to learn the initial behavior for the task at hand. Here's a detailed guide on how to SFT your LLM with ColossalChat. More details can be found in [example guideline](./examples/README.md).
|
Stage1 is supervised instructs fine-tuning (SFT). This step is a crucial part of the RLHF training process, as it involves training a machine learning model using human-provided instructions to learn the initial behavior for the task at hand. More details can be found in [example guideline](./examples/README.md).
|
||||||
|
|
||||||
#### Step 1: Data Collection
|
|
||||||
The first step in Stage 1 is to collect a dataset of human demonstrations of the following format.
|
|
||||||
|
|
||||||
```json
|
|
||||||
[
|
|
||||||
{"messages":
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"from": "user",
|
|
||||||
"content": "what are some pranks with a pen i can do?"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"from": "assistant",
|
|
||||||
"content": "Are you looking for practical joke ideas?"
|
|
||||||
},
|
|
||||||
]
|
|
||||||
},
|
|
||||||
]
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Step 2: Preprocessing
|
|
||||||
Once you have collected your SFT dataset, you will need to preprocess it. This involves four steps: data cleaning, data deduplication, formatting and tokenization. In this section, we will focus on formatting and tokenization.
|
|
||||||
|
|
||||||
In this code, we provide a flexible way for users to set the conversation template for formatting chat data using Huggingface's newest feature--- chat template. Please follow the [example guideline](./examples/README.md) on how to format and tokenize data.
|
|
||||||
|
|
||||||
#### Step 3: Training
|
|
||||||
Choose a suitable model architecture for your task. Note that your model should be compatible with the tokenizer that you used to tokenize the SFT dataset. You can run [train_sft.sh](./examples/training_scripts/train_sft.sh) to start a supervised instructs fine-tuning. More details can be found in [example guideline](./examples/README.md).
|
|
||||||
|
|
||||||
### RLHF Training Stage2 - Training Reward Model
|
### RLHF Training Stage2 - Training Reward Model
|
||||||
|
|
||||||
Stage2 trains a reward model, which obtains corresponding scores by manually ranking different outputs for the same prompt and supervises the training of the reward model.
|
Stage2 trains a reward model, which obtains corresponding scores by manually ranking different outputs for the same prompt and supervises the training of the reward model.
|
||||||
|
|
||||||
#### Step 1: Data Collection
|
|
||||||
Below shows the preference dataset format used in training the reward model.
|
|
||||||
|
|
||||||
```json
|
|
||||||
[
|
|
||||||
{"context": [
|
|
||||||
{
|
|
||||||
"from": "human",
|
|
||||||
"content": "Introduce butterflies species in Oregon."
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"chosen": [
|
|
||||||
{
|
|
||||||
"from": "assistant",
|
|
||||||
"content": "About 150 species of butterflies live in Oregon, with about 100 species are moths..."
|
|
||||||
},
|
|
||||||
],
|
|
||||||
"rejected": [
|
|
||||||
{
|
|
||||||
"from": "assistant",
|
|
||||||
"content": "Are you interested in just the common butterflies? There are a few common ones which will be easy to find..."
|
|
||||||
},
|
|
||||||
]
|
|
||||||
},
|
|
||||||
]
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Step 2: Preprocessing
|
|
||||||
Similar to the second step in the previous stage, we format the reward data into the same structured format as used in step 2 of the SFT stage. You can run [prepare_preference_dataset.sh](./examples/data_preparation_scripts/prepare_preference_dataset.sh) to prepare the preference data for reward model training.
|
|
||||||
|
|
||||||
#### Step 3: Training
|
|
||||||
You can run [train_rm.sh](./examples/training_scripts/train_rm.sh) to start the reward model training. More details can be found in [example guideline](./examples/README.md).
|
|
||||||
|
|
||||||
### RLHF Training Stage3 - Proximal Policy Optimization
|
### RLHF Training Stage3 - Proximal Policy Optimization
|
||||||
|
|
||||||
In stage3 we will use reinforcement learning algorithm--- Proximal Policy Optimization (PPO), which is the most complex part of the training process:
|
In stage3 we will use reinforcement learning algorithm--- Proximal Policy Optimization (PPO), which is the most complex part of the training process:
|
||||||
@ -193,85 +122,25 @@ In stage3 we will use reinforcement learning algorithm--- Proximal Policy Optimi
|
|||||||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/stage-3.jpeg" width=800/>
|
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/stage-3.jpeg" width=800/>
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
#### Step 1: Data Collection
|
|
||||||
PPO uses two kind of training data--- the prompt data and the sft data (optional). The first dataset is mandatory, data samples within the prompt dataset ends with a line from "human" and thus the "assistant" needs to generate a response to answer to the "human". Note that you can still use conversation that ends with a line from the "assistant", in that case, the last line will be dropped. Here is an example of the prompt dataset format.
|
|
||||||
|
|
||||||
```json
|
### Alternative Option For RLHF: Direct Preference Optimization (DPO)
|
||||||
[
|
|
||||||
{"messages":
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"from": "human",
|
|
||||||
"content": "what are some pranks with a pen i can do?"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
},
|
|
||||||
]
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Step 2: Data Preprocessing
|
|
||||||
To prepare the prompt dataset for PPO training, simply run [prepare_prompt_dataset.sh](./examples/data_preparation_scripts/prepare_prompt_dataset.sh)
|
|
||||||
|
|
||||||
#### Step 3: Training
|
|
||||||
You can run the [train_ppo.sh](./examples/training_scripts/train_ppo.sh) to start PPO training. Here are some unique arguments for PPO, please refer to the training configuration section for other training configuration. More detais can be found in [example guideline](./examples/README.md).
|
|
||||||
|
|
||||||
```bash
|
|
||||||
--pretrain $PRETRAINED_MODEL_PATH \
|
|
||||||
--rm_pretrain $PRETRAINED_MODEL_PATH \ # reward model architectual
|
|
||||||
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
|
|
||||||
--rm_checkpoint_path $REWARD_MODEL_PATH \ # reward model checkpoint path
|
|
||||||
--prompt_dataset ${prompt_dataset[@]} \ # List of string, the prompt dataset
|
|
||||||
--ptx_dataset ${ptx_dataset[@]} \ # List of string, the SFT data used in the SFT stage
|
|
||||||
--ptx_batch_size 1 \ # batch size for calculate ptx loss
|
|
||||||
--ptx_coef 0.0 \ # none-zero if ptx loss is enable
|
|
||||||
--num_episodes 2000 \ # number of episodes to train
|
|
||||||
--num_collect_steps 1 \
|
|
||||||
--num_update_steps 1 \
|
|
||||||
--experience_batch_size 8 \
|
|
||||||
--train_batch_size 4 \
|
|
||||||
--accumulation_steps 2
|
|
||||||
```
|
|
||||||
|
|
||||||
Each episode has two phases, the collect phase and the update phase. During the collect phase, we will collect experiences (answers generated by actor), store those in ExperienceBuffer. Then data in ExperienceBuffer is used during the update phase to update parameter of actor and critic.
|
|
||||||
|
|
||||||
- Without tensor parallelism,
|
|
||||||
```
|
|
||||||
experience buffer size
|
|
||||||
= num_process * num_collect_steps * experience_batch_size
|
|
||||||
= train_batch_size * accumulation_steps * num_process
|
|
||||||
```
|
|
||||||
|
|
||||||
- With tensor parallelism,
|
|
||||||
```
|
|
||||||
num_tp_group = num_process / tp
|
|
||||||
experience buffer size
|
|
||||||
= num_tp_group * num_collect_steps * experience_batch_size
|
|
||||||
= train_batch_size * accumulation_steps * num_tp_group
|
|
||||||
```
|
|
||||||
|
|
||||||
## Alternative Option For RLHF: Direct Preference Optimization (DPO)
|
|
||||||
For those seeking an alternative to Reinforcement Learning from Human Feedback (RLHF), Direct Preference Optimization (DPO) presents a compelling option. DPO, as detailed in this [paper](https://arxiv.org/abs/2305.18290), DPO offers an low-cost way to perform RLHF and usually request less computation resources compares to PPO. Read this [README](./examples/README.md) for more information.
|
For those seeking an alternative to Reinforcement Learning from Human Feedback (RLHF), Direct Preference Optimization (DPO) presents a compelling option. DPO, as detailed in this [paper](https://arxiv.org/abs/2305.18290), DPO offers an low-cost way to perform RLHF and usually request less computation resources compares to PPO. Read this [README](./examples/README.md) for more information.
|
||||||
|
|
||||||
### DPO Training Stage1 - Supervised Instructs Tuning
|
### Alternative Option For RLHF: Simple Preference Optimization (SimPO)
|
||||||
|
|
||||||
Please refer the [sft section](#dpo-training-stage1---supervised-instructs-tuning) in the PPO part.
|
|
||||||
|
|
||||||
### DPO Training Stage2 - DPO Training
|
|
||||||
#### Step 1: Data Collection & Preparation
|
|
||||||
For DPO training, you only need the preference dataset. Please follow the instruction in the [preference dataset preparation section](#rlhf-training-stage2---training-reward-model) to prepare the preference data for DPO training.
|
|
||||||
|
|
||||||
#### Step 2: Training
|
|
||||||
You can run the [train_dpo.sh](./examples/training_scripts/train_dpo.sh) to start DPO training. More detais can be found in [example guideline](./examples/README.md).
|
|
||||||
|
|
||||||
## Alternative Option For RLHF: Simple Preference Optimization (SimPO)
|
|
||||||
Simple Preference Optimization (SimPO) from this [paper](https://arxiv.org/pdf/2405.14734) is similar to DPO but it abandons the use of the reference model, which makes the training more efficient. It also adds a reward shaping term called target reward margin to enhance training stability. It also use length normalization to better align with the inference process. Read this [README](./examples/README.md) for more information.
|
Simple Preference Optimization (SimPO) from this [paper](https://arxiv.org/pdf/2405.14734) is similar to DPO but it abandons the use of the reference model, which makes the training more efficient. It also adds a reward shaping term called target reward margin to enhance training stability. It also use length normalization to better align with the inference process. Read this [README](./examples/README.md) for more information.
|
||||||
|
|
||||||
## Alternative Option For RLHF: Odds Ratio Preference Optimization (ORPO)
|
### Alternative Option For RLHF: Odds Ratio Preference Optimization (ORPO)
|
||||||
Odds Ratio Preference Optimization (ORPO) from this [paper](https://arxiv.org/pdf/2403.07691) is a reference model free alignment method that use a mixture of SFT loss and a reinforcement leanring loss calculated based on odds-ratio-based implicit reward to makes the training more efficient and stable. Read this [README](./examples/README.md) for more information.
|
Odds Ratio Preference Optimization (ORPO) from this [paper](https://arxiv.org/pdf/2403.07691) is a reference model free alignment method that use a mixture of SFT loss and a reinforcement leanring loss calculated based on odds-ratio-based implicit reward to makes the training more efficient and stable. Read this [README](./examples/README.md) for more information.
|
||||||
|
|
||||||
## Alternative Option For RLHF: Kahneman-Tversky Optimization (KTO)
|
### Alternative Option For RLHF: Kahneman-Tversky Optimization (KTO)
|
||||||
We support the method introduced in the paper [KTO:Model Alignment as Prospect Theoretic Optimization](https://arxiv.org/pdf/2402.01306) (KTO). Which is a aligment method that directly maximize "human utility" of generation results. Read this [README](./examples/README.md) for more information.
|
We support the method introduced in the paper [KTO:Model Alignment as Prospect Theoretic Optimization](https://arxiv.org/pdf/2402.01306) (KTO). Which is a aligment method that directly maximize "human utility" of generation results. Read this [README](./examples/README.md) for more information.
|
||||||
|
|
||||||
|
### Alternative Option For RLHF: Group Relative Policy Optimization (GRPO)
|
||||||
|
We support the main algorithm used to train DeepSeek R1 model, a variant of Proximal Policy Optimization (PPO), that enhances mathematical reasoning abilities while concurrently optimizing the memory usage of PPO. Read this [README](./examples/README.md) for more information.
|
||||||
|
|
||||||
|
### SFT for DeepSeek V3
|
||||||
|
We support fine-tuning DeepSeek V3/R1 model with LoRA. Read this [README](./examples/README.md) for more information.
|
||||||
|
|
||||||
### Inference Quantization and Serving - After Training
|
### Inference Quantization and Serving - After Training
|
||||||
|
|
||||||
We provide an online inference server and a benchmark. We aim to run inference on single GPU, so quantization is essential when using large models.
|
We provide an online inference server and a benchmark. We aim to run inference on single GPU, so quantization is essential when using large models.
|
||||||
@ -281,167 +150,7 @@ We support 8-bit quantization (RTN), 4-bit quantization (GPTQ), and FP16 inferen
|
|||||||
Online inference server scripts can help you deploy your own services.
|
Online inference server scripts can help you deploy your own services.
|
||||||
For more details, see [`inference/`](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat/inference).
|
For more details, see [`inference/`](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat/inference).
|
||||||
|
|
||||||
## Coati7B examples
|
|
||||||
|
|
||||||
### Generation
|
|
||||||
|
|
||||||
<details><summary><b>E-mail</b></summary>
|
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||
</details>
|
|
||||||
|
|
||||||
<details><summary><b>coding</b></summary>
|
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||
</details>
|
|
||||||
|
|
||||||
<details><summary><b>regex</b></summary>
|
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||
</details>
|
|
||||||
|
|
||||||
<details><summary><b>Tex</b></summary>
|
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||
</details>
|
|
||||||
|
|
||||||
<details><summary><b>writing</b></summary>
|
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||
</details>
|
|
||||||
|
|
||||||
<details><summary><b>Table</b></summary>
|
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||
</details>
|
|
||||||
|
|
||||||
### Open QA
|
|
||||||
|
|
||||||
<details><summary><b>Game</b></summary>
|
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||
</details>
|
|
||||||
|
|
||||||
<details><summary><b>Travel</b></summary>
|
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||
</details>
|
|
||||||
|
|
||||||
<details><summary><b>Physical</b></summary>
|
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||
</details>
|
|
||||||
|
|
||||||
<details><summary><b>Chemical</b></summary>
|
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||
</details>
|
|
||||||
|
|
||||||
<details><summary><b>Economy</b></summary>
|
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||
</details>
|
|
||||||
|
|
||||||
You can find more examples in this [repo](https://github.com/XueFuzhao/InstructionWild/blob/main/comparison.md).
|
|
||||||
|
|
||||||
### Limitation
|
|
||||||
|
|
||||||
<details><summary><b>Limitation for LLaMA-finetuned models</b></summary>
|
|
||||||
- Both Alpaca and ColossalChat are based on LLaMA. It is hard to compensate for the missing knowledge in the pre-training stage.
|
|
||||||
- Lack of counting ability: Cannot count the number of items in a list.
|
|
||||||
- Lack of Logics (reasoning and calculation)
|
|
||||||
- Tend to repeat the last sentence (fail to produce the end token).
|
|
||||||
- Poor multilingual results: LLaMA is mainly trained on English datasets (Generation performs better than QA).
|
|
||||||
</details>
|
|
||||||
|
|
||||||
<details><summary><b>Limitation of dataset</b></summary>
|
|
||||||
- Lack of summarization ability: No such instructions in finetune datasets.
|
|
||||||
- Lack of multi-turn chat: No such instructions in finetune datasets
|
|
||||||
- Lack of self-recognition: No such instructions in finetune datasets
|
|
||||||
- Lack of Safety:
|
|
||||||
- When the input contains fake facts, the model makes up false facts and explanations.
|
|
||||||
- Cannot abide by OpenAI's policy: When generating prompts from OpenAI API, it always abides by its policy. So no violation case is in the datasets.
|
|
||||||
</details>
|
|
||||||
|
|
||||||
## FAQ
|
|
||||||
|
|
||||||
<details><summary><b>How to save/load checkpoint</b></summary>
|
|
||||||
|
|
||||||
We have integrated the Transformers save and load pipeline, allowing users to freely call Hugging Face's language models and save them in the HF format.
|
|
||||||
|
|
||||||
- Option 1: Save the model weights, model config and generation config (Note: tokenizer will not be saved) which can be loaded using HF's from_pretrained method.
|
|
||||||
```python
|
|
||||||
# if use lora, you can choose to merge lora weights before saving
|
|
||||||
if args.lora_rank > 0 and args.merge_lora_weights:
|
|
||||||
from coati.models.lora import LORA_MANAGER
|
|
||||||
|
|
||||||
# NOTE: set model to eval to merge LoRA weights
|
|
||||||
LORA_MANAGER.merge_weights = True
|
|
||||||
model.eval()
|
|
||||||
# save model checkpoint after fitting on only rank0
|
|
||||||
booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True)
|
|
||||||
|
|
||||||
```
|
|
||||||
|
|
||||||
- Option 2: Save the model weights, model config, generation config, as well as the optimizer, learning rate scheduler, running states (Note: tokenizer will not be saved) which are needed for resuming training.
|
|
||||||
```python
|
|
||||||
from coati.utils import save_checkpoint
|
|
||||||
# save model checkpoint after fitting on only rank0
|
|
||||||
save_checkpoint(
|
|
||||||
save_dir=actor_save_dir,
|
|
||||||
booster=actor_booster,
|
|
||||||
model=model,
|
|
||||||
optimizer=optim,
|
|
||||||
lr_scheduler=lr_scheduler,
|
|
||||||
epoch=0,
|
|
||||||
step=step,
|
|
||||||
batch_size=train_batch_size,
|
|
||||||
coordinator=coordinator,
|
|
||||||
)
|
|
||||||
```
|
|
||||||
To load the saved checkpoint
|
|
||||||
```python
|
|
||||||
from coati.utils import load_checkpoint
|
|
||||||
start_epoch, start_step, sampler_start_idx = load_checkpoint(
|
|
||||||
load_dir=checkpoint_path,
|
|
||||||
booster=booster,
|
|
||||||
model=model,
|
|
||||||
optimizer=optim,
|
|
||||||
lr_scheduler=lr_scheduler,
|
|
||||||
)
|
|
||||||
```
|
|
||||||
</details>
|
|
||||||
|
|
||||||
<details><summary><b>How to train with limited resources</b></summary>
|
|
||||||
|
|
||||||
Here are some suggestions that can allow you to train a 7B model on a single or multiple consumer-grade GPUs.
|
|
||||||
|
|
||||||
`batch_size`, `lora_rank` and `grad_checkpoint` are the most important parameters to successfully train the model. To maintain a descent batch size for gradient calculation, consider increase the accumulation_step and reduce the batch_size on each rank.
|
|
||||||
|
|
||||||
If you only have a single 24G GPU. Generally, using lora and "zero2-cpu" will be sufficient.
|
|
||||||
|
|
||||||
`gemini` and `gemini-auto` can enable a single 24G GPU to train the whole model without using LoRA if you have sufficient CPU memory. But that strategy doesn't support gradient accumulation.
|
|
||||||
|
|
||||||
If you have multiple GPUs each has very limited VRAM, say 8GB. You can try the `3d` for the plugin option, which supports tensor parellelism, set `--tp` to the number of GPUs that you have.
|
|
||||||
</details>
|
|
||||||
|
|
||||||
### Real-time progress
|
|
||||||
|
|
||||||
You will find our progress in github [project broad](https://github.com/orgs/hpcaitech/projects/17/views/1).
|
|
||||||
|
|
||||||
## Invitation to open-source contribution
|
## Invitation to open-source contribution
|
||||||
|
|
||||||
Referring to the successful attempts of [BLOOM](https://bigscience.huggingface.co/) and [Stable Diffusion](https://en.wikipedia.org/wiki/Stable_Diffusion), any and all developers and partners with computing powers, datasets, models are welcome to join and build the Colossal-AI community, making efforts towards the era of big AI models from the starting point of replicating ChatGPT!
|
Referring to the successful attempts of [BLOOM](https://bigscience.huggingface.co/) and [Stable Diffusion](https://en.wikipedia.org/wiki/Stable_Diffusion), any and all developers and partners with computing powers, datasets, models are welcome to join and build the Colossal-AI community, making efforts towards the era of big AI models from the starting point of replicating ChatGPT!
|
||||||
|
|
||||||
You may contact us or participate in the following ways:
|
You may contact us or participate in the following ways:
|
||||||
@ -485,25 +194,17 @@ Thanks so much to all of our amazing contributors!
|
|||||||
- Increase the capacity of the fine-tuning model by up to 3.7 times on a single GPU
|
- Increase the capacity of the fine-tuning model by up to 3.7 times on a single GPU
|
||||||
- Keep in a sufficiently high running speed
|
- Keep in a sufficiently high running speed
|
||||||
|
|
||||||
| Model Pair | Alpaca-7B ⚔ Coati-7B | Coati-7B ⚔ Alpaca-7B |
|
|
||||||
| :-----------: | :------------------: | :------------------: |
|
|
||||||
| Better Cases | 38 ⚔ **41** | **45** ⚔ 33 |
|
|
||||||
| Win Rate | 48% ⚔ **52%** | **58%** ⚔ 42% |
|
|
||||||
| Average Score | 7.06 ⚔ **7.13** | **7.31** ⚔ 6.82 |
|
|
||||||
|
|
||||||
- Our Coati-7B model performs better than Alpaca-7B when using GPT-4 to evaluate model performance. The Coati-7B model we evaluate is an old version we trained a few weeks ago and the new version is around the corner.
|
|
||||||
|
|
||||||
## Authors
|
## Authors
|
||||||
|
|
||||||
Coati is developed by ColossalAI Team:
|
Coati is developed by ColossalAI Team:
|
||||||
|
- [ver217](https://github.com/ver217) Leading the project while contributing to the main framework (System Lead).
|
||||||
- [ver217](https://github.com/ver217) Leading the project while contributing to the main framework.
|
- [Tong Li](https://github.com/TongLi3701) Leading the project while contributing to the main framework (Algorithm Lead).
|
||||||
|
- [Anbang Ye](https://github.com/YeAnbang) Contributing to the refactored PPO version with updated acceleration framework. Add support for DPO, SimPO, ORPO.
|
||||||
- [FrankLeeeee](https://github.com/FrankLeeeee) Providing ML infra support and also taking charge of both front-end and back-end development.
|
- [FrankLeeeee](https://github.com/FrankLeeeee) Providing ML infra support and also taking charge of both front-end and back-end development.
|
||||||
- [htzhou](https://github.com/ht-zhou) Contributing to the algorithm and development for RM and PPO training.
|
- [htzhou](https://github.com/ht-zhou) Contributing to the algorithm and development for RM and PPO training.
|
||||||
- [Fazzie](https://fazzie-key.cool/about/index.html) Contributing to the algorithm and development for SFT.
|
- [Fazzie](https://fazzie-key.cool/about/index.html) Contributing to the algorithm and development for SFT.
|
||||||
- [ofey404](https://github.com/ofey404) Contributing to both front-end and back-end development.
|
- [ofey404](https://github.com/ofey404) Contributing to both front-end and back-end development.
|
||||||
- [Wenhao Chen](https://github.com/CWHer) Contributing to subsequent code enhancements and performance improvements.
|
- [Wenhao Chen](https://github.com/CWHer) Contributing to subsequent code enhancements and performance improvements.
|
||||||
- [Anbang Ye](https://github.com/YeAnbang) Contributing to the refactored PPO version with updated acceleration framework. Add support for DPO, SimPO, ORPO.
|
|
||||||
|
|
||||||
The PhD student from [(HPC-AI) Lab](https://ai.comp.nus.edu.sg/) also contributed a lot to this project.
|
The PhD student from [(HPC-AI) Lab](https://ai.comp.nus.edu.sg/) also contributed a lot to this project.
|
||||||
- [Zangwei Zheng](https://github.com/zhengzangw)
|
- [Zangwei Zheng](https://github.com/zhengzangw)
|
||||||
@ -512,7 +213,6 @@ The PhD student from [(HPC-AI) Lab](https://ai.comp.nus.edu.sg/) also contribute
|
|||||||
We also appreciate the valuable suggestions provided by [Jian Hu](https://github.com/hijkzzz) regarding the convergence of the PPO algorithm.
|
We also appreciate the valuable suggestions provided by [Jian Hu](https://github.com/hijkzzz) regarding the convergence of the PPO algorithm.
|
||||||
|
|
||||||
## Citations
|
## Citations
|
||||||
|
|
||||||
```bibtex
|
```bibtex
|
||||||
@article{Hu2021LoRALA,
|
@article{Hu2021LoRALA,
|
||||||
title = {LoRA: Low-Rank Adaptation of Large Language Models},
|
title = {LoRA: Low-Rank Adaptation of Large Language Models},
|
||||||
@ -583,8 +283,22 @@ We also appreciate the valuable suggestions provided by [Jian Hu](https://github
|
|||||||
primaryClass={cs.CL},
|
primaryClass={cs.CL},
|
||||||
url={https://arxiv.org/abs/2403.07691},
|
url={https://arxiv.org/abs/2403.07691},
|
||||||
}
|
}
|
||||||
|
@misc{shao2024deepseekmathpushinglimitsmathematical,
|
||||||
|
title={DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models},
|
||||||
|
author={Zhihong Shao and Peiyi Wang and Qihao Zhu and Runxin Xu and Junxiao Song and Xiao Bi and Haowei Zhang and Mingchuan Zhang and Y. K. Li and Y. Wu and Daya Guo},
|
||||||
|
year={2024},
|
||||||
|
eprint={2402.03300},
|
||||||
|
archivePrefix={arXiv},
|
||||||
|
primaryClass={cs.CL},
|
||||||
|
url={https://arxiv.org/abs/2402.03300},
|
||||||
|
}
|
||||||
|
@misc{logic-rl,
|
||||||
|
author = {Tian Xie and Qingnan Ren and Yuqian Hong and Zitian Gao and Haoming Luo},
|
||||||
|
title = {Logic-RL},
|
||||||
|
howpublished = {https://github.com/Unakar/Logic-RL},
|
||||||
|
note = {Accessed: 2025-02-03},
|
||||||
|
year = {2025}
|
||||||
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
## Licenses
|
## Licenses
|
||||||
|
|
||||||
Coati is licensed under the [Apache 2.0 License](LICENSE).
|
Coati is licensed under the [Apache 2.0 License](LICENSE).
|
||||||
|
@ -141,7 +141,7 @@ def setup_conversation_template(
|
|||||||
pass
|
pass
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise ValueError(e)
|
raise ValueError(e)
|
||||||
if not dist.is_initialized() or dist.get_rank() == 0:
|
if save_path is not None and (not dist.is_initialized() or dist.get_rank() == 0):
|
||||||
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
||||||
with open(save_path, "w", encoding="utf8") as f:
|
with open(save_path, "w", encoding="utf8") as f:
|
||||||
logger.info(f"Successfully generated a conversation tempalte config, save to {save_path}.")
|
logger.info(f"Successfully generated a conversation tempalte config, save to {save_path}.")
|
||||||
|
@ -8,6 +8,7 @@ import os
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, Iterator, List, Optional, Sequence, Union
|
from typing import Dict, Iterator, List, Optional, Sequence, Union
|
||||||
|
|
||||||
|
import jsonlines
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from coati.dataset.utils import chuncate_sequence, pad_to_max_len
|
from coati.dataset.utils import chuncate_sequence, pad_to_max_len
|
||||||
@ -155,13 +156,14 @@ class DataCollatorForPromptDataset(DataCollatorForSupervisedDataset):
|
|||||||
`input_ids`: `torch.Tensor` of shape (bsz, max_len);
|
`input_ids`: `torch.Tensor` of shape (bsz, max_len);
|
||||||
`attention_mask`: `torch.BoolTensor` of shape (bsz, max_len);
|
`attention_mask`: `torch.BoolTensor` of shape (bsz, max_len);
|
||||||
"""
|
"""
|
||||||
|
gt_answer = [ins.get("gt_answer", None) for ins in instances]
|
||||||
instances = [{"input_ids": ins["input_ids"], "labels": ins["input_ids"]} for ins in instances]
|
instances = [{"input_ids": ins["input_ids"], "labels": ins["input_ids"]} for ins in instances]
|
||||||
ret = super().__call__(instances=instances)
|
ret = super().__call__(instances=instances)
|
||||||
input_ids = F.pad(
|
input_ids = F.pad(
|
||||||
ret["input_ids"], (self.max_length - ret["input_ids"].size(1), 0), value=self.tokenizer.pad_token_id
|
ret["input_ids"], (self.max_length - ret["input_ids"].size(1), 0), value=self.tokenizer.pad_token_id
|
||||||
)
|
)
|
||||||
attention_mask = F.pad(ret["attention_mask"], (self.max_length - ret["attention_mask"].size(1), 0), value=False)
|
attention_mask = F.pad(ret["attention_mask"], (self.max_length - ret["attention_mask"].size(1), 0), value=False)
|
||||||
return {"input_ids": input_ids, "attention_mask": attention_mask}
|
return {"input_ids": input_ids, "attention_mask": attention_mask, "gt_answer": gt_answer}
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -344,3 +346,77 @@ class StatefulDistributedSampler(DistributedSampler):
|
|||||||
|
|
||||||
def set_start_index(self, start_index: int) -> None:
|
def set_start_index(self, start_index: int) -> None:
|
||||||
self.start_index = start_index
|
self.start_index = start_index
|
||||||
|
|
||||||
|
|
||||||
|
def apply_chat_template_and_mask(
|
||||||
|
tokenizer: PreTrainedTokenizer,
|
||||||
|
chat: List[Dict[str, str]],
|
||||||
|
max_length: Optional[int] = None,
|
||||||
|
padding: bool = True,
|
||||||
|
truncation: bool = True,
|
||||||
|
ignore_idx: int = -100,
|
||||||
|
) -> Dict[str, torch.Tensor]:
|
||||||
|
tokens = []
|
||||||
|
assistant_mask = []
|
||||||
|
for i, msg in enumerate(chat):
|
||||||
|
msg_tokens = tokenizer.apply_chat_template([msg], tokenize=True)
|
||||||
|
# remove unexpected bos token
|
||||||
|
if i > 0 and msg_tokens[0] == tokenizer.bos_token_id:
|
||||||
|
msg_tokens = msg_tokens[1:]
|
||||||
|
tokens.extend(msg_tokens)
|
||||||
|
if msg["role"] == "assistant":
|
||||||
|
assistant_mask.extend([True] * len(msg_tokens))
|
||||||
|
else:
|
||||||
|
assistant_mask.extend([False] * len(msg_tokens))
|
||||||
|
attention_mask = [1] * len(tokens)
|
||||||
|
if max_length is not None:
|
||||||
|
if padding and len(tokens) < max_length:
|
||||||
|
to_pad = max_length - len(tokens)
|
||||||
|
if tokenizer.padding_side == "right":
|
||||||
|
tokens.extend([tokenizer.pad_token_id] * to_pad)
|
||||||
|
assistant_mask.extend([False] * to_pad)
|
||||||
|
attention_mask.extend([0] * to_pad)
|
||||||
|
else:
|
||||||
|
tokens = [tokenizer.pad_token_id] * to_pad + tokens
|
||||||
|
assistant_mask = [False] * to_pad + assistant_mask
|
||||||
|
attention_mask = [0] * to_pad + attention_mask
|
||||||
|
if truncation and len(tokens) > max_length:
|
||||||
|
tokens = tokens[:max_length]
|
||||||
|
assistant_mask = assistant_mask[:max_length]
|
||||||
|
attention_mask = attention_mask[:max_length]
|
||||||
|
input_ids = torch.tensor(tokens, dtype=torch.long)
|
||||||
|
attention_mask = torch.tensor(attention_mask, dtype=torch.long)
|
||||||
|
labels = input_ids.clone()
|
||||||
|
labels[~torch.tensor(assistant_mask, dtype=torch.bool)] = ignore_idx
|
||||||
|
|
||||||
|
return {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
"labels": labels,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class RawConversationDataset(Dataset):
|
||||||
|
"""
|
||||||
|
Raw conversation dataset.
|
||||||
|
Each instance is a dictionary with fields `system`, `roles`, `messages`, `offset`, `sep_style`, `seps`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, tokenizer: PreTrainedTokenizer, input_file: str, max_length: int) -> None:
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.raw_texts = []
|
||||||
|
with jsonlines.open(input_file) as f:
|
||||||
|
for line in f:
|
||||||
|
self.raw_texts.append(line)
|
||||||
|
self.tokenized_texts = [None] * len(self.raw_texts)
|
||||||
|
self.max_length = max_length
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self.raw_texts)
|
||||||
|
|
||||||
|
def __getitem__(self, index: int):
|
||||||
|
if self.tokenized_texts[index] is None:
|
||||||
|
message = self.raw_texts[index]
|
||||||
|
tokens = apply_chat_template_and_mask(self.tokenizer, message, self.max_length)
|
||||||
|
self.tokenized_texts[index] = dict(tokens)
|
||||||
|
return self.tokenized_texts[index]
|
||||||
|
@ -147,7 +147,6 @@ def tokenize_prompt(
|
|||||||
ignore_index: the ignore index when calculate loss during training
|
ignore_index: the ignore index when calculate loss during training
|
||||||
max_length: the maximum context length
|
max_length: the maximum context length
|
||||||
"""
|
"""
|
||||||
|
|
||||||
messages = data_point["messages"]
|
messages = data_point["messages"]
|
||||||
template = deepcopy(conversation_template)
|
template = deepcopy(conversation_template)
|
||||||
template.messages = []
|
template.messages = []
|
||||||
@ -167,7 +166,6 @@ def tokenize_prompt(
|
|||||||
if len(template.messages) % 2 != 1:
|
if len(template.messages) % 2 != 1:
|
||||||
# exclude the answer if provided. keep only the prompt
|
# exclude the answer if provided. keep only the prompt
|
||||||
template.messages = template.messages[:-1]
|
template.messages = template.messages[:-1]
|
||||||
|
|
||||||
# Prepare data
|
# Prepare data
|
||||||
prompt = template.get_prompt(length=len(template.messages), add_generation_prompt=True)
|
prompt = template.get_prompt(length=len(template.messages), add_generation_prompt=True)
|
||||||
tokenized = tokenizer([prompt], add_special_tokens=False)["input_ids"][0]
|
tokenized = tokenizer([prompt], add_special_tokens=False)["input_ids"][0]
|
||||||
@ -185,12 +183,21 @@ def tokenize_prompt(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# `inputs_decode` can be used to check whether the tokenization method is true.
|
# `inputs_decode` can be used to check whether the tokenization method is true.
|
||||||
return dict(
|
if "gt_answer" in data_point:
|
||||||
input_ids=tokenized,
|
return dict(
|
||||||
inputs_decode=prompt,
|
input_ids=tokenized,
|
||||||
seq_length=len(tokenized),
|
inputs_decode=prompt,
|
||||||
seq_category=data_point["category"] if "category" in data_point else "None",
|
seq_length=len(tokenized),
|
||||||
)
|
seq_category=data_point["category"] if "category" in data_point else "None",
|
||||||
|
gt_answer=data_point["gt_answer"],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return dict(
|
||||||
|
input_ids=tokenized,
|
||||||
|
inputs_decode=prompt,
|
||||||
|
seq_length=len(tokenized),
|
||||||
|
seq_category=data_point["category"] if "category" in data_point else "None",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def apply_rlhf_data_format(template: Conversation, tokenizer: Any):
|
def apply_rlhf_data_format(template: Conversation, tokenizer: Any):
|
||||||
|
@ -27,6 +27,8 @@ class NaiveExperienceBuffer(ExperienceBuffer):
|
|||||||
self.target_device = torch.device(f"cuda:{torch.cuda.current_device()}")
|
self.target_device = torch.device(f"cuda:{torch.cuda.current_device()}")
|
||||||
# TODO(ver217): add prefetch
|
# TODO(ver217): add prefetch
|
||||||
self.items: List[BufferItem] = []
|
self.items: List[BufferItem] = []
|
||||||
|
self.rng_sequence = []
|
||||||
|
self.ptr = 0
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def append(self, experience: Experience) -> None:
|
def append(self, experience: Experience) -> None:
|
||||||
@ -40,6 +42,9 @@ class NaiveExperienceBuffer(ExperienceBuffer):
|
|||||||
if samples_to_remove > 0:
|
if samples_to_remove > 0:
|
||||||
logger.warning(f"Experience buffer is full. Removing {samples_to_remove} samples.")
|
logger.warning(f"Experience buffer is full. Removing {samples_to_remove} samples.")
|
||||||
self.items = self.items[samples_to_remove:]
|
self.items = self.items[samples_to_remove:]
|
||||||
|
self.rng_sequence = [i for i in range(len(self.items))]
|
||||||
|
random.shuffle(self.rng_sequence)
|
||||||
|
self.ptr = 0
|
||||||
|
|
||||||
def clear(self) -> None:
|
def clear(self) -> None:
|
||||||
self.items.clear()
|
self.items.clear()
|
||||||
@ -52,7 +57,10 @@ class NaiveExperienceBuffer(ExperienceBuffer):
|
|||||||
Returns:
|
Returns:
|
||||||
A batch of sampled experiences.
|
A batch of sampled experiences.
|
||||||
"""
|
"""
|
||||||
items = random.sample(self.items, self.sample_batch_size)
|
items = []
|
||||||
|
for _ in range(self.sample_batch_size):
|
||||||
|
self.ptr = (self.ptr + 1) % len(self.items)
|
||||||
|
items.append(self.items[self.rng_sequence[self.ptr]])
|
||||||
experience = make_experience_batch(items)
|
experience = make_experience_batch(items)
|
||||||
if self.cpu_offload:
|
if self.cpu_offload:
|
||||||
experience.to_device(self.target_device)
|
experience.to_device(self.target_device)
|
||||||
|
@ -2,6 +2,8 @@
|
|||||||
experience maker.
|
experience maker.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from coati.dataset.utils import find_first_occurrence_subsequence
|
from coati.dataset.utils import find_first_occurrence_subsequence
|
||||||
@ -38,14 +40,27 @@ class NaiveExperienceMaker(ExperienceMaker):
|
|||||||
kl_coef: float = 0.01,
|
kl_coef: float = 0.01,
|
||||||
gamma: float = 1.0,
|
gamma: float = 1.0,
|
||||||
lam: float = 0.95,
|
lam: float = 0.95,
|
||||||
|
use_grpo: bool = False,
|
||||||
|
num_generation: int = 8,
|
||||||
|
inference_batch_size: int = None,
|
||||||
|
logits_forward_batch_size: int = 2,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(actor, critic, reward_model, initial_model)
|
super().__init__(actor, critic, reward_model, initial_model)
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.kl_coef = kl_coef
|
self.kl_coef = kl_coef
|
||||||
self.gamma = gamma
|
self.gamma = gamma
|
||||||
self.lam = lam
|
self.lam = lam
|
||||||
|
self.use_grpo = use_grpo
|
||||||
|
self.num_generation = num_generation
|
||||||
|
self.inference_batch_size = inference_batch_size
|
||||||
|
self.logits_forward_batch_size = logits_forward_batch_size
|
||||||
|
if not self.use_grpo:
|
||||||
|
assert self.critic is not None, "Critic model is required for PPO training."
|
||||||
|
else:
|
||||||
|
assert self.critic is None, "Critic model is not required for GRPO training."
|
||||||
|
assert self.num_generation > 1, "Number of generations should be greater than 1 for GRPO training."
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.inference_mode()
|
||||||
def calculate_advantage(self, value: torch.Tensor, reward: torch.Tensor, num_actions: int) -> torch.Tensor:
|
def calculate_advantage(self, value: torch.Tensor, reward: torch.Tensor, num_actions: int) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Calculates the advantage values for each action based on the value and reward tensors.
|
Calculates the advantage values for each action based on the value and reward tensors.
|
||||||
@ -69,7 +84,9 @@ class NaiveExperienceMaker(ExperienceMaker):
|
|||||||
return advantages
|
return advantages
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def make_experience(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **generate_kwargs) -> Experience:
|
def make_experience(
|
||||||
|
self, input_ids: torch.Tensor, attention_mask: torch.Tensor, gt_answer: Any = None, **generate_kwargs
|
||||||
|
) -> Experience:
|
||||||
"""
|
"""
|
||||||
Generates an experience using the given input_ids and attention_mask.
|
Generates an experience using the given input_ids and attention_mask.
|
||||||
|
|
||||||
@ -83,98 +100,204 @@ class NaiveExperienceMaker(ExperienceMaker):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
self.actor.eval()
|
self.actor.eval()
|
||||||
self.critic.eval()
|
if self.critic:
|
||||||
|
self.critic.eval()
|
||||||
self.initial_model.eval()
|
self.initial_model.eval()
|
||||||
self.reward_model.eval()
|
self.reward_model.eval()
|
||||||
pad_token_id = self.tokenizer.pad_token_id
|
pad_token_id = self.tokenizer.pad_token_id
|
||||||
|
|
||||||
stop_token_ids = generate_kwargs.get("stop_token_ids", None)
|
stop_token_ids = generate_kwargs.get("stop_token_ids", None)
|
||||||
|
if isinstance(stop_token_ids, int):
|
||||||
|
stop_token_ids = [[stop_token_ids]]
|
||||||
|
elif isinstance(stop_token_ids[0], int):
|
||||||
|
stop_token_ids = [stop_token_ids]
|
||||||
|
elif isinstance(stop_token_ids[0], list):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"stop_token_ids should be a list of list of integers, a list of integers or an integers. got {stop_token_ids}"
|
||||||
|
)
|
||||||
|
generate_kwargs["stop_token_ids"] = stop_token_ids
|
||||||
torch.manual_seed(41) # for tp, gurantee the same input for reward model
|
torch.manual_seed(41) # for tp, gurantee the same input for reward model
|
||||||
|
|
||||||
sequences = generate(self.actor, input_ids, self.tokenizer, **generate_kwargs)
|
if self.use_grpo and self.num_generation > 1:
|
||||||
|
# Generate multiple responses for each prompt
|
||||||
|
input_ids = input_ids.repeat_interleave(self.num_generation, dim=0)
|
||||||
|
gt_answer_tmp = []
|
||||||
|
for t in gt_answer:
|
||||||
|
gt_answer_tmp.extend([t] * self.num_generation)
|
||||||
|
gt_answer = gt_answer_tmp
|
||||||
|
if self.inference_batch_size is None:
|
||||||
|
self.inference_batch_size = input_ids.size(0)
|
||||||
|
|
||||||
# Pad to max length
|
batch_sequences = []
|
||||||
sequences = F.pad(sequences, (0, generate_kwargs["max_length"] - sequences.size(1)), value=pad_token_id)
|
batch_input_ids_rm = []
|
||||||
sequence_length = sequences.size(1)
|
batch_attention_mask_rm = []
|
||||||
|
batch_attention_mask = []
|
||||||
|
batch_r = []
|
||||||
|
batch_action_log_probs = []
|
||||||
|
batch_base_action_log_probs = []
|
||||||
|
batch_action_mask = []
|
||||||
|
num_actions = 0
|
||||||
|
|
||||||
# Calculate auxiliary tensors
|
for inference_mini_batch_id in range(0, input_ids.size(0), self.inference_batch_size):
|
||||||
attention_mask = None
|
s, e = inference_mini_batch_id, inference_mini_batch_id + self.inference_batch_size
|
||||||
if pad_token_id is not None:
|
if input_ids[s:e].size(0) == 0:
|
||||||
attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device)
|
break
|
||||||
|
sequences = generate(self.actor, input_ids[s:e], self.tokenizer, **generate_kwargs)
|
||||||
|
# pad to max_len, you don't want to get an OOM error after a thousands of steps
|
||||||
|
sequences = F.pad(sequences, (0, generate_kwargs["max_length"] - sequences.size(1)), value=pad_token_id)
|
||||||
|
|
||||||
input_len = input_ids.size(1)
|
# Pad to max length
|
||||||
if stop_token_ids is None:
|
sequence_length = sequences.size(1)
|
||||||
# End the sequence with eos token
|
|
||||||
eos_token_id = self.tokenizer.eos_token_id
|
# Calculate auxiliary tensors
|
||||||
if eos_token_id is None:
|
attention_mask = None
|
||||||
action_mask = torch.ones_like(sequences, dtype=torch.bool)
|
if pad_token_id is not None:
|
||||||
else:
|
attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device)
|
||||||
# Left padding may be applied, only mask action
|
|
||||||
action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0
|
input_len = input_ids.size(1)
|
||||||
action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input
|
if stop_token_ids is None:
|
||||||
else:
|
# End the sequence with eos token
|
||||||
# stop_token_ids are given, generation ends with stop_token_ids
|
eos_token_id = self.tokenizer.eos_token_id
|
||||||
action_mask = torch.ones_like(sequences, dtype=torch.bool)
|
if eos_token_id is None:
|
||||||
for i in range(sequences.size(0)):
|
action_mask = torch.ones_like(sequences, dtype=torch.bool)
|
||||||
stop_index = find_first_occurrence_subsequence(
|
|
||||||
sequences[i][input_len:], torch.tensor(stop_token_ids).to(sequences.device)
|
|
||||||
)
|
|
||||||
if stop_index == -1:
|
|
||||||
# Sequence does not contain stop_token_ids, this should never happen BTW
|
|
||||||
logger.warning(
|
|
||||||
"Generated sequence does not contain stop_token_ids. Please check your chat template config"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
# Keep stop tokens
|
# Left padding may be applied, only mask action
|
||||||
stop_index = input_len + stop_index
|
action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0
|
||||||
action_mask[i, stop_index + len(stop_token_ids) :] = False
|
action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input
|
||||||
|
|
||||||
generation_end_index = (action_mask == True).sum(dim=-1) - 1
|
|
||||||
action_mask[:, :input_len] = False
|
|
||||||
action_mask = action_mask[:, 1:]
|
|
||||||
action_mask = action_mask[:, -(sequences.size(1) - input_len) :]
|
|
||||||
num_actions = action_mask.size(1)
|
|
||||||
|
|
||||||
actor_output = self.actor(input_ids=sequences, attention_mask=attention_mask)["logits"]
|
|
||||||
action_log_probs = calc_action_log_probs(actor_output, sequences, num_actions)
|
|
||||||
|
|
||||||
base_model_output = self.initial_model(input_ids=sequences, attention_mask=attention_mask)["logits"]
|
|
||||||
|
|
||||||
base_action_log_probs = calc_action_log_probs(base_model_output, sequences, num_actions)
|
|
||||||
|
|
||||||
# Convert to right padding for the reward model and the critic model
|
|
||||||
input_ids_rm = torch.zeros_like(sequences, device=sequences.device)
|
|
||||||
attention_mask_rm = torch.zeros_like(sequences, device=sequences.device)
|
|
||||||
for i in range(sequences.size(0)):
|
|
||||||
sequence = sequences[i]
|
|
||||||
bos_index = (sequence != pad_token_id).nonzero().reshape([-1])[0]
|
|
||||||
eos_index = generation_end_index[i]
|
|
||||||
sequence_to_pad = sequence[bos_index:eos_index]
|
|
||||||
sequence_padded = F.pad(
|
|
||||||
sequence_to_pad, (0, sequence_length - sequence_to_pad.size(0)), value=self.tokenizer.pad_token_id
|
|
||||||
)
|
|
||||||
input_ids_rm[i] = sequence_padded
|
|
||||||
if sequence_length - sequence_to_pad.size(0) > 0:
|
|
||||||
attention_mask_rm[i, : sequence_to_pad.size(0) + 1] = 1
|
|
||||||
else:
|
else:
|
||||||
attention_mask_rm[i, :] = 1
|
# stop_token_ids are given, generation ends with stop_token_ids
|
||||||
attention_mask_rm = attention_mask_rm.to(dtype=torch.bool)
|
action_mask = torch.ones_like(sequences, dtype=torch.bool)
|
||||||
|
for i in range(sequences.size(0)):
|
||||||
|
stop_token_pos = [
|
||||||
|
find_first_occurrence_subsequence(
|
||||||
|
sequences[i][input_len:], torch.tensor(stop_token_id).to(sequences.device)
|
||||||
|
)
|
||||||
|
for stop_token_id in stop_token_ids
|
||||||
|
]
|
||||||
|
stop_index = min([i for i in stop_token_pos if i != -1], default=-1)
|
||||||
|
stop_token_id = stop_token_ids[stop_token_pos.index(stop_index)]
|
||||||
|
if stop_index == -1:
|
||||||
|
# Sequence does not contain stop_token_ids, this should never happen BTW
|
||||||
|
logger.warning(
|
||||||
|
"Generated sequence does not contain stop_token_ids. Please check your chat template config"
|
||||||
|
)
|
||||||
|
print(self.tokenizer.decode(sequences[i], skip_special_tokens=True))
|
||||||
|
else:
|
||||||
|
# Keep stop tokens
|
||||||
|
stop_index = input_len + stop_index
|
||||||
|
action_mask[i, stop_index + len(stop_token_id) :] = False
|
||||||
|
|
||||||
r = self.reward_model(
|
generation_end_index = (action_mask == True).sum(dim=-1) - 1
|
||||||
input_ids=input_ids_rm.to(dtype=torch.long, device=sequences.device),
|
action_mask[:, :input_len] = False
|
||||||
attention_mask=attention_mask_rm.to(device=sequences.device),
|
action_mask = action_mask[:, 1:]
|
||||||
)
|
action_mask = action_mask[:, -(sequences.size(1) - input_len) :]
|
||||||
|
num_actions = action_mask.size(1)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
with torch.inference_mode():
|
||||||
|
actor_output = []
|
||||||
|
base_model_output = []
|
||||||
|
for i in range(0, sequences.size(0), self.logits_forward_batch_size):
|
||||||
|
actor_output.append(
|
||||||
|
self.actor(
|
||||||
|
input_ids=sequences[i : i + self.logits_forward_batch_size],
|
||||||
|
attention_mask=attention_mask[i : i + self.logits_forward_batch_size],
|
||||||
|
use_cache=False,
|
||||||
|
)["logits"]
|
||||||
|
)
|
||||||
|
base_model_output.append(
|
||||||
|
self.initial_model(
|
||||||
|
input_ids=sequences[i : i + self.logits_forward_batch_size],
|
||||||
|
attention_mask=attention_mask[i : i + self.logits_forward_batch_size],
|
||||||
|
use_cache=False,
|
||||||
|
)["logits"]
|
||||||
|
)
|
||||||
|
actor_output = torch.cat(actor_output, dim=0)
|
||||||
|
base_model_output = torch.cat(base_model_output, dim=0)
|
||||||
|
action_log_probs = calc_action_log_probs(actor_output, sequences, num_actions)
|
||||||
|
base_action_log_probs = calc_action_log_probs(base_model_output, sequences, num_actions)
|
||||||
|
|
||||||
value = self.critic(
|
# Convert to right padding for the reward model and the critic model
|
||||||
input_ids=input_ids_rm.to(dtype=torch.long, device=sequences.device),
|
input_ids_rm = torch.zeros_like(sequences, device=sequences.device)
|
||||||
attention_mask=attention_mask_rm.to(device=sequences.device),
|
response_start = []
|
||||||
)
|
response_end = []
|
||||||
reward, kl = compute_reward(r, self.kl_coef, action_log_probs, base_action_log_probs, action_mask=action_mask)
|
attention_mask_rm = torch.zeros_like(sequences, device=sequences.device)
|
||||||
value = value[:, -num_actions:] * action_mask
|
for i in range(sequences.size(0)):
|
||||||
advantages = self.calculate_advantage(value, reward, num_actions)
|
sequence = sequences[i]
|
||||||
|
bos_index = (sequence != pad_token_id).nonzero().reshape([-1])[0]
|
||||||
|
eos_index = generation_end_index[i] + 1 # include the stop token
|
||||||
|
sequence_to_pad = sequence[bos_index:eos_index]
|
||||||
|
response_start.append(input_len - bos_index)
|
||||||
|
response_end.append(eos_index - bos_index)
|
||||||
|
sequence_padded = F.pad(
|
||||||
|
sequence_to_pad, (0, sequence_length - sequence_to_pad.size(0)), value=self.tokenizer.pad_token_id
|
||||||
|
)
|
||||||
|
input_ids_rm[i] = sequence_padded
|
||||||
|
if sequence_length - sequence_to_pad.size(0) > 0:
|
||||||
|
attention_mask_rm[i, : sequence_to_pad.size(0) + 1] = 1
|
||||||
|
else:
|
||||||
|
attention_mask_rm[i, :] = 1
|
||||||
|
attention_mask_rm = attention_mask_rm.to(dtype=torch.bool)
|
||||||
|
|
||||||
advantages = advantages.detach()
|
r = self.reward_model(
|
||||||
value = value.detach()
|
input_ids=input_ids_rm.to(dtype=torch.long, device=sequences.device),
|
||||||
|
attention_mask=attention_mask_rm.to(device=sequences.device),
|
||||||
|
response_start=response_start,
|
||||||
|
response_end=response_end,
|
||||||
|
gt_answer=gt_answer[s:e],
|
||||||
|
)
|
||||||
|
|
||||||
|
batch_sequences.append(sequences)
|
||||||
|
batch_input_ids_rm.append(input_ids_rm)
|
||||||
|
batch_attention_mask_rm.append(attention_mask_rm)
|
||||||
|
batch_attention_mask.append(attention_mask)
|
||||||
|
batch_r.append(r)
|
||||||
|
batch_action_log_probs.append(action_log_probs.cpu())
|
||||||
|
batch_base_action_log_probs.append(base_action_log_probs.cpu())
|
||||||
|
batch_action_mask.append(action_mask)
|
||||||
|
|
||||||
|
sequences = torch.cat(batch_sequences, dim=0)
|
||||||
|
input_ids_rm = torch.cat(batch_input_ids_rm, dim=0)
|
||||||
|
attention_mask_rm = torch.cat(batch_attention_mask_rm, dim=0)
|
||||||
|
attention_mask = torch.cat(batch_attention_mask, dim=0)
|
||||||
|
r = torch.cat(batch_r, dim=0)
|
||||||
|
action_log_probs = torch.cat(batch_action_log_probs, dim=0).to(sequences.device)
|
||||||
|
base_action_log_probs = torch.cat(batch_base_action_log_probs, dim=0).to(sequences.device)
|
||||||
|
action_mask = torch.cat(batch_action_mask, dim=0).to(sequences.device)
|
||||||
|
if not self.use_grpo:
|
||||||
|
value = self.critic(
|
||||||
|
input_ids=input_ids_rm.to(dtype=torch.long, device=sequences.device),
|
||||||
|
attention_mask=attention_mask_rm.to(device=sequences.device),
|
||||||
|
)
|
||||||
|
value = value[:, -num_actions:] * action_mask
|
||||||
|
reward, kl = compute_reward(
|
||||||
|
r, self.kl_coef, action_log_probs, base_action_log_probs, action_mask=action_mask
|
||||||
|
)
|
||||||
|
advantages = self.calculate_advantage(value, reward, num_actions)
|
||||||
|
advantages = advantages.detach()
|
||||||
|
value = value.detach()
|
||||||
|
else:
|
||||||
|
# GRPO advantage calculation
|
||||||
|
kl = torch.sum(
|
||||||
|
-self.kl_coef * (action_log_probs - base_action_log_probs) * action_mask, dim=-1
|
||||||
|
) / torch.sum(
|
||||||
|
action_mask, dim=-1
|
||||||
|
) # address numerical instability issue
|
||||||
|
r = kl + r
|
||||||
|
mean_gr = r.view(-1, self.num_generation).mean(dim=1)
|
||||||
|
std_gr = r.view(-1, self.num_generation).std(dim=1)
|
||||||
|
mean_gr = mean_gr.repeat_interleave(self.num_generation, dim=0)
|
||||||
|
std_gr = std_gr.repeat_interleave(self.num_generation, dim=0)
|
||||||
|
advantages = (r - mean_gr) / (std_gr + 1e-4)
|
||||||
|
value = r.detach() # dummy value
|
||||||
r = r.detach()
|
r = r.detach()
|
||||||
|
return Experience(
|
||||||
return Experience(sequences, action_log_probs, value, r, kl, advantages, attention_mask, action_mask)
|
sequences.cpu(),
|
||||||
|
action_log_probs.cpu(),
|
||||||
|
value.cpu(),
|
||||||
|
r.cpu(),
|
||||||
|
kl.cpu(),
|
||||||
|
advantages.cpu(),
|
||||||
|
attention_mask.cpu(),
|
||||||
|
action_mask.cpu(),
|
||||||
|
)
|
||||||
|
@ -4,12 +4,14 @@ from .generation import generate, generate_streaming, prepare_inputs_fn, update_
|
|||||||
from .lora import LoraConfig, convert_to_lora_module, lora_manager
|
from .lora import LoraConfig, convert_to_lora_module, lora_manager
|
||||||
from .loss import DpoLoss, KTOLoss, LogExpLoss, LogSigLoss, PolicyLoss, ValueLoss
|
from .loss import DpoLoss, KTOLoss, LogExpLoss, LogSigLoss, PolicyLoss, ValueLoss
|
||||||
from .reward_model import RewardModel
|
from .reward_model import RewardModel
|
||||||
|
from .rlvr_reward_model import RLVRRewardModel
|
||||||
from .utils import disable_dropout
|
from .utils import disable_dropout
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BaseModel",
|
"BaseModel",
|
||||||
"Critic",
|
"Critic",
|
||||||
"RewardModel",
|
"RewardModel",
|
||||||
|
"RLVRRewardModel",
|
||||||
"PolicyLoss",
|
"PolicyLoss",
|
||||||
"ValueLoss",
|
"ValueLoss",
|
||||||
"LogSigLoss",
|
"LogSigLoss",
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import copy
|
||||||
from typing import Any, Callable, List, Optional
|
from typing import Any, Callable, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -88,13 +89,14 @@ def update_model_kwargs_fn(outputs: dict, new_mask, **model_kwargs) -> dict:
|
|||||||
return model_kwargs
|
return model_kwargs
|
||||||
|
|
||||||
|
|
||||||
def prepare_inputs_fn(input_ids: torch.Tensor, pad_token_id: int, **model_kwargs) -> dict:
|
def prepare_inputs_fn(input_ids: torch.Tensor, **model_kwargs) -> dict:
|
||||||
model_kwargs["input_ids"] = input_ids
|
model_kwargs["input_ids"] = input_ids
|
||||||
return model_kwargs
|
return model_kwargs
|
||||||
|
|
||||||
|
|
||||||
def _sample(
|
def _sample(
|
||||||
model: Any,
|
model: Any,
|
||||||
|
tokenizer: Any,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
max_length: int,
|
max_length: int,
|
||||||
early_stopping: bool = True,
|
early_stopping: bool = True,
|
||||||
@ -137,8 +139,8 @@ def _sample(
|
|||||||
if max_new_tokens is None:
|
if max_new_tokens is None:
|
||||||
max_new_tokens = max_length - context_length
|
max_new_tokens = max_length - context_length
|
||||||
if context_length + max_new_tokens > max_length or max_new_tokens == 0:
|
if context_length + max_new_tokens > max_length or max_new_tokens == 0:
|
||||||
|
print("Exeeded length limitation")
|
||||||
return input_ids
|
return input_ids
|
||||||
|
|
||||||
logits_processor = _prepare_logits_processor(top_k, top_p, temperature)
|
logits_processor = _prepare_logits_processor(top_k, top_p, temperature)
|
||||||
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
|
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
|
||||||
past = None
|
past = None
|
||||||
@ -183,18 +185,14 @@ def _sample(
|
|||||||
|
|
||||||
if stop_token_ids is not None:
|
if stop_token_ids is not None:
|
||||||
# If the last len(stop_token_ids) tokens of input_ids are equal to stop_token_ids, set sentence to finished.
|
# If the last len(stop_token_ids) tokens of input_ids are equal to stop_token_ids, set sentence to finished.
|
||||||
tokens_to_check = input_ids[:, -len(stop_token_ids) :]
|
for stop_token_id in stop_token_ids:
|
||||||
unfinished_sequences = unfinished_sequences.mul(
|
tokens_to_check = input_ids[:, -len(stop_token_id) :]
|
||||||
torch.any(tokens_to_check != torch.LongTensor(stop_token_ids).to(input_ids.device), dim=1).long()
|
unfinished_sequences = unfinished_sequences.mul(
|
||||||
)
|
torch.any(tokens_to_check != torch.LongTensor(stop_token_id).to(input_ids.device), dim=1).long()
|
||||||
|
)
|
||||||
|
|
||||||
# Stop when each sentence is finished if early_stopping=True
|
# Stop when each sentence is finished if early_stopping=True
|
||||||
if (early_stopping and _is_sequence_finished(unfinished_sequences)) or i == context_length + max_new_tokens - 1:
|
if (early_stopping and _is_sequence_finished(unfinished_sequences)) or i == context_length + max_new_tokens - 1:
|
||||||
if i == context_length + max_new_tokens - 1:
|
|
||||||
# Force to end with stop token ids
|
|
||||||
input_ids[input_ids[:, -1] != pad_token_id, -len(stop_token_ids) :] = (
|
|
||||||
torch.LongTensor(stop_token_ids).to(input_ids.device).long()
|
|
||||||
)
|
|
||||||
return input_ids
|
return input_ids
|
||||||
|
|
||||||
|
|
||||||
@ -237,8 +235,10 @@ def generate(
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
elif is_sample_gen_mode:
|
elif is_sample_gen_mode:
|
||||||
# Run sample
|
# Run sample
|
||||||
|
generation_kwargs = copy.deepcopy(model_kwargs)
|
||||||
res = _sample(
|
res = _sample(
|
||||||
model,
|
model,
|
||||||
|
tokenizer,
|
||||||
input_ids,
|
input_ids,
|
||||||
max_length,
|
max_length,
|
||||||
early_stopping=early_stopping,
|
early_stopping=early_stopping,
|
||||||
@ -249,8 +249,9 @@ def generate(
|
|||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
prepare_inputs_fn=prepare_inputs_fn,
|
prepare_inputs_fn=prepare_inputs_fn,
|
||||||
update_model_kwargs_fn=update_model_kwargs_fn,
|
update_model_kwargs_fn=update_model_kwargs_fn,
|
||||||
**model_kwargs,
|
**generation_kwargs,
|
||||||
)
|
)
|
||||||
|
del generation_kwargs
|
||||||
return res
|
return res
|
||||||
elif is_beam_gen_mode:
|
elif is_beam_gen_mode:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
@ -350,11 +351,17 @@ def _sample_streaming(
|
|||||||
unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long())
|
unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long())
|
||||||
|
|
||||||
if stop_token_ids is not None:
|
if stop_token_ids is not None:
|
||||||
# If the last len(stop_token_ids) tokens of input_ids are equal to stop_token_ids, set sentence to finished.
|
|
||||||
tokens_to_check = input_ids[:, -len(stop_token_ids) :]
|
tokens_to_check = input_ids[:, -len(stop_token_ids) :]
|
||||||
unfinished_sequences = unfinished_sequences.mul(
|
if isinstance(stop_token_ids[0], int):
|
||||||
torch.any(tokens_to_check != torch.LongTensor(stop_token_ids).to(input_ids.device), dim=1).long()
|
# If the last len(stop_token_ids) tokens of input_ids are equal to stop_token_ids, set sentence to finished.
|
||||||
)
|
unfinished_sequences = unfinished_sequences.mul(
|
||||||
|
torch.any(tokens_to_check != torch.LongTensor(stop_token_ids).to(input_ids.device), dim=1).long()
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
for stop_token_id in stop_token_ids:
|
||||||
|
unfinished_sequences = unfinished_sequences.mul(
|
||||||
|
torch.any(tokens_to_check != torch.LongTensor(stop_token_id).to(input_ids.device), dim=1).long()
|
||||||
|
)
|
||||||
|
|
||||||
# Stop when each sentence is finished if early_stopping=True
|
# Stop when each sentence is finished if early_stopping=True
|
||||||
if (
|
if (
|
||||||
|
@ -153,10 +153,11 @@ class DpoLoss(nn.Module):
|
|||||||
else:
|
else:
|
||||||
# If no reference model is provided
|
# If no reference model is provided
|
||||||
ref_logratios = 0.0
|
ref_logratios = 0.0
|
||||||
|
|
||||||
pi_logratios = logprob_actor_chosen.sum(-1) - logprob_actor_reject.sum(-1)
|
pi_logratios = logprob_actor_chosen.sum(-1) - logprob_actor_reject.sum(-1)
|
||||||
logits = pi_logratios - ref_logratios - self.gamma / self.beta
|
logits = pi_logratios - ref_logratios - self.gamma / self.beta
|
||||||
losses = -torch.nn.functional.logsigmoid(self.beta * logits)
|
losses = -torch.nn.functional.logsigmoid(self.beta * logits)
|
||||||
|
loss = losses.mean()
|
||||||
# Calculate rewards for logging
|
# Calculate rewards for logging
|
||||||
if logprob_ref_chosen is not None:
|
if logprob_ref_chosen is not None:
|
||||||
chosen_rewards = self.beta * (logprob_actor_chosen.sum(-1) - logprob_ref_chosen.sum(-1)).detach()
|
chosen_rewards = self.beta * (logprob_actor_chosen.sum(-1) - logprob_ref_chosen.sum(-1)).detach()
|
||||||
@ -167,7 +168,7 @@ class DpoLoss(nn.Module):
|
|||||||
else:
|
else:
|
||||||
rejected_rewards = self.beta * logprob_actor_reject.sum(-1).detach()
|
rejected_rewards = self.beta * logprob_actor_reject.sum(-1).detach()
|
||||||
|
|
||||||
return losses, chosen_rewards, rejected_rewards
|
return loss, chosen_rewards, rejected_rewards
|
||||||
|
|
||||||
|
|
||||||
class LogSigLoss(nn.Module):
|
class LogSigLoss(nn.Module):
|
||||||
|
@ -25,7 +25,9 @@ class RewardModel(BaseModel):
|
|||||||
self.value_head = nn.Linear(self.last_hidden_state_size, 1)
|
self.value_head = nn.Linear(self.last_hidden_state_size, 1)
|
||||||
self.value_head.weight.data.normal_(mean=0.0, std=1 / (self.last_hidden_state_size + 1))
|
self.value_head.weight.data.normal_(mean=0.0, std=1 / (self.last_hidden_state_size + 1))
|
||||||
|
|
||||||
def forward(self, input_ids: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
def forward(
|
||||||
|
self, input_ids: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None, **kwargs
|
||||||
|
) -> torch.Tensor:
|
||||||
outputs = self.model(input_ids, attention_mask=attention_mask)
|
outputs = self.model(input_ids, attention_mask=attention_mask)
|
||||||
|
|
||||||
last_hidden_states = outputs["last_hidden_state"]
|
last_hidden_states = outputs["last_hidden_state"]
|
||||||
|
50
applications/ColossalChat/coati/models/rlvr_reward_model.py
Normal file
50
applications/ColossalChat/coati/models/rlvr_reward_model.py
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
"""
|
||||||
|
reward model
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Callable, List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class RLVRRewardModel:
|
||||||
|
"""
|
||||||
|
RLVRReward model class. Support varifiable reward.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
reward_fn_list List: list of reward functions
|
||||||
|
**kwargs: all other kwargs as in reward functions
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, reward_fn_list: List[Callable], **kwargs) -> None:
|
||||||
|
self.reward_fn_list = reward_fn_list
|
||||||
|
self.kwargs = kwargs
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
response_start: List = None,
|
||||||
|
response_end: List = None,
|
||||||
|
gt_answer: List = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# apply varifiable reward
|
||||||
|
bs = input_ids.size(0)
|
||||||
|
rewards = torch.zeros(bs, device=input_ids.device)
|
||||||
|
for i in range(bs):
|
||||||
|
for reward_fn in self.reward_fn_list:
|
||||||
|
rewards[i] += reward_fn(
|
||||||
|
input_ids[i],
|
||||||
|
attention_mask[i],
|
||||||
|
response_start=response_start[i],
|
||||||
|
response_end=response_end[i],
|
||||||
|
gt_answer=gt_answer[i],
|
||||||
|
**self.kwargs,
|
||||||
|
)
|
||||||
|
return rewards
|
||||||
|
|
||||||
|
def to(self, device):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def eval(self):
|
||||||
|
return self
|
@ -50,8 +50,8 @@ def _log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.
|
|||||||
torch.Tensor: The log probabilities corresponding to the labels.
|
torch.Tensor: The log probabilities corresponding to the labels.
|
||||||
"""
|
"""
|
||||||
log_probs = F.log_softmax(logits, dim=-1)
|
log_probs = F.log_softmax(logits, dim=-1)
|
||||||
log_probs_labels = log_probs.gather(dim=-1, index=labels.unsqueeze(-1))
|
per_label_logps = log_probs.gather(dim=-1, index=labels.unsqueeze(-1))
|
||||||
return log_probs_labels.squeeze(-1)
|
return per_label_logps.squeeze(-1)
|
||||||
|
|
||||||
|
|
||||||
def calc_action_log_probs(logits: torch.Tensor, sequences: torch.LongTensor, num_actions: int) -> torch.Tensor:
|
def calc_action_log_probs(logits: torch.Tensor, sequences: torch.LongTensor, num_actions: int) -> torch.Tensor:
|
||||||
@ -142,3 +142,17 @@ def disable_dropout(model: torch.nn.Module):
|
|||||||
for module in model.modules():
|
for module in model.modules():
|
||||||
if isinstance(module, torch.nn.Dropout):
|
if isinstance(module, torch.nn.Dropout):
|
||||||
module.p = 0.0
|
module.p = 0.0
|
||||||
|
|
||||||
|
|
||||||
|
def repad_to_left(tensor, tokenizer):
|
||||||
|
repadded_input_ids = []
|
||||||
|
max_non_padded_seq_len = 0
|
||||||
|
for i in range(tensor.size(0)):
|
||||||
|
non_pad_indices = (tensor[i] != tokenizer.pad_token_id).nonzero(as_tuple=True)[0]
|
||||||
|
start, end = non_pad_indices.min(), non_pad_indices.max()
|
||||||
|
repadded_input_ids.append(tensor[i][start : end + 1])
|
||||||
|
max_non_padded_seq_len = max(max_non_padded_seq_len, repadded_input_ids[-1].size(0))
|
||||||
|
repadded_input_ids = [
|
||||||
|
F.pad(t, (max_non_padded_seq_len - t.size(0), 0), value=tokenizer.pad_token_id) for t in repadded_input_ids
|
||||||
|
]
|
||||||
|
return torch.stack(repadded_input_ids)
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
from .base import OLTrainer, SLTrainer
|
from .base import OLTrainer, SLTrainer
|
||||||
from .dpo import DPOTrainer
|
from .dpo import DPOTrainer
|
||||||
|
from .grpo import GRPOTrainer
|
||||||
from .kto import KTOTrainer
|
from .kto import KTOTrainer
|
||||||
from .orpo import ORPOTrainer
|
from .orpo import ORPOTrainer
|
||||||
from .ppo import PPOTrainer
|
from .ppo import PPOTrainer
|
||||||
@ -15,4 +16,5 @@ __all__ = [
|
|||||||
"DPOTrainer",
|
"DPOTrainer",
|
||||||
"ORPOTrainer",
|
"ORPOTrainer",
|
||||||
"KTOTrainer",
|
"KTOTrainer",
|
||||||
|
"GRPOTrainer",
|
||||||
]
|
]
|
||||||
|
@ -96,6 +96,7 @@ class OLTrainer(ABC):
|
|||||||
self.sample_buffer = sample_buffer
|
self.sample_buffer = sample_buffer
|
||||||
self.dataloader_pin_memory = dataloader_pin_memory
|
self.dataloader_pin_memory = dataloader_pin_memory
|
||||||
self.callbacks = callbacks
|
self.callbacks = callbacks
|
||||||
|
self.num_train_step = 0
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def _fit_ctx(self) -> None:
|
def _fit_ctx(self) -> None:
|
||||||
@ -212,5 +213,6 @@ class OLTrainer(ABC):
|
|||||||
self._update_phase(update_step)
|
self._update_phase(update_step)
|
||||||
# NOTE: this is for on-policy algorithms
|
# NOTE: this is for on-policy algorithms
|
||||||
self.data_buffer.clear()
|
self.data_buffer.clear()
|
||||||
if self.save_interval > 0 and (episode + 1) % (self.save_interval) == 0:
|
|
||||||
self._save_checkpoint(episode + 1)
|
if self.num_train_step > 0 and (self.num_train_step + 1) % (self.save_interval) == 0:
|
||||||
|
self._save_checkpoint(self.num_train_step + 1)
|
||||||
|
@ -6,6 +6,7 @@ import os
|
|||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
from coati.models.loss import DpoLoss
|
from coati.models.loss import DpoLoss
|
||||||
from coati.models.utils import calc_masked_log_probs
|
from coati.models.utils import calc_masked_log_probs
|
||||||
from coati.trainer.utils import all_reduce_mean
|
from coati.trainer.utils import all_reduce_mean
|
||||||
@ -13,10 +14,11 @@ from coati.utils import AccumulativeMeanMeter, save_checkpoint
|
|||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from torch.optim.lr_scheduler import _LRScheduler
|
from torch.optim.lr_scheduler import _LRScheduler
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from tqdm import trange
|
from tqdm import tqdm, trange
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
from colossalai.booster import Booster, Plugin
|
from colossalai.booster import Booster, Plugin
|
||||||
|
from colossalai.booster.plugin import HybridParallelPlugin
|
||||||
from colossalai.cluster import DistCoordinator
|
from colossalai.cluster import DistCoordinator
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
@ -96,18 +98,25 @@ class DPOTrainer(SLTrainer):
|
|||||||
self.train_dataloader = train_preference_dataloader
|
self.train_dataloader = train_preference_dataloader
|
||||||
self.eval_dataloader = eval_preference_dataloader
|
self.eval_dataloader = eval_preference_dataloader
|
||||||
self.writer = None
|
self.writer = None
|
||||||
if use_wandb and is_rank_0():
|
|
||||||
|
init_criterion = (
|
||||||
|
dist.get_rank() == dist.get_world_size() - 1
|
||||||
|
if isinstance(self.plugin, HybridParallelPlugin) and self.plugin.pp_size > 1
|
||||||
|
else is_rank_0()
|
||||||
|
)
|
||||||
|
|
||||||
|
if use_wandb and init_criterion:
|
||||||
assert log_dir is not None, "log_dir must be provided when use_wandb is True"
|
assert log_dir is not None, "log_dir must be provided when use_wandb is True"
|
||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
self.wandb_run = wandb.init(project="Coati-dpo", sync_tensorboard=True)
|
self.wandb_run = wandb.init(project="Coati-dpo", sync_tensorboard=True)
|
||||||
if log_dir is not None and is_rank_0():
|
if log_dir is not None and init_criterion:
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
log_dir = os.path.join(log_dir, "dpo")
|
log_dir = os.path.join(log_dir, "DPO")
|
||||||
log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()))
|
log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()))
|
||||||
self.writer = SummaryWriter(log_dir=log_dir)
|
self.writer = SummaryWriter(log_dir=log_dir)
|
||||||
|
|
||||||
@ -117,166 +126,147 @@ class DPOTrainer(SLTrainer):
|
|||||||
epoch int: the number of current epoch
|
epoch int: the number of current epoch
|
||||||
"""
|
"""
|
||||||
self.model.train()
|
self.model.train()
|
||||||
self.accumulative_meter.reset()
|
if isinstance(self.plugin, HybridParallelPlugin) and self.plugin.pp_size > 1:
|
||||||
step_bar = trange(
|
step_bar = tqdm(
|
||||||
len(self.train_dataloader) // self.accumulation_steps,
|
range(len(self.train_dataloader)),
|
||||||
desc=f"Epoch {epoch + 1}/{self.max_epochs}",
|
desc="Step",
|
||||||
disable=not is_rank_0(),
|
disable=not (dist.get_rank() == dist.get_world_size() - 1),
|
||||||
)
|
|
||||||
for i, batch in enumerate(self.train_dataloader):
|
|
||||||
batch = to_device(batch, self.device)
|
|
||||||
(
|
|
||||||
chosen_input_ids,
|
|
||||||
chosen_attention_mask,
|
|
||||||
chosen_loss_mask,
|
|
||||||
reject_input_ids,
|
|
||||||
reject_attention_mask,
|
|
||||||
reject_loss_mask,
|
|
||||||
) = (
|
|
||||||
batch["chosen_input_ids"],
|
|
||||||
batch["chosen_attention_mask"],
|
|
||||||
batch["chosen_loss_mask"],
|
|
||||||
batch["reject_input_ids"],
|
|
||||||
batch["reject_attention_mask"],
|
|
||||||
batch["reject_loss_mask"],
|
|
||||||
)
|
)
|
||||||
if not self.apply_loss_mask:
|
for i, batch in enumerate(self.train_dataloader):
|
||||||
chosen_loss_mask = chosen_loss_mask.fill_(1.0)
|
batch = to_device(batch, self.device)
|
||||||
reject_loss_mask = reject_loss_mask.fill_(1.0)
|
(
|
||||||
|
chosen_input_ids,
|
||||||
|
chosen_attention_mask,
|
||||||
|
chosen_loss_mask,
|
||||||
|
reject_input_ids,
|
||||||
|
reject_attention_mask,
|
||||||
|
reject_loss_mask,
|
||||||
|
) = (
|
||||||
|
batch["chosen_input_ids"],
|
||||||
|
batch["chosen_attention_mask"],
|
||||||
|
batch["chosen_loss_mask"],
|
||||||
|
batch["reject_input_ids"],
|
||||||
|
batch["reject_attention_mask"],
|
||||||
|
batch["reject_loss_mask"],
|
||||||
|
)
|
||||||
|
batch_size = chosen_input_ids.size()[0]
|
||||||
|
# Calculate logits from reference model.
|
||||||
|
if self.ref_model is not None:
|
||||||
|
self.ref_model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
ref_all_logits = self.ref_model(
|
||||||
|
input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
|
||||||
|
attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]),
|
||||||
|
)["logits"]
|
||||||
|
ref_chosen_logits = ref_all_logits[:batch_size]
|
||||||
|
ref_reject_logits = ref_all_logits[batch_size:]
|
||||||
|
logprob_ref_chosen = calc_masked_log_probs(
|
||||||
|
ref_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization
|
||||||
|
)
|
||||||
|
logprob_ref_reject = calc_masked_log_probs(
|
||||||
|
ref_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logprob_ref_chosen = None
|
||||||
|
logprob_ref_reject = None
|
||||||
|
|
||||||
batch_size = chosen_input_ids.size()[0]
|
# Merge chosen and reject
|
||||||
|
inputs_ids = torch.stack([item for tup in zip(chosen_input_ids, reject_input_ids) for item in tup])
|
||||||
|
attention_mask = torch.stack(
|
||||||
|
[item for tup in zip(chosen_attention_mask, reject_attention_mask) for item in tup]
|
||||||
|
)
|
||||||
|
loss_mask = torch.stack([item for tup in zip(chosen_loss_mask, reject_loss_mask) for item in tup])
|
||||||
|
logprob_ref = torch.stack([item for tup in zip(logprob_ref_chosen, logprob_ref_reject) for item in tup])
|
||||||
|
|
||||||
actor_all_logits = self.model(
|
data_iter = iter(
|
||||||
input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
|
[
|
||||||
attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]),
|
{
|
||||||
)["logits"]
|
"input_ids": inputs_ids,
|
||||||
actor_chosen_logits = actor_all_logits[:batch_size]
|
"attention_mask": attention_mask,
|
||||||
actor_reject_logits = actor_all_logits[batch_size:]
|
"loss_mask": loss_mask,
|
||||||
logprob_actor_chosen = calc_masked_log_probs(
|
"logprob_ref": logprob_ref,
|
||||||
actor_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization
|
}
|
||||||
)
|
]
|
||||||
|
)
|
||||||
|
rewards = []
|
||||||
|
|
||||||
logprob_actor_reject = calc_masked_log_probs(
|
def _criterion(outputs, inputs):
|
||||||
actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization
|
loss, chosen_rewards, rejected_rewards = self.actor_loss_fn(
|
||||||
)
|
calc_masked_log_probs(
|
||||||
|
outputs["logits"][0::2],
|
||||||
if self.ref_model is not None:
|
inputs["input_ids"][0::2],
|
||||||
self.ref_model.eval()
|
inputs["loss_mask"][0::2][:, 1:],
|
||||||
with torch.no_grad():
|
self.length_normalization,
|
||||||
ref_all_logits = self.ref_model(
|
),
|
||||||
input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
|
calc_masked_log_probs(
|
||||||
attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]),
|
outputs["logits"][1::2],
|
||||||
)["logits"]
|
inputs["input_ids"][1::2],
|
||||||
ref_chosen_logits = ref_all_logits[:batch_size]
|
inputs["loss_mask"][1::2][:, 1:],
|
||||||
ref_reject_logits = ref_all_logits[batch_size:]
|
self.length_normalization,
|
||||||
logprob_ref_chosen = calc_masked_log_probs(
|
),
|
||||||
ref_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization
|
inputs["logprob_ref"][0::2] if inputs["logprob_ref"] is not None else None,
|
||||||
|
inputs["logprob_ref"][1::2] if inputs["logprob_ref"] is not None else None,
|
||||||
|
inputs["loss_mask"][0::2][:, 1:],
|
||||||
|
inputs["loss_mask"][1::2][:, 1:],
|
||||||
)
|
)
|
||||||
logprob_ref_reject = calc_masked_log_probs(
|
rewards.append(chosen_rewards)
|
||||||
ref_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization
|
rewards.append(rejected_rewards)
|
||||||
)
|
return loss
|
||||||
else:
|
|
||||||
logprob_ref_chosen = None
|
|
||||||
logprob_ref_reject = None
|
|
||||||
|
|
||||||
losses, chosen_rewards, rejected_rewards = self.actor_loss_fn(
|
outputs = self.booster.execute_pipeline(
|
||||||
logprob_actor_chosen,
|
data_iter,
|
||||||
logprob_actor_reject,
|
self.model,
|
||||||
logprob_ref_chosen if logprob_ref_chosen is not None else None,
|
criterion=_criterion,
|
||||||
logprob_ref_reject if logprob_ref_reject is not None else None,
|
optimizer=self.optimizer,
|
||||||
chosen_loss_mask[:, 1:],
|
return_loss=True,
|
||||||
reject_loss_mask[:, 1:],
|
)
|
||||||
)
|
loss = outputs["loss"]
|
||||||
reward_accuracies = (chosen_rewards > rejected_rewards).float().mean()
|
if self.booster.plugin.stage_manager.is_last_stage():
|
||||||
|
chosen_rewards, rejected_rewards = rewards[0], rewards[1]
|
||||||
|
global_loss = all_reduce_mean(loss, self.plugin)
|
||||||
|
if dist.get_rank() == dist.get_world_size() - 1:
|
||||||
|
step_bar.set_postfix(
|
||||||
|
{
|
||||||
|
"train/loss": global_loss.item(),
|
||||||
|
"train/lr": self.actor_scheduler.get_last_lr()[0],
|
||||||
|
"train/chosen_rewards": chosen_rewards.to(torch.float16).mean().item(),
|
||||||
|
"train/rejected_rewards": rejected_rewards.to(torch.float16).mean().item(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
step_bar.update()
|
||||||
|
self.accumulative_meter.add("loss", global_loss.item())
|
||||||
|
self.accumulative_meter.add("chosen_rewards", chosen_rewards.to(torch.float16).mean().item())
|
||||||
|
self.accumulative_meter.add(
|
||||||
|
"rejected_rewards", rejected_rewards.to(torch.float16).mean().item()
|
||||||
|
)
|
||||||
|
if self.writer is not None:
|
||||||
|
self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), i)
|
||||||
|
self.writer.add_scalar(
|
||||||
|
"train/chosen_rewards", self.accumulative_meter.get("chosen_rewards"), i
|
||||||
|
)
|
||||||
|
self.writer.add_scalar(
|
||||||
|
"train/rejected_rewards",
|
||||||
|
self.accumulative_meter.get("rejected_rewards"),
|
||||||
|
i,
|
||||||
|
)
|
||||||
|
self.writer.add_scalar(
|
||||||
|
"train/margin",
|
||||||
|
self.accumulative_meter.get("chosen_rewards")
|
||||||
|
- self.accumulative_meter.get("rejected_rewards"),
|
||||||
|
i,
|
||||||
|
)
|
||||||
|
|
||||||
# DPO Loss
|
|
||||||
loss = losses.mean()
|
|
||||||
|
|
||||||
self.booster.backward(loss=loss, optimizer=self.optimizer)
|
|
||||||
if self.num_train_step % self.accumulation_steps == self.accumulation_steps - 1:
|
|
||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
self.actor_scheduler.step()
|
self.actor_scheduler.step()
|
||||||
|
else:
|
||||||
# sync
|
self.accumulative_meter.reset()
|
||||||
loss_mean = all_reduce_mean(tensor=loss)
|
step_bar = trange(
|
||||||
chosen_rewards_mean = all_reduce_mean(tensor=chosen_rewards)
|
len(self.train_dataloader) // self.accumulation_steps,
|
||||||
rejected_rewards_mean = all_reduce_mean(tensor=rejected_rewards)
|
desc=f"Epoch {epoch + 1}/{self.max_epochs}",
|
||||||
reward_accuracies_mean = all_reduce_mean(tensor=reward_accuracies)
|
disable=not is_rank_0(),
|
||||||
self.accumulative_meter.add("chosen_rewards", chosen_rewards_mean.to(torch.float16).mean().item())
|
)
|
||||||
self.accumulative_meter.add("rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item())
|
for i, batch in enumerate(self.train_dataloader):
|
||||||
self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item())
|
|
||||||
self.accumulative_meter.add("accuracy", reward_accuracies_mean.to(torch.float16).item())
|
|
||||||
|
|
||||||
if i % self.accumulation_steps == self.accumulation_steps - 1:
|
|
||||||
self.num_train_step += 1
|
|
||||||
step_bar.update()
|
|
||||||
# logging
|
|
||||||
if self.writer and is_rank_0():
|
|
||||||
self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step)
|
|
||||||
self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]["lr"], self.num_train_step)
|
|
||||||
self.writer.add_scalar(
|
|
||||||
"train/chosen_rewards", self.accumulative_meter.get("chosen_rewards"), self.num_train_step
|
|
||||||
)
|
|
||||||
self.writer.add_scalar(
|
|
||||||
"train/rejected_rewards",
|
|
||||||
self.accumulative_meter.get("rejected_rewards"),
|
|
||||||
self.num_train_step,
|
|
||||||
)
|
|
||||||
self.writer.add_scalar(
|
|
||||||
"train/margin",
|
|
||||||
self.accumulative_meter.get("chosen_rewards") - self.accumulative_meter.get("rejected_rewards"),
|
|
||||||
self.num_train_step,
|
|
||||||
)
|
|
||||||
self.writer.add_scalar(
|
|
||||||
"train/accuracy",
|
|
||||||
self.accumulative_meter.get("accuracy"),
|
|
||||||
self.num_train_step,
|
|
||||||
)
|
|
||||||
self.accumulative_meter.reset()
|
|
||||||
|
|
||||||
if self.save_dir is not None and (self.num_train_step + 1) % self.save_interval == 0:
|
|
||||||
# save checkpoint
|
|
||||||
self.coordinator.print_on_master("\nStart saving model checkpoint with running states")
|
|
||||||
save_checkpoint(
|
|
||||||
save_dir=self.save_dir,
|
|
||||||
booster=self.booster,
|
|
||||||
model=self.model,
|
|
||||||
optimizer=self.optimizer,
|
|
||||||
lr_scheduler=self.actor_scheduler,
|
|
||||||
epoch=epoch,
|
|
||||||
step=i + 1,
|
|
||||||
batch_size=batch_size,
|
|
||||||
coordinator=self.coordinator,
|
|
||||||
)
|
|
||||||
self.coordinator.print_on_master(
|
|
||||||
f"Saved checkpoint at epoch {epoch} step {self.save_interval} at folder {self.save_dir}"
|
|
||||||
)
|
|
||||||
|
|
||||||
step_bar.close()
|
|
||||||
|
|
||||||
def _eval(self, epoch: int):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
epoch int: the number of current epoch
|
|
||||||
"""
|
|
||||||
if self.eval_dataloader is None:
|
|
||||||
self.coordinator.print_on_master("No eval dataloader is provided, skip evaluation")
|
|
||||||
return
|
|
||||||
self.model.eval()
|
|
||||||
self.ref_model.eval()
|
|
||||||
self.coordinator.print_on_master("\nStart evaluation...")
|
|
||||||
|
|
||||||
step_bar = trange(
|
|
||||||
len(self.eval_dataloader),
|
|
||||||
desc=f"Epoch {epoch + 1}/{self.max_epochs}",
|
|
||||||
disable=not is_rank_0(),
|
|
||||||
)
|
|
||||||
|
|
||||||
self.accumulative_meter.reset()
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
for i, batch in enumerate(self.eval_dataloader):
|
|
||||||
batch = to_device(batch, self.device)
|
batch = to_device(batch, self.device)
|
||||||
(
|
(
|
||||||
chosen_input_ids,
|
chosen_input_ids,
|
||||||
@ -300,12 +290,11 @@ class DPOTrainer(SLTrainer):
|
|||||||
batch_size = chosen_input_ids.size()[0]
|
batch_size = chosen_input_ids.size()[0]
|
||||||
|
|
||||||
actor_all_logits = self.model(
|
actor_all_logits = self.model(
|
||||||
torch.cat([chosen_input_ids, reject_input_ids]),
|
input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
|
||||||
torch.cat([chosen_attention_mask, reject_attention_mask]),
|
attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]),
|
||||||
)["logits"]
|
)["logits"]
|
||||||
actor_chosen_logits = actor_all_logits[:batch_size]
|
actor_chosen_logits = actor_all_logits[:batch_size]
|
||||||
actor_reject_logits = actor_all_logits[batch_size:]
|
actor_reject_logits = actor_all_logits[batch_size:]
|
||||||
|
|
||||||
logprob_actor_chosen = calc_masked_log_probs(
|
logprob_actor_chosen = calc_masked_log_probs(
|
||||||
actor_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization
|
actor_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization
|
||||||
)
|
)
|
||||||
@ -314,22 +303,26 @@ class DPOTrainer(SLTrainer):
|
|||||||
actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization
|
actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization
|
||||||
)
|
)
|
||||||
|
|
||||||
self.ref_model.eval()
|
if self.ref_model is not None:
|
||||||
|
self.ref_model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
ref_all_logits = self.ref_model(
|
||||||
|
input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
|
||||||
|
attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]),
|
||||||
|
)["logits"]
|
||||||
|
ref_chosen_logits = ref_all_logits[:batch_size]
|
||||||
|
ref_reject_logits = ref_all_logits[batch_size:]
|
||||||
|
logprob_ref_chosen = calc_masked_log_probs(
|
||||||
|
ref_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization
|
||||||
|
)
|
||||||
|
logprob_ref_reject = calc_masked_log_probs(
|
||||||
|
ref_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logprob_ref_chosen = None
|
||||||
|
logprob_ref_reject = None
|
||||||
|
|
||||||
ref_all_logits = self.ref_model(
|
loss, chosen_rewards, rejected_rewards = self.actor_loss_fn(
|
||||||
torch.cat([chosen_input_ids, reject_input_ids]),
|
|
||||||
torch.cat([chosen_attention_mask, reject_attention_mask]),
|
|
||||||
)["logits"]
|
|
||||||
ref_chosen_logits = ref_all_logits[:batch_size]
|
|
||||||
ref_reject_logits = ref_all_logits[batch_size:]
|
|
||||||
logprob_ref_chosen = calc_masked_log_probs(
|
|
||||||
ref_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization
|
|
||||||
)
|
|
||||||
logprob_ref_reject = calc_masked_log_probs(
|
|
||||||
ref_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization
|
|
||||||
)
|
|
||||||
|
|
||||||
losses, chosen_rewards, rejected_rewards = self.actor_loss_fn(
|
|
||||||
logprob_actor_chosen,
|
logprob_actor_chosen,
|
||||||
logprob_actor_reject,
|
logprob_actor_reject,
|
||||||
logprob_ref_chosen if logprob_ref_chosen is not None else None,
|
logprob_ref_chosen if logprob_ref_chosen is not None else None,
|
||||||
@ -338,7 +331,9 @@ class DPOTrainer(SLTrainer):
|
|||||||
reject_loss_mask[:, 1:],
|
reject_loss_mask[:, 1:],
|
||||||
)
|
)
|
||||||
reward_accuracies = (chosen_rewards > rejected_rewards).float().mean()
|
reward_accuracies = (chosen_rewards > rejected_rewards).float().mean()
|
||||||
loss = losses.mean()
|
|
||||||
|
self.booster.backward(loss=loss, optimizer=self.optimizer)
|
||||||
|
# sync
|
||||||
loss_mean = all_reduce_mean(tensor=loss)
|
loss_mean = all_reduce_mean(tensor=loss)
|
||||||
chosen_rewards_mean = all_reduce_mean(tensor=chosen_rewards)
|
chosen_rewards_mean = all_reduce_mean(tensor=chosen_rewards)
|
||||||
rejected_rewards_mean = all_reduce_mean(tensor=rejected_rewards)
|
rejected_rewards_mean = all_reduce_mean(tensor=rejected_rewards)
|
||||||
@ -347,16 +342,302 @@ class DPOTrainer(SLTrainer):
|
|||||||
self.accumulative_meter.add("rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item())
|
self.accumulative_meter.add("rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item())
|
||||||
self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item())
|
self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item())
|
||||||
self.accumulative_meter.add("accuracy", reward_accuracies_mean.to(torch.float16).item())
|
self.accumulative_meter.add("accuracy", reward_accuracies_mean.to(torch.float16).item())
|
||||||
self.accumulative_meter.add(
|
|
||||||
"margin", (chosen_rewards_mean - rejected_rewards_mean).to(torch.float16).mean().item()
|
|
||||||
)
|
|
||||||
step_bar.update()
|
|
||||||
|
|
||||||
msg = "Evaluation Result:\n"
|
if (self.num_train_step + 1) % self.accumulation_steps == 0:
|
||||||
for tag in ["loss", "chosen_rewards", "rejected_rewards", "accuracy", "margin"]:
|
self.optimizer.step()
|
||||||
msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n"
|
self.optimizer.zero_grad()
|
||||||
self.coordinator.print_on_master(msg)
|
self.actor_scheduler.step()
|
||||||
os.makedirs(self.save_dir, exist_ok=True)
|
|
||||||
with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f:
|
step_bar.set_postfix(
|
||||||
f.write(msg)
|
{
|
||||||
|
"train/loss": self.accumulative_meter.get("loss"),
|
||||||
|
"train/chosen_rewards": self.accumulative_meter.get("chosen_rewards"),
|
||||||
|
"train/rejected_rewards": self.accumulative_meter.get("rejected_rewards"),
|
||||||
|
"train/accuracy": self.accumulative_meter.get("accuracy"),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
step_bar.update()
|
||||||
|
if self.writer and is_rank_0():
|
||||||
|
global_step = (self.num_train_step + 1) / self.accumulation_steps
|
||||||
|
self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), global_step)
|
||||||
|
self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]["lr"], global_step)
|
||||||
|
self.writer.add_scalar(
|
||||||
|
"train/chosen_rewards", self.accumulative_meter.get("chosen_rewards"), global_step
|
||||||
|
)
|
||||||
|
self.writer.add_scalar(
|
||||||
|
"train/rejected_rewards",
|
||||||
|
self.accumulative_meter.get("rejected_rewards"),
|
||||||
|
global_step,
|
||||||
|
)
|
||||||
|
self.writer.add_scalar(
|
||||||
|
"train/margin",
|
||||||
|
self.accumulative_meter.get("chosen_rewards")
|
||||||
|
- self.accumulative_meter.get("rejected_rewards"),
|
||||||
|
global_step,
|
||||||
|
)
|
||||||
|
self.writer.add_scalar(
|
||||||
|
"train/accuracy",
|
||||||
|
self.accumulative_meter.get("accuracy"),
|
||||||
|
global_step,
|
||||||
|
)
|
||||||
|
self.accumulative_meter.reset()
|
||||||
|
self.num_train_step += 1
|
||||||
|
|
||||||
|
if self.save_dir is not None and self.num_train_step > 0 and self.num_train_step % self.save_interval == 0:
|
||||||
|
# save checkpoint
|
||||||
|
self.coordinator.print_on_master("\nStart saving model checkpoint with running states")
|
||||||
|
save_checkpoint(
|
||||||
|
save_dir=self.save_dir,
|
||||||
|
booster=self.booster,
|
||||||
|
model=self.model,
|
||||||
|
optimizer=self.optimizer,
|
||||||
|
lr_scheduler=self.actor_scheduler,
|
||||||
|
epoch=epoch,
|
||||||
|
step=self.num_train_step,
|
||||||
|
batch_size=batch_size,
|
||||||
|
coordinator=self.coordinator,
|
||||||
|
)
|
||||||
|
self.coordinator.print_on_master(
|
||||||
|
f"Saved checkpoint at epoch {epoch} step {self.save_interval} at folder {self.save_dir}"
|
||||||
|
)
|
||||||
|
|
||||||
|
step_bar.close()
|
||||||
|
|
||||||
|
def _eval(self, epoch: int):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
epoch int: the number of current epoch
|
||||||
|
"""
|
||||||
|
if self.eval_dataloader is None:
|
||||||
|
self.coordinator.print_on_master("No eval dataloader is provided, skip evaluation")
|
||||||
|
return
|
||||||
|
self.model.eval()
|
||||||
|
self.ref_model.eval()
|
||||||
|
self.accumulative_meter.reset()
|
||||||
|
self.coordinator.print_on_master("\nStart evaluation...")
|
||||||
|
|
||||||
|
if isinstance(self.plugin, HybridParallelPlugin) and self.plugin.pp_size > 1:
|
||||||
|
step_bar = tqdm(
|
||||||
|
range(len(self.eval_dataloader)),
|
||||||
|
desc="Step",
|
||||||
|
disable=not (dist.get_rank() == dist.get_world_size() - 1),
|
||||||
|
)
|
||||||
|
with torch.no_grad():
|
||||||
|
for _, batch in enumerate(self.eval_dataloader):
|
||||||
|
batch = to_device(batch, self.device)
|
||||||
|
(
|
||||||
|
chosen_input_ids,
|
||||||
|
chosen_attention_mask,
|
||||||
|
chosen_loss_mask,
|
||||||
|
reject_input_ids,
|
||||||
|
reject_attention_mask,
|
||||||
|
reject_loss_mask,
|
||||||
|
) = (
|
||||||
|
batch["chosen_input_ids"],
|
||||||
|
batch["chosen_attention_mask"],
|
||||||
|
batch["chosen_loss_mask"],
|
||||||
|
batch["reject_input_ids"],
|
||||||
|
batch["reject_attention_mask"],
|
||||||
|
batch["reject_loss_mask"],
|
||||||
|
)
|
||||||
|
batch_size = chosen_input_ids.size()[0]
|
||||||
|
# Calculate logits from reference model.
|
||||||
|
if self.ref_model is not None:
|
||||||
|
self.ref_model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
ref_all_logits = self.ref_model(
|
||||||
|
input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
|
||||||
|
attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]),
|
||||||
|
)["logits"]
|
||||||
|
ref_chosen_logits = ref_all_logits[:batch_size]
|
||||||
|
ref_reject_logits = ref_all_logits[batch_size:]
|
||||||
|
logprob_ref_chosen = calc_masked_log_probs(
|
||||||
|
ref_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization
|
||||||
|
)
|
||||||
|
logprob_ref_reject = calc_masked_log_probs(
|
||||||
|
ref_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logprob_ref_chosen = None
|
||||||
|
logprob_ref_reject = None
|
||||||
|
|
||||||
|
# Merge chosen and reject
|
||||||
|
inputs_ids = torch.stack([item for tup in zip(chosen_input_ids, reject_input_ids) for item in tup])
|
||||||
|
attention_mask = torch.stack(
|
||||||
|
[item for tup in zip(chosen_attention_mask, reject_attention_mask) for item in tup]
|
||||||
|
)
|
||||||
|
loss_mask = torch.stack([item for tup in zip(chosen_loss_mask, reject_loss_mask) for item in tup])
|
||||||
|
logprob_ref = torch.stack(
|
||||||
|
[item for tup in zip(logprob_ref_chosen, logprob_ref_reject) for item in tup]
|
||||||
|
)
|
||||||
|
|
||||||
|
data_iter = iter(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"input_ids": inputs_ids,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
"loss_mask": loss_mask,
|
||||||
|
"logprob_ref": logprob_ref,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
rewards = []
|
||||||
|
|
||||||
|
def _criterion(outputs, inputs):
|
||||||
|
loss, chosen_rewards, rejected_rewards = self.actor_loss_fn(
|
||||||
|
calc_masked_log_probs(
|
||||||
|
outputs["logits"][0::2],
|
||||||
|
inputs["input_ids"][0::2],
|
||||||
|
inputs["loss_mask"][0::2][:, 1:],
|
||||||
|
self.length_normalization,
|
||||||
|
),
|
||||||
|
calc_masked_log_probs(
|
||||||
|
outputs["logits"][1::2],
|
||||||
|
inputs["input_ids"][1::2],
|
||||||
|
inputs["loss_mask"][1::2][:, 1:],
|
||||||
|
self.length_normalization,
|
||||||
|
),
|
||||||
|
inputs["logprob_ref"][0::2] if inputs["logprob_ref"] is not None else None,
|
||||||
|
inputs["logprob_ref"][1::2] if inputs["logprob_ref"] is not None else None,
|
||||||
|
inputs["loss_mask"][0::2][:, 1:],
|
||||||
|
inputs["loss_mask"][1::2][:, 1:],
|
||||||
|
)
|
||||||
|
rewards.append(chosen_rewards)
|
||||||
|
rewards.append(rejected_rewards)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
outputs = self.booster.execute_pipeline(
|
||||||
|
data_iter,
|
||||||
|
self.model,
|
||||||
|
criterion=_criterion,
|
||||||
|
optimizer=self.optimizer,
|
||||||
|
return_loss=True,
|
||||||
|
)
|
||||||
|
loss = outputs["loss"]
|
||||||
|
if self.booster.plugin.stage_manager.is_last_stage():
|
||||||
|
chosen_rewards, rejected_rewards = rewards[0], rewards[1]
|
||||||
|
global_loss = all_reduce_mean(loss, self.plugin)
|
||||||
|
chosen_rewards_mean = all_reduce_mean(chosen_rewards, self.plugin)
|
||||||
|
rejected_rewards_mean = all_reduce_mean(rejected_rewards, self.plugin)
|
||||||
|
if dist.get_rank() == dist.get_world_size() - 1:
|
||||||
|
step_bar.set_postfix(
|
||||||
|
{
|
||||||
|
"eval/loss": global_loss.item(),
|
||||||
|
"eval/lr": self.actor_scheduler.get_last_lr()[0],
|
||||||
|
"eval/chosen_rewards": chosen_rewards.to(torch.float16).mean().item(),
|
||||||
|
"eval/rejected_rewards": rejected_rewards.to(torch.float16).mean().item(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
self.accumulative_meter.add(
|
||||||
|
"chosen_rewards", chosen_rewards_mean.to(torch.float16).mean().item()
|
||||||
|
)
|
||||||
|
self.accumulative_meter.add(
|
||||||
|
"rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item()
|
||||||
|
)
|
||||||
|
self.accumulative_meter.add("loss", global_loss.to(torch.float16).item())
|
||||||
|
step_bar.update()
|
||||||
|
if self.booster.plugin.stage_manager.is_last_stage():
|
||||||
|
msg = "\nEvaluation Result:\n"
|
||||||
|
for tag in ["loss", "chosen_rewards", "rejected_rewards"]:
|
||||||
|
msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n"
|
||||||
|
if dist.get_rank() == dist.get_world_size() - 1:
|
||||||
|
print(msg)
|
||||||
|
else:
|
||||||
|
step_bar = trange(
|
||||||
|
len(self.eval_dataloader),
|
||||||
|
desc=f"Epoch {epoch + 1}/{self.max_epochs}",
|
||||||
|
disable=not is_rank_0(),
|
||||||
|
)
|
||||||
|
with torch.no_grad():
|
||||||
|
for i, batch in enumerate(self.eval_dataloader):
|
||||||
|
batch = to_device(batch, self.device)
|
||||||
|
(
|
||||||
|
chosen_input_ids,
|
||||||
|
chosen_attention_mask,
|
||||||
|
chosen_loss_mask,
|
||||||
|
reject_input_ids,
|
||||||
|
reject_attention_mask,
|
||||||
|
reject_loss_mask,
|
||||||
|
) = (
|
||||||
|
batch["chosen_input_ids"],
|
||||||
|
batch["chosen_attention_mask"],
|
||||||
|
batch["chosen_loss_mask"],
|
||||||
|
batch["reject_input_ids"],
|
||||||
|
batch["reject_attention_mask"],
|
||||||
|
batch["reject_loss_mask"],
|
||||||
|
)
|
||||||
|
if not self.apply_loss_mask:
|
||||||
|
chosen_loss_mask = chosen_loss_mask.fill_(1.0)
|
||||||
|
reject_loss_mask = reject_loss_mask.fill_(1.0)
|
||||||
|
|
||||||
|
batch_size = chosen_input_ids.size()[0]
|
||||||
|
|
||||||
|
actor_all_logits = self.model(
|
||||||
|
torch.cat([chosen_input_ids, reject_input_ids]),
|
||||||
|
torch.cat([chosen_attention_mask, reject_attention_mask]),
|
||||||
|
)["logits"]
|
||||||
|
actor_chosen_logits = actor_all_logits[:batch_size]
|
||||||
|
actor_reject_logits = actor_all_logits[batch_size:]
|
||||||
|
|
||||||
|
logprob_actor_chosen = calc_masked_log_probs(
|
||||||
|
actor_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization
|
||||||
|
)
|
||||||
|
|
||||||
|
logprob_actor_reject = calc_masked_log_probs(
|
||||||
|
actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization
|
||||||
|
)
|
||||||
|
ref_all_logits = self.ref_model(
|
||||||
|
torch.cat([chosen_input_ids, reject_input_ids]),
|
||||||
|
torch.cat([chosen_attention_mask, reject_attention_mask]),
|
||||||
|
)["logits"]
|
||||||
|
ref_chosen_logits = ref_all_logits[:batch_size]
|
||||||
|
ref_reject_logits = ref_all_logits[batch_size:]
|
||||||
|
logprob_ref_chosen = calc_masked_log_probs(
|
||||||
|
ref_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization
|
||||||
|
)
|
||||||
|
logprob_ref_reject = calc_masked_log_probs(
|
||||||
|
ref_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization
|
||||||
|
)
|
||||||
|
|
||||||
|
losses, chosen_rewards, rejected_rewards = self.actor_loss_fn(
|
||||||
|
logprob_actor_chosen,
|
||||||
|
logprob_actor_reject,
|
||||||
|
logprob_ref_chosen if logprob_ref_chosen is not None else None,
|
||||||
|
logprob_ref_reject if logprob_ref_reject is not None else None,
|
||||||
|
chosen_loss_mask[:, 1:],
|
||||||
|
reject_loss_mask[:, 1:],
|
||||||
|
)
|
||||||
|
reward_accuracies = (chosen_rewards > rejected_rewards).float().mean()
|
||||||
|
loss = losses.mean()
|
||||||
|
loss_mean = all_reduce_mean(tensor=loss)
|
||||||
|
chosen_rewards_mean = all_reduce_mean(tensor=chosen_rewards)
|
||||||
|
rejected_rewards_mean = all_reduce_mean(tensor=rejected_rewards)
|
||||||
|
reward_accuracies_mean = all_reduce_mean(tensor=reward_accuracies)
|
||||||
|
self.accumulative_meter.add("chosen_rewards", chosen_rewards_mean.to(torch.float16).mean().item())
|
||||||
|
self.accumulative_meter.add(
|
||||||
|
"rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item()
|
||||||
|
)
|
||||||
|
self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item())
|
||||||
|
self.accumulative_meter.add("accuracy", reward_accuracies_mean.to(torch.float16).item())
|
||||||
|
self.accumulative_meter.add(
|
||||||
|
"margin", (chosen_rewards_mean - rejected_rewards_mean).to(torch.float16).mean().item()
|
||||||
|
)
|
||||||
|
step_bar.set_postfix(
|
||||||
|
{
|
||||||
|
"eval/loss": self.accumulative_meter.get("loss"),
|
||||||
|
"eval/chosen_rewards": self.accumulative_meter.get("chosen_rewards"),
|
||||||
|
"eval/rejected_rewards": self.accumulative_meter.get("rejected_rewards"),
|
||||||
|
"eval/accuracy": self.accumulative_meter.get("accuracy"),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
step_bar.update()
|
||||||
|
|
||||||
|
msg = "\nEvaluation Result:\n"
|
||||||
|
for tag in ["loss", "chosen_rewards", "rejected_rewards", "accuracy", "margin"]:
|
||||||
|
msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n"
|
||||||
|
self.coordinator.print_on_master(msg)
|
||||||
|
if self.save_dir is not None:
|
||||||
|
os.makedirs(self.save_dir, exist_ok=True)
|
||||||
|
with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f:
|
||||||
|
f.write(msg)
|
||||||
step_bar.close()
|
step_bar.close()
|
||||||
|
386
applications/ColossalChat/coati/trainer/grpo.py
Executable file
386
applications/ColossalChat/coati/trainer/grpo.py
Executable file
@ -0,0 +1,386 @@
|
|||||||
|
"""
|
||||||
|
GRPO trainer
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import wandb
|
||||||
|
from coati.experience_buffer import NaiveExperienceBuffer
|
||||||
|
from coati.experience_maker import Experience, NaiveExperienceMaker
|
||||||
|
from coati.models import RewardModel, RLVRRewardModel
|
||||||
|
from coati.models.loss import GPTLMLoss, PolicyLoss
|
||||||
|
from coati.models.utils import calc_action_log_probs
|
||||||
|
from coati.trainer.callbacks import Callback
|
||||||
|
from coati.trainer.utils import all_reduce_mean
|
||||||
|
from coati.utils import AccumulativeMeanMeter, save_checkpoint
|
||||||
|
from torch.optim import Optimizer
|
||||||
|
from torch.optim.lr_scheduler import _LRScheduler
|
||||||
|
from torch.utils.data import DataLoader, DistributedSampler
|
||||||
|
from tqdm import tqdm
|
||||||
|
from transformers import PreTrainedModel, PreTrainedTokenizerBase
|
||||||
|
|
||||||
|
from colossalai.booster import Booster
|
||||||
|
from colossalai.booster.plugin import GeminiPlugin
|
||||||
|
from colossalai.cluster import DistCoordinator
|
||||||
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
|
from .base import OLTrainer
|
||||||
|
from .utils import AnnealingScheduler, CycledDataLoader, is_rank_0, to_device
|
||||||
|
|
||||||
|
|
||||||
|
def _set_default_generate_kwargs(actor: PreTrainedModel) -> Dict:
|
||||||
|
"""
|
||||||
|
Set default keyword arguments for generation based on the actor model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
actor (PreTrainedModel): The actor model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict: A dictionary containing the default keyword arguments for generation.
|
||||||
|
"""
|
||||||
|
unwrapped_model = actor.unwrap()
|
||||||
|
new_kwargs = {}
|
||||||
|
# use huggingface models method directly
|
||||||
|
if hasattr(unwrapped_model, "prepare_inputs_for_generation"):
|
||||||
|
new_kwargs["prepare_inputs_fn"] = unwrapped_model.prepare_inputs_for_generation
|
||||||
|
if hasattr(unwrapped_model, "_update_model_kwargs_for_generation"):
|
||||||
|
new_kwargs["update_model_kwargs_fn"] = unwrapped_model._update_model_kwargs_for_generation
|
||||||
|
return new_kwargs
|
||||||
|
|
||||||
|
|
||||||
|
class GRPOTrainer(OLTrainer):
|
||||||
|
"""
|
||||||
|
Trainer for GRPO algorithm.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
strategy (Booster): the strategy to use for training
|
||||||
|
actor (Actor): the actor model in ppo algorithm
|
||||||
|
reward_model (RewardModel): the reward model in rlhf algorithm to make reward of sentences
|
||||||
|
initial_model (Actor): the initial model in rlhf algorithm to generate reference logics to limit the update of actor
|
||||||
|
actor_optim (Optimizer): the optimizer to use for actor model
|
||||||
|
kl_coef (float, defaults to 0.1): the coefficient of kl divergence loss
|
||||||
|
train_batch_size (int, defaults to 8): the batch size to use for training
|
||||||
|
buffer_limit (int, defaults to 0): the max_size limitation of buffer
|
||||||
|
buffer_cpu_offload (bool, defaults to True): whether to offload buffer to cpu
|
||||||
|
eps_clip (float, defaults to 0.2): the clip coefficient of policy loss
|
||||||
|
vf_coef (float, defaults to 1.0): the coefficient of value loss
|
||||||
|
ptx_coef (float, defaults to 0.9): the coefficient of ptx loss
|
||||||
|
value_clip (float, defaults to 0.4): the clip coefficient of value loss
|
||||||
|
sample_buffer (bool, defaults to False): whether to sample from buffer
|
||||||
|
dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader
|
||||||
|
offload_inference_models (bool, defaults to True): whether to offload inference models to cpu during training process
|
||||||
|
callbacks (List[Callback], defaults to []): the callbacks to call during training process
|
||||||
|
generate_kwargs (dict, optional): the kwargs to use while model generating
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
actor_booster: Booster,
|
||||||
|
actor: PreTrainedModel,
|
||||||
|
reward_model: Union[RewardModel, RLVRRewardModel],
|
||||||
|
initial_model: PreTrainedModel,
|
||||||
|
actor_optim: Optimizer,
|
||||||
|
actor_lr_scheduler: _LRScheduler,
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
kl_coef: float = 0.1,
|
||||||
|
ptx_coef: float = 0.9,
|
||||||
|
train_batch_size: int = 8,
|
||||||
|
buffer_limit: int = 0,
|
||||||
|
buffer_cpu_offload: bool = True,
|
||||||
|
eps_clip: float = 0.2,
|
||||||
|
vf_coef: float = 1.0,
|
||||||
|
value_clip: float = 0.2,
|
||||||
|
sample_buffer: bool = False,
|
||||||
|
dataloader_pin_memory: bool = True,
|
||||||
|
offload_inference_models: bool = True,
|
||||||
|
apply_loss_mask: bool = True,
|
||||||
|
accumulation_steps: int = 1,
|
||||||
|
save_interval: int = 0,
|
||||||
|
save_dir: str = None,
|
||||||
|
use_tp: bool = False,
|
||||||
|
num_generation: int = 8,
|
||||||
|
inference_batch_size: int = None,
|
||||||
|
logits_forward_batch_size: int = None,
|
||||||
|
temperature_annealing_config: Optional[Dict] = None,
|
||||||
|
coordinator: DistCoordinator = None,
|
||||||
|
callbacks: List[Callback] = [],
|
||||||
|
**generate_kwargs,
|
||||||
|
) -> None:
|
||||||
|
if isinstance(actor_booster, GeminiPlugin):
|
||||||
|
assert not offload_inference_models, "GeminiPlugin is not compatible with manual model.to('cpu')"
|
||||||
|
|
||||||
|
data_buffer = NaiveExperienceBuffer(train_batch_size, buffer_limit, buffer_cpu_offload)
|
||||||
|
super().__init__(actor_booster, None, data_buffer, sample_buffer, dataloader_pin_memory, callbacks=callbacks)
|
||||||
|
self.generate_kwargs = _set_default_generate_kwargs(actor)
|
||||||
|
self.generate_kwargs.update(generate_kwargs)
|
||||||
|
|
||||||
|
self.actor = actor
|
||||||
|
self.actor_booster = actor_booster
|
||||||
|
self.actor_scheduler = actor_lr_scheduler
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.experience_maker = NaiveExperienceMaker(
|
||||||
|
self.actor,
|
||||||
|
None,
|
||||||
|
reward_model,
|
||||||
|
initial_model,
|
||||||
|
self.tokenizer,
|
||||||
|
kl_coef,
|
||||||
|
use_grpo=True,
|
||||||
|
num_generation=num_generation,
|
||||||
|
inference_batch_size=inference_batch_size,
|
||||||
|
logits_forward_batch_size=logits_forward_batch_size,
|
||||||
|
)
|
||||||
|
if temperature_annealing_config:
|
||||||
|
# use annealing
|
||||||
|
self.temperature_annealing_scheduler = AnnealingScheduler(
|
||||||
|
temperature_annealing_config["start_temperature"],
|
||||||
|
temperature_annealing_config["end_temperature"],
|
||||||
|
temperature_annealing_config["annealing_warmup_steps"],
|
||||||
|
temperature_annealing_config["annealing_steps"],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.temperature_annealing_scheduler = None
|
||||||
|
|
||||||
|
self.train_batch_size = train_batch_size
|
||||||
|
|
||||||
|
self.actor_loss_fn = PolicyLoss(eps_clip)
|
||||||
|
self.vf_coef = vf_coef
|
||||||
|
self.ptx_loss_fn = GPTLMLoss()
|
||||||
|
self.ptx_coef = ptx_coef
|
||||||
|
self.actor_optim = actor_optim
|
||||||
|
self.save_interval = save_interval
|
||||||
|
self.apply_loss_mask = apply_loss_mask
|
||||||
|
self.coordinator = coordinator
|
||||||
|
self.actor_save_dir = os.path.join(save_dir, "actor")
|
||||||
|
self.num_train_step = 0
|
||||||
|
self.accumulation_steps = accumulation_steps
|
||||||
|
self.use_tp = use_tp
|
||||||
|
self.accumulative_meter = AccumulativeMeanMeter()
|
||||||
|
self.offload_inference_models = offload_inference_models
|
||||||
|
self.device = get_current_device()
|
||||||
|
|
||||||
|
def _before_fit(
|
||||||
|
self,
|
||||||
|
prompt_dataloader: DataLoader,
|
||||||
|
pretrain_dataloader: Optional[DataLoader] = None,
|
||||||
|
log_dir: Optional[str] = None,
|
||||||
|
use_wandb: bool = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
prompt_dataloader (DataLoader): the dataloader to use for prompt data
|
||||||
|
pretrain_dataloader (DataLoader): the dataloader to use for pretrain data
|
||||||
|
"""
|
||||||
|
self.prompt_dataloader = CycledDataLoader(prompt_dataloader)
|
||||||
|
self.pretrain_dataloader = CycledDataLoader(pretrain_dataloader) if pretrain_dataloader is not None else None
|
||||||
|
|
||||||
|
self.writer = None
|
||||||
|
if use_wandb and is_rank_0():
|
||||||
|
assert log_dir is not None, "log_dir must be provided when use_wandb is True"
|
||||||
|
import wandb
|
||||||
|
|
||||||
|
self.wandb_run = wandb.init(project="Coati-grpo", sync_tensorboard=True)
|
||||||
|
if log_dir is not None and is_rank_0():
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
log_dir = os.path.join(log_dir, "grpo")
|
||||||
|
log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()))
|
||||||
|
self.writer = SummaryWriter(log_dir=log_dir)
|
||||||
|
|
||||||
|
def _setup_update_phrase_dataload(self):
|
||||||
|
"""
|
||||||
|
why not use distributed_dataloader?
|
||||||
|
if tp is used, input on each rank is the same and we use the same dataloader to feed same experience to all ranks
|
||||||
|
if tp is not used, input on each rank is different and we expect different experiences to be fed to each rank
|
||||||
|
"""
|
||||||
|
self.dataloader = DataLoader(
|
||||||
|
self.data_buffer,
|
||||||
|
batch_size=self.train_batch_size,
|
||||||
|
shuffle=True,
|
||||||
|
drop_last=True,
|
||||||
|
pin_memory=self.dataloader_pin_memory,
|
||||||
|
collate_fn=self.data_buffer.collate_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _make_experience(self, collect_step: int) -> Experience:
|
||||||
|
"""
|
||||||
|
Make experience
|
||||||
|
"""
|
||||||
|
prompts = self.prompt_dataloader.next()
|
||||||
|
if self.offload_inference_models:
|
||||||
|
# TODO(ver217): this may be controlled by strategy if they are prepared by strategy
|
||||||
|
self.experience_maker.initial_model.to(self.device)
|
||||||
|
self.experience_maker.reward_model.to(self.device)
|
||||||
|
if self.temperature_annealing_scheduler:
|
||||||
|
self.generate_kwargs["temperature"] = self.temperature_annealing_scheduler.get_temperature()
|
||||||
|
return self.experience_maker.make_experience(
|
||||||
|
input_ids=prompts["input_ids"].to(get_current_device()),
|
||||||
|
attention_mask=prompts["attention_mask"].to(get_current_device()),
|
||||||
|
gt_answer=prompts["gt_answer"],
|
||||||
|
**self.generate_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _training_step(self, experience: Experience):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
experience:
|
||||||
|
sequences: [batch_size, prompt_length + response_length] --- <PAD>...<PAD><PROMPT>...<PROMPT><RESPONSE>...<RESPONSE><PAD>...<PAD>
|
||||||
|
"""
|
||||||
|
self.actor.train()
|
||||||
|
num_actions = experience.action_log_probs.size(1)
|
||||||
|
# policy loss
|
||||||
|
|
||||||
|
actor_logits = self.actor(input_ids=experience.sequences, attention_mask=experience.attention_mask)[
|
||||||
|
"logits"
|
||||||
|
] # [batch size, prompt_length + response_length]
|
||||||
|
action_log_probs = calc_action_log_probs(actor_logits, experience.sequences, num_actions)
|
||||||
|
actor_loss, to_skip, max_ratio = self.actor_loss_fn(
|
||||||
|
action_log_probs,
|
||||||
|
experience.action_log_probs,
|
||||||
|
experience.advantages.unsqueeze(dim=-1).repeat_interleave(action_log_probs.size(-1), dim=-1),
|
||||||
|
action_mask=experience.action_mask if self.apply_loss_mask else None,
|
||||||
|
)
|
||||||
|
# sequence that is not end properly are not counted in token cost
|
||||||
|
token_cost = torch.sum(
|
||||||
|
(experience.sequences[:, -num_actions:] != self.tokenizer.pad_token_id).to(torch.float), axis=-1
|
||||||
|
).to(actor_logits.device)
|
||||||
|
end_properly = experience.sequences[:, -1] == self.tokenizer.pad_token_id
|
||||||
|
mean_token_cost = torch.sum(token_cost * end_properly) / torch.sum(end_properly)
|
||||||
|
actor_loss = (1 - self.ptx_coef) * actor_loss
|
||||||
|
if not to_skip:
|
||||||
|
self.actor_booster.backward(loss=actor_loss, optimizer=self.actor_optim)
|
||||||
|
|
||||||
|
# ptx loss
|
||||||
|
if self.ptx_coef != 0:
|
||||||
|
batch = self.pretrain_dataloader.next()
|
||||||
|
batch = to_device(batch, self.device)
|
||||||
|
outputs = self.actor(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"])
|
||||||
|
ptx_loss = outputs.loss
|
||||||
|
ptx_loss = self.ptx_coef * ptx_loss
|
||||||
|
self.actor_booster.backward(loss=ptx_loss, optimizer=self.actor_optim)
|
||||||
|
|
||||||
|
# sync
|
||||||
|
actor_loss_mean = all_reduce_mean(tensor=actor_loss)
|
||||||
|
max_ratio_mean = all_reduce_mean(tensor=max_ratio)
|
||||||
|
reward_mean = all_reduce_mean(tensor=experience.reward.mean())
|
||||||
|
advantages_mean = all_reduce_mean(tensor=experience.advantages.mean())
|
||||||
|
kl_mean = all_reduce_mean(tensor=experience.kl.mean())
|
||||||
|
mean_token_cost = all_reduce_mean(tensor=mean_token_cost)
|
||||||
|
if self.ptx_coef != 0:
|
||||||
|
ptx_loss_mean = all_reduce_mean(tensor=ptx_loss)
|
||||||
|
|
||||||
|
self.accumulative_meter.add("actor_loss", actor_loss_mean.to(torch.float16).mean().item())
|
||||||
|
self.accumulative_meter.add("max_ratio", max_ratio_mean.to(torch.float16).item())
|
||||||
|
self.accumulative_meter.add("reward", reward_mean.to(torch.float16).mean().item())
|
||||||
|
self.accumulative_meter.add("advantages", advantages_mean.to(torch.float16).item())
|
||||||
|
self.accumulative_meter.add("skip_ratio", 1.0 if to_skip else 0.0)
|
||||||
|
self.accumulative_meter.add("mean_token_cost", mean_token_cost.to(torch.float16).item())
|
||||||
|
self.accumulative_meter.add("kl", kl_mean.to(torch.float16).item())
|
||||||
|
if self.ptx_coef != 0:
|
||||||
|
self.accumulative_meter.add("ptx_loss", ptx_loss_mean.to(torch.float16).mean().item())
|
||||||
|
|
||||||
|
if self.num_train_step % self.accumulation_steps == self.accumulation_steps - 1:
|
||||||
|
self.actor_optim.step()
|
||||||
|
self.actor_optim.zero_grad()
|
||||||
|
self.actor_scheduler.step()
|
||||||
|
|
||||||
|
if self.temperature_annealing_scheduler:
|
||||||
|
self.temperature_annealing_scheduler.step_forward()
|
||||||
|
|
||||||
|
# preparing logging model output and corresponding rewards.
|
||||||
|
if self.num_train_step % 10 == 0:
|
||||||
|
response_text = self.experience_maker.tokenizer.batch_decode(
|
||||||
|
experience.sequences, skip_special_tokens=True
|
||||||
|
)
|
||||||
|
for i in range(len(response_text)):
|
||||||
|
response_text[i] = response_text[i] + f"\n\nReward: {experience.reward[i]}"
|
||||||
|
|
||||||
|
if self.writer and is_rank_0() and "wandb_run" in self.__dict__:
|
||||||
|
# log output to wandb
|
||||||
|
my_table = wandb.Table(
|
||||||
|
columns=[f"sample response {i}" for i in range(len(response_text))], data=[response_text]
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
self.wandb_run.log({"sample_response": my_table})
|
||||||
|
except OSError as e:
|
||||||
|
self.coordinator.print_on_master(e)
|
||||||
|
elif self.writer and is_rank_0():
|
||||||
|
for line in response_text:
|
||||||
|
self.coordinator.print_on_master(line)
|
||||||
|
|
||||||
|
if self.writer and is_rank_0():
|
||||||
|
global_step = (self.num_train_step + 1) / self.accumulation_steps
|
||||||
|
self.writer.add_scalar("train/max_ratio", self.accumulative_meter.get("max_ratio"), global_step)
|
||||||
|
self.writer.add_scalar("train/skip_ratio", self.accumulative_meter.get("skip_ratio"), global_step)
|
||||||
|
self.writer.add_scalar("train/actor_loss", self.accumulative_meter.get("actor_loss"), global_step)
|
||||||
|
self.writer.add_scalar("train/lr_actor", self.actor_optim.param_groups[0]["lr"], global_step)
|
||||||
|
if self.ptx_coef != 0:
|
||||||
|
self.writer.add_scalar("train/ptx_loss", self.accumulative_meter.get("ptx_loss"), global_step)
|
||||||
|
self.writer.add_scalar("reward", self.accumulative_meter.get("reward"), global_step)
|
||||||
|
self.writer.add_scalar("token_cost", self.accumulative_meter.get("mean_token_cost"), global_step)
|
||||||
|
self.writer.add_scalar("approx_kl", self.accumulative_meter.get("kl"), global_step)
|
||||||
|
self.writer.add_scalar("advantages", self.accumulative_meter.get("advantages"), global_step)
|
||||||
|
self.accumulative_meter.reset()
|
||||||
|
self.num_train_step += 1
|
||||||
|
|
||||||
|
def _learn(self, update_step: int):
|
||||||
|
"""
|
||||||
|
Perform the learning step of the PPO algorithm.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
update_step (int): The current update step.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
if self.offload_inference_models:
|
||||||
|
self.experience_maker.initial_model.to("cpu")
|
||||||
|
self.experience_maker.reward_model.to("cpu")
|
||||||
|
# buffer may be empty at first, we should rebuild at each training
|
||||||
|
if self.sample_buffer:
|
||||||
|
experience = self.data_buffer.sample()
|
||||||
|
self._on_learn_batch_start()
|
||||||
|
experience.to_device(self.device)
|
||||||
|
self._training_step(experience)
|
||||||
|
self._on_learn_batch_end(experience)
|
||||||
|
else:
|
||||||
|
if isinstance(self.dataloader.sampler, DistributedSampler):
|
||||||
|
self.dataloader.sampler.set_epoch(update_step)
|
||||||
|
pbar = tqdm(self.dataloader, desc=f"Train epoch [{update_step + 1}]", disable=not is_rank_0())
|
||||||
|
for experience in pbar:
|
||||||
|
self._on_learn_batch_start()
|
||||||
|
experience.to_device(self.device)
|
||||||
|
self._training_step(experience)
|
||||||
|
self._on_learn_batch_end(experience)
|
||||||
|
|
||||||
|
def _save_checkpoint(self, num_train_step: int = 0):
|
||||||
|
"""
|
||||||
|
Save the actor checkpoints with running states.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_train_step (int): The current num_train_step number.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.coordinator.print_on_master("\nStart saving actor checkpoint with running states")
|
||||||
|
save_checkpoint(
|
||||||
|
save_dir=self.actor_save_dir,
|
||||||
|
booster=self.actor_booster,
|
||||||
|
model=self.actor,
|
||||||
|
optimizer=self.actor_optim,
|
||||||
|
lr_scheduler=self.actor_scheduler,
|
||||||
|
epoch=0,
|
||||||
|
step=num_train_step + 1,
|
||||||
|
batch_size=self.train_batch_size,
|
||||||
|
coordinator=self.coordinator,
|
||||||
|
)
|
||||||
|
self.coordinator.print_on_master(
|
||||||
|
f"Saved actor checkpoint at episode {(num_train_step + 1)} at folder {self.actor_save_dir}"
|
||||||
|
)
|
@ -217,25 +217,25 @@ class KTOTrainer(SLTrainer):
|
|||||||
self.accumulative_meter.add("rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item())
|
self.accumulative_meter.add("rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item())
|
||||||
self.accumulative_meter.add("loss", loss_mean.to(torch.float16).detach().item())
|
self.accumulative_meter.add("loss", loss_mean.to(torch.float16).detach().item())
|
||||||
|
|
||||||
if i % self.accumulation_steps == self.accumulation_steps - 1:
|
if self.num_train_step % self.accumulation_steps == self.accumulation_steps - 1:
|
||||||
self.num_train_step += 1
|
|
||||||
step_bar.update()
|
step_bar.update()
|
||||||
# logging
|
# logging
|
||||||
if self.writer and is_rank_0():
|
if self.writer and is_rank_0():
|
||||||
self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step)
|
global_step = (self.num_train_step + 1) / self.accumulation_steps
|
||||||
self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]["lr"], self.num_train_step)
|
self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), global_step)
|
||||||
|
self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]["lr"], global_step)
|
||||||
self.writer.add_scalar(
|
self.writer.add_scalar(
|
||||||
"train/chosen_rewards", self.accumulative_meter.get("chosen_rewards"), self.num_train_step
|
"train/chosen_rewards", self.accumulative_meter.get("chosen_rewards"), global_step
|
||||||
)
|
)
|
||||||
self.writer.add_scalar(
|
self.writer.add_scalar(
|
||||||
"train/rejected_rewards",
|
"train/rejected_rewards",
|
||||||
self.accumulative_meter.get("rejected_rewards"),
|
self.accumulative_meter.get("rejected_rewards"),
|
||||||
self.num_train_step,
|
global_step,
|
||||||
)
|
)
|
||||||
self.writer.add_scalar(
|
self.writer.add_scalar(
|
||||||
"train/margin",
|
"train/margin",
|
||||||
self.accumulative_meter.get("chosen_rewards") - self.accumulative_meter.get("rejected_rewards"),
|
self.accumulative_meter.get("chosen_rewards") - self.accumulative_meter.get("rejected_rewards"),
|
||||||
self.num_train_step,
|
global_step,
|
||||||
)
|
)
|
||||||
self.accumulative_meter.reset()
|
self.accumulative_meter.reset()
|
||||||
|
|
||||||
@ -256,6 +256,7 @@ class KTOTrainer(SLTrainer):
|
|||||||
self.coordinator.print_on_master(
|
self.coordinator.print_on_master(
|
||||||
f"Saved checkpoint at epoch {epoch} step {self.save_interval} at folder {self.save_dir}"
|
f"Saved checkpoint at epoch {epoch} step {self.save_interval} at folder {self.save_dir}"
|
||||||
)
|
)
|
||||||
|
self.num_train_step += 1
|
||||||
|
|
||||||
step_bar.close()
|
step_bar.close()
|
||||||
|
|
||||||
|
@ -184,35 +184,35 @@ class ORPOTrainer(SLTrainer):
|
|||||||
self.accumulative_meter.add("log_odds_ratio", log_odds_ratio.to(torch.float16).mean().item())
|
self.accumulative_meter.add("log_odds_ratio", log_odds_ratio.to(torch.float16).mean().item())
|
||||||
self.accumulative_meter.add("accuracy", reward_accuracies_mean.to(torch.float16).item())
|
self.accumulative_meter.add("accuracy", reward_accuracies_mean.to(torch.float16).item())
|
||||||
|
|
||||||
if i % self.accumulation_steps == self.accumulation_steps - 1:
|
if self.num_train_step % self.accumulation_steps == self.accumulation_steps - 1:
|
||||||
self.num_train_step += 1
|
|
||||||
step_bar.update()
|
step_bar.update()
|
||||||
|
global_step = (self.num_train_step + 1) / self.accumulation_steps
|
||||||
# logging
|
# logging
|
||||||
if self.writer and is_rank_0():
|
if self.writer and is_rank_0():
|
||||||
self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step)
|
self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), global_step)
|
||||||
self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]["lr"], self.num_train_step)
|
self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]["lr"], global_step)
|
||||||
self.writer.add_scalar(
|
self.writer.add_scalar(
|
||||||
"train/chosen_rewards", self.accumulative_meter.get("chosen_rewards"), self.num_train_step
|
"train/chosen_rewards", self.accumulative_meter.get("chosen_rewards"), global_step
|
||||||
)
|
)
|
||||||
self.writer.add_scalar(
|
self.writer.add_scalar(
|
||||||
"train/rejected_rewards",
|
"train/rejected_rewards",
|
||||||
self.accumulative_meter.get("rejected_rewards"),
|
self.accumulative_meter.get("rejected_rewards"),
|
||||||
self.num_train_step,
|
global_step,
|
||||||
)
|
)
|
||||||
self.writer.add_scalar(
|
self.writer.add_scalar(
|
||||||
"train/margin",
|
"train/margin",
|
||||||
self.accumulative_meter.get("chosen_rewards") - self.accumulative_meter.get("rejected_rewards"),
|
self.accumulative_meter.get("chosen_rewards") - self.accumulative_meter.get("rejected_rewards"),
|
||||||
self.num_train_step,
|
global_step,
|
||||||
)
|
)
|
||||||
self.writer.add_scalar(
|
self.writer.add_scalar(
|
||||||
"train/accuracy",
|
"train/accuracy",
|
||||||
self.accumulative_meter.get("accuracy"),
|
self.accumulative_meter.get("accuracy"),
|
||||||
self.num_train_step,
|
global_step,
|
||||||
)
|
)
|
||||||
self.writer.add_scalar(
|
self.writer.add_scalar(
|
||||||
"train/log_odds_ratio",
|
"train/log_odds_ratio",
|
||||||
self.accumulative_meter.get("log_odds_ratio"),
|
self.accumulative_meter.get("log_odds_ratio"),
|
||||||
self.num_train_step,
|
global_step,
|
||||||
)
|
)
|
||||||
self.accumulative_meter.reset()
|
self.accumulative_meter.reset()
|
||||||
|
|
||||||
@ -233,6 +233,7 @@ class ORPOTrainer(SLTrainer):
|
|||||||
self.coordinator.print_on_master(
|
self.coordinator.print_on_master(
|
||||||
f"Saved checkpoint at epoch {epoch} step {self.save_interval} at folder {self.save_dir}"
|
f"Saved checkpoint at epoch {epoch} step {self.save_interval} at folder {self.save_dir}"
|
||||||
)
|
)
|
||||||
|
self.num_train_step += 1
|
||||||
|
|
||||||
step_bar.close()
|
step_bar.close()
|
||||||
|
|
||||||
|
@ -3,13 +3,13 @@ PPO trainer
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import wandb
|
import wandb
|
||||||
from coati.experience_buffer import NaiveExperienceBuffer
|
from coati.experience_buffer import NaiveExperienceBuffer
|
||||||
from coati.experience_maker import Experience, NaiveExperienceMaker
|
from coati.experience_maker import Experience, NaiveExperienceMaker
|
||||||
from coati.models import Critic, RewardModel
|
from coati.models import Critic, RewardModel, RLVRRewardModel
|
||||||
from coati.models.loss import GPTLMLoss, PolicyLoss, ValueLoss
|
from coati.models.loss import GPTLMLoss, PolicyLoss, ValueLoss
|
||||||
from coati.models.utils import calc_action_log_probs
|
from coati.models.utils import calc_action_log_probs
|
||||||
from coati.trainer.callbacks import Callback
|
from coati.trainer.callbacks import Callback
|
||||||
@ -84,7 +84,7 @@ class PPOTrainer(OLTrainer):
|
|||||||
critic_booster: Booster,
|
critic_booster: Booster,
|
||||||
actor: PreTrainedModel,
|
actor: PreTrainedModel,
|
||||||
critic: Critic,
|
critic: Critic,
|
||||||
reward_model: RewardModel,
|
reward_model: Union[RewardModel, RLVRRewardModel],
|
||||||
initial_model: PreTrainedModel,
|
initial_model: PreTrainedModel,
|
||||||
actor_optim: Optimizer,
|
actor_optim: Optimizer,
|
||||||
critic_optim: Optimizer,
|
critic_optim: Optimizer,
|
||||||
@ -210,6 +210,7 @@ class PPOTrainer(OLTrainer):
|
|||||||
return self.experience_maker.make_experience(
|
return self.experience_maker.make_experience(
|
||||||
input_ids=prompts["input_ids"].to(get_current_device()),
|
input_ids=prompts["input_ids"].to(get_current_device()),
|
||||||
attention_mask=prompts["attention_mask"].to(get_current_device()),
|
attention_mask=prompts["attention_mask"].to(get_current_device()),
|
||||||
|
gt_answer=prompts["gt_answer"],
|
||||||
**self.generate_kwargs,
|
**self.generate_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -219,7 +220,6 @@ class PPOTrainer(OLTrainer):
|
|||||||
experience:
|
experience:
|
||||||
sequences: [batch_size, prompt_length + response_length] --- <PAD>...<PAD><PROMPT>...<PROMPT><RESPONSE>...<RESPONSE><PAD>...<PAD>
|
sequences: [batch_size, prompt_length + response_length] --- <PAD>...<PAD><PROMPT>...<PROMPT><RESPONSE>...<RESPONSE><PAD>...<PAD>
|
||||||
"""
|
"""
|
||||||
self.num_train_step += 1
|
|
||||||
self.actor.train()
|
self.actor.train()
|
||||||
self.critic.train()
|
self.critic.train()
|
||||||
num_actions = experience.action_log_probs.size(1)
|
num_actions = experience.action_log_probs.size(1)
|
||||||
@ -293,7 +293,7 @@ class PPOTrainer(OLTrainer):
|
|||||||
self.critic_scheduler.step()
|
self.critic_scheduler.step()
|
||||||
|
|
||||||
# preparing logging model output and corresponding rewards.
|
# preparing logging model output and corresponding rewards.
|
||||||
if self.num_train_step % 10 == 1:
|
if self.num_train_step % 10 == 0:
|
||||||
response_text = self.experience_maker.tokenizer.batch_decode(
|
response_text = self.experience_maker.tokenizer.batch_decode(
|
||||||
experience.sequences, skip_special_tokens=True
|
experience.sequences, skip_special_tokens=True
|
||||||
)
|
)
|
||||||
@ -335,6 +335,7 @@ class PPOTrainer(OLTrainer):
|
|||||||
self.writer.add_scalar("value", self.accumulative_meter.get("value"), self.num_train_step)
|
self.writer.add_scalar("value", self.accumulative_meter.get("value"), self.num_train_step)
|
||||||
self.writer.add_scalar("advantages", self.accumulative_meter.get("advantages"), self.num_train_step)
|
self.writer.add_scalar("advantages", self.accumulative_meter.get("advantages"), self.num_train_step)
|
||||||
self.accumulative_meter.reset()
|
self.accumulative_meter.reset()
|
||||||
|
self.num_train_step += 1
|
||||||
|
|
||||||
def _learn(self, update_step: int):
|
def _learn(self, update_step: int):
|
||||||
"""
|
"""
|
||||||
|
@ -150,29 +150,29 @@ class RewardModelTrainer(SLTrainer):
|
|||||||
self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item())
|
self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item())
|
||||||
self.accumulative_meter.add("accuracy", accuracy_mean.mean().to(torch.float16).item())
|
self.accumulative_meter.add("accuracy", accuracy_mean.mean().to(torch.float16).item())
|
||||||
|
|
||||||
if (i + 1) % self.accumulation_steps == 0:
|
if (self.num_train_step + 1) % self.accumulation_steps == 0:
|
||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
self.actor_scheduler.step()
|
self.actor_scheduler.step()
|
||||||
step_bar.update()
|
step_bar.update()
|
||||||
self.num_train_step += 1
|
|
||||||
|
|
||||||
# Logging
|
# Logging
|
||||||
if self.writer and is_rank_0():
|
if self.writer and is_rank_0():
|
||||||
self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step)
|
global_step = (self.num_train_step + 1) / self.accumulation_steps
|
||||||
self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]["lr"], self.num_train_step)
|
self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), global_step)
|
||||||
|
self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]["lr"], global_step)
|
||||||
self.writer.add_scalar(
|
self.writer.add_scalar(
|
||||||
"train/dist",
|
"train/dist",
|
||||||
self.accumulative_meter.get("chosen_rewards") - self.accumulative_meter.get("rejected_rewards"),
|
self.accumulative_meter.get("chosen_rewards") - self.accumulative_meter.get("rejected_rewards"),
|
||||||
self.num_train_step,
|
global_step,
|
||||||
)
|
)
|
||||||
self.writer.add_scalar(
|
self.writer.add_scalar(
|
||||||
"train/reward_chosen", self.accumulative_meter.get("chosen_rewards"), self.num_train_step
|
"train/reward_chosen", self.accumulative_meter.get("chosen_rewards"), global_step
|
||||||
)
|
)
|
||||||
self.writer.add_scalar(
|
self.writer.add_scalar(
|
||||||
"train/reward_reject", self.accumulative_meter.get("rejected_rewards"), self.num_train_step
|
"train/reward_reject", self.accumulative_meter.get("rejected_rewards"), global_step
|
||||||
)
|
)
|
||||||
self.writer.add_scalar("train/acc", self.accumulative_meter.get("accuracy"), self.num_train_step)
|
self.writer.add_scalar("train/acc", self.accumulative_meter.get("accuracy"), global_step)
|
||||||
|
|
||||||
self.accumulative_meter.reset()
|
self.accumulative_meter.reset()
|
||||||
|
|
||||||
@ -193,6 +193,7 @@ class RewardModelTrainer(SLTrainer):
|
|||||||
self.coordinator.print_on_master(
|
self.coordinator.print_on_master(
|
||||||
f"Saved checkpoint at epoch {epoch} step {(i + 1)/self.accumulation_steps} at folder {self.save_dir}"
|
f"Saved checkpoint at epoch {epoch} step {(i + 1)/self.accumulation_steps} at folder {self.save_dir}"
|
||||||
)
|
)
|
||||||
|
self.num_train_step += 1
|
||||||
step_bar.close()
|
step_bar.close()
|
||||||
|
|
||||||
def _eval(self, epoch):
|
def _eval(self, epoch):
|
||||||
|
@ -143,18 +143,18 @@ class SFTTrainer(SLTrainer):
|
|||||||
self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item())
|
self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item())
|
||||||
|
|
||||||
# Gradient accumulation
|
# Gradient accumulation
|
||||||
if (i + 1) % self.accumulation_steps == 0:
|
if (self.num_train_step + 1) % self.accumulation_steps == 0:
|
||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
self.scheduler.step()
|
self.scheduler.step()
|
||||||
|
global_step = (self.num_train_step + 1) / self.accumulation_steps
|
||||||
step_bar.set_postfix({"train/loss": self.accumulative_meter.get("loss")})
|
step_bar.set_postfix({"train/loss": self.accumulative_meter.get("loss")})
|
||||||
if self.writer:
|
if self.writer:
|
||||||
self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step)
|
self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), global_step)
|
||||||
self.writer.add_scalar("train/lr", self.scheduler.get_last_lr()[0], self.num_train_step)
|
self.writer.add_scalar("train/lr", self.scheduler.get_last_lr()[0], global_step)
|
||||||
self.num_train_step += 1
|
|
||||||
self.accumulative_meter.reset()
|
self.accumulative_meter.reset()
|
||||||
step_bar.update()
|
step_bar.update()
|
||||||
|
self.num_train_step += 1
|
||||||
|
|
||||||
# Save checkpoint
|
# Save checkpoint
|
||||||
if (
|
if (
|
||||||
|
@ -12,6 +12,27 @@ from torch.utils.data import DataLoader
|
|||||||
from colossalai.booster import Plugin
|
from colossalai.booster import Plugin
|
||||||
|
|
||||||
|
|
||||||
|
class AnnealingScheduler:
|
||||||
|
def __init__(self, start, end, warmup_steps=100, annealing_step=2000):
|
||||||
|
self.start = start
|
||||||
|
self.end = end
|
||||||
|
self.warmup_steps = warmup_steps
|
||||||
|
self.step = 0
|
||||||
|
self.annealing_step = annealing_step
|
||||||
|
|
||||||
|
def get_temperature(self):
|
||||||
|
if self.step <= self.warmup_steps:
|
||||||
|
return self.start # Stop annealing after warm-up steps
|
||||||
|
elif self.step >= self.annealing_step:
|
||||||
|
return self.end
|
||||||
|
# Linear annealing
|
||||||
|
temp = self.start - (self.step / self.annealing_step) * (self.start - self.end)
|
||||||
|
return temp
|
||||||
|
|
||||||
|
def step_forward(self):
|
||||||
|
self.step += 1
|
||||||
|
|
||||||
|
|
||||||
class CycledDataLoader:
|
class CycledDataLoader:
|
||||||
"""
|
"""
|
||||||
A data loader that cycles through the data when it reaches the end.
|
A data loader that cycles through the data when it reaches the end.
|
||||||
|
@ -0,0 +1,4 @@
|
|||||||
|
from .competition import math_competition_reward_fn
|
||||||
|
from .gsm8k import gsm8k_reward_fn
|
||||||
|
|
||||||
|
__all__ = ["gsm8k_reward_fn", "math_competition_reward_fn"]
|
@ -0,0 +1,26 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
from .utils import extract_solution, validate_response_structure
|
||||||
|
|
||||||
|
|
||||||
|
def math_competition_reward_fn(input_ids, attention_mask, **kwargs):
|
||||||
|
# apply varifiable reward
|
||||||
|
# reward 10 points if the final answer is correct, reward 1 point if format is correct
|
||||||
|
|
||||||
|
gt_answer = kwargs["gt_answer"]
|
||||||
|
tokenizer = kwargs["tokenizer"]
|
||||||
|
s, e = kwargs["response_start"], kwargs["response_end"]
|
||||||
|
reward = torch.tensor(0.0).to(input_ids.device)
|
||||||
|
if gt_answer is None:
|
||||||
|
return reward
|
||||||
|
decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True)
|
||||||
|
final_answer, processed_str = extract_solution(decoded_final_answer)
|
||||||
|
|
||||||
|
format_valid = validate_response_structure(processed_str, kwargs["tags"])
|
||||||
|
if not format_valid:
|
||||||
|
return reward
|
||||||
|
else:
|
||||||
|
reward += 1.0
|
||||||
|
if gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower():
|
||||||
|
reward = reward + 9.0
|
||||||
|
return reward
|
31
applications/ColossalChat/coati/utils/reward_score/gsm8k.py
Normal file
31
applications/ColossalChat/coati/utils/reward_score/gsm8k.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
from .utils import extract_solution, validate_response_structure
|
||||||
|
|
||||||
|
|
||||||
|
def gsm8k_reward_fn(input_ids, attention_mask, **kwargs):
|
||||||
|
# apply varifiable reward
|
||||||
|
# reward 10 points if the final answer is correct, reward 1 point if format is correct
|
||||||
|
|
||||||
|
gt_answer = kwargs["gt_answer"]
|
||||||
|
tokenizer = kwargs["tokenizer"]
|
||||||
|
s, e = kwargs["response_start"], kwargs["response_end"]
|
||||||
|
reward = torch.tensor(0.0).to(input_ids.device)
|
||||||
|
if gt_answer is None:
|
||||||
|
return reward
|
||||||
|
decoded_final_answer = tokenizer.decode(input_ids[s:e], skip_special_tokens=True)
|
||||||
|
final_answer, processed_str = extract_solution(decoded_final_answer)
|
||||||
|
is_valid = True
|
||||||
|
try:
|
||||||
|
int(final_answer.strip())
|
||||||
|
except Exception:
|
||||||
|
is_valid = False
|
||||||
|
|
||||||
|
format_valid = validate_response_structure(processed_str, kwargs["tags"])
|
||||||
|
if not is_valid or not format_valid:
|
||||||
|
return reward
|
||||||
|
else:
|
||||||
|
reward += 1.0
|
||||||
|
if gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower():
|
||||||
|
reward = reward + 9.0
|
||||||
|
return reward
|
76
applications/ColossalChat/coati/utils/reward_score/utils.py
Normal file
76
applications/ColossalChat/coati/utils/reward_score/utils.py
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
# Copyright Unakar
|
||||||
|
# Modified from https://github.com/Unakar/Logic-RL/blob/086373176ac198c97277ff50f4b6e7e1bfe669d3/verl/utils/reward_score/kk.py#L99
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import re
|
||||||
|
from typing import Dict, Optional, Tuple
|
||||||
|
|
||||||
|
|
||||||
|
def validate_response_structure(processed_str: str, tags: Dict = None) -> bool:
|
||||||
|
"""Performs comprehensive validation of response structure.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
processed_str: Processed response string from the model
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Boolean indicating whether all formatting requirements are met
|
||||||
|
"""
|
||||||
|
validation_passed = True
|
||||||
|
# Check required tags
|
||||||
|
if tags is None:
|
||||||
|
tags = {
|
||||||
|
"think_start": {"text": "<think>", "num_occur": 1},
|
||||||
|
"think_end": {"text": "</think>", "num_occur": 1},
|
||||||
|
"answer_start": {"text": "<answer>", "num_occur": 1},
|
||||||
|
"answer_end": {"text": "</answer>", "num_occur": 1},
|
||||||
|
}
|
||||||
|
positions = {}
|
||||||
|
for tag_name, tag_info in tags.items():
|
||||||
|
tag_str = tag_info["text"]
|
||||||
|
expected_count = tag_info["num_occur"]
|
||||||
|
count = processed_str.count(tag_str)
|
||||||
|
positions[tag_name] = pos = processed_str.find(tag_str)
|
||||||
|
if count != expected_count:
|
||||||
|
validation_passed = False
|
||||||
|
# Verify tag order
|
||||||
|
if (
|
||||||
|
positions["think_start"] > positions["think_end"]
|
||||||
|
or positions["think_end"] > positions["answer_start"]
|
||||||
|
or positions["answer_start"] > positions["answer_end"]
|
||||||
|
):
|
||||||
|
validation_passed = False
|
||||||
|
if len(processed_str) - positions["answer_end"] != len(tags["answer_end"]["text"]):
|
||||||
|
validation_passed = False
|
||||||
|
return validation_passed
|
||||||
|
|
||||||
|
|
||||||
|
def extract_solution(solution_str: str) -> Tuple[Optional[str], str]:
|
||||||
|
"""Extracts the final answer from the model's response string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
solution_str: Raw response string from the language model
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple containing (extracted_answer, processed_string)
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Extract final answer using XML-style tags
|
||||||
|
answer_pattern = r"<answer>(.*?)</answer>"
|
||||||
|
matches = list(re.finditer(answer_pattern, solution_str, re.DOTALL))
|
||||||
|
|
||||||
|
if not matches:
|
||||||
|
return None, solution_str
|
||||||
|
|
||||||
|
final_answer = matches[-1].group(1).strip()
|
||||||
|
return final_answer, solution_str
|
@ -0,0 +1,8 @@
|
|||||||
|
{
|
||||||
|
"chat_template": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
||||||
|
"system_message": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
|
||||||
|
"stop_ids": [
|
||||||
|
122753
|
||||||
|
],
|
||||||
|
"end_of_assistant": "<|im_end|>"
|
||||||
|
}
|
@ -0,0 +1,26 @@
|
|||||||
|
{
|
||||||
|
"chat_template": "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
||||||
|
"system_message": "You are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and<answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>. Now the user asks you to solve a math problem that involves reasoning. After thinking, when you finally reach a conclusion, clearly output the final answer without explanation within the <answer> </answer> tags, your final answer should be a integer without unit, currency mark, thousands separator or other text. i.e., <answer> 123 </answer>.\n",
|
||||||
|
"stop_ids": [
|
||||||
|
151643
|
||||||
|
],
|
||||||
|
"end_of_assistant": "<|endoftext|>",
|
||||||
|
"response_format_tags": {
|
||||||
|
"think_start": {
|
||||||
|
"text": "<think>",
|
||||||
|
"num_occur": 1
|
||||||
|
},
|
||||||
|
"think_end": {
|
||||||
|
"text": "</think>",
|
||||||
|
"num_occur": 1
|
||||||
|
},
|
||||||
|
"answer_start": {
|
||||||
|
"text": "<answer>",
|
||||||
|
"num_occur": 1
|
||||||
|
},
|
||||||
|
"answer_end": {
|
||||||
|
"text": "</answer>",
|
||||||
|
"num_occur": 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -2,8 +2,6 @@
|
|||||||
|
|
||||||
|
|
||||||
## Table of Contents
|
## Table of Contents
|
||||||
|
|
||||||
|
|
||||||
- [Examples](#examples)
|
- [Examples](#examples)
|
||||||
- [Table of Contents](#table-of-contents)
|
- [Table of Contents](#table-of-contents)
|
||||||
- [Install Requirements](#install-requirements)
|
- [Install Requirements](#install-requirements)
|
||||||
@ -27,13 +25,14 @@
|
|||||||
- [Reward](#reward)
|
- [Reward](#reward)
|
||||||
- [KL Divergence](#approximate-kl-divergence)
|
- [KL Divergence](#approximate-kl-divergence)
|
||||||
- [Note on PPO Training](#note-on-ppo-training)
|
- [Note on PPO Training](#note-on-ppo-training)
|
||||||
|
- [GRPO Training and DeepSeek R1 reproduction](#grpo-training-and-deepseek-r1-reproduction)
|
||||||
- [Alternative Option For RLHF: Direct Preference Optimization](#alternative-option-for-rlhf-direct-preference-optimization)
|
- [Alternative Option For RLHF: Direct Preference Optimization](#alternative-option-for-rlhf-direct-preference-optimization)
|
||||||
- [DPO Stage 1: Supervised Instruction Tuning](#dpo-training-stage1---supervised-instructs-tuning)
|
- [DPO Stage 1: Supervised Instruction Tuning](#dpo-training-stage1---supervised-instructs-tuning)
|
||||||
- [DPO Stage 2: DPO Training](#dpo-training-stage2---dpo-training)
|
- [DPO Stage 2: DPO Training](#dpo-training-stage2---dpo-training)
|
||||||
- [Alternative Option For RLHF: Simple Preference Optimization](#alternative-option-for-rlhf-simple-preference-optimization)
|
- [Alternative Option For RLHF: Simple Preference Optimization](#alternative-option-for-rlhf-simple-preference-optimization)
|
||||||
- [Alternative Option For RLHF: Kahneman-Tversky Optimization (KTO)](#alternative-option-for-rlhf-kahneman-tversky-optimization-kto)
|
- [Alternative Option For RLHF: Kahneman-Tversky Optimization (KTO)](#alternative-option-for-rlhf-kahneman-tversky-optimization-kto)
|
||||||
- [Alternative Option For RLHF: Odds Ratio Preference Optimization](#alternative-option-for-rlhf-odds-ratio-preference-optimization)
|
- [Alternative Option For RLHF: Odds Ratio Preference Optimization](#alternative-option-for-rlhf-odds-ratio-preference-optimization)
|
||||||
- [List of Supported Models](#list-of-supported-models)
|
- [SFT for DeepSeek V3](#sft-for-deepseek-v3)
|
||||||
- [Hardware Requirements](#hardware-requirements)
|
- [Hardware Requirements](#hardware-requirements)
|
||||||
- [Inference example](#inference-example)
|
- [Inference example](#inference-example)
|
||||||
- [Attention](#attention)
|
- [Attention](#attention)
|
||||||
@ -725,10 +724,69 @@ Answer: The causes of this problem are two-fold. Check your reward model, make s
|
|||||||
#### Q4: Generation is garbage
|
#### Q4: Generation is garbage
|
||||||
Answer: Yes, this happens and is well documented by other implementations. After training for too many episodes, the actor gradually deviate from its original state, which may leads to decrease in language modeling capabilities. A way to fix this is to add supervised loss during PPO. Set ptx_coef to an non-zero value (between 0 and 1), which balances PPO loss and sft loss.
|
Answer: Yes, this happens and is well documented by other implementations. After training for too many episodes, the actor gradually deviate from its original state, which may leads to decrease in language modeling capabilities. A way to fix this is to add supervised loss during PPO. Set ptx_coef to an non-zero value (between 0 and 1), which balances PPO loss and sft loss.
|
||||||
|
|
||||||
|
## GRPO Training and DeepSeek R1 reproduction
|
||||||
|
We support GRPO (Group Relative Policy Optimization), which is the reinforcement learning algorithm used in DeepSeek R1 paper. In this section, we will walk through GRPO training with an example trying to reproduce Deepseek R1's results in mathematical problem solving.
|
||||||
|
|
||||||
|
**Note: Currently, our PPO and GRPO pipelines are still under extensive development (integration with Ray and the inference engine). The speed is primarily limited by the rollout process, as we are using a naive generation approach without any acceleration. This experiment is focused solely on verifying the correctness of the GRPO algorithm. We will open-source the new version of code as soon as possible, so please stay tuned.**
|
||||||
|
|
||||||
|
### GRPO Model Selection
|
||||||
|
We finally select the base version of [Qwen2.5-3B](https://huggingface.co/Qwen/Qwen2.5-3B). We also did experiments on the instruct version [Qwen2.5-3B-Instruct](https://huggingface.co/Qwen/Qwen2.5-3B-Instruct) but the later one fails to explore more diversed output. We recommend to use base models (without SFT) and use a few SFT steps (see [SFT section](#rlhf-training-stage1---supervised-instructs-tuning)) to correct the base model's output format before GRPO.
|
||||||
|
|
||||||
|
### Reinforcement Learning with Verifiable Reward
|
||||||
|
Both the PPO and the GRPO support reinforcement learning with verifiable reward (RLVR). In this experiment on mathematical problem solving, we define the reward function as following, in the following definition, forward is correct if there are exactly one pair of <think></think>, <answer></answer> tags in the response and the order of the tags is correct.
|
||||||
|
|
||||||
|
- reward=0, if format is incorrect.
|
||||||
|
- reward=1, if format is correct but the answer doesn't match the ground truth answer exactly.
|
||||||
|
- reward=10, if format is correct and the answer match the ground truth answer exactly.
|
||||||
|
|
||||||
|
### Step 1: Data Collection & Preparation
|
||||||
|
For GPRO training, you only need the prompt dataset. Please follow the instruction in the [prompt dataset preparation](#rlhf-training-stage3---proximal-policy-optimization) to prepare the prompt data for GPRO training. In our reproduction experiment, we use the [qwedsacf/competition_math dataset](https://huggingface.co/datasets/qwedsacf/competition_math), which is available on Huggingface.
|
||||||
|
|
||||||
|
### Step 2: Training
|
||||||
|
You can run the [train_grpo.sh](./training_scripts/train_grpo.sh) to start GRPO training. The script share most of its arguments with the PPO script (please refer to the [PPO training section](#step-3-training) for more details). Here are some unique arguments for GRPO.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
--num_generations 8 \ # number of roll outs to collect for each prompt
|
||||||
|
--inference_batch_size 8 \ # batch size used during roll out
|
||||||
|
--logits_forward_batch_size 1 \ # batch size used to calculate logits for GRPO training
|
||||||
|
--initial_temperature \ # initial temperature for annealing algorithm
|
||||||
|
--final_temperature \ # final temperature for annealing algorithm
|
||||||
|
```
|
||||||
|
|
||||||
|
As the GRPO requires to collect a group of response from each prompt (usually greater than 8), the effective batch size will satisfy the following constraints,
|
||||||
|
|
||||||
|
- Without tensor parallelism,
|
||||||
|
```
|
||||||
|
experience buffer size
|
||||||
|
= num_process * num_collect_steps * experience_batch_size * num_generations
|
||||||
|
= train_batch_size * accumulation_steps * num_process
|
||||||
|
```
|
||||||
|
|
||||||
|
- With tensor parallelism,
|
||||||
|
```
|
||||||
|
num_tp_group = num_process / tp
|
||||||
|
experience buffer size
|
||||||
|
= num_tp_group * num_collect_steps * experience_batch_size * num_generations
|
||||||
|
= train_batch_size * accumulation_steps * num_tp_group
|
||||||
|
```
|
||||||
|
|
||||||
|
During roll out, we perform rebatching to prevent out of memory both before roll out and before calculating logits. Please choose a proper setting for the "inference_batch_size" and the "logits_forward_batch_size" based on your device.
|
||||||
|
|
||||||
|
### GRPO Result
|
||||||
|
#### Reward and Response Length
|
||||||
|
<div style="display: flex; justify-content: space-between;">
|
||||||
|
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/grpo/reward.png" style="width: 48%;" />
|
||||||
|
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/grpo/token_cost.png" style="width: 48%;" />
|
||||||
|
</div>
|
||||||
|
|
||||||
|
#### Response Length Distribution (After Training) and Sample response
|
||||||
|
<div style="display: flex; justify-content: space-between;">
|
||||||
|
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/grpo/token_cost_eval.png" style="width: 48%;" />
|
||||||
|
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/grpo/sample.png" style="width: 48%;" />
|
||||||
|
</div>
|
||||||
|
|
||||||
|
|
||||||
## Alternative Option For RLHF: Direct Preference Optimization
|
## Alternative Option For RLHF: Direct Preference Optimization
|
||||||
|
|
||||||
|
|
||||||
For those seeking an alternative to Reinforcement Learning from Human Feedback (RLHF), Direct Preference Optimization (DPO) presents a compelling option. DPO, as detailed in the paper (available at [https://arxiv.org/abs/2305.18290](https://arxiv.org/abs/2305.18290)), DPO offers an low-cost way to perform RLHF and usually request less computation resources compares to PPO.
|
For those seeking an alternative to Reinforcement Learning from Human Feedback (RLHF), Direct Preference Optimization (DPO) presents a compelling option. DPO, as detailed in the paper (available at [https://arxiv.org/abs/2305.18290](https://arxiv.org/abs/2305.18290)), DPO offers an low-cost way to perform RLHF and usually request less computation resources compares to PPO.
|
||||||
|
|
||||||
|
|
||||||
@ -814,8 +872,95 @@ For training, use the [train_kto.sh](./examples/training_scripts/train_orpo.sh)
|
|||||||
<img width="1000" alt="image" src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/KTO.png">
|
<img width="1000" alt="image" src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/KTO.png">
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
## Hardware Requirements
|
|
||||||
|
|
||||||
|
### SFT for DeepSeek V3
|
||||||
|
We add a script to supervised-fintune the DeepSeek V3/R1 model with LoRA. The script is located in `examples/training_scripts/lora_fintune.py`. The script is similar to the SFT script for Coati7B, but with a few differences. This script is compatible with Peft.
|
||||||
|
|
||||||
|
#### Dataset preparation
|
||||||
|
|
||||||
|
This script receives JSONL format file as input dataset. Each line of dataset should be a list of chat dialogues. E.g.
|
||||||
|
```json
|
||||||
|
[{"role": "user", "content": "Hello, how are you?"}, {"role": "assistant", "content": "I'm doing great. How can I help you today?"}]
|
||||||
|
```
|
||||||
|
```json
|
||||||
|
[{"role": "user", "content": "火烧赤壁 曹操为何不拨打119求救?"}, {"role": "assistant", "content": "因为在三国时期,还没有电话和现代的消防系统,所以曹操无法拨打119求救。"}]
|
||||||
|
```
|
||||||
|
|
||||||
|
The dialogues can by multiple turns and it can contain system prompt. For more details, see the [chat_templating](https://huggingface.co/docs/transformers/main/chat_templating).
|
||||||
|
|
||||||
|
#### Model weights preparation
|
||||||
|
|
||||||
|
We use bf16 weights for finetuning. If you downloaded fp8 DeepSeek V3/R1 weights, you can use the [script](https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/fp8_cast_bf16.py) to convert the weights to bf16 via GPU. For Ascend NPU, you can use this [script](https://gitee.com/ascend/ModelZoo-PyTorch/blob/master/MindIE/LLM/DeepSeek/DeepSeek-V2/NPU_inference/fp8_cast_bf16.py).
|
||||||
|
|
||||||
|
We have also added details on how to load and reason with lora models.
|
||||||
|
```python
|
||||||
|
from transformers import (
|
||||||
|
AutoModelForCausalLM,
|
||||||
|
AutoTokenizer,
|
||||||
|
)
|
||||||
|
from peft import (
|
||||||
|
PeftModel
|
||||||
|
)
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# Set model path
|
||||||
|
model_name = "Qwen/Qwen2.5-3B"
|
||||||
|
lora_adapter = "Qwen2.5-3B_lora" # Your lora model Path
|
||||||
|
merged_model_path = "Qwen2.5-3B_merged"
|
||||||
|
|
||||||
|
######
|
||||||
|
# How to Load lora Model
|
||||||
|
######
|
||||||
|
# 1.Load base model
|
||||||
|
base_model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_name,
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device_map="auto",
|
||||||
|
trust_remote_code=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2.Load lora model
|
||||||
|
peft_model = PeftModel.from_pretrained(
|
||||||
|
base_model,
|
||||||
|
lora_adapter,
|
||||||
|
torch_dtype=torch.bfloat16
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3.Merge lora model
|
||||||
|
merged_model = peft_model.merge_and_unload()
|
||||||
|
|
||||||
|
# 4.Load tokenizer
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
model_name,
|
||||||
|
trust_remote_code=True,
|
||||||
|
pad_token="<|endoftext|>"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 5.Save merged lora model
|
||||||
|
merged_model.save_pretrained(
|
||||||
|
merged_model_path,
|
||||||
|
safe_serialization=True
|
||||||
|
)
|
||||||
|
tokenizer.save_pretrained(merged_model_path)
|
||||||
|
|
||||||
|
# 6.Run Inference
|
||||||
|
test_input = tokenizer("Instruction: Finding prime numbers up to 100\nAnswer:", return_tensors="pt").to("cuda")
|
||||||
|
output = merged_model.generate(**test_input, max_new_tokens=100)
|
||||||
|
print(tokenizer.decode(output[0], skip_special_tokens=True))
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Usage
|
||||||
|
|
||||||
|
After preparing the dataset and model weights, you can run the script with the following command:
|
||||||
|
```bash
|
||||||
|
colossalai run --hostfile path-to-host-file --nproc_per_node 8 lora_finetune.py --pretrained path-to-DeepSeek-R1-bf16 --dataset path-to-dataset.jsonl --plugin moe --lr 2e-5 --max_length 256 -g --ep 8 --pp 3 --batch_size 24 --lora_rank 8 --lora_alpha 16 --num_epochs 2 --warmup_steps 8 --tensorboard_dir logs --save_dir DeepSeek-R1-bf16-lora
|
||||||
|
```
|
||||||
|
|
||||||
|
For more details of each argument, you can run `python lora_finetune.py --help`.
|
||||||
|
|
||||||
|
The sample command does not use CPU offload to get better throughput. The minimum hardware requirement for sample command is 32 ascend 910B NPUs (with `ep=8,pp=4`) or 24 H100/H800 GPUs (with `ep=8,pp=3`). If you enable CPU offload by `--zero_cpu_offload`, the hardware requirement can be further reduced.
|
||||||
|
|
||||||
|
## Hardware Requirements
|
||||||
For SFT, we recommend using zero2 or zero2-cpu for 7B model and tp is your model is extra large. We tested the VRAM consumption on a dummy dataset with a sequence length of 2048. In all experiments, we use H800 GPUs with 80GB VRAM and enable gradient checkpointing and flash attention.
|
For SFT, we recommend using zero2 or zero2-cpu for 7B model and tp is your model is extra large. We tested the VRAM consumption on a dummy dataset with a sequence length of 2048. In all experiments, we use H800 GPUs with 80GB VRAM and enable gradient checkpointing and flash attention.
|
||||||
- 2 H800 GPU
|
- 2 H800 GPU
|
||||||
- zero2-cpu, micro batch size=4, VRAM Usage=22457.98 MB
|
- zero2-cpu, micro batch size=4, VRAM Usage=22457.98 MB
|
||||||
@ -872,35 +1017,9 @@ For KTO, we recommend using zero2-cpu or zero2 plugin, We tested the VRAM consum
|
|||||||
- zero2_cpu, micro batch size=2, VRAM_USAGE=32443.22 MB
|
- zero2_cpu, micro batch size=2, VRAM_USAGE=32443.22 MB
|
||||||
- zero2, micro batch size=4, VRAM_USAGE=59307.97 MB
|
- zero2, micro batch size=4, VRAM_USAGE=59307.97 MB
|
||||||
|
|
||||||
## List of Supported Models
|
|
||||||
|
|
||||||
For SFT, we support the following models/series:
|
|
||||||
- Colossal-LLaMA-2
|
|
||||||
- ChatGLM2
|
|
||||||
- ChatGLM3 (only with zero2, zero2_cpu plugin)
|
|
||||||
- Baichuan2
|
|
||||||
- LLaMA2
|
|
||||||
- Qwen1.5-7B-Chat (with transformers==4.39.1)
|
|
||||||
- Yi-1.5
|
|
||||||
|
|
||||||
For PPO and DPO, we theoratically support the following models/series (without guarantee):
|
|
||||||
- Colossal-LLaMA-2 (tested)
|
|
||||||
- ChatGLM2
|
|
||||||
- Baichuan2
|
|
||||||
- LLaMA2 (tested)
|
|
||||||
- Qwen1.5-7B-Chat (with transformers==4.39.1)
|
|
||||||
- Yi-1.5
|
|
||||||
|
|
||||||
*-* The zero2, zero2_cpu plugin also support a wide range of chat models not listed above.
|
|
||||||
|
|
||||||
## Inference example
|
## Inference example
|
||||||
|
|
||||||
|
|
||||||
We support different inference options, including int8 and int4 quantization.
|
We support different inference options, including int8 and int4 quantization.
|
||||||
For details, see [`inference/`](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat/inference).
|
For details, see [`inference/`](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat/inference).
|
||||||
|
|
||||||
|
|
||||||
## Attention
|
## Attention
|
||||||
|
|
||||||
|
|
||||||
The examples are demos for the whole training process. You need to change the hyper-parameters to reach great performance.
|
The examples are demos for the whole training process. You need to change the hyper-parameters to reach great performance.
|
||||||
|
@ -73,8 +73,7 @@ def main():
|
|||||||
"--conversation_template_config",
|
"--conversation_template_config",
|
||||||
type=str,
|
type=str,
|
||||||
default="conversation_template_config",
|
default="conversation_template_config",
|
||||||
help="Path \
|
help="Path to save conversation template config files.",
|
||||||
to save conversation template config files.",
|
|
||||||
)
|
)
|
||||||
parser.add_argument("--data_cache_dir", type=str, default="cache", help="Data cache directory")
|
parser.add_argument("--data_cache_dir", type=str, default="cache", help="Data cache directory")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
@ -11,4 +11,4 @@ python prepare_dataset.py --type prompt \
|
|||||||
--data_cache_dir $SAVE_DIR/cache \
|
--data_cache_dir $SAVE_DIR/cache \
|
||||||
--data_jsonl_output_dir $SAVE_DIR/jsonl \
|
--data_jsonl_output_dir $SAVE_DIR/jsonl \
|
||||||
--data_arrow_output_dir $SAVE_DIR/arrow \
|
--data_arrow_output_dir $SAVE_DIR/arrow \
|
||||||
--max_length 1024
|
--max_length 300
|
||||||
|
@ -1,181 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import os
|
|
||||||
import socket
|
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
import pandas as pd
|
|
||||||
import ray
|
|
||||||
from coati.quant import llama_load_quant, low_resource_init
|
|
||||||
from coati.ray.detached_trainer_ppo import DetachedPPOTrainer
|
|
||||||
from coati.ray.experience_maker_holder import ExperienceMakerHolder
|
|
||||||
from coati.ray.utils import (
|
|
||||||
get_actor_from_args,
|
|
||||||
get_critic_from_args,
|
|
||||||
get_reward_model_from_args,
|
|
||||||
get_strategy_from_args,
|
|
||||||
get_tokenizer_from_args,
|
|
||||||
)
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
from transformers import AutoConfig
|
|
||||||
from transformers.modeling_utils import no_init_weights
|
|
||||||
|
|
||||||
|
|
||||||
def get_free_port():
|
|
||||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
|
||||||
s.bind(("", 0))
|
|
||||||
return s.getsockname()[1]
|
|
||||||
|
|
||||||
|
|
||||||
def get_local_ip():
|
|
||||||
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
|
|
||||||
s.connect(("8.8.8.8", 80))
|
|
||||||
return s.getsockname()[0]
|
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
|
||||||
master_addr = str(get_local_ip())
|
|
||||||
# trainer_env_info
|
|
||||||
trainer_port = str(get_free_port())
|
|
||||||
env_info_trainers = [
|
|
||||||
{
|
|
||||||
"local_rank": "0",
|
|
||||||
"rank": str(rank),
|
|
||||||
"world_size": str(args.num_trainers),
|
|
||||||
"master_port": trainer_port,
|
|
||||||
"master_addr": master_addr,
|
|
||||||
}
|
|
||||||
for rank in range(args.num_trainers)
|
|
||||||
]
|
|
||||||
|
|
||||||
# maker_env_info
|
|
||||||
maker_port = str(get_free_port())
|
|
||||||
env_info_maker = {
|
|
||||||
"local_rank": "0",
|
|
||||||
"rank": "0",
|
|
||||||
"world_size": "1",
|
|
||||||
"master_port": maker_port,
|
|
||||||
"master_addr": master_addr,
|
|
||||||
}
|
|
||||||
|
|
||||||
# configure tokenizer
|
|
||||||
tokenizer = get_tokenizer_from_args(args.model)
|
|
||||||
|
|
||||||
def trainer_model_fn():
|
|
||||||
actor = get_actor_from_args(args.model, args.pretrain).half().cuda()
|
|
||||||
critic = get_critic_from_args(args.model, args.critic_pretrain).half().cuda()
|
|
||||||
return actor, critic
|
|
||||||
|
|
||||||
# configure Trainer
|
|
||||||
trainer_refs = [
|
|
||||||
DetachedPPOTrainer.options(name=f"trainer{i}", num_gpus=1, max_concurrency=2).remote(
|
|
||||||
experience_maker_holder_name_list=["maker1"],
|
|
||||||
strategy_fn=partial(get_strategy_from_args, args.trainer_strategy),
|
|
||||||
model_fn=trainer_model_fn,
|
|
||||||
env_info=env_info_trainer,
|
|
||||||
train_batch_size=args.train_batch_size,
|
|
||||||
buffer_limit=16,
|
|
||||||
eval_performance=True,
|
|
||||||
debug=args.debug,
|
|
||||||
update_lora_weights=not (args.lora_rank == 0),
|
|
||||||
)
|
|
||||||
for i, env_info_trainer in enumerate(env_info_trainers)
|
|
||||||
]
|
|
||||||
|
|
||||||
def model_fn():
|
|
||||||
actor = get_actor_from_args(args.model, args.pretrain).requires_grad_(False).half().cuda()
|
|
||||||
critic = get_critic_from_args(args.model, args.critic_pretrain).requires_grad_(False).half().cuda()
|
|
||||||
reward_model = get_reward_model_from_args(args.model, args.critic_pretrain).requires_grad_(False).half().cuda()
|
|
||||||
if args.initial_model_quant_ckpt is not None and args.model == "llama":
|
|
||||||
# quantize initial model
|
|
||||||
actor_cfg = AutoConfig.from_pretrained(args.pretrain)
|
|
||||||
with low_resource_init(), no_init_weights():
|
|
||||||
initial_model = get_actor_from_args(args.model, config=actor_cfg)
|
|
||||||
initial_model.model = (
|
|
||||||
llama_load_quant(
|
|
||||||
initial_model.model, args.initial_model_quant_ckpt, args.quant_bits, args.quant_group_size
|
|
||||||
)
|
|
||||||
.cuda()
|
|
||||||
.requires_grad_(False)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
initial_model = get_actor_from_args(args.model, args.pretrain).requires_grad_(False).half().cuda()
|
|
||||||
return actor, critic, reward_model, initial_model
|
|
||||||
|
|
||||||
# configure Experience Maker
|
|
||||||
experience_holder_ref = ExperienceMakerHolder.options(name="maker1", num_gpus=1, max_concurrency=2).remote(
|
|
||||||
detached_trainer_name_list=[f"trainer{i}" for i in range(args.num_trainers)],
|
|
||||||
strategy_fn=partial(get_strategy_from_args, args.maker_strategy),
|
|
||||||
model_fn=model_fn,
|
|
||||||
env_info=env_info_maker,
|
|
||||||
experience_batch_size=args.experience_batch_size,
|
|
||||||
kl_coef=0.1,
|
|
||||||
debug=args.debug,
|
|
||||||
update_lora_weights=not (args.lora_rank == 0),
|
|
||||||
# sync_models_from_trainers=True,
|
|
||||||
# generation kwargs:
|
|
||||||
max_length=512,
|
|
||||||
do_sample=True,
|
|
||||||
temperature=1.0,
|
|
||||||
top_k=50,
|
|
||||||
pad_token_id=tokenizer.pad_token_id,
|
|
||||||
eos_token_id=tokenizer.eos_token_id,
|
|
||||||
eval_performance=True,
|
|
||||||
use_cache=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# uncomment this function if sync_models_from_trainers is True
|
|
||||||
# ray.get([
|
|
||||||
# trainer_ref.sync_models_to_remote_makers.remote()
|
|
||||||
# for trainer_ref in trainer_refs
|
|
||||||
# ])
|
|
||||||
|
|
||||||
wait_tasks = []
|
|
||||||
|
|
||||||
total_steps = args.experience_batch_size * args.experience_steps // (args.num_trainers * args.train_batch_size)
|
|
||||||
for trainer_ref in trainer_refs:
|
|
||||||
wait_tasks.append(trainer_ref.fit.remote(total_steps, args.update_steps, args.train_epochs))
|
|
||||||
|
|
||||||
dataset_size = args.experience_batch_size * 4
|
|
||||||
|
|
||||||
def build_dataloader():
|
|
||||||
def tokenize_fn(texts):
|
|
||||||
batch = tokenizer(texts, return_tensors="pt", max_length=96, padding="max_length", truncation=True)
|
|
||||||
return {k: v.cuda() for k, v in batch.items()}
|
|
||||||
|
|
||||||
dataset = pd.read_csv(args.prompt_path)["prompt"]
|
|
||||||
dataloader = DataLoader(dataset=dataset, batch_size=dataset_size, shuffle=True, collate_fn=tokenize_fn)
|
|
||||||
return dataloader
|
|
||||||
|
|
||||||
wait_tasks.append(experience_holder_ref.workingloop.remote(build_dataloader, num_steps=args.experience_steps))
|
|
||||||
|
|
||||||
ray.get(wait_tasks)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument("--prompt_path", type=str, default=None)
|
|
||||||
parser.add_argument("--num_trainers", type=int, default=1)
|
|
||||||
parser.add_argument(
|
|
||||||
"--trainer_strategy",
|
|
||||||
choices=["ddp", "colossalai_gemini", "colossalai_zero2", "colossalai_gemini_cpu", "colossalai_zero2_cpu"],
|
|
||||||
default="ddp",
|
|
||||||
)
|
|
||||||
parser.add_argument("--maker_strategy", choices=["naive"], default="naive")
|
|
||||||
parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
|
|
||||||
parser.add_argument("--critic_model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
|
|
||||||
parser.add_argument("--pretrain", type=str, default=None)
|
|
||||||
parser.add_argument("--critic_pretrain", type=str, default=None)
|
|
||||||
parser.add_argument("--experience_steps", type=int, default=4)
|
|
||||||
parser.add_argument("--experience_batch_size", type=int, default=8)
|
|
||||||
parser.add_argument("--train_epochs", type=int, default=1)
|
|
||||||
parser.add_argument("--update_steps", type=int, default=2)
|
|
||||||
parser.add_argument("--train_batch_size", type=int, default=8)
|
|
||||||
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
|
|
||||||
|
|
||||||
parser.add_argument("--initial_model_quant_ckpt", type=str, default=None)
|
|
||||||
parser.add_argument("--quant_bits", type=int, default=4)
|
|
||||||
parser.add_argument("--quant_group_size", type=int, default=128)
|
|
||||||
parser.add_argument("--debug", action="store_true")
|
|
||||||
args = parser.parse_args()
|
|
||||||
ray.init(namespace=os.environ["RAY_NAMESPACE"], runtime_env={"env_vars": dict(os.environ)})
|
|
||||||
main(args)
|
|
@ -1,201 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import os
|
|
||||||
import socket
|
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
import pandas as pd
|
|
||||||
import ray
|
|
||||||
from coati.quant import llama_load_quant, low_resource_init
|
|
||||||
from coati.ray.detached_trainer_ppo import DetachedPPOTrainer
|
|
||||||
from coati.ray.experience_maker_holder import ExperienceMakerHolder
|
|
||||||
from coati.ray.utils import (
|
|
||||||
get_actor_from_args,
|
|
||||||
get_critic_from_args,
|
|
||||||
get_receivers_per_sender,
|
|
||||||
get_reward_model_from_args,
|
|
||||||
get_strategy_from_args,
|
|
||||||
)
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
from transformers import AutoConfig, AutoTokenizer
|
|
||||||
from transformers.modeling_utils import no_init_weights
|
|
||||||
|
|
||||||
|
|
||||||
def get_free_port():
|
|
||||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
|
||||||
s.bind(("", 0))
|
|
||||||
return s.getsockname()[1]
|
|
||||||
|
|
||||||
|
|
||||||
def get_local_ip():
|
|
||||||
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
|
|
||||||
s.connect(("8.8.8.8", 80))
|
|
||||||
return s.getsockname()[0]
|
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
|
||||||
master_addr = str(get_local_ip())
|
|
||||||
# trainer_env_info
|
|
||||||
trainer_port = str(get_free_port())
|
|
||||||
env_info_trainers = [
|
|
||||||
{
|
|
||||||
"local_rank": "0",
|
|
||||||
"rank": str(rank),
|
|
||||||
"world_size": str(args.num_trainers),
|
|
||||||
"master_port": trainer_port,
|
|
||||||
"master_addr": master_addr,
|
|
||||||
}
|
|
||||||
for rank in range(args.num_trainers)
|
|
||||||
]
|
|
||||||
|
|
||||||
# maker_env_info
|
|
||||||
maker_port = str(get_free_port())
|
|
||||||
env_info_makers = [
|
|
||||||
{
|
|
||||||
"local_rank": "0",
|
|
||||||
"rank": str(rank),
|
|
||||||
"world_size": str(args.num_makers),
|
|
||||||
"master_port": maker_port,
|
|
||||||
"master_addr": master_addr,
|
|
||||||
}
|
|
||||||
for rank in range(args.num_makers)
|
|
||||||
]
|
|
||||||
|
|
||||||
# configure tokenizer
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(args.pretrain)
|
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
|
||||||
|
|
||||||
def model_fn():
|
|
||||||
actor = get_actor_from_args(args.model, args.pretrain).requires_grad_(False).half().cuda()
|
|
||||||
critic = get_critic_from_args(args.model, args.critic_pretrain).requires_grad_(False).half().cuda()
|
|
||||||
reward_model = get_reward_model_from_args(args.model, args.critic_pretrain).requires_grad_(False).half().cuda()
|
|
||||||
if args.initial_model_quant_ckpt is not None and args.model == "llama":
|
|
||||||
# quantize initial model
|
|
||||||
actor_cfg = AutoConfig.from_pretrained(args.pretrain)
|
|
||||||
with low_resource_init(), no_init_weights():
|
|
||||||
initial_model = get_actor_from_args(args.model, config=actor_cfg)
|
|
||||||
initial_model.model = (
|
|
||||||
llama_load_quant(
|
|
||||||
initial_model.model, args.initial_model_quant_ckpt, args.quant_bits, args.quant_group_size
|
|
||||||
)
|
|
||||||
.cuda()
|
|
||||||
.requires_grad_(False)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
initial_model = get_actor_from_args(args.model, args.pretrain).requires_grad_(False).half().cuda()
|
|
||||||
return actor, critic, reward_model, initial_model
|
|
||||||
|
|
||||||
# configure Experience Maker
|
|
||||||
experience_holder_refs = [
|
|
||||||
ExperienceMakerHolder.options(name=f"maker{i}", num_gpus=1, max_concurrency=2).remote(
|
|
||||||
detached_trainer_name_list=[
|
|
||||||
f"trainer{x}"
|
|
||||||
for x in get_receivers_per_sender(i, args.num_makers, args.num_trainers, allow_idle_sender=False)
|
|
||||||
],
|
|
||||||
strategy_fn=partial(get_strategy_from_args, args.maker_strategy),
|
|
||||||
model_fn=model_fn,
|
|
||||||
env_info=env_info_maker,
|
|
||||||
kl_coef=0.1,
|
|
||||||
debug=args.debug,
|
|
||||||
update_lora_weights=not (args.lora_rank == 0),
|
|
||||||
# sync_models_from_trainers=True,
|
|
||||||
# generation kwargs:
|
|
||||||
max_length=512,
|
|
||||||
do_sample=True,
|
|
||||||
temperature=1.0,
|
|
||||||
top_k=50,
|
|
||||||
pad_token_id=tokenizer.pad_token_id,
|
|
||||||
eos_token_id=tokenizer.eos_token_id,
|
|
||||||
eval_performance=True,
|
|
||||||
use_cache=True,
|
|
||||||
)
|
|
||||||
for i, env_info_maker in enumerate(env_info_makers)
|
|
||||||
]
|
|
||||||
|
|
||||||
def trainer_model_fn():
|
|
||||||
actor = get_actor_from_args(args.model, args.pretrain, lora_rank=args.lora_rank).half().cuda()
|
|
||||||
critic = get_critic_from_args(args.model, args.critic_pretrain, lora_rank=args.lora_rank).half().cuda()
|
|
||||||
return actor, critic
|
|
||||||
|
|
||||||
# configure Trainer
|
|
||||||
trainer_refs = [
|
|
||||||
DetachedPPOTrainer.options(name=f"trainer{i}", num_gpus=1, max_concurrency=2).remote(
|
|
||||||
experience_maker_holder_name_list=[
|
|
||||||
f"maker{x}"
|
|
||||||
for x in get_receivers_per_sender(i, args.num_trainers, args.num_makers, allow_idle_sender=True)
|
|
||||||
],
|
|
||||||
strategy_fn=partial(get_strategy_from_args, args.trainer_strategy),
|
|
||||||
model_fn=trainer_model_fn,
|
|
||||||
env_info=env_info_trainer,
|
|
||||||
train_batch_size=args.train_batch_size,
|
|
||||||
buffer_limit=16,
|
|
||||||
eval_performance=True,
|
|
||||||
debug=args.debug,
|
|
||||||
update_lora_weights=not (args.lora_rank == 0),
|
|
||||||
)
|
|
||||||
for i, env_info_trainer in enumerate(env_info_trainers)
|
|
||||||
]
|
|
||||||
|
|
||||||
dataset_size = args.experience_batch_size * 4
|
|
||||||
|
|
||||||
def build_dataloader():
|
|
||||||
def tokenize_fn(texts):
|
|
||||||
batch = tokenizer(texts, return_tensors="pt", max_length=96, padding="max_length", truncation=True)
|
|
||||||
return {k: v.cuda() for k, v in batch.items()}
|
|
||||||
|
|
||||||
dataset = pd.read_csv(args.prompt_path)["prompt"]
|
|
||||||
dataloader = DataLoader(dataset=dataset, batch_size=dataset_size, shuffle=True, collate_fn=tokenize_fn)
|
|
||||||
return dataloader
|
|
||||||
|
|
||||||
# uncomment this function if sync_models_from_trainers is True
|
|
||||||
# ray.get([
|
|
||||||
# trainer_ref.sync_models_to_remote_makers.remote()
|
|
||||||
# for trainer_ref in trainer_refs
|
|
||||||
# ])
|
|
||||||
|
|
||||||
wait_tasks = []
|
|
||||||
|
|
||||||
for experience_holder_ref in experience_holder_refs:
|
|
||||||
wait_tasks.append(experience_holder_ref.workingloop.remote(build_dataloader, num_steps=args.experience_steps))
|
|
||||||
|
|
||||||
total_steps = (
|
|
||||||
args.experience_batch_size
|
|
||||||
* args.experience_steps
|
|
||||||
* args.num_makers
|
|
||||||
// (args.num_trainers * args.train_batch_size)
|
|
||||||
)
|
|
||||||
for trainer_ref in trainer_refs:
|
|
||||||
wait_tasks.append(trainer_ref.fit.remote(total_steps, args.update_steps, args.train_epochs))
|
|
||||||
|
|
||||||
ray.get(wait_tasks)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument("--prompt_path", type=str, default=None)
|
|
||||||
parser.add_argument("--num_makers", type=int, default=1)
|
|
||||||
parser.add_argument("--num_trainers", type=int, default=1)
|
|
||||||
parser.add_argument(
|
|
||||||
"--trainer_strategy",
|
|
||||||
choices=["ddp", "colossalai_gemini", "colossalai_zero2", "colossalai_gemini_cpu", "colossalai_zero2_cpu"],
|
|
||||||
default="ddp",
|
|
||||||
)
|
|
||||||
parser.add_argument("--maker_strategy", choices=["naive"], default="naive")
|
|
||||||
parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
|
|
||||||
parser.add_argument("--critic_model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
|
|
||||||
parser.add_argument("--pretrain", type=str, default=None)
|
|
||||||
parser.add_argument("--critic_pretrain", type=str, default=None)
|
|
||||||
parser.add_argument("--experience_steps", type=int, default=4)
|
|
||||||
parser.add_argument("--experience_batch_size", type=int, default=8)
|
|
||||||
parser.add_argument("--train_epochs", type=int, default=1)
|
|
||||||
parser.add_argument("--update_steps", type=int, default=2)
|
|
||||||
parser.add_argument("--train_batch_size", type=int, default=8)
|
|
||||||
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
|
|
||||||
|
|
||||||
parser.add_argument("--initial_model_quant_ckpt", type=str, default=None)
|
|
||||||
parser.add_argument("--quant_bits", type=int, default=4)
|
|
||||||
parser.add_argument("--quant_group_size", type=int, default=128)
|
|
||||||
parser.add_argument("--debug", action="store_true")
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
ray.init(namespace=os.environ["RAY_NAMESPACE"], runtime_env={"env_vars": dict(os.environ)})
|
|
||||||
main(args)
|
|
@ -1 +0,0 @@
|
|||||||
ray
|
|
@ -1,12 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
set -xe
|
|
||||||
BASE=$(realpath $(dirname $0))
|
|
||||||
|
|
||||||
export RAY_NAMESPACE=admin
|
|
||||||
export DATA=/data/scratch/chatgpt/prompts.csv
|
|
||||||
|
|
||||||
# install requirements
|
|
||||||
pip install -r ${BASE}/requirements.txt
|
|
||||||
|
|
||||||
python ${BASE}/mmmt_prompt.py --prompt_path $DATA --num_makers 2 --num_trainers 2 --trainer_strategy colossalai_gemini --model opt --critic_model opt --pretrain facebook/opt-350m --critic_pretrain facebook/opt-125m --experience_batch_size 4 --train_batch_size 2
|
|
@ -1,4 +1,4 @@
|
|||||||
pandas>=1.4.1
|
pandas>=1.4.1
|
||||||
sentencepiece
|
sentencepiece
|
||||||
colossalai==0.4.0
|
colossalai==0.4.7
|
||||||
prompt_toolkit
|
prompt_toolkit
|
||||||
|
@ -0,0 +1,455 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
Supervised fine-tuning of MoE models like Deepseek V3/R1 on a downstream task.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import resource
|
||||||
|
from contextlib import nullcontext
|
||||||
|
from types import MethodType
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from coati.dataset.loader import RawConversationDataset
|
||||||
|
from peft import LoraConfig
|
||||||
|
from tqdm import tqdm
|
||||||
|
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.accelerator import get_accelerator
|
||||||
|
from colossalai.booster import Booster
|
||||||
|
from colossalai.booster.plugin import (
|
||||||
|
GeminiPlugin,
|
||||||
|
HybridParallelPlugin,
|
||||||
|
LowLevelZeroPlugin,
|
||||||
|
MoeHybridParallelPlugin,
|
||||||
|
Plugin,
|
||||||
|
TorchDDPPlugin,
|
||||||
|
)
|
||||||
|
from colossalai.cluster import DistCoordinator
|
||||||
|
from colossalai.lazy import LazyInitContext
|
||||||
|
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||||
|
from colossalai.nn.optimizer import HybridAdam
|
||||||
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
|
|
||||||
|
def all_reduce_mean(loss: torch.Tensor, plugin: Plugin) -> torch.Tensor:
|
||||||
|
loss = loss.data
|
||||||
|
group = getattr(plugin, "dp_group", None)
|
||||||
|
dist.all_reduce(loss, group=group)
|
||||||
|
return loss / dist.get_world_size(group)
|
||||||
|
|
||||||
|
|
||||||
|
def train(args) -> None:
|
||||||
|
# ==============================
|
||||||
|
# Initialize Distributed Training
|
||||||
|
# ==============================
|
||||||
|
colossalai.launch_from_torch()
|
||||||
|
accelerator = get_accelerator()
|
||||||
|
coordinator = DistCoordinator()
|
||||||
|
|
||||||
|
# ==============================
|
||||||
|
# Initialize Booster
|
||||||
|
# ==============================
|
||||||
|
if args.plugin == "ddp":
|
||||||
|
plugin = TorchDDPPlugin(find_unused_parameters=True if args.use_grad_checkpoint is False else False)
|
||||||
|
elif args.plugin == "gemini":
|
||||||
|
plugin = GeminiPlugin(
|
||||||
|
precision=args.mixed_precision,
|
||||||
|
initial_scale=2**16,
|
||||||
|
max_norm=args.grad_clip,
|
||||||
|
enable_gradient_accumulation=(args.accumulation_steps > 1),
|
||||||
|
enable_fused_normalization=get_accelerator().is_available(),
|
||||||
|
enable_flash_attention=args.use_flash_attn,
|
||||||
|
)
|
||||||
|
elif args.plugin == "gemini_auto":
|
||||||
|
plugin = GeminiPlugin(
|
||||||
|
precision=args.mixed_precision,
|
||||||
|
placement_policy="auto",
|
||||||
|
initial_scale=2**16,
|
||||||
|
max_norm=args.grad_clip,
|
||||||
|
enable_gradient_accumulation=(args.accumulation_steps > 1),
|
||||||
|
enable_fused_normalization=get_accelerator().is_available(),
|
||||||
|
enable_flash_attention=args.use_flash_attn,
|
||||||
|
)
|
||||||
|
elif args.plugin == "zero2":
|
||||||
|
plugin = LowLevelZeroPlugin(
|
||||||
|
stage=2,
|
||||||
|
precision=args.mixed_precision,
|
||||||
|
initial_scale=2**16,
|
||||||
|
max_norm=args.grad_clip,
|
||||||
|
)
|
||||||
|
elif args.plugin == "zero2_cpu":
|
||||||
|
plugin = LowLevelZeroPlugin(
|
||||||
|
stage=2,
|
||||||
|
precision=args.mixed_precision,
|
||||||
|
initial_scale=2**16,
|
||||||
|
cpu_offload=True,
|
||||||
|
max_norm=args.grad_clip,
|
||||||
|
)
|
||||||
|
elif args.plugin == "3d":
|
||||||
|
plugin = HybridParallelPlugin(
|
||||||
|
tp_size=args.tp,
|
||||||
|
pp_size=args.pp,
|
||||||
|
sp_size=args.sp,
|
||||||
|
sequence_parallelism_mode=args.sp_mode,
|
||||||
|
zero_stage=args.zero_stage,
|
||||||
|
enable_flash_attention=args.use_flash_attn,
|
||||||
|
enable_fused_normalization=get_accelerator().is_available(),
|
||||||
|
enable_sequence_parallelism=args.enable_sequence_parallelism,
|
||||||
|
cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False,
|
||||||
|
max_norm=args.grad_clip,
|
||||||
|
precision=args.mixed_precision,
|
||||||
|
microbatch_size=args.microbatch_size,
|
||||||
|
)
|
||||||
|
elif args.plugin == "moe":
|
||||||
|
plugin = MoeHybridParallelPlugin(
|
||||||
|
ep_size=args.ep,
|
||||||
|
tp_size=args.tp,
|
||||||
|
pp_size=args.pp,
|
||||||
|
zero_stage=args.zero_stage,
|
||||||
|
sp_size=args.sp,
|
||||||
|
sequence_parallelism_mode=args.sp_mode,
|
||||||
|
enable_sequence_parallelism=args.sp > 1,
|
||||||
|
enable_fused_normalization=get_accelerator().is_available(),
|
||||||
|
enable_flash_attention=args.use_flash_attn,
|
||||||
|
max_norm=args.grad_clip,
|
||||||
|
precision=args.mixed_precision,
|
||||||
|
microbatch_size=args.microbatch_size,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown plugin {args.plugin}")
|
||||||
|
|
||||||
|
booster = Booster(plugin=plugin)
|
||||||
|
|
||||||
|
def is_master():
|
||||||
|
if isinstance(plugin, HybridParallelPlugin) and plugin.pp_size > 1:
|
||||||
|
return coordinator.rank == coordinator.world_size - 1
|
||||||
|
return coordinator.is_master()
|
||||||
|
|
||||||
|
# ==============================
|
||||||
|
# Initialize Tensorboard and Save Config
|
||||||
|
# ==============================
|
||||||
|
if is_master():
|
||||||
|
if args.tensorboard_dir is not None:
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
os.makedirs(args.tensorboard_dir, exist_ok=True)
|
||||||
|
writer = SummaryWriter(args.tensorboard_dir)
|
||||||
|
|
||||||
|
with open(args.config_file, "w") as f:
|
||||||
|
json.dump(args.__dict__, f, indent=4)
|
||||||
|
|
||||||
|
# ======================================================
|
||||||
|
# Initialize Tokenizer, Dataset, Collator and Dataloader
|
||||||
|
# ======================================================
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(args.pretrained, trust_remote_code=True)
|
||||||
|
|
||||||
|
coordinator.print_on_master(
|
||||||
|
f"Training Info:\nConfig file: {args.config_file} \nTensorboard logs: {args.tensorboard_dir} \nModel checkpoint: {args.save_dir}"
|
||||||
|
)
|
||||||
|
|
||||||
|
coordinator.print_on_master(f"Load dataset: {args.dataset}")
|
||||||
|
dataset = RawConversationDataset(
|
||||||
|
tokenizer,
|
||||||
|
args.dataset,
|
||||||
|
args.max_length,
|
||||||
|
)
|
||||||
|
|
||||||
|
dataloader = plugin.prepare_dataloader(
|
||||||
|
dataset=dataset,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
shuffle=True,
|
||||||
|
drop_last=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
coordinator.print_on_master(
|
||||||
|
f"Max device memory after data loader: {accelerator.max_memory_allocated() / 1024 ** 2:.2f} MB"
|
||||||
|
)
|
||||||
|
|
||||||
|
# ======================================================
|
||||||
|
# Initialize Model, Objective, Optimizer and LR Scheduler
|
||||||
|
# ======================================================
|
||||||
|
# When training the ChatGLM model, LoRA and gradient checkpointing are incompatible.
|
||||||
|
init_ctx = (
|
||||||
|
LazyInitContext(default_device=get_current_device())
|
||||||
|
if isinstance(plugin, (GeminiPlugin, HybridParallelPlugin))
|
||||||
|
else nullcontext()
|
||||||
|
)
|
||||||
|
attn_impl = "eager" if get_accelerator().name == "npu" else "flash_attention_2"
|
||||||
|
|
||||||
|
config = AutoConfig.from_pretrained(args.pretrained, trust_remote_code=True)
|
||||||
|
|
||||||
|
with init_ctx:
|
||||||
|
# from_pretrained is not compatible with LoRA, we load pretrained weights later.
|
||||||
|
# model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
# args.pretrained,
|
||||||
|
# torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
|
||||||
|
# trust_remote_code=True,
|
||||||
|
# attn_implementation=attn_impl,
|
||||||
|
# )
|
||||||
|
model = AutoModelForCausalLM.from_config(
|
||||||
|
config,
|
||||||
|
trust_remote_code=True,
|
||||||
|
attn_implementation=attn_impl,
|
||||||
|
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.lora_rank > 0:
|
||||||
|
if model.__class__.__name__.startswith("DeepseekV3"):
|
||||||
|
lora_config = LoraConfig(
|
||||||
|
task_type="CAUSAL_LM",
|
||||||
|
r=args.lora_rank,
|
||||||
|
lora_alpha=args.lora_alpha,
|
||||||
|
target_modules=["gate_proj", "up_proj", "down_proj"],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
lora_config = LoraConfig(task_type="CAUSAL_LM", r=args.lora_rank, lora_alpha=args.lora_alpha)
|
||||||
|
model = booster.enable_lora(model, lora_config=lora_config)
|
||||||
|
|
||||||
|
# this is essential, otherwise the grad checkpoint will not work.
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
if args.use_grad_checkpoint:
|
||||||
|
model.gradient_checkpointing_enable()
|
||||||
|
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
|
||||||
|
if model.config.__class__.__name__.startswith("DeepseekV3"):
|
||||||
|
model.config.use_cache = False
|
||||||
|
model.eval()
|
||||||
|
# enable grad for moe layers
|
||||||
|
for m in model.modules():
|
||||||
|
if m.__class__.__name__ == "DeepseekV3MoE":
|
||||||
|
m.moe_infer = MethodType(m.moe_infer.__wrapped__, m)
|
||||||
|
|
||||||
|
model_numel = sum(p.numel() for p in model.parameters())
|
||||||
|
coordinator.print_on_master(f"Model params: {model_numel / 1e9:.2f} B")
|
||||||
|
|
||||||
|
optimizer = HybridAdam(
|
||||||
|
model_params=model.parameters(),
|
||||||
|
lr=args.lr,
|
||||||
|
betas=(0.9, 0.95),
|
||||||
|
weight_decay=args.weight_decay,
|
||||||
|
adamw_mode=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.warmup_steps is None:
|
||||||
|
args.warmup_steps = int(args.num_epochs * 0.025 * (len(dataloader) // args.accumulation_steps))
|
||||||
|
coordinator.print_on_master(f"Warmup steps is set to {args.warmup_steps}")
|
||||||
|
|
||||||
|
lr_scheduler = CosineAnnealingWarmupLR(
|
||||||
|
optimizer=optimizer,
|
||||||
|
total_steps=args.num_epochs * (len(dataloader) // args.accumulation_steps),
|
||||||
|
warmup_steps=args.warmup_steps,
|
||||||
|
eta_min=0.1 * args.lr,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Flash attention will be disabled because it does NOT support fp32.
|
||||||
|
default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
|
||||||
|
torch.set_default_dtype(default_dtype)
|
||||||
|
model, optimizer, _, dataloader, lr_scheduler = booster.boost(
|
||||||
|
model=model,
|
||||||
|
optimizer=optimizer,
|
||||||
|
lr_scheduler=lr_scheduler,
|
||||||
|
dataloader=dataloader,
|
||||||
|
)
|
||||||
|
|
||||||
|
torch.set_default_dtype(torch.float)
|
||||||
|
booster.load_model(model, args.pretrained, low_cpu_mem_mode=False, num_threads=8)
|
||||||
|
|
||||||
|
coordinator.print_on_master(
|
||||||
|
f"Booster init max device memory: {accelerator.max_memory_allocated() / 1024 ** 2:.2f} MB"
|
||||||
|
)
|
||||||
|
coordinator.print_on_master(
|
||||||
|
f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
|
||||||
|
)
|
||||||
|
|
||||||
|
start_epoch = 0
|
||||||
|
start_step = 0
|
||||||
|
|
||||||
|
num_steps_per_epoch = len(dataloader) // args.accumulation_steps
|
||||||
|
|
||||||
|
for epoch in range(start_epoch, args.num_epochs):
|
||||||
|
dataloader.sampler.set_epoch(epoch=epoch)
|
||||||
|
if isinstance(plugin, HybridParallelPlugin) and plugin.pp_size > 1:
|
||||||
|
data_iter = iter(dataloader)
|
||||||
|
step_bar = tqdm(
|
||||||
|
range(len(dataloader)),
|
||||||
|
desc="Step",
|
||||||
|
disable=not is_master(),
|
||||||
|
)
|
||||||
|
for step in step_bar:
|
||||||
|
outputs = booster.execute_pipeline(
|
||||||
|
data_iter,
|
||||||
|
model,
|
||||||
|
criterion=lambda outputs, inputs: outputs[0],
|
||||||
|
optimizer=optimizer,
|
||||||
|
return_loss=True,
|
||||||
|
)
|
||||||
|
loss = outputs["loss"]
|
||||||
|
if booster.plugin.stage_manager.is_last_stage():
|
||||||
|
global_loss = all_reduce_mean(loss, plugin)
|
||||||
|
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
if booster.plugin.stage_manager.is_last_stage():
|
||||||
|
grad_norm = optimizer.get_grad_norm()
|
||||||
|
step_bar.set_postfix({"loss": global_loss.item(), "grad_norm": grad_norm})
|
||||||
|
|
||||||
|
if args.tensorboard_dir is not None and is_master():
|
||||||
|
global_step = (epoch * num_steps_per_epoch) + (step + 1) // args.accumulation_steps
|
||||||
|
writer.add_scalar(tag="Loss", scalar_value=global_loss.item(), global_step=global_step)
|
||||||
|
writer.add_scalar(
|
||||||
|
tag="Learning Rate",
|
||||||
|
scalar_value=lr_scheduler.get_last_lr()[0],
|
||||||
|
global_step=global_step,
|
||||||
|
)
|
||||||
|
writer.add_scalar(tag="Grad Norm", scalar_value=grad_norm, global_step=global_step)
|
||||||
|
|
||||||
|
lr_scheduler.step()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
else:
|
||||||
|
pbar = tqdm(
|
||||||
|
dataloader,
|
||||||
|
desc=f"Epoch {epoch}",
|
||||||
|
disable=not is_master(),
|
||||||
|
initial=start_step // args.accumulation_steps,
|
||||||
|
)
|
||||||
|
total_loss = torch.tensor(0.0, device=get_current_device())
|
||||||
|
for step, batch in enumerate(pbar, start=start_step // args.accumulation_steps):
|
||||||
|
batch = {k: v.to(get_current_device()) for k, v in batch.items() if isinstance(v, torch.Tensor)}
|
||||||
|
|
||||||
|
batch_output = model(**batch)
|
||||||
|
|
||||||
|
loss = batch_output.loss / args.accumulation_steps
|
||||||
|
total_loss.add_(loss.data)
|
||||||
|
|
||||||
|
booster.backward(loss=loss, optimizer=optimizer)
|
||||||
|
|
||||||
|
if (step + 1) % args.accumulation_steps == 0:
|
||||||
|
all_reduce_mean(total_loss, plugin)
|
||||||
|
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
grad_norm = optimizer.get_grad_norm()
|
||||||
|
pbar.set_postfix({"loss": total_loss.item(), "grad_norm": grad_norm})
|
||||||
|
if args.tensorboard_dir is not None and is_master():
|
||||||
|
global_step = (epoch * num_steps_per_epoch) + (step + 1) // args.accumulation_steps
|
||||||
|
writer.add_scalar(tag="Loss", scalar_value=total_loss.item(), global_step=global_step)
|
||||||
|
writer.add_scalar(
|
||||||
|
tag="Learning Rate",
|
||||||
|
scalar_value=lr_scheduler.get_last_lr()[0],
|
||||||
|
global_step=global_step,
|
||||||
|
)
|
||||||
|
writer.add_scalar(tag="Grad Norm", scalar_value=grad_norm, global_step=global_step)
|
||||||
|
|
||||||
|
lr_scheduler.step()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
total_loss.fill_(0.0)
|
||||||
|
|
||||||
|
# Delete cache.
|
||||||
|
# del batch, batch_labels, batch_output, loss
|
||||||
|
accelerator.empty_cache()
|
||||||
|
|
||||||
|
# Final save.
|
||||||
|
coordinator.print_on_master("Start saving final model checkpoint")
|
||||||
|
if args.lora_rank > 0:
|
||||||
|
booster.save_lora_as_pretrained(model, os.path.join(args.save_dir, "lora"))
|
||||||
|
else:
|
||||||
|
booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True)
|
||||||
|
coordinator.print_on_master(f"Saved final model checkpoint at epoch {epoch} at folder {args.save_dir}")
|
||||||
|
|
||||||
|
coordinator.print_on_master(f"Max device memory usage: {accelerator.max_memory_allocated()/1024**2:.2f} MB")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
# Basic training information.
|
||||||
|
parser.add_argument(
|
||||||
|
"-m",
|
||||||
|
"--pretrained",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Address of the pre-trained model",
|
||||||
|
)
|
||||||
|
parser.add_argument("-d", "--dataset", type=str, required=True, help="Raw Jonl dataset for training.")
|
||||||
|
parser.add_argument(
|
||||||
|
"-p",
|
||||||
|
"--plugin",
|
||||||
|
type=str,
|
||||||
|
default="zero2",
|
||||||
|
choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d", "ddp", "moe"],
|
||||||
|
help="Choose which plugin to use",
|
||||||
|
)
|
||||||
|
parser.add_argument("--save_dir", type=str, default="checkpoint_dir", help="Checkpoint directory")
|
||||||
|
parser.add_argument("--tensorboard_dir", type=str, default=None, help="Tensorboard directory")
|
||||||
|
parser.add_argument("--config_file", type=str, default="training_config.json", help="Config file")
|
||||||
|
# Training parameters
|
||||||
|
parser.add_argument("-n", "--num_epochs", type=int, default=1, help="Number of training epochs")
|
||||||
|
parser.add_argument("--accumulation_steps", type=int, default=1, help="Number of accumulation steps")
|
||||||
|
parser.add_argument("--batch_size", type=int, default=2, help="Global Batch size of each process")
|
||||||
|
parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate")
|
||||||
|
parser.add_argument("--max_length", type=int, default=8192, help="Model max length")
|
||||||
|
parser.add_argument(
|
||||||
|
"--mixed_precision",
|
||||||
|
type=str,
|
||||||
|
default="bf16",
|
||||||
|
choices=["fp16", "bf16"],
|
||||||
|
help="Mixed precision",
|
||||||
|
)
|
||||||
|
parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")
|
||||||
|
parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay")
|
||||||
|
parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
|
||||||
|
parser.add_argument(
|
||||||
|
"-g",
|
||||||
|
"--use_grad_checkpoint",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="Use gradient checkpointing",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-f",
|
||||||
|
"--use_flash_attn",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="Use flash-attention",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Additional arguments for 3d plugin.
|
||||||
|
parser.add_argument("--tp", type=int, default=1, help="TP size, used for 3d plugin.")
|
||||||
|
parser.add_argument("--pp", type=int, default=1, help="PP size, used for 3d plugin.")
|
||||||
|
parser.add_argument("--sp", type=int, default=1, help="SP size, used for 3d plugin.")
|
||||||
|
parser.add_argument("--ep", type=int, default=1, help="EP size, used for moe plugin.")
|
||||||
|
parser.add_argument("--zero_stage", type=int, default=1, help="Zero stage, used for 3d plugin.", choices=[0, 1, 2])
|
||||||
|
parser.add_argument(
|
||||||
|
"--sp_mode",
|
||||||
|
type=str,
|
||||||
|
default="split_gather",
|
||||||
|
choices=["split_gather", "ring", "all_to_all"],
|
||||||
|
help="SP mode, used for 3d plugin.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable_sequence_parallelism",
|
||||||
|
default=False,
|
||||||
|
action="store_true",
|
||||||
|
help="Whether to enable SP, used for 3d plugin.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--zero_cpu_offload", default=False, action="store_true", help="Whether to use offloading, used for 3d plugin."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--microbatch_size", type=int, default=1, help="Batch size for each process in PP, used for 3d plugin."
|
||||||
|
)
|
||||||
|
parser.add_argument("--lora_rank", type=int, default=0, help="lora rank when using lora to train.")
|
||||||
|
parser.add_argument("--lora_alpha", type=int, default=8, help="lora alpha when using lora to train.")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.plugin in ["3d", "moe"] and args.pp > 1 and args.accumulation_steps > 1:
|
||||||
|
raise ValueError("Accumulation steps should be 1 when using PP. Please adjust batch size directly.")
|
||||||
|
|
||||||
|
train(args)
|
@ -0,0 +1,10 @@
|
|||||||
|
[{"role": "user", "content": "只剩一个心脏了还能活吗?"}, {"role": "assistant", "content": "能,人本来就只有一个心脏。"}]
|
||||||
|
[{"role": "user", "content": "爸爸再婚,我是不是就有了个新娘?"}, {"role": "assistant", "content": "不是的,你有了一个继母。\"新娘\"是指新婚的女方,而你爸爸再婚,他的新婚妻子对你来说是继母。"}]
|
||||||
|
[{"role": "user", "content": "马上要上游泳课了,昨天洗的泳裤还没干,怎么办"}, {"role": "assistant", "content": "游泳时泳裤本来就会湿,不用晾干。"}]
|
||||||
|
[{"role": "user", "content": "为什么没人说ABCD型的成语?🤔"}, {"role": "assistant", "content": "这是因为中文成语一般都是四字成语,每个字都有其特定的含义,四个字合在一起构成一个完整的意思。而ABCD型的成语最常见,所以大家不会刻意强调。"}]
|
||||||
|
[{"role": "user", "content": "太阳还有五十亿年就没了,那到时候向日葵看哪呢?"}, {"role": "assistant", "content": "实际上,向日葵并不会意识到太阳的存在与否,它们只是通过光敏反应来追踪太阳。如果太阳消失,向日葵也不会存活下来,所以它们不会看向任何地方。"}]
|
||||||
|
[{"role": "user", "content": "我把我车钥匙落在车里,现在锁不了车门了。 有人能告诉我怎么办吗?"}, {"role": "assistant", "content": "如果你的车钥匙已经锁在车里,那么你实际上已经无法再从外部打开车门了,所以你的车已经被锁住了。你需要做的是找专业的开锁服务或者联系汽车救援服务来帮你打开车门,取回你的车钥匙。"}]
|
||||||
|
[{"role": "user", "content": "同学说steam好玩,我现在八级烫伤了怎么办? 救救我"}, {"role": "assistant", "content": "\"Steam\"是一个电子游戏分发平台,不是steam蒸汽。建议立即去医院,顺便看看脑子。"}]
|
||||||
|
[{"role": "user", "content": "人在岩浆里憋气最长时间是多少? 刚才好奇搜了一下,水里憋气时间最长的是克罗地亚54岁在水中闭气24分33秒,但是我搜在岩浆里憋气时间最长,居然发现一个没有,不都是液体吗?为什么没人在岩浆里憋气?"}, {"role": "assistant", "content": "因为岩浆的温度极高,达到了800-1200摄氏度,人如果接触到岩浆,会立刻被烧伤,甚至瞬间死亡,所以没有人能在岩浆里憋气。这并不是一个合理的比较,因为水和岩浆虽然都是液体,但是它们的性质差别极大。"}]
|
||||||
|
[{"role": "user", "content": "出生证丢了怎么证明自己出生了"}, {"role": "assistant", "content": "这是个很有意思的问题,从现实意义来讲,你既然已经活着,就已经证明你出生了。如果需要法律意义上的证明,可以通过其他官方文件,例如护照、身份证、驾驶证等。此外,医院的出生记录、学校的入学记录、家庭照片等也可以作为间接证据。如果所有这些都没有,可以向出生地的公共卫生部门或者医院申请复制出生证明。"}]
|
||||||
|
[{"role": "user", "content": "为什么抄袭永远都是今人抄袭古人,而没有古人抄袭今人的案例?"}, {"role": "assistant", "content": "这是因为时间的流逝是单向的,古人无法预知今人的思想和作品,所以古人无法抄袭今人。而今人可以通过学习古人的作品,因此有可能出现抄袭古人的情况。"}]
|
@ -13,7 +13,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
|
|||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.booster import Booster
|
from colossalai.booster import Booster
|
||||||
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
|
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
||||||
from colossalai.cluster import DistCoordinator
|
from colossalai.cluster import DistCoordinator
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||||
@ -29,8 +29,6 @@ def train(args):
|
|||||||
# check lora compatibility
|
# check lora compatibility
|
||||||
if "gemini" in args.plugin and lora_config is not None and lora_config.r > 0:
|
if "gemini" in args.plugin and lora_config is not None and lora_config.r > 0:
|
||||||
raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
|
raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
|
||||||
if args.plugin == "gemini_auto" and args.accumulation_steps > 1:
|
|
||||||
raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin")
|
|
||||||
|
|
||||||
# ==============================
|
# ==============================
|
||||||
# Initialize Distributed Training
|
# Initialize Distributed Training
|
||||||
@ -46,7 +44,7 @@ def train(args):
|
|||||||
Default torch ddp plugin without any acceleration, for
|
Default torch ddp plugin without any acceleration, for
|
||||||
debugging purpose acceleration, for debugging purpose
|
debugging purpose acceleration, for debugging purpose
|
||||||
"""
|
"""
|
||||||
plugin = TorchDDPPlugin(find_unused_parameters=True)
|
plugin = TorchDDPPlugin(find_unused_parameters=not args.grad_checkpoint)
|
||||||
elif args.plugin == "gemini":
|
elif args.plugin == "gemini":
|
||||||
plugin = GeminiPlugin(
|
plugin = GeminiPlugin(
|
||||||
precision=args.mixed_precision,
|
precision=args.mixed_precision,
|
||||||
@ -56,14 +54,6 @@ def train(args):
|
|||||||
enable_gradient_accumulation=True,
|
enable_gradient_accumulation=True,
|
||||||
enable_flash_attention=args.use_flash_attn,
|
enable_flash_attention=args.use_flash_attn,
|
||||||
)
|
)
|
||||||
elif args.plugin == "gemini_auto":
|
|
||||||
plugin = GeminiPlugin(
|
|
||||||
precision=args.mixed_precision,
|
|
||||||
placement_policy="auto",
|
|
||||||
initial_scale=2**16,
|
|
||||||
max_norm=args.grad_clip,
|
|
||||||
enable_flash_attention=args.use_flash_attn,
|
|
||||||
)
|
|
||||||
elif args.plugin == "zero2":
|
elif args.plugin == "zero2":
|
||||||
plugin = LowLevelZeroPlugin(
|
plugin = LowLevelZeroPlugin(
|
||||||
stage=2,
|
stage=2,
|
||||||
@ -92,20 +82,24 @@ def train(args):
|
|||||||
parallel_output=False,
|
parallel_output=False,
|
||||||
max_norm=args.grad_clip,
|
max_norm=args.grad_clip,
|
||||||
precision=args.mixed_precision,
|
precision=args.mixed_precision,
|
||||||
|
microbatch_size=args.microbatch_size,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown plugin {args.plugin}")
|
raise ValueError(f"Unknown plugin {args.plugin}")
|
||||||
|
|
||||||
booster = Booster(plugin=plugin)
|
booster = Booster(plugin=plugin)
|
||||||
ref_booster = Booster(plugin=plugin)
|
|
||||||
|
|
||||||
# ======================================================
|
ref_plugin = HybridParallelPlugin(
|
||||||
# Initialize Model, Objective, Optimizer and LR Scheduler
|
tp_size=args.ref_tp,
|
||||||
# ======================================================
|
pp_size=1,
|
||||||
# Temp Fix: Disable lazy init due to version conflict
|
zero_stage=args.zero_stage,
|
||||||
# init_ctx = (
|
enable_flash_attention=args.use_flash_attn,
|
||||||
# LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext()
|
cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False,
|
||||||
# )
|
parallel_output=False,
|
||||||
|
max_norm=args.grad_clip,
|
||||||
|
precision=args.mixed_precision,
|
||||||
|
)
|
||||||
|
ref_booster = Booster(plugin=ref_plugin)
|
||||||
|
|
||||||
init_ctx = nullcontext()
|
init_ctx = nullcontext()
|
||||||
with init_ctx:
|
with init_ctx:
|
||||||
@ -130,6 +124,7 @@ def train(args):
|
|||||||
ref_model = AutoModelForCausalLM.from_pretrained(args.pretrain)
|
ref_model = AutoModelForCausalLM.from_pretrained(args.pretrain)
|
||||||
else:
|
else:
|
||||||
ref_model = None
|
ref_model = None
|
||||||
|
|
||||||
if args.lora_config is not None:
|
if args.lora_config is not None:
|
||||||
model = convert_to_lora_module(model, lora_config=lora_config)
|
model = convert_to_lora_module(model, lora_config=lora_config)
|
||||||
for name, module in model.named_modules():
|
for name, module in model.named_modules():
|
||||||
@ -139,7 +134,9 @@ def train(args):
|
|||||||
disable_dropout(ref_model)
|
disable_dropout(ref_model)
|
||||||
|
|
||||||
if args.grad_checkpoint:
|
if args.grad_checkpoint:
|
||||||
# Note, for some models, lora may not be compatible with gradient checkpointing
|
# Make sure gradient checkpointing can be activated.
|
||||||
|
model.train()
|
||||||
|
# Note, for some models, lora may not be compatible with gradient checkpointing.
|
||||||
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
|
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
|
||||||
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
|
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
|
||||||
|
|
||||||
@ -169,7 +166,7 @@ def train(args):
|
|||||||
adamw_mode=True,
|
adamw_mode=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# configure dataset
|
# Configure dataset
|
||||||
coordinator.print_on_master(f"Load dataset: {args.dataset}")
|
coordinator.print_on_master(f"Load dataset: {args.dataset}")
|
||||||
mode_map = {"train": "train", "valid": "validation", "test": "test"}
|
mode_map = {"train": "train", "valid": "validation", "test": "test"}
|
||||||
train_dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train", mode_map=mode_map)
|
train_dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train", mode_map=mode_map)
|
||||||
@ -213,14 +210,15 @@ def train(args):
|
|||||||
|
|
||||||
default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
|
default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
|
||||||
torch.set_default_dtype(default_dtype)
|
torch.set_default_dtype(default_dtype)
|
||||||
|
|
||||||
model, optim, _, train_dataloader, lr_scheduler = booster.boost(
|
model, optim, _, train_dataloader, lr_scheduler = booster.boost(
|
||||||
model=model,
|
model=model,
|
||||||
optimizer=optim,
|
optimizer=optim,
|
||||||
lr_scheduler=lr_scheduler,
|
lr_scheduler=lr_scheduler,
|
||||||
dataloader=train_dataloader,
|
dataloader=train_dataloader,
|
||||||
)
|
)
|
||||||
if ref_model is not None:
|
ref_model, _, _, _, _ = ref_booster.boost(model=ref_model)
|
||||||
ref_model, _, _, _, _ = ref_booster.boost(model=ref_model, dataloader=train_dataloader)
|
|
||||||
torch.set_default_dtype(torch.float)
|
torch.set_default_dtype(torch.float)
|
||||||
|
|
||||||
coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
|
coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
|
||||||
@ -312,7 +310,7 @@ if __name__ == "__main__":
|
|||||||
"--plugin",
|
"--plugin",
|
||||||
type=str,
|
type=str,
|
||||||
default="gemini",
|
default="gemini",
|
||||||
choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d"],
|
choices=["gemini", "zero2", "zero2_cpu", "3d", "ddp"],
|
||||||
help="Choose which plugin to use",
|
help="Choose which plugin to use",
|
||||||
)
|
)
|
||||||
parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")
|
parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")
|
||||||
@ -342,22 +340,35 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--max_length", type=int, default=2048, help="Model max length")
|
parser.add_argument("--max_length", type=int, default=2048, help="Model max length")
|
||||||
parser.add_argument("--max_epochs", type=int, default=3)
|
parser.add_argument("--max_epochs", type=int, default=3)
|
||||||
parser.add_argument("--batch_size", type=int, default=4)
|
parser.add_argument("--batch_size", type=int, default=4)
|
||||||
|
parser.add_argument("--disable_loss_mask", default=False, action="store_true")
|
||||||
|
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
|
||||||
|
parser.add_argument("--lora_config", type=str, default=None, help="low-rank adaptation config file path")
|
||||||
|
parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints")
|
||||||
|
parser.add_argument("--lr", type=float, default=5e-6)
|
||||||
|
parser.add_argument("--accumulation_steps", type=int, default=1)
|
||||||
|
parser.add_argument("--log_dir", default=None, type=str)
|
||||||
|
parser.add_argument("--use_wandb", default=False, action="store_true")
|
||||||
|
parser.add_argument("--grad_checkpoint", default=False, action="store_true")
|
||||||
|
parser.add_argument("--use_flash_attn", default=False, action="store_true")
|
||||||
|
parser.add_argument(
|
||||||
|
"--microbatch_size",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="Micro batch size for PP training. To activate PP training for DPO-like algorithm, you must keep size even and the size should be equal or greater than 2.",
|
||||||
|
)
|
||||||
|
# Parameter for reference model
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--disable_reference_model",
|
"--disable_reference_model",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
default=False,
|
default=False,
|
||||||
help="Disable the reference model (enabled by default)",
|
help="Disable the reference model (enabled by default)",
|
||||||
)
|
)
|
||||||
parser.add_argument("--disable_loss_mask", default=False, action="store_true")
|
parser.add_argument(
|
||||||
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
|
"--ref_tp",
|
||||||
parser.add_argument("--lora_config", type=str, default=None, help="low-rank adaptation config file path")
|
type=int,
|
||||||
parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints")
|
default=1,
|
||||||
parser.add_argument("--lr", type=float, default=5e-6)
|
help="TP size for reference model; used only when reference model is too large.",
|
||||||
parser.add_argument("--accumulation_steps", type=int, default=8)
|
)
|
||||||
parser.add_argument("--log_dir", default=None, type=str)
|
|
||||||
parser.add_argument("--use_wandb", default=False, action="store_true")
|
|
||||||
parser.add_argument("--grad_checkpoint", default=False, action="store_true")
|
|
||||||
parser.add_argument("--use_flash_attn", default=False, action="store_true")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# fool proof hyperparameter setup
|
# fool proof hyperparameter setup
|
||||||
|
494
applications/ColossalChat/examples/training_scripts/train_grpo.py
Executable file
494
applications/ColossalChat/examples/training_scripts/train_grpo.py
Executable file
@ -0,0 +1,494 @@
|
|||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import resource
|
||||||
|
from contextlib import nullcontext
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from coati.dataset import (
|
||||||
|
DataCollatorForPromptDataset,
|
||||||
|
DataCollatorForSupervisedDataset,
|
||||||
|
StatefulDistributedSampler,
|
||||||
|
load_tokenized_dataset,
|
||||||
|
setup_conversation_template,
|
||||||
|
)
|
||||||
|
from coati.models import LoraConfig, RewardModel, RLVRRewardModel, convert_to_lora_module, disable_dropout, lora_manager
|
||||||
|
from coati.trainer import GRPOTrainer
|
||||||
|
from coati.utils import load_checkpoint
|
||||||
|
from coati.utils.reward_score import *
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.booster import Booster
|
||||||
|
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
|
||||||
|
from colossalai.cluster import DistCoordinator
|
||||||
|
from colossalai.logging import get_dist_logger
|
||||||
|
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||||
|
from colossalai.nn.optimizer import HybridAdam
|
||||||
|
from colossalai.shardformer.policies.auto_policy import get_autopolicy
|
||||||
|
|
||||||
|
logger = get_dist_logger()
|
||||||
|
# default settings for response format tags, overwrite it in chat_template definition if needed
|
||||||
|
response_format_tags = {
|
||||||
|
"think_start": {"text": "<think>", "num_occur": 1},
|
||||||
|
"think_end": {"text": "</think>", "num_occur": 1},
|
||||||
|
"answer_start": {"text": "<answer>", "num_occur": 1},
|
||||||
|
"answer_end": {"text": "</answer>", "num_occur": 1},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def train(args):
|
||||||
|
global response_format_tags
|
||||||
|
lora_config = None
|
||||||
|
if args.lora_config is not None:
|
||||||
|
lora_config = LoraConfig.from_file(args.lora_config)
|
||||||
|
# check lora compatibility
|
||||||
|
if "gemini" in args.plugin and lora_config is not None and lora_config.r > 0:
|
||||||
|
raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
|
||||||
|
if args.plugin == "gemini_auto" and args.accumulation_steps > 1:
|
||||||
|
raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin")
|
||||||
|
# ==============================
|
||||||
|
# Initialize Distributed Training
|
||||||
|
# ==============================
|
||||||
|
colossalai.launch_from_torch()
|
||||||
|
coordinator = DistCoordinator()
|
||||||
|
|
||||||
|
# ======================================================
|
||||||
|
# Initialize Model, Objective, Optimizer and LR Scheduler
|
||||||
|
# ======================================================
|
||||||
|
# Temp Fix: Disable lazy init due to version conflict
|
||||||
|
# init_ctx = (
|
||||||
|
# LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext()
|
||||||
|
# )
|
||||||
|
|
||||||
|
init_ctx = nullcontext()
|
||||||
|
with init_ctx:
|
||||||
|
if args.use_flash_attn:
|
||||||
|
actor = AutoModelForCausalLM.from_pretrained(
|
||||||
|
args.pretrain,
|
||||||
|
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
|
||||||
|
use_flash_attention_2=True,
|
||||||
|
local_files_only=True,
|
||||||
|
trust_remote_code=True,
|
||||||
|
)
|
||||||
|
ref_model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
args.pretrain,
|
||||||
|
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
|
||||||
|
use_flash_attention_2=True,
|
||||||
|
local_files_only=True,
|
||||||
|
trust_remote_code=True,
|
||||||
|
)
|
||||||
|
if args.rm_pretrain:
|
||||||
|
reward_model = RewardModel(
|
||||||
|
args.rm_pretrain,
|
||||||
|
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
|
||||||
|
use_flash_attention_2=True,
|
||||||
|
trust_remote_code=True,
|
||||||
|
)
|
||||||
|
coordinator.print_on_master(msg="Flash-attention enabled successfully")
|
||||||
|
else:
|
||||||
|
actor = AutoModelForCausalLM.from_pretrained(args.pretrain, local_files_only=True, trust_remote_code=True)
|
||||||
|
if args.rm_pretrain:
|
||||||
|
reward_model = RewardModel(args.rm_pretrain, trust_remote_code=True)
|
||||||
|
ref_model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
args.pretrain, local_files_only=True, trust_remote_code=True
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.lora_config is not None:
|
||||||
|
actor = convert_to_lora_module(actor, lora_config=lora_config)
|
||||||
|
for name, module in actor.named_modules():
|
||||||
|
if "norm" in name or "gate" in name:
|
||||||
|
module = module.to(torch.float32)
|
||||||
|
lora_manager.able_to_merge = False
|
||||||
|
|
||||||
|
# Disable dropout
|
||||||
|
disable_dropout(actor)
|
||||||
|
|
||||||
|
if args.grad_checkpoint:
|
||||||
|
actor.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
|
||||||
|
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
|
||||||
|
|
||||||
|
# configure tokenizer
|
||||||
|
tokenizer_dir = args.tokenizer_dir if args.tokenizer_dir is not None else args.pretrain
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir, use_fast=False, trust_remote_code=True)
|
||||||
|
if os.path.exists(args.conversation_template_config):
|
||||||
|
with open(args.conversation_template_config, "r", encoding="utf8") as f:
|
||||||
|
conversation_template_config = json.load(f)
|
||||||
|
dist.barrier()
|
||||||
|
if "response_format_tags" in conversation_template_config:
|
||||||
|
logger.warning(f"Overwrite default response format tags with {args.conversation_template_config}")
|
||||||
|
response_format_tags = conversation_template_config.get("response_format_tags", response_format_tags)
|
||||||
|
conversation_template = setup_conversation_template(
|
||||||
|
tokenizer, chat_template_config=conversation_template_config, save_path=args.conversation_template_config
|
||||||
|
)
|
||||||
|
stop_ids = conversation_template.stop_ids if len(conversation_template.stop_ids) > 0 else None
|
||||||
|
else:
|
||||||
|
raise ValueError("Conversation template config is not provided or incorrect")
|
||||||
|
if hasattr(tokenizer, "pad_token") and hasattr(tokenizer, "eos_token") and tokenizer.eos_token is not None:
|
||||||
|
try:
|
||||||
|
# Some tokenizers doesn't allow to set pad_token mannually e.g., Qwen
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
except AttributeError as e:
|
||||||
|
logger.warning(f"Unable to set pad token to eos token, {str(e)}")
|
||||||
|
if not hasattr(tokenizer, "pad_token") or tokenizer.pad_token is None:
|
||||||
|
logger.warning(
|
||||||
|
"The tokenizer does not have a pad token which is required. May lead to unintended behavior in training, Please consider manually set them."
|
||||||
|
)
|
||||||
|
|
||||||
|
tokenizer.add_bos_token = False
|
||||||
|
tokenizer.add_eos_token = False
|
||||||
|
tokenizer.padding_side = "left" # left padding for generation (online learning)
|
||||||
|
|
||||||
|
# configure generation config
|
||||||
|
actor.generation_config.update(
|
||||||
|
pad_token_id=tokenizer.eos_token_id, bos_token_id=tokenizer.bos_token_id, eos_token_id=tokenizer.eos_token_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# configure optimizer
|
||||||
|
coordinator.print_on_master(f"setting up optimizer for actor: lr={args.lr}, weight_decay={args.weight_decay}")
|
||||||
|
actor_optim = HybridAdam(
|
||||||
|
model_params=actor.parameters(),
|
||||||
|
lr=args.lr,
|
||||||
|
betas=(0.9, 0.95),
|
||||||
|
weight_decay=args.weight_decay,
|
||||||
|
adamw_mode=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.warmup_steps is None:
|
||||||
|
args.warmup_steps = int(0.025 * args.num_episodes)
|
||||||
|
coordinator.print_on_master(f"Warmup steps is set to {args.warmup_steps}")
|
||||||
|
|
||||||
|
actor_lr_scheduler = CosineAnnealingWarmupLR(
|
||||||
|
optimizer=actor_optim,
|
||||||
|
total_steps=args.num_episodes,
|
||||||
|
warmup_steps=args.warmup_steps,
|
||||||
|
eta_min=0.1 * args.lr,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ==============================
|
||||||
|
# Initialize Booster
|
||||||
|
# ==============================
|
||||||
|
if args.plugin == "ddp":
|
||||||
|
"""
|
||||||
|
Default torch ddp plugin without any acceleration, for
|
||||||
|
debugging purpose acceleration, for debugging purpose
|
||||||
|
"""
|
||||||
|
plugin = TorchDDPPlugin(find_unused_parameters=True)
|
||||||
|
elif args.plugin == "gemini":
|
||||||
|
plugin = GeminiPlugin(
|
||||||
|
precision=args.mixed_precision,
|
||||||
|
placement_policy="static",
|
||||||
|
initial_scale=2**16,
|
||||||
|
max_norm=args.grad_clip,
|
||||||
|
enable_gradient_accumulation=True,
|
||||||
|
enable_flash_attention=args.use_flash_attn,
|
||||||
|
)
|
||||||
|
elif args.plugin == "gemini_auto":
|
||||||
|
plugin = GeminiPlugin(
|
||||||
|
precision=args.mixed_precision,
|
||||||
|
placement_policy="auto",
|
||||||
|
initial_scale=2**16,
|
||||||
|
max_norm=args.grad_clip,
|
||||||
|
enable_flash_attention=args.use_flash_attn,
|
||||||
|
)
|
||||||
|
elif args.plugin == "zero2":
|
||||||
|
plugin = LowLevelZeroPlugin(
|
||||||
|
stage=2,
|
||||||
|
precision=args.mixed_precision,
|
||||||
|
initial_scale=2**16,
|
||||||
|
max_norm=args.grad_clip,
|
||||||
|
)
|
||||||
|
elif args.plugin == "zero2_cpu":
|
||||||
|
plugin = LowLevelZeroPlugin(
|
||||||
|
stage=2,
|
||||||
|
precision=args.mixed_precision,
|
||||||
|
initial_scale=2**16,
|
||||||
|
cpu_offload=True,
|
||||||
|
max_norm=args.grad_clip,
|
||||||
|
)
|
||||||
|
elif args.plugin == "3d":
|
||||||
|
if args.use_flash_attn and (args.tp > 1 or args.pp > 1 or args.sp > 1 or args.enable_sequence_parallelism):
|
||||||
|
logger.warning("Flash attention cannot be used with 3D parallelism for PPO training. Disabling it.")
|
||||||
|
args.use_flash_attn = False
|
||||||
|
plugin = HybridParallelPlugin(
|
||||||
|
tp_size=args.tp,
|
||||||
|
pp_size=args.pp,
|
||||||
|
sp_size=args.sp,
|
||||||
|
sequence_parallelism_mode=args.sp_mode,
|
||||||
|
zero_stage=args.zero_stage,
|
||||||
|
enable_flash_attention=args.use_flash_attn,
|
||||||
|
enable_sequence_parallelism=args.enable_sequence_parallelism,
|
||||||
|
cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False,
|
||||||
|
parallel_output=False,
|
||||||
|
max_norm=args.grad_clip,
|
||||||
|
precision=args.mixed_precision,
|
||||||
|
)
|
||||||
|
if args.rm_pretrain:
|
||||||
|
custom_plugin = HybridParallelPlugin(
|
||||||
|
tp_size=args.tp,
|
||||||
|
pp_size=args.pp,
|
||||||
|
sp_size=args.sp,
|
||||||
|
sequence_parallelism_mode=args.sp_mode,
|
||||||
|
zero_stage=args.zero_stage,
|
||||||
|
enable_flash_attention=args.use_flash_attn,
|
||||||
|
enable_sequence_parallelism=args.enable_sequence_parallelism,
|
||||||
|
cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False,
|
||||||
|
parallel_output=False,
|
||||||
|
max_norm=args.grad_clip,
|
||||||
|
precision=args.mixed_precision,
|
||||||
|
custom_policy=get_autopolicy(reward_model.model),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown plugin {args.plugin}")
|
||||||
|
|
||||||
|
if args.plugin != "3d" and args.rm_pretrain:
|
||||||
|
custom_plugin = plugin
|
||||||
|
|
||||||
|
# configure dataset
|
||||||
|
coordinator.print_on_master(f"Load dataset: {args.prompt_dataset}")
|
||||||
|
mode_map = {"train": "train", "valid": "validation", "test": "test"}
|
||||||
|
train_prompt_dataset = load_tokenized_dataset(dataset_paths=args.prompt_dataset, mode="train", mode_map=mode_map)
|
||||||
|
|
||||||
|
data_collator = DataCollatorForPromptDataset(tokenizer=tokenizer, max_length=args.max_length - args.max_seq_len)
|
||||||
|
|
||||||
|
train_prompt_dataloader = plugin.prepare_dataloader(
|
||||||
|
dataset=train_prompt_dataset,
|
||||||
|
batch_size=args.experience_batch_size,
|
||||||
|
shuffle=True,
|
||||||
|
drop_last=True,
|
||||||
|
collate_fn=data_collator,
|
||||||
|
distributed_sampler_cls=StatefulDistributedSampler,
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(args.ptx_dataset) > 0:
|
||||||
|
train_ptx_dataset = load_tokenized_dataset(dataset_paths=args.ptx_dataset, mode="train", mode_map=mode_map)
|
||||||
|
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=args.max_length)
|
||||||
|
train_pretrain_dataloader = plugin.prepare_dataloader(
|
||||||
|
dataset=train_ptx_dataset,
|
||||||
|
batch_size=args.ptx_batch_size,
|
||||||
|
shuffle=True,
|
||||||
|
drop_last=True,
|
||||||
|
collate_fn=data_collator,
|
||||||
|
distributed_sampler_cls=StatefulDistributedSampler,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
train_pretrain_dataloader = None
|
||||||
|
|
||||||
|
actor_booster = Booster(plugin=plugin)
|
||||||
|
ref_booster = Booster(plugin=plugin)
|
||||||
|
if args.rm_pretrain:
|
||||||
|
rm_booster = Booster(plugin=custom_plugin)
|
||||||
|
|
||||||
|
default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
|
||||||
|
torch.set_default_dtype(default_dtype)
|
||||||
|
actor, actor_optim, _, train_prompt_dataloader, actor_lr_scheduler = actor_booster.boost(
|
||||||
|
model=actor,
|
||||||
|
optimizer=actor_optim,
|
||||||
|
lr_scheduler=actor_lr_scheduler,
|
||||||
|
dataloader=train_prompt_dataloader,
|
||||||
|
)
|
||||||
|
if args.rm_pretrain:
|
||||||
|
reward_model, _, _, _, _ = rm_booster.boost(model=reward_model, dataloader=train_prompt_dataloader)
|
||||||
|
else:
|
||||||
|
if args.reward_functions:
|
||||||
|
reward_fn_list = []
|
||||||
|
for reward_fn in args.reward_functions:
|
||||||
|
"""
|
||||||
|
To define custom reward function, you can define your functions under:
|
||||||
|
colossalai/applications/ColossalChat/coati/utils/reward_score/__init__.py
|
||||||
|
and use it here by mofiying the following line:
|
||||||
|
"""
|
||||||
|
if reward_fn == "gsm8k_reward_fn":
|
||||||
|
reward_fn_list.append(gsm8k_reward_fn)
|
||||||
|
elif reward_fn == "math_competition_reward_fn":
|
||||||
|
reward_fn_list.append(math_competition_reward_fn)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown reward function {reward_fn}")
|
||||||
|
reward_model = RLVRRewardModel(
|
||||||
|
reward_fn_list=reward_fn_list, tokenizer=tokenizer, tags=response_format_tags
|
||||||
|
)
|
||||||
|
|
||||||
|
ref_model, _, _, _, _ = ref_booster.boost(model=ref_model, dataloader=train_prompt_dataloader)
|
||||||
|
|
||||||
|
torch.set_default_dtype(torch.float)
|
||||||
|
|
||||||
|
coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
|
||||||
|
coordinator.print_on_master(
|
||||||
|
f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
|
||||||
|
)
|
||||||
|
|
||||||
|
sampler_start_idx = 0
|
||||||
|
start_step = 0
|
||||||
|
|
||||||
|
if args.rm_checkpoint_path is not None:
|
||||||
|
if "modeling" in args.rm_checkpoint_path:
|
||||||
|
rm_booster.load_model(reward_model, args.rm_checkpoint_path)
|
||||||
|
else:
|
||||||
|
_, _, _ = load_checkpoint(
|
||||||
|
load_dir=args.rm_checkpoint_path,
|
||||||
|
booster=rm_booster,
|
||||||
|
model=reward_model,
|
||||||
|
optimizer=None,
|
||||||
|
lr_scheduler=None,
|
||||||
|
)
|
||||||
|
coordinator.print_on_master(f"Loaded reward model checkpoint {args.rm_checkpoint_path}")
|
||||||
|
if args.checkpoint_path is not None:
|
||||||
|
if "modeling" in args.checkpoint_path:
|
||||||
|
actor_booster.load_model(actor, args.checkpoint_path)
|
||||||
|
ref_booster.load_model(ref_model, args.checkpoint_path)
|
||||||
|
coordinator.print_on_master(f"Loaded actor and reference model {args.checkpoint_path}")
|
||||||
|
else:
|
||||||
|
_, start_step, sampler_start_idx = load_checkpoint(
|
||||||
|
load_dir=args.checkpoint_path,
|
||||||
|
booster=actor_booster,
|
||||||
|
model=actor,
|
||||||
|
optimizer=actor_optim,
|
||||||
|
lr_scheduler=actor_lr_scheduler,
|
||||||
|
)
|
||||||
|
_, _, _ = load_checkpoint(load_dir=args.checkpoint_path, booster=ref_booster, model=ref_model)
|
||||||
|
assert isinstance(train_prompt_dataloader.sampler, StatefulDistributedSampler)
|
||||||
|
train_prompt_dataloader.sampler.set_start_index(start_index=sampler_start_idx)
|
||||||
|
|
||||||
|
coordinator.print_on_master(
|
||||||
|
f"Loaded actor and reference model checkpoint {args.checkpoint_path} at spisode {start_step}"
|
||||||
|
)
|
||||||
|
coordinator.print_on_master(f"Loaded sample at index {sampler_start_idx}")
|
||||||
|
|
||||||
|
coordinator.print_on_master(
|
||||||
|
f"Checkpoint loaded max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
|
||||||
|
)
|
||||||
|
coordinator.print_on_master(
|
||||||
|
f"Checkpoint loaded CUDA memory: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB"
|
||||||
|
)
|
||||||
|
coordinator.print_on_master(
|
||||||
|
f"Checkpoint loaded max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
|
||||||
|
)
|
||||||
|
|
||||||
|
# configure trainer
|
||||||
|
trainer = GRPOTrainer(
|
||||||
|
actor_booster,
|
||||||
|
actor,
|
||||||
|
reward_model,
|
||||||
|
ref_model,
|
||||||
|
actor_optim,
|
||||||
|
actor_lr_scheduler,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
stop_token_ids=[stop_ids],
|
||||||
|
kl_coef=args.kl_coef,
|
||||||
|
ptx_coef=args.ptx_coef,
|
||||||
|
train_batch_size=args.train_batch_size,
|
||||||
|
buffer_limit=args.num_collect_steps * args.experience_batch_size * args.num_generations,
|
||||||
|
max_length=args.max_length,
|
||||||
|
use_cache=True,
|
||||||
|
do_sample=True,
|
||||||
|
apply_loss_mask=not args.disable_loss_mask,
|
||||||
|
accumulation_steps=args.accumulation_steps,
|
||||||
|
save_dir=args.save_path,
|
||||||
|
save_interval=args.save_interval,
|
||||||
|
top_k=50,
|
||||||
|
use_tp=args.tp > 1,
|
||||||
|
num_generations=args.num_generations,
|
||||||
|
inference_batch_size=args.inference_batch_size,
|
||||||
|
logits_forward_batch_size=args.logits_forward_batch_size,
|
||||||
|
offload_inference_models="gemini" not in args.plugin,
|
||||||
|
coordinator=coordinator,
|
||||||
|
max_tokens_thinking=args.max_tokens_thinking if args.max_tokens_thinking else args.max_length - 100,
|
||||||
|
temperature_annealing_config={
|
||||||
|
"start_temperature": args.initial_temperature,
|
||||||
|
"end_temperature": args.final_temperature,
|
||||||
|
"annealing_warmup_steps": min(100, int(args.num_episodes / 6)),
|
||||||
|
"annealing_steps": min(600, int(args.num_episodes / 2)),
|
||||||
|
},
|
||||||
|
# Hack: some old model's default update_model_kwargs_fn/prepare_inputs_fn may doesn't work due to version conflict with transformers, you can overwrite them
|
||||||
|
# update_model_kwargs_fn=update_model_kwargs_fn,
|
||||||
|
# prepare_inputs_fn = None
|
||||||
|
)
|
||||||
|
|
||||||
|
trainer.fit(
|
||||||
|
num_episodes=args.num_episodes,
|
||||||
|
num_collect_steps=args.num_collect_steps,
|
||||||
|
num_update_steps=args.num_update_steps,
|
||||||
|
prompt_dataloader=train_prompt_dataloader,
|
||||||
|
pretrain_dataloader=train_pretrain_dataloader,
|
||||||
|
log_dir=args.log_dir,
|
||||||
|
use_wandb=args.use_wandb,
|
||||||
|
)
|
||||||
|
|
||||||
|
if lora_config is not None and lora_config.r > 0:
|
||||||
|
# NOTE: set model to eval to merge LoRA weights
|
||||||
|
lora_manager.able_to_merge = True
|
||||||
|
actor.eval()
|
||||||
|
# save model checkpoint after fitting on only rank0
|
||||||
|
coordinator.print_on_master("Start saving final actor model checkpoint")
|
||||||
|
actor_booster.save_model(actor, os.path.join(trainer.actor_save_dir, "modeling"), shard=True)
|
||||||
|
coordinator.print_on_master(
|
||||||
|
f"Saved final actor model checkpoint at episodes {args.num_episodes} at folder {args.save_path}"
|
||||||
|
)
|
||||||
|
coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--prompt_dataset", nargs="+", default=[])
|
||||||
|
parser.add_argument("--ptx_dataset", nargs="+", default=[])
|
||||||
|
parser.add_argument(
|
||||||
|
"--plugin",
|
||||||
|
type=str,
|
||||||
|
default="gemini",
|
||||||
|
choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d"],
|
||||||
|
help="Choose which plugin to use",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--conversation_template_config",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Path \
|
||||||
|
to save conversation template config files.",
|
||||||
|
)
|
||||||
|
parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")
|
||||||
|
parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay")
|
||||||
|
parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
|
||||||
|
parser.add_argument("--tokenizer_dir", type=str, default=None)
|
||||||
|
parser.add_argument("--tp", type=int, default=1)
|
||||||
|
parser.add_argument("--pp", type=int, default=1)
|
||||||
|
parser.add_argument("--sp", type=int, default=1)
|
||||||
|
parser.add_argument("--enable_sequence_parallelism", default=False, action="store_true")
|
||||||
|
parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2])
|
||||||
|
parser.add_argument("--zero_cpu_offload", default=False, action="store_true")
|
||||||
|
parser.add_argument("--sp_mode", type=str, default="split_gather", choices=["split_gather", "ring", "all_to_all"])
|
||||||
|
parser.add_argument("--pretrain", type=str, default=None)
|
||||||
|
parser.add_argument("--rm_pretrain", type=str, default=None)
|
||||||
|
parser.add_argument("--checkpoint_path", type=str, default=None)
|
||||||
|
parser.add_argument("--rm_checkpoint_path", type=str, help="Reward model checkpoint path")
|
||||||
|
parser.add_argument("--reward_functions", type=str, nargs="+", default=None, help="Reward functions to use")
|
||||||
|
parser.add_argument("--save_path", type=str, default="actor_checkpoint_prompts")
|
||||||
|
parser.add_argument("--num_episodes", type=int, default=1)
|
||||||
|
parser.add_argument("--num_collect_steps", type=int, default=2)
|
||||||
|
parser.add_argument("--num_update_steps", type=int, default=5)
|
||||||
|
parser.add_argument("--num_generations", type=int, default=8)
|
||||||
|
parser.add_argument("--inference_batch_size", type=int, default=None)
|
||||||
|
parser.add_argument("--save_interval", type=int, default=1000)
|
||||||
|
parser.add_argument("--train_batch_size", type=int, default=16)
|
||||||
|
parser.add_argument("--logits_forward_batch_size", type=int, default=1)
|
||||||
|
parser.add_argument("--experience_batch_size", type=int, default=16)
|
||||||
|
parser.add_argument("--ptx_batch_size", type=int, default=4)
|
||||||
|
parser.add_argument("--lora_config", type=str, default=None, help="low-rank adaptation config file path")
|
||||||
|
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
|
||||||
|
parser.add_argument("--accumulation_steps", type=int, default=8)
|
||||||
|
parser.add_argument("--lr", type=float, default=1e-6)
|
||||||
|
parser.add_argument("--kl_coef", type=float, default=0.7)
|
||||||
|
parser.add_argument("--ptx_coef", type=float, default=0.0)
|
||||||
|
parser.add_argument("--disable_loss_mask", default=False, action="store_true")
|
||||||
|
parser.add_argument("--max_length", type=int, default=2048)
|
||||||
|
parser.add_argument("--max_tokens_thinking", type=int, default=2000)
|
||||||
|
parser.add_argument("--max_seq_len", type=int, default=256)
|
||||||
|
parser.add_argument("--initial_temperature", type=float, default=1.0)
|
||||||
|
parser.add_argument("--final_temperature", type=float, default=0.9)
|
||||||
|
parser.add_argument("--log_dir", default=None, type=str)
|
||||||
|
parser.add_argument("--use_wandb", default=False, action="store_true")
|
||||||
|
parser.add_argument("--grad_checkpoint", default=False, action="store_true")
|
||||||
|
parser.add_argument("--use_flash_attn", default=False, action="store_true")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
train(args)
|
86
applications/ColossalChat/examples/training_scripts/train_grpo.sh
Executable file
86
applications/ColossalChat/examples/training_scripts/train_grpo.sh
Executable file
@ -0,0 +1,86 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
set_n_least_used_CUDA_VISIBLE_DEVICES() {
|
||||||
|
local n=${1:-"9999"}
|
||||||
|
echo "GPU Memory Usage:"
|
||||||
|
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv |
|
||||||
|
tail -n +2 |
|
||||||
|
nl -v 0 |
|
||||||
|
tee /dev/tty |
|
||||||
|
sort -g -k 2 |
|
||||||
|
awk '{print $1}' |
|
||||||
|
head -n $n)
|
||||||
|
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
|
||||||
|
echo "Now CUDA_VISIBLE_DEVICES is set to:"
|
||||||
|
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
||||||
|
}
|
||||||
|
set_n_least_used_CUDA_VISIBLE_DEVICES 8
|
||||||
|
|
||||||
|
PROJECT_NAME="PPO-RLVR"
|
||||||
|
|
||||||
|
PARENT_SAVE_DIR="" # Path to a folder to save checkpoints
|
||||||
|
PARENT_CONFIG_FILE="" # Path to a folder to save training config logs
|
||||||
|
PRETRAINED_MODEL_PATH="" # local pretrained model path (from RLHF step 1: SFT)
|
||||||
|
PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
|
||||||
|
CONVERSATION_TEMPLATE_CONFIG_PATH="" # path to the conversation config file
|
||||||
|
LOGDIR=""
|
||||||
|
|
||||||
|
declare -a prompt_dataset=(
|
||||||
|
YOUR/PROMPT/DATA/DIR/arrow/part-00000
|
||||||
|
YOUR/PROMPT/DATA/DIR/arrow/part-00001
|
||||||
|
YOUR/PROMPT/DATA/DIR/arrow/part-00002
|
||||||
|
YOUR/PROMPT/DATA/DIR/arrow/part-00003
|
||||||
|
YOUR/PROMPT/DATA/DIR/arrow/part-00004
|
||||||
|
YOUR/PROMPT/DATA/DIR/arrow/part-00005
|
||||||
|
YOUR/PROMPT/DATA/DIR/arrow/part-00006
|
||||||
|
YOUR/PROMPT/DATA/DIR/arrow/part-00007
|
||||||
|
YOUR/PROMPT/DATA/DIR/arrow/part-00008
|
||||||
|
YOUR/PROMPT/DATA/DIR/arrow/part-00009
|
||||||
|
)
|
||||||
|
|
||||||
|
declare -a ptx_dataset=(
|
||||||
|
YOUR/SFT/DATA/DIR/arrow/part-00000
|
||||||
|
YOUR/SFT/DATA/DIR/arrow/part-00001
|
||||||
|
YOUR/SFT/DATA/DIR/arrow/part-00002
|
||||||
|
YOUR/SFT/DATA/DIR/arrow/part-00003
|
||||||
|
YOUR/SFT/DATA/DIR/arrow/part-00004
|
||||||
|
YOUR/SFT/DATA/DIR/arrow/part-00005
|
||||||
|
YOUR/SFT/DATA/DIR/arrow/part-00006
|
||||||
|
YOUR/SFT/DATA/DIR/arrow/part-00007
|
||||||
|
YOUR/SFT/DATA/DIR/arrow/part-00008
|
||||||
|
YOUR/SFT/DATA/DIR/arrow/part-00009
|
||||||
|
)
|
||||||
|
|
||||||
|
TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
|
||||||
|
FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
|
||||||
|
SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}"
|
||||||
|
CONFIG_FILE="${PARENT_CONFIG_FILE}${FULL_PROJECT_NAME}.json"
|
||||||
|
|
||||||
|
colossalai run --nproc_per_node 8 --num_nodes 1 --hostfile ./hostfile train_grpo.py \
|
||||||
|
--pretrain $PRETRAINED_MODEL_PATH \
|
||||||
|
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
|
||||||
|
--prompt_dataset ${prompt_dataset[@]} \
|
||||||
|
--conversation_template_config $CONVERSATION_TEMPLATE_CONFIG_PATH \
|
||||||
|
--ptx_coef 0.0 \
|
||||||
|
--plugin "zero2_cpu" \
|
||||||
|
--reward_functions math_competition_reward_fn \
|
||||||
|
--save_interval 250 \
|
||||||
|
--save_path $SAVE_DIR \
|
||||||
|
--num_episodes 100 \
|
||||||
|
--num_collect_steps 8 \
|
||||||
|
--num_update_steps 1 \
|
||||||
|
--experience_batch_size 1 \
|
||||||
|
--train_batch_size 4 \
|
||||||
|
--inference_batch_size 8 \
|
||||||
|
--logits_forward_batch_size 2 \
|
||||||
|
--accumulation_steps 4 \
|
||||||
|
--lr 1e-6 \
|
||||||
|
--mixed_precision "bf16" \
|
||||||
|
--grad_clip 0.1\
|
||||||
|
--weight_decay 0.01 \
|
||||||
|
--kl_coef 0.01 \
|
||||||
|
--warmup_steps 40 \
|
||||||
|
--max_length 2000 \
|
||||||
|
--max_seq_len 1700 \
|
||||||
|
--log_dir $LOGDIR \
|
||||||
|
--use_flash_attn \
|
||||||
|
--grad_checkpoint
|
@ -13,9 +13,18 @@ from coati.dataset import (
|
|||||||
load_tokenized_dataset,
|
load_tokenized_dataset,
|
||||||
setup_conversation_template,
|
setup_conversation_template,
|
||||||
)
|
)
|
||||||
from coati.models import Critic, LoraConfig, RewardModel, convert_to_lora_module, disable_dropout, lora_manager
|
from coati.models import (
|
||||||
|
Critic,
|
||||||
|
LoraConfig,
|
||||||
|
RewardModel,
|
||||||
|
RLVRRewardModel,
|
||||||
|
convert_to_lora_module,
|
||||||
|
disable_dropout,
|
||||||
|
lora_manager,
|
||||||
|
)
|
||||||
from coati.trainer import PPOTrainer
|
from coati.trainer import PPOTrainer
|
||||||
from coati.utils import load_checkpoint
|
from coati.utils import load_checkpoint
|
||||||
|
from coati.utils.reward_score import *
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
@ -29,8 +38,17 @@ from colossalai.shardformer.policies.auto_policy import get_autopolicy
|
|||||||
|
|
||||||
logger = get_dist_logger()
|
logger = get_dist_logger()
|
||||||
|
|
||||||
|
# default settings for response format tags, overwrite it in chat_template definition if needed
|
||||||
|
response_format_tags = {
|
||||||
|
"think_start": {"text": "<think>", "num_occur": 1},
|
||||||
|
"think_end": {"text": "</think>", "num_occur": 1},
|
||||||
|
"answer_start": {"text": "<answer>", "num_occur": 1},
|
||||||
|
"answer_end": {"text": "</answer>", "num_occur": 1},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def train(args):
|
def train(args):
|
||||||
|
global response_format_tags
|
||||||
lora_config = None
|
lora_config = None
|
||||||
if args.lora_config is not None:
|
if args.lora_config is not None:
|
||||||
lora_config = LoraConfig.from_file(args.lora_config)
|
lora_config = LoraConfig.from_file(args.lora_config)
|
||||||
@ -61,28 +79,36 @@ def train(args):
|
|||||||
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
|
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
|
||||||
use_flash_attention_2=True,
|
use_flash_attention_2=True,
|
||||||
local_files_only=True,
|
local_files_only=True,
|
||||||
|
trust_remote_code=True,
|
||||||
)
|
)
|
||||||
ref_model = AutoModelForCausalLM.from_pretrained(
|
ref_model = AutoModelForCausalLM.from_pretrained(
|
||||||
args.pretrain,
|
args.pretrain,
|
||||||
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
|
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
|
||||||
use_flash_attention_2=True,
|
use_flash_attention_2=True,
|
||||||
local_files_only=True,
|
local_files_only=True,
|
||||||
|
trust_remote_code=True,
|
||||||
)
|
)
|
||||||
reward_model = RewardModel(
|
if not args.no_neural_reward_model:
|
||||||
args.rm_pretrain,
|
reward_model = RewardModel(
|
||||||
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
|
args.rm_pretrain,
|
||||||
use_flash_attention_2=True,
|
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
|
||||||
)
|
use_flash_attention_2=True,
|
||||||
|
trust_remote_code=True,
|
||||||
|
)
|
||||||
critic = Critic(
|
critic = Critic(
|
||||||
args.rm_pretrain,
|
args.rm_pretrain,
|
||||||
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
|
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
|
||||||
use_flash_attention_2=True,
|
use_flash_attention_2=True,
|
||||||
|
trust_remote_code=True,
|
||||||
)
|
)
|
||||||
coordinator.print_on_master(msg="Flash-attention enabled successfully")
|
coordinator.print_on_master(msg="Flash-attention enabled successfully")
|
||||||
else:
|
else:
|
||||||
actor = AutoModelForCausalLM.from_pretrained(args.pretrain, local_files_only=True)
|
actor = AutoModelForCausalLM.from_pretrained(args.pretrain, local_files_only=True, trust_remote_code=True)
|
||||||
ref_model = AutoModelForCausalLM.from_pretrained(args.pretrain, local_files_only=True)
|
ref_model = AutoModelForCausalLM.from_pretrained(
|
||||||
reward_model = RewardModel(args.rm_pretrain)
|
args.pretrain, local_files_only=True, trust_remote_code=True
|
||||||
|
)
|
||||||
|
if not args.no_neural_reward_model:
|
||||||
|
reward_model = RewardModel(args.rm_pretrain, trust_remote_code=True)
|
||||||
critic = Critic(args.rm_pretrain)
|
critic = Critic(args.rm_pretrain)
|
||||||
|
|
||||||
if args.lora_config is not None:
|
if args.lora_config is not None:
|
||||||
@ -112,6 +138,9 @@ def train(args):
|
|||||||
with open(args.conversation_template_config, "r", encoding="utf8") as f:
|
with open(args.conversation_template_config, "r", encoding="utf8") as f:
|
||||||
conversation_template_config = json.load(f)
|
conversation_template_config = json.load(f)
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
|
if "response_format_tags" in conversation_template_config:
|
||||||
|
logger.warning(f"Overwrite default response format tags with {args.conversation_template_config}")
|
||||||
|
response_format_tags = conversation_template_config.get("response_format_tags", response_format_tags)
|
||||||
conversation_template = setup_conversation_template(
|
conversation_template = setup_conversation_template(
|
||||||
tokenizer, chat_template_config=conversation_template_config, save_path=args.conversation_template_config
|
tokenizer, chat_template_config=conversation_template_config, save_path=args.conversation_template_config
|
||||||
)
|
)
|
||||||
@ -245,7 +274,7 @@ def train(args):
|
|||||||
parallel_output=False,
|
parallel_output=False,
|
||||||
max_norm=args.grad_clip,
|
max_norm=args.grad_clip,
|
||||||
precision=args.mixed_precision,
|
precision=args.mixed_precision,
|
||||||
custom_policy=get_autopolicy(reward_model.model),
|
custom_policy=get_autopolicy(critic.model),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown plugin {args.plugin}")
|
raise ValueError(f"Unknown plugin {args.plugin}")
|
||||||
@ -284,7 +313,8 @@ def train(args):
|
|||||||
|
|
||||||
actor_booster = Booster(plugin=plugin)
|
actor_booster = Booster(plugin=plugin)
|
||||||
ref_booster = Booster(plugin=plugin)
|
ref_booster = Booster(plugin=plugin)
|
||||||
rm_booster = Booster(plugin=custom_plugin)
|
if not args.no_neural_reward_model:
|
||||||
|
rm_booster = Booster(plugin=custom_plugin)
|
||||||
critic_booster = Booster(plugin=custom_plugin)
|
critic_booster = Booster(plugin=custom_plugin)
|
||||||
|
|
||||||
default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
|
default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
|
||||||
@ -302,7 +332,28 @@ def train(args):
|
|||||||
lr_scheduler=critic_lr_scheduler,
|
lr_scheduler=critic_lr_scheduler,
|
||||||
dataloader=train_prompt_dataloader,
|
dataloader=train_prompt_dataloader,
|
||||||
)
|
)
|
||||||
reward_model, _, _, _, _ = rm_booster.boost(model=reward_model, dataloader=train_prompt_dataloader)
|
if not args.no_neural_reward_model:
|
||||||
|
reward_model, _, _, _, _ = rm_booster.boost(model=reward_model, dataloader=train_prompt_dataloader)
|
||||||
|
else:
|
||||||
|
if args.reward_functions:
|
||||||
|
reward_fn_list = []
|
||||||
|
for reward_fn in args.reward_functions:
|
||||||
|
"""
|
||||||
|
To define custom reward function, you can define your functions under:
|
||||||
|
colossalai/applications/ColossalChat/coati/utils/reward_score/__init__.py
|
||||||
|
and use it here by mofiying the following line:
|
||||||
|
"""
|
||||||
|
if reward_fn == "gsm8k_reward_fn":
|
||||||
|
reward_fn_list.append(gsm8k_reward_fn)
|
||||||
|
elif reward_fn == "math_competition_reward_fn":
|
||||||
|
reward_fn_list.append(math_competition_reward_fn)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown reward function {reward_fn}")
|
||||||
|
reward_fn_list.append(eval(reward_fn))
|
||||||
|
reward_model = RLVRRewardModel(
|
||||||
|
reward_fn_list=reward_fn_list, tokenizer=tokenizer, tags=response_format_tags
|
||||||
|
)
|
||||||
|
|
||||||
ref_model, _, _, _, _ = ref_booster.boost(model=ref_model, dataloader=train_prompt_dataloader)
|
ref_model, _, _, _, _ = ref_booster.boost(model=ref_model, dataloader=train_prompt_dataloader)
|
||||||
|
|
||||||
torch.set_default_dtype(torch.float)
|
torch.set_default_dtype(torch.float)
|
||||||
@ -481,9 +532,11 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--sp_mode", type=str, default="split_gather", choices=["split_gather", "ring", "all_to_all"])
|
parser.add_argument("--sp_mode", type=str, default="split_gather", choices=["split_gather", "ring", "all_to_all"])
|
||||||
parser.add_argument("--pretrain", type=str, default=None)
|
parser.add_argument("--pretrain", type=str, default=None)
|
||||||
parser.add_argument("--rm_pretrain", type=str, default=None)
|
parser.add_argument("--rm_pretrain", type=str, default=None)
|
||||||
|
parser.add_argument("--no_neural_reward_model", default=False, action="store_true")
|
||||||
parser.add_argument("--checkpoint_path", type=str, default=None)
|
parser.add_argument("--checkpoint_path", type=str, default=None)
|
||||||
parser.add_argument("--critic_checkpoint_path", type=str, default=None)
|
parser.add_argument("--critic_checkpoint_path", type=str, default=None)
|
||||||
parser.add_argument("--rm_checkpoint_path", type=str, help="Reward model checkpoint path")
|
parser.add_argument("--rm_checkpoint_path", type=str, help="Reward model checkpoint path")
|
||||||
|
parser.add_argument("--reward_functions", type=str, nargs="+", default=None, help="Reward functions to use")
|
||||||
parser.add_argument("--save_path", type=str, default="actor_checkpoint_prompts")
|
parser.add_argument("--save_path", type=str, default="actor_checkpoint_prompts")
|
||||||
parser.add_argument("--num_episodes", type=int, default=1)
|
parser.add_argument("--num_episodes", type=int, default=1)
|
||||||
parser.add_argument("--num_collect_steps", type=int, default=2)
|
parser.add_argument("--num_collect_steps", type=int, default=2)
|
||||||
|
@ -68,7 +68,7 @@ def train(args):
|
|||||||
Default torch ddp plugin without any acceleration, for
|
Default torch ddp plugin without any acceleration, for
|
||||||
debugging purpose acceleration, for debugging purpose
|
debugging purpose acceleration, for debugging purpose
|
||||||
"""
|
"""
|
||||||
plugin = TorchDDPPlugin(find_unused_parameters=True if args.grad_checkpoint is False else False)
|
plugin = TorchDDPPlugin(find_unused_parameters=not args.grad_checkpoint)
|
||||||
elif args.plugin == "gemini":
|
elif args.plugin == "gemini":
|
||||||
plugin = GeminiPlugin(
|
plugin = GeminiPlugin(
|
||||||
precision=args.mixed_precision,
|
precision=args.mixed_precision,
|
||||||
|
@ -2,7 +2,7 @@ transformers==4.39.3
|
|||||||
tqdm
|
tqdm
|
||||||
datasets==2.14.7
|
datasets==2.14.7
|
||||||
loralib
|
loralib
|
||||||
colossalai>=0.4.0
|
colossalai>=0.4.7
|
||||||
torch>=2.1.0
|
torch>=2.1.0
|
||||||
langchain
|
langchain
|
||||||
tokenizers
|
tokenizers
|
||||||
@ -21,3 +21,4 @@ ninja==1.11.1
|
|||||||
sentencepiece==0.1.99
|
sentencepiece==0.1.99
|
||||||
flash-attn
|
flash-attn
|
||||||
tiktoken
|
tiktoken
|
||||||
|
jsonlines
|
||||||
|
@ -20,6 +20,15 @@ prompt_seed = {
|
|||||||
},
|
},
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
prompt_rlvr_seed = {
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"from": "user",
|
||||||
|
"content": "What is the degree of the polynomial $(4 +5x^3 +100 +2\pi x^4 + \sqrt{10}x^4 +9)$?",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"gt_answer": "4",
|
||||||
|
}
|
||||||
preference_seed = {
|
preference_seed = {
|
||||||
"context": [
|
"context": [
|
||||||
{"from": "user", "content": "What kind of noises did dinosaurs make?"},
|
{"from": "user", "content": "What kind of noises did dinosaurs make?"},
|
||||||
@ -72,6 +81,8 @@ if __name__ == "__main__":
|
|||||||
seed = sft_seed
|
seed = sft_seed
|
||||||
elif args.data_type == "prompt":
|
elif args.data_type == "prompt":
|
||||||
seed = prompt_seed
|
seed = prompt_seed
|
||||||
|
elif args.data_type == "prompt_rlvr":
|
||||||
|
seed = prompt_rlvr_seed
|
||||||
elif args.data_type == "preference":
|
elif args.data_type == "preference":
|
||||||
seed = preference_seed
|
seed = preference_seed
|
||||||
elif args.data_type == "kto":
|
elif args.data_type == "kto":
|
||||||
|
16
applications/ColossalChat/tests/prepare_test_env.sh
Executable file
16
applications/ColossalChat/tests/prepare_test_env.sh
Executable file
@ -0,0 +1,16 @@
|
|||||||
|
# run under /ColossalAI/applications/ColossalChat
|
||||||
|
export NCCL_SHM_DISABLE=1
|
||||||
|
export MAX_JOBS=1
|
||||||
|
export PRETRAINED_MODEL_PATH=./models
|
||||||
|
export SFT_DATASET=./sft_data
|
||||||
|
export PROMPT_DATASET=./prompt_data
|
||||||
|
export PROMPT_RLVR_DATASET=./prompt_data
|
||||||
|
export PREFERENCE_DATASET=./preference_data
|
||||||
|
export KTO_DATASET=./kto_data
|
||||||
|
mkdir models
|
||||||
|
mkdir sft_data
|
||||||
|
mkdir prompt_data
|
||||||
|
mkdir preference_data
|
||||||
|
mkdir kto_data
|
||||||
|
# ./tests/test_data_preparation.sh
|
||||||
|
# ./tests/test_train.sh
|
@ -24,7 +24,12 @@ if [ -z "$SFT_DATASET" ]; then
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
if [ -z "$PROMPT_DATASET" ]; then
|
if [ -z "$PROMPT_DATASET" ]; then
|
||||||
echo "Please set \$PROMPT_DATASET to the path to prompts."
|
echo "Please set \$PROMPT_DATASET to the path to prompts dataset."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ -z "$PROMPT_RLVR_DATASET" ]; then
|
||||||
|
echo "Please set \$PROMPT_RLVR_DATASET to the path to prompts dataset with gt_answer labels."
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
@ -69,6 +74,8 @@ get_data_input_dirs() {
|
|||||||
echo "$SFT_DATASET"
|
echo "$SFT_DATASET"
|
||||||
elif [[ $data_type == "prompt" ]]; then
|
elif [[ $data_type == "prompt" ]]; then
|
||||||
echo "$PROMPT_DATASET"
|
echo "$PROMPT_DATASET"
|
||||||
|
elif [[ $data_type == "prompt_rlvr" ]]; then
|
||||||
|
echo "$PROMPT_RLVR_DATASET"
|
||||||
elif [[ $data_type == "preference" ]]; then
|
elif [[ $data_type == "preference" ]]; then
|
||||||
echo "$PREFERENCE_DATASET"
|
echo "$PREFERENCE_DATASET"
|
||||||
elif [[ $data_type == "kto" ]]; then
|
elif [[ $data_type == "kto" ]]; then
|
||||||
@ -123,6 +130,10 @@ python $TEST_DIR/generate_dummy_datasets_for_testing.py \
|
|||||||
--data_dir $(get_data_input_dirs prompt) \
|
--data_dir $(get_data_input_dirs prompt) \
|
||||||
--data_type "prompt"
|
--data_type "prompt"
|
||||||
|
|
||||||
|
python $TEST_DIR/generate_dummy_datasets_for_testing.py \
|
||||||
|
--data_dir $(get_data_input_dirs prompt_rlvr) \
|
||||||
|
--data_type "prompt_rlvr"
|
||||||
|
|
||||||
python $TEST_DIR/generate_dummy_datasets_for_testing.py \
|
python $TEST_DIR/generate_dummy_datasets_for_testing.py \
|
||||||
--data_dir $(get_data_input_dirs kto) \
|
--data_dir $(get_data_input_dirs kto) \
|
||||||
--data_type "kto"
|
--data_type "kto"
|
||||||
@ -266,6 +277,52 @@ for model in ${MODELS[@]}; do
|
|||||||
done
|
done
|
||||||
|
|
||||||
|
|
||||||
|
echo "[Test]: testing prepare_prompt_dataset.py (with verifiable reward)..."
|
||||||
|
|
||||||
|
# FIXME: This is a hack to skip tests that are not working
|
||||||
|
SKIPPED_TESTS=(
|
||||||
|
)
|
||||||
|
|
||||||
|
# test prepare_prompt_dataset
|
||||||
|
for model in ${MODELS[@]}; do
|
||||||
|
data_type="prompt_rlvr"
|
||||||
|
if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$data_type " ]]; then
|
||||||
|
echo "[Test]: Skipped $model-$data_type"
|
||||||
|
continue
|
||||||
|
fi
|
||||||
|
cache_dir=$DATA_SAVE_PATH/tokenized_${model}_${data_type}/cache
|
||||||
|
jsonl_dir=$DATA_SAVE_PATH/tokenized_${model}_${data_type}/jsonl
|
||||||
|
arrow_dir=$DATA_SAVE_PATH/tokenized_${model}_${data_type}/arrow
|
||||||
|
data_input_dirs=$(get_data_input_dirs $data_type)
|
||||||
|
tokenizer_dir=$(get_tokenizer_dirs $model)
|
||||||
|
conversation_template=$(get_conversation_template_config $model)
|
||||||
|
for i in $(seq $NUM_RETRY); do
|
||||||
|
rm -rf $cache_dir
|
||||||
|
rm -rf $jsonl_dir
|
||||||
|
rm -rf $arrow_dir
|
||||||
|
echo "[Test]: $model-$data_type, attempt $i"
|
||||||
|
python $EXAMPLES_DIR/data_preparation_scripts/prepare_dataset.py \
|
||||||
|
--type prompt \
|
||||||
|
--data_input_dirs $data_input_dirs \
|
||||||
|
--conversation_template_config $conversation_template \
|
||||||
|
--tokenizer_dir $tokenizer_dir \
|
||||||
|
--data_cache_dir $cache_dir \
|
||||||
|
--data_jsonl_output_dir $jsonl_dir \
|
||||||
|
--data_arrow_output_dir $arrow_dir \
|
||||||
|
--max_length 400 \
|
||||||
|
--num_samples_per_datafile 100 \
|
||||||
|
--num_spliced_dataset_bins 1
|
||||||
|
passed=$?
|
||||||
|
if [ $passed -eq 0 ]; then
|
||||||
|
break
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
if [ $passed -ne 0 ]; then
|
||||||
|
echo "[Test]: Failed $model-$data_type"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
echo "[Test]: testing prepare_kto_dataset.py ..."
|
echo "[Test]: testing prepare_kto_dataset.py ..."
|
||||||
|
|
||||||
# FIXME: This is a hack to skip tests that are not working
|
# FIXME: This is a hack to skip tests that are not working
|
||||||
|
@ -4,7 +4,7 @@ BASE_TEMP_DIR=$BASE_DIR/temp
|
|||||||
EXAMPLES_DIR=$BASE_DIR/examples
|
EXAMPLES_DIR=$BASE_DIR/examples
|
||||||
TEST_DATA_DIR=$BASE_DIR/tests/test_data
|
TEST_DATA_DIR=$BASE_DIR/tests/test_data
|
||||||
DATA_SAVE_PATH=$BASE_TEMP_DIR/tests
|
DATA_SAVE_PATH=$BASE_TEMP_DIR/tests
|
||||||
CONFIG_DIR=$BASE_DIR/config
|
CONFIG_DIR=$BASE_DIR/conversation_template
|
||||||
|
|
||||||
# MODELS=("colossal-llama2" "llama2" "mistral" "chatGLM2" "chatGLM3" "deepseek" "Yi" "baichuan") # for local test
|
# MODELS=("colossal-llama2" "llama2" "mistral" "chatGLM2" "chatGLM3" "deepseek" "Yi" "baichuan") # for local test
|
||||||
MODELS=("colossal-llama2" "llama2" "chatGLM2" "chatGLM3" "deepseek" "Yi")
|
MODELS=("colossal-llama2" "llama2" "chatGLM2" "chatGLM3" "deepseek" "Yi")
|
||||||
@ -39,23 +39,23 @@ get_pretrain() {
|
|||||||
get_conversation_template_config() {
|
get_conversation_template_config() {
|
||||||
local model=$1
|
local model=$1
|
||||||
if [[ $model == "colossal-llama2" ]]; then
|
if [[ $model == "colossal-llama2" ]]; then
|
||||||
echo "$CONFIG_DIR/conversation_template/colossal-llama2.json"
|
echo "$CONFIG_DIR/colossal-llama2.json"
|
||||||
elif [[ $model == "llama2" ]]; then
|
elif [[ $model == "llama2" ]]; then
|
||||||
echo "$CONFIG_DIR/conversation_template/llama2.json"
|
echo "$CONFIG_DIR/llama2.json"
|
||||||
elif [[ $model == "deepseek" ]]; then
|
elif [[ $model == "deepseek" ]]; then
|
||||||
echo "$CONFIG_DIR/conversation_template/deepseek-ai_DeepSeek-V2-Lite.json"
|
echo "$CONFIG_DIR/deepseek-ai_DeepSeek-V2-Lite.json"
|
||||||
elif [[ $model == "mistral" ]]; then
|
elif [[ $model == "mistral" ]]; then
|
||||||
echo "$CONFIG_DIR/conversation_template/mistralai_Mixtral-8x7B-Instruct-v0.1.json"
|
echo "$CONFIG_DIR/mistralai_Mixtral-8x7B-Instruct-v0.1.json"
|
||||||
elif [[ $model == "chatGLM2" ]]; then
|
elif [[ $model == "chatGLM2" ]]; then
|
||||||
echo "$CONFIG_DIR/conversation_template/THUDM_chatglm2-6b.json"
|
echo "$CONFIG_DIR/THUDM_chatglm2-6b.json"
|
||||||
elif [[ $model == "chatGLM3" ]]; then
|
elif [[ $model == "chatGLM3" ]]; then
|
||||||
echo "$CONFIG_DIR/conversation_template/THUDM_chatglm3-6b.json"
|
echo "$CONFIG_DIR/THUDM_chatglm3-6b.json"
|
||||||
elif [[ $model == "phi" ]]; then
|
elif [[ $model == "phi" ]]; then
|
||||||
echo "$CONFIG_DIR/conversation_template/microsoft_phi-2.json"
|
echo "$CONFIG_DIR/microsoft_phi-2.json"
|
||||||
elif [[ $model == "Yi" ]]; then
|
elif [[ $model == "Yi" ]]; then
|
||||||
echo "$CONFIG_DIR/conversation_template/01-ai_Yi-1.5-9B-Chat.json"
|
echo "$CONFIG_DIR/01-ai_Yi-1.5-9B-Chat.json"
|
||||||
elif [[ $model == "baichuan" ]]; then
|
elif [[ $model == "baichuan" ]]; then
|
||||||
echo "$CONFIG_DIR/conversation_template/baichuan-inc_Baichuan2-13B-Chat.json"
|
echo "$CONFIG_DIR/baichuan-inc_Baichuan2-13B-Chat.json"
|
||||||
else
|
else
|
||||||
echo "Unknown model $model"
|
echo "Unknown model $model"
|
||||||
exit 1
|
exit 1
|
||||||
@ -71,6 +71,7 @@ for model in ${MODELS[@]}; do
|
|||||||
rm -rf $SAVE_DIR/arrow
|
rm -rf $SAVE_DIR/arrow
|
||||||
pretrain=$(get_pretrain $model)
|
pretrain=$(get_pretrain $model)
|
||||||
conversation_template_config=$(get_conversation_template_config $model)
|
conversation_template_config=$(get_conversation_template_config $model)
|
||||||
|
echo $conversation_template_config
|
||||||
python $EXAMPLES_DIR/data_preparation_scripts/prepare_dataset.py --type sft --data_input_dirs $TEST_DATA_DIR/sft \
|
python $EXAMPLES_DIR/data_preparation_scripts/prepare_dataset.py --type sft --data_input_dirs $TEST_DATA_DIR/sft \
|
||||||
--tokenizer_dir $pretrain \
|
--tokenizer_dir $pretrain \
|
||||||
--conversation_template_config $conversation_template_config \
|
--conversation_template_config $conversation_template_config \
|
||||||
|
@ -81,8 +81,242 @@ random_choice() {
|
|||||||
echo ${arr[$idx]}
|
echo ${arr[$idx]}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
echo "[Test]: testing grpo ..."
|
||||||
|
|
||||||
|
|
||||||
|
SKIPPED_TESTS=(
|
||||||
|
llama-3d # 3d plugin doesn't support lora
|
||||||
|
llama-gemini # gemini doesn't support lora
|
||||||
|
)
|
||||||
|
|
||||||
|
GRAD_CKPTS=('--grad_checkpoint')
|
||||||
|
REWARD_FLAG=('nn' 'vr')
|
||||||
|
for reward_type in ${REWARD_FLAG[@]}; do
|
||||||
|
for lora_rank in ${LORA_RANK[@]}; do
|
||||||
|
for model in ${MODELS[@]}; do
|
||||||
|
for plugin in ${PLUGINS[@]}; do
|
||||||
|
if [[ $plugin == "gemini_auto" ]]; then
|
||||||
|
echo "[Test]: Skipped $model-$plugin"
|
||||||
|
continue # gemini_auto plugin doesn't support generation
|
||||||
|
fi
|
||||||
|
if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin-$lora_rank " ]]; then
|
||||||
|
echo "[Test]: Skipped $model-$plugin-$lora_rank"
|
||||||
|
continue
|
||||||
|
elif [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin " ]]; then
|
||||||
|
echo "[Test]: Skipped $model-$plugin"
|
||||||
|
continue
|
||||||
|
fi
|
||||||
|
pretrain=$(get_pretrain $model)
|
||||||
|
rm_pretrain="--rm_pretrain $pretrain"
|
||||||
|
reward_fn=""
|
||||||
|
if [[ $reward_type == "vr" ]]; then
|
||||||
|
rm_pretrain=""
|
||||||
|
reward_fn="--reward_functions gsm8k_reward_fn"
|
||||||
|
fi
|
||||||
|
tokenizer_dir=$(get_tokenizer_dirs $model)
|
||||||
|
grad_ckpt=$(random_choice "${GRAD_CKPTS[@]}")
|
||||||
|
tp='1'
|
||||||
|
bs='2'
|
||||||
|
ebs='1'
|
||||||
|
conversation_template=$(get_conversation_template_config $model)
|
||||||
|
if [[ $plugin == "zero2" ]]; then
|
||||||
|
lora_config=$LORA_CONFIG_ENABLE
|
||||||
|
else
|
||||||
|
lora_config=""
|
||||||
|
fi
|
||||||
|
if [[ $plugin == "3d" ]]; then
|
||||||
|
tp='2'
|
||||||
|
bs='2'
|
||||||
|
ebs='1'
|
||||||
|
fi
|
||||||
|
grad_accu='2'
|
||||||
|
# gemini_auto and gemini doesn't support gradient accumulation
|
||||||
|
if [[ $plugin == "gemini_auto" ]]; then
|
||||||
|
grad_accu='1'
|
||||||
|
fi
|
||||||
|
# gemini_auto and gemini doesn't support generation
|
||||||
|
if [[ $plugin == "gemini_auto" ]]; then
|
||||||
|
# gemini-auto doesn't support generation
|
||||||
|
echo "[Test]: Skipped $model-$plugin"
|
||||||
|
continue
|
||||||
|
fi
|
||||||
|
for i in $(seq $NUM_RETRY); do
|
||||||
|
echo "[Test]: $model-$plugin-$lora_rank-$reward_type, attempt $i"
|
||||||
|
declare -a prompt_dataset=()
|
||||||
|
for split in $(seq -f "%05g" 0 0); do
|
||||||
|
if [[ $reward_type == "vr" ]]; then
|
||||||
|
prompt_dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_prompt_rlvr/arrow/part-$split")
|
||||||
|
else
|
||||||
|
prompt_dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_prompt/arrow/part-$split")
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
declare -a ptx_dataset=()
|
||||||
|
for split in $(seq -f "%05g" 0 0); do
|
||||||
|
ptx_dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_sft/arrow/part-$split")
|
||||||
|
done
|
||||||
|
colossalai run --nproc_per_node 2 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_grpo.py \
|
||||||
|
--pretrain $pretrain \
|
||||||
|
$rm_pretrain \
|
||||||
|
--tokenizer_dir $tokenizer_dir \
|
||||||
|
--conversation_template_config $conversation_template \
|
||||||
|
--prompt_dataset ${prompt_dataset[@]} \
|
||||||
|
--ptx_dataset ${ptx_dataset[@]} \
|
||||||
|
--ptx_batch_size 1 \
|
||||||
|
--num_generations 2 \
|
||||||
|
--ptx_coef 0.2 \
|
||||||
|
--save_path $MODEL_SAVE_PATH \
|
||||||
|
$lora_config \
|
||||||
|
--plugin $plugin \
|
||||||
|
--num_episodes 5 \
|
||||||
|
--num_collect_steps 1 \
|
||||||
|
--num_update_steps 1 \
|
||||||
|
--experience_batch_size $ebs \
|
||||||
|
--train_batch_size $bs \
|
||||||
|
--accumulation_steps $grad_accu \
|
||||||
|
--lr 9e-6 \
|
||||||
|
--mixed_precision "bf16" \
|
||||||
|
--grad_clip 1.0 \
|
||||||
|
--tp $tp \
|
||||||
|
--lr 2e-5 \
|
||||||
|
$grad_ckpt \
|
||||||
|
--max_len 200 \ \
|
||||||
|
--max_seq_len 10 \
|
||||||
|
$reward_fn
|
||||||
|
# --use_flash_attn
|
||||||
|
passed=$?
|
||||||
|
if [ $passed -eq 0 ]; then
|
||||||
|
rm -rf ${MODEL_SAVE_PATH:?}/*
|
||||||
|
rm -rf ${MODELS_DIR:?}/*
|
||||||
|
break
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
if [ $passed -ne 0 ]; then
|
||||||
|
echo "[Test]: Failed $model-$plugin-$lora_rank-$reward_type"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
done
|
||||||
|
done
|
||||||
|
done
|
||||||
|
|
||||||
|
|
||||||
|
echo "[Test]: testing ppo ..."
|
||||||
|
|
||||||
|
|
||||||
|
SKIPPED_TESTS=(
|
||||||
|
llama-3d # 3d plugin doesn't support lora
|
||||||
|
llama-gemini # gemini doesn't support lora
|
||||||
|
)
|
||||||
|
|
||||||
|
GRAD_CKPTS=('--grad_checkpoint')
|
||||||
|
REWARD_FLAG=('vr' 'nn')
|
||||||
|
for reward_type in ${REWARD_FLAG[@]}; do
|
||||||
|
for lora_rank in ${LORA_RANK[@]}; do
|
||||||
|
for model in ${MODELS[@]}; do
|
||||||
|
for plugin in ${PLUGINS[@]}; do
|
||||||
|
if [[ $plugin == "gemini_auto" ]]; then
|
||||||
|
echo "[Test]: Skipped $model-$plugin"
|
||||||
|
continue # gemini_auto plugin doesn't support generation
|
||||||
|
fi
|
||||||
|
if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin-$lora_rank " ]]; then
|
||||||
|
echo "[Test]: Skipped $model-$plugin-$lora_rank"
|
||||||
|
continue
|
||||||
|
elif [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin " ]]; then
|
||||||
|
echo "[Test]: Skipped $model-$plugin"
|
||||||
|
continue
|
||||||
|
fi
|
||||||
|
pretrain=$(get_pretrain $model)
|
||||||
|
reward_fn=""
|
||||||
|
no_nn=""
|
||||||
|
if [[ $reward_type == "vr" ]]; then
|
||||||
|
reward_fn="--reward_functions gsm8k_reward_fn"
|
||||||
|
no_nn="--no_neural_reward_model"
|
||||||
|
fi
|
||||||
|
tokenizer_dir=$(get_tokenizer_dirs $model)
|
||||||
|
grad_ckpt=$(random_choice "${GRAD_CKPTS[@]}")
|
||||||
|
tp='1'
|
||||||
|
bs='2'
|
||||||
|
ebs='2'
|
||||||
|
conversation_template=$(get_conversation_template_config $model)
|
||||||
|
if [[ $plugin == "zero2" ]]; then
|
||||||
|
lora_config=$LORA_CONFIG_ENABLE
|
||||||
|
else
|
||||||
|
lora_config=""
|
||||||
|
fi
|
||||||
|
if [[ $plugin == "3d" ]]; then
|
||||||
|
tp='2'
|
||||||
|
bs='2'
|
||||||
|
ebs='2'
|
||||||
|
fi
|
||||||
|
grad_accu='2'
|
||||||
|
# gemini_auto and gemini doesn't support gradient accumulation
|
||||||
|
if [[ $plugin == "gemini_auto" ]]; then
|
||||||
|
grad_accu='1'
|
||||||
|
fi
|
||||||
|
# gemini_auto and gemini doesn't support generation
|
||||||
|
if [[ $plugin == "gemini_auto" ]]; then
|
||||||
|
# gemini-auto doesn't support generation
|
||||||
|
echo "[Test]: Skipped $model-$plugin"
|
||||||
|
continue
|
||||||
|
fi
|
||||||
|
for i in $(seq $NUM_RETRY); do
|
||||||
|
echo "[Test]: $model-$plugin-$lora_rank-$reward_type, attempt $i"
|
||||||
|
declare -a prompt_dataset=()
|
||||||
|
for split in $(seq -f "%05g" 0 0); do
|
||||||
|
if [[ $reward_type == "vr" ]]; then
|
||||||
|
prompt_dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_prompt_rlvr/arrow/part-$split")
|
||||||
|
else
|
||||||
|
prompt_dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_prompt/arrow/part-$split")
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
declare -a ptx_dataset=()
|
||||||
|
for split in $(seq -f "%05g" 0 0); do
|
||||||
|
ptx_dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_sft/arrow/part-$split")
|
||||||
|
done
|
||||||
|
colossalai run --nproc_per_node 2 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_ppo.py \
|
||||||
|
--pretrain $pretrain \
|
||||||
|
--rm_pretrain $pretrain \
|
||||||
|
--tokenizer_dir $tokenizer_dir \
|
||||||
|
--conversation_template_config $conversation_template \
|
||||||
|
--prompt_dataset ${prompt_dataset[@]} \
|
||||||
|
--ptx_dataset ${ptx_dataset[@]} \
|
||||||
|
--ptx_batch_size 1 \
|
||||||
|
--ptx_coef 0.2 \
|
||||||
|
--save_path $MODEL_SAVE_PATH \
|
||||||
|
$lora_config \
|
||||||
|
--plugin $plugin \
|
||||||
|
--num_episodes 5 \
|
||||||
|
--num_collect_steps 1 \
|
||||||
|
--num_update_steps 1 \
|
||||||
|
--experience_batch_size $ebs \
|
||||||
|
--train_batch_size $bs \
|
||||||
|
--accumulation_steps $grad_accu \
|
||||||
|
--lr 9e-6 \
|
||||||
|
--mixed_precision "bf16" \
|
||||||
|
--grad_clip 1.0 \
|
||||||
|
--tp $tp \
|
||||||
|
--lr 2e-5 \
|
||||||
|
$grad_ckpt \
|
||||||
|
--max_len 400 \
|
||||||
|
--max_seq_len 10 \
|
||||||
|
$reward_fn \
|
||||||
|
$no_nn
|
||||||
|
# --use_flash_attn
|
||||||
|
passed=$?
|
||||||
|
if [ $passed -eq 0 ]; then
|
||||||
|
rm -rf ${MODEL_SAVE_PATH:?}/*
|
||||||
|
rm -rf ${MODELS_DIR:?}/*
|
||||||
|
break
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
if [ $passed -ne 0 ]; then
|
||||||
|
echo "[Test]: Failed $model-$plugin-$lora_rank-$reward_type"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
done
|
||||||
|
done
|
||||||
|
done
|
||||||
|
|
||||||
echo "[Test]: testing sft ..."
|
echo "[Test]: testing sft ..."
|
||||||
|
|
||||||
@ -316,111 +550,6 @@ for lora_rank in ${LORA_RANK[@]}; do
|
|||||||
done
|
done
|
||||||
done
|
done
|
||||||
|
|
||||||
|
|
||||||
echo "[Test]: testing ppo ..."
|
|
||||||
|
|
||||||
|
|
||||||
SKIPPED_TESTS=(
|
|
||||||
llama-3d # 3d plugin doesn't support lora
|
|
||||||
llama-gemini # gemini doesn't support lora
|
|
||||||
)
|
|
||||||
|
|
||||||
GRAD_CKPTS=('--grad_checkpoint')
|
|
||||||
for lora_rank in ${LORA_RANK[@]}; do
|
|
||||||
for model in ${MODELS[@]}; do
|
|
||||||
for plugin in ${PLUGINS[@]}; do
|
|
||||||
if [[ $plugin == "gemini_auto" ]]; then
|
|
||||||
echo "[Test]: Skipped $model-$plugin"
|
|
||||||
continue # gemini_auto plugin doesn't support generation
|
|
||||||
fi
|
|
||||||
if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin-$lora_rank " ]]; then
|
|
||||||
echo "[Test]: Skipped $model-$plugin-$lora_rank"
|
|
||||||
continue
|
|
||||||
elif [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin " ]]; then
|
|
||||||
echo "[Test]: Skipped $model-$plugin"
|
|
||||||
continue
|
|
||||||
fi
|
|
||||||
pretrain=$(get_pretrain $model)
|
|
||||||
tokenizer_dir=$(get_tokenizer_dirs $model)
|
|
||||||
grad_ckpt=$(random_choice "${GRAD_CKPTS[@]}")
|
|
||||||
tp='1'
|
|
||||||
bs='4'
|
|
||||||
ebs='8'
|
|
||||||
conversation_template=$(get_conversation_template_config $model)
|
|
||||||
if [[ $plugin == "zero2" ]]; then
|
|
||||||
lora_config=$LORA_CONFIG_ENABLE
|
|
||||||
else
|
|
||||||
lora_config=""
|
|
||||||
fi
|
|
||||||
if [[ $plugin == "3d" ]]; then
|
|
||||||
tp='2'
|
|
||||||
bs='16'
|
|
||||||
ebs='32'
|
|
||||||
fi
|
|
||||||
grad_accu='2'
|
|
||||||
# gemini_auto and gemini doesn't support gradient accumulation
|
|
||||||
if [[ $plugin == "gemini_auto" ]]; then
|
|
||||||
grad_accu='1'
|
|
||||||
fi
|
|
||||||
# gemini_auto and gemini doesn't support generation
|
|
||||||
if [[ $plugin == "gemini_auto" ]]; then
|
|
||||||
# gemini-auto doesn't support generation
|
|
||||||
echo "[Test]: Skipped $model-$plugin"
|
|
||||||
continue
|
|
||||||
fi
|
|
||||||
for i in $(seq $NUM_RETRY); do
|
|
||||||
echo "[Test]: $model-$plugin-$lora_rank, attempt $i"
|
|
||||||
declare -a prompt_dataset=()
|
|
||||||
for split in $(seq -f "%05g" 0 0); do
|
|
||||||
prompt_dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_prompt/arrow/part-$split")
|
|
||||||
done
|
|
||||||
declare -a ptx_dataset=()
|
|
||||||
for split in $(seq -f "%05g" 0 0); do
|
|
||||||
ptx_dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_sft/arrow/part-$split")
|
|
||||||
done
|
|
||||||
colossalai run --nproc_per_node 2 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_ppo.py \
|
|
||||||
--pretrain $pretrain \
|
|
||||||
--rm_pretrain $pretrain \
|
|
||||||
--tokenizer_dir $tokenizer_dir \
|
|
||||||
--conversation_template_config $conversation_template \
|
|
||||||
--prompt_dataset ${prompt_dataset[@]} \
|
|
||||||
--ptx_dataset ${ptx_dataset[@]} \
|
|
||||||
--ptx_batch_size 1 \
|
|
||||||
--ptx_coef 0.2 \
|
|
||||||
--save_path $MODEL_SAVE_PATH \
|
|
||||||
$lora_config \
|
|
||||||
--plugin $plugin \
|
|
||||||
--num_episodes 5 \
|
|
||||||
--num_collect_steps 1 \
|
|
||||||
--num_update_steps 1 \
|
|
||||||
--experience_batch_size $ebs \
|
|
||||||
--train_batch_size $bs \
|
|
||||||
--accumulation_steps $grad_accu \
|
|
||||||
--lr 9e-6 \
|
|
||||||
--mixed_precision "bf16" \
|
|
||||||
--grad_clip 1.0 \
|
|
||||||
--tp $tp \
|
|
||||||
--lr 2e-5 \
|
|
||||||
$grad_ckpt \
|
|
||||||
--max_len 400 \
|
|
||||||
--max_seq_len 10 \
|
|
||||||
# --use_flash_attn
|
|
||||||
passed=$?
|
|
||||||
if [ $passed -eq 0 ]; then
|
|
||||||
rm -rf ${MODEL_SAVE_PATH:?}/*
|
|
||||||
rm -rf ${MODELS_DIR:?}/*
|
|
||||||
break
|
|
||||||
fi
|
|
||||||
done
|
|
||||||
if [ $passed -ne 0 ]; then
|
|
||||||
echo "[Test]: Failed $model-$plugin-$lora_rank"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
done
|
|
||||||
done
|
|
||||||
done
|
|
||||||
|
|
||||||
|
|
||||||
echo "[Test]: testing DPO ..."
|
echo "[Test]: testing DPO ..."
|
||||||
|
|
||||||
SKIPPED_TESTS=(
|
SKIPPED_TESTS=(
|
||||||
@ -446,7 +575,7 @@ for lora_rank in ${LORA_RANK[@]}; do
|
|||||||
bs='2'
|
bs='2'
|
||||||
if [[ $plugin == "3d" ]]; then
|
if [[ $plugin == "3d" ]]; then
|
||||||
tp='2'
|
tp='2'
|
||||||
bs='8'
|
bs='2'
|
||||||
fi
|
fi
|
||||||
if [[ $plugin == "zero2" ]]; then
|
if [[ $plugin == "zero2" ]]; then
|
||||||
lora_config=$LORA_CONFIG_ENABLE
|
lora_config=$LORA_CONFIG_ENABLE
|
||||||
@ -503,10 +632,10 @@ for lora_rank in ${LORA_RANK[@]}; do
|
|||||||
done
|
done
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
echo "[Test]: testing ORPO ..."
|
echo "[Test]: testing ORPO ..."
|
||||||
|
|
||||||
SKIPPED_TESTS=(
|
SKIPPED_TESTS=(
|
||||||
|
llama-3d-0
|
||||||
llama-3d-20 # 3d plugin doesn't support lora
|
llama-3d-20 # 3d plugin doesn't support lora
|
||||||
llama-gemini_auto-20 # gemini_auto plugin doesn't support lora
|
llama-gemini_auto-20 # gemini_auto plugin doesn't support lora
|
||||||
llama-gemini-20 # gemini doesn't support lora
|
llama-gemini-20 # gemini doesn't support lora
|
||||||
@ -529,7 +658,7 @@ for lora_rank in ${LORA_RANK[@]}; do
|
|||||||
bs='2'
|
bs='2'
|
||||||
if [[ $plugin == "3d" ]]; then
|
if [[ $plugin == "3d" ]]; then
|
||||||
tp='2'
|
tp='2'
|
||||||
bs='8'
|
bs='2'
|
||||||
fi
|
fi
|
||||||
if [[ $plugin == "zero2" ]]; then
|
if [[ $plugin == "zero2" ]]; then
|
||||||
lora_config=$LORA_CONFIG_ENABLE
|
lora_config=$LORA_CONFIG_ENABLE
|
||||||
@ -585,11 +714,10 @@ for lora_rank in ${LORA_RANK[@]}; do
|
|||||||
done
|
done
|
||||||
done
|
done
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
echo "[Test]: testing KTO ..."
|
echo "[Test]: testing KTO ..."
|
||||||
|
|
||||||
SKIPPED_TESTS=(
|
SKIPPED_TESTS=(
|
||||||
|
llama-3d-0
|
||||||
llama-3d-20 # 3d plugin doesn't support lora
|
llama-3d-20 # 3d plugin doesn't support lora
|
||||||
llama-gemini_auto-20 # gemini_auto plugin doesn't support lora
|
llama-gemini_auto-20 # gemini_auto plugin doesn't support lora
|
||||||
llama-gemini-20 # gemini doesn't support lora
|
llama-gemini-20 # gemini doesn't support lora
|
||||||
@ -612,7 +740,7 @@ for lora_rank in ${LORA_RANK[@]}; do
|
|||||||
bs='2'
|
bs='2'
|
||||||
if [[ $plugin == "3d" ]]; then
|
if [[ $plugin == "3d" ]]; then
|
||||||
tp='2'
|
tp='2'
|
||||||
bs='8'
|
bs='2'
|
||||||
fi
|
fi
|
||||||
if [[ $plugin == "zero2" ]]; then
|
if [[ $plugin == "zero2" ]]; then
|
||||||
lora_config=$LORA_CONFIG_ENABLE
|
lora_config=$LORA_CONFIG_ENABLE
|
||||||
|
@ -154,7 +154,7 @@ inference_kwargs = {
|
|||||||
"calculate_loss": True,
|
"calculate_loss": True,
|
||||||
"all_classes": ["A", "B", "C", "D"],
|
"all_classes": ["A", "B", "C", "D"],
|
||||||
"language": "Chinese",
|
"language": "Chinese",
|
||||||
"pretrain": False,
|
"calculate_overall_loss": False,
|
||||||
"max_new_tokens": 32
|
"max_new_tokens": 32
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
@ -163,7 +163,7 @@ The `inference_kwargs` currently contains 5 fields:
|
|||||||
- `calculate_loss` (bool, compulsory): Whether the loss on target tokens will be calculated
|
- `calculate_loss` (bool, compulsory): Whether the loss on target tokens will be calculated
|
||||||
- `all_classes` (Optional[list], compulsory): Whether the subcategory is a single-choice question. Specify all available options in a list or otherwise None.
|
- `all_classes` (Optional[list], compulsory): Whether the subcategory is a single-choice question. Specify all available options in a list or otherwise None.
|
||||||
- `language` (str, compulsory): The language for the subcategory.
|
- `language` (str, compulsory): The language for the subcategory.
|
||||||
- `pretrain` (bool, compulsory): Whether the dataset is a pretrain dataset or not. It is usually used for calculate perplexity when you want to evaluate a model with extended context length.
|
- `calculate_overall_loss` (bool, compulsory): Whether to calculate the overall loss of sentences or not if the dataset is a pretrain dataset. It is usually used for calculate perplexity when you want to evaluate a model with extended context length.
|
||||||
- `max_new_tokens` (int, compulsory): The number of new tokens to generate during inference.
|
- `max_new_tokens` (int, compulsory): The number of new tokens to generate during inference.
|
||||||
|
|
||||||
For example, for dataset MMLU, each subcategory consists of single-choice questions with options A, B, C and D by default and we can assign value `["A", "B", "C", "D"]` to key`all_classes`. For dataset C-Eval, target answers aren't provided in the test split so `calculate_loss` should be set as False. However, other dataset such as GAOKAO-bench contains different formats of questions and lacks some keys or metadata which can reveal what type (single-choice or multi-choice) of questions it is. Before assigning inference arguments, we first parse the dataset to decide which type of questions the subcategory belongs to and set the inference arguments accordingly.
|
For example, for dataset MMLU, each subcategory consists of single-choice questions with options A, B, C and D by default and we can assign value `["A", "B", "C", "D"]` to key`all_classes`. For dataset C-Eval, target answers aren't provided in the test split so `calculate_loss` should be set as False. However, other dataset such as GAOKAO-bench contains different formats of questions and lacks some keys or metadata which can reveal what type (single-choice or multi-choice) of questions it is. Before assigning inference arguments, we first parse the dataset to decide which type of questions the subcategory belongs to and set the inference arguments accordingly.
|
||||||
@ -230,7 +230,7 @@ Example:
|
|||||||
In this step, you will configure your tokenizer and model arguments to infer on the given datasets.
|
In this step, you will configure your tokenizer and model arguments to infer on the given datasets.
|
||||||
|
|
||||||
A config file consists of two parts.
|
A config file consists of two parts.
|
||||||
1. Model config. In model config, you need to specify model name, model path, model class, tokenizer arguments and model arguments. For model class, currently we support `HuggingFaceModel`, `HuggingFaceCausalLM`, `ChatGLMModel` and `ChatGLMModel2`. `HuggingFaceModel` is for models that can be loaded with `AutoModel` and `HuggingFaceCausalLM` is for models that can be loaded with `AutoModelForCausalLM`. `ChatGLMModel` and `ChatGLMModel2` are for ChatGLM and ChatGLM2 models respectively. You can check all model classes in `colossal_eval/models/__init__.py`. If your model should set `trust_remote_code` as true, specify it in the `tokenizer_kwargs` and `model_kwargs` fields.
|
1. Model config. In model config, you need to specify model name, model path, model class, tokenizer arguments and model arguments. For model class, currently we support `HuggingFaceModel`, `HuggingFaceCausalLM`, `ChatGLMModel`, `ChatGLMModel2` and `vLLMModel`. `HuggingFaceModel` is for models that can be loaded with `AutoModel` and `HuggingFaceCausalLM` is for models that can be loaded with `AutoModelForCausalLM`. `ChatGLMModel` and `ChatGLMModel2` are for ChatGLM and ChatGLM2 models respectively. `vLLMModel` is for models that can be loaded with vllm offline inference `LLM` class. You can check all model classes in `colossal_eval/models/__init__.py`. If your model should set `trust_remote_code` as true, specify it in the `tokenizer_kwargs` and `model_kwargs` fields.
|
||||||
2. Dataset config. In dataset config, you need to specify dataset name, path and dataset class. Currently, we support zero-shot on dataset MMLU, CMMLU, AGIEval, GAOKAO-Bench, GSM8K and LongBench and few-shot on dataset MMLU, CMMLU AGIEval and GSM8K. If you want to enable few shot, set `few_shot` as true. You can check all model classes in `colossal_eval/dataset/__init__.py`.
|
2. Dataset config. In dataset config, you need to specify dataset name, path and dataset class. Currently, we support zero-shot on dataset MMLU, CMMLU, AGIEval, GAOKAO-Bench, GSM8K and LongBench and few-shot on dataset MMLU, CMMLU AGIEval and GSM8K. If you want to enable few shot, set `few_shot` as true. You can check all model classes in `colossal_eval/dataset/__init__.py`.
|
||||||
|
|
||||||
Once you have all config ready, the program will run inference on all the given datasets on all the given models.
|
Once you have all config ready, the program will run inference on all the given datasets on all the given models.
|
||||||
@ -272,7 +272,42 @@ An example config using model class `HuggingFaceCausalLM` and dataset class `CMM
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
Currently, we support Hugging Face models. The `tokenizer_kwargs` is the arguments used in `AutoTokenizer.from_pretrained()`. The `model_kwargs` is the arguments used in `AutoModel.from_pretrained` or `AutoModelForCausalLM.from_pretrained()`. `few_shot` will be set true if you want to enable few-shot prompting for the dataset. `debug` will be set true if you want to verify whether your prompt is right or wrong.
|
An example config using model class `vLLMModel` and dataset class `CMMLUDataset` can be:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"model": [
|
||||||
|
{
|
||||||
|
"name": "model name",
|
||||||
|
"model_class": "vLLMModel",
|
||||||
|
"parameters": {
|
||||||
|
"path": "path to model",
|
||||||
|
"model_max_length": 2048,
|
||||||
|
"tokenizer_path": "",
|
||||||
|
"tokenizer_kwargs": {
|
||||||
|
"trust_remote_code": true
|
||||||
|
},
|
||||||
|
"model_kwargs": {
|
||||||
|
"trust_remote_code": true
|
||||||
|
},
|
||||||
|
"prompt_template": "plain",
|
||||||
|
"batch_size": 4
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"dataset": [
|
||||||
|
{
|
||||||
|
"name": "dataset name",
|
||||||
|
"dataset_class": "CMMLUDataset",
|
||||||
|
"debug": false,
|
||||||
|
"few_shot": true,
|
||||||
|
"path": "path to original dataset",
|
||||||
|
"save_path": "path to save converted dataset"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Currently, we support Hugging Face models as well as vLLM models. For Hugging Face models, the `tokenizer_kwargs` is the arguments used in `AutoTokenizer.from_pretrained()`. The `model_kwargs` is the arguments used in `AutoModel.from_pretrained` or `AutoModelForCausalLM.from_pretrained()`. For vLLM model, the `tokenizer_kwargs` and `model_kwargs` are loaded together in `LLM` class.`few_shot` will be set true if you want to enable few-shot prompting for the dataset. `debug` will be set true if you want to verify whether your prompt is right or wrong.
|
||||||
|
|
||||||
> For GSM8K dataset, you can set additional flags `load_train` or `load_reference` for dataset configuration as true and during the inference process, the program will calculate loss summation over all tokens for each data sample. During the evaluation process, you can use metric `loss_over_all_tokens` to calculate the overall loss and use it for data leakage evaluation.
|
> For GSM8K dataset, you can set additional flags `load_train` or `load_reference` for dataset configuration as true and during the inference process, the program will calculate loss summation over all tokens for each data sample. During the evaluation process, you can use metric `loss_over_all_tokens` to calculate the overall loss and use it for data leakage evaluation.
|
||||||
|
|
||||||
@ -287,7 +322,7 @@ torchrun --nproc_per_node=4 inference.py \
|
|||||||
--inference_save_path "path to save inference results"
|
--inference_save_path "path to save inference results"
|
||||||
```
|
```
|
||||||
|
|
||||||
You should specify the path to config file in `config`. You can run the script without specifying `load_dataset` if you already save the converted dataset or otherwise set it to first load the original dataset and save the converted dataset. You should specify the path to save inference results in `inference_save_path`. If you want to use tensor parallel inference, specify the tensor parallel size in `--tp_size` and the process will automatically calculate data parallel size.
|
You should specify the path to config file in `config`. You can run the script without specifying `load_dataset` if you already save the converted dataset or otherwise set it to first load the original dataset and save the converted dataset. You should specify the path to save inference results in `inference_save_path`. If you want to use tensor parallel inference, specify the tensor parallel size in `--tp_size` and the process will automatically calculate data parallel size (currently not support for `vLLMModel`).
|
||||||
|
|
||||||
### Evaluation
|
### Evaluation
|
||||||
|
|
||||||
@ -530,10 +565,6 @@ class CustomizedModel(BaseModel):
|
|||||||
|
|
||||||
Once you have successfully added your own model, you can specify your model class in your inference config.
|
Once you have successfully added your own model, you can specify your model class in your inference config.
|
||||||
|
|
||||||
## To do
|
|
||||||
|
|
||||||
- [ ] Add visualization code for evaluation results on public dataset
|
|
||||||
- [ ] Improve the way to label target tokens
|
|
||||||
|
|
||||||
## Citations
|
## Citations
|
||||||
|
|
||||||
|
@ -47,7 +47,7 @@ default_inference_kwargs = {
|
|||||||
"calculate_loss": True,
|
"calculate_loss": True,
|
||||||
"all_classes": None,
|
"all_classes": None,
|
||||||
"language": "Chinese",
|
"language": "Chinese",
|
||||||
"pretrain": False,
|
"calculate_overall_loss": False,
|
||||||
"max_new_tokens": 32,
|
"max_new_tokens": 32,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -70,7 +70,7 @@ default_inference_kwargs = {
|
|||||||
"calculate_loss": False,
|
"calculate_loss": False,
|
||||||
"all_classes": ["A", "B", "C", "D"],
|
"all_classes": ["A", "B", "C", "D"],
|
||||||
"language": "Chinese",
|
"language": "Chinese",
|
||||||
"pretrain": False,
|
"calculate_overall_loss": False,
|
||||||
"max_new_tokens": 32,
|
"max_new_tokens": 32,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -81,7 +81,7 @@ default_inference_kwargs = {
|
|||||||
"calculate_loss": True,
|
"calculate_loss": True,
|
||||||
"all_classes": ["A", "B", "C", "D"],
|
"all_classes": ["A", "B", "C", "D"],
|
||||||
"language": "Chinese",
|
"language": "Chinese",
|
||||||
"pretrain": False,
|
"calculate_overall_loss": False,
|
||||||
"max_new_tokens": 32,
|
"max_new_tokens": 32,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -12,7 +12,7 @@ default_inference_kwargs = {
|
|||||||
"calculate_loss": False,
|
"calculate_loss": False,
|
||||||
"all_classes": None,
|
"all_classes": None,
|
||||||
"language": "Chinese",
|
"language": "Chinese",
|
||||||
"pretrain": False,
|
"calculate_overall_loss": False,
|
||||||
"max_new_tokens": 256,
|
"max_new_tokens": 256,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -15,7 +15,7 @@ default_inference_kwargs = {
|
|||||||
"calculate_loss": False,
|
"calculate_loss": False,
|
||||||
"all_classes": ["A", "B"],
|
"all_classes": ["A", "B"],
|
||||||
"language": LANGUAGE,
|
"language": LANGUAGE,
|
||||||
"pretrain": False,
|
"calculate_overall_loss": False,
|
||||||
"max_new_tokens": 32,
|
"max_new_tokens": 32,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -36,7 +36,7 @@ default_inference_kwargs = {
|
|||||||
"calculate_loss": True,
|
"calculate_loss": True,
|
||||||
"all_classes": None,
|
"all_classes": None,
|
||||||
"language": "Chinese",
|
"language": "Chinese",
|
||||||
"pretrain": False,
|
"calculate_overall_loss": False,
|
||||||
"max_new_tokens": 32,
|
"max_new_tokens": 32,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -72,7 +72,7 @@ default_inference_kwargs = {
|
|||||||
"calculate_loss": True,
|
"calculate_loss": True,
|
||||||
"all_classes": None,
|
"all_classes": None,
|
||||||
"language": "English",
|
"language": "English",
|
||||||
"pretrain": False,
|
"calculate_overall_loss": False,
|
||||||
"max_new_tokens": 256,
|
"max_new_tokens": 256,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -114,7 +114,7 @@ class GSMDataset(BaseDataset):
|
|||||||
dataset[split][subject]["inference_kwargs"] = copy.deepcopy(default_inference_kwargs)
|
dataset[split][subject]["inference_kwargs"] = copy.deepcopy(default_inference_kwargs)
|
||||||
|
|
||||||
if forward_only:
|
if forward_only:
|
||||||
dataset[split][subject]["inference_kwargs"]["pretrain"] = True
|
dataset[split][subject]["inference_kwargs"]["calculate_overall_loss"] = True
|
||||||
|
|
||||||
if split == "test" and few_shot:
|
if split == "test" and few_shot:
|
||||||
dataset[split][subject]["inference_kwargs"]["few_shot_data"] = get_few_shot_data()
|
dataset[split][subject]["inference_kwargs"]["few_shot_data"] = get_few_shot_data()
|
||||||
|
@ -60,7 +60,7 @@ default_inference_kwargs = {
|
|||||||
"calculate_loss": True,
|
"calculate_loss": True,
|
||||||
"all_classes": None,
|
"all_classes": None,
|
||||||
"language": "Chinese",
|
"language": "Chinese",
|
||||||
"pretrain": False,
|
"calculate_overall_loss": False,
|
||||||
"max_new_tokens": 32,
|
"max_new_tokens": 32,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -11,7 +11,7 @@ default_inference_kwargs = {
|
|||||||
"calculate_loss": True,
|
"calculate_loss": True,
|
||||||
"all_classes": ["A", "B", "C", "D"],
|
"all_classes": ["A", "B", "C", "D"],
|
||||||
"language": "English",
|
"language": "English",
|
||||||
"pretrain": False,
|
"calculate_overall_loss": False,
|
||||||
"max_new_tokens": 32,
|
"max_new_tokens": 32,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -14,7 +14,7 @@ default_inference_kwargs = {
|
|||||||
"calculate_loss": False,
|
"calculate_loss": False,
|
||||||
"all_classes": None,
|
"all_classes": None,
|
||||||
"language": "English",
|
"language": "English",
|
||||||
"pretrain": False,
|
"calculate_overall_loss": False,
|
||||||
"max_new_tokens": 1024,
|
"max_new_tokens": 1024,
|
||||||
"turns": 2,
|
"turns": 2,
|
||||||
}
|
}
|
||||||
|
@ -28,7 +28,7 @@ default_inference_kwargs = {
|
|||||||
"calculate_loss": False,
|
"calculate_loss": False,
|
||||||
"all_classes": ["A", "B", "C", "D"],
|
"all_classes": ["A", "B", "C", "D"],
|
||||||
"language": LANGUAGE,
|
"language": LANGUAGE,
|
||||||
"pretrain": False,
|
"calculate_overall_loss": False,
|
||||||
"max_new_tokens": 32,
|
"max_new_tokens": 32,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -28,7 +28,7 @@ default_inference_kwargs = {
|
|||||||
"calculate_loss": False,
|
"calculate_loss": False,
|
||||||
"all_classes": ["A", "B", "C", "D"],
|
"all_classes": ["A", "B", "C", "D"],
|
||||||
"language": LANGUAGE,
|
"language": LANGUAGE,
|
||||||
"pretrain": False,
|
"calculate_overall_loss": False,
|
||||||
"max_new_tokens": 32,
|
"max_new_tokens": 32,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
from .base import BaseModel
|
from .base import BaseModel
|
||||||
from .chatglm import ChatGLM2Model, ChatGLMModel
|
from .chatglm import ChatGLM2Model, ChatGLMModel
|
||||||
from .huggingface import HuggingFaceCausalLM, HuggingFaceModel
|
from .huggingface import HuggingFaceCausalLM, HuggingFaceModel
|
||||||
|
from .vllm import vLLMModel
|
||||||
|
|
||||||
__all__ = ["BaseModel", "HuggingFaceModel", "HuggingFaceCausalLM", "ChatGLMModel", "ChatGLM2Model"]
|
__all__ = ["BaseModel", "HuggingFaceModel", "HuggingFaceCausalLM", "ChatGLMModel", "ChatGLM2Model", "vLLMModel"]
|
||||||
|
@ -28,7 +28,7 @@ class ChatGLMModel(HuggingFaceModel):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def get_loss(
|
def get_loss(
|
||||||
self, batch_prompt: List[str], batch_target: List[List[str]], pretrain: bool = False
|
self, batch_prompt: List[str], batch_target: List[List[str]], calculate_overall_loss: bool = False
|
||||||
) -> List[List[float]]:
|
) -> List[List[float]]:
|
||||||
"""
|
"""
|
||||||
Calculate loss only on target tokens.
|
Calculate loss only on target tokens.
|
||||||
@ -225,7 +225,7 @@ class ChatGLM2Model(ChatGLMModel):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def get_loss(
|
def get_loss(
|
||||||
self, batch_prompt: List[str], batch_target: List[List[str]], pretrain: bool = False
|
self, batch_prompt: List[str], batch_target: List[List[str]], calculate_overall_loss: bool = False
|
||||||
) -> List[List[float]]:
|
) -> List[List[float]]:
|
||||||
"""
|
"""
|
||||||
Calculate loss only on target tokens.
|
Calculate loss only on target tokens.
|
||||||
|
@ -105,6 +105,12 @@ class HuggingFaceModel(BaseModel):
|
|||||||
elif hasattr(self.tokenizer, "eod_id"):
|
elif hasattr(self.tokenizer, "eod_id"):
|
||||||
# Qwen has an eod token "<|endoftext|>".
|
# Qwen has an eod token "<|endoftext|>".
|
||||||
self.tokenizer.pad_token_id = self.tokenizer.eod_id
|
self.tokenizer.pad_token_id = self.tokenizer.eod_id
|
||||||
|
else:
|
||||||
|
self.logger.error("Neither eos_token nor eod_id is available for setting pad_token_id.")
|
||||||
|
raise ValueError(
|
||||||
|
"The tokenizer does not have a pad_token_id, eos_token, or eod_id. "
|
||||||
|
"Please set pad_token_id manually."
|
||||||
|
)
|
||||||
|
|
||||||
def _load_model(
|
def _load_model(
|
||||||
self, path: str, model_kwargs: dict, peft_path: Optional[str] = None, shard_config: ShardConfig = None
|
self, path: str, model_kwargs: dict, peft_path: Optional[str] = None, shard_config: ShardConfig = None
|
||||||
@ -245,7 +251,7 @@ class HuggingFaceModel(BaseModel):
|
|||||||
return input_ids_list, labels_list, bytes_list
|
return input_ids_list, labels_list, bytes_list
|
||||||
|
|
||||||
def _get_input_ids_and_labels(
|
def _get_input_ids_and_labels(
|
||||||
self, batch_prompt: List[str], batch_target: List[List[str]], pretrain: bool
|
self, batch_prompt: List[str], batch_target: List[List[str]], calculate_overall_loss: bool
|
||||||
) -> Tuple[List[torch.LongTensor]]:
|
) -> Tuple[List[torch.LongTensor]]:
|
||||||
"""
|
"""
|
||||||
Get input_ids and labels for the given data.
|
Get input_ids and labels for the given data.
|
||||||
@ -258,7 +264,7 @@ class HuggingFaceModel(BaseModel):
|
|||||||
Input_ids and labels for the given batch.
|
Input_ids and labels for the given batch.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if pretrain:
|
if calculate_overall_loss:
|
||||||
batch = []
|
batch = []
|
||||||
# Concatenate prompt and target answers.
|
# Concatenate prompt and target answers.
|
||||||
# You should decide the concatenation character in the corresponding dataset script in dataset folder. For example, in line 119 dataset/gsm.py, the concatenation character is space.
|
# You should decide the concatenation character in the corresponding dataset script in dataset folder. For example, in line 119 dataset/gsm.py, the concatenation character is space.
|
||||||
@ -342,7 +348,7 @@ class HuggingFaceModel(BaseModel):
|
|||||||
calculate_loss = inference_kwargs["calculate_loss"]
|
calculate_loss = inference_kwargs["calculate_loss"]
|
||||||
classes = inference_kwargs["all_classes"]
|
classes = inference_kwargs["all_classes"]
|
||||||
language = inference_kwargs["language"]
|
language = inference_kwargs["language"]
|
||||||
pretrain = inference_kwargs["pretrain"]
|
calculate_overall_loss = inference_kwargs["calculate_overall_loss"]
|
||||||
max_new_tokens = inference_kwargs["max_new_tokens"]
|
max_new_tokens = inference_kwargs["max_new_tokens"]
|
||||||
few_shot_data = inference_kwargs.get("few_shot_data", None)
|
few_shot_data = inference_kwargs.get("few_shot_data", None)
|
||||||
|
|
||||||
@ -384,12 +390,12 @@ class HuggingFaceModel(BaseModel):
|
|||||||
self.logger.info("-" * 120)
|
self.logger.info("-" * 120)
|
||||||
self.logger.info(batch_prompt[0] + batch_target[0][0])
|
self.logger.info(batch_prompt[0] + batch_target[0][0])
|
||||||
|
|
||||||
if not pretrain:
|
if not calculate_overall_loss:
|
||||||
batch_decodes, scores = self.generate(batch_prompt, max_new_tokens)
|
batch_decodes, scores = self.generate(batch_prompt, max_new_tokens)
|
||||||
|
|
||||||
if calculate_loss:
|
if calculate_loss:
|
||||||
batch_losses, batch_target_token_nums, batch_bytes_nums = self.get_loss(
|
batch_losses, batch_target_token_nums, batch_bytes_nums = self.get_loss(
|
||||||
batch_prompt, batch_target, pretrain
|
batch_prompt, batch_target, calculate_overall_loss
|
||||||
)
|
)
|
||||||
|
|
||||||
probs = []
|
probs = []
|
||||||
@ -409,7 +415,7 @@ class HuggingFaceModel(BaseModel):
|
|||||||
]
|
]
|
||||||
|
|
||||||
for j in range(len(batch)):
|
for j in range(len(batch)):
|
||||||
if not pretrain:
|
if not calculate_overall_loss:
|
||||||
if isinstance(batch[j]["output"], list):
|
if isinstance(batch[j]["output"], list):
|
||||||
batch[j]["output"].append(batch_decodes[j].strip())
|
batch[j]["output"].append(batch_decodes[j].strip())
|
||||||
else:
|
else:
|
||||||
@ -496,7 +502,9 @@ class HuggingFaceModel(BaseModel):
|
|||||||
return decoded_sequences, scores
|
return decoded_sequences, scores
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def get_loss(self, batch_prompt: List[str], batch_target: List[List[str]], pretrain: bool) -> List[List[float]]:
|
def get_loss(
|
||||||
|
self, batch_prompt: List[str], batch_target: List[List[str]], calculate_overall_loss: bool
|
||||||
|
) -> List[List[float]]:
|
||||||
"""
|
"""
|
||||||
Calculate loss only on target tokens.
|
Calculate loss only on target tokens.
|
||||||
|
|
||||||
@ -513,13 +521,15 @@ class HuggingFaceModel(BaseModel):
|
|||||||
# We don't need to generate new tokens.
|
# We don't need to generate new tokens.
|
||||||
# Target answer's length is usually << model_max_length, but we still call it in case.
|
# Target answer's length is usually << model_max_length, but we still call it in case.
|
||||||
# We don't call self._get_truncated_prompts for batch_prompt because we need target answer's length first to reserve some space for target answer's tokens.
|
# We don't call self._get_truncated_prompts for batch_prompt because we need target answer's length first to reserve some space for target answer's tokens.
|
||||||
if not pretrain:
|
if not calculate_overall_loss:
|
||||||
batch_target = [self._get_truncated_prompts(prompt_target, 0) for prompt_target in batch_target]
|
batch_target = [self._get_truncated_prompts(prompt_target, 0) for prompt_target in batch_target]
|
||||||
|
|
||||||
# Get the number of target answers for different questions
|
# Get the number of target answers for different questions
|
||||||
batch_target_nums = [len(prompt_target) for prompt_target in batch_target]
|
batch_target_nums = [len(prompt_target) for prompt_target in batch_target]
|
||||||
|
|
||||||
input_ids_list, labels_list, bytes_list = self._get_input_ids_and_labels(batch_prompt, batch_target, pretrain)
|
input_ids_list, labels_list, bytes_list = self._get_input_ids_and_labels(
|
||||||
|
batch_prompt, batch_target, calculate_overall_loss
|
||||||
|
)
|
||||||
|
|
||||||
# Because of multiple target answers, the final batch size may be greater than self.batch_size.
|
# Because of multiple target answers, the final batch size may be greater than self.batch_size.
|
||||||
# We will generate new batches.
|
# We will generate new batches.
|
||||||
|
498
applications/ColossalEval/colossal_eval/models/vllm.py
Normal file
498
applications/ColossalEval/colossal_eval/models/vllm.py
Normal file
@ -0,0 +1,498 @@
|
|||||||
|
import copy
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from colossal_eval.utils import Conversation, get_batch_prompt, is_rank_0
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from tqdm import tqdm
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
|
||||||
|
from colossalai.logging import DistributedLogger
|
||||||
|
|
||||||
|
from .huggingface import HuggingFaceModel
|
||||||
|
|
||||||
|
IGNORE_INDEX = -100
|
||||||
|
|
||||||
|
|
||||||
|
class vLLMModel(HuggingFaceModel):
|
||||||
|
"""
|
||||||
|
Model wrapper around vLLM models.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: The path to a vLLM model.
|
||||||
|
model_max_length: The maximum sequence length of the model.
|
||||||
|
tokenizer_path: The path to the tokenizer.
|
||||||
|
tokenizer_kwargs: Keyword arguments for the tokenizer.
|
||||||
|
model_kwargs: Keyword arguments for the model.
|
||||||
|
prompt_template: The model's prompt template.
|
||||||
|
batch_size: Batch size for inference.
|
||||||
|
logger: Logger for the model.
|
||||||
|
trust_remote_code: Trust remote code (e.g., from HuggingFace) when downloading the model and tokenizer.
|
||||||
|
tensor_parallel_size: The number of GPUs to use for distributed execution with tensor parallelism.
|
||||||
|
quantization: The method used to quantize the model weights
|
||||||
|
gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV cache.
|
||||||
|
swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
|
||||||
|
cpu_offload_gb: The size (GiB) of CPU memory to use for offloading the model weights.
|
||||||
|
enforce_eager: Whether to enforce eager execution.
|
||||||
|
max_context_len_to_capture: Maximum context len covered by CUDA graphs.
|
||||||
|
max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
|
||||||
|
disable_custom_all_reduce: See ParallelConfig
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
path: str,
|
||||||
|
model_max_length: int = 2048,
|
||||||
|
tokenizer_path: Optional[str] = None,
|
||||||
|
tokenizer_kwargs: Dict = None,
|
||||||
|
model_kwargs: Dict = None,
|
||||||
|
prompt_template: Conversation = None,
|
||||||
|
batch_size: int = 1,
|
||||||
|
logger: DistributedLogger = None,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
|
tensor_parallel_size: int = 1,
|
||||||
|
quantization: Optional[str] = None,
|
||||||
|
gpu_memory_utilization: float = 0.5,
|
||||||
|
swap_space: float = 4,
|
||||||
|
cpu_offload_gb: float = 0,
|
||||||
|
enforce_eager: Optional[bool] = None,
|
||||||
|
max_context_len_to_capture: Optional[int] = None,
|
||||||
|
max_seq_len_to_capture: int = 8192,
|
||||||
|
disable_custom_all_reduce: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
path=path,
|
||||||
|
model_max_length=model_max_length,
|
||||||
|
prompt_template=prompt_template,
|
||||||
|
batch_size=batch_size,
|
||||||
|
logger=logger,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._load_model(
|
||||||
|
path=path,
|
||||||
|
model_kwargs=model_kwargs,
|
||||||
|
tokenizer_kwargs=tokenizer_kwargs,
|
||||||
|
tokenizer_path=tokenizer_path if tokenizer_path else None,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
tensor_parallel_size=tensor_parallel_size,
|
||||||
|
quantization=quantization,
|
||||||
|
gpu_memory_utilization=gpu_memory_utilization,
|
||||||
|
swap_space=swap_space,
|
||||||
|
cpu_offload_gb=cpu_offload_gb,
|
||||||
|
enforce_eager=enforce_eager,
|
||||||
|
max_context_len_to_capture=max_context_len_to_capture,
|
||||||
|
max_seq_len_to_capture=max_seq_len_to_capture,
|
||||||
|
disable_custom_all_reduce=disable_custom_all_reduce,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _load_model(
|
||||||
|
self,
|
||||||
|
path: str,
|
||||||
|
model_kwargs: dict,
|
||||||
|
tokenizer_kwargs: dict,
|
||||||
|
tokenizer_path: Optional[str] = None,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
|
tensor_parallel_size: int = 1,
|
||||||
|
quantization: Optional[str] = None,
|
||||||
|
gpu_memory_utilization: float = 0.9,
|
||||||
|
swap_space: float = 4,
|
||||||
|
cpu_offload_gb: float = 0,
|
||||||
|
enforce_eager: Optional[bool] = None,
|
||||||
|
max_context_len_to_capture: Optional[int] = None,
|
||||||
|
max_seq_len_to_capture: int = 8192,
|
||||||
|
disable_custom_all_reduce: bool = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Load model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: The path to the model.
|
||||||
|
model_kwargs: Keyword arguments for the model.
|
||||||
|
tokenizer_kwargs: Keyword arguments for the tokenizer.
|
||||||
|
tokenizer_path: The path to the tokenizer.
|
||||||
|
trust_remote_code: Trust remote code (e.g., from HuggingFace) when downloading the model and tokenizer.
|
||||||
|
tensor_parallel_size: The number of GPUs to use for distributed execution with tensor parallelism.
|
||||||
|
quantization: The method used to quantize the model weights
|
||||||
|
gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV cache.
|
||||||
|
swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
|
||||||
|
cpu_offload_gb: The size (GiB) of CPU memory to use for offloading the model weights.
|
||||||
|
enforce_eager: Whether to enforce eager execution.
|
||||||
|
max_context_len_to_capture: Maximum context len covered by CUDA graphs.
|
||||||
|
max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
|
||||||
|
disable_custom_all_reduce: See ParallelConfig
|
||||||
|
|
||||||
|
"""
|
||||||
|
if "torch_dtype" in model_kwargs:
|
||||||
|
model_kwargs["dtype"] = eval(model_kwargs["torch_dtype"])
|
||||||
|
model_kwargs.pop("torch_dtype")
|
||||||
|
else:
|
||||||
|
model_kwargs.setdefault("dtype", torch.float16)
|
||||||
|
|
||||||
|
if "trust_remote_code" in model_kwargs:
|
||||||
|
trust_remote_code = model_kwargs["trust_remote_code"]
|
||||||
|
model_kwargs.pop("trust_remote_code")
|
||||||
|
|
||||||
|
if "trust_remote_code" in tokenizer_kwargs:
|
||||||
|
trust_remote_code = tokenizer_kwargs["trust_remote_code"]
|
||||||
|
tokenizer_kwargs.pop("trust_remote_code")
|
||||||
|
|
||||||
|
self.model = LLM(
|
||||||
|
model=path,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
tensor_parallel_size=tensor_parallel_size,
|
||||||
|
quantization=quantization,
|
||||||
|
gpu_memory_utilization=gpu_memory_utilization,
|
||||||
|
swap_space=swap_space,
|
||||||
|
cpu_offload_gb=cpu_offload_gb,
|
||||||
|
enforce_eager=enforce_eager,
|
||||||
|
max_context_len_to_capture=max_context_len_to_capture,
|
||||||
|
max_seq_len_to_capture=max_seq_len_to_capture,
|
||||||
|
disable_custom_all_reduce=disable_custom_all_reduce,
|
||||||
|
**model_kwargs,
|
||||||
|
**tokenizer_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.tokenizer = self.model.get_tokenizer()
|
||||||
|
|
||||||
|
if self.batch_size > 1:
|
||||||
|
self.tokenizer.padding_side = "left"
|
||||||
|
self.tokenizer.truncation_side = "left"
|
||||||
|
|
||||||
|
if self.tokenizer.pad_token_id is None:
|
||||||
|
self.logger.warning("pad_token_id is not set for the tokenizer. " "Using eos_token_id as pad_token_id.")
|
||||||
|
if self.tokenizer.eos_token:
|
||||||
|
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||||
|
elif hasattr(self.tokenizer, "eod_id"):
|
||||||
|
# Qwen has an eod token "<|endoftext|>".
|
||||||
|
self.tokenizer.pad_token_id = self.tokenizer.eod_id
|
||||||
|
else:
|
||||||
|
self.logger.error("Neither eos_token nor eod_id is available for setting pad_token_id.")
|
||||||
|
raise ValueError(
|
||||||
|
"The tokenizer does not have a pad_token_id, eos_token, or eod_id. "
|
||||||
|
"Please set pad_token_id manually."
|
||||||
|
)
|
||||||
|
|
||||||
|
def _calculate_loss(self, inputs: List[str], labels: List[str]) -> Tuple[List]:
|
||||||
|
"""
|
||||||
|
Calculate loss on target tokens. Adapted from https://github.com/open-compass/opencompass/blob/c2bcd8725e615ec455bf5b7301f8d09962cd64e3/opencompass/models/vllm.py#L110
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_ids_list: A batch of input string.
|
||||||
|
labels: A batch of labels.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of loss and a list of label length.
|
||||||
|
|
||||||
|
"""
|
||||||
|
batch_size = len(inputs)
|
||||||
|
sampling_kwargs = SamplingParams(logprobs=1)
|
||||||
|
outputs = self.model.generate(inputs, sampling_kwargs)
|
||||||
|
ce_loss = []
|
||||||
|
|
||||||
|
if labels is not None:
|
||||||
|
lens = [len(self.tokenizer.encode(label, add_special_tokens=False)) for label in labels]
|
||||||
|
else:
|
||||||
|
lens = [1] * batch_size
|
||||||
|
|
||||||
|
for i in range(batch_size):
|
||||||
|
logprobs = outputs[i].outputs[0].logprobs
|
||||||
|
token_ids = outputs[i].outputs[0].token_ids
|
||||||
|
|
||||||
|
logprobs_list = [logprobs[i][token_ids[i]] for i in range(len(logprobs))]
|
||||||
|
logprobs_list = [i.logprob for i in logprobs_list]
|
||||||
|
logprobs_list = np.array(logprobs_list)
|
||||||
|
|
||||||
|
if lens is not None:
|
||||||
|
logprobs_list = logprobs_list[: lens[i]]
|
||||||
|
|
||||||
|
loss = -logprobs_list.sum(axis=-1) / lens[i]
|
||||||
|
ce_loss.append(loss)
|
||||||
|
|
||||||
|
batch_loss = np.array(ce_loss)
|
||||||
|
|
||||||
|
return batch_loss, lens
|
||||||
|
|
||||||
|
def inference(self, data_loader: DataLoader, inference_kwargs: Dict[str, Any], debug: bool = False) -> List[Dict]:
|
||||||
|
"""
|
||||||
|
Infer the given data.
|
||||||
|
This function will call self.generate() to get model outputs and use LogitsProcessor param to get specific logits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: The data for inference.
|
||||||
|
inference_kwargs: Arguments for inference.
|
||||||
|
debug: Whether to display generated prompt for debugging.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Inference results.
|
||||||
|
|
||||||
|
"""
|
||||||
|
calculate_loss = inference_kwargs["calculate_loss"]
|
||||||
|
classes = inference_kwargs["all_classes"]
|
||||||
|
language = inference_kwargs["language"]
|
||||||
|
calculate_overall_loss = inference_kwargs["calculate_overall_loss"]
|
||||||
|
max_new_tokens = inference_kwargs["max_new_tokens"]
|
||||||
|
few_shot_data = inference_kwargs.get("few_shot_data", None)
|
||||||
|
|
||||||
|
# Some classification questions' options are texts not a single letter such as A, B, C and D.
|
||||||
|
# If the text length is greater than 1, we won't calculate loss over choices.
|
||||||
|
if classes is not None and any(len(c) > 1 for c in classes):
|
||||||
|
classes = None
|
||||||
|
|
||||||
|
self.choices = classes
|
||||||
|
self.indices_for_choices = None
|
||||||
|
if self.choices:
|
||||||
|
# Get indices for each choice
|
||||||
|
self._get_choices_indices(language)
|
||||||
|
|
||||||
|
self.str_label_map = {choice: idx for idx, choice in enumerate(self.choices)}
|
||||||
|
|
||||||
|
bar = tqdm(
|
||||||
|
range(len(data_loader)),
|
||||||
|
desc=f"{inference_kwargs['dataset']}-{inference_kwargs['category']} Inference steps",
|
||||||
|
disable=not is_rank_0(),
|
||||||
|
)
|
||||||
|
loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
|
||||||
|
|
||||||
|
answers = []
|
||||||
|
|
||||||
|
for i, batch in enumerate(data_loader):
|
||||||
|
batch_prompt, batch_target = get_batch_prompt(
|
||||||
|
self.prompt_template, batch, few_shot_data, self.tokenizer, self.model_max_length
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_rank_0() and debug and i == 0:
|
||||||
|
self.logger.info(
|
||||||
|
f"Inference arguments for dataset {batch[0]['dataset']} category {batch[0]['category']} is:\n{inference_kwargs}"
|
||||||
|
)
|
||||||
|
self.logger.info("-" * 120)
|
||||||
|
self.logger.info("An example prompt and prompt with target is:")
|
||||||
|
self.logger.info("-" * 120)
|
||||||
|
self.logger.info(batch_prompt[0])
|
||||||
|
self.logger.info("-" * 120)
|
||||||
|
self.logger.info(batch_prompt[0] + batch_target[0][0])
|
||||||
|
|
||||||
|
if not calculate_overall_loss:
|
||||||
|
batch_decodes, scores = self.generate(batch_prompt, max_new_tokens)
|
||||||
|
|
||||||
|
if calculate_loss:
|
||||||
|
batch_losses, batch_target_token_nums, batch_bytes_nums = self.get_loss(
|
||||||
|
batch_prompt, batch_target, calculate_overall_loss
|
||||||
|
)
|
||||||
|
|
||||||
|
probs = []
|
||||||
|
if self.indices_for_choices:
|
||||||
|
scores = scores.to(torch.float32)
|
||||||
|
# If we have indices_for_choices(must be single-choice question), there will be only one target answer for one data sample.
|
||||||
|
# Otherwise this will violate the single-choice setting.
|
||||||
|
|
||||||
|
if calculate_loss:
|
||||||
|
labels = [self.str_label_map[batch[j]["target"]] for j in range(len(batch))]
|
||||||
|
|
||||||
|
loss_over_choices = loss_fct(scores, torch.tensor(labels, dtype=torch.long)).numpy().tolist()
|
||||||
|
|
||||||
|
probs = scores.numpy().tolist()
|
||||||
|
probs = [
|
||||||
|
{choice: probs[i][self.str_label_map[choice]] for choice in self.choices} for i in range(len(probs))
|
||||||
|
]
|
||||||
|
|
||||||
|
for j in range(len(batch)):
|
||||||
|
if not calculate_overall_loss:
|
||||||
|
if isinstance(batch[j]["output"], list):
|
||||||
|
batch[j]["output"].append(batch_decodes[j].strip())
|
||||||
|
else:
|
||||||
|
batch[j]["output"] = batch_decodes[j].strip()
|
||||||
|
|
||||||
|
if isinstance(scores, torch.Tensor):
|
||||||
|
batch[j]["logits_over_choices"] = probs[j]
|
||||||
|
|
||||||
|
if calculate_loss:
|
||||||
|
batch[j]["loss_over_choices"] = loss_over_choices[j]
|
||||||
|
|
||||||
|
if calculate_loss:
|
||||||
|
batch[j]["loss"] = (np.array(batch_losses[j]) / np.array(batch_target_token_nums[j])).tolist()
|
||||||
|
|
||||||
|
# loss_sum is specially used for pertrain dataset for calculating per-byte-perplexity.
|
||||||
|
# However, loss (which is per sample loss) suffices for most cases.
|
||||||
|
batch[j]["loss_sum"] = batch_losses[j]
|
||||||
|
batch[j]["token_num"] = batch_target_token_nums[j]
|
||||||
|
|
||||||
|
if batch_bytes_nums:
|
||||||
|
batch[j]["byte_num"] = batch_bytes_nums[j]
|
||||||
|
answers.extend(batch)
|
||||||
|
|
||||||
|
bar.update()
|
||||||
|
|
||||||
|
return answers
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def generate(self, inputs: List[str], max_new_tokens: int, **kwargs) -> List[str]:
|
||||||
|
"""Generate results given a list of inputs and get logits of the first new token over choices.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs: A list of strings.
|
||||||
|
max_new_tokens: Max new tokens for generation.
|
||||||
|
kwargs: Key arguments for generation
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of generated strings and logits over choices.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Currently the function only returns the logits of the first new token.
|
||||||
|
It is used for single choice question.
|
||||||
|
For multiple choices question, please avoid using the loss over choices.
|
||||||
|
You should set argument choices as None in self.inference().
|
||||||
|
|
||||||
|
"""
|
||||||
|
truncated_inputs = self._get_truncated_prompts(inputs, max_new_tokens)
|
||||||
|
|
||||||
|
generation_kwargs = kwargs.copy()
|
||||||
|
generation_kwargs.update({"max_tokens": max_new_tokens})
|
||||||
|
logits_processor = GetTokenLogitsProcessor(self.indices_for_choices)
|
||||||
|
|
||||||
|
sampling_kwargs = SamplingParams(logits_processors=[logits_processor], **generation_kwargs)
|
||||||
|
|
||||||
|
outputs = self.model.generate(truncated_inputs, sampling_kwargs)
|
||||||
|
output_strs = []
|
||||||
|
for output in outputs:
|
||||||
|
generated_text = output.outputs[0].text
|
||||||
|
output_strs.append(generated_text)
|
||||||
|
scores = logits_processor.get_target_logits()
|
||||||
|
|
||||||
|
return output_strs, scores
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def get_loss(
|
||||||
|
self, batch_prompt: List[str], batch_target: List[List[str]], calculate_overall_loss: bool
|
||||||
|
) -> List[List[float]]:
|
||||||
|
"""
|
||||||
|
Calculate loss only on target tokens.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch: A batch of prompt without target answer.
|
||||||
|
batch_target: A batch of target answer. Sometimes one question can have multiple target answers.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Loss.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
# We set max_new_tokens in self._get_truncated_prompts to 0 because we only need logits to calculate loss.
|
||||||
|
# We don't need to generate new tokens.
|
||||||
|
# Target answer's length is usually << model_max_length, but we still call it in case.
|
||||||
|
# We don't call self._get_truncated_prompts for batch_prompt because we need target answer's length first to reserve some space for target answer's tokens.
|
||||||
|
if not calculate_overall_loss:
|
||||||
|
batch_target = [self._get_truncated_prompts(prompt_target, 0) for prompt_target in batch_target]
|
||||||
|
|
||||||
|
# Get the number of target answers for different questions
|
||||||
|
batch_target_nums = [len(prompt_target) for prompt_target in batch_target]
|
||||||
|
|
||||||
|
if calculate_overall_loss:
|
||||||
|
batch = []
|
||||||
|
bytes_list = []
|
||||||
|
batch_prompt_pretrain = []
|
||||||
|
for p, b in zip(batch_prompt, batch_target):
|
||||||
|
batch.append(p + b[0])
|
||||||
|
|
||||||
|
for input in batch:
|
||||||
|
# Pretrain data tends to be very long, sometimes much larger than the model_max_length, we only tokenize 1/ratio of the data first to accelerate the tokenization process.
|
||||||
|
# Once the length of the result is greater or equal to model_max_length, we stop iterating on ratios and use the result as input_ids and labels.
|
||||||
|
# After all, the rest of the original string doesn't need to be tokenized at the first place.
|
||||||
|
# Pretrain data tends to be very long, sometimes much larger than the model_max_length, we only tokenize 1/ratio of the data first to accelerate the tokenization process.
|
||||||
|
# Once the length of the result is greater or equal to model_max_length, we stop iterating on ratios and use the result as input_ids and labels.
|
||||||
|
# After all, the rest of the original string doesn't need to be tokenized at the first place.
|
||||||
|
ratio = [16, 8, 4, 2, 1]
|
||||||
|
tokenized = None
|
||||||
|
for r in ratio:
|
||||||
|
tokenized = self.tokenizer(
|
||||||
|
[input[0 : len(input) // r]],
|
||||||
|
truncation=True,
|
||||||
|
max_length=self.model_max_length,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
if tokenized.input_ids.size(1) >= self.model_max_length:
|
||||||
|
break
|
||||||
|
|
||||||
|
string = self.tokenizer.decode(tokenized.input_ids[0], skip_special_tokens=True)
|
||||||
|
batch_prompt_pretrain.append(string)
|
||||||
|
bytes_list.append(len(string.encode("utf-8")))
|
||||||
|
|
||||||
|
batch_prompt = copy.deepcopy(batch_prompt_pretrain)
|
||||||
|
batch_target = None
|
||||||
|
else:
|
||||||
|
batch_prompt_processed = []
|
||||||
|
batch_target_processed = []
|
||||||
|
for prompt, targets in zip(batch_prompt, batch_target):
|
||||||
|
for target in targets:
|
||||||
|
target_tokenized = self.tokenizer(
|
||||||
|
[target], truncation=True, max_length=self.model_max_length, return_tensors="pt"
|
||||||
|
)
|
||||||
|
max_new_tokens = target_tokenized["input_ids"][0].size(0)
|
||||||
|
prompt_with_correct_length = self._get_truncated_prompts([prompt], max_new_tokens)[0]
|
||||||
|
batch_prompt_processed.append(prompt_with_correct_length)
|
||||||
|
batch_target_processed.append(target)
|
||||||
|
|
||||||
|
batch_prompt = copy.deepcopy(batch_prompt_processed)
|
||||||
|
batch_target = copy.deepcopy(batch_target_processed)
|
||||||
|
bytes_list = None
|
||||||
|
|
||||||
|
# Because of multiple target answers, the final batch size may be greater than self.batch_size.
|
||||||
|
# We will generate new batches.
|
||||||
|
losses = []
|
||||||
|
target_token_nums = []
|
||||||
|
|
||||||
|
losses_per_batch, target_token_num_per_batch = self._calculate_loss(batch_prompt, batch_target)
|
||||||
|
losses.extend(losses_per_batch)
|
||||||
|
target_token_nums.extend(target_token_num_per_batch)
|
||||||
|
|
||||||
|
start_indice = 0
|
||||||
|
losses_per_sample = []
|
||||||
|
|
||||||
|
target_token_nums_per_sample = []
|
||||||
|
bytes_nums_per_sample = []
|
||||||
|
for length in batch_target_nums:
|
||||||
|
losses_per_sample.append(losses[start_indice : start_indice + length])
|
||||||
|
target_token_nums_per_sample.append(target_token_nums[start_indice : start_indice + length])
|
||||||
|
|
||||||
|
if bytes_list:
|
||||||
|
bytes_nums_per_sample.append(bytes_list[start_indice : start_indice + length])
|
||||||
|
|
||||||
|
start_indice += length
|
||||||
|
|
||||||
|
if bytes_list:
|
||||||
|
return losses_per_sample, target_token_nums_per_sample, bytes_nums_per_sample
|
||||||
|
|
||||||
|
return losses_per_sample, target_token_nums_per_sample, None
|
||||||
|
|
||||||
|
|
||||||
|
class GetTokenLogitsProcessor:
|
||||||
|
"""
|
||||||
|
LogitsProcessor to get specific logits
|
||||||
|
|
||||||
|
Args:
|
||||||
|
indices_for_choices: token indices of required tokens
|
||||||
|
target_logits: store all the target logits
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
indices_for_choices: List[List[int]],
|
||||||
|
):
|
||||||
|
self.indices_for_choices = (indices_for_choices,)
|
||||||
|
self.target_logits = []
|
||||||
|
|
||||||
|
def __call__(self, input_ids: torch.Tensor, logits: torch.Tensor) -> torch.Tensor:
|
||||||
|
choice_scores = []
|
||||||
|
|
||||||
|
if not input_ids:
|
||||||
|
for option_indices in self.indices_for_choices[0]:
|
||||||
|
choice_scores.append(logits[option_indices].detach().cpu())
|
||||||
|
|
||||||
|
choice_scores = torch.max(torch.stack(choice_scores), dim=0)[0]
|
||||||
|
self.target_logits.append(choice_scores)
|
||||||
|
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def get_target_logits(self) -> torch.Tensor:
|
||||||
|
return torch.stack(self.target_logits) if self.target_logits else torch.tensor([])
|
@ -69,7 +69,7 @@ def rm_and_merge(
|
|||||||
os.remove(directory)
|
os.remove(directory)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
print(len(answers["data"]))
|
|
||||||
all_answers[category] = answers
|
all_answers[category] = answers
|
||||||
|
|
||||||
all_answers_with_dataset_class["inference_results"] = all_answers
|
all_answers_with_dataset_class["inference_results"] = all_answers
|
||||||
|
@ -10,3 +10,4 @@ matplotlib
|
|||||||
pandas
|
pandas
|
||||||
seaborn
|
seaborn
|
||||||
scikit-learn
|
scikit-learn
|
||||||
|
vllm==0.5.5
|
||||||
|
@ -279,4 +279,4 @@ class CudaAccelerator(BaseAccelerator):
|
|||||||
"""
|
"""
|
||||||
Return autocast function
|
Return autocast function
|
||||||
"""
|
"""
|
||||||
return torch.cuda.amp.autocast(enabled=enabled, dtype=dtype, cache_enabled=cache_enabled)
|
return torch.amp.autocast(device_type="cuda", enabled=enabled, dtype=dtype, cache_enabled=cache_enabled)
|
||||||
|
@ -43,7 +43,7 @@ class MixedPrecisionMixin(ABC):
|
|||||||
dtype: torch.dtype
|
dtype: torch.dtype
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def pre_backward(self, loss: Tensor) -> Tensor:
|
def pre_backward(self, loss: Tensor, *args, **kwargs) -> Tensor:
|
||||||
"""Called before backward.
|
"""Called before backward.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor, inf
|
from torch import Tensor, inf
|
||||||
@ -84,14 +84,20 @@ class MixedPrecisionOptimizer(OptimizerWrapper):
|
|||||||
self.master_to_working_map[master_p] = p
|
self.master_to_working_map[master_p] = p
|
||||||
master_params.append(master_p)
|
master_params.append(master_p)
|
||||||
group["params"] = master_params
|
group["params"] = master_params
|
||||||
|
self._current_grad_norm: Optional[float] = None
|
||||||
|
|
||||||
def backward(self, loss: Tensor, *args, **kwargs):
|
def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs):
|
||||||
loss = self.mixed_precision.pre_backward(loss)
|
loss = self.mixed_precision.pre_backward(loss)
|
||||||
loss.backward(*args, **kwargs)
|
loss.backward(inputs=inputs, retain_graph=retain_graph, **kwargs)
|
||||||
|
|
||||||
def backward_by_grad(self, tensor: Tensor, grad: Tensor):
|
def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False):
|
||||||
grad = self.mixed_precision.pre_backward_by_grad(tensor, grad)
|
grad = self.mixed_precision.pre_backward_by_grad(tensor, grad)
|
||||||
tensor.backward(grad)
|
torch.autograd.backward(
|
||||||
|
tensors=tensor,
|
||||||
|
grad_tensors=grad,
|
||||||
|
inputs=inputs,
|
||||||
|
retain_graph=retain_graph,
|
||||||
|
)
|
||||||
|
|
||||||
def zero_grad(self, *args, **kwargs):
|
def zero_grad(self, *args, **kwargs):
|
||||||
for p in self.working_to_master_map.keys():
|
for p in self.working_to_master_map.keys():
|
||||||
@ -187,6 +193,7 @@ class MixedPrecisionOptimizer(OptimizerWrapper):
|
|||||||
if p.grad is not None
|
if p.grad is not None
|
||||||
]
|
]
|
||||||
total_norm = self._compute_grad_norm(param_gradient_pairs)
|
total_norm = self._compute_grad_norm(param_gradient_pairs)
|
||||||
|
self._current_grad_norm = total_norm
|
||||||
self._unscale_and_clip_grads(total_norm)
|
self._unscale_and_clip_grads(total_norm)
|
||||||
|
|
||||||
self.optim.step(*args, **kwargs)
|
self.optim.step(*args, **kwargs)
|
||||||
@ -212,3 +219,6 @@ class MixedPrecisionOptimizer(OptimizerWrapper):
|
|||||||
|
|
||||||
def get_master_to_working_map(self) -> Dict[int, torch.Tensor]:
|
def get_master_to_working_map(self) -> Dict[int, torch.Tensor]:
|
||||||
return {id(master_p): working_p for master_p, working_p in self.master_to_working_map.items()}
|
return {id(master_p): working_p for master_p, working_p in self.master_to_working_map.items()}
|
||||||
|
|
||||||
|
def get_grad_norm(self, norm_type=2, **kwargs):
|
||||||
|
return self._current_grad_norm
|
||||||
|
@ -288,7 +288,14 @@ class Booster:
|
|||||||
|
|
||||||
return self.plugin.enable_lora(model, pretrained_dir, lora_config, bnb_quantization_config)
|
return self.plugin.enable_lora(model, pretrained_dir, lora_config, bnb_quantization_config)
|
||||||
|
|
||||||
def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True) -> None:
|
def load_model(
|
||||||
|
self,
|
||||||
|
model: Union[nn.Module, ModelWrapper],
|
||||||
|
checkpoint: str,
|
||||||
|
strict: bool = True,
|
||||||
|
low_cpu_mem_mode: bool = True,
|
||||||
|
num_threads: int = 1,
|
||||||
|
) -> None:
|
||||||
"""Load model from checkpoint.
|
"""Load model from checkpoint.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -298,8 +305,12 @@ class Booster:
|
|||||||
strict (bool, optional): whether to strictly enforce that the keys
|
strict (bool, optional): whether to strictly enforce that the keys
|
||||||
in :attr:`state_dict` match the keys returned by this module's
|
in :attr:`state_dict` match the keys returned by this module's
|
||||||
:meth:`~torch.nn.Module.state_dict` function. Defaults to True.
|
:meth:`~torch.nn.Module.state_dict` function. Defaults to True.
|
||||||
|
low_cpu_mem_mode (bool): whether to load the model in low cpu memory mode. If false, it will use RAM cache to accelerate loading. Default: True.
|
||||||
|
num_threads (int): number of threads to use when loading the model. Only useful when disabling low cpu mem mode. Default: 1.
|
||||||
"""
|
"""
|
||||||
self.checkpoint_io.load_model(model, checkpoint, strict)
|
self.checkpoint_io.load_model(
|
||||||
|
model, checkpoint, strict, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
|
||||||
|
)
|
||||||
|
|
||||||
def save_model(
|
def save_model(
|
||||||
self,
|
self,
|
||||||
@ -310,6 +321,7 @@ class Booster:
|
|||||||
prefix: Optional[str] = None,
|
prefix: Optional[str] = None,
|
||||||
size_per_shard: int = 1024,
|
size_per_shard: int = 1024,
|
||||||
use_safetensors: bool = False,
|
use_safetensors: bool = False,
|
||||||
|
use_async: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Save model to checkpoint.
|
"""Save model to checkpoint.
|
||||||
|
|
||||||
@ -324,6 +336,7 @@ class Booster:
|
|||||||
names to compose the keys in state_dict. Defaults to None.
|
names to compose the keys in state_dict. Defaults to None.
|
||||||
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
|
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
|
||||||
use_safetensors (bool, optional): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved.
|
use_safetensors (bool, optional): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved.
|
||||||
|
use_async (bool, optional): whether to save the state_dict of model asynchronously. Default: False.
|
||||||
"""
|
"""
|
||||||
self.checkpoint_io.save_model(
|
self.checkpoint_io.save_model(
|
||||||
model,
|
model,
|
||||||
@ -333,20 +346,28 @@ class Booster:
|
|||||||
prefix=prefix,
|
prefix=prefix,
|
||||||
size_per_shard=size_per_shard,
|
size_per_shard=size_per_shard,
|
||||||
use_safetensors=use_safetensors,
|
use_safetensors=use_safetensors,
|
||||||
|
use_async=use_async,
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_optimizer(self, optimizer: Optimizer, checkpoint: str) -> None:
|
def load_optimizer(
|
||||||
|
self,
|
||||||
|
optimizer: Optimizer,
|
||||||
|
checkpoint: str,
|
||||||
|
low_cpu_mem_mode: bool = True,
|
||||||
|
num_threads: int = 1,
|
||||||
|
) -> None:
|
||||||
"""Load optimizer from checkpoint.
|
"""Load optimizer from checkpoint.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
optimizer (Optimizer): An optimizer boosted by Booster.
|
optimizer (Optimizer): An optimizer boosted by Booster.
|
||||||
checkpoint (str): Path to the checkpoint. It must be a local path.
|
checkpoint (str): Path to the checkpoint. It must be a local path.
|
||||||
It should be a directory path if the checkpoint is sharded. Otherwise, it should be a file path.
|
It should be a directory path if the checkpoint is sharded. Otherwise, it should be a file path.
|
||||||
prefix (str, optional): A prefix added to parameter and buffer
|
low_cpu_mem_mode (bool): whether to load the model in low cpu memory mode. If false, it will use RAM cache to accelerate loading. Default: True.
|
||||||
names to compose the keys in state_dict. Defaults to None.
|
num_threads (int): number of threads to use when loading the model. Only useful when disabling low cpu mem mode. Default: 1.
|
||||||
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
|
|
||||||
"""
|
"""
|
||||||
self.checkpoint_io.load_optimizer(optimizer, checkpoint)
|
self.checkpoint_io.load_optimizer(
|
||||||
|
optimizer, checkpoint, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
|
||||||
|
)
|
||||||
|
|
||||||
def save_optimizer(
|
def save_optimizer(
|
||||||
self,
|
self,
|
||||||
@ -356,6 +377,7 @@ class Booster:
|
|||||||
gather_dtensor: bool = True,
|
gather_dtensor: bool = True,
|
||||||
prefix: Optional[str] = None,
|
prefix: Optional[str] = None,
|
||||||
size_per_shard: int = 1024,
|
size_per_shard: int = 1024,
|
||||||
|
use_async: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Save optimizer to checkpoint.
|
Save optimizer to checkpoint.
|
||||||
@ -371,7 +393,9 @@ class Booster:
|
|||||||
names to compose the keys in state_dict. Defaults to None.
|
names to compose the keys in state_dict. Defaults to None.
|
||||||
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
|
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
|
||||||
"""
|
"""
|
||||||
self.checkpoint_io.save_optimizer(optimizer, checkpoint, shard, gather_dtensor, prefix, size_per_shard)
|
self.checkpoint_io.save_optimizer(
|
||||||
|
optimizer, checkpoint, shard, gather_dtensor, prefix, size_per_shard, use_async=use_async
|
||||||
|
)
|
||||||
|
|
||||||
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str) -> None:
|
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str) -> None:
|
||||||
"""Save lr scheduler to checkpoint.
|
"""Save lr scheduler to checkpoint.
|
||||||
|
@ -46,9 +46,9 @@ class TorchAMPOptimizer(OptimizerWrapper):
|
|||||||
growth_interval=growth_interval,
|
growth_interval=growth_interval,
|
||||||
)
|
)
|
||||||
|
|
||||||
def backward(self, loss: Tensor, *args, **kwargs) -> None:
|
def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs) -> None:
|
||||||
scaled_loss = self.scale_loss(loss)
|
scaled_loss = self.scale_loss(loss)
|
||||||
scaled_loss.backward(*args, **kwargs)
|
scaled_loss.backward(inputs=inputs, retain_graph=retain_graph, **kwargs)
|
||||||
|
|
||||||
def step(self, *args, **kwargs) -> Optional[float]:
|
def step(self, *args, **kwargs) -> Optional[float]:
|
||||||
out = self.scaler.step(self.optim, *args, **kwargs)
|
out = self.scaler.step(self.optim, *args, **kwargs)
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
import gc
|
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -17,9 +16,11 @@ from torch.utils.data.distributed import DistributedSampler
|
|||||||
from colossalai.accelerator import get_accelerator
|
from colossalai.accelerator import get_accelerator
|
||||||
from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO
|
from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO
|
||||||
from colossalai.checkpoint_io.utils import (
|
from colossalai.checkpoint_io.utils import (
|
||||||
|
async_save_state_dict_shards,
|
||||||
|
create_pinned_state_dict,
|
||||||
get_model_base_filenames,
|
get_model_base_filenames,
|
||||||
get_optimizer_base_filenames,
|
get_optimizer_base_filenames,
|
||||||
load_shard_state_dict,
|
load_state_dict_shards,
|
||||||
save_config_file,
|
save_config_file,
|
||||||
save_state_dict,
|
save_state_dict,
|
||||||
save_state_dict_shards,
|
save_state_dict_shards,
|
||||||
@ -65,7 +66,14 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
|||||||
self.coordinator = DistCoordinator()
|
self.coordinator = DistCoordinator()
|
||||||
self.logger = get_dist_logger()
|
self.logger = get_dist_logger()
|
||||||
|
|
||||||
def save_unsharded_model(self, model: GeminiDDP, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
|
def save_unsharded_model(
|
||||||
|
self,
|
||||||
|
model: GeminiDDP,
|
||||||
|
checkpoint: str,
|
||||||
|
gather_dtensor: bool,
|
||||||
|
use_safetensors: bool,
|
||||||
|
use_async: bool = False,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Save sharded model to checkpoint but only on master process.
|
Save sharded model to checkpoint but only on master process.
|
||||||
The model should be unwrapped in self.load_model via ModelWrapper.unwrap.
|
The model should be unwrapped in self.load_model via ModelWrapper.unwrap.
|
||||||
@ -74,17 +82,39 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
|||||||
assert isinstance(model, GeminiDDP), "Please boost the model before saving!"
|
assert isinstance(model, GeminiDDP), "Please boost the model before saving!"
|
||||||
state_dict = model.state_dict(only_rank_0=True)
|
state_dict = model.state_dict(only_rank_0=True)
|
||||||
if self.coordinator.is_master():
|
if self.coordinator.is_master():
|
||||||
save_state_dict(state_dict, checkpoint, use_safetensors)
|
if use_async:
|
||||||
|
from colossalai.utils.safetensors import save
|
||||||
|
|
||||||
def load_unsharded_model(self, model: GeminiDDP, checkpoint: str, strict: bool = True):
|
if hash(model) not in self.pinned_state_dicts:
|
||||||
|
self.pinned_state_dicts[hash(model)] = create_pinned_state_dict(state_dict)
|
||||||
|
for k, v in state_dict.items():
|
||||||
|
self.pinned_state_dicts[hash(model)][k].copy_(v)
|
||||||
|
state_dict[k] = self.pinned_state_dicts[hash(model)][k]
|
||||||
|
writer = save(checkpoint, state_dict)
|
||||||
|
self.async_writers.append(writer)
|
||||||
|
else:
|
||||||
|
save_state_dict(state_dict, checkpoint, use_safetensors)
|
||||||
|
|
||||||
|
def load_unsharded_model(
|
||||||
|
self,
|
||||||
|
model: GeminiDDP,
|
||||||
|
checkpoint: str,
|
||||||
|
strict: bool = True,
|
||||||
|
low_cpu_mem_mode: bool = True,
|
||||||
|
num_threads: int = 1,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Load model from checkpoint with automatic unwrapping.
|
Load model from checkpoint with automatic unwrapping.
|
||||||
The model should be unwrapped in self.load_model via ModelWrapper.unwrap.
|
The model should be unwrapped in self.load_model via ModelWrapper.unwrap.
|
||||||
"""
|
"""
|
||||||
assert isinstance(model, GeminiDDP), "Please boost the model before loading!"
|
assert isinstance(model, GeminiDDP), "Please boost the model before loading!"
|
||||||
super().load_unsharded_model(model, checkpoint, strict=strict)
|
super().load_unsharded_model(
|
||||||
|
model, checkpoint, strict=strict, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
|
||||||
|
)
|
||||||
|
|
||||||
def save_unsharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint: str, gather_dtensor: bool):
|
def save_unsharded_optimizer(
|
||||||
|
self, optimizer: GeminiOptimizer, checkpoint: str, gather_dtensor: bool, use_async: bool = False
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Save unsharded optimizer state dict to checkpoint.
|
Save unsharded optimizer state dict to checkpoint.
|
||||||
After calling optimizer.state_dict(), the complete optimizer states will be collected on master rank.
|
After calling optimizer.state_dict(), the complete optimizer states will be collected on master rank.
|
||||||
@ -94,15 +124,31 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
|||||||
assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before saving!"
|
assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before saving!"
|
||||||
state_dict = optimizer.state_dict()
|
state_dict = optimizer.state_dict()
|
||||||
if self.coordinator.is_master():
|
if self.coordinator.is_master():
|
||||||
save_state_dict(state_dict, checkpoint, use_safetensors=False)
|
if use_async:
|
||||||
|
from colossalai.utils.safetensors import _flatten_optim_state_dict, save
|
||||||
|
|
||||||
def load_unsharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint: str):
|
flatten_state_dict, metadata = _flatten_optim_state_dict(state_dict)
|
||||||
|
if id(optimizer) not in self.pinned_state_dicts:
|
||||||
|
self.pinned_state_dicts[id(optimizer)] = create_pinned_state_dict(flatten_state_dict)
|
||||||
|
for k, v in flatten_state_dict.items():
|
||||||
|
self.pinned_state_dicts[id(optimizer)][k].copy_(v)
|
||||||
|
flatten_state_dict[k] = self.pinned_state_dicts[id(optimizer)][k]
|
||||||
|
writer = save(checkpoint, flatten_state_dict, metadata)
|
||||||
|
self.async_writers.append(writer)
|
||||||
|
else:
|
||||||
|
save_state_dict(state_dict, checkpoint, use_safetensors=False)
|
||||||
|
|
||||||
|
def load_unsharded_optimizer(
|
||||||
|
self, optimizer: GeminiOptimizer, checkpoint: str, low_cpu_mem_mode: bool = True, num_threads: int = 1
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Loading unsharded optimizer from checkpoint file.
|
Loading unsharded optimizer from checkpoint file.
|
||||||
For each process, only loading optimizer states of parameters it controls.
|
For each process, only loading optimizer states of parameters it controls.
|
||||||
"""
|
"""
|
||||||
assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before loading!"
|
assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before loading!"
|
||||||
super().load_unsharded_optimizer(optimizer, checkpoint)
|
super().load_unsharded_optimizer(
|
||||||
|
optimizer, checkpoint, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
|
||||||
|
)
|
||||||
|
|
||||||
def save_sharded_model(
|
def save_sharded_model(
|
||||||
self,
|
self,
|
||||||
@ -112,6 +158,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
|||||||
prefix: Optional[str] = None,
|
prefix: Optional[str] = None,
|
||||||
max_shard_size: int = 1024,
|
max_shard_size: int = 1024,
|
||||||
use_safetensors: bool = False,
|
use_safetensors: bool = False,
|
||||||
|
use_async: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Save sharded model.
|
Save sharded model.
|
||||||
@ -124,20 +171,38 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
|||||||
|
|
||||||
Path(checkpoint_path).mkdir(parents=True, exist_ok=True)
|
Path(checkpoint_path).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
state_dict_shard = model.state_dict_shard(max_shard_size=max_shard_size, only_rank_0=True)
|
if use_async and self.coordinator.is_master():
|
||||||
|
if hash(model) not in self.pinned_state_dicts:
|
||||||
|
self.pinned_state_dicts[hash(model)] = {}
|
||||||
|
pinned_state_dicts = self.pinned_state_dicts[hash(model)]
|
||||||
|
else:
|
||||||
|
pinned_state_dicts = None
|
||||||
|
state_dict_shard = model.state_dict_shard(
|
||||||
|
max_shard_size=max_shard_size, only_rank_0=True, pinned_state_dicts=pinned_state_dicts
|
||||||
|
)
|
||||||
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
|
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
|
||||||
index_file = CheckpointIndexFile(checkpoint_path)
|
index_file = CheckpointIndexFile(checkpoint_path)
|
||||||
|
|
||||||
# Save shards of optimizer states.
|
# Save shards of optimizer states.
|
||||||
is_master = self.coordinator.is_master()
|
is_master = self.coordinator.is_master()
|
||||||
total_size = save_state_dict_shards(
|
if use_async:
|
||||||
sharded_state_dict=state_dict_shard,
|
total_size, writers = async_save_state_dict_shards(
|
||||||
checkpoint=checkpoint_path,
|
sharded_state_dict=state_dict_shard,
|
||||||
index_file=index_file,
|
checkpoint=checkpoint_path,
|
||||||
base_filename=weights_name,
|
index_file=index_file,
|
||||||
is_master=is_master,
|
base_filename=weights_name,
|
||||||
use_safetensors=use_safetensors,
|
is_master=is_master,
|
||||||
)
|
)
|
||||||
|
self.async_writers.extend(writers)
|
||||||
|
else:
|
||||||
|
total_size = save_state_dict_shards(
|
||||||
|
sharded_state_dict=state_dict_shard,
|
||||||
|
checkpoint=checkpoint_path,
|
||||||
|
index_file=index_file,
|
||||||
|
base_filename=weights_name,
|
||||||
|
is_master=is_master,
|
||||||
|
use_safetensors=use_safetensors,
|
||||||
|
)
|
||||||
|
|
||||||
# only save the index file on the master rank
|
# only save the index file on the master rank
|
||||||
if self.coordinator.is_master():
|
if self.coordinator.is_master():
|
||||||
@ -152,16 +217,36 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def load_sharded_model(
|
def load_sharded_model(
|
||||||
self, model: GeminiDDP, checkpoint_index_file: Path, strict: bool = False, use_safetensors: bool = False
|
self,
|
||||||
|
model: GeminiDDP,
|
||||||
|
checkpoint_index_file: Path,
|
||||||
|
strict: bool = False,
|
||||||
|
use_safetensors: bool = False,
|
||||||
|
low_cpu_mem_mode: bool = True,
|
||||||
|
num_threads: int = 1,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Load shard model, load model from multiple files.
|
Load shard model, load model from multiple files.
|
||||||
"""
|
"""
|
||||||
assert isinstance(model, GeminiDDP), "Please boost the model before loading!"
|
assert isinstance(model, GeminiDDP), "Please boost the model before loading!"
|
||||||
return super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module=False)
|
return super().load_sharded_model(
|
||||||
|
model,
|
||||||
|
checkpoint_index_file,
|
||||||
|
strict,
|
||||||
|
use_safetensors,
|
||||||
|
load_sub_module=False,
|
||||||
|
low_cpu_mem_mode=low_cpu_mem_mode,
|
||||||
|
num_threads=num_threads,
|
||||||
|
)
|
||||||
|
|
||||||
def save_sharded_optimizer(
|
def save_sharded_optimizer(
|
||||||
self, optimizer: GeminiOptimizer, checkpoint: Path, gather_dtensor: bool, prefix: str, size_per_shard: int
|
self,
|
||||||
|
optimizer: GeminiOptimizer,
|
||||||
|
checkpoint: Path,
|
||||||
|
gather_dtensor: bool,
|
||||||
|
prefix: str,
|
||||||
|
size_per_shard: int,
|
||||||
|
use_async: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Save sharded optimizer state dict to checkpoint folder.
|
Save sharded optimizer state dict to checkpoint folder.
|
||||||
@ -176,7 +261,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
|||||||
Path(checkpoint).mkdir(parents=True, exist_ok=True)
|
Path(checkpoint).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# Preparing file paths and index file.
|
# Preparing file paths and index file.
|
||||||
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
|
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix, use_safetensors=use_async)
|
||||||
index_file = CheckpointIndexFile(checkpoint)
|
index_file = CheckpointIndexFile(checkpoint)
|
||||||
index_file.append_meta_data("param_groups", param_group_file)
|
index_file.append_meta_data("param_groups", param_group_file)
|
||||||
|
|
||||||
@ -187,17 +272,36 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
|||||||
torch.save(param_groups, group_file_path)
|
torch.save(param_groups, group_file_path)
|
||||||
|
|
||||||
# States are broken into shards within max_shard_size.
|
# States are broken into shards within max_shard_size.
|
||||||
state_dict_shard = optimizer.state_shard(prefix=prefix, max_shard_size=size_per_shard, only_rank_0=True)
|
if use_async and self.coordinator.is_master():
|
||||||
|
if id(optimizer) not in self.pinned_state_dicts:
|
||||||
|
self.pinned_state_dicts[id(optimizer)] = {}
|
||||||
|
pinned_state_dicts = self.pinned_state_dicts[id(optimizer)]
|
||||||
|
else:
|
||||||
|
pinned_state_dicts = None
|
||||||
|
state_dict_shard = optimizer.state_shard(
|
||||||
|
prefix=prefix, max_shard_size=size_per_shard, only_rank_0=True, pinned_state_dicts=pinned_state_dicts
|
||||||
|
)
|
||||||
|
|
||||||
# Save shards of optimizer states.
|
# Save shards of optimizer states.
|
||||||
total_size = save_state_dict_shards(
|
if use_async:
|
||||||
sharded_state_dict=state_dict_shard,
|
total_size, writers = async_save_state_dict_shards(
|
||||||
checkpoint=checkpoint,
|
sharded_state_dict=state_dict_shard,
|
||||||
index_file=index_file,
|
checkpoint=checkpoint,
|
||||||
base_filename=states_name,
|
index_file=index_file,
|
||||||
is_master=self.coordinator.is_master(),
|
base_filename=states_name,
|
||||||
use_safetensors=False,
|
is_master=self.coordinator.is_master(),
|
||||||
)
|
state_preprocess=True,
|
||||||
|
)
|
||||||
|
self.async_writers.extend(writers)
|
||||||
|
else:
|
||||||
|
total_size = save_state_dict_shards(
|
||||||
|
sharded_state_dict=state_dict_shard,
|
||||||
|
checkpoint=checkpoint,
|
||||||
|
index_file=index_file,
|
||||||
|
base_filename=states_name,
|
||||||
|
is_master=self.coordinator.is_master(),
|
||||||
|
use_safetensors=False,
|
||||||
|
)
|
||||||
|
|
||||||
# Wrap up index file. Only save it on master rank.
|
# Wrap up index file. Only save it on master rank.
|
||||||
if self.coordinator.is_master():
|
if self.coordinator.is_master():
|
||||||
@ -210,7 +314,14 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
|||||||
ranks=[0],
|
ranks=[0],
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_sharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint_index_file: Path, prefix: str):
|
def load_sharded_optimizer(
|
||||||
|
self,
|
||||||
|
optimizer: GeminiOptimizer,
|
||||||
|
checkpoint_index_file: Path,
|
||||||
|
prefix: str,
|
||||||
|
low_cpu_mem_mode: bool = True,
|
||||||
|
num_threads: int = 1,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Loading sharded optimizer from checkpoint folder, with index file given.
|
Loading sharded optimizer from checkpoint folder, with index file given.
|
||||||
For each process, only loading optimizer states of parameters it controls.
|
For each process, only loading optimizer states of parameters it controls.
|
||||||
@ -238,11 +349,12 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
|||||||
|
|
||||||
# Load optimizer states from shard files under checkpoint path.
|
# Load optimizer states from shard files under checkpoint path.
|
||||||
# For each file, only load the states managed by current process.
|
# For each file, only load the states managed by current process.
|
||||||
for shard_file in checkpoint_files:
|
for state_dict_shard in load_state_dict_shards(
|
||||||
state_dict_shard = load_shard_state_dict(Path(shard_file), use_safetensors=False)
|
checkpoint_files, True, False, low_cpu_mem_mode=low_cpu_mem_mode
|
||||||
|
):
|
||||||
|
if not low_cpu_mem_mode:
|
||||||
|
state_dict_shard = create_pinned_state_dict(state_dict_shard, empty=False, num_threads=num_threads)
|
||||||
optimizer.load_param_states(state_dict_shard)
|
optimizer.load_param_states(state_dict_shard)
|
||||||
del state_dict_shard
|
|
||||||
gc.collect()
|
|
||||||
|
|
||||||
optimizer.optimizer_loading_epilogue()
|
optimizer.optimizer_loading_epilogue()
|
||||||
|
|
||||||
@ -322,8 +434,9 @@ class GeminiPlugin(DPPluginBase):
|
|||||||
enable_flash_attention (bool, optional): Whether to switch on flash attention in Shardformer. Defaults to False.
|
enable_flash_attention (bool, optional): Whether to switch on flash attention in Shardformer. Defaults to False.
|
||||||
enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False.
|
enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False.
|
||||||
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False.
|
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False.
|
||||||
enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False.
|
use_fp8 (bool, optional): Whether to enable fp8 mixed precision training. Defaults to False.
|
||||||
verbose (bool, optional): verbose mode. Debug info including chunk search result will be printed. Defaults to False.
|
verbose (bool, optional): verbose mode. Debug info including chunk search result will be printed. Defaults to False.
|
||||||
|
fp8_communication (bool, optional): Whether to enable fp8 communication. Defaults to False.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -364,7 +477,6 @@ class GeminiPlugin(DPPluginBase):
|
|||||||
enable_flash_attention: bool = False,
|
enable_flash_attention: bool = False,
|
||||||
enable_sequence_parallelism: bool = False,
|
enable_sequence_parallelism: bool = False,
|
||||||
enable_jit_fused: bool = False,
|
enable_jit_fused: bool = False,
|
||||||
enable_sequence_overlap: bool = False,
|
|
||||||
enable_async_reduce: bool = True,
|
enable_async_reduce: bool = True,
|
||||||
use_fp8: bool = False,
|
use_fp8: bool = False,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
@ -426,7 +538,6 @@ class GeminiPlugin(DPPluginBase):
|
|||||||
self.enable_flash_attention = enable_flash_attention
|
self.enable_flash_attention = enable_flash_attention
|
||||||
self.enable_sequence_parallelism = enable_sequence_parallelism if self.enable_tensor_parallelism else False
|
self.enable_sequence_parallelism = enable_sequence_parallelism if self.enable_tensor_parallelism else False
|
||||||
self.enable_jit_fused = enable_jit_fused
|
self.enable_jit_fused = enable_jit_fused
|
||||||
self.enable_sequence_overlap = enable_sequence_overlap
|
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
|
|
||||||
self.tp_size = tp_size
|
self.tp_size = tp_size
|
||||||
@ -453,7 +564,6 @@ class GeminiPlugin(DPPluginBase):
|
|||||||
enable_flash_attention=self.enable_flash_attention,
|
enable_flash_attention=self.enable_flash_attention,
|
||||||
enable_jit_fused=self.enable_jit_fused,
|
enable_jit_fused=self.enable_jit_fused,
|
||||||
enable_sequence_parallelism=self.enable_sequence_parallelism,
|
enable_sequence_parallelism=self.enable_sequence_parallelism,
|
||||||
enable_sequence_overlap=self.enable_sequence_overlap,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
|
@ -10,6 +10,7 @@ from typing import Any, Callable, Dict, Iterator, List, Optional, OrderedDict, T
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
from peft import PeftModel
|
||||||
from torch import Tensor, inf
|
from torch import Tensor, inf
|
||||||
from torch.distributed import ProcessGroup, get_world_size
|
from torch.distributed import ProcessGroup, get_world_size
|
||||||
from torch.nn import Module, SyncBatchNorm
|
from torch.nn import Module, SyncBatchNorm
|
||||||
@ -25,10 +26,11 @@ from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOpt
|
|||||||
from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO
|
from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO
|
||||||
from colossalai.cluster import ProcessGroupMesh
|
from colossalai.cluster import ProcessGroupMesh
|
||||||
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
|
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
|
||||||
|
from colossalai.interface.model import PeftUnwrapMixin
|
||||||
from colossalai.interface.optimizer import DistributedOptim
|
from colossalai.interface.optimizer import DistributedOptim
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
|
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
|
||||||
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
|
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule, ZeroBubbleVPipeScheduler
|
||||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
from colossalai.quantization import BnbQuantizationConfig, quantize_model
|
from colossalai.quantization import BnbQuantizationConfig, quantize_model
|
||||||
from colossalai.quantization.fp8_hook import FP8Hook
|
from colossalai.quantization.fp8_hook import FP8Hook
|
||||||
@ -216,25 +218,23 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
|
|||||||
if self.convert_fn is not None:
|
if self.convert_fn is not None:
|
||||||
args = tree_map(self.convert_fn, args)
|
args = tree_map(self.convert_fn, args)
|
||||||
kwargs = tree_map(self.convert_fn, kwargs)
|
kwargs = tree_map(self.convert_fn, kwargs)
|
||||||
with self._wait_all_gather():
|
with self._hook_context():
|
||||||
return super().forward(*args, **kwargs)
|
return super().forward(*args, **kwargs)
|
||||||
|
|
||||||
def unwrap(self):
|
def unwrap(self, unwrap_peft: bool = True):
|
||||||
module = super().unwrap()
|
model = self.module
|
||||||
if isinstance(module, DDP):
|
if isinstance(model, DDP):
|
||||||
module = module.module
|
model = model.module
|
||||||
return module
|
if unwrap_peft and isinstance(model, PeftModel):
|
||||||
|
model = PeftUnwrapMixin(model)
|
||||||
|
return model
|
||||||
|
|
||||||
def _force_wait_all_gather(self):
|
def _force_wait_all_gather(self):
|
||||||
for p in self.module.parameters():
|
for p in self.module.parameters():
|
||||||
wait_all_gather_handle(p)
|
wait_all_gather_handle(p)
|
||||||
|
|
||||||
def _wait_all_gather(self):
|
def _hook_context(self):
|
||||||
return (
|
return ColoParamOpHookManager.use_hooks(*self.op_hooks) if len(self.op_hooks) > 0 else nullcontext()
|
||||||
ColoParamOpHookManager.use_hooks(*self.op_hooks)
|
|
||||||
if (self.overlap_allgather or self.use_fp8)
|
|
||||||
else nullcontext()
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_param_info(optim: Optimizer):
|
def get_param_info(optim: Optimizer):
|
||||||
@ -297,9 +297,10 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
|
|||||||
self.pp_pg = pp_process_group
|
self.pp_pg = pp_process_group
|
||||||
self.tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1
|
self.tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1
|
||||||
self.pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1
|
self.pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1
|
||||||
|
self._current_grad_norm: Optional[float] = None
|
||||||
super().__init__(optim)
|
super().__init__(optim)
|
||||||
|
|
||||||
def backward(self, loss: Tensor, *args, **kwargs):
|
def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs):
|
||||||
r"""
|
r"""
|
||||||
Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients.
|
Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients.
|
||||||
|
|
||||||
@ -317,7 +318,8 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Call the superclass backward method to compute gradients.
|
# Call the superclass backward method to compute gradients.
|
||||||
super().backward(loss, *args, **kwargs)
|
with self.model._hook_context():
|
||||||
|
super().backward(loss, inputs=inputs, retain_graph=retain_graph, **kwargs)
|
||||||
|
|
||||||
if self.model.require_grad_sync:
|
if self.model.require_grad_sync:
|
||||||
# If gradient synchronization is required, sync sequence parallelism gradients.
|
# If gradient synchronization is required, sync sequence parallelism gradients.
|
||||||
@ -326,7 +328,7 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
|
|||||||
# If gradient synchronization is is not required, return.
|
# If gradient synchronization is is not required, return.
|
||||||
return
|
return
|
||||||
|
|
||||||
def backward_by_grad(self, tensor: Tensor, grad: Tensor):
|
def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False):
|
||||||
"""
|
"""
|
||||||
Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients.
|
Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients.
|
||||||
|
|
||||||
@ -343,7 +345,7 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Call the superclass backward method to compute gradients.
|
# Call the superclass backward method to compute gradients.
|
||||||
super().backward_by_grad(tensor, grad)
|
super().backward_by_grad(tensor, grad, inputs=inputs, retain_graph=retain_graph)
|
||||||
|
|
||||||
if self.model.require_grad_sync:
|
if self.model.require_grad_sync:
|
||||||
# If gradient synchronization is required, sync sequence parallelism gradients.
|
# If gradient synchronization is required, sync sequence parallelism gradients.
|
||||||
@ -367,6 +369,7 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
|
|||||||
(p, p.grad) for group in self.optim.param_groups for p in group["params"] if p.grad is not None
|
(p, p.grad) for group in self.optim.param_groups for p in group["params"] if p.grad is not None
|
||||||
]
|
]
|
||||||
total_norm = self._compute_grad_norm(param_gradient_pairs)
|
total_norm = self._compute_grad_norm(param_gradient_pairs)
|
||||||
|
self._current_grad_norm = total_norm
|
||||||
|
|
||||||
# Clip the gradients to prevent exploding gradients.
|
# Clip the gradients to prevent exploding gradients.
|
||||||
self._clip_grad_norm(total_norm)
|
self._clip_grad_norm(total_norm)
|
||||||
@ -480,6 +483,9 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
|
|||||||
def get_master_to_working_map(self):
|
def get_master_to_working_map(self):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def get_grad_norm(self, norm_type=2, **kwargs):
|
||||||
|
return self._current_grad_norm
|
||||||
|
|
||||||
|
|
||||||
class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
|
class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -523,7 +529,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
|
|||||||
max_norm=max_norm,
|
max_norm=max_norm,
|
||||||
)
|
)
|
||||||
|
|
||||||
def backward(self, loss: Tensor, *args, **kwargs):
|
def backward(self, loss: Tensor, inputs=None, retain_graph=False, **kwargs):
|
||||||
r"""
|
r"""
|
||||||
Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients.
|
Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients.
|
||||||
|
|
||||||
@ -540,7 +546,8 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
|
|||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
# Call the superclass backward method to compute gradients.
|
# Call the superclass backward method to compute gradients.
|
||||||
super().backward(loss, *args, **kwargs)
|
with self.model._hook_context():
|
||||||
|
super().backward(loss, inputs=inputs, retain_graph=retain_graph, **kwargs)
|
||||||
|
|
||||||
if self.model.require_grad_sync:
|
if self.model.require_grad_sync:
|
||||||
# If gradient synchronization is required, sync sequence parallelism gradients.
|
# If gradient synchronization is required, sync sequence parallelism gradients.
|
||||||
@ -549,7 +556,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
|
|||||||
# If gradient synchronization is is not required, return.
|
# If gradient synchronization is is not required, return.
|
||||||
return
|
return
|
||||||
|
|
||||||
def backward_by_grad(self, tensor: Tensor, grad: Tensor):
|
def backward_by_grad(self, tensor: Tensor, grad: Tensor, inputs: Tensor = None, retain_graph: bool = False):
|
||||||
"""
|
"""
|
||||||
Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients.
|
Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients.
|
||||||
|
|
||||||
@ -565,7 +572,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
|
|||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
# Call the superclass backward method to compute gradients.
|
# Call the superclass backward method to compute gradients.
|
||||||
super().backward_by_grad(tensor, grad)
|
super().backward_by_grad(tensor, grad, inputs=inputs, retain_graph=retain_graph)
|
||||||
|
|
||||||
if self.model.require_grad_sync:
|
if self.model.require_grad_sync:
|
||||||
# If gradient synchronization is required, sync sequence parallelism gradients.
|
# If gradient synchronization is required, sync sequence parallelism gradients.
|
||||||
@ -683,6 +690,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
|||||||
pp_process_group: Optional[ProcessGroup] = None, # if using pp
|
pp_process_group: Optional[ProcessGroup] = None, # if using pp
|
||||||
forced_dtype: Optional[torch.dtype] = None,
|
forced_dtype: Optional[torch.dtype] = None,
|
||||||
overlap_allgather: bool = False,
|
overlap_allgather: bool = False,
|
||||||
|
fp8_communication: bool = False,
|
||||||
):
|
):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.param_info = param_info
|
self.param_info = param_info
|
||||||
@ -712,6 +720,8 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
|||||||
dp_process_group=dp_process_group,
|
dp_process_group=dp_process_group,
|
||||||
forced_dtype=forced_dtype,
|
forced_dtype=forced_dtype,
|
||||||
overlap_allgather=overlap_allgather,
|
overlap_allgather=overlap_allgather,
|
||||||
|
fp8_communication=fp8_communication,
|
||||||
|
backward_context=model._hook_context,
|
||||||
)
|
)
|
||||||
|
|
||||||
def sync_dp_grads(self):
|
def sync_dp_grads(self):
|
||||||
@ -779,7 +789,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
|||||||
else:
|
else:
|
||||||
return
|
return
|
||||||
|
|
||||||
def backward(self, loss, retain_graph=False):
|
def backward(self, loss, inputs=None, retain_graph=False):
|
||||||
"""
|
"""
|
||||||
Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients.
|
Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients.
|
||||||
|
|
||||||
@ -795,7 +805,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
|||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
# Call the superclass backward method to compute gradients.
|
# Call the superclass backward method to compute gradients.
|
||||||
super().backward(loss, retain_graph)
|
super().backward(loss, inputs=inputs, retain_graph=retain_graph)
|
||||||
|
|
||||||
if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
|
if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
|
||||||
# If gradient synchronization is required, sync sequence parallelism gradients.
|
# If gradient synchronization is required, sync sequence parallelism gradients.
|
||||||
@ -804,7 +814,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
|||||||
# If gradient synchronization is is not required, return.
|
# If gradient synchronization is is not required, return.
|
||||||
return
|
return
|
||||||
|
|
||||||
def backward_by_grad(self, tensor, grad):
|
def backward_by_grad(self, tensor, grad, inputs: Tensor = None, retain_graph: bool = False):
|
||||||
"""
|
"""
|
||||||
Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients.
|
Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients.
|
||||||
|
|
||||||
@ -820,7 +830,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
|||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
# Call the superclass backward_by_grad method to compute gradients.
|
# Call the superclass backward_by_grad method to compute gradients.
|
||||||
super().backward_by_grad(tensor, grad)
|
super().backward_by_grad(tensor, grad, inputs=inputs, retain_graph=retain_graph)
|
||||||
|
|
||||||
if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
|
if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
|
||||||
# If gradient synchronization is required, sync sequence parallelism gradients.
|
# If gradient synchronization is required, sync sequence parallelism gradients.
|
||||||
@ -950,7 +960,6 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||||||
enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False.
|
enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False.
|
||||||
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False.
|
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False.
|
||||||
sequence_parallelism_mode (str): The Sequence parallelism mode. Can only be choosed from ["split_gather", "ring", "all_to_all"]. Defaults to "split_gather".
|
sequence_parallelism_mode (str): The Sequence parallelism mode. Can only be choosed from ["split_gather", "ring", "all_to_all"]. Defaults to "split_gather".
|
||||||
enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False.
|
|
||||||
parallel_output (bool): Whether to keep the output parallel when enabling tensor parallelism. Default to True.
|
parallel_output (bool): Whether to keep the output parallel when enabling tensor parallelism. Default to True.
|
||||||
num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None.
|
num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None.
|
||||||
microbatch_size (int, optional): Microbatch size when using pipeline parallelism.
|
microbatch_size (int, optional): Microbatch size when using pipeline parallelism.
|
||||||
@ -980,7 +989,8 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||||||
gradient_checkpoint_config (GradientCheckpointConfig, optional): Configuration for gradient checkpointing. Defaults to None.
|
gradient_checkpoint_config (GradientCheckpointConfig, optional): Configuration for gradient checkpointing. Defaults to None.
|
||||||
enable_metadata_cache (bool, optional): Whether to enable metadata cache for pipeline parallelism. Defaults to True.
|
enable_metadata_cache (bool, optional): Whether to enable metadata cache for pipeline parallelism. Defaults to True.
|
||||||
make_vocab_size_divisible_by (int, optional): it's used when padding the vocabulary size, to make it choose an faster kenel. Default to 64.
|
make_vocab_size_divisible_by (int, optional): it's used when padding the vocabulary size, to make it choose an faster kenel. Default to 64.
|
||||||
fp8_communication (bool, optional): Whether to enable fp8 communication in model parallelism
|
fp8_communication (bool, optional): Whether to enable fp8 communication. Defaults to False.
|
||||||
|
use_fp8 (bool, optional): Whether to enable fp8 mixed precision training. Defaults to False.
|
||||||
overlap_p2p (bool, optional): Whether to overlap the p2p communication in pipeline parallelism
|
overlap_p2p (bool, optional): Whether to overlap the p2p communication in pipeline parallelism
|
||||||
inner_ring_size (int, optional): The inner ring size of 2D Ring Attention when sp mode is "ring_attn".
|
inner_ring_size (int, optional): The inner ring size of 2D Ring Attention when sp mode is "ring_attn".
|
||||||
It's advisable to not tune this (especially in single-node settings) and let it be heuristically set based on topology by default.
|
It's advisable to not tune this (especially in single-node settings) and let it be heuristically set based on topology by default.
|
||||||
@ -1000,7 +1010,6 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||||||
enable_jit_fused: bool = False,
|
enable_jit_fused: bool = False,
|
||||||
enable_sequence_parallelism: bool = False,
|
enable_sequence_parallelism: bool = False,
|
||||||
sequence_parallelism_mode: str = None,
|
sequence_parallelism_mode: str = None,
|
||||||
enable_sequence_overlap: bool = False,
|
|
||||||
parallel_output: bool = True,
|
parallel_output: bool = True,
|
||||||
num_microbatches: Optional[int] = None,
|
num_microbatches: Optional[int] = None,
|
||||||
microbatch_size: Optional[int] = None,
|
microbatch_size: Optional[int] = None,
|
||||||
@ -1025,6 +1034,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||||||
custom_policy: Policy = None,
|
custom_policy: Policy = None,
|
||||||
pp_style: str = "1f1b",
|
pp_style: str = "1f1b",
|
||||||
num_model_chunks: int = 1,
|
num_model_chunks: int = 1,
|
||||||
|
scheduler_nodes: List = None,
|
||||||
num_layers_per_stage: Optional[List[int]] = None,
|
num_layers_per_stage: Optional[List[int]] = None,
|
||||||
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None,
|
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None,
|
||||||
enable_metadata_cache: bool = True,
|
enable_metadata_cache: bool = True,
|
||||||
@ -1043,6 +1053,9 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||||||
dist.get_world_size() % (tp_size * pp_size) == 0
|
dist.get_world_size() % (tp_size * pp_size) == 0
|
||||||
), f"World size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
|
), f"World size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
|
||||||
|
|
||||||
|
assert (
|
||||||
|
not pp_style == "zbv" or scheduler_nodes is not None
|
||||||
|
), f"scheduler_nodes must not be None when using zero bubble pipeline."
|
||||||
if enable_sequence_parallelism:
|
if enable_sequence_parallelism:
|
||||||
self.sequence_parallelism_mode = (
|
self.sequence_parallelism_mode = (
|
||||||
sequence_parallelism_mode if sequence_parallelism_mode is not None else "all_to_all"
|
sequence_parallelism_mode if sequence_parallelism_mode is not None else "all_to_all"
|
||||||
@ -1104,29 +1117,39 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||||||
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size)
|
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size)
|
||||||
|
|
||||||
self.stage_manager = None
|
self.stage_manager = None
|
||||||
self.schedule = None
|
self.scheduler = None
|
||||||
self.custom_policy = custom_policy
|
self.custom_policy = custom_policy
|
||||||
assert zero_stage in (0, 1, 2)
|
assert zero_stage in (0, 1, 2)
|
||||||
if self.pp_size > 1:
|
if self.pp_size > 1:
|
||||||
assert pp_style in ["1f1b", "interleaved"], "Unsupported pipeline parallelism style"
|
assert pp_style in ["1f1b", "interleaved", "zbv"], "Unsupported pipeline parallelism style"
|
||||||
assert pp_style == "interleaved" or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b"
|
assert (
|
||||||
|
pp_style in ["interleaved", "zbv"] or num_model_chunks == 1
|
||||||
|
), "num_model_chunks must be 1 when using 1f1b"
|
||||||
|
assert (
|
||||||
|
pp_style in ["1f1b", "interleaved"] or num_model_chunks == 2
|
||||||
|
), "num_model_chunks must be 2 when using zero bubble pipeline"
|
||||||
assert (
|
assert (
|
||||||
num_microbatches is not None or microbatch_size is not None
|
num_microbatches is not None or microbatch_size is not None
|
||||||
), "num_microbatches or microbatch_size must be specified when using pipeline parallelism"
|
), "num_microbatches or microbatch_size must be specified when using pipeline parallelism"
|
||||||
assert (
|
assert (
|
||||||
self.zero_stage <= 1
|
self.zero_stage <= 1
|
||||||
), "To avoid prohibitive gradient synchronization costs, zero stage must be 0 or 1 when using pipeline parallelism"
|
), "To avoid prohibitive gradient synchronization costs, zero stage must be 0 or 1 when using pipeline parallelism"
|
||||||
|
if pp_style == "zbv":
|
||||||
|
self.logger.warning(
|
||||||
|
"""the enable_gradient_checkpointing function must set the use_reentrant to False, such as model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant':False})"""
|
||||||
|
)
|
||||||
self.stage_manager = PipelineStageManager(
|
self.stage_manager = PipelineStageManager(
|
||||||
self.pg_mesh,
|
self.pg_mesh,
|
||||||
pipeline_axis=self.pp_axis,
|
pipeline_axis=self.pp_axis,
|
||||||
enable_interleave=pp_style == "interleaved",
|
enable_interleave=(pp_style == "interleaved" or pp_style == "zbv"),
|
||||||
|
use_zbv=(pp_style == "zbv"),
|
||||||
num_model_chunks=num_model_chunks,
|
num_model_chunks=num_model_chunks,
|
||||||
num_layers_per_stage=num_layers_per_stage,
|
num_layers_per_stage=num_layers_per_stage,
|
||||||
)
|
)
|
||||||
|
|
||||||
if pp_style == "interleaved":
|
if pp_style == "interleaved":
|
||||||
assert num_model_chunks > 1, "number of model chunks must be > 1 when using interleaved"
|
assert num_model_chunks > 1, "number of model chunks must be > 1 when using interleaved"
|
||||||
self.schedule = InterleavedSchedule(
|
self.scheduler = InterleavedSchedule(
|
||||||
stage_manager=self.stage_manager,
|
stage_manager=self.stage_manager,
|
||||||
num_model_chunks=num_model_chunks,
|
num_model_chunks=num_model_chunks,
|
||||||
num_microbatch=num_microbatches,
|
num_microbatch=num_microbatches,
|
||||||
@ -1136,13 +1159,21 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||||||
fp8_communication=fp8_communication,
|
fp8_communication=fp8_communication,
|
||||||
)
|
)
|
||||||
elif pp_style == "1f1b":
|
elif pp_style == "1f1b":
|
||||||
self.schedule = OneForwardOneBackwardSchedule(
|
self.scheduler = OneForwardOneBackwardSchedule(
|
||||||
stage_manager=self.stage_manager,
|
stage_manager=self.stage_manager,
|
||||||
num_microbatches=num_microbatches,
|
num_microbatches=num_microbatches,
|
||||||
microbatch_size=microbatch_size,
|
microbatch_size=microbatch_size,
|
||||||
enable_metadata_cache=enable_metadata_cache,
|
enable_metadata_cache=enable_metadata_cache,
|
||||||
fp8_communication=fp8_communication,
|
fp8_communication=fp8_communication,
|
||||||
)
|
)
|
||||||
|
elif pp_style == "zbv":
|
||||||
|
self.scheduler = ZeroBubbleVPipeScheduler(
|
||||||
|
stage_manager=self.stage_manager,
|
||||||
|
schedule=scheduler_nodes,
|
||||||
|
num_model_chunks=num_model_chunks,
|
||||||
|
num_microbatch=num_microbatches,
|
||||||
|
microbatch_size=microbatch_size,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
if sequence_parallelism_mode == "ring_attn":
|
if sequence_parallelism_mode == "ring_attn":
|
||||||
@ -1161,6 +1192,15 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||||||
else:
|
else:
|
||||||
self.sp_group = self.pg_mesh.get_group_along_axis(self.sp_axis)
|
self.sp_group = self.pg_mesh.get_group_along_axis(self.sp_axis)
|
||||||
|
|
||||||
|
# sync gradients across DP * SP ranks
|
||||||
|
# sync gradients across DP * SP ranks
|
||||||
|
# Apply Hybrid ZeRO across DP * SP ranks
|
||||||
|
if self.enable_sequence_parallelism and not is_share_sp_tp(self.sequence_parallelism_mode):
|
||||||
|
self.mixed_dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
|
||||||
|
self.dp_size = get_world_size(self.mixed_dp_group)
|
||||||
|
else:
|
||||||
|
self.mixed_dp_group = self.dp_group
|
||||||
|
|
||||||
self.shard_config = ShardConfig(
|
self.shard_config = ShardConfig(
|
||||||
tensor_parallel_process_group=self.tp_group,
|
tensor_parallel_process_group=self.tp_group,
|
||||||
sequence_parallel_process_group=self.sp_group,
|
sequence_parallel_process_group=self.sp_group,
|
||||||
@ -1172,13 +1212,15 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||||||
enable_jit_fused=self.enable_jit_fused,
|
enable_jit_fused=self.enable_jit_fused,
|
||||||
enable_sequence_parallelism=enable_sequence_parallelism,
|
enable_sequence_parallelism=enable_sequence_parallelism,
|
||||||
sequence_parallelism_mode=sequence_parallelism_mode,
|
sequence_parallelism_mode=sequence_parallelism_mode,
|
||||||
enable_sequence_overlap=enable_sequence_overlap,
|
|
||||||
parallel_output=parallel_output,
|
parallel_output=parallel_output,
|
||||||
make_vocab_size_divisible_by=make_vocab_size_divisible_by,
|
make_vocab_size_divisible_by=make_vocab_size_divisible_by,
|
||||||
gradient_checkpoint_config=gradient_checkpoint_config,
|
gradient_checkpoint_config=gradient_checkpoint_config,
|
||||||
fp8_communication=fp8_communication,
|
fp8_communication=fp8_communication,
|
||||||
inner_ring_size=inner_ring_size,
|
inner_ring_size=inner_ring_size,
|
||||||
|
pg_mesh=self.pg_mesh,
|
||||||
|
sp_axis=self.sp_axis,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.amp_config = dict(
|
self.amp_config = dict(
|
||||||
initial_scale=initial_scale,
|
initial_scale=initial_scale,
|
||||||
growth_factor=growth_factor,
|
growth_factor=growth_factor,
|
||||||
@ -1206,6 +1248,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||||||
partition_grad=(self.zero_stage == 2),
|
partition_grad=(self.zero_stage == 2),
|
||||||
forced_dtype=PRECISION_TORCH_TYPE[precision],
|
forced_dtype=PRECISION_TORCH_TYPE[precision],
|
||||||
overlap_allgather=overlap_allgather,
|
overlap_allgather=overlap_allgather,
|
||||||
|
fp8_communication=fp8_communication,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.max_norm = max_norm
|
self.max_norm = max_norm
|
||||||
@ -1255,7 +1298,6 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||||||
|
|
||||||
# Replace with distributed implementation if exists
|
# Replace with distributed implementation if exists
|
||||||
optimizer = cast_to_distributed(optimizer)
|
optimizer = cast_to_distributed(optimizer)
|
||||||
|
|
||||||
if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and self.dp_size > 0:
|
if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and self.dp_size > 0:
|
||||||
self.logger.warning(
|
self.logger.warning(
|
||||||
"Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.",
|
"Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.",
|
||||||
@ -1269,18 +1311,11 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||||||
use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or (
|
use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or (
|
||||||
self.dp_size == 1 and self.pp_size == 1
|
self.dp_size == 1 and self.pp_size == 1
|
||||||
)
|
)
|
||||||
# sync gradients across DP * SP ranks
|
|
||||||
# Apply Hybrid ZeRO across DP * SP ranks
|
|
||||||
if self.enable_sequence_parallelism and not is_share_sp_tp(self.sequence_parallelism_mode):
|
|
||||||
dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
|
|
||||||
self.dp_size = get_world_size(dp_group)
|
|
||||||
else:
|
|
||||||
dp_group = self.dp_group
|
|
||||||
model = HybridParallelModule(
|
model = HybridParallelModule(
|
||||||
model,
|
model,
|
||||||
precision=self.precision,
|
precision=self.precision,
|
||||||
shard_config=self.shard_config,
|
shard_config=self.shard_config,
|
||||||
dp_group=dp_group,
|
dp_group=self.mixed_dp_group,
|
||||||
tp_group=self.tp_group,
|
tp_group=self.tp_group,
|
||||||
sp_group=self.sp_group,
|
sp_group=self.sp_group,
|
||||||
use_ddp=use_ddp,
|
use_ddp=use_ddp,
|
||||||
@ -1329,7 +1364,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||||||
model,
|
model,
|
||||||
use_pipeline=self.enable_pipeline_parallelism,
|
use_pipeline=self.enable_pipeline_parallelism,
|
||||||
param_info=param_info,
|
param_info=param_info,
|
||||||
dp_process_group=dp_group,
|
dp_process_group=self.mixed_dp_group,
|
||||||
tp_process_group=self.tp_group,
|
tp_process_group=self.tp_group,
|
||||||
pp_process_group=self.pp_group,
|
pp_process_group=self.pp_group,
|
||||||
verbose=True,
|
verbose=True,
|
||||||
@ -1371,8 +1406,8 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||||||
# so we disable it, performing manual reduction instead.
|
# so we disable it, performing manual reduction instead.
|
||||||
ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync()
|
ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync()
|
||||||
|
|
||||||
with ctx, model._wait_all_gather():
|
with ctx, model._hook_context():
|
||||||
outputs = self.schedule.forward_backward_step(
|
outputs = self.scheduler.forward_backward_step(
|
||||||
model, data_iter, criterion, optimizer, return_loss, return_outputs
|
model, data_iter, criterion, optimizer, return_loss, return_outputs
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1458,7 +1493,9 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def get_checkpoint_io(self) -> CheckpointIO:
|
def get_checkpoint_io(self) -> CheckpointIO:
|
||||||
return HybridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
|
return HybridParallelCheckpointIO(
|
||||||
|
self.mixed_dp_group, self.pp_group, self.tp_group, self.sp_group, self.zero_stage
|
||||||
|
)
|
||||||
|
|
||||||
def no_sync(self, model: Module, optimizer: OptimizerWrapper) -> Iterator[None]:
|
def no_sync(self, model: Module, optimizer: OptimizerWrapper) -> Iterator[None]:
|
||||||
assert (
|
assert (
|
||||||
@ -1476,7 +1513,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||||||
from peft import PeftModel, get_peft_model
|
from peft import PeftModel, get_peft_model
|
||||||
|
|
||||||
assert not isinstance(model, HybridParallelModule), "Lora should be enabled before boosting the model."
|
assert not isinstance(model, HybridParallelModule), "Lora should be enabled before boosting the model."
|
||||||
assert self.pp_size == 1 and self.tp_size == 1
|
assert self.tp_size == 1
|
||||||
self.lora_enabled = True
|
self.lora_enabled = True
|
||||||
self.logger.warning("You have enabled LoRa training. Please check the hyperparameters such as lr", ranks=[0])
|
self.logger.warning("You have enabled LoRa training. Please check the hyperparameters such as lr", ranks=[0])
|
||||||
|
|
||||||
|
@ -20,15 +20,18 @@ from torch.utils.data import DataLoader
|
|||||||
from colossalai.accelerator import get_accelerator
|
from colossalai.accelerator import get_accelerator
|
||||||
from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO
|
from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO
|
||||||
from colossalai.checkpoint_io.utils import (
|
from colossalai.checkpoint_io.utils import (
|
||||||
|
create_pinned_state_dict,
|
||||||
get_optimizer_base_filenames,
|
get_optimizer_base_filenames,
|
||||||
get_shard_filename,
|
get_shard_filename,
|
||||||
load_param_groups_into_optimizer,
|
load_param_groups_into_optimizer,
|
||||||
load_shard_state_dict,
|
load_state_dict,
|
||||||
|
load_state_dict_shards,
|
||||||
load_states_into_optimizer,
|
load_states_into_optimizer,
|
||||||
save_param_groups,
|
save_param_groups,
|
||||||
save_state_dict,
|
save_state_dict,
|
||||||
sharded_optimizer_loading_epilogue,
|
sharded_optimizer_loading_epilogue,
|
||||||
)
|
)
|
||||||
|
from colossalai.cluster import ProcessGroupMesh
|
||||||
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
|
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
|
||||||
from colossalai.interface.optimizer import DistributedOptim
|
from colossalai.interface.optimizer import DistributedOptim
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
@ -100,17 +103,21 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
|
|||||||
if self.convert_fn is not None:
|
if self.convert_fn is not None:
|
||||||
args = tree_map(self.convert_fn, args)
|
args = tree_map(self.convert_fn, args)
|
||||||
kwargs = tree_map(self.convert_fn, kwargs)
|
kwargs = tree_map(self.convert_fn, kwargs)
|
||||||
ctx = ColoParamOpHookManager.use_hooks(*self.op_hooks) if self.overlap_allgather else nullcontext()
|
with self._hook_context():
|
||||||
with ctx:
|
|
||||||
return super().forward(*args, **kwargs)
|
return super().forward(*args, **kwargs)
|
||||||
|
|
||||||
def _force_wait_all_gather(self):
|
def _force_wait_all_gather(self):
|
||||||
for p in self.module.parameters():
|
for p in self.module.parameters():
|
||||||
wait_all_gather_handle(p)
|
wait_all_gather_handle(p)
|
||||||
|
|
||||||
|
def _hook_context(self):
|
||||||
|
return ColoParamOpHookManager.use_hooks(*self.op_hooks) if len(self.op_hooks) > 0 else nullcontext()
|
||||||
|
|
||||||
|
|
||||||
class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||||
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False):
|
def save_unsharded_optimizer(
|
||||||
|
self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False, use_async: bool = False
|
||||||
|
):
|
||||||
"""Save optimizer to checkpoint but only on master process.
|
"""Save optimizer to checkpoint but only on master process.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -122,9 +129,36 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
|||||||
# the `state_dict` in LowLevelZeroOptimizer has communication
|
# the `state_dict` in LowLevelZeroOptimizer has communication
|
||||||
# if only the master rank collect state_dict and save,
|
# if only the master rank collect state_dict and save,
|
||||||
# the communication on each rank would not match
|
# the communication on each rank would not match
|
||||||
state_dict = optimizer.state_dict()
|
if use_async and self.coordinator.is_master():
|
||||||
|
if id(optimizer) not in self.pinned_state_dicts:
|
||||||
|
self.pinned_state_dicts[id(optimizer)] = {}
|
||||||
|
pinned_state_dicts = self.pinned_state_dicts[id(optimizer)]
|
||||||
|
else:
|
||||||
|
pinned_state_dicts = None
|
||||||
|
state_dict = optimizer.state_dict(pinned_state_dicts, only_on_master=True)
|
||||||
if self.coordinator.is_master():
|
if self.coordinator.is_master():
|
||||||
save_state_dict(state_dict, checkpoint, use_safetensors=False)
|
if use_async:
|
||||||
|
|
||||||
|
from colossalai.utils.safetensors import save_nested
|
||||||
|
|
||||||
|
f_writer = save_nested(checkpoint, state_dict)
|
||||||
|
self.async_writers.append(f_writer)
|
||||||
|
else:
|
||||||
|
save_state_dict(state_dict, checkpoint, use_safetensors=False)
|
||||||
|
|
||||||
|
def load_unsharded_optimizer(
|
||||||
|
self, optimizer: OptimizerWrapper, checkpoint: str, low_cpu_mem_mode: bool = True, num_threads: int = 1
|
||||||
|
):
|
||||||
|
use_async = checkpoint.endswith(".safetensors")
|
||||||
|
if use_async:
|
||||||
|
from colossalai.utils.safetensors import load_flat
|
||||||
|
|
||||||
|
checkpoint = load_flat(checkpoint)
|
||||||
|
else:
|
||||||
|
checkpoint = load_state_dict(checkpoint)
|
||||||
|
if not low_cpu_mem_mode:
|
||||||
|
checkpoint = create_pinned_state_dict(checkpoint, empty=False, num_threads=num_threads)
|
||||||
|
optimizer.load_state_dict(checkpoint)
|
||||||
|
|
||||||
def save_sharded_optimizer(
|
def save_sharded_optimizer(
|
||||||
self,
|
self,
|
||||||
@ -133,6 +167,7 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
|||||||
gather_dtensor: bool = False,
|
gather_dtensor: bool = False,
|
||||||
prefix: str = None,
|
prefix: str = None,
|
||||||
size_per_shard: int = 1024,
|
size_per_shard: int = 1024,
|
||||||
|
use_async: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Save sharded Zero-optimizer checkpoint under the given checkpointing path.
|
Save sharded Zero-optimizer checkpoint under the given checkpointing path.
|
||||||
@ -158,10 +193,18 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
|||||||
# state_dict only provide only 'param_groups'
|
# state_dict only provide only 'param_groups'
|
||||||
state_dict = optimizer.optim.state_dict()
|
state_dict = optimizer.optim.state_dict()
|
||||||
# state shard would be handled by the low-level zero optimizer
|
# state shard would be handled by the low-level zero optimizer
|
||||||
sharded_state = optimizer.state_dict_shard(max_shard_size=size_per_shard)
|
if use_async and self.coordinator.is_master():
|
||||||
|
if id(optimizer) not in self.pinned_state_dicts:
|
||||||
|
self.pinned_state_dicts[id(optimizer)] = {}
|
||||||
|
pinned_state_dicts = self.pinned_state_dicts[id(optimizer)]
|
||||||
|
else:
|
||||||
|
pinned_state_dicts = None
|
||||||
|
sharded_state = optimizer.state_dict_shard(
|
||||||
|
max_shard_size=size_per_shard, pinned_state_dicts=pinned_state_dicts, only_on_master=True
|
||||||
|
)
|
||||||
|
|
||||||
# Preparing file paths and index file.
|
# Preparing file paths and index file.
|
||||||
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
|
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix, use_safetensors=use_async)
|
||||||
index_file = CheckpointIndexFile(checkpoint)
|
index_file = CheckpointIndexFile(checkpoint)
|
||||||
index_file.append_meta_data("param_groups", param_group_file)
|
index_file.append_meta_data("param_groups", param_group_file)
|
||||||
|
|
||||||
@ -181,7 +224,14 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
|||||||
|
|
||||||
checkpoint_file_path = os.path.join(checkpoint, shard_file)
|
checkpoint_file_path = os.path.join(checkpoint, shard_file)
|
||||||
if self.coordinator.is_master():
|
if self.coordinator.is_master():
|
||||||
save_state_dict(shard, checkpoint_file_path, use_safetensors=False)
|
if use_async:
|
||||||
|
|
||||||
|
from colossalai.utils.safetensors import save_nested
|
||||||
|
|
||||||
|
f_writer = save_nested(checkpoint_file_path, shard)
|
||||||
|
self.async_writers.append(f_writer)
|
||||||
|
else:
|
||||||
|
save_state_dict(shard, checkpoint_file_path, use_safetensors=False)
|
||||||
|
|
||||||
# Wrap up index file.
|
# Wrap up index file.
|
||||||
index_file.append_meta_data("total_size", total_size)
|
index_file.append_meta_data("total_size", total_size)
|
||||||
@ -194,7 +244,14 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
|||||||
ranks=[0],
|
ranks=[0],
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_sharded_optimizer(self, optimizer: OptimizerWrapper, index_file_path: str, prefix: str):
|
def load_sharded_optimizer(
|
||||||
|
self,
|
||||||
|
optimizer: OptimizerWrapper,
|
||||||
|
index_file_path: str,
|
||||||
|
prefix: str,
|
||||||
|
low_cpu_mem_mode: bool = True,
|
||||||
|
num_threads: int = 1,
|
||||||
|
):
|
||||||
"""Load sharded optimizer with the given path to index file.
|
"""Load sharded optimizer with the given path to index file.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -219,8 +276,7 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
|||||||
|
|
||||||
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
|
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
|
||||||
|
|
||||||
for shard_file in checkpoint_files:
|
for state_dict in load_state_dict_shards(checkpoint_files, True, False, low_cpu_mem_mode):
|
||||||
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False)
|
|
||||||
# shard state dict
|
# shard state dict
|
||||||
for param_idx, state in state_dict.items():
|
for param_idx, state in state_dict.items():
|
||||||
for k, v in state.items():
|
for k, v in state.items():
|
||||||
@ -233,14 +289,28 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
|||||||
if padding_size > 0:
|
if padding_size > 0:
|
||||||
v = torch.nn.functional.pad(v, [0, padding_size])
|
v = torch.nn.functional.pad(v, [0, padding_size])
|
||||||
v_list = v.split(v.numel() // self.coordinator.world_size)
|
v_list = v.split(v.numel() // self.coordinator.world_size)
|
||||||
state_dict[param_idx][k] = v_list[self.coordinator.rank].detach().clone()
|
state_dict[param_idx][k] = v_list[self.coordinator.rank].detach()
|
||||||
|
if low_cpu_mem_mode:
|
||||||
|
state_dict[param_idx][k] = state_dict[param_idx][k].clone()
|
||||||
|
|
||||||
|
if not low_cpu_mem_mode:
|
||||||
|
state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads)
|
||||||
load_states_into_optimizer(optimizer, state_dict, id_map)
|
load_states_into_optimizer(optimizer, state_dict, id_map)
|
||||||
sharded_optimizer_loading_epilogue(optimizer)
|
sharded_optimizer_loading_epilogue(optimizer)
|
||||||
|
|
||||||
def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = True):
|
def load_unsharded_model(
|
||||||
|
self,
|
||||||
|
model: ModelWrapper,
|
||||||
|
checkpoint: str,
|
||||||
|
strict: bool = True,
|
||||||
|
low_cpu_mem_mode: bool = True,
|
||||||
|
num_threads: int = 1,
|
||||||
|
):
|
||||||
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
|
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
|
||||||
model._force_wait_all_gather()
|
model._force_wait_all_gather()
|
||||||
super().load_unsharded_model(model, checkpoint, strict)
|
super().load_unsharded_model(
|
||||||
|
model, checkpoint, strict, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
|
||||||
|
)
|
||||||
model.update_master_params()
|
model.update_master_params()
|
||||||
|
|
||||||
def load_sharded_model(
|
def load_sharded_model(
|
||||||
@ -250,16 +320,28 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
|||||||
strict: bool = False,
|
strict: bool = False,
|
||||||
use_safetensors: bool = False,
|
use_safetensors: bool = False,
|
||||||
load_sub_module: bool = True,
|
load_sub_module: bool = True,
|
||||||
|
low_cpu_mem_mode: bool = True,
|
||||||
|
num_threads: int = 1,
|
||||||
):
|
):
|
||||||
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
|
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
|
||||||
model._force_wait_all_gather()
|
model._force_wait_all_gather()
|
||||||
super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module)
|
super().load_sharded_model(
|
||||||
|
model,
|
||||||
|
checkpoint_index_file,
|
||||||
|
strict,
|
||||||
|
use_safetensors,
|
||||||
|
load_sub_module,
|
||||||
|
low_cpu_mem_mode=low_cpu_mem_mode,
|
||||||
|
num_threads=num_threads,
|
||||||
|
)
|
||||||
model.update_master_params()
|
model.update_master_params()
|
||||||
|
|
||||||
def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
|
def save_unsharded_model(
|
||||||
|
self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False
|
||||||
|
):
|
||||||
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
|
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
|
||||||
model._force_wait_all_gather()
|
model._force_wait_all_gather()
|
||||||
return super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors)
|
return super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async=use_async)
|
||||||
|
|
||||||
def save_sharded_model(
|
def save_sharded_model(
|
||||||
self,
|
self,
|
||||||
@ -269,26 +351,18 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
|||||||
prefix: Optional[str] = None,
|
prefix: Optional[str] = None,
|
||||||
max_shard_size: int = 1024,
|
max_shard_size: int = 1024,
|
||||||
use_safetensors: bool = False,
|
use_safetensors: bool = False,
|
||||||
|
use_async: bool = False,
|
||||||
):
|
):
|
||||||
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
|
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
|
||||||
model._force_wait_all_gather()
|
model._force_wait_all_gather()
|
||||||
return super().save_sharded_model(
|
return super().save_sharded_model(
|
||||||
model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors
|
model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors, use_async=use_async
|
||||||
)
|
)
|
||||||
|
|
||||||
def save_lora_as_pretrained(self, model, checkpoint, use_safetensors):
|
def save_lora_as_pretrained(self, model, checkpoint, use_safetensors, state_dict: Optional[dict] = None):
|
||||||
if os.path.isfile(checkpoint):
|
assert isinstance(model, LowLevelZeroModel), "Please boost the model before saving!"
|
||||||
self.logger.error(f"Provided path ({checkpoint}) should be a directory, not a file", ranks=[0])
|
|
||||||
return
|
|
||||||
from peft import PeftModel
|
|
||||||
|
|
||||||
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
|
|
||||||
model._force_wait_all_gather()
|
model._force_wait_all_gather()
|
||||||
peft_model = model.unwrap()
|
super().save_lora_as_pretrained(model, checkpoint, use_safetensors, state_dict=state_dict)
|
||||||
assert isinstance(
|
|
||||||
peft_model, PeftModel
|
|
||||||
), "The model doesn't have lora adapters, please enable lora before saving."
|
|
||||||
return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors)
|
|
||||||
|
|
||||||
|
|
||||||
class LowLevelZeroPlugin(DPPluginBase):
|
class LowLevelZeroPlugin(DPPluginBase):
|
||||||
@ -325,6 +399,9 @@ class LowLevelZeroPlugin(DPPluginBase):
|
|||||||
overlap_communication (bool, optional): whether to overlap communication and computation. Defaults to True.
|
overlap_communication (bool, optional): whether to overlap communication and computation. Defaults to True.
|
||||||
cpu_offload (bool, optional): whether to offload grad, master weight and optimizer state to cpu. Defaults to False.
|
cpu_offload (bool, optional): whether to offload grad, master weight and optimizer state to cpu. Defaults to False.
|
||||||
verbose (bool, optional): verbose mode. Debug info including grad overflow will be printed. Defaults to False.
|
verbose (bool, optional): verbose mode. Debug info including grad overflow will be printed. Defaults to False.
|
||||||
|
use_fp8 (bool, optional): Whether to enable fp8 mixed precision training. Defaults to False.
|
||||||
|
fp8_communication (bool, optional): Whether to enable fp8 communication. Defaults to False.
|
||||||
|
extra_dp_size (int, optional): The number of extra data parallel groups. Defaults to 1.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -350,11 +427,16 @@ class LowLevelZeroPlugin(DPPluginBase):
|
|||||||
cast_inputs: bool = True,
|
cast_inputs: bool = True,
|
||||||
fp8_communication: bool = False,
|
fp8_communication: bool = False,
|
||||||
use_fp8: bool = False,
|
use_fp8: bool = False,
|
||||||
|
extra_dp_size: int = 1,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert stage in (1, 2), f"LowLevelZeroPlugin only supports stage 1/2 training"
|
assert stage in (1, 2), f"LowLevelZeroPlugin only supports stage 1/2 training"
|
||||||
assert precision in SUPPORTED_PRECISION, f"LowLevelZeroPlugin only supports {SUPPORTED_PRECISION} training"
|
assert precision in SUPPORTED_PRECISION, f"LowLevelZeroPlugin only supports {SUPPORTED_PRECISION} training"
|
||||||
assert norm_type == 2.0, f"LowLevelZeroPlugin only supports norm_type=2.0 now"
|
assert norm_type == 2.0, f"LowLevelZeroPlugin only supports norm_type=2.0 now"
|
||||||
|
if extra_dp_size > 1:
|
||||||
|
assert dist.get_world_size() % extra_dp_size == 0, "extra_dp_size should be a factor of world_size"
|
||||||
|
inner_dp_size = dist.get_world_size() // extra_dp_size
|
||||||
|
self.pg_mesh = ProcessGroupMesh(extra_dp_size, inner_dp_size)
|
||||||
self.stage = stage
|
self.stage = stage
|
||||||
self.precision = precision
|
self.precision = precision
|
||||||
self.zero_optim_kwargs = dict(
|
self.zero_optim_kwargs = dict(
|
||||||
@ -375,6 +457,9 @@ class LowLevelZeroPlugin(DPPluginBase):
|
|||||||
overlap_allgather=overlap_allgather,
|
overlap_allgather=overlap_allgather,
|
||||||
fp8_communication=fp8_communication,
|
fp8_communication=fp8_communication,
|
||||||
)
|
)
|
||||||
|
if extra_dp_size > 1:
|
||||||
|
self.zero_optim_kwargs["extra_dp_group"] = self.pg_mesh.get_group_along_axis(0)
|
||||||
|
self.zero_optim_kwargs["dp_process_group"] = self.pg_mesh.get_group_along_axis(1)
|
||||||
self.lora_enabled = False
|
self.lora_enabled = False
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
self.logger = get_dist_logger()
|
self.logger = get_dist_logger()
|
||||||
@ -520,7 +605,7 @@ class LowLevelZeroPlugin(DPPluginBase):
|
|||||||
|
|
||||||
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
|
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
|
||||||
optimizer: LowLevelZeroOptimizer = LowLevelZeroOptimizer(
|
optimizer: LowLevelZeroOptimizer = LowLevelZeroOptimizer(
|
||||||
optimizer, **zero_optim_kwargs, verbose=self.verbose
|
optimizer, **zero_optim_kwargs, verbose=self.verbose, backward_context=model._hook_context
|
||||||
)
|
)
|
||||||
# inject update_master_params
|
# inject update_master_params
|
||||||
model.update_master_params = MethodType(optimizer.update_master_params, model)
|
model.update_master_params = MethodType(optimizer.update_master_params, model)
|
||||||
|
@ -19,7 +19,6 @@ from colossalai.booster.plugin.hybrid_parallel_plugin import (
|
|||||||
HybridParallelPlugin,
|
HybridParallelPlugin,
|
||||||
HybridParallelZeroOptimizer,
|
HybridParallelZeroOptimizer,
|
||||||
get_param_info,
|
get_param_info,
|
||||||
reinitialize_optimizer,
|
|
||||||
)
|
)
|
||||||
from colossalai.checkpoint_io import MoECheckpointIO
|
from colossalai.checkpoint_io import MoECheckpointIO
|
||||||
from colossalai.cluster.process_group_mesh import ProcessGroupMesh
|
from colossalai.cluster.process_group_mesh import ProcessGroupMesh
|
||||||
@ -29,6 +28,7 @@ from colossalai.logging import get_dist_logger
|
|||||||
from colossalai.nn.optimizer import cast_to_distributed
|
from colossalai.nn.optimizer import cast_to_distributed
|
||||||
from colossalai.pipeline.schedule.interleaved_pp import InterleavedSchedule
|
from colossalai.pipeline.schedule.interleaved_pp import InterleavedSchedule
|
||||||
from colossalai.pipeline.schedule.one_f_one_b import OneForwardOneBackwardSchedule
|
from colossalai.pipeline.schedule.one_f_one_b import OneForwardOneBackwardSchedule
|
||||||
|
from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler
|
||||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
from colossalai.shardformer.policies.base_policy import Policy
|
from colossalai.shardformer.policies.base_policy import Policy
|
||||||
from colossalai.shardformer.shard.grad_ckpt_config import GradientCheckpointConfig
|
from colossalai.shardformer.shard.grad_ckpt_config import GradientCheckpointConfig
|
||||||
@ -64,13 +64,18 @@ class MoeHybridParallelZeroOptimizer(HybridParallelZeroOptimizer):
|
|||||||
forced_dtype: Optional[torch.dtype] = None,
|
forced_dtype: Optional[torch.dtype] = None,
|
||||||
overlap_allgather: bool = False,
|
overlap_allgather: bool = False,
|
||||||
):
|
):
|
||||||
pg_param_list = {
|
if dp_process_group is moe_dp_group:
|
||||||
dp_process_group: list(filter(lambda p: not is_moe_tensor(p), model.parameters())),
|
pg_param_list = {
|
||||||
moe_dp_group: list(filter(is_moe_tensor, model.parameters())),
|
dp_process_group: list(model.parameters()),
|
||||||
}
|
}
|
||||||
|
else:
|
||||||
|
pg_param_list = {
|
||||||
|
dp_process_group: list(filter(lambda p: not is_moe_tensor(p), model.parameters())),
|
||||||
|
moe_dp_group: list(filter(is_moe_tensor, model.parameters())),
|
||||||
|
}
|
||||||
|
|
||||||
if len(pg_param_list[dp_process_group]) == 0 or len(pg_param_list[moe_dp_group]) == 0:
|
if len(pg_param_list[moe_dp_group]) == 0:
|
||||||
raise ValueError("No parameters found in dp_process_group or moe_dp_group")
|
raise ValueError("No parameters found in moe_dp_group, please consider using HybridParallelPlugin instead")
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
model=model,
|
model=model,
|
||||||
@ -135,7 +140,6 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||||||
enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False.
|
enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False.
|
||||||
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False.
|
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False.
|
||||||
sequence_parallelism_mode (str): The Sequence parallelism mode. Can only be choosed from ["split_gather", "ring", "all_to_all"]. Defaults to "split_gather".
|
sequence_parallelism_mode (str): The Sequence parallelism mode. Can only be choosed from ["split_gather", "ring", "all_to_all"]. Defaults to "split_gather".
|
||||||
enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False.
|
|
||||||
parallel_output (bool): Whether to keep the output parallel when enabling tensor parallelism. Default to True.
|
parallel_output (bool): Whether to keep the output parallel when enabling tensor parallelism. Default to True.
|
||||||
num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None.
|
num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None.
|
||||||
microbatch_size (int, optional): Microbatch size when using pipeline parallelism.
|
microbatch_size (int, optional): Microbatch size when using pipeline parallelism.
|
||||||
@ -165,7 +169,9 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||||||
gradient_checkpoint_config (GradientCheckpointConfig, optional): Configuration for gradient checkpointing. Defaults to None.
|
gradient_checkpoint_config (GradientCheckpointConfig, optional): Configuration for gradient checkpointing. Defaults to None.
|
||||||
enable_metadata_cache (bool, optional): Whether to enable metadata cache for pipeline parallelism. Defaults to True.
|
enable_metadata_cache (bool, optional): Whether to enable metadata cache for pipeline parallelism. Defaults to True.
|
||||||
make_vocab_size_divisible_by (int, optional): it's used when padding the vocabulary size, to make it choose an faster kenel. Default to 64.
|
make_vocab_size_divisible_by (int, optional): it's used when padding the vocabulary size, to make it choose an faster kenel. Default to 64.
|
||||||
overlap_p2p (bool, optional): Whether to overlap the p2p communication in pipeline parallelism
|
overlap_p2p (bool, optional): Whether to overlap the p2p communication in pipeline parallelism.
|
||||||
|
use_fp8 (bool, optional): Whether to enable fp8 mixed precision training. Defaults to False.
|
||||||
|
fp8_communication (bool, optional): Whether to enable fp8 communication. Defaults to False.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -182,7 +188,6 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||||||
enable_jit_fused: bool = False,
|
enable_jit_fused: bool = False,
|
||||||
enable_sequence_parallelism: bool = False,
|
enable_sequence_parallelism: bool = False,
|
||||||
sequence_parallelism_mode: str = None,
|
sequence_parallelism_mode: str = None,
|
||||||
enable_sequence_overlap: bool = False,
|
|
||||||
parallel_output: bool = True,
|
parallel_output: bool = True,
|
||||||
num_microbatches: Optional[int] = None,
|
num_microbatches: Optional[int] = None,
|
||||||
microbatch_size: Optional[int] = None,
|
microbatch_size: Optional[int] = None,
|
||||||
@ -207,6 +212,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||||||
custom_policy: Policy = None,
|
custom_policy: Policy = None,
|
||||||
pp_style: str = "1f1b",
|
pp_style: str = "1f1b",
|
||||||
num_model_chunks: int = 1,
|
num_model_chunks: int = 1,
|
||||||
|
scheduler_nodes: List = None,
|
||||||
num_layers_per_stage: Optional[List[int]] = None,
|
num_layers_per_stage: Optional[List[int]] = None,
|
||||||
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None,
|
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None,
|
||||||
enable_metadata_cache: bool = True,
|
enable_metadata_cache: bool = True,
|
||||||
@ -280,12 +286,17 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||||||
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.moe_dp_size, self.ep_size, self.tp_size, self.sp_size)
|
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.moe_dp_size, self.ep_size, self.tp_size, self.sp_size)
|
||||||
|
|
||||||
self.stage_manager = None
|
self.stage_manager = None
|
||||||
self.schedule = None
|
self.scheduler = None
|
||||||
self.custom_policy = custom_policy
|
self.custom_policy = custom_policy
|
||||||
assert zero_stage in (0, 1, 2)
|
assert zero_stage in (0, 1, 2)
|
||||||
if self.pp_size > 1:
|
if self.pp_size > 1:
|
||||||
assert pp_style in ["1f1b", "interleaved"], "Unsupported pipeline parallelism style"
|
assert pp_style in ["1f1b", "interleaved", "zbv"], "Unsupported pipeline parallelism style"
|
||||||
assert pp_style == "interleaved" or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b"
|
assert (
|
||||||
|
pp_style in ["interleaved", "zbv"] or num_model_chunks == 1
|
||||||
|
), "num_model_chunks must be 1 when using 1f1b"
|
||||||
|
assert (
|
||||||
|
pp_style in ["1f1b", "interleaved"] or num_model_chunks == 2
|
||||||
|
), "num_model_chunks must be 2 when using zero bubble pipeline"
|
||||||
assert (
|
assert (
|
||||||
num_microbatches is not None or microbatch_size is not None
|
num_microbatches is not None or microbatch_size is not None
|
||||||
), "num_microbatches or microbatch_size must be specified when using pipeline parallelism"
|
), "num_microbatches or microbatch_size must be specified when using pipeline parallelism"
|
||||||
@ -295,14 +306,15 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||||||
self.stage_manager = PipelineStageManager(
|
self.stage_manager = PipelineStageManager(
|
||||||
self.pg_mesh,
|
self.pg_mesh,
|
||||||
pipeline_axis=self.pp_axis,
|
pipeline_axis=self.pp_axis,
|
||||||
enable_interleave=pp_style == "interleaved",
|
enable_interleave=(pp_style == "interleaved" or pp_style == "zbv"),
|
||||||
num_model_chunks=num_model_chunks,
|
num_model_chunks=num_model_chunks,
|
||||||
num_layers_per_stage=num_layers_per_stage,
|
num_layers_per_stage=num_layers_per_stage,
|
||||||
|
use_zbv=(pp_style == "zbv"),
|
||||||
)
|
)
|
||||||
|
|
||||||
if pp_style == "interleaved":
|
if pp_style == "interleaved":
|
||||||
assert num_model_chunks > 1, "number of model chunks must be > 1 when using interleaved"
|
assert num_model_chunks > 1, "number of model chunks must be > 1 when using interleaved"
|
||||||
self.schedule = InterleavedSchedule(
|
self.scheduler = InterleavedSchedule(
|
||||||
stage_manager=self.stage_manager,
|
stage_manager=self.stage_manager,
|
||||||
num_model_chunks=num_model_chunks,
|
num_model_chunks=num_model_chunks,
|
||||||
num_microbatch=num_microbatches,
|
num_microbatch=num_microbatches,
|
||||||
@ -311,12 +323,21 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||||||
overlap_p2p=overlap_p2p,
|
overlap_p2p=overlap_p2p,
|
||||||
)
|
)
|
||||||
elif pp_style == "1f1b":
|
elif pp_style == "1f1b":
|
||||||
self.schedule = OneForwardOneBackwardSchedule(
|
self.scheduler = OneForwardOneBackwardSchedule(
|
||||||
stage_manager=self.stage_manager,
|
stage_manager=self.stage_manager,
|
||||||
num_microbatches=num_microbatches,
|
num_microbatches=num_microbatches,
|
||||||
microbatch_size=microbatch_size,
|
microbatch_size=microbatch_size,
|
||||||
enable_metadata_cache=enable_metadata_cache,
|
enable_metadata_cache=enable_metadata_cache,
|
||||||
)
|
)
|
||||||
|
elif pp_style == "zbv":
|
||||||
|
assert num_model_chunks > 1, "number of model chunks must be > 1 when using ZerbubbleV"
|
||||||
|
self.scheduler = ZeroBubbleVPipeScheduler(
|
||||||
|
schedule=scheduler_nodes,
|
||||||
|
stage_manager=self.stage_manager,
|
||||||
|
num_model_chunks=num_model_chunks,
|
||||||
|
num_microbatch=num_microbatches,
|
||||||
|
overlap_p2p=overlap_p2p,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@ -329,6 +350,14 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||||||
self.sp_group = self.pg_mesh.get_group_along_axis(self.tp_axis)
|
self.sp_group = self.pg_mesh.get_group_along_axis(self.tp_axis)
|
||||||
else:
|
else:
|
||||||
self.sp_group = self.pg_mesh.get_group_along_axis(self.sp_axis)
|
self.sp_group = self.pg_mesh.get_group_along_axis(self.sp_axis)
|
||||||
|
|
||||||
|
# sync gradients across DP * SP ranks
|
||||||
|
if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all":
|
||||||
|
self.mixed_dp_group = self.pg_mesh.create_group_along_axis([self.moe_dp_axis, self.ep_axis, self.sp_axis])
|
||||||
|
self.dp_size = dist.get_world_size(self.mixed_dp_group)
|
||||||
|
else:
|
||||||
|
self.mixed_dp_group = self.dp_group
|
||||||
|
|
||||||
self.use_fp8 = use_fp8
|
self.use_fp8 = use_fp8
|
||||||
|
|
||||||
self.shard_config = ShardConfig(
|
self.shard_config = ShardConfig(
|
||||||
@ -344,7 +373,6 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||||||
enable_jit_fused=self.enable_jit_fused,
|
enable_jit_fused=self.enable_jit_fused,
|
||||||
enable_sequence_parallelism=enable_sequence_parallelism,
|
enable_sequence_parallelism=enable_sequence_parallelism,
|
||||||
sequence_parallelism_mode=sequence_parallelism_mode,
|
sequence_parallelism_mode=sequence_parallelism_mode,
|
||||||
enable_sequence_overlap=enable_sequence_overlap,
|
|
||||||
parallel_output=parallel_output,
|
parallel_output=parallel_output,
|
||||||
make_vocab_size_divisible_by=make_vocab_size_divisible_by,
|
make_vocab_size_divisible_by=make_vocab_size_divisible_by,
|
||||||
gradient_checkpoint_config=gradient_checkpoint_config,
|
gradient_checkpoint_config=gradient_checkpoint_config,
|
||||||
@ -383,7 +411,13 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||||||
|
|
||||||
def get_checkpoint_io(self) -> MoECheckpointIO:
|
def get_checkpoint_io(self) -> MoECheckpointIO:
|
||||||
return MoECheckpointIO(
|
return MoECheckpointIO(
|
||||||
self.dp_group, self.pp_group, self.tp_group, self.ep_group, self.moe_dp_group, self.zero_stage
|
self.mixed_dp_group,
|
||||||
|
self.pp_group,
|
||||||
|
self.tp_group,
|
||||||
|
self.sp_group,
|
||||||
|
self.ep_group,
|
||||||
|
self.moe_dp_group,
|
||||||
|
self.zero_stage,
|
||||||
)
|
)
|
||||||
|
|
||||||
def configure(
|
def configure(
|
||||||
@ -407,6 +441,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||||||
and self.enable_sequence_parallelism
|
and self.enable_sequence_parallelism
|
||||||
and self.sequence_parallelism_mode == "all_to_all"
|
and self.sequence_parallelism_mode == "all_to_all"
|
||||||
)
|
)
|
||||||
|
|
||||||
if use_ddp:
|
if use_ddp:
|
||||||
self.logger.warning(
|
self.logger.warning(
|
||||||
f"Will have to check all params are used in pytorch DDP since not all experts are always activated",
|
f"Will have to check all params are used in pytorch DDP since not all experts are always activated",
|
||||||
@ -414,22 +449,16 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||||||
)
|
)
|
||||||
self.ddp_config["find_unused_parameters"] = True
|
self.ddp_config["find_unused_parameters"] = True
|
||||||
|
|
||||||
if dist.get_process_group_ranks(self.dp_group) != dist.get_process_group_ranks(self.moe_dp_group):
|
if dist.get_process_group_ranks(self.mixed_dp_group) != dist.get_process_group_ranks(self.moe_dp_group):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"if pytorch ddp is used, dp_group and moe_dp_group are expected to be the same since DDP can only reduce grad across a single group, but found dp_group {dist.get_process_group_ranks(self.dp_group)} and moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)}, you might want to use HybridParallelPlugin (i.e. set ep_size = 1) or set zero_stage > 0"
|
f"if pytorch DDP is used, dp_group and moe_dp_group are expected to be the same since DDP can only reduce grad across a single group, but found dp_group {dist.get_process_group_ranks(dp_group)} and moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)}, you might want to modify your config to bypass DDP \nhint: check the above ddp condition to by pass this"
|
||||||
)
|
)
|
||||||
|
|
||||||
# sync gradients across DP * SP ranks
|
|
||||||
if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all":
|
|
||||||
dp_group = self.pg_mesh.create_group_along_axis([self.moe_dp_axis, self.ep_axis, self.sp_axis])
|
|
||||||
else:
|
|
||||||
dp_group = self.dp_group
|
|
||||||
|
|
||||||
model = HybridParallelModule(
|
model = HybridParallelModule(
|
||||||
module=model,
|
module=model,
|
||||||
precision=self.precision,
|
precision=self.precision,
|
||||||
shard_config=self.shard_config,
|
shard_config=self.shard_config,
|
||||||
dp_group=dp_group,
|
dp_group=self.mixed_dp_group,
|
||||||
tp_group=self.tp_group,
|
tp_group=self.tp_group,
|
||||||
sp_group=self.sp_group,
|
sp_group=self.sp_group,
|
||||||
use_ddp=use_ddp,
|
use_ddp=use_ddp,
|
||||||
@ -438,18 +467,13 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||||||
use_fp8=self.use_fp8,
|
use_fp8=self.use_fp8,
|
||||||
)
|
)
|
||||||
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
|
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
|
||||||
if self.ep_size > 1:
|
|
||||||
# if ep is enabled, the num of (moe) paramaters changed since they are sharded among ep groups
|
|
||||||
# but the optimizer is not aware of ep, so we need to update the optimizer
|
|
||||||
reinitialize_optimizer(optimizer, model)
|
|
||||||
|
|
||||||
if self.zero_stage == 0:
|
if self.zero_stage == 0:
|
||||||
is_zero = False
|
is_zero = False
|
||||||
if self.precision in ["fp16", "bf16"]:
|
if self.precision in ["fp16", "bf16"]:
|
||||||
optimizer = HybridParallelAMPOptimizer(
|
optimizer = HybridParallelAMPOptimizer(
|
||||||
optimizer,
|
optimizer,
|
||||||
model,
|
model,
|
||||||
use_pipeline=self.enable_pipeline_parallelism,
|
use_pipeline=self.enable_pipeline_parallelism or self.ep_size > 1,
|
||||||
param_info=param_info,
|
param_info=param_info,
|
||||||
precision=self.precision,
|
precision=self.precision,
|
||||||
max_norm=self.max_norm,
|
max_norm=self.max_norm,
|
||||||
@ -459,13 +483,14 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||||||
optimizer = HybridParallelNaiveOptimizer(
|
optimizer = HybridParallelNaiveOptimizer(
|
||||||
optimizer,
|
optimizer,
|
||||||
model,
|
model,
|
||||||
use_pipeline=self.enable_pipeline_parallelism,
|
use_pipeline=self.enable_pipeline_parallelism or self.ep_size > 1,
|
||||||
param_info=param_info,
|
param_info=param_info,
|
||||||
max_norm=self.max_norm,
|
max_norm=self.max_norm,
|
||||||
pp_process_group=self.pp_group,
|
pp_process_group=self.pp_group,
|
||||||
tp_process_group=self.tp_group,
|
tp_process_group=self.tp_group,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
is_zero = True
|
||||||
if self.dp_size <= 1:
|
if self.dp_size <= 1:
|
||||||
self.logger.warning(
|
self.logger.warning(
|
||||||
"Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. "
|
"Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. "
|
||||||
@ -476,9 +501,9 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||||||
optimizer = MoeHybridParallelZeroOptimizer(
|
optimizer = MoeHybridParallelZeroOptimizer(
|
||||||
optimizer,
|
optimizer,
|
||||||
model,
|
model,
|
||||||
use_pipeline=self.enable_pipeline_parallelism,
|
use_pipeline=self.enable_pipeline_parallelism or self.ep_size > 1,
|
||||||
param_info=param_info,
|
param_info=param_info,
|
||||||
dp_process_group=dp_group,
|
dp_process_group=self.mixed_dp_group,
|
||||||
tp_process_group=self.tp_group,
|
tp_process_group=self.tp_group,
|
||||||
pp_process_group=self.pp_group,
|
pp_process_group=self.pp_group,
|
||||||
moe_dp_group=self.moe_dp_group,
|
moe_dp_group=self.moe_dp_group,
|
||||||
|
@ -1,14 +1,18 @@
|
|||||||
from typing import Callable, Dict, Iterator, List, Optional, Tuple, Union
|
from typing import Callable, Dict, Iterator, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from peft import PeftModel
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||||
|
from torch.utils._pytree import tree_map
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
|
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
|
||||||
from colossalai.cluster import DistCoordinator
|
from colossalai.cluster import DistCoordinator
|
||||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||||
|
from colossalai.interface.model import PeftUnwrapMixin
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.quantization import BnbQuantizationConfig, quantize_model
|
from colossalai.quantization import BnbQuantizationConfig, quantize_model
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
@ -24,35 +28,54 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
|
|||||||
self.coordinator = DistCoordinator()
|
self.coordinator = DistCoordinator()
|
||||||
self.logger = get_dist_logger()
|
self.logger = get_dist_logger()
|
||||||
|
|
||||||
def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = True):
|
def load_unsharded_model(
|
||||||
|
self,
|
||||||
|
model: ModelWrapper,
|
||||||
|
checkpoint: str,
|
||||||
|
strict: bool = True,
|
||||||
|
low_cpu_mem_mode: bool = True,
|
||||||
|
num_threads: int = 1,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Load model from checkpoint.
|
Load model from checkpoint.
|
||||||
"""
|
"""
|
||||||
assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
|
assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
|
||||||
super().load_unsharded_model(model.unwrap(), checkpoint, strict=strict)
|
super().load_unsharded_model(
|
||||||
|
model.unwrap(), checkpoint, strict=strict, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
|
||||||
|
)
|
||||||
|
|
||||||
def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
|
def save_unsharded_model(
|
||||||
|
self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Save model to checkpoint but only on master process.
|
Save model to checkpoint but only on master process.
|
||||||
"""
|
"""
|
||||||
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
|
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
|
||||||
if self.coordinator.is_master():
|
if self.coordinator.is_master():
|
||||||
super().save_unsharded_model(model.unwrap(), checkpoint, gather_dtensor, use_safetensors)
|
super().save_unsharded_model(
|
||||||
|
model.unwrap(), checkpoint, gather_dtensor, use_safetensors, use_async=use_async
|
||||||
|
)
|
||||||
|
|
||||||
def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str):
|
def load_unsharded_optimizer(
|
||||||
|
self, optimizer: OptimizerWrapper, checkpoint: str, low_cpu_mem_mode: bool = True, num_threads: int = 1
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Load optimizer from checkpoint.
|
Load optimizer from checkpoint.
|
||||||
"""
|
"""
|
||||||
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
|
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
|
||||||
super().load_unsharded_optimizer(optimizer, checkpoint)
|
super().load_unsharded_optimizer(
|
||||||
|
optimizer, checkpoint, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
|
||||||
|
)
|
||||||
|
|
||||||
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
|
def save_unsharded_optimizer(
|
||||||
|
self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool, use_async: bool = False
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Save optimizer to checkpoint but only on master process.
|
Save optimizer to checkpoint but only on master process.
|
||||||
"""
|
"""
|
||||||
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
|
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
|
||||||
if self.coordinator.is_master():
|
if self.coordinator.is_master():
|
||||||
super().save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor)
|
super().save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor, use_async=use_async)
|
||||||
|
|
||||||
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
|
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
|
||||||
"""
|
"""
|
||||||
@ -69,6 +92,7 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
|
|||||||
prefix: Optional[str] = None,
|
prefix: Optional[str] = None,
|
||||||
max_shard_size: int = 1024,
|
max_shard_size: int = 1024,
|
||||||
use_safetensors: bool = False,
|
use_safetensors: bool = False,
|
||||||
|
use_async: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Save model to checkpoint but only on master process.
|
Save model to checkpoint but only on master process.
|
||||||
@ -76,7 +100,13 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
|
|||||||
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
|
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
|
||||||
if self.coordinator.is_master():
|
if self.coordinator.is_master():
|
||||||
super().save_sharded_model(
|
super().save_sharded_model(
|
||||||
model.unwrap(), checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors
|
model.unwrap(),
|
||||||
|
checkpoint_path,
|
||||||
|
gather_dtensor,
|
||||||
|
prefix,
|
||||||
|
max_shard_size,
|
||||||
|
use_safetensors,
|
||||||
|
use_async=use_async,
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_sharded_model(
|
def load_sharded_model(
|
||||||
@ -86,12 +116,22 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
|
|||||||
strict: bool = False,
|
strict: bool = False,
|
||||||
use_safetensors: bool = False,
|
use_safetensors: bool = False,
|
||||||
load_sub_module: bool = True,
|
load_sub_module: bool = True,
|
||||||
|
low_cpu_mem_mode: bool = True,
|
||||||
|
num_threads: int = 1,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Load model from sharded checkpoint.
|
Load model from sharded checkpoint.
|
||||||
"""
|
"""
|
||||||
assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
|
assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
|
||||||
super().load_sharded_model(model.unwrap(), checkpoint_index_file, strict, use_safetensors, load_sub_module)
|
super().load_sharded_model(
|
||||||
|
model.unwrap(),
|
||||||
|
checkpoint_index_file,
|
||||||
|
strict,
|
||||||
|
use_safetensors,
|
||||||
|
load_sub_module,
|
||||||
|
low_cpu_mem_mode=low_cpu_mem_mode,
|
||||||
|
num_threads=num_threads,
|
||||||
|
)
|
||||||
|
|
||||||
def save_sharded_optimizer(
|
def save_sharded_optimizer(
|
||||||
self,
|
self,
|
||||||
@ -100,28 +140,39 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
|
|||||||
gather_dtensor: bool = True,
|
gather_dtensor: bool = True,
|
||||||
prefix: Optional[str] = None,
|
prefix: Optional[str] = None,
|
||||||
size_per_shard: int = 1024,
|
size_per_shard: int = 1024,
|
||||||
|
use_async: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Save optimizer to sharded checkpoint but only on master process.
|
Save optimizer to sharded checkpoint but only on master process.
|
||||||
"""
|
"""
|
||||||
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
|
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
|
||||||
if self.coordinator.is_master():
|
if self.coordinator.is_master():
|
||||||
super().save_sharded_optimizer(optimizer.unwrap(), checkpoint, gather_dtensor, prefix, size_per_shard)
|
super().save_sharded_optimizer(
|
||||||
|
optimizer.unwrap(), checkpoint, gather_dtensor, prefix, size_per_shard, use_async=use_async
|
||||||
|
)
|
||||||
|
|
||||||
def load_sharded_optimizer(
|
def load_sharded_optimizer(
|
||||||
self,
|
self,
|
||||||
optimizer: Optimizer,
|
optimizer: Optimizer,
|
||||||
index_file_path: str,
|
index_file_path: str,
|
||||||
prefix: Optional[str] = None,
|
prefix: Optional[str] = None,
|
||||||
|
low_cpu_mem_mode: bool = True,
|
||||||
|
num_threads: int = 1,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Load optimizer from sharded checkpoint.
|
Load optimizer from sharded checkpoint.
|
||||||
"""
|
"""
|
||||||
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
|
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
|
||||||
super().load_sharded_optimizer(optimizer.unwrap(), index_file_path, prefix)
|
super().load_sharded_optimizer(
|
||||||
|
optimizer.unwrap(), index_file_path, prefix, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
|
||||||
|
)
|
||||||
|
|
||||||
def save_lora_as_pretrained(
|
def save_lora_as_pretrained(
|
||||||
self, model: Union[nn.Module, ModelWrapper], checkpoint: str, use_safetensors: bool = False
|
self,
|
||||||
|
model: Union[nn.Module, ModelWrapper],
|
||||||
|
checkpoint: str,
|
||||||
|
use_safetensors: bool = False,
|
||||||
|
state_dict: Optional[dict] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Save the lora adapters and adapter configuration file to checkpoint directory.
|
Save the lora adapters and adapter configuration file to checkpoint directory.
|
||||||
@ -129,12 +180,18 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
|
|||||||
from peft import PeftModel
|
from peft import PeftModel
|
||||||
|
|
||||||
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
|
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
|
||||||
|
peft_model = model.unwrap(unwrap_peft=False)
|
||||||
|
assert isinstance(
|
||||||
|
peft_model, PeftModel
|
||||||
|
), "The model doesn't have lora adapters, please enable lora before saving."
|
||||||
|
if state_dict is None:
|
||||||
|
state_dict = tree_map(lambda x: x.data.cpu() if torch.is_tensor(x) else x, peft_model.state_dict())
|
||||||
if self.coordinator.is_master():
|
if self.coordinator.is_master():
|
||||||
peft_model = model.unwrap()
|
return peft_model.save_pretrained(
|
||||||
assert isinstance(
|
checkpoint,
|
||||||
peft_model, PeftModel
|
safe_serialization=use_safetensors,
|
||||||
), "The model doesn't have lora adapters, please enable lora before saving."
|
state_dict=state_dict,
|
||||||
peft_model.save_pretrained(save_directory=checkpoint, safe_serialization=use_safetensors)
|
)
|
||||||
|
|
||||||
|
|
||||||
class TorchDDPModel(ModelWrapper):
|
class TorchDDPModel(ModelWrapper):
|
||||||
@ -142,8 +199,11 @@ class TorchDDPModel(ModelWrapper):
|
|||||||
super().__init__(module)
|
super().__init__(module)
|
||||||
self.module = DDP(module, *args, **kwargs)
|
self.module = DDP(module, *args, **kwargs)
|
||||||
|
|
||||||
def unwrap(self):
|
def unwrap(self, unwrap_peft: bool = True) -> nn.Module:
|
||||||
return self.module.module
|
model = self.module.module
|
||||||
|
if unwrap_peft and isinstance(model, PeftModel):
|
||||||
|
model = PeftUnwrapMixin(model)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
class TorchDDPPlugin(DPPluginBase):
|
class TorchDDPPlugin(DPPluginBase):
|
||||||
@ -169,6 +229,7 @@ class TorchDDPPlugin(DPPluginBase):
|
|||||||
check_reduction (bool, optional): Whether to check reduction. Defaults to False.
|
check_reduction (bool, optional): Whether to check reduction. Defaults to False.
|
||||||
gradient_as_bucket_view (bool, optional): Whether to use gradient as bucket view. Defaults to False.
|
gradient_as_bucket_view (bool, optional): Whether to use gradient as bucket view. Defaults to False.
|
||||||
static_graph (bool, optional): Whether to use static graph. Defaults to False.
|
static_graph (bool, optional): Whether to use static graph. Defaults to False.
|
||||||
|
fp8_communication (bool, optional): Whether to enable fp8 communication. Defaults to False.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Dict, Iterable, Iterator, List, Optional, Tuple
|
from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -26,9 +26,11 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
|||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO, utils
|
from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO, utils
|
||||||
|
from colossalai.checkpoint_io.utils import async_save_state_dict_shards, create_pinned_state_dict
|
||||||
from colossalai.cluster import DistCoordinator
|
from colossalai.cluster import DistCoordinator
|
||||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
|
from colossalai.utils.safetensors import load_flat
|
||||||
|
|
||||||
from .dp_plugin_base import DPPluginBase
|
from .dp_plugin_base import DPPluginBase
|
||||||
|
|
||||||
@ -41,20 +43,54 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
|
|||||||
self.coordinator = DistCoordinator()
|
self.coordinator = DistCoordinator()
|
||||||
self.logger = get_dist_logger()
|
self.logger = get_dist_logger()
|
||||||
|
|
||||||
def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool):
|
def load_unsharded_model(
|
||||||
|
self, model: ModelWrapper, checkpoint: str, strict: bool, low_cpu_mem_mode: bool = True, num_threads: int = 1
|
||||||
|
):
|
||||||
assert isinstance(model, TorchFSDPModel), "Please boost the model before loading!"
|
assert isinstance(model, TorchFSDPModel), "Please boost the model before loading!"
|
||||||
model = model.unwrap()
|
model = model.unwrap()
|
||||||
checkpoint = utils.load_state_dict(checkpoint)
|
checkpoint = utils.load_state_dict(checkpoint)
|
||||||
model.load_state_dict(checkpoint)
|
model.load_state_dict(checkpoint)
|
||||||
|
|
||||||
def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: Path):
|
def load_unsharded_optimizer(
|
||||||
|
self, optimizer: OptimizerWrapper, checkpoint: Path, low_cpu_mem_mode: bool = True, num_threads: int = 1
|
||||||
|
):
|
||||||
assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before loading!"
|
assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before loading!"
|
||||||
checkpoint = utils.load_state_dict(checkpoint)
|
if checkpoint.endswith(".safetensors"):
|
||||||
|
checkpoint = load_flat(checkpoint, seperator=".")
|
||||||
|
else:
|
||||||
|
checkpoint = utils.load_state_dict(checkpoint)
|
||||||
|
|
||||||
fsdp_model = optimizer.unwrap_model()
|
fsdp_model = optimizer.unwrap_model()
|
||||||
|
full_optimizer_state = FSDP.full_optim_state_dict(fsdp_model, optim=optimizer, rank0_only=False)
|
||||||
|
start_index = 0
|
||||||
|
id2name = {}
|
||||||
|
|
||||||
|
def get_index_mapping(group: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
nonlocal start_index
|
||||||
|
start_num = len(id2name)
|
||||||
|
id2name.update({i: p for i, p in enumerate(group["params"], start_index) if i not in id2name})
|
||||||
|
end_num = len(id2name)
|
||||||
|
start_index += end_num - start_num
|
||||||
|
|
||||||
|
for g in full_optimizer_state["param_groups"]:
|
||||||
|
get_index_mapping(g)
|
||||||
|
|
||||||
|
new_state = {}
|
||||||
|
for key, value in checkpoint["state"].items():
|
||||||
|
new_state[id2name[int(key)]] = value
|
||||||
|
checkpoint["state"] = new_state
|
||||||
|
for g in checkpoint["param_groups"]:
|
||||||
|
new_group = []
|
||||||
|
for param_id in g["params"]:
|
||||||
|
new_group.append(id2name[param_id])
|
||||||
|
g["params"] = new_group
|
||||||
|
|
||||||
sharded_osd = FSDP.scatter_full_optim_state_dict(checkpoint, fsdp_model)
|
sharded_osd = FSDP.scatter_full_optim_state_dict(checkpoint, fsdp_model)
|
||||||
optimizer.load_state_dict(sharded_osd)
|
optimizer.load_state_dict(sharded_osd)
|
||||||
|
|
||||||
def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
|
def save_unsharded_model(
|
||||||
|
self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Save model to checkpoint but only on master process.
|
Save model to checkpoint but only on master process.
|
||||||
"""
|
"""
|
||||||
@ -63,16 +99,67 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
|
|||||||
cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
|
cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
|
||||||
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, cfg):
|
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, cfg):
|
||||||
full_model_state = model.state_dict()
|
full_model_state = model.state_dict()
|
||||||
utils.save_state_dict(full_model_state, checkpoint_file_path=checkpoint, use_safetensors=use_safetensors)
|
if self.coordinator.is_master():
|
||||||
|
if use_async:
|
||||||
|
from colossalai.utils.safetensors import save
|
||||||
|
|
||||||
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
|
if hash(model) not in self.pinned_state_dicts:
|
||||||
|
self.pinned_state_dicts[hash(model)] = create_pinned_state_dict(full_model_state)
|
||||||
|
for k, v in full_model_state.items():
|
||||||
|
self.pinned_state_dicts[hash(model)][k].copy_(v)
|
||||||
|
full_model_state[k] = self.pinned_state_dicts[hash(model)][k]
|
||||||
|
writer = save(checkpoint, full_model_state)
|
||||||
|
self.async_writers.append(writer)
|
||||||
|
else:
|
||||||
|
utils.save_state_dict(
|
||||||
|
full_model_state, checkpoint_file_path=checkpoint, use_safetensors=use_safetensors
|
||||||
|
)
|
||||||
|
|
||||||
|
def save_unsharded_optimizer(
|
||||||
|
self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool, use_async: bool = False
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Save optimizer to checkpoint but only on master process.
|
Save optimizer to checkpoint but only on master process.
|
||||||
"""
|
"""
|
||||||
assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before saving!"
|
assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before saving!"
|
||||||
fsdp_model = optimizer.unwrap_model()
|
fsdp_model = optimizer.unwrap_model()
|
||||||
|
|
||||||
full_optimizer_state = FSDP.full_optim_state_dict(fsdp_model, optim=optimizer, rank0_only=True)
|
full_optimizer_state = FSDP.full_optim_state_dict(fsdp_model, optim=optimizer, rank0_only=True)
|
||||||
utils.save_state_dict(full_optimizer_state, checkpoint_file_path=checkpoint, use_safetensors=False)
|
|
||||||
|
if self.coordinator.is_master():
|
||||||
|
|
||||||
|
# Save order indices instead of Tensors
|
||||||
|
name2id: Dict[str, int] = {}
|
||||||
|
start_index = 0
|
||||||
|
|
||||||
|
def pack_group(group: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
nonlocal start_index
|
||||||
|
packed = {k: v for k, v in group.items() if k != "params"}
|
||||||
|
name2id.update({p: i for i, p in enumerate(group["params"], start_index) if p not in name2id})
|
||||||
|
packed["params"] = [name2id[p] for p in group["params"]]
|
||||||
|
start_index += len(packed["params"])
|
||||||
|
return packed
|
||||||
|
|
||||||
|
param_groups = [pack_group(g) for g in full_optimizer_state["param_groups"]]
|
||||||
|
full_optimizer_state["param_groups"] = param_groups
|
||||||
|
new_state = {}
|
||||||
|
for key, value in full_optimizer_state["state"].items():
|
||||||
|
new_state[name2id[key]] = value
|
||||||
|
full_optimizer_state["state"] = new_state
|
||||||
|
|
||||||
|
if use_async:
|
||||||
|
from colossalai.utils.safetensors import _flatten_optim_state_dict, save
|
||||||
|
|
||||||
|
flatten_state_dict, metadata = _flatten_optim_state_dict(full_optimizer_state, seperator=".")
|
||||||
|
if id(optimizer) not in self.pinned_state_dicts:
|
||||||
|
self.pinned_state_dicts[id(optimizer)] = create_pinned_state_dict(flatten_state_dict)
|
||||||
|
for k, v in flatten_state_dict.items():
|
||||||
|
self.pinned_state_dicts[id(optimizer)][k].copy_(v)
|
||||||
|
flatten_state_dict[k] = self.pinned_state_dicts[id(optimizer)][k]
|
||||||
|
writer = save(checkpoint, state_dict=flatten_state_dict, metadata=metadata)
|
||||||
|
self.async_writers.append(writer)
|
||||||
|
else:
|
||||||
|
utils.save_state_dict(full_optimizer_state, checkpoint_file_path=checkpoint, use_safetensors=False)
|
||||||
|
|
||||||
def save_sharded_model(
|
def save_sharded_model(
|
||||||
self,
|
self,
|
||||||
@ -82,6 +169,7 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
|
|||||||
prefix: Optional[str] = None,
|
prefix: Optional[str] = None,
|
||||||
size_per_shard: int = 1024,
|
size_per_shard: int = 1024,
|
||||||
use_safetensors: bool = False,
|
use_safetensors: bool = False,
|
||||||
|
use_async: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Save model to checkpoint but only on master process.
|
Save model to checkpoint but only on master process.
|
||||||
@ -97,20 +185,38 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
|
|||||||
):
|
):
|
||||||
state_dict = model.unwrap().state_dict()
|
state_dict = model.unwrap().state_dict()
|
||||||
|
|
||||||
state_dict_shard = utils.shard_model_checkpoint(state_dict, max_shard_size=size_per_shard)
|
if use_async and self.coordinator.is_master():
|
||||||
|
if hash(model) not in self.pinned_state_dicts:
|
||||||
|
self.pinned_state_dicts[hash(model)] = {}
|
||||||
|
pinned_state_dicts = self.pinned_state_dicts[hash(model)]
|
||||||
|
else:
|
||||||
|
pinned_state_dicts = None
|
||||||
|
state_dict_shard = utils.shard_model_checkpoint(
|
||||||
|
state_dict, max_shard_size=size_per_shard, pinned_state_dicts=pinned_state_dicts
|
||||||
|
)
|
||||||
|
|
||||||
weights_name, save_index_file = utils.get_model_base_filenames(prefix, use_safetensors)
|
weights_name, save_index_file = utils.get_model_base_filenames(prefix, use_safetensors)
|
||||||
index_file = CheckpointIndexFile(checkpoint_path)
|
index_file = CheckpointIndexFile(checkpoint_path)
|
||||||
|
|
||||||
# In general cases, is_master is set to True to get the right behavior.
|
# In general cases, is_master is set to True to get the right behavior.
|
||||||
total_size = utils.save_state_dict_shards(
|
if use_async:
|
||||||
sharded_state_dict=state_dict_shard,
|
total_size, writers = async_save_state_dict_shards(
|
||||||
checkpoint=checkpoint_path,
|
sharded_state_dict=state_dict_shard,
|
||||||
index_file=index_file,
|
checkpoint=checkpoint_path,
|
||||||
base_filename=weights_name,
|
index_file=index_file,
|
||||||
is_master=self.coordinator.is_master(),
|
base_filename=weights_name,
|
||||||
use_safetensors=use_safetensors,
|
is_master=self.coordinator.is_master(),
|
||||||
)
|
)
|
||||||
|
self.async_writers.extend(writers)
|
||||||
|
else:
|
||||||
|
total_size = utils.save_state_dict_shards(
|
||||||
|
sharded_state_dict=state_dict_shard,
|
||||||
|
checkpoint=checkpoint_path,
|
||||||
|
index_file=index_file,
|
||||||
|
base_filename=weights_name,
|
||||||
|
is_master=self.coordinator.is_master(),
|
||||||
|
use_safetensors=use_safetensors,
|
||||||
|
)
|
||||||
|
|
||||||
# only save the index file on the master rank
|
# only save the index file on the master rank
|
||||||
if self.coordinator.is_master():
|
if self.coordinator.is_master():
|
||||||
@ -130,6 +236,8 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
|
|||||||
strict: bool = False,
|
strict: bool = False,
|
||||||
use_safetensors: bool = False,
|
use_safetensors: bool = False,
|
||||||
load_sub_module: bool = True,
|
load_sub_module: bool = True,
|
||||||
|
low_cpu_mem_mode: bool = True,
|
||||||
|
num_threads: int = 1,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Load model to checkpoint but only on master process.
|
Load model to checkpoint but only on master process.
|
||||||
@ -147,14 +255,20 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
|
|||||||
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
|
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
|
||||||
|
|
||||||
fsdp_state_dict = {}
|
fsdp_state_dict = {}
|
||||||
for shard_file in checkpoint_files:
|
for state_dict in utils.load_state_dict_shards(checkpoint_files, False, use_safetensors):
|
||||||
fsdp_state_dict.update(utils.load_shard_state_dict(Path(shard_file), use_safetensors))
|
fsdp_state_dict.update(state_dict)
|
||||||
|
|
||||||
with FSDP.state_dict_type(model.unwrap(), StateDictType.FULL_STATE_DICT):
|
with FSDP.state_dict_type(model.unwrap(), StateDictType.FULL_STATE_DICT):
|
||||||
model.unwrap().load_state_dict(fsdp_state_dict, strict=False)
|
model.unwrap().load_state_dict(fsdp_state_dict, strict=False)
|
||||||
|
|
||||||
def save_sharded_optimizer(
|
def save_sharded_optimizer(
|
||||||
self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool, prefix: str, size_per_shard: int
|
self,
|
||||||
|
optimizer: Optimizer,
|
||||||
|
checkpoint: str,
|
||||||
|
gather_dtensor: bool,
|
||||||
|
prefix: str,
|
||||||
|
size_per_shard: int,
|
||||||
|
use_async: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Save optimizer to checkpoint but only on master process.
|
Save optimizer to checkpoint but only on master process.
|
||||||
@ -177,26 +291,66 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.coordinator.is_master():
|
if self.coordinator.is_master():
|
||||||
|
|
||||||
|
# Save order indices instead of Tensors
|
||||||
|
name2id: Dict[str, int] = {}
|
||||||
|
start_index = 0
|
||||||
|
|
||||||
|
def pack_group(group: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
nonlocal start_index
|
||||||
|
packed = {k: v for k, v in group.items() if k != "params"}
|
||||||
|
name2id.update({p: i for i, p in enumerate(group["params"], start_index) if p not in name2id})
|
||||||
|
packed["params"] = [name2id[p] for p in group["params"]]
|
||||||
|
start_index += len(packed["params"])
|
||||||
|
return packed
|
||||||
|
|
||||||
|
param_groups = [pack_group(g) for g in fsdp_optim_state["param_groups"]]
|
||||||
|
fsdp_optim_state["param_groups"] = param_groups
|
||||||
|
new_state = {}
|
||||||
|
for key, value in fsdp_optim_state["state"].items():
|
||||||
|
new_state[name2id[key]] = value
|
||||||
|
fsdp_optim_state["state"] = new_state
|
||||||
|
|
||||||
# Preparing file paths and index file.
|
# Preparing file paths and index file.
|
||||||
states_name, save_index_file, param_group_file = utils.get_optimizer_base_filenames(prefix)
|
states_name, save_index_file, param_group_file = utils.get_optimizer_base_filenames(
|
||||||
|
prefix, use_safetensors=use_async
|
||||||
|
)
|
||||||
index_file = CheckpointIndexFile(checkpoint)
|
index_file = CheckpointIndexFile(checkpoint)
|
||||||
|
|
||||||
index_file.append_meta_data("param_groups", param_group_file)
|
index_file.append_meta_data("param_groups", param_group_file)
|
||||||
group_file_path = os.path.join(checkpoint, param_group_file)
|
group_file_path = os.path.join(checkpoint, param_group_file)
|
||||||
utils.save_param_groups(fsdp_optim_state, group_file_path)
|
utils.save_param_groups(fsdp_optim_state, group_file_path)
|
||||||
|
|
||||||
sharded_state = utils.shard_optimizer_checkpoint(fsdp_optim_state, max_shard_size=size_per_shard)
|
if use_async:
|
||||||
|
if id(optimizer) not in self.pinned_state_dicts:
|
||||||
|
self.pinned_state_dicts[id(optimizer)] = {}
|
||||||
|
pinned_state_dicts = self.pinned_state_dicts[id(optimizer)]
|
||||||
|
else:
|
||||||
|
pinned_state_dicts = None
|
||||||
|
sharded_state = utils.shard_optimizer_checkpoint(
|
||||||
|
fsdp_optim_state, max_shard_size=size_per_shard, pinned_state_dicts=pinned_state_dicts
|
||||||
|
)
|
||||||
# Save shards of optimizer states.
|
# Save shards of optimizer states.
|
||||||
# In general cases, is_master is set to True to get the right behavior.
|
# In general cases, is_master is set to True to get the right behavior.
|
||||||
total_size = utils.save_state_dict_shards(
|
if use_async:
|
||||||
sharded_state_dict=sharded_state,
|
total_size, writers = async_save_state_dict_shards(
|
||||||
checkpoint=checkpoint,
|
sharded_state_dict=sharded_state,
|
||||||
index_file=index_file,
|
checkpoint=checkpoint,
|
||||||
base_filename=states_name,
|
index_file=index_file,
|
||||||
is_master=self.coordinator.is_master(),
|
base_filename=states_name,
|
||||||
use_safetensors=False,
|
is_master=self.coordinator.is_master(),
|
||||||
)
|
state_preprocess=True,
|
||||||
|
)
|
||||||
|
self.async_writers.extend(writers)
|
||||||
|
else:
|
||||||
|
total_size = utils.save_state_dict_shards(
|
||||||
|
sharded_state_dict=sharded_state,
|
||||||
|
checkpoint=checkpoint,
|
||||||
|
index_file=index_file,
|
||||||
|
base_filename=states_name,
|
||||||
|
is_master=self.coordinator.is_master(),
|
||||||
|
use_safetensors=False,
|
||||||
|
)
|
||||||
|
|
||||||
index_file.append_meta_data("total_size", total_size)
|
index_file.append_meta_data("total_size", total_size)
|
||||||
index_file.write_index_file(save_index_file)
|
index_file.write_index_file(save_index_file)
|
||||||
@ -206,7 +360,14 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
|
|||||||
f"index located at {save_index_file}."
|
f"index located at {save_index_file}."
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, size_per_shard: int):
|
def load_sharded_optimizer(
|
||||||
|
self,
|
||||||
|
optimizer: Optimizer,
|
||||||
|
index_file_path: str,
|
||||||
|
size_per_shard: int,
|
||||||
|
low_cpu_mem_mode: bool = True,
|
||||||
|
num_threads: int = 1,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Load optimizer to checkpoint but only on master process.
|
Load optimizer to checkpoint but only on master process.
|
||||||
"""
|
"""
|
||||||
@ -227,12 +388,36 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
|
|||||||
# Load param
|
# Load param
|
||||||
fsdp_optim_state = {}
|
fsdp_optim_state = {}
|
||||||
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
|
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
|
||||||
for shard_file in checkpoint_files:
|
for state_dict_shard in utils.load_state_dict_shards(checkpoint_files, True, False):
|
||||||
state_dict_shard = utils.load_shard_state_dict(Path(shard_file), use_safetensors=False)
|
|
||||||
fsdp_optim_state.update(state_dict_shard)
|
fsdp_optim_state.update(state_dict_shard)
|
||||||
|
|
||||||
fsdp_optim_dict = dict(state=fsdp_optim_state, param_groups=saved_param_groups)
|
fsdp_optim_dict = dict(state=fsdp_optim_state, param_groups=saved_param_groups)
|
||||||
|
|
||||||
|
fsdp_model = optimizer.unwrap_model()
|
||||||
|
full_optimizer_state = FSDP.full_optim_state_dict(fsdp_model.unwrap(), optim=optimizer, rank0_only=False)
|
||||||
|
start_index = 0
|
||||||
|
id2name = {}
|
||||||
|
|
||||||
|
def get_index_mapping(group: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
nonlocal start_index
|
||||||
|
start_num = len(id2name)
|
||||||
|
id2name.update({i: p for i, p in enumerate(group["params"], start_index) if i not in id2name})
|
||||||
|
end_num = len(id2name)
|
||||||
|
start_index += end_num - start_num
|
||||||
|
|
||||||
|
for g in full_optimizer_state["param_groups"]:
|
||||||
|
get_index_mapping(g)
|
||||||
|
|
||||||
|
new_state = {}
|
||||||
|
for key, value in fsdp_optim_dict["state"].items():
|
||||||
|
new_state[id2name[int(key)]] = value
|
||||||
|
fsdp_optim_dict["state"] = new_state
|
||||||
|
for g in fsdp_optim_dict["param_groups"]:
|
||||||
|
new_group = []
|
||||||
|
for param_id in g["params"]:
|
||||||
|
new_group.append(id2name[param_id])
|
||||||
|
g["params"] = new_group
|
||||||
|
|
||||||
with FSDP.state_dict_type(optimizer.unwrap_model().unwrap(), StateDictType.FULL_STATE_DICT):
|
with FSDP.state_dict_type(optimizer.unwrap_model().unwrap(), StateDictType.FULL_STATE_DICT):
|
||||||
fsdp_state = FSDP.optim_state_dict_to_load(
|
fsdp_state = FSDP.optim_state_dict_to_load(
|
||||||
model=optimizer.unwrap_model().unwrap(), optim=optimizer, optim_state_dict=fsdp_optim_dict
|
model=optimizer.unwrap_model().unwrap(), optim=optimizer, optim_state_dict=fsdp_optim_dict
|
||||||
@ -252,9 +437,6 @@ class TorchFSDPModel(ModelWrapper):
|
|||||||
super().__init__(module)
|
super().__init__(module)
|
||||||
self.module = FSDP(module, *args, **kwargs)
|
self.module = FSDP(module, *args, **kwargs)
|
||||||
|
|
||||||
def unwrap(self):
|
|
||||||
return self.module
|
|
||||||
|
|
||||||
|
|
||||||
class FSDPOptimizerWrapper(OptimizerWrapper):
|
class FSDPOptimizerWrapper(OptimizerWrapper):
|
||||||
def __init__(self, optimizer: Optimizer, model: nn.Module):
|
def __init__(self, optimizer: Optimizer, model: nn.Module):
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Union
|
from typing import Dict, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -8,6 +8,7 @@ from torch.optim import Optimizer
|
|||||||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||||
|
|
||||||
from colossalai.interface import ModelWrapper
|
from colossalai.interface import ModelWrapper
|
||||||
|
from colossalai.logging import get_dist_logger
|
||||||
|
|
||||||
from .utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, has_index_file
|
from .utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, has_index_file
|
||||||
|
|
||||||
@ -61,8 +62,35 @@ class CheckpointIO(ABC):
|
|||||||
# ======================================
|
# ======================================
|
||||||
# Public methods
|
# Public methods
|
||||||
# ======================================
|
# ======================================
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.pinned_state_dicts: Dict[int, dict] = {}
|
||||||
|
self.async_writers = []
|
||||||
|
|
||||||
|
def _sync_io(self):
|
||||||
|
for writer in self.async_writers:
|
||||||
|
writer.synchronize()
|
||||||
|
self.async_writers.clear()
|
||||||
|
|
||||||
|
def _sync_d2h(self):
|
||||||
|
for writer in self.async_writers:
|
||||||
|
writer.sync_before_step()
|
||||||
|
|
||||||
|
def synchronize(self):
|
||||||
|
"""This method must be called before updating the model weights."""
|
||||||
|
self._sync_d2h()
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
self._sync_d2h()
|
||||||
|
self._sync_io()
|
||||||
|
|
||||||
def load_model(
|
def load_model(
|
||||||
self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True
|
self,
|
||||||
|
model: Union[nn.Module, ModelWrapper],
|
||||||
|
checkpoint: str,
|
||||||
|
strict: bool = True,
|
||||||
|
low_cpu_mem_mode: bool = True,
|
||||||
|
num_threads: int = 1,
|
||||||
) -> Union[nn.Module, ModelWrapper]:
|
) -> Union[nn.Module, ModelWrapper]:
|
||||||
"""
|
"""
|
||||||
Load model from checkpoint.
|
Load model from checkpoint.
|
||||||
@ -77,6 +105,8 @@ class CheckpointIO(ABC):
|
|||||||
Distributed tensors cannot be loaded directly unless gathered offline via our CLI.
|
Distributed tensors cannot be loaded directly unless gathered offline via our CLI.
|
||||||
strict (bool): whether to strictly enforce that the param name in
|
strict (bool): whether to strictly enforce that the param name in
|
||||||
the checkpoint match the keys returned by this module's.
|
the checkpoint match the keys returned by this module's.
|
||||||
|
low_cpu_mem_mode (bool): whether to load the model in low cpu memory mode. If false, it will use RAM cache to accelerate loading. Default: True.
|
||||||
|
num_threads (int): number of threads to use when loading the model. Only useful when disabling low cpu mem mode. Default: 1.
|
||||||
"""
|
"""
|
||||||
# since we only support loaded sharded and unsharded weight format
|
# since we only support loaded sharded and unsharded weight format
|
||||||
# containing no distributed tensors, dtensor -> full tensor conversion
|
# containing no distributed tensors, dtensor -> full tensor conversion
|
||||||
@ -88,17 +118,25 @@ class CheckpointIO(ABC):
|
|||||||
origin_model = model
|
origin_model = model
|
||||||
|
|
||||||
if index_file_exists:
|
if index_file_exists:
|
||||||
self.load_sharded_model(model, index_file_path, strict)
|
self.load_sharded_model(
|
||||||
|
model, index_file_path, strict, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
path = Path(checkpoint, SAFE_WEIGHTS_NAME)
|
path = Path(checkpoint, SAFE_WEIGHTS_NAME)
|
||||||
if path.is_file():
|
if path.is_file():
|
||||||
self.load_unsharded_model(model, str(path), strict)
|
self.load_unsharded_model(
|
||||||
|
model, str(path), strict, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
path = Path(checkpoint, WEIGHTS_NAME)
|
path = Path(checkpoint, WEIGHTS_NAME)
|
||||||
if path.is_file():
|
if path.is_file():
|
||||||
self.load_unsharded_model(model, str(path), strict)
|
self.load_unsharded_model(
|
||||||
|
model, str(path), strict, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.load_unsharded_model(model, checkpoint, strict)
|
self.load_unsharded_model(
|
||||||
|
model, checkpoint, strict, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
|
||||||
|
)
|
||||||
|
|
||||||
return origin_model
|
return origin_model
|
||||||
|
|
||||||
@ -111,6 +149,7 @@ class CheckpointIO(ABC):
|
|||||||
prefix: str = None,
|
prefix: str = None,
|
||||||
size_per_shard: int = 1024,
|
size_per_shard: int = 1024,
|
||||||
use_safetensors: bool = False,
|
use_safetensors: bool = False,
|
||||||
|
use_async: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Save model to checkpoint.
|
Save model to checkpoint.
|
||||||
@ -138,13 +177,30 @@ class CheckpointIO(ABC):
|
|||||||
size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard = True.
|
size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard = True.
|
||||||
use_safetensors (bool): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved
|
use_safetensors (bool): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved
|
||||||
"""
|
"""
|
||||||
|
self._sync_io()
|
||||||
|
if use_async and not use_safetensors:
|
||||||
|
logger = get_dist_logger()
|
||||||
|
logger.warning(
|
||||||
|
"Async save is only supported when use_safetensors is set to True. "
|
||||||
|
"Setting use_safetensors to True for async save."
|
||||||
|
)
|
||||||
|
use_safetensors = True
|
||||||
|
|
||||||
if shard:
|
if shard:
|
||||||
self.save_sharded_model(model, checkpoint, gather_dtensor, prefix, size_per_shard, use_safetensors)
|
self.save_sharded_model(
|
||||||
|
model, checkpoint, gather_dtensor, prefix, size_per_shard, use_safetensors, use_async
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors)
|
self.save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async)
|
||||||
|
|
||||||
def load_optimizer(self, optimizer: Optimizer, checkpoint: str, prefix: str = None, size_per_shard: int = 1024):
|
def load_optimizer(
|
||||||
|
self,
|
||||||
|
optimizer: Optimizer,
|
||||||
|
checkpoint: str,
|
||||||
|
prefix: str = None,
|
||||||
|
low_cpu_mem_mode: bool = True,
|
||||||
|
num_threads: int = 1,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Load optimizer from checkpoint.
|
Load optimizer from checkpoint.
|
||||||
|
|
||||||
@ -153,7 +209,8 @@ class CheckpointIO(ABC):
|
|||||||
checkpoint (str): checkpoint path. This value is made compatibility with the model checkpoints in the
|
checkpoint (str): checkpoint path. This value is made compatibility with the model checkpoints in the
|
||||||
prefix (str, optional): A prefix added to parameter and buffer
|
prefix (str, optional): A prefix added to parameter and buffer
|
||||||
names to compose the keys in state_dict. Defaults to None.
|
names to compose the keys in state_dict. Defaults to None.
|
||||||
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
|
low_cpu_mem_mode (bool): whether to load the model in low cpu memory mode. If false, it will use RAM cache to accelerate loading. Default: True.
|
||||||
|
num_threads (int): number of threads to use when loading the model. Only useful when disabling low cpu mem mode. Default: 1.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
index_file_exists, index_file_path = has_index_file(checkpoint)
|
index_file_exists, index_file_path = has_index_file(checkpoint)
|
||||||
@ -164,9 +221,13 @@ class CheckpointIO(ABC):
|
|||||||
|
|
||||||
if index_file_exists:
|
if index_file_exists:
|
||||||
# the existence of index file means it is a sharded checkpoint
|
# the existence of index file means it is a sharded checkpoint
|
||||||
self.load_sharded_optimizer(optimizer, index_file_path, prefix)
|
self.load_sharded_optimizer(
|
||||||
|
optimizer, index_file_path, prefix, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.load_unsharded_optimizer(optimizer, checkpoint)
|
self.load_unsharded_optimizer(
|
||||||
|
optimizer, checkpoint, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
|
||||||
|
)
|
||||||
|
|
||||||
def save_optimizer(
|
def save_optimizer(
|
||||||
self,
|
self,
|
||||||
@ -176,6 +237,7 @@ class CheckpointIO(ABC):
|
|||||||
gather_dtensor=True,
|
gather_dtensor=True,
|
||||||
prefix: str = None,
|
prefix: str = None,
|
||||||
size_per_shard: int = 1024,
|
size_per_shard: int = 1024,
|
||||||
|
use_async: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Save optimizer to checkpoint. Optimizer states saving is not compatible with safetensors.
|
Save optimizer to checkpoint. Optimizer states saving is not compatible with safetensors.
|
||||||
@ -192,17 +254,20 @@ class CheckpointIO(ABC):
|
|||||||
prefix (str): prefix for the optimizer checkpoint when shard = True. Default: None.
|
prefix (str): prefix for the optimizer checkpoint when shard = True. Default: None.
|
||||||
size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard is set to True.
|
size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard is set to True.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if shard:
|
if shard:
|
||||||
self.save_sharded_optimizer(optimizer, checkpoint, gather_dtensor, prefix, size_per_shard)
|
self.save_sharded_optimizer(
|
||||||
|
optimizer, checkpoint, gather_dtensor, prefix, size_per_shard, use_async=use_async
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor)
|
self.save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor, use_async=use_async)
|
||||||
|
|
||||||
# ========================================================
|
# ========================================================
|
||||||
# Abstract methods for model loading/saving implementation
|
# Abstract methods for model loading/saving implementation
|
||||||
# ========================================================
|
# ========================================================
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def load_sharded_model(self, model: nn.Module, index_file_path: str, strict: bool):
|
def load_sharded_model(
|
||||||
|
self, model: nn.Module, index_file_path: str, strict: bool, low_cpu_mem_mode: bool = True, num_threads: int = 1
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Load model from sharded checkpoint.
|
Load model from sharded checkpoint.
|
||||||
|
|
||||||
@ -211,10 +276,14 @@ class CheckpointIO(ABC):
|
|||||||
index_file_path (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file.
|
index_file_path (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file.
|
||||||
strict (bool): whether to strictly enforce that the param name in
|
strict (bool): whether to strictly enforce that the param name in
|
||||||
the checkpoint match the keys returned by this module's.
|
the checkpoint match the keys returned by this module's.
|
||||||
|
low_cpu_mem_mode (bool): whether to load the model in low cpu memory mode. If false, it will use RAM cache to accelerate loading. Default: True.
|
||||||
|
num_threads (int): number of threads to use when loading the model. Only useful when disabling low cpu mem mode. Default: 1.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool):
|
def load_unsharded_model(
|
||||||
|
self, model: nn.Module, checkpoint: str, strict: bool, low_cpu_mem_mode: bool = True, num_threads: int = 1
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Load model from unsharded checkpoint.
|
Load model from unsharded checkpoint.
|
||||||
|
|
||||||
@ -223,6 +292,8 @@ class CheckpointIO(ABC):
|
|||||||
checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary.
|
checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary.
|
||||||
strict (bool): whether to strictly enforce that the param name in
|
strict (bool): whether to strictly enforce that the param name in
|
||||||
the checkpoint match the keys returned by this module's.
|
the checkpoint match the keys returned by this module's.
|
||||||
|
low_cpu_mem_mode (bool): whether to load the model in low cpu memory mode. If false, it will use RAM cache to accelerate loading. Default: True.
|
||||||
|
num_threads (int): number of threads to use when loading the model. Only useful when disabling low cpu mem mode. Default: 1.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -234,6 +305,7 @@ class CheckpointIO(ABC):
|
|||||||
prefix: Optional[str],
|
prefix: Optional[str],
|
||||||
size_per_shard: int,
|
size_per_shard: int,
|
||||||
use_safetensors: bool,
|
use_safetensors: bool,
|
||||||
|
use_async: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Save model to sharded checkpoint.
|
Save model to sharded checkpoint.
|
||||||
@ -248,7 +320,9 @@ class CheckpointIO(ABC):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
|
def save_unsharded_model(
|
||||||
|
self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Save model to unsharded checkpoint.
|
Save model to unsharded checkpoint.
|
||||||
|
|
||||||
@ -264,7 +338,14 @@ class CheckpointIO(ABC):
|
|||||||
# ========================================================
|
# ========================================================
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str):
|
def load_sharded_optimizer(
|
||||||
|
self,
|
||||||
|
optimizer: Optimizer,
|
||||||
|
index_file_path: str,
|
||||||
|
prefix: str,
|
||||||
|
low_cpu_mem_mode: bool = True,
|
||||||
|
num_threads: int = 1,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Load optimizer from sharded checkpoint.
|
Load optimizer from sharded checkpoint.
|
||||||
|
|
||||||
@ -272,21 +353,33 @@ class CheckpointIO(ABC):
|
|||||||
optimizer (Optimizer): optimizer to be loaded.
|
optimizer (Optimizer): optimizer to be loaded.
|
||||||
index_file_path (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file.
|
index_file_path (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file.
|
||||||
prefix (str): prefix for the optimizer checkpoint.
|
prefix (str): prefix for the optimizer checkpoint.
|
||||||
|
low_cpu_mem_mode (bool): whether to load the model in low cpu memory mode. If false, it will use RAM cache to accelerate loading. Default: True.
|
||||||
|
num_threads (int): number of threads to use when loading the model. Only useful when disabling low cpu mem mode. Default: 1.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
|
def load_unsharded_optimizer(
|
||||||
|
self, optimizer: Optimizer, checkpoint: Path, low_cpu_mem_mode: bool = True, num_threads: int = 1
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Load optimizer from unsharded checkpoint.
|
Load optimizer from unsharded checkpoint.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
optimizer (Optimizer): optimizer to be loaded.
|
optimizer (Optimizer): optimizer to be loaded.
|
||||||
checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary.
|
checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary.
|
||||||
|
low_cpu_mem_mode (bool): whether to load the model in low cpu memory mode. If false, it will use RAM cache to accelerate loading. Default: True.
|
||||||
|
num_threads (int): number of threads to use when loading the model. Only useful when disabling low cpu mem mode. Default: 1.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def save_sharded_optimizer(
|
def save_sharded_optimizer(
|
||||||
self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, prefix: str, size_per_shard: int
|
self,
|
||||||
|
optimizer: Optimizer,
|
||||||
|
checkpoint: Path,
|
||||||
|
gather_dtensor: bool,
|
||||||
|
prefix: str,
|
||||||
|
size_per_shard: int,
|
||||||
|
use_async: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Save optimizer to sharded checkpoint.
|
Save optimizer to sharded checkpoint.
|
||||||
@ -300,7 +393,9 @@ class CheckpointIO(ABC):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool):
|
def save_unsharded_optimizer(
|
||||||
|
self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, use_async: bool = False
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Save optimizer to unsharded checkpoint.
|
Save optimizer to unsharded checkpoint.
|
||||||
|
|
||||||
@ -342,7 +437,11 @@ class CheckpointIO(ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def save_lora_as_pretrained(
|
def save_lora_as_pretrained(
|
||||||
self, model: Union[nn.Module, ModelWrapper], checkpoint: str, use_safetensors: bool = False
|
self,
|
||||||
|
model: Union[nn.Module, ModelWrapper],
|
||||||
|
checkpoint: str,
|
||||||
|
use_safetensors: bool = False,
|
||||||
|
state_dict: Optional[dict] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Save the lora adapters and adapter configuration file to a pretrained checkpoint directory.
|
Save the lora adapters and adapter configuration file to a pretrained checkpoint directory.
|
||||||
@ -351,4 +450,5 @@ class CheckpointIO(ABC):
|
|||||||
model (Union[nn.Module, ModelWrapper]): A model boosted by Booster.
|
model (Union[nn.Module, ModelWrapper]): A model boosted by Booster.
|
||||||
checkpoint (str): Path to the checkpoint directory. It must be a local path.
|
checkpoint (str): Path to the checkpoint directory. It must be a local path.
|
||||||
use_safetensors (bool, optional): Whether to use safe tensors when saving. Defaults to False.
|
use_safetensors (bool, optional): Whether to use safe tensors when saving. Defaults to False.
|
||||||
|
state_dict (Optional[dict], optional): The state dict to save. Defaults to None.
|
||||||
"""
|
"""
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
import gc
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
@ -8,16 +7,20 @@ from typing import Optional
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
|
|
||||||
|
from colossalai.utils.safetensors import load_flat
|
||||||
|
|
||||||
from .checkpoint_io_base import CheckpointIO
|
from .checkpoint_io_base import CheckpointIO
|
||||||
from .index_file import CheckpointIndexFile
|
from .index_file import CheckpointIndexFile
|
||||||
from .utils import (
|
from .utils import (
|
||||||
|
async_move_save_state_dict_shards,
|
||||||
|
create_pinned_state_dict,
|
||||||
get_model_base_filenames,
|
get_model_base_filenames,
|
||||||
get_optimizer_base_filenames,
|
get_optimizer_base_filenames,
|
||||||
is_safetensors_available,
|
is_safetensors_available,
|
||||||
load_param_groups_into_optimizer,
|
load_param_groups_into_optimizer,
|
||||||
load_shard_state_dict,
|
|
||||||
load_state_dict,
|
load_state_dict,
|
||||||
load_state_dict_into_model,
|
load_state_dict_into_model,
|
||||||
|
load_state_dict_shards,
|
||||||
load_states_into_optimizer,
|
load_states_into_optimizer,
|
||||||
save_config_file,
|
save_config_file,
|
||||||
save_param_groups,
|
save_param_groups,
|
||||||
@ -36,21 +39,43 @@ class GeneralCheckpointIO(CheckpointIO):
|
|||||||
Checkpoint IO
|
Checkpoint IO
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool):
|
def load_unsharded_model(
|
||||||
|
self,
|
||||||
|
model: nn.Module,
|
||||||
|
checkpoint: str,
|
||||||
|
strict: bool,
|
||||||
|
low_cpu_mem_mode: bool = True,
|
||||||
|
num_threads: int = 1,
|
||||||
|
):
|
||||||
checkpoint = load_state_dict(checkpoint)
|
checkpoint = load_state_dict(checkpoint)
|
||||||
|
if not low_cpu_mem_mode:
|
||||||
|
checkpoint = create_pinned_state_dict(checkpoint, empty=False, num_threads=num_threads)
|
||||||
model.load_state_dict(checkpoint, strict=strict)
|
model.load_state_dict(checkpoint, strict=strict)
|
||||||
|
|
||||||
def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
|
def save_unsharded_model(
|
||||||
|
self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False
|
||||||
|
):
|
||||||
state_dict = model.state_dict()
|
state_dict = model.state_dict()
|
||||||
|
|
||||||
# TODO(FrankLeeeee): add support for gather_dtensor
|
if use_async:
|
||||||
if gather_dtensor:
|
from colossalai.utils.safetensors import move_and_save
|
||||||
pass
|
|
||||||
|
|
||||||
# save the checkpoint
|
if hash(model) not in self.pinned_state_dicts:
|
||||||
save_state_dict(state_dict, checkpoint, use_safetensors)
|
self.pinned_state_dicts[hash(model)] = create_pinned_state_dict(state_dict)
|
||||||
|
writer = move_and_save(checkpoint, state_dict, self.pinned_state_dicts[hash(model)])
|
||||||
|
self.async_writers.append(writer)
|
||||||
|
else:
|
||||||
|
# save the checkpoint
|
||||||
|
save_state_dict(state_dict, checkpoint, use_safetensors)
|
||||||
|
|
||||||
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str):
|
def load_sharded_optimizer(
|
||||||
|
self,
|
||||||
|
optimizer: Optimizer,
|
||||||
|
index_file_path: str,
|
||||||
|
prefix: str,
|
||||||
|
low_cpu_mem_mode: bool = True,
|
||||||
|
num_threads: int = 1,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Load sharded optimizer with the given path to index file.
|
Load sharded optimizer with the given path to index file.
|
||||||
"""
|
"""
|
||||||
@ -69,8 +94,9 @@ class GeneralCheckpointIO(CheckpointIO):
|
|||||||
|
|
||||||
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
|
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
|
||||||
|
|
||||||
for shard_file in checkpoint_files:
|
for state_dict in load_state_dict_shards(checkpoint_files, True, False, low_cpu_mem_mode):
|
||||||
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False)
|
if not low_cpu_mem_mode:
|
||||||
|
state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads)
|
||||||
load_states_into_optimizer(optimizer, state_dict, id_map)
|
load_states_into_optimizer(optimizer, state_dict, id_map)
|
||||||
|
|
||||||
sharded_optimizer_loading_epilogue(optimizer)
|
sharded_optimizer_loading_epilogue(optimizer)
|
||||||
@ -82,6 +108,7 @@ class GeneralCheckpointIO(CheckpointIO):
|
|||||||
gather_dtensor: bool,
|
gather_dtensor: bool,
|
||||||
prefix: str,
|
prefix: str,
|
||||||
size_per_shard: int,
|
size_per_shard: int,
|
||||||
|
use_async: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Save sharded optimizer checkpoint under the given checkpointing path.
|
Save sharded optimizer checkpoint under the given checkpointing path.
|
||||||
@ -102,7 +129,7 @@ class GeneralCheckpointIO(CheckpointIO):
|
|||||||
sharded_state = shard_optimizer_checkpoint(state_dict, max_shard_size=size_per_shard)
|
sharded_state = shard_optimizer_checkpoint(state_dict, max_shard_size=size_per_shard)
|
||||||
|
|
||||||
# Preparing file paths and index file.
|
# Preparing file paths and index file.
|
||||||
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
|
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix, use_safetensors=use_async)
|
||||||
index_file = CheckpointIndexFile(checkpoint)
|
index_file = CheckpointIndexFile(checkpoint)
|
||||||
|
|
||||||
# Store the information of param groups to param_group_file.
|
# Store the information of param groups to param_group_file.
|
||||||
@ -112,14 +139,28 @@ class GeneralCheckpointIO(CheckpointIO):
|
|||||||
|
|
||||||
# Save shards of optimizer states.
|
# Save shards of optimizer states.
|
||||||
# In general cases, is_master is set to True to get the right behavior.
|
# In general cases, is_master is set to True to get the right behavior.
|
||||||
total_size = save_state_dict_shards(
|
if use_async:
|
||||||
sharded_state_dict=sharded_state,
|
pinned_state_dict = self.pinned_state_dicts.get(id(optimizer), None)
|
||||||
checkpoint=checkpoint,
|
total_size, new_pinned_state_dict, writers = async_move_save_state_dict_shards(
|
||||||
index_file=index_file,
|
sharded_state_dict=sharded_state,
|
||||||
base_filename=states_name,
|
checkpoint=checkpoint,
|
||||||
is_master=True,
|
index_file=index_file,
|
||||||
use_safetensors=False,
|
base_filename=states_name,
|
||||||
)
|
is_master=True,
|
||||||
|
pinned_state_dict=pinned_state_dict,
|
||||||
|
state_preprocess=True,
|
||||||
|
)
|
||||||
|
self.pinned_state_dicts[id(optimizer)] = new_pinned_state_dict
|
||||||
|
self.async_writers.extend(writers)
|
||||||
|
else:
|
||||||
|
total_size = save_state_dict_shards(
|
||||||
|
sharded_state_dict=sharded_state,
|
||||||
|
checkpoint=checkpoint,
|
||||||
|
index_file=index_file,
|
||||||
|
base_filename=states_name,
|
||||||
|
is_master=True,
|
||||||
|
use_safetensors=False,
|
||||||
|
)
|
||||||
|
|
||||||
# Wrap up index file.
|
# Wrap up index file.
|
||||||
index_file.append_meta_data("total_size", total_size)
|
index_file.append_meta_data("total_size", total_size)
|
||||||
@ -130,8 +171,15 @@ class GeneralCheckpointIO(CheckpointIO):
|
|||||||
f"index located at {save_index_file}."
|
f"index located at {save_index_file}."
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
|
def load_unsharded_optimizer(
|
||||||
checkpoint = load_state_dict(checkpoint)
|
self, optimizer: Optimizer, checkpoint: Path, low_cpu_mem_mode: bool = True, num_threads: int = 1
|
||||||
|
):
|
||||||
|
if checkpoint.endswith(".safetensors"):
|
||||||
|
checkpoint = load_flat(checkpoint)
|
||||||
|
else:
|
||||||
|
checkpoint = load_state_dict(checkpoint)
|
||||||
|
if not low_cpu_mem_mode:
|
||||||
|
checkpoint = create_pinned_state_dict(checkpoint, empty=False, num_threads=num_threads)
|
||||||
optimizer.load_state_dict(checkpoint)
|
optimizer.load_state_dict(checkpoint)
|
||||||
|
|
||||||
def save_unsharded_optimizer(
|
def save_unsharded_optimizer(
|
||||||
@ -139,9 +187,25 @@ class GeneralCheckpointIO(CheckpointIO):
|
|||||||
optimizer: Optimizer,
|
optimizer: Optimizer,
|
||||||
checkpoint: Path,
|
checkpoint: Path,
|
||||||
gather_dtensor: bool,
|
gather_dtensor: bool,
|
||||||
|
use_async: bool = False,
|
||||||
):
|
):
|
||||||
# TODO(FrankLeeeee): handle distributed tensors
|
# TODO(FrankLeeeee): handle distributed tensors
|
||||||
save_state_dict(optimizer.state_dict(), checkpoint, use_safetensors=False)
|
state_dict = optimizer.state_dict()
|
||||||
|
if use_async:
|
||||||
|
from colossalai.utils.safetensors import _flatten_optim_state_dict, move_and_save
|
||||||
|
|
||||||
|
flatten_state_dict, metadata = _flatten_optim_state_dict(state_dict)
|
||||||
|
if id(optimizer) not in self.pinned_state_dicts:
|
||||||
|
self.pinned_state_dicts[id(optimizer)] = create_pinned_state_dict(flatten_state_dict)
|
||||||
|
writer = move_and_save(
|
||||||
|
path=checkpoint,
|
||||||
|
state_dict=flatten_state_dict,
|
||||||
|
state_dict_pinned=self.pinned_state_dicts[id(optimizer)],
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
self.async_writers.append(writer)
|
||||||
|
else:
|
||||||
|
save_state_dict(state_dict, checkpoint, use_safetensors=False)
|
||||||
|
|
||||||
def save_sharded_model(
|
def save_sharded_model(
|
||||||
self,
|
self,
|
||||||
@ -151,6 +215,7 @@ class GeneralCheckpointIO(CheckpointIO):
|
|||||||
prefix: Optional[str] = None,
|
prefix: Optional[str] = None,
|
||||||
max_shard_size: int = 1024,
|
max_shard_size: int = 1024,
|
||||||
use_safetensors: bool = False,
|
use_safetensors: bool = False,
|
||||||
|
use_async: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
implement this method as it can be supported by Huggingface model,
|
implement this method as it can be supported by Huggingface model,
|
||||||
@ -168,16 +233,29 @@ class GeneralCheckpointIO(CheckpointIO):
|
|||||||
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
|
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
|
||||||
index_file = CheckpointIndexFile(checkpoint_path)
|
index_file = CheckpointIndexFile(checkpoint_path)
|
||||||
|
|
||||||
# Save shards of optimizer states.
|
if use_async:
|
||||||
# In general cases, is_master is set to True to get the right behavior.
|
pinned_state_dict = self.pinned_state_dicts.get(hash(model), None)
|
||||||
total_size = save_state_dict_shards(
|
total_size, new_pinned_state_dict, writers = async_move_save_state_dict_shards(
|
||||||
sharded_state_dict=state_dict_shard,
|
sharded_state_dict=state_dict_shard,
|
||||||
checkpoint=checkpoint_path,
|
checkpoint=checkpoint_path,
|
||||||
index_file=index_file,
|
index_file=index_file,
|
||||||
base_filename=weights_name,
|
base_filename=weights_name,
|
||||||
is_master=True,
|
is_master=True,
|
||||||
use_safetensors=use_safetensors,
|
pinned_state_dict=pinned_state_dict,
|
||||||
)
|
)
|
||||||
|
self.pinned_state_dicts[hash(model)] = new_pinned_state_dict
|
||||||
|
self.async_writers.extend(writers)
|
||||||
|
else:
|
||||||
|
# Save shards of optimizer states.
|
||||||
|
# In general cases, is_master is set to True to get the right behavior.
|
||||||
|
total_size = save_state_dict_shards(
|
||||||
|
sharded_state_dict=state_dict_shard,
|
||||||
|
checkpoint=checkpoint_path,
|
||||||
|
index_file=index_file,
|
||||||
|
base_filename=weights_name,
|
||||||
|
is_master=True,
|
||||||
|
use_safetensors=use_safetensors,
|
||||||
|
)
|
||||||
|
|
||||||
index_file.append_meta_data("total_size", total_size)
|
index_file.append_meta_data("total_size", total_size)
|
||||||
index_file.write_index_file(save_index_file)
|
index_file.write_index_file(save_index_file)
|
||||||
@ -195,6 +273,8 @@ class GeneralCheckpointIO(CheckpointIO):
|
|||||||
strict: bool = False,
|
strict: bool = False,
|
||||||
use_safetensors: bool = False,
|
use_safetensors: bool = False,
|
||||||
load_sub_module: bool = True,
|
load_sub_module: bool = True,
|
||||||
|
low_cpu_mem_mode: bool = True,
|
||||||
|
num_threads: int = 1,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
load shard model, load model from multiple files
|
load shard model, load model from multiple files
|
||||||
@ -211,11 +291,10 @@ class GeneralCheckpointIO(CheckpointIO):
|
|||||||
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
|
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
|
||||||
missing_keys = []
|
missing_keys = []
|
||||||
|
|
||||||
for shard_file in checkpoint_files:
|
for state_dict in load_state_dict_shards(checkpoint_files, False, use_safetensors, low_cpu_mem_mode):
|
||||||
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors)
|
if not low_cpu_mem_mode:
|
||||||
|
state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads)
|
||||||
load_state_dict_into_model(model, state_dict, missing_keys, strict, load_sub_module)
|
load_state_dict_into_model(model, state_dict, missing_keys, strict, load_sub_module)
|
||||||
del state_dict
|
|
||||||
gc.collect()
|
|
||||||
|
|
||||||
if strict:
|
if strict:
|
||||||
remain_keys = reduce(lambda a, b: a & b, map(set, missing_keys))
|
remain_keys = reduce(lambda a, b: a & b, map(set, missing_keys))
|
||||||
@ -229,5 +308,7 @@ class GeneralCheckpointIO(CheckpointIO):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def save_lora_as_pretrained(self, model: nn.Module, checkpoint: str, use_safetensors: bool = False) -> None:
|
def save_lora_as_pretrained(
|
||||||
|
self, model: nn.Module, checkpoint: str, use_safetensors: bool = False, state_dict: Optional[dict] = None
|
||||||
|
) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import copy
|
import copy
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
from collections import defaultdict
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import rmtree
|
from shutil import rmtree
|
||||||
@ -10,7 +11,9 @@ import torch
|
|||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
|
from torch.optim import Optimizer
|
||||||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||||
|
from torch.utils._pytree import tree_map
|
||||||
|
|
||||||
from colossalai.cluster import DistCoordinator
|
from colossalai.cluster import DistCoordinator
|
||||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||||
@ -20,20 +23,24 @@ from colossalai.tensor.padded_tensor import (
|
|||||||
to_padded_tensor,
|
to_padded_tensor,
|
||||||
to_unpadded_tensor,
|
to_unpadded_tensor,
|
||||||
)
|
)
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device, get_non_persistent_buffers_set
|
||||||
|
from colossalai.utils.safetensors import _flatten_optim_state_dict, load_flat
|
||||||
|
|
||||||
from .general_checkpoint_io import GeneralCheckpointIO
|
from .general_checkpoint_io import GeneralCheckpointIO
|
||||||
from .index_file import CheckpointIndexFile
|
from .index_file import CheckpointIndexFile
|
||||||
from .utils import (
|
from .utils import (
|
||||||
StateDictSharder,
|
StateDictSharder,
|
||||||
|
async_save_state_dict_shards,
|
||||||
|
create_pinned_state_dict,
|
||||||
gather_distributed_param,
|
gather_distributed_param,
|
||||||
|
gather_state_dict_fast,
|
||||||
|
get_lora_state_dict,
|
||||||
get_model_base_filenames,
|
get_model_base_filenames,
|
||||||
get_optimizer_base_filenames,
|
get_optimizer_base_filenames,
|
||||||
is_safetensors_available,
|
is_safetensors_available,
|
||||||
load_shard_state_dict,
|
load_shard_state_dict,
|
||||||
load_state_dict,
|
load_state_dict,
|
||||||
load_state_dict_into_model,
|
load_state_dict_into_model,
|
||||||
load_states_into_optimizer,
|
|
||||||
save_config_file,
|
save_config_file,
|
||||||
save_param_groups,
|
save_param_groups,
|
||||||
save_state_dict,
|
save_state_dict,
|
||||||
@ -66,6 +73,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||||||
dp_group: ProcessGroup,
|
dp_group: ProcessGroup,
|
||||||
pp_group: ProcessGroup,
|
pp_group: ProcessGroup,
|
||||||
tp_group: ProcessGroup,
|
tp_group: ProcessGroup,
|
||||||
|
sp_group: ProcessGroup,
|
||||||
zero_stage: int,
|
zero_stage: int,
|
||||||
verbose: bool = True,
|
verbose: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -73,9 +81,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||||||
self.global_dp_group = dp_group
|
self.global_dp_group = dp_group
|
||||||
self.pp_group = pp_group
|
self.pp_group = pp_group
|
||||||
self.tp_group = tp_group
|
self.tp_group = tp_group
|
||||||
|
self.sp_group = sp_group
|
||||||
self.dp_rank = dist.get_rank(self.global_dp_group)
|
self.dp_rank = dist.get_rank(self.global_dp_group)
|
||||||
self.tp_rank = dist.get_rank(self.tp_group)
|
self.tp_rank = dist.get_rank(self.tp_group)
|
||||||
self.pp_rank = dist.get_rank(self.pp_group)
|
self.pp_rank = dist.get_rank(self.pp_group)
|
||||||
|
self.sp_rank = dist.get_rank(self.sp_group)
|
||||||
self.global_dp_size = dist.get_world_size(dp_group)
|
self.global_dp_size = dist.get_world_size(dp_group)
|
||||||
self.pp_size = dist.get_world_size(pp_group)
|
self.pp_size = dist.get_world_size(pp_group)
|
||||||
self.tp_size = dist.get_world_size(tp_group)
|
self.tp_size = dist.get_world_size(tp_group)
|
||||||
@ -85,7 +95,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _model_sharder(
|
def _model_sharder(
|
||||||
model: nn.Module, prefix: str = "", keep_vars: bool = False, size_per_shard: int = 1024
|
model: nn.Module,
|
||||||
|
prefix: str = "",
|
||||||
|
keep_vars: bool = False,
|
||||||
|
size_per_shard: int = 1024,
|
||||||
|
pinned_state_dicts: Optional[Dict[str, torch.Tensor]] = None,
|
||||||
) -> Iterator[Tuple[OrderedDict, int]]:
|
) -> Iterator[Tuple[OrderedDict, int]]:
|
||||||
# An internel method that breaks state_dict of model into shards within limited size.
|
# An internel method that breaks state_dict of model into shards within limited size.
|
||||||
|
|
||||||
@ -96,17 +110,28 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||||||
if param is None:
|
if param is None:
|
||||||
continue
|
continue
|
||||||
# Gather tensor pieces when using tensor parallel.
|
# Gather tensor pieces when using tensor parallel.
|
||||||
if is_padded_tensor(param):
|
|
||||||
param = to_unpadded_tensor(param)
|
|
||||||
param_ = gather_distributed_param(param, keep_vars=False)
|
param_ = gather_distributed_param(param, keep_vars=False)
|
||||||
|
if is_padded_tensor(param_):
|
||||||
|
param_ = to_unpadded_tensor(param_)
|
||||||
|
if pinned_state_dicts is not None:
|
||||||
|
if (prefix + name) not in pinned_state_dicts:
|
||||||
|
pinned_state_dicts[prefix + name] = torch.empty_like(param_, pin_memory=True, device="cpu")
|
||||||
|
pinned_state_dicts[prefix + name].copy_(param_)
|
||||||
|
param_ = pinned_state_dicts[prefix + name]
|
||||||
block, block_size = state_dict_sharder.append_param(prefix + name, param_)
|
block, block_size = state_dict_sharder.append_param(prefix + name, param_)
|
||||||
if block is not None:
|
if block is not None:
|
||||||
yield block, block_size
|
yield block, block_size
|
||||||
|
|
||||||
# Save buffers.
|
# Save buffers.
|
||||||
|
non_persist_buffers_set = get_non_persistent_buffers_set(model)
|
||||||
for name, buf in model.named_buffers():
|
for name, buf in model.named_buffers():
|
||||||
if buf is not None and name not in model._non_persistent_buffers_set:
|
if buf is not None and name not in non_persist_buffers_set:
|
||||||
buffer = buf if keep_vars else buf.detach()
|
buffer = buf if keep_vars else buf.detach()
|
||||||
|
if pinned_state_dicts is not None:
|
||||||
|
if (prefix + name) not in pinned_state_dicts:
|
||||||
|
pinned_state_dicts[prefix + name] = torch.empty_like(param_, pin_memory=True, device="cpu")
|
||||||
|
pinned_state_dicts[prefix + name].copy_(buffer)
|
||||||
|
buffer = pinned_state_dicts[prefix + name]
|
||||||
block, block_size = state_dict_sharder.append_param(prefix + name, buffer)
|
block, block_size = state_dict_sharder.append_param(prefix + name, buffer)
|
||||||
if block is not None:
|
if block is not None:
|
||||||
yield block, block_size
|
yield block, block_size
|
||||||
@ -118,6 +143,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||||||
is not torch.nn.Module.get_extra_state
|
is not torch.nn.Module.get_extra_state
|
||||||
):
|
):
|
||||||
extra_state = model.get_extra_state()
|
extra_state = model.get_extra_state()
|
||||||
|
if pinned_state_dicts is not None:
|
||||||
|
if extra_state_key not in pinned_state_dicts:
|
||||||
|
pinned_state_dicts[extra_state_key] = torch.empty_like(param_, pin_memory=True, device="cpu")
|
||||||
|
pinned_state_dicts[extra_state_key].copy_(extra_state)
|
||||||
|
extra_state = pinned_state_dicts[extra_state_key]
|
||||||
block, block_size = state_dict_sharder.append_param(extra_state_key, extra_state)
|
block, block_size = state_dict_sharder.append_param(extra_state_key, extra_state)
|
||||||
if block is not None:
|
if block is not None:
|
||||||
yield block, block_size
|
yield block, block_size
|
||||||
@ -132,6 +162,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||||||
dp_group: ProcessGroup,
|
dp_group: ProcessGroup,
|
||||||
tp_group: ProcessGroup,
|
tp_group: ProcessGroup,
|
||||||
size_per_shard: int = 1024,
|
size_per_shard: int = 1024,
|
||||||
|
pinned_state_dicts: Optional[Dict[int, Dict[str, torch.Tensor]]] = None,
|
||||||
):
|
):
|
||||||
# An internel method that breaks state_dict of optimizer into shards within limited size.
|
# An internel method that breaks state_dict of optimizer into shards within limited size.
|
||||||
|
|
||||||
@ -149,6 +180,9 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||||||
working_param = param
|
working_param = param
|
||||||
|
|
||||||
param_id = param_info["param2id"][id(working_param)]
|
param_id = param_info["param2id"][id(working_param)]
|
||||||
|
if pinned_state_dicts is not None:
|
||||||
|
if param_id not in pinned_state_dicts:
|
||||||
|
pinned_state_dicts[param_id] = {}
|
||||||
original_shape = param_info["param2shape"][id(working_param)]
|
original_shape = param_info["param2shape"][id(working_param)]
|
||||||
state_ = HybridParallelCheckpointIO.gather_from_sharded_optimizer_state(
|
state_ = HybridParallelCheckpointIO.gather_from_sharded_optimizer_state(
|
||||||
state,
|
state,
|
||||||
@ -158,6 +192,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||||||
tp_group=tp_group,
|
tp_group=tp_group,
|
||||||
use_zero=use_zero,
|
use_zero=use_zero,
|
||||||
inplace=False,
|
inplace=False,
|
||||||
|
pinned_state_dicts=pinned_state_dicts[param_id] if pinned_state_dicts is not None else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
block, block_size = state_dict_sharder.append_optim_state(param_id, state_)
|
block, block_size = state_dict_sharder.append_optim_state(param_id, state_)
|
||||||
@ -175,6 +210,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||||||
prefix: Optional[str] = None,
|
prefix: Optional[str] = None,
|
||||||
size_per_shard: int = 1024,
|
size_per_shard: int = 1024,
|
||||||
use_safetensors: bool = False,
|
use_safetensors: bool = False,
|
||||||
|
use_async: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Save sharded model checkpoint under the given checkpointing path.
|
Save sharded model checkpoint under the given checkpointing path.
|
||||||
@ -192,6 +228,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||||||
prefix (str, optional): Perfix of file to save. Defaults to None.
|
prefix (str, optional): Perfix of file to save. Defaults to None.
|
||||||
size_per_shard (int, optional): Size per shard in MB. Defaults to 1024.
|
size_per_shard (int, optional): Size per shard in MB. Defaults to 1024.
|
||||||
use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
|
use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
|
||||||
|
use_async (bool, optional): Whether to save the state_dicts of model asynchronously. Defaults to False.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
|
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
|
||||||
@ -210,21 +247,40 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||||||
|
|
||||||
# Then collect the sharded parameters & buffers along tp_group.
|
# Then collect the sharded parameters & buffers along tp_group.
|
||||||
# Only devices with tp_rank == 0 are responsible for model saving.
|
# Only devices with tp_rank == 0 are responsible for model saving.
|
||||||
state_dict_shard = HybridParallelCheckpointIO._model_sharder(model, size_per_shard=size_per_shard)
|
control_saving = self.tp_rank == 0 and self.sp_rank == 0
|
||||||
|
if control_saving and use_async:
|
||||||
|
if hash(model) not in self.pinned_state_dicts:
|
||||||
|
self.pinned_state_dicts[hash(model)] = {}
|
||||||
|
pinned_state_dicts = self.pinned_state_dicts[hash(model)]
|
||||||
|
else:
|
||||||
|
pinned_state_dicts = None
|
||||||
|
state_dict_shard = HybridParallelCheckpointIO._model_sharder(
|
||||||
|
model, size_per_shard=size_per_shard, pinned_state_dicts=pinned_state_dicts
|
||||||
|
)
|
||||||
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
|
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
|
||||||
index_file = CheckpointIndexFile(checkpoint)
|
index_file = CheckpointIndexFile(checkpoint)
|
||||||
control_saving = self.tp_rank == 0
|
|
||||||
|
|
||||||
if self.pp_size == 1:
|
if self.pp_size == 1:
|
||||||
# When pipeline is not used, save the model shards as in general checkpointIO
|
# When pipeline is not used, save the model shards as in general checkpointIO
|
||||||
total_size = save_state_dict_shards(
|
if use_async:
|
||||||
sharded_state_dict=state_dict_shard,
|
total_size, writers = async_save_state_dict_shards(
|
||||||
checkpoint=checkpoint,
|
sharded_state_dict=state_dict_shard,
|
||||||
index_file=index_file,
|
checkpoint=checkpoint,
|
||||||
base_filename=weights_name,
|
index_file=index_file,
|
||||||
is_master=control_saving,
|
base_filename=weights_name,
|
||||||
use_safetensors=use_safetensors,
|
is_master=control_saving,
|
||||||
)
|
state_preprocess=False,
|
||||||
|
)
|
||||||
|
self.async_writers.extend(writers)
|
||||||
|
else:
|
||||||
|
total_size = save_state_dict_shards(
|
||||||
|
sharded_state_dict=state_dict_shard,
|
||||||
|
checkpoint=checkpoint,
|
||||||
|
index_file=index_file,
|
||||||
|
base_filename=weights_name,
|
||||||
|
is_master=control_saving,
|
||||||
|
use_safetensors=use_safetensors,
|
||||||
|
)
|
||||||
if control_saving:
|
if control_saving:
|
||||||
index_file.append_meta_data("total_size", total_size)
|
index_file.append_meta_data("total_size", total_size)
|
||||||
index_file.write_index_file(save_index_file)
|
index_file.write_index_file(save_index_file)
|
||||||
@ -249,16 +305,26 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||||||
weights_name = weights_name.replace(".safetensors", f"-stage-{self.pp_rank+1:05d}-shard.safetensors")
|
weights_name = weights_name.replace(".safetensors", f"-stage-{self.pp_rank+1:05d}-shard.safetensors")
|
||||||
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json")
|
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json")
|
||||||
save_index_file = os.path.join("tmp_index_files", save_index_file)
|
save_index_file = os.path.join("tmp_index_files", save_index_file)
|
||||||
|
if use_async:
|
||||||
total_size = save_state_dict_shards(
|
total_size, writers = async_save_state_dict_shards(
|
||||||
sharded_state_dict=state_dict_shard,
|
sharded_state_dict=state_dict_shard,
|
||||||
checkpoint=checkpoint,
|
checkpoint=checkpoint,
|
||||||
index_file=index_file,
|
index_file=index_file,
|
||||||
base_filename=weights_name,
|
base_filename=weights_name,
|
||||||
is_master=control_saving,
|
is_master=control_saving,
|
||||||
use_safetensors=use_safetensors,
|
state_preprocess=False,
|
||||||
use_pp_format=True,
|
)
|
||||||
)
|
self.async_writers.extend(writers)
|
||||||
|
else:
|
||||||
|
total_size = save_state_dict_shards(
|
||||||
|
sharded_state_dict=state_dict_shard,
|
||||||
|
checkpoint=checkpoint,
|
||||||
|
index_file=index_file,
|
||||||
|
base_filename=weights_name,
|
||||||
|
is_master=control_saving,
|
||||||
|
use_safetensors=use_safetensors,
|
||||||
|
use_pp_format=True,
|
||||||
|
)
|
||||||
|
|
||||||
if control_saving:
|
if control_saving:
|
||||||
assert (
|
assert (
|
||||||
@ -292,7 +358,14 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||||||
f"index located at {final_index_file_path}."
|
f"index located at {final_index_file_path}."
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_sharded_model(self, model: ModelWrapper, checkpoint_index_file: Path, strict: bool = False):
|
def load_sharded_model(
|
||||||
|
self,
|
||||||
|
model: ModelWrapper,
|
||||||
|
checkpoint_index_file: Path,
|
||||||
|
strict: bool = False,
|
||||||
|
low_cpu_mem_mode: bool = True,
|
||||||
|
num_threads: int = 1,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Load sharded model with the given path to index file of checkpoint folder.
|
Load sharded model with the given path to index file of checkpoint folder.
|
||||||
|
|
||||||
@ -340,6 +413,8 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||||||
|
|
||||||
file_path = os.path.join(ckpt_root_path, filename)
|
file_path = os.path.join(ckpt_root_path, filename)
|
||||||
state_dict = load_shard_state_dict(Path(file_path), use_safetensors)
|
state_dict = load_shard_state_dict(Path(file_path), use_safetensors)
|
||||||
|
if not low_cpu_mem_mode:
|
||||||
|
state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads)
|
||||||
|
|
||||||
load_state_dict_into_model(
|
load_state_dict_into_model(
|
||||||
model, state_dict, missing_keys=missing_keys, strict=strict, load_sub_module=True
|
model, state_dict, missing_keys=missing_keys, strict=strict, load_sub_module=True
|
||||||
@ -351,9 +426,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||||||
_load(name)
|
_load(name)
|
||||||
|
|
||||||
# Load buffers.
|
# Load buffers.
|
||||||
non_persistent_buffers = set()
|
non_persistent_buffers = get_non_persistent_buffers_set(model)
|
||||||
for n, m in model.named_modules():
|
|
||||||
non_persistent_buffers |= set(".".join((n, b)) for b in m._non_persistent_buffers_set)
|
|
||||||
for name, buf in model.named_buffers():
|
for name, buf in model.named_buffers():
|
||||||
if buf is not None and name not in non_persistent_buffers:
|
if buf is not None and name not in non_persistent_buffers:
|
||||||
_load(name)
|
_load(name)
|
||||||
@ -400,6 +473,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||||||
gather_dtensor: bool = True,
|
gather_dtensor: bool = True,
|
||||||
prefix: Optional[str] = None,
|
prefix: Optional[str] = None,
|
||||||
size_per_shard: int = 1024,
|
size_per_shard: int = 1024,
|
||||||
|
use_async: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Save sharded optimizer checkpoint under the given checkpointing path.
|
Save sharded optimizer checkpoint under the given checkpointing path.
|
||||||
@ -431,26 +505,46 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||||||
|
|
||||||
# Then collect the sharded states along dp_group(if using zero)/tp_group.
|
# Then collect the sharded states along dp_group(if using zero)/tp_group.
|
||||||
# Only devices with (dp_rank == 0 and tp_rank == 0) are responsible for states saving.
|
# Only devices with (dp_rank == 0 and tp_rank == 0) are responsible for states saving.
|
||||||
|
control_saving = self.dp_rank == 0 and self.tp_rank == 0 and self.sp_rank == 0
|
||||||
|
|
||||||
|
if use_async and control_saving:
|
||||||
|
if id(optimizer) not in self.pinned_state_dicts:
|
||||||
|
self.pinned_state_dicts[id(optimizer)] = {}
|
||||||
|
pinned_state_dicts = self.pinned_state_dicts[id(optimizer)]
|
||||||
|
else:
|
||||||
|
pinned_state_dicts = None
|
||||||
state_dict_shard = HybridParallelCheckpointIO._optimizer_sharder(
|
state_dict_shard = HybridParallelCheckpointIO._optimizer_sharder(
|
||||||
optimizer,
|
optimizer,
|
||||||
use_zero=self.use_zero,
|
use_zero=self.use_zero,
|
||||||
dp_group=self.global_dp_group,
|
dp_group=self.global_dp_group,
|
||||||
tp_group=self.tp_group,
|
tp_group=self.tp_group,
|
||||||
size_per_shard=size_per_shard,
|
size_per_shard=size_per_shard,
|
||||||
|
pinned_state_dicts=pinned_state_dicts,
|
||||||
)
|
)
|
||||||
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
|
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix, use_safetensors=use_async)
|
||||||
index_file = CheckpointIndexFile(checkpoint)
|
index_file = CheckpointIndexFile(checkpoint)
|
||||||
control_saving = self.dp_rank == 0 and self.tp_rank == 0
|
|
||||||
|
|
||||||
if self.pp_size == 1:
|
if self.pp_size == 1:
|
||||||
# When pipeline is not used, save the optimizer shards as in general checkpointIO
|
# When pipeline is not used, save the optimizer shards as in general checkpointIO
|
||||||
total_size = save_state_dict_shards(
|
if use_async:
|
||||||
sharded_state_dict=state_dict_shard,
|
total_size, writers = async_save_state_dict_shards(
|
||||||
checkpoint=checkpoint,
|
sharded_state_dict=state_dict_shard,
|
||||||
index_file=index_file,
|
checkpoint=checkpoint,
|
||||||
base_filename=states_name,
|
index_file=index_file,
|
||||||
is_master=control_saving,
|
base_filename=states_name,
|
||||||
)
|
is_master=control_saving,
|
||||||
|
use_pp_format=True,
|
||||||
|
state_preprocess=True,
|
||||||
|
)
|
||||||
|
self.async_writers.extend(writers)
|
||||||
|
else:
|
||||||
|
total_size = save_state_dict_shards(
|
||||||
|
sharded_state_dict=state_dict_shard,
|
||||||
|
checkpoint=checkpoint,
|
||||||
|
index_file=index_file,
|
||||||
|
base_filename=states_name,
|
||||||
|
is_master=control_saving,
|
||||||
|
)
|
||||||
|
|
||||||
if control_saving:
|
if control_saving:
|
||||||
# Store param groups.
|
# Store param groups.
|
||||||
@ -481,18 +575,33 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||||||
Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True)
|
Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# Manage filenames of sharded weights and index file for each pipeline stage.
|
# Manage filenames of sharded weights and index file for each pipeline stage.
|
||||||
states_name = states_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-shard.bin")
|
if not use_async:
|
||||||
|
states_name = states_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-shard.bin")
|
||||||
|
else:
|
||||||
|
states_name = states_name.replace(".safetensors", f"-stage-{self.pp_rank+1:05d}-shard.safetensors")
|
||||||
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json")
|
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json")
|
||||||
save_index_file = os.path.join("tmp_index_files", save_index_file)
|
save_index_file = os.path.join("tmp_index_files", save_index_file)
|
||||||
|
|
||||||
total_size = save_state_dict_shards(
|
if use_async:
|
||||||
sharded_state_dict=state_dict_shard,
|
total_size, writers = async_save_state_dict_shards(
|
||||||
checkpoint=checkpoint,
|
sharded_state_dict=state_dict_shard,
|
||||||
index_file=index_file,
|
checkpoint=checkpoint,
|
||||||
base_filename=states_name,
|
index_file=index_file,
|
||||||
is_master=control_saving,
|
base_filename=states_name,
|
||||||
use_pp_format=True,
|
is_master=control_saving,
|
||||||
)
|
use_pp_format=True,
|
||||||
|
state_preprocess=True,
|
||||||
|
)
|
||||||
|
self.async_writers.extend(writers)
|
||||||
|
else:
|
||||||
|
total_size = save_state_dict_shards(
|
||||||
|
sharded_state_dict=state_dict_shard,
|
||||||
|
checkpoint=checkpoint,
|
||||||
|
index_file=index_file,
|
||||||
|
base_filename=states_name,
|
||||||
|
is_master=control_saving,
|
||||||
|
use_pp_format=True,
|
||||||
|
)
|
||||||
|
|
||||||
if control_saving:
|
if control_saving:
|
||||||
assert (
|
assert (
|
||||||
@ -535,7 +644,14 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||||||
f"index located at {final_index_file_path}."
|
f"index located at {final_index_file_path}."
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_sharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint_index_file: str, prefix: str = ""):
|
def load_sharded_optimizer(
|
||||||
|
self,
|
||||||
|
optimizer: OptimizerWrapper,
|
||||||
|
checkpoint_index_file: str,
|
||||||
|
prefix: str = "",
|
||||||
|
low_cpu_mem_mode: bool = True,
|
||||||
|
num_threads: int = 1,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Load sharded optimizer with the given path to index file of checkpoint folder.
|
Load sharded optimizer with the given path to index file of checkpoint folder.
|
||||||
|
|
||||||
@ -605,28 +721,46 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
file_path = os.path.join(ckpt_root_path, filename)
|
file_path = os.path.join(ckpt_root_path, filename)
|
||||||
state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False)
|
if file_path.endswith(".safetensors"):
|
||||||
load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True)
|
state_dict = load_flat(file_path)
|
||||||
|
else:
|
||||||
|
state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False)
|
||||||
|
if not low_cpu_mem_mode:
|
||||||
|
state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads)
|
||||||
|
self.load_states_into_optimizer(optimizer, state_dict, id_map)
|
||||||
loaded_file.add(filename)
|
loaded_file.add(filename)
|
||||||
|
|
||||||
# Then shard the loaded optimizer states if using tp/zero.
|
|
||||||
for param, state in optimizer.optim.state.items():
|
|
||||||
device = param.device
|
|
||||||
if master_to_working_map is not None:
|
|
||||||
working_param = master_to_working_map[id(param)]
|
|
||||||
else:
|
|
||||||
working_param = param
|
|
||||||
original_shape = optimizer.param_info["param2shape"][id(working_param)]
|
|
||||||
sharded_state = self.shard_from_complete_optimizer_state(
|
|
||||||
state, current_shape=working_param.shape, original_shape=original_shape, device=device, inplace=True
|
|
||||||
)
|
|
||||||
optimizer.optim.state[param] = sharded_state
|
|
||||||
|
|
||||||
sharded_optimizer_loading_epilogue(optimizer.optim)
|
sharded_optimizer_loading_epilogue(optimizer.optim)
|
||||||
if self.verbose and self.coordinator.is_master():
|
if self.verbose and self.coordinator.is_master():
|
||||||
logging.info(f"The optimizer has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
|
logging.info(f"The optimizer has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
|
||||||
|
|
||||||
def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
|
def load_states_into_optimizer(self, optimizer: Optimizer, state_dict: dict, id_map: dict):
|
||||||
|
state_dict = {int(k): v for k, v in state_dict.items()}
|
||||||
|
new_states = defaultdict(dict)
|
||||||
|
master_to_working_map = optimizer.get_master_to_working_map()
|
||||||
|
for k, state in state_dict.items():
|
||||||
|
if k in id_map:
|
||||||
|
param = id_map[k]
|
||||||
|
device = param.device
|
||||||
|
dtype = param.dtype
|
||||||
|
if master_to_working_map is not None:
|
||||||
|
working_param = master_to_working_map[id(param)]
|
||||||
|
else:
|
||||||
|
working_param = param
|
||||||
|
original_shape = optimizer.param_info["param2shape"][id(working_param)]
|
||||||
|
new_states[param] = self.shard_from_complete_optimizer_state(
|
||||||
|
state,
|
||||||
|
current_shape=working_param.shape,
|
||||||
|
original_shape=original_shape,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
inplace=True,
|
||||||
|
)
|
||||||
|
optimizer.optim.state.update(new_states)
|
||||||
|
|
||||||
|
def save_unsharded_model(
|
||||||
|
self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Save model state dict to a single file with given checkpointing path.
|
Save model state dict to a single file with given checkpointing path.
|
||||||
|
|
||||||
@ -635,6 +769,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||||||
checkpoint (str): Checkpointing path which should be a file path. Can be absolute or relative path.
|
checkpoint (str): Checkpointing path which should be a file path. Can be absolute or relative path.
|
||||||
gather_dtensor (bool, optional): Whether to gather dtensor, currently not used. Defaults to True.
|
gather_dtensor (bool, optional): Whether to gather dtensor, currently not used. Defaults to True.
|
||||||
use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
|
use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
|
||||||
|
use_async (bool, optional): Whether to save the state_dicts of model asynchronously. Defaults to False.
|
||||||
"""
|
"""
|
||||||
if self.coordinator.is_master():
|
if self.coordinator.is_master():
|
||||||
logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!")
|
logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!")
|
||||||
@ -651,7 +786,18 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||||||
if self.pp_size == 1:
|
if self.pp_size == 1:
|
||||||
# When pipeline is not used, let master rank directly save the collected state_dict.
|
# When pipeline is not used, let master rank directly save the collected state_dict.
|
||||||
if self.tp_rank == 0:
|
if self.tp_rank == 0:
|
||||||
save_state_dict(state_dict, checkpoint, use_safetensors)
|
if use_async:
|
||||||
|
from colossalai.utils.safetensors import save
|
||||||
|
|
||||||
|
if hash(model) not in self.pinned_state_dicts:
|
||||||
|
self.pinned_state_dicts[hash(model)] = create_pinned_state_dict(state_dict)
|
||||||
|
for name, param in state_dict.items():
|
||||||
|
self.pinned_state_dicts[hash(model)][name].copy_(param)
|
||||||
|
state_dict[name] = self.pinned_state_dicts[hash(model)][name]
|
||||||
|
writer = save(path=checkpoint, state_dict=state_dict)
|
||||||
|
self.async_writers.append(writer)
|
||||||
|
else:
|
||||||
|
save_state_dict(state_dict, checkpoint, use_safetensors)
|
||||||
else:
|
else:
|
||||||
# When pipeline is used, first collect state_dict from every pipeline stage, then save the complete state_dict.
|
# When pipeline is used, first collect state_dict from every pipeline stage, then save the complete state_dict.
|
||||||
state_dict_list = [None for _ in range(self.pp_size)]
|
state_dict_list = [None for _ in range(self.pp_size)]
|
||||||
@ -662,9 +808,27 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||||||
complete_state_dict = dict()
|
complete_state_dict = dict()
|
||||||
for _state_dict in state_dict_list:
|
for _state_dict in state_dict_list:
|
||||||
complete_state_dict.update(_state_dict)
|
complete_state_dict.update(_state_dict)
|
||||||
save_state_dict(complete_state_dict, checkpoint, use_safetensors)
|
if use_async:
|
||||||
|
from colossalai.utils.safetensors import save
|
||||||
|
|
||||||
def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = False):
|
if hash(model) not in self.pinned_state_dicts:
|
||||||
|
self.pinned_state_dicts[hash(model)] = create_pinned_state_dict(complete_state_dict)
|
||||||
|
for name, param in complete_state_dict.items():
|
||||||
|
self.pinned_state_dicts[hash(model)][name].copy_(param)
|
||||||
|
complete_state_dict[name] = self.pinned_state_dicts[hash(model)][name]
|
||||||
|
writer = save(path=checkpoint, state_dict=complete_state_dict)
|
||||||
|
self.async_writers.append(writer)
|
||||||
|
else:
|
||||||
|
save_state_dict(complete_state_dict, checkpoint, use_safetensors)
|
||||||
|
|
||||||
|
def load_unsharded_model(
|
||||||
|
self,
|
||||||
|
model: ModelWrapper,
|
||||||
|
checkpoint: str,
|
||||||
|
strict: bool = False,
|
||||||
|
low_cpu_mem_mode: bool = True,
|
||||||
|
num_threads: int = 1,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Load model from a single file with the given path of checkpoint.
|
Load model from a single file with the given path of checkpoint.
|
||||||
|
|
||||||
@ -687,12 +851,16 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||||||
# has been implemented by _load_from_state_dict method of ParallelModule in Shardformer,
|
# has been implemented by _load_from_state_dict method of ParallelModule in Shardformer,
|
||||||
# model.load_state_dict can be directly called.
|
# model.load_state_dict can be directly called.
|
||||||
state_dict = load_state_dict(checkpoint)
|
state_dict = load_state_dict(checkpoint)
|
||||||
|
if not low_cpu_mem_mode:
|
||||||
|
state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads)
|
||||||
model.load_state_dict(state_dict, strict=strict)
|
model.load_state_dict(state_dict, strict=strict)
|
||||||
|
|
||||||
# Update master params if mixed-precision training is enabled.
|
# Update master params if mixed-precision training is enabled.
|
||||||
model_before_wrapping.update_master_params()
|
model_before_wrapping.update_master_params()
|
||||||
|
|
||||||
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
|
def save_unsharded_optimizer(
|
||||||
|
self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool, use_async: bool = False
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Save optimizer state dict to a file with given path.
|
Save optimizer state dict to a file with given path.
|
||||||
|
|
||||||
@ -723,6 +891,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||||||
# gather complete state from tp shards & dp shards
|
# gather complete state from tp shards & dp shards
|
||||||
param_id = optimizer.param_info["param2id"][id(working_param)]
|
param_id = optimizer.param_info["param2id"][id(working_param)]
|
||||||
original_shape = optimizer.param_info["param2shape"][id(working_param)]
|
original_shape = optimizer.param_info["param2shape"][id(working_param)]
|
||||||
|
|
||||||
local_states[param_id] = HybridParallelCheckpointIO.gather_from_sharded_optimizer_state(
|
local_states[param_id] = HybridParallelCheckpointIO.gather_from_sharded_optimizer_state(
|
||||||
state,
|
state,
|
||||||
working_param,
|
working_param,
|
||||||
@ -742,7 +911,19 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||||||
]
|
]
|
||||||
state_dict = {"param_groups": param_groups, "state": local_states}
|
state_dict = {"param_groups": param_groups, "state": local_states}
|
||||||
if self.coordinator.is_master():
|
if self.coordinator.is_master():
|
||||||
save_state_dict(state_dict, checkpoint, use_safetensors=False)
|
if use_async:
|
||||||
|
from colossalai.utils.safetensors import save
|
||||||
|
|
||||||
|
flatten_state_dict, metadata = _flatten_optim_state_dict(state_dict)
|
||||||
|
if id(optimizer) not in self.pinned_state_dicts:
|
||||||
|
self.pinned_state_dicts = create_pinned_state_dict(flatten_state_dict)
|
||||||
|
for k, v in flatten_state_dict.items():
|
||||||
|
self.pinned_state_dicts[k].copy_(v)
|
||||||
|
flatten_state_dict[k] = self.pinned_state_dicts[k]
|
||||||
|
writer = save(path=checkpoint, state_dict=flatten_state_dict, metadata=metadata)
|
||||||
|
self.async_writers.append(writer)
|
||||||
|
else:
|
||||||
|
save_state_dict(state_dict, checkpoint, use_safetensors=False)
|
||||||
else:
|
else:
|
||||||
# When pipeline is used, first collect state_dict from every pipeline stage, then save the complete state_dict.
|
# When pipeline is used, first collect state_dict from every pipeline stage, then save the complete state_dict.
|
||||||
states_list = [None for _ in range(self.pp_size)]
|
states_list = [None for _ in range(self.pp_size)]
|
||||||
@ -758,9 +939,23 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||||||
state_dict = {"param_groups": param_groups, "state": dict()}
|
state_dict = {"param_groups": param_groups, "state": dict()}
|
||||||
for _states in states_list:
|
for _states in states_list:
|
||||||
state_dict["state"].update(_states)
|
state_dict["state"].update(_states)
|
||||||
save_state_dict(state_dict, checkpoint, use_safetensors=False)
|
if use_async:
|
||||||
|
from colossalai.utils.safetensors import save
|
||||||
|
|
||||||
def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str):
|
flatten_state_dict, metadata = _flatten_optim_state_dict(state_dict)
|
||||||
|
if id(optimizer) not in self.pinned_state_dicts:
|
||||||
|
self.pinned_state_dicts = create_pinned_state_dict(flatten_state_dict)
|
||||||
|
for k, v in flatten_state_dict.items():
|
||||||
|
self.pinned_state_dicts[k].copy_(v)
|
||||||
|
flatten_state_dict[k] = self.pinned_state_dicts[k]
|
||||||
|
writer = save(path=checkpoint, state_dict=flatten_state_dict, metadata=metadata)
|
||||||
|
self.async_writers.append(writer)
|
||||||
|
else:
|
||||||
|
save_state_dict(state_dict, checkpoint, use_safetensors=False)
|
||||||
|
|
||||||
|
def load_unsharded_optimizer(
|
||||||
|
self, optimizer: OptimizerWrapper, checkpoint: str, low_cpu_mem_mode: bool = True, num_threads: int = 1
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Load optimizer from a file with given path.
|
Load optimizer from a file with given path.
|
||||||
|
|
||||||
@ -784,7 +979,12 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||||||
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
|
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
|
||||||
|
|
||||||
# Complete optimizer state_dict loaded from checkpoint, need to be processed later.
|
# Complete optimizer state_dict loaded from checkpoint, need to be processed later.
|
||||||
state_dict = load_state_dict(checkpoint)
|
if checkpoint.endswith(".safetensors"):
|
||||||
|
state_dict = load_flat(checkpoint)
|
||||||
|
else:
|
||||||
|
state_dict = load_state_dict(checkpoint)
|
||||||
|
if not low_cpu_mem_mode:
|
||||||
|
state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads)
|
||||||
|
|
||||||
# Load param_groups.
|
# Load param_groups.
|
||||||
updated_groups = []
|
updated_groups = []
|
||||||
@ -802,22 +1002,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||||||
for param in pg["params"]:
|
for param in pg["params"]:
|
||||||
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
|
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
|
||||||
id_map[param_id] = param
|
id_map[param_id] = param
|
||||||
load_states_into_optimizer(optimizer.optim, state_dict["state"], id_map, strict=True)
|
self.load_states_into_optimizer(optimizer, state_dict["state"], id_map)
|
||||||
|
|
||||||
# Then shard the loaded optimizer states if using tp/zero.
|
|
||||||
for param, state in optimizer.optim.state.items():
|
|
||||||
if param is None:
|
|
||||||
continue
|
|
||||||
device = param.device
|
|
||||||
if master_to_working_map is not None:
|
|
||||||
working_param = master_to_working_map[id(param)]
|
|
||||||
else:
|
|
||||||
working_param = param
|
|
||||||
original_shape = optimizer.param_info["param2shape"][id(working_param)]
|
|
||||||
sharded_state = self.shard_from_complete_optimizer_state(
|
|
||||||
state, current_shape=working_param.shape, original_shape=original_shape, device=device, inplace=True
|
|
||||||
)
|
|
||||||
optimizer.optim.state[param] = sharded_state
|
|
||||||
|
|
||||||
sharded_optimizer_loading_epilogue(optimizer.optim)
|
sharded_optimizer_loading_epilogue(optimizer.optim)
|
||||||
|
|
||||||
@ -838,6 +1023,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||||||
use_zero: bool,
|
use_zero: bool,
|
||||||
inplace: bool,
|
inplace: bool,
|
||||||
device: torch.device = torch.device("cpu"),
|
device: torch.device = torch.device("cpu"),
|
||||||
|
pinned_state_dicts: Optional[Dict[str, torch.Tensor]] = None,
|
||||||
) -> OrderedDict:
|
) -> OrderedDict:
|
||||||
"""
|
"""
|
||||||
With given parameter and its optimizer states, gather the complete optimizer state for saving.
|
With given parameter and its optimizer states, gather the complete optimizer state for saving.
|
||||||
@ -861,6 +1047,8 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||||||
state_ = state if inplace else copy.deepcopy(state)
|
state_ = state if inplace else copy.deepcopy(state)
|
||||||
|
|
||||||
for k, v in state_.items():
|
for k, v in state_.items():
|
||||||
|
if v is None:
|
||||||
|
continue
|
||||||
if isinstance(v, torch.Tensor) and k != "step":
|
if isinstance(v, torch.Tensor) and k != "step":
|
||||||
# First gather Zero shards.
|
# First gather Zero shards.
|
||||||
if use_zero:
|
if use_zero:
|
||||||
@ -881,7 +1069,13 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||||||
v = init_as_padded_tensor(v, v.shape[padding_dim], original_shape[padding_dim], padding_dim)
|
v = init_as_padded_tensor(v, v.shape[padding_dim], original_shape[padding_dim], padding_dim)
|
||||||
v = to_unpadded_tensor(v)
|
v = to_unpadded_tensor(v)
|
||||||
|
|
||||||
state_[k] = v.detach().clone().to(device)
|
if pinned_state_dicts is not None:
|
||||||
|
if k not in pinned_state_dicts:
|
||||||
|
pinned_state_dicts[k] = torch.empty_like(v, pin_memory=True, device="cpu")
|
||||||
|
pinned_state_dicts[k].copy_(v)
|
||||||
|
state_[k] = pinned_state_dicts[k]
|
||||||
|
else:
|
||||||
|
state_[k] = v.detach().clone().to(device)
|
||||||
|
|
||||||
return state_
|
return state_
|
||||||
|
|
||||||
@ -891,6 +1085,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||||||
current_shape: torch.Size,
|
current_shape: torch.Size,
|
||||||
original_shape: torch.Size,
|
original_shape: torch.Size,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
|
dtype: torch.dtype,
|
||||||
inplace: bool,
|
inplace: bool,
|
||||||
) -> OrderedDict:
|
) -> OrderedDict:
|
||||||
"""
|
"""
|
||||||
@ -940,11 +1135,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||||||
slice_size = v.numel() // self.global_dp_size
|
slice_size = v.numel() // self.global_dp_size
|
||||||
v = v.split(slice_size, dim=0)[self.dp_rank]
|
v = v.split(slice_size, dim=0)[self.dp_rank]
|
||||||
|
|
||||||
state_[k] = v.detach().clone().to(device)
|
state_[k] = v.detach().clone().to(device=device, dtype=dtype)
|
||||||
|
|
||||||
return state_
|
return state_
|
||||||
|
|
||||||
def save_lora_as_pretrained(self, model, checkpoint, use_safetensors):
|
def save_lora_as_pretrained(self, model, checkpoint, use_safetensors, state_dict: Optional[dict] = None):
|
||||||
if os.path.isfile(checkpoint):
|
if os.path.isfile(checkpoint):
|
||||||
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
|
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
|
||||||
return
|
return
|
||||||
@ -952,8 +1147,21 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||||||
|
|
||||||
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
|
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
|
||||||
model._force_wait_all_gather()
|
model._force_wait_all_gather()
|
||||||
peft_model = model.unwrap()
|
peft_model = model.unwrap(unwrap_peft=False)
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
peft_model, PeftModel
|
peft_model, PeftModel
|
||||||
), "The model doesn't have lora adapters, please enable lora before saving."
|
), "The model doesn't have lora adapters, please enable lora before saving."
|
||||||
return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors)
|
if state_dict is None:
|
||||||
|
state_dict = tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict())
|
||||||
|
if self.pp_size > 1:
|
||||||
|
lora_state_dict = get_lora_state_dict(peft_model, state_dict)
|
||||||
|
gathered_lora_state_dict = gather_state_dict_fast(lora_state_dict, self.pp_group, device="cpu")
|
||||||
|
if self.pp_rank == 0:
|
||||||
|
state_dict.update(gathered_lora_state_dict)
|
||||||
|
state_dict = tree_map(lambda x: x.cpu() if torch.is_tensor(x) else x, state_dict)
|
||||||
|
if self.coordinator.is_master():
|
||||||
|
return peft_model.save_pretrained(
|
||||||
|
checkpoint,
|
||||||
|
safe_serialization=use_safetensors,
|
||||||
|
state_dict=state_dict,
|
||||||
|
)
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user