mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 12:30:42 +00:00
remove redundant code and fix bugs
This commit is contained in:
@@ -121,14 +121,14 @@ class BaseConsumer:
|
||||
raw_batch = unbind_batch(
|
||||
ray_broadcast_tensor_dict(None, src=0, device=self.device, group_name=f"sync_data_{r}")
|
||||
)
|
||||
filtered_batch = [
|
||||
t
|
||||
for t in [
|
||||
self.prompt_level_filtering(self.calculate_group_reward(group))
|
||||
for group in raw_batch
|
||||
]
|
||||
if t is not None
|
||||
processed_batch = [
|
||||
self.prompt_level_filtering(self.calculate_group_reward(group)) for group in raw_batch
|
||||
]
|
||||
filtered_batch = [t for t in processed_batch if t is not None]
|
||||
if self.filter_range is not None:
|
||||
print(
|
||||
f"[T{dist.get_rank()}] Filter recv data: {len(processed_batch)} -> {len(filtered_batch)}"
|
||||
)
|
||||
|
||||
self.buffer.extend(filtered_batch)
|
||||
while len(self.buffer) >= self.dp_size * self.minibatch_size:
|
||||
@@ -137,13 +137,8 @@ class BaseConsumer:
|
||||
]
|
||||
batch = bind_batch(batches)
|
||||
batch = post_recv(batch)
|
||||
loss, excessive_prompts_idx = self.step(i, pbar, **batch)
|
||||
|
||||
if excessive_prompts_idx is not None:
|
||||
excessive_prompts = [self.buffer[idx] for idx in excessive_prompts_idx]
|
||||
self.buffer = excessive_prompts + self.buffer[self.dp_size * self.minibatch_size :]
|
||||
else:
|
||||
self.buffer = self.buffer[self.dp_size * self.minibatch_size :]
|
||||
loss = self.step(i, pbar, **batch)
|
||||
self.buffer = self.buffer[self.dp_size * self.minibatch_size :]
|
||||
if loss is not None:
|
||||
allow_sync_model = True
|
||||
pbar.set_postfix({"loss": loss})
|
||||
|
Reference in New Issue
Block a user