cherry pick zero bubble RL

This commit is contained in:
YeAnbang
2025-11-06 15:12:51 +08:00
parent 2336d7f6d6
commit c865de32a5
4 changed files with 8 additions and 17 deletions

View File

@@ -425,20 +425,6 @@ class GRPOConsumer(BaseConsumer):
kl = all_reduce_mean(torch.mean(torch.stack(kl)).to(loss.device), self.plugin).data
mean_kl.append(kl)
mean_loss.append(all_reduce_mean(loss, self.plugin).data)
mini_batch_entropies.append(
all_reduce_mean(
(
(
(
entropy_from_logits(policy_model_logits[:, -num_action:])
* action_mask_forward_micro_batch
).sum(-1)
)
/ action_mask_forward_micro_batch.sum(-1)
).detach(),
self.plugin,
)
)
else:
policy_model_logits = self.policy_model(
input_ids=input_ids_forward_micro_batch,

View File

@@ -64,8 +64,10 @@ class Distributor:
)
self.profiler.exit(f"sync_model_consumer_pp_{i}")
self.weight_version[i] += 1
for i in range(self.consumer_pp_size):
if signal.get(f"producer_{self.distributor_id}_pp_{i}", None) == "ready_sync_model":
if all(
[signal.get(f"producer_{self.distributor_id}_pp_{i}", None) == "ready_sync_model" for i in range(self.consumer_pp_size)]
):
for i in range(self.consumer_pp_size):
self.profiler.enter(f"sync_model_producer_{self.distributor_id}_pp_{i}")
# Broadcast the model state dict to all producers
ray.get(
@@ -116,4 +118,4 @@ class Distributor:
ray.get(self.shared_signal_actor.set_signal.remote("distributor_weight_version", last_weight_version))
def get_weight_version(self):
return min(self.weight_version)
return self.weight_version[0]

View File

@@ -244,6 +244,7 @@ class BaseProducer:
f"producer_{self.producer_idx}_pp_{pp_idx}", "ready_sync_model"
)
)
for pp_idx in range(self.consumer_pp_size):
print(
f"[P{self.producer_idx}] Sync model PP stage {pp_idx} episode {episode} step {(i + 1) // self.num_microbatches - 1}"
)

View File

@@ -0,0 +1,2 @@
ray==2.49.2
pygloo>=0.2.0 # you need to build from source: https://github.com/ray-project/pygloo commit 82ae2d72222aefcac54a8e88995735ede3abe9cf https://github.com/ray-project/pygloo/blob/main/README.md