From e08626d740a8f2cb3dac56ec08f97ed11cfbd85e Mon Sep 17 00:00:00 2001 From: Chen Li Date: Tue, 13 May 2025 16:08:38 +0800 Subject: [PATCH] register agpo --- .../ColossalChat/coati/distributed/launch.py | 2 +- applications/ColossalChat/rl_example.py | 16 +++++++++++++++- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/launch.py b/applications/ColossalChat/coati/distributed/launch.py index a346d1d4f..b64529593 100644 --- a/applications/ColossalChat/coati/distributed/launch.py +++ b/applications/ColossalChat/coati/distributed/launch.py @@ -7,7 +7,7 @@ from .consumer import SimpleConsumer from .grpo_consumer import GRPOConsumer from .producer import SimpleProducer -ALGO_MAP = {"Simple": SimpleConsumer, "GRPO": GRPOConsumer, "DAPO": GRPOConsumer} +ALGO_MAP = {"Simple": SimpleConsumer, "GRPO": GRPOConsumer, "DAPO": GRPOConsumer, "AGPO": GRPOConsumer} def get_jsonl_size_fast(path: str) -> int: diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index 788e60c2e..d6565b922 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -83,7 +83,7 @@ if __name__ == "__main__": parser.add_argument("-mpt", "--max-prompt-tokens", type=int, default=512, help="Max length for prompt.") # GRPO parameters - parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["DAPO", "GRPO"]) + parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["GRPO", "DAPO", "AGPO"]) parser.add_argument("-lr", "--learning-rate", type=float, default=1e-6, help="Learning rate for GRPO.") parser.add_argument("-kl", "--kl-coeff", type=float, default=0.01, help="KL penalty coefficient for GRPO.") parser.add_argument( @@ -196,6 +196,20 @@ if __name__ == "__main__": "filter_truncated_response": True, "reward_fn_type": args.reward_type, } + elif args.algo == "AGPO": + # AGPO variant settings + grpo_config = { + "filter_range": [0.01, 0.99], + "lr": args.learning_rate, + "train_microbatch_size": args.train_microbatch_size, + "dynamic_batching": True, + "clip_eps_low": 0.2, + "clip_eps_high": 0.28, + "beta": 0, # no KL penalty + "loss_variation": "token_level", + "max_length": args.max_new_tokens + args.max_prompt_tokens, + "reward_fn_type": args.reward_type, + } else: raise ValueError(f"Unsupported algorithm: {args.algo}")