[feat] Support boxed math reward (#6284)

* fix pp+tp, fix dataloader

* fixed plugin micro-batch size

* support boxed reward

* add boxed reward

* fix pp state dict incomplete issue

* Revert "fix pp state dict incomplete issue"

This reverts commit 6c1b3b694f.
This commit is contained in:
YeAnbang
2025-04-29 16:46:47 +08:00
committed by GitHub
parent 2ca1e3c630
commit 14f237ce7e
5 changed files with 118 additions and 12 deletions

View File

@@ -86,6 +86,14 @@ if __name__ == "__main__":
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["DAPO", "GRPO"])
parser.add_argument("-lr", "--learning-rate", type=float, default=1e-6, help="Learning rate for GRPO.")
parser.add_argument("-kl", "--kl-coeff", type=float, default=0.01, help="KL penalty coefficient for GRPO.")
parser.add_argument(
"-rt",
"--reward-type",
type=str,
default="think_answer_tags",
choices=["think_answer_tags", "boxed"],
help="Reward type for GRPO.",
)
# Logging/Checkpointing parameters
parser.add_argument("-si", "--save-interval", type=int, default=100, help="Interval for saving checkpoints.")
@@ -136,8 +144,8 @@ if __name__ == "__main__":
max_length=args.max_new_tokens + args.max_prompt_tokens,
do_sample=True,
max_new_tokens=None,
early_stopping=False,
stop_strings=["</answer>"],
early_stopping=False if args.reward_type == "think_answer_tags" else True,
stop_strings=["</answer>"] if args.reward_type == "think_answer_tags" else None,
)
)
elif args.backend == "vllm":
@@ -153,9 +161,9 @@ if __name__ == "__main__":
generate_config.update(
dict(
max_tokens=args.max_new_tokens, # max new tokens
ignore_eos=True,
ignore_eos=True if args.reward_type == "think_answer_tags" else False,
include_stop_str_in_output=True,
stop=["</answer>"],
stop=["</answer>"] if args.reward_type == "think_answer_tags" else None,
)
)
else:
@@ -168,6 +176,7 @@ if __name__ == "__main__":
"train_microbatch_size": args.train_microbatch_size,
"beta": args.kl_coeff, # KL penalty coefficient
"loss_variation": "sample_level",
"reward_fn_type": args.reward_type,
}
elif args.algo == "DAPO":
# DAPO variant settings
@@ -185,6 +194,7 @@ if __name__ == "__main__":
"max_length": args.max_new_tokens + args.max_prompt_tokens,
"cache_length": min(1024, int(args.max_new_tokens / 4)),
"filter_truncated_response": True,
"reward_fn_type": args.reward_type,
}
else:
raise ValueError(f"Unsupported algorithm: {args.algo}")
@@ -212,14 +222,15 @@ if __name__ == "__main__":
plugin_config={
"zero_stage": 2,
}, # for zero
# currently not support tp/pp
# plugin_config={
# "tp_size": 2,
# "pp_size": 2,
# "microbatch_size": max(1, args.train_microbatch_size // 2),
# "microbatch_size": max(
# 1, args.train_microbatch_size // 2
# ), # microbatch size should be set to train_microbatch_size // pp_size
# "zero_stage": 0,
# "max_norm": 1.0,
# }, # for pp
# }, # for pp, tp
inference_backend=args.backend,
master_addr="localhost",
master_port=args.master_port,