mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-16 23:16:56 +00:00
register agpo
This commit is contained in:
parent
eb6b5dd62e
commit
e08626d740
@ -7,7 +7,7 @@ from .consumer import SimpleConsumer
|
|||||||
from .grpo_consumer import GRPOConsumer
|
from .grpo_consumer import GRPOConsumer
|
||||||
from .producer import SimpleProducer
|
from .producer import SimpleProducer
|
||||||
|
|
||||||
ALGO_MAP = {"Simple": SimpleConsumer, "GRPO": GRPOConsumer, "DAPO": GRPOConsumer}
|
ALGO_MAP = {"Simple": SimpleConsumer, "GRPO": GRPOConsumer, "DAPO": GRPOConsumer, "AGPO": GRPOConsumer}
|
||||||
|
|
||||||
|
|
||||||
def get_jsonl_size_fast(path: str) -> int:
|
def get_jsonl_size_fast(path: str) -> int:
|
||||||
|
@ -83,7 +83,7 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("-mpt", "--max-prompt-tokens", type=int, default=512, help="Max length for prompt.")
|
parser.add_argument("-mpt", "--max-prompt-tokens", type=int, default=512, help="Max length for prompt.")
|
||||||
|
|
||||||
# GRPO parameters
|
# GRPO parameters
|
||||||
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["DAPO", "GRPO"])
|
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["GRPO", "DAPO", "AGPO"])
|
||||||
parser.add_argument("-lr", "--learning-rate", type=float, default=1e-6, help="Learning rate for GRPO.")
|
parser.add_argument("-lr", "--learning-rate", type=float, default=1e-6, help="Learning rate for GRPO.")
|
||||||
parser.add_argument("-kl", "--kl-coeff", type=float, default=0.01, help="KL penalty coefficient for GRPO.")
|
parser.add_argument("-kl", "--kl-coeff", type=float, default=0.01, help="KL penalty coefficient for GRPO.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -196,6 +196,20 @@ if __name__ == "__main__":
|
|||||||
"filter_truncated_response": True,
|
"filter_truncated_response": True,
|
||||||
"reward_fn_type": args.reward_type,
|
"reward_fn_type": args.reward_type,
|
||||||
}
|
}
|
||||||
|
elif args.algo == "AGPO":
|
||||||
|
# AGPO variant settings
|
||||||
|
grpo_config = {
|
||||||
|
"filter_range": [0.01, 0.99],
|
||||||
|
"lr": args.learning_rate,
|
||||||
|
"train_microbatch_size": args.train_microbatch_size,
|
||||||
|
"dynamic_batching": True,
|
||||||
|
"clip_eps_low": 0.2,
|
||||||
|
"clip_eps_high": 0.28,
|
||||||
|
"beta": 0, # no KL penalty
|
||||||
|
"loss_variation": "token_level",
|
||||||
|
"max_length": args.max_new_tokens + args.max_prompt_tokens,
|
||||||
|
"reward_fn_type": args.reward_type,
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported algorithm: {args.algo}")
|
raise ValueError(f"Unsupported algorithm: {args.algo}")
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user