fix metric calculation

This commit is contained in:
YeAnbang
2025-05-20 18:14:05 +08:00
parent 116621d004
commit 37663386bc
4 changed files with 147 additions and 48 deletions

View File

@@ -121,6 +121,34 @@ if __name__ == "__main__":
parser.add_argument(
"-rsd", "--rollout-save-dir", type=str, default="./rollouts", help="Directory for saving rollout loggings."
)
parser.add_argument(
"-tp",
"--tensor-parallel-size",
type=int,
default=1,
help="Tensor parallel size for the inference backend. Please check the generation arguments documentation for your backend.",
)
parser.add_argument(
"-pp",
"--pipeline-parallel-size",
type=int,
default=1,
help="Pipeline parallel size for the inference backend. Please check the generation arguments documentation for your backend.",
)
parser.add_argument(
"-zero",
"--zero-stage",
type=int,
default=0,
help="Zero stage for the inference backend. Please check the generation arguments documentation for your backend.",
)
parser.add_argument(
"-ptp",
"--produce-tensor-parallel-size",
type=int,
default=1,
help="Tensor parallel size for the producer. Please check the generation arguments documentation for your backend.",
)
args = parser.parse_args()
if args.train_minibatch_size is None:
@@ -178,7 +206,7 @@ if __name__ == "__main__":
enforce_eager=True,
enable_chunked_prefill=True,
max_model_len=args.max_new_tokens + args.max_prompt_tokens,
tensor_parallel_size=1,
tensor_parallel_size=args.produce_tensor_parallel_size,
)
)
generate_config.update(
@@ -228,7 +256,7 @@ if __name__ == "__main__":
launch_distributed(
num_producers=args.num_inferencer,
num_proc_per_producer=inference_model_config.get("tensor_parallel_size", 1),
num_proc_per_producer=inference_model_config.get("tensor_parallel_size", args.produce_tensor_parallel_size),
num_consumer_procs=args.num_trainers,
num_episodes=args.num_episodes,
inference_batch_size=args.inference_batch_size,
@@ -247,17 +275,14 @@ if __name__ == "__main__":
train_model_config=train_model_config,
grpo_config=grpo_config,
plugin_config={
"zero_stage": 2,
}, # for zero
# plugin_config={
# "tp_size": 2,
# "pp_size": 2,
# "microbatch_size": max(
# 1, args.train_microbatch_size // 2
# ), # microbatch size should be set to train_microbatch_size // pp_size
# "zero_stage": 0,
# "max_norm": 1.0,
# }, # for pp, tp
"tp_size": args.tensor_parallel_size,
"pp_size": args.pipeline_parallel_size,
"microbatch_size": max(
1, args.train_microbatch_size // args.pipeline_parallel_size
), # microbatch size should be set to train_microbatch_size // pp_size
"zero_stage": args.zero_stage,
"max_norm": 1.0,
}, # for pp, tp
inference_backend=args.backend,
master_addr="localhost",
master_port=args.master_port,
@@ -273,5 +298,5 @@ if __name__ == "__main__":
eval_save_dir=os.path.join(args.eval_save_dir, args.project.replace(" ", "_")),
eval_generation_config=eval_generation_config,
log_rollout_interval=20,
rollout_log_file=os.path.join(args.rollout_save_dir, args.project.replace(" ", "_") + ".jsonl"),
rollout_save_dir=args.rollout_save_dir,
)