register agpo

This commit is contained in:
Chen Li 2025-05-13 16:08:38 +08:00
parent eb6b5dd62e
commit e08626d740
2 changed files with 16 additions and 2 deletions

View File

@ -7,7 +7,7 @@ from .consumer import SimpleConsumer
from .grpo_consumer import GRPOConsumer
from .producer import SimpleProducer
ALGO_MAP = {"Simple": SimpleConsumer, "GRPO": GRPOConsumer, "DAPO": GRPOConsumer}
ALGO_MAP = {"Simple": SimpleConsumer, "GRPO": GRPOConsumer, "DAPO": GRPOConsumer, "AGPO": GRPOConsumer}
def get_jsonl_size_fast(path: str) -> int:

View File

@ -83,7 +83,7 @@ if __name__ == "__main__":
parser.add_argument("-mpt", "--max-prompt-tokens", type=int, default=512, help="Max length for prompt.")
# GRPO parameters
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["DAPO", "GRPO"])
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["GRPO", "DAPO", "AGPO"])
parser.add_argument("-lr", "--learning-rate", type=float, default=1e-6, help="Learning rate for GRPO.")
parser.add_argument("-kl", "--kl-coeff", type=float, default=0.01, help="KL penalty coefficient for GRPO.")
parser.add_argument(
@ -196,6 +196,20 @@ if __name__ == "__main__":
"filter_truncated_response": True,
"reward_fn_type": args.reward_type,
}
elif args.algo == "AGPO":
# AGPO variant settings
grpo_config = {
"filter_range": [0.01, 0.99],
"lr": args.learning_rate,
"train_microbatch_size": args.train_microbatch_size,
"dynamic_batching": True,
"clip_eps_low": 0.2,
"clip_eps_high": 0.28,
"beta": 0, # no KL penalty
"loss_variation": "token_level",
"max_length": args.max_new_tokens + args.max_prompt_tokens,
"reward_fn_type": args.reward_type,
}
else:
raise ValueError(f"Unsupported algorithm: {args.algo}")