mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-25 15:01:43 +00:00
remove chatgpt (#3284)
This commit is contained in:
parent
b0ce5a1032
commit
bb6196e71a
146
applications/ChatGPT/.gitignore
vendored
146
applications/ChatGPT/.gitignore
vendored
@ -1,146 +0,0 @@
|
|||||||
# Byte-compiled / optimized / DLL files
|
|
||||||
__pycache__/
|
|
||||||
*.py[cod]
|
|
||||||
*$py.class
|
|
||||||
|
|
||||||
# C extensions
|
|
||||||
*.so
|
|
||||||
|
|
||||||
# Distribution / packaging
|
|
||||||
.Python
|
|
||||||
build/
|
|
||||||
develop-eggs/
|
|
||||||
dist/
|
|
||||||
downloads/
|
|
||||||
eggs/
|
|
||||||
.eggs/
|
|
||||||
lib/
|
|
||||||
lib64/
|
|
||||||
parts/
|
|
||||||
sdist/
|
|
||||||
var/
|
|
||||||
wheels/
|
|
||||||
pip-wheel-metadata/
|
|
||||||
share/python-wheels/
|
|
||||||
*.egg-info/
|
|
||||||
.installed.cfg
|
|
||||||
*.egg
|
|
||||||
MANIFEST
|
|
||||||
|
|
||||||
# PyInstaller
|
|
||||||
# Usually these files are written by a python script from a template
|
|
||||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
|
||||||
*.manifest
|
|
||||||
*.spec
|
|
||||||
|
|
||||||
# Installer logs
|
|
||||||
pip-log.txt
|
|
||||||
pip-delete-this-directory.txt
|
|
||||||
|
|
||||||
# Unit test / coverage reports
|
|
||||||
htmlcov/
|
|
||||||
.tox/
|
|
||||||
.nox/
|
|
||||||
.coverage
|
|
||||||
.coverage.*
|
|
||||||
.cache
|
|
||||||
nosetests.xml
|
|
||||||
coverage.xml
|
|
||||||
*.cover
|
|
||||||
*.py,cover
|
|
||||||
.hypothesis/
|
|
||||||
.pytest_cache/
|
|
||||||
|
|
||||||
# Translations
|
|
||||||
*.mo
|
|
||||||
*.pot
|
|
||||||
|
|
||||||
# Django stuff:
|
|
||||||
*.log
|
|
||||||
local_settings.py
|
|
||||||
db.sqlite3
|
|
||||||
db.sqlite3-journal
|
|
||||||
|
|
||||||
# Flask stuff:
|
|
||||||
instance/
|
|
||||||
.webassets-cache
|
|
||||||
|
|
||||||
# Scrapy stuff:
|
|
||||||
.scrapy
|
|
||||||
|
|
||||||
# Sphinx documentation
|
|
||||||
docs/_build/
|
|
||||||
docs/.build/
|
|
||||||
|
|
||||||
# PyBuilder
|
|
||||||
target/
|
|
||||||
|
|
||||||
# Jupyter Notebook
|
|
||||||
.ipynb_checkpoints
|
|
||||||
|
|
||||||
# IPython
|
|
||||||
profile_default/
|
|
||||||
ipython_config.py
|
|
||||||
|
|
||||||
# pyenv
|
|
||||||
.python-version
|
|
||||||
|
|
||||||
# pipenv
|
|
||||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
|
||||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
|
||||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
|
||||||
# install all needed dependencies.
|
|
||||||
#Pipfile.lock
|
|
||||||
|
|
||||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
|
||||||
__pypackages__/
|
|
||||||
|
|
||||||
# Celery stuff
|
|
||||||
celerybeat-schedule
|
|
||||||
celerybeat.pid
|
|
||||||
|
|
||||||
# SageMath parsed files
|
|
||||||
*.sage.py
|
|
||||||
|
|
||||||
# Environments
|
|
||||||
.env
|
|
||||||
.venv
|
|
||||||
env/
|
|
||||||
venv/
|
|
||||||
ENV/
|
|
||||||
env.bak/
|
|
||||||
venv.bak/
|
|
||||||
|
|
||||||
# Spyder project settings
|
|
||||||
.spyderproject
|
|
||||||
.spyproject
|
|
||||||
|
|
||||||
# Rope project settings
|
|
||||||
.ropeproject
|
|
||||||
|
|
||||||
# mkdocs documentation
|
|
||||||
/site
|
|
||||||
|
|
||||||
# mypy
|
|
||||||
.mypy_cache/
|
|
||||||
.dmypy.json
|
|
||||||
dmypy.json
|
|
||||||
|
|
||||||
# Pyre type checker
|
|
||||||
.pyre/
|
|
||||||
|
|
||||||
# IDE
|
|
||||||
.idea/
|
|
||||||
.vscode/
|
|
||||||
|
|
||||||
# macos
|
|
||||||
*.DS_Store
|
|
||||||
#data/
|
|
||||||
|
|
||||||
docs/.build
|
|
||||||
|
|
||||||
# pytorch checkpoint
|
|
||||||
*.pt
|
|
||||||
|
|
||||||
# ignore version.py generated by setup.py
|
|
||||||
colossalai/version.py
|
|
@ -1,202 +0,0 @@
|
|||||||
Copyright 2021- HPC-AI Technology Inc. All rights reserved.
|
|
||||||
Apache License
|
|
||||||
Version 2.0, January 2004
|
|
||||||
http://www.apache.org/licenses/
|
|
||||||
|
|
||||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
|
||||||
|
|
||||||
1. Definitions.
|
|
||||||
|
|
||||||
"License" shall mean the terms and conditions for use, reproduction,
|
|
||||||
and distribution as defined by Sections 1 through 9 of this document.
|
|
||||||
|
|
||||||
"Licensor" shall mean the copyright owner or entity authorized by
|
|
||||||
the copyright owner that is granting the License.
|
|
||||||
|
|
||||||
"Legal Entity" shall mean the union of the acting entity and all
|
|
||||||
other entities that control, are controlled by, or are under common
|
|
||||||
control with that entity. For the purposes of this definition,
|
|
||||||
"control" means (i) the power, direct or indirect, to cause the
|
|
||||||
direction or management of such entity, whether by contract or
|
|
||||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
|
||||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
|
||||||
|
|
||||||
"You" (or "Your") shall mean an individual or Legal Entity
|
|
||||||
exercising permissions granted by this License.
|
|
||||||
|
|
||||||
"Source" form shall mean the preferred form for making modifications,
|
|
||||||
including but not limited to software source code, documentation
|
|
||||||
source, and configuration files.
|
|
||||||
|
|
||||||
"Object" form shall mean any form resulting from mechanical
|
|
||||||
transformation or translation of a Source form, including but
|
|
||||||
not limited to compiled object code, generated documentation,
|
|
||||||
and conversions to other media types.
|
|
||||||
|
|
||||||
"Work" shall mean the work of authorship, whether in Source or
|
|
||||||
Object form, made available under the License, as indicated by a
|
|
||||||
copyright notice that is included in or attached to the work
|
|
||||||
(an example is provided in the Appendix below).
|
|
||||||
|
|
||||||
"Derivative Works" shall mean any work, whether in Source or Object
|
|
||||||
form, that is based on (or derived from) the Work and for which the
|
|
||||||
editorial revisions, annotations, elaborations, or other modifications
|
|
||||||
represent, as a whole, an original work of authorship. For the purposes
|
|
||||||
of this License, Derivative Works shall not include works that remain
|
|
||||||
separable from, or merely link (or bind by name) to the interfaces of,
|
|
||||||
the Work and Derivative Works thereof.
|
|
||||||
|
|
||||||
"Contribution" shall mean any work of authorship, including
|
|
||||||
the original version of the Work and any modifications or additions
|
|
||||||
to that Work or Derivative Works thereof, that is intentionally
|
|
||||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
|
||||||
or by an individual or Legal Entity authorized to submit on behalf of
|
|
||||||
the copyright owner. For the purposes of this definition, "submitted"
|
|
||||||
means any form of electronic, verbal, or written communication sent
|
|
||||||
to the Licensor or its representatives, including but not limited to
|
|
||||||
communication on electronic mailing lists, source code control systems,
|
|
||||||
and issue tracking systems that are managed by, or on behalf of, the
|
|
||||||
Licensor for the purpose of discussing and improving the Work, but
|
|
||||||
excluding communication that is conspicuously marked or otherwise
|
|
||||||
designated in writing by the copyright owner as "Not a Contribution."
|
|
||||||
|
|
||||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
|
||||||
on behalf of whom a Contribution has been received by Licensor and
|
|
||||||
subsequently incorporated within the Work.
|
|
||||||
|
|
||||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
|
||||||
this License, each Contributor hereby grants to You a perpetual,
|
|
||||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
|
||||||
copyright license to reproduce, prepare Derivative Works of,
|
|
||||||
publicly display, publicly perform, sublicense, and distribute the
|
|
||||||
Work and such Derivative Works in Source or Object form.
|
|
||||||
|
|
||||||
3. Grant of Patent License. Subject to the terms and conditions of
|
|
||||||
this License, each Contributor hereby grants to You a perpetual,
|
|
||||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
|
||||||
(except as stated in this section) patent license to make, have made,
|
|
||||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
|
||||||
where such license applies only to those patent claims licensable
|
|
||||||
by such Contributor that are necessarily infringed by their
|
|
||||||
Contribution(s) alone or by combination of their Contribution(s)
|
|
||||||
with the Work to which such Contribution(s) was submitted. If You
|
|
||||||
institute patent litigation against any entity (including a
|
|
||||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
|
||||||
or a Contribution incorporated within the Work constitutes direct
|
|
||||||
or contributory patent infringement, then any patent licenses
|
|
||||||
granted to You under this License for that Work shall terminate
|
|
||||||
as of the date such litigation is filed.
|
|
||||||
|
|
||||||
4. Redistribution. You may reproduce and distribute copies of the
|
|
||||||
Work or Derivative Works thereof in any medium, with or without
|
|
||||||
modifications, and in Source or Object form, provided that You
|
|
||||||
meet the following conditions:
|
|
||||||
|
|
||||||
(a) You must give any other recipients of the Work or
|
|
||||||
Derivative Works a copy of this License; and
|
|
||||||
|
|
||||||
(b) You must cause any modified files to carry prominent notices
|
|
||||||
stating that You changed the files; and
|
|
||||||
|
|
||||||
(c) You must retain, in the Source form of any Derivative Works
|
|
||||||
that You distribute, all copyright, patent, trademark, and
|
|
||||||
attribution notices from the Source form of the Work,
|
|
||||||
excluding those notices that do not pertain to any part of
|
|
||||||
the Derivative Works; and
|
|
||||||
|
|
||||||
(d) If the Work includes a "NOTICE" text file as part of its
|
|
||||||
distribution, then any Derivative Works that You distribute must
|
|
||||||
include a readable copy of the attribution notices contained
|
|
||||||
within such NOTICE file, excluding those notices that do not
|
|
||||||
pertain to any part of the Derivative Works, in at least one
|
|
||||||
of the following places: within a NOTICE text file distributed
|
|
||||||
as part of the Derivative Works; within the Source form or
|
|
||||||
documentation, if provided along with the Derivative Works; or,
|
|
||||||
within a display generated by the Derivative Works, if and
|
|
||||||
wherever such third-party notices normally appear. The contents
|
|
||||||
of the NOTICE file are for informational purposes only and
|
|
||||||
do not modify the License. You may add Your own attribution
|
|
||||||
notices within Derivative Works that You distribute, alongside
|
|
||||||
or as an addendum to the NOTICE text from the Work, provided
|
|
||||||
that such additional attribution notices cannot be construed
|
|
||||||
as modifying the License.
|
|
||||||
|
|
||||||
You may add Your own copyright statement to Your modifications and
|
|
||||||
may provide additional or different license terms and conditions
|
|
||||||
for use, reproduction, or distribution of Your modifications, or
|
|
||||||
for any such Derivative Works as a whole, provided Your use,
|
|
||||||
reproduction, and distribution of the Work otherwise complies with
|
|
||||||
the conditions stated in this License.
|
|
||||||
|
|
||||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
|
||||||
any Contribution intentionally submitted for inclusion in the Work
|
|
||||||
by You to the Licensor shall be under the terms and conditions of
|
|
||||||
this License, without any additional terms or conditions.
|
|
||||||
Notwithstanding the above, nothing herein shall supersede or modify
|
|
||||||
the terms of any separate license agreement you may have executed
|
|
||||||
with Licensor regarding such Contributions.
|
|
||||||
|
|
||||||
6. Trademarks. This License does not grant permission to use the trade
|
|
||||||
names, trademarks, service marks, or product names of the Licensor,
|
|
||||||
except as required for reasonable and customary use in describing the
|
|
||||||
origin of the Work and reproducing the content of the NOTICE file.
|
|
||||||
|
|
||||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
|
||||||
agreed to in writing, Licensor provides the Work (and each
|
|
||||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
|
||||||
implied, including, without limitation, any warranties or conditions
|
|
||||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
|
||||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
|
||||||
appropriateness of using or redistributing the Work and assume any
|
|
||||||
risks associated with Your exercise of permissions under this License.
|
|
||||||
|
|
||||||
8. Limitation of Liability. In no event and under no legal theory,
|
|
||||||
whether in tort (including negligence), contract, or otherwise,
|
|
||||||
unless required by applicable law (such as deliberate and grossly
|
|
||||||
negligent acts) or agreed to in writing, shall any Contributor be
|
|
||||||
liable to You for damages, including any direct, indirect, special,
|
|
||||||
incidental, or consequential damages of any character arising as a
|
|
||||||
result of this License or out of the use or inability to use the
|
|
||||||
Work (including but not limited to damages for loss of goodwill,
|
|
||||||
work stoppage, computer failure or malfunction, or any and all
|
|
||||||
other commercial damages or losses), even if such Contributor
|
|
||||||
has been advised of the possibility of such damages.
|
|
||||||
|
|
||||||
9. Accepting Warranty or Additional Liability. While redistributing
|
|
||||||
the Work or Derivative Works thereof, You may choose to offer,
|
|
||||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
|
||||||
or other liability obligations and/or rights consistent with this
|
|
||||||
License. However, in accepting such obligations, You may act only
|
|
||||||
on Your own behalf and on Your sole responsibility, not on behalf
|
|
||||||
of any other Contributor, and only if You agree to indemnify,
|
|
||||||
defend, and hold each Contributor harmless for any liability
|
|
||||||
incurred by, or claims asserted against, such Contributor by reason
|
|
||||||
of your accepting any such warranty or additional liability.
|
|
||||||
|
|
||||||
END OF TERMS AND CONDITIONS
|
|
||||||
|
|
||||||
APPENDIX: How to apply the Apache License to your work.
|
|
||||||
|
|
||||||
To apply the Apache License to your work, attach the following
|
|
||||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
|
||||||
replaced with your own identifying information. (Don't include
|
|
||||||
the brackets!) The text should be enclosed in the appropriate
|
|
||||||
comment syntax for the file format. We also recommend that a
|
|
||||||
file or class name and description of purpose be included on the
|
|
||||||
same "printed page" as the copyright notice for easier
|
|
||||||
identification within third-party archives.
|
|
||||||
|
|
||||||
Copyright 2021- HPC-AI Technology Inc.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
@ -1,209 +0,0 @@
|
|||||||
# RLHF - Colossal-AI
|
|
||||||
|
|
||||||
## Table of Contents
|
|
||||||
|
|
||||||
- [What is RLHF - Colossal-AI?](#intro)
|
|
||||||
- [How to Install?](#install)
|
|
||||||
- [The Plan](#the-plan)
|
|
||||||
- [How can you partcipate in open source?](#invitation-to-open-source-contribution)
|
|
||||||
---
|
|
||||||
## Intro
|
|
||||||
Implementation of RLHF (Reinforcement Learning with Human Feedback) powered by Colossal-AI. It supports distributed training and offloading, which can fit extremly large models. More details can be found in the [blog](https://www.hpc-ai.tech/blog/colossal-ai-chatgpt).
|
|
||||||
|
|
||||||
<p align="center">
|
|
||||||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chatgpt/chatgpt.png" width=700/>
|
|
||||||
</p>
|
|
||||||
|
|
||||||
## Training process (step 3)
|
|
||||||
<p align="center">
|
|
||||||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chatgpt/experience.jpg" width=500/>
|
|
||||||
</p>
|
|
||||||
<p align="center">
|
|
||||||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chatgpt/train.jpg" width=500/>
|
|
||||||
</p>
|
|
||||||
|
|
||||||
|
|
||||||
## Install
|
|
||||||
```shell
|
|
||||||
pip install .
|
|
||||||
```
|
|
||||||
|
|
||||||
## Usage
|
|
||||||
|
|
||||||
The main entrypoint is `Trainer`. We only support PPO trainer now. We support many training strategies:
|
|
||||||
|
|
||||||
- NaiveStrategy: simplest strategy. Train on single GPU.
|
|
||||||
- DDPStrategy: use `torch.nn.parallel.DistributedDataParallel`. Train on multi GPUs.
|
|
||||||
- ColossalAIStrategy: use Gemini and Zero of ColossalAI. It eliminates model duplication on each GPU and supports offload. It's very useful when training large models on multi GPUs.
|
|
||||||
|
|
||||||
Simplest usage:
|
|
||||||
|
|
||||||
```python
|
|
||||||
from chatgpt.trainer import PPOTrainer
|
|
||||||
from chatgpt.trainer.strategies import ColossalAIStrategy
|
|
||||||
from chatgpt.models.gpt import GPTActor, GPTCritic
|
|
||||||
from chatgpt.models.base import RewardModel
|
|
||||||
from copy import deepcopy
|
|
||||||
from colossalai.nn.optimizer import HybridAdam
|
|
||||||
|
|
||||||
strategy = ColossalAIStrategy()
|
|
||||||
|
|
||||||
with strategy.model_init_context():
|
|
||||||
# init your model here
|
|
||||||
# load pretrained gpt2
|
|
||||||
actor = GPTActor(pretrained='gpt2')
|
|
||||||
critic = GPTCritic()
|
|
||||||
initial_model = deepcopy(actor).cuda()
|
|
||||||
reward_model = RewardModel(deepcopy(critic.model), deepcopy(critic.value_head)).cuda()
|
|
||||||
|
|
||||||
actor_optim = HybridAdam(actor.parameters(), lr=5e-6)
|
|
||||||
critic_optim = HybridAdam(critic.parameters(), lr=5e-6)
|
|
||||||
|
|
||||||
# prepare models and optimizers
|
|
||||||
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare(
|
|
||||||
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model)
|
|
||||||
|
|
||||||
# load saved model checkpoint after preparing
|
|
||||||
strategy.load_model(actor, 'actor_checkpoint.pt', strict=False)
|
|
||||||
# load saved optimizer checkpoint after preparing
|
|
||||||
strategy.load_optimizer(actor_optim, 'actor_optim_checkpoint.pt')
|
|
||||||
|
|
||||||
trainer = PPOTrainer(strategy,
|
|
||||||
actor,
|
|
||||||
critic,
|
|
||||||
reward_model,
|
|
||||||
initial_model,
|
|
||||||
actor_optim,
|
|
||||||
critic_optim,
|
|
||||||
...)
|
|
||||||
|
|
||||||
trainer.fit(dataset, ...)
|
|
||||||
|
|
||||||
# save model checkpoint after fitting on only rank0
|
|
||||||
strategy.save_model(actor, 'actor_checkpoint.pt', only_rank0=True)
|
|
||||||
# save optimizer checkpoint on all ranks
|
|
||||||
strategy.save_optimizer(actor_optim, 'actor_optim_checkpoint.pt', only_rank0=False)
|
|
||||||
```
|
|
||||||
|
|
||||||
For more details, see `examples/`.
|
|
||||||
|
|
||||||
We also support training reward model with true-world data. See `examples/train_reward_model.py`.
|
|
||||||
|
|
||||||
## FAQ
|
|
||||||
|
|
||||||
### How to save/load checkpoint
|
|
||||||
|
|
||||||
To load pretrained model, you can simply use huggingface pretrained models:
|
|
||||||
|
|
||||||
```python
|
|
||||||
# load OPT-350m pretrained model
|
|
||||||
actor = OPTActor(pretrained='facebook/opt-350m')
|
|
||||||
```
|
|
||||||
|
|
||||||
To save model checkpoint:
|
|
||||||
|
|
||||||
```python
|
|
||||||
# save model checkpoint on only rank0
|
|
||||||
strategy.save_model(actor, 'actor_checkpoint.pt', only_rank0=True)
|
|
||||||
```
|
|
||||||
|
|
||||||
This function must be called after `strategy.prepare()`.
|
|
||||||
|
|
||||||
For DDP strategy, model weights are replicated on all ranks. And for ColossalAI strategy, model weights may be sharded, but all-gather will be applied before returning state dict. You can set `only_rank0=True` for both of them, which only saves checkpoint on rank0, to save disk space usage. The checkpoint is float32.
|
|
||||||
|
|
||||||
To save optimizer checkpoint:
|
|
||||||
|
|
||||||
```python
|
|
||||||
# save optimizer checkpoint on all ranks
|
|
||||||
strategy.save_optimizer(actor_optim, 'actor_optim_checkpoint.pt', only_rank0=False)
|
|
||||||
```
|
|
||||||
|
|
||||||
For DDP strategy, optimizer states are replicated on all ranks. You can set `only_rank0=True`. But for ColossalAI strategy, optimizer states are sharded over all ranks, and no all-gather will be applied. So for ColossalAI strategy, you can only set `only_rank0=False`. That is to say, each rank will save a cehckpoint. When loading, each rank should load the corresponding part.
|
|
||||||
|
|
||||||
Note that different stategy may have different shapes of optimizer checkpoint.
|
|
||||||
|
|
||||||
To load model checkpoint:
|
|
||||||
|
|
||||||
```python
|
|
||||||
# load saved model checkpoint after preparing
|
|
||||||
strategy.load_model(actor, 'actor_checkpoint.pt', strict=False)
|
|
||||||
```
|
|
||||||
|
|
||||||
To load optimizer checkpoint:
|
|
||||||
|
|
||||||
```python
|
|
||||||
# load saved optimizer checkpoint after preparing
|
|
||||||
strategy.load_optimizer(actor_optim, 'actor_optim_checkpoint.pt')
|
|
||||||
```
|
|
||||||
|
|
||||||
## The Plan
|
|
||||||
|
|
||||||
- [x] implement PPO fine-tuning
|
|
||||||
- [x] implement training reward model
|
|
||||||
- [x] support LoRA
|
|
||||||
- [x] support inference
|
|
||||||
- [ ] open source the reward model weight
|
|
||||||
- [ ] support llama from [facebook](https://github.com/facebookresearch/llama)
|
|
||||||
- [ ] support BoN(best of N sample)
|
|
||||||
- [ ] implement PPO-ptx fine-tuning
|
|
||||||
- [ ] integrate with Ray
|
|
||||||
- [ ] support more RL paradigms, like Implicit Language Q-Learning (ILQL),
|
|
||||||
- [ ] support chain of throught by [langchain](https://github.com/hwchase17/langchain)
|
|
||||||
|
|
||||||
### Real-time progress
|
|
||||||
You will find our progress in github project broad
|
|
||||||
|
|
||||||
[Open ChatGPT](https://github.com/orgs/hpcaitech/projects/17/views/1)
|
|
||||||
|
|
||||||
## Invitation to open-source contribution
|
|
||||||
Referring to the successful attempts of [BLOOM](https://bigscience.huggingface.co/) and [Stable Diffusion](https://en.wikipedia.org/wiki/Stable_Diffusion), any and all developers and partners with computing powers, datasets, models are welcome to join and build the Colossal-AI community, making efforts towards the era of big AI models from the starting point of replicating ChatGPT!
|
|
||||||
|
|
||||||
You may contact us or participate in the following ways:
|
|
||||||
1. [Leaving a Star ⭐](https://github.com/hpcaitech/ColossalAI/stargazers) to show your like and support. Thanks!
|
|
||||||
2. Posting an [issue](https://github.com/hpcaitech/ColossalAI/issues/new/choose), or submitting a PR on GitHub follow the guideline in [Contributing](https://github.com/hpcaitech/ColossalAI/blob/main/CONTRIBUTING.md).
|
|
||||||
3. Join the Colossal-AI community on
|
|
||||||
[Slack](https://join.slack.com/t/colossalaiworkspace/shared_invite/zt-z7b26eeb-CBp7jouvu~r0~lcFzX832w),
|
|
||||||
and [WeChat(微信)](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/WeChat.png "qrcode") to share your ideas.
|
|
||||||
4. Send your official proposal to email contact@hpcaitech.com
|
|
||||||
|
|
||||||
Thanks so much to all of our amazing contributors!
|
|
||||||
|
|
||||||
## Quick Preview
|
|
||||||
<p id="ChatGPT_scaling" align="center">
|
|
||||||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chatgpt/ChatGPT%20scaling.png" width=800/>
|
|
||||||
</p>
|
|
||||||
|
|
||||||
- Up to 7.73 times faster for single server training and 1.42 times faster for single-GPU inference
|
|
||||||
|
|
||||||
<p id="ChatGPT-1GPU" align="center">
|
|
||||||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chatgpt/ChatGPT-1GPU.jpg" width=450/>
|
|
||||||
</p>
|
|
||||||
|
|
||||||
- Up to 10.3x growth in model capacity on one GPU
|
|
||||||
- A mini demo training process requires only 1.62GB of GPU memory (any consumer-grade GPU)
|
|
||||||
|
|
||||||
<p id="inference" align="center">
|
|
||||||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chatgpt/LoRA%20data.jpg" width=600/>
|
|
||||||
</p>
|
|
||||||
|
|
||||||
- Increase the capacity of the fine-tuning model by up to 3.7 times on a single GPU
|
|
||||||
- Keep in a sufficiently high running speed
|
|
||||||
|
|
||||||
## Citations
|
|
||||||
|
|
||||||
```bibtex
|
|
||||||
@article{Hu2021LoRALA,
|
|
||||||
title = {LoRA: Low-Rank Adaptation of Large Language Models},
|
|
||||||
author = {Edward J. Hu and Yelong Shen and Phillip Wallis and Zeyuan Allen-Zhu and Yuanzhi Li and Shean Wang and Weizhu Chen},
|
|
||||||
journal = {ArXiv},
|
|
||||||
year = {2021},
|
|
||||||
volume = {abs/2106.09685}
|
|
||||||
}
|
|
||||||
|
|
||||||
@article{ouyang2022training,
|
|
||||||
title={Training language models to follow instructions with human feedback},
|
|
||||||
author={Ouyang, Long and Wu, Jeff and Jiang, Xu and Almeida, Diogo and Wainwright, Carroll L and Mishkin, Pamela and Zhang, Chong and Agarwal, Sandhini and Slama, Katarina and Ray, Alex and others},
|
|
||||||
journal={arXiv preprint arXiv:2203.02155},
|
|
||||||
year={2022}
|
|
||||||
}
|
|
||||||
```
|
|
@ -1,94 +0,0 @@
|
|||||||
# Benchmarks
|
|
||||||
|
|
||||||
## Benchmark GPT on dummy prompt data
|
|
||||||
|
|
||||||
We provide various GPT models (string in parentheses is the corresponding model name used in this script):
|
|
||||||
|
|
||||||
- GPT2-S (s)
|
|
||||||
- GPT2-M (m)
|
|
||||||
- GPT2-L (l)
|
|
||||||
- GPT2-XL (xl)
|
|
||||||
- GPT2-4B (4b)
|
|
||||||
- GPT2-6B (6b)
|
|
||||||
- GPT2-8B (8b)
|
|
||||||
- GPT2-10B (10b)
|
|
||||||
- GPT2-12B (12b)
|
|
||||||
- GPT2-15B (15b)
|
|
||||||
- GPT2-18B (18b)
|
|
||||||
- GPT2-20B (20b)
|
|
||||||
- GPT2-24B (24b)
|
|
||||||
- GPT2-28B (28b)
|
|
||||||
- GPT2-32B (32b)
|
|
||||||
- GPT2-36B (36b)
|
|
||||||
- GPT2-40B (40b)
|
|
||||||
- GPT3 (175b)
|
|
||||||
|
|
||||||
We also provide various training strategies:
|
|
||||||
|
|
||||||
- ddp: torch DDP
|
|
||||||
- colossalai_gemini: ColossalAI GeminiDDP with `placement_policy="cuda"`, like zero3
|
|
||||||
- colossalai_gemini_cpu: ColossalAI GeminiDDP with `placement_policy="cpu"`, like zero3-offload
|
|
||||||
- colossalai_zero2: ColossalAI zero2
|
|
||||||
- colossalai_zero2_cpu: ColossalAI zero2-offload
|
|
||||||
- colossalai_zero1: ColossalAI zero1
|
|
||||||
- colossalai_zero1_cpu: ColossalAI zero1-offload
|
|
||||||
|
|
||||||
We only support `torchrun` to launch now. E.g.
|
|
||||||
|
|
||||||
```shell
|
|
||||||
# run GPT2-S on single-node single-GPU with min batch size
|
|
||||||
torchrun --standalone --nproc_per_node 1 benchmark_gpt_dummy.py --model s --strategy ddp --experience_batch_size 1 --train_batch_size 1
|
|
||||||
# run GPT2-XL on single-node 4-GPU
|
|
||||||
torchrun --standalone --nproc_per_node 4 benchmark_gpt_dummy.py --model xl --strategy colossalai_zero2
|
|
||||||
# run GPT3 on 8-node 8-GPU
|
|
||||||
torchrun --nnodes 8 --nproc_per_node 8 \
|
|
||||||
--rdzv_id=$JOB_ID --rdzv_backend=c10d --rdzv_endpoint=$HOST_NODE_ADDR \
|
|
||||||
benchmark_gpt_dummy.py --model 175b --strategy colossalai_gemini
|
|
||||||
```
|
|
||||||
|
|
||||||
> ⚠ Batch sizes in CLI args and outputed throughput/TFLOPS are all values of per GPU.
|
|
||||||
|
|
||||||
In this benchmark, we assume the model architectures/sizes of actor and critic are the same for simplicity. But in practice, to reduce training cost, we may use a smaller critic.
|
|
||||||
|
|
||||||
We also provide a simple shell script to run a set of benchmarks. But it only supports benchmark on single node. However, it's easy to run on multi-nodes by modifying launch command in this script.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
|
|
||||||
```shell
|
|
||||||
# run for GPUS=(1 2 4 8) x strategy=("ddp" "colossalai_zero2" "colossalai_gemini" "colossalai_zero2_cpu" "colossalai_gemini_cpu") x model=("s" "m" "l" "xl" "2b" "4b" "6b" "8b" "10b") x batch_size=(1 2 4 8 16 32 64 128 256)
|
|
||||||
./benchmark_gpt_dummy.sh
|
|
||||||
# run for GPUS=2 x strategy=("ddp" "colossalai_zero2" "colossalai_gemini" "colossalai_zero2_cpu" "colossalai_gemini_cpu") x model=("s" "m" "l" "xl" "2b" "4b" "6b" "8b" "10b") x batch_size=(1 2 4 8 16 32 64 128 256)
|
|
||||||
./benchmark_gpt_dummy.sh 2
|
|
||||||
# run for GPUS=2 x strategy=ddp x model=("s" "m" "l" "xl" "2b" "4b" "6b" "8b" "10b") x batch_size=(1 2 4 8 16 32 64 128 256)
|
|
||||||
./benchmark_gpt_dummy.sh 2 ddp
|
|
||||||
# run for GPUS=2 x strategy=ddp x model=l x batch_size=(1 2 4 8 16 32 64 128 256)
|
|
||||||
./benchmark_gpt_dummy.sh 2 ddp l
|
|
||||||
```
|
|
||||||
|
|
||||||
## Benchmark OPT with LoRA on dummy prompt data
|
|
||||||
|
|
||||||
We provide various OPT models (string in parentheses is the corresponding model name used in this script):
|
|
||||||
|
|
||||||
- OPT-125M (125m)
|
|
||||||
- OPT-350M (350m)
|
|
||||||
- OPT-700M (700m)
|
|
||||||
- OPT-1.3B (1.3b)
|
|
||||||
- OPT-2.7B (2.7b)
|
|
||||||
- OPT-3.5B (3.5b)
|
|
||||||
- OPT-5.5B (5.5b)
|
|
||||||
- OPT-6.7B (6.7b)
|
|
||||||
- OPT-10B (10b)
|
|
||||||
- OPT-13B (13b)
|
|
||||||
|
|
||||||
We only support `torchrun` to launch now. E.g.
|
|
||||||
|
|
||||||
```shell
|
|
||||||
# run OPT-125M with no lora (lora_rank=0) on single-node single-GPU with min batch size
|
|
||||||
torchrun --standalone --nproc_per_node 1 benchmark_opt_lora_dummy.py --model 125m --strategy ddp --experience_batch_size 1 --train_batch_size 1 --lora_rank 0
|
|
||||||
# run OPT-350M with lora_rank=4 on single-node 4-GPU
|
|
||||||
torchrun --standalone --nproc_per_node 4 benchmark_opt_lora_dummy.py --model 350m --strategy colossalai_zero2 --lora_rank 4
|
|
||||||
```
|
|
||||||
|
|
||||||
> ⚠ Batch sizes in CLI args and outputed throughput/TFLOPS are all values of per GPU.
|
|
||||||
|
|
||||||
In this benchmark, we assume the model architectures/sizes of actor and critic are the same for simplicity. But in practice, to reduce training cost, we may use a smaller critic.
|
|
@ -1,184 +0,0 @@
|
|||||||
import argparse
|
|
||||||
from copy import deepcopy
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.distributed as dist
|
|
||||||
import torch.nn as nn
|
|
||||||
from chatgpt.models.base import RewardModel
|
|
||||||
from chatgpt.models.gpt import GPTActor, GPTCritic
|
|
||||||
from chatgpt.trainer import PPOTrainer
|
|
||||||
from chatgpt.trainer.callbacks import PerformanceEvaluator
|
|
||||||
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, Strategy
|
|
||||||
from torch.optim import Adam
|
|
||||||
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
|
|
||||||
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
|
|
||||||
|
|
||||||
from colossalai.nn.optimizer import HybridAdam
|
|
||||||
|
|
||||||
|
|
||||||
def get_model_numel(model: nn.Module, strategy: Strategy) -> int:
|
|
||||||
numel = sum(p.numel() for p in model.parameters())
|
|
||||||
if isinstance(strategy, ColossalAIStrategy) and strategy.stage == 3 and strategy.shard_init:
|
|
||||||
numel *= dist.get_world_size()
|
|
||||||
return numel
|
|
||||||
|
|
||||||
|
|
||||||
def preprocess_batch(samples) -> dict:
|
|
||||||
input_ids = torch.stack(samples)
|
|
||||||
attention_mask = torch.ones_like(input_ids, dtype=torch.long)
|
|
||||||
return {'input_ids': input_ids, 'attention_mask': attention_mask}
|
|
||||||
|
|
||||||
|
|
||||||
def print_rank_0(*args, **kwargs) -> None:
|
|
||||||
if dist.get_rank() == 0:
|
|
||||||
print(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def print_model_numel(model_dict: dict) -> None:
|
|
||||||
B = 1024**3
|
|
||||||
M = 1024**2
|
|
||||||
K = 1024
|
|
||||||
outputs = ''
|
|
||||||
for name, numel in model_dict.items():
|
|
||||||
outputs += f'{name}: '
|
|
||||||
if numel >= B:
|
|
||||||
outputs += f'{numel / B:.2f} B\n'
|
|
||||||
elif numel >= M:
|
|
||||||
outputs += f'{numel / M:.2f} M\n'
|
|
||||||
elif numel >= K:
|
|
||||||
outputs += f'{numel / K:.2f} K\n'
|
|
||||||
else:
|
|
||||||
outputs += f'{numel}\n'
|
|
||||||
print_rank_0(outputs)
|
|
||||||
|
|
||||||
|
|
||||||
def get_gpt_config(model_name: str) -> GPT2Config:
|
|
||||||
model_map = {
|
|
||||||
's': GPT2Config(),
|
|
||||||
'm': GPT2Config(n_embd=1024, n_layer=24, n_head=16),
|
|
||||||
'l': GPT2Config(n_embd=1280, n_layer=36, n_head=20),
|
|
||||||
'xl': GPT2Config(n_embd=1600, n_layer=48, n_head=25),
|
|
||||||
'2b': GPT2Config(n_embd=2048, n_layer=40, n_head=16),
|
|
||||||
'4b': GPT2Config(n_embd=2304, n_layer=64, n_head=16),
|
|
||||||
'6b': GPT2Config(n_embd=4096, n_layer=30, n_head=16),
|
|
||||||
'8b': GPT2Config(n_embd=4096, n_layer=40, n_head=16),
|
|
||||||
'10b': GPT2Config(n_embd=4096, n_layer=50, n_head=16),
|
|
||||||
'12b': GPT2Config(n_embd=4096, n_layer=60, n_head=16),
|
|
||||||
'15b': GPT2Config(n_embd=4096, n_layer=78, n_head=16),
|
|
||||||
'18b': GPT2Config(n_embd=4096, n_layer=90, n_head=16),
|
|
||||||
'20b': GPT2Config(n_embd=8192, n_layer=25, n_head=16),
|
|
||||||
'24b': GPT2Config(n_embd=8192, n_layer=30, n_head=16),
|
|
||||||
'28b': GPT2Config(n_embd=8192, n_layer=35, n_head=16),
|
|
||||||
'32b': GPT2Config(n_embd=8192, n_layer=40, n_head=16),
|
|
||||||
'36b': GPT2Config(n_embd=8192, n_layer=45, n_head=16),
|
|
||||||
'40b': GPT2Config(n_embd=8192, n_layer=50, n_head=16),
|
|
||||||
'175b': GPT2Config(n_positions=2048, n_embd=12288, n_layer=96, n_head=96),
|
|
||||||
}
|
|
||||||
try:
|
|
||||||
return model_map[model_name]
|
|
||||||
except KeyError:
|
|
||||||
raise ValueError(f'Unknown model "{model_name}"')
|
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
|
||||||
if args.strategy == 'ddp':
|
|
||||||
strategy = DDPStrategy()
|
|
||||||
elif args.strategy == 'colossalai_gemini':
|
|
||||||
strategy = ColossalAIStrategy(stage=3, placement_policy='cuda', initial_scale=2**5)
|
|
||||||
elif args.strategy == 'colossalai_gemini_cpu':
|
|
||||||
strategy = ColossalAIStrategy(stage=3, placement_policy='cpu', initial_scale=2**5)
|
|
||||||
elif args.strategy == 'colossalai_zero2':
|
|
||||||
strategy = ColossalAIStrategy(stage=2, placement_policy='cuda')
|
|
||||||
elif args.strategy == 'colossalai_zero2_cpu':
|
|
||||||
strategy = ColossalAIStrategy(stage=2, placement_policy='cpu')
|
|
||||||
elif args.strategy == 'colossalai_zero1':
|
|
||||||
strategy = ColossalAIStrategy(stage=1, placement_policy='cuda')
|
|
||||||
elif args.strategy == 'colossalai_zero1_cpu':
|
|
||||||
strategy = ColossalAIStrategy(stage=1, placement_policy='cpu')
|
|
||||||
else:
|
|
||||||
raise ValueError(f'Unsupported strategy "{args.strategy}"')
|
|
||||||
|
|
||||||
model_config = get_gpt_config(args.model)
|
|
||||||
|
|
||||||
with strategy.model_init_context():
|
|
||||||
actor = GPTActor(config=model_config).cuda()
|
|
||||||
critic = GPTCritic(config=model_config).cuda()
|
|
||||||
|
|
||||||
initial_model = deepcopy(actor).cuda()
|
|
||||||
reward_model = RewardModel(deepcopy(critic.model), deepcopy(critic.value_head)).cuda()
|
|
||||||
|
|
||||||
actor_numel = get_model_numel(actor, strategy)
|
|
||||||
critic_numel = get_model_numel(critic, strategy)
|
|
||||||
initial_model_numel = get_model_numel(initial_model, strategy)
|
|
||||||
reward_model_numel = get_model_numel(reward_model, strategy)
|
|
||||||
print_model_numel({
|
|
||||||
'Actor': actor_numel,
|
|
||||||
'Critic': critic_numel,
|
|
||||||
'Initial model': initial_model_numel,
|
|
||||||
'Reward model': reward_model_numel
|
|
||||||
})
|
|
||||||
performance_evaluator = PerformanceEvaluator(actor_numel,
|
|
||||||
critic_numel,
|
|
||||||
initial_model_numel,
|
|
||||||
reward_model_numel,
|
|
||||||
enable_grad_checkpoint=False,
|
|
||||||
ignore_episodes=1)
|
|
||||||
|
|
||||||
if args.strategy.startswith('colossalai'):
|
|
||||||
actor_optim = HybridAdam(actor.parameters(), lr=5e-6)
|
|
||||||
critic_optim = HybridAdam(critic.parameters(), lr=5e-6)
|
|
||||||
else:
|
|
||||||
actor_optim = Adam(actor.parameters(), lr=5e-6)
|
|
||||||
critic_optim = Adam(critic.parameters(), lr=5e-6)
|
|
||||||
|
|
||||||
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
|
||||||
|
|
||||||
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare(
|
|
||||||
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model)
|
|
||||||
|
|
||||||
trainer = PPOTrainer(strategy,
|
|
||||||
actor,
|
|
||||||
critic,
|
|
||||||
reward_model,
|
|
||||||
initial_model,
|
|
||||||
actor_optim,
|
|
||||||
critic_optim,
|
|
||||||
max_epochs=args.max_epochs,
|
|
||||||
train_batch_size=args.train_batch_size,
|
|
||||||
experience_batch_size=args.experience_batch_size,
|
|
||||||
tokenizer=preprocess_batch,
|
|
||||||
max_length=512,
|
|
||||||
do_sample=True,
|
|
||||||
temperature=1.0,
|
|
||||||
top_k=50,
|
|
||||||
pad_token_id=tokenizer.pad_token_id,
|
|
||||||
eos_token_id=tokenizer.eos_token_id,
|
|
||||||
callbacks=[performance_evaluator])
|
|
||||||
|
|
||||||
random_prompts = torch.randint(tokenizer.vocab_size, (1000, 400), device=torch.cuda.current_device())
|
|
||||||
trainer.fit(random_prompts,
|
|
||||||
num_episodes=args.num_episodes,
|
|
||||||
max_timesteps=args.max_timesteps,
|
|
||||||
update_timesteps=args.update_timesteps)
|
|
||||||
|
|
||||||
print_rank_0(f'Peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.2f} GB')
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument('--model', default='s')
|
|
||||||
parser.add_argument('--strategy',
|
|
||||||
choices=[
|
|
||||||
'ddp', 'colossalai_gemini', 'colossalai_gemini_cpu', 'colossalai_zero2',
|
|
||||||
'colossalai_zero2_cpu', 'colossalai_zero1', 'colossalai_zero1_cpu'
|
|
||||||
],
|
|
||||||
default='ddp')
|
|
||||||
parser.add_argument('--num_episodes', type=int, default=3)
|
|
||||||
parser.add_argument('--max_timesteps', type=int, default=8)
|
|
||||||
parser.add_argument('--update_timesteps', type=int, default=8)
|
|
||||||
parser.add_argument('--max_epochs', type=int, default=3)
|
|
||||||
parser.add_argument('--train_batch_size', type=int, default=8)
|
|
||||||
parser.add_argument('--experience_batch_size', type=int, default=8)
|
|
||||||
args = parser.parse_args()
|
|
||||||
main(args)
|
|
@ -1,45 +0,0 @@
|
|||||||
#!/usr/bin/env bash
|
|
||||||
# Usage: $0 <?number-of-gpus> <?strategy> <?model>
|
|
||||||
set -xu
|
|
||||||
|
|
||||||
BASE=$(realpath $(dirname $0))
|
|
||||||
|
|
||||||
|
|
||||||
PY_SCRIPT=${BASE}/benchmark_gpt_dummy.py
|
|
||||||
export OMP_NUM_THREADS=8
|
|
||||||
|
|
||||||
function tune_batch_size() {
|
|
||||||
# we found when experience batch size is equal to train batch size
|
|
||||||
# peak CUDA memory usage of making experience phase is less than or equal to that of training phase
|
|
||||||
# thus, experience batch size can be larger than or equal to train batch size
|
|
||||||
for bs in 1 2 4 8 16 32 64 128 256; do
|
|
||||||
torchrun --standalone --nproc_per_node $1 $PY_SCRIPT --model $2 --strategy $3 --experience_batch_size $bs --train_batch_size $bs || return 1
|
|
||||||
done
|
|
||||||
}
|
|
||||||
|
|
||||||
if [ $# -eq 0 ]; then
|
|
||||||
num_gpus=(1 2 4 8)
|
|
||||||
else
|
|
||||||
num_gpus=($1)
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ $# -le 1 ]; then
|
|
||||||
strategies=("ddp" "colossalai_zero2" "colossalai_gemini" "colossalai_zero2_cpu" "colossalai_gemini_cpu")
|
|
||||||
else
|
|
||||||
strategies=($2)
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ $# -le 2 ]; then
|
|
||||||
models=("s" "m" "l" "xl" "2b" "4b" "6b" "8b" "10b")
|
|
||||||
else
|
|
||||||
models=($3)
|
|
||||||
fi
|
|
||||||
|
|
||||||
|
|
||||||
for num_gpu in ${num_gpus[@]}; do
|
|
||||||
for strategy in ${strategies[@]}; do
|
|
||||||
for model in ${models[@]}; do
|
|
||||||
tune_batch_size $num_gpu $model $strategy || break
|
|
||||||
done
|
|
||||||
done
|
|
||||||
done
|
|
@ -1,179 +0,0 @@
|
|||||||
import argparse
|
|
||||||
from copy import deepcopy
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.distributed as dist
|
|
||||||
import torch.nn as nn
|
|
||||||
from chatgpt.models.base import RewardModel
|
|
||||||
from chatgpt.models.opt import OPTActor, OPTCritic
|
|
||||||
from chatgpt.trainer import PPOTrainer
|
|
||||||
from chatgpt.trainer.callbacks import PerformanceEvaluator
|
|
||||||
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, Strategy
|
|
||||||
from torch.optim import Adam
|
|
||||||
from transformers import AutoTokenizer
|
|
||||||
from transformers.models.opt.configuration_opt import OPTConfig
|
|
||||||
|
|
||||||
from colossalai.nn.optimizer import HybridAdam
|
|
||||||
|
|
||||||
|
|
||||||
def get_model_numel(model: nn.Module, strategy: Strategy) -> int:
|
|
||||||
numel = sum(p.numel() for p in model.parameters())
|
|
||||||
if isinstance(strategy, ColossalAIStrategy) and strategy.stage == 3 and strategy.shard_init:
|
|
||||||
numel *= dist.get_world_size()
|
|
||||||
return numel
|
|
||||||
|
|
||||||
|
|
||||||
def preprocess_batch(samples) -> dict:
|
|
||||||
input_ids = torch.stack(samples)
|
|
||||||
attention_mask = torch.ones_like(input_ids, dtype=torch.long)
|
|
||||||
return {'input_ids': input_ids, 'attention_mask': attention_mask}
|
|
||||||
|
|
||||||
|
|
||||||
def print_rank_0(*args, **kwargs) -> None:
|
|
||||||
if dist.get_rank() == 0:
|
|
||||||
print(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def print_model_numel(model_dict: dict) -> None:
|
|
||||||
B = 1024**3
|
|
||||||
M = 1024**2
|
|
||||||
K = 1024
|
|
||||||
outputs = ''
|
|
||||||
for name, numel in model_dict.items():
|
|
||||||
outputs += f'{name}: '
|
|
||||||
if numel >= B:
|
|
||||||
outputs += f'{numel / B:.2f} B\n'
|
|
||||||
elif numel >= M:
|
|
||||||
outputs += f'{numel / M:.2f} M\n'
|
|
||||||
elif numel >= K:
|
|
||||||
outputs += f'{numel / K:.2f} K\n'
|
|
||||||
else:
|
|
||||||
outputs += f'{numel}\n'
|
|
||||||
print_rank_0(outputs)
|
|
||||||
|
|
||||||
|
|
||||||
def get_gpt_config(model_name: str) -> OPTConfig:
|
|
||||||
model_map = {
|
|
||||||
'125m': OPTConfig.from_pretrained('facebook/opt-125m'),
|
|
||||||
'350m': OPTConfig(hidden_size=1024, ffn_dim=4096, num_hidden_layers=24, num_attention_heads=16),
|
|
||||||
'700m': OPTConfig(hidden_size=1280, ffn_dim=5120, num_hidden_layers=36, num_attention_heads=20),
|
|
||||||
'1.3b': OPTConfig.from_pretrained('facebook/opt-1.3b'),
|
|
||||||
'2.7b': OPTConfig.from_pretrained('facebook/opt-2.7b'),
|
|
||||||
'3.5b': OPTConfig(hidden_size=3072, ffn_dim=12288, num_hidden_layers=32, num_attention_heads=32),
|
|
||||||
'5.5b': OPTConfig(hidden_size=3840, ffn_dim=15360, num_hidden_layers=32, num_attention_heads=32),
|
|
||||||
'6.7b': OPTConfig.from_pretrained('facebook/opt-6.7b'),
|
|
||||||
'10b': OPTConfig(hidden_size=5120, ffn_dim=20480, num_hidden_layers=32, num_attention_heads=32),
|
|
||||||
'13b': OPTConfig.from_pretrained('facebook/opt-13b'),
|
|
||||||
}
|
|
||||||
try:
|
|
||||||
return model_map[model_name]
|
|
||||||
except KeyError:
|
|
||||||
raise ValueError(f'Unknown model "{model_name}"')
|
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
|
||||||
if args.strategy == 'ddp':
|
|
||||||
strategy = DDPStrategy()
|
|
||||||
elif args.strategy == 'colossalai_gemini':
|
|
||||||
strategy = ColossalAIStrategy(stage=3, placement_policy='cuda', initial_scale=2**5)
|
|
||||||
elif args.strategy == 'colossalai_gemini_cpu':
|
|
||||||
strategy = ColossalAIStrategy(stage=3, placement_policy='cpu', initial_scale=2**5)
|
|
||||||
elif args.strategy == 'colossalai_zero2':
|
|
||||||
strategy = ColossalAIStrategy(stage=2, placement_policy='cuda')
|
|
||||||
elif args.strategy == 'colossalai_zero2_cpu':
|
|
||||||
strategy = ColossalAIStrategy(stage=2, placement_policy='cpu')
|
|
||||||
elif args.strategy == 'colossalai_zero1':
|
|
||||||
strategy = ColossalAIStrategy(stage=1, placement_policy='cuda')
|
|
||||||
elif args.strategy == 'colossalai_zero1_cpu':
|
|
||||||
strategy = ColossalAIStrategy(stage=1, placement_policy='cpu')
|
|
||||||
else:
|
|
||||||
raise ValueError(f'Unsupported strategy "{args.strategy}"')
|
|
||||||
|
|
||||||
torch.cuda.set_per_process_memory_fraction(args.cuda_mem_frac)
|
|
||||||
|
|
||||||
model_config = get_gpt_config(args.model)
|
|
||||||
|
|
||||||
with strategy.model_init_context():
|
|
||||||
actor = OPTActor(config=model_config, lora_rank=args.lora_rank).cuda()
|
|
||||||
critic = OPTCritic(config=model_config, lora_rank=args.lora_rank).cuda()
|
|
||||||
|
|
||||||
initial_model = deepcopy(actor).cuda()
|
|
||||||
reward_model = RewardModel(deepcopy(critic.model), deepcopy(critic.value_head)).cuda()
|
|
||||||
|
|
||||||
actor_numel = get_model_numel(actor, strategy)
|
|
||||||
critic_numel = get_model_numel(critic, strategy)
|
|
||||||
initial_model_numel = get_model_numel(initial_model, strategy)
|
|
||||||
reward_model_numel = get_model_numel(reward_model, strategy)
|
|
||||||
print_model_numel({
|
|
||||||
'Actor': actor_numel,
|
|
||||||
'Critic': critic_numel,
|
|
||||||
'Initial model': initial_model_numel,
|
|
||||||
'Reward model': reward_model_numel
|
|
||||||
})
|
|
||||||
performance_evaluator = PerformanceEvaluator(actor_numel,
|
|
||||||
critic_numel,
|
|
||||||
initial_model_numel,
|
|
||||||
reward_model_numel,
|
|
||||||
enable_grad_checkpoint=False,
|
|
||||||
ignore_episodes=1)
|
|
||||||
|
|
||||||
if args.strategy.startswith('colossalai'):
|
|
||||||
actor_optim = HybridAdam(actor.parameters(), lr=5e-6)
|
|
||||||
critic_optim = HybridAdam(critic.parameters(), lr=5e-6)
|
|
||||||
else:
|
|
||||||
actor_optim = Adam(actor.parameters(), lr=5e-6)
|
|
||||||
critic_optim = Adam(critic.parameters(), lr=5e-6)
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained('facebook/opt-350m')
|
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
|
||||||
|
|
||||||
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare(
|
|
||||||
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model)
|
|
||||||
|
|
||||||
trainer = PPOTrainer(strategy,
|
|
||||||
actor,
|
|
||||||
critic,
|
|
||||||
reward_model,
|
|
||||||
initial_model,
|
|
||||||
actor_optim,
|
|
||||||
critic_optim,
|
|
||||||
max_epochs=args.max_epochs,
|
|
||||||
train_batch_size=args.train_batch_size,
|
|
||||||
experience_batch_size=args.experience_batch_size,
|
|
||||||
tokenizer=preprocess_batch,
|
|
||||||
max_length=512,
|
|
||||||
do_sample=True,
|
|
||||||
temperature=1.0,
|
|
||||||
top_k=50,
|
|
||||||
pad_token_id=tokenizer.pad_token_id,
|
|
||||||
eos_token_id=tokenizer.eos_token_id,
|
|
||||||
callbacks=[performance_evaluator])
|
|
||||||
|
|
||||||
random_prompts = torch.randint(tokenizer.vocab_size, (1000, 400), device=torch.cuda.current_device())
|
|
||||||
trainer.fit(random_prompts,
|
|
||||||
num_episodes=args.num_episodes,
|
|
||||||
max_timesteps=args.max_timesteps,
|
|
||||||
update_timesteps=args.update_timesteps)
|
|
||||||
|
|
||||||
print_rank_0(f'Peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.2f} GB')
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument('--model', default='125m')
|
|
||||||
parser.add_argument('--strategy',
|
|
||||||
choices=[
|
|
||||||
'ddp', 'colossalai_gemini', 'colossalai_gemini_cpu', 'colossalai_zero2',
|
|
||||||
'colossalai_zero2_cpu', 'colossalai_zero1', 'colossalai_zero1_cpu'
|
|
||||||
],
|
|
||||||
default='ddp')
|
|
||||||
parser.add_argument('--num_episodes', type=int, default=3)
|
|
||||||
parser.add_argument('--max_timesteps', type=int, default=8)
|
|
||||||
parser.add_argument('--update_timesteps', type=int, default=8)
|
|
||||||
parser.add_argument('--max_epochs', type=int, default=3)
|
|
||||||
parser.add_argument('--train_batch_size', type=int, default=8)
|
|
||||||
parser.add_argument('--experience_batch_size', type=int, default=8)
|
|
||||||
parser.add_argument('--lora_rank', type=int, default=4)
|
|
||||||
parser.add_argument('--cuda_mem_frac', type=float, default=1.0)
|
|
||||||
args = parser.parse_args()
|
|
||||||
main(args)
|
|
@ -1,5 +0,0 @@
|
|||||||
from .reward_dataset import RmStaticDataset, HhRlhfDataset
|
|
||||||
from .utils import is_rank_0
|
|
||||||
from .sft_dataset import SFTDataset, AlpacaDataset, AlpacaDataCollator
|
|
||||||
|
|
||||||
__all__ = ['RmStaticDataset', 'HhRlhfDataset','is_rank_0', 'SFTDataset', 'AlpacaDataset', 'AlpacaDataCollator']
|
|
@ -1,109 +0,0 @@
|
|||||||
from typing import Callable
|
|
||||||
|
|
||||||
from torch.utils.data import Dataset
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
from .utils import is_rank_0
|
|
||||||
|
|
||||||
# Dahaos/rm-static
|
|
||||||
class RmStaticDataset(Dataset):
|
|
||||||
"""
|
|
||||||
Dataset for reward model
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dataset: dataset for reward model
|
|
||||||
tokenizer: tokenizer for reward model
|
|
||||||
max_length: max length of input
|
|
||||||
special_token: special token at the end of sentence
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=None) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.chosen = []
|
|
||||||
self.reject = []
|
|
||||||
if special_token is None:
|
|
||||||
self.end_token = tokenizer.eos_token
|
|
||||||
else:
|
|
||||||
self.end_token = special_token
|
|
||||||
for data in tqdm(dataset, disable=not is_rank_0()):
|
|
||||||
prompt = data['prompt']
|
|
||||||
|
|
||||||
chosen = prompt + data['chosen'] + self.end_token
|
|
||||||
chosen_token = tokenizer(chosen,
|
|
||||||
max_length=max_length,
|
|
||||||
padding="max_length",
|
|
||||||
truncation=True,
|
|
||||||
return_tensors="pt")
|
|
||||||
self.chosen.append({
|
|
||||||
"input_ids": chosen_token['input_ids'],
|
|
||||||
"attention_mask": chosen_token['attention_mask']
|
|
||||||
})
|
|
||||||
|
|
||||||
reject = prompt + data['rejected'] + self.end_token
|
|
||||||
reject_token = tokenizer(reject,
|
|
||||||
max_length=max_length,
|
|
||||||
padding="max_length",
|
|
||||||
truncation=True,
|
|
||||||
return_tensors="pt")
|
|
||||||
self.reject.append({
|
|
||||||
"input_ids": reject_token['input_ids'],
|
|
||||||
"attention_mask": reject_token['attention_mask']
|
|
||||||
})
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
length = len(self.chosen)
|
|
||||||
return length
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
return self.chosen[idx]["input_ids"], self.chosen[idx]["attention_mask"], self.reject[idx][
|
|
||||||
"input_ids"], self.reject[idx]["attention_mask"]
|
|
||||||
|
|
||||||
# Anthropic/hh-rlhf
|
|
||||||
class HhRlhfDataset(Dataset):
|
|
||||||
"""
|
|
||||||
Dataset for reward model
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dataset: dataset for reward model
|
|
||||||
tokenizer: tokenizer for reward model
|
|
||||||
max_length: max length of input
|
|
||||||
special_token: special token at the end of sentence
|
|
||||||
"""
|
|
||||||
def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=None) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.chosen = []
|
|
||||||
self.reject = []
|
|
||||||
if special_token is None:
|
|
||||||
self.end_token = tokenizer.eos_token
|
|
||||||
else:
|
|
||||||
self.end_token = special_token
|
|
||||||
for data in tqdm(dataset, disable=not is_rank_0()):
|
|
||||||
chosen = data['chosen'] + self.end_token
|
|
||||||
chosen_token = tokenizer(chosen,
|
|
||||||
max_length=max_length,
|
|
||||||
padding="max_length",
|
|
||||||
truncation=True,
|
|
||||||
return_tensors="pt")
|
|
||||||
self.chosen.append({
|
|
||||||
"input_ids": chosen_token['input_ids'],
|
|
||||||
"attention_mask": chosen_token['attention_mask']
|
|
||||||
})
|
|
||||||
|
|
||||||
reject = data['rejected'] + self.end_token
|
|
||||||
reject_token = tokenizer(reject,
|
|
||||||
max_length=max_length,
|
|
||||||
padding="max_length",
|
|
||||||
truncation=True,
|
|
||||||
return_tensors="pt")
|
|
||||||
self.reject.append({
|
|
||||||
"input_ids": reject_token['input_ids'],
|
|
||||||
"attention_mask": reject_token['attention_mask']
|
|
||||||
})
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
length = len(self.chosen)
|
|
||||||
return length
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
return self.chosen[idx]["input_ids"], self.chosen[idx]["attention_mask"], self.reject[idx][
|
|
||||||
"input_ids"], self.reject[idx]["attention_mask"]
|
|
@ -1,168 +0,0 @@
|
|||||||
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
import copy
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import Callable, Dict, Sequence
|
|
||||||
import random
|
|
||||||
from torch.utils.data import Dataset
|
|
||||||
import torch.distributed as dist
|
|
||||||
from tqdm import tqdm
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from .utils import is_rank_0, jload
|
|
||||||
|
|
||||||
import transformers
|
|
||||||
from colossalai.logging import get_dist_logger
|
|
||||||
|
|
||||||
logger = get_dist_logger()
|
|
||||||
|
|
||||||
IGNORE_INDEX = -100
|
|
||||||
PROMPT_DICT = {
|
|
||||||
"prompt_input": (
|
|
||||||
"Below is an instruction that describes a task, paired with an input that provides further context. "
|
|
||||||
"Write a response that appropriately completes the request.\n\n"
|
|
||||||
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
|
|
||||||
),
|
|
||||||
"prompt_no_input": (
|
|
||||||
"Below is an instruction that describes a task. "
|
|
||||||
"Write a response that appropriately completes the request.\n\n"
|
|
||||||
"### Instruction:\n{instruction}\n\n### Response:"
|
|
||||||
),
|
|
||||||
}
|
|
||||||
|
|
||||||
class SFTDataset(Dataset):
|
|
||||||
"""
|
|
||||||
Dataset for sft model
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dataset: dataset for supervised model
|
|
||||||
tokenizer: tokenizer for supervised model
|
|
||||||
max_length: max length of input
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, dataset, tokenizer: Callable, max_length: int=512) -> None:
|
|
||||||
super().__init__()
|
|
||||||
# self.prompts = []
|
|
||||||
self.input_ids = []
|
|
||||||
|
|
||||||
for data in tqdm(dataset, disable=not is_rank_0()):
|
|
||||||
prompt = data['prompt'] + data['completion'] + "<|endoftext|>"
|
|
||||||
prompt_token = tokenizer(prompt,
|
|
||||||
max_length=max_length,
|
|
||||||
padding="max_length",
|
|
||||||
truncation=True,
|
|
||||||
return_tensors="pt")
|
|
||||||
|
|
||||||
# self.prompts.append(prompt_token)s
|
|
||||||
self.input_ids.append(prompt_token)
|
|
||||||
self.labels = copy.deepcopy(self.input_ids)
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
length = len(self.prompts)
|
|
||||||
return length
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
# dict(input_ids=self.input_ids[i], labels=self.labels[i])
|
|
||||||
return dict(input_ids=self.input_ids[i], labels=self.labels[i])
|
|
||||||
# return dict(self.prompts[idx], self.prompts[idx])
|
|
||||||
|
|
||||||
|
|
||||||
def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:
|
|
||||||
"""Tokenize a list of strings."""
|
|
||||||
tokenized_list = [
|
|
||||||
tokenizer(
|
|
||||||
text,
|
|
||||||
return_tensors="pt",
|
|
||||||
padding="longest",
|
|
||||||
max_length=tokenizer.model_max_length,
|
|
||||||
truncation=True,
|
|
||||||
)
|
|
||||||
for text in strings
|
|
||||||
]
|
|
||||||
input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
|
|
||||||
input_ids_lens = labels_lens = [
|
|
||||||
tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
|
|
||||||
]
|
|
||||||
return dict(
|
|
||||||
input_ids=input_ids,
|
|
||||||
labels=labels,
|
|
||||||
input_ids_lens=input_ids_lens,
|
|
||||||
labels_lens=labels_lens,
|
|
||||||
)
|
|
||||||
|
|
||||||
def preprocess(
|
|
||||||
sources: Sequence[str],
|
|
||||||
targets: Sequence[str],
|
|
||||||
tokenizer: transformers.PreTrainedTokenizer,
|
|
||||||
) -> Dict:
|
|
||||||
"""Preprocess the data by tokenizing."""
|
|
||||||
examples = [s + t for s, t in zip(sources, targets)]
|
|
||||||
examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)]
|
|
||||||
input_ids = examples_tokenized["input_ids"]
|
|
||||||
labels = copy.deepcopy(input_ids)
|
|
||||||
for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
|
|
||||||
label[:source_len] = IGNORE_INDEX
|
|
||||||
return dict(input_ids=input_ids, labels=labels)
|
|
||||||
|
|
||||||
class AlpacaDataset(Dataset):
|
|
||||||
"""Dataset for supervised fine-tuning."""
|
|
||||||
|
|
||||||
def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer, max_length: int=None):
|
|
||||||
super(AlpacaDataset, self).__init__()
|
|
||||||
logger.info("Loading data...")
|
|
||||||
list_data_dict = jload(data_path)
|
|
||||||
logger.info(f"Loaded {len(list_data_dict)} examples.")
|
|
||||||
|
|
||||||
if max_length is not None:
|
|
||||||
logger.info(f"Truncating data to max length {max_length}...")
|
|
||||||
list_data_dict = [example for example in list_data_dict if len(example["input"]) <= max_length]
|
|
||||||
|
|
||||||
logger.info("Formatting inputs...")
|
|
||||||
prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]
|
|
||||||
sources = [
|
|
||||||
prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example)
|
|
||||||
for example in list_data_dict
|
|
||||||
]
|
|
||||||
targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict]
|
|
||||||
|
|
||||||
logger.info("Tokenizing inputs... This may take some time...")
|
|
||||||
data_dict = preprocess(sources, targets, tokenizer)
|
|
||||||
|
|
||||||
self.input_ids = data_dict["input_ids"]
|
|
||||||
self.labels = data_dict["labels"]
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.input_ids)
|
|
||||||
|
|
||||||
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
|
|
||||||
return dict(input_ids=self.input_ids[i], labels=self.labels[i])
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class AlpacaDataCollator(object):
|
|
||||||
"""Collate examples for supervised fine-tuning."""
|
|
||||||
|
|
||||||
tokenizer: transformers.PreTrainedTokenizer
|
|
||||||
|
|
||||||
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
|
|
||||||
input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
|
|
||||||
input_ids = torch.nn.utils.rnn.pad_sequence(
|
|
||||||
input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
|
|
||||||
)
|
|
||||||
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
|
|
||||||
return dict(
|
|
||||||
input_ids=input_ids,
|
|
||||||
labels=labels,
|
|
||||||
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
|
|
||||||
)
|
|
@ -1,20 +0,0 @@
|
|||||||
import io
|
|
||||||
import json
|
|
||||||
|
|
||||||
import torch.distributed as dist
|
|
||||||
|
|
||||||
|
|
||||||
def is_rank_0() -> bool:
|
|
||||||
return not dist.is_initialized() or dist.get_rank() == 0
|
|
||||||
|
|
||||||
def _make_r_io_base(f, mode: str):
|
|
||||||
if not isinstance(f, io.IOBase):
|
|
||||||
f = open(f, mode=mode)
|
|
||||||
return f
|
|
||||||
|
|
||||||
def jload(f, mode="r"):
|
|
||||||
"""Load a .json file into a dictionary."""
|
|
||||||
f = _make_r_io_base(f, mode)
|
|
||||||
jdict = json.load(f)
|
|
||||||
f.close()
|
|
||||||
return jdict
|
|
@ -1,4 +0,0 @@
|
|||||||
from .base import Experience, ExperienceMaker
|
|
||||||
from .naive import NaiveExperienceMaker
|
|
||||||
|
|
||||||
__all__ = ['Experience', 'ExperienceMaker', 'NaiveExperienceMaker']
|
|
@ -1,77 +0,0 @@
|
|||||||
from abc import ABC, abstractmethod
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from chatgpt.models.base import Actor
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Experience:
|
|
||||||
"""Experience is a batch of data.
|
|
||||||
These data should have the the sequence length and number of actions.
|
|
||||||
Left padding for sequences is applied.
|
|
||||||
|
|
||||||
Shapes of each tensor:
|
|
||||||
sequences: (B, S)
|
|
||||||
action_log_probs: (B, A)
|
|
||||||
values: (B)
|
|
||||||
reward: (B)
|
|
||||||
advatanges: (B)
|
|
||||||
attention_mask: (B, S)
|
|
||||||
action_mask: (B, A)
|
|
||||||
|
|
||||||
"A" is the number of actions.
|
|
||||||
"""
|
|
||||||
sequences: torch.Tensor
|
|
||||||
action_log_probs: torch.Tensor
|
|
||||||
values: torch.Tensor
|
|
||||||
reward: torch.Tensor
|
|
||||||
advantages: torch.Tensor
|
|
||||||
attention_mask: Optional[torch.LongTensor]
|
|
||||||
action_mask: Optional[torch.BoolTensor]
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def to_device(self, device: torch.device) -> None:
|
|
||||||
self.sequences = self.sequences.to(device)
|
|
||||||
self.action_log_probs = self.action_log_probs.to(device)
|
|
||||||
self.values = self.values.to(device)
|
|
||||||
self.reward = self.reward.to(device)
|
|
||||||
self.advantages = self.advantages.to(device)
|
|
||||||
if self.attention_mask is not None:
|
|
||||||
self.attention_mask = self.attention_mask.to(device)
|
|
||||||
if self.action_mask is not None:
|
|
||||||
self.action_mask = self.action_mask.to(device)
|
|
||||||
|
|
||||||
def pin_memory(self):
|
|
||||||
self.sequences = self.sequences.pin_memory()
|
|
||||||
self.action_log_probs = self.action_log_probs.pin_memory()
|
|
||||||
self.values = self.values.pin_memory()
|
|
||||||
self.reward = self.reward.pin_memory()
|
|
||||||
self.advantages = self.advantages.pin_memory()
|
|
||||||
if self.attention_mask is not None:
|
|
||||||
self.attention_mask = self.attention_mask.pin_memory()
|
|
||||||
if self.action_mask is not None:
|
|
||||||
self.action_mask = self.action_mask.pin_memory()
|
|
||||||
return self
|
|
||||||
|
|
||||||
|
|
||||||
class ExperienceMaker(ABC):
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
actor: Actor,
|
|
||||||
critic: nn.Module,
|
|
||||||
reward_model: nn.Module,
|
|
||||||
initial_model: Actor,
|
|
||||||
kl_coef: float = 0.1) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.actor = actor
|
|
||||||
self.critic = critic
|
|
||||||
self.reward_model = reward_model
|
|
||||||
self.initial_model = initial_model
|
|
||||||
self.kl_coef = kl_coef
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def make_experience(self, input_ids: torch.Tensor, **generate_kwargs) -> Experience:
|
|
||||||
pass
|
|
@ -1,36 +0,0 @@
|
|||||||
import torch
|
|
||||||
from chatgpt.models.utils import compute_reward, normalize
|
|
||||||
|
|
||||||
from .base import Experience, ExperienceMaker
|
|
||||||
|
|
||||||
|
|
||||||
class NaiveExperienceMaker(ExperienceMaker):
|
|
||||||
"""
|
|
||||||
Naive experience maker.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def make_experience(self, input_ids: torch.Tensor, **generate_kwargs) -> Experience:
|
|
||||||
self.actor.eval()
|
|
||||||
self.critic.eval()
|
|
||||||
self.initial_model.eval()
|
|
||||||
self.reward_model.eval()
|
|
||||||
|
|
||||||
sequences, attention_mask, action_mask = self.actor.generate(input_ids,
|
|
||||||
return_action_mask=True,
|
|
||||||
**generate_kwargs)
|
|
||||||
num_actions = action_mask.size(1)
|
|
||||||
|
|
||||||
action_log_probs = self.actor(sequences, num_actions, attention_mask)
|
|
||||||
base_action_log_probs = self.initial_model(sequences, num_actions, attention_mask)
|
|
||||||
value = self.critic(sequences, action_mask, attention_mask)
|
|
||||||
r = self.reward_model(sequences, attention_mask)
|
|
||||||
|
|
||||||
reward = compute_reward(r, self.kl_coef, action_log_probs, base_action_log_probs, action_mask=action_mask)
|
|
||||||
|
|
||||||
advantage = reward - value
|
|
||||||
# TODO(ver217): maybe normalize adv
|
|
||||||
if advantage.ndim == 1:
|
|
||||||
advantage = advantage.unsqueeze(-1)
|
|
||||||
|
|
||||||
return Experience(sequences, action_log_probs, value, reward, advantage, attention_mask, action_mask)
|
|
@ -1,4 +0,0 @@
|
|||||||
from .base import Actor, Critic, RewardModel
|
|
||||||
from .loss import PolicyLoss, PPOPtxActorLoss, ValueLoss, LogSigLoss, LogExpLoss
|
|
||||||
|
|
||||||
__all__ = ['Actor', 'Critic', 'RewardModel', 'PolicyLoss', 'ValueLoss', 'PPOPtxActorLoss', 'LogSigLoss', 'LogExpLoss']
|
|
@ -1,6 +0,0 @@
|
|||||||
from .actor import Actor
|
|
||||||
from .critic import Critic
|
|
||||||
from .reward_model import RewardModel
|
|
||||||
from .lm import LM
|
|
||||||
|
|
||||||
__all__ = ['Actor', 'Critic', 'RewardModel', 'LM']
|
|
@ -1,65 +0,0 @@
|
|||||||
from typing import Optional, Tuple, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
from ..generation import generate
|
|
||||||
from ..lora import LoRAModule
|
|
||||||
from ..utils import log_probs_from_logits
|
|
||||||
|
|
||||||
|
|
||||||
class Actor(LoRAModule):
|
|
||||||
"""
|
|
||||||
Actor model base class.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model (nn.Module): Actor Model.
|
|
||||||
lora_rank (int): LoRA rank.
|
|
||||||
lora_train_bias (str): LoRA bias training mode.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, model: nn.Module, lora_rank: int = 0, lora_train_bias: str = 'none') -> None:
|
|
||||||
super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias)
|
|
||||||
self.model = model
|
|
||||||
self.convert_to_lora()
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def generate(
|
|
||||||
self,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
return_action_mask: bool = True,
|
|
||||||
**kwargs
|
|
||||||
) -> Union[Tuple[torch.LongTensor, torch.LongTensor], Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]]:
|
|
||||||
sequences = generate(self.model, input_ids, **kwargs)
|
|
||||||
attention_mask = None
|
|
||||||
pad_token_id = kwargs.get('pad_token_id', None)
|
|
||||||
if pad_token_id is not None:
|
|
||||||
attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device)
|
|
||||||
if not return_action_mask:
|
|
||||||
return sequences, attention_mask, None
|
|
||||||
input_len = input_ids.size(1)
|
|
||||||
eos_token_id = kwargs.get('eos_token_id', None)
|
|
||||||
if eos_token_id is None:
|
|
||||||
action_mask = torch.ones_like(sequences, dtype=torch.bool)
|
|
||||||
else:
|
|
||||||
# left padding may be applied, only mask action
|
|
||||||
action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0
|
|
||||||
action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input
|
|
||||||
action_mask[:, :input_len] = False
|
|
||||||
action_mask = action_mask[:, 1:]
|
|
||||||
return sequences, attention_mask, action_mask[:, -(sequences.size(1) - input_len):]
|
|
||||||
|
|
||||||
def forward(self,
|
|
||||||
sequences: torch.LongTensor,
|
|
||||||
num_actions: int,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
||||||
"""Returns action log probs
|
|
||||||
"""
|
|
||||||
output = self.model(sequences, attention_mask=attention_mask)
|
|
||||||
logits = output['logits']
|
|
||||||
log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
|
|
||||||
return log_probs[:, -num_actions:]
|
|
||||||
|
|
||||||
def get_base_model(self):
|
|
||||||
return self.model
|
|
@ -1,54 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
from ..lora import LoRAModule
|
|
||||||
from ..utils import masked_mean
|
|
||||||
|
|
||||||
|
|
||||||
class Critic(LoRAModule):
|
|
||||||
"""
|
|
||||||
Critic model base class.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model (nn.Module): Critic model.
|
|
||||||
value_head (nn.Module): Value head to get value.
|
|
||||||
lora_rank (int): LoRA rank.
|
|
||||||
lora_train_bias (str): LoRA bias training mode.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model: nn.Module,
|
|
||||||
value_head: nn.Module,
|
|
||||||
lora_rank: int = 0,
|
|
||||||
lora_train_bias: str = 'none',
|
|
||||||
use_action_mask: bool = False,
|
|
||||||
) -> None:
|
|
||||||
|
|
||||||
super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias)
|
|
||||||
self.model = model
|
|
||||||
self.value_head = value_head
|
|
||||||
self.use_action_mask = use_action_mask
|
|
||||||
self.convert_to_lora()
|
|
||||||
|
|
||||||
def forward(self,
|
|
||||||
sequences: torch.LongTensor,
|
|
||||||
action_mask: Optional[torch.Tensor] = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
||||||
outputs = self.model(sequences, attention_mask=attention_mask)
|
|
||||||
last_hidden_states = outputs['last_hidden_state']
|
|
||||||
|
|
||||||
values = self.value_head(last_hidden_states).squeeze(-1)
|
|
||||||
|
|
||||||
if action_mask is not None and self.use_action_mask:
|
|
||||||
num_actions = action_mask.size(1)
|
|
||||||
prompt_mask = attention_mask[:, :-num_actions]
|
|
||||||
values = values[:, :-num_actions]
|
|
||||||
value = masked_mean(values, prompt_mask, dim=1)
|
|
||||||
return value
|
|
||||||
|
|
||||||
values = values[:, :-1]
|
|
||||||
value = values.mean(dim=1)
|
|
||||||
return value
|
|
@ -1,33 +0,0 @@
|
|||||||
from typing import Optional, Tuple, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
from ..generation import generate
|
|
||||||
from .actor import Actor
|
|
||||||
|
|
||||||
|
|
||||||
class LM(Actor):
|
|
||||||
"""
|
|
||||||
Language model base class.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model (nn.Module): Language Model.
|
|
||||||
lora_rank (int): LoRA rank.
|
|
||||||
lora_train_bias (str): LoRA bias training mode.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, model: nn.Module, lora_rank: int = 0, lora_train_bias: str = 'none') -> None:
|
|
||||||
super().__init__(model=model, lora_rank=lora_rank, lora_train_bias=lora_train_bias)
|
|
||||||
|
|
||||||
def forward(self,
|
|
||||||
sequences: torch.LongTensor,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
||||||
"""Returns output log probs
|
|
||||||
"""
|
|
||||||
output = self.model(sequences, attention_mask=attention_mask)
|
|
||||||
logits = output['logits']
|
|
||||||
log_probs = F.log_softmax(logits, dim=-1)
|
|
||||||
return log_probs
|
|
||||||
|
|
@ -1,41 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
from ..lora import LoRAModule
|
|
||||||
|
|
||||||
|
|
||||||
class RewardModel(LoRAModule):
|
|
||||||
"""
|
|
||||||
Reward model base class.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model (nn.Module): Reward model.
|
|
||||||
value_head (nn.Module): Value head to get reward score.
|
|
||||||
lora_rank (int): LoRA rank.
|
|
||||||
lora_train_bias (str): LoRA bias training mode.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
model: nn.Module,
|
|
||||||
value_head: Optional[nn.Module] = None,
|
|
||||||
lora_rank: int = 0,
|
|
||||||
lora_train_bias: str = 'none') -> None:
|
|
||||||
super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias)
|
|
||||||
self.model = model
|
|
||||||
self.convert_to_lora()
|
|
||||||
|
|
||||||
if value_head is not None:
|
|
||||||
if value_head.out_features != 1:
|
|
||||||
raise ValueError("The value head of reward model's output dim should be 1!")
|
|
||||||
self.value_head = value_head
|
|
||||||
else:
|
|
||||||
self.value_head = nn.Linear(model.config.n_embd, 1)
|
|
||||||
|
|
||||||
def forward(self, sequences: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
||||||
outputs = self.model(sequences, attention_mask=attention_mask)
|
|
||||||
last_hidden_states = outputs['last_hidden_state']
|
|
||||||
values = self.value_head(last_hidden_states)[:, :-1]
|
|
||||||
value = values.mean(dim=1).squeeze(1) # ensure shape is (B)
|
|
||||||
return value
|
|
@ -1,6 +0,0 @@
|
|||||||
from .bloom_actor import BLOOMActor
|
|
||||||
from .bloom_critic import BLOOMCritic
|
|
||||||
from .bloom_rm import BLOOMRM
|
|
||||||
from .bloom_lm import BLOOMLM
|
|
||||||
|
|
||||||
__all__ = ['BLOOMActor', 'BLOOMCritic', 'BLOOMRM', 'BLOOMLM']
|
|
@ -1,35 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from transformers import BloomConfig, BloomForCausalLM, BloomModel
|
|
||||||
|
|
||||||
from ..base import Actor
|
|
||||||
|
|
||||||
|
|
||||||
class BLOOMActor(Actor):
|
|
||||||
"""
|
|
||||||
BLOOM Actor model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
pretrained (str): Pretrained model name or path.
|
|
||||||
config (BloomConfig): Model config.
|
|
||||||
checkpoint (bool): Enable gradient checkpointing.
|
|
||||||
lora_rank (int): LoRA rank.
|
|
||||||
lora_train_bias (str): LoRA bias training mode.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
pretrained: str = None,
|
|
||||||
config: Optional[BloomConfig] = None,
|
|
||||||
checkpoint: bool = False,
|
|
||||||
lora_rank: int = 0,
|
|
||||||
lora_train_bias: str = 'none') -> None:
|
|
||||||
if pretrained is not None:
|
|
||||||
model = BloomForCausalLM.from_pretrained(pretrained)
|
|
||||||
elif config is not None:
|
|
||||||
model = BloomForCausalLM(config)
|
|
||||||
else:
|
|
||||||
model = BloomForCausalLM(BloomConfig())
|
|
||||||
if checkpoint:
|
|
||||||
model.gradient_checkpointing_enable()
|
|
||||||
super().__init__(model, lora_rank, lora_train_bias)
|
|
@ -1,38 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from transformers import BloomConfig, BloomForCausalLM, BloomModel
|
|
||||||
|
|
||||||
from ..base import Critic
|
|
||||||
|
|
||||||
|
|
||||||
class BLOOMCritic(Critic):
|
|
||||||
"""
|
|
||||||
BLOOM Critic model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
pretrained (str): Pretrained model name or path.
|
|
||||||
config (BloomConfig): Model config.
|
|
||||||
checkpoint (bool): Enable gradient checkpointing.
|
|
||||||
lora_rank (int): LoRA rank.
|
|
||||||
lora_train_bias (str): LoRA bias training mode.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
pretrained: str = None,
|
|
||||||
config: Optional[BloomConfig] = None,
|
|
||||||
checkpoint: bool = False,
|
|
||||||
lora_rank: int = 0,
|
|
||||||
lora_train_bias: str = 'none',
|
|
||||||
**kwargs) -> None:
|
|
||||||
if pretrained is not None:
|
|
||||||
model = BloomModel.from_pretrained(pretrained)
|
|
||||||
elif config is not None:
|
|
||||||
model = BloomModel(config)
|
|
||||||
else:
|
|
||||||
model = BloomModel(BloomConfig())
|
|
||||||
if checkpoint:
|
|
||||||
model.gradient_checkpointing_enable()
|
|
||||||
value_head = nn.Linear(model.config.hidden_size, 1)
|
|
||||||
super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs)
|
|
@ -1,36 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from transformers import BloomConfig, BloomForCausalLM, BloomModel
|
|
||||||
|
|
||||||
from ..base import LM
|
|
||||||
|
|
||||||
|
|
||||||
class BLOOMLM(LM):
|
|
||||||
"""
|
|
||||||
BLOOM language model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
pretrained (str): Pretrained model name or path.
|
|
||||||
config (BloomConfig): Model config.
|
|
||||||
checkpoint (bool): Enable gradient checkpointing.
|
|
||||||
lora_rank (int): LoRA rank.
|
|
||||||
lora_train_bias (str): LoRA bias training mode.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
pretrained: str = None,
|
|
||||||
config: Optional[BloomConfig] = None,
|
|
||||||
checkpoint: bool = False,
|
|
||||||
lora_rank: int = 0,
|
|
||||||
lora_train_bias: str = 'none') -> None:
|
|
||||||
if pretrained is not None:
|
|
||||||
model = BloomForCausalLM.from_pretrained(pretrained)
|
|
||||||
elif config is not None:
|
|
||||||
model = BloomForCausalLM(config)
|
|
||||||
else:
|
|
||||||
model = BloomForCausalLM(BloomConfig())
|
|
||||||
if checkpoint:
|
|
||||||
model.gradient_checkpointing_enable()
|
|
||||||
super().__init__(model, lora_rank, lora_train_bias)
|
|
||||||
|
|
@ -1,37 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch.nn as nn
|
|
||||||
from transformers import BloomConfig, BloomForCausalLM, BloomModel
|
|
||||||
|
|
||||||
from ..base import RewardModel
|
|
||||||
|
|
||||||
|
|
||||||
class BLOOMRM(RewardModel):
|
|
||||||
"""
|
|
||||||
BLOOM Reward model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
pretrained (str): Pretrained model name or path.
|
|
||||||
config (BloomConfig): Model config.
|
|
||||||
checkpoint (bool): Enable gradient checkpointing.
|
|
||||||
lora_rank (int): LoRA rank.
|
|
||||||
lora_train_bias (str): LoRA bias training mode.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
pretrained: str = None,
|
|
||||||
config: Optional[BloomConfig] = None,
|
|
||||||
checkpoint: bool = False,
|
|
||||||
lora_rank: int = 0,
|
|
||||||
lora_train_bias: str = 'none') -> None:
|
|
||||||
if pretrained is not None:
|
|
||||||
model = BloomModel.from_pretrained(pretrained)
|
|
||||||
elif config is not None:
|
|
||||||
model = BloomModel(config)
|
|
||||||
else:
|
|
||||||
model = BloomModel(BloomConfig())
|
|
||||||
if checkpoint:
|
|
||||||
model.gradient_checkpointing_enable()
|
|
||||||
value_head = nn.Linear(model.config.hidden_size, 1)
|
|
||||||
value_head.weight.data.normal_(mean=0.0, std=1/(model.config.hidden_size + 1))
|
|
||||||
super().__init__(model, value_head, lora_rank, lora_train_bias)
|
|
@ -1,4 +0,0 @@
|
|||||||
from .deberta_critic import DebertaCritic
|
|
||||||
from .deberta_rm import DebertaRM
|
|
||||||
|
|
||||||
__all__ = ['DebertaCritic', 'DebertaRM']
|
|
@ -1,36 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch.nn as nn
|
|
||||||
from transformers import DebertaV2Config, DebertaV2Model
|
|
||||||
|
|
||||||
from ..base import Critic
|
|
||||||
|
|
||||||
|
|
||||||
class DebertaCritic(Critic):
|
|
||||||
"""
|
|
||||||
Deberta Critic model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
pretrained (str): Pretrained model name or path.
|
|
||||||
config (DebertaV2Config): Model config.
|
|
||||||
checkpoint (bool): Enable gradient checkpointing.
|
|
||||||
lora_rank (int): Rank of the LO-RA decomposition.
|
|
||||||
lora_train_bias (str): LoRA bias training mode.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
pretrained: Optional[str] = None,
|
|
||||||
config: Optional[DebertaV2Config] = None,
|
|
||||||
checkpoint: bool = False,
|
|
||||||
lora_rank: int = 0,
|
|
||||||
lora_train_bias: str = 'none') -> None:
|
|
||||||
if pretrained is not None:
|
|
||||||
model = DebertaV2Model.from_pretrained(pretrained)
|
|
||||||
elif config is not None:
|
|
||||||
model = DebertaV2Model(config)
|
|
||||||
else:
|
|
||||||
model = DebertaV2Model(DebertaV2Config())
|
|
||||||
if checkpoint:
|
|
||||||
model.gradient_checkpointing_enable()
|
|
||||||
value_head = nn.Linear(model.config.hidden_size, 1)
|
|
||||||
super().__init__(model, value_head, lora_rank, lora_train_bias)
|
|
@ -1,37 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch.nn as nn
|
|
||||||
from transformers import DebertaV2Config, DebertaV2Model
|
|
||||||
|
|
||||||
from ..base import RewardModel
|
|
||||||
|
|
||||||
|
|
||||||
class DebertaRM(RewardModel):
|
|
||||||
"""
|
|
||||||
Deberta Reward model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
pretrained (str): Pretrained model name or path.
|
|
||||||
config (DebertaV2Config): Model config.
|
|
||||||
checkpoint (bool): Enable gradient checkpointing.
|
|
||||||
lora_rank (int): Rank of the LO-RA decomposition.
|
|
||||||
lora_train_bias (str): LoRA bias training mode.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
pretrained: str = None,
|
|
||||||
config: Optional[DebertaV2Config] = None,
|
|
||||||
checkpoint: bool = False,
|
|
||||||
lora_rank: int = 0,
|
|
||||||
lora_train_bias: str = 'none') -> None:
|
|
||||||
if pretrained is not None:
|
|
||||||
model = DebertaV2Model.from_pretrained(pretrained)
|
|
||||||
elif config is not None:
|
|
||||||
model = DebertaV2Model(config)
|
|
||||||
else:
|
|
||||||
model = DebertaV2Model(DebertaV2Config())
|
|
||||||
if checkpoint:
|
|
||||||
model.gradient_checkpointing_enable()
|
|
||||||
value_head = nn.Linear(model.config.hidden_size, 1)
|
|
||||||
value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.hidden_size + 1))
|
|
||||||
super().__init__(model, value_head, lora_rank, lora_train_bias)
|
|
@ -1,146 +0,0 @@
|
|||||||
from typing import Any, Callable, Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.distributed as dist
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
try:
|
|
||||||
from transformers.generation_logits_process import (
|
|
||||||
LogitsProcessorList,
|
|
||||||
TemperatureLogitsWarper,
|
|
||||||
TopKLogitsWarper,
|
|
||||||
TopPLogitsWarper,
|
|
||||||
)
|
|
||||||
except ImportError:
|
|
||||||
from transformers.generation import LogitsProcessorList, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_logits_processor(top_k: Optional[int] = None,
|
|
||||||
top_p: Optional[float] = None,
|
|
||||||
temperature: Optional[float] = None) -> LogitsProcessorList:
|
|
||||||
processor_list = LogitsProcessorList()
|
|
||||||
if temperature is not None and temperature != 1.0:
|
|
||||||
processor_list.append(TemperatureLogitsWarper(temperature))
|
|
||||||
if top_k is not None and top_k != 0:
|
|
||||||
processor_list.append(TopKLogitsWarper(top_k))
|
|
||||||
if top_p is not None and top_p < 1.0:
|
|
||||||
processor_list.append(TopPLogitsWarper(top_p))
|
|
||||||
return processor_list
|
|
||||||
|
|
||||||
|
|
||||||
def _is_sequence_finished(unfinished_sequences: torch.Tensor) -> bool:
|
|
||||||
if dist.is_initialized() and dist.get_world_size() > 1:
|
|
||||||
# consider DP
|
|
||||||
unfinished_sequences = unfinished_sequences.clone()
|
|
||||||
dist.all_reduce(unfinished_sequences)
|
|
||||||
return unfinished_sequences.max() == 0
|
|
||||||
|
|
||||||
|
|
||||||
def sample(model: nn.Module,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
max_length: int,
|
|
||||||
early_stopping: bool = False,
|
|
||||||
eos_token_id: Optional[int] = None,
|
|
||||||
pad_token_id: Optional[int] = None,
|
|
||||||
top_k: Optional[int] = None,
|
|
||||||
top_p: Optional[float] = None,
|
|
||||||
temperature: Optional[float] = None,
|
|
||||||
prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
|
|
||||||
update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
|
|
||||||
**model_kwargs) -> torch.Tensor:
|
|
||||||
if input_ids.size(1) >= max_length:
|
|
||||||
return input_ids
|
|
||||||
|
|
||||||
logits_processor = prepare_logits_processor(top_k, top_p, temperature)
|
|
||||||
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
|
|
||||||
|
|
||||||
for _ in range(input_ids.size(1), max_length):
|
|
||||||
model_inputs = prepare_inputs_fn(input_ids, **model_kwargs) if prepare_inputs_fn is not None else {
|
|
||||||
'input_ids': input_ids
|
|
||||||
}
|
|
||||||
outputs = model(**model_inputs)
|
|
||||||
|
|
||||||
next_token_logits = outputs['logits'][:, -1, :]
|
|
||||||
# pre-process distribution
|
|
||||||
next_token_logits = logits_processor(input_ids, next_token_logits)
|
|
||||||
# sample
|
|
||||||
probs = torch.softmax(next_token_logits, dim=-1, dtype=torch.float)
|
|
||||||
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
|
||||||
|
|
||||||
# finished sentences should have their next token be a padding token
|
|
||||||
if eos_token_id is not None:
|
|
||||||
if pad_token_id is None:
|
|
||||||
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
|
|
||||||
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
|
|
||||||
|
|
||||||
# update generated ids, model inputs for next step
|
|
||||||
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
|
||||||
if update_model_kwargs_fn is not None:
|
|
||||||
model_kwargs = update_model_kwargs_fn(outputs, **model_kwargs)
|
|
||||||
|
|
||||||
# if eos_token was found in one sentence, set sentence to finished
|
|
||||||
if eos_token_id is not None:
|
|
||||||
unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long())
|
|
||||||
|
|
||||||
# stop when each sentence is finished if early_stopping=True
|
|
||||||
if early_stopping and _is_sequence_finished(unfinished_sequences):
|
|
||||||
break
|
|
||||||
|
|
||||||
return input_ids
|
|
||||||
|
|
||||||
|
|
||||||
def generate(model: nn.Module,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
max_length: int,
|
|
||||||
num_beams: int = 1,
|
|
||||||
do_sample: bool = True,
|
|
||||||
early_stopping: bool = False,
|
|
||||||
eos_token_id: Optional[int] = None,
|
|
||||||
pad_token_id: Optional[int] = None,
|
|
||||||
top_k: Optional[int] = None,
|
|
||||||
top_p: Optional[float] = None,
|
|
||||||
temperature: Optional[float] = None,
|
|
||||||
prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
|
|
||||||
update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
|
|
||||||
**model_kwargs) -> torch.Tensor:
|
|
||||||
"""Generate token sequence. The returned sequence is input_ids + generated_tokens.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model (nn.Module): model
|
|
||||||
input_ids (torch.Tensor): input sequence
|
|
||||||
max_length (int): max length of the returned sequence
|
|
||||||
num_beams (int, optional): number of beams. Defaults to 1.
|
|
||||||
do_sample (bool, optional): whether to do sample. Defaults to True.
|
|
||||||
early_stopping (bool, optional): if True, the sequence length may be smaller than max_length due to finding eos. Defaults to False.
|
|
||||||
eos_token_id (Optional[int], optional): end of sequence token id. Defaults to None.
|
|
||||||
pad_token_id (Optional[int], optional): pad token id. Defaults to None.
|
|
||||||
top_k (Optional[int], optional): the number of highest probability vocabulary tokens to keep for top-k-filtering. Defaults to None.
|
|
||||||
top_p (Optional[float], optional): If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. Defaults to None.
|
|
||||||
temperature (Optional[float], optional): The value used to module the next token probabilities. Defaults to None.
|
|
||||||
prepare_inputs_fn (Optional[Callable[[torch.Tensor, Any], dict]], optional): Function to preprocess model inputs. Arguments of this function should be input_ids and model_kwargs. Defaults to None.
|
|
||||||
update_model_kwargs_fn (Optional[Callable[[dict, Any], dict]], optional): Function to update model_kwargs based on outputs. Arguments of this function should be outputs and model_kwargs. Defaults to None.
|
|
||||||
"""
|
|
||||||
is_greedy_gen_mode = ((num_beams == 1) and do_sample is False)
|
|
||||||
is_sample_gen_mode = ((num_beams == 1) and do_sample is True)
|
|
||||||
is_beam_gen_mode = ((num_beams > 1) and do_sample is False)
|
|
||||||
if is_greedy_gen_mode:
|
|
||||||
# run greedy search
|
|
||||||
raise NotImplementedError
|
|
||||||
elif is_sample_gen_mode:
|
|
||||||
# run sample
|
|
||||||
return sample(model,
|
|
||||||
input_ids,
|
|
||||||
max_length,
|
|
||||||
early_stopping=early_stopping,
|
|
||||||
eos_token_id=eos_token_id,
|
|
||||||
pad_token_id=pad_token_id,
|
|
||||||
top_k=top_k,
|
|
||||||
top_p=top_p,
|
|
||||||
temperature=temperature,
|
|
||||||
prepare_inputs_fn=prepare_inputs_fn,
|
|
||||||
update_model_kwargs_fn=update_model_kwargs_fn,
|
|
||||||
**model_kwargs)
|
|
||||||
elif is_beam_gen_mode:
|
|
||||||
raise NotImplementedError
|
|
||||||
else:
|
|
||||||
raise ValueError("Unsupported generation mode")
|
|
@ -1,92 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
def gpt_prepare_inputs_fn(input_ids: torch.Tensor, past: Optional[torch.Tensor] = None, **kwargs) -> dict:
|
|
||||||
token_type_ids = kwargs.get("token_type_ids", None)
|
|
||||||
# only last token for inputs_ids if past is defined in kwargs
|
|
||||||
if past:
|
|
||||||
input_ids = input_ids[:, -1].unsqueeze(-1)
|
|
||||||
if token_type_ids is not None:
|
|
||||||
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
|
|
||||||
|
|
||||||
attention_mask = kwargs.get("attention_mask", None)
|
|
||||||
position_ids = kwargs.get("position_ids", None)
|
|
||||||
|
|
||||||
if attention_mask is not None and position_ids is None:
|
|
||||||
# create position_ids on the fly for batch generation
|
|
||||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
|
||||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
|
||||||
if past:
|
|
||||||
position_ids = position_ids[:, -1].unsqueeze(-1)
|
|
||||||
else:
|
|
||||||
position_ids = None
|
|
||||||
return {
|
|
||||||
"input_ids": input_ids,
|
|
||||||
"past_key_values": past,
|
|
||||||
"use_cache": kwargs.get("use_cache"),
|
|
||||||
"position_ids": position_ids,
|
|
||||||
"attention_mask": attention_mask,
|
|
||||||
"token_type_ids": token_type_ids,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def update_model_kwargs_fn(outputs: dict, **model_kwargs) -> dict:
|
|
||||||
if "past_key_values" in outputs:
|
|
||||||
model_kwargs["past"] = outputs["past_key_values"]
|
|
||||||
else:
|
|
||||||
model_kwargs["past"] = None
|
|
||||||
|
|
||||||
# update token_type_ids with last value
|
|
||||||
if "token_type_ids" in model_kwargs:
|
|
||||||
token_type_ids = model_kwargs["token_type_ids"]
|
|
||||||
model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)
|
|
||||||
|
|
||||||
# update attention mask
|
|
||||||
if "attention_mask" in model_kwargs:
|
|
||||||
attention_mask = model_kwargs["attention_mask"]
|
|
||||||
model_kwargs["attention_mask"] = torch.cat(
|
|
||||||
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1)
|
|
||||||
|
|
||||||
return model_kwargs
|
|
||||||
|
|
||||||
|
|
||||||
def opt_prepare_inputs_fn(input_ids: torch.Tensor,
|
|
||||||
past: Optional[torch.Tensor] = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
**kwargs) -> dict:
|
|
||||||
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
|
||||||
if attention_mask is None:
|
|
||||||
attention_mask = input_ids.new_ones(input_ids.shape)
|
|
||||||
|
|
||||||
if past:
|
|
||||||
input_ids = input_ids[:, -1:]
|
|
||||||
# first step, decoder_cached_states are empty
|
|
||||||
return {
|
|
||||||
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
|
|
||||||
"attention_mask": attention_mask,
|
|
||||||
"past_key_values": past,
|
|
||||||
"use_cache": use_cache,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def bloom_prepare_inputs_fn(input_ids: torch.Tensor,
|
|
||||||
past: Optional[torch.Tensor] = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
**kwargs) -> dict:
|
|
||||||
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
|
||||||
if attention_mask is None:
|
|
||||||
attention_mask = input_ids.new_ones(input_ids.shape)
|
|
||||||
|
|
||||||
if past:
|
|
||||||
input_ids = input_ids[:, -1:]
|
|
||||||
# first step, decoder_cached_states are empty
|
|
||||||
return {
|
|
||||||
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
|
|
||||||
"attention_mask": attention_mask,
|
|
||||||
"past_key_values": past,
|
|
||||||
"use_cache": use_cache,
|
|
||||||
}
|
|
@ -1,6 +0,0 @@
|
|||||||
from .gpt_actor import GPTActor
|
|
||||||
from .gpt_critic import GPTCritic
|
|
||||||
from .gpt_rm import GPTRM
|
|
||||||
from .gpt_lm import GPTLM
|
|
||||||
|
|
||||||
__all__ = ['GPTActor', 'GPTCritic', 'GPTRM', 'GPTLM']
|
|
@ -1,35 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
|
|
||||||
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
|
|
||||||
|
|
||||||
from ..base import Actor
|
|
||||||
|
|
||||||
|
|
||||||
class GPTActor(Actor):
|
|
||||||
"""
|
|
||||||
GPT Actor model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
pretrained (str): Pretrained model name or path.
|
|
||||||
config (GPT2Config): Model config.
|
|
||||||
checkpoint (bool): Enable gradient checkpointing.
|
|
||||||
lora_rank (int): Rank of the LoRa layer.
|
|
||||||
lora_train_bias (str): Bias training strategy for the LoRa layer.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
pretrained: Optional[str] = None,
|
|
||||||
config: Optional[GPT2Config] = None,
|
|
||||||
checkpoint: bool = False,
|
|
||||||
lora_rank: int = 0,
|
|
||||||
lora_train_bias: str = 'none') -> None:
|
|
||||||
if pretrained is not None:
|
|
||||||
model = GPT2LMHeadModel.from_pretrained(pretrained)
|
|
||||||
elif config is not None:
|
|
||||||
model = GPT2LMHeadModel(config)
|
|
||||||
else:
|
|
||||||
model = GPT2LMHeadModel(GPT2Config())
|
|
||||||
if checkpoint:
|
|
||||||
model.gradient_checkpointing_enable()
|
|
||||||
super().__init__(model, lora_rank, lora_train_bias)
|
|
@ -1,37 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch.nn as nn
|
|
||||||
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
|
|
||||||
from transformers.models.gpt2.modeling_gpt2 import GPT2Model
|
|
||||||
|
|
||||||
from ..base import Critic
|
|
||||||
|
|
||||||
|
|
||||||
class GPTCritic(Critic):
|
|
||||||
"""
|
|
||||||
GPT Critic model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
pretrained (str): Pretrained model name or path.
|
|
||||||
config (GPT2Config): Model config.
|
|
||||||
checkpoint (bool): Enable gradient checkpointing.
|
|
||||||
lora_rank (int): Rank of the LO-RA decomposition.
|
|
||||||
lora_train_bias (str): LoRA bias training mode.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
pretrained: Optional[str] = None,
|
|
||||||
config: Optional[GPT2Config] = None,
|
|
||||||
checkpoint: bool = False,
|
|
||||||
lora_rank: int = 0,
|
|
||||||
lora_train_bias: str = 'none') -> None:
|
|
||||||
if pretrained is not None:
|
|
||||||
model = GPT2Model.from_pretrained(pretrained)
|
|
||||||
elif config is not None:
|
|
||||||
model = GPT2Model(config)
|
|
||||||
else:
|
|
||||||
model = GPT2Model(GPT2Config())
|
|
||||||
if checkpoint:
|
|
||||||
model.gradient_checkpointing_enable()
|
|
||||||
value_head = nn.Linear(model.config.n_embd, 1)
|
|
||||||
super().__init__(model, value_head, lora_rank, lora_train_bias)
|
|
@ -1,36 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
|
|
||||||
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
|
|
||||||
|
|
||||||
from ..base import LM
|
|
||||||
|
|
||||||
|
|
||||||
class GPTLM(LM):
|
|
||||||
"""
|
|
||||||
GPT language model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
pretrained (str): Pretrained model name or path.
|
|
||||||
config (GPT2Config): Model config.
|
|
||||||
checkpoint (bool): Enable gradient checkpointing.
|
|
||||||
lora_rank (int): Rank of the LoRa layer.
|
|
||||||
lora_train_bias (str): Bias training strategy for the LoRa layer.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
pretrained: Optional[str] = None,
|
|
||||||
config: Optional[GPT2Config] = None,
|
|
||||||
checkpoint: bool = False,
|
|
||||||
lora_rank: int = 0,
|
|
||||||
lora_train_bias: str = 'none') -> None:
|
|
||||||
if pretrained is not None:
|
|
||||||
model = GPT2LMHeadModel.from_pretrained(pretrained)
|
|
||||||
elif config is not None:
|
|
||||||
model = GPT2LMHeadModel(config)
|
|
||||||
else:
|
|
||||||
model = GPT2LMHeadModel(GPT2Config())
|
|
||||||
if checkpoint:
|
|
||||||
model.gradient_checkpointing_enable()
|
|
||||||
super().__init__(model, lora_rank, lora_train_bias)
|
|
||||||
|
|
@ -1,39 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch.nn as nn
|
|
||||||
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
|
|
||||||
from transformers.models.gpt2.modeling_gpt2 import GPT2Model
|
|
||||||
|
|
||||||
from ..base import RewardModel
|
|
||||||
|
|
||||||
|
|
||||||
class GPTRM(RewardModel):
|
|
||||||
"""
|
|
||||||
GPT Reward model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
pretrained (str): Pretrained model name or path.
|
|
||||||
config (GPT2Config): Model config.
|
|
||||||
checkpoint (bool): Enable gradient checkpointing.
|
|
||||||
lora_rank (int): Rank of the low-rank approximation.
|
|
||||||
lora_train_bias (str): LoRA bias training mode.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
pretrained: Optional[str] = None,
|
|
||||||
config: Optional[GPT2Config] = None,
|
|
||||||
checkpoint: bool = False,
|
|
||||||
lora_rank: int = 0,
|
|
||||||
lora_train_bias: str = 'none') -> None:
|
|
||||||
if pretrained is not None:
|
|
||||||
model = GPT2Model.from_pretrained(pretrained)
|
|
||||||
elif config is not None:
|
|
||||||
model = GPT2Model(config)
|
|
||||||
else:
|
|
||||||
model = GPT2Model(GPT2Config())
|
|
||||||
if checkpoint:
|
|
||||||
model.gradient_checkpointing_enable()
|
|
||||||
|
|
||||||
value_head = nn.Linear(model.config.n_embd, 1)
|
|
||||||
value_head.weight.data.normal_(mean=0.0, std=1/(model.config.n_embd + 1))
|
|
||||||
super().__init__(model, value_head, lora_rank, lora_train_bias)
|
|
@ -1,6 +0,0 @@
|
|||||||
from .llama_actor import LlamaActor
|
|
||||||
from .llama_critic import LlamaCritic
|
|
||||||
from .llama_rm import LlamaRM
|
|
||||||
from .llama_lm import LlamaLM
|
|
||||||
|
|
||||||
__all__ = ['LlamaActor', 'LlamaCritic', 'LlamaRM', 'LlamaLM']
|
|
@ -1,38 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from transformers import AutoModelForCausalLM, LlamaConfig, LlamaForCausalLM
|
|
||||||
|
|
||||||
from ..base import Actor
|
|
||||||
|
|
||||||
|
|
||||||
class LlamaActor(Actor):
|
|
||||||
"""
|
|
||||||
Llama Actor model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
pretrained (str): Pretrained model name or path.
|
|
||||||
config (LlamaConfig): Model config.
|
|
||||||
checkpoint (bool): Enable gradient checkpointing.
|
|
||||||
lora_rank (int): LoRA rank.
|
|
||||||
lora_train_bias (str): LoRA bias training mode.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
pretrained: Optional[str] = None,
|
|
||||||
config: Optional[LlamaConfig] = None,
|
|
||||||
checkpoint: bool = False,
|
|
||||||
lora_rank: int = 0,
|
|
||||||
lora_train_bias: str = 'none') -> None:
|
|
||||||
|
|
||||||
if pretrained is not None:
|
|
||||||
model = LlamaForCausalLM.from_pretrained(pretrained)
|
|
||||||
elif config is not None:
|
|
||||||
model = LlamaForCausalLM(config)
|
|
||||||
else:
|
|
||||||
model = LlamaForCausalLM(LlamaConfig())
|
|
||||||
|
|
||||||
if checkpoint:
|
|
||||||
model.gradient_checkpointing_enable()
|
|
||||||
|
|
||||||
super().__init__(model, lora_rank, lora_train_bias)
|
|
@ -1,42 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from transformers import AutoModelForCausalLM, LlamaConfig, LlamaForCausalLM
|
|
||||||
|
|
||||||
from ..base import Critic
|
|
||||||
|
|
||||||
|
|
||||||
class LlamaCritic(Critic):
|
|
||||||
"""
|
|
||||||
Llama Critic model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
pretrained (str): Pretrained model name or path.
|
|
||||||
config (LlamaConfig): Model config.
|
|
||||||
checkpoint (bool): Enable gradient checkpointing.
|
|
||||||
lora_rank (int): LoRA rank.
|
|
||||||
lora_train_bias (str): LoRA bias training mode.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
pretrained: Optional[str] = None,
|
|
||||||
config: Optional[LlamaConfig] = None,
|
|
||||||
checkpoint: bool = False,
|
|
||||||
lora_rank: int = 0,
|
|
||||||
lora_train_bias: str = 'none',
|
|
||||||
**kwargs) -> None:
|
|
||||||
|
|
||||||
if pretrained is not None:
|
|
||||||
model = LlamaForCausalLM.from_pretrained(pretrained)
|
|
||||||
elif config is not None:
|
|
||||||
model = LlamaForCausalLM(config)
|
|
||||||
else:
|
|
||||||
model = LlamaForCausalLM(LlamaConfig())
|
|
||||||
|
|
||||||
if checkpoint:
|
|
||||||
model.gradient_checkpointing_enable()
|
|
||||||
|
|
||||||
value_head = nn.Linear(model.config.hidden_size, 1)
|
|
||||||
|
|
||||||
super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs)
|
|
@ -1,40 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
from transformers import LlamaConfig, LlamaForCausalLM
|
|
||||||
|
|
||||||
from ..base import LM
|
|
||||||
|
|
||||||
|
|
||||||
class LlamaLM(LM):
|
|
||||||
"""
|
|
||||||
Llama language model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
pretrained (str): Pretrained model name or path.
|
|
||||||
config (LlamaConfig): Model config.
|
|
||||||
checkpoint (bool): Enable gradient checkpointing.
|
|
||||||
lora_rank (int): LoRA rank.
|
|
||||||
lora_train_bias (str): LoRA bias training mode.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
pretrained: Optional[str] = None,
|
|
||||||
config: Optional[LlamaConfig] = None,
|
|
||||||
checkpoint: bool = False,
|
|
||||||
lora_rank: int = 0,
|
|
||||||
lora_train_bias: str = 'none') -> None:
|
|
||||||
|
|
||||||
if pretrained is not None:
|
|
||||||
model = LlamaForCausalLM.from_pretrained(pretrained)
|
|
||||||
elif config is not None:
|
|
||||||
model = LlamaForCausalLM(config)
|
|
||||||
else:
|
|
||||||
model = LlamaForCausalLM(LlamaConfig())
|
|
||||||
|
|
||||||
if checkpoint:
|
|
||||||
model.gradient_checkpointing_enable()
|
|
||||||
|
|
||||||
super().__init__(model, lora_rank, lora_train_bias)
|
|
||||||
|
|
||||||
def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
|
|
||||||
return self.model(input_ids, attention_mask=attention_mask, labels=labels, **kwargs)
|
|
@ -1,41 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch.nn as nn
|
|
||||||
from transformers import LlamaConfig, LlamaForCausalLM
|
|
||||||
|
|
||||||
from ..base import RewardModel
|
|
||||||
|
|
||||||
|
|
||||||
class LlamaRM(RewardModel):
|
|
||||||
"""
|
|
||||||
Llama Reward model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
pretrained (str): Pretrained model name or path.
|
|
||||||
config (LlamaConfig): Model config.
|
|
||||||
checkpoint (bool): Enable gradient checkpointing.
|
|
||||||
lora_rank (int): LoRA rank.
|
|
||||||
lora_train_bias (str): LoRA bias training mode.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
pretrained: Optional[str] = None,
|
|
||||||
config: Optional[LlamaConfig] = None,
|
|
||||||
checkpoint: bool = False,
|
|
||||||
lora_rank: int = 0,
|
|
||||||
lora_train_bias: str = 'none') -> None:
|
|
||||||
|
|
||||||
if pretrained is not None:
|
|
||||||
model = LlamaForCausalLM.from_pretrained(pretrained)
|
|
||||||
elif config is not None:
|
|
||||||
model = LlamaForCausalLM(config)
|
|
||||||
else:
|
|
||||||
model = LlamaForCausalLM(LlamaConfig())
|
|
||||||
|
|
||||||
if checkpoint:
|
|
||||||
model.gradient_checkpointing_enable()
|
|
||||||
|
|
||||||
value_head = nn.Linear(model.config.hidden_size, 1)
|
|
||||||
value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.hidden_size + 1))
|
|
||||||
|
|
||||||
super().__init__(model, lora_rank, lora_train_bias)
|
|
@ -1,130 +0,0 @@
|
|||||||
import math
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import loralib as lora
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
|
|
||||||
class LoraLinear(lora.LoRALayer, nn.Module):
|
|
||||||
"""Replace in-place ops to out-of-place ops to fit gemini. Convert a torch.nn.Linear to LoraLinear.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
weight: nn.Parameter,
|
|
||||||
bias: Optional[nn.Parameter],
|
|
||||||
r: int = 0,
|
|
||||||
lora_alpha: int = 1,
|
|
||||||
lora_dropout: float = 0.,
|
|
||||||
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
|
|
||||||
merge_weights: bool = True,
|
|
||||||
):
|
|
||||||
nn.Module.__init__(self)
|
|
||||||
lora.LoRALayer.__init__(self,
|
|
||||||
r=r,
|
|
||||||
lora_alpha=lora_alpha,
|
|
||||||
lora_dropout=lora_dropout,
|
|
||||||
merge_weights=merge_weights)
|
|
||||||
self.weight = weight
|
|
||||||
self.bias = bias
|
|
||||||
|
|
||||||
out_features, in_features = weight.shape
|
|
||||||
self.in_features = in_features
|
|
||||||
self.out_features = out_features
|
|
||||||
|
|
||||||
self.fan_in_fan_out = fan_in_fan_out
|
|
||||||
# Actual trainable parameters
|
|
||||||
if r > 0:
|
|
||||||
self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features)))
|
|
||||||
self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r)))
|
|
||||||
self.scaling = self.lora_alpha / self.r
|
|
||||||
# Freezing the pre-trained weight matrix
|
|
||||||
self.weight.requires_grad = False
|
|
||||||
self.reset_parameters()
|
|
||||||
if fan_in_fan_out:
|
|
||||||
self.weight.data = self.weight.data.T
|
|
||||||
|
|
||||||
def reset_parameters(self):
|
|
||||||
if hasattr(self, 'lora_A'):
|
|
||||||
# initialize A the same way as the default for nn.Linear and B to zero
|
|
||||||
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
|
|
||||||
nn.init.zeros_(self.lora_B)
|
|
||||||
|
|
||||||
def train(self, mode: bool = True):
|
|
||||||
|
|
||||||
def T(w):
|
|
||||||
return w.T if self.fan_in_fan_out else w
|
|
||||||
|
|
||||||
nn.Module.train(self, mode)
|
|
||||||
if self.merge_weights and self.merged:
|
|
||||||
# Make sure that the weights are not merged
|
|
||||||
if self.r > 0:
|
|
||||||
self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
|
|
||||||
self.merged = False
|
|
||||||
|
|
||||||
def eval(self):
|
|
||||||
|
|
||||||
def T(w):
|
|
||||||
return w.T if self.fan_in_fan_out else w
|
|
||||||
|
|
||||||
nn.Module.eval(self)
|
|
||||||
if self.merge_weights and not self.merged:
|
|
||||||
# Merge the weights and mark it
|
|
||||||
if self.r > 0:
|
|
||||||
self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
|
|
||||||
delattr(self, 'lora_A')
|
|
||||||
delattr(self, 'lora_B')
|
|
||||||
self.merged = True
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
|
||||||
|
|
||||||
def T(w):
|
|
||||||
return w.T if self.fan_in_fan_out else w
|
|
||||||
|
|
||||||
if self.r > 0 and not self.merged:
|
|
||||||
result = F.linear(x, T(self.weight), bias=self.bias)
|
|
||||||
if self.r > 0:
|
|
||||||
result = result + (self.lora_dropout(x) @ self.lora_A.t() @ self.lora_B.t()) * self.scaling
|
|
||||||
return result
|
|
||||||
else:
|
|
||||||
return F.linear(x, T(self.weight), bias=self.bias)
|
|
||||||
|
|
||||||
|
|
||||||
def lora_linear_wrapper(linear: nn.Linear, lora_rank: int) -> LoraLinear:
|
|
||||||
assert lora_rank <= linear.in_features, f'LoRA rank ({lora_rank}) must be less than or equal to in features ({linear.in_features})'
|
|
||||||
lora_linear = LoraLinear(linear.weight, linear.bias, r=lora_rank, merge_weights=False)
|
|
||||||
return lora_linear
|
|
||||||
|
|
||||||
|
|
||||||
def convert_to_lora_recursively(module: nn.Module, lora_rank: int) -> None:
|
|
||||||
for name, child in module.named_children():
|
|
||||||
if isinstance(child, nn.Linear):
|
|
||||||
setattr(module, name, lora_linear_wrapper(child, lora_rank))
|
|
||||||
else:
|
|
||||||
convert_to_lora_recursively(child, lora_rank)
|
|
||||||
|
|
||||||
|
|
||||||
class LoRAModule(nn.Module):
|
|
||||||
"""A LoRA module base class. All derived classes should call `convert_to_lora()` at the bottom of `__init__()`.
|
|
||||||
This calss will convert all torch.nn.Linear layer to LoraLinear layer.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
lora_rank (int, optional): LoRA rank. 0 means LoRA is not applied. Defaults to 0.
|
|
||||||
lora_train_bias (str, optional): Whether LoRA train biases.
|
|
||||||
'none' means it doesn't train biases. 'all' means it trains all biases. 'lora_only' means it only trains biases of LoRA layers.
|
|
||||||
Defaults to 'none'.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, lora_rank: int = 0, lora_train_bias: str = 'none') -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.lora_rank = lora_rank
|
|
||||||
self.lora_train_bias = lora_train_bias
|
|
||||||
|
|
||||||
def convert_to_lora(self) -> None:
|
|
||||||
if self.lora_rank <= 0:
|
|
||||||
return
|
|
||||||
convert_to_lora_recursively(self, self.lora_rank)
|
|
||||||
lora.mark_only_lora_as_trainable(self, self.lora_train_bias)
|
|
||||||
|
|
@ -1,115 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
from .utils import masked_mean
|
|
||||||
|
|
||||||
|
|
||||||
class GPTLMLoss(nn.Module):
|
|
||||||
"""
|
|
||||||
GPT Language Model Loss
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.loss = nn.CrossEntropyLoss()
|
|
||||||
|
|
||||||
def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
|
|
||||||
shift_logits = logits[..., :-1, :].contiguous()
|
|
||||||
shift_labels = labels[..., 1:].contiguous()
|
|
||||||
# Flatten the tokens
|
|
||||||
return self.loss(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
|
||||||
|
|
||||||
|
|
||||||
class PolicyLoss(nn.Module):
|
|
||||||
"""
|
|
||||||
Policy Loss for PPO
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, clip_eps: float = 0.2) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.clip_eps = clip_eps
|
|
||||||
|
|
||||||
def forward(self,
|
|
||||||
log_probs: torch.Tensor,
|
|
||||||
old_log_probs: torch.Tensor,
|
|
||||||
advantages: torch.Tensor,
|
|
||||||
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
||||||
ratio = (log_probs - old_log_probs).exp()
|
|
||||||
surr1 = ratio * advantages
|
|
||||||
surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages
|
|
||||||
loss = -torch.min(surr1, surr2)
|
|
||||||
if action_mask is not None:
|
|
||||||
loss = masked_mean(loss, action_mask)
|
|
||||||
loss = loss.mean()
|
|
||||||
return loss
|
|
||||||
|
|
||||||
|
|
||||||
class ValueLoss(nn.Module):
|
|
||||||
"""
|
|
||||||
Value Loss for PPO
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, clip_eps: float = 0.4) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.clip_eps = clip_eps
|
|
||||||
|
|
||||||
def forward(self,
|
|
||||||
values: torch.Tensor,
|
|
||||||
old_values: torch.Tensor,
|
|
||||||
reward: torch.Tensor,
|
|
||||||
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
||||||
values_clipped = old_values + (values - old_values).clamp(-self.clip_eps, self.clip_eps)
|
|
||||||
surr1 = (values_clipped - reward)**2
|
|
||||||
surr2 = (values - reward)**2
|
|
||||||
loss = torch.max(surr1, surr2)
|
|
||||||
loss = loss.mean()
|
|
||||||
return loss
|
|
||||||
|
|
||||||
|
|
||||||
class PPOPtxActorLoss(nn.Module):
|
|
||||||
"""
|
|
||||||
To Do:
|
|
||||||
|
|
||||||
PPO-ptx Actor Loss
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, policy_clip_eps: float = 0.2, pretrain_coef: float = 0.0, pretrain_loss_fn=GPTLMLoss()) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.pretrain_coef = pretrain_coef
|
|
||||||
self.policy_loss_fn = PolicyLoss(clip_eps=policy_clip_eps)
|
|
||||||
self.pretrain_loss_fn = pretrain_loss_fn
|
|
||||||
|
|
||||||
def forward(self,
|
|
||||||
log_probs: torch.Tensor,
|
|
||||||
old_log_probs: torch.Tensor,
|
|
||||||
advantages: torch.Tensor,
|
|
||||||
lm_logits: torch.Tensor,
|
|
||||||
lm_input_ids: torch.Tensor,
|
|
||||||
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
||||||
policy_loss = self.policy_loss_fn(log_probs, old_log_probs, advantages, action_mask=action_mask)
|
|
||||||
lm_loss = self.pretrain_loss_fn(lm_logits, lm_input_ids)
|
|
||||||
return policy_loss + self.pretrain_coef * lm_loss
|
|
||||||
|
|
||||||
|
|
||||||
class LogSigLoss(nn.Module):
|
|
||||||
"""
|
|
||||||
Pairwise Loss for Reward Model
|
|
||||||
Details: https://arxiv.org/abs/2203.02155
|
|
||||||
"""
|
|
||||||
def forward(self, chosen_reward: torch.Tensor, reject_reward: torch.Tensor) -> torch.Tensor:
|
|
||||||
probs = torch.sigmoid(chosen_reward - reject_reward)
|
|
||||||
log_probs = torch.log(probs)
|
|
||||||
loss = -log_probs.mean()
|
|
||||||
return loss
|
|
||||||
|
|
||||||
|
|
||||||
class LogExpLoss(nn.Module):
|
|
||||||
"""
|
|
||||||
Pairwise Loss for Reward Model
|
|
||||||
Details: https://arxiv.org/abs/2204.05862
|
|
||||||
"""
|
|
||||||
def forward(self, chosen_reward: torch.Tensor, reject_reward: torch.Tensor) -> torch.Tensor:
|
|
||||||
loss = torch.log(1 + torch.exp(reject_reward - chosen_reward)).mean()
|
|
||||||
return loss
|
|
@ -1,6 +0,0 @@
|
|||||||
from .opt_actor import OPTActor
|
|
||||||
from .opt_critic import OPTCritic
|
|
||||||
from .opt_rm import OPTRM
|
|
||||||
from .opt_lm import OPTLM
|
|
||||||
|
|
||||||
__all__ = ['OPTActor', 'OPTCritic', 'OPTRM', 'OPTLM']
|
|
@ -1,35 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
from transformers.models.opt.configuration_opt import OPTConfig
|
|
||||||
from transformers.models.opt.modeling_opt import OPTForCausalLM
|
|
||||||
|
|
||||||
from ..base import Actor
|
|
||||||
|
|
||||||
|
|
||||||
class OPTActor(Actor):
|
|
||||||
"""
|
|
||||||
OPT Actor model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
pretrained (str): Pretrained model name or path.
|
|
||||||
config (OPTConfig): Model config.
|
|
||||||
checkpoint (bool): Enable gradient checkpointing.
|
|
||||||
lora_rank (int): Rank of the low-rank approximation.
|
|
||||||
lora_train_bias (str): LoRA bias training mode.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
pretrained: Optional[str] = None,
|
|
||||||
config: Optional[OPTConfig] = None,
|
|
||||||
checkpoint: bool = False,
|
|
||||||
lora_rank: int = 0,
|
|
||||||
lora_train_bias: str = 'none') -> None:
|
|
||||||
if pretrained is not None:
|
|
||||||
model = OPTForCausalLM.from_pretrained(pretrained)
|
|
||||||
elif config is not None:
|
|
||||||
model = OPTForCausalLM(config)
|
|
||||||
else:
|
|
||||||
model = OPTForCausalLM(OPTConfig())
|
|
||||||
if checkpoint:
|
|
||||||
model.gradient_checkpointing_enable()
|
|
||||||
super().__init__(model, lora_rank, lora_train_bias)
|
|
@ -1,38 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch.nn as nn
|
|
||||||
from transformers.models.opt.configuration_opt import OPTConfig
|
|
||||||
from transformers.models.opt.modeling_opt import OPTModel
|
|
||||||
|
|
||||||
from ..base import Critic
|
|
||||||
|
|
||||||
|
|
||||||
class OPTCritic(Critic):
|
|
||||||
"""
|
|
||||||
OPT Critic model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
pretrained (str): Pretrained model name or path.
|
|
||||||
config (OPTConfig): Model config.
|
|
||||||
checkpoint (bool): Enable gradient checkpointing.
|
|
||||||
lora_rank (int): Rank of the low-rank approximation.
|
|
||||||
lora_train_bias (str): LoRA bias training mode.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
pretrained: Optional[str] = None,
|
|
||||||
config: Optional[OPTConfig] = None,
|
|
||||||
checkpoint: bool = False,
|
|
||||||
lora_rank: int = 0,
|
|
||||||
lora_train_bias: str = 'none',
|
|
||||||
**kwargs) -> None:
|
|
||||||
if pretrained is not None:
|
|
||||||
model = OPTModel.from_pretrained(pretrained)
|
|
||||||
elif config is not None:
|
|
||||||
model = OPTModel(config)
|
|
||||||
else:
|
|
||||||
model = OPTModel(OPTConfig())
|
|
||||||
if checkpoint:
|
|
||||||
model.gradient_checkpointing_enable()
|
|
||||||
value_head = nn.Linear(model.config.word_embed_proj_dim, 1)
|
|
||||||
super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs)
|
|
@ -1,36 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
from transformers.models.opt.configuration_opt import OPTConfig
|
|
||||||
from transformers.models.opt.modeling_opt import OPTForCausalLM
|
|
||||||
|
|
||||||
from ..base import LM
|
|
||||||
|
|
||||||
|
|
||||||
class OPTLM(LM):
|
|
||||||
"""
|
|
||||||
OPT language model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
pretrained (str): Pretrained model name or path.
|
|
||||||
config (OPTConfig): Model config.
|
|
||||||
checkpoint (bool): Enable gradient checkpointing.
|
|
||||||
lora_rank (int): Rank of the low-rank approximation.
|
|
||||||
lora_train_bias (str): LoRA bias training mode.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
pretrained: Optional[str] = None,
|
|
||||||
config: Optional[OPTConfig] = None,
|
|
||||||
checkpoint: bool = False,
|
|
||||||
lora_rank: int = 0,
|
|
||||||
lora_train_bias: str = 'none') -> None:
|
|
||||||
if pretrained is not None:
|
|
||||||
model = OPTForCausalLM.from_pretrained(pretrained)
|
|
||||||
elif config is not None:
|
|
||||||
model = OPTForCausalLM(config)
|
|
||||||
else:
|
|
||||||
model = OPTForCausalLM(OPTConfig())
|
|
||||||
if checkpoint:
|
|
||||||
model.gradient_checkpointing_enable()
|
|
||||||
super().__init__(model, lora_rank, lora_train_bias)
|
|
||||||
|
|
@ -1,38 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch.nn as nn
|
|
||||||
from transformers import OPTConfig, OPTModel
|
|
||||||
|
|
||||||
from ..base import RewardModel
|
|
||||||
|
|
||||||
|
|
||||||
class OPTRM(RewardModel):
|
|
||||||
"""
|
|
||||||
OPT Reward model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
pretrained (str): Pretrained model name or path.
|
|
||||||
config (OPTConfig): Model config.
|
|
||||||
checkpoint (bool): Enable gradient checkpointing.
|
|
||||||
lora_rank (int): Rank of the low-rank approximation.
|
|
||||||
lora_train_bias (str): LoRA bias training mode.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
pretrained: Optional[str] = None,
|
|
||||||
config: Optional[OPTConfig] = None,
|
|
||||||
checkpoint: bool = False,
|
|
||||||
lora_rank: int = 0,
|
|
||||||
lora_train_bias: str = 'none') -> None:
|
|
||||||
if pretrained is not None:
|
|
||||||
model = OPTModel.from_pretrained(pretrained)
|
|
||||||
elif config is not None:
|
|
||||||
model = OPTModel(config)
|
|
||||||
else:
|
|
||||||
model = OPTModel(OPTConfig())
|
|
||||||
if checkpoint:
|
|
||||||
model.gradient_checkpointing_enable()
|
|
||||||
|
|
||||||
value_head = nn.Linear(model.config.word_embed_proj_dim, 1)
|
|
||||||
value_head.weight.data.normal_(mean=0.0, std=1/(model.config.word_embed_proj_dim + 1))
|
|
||||||
super().__init__(model, value_head, lora_rank, lora_train_bias)
|
|
@ -1,92 +0,0 @@
|
|||||||
from typing import Optional, Union
|
|
||||||
|
|
||||||
import loralib as lora
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
|
|
||||||
def compute_approx_kl(log_probs: torch.Tensor,
|
|
||||||
log_probs_base: torch.Tensor,
|
|
||||||
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Compute the approximate KL divergence between two distributions.
|
|
||||||
Schulman blog: http://joschu.net/blog/kl-approx.html
|
|
||||||
|
|
||||||
Args:
|
|
||||||
log_probs: Log probabilities of the new distribution.
|
|
||||||
log_probs_base: Log probabilities of the base distribution.
|
|
||||||
action_mask: Mask for actions.
|
|
||||||
"""
|
|
||||||
|
|
||||||
log_ratio = log_probs - log_probs_base
|
|
||||||
approx_kl = (log_ratio.exp() - 1) - log_ratio
|
|
||||||
if action_mask is not None:
|
|
||||||
approx_kl = masked_mean(approx_kl, action_mask, dim=1)
|
|
||||||
return approx_kl
|
|
||||||
approx_kl = approx_kl.mean(dim=1)
|
|
||||||
return approx_kl
|
|
||||||
|
|
||||||
|
|
||||||
def compute_reward(r: Union[torch.Tensor, float],
|
|
||||||
kl_coef: float,
|
|
||||||
log_probs: torch.Tensor,
|
|
||||||
log_probs_base: torch.Tensor,
|
|
||||||
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
||||||
if kl_coef <= 0.0:
|
|
||||||
return r
|
|
||||||
kl = compute_approx_kl(log_probs, log_probs_base, action_mask=action_mask)
|
|
||||||
reward = r - kl_coef * kl
|
|
||||||
return reward
|
|
||||||
|
|
||||||
|
|
||||||
def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
|
|
||||||
log_probs = F.log_softmax(logits, dim=-1)
|
|
||||||
log_probs_labels = log_probs.gather(dim=-1, index=labels.unsqueeze(-1))
|
|
||||||
return log_probs_labels.squeeze(-1)
|
|
||||||
|
|
||||||
|
|
||||||
def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch.Tensor:
|
|
||||||
tensor = tensor * mask
|
|
||||||
tensor = tensor.sum(dim=dim)
|
|
||||||
mask_sum = mask.sum(dim=dim)
|
|
||||||
mean = tensor / (mask_sum + 1e-8)
|
|
||||||
return mean
|
|
||||||
|
|
||||||
|
|
||||||
def masked_normalize(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1, eps: float = 1e-8) -> torch.Tensor:
|
|
||||||
tensor = tensor * mask
|
|
||||||
mean = masked_mean(tensor, mask, dim=dim)
|
|
||||||
mean_centered = tensor - mean
|
|
||||||
var = masked_mean(mean_centered**2, mask, dim=dim)
|
|
||||||
return mean_centered * var.clamp(min=eps).rsqrt()
|
|
||||||
|
|
||||||
|
|
||||||
def normalize(tensor: torch.Tensor, dim: int = 0, eps: float = 1e-8) -> torch.Tensor:
|
|
||||||
mean = tensor.mean(dim)
|
|
||||||
mean_centered = tensor - mean
|
|
||||||
var = (mean_centered**2).mean(dim)
|
|
||||||
norm = mean_centered * var.clamp(min=eps).rsqrt()
|
|
||||||
return norm
|
|
||||||
|
|
||||||
|
|
||||||
def convert_to_lora(model: nn.Module,
|
|
||||||
input_size: int,
|
|
||||||
output_size: int,
|
|
||||||
lora_rank: int = 16,
|
|
||||||
lora_alpha: int = 1,
|
|
||||||
lora_dropout: float = 0.,
|
|
||||||
fan_in_fan_out: bool = False,
|
|
||||||
merge_weights: bool = True):
|
|
||||||
if lora_rank > min(input_size, output_size):
|
|
||||||
raise ValueError(f"LoRA rank {lora_rank} must be less or equal than {min(input_size, output_size)}")
|
|
||||||
|
|
||||||
for name, module in model.named_modules():
|
|
||||||
if isinstance(module, nn.Linear):
|
|
||||||
module._modules[name] = lora.Linear(input_size,
|
|
||||||
output_size,
|
|
||||||
r=lora_rank,
|
|
||||||
lora_alpha=lora_alpha,
|
|
||||||
lora_dropout=lora_dropout,
|
|
||||||
fan_in_fan_out=fan_in_fan_out,
|
|
||||||
merge_weights=merge_weights)
|
|
@ -1,4 +0,0 @@
|
|||||||
from .base import ReplayBuffer
|
|
||||||
from .naive import NaiveReplayBuffer
|
|
||||||
|
|
||||||
__all__ = ['ReplayBuffer', 'NaiveReplayBuffer']
|
|
@ -1,43 +0,0 @@
|
|||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from chatgpt.experience_maker.base import Experience
|
|
||||||
|
|
||||||
|
|
||||||
class ReplayBuffer(ABC):
|
|
||||||
"""Replay buffer base class. It stores experience.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
sample_batch_size (int): Batch size when sampling.
|
|
||||||
limit (int, optional): Limit of number of experience samples. A number <= 0 means unlimited. Defaults to 0.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, sample_batch_size: int, limit: int = 0) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.sample_batch_size = sample_batch_size
|
|
||||||
# limit <= 0 means unlimited
|
|
||||||
self.limit = limit
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def append(self, experience: Experience) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def clear(self) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def sample(self) -> Experience:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def __len__(self) -> int:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def __getitem__(self, idx: int) -> Any:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def collate_fn(self, batch: Any) -> Experience:
|
|
||||||
pass
|
|
@ -1,57 +0,0 @@
|
|||||||
import random
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from chatgpt.experience_maker.base import Experience
|
|
||||||
|
|
||||||
from .base import ReplayBuffer
|
|
||||||
from .utils import BufferItem, make_experience_batch, split_experience_batch
|
|
||||||
|
|
||||||
|
|
||||||
class NaiveReplayBuffer(ReplayBuffer):
|
|
||||||
"""Naive replay buffer class. It stores experience.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
sample_batch_size (int): Batch size when sampling.
|
|
||||||
limit (int, optional): Limit of number of experience samples. A number <= 0 means unlimited. Defaults to 0.
|
|
||||||
cpu_offload (bool, optional): Whether to offload experience to cpu when sampling. Defaults to True.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, sample_batch_size: int, limit: int = 0, cpu_offload: bool = True) -> None:
|
|
||||||
super().__init__(sample_batch_size, limit)
|
|
||||||
self.cpu_offload = cpu_offload
|
|
||||||
self.target_device = torch.device(f'cuda:{torch.cuda.current_device()}')
|
|
||||||
# TODO(ver217): add prefetch
|
|
||||||
self.items: List[BufferItem] = []
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def append(self, experience: Experience) -> None:
|
|
||||||
if self.cpu_offload:
|
|
||||||
experience.to_device(torch.device('cpu'))
|
|
||||||
items = split_experience_batch(experience)
|
|
||||||
self.items.extend(items)
|
|
||||||
if self.limit > 0:
|
|
||||||
samples_to_remove = len(self.items) - self.limit
|
|
||||||
if samples_to_remove > 0:
|
|
||||||
self.items = self.items[samples_to_remove:]
|
|
||||||
|
|
||||||
def clear(self) -> None:
|
|
||||||
self.items.clear()
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def sample(self) -> Experience:
|
|
||||||
items = random.sample(self.items, self.sample_batch_size)
|
|
||||||
experience = make_experience_batch(items)
|
|
||||||
if self.cpu_offload:
|
|
||||||
experience.to_device(self.target_device)
|
|
||||||
return experience
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
|
||||||
return len(self.items)
|
|
||||||
|
|
||||||
def __getitem__(self, idx: int) -> BufferItem:
|
|
||||||
return self.items[idx]
|
|
||||||
|
|
||||||
def collate_fn(self, batch) -> Experience:
|
|
||||||
experience = make_experience_batch(batch)
|
|
||||||
return experience
|
|
@ -1,73 +0,0 @@
|
|||||||
from dataclasses import dataclass
|
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from chatgpt.experience_maker.base import Experience
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class BufferItem:
|
|
||||||
"""BufferItem is an item of experience data.
|
|
||||||
|
|
||||||
Shapes of each tensor:
|
|
||||||
sequences: (S)
|
|
||||||
action_log_probs: (A)
|
|
||||||
values: (1)
|
|
||||||
reward: (1)
|
|
||||||
advatanges: (1)
|
|
||||||
attention_mask: (S)
|
|
||||||
action_mask: (A)
|
|
||||||
|
|
||||||
"A" is the number of actions.
|
|
||||||
"""
|
|
||||||
sequences: torch.Tensor
|
|
||||||
action_log_probs: torch.Tensor
|
|
||||||
values: torch.Tensor
|
|
||||||
reward: torch.Tensor
|
|
||||||
advantages: torch.Tensor
|
|
||||||
attention_mask: Optional[torch.LongTensor]
|
|
||||||
action_mask: Optional[torch.BoolTensor]
|
|
||||||
|
|
||||||
|
|
||||||
def split_experience_batch(experience: Experience) -> List[BufferItem]:
|
|
||||||
batch_size = experience.sequences.size(0)
|
|
||||||
batch_kwargs = [{} for _ in range(batch_size)]
|
|
||||||
keys = ('sequences', 'action_log_probs', 'values', 'reward', 'advantages', 'attention_mask', 'action_mask')
|
|
||||||
for key in keys:
|
|
||||||
value = getattr(experience, key)
|
|
||||||
if isinstance(value, torch.Tensor):
|
|
||||||
vals = torch.unbind(value)
|
|
||||||
else:
|
|
||||||
# None
|
|
||||||
vals = [value for _ in range(batch_size)]
|
|
||||||
assert batch_size == len(vals)
|
|
||||||
for i, v in enumerate(vals):
|
|
||||||
batch_kwargs[i][key] = v
|
|
||||||
items = [BufferItem(**kwargs) for kwargs in batch_kwargs]
|
|
||||||
return items
|
|
||||||
|
|
||||||
|
|
||||||
def zero_pad_sequences(sequences: List[torch.Tensor], side: str = 'left') -> torch.Tensor:
|
|
||||||
assert side in ('left', 'right')
|
|
||||||
max_len = max(seq.size(0) for seq in sequences)
|
|
||||||
padded_sequences = []
|
|
||||||
for seq in sequences:
|
|
||||||
pad_len = max_len - seq.size(0)
|
|
||||||
padding = (pad_len, 0) if side == 'left' else (0, pad_len)
|
|
||||||
padded_sequences.append(F.pad(seq, padding))
|
|
||||||
return torch.stack(padded_sequences, dim=0)
|
|
||||||
|
|
||||||
|
|
||||||
def make_experience_batch(items: List[BufferItem]) -> Experience:
|
|
||||||
kwargs = {}
|
|
||||||
to_pad_keys = set(('action_log_probs', 'action_mask'))
|
|
||||||
keys = ('sequences', 'action_log_probs', 'values', 'reward', 'advantages', 'attention_mask', 'action_mask')
|
|
||||||
for key in keys:
|
|
||||||
vals = [getattr(item, key) for item in items]
|
|
||||||
if key in to_pad_keys:
|
|
||||||
batch_data = zero_pad_sequences(vals)
|
|
||||||
else:
|
|
||||||
batch_data = torch.stack(vals, dim=0)
|
|
||||||
kwargs[key] = batch_data
|
|
||||||
return Experience(**kwargs)
|
|
@ -1,6 +0,0 @@
|
|||||||
from .base import Trainer
|
|
||||||
from .ppo import PPOTrainer
|
|
||||||
from .rm import RewardModelTrainer
|
|
||||||
from .sft import SFTTrainer
|
|
||||||
|
|
||||||
__all__ = ['Trainer', 'PPOTrainer', 'RewardModelTrainer', 'SFTTrainer']
|
|
@ -1,162 +0,0 @@
|
|||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Any, Callable, Dict, List, Optional, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from chatgpt.experience_maker import Experience, ExperienceMaker
|
|
||||||
from chatgpt.replay_buffer import ReplayBuffer
|
|
||||||
from torch import Tensor
|
|
||||||
from torch.utils.data import DistributedSampler
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
from .callbacks import Callback
|
|
||||||
from .strategies import Strategy
|
|
||||||
from .utils import is_rank_0
|
|
||||||
|
|
||||||
|
|
||||||
class Trainer(ABC):
|
|
||||||
"""
|
|
||||||
Base class for rlhf trainers.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
strategy (Strategy):the strategy to use for training
|
|
||||||
experience_maker (ExperienceMaker): the experience maker to use for produce experience to fullfill replay buffer
|
|
||||||
replay_buffer (ReplayBuffer): the replay buffer to use for training
|
|
||||||
experience_batch_size (int, defaults to 8): the batch size to use for experience generation
|
|
||||||
max_epochs (int, defaults to 1): the number of epochs of training process
|
|
||||||
tokenizer (Callable, optional): the tokenizer to use for tokenizing the input
|
|
||||||
sample_replay_buffer (bool, defaults to False): whether to sample from replay buffer
|
|
||||||
data_loader_pin_memory (bool, defaults to True): whether to pin memory for data loader
|
|
||||||
callbacks (List[Callback], defaults to []): the callbacks to call during training process
|
|
||||||
generate_kwargs (dict, optional): the kwargs to use while model generating
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
strategy: Strategy,
|
|
||||||
experience_maker: ExperienceMaker,
|
|
||||||
replay_buffer: ReplayBuffer,
|
|
||||||
experience_batch_size: int = 8,
|
|
||||||
max_epochs: int = 1,
|
|
||||||
tokenizer: Optional[Callable[[Any], dict]] = None,
|
|
||||||
sample_replay_buffer: bool = False,
|
|
||||||
dataloader_pin_memory: bool = True,
|
|
||||||
callbacks: List[Callback] = [],
|
|
||||||
**generate_kwargs) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.strategy = strategy
|
|
||||||
self.experience_maker = experience_maker
|
|
||||||
self.replay_buffer = replay_buffer
|
|
||||||
self.experience_batch_size = experience_batch_size
|
|
||||||
self.max_epochs = max_epochs
|
|
||||||
self.tokenizer = tokenizer
|
|
||||||
self.generate_kwargs = generate_kwargs
|
|
||||||
self.sample_replay_buffer = sample_replay_buffer
|
|
||||||
self.dataloader_pin_memory = dataloader_pin_memory
|
|
||||||
self.callbacks = callbacks
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def training_step(self, experience: Experience) -> Dict[str, Any]:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def _make_experience(self, inputs: Union[Tensor, Dict[str, Tensor]]) -> Experience:
|
|
||||||
if isinstance(inputs, Tensor):
|
|
||||||
return self.experience_maker.make_experience(inputs, **self.generate_kwargs)
|
|
||||||
elif isinstance(inputs, dict):
|
|
||||||
return self.experience_maker.make_experience(**inputs, **self.generate_kwargs)
|
|
||||||
else:
|
|
||||||
raise ValueError(f'Unsupported input type "{type(inputs)}"')
|
|
||||||
|
|
||||||
def _sample_prompts(self, prompts) -> list:
|
|
||||||
indices = list(range(len(prompts)))
|
|
||||||
sampled_indices = self.strategy.experience_sampler.choice(indices, self.experience_batch_size, replace=False)
|
|
||||||
return [prompts[i] for i in sampled_indices]
|
|
||||||
|
|
||||||
def _learn(self):
|
|
||||||
# replay buffer may be empty at first, we should rebuild at each training
|
|
||||||
if not self.sample_replay_buffer:
|
|
||||||
dataloader = self.strategy.setup_dataloader(self.replay_buffer, self.dataloader_pin_memory)
|
|
||||||
device = torch.cuda.current_device()
|
|
||||||
if self.sample_replay_buffer:
|
|
||||||
pbar = tqdm(range(self.max_epochs), desc='Train epoch', disable=not is_rank_0())
|
|
||||||
for _ in pbar:
|
|
||||||
experience = self.replay_buffer.sample()
|
|
||||||
metrics = self.training_step(experience)
|
|
||||||
pbar.set_postfix(metrics)
|
|
||||||
else:
|
|
||||||
for epoch in range(self.max_epochs):
|
|
||||||
self._on_learn_epoch_start(epoch)
|
|
||||||
if isinstance(dataloader.sampler, DistributedSampler):
|
|
||||||
dataloader.sampler.set_epoch(epoch)
|
|
||||||
pbar = tqdm(dataloader, desc=f'Train epoch [{epoch+1}/{self.max_epochs}]', disable=not is_rank_0())
|
|
||||||
for experience in pbar:
|
|
||||||
self._on_learn_batch_start()
|
|
||||||
experience.to_device(device)
|
|
||||||
metrics = self.training_step(experience)
|
|
||||||
self._on_learn_batch_end(metrics, experience)
|
|
||||||
pbar.set_postfix(metrics)
|
|
||||||
self._on_learn_epoch_end(epoch)
|
|
||||||
|
|
||||||
def fit(self, prompts, num_episodes: int = 50000, max_timesteps: int = 500, update_timesteps: int = 5000) -> None:
|
|
||||||
time = 0
|
|
||||||
sampler = self.strategy.setup_sampler(prompts)
|
|
||||||
self._on_fit_start()
|
|
||||||
for episode in range(num_episodes):
|
|
||||||
self._on_episode_start(episode)
|
|
||||||
for timestep in tqdm(range(max_timesteps),
|
|
||||||
desc=f'Episode [{episode+1}/{num_episodes}]',
|
|
||||||
disable=not is_rank_0()):
|
|
||||||
time += 1
|
|
||||||
rand_prompts = sampler.sample(self.experience_batch_size)
|
|
||||||
if self.tokenizer is not None:
|
|
||||||
inputs = self.tokenizer(rand_prompts)
|
|
||||||
else:
|
|
||||||
inputs = rand_prompts
|
|
||||||
self._on_make_experience_start()
|
|
||||||
experience = self._make_experience(inputs)
|
|
||||||
self._on_make_experience_end(experience)
|
|
||||||
self.replay_buffer.append(experience)
|
|
||||||
if time % update_timesteps == 0:
|
|
||||||
self._learn()
|
|
||||||
self.replay_buffer.clear()
|
|
||||||
self._on_episode_end(episode)
|
|
||||||
self._on_fit_end()
|
|
||||||
|
|
||||||
# TODO(ver217): maybe simplify these code using context
|
|
||||||
def _on_fit_start(self) -> None:
|
|
||||||
for callback in self.callbacks:
|
|
||||||
callback.on_fit_start()
|
|
||||||
|
|
||||||
def _on_fit_end(self) -> None:
|
|
||||||
for callback in self.callbacks:
|
|
||||||
callback.on_fit_end()
|
|
||||||
|
|
||||||
def _on_episode_start(self, episode: int) -> None:
|
|
||||||
for callback in self.callbacks:
|
|
||||||
callback.on_episode_start(episode)
|
|
||||||
|
|
||||||
def _on_episode_end(self, episode: int) -> None:
|
|
||||||
for callback in self.callbacks:
|
|
||||||
callback.on_episode_end(episode)
|
|
||||||
|
|
||||||
def _on_make_experience_start(self) -> None:
|
|
||||||
for callback in self.callbacks:
|
|
||||||
callback.on_make_experience_start()
|
|
||||||
|
|
||||||
def _on_make_experience_end(self, experience: Experience) -> None:
|
|
||||||
for callback in self.callbacks:
|
|
||||||
callback.on_make_experience_end(experience)
|
|
||||||
|
|
||||||
def _on_learn_epoch_start(self, epoch: int) -> None:
|
|
||||||
for callback in self.callbacks:
|
|
||||||
callback.on_learn_epoch_start(epoch)
|
|
||||||
|
|
||||||
def _on_learn_epoch_end(self, epoch: int) -> None:
|
|
||||||
for callback in self.callbacks:
|
|
||||||
callback.on_learn_epoch_end(epoch)
|
|
||||||
|
|
||||||
def _on_learn_batch_start(self) -> None:
|
|
||||||
for callback in self.callbacks:
|
|
||||||
callback.on_learn_batch_start()
|
|
||||||
|
|
||||||
def _on_learn_batch_end(self, metrics: dict, experience: Experience) -> None:
|
|
||||||
for callback in self.callbacks:
|
|
||||||
callback.on_learn_batch_end(metrics, experience)
|
|
@ -1,5 +0,0 @@
|
|||||||
from .base import Callback
|
|
||||||
from .performance_evaluator import PerformanceEvaluator
|
|
||||||
from .save_checkpoint import SaveCheckpoint
|
|
||||||
|
|
||||||
__all__ = ['Callback', 'PerformanceEvaluator', 'SaveCheckpoint']
|
|
@ -1,39 +0,0 @@
|
|||||||
from abc import ABC
|
|
||||||
|
|
||||||
from chatgpt.experience_maker import Experience
|
|
||||||
|
|
||||||
|
|
||||||
class Callback(ABC):
|
|
||||||
"""
|
|
||||||
Base callback class. It defines the interface for callbacks.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def on_fit_start(self) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def on_fit_end(self) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def on_episode_start(self, episode: int) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def on_episode_end(self, episode: int) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def on_make_experience_start(self) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def on_make_experience_end(self, experience: Experience) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def on_learn_epoch_start(self, epoch: int) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def on_learn_epoch_end(self, epoch: int) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def on_learn_batch_start(self) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def on_learn_batch_end(self, metrics: dict, experience: Experience) -> None:
|
|
||||||
pass
|
|
@ -1,133 +0,0 @@
|
|||||||
from time import time
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.distributed as dist
|
|
||||||
from chatgpt.experience_maker import Experience
|
|
||||||
|
|
||||||
from .base import Callback
|
|
||||||
|
|
||||||
|
|
||||||
def get_world_size() -> int:
|
|
||||||
if dist.is_initialized():
|
|
||||||
return dist.get_world_size()
|
|
||||||
return 1
|
|
||||||
|
|
||||||
|
|
||||||
def print_rank_0(*args, **kwargs) -> None:
|
|
||||||
if not dist.is_initialized() or dist.get_rank() == 0:
|
|
||||||
print(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def all_reduce_mean(x: float, world_size: int) -> float:
|
|
||||||
if world_size == 1:
|
|
||||||
return x
|
|
||||||
tensor = torch.tensor([x], device=torch.cuda.current_device())
|
|
||||||
dist.all_reduce(tensor)
|
|
||||||
tensor = tensor / world_size
|
|
||||||
return tensor.item()
|
|
||||||
|
|
||||||
|
|
||||||
class PerformanceEvaluator(Callback):
|
|
||||||
"""
|
|
||||||
Callback for valuate the performance of the model.
|
|
||||||
Args:
|
|
||||||
actor_num_params: The number of parameters of the actor model.
|
|
||||||
critic_num_params: The number of parameters of the critic model.
|
|
||||||
initial_model_num_params: The number of parameters of the initial model.
|
|
||||||
reward_model_num_params: The number of parameters of the reward model.
|
|
||||||
enable_grad_checkpoint: Whether to enable gradient checkpointing.
|
|
||||||
ignore_episodes: The number of episodes to ignore when calculating the performance.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
actor_num_params: int,
|
|
||||||
critic_num_params: int,
|
|
||||||
initial_model_num_params: int,
|
|
||||||
reward_model_num_params: int,
|
|
||||||
enable_grad_checkpoint: bool = False,
|
|
||||||
ignore_episodes: int = 0) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.world_size = get_world_size()
|
|
||||||
self.actor_num_params = actor_num_params
|
|
||||||
self.critic_num_params = critic_num_params
|
|
||||||
self.initial_model_num_params = initial_model_num_params
|
|
||||||
self.reward_model_num_params = reward_model_num_params
|
|
||||||
self.enable_grad_checkpoint = enable_grad_checkpoint
|
|
||||||
self.ignore_episodes = ignore_episodes
|
|
||||||
self.disable: bool = False
|
|
||||||
|
|
||||||
self.make_experience_duration: float = 0.
|
|
||||||
self.make_experience_start_time: Optional[float] = None
|
|
||||||
self.make_experience_num_samples: int = 0
|
|
||||||
self.make_experience_flop: int = 0
|
|
||||||
self.learn_duration: float = 0.
|
|
||||||
self.learn_start_time: Optional[float] = None
|
|
||||||
self.learn_num_samples: int = 0
|
|
||||||
self.learn_flop: int = 0
|
|
||||||
|
|
||||||
def on_episode_start(self, episode: int) -> None:
|
|
||||||
self.disable = self.ignore_episodes > 0 and episode < self.ignore_episodes
|
|
||||||
|
|
||||||
def on_make_experience_start(self) -> None:
|
|
||||||
if self.disable:
|
|
||||||
return
|
|
||||||
self.make_experience_start_time = time()
|
|
||||||
|
|
||||||
def on_make_experience_end(self, experience: Experience) -> None:
|
|
||||||
if self.disable:
|
|
||||||
return
|
|
||||||
self.make_experience_duration += time() - self.make_experience_start_time
|
|
||||||
|
|
||||||
batch_size, seq_len = experience.sequences.shape
|
|
||||||
|
|
||||||
self.make_experience_num_samples += batch_size
|
|
||||||
|
|
||||||
# actor generate
|
|
||||||
num_actions = experience.action_mask.size(1)
|
|
||||||
input_len = seq_len - num_actions
|
|
||||||
total_seq_len = (input_len + seq_len - 1) * num_actions / 2
|
|
||||||
self.make_experience_flop += self.actor_num_params * batch_size * total_seq_len * 2
|
|
||||||
# actor forward
|
|
||||||
self.make_experience_flop += self.actor_num_params * batch_size * seq_len * 2
|
|
||||||
# critic forward
|
|
||||||
self.make_experience_flop += self.critic_num_params * batch_size * seq_len * 2
|
|
||||||
# initial model forward
|
|
||||||
self.make_experience_flop += self.initial_model_num_params * batch_size * seq_len * 2
|
|
||||||
# reward model forward
|
|
||||||
self.make_experience_flop += self.reward_model_num_params * batch_size * seq_len * 2
|
|
||||||
|
|
||||||
def on_learn_batch_start(self) -> None:
|
|
||||||
if self.disable:
|
|
||||||
return
|
|
||||||
self.learn_start_time = time()
|
|
||||||
|
|
||||||
def on_learn_batch_end(self, metrics: dict, experience: Experience) -> None:
|
|
||||||
if self.disable:
|
|
||||||
return
|
|
||||||
self.learn_duration += time() - self.learn_start_time
|
|
||||||
|
|
||||||
batch_size, seq_len = experience.sequences.shape
|
|
||||||
|
|
||||||
self.learn_num_samples += batch_size
|
|
||||||
|
|
||||||
# actor forward-backward, 3 means forward(1) + backward(2)
|
|
||||||
self.learn_flop += self.actor_num_params * batch_size * seq_len * 2 * (3 + int(self.enable_grad_checkpoint))
|
|
||||||
# critic foward-backward
|
|
||||||
self.learn_flop += self.critic_num_params * batch_size * seq_len * 2 * (3 + int(self.enable_grad_checkpoint))
|
|
||||||
|
|
||||||
def on_fit_end(self) -> None:
|
|
||||||
avg_make_experience_duration = all_reduce_mean(self.make_experience_duration, self.world_size)
|
|
||||||
avg_learn_duration = all_reduce_mean(self.learn_duration, self.world_size)
|
|
||||||
|
|
||||||
avg_make_experience_throughput = self.make_experience_num_samples / (avg_make_experience_duration + 1e-12)
|
|
||||||
avg_make_experience_tflops = self.make_experience_flop / 1e12 / (avg_make_experience_duration + 1e-12)
|
|
||||||
|
|
||||||
avg_learn_throughput = self.learn_num_samples / (avg_learn_duration + 1e-12)
|
|
||||||
avg_learn_tflops = self.learn_flop / 1e12 / (avg_learn_duration + 1e-12)
|
|
||||||
|
|
||||||
print_rank_0(
|
|
||||||
f'Making experience throughput: {avg_make_experience_throughput:.3f} samples/sec, TFLOPS: {avg_make_experience_tflops:.3f}'
|
|
||||||
)
|
|
||||||
print_rank_0(f'Learning throughput: {avg_learn_throughput:.3f} samples/sec, TFLOPS: {avg_learn_tflops:.3f}')
|
|
@ -1,75 +0,0 @@
|
|||||||
import os
|
|
||||||
|
|
||||||
import torch.distributed as dist
|
|
||||||
from chatgpt.trainer.strategies import ColossalAIStrategy, Strategy
|
|
||||||
from chatgpt.trainer.utils import is_rank_0
|
|
||||||
from torch import nn
|
|
||||||
from torch.optim import Optimizer
|
|
||||||
|
|
||||||
from .base import Callback
|
|
||||||
|
|
||||||
|
|
||||||
class SaveCheckpoint(Callback):
|
|
||||||
"""
|
|
||||||
The callback for saving checkpoint for chatgpt.
|
|
||||||
|
|
||||||
Only support saving actor and critic model.
|
|
||||||
A typical architecture of the saved checkpoint would be:
|
|
||||||
- checkpoint
|
|
||||||
- episode_x
|
|
||||||
- actor.pt
|
|
||||||
- actor-optim-rank-0.pt
|
|
||||||
- actor-optim-rank-1.pt
|
|
||||||
- critic.pt
|
|
||||||
- critic-optim-rank-0.pt
|
|
||||||
- critic-optim-rank-1.pt
|
|
||||||
- ...
|
|
||||||
|
|
||||||
Args:
|
|
||||||
path(str): the base path you want to save checkpoint, the checkpoint would be saved at `path/checkpoint`
|
|
||||||
interval(int): the interval episode of saving checkpoint
|
|
||||||
strategy(Strategy): the strategy used to train
|
|
||||||
actor(nn.Module): the actor model
|
|
||||||
critic(nn.Module): the critic model
|
|
||||||
actor_optim(Optimizer): the optimizer of actor
|
|
||||||
critic_optim(Optimizer): the optimizer of critic
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
path: str,
|
|
||||||
interval: int,
|
|
||||||
strategy: Strategy,
|
|
||||||
actor: nn.Module = None,
|
|
||||||
critic: nn.Module = None,
|
|
||||||
actor_optim: Optimizer = None,
|
|
||||||
critic_optim: Optimizer = None) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.path = os.path.join(path, 'checkpoint')
|
|
||||||
self.interval = interval
|
|
||||||
self.strategy = strategy
|
|
||||||
self.model_dict = {'actor': [actor, actor_optim], 'critic': [critic, critic_optim]}
|
|
||||||
|
|
||||||
def on_episode_end(self, episode: int) -> None:
|
|
||||||
if (episode + 1) % self.interval != 0:
|
|
||||||
return
|
|
||||||
base_path = os.path.join(self.path, f'episode_{episode}')
|
|
||||||
if not os.path.exists(base_path):
|
|
||||||
os.makedirs(base_path)
|
|
||||||
|
|
||||||
for model in self.model_dict.keys():
|
|
||||||
|
|
||||||
# save model
|
|
||||||
if self.model_dict[model][0] is None:
|
|
||||||
# saving only optimizer states is meaningless, so it would be skipped
|
|
||||||
continue
|
|
||||||
model_path = os.path.join(base_path, f'{model}.pt')
|
|
||||||
self.strategy.save_model(model=self.model_dict[model][0], path=model_path, only_rank0=True)
|
|
||||||
|
|
||||||
# save optimizer
|
|
||||||
if self.model_dict[model][1] is None:
|
|
||||||
continue
|
|
||||||
only_rank0 = not isinstance(self.strategy, ColossalAIStrategy)
|
|
||||||
rank = 0 if is_rank_0() else dist.get_rank()
|
|
||||||
optim_path = os.path.join(base_path, f'{model}-optim-rank-{rank}.pt')
|
|
||||||
self.strategy.save_optimizer(optimizer=self.model_dict[model][1], path=optim_path, only_rank0=only_rank0)
|
|
@ -1,116 +0,0 @@
|
|||||||
from typing import Any, Callable, Dict, List, Optional
|
|
||||||
|
|
||||||
import torch.nn as nn
|
|
||||||
from chatgpt.experience_maker import Experience, NaiveExperienceMaker
|
|
||||||
from chatgpt.models.base import Actor, Critic
|
|
||||||
from chatgpt.models.generation_utils import update_model_kwargs_fn
|
|
||||||
from chatgpt.models.loss import PolicyLoss, ValueLoss
|
|
||||||
from chatgpt.replay_buffer import NaiveReplayBuffer
|
|
||||||
from torch.optim import Optimizer
|
|
||||||
|
|
||||||
from .base import Trainer
|
|
||||||
from .callbacks import Callback
|
|
||||||
from .strategies import Strategy
|
|
||||||
|
|
||||||
|
|
||||||
class PPOTrainer(Trainer):
|
|
||||||
"""
|
|
||||||
Trainer for PPO algorithm.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
strategy (Strategy): the strategy to use for training
|
|
||||||
actor (Actor): the actor model in ppo algorithm
|
|
||||||
critic (Critic): the critic model in ppo algorithm
|
|
||||||
reward_model (nn.Module): the reward model in rlhf algorithm to make reward of sentences
|
|
||||||
initial_model (Actor): the initial model in rlhf algorithm to generate reference logits to limit the update of actor
|
|
||||||
actor_optim (Optimizer): the optimizer to use for actor model
|
|
||||||
critic_optim (Optimizer): the optimizer to use for critic model
|
|
||||||
kl_coef (float, defaults to 0.1): the coefficient of kl divergence loss
|
|
||||||
train_batch_size (int, defaults to 8): the batch size to use for training
|
|
||||||
buffer_limit (int, defaults to 0): the max_size limitaiton of replay buffer
|
|
||||||
buffer_cpu_offload (bool, defaults to True): whether to offload replay buffer to cpu
|
|
||||||
eps_clip (float, defaults to 0.2): the clip coefficient of policy loss
|
|
||||||
value_clip (float, defaults to 0.4): the clip coefficient of value loss
|
|
||||||
experience_batch_size (int, defaults to 8): the batch size to use for experience generation
|
|
||||||
max_epochs (int, defaults to 1): the number of epochs of training process
|
|
||||||
tokenier (Callable, optional): the tokenizer to use for tokenizing the input
|
|
||||||
sample_replay_buffer (bool, defaults to False): whether to sample from replay buffer
|
|
||||||
dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader
|
|
||||||
callbacks (List[Callback], defaults to []): the callbacks to call during training process
|
|
||||||
generate_kwargs (dict, optional): the kwargs to use while model generating
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
strategy: Strategy,
|
|
||||||
actor: Actor,
|
|
||||||
critic: Critic,
|
|
||||||
reward_model: nn.Module,
|
|
||||||
initial_model: Actor,
|
|
||||||
actor_optim: Optimizer,
|
|
||||||
critic_optim: Optimizer,
|
|
||||||
kl_coef: float = 0.1,
|
|
||||||
train_batch_size: int = 8,
|
|
||||||
buffer_limit: int = 0,
|
|
||||||
buffer_cpu_offload: bool = True,
|
|
||||||
eps_clip: float = 0.2,
|
|
||||||
value_clip: float = 0.4,
|
|
||||||
experience_batch_size: int = 8,
|
|
||||||
max_epochs: int = 1,
|
|
||||||
tokenizer: Optional[Callable[[Any], dict]] = None,
|
|
||||||
sample_replay_buffer: bool = False,
|
|
||||||
dataloader_pin_memory: bool = True,
|
|
||||||
callbacks: List[Callback] = [],
|
|
||||||
**generate_kwargs) -> None:
|
|
||||||
experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, kl_coef)
|
|
||||||
replay_buffer = NaiveReplayBuffer(train_batch_size, buffer_limit, buffer_cpu_offload)
|
|
||||||
generate_kwargs = _set_default_generate_kwargs(strategy, generate_kwargs, actor)
|
|
||||||
super().__init__(strategy, experience_maker, replay_buffer, experience_batch_size, max_epochs, tokenizer,
|
|
||||||
sample_replay_buffer, dataloader_pin_memory, callbacks, **generate_kwargs)
|
|
||||||
self.actor = actor
|
|
||||||
self.critic = critic
|
|
||||||
|
|
||||||
self.actor_loss_fn = PolicyLoss(eps_clip)
|
|
||||||
self.critic_loss_fn = ValueLoss(value_clip)
|
|
||||||
|
|
||||||
self.actor_optim = actor_optim
|
|
||||||
self.critic_optim = critic_optim
|
|
||||||
|
|
||||||
def training_step(self, experience: Experience) -> Dict[str, float]:
|
|
||||||
self.actor.train()
|
|
||||||
self.critic.train()
|
|
||||||
|
|
||||||
num_actions = experience.action_mask.size(1)
|
|
||||||
action_log_probs = self.actor(experience.sequences, num_actions, attention_mask=experience.attention_mask)
|
|
||||||
actor_loss = self.actor_loss_fn(action_log_probs,
|
|
||||||
experience.action_log_probs,
|
|
||||||
experience.advantages,
|
|
||||||
action_mask=experience.action_mask)
|
|
||||||
self.strategy.backward(actor_loss, self.actor, self.actor_optim)
|
|
||||||
self.strategy.optimizer_step(self.actor_optim)
|
|
||||||
self.actor_optim.zero_grad()
|
|
||||||
|
|
||||||
values = self.critic(experience.sequences,
|
|
||||||
action_mask=experience.action_mask,
|
|
||||||
attention_mask=experience.attention_mask)
|
|
||||||
critic_loss = self.critic_loss_fn(values,
|
|
||||||
experience.values,
|
|
||||||
experience.reward,
|
|
||||||
action_mask=experience.action_mask)
|
|
||||||
self.strategy.backward(critic_loss, self.critic, self.critic_optim)
|
|
||||||
self.strategy.optimizer_step(self.critic_optim)
|
|
||||||
self.critic_optim.zero_grad()
|
|
||||||
|
|
||||||
return {'actor_loss': actor_loss.item(), 'critic_loss': critic_loss.item()}
|
|
||||||
|
|
||||||
|
|
||||||
def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, actor: Actor) -> None:
|
|
||||||
origin_model = strategy._unwrap_actor(actor)
|
|
||||||
new_kwargs = {**generate_kwargs}
|
|
||||||
# use huggingface models method directly
|
|
||||||
if 'prepare_inputs_fn' not in generate_kwargs and hasattr(origin_model, 'prepare_inputs_for_generation'):
|
|
||||||
new_kwargs['prepare_inputs_fn'] = origin_model.prepare_inputs_for_generation
|
|
||||||
|
|
||||||
if 'update_model_kwargs_fn' not in generate_kwargs:
|
|
||||||
new_kwargs['update_model_kwargs_fn'] = update_model_kwargs_fn
|
|
||||||
|
|
||||||
return new_kwargs
|
|
@ -1,120 +0,0 @@
|
|||||||
from abc import ABC
|
|
||||||
import pandas as pd
|
|
||||||
import loralib as lora
|
|
||||||
import torch
|
|
||||||
from datetime import datetime
|
|
||||||
from torch.optim import Optimizer, lr_scheduler
|
|
||||||
from torch.utils.data import DataLoader, Dataset
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
from .strategies import Strategy
|
|
||||||
from .utils import is_rank_0
|
|
||||||
|
|
||||||
|
|
||||||
class RewardModelTrainer(ABC):
|
|
||||||
"""
|
|
||||||
Trainer to use while training reward model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model (torch.nn.Module): the model to train
|
|
||||||
strategy (Strategy): the strategy to use for training
|
|
||||||
optim(Optimizer): the optimizer to use for training
|
|
||||||
loss_fn (callable): the loss function to use for training
|
|
||||||
train_dataset (Dataset): the dataset to use for training
|
|
||||||
valid_dataset (Dataset): the dataset to use for validation
|
|
||||||
eval_dataset (Dataset): the dataset to use for evaluation
|
|
||||||
batch_size (int, defaults to 1): the batch size while training
|
|
||||||
max_epochs (int, defaults to 2): the number of epochs to train
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model,
|
|
||||||
strategy: Strategy,
|
|
||||||
optim: Optimizer,
|
|
||||||
loss_fn,
|
|
||||||
train_dataset: Dataset,
|
|
||||||
valid_dataset: Dataset,
|
|
||||||
eval_dataset: Dataset,
|
|
||||||
batch_size: int = 1,
|
|
||||||
max_epochs: int = 1,
|
|
||||||
) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.strategy = strategy
|
|
||||||
self.epochs = max_epochs
|
|
||||||
self.train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
|
||||||
self.valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=True)
|
|
||||||
self.eval_dataloader = DataLoader(eval_dataset, batch_size=batch_size, shuffle=True)
|
|
||||||
|
|
||||||
self.model = strategy.setup_model(model)
|
|
||||||
self.loss_fn = loss_fn
|
|
||||||
self.optimizer = strategy.setup_optimizer(optim, self.model)
|
|
||||||
self.scheduler = lr_scheduler.CosineAnnealingLR(self.optimizer, self.train_dataloader.__len__()//100)
|
|
||||||
|
|
||||||
|
|
||||||
def eval_acc(self, dataloader):
|
|
||||||
dist = 0
|
|
||||||
on = 0
|
|
||||||
cnt = 0
|
|
||||||
self.model.eval()
|
|
||||||
with torch.no_grad():
|
|
||||||
for chosen_ids, c_mask, reject_ids, r_mask in dataloader:
|
|
||||||
chosen_ids = chosen_ids.squeeze(1).to(torch.cuda.current_device())
|
|
||||||
c_mask = c_mask.squeeze(1).to(torch.cuda.current_device())
|
|
||||||
reject_ids = reject_ids.squeeze(1).to(torch.cuda.current_device())
|
|
||||||
r_mask = r_mask.squeeze(1).to(torch.cuda.current_device())
|
|
||||||
chosen_reward = self.model(chosen_ids, attention_mask=c_mask)
|
|
||||||
reject_reward = self.model(reject_ids, attention_mask=r_mask)
|
|
||||||
for i in range(len(chosen_reward)):
|
|
||||||
cnt += 1
|
|
||||||
if chosen_reward[i] > reject_reward[i]:
|
|
||||||
on += 1
|
|
||||||
dist += (chosen_reward - reject_reward).mean().item()
|
|
||||||
dist_mean = dist / len(dataloader)
|
|
||||||
acc = on / cnt
|
|
||||||
self.model.train()
|
|
||||||
return dist_mean, acc
|
|
||||||
|
|
||||||
|
|
||||||
def fit(self):
|
|
||||||
time = datetime.now()
|
|
||||||
epoch_bar = tqdm(range(self.epochs), desc='Train epoch', disable=not is_rank_0())
|
|
||||||
for epoch in range(self.epochs):
|
|
||||||
step_bar = tqdm(range(self.train_dataloader.__len__()),
|
|
||||||
desc='Train step of epoch %d' % epoch,
|
|
||||||
disable=not is_rank_0())
|
|
||||||
# train
|
|
||||||
self.model.train()
|
|
||||||
cnt = 0
|
|
||||||
acc = 0
|
|
||||||
dist = 0
|
|
||||||
for chosen_ids, c_mask, reject_ids, r_mask in self.train_dataloader:
|
|
||||||
chosen_ids = chosen_ids.squeeze(1).to(torch.cuda.current_device())
|
|
||||||
c_mask = c_mask.squeeze(1).to(torch.cuda.current_device())
|
|
||||||
reject_ids = reject_ids.squeeze(1).to(torch.cuda.current_device())
|
|
||||||
r_mask = r_mask.squeeze(1).to(torch.cuda.current_device())
|
|
||||||
chosen_reward = self.model(chosen_ids, attention_mask=c_mask)
|
|
||||||
reject_reward = self.model(reject_ids, attention_mask=r_mask)
|
|
||||||
loss = self.loss_fn(chosen_reward, reject_reward)
|
|
||||||
self.strategy.backward(loss, self.model, self.optimizer)
|
|
||||||
self.strategy.optimizer_step(self.optimizer)
|
|
||||||
self.optimizer.zero_grad()
|
|
||||||
cnt += 1
|
|
||||||
if cnt == 100:
|
|
||||||
self.scheduler.step()
|
|
||||||
dist, acc = self.eval_acc(self.valid_dataloader)
|
|
||||||
cnt = 0
|
|
||||||
if is_rank_0():
|
|
||||||
log = pd.DataFrame([[step_bar.n, loss.item(), dist, acc]], columns=['step', 'loss', 'dist', 'acc'])
|
|
||||||
log.to_csv('log_%s.csv' % time, mode='a', header=False, index=False)
|
|
||||||
step_bar.update()
|
|
||||||
step_bar.set_postfix({'dist': dist, 'acc': acc})
|
|
||||||
|
|
||||||
# eval
|
|
||||||
dist, acc = self.eval_acc(self.eval_dataloader)
|
|
||||||
if is_rank_0():
|
|
||||||
log = pd.DataFrame([[step_bar.n, loss.item(), dist, acc]], columns=['step', 'loss', 'dist', 'acc'])
|
|
||||||
log.to_csv('log.csv', mode='a', header=False, index=False)
|
|
||||||
epoch_bar.update()
|
|
||||||
step_bar.set_postfix({'dist': dist, 'acc': acc})
|
|
||||||
step_bar.close()
|
|
@ -1,106 +0,0 @@
|
|||||||
from abc import ABC
|
|
||||||
from typing import Optional
|
|
||||||
import loralib as lora
|
|
||||||
import torch
|
|
||||||
from chatgpt.models.loss import GPTLMLoss
|
|
||||||
from torch.optim import Adam, Optimizer
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
|
||||||
from tqdm import tqdm
|
|
||||||
import torch.distributed as dist
|
|
||||||
from .strategies import Strategy
|
|
||||||
from .utils import is_rank_0
|
|
||||||
from colossalai.logging import get_dist_logger
|
|
||||||
|
|
||||||
|
|
||||||
class SFTTrainer(ABC):
|
|
||||||
"""
|
|
||||||
Trainer to use while training reward model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model (torch.nn.Module): the model to train
|
|
||||||
strategy (Strategy): the strategy to use for training
|
|
||||||
optim(Optimizer): the optimizer to use for training
|
|
||||||
train_dataloader: the dataloader to use for training
|
|
||||||
eval_dataloader: the dataloader to use for evaluation
|
|
||||||
batch_size (int, defaults to 1): the batch size while training
|
|
||||||
max_epochs (int, defaults to 2): the number of epochs to train
|
|
||||||
optim_kwargs (dict, defaults to {'lr':1e-4}): the kwargs to use while initializing optimizer
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model,
|
|
||||||
strategy: Strategy,
|
|
||||||
optim: Optimizer,
|
|
||||||
train_dataloader: DataLoader,
|
|
||||||
eval_dataloader: DataLoader = None,
|
|
||||||
sampler: Optional[DistributedSampler] = None,
|
|
||||||
batch_size: int = 1,
|
|
||||||
max_epochs: int = 2,
|
|
||||||
) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.strategy = strategy
|
|
||||||
self.epochs = max_epochs
|
|
||||||
self.sampler = sampler
|
|
||||||
|
|
||||||
self.train_dataloader = train_dataloader
|
|
||||||
self.eval_dataloader = eval_dataloader
|
|
||||||
|
|
||||||
self.model = strategy.setup_model(model)
|
|
||||||
if "DDP" in str(self.strategy):
|
|
||||||
self.model = self.model.module
|
|
||||||
self.loss_fn = GPTLMLoss()
|
|
||||||
self.optimizer = strategy.setup_optimizer(optim, self.model)
|
|
||||||
|
|
||||||
def fit(self, logger, use_lora, log_interval=10):
|
|
||||||
epoch_bar = tqdm(range(self.epochs), desc='Train epoch', disable=not is_rank_0())
|
|
||||||
for epoch in range(self.epochs):
|
|
||||||
if isinstance(self.sampler, DistributedSampler):
|
|
||||||
self.sampler.set_epoch(epoch)
|
|
||||||
# train
|
|
||||||
self.model.train()
|
|
||||||
for batch_id, batch in enumerate(self.train_dataloader):
|
|
||||||
prompt_ids = batch["input_ids"].to(torch.cuda.current_device())
|
|
||||||
p_mask = batch["attention_mask"].to(torch.cuda.current_device())
|
|
||||||
labels = batch["labels"].to(torch.cuda.current_device())
|
|
||||||
# prompt_ids = prompt_ids.squeeze(1).cuda()
|
|
||||||
# p_mask = p_mask.squeeze(1).cuda()
|
|
||||||
# prompt_logits = self.model(prompt_ids, attention_mask=p_mask, labels=labels)
|
|
||||||
outputs = self.model(prompt_ids, attention_mask=p_mask, labels=labels)
|
|
||||||
loss = outputs.loss
|
|
||||||
prompt_logits = outputs.logits
|
|
||||||
|
|
||||||
# loss = self.loss_fn(prompt_logits, labels)
|
|
||||||
self.strategy.backward(loss, self.model, self.optimizer)
|
|
||||||
self.strategy.optimizer_step(self.optimizer)
|
|
||||||
self.optimizer.zero_grad()
|
|
||||||
if batch_id % log_interval == 0:
|
|
||||||
logger.info(f'Train Epoch {epoch}/{self.epochs} Batch {batch_id} Rank {dist.get_rank()} loss {loss.item()}')
|
|
||||||
|
|
||||||
# eval
|
|
||||||
if self.eval_dataloader is not None:
|
|
||||||
self.model.eval()
|
|
||||||
with torch.no_grad():
|
|
||||||
loss_sum = 0
|
|
||||||
num_seen = 0
|
|
||||||
for batch in self.eval_dataloader:
|
|
||||||
prompt_ids = batch["input_ids"].to(torch.cuda.current_device())
|
|
||||||
p_mask = batch["attention_mask"].to(torch.cuda.current_device())
|
|
||||||
labels = batch["labels"].to(torch.cuda.current_device())
|
|
||||||
# prompt_ids = prompt_ids.squeeze(1).cuda()
|
|
||||||
# p_mask = p_mask.squeeze(1).cuda()
|
|
||||||
|
|
||||||
outputs = self.model(prompt_ids, attention_mask=p_mask, labels=labels)
|
|
||||||
loss = outputs.loss
|
|
||||||
# prompt_logits = outputs.logits
|
|
||||||
|
|
||||||
loss_sum += loss.item()
|
|
||||||
num_seen += prompt_ids.size(0)
|
|
||||||
|
|
||||||
loss_mean = loss_sum / num_seen
|
|
||||||
if dist.get_rank() == 0:
|
|
||||||
logger.info(f'Eval Epoch {epoch}/{self.epochs} loss {loss_mean}')
|
|
||||||
|
|
||||||
epoch_bar.update()
|
|
||||||
|
|
@ -1,6 +0,0 @@
|
|||||||
from .base import Strategy
|
|
||||||
from .colossalai import ColossalAIStrategy
|
|
||||||
from .ddp import DDPStrategy
|
|
||||||
from .naive import NaiveStrategy
|
|
||||||
|
|
||||||
__all__ = ['Strategy', 'NaiveStrategy', 'DDPStrategy', 'ColossalAIStrategy']
|
|
@ -1,131 +0,0 @@
|
|||||||
from abc import ABC, abstractmethod
|
|
||||||
from contextlib import nullcontext
|
|
||||||
from typing import Any, List, Tuple, Union
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from chatgpt.models.base import Actor, Critic, RewardModel
|
|
||||||
from chatgpt.replay_buffer import ReplayBuffer
|
|
||||||
from torch.optim import Optimizer
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
from .sampler import DistributedSampler
|
|
||||||
|
|
||||||
ModelOptimPair = Tuple[nn.Module, Optimizer]
|
|
||||||
ModelOrModelOptimPair = Union[nn.Module, ModelOptimPair]
|
|
||||||
|
|
||||||
|
|
||||||
class Strategy(ABC):
|
|
||||||
"""
|
|
||||||
Base class for training strategies.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.setup_distributed()
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def backward(self, loss: torch.Tensor, model: nn.Module, optimizer: Optimizer, **kwargs) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def optimizer_step(self, optimizer: Optimizer, **kwargs) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def setup_distributed(self) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def setup_model(self, model: nn.Module) -> nn.Module:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def setup_optimizer(self, optimizer: Optimizer, model: nn.Module) -> Optimizer:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def setup_dataloader(self, replay_buffer: ReplayBuffer, pin_memory: bool = False) -> DataLoader:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def model_init_context(self):
|
|
||||||
return nullcontext()
|
|
||||||
|
|
||||||
def prepare(
|
|
||||||
self, *models_or_model_optim_pairs: ModelOrModelOptimPair
|
|
||||||
) -> Union[List[ModelOrModelOptimPair], ModelOrModelOptimPair]:
|
|
||||||
"""Prepare models or model-optimizer-pairs based on each strategy.
|
|
||||||
|
|
||||||
Example::
|
|
||||||
>>> # when fine-tuning actor and critic
|
|
||||||
>>> (actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare((actor, actor_optim), (critic, critic_optim), reward_model, initial_model)
|
|
||||||
>>> # or when training reward model
|
|
||||||
>>> (reward_model, reward_model_optim) = strategy.prepare((reward_model, reward_model_optim))
|
|
||||||
>>> # or just inference
|
|
||||||
>>> actor, critic = strategy.prepare(actor, critic)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Union[List[ModelOrModelOptimPair], ModelOrModelOptimPair]: Models or model-optimizer-pairs in the original order.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def prepare_model(model: nn.Module):
|
|
||||||
if isinstance(model, Actor):
|
|
||||||
return Actor(self.setup_model(self._unwrap_model(model)))
|
|
||||||
return self.setup_model(self._unwrap_model(model))
|
|
||||||
|
|
||||||
rets = []
|
|
||||||
for arg in models_or_model_optim_pairs:
|
|
||||||
if isinstance(arg, tuple):
|
|
||||||
assert len(arg) == 2, f'Expect (model, optimizer) pair, got a tuple with size "{len(arg)}"'
|
|
||||||
model, optimizer = arg
|
|
||||||
model = prepare_model(model)
|
|
||||||
optimizer = self.setup_optimizer(optimizer, self._unwrap_model(model))
|
|
||||||
rets.append((model, optimizer))
|
|
||||||
elif isinstance(arg, nn.Module):
|
|
||||||
rets.append(prepare_model(arg))
|
|
||||||
else:
|
|
||||||
raise RuntimeError(f'Expect model or (model, optimizer) pair, got {type(arg)}')
|
|
||||||
|
|
||||||
if len(rets) == 1:
|
|
||||||
return rets[0]
|
|
||||||
return rets
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _unwrap_model(model: nn.Module) -> nn.Module:
|
|
||||||
"""Useful for saving state dict. As actor is wrapped by Actor class again in `prepare()`, we should unwrap it before saving.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model (nn.Module): an actor or a critic
|
|
||||||
"""
|
|
||||||
if isinstance(model, Actor):
|
|
||||||
return model.model
|
|
||||||
return model
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _unwrap_actor(actor: Actor) -> nn.Module:
|
|
||||||
"""Get `actor.model` from a wrapped (by `prepare()`) actor. Useful for getting original huggingface model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
actor (Actor): a wrapped actor
|
|
||||||
"""
|
|
||||||
return Strategy._unwrap_model(actor)
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def save_model(self, model: nn.Module, path: str, only_rank0: bool = False) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def load_model(self, model: nn.Module, path: str, map_location: Any = None, strict: bool = True) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def load_optimizer(self, optimizer: Optimizer, path: str, map_location: Any = None) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def setup_sampler(self, dataset) -> DistributedSampler:
|
|
||||||
return DistributedSampler(dataset, 1, 0)
|
|
@ -1,190 +0,0 @@
|
|||||||
import warnings
|
|
||||||
from typing import Optional, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.distributed as dist
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.optim as optim
|
|
||||||
from chatgpt.models.base import Actor
|
|
||||||
from chatgpt.models.lora import LoraLinear
|
|
||||||
from torch.optim import Optimizer
|
|
||||||
|
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
|
||||||
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
|
|
||||||
|
|
||||||
import colossalai
|
|
||||||
from colossalai.nn.optimizer import CPUAdam, HybridAdam
|
|
||||||
from colossalai.nn.parallel import ZeroDDP, zero_model_wrapper, zero_optim_wrapper
|
|
||||||
from colossalai.nn.parallel.utils import get_static_torch_model
|
|
||||||
from colossalai.tensor import ProcessGroup, ShardSpec
|
|
||||||
from colossalai.utils import get_current_device
|
|
||||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
|
||||||
|
|
||||||
from .base import Strategy
|
|
||||||
from .ddp import DDPStrategy
|
|
||||||
|
|
||||||
|
|
||||||
class ColossalAIStrategy(DDPStrategy):
|
|
||||||
"""
|
|
||||||
The strategy for training with ColossalAI.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
stage(int): The stage to use in ZeRO. Choose in (1, 2, 3)
|
|
||||||
precision(str): The precision to use. Choose in ('fp32', 'fp16'). Stage 3 only supports fp16.
|
|
||||||
seed(int): The seed for the random number generator.
|
|
||||||
shard_init(bool): Whether to shard the model parameters during initialization. Only for ZeRO-3.
|
|
||||||
This is not compativle with `from_pretrained()`. We temporarily disable this and will support it in the future.
|
|
||||||
placement_policy(str): The placement policy for gemini. Choose in ('cpu', 'cuda')
|
|
||||||
If it is “cpu”, parameters, gradients and optimizer states will be offloaded to CPU,
|
|
||||||
If it is “cuda”, they will not be offloaded, which means max CUDA memory will be used. It is the fastest.
|
|
||||||
pin_memory(bool): Whether to pin the memory for the data loader. Only for ZeRO-3.
|
|
||||||
force_outputs_fp32(bool): Whether to force the outputs to be fp32. Only for ZeRO-3.
|
|
||||||
search_range_mb(int): The search range in MB for the chunk size. Only for ZeRO-3.
|
|
||||||
hidden_dim(optional, int): The hidden dimension for the gemini. Only for ZeRO-3.
|
|
||||||
min_chunk_size_mb(float): The minimum chunk size in MB. Only for ZeRO-3.
|
|
||||||
gpu_margin_mem_ratio(float): The margin memory ratio for the GPU. Only for ZeRO-3.
|
|
||||||
reduce_bugket_size(int): The reduce bucket size in bytes. Only for ZeRO-1 and ZeRO-2.
|
|
||||||
overlap_communication(bool): Whether to overlap communication and computation. Only for ZeRO-1 and ZeRO-2.
|
|
||||||
initial_scale(float): The initial scale for the optimizer.
|
|
||||||
growth_factor(float): The growth factor for the optimizer.
|
|
||||||
backoff_factor(float): The backoff factor for the optimizer.
|
|
||||||
growth_interval(int): The growth interval for the optimizer.
|
|
||||||
hysteresis(int): The hysteresis for the optimizer.
|
|
||||||
min_scale(float): The minimum scale for the optimizer.
|
|
||||||
max_scale(float): The maximum scale for the optimizer.
|
|
||||||
max_norm(float): The maximum norm for the optimizer.
|
|
||||||
norm_type(float): The norm type for the optimizer.
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
stage: int = 3,
|
|
||||||
precision: str = 'fp16',
|
|
||||||
seed: int = 42,
|
|
||||||
shard_init: bool = False, # only for stage 3
|
|
||||||
placement_policy: str = 'cuda',
|
|
||||||
pin_memory: bool = True, # only for stage 3
|
|
||||||
force_outputs_fp32: bool = False, # only for stage 3
|
|
||||||
search_range_mb: int = 32, # only for stage 3
|
|
||||||
hidden_dim: Optional[int] = None, # only for stage 3
|
|
||||||
min_chunk_size_mb: float = 32, # only for stage 3
|
|
||||||
gpu_margin_mem_ratio: float = 0.0, # only for stage 3
|
|
||||||
reduce_bucket_size: int = 12 * 1024**2, # only for stage 1&2
|
|
||||||
overlap_communication: bool = True, # only for stage 1&2
|
|
||||||
initial_scale: float = 2**16,
|
|
||||||
growth_factor: float = 2,
|
|
||||||
backoff_factor: float = 0.5,
|
|
||||||
growth_interval: int = 1000,
|
|
||||||
hysteresis: int = 2,
|
|
||||||
min_scale: float = 1,
|
|
||||||
max_scale: float = 2**32,
|
|
||||||
max_norm: float = 0.0,
|
|
||||||
norm_type: float = 2.0) -> None:
|
|
||||||
super().__init__(seed)
|
|
||||||
assert placement_policy in ('cpu', 'cuda'), f'Unsupported placement policy "{placement_policy}"'
|
|
||||||
assert precision in ('fp32', 'fp16'), f'Unsupported precision "{precision}"'
|
|
||||||
self.stage = stage
|
|
||||||
# TODO(ver217): support shard_init when using from_pretrained()
|
|
||||||
if shard_init:
|
|
||||||
warnings.warn(
|
|
||||||
f'Shard init is not supported model.from_pretrained() yet. Please load weights after strategy.prepare()'
|
|
||||||
)
|
|
||||||
if stage == 3 and precision == 'fp32':
|
|
||||||
warnings.warn(f'Stage 3 only supports fp16. Precision is set to fp16.')
|
|
||||||
precision = 'fp16'
|
|
||||||
self.precision = precision
|
|
||||||
self.shard_init = shard_init
|
|
||||||
self.gemini_config = dict(device=get_current_device(),
|
|
||||||
placement_policy=placement_policy,
|
|
||||||
pin_memory=pin_memory,
|
|
||||||
force_outputs_fp32=force_outputs_fp32,
|
|
||||||
strict_ddp_mode=shard_init,
|
|
||||||
search_range_mb=search_range_mb,
|
|
||||||
hidden_dim=hidden_dim,
|
|
||||||
min_chunk_size_mb=min_chunk_size_mb)
|
|
||||||
if stage == 3:
|
|
||||||
self.zero_optim_config = dict(gpu_margin_mem_ratio=gpu_margin_mem_ratio)
|
|
||||||
else:
|
|
||||||
self.zero_optim_config = dict(reduce_bucket_size=reduce_bucket_size,
|
|
||||||
overlap_communication=overlap_communication,
|
|
||||||
cpu_offload=(placement_policy == 'cpu'))
|
|
||||||
self.optim_kwargs = dict(initial_scale=initial_scale,
|
|
||||||
growth_factor=growth_factor,
|
|
||||||
backoff_factor=backoff_factor,
|
|
||||||
growth_interval=growth_interval,
|
|
||||||
hysteresis=hysteresis,
|
|
||||||
min_scale=min_scale,
|
|
||||||
max_scale=max_scale,
|
|
||||||
max_norm=max_norm,
|
|
||||||
norm_type=norm_type)
|
|
||||||
|
|
||||||
def setup_distributed(self) -> None:
|
|
||||||
colossalai.launch_from_torch({}, seed=self.seed)
|
|
||||||
|
|
||||||
def model_init_context(self):
|
|
||||||
if self.stage == 3:
|
|
||||||
world_size = dist.get_world_size()
|
|
||||||
shard_pg = ProcessGroup(tp_degree=world_size) if self.shard_init else None
|
|
||||||
default_dist_spec = ShardSpec([-1], [world_size]) if self.shard_init else None
|
|
||||||
return ColoInitContext(device=get_current_device(),
|
|
||||||
dtype=torch.half,
|
|
||||||
default_pg=shard_pg,
|
|
||||||
default_dist_spec=default_dist_spec)
|
|
||||||
return super().model_init_context()
|
|
||||||
|
|
||||||
def setup_model(self, model: nn.Module) -> nn.Module:
|
|
||||||
model = zero_model_wrapper(model, zero_stage=self.stage, gemini_config=self.gemini_config)
|
|
||||||
if self.stage != 3 and self.precision == 'fp16':
|
|
||||||
model = model.half()
|
|
||||||
return model
|
|
||||||
|
|
||||||
def setup_optimizer(self, optimizer: optim.Optimizer, model: nn.Module) -> optim.Optimizer:
|
|
||||||
assert isinstance(optimizer, (CPUAdam, HybridAdam)), f'Unsupported optimizer {type(optimizer)}'
|
|
||||||
return zero_optim_wrapper(model, optimizer, optim_config=self.zero_optim_config, **self.optim_kwargs)
|
|
||||||
|
|
||||||
def backward(self, loss: torch.Tensor, model: nn.Module, optimizer: optim.Optimizer, **kwargs) -> None:
|
|
||||||
optimizer.backward(loss)
|
|
||||||
|
|
||||||
def optimizer_step(self, optimizer: optim.Optimizer, **kwargs) -> None:
|
|
||||||
optimizer.step()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _unwrap_actor(actor: Actor) -> nn.Module:
|
|
||||||
model: Union[nn.Module, ZeroDDP] = Strategy._unwrap_actor(actor)
|
|
||||||
if isinstance(model, ZeroDDP):
|
|
||||||
return model.module
|
|
||||||
return model
|
|
||||||
|
|
||||||
def save_model(self, model: nn.Module, path: str, only_rank0: bool = False, tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
|
|
||||||
unwrapped_model = self._unwrap_model(model)
|
|
||||||
# TODO : better way to get torch model from gemini model
|
|
||||||
# to get torch model from gemini model
|
|
||||||
if isinstance(unwrapped_model, ZeroDDP):
|
|
||||||
state_dict = unwrapped_model.state_dict()
|
|
||||||
unwrapped_model = get_static_torch_model(unwrapped_model)
|
|
||||||
if only_rank0 and dist.get_rank() != 0:
|
|
||||||
return
|
|
||||||
unwrapped_model.load_state_dict(state_dict)
|
|
||||||
# merge lora_weights into weights
|
|
||||||
for module in unwrapped_model.modules():
|
|
||||||
if isinstance(module, LoraLinear):
|
|
||||||
module.merge_weights = True
|
|
||||||
module.eval()
|
|
||||||
# get state_dict and save
|
|
||||||
|
|
||||||
if not isinstance(self.model, PreTrainedModel):
|
|
||||||
state_dict = unwrapped_model.state_dict()
|
|
||||||
if only_rank0 and dist.get_rank() != 0:
|
|
||||||
return
|
|
||||||
torch.save(state_dict, path)
|
|
||||||
else:
|
|
||||||
self.model.save_pretrained(path)
|
|
||||||
if tokenizer is not None:
|
|
||||||
tokenizer.save_pretrained(path)
|
|
||||||
|
|
||||||
def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None:
|
|
||||||
if only_rank0:
|
|
||||||
raise RuntimeError(
|
|
||||||
f'Optimizer states are sharded when using ColossalAIStrategy. Only rank0 is not supported.')
|
|
||||||
torch.save(optimizer.state_dict(), path)
|
|
@ -1,93 +0,0 @@
|
|||||||
import os
|
|
||||||
import random
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torch.distributed as dist
|
|
||||||
import torch.nn as nn
|
|
||||||
from chatgpt.models.base import Actor
|
|
||||||
from chatgpt.models.lora import LoraLinear
|
|
||||||
from chatgpt.replay_buffer import ReplayBuffer
|
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
||||||
from torch.optim import Optimizer
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
from .base import Strategy
|
|
||||||
from .naive import NaiveStrategy
|
|
||||||
from .sampler import DistributedSampler
|
|
||||||
|
|
||||||
|
|
||||||
class DDPStrategy(NaiveStrategy):
|
|
||||||
"""
|
|
||||||
Strategy for distributed training using torch.distributed.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, seed: int = 42) -> None:
|
|
||||||
self.seed = seed
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def setup_distributed(self) -> None:
|
|
||||||
try:
|
|
||||||
rank = int(os.environ['RANK'])
|
|
||||||
local_rank = int(os.environ['LOCAL_RANK'])
|
|
||||||
world_size = int(os.environ['WORLD_SIZE'])
|
|
||||||
host = os.environ['MASTER_ADDR']
|
|
||||||
port = int(os.environ['MASTER_PORT'])
|
|
||||||
except KeyError as e:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Could not find {e} in the torch environment, visit https://www.colossalai.org/ for more information on launching with torch"
|
|
||||||
)
|
|
||||||
dist.init_process_group('nccl', init_method=f'tcp://[{host}]:{port}', world_size=world_size, rank=rank)
|
|
||||||
self.set_seed(self.seed)
|
|
||||||
torch.cuda.set_device(local_rank)
|
|
||||||
|
|
||||||
def set_seed(self, seed: int) -> None:
|
|
||||||
random.seed(seed)
|
|
||||||
np.random.seed(seed)
|
|
||||||
torch.manual_seed(seed)
|
|
||||||
|
|
||||||
def setup_model(self, model: nn.Module) -> nn.Module:
|
|
||||||
device = torch.cuda.current_device()
|
|
||||||
return DDP(model, device_ids=[device])
|
|
||||||
|
|
||||||
def setup_dataloader(self, replay_buffer: ReplayBuffer, pin_memory: bool = False) -> DataLoader:
|
|
||||||
# DDP only mode, replay buffers on each rank are different.
|
|
||||||
# sampler = DistributedSampler(replay_buffer,
|
|
||||||
# num_replicas=dist.get_world_size(),
|
|
||||||
# rank=dist.get_rank(),
|
|
||||||
# shuffle=True,
|
|
||||||
# seed=self.seed,
|
|
||||||
# drop_last=True)
|
|
||||||
return DataLoader(
|
|
||||||
replay_buffer,
|
|
||||||
batch_size=replay_buffer.sample_batch_size,
|
|
||||||
# sampler=sampler,
|
|
||||||
shuffle=True,
|
|
||||||
drop_last=True,
|
|
||||||
pin_memory=pin_memory,
|
|
||||||
collate_fn=replay_buffer.collate_fn)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _unwrap_actor(actor: Actor) -> nn.Module:
|
|
||||||
model: DDP = Strategy._unwrap_actor(actor)
|
|
||||||
return model.module
|
|
||||||
|
|
||||||
def save_model(self, model: nn.Module, path: str, only_rank0: bool = False) -> None:
|
|
||||||
for module in model.modules():
|
|
||||||
if isinstance(module, LoraLinear):
|
|
||||||
module.merge_weights=True
|
|
||||||
module.eval()
|
|
||||||
|
|
||||||
if only_rank0 and dist.get_rank() != 0:
|
|
||||||
return
|
|
||||||
model = model.model.module
|
|
||||||
state_dict = model.state_dict()
|
|
||||||
torch.save(state_dict, path)
|
|
||||||
|
|
||||||
def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None:
|
|
||||||
if only_rank0 and dist.get_rank() != 0:
|
|
||||||
return
|
|
||||||
super().save_optimizer(optimizer, path, only_rank0)
|
|
||||||
|
|
||||||
def setup_sampler(self, dataset) -> DistributedSampler:
|
|
||||||
return DistributedSampler(dataset, dist.get_world_size(), dist.get_rank())
|
|
@ -1,55 +0,0 @@
|
|||||||
from typing import Any
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.optim as optim
|
|
||||||
from chatgpt.replay_buffer import ReplayBuffer
|
|
||||||
from torch.optim import Optimizer
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
from .base import Strategy
|
|
||||||
|
|
||||||
|
|
||||||
class NaiveStrategy(Strategy):
|
|
||||||
"""
|
|
||||||
Strategy for single GPU. No parallelism is used.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def backward(self, loss: torch.Tensor, model: nn.Module, optimizer: optim.Optimizer, **kwargs) -> None:
|
|
||||||
loss.backward()
|
|
||||||
|
|
||||||
def optimizer_step(self, optimizer: optim.Optimizer, **kwargs) -> None:
|
|
||||||
optimizer.step()
|
|
||||||
|
|
||||||
def setup_distributed(self) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def setup_model(self, model: nn.Module) -> nn.Module:
|
|
||||||
return model
|
|
||||||
|
|
||||||
def setup_optimizer(self, optimizer: optim.Optimizer, model: nn.Module) -> optim.Optimizer:
|
|
||||||
return optimizer
|
|
||||||
|
|
||||||
def setup_dataloader(self, replay_buffer: ReplayBuffer, pin_memory: bool = False) -> DataLoader:
|
|
||||||
return DataLoader(replay_buffer,
|
|
||||||
batch_size=replay_buffer.sample_batch_size,
|
|
||||||
shuffle=True,
|
|
||||||
drop_last=True,
|
|
||||||
pin_memory=pin_memory,
|
|
||||||
collate_fn=replay_buffer.collate_fn)
|
|
||||||
|
|
||||||
def save_model(self, model: nn.Module, path: str, only_rank0: bool = False) -> None:
|
|
||||||
unwrapped_model = self._unwrap_model(model)
|
|
||||||
torch.save(unwrapped_model.state_dict(), path)
|
|
||||||
|
|
||||||
def load_model(self, model: nn.Module, path: str, map_location: Any = None, strict: bool = True) -> None:
|
|
||||||
unwrapped_model = self._unwrap_model(model)
|
|
||||||
state_dict = torch.load(path, map_location=map_location)
|
|
||||||
unwrapped_model.load_state_dict(state_dict, strict=strict)
|
|
||||||
|
|
||||||
def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None:
|
|
||||||
torch.save(optimizer.state_dict(), path)
|
|
||||||
|
|
||||||
def load_optimizer(self, optimizer: Optimizer, path: str, map_location: Any = None) -> None:
|
|
||||||
state_dict = torch.load(path, map_location=map_location)
|
|
||||||
optimizer.load_state_dict(state_dict)
|
|
@ -1,32 +0,0 @@
|
|||||||
import math
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
|
||||||
class DistributedSampler:
|
|
||||||
|
|
||||||
def __init__(self, dataset, num_replicas: int, rank: int) -> None:
|
|
||||||
self.dataset = dataset
|
|
||||||
self.num_replicas = num_replicas
|
|
||||||
self.rank = rank
|
|
||||||
|
|
||||||
if len(self.dataset) % self.num_replicas != 0:
|
|
||||||
self.num_samples = math.ceil(
|
|
||||||
(len(self.dataset) - self.num_replicas) / self.num_replicas # type: ignore[arg-type]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.num_samples = math.ceil(len(self.dataset) / self.num_replicas)
|
|
||||||
|
|
||||||
self.total_size = self.num_samples * self.num_replicas
|
|
||||||
|
|
||||||
indices = list(range(len(self.dataset)))
|
|
||||||
indices = indices[:self.total_size]
|
|
||||||
assert len(indices) == self.total_size
|
|
||||||
# subsample
|
|
||||||
indices = indices[self.rank:self.total_size:self.num_replicas]
|
|
||||||
assert len(indices) == self.num_samples
|
|
||||||
self.indices = indices
|
|
||||||
|
|
||||||
def sample(self, batch_size: int) -> list:
|
|
||||||
sampled_indices = np.random.choice(self.indices, batch_size, replace=False)
|
|
||||||
return [self.dataset[idx] for idx in sampled_indices]
|
|
@ -1,5 +0,0 @@
|
|||||||
import torch.distributed as dist
|
|
||||||
|
|
||||||
|
|
||||||
def is_rank_0() -> bool:
|
|
||||||
return not dist.is_initialized() or dist.get_rank() == 0
|
|
@ -1,3 +0,0 @@
|
|||||||
from .tokenizer_utils import smart_tokenizer_and_embedding_resize, prepare_llama_tokenizer_and_embedding
|
|
||||||
|
|
||||||
__all__ = ['smart_tokenizer_and_embedding_resize', 'prepare_llama_tokenizer_and_embedding']
|
|
@ -1,80 +0,0 @@
|
|||||||
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
from typing import Dict
|
|
||||||
|
|
||||||
import transformers
|
|
||||||
|
|
||||||
from ..models.llama.llama_lm import LlamaLM
|
|
||||||
|
|
||||||
DEFAULT_PAD_TOKEN = "[PAD]"
|
|
||||||
DEFAULT_EOS_TOKEN = "</s>"
|
|
||||||
DEFAULT_BOS_TOKEN = "</s>"
|
|
||||||
DEFAULT_UNK_TOKEN = "</s>"
|
|
||||||
|
|
||||||
def prepare_llama_tokenizer_and_embedding(
|
|
||||||
tokenizer: transformers.PreTrainedTokenizer,
|
|
||||||
model: transformers.PreTrainedModel,
|
|
||||||
special_tokens_dict: Dict = dict(pad_token=DEFAULT_PAD_TOKEN),
|
|
||||||
):
|
|
||||||
"""prepare llama tokenizer and embedding.
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
if tokenizer.pad_token is None:
|
|
||||||
smart_tokenizer_and_embedding_resize(
|
|
||||||
special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN),
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
model=model,
|
|
||||||
)
|
|
||||||
|
|
||||||
tokenizer.add_special_tokens(
|
|
||||||
{
|
|
||||||
"eos_token": DEFAULT_EOS_TOKEN,
|
|
||||||
"bos_token": DEFAULT_BOS_TOKEN,
|
|
||||||
"unk_token": DEFAULT_UNK_TOKEN,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return tokenizer
|
|
||||||
|
|
||||||
|
|
||||||
def smart_tokenizer_and_embedding_resize(
|
|
||||||
tokenizer: transformers.PreTrainedTokenizer,
|
|
||||||
model: transformers.PreTrainedModel,
|
|
||||||
special_tokens_dict: Dict = dict(pad_token=DEFAULT_PAD_TOKEN),
|
|
||||||
):
|
|
||||||
"""Resize tokenizer and embedding.
|
|
||||||
|
|
||||||
Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if tokenizer.pad_token is None:
|
|
||||||
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
|
|
||||||
|
|
||||||
if isinstance(model, LlamaLM):
|
|
||||||
model = model.get_base_model()
|
|
||||||
|
|
||||||
model.resize_token_embeddings(len(tokenizer))
|
|
||||||
|
|
||||||
if num_new_tokens > 0:
|
|
||||||
input_embeddings = model.get_input_embeddings().weight.data
|
|
||||||
output_embeddings = model.get_output_embeddings().weight.data
|
|
||||||
|
|
||||||
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
|
|
||||||
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
|
|
||||||
|
|
||||||
input_embeddings[-num_new_tokens:] = input_embeddings_avg
|
|
||||||
output_embeddings[-num_new_tokens:] = output_embeddings_avg
|
|
||||||
|
|
@ -1,141 +0,0 @@
|
|||||||
# Examples
|
|
||||||
|
|
||||||
## Install requirements
|
|
||||||
|
|
||||||
```shell
|
|
||||||
pip install -r requirements.txt
|
|
||||||
```
|
|
||||||
|
|
||||||
## Train the reward model (Stage 2)
|
|
||||||
Use these code to train your reward model.
|
|
||||||
```shell
|
|
||||||
# Take naive reward model training with opt-350m as example
|
|
||||||
python train_reward_model.py --pretrain "facebook/opt-350m" --model 'opt' --strategy naive
|
|
||||||
# use colossalai_zero2
|
|
||||||
torchrun --standalone --nproc_per_node=2 train_reward_model.py --pretrain "facebook/opt-350m" --model 'opt' --strategy colossalai_zero2
|
|
||||||
```
|
|
||||||
|
|
||||||
### Features and tricks in RM training
|
|
||||||
- We support [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf) and [rm-static](https://huggingface.co/datasets/Dahoas/rm-static) datasets.
|
|
||||||
- We support 2 kinds of loss_function named 'log_sig'(used by OpenAI) and 'log_exp'(used by Anthropic).
|
|
||||||
- We change the loss to valid_acc and pair_dist to monitor progress during training.
|
|
||||||
- We add special token to the end of the sequence to get better result.
|
|
||||||
- We use cosine-reducing lr-scheduler for RM training.
|
|
||||||
- We set value_head as 1 liner layer and initialize the weight of value_head using N(0,1/(d_model + 1)) distribution.
|
|
||||||
- We train a Bloom-560m reward model for 1 epoch and find the test acc of the model achieve the performance mentions in [Anthropics paper](https://arxiv.org/abs/2204.05862).
|
|
||||||
|
|
||||||
### Experiment result
|
|
||||||
Model performance in [Anthropics paper](https://arxiv.org/abs/2204.05862):
|
|
||||||
|
|
||||||
<div align=center> <img width="512" alt="image" src="https://user-images.githubusercontent.com/70618399/225263321-8d64c3a8-6877-4cc8-9b61-0e1c52d3d94f.png">
|
|
||||||
|
|
||||||
<div align=left>Our training & test result of bloom-560m for 1 epoch:
|
|
||||||
|
|
||||||
<div align=center> <img width="512" alt="image" src="https://user-images.githubusercontent.com/70618399/225262950-a7f0a686-25de-44ec-98f2-11b83ea86674.png">
|
|
||||||
|
|
||||||
<div align=left>
|
|
||||||
|
|
||||||
## Train with dummy prompt data (Stage 3)
|
|
||||||
|
|
||||||
This script supports 4 kinds of strategies:
|
|
||||||
|
|
||||||
- naive
|
|
||||||
- ddp
|
|
||||||
- colossalai_zero2
|
|
||||||
- colossalai_gemini
|
|
||||||
|
|
||||||
It uses random generated prompt data.
|
|
||||||
|
|
||||||
Naive strategy only support single GPU training:
|
|
||||||
|
|
||||||
```shell
|
|
||||||
python train_dummy.py --strategy naive
|
|
||||||
# display cli help
|
|
||||||
python train_dummy.py -h
|
|
||||||
```
|
|
||||||
|
|
||||||
DDP strategy and ColossalAI strategy support multi GPUs training:
|
|
||||||
|
|
||||||
```shell
|
|
||||||
# run DDP on 2 GPUs
|
|
||||||
torchrun --standalone --nproc_per_node=2 train_dummy.py --strategy ddp
|
|
||||||
# run ColossalAI on 2 GPUs
|
|
||||||
torchrun --standalone --nproc_per_node=2 train_dummy.py --strategy colossalai_zero2
|
|
||||||
```
|
|
||||||
|
|
||||||
## Train with real prompt data (Stage 3)
|
|
||||||
|
|
||||||
We use [awesome-chatgpt-prompts](https://huggingface.co/datasets/fka/awesome-chatgpt-prompts) as example dataset. It is a small dataset with hundreds of prompts.
|
|
||||||
|
|
||||||
You should download `prompts.csv` first.
|
|
||||||
|
|
||||||
This script also supports 4 strategies.
|
|
||||||
|
|
||||||
```shell
|
|
||||||
# display cli help
|
|
||||||
python train_dummy.py -h
|
|
||||||
# run naive on 1 GPU
|
|
||||||
python train_prompts.py prompts.csv --strategy naive
|
|
||||||
# run DDP on 2 GPUs
|
|
||||||
torchrun --standalone --nproc_per_node=2 train_prompts.py prompts.csv --strategy ddp
|
|
||||||
# run ColossalAI on 2 GPUs
|
|
||||||
torchrun --standalone --nproc_per_node=2 train_prompts.py prompts.csv --strategy colossalai_zero2
|
|
||||||
```
|
|
||||||
|
|
||||||
## Inference example(After Stage3)
|
|
||||||
We support naive inference demo after training.
|
|
||||||
```shell
|
|
||||||
# inference, using pretrain path to configure model
|
|
||||||
python inference.py --model_path <your actor model path> --model <your model type> --pretrain <your pretrain model name/path>
|
|
||||||
# example
|
|
||||||
python inference.py --model_path ./actor_checkpoint_prompts.pt --pretrain bigscience/bloom-560m --model bloom
|
|
||||||
```
|
|
||||||
|
|
||||||
## Attention
|
|
||||||
The examples is just a demo for testing our progress of RM and PPO training.
|
|
||||||
|
|
||||||
|
|
||||||
#### data
|
|
||||||
- [x] [rm-static](https://huggingface.co/datasets/Dahoas/rm-static)
|
|
||||||
- [x] [hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf)
|
|
||||||
- [ ] [openai/summarize_from_feedback](https://huggingface.co/datasets/openai/summarize_from_feedback)
|
|
||||||
- [ ] [openai/webgpt_comparisons](https://huggingface.co/datasets/openai/webgpt_comparisons)
|
|
||||||
- [ ] [Dahoas/instruct-synthetic-prompt-responses](https://huggingface.co/datasets/Dahoas/instruct-synthetic-prompt-responses)
|
|
||||||
|
|
||||||
## Support Model
|
|
||||||
|
|
||||||
### GPT
|
|
||||||
- [x] GPT2-S (s)
|
|
||||||
- [x] GPT2-M (m)
|
|
||||||
- [x] GPT2-L (l)
|
|
||||||
- [ ] GPT2-XL (xl)
|
|
||||||
- [x] GPT2-4B (4b)
|
|
||||||
- [ ] GPT2-6B (6b)
|
|
||||||
- [ ] GPT2-8B (8b)
|
|
||||||
- [ ] GPT2-10B (10b)
|
|
||||||
- [ ] GPT2-12B (12b)
|
|
||||||
- [ ] GPT2-15B (15b)
|
|
||||||
- [ ] GPT2-18B (18b)
|
|
||||||
- [ ] GPT2-20B (20b)
|
|
||||||
- [ ] GPT2-24B (24b)
|
|
||||||
- [ ] GPT2-28B (28b)
|
|
||||||
- [ ] GPT2-32B (32b)
|
|
||||||
- [ ] GPT2-36B (36b)
|
|
||||||
- [ ] GPT2-40B (40b)
|
|
||||||
- [ ] GPT3 (175b)
|
|
||||||
|
|
||||||
### BLOOM
|
|
||||||
- [x] [BLOOM-560m](https://huggingface.co/bigscience/bloom-560m)
|
|
||||||
- [x] [BLOOM-1b1](https://huggingface.co/bigscience/bloom-1b1)
|
|
||||||
- [x] [BLOOM-3b](https://huggingface.co/bigscience/bloom-3b)
|
|
||||||
- [x] [BLOOM-7b](https://huggingface.co/bigscience/bloom-7b1)
|
|
||||||
- [ ] BLOOM-175b
|
|
||||||
|
|
||||||
### OPT
|
|
||||||
- [x] [OPT-125M](https://huggingface.co/facebook/opt-125m)
|
|
||||||
- [x] [OPT-350M](https://huggingface.co/facebook/opt-350m)
|
|
||||||
- [ ] [OPT-1.3B](https://huggingface.co/facebook/opt-1.3b)
|
|
||||||
- [ ] [OPT-2.7B](https://huggingface.co/facebook/opt-2.7b)
|
|
||||||
- [ ] [OPT-6.7B](https://huggingface.co/facebook/opt-6.7b)
|
|
||||||
- [ ] [OPT-13B](https://huggingface.co/facebook/opt-13b)
|
|
||||||
- [ ] [OPT-30B](https://huggingface.co/facebook/opt-30b)
|
|
@ -1,59 +0,0 @@
|
|||||||
import argparse
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from chatgpt.models.bloom import BLOOMActor
|
|
||||||
from chatgpt.models.gpt import GPTActor
|
|
||||||
from chatgpt.models.opt import OPTActor
|
|
||||||
from transformers import AutoTokenizer
|
|
||||||
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
|
|
||||||
|
|
||||||
|
|
||||||
def eval(args):
|
|
||||||
# configure model
|
|
||||||
if args.model == 'gpt2':
|
|
||||||
actor = GPTActor(pretrained=args.pretrain).to(torch.cuda.current_device())
|
|
||||||
elif args.model == 'bloom':
|
|
||||||
actor = BLOOMActor(pretrained=args.pretrain).to(torch.cuda.current_device())
|
|
||||||
elif args.model == 'opt':
|
|
||||||
actor = OPTActor(pretrained=args.pretrain).to(torch.cuda.current_device())
|
|
||||||
else:
|
|
||||||
raise ValueError(f'Unsupported model "{args.model}"')
|
|
||||||
|
|
||||||
state_dict = torch.load(args.model_path)
|
|
||||||
actor.model.load_state_dict(state_dict)
|
|
||||||
|
|
||||||
# configure tokenizer
|
|
||||||
if args.model == 'gpt2':
|
|
||||||
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
|
||||||
elif args.model == 'bloom':
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained('bigscience/bloom-560m')
|
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
|
||||||
elif args.model == 'opt':
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained('facebook/opt-350m')
|
|
||||||
else:
|
|
||||||
raise ValueError(f'Unsupported model "{args.model}"')
|
|
||||||
|
|
||||||
actor.eval()
|
|
||||||
input = args.input
|
|
||||||
input_ids = tokenizer.encode(input, return_tensors='pt').to(torch.cuda.current_device())
|
|
||||||
outputs = actor.generate(input_ids,
|
|
||||||
max_length=args.max_length,
|
|
||||||
do_sample=True,
|
|
||||||
top_k=50,
|
|
||||||
top_p=0.95,
|
|
||||||
num_return_sequences=1)
|
|
||||||
output = tokenizer.batch_decode(outputs[0], skip_special_tokens=True)
|
|
||||||
print(output)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt'])
|
|
||||||
# We suggest to use the pretrained model from HuggingFace, use pretrain to configure model
|
|
||||||
parser.add_argument('--pretrain', type=str, default=None)
|
|
||||||
parser.add_argument('--model_path', type=str, default=None)
|
|
||||||
parser.add_argument('--input', type=str, default='Question: How are you ? Answer:')
|
|
||||||
parser.add_argument('--max_length', type=int, default=100)
|
|
||||||
args = parser.parse_args()
|
|
||||||
eval(args)
|
|
@ -1,2 +0,0 @@
|
|||||||
pandas>=1.4.1
|
|
||||||
sentencepiece
|
|
@ -1,97 +0,0 @@
|
|||||||
#!/usr/bin/env bash
|
|
||||||
|
|
||||||
set -xue
|
|
||||||
|
|
||||||
if [ -z "$PROMPT_PATH" ]; then
|
|
||||||
echo "Please set \$PROMPT_PATH to the path to prompts csv."
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
BASE=$(realpath $(dirname $0))
|
|
||||||
|
|
||||||
export OMP_NUM_THREADS=8
|
|
||||||
|
|
||||||
# install requirements
|
|
||||||
pip install -r ${BASE}/requirements.txt
|
|
||||||
|
|
||||||
# train dummy
|
|
||||||
python ${BASE}/train_dummy.py --strategy naive --num_episodes 1 \
|
|
||||||
--max_timesteps 2 --update_timesteps 2 \
|
|
||||||
--max_epochs 1 --train_batch_size 2 --lora_rank 4
|
|
||||||
|
|
||||||
torchrun --standalone --nproc_per_node=2 ${BASE}/train_dummy.py \
|
|
||||||
--strategy colossalai_gemini --num_episodes 1 --max_timesteps 2 \
|
|
||||||
--update_timesteps 2 --max_epochs 1 --train_batch_size 2\
|
|
||||||
--pretrain 'facebook/opt-350m' --model opt --lora_rank 4\
|
|
||||||
--save_path ${BASE}/actor_checkpoint_dummy.pt
|
|
||||||
python ${BASE}/inference.py --model_path ${BASE}/actor_checkpoint_dummy.pt --pretrain 'facebook/opt-350m' --model opt
|
|
||||||
|
|
||||||
torchrun --standalone --nproc_per_node=2 ${BASE}/train_dummy.py \
|
|
||||||
--strategy ddp --num_episodes 1 --max_timesteps 2 \
|
|
||||||
--update_timesteps 2 --max_epochs 1 --train_batch_size 2\
|
|
||||||
--pretrain 'facebook/opt-350m' --model opt --lora_rank 4\
|
|
||||||
--save_path ${BASE}/actor_checkpoint_dummy.pt
|
|
||||||
python ${BASE}/inference.py --model_path ${BASE}/actor_checkpoint_dummy.pt --pretrain 'facebook/opt-350m' --model opt
|
|
||||||
|
|
||||||
torchrun --standalone --nproc_per_node=2 ${BASE}/train_dummy.py \
|
|
||||||
--strategy colossalai_zero2 --num_episodes 1 --max_timesteps 2 \
|
|
||||||
--update_timesteps 2 --max_epochs 1 --train_batch_size 2\
|
|
||||||
--pretrain 'gpt2' --model gpt2 --lora_rank 4\
|
|
||||||
--save_path ${BASE}/actor_checkpoint_dummy.pt
|
|
||||||
python ${BASE}/inference.py --model_path ${BASE}/actor_checkpoint_dummy.pt --pretrain 'gpt2' --model gpt2
|
|
||||||
|
|
||||||
rm -rf ${BASE}/actor_checkpoint_dummy.pt
|
|
||||||
|
|
||||||
# train prompts
|
|
||||||
python ${BASE}/train_prompts.py $PROMPT_PATH --strategy naive --num_episodes 1 \
|
|
||||||
--max_timesteps 2 --update_timesteps 2 \
|
|
||||||
--max_epochs 1 --train_batch_size 2 --lora_rank 4
|
|
||||||
|
|
||||||
torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py $PROMPT_PATH \
|
|
||||||
--strategy colossalai_zero2 --num_episodes 1 --max_timesteps 2 \
|
|
||||||
--update_timesteps 2 --max_epochs 1 --train_batch_size 2\
|
|
||||||
--pretrain 'facebook/opt-350m' --model opt --lora_rank 4\
|
|
||||||
--save_path ${BASE}/actor_checkpoint_prompts.pt
|
|
||||||
python ${BASE}/inference.py --model_path ${BASE}/actor_checkpoint_prompts.pt --pretrain 'facebook/opt-350m' --model opt
|
|
||||||
|
|
||||||
torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py $PROMPT_PATH \
|
|
||||||
--strategy ddp --num_episodes 1 --max_timesteps 2 \
|
|
||||||
--update_timesteps 2 --max_epochs 1 --train_batch_size 2\
|
|
||||||
--pretrain 'gpt2' --model gpt2 --lora_rank 4\
|
|
||||||
--save_path ${BASE}/actor_checkpoint_prompts.pt
|
|
||||||
python ${BASE}/inference.py --model_path ${BASE}/actor_checkpoint_prompts.pt --pretrain 'gpt2' --model gpt2
|
|
||||||
|
|
||||||
torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py $PROMPT_PATH \
|
|
||||||
--strategy colossalai_gemini --num_episodes 1 --max_timesteps 2 \
|
|
||||||
--update_timesteps 2 --max_epochs 1 --train_batch_size 2\
|
|
||||||
--pretrain 'gpt2' --model gpt2 --lora_rank 4\
|
|
||||||
--save_path ${BASE}/actor_checkpoint_prompts.pt
|
|
||||||
python ${BASE}/inference.py --model_path ${BASE}/actor_checkpoint_prompts.pt --pretrain 'gpt2' --model gpt2
|
|
||||||
|
|
||||||
rm -rf ${BASE}/actor_checkpoint_prompts.pt
|
|
||||||
|
|
||||||
# train rm
|
|
||||||
torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
|
|
||||||
--pretrain 'facebook/opt-350m' --model 'opt' \
|
|
||||||
--strategy colossalai_zero2 --loss_fn 'log_sig'\
|
|
||||||
--dataset 'Anthropic/hh-rlhf' --subset 'harmless-base'\
|
|
||||||
--test True --lora_rank 4
|
|
||||||
|
|
||||||
torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
|
|
||||||
--pretrain 'gpt2' --model 'gpt2' \
|
|
||||||
--strategy colossalai_gemini --loss_fn 'log_exp'\
|
|
||||||
--dataset 'Dahoas/rm-static' --test True --lora_rank 4
|
|
||||||
|
|
||||||
torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
|
|
||||||
--pretrain 'bigscience/bloom-560m' --model 'bloom' \
|
|
||||||
--strategy colossalai_zero2 --loss_fn 'log_sig'\
|
|
||||||
--dataset 'Anthropic/hh-rlhf' --subset 'harmless-base'\
|
|
||||||
--test True --lora_rank 4
|
|
||||||
|
|
||||||
torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
|
|
||||||
--pretrain 'microsoft/deberta-v3-large' --model 'deberta' \
|
|
||||||
--strategy colossalai_zero2 --loss_fn 'log_sig'\
|
|
||||||
--dataset 'Anthropic/hh-rlhf' --subset 'harmless-base'\
|
|
||||||
--test True --lora_rank 4
|
|
||||||
|
|
||||||
rm -rf ${BASE}/rm_ckpt.pt
|
|
@ -1,148 +0,0 @@
|
|||||||
import argparse
|
|
||||||
from copy import deepcopy
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from chatgpt.models.base import RewardModel
|
|
||||||
from chatgpt.models.bloom import BLOOMActor, BLOOMCritic
|
|
||||||
from chatgpt.models.gpt import GPTActor, GPTCritic
|
|
||||||
from chatgpt.models.opt import OPTActor, OPTCritic
|
|
||||||
from chatgpt.trainer import PPOTrainer
|
|
||||||
from chatgpt.trainer.callbacks import SaveCheckpoint
|
|
||||||
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
|
|
||||||
from torch.optim import Adam
|
|
||||||
from transformers import AutoTokenizer, BloomTokenizerFast
|
|
||||||
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
|
|
||||||
|
|
||||||
from colossalai.nn.optimizer import HybridAdam
|
|
||||||
|
|
||||||
|
|
||||||
def preprocess_batch(samples):
|
|
||||||
input_ids = torch.stack(samples)
|
|
||||||
attention_mask = torch.ones_like(input_ids, dtype=torch.long)
|
|
||||||
return {'input_ids': input_ids, 'attention_mask': attention_mask}
|
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
|
||||||
# configure strategy
|
|
||||||
if args.strategy == 'naive':
|
|
||||||
strategy = NaiveStrategy()
|
|
||||||
elif args.strategy == 'ddp':
|
|
||||||
strategy = DDPStrategy()
|
|
||||||
elif args.strategy == 'colossalai_gemini':
|
|
||||||
strategy = ColossalAIStrategy(stage=3, placement_policy='cuda', initial_scale=2**5)
|
|
||||||
elif args.strategy == 'colossalai_zero2':
|
|
||||||
strategy = ColossalAIStrategy(stage=2, placement_policy='cuda')
|
|
||||||
else:
|
|
||||||
raise ValueError(f'Unsupported strategy "{args.strategy}"')
|
|
||||||
|
|
||||||
# configure model
|
|
||||||
with strategy.model_init_context():
|
|
||||||
if args.model == 'gpt2':
|
|
||||||
actor = GPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
|
||||||
critic = GPTCritic(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
|
||||||
elif args.model == 'bloom':
|
|
||||||
actor = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
|
||||||
critic = BLOOMCritic(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
|
||||||
elif args.model == 'opt':
|
|
||||||
actor = OPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
|
||||||
critic = OPTCritic(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
|
||||||
else:
|
|
||||||
raise ValueError(f'Unsupported model "{args.model}"')
|
|
||||||
|
|
||||||
initial_model = deepcopy(actor).to(torch.cuda.current_device())
|
|
||||||
reward_model = RewardModel(deepcopy(critic.model), deepcopy(critic.value_head)).to(torch.cuda.current_device())
|
|
||||||
|
|
||||||
# configure optimizer
|
|
||||||
if args.strategy.startswith('colossalai'):
|
|
||||||
actor_optim = HybridAdam(actor.parameters(), lr=5e-6)
|
|
||||||
critic_optim = HybridAdam(critic.parameters(), lr=5e-6)
|
|
||||||
else:
|
|
||||||
actor_optim = Adam(actor.parameters(), lr=5e-6)
|
|
||||||
critic_optim = Adam(critic.parameters(), lr=5e-6)
|
|
||||||
|
|
||||||
# configure tokenizer
|
|
||||||
if args.model == 'gpt2':
|
|
||||||
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
|
||||||
elif args.model == 'bloom':
|
|
||||||
tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain)
|
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
|
||||||
elif args.model == 'opt':
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
|
||||||
else:
|
|
||||||
raise ValueError(f'Unsupported model "{args.model}"')
|
|
||||||
|
|
||||||
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare(
|
|
||||||
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model)
|
|
||||||
|
|
||||||
callbacks = []
|
|
||||||
if args.save_ckpt_path:
|
|
||||||
ckpt_callback = SaveCheckpoint(
|
|
||||||
args.save_ckpt_path,
|
|
||||||
args.save_ckpt_interval,
|
|
||||||
strategy,
|
|
||||||
actor,
|
|
||||||
critic,
|
|
||||||
actor_optim,
|
|
||||||
critic_optim,
|
|
||||||
)
|
|
||||||
callbacks.append(ckpt_callback)
|
|
||||||
|
|
||||||
# configure trainer
|
|
||||||
|
|
||||||
trainer = PPOTrainer(strategy,
|
|
||||||
actor,
|
|
||||||
critic,
|
|
||||||
reward_model,
|
|
||||||
initial_model,
|
|
||||||
actor_optim,
|
|
||||||
critic_optim,
|
|
||||||
max_epochs=args.max_epochs,
|
|
||||||
train_batch_size=args.train_batch_size,
|
|
||||||
tokenizer=preprocess_batch,
|
|
||||||
max_length=128,
|
|
||||||
do_sample=True,
|
|
||||||
temperature=1.0,
|
|
||||||
top_k=50,
|
|
||||||
pad_token_id=tokenizer.pad_token_id,
|
|
||||||
eos_token_id=tokenizer.eos_token_id,
|
|
||||||
callbacks=callbacks)
|
|
||||||
|
|
||||||
random_prompts = torch.randint(tokenizer.vocab_size, (1000, 64), device=torch.cuda.current_device())
|
|
||||||
trainer.fit(random_prompts,
|
|
||||||
num_episodes=args.num_episodes,
|
|
||||||
max_timesteps=args.max_timesteps,
|
|
||||||
update_timesteps=args.update_timesteps)
|
|
||||||
|
|
||||||
# save model checkpoint after fitting
|
|
||||||
strategy.save_model(actor, args.save_path, only_rank0=True)
|
|
||||||
# save optimizer checkpoint on all ranks
|
|
||||||
if args.need_optim_ckpt:
|
|
||||||
strategy.save_optimizer(actor_optim,
|
|
||||||
'actor_optim_checkpoint_dummy_%d.pt' % (torch.cuda.current_device()),
|
|
||||||
only_rank0=False)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument('--strategy',
|
|
||||||
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
|
|
||||||
default='naive')
|
|
||||||
parser.add_argument('--model', type=str, default='gpt2', choices=['gpt2', 'bloom', 'opt'])
|
|
||||||
parser.add_argument('--pretrain', type=str, default=None)
|
|
||||||
parser.add_argument('--save_path', type=str, default='actor_checkpoint_dummy.pt')
|
|
||||||
parser.add_argument('--need_optim_ckpt', type=bool, default=False)
|
|
||||||
parser.add_argument('--num_episodes', type=int, default=50)
|
|
||||||
parser.add_argument('--max_timesteps', type=int, default=10)
|
|
||||||
parser.add_argument('--update_timesteps', type=int, default=10)
|
|
||||||
parser.add_argument('--max_epochs', type=int, default=5)
|
|
||||||
parser.add_argument('--train_batch_size', type=int, default=8)
|
|
||||||
parser.add_argument('--experience_batch_size', type=int, default=8)
|
|
||||||
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
|
|
||||||
parser.add_argument('--save_ckpt_path',
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="path to save checkpoint, None means not to save")
|
|
||||||
parser.add_argument('--save_ckpt_interval', type=int, default=1, help="the interval of episode to save checkpoint")
|
|
||||||
args = parser.parse_args()
|
|
||||||
main(args)
|
|
@ -1,18 +0,0 @@
|
|||||||
set_n_least_used_CUDA_VISIBLE_DEVICES() {
|
|
||||||
local n=${1:-"9999"}
|
|
||||||
echo "GPU Memory Usage:"
|
|
||||||
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \
|
|
||||||
| tail -n +2 \
|
|
||||||
| nl -v 0 \
|
|
||||||
| tee /dev/tty \
|
|
||||||
| sort -g -k 2 \
|
|
||||||
| awk '{print $1}' \
|
|
||||||
| head -n $n)
|
|
||||||
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
|
|
||||||
echo "Now CUDA_VISIBLE_DEVICES is set to:"
|
|
||||||
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
|
||||||
}
|
|
||||||
|
|
||||||
set_n_least_used_CUDA_VISIBLE_DEVICES 2
|
|
||||||
|
|
||||||
torchrun --standalone --nproc_per_node=2 train_dummy.py --strategy colossalai_zero2
|
|
@ -1,132 +0,0 @@
|
|||||||
import argparse
|
|
||||||
from copy import deepcopy
|
|
||||||
|
|
||||||
import pandas as pd
|
|
||||||
import torch
|
|
||||||
from chatgpt.models.base import RewardModel
|
|
||||||
from chatgpt.models.bloom import BLOOMActor, BLOOMCritic
|
|
||||||
from chatgpt.models.gpt import GPTActor, GPTCritic
|
|
||||||
from chatgpt.models.opt import OPTActor, OPTCritic
|
|
||||||
from chatgpt.trainer import PPOTrainer
|
|
||||||
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
|
|
||||||
from torch.optim import Adam
|
|
||||||
from transformers import AutoTokenizer, BloomTokenizerFast
|
|
||||||
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
|
|
||||||
|
|
||||||
from colossalai.nn.optimizer import HybridAdam
|
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
|
||||||
# configure strategy
|
|
||||||
if args.strategy == 'naive':
|
|
||||||
strategy = NaiveStrategy()
|
|
||||||
elif args.strategy == 'ddp':
|
|
||||||
strategy = DDPStrategy()
|
|
||||||
elif args.strategy == 'colossalai_gemini':
|
|
||||||
strategy = ColossalAIStrategy(stage=3, placement_policy='cuda', initial_scale=2**5)
|
|
||||||
elif args.strategy == 'colossalai_zero2':
|
|
||||||
strategy = ColossalAIStrategy(stage=2, placement_policy='cuda')
|
|
||||||
else:
|
|
||||||
raise ValueError(f'Unsupported strategy "{args.strategy}"')
|
|
||||||
|
|
||||||
# configure model
|
|
||||||
with strategy.model_init_context():
|
|
||||||
if args.model == 'gpt2':
|
|
||||||
actor = GPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
|
||||||
critic = GPTCritic(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
|
||||||
elif args.model == 'bloom':
|
|
||||||
actor = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
|
||||||
critic = BLOOMCritic(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
|
||||||
elif args.model == 'opt':
|
|
||||||
actor = OPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
|
||||||
critic = OPTCritic(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
|
||||||
else:
|
|
||||||
raise ValueError(f'Unsupported model "{args.model}"')
|
|
||||||
|
|
||||||
initial_model = deepcopy(actor)
|
|
||||||
reward_model = RewardModel(deepcopy(critic.model), deepcopy(critic.value_head)).to(torch.cuda.current_device())
|
|
||||||
|
|
||||||
# configure optimizer
|
|
||||||
if args.strategy.startswith('colossalai'):
|
|
||||||
actor_optim = HybridAdam(actor.parameters(), lr=5e-6)
|
|
||||||
critic_optim = HybridAdam(critic.parameters(), lr=5e-6)
|
|
||||||
else:
|
|
||||||
actor_optim = Adam(actor.parameters(), lr=5e-6)
|
|
||||||
critic_optim = Adam(critic.parameters(), lr=5e-6)
|
|
||||||
|
|
||||||
# configure tokenizer
|
|
||||||
if args.model == 'gpt2':
|
|
||||||
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
|
||||||
elif args.model == 'bloom':
|
|
||||||
tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain)
|
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
|
||||||
elif args.model == 'opt':
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
|
||||||
else:
|
|
||||||
raise ValueError(f'Unsupported model "{args.model}"')
|
|
||||||
|
|
||||||
dataset = pd.read_csv(args.prompt_path)['prompt']
|
|
||||||
|
|
||||||
def tokenize_fn(texts):
|
|
||||||
# MUST padding to max length to ensure inputs of all ranks have the same length
|
|
||||||
# Different length may lead to hang when using gemini, as different generation steps
|
|
||||||
batch = tokenizer(texts, return_tensors='pt', max_length=96, padding='max_length', truncation=True)
|
|
||||||
return {k: v.cuda() for k, v in batch.items()}
|
|
||||||
|
|
||||||
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare(
|
|
||||||
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model)
|
|
||||||
|
|
||||||
# configure trainer
|
|
||||||
trainer = PPOTrainer(
|
|
||||||
strategy,
|
|
||||||
actor,
|
|
||||||
critic,
|
|
||||||
reward_model,
|
|
||||||
initial_model,
|
|
||||||
actor_optim,
|
|
||||||
critic_optim,
|
|
||||||
max_epochs=args.max_epochs,
|
|
||||||
train_batch_size=args.train_batch_size,
|
|
||||||
experience_batch_size=args.experience_batch_size,
|
|
||||||
tokenizer=tokenize_fn,
|
|
||||||
max_length=128,
|
|
||||||
do_sample=True,
|
|
||||||
temperature=1.0,
|
|
||||||
top_k=50,
|
|
||||||
pad_token_id=tokenizer.pad_token_id,
|
|
||||||
eos_token_id=tokenizer.eos_token_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
trainer.fit(dataset,
|
|
||||||
num_episodes=args.num_episodes,
|
|
||||||
max_timesteps=args.max_timesteps,
|
|
||||||
update_timesteps=args.update_timesteps)
|
|
||||||
# save model checkpoint after fitting
|
|
||||||
strategy.save_model(actor, args.save_path, only_rank0=True)
|
|
||||||
# save optimizer checkpoint on all ranks
|
|
||||||
if args.need_optim_ckpt:
|
|
||||||
strategy.save_optimizer(actor_optim,
|
|
||||||
'actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()),
|
|
||||||
only_rank0=False)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument('prompt_path')
|
|
||||||
parser.add_argument('--strategy',
|
|
||||||
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
|
|
||||||
default='naive')
|
|
||||||
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt'])
|
|
||||||
parser.add_argument('--pretrain', type=str, default=None)
|
|
||||||
parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts.pt')
|
|
||||||
parser.add_argument('--need_optim_ckpt', type=bool, default=False)
|
|
||||||
parser.add_argument('--num_episodes', type=int, default=10)
|
|
||||||
parser.add_argument('--max_timesteps', type=int, default=10)
|
|
||||||
parser.add_argument('--update_timesteps', type=int, default=10)
|
|
||||||
parser.add_argument('--max_epochs', type=int, default=5)
|
|
||||||
parser.add_argument('--train_batch_size', type=int, default=8)
|
|
||||||
parser.add_argument('--experience_batch_size', type=int, default=8)
|
|
||||||
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
|
|
||||||
args = parser.parse_args()
|
|
||||||
main(args)
|
|
@ -1,18 +0,0 @@
|
|||||||
set_n_least_used_CUDA_VISIBLE_DEVICES() {
|
|
||||||
local n=${1:-"9999"}
|
|
||||||
echo "GPU Memory Usage:"
|
|
||||||
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \
|
|
||||||
| tail -n +2 \
|
|
||||||
| nl -v 0 \
|
|
||||||
| tee /dev/tty \
|
|
||||||
| sort -g -k 2 \
|
|
||||||
| awk '{print $1}' \
|
|
||||||
| head -n $n)
|
|
||||||
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
|
|
||||||
echo "Now CUDA_VISIBLE_DEVICES is set to:"
|
|
||||||
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
|
||||||
}
|
|
||||||
|
|
||||||
set_n_least_used_CUDA_VISIBLE_DEVICES 2
|
|
||||||
|
|
||||||
torchrun --standalone --nproc_per_node=2 train_prompts.py prompts.csv --strategy colossalai_zero2
|
|
@ -1,143 +0,0 @@
|
|||||||
import argparse
|
|
||||||
|
|
||||||
import loralib as lora
|
|
||||||
import torch
|
|
||||||
from chatgpt.dataset import HhRlhfDataset, RmStaticDataset
|
|
||||||
from chatgpt.models import LogSigLoss, LogExpLoss
|
|
||||||
from chatgpt.models.base import RewardModel
|
|
||||||
from chatgpt.models.bloom import BLOOMRM
|
|
||||||
from chatgpt.models.gpt import GPTRM
|
|
||||||
from chatgpt.models.opt import OPTRM
|
|
||||||
from chatgpt.models.deberta import DebertaRM
|
|
||||||
from chatgpt.trainer import RewardModelTrainer
|
|
||||||
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
|
|
||||||
from datasets import load_dataset
|
|
||||||
from random import randint
|
|
||||||
from torch.optim import Adam
|
|
||||||
from transformers import AutoTokenizer, BloomTokenizerFast, DebertaV2Tokenizer
|
|
||||||
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
|
|
||||||
|
|
||||||
from colossalai.nn.optimizer import HybridAdam
|
|
||||||
|
|
||||||
def train(args):
|
|
||||||
# configure strategy
|
|
||||||
if args.strategy == 'naive':
|
|
||||||
strategy = NaiveStrategy()
|
|
||||||
elif args.strategy == 'ddp':
|
|
||||||
strategy = DDPStrategy()
|
|
||||||
elif args.strategy == 'colossalai_gemini':
|
|
||||||
strategy = ColossalAIStrategy(stage=3, placement_policy='cuda')
|
|
||||||
elif args.strategy == 'colossalai_zero2':
|
|
||||||
strategy = ColossalAIStrategy(stage=2, placement_policy='cuda')
|
|
||||||
else:
|
|
||||||
raise ValueError(f'Unsupported strategy "{args.strategy}"')
|
|
||||||
|
|
||||||
# configure model
|
|
||||||
with strategy.model_init_context():
|
|
||||||
if args.model == 'bloom':
|
|
||||||
model = BLOOMRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
|
||||||
elif args.model == 'opt':
|
|
||||||
model = OPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
|
||||||
elif args.model == 'gpt2':
|
|
||||||
model = GPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
|
||||||
elif args.model == 'deberta':
|
|
||||||
model = DebertaRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
|
||||||
else:
|
|
||||||
raise ValueError(f'Unsupported model "{args.model}"')
|
|
||||||
|
|
||||||
if args.model_path is not None:
|
|
||||||
state_dict = torch.load(args.model_path)
|
|
||||||
model.load_state_dict(state_dict)
|
|
||||||
|
|
||||||
# configure tokenizer
|
|
||||||
if args.model == 'gpt2':
|
|
||||||
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
|
||||||
elif args.model == 'bloom':
|
|
||||||
tokenizer = BloomTokenizerFast.from_pretrained('bigscience/bloom-560m')
|
|
||||||
elif args.model == 'opt':
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
|
||||||
elif args.model == 'deberta':
|
|
||||||
tokenizer = DebertaV2Tokenizer.from_pretrained('microsoft/deberta-v3-large')
|
|
||||||
else:
|
|
||||||
raise ValueError(f'Unsupported model "{args.model}"')
|
|
||||||
max_len = args.max_len
|
|
||||||
|
|
||||||
# configure optimizer
|
|
||||||
if args.strategy.startswith('colossalai'):
|
|
||||||
optim = HybridAdam(model.parameters(), lr=1.5e-5)
|
|
||||||
else:
|
|
||||||
optim = Adam(model.parameters(), lr=1.5e-5)
|
|
||||||
|
|
||||||
# configure loss function
|
|
||||||
if args.loss_fn == 'log_sig':
|
|
||||||
loss_fn = LogSigLoss()
|
|
||||||
elif args.loss_fn == 'log_exp':
|
|
||||||
loss_fn = LogExpLoss()
|
|
||||||
else:
|
|
||||||
raise ValueError(f'Unsupported loss function "{args.loss_fn}"')
|
|
||||||
|
|
||||||
# prepare for data and dataset
|
|
||||||
if args.subset is not None:
|
|
||||||
data = load_dataset(args.dataset, data_dir=args.subset)
|
|
||||||
else:
|
|
||||||
data = load_dataset(args.dataset)
|
|
||||||
|
|
||||||
if args.test:
|
|
||||||
train_data = data['train'].select(range(100))
|
|
||||||
eval_data = data['test'].select(range(10))
|
|
||||||
else:
|
|
||||||
train_data = data['train']
|
|
||||||
eval_data = data['test']
|
|
||||||
valid_data = data['test'].select((randint(0, len(eval_data) - 1) for _ in range(len(eval_data)//10)))
|
|
||||||
|
|
||||||
if args.dataset == 'Dahoas/rm-static':
|
|
||||||
train_dataset = RmStaticDataset(train_data, tokenizer, max_len)
|
|
||||||
valid_dataset = RmStaticDataset(valid_data, tokenizer, max_len)
|
|
||||||
eval_dataset = RmStaticDataset(eval_data, tokenizer, max_len)
|
|
||||||
elif args.dataset == 'Anthropic/hh-rlhf':
|
|
||||||
train_dataset = HhRlhfDataset(train_data, tokenizer, max_len)
|
|
||||||
valid_dataset = HhRlhfDataset(valid_data, tokenizer, max_len)
|
|
||||||
eval_dataset = HhRlhfDataset(eval_data, tokenizer, max_len)
|
|
||||||
else:
|
|
||||||
raise ValueError(f'Unsupported dataset "{args.dataset}"')
|
|
||||||
|
|
||||||
trainer = RewardModelTrainer(model=model,
|
|
||||||
strategy=strategy,
|
|
||||||
optim=optim,
|
|
||||||
loss_fn = loss_fn,
|
|
||||||
train_dataset=train_dataset,
|
|
||||||
valid_dataset=valid_dataset,
|
|
||||||
eval_dataset=eval_dataset,
|
|
||||||
batch_size=args.batch_size,
|
|
||||||
max_epochs=args.max_epochs)
|
|
||||||
|
|
||||||
trainer.fit()
|
|
||||||
# save model checkpoint after fitting on only rank0
|
|
||||||
strategy.save_model(trainer.model, args.save_path, only_rank0=True)
|
|
||||||
# save optimizer checkpoint on all ranks
|
|
||||||
if args.need_optim_ckpt:
|
|
||||||
strategy.save_optimizer(trainer.optimizer, 'rm_optim_checkpoint_%d.pt' % (torch.cuda.current_device()), only_rank0=False)
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument('--strategy',
|
|
||||||
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
|
|
||||||
default='naive')
|
|
||||||
parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'deberta'], default='bloom')
|
|
||||||
parser.add_argument('--pretrain', type=str, default=None)
|
|
||||||
parser.add_argument('--model_path', type=str, default=None)
|
|
||||||
parser.add_argument('--need_optim_ckpt', type=bool, default=False)
|
|
||||||
parser.add_argument('--dataset', type=str,
|
|
||||||
choices=['Anthropic/hh-rlhf', 'Dahoas/rm-static'],
|
|
||||||
default='Dahoas/rm-static')
|
|
||||||
parser.add_argument('--subset', type=str, default=None)
|
|
||||||
parser.add_argument('--save_path', type=str, default='rm_ckpt.pt')
|
|
||||||
parser.add_argument('--max_epochs', type=int, default=1)
|
|
||||||
parser.add_argument('--batch_size', type=int, default=1)
|
|
||||||
parser.add_argument('--max_len', type=int, default=512)
|
|
||||||
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
|
|
||||||
parser.add_argument('--loss_fn', type=str, default='log_sig', choices=['log_sig', 'log_exp'])
|
|
||||||
parser.add_argument('--test', type=bool, default=False)
|
|
||||||
args = parser.parse_args()
|
|
||||||
train(args)
|
|
@ -1,8 +0,0 @@
|
|||||||
set_n_least_used_CUDA_VISIBLE_DEVICES 1
|
|
||||||
|
|
||||||
python train_reward_model.py --pretrain 'microsoft/deberta-v3-large' \
|
|
||||||
--model 'deberta' \
|
|
||||||
--strategy naive \
|
|
||||||
--loss_fn 'log_exp'\
|
|
||||||
--save_path 'rmstatic.pt' \
|
|
||||||
--test True
|
|
@ -1,143 +0,0 @@
|
|||||||
import argparse
|
|
||||||
|
|
||||||
import loralib as lora
|
|
||||||
import torch
|
|
||||||
import torch.distributed as dist
|
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
|
||||||
from chatgpt.dataset import SFTDataset, AlpacaDataset, AlpacaDataCollator
|
|
||||||
from chatgpt.models.base import RewardModel
|
|
||||||
from chatgpt.models.bloom import BLOOMLM
|
|
||||||
from chatgpt.models.gpt import GPTLM
|
|
||||||
from chatgpt.models.opt import OPTLM
|
|
||||||
from chatgpt.models.llama import LlamaLM
|
|
||||||
from chatgpt.trainer import SFTTrainer
|
|
||||||
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
|
|
||||||
from chatgpt.utils import prepare_llama_tokenizer_and_embedding
|
|
||||||
from datasets import load_dataset
|
|
||||||
from torch.optim import Adam
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
from transformers import AutoTokenizer, BloomTokenizerFast
|
|
||||||
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
|
|
||||||
|
|
||||||
from colossalai.nn.optimizer import HybridAdam
|
|
||||||
from colossalai.logging import get_dist_logger
|
|
||||||
|
|
||||||
|
|
||||||
def train(args):
|
|
||||||
# configure strategy
|
|
||||||
if args.strategy == 'naive':
|
|
||||||
strategy = NaiveStrategy()
|
|
||||||
elif args.strategy == 'ddp':
|
|
||||||
strategy = DDPStrategy()
|
|
||||||
elif args.strategy == 'colossalai_gemini':
|
|
||||||
strategy = ColossalAIStrategy(stage=3, placement_policy='cuda')
|
|
||||||
elif args.strategy == 'colossalai_zero2':
|
|
||||||
strategy = ColossalAIStrategy(stage=2, placement_policy='cuda')
|
|
||||||
else:
|
|
||||||
raise ValueError(f'Unsupported strategy "{args.strategy}"')
|
|
||||||
|
|
||||||
# configure model
|
|
||||||
with strategy.model_init_context():
|
|
||||||
if args.model == 'bloom':
|
|
||||||
model = BLOOMLM(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda()
|
|
||||||
elif args.model == 'opt':
|
|
||||||
model = OPTLM(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda()
|
|
||||||
elif args.model == 'gpt2':
|
|
||||||
model = GPTLM(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda()
|
|
||||||
elif args.model == 'llama':
|
|
||||||
model = LlamaLM(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda()
|
|
||||||
else:
|
|
||||||
raise ValueError(f'Unsupported model "{args.model}"')
|
|
||||||
|
|
||||||
# configure tokenizer
|
|
||||||
if args.model == 'gpt2':
|
|
||||||
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
|
||||||
elif args.model == 'bloom':
|
|
||||||
tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain)
|
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
|
||||||
elif args.model == 'opt':
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
|
||||||
elif args.model == 'llama':
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
args.pretrain,
|
|
||||||
padding_side="right",
|
|
||||||
use_fast=False,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(f'Unsupported model "{args.model}"')
|
|
||||||
|
|
||||||
if args.model == 'llama':
|
|
||||||
tokenizer = prepare_llama_tokenizer_and_embedding(tokenizer, model)
|
|
||||||
else:
|
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
|
||||||
|
|
||||||
max_len = 512
|
|
||||||
|
|
||||||
# configure optimizer
|
|
||||||
if args.strategy.startswith('colossalai'):
|
|
||||||
optim = HybridAdam(model.parameters(), lr=5e-5)
|
|
||||||
else:
|
|
||||||
optim = Adam(model.parameters(), lr=5e-5)
|
|
||||||
|
|
||||||
logger = get_dist_logger()
|
|
||||||
|
|
||||||
# configure dataset
|
|
||||||
if args.dataset == 'yizhongw/self_instruct':
|
|
||||||
train_data = load_dataset(args.dataset, 'super_natural_instructions', split='train')
|
|
||||||
eval_data = load_dataset(args.dataset, 'super_natural_instructions', split='test')
|
|
||||||
|
|
||||||
train_dataset = SFTDataset(train_data, tokenizer, max_len)
|
|
||||||
eval_dataset = SFTDataset(eval_data, tokenizer, max_len)
|
|
||||||
|
|
||||||
elif 'alpaca' in args.dataset:
|
|
||||||
train_dataset = AlpacaDataset(tokenizer=tokenizer, data_path=args.dataset)
|
|
||||||
eval_dataset = None
|
|
||||||
data_collator = AlpacaDataCollator(tokenizer=tokenizer)
|
|
||||||
|
|
||||||
if dist.is_initialized() and dist.get_world_size() > 1:
|
|
||||||
train_sampler = DistributedSampler(train_dataset, shuffle=True, seed=42, drop_last=True)
|
|
||||||
if eval_dataset is not None:
|
|
||||||
eval_sampler = DistributedSampler(eval_dataset, shuffle=False, seed=42, drop_last=False)
|
|
||||||
else:
|
|
||||||
train_sampler = None
|
|
||||||
eval_sampler = None
|
|
||||||
|
|
||||||
train_dataloader = DataLoader(train_dataset, shuffle=(train_sampler is None), sampler=train_sampler, batch_size=args.batch_size, collate_fn=data_collator)
|
|
||||||
if eval_dataset is not None:
|
|
||||||
eval_dataloader = DataLoader(eval_dataset, shuffle=(eval_sampler is None), sampler=eval_sampler, batch_size=args.batch_size, collate_fn=data_collator)
|
|
||||||
else:
|
|
||||||
eval_dataloader = None
|
|
||||||
|
|
||||||
trainer = SFTTrainer(model=model,
|
|
||||||
strategy=strategy,
|
|
||||||
optim=optim,
|
|
||||||
train_dataloader=train_dataloader,
|
|
||||||
eval_dataloader=eval_dataloader,
|
|
||||||
batch_size=args.batch_size,
|
|
||||||
max_epochs=args.max_epochs)
|
|
||||||
|
|
||||||
trainer.fit(logger=logger, use_lora=args.lora_rank, log_interval=args.log_interval)
|
|
||||||
|
|
||||||
# save model checkpoint after fitting on only rank0
|
|
||||||
strategy.save_model(model, 'sft_checkpoint.pt', only_rank0=True)
|
|
||||||
# save optimizer checkpoint on all ranks
|
|
||||||
strategy.save_optimizer(optim, 'sft_optim_checkpoint_%d.pt' % (torch.cuda.current_device()), only_rank0=False)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument('--strategy',
|
|
||||||
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
|
|
||||||
default='naive')
|
|
||||||
parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom')
|
|
||||||
parser.add_argument('--pretrain', type=str, default=None)
|
|
||||||
parser.add_argument('--dataset', type=str, default='yizhongw/self_instruct')
|
|
||||||
parser.add_argument('--save_path', type=str, default='sft_ckpt.pth')
|
|
||||||
parser.add_argument('--max_epochs', type=int, default=1)
|
|
||||||
parser.add_argument('--batch_size', type=int, default=4)
|
|
||||||
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
|
|
||||||
parser.add_argument('--log_interval', type=int, default=100, help="how many steps to log")
|
|
||||||
args = parser.parse_args()
|
|
||||||
train(args)
|
|
||||||
|
|
@ -1,26 +0,0 @@
|
|||||||
set_n_least_used_CUDA_VISIBLE_DEVICES() {
|
|
||||||
local n=${1:-"9999"}
|
|
||||||
echo "GPU Memory Usage:"
|
|
||||||
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \
|
|
||||||
| tail -n +2 \
|
|
||||||
| nl -v 0 \
|
|
||||||
| tee /dev/tty \
|
|
||||||
| sort -g -k 2 \
|
|
||||||
| awk '{print $1}' \
|
|
||||||
| head -n $n)
|
|
||||||
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
|
|
||||||
echo "Now CUDA_VISIBLE_DEVICES is set to:"
|
|
||||||
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
|
||||||
}
|
|
||||||
|
|
||||||
set_n_least_used_CUDA_VISIBLE_DEVICES 8
|
|
||||||
|
|
||||||
#torchrun --standalone --nproc_per_node=2 train_sft.py --pretrain 'bigscience/bloomz-560m' --model 'bloom' --strategy colossalai_zero2 --log_interval 10
|
|
||||||
#torchrun --standalone --nproc_per_node=8 train_sft.py --model 'gpt2' --strategy colossalai_zero2 --batch_size 1 --log_interval 10
|
|
||||||
torchrun --standalone --nproc_per_node=8 train_sft.py \
|
|
||||||
--pretrain "/data/personal/nus-mql/LLAMA-7B" \
|
|
||||||
--model 'llama' \
|
|
||||||
--strategy colossalai_zero2 \
|
|
||||||
--log_interval 10 \
|
|
||||||
--save_path /data/personal/nus-mql/Coati-7B \
|
|
||||||
--dataset /data/personal/nus-mql/stanford_alpaca/alpaca_data.json
|
|
@ -1,6 +0,0 @@
|
|||||||
[pytest]
|
|
||||||
markers =
|
|
||||||
cpu: tests which can run on CPU
|
|
||||||
gpu: tests which requires a single GPU
|
|
||||||
dist: tests which are run in a multi-GPU or multi-machine environment
|
|
||||||
experiment: tests for experimental features
|
|
@ -1 +0,0 @@
|
|||||||
pytest
|
|
@ -1,7 +0,0 @@
|
|||||||
transformers>=4.20.1
|
|
||||||
tqdm
|
|
||||||
datasets
|
|
||||||
loralib
|
|
||||||
colossalai>=0.2.4
|
|
||||||
torch==1.12.1
|
|
||||||
langchain
|
|
@ -1,41 +0,0 @@
|
|||||||
from setuptools import find_packages, setup
|
|
||||||
|
|
||||||
|
|
||||||
def fetch_requirements(path):
|
|
||||||
with open(path, 'r') as fd:
|
|
||||||
return [r.strip() for r in fd.readlines()]
|
|
||||||
|
|
||||||
|
|
||||||
def fetch_readme():
|
|
||||||
with open('README.md', encoding='utf-8') as f:
|
|
||||||
return f.read()
|
|
||||||
|
|
||||||
|
|
||||||
def fetch_version():
|
|
||||||
with open('version.txt', 'r') as f:
|
|
||||||
return f.read().strip()
|
|
||||||
|
|
||||||
|
|
||||||
setup(
|
|
||||||
name='chatgpt',
|
|
||||||
version=fetch_version(),
|
|
||||||
packages=find_packages(exclude=(
|
|
||||||
'tests',
|
|
||||||
'benchmarks',
|
|
||||||
'*.egg-info',
|
|
||||||
)),
|
|
||||||
description='A RLFH implementation (ChatGPT) powered by ColossalAI',
|
|
||||||
long_description=fetch_readme(),
|
|
||||||
long_description_content_type='text/markdown',
|
|
||||||
license='Apache Software License 2.0',
|
|
||||||
url='https://github.com/hpcaitech/ChatGPT',
|
|
||||||
install_requires=fetch_requirements('requirements.txt'),
|
|
||||||
python_requires='>=3.6',
|
|
||||||
classifiers=[
|
|
||||||
'Programming Language :: Python :: 3',
|
|
||||||
'License :: OSI Approved :: Apache Software License',
|
|
||||||
'Environment :: GPU :: NVIDIA CUDA',
|
|
||||||
'Topic :: Scientific/Engineering :: Artificial Intelligence',
|
|
||||||
'Topic :: System :: Distributed Computing',
|
|
||||||
],
|
|
||||||
)
|
|
@ -1,98 +0,0 @@
|
|||||||
import os
|
|
||||||
import tempfile
|
|
||||||
from contextlib import nullcontext
|
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
import torch.distributed as dist
|
|
||||||
import torch.multiprocessing as mp
|
|
||||||
from chatgpt.models.gpt import GPTActor
|
|
||||||
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy
|
|
||||||
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
|
|
||||||
|
|
||||||
from colossalai.nn.optimizer import HybridAdam
|
|
||||||
from colossalai.testing import rerun_if_address_is_in_use
|
|
||||||
from colossalai.utils import free_port
|
|
||||||
|
|
||||||
GPT_CONFIG = GPT2Config(n_embd=128, n_layer=4, n_head=4)
|
|
||||||
|
|
||||||
|
|
||||||
def get_data(batch_size: int, seq_len: int = 10) -> dict:
|
|
||||||
input_ids = torch.randint(0, 50257, (batch_size, seq_len), device='cuda')
|
|
||||||
attention_mask = torch.ones_like(input_ids)
|
|
||||||
return dict(input_ids=input_ids, attention_mask=attention_mask)
|
|
||||||
|
|
||||||
|
|
||||||
def run_test_checkpoint(strategy):
|
|
||||||
BATCH_SIZE = 2
|
|
||||||
|
|
||||||
if strategy == 'ddp':
|
|
||||||
strategy = DDPStrategy()
|
|
||||||
elif strategy == 'colossalai_gemini':
|
|
||||||
strategy = ColossalAIStrategy(stage=3, placement_policy='cuda', initial_scale=2**5)
|
|
||||||
elif strategy == 'colossalai_zero2':
|
|
||||||
strategy = ColossalAIStrategy(stage=2, placement_policy='cuda')
|
|
||||||
else:
|
|
||||||
raise ValueError(f'Unsupported strategy "{strategy}"')
|
|
||||||
|
|
||||||
with strategy.model_init_context():
|
|
||||||
actor = GPTActor(config=GPT_CONFIG).cuda()
|
|
||||||
|
|
||||||
actor_optim = HybridAdam(actor.parameters())
|
|
||||||
|
|
||||||
actor, actor_optim = strategy.prepare((actor, actor_optim))
|
|
||||||
|
|
||||||
def run_step():
|
|
||||||
data = get_data(BATCH_SIZE)
|
|
||||||
action_mask = torch.ones_like(data['attention_mask'], dtype=torch.bool)
|
|
||||||
action_log_probs = actor(data['input_ids'], action_mask.size(1), data['attention_mask'])
|
|
||||||
loss = action_log_probs.sum()
|
|
||||||
strategy.backward(loss, actor, actor_optim)
|
|
||||||
strategy.optimizer_step(actor_optim)
|
|
||||||
|
|
||||||
run_step()
|
|
||||||
|
|
||||||
ctx = tempfile.TemporaryDirectory() if dist.get_rank() == 0 else nullcontext()
|
|
||||||
|
|
||||||
with ctx as dirname:
|
|
||||||
rank0_dirname = [dirname]
|
|
||||||
dist.broadcast_object_list(rank0_dirname)
|
|
||||||
rank0_dirname = rank0_dirname[0]
|
|
||||||
|
|
||||||
model_path = os.path.join(rank0_dirname, 'model.pt')
|
|
||||||
optim_path = os.path.join(rank0_dirname, f'optim-r{dist.get_rank()}.pt')
|
|
||||||
|
|
||||||
strategy.save_model(actor, model_path, only_rank0=True)
|
|
||||||
strategy.save_optimizer(actor_optim, optim_path, only_rank0=False)
|
|
||||||
|
|
||||||
dist.barrier()
|
|
||||||
|
|
||||||
strategy.load_model(actor, model_path, strict=False)
|
|
||||||
strategy.load_optimizer(actor_optim, optim_path)
|
|
||||||
|
|
||||||
dist.barrier()
|
|
||||||
|
|
||||||
run_step()
|
|
||||||
|
|
||||||
|
|
||||||
def run_dist(rank, world_size, port, strategy):
|
|
||||||
os.environ['RANK'] = str(rank)
|
|
||||||
os.environ['LOCAL_RANK'] = str(rank)
|
|
||||||
os.environ['WORLD_SIZE'] = str(world_size)
|
|
||||||
os.environ['MASTER_ADDR'] = 'localhost'
|
|
||||||
os.environ['MASTER_PORT'] = str(port)
|
|
||||||
run_test_checkpoint(strategy)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
|
||||||
@pytest.mark.parametrize('world_size', [2])
|
|
||||||
@pytest.mark.parametrize('strategy', ['ddp', 'colossalai_zero2', 'colossalai_gemini'])
|
|
||||||
@rerun_if_address_is_in_use()
|
|
||||||
def test_checkpoint(world_size, strategy):
|
|
||||||
run_func = partial(run_dist, world_size=world_size, port=free_port(), strategy=strategy)
|
|
||||||
mp.spawn(run_func, nprocs=world_size)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
test_checkpoint(2, 'colossalai_zero2')
|
|
@ -1,122 +0,0 @@
|
|||||||
import os
|
|
||||||
from copy import deepcopy
|
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
import torch.distributed as dist
|
|
||||||
import torch.multiprocessing as mp
|
|
||||||
from chatgpt.experience_maker import NaiveExperienceMaker
|
|
||||||
from chatgpt.models.base import RewardModel
|
|
||||||
from chatgpt.models.gpt import GPTActor, GPTCritic
|
|
||||||
from chatgpt.replay_buffer import NaiveReplayBuffer
|
|
||||||
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy
|
|
||||||
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
|
|
||||||
|
|
||||||
from colossalai.testing import rerun_if_address_is_in_use
|
|
||||||
from colossalai.utils import free_port
|
|
||||||
|
|
||||||
GPT_CONFIG = GPT2Config(n_embd=128, n_layer=4, n_head=4)
|
|
||||||
|
|
||||||
|
|
||||||
def get_data(batch_size: int, seq_len: int = 10) -> dict:
|
|
||||||
input_ids = torch.randint(0, 50257, (batch_size, seq_len), device='cuda')
|
|
||||||
attention_mask = torch.ones_like(input_ids)
|
|
||||||
return dict(input_ids=input_ids, attention_mask=attention_mask)
|
|
||||||
|
|
||||||
|
|
||||||
def gather_and_equal(tensor: torch.Tensor) -> bool:
|
|
||||||
world_size = dist.get_world_size()
|
|
||||||
outputs = [torch.empty_like(tensor) for _ in range(world_size)]
|
|
||||||
dist.all_gather(outputs, tensor.contiguous())
|
|
||||||
for t in outputs[1:]:
|
|
||||||
if not torch.equal(outputs[0], t):
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def run_test_data(strategy):
|
|
||||||
EXPERINCE_BATCH_SIZE = 4
|
|
||||||
SAMPLE_BATCH_SIZE = 2
|
|
||||||
|
|
||||||
if strategy == 'ddp':
|
|
||||||
strategy = DDPStrategy()
|
|
||||||
elif strategy == 'colossalai':
|
|
||||||
strategy = ColossalAIStrategy(placement_policy='cuda')
|
|
||||||
else:
|
|
||||||
raise ValueError(f'Unsupported strategy "{strategy}"')
|
|
||||||
|
|
||||||
actor = GPTActor(config=GPT_CONFIG).cuda()
|
|
||||||
critic = GPTCritic(config=GPT_CONFIG).cuda()
|
|
||||||
|
|
||||||
initial_model = deepcopy(actor)
|
|
||||||
reward_model = RewardModel(deepcopy(critic.model)).cuda()
|
|
||||||
|
|
||||||
experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model)
|
|
||||||
replay_buffer = NaiveReplayBuffer(SAMPLE_BATCH_SIZE, cpu_offload=False)
|
|
||||||
|
|
||||||
# experience of all ranks should be the same
|
|
||||||
for _ in range(2):
|
|
||||||
data = get_data(EXPERINCE_BATCH_SIZE)
|
|
||||||
assert gather_and_equal(data['input_ids'])
|
|
||||||
assert gather_and_equal(data['attention_mask'])
|
|
||||||
experience = experience_maker.make_experience(**data,
|
|
||||||
do_sample=True,
|
|
||||||
max_length=16,
|
|
||||||
eos_token_id=50256,
|
|
||||||
pad_token_id=50256)
|
|
||||||
assert gather_and_equal(experience.sequences)
|
|
||||||
assert gather_and_equal(experience.action_log_probs)
|
|
||||||
assert gather_and_equal(experience.values)
|
|
||||||
assert gather_and_equal(experience.reward)
|
|
||||||
assert gather_and_equal(experience.advantages)
|
|
||||||
assert gather_and_equal(experience.action_mask)
|
|
||||||
assert gather_and_equal(experience.attention_mask)
|
|
||||||
replay_buffer.append(experience)
|
|
||||||
|
|
||||||
# replay buffer's data should be the same
|
|
||||||
buffer_size = torch.tensor([len(replay_buffer)], device='cuda')
|
|
||||||
assert gather_and_equal(buffer_size)
|
|
||||||
for item in replay_buffer.items:
|
|
||||||
assert gather_and_equal(item.sequences)
|
|
||||||
assert gather_and_equal(item.action_log_probs)
|
|
||||||
assert gather_and_equal(item.values)
|
|
||||||
assert gather_and_equal(item.reward)
|
|
||||||
assert gather_and_equal(item.advantages)
|
|
||||||
assert gather_and_equal(item.action_mask)
|
|
||||||
assert gather_and_equal(item.attention_mask)
|
|
||||||
|
|
||||||
# dataloader of each rank should have the same size and different batch
|
|
||||||
dataloader = strategy.setup_dataloader(replay_buffer)
|
|
||||||
dataloader_size = torch.tensor([len(dataloader)], device='cuda')
|
|
||||||
assert gather_and_equal(dataloader_size)
|
|
||||||
for experience in dataloader:
|
|
||||||
assert not gather_and_equal(experience.sequences)
|
|
||||||
assert not gather_and_equal(experience.action_log_probs)
|
|
||||||
assert not gather_and_equal(experience.values)
|
|
||||||
assert not gather_and_equal(experience.reward)
|
|
||||||
assert not gather_and_equal(experience.advantages)
|
|
||||||
# action mask and attention mask may be same
|
|
||||||
|
|
||||||
|
|
||||||
def run_dist(rank, world_size, port, strategy):
|
|
||||||
os.environ['RANK'] = str(rank)
|
|
||||||
os.environ['LOCAL_RANK'] = str(rank)
|
|
||||||
os.environ['WORLD_SIZE'] = str(world_size)
|
|
||||||
os.environ['MASTER_ADDR'] = 'localhost'
|
|
||||||
os.environ['MASTER_PORT'] = str(port)
|
|
||||||
run_test_data(strategy)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip
|
|
||||||
@pytest.mark.dist
|
|
||||||
@pytest.mark.parametrize('world_size', [2])
|
|
||||||
@pytest.mark.parametrize('strategy', ['ddp', 'colossalai'])
|
|
||||||
@rerun_if_address_is_in_use()
|
|
||||||
def test_data(world_size, strategy):
|
|
||||||
run_func = partial(run_dist, world_size=world_size, port=free_port(), strategy=strategy)
|
|
||||||
mp.spawn(run_func, nprocs=world_size)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
test_data(2, 'colossalai')
|
|
@ -1 +0,0 @@
|
|||||||
1.0.0
|
|
Loading…
Reference in New Issue
Block a user