[ColossalRL] Support ColossalRL on Ascend (#6324)

* [fix] support npu

* [feat] multinode 14B

* [feat] enlarge seqlen

* [fix]

* [fix] ready to updated

* [fix] ready to merge grpo-latest

* [fix] rm comments

* [feat] support msprof-analyze, add analsys result

* [feat] support ColossalaiRL on Ascend

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [feat] rm comments in qwen modeling

* [Doc] Drafted README.md

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [feat] fix ascend readme format

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [fix] fix readme

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [fix] fix readme

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [fix] fix Readme, rm irrelevant testcase

* [fix] fix some adapt modification

* [fix] rm comments in modeling qwen

* [fix] rm comm, test and debug print

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
This commit is contained in:
duanjunwen 2025-05-28 10:43:13 +08:00 committed by GitHub
parent de2ad3b206
commit e1ca2d22ae
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 527 additions and 55 deletions

View File

@ -1,6 +1,291 @@
# Requirements # Distributed RL Framework for Language Model Fine-Tuning
This repository implements a distributed Reinforcement Learning (RL) training framework designed to fine-tune large language models using algorithms such as **GRPO** and **DAPO**. It supports multi-node and multi-GPU setups, scalable rollout generation, and policy optimization using libraries like VLLM.
---
## 🚀 Features
* **Distributed Training with Ray**: Scalable to multiple machines and GPUs.
* **Support for GRPO and DAPO**: Choose your preferred policy optimization algorithm.
* **Model Backends**: Support `vllm` as inference backends.
* **Rollout and Policy Decoupling**: Efficient generation and consumption of data through parallel inferencer-trainer architecture.
* **Evaluation Integration**: Easily plug in task-specific eval datasets.
* **Checkpoints and Logging**: Configurable intervals and directories.
---
## 🛠 Installation
### Prepare Develop Environment
Install Colossalai & ColossalChat
```bash
git clone https://github.com/hpcaitech/ColossalAI.git
git checkout grpo-latest-ascend
pip install -e .
cd ./applications/ColossalChat
pip install -e .
```
Install Fuyao Ray.
Please update CANN before install fuyao ray
```bash
# Install CANN
source /usr/local/Ascend/ascend-toolkit/set_env.sh
./Ascend-cann-kernels-910b_8.1.RC1.alpha001_linux-aarch64.run --devel
# Clone Fuyao Ray. Fuyao Ray is not an open source project, it will be inherited in the ColossalRL images.
git clone https://gitee.com/openfuyao/ray.git
cd ray
git pull origin pull/5/head
# Install ray
pip install ray==2.43.0 --no-cache-dir
# Create soft-link from fuyao-ray to ray site-package
cd ..
ln -s ./ray/python/ray/ /usr/local/python3.10/lib/python3.10/site-packages/ray
# Install Fuyao Ray
cd ray
python python/ray/setup-dev.py
```
Prepare Model & dataset
```bash
huggingface-cli download --local-dir-use-symlinks False Qwen/Qwen2.5-7B --local-dir /models/Qwen/Qwen2.5-7B
```
### Set Distributed Config
Now, we need to set distributed config for multi-node.
First, we set host ip config.
For example. I need to configure a cluster of 4 nodes, then I do
```bash
vim /etc/hosts
```
Then write IP node map to /etc/hosts
```bash
10.0.0.3 npu-3
10.0.0.4 npu-4
10.0.0.5 npu-5
10.0.0.6 npu-6
```
Set Ascend Multi-Node Config
```bash ```bash
pip install cupy-cuda12x export ATB_LLM_HCCL_ENABLE=1
python -m cupyx.tools.install_library --cuda 12.x --library nccl export ATB_LLM_COMM_BACKEND="hccl"
export HCCL_CONNECT_TIMEOUT=7200
export WORLD_SIZE=32
export HCCL_EXEC_TIMEOUT=7200
export HCCL_SOCKET_IFNAME=eno0
export RAY_COLLECTIVE_MEET_TIMEOUT_SECONDS=7200
``` ```
## 🧠 Data Format
Each data sample in the training or evaluation `.jsonl` file should follow this format:
```json
{
"messages": {
"role": "user",
"content": "Simplify $\\sqrt[3]{1+8} \\cdot \\sqrt[3]{1+\\sqrt[3]{8}}$. Let's think step by step and output the final answer within \\boxed{}."
},
"gt_answer": "3"
}
```
---
## ⚙️ Hyperparameters & Arguments
| Argument | Description | Example |
| ---------------- | --------------------------------------- | ----------------- |
| `--model` | Model path or identifier | `/path/to/model` |
| `--dataset` | Path to training `.jsonl` | `/path/to/train_data.jsonl` |
| `--eval-dataset` | JSON of task\:eval\_dataset\_path pairs | `{'eval_1':'/path/to/eval_1.jsonl'}` |
| `--project` | Project name | `Project1` |
| `--num-episodes` | Number of training episodes | `1` |
### Distributed Training
| Argument | Description | Example |
| ----------------------------- | ------------------------------------- | ------- |
| `--num-trainers` | Number of trainer processes | `4` |
| `--num-inferencer` | Number of inferencer processes | `4` |
| `--inference-batch-size` | Prompts per inference step | `8` |
| `--inference-microbatch-size` | Per-GPU batch size for inference | `8` |
| `--train-batch-size` | Prompts per trainer step per dp group | `8` |
| `--train-minibatch-size` | Mini-batch size before forward pass | `8` |
| `--train-microbatch-size` | Per-GPU batch size for training | `2` |
### Sampling
| Argument | Description | Example |
| --------------------- | --------------------- | -------------- |
| `--backend` | Generation backend, choose from `vllm` | `vllm` |
| `--temperature` | Sampling temperature for generation | `1.0` |
| `--top-k` | Top-K sampling parameter for generation | `None` |
| `--top-p` | Top-P sampling parameter for generation | `1.0` |
| `--system-prompt` | System prompt, default to the system prompt for `think_answer_tags` format | `Please reason step by step, and put your final answer within \\boxed{}.` |
| `--max-new-tokens` | Max generation tokens | `3584` |
| `--max-prompt-tokens` | Max prompt tokens | `512` |
### GRPO Specific
| Argument | Description | Example |
| ----------------- | ---------------------------- | ------------------- |
| `--algo` | Algorithm (`GRPO` or `DAPO`), for more customization refer to [GRPO Settings](#-grpo-settings) | `GRPO` |
| `--learning-rate` | Learning rate | `1e-6` |
| `--kl-coeff` | KL penalty coefficient | `0.01` |
| `--reward-type` | Reward signal type (choose from 'think_answer_tags', 'boxed') | `think_answer_tags` |
| `--eval-interval` | Evaluation interval in number of training steps (positive value to enable evaluation) | `100` |
### Logging and Checkpointing
| Argument | Description | Example |
| -------------------- | ------------------------- | ------------ |
| `--save-interval` | Training steps between checkpoints | `20` |
| `--save-dir` | Checkpoint directory | `./model` |
| `--eval-save-dir` | Evaluation save path | `./eval` |
| `--rollout-save-dir` | Rollout logs directory | `./rollouts` |
### Miscellaneous
| Argument | Description | Example |
| ------------------ | --------------------------------------- | ------- |
| `--ray_dir` | Custom Ray temp dir of a running Ray cluster (optional) | `None` |
| `--master_address` | Master address of a running Ray cluster | `None` |
| `--master_port` | Master port for torch DDP | `29506` |
---
## ⚙️ GRPO Settings
In addition to the two default training settings we provided--- original `GRPO` and `DAPO`, users can customize their training by changing the following hyperparameters in `grpo_config` in `rl_example.py`.
| Argument Name | Description | Default |
| ----------------------------- | ---------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------- |
| `filter_range` | Filters out rollout group if the success rate within that group is out of this range.| `[0.01, 0.99]` |
| `dynamic_batching` | Enables dynamic batching as described in the [DAPO paper](https://arxiv.org/abs/2503.14476). | `True` |
| `clip_eps_low` | epsilon_low in DAPO in equation in [DAPO paper](https://arxiv.org/abs/2503.14476) | `0.2` |
| `clip_eps_high` | epsilon_high in DAPO equation in [DAPO paper](https://arxiv.org/abs/2503.14476) | `0.28` |
| `skip_threshold` | If ratio is above this threshold, the sample is skipped to avoid instability. | `20.0` |
| `loss_variation` | Type of loss variation. Supports `"token_level"` for token-wise policy gradient loss and `sample_level` for original GRPO loss. | `"token_level"` |
| `soft_over_length_punishment` | Whether to use soft overlength penalty in [DAPO paper](https://arxiv.org/abs/2503.14476) or not. | `True` |
| `cache_length` | `L_cache` parameter for soft overlength penalty in e.q. 13 in [DAPO paper](https://arxiv.org/abs/2503.14476) | `min(1024, int(args.max_new_tokens / 4))` |
| `filter_truncated_response` | Mask out truncated responses in loss calculation. | `True` |
## 🔄 Constraints and Notes
* `num_inferencer + num_trainer == NUM_GPUs`
* `num_inferencer % num_trainer == 0`
* `(num_inferencer * inference_batch_size) % (num_trainer * train_batch_size) == 0`
* `train_batch_size >= train_minibatch_size >= train_microbatch_size`
* `inference_batch_size >= inference_microbatch_size`
* Set microbatch sizes based on **VRAM capacity**
* To use tensor parallelism on inferencer
* set backend to `vllm`
* change `tensor_parallel_size` in `inference_model_config` in rl_example.py
* set `num_inferencer = NUM_INFERENCE_GPUs / tensor_parallel_size`
* To set tensor parallelism / pipeline parallelism / zero stage
* change corresponding settings in `plugin_config` in rl_example.py
* Ensure rollout generation rate matches trainer consumption:
```
num_inferencer * inference_batch_size % (
num_trainer * train_batch_size /
train_pipeline_parallelism_size /
train_tensor_parallelism_size
) == 0
```
* Model weights sync every:
```
(num_inferencer * inference_batch_size) /
(num_trainer * train_batch_size /
train_pipeline_parallelism_size /
train_tensor_parallelism_size)
```
---
## 🧪 Example: single machine 8-GPU Zero2 Strategy
```bash
python rl_example.py \
--dataset /path/to/train_data.jsonl \
--model /path/to/Qwen2.5-3B/ \
-t 4 -i 4 \
-b vllm \
-ibs 2 -tbs 4 -tMbs 1 -tmbs 4 -imbs 1 \
-rt boxed \
-g 4 \
-ibs 1 \
-tbs 2 \
-tMbs 1 \
-tmbs 2 \
-imbs 1 \
-s "Please reason step by step, and put your final answer within \\boxed{}." \
-tMbs 8 \
-p GRPO-Train-Align-Debug \
```
## 🧪 Example: multi-machine TP+PP Strategy
### Create ray cluster on multi-machine
For example, now we have 4 nodes and their IPs are 10.0.0.3, 10.0.0.4, 10.0.0.5, 10.0.0.6.
We use 10.0.0.3 as master node. First we start a ray cluster on 10.0.0.3:
```bash
ray start --head --node-ip-address=10.0.0.3
```
Then, for each slave node (10.0.0.4/10.0.0.5/10.0.0.6), we add to the ray cluser by following code:
```bash
ray start --address='10.0.0.3:6379'
```
Modify plugin_config in ./applications/ColossalChat/rl_example.py
```python
plugin_config={
"tp_size": 4,
"pp_size": 2,
"microbatch_size": max(
1, args.train_microbatch_size // 2
), # microbatch size should be set to train_microbatch_size // pp_size
"zero_stage": 1,
"max_norm": 1.0,
}, # for pp, tp
```
```bash
# Hint1: replace /models/Qwen/Qwen2.5-7B to your model path
# replace /datasets/train-alignment.jsonl to your dataset path
python rl_example.py
-m /path/to/Qwen2.5-Math-7B/ \
-d /path/to/train_data.jsonl \
--master_address '10.0.0.3'
-t 16 \
-i 16 \
-p GRPO-Train-Align-Debug \
-g 2 \
-ibs 1 \
-tbs 2 \
-tMbs 1 \
-tmbs 2 \
-imbs 1 \
-b vllm \
-e 2 \
-rt boxed \
-s "Please reason step by step, and put your final answer within \\boxed{}."
```
## Acknowledgement
---

View File

@ -13,7 +13,6 @@ from colossalai.booster import Booster
from colossalai.booster.plugin import HybridParallelPlugin from colossalai.booster.plugin import HybridParallelPlugin
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
from .comm import ray_broadcast_tensor_dict from .comm import ray_broadcast_tensor_dict
from .utils import bind_batch, post_recv, unbind_batch from .utils import bind_batch, post_recv, unbind_batch
@ -33,6 +32,7 @@ class BaseConsumer:
batch_size: int, batch_size: int,
model_config: Dict[str, Any], model_config: Dict[str, Any],
plugin_config: Dict[str, Any], plugin_config: Dict[str, Any],
generate_config: Dict[str, Any],
minibatch_size: int = 1, minibatch_size: int = 1,
save_interval: int = 100, save_interval: int = 100,
save_dir: str = "./model", save_dir: str = "./model",
@ -55,8 +55,9 @@ class BaseConsumer:
self.model_config = model_config self.model_config = model_config
self.plugin_config = plugin_config self.plugin_config = plugin_config
self.device = get_current_device() self.device = "npu"
self.lr_scheduler = None self.lr_scheduler = None
self.generate_config = generate_config
def setup(self) -> None: def setup(self) -> None:
launch(self.rank, self.world_size, self.master_addr, self.master_port, local_rank=0) launch(self.rank, self.world_size, self.master_addr, self.master_port, local_rank=0)
@ -73,24 +74,28 @@ class BaseConsumer:
self.booster = Booster(plugin=self.plugin) self.booster = Booster(plugin=self.plugin)
self.dp_rank = dist.get_rank(self.plugin.dp_group) self.dp_rank = dist.get_rank(self.plugin.dp_group)
self.tp_rank = dist.get_rank(self.plugin.tp_group) self.tp_rank = dist.get_rank(self.plugin.tp_group)
self.sp_rank = dist.get_rank(self.plugin.sp_group)
self.pp_rank = dist.get_rank(self.plugin.pp_group) self.pp_rank = dist.get_rank(self.plugin.pp_group)
self.dp_size = dist.get_world_size(self.plugin.dp_group) self.dp_size = dist.get_world_size(self.plugin.dp_group)
self.tp_size = dist.get_world_size(self.plugin.tp_group) self.tp_size = dist.get_world_size(self.plugin.tp_group)
self.sp_size = dist.get_world_size(self.plugin.sp_group)
self.pp_size = dist.get_world_size(self.plugin.pp_group) self.pp_size = dist.get_world_size(self.plugin.pp_group)
# Init Hybrid ray process group # Init Hybrid ray process group
for i in range(self.num_producers): for i in range(self.num_producers):
cc.init_collective_group(self.world_size + 1, self.rank + 1, group_name=f"sync_data_{i}") cc.init_collective_group(self.world_size + 1, self.rank + 1, backend="hccl", group_name=f"sync_data_{i}")
if self.pp_size > 1: if self.pp_size > 1:
# use hybrid tp + pp # use hybrid tp + pp
if self.tp_rank == 0 and self.dp_rank == 0: if self.tp_rank == 0 and self.dp_rank == 0:
cc.init_collective_group( cc.init_collective_group(
self.num_producers + 1, self.num_producers, group_name=f"sync_model_{self.pp_rank}" self.num_producers + 1, self.num_producers, backend="hccl", group_name=f"sync_model_{self.pp_rank}"
) )
else: else:
if self.rank == 0: if self.rank == 0:
cc.init_collective_group(self.num_producers + 1, self.num_producers, group_name="sync_model") cc.init_collective_group(
self.num_producers + 1, self.num_producers, backend="hccl", group_name="sync_model"
)
self.buffer = [] self.buffer = []
self.recv_cnt = 0 self.recv_cnt = 0

View File

@ -210,6 +210,7 @@ class VLLMInferenceBackend(BaseInferenceBackend):
self.model_config = model_config self.model_config = model_config
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.num_generations = num_generations self.num_generations = num_generations
self.max_length = generate_config["max_tokens"]
@torch.no_grad() @torch.no_grad()
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]: def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:

View File

@ -80,9 +80,42 @@ def launch_distributed(
f"{project_name.replace(' ','_')}_run_{wandb_group_name}.jsonl", f"{project_name.replace(' ','_')}_run_{wandb_group_name}.jsonl",
) )
procs = [] nodes = ray.nodes()
node_info = {
node["NodeID"]: {
# "num_gpus": node["Resources"].get("GPU", 0),
"num_gpus": node["Resources"].get("NPU", 0),
"address": node["NodeManagerAddress"],
} # Default to 0 if no GPUs are available
for node in nodes
}
gpu_to_node_id = []
gpu_to_ip_address = []
for node_id in node_info:
for idx in range(int(node_info[node_id]["num_gpus"])): # use num_gpus instead of num_npus
gpu_to_node_id.append(node_id)
gpu_to_ip_address.append(node_info[node_id]["address"])
producer_procs = []
for i in range(num_producers): for i in range(num_producers):
producer = SimpleProducer.options(num_gpus=num_proc_per_producer).remote( node_id = gpu_to_node_id[0]
producer_ip_address = gpu_to_ip_address[0]
for _ in range(num_proc_per_producer):
gpu_to_node_id.pop(0)
gpu_to_ip_address.pop(0)
print(f"Schedual Producer P[{i}] which requires {num_proc_per_producer} GPUs on node {producer_ip_address}")
producer = SimpleProducer.options(
# num_cpus=1,
# num_cpus=num_proc_per_producer,
num_gpus=0,
resources={"NPU": num_proc_per_producer},
scheduling_strategy=ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy(
node_id=node_id,
soft=False,
),
).remote(
producer_idx=i, producer_idx=i,
num_producers=num_producers, num_producers=num_producers,
num_consumer_procs=num_consumer_procs, num_consumer_procs=num_consumer_procs,
@ -108,20 +141,35 @@ def launch_distributed(
log_rollout_interval=log_rollout_interval, log_rollout_interval=log_rollout_interval,
rollout_log_file=rollout_log_file, rollout_log_file=rollout_log_file,
) )
procs.append(producer) producer_procs.append(producer)
ray.get([p.setup.remote() for p in producer_procs])
generate_config_consumer = copy.deepcopy(generate_config) generate_config_consumer = copy.deepcopy(generate_config)
generate_config_consumer.update( generate_config_consumer.update(
dict( dict(
backend=inference_backend, backend=inference_backend,
) )
) )
consumer_master_ip_address = gpu_to_ip_address[0]
print(f"Use {consumer_master_ip_address} as master address for torch DDP.")
consumer_procs = []
for i in range(num_consumer_procs): for i in range(num_consumer_procs):
consumer = core_consumer.options(num_gpus=1).remote( node_id = gpu_to_node_id[0]
consumer_ip_address = gpu_to_ip_address[0]
gpu_to_node_id.pop(0)
gpu_to_ip_address.pop(0)
print(f"Schedual Consumer T[{i}] which requires 1 GPUs on node {consumer_ip_address}")
consumer = core_consumer.options(
resources={"NPU": 1},
scheduling_strategy=ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy(
node_id=node_id,
soft=False,
),
).remote(
num_producers=num_producers, num_producers=num_producers,
num_episodes=num_episodes, num_episodes=num_episodes,
rank=i, rank=i,
world_size=num_consumer_procs, world_size=num_consumer_procs,
master_addr=master_addr, master_addr=consumer_master_ip_address,
master_port=master_port, master_port=master_port,
num_update_per_episode=num_update_per_episode, num_update_per_episode=num_update_per_episode,
num_recv_per_update=num_recv_per_update, num_recv_per_update=num_recv_per_update,
@ -138,6 +186,6 @@ def launch_distributed(
run_name=run_name, run_name=run_name,
wandb_group_name=wandb_group_name, wandb_group_name=wandb_group_name,
) )
procs.append(consumer) consumer_procs.append(consumer)
ray.get([p.setup.remote() for p in procs]) ray.get([p.setup.remote() for p in consumer_procs])
ray.get([p.loop.remote() for p in procs]) ray.get([p.loop.remote() for p in (producer_procs + consumer_procs)])

View File

@ -11,15 +11,13 @@ import wandb
from coati.dataset.loader import RawConversationDataset from coati.dataset.loader import RawConversationDataset
from coati.distributed.reward.reward_fn import boxed_math_reward_fn, math_reward_fn from coati.distributed.reward.reward_fn import boxed_math_reward_fn, math_reward_fn
from ray.util.collective import allreduce from ray.util.collective import allreduce
from ray.util.collective.types import Backend, ReduceOp from ray.util.collective.types import ReduceOp
from torch.utils.data import DataLoader, DistributedSampler from torch.utils.data import DataLoader, DistributedSampler
from transformers import AutoTokenizer from transformers import AutoTokenizer
from colossalai.utils import get_current_device
from .comm import ray_broadcast_tensor_dict from .comm import ray_broadcast_tensor_dict
from .inference_backend import BACKEND_MAP from .inference_backend import BACKEND_MAP
from .utils import pre_send, safe_append_to_jsonl_file from .utils import safe_append_to_jsonl_file
try: try:
from vllm import SamplingParams from vllm import SamplingParams
@ -150,7 +148,8 @@ class BaseProducer:
raise ValueError(f"Unknown evaluation function type {evaluation_function_type}") raise ValueError(f"Unknown evaluation function type {evaluation_function_type}")
else: else:
print("No eval dataset provided, skip eval") print("No eval dataset provided, skip eval")
self.device = get_current_device()
self.device = "npu"
# init backend # init backend
if backend in BACKEND_MAP: if backend in BACKEND_MAP:
@ -162,17 +161,15 @@ class BaseProducer:
def setup(self) -> None: def setup(self) -> None:
cc.init_collective_group( cc.init_collective_group(
world_size=self.num_producers, 1 + self.num_consumer_procs, 0, backend="hccl", group_name=f"sync_data_{self.producer_idx}"
rank=self.producer_idx,
backend=Backend.NCCL,
group_name="producer_group",
) )
cc.init_collective_group(1 + self.num_consumer_procs, 0, group_name=f"sync_data_{self.producer_idx}")
if self.consumer_pp_size > 1: if self.consumer_pp_size > 1:
for i in range(self.consumer_pp_size): for i in range(self.consumer_pp_size):
cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name=f"sync_model_{i}") cc.init_collective_group(
self.num_producers + 1, self.producer_idx, backend="hccl", group_name=f"sync_model_{i}"
)
else: else:
cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name="sync_model") cc.init_collective_group(self.num_producers + 1, self.producer_idx, backend="hccl", group_name="sync_model")
def rollout(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]: def rollout(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
raise NotImplementedError raise NotImplementedError
@ -250,7 +247,6 @@ class BaseProducer:
outputs["temperature"] = torch.tensor( outputs["temperature"] = torch.tensor(
[self.model.generate_config["temperature"]] * outputs["input_ids"].size(0) [self.model.generate_config["temperature"]] * outputs["input_ids"].size(0)
).to(outputs["input_ids"].device) ).to(outputs["input_ids"].device)
outputs = pre_send(outputs)
ray_broadcast_tensor_dict( ray_broadcast_tensor_dict(
outputs, src=0, device=self.device, group_name=f"sync_data_{self.producer_idx}" outputs, src=0, device=self.device, group_name=f"sync_data_{self.producer_idx}"
) )

View File

@ -1,9 +1,9 @@
transformers==4.39.3 transformers==4.47.0
tqdm tqdm
datasets==2.14.7 datasets==2.14.7
loralib loralib
colossalai>=0.4.7 colossalai>=0.4.7
torch>=2.1.0 torch==2.5.1
langchain langchain
tokenizers tokenizers
fastapi fastapi
@ -22,3 +22,9 @@ sentencepiece==0.1.99
flash-attn flash-attn
tiktoken tiktoken
jsonlines jsonlines
math-verify==0.7.0
# The following packages be built into the image.
# torch_npu==2.5.1
# fuyao-ray==2.43.0
# vllm-ascend==0.7.3

View File

@ -151,7 +151,9 @@ if __name__ == "__main__":
args.top_k = -1 args.top_k = -1
inference_model_config = dict(path=args.model) inference_model_config = dict(path=args.model)
train_model_config = dict(path=args.model, use_flash_attention_2=True, use_cache=False) train_model_config = dict(
path=args.model, use_flash_attention_2=False, use_cache=False, attn_implementation="eager"
)
generate_config = dict(top_k=args.top_k, top_p=args.top_p, temperature=args.temperature) generate_config = dict(top_k=args.top_k, top_p=args.top_p, temperature=args.temperature)
if args.backend == "transformers": if args.backend == "transformers":
@ -247,17 +249,14 @@ if __name__ == "__main__":
train_model_config=train_model_config, train_model_config=train_model_config,
grpo_config=grpo_config, grpo_config=grpo_config,
plugin_config={ plugin_config={
"zero_stage": 2, "tp_size": 2,
}, # for zero "pp_size": 2,
# plugin_config={ "microbatch_size": max(
# "tp_size": 2, 1, args.train_microbatch_size // 2
# "pp_size": 2, ), # microbatch size should be set to train_microbatch_size // pp_size
# "microbatch_size": max( "zero_stage": 1,
# 1, args.train_microbatch_size // 2 "max_norm": 1.0,
# ), # microbatch size should be set to train_microbatch_size // pp_size }, # for pp, tp
# "zero_stage": 0,
# "max_norm": 1.0,
# }, # for pp, tp
inference_backend=args.backend, inference_backend=args.backend,
master_addr="localhost", master_addr="localhost",
master_port=args.master_port, master_port=args.master_port,

View File

@ -92,7 +92,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
assert ( assert (
self.num_microbatches >= self.stage_manager.num_stages self.num_microbatches >= self.stage_manager.num_stages
), "Number of microbatch should be larger than number of stages" ), f"Number of microbatch should be larger than number of stages"
if self.forward_only: if self.forward_only:
self.num_microbatches = (self.batch_size - 1) // self.microbatch_size + 1 self.num_microbatches = (self.batch_size - 1) // self.microbatch_size + 1

View File

@ -2,6 +2,7 @@ import math
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import torch import torch
import torch_npu
from torch import nn from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.modeling_outputs import ( from transformers.modeling_outputs import (
@ -143,14 +144,8 @@ class Qwen2PipelineForwards:
# for the other stages, hidden_states is the output of the previous stage # for the other stages, hidden_states is the output of the previous stage
if shard_config.enable_flash_attention: if shard_config.enable_flash_attention:
# in this case, attention_mask is a dict rather than a tensor # in this case, attention_mask is a dict rather than a tensor
mask_shape = (batch_size, 1, seq_length, seq_length_with_past) (batch_size, 1, seq_length, seq_length_with_past)
attention_mask = ColoAttention.prepare_attn_kwargs( attention_mask = None
mask_shape,
hidden_states.dtype,
hidden_states.device,
q_padding_mask=attention_mask,
is_causal=True,
)
else: else:
if self._attn_implementation == "flash_attention_2": if self._attn_implementation == "flash_attention_2":
# 2d mask is passed through the layers # 2d mask is passed through the layers
@ -488,6 +483,144 @@ class Qwen2PipelineForwards:
return {"hidden_states": hidden_states} return {"hidden_states": hidden_states}
def get_qwen2_flash_attention_npu_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None):
def forward(
self: Qwen2Attention,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if sp_mode is not None:
assert sp_mode in ["all_to_all", "split_gather", "ring"], "Invalid sp_mode"
assert (sp_size is not None) and (
sp_group is not None
), "Must specify sp_size and sp_group for sequence parallel"
bsz, q_len, _ = hidden_states.size()
# sp: modify sp_len when sequence parallel mode is ring
if sp_mode in ["split_gather", "ring"]:
q_len *= sp_size
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
# sp: all-to-all comminucation when introducing sequence parallel
if sp_mode == "all_to_all":
query_states = all_to_all_comm(query_states, sp_group, fp8_communication=shard_config.fp8_communication)
key_states = all_to_all_comm(key_states, sp_group, fp8_communication=shard_config.fp8_communication)
value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication)
bsz, q_len, _ = query_states.size()
query_states = query_states.view(bsz, q_len, self.num_heads, -1).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, -1).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, -1).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
# Because the input can be padded, the absolute sequence length depends on the max position id.
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
# Activate slicing cache only if the config has a value `sliding_windows` attribute
cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
if (
getattr(self.config, "sliding_window", None) is not None
and kv_seq_len > self.config.sliding_window
and cache_has_contents
):
slicing_tokens = 1 - self.config.sliding_window
past_key = past_key_value[self.layer_idx][0]
past_value = past_key_value[self.layer_idx][1]
past_key = past_key[:, :, slicing_tokens:, :].contiguous()
past_value = past_value[:, :, slicing_tokens:, :].contiguous()
if past_key.shape[-2] != self.config.sliding_window - 1:
raise ValueError(
f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
f" {past_key.shape}"
)
if attention_mask is not None:
attention_mask = attention_mask[:, slicing_tokens:]
attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
if shard_config.enable_flash_attention:
scale = 1.0 / math.sqrt(query_states.shape[-1])
attn_output = torch_npu.npu_fusion_attention(
query_states,
key_states,
value_states,
head_num=query_states.size(1),
input_layout="BNSD",
sparse_mode=1,
atten_mask=None,
scale=scale,
pre_tockens=65536,
next_tockens=65536,
)
attn_output = attn_output[0]
else:
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
if sp_mode == "all_to_all":
attn_output = attn_output.reshape(bsz, q_len, -1)
attn_output = all_to_all_comm(
attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication
)
else:
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
return forward
def get_qwen2_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None): def get_qwen2_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None):
def forward( def forward(
self: Qwen2Attention, self: Qwen2Attention,
@ -824,7 +957,6 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
force_sp_output_gather=False,
) )
hidden_states = outputs[0] hidden_states = outputs[0]

View File

@ -19,7 +19,7 @@ from colossalai.shardformer.layer import (
from ..modeling.qwen2 import ( from ..modeling.qwen2 import (
Qwen2PipelineForwards, Qwen2PipelineForwards,
get_lm_forward_with_dist_cross_entropy, get_lm_forward_with_dist_cross_entropy,
get_qwen2_flash_attention_forward, get_qwen2_flash_attention_npu_forward,
get_qwen2_model_forward_for_flash_attn, get_qwen2_model_forward_for_flash_attn,
) )
@ -304,7 +304,7 @@ class Qwen2Policy(Policy):
if self.shard_config.enable_flash_attention or self.shard_config.enable_sequence_parallelism: if self.shard_config.enable_flash_attention or self.shard_config.enable_sequence_parallelism:
self.append_or_create_method_replacement( self.append_or_create_method_replacement(
description={ description={
"forward": get_qwen2_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group), "forward": get_qwen2_flash_attention_npu_forward(self.shard_config, sp_mode, sp_size, sp_group),
}, },
policy=policy, policy=policy,
target_key=attn_cls, target_key=attn_cls,