fixed plugin micro-batch size

This commit is contained in:
YeAnbang 2025-04-28 16:18:50 +08:00
parent 0f794f7294
commit 263a9cbe7a
3 changed files with 16 additions and 12 deletions

View File

@ -71,7 +71,7 @@ class BaseConsumer:
and "num_microbatches" not in self.plugin_config
and "microbatch_size" not in self.plugin_config
):
plugin_config["microbatch_size"] = self.minibatch_size
plugin_config["microbatch_size"] = max(1, self.minibatch_size // plugin_config.get("pp_size", 1))
plugin_config.update(self.plugin_config)
self.plugin = HybridParallelPlugin(**plugin_config)
self.booster = Booster(plugin=self.plugin)

View File

@ -54,7 +54,9 @@ class GRPOConsumer(BaseConsumer):
and "num_microbatches" not in plugin_config
and "microbatch_size" not in plugin_config
):
plugin_config["microbatch_size"] = max(1, grpo_config.get("train_microbatch_size") // 2)
plugin_config["microbatch_size"] = max(
1, grpo_config.get("train_microbatch_size") // plugin_config.get("pp_size", 1)
)
super().__init__(
num_producers,
num_episodes,

View File

@ -209,17 +209,19 @@ if __name__ == "__main__":
num_generations=args.num_generations,
train_model_config=train_model_config,
grpo_config=grpo_config,
plugin_config={
"zero_stage": 2,
}, # for zero
# currently not support tp/pp
# plugin_config={
# "tp_size": 2,
# "pp_size": 2,
# "microbatch_size": max(1, args.train_microbatch_size // 2),
# "zero_stage": 0,
# "max_norm": 1.0,
# }, # for pp
# "zero_stage": 2,
# }, # for zero
# currently not support tp/pp
plugin_config={
"tp_size": 2,
"pp_size": 2,
"microbatch_size": max(
1, args.train_microbatch_size // 2
), # microbatch size should be set to train_microbatch_size // pp_size
"zero_stage": 0,
"max_norm": 1.0,
}, # for pp
inference_backend=args.backend,
master_addr="localhost",
master_port=args.master_port,