move prompt-level-filtering to buffer side

This commit is contained in:
YeAnbang
2025-05-15 18:16:50 +08:00
parent 957e3a521a
commit 55eee129d2
2 changed files with 81 additions and 24 deletions

View File

@@ -113,18 +113,24 @@ class BaseConsumer:
) as pbar:
for step in pbar:
i = 0
allow_sync_model = False
allow_sync_model = True
for _ in range(self.num_recv_per_update):
# receive data from producers
for r in range(self.num_producers):
print(f"[T{dist.get_rank()}] Recv data episode {episode} step {step} from {r}")
self.buffer.extend(
unbind_batch(
ray_broadcast_tensor_dict(
None, src=0, device=self.device, group_name=f"sync_data_{r}"
)
)
raw_batch = unbind_batch(
ray_broadcast_tensor_dict(None, src=0, device=self.device, group_name=f"sync_data_{r}")
)
filtered_batch = [
t
for t in [
self.prompt_level_filtering(self.calculate_group_reward(group))
for group in raw_batch
]
if t is not None
]
self.buffer.extend(filtered_batch)
while len(self.buffer) >= self.dp_size * self.minibatch_size:
batches = self.buffer[
self.dp_rank * self.minibatch_size : (self.dp_rank + 1) * self.minibatch_size
@@ -177,7 +183,7 @@ class BaseConsumer:
)
del state_dict
torch.cuda.empty_cache()
allow_sync_model = False
allow_sync_model = True
@ray.remote