support n_behind, add profiling

This commit is contained in:
YeAnbang
2025-06-20 03:14:00 +00:00
parent e3d56cbd86
commit ff6696a9bb
8 changed files with 233 additions and 29 deletions

View File

@@ -67,6 +67,27 @@ if __name__ == "__main__":
default=2,
help="Effective batch size per dp group for forwarding and backwarding. Please select based on the availiable memory.",
)
parser.add_argument(
"-tp",
"--tensor-parallel-size",
type=int,
default=1,
help="Tensor parallel size for the trainer (consumer). Please check the generation arguments documentation for your backend.",
)
parser.add_argument(
"-pp",
"--pipeline-parallel-size",
type=int,
default=1,
help="Pipeline parallel size for the trainer (consumer). Please check the generation arguments documentation for your backend.",
)
parser.add_argument(
"-zero",
"--zero-stage",
type=int,
default=0,
help="Zero stage for the trainer (consumer). Please check the generation arguments documentation for your backend.",
)
parser.add_argument(
"--ray_dir", type=str, default=None, help="Custom temperary directory for storing ray cluster data, Optional"
)
@@ -97,6 +118,13 @@ if __name__ == "__main__":
parser.add_argument("-s", "--system-prompt", type=str, default=None, help="System prompt for data construction.")
parser.add_argument("-mnt", "--max-new-tokens", type=int, default=1024 * 4 - 512, help="Max length for generation.")
parser.add_argument("-mpt", "--max-prompt-tokens", type=int, default=512, help="Max length for prompt.")
parser.add_argument(
"-ptp",
"--producer-tensor-parallel-size",
type=int,
default=1,
help="Tensor parallel size for the producer. Please check the generation arguments documentation for your backend.",
)
# GRPO parameters
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["DAPO", "GRPO"])
@@ -117,6 +145,13 @@ if __name__ == "__main__":
default=100,
help="Interval for evaluation. Evaluate every ei training steps.",
)
parser.add_argument(
"-nb",
"--n-behind",
type=int,
default=0,
help="Number of producer batches to rollout to fill the data buffer before trainer starts to decrease bubble time",
)
# Logging/Checkpointing parameters
parser.add_argument("-si", "--save-interval", type=int, default=100, help="Interval for saving checkpoints.")
@@ -128,32 +163,7 @@ if __name__ == "__main__":
"-rsd", "--rollout-save-dir", type=str, default="./rollouts", help="Directory for saving rollout loggings."
)
parser.add_argument(
"-tp",
"--tensor-parallel-size",
type=int,
default=1,
help="Tensor parallel size for the trainer (consumer). Please check the generation arguments documentation for your backend.",
)
parser.add_argument(
"-pp",
"--pipeline-parallel-size",
type=int,
default=1,
help="Pipeline parallel size for the trainer (consumer). Please check the generation arguments documentation for your backend.",
)
parser.add_argument(
"-zero",
"--zero-stage",
type=int,
default=0,
help="Zero stage for the trainer (consumer). Please check the generation arguments documentation for your backend.",
)
parser.add_argument(
"-ptp",
"--producer-tensor-parallel-size",
type=int,
default=1,
help="Tensor parallel size for the producer. Please check the generation arguments documentation for your backend.",
"--enable_profiling", action="store_true", default=False, help="Enable profiling for the training process."
)
args = parser.parse_args()
@@ -353,4 +363,6 @@ if __name__ == "__main__":
eval_generation_config=eval_generation_config,
log_rollout_interval=20,
rollout_save_dir=args.rollout_save_dir,
enable_profiling=args.enable_profiling,
n_behind=args.n_behind,
)