mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-14 06:05:26 +00:00
fixed plugin micro-batch size
This commit is contained in:
parent
0f794f7294
commit
263a9cbe7a
@ -71,7 +71,7 @@ class BaseConsumer:
|
||||
and "num_microbatches" not in self.plugin_config
|
||||
and "microbatch_size" not in self.plugin_config
|
||||
):
|
||||
plugin_config["microbatch_size"] = self.minibatch_size
|
||||
plugin_config["microbatch_size"] = max(1, self.minibatch_size // plugin_config.get("pp_size", 1))
|
||||
plugin_config.update(self.plugin_config)
|
||||
self.plugin = HybridParallelPlugin(**plugin_config)
|
||||
self.booster = Booster(plugin=self.plugin)
|
||||
|
@ -54,7 +54,9 @@ class GRPOConsumer(BaseConsumer):
|
||||
and "num_microbatches" not in plugin_config
|
||||
and "microbatch_size" not in plugin_config
|
||||
):
|
||||
plugin_config["microbatch_size"] = max(1, grpo_config.get("train_microbatch_size") // 2)
|
||||
plugin_config["microbatch_size"] = max(
|
||||
1, grpo_config.get("train_microbatch_size") // plugin_config.get("pp_size", 1)
|
||||
)
|
||||
super().__init__(
|
||||
num_producers,
|
||||
num_episodes,
|
||||
|
@ -209,17 +209,19 @@ if __name__ == "__main__":
|
||||
num_generations=args.num_generations,
|
||||
train_model_config=train_model_config,
|
||||
grpo_config=grpo_config,
|
||||
plugin_config={
|
||||
"zero_stage": 2,
|
||||
}, # for zero
|
||||
# currently not support tp/pp
|
||||
# plugin_config={
|
||||
# "tp_size": 2,
|
||||
# "pp_size": 2,
|
||||
# "microbatch_size": max(1, args.train_microbatch_size // 2),
|
||||
# "zero_stage": 0,
|
||||
# "max_norm": 1.0,
|
||||
# }, # for pp
|
||||
# "zero_stage": 2,
|
||||
# }, # for zero
|
||||
# currently not support tp/pp
|
||||
plugin_config={
|
||||
"tp_size": 2,
|
||||
"pp_size": 2,
|
||||
"microbatch_size": max(
|
||||
1, args.train_microbatch_size // 2
|
||||
), # microbatch size should be set to train_microbatch_size // pp_size
|
||||
"zero_stage": 0,
|
||||
"max_norm": 1.0,
|
||||
}, # for pp
|
||||
inference_backend=args.backend,
|
||||
master_addr="localhost",
|
||||
master_port=args.master_port,
|
||||
|
Loading…
Reference in New Issue
Block a user