diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 5e109e1eb..40d362340 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -425,20 +425,6 @@ class GRPOConsumer(BaseConsumer): kl = all_reduce_mean(torch.mean(torch.stack(kl)).to(loss.device), self.plugin).data mean_kl.append(kl) mean_loss.append(all_reduce_mean(loss, self.plugin).data) - mini_batch_entropies.append( - all_reduce_mean( - ( - ( - ( - entropy_from_logits(policy_model_logits[:, -num_action:]) - * action_mask_forward_micro_batch - ).sum(-1) - ) - / action_mask_forward_micro_batch.sum(-1) - ).detach(), - self.plugin, - ) - ) else: policy_model_logits = self.policy_model( input_ids=input_ids_forward_micro_batch, diff --git a/applications/ColossalChat/coati/distributed/zero_bubble/distributor.py b/applications/ColossalChat/coati/distributed/zero_bubble/distributor.py index ea04ae13c..8a710ec07 100644 --- a/applications/ColossalChat/coati/distributed/zero_bubble/distributor.py +++ b/applications/ColossalChat/coati/distributed/zero_bubble/distributor.py @@ -64,8 +64,10 @@ class Distributor: ) self.profiler.exit(f"sync_model_consumer_pp_{i}") self.weight_version[i] += 1 - for i in range(self.consumer_pp_size): - if signal.get(f"producer_{self.distributor_id}_pp_{i}", None) == "ready_sync_model": + if all( + [signal.get(f"producer_{self.distributor_id}_pp_{i}", None) == "ready_sync_model" for i in range(self.consumer_pp_size)] + ): + for i in range(self.consumer_pp_size): self.profiler.enter(f"sync_model_producer_{self.distributor_id}_pp_{i}") # Broadcast the model state dict to all producers ray.get( @@ -116,4 +118,4 @@ class Distributor: ray.get(self.shared_signal_actor.set_signal.remote("distributor_weight_version", last_weight_version)) def get_weight_version(self): - return min(self.weight_version) + return self.weight_version[0] diff --git a/applications/ColossalChat/coati/distributed/zero_bubble/producer.py b/applications/ColossalChat/coati/distributed/zero_bubble/producer.py index 31c314dd5..7179b1da4 100644 --- a/applications/ColossalChat/coati/distributed/zero_bubble/producer.py +++ b/applications/ColossalChat/coati/distributed/zero_bubble/producer.py @@ -244,6 +244,7 @@ class BaseProducer: f"producer_{self.producer_idx}_pp_{pp_idx}", "ready_sync_model" ) ) + for pp_idx in range(self.consumer_pp_size): print( f"[P{self.producer_idx}] Sync model PP stage {pp_idx} episode {episode} step {(i + 1) // self.num_microbatches - 1}" ) diff --git a/applications/ColossalChat/coati/distributed/zero_bubble/requirements.txt b/applications/ColossalChat/coati/distributed/zero_bubble/requirements.txt new file mode 100644 index 000000000..8c4fb8eed --- /dev/null +++ b/applications/ColossalChat/coati/distributed/zero_bubble/requirements.txt @@ -0,0 +1,2 @@ +ray==2.49.2 +pygloo>=0.2.0 # you need to build from source: https://github.com/ray-project/pygloo commit 82ae2d72222aefcac54a8e88995735ede3abe9cf https://github.com/ray-project/pygloo/blob/main/README.md \ No newline at end of file