remove redundant code and fix bugs

This commit is contained in:
YeAnbang
2025-05-16 14:08:23 +08:00
parent a528921944
commit 11a5854b50
5 changed files with 27 additions and 59 deletions

View File

@@ -121,14 +121,14 @@ class BaseConsumer:
raw_batch = unbind_batch(
ray_broadcast_tensor_dict(None, src=0, device=self.device, group_name=f"sync_data_{r}")
)
filtered_batch = [
t
for t in [
self.prompt_level_filtering(self.calculate_group_reward(group))
for group in raw_batch
]
if t is not None
processed_batch = [
self.prompt_level_filtering(self.calculate_group_reward(group)) for group in raw_batch
]
filtered_batch = [t for t in processed_batch if t is not None]
if self.filter_range is not None:
print(
f"[T{dist.get_rank()}] Filter recv data: {len(processed_batch)} -> {len(filtered_batch)}"
)
self.buffer.extend(filtered_batch)
while len(self.buffer) >= self.dp_size * self.minibatch_size:
@@ -137,13 +137,8 @@ class BaseConsumer:
]
batch = bind_batch(batches)
batch = post_recv(batch)
loss, excessive_prompts_idx = self.step(i, pbar, **batch)
if excessive_prompts_idx is not None:
excessive_prompts = [self.buffer[idx] for idx in excessive_prompts_idx]
self.buffer = excessive_prompts + self.buffer[self.dp_size * self.minibatch_size :]
else:
self.buffer = self.buffer[self.dp_size * self.minibatch_size :]
loss = self.step(i, pbar, **batch)
self.buffer = self.buffer[self.dp_size * self.minibatch_size :]
if loss is not None:
allow_sync_model = True
pbar.set_postfix({"loss": loss})