mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-04-26 17:53:08 +00:00
cherry pick zero bubble RL
This commit is contained in:
@@ -425,20 +425,6 @@ class GRPOConsumer(BaseConsumer):
|
||||
kl = all_reduce_mean(torch.mean(torch.stack(kl)).to(loss.device), self.plugin).data
|
||||
mean_kl.append(kl)
|
||||
mean_loss.append(all_reduce_mean(loss, self.plugin).data)
|
||||
mini_batch_entropies.append(
|
||||
all_reduce_mean(
|
||||
(
|
||||
(
|
||||
(
|
||||
entropy_from_logits(policy_model_logits[:, -num_action:])
|
||||
* action_mask_forward_micro_batch
|
||||
).sum(-1)
|
||||
)
|
||||
/ action_mask_forward_micro_batch.sum(-1)
|
||||
).detach(),
|
||||
self.plugin,
|
||||
)
|
||||
)
|
||||
else:
|
||||
policy_model_logits = self.policy_model(
|
||||
input_ids=input_ids_forward_micro_batch,
|
||||
|
||||
@@ -64,8 +64,10 @@ class Distributor:
|
||||
)
|
||||
self.profiler.exit(f"sync_model_consumer_pp_{i}")
|
||||
self.weight_version[i] += 1
|
||||
for i in range(self.consumer_pp_size):
|
||||
if signal.get(f"producer_{self.distributor_id}_pp_{i}", None) == "ready_sync_model":
|
||||
if all(
|
||||
[signal.get(f"producer_{self.distributor_id}_pp_{i}", None) == "ready_sync_model" for i in range(self.consumer_pp_size)]
|
||||
):
|
||||
for i in range(self.consumer_pp_size):
|
||||
self.profiler.enter(f"sync_model_producer_{self.distributor_id}_pp_{i}")
|
||||
# Broadcast the model state dict to all producers
|
||||
ray.get(
|
||||
@@ -116,4 +118,4 @@ class Distributor:
|
||||
ray.get(self.shared_signal_actor.set_signal.remote("distributor_weight_version", last_weight_version))
|
||||
|
||||
def get_weight_version(self):
|
||||
return min(self.weight_version)
|
||||
return self.weight_version[0]
|
||||
|
||||
@@ -244,6 +244,7 @@ class BaseProducer:
|
||||
f"producer_{self.producer_idx}_pp_{pp_idx}", "ready_sync_model"
|
||||
)
|
||||
)
|
||||
for pp_idx in range(self.consumer_pp_size):
|
||||
print(
|
||||
f"[P{self.producer_idx}] Sync model PP stage {pp_idx} episode {episode} step {(i + 1) // self.num_microbatches - 1}"
|
||||
)
|
||||
|
||||
@@ -0,0 +1,2 @@
|
||||
ray==2.49.2
|
||||
pygloo>=0.2.0 # you need to build from source: https://github.com/ray-project/pygloo commit 82ae2d72222aefcac54a8e88995735ede3abe9cf https://github.com/ray-project/pygloo/blob/main/README.md
|
||||
Reference in New Issue
Block a user