mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-23 11:44:15 +00:00
[ColossalRL] Support ColossalRL on Ascend (#6324)
* [fix] support npu * [feat] multinode 14B * [feat] enlarge seqlen * [fix] * [fix] ready to updated * [fix] ready to merge grpo-latest * [fix] rm comments * [feat] support msprof-analyze, add analsys result * [feat] support ColossalaiRL on Ascend * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [feat] rm comments in qwen modeling * [Doc] Drafted README.md * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [feat] fix ascend readme format * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [fix] fix readme * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [fix] fix readme * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [fix] fix Readme, rm irrelevant testcase * [fix] fix some adapt modification * [fix] rm comments in modeling qwen * [fix] rm comm, test and debug print * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
This commit is contained in:
parent
8880b83791
commit
d4ef7f57be
@ -1,5 +1,81 @@
|
||||
# Distributed RL Framework for Language Model Fine-Tuning
|
||||
|
||||
This repository implements a distributed Reinforcement Learning (RL) training framework designed to fine-tune large language models using algorithms such as **GRPO** and **DAPO**. It supports multi-node and multi-GPU setups, scalable rollout generation, and policy optimization using libraries like VLLM.
|
||||
|
||||
---
|
||||
|
||||
## 🚀 Features
|
||||
|
||||
* **Distributed Training with Ray**: Scalable to multiple machines and GPUs.
|
||||
* **Support for GRPO and DAPO**: Choose your preferred policy optimization algorithm.
|
||||
* **Model Backends**: Support `vllm` as inference backends.
|
||||
* **Rollout and Policy Decoupling**: Efficient generation and consumption of data through parallel inferencer-trainer architecture.
|
||||
* **Evaluation Integration**: Easily plug in task-specific eval datasets.
|
||||
* **Checkpoints and Logging**: Configurable intervals and directories.
|
||||
|
||||
---
|
||||
|
||||
## 🛠 Installation
|
||||
|
||||
### Prepare Develop Environment
|
||||
|
||||
Install Colossalai & ColossalChat
|
||||
```bash
|
||||
git clone https://github.com/hpcaitech/ColossalAI.git
|
||||
git checkout grpo-latest-ascend
|
||||
pip install -e .
|
||||
|
||||
cd ./applications/ColossalChat
|
||||
pip install -e .
|
||||
```
|
||||
Install Fuyao Ray.
|
||||
Please update CANN before install fuyao ray
|
||||
```bash
|
||||
# Install CANN
|
||||
source /usr/local/Ascend/ascend-toolkit/set_env.sh
|
||||
./Ascend-cann-kernels-910b_8.1.RC1.alpha001_linux-aarch64.run --devel
|
||||
|
||||
# Clone Fuyao Ray. Fuyao Ray is not an open source project, it will be inherited in the ColossalRL images.
|
||||
git clone https://gitee.com/openfuyao/ray.git
|
||||
cd ray
|
||||
git pull origin pull/5/head
|
||||
|
||||
# Install ray
|
||||
pip install ray==2.43.0 --no-cache-dir
|
||||
|
||||
# Create soft-link from fuyao-ray to ray site-package
|
||||
cd ..
|
||||
ln -s ./ray/python/ray/ /usr/local/python3.10/lib/python3.10/site-packages/ray
|
||||
|
||||
# Install Fuyao Ray
|
||||
cd ray
|
||||
python python/ray/setup-dev.py
|
||||
```
|
||||
|
||||
Prepare Model & dataset
|
||||
```bash
|
||||
huggingface-cli download --local-dir-use-symlinks False Qwen/Qwen2.5-7B --local-dir /models/Qwen/Qwen2.5-7B
|
||||
```
|
||||
|
||||
### Set Distributed Config
|
||||
Now, we need to set distributed config for multi-node.
|
||||
|
||||
First, we set host ip config.
|
||||
For example. I need to configure a cluster of 4 nodes, then I do
|
||||
```bash
|
||||
vim /etc/hosts
|
||||
```
|
||||
Then write IP node map to /etc/hosts
|
||||
```bash
|
||||
10.0.0.3 npu-3
|
||||
10.0.0.4 npu-4
|
||||
10.0.0.5 npu-5
|
||||
10.0.0.6 npu-6
|
||||
```
|
||||
|
||||
Set Ascend Multi-Node Config
|
||||
|
||||
|
||||
This repository implements a distributed Reinforcement Learning (RL) training framework designed to fine-tune large language models using algorithms such as **GRPO** and **DAPO**. It supports multi-node and multi-GPU setups, scalable rollout generation, and policy optimization using libraries like VLLM. Currently, we support two Reinforcement Learning with Verifiable Reward (RLVR) tasks: solving math problems and code generation.
|
||||
|
||||
**Please note that we are still under intensive development, stay tuned.**
|
||||
@ -43,73 +119,29 @@ pip install ray
|
||||
|
||||
Install Other Dependencies
|
||||
```bash
|
||||
pip install cupy-cuda12x
|
||||
python -m cupyx.tools.install_library --cuda 12.x --library nccl
|
||||
export ATB_LLM_HCCL_ENABLE=1
|
||||
export ATB_LLM_COMM_BACKEND="hccl"
|
||||
export HCCL_CONNECT_TIMEOUT=7200
|
||||
export WORLD_SIZE=32
|
||||
export HCCL_EXEC_TIMEOUT=7200
|
||||
export HCCL_SOCKET_IFNAME=eno0
|
||||
export RAY_COLLECTIVE_MEET_TIMEOUT_SECONDS=7200
|
||||
```
|
||||
|
||||
To support long input/output sequence length (e.g., 32K), you may need to manually change the default setting (180 seconds) for the `timeout_s` variable in your ray installation to a larger value as shown in the screenshot below.
|
||||
|
||||
<div align="center">
|
||||
<p align="center">
|
||||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/change_ray_timeout.png" width=700/>
|
||||
</p>
|
||||
</div>
|
||||
|
||||
Prepare Model & dataset
|
||||
```bash
|
||||
huggingface-cli download --local-dir-use-symlinks False Qwen/Qwen2.5-7B --local-dir /models/Qwen/Qwen2.5-7B
|
||||
```
|
||||
|
||||
## Architecture Design
|
||||
|
||||
<div align="center">
|
||||
<p align="center">
|
||||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chat/producer-consumer-pattern.png" width=700/>
|
||||
</p>
|
||||
</div>
|
||||
Producer-Consumer Pattern: a classic software design pattern used for managing resources, data, or tasks between two different processes or threads.
|
||||
|
||||
* Producer: inference engine which rollouts out examples and saves them into a shared buffer.
|
||||
* Consumer: training framework which takes training examples from the shared buffer and train the policy model.
|
||||
|
||||
Key features for Producer-Consumer Pattern:
|
||||
* Buffer: Acts as a shared queue where the producer adds data and the consumer removes data.
|
||||
* Concurrency: Rollout and training can work concurrently.
|
||||
|
||||
## 🧠 Data Format
|
||||
|
||||
Samples in the training or evaluation `.jsonl` file should follow the format specific to the type of task. We currently support two RLVR tasks: solving math problems and code generation.
|
||||
Each data sample in the training or evaluation `.jsonl` file should follow this format:
|
||||
|
||||
### Math Data Format
|
||||
```json
|
||||
{
|
||||
"messages": {
|
||||
"role": "user",
|
||||
"content": "Simplify $\\sqrt[3]{1+8} \\cdot \\sqrt[3]{1+\\sqrt[3]{8}}$."
|
||||
"content": "Simplify $\\sqrt[3]{1+8} \\cdot \\sqrt[3]{1+\\sqrt[3]{8}}$. Let's think step by step and output the final answer within \\boxed{}."
|
||||
},
|
||||
"gt_answer": "3"
|
||||
}
|
||||
```
|
||||
|
||||
### Code Data Format
|
||||
We support [Prime code dataset format](https://github.com/PRIME-RL/PRIME). Inputs and outputs in test cases should be two lists containing only strings and matching in the number of elements. Your prompt must properly instruct the LLM to generate code to read test cases from stdin and output results to stdout.
|
||||
```json
|
||||
{
|
||||
"messages": {
|
||||
"role": "user",
|
||||
"content": "Solve the following coding problem using the programming language python:\n\nMikhail walks on a Cartesian plane. He starts at the point $(0, 0)$, and in one move he can go to any of eight adjacent points. For example, ..."
|
||||
},
|
||||
"test_cases": {
|
||||
"inputs": [
|
||||
"3\n2 2 3\n4 3 7\n10 1 9\n"
|
||||
],
|
||||
"outputs": [
|
||||
"1\n6\n-1\n"
|
||||
]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## ⚙️ Hyperparameters & Arguments
|
||||
@ -118,7 +150,7 @@ We support [Prime code dataset format](https://github.com/PRIME-RL/PRIME). Input
|
||||
| ---------------- | --------------------------------------- | ----------------- |
|
||||
| `--model` | Model path or identifier | `/path/to/model` |
|
||||
| `--dataset` | Path to training `.jsonl` | `/path/to/train_data.jsonl` |
|
||||
| `--eval-dataset` | JSON of task\:eval\_dataset\_path pairs | `{"eval_1":"/path/to/eval_1.jsonl"}` |
|
||||
| `--eval-dataset` | JSON of task\:eval\_dataset\_path pairs | `{'eval_1':'/path/to/eval_1.jsonl'}` |
|
||||
| `--project` | Project name | `Project1` |
|
||||
| `--num-episodes` | Number of training episodes | `1` |
|
||||
|
||||
@ -142,7 +174,7 @@ We support [Prime code dataset format](https://github.com/PRIME-RL/PRIME). Input
|
||||
| `--temperature` | Sampling temperature for generation | `1.0` |
|
||||
| `--top-k` | Top-K sampling parameter for generation | `None` |
|
||||
| `--top-p` | Top-P sampling parameter for generation | `1.0` |
|
||||
| `--system-prompt` | System prompt, optional, default to the default system prompt for each reward types. For more information, refer to the [**reward type**](#-constraints-and-notes) section | `Please reason step by step, and put your final answer within \\boxed{}.` |
|
||||
| `--system-prompt` | System prompt, default to the system prompt for `think_answer_tags` format | `Please reason step by step, and put your final answer within \\boxed{}.` |
|
||||
| `--max-new-tokens` | Max generation tokens | `3584` |
|
||||
| `--max-prompt-tokens` | Max prompt tokens | `512` |
|
||||
|
||||
@ -152,9 +184,9 @@ We support [Prime code dataset format](https://github.com/PRIME-RL/PRIME). Input
|
||||
| ----------------- | ---------------------------- | ------------------- |
|
||||
| `--algo` | Algorithm (`GRPO` or `DAPO`), for more customization refer to [GRPO Settings](#️-grpo-settings) | `GRPO` |
|
||||
| `--learning-rate` | Learning rate | `1e-6` |
|
||||
| `--kl-coeff` | KL penalty coefficient, if nonzero, a reference model will be used | `0.01` |
|
||||
| `--reward-type` | Reward signal type (choose from 'think_answer_tags', 'boxed', 'code') For more information, refer to the [**reward type**](#-constraints-and-notes) section | `think_answer_tags` |
|
||||
| `--eval-interval` | Evaluation interval in number of training steps (positive value to enable evaluation) | `10` |
|
||||
| `--kl-coeff` | KL penalty coefficient | `0.01` |
|
||||
| `--reward-type` | Reward signal type (choose from 'think_answer_tags', 'boxed') | `think_answer_tags` |
|
||||
| `--eval-interval` | Evaluation interval in number of training steps (positive value to enable evaluation) | `100` |
|
||||
|
||||
### Logging and Checkpointing
|
||||
|
||||
@ -177,7 +209,7 @@ We support [Prime code dataset format](https://github.com/PRIME-RL/PRIME). Input
|
||||
|
||||
## ⚙️ GRPO Settings
|
||||
|
||||
In addition to the two default training settings provided—`GRPO` and `DAPO`—users can customize their training by changing the following hyperparameters in `grpo_config` in `rl_example.py`.
|
||||
In addition to the two default training settings we provided--- original `GRPO` and `DAPO`, users can customize their training by changing the following hyperparameters in `grpo_config` in `rl_example.py`.
|
||||
|
||||
| Argument Name | Description | Default |
|
||||
| ----------------------------- | ---------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
@ -224,29 +256,6 @@ In addition to the two default training settings provided—`GRPO` and `DAPO`—
|
||||
train_pipeline_parallelism_size /
|
||||
train_tensor_parallelism_size)
|
||||
```
|
||||
* Reward Type
|
||||
|
||||
We currently support three reward types--- `think_answer_tags`, `boxed`, `code`, each varies in details such as how answer is extracted and the reward calculation process. Please select one from `think_answer_tags`, `boxed` for math problem solving and use `code` for code generation. The default system prompt for each reward type is as follows. Please make sure your system prompt provides information for the answer to be correctly extracted from model responses.
|
||||
|
||||
* think_answer_tags
|
||||
|
||||
Answer extraction: extract the content between the last `<answer>`, `</answer>` tags.
|
||||
|
||||
```
|
||||
You are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and<answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>. Now the user asks you to solve a math problem that involves reasoning. After thinking, when you finally reach a conclusion, clearly output the final answer without explanation within the <answer> </answer> tags, i.e., <answer> 123 </answer>.\n\n
|
||||
```
|
||||
* boxed
|
||||
|
||||
Answer extraction: extract the last content marked by `\\boxed{}`
|
||||
```
|
||||
Please reason step by step, and put your final answer within \\boxed{}.
|
||||
```
|
||||
* code
|
||||
|
||||
Answer extraction: extract code inside ` ```python\n...``` `
|
||||
```
|
||||
You are a helpful assistant.
|
||||
```
|
||||
---
|
||||
|
||||
## 🧪 Example: single machine 8-GPU Zero2 Strategy
|
||||
@ -279,7 +288,7 @@ We use 10.0.0.3 as master node. First we start a ray cluster on 10.0.0.3:
|
||||
ray start --head --node-ip-address=10.0.0.3
|
||||
```
|
||||
|
||||
Then, for each slave node (10.0.0.4/10.0.0.5/10.0.0.6), we add to the ray cluster by following code:
|
||||
Then, for each slave node (10.0.0.4/10.0.0.5/10.0.0.6), we add to the ray cluser by following code:
|
||||
```bash
|
||||
ray start --address='10.0.0.3:6379'
|
||||
```
|
||||
@ -320,4 +329,5 @@ python rl_example.py
|
||||
```
|
||||
|
||||
## Acknowledgement
|
||||
Colossal-RL is a distributed version of ColossalChat and inspired by a few awesome open-source projects. We would like to express our gratitude to the following awesome open-source projects and algorithms: GRPO, DAPO, TRL, Verl, OpenRLHF, StreamRL, Qwen, Logic-RL.
|
||||
|
||||
---
|
||||
|
@ -13,7 +13,6 @@ from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import HybridParallelPlugin
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from .comm import ray_broadcast_tensor_dict
|
||||
from .utils import bind_batch, post_recv, unbind_batch
|
||||
@ -33,6 +32,7 @@ class BaseConsumer:
|
||||
batch_size: int,
|
||||
model_config: Dict[str, Any],
|
||||
plugin_config: Dict[str, Any],
|
||||
generate_config: Dict[str, Any],
|
||||
minibatch_size: int = 1,
|
||||
save_interval: int = 100,
|
||||
save_dir: str = "./model",
|
||||
@ -55,8 +55,9 @@ class BaseConsumer:
|
||||
self.model_config = model_config
|
||||
self.plugin_config = plugin_config
|
||||
|
||||
self.device = get_current_device()
|
||||
self.device = "npu"
|
||||
self.lr_scheduler = None
|
||||
self.generate_config = generate_config
|
||||
|
||||
def setup(self) -> None:
|
||||
launch(self.rank, self.world_size, self.master_addr, self.master_port, local_rank=0)
|
||||
@ -73,24 +74,28 @@ class BaseConsumer:
|
||||
self.booster = Booster(plugin=self.plugin)
|
||||
self.dp_rank = dist.get_rank(self.plugin.dp_group)
|
||||
self.tp_rank = dist.get_rank(self.plugin.tp_group)
|
||||
self.sp_rank = dist.get_rank(self.plugin.sp_group)
|
||||
self.pp_rank = dist.get_rank(self.plugin.pp_group)
|
||||
|
||||
self.dp_size = dist.get_world_size(self.plugin.dp_group)
|
||||
self.tp_size = dist.get_world_size(self.plugin.tp_group)
|
||||
self.sp_size = dist.get_world_size(self.plugin.sp_group)
|
||||
self.pp_size = dist.get_world_size(self.plugin.pp_group)
|
||||
|
||||
# Init Hybrid ray process group
|
||||
for i in range(self.num_producers):
|
||||
cc.init_collective_group(self.world_size + 1, self.rank + 1, group_name=f"sync_data_{i}")
|
||||
cc.init_collective_group(self.world_size + 1, self.rank + 1, backend="hccl", group_name=f"sync_data_{i}")
|
||||
if self.pp_size > 1:
|
||||
# use hybrid tp + pp
|
||||
if self.tp_rank == 0 and self.dp_rank == 0:
|
||||
cc.init_collective_group(
|
||||
self.num_producers + 1, self.num_producers, group_name=f"sync_model_{self.pp_rank}"
|
||||
self.num_producers + 1, self.num_producers, backend="hccl", group_name=f"sync_model_{self.pp_rank}"
|
||||
)
|
||||
else:
|
||||
if self.rank == 0:
|
||||
cc.init_collective_group(self.num_producers + 1, self.num_producers, group_name="sync_model")
|
||||
cc.init_collective_group(
|
||||
self.num_producers + 1, self.num_producers, backend="hccl", group_name="sync_model"
|
||||
)
|
||||
|
||||
self.buffer = []
|
||||
self.recv_cnt = 0
|
||||
|
@ -210,6 +210,7 @@ class VLLMInferenceBackend(BaseInferenceBackend):
|
||||
self.model_config = model_config
|
||||
self.tokenizer = tokenizer
|
||||
self.num_generations = num_generations
|
||||
self.max_length = generate_config["max_tokens"]
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
|
||||
|
@ -79,15 +79,11 @@ def launch_distributed(
|
||||
f"{project_name.replace(' ','_')}_run_{wandb_group_name}.jsonl",
|
||||
)
|
||||
|
||||
# Attention: Ray use complex schedualing method that consider various factors including load-balancing.
|
||||
# when requesting resources, it is not guaranteed that the resource comes from a node with lower node it
|
||||
# this go against the design principle of our implementation, and we need to manually force the schedualing,
|
||||
# allocating the producer to nodes with lower node id and the consumer to the resouces from nodes with higher
|
||||
# node id. See the reference here: https://docs.ray.io/en/latest/ray-core/scheduling/index.html#nodeaffinityschedulingstrategy
|
||||
nodes = ray.nodes()
|
||||
node_info = {
|
||||
node["NodeID"]: {
|
||||
"num_gpus": node["Resources"].get("GPU", 0),
|
||||
# "num_gpus": node["Resources"].get("GPU", 0),
|
||||
"num_gpus": node["Resources"].get("NPU", 0),
|
||||
"address": node["NodeManagerAddress"],
|
||||
} # Default to 0 if no GPUs are available
|
||||
for node in nodes
|
||||
@ -95,12 +91,12 @@ def launch_distributed(
|
||||
gpu_to_node_id = []
|
||||
gpu_to_ip_address = []
|
||||
for node_id in node_info:
|
||||
for idx in range(int(node_info[node_id]["num_gpus"])):
|
||||
for idx in range(int(node_info[node_id]["num_gpus"])): # use num_gpus instead of num_npus
|
||||
gpu_to_node_id.append(node_id)
|
||||
gpu_to_ip_address.append(node_info[node_id]["address"])
|
||||
print(node_info)
|
||||
|
||||
producer_procs = []
|
||||
|
||||
for i in range(num_producers):
|
||||
node_id = gpu_to_node_id[0]
|
||||
producer_ip_address = gpu_to_ip_address[0]
|
||||
@ -108,7 +104,17 @@ def launch_distributed(
|
||||
gpu_to_node_id.pop(0)
|
||||
gpu_to_ip_address.pop(0)
|
||||
print(f"Schedual Producer P[{i}] which requires {num_proc_per_producer} GPUs on node {producer_ip_address}")
|
||||
producer = SimpleProducer.options(num_gpus=num_proc_per_producer).remote(
|
||||
|
||||
producer = SimpleProducer.options(
|
||||
# num_cpus=1,
|
||||
# num_cpus=num_proc_per_producer,
|
||||
num_gpus=0,
|
||||
resources={"NPU": num_proc_per_producer},
|
||||
scheduling_strategy=ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy(
|
||||
node_id=node_id,
|
||||
soft=False,
|
||||
),
|
||||
).remote(
|
||||
producer_idx=i,
|
||||
num_producers=num_producers,
|
||||
num_consumer_procs=num_consumer_procs,
|
||||
@ -150,7 +156,13 @@ def launch_distributed(
|
||||
gpu_to_node_id.pop(0)
|
||||
gpu_to_ip_address.pop(0)
|
||||
print(f"Schedual Consumer T[{i}] which requires 1 GPUs on node {consumer_ip_address}")
|
||||
consumer = core_consumer.options(num_gpus=1).remote(
|
||||
consumer = core_consumer.options(
|
||||
resources={"NPU": 1},
|
||||
scheduling_strategy=ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy(
|
||||
node_id=node_id,
|
||||
soft=False,
|
||||
),
|
||||
).remote(
|
||||
num_producers=num_producers,
|
||||
num_episodes=num_episodes,
|
||||
rank=i,
|
||||
|
@ -12,15 +12,13 @@ from coati.dataset.loader import RawConversationDataset, collate_fn_grpo
|
||||
from coati.distributed.reward.reward_fn import boxed_math_reward_fn, code_reward_fn, math_reward_fn
|
||||
from coati.distributed.reward.verifiable_reward import VerifiableReward
|
||||
from ray.util.collective import allreduce
|
||||
from ray.util.collective.types import Backend, ReduceOp
|
||||
from ray.util.collective.types import ReduceOp
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from .comm import ray_broadcast_tensor_dict
|
||||
from .inference_backend import BACKEND_MAP
|
||||
from .utils import pre_send, safe_append_to_jsonl_file
|
||||
from .utils import safe_append_to_jsonl_file
|
||||
|
||||
try:
|
||||
from vllm import SamplingParams
|
||||
@ -161,13 +159,8 @@ class BaseProducer:
|
||||
)
|
||||
else:
|
||||
print("No eval dataset provided, skip eval")
|
||||
self.device = get_current_device()
|
||||
self.reward_model = VerifiableReward(
|
||||
reward_fns=[self.evaluation_function], # multiple reward functions can be added here
|
||||
tokenizer=self.tokenizer,
|
||||
tags=self.response_format_tags,
|
||||
**reward_model_kwargs,
|
||||
)
|
||||
|
||||
self.device = "npu"
|
||||
|
||||
# init backend
|
||||
if backend in BACKEND_MAP:
|
||||
@ -179,17 +172,15 @@ class BaseProducer:
|
||||
|
||||
def setup(self) -> None:
|
||||
cc.init_collective_group(
|
||||
world_size=self.num_producers,
|
||||
rank=self.producer_idx,
|
||||
backend=Backend.NCCL,
|
||||
group_name="producer_group",
|
||||
1 + self.num_consumer_procs, 0, backend="hccl", group_name=f"sync_data_{self.producer_idx}"
|
||||
)
|
||||
cc.init_collective_group(1 + self.num_consumer_procs, 0, group_name=f"sync_data_{self.producer_idx}")
|
||||
if self.consumer_pp_size > 1:
|
||||
for i in range(self.consumer_pp_size):
|
||||
cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name=f"sync_model_{i}")
|
||||
cc.init_collective_group(
|
||||
self.num_producers + 1, self.producer_idx, backend="hccl", group_name=f"sync_model_{i}"
|
||||
)
|
||||
else:
|
||||
cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name="sync_model")
|
||||
cc.init_collective_group(self.num_producers + 1, self.producer_idx, backend="hccl", group_name="sync_model")
|
||||
|
||||
def rollout(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
@ -272,47 +263,6 @@ class BaseProducer:
|
||||
outputs["temperature"] = torch.tensor(
|
||||
[self.model.generate_config["temperature"]] * outputs["input_ids"].size(0)
|
||||
).to(outputs["input_ids"].device)
|
||||
bs, num_gen = outputs["input_ids"].size(0), outputs["input_ids"].size(1)
|
||||
if self.grpo_config["reward_fn_type"] == "code":
|
||||
test_cases = []
|
||||
for prompt_id in range(bs):
|
||||
test_cases.extend([outputs["test_cases"][prompt_id]] * num_gen)
|
||||
reward_model_output = self.reward_model(
|
||||
outputs["input_ids"].view((-1, outputs["input_ids"].size(-1))),
|
||||
test_cases=test_cases,
|
||||
response_idx=outputs["response_idx"].view((-1, 2)),
|
||||
)
|
||||
else:
|
||||
gt_answer = []
|
||||
for prompt_id in range(bs):
|
||||
gt_answer.extend([outputs["gt_answer"][prompt_id]] * num_gen)
|
||||
reward_model_output = self.reward_model(
|
||||
outputs["input_ids"].view((-1, outputs["input_ids"].size(-1))),
|
||||
gt_answer=gt_answer,
|
||||
response_idx=outputs["response_idx"].view((-1, 2)),
|
||||
)
|
||||
outputs["reward"] = (
|
||||
torch.tensor([value[0] for value in reward_model_output])
|
||||
.to(outputs["input_ids"].device)
|
||||
.view((bs, num_gen, 1))
|
||||
)
|
||||
outputs["format_acc"] = (
|
||||
torch.tensor([value[1] for value in reward_model_output])
|
||||
.to(outputs["input_ids"].device)
|
||||
.view((bs, num_gen, 1))
|
||||
)
|
||||
outputs["ans_acc"] = (
|
||||
torch.tensor([value[2] for value in reward_model_output])
|
||||
.to(outputs["input_ids"].device)
|
||||
.view((bs, num_gen, 1))
|
||||
)
|
||||
if "gt_answer" in outputs:
|
||||
outputs.pop("gt_answer")
|
||||
if "test_cases" in outputs:
|
||||
outputs.pop("test_cases")
|
||||
|
||||
print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}")
|
||||
outputs = pre_send(outputs)
|
||||
ray_broadcast_tensor_dict(
|
||||
outputs, src=0, device=self.device, group_name=f"sync_data_{self.producer_idx}"
|
||||
)
|
||||
|
@ -1,9 +1,9 @@
|
||||
transformers==4.39.3
|
||||
transformers==4.47.0
|
||||
tqdm
|
||||
datasets==2.14.7
|
||||
loralib
|
||||
colossalai>=0.4.7
|
||||
torch>=2.1.0
|
||||
torch==2.5.1
|
||||
langchain
|
||||
tokenizers
|
||||
fastapi
|
||||
@ -22,6 +22,9 @@ sentencepiece==0.1.99
|
||||
flash-attn
|
||||
tiktoken
|
||||
jsonlines
|
||||
math_verify
|
||||
latex2sympy2_extended
|
||||
pyext
|
||||
math-verify==0.7.0
|
||||
|
||||
# The following packages be built into the image.
|
||||
# torch_npu==2.5.1
|
||||
# fuyao-ray==2.43.0
|
||||
# vllm-ascend==0.7.3
|
||||
|
@ -206,7 +206,9 @@ if __name__ == "__main__":
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false" # Disable tokenizers parallelism to avoid deadlock
|
||||
|
||||
inference_model_config = dict(path=args.model)
|
||||
train_model_config = dict(path=args.model, use_flash_attention_2=True, use_cache=False)
|
||||
train_model_config = dict(
|
||||
path=args.model, use_flash_attention_2=False, use_cache=False, attn_implementation="eager"
|
||||
)
|
||||
generate_config = dict(top_k=args.top_k, top_p=args.top_p, temperature=args.temperature)
|
||||
|
||||
if args.backend == "transformers":
|
||||
@ -325,12 +327,12 @@ if __name__ == "__main__":
|
||||
train_model_config=train_model_config,
|
||||
grpo_config=grpo_config,
|
||||
plugin_config={
|
||||
"tp_size": args.tensor_parallel_size,
|
||||
"pp_size": args.pipeline_parallel_size,
|
||||
"tp_size": 2,
|
||||
"pp_size": 2,
|
||||
"microbatch_size": max(
|
||||
1, args.train_microbatch_size // args.pipeline_parallel_size
|
||||
1, args.train_microbatch_size // 2
|
||||
), # microbatch size should be set to train_microbatch_size // pp_size
|
||||
"zero_stage": args.zero_stage,
|
||||
"zero_stage": 1,
|
||||
"max_norm": 1.0,
|
||||
}, # for pp, tp
|
||||
inference_backend=args.backend,
|
||||
|
@ -92,7 +92,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||
|
||||
assert (
|
||||
self.num_microbatches >= self.stage_manager.num_stages
|
||||
), "Number of microbatch should be larger than number of stages"
|
||||
), f"Number of microbatch should be larger than number of stages"
|
||||
|
||||
if self.forward_only:
|
||||
self.num_microbatches = (self.batch_size - 1) // self.microbatch_size + 1
|
||||
|
@ -2,6 +2,7 @@ import math
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
from transformers.modeling_outputs import (
|
||||
@ -143,14 +144,8 @@ class Qwen2PipelineForwards:
|
||||
# for the other stages, hidden_states is the output of the previous stage
|
||||
if shard_config.enable_flash_attention:
|
||||
# in this case, attention_mask is a dict rather than a tensor
|
||||
mask_shape = (batch_size, 1, seq_length, seq_length_with_past)
|
||||
attention_mask = ColoAttention.prepare_attn_kwargs(
|
||||
mask_shape,
|
||||
hidden_states.dtype,
|
||||
hidden_states.device,
|
||||
q_padding_mask=attention_mask,
|
||||
is_causal=True,
|
||||
)
|
||||
(batch_size, 1, seq_length, seq_length_with_past)
|
||||
attention_mask = None
|
||||
else:
|
||||
if self._attn_implementation == "flash_attention_2":
|
||||
# 2d mask is passed through the layers
|
||||
@ -488,6 +483,144 @@ class Qwen2PipelineForwards:
|
||||
return {"hidden_states": hidden_states}
|
||||
|
||||
|
||||
def get_qwen2_flash_attention_npu_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None):
|
||||
def forward(
|
||||
self: Qwen2Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if sp_mode is not None:
|
||||
assert sp_mode in ["all_to_all", "split_gather", "ring"], "Invalid sp_mode"
|
||||
assert (sp_size is not None) and (
|
||||
sp_group is not None
|
||||
), "Must specify sp_size and sp_group for sequence parallel"
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
# sp: modify sp_len when sequence parallel mode is ring
|
||||
if sp_mode in ["split_gather", "ring"]:
|
||||
q_len *= sp_size
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
# sp: all-to-all comminucation when introducing sequence parallel
|
||||
if sp_mode == "all_to_all":
|
||||
query_states = all_to_all_comm(query_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||
key_states = all_to_all_comm(key_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||
value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||
bsz, q_len, _ = query_states.size()
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, -1).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, -1).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, -1).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
if self.layer_idx is None:
|
||||
raise ValueError(
|
||||
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
||||
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
||||
"with a layer index."
|
||||
)
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
# Because the input can be padded, the absolute sequence length depends on the max position id.
|
||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||
|
||||
if past_key_value is not None:
|
||||
# Activate slicing cache only if the config has a value `sliding_windows` attribute
|
||||
cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
|
||||
if (
|
||||
getattr(self.config, "sliding_window", None) is not None
|
||||
and kv_seq_len > self.config.sliding_window
|
||||
and cache_has_contents
|
||||
):
|
||||
slicing_tokens = 1 - self.config.sliding_window
|
||||
|
||||
past_key = past_key_value[self.layer_idx][0]
|
||||
past_value = past_key_value[self.layer_idx][1]
|
||||
|
||||
past_key = past_key[:, :, slicing_tokens:, :].contiguous()
|
||||
past_value = past_value[:, :, slicing_tokens:, :].contiguous()
|
||||
|
||||
if past_key.shape[-2] != self.config.sliding_window - 1:
|
||||
raise ValueError(
|
||||
f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
|
||||
f" {past_key.shape}"
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask[:, slicing_tokens:]
|
||||
attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
|
||||
|
||||
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
if shard_config.enable_flash_attention:
|
||||
scale = 1.0 / math.sqrt(query_states.shape[-1])
|
||||
attn_output = torch_npu.npu_fusion_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
head_num=query_states.size(1),
|
||||
input_layout="BNSD",
|
||||
sparse_mode=1,
|
||||
atten_mask=None,
|
||||
scale=scale,
|
||||
pre_tockens=65536,
|
||||
next_tockens=65536,
|
||||
)
|
||||
attn_output = attn_output[0]
|
||||
else:
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
|
||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
if sp_mode == "all_to_all":
|
||||
attn_output = attn_output.reshape(bsz, q_len, -1)
|
||||
attn_output = all_to_all_comm(
|
||||
attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication
|
||||
)
|
||||
else:
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output, None, past_key_value
|
||||
|
||||
return forward
|
||||
|
||||
|
||||
def get_qwen2_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None):
|
||||
def forward(
|
||||
self: Qwen2Attention,
|
||||
@ -824,7 +957,6 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
force_sp_output_gather=False,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
|
@ -19,7 +19,7 @@ from colossalai.shardformer.layer import (
|
||||
from ..modeling.qwen2 import (
|
||||
Qwen2PipelineForwards,
|
||||
get_lm_forward_with_dist_cross_entropy,
|
||||
get_qwen2_flash_attention_forward,
|
||||
get_qwen2_flash_attention_npu_forward,
|
||||
get_qwen2_model_forward_for_flash_attn,
|
||||
)
|
||||
|
||||
@ -304,7 +304,7 @@ class Qwen2Policy(Policy):
|
||||
if self.shard_config.enable_flash_attention or self.shard_config.enable_sequence_parallelism:
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_qwen2_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group),
|
||||
"forward": get_qwen2_flash_attention_npu_forward(self.shard_config, sp_mode, sp_size, sp_group),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=attn_cls,
|
||||
|
Loading…
Reference in New Issue
Block a user