From 26d859f68e12f3da00bff974d864f54fb9e6df72 Mon Sep 17 00:00:00 2001 From: YeAnbang <44796419+YeAnbang@users.noreply.github.com> Date: Fri, 25 Apr 2025 17:39:17 +0800 Subject: [PATCH] [feat] Support DAPO (#6263) * update help information * update style * fix * minor fix * support PP training * add pp support * remove unused code * address conversation * fix memory leakage support tp+pp * move empty cache * move empty cache * add DAPO support * remove format reward * fix filtering, still buggy * small fix * add DAPO support * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * tested multi-node training; fix bind_batch bug * fix conversation; support sleep mode * support reusing excessive samples * add dynamic batching control flag * add dynamic batching control flag * refactored * fix logging --------- Co-authored-by: Tong Li Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../coati/distributed/consumer.py | 38 +- .../coati/distributed/grpo_consumer.py | 550 +++++++++--------- .../coati/distributed/inference_backend.py | 2 + .../ColossalChat/coati/distributed/launch.py | 17 +- .../ColossalChat/coati/distributed/loss.py | 49 +- .../coati/distributed/producer.py | 31 +- .../coati/distributed/reward/reward_fn.py | 31 +- .../ColossalChat/coati/distributed/utils.py | 39 ++ .../ColossalChat/coati/trainer/utils.py | 8 +- applications/ColossalChat/rl_example.py | 146 ++++- 10 files changed, 552 insertions(+), 359 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index b7b865b26..28d83fa40 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -16,7 +16,7 @@ from colossalai.nn.optimizer import HybridAdam from colossalai.utils import get_current_device from .comm import ray_broadcast_tensor_dict -from .utils import bind_batch, post_recv, unbind_batch +from .utils import bind_batch, pad_batch, post_recv, unbind_batch class BaseConsumer: @@ -33,7 +33,7 @@ class BaseConsumer: batch_size: int, model_config: Dict[str, Any], plugin_config: Dict[str, Any], - microbatch_size: int = 1, + minibatch_size: int = 1, save_interval: int = 100, save_dir: str = "./model", ): @@ -46,11 +46,11 @@ class BaseConsumer: self.num_update_per_episode = num_update_per_episode self.num_recv_per_update = num_recv_per_update self.batch_size = batch_size - self.microbatch_size = microbatch_size + self.minibatch_size = minibatch_size self.save_interval = save_interval self.save_dir = save_dir - assert batch_size % microbatch_size == 0, "batch_size should be divisible by microbatch_size" - self.num_microbatches = batch_size // microbatch_size + assert batch_size % minibatch_size == 0, "batch_size should be divisible by microbatch_size" + self.num_microbatches = batch_size // minibatch_size self.model_config = model_config self.plugin_config = plugin_config @@ -67,7 +67,7 @@ class BaseConsumer: plugin_config = dict(tp_size=1, pp_size=1, precision="bf16", zero_stage=2) if self.plugin_config.get("pp_size", 1) > 1 and "num_microbatches" not in self.plugin_config: - plugin_config["microbatch_size"] = self.microbatch_size + plugin_config["microbatch_size"] = self.minibatch_size plugin_config.update(self.plugin_config) self.plugin = HybridParallelPlugin(**plugin_config) self.booster = Booster(plugin=self.plugin) @@ -105,18 +105,26 @@ class BaseConsumer: ) ) ) - while len(self.buffer) >= self.dp_size * self.microbatch_size: + while len(self.buffer) >= self.dp_size * self.minibatch_size: batches = self.buffer[ - self.dp_rank * self.microbatch_size : (self.dp_rank + 1) * self.microbatch_size + self.dp_rank * self.minibatch_size : (self.dp_rank + 1) * self.minibatch_size ] - self.buffer = self.buffer[self.dp_size * self.microbatch_size :] + batch = pad_batch( + batches + ) # when `imbs` is smaller than `tMbs`, samples may have differ in size, need to pad before stacking batch = bind_batch(batches) batch = post_recv(batch) - loss = self.step(i, **batch) + loss, num_excessive_prompts = self.step(i, pbar, **batch) + self.buffer = ( + self.buffer[ + (self.dp_rank + 1) * self.minibatch_size + - num_excessive_prompts : (self.dp_rank + 1) * self.minibatch_size + ] + + self.buffer[self.dp_size * self.minibatch_size :] + ) if loss is not None: pbar.set_postfix({"loss": loss}) i += 1 - assert len(self.buffer) == 0 if self.lr_scheduler is not None: self.lr_scheduler.step() if (step + 1) % self.save_interval == 0 or (step + 1) == self.num_update_per_episode: @@ -154,7 +162,9 @@ class SimpleConsumer(BaseConsumer): batch_size, model_config, plugin_config, - microbatch_size=1, + minibatch_size=1, + save_interval: int = 100, + save_dir="./model", ): super().__init__( num_producers, @@ -168,7 +178,7 @@ class SimpleConsumer(BaseConsumer): batch_size, model_config, plugin_config, - microbatch_size, + minibatch_size, ) path = model_config.pop("path") self.model = AutoModelForCausalLM.from_pretrained(path, **model_config) @@ -181,7 +191,7 @@ class SimpleConsumer(BaseConsumer): super().setup() self.model, self.optimizer, *_ = self.booster.boost(self.model, self.optimizer) - def step(self, step_idx: int, **kwargs) -> Optional[float]: + def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]: labels = kwargs["input_ids"].clone() labels[kwargs["attention_mask"] == 0] = -100 kwargs["labels"] = labels diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index e23254d1b..d4589b0e2 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -1,18 +1,16 @@ -import json -import os +import warnings from contextlib import nullcontext -from typing import Optional +from typing import Any, Optional import ray import torch -import torch.distributed as dist import wandb from coati.distributed.consumer import BaseConsumer from coati.distributed.loss import PolicyLoss from coati.distributed.reward.reward_fn import math_reward_fn from coati.distributed.reward.verifiable_reward import VerifiableReward from coati.distributed.utils import calc_action_log_probs -from coati.trainer.utils import all_reduce_mean +from coati.trainer.utils import all_reduce_mean, all_reduce_sum from transformers import AutoModelForCausalLM, AutoTokenizer from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR @@ -34,13 +32,23 @@ class GRPOConsumer(BaseConsumer): batch_size, model_config, plugin_config, - microbatch_size=1, + minibatch_size=1, num_generations=8, use_wandb=True, generate_config=None, - training_config={}, + grpo_config={}, project_name=None, + save_interval: int = 100, + save_dir="./model", ): + print(f"Using GRPO config: {grpo_config}") + if grpo_config.get("loss_variation", "sample_level") == "token_level": + if batch_size != minibatch_size: + warnings.warn( + f"Applied token_level loss, force overwrite mini-batch-size with batch-size: mini-batch-size: {minibatch_size}->{batch_size}", + UserWarning, + ) + minibatch_size = batch_size super().__init__( num_producers, num_episodes, @@ -53,36 +61,58 @@ class GRPOConsumer(BaseConsumer): batch_size, model_config, plugin_config, - microbatch_size, + minibatch_size, + save_dir=save_dir, ) path = model_config.pop("path") self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config) self.policy_model.train() self.policy_model.gradient_checkpointing_enable() - self.optimizer = HybridAdam(self.policy_model.parameters(), lr=training_config.get("lr", 1e-6)) + self.optimizer = HybridAdam(self.policy_model.parameters(), lr=grpo_config.get("lr", 1e-6)) self.accum_loss = torch.zeros(1, device=self.device) self.accum_reward = torch.zeros(1, device=self.device) self.accum_kl = torch.zeros(1, device=self.device) - self.accum_format_reward = torch.zeros(1, device=self.device) - self.accum_acc_reward = torch.zeros(1, device=self.device) + self.accum_format_acc = torch.zeros(1, device=self.device) + self.accum_ans_acc = torch.zeros(1, device=self.device) self.accum_advantages = torch.zeros(1, device=self.device) self.accum_response_length = torch.zeros(1, device=self.device) self.accum_count = 0 self.generate_config = generate_config - self.training_config = training_config + self.grpo_config = grpo_config self.project_name = project_name + self.effective_sample_count = 0 + self.total_sample_count = 0 + + self.policy_loss_fn = PolicyLoss( + clip_eps_low=grpo_config.get("clip_eps_low", 0.2), + clip_eps_high=grpo_config.get("clip_eps_high", 0.2), + beta=grpo_config.get("beta", 0.01), + loss_variation=grpo_config.get("loss_variation", "sample_level"), + ) # Reference model is initialized from policy model. - self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config) - self.reference_model.eval() + if self.policy_loss_fn.beta > 0: + self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config) + self.reference_model.eval() self.tokenizer = AutoTokenizer.from_pretrained(path) self.pad_token_id = self.tokenizer.pad_token_id self.num_generations = num_generations - self.filter_range = training_config.get("filter_range", None) + self.filter_range = grpo_config.get("filter_range", None) if self.filter_range is not None: assert len(self.filter_range) == 2, "Filter range should have 2 values." + self.filter_truncated_response = grpo_config.get("filter_truncated_response", False) + if self.filter_truncated_response: + self.max_length = 0 + if "max_tokens" in self.generate_config: + self.max_length = self.generate_config["max_tokens"] + elif "max_new_tokens" in self.generate_config: + self.max_length = self.generate_config["max_new_tokens"] + else: + raise ValueError( + "either max_tokens (vllm) or max_new_tokens (transformers) must be set in generate_config." + ) # Initialize verifiable reward. response_format_tags = { "think_start": {"text": "", "num_occur": 1}, @@ -90,11 +120,12 @@ class GRPOConsumer(BaseConsumer): "answer_start": {"text": "", "num_occur": 1}, "answer_end": {"text": "", "num_occur": 1}, } + reward_model_kwargs = { + k: v for k, v in grpo_config.items() if k in ["soft_over_length_punishment", "max_length", "cache_length"] + } self.reward_model = VerifiableReward( - reward_fns=[math_reward_fn], tokenizer=self.tokenizer, tags=response_format_tags + reward_fns=[math_reward_fn], tokenizer=self.tokenizer, tags=response_format_tags, **reward_model_kwargs ) - - self.policy_loss_fn = PolicyLoss() self.global_step = 0 self.use_wandb = use_wandb @@ -102,7 +133,7 @@ class GRPOConsumer(BaseConsumer): optimizer=self.optimizer, total_steps=min(self.num_episodes, 4) * self.num_update_per_episode, warmup_steps=0, - eta_min=0.1 * training_config.get("lr", 1e-6), + eta_min=0.1 * grpo_config.get("lr", 1e-6), ) def setup(self): @@ -118,10 +149,11 @@ class GRPOConsumer(BaseConsumer): self.policy_model, self.optimizer, _, _, self.lr_scheduler = self.booster.boost( self.policy_model, self.optimizer, lr_scheduler=self.lr_scheduler ) - self.reference_model, *_ = self.booster.boost(self.reference_model) + if self.policy_loss_fn.beta > 0: + self.reference_model, *_ = self.booster.boost(self.reference_model) self.plugin.logger.set_level("ERROR") - def step(self, step_idx: int, **kwargs) -> Optional[float]: + def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]: """ Step data from policy model: [{ @@ -132,18 +164,108 @@ class GRPOConsumer(BaseConsumer): }, ...] Format: - [batch_size, num_of_generation, prompt_length + response_length] --- ............. + [minibatch_size, num_of_generation, prompt_length + response_length] --- ............. """ - # Reshape to [batch_size x num_of_generation, prompt_length + response_length] + # Reshape to [minibatch_size x num_of_generation, prompt_length + response_length] data = {k: v.view(-1, v.size(-1)) for k, v in kwargs.items()} action_mask = data["action_mask"] num_action = action_mask.shape[1] old_action_log_probs = data["action_log_probs"] response_length = torch.sum(action_mask, dim=1).to(torch.float32) - forward_batch_size = self.training_config.get("train_microbatch_size", data["input_ids"].size(0)) + train_microbatch_size = self.grpo_config.get("train_microbatch_size", data["input_ids"].size(0)) + + reward_group = self.reward_model( + data["input_ids"], + gt_answer=data["gt_answer"], + response_idx=data["response_idx"], + ) + + reward = torch.tensor([value[0] for value in reward_group]).to(data["input_ids"].device) + format_acc = torch.tensor([value[1] for value in reward_group]).to(data["input_ids"].device) + ans_acc = torch.tensor([value[2] for value in reward_group]).to(data["input_ids"].device) + + # [minibatch_size, num_generations] + + group_reward = reward.view(-1, self.num_generations) + reward_mean = group_reward.mean(dim=1) + # [minibatch_size x num_generations] + reward_mean = reward_mean.repeat_interleave(self.num_generations, dim=0) + + reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0) + # [minibatch_size x num_generations] + advantages = ((reward - reward_mean) / (reward_std + 1e-4)).unsqueeze(dim=-1) + # filter out the reward that is too high (all sample gets full score) or too low (all sample gets 0 score), + group_ans_acc = ( + ans_acc.view(-1, self.num_generations).mean(dim=1).repeat_interleave(self.num_generations, dim=0) + ) + loss_mask = ( + torch.ones(action_mask.size(0), device=action_mask.device).bool() + if self.filter_range is None + else torch.logical_and(group_ans_acc > self.filter_range[0], group_ans_acc < self.filter_range[1]) + ) + # filter out overlength samples + if self.filter_truncated_response and action_mask.size(1) == self.max_length: + loss_mask = torch.logical_and( + loss_mask, + action_mask[:, -1] == False, + ) + effective_tokens_count = torch.sum(action_mask, dim=-1) * loss_mask + effective_samples = all_reduce_sum(torch.sum(loss_mask), self.plugin) + total_effective_tokens_count = all_reduce_sum(torch.sum(effective_tokens_count), self.plugin) + total_samples = all_reduce_sum(torch.sum(torch.ones_like(loss_mask, device=loss_mask.device)), self.plugin) + self.effective_sample_count += effective_samples.item() + self.total_sample_count += total_samples.item() + + mean_kl, mean_loss = [], [] + + if self.grpo_config.get("dynamic_batching", True): + need_update = self.effective_sample_count >= self.batch_size * self.dp_size * self.num_generations + # to exclude the excessive samples from the last batch, the last num_excessive_samples samples are not used for training and will be kept in buffer for the next iteration. + num_excessive_samples = ( + int( + (self.effective_sample_count - self.batch_size * self.dp_size * self.num_generations) + / self.num_generations + / self.dp_size + ) + * self.num_generations + ) + if num_excessive_samples > 0: + data = { + k: ( + v[: -num_excessive_samples if num_excessive_samples != 0 else v.size(0)] + if k + in [ + "input_ids", + "attention_mask", + "action_log_probs", + "action_mask", + "response_idx", + "gt_answer", + ] + else v + ) + for k, v in data.items() + } + action_mask = action_mask[ + : -num_excessive_samples if num_excessive_samples != 0 else action_mask.size(0) + ] + loss_mask = loss_mask[: -num_excessive_samples if num_excessive_samples != 0 else loss_mask.size(0)] + advantages = advantages[: -num_excessive_samples if num_excessive_samples != 0 else advantages.size(0)] + else: + num_excessive_samples = 0 + else: + # If dynamic batching is disabled, we need to use all samples for training. + need_update = (step_idx + 1) % self.num_microbatches == 0 + num_excessive_samples = 0 + + pbar.set_postfix( + { + "Step": self.global_step + 1, + "Status": f"Collecting: {self.effective_sample_count}/{self.batch_size * self.dp_size * self.num_generations}", + } + ) - need_update = (step_idx + 1) % self.num_microbatches == 0 # Gradient must be synchronized if zero2 is enabled. https://github.com/hpcaitech/ColossalAI/blob/44d4053fec005fe0b06b6bc755fdc962463145df/colossalai/booster/plugin/hybrid_parallel_plugin.py#L1500 ctx = ( nullcontext() @@ -151,95 +273,71 @@ class GRPOConsumer(BaseConsumer): else self.booster.no_sync(self.policy_model, self.optimizer) ) with ctx: - reward_group = self.reward_model( - data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"] - ) - - reward = torch.tensor([value[0] for value in reward_group]).to(data["input_ids"].device) - format_reward = torch.tensor([value[1] for value in reward_group]).to(data["input_ids"].device) - acc_reward = torch.tensor([value[2] for value in reward_group]).to(data["input_ids"].device) - - # [batch_size, num_generations] - - group_reward = reward.view(-1, self.num_generations) - reward_mean = group_reward.mean(dim=1) - # [batch_size x num_generations] - reward_mean = reward_mean.repeat_interleave(self.num_generations, dim=0) - reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0) - # [batch_size x num_generations] - advantages = ((reward - reward_mean) / (reward_std + 1e-4)).unsqueeze(dim=-1) - # filter out the reward that is too high (all sample gets full score) or too low (all sample gets 0 score), - loss_mask = ( - None - if self.filter_range is None - else torch.logical_and( - reward_mean > self.filter_range[0], reward_mean < self.filter_range[1] - ).repeat_interleave(self.num_generations, dim=0) - ) - mean_kl, mean_loss = [], [] - - for forward_micro_batch_start in range(0, data["input_ids"].size(0), forward_batch_size): + for forward_micro_batch_start in range(0, data["input_ids"].size(0), train_microbatch_size): input_ids_forward_micro_batch = data["input_ids"][ - forward_micro_batch_start : forward_micro_batch_start + forward_batch_size + forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size ] attention_mask_forward_micro_batch = data["attention_mask"][ - forward_micro_batch_start : forward_micro_batch_start + forward_batch_size + forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size ] action_mask_forward_micro_batch = action_mask[ - forward_micro_batch_start : forward_micro_batch_start + forward_batch_size + forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size ] loss_mask_forward_micro_batch = ( - loss_mask[forward_micro_batch_start : forward_micro_batch_start + forward_batch_size] + loss_mask[forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size] if loss_mask is not None else None ) advantages_forward_micro_batch = advantages[ - forward_micro_batch_start : forward_micro_batch_start + forward_batch_size + forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size ] if self.plugin.pp_size > 1: # Support training with PP. + if self.policy_loss_fn.beta > 0: + with torch.no_grad(): + reference_model_outputs = self.booster.execute_pipeline( + iter( + [ + { + "input_ids": input_ids_forward_micro_batch, + "attention_mask": attention_mask_forward_micro_batch, + } + ] + ), + self.reference_model, + criterion=lambda outputs, inputs: torch.tensor( + [0.0], device=action_mask.device + ), # dummy criterion + optimizer=None, + return_loss=False, + return_outputs=True, + ) - with torch.no_grad(): - reference_model_outputs = self.booster.execute_pipeline( - iter( - [ - { - "input_ids": input_ids_forward_micro_batch, - "attention_mask": attention_mask_forward_micro_batch, - } - ] - ), - self.reference_model, - criterion=lambda outputs, inputs: torch.tensor( - [0.0], device=action_mask.device - ), # dummy criterion - optimizer=None, - return_loss=False, - return_outputs=True, - ) - - if self.booster.plugin.stage_manager.is_last_stage(): - reference_model_logits = reference_model_outputs["outputs"]["logits"] - reference_action_log_probs = calc_action_log_probs( - reference_model_logits / self.generate_config["temperature"], - input_ids_forward_micro_batch, - num_action, - self.plugin.shard_config, - ) + if self.booster.plugin.stage_manager.is_last_stage(): + reference_model_logits = reference_model_outputs["outputs"]["logits"] + reference_action_log_probs = calc_action_log_probs( + reference_model_logits / self.generate_config["temperature"], + input_ids_forward_micro_batch, + num_action, + self.plugin.shard_config, + ) + else: + # Dummy reference logprobs for data iterator. + reference_action_log_probs = None else: - # Dummy reference logprobs for data iterator. reference_action_log_probs = None data_policy_forward = { "input_ids": input_ids_forward_micro_batch, "attention_mask": attention_mask_forward_micro_batch, "action_mask": action_mask_forward_micro_batch, - "reference_action_log_probs": reference_action_log_probs, "advantages": advantages_forward_micro_batch, "loss_mask": loss_mask_forward_micro_batch, "source": self.rank, } + if reference_action_log_probs is not None: + data_policy_forward["reference_action_log_probs"] = reference_action_log_probs kl = [] @@ -251,24 +349,30 @@ class GRPOConsumer(BaseConsumer): num_action, self.plugin.shard_config, ) - per_token_kl = ( - torch.exp(inputs["reference_action_log_probs"] - action_log_probs) - - (inputs["reference_action_log_probs"] - action_log_probs) - - 1 - ) - appox_kl = torch.sum(per_token_kl * inputs["action_mask"], dim=-1) / torch.sum( - inputs["action_mask"], dim=-1 - ) - kl.append(appox_kl.mean()) - loss, skip_update, _ = self.policy_loss_fn( + if "reference_action_log_probs" in inputs: + per_token_kl = ( + torch.exp(inputs["reference_action_log_probs"] - action_log_probs) + - (inputs["reference_action_log_probs"] - action_log_probs) + - 1 + ) + appox_kl = torch.sum(per_token_kl * inputs["action_mask"], dim=-1) / torch.sum( + inputs["action_mask"], dim=-1 + ) + kl.append(appox_kl.mean()) + else: + per_token_kl = 0.0 + kl.append(0.0) + + loss, _ = self.policy_loss_fn( action_log_probs, action_log_probs, inputs["advantages"].repeat_interleave(action_log_probs.size(-1), dim=-1), per_token_kl, inputs["action_mask"], loss_mask=inputs["loss_mask"], + total_effective_tokens_in_batch=total_effective_tokens_count, ) - return loss + return loss, num_excessive_samples // self.num_generations policy_model_outputs = self.booster.execute_pipeline( iter([data_policy_forward]), @@ -298,61 +402,71 @@ class GRPOConsumer(BaseConsumer): self.plugin.shard_config, ) - with torch.no_grad(): - reference_model_logits = self.reference_model( - input_ids=input_ids_forward_micro_batch, - attention_mask=attention_mask_forward_micro_batch, - ).logits - reference_action_log_probs = calc_action_log_probs( - reference_model_logits / self.generate_config["temperature"], - input_ids_forward_micro_batch, - num_action, - self.plugin.shard_config, - ) - per_token_kl = ( - torch.exp(reference_action_log_probs - action_log_probs) - - (reference_action_log_probs - action_log_probs) - - 1 - ) - kl = torch.sum(per_token_kl * action_mask_forward_micro_batch, dim=-1) / torch.sum( - action_mask_forward_micro_batch, dim=-1 - ) + if self.policy_loss_fn.beta > 0: + with torch.no_grad(): + reference_model_logits = self.reference_model( + input_ids=input_ids_forward_micro_batch, + attention_mask=attention_mask_forward_micro_batch, + ).logits + reference_action_log_probs = calc_action_log_probs( + reference_model_logits / self.generate_config["temperature"], + input_ids_forward_micro_batch, + num_action, + self.plugin.shard_config, + ) + per_token_kl = ( + torch.exp(reference_action_log_probs - action_log_probs) + - (reference_action_log_probs - action_log_probs) + - 1 + ) + kl = torch.sum(per_token_kl * action_mask_forward_micro_batch, dim=-1) / torch.sum( + action_mask_forward_micro_batch, dim=-1 + ) + else: + per_token_kl = 0.0 + kl = None - loss, skip_update, _ = self.policy_loss_fn( + loss, _ = self.policy_loss_fn( action_log_probs, old_action_log_probs, advantages_forward_micro_batch.repeat_interleave(action_log_probs.size(-1), dim=-1), per_token_kl, action_mask_forward_micro_batch, loss_mask=loss_mask_forward_micro_batch, + total_effective_tokens_in_batch=total_effective_tokens_count, ) - if not skip_update: - self.booster.backward(loss, self.optimizer) + self.booster.backward(loss, self.optimizer) loss = all_reduce_mean(loss, self.plugin) - kl = all_reduce_mean(kl.mean(), self.plugin) # Calculate accumulate value. - mean_kl.append(kl.data) + if kl is not None: + kl = all_reduce_mean(kl.mean(), self.plugin) + mean_kl.append(kl.data) mean_loss.append(loss.data) if not self.plugin.pp_size > 1 or ( self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0 ): reward = all_reduce_mean(reward.mean(), self.plugin) - format_reward = all_reduce_mean(format_reward.mean(), self.plugin) - acc_reward = all_reduce_mean(acc_reward.mean(), self.plugin) + format_acc = all_reduce_mean(format_acc.mean(), self.plugin) + ans_acc = all_reduce_mean(ans_acc.mean(), self.plugin) advantages = all_reduce_mean(advantages.mean(), self.plugin) response_length = all_reduce_mean(response_length.mean(), self.plugin) self.accum_loss.add_(sum(mean_loss) / len(mean_loss)) - self.accum_kl.add_(sum(mean_kl) / len(mean_kl)) + if self.policy_loss_fn.beta > 0: + self.accum_kl.add_(sum(mean_kl) / len(mean_kl)) self.accum_reward.add_(reward.data) - self.accum_format_reward.add_(format_reward.data) - self.accum_acc_reward.add_(acc_reward.data) + self.accum_format_acc.add_(format_acc.data) + self.accum_ans_acc.add_(ans_acc.data) self.accum_advantages.add_(advantages.data) self.accum_response_length.add_(response_length.data) self.accum_count += 1 if need_update: self.optimizer.step() self.optimizer.zero_grad() + self.global_step += 1 + sample_utilization = self.effective_sample_count / self.total_sample_count + self.effective_sample_count = 0 + self.total_sample_count = 0 if not self.plugin.pp_size > 1 or ( self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0 ): @@ -360,167 +474,41 @@ class GRPOConsumer(BaseConsumer): if (not self.plugin.pp_size > 1 and self.rank == 0) or ( self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0 ): - print( - "Loss:", - self.accum_loss.item() / self.accum_count, - "\nReward:", - self.accum_reward.item() / self.accum_count, - "\nFormat Reward:", - self.accum_format_reward.item() / self.accum_count, - "\nAcc Reward:", - self.accum_acc_reward.item() / self.accum_count, - "\nKL:", - self.accum_kl.item() / self.accum_count, - "\nAdvantages:", - self.accum_advantages.item() / self.accum_count, - "\nResponse Length:", - self.accum_response_length.item() / self.accum_count, - ) - self.wandb_run.log( - { - "metrics/reward": self.accum_reward.item() / self.accum_count, - "metrics/format_reward": self.accum_format_reward.item() / self.accum_count, - "metrics/acc_reward": self.accum_acc_reward.item() / self.accum_count, - "metrics/response_length": self.accum_response_length.item() / self.accum_count, - "train/loss": self.accum_loss.item() / self.accum_count, - "train/kl": self.accum_kl.item() / self.accum_count, - "train/advantages": self.accum_advantages.item() / self.accum_count, - "train/learning_rate": self.lr_scheduler.get_last_lr()[0], - "rollout/temperature": data["temperature"].cpu().numpy()[0][0], - } - ) + to_log_msg = [ + f"Loss: {self.accum_loss.item() / self.accum_count:.4f}", + f"Reward: {self.accum_reward.item() / self.accum_count:.4f}", + f"format Reward: {self.accum_format_acc.item() / self.accum_count:.4f}", + f"Acc Reward: {self.accum_ans_acc.item() / self.accum_count:.4f}", + f"Advantages: {self.accum_advantages.item() / self.accum_count:.4f}", + f"Response Length: {self.accum_response_length.item() / self.accum_count:.4f}", + ] + ([f"KL: {self.accum_kl.item() / self.accum_count:.4f}"] if self.policy_loss_fn.beta > 0 else []) + print("\n".join(to_log_msg)) + metrics = { + "metrics/reward": self.accum_reward.item() / self.accum_count, + "metrics/format_acc": self.accum_format_acc.item() / self.accum_count, + "metrics/ans_acc": self.accum_ans_acc.item() / self.accum_count, + "metrics/response_length": self.accum_response_length.item() / self.accum_count, + "train/loss": self.accum_loss.item() / self.accum_count, + "train/advantages": self.accum_advantages.item() / self.accum_count, + "train/learning_rate": self.lr_scheduler.get_last_lr()[0], + "train/sample_utilization": sample_utilization, + "rollout/temperature": data["temperature"].cpu().numpy()[0][0], + } + if self.policy_loss_fn.beta > 0: + metrics["train/kl"] = self.accum_kl.item() / self.accum_count + + self.wandb_run.log(metrics) self.accum_loss.zero_() self.accum_reward.zero_() - self.accum_acc_reward.zero_() - self.accum_format_reward.zero_() + self.accum_ans_acc.zero_() + self.accum_format_acc.zero_() self.accum_kl.zero_() self.accum_advantages.zero_() self.accum_response_length.zero_() - self.accum_count = 0 - return loss_scalar - - def state_dict(self): - self.policy_model._force_wait_all_gather() - model = self.policy_model.unwrap() - state_dict = model.state_dict() - return state_dict - - -@ray.remote -class GRPOEvalConsumer(BaseConsumer): - def __init__( - self, - num_producers, - num_episodes, - rank, - world_size, - master_addr, - master_port, - num_update_per_episode, - num_recv_per_update, - batch_size, - model_config, - plugin_config, - microbatch_size=1, - num_generations=4, - use_wandb=True, - log_dir="./results", - ): - super().__init__( - num_producers, - num_episodes, - rank, - world_size, - master_addr, - master_port, - num_update_per_episode, - num_recv_per_update, - batch_size, - model_config, - plugin_config, - microbatch_size, - ) - path = model_config.pop("path") - self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config) - self.policy_model.train() - self.accum_reward = torch.zeros(1, device=self.device) - self.accum_format_reward = torch.zeros(1, device=self.device) - self.accum_acc_reward = torch.zeros(1, device=self.device) - self.accum_response_length = torch.zeros(1, device=self.device) - self.accum_count = torch.zeros(1, device=self.device) - - self.tokenizer = AutoTokenizer.from_pretrained(path) - self.pad_token_id = self.tokenizer.pad_token_id - self.num_generations = num_generations - - # Initialize verifiable reward. - response_format_tags = { - "think_start": {"text": "", "num_occur": 1}, - "think_end": {"text": "", "num_occur": 1}, - "answer_start": {"text": "", "num_occur": 1}, - "answer_end": {"text": "", "num_occur": 1}, - } - self.reward_model = VerifiableReward( - reward_fns=[math_reward_fn], tokenizer=self.tokenizer, tags=response_format_tags - ) - - self.log_dir = log_dir - if not os.path.exists(self.log_dir): - os.makedirs(self.log_dir) + return loss_scalar, num_excessive_samples // self.num_generations else: - os.system(f"rm -rf {self.log_dir}/*") - - def setup(self): - super().setup() - self.policy_model, _, *_ = self.booster.boost(self.policy_model) - - def step(self, step_idx: int, **kwargs) -> Optional[float]: - rank = dist.get_rank() - data = {k: v.view(-1, v.size(-1)).cpu() for k, v in kwargs.items()} - kwargs["input_ids"].size(0) - reward_group = self.reward_model( - data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"] - ) - reward = [value[0].item() for value in reward_group] - format_reward = [value[1].item() for value in reward_group] - acc_reward = [value[2].item() for value in reward_group] - response_length = [(data["response_idx"][i][1] - data["response_idx"][i][0]).item() for i in range(len(reward))] - - response = self.tokenizer.batch_decode(data["input_ids"], skip_special_tokens=True) - with open(f"{self.log_dir}/eval_results_rank_{rank}.jsonl", "a", encoding="utf8") as f: - for i in range(len(response)): - f.write( - json.dumps( - { - "response": response[i], - "reward": reward[i], - "format_reward": format_reward[i], - "acc_reward": acc_reward[i], - "response_length": response_length[i], - }, - ensure_ascii=False, - ) - + "\n" - ) - - self.accum_reward += sum(reward) - self.accum_format_reward += sum(format_reward) - self.accum_acc_reward += sum(acc_reward) - self.accum_response_length += sum(response_length) - self.accum_count += len(reward) - - # print results - total_count = all_reduce_mean(self.accum_count, self.plugin) - mean_reward = all_reduce_mean(self.accum_reward, self.plugin) / total_count - mean_format_reward = all_reduce_mean(self.accum_format_reward, self.plugin) / total_count - mean_acc_reward = all_reduce_mean(self.accum_acc_reward, self.plugin) / total_count - mean_response_length = all_reduce_mean(self.accum_response_length, self.plugin) / total_count - if rank == 0: - print( - f"Step {step_idx}: Mean Reward: {mean_reward}, Mean Format Reward: {mean_format_reward}, Mean Acc Reward: {mean_acc_reward}, Mean Response Length: {mean_response_length}" - ) - return None + return None, num_excessive_samples // self.num_generations def state_dict(self): self.policy_model._force_wait_all_gather() diff --git a/applications/ColossalChat/coati/distributed/inference_backend.py b/applications/ColossalChat/coati/distributed/inference_backend.py index 17c71c8a8..7d32cd52a 100644 --- a/applications/ColossalChat/coati/distributed/inference_backend.py +++ b/applications/ColossalChat/coati/distributed/inference_backend.py @@ -184,6 +184,7 @@ class SGLangInferenceBackend(BaseInferenceBackend): class VLLMInferenceBackend(BaseInferenceBackend): DEFAULT_MODEL_CONFIG = dict( trust_remote_code=True, + enable_sleep_mode=False, ) FORCE_GENERATE_CONFIG = dict( logprobs=0, @@ -205,6 +206,7 @@ class VLLMInferenceBackend(BaseInferenceBackend): generate_config.update(self.FORCE_GENERATE_CONFIG) generate_config.update({"n": num_generations}) self.generate_config = SamplingParams(**generate_config) + self.model_config = model_config self.tokenizer = tokenizer self.num_generations = num_generations diff --git a/applications/ColossalChat/coati/distributed/launch.py b/applications/ColossalChat/coati/distributed/launch.py index 699d90a8c..b193dc8bc 100644 --- a/applications/ColossalChat/coati/distributed/launch.py +++ b/applications/ColossalChat/coati/distributed/launch.py @@ -4,10 +4,10 @@ from typing import Any, Dict, Optional import ray from .consumer import SimpleConsumer -from .grpo_consumer import GRPOConsumer, GRPOEvalConsumer +from .grpo_consumer import GRPOConsumer from .producer import SimpleProducer -ALGO_MAP = {"Simple": SimpleConsumer, "GRPO": GRPOConsumer, "EvalGRPO": GRPOEvalConsumer} +ALGO_MAP = {"Simple": SimpleConsumer, "GRPO": GRPOConsumer, "DAPO": GRPOConsumer} def get_jsonl_size_fast(path: str) -> int: @@ -40,6 +40,7 @@ def launch_distributed( inference_model_config: Dict[str, Any], generate_config: Dict[str, Any], train_model_config: Dict[str, Any], + grpo_config: Dict[str, Any], plugin_config: Dict[str, Any], tokenizer_config: Optional[Dict[str, Any]] = None, inference_backend: str = "transformers", @@ -48,6 +49,8 @@ def launch_distributed( master_port: int = 29500, core_algo: str = "GRPO", project_name: Optional[str] = None, + save_interval: int = 100, + save_dir: str = "./model", ): if core_algo not in ALGO_MAP: @@ -101,15 +104,13 @@ def launch_distributed( batch_size=train_batch_size, model_config=train_model_config, plugin_config=plugin_config, - microbatch_size=train_minibatch_size, + minibatch_size=train_minibatch_size, generate_config=generate_config_consumer, - training_config={ - "filter_range": [0.05, 9.0], - "lr": 1e-6, - "train_microbatch_size": train_microbatch_size, - }, + grpo_config=grpo_config, num_generations=num_generations, project_name=project_name, + save_interval=save_interval, + save_dir=save_dir, ) procs.append(consumer) ray.get([p.setup.remote() for p in procs]) diff --git a/applications/ColossalChat/coati/distributed/loss.py b/applications/ColossalChat/coati/distributed/loss.py index 90ad09736..36057b24f 100644 --- a/applications/ColossalChat/coati/distributed/loss.py +++ b/applications/ColossalChat/coati/distributed/loss.py @@ -2,7 +2,7 @@ from typing import Optional import torch import torch.nn as nn -from coati.distributed.utils import masked_mean +from coati.distributed.utils import masked_mean, masked_sum class PolicyLoss(nn.Module): @@ -10,11 +10,19 @@ class PolicyLoss(nn.Module): Policy Loss for PPO """ - def __init__(self, clip_eps: float = 0.2, skip_threshold: float = 20.0, beta: float = 0.01) -> None: + def __init__( + self, + clip_eps_low: float = 0.2, + clip_eps_high: float = 0.2, + beta: float = 0.01, + loss_variation: str = "sample_level", + ) -> None: super().__init__() - self.clip_eps = clip_eps - self.skip_threshold = skip_threshold + self.clip_eps_low = clip_eps_low + self.clip_eps_high = clip_eps_high self.beta = beta + self.loss_variation = loss_variation + assert loss_variation in ["sample_level", "token_level"], f"Unsupported loss variation: {loss_variation}" def forward( self, @@ -24,22 +32,37 @@ class PolicyLoss(nn.Module): per_token_kl: torch.Tensor, action_mask: Optional[torch.Tensor] = None, loss_mask: Optional[torch.Tensor] = None, + total_effective_tokens_in_batch: torch.Tensor = None, ) -> torch.Tensor: - skip = False if action_mask is None: ratio = (log_probs - log_probs.detach()).exp() else: ratio = ((log_probs - log_probs.detach()) * action_mask).exp() surr1 = ratio * advantages - surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages + surr2 = ratio.clamp(1 - self.clip_eps_low, 1 + self.clip_eps_high) * advantages + if self.beta == 0: + # skip kl term if kl coefficient is zero + per_token_kl = 0.0 loss = -torch.min(surr1, surr2) + self.beta * per_token_kl - if action_mask is not None: - loss = masked_mean(loss, action_mask) + if self.loss_variation == "sample_level": + if action_mask is not None: + loss = masked_mean(loss, action_mask) + else: + loss = loss.mean(dim=1) + if loss_mask is not None: + loss = loss * loss_mask + loss = loss.mean() + elif self.loss_variation == "token_level": + if action_mask is not None: + loss = masked_sum(loss, action_mask) + else: + loss = loss.sum(dim=1) + if loss_mask is not None: + loss = loss * loss_mask + loss = loss.sum() / (total_effective_tokens_in_batch + 1e-8) else: - loss = loss.mean(dim=1) - if loss_mask is not None: - loss = loss * loss_mask - loss = loss.mean() - return loss, skip, ratio.max() + raise ValueError(f"Unsupported loss variation: {self.loss_variation}") + + return loss, ratio.max() diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py index 2c6a24a36..b0624f72a 100644 --- a/applications/ColossalChat/coati/distributed/producer.py +++ b/applications/ColossalChat/coati/distributed/producer.py @@ -103,7 +103,14 @@ class BaseProducer: print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}") outputs["temperature"] = torch.tensor( - [self.model.generate_config.temperature] * outputs["input_ids"].size(0) + [ + ( + self.model.generate_config["temperature"] + if isinstance(self.model.generate_config.temperature, dict) + else self.model.generate_config.temperature + ) + ] + * outputs["input_ids"].size(0) ).to(outputs["input_ids"].device) outputs = pre_send(outputs) ray_broadcast_tensor_dict( @@ -113,10 +120,15 @@ class BaseProducer: if (i + 1) % self.num_microbatches == 0 and ( episode != self.num_episodes - 1 or i != num_valid_microbatches - 1 ): + if isinstance(self.model, BACKEND_MAP["vllm"]) and self.model.model_config.get( + "enable_sleep_mode", False + ): + self.model.llm.sleep() # revict KV_cache to avoid OOM # don't sync model for last iteration print( f"[P{self.producer_idx}] Sync model episode {episode} step {(i + 1) // self.num_microbatches - 1}" ) + torch.cuda.empty_cache() state_dict = ray_broadcast_tensor_dict( None, self.num_producers, device=self.device, group_name="sync_model" @@ -124,12 +136,21 @@ class BaseProducer: self.load_state_dict(state_dict) del state_dict torch.cuda.empty_cache() - # linear annealing for 1 episode, temperature from initial to 0.7 + if isinstance(self.model, BACKEND_MAP["vllm"]) and self.model.model_config.get( + "enable_sleep_mode", False + ): + self.model.llm.wake_up() + # linear annealing for 1 episode, temperature from initial to 0.9 if episode <= 0: ratio = 1 - (len(self.dataloader) - i) / len(self.dataloader) - self.model.generate_config.temperature = (1 - ratio) * self.generate_config[ - "temperature" - ] + ratio * 0.7 + if isinstance(self.model.generate_config.temperature, dict): + self.model.generate_config["temperature"] = (1 - ratio) * self.generate_config[ + "temperature" + ] + ratio * 0.9 + else: + self.model.generate_config.temperature = (1 - ratio) * self.generate_config[ + "temperature" + ] + ratio * 0.9 @ray.remote diff --git a/applications/ColossalChat/coati/distributed/reward/reward_fn.py b/applications/ColossalChat/coati/distributed/reward/reward_fn.py index 53bc15e25..fd9129130 100644 --- a/applications/ColossalChat/coati/distributed/reward/reward_fn.py +++ b/applications/ColossalChat/coati/distributed/reward/reward_fn.py @@ -4,13 +4,22 @@ from .reward_utils import extract_solution, validate_response_structure def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): - format_score = 1.0 - acc_score = 9.0 tokenizer = kwargs["tokenizer"] + soft_over_length_punishment = kwargs.get("soft_over_length_punishment", False) + acc_score = 10.0 reward = torch.tensor(0.0) - format_reward = torch.tensor(0.0) - acc_reward = torch.tensor(0.0) + format_acc = torch.tensor(0.0) + ans_acc = torch.tensor(0.0) s, e = response_idx[0], response_idx[1] + + length_reward = 0.0 + if soft_over_length_punishment: + max_length = kwargs.get("max_length", 1024 * 4) + cache_length = kwargs.get("cache_length", 512) + res_length = e.item() - s.item() + 1 + if max_length - cache_length < res_length < max_length: + length_reward = ((max_length - cache_length) - res_length) / cache_length * acc_score + if gt_answer is None: return reward @@ -22,18 +31,20 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): # Check format accuracy if format_valid: - format_reward += format_score - reward += format_score + format_acc += 1 - # Check answer accuracy + # Check answer accuracy, answer is considered correct if the answer is correct and the format is valid if ( - final_answer is not None + format_valid + and final_answer is not None and gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower() ): - acc_reward += acc_score + ans_acc += 1 reward += acc_score - return torch.tensor([reward, format_reward, acc_reward]).to(input_ids.device) + reward = reward + length_reward + + return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device) def gsm8k_reward_fn(input_ids, **kwargs): diff --git a/applications/ColossalChat/coati/distributed/utils.py b/applications/ColossalChat/coati/distributed/utils.py index 919e4434f..ce4923dc4 100644 --- a/applications/ColossalChat/coati/distributed/utils.py +++ b/applications/ColossalChat/coati/distributed/utils.py @@ -1,3 +1,4 @@ +from collections import defaultdict from typing import Any, Dict, List import torch @@ -26,6 +27,27 @@ def bind_batch(batches: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor return batch +def pad_batch(batches: List[Dict[str, torch.Tensor]], tokenizer: Any = None) -> List[Dict[str, torch.Tensor]]: + max_len = defaultdict(int) + for sample in batches: + for k in sample: + if k in ["input_ids", "attention_mask", "action_log_probs", "action_mask"]: + max_len[k] = max(max_len[k], sample[k].size(-1)) + for idx, sample in enumerate(batches): + for k in sample: + if k in ["input_ids", "attention_mask", "action_log_probs", "action_mask"]: + # right pad with 0s + if k in ["attention_mask", "action_mask"]: + batches[idx][k] = torch.nn.functional.pad( + batches[idx][k], (0, max_len[k] - batches[idx][k].size(-1)), "constant", False + ) + else: + batches[idx][k] = torch.nn.functional.pad( + batches[idx][k], (0, max_len[k] - batches[idx][k].size(-1)), "constant", 0 + ) + return batches + + def pre_send(batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: # compress mask to save bandwidth if "attention_mask" in batch: @@ -113,3 +135,20 @@ def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch mask_sum = mask.sum(dim=dim) mean = tensor / (mask_sum + 1e-8) return mean + + +def masked_sum(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch.Tensor: + """ + Compute the masked sum of a tensor along a specified dimension. + + Args: + tensor (torch.Tensor): The input tensor. + mask (torch.Tensor): The mask tensor with the same shape as the input tensor. + dim (int, optional): The dimension along which to compute the sum. Default is 1. + + Returns: + torch.Tensor: The masked sum tensor. + + """ + tensor = tensor * mask + return tensor.sum(dim=dim) diff --git a/applications/ColossalChat/coati/trainer/utils.py b/applications/ColossalChat/coati/trainer/utils.py index 22a5f492e..5153ce3ad 100755 --- a/applications/ColossalChat/coati/trainer/utils.py +++ b/applications/ColossalChat/coati/trainer/utils.py @@ -128,7 +128,7 @@ def all_reduce_mean(tensor: torch.Tensor, plugin: Plugin = None) -> torch.Tensor return tensor -def all_reduce_sum(tensor: torch.Tensor) -> torch.Tensor: +def all_reduce_sum(tensor: torch.Tensor, plugin: Plugin = None) -> torch.Tensor: """ Performs an all-reduce operation to sum the values of the given tensor across all processes. @@ -138,5 +138,9 @@ def all_reduce_sum(tensor: torch.Tensor) -> torch.Tensor: Returns: torch.Tensor: The reduced tensor with the sum of values across all processes. """ - dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM) + # All reduce sum across DP group + if plugin is not None: + dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM, group=plugin.dp_group) + else: + dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM) return tensor diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index f42a660b7..9c0ec7922 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -1,4 +1,5 @@ import argparse +import os import ray import torch @@ -8,15 +9,17 @@ if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B") parser.add_argument("-d", "--dataset", type=str, default="data.jsonl") + parser.add_argument("-p", "--project", type=str, default="GRPO", help="Project name.") + + # Distributed training parameters parser.add_argument("-t", "--num-trainers", type=int, default=2) parser.add_argument("-i", "--num-inferencer", type=int, default=2) parser.add_argument("-g", "--num-generations", type=int, default=8, help="Number of generations per prompt.") - parser.add_argument("-p", "--project", type=str, default="GRPO", help="Project name.") parser.add_argument( "-ibs", "--inference-batch-size", type=int, - default=64, + default=None, help="Number of prompts to generate per inference step. It should be divisible by tbs, and the weights on the inference backend will be synced every ibs/tbs training steps of the policy model.", ) parser.add_argument( @@ -37,7 +40,7 @@ if __name__ == "__main__": "-tMbs", "--train-minibatch-size", type=int, - default=1, + default=None, help="Number of unique prompts in each training batch per dp group. The inference backend must generate tMbs * g * dp_size samples before forwarding. Satisfy tMbs * g >= tmbs", ) parser.add_argument( @@ -47,22 +50,78 @@ if __name__ == "__main__": default=2, help="Effective batch size per dp group for forwarding and backwarding. Please select based on the availiable memory.", ) + parser.add_argument( + "--ray_dir", type=str, default=None, help="Custom temperary directory for storing ray cluster data, Optional" + ) + parser.add_argument( + "--master_address", type=str, default=None, help="Master address for multi-node distributed training, Optional" + ) + parser.add_argument( + "--master_port", type=int, default=29506, help="Master port for multi-node distributed training, Optional" + ) + + # Sampling parameters parser.add_argument("-b", "--backend", type=str, default="transformers", choices=["transformers", "vllm"]) - parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple", "GRPO", "EvalGRPO"]) + parser.add_argument("-temp", "--temperature", type=float, default=1.0, help="Temperature for sampling.") + parser.add_argument( + "-topk", + "--top-k", + type=int, + default=None, + help="Top k for sampling. Please check the generation arguments documentation for your backend.", + ) + parser.add_argument( + "-topp", + "--top-p", + type=float, + default=1.0, + help="Top p for sampling. Please check the generation arguments documentation for your backend.", + ) parser.add_argument("-s", "--system-prompt", type=str, default=None, help="System prompt for data construction.") + parser.add_argument("-mnt", "--max-new-tokens", type=int, default=1024 * 4 - 512, help="Max length for generation.") + parser.add_argument("-mpt", "--max-prompt-tokens", type=int, default=512, help="Max length for prompt.") + + # GRPO parameters + parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["DAPO", "GRPO"]) + parser.add_argument("-lr", "--learning-rate", type=float, default=1e-6, help="Learning rate for GRPO.") + parser.add_argument("-kl", "--kl-coeff", type=float, default=0.01, help="KL penalty coefficient for GRPO.") + + # Logging/Checkpointing parameters + parser.add_argument("-si", "--save-interval", type=int, default=100, help="Interval for saving checkpoints.") + parser.add_argument("-sd", "--save-dir", type=str, default="./model", help="Directory for saving checkpoints.") + args = parser.parse_args() - assert args.train_minibatch_size > 0, "Train mini batch size must be greater than 0" + if args.train_minibatch_size is None: + # Default settings: Using train batch size as mini batch size + args.train_minibatch_size = args.train_batch_size + if args.inference_batch_size is None: + # Default settings: Using train batch size as inference batch size, sync every inference model every train step + args.inference_batch_size = args.train_batch_size assert ( args.train_minibatch_size * args.num_generations >= args.train_microbatch_size and args.train_microbatch_size > 0 ), "Train micro batch size must be greater than 0 less than train mini batch size * num generations" + assert ( + args.train_minibatch_size <= args.train_batch_size + ), "Train mini batch size must be less than or equals to train batch size" - ray.init(address="local", namespace="ray-example") + if args.master_address is None: + # Default settings: Using single machine + ray.init(address="local", namespace="ray-example") + else: + # For ray distributed multi-machine training, Please change _node_ip_address to your IP address of your master node + ray.init(_node_ip_address=args.master_address, namespace="ray-example", _temp_dir=args.ray_dir) + + if args.top_k is None: + if args.backend == "transformers": + args.top_k = 50 + elif args.backend == "vllm": + args.top_k = -1 inference_model_config = dict(path=args.model) train_model_config = dict(path=args.model, use_flash_attention_2=True, use_cache=False) - generate_config = dict(top_k=50, top_p=0.75, temperature=0.9) + generate_config = dict(top_k=args.top_k, top_p=args.top_p, temperature=args.temperature) if args.backend == "transformers": inference_model_config.update( @@ -73,7 +132,7 @@ if __name__ == "__main__": ) generate_config.update( dict( - max_length=1024 + 512, + max_length=args.max_new_tokens + args.max_prompt_tokens, do_sample=True, max_new_tokens=None, early_stopping=False, @@ -81,31 +140,57 @@ if __name__ == "__main__": ) ) elif args.backend == "vllm": - inference_model_config.update(dict(gpu_memory_utilization=0.7, enforce_eager=True, enable_chunked_prefill=True)) + inference_model_config.update( + dict( + gpu_memory_utilization=0.7, + enforce_eager=True, + enable_chunked_prefill=True, + max_model_len=args.max_new_tokens + args.max_prompt_tokens, + tensor_parallel_size=1, + ) + ) generate_config.update( dict( - max_tokens=2048, + max_tokens=args.max_new_tokens, # max new tokens ignore_eos=True, include_stop_str_in_output=True, stop=[""], ) ) else: - inference_model_config.update( - dict( - mem_fraction_static=0.6, - ) - ) - generate_config.update( - dict( - max_new_tokens=256, - ignore_eos=True, - ) - ) + raise ValueError(f"Unsupported backend: {args.backend}") + + if args.algo == "GRPO": + # Default Settings + grpo_config = { + "lr": args.learning_rate, + "train_microbatch_size": args.train_microbatch_size, + "beta": args.kl_coeff, # KL penalty coefficient + "loss_variation": "sample_level", + } + elif args.algo == "DAPO": + # DAPO variant settings + grpo_config = { + "filter_range": [0.01, 0.99], # only filter out all zero batch and all one batch + "lr": args.learning_rate, + "train_microbatch_size": args.train_microbatch_size, + "dynamic_batching": True, + "clip_eps_low": 0.2, + "clip_eps_high": 0.28, + "skip_threshold": 20.0, + "beta": 0, # no KL penalty for DAPO + "loss_variation": "token_level", + "soft_over_length_punishment": True, + "max_length": args.max_new_tokens + args.max_prompt_tokens, + "cache_length": min(1024, int(args.max_new_tokens / 4)), + "filter_truncated_response": True, + } + else: + raise ValueError(f"Unsupported algorithm: {args.algo}") launch_distributed( num_producers=args.num_inferencer, - num_proc_per_producer=1, + num_proc_per_producer=inference_model_config.get("tensor_parallel_size", 1), num_consumer_procs=args.num_trainers, num_episodes=1, inference_batch_size=args.inference_batch_size, @@ -113,15 +198,22 @@ if __name__ == "__main__": train_batch_size=args.train_batch_size, train_minibatch_size=args.train_minibatch_size, train_microbatch_size=args.train_microbatch_size, - dataset_config={"path": args.dataset, "max_length": 300, "system_prompt": args.system_prompt}, + dataset_config={ + "path": args.dataset, + "max_length": args.max_prompt_tokens, + "system_prompt": args.system_prompt, + }, dataloaders_config={}, inference_model_config=inference_model_config, generate_config=generate_config, num_generations=args.num_generations, train_model_config=train_model_config, - plugin_config={}, # Default setting: zero. + grpo_config=grpo_config, + plugin_config={ + "zero_stage": 2, + }, # for zero + # currently not support tp/pp # plugin_config={ - # "pp_size": 2, # "tp_size": 2, # "microbatch_size": args.train_microbatch_size // 2, # "zero_stage": 0, @@ -129,7 +221,9 @@ if __name__ == "__main__": # }, # for pp inference_backend=args.backend, master_addr="localhost", - master_port=29506, + master_port=args.master_port, core_algo=args.algo, project_name=args.project, + save_interval=args.save_interval, + save_dir=os.path.join(args.save_dir, args.project.replace(" ", "_")), )