diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index 9503d65eb..38b31398d 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -107,9 +107,14 @@ class BaseConsumer: f"Consumer{self.rank} num_update: {self.num_update_per_episode}, num_recv: {self.num_recv_per_update}, nmb: {self.num_microbatches}" ) for episode in range(self.num_episodes): - with tqdm(range(self.num_update_per_episode), desc=f"Episode {episode}", disable=self.rank != 0) as pbar: + with tqdm( + range(self.num_update_per_episode), + desc=f"Episode {episode} with rollout step(s)", + disable=self.rank != 0, + ) as pbar: for step in pbar: i = 0 + allow_sync_model = False for _ in range(self.num_recv_per_update): # receive data from producers for r in range(self.num_producers): @@ -127,15 +132,15 @@ class BaseConsumer: ] batch = bind_batch(batches) batch = post_recv(batch) - loss, num_excessive_prompts = self.step(i, pbar, **batch) - self.buffer = ( - self.buffer[ - (self.dp_rank + 1) * self.minibatch_size - - num_excessive_prompts : (self.dp_rank + 1) * self.minibatch_size - ] - + self.buffer[self.dp_size * self.minibatch_size :] - ) + loss, excessive_prompts_idx = self.step(i, pbar, **batch) + + if excessive_prompts_idx is not None: + excessive_prompts = [self.buffer[idx] for idx in excessive_prompts_idx] + self.buffer = excessive_prompts + self.buffer[self.dp_size * self.minibatch_size :] + else: + self.buffer = self.buffer[self.dp_size * self.minibatch_size :] if loss is not None: + allow_sync_model = True pbar.set_postfix({"loss": loss}) i += 1 if self.lr_scheduler is not None: @@ -149,29 +154,31 @@ class BaseConsumer: print(f"Saved model checkpoint at step {step + 1} in folder {save_path}") if episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1: - if self.pp_size > 1: - print( - f"[T{dist.get_rank()}] Sync model PP stage {self.pp_rank} episode {episode} step {step}" - ) - else: - print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}") - torch.cuda.empty_cache() - state_dict = self.state_dict() - if self.pp_size > 1: - if self.tp_rank == 0 and self.dp_rank == 0: - ray_broadcast_tensor_dict( - state_dict, - src=self.num_producers, - device=self.device, - group_name=f"sync_model_{self.pp_rank}", + if allow_sync_model: + if self.pp_size > 1: + print( + f"[T{dist.get_rank()}] Sync model PP stage {self.pp_rank} episode {episode} step {step}" ) - else: - if self.rank == 0: - ray_broadcast_tensor_dict( - state_dict, src=self.num_producers, device=self.device, group_name="sync_model" - ) - del state_dict - torch.cuda.empty_cache() + else: + print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}") + torch.cuda.empty_cache() + state_dict = self.state_dict() + if self.pp_size > 1: + if self.tp_rank == 0 and self.dp_rank == 0: + ray_broadcast_tensor_dict( + state_dict, + src=self.num_producers, + device=self.device, + group_name=f"sync_model_{self.pp_rank}", + ) + else: + if self.rank == 0: + ray_broadcast_tensor_dict( + state_dict, src=self.num_producers, device=self.device, group_name="sync_model" + ) + del state_dict + torch.cuda.empty_cache() + allow_sync_model = False @ray.remote diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 877ff98ec..31b687639 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -1,4 +1,3 @@ -import warnings from contextlib import nullcontext from typing import Any, Optional @@ -10,7 +9,7 @@ from coati.distributed.loss import PolicyLoss from coati.distributed.reward.reward_fn import boxed_math_reward_fn, math_reward_fn from coati.distributed.reward.verifiable_reward import VerifiableReward from coati.distributed.utils import calc_action_log_probs -from coati.trainer.utils import all_reduce_mean, all_reduce_sum +from coati.trainer.utils import all_gather_tensors, all_reduce_mean, all_reduce_sum from transformers import AutoModelForCausalLM, AutoTokenizer from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR @@ -42,13 +41,6 @@ class GRPOConsumer(BaseConsumer): save_dir="./model", ): print(f"Using GRPO config: {grpo_config}") - if grpo_config.get("loss_variation", "sample_level") == "token_level": - if batch_size != minibatch_size: - warnings.warn( - f"Applied token_level loss, force overwrite mini-batch-size with batch-size: mini-batch-size: {minibatch_size}->{batch_size}", - UserWarning, - ) - minibatch_size = batch_size if ( plugin_config.get("pp_size", 1) > 1 and "num_microbatches" not in plugin_config @@ -90,6 +82,7 @@ class GRPOConsumer(BaseConsumer): self.grpo_config = grpo_config self.project_name = project_name self.effective_sample_count = 0 + self.effective_prompt_count = 0 self.total_sample_count = 0 self.policy_loss_fn = PolicyLoss( @@ -213,70 +206,66 @@ class GRPOConsumer(BaseConsumer): group_ans_acc = ( ans_acc.view(-1, self.num_generations).mean(dim=1).repeat_interleave(self.num_generations, dim=0) ) + # [minibatch_size x num_of_generation] loss_mask = ( torch.ones(action_mask.size(0), device=action_mask.device).bool() if self.filter_range is None else torch.logical_and(group_ans_acc > self.filter_range[0], group_ans_acc < self.filter_range[1]) ) + # filter out overlength samples if self.filter_truncated_response and action_mask.size(1) == self.max_length: loss_mask = torch.logical_and( loss_mask, action_mask[:, -1] == False, ) - effective_tokens_count = torch.sum(action_mask, dim=-1) * loss_mask + prompt_level_mask = loss_mask.view(self.minibatch_size, self.num_generations) + + # [minibatch_size] -> calculate the number of effective prompts + effective_prompts_mask = prompt_level_mask.any(dim=1) + effective_prompts = all_reduce_sum(torch.sum(effective_prompts_mask), self.plugin) + self.effective_prompt_count += effective_prompts.item() + excessive_prompts_idx = None + + mean_kl, mean_loss = [], [] + + if self.grpo_config.get("dynamic_batching", True): + need_update = self.effective_prompt_count >= self.batch_size * self.dp_size + excessive_prompts = self.effective_prompt_count - self.batch_size * self.dp_size + + if excessive_prompts > 0: + excessive_prompts_per_rank = excessive_prompts // self.dp_size + # Only count excessive prompts if they are greater than 1 per rank. + # TODO: customize excessive prompts calculation. + if excessive_prompts_per_rank != 0: + # Mask excessive prompts to False + true_indices = torch.nonzero(effective_prompts_mask).squeeze() + if excessive_prompts_per_rank <= len(true_indices): + excessive_prompts_idx = true_indices[-excessive_prompts_per_rank:] + else: + excessive_prompts_idx = true_indices + effective_prompts_mask[excessive_prompts_idx] = False + + for mask_idx in range(len(effective_prompts_mask)): + if effective_prompts_mask[mask_idx] == False: + # Update loss mask. + loss_mask[mask_idx] = False + else: + # If dynamic batching is disabled, we need to use all samples for training. + need_update = (step_idx + 1) % self.num_microbatches == 0 + effective_samples = all_reduce_sum(torch.sum(loss_mask), self.plugin) + effective_tokens_count = torch.sum(action_mask, dim=-1) * loss_mask total_effective_tokens_count = all_reduce_sum(torch.sum(effective_tokens_count), self.plugin) total_samples = all_reduce_sum(torch.sum(torch.ones_like(loss_mask, device=loss_mask.device)), self.plugin) self.effective_sample_count += effective_samples.item() self.total_sample_count += total_samples.item() - mean_kl, mean_loss = [], [] - - if self.grpo_config.get("dynamic_batching", True): - need_update = self.effective_sample_count >= self.batch_size * self.dp_size * self.num_generations - # to exclude the excessive samples from the last batch, the last num_excessive_samples samples are not used for training and will be kept in buffer for the next iteration. - num_excessive_samples = ( - int( - (self.effective_sample_count - self.batch_size * self.dp_size * self.num_generations) - / self.num_generations - / self.dp_size - ) - * self.num_generations - ) - if num_excessive_samples > 0: - data = { - k: ( - v[: -num_excessive_samples if num_excessive_samples != 0 else v.size(0)] - if k - in [ - "input_ids", - "attention_mask", - "action_log_probs", - "action_mask", - "response_idx", - "gt_answer", - ] - else v - ) - for k, v in data.items() - } - action_mask = action_mask[ - : -num_excessive_samples if num_excessive_samples != 0 else action_mask.size(0) - ] - loss_mask = loss_mask[: -num_excessive_samples if num_excessive_samples != 0 else loss_mask.size(0)] - advantages = advantages[: -num_excessive_samples if num_excessive_samples != 0 else advantages.size(0)] - else: - num_excessive_samples = 0 - else: - # If dynamic batching is disabled, we need to use all samples for training. - need_update = (step_idx + 1) % self.num_microbatches == 0 - num_excessive_samples = 0 - pbar.set_postfix( { - "Step": self.global_step + 1, - "Status": f"Collecting: {self.effective_sample_count}/{self.batch_size * self.dp_size * self.num_generations}", + "Global Step": self.global_step, + "Effective prompts": f"{self.effective_prompt_count}/{self.batch_size * self.dp_size}", + "Effective samples": f"{self.effective_sample_count}/{self.batch_size * self.dp_size * self.num_generations}", } ) @@ -375,7 +364,7 @@ class GRPOConsumer(BaseConsumer): kl.append(appox_kl.mean()) else: per_token_kl = 0.0 - kl.append(0.0) + kl.append(torch.tensor(0.0)) loss, _ = self.policy_loss_fn( action_log_probs, @@ -479,6 +468,7 @@ class GRPOConsumer(BaseConsumer): self.optimizer.zero_grad() self.global_step += 1 sample_utilization = self.effective_sample_count / self.total_sample_count + self.effective_prompt_count = 0 self.effective_sample_count = 0 self.total_sample_count = 0 loss_scalar = self.accum_loss.item() @@ -495,6 +485,7 @@ class GRPOConsumer(BaseConsumer): f"Acc Reward: {self.accum_ans_acc.item() / self.accum_count:.4f}", f"Advantages: {self.accum_advantages.item() / self.accum_count:.4f}", f"Response Length: {self.accum_response_length.item() / self.accum_count:.4f}", + f"Sample_utilization: {sample_utilization:.4f}", ] + ([f"KL: {self.accum_kl.item() / self.accum_count:.4f}"] if self.policy_loss_fn.beta > 0 else []) print("\n".join(to_log_msg)) metrics = { @@ -520,9 +511,15 @@ class GRPOConsumer(BaseConsumer): self.accum_advantages.zero_() self.accum_response_length.zero_() self.accum_count = 0 - return loss_scalar, num_excessive_samples // self.num_generations + + if excessive_prompts_idx is not None: + # All gather excessive prompts index across DP ranks. + excessive_prompts_idx = [idx + self.dp_rank * self.minibatch_size for idx in excessive_prompts_idx] + excessive_prompts_idx = all_gather_tensors(excessive_prompts_idx, self.plugin) + + return loss_scalar, excessive_prompts_idx else: - return None, num_excessive_samples // self.num_generations + return None, excessive_prompts_idx def state_dict(self): self.policy_model._force_wait_all_gather() diff --git a/applications/ColossalChat/coati/trainer/utils.py b/applications/ColossalChat/coati/trainer/utils.py index 5153ce3ad..25615d191 100755 --- a/applications/ColossalChat/coati/trainer/utils.py +++ b/applications/ColossalChat/coati/trainer/utils.py @@ -144,3 +144,29 @@ def all_reduce_sum(tensor: torch.Tensor, plugin: Plugin = None) -> torch.Tensor: else: dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM) return tensor + + +def all_gather_tensors(local_tensor_list: torch.Tensor, plugin: Plugin = None) -> torch.Tensor: + """ + Gathers tensors from all processes and concatenates them along the first dimension. + + Args: + tensor (torch.Tensor): The input tensor to be gathered. + + Returns: + torch.Tensor: The gathered tensor. + """ + # Gather tensors across DP group + if plugin is not None: + all_tensor_lists = [None] * plugin.dp_size + dist.all_gather_object(all_tensor_lists, local_tensor_list, group=plugin.dp_group) + gathered_tensor_list = [] + for tensors in all_tensor_lists: + gathered_tensor_list.extend(tensors) + else: + all_tensor_lists = [None] * dist.get_world_size() + dist.all_gather_object(all_tensor_lists, local_tensor_list) + gathered_tensor_list = [] + for tensors in all_tensor_lists: + gathered_tensor_list.extend(tensors) + return gathered_tensor_list diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index 788e60c2e..f5609c890 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -9,7 +9,7 @@ if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B") parser.add_argument("-d", "--dataset", type=str, default="data.jsonl") - parser.add_argument("-p", "--project", type=str, default="GRPO", help="Project name.") + parser.add_argument("-p", "--project", type=str, default="GRPO-V3", help="Project name.") parser.add_argument("-e", "--num-episodes", type=int, default=1, help="Number of episodes to train.") # Distributed training parameters @@ -20,7 +20,7 @@ if __name__ == "__main__": "-ibs", "--inference-batch-size", type=int, - default=None, + default=64, help="Number of prompts to generate per inference step. It should be divisible by tbs, and the weights on the inference backend will be synced every ibs/tbs training steps of the policy model.", ) parser.add_argument( @@ -41,7 +41,7 @@ if __name__ == "__main__": "-tMbs", "--train-minibatch-size", type=int, - default=None, + default=8, help="Number of unique prompts in each training batch per dp group. The inference backend must generate tMbs * g * dp_size samples before forwarding. Satisfy tMbs * g >= tmbs", ) parser.add_argument( @@ -58,7 +58,7 @@ if __name__ == "__main__": "--master_address", type=str, default=None, help="Master address for multi-node distributed training, Optional" ) parser.add_argument( - "--master_port", type=int, default=29505, help="Master port for multi-node distributed training, Optional" + "--master_port", type=int, default=29506, help="Master port for multi-node distributed training, Optional" ) # Sampling parameters @@ -223,7 +223,7 @@ if __name__ == "__main__": "zero_stage": 2, }, # for zero # plugin_config={ - # "tp_size": 1, + # "tp_size": 2, # "pp_size": 2, # "microbatch_size": max( # 1, args.train_microbatch_size // 2