mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-12 12:47:21 +00:00
Add new implementations of RL algorithms (#6383)
* add new algorithm * move common calculations * delete data * move common calculations of rewards * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -101,6 +101,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
clip_eps_high=grpo_config.get("clip_eps_high", 0.2),
|
clip_eps_high=grpo_config.get("clip_eps_high", 0.2),
|
||||||
beta=grpo_config.get("beta", 0.01),
|
beta=grpo_config.get("beta", 0.01),
|
||||||
loss_variation=grpo_config.get("loss_variation", "sample_level"),
|
loss_variation=grpo_config.get("loss_variation", "sample_level"),
|
||||||
|
adv=grpo_config.get("algo"),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Reference model is initialized from policy model.
|
# Reference model is initialized from policy model.
|
||||||
@@ -137,6 +138,8 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
eta_min=0.1 * grpo_config.get("lr", 1e-6),
|
eta_min=0.1 * grpo_config.get("lr", 1e-6),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.adv = grpo_config.get("algo")
|
||||||
|
|
||||||
def setup(self):
|
def setup(self):
|
||||||
super().setup()
|
super().setup()
|
||||||
if (not self.plugin.pp_size > 1 and self.rank == 0) or (
|
if (not self.plugin.pp_size > 1 and self.rank == 0) or (
|
||||||
@@ -204,9 +207,23 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
# [minibatch_size x num_generations]
|
# [minibatch_size x num_generations]
|
||||||
reward_mean = reward_mean.repeat_interleave(self.num_generations, dim=0)
|
reward_mean = reward_mean.repeat_interleave(self.num_generations, dim=0)
|
||||||
|
|
||||||
reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0)
|
if self.adv == "GRPO" or self.adv == "DAPO":
|
||||||
# [minibatch_size x num_generations]
|
|
||||||
advantages = ((reward - reward_mean) / (reward_std + 1e-4)).unsqueeze(dim=-1)
|
reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0)
|
||||||
|
# [minibatch_size x num_generations]
|
||||||
|
advantages = ((reward - reward_mean) / (reward_std + 1e-4)).unsqueeze(dim=-1)
|
||||||
|
|
||||||
|
elif self.adv == "REINFORCE_PPB":
|
||||||
|
|
||||||
|
# [minibatch_size x num_generations]
|
||||||
|
advantages = ((reward - reward_mean)).unsqueeze(dim=-1)
|
||||||
|
|
||||||
|
elif self.adv == "RLOO":
|
||||||
|
|
||||||
|
advantages = (
|
||||||
|
reward * self.num_generations / (self.num_generations - 1)
|
||||||
|
- reward_mean * self.num_generations / (self.num_generations - 1)
|
||||||
|
).unsqueeze(dim=-1)
|
||||||
|
|
||||||
# [minibatch_size x num_of_generation]
|
# [minibatch_size x num_of_generation]
|
||||||
loss_mask = torch.ones(action_mask.size(0), device=action_mask.device).bool()
|
loss_mask = torch.ones(action_mask.size(0), device=action_mask.device).bool()
|
||||||
@@ -358,10 +375,34 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
per_token_kl = 0.0
|
per_token_kl = 0.0
|
||||||
kl.append(torch.tensor(0.0))
|
kl.append(torch.tensor(0.0))
|
||||||
|
|
||||||
|
inputs["advantages"].repeat_interleave(action_log_probs.size(-1), dim=-1)
|
||||||
|
|
||||||
|
if self.adv == "REINFORCE_PPB":
|
||||||
|
|
||||||
|
inputs["advantages"] = inputs["advantages"] - self.policy_loss_fn.beta * per_token_kl
|
||||||
|
advantages_forward_micro_batch_mean = torch.sum(
|
||||||
|
inputs["advantages"] * inputs["action_mask"]
|
||||||
|
) / (torch.sum(inputs["action_mask"]) + 1e-4)
|
||||||
|
advantages_forward_micro_batch_std = torch.rsqrt(
|
||||||
|
torch.sum(
|
||||||
|
(inputs["advantages"] - advantages_forward_micro_batch_mean) ** 2
|
||||||
|
* inputs["action_mask"]
|
||||||
|
)
|
||||||
|
/ (torch.sum(inputs["action_mask"]) + 1e-4)
|
||||||
|
+ 1e-8
|
||||||
|
)
|
||||||
|
inputs["advantages"] = (
|
||||||
|
(inputs["advantages"] - advantages_forward_micro_batch_mean)
|
||||||
|
* inputs["action_mask"]
|
||||||
|
/ (advantages_forward_micro_batch_std)
|
||||||
|
)
|
||||||
|
|
||||||
|
per_token_kl = 0.0
|
||||||
|
|
||||||
loss, _ = self.policy_loss_fn(
|
loss, _ = self.policy_loss_fn(
|
||||||
action_log_probs,
|
action_log_probs,
|
||||||
inputs["old_action_log_probs"],
|
inputs["old_action_log_probs"],
|
||||||
inputs["advantages"].repeat_interleave(action_log_probs.size(-1), dim=-1),
|
inputs["advantages"],
|
||||||
per_token_kl,
|
per_token_kl,
|
||||||
inputs["action_mask"],
|
inputs["action_mask"],
|
||||||
loss_mask=inputs["loss_mask"],
|
loss_mask=inputs["loss_mask"],
|
||||||
@@ -420,10 +461,39 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
per_token_kl = 0.0
|
per_token_kl = 0.0
|
||||||
kl = None
|
kl = None
|
||||||
|
|
||||||
|
(
|
||||||
|
advantages_forward_micro_batch.repeat_interleave(action_log_probs.size(-1), dim=-1)
|
||||||
|
- self.policy_loss_fn.beta * per_token_kl
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.adv == "REINFORCE_PPB":
|
||||||
|
|
||||||
|
advantages_forward_micro_batch = (
|
||||||
|
advantages_forward_micro_batch - self.policy_loss_fn.beta * per_token_kl
|
||||||
|
)
|
||||||
|
advantages_forward_micro_batch_mean = torch.sum(
|
||||||
|
advantages_forward_micro_batch * action_mask_forward_micro_batch
|
||||||
|
) / (torch.sum(action_mask_forward_micro_batch) + 1e-4)
|
||||||
|
advantages_forward_micro_batch_std = torch.rsqrt(
|
||||||
|
torch.sum(
|
||||||
|
(advantages_forward_micro_batch - advantages_forward_micro_batch_mean) ** 2
|
||||||
|
* action_mask_forward_micro_batch
|
||||||
|
)
|
||||||
|
/ (torch.sum(action_mask_forward_micro_batch) + 1e-4)
|
||||||
|
+ 1e-8
|
||||||
|
)
|
||||||
|
advantages_forward_micro_batch = (
|
||||||
|
(advantages_forward_micro_batch - advantages_forward_micro_batch_mean)
|
||||||
|
* action_mask_forward_micro_batch
|
||||||
|
/ (advantages_forward_micro_batch_std)
|
||||||
|
)
|
||||||
|
|
||||||
|
per_token_kl = 0.0
|
||||||
|
|
||||||
loss, _ = self.policy_loss_fn(
|
loss, _ = self.policy_loss_fn(
|
||||||
action_log_probs,
|
action_log_probs,
|
||||||
old_action_log_probs_micro_batch,
|
old_action_log_probs_micro_batch,
|
||||||
advantages_forward_micro_batch.repeat_interleave(action_log_probs.size(-1), dim=-1),
|
advantages_forward_micro_batch,
|
||||||
per_token_kl,
|
per_token_kl,
|
||||||
action_mask_forward_micro_batch,
|
action_mask_forward_micro_batch,
|
||||||
loss_mask=loss_mask_forward_micro_batch,
|
loss_mask=loss_mask_forward_micro_batch,
|
||||||
|
@@ -9,7 +9,13 @@ from .consumer import SimpleConsumer
|
|||||||
from .grpo_consumer import GRPOConsumer
|
from .grpo_consumer import GRPOConsumer
|
||||||
from .producer import SimpleProducer
|
from .producer import SimpleProducer
|
||||||
|
|
||||||
ALGO_MAP = {"Simple": SimpleConsumer, "GRPO": GRPOConsumer, "DAPO": GRPOConsumer}
|
ALGO_MAP = {
|
||||||
|
"Simple": SimpleConsumer,
|
||||||
|
"GRPO": GRPOConsumer,
|
||||||
|
"DAPO": GRPOConsumer,
|
||||||
|
"REINFORCE_PPB": GRPOConsumer,
|
||||||
|
"RLOO": GRPOConsumer,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def get_jsonl_size_fast(path: str) -> int:
|
def get_jsonl_size_fast(path: str) -> int:
|
||||||
@@ -66,6 +72,7 @@ def launch_distributed(
|
|||||||
core_consumer = ALGO_MAP.get(core_algo, SimpleConsumer)
|
core_consumer = ALGO_MAP.get(core_algo, SimpleConsumer)
|
||||||
|
|
||||||
train_dp_size = get_dp_size_fast(num_consumer_procs, plugin_config)
|
train_dp_size = get_dp_size_fast(num_consumer_procs, plugin_config)
|
||||||
|
|
||||||
assert (inference_batch_size * num_producers) % (train_batch_size * train_dp_size) == 0
|
assert (inference_batch_size * num_producers) % (train_batch_size * train_dp_size) == 0
|
||||||
|
|
||||||
dataset_path = train_dataset_config["path"]
|
dataset_path = train_dataset_config["path"]
|
||||||
|
@@ -16,6 +16,7 @@ class PolicyLoss(nn.Module):
|
|||||||
clip_eps_high: float = 0.2,
|
clip_eps_high: float = 0.2,
|
||||||
beta: float = 0.01,
|
beta: float = 0.01,
|
||||||
loss_variation: str = "sample_level",
|
loss_variation: str = "sample_level",
|
||||||
|
adv: str = "GRPO",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.clip_eps_low = clip_eps_low
|
self.clip_eps_low = clip_eps_low
|
||||||
@@ -23,6 +24,7 @@ class PolicyLoss(nn.Module):
|
|||||||
self.beta = beta
|
self.beta = beta
|
||||||
self.loss_variation = loss_variation
|
self.loss_variation = loss_variation
|
||||||
assert loss_variation in ["sample_level", "token_level"], f"Unsupported loss variation: {loss_variation}"
|
assert loss_variation in ["sample_level", "token_level"], f"Unsupported loss variation: {loss_variation}"
|
||||||
|
self.adv = adv
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@@ -118,6 +118,9 @@ class BaseProducer:
|
|||||||
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, **tokenizer_config)
|
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, **tokenizer_config)
|
||||||
self.tokenizer.padding_side = "left"
|
self.tokenizer.padding_side = "left"
|
||||||
|
|
||||||
|
if self.tokenizer.pad_token_id is None:
|
||||||
|
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||||
|
|
||||||
# init dataloader
|
# init dataloader
|
||||||
train_dataset_path = train_dataset_config.pop("path")
|
train_dataset_path = train_dataset_config.pop("path")
|
||||||
self.train_dataset = RawConversationDataset(self.tokenizer, train_dataset_path, **train_dataset_config)
|
self.train_dataset = RawConversationDataset(self.tokenizer, train_dataset_path, **train_dataset_config)
|
||||||
|
@@ -137,7 +137,7 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
|
|
||||||
# GRPO parameters
|
# GRPO parameters
|
||||||
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["DAPO", "GRPO"])
|
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["DAPO", "GRPO", "REINFORCE_PPB", "RLOO"])
|
||||||
parser.add_argument("-lr", "--learning-rate", type=float, default=1e-6, help="Learning rate for GRPO.")
|
parser.add_argument("-lr", "--learning-rate", type=float, default=1e-6, help="Learning rate for GRPO.")
|
||||||
parser.add_argument("-kl", "--kl-coeff", type=float, default=0.01, help="KL penalty coefficient for GRPO.")
|
parser.add_argument("-kl", "--kl-coeff", type=float, default=0.01, help="KL penalty coefficient for GRPO.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@@ -292,6 +292,7 @@ if __name__ == "__main__":
|
|||||||
if args.algo == "GRPO":
|
if args.algo == "GRPO":
|
||||||
# Default Settings
|
# Default Settings
|
||||||
grpo_config = {
|
grpo_config = {
|
||||||
|
"algo": "GRPO",
|
||||||
"lr": args.learning_rate,
|
"lr": args.learning_rate,
|
||||||
"train_microbatch_size": args.train_microbatch_size,
|
"train_microbatch_size": args.train_microbatch_size,
|
||||||
"beta": args.kl_coeff, # KL penalty coefficient
|
"beta": args.kl_coeff, # KL penalty coefficient
|
||||||
@@ -313,6 +314,7 @@ if __name__ == "__main__":
|
|||||||
elif args.algo == "DAPO":
|
elif args.algo == "DAPO":
|
||||||
# DAPO variant settings
|
# DAPO variant settings
|
||||||
grpo_config = {
|
grpo_config = {
|
||||||
|
"algo": "DAPO",
|
||||||
"filter_range": [0.01, 0.99], # only filter out all zero batch and all one batch
|
"filter_range": [0.01, 0.99], # only filter out all zero batch and all one batch
|
||||||
"lr": args.learning_rate,
|
"lr": args.learning_rate,
|
||||||
"train_microbatch_size": args.train_microbatch_size,
|
"train_microbatch_size": args.train_microbatch_size,
|
||||||
@@ -339,6 +341,50 @@ if __name__ == "__main__":
|
|||||||
else None
|
else None
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
elif args.algo == "REINFORCE_PPB":
|
||||||
|
# Default Settings
|
||||||
|
grpo_config = {
|
||||||
|
"algo": "REINFORCE_PPB",
|
||||||
|
"lr": args.learning_rate,
|
||||||
|
"train_microbatch_size": args.train_microbatch_size,
|
||||||
|
"beta": args.kl_coeff, # KL penalty coefficient
|
||||||
|
"loss_variation": "sample_level",
|
||||||
|
"reward_fn_type": args.reward_type,
|
||||||
|
"max_length": args.max_new_tokens + args.max_prompt_tokens,
|
||||||
|
"max_new_tokens": args.max_new_tokens,
|
||||||
|
"response_format_tags": (
|
||||||
|
{
|
||||||
|
"think_start": {"text": "<think>", "num_occur": 1},
|
||||||
|
"think_end": {"text": "</think>", "num_occur": 1},
|
||||||
|
"answer_start": {"text": "<answer>", "num_occur": 1},
|
||||||
|
"answer_end": {"text": "</answer>", "num_occur": 1},
|
||||||
|
}
|
||||||
|
if args.reward_type == "think_answer_tags"
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
}
|
||||||
|
elif args.algo == "RLOO":
|
||||||
|
# Default Settings
|
||||||
|
grpo_config = {
|
||||||
|
"algo": "RLOO",
|
||||||
|
"lr": args.learning_rate,
|
||||||
|
"train_microbatch_size": args.train_microbatch_size,
|
||||||
|
"beta": args.kl_coeff, # KL penalty coefficient
|
||||||
|
"loss_variation": "sample_level",
|
||||||
|
"reward_fn_type": args.reward_type,
|
||||||
|
"max_length": args.max_new_tokens + args.max_prompt_tokens,
|
||||||
|
"max_new_tokens": args.max_new_tokens,
|
||||||
|
"response_format_tags": (
|
||||||
|
{
|
||||||
|
"think_start": {"text": "<think>", "num_occur": 1},
|
||||||
|
"think_end": {"text": "</think>", "num_occur": 1},
|
||||||
|
"answer_start": {"text": "<answer>", "num_occur": 1},
|
||||||
|
"answer_end": {"text": "</answer>", "num_occur": 1},
|
||||||
|
}
|
||||||
|
if args.reward_type == "think_answer_tags"
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported algorithm: {args.algo}")
|
raise ValueError(f"Unsupported algorithm: {args.algo}")
|
||||||
if args.reward_type == "code":
|
if args.reward_type == "code":
|
||||||
|
Reference in New Issue
Block a user