From 2ca1e3c630f3ff4b4f066b11ff9809989e1bac5e Mon Sep 17 00:00:00 2001 From: YeAnbang <44796419+YeAnbang@users.noreply.github.com> Date: Mon, 28 Apr 2025 17:10:00 +0800 Subject: [PATCH] fix pp+tp, fix dataloader (#6280) --- .../ColossalChat/coati/distributed/consumer.py | 6 +++++- .../ColossalChat/coati/distributed/grpo_consumer.py | 12 +++++++++--- .../ColossalChat/coati/distributed/launch.py | 1 - .../ColossalChat/coati/distributed/producer.py | 2 +- applications/ColossalChat/rl_example.py | 4 ++-- 5 files changed, 17 insertions(+), 8 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index 47f08cc0b..439d4d702 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -66,7 +66,11 @@ class BaseConsumer: launch(self.rank, self.world_size, self.master_addr, self.master_port, local_rank=0) plugin_config = dict(tp_size=1, pp_size=1, precision="bf16", zero_stage=2) - if self.plugin_config.get("pp_size", 1) > 1 and "num_microbatches" not in self.plugin_config: + if ( + self.plugin_config.get("pp_size", 1) > 1 + and "num_microbatches" not in self.plugin_config + and "microbatch_size" not in self.plugin_config + ): plugin_config["microbatch_size"] = self.minibatch_size plugin_config.update(self.plugin_config) self.plugin = HybridParallelPlugin(**plugin_config) diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 29307e1d3..3da4a4f47 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -49,6 +49,12 @@ class GRPOConsumer(BaseConsumer): UserWarning, ) minibatch_size = batch_size + if ( + plugin_config.get("pp_size", 1) > 1 + and "num_microbatches" not in plugin_config + and "microbatch_size" not in plugin_config + ): + plugin_config["microbatch_size"] = max(1, grpo_config.get("train_microbatch_size") // 2) super().__init__( num_producers, num_episodes, @@ -373,7 +379,7 @@ class GRPOConsumer(BaseConsumer): loss_mask=inputs["loss_mask"], total_effective_tokens_in_batch=total_effective_tokens_count, ) - return loss, num_excessive_samples // self.num_generations + return loss policy_model_outputs = self.booster.execute_pipeline( iter([data_policy_forward]), @@ -468,10 +474,10 @@ class GRPOConsumer(BaseConsumer): sample_utilization = self.effective_sample_count / self.total_sample_count self.effective_sample_count = 0 self.total_sample_count = 0 + loss_scalar = self.accum_loss.item() if not self.plugin.pp_size > 1 or ( self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0 ): - loss_scalar = self.accum_loss.item() if (not self.plugin.pp_size > 1 and self.rank == 0) or ( self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0 ): @@ -507,7 +513,7 @@ class GRPOConsumer(BaseConsumer): self.accum_advantages.zero_() self.accum_response_length.zero_() self.accum_count = 0 - return loss_scalar, num_excessive_samples // self.num_generations + return loss_scalar, num_excessive_samples // self.num_generations else: return None, num_excessive_samples // self.num_generations diff --git a/applications/ColossalChat/coati/distributed/launch.py b/applications/ColossalChat/coati/distributed/launch.py index b193dc8bc..4e100ebd1 100644 --- a/applications/ColossalChat/coati/distributed/launch.py +++ b/applications/ColossalChat/coati/distributed/launch.py @@ -33,7 +33,6 @@ def launch_distributed( inference_batch_size: int, inference_microbatch_size: int, train_batch_size: int, - train_microbatch_size: int, train_minibatch_size: int, dataset_config: Dict[str, Any], dataloaders_config: Dict[str, Any], diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py index b0624f72a..fcb1a184a 100644 --- a/applications/ColossalChat/coati/distributed/producer.py +++ b/applications/ColossalChat/coati/distributed/producer.py @@ -68,6 +68,7 @@ class BaseProducer: seed=42, ), num_workers=4, + drop_last=True, ) self.device = get_current_device() @@ -116,7 +117,6 @@ class BaseProducer: ray_broadcast_tensor_dict( outputs, src=0, device=self.device, group_name=f"sync_data_{self.producer_idx}" ) - if (i + 1) % self.num_microbatches == 0 and ( episode != self.num_episodes - 1 or i != num_valid_microbatches - 1 ): diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index b20e9dc38..bdfcadbb0 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -198,7 +198,6 @@ if __name__ == "__main__": inference_microbatch_size=args.inference_microbatch_size, train_batch_size=args.train_batch_size, train_minibatch_size=args.train_minibatch_size, - train_microbatch_size=args.train_microbatch_size, dataset_config={ "path": args.dataset, "max_length": args.max_prompt_tokens, @@ -216,7 +215,8 @@ if __name__ == "__main__": # currently not support tp/pp # plugin_config={ # "tp_size": 2, - # "microbatch_size": args.train_microbatch_size // 2, + # "pp_size": 2, + # "microbatch_size": max(1, args.train_microbatch_size // 2), # "zero_stage": 0, # "max_norm": 1.0, # }, # for pp