address conversation

This commit is contained in:
YeAnbang 2025-05-28 17:34:11 +08:00
parent 4b1c515f52
commit 2a39d3afd9
2 changed files with 17 additions and 42 deletions

View File

@ -117,14 +117,12 @@ class BaseConsumer:
# receive data from producers
for r in range(self.num_producers):
print(f"[T{dist.get_rank()}] Recv data episode {episode} step {step} from {r}")
raw_batch = unbind_batch(
ray_broadcast_tensor_dict(None, src=0, device=self.device, group_name=f"sync_data_{r}")
)
raw_batch = ray_broadcast_tensor_dict(None, src=0, device=self.device, group_name=f"sync_data_{r}")
recv_effective_count = 0
# calculate group reward et al. filtering. As only the filtered group will be used for training (which is incomplete),
# we need to calculate the metrics before filtering here for logging
for group in raw_batch:
group_with_reward = self.calculate_group_reward(group)
raw_batch_with_reward = unbind_batch(self.calculate_reward(raw_batch))
for group_with_reward in raw_batch_with_reward:
group_reward_mean = group_with_reward["reward"].mean().cpu().item()
group_format_acc_mean = group_with_reward["format_acc"].mean().cpu().item()
group_ans_acc_mean = group_with_reward["ans_acc"].mean().cpu().item()
@ -139,6 +137,7 @@ class BaseConsumer:
.cpu()
.item()
)
if self.grpo_config.get("dynamic_batching", True):
filtered_group = self.prompt_level_filtering(group_with_reward)
recv_effective_count += 1 if filtered_group is not None else 0
self.buffer.append(

View File

@ -218,30 +218,6 @@ class GRPOConsumer(BaseConsumer):
if self.grpo_config.get("dynamic_batching", True):
need_update = self.effective_prompt_count >= self.batch_size * self.dp_size
excessive_prompts = self.effective_prompt_count - self.batch_size * self.dp_size
if excessive_prompts > 0:
excessive_prompts_per_rank = excessive_prompts // self.dp_size
# Only count excessive prompts if they are greater than 1 per rank.
# TODO: customize excessive prompts calculation.
if excessive_prompts_per_rank != 0:
# Mask excessive prompts to False
true_indices = torch.nonzero(effective_prompts_mask)
# Make sure the indices are not empty.
if true_indices.numel() > 0:
true_indices = true_indices.squeeze(-1)
if excessive_prompts_per_rank <= len(true_indices):
excessive_prompts_idx = true_indices[-excessive_prompts_per_rank:]
else:
excessive_prompts_idx = true_indices
effective_prompts_mask[excessive_prompts_idx] = False
for mask_idx in range(len(effective_prompts_mask)):
if effective_prompts_mask[mask_idx] == False:
# Update loss mask.
loss_mask[mask_idx] = False
else:
excessive_prompts_idx = torch.empty([0])
else:
# If dynamic batching is disabled, we need to use all samples for training.
need_update = (step_idx + 1) % self.num_microbatches == 0
@ -510,7 +486,7 @@ class GRPOConsumer(BaseConsumer):
else:
return None
def calculate_group_reward(self, rollout_group: Dict[str, Any]) -> Dict[str, Any]:
def calculate_reward(self, rollout: Dict[str, Any]) -> Dict[str, Any]:
"""
Calculate the group reward for the given rollout group.
@ -529,20 +505,20 @@ class GRPOConsumer(BaseConsumer):
Returns:
Dict[str, Any]: The new group data with calculated reward.
"""
reward_group = self.reward_model(
rollout_group["input_ids"],
gt_answer=rollout_group["gt_answer"],
response_idx=rollout_group["response_idx"],
reward_model_output = self.reward_model(
rollout["input_ids"],
gt_answer=rollout["gt_answer"],
response_idx=rollout["response_idx"],
)
# [num_of_generation]
reward = torch.tensor([value[0] for value in reward_group]).to(rollout_group["input_ids"].device)
format_acc = torch.tensor([value[1] for value in reward_group]).to(rollout_group["input_ids"].device)
ans_acc = torch.tensor([value[2] for value in reward_group]).to(rollout_group["input_ids"].device)
reward = torch.tensor([value[0] for value in reward_model_output]).to(rollout["input_ids"].device)
format_acc = torch.tensor([value[1] for value in reward_model_output]).to(rollout["input_ids"].device)
ans_acc = torch.tensor([value[2] for value in reward_model_output]).to(rollout["input_ids"].device)
rollout_group["reward"] = reward.view((-1, 1))
rollout_group["format_acc"] = format_acc.view((-1, 1))
rollout_group["ans_acc"] = ans_acc.view((-1, 1))
return rollout_group
rollout["reward"] = reward.view((-1, 1))
rollout["format_acc"] = format_acc.view((-1, 1))
rollout["ans_acc"] = ans_acc.view((-1, 1))
return rollout
def prompt_level_filtering(self, rollout_group: Dict[str, Any]) -> Dict[str, Any]:
"""