diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index 816cab50a..22de6911e 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -117,26 +117,102 @@ class BaseConsumer: # receive data from producers for r in range(self.num_producers): print(f"[T{dist.get_rank()}] Recv data episode {episode} step {step} from {r}") - self.buffer.extend( - unbind_batch( - ray_broadcast_tensor_dict( - None, src=0, device=self.device, group_name=f"sync_data_{r}" - ) - ) + raw_batch = ray_broadcast_tensor_dict( + None, src=0, device=self.device, group_name=f"sync_data_{r}" ) - while len(self.buffer) >= self.dp_size * self.minibatch_size: - batches = self.buffer[ - self.dp_rank * self.minibatch_size : (self.dp_rank + 1) * self.minibatch_size - ] - batch = bind_batch(batches) - batch = post_recv(batch) - loss, excessive_prompts_idx = self.step(i, pbar, **batch) + # calculate group reward et al. filtering. As only the filtered group will be used for training (which is incomplete), + # we need to calculate the metrics before filtering here for logging + # [batch_size, num_generations, ...] -> [batch_size * num_generations, ...] + raw_batch_with_reward = self.calculate_reward( + {k: v.view(-1, v.size(-1)) if k != "temperature" else v for k, v in raw_batch.items()} + ) + raw_batch_with_reward = { + k: v.view(-1, self.num_generations, v.size(-1)) if k != "temperature" else v + for k, v in raw_batch_with_reward.items() + } + # [batch_size, num_generations] -> [batch_size] + reward = raw_batch_with_reward["reward"][:, :, 0] + format_acc = raw_batch_with_reward["format_acc"][:, :, 0] + ans_acc = raw_batch_with_reward["ans_acc"][:, :, 0] + response_len = ( + raw_batch_with_reward["response_idx"][:, :, 1] + - raw_batch_with_reward["response_idx"][:, :, 0] + + 1 + ).type(torch.float32) + effective_group_mask = None + if self.filter_range is not None and self.grpo_config.get("dynamic_batching", True): + # filter the group based on the reward and accuracy + group_ans_acc_mean = ans_acc.mean(dim=1) + effective_group_mask = torch.logical_and( + group_ans_acc_mean > self.filter_range[0], group_ans_acc_mean < self.filter_range[1] + ) + raw_batch_with_reward = unbind_batch(raw_batch_with_reward) # List[Dict[str, torch.Tensor]] + for group_idx, group_with_reward in enumerate(raw_batch_with_reward): + self.buffer.append( + [ + ( + group_with_reward + if effective_group_mask is None or effective_group_mask[group_idx] + else None + ), + reward[group_idx], + format_acc[group_idx], + ans_acc[group_idx], + response_len[group_idx], + ] + ) + if effective_group_mask is not None: + print( + f"[T{dist.get_rank()}] Filter recv data: {len(raw_batch_with_reward)} -> {torch.sum(effective_group_mask).cpu().item()} effective groups" + ) + # mapping the effective group to the raw group for indexing + effective_group_to_raw_group_mapping = {} + for buffer_idx in range(len(self.buffer)): + if self.buffer[buffer_idx][0] is not None: + effective_group_to_raw_group_mapping[len(effective_group_to_raw_group_mapping)] = ( + buffer_idx + ) + print( + f"[T{dist.get_rank()}] Collect Effective Prompt: {len(effective_group_to_raw_group_mapping)}/{self.dp_size * self.minibatch_size}" + ) - if excessive_prompts_idx is not None: - excessive_prompts = [self.buffer[idx] for idx in excessive_prompts_idx] - self.buffer = excessive_prompts + self.buffer[self.dp_size * self.minibatch_size :] - else: - self.buffer = self.buffer[self.dp_size * self.minibatch_size :] + while len(effective_group_to_raw_group_mapping) >= self.dp_size * self.minibatch_size: + # on each dp_rank, we use minibatch_size effective samples to form a batch + batches = [ + self.buffer[effective_group_to_raw_group_mapping[i]] + for i in range( + self.dp_rank * self.minibatch_size, (self.dp_rank + 1) * self.minibatch_size + ) + ] + # every dp_rank will receive a complete mini-batch, no need to sync within step() later + # each mini-batch use the first self.dp_size * minibatch_size effective samples + raw_mini_batches = self.buffer[ + : effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1 + ] # include the last effective sample + raw_mini_batches_metric_dict = { + "raw_train_mini_batch_reward": [t[1] for t in raw_mini_batches], + "raw_train_mini_batch_format_acc": [t[2] for t in raw_mini_batches], + "raw_train_mini_batch_ans_acc": [t[3] for t in raw_mini_batches], + "raw_train_mini_batch_response_len": [t[4] for t in raw_mini_batches], + } + batch = bind_batch([t[0] for t in batches]) + batch = post_recv(batch) + loss = self.step(i, pbar, **batch, **raw_mini_batches_metric_dict) + self.buffer = self.buffer[ + effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1 : + ] + # recalculate the effective group to raw group mapping + effective_group_to_raw_group_mapping_size_before = len(effective_group_to_raw_group_mapping) + effective_group_to_raw_group_mapping = {} + for buffer_idx in range(len(self.buffer)): + if self.buffer[buffer_idx][0] is not None: + effective_group_to_raw_group_mapping[len(effective_group_to_raw_group_mapping)] = ( + buffer_idx + ) + assert ( + len(effective_group_to_raw_group_mapping) + == effective_group_to_raw_group_mapping_size_before - self.dp_size * self.minibatch_size + ) if loss is not None: pbar.set_postfix({"loss": loss}) i += 1 diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 1666ac582..89fae6fc7 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -1,5 +1,5 @@ from contextlib import nullcontext -from typing import Any, Optional +from typing import Any, Dict, Optional import ray import torch @@ -9,7 +9,7 @@ from coati.distributed.loss import PolicyLoss from coati.distributed.reward.reward_fn import boxed_math_reward_fn, math_reward_fn from coati.distributed.reward.verifiable_reward import VerifiableReward from coati.distributed.utils import calc_action_log_probs -from coati.trainer.utils import all_gather_tensors, all_reduce_mean, all_reduce_sum +from coati.trainer.utils import all_reduce_mean, all_reduce_sum from transformers import AutoModelForCausalLM, AutoTokenizer from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR @@ -72,21 +72,18 @@ class GRPOConsumer(BaseConsumer): self.policy_model.gradient_checkpointing_enable() self.optimizer = HybridAdam(self.policy_model.parameters(), lr=grpo_config.get("lr", 1e-6)) self.accum_loss = torch.zeros(1, device=self.device) - self.accum_reward = torch.zeros(1, device=self.device) self.accum_kl = torch.zeros(1, device=self.device) - self.accum_format_acc = torch.zeros(1, device=self.device) - self.accum_ans_acc = torch.zeros(1, device=self.device) self.accum_advantages = torch.zeros(1, device=self.device) - self.accum_response_length = torch.zeros(1, device=self.device) + self.raw_train_batch_reward = [] + self.raw_train_batch_format_acc = [] + self.raw_train_batch_ans_acc = [] + self.raw_train_batch_response_len = [] self.accum_count = 0 self.generate_config = generate_config self.grpo_config = grpo_config self.project_name = project_name self.effective_sample_count = 0 self.effective_prompt_count = 0 - self.total_sample_count = 0 - self.overlength_samples = 0 - self.total_overlength_samples = 0 self.project_name = project_name self.run_name = run_name self.wandb_group_name = wandb_group_name @@ -122,16 +119,7 @@ class GRPOConsumer(BaseConsumer): "either max_tokens (vllm) or max_new_tokens (transformers) must be set in generate_config." ) # Initialize verifiable reward. - response_format_tags = ( - { - "think_start": {"text": "", "num_occur": 1}, - "think_end": {"text": "", "num_occur": 1}, - "answer_start": {"text": "", "num_occur": 1}, - "answer_end": {"text": "", "num_occur": 1}, - } - if grpo_config.get("reward_fn_type") == "think_answer_tags" - else None - ) + response_format_tags = grpo_config.get("response_format_tags", None) reward_model_kwargs = { k: v for k, v in grpo_config.items() @@ -187,24 +175,21 @@ class GRPOConsumer(BaseConsumer): Format: [minibatch_size, num_of_generation, prompt_length + response_length] --- ............. """ - # Reshape to [minibatch_size x num_of_generation, prompt_length + response_length] - data = {k: v.view(-1, v.size(-1)) for k, v in kwargs.items()} + data = {k: v.view(-1, v.size(-1)) for k, v in kwargs.items() if "raw_train_mini_batch_" not in k} + self.raw_train_batch_reward.extend(kwargs["raw_train_mini_batch_reward"]) + self.raw_train_batch_format_acc.extend(kwargs["raw_train_mini_batch_format_acc"]) + self.raw_train_batch_ans_acc.extend(kwargs["raw_train_mini_batch_ans_acc"]) + self.raw_train_batch_response_len.extend(kwargs["raw_train_mini_batch_response_len"]) action_mask = data["action_mask"] num_action = action_mask.shape[1] old_action_log_probs = data["action_log_probs"] response_length = torch.sum(action_mask, dim=1).to(torch.float32) train_microbatch_size = self.grpo_config.get("train_microbatch_size", data["input_ids"].size(0)) - reward_group = self.reward_model( - data["input_ids"], - gt_answer=data["gt_answer"], - response_idx=data["response_idx"], - ) - - reward = torch.tensor([value[0] for value in reward_group]).to(data["input_ids"].device) - format_acc = torch.tensor([value[1] for value in reward_group]).to(data["input_ids"].device) - ans_acc = torch.tensor([value[2] for value in reward_group]).to(data["input_ids"].device) + reward = data["reward"].view((-1)) + format_acc = data["format_acc"].view((-1)) + ans_acc = data["ans_acc"].view((-1)) # [minibatch_size, num_generations] @@ -216,67 +201,35 @@ class GRPOConsumer(BaseConsumer): reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0) # [minibatch_size x num_generations] advantages = ((reward - reward_mean) / (reward_std + 1e-4)).unsqueeze(dim=-1) - # filter out the reward that is too high (all sample gets full score) or too low (all sample gets 0 score), - group_ans_acc = ( - ans_acc.view(-1, self.num_generations).mean(dim=1).repeat_interleave(self.num_generations, dim=0) - ) + # [minibatch_size x num_of_generation] - loss_mask = ( - torch.ones(action_mask.size(0), device=action_mask.device).bool() - if self.filter_range is None - else torch.logical_and(group_ans_acc > self.filter_range[0], group_ans_acc < self.filter_range[1]) - ) + loss_mask = torch.ones(action_mask.size(0), device=action_mask.device).bool() # filter out overlength samples if self.filter_truncated_response and action_mask.size(1) == self.max_length: - old_loss_mask = loss_mask.clone() loss_mask = torch.logical_and( loss_mask, action_mask[:, -1] == False, ) - - self.overlength_samples = (old_loss_mask & ~loss_mask).sum().item() - self.overlength_samples = all_reduce_sum( - torch.tensor(self.overlength_samples, device=loss_mask.device), self.plugin + if self.filter_range is not None and self.grpo_config.get("dynamic_batching", False) == False: + # filter out samples with reward outside the range + # if dynamic batching is enabled, we filter out out of range groups before training + group_ans_acc_mean = ( + ans_acc.view(-1, self.num_generations).mean(dim=1).repeat_interleave(self.num_generations, dim=-1) ) - self.total_overlength_samples += self.overlength_samples.item() - - prompt_level_mask = loss_mask.view(self.minibatch_size, self.num_generations) - - # [minibatch_size] -> calculate the number of effective prompts - effective_prompts_mask = prompt_level_mask.any(dim=1) - effective_prompts = all_reduce_sum(torch.sum(effective_prompts_mask), self.plugin) - self.effective_prompt_count += effective_prompts.item() - excessive_prompts_idx = None + loss_mask = torch.logical_and( + loss_mask, + torch.logical_and( + group_ans_acc_mean > self.filter_range[0], + group_ans_acc_mean < self.filter_range[1], + ), + ) + self.effective_prompt_count += group_reward.size(0) * self.dp_size mean_kl, mean_loss = [], [] if self.grpo_config.get("dynamic_batching", True): need_update = self.effective_prompt_count >= self.batch_size * self.dp_size - excessive_prompts = self.effective_prompt_count - self.batch_size * self.dp_size - - if excessive_prompts > 0: - excessive_prompts_per_rank = excessive_prompts // self.dp_size - # Only count excessive prompts if they are greater than 1 per rank. - # TODO: customize excessive prompts calculation. - if excessive_prompts_per_rank != 0: - # Mask excessive prompts to False - true_indices = torch.nonzero(effective_prompts_mask) - # Make sure the indices are not empty. - if true_indices.numel() > 0: - true_indices = true_indices.squeeze(-1) - if excessive_prompts_per_rank <= len(true_indices): - excessive_prompts_idx = true_indices[-excessive_prompts_per_rank:] - else: - excessive_prompts_idx = true_indices - effective_prompts_mask[excessive_prompts_idx] = False - - for mask_idx in range(len(effective_prompts_mask)): - if effective_prompts_mask[mask_idx] == False: - # Update loss mask. - loss_mask[mask_idx] = False - else: - excessive_prompts_idx = torch.empty([0]) else: # If dynamic batching is disabled, we need to use all samples for training. need_update = (step_idx + 1) % self.num_microbatches == 0 @@ -286,12 +239,10 @@ class GRPOConsumer(BaseConsumer): total_effective_tokens_count = all_reduce_sum(torch.sum(effective_tokens_count), self.plugin) total_samples = all_reduce_sum(torch.sum(torch.ones_like(loss_mask, device=loss_mask.device)), self.plugin) self.effective_sample_count += effective_samples.item() - self.total_sample_count += total_samples.item() pbar.set_postfix( { "Global Step": self.global_step, - "Effective prompts": f"{self.effective_prompt_count}/{self.batch_size * self.dp_size}", - "Effective samples": f"{self.effective_sample_count}/{self.batch_size * self.dp_size * self.num_generations}", + "Gradient Accumulation on": f"{self.effective_prompt_count}/{self.batch_size * self.dp_size} effective prompts, {self.effective_sample_count}/{self.batch_size * self.dp_size * self.num_generations} effective samples", } ) @@ -483,22 +434,16 @@ class GRPOConsumer(BaseConsumer): self.accum_loss.add_(sum(mean_loss) / len(mean_loss)) if self.policy_loss_fn.beta > 0: self.accum_kl.add_(sum(mean_kl) / len(mean_kl)) - self.accum_reward.add_(reward.data) - self.accum_format_acc.add_(format_acc.data) - self.accum_ans_acc.add_(ans_acc.data) self.accum_advantages.add_(advantages.data) - self.accum_response_length.add_(response_length.data) self.accum_count += 1 if need_update: self.optimizer.step() self.optimizer.zero_grad() self.global_step += 1 - sample_utilization = self.effective_sample_count / self.total_sample_count - overlength_samples_percentage = self.total_overlength_samples / self.total_sample_count + # no need to run all reduce as raw_train_batch_* are not splited across dp rank + sample_utilization = self.effective_sample_count / len(self.raw_train_batch_reward) / self.num_generations self.effective_prompt_count = 0 self.effective_sample_count = 0 - self.total_sample_count = 0 - self.total_overlength_samples = 0 loss_scalar = self.accum_loss.item() if not self.plugin.pp_size > 1 or ( self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0 @@ -506,27 +451,39 @@ class GRPOConsumer(BaseConsumer): if (not self.plugin.pp_size > 1 and self.rank == 0) or ( self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0 ): + raw_batch_reward_mean = torch.cat(self.raw_train_batch_reward, dim=0).mean().cpu().item() + raw_batch_format_acc_mean = torch.cat(self.raw_train_batch_format_acc, dim=0).mean().cpu().item() + raw_batch_ans_acc_mean = torch.cat(self.raw_train_batch_ans_acc, dim=0).mean().cpu().item() + raw_batch_response_len = torch.cat(self.raw_train_batch_response_len, dim=0) + raw_batch_response_len_mean = raw_batch_response_len.mean().cpu().item() + overlength_samples_ratio = ( + (raw_batch_response_len >= action_mask.size(-1)).to(float).mean().cpu().item() + ) # not an exact figure, but a close estimate + self.raw_train_batch_reward = [] + self.raw_train_batch_format_acc = [] + self.raw_train_batch_ans_acc = [] + self.raw_train_batch_response_len = [] to_log_msg = [ f"Loss: {self.accum_loss.item() / self.accum_count:.4f}", - f"Reward: {self.accum_reward.item() / self.accum_count:.4f}", - f"format Reward: {self.accum_format_acc.item() / self.accum_count:.4f}", - f"Acc Reward: {self.accum_ans_acc.item() / self.accum_count:.4f}", + f"Reward: {raw_batch_reward_mean:.4f}", + f"format Reward: {raw_batch_format_acc_mean:.4f}", + f"Acc Reward: {raw_batch_ans_acc_mean:.4f}", f"Advantages: {self.accum_advantages.item() / self.accum_count:.4f}", - f"Response Length: {self.accum_response_length.item() / self.accum_count:.4f}", + f"Response Length: {raw_batch_response_len_mean:.4f}", f"Sample_utilization: {sample_utilization:.4f}", - f"Percentage of overlength samples: {overlength_samples_percentage:.4f}", + f"Overlength samples ratio: {overlength_samples_ratio:.4f}", ] + ([f"KL: {self.accum_kl.item() / self.accum_count:.4f}"] if self.policy_loss_fn.beta > 0 else []) print("\n".join(to_log_msg)) metrics = { - "metrics/reward": self.accum_reward.item() / self.accum_count, - "metrics/format_acc": self.accum_format_acc.item() / self.accum_count, - "metrics/ans_acc": self.accum_ans_acc.item() / self.accum_count, - "metrics/response_length": self.accum_response_length.item() / self.accum_count, + "metrics/reward": raw_batch_reward_mean, + "metrics/format_acc": raw_batch_format_acc_mean, + "metrics/ans_acc": raw_batch_ans_acc_mean, + "metrics/response_length": raw_batch_response_len_mean, "train/loss": self.accum_loss.item() / self.accum_count, "train/advantages": self.accum_advantages.item() / self.accum_count, "train/learning_rate": self.lr_scheduler.get_last_lr()[0], "train/sample_utilization": sample_utilization, - "train/percentage_overlength_samples": overlength_samples_percentage, + "train/overlength_samples_ratio": overlength_samples_ratio, "rollout/temperature": data["temperature"].cpu().numpy()[0][0], } if self.policy_loss_fn.beta > 0: @@ -534,21 +491,46 @@ class GRPOConsumer(BaseConsumer): if self.wandb_run is not None: self.wandb_run.log(metrics) self.accum_loss.zero_() - self.accum_reward.zero_() - self.accum_ans_acc.zero_() - self.accum_format_acc.zero_() self.accum_kl.zero_() self.accum_advantages.zero_() - self.accum_response_length.zero_() self.accum_count = 0 - - if excessive_prompts_idx is not None: - # All gather excessive prompts index across DP ranks. - excessive_prompts_idx = [idx + self.dp_rank * self.minibatch_size for idx in excessive_prompts_idx] - excessive_prompts_idx = all_gather_tensors(excessive_prompts_idx, self.plugin) - return loss_scalar, excessive_prompts_idx + return loss_scalar else: - return None, excessive_prompts_idx + return None + + def calculate_reward(self, rollout: Dict[str, Any]) -> Dict[str, Any]: + """ + Calculate the group reward for the given rollout group. + + Args: + rollout_group (Dict[str, Any]): + a group of samples generated by the model from the same prompt + contain the following keys: + "input_ids": torch.Tensor, [num_of_generation, prompt_length + response_length] + "attention_mask": torch.Tensor, [num_of_generation, prompt_length + response_length] + "action_mask": torch.Tensor, [num_of_generation, response_length] + "action_log_probs": torch.Tensor, [num_of_generation, response_length] + "response_idx": int, torch.Tensor, [num_of_generation, 2] + "gt_answer": torch.Tensor, [num_of_generation, 128] + "temperature": torch.Tensor, [] (scalar) + + Returns: + Dict[str, Any]: The new group data with calculated reward. + """ + reward_model_output = self.reward_model( + rollout["input_ids"], + gt_answer=rollout["gt_answer"], + response_idx=rollout["response_idx"], + ) + # [num_of_generation] + reward = torch.tensor([value[0] for value in reward_model_output]).to(rollout["input_ids"].device) + format_acc = torch.tensor([value[1] for value in reward_model_output]).to(rollout["input_ids"].device) + ans_acc = torch.tensor([value[2] for value in reward_model_output]).to(rollout["input_ids"].device) + + rollout["reward"] = reward.view((-1, 1)) + rollout["format_acc"] = format_acc.view((-1, 1)) + rollout["ans_acc"] = ans_acc.view((-1, 1)) + return rollout def state_dict(self): self.policy_model._force_wait_all_gather() diff --git a/applications/ColossalChat/coati/distributed/launch.py b/applications/ColossalChat/coati/distributed/launch.py index c9e8d2ab2..764325cdd 100644 --- a/applications/ColossalChat/coati/distributed/launch.py +++ b/applications/ColossalChat/coati/distributed/launch.py @@ -100,6 +100,7 @@ def launch_distributed( eval_dataset_config=eval_dataset_config, eval_interval=eval_interval, evaluation_function_type=grpo_config["reward_fn_type"], + response_format_tags=grpo_config["response_format_tags"], eval_save_dir=eval_save_dir, eval_generation_config=eval_generation_config, project_name=project_name, diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py index 78964afa1..623ed7ab9 100644 --- a/applications/ColossalChat/coati/distributed/producer.py +++ b/applications/ColossalChat/coati/distributed/producer.py @@ -46,6 +46,7 @@ class BaseProducer: eval_dataset_config=None, eval_interval=-1, # disable evaluation evaluation_function_type="think_answer_tags", + response_format_tags=None, eval_save_dir: str = "./eval", project_name: str = None, run_name: str = None, @@ -148,6 +149,7 @@ class BaseProducer: self.evaluation_function = boxed_math_reward_fn else: raise ValueError(f"Unknown evaluation function type {evaluation_function_type}") + self.response_format_tags = response_format_tags else: print("No eval dataset provided, skip eval") self.device = get_current_device() @@ -217,6 +219,7 @@ class BaseProducer: eval_outputs["response_idx"][m][n], tokenizer=self.tokenizer, eval_mode=True, + tags=self.response_format_tags, ) for m in range(eval_outputs["input_ids"].size(0)) for n in range(eval_outputs["input_ids"].size(1)) @@ -245,11 +248,10 @@ class BaseProducer: self.eval_mode = False self.latest_eval_step = self.consumer_global_step outputs = self.rollout(**batch) - - print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}") outputs["temperature"] = torch.tensor( [self.model.generate_config["temperature"]] * outputs["input_ids"].size(0) ).to(outputs["input_ids"].device) + print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}") outputs = pre_send(outputs) ray_broadcast_tensor_dict( outputs, src=0, device=self.device, group_name=f"sync_data_{self.producer_idx}" @@ -324,6 +326,7 @@ class SimpleProducer(BaseProducer): eval_dataset_config=None, eval_interval=-1, # disable evaluation evaluation_function_type="think_answer_tags", + response_format_tags=None, eval_save_dir: str = "./eval", eval_generation_config={}, project_name: str = None, @@ -349,6 +352,7 @@ class SimpleProducer(BaseProducer): eval_dataset_config=eval_dataset_config, eval_interval=eval_interval, evaluation_function_type=evaluation_function_type, + response_format_tags=response_format_tags, eval_save_dir=eval_save_dir, project_name=project_name, run_name=run_name, diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index 5aec7f5a6..b4f565617 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -121,6 +121,34 @@ if __name__ == "__main__": parser.add_argument( "-rsd", "--rollout-save-dir", type=str, default="./rollouts", help="Directory for saving rollout loggings." ) + parser.add_argument( + "-tp", + "--tensor-parallel-size", + type=int, + default=1, + help="Tensor parallel size for the trainer (consumer). Please check the generation arguments documentation for your backend.", + ) + parser.add_argument( + "-pp", + "--pipeline-parallel-size", + type=int, + default=1, + help="Pipeline parallel size for the trainer (consumer). Please check the generation arguments documentation for your backend.", + ) + parser.add_argument( + "-zero", + "--zero-stage", + type=int, + default=0, + help="Zero stage for the trainer (consumer). Please check the generation arguments documentation for your backend.", + ) + parser.add_argument( + "-ptp", + "--producer-tensor-parallel-size", + type=int, + default=1, + help="Tensor parallel size for the producer. Please check the generation arguments documentation for your backend.", + ) args = parser.parse_args() if args.train_minibatch_size is None: @@ -134,8 +162,8 @@ if __name__ == "__main__": and args.train_microbatch_size > 0 ), "Train micro batch size must be greater than 0 less than train mini batch size * num generations" assert ( - args.train_minibatch_size <= args.train_batch_size - ), "Train mini batch size must be less than or equals to train batch size" + args.train_minibatch_size <= args.train_batch_size and args.train_batch_size % args.train_minibatch_size == 0 + ), "Train mini batch size must be less than or equals to train batch size and train batch size must be divisible by train mini batch size" if args.master_address is None: # Default settings: Using single machine @@ -178,7 +206,7 @@ if __name__ == "__main__": enforce_eager=True, enable_chunked_prefill=True, max_model_len=args.max_new_tokens + args.max_prompt_tokens, - tensor_parallel_size=1, + tensor_parallel_size=args.producer_tensor_parallel_size, ) ) generate_config.update( @@ -203,6 +231,16 @@ if __name__ == "__main__": "reward_fn_type": args.reward_type, "max_length": args.max_new_tokens + args.max_prompt_tokens, "max_new_tokens": args.max_new_tokens, + "response_format_tags": ( + { + "think_start": {"text": "", "num_occur": 1}, + "think_end": {"text": "", "num_occur": 1}, + "answer_start": {"text": "", "num_occur": 1}, + "answer_end": {"text": "", "num_occur": 1}, + } + if args.reward_type == "think_answer_tags" + else None + ), } elif args.algo == "DAPO": # DAPO variant settings @@ -222,13 +260,23 @@ if __name__ == "__main__": "cache_length": min(1024, int(args.max_new_tokens / 4)), "filter_truncated_response": True, "reward_fn_type": args.reward_type, + "response_format_tags": ( + { + "think_start": {"text": "", "num_occur": 1}, + "think_end": {"text": "", "num_occur": 1}, + "answer_start": {"text": "", "num_occur": 1}, + "answer_end": {"text": "", "num_occur": 1}, + } + if args.reward_type == "think_answer_tags" + else None + ), } else: raise ValueError(f"Unsupported algorithm: {args.algo}") launch_distributed( num_producers=args.num_inferencer, - num_proc_per_producer=inference_model_config.get("tensor_parallel_size", 1), + num_proc_per_producer=inference_model_config.get("tensor_parallel_size", args.producer_tensor_parallel_size), num_consumer_procs=args.num_trainers, num_episodes=args.num_episodes, inference_batch_size=args.inference_batch_size, @@ -247,17 +295,14 @@ if __name__ == "__main__": train_model_config=train_model_config, grpo_config=grpo_config, plugin_config={ - "zero_stage": 2, - }, # for zero - # plugin_config={ - # "tp_size": 2, - # "pp_size": 2, - # "microbatch_size": max( - # 1, args.train_microbatch_size // 2 - # ), # microbatch size should be set to train_microbatch_size // pp_size - # "zero_stage": 0, - # "max_norm": 1.0, - # }, # for pp, tp + "tp_size": args.tensor_parallel_size, + "pp_size": args.pipeline_parallel_size, + "microbatch_size": max( + 1, args.train_microbatch_size // args.pipeline_parallel_size + ), # microbatch size should be set to train_microbatch_size // pp_size + "zero_stage": args.zero_stage, + "max_norm": 1.0, + }, # for pp, tp inference_backend=args.backend, master_addr="localhost", master_port=args.master_port,