mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-17 07:26:29 +00:00
fixed plugin micro-batch size
This commit is contained in:
parent
0f794f7294
commit
263a9cbe7a
@ -71,7 +71,7 @@ class BaseConsumer:
|
|||||||
and "num_microbatches" not in self.plugin_config
|
and "num_microbatches" not in self.plugin_config
|
||||||
and "microbatch_size" not in self.plugin_config
|
and "microbatch_size" not in self.plugin_config
|
||||||
):
|
):
|
||||||
plugin_config["microbatch_size"] = self.minibatch_size
|
plugin_config["microbatch_size"] = max(1, self.minibatch_size // plugin_config.get("pp_size", 1))
|
||||||
plugin_config.update(self.plugin_config)
|
plugin_config.update(self.plugin_config)
|
||||||
self.plugin = HybridParallelPlugin(**plugin_config)
|
self.plugin = HybridParallelPlugin(**plugin_config)
|
||||||
self.booster = Booster(plugin=self.plugin)
|
self.booster = Booster(plugin=self.plugin)
|
||||||
|
@ -54,7 +54,9 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
and "num_microbatches" not in plugin_config
|
and "num_microbatches" not in plugin_config
|
||||||
and "microbatch_size" not in plugin_config
|
and "microbatch_size" not in plugin_config
|
||||||
):
|
):
|
||||||
plugin_config["microbatch_size"] = max(1, grpo_config.get("train_microbatch_size") // 2)
|
plugin_config["microbatch_size"] = max(
|
||||||
|
1, grpo_config.get("train_microbatch_size") // plugin_config.get("pp_size", 1)
|
||||||
|
)
|
||||||
super().__init__(
|
super().__init__(
|
||||||
num_producers,
|
num_producers,
|
||||||
num_episodes,
|
num_episodes,
|
||||||
|
@ -209,17 +209,19 @@ if __name__ == "__main__":
|
|||||||
num_generations=args.num_generations,
|
num_generations=args.num_generations,
|
||||||
train_model_config=train_model_config,
|
train_model_config=train_model_config,
|
||||||
grpo_config=grpo_config,
|
grpo_config=grpo_config,
|
||||||
plugin_config={
|
|
||||||
"zero_stage": 2,
|
|
||||||
}, # for zero
|
|
||||||
# currently not support tp/pp
|
|
||||||
# plugin_config={
|
# plugin_config={
|
||||||
# "tp_size": 2,
|
# "zero_stage": 2,
|
||||||
# "pp_size": 2,
|
# }, # for zero
|
||||||
# "microbatch_size": max(1, args.train_microbatch_size // 2),
|
# currently not support tp/pp
|
||||||
# "zero_stage": 0,
|
plugin_config={
|
||||||
# "max_norm": 1.0,
|
"tp_size": 2,
|
||||||
# }, # for pp
|
"pp_size": 2,
|
||||||
|
"microbatch_size": max(
|
||||||
|
1, args.train_microbatch_size // 2
|
||||||
|
), # microbatch size should be set to train_microbatch_size // pp_size
|
||||||
|
"zero_stage": 0,
|
||||||
|
"max_norm": 1.0,
|
||||||
|
}, # for pp
|
||||||
inference_backend=args.backend,
|
inference_backend=args.backend,
|
||||||
master_addr="localhost",
|
master_addr="localhost",
|
||||||
master_port=args.master_port,
|
master_port=args.master_port,
|
||||||
|
Loading…
Reference in New Issue
Block a user