mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-11-01 06:19:48 +00:00
[feat] Support boxed math reward (#6284)
* fix pp+tp, fix dataloader
* fixed plugin micro-batch size
* support boxed reward
* add boxed reward
* fix pp state dict incomplete issue
* Revert "fix pp state dict incomplete issue"
This reverts commit 6c1b3b694f.
This commit is contained in:
@@ -86,6 +86,14 @@ if __name__ == "__main__":
|
||||
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["DAPO", "GRPO"])
|
||||
parser.add_argument("-lr", "--learning-rate", type=float, default=1e-6, help="Learning rate for GRPO.")
|
||||
parser.add_argument("-kl", "--kl-coeff", type=float, default=0.01, help="KL penalty coefficient for GRPO.")
|
||||
parser.add_argument(
|
||||
"-rt",
|
||||
"--reward-type",
|
||||
type=str,
|
||||
default="think_answer_tags",
|
||||
choices=["think_answer_tags", "boxed"],
|
||||
help="Reward type for GRPO.",
|
||||
)
|
||||
|
||||
# Logging/Checkpointing parameters
|
||||
parser.add_argument("-si", "--save-interval", type=int, default=100, help="Interval for saving checkpoints.")
|
||||
@@ -136,8 +144,8 @@ if __name__ == "__main__":
|
||||
max_length=args.max_new_tokens + args.max_prompt_tokens,
|
||||
do_sample=True,
|
||||
max_new_tokens=None,
|
||||
early_stopping=False,
|
||||
stop_strings=["</answer>"],
|
||||
early_stopping=False if args.reward_type == "think_answer_tags" else True,
|
||||
stop_strings=["</answer>"] if args.reward_type == "think_answer_tags" else None,
|
||||
)
|
||||
)
|
||||
elif args.backend == "vllm":
|
||||
@@ -153,9 +161,9 @@ if __name__ == "__main__":
|
||||
generate_config.update(
|
||||
dict(
|
||||
max_tokens=args.max_new_tokens, # max new tokens
|
||||
ignore_eos=True,
|
||||
ignore_eos=True if args.reward_type == "think_answer_tags" else False,
|
||||
include_stop_str_in_output=True,
|
||||
stop=["</answer>"],
|
||||
stop=["</answer>"] if args.reward_type == "think_answer_tags" else None,
|
||||
)
|
||||
)
|
||||
else:
|
||||
@@ -168,6 +176,7 @@ if __name__ == "__main__":
|
||||
"train_microbatch_size": args.train_microbatch_size,
|
||||
"beta": args.kl_coeff, # KL penalty coefficient
|
||||
"loss_variation": "sample_level",
|
||||
"reward_fn_type": args.reward_type,
|
||||
}
|
||||
elif args.algo == "DAPO":
|
||||
# DAPO variant settings
|
||||
@@ -185,6 +194,7 @@ if __name__ == "__main__":
|
||||
"max_length": args.max_new_tokens + args.max_prompt_tokens,
|
||||
"cache_length": min(1024, int(args.max_new_tokens / 4)),
|
||||
"filter_truncated_response": True,
|
||||
"reward_fn_type": args.reward_type,
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unsupported algorithm: {args.algo}")
|
||||
@@ -212,14 +222,15 @@ if __name__ == "__main__":
|
||||
plugin_config={
|
||||
"zero_stage": 2,
|
||||
}, # for zero
|
||||
# currently not support tp/pp
|
||||
# plugin_config={
|
||||
# "tp_size": 2,
|
||||
# "pp_size": 2,
|
||||
# "microbatch_size": max(1, args.train_microbatch_size // 2),
|
||||
# "microbatch_size": max(
|
||||
# 1, args.train_microbatch_size // 2
|
||||
# ), # microbatch size should be set to train_microbatch_size // pp_size
|
||||
# "zero_stage": 0,
|
||||
# "max_norm": 1.0,
|
||||
# }, # for pp
|
||||
# }, # for pp, tp
|
||||
inference_backend=args.backend,
|
||||
master_addr="localhost",
|
||||
master_port=args.master_port,
|
||||
|
||||
Reference in New Issue
Block a user