diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index ecfd65458..6385df2cf 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -113,18 +113,24 @@ class BaseConsumer: ) as pbar: for step in pbar: i = 0 - allow_sync_model = False + allow_sync_model = True for _ in range(self.num_recv_per_update): # receive data from producers for r in range(self.num_producers): print(f"[T{dist.get_rank()}] Recv data episode {episode} step {step} from {r}") - self.buffer.extend( - unbind_batch( - ray_broadcast_tensor_dict( - None, src=0, device=self.device, group_name=f"sync_data_{r}" - ) - ) + raw_batch = unbind_batch( + ray_broadcast_tensor_dict(None, src=0, device=self.device, group_name=f"sync_data_{r}") ) + filtered_batch = [ + t + for t in [ + self.prompt_level_filtering(self.calculate_group_reward(group)) + for group in raw_batch + ] + if t is not None + ] + + self.buffer.extend(filtered_batch) while len(self.buffer) >= self.dp_size * self.minibatch_size: batches = self.buffer[ self.dp_rank * self.minibatch_size : (self.dp_rank + 1) * self.minibatch_size @@ -177,7 +183,7 @@ class BaseConsumer: ) del state_dict torch.cuda.empty_cache() - allow_sync_model = False + allow_sync_model = True @ray.remote diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 7ea2dbf18..fcf7b0740 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -1,5 +1,5 @@ from contextlib import nullcontext -from typing import Any, Optional +from typing import Any, Dict, Optional import ray import torch @@ -179,7 +179,6 @@ class GRPOConsumer(BaseConsumer): Format: [minibatch_size, num_of_generation, prompt_length + response_length] --- ............. """ - # Reshape to [minibatch_size x num_of_generation, prompt_length + response_length] data = {k: v.view(-1, v.size(-1)) for k, v in kwargs.items()} action_mask = data["action_mask"] @@ -188,15 +187,9 @@ class GRPOConsumer(BaseConsumer): response_length = torch.sum(action_mask, dim=1).to(torch.float32) train_microbatch_size = self.grpo_config.get("train_microbatch_size", data["input_ids"].size(0)) - reward_group = self.reward_model( - data["input_ids"], - gt_answer=data["gt_answer"], - response_idx=data["response_idx"], - ) - - reward = torch.tensor([value[0] for value in reward_group]).to(data["input_ids"].device) - format_acc = torch.tensor([value[1] for value in reward_group]).to(data["input_ids"].device) - ans_acc = torch.tensor([value[2] for value in reward_group]).to(data["input_ids"].device) + reward = data["reward"].view((-1)) + format_acc = data["format_acc"].view((-1)) + ans_acc = data["ans_acc"].view((-1)) # [minibatch_size, num_generations] @@ -213,11 +206,7 @@ class GRPOConsumer(BaseConsumer): ans_acc.view(-1, self.num_generations).mean(dim=1).repeat_interleave(self.num_generations, dim=0) ) # [minibatch_size x num_of_generation] - loss_mask = ( - torch.ones(action_mask.size(0), device=action_mask.device).bool() - if self.filter_range is None - else torch.logical_and(group_ans_acc > self.filter_range[0], group_ans_acc < self.filter_range[1]) - ) + loss_mask = torch.ones(action_mask.size(0), device=action_mask.device).bool() # filter out overlength samples if self.filter_truncated_response and action_mask.size(1) == self.max_length: @@ -525,6 +514,68 @@ class GRPOConsumer(BaseConsumer): else: return None, excessive_prompts_idx + def calculate_group_reward(self, rollout_group: Dict[str, Any]) -> Dict[str, Any]: + """ + Calculate the group reward for the given rollout group. + + Args: + rollout_group (Dict[str, Any]): + a group of samples generated by the model from the same prompt + contain the following keys: + "input_ids": torch.Tensor, [num_of_generation, prompt_length + response_length] + "attention_mask": torch.Tensor, [num_of_generation, prompt_length + response_length] + "action_mask": torch.Tensor, [num_of_generation, response_length] + "action_log_probs": torch.Tensor, [num_of_generation, response_length] + "response_idx": int, torch.Tensor, [num_of_generation, 2] + "gt_answer": torch.Tensor, [num_of_generation, 128] + "temperature": torch.Tensor, [] (scalar) + + Returns: + Dict[str, Any]: The new group data with calculated reward. + """ + reward_group = self.reward_model( + rollout_group["input_ids"], + gt_answer=rollout_group["gt_answer"], + response_idx=rollout_group["response_idx"], + ) + # [num_of_generation] + reward = torch.tensor([value[0] for value in reward_group]).to(rollout_group["input_ids"].device) + format_acc = torch.tensor([value[1] for value in reward_group]).to(rollout_group["input_ids"].device) + ans_acc = torch.tensor([value[2] for value in reward_group]).to(rollout_group["input_ids"].device) + + rollout_group["reward"] = reward.view((-1, 1)) + rollout_group["format_acc"] = format_acc.view((-1, 1)) + rollout_group["ans_acc"] = ans_acc.view((-1, 1)) + return rollout_group + + def prompt_level_filtering(self, rollout_group: Dict[str, Any]) -> Dict[str, Any]: + """ + rollout_group: Dict[str, Any] + a group of samples generated by the model from the same prompt + contain the following keys: + "input_ids": torch.Tensor, [num_of_generation, prompt_length + response_length] + "attention_mask": torch.Tensor, [num_of_generation, prompt_length + response_length] + "action_mask": torch.Tensor, [num_of_generation, response_length] + "action_log_probs": torch.Tensor, [num_of_generation, response_length] + "response_idx": int, torch.Tensor, [num_of_generation, 2] + "gt_answer": torch.Tensor, [num_of_generation, 128] + "temperature": torch.Tensor, [] (scalar) + "reward": torch.Tensor, [num_of_generation] + "format_acc": torch.Tensor, [num_of_generation] + "ans_acc": torch.Tensor, [num_of_generation] + """ + if self.filter_range is not None: + # filter prompt whoes accuracy is too high or too low (out of range) + group_ans_acc = torch.mean(rollout_group["ans_acc"]) + if group_ans_acc < self.filter_range[0] or group_ans_acc > self.filter_range[1]: + # filter out the prompt + return None + else: + return rollout_group + else: + # no filter + return rollout_group + def state_dict(self): self.policy_model._force_wait_all_gather() model = self.policy_model.unwrap()