From ee939d9aa506b4c5266fd2136e23bc6503b1d5b7 Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Thu, 29 May 2025 10:25:59 +0800 Subject: [PATCH] address conversation --- .../coati/distributed/consumer.py | 67 ++++++++++--------- .../coati/distributed/grpo_consumer.py | 34 +--------- .../coati/distributed/producer.py | 3 +- 3 files changed, 36 insertions(+), 68 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index 6f8f1b497..593e0f4ec 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -118,48 +118,49 @@ class BaseConsumer: for r in range(self.num_producers): print(f"[T{dist.get_rank()}] Recv data episode {episode} step {step} from {r}") raw_batch = ray_broadcast_tensor_dict(None, src=0, device=self.device, group_name=f"sync_data_{r}") - recv_effective_count = 0 # calculate group reward et al. filtering. As only the filtered group will be used for training (which is incomplete), # we need to calculate the metrics before filtering here for logging - raw_batch_with_reward = unbind_batch(self.calculate_reward(raw_batch)) - for group_with_reward in raw_batch_with_reward: - group_reward_mean = group_with_reward["reward"].mean().cpu().item() - group_format_acc_mean = group_with_reward["format_acc"].mean().cpu().item() - group_ans_acc_mean = group_with_reward["ans_acc"].mean().cpu().item() - group_response_len = ( - ( - group_with_reward["response_idx"][:, 1] - - group_with_reward["response_idx"][:, 0] - + 1 - ) - .type(torch.float32) - .mean() - .cpu() - .item() + # [batch_size, num_generations, ...] -> [batch_size * num_generations, ...] + raw_batch_with_reward = self.calculate_reward({k:v.view(-1, v.size(-1)) if k!='temperature' else v for k, v in raw_batch.items()}) + raw_batch_with_reward = {k: v.view(-1, self.num_generations, v.size(-1)) if k!='temperature' else v for k, v in raw_batch_with_reward.items()} + # [batch_size, num_generations] -> [batch_size] + group_reward_mean = raw_batch_with_reward["reward"][:,:,0].mean(dim=-1) + group_format_acc_mean = raw_batch_with_reward["format_acc"][:,:,0].mean(dim=-1) + group_ans_acc_mean = raw_batch_with_reward["ans_acc"][:,:,0].mean(dim=-1) + group_response_len = ( + (raw_batch_with_reward["response_idx"][:, :, 1] - raw_batch_with_reward["response_idx"][:, :, 0] + 1) + .type(torch.float32) + .mean(dim=-1) + ) + effective_group_mask = None + if self.filter_range is not None and self.grpo_config.get("dynamic_batching", True): + # filter the group based on the reward and accuracy + effective_group_mask = torch.logical_and( + group_ans_acc_mean > self.filter_range[0], group_ans_acc_mean < self.filter_range[1] ) - if self.grpo_config.get("dynamic_batching", True): - filtered_group = self.prompt_level_filtering(group_with_reward) - recv_effective_count += 1 if filtered_group is not None else 0 + raw_batch_with_reward = unbind_batch(raw_batch_with_reward) # List[Dict[str, torch.Tensor]] + for group_idx, group_with_reward in enumerate(raw_batch_with_reward): self.buffer.append( [ - filtered_group, - group_reward_mean, - group_format_acc_mean, - group_ans_acc_mean, - group_response_len, + group_with_reward if effective_group_mask is None or effective_group_mask[group_idx] else None, + group_reward_mean[group_idx], + group_format_acc_mean[group_idx], + group_ans_acc_mean[group_idx], + group_response_len[group_idx], ] ) - if self.filter_range is not None: + if effective_group_mask is not None: print( - f"[T{dist.get_rank()}] Filter recv data: {len(raw_batch)} -> {recv_effective_count}" + f"[T{dist.get_rank()}] Filter recv data: {len(raw_batch)} -> {torch.sum(effective_group_mask).cpu().item()} effective groups" ) - # mapping the effective group to the raw group for indexing - effective_group_to_raw_group_mapping = {} - for buffer_idx in range(len(self.buffer)): - if self.buffer[buffer_idx][0] is not None: - effective_group_to_raw_group_mapping[len(effective_group_to_raw_group_mapping)] = ( - buffer_idx - ) + # mapping the effective group to the raw group for indexing + effective_group_to_raw_group_mapping = {} + for buffer_idx in range(len(self.buffer)): + if self.buffer[buffer_idx][0] is not None: + effective_group_to_raw_group_mapping[len(effective_group_to_raw_group_mapping)] = ( + buffer_idx + ) + pbar.set_postfix({"Collect Effective Prompt": f"{len(effective_group_to_raw_group_mapping)}/{self.dp_size * self.minibatch_size}"}) while len(effective_group_to_raw_group_mapping) >= self.dp_size * self.minibatch_size: # on each dp_rank, we use minibatch_size effective samples to form a batch diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 891a6f842..50702683f 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -84,7 +84,6 @@ class GRPOConsumer(BaseConsumer): self.project_name = project_name self.effective_sample_count = 0 self.effective_prompt_count = 0 - self.total_sample_count = 0 self.project_name = project_name self.run_name = run_name self.wandb_group_name = wandb_group_name @@ -429,11 +428,9 @@ class GRPOConsumer(BaseConsumer): self.optimizer.step() self.optimizer.zero_grad() self.global_step += 1 - # no need to run all_reduce_sum on total_sample_count, because all dp ranks recieves a complete inference batch from all producers. - sample_utilization = self.effective_sample_count / self.total_sample_count + sample_utilization = self.effective_sample_count / len(self.raw_train_batch_reward) / self.num_generations self.effective_prompt_count = 0 self.effective_sample_count = 0 - self.total_sample_count = 0 loss_scalar = self.accum_loss.item() if not self.plugin.pp_size > 1 or ( self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0 @@ -520,35 +517,6 @@ class GRPOConsumer(BaseConsumer): rollout["ans_acc"] = ans_acc.view((-1, 1)) return rollout - def prompt_level_filtering(self, rollout_group: Dict[str, Any]) -> Dict[str, Any]: - """ - rollout_group: Dict[str, Any] - a group of samples generated by the model from the same prompt - contain the following keys: - "input_ids": torch.Tensor, [num_of_generation, prompt_length + response_length] - "attention_mask": torch.Tensor, [num_of_generation, prompt_length + response_length] - "action_mask": torch.Tensor, [num_of_generation, response_length] - "action_log_probs": torch.Tensor, [num_of_generation, response_length] - "response_idx": int, torch.Tensor, [num_of_generation, 2] - "gt_answer": torch.Tensor, [num_of_generation, 128] - "temperature": torch.Tensor, [] (scalar) - "reward": torch.Tensor, [num_of_generation] - "format_acc": torch.Tensor, [num_of_generation] - "ans_acc": torch.Tensor, [num_of_generation] - """ - self.total_sample_count += rollout_group["input_ids"].size(0) - if self.filter_range is not None: - # filter prompt whoes accuracy is too high or too low (out of range) - group_ans_acc = torch.mean(rollout_group["ans_acc"]) - if group_ans_acc < self.filter_range[0] or group_ans_acc > self.filter_range[1]: - # filter out the prompt - return None - else: - return rollout_group - else: - # no filter - return rollout_group - def state_dict(self): self.policy_model._force_wait_all_gather() model = self.policy_model.unwrap() diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py index f5bdc6835..623ed7ab9 100644 --- a/applications/ColossalChat/coati/distributed/producer.py +++ b/applications/ColossalChat/coati/distributed/producer.py @@ -248,11 +248,10 @@ class BaseProducer: self.eval_mode = False self.latest_eval_step = self.consumer_global_step outputs = self.rollout(**batch) - - print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}") outputs["temperature"] = torch.tensor( [self.model.generate_config["temperature"]] * outputs["input_ids"].size(0) ).to(outputs["input_ids"].device) + print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}") outputs = pre_send(outputs) ray_broadcast_tensor_dict( outputs, src=0, device=self.device, group_name=f"sync_data_{self.producer_idx}"