address conversation

This commit is contained in:
YeAnbang
2025-05-28 17:34:11 +08:00
parent 78a06f5ce3
commit 4c3656870a
2 changed files with 17 additions and 20 deletions

View File

@@ -117,14 +117,12 @@ class BaseConsumer:
# receive data from producers
for r in range(self.num_producers):
print(f"[T{dist.get_rank()}] Recv data episode {episode} step {step} from {r}")
raw_batch = unbind_batch(
ray_broadcast_tensor_dict(None, src=0, device=self.device, group_name=f"sync_data_{r}")
)
raw_batch = ray_broadcast_tensor_dict(None, src=0, device=self.device, group_name=f"sync_data_{r}")
recv_effective_count = 0
# calculate group reward et al. filtering. As only the filtered group will be used for training (which is incomplete),
# we need to calculate the metrics before filtering here for logging
for group in raw_batch:
group_with_reward = self.calculate_group_reward(group)
raw_batch_with_reward = unbind_batch(self.calculate_reward(raw_batch))
for group_with_reward in raw_batch_with_reward:
group_reward_mean = group_with_reward["reward"].mean().cpu().item()
group_format_acc_mean = group_with_reward["format_acc"].mean().cpu().item()
group_ans_acc_mean = group_with_reward["ans_acc"].mean().cpu().item()
@@ -139,7 +137,8 @@ class BaseConsumer:
.cpu()
.item()
)
filtered_group = self.prompt_level_filtering(group_with_reward)
if self.grpo_config.get("dynamic_batching", True):
filtered_group = self.prompt_level_filtering(group_with_reward)
recv_effective_count += 1 if filtered_group is not None else 0
self.buffer.append(
[