diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index 026056783..6f8f1b497 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -117,14 +117,12 @@ class BaseConsumer: # receive data from producers for r in range(self.num_producers): print(f"[T{dist.get_rank()}] Recv data episode {episode} step {step} from {r}") - raw_batch = unbind_batch( - ray_broadcast_tensor_dict(None, src=0, device=self.device, group_name=f"sync_data_{r}") - ) + raw_batch = ray_broadcast_tensor_dict(None, src=0, device=self.device, group_name=f"sync_data_{r}") recv_effective_count = 0 # calculate group reward et al. filtering. As only the filtered group will be used for training (which is incomplete), # we need to calculate the metrics before filtering here for logging - for group in raw_batch: - group_with_reward = self.calculate_group_reward(group) + raw_batch_with_reward = unbind_batch(self.calculate_reward(raw_batch)) + for group_with_reward in raw_batch_with_reward: group_reward_mean = group_with_reward["reward"].mean().cpu().item() group_format_acc_mean = group_with_reward["format_acc"].mean().cpu().item() group_ans_acc_mean = group_with_reward["ans_acc"].mean().cpu().item() @@ -139,7 +137,8 @@ class BaseConsumer: .cpu() .item() ) - filtered_group = self.prompt_level_filtering(group_with_reward) + if self.grpo_config.get("dynamic_batching", True): + filtered_group = self.prompt_level_filtering(group_with_reward) recv_effective_count += 1 if filtered_group is not None else 0 self.buffer.append( [ diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 152689e06..891a6f842 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -218,8 +218,6 @@ class GRPOConsumer(BaseConsumer): if self.grpo_config.get("dynamic_batching", True): need_update = self.effective_prompt_count >= self.batch_size * self.dp_size - excessive_prompts = self.effective_prompt_count - self.batch_size * self.dp_size - assert excessive_prompts <= 0, "Debug: Excessive prompts should always be less than 0. Bug!!!!" else: # If dynamic batching is disabled, we need to use all samples for training. need_update = (step_idx + 1) % self.num_microbatches == 0 @@ -488,7 +486,7 @@ class GRPOConsumer(BaseConsumer): else: return None - def calculate_group_reward(self, rollout_group: Dict[str, Any]) -> Dict[str, Any]: + def calculate_reward(self, rollout: Dict[str, Any]) -> Dict[str, Any]: """ Calculate the group reward for the given rollout group. @@ -507,20 +505,20 @@ class GRPOConsumer(BaseConsumer): Returns: Dict[str, Any]: The new group data with calculated reward. """ - reward_group = self.reward_model( - rollout_group["input_ids"], - gt_answer=rollout_group["gt_answer"], - response_idx=rollout_group["response_idx"], + reward_model_output = self.reward_model( + rollout["input_ids"], + gt_answer=rollout["gt_answer"], + response_idx=rollout["response_idx"], ) # [num_of_generation] - reward = torch.tensor([value[0] for value in reward_group]).to(rollout_group["input_ids"].device) - format_acc = torch.tensor([value[1] for value in reward_group]).to(rollout_group["input_ids"].device) - ans_acc = torch.tensor([value[2] for value in reward_group]).to(rollout_group["input_ids"].device) + reward = torch.tensor([value[0] for value in reward_model_output]).to(rollout["input_ids"].device) + format_acc = torch.tensor([value[1] for value in reward_model_output]).to(rollout["input_ids"].device) + ans_acc = torch.tensor([value[2] for value in reward_model_output]).to(rollout["input_ids"].device) - rollout_group["reward"] = reward.view((-1, 1)) - rollout_group["format_acc"] = format_acc.view((-1, 1)) - rollout_group["ans_acc"] = ans_acc.view((-1, 1)) - return rollout_group + rollout["reward"] = reward.view((-1, 1)) + rollout["format_acc"] = format_acc.view((-1, 1)) + rollout["ans_acc"] = ans_acc.view((-1, 1)) + return rollout def prompt_level_filtering(self, rollout_group: Dict[str, Any]) -> Dict[str, Any]: """