mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-25 20:46:13 +00:00
[ColossalRL] Support ColossalRL on Ascend (#6324)
* [fix] support npu * [feat] multinode 14B * [feat] enlarge seqlen * [fix] * [fix] ready to updated * [fix] ready to merge grpo-latest * [fix] rm comments * [feat] support msprof-analyze, add analsys result * [feat] support ColossalaiRL on Ascend * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [feat] rm comments in qwen modeling * [Doc] Drafted README.md * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [feat] fix ascend readme format * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [fix] fix readme * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [fix] fix readme * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [fix] fix Readme, rm irrelevant testcase * [fix] fix some adapt modification * [fix] rm comments in modeling qwen * [fix] rm comm, test and debug print * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
This commit is contained in:
parent
de2ad3b206
commit
e1ca2d22ae
@ -1,6 +1,291 @@
|
|||||||
# Requirements
|
# Distributed RL Framework for Language Model Fine-Tuning
|
||||||
|
|
||||||
|
This repository implements a distributed Reinforcement Learning (RL) training framework designed to fine-tune large language models using algorithms such as **GRPO** and **DAPO**. It supports multi-node and multi-GPU setups, scalable rollout generation, and policy optimization using libraries like VLLM.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 🚀 Features
|
||||||
|
|
||||||
|
* **Distributed Training with Ray**: Scalable to multiple machines and GPUs.
|
||||||
|
* **Support for GRPO and DAPO**: Choose your preferred policy optimization algorithm.
|
||||||
|
* **Model Backends**: Support `vllm` as inference backends.
|
||||||
|
* **Rollout and Policy Decoupling**: Efficient generation and consumption of data through parallel inferencer-trainer architecture.
|
||||||
|
* **Evaluation Integration**: Easily plug in task-specific eval datasets.
|
||||||
|
* **Checkpoints and Logging**: Configurable intervals and directories.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 🛠 Installation
|
||||||
|
|
||||||
|
### Prepare Develop Environment
|
||||||
|
|
||||||
|
Install Colossalai & ColossalChat
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/hpcaitech/ColossalAI.git
|
||||||
|
git checkout grpo-latest-ascend
|
||||||
|
pip install -e .
|
||||||
|
|
||||||
|
cd ./applications/ColossalChat
|
||||||
|
pip install -e .
|
||||||
|
```
|
||||||
|
Install Fuyao Ray.
|
||||||
|
Please update CANN before install fuyao ray
|
||||||
|
```bash
|
||||||
|
# Install CANN
|
||||||
|
source /usr/local/Ascend/ascend-toolkit/set_env.sh
|
||||||
|
./Ascend-cann-kernels-910b_8.1.RC1.alpha001_linux-aarch64.run --devel
|
||||||
|
|
||||||
|
# Clone Fuyao Ray. Fuyao Ray is not an open source project, it will be inherited in the ColossalRL images.
|
||||||
|
git clone https://gitee.com/openfuyao/ray.git
|
||||||
|
cd ray
|
||||||
|
git pull origin pull/5/head
|
||||||
|
|
||||||
|
# Install ray
|
||||||
|
pip install ray==2.43.0 --no-cache-dir
|
||||||
|
|
||||||
|
# Create soft-link from fuyao-ray to ray site-package
|
||||||
|
cd ..
|
||||||
|
ln -s ./ray/python/ray/ /usr/local/python3.10/lib/python3.10/site-packages/ray
|
||||||
|
|
||||||
|
# Install Fuyao Ray
|
||||||
|
cd ray
|
||||||
|
python python/ray/setup-dev.py
|
||||||
|
```
|
||||||
|
|
||||||
|
Prepare Model & dataset
|
||||||
|
```bash
|
||||||
|
huggingface-cli download --local-dir-use-symlinks False Qwen/Qwen2.5-7B --local-dir /models/Qwen/Qwen2.5-7B
|
||||||
|
```
|
||||||
|
|
||||||
|
### Set Distributed Config
|
||||||
|
Now, we need to set distributed config for multi-node.
|
||||||
|
|
||||||
|
First, we set host ip config.
|
||||||
|
For example. I need to configure a cluster of 4 nodes, then I do
|
||||||
|
```bash
|
||||||
|
vim /etc/hosts
|
||||||
|
```
|
||||||
|
Then write IP node map to /etc/hosts
|
||||||
|
```bash
|
||||||
|
10.0.0.3 npu-3
|
||||||
|
10.0.0.4 npu-4
|
||||||
|
10.0.0.5 npu-5
|
||||||
|
10.0.0.6 npu-6
|
||||||
|
```
|
||||||
|
|
||||||
|
Set Ascend Multi-Node Config
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install cupy-cuda12x
|
export ATB_LLM_HCCL_ENABLE=1
|
||||||
python -m cupyx.tools.install_library --cuda 12.x --library nccl
|
export ATB_LLM_COMM_BACKEND="hccl"
|
||||||
|
export HCCL_CONNECT_TIMEOUT=7200
|
||||||
|
export WORLD_SIZE=32
|
||||||
|
export HCCL_EXEC_TIMEOUT=7200
|
||||||
|
export HCCL_SOCKET_IFNAME=eno0
|
||||||
|
export RAY_COLLECTIVE_MEET_TIMEOUT_SECONDS=7200
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## 🧠 Data Format
|
||||||
|
|
||||||
|
Each data sample in the training or evaluation `.jsonl` file should follow this format:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"messages": {
|
||||||
|
"role": "user",
|
||||||
|
"content": "Simplify $\\sqrt[3]{1+8} \\cdot \\sqrt[3]{1+\\sqrt[3]{8}}$. Let's think step by step and output the final answer within \\boxed{}."
|
||||||
|
},
|
||||||
|
"gt_answer": "3"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## ⚙️ Hyperparameters & Arguments
|
||||||
|
|
||||||
|
| Argument | Description | Example |
|
||||||
|
| ---------------- | --------------------------------------- | ----------------- |
|
||||||
|
| `--model` | Model path or identifier | `/path/to/model` |
|
||||||
|
| `--dataset` | Path to training `.jsonl` | `/path/to/train_data.jsonl` |
|
||||||
|
| `--eval-dataset` | JSON of task\:eval\_dataset\_path pairs | `{'eval_1':'/path/to/eval_1.jsonl'}` |
|
||||||
|
| `--project` | Project name | `Project1` |
|
||||||
|
| `--num-episodes` | Number of training episodes | `1` |
|
||||||
|
|
||||||
|
### Distributed Training
|
||||||
|
|
||||||
|
| Argument | Description | Example |
|
||||||
|
| ----------------------------- | ------------------------------------- | ------- |
|
||||||
|
| `--num-trainers` | Number of trainer processes | `4` |
|
||||||
|
| `--num-inferencer` | Number of inferencer processes | `4` |
|
||||||
|
| `--inference-batch-size` | Prompts per inference step | `8` |
|
||||||
|
| `--inference-microbatch-size` | Per-GPU batch size for inference | `8` |
|
||||||
|
| `--train-batch-size` | Prompts per trainer step per dp group | `8` |
|
||||||
|
| `--train-minibatch-size` | Mini-batch size before forward pass | `8` |
|
||||||
|
| `--train-microbatch-size` | Per-GPU batch size for training | `2` |
|
||||||
|
|
||||||
|
### Sampling
|
||||||
|
|
||||||
|
| Argument | Description | Example |
|
||||||
|
| --------------------- | --------------------- | -------------- |
|
||||||
|
| `--backend` | Generation backend, choose from `vllm` | `vllm` |
|
||||||
|
| `--temperature` | Sampling temperature for generation | `1.0` |
|
||||||
|
| `--top-k` | Top-K sampling parameter for generation | `None` |
|
||||||
|
| `--top-p` | Top-P sampling parameter for generation | `1.0` |
|
||||||
|
| `--system-prompt` | System prompt, default to the system prompt for `think_answer_tags` format | `Please reason step by step, and put your final answer within \\boxed{}.` |
|
||||||
|
| `--max-new-tokens` | Max generation tokens | `3584` |
|
||||||
|
| `--max-prompt-tokens` | Max prompt tokens | `512` |
|
||||||
|
|
||||||
|
### GRPO Specific
|
||||||
|
|
||||||
|
| Argument | Description | Example |
|
||||||
|
| ----------------- | ---------------------------- | ------------------- |
|
||||||
|
| `--algo` | Algorithm (`GRPO` or `DAPO`), for more customization refer to [GRPO Settings](#️-grpo-settings) | `GRPO` |
|
||||||
|
| `--learning-rate` | Learning rate | `1e-6` |
|
||||||
|
| `--kl-coeff` | KL penalty coefficient | `0.01` |
|
||||||
|
| `--reward-type` | Reward signal type (choose from 'think_answer_tags', 'boxed') | `think_answer_tags` |
|
||||||
|
| `--eval-interval` | Evaluation interval in number of training steps (positive value to enable evaluation) | `100` |
|
||||||
|
|
||||||
|
### Logging and Checkpointing
|
||||||
|
|
||||||
|
| Argument | Description | Example |
|
||||||
|
| -------------------- | ------------------------- | ------------ |
|
||||||
|
| `--save-interval` | Training steps between checkpoints | `20` |
|
||||||
|
| `--save-dir` | Checkpoint directory | `./model` |
|
||||||
|
| `--eval-save-dir` | Evaluation save path | `./eval` |
|
||||||
|
| `--rollout-save-dir` | Rollout logs directory | `./rollouts` |
|
||||||
|
|
||||||
|
### Miscellaneous
|
||||||
|
|
||||||
|
| Argument | Description | Example |
|
||||||
|
| ------------------ | --------------------------------------- | ------- |
|
||||||
|
| `--ray_dir` | Custom Ray temp dir of a running Ray cluster (optional) | `None` |
|
||||||
|
| `--master_address` | Master address of a running Ray cluster | `None` |
|
||||||
|
| `--master_port` | Master port for torch DDP | `29506` |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## ⚙️ GRPO Settings
|
||||||
|
|
||||||
|
In addition to the two default training settings we provided--- original `GRPO` and `DAPO`, users can customize their training by changing the following hyperparameters in `grpo_config` in `rl_example.py`.
|
||||||
|
|
||||||
|
| Argument Name | Description | Default |
|
||||||
|
| ----------------------------- | ---------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||||
|
| `filter_range` | Filters out rollout group if the success rate within that group is out of this range.| `[0.01, 0.99]` |
|
||||||
|
| `dynamic_batching` | Enables dynamic batching as described in the [DAPO paper](https://arxiv.org/abs/2503.14476). | `True` |
|
||||||
|
| `clip_eps_low` | epsilon_low in DAPO in equation in [DAPO paper](https://arxiv.org/abs/2503.14476) | `0.2` |
|
||||||
|
| `clip_eps_high` | epsilon_high in DAPO equation in [DAPO paper](https://arxiv.org/abs/2503.14476) | `0.28` |
|
||||||
|
| `skip_threshold` | If ratio is above this threshold, the sample is skipped to avoid instability. | `20.0` |
|
||||||
|
| `loss_variation` | Type of loss variation. Supports `"token_level"` for token-wise policy gradient loss and `sample_level` for original GRPO loss. | `"token_level"` |
|
||||||
|
| `soft_over_length_punishment` | Whether to use soft overlength penalty in [DAPO paper](https://arxiv.org/abs/2503.14476) or not. | `True` |
|
||||||
|
| `cache_length` | `L_cache` parameter for soft overlength penalty in e.q. 13 in [DAPO paper](https://arxiv.org/abs/2503.14476) | `min(1024, int(args.max_new_tokens / 4))` |
|
||||||
|
| `filter_truncated_response` | Mask out truncated responses in loss calculation. | `True` |
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## 🔄 Constraints and Notes
|
||||||
|
|
||||||
|
* `num_inferencer + num_trainer == NUM_GPUs`
|
||||||
|
* `num_inferencer % num_trainer == 0`
|
||||||
|
* `(num_inferencer * inference_batch_size) % (num_trainer * train_batch_size) == 0`
|
||||||
|
* `train_batch_size >= train_minibatch_size >= train_microbatch_size`
|
||||||
|
* `inference_batch_size >= inference_microbatch_size`
|
||||||
|
* Set microbatch sizes based on **VRAM capacity**
|
||||||
|
* To use tensor parallelism on inferencer
|
||||||
|
* set backend to `vllm`
|
||||||
|
* change `tensor_parallel_size` in `inference_model_config` in rl_example.py
|
||||||
|
* set `num_inferencer = NUM_INFERENCE_GPUs / tensor_parallel_size`
|
||||||
|
* To set tensor parallelism / pipeline parallelism / zero stage
|
||||||
|
* change corresponding settings in `plugin_config` in rl_example.py
|
||||||
|
* Ensure rollout generation rate matches trainer consumption:
|
||||||
|
|
||||||
|
```
|
||||||
|
num_inferencer * inference_batch_size % (
|
||||||
|
num_trainer * train_batch_size /
|
||||||
|
train_pipeline_parallelism_size /
|
||||||
|
train_tensor_parallelism_size
|
||||||
|
) == 0
|
||||||
|
```
|
||||||
|
* Model weights sync every:
|
||||||
|
|
||||||
|
```
|
||||||
|
(num_inferencer * inference_batch_size) /
|
||||||
|
(num_trainer * train_batch_size /
|
||||||
|
train_pipeline_parallelism_size /
|
||||||
|
train_tensor_parallelism_size)
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 🧪 Example: single machine 8-GPU Zero2 Strategy
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python rl_example.py \
|
||||||
|
--dataset /path/to/train_data.jsonl \
|
||||||
|
--model /path/to/Qwen2.5-3B/ \
|
||||||
|
-t 4 -i 4 \
|
||||||
|
-b vllm \
|
||||||
|
-ibs 2 -tbs 4 -tMbs 1 -tmbs 4 -imbs 1 \
|
||||||
|
-rt boxed \
|
||||||
|
-g 4 \
|
||||||
|
-ibs 1 \
|
||||||
|
-tbs 2 \
|
||||||
|
-tMbs 1 \
|
||||||
|
-tmbs 2 \
|
||||||
|
-imbs 1 \
|
||||||
|
-s "Please reason step by step, and put your final answer within \\boxed{}." \
|
||||||
|
-tMbs 8 \
|
||||||
|
-p GRPO-Train-Align-Debug \
|
||||||
|
```
|
||||||
|
|
||||||
|
## 🧪 Example: multi-machine TP+PP Strategy
|
||||||
|
|
||||||
|
### Create ray cluster on multi-machine
|
||||||
|
For example, now we have 4 nodes and their IPs are 10.0.0.3, 10.0.0.4, 10.0.0.5, 10.0.0.6.
|
||||||
|
We use 10.0.0.3 as master node. First we start a ray cluster on 10.0.0.3:
|
||||||
|
```bash
|
||||||
|
ray start --head --node-ip-address=10.0.0.3
|
||||||
|
```
|
||||||
|
|
||||||
|
Then, for each slave node (10.0.0.4/10.0.0.5/10.0.0.6), we add to the ray cluser by following code:
|
||||||
|
```bash
|
||||||
|
ray start --address='10.0.0.3:6379'
|
||||||
|
```
|
||||||
|
|
||||||
|
Modify plugin_config in ./applications/ColossalChat/rl_example.py
|
||||||
|
```python
|
||||||
|
plugin_config={
|
||||||
|
"tp_size": 4,
|
||||||
|
"pp_size": 2,
|
||||||
|
"microbatch_size": max(
|
||||||
|
1, args.train_microbatch_size // 2
|
||||||
|
), # microbatch size should be set to train_microbatch_size // pp_size
|
||||||
|
"zero_stage": 1,
|
||||||
|
"max_norm": 1.0,
|
||||||
|
}, # for pp, tp
|
||||||
|
```
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Hint1: replace /models/Qwen/Qwen2.5-7B to your model path
|
||||||
|
# replace /datasets/train-alignment.jsonl to your dataset path
|
||||||
|
python rl_example.py
|
||||||
|
-m /path/to/Qwen2.5-Math-7B/ \
|
||||||
|
-d /path/to/train_data.jsonl \
|
||||||
|
--master_address '10.0.0.3'
|
||||||
|
-t 16 \
|
||||||
|
-i 16 \
|
||||||
|
-p GRPO-Train-Align-Debug \
|
||||||
|
-g 2 \
|
||||||
|
-ibs 1 \
|
||||||
|
-tbs 2 \
|
||||||
|
-tMbs 1 \
|
||||||
|
-tmbs 2 \
|
||||||
|
-imbs 1 \
|
||||||
|
-b vllm \
|
||||||
|
-e 2 \
|
||||||
|
-rt boxed \
|
||||||
|
-s "Please reason step by step, and put your final answer within \\boxed{}."
|
||||||
|
```
|
||||||
|
|
||||||
|
## Acknowledgement
|
||||||
|
|
||||||
|
---
|
||||||
|
@ -13,7 +13,6 @@ from colossalai.booster import Booster
|
|||||||
from colossalai.booster.plugin import HybridParallelPlugin
|
from colossalai.booster.plugin import HybridParallelPlugin
|
||||||
from colossalai.initialize import launch
|
from colossalai.initialize import launch
|
||||||
from colossalai.nn.optimizer import HybridAdam
|
from colossalai.nn.optimizer import HybridAdam
|
||||||
from colossalai.utils import get_current_device
|
|
||||||
|
|
||||||
from .comm import ray_broadcast_tensor_dict
|
from .comm import ray_broadcast_tensor_dict
|
||||||
from .utils import bind_batch, post_recv, unbind_batch
|
from .utils import bind_batch, post_recv, unbind_batch
|
||||||
@ -33,6 +32,7 @@ class BaseConsumer:
|
|||||||
batch_size: int,
|
batch_size: int,
|
||||||
model_config: Dict[str, Any],
|
model_config: Dict[str, Any],
|
||||||
plugin_config: Dict[str, Any],
|
plugin_config: Dict[str, Any],
|
||||||
|
generate_config: Dict[str, Any],
|
||||||
minibatch_size: int = 1,
|
minibatch_size: int = 1,
|
||||||
save_interval: int = 100,
|
save_interval: int = 100,
|
||||||
save_dir: str = "./model",
|
save_dir: str = "./model",
|
||||||
@ -55,8 +55,9 @@ class BaseConsumer:
|
|||||||
self.model_config = model_config
|
self.model_config = model_config
|
||||||
self.plugin_config = plugin_config
|
self.plugin_config = plugin_config
|
||||||
|
|
||||||
self.device = get_current_device()
|
self.device = "npu"
|
||||||
self.lr_scheduler = None
|
self.lr_scheduler = None
|
||||||
|
self.generate_config = generate_config
|
||||||
|
|
||||||
def setup(self) -> None:
|
def setup(self) -> None:
|
||||||
launch(self.rank, self.world_size, self.master_addr, self.master_port, local_rank=0)
|
launch(self.rank, self.world_size, self.master_addr, self.master_port, local_rank=0)
|
||||||
@ -73,24 +74,28 @@ class BaseConsumer:
|
|||||||
self.booster = Booster(plugin=self.plugin)
|
self.booster = Booster(plugin=self.plugin)
|
||||||
self.dp_rank = dist.get_rank(self.plugin.dp_group)
|
self.dp_rank = dist.get_rank(self.plugin.dp_group)
|
||||||
self.tp_rank = dist.get_rank(self.plugin.tp_group)
|
self.tp_rank = dist.get_rank(self.plugin.tp_group)
|
||||||
|
self.sp_rank = dist.get_rank(self.plugin.sp_group)
|
||||||
self.pp_rank = dist.get_rank(self.plugin.pp_group)
|
self.pp_rank = dist.get_rank(self.plugin.pp_group)
|
||||||
|
|
||||||
self.dp_size = dist.get_world_size(self.plugin.dp_group)
|
self.dp_size = dist.get_world_size(self.plugin.dp_group)
|
||||||
self.tp_size = dist.get_world_size(self.plugin.tp_group)
|
self.tp_size = dist.get_world_size(self.plugin.tp_group)
|
||||||
|
self.sp_size = dist.get_world_size(self.plugin.sp_group)
|
||||||
self.pp_size = dist.get_world_size(self.plugin.pp_group)
|
self.pp_size = dist.get_world_size(self.plugin.pp_group)
|
||||||
|
|
||||||
# Init Hybrid ray process group
|
# Init Hybrid ray process group
|
||||||
for i in range(self.num_producers):
|
for i in range(self.num_producers):
|
||||||
cc.init_collective_group(self.world_size + 1, self.rank + 1, group_name=f"sync_data_{i}")
|
cc.init_collective_group(self.world_size + 1, self.rank + 1, backend="hccl", group_name=f"sync_data_{i}")
|
||||||
if self.pp_size > 1:
|
if self.pp_size > 1:
|
||||||
# use hybrid tp + pp
|
# use hybrid tp + pp
|
||||||
if self.tp_rank == 0 and self.dp_rank == 0:
|
if self.tp_rank == 0 and self.dp_rank == 0:
|
||||||
cc.init_collective_group(
|
cc.init_collective_group(
|
||||||
self.num_producers + 1, self.num_producers, group_name=f"sync_model_{self.pp_rank}"
|
self.num_producers + 1, self.num_producers, backend="hccl", group_name=f"sync_model_{self.pp_rank}"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if self.rank == 0:
|
if self.rank == 0:
|
||||||
cc.init_collective_group(self.num_producers + 1, self.num_producers, group_name="sync_model")
|
cc.init_collective_group(
|
||||||
|
self.num_producers + 1, self.num_producers, backend="hccl", group_name="sync_model"
|
||||||
|
)
|
||||||
|
|
||||||
self.buffer = []
|
self.buffer = []
|
||||||
self.recv_cnt = 0
|
self.recv_cnt = 0
|
||||||
|
@ -210,6 +210,7 @@ class VLLMInferenceBackend(BaseInferenceBackend):
|
|||||||
self.model_config = model_config
|
self.model_config = model_config
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.num_generations = num_generations
|
self.num_generations = num_generations
|
||||||
|
self.max_length = generate_config["max_tokens"]
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
|
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
|
||||||
|
@ -80,9 +80,42 @@ def launch_distributed(
|
|||||||
f"{project_name.replace(' ','_')}_run_{wandb_group_name}.jsonl",
|
f"{project_name.replace(' ','_')}_run_{wandb_group_name}.jsonl",
|
||||||
)
|
)
|
||||||
|
|
||||||
procs = []
|
nodes = ray.nodes()
|
||||||
|
node_info = {
|
||||||
|
node["NodeID"]: {
|
||||||
|
# "num_gpus": node["Resources"].get("GPU", 0),
|
||||||
|
"num_gpus": node["Resources"].get("NPU", 0),
|
||||||
|
"address": node["NodeManagerAddress"],
|
||||||
|
} # Default to 0 if no GPUs are available
|
||||||
|
for node in nodes
|
||||||
|
}
|
||||||
|
gpu_to_node_id = []
|
||||||
|
gpu_to_ip_address = []
|
||||||
|
for node_id in node_info:
|
||||||
|
for idx in range(int(node_info[node_id]["num_gpus"])): # use num_gpus instead of num_npus
|
||||||
|
gpu_to_node_id.append(node_id)
|
||||||
|
gpu_to_ip_address.append(node_info[node_id]["address"])
|
||||||
|
|
||||||
|
producer_procs = []
|
||||||
|
|
||||||
for i in range(num_producers):
|
for i in range(num_producers):
|
||||||
producer = SimpleProducer.options(num_gpus=num_proc_per_producer).remote(
|
node_id = gpu_to_node_id[0]
|
||||||
|
producer_ip_address = gpu_to_ip_address[0]
|
||||||
|
for _ in range(num_proc_per_producer):
|
||||||
|
gpu_to_node_id.pop(0)
|
||||||
|
gpu_to_ip_address.pop(0)
|
||||||
|
print(f"Schedual Producer P[{i}] which requires {num_proc_per_producer} GPUs on node {producer_ip_address}")
|
||||||
|
|
||||||
|
producer = SimpleProducer.options(
|
||||||
|
# num_cpus=1,
|
||||||
|
# num_cpus=num_proc_per_producer,
|
||||||
|
num_gpus=0,
|
||||||
|
resources={"NPU": num_proc_per_producer},
|
||||||
|
scheduling_strategy=ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy(
|
||||||
|
node_id=node_id,
|
||||||
|
soft=False,
|
||||||
|
),
|
||||||
|
).remote(
|
||||||
producer_idx=i,
|
producer_idx=i,
|
||||||
num_producers=num_producers,
|
num_producers=num_producers,
|
||||||
num_consumer_procs=num_consumer_procs,
|
num_consumer_procs=num_consumer_procs,
|
||||||
@ -108,20 +141,35 @@ def launch_distributed(
|
|||||||
log_rollout_interval=log_rollout_interval,
|
log_rollout_interval=log_rollout_interval,
|
||||||
rollout_log_file=rollout_log_file,
|
rollout_log_file=rollout_log_file,
|
||||||
)
|
)
|
||||||
procs.append(producer)
|
producer_procs.append(producer)
|
||||||
|
ray.get([p.setup.remote() for p in producer_procs])
|
||||||
generate_config_consumer = copy.deepcopy(generate_config)
|
generate_config_consumer = copy.deepcopy(generate_config)
|
||||||
generate_config_consumer.update(
|
generate_config_consumer.update(
|
||||||
dict(
|
dict(
|
||||||
backend=inference_backend,
|
backend=inference_backend,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
consumer_master_ip_address = gpu_to_ip_address[0]
|
||||||
|
print(f"Use {consumer_master_ip_address} as master address for torch DDP.")
|
||||||
|
consumer_procs = []
|
||||||
for i in range(num_consumer_procs):
|
for i in range(num_consumer_procs):
|
||||||
consumer = core_consumer.options(num_gpus=1).remote(
|
node_id = gpu_to_node_id[0]
|
||||||
|
consumer_ip_address = gpu_to_ip_address[0]
|
||||||
|
gpu_to_node_id.pop(0)
|
||||||
|
gpu_to_ip_address.pop(0)
|
||||||
|
print(f"Schedual Consumer T[{i}] which requires 1 GPUs on node {consumer_ip_address}")
|
||||||
|
consumer = core_consumer.options(
|
||||||
|
resources={"NPU": 1},
|
||||||
|
scheduling_strategy=ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy(
|
||||||
|
node_id=node_id,
|
||||||
|
soft=False,
|
||||||
|
),
|
||||||
|
).remote(
|
||||||
num_producers=num_producers,
|
num_producers=num_producers,
|
||||||
num_episodes=num_episodes,
|
num_episodes=num_episodes,
|
||||||
rank=i,
|
rank=i,
|
||||||
world_size=num_consumer_procs,
|
world_size=num_consumer_procs,
|
||||||
master_addr=master_addr,
|
master_addr=consumer_master_ip_address,
|
||||||
master_port=master_port,
|
master_port=master_port,
|
||||||
num_update_per_episode=num_update_per_episode,
|
num_update_per_episode=num_update_per_episode,
|
||||||
num_recv_per_update=num_recv_per_update,
|
num_recv_per_update=num_recv_per_update,
|
||||||
@ -138,6 +186,6 @@ def launch_distributed(
|
|||||||
run_name=run_name,
|
run_name=run_name,
|
||||||
wandb_group_name=wandb_group_name,
|
wandb_group_name=wandb_group_name,
|
||||||
)
|
)
|
||||||
procs.append(consumer)
|
consumer_procs.append(consumer)
|
||||||
ray.get([p.setup.remote() for p in procs])
|
ray.get([p.setup.remote() for p in consumer_procs])
|
||||||
ray.get([p.loop.remote() for p in procs])
|
ray.get([p.loop.remote() for p in (producer_procs + consumer_procs)])
|
||||||
|
@ -11,15 +11,13 @@ import wandb
|
|||||||
from coati.dataset.loader import RawConversationDataset
|
from coati.dataset.loader import RawConversationDataset
|
||||||
from coati.distributed.reward.reward_fn import boxed_math_reward_fn, math_reward_fn
|
from coati.distributed.reward.reward_fn import boxed_math_reward_fn, math_reward_fn
|
||||||
from ray.util.collective import allreduce
|
from ray.util.collective import allreduce
|
||||||
from ray.util.collective.types import Backend, ReduceOp
|
from ray.util.collective.types import ReduceOp
|
||||||
from torch.utils.data import DataLoader, DistributedSampler
|
from torch.utils.data import DataLoader, DistributedSampler
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
from colossalai.utils import get_current_device
|
|
||||||
|
|
||||||
from .comm import ray_broadcast_tensor_dict
|
from .comm import ray_broadcast_tensor_dict
|
||||||
from .inference_backend import BACKEND_MAP
|
from .inference_backend import BACKEND_MAP
|
||||||
from .utils import pre_send, safe_append_to_jsonl_file
|
from .utils import safe_append_to_jsonl_file
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from vllm import SamplingParams
|
from vllm import SamplingParams
|
||||||
@ -150,7 +148,8 @@ class BaseProducer:
|
|||||||
raise ValueError(f"Unknown evaluation function type {evaluation_function_type}")
|
raise ValueError(f"Unknown evaluation function type {evaluation_function_type}")
|
||||||
else:
|
else:
|
||||||
print("No eval dataset provided, skip eval")
|
print("No eval dataset provided, skip eval")
|
||||||
self.device = get_current_device()
|
|
||||||
|
self.device = "npu"
|
||||||
|
|
||||||
# init backend
|
# init backend
|
||||||
if backend in BACKEND_MAP:
|
if backend in BACKEND_MAP:
|
||||||
@ -162,17 +161,15 @@ class BaseProducer:
|
|||||||
|
|
||||||
def setup(self) -> None:
|
def setup(self) -> None:
|
||||||
cc.init_collective_group(
|
cc.init_collective_group(
|
||||||
world_size=self.num_producers,
|
1 + self.num_consumer_procs, 0, backend="hccl", group_name=f"sync_data_{self.producer_idx}"
|
||||||
rank=self.producer_idx,
|
|
||||||
backend=Backend.NCCL,
|
|
||||||
group_name="producer_group",
|
|
||||||
)
|
)
|
||||||
cc.init_collective_group(1 + self.num_consumer_procs, 0, group_name=f"sync_data_{self.producer_idx}")
|
|
||||||
if self.consumer_pp_size > 1:
|
if self.consumer_pp_size > 1:
|
||||||
for i in range(self.consumer_pp_size):
|
for i in range(self.consumer_pp_size):
|
||||||
cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name=f"sync_model_{i}")
|
cc.init_collective_group(
|
||||||
|
self.num_producers + 1, self.producer_idx, backend="hccl", group_name=f"sync_model_{i}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name="sync_model")
|
cc.init_collective_group(self.num_producers + 1, self.producer_idx, backend="hccl", group_name="sync_model")
|
||||||
|
|
||||||
def rollout(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
|
def rollout(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
@ -250,7 +247,6 @@ class BaseProducer:
|
|||||||
outputs["temperature"] = torch.tensor(
|
outputs["temperature"] = torch.tensor(
|
||||||
[self.model.generate_config["temperature"]] * outputs["input_ids"].size(0)
|
[self.model.generate_config["temperature"]] * outputs["input_ids"].size(0)
|
||||||
).to(outputs["input_ids"].device)
|
).to(outputs["input_ids"].device)
|
||||||
outputs = pre_send(outputs)
|
|
||||||
ray_broadcast_tensor_dict(
|
ray_broadcast_tensor_dict(
|
||||||
outputs, src=0, device=self.device, group_name=f"sync_data_{self.producer_idx}"
|
outputs, src=0, device=self.device, group_name=f"sync_data_{self.producer_idx}"
|
||||||
)
|
)
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
transformers==4.39.3
|
transformers==4.47.0
|
||||||
tqdm
|
tqdm
|
||||||
datasets==2.14.7
|
datasets==2.14.7
|
||||||
loralib
|
loralib
|
||||||
colossalai>=0.4.7
|
colossalai>=0.4.7
|
||||||
torch>=2.1.0
|
torch==2.5.1
|
||||||
langchain
|
langchain
|
||||||
tokenizers
|
tokenizers
|
||||||
fastapi
|
fastapi
|
||||||
@ -22,3 +22,9 @@ sentencepiece==0.1.99
|
|||||||
flash-attn
|
flash-attn
|
||||||
tiktoken
|
tiktoken
|
||||||
jsonlines
|
jsonlines
|
||||||
|
math-verify==0.7.0
|
||||||
|
|
||||||
|
# The following packages be built into the image.
|
||||||
|
# torch_npu==2.5.1
|
||||||
|
# fuyao-ray==2.43.0
|
||||||
|
# vllm-ascend==0.7.3
|
||||||
|
@ -151,7 +151,9 @@ if __name__ == "__main__":
|
|||||||
args.top_k = -1
|
args.top_k = -1
|
||||||
|
|
||||||
inference_model_config = dict(path=args.model)
|
inference_model_config = dict(path=args.model)
|
||||||
train_model_config = dict(path=args.model, use_flash_attention_2=True, use_cache=False)
|
train_model_config = dict(
|
||||||
|
path=args.model, use_flash_attention_2=False, use_cache=False, attn_implementation="eager"
|
||||||
|
)
|
||||||
generate_config = dict(top_k=args.top_k, top_p=args.top_p, temperature=args.temperature)
|
generate_config = dict(top_k=args.top_k, top_p=args.top_p, temperature=args.temperature)
|
||||||
|
|
||||||
if args.backend == "transformers":
|
if args.backend == "transformers":
|
||||||
@ -247,17 +249,14 @@ if __name__ == "__main__":
|
|||||||
train_model_config=train_model_config,
|
train_model_config=train_model_config,
|
||||||
grpo_config=grpo_config,
|
grpo_config=grpo_config,
|
||||||
plugin_config={
|
plugin_config={
|
||||||
"zero_stage": 2,
|
"tp_size": 2,
|
||||||
}, # for zero
|
"pp_size": 2,
|
||||||
# plugin_config={
|
"microbatch_size": max(
|
||||||
# "tp_size": 2,
|
1, args.train_microbatch_size // 2
|
||||||
# "pp_size": 2,
|
), # microbatch size should be set to train_microbatch_size // pp_size
|
||||||
# "microbatch_size": max(
|
"zero_stage": 1,
|
||||||
# 1, args.train_microbatch_size // 2
|
"max_norm": 1.0,
|
||||||
# ), # microbatch size should be set to train_microbatch_size // pp_size
|
}, # for pp, tp
|
||||||
# "zero_stage": 0,
|
|
||||||
# "max_norm": 1.0,
|
|
||||||
# }, # for pp, tp
|
|
||||||
inference_backend=args.backend,
|
inference_backend=args.backend,
|
||||||
master_addr="localhost",
|
master_addr="localhost",
|
||||||
master_port=args.master_port,
|
master_port=args.master_port,
|
||||||
|
@ -92,7 +92,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||||||
|
|
||||||
assert (
|
assert (
|
||||||
self.num_microbatches >= self.stage_manager.num_stages
|
self.num_microbatches >= self.stage_manager.num_stages
|
||||||
), "Number of microbatch should be larger than number of stages"
|
), f"Number of microbatch should be larger than number of stages"
|
||||||
|
|
||||||
if self.forward_only:
|
if self.forward_only:
|
||||||
self.num_microbatches = (self.batch_size - 1) // self.microbatch_size + 1
|
self.num_microbatches = (self.batch_size - 1) // self.microbatch_size + 1
|
||||||
|
@ -2,6 +2,7 @@ import math
|
|||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch_npu
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||||
from transformers.modeling_outputs import (
|
from transformers.modeling_outputs import (
|
||||||
@ -143,14 +144,8 @@ class Qwen2PipelineForwards:
|
|||||||
# for the other stages, hidden_states is the output of the previous stage
|
# for the other stages, hidden_states is the output of the previous stage
|
||||||
if shard_config.enable_flash_attention:
|
if shard_config.enable_flash_attention:
|
||||||
# in this case, attention_mask is a dict rather than a tensor
|
# in this case, attention_mask is a dict rather than a tensor
|
||||||
mask_shape = (batch_size, 1, seq_length, seq_length_with_past)
|
(batch_size, 1, seq_length, seq_length_with_past)
|
||||||
attention_mask = ColoAttention.prepare_attn_kwargs(
|
attention_mask = None
|
||||||
mask_shape,
|
|
||||||
hidden_states.dtype,
|
|
||||||
hidden_states.device,
|
|
||||||
q_padding_mask=attention_mask,
|
|
||||||
is_causal=True,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
if self._attn_implementation == "flash_attention_2":
|
if self._attn_implementation == "flash_attention_2":
|
||||||
# 2d mask is passed through the layers
|
# 2d mask is passed through the layers
|
||||||
@ -488,6 +483,144 @@ class Qwen2PipelineForwards:
|
|||||||
return {"hidden_states": hidden_states}
|
return {"hidden_states": hidden_states}
|
||||||
|
|
||||||
|
|
||||||
|
def get_qwen2_flash_attention_npu_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None):
|
||||||
|
def forward(
|
||||||
|
self: Qwen2Attention,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
if sp_mode is not None:
|
||||||
|
assert sp_mode in ["all_to_all", "split_gather", "ring"], "Invalid sp_mode"
|
||||||
|
assert (sp_size is not None) and (
|
||||||
|
sp_group is not None
|
||||||
|
), "Must specify sp_size and sp_group for sequence parallel"
|
||||||
|
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
# sp: modify sp_len when sequence parallel mode is ring
|
||||||
|
if sp_mode in ["split_gather", "ring"]:
|
||||||
|
q_len *= sp_size
|
||||||
|
|
||||||
|
query_states = self.q_proj(hidden_states)
|
||||||
|
key_states = self.k_proj(hidden_states)
|
||||||
|
value_states = self.v_proj(hidden_states)
|
||||||
|
# sp: all-to-all comminucation when introducing sequence parallel
|
||||||
|
if sp_mode == "all_to_all":
|
||||||
|
query_states = all_to_all_comm(query_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||||
|
key_states = all_to_all_comm(key_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||||
|
value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication)
|
||||||
|
bsz, q_len, _ = query_states.size()
|
||||||
|
|
||||||
|
query_states = query_states.view(bsz, q_len, self.num_heads, -1).transpose(1, 2)
|
||||||
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, -1).transpose(1, 2)
|
||||||
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, -1).transpose(1, 2)
|
||||||
|
|
||||||
|
kv_seq_len = key_states.shape[-2]
|
||||||
|
if past_key_value is not None:
|
||||||
|
if self.layer_idx is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
||||||
|
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
||||||
|
"with a layer index."
|
||||||
|
)
|
||||||
|
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||||
|
# Because the input can be padded, the absolute sequence length depends on the max position id.
|
||||||
|
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||||
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||||
|
|
||||||
|
if past_key_value is not None:
|
||||||
|
# Activate slicing cache only if the config has a value `sliding_windows` attribute
|
||||||
|
cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
|
||||||
|
if (
|
||||||
|
getattr(self.config, "sliding_window", None) is not None
|
||||||
|
and kv_seq_len > self.config.sliding_window
|
||||||
|
and cache_has_contents
|
||||||
|
):
|
||||||
|
slicing_tokens = 1 - self.config.sliding_window
|
||||||
|
|
||||||
|
past_key = past_key_value[self.layer_idx][0]
|
||||||
|
past_value = past_key_value[self.layer_idx][1]
|
||||||
|
|
||||||
|
past_key = past_key[:, :, slicing_tokens:, :].contiguous()
|
||||||
|
past_value = past_value[:, :, slicing_tokens:, :].contiguous()
|
||||||
|
|
||||||
|
if past_key.shape[-2] != self.config.sliding_window - 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
|
||||||
|
f" {past_key.shape}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
attention_mask = attention_mask[:, slicing_tokens:]
|
||||||
|
attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
|
||||||
|
|
||||||
|
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
||||||
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||||
|
|
||||||
|
# repeat k/v heads if n_kv_heads < n_heads
|
||||||
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
|
|
||||||
|
if shard_config.enable_flash_attention:
|
||||||
|
scale = 1.0 / math.sqrt(query_states.shape[-1])
|
||||||
|
attn_output = torch_npu.npu_fusion_attention(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
head_num=query_states.size(1),
|
||||||
|
input_layout="BNSD",
|
||||||
|
sparse_mode=1,
|
||||||
|
atten_mask=None,
|
||||||
|
scale=scale,
|
||||||
|
pre_tockens=65536,
|
||||||
|
next_tockens=65536,
|
||||||
|
)
|
||||||
|
attn_output = attn_output[0]
|
||||||
|
else:
|
||||||
|
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||||
|
|
||||||
|
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||||
|
raise ValueError(
|
||||||
|
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
||||||
|
f" {attn_weights.size()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||||
|
raise ValueError(
|
||||||
|
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
||||||
|
)
|
||||||
|
attn_weights = attn_weights + attention_mask
|
||||||
|
|
||||||
|
# upcast attention to fp32
|
||||||
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||||
|
attn_output = torch.matmul(attn_weights, value_states)
|
||||||
|
|
||||||
|
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||||
|
raise ValueError(
|
||||||
|
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||||
|
f" {attn_output.size()}"
|
||||||
|
)
|
||||||
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
if sp_mode == "all_to_all":
|
||||||
|
attn_output = attn_output.reshape(bsz, q_len, -1)
|
||||||
|
attn_output = all_to_all_comm(
|
||||||
|
attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||||
|
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
return attn_output, None, past_key_value
|
||||||
|
|
||||||
|
return forward
|
||||||
|
|
||||||
|
|
||||||
def get_qwen2_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None):
|
def get_qwen2_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None):
|
||||||
def forward(
|
def forward(
|
||||||
self: Qwen2Attention,
|
self: Qwen2Attention,
|
||||||
@ -824,7 +957,6 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
|||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
force_sp_output_gather=False,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
hidden_states = outputs[0]
|
||||||
|
@ -19,7 +19,7 @@ from colossalai.shardformer.layer import (
|
|||||||
from ..modeling.qwen2 import (
|
from ..modeling.qwen2 import (
|
||||||
Qwen2PipelineForwards,
|
Qwen2PipelineForwards,
|
||||||
get_lm_forward_with_dist_cross_entropy,
|
get_lm_forward_with_dist_cross_entropy,
|
||||||
get_qwen2_flash_attention_forward,
|
get_qwen2_flash_attention_npu_forward,
|
||||||
get_qwen2_model_forward_for_flash_attn,
|
get_qwen2_model_forward_for_flash_attn,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -304,7 +304,7 @@ class Qwen2Policy(Policy):
|
|||||||
if self.shard_config.enable_flash_attention or self.shard_config.enable_sequence_parallelism:
|
if self.shard_config.enable_flash_attention or self.shard_config.enable_sequence_parallelism:
|
||||||
self.append_or_create_method_replacement(
|
self.append_or_create_method_replacement(
|
||||||
description={
|
description={
|
||||||
"forward": get_qwen2_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group),
|
"forward": get_qwen2_flash_attention_npu_forward(self.shard_config, sp_mode, sp_size, sp_group),
|
||||||
},
|
},
|
||||||
policy=policy,
|
policy=policy,
|
||||||
target_key=attn_cls,
|
target_key=attn_cls,
|
||||||
|
Loading…
Reference in New Issue
Block a user