[feat] Support DAPO (#6263)

* update help information

* update style

* fix

* minor fix

* support PP training

* add pp support

* remove unused code

* address conversation

* fix memory leakage support tp+pp

* move empty cache

* move empty cache

* add DAPO support

* remove format reward

* fix filtering, still buggy

* small fix

* add DAPO support

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* tested multi-node training; fix bind_batch bug

* fix conversation; support sleep mode

* support reusing excessive samples

* add dynamic batching control flag

* add dynamic batching control flag

* refactored

* fix logging

---------

Co-authored-by: Tong Li <tong.li35271158@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
YeAnbang 2025-04-25 17:39:17 +08:00 committed by GitHub
parent b823c6eec7
commit 26d859f68e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 552 additions and 359 deletions

View File

@ -16,7 +16,7 @@ from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from .comm import ray_broadcast_tensor_dict from .comm import ray_broadcast_tensor_dict
from .utils import bind_batch, post_recv, unbind_batch from .utils import bind_batch, pad_batch, post_recv, unbind_batch
class BaseConsumer: class BaseConsumer:
@ -33,7 +33,7 @@ class BaseConsumer:
batch_size: int, batch_size: int,
model_config: Dict[str, Any], model_config: Dict[str, Any],
plugin_config: Dict[str, Any], plugin_config: Dict[str, Any],
microbatch_size: int = 1, minibatch_size: int = 1,
save_interval: int = 100, save_interval: int = 100,
save_dir: str = "./model", save_dir: str = "./model",
): ):
@ -46,11 +46,11 @@ class BaseConsumer:
self.num_update_per_episode = num_update_per_episode self.num_update_per_episode = num_update_per_episode
self.num_recv_per_update = num_recv_per_update self.num_recv_per_update = num_recv_per_update
self.batch_size = batch_size self.batch_size = batch_size
self.microbatch_size = microbatch_size self.minibatch_size = minibatch_size
self.save_interval = save_interval self.save_interval = save_interval
self.save_dir = save_dir self.save_dir = save_dir
assert batch_size % microbatch_size == 0, "batch_size should be divisible by microbatch_size" assert batch_size % minibatch_size == 0, "batch_size should be divisible by microbatch_size"
self.num_microbatches = batch_size // microbatch_size self.num_microbatches = batch_size // minibatch_size
self.model_config = model_config self.model_config = model_config
self.plugin_config = plugin_config self.plugin_config = plugin_config
@ -67,7 +67,7 @@ class BaseConsumer:
plugin_config = dict(tp_size=1, pp_size=1, precision="bf16", zero_stage=2) plugin_config = dict(tp_size=1, pp_size=1, precision="bf16", zero_stage=2)
if self.plugin_config.get("pp_size", 1) > 1 and "num_microbatches" not in self.plugin_config: if self.plugin_config.get("pp_size", 1) > 1 and "num_microbatches" not in self.plugin_config:
plugin_config["microbatch_size"] = self.microbatch_size plugin_config["microbatch_size"] = self.minibatch_size
plugin_config.update(self.plugin_config) plugin_config.update(self.plugin_config)
self.plugin = HybridParallelPlugin(**plugin_config) self.plugin = HybridParallelPlugin(**plugin_config)
self.booster = Booster(plugin=self.plugin) self.booster = Booster(plugin=self.plugin)
@ -105,18 +105,26 @@ class BaseConsumer:
) )
) )
) )
while len(self.buffer) >= self.dp_size * self.microbatch_size: while len(self.buffer) >= self.dp_size * self.minibatch_size:
batches = self.buffer[ batches = self.buffer[
self.dp_rank * self.microbatch_size : (self.dp_rank + 1) * self.microbatch_size self.dp_rank * self.minibatch_size : (self.dp_rank + 1) * self.minibatch_size
] ]
self.buffer = self.buffer[self.dp_size * self.microbatch_size :] batch = pad_batch(
batches
) # when `imbs` is smaller than `tMbs`, samples may have differ in size, need to pad before stacking
batch = bind_batch(batches) batch = bind_batch(batches)
batch = post_recv(batch) batch = post_recv(batch)
loss = self.step(i, **batch) loss, num_excessive_prompts = self.step(i, pbar, **batch)
self.buffer = (
self.buffer[
(self.dp_rank + 1) * self.minibatch_size
- num_excessive_prompts : (self.dp_rank + 1) * self.minibatch_size
]
+ self.buffer[self.dp_size * self.minibatch_size :]
)
if loss is not None: if loss is not None:
pbar.set_postfix({"loss": loss}) pbar.set_postfix({"loss": loss})
i += 1 i += 1
assert len(self.buffer) == 0
if self.lr_scheduler is not None: if self.lr_scheduler is not None:
self.lr_scheduler.step() self.lr_scheduler.step()
if (step + 1) % self.save_interval == 0 or (step + 1) == self.num_update_per_episode: if (step + 1) % self.save_interval == 0 or (step + 1) == self.num_update_per_episode:
@ -154,7 +162,9 @@ class SimpleConsumer(BaseConsumer):
batch_size, batch_size,
model_config, model_config,
plugin_config, plugin_config,
microbatch_size=1, minibatch_size=1,
save_interval: int = 100,
save_dir="./model",
): ):
super().__init__( super().__init__(
num_producers, num_producers,
@ -168,7 +178,7 @@ class SimpleConsumer(BaseConsumer):
batch_size, batch_size,
model_config, model_config,
plugin_config, plugin_config,
microbatch_size, minibatch_size,
) )
path = model_config.pop("path") path = model_config.pop("path")
self.model = AutoModelForCausalLM.from_pretrained(path, **model_config) self.model = AutoModelForCausalLM.from_pretrained(path, **model_config)
@ -181,7 +191,7 @@ class SimpleConsumer(BaseConsumer):
super().setup() super().setup()
self.model, self.optimizer, *_ = self.booster.boost(self.model, self.optimizer) self.model, self.optimizer, *_ = self.booster.boost(self.model, self.optimizer)
def step(self, step_idx: int, **kwargs) -> Optional[float]: def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
labels = kwargs["input_ids"].clone() labels = kwargs["input_ids"].clone()
labels[kwargs["attention_mask"] == 0] = -100 labels[kwargs["attention_mask"] == 0] = -100
kwargs["labels"] = labels kwargs["labels"] = labels

View File

@ -1,18 +1,16 @@
import json import warnings
import os
from contextlib import nullcontext from contextlib import nullcontext
from typing import Optional from typing import Any, Optional
import ray import ray
import torch import torch
import torch.distributed as dist
import wandb import wandb
from coati.distributed.consumer import BaseConsumer from coati.distributed.consumer import BaseConsumer
from coati.distributed.loss import PolicyLoss from coati.distributed.loss import PolicyLoss
from coati.distributed.reward.reward_fn import math_reward_fn from coati.distributed.reward.reward_fn import math_reward_fn
from coati.distributed.reward.verifiable_reward import VerifiableReward from coati.distributed.reward.verifiable_reward import VerifiableReward
from coati.distributed.utils import calc_action_log_probs from coati.distributed.utils import calc_action_log_probs
from coati.trainer.utils import all_reduce_mean from coati.trainer.utils import all_reduce_mean, all_reduce_sum
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
@ -34,13 +32,23 @@ class GRPOConsumer(BaseConsumer):
batch_size, batch_size,
model_config, model_config,
plugin_config, plugin_config,
microbatch_size=1, minibatch_size=1,
num_generations=8, num_generations=8,
use_wandb=True, use_wandb=True,
generate_config=None, generate_config=None,
training_config={}, grpo_config={},
project_name=None, project_name=None,
save_interval: int = 100,
save_dir="./model",
): ):
print(f"Using GRPO config: {grpo_config}")
if grpo_config.get("loss_variation", "sample_level") == "token_level":
if batch_size != minibatch_size:
warnings.warn(
f"Applied token_level loss, force overwrite mini-batch-size with batch-size: mini-batch-size: {minibatch_size}->{batch_size}",
UserWarning,
)
minibatch_size = batch_size
super().__init__( super().__init__(
num_producers, num_producers,
num_episodes, num_episodes,
@ -53,36 +61,58 @@ class GRPOConsumer(BaseConsumer):
batch_size, batch_size,
model_config, model_config,
plugin_config, plugin_config,
microbatch_size, minibatch_size,
save_dir=save_dir,
) )
path = model_config.pop("path") path = model_config.pop("path")
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config) self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
self.policy_model.train() self.policy_model.train()
self.policy_model.gradient_checkpointing_enable() self.policy_model.gradient_checkpointing_enable()
self.optimizer = HybridAdam(self.policy_model.parameters(), lr=training_config.get("lr", 1e-6)) self.optimizer = HybridAdam(self.policy_model.parameters(), lr=grpo_config.get("lr", 1e-6))
self.accum_loss = torch.zeros(1, device=self.device) self.accum_loss = torch.zeros(1, device=self.device)
self.accum_reward = torch.zeros(1, device=self.device) self.accum_reward = torch.zeros(1, device=self.device)
self.accum_kl = torch.zeros(1, device=self.device) self.accum_kl = torch.zeros(1, device=self.device)
self.accum_format_reward = torch.zeros(1, device=self.device) self.accum_format_acc = torch.zeros(1, device=self.device)
self.accum_acc_reward = torch.zeros(1, device=self.device) self.accum_ans_acc = torch.zeros(1, device=self.device)
self.accum_advantages = torch.zeros(1, device=self.device) self.accum_advantages = torch.zeros(1, device=self.device)
self.accum_response_length = torch.zeros(1, device=self.device) self.accum_response_length = torch.zeros(1, device=self.device)
self.accum_count = 0 self.accum_count = 0
self.generate_config = generate_config self.generate_config = generate_config
self.training_config = training_config self.grpo_config = grpo_config
self.project_name = project_name self.project_name = project_name
self.effective_sample_count = 0
self.total_sample_count = 0
self.policy_loss_fn = PolicyLoss(
clip_eps_low=grpo_config.get("clip_eps_low", 0.2),
clip_eps_high=grpo_config.get("clip_eps_high", 0.2),
beta=grpo_config.get("beta", 0.01),
loss_variation=grpo_config.get("loss_variation", "sample_level"),
)
# Reference model is initialized from policy model. # Reference model is initialized from policy model.
self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config) if self.policy_loss_fn.beta > 0:
self.reference_model.eval() self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
self.reference_model.eval()
self.tokenizer = AutoTokenizer.from_pretrained(path) self.tokenizer = AutoTokenizer.from_pretrained(path)
self.pad_token_id = self.tokenizer.pad_token_id self.pad_token_id = self.tokenizer.pad_token_id
self.num_generations = num_generations self.num_generations = num_generations
self.filter_range = training_config.get("filter_range", None) self.filter_range = grpo_config.get("filter_range", None)
if self.filter_range is not None: if self.filter_range is not None:
assert len(self.filter_range) == 2, "Filter range should have 2 values." assert len(self.filter_range) == 2, "Filter range should have 2 values."
self.filter_truncated_response = grpo_config.get("filter_truncated_response", False)
if self.filter_truncated_response:
self.max_length = 0
if "max_tokens" in self.generate_config:
self.max_length = self.generate_config["max_tokens"]
elif "max_new_tokens" in self.generate_config:
self.max_length = self.generate_config["max_new_tokens"]
else:
raise ValueError(
"either max_tokens (vllm) or max_new_tokens (transformers) must be set in generate_config."
)
# Initialize verifiable reward. # Initialize verifiable reward.
response_format_tags = { response_format_tags = {
"think_start": {"text": "<think>", "num_occur": 1}, "think_start": {"text": "<think>", "num_occur": 1},
@ -90,11 +120,12 @@ class GRPOConsumer(BaseConsumer):
"answer_start": {"text": "<answer>", "num_occur": 1}, "answer_start": {"text": "<answer>", "num_occur": 1},
"answer_end": {"text": "</answer>", "num_occur": 1}, "answer_end": {"text": "</answer>", "num_occur": 1},
} }
reward_model_kwargs = {
k: v for k, v in grpo_config.items() if k in ["soft_over_length_punishment", "max_length", "cache_length"]
}
self.reward_model = VerifiableReward( self.reward_model = VerifiableReward(
reward_fns=[math_reward_fn], tokenizer=self.tokenizer, tags=response_format_tags reward_fns=[math_reward_fn], tokenizer=self.tokenizer, tags=response_format_tags, **reward_model_kwargs
) )
self.policy_loss_fn = PolicyLoss()
self.global_step = 0 self.global_step = 0
self.use_wandb = use_wandb self.use_wandb = use_wandb
@ -102,7 +133,7 @@ class GRPOConsumer(BaseConsumer):
optimizer=self.optimizer, optimizer=self.optimizer,
total_steps=min(self.num_episodes, 4) * self.num_update_per_episode, total_steps=min(self.num_episodes, 4) * self.num_update_per_episode,
warmup_steps=0, warmup_steps=0,
eta_min=0.1 * training_config.get("lr", 1e-6), eta_min=0.1 * grpo_config.get("lr", 1e-6),
) )
def setup(self): def setup(self):
@ -118,10 +149,11 @@ class GRPOConsumer(BaseConsumer):
self.policy_model, self.optimizer, _, _, self.lr_scheduler = self.booster.boost( self.policy_model, self.optimizer, _, _, self.lr_scheduler = self.booster.boost(
self.policy_model, self.optimizer, lr_scheduler=self.lr_scheduler self.policy_model, self.optimizer, lr_scheduler=self.lr_scheduler
) )
self.reference_model, *_ = self.booster.boost(self.reference_model) if self.policy_loss_fn.beta > 0:
self.reference_model, *_ = self.booster.boost(self.reference_model)
self.plugin.logger.set_level("ERROR") self.plugin.logger.set_level("ERROR")
def step(self, step_idx: int, **kwargs) -> Optional[float]: def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
""" """
Step data from policy model: Step data from policy model:
[{ [{
@ -132,18 +164,108 @@ class GRPOConsumer(BaseConsumer):
}, },
...] ...]
Format: Format:
[batch_size, num_of_generation, prompt_length + response_length] --- <PAD>...<PAD><PROMPT>...<PROMPT><RESPONSE>...<RESPONSE><PAD>...<PAD>. [minibatch_size, num_of_generation, prompt_length + response_length] --- <PAD>...<PAD><PROMPT>...<PROMPT><RESPONSE>...<RESPONSE><PAD>...<PAD>.
""" """
# Reshape to [batch_size x num_of_generation, prompt_length + response_length] # Reshape to [minibatch_size x num_of_generation, prompt_length + response_length]
data = {k: v.view(-1, v.size(-1)) for k, v in kwargs.items()} data = {k: v.view(-1, v.size(-1)) for k, v in kwargs.items()}
action_mask = data["action_mask"] action_mask = data["action_mask"]
num_action = action_mask.shape[1] num_action = action_mask.shape[1]
old_action_log_probs = data["action_log_probs"] old_action_log_probs = data["action_log_probs"]
response_length = torch.sum(action_mask, dim=1).to(torch.float32) response_length = torch.sum(action_mask, dim=1).to(torch.float32)
forward_batch_size = self.training_config.get("train_microbatch_size", data["input_ids"].size(0)) train_microbatch_size = self.grpo_config.get("train_microbatch_size", data["input_ids"].size(0))
reward_group = self.reward_model(
data["input_ids"],
gt_answer=data["gt_answer"],
response_idx=data["response_idx"],
)
reward = torch.tensor([value[0] for value in reward_group]).to(data["input_ids"].device)
format_acc = torch.tensor([value[1] for value in reward_group]).to(data["input_ids"].device)
ans_acc = torch.tensor([value[2] for value in reward_group]).to(data["input_ids"].device)
# [minibatch_size, num_generations]
group_reward = reward.view(-1, self.num_generations)
reward_mean = group_reward.mean(dim=1)
# [minibatch_size x num_generations]
reward_mean = reward_mean.repeat_interleave(self.num_generations, dim=0)
reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0)
# [minibatch_size x num_generations]
advantages = ((reward - reward_mean) / (reward_std + 1e-4)).unsqueeze(dim=-1)
# filter out the reward that is too high (all sample gets full score) or too low (all sample gets 0 score),
group_ans_acc = (
ans_acc.view(-1, self.num_generations).mean(dim=1).repeat_interleave(self.num_generations, dim=0)
)
loss_mask = (
torch.ones(action_mask.size(0), device=action_mask.device).bool()
if self.filter_range is None
else torch.logical_and(group_ans_acc > self.filter_range[0], group_ans_acc < self.filter_range[1])
)
# filter out overlength samples
if self.filter_truncated_response and action_mask.size(1) == self.max_length:
loss_mask = torch.logical_and(
loss_mask,
action_mask[:, -1] == False,
)
effective_tokens_count = torch.sum(action_mask, dim=-1) * loss_mask
effective_samples = all_reduce_sum(torch.sum(loss_mask), self.plugin)
total_effective_tokens_count = all_reduce_sum(torch.sum(effective_tokens_count), self.plugin)
total_samples = all_reduce_sum(torch.sum(torch.ones_like(loss_mask, device=loss_mask.device)), self.plugin)
self.effective_sample_count += effective_samples.item()
self.total_sample_count += total_samples.item()
mean_kl, mean_loss = [], []
if self.grpo_config.get("dynamic_batching", True):
need_update = self.effective_sample_count >= self.batch_size * self.dp_size * self.num_generations
# to exclude the excessive samples from the last batch, the last num_excessive_samples samples are not used for training and will be kept in buffer for the next iteration.
num_excessive_samples = (
int(
(self.effective_sample_count - self.batch_size * self.dp_size * self.num_generations)
/ self.num_generations
/ self.dp_size
)
* self.num_generations
)
if num_excessive_samples > 0:
data = {
k: (
v[: -num_excessive_samples if num_excessive_samples != 0 else v.size(0)]
if k
in [
"input_ids",
"attention_mask",
"action_log_probs",
"action_mask",
"response_idx",
"gt_answer",
]
else v
)
for k, v in data.items()
}
action_mask = action_mask[
: -num_excessive_samples if num_excessive_samples != 0 else action_mask.size(0)
]
loss_mask = loss_mask[: -num_excessive_samples if num_excessive_samples != 0 else loss_mask.size(0)]
advantages = advantages[: -num_excessive_samples if num_excessive_samples != 0 else advantages.size(0)]
else:
num_excessive_samples = 0
else:
# If dynamic batching is disabled, we need to use all samples for training.
need_update = (step_idx + 1) % self.num_microbatches == 0
num_excessive_samples = 0
pbar.set_postfix(
{
"Step": self.global_step + 1,
"Status": f"Collecting: {self.effective_sample_count}/{self.batch_size * self.dp_size * self.num_generations}",
}
)
need_update = (step_idx + 1) % self.num_microbatches == 0
# Gradient must be synchronized if zero2 is enabled. https://github.com/hpcaitech/ColossalAI/blob/44d4053fec005fe0b06b6bc755fdc962463145df/colossalai/booster/plugin/hybrid_parallel_plugin.py#L1500 # Gradient must be synchronized if zero2 is enabled. https://github.com/hpcaitech/ColossalAI/blob/44d4053fec005fe0b06b6bc755fdc962463145df/colossalai/booster/plugin/hybrid_parallel_plugin.py#L1500
ctx = ( ctx = (
nullcontext() nullcontext()
@ -151,95 +273,71 @@ class GRPOConsumer(BaseConsumer):
else self.booster.no_sync(self.policy_model, self.optimizer) else self.booster.no_sync(self.policy_model, self.optimizer)
) )
with ctx: with ctx:
reward_group = self.reward_model( for forward_micro_batch_start in range(0, data["input_ids"].size(0), train_microbatch_size):
data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"]
)
reward = torch.tensor([value[0] for value in reward_group]).to(data["input_ids"].device)
format_reward = torch.tensor([value[1] for value in reward_group]).to(data["input_ids"].device)
acc_reward = torch.tensor([value[2] for value in reward_group]).to(data["input_ids"].device)
# [batch_size, num_generations]
group_reward = reward.view(-1, self.num_generations)
reward_mean = group_reward.mean(dim=1)
# [batch_size x num_generations]
reward_mean = reward_mean.repeat_interleave(self.num_generations, dim=0)
reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0)
# [batch_size x num_generations]
advantages = ((reward - reward_mean) / (reward_std + 1e-4)).unsqueeze(dim=-1)
# filter out the reward that is too high (all sample gets full score) or too low (all sample gets 0 score),
loss_mask = (
None
if self.filter_range is None
else torch.logical_and(
reward_mean > self.filter_range[0], reward_mean < self.filter_range[1]
).repeat_interleave(self.num_generations, dim=0)
)
mean_kl, mean_loss = [], []
for forward_micro_batch_start in range(0, data["input_ids"].size(0), forward_batch_size):
input_ids_forward_micro_batch = data["input_ids"][ input_ids_forward_micro_batch = data["input_ids"][
forward_micro_batch_start : forward_micro_batch_start + forward_batch_size forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
] ]
attention_mask_forward_micro_batch = data["attention_mask"][ attention_mask_forward_micro_batch = data["attention_mask"][
forward_micro_batch_start : forward_micro_batch_start + forward_batch_size forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
] ]
action_mask_forward_micro_batch = action_mask[ action_mask_forward_micro_batch = action_mask[
forward_micro_batch_start : forward_micro_batch_start + forward_batch_size forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
] ]
loss_mask_forward_micro_batch = ( loss_mask_forward_micro_batch = (
loss_mask[forward_micro_batch_start : forward_micro_batch_start + forward_batch_size] loss_mask[forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size]
if loss_mask is not None if loss_mask is not None
else None else None
) )
advantages_forward_micro_batch = advantages[ advantages_forward_micro_batch = advantages[
forward_micro_batch_start : forward_micro_batch_start + forward_batch_size forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
] ]
if self.plugin.pp_size > 1: if self.plugin.pp_size > 1:
# Support training with PP. # Support training with PP.
if self.policy_loss_fn.beta > 0:
with torch.no_grad():
reference_model_outputs = self.booster.execute_pipeline(
iter(
[
{
"input_ids": input_ids_forward_micro_batch,
"attention_mask": attention_mask_forward_micro_batch,
}
]
),
self.reference_model,
criterion=lambda outputs, inputs: torch.tensor(
[0.0], device=action_mask.device
), # dummy criterion
optimizer=None,
return_loss=False,
return_outputs=True,
)
with torch.no_grad(): if self.booster.plugin.stage_manager.is_last_stage():
reference_model_outputs = self.booster.execute_pipeline( reference_model_logits = reference_model_outputs["outputs"]["logits"]
iter( reference_action_log_probs = calc_action_log_probs(
[ reference_model_logits / self.generate_config["temperature"],
{ input_ids_forward_micro_batch,
"input_ids": input_ids_forward_micro_batch, num_action,
"attention_mask": attention_mask_forward_micro_batch, self.plugin.shard_config,
} )
] else:
), # Dummy reference logprobs for data iterator.
self.reference_model, reference_action_log_probs = None
criterion=lambda outputs, inputs: torch.tensor(
[0.0], device=action_mask.device
), # dummy criterion
optimizer=None,
return_loss=False,
return_outputs=True,
)
if self.booster.plugin.stage_manager.is_last_stage():
reference_model_logits = reference_model_outputs["outputs"]["logits"]
reference_action_log_probs = calc_action_log_probs(
reference_model_logits / self.generate_config["temperature"],
input_ids_forward_micro_batch,
num_action,
self.plugin.shard_config,
)
else: else:
# Dummy reference logprobs for data iterator.
reference_action_log_probs = None reference_action_log_probs = None
data_policy_forward = { data_policy_forward = {
"input_ids": input_ids_forward_micro_batch, "input_ids": input_ids_forward_micro_batch,
"attention_mask": attention_mask_forward_micro_batch, "attention_mask": attention_mask_forward_micro_batch,
"action_mask": action_mask_forward_micro_batch, "action_mask": action_mask_forward_micro_batch,
"reference_action_log_probs": reference_action_log_probs,
"advantages": advantages_forward_micro_batch, "advantages": advantages_forward_micro_batch,
"loss_mask": loss_mask_forward_micro_batch, "loss_mask": loss_mask_forward_micro_batch,
"source": self.rank, "source": self.rank,
} }
if reference_action_log_probs is not None:
data_policy_forward["reference_action_log_probs"] = reference_action_log_probs
kl = [] kl = []
@ -251,24 +349,30 @@ class GRPOConsumer(BaseConsumer):
num_action, num_action,
self.plugin.shard_config, self.plugin.shard_config,
) )
per_token_kl = ( if "reference_action_log_probs" in inputs:
torch.exp(inputs["reference_action_log_probs"] - action_log_probs) per_token_kl = (
- (inputs["reference_action_log_probs"] - action_log_probs) torch.exp(inputs["reference_action_log_probs"] - action_log_probs)
- 1 - (inputs["reference_action_log_probs"] - action_log_probs)
) - 1
appox_kl = torch.sum(per_token_kl * inputs["action_mask"], dim=-1) / torch.sum( )
inputs["action_mask"], dim=-1 appox_kl = torch.sum(per_token_kl * inputs["action_mask"], dim=-1) / torch.sum(
) inputs["action_mask"], dim=-1
kl.append(appox_kl.mean()) )
loss, skip_update, _ = self.policy_loss_fn( kl.append(appox_kl.mean())
else:
per_token_kl = 0.0
kl.append(0.0)
loss, _ = self.policy_loss_fn(
action_log_probs, action_log_probs,
action_log_probs, action_log_probs,
inputs["advantages"].repeat_interleave(action_log_probs.size(-1), dim=-1), inputs["advantages"].repeat_interleave(action_log_probs.size(-1), dim=-1),
per_token_kl, per_token_kl,
inputs["action_mask"], inputs["action_mask"],
loss_mask=inputs["loss_mask"], loss_mask=inputs["loss_mask"],
total_effective_tokens_in_batch=total_effective_tokens_count,
) )
return loss return loss, num_excessive_samples // self.num_generations
policy_model_outputs = self.booster.execute_pipeline( policy_model_outputs = self.booster.execute_pipeline(
iter([data_policy_forward]), iter([data_policy_forward]),
@ -298,61 +402,71 @@ class GRPOConsumer(BaseConsumer):
self.plugin.shard_config, self.plugin.shard_config,
) )
with torch.no_grad(): if self.policy_loss_fn.beta > 0:
reference_model_logits = self.reference_model( with torch.no_grad():
input_ids=input_ids_forward_micro_batch, reference_model_logits = self.reference_model(
attention_mask=attention_mask_forward_micro_batch, input_ids=input_ids_forward_micro_batch,
).logits attention_mask=attention_mask_forward_micro_batch,
reference_action_log_probs = calc_action_log_probs( ).logits
reference_model_logits / self.generate_config["temperature"], reference_action_log_probs = calc_action_log_probs(
input_ids_forward_micro_batch, reference_model_logits / self.generate_config["temperature"],
num_action, input_ids_forward_micro_batch,
self.plugin.shard_config, num_action,
) self.plugin.shard_config,
per_token_kl = ( )
torch.exp(reference_action_log_probs - action_log_probs) per_token_kl = (
- (reference_action_log_probs - action_log_probs) torch.exp(reference_action_log_probs - action_log_probs)
- 1 - (reference_action_log_probs - action_log_probs)
) - 1
kl = torch.sum(per_token_kl * action_mask_forward_micro_batch, dim=-1) / torch.sum( )
action_mask_forward_micro_batch, dim=-1 kl = torch.sum(per_token_kl * action_mask_forward_micro_batch, dim=-1) / torch.sum(
) action_mask_forward_micro_batch, dim=-1
)
else:
per_token_kl = 0.0
kl = None
loss, skip_update, _ = self.policy_loss_fn( loss, _ = self.policy_loss_fn(
action_log_probs, action_log_probs,
old_action_log_probs, old_action_log_probs,
advantages_forward_micro_batch.repeat_interleave(action_log_probs.size(-1), dim=-1), advantages_forward_micro_batch.repeat_interleave(action_log_probs.size(-1), dim=-1),
per_token_kl, per_token_kl,
action_mask_forward_micro_batch, action_mask_forward_micro_batch,
loss_mask=loss_mask_forward_micro_batch, loss_mask=loss_mask_forward_micro_batch,
total_effective_tokens_in_batch=total_effective_tokens_count,
) )
if not skip_update: self.booster.backward(loss, self.optimizer)
self.booster.backward(loss, self.optimizer)
loss = all_reduce_mean(loss, self.plugin) loss = all_reduce_mean(loss, self.plugin)
kl = all_reduce_mean(kl.mean(), self.plugin)
# Calculate accumulate value. # Calculate accumulate value.
mean_kl.append(kl.data) if kl is not None:
kl = all_reduce_mean(kl.mean(), self.plugin)
mean_kl.append(kl.data)
mean_loss.append(loss.data) mean_loss.append(loss.data)
if not self.plugin.pp_size > 1 or ( if not self.plugin.pp_size > 1 or (
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0 self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
): ):
reward = all_reduce_mean(reward.mean(), self.plugin) reward = all_reduce_mean(reward.mean(), self.plugin)
format_reward = all_reduce_mean(format_reward.mean(), self.plugin) format_acc = all_reduce_mean(format_acc.mean(), self.plugin)
acc_reward = all_reduce_mean(acc_reward.mean(), self.plugin) ans_acc = all_reduce_mean(ans_acc.mean(), self.plugin)
advantages = all_reduce_mean(advantages.mean(), self.plugin) advantages = all_reduce_mean(advantages.mean(), self.plugin)
response_length = all_reduce_mean(response_length.mean(), self.plugin) response_length = all_reduce_mean(response_length.mean(), self.plugin)
self.accum_loss.add_(sum(mean_loss) / len(mean_loss)) self.accum_loss.add_(sum(mean_loss) / len(mean_loss))
self.accum_kl.add_(sum(mean_kl) / len(mean_kl)) if self.policy_loss_fn.beta > 0:
self.accum_kl.add_(sum(mean_kl) / len(mean_kl))
self.accum_reward.add_(reward.data) self.accum_reward.add_(reward.data)
self.accum_format_reward.add_(format_reward.data) self.accum_format_acc.add_(format_acc.data)
self.accum_acc_reward.add_(acc_reward.data) self.accum_ans_acc.add_(ans_acc.data)
self.accum_advantages.add_(advantages.data) self.accum_advantages.add_(advantages.data)
self.accum_response_length.add_(response_length.data) self.accum_response_length.add_(response_length.data)
self.accum_count += 1 self.accum_count += 1
if need_update: if need_update:
self.optimizer.step() self.optimizer.step()
self.optimizer.zero_grad() self.optimizer.zero_grad()
self.global_step += 1
sample_utilization = self.effective_sample_count / self.total_sample_count
self.effective_sample_count = 0
self.total_sample_count = 0
if not self.plugin.pp_size > 1 or ( if not self.plugin.pp_size > 1 or (
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0 self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
): ):
@ -360,167 +474,41 @@ class GRPOConsumer(BaseConsumer):
if (not self.plugin.pp_size > 1 and self.rank == 0) or ( if (not self.plugin.pp_size > 1 and self.rank == 0) or (
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0 self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
): ):
print( to_log_msg = [
"Loss:", f"Loss: {self.accum_loss.item() / self.accum_count:.4f}",
self.accum_loss.item() / self.accum_count, f"Reward: {self.accum_reward.item() / self.accum_count:.4f}",
"\nReward:", f"format Reward: {self.accum_format_acc.item() / self.accum_count:.4f}",
self.accum_reward.item() / self.accum_count, f"Acc Reward: {self.accum_ans_acc.item() / self.accum_count:.4f}",
"\nFormat Reward:", f"Advantages: {self.accum_advantages.item() / self.accum_count:.4f}",
self.accum_format_reward.item() / self.accum_count, f"Response Length: {self.accum_response_length.item() / self.accum_count:.4f}",
"\nAcc Reward:", ] + ([f"KL: {self.accum_kl.item() / self.accum_count:.4f}"] if self.policy_loss_fn.beta > 0 else [])
self.accum_acc_reward.item() / self.accum_count, print("\n".join(to_log_msg))
"\nKL:", metrics = {
self.accum_kl.item() / self.accum_count, "metrics/reward": self.accum_reward.item() / self.accum_count,
"\nAdvantages:", "metrics/format_acc": self.accum_format_acc.item() / self.accum_count,
self.accum_advantages.item() / self.accum_count, "metrics/ans_acc": self.accum_ans_acc.item() / self.accum_count,
"\nResponse Length:", "metrics/response_length": self.accum_response_length.item() / self.accum_count,
self.accum_response_length.item() / self.accum_count, "train/loss": self.accum_loss.item() / self.accum_count,
) "train/advantages": self.accum_advantages.item() / self.accum_count,
self.wandb_run.log( "train/learning_rate": self.lr_scheduler.get_last_lr()[0],
{ "train/sample_utilization": sample_utilization,
"metrics/reward": self.accum_reward.item() / self.accum_count, "rollout/temperature": data["temperature"].cpu().numpy()[0][0],
"metrics/format_reward": self.accum_format_reward.item() / self.accum_count, }
"metrics/acc_reward": self.accum_acc_reward.item() / self.accum_count, if self.policy_loss_fn.beta > 0:
"metrics/response_length": self.accum_response_length.item() / self.accum_count, metrics["train/kl"] = self.accum_kl.item() / self.accum_count
"train/loss": self.accum_loss.item() / self.accum_count,
"train/kl": self.accum_kl.item() / self.accum_count, self.wandb_run.log(metrics)
"train/advantages": self.accum_advantages.item() / self.accum_count,
"train/learning_rate": self.lr_scheduler.get_last_lr()[0],
"rollout/temperature": data["temperature"].cpu().numpy()[0][0],
}
)
self.accum_loss.zero_() self.accum_loss.zero_()
self.accum_reward.zero_() self.accum_reward.zero_()
self.accum_acc_reward.zero_() self.accum_ans_acc.zero_()
self.accum_format_reward.zero_() self.accum_format_acc.zero_()
self.accum_kl.zero_() self.accum_kl.zero_()
self.accum_advantages.zero_() self.accum_advantages.zero_()
self.accum_response_length.zero_() self.accum_response_length.zero_()
self.accum_count = 0 self.accum_count = 0
return loss_scalar return loss_scalar, num_excessive_samples // self.num_generations
def state_dict(self):
self.policy_model._force_wait_all_gather()
model = self.policy_model.unwrap()
state_dict = model.state_dict()
return state_dict
@ray.remote
class GRPOEvalConsumer(BaseConsumer):
def __init__(
self,
num_producers,
num_episodes,
rank,
world_size,
master_addr,
master_port,
num_update_per_episode,
num_recv_per_update,
batch_size,
model_config,
plugin_config,
microbatch_size=1,
num_generations=4,
use_wandb=True,
log_dir="./results",
):
super().__init__(
num_producers,
num_episodes,
rank,
world_size,
master_addr,
master_port,
num_update_per_episode,
num_recv_per_update,
batch_size,
model_config,
plugin_config,
microbatch_size,
)
path = model_config.pop("path")
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
self.policy_model.train()
self.accum_reward = torch.zeros(1, device=self.device)
self.accum_format_reward = torch.zeros(1, device=self.device)
self.accum_acc_reward = torch.zeros(1, device=self.device)
self.accum_response_length = torch.zeros(1, device=self.device)
self.accum_count = torch.zeros(1, device=self.device)
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.pad_token_id = self.tokenizer.pad_token_id
self.num_generations = num_generations
# Initialize verifiable reward.
response_format_tags = {
"think_start": {"text": "<think>", "num_occur": 1},
"think_end": {"text": "</think>", "num_occur": 1},
"answer_start": {"text": "<answer>", "num_occur": 1},
"answer_end": {"text": "</answer>", "num_occur": 1},
}
self.reward_model = VerifiableReward(
reward_fns=[math_reward_fn], tokenizer=self.tokenizer, tags=response_format_tags
)
self.log_dir = log_dir
if not os.path.exists(self.log_dir):
os.makedirs(self.log_dir)
else: else:
os.system(f"rm -rf {self.log_dir}/*") return None, num_excessive_samples // self.num_generations
def setup(self):
super().setup()
self.policy_model, _, *_ = self.booster.boost(self.policy_model)
def step(self, step_idx: int, **kwargs) -> Optional[float]:
rank = dist.get_rank()
data = {k: v.view(-1, v.size(-1)).cpu() for k, v in kwargs.items()}
kwargs["input_ids"].size(0)
reward_group = self.reward_model(
data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"]
)
reward = [value[0].item() for value in reward_group]
format_reward = [value[1].item() for value in reward_group]
acc_reward = [value[2].item() for value in reward_group]
response_length = [(data["response_idx"][i][1] - data["response_idx"][i][0]).item() for i in range(len(reward))]
response = self.tokenizer.batch_decode(data["input_ids"], skip_special_tokens=True)
with open(f"{self.log_dir}/eval_results_rank_{rank}.jsonl", "a", encoding="utf8") as f:
for i in range(len(response)):
f.write(
json.dumps(
{
"response": response[i],
"reward": reward[i],
"format_reward": format_reward[i],
"acc_reward": acc_reward[i],
"response_length": response_length[i],
},
ensure_ascii=False,
)
+ "\n"
)
self.accum_reward += sum(reward)
self.accum_format_reward += sum(format_reward)
self.accum_acc_reward += sum(acc_reward)
self.accum_response_length += sum(response_length)
self.accum_count += len(reward)
# print results
total_count = all_reduce_mean(self.accum_count, self.plugin)
mean_reward = all_reduce_mean(self.accum_reward, self.plugin) / total_count
mean_format_reward = all_reduce_mean(self.accum_format_reward, self.plugin) / total_count
mean_acc_reward = all_reduce_mean(self.accum_acc_reward, self.plugin) / total_count
mean_response_length = all_reduce_mean(self.accum_response_length, self.plugin) / total_count
if rank == 0:
print(
f"Step {step_idx}: Mean Reward: {mean_reward}, Mean Format Reward: {mean_format_reward}, Mean Acc Reward: {mean_acc_reward}, Mean Response Length: {mean_response_length}"
)
return None
def state_dict(self): def state_dict(self):
self.policy_model._force_wait_all_gather() self.policy_model._force_wait_all_gather()

View File

@ -184,6 +184,7 @@ class SGLangInferenceBackend(BaseInferenceBackend):
class VLLMInferenceBackend(BaseInferenceBackend): class VLLMInferenceBackend(BaseInferenceBackend):
DEFAULT_MODEL_CONFIG = dict( DEFAULT_MODEL_CONFIG = dict(
trust_remote_code=True, trust_remote_code=True,
enable_sleep_mode=False,
) )
FORCE_GENERATE_CONFIG = dict( FORCE_GENERATE_CONFIG = dict(
logprobs=0, logprobs=0,
@ -205,6 +206,7 @@ class VLLMInferenceBackend(BaseInferenceBackend):
generate_config.update(self.FORCE_GENERATE_CONFIG) generate_config.update(self.FORCE_GENERATE_CONFIG)
generate_config.update({"n": num_generations}) generate_config.update({"n": num_generations})
self.generate_config = SamplingParams(**generate_config) self.generate_config = SamplingParams(**generate_config)
self.model_config = model_config
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.num_generations = num_generations self.num_generations = num_generations

View File

@ -4,10 +4,10 @@ from typing import Any, Dict, Optional
import ray import ray
from .consumer import SimpleConsumer from .consumer import SimpleConsumer
from .grpo_consumer import GRPOConsumer, GRPOEvalConsumer from .grpo_consumer import GRPOConsumer
from .producer import SimpleProducer from .producer import SimpleProducer
ALGO_MAP = {"Simple": SimpleConsumer, "GRPO": GRPOConsumer, "EvalGRPO": GRPOEvalConsumer} ALGO_MAP = {"Simple": SimpleConsumer, "GRPO": GRPOConsumer, "DAPO": GRPOConsumer}
def get_jsonl_size_fast(path: str) -> int: def get_jsonl_size_fast(path: str) -> int:
@ -40,6 +40,7 @@ def launch_distributed(
inference_model_config: Dict[str, Any], inference_model_config: Dict[str, Any],
generate_config: Dict[str, Any], generate_config: Dict[str, Any],
train_model_config: Dict[str, Any], train_model_config: Dict[str, Any],
grpo_config: Dict[str, Any],
plugin_config: Dict[str, Any], plugin_config: Dict[str, Any],
tokenizer_config: Optional[Dict[str, Any]] = None, tokenizer_config: Optional[Dict[str, Any]] = None,
inference_backend: str = "transformers", inference_backend: str = "transformers",
@ -48,6 +49,8 @@ def launch_distributed(
master_port: int = 29500, master_port: int = 29500,
core_algo: str = "GRPO", core_algo: str = "GRPO",
project_name: Optional[str] = None, project_name: Optional[str] = None,
save_interval: int = 100,
save_dir: str = "./model",
): ):
if core_algo not in ALGO_MAP: if core_algo not in ALGO_MAP:
@ -101,15 +104,13 @@ def launch_distributed(
batch_size=train_batch_size, batch_size=train_batch_size,
model_config=train_model_config, model_config=train_model_config,
plugin_config=plugin_config, plugin_config=plugin_config,
microbatch_size=train_minibatch_size, minibatch_size=train_minibatch_size,
generate_config=generate_config_consumer, generate_config=generate_config_consumer,
training_config={ grpo_config=grpo_config,
"filter_range": [0.05, 9.0],
"lr": 1e-6,
"train_microbatch_size": train_microbatch_size,
},
num_generations=num_generations, num_generations=num_generations,
project_name=project_name, project_name=project_name,
save_interval=save_interval,
save_dir=save_dir,
) )
procs.append(consumer) procs.append(consumer)
ray.get([p.setup.remote() for p in procs]) ray.get([p.setup.remote() for p in procs])

View File

@ -2,7 +2,7 @@ from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from coati.distributed.utils import masked_mean from coati.distributed.utils import masked_mean, masked_sum
class PolicyLoss(nn.Module): class PolicyLoss(nn.Module):
@ -10,11 +10,19 @@ class PolicyLoss(nn.Module):
Policy Loss for PPO Policy Loss for PPO
""" """
def __init__(self, clip_eps: float = 0.2, skip_threshold: float = 20.0, beta: float = 0.01) -> None: def __init__(
self,
clip_eps_low: float = 0.2,
clip_eps_high: float = 0.2,
beta: float = 0.01,
loss_variation: str = "sample_level",
) -> None:
super().__init__() super().__init__()
self.clip_eps = clip_eps self.clip_eps_low = clip_eps_low
self.skip_threshold = skip_threshold self.clip_eps_high = clip_eps_high
self.beta = beta self.beta = beta
self.loss_variation = loss_variation
assert loss_variation in ["sample_level", "token_level"], f"Unsupported loss variation: {loss_variation}"
def forward( def forward(
self, self,
@ -24,22 +32,37 @@ class PolicyLoss(nn.Module):
per_token_kl: torch.Tensor, per_token_kl: torch.Tensor,
action_mask: Optional[torch.Tensor] = None, action_mask: Optional[torch.Tensor] = None,
loss_mask: Optional[torch.Tensor] = None, loss_mask: Optional[torch.Tensor] = None,
total_effective_tokens_in_batch: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:
skip = False
if action_mask is None: if action_mask is None:
ratio = (log_probs - log_probs.detach()).exp() ratio = (log_probs - log_probs.detach()).exp()
else: else:
ratio = ((log_probs - log_probs.detach()) * action_mask).exp() ratio = ((log_probs - log_probs.detach()) * action_mask).exp()
surr1 = ratio * advantages surr1 = ratio * advantages
surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages surr2 = ratio.clamp(1 - self.clip_eps_low, 1 + self.clip_eps_high) * advantages
if self.beta == 0:
# skip kl term if kl coefficient is zero
per_token_kl = 0.0
loss = -torch.min(surr1, surr2) + self.beta * per_token_kl loss = -torch.min(surr1, surr2) + self.beta * per_token_kl
if action_mask is not None: if self.loss_variation == "sample_level":
loss = masked_mean(loss, action_mask) if action_mask is not None:
loss = masked_mean(loss, action_mask)
else:
loss = loss.mean(dim=1)
if loss_mask is not None:
loss = loss * loss_mask
loss = loss.mean()
elif self.loss_variation == "token_level":
if action_mask is not None:
loss = masked_sum(loss, action_mask)
else:
loss = loss.sum(dim=1)
if loss_mask is not None:
loss = loss * loss_mask
loss = loss.sum() / (total_effective_tokens_in_batch + 1e-8)
else: else:
loss = loss.mean(dim=1) raise ValueError(f"Unsupported loss variation: {self.loss_variation}")
if loss_mask is not None:
loss = loss * loss_mask return loss, ratio.max()
loss = loss.mean()
return loss, skip, ratio.max()

View File

@ -103,7 +103,14 @@ class BaseProducer:
print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}") print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}")
outputs["temperature"] = torch.tensor( outputs["temperature"] = torch.tensor(
[self.model.generate_config.temperature] * outputs["input_ids"].size(0) [
(
self.model.generate_config["temperature"]
if isinstance(self.model.generate_config.temperature, dict)
else self.model.generate_config.temperature
)
]
* outputs["input_ids"].size(0)
).to(outputs["input_ids"].device) ).to(outputs["input_ids"].device)
outputs = pre_send(outputs) outputs = pre_send(outputs)
ray_broadcast_tensor_dict( ray_broadcast_tensor_dict(
@ -113,10 +120,15 @@ class BaseProducer:
if (i + 1) % self.num_microbatches == 0 and ( if (i + 1) % self.num_microbatches == 0 and (
episode != self.num_episodes - 1 or i != num_valid_microbatches - 1 episode != self.num_episodes - 1 or i != num_valid_microbatches - 1
): ):
if isinstance(self.model, BACKEND_MAP["vllm"]) and self.model.model_config.get(
"enable_sleep_mode", False
):
self.model.llm.sleep() # revict KV_cache to avoid OOM
# don't sync model for last iteration # don't sync model for last iteration
print( print(
f"[P{self.producer_idx}] Sync model episode {episode} step {(i + 1) // self.num_microbatches - 1}" f"[P{self.producer_idx}] Sync model episode {episode} step {(i + 1) // self.num_microbatches - 1}"
) )
torch.cuda.empty_cache()
state_dict = ray_broadcast_tensor_dict( state_dict = ray_broadcast_tensor_dict(
None, self.num_producers, device=self.device, group_name="sync_model" None, self.num_producers, device=self.device, group_name="sync_model"
@ -124,12 +136,21 @@ class BaseProducer:
self.load_state_dict(state_dict) self.load_state_dict(state_dict)
del state_dict del state_dict
torch.cuda.empty_cache() torch.cuda.empty_cache()
# linear annealing for 1 episode, temperature from initial to 0.7 if isinstance(self.model, BACKEND_MAP["vllm"]) and self.model.model_config.get(
"enable_sleep_mode", False
):
self.model.llm.wake_up()
# linear annealing for 1 episode, temperature from initial to 0.9
if episode <= 0: if episode <= 0:
ratio = 1 - (len(self.dataloader) - i) / len(self.dataloader) ratio = 1 - (len(self.dataloader) - i) / len(self.dataloader)
self.model.generate_config.temperature = (1 - ratio) * self.generate_config[ if isinstance(self.model.generate_config.temperature, dict):
"temperature" self.model.generate_config["temperature"] = (1 - ratio) * self.generate_config[
] + ratio * 0.7 "temperature"
] + ratio * 0.9
else:
self.model.generate_config.temperature = (1 - ratio) * self.generate_config[
"temperature"
] + ratio * 0.9
@ray.remote @ray.remote

View File

@ -4,13 +4,22 @@ from .reward_utils import extract_solution, validate_response_structure
def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
format_score = 1.0
acc_score = 9.0
tokenizer = kwargs["tokenizer"] tokenizer = kwargs["tokenizer"]
soft_over_length_punishment = kwargs.get("soft_over_length_punishment", False)
acc_score = 10.0
reward = torch.tensor(0.0) reward = torch.tensor(0.0)
format_reward = torch.tensor(0.0) format_acc = torch.tensor(0.0)
acc_reward = torch.tensor(0.0) ans_acc = torch.tensor(0.0)
s, e = response_idx[0], response_idx[1] s, e = response_idx[0], response_idx[1]
length_reward = 0.0
if soft_over_length_punishment:
max_length = kwargs.get("max_length", 1024 * 4)
cache_length = kwargs.get("cache_length", 512)
res_length = e.item() - s.item() + 1
if max_length - cache_length < res_length < max_length:
length_reward = ((max_length - cache_length) - res_length) / cache_length * acc_score
if gt_answer is None: if gt_answer is None:
return reward return reward
@ -22,18 +31,20 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
# Check format accuracy # Check format accuracy
if format_valid: if format_valid:
format_reward += format_score format_acc += 1
reward += format_score
# Check answer accuracy # Check answer accuracy, answer is considered correct if the answer is correct and the format is valid
if ( if (
final_answer is not None format_valid
and final_answer is not None
and gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower() and gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower()
): ):
acc_reward += acc_score ans_acc += 1
reward += acc_score reward += acc_score
return torch.tensor([reward, format_reward, acc_reward]).to(input_ids.device) reward = reward + length_reward
return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device)
def gsm8k_reward_fn(input_ids, **kwargs): def gsm8k_reward_fn(input_ids, **kwargs):

View File

@ -1,3 +1,4 @@
from collections import defaultdict
from typing import Any, Dict, List from typing import Any, Dict, List
import torch import torch
@ -26,6 +27,27 @@ def bind_batch(batches: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor
return batch return batch
def pad_batch(batches: List[Dict[str, torch.Tensor]], tokenizer: Any = None) -> List[Dict[str, torch.Tensor]]:
max_len = defaultdict(int)
for sample in batches:
for k in sample:
if k in ["input_ids", "attention_mask", "action_log_probs", "action_mask"]:
max_len[k] = max(max_len[k], sample[k].size(-1))
for idx, sample in enumerate(batches):
for k in sample:
if k in ["input_ids", "attention_mask", "action_log_probs", "action_mask"]:
# right pad with 0s
if k in ["attention_mask", "action_mask"]:
batches[idx][k] = torch.nn.functional.pad(
batches[idx][k], (0, max_len[k] - batches[idx][k].size(-1)), "constant", False
)
else:
batches[idx][k] = torch.nn.functional.pad(
batches[idx][k], (0, max_len[k] - batches[idx][k].size(-1)), "constant", 0
)
return batches
def pre_send(batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: def pre_send(batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
# compress mask to save bandwidth # compress mask to save bandwidth
if "attention_mask" in batch: if "attention_mask" in batch:
@ -113,3 +135,20 @@ def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch
mask_sum = mask.sum(dim=dim) mask_sum = mask.sum(dim=dim)
mean = tensor / (mask_sum + 1e-8) mean = tensor / (mask_sum + 1e-8)
return mean return mean
def masked_sum(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch.Tensor:
"""
Compute the masked sum of a tensor along a specified dimension.
Args:
tensor (torch.Tensor): The input tensor.
mask (torch.Tensor): The mask tensor with the same shape as the input tensor.
dim (int, optional): The dimension along which to compute the sum. Default is 1.
Returns:
torch.Tensor: The masked sum tensor.
"""
tensor = tensor * mask
return tensor.sum(dim=dim)

View File

@ -128,7 +128,7 @@ def all_reduce_mean(tensor: torch.Tensor, plugin: Plugin = None) -> torch.Tensor
return tensor return tensor
def all_reduce_sum(tensor: torch.Tensor) -> torch.Tensor: def all_reduce_sum(tensor: torch.Tensor, plugin: Plugin = None) -> torch.Tensor:
""" """
Performs an all-reduce operation to sum the values of the given tensor across all processes. Performs an all-reduce operation to sum the values of the given tensor across all processes.
@ -138,5 +138,9 @@ def all_reduce_sum(tensor: torch.Tensor) -> torch.Tensor:
Returns: Returns:
torch.Tensor: The reduced tensor with the sum of values across all processes. torch.Tensor: The reduced tensor with the sum of values across all processes.
""" """
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM) # All reduce sum across DP group
if plugin is not None:
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM, group=plugin.dp_group)
else:
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM)
return tensor return tensor

View File

@ -1,4 +1,5 @@
import argparse import argparse
import os
import ray import ray
import torch import torch
@ -8,15 +9,17 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B") parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B")
parser.add_argument("-d", "--dataset", type=str, default="data.jsonl") parser.add_argument("-d", "--dataset", type=str, default="data.jsonl")
parser.add_argument("-p", "--project", type=str, default="GRPO", help="Project name.")
# Distributed training parameters
parser.add_argument("-t", "--num-trainers", type=int, default=2) parser.add_argument("-t", "--num-trainers", type=int, default=2)
parser.add_argument("-i", "--num-inferencer", type=int, default=2) parser.add_argument("-i", "--num-inferencer", type=int, default=2)
parser.add_argument("-g", "--num-generations", type=int, default=8, help="Number of generations per prompt.") parser.add_argument("-g", "--num-generations", type=int, default=8, help="Number of generations per prompt.")
parser.add_argument("-p", "--project", type=str, default="GRPO", help="Project name.")
parser.add_argument( parser.add_argument(
"-ibs", "-ibs",
"--inference-batch-size", "--inference-batch-size",
type=int, type=int,
default=64, default=None,
help="Number of prompts to generate per inference step. It should be divisible by tbs, and the weights on the inference backend will be synced every ibs/tbs training steps of the policy model.", help="Number of prompts to generate per inference step. It should be divisible by tbs, and the weights on the inference backend will be synced every ibs/tbs training steps of the policy model.",
) )
parser.add_argument( parser.add_argument(
@ -37,7 +40,7 @@ if __name__ == "__main__":
"-tMbs", "-tMbs",
"--train-minibatch-size", "--train-minibatch-size",
type=int, type=int,
default=1, default=None,
help="Number of unique prompts in each training batch per dp group. The inference backend must generate tMbs * g * dp_size samples before forwarding. Satisfy tMbs * g >= tmbs", help="Number of unique prompts in each training batch per dp group. The inference backend must generate tMbs * g * dp_size samples before forwarding. Satisfy tMbs * g >= tmbs",
) )
parser.add_argument( parser.add_argument(
@ -47,22 +50,78 @@ if __name__ == "__main__":
default=2, default=2,
help="Effective batch size per dp group for forwarding and backwarding. Please select based on the availiable memory.", help="Effective batch size per dp group for forwarding and backwarding. Please select based on the availiable memory.",
) )
parser.add_argument(
"--ray_dir", type=str, default=None, help="Custom temperary directory for storing ray cluster data, Optional"
)
parser.add_argument(
"--master_address", type=str, default=None, help="Master address for multi-node distributed training, Optional"
)
parser.add_argument(
"--master_port", type=int, default=29506, help="Master port for multi-node distributed training, Optional"
)
# Sampling parameters
parser.add_argument("-b", "--backend", type=str, default="transformers", choices=["transformers", "vllm"]) parser.add_argument("-b", "--backend", type=str, default="transformers", choices=["transformers", "vllm"])
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple", "GRPO", "EvalGRPO"]) parser.add_argument("-temp", "--temperature", type=float, default=1.0, help="Temperature for sampling.")
parser.add_argument(
"-topk",
"--top-k",
type=int,
default=None,
help="Top k for sampling. Please check the generation arguments documentation for your backend.",
)
parser.add_argument(
"-topp",
"--top-p",
type=float,
default=1.0,
help="Top p for sampling. Please check the generation arguments documentation for your backend.",
)
parser.add_argument("-s", "--system-prompt", type=str, default=None, help="System prompt for data construction.") parser.add_argument("-s", "--system-prompt", type=str, default=None, help="System prompt for data construction.")
parser.add_argument("-mnt", "--max-new-tokens", type=int, default=1024 * 4 - 512, help="Max length for generation.")
parser.add_argument("-mpt", "--max-prompt-tokens", type=int, default=512, help="Max length for prompt.")
# GRPO parameters
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["DAPO", "GRPO"])
parser.add_argument("-lr", "--learning-rate", type=float, default=1e-6, help="Learning rate for GRPO.")
parser.add_argument("-kl", "--kl-coeff", type=float, default=0.01, help="KL penalty coefficient for GRPO.")
# Logging/Checkpointing parameters
parser.add_argument("-si", "--save-interval", type=int, default=100, help="Interval for saving checkpoints.")
parser.add_argument("-sd", "--save-dir", type=str, default="./model", help="Directory for saving checkpoints.")
args = parser.parse_args() args = parser.parse_args()
assert args.train_minibatch_size > 0, "Train mini batch size must be greater than 0" if args.train_minibatch_size is None:
# Default settings: Using train batch size as mini batch size
args.train_minibatch_size = args.train_batch_size
if args.inference_batch_size is None:
# Default settings: Using train batch size as inference batch size, sync every inference model every train step
args.inference_batch_size = args.train_batch_size
assert ( assert (
args.train_minibatch_size * args.num_generations >= args.train_microbatch_size args.train_minibatch_size * args.num_generations >= args.train_microbatch_size
and args.train_microbatch_size > 0 and args.train_microbatch_size > 0
), "Train micro batch size must be greater than 0 less than train mini batch size * num generations" ), "Train micro batch size must be greater than 0 less than train mini batch size * num generations"
assert (
args.train_minibatch_size <= args.train_batch_size
), "Train mini batch size must be less than or equals to train batch size"
ray.init(address="local", namespace="ray-example") if args.master_address is None:
# Default settings: Using single machine
ray.init(address="local", namespace="ray-example")
else:
# For ray distributed multi-machine training, Please change _node_ip_address to your IP address of your master node
ray.init(_node_ip_address=args.master_address, namespace="ray-example", _temp_dir=args.ray_dir)
if args.top_k is None:
if args.backend == "transformers":
args.top_k = 50
elif args.backend == "vllm":
args.top_k = -1
inference_model_config = dict(path=args.model) inference_model_config = dict(path=args.model)
train_model_config = dict(path=args.model, use_flash_attention_2=True, use_cache=False) train_model_config = dict(path=args.model, use_flash_attention_2=True, use_cache=False)
generate_config = dict(top_k=50, top_p=0.75, temperature=0.9) generate_config = dict(top_k=args.top_k, top_p=args.top_p, temperature=args.temperature)
if args.backend == "transformers": if args.backend == "transformers":
inference_model_config.update( inference_model_config.update(
@ -73,7 +132,7 @@ if __name__ == "__main__":
) )
generate_config.update( generate_config.update(
dict( dict(
max_length=1024 + 512, max_length=args.max_new_tokens + args.max_prompt_tokens,
do_sample=True, do_sample=True,
max_new_tokens=None, max_new_tokens=None,
early_stopping=False, early_stopping=False,
@ -81,31 +140,57 @@ if __name__ == "__main__":
) )
) )
elif args.backend == "vllm": elif args.backend == "vllm":
inference_model_config.update(dict(gpu_memory_utilization=0.7, enforce_eager=True, enable_chunked_prefill=True)) inference_model_config.update(
dict(
gpu_memory_utilization=0.7,
enforce_eager=True,
enable_chunked_prefill=True,
max_model_len=args.max_new_tokens + args.max_prompt_tokens,
tensor_parallel_size=1,
)
)
generate_config.update( generate_config.update(
dict( dict(
max_tokens=2048, max_tokens=args.max_new_tokens, # max new tokens
ignore_eos=True, ignore_eos=True,
include_stop_str_in_output=True, include_stop_str_in_output=True,
stop=["</answer>"], stop=["</answer>"],
) )
) )
else: else:
inference_model_config.update( raise ValueError(f"Unsupported backend: {args.backend}")
dict(
mem_fraction_static=0.6, if args.algo == "GRPO":
) # Default Settings
) grpo_config = {
generate_config.update( "lr": args.learning_rate,
dict( "train_microbatch_size": args.train_microbatch_size,
max_new_tokens=256, "beta": args.kl_coeff, # KL penalty coefficient
ignore_eos=True, "loss_variation": "sample_level",
) }
) elif args.algo == "DAPO":
# DAPO variant settings
grpo_config = {
"filter_range": [0.01, 0.99], # only filter out all zero batch and all one batch
"lr": args.learning_rate,
"train_microbatch_size": args.train_microbatch_size,
"dynamic_batching": True,
"clip_eps_low": 0.2,
"clip_eps_high": 0.28,
"skip_threshold": 20.0,
"beta": 0, # no KL penalty for DAPO
"loss_variation": "token_level",
"soft_over_length_punishment": True,
"max_length": args.max_new_tokens + args.max_prompt_tokens,
"cache_length": min(1024, int(args.max_new_tokens / 4)),
"filter_truncated_response": True,
}
else:
raise ValueError(f"Unsupported algorithm: {args.algo}")
launch_distributed( launch_distributed(
num_producers=args.num_inferencer, num_producers=args.num_inferencer,
num_proc_per_producer=1, num_proc_per_producer=inference_model_config.get("tensor_parallel_size", 1),
num_consumer_procs=args.num_trainers, num_consumer_procs=args.num_trainers,
num_episodes=1, num_episodes=1,
inference_batch_size=args.inference_batch_size, inference_batch_size=args.inference_batch_size,
@ -113,15 +198,22 @@ if __name__ == "__main__":
train_batch_size=args.train_batch_size, train_batch_size=args.train_batch_size,
train_minibatch_size=args.train_minibatch_size, train_minibatch_size=args.train_minibatch_size,
train_microbatch_size=args.train_microbatch_size, train_microbatch_size=args.train_microbatch_size,
dataset_config={"path": args.dataset, "max_length": 300, "system_prompt": args.system_prompt}, dataset_config={
"path": args.dataset,
"max_length": args.max_prompt_tokens,
"system_prompt": args.system_prompt,
},
dataloaders_config={}, dataloaders_config={},
inference_model_config=inference_model_config, inference_model_config=inference_model_config,
generate_config=generate_config, generate_config=generate_config,
num_generations=args.num_generations, num_generations=args.num_generations,
train_model_config=train_model_config, train_model_config=train_model_config,
plugin_config={}, # Default setting: zero. grpo_config=grpo_config,
plugin_config={
"zero_stage": 2,
}, # for zero
# currently not support tp/pp
# plugin_config={ # plugin_config={
# "pp_size": 2,
# "tp_size": 2, # "tp_size": 2,
# "microbatch_size": args.train_microbatch_size // 2, # "microbatch_size": args.train_microbatch_size // 2,
# "zero_stage": 0, # "zero_stage": 0,
@ -129,7 +221,9 @@ if __name__ == "__main__":
# }, # for pp # }, # for pp
inference_backend=args.backend, inference_backend=args.backend,
master_addr="localhost", master_addr="localhost",
master_port=29506, master_port=args.master_port,
core_algo=args.algo, core_algo=args.algo,
project_name=args.project, project_name=args.project,
save_interval=args.save_interval,
save_dir=os.path.join(args.save_dir, args.project.replace(" ", "_")),
) )