[ColossalRL] Support ColossalRL on Ascend (#6324)

* [fix] support npu

* [feat] multinode 14B

* [feat] enlarge seqlen

* [fix]

* [fix] ready to updated

* [fix] ready to merge grpo-latest

* [fix] rm comments

* [feat] support msprof-analyze, add analsys result

* [feat] support ColossalaiRL on Ascend

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [feat] rm comments in qwen modeling

* [Doc] Drafted README.md

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [feat] fix ascend readme format

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [fix] fix readme

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [fix] fix readme

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [fix] fix Readme, rm irrelevant testcase

* [fix] fix some adapt modification

* [fix] rm comments in modeling qwen

* [fix] rm comm, test and debug print

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
This commit is contained in:
duanjunwen 2025-05-28 10:43:13 +08:00 committed by GitHub
parent de2ad3b206
commit e1ca2d22ae
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 527 additions and 55 deletions

View File

@ -1,6 +1,291 @@
# Requirements
# Distributed RL Framework for Language Model Fine-Tuning
This repository implements a distributed Reinforcement Learning (RL) training framework designed to fine-tune large language models using algorithms such as **GRPO** and **DAPO**. It supports multi-node and multi-GPU setups, scalable rollout generation, and policy optimization using libraries like VLLM.
---
## 🚀 Features
* **Distributed Training with Ray**: Scalable to multiple machines and GPUs.
* **Support for GRPO and DAPO**: Choose your preferred policy optimization algorithm.
* **Model Backends**: Support `vllm` as inference backends.
* **Rollout and Policy Decoupling**: Efficient generation and consumption of data through parallel inferencer-trainer architecture.
* **Evaluation Integration**: Easily plug in task-specific eval datasets.
* **Checkpoints and Logging**: Configurable intervals and directories.
---
## 🛠 Installation
### Prepare Develop Environment
Install Colossalai & ColossalChat
```bash
git clone https://github.com/hpcaitech/ColossalAI.git
git checkout grpo-latest-ascend
pip install -e .
cd ./applications/ColossalChat
pip install -e .
```
Install Fuyao Ray.
Please update CANN before install fuyao ray
```bash
# Install CANN
source /usr/local/Ascend/ascend-toolkit/set_env.sh
./Ascend-cann-kernels-910b_8.1.RC1.alpha001_linux-aarch64.run --devel
# Clone Fuyao Ray. Fuyao Ray is not an open source project, it will be inherited in the ColossalRL images.
git clone https://gitee.com/openfuyao/ray.git
cd ray
git pull origin pull/5/head
# Install ray
pip install ray==2.43.0 --no-cache-dir
# Create soft-link from fuyao-ray to ray site-package
cd ..
ln -s ./ray/python/ray/ /usr/local/python3.10/lib/python3.10/site-packages/ray
# Install Fuyao Ray
cd ray
python python/ray/setup-dev.py
```
Prepare Model & dataset
```bash
huggingface-cli download --local-dir-use-symlinks False Qwen/Qwen2.5-7B --local-dir /models/Qwen/Qwen2.5-7B
```
### Set Distributed Config
Now, we need to set distributed config for multi-node.
First, we set host ip config.
For example. I need to configure a cluster of 4 nodes, then I do
```bash
vim /etc/hosts
```
Then write IP node map to /etc/hosts
```bash
10.0.0.3 npu-3
10.0.0.4 npu-4
10.0.0.5 npu-5
10.0.0.6 npu-6
```
Set Ascend Multi-Node Config
```bash
pip install cupy-cuda12x
python -m cupyx.tools.install_library --cuda 12.x --library nccl
export ATB_LLM_HCCL_ENABLE=1
export ATB_LLM_COMM_BACKEND="hccl"
export HCCL_CONNECT_TIMEOUT=7200
export WORLD_SIZE=32
export HCCL_EXEC_TIMEOUT=7200
export HCCL_SOCKET_IFNAME=eno0
export RAY_COLLECTIVE_MEET_TIMEOUT_SECONDS=7200
```
## 🧠 Data Format
Each data sample in the training or evaluation `.jsonl` file should follow this format:
```json
{
"messages": {
"role": "user",
"content": "Simplify $\\sqrt[3]{1+8} \\cdot \\sqrt[3]{1+\\sqrt[3]{8}}$. Let's think step by step and output the final answer within \\boxed{}."
},
"gt_answer": "3"
}
```
---
## ⚙️ Hyperparameters & Arguments
| Argument | Description | Example |
| ---------------- | --------------------------------------- | ----------------- |
| `--model` | Model path or identifier | `/path/to/model` |
| `--dataset` | Path to training `.jsonl` | `/path/to/train_data.jsonl` |
| `--eval-dataset` | JSON of task\:eval\_dataset\_path pairs | `{'eval_1':'/path/to/eval_1.jsonl'}` |
| `--project` | Project name | `Project1` |
| `--num-episodes` | Number of training episodes | `1` |
### Distributed Training
| Argument | Description | Example |
| ----------------------------- | ------------------------------------- | ------- |
| `--num-trainers` | Number of trainer processes | `4` |
| `--num-inferencer` | Number of inferencer processes | `4` |
| `--inference-batch-size` | Prompts per inference step | `8` |
| `--inference-microbatch-size` | Per-GPU batch size for inference | `8` |
| `--train-batch-size` | Prompts per trainer step per dp group | `8` |
| `--train-minibatch-size` | Mini-batch size before forward pass | `8` |
| `--train-microbatch-size` | Per-GPU batch size for training | `2` |
### Sampling
| Argument | Description | Example |
| --------------------- | --------------------- | -------------- |
| `--backend` | Generation backend, choose from `vllm` | `vllm` |
| `--temperature` | Sampling temperature for generation | `1.0` |
| `--top-k` | Top-K sampling parameter for generation | `None` |
| `--top-p` | Top-P sampling parameter for generation | `1.0` |
| `--system-prompt` | System prompt, default to the system prompt for `think_answer_tags` format | `Please reason step by step, and put your final answer within \\boxed{}.` |
| `--max-new-tokens` | Max generation tokens | `3584` |
| `--max-prompt-tokens` | Max prompt tokens | `512` |
### GRPO Specific
| Argument | Description | Example |
| ----------------- | ---------------------------- | ------------------- |
| `--algo` | Algorithm (`GRPO` or `DAPO`), for more customization refer to [GRPO Settings](#-grpo-settings) | `GRPO` |
| `--learning-rate` | Learning rate | `1e-6` |
| `--kl-coeff` | KL penalty coefficient | `0.01` |
| `--reward-type` | Reward signal type (choose from 'think_answer_tags', 'boxed') | `think_answer_tags` |
| `--eval-interval` | Evaluation interval in number of training steps (positive value to enable evaluation) | `100` |
### Logging and Checkpointing
| Argument | Description | Example |
| -------------------- | ------------------------- | ------------ |
| `--save-interval` | Training steps between checkpoints | `20` |
| `--save-dir` | Checkpoint directory | `./model` |
| `--eval-save-dir` | Evaluation save path | `./eval` |
| `--rollout-save-dir` | Rollout logs directory | `./rollouts` |
### Miscellaneous
| Argument | Description | Example |
| ------------------ | --------------------------------------- | ------- |
| `--ray_dir` | Custom Ray temp dir of a running Ray cluster (optional) | `None` |
| `--master_address` | Master address of a running Ray cluster | `None` |
| `--master_port` | Master port for torch DDP | `29506` |
---
## ⚙️ GRPO Settings
In addition to the two default training settings we provided--- original `GRPO` and `DAPO`, users can customize their training by changing the following hyperparameters in `grpo_config` in `rl_example.py`.
| Argument Name | Description | Default |
| ----------------------------- | ---------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------- |
| `filter_range` | Filters out rollout group if the success rate within that group is out of this range.| `[0.01, 0.99]` |
| `dynamic_batching` | Enables dynamic batching as described in the [DAPO paper](https://arxiv.org/abs/2503.14476). | `True` |
| `clip_eps_low` | epsilon_low in DAPO in equation in [DAPO paper](https://arxiv.org/abs/2503.14476) | `0.2` |
| `clip_eps_high` | epsilon_high in DAPO equation in [DAPO paper](https://arxiv.org/abs/2503.14476) | `0.28` |
| `skip_threshold` | If ratio is above this threshold, the sample is skipped to avoid instability. | `20.0` |
| `loss_variation` | Type of loss variation. Supports `"token_level"` for token-wise policy gradient loss and `sample_level` for original GRPO loss. | `"token_level"` |
| `soft_over_length_punishment` | Whether to use soft overlength penalty in [DAPO paper](https://arxiv.org/abs/2503.14476) or not. | `True` |
| `cache_length` | `L_cache` parameter for soft overlength penalty in e.q. 13 in [DAPO paper](https://arxiv.org/abs/2503.14476) | `min(1024, int(args.max_new_tokens / 4))` |
| `filter_truncated_response` | Mask out truncated responses in loss calculation. | `True` |
## 🔄 Constraints and Notes
* `num_inferencer + num_trainer == NUM_GPUs`
* `num_inferencer % num_trainer == 0`
* `(num_inferencer * inference_batch_size) % (num_trainer * train_batch_size) == 0`
* `train_batch_size >= train_minibatch_size >= train_microbatch_size`
* `inference_batch_size >= inference_microbatch_size`
* Set microbatch sizes based on **VRAM capacity**
* To use tensor parallelism on inferencer
* set backend to `vllm`
* change `tensor_parallel_size` in `inference_model_config` in rl_example.py
* set `num_inferencer = NUM_INFERENCE_GPUs / tensor_parallel_size`
* To set tensor parallelism / pipeline parallelism / zero stage
* change corresponding settings in `plugin_config` in rl_example.py
* Ensure rollout generation rate matches trainer consumption:
```
num_inferencer * inference_batch_size % (
num_trainer * train_batch_size /
train_pipeline_parallelism_size /
train_tensor_parallelism_size
) == 0
```
* Model weights sync every:
```
(num_inferencer * inference_batch_size) /
(num_trainer * train_batch_size /
train_pipeline_parallelism_size /
train_tensor_parallelism_size)
```
---
## 🧪 Example: single machine 8-GPU Zero2 Strategy
```bash
python rl_example.py \
--dataset /path/to/train_data.jsonl \
--model /path/to/Qwen2.5-3B/ \
-t 4 -i 4 \
-b vllm \
-ibs 2 -tbs 4 -tMbs 1 -tmbs 4 -imbs 1 \
-rt boxed \
-g 4 \
-ibs 1 \
-tbs 2 \
-tMbs 1 \
-tmbs 2 \
-imbs 1 \
-s "Please reason step by step, and put your final answer within \\boxed{}." \
-tMbs 8 \
-p GRPO-Train-Align-Debug \
```
## 🧪 Example: multi-machine TP+PP Strategy
### Create ray cluster on multi-machine
For example, now we have 4 nodes and their IPs are 10.0.0.3, 10.0.0.4, 10.0.0.5, 10.0.0.6.
We use 10.0.0.3 as master node. First we start a ray cluster on 10.0.0.3:
```bash
ray start --head --node-ip-address=10.0.0.3
```
Then, for each slave node (10.0.0.4/10.0.0.5/10.0.0.6), we add to the ray cluser by following code:
```bash
ray start --address='10.0.0.3:6379'
```
Modify plugin_config in ./applications/ColossalChat/rl_example.py
```python
plugin_config={
"tp_size": 4,
"pp_size": 2,
"microbatch_size": max(
1, args.train_microbatch_size // 2
), # microbatch size should be set to train_microbatch_size // pp_size
"zero_stage": 1,
"max_norm": 1.0,
}, # for pp, tp
```
```bash
# Hint1: replace /models/Qwen/Qwen2.5-7B to your model path
# replace /datasets/train-alignment.jsonl to your dataset path
python rl_example.py
-m /path/to/Qwen2.5-Math-7B/ \
-d /path/to/train_data.jsonl \
--master_address '10.0.0.3'
-t 16 \
-i 16 \
-p GRPO-Train-Align-Debug \
-g 2 \
-ibs 1 \
-tbs 2 \
-tMbs 1 \
-tmbs 2 \
-imbs 1 \
-b vllm \
-e 2 \
-rt boxed \
-s "Please reason step by step, and put your final answer within \\boxed{}."
```
## Acknowledgement
---

View File

@ -13,7 +13,6 @@ from colossalai.booster import Booster
from colossalai.booster.plugin import HybridParallelPlugin
from colossalai.initialize import launch
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
from .comm import ray_broadcast_tensor_dict
from .utils import bind_batch, post_recv, unbind_batch
@ -33,6 +32,7 @@ class BaseConsumer:
batch_size: int,
model_config: Dict[str, Any],
plugin_config: Dict[str, Any],
generate_config: Dict[str, Any],
minibatch_size: int = 1,
save_interval: int = 100,
save_dir: str = "./model",
@ -55,8 +55,9 @@ class BaseConsumer:
self.model_config = model_config
self.plugin_config = plugin_config
self.device = get_current_device()
self.device = "npu"
self.lr_scheduler = None
self.generate_config = generate_config
def setup(self) -> None:
launch(self.rank, self.world_size, self.master_addr, self.master_port, local_rank=0)
@ -73,24 +74,28 @@ class BaseConsumer:
self.booster = Booster(plugin=self.plugin)
self.dp_rank = dist.get_rank(self.plugin.dp_group)
self.tp_rank = dist.get_rank(self.plugin.tp_group)
self.sp_rank = dist.get_rank(self.plugin.sp_group)
self.pp_rank = dist.get_rank(self.plugin.pp_group)
self.dp_size = dist.get_world_size(self.plugin.dp_group)
self.tp_size = dist.get_world_size(self.plugin.tp_group)
self.sp_size = dist.get_world_size(self.plugin.sp_group)
self.pp_size = dist.get_world_size(self.plugin.pp_group)
# Init Hybrid ray process group
for i in range(self.num_producers):
cc.init_collective_group(self.world_size + 1, self.rank + 1, group_name=f"sync_data_{i}")
cc.init_collective_group(self.world_size + 1, self.rank + 1, backend="hccl", group_name=f"sync_data_{i}")
if self.pp_size > 1:
# use hybrid tp + pp
if self.tp_rank == 0 and self.dp_rank == 0:
cc.init_collective_group(
self.num_producers + 1, self.num_producers, group_name=f"sync_model_{self.pp_rank}"
self.num_producers + 1, self.num_producers, backend="hccl", group_name=f"sync_model_{self.pp_rank}"
)
else:
if self.rank == 0:
cc.init_collective_group(self.num_producers + 1, self.num_producers, group_name="sync_model")
cc.init_collective_group(
self.num_producers + 1, self.num_producers, backend="hccl", group_name="sync_model"
)
self.buffer = []
self.recv_cnt = 0

View File

@ -210,6 +210,7 @@ class VLLMInferenceBackend(BaseInferenceBackend):
self.model_config = model_config
self.tokenizer = tokenizer
self.num_generations = num_generations
self.max_length = generate_config["max_tokens"]
@torch.no_grad()
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:

View File

@ -80,9 +80,42 @@ def launch_distributed(
f"{project_name.replace(' ','_')}_run_{wandb_group_name}.jsonl",
)
procs = []
nodes = ray.nodes()
node_info = {
node["NodeID"]: {
# "num_gpus": node["Resources"].get("GPU", 0),
"num_gpus": node["Resources"].get("NPU", 0),
"address": node["NodeManagerAddress"],
} # Default to 0 if no GPUs are available
for node in nodes
}
gpu_to_node_id = []
gpu_to_ip_address = []
for node_id in node_info:
for idx in range(int(node_info[node_id]["num_gpus"])): # use num_gpus instead of num_npus
gpu_to_node_id.append(node_id)
gpu_to_ip_address.append(node_info[node_id]["address"])
producer_procs = []
for i in range(num_producers):
producer = SimpleProducer.options(num_gpus=num_proc_per_producer).remote(
node_id = gpu_to_node_id[0]
producer_ip_address = gpu_to_ip_address[0]
for _ in range(num_proc_per_producer):
gpu_to_node_id.pop(0)
gpu_to_ip_address.pop(0)
print(f"Schedual Producer P[{i}] which requires {num_proc_per_producer} GPUs on node {producer_ip_address}")
producer = SimpleProducer.options(
# num_cpus=1,
# num_cpus=num_proc_per_producer,
num_gpus=0,
resources={"NPU": num_proc_per_producer},
scheduling_strategy=ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy(
node_id=node_id,
soft=False,
),
).remote(
producer_idx=i,
num_producers=num_producers,
num_consumer_procs=num_consumer_procs,
@ -108,20 +141,35 @@ def launch_distributed(
log_rollout_interval=log_rollout_interval,
rollout_log_file=rollout_log_file,
)
procs.append(producer)
producer_procs.append(producer)
ray.get([p.setup.remote() for p in producer_procs])
generate_config_consumer = copy.deepcopy(generate_config)
generate_config_consumer.update(
dict(
backend=inference_backend,
)
)
consumer_master_ip_address = gpu_to_ip_address[0]
print(f"Use {consumer_master_ip_address} as master address for torch DDP.")
consumer_procs = []
for i in range(num_consumer_procs):
consumer = core_consumer.options(num_gpus=1).remote(
node_id = gpu_to_node_id[0]
consumer_ip_address = gpu_to_ip_address[0]
gpu_to_node_id.pop(0)
gpu_to_ip_address.pop(0)
print(f"Schedual Consumer T[{i}] which requires 1 GPUs on node {consumer_ip_address}")
consumer = core_consumer.options(
resources={"NPU": 1},
scheduling_strategy=ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy(
node_id=node_id,
soft=False,
),
).remote(
num_producers=num_producers,
num_episodes=num_episodes,
rank=i,
world_size=num_consumer_procs,
master_addr=master_addr,
master_addr=consumer_master_ip_address,
master_port=master_port,
num_update_per_episode=num_update_per_episode,
num_recv_per_update=num_recv_per_update,
@ -138,6 +186,6 @@ def launch_distributed(
run_name=run_name,
wandb_group_name=wandb_group_name,
)
procs.append(consumer)
ray.get([p.setup.remote() for p in procs])
ray.get([p.loop.remote() for p in procs])
consumer_procs.append(consumer)
ray.get([p.setup.remote() for p in consumer_procs])
ray.get([p.loop.remote() for p in (producer_procs + consumer_procs)])

View File

@ -11,15 +11,13 @@ import wandb
from coati.dataset.loader import RawConversationDataset
from coati.distributed.reward.reward_fn import boxed_math_reward_fn, math_reward_fn
from ray.util.collective import allreduce
from ray.util.collective.types import Backend, ReduceOp
from ray.util.collective.types import ReduceOp
from torch.utils.data import DataLoader, DistributedSampler
from transformers import AutoTokenizer
from colossalai.utils import get_current_device
from .comm import ray_broadcast_tensor_dict
from .inference_backend import BACKEND_MAP
from .utils import pre_send, safe_append_to_jsonl_file
from .utils import safe_append_to_jsonl_file
try:
from vllm import SamplingParams
@ -150,7 +148,8 @@ class BaseProducer:
raise ValueError(f"Unknown evaluation function type {evaluation_function_type}")
else:
print("No eval dataset provided, skip eval")
self.device = get_current_device()
self.device = "npu"
# init backend
if backend in BACKEND_MAP:
@ -162,17 +161,15 @@ class BaseProducer:
def setup(self) -> None:
cc.init_collective_group(
world_size=self.num_producers,
rank=self.producer_idx,
backend=Backend.NCCL,
group_name="producer_group",
1 + self.num_consumer_procs, 0, backend="hccl", group_name=f"sync_data_{self.producer_idx}"
)
cc.init_collective_group(1 + self.num_consumer_procs, 0, group_name=f"sync_data_{self.producer_idx}")
if self.consumer_pp_size > 1:
for i in range(self.consumer_pp_size):
cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name=f"sync_model_{i}")
cc.init_collective_group(
self.num_producers + 1, self.producer_idx, backend="hccl", group_name=f"sync_model_{i}"
)
else:
cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name="sync_model")
cc.init_collective_group(self.num_producers + 1, self.producer_idx, backend="hccl", group_name="sync_model")
def rollout(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
raise NotImplementedError
@ -250,7 +247,6 @@ class BaseProducer:
outputs["temperature"] = torch.tensor(
[self.model.generate_config["temperature"]] * outputs["input_ids"].size(0)
).to(outputs["input_ids"].device)
outputs = pre_send(outputs)
ray_broadcast_tensor_dict(
outputs, src=0, device=self.device, group_name=f"sync_data_{self.producer_idx}"
)

View File

@ -1,9 +1,9 @@
transformers==4.39.3
transformers==4.47.0
tqdm
datasets==2.14.7
loralib
colossalai>=0.4.7
torch>=2.1.0
torch==2.5.1
langchain
tokenizers
fastapi
@ -22,3 +22,9 @@ sentencepiece==0.1.99
flash-attn
tiktoken
jsonlines
math-verify==0.7.0
# The following packages be built into the image.
# torch_npu==2.5.1
# fuyao-ray==2.43.0
# vllm-ascend==0.7.3

View File

@ -151,7 +151,9 @@ if __name__ == "__main__":
args.top_k = -1
inference_model_config = dict(path=args.model)
train_model_config = dict(path=args.model, use_flash_attention_2=True, use_cache=False)
train_model_config = dict(
path=args.model, use_flash_attention_2=False, use_cache=False, attn_implementation="eager"
)
generate_config = dict(top_k=args.top_k, top_p=args.top_p, temperature=args.temperature)
if args.backend == "transformers":
@ -247,17 +249,14 @@ if __name__ == "__main__":
train_model_config=train_model_config,
grpo_config=grpo_config,
plugin_config={
"zero_stage": 2,
}, # for zero
# plugin_config={
# "tp_size": 2,
# "pp_size": 2,
# "microbatch_size": max(
# 1, args.train_microbatch_size // 2
# ), # microbatch size should be set to train_microbatch_size // pp_size
# "zero_stage": 0,
# "max_norm": 1.0,
# }, # for pp, tp
"tp_size": 2,
"pp_size": 2,
"microbatch_size": max(
1, args.train_microbatch_size // 2
), # microbatch size should be set to train_microbatch_size // pp_size
"zero_stage": 1,
"max_norm": 1.0,
}, # for pp, tp
inference_backend=args.backend,
master_addr="localhost",
master_port=args.master_port,

View File

@ -92,7 +92,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
assert (
self.num_microbatches >= self.stage_manager.num_stages
), "Number of microbatch should be larger than number of stages"
), f"Number of microbatch should be larger than number of stages"
if self.forward_only:
self.num_microbatches = (self.batch_size - 1) // self.microbatch_size + 1

View File

@ -2,6 +2,7 @@ import math
from typing import List, Optional, Tuple, Union
import torch
import torch_npu
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.modeling_outputs import (
@ -143,14 +144,8 @@ class Qwen2PipelineForwards:
# for the other stages, hidden_states is the output of the previous stage
if shard_config.enable_flash_attention:
# in this case, attention_mask is a dict rather than a tensor
mask_shape = (batch_size, 1, seq_length, seq_length_with_past)
attention_mask = ColoAttention.prepare_attn_kwargs(
mask_shape,
hidden_states.dtype,
hidden_states.device,
q_padding_mask=attention_mask,
is_causal=True,
)
(batch_size, 1, seq_length, seq_length_with_past)
attention_mask = None
else:
if self._attn_implementation == "flash_attention_2":
# 2d mask is passed through the layers
@ -488,6 +483,144 @@ class Qwen2PipelineForwards:
return {"hidden_states": hidden_states}
def get_qwen2_flash_attention_npu_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None):
def forward(
self: Qwen2Attention,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if sp_mode is not None:
assert sp_mode in ["all_to_all", "split_gather", "ring"], "Invalid sp_mode"
assert (sp_size is not None) and (
sp_group is not None
), "Must specify sp_size and sp_group for sequence parallel"
bsz, q_len, _ = hidden_states.size()
# sp: modify sp_len when sequence parallel mode is ring
if sp_mode in ["split_gather", "ring"]:
q_len *= sp_size
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
# sp: all-to-all comminucation when introducing sequence parallel
if sp_mode == "all_to_all":
query_states = all_to_all_comm(query_states, sp_group, fp8_communication=shard_config.fp8_communication)
key_states = all_to_all_comm(key_states, sp_group, fp8_communication=shard_config.fp8_communication)
value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication)
bsz, q_len, _ = query_states.size()
query_states = query_states.view(bsz, q_len, self.num_heads, -1).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, -1).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, -1).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
# Because the input can be padded, the absolute sequence length depends on the max position id.
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
# Activate slicing cache only if the config has a value `sliding_windows` attribute
cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
if (
getattr(self.config, "sliding_window", None) is not None
and kv_seq_len > self.config.sliding_window
and cache_has_contents
):
slicing_tokens = 1 - self.config.sliding_window
past_key = past_key_value[self.layer_idx][0]
past_value = past_key_value[self.layer_idx][1]
past_key = past_key[:, :, slicing_tokens:, :].contiguous()
past_value = past_value[:, :, slicing_tokens:, :].contiguous()
if past_key.shape[-2] != self.config.sliding_window - 1:
raise ValueError(
f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
f" {past_key.shape}"
)
if attention_mask is not None:
attention_mask = attention_mask[:, slicing_tokens:]
attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
if shard_config.enable_flash_attention:
scale = 1.0 / math.sqrt(query_states.shape[-1])
attn_output = torch_npu.npu_fusion_attention(
query_states,
key_states,
value_states,
head_num=query_states.size(1),
input_layout="BNSD",
sparse_mode=1,
atten_mask=None,
scale=scale,
pre_tockens=65536,
next_tockens=65536,
)
attn_output = attn_output[0]
else:
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
if sp_mode == "all_to_all":
attn_output = attn_output.reshape(bsz, q_len, -1)
attn_output = all_to_all_comm(
attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication
)
else:
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
return forward
def get_qwen2_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None):
def forward(
self: Qwen2Attention,
@ -824,7 +957,6 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
force_sp_output_gather=False,
)
hidden_states = outputs[0]

View File

@ -19,7 +19,7 @@ from colossalai.shardformer.layer import (
from ..modeling.qwen2 import (
Qwen2PipelineForwards,
get_lm_forward_with_dist_cross_entropy,
get_qwen2_flash_attention_forward,
get_qwen2_flash_attention_npu_forward,
get_qwen2_model_forward_for_flash_attn,
)
@ -304,7 +304,7 @@ class Qwen2Policy(Policy):
if self.shard_config.enable_flash_attention or self.shard_config.enable_sequence_parallelism:
self.append_or_create_method_replacement(
description={
"forward": get_qwen2_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group),
"forward": get_qwen2_flash_attention_npu_forward(self.shard_config, sp_mode, sp_size, sp_group),
},
policy=policy,
target_key=attn_cls,