mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-10 21:40:02 +00:00
move prompt-level-filtering to buffer side
This commit is contained in:
@@ -113,18 +113,24 @@ class BaseConsumer:
|
||||
) as pbar:
|
||||
for step in pbar:
|
||||
i = 0
|
||||
allow_sync_model = False
|
||||
allow_sync_model = True
|
||||
for _ in range(self.num_recv_per_update):
|
||||
# receive data from producers
|
||||
for r in range(self.num_producers):
|
||||
print(f"[T{dist.get_rank()}] Recv data episode {episode} step {step} from {r}")
|
||||
self.buffer.extend(
|
||||
unbind_batch(
|
||||
ray_broadcast_tensor_dict(
|
||||
None, src=0, device=self.device, group_name=f"sync_data_{r}"
|
||||
)
|
||||
)
|
||||
raw_batch = unbind_batch(
|
||||
ray_broadcast_tensor_dict(None, src=0, device=self.device, group_name=f"sync_data_{r}")
|
||||
)
|
||||
filtered_batch = [
|
||||
t
|
||||
for t in [
|
||||
self.prompt_level_filtering(self.calculate_group_reward(group))
|
||||
for group in raw_batch
|
||||
]
|
||||
if t is not None
|
||||
]
|
||||
|
||||
self.buffer.extend(filtered_batch)
|
||||
while len(self.buffer) >= self.dp_size * self.minibatch_size:
|
||||
batches = self.buffer[
|
||||
self.dp_rank * self.minibatch_size : (self.dp_rank + 1) * self.minibatch_size
|
||||
@@ -177,7 +183,7 @@ class BaseConsumer:
|
||||
)
|
||||
del state_dict
|
||||
torch.cuda.empty_cache()
|
||||
allow_sync_model = False
|
||||
allow_sync_model = True
|
||||
|
||||
|
||||
@ray.remote
|
||||
|
Reference in New Issue
Block a user