Merge pull request #6312 from hpcaitech/grpo-latest-dev

[feat] Move prompt-level-filtering to buffer side
This commit is contained in:
YeAnbang 2025-06-05 15:51:38 +08:00 committed by GitHub
commit ceb7065d6d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 252 additions and 144 deletions

View File

@ -117,26 +117,102 @@ class BaseConsumer:
# receive data from producers
for r in range(self.num_producers):
print(f"[T{dist.get_rank()}] Recv data episode {episode} step {step} from {r}")
self.buffer.extend(
unbind_batch(
ray_broadcast_tensor_dict(
None, src=0, device=self.device, group_name=f"sync_data_{r}"
)
)
raw_batch = ray_broadcast_tensor_dict(
None, src=0, device=self.device, group_name=f"sync_data_{r}"
)
while len(self.buffer) >= self.dp_size * self.minibatch_size:
batches = self.buffer[
self.dp_rank * self.minibatch_size : (self.dp_rank + 1) * self.minibatch_size
]
batch = bind_batch(batches)
batch = post_recv(batch)
loss, excessive_prompts_idx = self.step(i, pbar, **batch)
# calculate group reward et al. filtering. As only the filtered group will be used for training (which is incomplete),
# we need to calculate the metrics before filtering here for logging
# [batch_size, num_generations, ...] -> [batch_size * num_generations, ...]
raw_batch_with_reward = self.calculate_reward(
{k: v.view(-1, v.size(-1)) if k != "temperature" else v for k, v in raw_batch.items()}
)
raw_batch_with_reward = {
k: v.view(-1, self.num_generations, v.size(-1)) if k != "temperature" else v
for k, v in raw_batch_with_reward.items()
}
# [batch_size, num_generations] -> [batch_size]
reward = raw_batch_with_reward["reward"][:, :, 0]
format_acc = raw_batch_with_reward["format_acc"][:, :, 0]
ans_acc = raw_batch_with_reward["ans_acc"][:, :, 0]
response_len = (
raw_batch_with_reward["response_idx"][:, :, 1]
- raw_batch_with_reward["response_idx"][:, :, 0]
+ 1
).type(torch.float32)
effective_group_mask = None
if self.filter_range is not None and self.grpo_config.get("dynamic_batching", True):
# filter the group based on the reward and accuracy
group_ans_acc_mean = ans_acc.mean(dim=1)
effective_group_mask = torch.logical_and(
group_ans_acc_mean > self.filter_range[0], group_ans_acc_mean < self.filter_range[1]
)
raw_batch_with_reward = unbind_batch(raw_batch_with_reward) # List[Dict[str, torch.Tensor]]
for group_idx, group_with_reward in enumerate(raw_batch_with_reward):
self.buffer.append(
[
(
group_with_reward
if effective_group_mask is None or effective_group_mask[group_idx]
else None
),
reward[group_idx],
format_acc[group_idx],
ans_acc[group_idx],
response_len[group_idx],
]
)
if effective_group_mask is not None:
print(
f"[T{dist.get_rank()}] Filter recv data: {len(raw_batch_with_reward)} -> {torch.sum(effective_group_mask).cpu().item()} effective groups"
)
# mapping the effective group to the raw group for indexing
effective_group_to_raw_group_mapping = {}
for buffer_idx in range(len(self.buffer)):
if self.buffer[buffer_idx][0] is not None:
effective_group_to_raw_group_mapping[len(effective_group_to_raw_group_mapping)] = (
buffer_idx
)
print(
f"[T{dist.get_rank()}] Collect Effective Prompt: {len(effective_group_to_raw_group_mapping)}/{self.dp_size * self.minibatch_size}"
)
if excessive_prompts_idx is not None:
excessive_prompts = [self.buffer[idx] for idx in excessive_prompts_idx]
self.buffer = excessive_prompts + self.buffer[self.dp_size * self.minibatch_size :]
else:
self.buffer = self.buffer[self.dp_size * self.minibatch_size :]
while len(effective_group_to_raw_group_mapping) >= self.dp_size * self.minibatch_size:
# on each dp_rank, we use minibatch_size effective samples to form a batch
batches = [
self.buffer[effective_group_to_raw_group_mapping[i]]
for i in range(
self.dp_rank * self.minibatch_size, (self.dp_rank + 1) * self.minibatch_size
)
]
# every dp_rank will receive a complete mini-batch, no need to sync within step() later
# each mini-batch use the first self.dp_size * minibatch_size effective samples
raw_mini_batches = self.buffer[
: effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1
] # include the last effective sample
raw_mini_batches_metric_dict = {
"raw_train_mini_batch_reward": [t[1] for t in raw_mini_batches],
"raw_train_mini_batch_format_acc": [t[2] for t in raw_mini_batches],
"raw_train_mini_batch_ans_acc": [t[3] for t in raw_mini_batches],
"raw_train_mini_batch_response_len": [t[4] for t in raw_mini_batches],
}
batch = bind_batch([t[0] for t in batches])
batch = post_recv(batch)
loss = self.step(i, pbar, **batch, **raw_mini_batches_metric_dict)
self.buffer = self.buffer[
effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1 :
]
# recalculate the effective group to raw group mapping
effective_group_to_raw_group_mapping_size_before = len(effective_group_to_raw_group_mapping)
effective_group_to_raw_group_mapping = {}
for buffer_idx in range(len(self.buffer)):
if self.buffer[buffer_idx][0] is not None:
effective_group_to_raw_group_mapping[len(effective_group_to_raw_group_mapping)] = (
buffer_idx
)
assert (
len(effective_group_to_raw_group_mapping)
== effective_group_to_raw_group_mapping_size_before - self.dp_size * self.minibatch_size
)
if loss is not None:
pbar.set_postfix({"loss": loss})
i += 1

View File

@ -1,5 +1,5 @@
from contextlib import nullcontext
from typing import Any, Optional
from typing import Any, Dict, Optional
import ray
import torch
@ -9,7 +9,7 @@ from coati.distributed.loss import PolicyLoss
from coati.distributed.reward.reward_fn import boxed_math_reward_fn, math_reward_fn
from coati.distributed.reward.verifiable_reward import VerifiableReward
from coati.distributed.utils import calc_action_log_probs
from coati.trainer.utils import all_gather_tensors, all_reduce_mean, all_reduce_sum
from coati.trainer.utils import all_reduce_mean, all_reduce_sum
from transformers import AutoModelForCausalLM, AutoTokenizer
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
@ -72,21 +72,18 @@ class GRPOConsumer(BaseConsumer):
self.policy_model.gradient_checkpointing_enable()
self.optimizer = HybridAdam(self.policy_model.parameters(), lr=grpo_config.get("lr", 1e-6))
self.accum_loss = torch.zeros(1, device=self.device)
self.accum_reward = torch.zeros(1, device=self.device)
self.accum_kl = torch.zeros(1, device=self.device)
self.accum_format_acc = torch.zeros(1, device=self.device)
self.accum_ans_acc = torch.zeros(1, device=self.device)
self.accum_advantages = torch.zeros(1, device=self.device)
self.accum_response_length = torch.zeros(1, device=self.device)
self.raw_train_batch_reward = []
self.raw_train_batch_format_acc = []
self.raw_train_batch_ans_acc = []
self.raw_train_batch_response_len = []
self.accum_count = 0
self.generate_config = generate_config
self.grpo_config = grpo_config
self.project_name = project_name
self.effective_sample_count = 0
self.effective_prompt_count = 0
self.total_sample_count = 0
self.overlength_samples = 0
self.total_overlength_samples = 0
self.project_name = project_name
self.run_name = run_name
self.wandb_group_name = wandb_group_name
@ -122,16 +119,7 @@ class GRPOConsumer(BaseConsumer):
"either max_tokens (vllm) or max_new_tokens (transformers) must be set in generate_config."
)
# Initialize verifiable reward.
response_format_tags = (
{
"think_start": {"text": "<think>", "num_occur": 1},
"think_end": {"text": "</think>", "num_occur": 1},
"answer_start": {"text": "<answer>", "num_occur": 1},
"answer_end": {"text": "</answer>", "num_occur": 1},
}
if grpo_config.get("reward_fn_type") == "think_answer_tags"
else None
)
response_format_tags = grpo_config.get("response_format_tags", None)
reward_model_kwargs = {
k: v
for k, v in grpo_config.items()
@ -187,24 +175,21 @@ class GRPOConsumer(BaseConsumer):
Format:
[minibatch_size, num_of_generation, prompt_length + response_length] --- <PAD>...<PAD><PROMPT>...<PROMPT><RESPONSE>...<RESPONSE><PAD>...<PAD>.
"""
# Reshape to [minibatch_size x num_of_generation, prompt_length + response_length]
data = {k: v.view(-1, v.size(-1)) for k, v in kwargs.items()}
data = {k: v.view(-1, v.size(-1)) for k, v in kwargs.items() if "raw_train_mini_batch_" not in k}
self.raw_train_batch_reward.extend(kwargs["raw_train_mini_batch_reward"])
self.raw_train_batch_format_acc.extend(kwargs["raw_train_mini_batch_format_acc"])
self.raw_train_batch_ans_acc.extend(kwargs["raw_train_mini_batch_ans_acc"])
self.raw_train_batch_response_len.extend(kwargs["raw_train_mini_batch_response_len"])
action_mask = data["action_mask"]
num_action = action_mask.shape[1]
old_action_log_probs = data["action_log_probs"]
response_length = torch.sum(action_mask, dim=1).to(torch.float32)
train_microbatch_size = self.grpo_config.get("train_microbatch_size", data["input_ids"].size(0))
reward_group = self.reward_model(
data["input_ids"],
gt_answer=data["gt_answer"],
response_idx=data["response_idx"],
)
reward = torch.tensor([value[0] for value in reward_group]).to(data["input_ids"].device)
format_acc = torch.tensor([value[1] for value in reward_group]).to(data["input_ids"].device)
ans_acc = torch.tensor([value[2] for value in reward_group]).to(data["input_ids"].device)
reward = data["reward"].view((-1))
format_acc = data["format_acc"].view((-1))
ans_acc = data["ans_acc"].view((-1))
# [minibatch_size, num_generations]
@ -216,67 +201,35 @@ class GRPOConsumer(BaseConsumer):
reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0)
# [minibatch_size x num_generations]
advantages = ((reward - reward_mean) / (reward_std + 1e-4)).unsqueeze(dim=-1)
# filter out the reward that is too high (all sample gets full score) or too low (all sample gets 0 score),
group_ans_acc = (
ans_acc.view(-1, self.num_generations).mean(dim=1).repeat_interleave(self.num_generations, dim=0)
)
# [minibatch_size x num_of_generation]
loss_mask = (
torch.ones(action_mask.size(0), device=action_mask.device).bool()
if self.filter_range is None
else torch.logical_and(group_ans_acc > self.filter_range[0], group_ans_acc < self.filter_range[1])
)
loss_mask = torch.ones(action_mask.size(0), device=action_mask.device).bool()
# filter out overlength samples
if self.filter_truncated_response and action_mask.size(1) == self.max_length:
old_loss_mask = loss_mask.clone()
loss_mask = torch.logical_and(
loss_mask,
action_mask[:, -1] == False,
)
self.overlength_samples = (old_loss_mask & ~loss_mask).sum().item()
self.overlength_samples = all_reduce_sum(
torch.tensor(self.overlength_samples, device=loss_mask.device), self.plugin
if self.filter_range is not None and self.grpo_config.get("dynamic_batching", False) == False:
# filter out samples with reward outside the range
# if dynamic batching is enabled, we filter out out of range groups before training
group_ans_acc_mean = (
ans_acc.view(-1, self.num_generations).mean(dim=1).repeat_interleave(self.num_generations, dim=-1)
)
self.total_overlength_samples += self.overlength_samples.item()
prompt_level_mask = loss_mask.view(self.minibatch_size, self.num_generations)
# [minibatch_size] -> calculate the number of effective prompts
effective_prompts_mask = prompt_level_mask.any(dim=1)
effective_prompts = all_reduce_sum(torch.sum(effective_prompts_mask), self.plugin)
self.effective_prompt_count += effective_prompts.item()
excessive_prompts_idx = None
loss_mask = torch.logical_and(
loss_mask,
torch.logical_and(
group_ans_acc_mean > self.filter_range[0],
group_ans_acc_mean < self.filter_range[1],
),
)
self.effective_prompt_count += group_reward.size(0) * self.dp_size
mean_kl, mean_loss = [], []
if self.grpo_config.get("dynamic_batching", True):
need_update = self.effective_prompt_count >= self.batch_size * self.dp_size
excessive_prompts = self.effective_prompt_count - self.batch_size * self.dp_size
if excessive_prompts > 0:
excessive_prompts_per_rank = excessive_prompts // self.dp_size
# Only count excessive prompts if they are greater than 1 per rank.
# TODO: customize excessive prompts calculation.
if excessive_prompts_per_rank != 0:
# Mask excessive prompts to False
true_indices = torch.nonzero(effective_prompts_mask)
# Make sure the indices are not empty.
if true_indices.numel() > 0:
true_indices = true_indices.squeeze(-1)
if excessive_prompts_per_rank <= len(true_indices):
excessive_prompts_idx = true_indices[-excessive_prompts_per_rank:]
else:
excessive_prompts_idx = true_indices
effective_prompts_mask[excessive_prompts_idx] = False
for mask_idx in range(len(effective_prompts_mask)):
if effective_prompts_mask[mask_idx] == False:
# Update loss mask.
loss_mask[mask_idx] = False
else:
excessive_prompts_idx = torch.empty([0])
else:
# If dynamic batching is disabled, we need to use all samples for training.
need_update = (step_idx + 1) % self.num_microbatches == 0
@ -286,12 +239,10 @@ class GRPOConsumer(BaseConsumer):
total_effective_tokens_count = all_reduce_sum(torch.sum(effective_tokens_count), self.plugin)
total_samples = all_reduce_sum(torch.sum(torch.ones_like(loss_mask, device=loss_mask.device)), self.plugin)
self.effective_sample_count += effective_samples.item()
self.total_sample_count += total_samples.item()
pbar.set_postfix(
{
"Global Step": self.global_step,
"Effective prompts": f"{self.effective_prompt_count}/{self.batch_size * self.dp_size}",
"Effective samples": f"{self.effective_sample_count}/{self.batch_size * self.dp_size * self.num_generations}",
"Gradient Accumulation on": f"{self.effective_prompt_count}/{self.batch_size * self.dp_size} effective prompts, {self.effective_sample_count}/{self.batch_size * self.dp_size * self.num_generations} effective samples",
}
)
@ -483,22 +434,16 @@ class GRPOConsumer(BaseConsumer):
self.accum_loss.add_(sum(mean_loss) / len(mean_loss))
if self.policy_loss_fn.beta > 0:
self.accum_kl.add_(sum(mean_kl) / len(mean_kl))
self.accum_reward.add_(reward.data)
self.accum_format_acc.add_(format_acc.data)
self.accum_ans_acc.add_(ans_acc.data)
self.accum_advantages.add_(advantages.data)
self.accum_response_length.add_(response_length.data)
self.accum_count += 1
if need_update:
self.optimizer.step()
self.optimizer.zero_grad()
self.global_step += 1
sample_utilization = self.effective_sample_count / self.total_sample_count
overlength_samples_percentage = self.total_overlength_samples / self.total_sample_count
# no need to run all reduce as raw_train_batch_* are not splited across dp rank
sample_utilization = self.effective_sample_count / len(self.raw_train_batch_reward) / self.num_generations
self.effective_prompt_count = 0
self.effective_sample_count = 0
self.total_sample_count = 0
self.total_overlength_samples = 0
loss_scalar = self.accum_loss.item()
if not self.plugin.pp_size > 1 or (
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
@ -506,27 +451,39 @@ class GRPOConsumer(BaseConsumer):
if (not self.plugin.pp_size > 1 and self.rank == 0) or (
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
):
raw_batch_reward_mean = torch.cat(self.raw_train_batch_reward, dim=0).mean().cpu().item()
raw_batch_format_acc_mean = torch.cat(self.raw_train_batch_format_acc, dim=0).mean().cpu().item()
raw_batch_ans_acc_mean = torch.cat(self.raw_train_batch_ans_acc, dim=0).mean().cpu().item()
raw_batch_response_len = torch.cat(self.raw_train_batch_response_len, dim=0)
raw_batch_response_len_mean = raw_batch_response_len.mean().cpu().item()
overlength_samples_ratio = (
(raw_batch_response_len >= action_mask.size(-1)).to(float).mean().cpu().item()
) # not an exact figure, but a close estimate
self.raw_train_batch_reward = []
self.raw_train_batch_format_acc = []
self.raw_train_batch_ans_acc = []
self.raw_train_batch_response_len = []
to_log_msg = [
f"Loss: {self.accum_loss.item() / self.accum_count:.4f}",
f"Reward: {self.accum_reward.item() / self.accum_count:.4f}",
f"format Reward: {self.accum_format_acc.item() / self.accum_count:.4f}",
f"Acc Reward: {self.accum_ans_acc.item() / self.accum_count:.4f}",
f"Reward: {raw_batch_reward_mean:.4f}",
f"format Reward: {raw_batch_format_acc_mean:.4f}",
f"Acc Reward: {raw_batch_ans_acc_mean:.4f}",
f"Advantages: {self.accum_advantages.item() / self.accum_count:.4f}",
f"Response Length: {self.accum_response_length.item() / self.accum_count:.4f}",
f"Response Length: {raw_batch_response_len_mean:.4f}",
f"Sample_utilization: {sample_utilization:.4f}",
f"Percentage of overlength samples: {overlength_samples_percentage:.4f}",
f"Overlength samples ratio: {overlength_samples_ratio:.4f}",
] + ([f"KL: {self.accum_kl.item() / self.accum_count:.4f}"] if self.policy_loss_fn.beta > 0 else [])
print("\n".join(to_log_msg))
metrics = {
"metrics/reward": self.accum_reward.item() / self.accum_count,
"metrics/format_acc": self.accum_format_acc.item() / self.accum_count,
"metrics/ans_acc": self.accum_ans_acc.item() / self.accum_count,
"metrics/response_length": self.accum_response_length.item() / self.accum_count,
"metrics/reward": raw_batch_reward_mean,
"metrics/format_acc": raw_batch_format_acc_mean,
"metrics/ans_acc": raw_batch_ans_acc_mean,
"metrics/response_length": raw_batch_response_len_mean,
"train/loss": self.accum_loss.item() / self.accum_count,
"train/advantages": self.accum_advantages.item() / self.accum_count,
"train/learning_rate": self.lr_scheduler.get_last_lr()[0],
"train/sample_utilization": sample_utilization,
"train/percentage_overlength_samples": overlength_samples_percentage,
"train/overlength_samples_ratio": overlength_samples_ratio,
"rollout/temperature": data["temperature"].cpu().numpy()[0][0],
}
if self.policy_loss_fn.beta > 0:
@ -534,21 +491,46 @@ class GRPOConsumer(BaseConsumer):
if self.wandb_run is not None:
self.wandb_run.log(metrics)
self.accum_loss.zero_()
self.accum_reward.zero_()
self.accum_ans_acc.zero_()
self.accum_format_acc.zero_()
self.accum_kl.zero_()
self.accum_advantages.zero_()
self.accum_response_length.zero_()
self.accum_count = 0
if excessive_prompts_idx is not None:
# All gather excessive prompts index across DP ranks.
excessive_prompts_idx = [idx + self.dp_rank * self.minibatch_size for idx in excessive_prompts_idx]
excessive_prompts_idx = all_gather_tensors(excessive_prompts_idx, self.plugin)
return loss_scalar, excessive_prompts_idx
return loss_scalar
else:
return None, excessive_prompts_idx
return None
def calculate_reward(self, rollout: Dict[str, Any]) -> Dict[str, Any]:
"""
Calculate the group reward for the given rollout group.
Args:
rollout_group (Dict[str, Any]):
a group of samples generated by the model from the same prompt
contain the following keys:
"input_ids": torch.Tensor, [num_of_generation, prompt_length + response_length]
"attention_mask": torch.Tensor, [num_of_generation, prompt_length + response_length]
"action_mask": torch.Tensor, [num_of_generation, response_length]
"action_log_probs": torch.Tensor, [num_of_generation, response_length]
"response_idx": int, torch.Tensor, [num_of_generation, 2]
"gt_answer": torch.Tensor, [num_of_generation, 128]
"temperature": torch.Tensor, [] (scalar)
Returns:
Dict[str, Any]: The new group data with calculated reward.
"""
reward_model_output = self.reward_model(
rollout["input_ids"],
gt_answer=rollout["gt_answer"],
response_idx=rollout["response_idx"],
)
# [num_of_generation]
reward = torch.tensor([value[0] for value in reward_model_output]).to(rollout["input_ids"].device)
format_acc = torch.tensor([value[1] for value in reward_model_output]).to(rollout["input_ids"].device)
ans_acc = torch.tensor([value[2] for value in reward_model_output]).to(rollout["input_ids"].device)
rollout["reward"] = reward.view((-1, 1))
rollout["format_acc"] = format_acc.view((-1, 1))
rollout["ans_acc"] = ans_acc.view((-1, 1))
return rollout
def state_dict(self):
self.policy_model._force_wait_all_gather()

View File

@ -100,6 +100,7 @@ def launch_distributed(
eval_dataset_config=eval_dataset_config,
eval_interval=eval_interval,
evaluation_function_type=grpo_config["reward_fn_type"],
response_format_tags=grpo_config["response_format_tags"],
eval_save_dir=eval_save_dir,
eval_generation_config=eval_generation_config,
project_name=project_name,

View File

@ -46,6 +46,7 @@ class BaseProducer:
eval_dataset_config=None,
eval_interval=-1, # disable evaluation
evaluation_function_type="think_answer_tags",
response_format_tags=None,
eval_save_dir: str = "./eval",
project_name: str = None,
run_name: str = None,
@ -148,6 +149,7 @@ class BaseProducer:
self.evaluation_function = boxed_math_reward_fn
else:
raise ValueError(f"Unknown evaluation function type {evaluation_function_type}")
self.response_format_tags = response_format_tags
else:
print("No eval dataset provided, skip eval")
self.device = get_current_device()
@ -217,6 +219,7 @@ class BaseProducer:
eval_outputs["response_idx"][m][n],
tokenizer=self.tokenizer,
eval_mode=True,
tags=self.response_format_tags,
)
for m in range(eval_outputs["input_ids"].size(0))
for n in range(eval_outputs["input_ids"].size(1))
@ -245,11 +248,10 @@ class BaseProducer:
self.eval_mode = False
self.latest_eval_step = self.consumer_global_step
outputs = self.rollout(**batch)
print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}")
outputs["temperature"] = torch.tensor(
[self.model.generate_config["temperature"]] * outputs["input_ids"].size(0)
).to(outputs["input_ids"].device)
print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}")
outputs = pre_send(outputs)
ray_broadcast_tensor_dict(
outputs, src=0, device=self.device, group_name=f"sync_data_{self.producer_idx}"
@ -324,6 +326,7 @@ class SimpleProducer(BaseProducer):
eval_dataset_config=None,
eval_interval=-1, # disable evaluation
evaluation_function_type="think_answer_tags",
response_format_tags=None,
eval_save_dir: str = "./eval",
eval_generation_config={},
project_name: str = None,
@ -349,6 +352,7 @@ class SimpleProducer(BaseProducer):
eval_dataset_config=eval_dataset_config,
eval_interval=eval_interval,
evaluation_function_type=evaluation_function_type,
response_format_tags=response_format_tags,
eval_save_dir=eval_save_dir,
project_name=project_name,
run_name=run_name,

View File

@ -121,6 +121,34 @@ if __name__ == "__main__":
parser.add_argument(
"-rsd", "--rollout-save-dir", type=str, default="./rollouts", help="Directory for saving rollout loggings."
)
parser.add_argument(
"-tp",
"--tensor-parallel-size",
type=int,
default=1,
help="Tensor parallel size for the trainer (consumer). Please check the generation arguments documentation for your backend.",
)
parser.add_argument(
"-pp",
"--pipeline-parallel-size",
type=int,
default=1,
help="Pipeline parallel size for the trainer (consumer). Please check the generation arguments documentation for your backend.",
)
parser.add_argument(
"-zero",
"--zero-stage",
type=int,
default=0,
help="Zero stage for the trainer (consumer). Please check the generation arguments documentation for your backend.",
)
parser.add_argument(
"-ptp",
"--producer-tensor-parallel-size",
type=int,
default=1,
help="Tensor parallel size for the producer. Please check the generation arguments documentation for your backend.",
)
args = parser.parse_args()
if args.train_minibatch_size is None:
@ -134,8 +162,8 @@ if __name__ == "__main__":
and args.train_microbatch_size > 0
), "Train micro batch size must be greater than 0 less than train mini batch size * num generations"
assert (
args.train_minibatch_size <= args.train_batch_size
), "Train mini batch size must be less than or equals to train batch size"
args.train_minibatch_size <= args.train_batch_size and args.train_batch_size % args.train_minibatch_size == 0
), "Train mini batch size must be less than or equals to train batch size and train batch size must be divisible by train mini batch size"
if args.master_address is None:
# Default settings: Using single machine
@ -178,7 +206,7 @@ if __name__ == "__main__":
enforce_eager=True,
enable_chunked_prefill=True,
max_model_len=args.max_new_tokens + args.max_prompt_tokens,
tensor_parallel_size=1,
tensor_parallel_size=args.producer_tensor_parallel_size,
)
)
generate_config.update(
@ -203,6 +231,16 @@ if __name__ == "__main__":
"reward_fn_type": args.reward_type,
"max_length": args.max_new_tokens + args.max_prompt_tokens,
"max_new_tokens": args.max_new_tokens,
"response_format_tags": (
{
"think_start": {"text": "<think>", "num_occur": 1},
"think_end": {"text": "</think>", "num_occur": 1},
"answer_start": {"text": "<answer>", "num_occur": 1},
"answer_end": {"text": "</answer>", "num_occur": 1},
}
if args.reward_type == "think_answer_tags"
else None
),
}
elif args.algo == "DAPO":
# DAPO variant settings
@ -222,13 +260,23 @@ if __name__ == "__main__":
"cache_length": min(1024, int(args.max_new_tokens / 4)),
"filter_truncated_response": True,
"reward_fn_type": args.reward_type,
"response_format_tags": (
{
"think_start": {"text": "<think>", "num_occur": 1},
"think_end": {"text": "</think>", "num_occur": 1},
"answer_start": {"text": "<answer>", "num_occur": 1},
"answer_end": {"text": "</answer>", "num_occur": 1},
}
if args.reward_type == "think_answer_tags"
else None
),
}
else:
raise ValueError(f"Unsupported algorithm: {args.algo}")
launch_distributed(
num_producers=args.num_inferencer,
num_proc_per_producer=inference_model_config.get("tensor_parallel_size", 1),
num_proc_per_producer=inference_model_config.get("tensor_parallel_size", args.producer_tensor_parallel_size),
num_consumer_procs=args.num_trainers,
num_episodes=args.num_episodes,
inference_batch_size=args.inference_batch_size,
@ -247,17 +295,14 @@ if __name__ == "__main__":
train_model_config=train_model_config,
grpo_config=grpo_config,
plugin_config={
"zero_stage": 2,
}, # for zero
# plugin_config={
# "tp_size": 2,
# "pp_size": 2,
# "microbatch_size": max(
# 1, args.train_microbatch_size // 2
# ), # microbatch size should be set to train_microbatch_size // pp_size
# "zero_stage": 0,
# "max_norm": 1.0,
# }, # for pp, tp
"tp_size": args.tensor_parallel_size,
"pp_size": args.pipeline_parallel_size,
"microbatch_size": max(
1, args.train_microbatch_size // args.pipeline_parallel_size
), # microbatch size should be set to train_microbatch_size // pp_size
"zero_stage": args.zero_stage,
"max_norm": 1.0,
}, # for pp, tp
inference_backend=args.backend,
master_addr="localhost",
master_port=args.master_port,