mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-25 15:01:43 +00:00
Merge pull request #6312 from hpcaitech/grpo-latest-dev
[feat] Move prompt-level-filtering to buffer side
This commit is contained in:
commit
ceb7065d6d
@ -117,26 +117,102 @@ class BaseConsumer:
|
||||
# receive data from producers
|
||||
for r in range(self.num_producers):
|
||||
print(f"[T{dist.get_rank()}] Recv data episode {episode} step {step} from {r}")
|
||||
self.buffer.extend(
|
||||
unbind_batch(
|
||||
ray_broadcast_tensor_dict(
|
||||
raw_batch = ray_broadcast_tensor_dict(
|
||||
None, src=0, device=self.device, group_name=f"sync_data_{r}"
|
||||
)
|
||||
# calculate group reward et al. filtering. As only the filtered group will be used for training (which is incomplete),
|
||||
# we need to calculate the metrics before filtering here for logging
|
||||
# [batch_size, num_generations, ...] -> [batch_size * num_generations, ...]
|
||||
raw_batch_with_reward = self.calculate_reward(
|
||||
{k: v.view(-1, v.size(-1)) if k != "temperature" else v for k, v in raw_batch.items()}
|
||||
)
|
||||
raw_batch_with_reward = {
|
||||
k: v.view(-1, self.num_generations, v.size(-1)) if k != "temperature" else v
|
||||
for k, v in raw_batch_with_reward.items()
|
||||
}
|
||||
# [batch_size, num_generations] -> [batch_size]
|
||||
reward = raw_batch_with_reward["reward"][:, :, 0]
|
||||
format_acc = raw_batch_with_reward["format_acc"][:, :, 0]
|
||||
ans_acc = raw_batch_with_reward["ans_acc"][:, :, 0]
|
||||
response_len = (
|
||||
raw_batch_with_reward["response_idx"][:, :, 1]
|
||||
- raw_batch_with_reward["response_idx"][:, :, 0]
|
||||
+ 1
|
||||
).type(torch.float32)
|
||||
effective_group_mask = None
|
||||
if self.filter_range is not None and self.grpo_config.get("dynamic_batching", True):
|
||||
# filter the group based on the reward and accuracy
|
||||
group_ans_acc_mean = ans_acc.mean(dim=1)
|
||||
effective_group_mask = torch.logical_and(
|
||||
group_ans_acc_mean > self.filter_range[0], group_ans_acc_mean < self.filter_range[1]
|
||||
)
|
||||
while len(self.buffer) >= self.dp_size * self.minibatch_size:
|
||||
batches = self.buffer[
|
||||
self.dp_rank * self.minibatch_size : (self.dp_rank + 1) * self.minibatch_size
|
||||
raw_batch_with_reward = unbind_batch(raw_batch_with_reward) # List[Dict[str, torch.Tensor]]
|
||||
for group_idx, group_with_reward in enumerate(raw_batch_with_reward):
|
||||
self.buffer.append(
|
||||
[
|
||||
(
|
||||
group_with_reward
|
||||
if effective_group_mask is None or effective_group_mask[group_idx]
|
||||
else None
|
||||
),
|
||||
reward[group_idx],
|
||||
format_acc[group_idx],
|
||||
ans_acc[group_idx],
|
||||
response_len[group_idx],
|
||||
]
|
||||
batch = bind_batch(batches)
|
||||
batch = post_recv(batch)
|
||||
loss, excessive_prompts_idx = self.step(i, pbar, **batch)
|
||||
)
|
||||
if effective_group_mask is not None:
|
||||
print(
|
||||
f"[T{dist.get_rank()}] Filter recv data: {len(raw_batch_with_reward)} -> {torch.sum(effective_group_mask).cpu().item()} effective groups"
|
||||
)
|
||||
# mapping the effective group to the raw group for indexing
|
||||
effective_group_to_raw_group_mapping = {}
|
||||
for buffer_idx in range(len(self.buffer)):
|
||||
if self.buffer[buffer_idx][0] is not None:
|
||||
effective_group_to_raw_group_mapping[len(effective_group_to_raw_group_mapping)] = (
|
||||
buffer_idx
|
||||
)
|
||||
print(
|
||||
f"[T{dist.get_rank()}] Collect Effective Prompt: {len(effective_group_to_raw_group_mapping)}/{self.dp_size * self.minibatch_size}"
|
||||
)
|
||||
|
||||
if excessive_prompts_idx is not None:
|
||||
excessive_prompts = [self.buffer[idx] for idx in excessive_prompts_idx]
|
||||
self.buffer = excessive_prompts + self.buffer[self.dp_size * self.minibatch_size :]
|
||||
else:
|
||||
self.buffer = self.buffer[self.dp_size * self.minibatch_size :]
|
||||
while len(effective_group_to_raw_group_mapping) >= self.dp_size * self.minibatch_size:
|
||||
# on each dp_rank, we use minibatch_size effective samples to form a batch
|
||||
batches = [
|
||||
self.buffer[effective_group_to_raw_group_mapping[i]]
|
||||
for i in range(
|
||||
self.dp_rank * self.minibatch_size, (self.dp_rank + 1) * self.minibatch_size
|
||||
)
|
||||
]
|
||||
# every dp_rank will receive a complete mini-batch, no need to sync within step() later
|
||||
# each mini-batch use the first self.dp_size * minibatch_size effective samples
|
||||
raw_mini_batches = self.buffer[
|
||||
: effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1
|
||||
] # include the last effective sample
|
||||
raw_mini_batches_metric_dict = {
|
||||
"raw_train_mini_batch_reward": [t[1] for t in raw_mini_batches],
|
||||
"raw_train_mini_batch_format_acc": [t[2] for t in raw_mini_batches],
|
||||
"raw_train_mini_batch_ans_acc": [t[3] for t in raw_mini_batches],
|
||||
"raw_train_mini_batch_response_len": [t[4] for t in raw_mini_batches],
|
||||
}
|
||||
batch = bind_batch([t[0] for t in batches])
|
||||
batch = post_recv(batch)
|
||||
loss = self.step(i, pbar, **batch, **raw_mini_batches_metric_dict)
|
||||
self.buffer = self.buffer[
|
||||
effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1 :
|
||||
]
|
||||
# recalculate the effective group to raw group mapping
|
||||
effective_group_to_raw_group_mapping_size_before = len(effective_group_to_raw_group_mapping)
|
||||
effective_group_to_raw_group_mapping = {}
|
||||
for buffer_idx in range(len(self.buffer)):
|
||||
if self.buffer[buffer_idx][0] is not None:
|
||||
effective_group_to_raw_group_mapping[len(effective_group_to_raw_group_mapping)] = (
|
||||
buffer_idx
|
||||
)
|
||||
assert (
|
||||
len(effective_group_to_raw_group_mapping)
|
||||
== effective_group_to_raw_group_mapping_size_before - self.dp_size * self.minibatch_size
|
||||
)
|
||||
if loss is not None:
|
||||
pbar.set_postfix({"loss": loss})
|
||||
i += 1
|
||||
|
@ -1,5 +1,5 @@
|
||||
from contextlib import nullcontext
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import ray
|
||||
import torch
|
||||
@ -9,7 +9,7 @@ from coati.distributed.loss import PolicyLoss
|
||||
from coati.distributed.reward.reward_fn import boxed_math_reward_fn, math_reward_fn
|
||||
from coati.distributed.reward.verifiable_reward import VerifiableReward
|
||||
from coati.distributed.utils import calc_action_log_probs
|
||||
from coati.trainer.utils import all_gather_tensors, all_reduce_mean, all_reduce_sum
|
||||
from coati.trainer.utils import all_reduce_mean, all_reduce_sum
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
@ -72,21 +72,18 @@ class GRPOConsumer(BaseConsumer):
|
||||
self.policy_model.gradient_checkpointing_enable()
|
||||
self.optimizer = HybridAdam(self.policy_model.parameters(), lr=grpo_config.get("lr", 1e-6))
|
||||
self.accum_loss = torch.zeros(1, device=self.device)
|
||||
self.accum_reward = torch.zeros(1, device=self.device)
|
||||
self.accum_kl = torch.zeros(1, device=self.device)
|
||||
self.accum_format_acc = torch.zeros(1, device=self.device)
|
||||
self.accum_ans_acc = torch.zeros(1, device=self.device)
|
||||
self.accum_advantages = torch.zeros(1, device=self.device)
|
||||
self.accum_response_length = torch.zeros(1, device=self.device)
|
||||
self.raw_train_batch_reward = []
|
||||
self.raw_train_batch_format_acc = []
|
||||
self.raw_train_batch_ans_acc = []
|
||||
self.raw_train_batch_response_len = []
|
||||
self.accum_count = 0
|
||||
self.generate_config = generate_config
|
||||
self.grpo_config = grpo_config
|
||||
self.project_name = project_name
|
||||
self.effective_sample_count = 0
|
||||
self.effective_prompt_count = 0
|
||||
self.total_sample_count = 0
|
||||
self.overlength_samples = 0
|
||||
self.total_overlength_samples = 0
|
||||
self.project_name = project_name
|
||||
self.run_name = run_name
|
||||
self.wandb_group_name = wandb_group_name
|
||||
@ -122,16 +119,7 @@ class GRPOConsumer(BaseConsumer):
|
||||
"either max_tokens (vllm) or max_new_tokens (transformers) must be set in generate_config."
|
||||
)
|
||||
# Initialize verifiable reward.
|
||||
response_format_tags = (
|
||||
{
|
||||
"think_start": {"text": "<think>", "num_occur": 1},
|
||||
"think_end": {"text": "</think>", "num_occur": 1},
|
||||
"answer_start": {"text": "<answer>", "num_occur": 1},
|
||||
"answer_end": {"text": "</answer>", "num_occur": 1},
|
||||
}
|
||||
if grpo_config.get("reward_fn_type") == "think_answer_tags"
|
||||
else None
|
||||
)
|
||||
response_format_tags = grpo_config.get("response_format_tags", None)
|
||||
reward_model_kwargs = {
|
||||
k: v
|
||||
for k, v in grpo_config.items()
|
||||
@ -187,24 +175,21 @@ class GRPOConsumer(BaseConsumer):
|
||||
Format:
|
||||
[minibatch_size, num_of_generation, prompt_length + response_length] --- <PAD>...<PAD><PROMPT>...<PROMPT><RESPONSE>...<RESPONSE><PAD>...<PAD>.
|
||||
"""
|
||||
|
||||
# Reshape to [minibatch_size x num_of_generation, prompt_length + response_length]
|
||||
data = {k: v.view(-1, v.size(-1)) for k, v in kwargs.items()}
|
||||
data = {k: v.view(-1, v.size(-1)) for k, v in kwargs.items() if "raw_train_mini_batch_" not in k}
|
||||
self.raw_train_batch_reward.extend(kwargs["raw_train_mini_batch_reward"])
|
||||
self.raw_train_batch_format_acc.extend(kwargs["raw_train_mini_batch_format_acc"])
|
||||
self.raw_train_batch_ans_acc.extend(kwargs["raw_train_mini_batch_ans_acc"])
|
||||
self.raw_train_batch_response_len.extend(kwargs["raw_train_mini_batch_response_len"])
|
||||
action_mask = data["action_mask"]
|
||||
num_action = action_mask.shape[1]
|
||||
old_action_log_probs = data["action_log_probs"]
|
||||
response_length = torch.sum(action_mask, dim=1).to(torch.float32)
|
||||
train_microbatch_size = self.grpo_config.get("train_microbatch_size", data["input_ids"].size(0))
|
||||
|
||||
reward_group = self.reward_model(
|
||||
data["input_ids"],
|
||||
gt_answer=data["gt_answer"],
|
||||
response_idx=data["response_idx"],
|
||||
)
|
||||
|
||||
reward = torch.tensor([value[0] for value in reward_group]).to(data["input_ids"].device)
|
||||
format_acc = torch.tensor([value[1] for value in reward_group]).to(data["input_ids"].device)
|
||||
ans_acc = torch.tensor([value[2] for value in reward_group]).to(data["input_ids"].device)
|
||||
reward = data["reward"].view((-1))
|
||||
format_acc = data["format_acc"].view((-1))
|
||||
ans_acc = data["ans_acc"].view((-1))
|
||||
|
||||
# [minibatch_size, num_generations]
|
||||
|
||||
@ -216,67 +201,35 @@ class GRPOConsumer(BaseConsumer):
|
||||
reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0)
|
||||
# [minibatch_size x num_generations]
|
||||
advantages = ((reward - reward_mean) / (reward_std + 1e-4)).unsqueeze(dim=-1)
|
||||
# filter out the reward that is too high (all sample gets full score) or too low (all sample gets 0 score),
|
||||
group_ans_acc = (
|
||||
ans_acc.view(-1, self.num_generations).mean(dim=1).repeat_interleave(self.num_generations, dim=0)
|
||||
)
|
||||
|
||||
# [minibatch_size x num_of_generation]
|
||||
loss_mask = (
|
||||
torch.ones(action_mask.size(0), device=action_mask.device).bool()
|
||||
if self.filter_range is None
|
||||
else torch.logical_and(group_ans_acc > self.filter_range[0], group_ans_acc < self.filter_range[1])
|
||||
)
|
||||
loss_mask = torch.ones(action_mask.size(0), device=action_mask.device).bool()
|
||||
|
||||
# filter out overlength samples
|
||||
if self.filter_truncated_response and action_mask.size(1) == self.max_length:
|
||||
old_loss_mask = loss_mask.clone()
|
||||
loss_mask = torch.logical_and(
|
||||
loss_mask,
|
||||
action_mask[:, -1] == False,
|
||||
)
|
||||
|
||||
self.overlength_samples = (old_loss_mask & ~loss_mask).sum().item()
|
||||
self.overlength_samples = all_reduce_sum(
|
||||
torch.tensor(self.overlength_samples, device=loss_mask.device), self.plugin
|
||||
if self.filter_range is not None and self.grpo_config.get("dynamic_batching", False) == False:
|
||||
# filter out samples with reward outside the range
|
||||
# if dynamic batching is enabled, we filter out out of range groups before training
|
||||
group_ans_acc_mean = (
|
||||
ans_acc.view(-1, self.num_generations).mean(dim=1).repeat_interleave(self.num_generations, dim=-1)
|
||||
)
|
||||
self.total_overlength_samples += self.overlength_samples.item()
|
||||
|
||||
prompt_level_mask = loss_mask.view(self.minibatch_size, self.num_generations)
|
||||
|
||||
# [minibatch_size] -> calculate the number of effective prompts
|
||||
effective_prompts_mask = prompt_level_mask.any(dim=1)
|
||||
effective_prompts = all_reduce_sum(torch.sum(effective_prompts_mask), self.plugin)
|
||||
self.effective_prompt_count += effective_prompts.item()
|
||||
excessive_prompts_idx = None
|
||||
loss_mask = torch.logical_and(
|
||||
loss_mask,
|
||||
torch.logical_and(
|
||||
group_ans_acc_mean > self.filter_range[0],
|
||||
group_ans_acc_mean < self.filter_range[1],
|
||||
),
|
||||
)
|
||||
self.effective_prompt_count += group_reward.size(0) * self.dp_size
|
||||
|
||||
mean_kl, mean_loss = [], []
|
||||
|
||||
if self.grpo_config.get("dynamic_batching", True):
|
||||
need_update = self.effective_prompt_count >= self.batch_size * self.dp_size
|
||||
excessive_prompts = self.effective_prompt_count - self.batch_size * self.dp_size
|
||||
|
||||
if excessive_prompts > 0:
|
||||
excessive_prompts_per_rank = excessive_prompts // self.dp_size
|
||||
# Only count excessive prompts if they are greater than 1 per rank.
|
||||
# TODO: customize excessive prompts calculation.
|
||||
if excessive_prompts_per_rank != 0:
|
||||
# Mask excessive prompts to False
|
||||
true_indices = torch.nonzero(effective_prompts_mask)
|
||||
# Make sure the indices are not empty.
|
||||
if true_indices.numel() > 0:
|
||||
true_indices = true_indices.squeeze(-1)
|
||||
if excessive_prompts_per_rank <= len(true_indices):
|
||||
excessive_prompts_idx = true_indices[-excessive_prompts_per_rank:]
|
||||
else:
|
||||
excessive_prompts_idx = true_indices
|
||||
effective_prompts_mask[excessive_prompts_idx] = False
|
||||
|
||||
for mask_idx in range(len(effective_prompts_mask)):
|
||||
if effective_prompts_mask[mask_idx] == False:
|
||||
# Update loss mask.
|
||||
loss_mask[mask_idx] = False
|
||||
else:
|
||||
excessive_prompts_idx = torch.empty([0])
|
||||
else:
|
||||
# If dynamic batching is disabled, we need to use all samples for training.
|
||||
need_update = (step_idx + 1) % self.num_microbatches == 0
|
||||
@ -286,12 +239,10 @@ class GRPOConsumer(BaseConsumer):
|
||||
total_effective_tokens_count = all_reduce_sum(torch.sum(effective_tokens_count), self.plugin)
|
||||
total_samples = all_reduce_sum(torch.sum(torch.ones_like(loss_mask, device=loss_mask.device)), self.plugin)
|
||||
self.effective_sample_count += effective_samples.item()
|
||||
self.total_sample_count += total_samples.item()
|
||||
pbar.set_postfix(
|
||||
{
|
||||
"Global Step": self.global_step,
|
||||
"Effective prompts": f"{self.effective_prompt_count}/{self.batch_size * self.dp_size}",
|
||||
"Effective samples": f"{self.effective_sample_count}/{self.batch_size * self.dp_size * self.num_generations}",
|
||||
"Gradient Accumulation on": f"{self.effective_prompt_count}/{self.batch_size * self.dp_size} effective prompts, {self.effective_sample_count}/{self.batch_size * self.dp_size * self.num_generations} effective samples",
|
||||
}
|
||||
)
|
||||
|
||||
@ -483,22 +434,16 @@ class GRPOConsumer(BaseConsumer):
|
||||
self.accum_loss.add_(sum(mean_loss) / len(mean_loss))
|
||||
if self.policy_loss_fn.beta > 0:
|
||||
self.accum_kl.add_(sum(mean_kl) / len(mean_kl))
|
||||
self.accum_reward.add_(reward.data)
|
||||
self.accum_format_acc.add_(format_acc.data)
|
||||
self.accum_ans_acc.add_(ans_acc.data)
|
||||
self.accum_advantages.add_(advantages.data)
|
||||
self.accum_response_length.add_(response_length.data)
|
||||
self.accum_count += 1
|
||||
if need_update:
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
self.global_step += 1
|
||||
sample_utilization = self.effective_sample_count / self.total_sample_count
|
||||
overlength_samples_percentage = self.total_overlength_samples / self.total_sample_count
|
||||
# no need to run all reduce as raw_train_batch_* are not splited across dp rank
|
||||
sample_utilization = self.effective_sample_count / len(self.raw_train_batch_reward) / self.num_generations
|
||||
self.effective_prompt_count = 0
|
||||
self.effective_sample_count = 0
|
||||
self.total_sample_count = 0
|
||||
self.total_overlength_samples = 0
|
||||
loss_scalar = self.accum_loss.item()
|
||||
if not self.plugin.pp_size > 1 or (
|
||||
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
|
||||
@ -506,27 +451,39 @@ class GRPOConsumer(BaseConsumer):
|
||||
if (not self.plugin.pp_size > 1 and self.rank == 0) or (
|
||||
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
|
||||
):
|
||||
raw_batch_reward_mean = torch.cat(self.raw_train_batch_reward, dim=0).mean().cpu().item()
|
||||
raw_batch_format_acc_mean = torch.cat(self.raw_train_batch_format_acc, dim=0).mean().cpu().item()
|
||||
raw_batch_ans_acc_mean = torch.cat(self.raw_train_batch_ans_acc, dim=0).mean().cpu().item()
|
||||
raw_batch_response_len = torch.cat(self.raw_train_batch_response_len, dim=0)
|
||||
raw_batch_response_len_mean = raw_batch_response_len.mean().cpu().item()
|
||||
overlength_samples_ratio = (
|
||||
(raw_batch_response_len >= action_mask.size(-1)).to(float).mean().cpu().item()
|
||||
) # not an exact figure, but a close estimate
|
||||
self.raw_train_batch_reward = []
|
||||
self.raw_train_batch_format_acc = []
|
||||
self.raw_train_batch_ans_acc = []
|
||||
self.raw_train_batch_response_len = []
|
||||
to_log_msg = [
|
||||
f"Loss: {self.accum_loss.item() / self.accum_count:.4f}",
|
||||
f"Reward: {self.accum_reward.item() / self.accum_count:.4f}",
|
||||
f"format Reward: {self.accum_format_acc.item() / self.accum_count:.4f}",
|
||||
f"Acc Reward: {self.accum_ans_acc.item() / self.accum_count:.4f}",
|
||||
f"Reward: {raw_batch_reward_mean:.4f}",
|
||||
f"format Reward: {raw_batch_format_acc_mean:.4f}",
|
||||
f"Acc Reward: {raw_batch_ans_acc_mean:.4f}",
|
||||
f"Advantages: {self.accum_advantages.item() / self.accum_count:.4f}",
|
||||
f"Response Length: {self.accum_response_length.item() / self.accum_count:.4f}",
|
||||
f"Response Length: {raw_batch_response_len_mean:.4f}",
|
||||
f"Sample_utilization: {sample_utilization:.4f}",
|
||||
f"Percentage of overlength samples: {overlength_samples_percentage:.4f}",
|
||||
f"Overlength samples ratio: {overlength_samples_ratio:.4f}",
|
||||
] + ([f"KL: {self.accum_kl.item() / self.accum_count:.4f}"] if self.policy_loss_fn.beta > 0 else [])
|
||||
print("\n".join(to_log_msg))
|
||||
metrics = {
|
||||
"metrics/reward": self.accum_reward.item() / self.accum_count,
|
||||
"metrics/format_acc": self.accum_format_acc.item() / self.accum_count,
|
||||
"metrics/ans_acc": self.accum_ans_acc.item() / self.accum_count,
|
||||
"metrics/response_length": self.accum_response_length.item() / self.accum_count,
|
||||
"metrics/reward": raw_batch_reward_mean,
|
||||
"metrics/format_acc": raw_batch_format_acc_mean,
|
||||
"metrics/ans_acc": raw_batch_ans_acc_mean,
|
||||
"metrics/response_length": raw_batch_response_len_mean,
|
||||
"train/loss": self.accum_loss.item() / self.accum_count,
|
||||
"train/advantages": self.accum_advantages.item() / self.accum_count,
|
||||
"train/learning_rate": self.lr_scheduler.get_last_lr()[0],
|
||||
"train/sample_utilization": sample_utilization,
|
||||
"train/percentage_overlength_samples": overlength_samples_percentage,
|
||||
"train/overlength_samples_ratio": overlength_samples_ratio,
|
||||
"rollout/temperature": data["temperature"].cpu().numpy()[0][0],
|
||||
}
|
||||
if self.policy_loss_fn.beta > 0:
|
||||
@ -534,21 +491,46 @@ class GRPOConsumer(BaseConsumer):
|
||||
if self.wandb_run is not None:
|
||||
self.wandb_run.log(metrics)
|
||||
self.accum_loss.zero_()
|
||||
self.accum_reward.zero_()
|
||||
self.accum_ans_acc.zero_()
|
||||
self.accum_format_acc.zero_()
|
||||
self.accum_kl.zero_()
|
||||
self.accum_advantages.zero_()
|
||||
self.accum_response_length.zero_()
|
||||
self.accum_count = 0
|
||||
|
||||
if excessive_prompts_idx is not None:
|
||||
# All gather excessive prompts index across DP ranks.
|
||||
excessive_prompts_idx = [idx + self.dp_rank * self.minibatch_size for idx in excessive_prompts_idx]
|
||||
excessive_prompts_idx = all_gather_tensors(excessive_prompts_idx, self.plugin)
|
||||
return loss_scalar, excessive_prompts_idx
|
||||
return loss_scalar
|
||||
else:
|
||||
return None, excessive_prompts_idx
|
||||
return None
|
||||
|
||||
def calculate_reward(self, rollout: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Calculate the group reward for the given rollout group.
|
||||
|
||||
Args:
|
||||
rollout_group (Dict[str, Any]):
|
||||
a group of samples generated by the model from the same prompt
|
||||
contain the following keys:
|
||||
"input_ids": torch.Tensor, [num_of_generation, prompt_length + response_length]
|
||||
"attention_mask": torch.Tensor, [num_of_generation, prompt_length + response_length]
|
||||
"action_mask": torch.Tensor, [num_of_generation, response_length]
|
||||
"action_log_probs": torch.Tensor, [num_of_generation, response_length]
|
||||
"response_idx": int, torch.Tensor, [num_of_generation, 2]
|
||||
"gt_answer": torch.Tensor, [num_of_generation, 128]
|
||||
"temperature": torch.Tensor, [] (scalar)
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: The new group data with calculated reward.
|
||||
"""
|
||||
reward_model_output = self.reward_model(
|
||||
rollout["input_ids"],
|
||||
gt_answer=rollout["gt_answer"],
|
||||
response_idx=rollout["response_idx"],
|
||||
)
|
||||
# [num_of_generation]
|
||||
reward = torch.tensor([value[0] for value in reward_model_output]).to(rollout["input_ids"].device)
|
||||
format_acc = torch.tensor([value[1] for value in reward_model_output]).to(rollout["input_ids"].device)
|
||||
ans_acc = torch.tensor([value[2] for value in reward_model_output]).to(rollout["input_ids"].device)
|
||||
|
||||
rollout["reward"] = reward.view((-1, 1))
|
||||
rollout["format_acc"] = format_acc.view((-1, 1))
|
||||
rollout["ans_acc"] = ans_acc.view((-1, 1))
|
||||
return rollout
|
||||
|
||||
def state_dict(self):
|
||||
self.policy_model._force_wait_all_gather()
|
||||
|
@ -100,6 +100,7 @@ def launch_distributed(
|
||||
eval_dataset_config=eval_dataset_config,
|
||||
eval_interval=eval_interval,
|
||||
evaluation_function_type=grpo_config["reward_fn_type"],
|
||||
response_format_tags=grpo_config["response_format_tags"],
|
||||
eval_save_dir=eval_save_dir,
|
||||
eval_generation_config=eval_generation_config,
|
||||
project_name=project_name,
|
||||
|
@ -46,6 +46,7 @@ class BaseProducer:
|
||||
eval_dataset_config=None,
|
||||
eval_interval=-1, # disable evaluation
|
||||
evaluation_function_type="think_answer_tags",
|
||||
response_format_tags=None,
|
||||
eval_save_dir: str = "./eval",
|
||||
project_name: str = None,
|
||||
run_name: str = None,
|
||||
@ -148,6 +149,7 @@ class BaseProducer:
|
||||
self.evaluation_function = boxed_math_reward_fn
|
||||
else:
|
||||
raise ValueError(f"Unknown evaluation function type {evaluation_function_type}")
|
||||
self.response_format_tags = response_format_tags
|
||||
else:
|
||||
print("No eval dataset provided, skip eval")
|
||||
self.device = get_current_device()
|
||||
@ -217,6 +219,7 @@ class BaseProducer:
|
||||
eval_outputs["response_idx"][m][n],
|
||||
tokenizer=self.tokenizer,
|
||||
eval_mode=True,
|
||||
tags=self.response_format_tags,
|
||||
)
|
||||
for m in range(eval_outputs["input_ids"].size(0))
|
||||
for n in range(eval_outputs["input_ids"].size(1))
|
||||
@ -245,11 +248,10 @@ class BaseProducer:
|
||||
self.eval_mode = False
|
||||
self.latest_eval_step = self.consumer_global_step
|
||||
outputs = self.rollout(**batch)
|
||||
|
||||
print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}")
|
||||
outputs["temperature"] = torch.tensor(
|
||||
[self.model.generate_config["temperature"]] * outputs["input_ids"].size(0)
|
||||
).to(outputs["input_ids"].device)
|
||||
print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}")
|
||||
outputs = pre_send(outputs)
|
||||
ray_broadcast_tensor_dict(
|
||||
outputs, src=0, device=self.device, group_name=f"sync_data_{self.producer_idx}"
|
||||
@ -324,6 +326,7 @@ class SimpleProducer(BaseProducer):
|
||||
eval_dataset_config=None,
|
||||
eval_interval=-1, # disable evaluation
|
||||
evaluation_function_type="think_answer_tags",
|
||||
response_format_tags=None,
|
||||
eval_save_dir: str = "./eval",
|
||||
eval_generation_config={},
|
||||
project_name: str = None,
|
||||
@ -349,6 +352,7 @@ class SimpleProducer(BaseProducer):
|
||||
eval_dataset_config=eval_dataset_config,
|
||||
eval_interval=eval_interval,
|
||||
evaluation_function_type=evaluation_function_type,
|
||||
response_format_tags=response_format_tags,
|
||||
eval_save_dir=eval_save_dir,
|
||||
project_name=project_name,
|
||||
run_name=run_name,
|
||||
|
@ -121,6 +121,34 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"-rsd", "--rollout-save-dir", type=str, default="./rollouts", help="Directory for saving rollout loggings."
|
||||
)
|
||||
parser.add_argument(
|
||||
"-tp",
|
||||
"--tensor-parallel-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Tensor parallel size for the trainer (consumer). Please check the generation arguments documentation for your backend.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-pp",
|
||||
"--pipeline-parallel-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Pipeline parallel size for the trainer (consumer). Please check the generation arguments documentation for your backend.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-zero",
|
||||
"--zero-stage",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Zero stage for the trainer (consumer). Please check the generation arguments documentation for your backend.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-ptp",
|
||||
"--producer-tensor-parallel-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Tensor parallel size for the producer. Please check the generation arguments documentation for your backend.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.train_minibatch_size is None:
|
||||
@ -134,8 +162,8 @@ if __name__ == "__main__":
|
||||
and args.train_microbatch_size > 0
|
||||
), "Train micro batch size must be greater than 0 less than train mini batch size * num generations"
|
||||
assert (
|
||||
args.train_minibatch_size <= args.train_batch_size
|
||||
), "Train mini batch size must be less than or equals to train batch size"
|
||||
args.train_minibatch_size <= args.train_batch_size and args.train_batch_size % args.train_minibatch_size == 0
|
||||
), "Train mini batch size must be less than or equals to train batch size and train batch size must be divisible by train mini batch size"
|
||||
|
||||
if args.master_address is None:
|
||||
# Default settings: Using single machine
|
||||
@ -178,7 +206,7 @@ if __name__ == "__main__":
|
||||
enforce_eager=True,
|
||||
enable_chunked_prefill=True,
|
||||
max_model_len=args.max_new_tokens + args.max_prompt_tokens,
|
||||
tensor_parallel_size=1,
|
||||
tensor_parallel_size=args.producer_tensor_parallel_size,
|
||||
)
|
||||
)
|
||||
generate_config.update(
|
||||
@ -203,6 +231,16 @@ if __name__ == "__main__":
|
||||
"reward_fn_type": args.reward_type,
|
||||
"max_length": args.max_new_tokens + args.max_prompt_tokens,
|
||||
"max_new_tokens": args.max_new_tokens,
|
||||
"response_format_tags": (
|
||||
{
|
||||
"think_start": {"text": "<think>", "num_occur": 1},
|
||||
"think_end": {"text": "</think>", "num_occur": 1},
|
||||
"answer_start": {"text": "<answer>", "num_occur": 1},
|
||||
"answer_end": {"text": "</answer>", "num_occur": 1},
|
||||
}
|
||||
if args.reward_type == "think_answer_tags"
|
||||
else None
|
||||
),
|
||||
}
|
||||
elif args.algo == "DAPO":
|
||||
# DAPO variant settings
|
||||
@ -222,13 +260,23 @@ if __name__ == "__main__":
|
||||
"cache_length": min(1024, int(args.max_new_tokens / 4)),
|
||||
"filter_truncated_response": True,
|
||||
"reward_fn_type": args.reward_type,
|
||||
"response_format_tags": (
|
||||
{
|
||||
"think_start": {"text": "<think>", "num_occur": 1},
|
||||
"think_end": {"text": "</think>", "num_occur": 1},
|
||||
"answer_start": {"text": "<answer>", "num_occur": 1},
|
||||
"answer_end": {"text": "</answer>", "num_occur": 1},
|
||||
}
|
||||
if args.reward_type == "think_answer_tags"
|
||||
else None
|
||||
),
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unsupported algorithm: {args.algo}")
|
||||
|
||||
launch_distributed(
|
||||
num_producers=args.num_inferencer,
|
||||
num_proc_per_producer=inference_model_config.get("tensor_parallel_size", 1),
|
||||
num_proc_per_producer=inference_model_config.get("tensor_parallel_size", args.producer_tensor_parallel_size),
|
||||
num_consumer_procs=args.num_trainers,
|
||||
num_episodes=args.num_episodes,
|
||||
inference_batch_size=args.inference_batch_size,
|
||||
@ -247,17 +295,14 @@ if __name__ == "__main__":
|
||||
train_model_config=train_model_config,
|
||||
grpo_config=grpo_config,
|
||||
plugin_config={
|
||||
"zero_stage": 2,
|
||||
}, # for zero
|
||||
# plugin_config={
|
||||
# "tp_size": 2,
|
||||
# "pp_size": 2,
|
||||
# "microbatch_size": max(
|
||||
# 1, args.train_microbatch_size // 2
|
||||
# ), # microbatch size should be set to train_microbatch_size // pp_size
|
||||
# "zero_stage": 0,
|
||||
# "max_norm": 1.0,
|
||||
# }, # for pp, tp
|
||||
"tp_size": args.tensor_parallel_size,
|
||||
"pp_size": args.pipeline_parallel_size,
|
||||
"microbatch_size": max(
|
||||
1, args.train_microbatch_size // args.pipeline_parallel_size
|
||||
), # microbatch size should be set to train_microbatch_size // pp_size
|
||||
"zero_stage": args.zero_stage,
|
||||
"max_norm": 1.0,
|
||||
}, # for pp, tp
|
||||
inference_backend=args.backend,
|
||||
master_addr="localhost",
|
||||
master_port=args.master_port,
|
||||
|
Loading…
Reference in New Issue
Block a user