[feat] Support prompt level dynamic (#6300)

* adjust to dynamic prompt bs

* remove debug

* update pad seq (#6303)

Co-authored-by: Tong Li <tong.li35271158@gmail.com>

* adjust to dynamic prompt bs

* remove debug

* fix dp issue

* fix

* fix default settings

---------

Co-authored-by: Tong Li <tong.li35271158@gmail.com>
This commit is contained in:
Tong Li
2025-05-14 16:40:35 +08:00
committed by GitHub
parent b920af427b
commit aca547623f
4 changed files with 123 additions and 93 deletions

View File

@@ -9,7 +9,7 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B")
parser.add_argument("-d", "--dataset", type=str, default="data.jsonl")
parser.add_argument("-p", "--project", type=str, default="GRPO", help="Project name.")
parser.add_argument("-p", "--project", type=str, default="GRPO-V3", help="Project name.")
parser.add_argument("-e", "--num-episodes", type=int, default=1, help="Number of episodes to train.")
# Distributed training parameters
@@ -20,7 +20,7 @@ if __name__ == "__main__":
"-ibs",
"--inference-batch-size",
type=int,
default=None,
default=64,
help="Number of prompts to generate per inference step. It should be divisible by tbs, and the weights on the inference backend will be synced every ibs/tbs training steps of the policy model.",
)
parser.add_argument(
@@ -41,7 +41,7 @@ if __name__ == "__main__":
"-tMbs",
"--train-minibatch-size",
type=int,
default=None,
default=8,
help="Number of unique prompts in each training batch per dp group. The inference backend must generate tMbs * g * dp_size samples before forwarding. Satisfy tMbs * g >= tmbs",
)
parser.add_argument(
@@ -58,7 +58,7 @@ if __name__ == "__main__":
"--master_address", type=str, default=None, help="Master address for multi-node distributed training, Optional"
)
parser.add_argument(
"--master_port", type=int, default=29505, help="Master port for multi-node distributed training, Optional"
"--master_port", type=int, default=29506, help="Master port for multi-node distributed training, Optional"
)
# Sampling parameters
@@ -223,7 +223,7 @@ if __name__ == "__main__":
"zero_stage": 2,
}, # for zero
# plugin_config={
# "tp_size": 1,
# "tp_size": 2,
# "pp_size": 2,
# "microbatch_size": max(
# 1, args.train_microbatch_size // 2