From 263a9cbe7a7d314f6086872ef1b1a73de0db3239 Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Mon, 28 Apr 2025 16:18:50 +0800 Subject: [PATCH] fixed plugin micro-batch size --- .../coati/distributed/consumer.py | 2 +- .../coati/distributed/grpo_consumer.py | 4 +++- applications/ColossalChat/rl_example.py | 22 ++++++++++--------- 3 files changed, 16 insertions(+), 12 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index 439d4d702..40fa67e1a 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -71,7 +71,7 @@ class BaseConsumer: and "num_microbatches" not in self.plugin_config and "microbatch_size" not in self.plugin_config ): - plugin_config["microbatch_size"] = self.minibatch_size + plugin_config["microbatch_size"] = max(1, self.minibatch_size // plugin_config.get("pp_size", 1)) plugin_config.update(self.plugin_config) self.plugin = HybridParallelPlugin(**plugin_config) self.booster = Booster(plugin=self.plugin) diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 3da4a4f47..859b6bdd1 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -54,7 +54,9 @@ class GRPOConsumer(BaseConsumer): and "num_microbatches" not in plugin_config and "microbatch_size" not in plugin_config ): - plugin_config["microbatch_size"] = max(1, grpo_config.get("train_microbatch_size") // 2) + plugin_config["microbatch_size"] = max( + 1, grpo_config.get("train_microbatch_size") // plugin_config.get("pp_size", 1) + ) super().__init__( num_producers, num_episodes, diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index bdfcadbb0..6d7964c6f 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -209,17 +209,19 @@ if __name__ == "__main__": num_generations=args.num_generations, train_model_config=train_model_config, grpo_config=grpo_config, - plugin_config={ - "zero_stage": 2, - }, # for zero - # currently not support tp/pp # plugin_config={ - # "tp_size": 2, - # "pp_size": 2, - # "microbatch_size": max(1, args.train_microbatch_size // 2), - # "zero_stage": 0, - # "max_norm": 1.0, - # }, # for pp + # "zero_stage": 2, + # }, # for zero + # currently not support tp/pp + plugin_config={ + "tp_size": 2, + "pp_size": 2, + "microbatch_size": max( + 1, args.train_microbatch_size // 2 + ), # microbatch size should be set to train_microbatch_size // pp_size + "zero_stage": 0, + "max_norm": 1.0, + }, # for pp inference_backend=args.backend, master_addr="localhost", master_port=args.master_port,