mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-10-29 21:03:13 +00:00
support n_behind, add profiling
This commit is contained in:
@@ -67,6 +67,27 @@ if __name__ == "__main__":
|
||||
default=2,
|
||||
help="Effective batch size per dp group for forwarding and backwarding. Please select based on the availiable memory.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-tp",
|
||||
"--tensor-parallel-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Tensor parallel size for the trainer (consumer). Please check the generation arguments documentation for your backend.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-pp",
|
||||
"--pipeline-parallel-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Pipeline parallel size for the trainer (consumer). Please check the generation arguments documentation for your backend.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-zero",
|
||||
"--zero-stage",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Zero stage for the trainer (consumer). Please check the generation arguments documentation for your backend.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ray_dir", type=str, default=None, help="Custom temperary directory for storing ray cluster data, Optional"
|
||||
)
|
||||
@@ -97,6 +118,13 @@ if __name__ == "__main__":
|
||||
parser.add_argument("-s", "--system-prompt", type=str, default=None, help="System prompt for data construction.")
|
||||
parser.add_argument("-mnt", "--max-new-tokens", type=int, default=1024 * 4 - 512, help="Max length for generation.")
|
||||
parser.add_argument("-mpt", "--max-prompt-tokens", type=int, default=512, help="Max length for prompt.")
|
||||
parser.add_argument(
|
||||
"-ptp",
|
||||
"--producer-tensor-parallel-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Tensor parallel size for the producer. Please check the generation arguments documentation for your backend.",
|
||||
)
|
||||
|
||||
# GRPO parameters
|
||||
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["DAPO", "GRPO"])
|
||||
@@ -117,6 +145,13 @@ if __name__ == "__main__":
|
||||
default=100,
|
||||
help="Interval for evaluation. Evaluate every ei training steps.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-nb",
|
||||
"--n-behind",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Number of producer batches to rollout to fill the data buffer before trainer starts to decrease bubble time",
|
||||
)
|
||||
|
||||
# Logging/Checkpointing parameters
|
||||
parser.add_argument("-si", "--save-interval", type=int, default=100, help="Interval for saving checkpoints.")
|
||||
@@ -128,32 +163,7 @@ if __name__ == "__main__":
|
||||
"-rsd", "--rollout-save-dir", type=str, default="./rollouts", help="Directory for saving rollout loggings."
|
||||
)
|
||||
parser.add_argument(
|
||||
"-tp",
|
||||
"--tensor-parallel-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Tensor parallel size for the trainer (consumer). Please check the generation arguments documentation for your backend.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-pp",
|
||||
"--pipeline-parallel-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Pipeline parallel size for the trainer (consumer). Please check the generation arguments documentation for your backend.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-zero",
|
||||
"--zero-stage",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Zero stage for the trainer (consumer). Please check the generation arguments documentation for your backend.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-ptp",
|
||||
"--producer-tensor-parallel-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Tensor parallel size for the producer. Please check the generation arguments documentation for your backend.",
|
||||
"--enable_profiling", action="store_true", default=False, help="Enable profiling for the training process."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -353,4 +363,6 @@ if __name__ == "__main__":
|
||||
eval_generation_config=eval_generation_config,
|
||||
log_rollout_interval=20,
|
||||
rollout_save_dir=args.rollout_save_dir,
|
||||
enable_profiling=args.enable_profiling,
|
||||
n_behind=args.n_behind,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user