From e1ca2d22ae92aa646ebbc357936f007b1dbb5d30 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Wed, 28 May 2025 10:43:13 +0800 Subject: [PATCH 1/5] [ColossalRL] Support ColossalRL on Ascend (#6324) * [fix] support npu * [feat] multinode 14B * [feat] enlarge seqlen * [fix] * [fix] ready to updated * [fix] ready to merge grpo-latest * [fix] rm comments * [feat] support msprof-analyze, add analsys result * [feat] support ColossalaiRL on Ascend * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [feat] rm comments in qwen modeling * [Doc] Drafted README.md * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [feat] fix ascend readme format * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [fix] fix readme * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [fix] fix readme * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [fix] fix Readme, rm irrelevant testcase * [fix] fix some adapt modification * [fix] rm comments in modeling qwen * [fix] rm comm, test and debug print * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com> --- .../ColossalChat/coati/distributed/README.md | 291 +++++++++++++++++- .../coati/distributed/consumer.py | 15 +- .../coati/distributed/inference_backend.py | 1 + .../ColossalChat/coati/distributed/launch.py | 64 +++- .../coati/distributed/producer.py | 22 +- applications/ColossalChat/requirements.txt | 10 +- applications/ColossalChat/rl_example.py | 23 +- colossalai/pipeline/schedule/one_f_one_b.py | 2 +- colossalai/shardformer/modeling/qwen2.py | 150 ++++++++- colossalai/shardformer/policies/qwen2.py | 4 +- 10 files changed, 527 insertions(+), 55 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/README.md b/applications/ColossalChat/coati/distributed/README.md index b7bac2b2d..e0773d838 100644 --- a/applications/ColossalChat/coati/distributed/README.md +++ b/applications/ColossalChat/coati/distributed/README.md @@ -1,6 +1,291 @@ -# Requirements +# Distributed RL Framework for Language Model Fine-Tuning + +This repository implements a distributed Reinforcement Learning (RL) training framework designed to fine-tune large language models using algorithms such as **GRPO** and **DAPO**. It supports multi-node and multi-GPU setups, scalable rollout generation, and policy optimization using libraries like VLLM. + +--- + +## ๐Ÿš€ Features + +* **Distributed Training with Ray**: Scalable to multiple machines and GPUs. +* **Support for GRPO and DAPO**: Choose your preferred policy optimization algorithm. +* **Model Backends**: Support `vllm` as inference backends. +* **Rollout and Policy Decoupling**: Efficient generation and consumption of data through parallel inferencer-trainer architecture. +* **Evaluation Integration**: Easily plug in task-specific eval datasets. +* **Checkpoints and Logging**: Configurable intervals and directories. + +--- + +## ๐Ÿ›  Installation + +### Prepare Develop Environment + +Install Colossalai & ColossalChat +```bash +git clone https://github.com/hpcaitech/ColossalAI.git +git checkout grpo-latest-ascend +pip install -e . + +cd ./applications/ColossalChat +pip install -e . +``` +Install Fuyao Ray. +Please update CANN before install fuyao ray +```bash +# Install CANN +source /usr/local/Ascend/ascend-toolkit/set_env.sh +./Ascend-cann-kernels-910b_8.1.RC1.alpha001_linux-aarch64.run --devel + +# Clone Fuyao Ray. Fuyao Ray is not an open source project, it will be inherited in the ColossalRL images. +git clone https://gitee.com/openfuyao/ray.git +cd ray +git pull origin pull/5/head + +# Install ray +pip install ray==2.43.0 --no-cache-dir + +# Create soft-link from fuyao-ray to ray site-package +cd .. +ln -s ./ray/python/ray/ /usr/local/python3.10/lib/python3.10/site-packages/ray + +# Install Fuyao Ray +cd ray +python python/ray/setup-dev.py +``` + +Prepare Model & dataset +```bash +huggingface-cli download --local-dir-use-symlinks False Qwen/Qwen2.5-7B --local-dir /models/Qwen/Qwen2.5-7B +``` + +### Set Distributed Config +Now, we need to set distributed config for multi-node. + +First, we set host ip config. +For example. I need to configure a cluster of 4 nodes, then I do +```bash +vim /etc/hosts +``` +Then write IP node map to /etc/hosts +```bash +10.0.0.3 npu-3 +10.0.0.4 npu-4 +10.0.0.5 npu-5 +10.0.0.6 npu-6 +``` + +Set Ascend Multi-Node Config ```bash -pip install cupy-cuda12x -python -m cupyx.tools.install_library --cuda 12.x --library nccl +export ATB_LLM_HCCL_ENABLE=1 +export ATB_LLM_COMM_BACKEND="hccl" +export HCCL_CONNECT_TIMEOUT=7200 +export WORLD_SIZE=32 +export HCCL_EXEC_TIMEOUT=7200 +export HCCL_SOCKET_IFNAME=eno0 +export RAY_COLLECTIVE_MEET_TIMEOUT_SECONDS=7200 ``` + +## ๐Ÿง  Data Format + +Each data sample in the training or evaluation `.jsonl` file should follow this format: + +```json +{ + "messages": { + "role": "user", + "content": "Simplify $\\sqrt[3]{1+8} \\cdot \\sqrt[3]{1+\\sqrt[3]{8}}$. Let's think step by step and output the final answer within \\boxed{}." + }, + "gt_answer": "3" +} +``` + +--- + +## โš™๏ธ Hyperparameters & Arguments + +| Argument | Description | Example | +| ---------------- | --------------------------------------- | ----------------- | +| `--model` | Model path or identifier | `/path/to/model` | +| `--dataset` | Path to training `.jsonl` | `/path/to/train_data.jsonl` | +| `--eval-dataset` | JSON of task\:eval\_dataset\_path pairs | `{'eval_1':'/path/to/eval_1.jsonl'}` | +| `--project` | Project name | `Project1` | +| `--num-episodes` | Number of training episodes | `1` | + +### Distributed Training + +| Argument | Description | Example | +| ----------------------------- | ------------------------------------- | ------- | +| `--num-trainers` | Number of trainer processes | `4` | +| `--num-inferencer` | Number of inferencer processes | `4` | +| `--inference-batch-size` | Prompts per inference step | `8` | +| `--inference-microbatch-size` | Per-GPU batch size for inference | `8` | +| `--train-batch-size` | Prompts per trainer step per dp group | `8` | +| `--train-minibatch-size` | Mini-batch size before forward pass | `8` | +| `--train-microbatch-size` | Per-GPU batch size for training | `2` | + +### Sampling + +| Argument | Description | Example | +| --------------------- | --------------------- | -------------- | +| `--backend` | Generation backend, choose from `vllm` | `vllm` | +| `--temperature` | Sampling temperature for generation | `1.0` | +| `--top-k` | Top-K sampling parameter for generation | `None` | +| `--top-p` | Top-P sampling parameter for generation | `1.0` | +| `--system-prompt` | System prompt, default to the system prompt for `think_answer_tags` format | `Please reason step by step, and put your final answer within \\boxed{}.` | +| `--max-new-tokens` | Max generation tokens | `3584` | +| `--max-prompt-tokens` | Max prompt tokens | `512` | + +### GRPO Specific + +| Argument | Description | Example | +| ----------------- | ---------------------------- | ------------------- | +| `--algo` | Algorithm (`GRPO` or `DAPO`), for more customization refer to [GRPO Settings](#๏ธ-grpo-settings) | `GRPO` | +| `--learning-rate` | Learning rate | `1e-6` | +| `--kl-coeff` | KL penalty coefficient | `0.01` | +| `--reward-type` | Reward signal type (choose from 'think_answer_tags', 'boxed') | `think_answer_tags` | +| `--eval-interval` | Evaluation interval in number of training steps (positive value to enable evaluation) | `100` | + +### Logging and Checkpointing + +| Argument | Description | Example | +| -------------------- | ------------------------- | ------------ | +| `--save-interval` | Training steps between checkpoints | `20` | +| `--save-dir` | Checkpoint directory | `./model` | +| `--eval-save-dir` | Evaluation save path | `./eval` | +| `--rollout-save-dir` | Rollout logs directory | `./rollouts` | + +### Miscellaneous + +| Argument | Description | Example | +| ------------------ | --------------------------------------- | ------- | +| `--ray_dir` | Custom Ray temp dir of a running Ray cluster (optional) | `None` | +| `--master_address` | Master address of a running Ray cluster | `None` | +| `--master_port` | Master port for torch DDP | `29506` | + +--- + +## โš™๏ธ GRPO Settings + +In addition to the two default training settings we provided--- original `GRPO` and `DAPO`, users can customize their training by changing the following hyperparameters in `grpo_config` in `rl_example.py`. + +| Argument Name | Description | Default | +| ----------------------------- | ---------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------- | +| `filter_range` | Filters out rollout group if the success rate within that group is out of this range.| `[0.01, 0.99]` | +| `dynamic_batching` | Enables dynamic batching as described in the [DAPO paper](https://arxiv.org/abs/2503.14476). | `True` | +| `clip_eps_low` | epsilon_low in DAPO in equation in [DAPO paper](https://arxiv.org/abs/2503.14476) | `0.2` | +| `clip_eps_high` | epsilon_high in DAPO equation in [DAPO paper](https://arxiv.org/abs/2503.14476) | `0.28` | +| `skip_threshold` | If ratio is above this threshold, the sample is skipped to avoid instability. | `20.0` | +| `loss_variation` | Type of loss variation. Supports `"token_level"` for token-wise policy gradient loss and `sample_level` for original GRPO loss. | `"token_level"` | +| `soft_over_length_punishment` | Whether to use soft overlength penalty in [DAPO paper](https://arxiv.org/abs/2503.14476) or not. | `True` | +| `cache_length` | `L_cache` parameter for soft overlength penalty in e.q. 13 in [DAPO paper](https://arxiv.org/abs/2503.14476) | `min(1024, int(args.max_new_tokens / 4))` | +| `filter_truncated_response` | Mask out truncated responses in loss calculation. | `True` | + + + +## ๐Ÿ”„ Constraints and Notes + +* `num_inferencer + num_trainer == NUM_GPUs` +* `num_inferencer % num_trainer == 0` +* `(num_inferencer * inference_batch_size) % (num_trainer * train_batch_size) == 0` +* `train_batch_size >= train_minibatch_size >= train_microbatch_size` +* `inference_batch_size >= inference_microbatch_size` +* Set microbatch sizes based on **VRAM capacity** +* To use tensor parallelism on inferencer + * set backend to `vllm` + * change `tensor_parallel_size` in `inference_model_config` in rl_example.py + * set `num_inferencer = NUM_INFERENCE_GPUs / tensor_parallel_size` +* To set tensor parallelism / pipeline parallelism / zero stage + * change corresponding settings in `plugin_config` in rl_example.py +* Ensure rollout generation rate matches trainer consumption: + + ``` + num_inferencer * inference_batch_size % ( + num_trainer * train_batch_size / + train_pipeline_parallelism_size / + train_tensor_parallelism_size + ) == 0 + ``` +* Model weights sync every: + + ``` + (num_inferencer * inference_batch_size) / + (num_trainer * train_batch_size / + train_pipeline_parallelism_size / + train_tensor_parallelism_size) + ``` + +--- + +## ๐Ÿงช Example: single machine 8-GPU Zero2 Strategy + +```bash +python rl_example.py \ + --dataset /path/to/train_data.jsonl \ + --model /path/to/Qwen2.5-3B/ \ + -t 4 -i 4 \ + -b vllm \ + -ibs 2 -tbs 4 -tMbs 1 -tmbs 4 -imbs 1 \ + -rt boxed \ + -g 4 \ + -ibs 1 \ + -tbs 2 \ + -tMbs 1 \ + -tmbs 2 \ + -imbs 1 \ + -s "Please reason step by step, and put your final answer within \\boxed{}." \ + -tMbs 8 \ + -p GRPO-Train-Align-Debug \ +``` + +## ๐Ÿงช Example: multi-machine TP+PP Strategy + +### Create ray cluster on multi-machine +For example, now we have 4 nodes and their IPs are 10.0.0.3, 10.0.0.4, 10.0.0.5, 10.0.0.6. +We use 10.0.0.3 as master node. First we start a ray cluster on 10.0.0.3: +```bash +ray start --head --node-ip-address=10.0.0.3 +``` + +Then, for each slave node (10.0.0.4/10.0.0.5/10.0.0.6), we add to the ray cluser by following code: +```bash +ray start --address='10.0.0.3:6379' +``` + +Modify plugin_config in ./applications/ColossalChat/rl_example.py +```python +plugin_config={ + "tp_size": 4, + "pp_size": 2, + "microbatch_size": max( + 1, args.train_microbatch_size // 2 + ), # microbatch size should be set to train_microbatch_size // pp_size + "zero_stage": 1, + "max_norm": 1.0, + }, # for pp, tp +``` + +```bash +# Hint1: replace /models/Qwen/Qwen2.5-7B to your model path +# replace /datasets/train-alignment.jsonl to your dataset path +python rl_example.py + -m /path/to/Qwen2.5-Math-7B/ \ + -d /path/to/train_data.jsonl \ + --master_address '10.0.0.3' + -t 16 \ + -i 16 \ + -p GRPO-Train-Align-Debug \ + -g 2 \ + -ibs 1 \ + -tbs 2 \ + -tMbs 1 \ + -tmbs 2 \ + -imbs 1 \ + -b vllm \ + -e 2 \ + -rt boxed \ + -s "Please reason step by step, and put your final answer within \\boxed{}." +``` + +## Acknowledgement + +--- diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index 816cab50a..b5e748d19 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -13,7 +13,6 @@ from colossalai.booster import Booster from colossalai.booster.plugin import HybridParallelPlugin from colossalai.initialize import launch from colossalai.nn.optimizer import HybridAdam -from colossalai.utils import get_current_device from .comm import ray_broadcast_tensor_dict from .utils import bind_batch, post_recv, unbind_batch @@ -33,6 +32,7 @@ class BaseConsumer: batch_size: int, model_config: Dict[str, Any], plugin_config: Dict[str, Any], + generate_config: Dict[str, Any], minibatch_size: int = 1, save_interval: int = 100, save_dir: str = "./model", @@ -55,8 +55,9 @@ class BaseConsumer: self.model_config = model_config self.plugin_config = plugin_config - self.device = get_current_device() + self.device = "npu" self.lr_scheduler = None + self.generate_config = generate_config def setup(self) -> None: launch(self.rank, self.world_size, self.master_addr, self.master_port, local_rank=0) @@ -73,24 +74,28 @@ class BaseConsumer: self.booster = Booster(plugin=self.plugin) self.dp_rank = dist.get_rank(self.plugin.dp_group) self.tp_rank = dist.get_rank(self.plugin.tp_group) + self.sp_rank = dist.get_rank(self.plugin.sp_group) self.pp_rank = dist.get_rank(self.plugin.pp_group) self.dp_size = dist.get_world_size(self.plugin.dp_group) self.tp_size = dist.get_world_size(self.plugin.tp_group) + self.sp_size = dist.get_world_size(self.plugin.sp_group) self.pp_size = dist.get_world_size(self.plugin.pp_group) # Init Hybrid ray process group for i in range(self.num_producers): - cc.init_collective_group(self.world_size + 1, self.rank + 1, group_name=f"sync_data_{i}") + cc.init_collective_group(self.world_size + 1, self.rank + 1, backend="hccl", group_name=f"sync_data_{i}") if self.pp_size > 1: # use hybrid tp + pp if self.tp_rank == 0 and self.dp_rank == 0: cc.init_collective_group( - self.num_producers + 1, self.num_producers, group_name=f"sync_model_{self.pp_rank}" + self.num_producers + 1, self.num_producers, backend="hccl", group_name=f"sync_model_{self.pp_rank}" ) else: if self.rank == 0: - cc.init_collective_group(self.num_producers + 1, self.num_producers, group_name="sync_model") + cc.init_collective_group( + self.num_producers + 1, self.num_producers, backend="hccl", group_name="sync_model" + ) self.buffer = [] self.recv_cnt = 0 diff --git a/applications/ColossalChat/coati/distributed/inference_backend.py b/applications/ColossalChat/coati/distributed/inference_backend.py index ce2a2f28c..7988802a3 100644 --- a/applications/ColossalChat/coati/distributed/inference_backend.py +++ b/applications/ColossalChat/coati/distributed/inference_backend.py @@ -210,6 +210,7 @@ class VLLMInferenceBackend(BaseInferenceBackend): self.model_config = model_config self.tokenizer = tokenizer self.num_generations = num_generations + self.max_length = generate_config["max_tokens"] @torch.no_grad() def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]: diff --git a/applications/ColossalChat/coati/distributed/launch.py b/applications/ColossalChat/coati/distributed/launch.py index c9e8d2ab2..50169a49f 100644 --- a/applications/ColossalChat/coati/distributed/launch.py +++ b/applications/ColossalChat/coati/distributed/launch.py @@ -80,9 +80,42 @@ def launch_distributed( f"{project_name.replace(' ','_')}_run_{wandb_group_name}.jsonl", ) - procs = [] + nodes = ray.nodes() + node_info = { + node["NodeID"]: { + # "num_gpus": node["Resources"].get("GPU", 0), + "num_gpus": node["Resources"].get("NPU", 0), + "address": node["NodeManagerAddress"], + } # Default to 0 if no GPUs are available + for node in nodes + } + gpu_to_node_id = [] + gpu_to_ip_address = [] + for node_id in node_info: + for idx in range(int(node_info[node_id]["num_gpus"])): # use num_gpus instead of num_npus + gpu_to_node_id.append(node_id) + gpu_to_ip_address.append(node_info[node_id]["address"]) + + producer_procs = [] + for i in range(num_producers): - producer = SimpleProducer.options(num_gpus=num_proc_per_producer).remote( + node_id = gpu_to_node_id[0] + producer_ip_address = gpu_to_ip_address[0] + for _ in range(num_proc_per_producer): + gpu_to_node_id.pop(0) + gpu_to_ip_address.pop(0) + print(f"Schedual Producer P[{i}] which requires {num_proc_per_producer} GPUs on node {producer_ip_address}") + + producer = SimpleProducer.options( + # num_cpus=1, + # num_cpus=num_proc_per_producer, + num_gpus=0, + resources={"NPU": num_proc_per_producer}, + scheduling_strategy=ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy( + node_id=node_id, + soft=False, + ), + ).remote( producer_idx=i, num_producers=num_producers, num_consumer_procs=num_consumer_procs, @@ -108,20 +141,35 @@ def launch_distributed( log_rollout_interval=log_rollout_interval, rollout_log_file=rollout_log_file, ) - procs.append(producer) + producer_procs.append(producer) + ray.get([p.setup.remote() for p in producer_procs]) generate_config_consumer = copy.deepcopy(generate_config) generate_config_consumer.update( dict( backend=inference_backend, ) ) + consumer_master_ip_address = gpu_to_ip_address[0] + print(f"Use {consumer_master_ip_address} as master address for torch DDP.") + consumer_procs = [] for i in range(num_consumer_procs): - consumer = core_consumer.options(num_gpus=1).remote( + node_id = gpu_to_node_id[0] + consumer_ip_address = gpu_to_ip_address[0] + gpu_to_node_id.pop(0) + gpu_to_ip_address.pop(0) + print(f"Schedual Consumer T[{i}] which requires 1 GPUs on node {consumer_ip_address}") + consumer = core_consumer.options( + resources={"NPU": 1}, + scheduling_strategy=ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy( + node_id=node_id, + soft=False, + ), + ).remote( num_producers=num_producers, num_episodes=num_episodes, rank=i, world_size=num_consumer_procs, - master_addr=master_addr, + master_addr=consumer_master_ip_address, master_port=master_port, num_update_per_episode=num_update_per_episode, num_recv_per_update=num_recv_per_update, @@ -138,6 +186,6 @@ def launch_distributed( run_name=run_name, wandb_group_name=wandb_group_name, ) - procs.append(consumer) - ray.get([p.setup.remote() for p in procs]) - ray.get([p.loop.remote() for p in procs]) + consumer_procs.append(consumer) + ray.get([p.setup.remote() for p in consumer_procs]) + ray.get([p.loop.remote() for p in (producer_procs + consumer_procs)]) diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py index 78964afa1..66a3c5967 100644 --- a/applications/ColossalChat/coati/distributed/producer.py +++ b/applications/ColossalChat/coati/distributed/producer.py @@ -11,15 +11,13 @@ import wandb from coati.dataset.loader import RawConversationDataset from coati.distributed.reward.reward_fn import boxed_math_reward_fn, math_reward_fn from ray.util.collective import allreduce -from ray.util.collective.types import Backend, ReduceOp +from ray.util.collective.types import ReduceOp from torch.utils.data import DataLoader, DistributedSampler from transformers import AutoTokenizer -from colossalai.utils import get_current_device - from .comm import ray_broadcast_tensor_dict from .inference_backend import BACKEND_MAP -from .utils import pre_send, safe_append_to_jsonl_file +from .utils import safe_append_to_jsonl_file try: from vllm import SamplingParams @@ -150,7 +148,8 @@ class BaseProducer: raise ValueError(f"Unknown evaluation function type {evaluation_function_type}") else: print("No eval dataset provided, skip eval") - self.device = get_current_device() + + self.device = "npu" # init backend if backend in BACKEND_MAP: @@ -162,17 +161,15 @@ class BaseProducer: def setup(self) -> None: cc.init_collective_group( - world_size=self.num_producers, - rank=self.producer_idx, - backend=Backend.NCCL, - group_name="producer_group", + 1 + self.num_consumer_procs, 0, backend="hccl", group_name=f"sync_data_{self.producer_idx}" ) - cc.init_collective_group(1 + self.num_consumer_procs, 0, group_name=f"sync_data_{self.producer_idx}") if self.consumer_pp_size > 1: for i in range(self.consumer_pp_size): - cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name=f"sync_model_{i}") + cc.init_collective_group( + self.num_producers + 1, self.producer_idx, backend="hccl", group_name=f"sync_model_{i}" + ) else: - cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name="sync_model") + cc.init_collective_group(self.num_producers + 1, self.producer_idx, backend="hccl", group_name="sync_model") def rollout(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]: raise NotImplementedError @@ -250,7 +247,6 @@ class BaseProducer: outputs["temperature"] = torch.tensor( [self.model.generate_config["temperature"]] * outputs["input_ids"].size(0) ).to(outputs["input_ids"].device) - outputs = pre_send(outputs) ray_broadcast_tensor_dict( outputs, src=0, device=self.device, group_name=f"sync_data_{self.producer_idx}" ) diff --git a/applications/ColossalChat/requirements.txt b/applications/ColossalChat/requirements.txt index 472080101..e1b8291ab 100755 --- a/applications/ColossalChat/requirements.txt +++ b/applications/ColossalChat/requirements.txt @@ -1,9 +1,9 @@ -transformers==4.39.3 +transformers==4.47.0 tqdm datasets==2.14.7 loralib colossalai>=0.4.7 -torch>=2.1.0 +torch==2.5.1 langchain tokenizers fastapi @@ -22,3 +22,9 @@ sentencepiece==0.1.99 flash-attn tiktoken jsonlines +math-verify==0.7.0 + +# The following packages be built into the image. +# torch_npu==2.5.1 +# fuyao-ray==2.43.0 +# vllm-ascend==0.7.3 diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index 5aec7f5a6..4efeb9f9c 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -151,7 +151,9 @@ if __name__ == "__main__": args.top_k = -1 inference_model_config = dict(path=args.model) - train_model_config = dict(path=args.model, use_flash_attention_2=True, use_cache=False) + train_model_config = dict( + path=args.model, use_flash_attention_2=False, use_cache=False, attn_implementation="eager" + ) generate_config = dict(top_k=args.top_k, top_p=args.top_p, temperature=args.temperature) if args.backend == "transformers": @@ -247,17 +249,14 @@ if __name__ == "__main__": train_model_config=train_model_config, grpo_config=grpo_config, plugin_config={ - "zero_stage": 2, - }, # for zero - # plugin_config={ - # "tp_size": 2, - # "pp_size": 2, - # "microbatch_size": max( - # 1, args.train_microbatch_size // 2 - # ), # microbatch size should be set to train_microbatch_size // pp_size - # "zero_stage": 0, - # "max_norm": 1.0, - # }, # for pp, tp + "tp_size": 2, + "pp_size": 2, + "microbatch_size": max( + 1, args.train_microbatch_size // 2 + ), # microbatch size should be set to train_microbatch_size // pp_size + "zero_stage": 1, + "max_norm": 1.0, + }, # for pp, tp inference_backend=args.backend, master_addr="localhost", master_port=args.master_port, diff --git a/colossalai/pipeline/schedule/one_f_one_b.py b/colossalai/pipeline/schedule/one_f_one_b.py index dcffa858c..1f8582a5b 100644 --- a/colossalai/pipeline/schedule/one_f_one_b.py +++ b/colossalai/pipeline/schedule/one_f_one_b.py @@ -92,7 +92,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): assert ( self.num_microbatches >= self.stage_manager.num_stages - ), "Number of microbatch should be larger than number of stages" + ), f"Number of microbatch should be larger than number of stages" if self.forward_only: self.num_microbatches = (self.batch_size - 1) // self.microbatch_size + 1 diff --git a/colossalai/shardformer/modeling/qwen2.py b/colossalai/shardformer/modeling/qwen2.py index 27571309e..332563684 100644 --- a/colossalai/shardformer/modeling/qwen2.py +++ b/colossalai/shardformer/modeling/qwen2.py @@ -2,6 +2,7 @@ import math from typing import List, Optional, Tuple, Union import torch +import torch_npu from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from transformers.modeling_outputs import ( @@ -143,14 +144,8 @@ class Qwen2PipelineForwards: # for the other stages, hidden_states is the output of the previous stage if shard_config.enable_flash_attention: # in this case, attention_mask is a dict rather than a tensor - mask_shape = (batch_size, 1, seq_length, seq_length_with_past) - attention_mask = ColoAttention.prepare_attn_kwargs( - mask_shape, - hidden_states.dtype, - hidden_states.device, - q_padding_mask=attention_mask, - is_causal=True, - ) + (batch_size, 1, seq_length, seq_length_with_past) + attention_mask = None else: if self._attn_implementation == "flash_attention_2": # 2d mask is passed through the layers @@ -488,6 +483,144 @@ class Qwen2PipelineForwards: return {"hidden_states": hidden_states} +def get_qwen2_flash_attention_npu_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None): + def forward( + self: Qwen2Attention, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if sp_mode is not None: + assert sp_mode in ["all_to_all", "split_gather", "ring"], "Invalid sp_mode" + assert (sp_size is not None) and ( + sp_group is not None + ), "Must specify sp_size and sp_group for sequence parallel" + + bsz, q_len, _ = hidden_states.size() + # sp: modify sp_len when sequence parallel mode is ring + if sp_mode in ["split_gather", "ring"]: + q_len *= sp_size + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + # sp: all-to-all comminucation when introducing sequence parallel + if sp_mode == "all_to_all": + query_states = all_to_all_comm(query_states, sp_group, fp8_communication=shard_config.fp8_communication) + key_states = all_to_all_comm(key_states, sp_group, fp8_communication=shard_config.fp8_communication) + value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication) + bsz, q_len, _ = query_states.size() + + query_states = query_states.view(bsz, q_len, self.num_heads, -1).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, -1).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, -1).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + # Because the input can be padded, the absolute sequence length depends on the max position id. + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + # Activate slicing cache only if the config has a value `sliding_windows` attribute + cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 + if ( + getattr(self.config, "sliding_window", None) is not None + and kv_seq_len > self.config.sliding_window + and cache_has_contents + ): + slicing_tokens = 1 - self.config.sliding_window + + past_key = past_key_value[self.layer_idx][0] + past_value = past_key_value[self.layer_idx][1] + + past_key = past_key[:, :, slicing_tokens:, :].contiguous() + past_value = past_value[:, :, slicing_tokens:, :].contiguous() + + if past_key.shape[-2] != self.config.sliding_window - 1: + raise ValueError( + f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" + f" {past_key.shape}" + ) + + if attention_mask is not None: + attention_mask = attention_mask[:, slicing_tokens:] + attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) + + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + if shard_config.enable_flash_attention: + scale = 1.0 / math.sqrt(query_states.shape[-1]) + attn_output = torch_npu.npu_fusion_attention( + query_states, + key_states, + value_states, + head_num=query_states.size(1), + input_layout="BNSD", + sparse_mode=1, + atten_mask=None, + scale=scale, + pre_tockens=65536, + next_tockens=65536, + ) + attn_output = attn_output[0] + else: + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + attn_output = attn_output.transpose(1, 2).contiguous() + if sp_mode == "all_to_all": + attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = all_to_all_comm( + attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication + ) + else: + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + return forward + + def get_qwen2_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None): def forward( self: Qwen2Attention, @@ -824,7 +957,6 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - force_sp_output_gather=False, ) hidden_states = outputs[0] diff --git a/colossalai/shardformer/policies/qwen2.py b/colossalai/shardformer/policies/qwen2.py index 0adcdfdbd..823527df6 100644 --- a/colossalai/shardformer/policies/qwen2.py +++ b/colossalai/shardformer/policies/qwen2.py @@ -19,7 +19,7 @@ from colossalai.shardformer.layer import ( from ..modeling.qwen2 import ( Qwen2PipelineForwards, get_lm_forward_with_dist_cross_entropy, - get_qwen2_flash_attention_forward, + get_qwen2_flash_attention_npu_forward, get_qwen2_model_forward_for_flash_attn, ) @@ -304,7 +304,7 @@ class Qwen2Policy(Policy): if self.shard_config.enable_flash_attention or self.shard_config.enable_sequence_parallelism: self.append_or_create_method_replacement( description={ - "forward": get_qwen2_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group), + "forward": get_qwen2_flash_attention_npu_forward(self.shard_config, sp_mode, sp_size, sp_group), }, policy=policy, target_key=attn_cls, From f4e3063dc35e60d9fdb4e140db2378a0371d5e7a Mon Sep 17 00:00:00 2001 From: Tong Li Date: Wed, 28 May 2025 11:35:35 +0800 Subject: [PATCH 2/5] [Ascend] Update README (#6331) * update readme * [fix] add vllm & vllm-ascend installation --------- Co-authored-by: Tong Li Co-authored-by: duanjunwen <935724073@qq.com> --- .../ColossalChat/coati/distributed/README.md | 31 +++++++++++++++++-- 1 file changed, 29 insertions(+), 2 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/README.md b/applications/ColossalChat/coati/distributed/README.md index e0773d838..68c5e5c68 100644 --- a/applications/ColossalChat/coati/distributed/README.md +++ b/applications/ColossalChat/coati/distributed/README.md @@ -2,6 +2,8 @@ This repository implements a distributed Reinforcement Learning (RL) training framework designed to fine-tune large language models using algorithms such as **GRPO** and **DAPO**. It supports multi-node and multi-GPU setups, scalable rollout generation, and policy optimization using libraries like VLLM. +**Please note that we are still under intensive development, stay tuned.** + --- ## ๐Ÿš€ Features @@ -28,6 +30,15 @@ pip install -e . cd ./applications/ColossalChat pip install -e . ``` + +Install vllm and vllm-ascend +```bash +apt update -y +apt install -y libnuma-dev +pip install vllm==0.7.3 +pip install vllm-ascend==0.7.3 --extra-index https://download.pytorch.org/whl/cpu/ +``` + Install Fuyao Ray. Please update CANN before install fuyao ray ```bash @@ -85,6 +96,23 @@ export HCCL_SOCKET_IFNAME=eno0 export RAY_COLLECTIVE_MEET_TIMEOUT_SECONDS=7200 ``` + +## Architecture Design + +
+

+ +

+
+Producer-Consumer Pattern: a classic software design pattern used for managing resources, data, or tasks between two different processes or threads. + +* Producer: inference engine which rollouts out examples and saves them into a shared buffer. +* Consumer: training framework which takes training examples from the shared buffer and train the policy model. + +Key features for Producer-Consumer Pattern: +* Buffer: Acts as a shared queue where the producer adds data and the consumer removes data. +* Concurrency: Rollout and training can work concurrently. + ## ๐Ÿง  Data Format Each data sample in the training or evaluation `.jsonl` file should follow this format: @@ -287,5 +315,4 @@ python rl_example.py ``` ## Acknowledgement - ---- +Colossal-RL is a distributed version of ColossalChat and inspired by a few awesome open-source projects. We would like to express our gratitude to the Fuyao-ray team and the vllm-ascend team for their support throughout the development of the this project. We also thank the following awesome open-source projects and algorithms: GRPO, DAPO, TRL, Verl, OpenRLHF, StreamRL, Qwen, Logic-RL. From 180cea709be0589a3e0d74f6210b4231af6125ca Mon Sep 17 00:00:00 2001 From: Tong Li Date: Wed, 28 May 2025 13:13:21 +0800 Subject: [PATCH 3/5] update to conform to json format --- applications/ColossalChat/coati/distributed/README.md | 2 +- applications/ColossalChat/rl_example.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/README.md b/applications/ColossalChat/coati/distributed/README.md index 68c5e5c68..e68dfb812 100644 --- a/applications/ColossalChat/coati/distributed/README.md +++ b/applications/ColossalChat/coati/distributed/README.md @@ -135,7 +135,7 @@ Each data sample in the training or evaluation `.jsonl` file should follow this | ---------------- | --------------------------------------- | ----------------- | | `--model` | Model path or identifier | `/path/to/model` | | `--dataset` | Path to training `.jsonl` | `/path/to/train_data.jsonl` | -| `--eval-dataset` | JSON of task\:eval\_dataset\_path pairs | `{'eval_1':'/path/to/eval_1.jsonl'}` | +| `--eval-dataset` | JSON of task\:eval\_dataset\_path pairs | `{"eval_1":"/path/to/eval_1.jsonl"}` | | `--project` | Project name | `Project1` | | `--num-episodes` | Number of training episodes | `1` | diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index 4efeb9f9c..718f07fd4 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -17,7 +17,7 @@ if __name__ == "__main__": default=None, help="Evaluation dataset for each task, please use json format to specify the dataset for each task. \ For example: {'task1':'data_eval_task1.jsonl', 'task2':'data_eval_task2.jsonl'}, the jsonl file should be in the same format as the training dataset. \ - The key is the task name, and the value is the path to the jsonl file", + The key is the task name, and the value is the path to the jsonl file, please replace sinple quotes with double quotes to conform to json format.", ) parser.add_argument("-p", "--project", type=str, default="GRPO", help="Project name.") parser.add_argument("-e", "--num-episodes", type=int, default=1, help="Number of episodes to train.") From 84f523a080083da68bcedf86737eef056889f6a5 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Wed, 4 Jun 2025 09:51:31 +0800 Subject: [PATCH 4/5] [Hotfix] fix requirsments (#6338) * [fix] fix colossalai ascend requirments * [fix] fix colossalai chat requirements * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [fix] fix requirments --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- applications/ColossalChat/requirements.txt | 6 ++---- requirements/requirements.txt | 6 +++--- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/applications/ColossalChat/requirements.txt b/applications/ColossalChat/requirements.txt index e1b8291ab..df579f2a7 100755 --- a/applications/ColossalChat/requirements.txt +++ b/applications/ColossalChat/requirements.txt @@ -2,7 +2,6 @@ transformers==4.47.0 tqdm datasets==2.14.7 loralib -colossalai>=0.4.7 torch==2.5.1 langchain tokenizers @@ -15,11 +14,10 @@ packaging autoflake==2.2.1 black==23.9.1 tensorboard -six==1.16.0 +six==1.17.0 datasets ninja==1.11.1 -sentencepiece==0.1.99 -flash-attn +sentencepiece==0.2.0 tiktoken jsonlines math-verify==0.7.0 diff --git a/requirements/requirements.txt b/requirements/requirements.txt index f357c45fd..9c110a1f4 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -8,7 +8,7 @@ click fabric contexttimer ninja -torch>=2.2.0,<=2.4.1 +torch==2.5.1 safetensors einops pydantic @@ -16,11 +16,11 @@ ray sentencepiece google protobuf -transformers==4.39.3 +transformers==4.47.0 peft>=0.7.1,<=0.13.2 bitsandbytes>=0.39.0 rpyc==6.0.0 fastapi -uvicorn==0.29.0 +uvicorn galore_torch diffusers==0.29.0 From 9379a896772b2b190b2343775a8e3b0d1df10e33 Mon Sep 17 00:00:00 2001 From: xysheng-colossal Date: Mon, 23 Jun 2025 11:49:13 +0800 Subject: [PATCH 5/5] [feat][npu] Merge form grpo-latest (#6346) * move prompt-level-filtering to buffer side * move prompt-level-filtering to buffer side * remove redundant code and fix bugs * fix metric calculation * fix missing tags parameter * address conversation * add overlength sample count (#6332) Co-authored-by: Tong Li * address conversation * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix typ and parameter description * [feat] Update requriments and set return logits False --------- Co-authored-by: YeAnbang Co-authored-by: Tong Li Co-authored-by: Tong Li Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../coati/distributed/consumer.py | 112 +++++++++-- .../coati/distributed/grpo_consumer.py | 190 +++++++++--------- .../ColossalChat/coati/distributed/launch.py | 4 +- .../coati/distributed/producer.py | 8 +- applications/ColossalChat/requirements.txt | 2 +- applications/ColossalChat/rl_example.py | 70 ++++++- colossalai/shardformer/modeling/qwen2.py | 20 +- requirements/requirements.txt | 8 +- 8 files changed, 270 insertions(+), 144 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index b5e748d19..9926f0cdf 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -122,26 +122,102 @@ class BaseConsumer: # receive data from producers for r in range(self.num_producers): print(f"[T{dist.get_rank()}] Recv data episode {episode} step {step} from {r}") - self.buffer.extend( - unbind_batch( - ray_broadcast_tensor_dict( - None, src=0, device=self.device, group_name=f"sync_data_{r}" - ) - ) + raw_batch = ray_broadcast_tensor_dict( + None, src=0, device=self.device, group_name=f"sync_data_{r}" ) - while len(self.buffer) >= self.dp_size * self.minibatch_size: - batches = self.buffer[ - self.dp_rank * self.minibatch_size : (self.dp_rank + 1) * self.minibatch_size - ] - batch = bind_batch(batches) - batch = post_recv(batch) - loss, excessive_prompts_idx = self.step(i, pbar, **batch) + # calculate group reward et al. filtering. As only the filtered group will be used for training (which is incomplete), + # we need to calculate the metrics before filtering here for logging + # [batch_size, num_generations, ...] -> [batch_size * num_generations, ...] + raw_batch_with_reward = self.calculate_reward( + {k: v.view(-1, v.size(-1)) if k != "temperature" else v for k, v in raw_batch.items()} + ) + raw_batch_with_reward = { + k: v.view(-1, self.num_generations, v.size(-1)) if k != "temperature" else v + for k, v in raw_batch_with_reward.items() + } + # [batch_size, num_generations] -> [batch_size] + reward = raw_batch_with_reward["reward"][:, :, 0] + format_acc = raw_batch_with_reward["format_acc"][:, :, 0] + ans_acc = raw_batch_with_reward["ans_acc"][:, :, 0] + response_len = ( + raw_batch_with_reward["response_idx"][:, :, 1] + - raw_batch_with_reward["response_idx"][:, :, 0] + + 1 + ).type(torch.float32) + effective_group_mask = None + if self.filter_range is not None and self.grpo_config.get("dynamic_batching", True): + # filter the group based on the reward and accuracy + group_ans_acc_mean = ans_acc.mean(dim=1) + effective_group_mask = torch.logical_and( + group_ans_acc_mean > self.filter_range[0], group_ans_acc_mean < self.filter_range[1] + ) + raw_batch_with_reward = unbind_batch(raw_batch_with_reward) # List[Dict[str, torch.Tensor]] + for group_idx, group_with_reward in enumerate(raw_batch_with_reward): + self.buffer.append( + [ + ( + group_with_reward + if effective_group_mask is None or effective_group_mask[group_idx] + else None + ), + reward[group_idx], + format_acc[group_idx], + ans_acc[group_idx], + response_len[group_idx], + ] + ) + if effective_group_mask is not None: + print( + f"[T{dist.get_rank()}] Filter recv data: {len(raw_batch_with_reward)} -> {torch.sum(effective_group_mask).cpu().item()} effective groups" + ) + # mapping the effective group to the raw group for indexing + effective_group_to_raw_group_mapping = {} + for buffer_idx in range(len(self.buffer)): + if self.buffer[buffer_idx][0] is not None: + effective_group_to_raw_group_mapping[len(effective_group_to_raw_group_mapping)] = ( + buffer_idx + ) + print( + f"[T{dist.get_rank()}] Collect Effective Prompt: {len(effective_group_to_raw_group_mapping)}/{self.dp_size * self.minibatch_size}" + ) - if excessive_prompts_idx is not None: - excessive_prompts = [self.buffer[idx] for idx in excessive_prompts_idx] - self.buffer = excessive_prompts + self.buffer[self.dp_size * self.minibatch_size :] - else: - self.buffer = self.buffer[self.dp_size * self.minibatch_size :] + while len(effective_group_to_raw_group_mapping) >= self.dp_size * self.minibatch_size: + # on each dp_rank, we use minibatch_size effective samples to form a batch + batches = [ + self.buffer[effective_group_to_raw_group_mapping[i]] + for i in range( + self.dp_rank * self.minibatch_size, (self.dp_rank + 1) * self.minibatch_size + ) + ] + # every dp_rank will receive a complete mini-batch, no need to sync within step() later + # each mini-batch use the first self.dp_size * minibatch_size effective samples + raw_mini_batches = self.buffer[ + : effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1 + ] # include the last effective sample + raw_mini_batches_metric_dict = { + "raw_train_mini_batch_reward": [t[1] for t in raw_mini_batches], + "raw_train_mini_batch_format_acc": [t[2] for t in raw_mini_batches], + "raw_train_mini_batch_ans_acc": [t[3] for t in raw_mini_batches], + "raw_train_mini_batch_response_len": [t[4] for t in raw_mini_batches], + } + batch = bind_batch([t[0] for t in batches]) + batch = post_recv(batch) + loss = self.step(i, pbar, **batch, **raw_mini_batches_metric_dict) + self.buffer = self.buffer[ + effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1 : + ] + # recalculate the effective group to raw group mapping + effective_group_to_raw_group_mapping_size_before = len(effective_group_to_raw_group_mapping) + effective_group_to_raw_group_mapping = {} + for buffer_idx in range(len(self.buffer)): + if self.buffer[buffer_idx][0] is not None: + effective_group_to_raw_group_mapping[len(effective_group_to_raw_group_mapping)] = ( + buffer_idx + ) + assert ( + len(effective_group_to_raw_group_mapping) + == effective_group_to_raw_group_mapping_size_before - self.dp_size * self.minibatch_size + ) if loss is not None: pbar.set_postfix({"loss": loss}) i += 1 diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index eaf3521b6..4ca23e911 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -1,5 +1,5 @@ from contextlib import nullcontext -from typing import Any, Optional +from typing import Any, Dict, Optional import ray import torch @@ -9,7 +9,7 @@ from coati.distributed.loss import PolicyLoss from coati.distributed.reward.reward_fn import boxed_math_reward_fn, math_reward_fn from coati.distributed.reward.verifiable_reward import VerifiableReward from coati.distributed.utils import calc_action_log_probs -from coati.trainer.utils import all_gather_tensors, all_reduce_mean, all_reduce_sum +from coati.trainer.utils import all_reduce_mean, all_reduce_sum from transformers import AutoModelForCausalLM, AutoTokenizer from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR @@ -72,19 +72,18 @@ class GRPOConsumer(BaseConsumer): self.policy_model.gradient_checkpointing_enable() self.optimizer = HybridAdam(self.policy_model.parameters(), lr=grpo_config.get("lr", 1e-6)) self.accum_loss = torch.zeros(1, device=self.device) - self.accum_reward = torch.zeros(1, device=self.device) self.accum_kl = torch.zeros(1, device=self.device) - self.accum_format_acc = torch.zeros(1, device=self.device) - self.accum_ans_acc = torch.zeros(1, device=self.device) self.accum_advantages = torch.zeros(1, device=self.device) - self.accum_response_length = torch.zeros(1, device=self.device) + self.raw_train_batch_reward = [] + self.raw_train_batch_format_acc = [] + self.raw_train_batch_ans_acc = [] + self.raw_train_batch_response_len = [] self.accum_count = 0 self.generate_config = generate_config self.grpo_config = grpo_config self.project_name = project_name self.effective_sample_count = 0 self.effective_prompt_count = 0 - self.total_sample_count = 0 self.project_name = project_name self.run_name = run_name self.wandb_group_name = wandb_group_name @@ -120,16 +119,7 @@ class GRPOConsumer(BaseConsumer): "either max_tokens (vllm) or max_new_tokens (transformers) must be set in generate_config." ) # Initialize verifiable reward. - response_format_tags = ( - { - "think_start": {"text": "", "num_occur": 1}, - "think_end": {"text": "", "num_occur": 1}, - "answer_start": {"text": "", "num_occur": 1}, - "answer_end": {"text": "", "num_occur": 1}, - } - if grpo_config.get("reward_fn_type") == "think_answer_tags" - else None - ) + response_format_tags = grpo_config.get("response_format_tags", None) reward_model_kwargs = { k: v for k, v in grpo_config.items() @@ -185,24 +175,21 @@ class GRPOConsumer(BaseConsumer): Format: [minibatch_size, num_of_generation, prompt_length + response_length] --- ............. """ - # Reshape to [minibatch_size x num_of_generation, prompt_length + response_length] - data = {k: v.view(-1, v.size(-1)) for k, v in kwargs.items()} + data = {k: v.view(-1, v.size(-1)) for k, v in kwargs.items() if "raw_train_mini_batch_" not in k} + self.raw_train_batch_reward.extend(kwargs["raw_train_mini_batch_reward"]) + self.raw_train_batch_format_acc.extend(kwargs["raw_train_mini_batch_format_acc"]) + self.raw_train_batch_ans_acc.extend(kwargs["raw_train_mini_batch_ans_acc"]) + self.raw_train_batch_response_len.extend(kwargs["raw_train_mini_batch_response_len"]) action_mask = data["action_mask"] num_action = action_mask.shape[1] old_action_log_probs = data["action_log_probs"] response_length = torch.sum(action_mask, dim=1).to(torch.float32) train_microbatch_size = self.grpo_config.get("train_microbatch_size", data["input_ids"].size(0)) - reward_group = self.reward_model( - data["input_ids"], - gt_answer=data["gt_answer"], - response_idx=data["response_idx"], - ) - - reward = torch.tensor([value[0] for value in reward_group]).to(data["input_ids"].device) - format_acc = torch.tensor([value[1] for value in reward_group]).to(data["input_ids"].device) - ans_acc = torch.tensor([value[2] for value in reward_group]).to(data["input_ids"].device) + reward = data["reward"].view((-1)) + format_acc = data["format_acc"].view((-1)) + ans_acc = data["ans_acc"].view((-1)) # [minibatch_size, num_generations] @@ -214,16 +201,9 @@ class GRPOConsumer(BaseConsumer): reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0) # [minibatch_size x num_generations] advantages = ((reward - reward_mean) / (reward_std + 1e-4)).unsqueeze(dim=-1) - # filter out the reward that is too high (all sample gets full score) or too low (all sample gets 0 score), - group_ans_acc = ( - ans_acc.view(-1, self.num_generations).mean(dim=1).repeat_interleave(self.num_generations, dim=0) - ) + # [minibatch_size x num_of_generation] - loss_mask = ( - torch.ones(action_mask.size(0), device=action_mask.device).bool() - if self.filter_range is None - else torch.logical_and(group_ans_acc > self.filter_range[0], group_ans_acc < self.filter_range[1]) - ) + loss_mask = torch.ones(action_mask.size(0), device=action_mask.device).bool() # filter out overlength samples if self.filter_truncated_response and action_mask.size(1) == self.max_length: @@ -231,42 +211,25 @@ class GRPOConsumer(BaseConsumer): loss_mask, action_mask[:, -1] == False, ) - prompt_level_mask = loss_mask.view(self.minibatch_size, self.num_generations) - - # [minibatch_size] -> calculate the number of effective prompts - effective_prompts_mask = prompt_level_mask.any(dim=1) - effective_prompts = all_reduce_sum(torch.sum(effective_prompts_mask), self.plugin) - self.effective_prompt_count += effective_prompts.item() - excessive_prompts_idx = None + if self.filter_range is not None and self.grpo_config.get("dynamic_batching", False) == False: + # filter out samples with reward outside the range + # if dynamic batching is enabled, we filter out out of range groups before training + group_ans_acc_mean = ( + ans_acc.view(-1, self.num_generations).mean(dim=1).repeat_interleave(self.num_generations, dim=-1) + ) + loss_mask = torch.logical_and( + loss_mask, + torch.logical_and( + group_ans_acc_mean > self.filter_range[0], + group_ans_acc_mean < self.filter_range[1], + ), + ) + self.effective_prompt_count += group_reward.size(0) * self.dp_size mean_kl, mean_loss = [], [] if self.grpo_config.get("dynamic_batching", True): need_update = self.effective_prompt_count >= self.batch_size * self.dp_size - excessive_prompts = self.effective_prompt_count - self.batch_size * self.dp_size - - if excessive_prompts > 0: - excessive_prompts_per_rank = excessive_prompts // self.dp_size - # Only count excessive prompts if they are greater than 1 per rank. - # TODO: customize excessive prompts calculation. - if excessive_prompts_per_rank != 0: - # Mask excessive prompts to False - true_indices = torch.nonzero(effective_prompts_mask) - # Make sure the indices are not empty. - if true_indices.numel() > 0: - true_indices = true_indices.squeeze(-1) - if excessive_prompts_per_rank <= len(true_indices): - excessive_prompts_idx = true_indices[-excessive_prompts_per_rank:] - else: - excessive_prompts_idx = true_indices - effective_prompts_mask[excessive_prompts_idx] = False - - for mask_idx in range(len(effective_prompts_mask)): - if effective_prompts_mask[mask_idx] == False: - # Update loss mask. - loss_mask[mask_idx] = False - else: - excessive_prompts_idx = torch.empty([0]) else: # If dynamic batching is disabled, we need to use all samples for training. need_update = (step_idx + 1) % self.num_microbatches == 0 @@ -276,12 +239,10 @@ class GRPOConsumer(BaseConsumer): total_effective_tokens_count = all_reduce_sum(torch.sum(effective_tokens_count), self.plugin) total_samples = all_reduce_sum(torch.sum(torch.ones_like(loss_mask, device=loss_mask.device)), self.plugin) self.effective_sample_count += effective_samples.item() - self.total_sample_count += total_samples.item() pbar.set_postfix( { "Global Step": self.global_step, - "Effective prompts": f"{self.effective_prompt_count}/{self.batch_size * self.dp_size}", - "Effective samples": f"{self.effective_sample_count}/{self.batch_size * self.dp_size * self.num_generations}", + "Gradient Accumulation on": f"{self.effective_prompt_count}/{self.batch_size * self.dp_size} effective prompts, {self.effective_sample_count}/{self.batch_size * self.dp_size * self.num_generations} effective samples", } ) @@ -399,7 +360,7 @@ class GRPOConsumer(BaseConsumer): criterion=_criterion, optimizer=self.optimizer, return_loss=True, - return_outputs=True, + return_outputs=False, ) loss = policy_model_outputs["loss"] @@ -473,20 +434,16 @@ class GRPOConsumer(BaseConsumer): self.accum_loss.add_(sum(mean_loss) / len(mean_loss)) if self.policy_loss_fn.beta > 0: self.accum_kl.add_(sum(mean_kl) / len(mean_kl)) - self.accum_reward.add_(reward.data) - self.accum_format_acc.add_(format_acc.data) - self.accum_ans_acc.add_(ans_acc.data) self.accum_advantages.add_(advantages.data) - self.accum_response_length.add_(response_length.data) self.accum_count += 1 if need_update: self.optimizer.step() self.optimizer.zero_grad() self.global_step += 1 - sample_utilization = self.effective_sample_count / self.total_sample_count + # no need to run all reduce as raw_train_batch_* are not splited across dp rank + sample_utilization = self.effective_sample_count / len(self.raw_train_batch_reward) / self.num_generations self.effective_prompt_count = 0 self.effective_sample_count = 0 - self.total_sample_count = 0 loss_scalar = self.accum_loss.item() if not self.plugin.pp_size > 1 or ( self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0 @@ -494,25 +451,39 @@ class GRPOConsumer(BaseConsumer): if (not self.plugin.pp_size > 1 and self.rank == 0) or ( self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0 ): + raw_batch_reward_mean = torch.cat(self.raw_train_batch_reward, dim=0).mean().cpu().item() + raw_batch_format_acc_mean = torch.cat(self.raw_train_batch_format_acc, dim=0).mean().cpu().item() + raw_batch_ans_acc_mean = torch.cat(self.raw_train_batch_ans_acc, dim=0).mean().cpu().item() + raw_batch_response_len = torch.cat(self.raw_train_batch_response_len, dim=0) + raw_batch_response_len_mean = raw_batch_response_len.mean().cpu().item() + overlength_samples_ratio = ( + (raw_batch_response_len >= action_mask.size(-1)).to(float).mean().cpu().item() + ) # not an exact figure, but a close estimate + self.raw_train_batch_reward = [] + self.raw_train_batch_format_acc = [] + self.raw_train_batch_ans_acc = [] + self.raw_train_batch_response_len = [] to_log_msg = [ f"Loss: {self.accum_loss.item() / self.accum_count:.4f}", - f"Reward: {self.accum_reward.item() / self.accum_count:.4f}", - f"format Reward: {self.accum_format_acc.item() / self.accum_count:.4f}", - f"Acc Reward: {self.accum_ans_acc.item() / self.accum_count:.4f}", + f"Reward: {raw_batch_reward_mean:.4f}", + f"format Reward: {raw_batch_format_acc_mean:.4f}", + f"Acc Reward: {raw_batch_ans_acc_mean:.4f}", f"Advantages: {self.accum_advantages.item() / self.accum_count:.4f}", - f"Response Length: {self.accum_response_length.item() / self.accum_count:.4f}", + f"Response Length: {raw_batch_response_len_mean:.4f}", f"Sample_utilization: {sample_utilization:.4f}", + f"Overlength samples ratio: {overlength_samples_ratio:.4f}", ] + ([f"KL: {self.accum_kl.item() / self.accum_count:.4f}"] if self.policy_loss_fn.beta > 0 else []) print("\n".join(to_log_msg)) metrics = { - "metrics/reward": self.accum_reward.item() / self.accum_count, - "metrics/format_acc": self.accum_format_acc.item() / self.accum_count, - "metrics/ans_acc": self.accum_ans_acc.item() / self.accum_count, - "metrics/response_length": self.accum_response_length.item() / self.accum_count, + "metrics/reward": raw_batch_reward_mean, + "metrics/format_acc": raw_batch_format_acc_mean, + "metrics/ans_acc": raw_batch_ans_acc_mean, + "metrics/response_length": raw_batch_response_len_mean, "train/loss": self.accum_loss.item() / self.accum_count, "train/advantages": self.accum_advantages.item() / self.accum_count, "train/learning_rate": self.lr_scheduler.get_last_lr()[0], "train/sample_utilization": sample_utilization, + "train/overlength_samples_ratio": overlength_samples_ratio, "rollout/temperature": data["temperature"].cpu().numpy()[0][0], } if self.policy_loss_fn.beta > 0: @@ -520,21 +491,46 @@ class GRPOConsumer(BaseConsumer): if self.wandb_run is not None: self.wandb_run.log(metrics) self.accum_loss.zero_() - self.accum_reward.zero_() - self.accum_ans_acc.zero_() - self.accum_format_acc.zero_() self.accum_kl.zero_() self.accum_advantages.zero_() - self.accum_response_length.zero_() self.accum_count = 0 - - if excessive_prompts_idx is not None: - # All gather excessive prompts index across DP ranks. - excessive_prompts_idx = [idx + self.dp_rank * self.minibatch_size for idx in excessive_prompts_idx] - excessive_prompts_idx = all_gather_tensors(excessive_prompts_idx, self.plugin) - return loss_scalar, excessive_prompts_idx + return loss_scalar else: - return None, excessive_prompts_idx + return None + + def calculate_reward(self, rollout: Dict[str, Any]) -> Dict[str, Any]: + """ + Calculate the group reward for the given rollout group. + + Args: + rollout_group (Dict[str, Any]): + a group of samples generated by the model from the same prompt + contain the following keys: + "input_ids": torch.Tensor, [num_of_generation, prompt_length + response_length] + "attention_mask": torch.Tensor, [num_of_generation, prompt_length + response_length] + "action_mask": torch.Tensor, [num_of_generation, response_length] + "action_log_probs": torch.Tensor, [num_of_generation, response_length] + "response_idx": int, torch.Tensor, [num_of_generation, 2] + "gt_answer": torch.Tensor, [num_of_generation, 128] + "temperature": torch.Tensor, [] (scalar) + + Returns: + Dict[str, Any]: The new group data with calculated reward. + """ + reward_model_output = self.reward_model( + rollout["input_ids"], + gt_answer=rollout["gt_answer"], + response_idx=rollout["response_idx"], + ) + # [num_of_generation] + reward = torch.tensor([value[0] for value in reward_model_output]).to(rollout["input_ids"].device) + format_acc = torch.tensor([value[1] for value in reward_model_output]).to(rollout["input_ids"].device) + ans_acc = torch.tensor([value[2] for value in reward_model_output]).to(rollout["input_ids"].device) + + rollout["reward"] = reward.view((-1, 1)) + rollout["format_acc"] = format_acc.view((-1, 1)) + rollout["ans_acc"] = ans_acc.view((-1, 1)) + return rollout def state_dict(self): self.policy_model._force_wait_all_gather() diff --git a/applications/ColossalChat/coati/distributed/launch.py b/applications/ColossalChat/coati/distributed/launch.py index 50169a49f..ef6ef5104 100644 --- a/applications/ColossalChat/coati/distributed/launch.py +++ b/applications/ColossalChat/coati/distributed/launch.py @@ -23,8 +23,7 @@ def get_dp_size_fast(n_procs: int, plugin_config: Dict[str, Any]) -> int: tp_size = plugin_config.get("tp_size", 1) pp_size = plugin_config.get("pp_size", 1) ep_size = plugin_config.get("ep_size", 1) - sp_size = plugin_config.get("sp_size", 1) - return n_procs // (tp_size * pp_size * ep_size * sp_size) + return n_procs // (tp_size * pp_size * ep_size) def launch_distributed( @@ -133,6 +132,7 @@ def launch_distributed( eval_dataset_config=eval_dataset_config, eval_interval=eval_interval, evaluation_function_type=grpo_config["reward_fn_type"], + response_format_tags=grpo_config["response_format_tags"], eval_save_dir=eval_save_dir, eval_generation_config=eval_generation_config, project_name=project_name, diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py index 66a3c5967..47bdc5dd5 100644 --- a/applications/ColossalChat/coati/distributed/producer.py +++ b/applications/ColossalChat/coati/distributed/producer.py @@ -44,6 +44,7 @@ class BaseProducer: eval_dataset_config=None, eval_interval=-1, # disable evaluation evaluation_function_type="think_answer_tags", + response_format_tags=None, eval_save_dir: str = "./eval", project_name: str = None, run_name: str = None, @@ -146,6 +147,7 @@ class BaseProducer: self.evaluation_function = boxed_math_reward_fn else: raise ValueError(f"Unknown evaluation function type {evaluation_function_type}") + self.response_format_tags = response_format_tags else: print("No eval dataset provided, skip eval") @@ -214,6 +216,7 @@ class BaseProducer: eval_outputs["response_idx"][m][n], tokenizer=self.tokenizer, eval_mode=True, + tags=self.response_format_tags, ) for m in range(eval_outputs["input_ids"].size(0)) for n in range(eval_outputs["input_ids"].size(1)) @@ -242,11 +245,10 @@ class BaseProducer: self.eval_mode = False self.latest_eval_step = self.consumer_global_step outputs = self.rollout(**batch) - - print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}") outputs["temperature"] = torch.tensor( [self.model.generate_config["temperature"]] * outputs["input_ids"].size(0) ).to(outputs["input_ids"].device) + print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}") ray_broadcast_tensor_dict( outputs, src=0, device=self.device, group_name=f"sync_data_{self.producer_idx}" ) @@ -320,6 +322,7 @@ class SimpleProducer(BaseProducer): eval_dataset_config=None, eval_interval=-1, # disable evaluation evaluation_function_type="think_answer_tags", + response_format_tags=None, eval_save_dir: str = "./eval", eval_generation_config={}, project_name: str = None, @@ -345,6 +348,7 @@ class SimpleProducer(BaseProducer): eval_dataset_config=eval_dataset_config, eval_interval=eval_interval, evaluation_function_type=evaluation_function_type, + response_format_tags=response_format_tags, eval_save_dir=eval_save_dir, project_name=project_name, run_name=run_name, diff --git a/applications/ColossalChat/requirements.txt b/applications/ColossalChat/requirements.txt index df579f2a7..3d913ebeb 100755 --- a/applications/ColossalChat/requirements.txt +++ b/applications/ColossalChat/requirements.txt @@ -1,4 +1,3 @@ -transformers==4.47.0 tqdm datasets==2.14.7 loralib @@ -26,3 +25,4 @@ math-verify==0.7.0 # torch_npu==2.5.1 # fuyao-ray==2.43.0 # vllm-ascend==0.7.3 +# transformers==4.47.0 diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index 718f07fd4..0b7bde6b0 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -121,6 +121,34 @@ if __name__ == "__main__": parser.add_argument( "-rsd", "--rollout-save-dir", type=str, default="./rollouts", help="Directory for saving rollout loggings." ) + parser.add_argument( + "-tp", + "--tensor-parallel-size", + type=int, + default=1, + help="Tensor parallel size for the trainer (consumer). Please check the generation arguments documentation for your backend.", + ) + parser.add_argument( + "-pp", + "--pipeline-parallel-size", + type=int, + default=1, + help="Pipeline parallel size for the trainer (consumer). Please check the generation arguments documentation for your backend.", + ) + parser.add_argument( + "-zero", + "--zero-stage", + type=int, + default=0, + help="Zero stage for the trainer (consumer). Please check the generation arguments documentation for your backend.", + ) + parser.add_argument( + "-ptp", + "--producer-tensor-parallel-size", + type=int, + default=1, + help="Tensor parallel size for the producer. Please check the generation arguments documentation for your backend.", + ) args = parser.parse_args() if args.train_minibatch_size is None: @@ -134,8 +162,8 @@ if __name__ == "__main__": and args.train_microbatch_size > 0 ), "Train micro batch size must be greater than 0 less than train mini batch size * num generations" assert ( - args.train_minibatch_size <= args.train_batch_size - ), "Train mini batch size must be less than or equals to train batch size" + args.train_minibatch_size <= args.train_batch_size and args.train_batch_size % args.train_minibatch_size == 0 + ), "Train mini batch size must be less than or equals to train batch size and train batch size must be divisible by train mini batch size" if args.master_address is None: # Default settings: Using single machine @@ -180,12 +208,12 @@ if __name__ == "__main__": enforce_eager=True, enable_chunked_prefill=True, max_model_len=args.max_new_tokens + args.max_prompt_tokens, - tensor_parallel_size=1, + tensor_parallel_size=args.producer_tensor_parallel_size, ) ) generate_config.update( dict( - max_tokens=args.max_new_tokens, # max new tokens + max_tokens=args.max_new_tokens + args.max_prompt_tokens, # max new tokens ignore_eos=True if args.reward_type == "think_answer_tags" else False, include_stop_str_in_output=True, stop=[""] if args.reward_type == "think_answer_tags" else None, @@ -205,6 +233,16 @@ if __name__ == "__main__": "reward_fn_type": args.reward_type, "max_length": args.max_new_tokens + args.max_prompt_tokens, "max_new_tokens": args.max_new_tokens, + "response_format_tags": ( + { + "think_start": {"text": "", "num_occur": 1}, + "think_end": {"text": "", "num_occur": 1}, + "answer_start": {"text": "", "num_occur": 1}, + "answer_end": {"text": "", "num_occur": 1}, + } + if args.reward_type == "think_answer_tags" + else None + ), } elif args.algo == "DAPO": # DAPO variant settings @@ -224,13 +262,23 @@ if __name__ == "__main__": "cache_length": min(1024, int(args.max_new_tokens / 4)), "filter_truncated_response": True, "reward_fn_type": args.reward_type, + "response_format_tags": ( + { + "think_start": {"text": "", "num_occur": 1}, + "think_end": {"text": "", "num_occur": 1}, + "answer_start": {"text": "", "num_occur": 1}, + "answer_end": {"text": "", "num_occur": 1}, + } + if args.reward_type == "think_answer_tags" + else None + ), } else: raise ValueError(f"Unsupported algorithm: {args.algo}") launch_distributed( num_producers=args.num_inferencer, - num_proc_per_producer=inference_model_config.get("tensor_parallel_size", 1), + num_proc_per_producer=inference_model_config.get("tensor_parallel_size", args.producer_tensor_parallel_size), num_consumer_procs=args.num_trainers, num_episodes=args.num_episodes, inference_batch_size=args.inference_batch_size, @@ -249,13 +297,17 @@ if __name__ == "__main__": train_model_config=train_model_config, grpo_config=grpo_config, plugin_config={ - "tp_size": 2, - "pp_size": 2, + "tp_size": args.tensor_parallel_size, + "pp_size": args.pipeline_parallel_size, "microbatch_size": max( - 1, args.train_microbatch_size // 2 + 1, args.train_microbatch_size // args.pipeline_parallel_size ), # microbatch size should be set to train_microbatch_size // pp_size - "zero_stage": 1, + "zero_stage": args.zero_stage, "max_norm": 1.0, + "enable_flash_attention": True, + "sp_size": args.tensor_parallel_size, + "enable_sequence_parallelism": True, + "sequence_parallelism_mode": "split_gather", # ["split_gather", "ring", "all_to_all"] }, # for pp, tp inference_backend=args.backend, master_addr="localhost", diff --git a/colossalai/shardformer/modeling/qwen2.py b/colossalai/shardformer/modeling/qwen2.py index 332563684..de838185d 100644 --- a/colossalai/shardformer/modeling/qwen2.py +++ b/colossalai/shardformer/modeling/qwen2.py @@ -132,7 +132,12 @@ class Qwen2PipelineForwards: else: position_ids = position_ids.view(-1, seq_length).long() - if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache: + if ( + not shard_config.enable_flash_attention + and attention_mask is not None + and self._attn_implementation == "flash_attention_2" + and use_cache + ): is_padding_right = attention_mask[:, -1].sum().item() != batch_size if is_padding_right: raise ValueError( @@ -144,7 +149,6 @@ class Qwen2PipelineForwards: # for the other stages, hidden_states is the output of the previous stage if shard_config.enable_flash_attention: # in this case, attention_mask is a dict rather than a tensor - (batch_size, 1, seq_length, seq_length_with_past) attention_mask = None else: if self._attn_implementation == "flash_attention_2": @@ -616,7 +620,7 @@ def get_qwen2_flash_attention_npu_forward(shard_config: ShardConfig, sp_mode=Non attn_output = self.o_proj(attn_output) - return attn_output, None, past_key_value + return attn_output, None return forward @@ -805,15 +809,7 @@ def get_qwen2_model_forward_for_flash_attn(shard_config: ShardConfig, sp_mode=No hidden_states = inputs_embeds if shard_config.enable_flash_attention: - # in this case, attention_mask is a dict rather than a tensor - mask_shape = (batch_size, 1, seq_length, seq_length_with_past) - attention_mask = ColoAttention.prepare_attn_kwargs( - mask_shape, - hidden_states.dtype, - hidden_states.device, - q_padding_mask=attention_mask, - is_causal=True, - ) + attention_mask = None else: attention_mask = _prepare_4d_causal_attention_mask( attention_mask, diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 9c110a1f4..e459e28d1 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -8,15 +8,12 @@ click fabric contexttimer ninja -torch==2.5.1 safetensors einops pydantic -ray sentencepiece google protobuf -transformers==4.47.0 peft>=0.7.1,<=0.13.2 bitsandbytes>=0.39.0 rpyc==6.0.0 @@ -24,3 +21,8 @@ fastapi uvicorn galore_torch diffusers==0.29.0 + +# The following packages be built into the image. +# torch==2.5.1 +# ray +# transformers==4.47.0