fix pp+tp, fix dataloader (#6280)

This commit is contained in:
YeAnbang
2025-04-28 17:10:00 +08:00
committed by GitHub
parent 28795f560c
commit 2ca1e3c630
5 changed files with 17 additions and 8 deletions

View File

@@ -66,7 +66,11 @@ class BaseConsumer:
launch(self.rank, self.world_size, self.master_addr, self.master_port, local_rank=0)
plugin_config = dict(tp_size=1, pp_size=1, precision="bf16", zero_stage=2)
if self.plugin_config.get("pp_size", 1) > 1 and "num_microbatches" not in self.plugin_config:
if (
self.plugin_config.get("pp_size", 1) > 1
and "num_microbatches" not in self.plugin_config
and "microbatch_size" not in self.plugin_config
):
plugin_config["microbatch_size"] = self.minibatch_size
plugin_config.update(self.plugin_config)
self.plugin = HybridParallelPlugin(**plugin_config)