mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-24 20:20:53 +00:00
[ColossalRL] Support ColossalRL on Ascend (#6324)
* [fix] support npu * [feat] multinode 14B * [feat] enlarge seqlen * [fix] * [fix] ready to updated * [fix] ready to merge grpo-latest * [fix] rm comments * [feat] support msprof-analyze, add analsys result * [feat] support ColossalaiRL on Ascend * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [feat] rm comments in qwen modeling * [Doc] Drafted README.md * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [feat] fix ascend readme format * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [fix] fix readme * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [fix] fix readme * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [fix] fix Readme, rm irrelevant testcase * [fix] fix some adapt modification * [fix] rm comments in modeling qwen * [fix] rm comm, test and debug print * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
This commit is contained in:
parent
de2ad3b206
commit
e1ca2d22ae
@ -1,6 +1,291 @@
|
||||
# Requirements
|
||||
# Distributed RL Framework for Language Model Fine-Tuning
|
||||
|
||||
This repository implements a distributed Reinforcement Learning (RL) training framework designed to fine-tune large language models using algorithms such as **GRPO** and **DAPO**. It supports multi-node and multi-GPU setups, scalable rollout generation, and policy optimization using libraries like VLLM.
|
||||
|
||||
---
|
||||
|
||||
## 🚀 Features
|
||||
|
||||
* **Distributed Training with Ray**: Scalable to multiple machines and GPUs.
|
||||
* **Support for GRPO and DAPO**: Choose your preferred policy optimization algorithm.
|
||||
* **Model Backends**: Support `vllm` as inference backends.
|
||||
* **Rollout and Policy Decoupling**: Efficient generation and consumption of data through parallel inferencer-trainer architecture.
|
||||
* **Evaluation Integration**: Easily plug in task-specific eval datasets.
|
||||
* **Checkpoints and Logging**: Configurable intervals and directories.
|
||||
|
||||
---
|
||||
|
||||
## 🛠 Installation
|
||||
|
||||
### Prepare Develop Environment
|
||||
|
||||
Install Colossalai & ColossalChat
|
||||
```bash
|
||||
git clone https://github.com/hpcaitech/ColossalAI.git
|
||||
git checkout grpo-latest-ascend
|
||||
pip install -e .
|
||||
|
||||
cd ./applications/ColossalChat
|
||||
pip install -e .
|
||||
```
|
||||
Install Fuyao Ray.
|
||||
Please update CANN before install fuyao ray
|
||||
```bash
|
||||
# Install CANN
|
||||
source /usr/local/Ascend/ascend-toolkit/set_env.sh
|
||||
./Ascend-cann-kernels-910b_8.1.RC1.alpha001_linux-aarch64.run --devel
|
||||
|
||||
# Clone Fuyao Ray. Fuyao Ray is not an open source project, it will be inherited in the ColossalRL images.
|
||||
git clone https://gitee.com/openfuyao/ray.git
|
||||
cd ray
|
||||
git pull origin pull/5/head
|
||||
|
||||
# Install ray
|
||||
pip install ray==2.43.0 --no-cache-dir
|
||||
|
||||
# Create soft-link from fuyao-ray to ray site-package
|
||||
cd ..
|
||||
ln -s ./ray/python/ray/ /usr/local/python3.10/lib/python3.10/site-packages/ray
|
||||
|
||||
# Install Fuyao Ray
|
||||
cd ray
|
||||
python python/ray/setup-dev.py
|
||||
```
|
||||
|
||||
Prepare Model & dataset
|
||||
```bash
|
||||
huggingface-cli download --local-dir-use-symlinks False Qwen/Qwen2.5-7B --local-dir /models/Qwen/Qwen2.5-7B
|
||||
```
|
||||
|
||||
### Set Distributed Config
|
||||
Now, we need to set distributed config for multi-node.
|
||||
|
||||
First, we set host ip config.
|
||||
For example. I need to configure a cluster of 4 nodes, then I do
|
||||
```bash
|
||||
vim /etc/hosts
|
||||
```
|
||||
Then write IP node map to /etc/hosts
|
||||
```bash
|
||||
10.0.0.3 npu-3
|
||||
10.0.0.4 npu-4
|
||||
10.0.0.5 npu-5
|
||||
10.0.0.6 npu-6
|
||||
```
|
||||
|
||||
Set Ascend Multi-Node Config
|
||||
|
||||
```bash
|
||||
pip install cupy-cuda12x
|
||||
python -m cupyx.tools.install_library --cuda 12.x --library nccl
|
||||
export ATB_LLM_HCCL_ENABLE=1
|
||||
export ATB_LLM_COMM_BACKEND="hccl"
|
||||
export HCCL_CONNECT_TIMEOUT=7200
|
||||
export WORLD_SIZE=32
|
||||
export HCCL_EXEC_TIMEOUT=7200
|
||||
export HCCL_SOCKET_IFNAME=eno0
|
||||
export RAY_COLLECTIVE_MEET_TIMEOUT_SECONDS=7200
|
||||
```
|
||||
|
||||
## 🧠 Data Format
|
||||
|
||||
Each data sample in the training or evaluation `.jsonl` file should follow this format:
|
||||
|
||||
```json
|
||||
{
|
||||
"messages": {
|
||||
"role": "user",
|
||||
"content": "Simplify $\\sqrt[3]{1+8} \\cdot \\sqrt[3]{1+\\sqrt[3]{8}}$. Let's think step by step and output the final answer within \\boxed{}."
|
||||
},
|
||||
"gt_answer": "3"
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## ⚙️ Hyperparameters & Arguments
|
||||
|
||||
| Argument | Description | Example |
|
||||
| ---------------- | --------------------------------------- | ----------------- |
|
||||
| `--model` | Model path or identifier | `/path/to/model` |
|
||||
| `--dataset` | Path to training `.jsonl` | `/path/to/train_data.jsonl` |
|
||||
| `--eval-dataset` | JSON of task\:eval\_dataset\_path pairs | `{'eval_1':'/path/to/eval_1.jsonl'}` |
|
||||
| `--project` | Project name | `Project1` |
|
||||
| `--num-episodes` | Number of training episodes | `1` |
|
||||
|
||||
### Distributed Training
|
||||
|
||||
| Argument | Description | Example |
|
||||
| ----------------------------- | ------------------------------------- | ------- |
|
||||
| `--num-trainers` | Number of trainer processes | `4` |
|
||||
| `--num-inferencer` | Number of inferencer processes | `4` |
|
||||
| `--inference-batch-size` | Prompts per inference step | `8` |
|
||||
| `--inference-microbatch-size` | Per-GPU batch size for inference | `8` |
|
||||
| `--train-batch-size` | Prompts per trainer step per dp group | `8` |
|
||||
| `--train-minibatch-size` | Mini-batch size before forward pass | `8` |
|
||||
| `--train-microbatch-size` | Per-GPU batch size for training | `2` |
|
||||
|
||||
### Sampling
|
||||
|
||||
| Argument | Description | Example |
|
||||
| --------------------- | --------------------- | -------------- |
|
||||
| `--backend` | Generation backend, choose from `vllm` | `vllm` |
|
||||
| `--temperature` | Sampling temperature for generation | `1.0` |
|
||||
| `--top-k` | Top-K sampling parameter for generation | `None` |
|
||||
| `--top-p` | Top-P sampling parameter for generation | `1.0` |
|
||||
| `--system-prompt` | System prompt, default to the system prompt for `think_answer_tags` format | `Please reason step by step, and put your final answer within \\boxed{}.` |
|
||||
| `--max-new-tokens` | Max generation tokens | `3584` |
|
||||
| `--max-prompt-tokens` | Max prompt tokens | `512` |
|
||||
|
||||
### GRPO Specific
|
||||
|
||||
| Argument | Description | Example |
|
||||
| ----------------- | ---------------------------- | ------------------- |
|
||||
| `--algo` | Algorithm (`GRPO` or `DAPO`), for more customization refer to [GRPO Settings](#️-grpo-settings) | `GRPO` |
|
||||
| `--learning-rate` | Learning rate | `1e-6` |
|
||||
| `--kl-coeff` | KL penalty coefficient | `0.01` |
|
||||
| `--reward-type` | Reward signal type (choose from 'think_answer_tags', 'boxed') | `think_answer_tags` |
|
||||
| `--eval-interval` | Evaluation interval in number of training steps (positive value to enable evaluation) | `100` |
|
||||
|
||||
### Logging and Checkpointing
|
||||
|
||||
| Argument | Description | Example |
|
||||
| -------------------- | ------------------------- | ------------ |
|
||||
| `--save-interval` | Training steps between checkpoints | `20` |
|
||||
| `--save-dir` | Checkpoint directory | `./model` |
|
||||
| `--eval-save-dir` | Evaluation save path | `./eval` |
|
||||
| `--rollout-save-dir` | Rollout logs directory | `./rollouts` |
|
||||
|
||||
### Miscellaneous
|
||||
|
||||
| Argument | Description | Example |
|
||||
| ------------------ | --------------------------------------- | ------- |
|
||||
| `--ray_dir` | Custom Ray temp dir of a running Ray cluster (optional) | `None` |
|
||||
| `--master_address` | Master address of a running Ray cluster | `None` |
|
||||
| `--master_port` | Master port for torch DDP | `29506` |
|
||||
|
||||
---
|
||||
|
||||
## ⚙️ GRPO Settings
|
||||
|
||||
In addition to the two default training settings we provided--- original `GRPO` and `DAPO`, users can customize their training by changing the following hyperparameters in `grpo_config` in `rl_example.py`.
|
||||
|
||||
| Argument Name | Description | Default |
|
||||
| ----------------------------- | ---------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `filter_range` | Filters out rollout group if the success rate within that group is out of this range.| `[0.01, 0.99]` |
|
||||
| `dynamic_batching` | Enables dynamic batching as described in the [DAPO paper](https://arxiv.org/abs/2503.14476). | `True` |
|
||||
| `clip_eps_low` | epsilon_low in DAPO in equation in [DAPO paper](https://arxiv.org/abs/2503.14476) | `0.2` |
|
||||
| `clip_eps_high` | epsilon_high in DAPO equation in [DAPO paper](https://arxiv.org/abs/2503.14476) | `0.28` |
|
||||
| `skip_threshold` | If ratio is above this threshold, the sample is skipped to avoid instability. | `20.0` |
|
||||
| `loss_variation` | Type of loss variation. Supports `"token_level"` for token-wise policy gradient loss and `sample_level` for original GRPO loss. | `"token_level"` |
|
||||
| `soft_over_length_punishment` | Whether to use soft overlength penalty in [DAPO paper](https://arxiv.org/abs/2503.14476) or not. | `True` |
|
||||
| `cache_length` | `L_cache` parameter for soft overlength penalty in e.q. 13 in [DAPO paper](https://arxiv.org/abs/2503.14476) | `min(1024, int(args.max_new_tokens / 4))` |
|
||||
| `filter_truncated_response` | Mask out truncated responses in loss calculation. | `True` |
|
||||
|
||||
|
||||
|
||||
## 🔄 Constraints and Notes
|
||||
|
||||
* `num_inferencer + num_trainer == NUM_GPUs`
|
||||
* `num_inferencer % num_trainer == 0`
|
||||
* `(num_inferencer * inference_batch_size) % (num_trainer * train_batch_size) == 0`
|
||||
* `train_batch_size >= train_minibatch_size >= train_microbatch_size`
|
||||
* `inference_batch_size >= inference_microbatch_size`
|
||||
* Set microbatch sizes based on **VRAM capacity**
|
||||
* To use tensor parallelism on inferencer
|
||||
* set backend to `vllm`
|
||||
* change `tensor_parallel_size` in `inference_model_config` in rl_example.py
|
||||
* set `num_inferencer = NUM_INFERENCE_GPUs / tensor_parallel_size`
|
||||
* To set tensor parallelism / pipeline parallelism / zero stage
|
||||
* change corresponding settings in `plugin_config` in rl_example.py
|
||||
* Ensure rollout generation rate matches trainer consumption:
|
||||
|
||||
```
|
||||
num_inferencer * inference_batch_size % (
|
||||
num_trainer * train_batch_size /
|
||||
train_pipeline_parallelism_size /
|
||||
train_tensor_parallelism_size
|
||||
) == 0
|
||||
```
|
||||
* Model weights sync every:
|
||||
|
||||
```
|
||||
(num_inferencer * inference_batch_size) /
|
||||
(num_trainer * train_batch_size /
|
||||
train_pipeline_parallelism_size /
|
||||
train_tensor_parallelism_size)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 🧪 Example: single machine 8-GPU Zero2 Strategy
|
||||
|
||||
```bash
|
||||
python rl_example.py \
|
||||
--dataset /path/to/train_data.jsonl \
|
||||
--model /path/to/Qwen2.5-3B/ \
|
||||
-t 4 -i 4 \
|
||||
-b vllm \
|
||||
-ibs 2 -tbs 4 -tMbs 1 -tmbs 4 -imbs 1 \
|
||||
-rt boxed \
|
||||
-g 4 \
|
||||
-ibs 1 \
|
||||
-tbs 2 \
|
||||
-tMbs 1 \
|
||||
-tmbs 2 \
|
||||
-imbs 1 \
|
||||
-s "Please reason step by step, and put your final answer within \\boxed{}." \
|
||||
-tMbs 8 \
|
||||
-p GRPO-Train-Align-Debug \
|
||||
```
|
||||
|
||||
## 🧪 Example: multi-machine TP+PP Strategy
|
||||
|
||||
### Create ray cluster on multi-machine
|
||||
For example, now we have 4 nodes and their IPs are 10.0.0.3, 10.0.0.4, 10.0.0.5, 10.0.0.6.
|
||||
We use 10.0.0.3 as master node. First we start a ray cluster on 10.0.0.3:
|
||||
```bash
|
||||
ray start --head --node-ip-address=10.0.0.3
|
||||
```
|
||||
|
||||
Then, for each slave node (10.0.0.4/10.0.0.5/10.0.0.6), we add to the ray cluser by following code:
|
||||
```bash
|
||||
ray start --address='10.0.0.3:6379'
|
||||
```
|
||||
|
||||
Modify plugin_config in ./applications/ColossalChat/rl_example.py
|
||||
```python
|
||||
plugin_config={
|
||||
"tp_size": 4,
|
||||
"pp_size": 2,
|
||||
"microbatch_size": max(
|
||||
1, args.train_microbatch_size // 2
|
||||
), # microbatch size should be set to train_microbatch_size // pp_size
|
||||
"zero_stage": 1,
|
||||
"max_norm": 1.0,
|
||||
}, # for pp, tp
|
||||
```
|
||||
|
||||
```bash
|
||||
# Hint1: replace /models/Qwen/Qwen2.5-7B to your model path
|
||||
# replace /datasets/train-alignment.jsonl to your dataset path
|
||||
python rl_example.py
|
||||
-m /path/to/Qwen2.5-Math-7B/ \
|
||||
-d /path/to/train_data.jsonl \
|
||||
--master_address '10.0.0.3'
|
||||
-t 16 \
|
||||
-i 16 \
|
||||
-p GRPO-Train-Align-Debug \
|
||||
-g 2 \
|
||||
-ibs 1 \
|
||||
-tbs 2 \
|
||||
-tMbs 1 \
|
||||
-tmbs 2 \
|
||||
-imbs 1 \
|
||||
-b vllm \
|
||||
-e 2 \
|
||||
-rt boxed \
|
||||
-s "Please reason step by step, and put your final answer within \\boxed{}."
|
||||
```
|
||||
|
||||
## Acknowledgement
|
||||
|
||||
---
|
||||
|
@ -13,7 +13,6 @@ from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import HybridParallelPlugin
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from .comm import ray_broadcast_tensor_dict
|
||||
from .utils import bind_batch, post_recv, unbind_batch
|
||||
@ -33,6 +32,7 @@ class BaseConsumer:
|
||||
batch_size: int,
|
||||
model_config: Dict[str, Any],
|
||||
plugin_config: Dict[str, Any],
|
||||
generate_config: Dict[str, Any],
|
||||
minibatch_size: int = 1,
|
||||
save_interval: int = 100,
|
||||
save_dir: str = "./model",
|
||||
@ -55,8 +55,9 @@ class BaseConsumer:
|
||||
self.model_config = model_config
|
||||
self.plugin_config = plugin_config
|
||||
|
||||
self.device = get_current_device()
|
||||
self.device = "npu"
|
||||
self.lr_scheduler = None
|
||||
self.generate_config = generate_config
|
||||
|
||||
def setup(self) -> None:
|
||||
launch(self.rank, self.world_size, self.master_addr, self.master_port, local_rank=0)
|
||||
@ -73,24 +74,28 @@ class BaseConsumer:
|
||||
self.booster = Booster(plugin=self.plugin)
|
||||
self.dp_rank = dist.get_rank(self.plugin.dp_group)
|
||||
self.tp_rank = dist.get_rank(self.plugin.tp_group)
|
||||
self.sp_rank = dist.get_rank(self.plugin.sp_group)
|
||||
self.pp_rank = dist.get_rank(self.plugin.pp_group)
|
||||
|
||||
self.dp_size = dist.get_world_size(self.plugin.dp_group)
|
||||
self.tp_size = dist.get_world_size(self.plugin.tp_group)
|
||||
self.sp_size = dist.get_world_size(self.plugin.sp_group)
|
||||
self.pp_size = dist.get_world_size(self.plugin.pp_group)
|
||||
|
||||
# Init Hybrid ray process group
|
||||
for i in range(self.num_producers):
|
||||
cc.init_collective_group(self.world_size + 1, self.rank + 1, group_name=f"sync_data_{i}")
|
||||
cc.init_collective_group(self.world_size + 1, self.rank + 1, backend="hccl", group_name=f"sync_data_{i}")
|
||||
if self.pp_size > 1:
|
||||
# use hybrid tp + pp
|
||||
if self.tp_rank == 0 and self.dp_rank == 0:
|
||||
cc.init_collective_group(
|
||||
self.num_producers + 1, self.num_producers, group_name=f"sync_model_{self.pp_rank}"
|
||||
self.num_producers + 1, self.num_producers, backend="hccl", group_name=f"sync_model_{self.pp_rank}"
|
||||
)
|
||||
else:
|
||||
if self.rank == 0:
|
||||
cc.init_collective_group(self.num_producers + 1, self.num_producers, group_name="sync_model")
|
||||
cc.init_collective_group(
|
||||
self.num_producers + 1, self.num_producers, backend="hccl", group_name="sync_model"
|
||||
)
|
||||
|
||||
self.buffer = []
|
||||
self.recv_cnt = 0
|
||||
|
@ -210,6 +210,7 @@ class VLLMInferenceBackend(BaseInferenceBackend):
|
||||
self.model_config = model_config
|
||||
self.tokenizer = tokenizer
|
||||
self.num_generations = num_generations
|
||||
self.max_length = generate_config["max_tokens"]
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
|
||||
|
@ -80,9 +80,42 @@ def launch_distributed(
|
||||
f"{project_name.replace(' ','_')}_run_{wandb_group_name}.jsonl",
|
||||
)
|
||||
|
||||
procs = []
|
||||
nodes = ray.nodes()
|
||||
node_info = {
|
||||
node["NodeID"]: {
|
||||
# "num_gpus": node["Resources"].get("GPU", 0),
|
||||
"num_gpus": node["Resources"].get("NPU", 0),
|
||||
"address": node["NodeManagerAddress"],
|
||||
} # Default to 0 if no GPUs are available
|
||||
for node in nodes
|
||||
}
|
||||
gpu_to_node_id = []
|
||||
gpu_to_ip_address = []
|
||||
for node_id in node_info:
|
||||
for idx in range(int(node_info[node_id]["num_gpus"])): # use num_gpus instead of num_npus
|
||||
gpu_to_node_id.append(node_id)
|
||||
gpu_to_ip_address.append(node_info[node_id]["address"])
|
||||
|
||||
producer_procs = []
|
||||
|
||||
for i in range(num_producers):
|
||||
producer = SimpleProducer.options(num_gpus=num_proc_per_producer).remote(
|
||||
node_id = gpu_to_node_id[0]
|
||||
producer_ip_address = gpu_to_ip_address[0]
|
||||
for _ in range(num_proc_per_producer):
|
||||
gpu_to_node_id.pop(0)
|
||||
gpu_to_ip_address.pop(0)
|
||||
print(f"Schedual Producer P[{i}] which requires {num_proc_per_producer} GPUs on node {producer_ip_address}")
|
||||
|
||||
producer = SimpleProducer.options(
|
||||
# num_cpus=1,
|
||||
# num_cpus=num_proc_per_producer,
|
||||
num_gpus=0,
|
||||
resources={"NPU": num_proc_per_producer},
|
||||
scheduling_strategy=ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy(
|
||||
node_id=node_id,
|
||||
soft=False,
|
||||
),
|
||||
).remote(
|
||||
producer_idx=i,
|
||||
num_producers=num_producers,
|
||||
num_consumer_procs=num_consumer_procs,
|
||||
@ -108,20 +141,35 @@ def launch_distributed(
|
||||
log_rollout_interval=log_rollout_interval,
|
||||
rollout_log_file=rollout_log_file,
|
||||
)
|
||||
procs.append(producer)
|
||||
producer_procs.append(producer)
|
||||
ray.get([p.setup.remote() for p in producer_procs])
|
||||
generate_config_consumer = copy.deepcopy(generate_config)
|
||||
generate_config_consumer.update(
|
||||
dict(
|
||||
backend=inference_backend,
|
||||
)
|
||||
)
|
||||
consumer_master_ip_address = gpu_to_ip_address[0]
|
||||
print(f"Use {consumer_master_ip_address} as master address for torch DDP.")
|
||||
consumer_procs = []
|
||||
for i in range(num_consumer_procs):
|
||||
consumer = core_consumer.options(num_gpus=1).remote(
|
||||
node_id = gpu_to_node_id[0]
|
||||
consumer_ip_address = gpu_to_ip_address[0]
|
||||
gpu_to_node_id.pop(0)
|
||||
gpu_to_ip_address.pop(0)
|
||||
print(f"Schedual Consumer T[{i}] which requires 1 GPUs on node {consumer_ip_address}")
|
||||
consumer = core_consumer.options(
|
||||
resources={"NPU": 1},
|
||||
scheduling_strategy=ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy(
|
||||
node_id=node_id,
|
||||
soft=False,
|
||||
),
|
||||
).remote(
|
||||
num_producers=num_producers,
|
||||
num_episodes=num_episodes,
|
||||
rank=i,
|
||||
world_size=num_consumer_procs,
|
||||
master_addr=master_addr,
|
||||
master_addr=consumer_master_ip_address,
|
||||
master_port=master_port,
|
||||
num_update_per_episode=num_update_per_episode,
|
||||
num_recv_per_update=num_recv_per_update,
|
||||
@ -138,6 +186,6 @@ def launch_distributed(
|
||||
run_name=run_name,
|
||||
wandb_group_name=wandb_group_name,
|
||||
)
|
||||
procs.append(consumer)
|
||||
ray.get([p.setup.remote() for p in procs])
|
||||
ray.get([p.loop.remote() for p in procs])
|
||||
consumer_procs.append(consumer)
|
||||
ray.get([p.setup.remote() for p in consumer_procs])
|
||||
ray.get([p.loop.remote() for p in (producer_procs + consumer_procs)])
|
||||
|
@ -11,15 +11,13 @@ import wandb
|
||||
from coati.dataset.loader import RawConversationDataset
|
||||
from coati.distributed.reward.reward_fn import boxed_math_reward_fn, math_reward_fn
|
||||
from ray.util.collective import allreduce
|
||||
from ray.util.collective.types import Backend, ReduceOp
|
||||
from ray.util.collective.types import ReduceOp
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from .comm import ray_broadcast_tensor_dict
|
||||
from .inference_backend import BACKEND_MAP
|
||||
from .utils import pre_send, safe_append_to_jsonl_file
|
||||
from .utils import safe_append_to_jsonl_file
|
||||
|
||||
try:
|
||||
from vllm import SamplingParams
|
||||
@ -150,7 +148,8 @@ class BaseProducer:
|
||||
raise ValueError(f"Unknown evaluation function type {evaluation_function_type}")
|
||||
else:
|
||||
print("No eval dataset provided, skip eval")
|
||||
self.device = get_current_device()
|
||||
|
||||
self.device = "npu"
|
||||
|
||||
# init backend
|
||||
if backend in BACKEND_MAP:
|
||||
@ -162,17 +161,15 @@ class BaseProducer:
|
||||
|
||||
def setup(self) -> None:
|
||||
cc.init_collective_group(
|
||||
world_size=self.num_producers,
|
||||
rank=self.producer_idx,
|
||||
backend=Backend.NCCL,
|
||||
group_name="producer_group",
|
||||
1 + self.num_consumer_procs, 0, backend="hccl", group_name=f"sync_data_{self.producer_idx}"
|
||||
)
|
||||
cc.init_collective_group(1 + self.num_consumer_procs, 0, group_name=f"sync_data_{self.producer_idx}")
|
||||
if self.consumer_pp_size > 1:
|
||||
for i in range(self.consumer_pp_size):
|
||||
cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name=f"sync_model_{i}")
|
||||
cc.init_collective_group(
|
||||
self.num_producers + 1, self.producer_idx, backend="hccl", group_name=f"sync_model_{i}"
|
||||
)
|
||||
else:
|
||||
cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name="sync_model")
|
||||
cc.init_collective_group(self.num_producers + 1, self.producer_idx, backend="hccl", group_name="sync_model")
|
||||
|
||||
def rollout(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
@ -250,7 +247,6 @@ class BaseProducer:
|
||||
outputs["temperature"] = torch.tensor(
|
||||
[self.model.generate_config["temperature"]] * outputs["input_ids"].size(0)
|
||||
).to(outputs["input_ids"].device)
|
||||
outputs = pre_send(outputs)
|
||||
ray_broadcast_tensor_dict(
|
||||
outputs, src=0, device=self.device, group_name=f"sync_data_{self.producer_idx}"
|
||||
)
|
||||
|
@ -1,9 +1,9 @@
|
||||
transformers==4.39.3
|
||||
transformers==4.47.0
|
||||
tqdm
|
||||
datasets==2.14.7
|
||||
loralib
|
||||
colossalai>=0.4.7
|
||||
torch>=2.1.0
|
||||
torch==2.5.1
|
||||
langchain
|
||||
tokenizers
|
||||
fastapi
|
||||
@ -22,3 +22,9 @@ sentencepiece==0.1.99
|
||||
flash-attn
|
||||
tiktoken
|
||||
jsonlines
|
||||
math-verify==0.7.0
|
||||
|
||||
# The following packages be built into the image.
|
||||
# torch_npu==2.5.1
|
||||
# fuyao-ray==2.43.0
|
||||
# vllm-ascend==0.7.3
|
||||
|
@ -151,7 +151,9 @@ if __name__ == "__main__":
|
||||
args.top_k = -1
|
||||
|
||||
inference_model_config = dict(path=args.model)
|
||||
train_model_config = dict(path=args.model, use_flash_attention_2=True, use_cache=False)
|
||||
train_model_config = dict(
|
||||
path=args.model, use_flash_attention_2=False, use_cache=False, attn_implementation="eager"
|
||||
)
|
||||
generate_config = dict(top_k=args.top_k, top_p=args.top_p, temperature=args.temperature)
|
||||
|
||||
if args.backend == "transformers":
|
||||
@ -247,17 +249,14 @@ if __name__ == "__main__":
|
||||
train_model_config=train_model_config,
|
||||
grpo_config=grpo_config,
|
||||
plugin_config={
|
||||
"zero_stage": 2,
|
||||
}, # for zero
|
||||
# plugin_config={
|
||||
# "tp_size": 2,
|
||||
# "pp_size": 2,
|
||||
# "microbatch_size": max(
|
||||
# 1, args.train_microbatch_size // 2
|
||||
# ), # microbatch size should be set to train_microbatch_size // pp_size
|
||||
# "zero_stage": 0,
|
||||
# "max_norm": 1.0,
|
||||
# }, # for pp, tp
|
||||
"tp_size": 2,
|
||||
"pp_size": 2,
|
||||
"microbatch_size": max(
|
||||
1, args.train_microbatch_size // 2
|
||||
), # microbatch size should be set to train_microbatch_size // pp_size
|
||||
"zero_stage": 1,
|
||||
"max_norm": 1.0,
|
||||
}, # for pp, tp
|
||||
inference_backend=args.backend,
|
||||
master_addr="localhost",
|
||||
master_port=args.master_port,
|
||||
|
@ -92,7 +92,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||
|
||||
assert (
|
||||
self.num_microbatches >= self.stage_manager.num_stages
|
||||
), "Number of microbatch should be larger than number of stages"
|
||||
), f"Number of microbatch should be larger than number of stages"
|
||||
|
||||
if self.forward_only:
|
||||
self.num_microbatches = (self.batch_size - 1) // self.microbatch_size + 1
|
||||
|
@ -2,6 +2,7 @@ import math
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
from transformers.modeling_outputs import (
|
||||
@ -143,14 +144,8 @@ class Qwen2PipelineForwards:
|
||||
# for the other stages, hidden_states is the output of the previous stage
|
||||
if shard_config.enable_flash_attention:
|
||||
# in this case, attention_mask is a dict rather than a tensor
|
||||
mask_shape = (batch_size, 1, seq_length, seq_length_with_past)
|
||||
attention_mask = ColoAttention.prepare_attn_kwargs(
|
||||
mask_shape,
|
||||
hidden_states.dtype,
|
||||
hidden_states.device,
|
||||
q_padding_mask=attention_mask,
|
||||
is_causal=True,
|
||||
)
|
||||
(batch_size, 1, seq_length, seq_length_with_past)
|
||||
attention_mask = None
|
||||
else:
|
||||
if self._attn_implementation == "flash_attention_2":
|
||||
# 2d mask is passed through the layers
|
||||
@ -488,6 +483,144 @@ class Qwen2PipelineForwards:
|
||||
return {"hidden_states": hidden_states}
|
||||
|
||||
|
||||
def get_qwen2_flash_attention_npu_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None):
|
||||
def forward(
|
||||
self: Qwen2Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if sp_mode is not None:
|
||||
assert sp_mode in ["all_to_all", "split_gather", "ring"], "Invalid sp_mode"
|
||||
assert (sp_size is not None) and (
|
||||
sp_group is not None
|
||||
), "Must specify sp_size and sp_group for sequence parallel"
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
# sp: modify sp_len when sequence parallel mode is ring
|
||||
if sp_mode in ["split_gather", "ring"]:
|
||||
q_len *= sp_size
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
# sp: all-to-all comminucation when introducing sequence parallel
|
||||
if sp_mode == "all_to_all":
|
||||
query_states = all_to_all_comm(query_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||
key_states = all_to_all_comm(key_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||
value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||
bsz, q_len, _ = query_states.size()
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, -1).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, -1).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, -1).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
if self.layer_idx is None:
|
||||
raise ValueError(
|
||||
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
||||
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
||||
"with a layer index."
|
||||
)
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
# Because the input can be padded, the absolute sequence length depends on the max position id.
|
||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||
|
||||
if past_key_value is not None:
|
||||
# Activate slicing cache only if the config has a value `sliding_windows` attribute
|
||||
cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
|
||||
if (
|
||||
getattr(self.config, "sliding_window", None) is not None
|
||||
and kv_seq_len > self.config.sliding_window
|
||||
and cache_has_contents
|
||||
):
|
||||
slicing_tokens = 1 - self.config.sliding_window
|
||||
|
||||
past_key = past_key_value[self.layer_idx][0]
|
||||
past_value = past_key_value[self.layer_idx][1]
|
||||
|
||||
past_key = past_key[:, :, slicing_tokens:, :].contiguous()
|
||||
past_value = past_value[:, :, slicing_tokens:, :].contiguous()
|
||||
|
||||
if past_key.shape[-2] != self.config.sliding_window - 1:
|
||||
raise ValueError(
|
||||
f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
|
||||
f" {past_key.shape}"
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask[:, slicing_tokens:]
|
||||
attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
|
||||
|
||||
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
if shard_config.enable_flash_attention:
|
||||
scale = 1.0 / math.sqrt(query_states.shape[-1])
|
||||
attn_output = torch_npu.npu_fusion_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
head_num=query_states.size(1),
|
||||
input_layout="BNSD",
|
||||
sparse_mode=1,
|
||||
atten_mask=None,
|
||||
scale=scale,
|
||||
pre_tockens=65536,
|
||||
next_tockens=65536,
|
||||
)
|
||||
attn_output = attn_output[0]
|
||||
else:
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
|
||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
if sp_mode == "all_to_all":
|
||||
attn_output = attn_output.reshape(bsz, q_len, -1)
|
||||
attn_output = all_to_all_comm(
|
||||
attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
else:
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output, None, past_key_value
|
||||
|
||||
return forward
|
||||
|
||||
|
||||
def get_qwen2_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None):
|
||||
def forward(
|
||||
self: Qwen2Attention,
|
||||
@ -824,7 +957,6 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
force_sp_output_gather=False,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
|
@ -19,7 +19,7 @@ from colossalai.shardformer.layer import (
|
||||
from ..modeling.qwen2 import (
|
||||
Qwen2PipelineForwards,
|
||||
get_lm_forward_with_dist_cross_entropy,
|
||||
get_qwen2_flash_attention_forward,
|
||||
get_qwen2_flash_attention_npu_forward,
|
||||
get_qwen2_model_forward_for_flash_attn,
|
||||
)
|
||||
|
||||
@ -304,7 +304,7 @@ class Qwen2Policy(Policy):
|
||||
if self.shard_config.enable_flash_attention or self.shard_config.enable_sequence_parallelism:
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_qwen2_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group),
|
||||
"forward": get_qwen2_flash_attention_npu_forward(self.shard_config, sp_mode, sp_size, sp_group),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=attn_cls,
|
||||
|
Loading…
Reference in New Issue
Block a user