mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 20:10:17 +00:00
address conversation
This commit is contained in:
@@ -117,14 +117,12 @@ class BaseConsumer:
|
||||
# receive data from producers
|
||||
for r in range(self.num_producers):
|
||||
print(f"[T{dist.get_rank()}] Recv data episode {episode} step {step} from {r}")
|
||||
raw_batch = unbind_batch(
|
||||
ray_broadcast_tensor_dict(None, src=0, device=self.device, group_name=f"sync_data_{r}")
|
||||
)
|
||||
raw_batch = ray_broadcast_tensor_dict(None, src=0, device=self.device, group_name=f"sync_data_{r}")
|
||||
recv_effective_count = 0
|
||||
# calculate group reward et al. filtering. As only the filtered group will be used for training (which is incomplete),
|
||||
# we need to calculate the metrics before filtering here for logging
|
||||
for group in raw_batch:
|
||||
group_with_reward = self.calculate_group_reward(group)
|
||||
raw_batch_with_reward = unbind_batch(self.calculate_reward(raw_batch))
|
||||
for group_with_reward in raw_batch_with_reward:
|
||||
group_reward_mean = group_with_reward["reward"].mean().cpu().item()
|
||||
group_format_acc_mean = group_with_reward["format_acc"].mean().cpu().item()
|
||||
group_ans_acc_mean = group_with_reward["ans_acc"].mean().cpu().item()
|
||||
@@ -139,7 +137,8 @@ class BaseConsumer:
|
||||
.cpu()
|
||||
.item()
|
||||
)
|
||||
filtered_group = self.prompt_level_filtering(group_with_reward)
|
||||
if self.grpo_config.get("dynamic_batching", True):
|
||||
filtered_group = self.prompt_level_filtering(group_with_reward)
|
||||
recv_effective_count += 1 if filtered_group is not None else 0
|
||||
self.buffer.append(
|
||||
[
|
||||
|
Reference in New Issue
Block a user