mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-14 06:05:26 +00:00
register agpo
This commit is contained in:
parent
eb6b5dd62e
commit
e08626d740
@ -7,7 +7,7 @@ from .consumer import SimpleConsumer
|
||||
from .grpo_consumer import GRPOConsumer
|
||||
from .producer import SimpleProducer
|
||||
|
||||
ALGO_MAP = {"Simple": SimpleConsumer, "GRPO": GRPOConsumer, "DAPO": GRPOConsumer}
|
||||
ALGO_MAP = {"Simple": SimpleConsumer, "GRPO": GRPOConsumer, "DAPO": GRPOConsumer, "AGPO": GRPOConsumer}
|
||||
|
||||
|
||||
def get_jsonl_size_fast(path: str) -> int:
|
||||
|
@ -83,7 +83,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("-mpt", "--max-prompt-tokens", type=int, default=512, help="Max length for prompt.")
|
||||
|
||||
# GRPO parameters
|
||||
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["DAPO", "GRPO"])
|
||||
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["GRPO", "DAPO", "AGPO"])
|
||||
parser.add_argument("-lr", "--learning-rate", type=float, default=1e-6, help="Learning rate for GRPO.")
|
||||
parser.add_argument("-kl", "--kl-coeff", type=float, default=0.01, help="KL penalty coefficient for GRPO.")
|
||||
parser.add_argument(
|
||||
@ -196,6 +196,20 @@ if __name__ == "__main__":
|
||||
"filter_truncated_response": True,
|
||||
"reward_fn_type": args.reward_type,
|
||||
}
|
||||
elif args.algo == "AGPO":
|
||||
# AGPO variant settings
|
||||
grpo_config = {
|
||||
"filter_range": [0.01, 0.99],
|
||||
"lr": args.learning_rate,
|
||||
"train_microbatch_size": args.train_microbatch_size,
|
||||
"dynamic_batching": True,
|
||||
"clip_eps_low": 0.2,
|
||||
"clip_eps_high": 0.28,
|
||||
"beta": 0, # no KL penalty
|
||||
"loss_variation": "token_level",
|
||||
"max_length": args.max_new_tokens + args.max_prompt_tokens,
|
||||
"reward_fn_type": args.reward_type,
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unsupported algorithm: {args.algo}")
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user