mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-11-29 22:37:14 +00:00
Add new implementations of RL algorithms (#6383)
* add new algorithm * move common calculations * delete data * move common calculations of rewards * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -137,7 +137,7 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
# GRPO parameters
|
||||
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["DAPO", "GRPO"])
|
||||
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["DAPO", "GRPO", "REINFORCE_PPB", "RLOO"])
|
||||
parser.add_argument("-lr", "--learning-rate", type=float, default=1e-6, help="Learning rate for GRPO.")
|
||||
parser.add_argument("-kl", "--kl-coeff", type=float, default=0.01, help="KL penalty coefficient for GRPO.")
|
||||
parser.add_argument(
|
||||
@@ -292,6 +292,7 @@ if __name__ == "__main__":
|
||||
if args.algo == "GRPO":
|
||||
# Default Settings
|
||||
grpo_config = {
|
||||
"algo": "GRPO",
|
||||
"lr": args.learning_rate,
|
||||
"train_microbatch_size": args.train_microbatch_size,
|
||||
"beta": args.kl_coeff, # KL penalty coefficient
|
||||
@@ -313,6 +314,7 @@ if __name__ == "__main__":
|
||||
elif args.algo == "DAPO":
|
||||
# DAPO variant settings
|
||||
grpo_config = {
|
||||
"algo": "DAPO",
|
||||
"filter_range": [0.01, 0.99], # only filter out all zero batch and all one batch
|
||||
"lr": args.learning_rate,
|
||||
"train_microbatch_size": args.train_microbatch_size,
|
||||
@@ -339,6 +341,50 @@ if __name__ == "__main__":
|
||||
else None
|
||||
),
|
||||
}
|
||||
elif args.algo == "REINFORCE_PPB":
|
||||
# Default Settings
|
||||
grpo_config = {
|
||||
"algo": "REINFORCE_PPB",
|
||||
"lr": args.learning_rate,
|
||||
"train_microbatch_size": args.train_microbatch_size,
|
||||
"beta": args.kl_coeff, # KL penalty coefficient
|
||||
"loss_variation": "sample_level",
|
||||
"reward_fn_type": args.reward_type,
|
||||
"max_length": args.max_new_tokens + args.max_prompt_tokens,
|
||||
"max_new_tokens": args.max_new_tokens,
|
||||
"response_format_tags": (
|
||||
{
|
||||
"think_start": {"text": "<think>", "num_occur": 1},
|
||||
"think_end": {"text": "</think>", "num_occur": 1},
|
||||
"answer_start": {"text": "<answer>", "num_occur": 1},
|
||||
"answer_end": {"text": "</answer>", "num_occur": 1},
|
||||
}
|
||||
if args.reward_type == "think_answer_tags"
|
||||
else None
|
||||
),
|
||||
}
|
||||
elif args.algo == "RLOO":
|
||||
# Default Settings
|
||||
grpo_config = {
|
||||
"algo": "RLOO",
|
||||
"lr": args.learning_rate,
|
||||
"train_microbatch_size": args.train_microbatch_size,
|
||||
"beta": args.kl_coeff, # KL penalty coefficient
|
||||
"loss_variation": "sample_level",
|
||||
"reward_fn_type": args.reward_type,
|
||||
"max_length": args.max_new_tokens + args.max_prompt_tokens,
|
||||
"max_new_tokens": args.max_new_tokens,
|
||||
"response_format_tags": (
|
||||
{
|
||||
"think_start": {"text": "<think>", "num_occur": 1},
|
||||
"think_end": {"text": "</think>", "num_occur": 1},
|
||||
"answer_start": {"text": "<answer>", "num_occur": 1},
|
||||
"answer_end": {"text": "</answer>", "num_occur": 1},
|
||||
}
|
||||
if args.reward_type == "think_answer_tags"
|
||||
else None
|
||||
),
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unsupported algorithm: {args.algo}")
|
||||
if args.reward_type == "code":
|
||||
|
||||
Reference in New Issue
Block a user