Add new implementations of RL algorithms (#6383)

* add new algorithm

* move common calculations

* delete data

* move common calculations of rewards

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
sglucas
2025-09-03 13:48:06 +08:00
committed by GitHub
parent 48a673dcb0
commit 083766d54c
5 changed files with 135 additions and 7 deletions

View File

@@ -101,6 +101,7 @@ class GRPOConsumer(BaseConsumer):
clip_eps_high=grpo_config.get("clip_eps_high", 0.2), clip_eps_high=grpo_config.get("clip_eps_high", 0.2),
beta=grpo_config.get("beta", 0.01), beta=grpo_config.get("beta", 0.01),
loss_variation=grpo_config.get("loss_variation", "sample_level"), loss_variation=grpo_config.get("loss_variation", "sample_level"),
adv=grpo_config.get("algo"),
) )
# Reference model is initialized from policy model. # Reference model is initialized from policy model.
@@ -137,6 +138,8 @@ class GRPOConsumer(BaseConsumer):
eta_min=0.1 * grpo_config.get("lr", 1e-6), eta_min=0.1 * grpo_config.get("lr", 1e-6),
) )
self.adv = grpo_config.get("algo")
def setup(self): def setup(self):
super().setup() super().setup()
if (not self.plugin.pp_size > 1 and self.rank == 0) or ( if (not self.plugin.pp_size > 1 and self.rank == 0) or (
@@ -204,9 +207,23 @@ class GRPOConsumer(BaseConsumer):
# [minibatch_size x num_generations] # [minibatch_size x num_generations]
reward_mean = reward_mean.repeat_interleave(self.num_generations, dim=0) reward_mean = reward_mean.repeat_interleave(self.num_generations, dim=0)
reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0) if self.adv == "GRPO" or self.adv == "DAPO":
# [minibatch_size x num_generations]
advantages = ((reward - reward_mean) / (reward_std + 1e-4)).unsqueeze(dim=-1) reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0)
# [minibatch_size x num_generations]
advantages = ((reward - reward_mean) / (reward_std + 1e-4)).unsqueeze(dim=-1)
elif self.adv == "REINFORCE_PPB":
# [minibatch_size x num_generations]
advantages = ((reward - reward_mean)).unsqueeze(dim=-1)
elif self.adv == "RLOO":
advantages = (
reward * self.num_generations / (self.num_generations - 1)
- reward_mean * self.num_generations / (self.num_generations - 1)
).unsqueeze(dim=-1)
# [minibatch_size x num_of_generation] # [minibatch_size x num_of_generation]
loss_mask = torch.ones(action_mask.size(0), device=action_mask.device).bool() loss_mask = torch.ones(action_mask.size(0), device=action_mask.device).bool()
@@ -358,10 +375,34 @@ class GRPOConsumer(BaseConsumer):
per_token_kl = 0.0 per_token_kl = 0.0
kl.append(torch.tensor(0.0)) kl.append(torch.tensor(0.0))
inputs["advantages"].repeat_interleave(action_log_probs.size(-1), dim=-1)
if self.adv == "REINFORCE_PPB":
inputs["advantages"] = inputs["advantages"] - self.policy_loss_fn.beta * per_token_kl
advantages_forward_micro_batch_mean = torch.sum(
inputs["advantages"] * inputs["action_mask"]
) / (torch.sum(inputs["action_mask"]) + 1e-4)
advantages_forward_micro_batch_std = torch.rsqrt(
torch.sum(
(inputs["advantages"] - advantages_forward_micro_batch_mean) ** 2
* inputs["action_mask"]
)
/ (torch.sum(inputs["action_mask"]) + 1e-4)
+ 1e-8
)
inputs["advantages"] = (
(inputs["advantages"] - advantages_forward_micro_batch_mean)
* inputs["action_mask"]
/ (advantages_forward_micro_batch_std)
)
per_token_kl = 0.0
loss, _ = self.policy_loss_fn( loss, _ = self.policy_loss_fn(
action_log_probs, action_log_probs,
inputs["old_action_log_probs"], inputs["old_action_log_probs"],
inputs["advantages"].repeat_interleave(action_log_probs.size(-1), dim=-1), inputs["advantages"],
per_token_kl, per_token_kl,
inputs["action_mask"], inputs["action_mask"],
loss_mask=inputs["loss_mask"], loss_mask=inputs["loss_mask"],
@@ -420,10 +461,39 @@ class GRPOConsumer(BaseConsumer):
per_token_kl = 0.0 per_token_kl = 0.0
kl = None kl = None
(
advantages_forward_micro_batch.repeat_interleave(action_log_probs.size(-1), dim=-1)
- self.policy_loss_fn.beta * per_token_kl
)
if self.adv == "REINFORCE_PPB":
advantages_forward_micro_batch = (
advantages_forward_micro_batch - self.policy_loss_fn.beta * per_token_kl
)
advantages_forward_micro_batch_mean = torch.sum(
advantages_forward_micro_batch * action_mask_forward_micro_batch
) / (torch.sum(action_mask_forward_micro_batch) + 1e-4)
advantages_forward_micro_batch_std = torch.rsqrt(
torch.sum(
(advantages_forward_micro_batch - advantages_forward_micro_batch_mean) ** 2
* action_mask_forward_micro_batch
)
/ (torch.sum(action_mask_forward_micro_batch) + 1e-4)
+ 1e-8
)
advantages_forward_micro_batch = (
(advantages_forward_micro_batch - advantages_forward_micro_batch_mean)
* action_mask_forward_micro_batch
/ (advantages_forward_micro_batch_std)
)
per_token_kl = 0.0
loss, _ = self.policy_loss_fn( loss, _ = self.policy_loss_fn(
action_log_probs, action_log_probs,
old_action_log_probs_micro_batch, old_action_log_probs_micro_batch,
advantages_forward_micro_batch.repeat_interleave(action_log_probs.size(-1), dim=-1), advantages_forward_micro_batch,
per_token_kl, per_token_kl,
action_mask_forward_micro_batch, action_mask_forward_micro_batch,
loss_mask=loss_mask_forward_micro_batch, loss_mask=loss_mask_forward_micro_batch,

View File

@@ -9,7 +9,13 @@ from .consumer import SimpleConsumer
from .grpo_consumer import GRPOConsumer from .grpo_consumer import GRPOConsumer
from .producer import SimpleProducer from .producer import SimpleProducer
ALGO_MAP = {"Simple": SimpleConsumer, "GRPO": GRPOConsumer, "DAPO": GRPOConsumer} ALGO_MAP = {
"Simple": SimpleConsumer,
"GRPO": GRPOConsumer,
"DAPO": GRPOConsumer,
"REINFORCE_PPB": GRPOConsumer,
"RLOO": GRPOConsumer,
}
def get_jsonl_size_fast(path: str) -> int: def get_jsonl_size_fast(path: str) -> int:
@@ -66,6 +72,7 @@ def launch_distributed(
core_consumer = ALGO_MAP.get(core_algo, SimpleConsumer) core_consumer = ALGO_MAP.get(core_algo, SimpleConsumer)
train_dp_size = get_dp_size_fast(num_consumer_procs, plugin_config) train_dp_size = get_dp_size_fast(num_consumer_procs, plugin_config)
assert (inference_batch_size * num_producers) % (train_batch_size * train_dp_size) == 0 assert (inference_batch_size * num_producers) % (train_batch_size * train_dp_size) == 0
dataset_path = train_dataset_config["path"] dataset_path = train_dataset_config["path"]

View File

@@ -16,6 +16,7 @@ class PolicyLoss(nn.Module):
clip_eps_high: float = 0.2, clip_eps_high: float = 0.2,
beta: float = 0.01, beta: float = 0.01,
loss_variation: str = "sample_level", loss_variation: str = "sample_level",
adv: str = "GRPO",
) -> None: ) -> None:
super().__init__() super().__init__()
self.clip_eps_low = clip_eps_low self.clip_eps_low = clip_eps_low
@@ -23,6 +24,7 @@ class PolicyLoss(nn.Module):
self.beta = beta self.beta = beta
self.loss_variation = loss_variation self.loss_variation = loss_variation
assert loss_variation in ["sample_level", "token_level"], f"Unsupported loss variation: {loss_variation}" assert loss_variation in ["sample_level", "token_level"], f"Unsupported loss variation: {loss_variation}"
self.adv = adv
def forward( def forward(
self, self,

View File

@@ -118,6 +118,9 @@ class BaseProducer:
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, **tokenizer_config) self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, **tokenizer_config)
self.tokenizer.padding_side = "left" self.tokenizer.padding_side = "left"
if self.tokenizer.pad_token_id is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
# init dataloader # init dataloader
train_dataset_path = train_dataset_config.pop("path") train_dataset_path = train_dataset_config.pop("path")
self.train_dataset = RawConversationDataset(self.tokenizer, train_dataset_path, **train_dataset_config) self.train_dataset = RawConversationDataset(self.tokenizer, train_dataset_path, **train_dataset_config)

View File

@@ -137,7 +137,7 @@ if __name__ == "__main__":
) )
# GRPO parameters # GRPO parameters
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["DAPO", "GRPO"]) parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["DAPO", "GRPO", "REINFORCE_PPB", "RLOO"])
parser.add_argument("-lr", "--learning-rate", type=float, default=1e-6, help="Learning rate for GRPO.") parser.add_argument("-lr", "--learning-rate", type=float, default=1e-6, help="Learning rate for GRPO.")
parser.add_argument("-kl", "--kl-coeff", type=float, default=0.01, help="KL penalty coefficient for GRPO.") parser.add_argument("-kl", "--kl-coeff", type=float, default=0.01, help="KL penalty coefficient for GRPO.")
parser.add_argument( parser.add_argument(
@@ -292,6 +292,7 @@ if __name__ == "__main__":
if args.algo == "GRPO": if args.algo == "GRPO":
# Default Settings # Default Settings
grpo_config = { grpo_config = {
"algo": "GRPO",
"lr": args.learning_rate, "lr": args.learning_rate,
"train_microbatch_size": args.train_microbatch_size, "train_microbatch_size": args.train_microbatch_size,
"beta": args.kl_coeff, # KL penalty coefficient "beta": args.kl_coeff, # KL penalty coefficient
@@ -313,6 +314,7 @@ if __name__ == "__main__":
elif args.algo == "DAPO": elif args.algo == "DAPO":
# DAPO variant settings # DAPO variant settings
grpo_config = { grpo_config = {
"algo": "DAPO",
"filter_range": [0.01, 0.99], # only filter out all zero batch and all one batch "filter_range": [0.01, 0.99], # only filter out all zero batch and all one batch
"lr": args.learning_rate, "lr": args.learning_rate,
"train_microbatch_size": args.train_microbatch_size, "train_microbatch_size": args.train_microbatch_size,
@@ -339,6 +341,50 @@ if __name__ == "__main__":
else None else None
), ),
} }
elif args.algo == "REINFORCE_PPB":
# Default Settings
grpo_config = {
"algo": "REINFORCE_PPB",
"lr": args.learning_rate,
"train_microbatch_size": args.train_microbatch_size,
"beta": args.kl_coeff, # KL penalty coefficient
"loss_variation": "sample_level",
"reward_fn_type": args.reward_type,
"max_length": args.max_new_tokens + args.max_prompt_tokens,
"max_new_tokens": args.max_new_tokens,
"response_format_tags": (
{
"think_start": {"text": "<think>", "num_occur": 1},
"think_end": {"text": "</think>", "num_occur": 1},
"answer_start": {"text": "<answer>", "num_occur": 1},
"answer_end": {"text": "</answer>", "num_occur": 1},
}
if args.reward_type == "think_answer_tags"
else None
),
}
elif args.algo == "RLOO":
# Default Settings
grpo_config = {
"algo": "RLOO",
"lr": args.learning_rate,
"train_microbatch_size": args.train_microbatch_size,
"beta": args.kl_coeff, # KL penalty coefficient
"loss_variation": "sample_level",
"reward_fn_type": args.reward_type,
"max_length": args.max_new_tokens + args.max_prompt_tokens,
"max_new_tokens": args.max_new_tokens,
"response_format_tags": (
{
"think_start": {"text": "<think>", "num_occur": 1},
"think_end": {"text": "</think>", "num_occur": 1},
"answer_start": {"text": "<answer>", "num_occur": 1},
"answer_end": {"text": "</answer>", "num_occur": 1},
}
if args.reward_type == "think_answer_tags"
else None
),
}
else: else:
raise ValueError(f"Unsupported algorithm: {args.algo}") raise ValueError(f"Unsupported algorithm: {args.algo}")
if args.reward_type == "code": if args.reward_type == "code":