mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 20:40:34 +00:00
address conversation
This commit is contained in:
@@ -218,8 +218,6 @@ class GRPOConsumer(BaseConsumer):
|
||||
|
||||
if self.grpo_config.get("dynamic_batching", True):
|
||||
need_update = self.effective_prompt_count >= self.batch_size * self.dp_size
|
||||
excessive_prompts = self.effective_prompt_count - self.batch_size * self.dp_size
|
||||
assert excessive_prompts <= 0, "Debug: Excessive prompts should always be less than 0. Bug!!!!"
|
||||
else:
|
||||
# If dynamic batching is disabled, we need to use all samples for training.
|
||||
need_update = (step_idx + 1) % self.num_microbatches == 0
|
||||
@@ -488,7 +486,7 @@ class GRPOConsumer(BaseConsumer):
|
||||
else:
|
||||
return None
|
||||
|
||||
def calculate_group_reward(self, rollout_group: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def calculate_reward(self, rollout: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Calculate the group reward for the given rollout group.
|
||||
|
||||
@@ -507,20 +505,20 @@ class GRPOConsumer(BaseConsumer):
|
||||
Returns:
|
||||
Dict[str, Any]: The new group data with calculated reward.
|
||||
"""
|
||||
reward_group = self.reward_model(
|
||||
rollout_group["input_ids"],
|
||||
gt_answer=rollout_group["gt_answer"],
|
||||
response_idx=rollout_group["response_idx"],
|
||||
reward_model_output = self.reward_model(
|
||||
rollout["input_ids"],
|
||||
gt_answer=rollout["gt_answer"],
|
||||
response_idx=rollout["response_idx"],
|
||||
)
|
||||
# [num_of_generation]
|
||||
reward = torch.tensor([value[0] for value in reward_group]).to(rollout_group["input_ids"].device)
|
||||
format_acc = torch.tensor([value[1] for value in reward_group]).to(rollout_group["input_ids"].device)
|
||||
ans_acc = torch.tensor([value[2] for value in reward_group]).to(rollout_group["input_ids"].device)
|
||||
reward = torch.tensor([value[0] for value in reward_model_output]).to(rollout["input_ids"].device)
|
||||
format_acc = torch.tensor([value[1] for value in reward_model_output]).to(rollout["input_ids"].device)
|
||||
ans_acc = torch.tensor([value[2] for value in reward_model_output]).to(rollout["input_ids"].device)
|
||||
|
||||
rollout_group["reward"] = reward.view((-1, 1))
|
||||
rollout_group["format_acc"] = format_acc.view((-1, 1))
|
||||
rollout_group["ans_acc"] = ans_acc.view((-1, 1))
|
||||
return rollout_group
|
||||
rollout["reward"] = reward.view((-1, 1))
|
||||
rollout["format_acc"] = format_acc.view((-1, 1))
|
||||
rollout["ans_acc"] = ans_acc.view((-1, 1))
|
||||
return rollout
|
||||
|
||||
def prompt_level_filtering(self, rollout_group: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
|
Reference in New Issue
Block a user