From 083766d54ca2fab54fa6770bb05401f4ee44c525 Mon Sep 17 00:00:00 2001 From: sglucas Date: Wed, 3 Sep 2025 13:48:06 +0800 Subject: [PATCH] Add new implementations of RL algorithms (#6383) * add new algorithm * move common calculations * delete data * move common calculations of rewards * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../coati/distributed/grpo_consumer.py | 80 +++++++++++++++++-- .../ColossalChat/coati/distributed/launch.py | 9 ++- .../ColossalChat/coati/distributed/loss.py | 2 + .../coati/distributed/producer.py | 3 + applications/ColossalChat/rl_example.py | 48 ++++++++++- 5 files changed, 135 insertions(+), 7 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index ee72e0290..40d362340 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -101,6 +101,7 @@ class GRPOConsumer(BaseConsumer): clip_eps_high=grpo_config.get("clip_eps_high", 0.2), beta=grpo_config.get("beta", 0.01), loss_variation=grpo_config.get("loss_variation", "sample_level"), + adv=grpo_config.get("algo"), ) # Reference model is initialized from policy model. @@ -137,6 +138,8 @@ class GRPOConsumer(BaseConsumer): eta_min=0.1 * grpo_config.get("lr", 1e-6), ) + self.adv = grpo_config.get("algo") + def setup(self): super().setup() if (not self.plugin.pp_size > 1 and self.rank == 0) or ( @@ -204,9 +207,23 @@ class GRPOConsumer(BaseConsumer): # [minibatch_size x num_generations] reward_mean = reward_mean.repeat_interleave(self.num_generations, dim=0) - reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0) - # [minibatch_size x num_generations] - advantages = ((reward - reward_mean) / (reward_std + 1e-4)).unsqueeze(dim=-1) + if self.adv == "GRPO" or self.adv == "DAPO": + + reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0) + # [minibatch_size x num_generations] + advantages = ((reward - reward_mean) / (reward_std + 1e-4)).unsqueeze(dim=-1) + + elif self.adv == "REINFORCE_PPB": + + # [minibatch_size x num_generations] + advantages = ((reward - reward_mean)).unsqueeze(dim=-1) + + elif self.adv == "RLOO": + + advantages = ( + reward * self.num_generations / (self.num_generations - 1) + - reward_mean * self.num_generations / (self.num_generations - 1) + ).unsqueeze(dim=-1) # [minibatch_size x num_of_generation] loss_mask = torch.ones(action_mask.size(0), device=action_mask.device).bool() @@ -358,10 +375,34 @@ class GRPOConsumer(BaseConsumer): per_token_kl = 0.0 kl.append(torch.tensor(0.0)) + inputs["advantages"].repeat_interleave(action_log_probs.size(-1), dim=-1) + + if self.adv == "REINFORCE_PPB": + + inputs["advantages"] = inputs["advantages"] - self.policy_loss_fn.beta * per_token_kl + advantages_forward_micro_batch_mean = torch.sum( + inputs["advantages"] * inputs["action_mask"] + ) / (torch.sum(inputs["action_mask"]) + 1e-4) + advantages_forward_micro_batch_std = torch.rsqrt( + torch.sum( + (inputs["advantages"] - advantages_forward_micro_batch_mean) ** 2 + * inputs["action_mask"] + ) + / (torch.sum(inputs["action_mask"]) + 1e-4) + + 1e-8 + ) + inputs["advantages"] = ( + (inputs["advantages"] - advantages_forward_micro_batch_mean) + * inputs["action_mask"] + / (advantages_forward_micro_batch_std) + ) + + per_token_kl = 0.0 + loss, _ = self.policy_loss_fn( action_log_probs, inputs["old_action_log_probs"], - inputs["advantages"].repeat_interleave(action_log_probs.size(-1), dim=-1), + inputs["advantages"], per_token_kl, inputs["action_mask"], loss_mask=inputs["loss_mask"], @@ -420,10 +461,39 @@ class GRPOConsumer(BaseConsumer): per_token_kl = 0.0 kl = None + ( + advantages_forward_micro_batch.repeat_interleave(action_log_probs.size(-1), dim=-1) + - self.policy_loss_fn.beta * per_token_kl + ) + + if self.adv == "REINFORCE_PPB": + + advantages_forward_micro_batch = ( + advantages_forward_micro_batch - self.policy_loss_fn.beta * per_token_kl + ) + advantages_forward_micro_batch_mean = torch.sum( + advantages_forward_micro_batch * action_mask_forward_micro_batch + ) / (torch.sum(action_mask_forward_micro_batch) + 1e-4) + advantages_forward_micro_batch_std = torch.rsqrt( + torch.sum( + (advantages_forward_micro_batch - advantages_forward_micro_batch_mean) ** 2 + * action_mask_forward_micro_batch + ) + / (torch.sum(action_mask_forward_micro_batch) + 1e-4) + + 1e-8 + ) + advantages_forward_micro_batch = ( + (advantages_forward_micro_batch - advantages_forward_micro_batch_mean) + * action_mask_forward_micro_batch + / (advantages_forward_micro_batch_std) + ) + + per_token_kl = 0.0 + loss, _ = self.policy_loss_fn( action_log_probs, old_action_log_probs_micro_batch, - advantages_forward_micro_batch.repeat_interleave(action_log_probs.size(-1), dim=-1), + advantages_forward_micro_batch, per_token_kl, action_mask_forward_micro_batch, loss_mask=loss_mask_forward_micro_batch, diff --git a/applications/ColossalChat/coati/distributed/launch.py b/applications/ColossalChat/coati/distributed/launch.py index a48246c87..d60312e2b 100644 --- a/applications/ColossalChat/coati/distributed/launch.py +++ b/applications/ColossalChat/coati/distributed/launch.py @@ -9,7 +9,13 @@ from .consumer import SimpleConsumer from .grpo_consumer import GRPOConsumer from .producer import SimpleProducer -ALGO_MAP = {"Simple": SimpleConsumer, "GRPO": GRPOConsumer, "DAPO": GRPOConsumer} +ALGO_MAP = { + "Simple": SimpleConsumer, + "GRPO": GRPOConsumer, + "DAPO": GRPOConsumer, + "REINFORCE_PPB": GRPOConsumer, + "RLOO": GRPOConsumer, +} def get_jsonl_size_fast(path: str) -> int: @@ -66,6 +72,7 @@ def launch_distributed( core_consumer = ALGO_MAP.get(core_algo, SimpleConsumer) train_dp_size = get_dp_size_fast(num_consumer_procs, plugin_config) + assert (inference_batch_size * num_producers) % (train_batch_size * train_dp_size) == 0 dataset_path = train_dataset_config["path"] diff --git a/applications/ColossalChat/coati/distributed/loss.py b/applications/ColossalChat/coati/distributed/loss.py index 36057b24f..ab38f987f 100644 --- a/applications/ColossalChat/coati/distributed/loss.py +++ b/applications/ColossalChat/coati/distributed/loss.py @@ -16,6 +16,7 @@ class PolicyLoss(nn.Module): clip_eps_high: float = 0.2, beta: float = 0.01, loss_variation: str = "sample_level", + adv: str = "GRPO", ) -> None: super().__init__() self.clip_eps_low = clip_eps_low @@ -23,6 +24,7 @@ class PolicyLoss(nn.Module): self.beta = beta self.loss_variation = loss_variation assert loss_variation in ["sample_level", "token_level"], f"Unsupported loss variation: {loss_variation}" + self.adv = adv def forward( self, diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py index fbec2319b..38a85b9b1 100644 --- a/applications/ColossalChat/coati/distributed/producer.py +++ b/applications/ColossalChat/coati/distributed/producer.py @@ -118,6 +118,9 @@ class BaseProducer: self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, **tokenizer_config) self.tokenizer.padding_side = "left" + if self.tokenizer.pad_token_id is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + # init dataloader train_dataset_path = train_dataset_config.pop("path") self.train_dataset = RawConversationDataset(self.tokenizer, train_dataset_path, **train_dataset_config) diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index 42ec582f6..b584b940c 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -137,7 +137,7 @@ if __name__ == "__main__": ) # GRPO parameters - parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["DAPO", "GRPO"]) + parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["DAPO", "GRPO", "REINFORCE_PPB", "RLOO"]) parser.add_argument("-lr", "--learning-rate", type=float, default=1e-6, help="Learning rate for GRPO.") parser.add_argument("-kl", "--kl-coeff", type=float, default=0.01, help="KL penalty coefficient for GRPO.") parser.add_argument( @@ -292,6 +292,7 @@ if __name__ == "__main__": if args.algo == "GRPO": # Default Settings grpo_config = { + "algo": "GRPO", "lr": args.learning_rate, "train_microbatch_size": args.train_microbatch_size, "beta": args.kl_coeff, # KL penalty coefficient @@ -313,6 +314,7 @@ if __name__ == "__main__": elif args.algo == "DAPO": # DAPO variant settings grpo_config = { + "algo": "DAPO", "filter_range": [0.01, 0.99], # only filter out all zero batch and all one batch "lr": args.learning_rate, "train_microbatch_size": args.train_microbatch_size, @@ -339,6 +341,50 @@ if __name__ == "__main__": else None ), } + elif args.algo == "REINFORCE_PPB": + # Default Settings + grpo_config = { + "algo": "REINFORCE_PPB", + "lr": args.learning_rate, + "train_microbatch_size": args.train_microbatch_size, + "beta": args.kl_coeff, # KL penalty coefficient + "loss_variation": "sample_level", + "reward_fn_type": args.reward_type, + "max_length": args.max_new_tokens + args.max_prompt_tokens, + "max_new_tokens": args.max_new_tokens, + "response_format_tags": ( + { + "think_start": {"text": "", "num_occur": 1}, + "think_end": {"text": "", "num_occur": 1}, + "answer_start": {"text": "", "num_occur": 1}, + "answer_end": {"text": "", "num_occur": 1}, + } + if args.reward_type == "think_answer_tags" + else None + ), + } + elif args.algo == "RLOO": + # Default Settings + grpo_config = { + "algo": "RLOO", + "lr": args.learning_rate, + "train_microbatch_size": args.train_microbatch_size, + "beta": args.kl_coeff, # KL penalty coefficient + "loss_variation": "sample_level", + "reward_fn_type": args.reward_type, + "max_length": args.max_new_tokens + args.max_prompt_tokens, + "max_new_tokens": args.max_new_tokens, + "response_format_tags": ( + { + "think_start": {"text": "", "num_occur": 1}, + "think_end": {"text": "", "num_occur": 1}, + "answer_start": {"text": "", "num_occur": 1}, + "answer_end": {"text": "", "num_occur": 1}, + } + if args.reward_type == "think_answer_tags" + else None + ), + } else: raise ValueError(f"Unsupported algorithm: {args.algo}") if args.reward_type == "code":