mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-25 15:01:43 +00:00
Merge pull request #6312 from hpcaitech/grpo-latest-dev
[feat] Move prompt-level-filtering to buffer side
This commit is contained in:
commit
ceb7065d6d
@ -117,26 +117,102 @@ class BaseConsumer:
|
|||||||
# receive data from producers
|
# receive data from producers
|
||||||
for r in range(self.num_producers):
|
for r in range(self.num_producers):
|
||||||
print(f"[T{dist.get_rank()}] Recv data episode {episode} step {step} from {r}")
|
print(f"[T{dist.get_rank()}] Recv data episode {episode} step {step} from {r}")
|
||||||
self.buffer.extend(
|
raw_batch = ray_broadcast_tensor_dict(
|
||||||
unbind_batch(
|
|
||||||
ray_broadcast_tensor_dict(
|
|
||||||
None, src=0, device=self.device, group_name=f"sync_data_{r}"
|
None, src=0, device=self.device, group_name=f"sync_data_{r}"
|
||||||
)
|
)
|
||||||
|
# calculate group reward et al. filtering. As only the filtered group will be used for training (which is incomplete),
|
||||||
|
# we need to calculate the metrics before filtering here for logging
|
||||||
|
# [batch_size, num_generations, ...] -> [batch_size * num_generations, ...]
|
||||||
|
raw_batch_with_reward = self.calculate_reward(
|
||||||
|
{k: v.view(-1, v.size(-1)) if k != "temperature" else v for k, v in raw_batch.items()}
|
||||||
)
|
)
|
||||||
|
raw_batch_with_reward = {
|
||||||
|
k: v.view(-1, self.num_generations, v.size(-1)) if k != "temperature" else v
|
||||||
|
for k, v in raw_batch_with_reward.items()
|
||||||
|
}
|
||||||
|
# [batch_size, num_generations] -> [batch_size]
|
||||||
|
reward = raw_batch_with_reward["reward"][:, :, 0]
|
||||||
|
format_acc = raw_batch_with_reward["format_acc"][:, :, 0]
|
||||||
|
ans_acc = raw_batch_with_reward["ans_acc"][:, :, 0]
|
||||||
|
response_len = (
|
||||||
|
raw_batch_with_reward["response_idx"][:, :, 1]
|
||||||
|
- raw_batch_with_reward["response_idx"][:, :, 0]
|
||||||
|
+ 1
|
||||||
|
).type(torch.float32)
|
||||||
|
effective_group_mask = None
|
||||||
|
if self.filter_range is not None and self.grpo_config.get("dynamic_batching", True):
|
||||||
|
# filter the group based on the reward and accuracy
|
||||||
|
group_ans_acc_mean = ans_acc.mean(dim=1)
|
||||||
|
effective_group_mask = torch.logical_and(
|
||||||
|
group_ans_acc_mean > self.filter_range[0], group_ans_acc_mean < self.filter_range[1]
|
||||||
)
|
)
|
||||||
while len(self.buffer) >= self.dp_size * self.minibatch_size:
|
raw_batch_with_reward = unbind_batch(raw_batch_with_reward) # List[Dict[str, torch.Tensor]]
|
||||||
batches = self.buffer[
|
for group_idx, group_with_reward in enumerate(raw_batch_with_reward):
|
||||||
self.dp_rank * self.minibatch_size : (self.dp_rank + 1) * self.minibatch_size
|
self.buffer.append(
|
||||||
|
[
|
||||||
|
(
|
||||||
|
group_with_reward
|
||||||
|
if effective_group_mask is None or effective_group_mask[group_idx]
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
reward[group_idx],
|
||||||
|
format_acc[group_idx],
|
||||||
|
ans_acc[group_idx],
|
||||||
|
response_len[group_idx],
|
||||||
]
|
]
|
||||||
batch = bind_batch(batches)
|
)
|
||||||
batch = post_recv(batch)
|
if effective_group_mask is not None:
|
||||||
loss, excessive_prompts_idx = self.step(i, pbar, **batch)
|
print(
|
||||||
|
f"[T{dist.get_rank()}] Filter recv data: {len(raw_batch_with_reward)} -> {torch.sum(effective_group_mask).cpu().item()} effective groups"
|
||||||
|
)
|
||||||
|
# mapping the effective group to the raw group for indexing
|
||||||
|
effective_group_to_raw_group_mapping = {}
|
||||||
|
for buffer_idx in range(len(self.buffer)):
|
||||||
|
if self.buffer[buffer_idx][0] is not None:
|
||||||
|
effective_group_to_raw_group_mapping[len(effective_group_to_raw_group_mapping)] = (
|
||||||
|
buffer_idx
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"[T{dist.get_rank()}] Collect Effective Prompt: {len(effective_group_to_raw_group_mapping)}/{self.dp_size * self.minibatch_size}"
|
||||||
|
)
|
||||||
|
|
||||||
if excessive_prompts_idx is not None:
|
while len(effective_group_to_raw_group_mapping) >= self.dp_size * self.minibatch_size:
|
||||||
excessive_prompts = [self.buffer[idx] for idx in excessive_prompts_idx]
|
# on each dp_rank, we use minibatch_size effective samples to form a batch
|
||||||
self.buffer = excessive_prompts + self.buffer[self.dp_size * self.minibatch_size :]
|
batches = [
|
||||||
else:
|
self.buffer[effective_group_to_raw_group_mapping[i]]
|
||||||
self.buffer = self.buffer[self.dp_size * self.minibatch_size :]
|
for i in range(
|
||||||
|
self.dp_rank * self.minibatch_size, (self.dp_rank + 1) * self.minibatch_size
|
||||||
|
)
|
||||||
|
]
|
||||||
|
# every dp_rank will receive a complete mini-batch, no need to sync within step() later
|
||||||
|
# each mini-batch use the first self.dp_size * minibatch_size effective samples
|
||||||
|
raw_mini_batches = self.buffer[
|
||||||
|
: effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1
|
||||||
|
] # include the last effective sample
|
||||||
|
raw_mini_batches_metric_dict = {
|
||||||
|
"raw_train_mini_batch_reward": [t[1] for t in raw_mini_batches],
|
||||||
|
"raw_train_mini_batch_format_acc": [t[2] for t in raw_mini_batches],
|
||||||
|
"raw_train_mini_batch_ans_acc": [t[3] for t in raw_mini_batches],
|
||||||
|
"raw_train_mini_batch_response_len": [t[4] for t in raw_mini_batches],
|
||||||
|
}
|
||||||
|
batch = bind_batch([t[0] for t in batches])
|
||||||
|
batch = post_recv(batch)
|
||||||
|
loss = self.step(i, pbar, **batch, **raw_mini_batches_metric_dict)
|
||||||
|
self.buffer = self.buffer[
|
||||||
|
effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1 :
|
||||||
|
]
|
||||||
|
# recalculate the effective group to raw group mapping
|
||||||
|
effective_group_to_raw_group_mapping_size_before = len(effective_group_to_raw_group_mapping)
|
||||||
|
effective_group_to_raw_group_mapping = {}
|
||||||
|
for buffer_idx in range(len(self.buffer)):
|
||||||
|
if self.buffer[buffer_idx][0] is not None:
|
||||||
|
effective_group_to_raw_group_mapping[len(effective_group_to_raw_group_mapping)] = (
|
||||||
|
buffer_idx
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
len(effective_group_to_raw_group_mapping)
|
||||||
|
== effective_group_to_raw_group_mapping_size_before - self.dp_size * self.minibatch_size
|
||||||
|
)
|
||||||
if loss is not None:
|
if loss is not None:
|
||||||
pbar.set_postfix({"loss": loss})
|
pbar.set_postfix({"loss": loss})
|
||||||
i += 1
|
i += 1
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from typing import Any, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
import torch
|
import torch
|
||||||
@ -9,7 +9,7 @@ from coati.distributed.loss import PolicyLoss
|
|||||||
from coati.distributed.reward.reward_fn import boxed_math_reward_fn, math_reward_fn
|
from coati.distributed.reward.reward_fn import boxed_math_reward_fn, math_reward_fn
|
||||||
from coati.distributed.reward.verifiable_reward import VerifiableReward
|
from coati.distributed.reward.verifiable_reward import VerifiableReward
|
||||||
from coati.distributed.utils import calc_action_log_probs
|
from coati.distributed.utils import calc_action_log_probs
|
||||||
from coati.trainer.utils import all_gather_tensors, all_reduce_mean, all_reduce_sum
|
from coati.trainer.utils import all_reduce_mean, all_reduce_sum
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||||
@ -72,21 +72,18 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
self.policy_model.gradient_checkpointing_enable()
|
self.policy_model.gradient_checkpointing_enable()
|
||||||
self.optimizer = HybridAdam(self.policy_model.parameters(), lr=grpo_config.get("lr", 1e-6))
|
self.optimizer = HybridAdam(self.policy_model.parameters(), lr=grpo_config.get("lr", 1e-6))
|
||||||
self.accum_loss = torch.zeros(1, device=self.device)
|
self.accum_loss = torch.zeros(1, device=self.device)
|
||||||
self.accum_reward = torch.zeros(1, device=self.device)
|
|
||||||
self.accum_kl = torch.zeros(1, device=self.device)
|
self.accum_kl = torch.zeros(1, device=self.device)
|
||||||
self.accum_format_acc = torch.zeros(1, device=self.device)
|
|
||||||
self.accum_ans_acc = torch.zeros(1, device=self.device)
|
|
||||||
self.accum_advantages = torch.zeros(1, device=self.device)
|
self.accum_advantages = torch.zeros(1, device=self.device)
|
||||||
self.accum_response_length = torch.zeros(1, device=self.device)
|
self.raw_train_batch_reward = []
|
||||||
|
self.raw_train_batch_format_acc = []
|
||||||
|
self.raw_train_batch_ans_acc = []
|
||||||
|
self.raw_train_batch_response_len = []
|
||||||
self.accum_count = 0
|
self.accum_count = 0
|
||||||
self.generate_config = generate_config
|
self.generate_config = generate_config
|
||||||
self.grpo_config = grpo_config
|
self.grpo_config = grpo_config
|
||||||
self.project_name = project_name
|
self.project_name = project_name
|
||||||
self.effective_sample_count = 0
|
self.effective_sample_count = 0
|
||||||
self.effective_prompt_count = 0
|
self.effective_prompt_count = 0
|
||||||
self.total_sample_count = 0
|
|
||||||
self.overlength_samples = 0
|
|
||||||
self.total_overlength_samples = 0
|
|
||||||
self.project_name = project_name
|
self.project_name = project_name
|
||||||
self.run_name = run_name
|
self.run_name = run_name
|
||||||
self.wandb_group_name = wandb_group_name
|
self.wandb_group_name = wandb_group_name
|
||||||
@ -122,16 +119,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
"either max_tokens (vllm) or max_new_tokens (transformers) must be set in generate_config."
|
"either max_tokens (vllm) or max_new_tokens (transformers) must be set in generate_config."
|
||||||
)
|
)
|
||||||
# Initialize verifiable reward.
|
# Initialize verifiable reward.
|
||||||
response_format_tags = (
|
response_format_tags = grpo_config.get("response_format_tags", None)
|
||||||
{
|
|
||||||
"think_start": {"text": "<think>", "num_occur": 1},
|
|
||||||
"think_end": {"text": "</think>", "num_occur": 1},
|
|
||||||
"answer_start": {"text": "<answer>", "num_occur": 1},
|
|
||||||
"answer_end": {"text": "</answer>", "num_occur": 1},
|
|
||||||
}
|
|
||||||
if grpo_config.get("reward_fn_type") == "think_answer_tags"
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
reward_model_kwargs = {
|
reward_model_kwargs = {
|
||||||
k: v
|
k: v
|
||||||
for k, v in grpo_config.items()
|
for k, v in grpo_config.items()
|
||||||
@ -187,24 +175,21 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
Format:
|
Format:
|
||||||
[minibatch_size, num_of_generation, prompt_length + response_length] --- <PAD>...<PAD><PROMPT>...<PROMPT><RESPONSE>...<RESPONSE><PAD>...<PAD>.
|
[minibatch_size, num_of_generation, prompt_length + response_length] --- <PAD>...<PAD><PROMPT>...<PROMPT><RESPONSE>...<RESPONSE><PAD>...<PAD>.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Reshape to [minibatch_size x num_of_generation, prompt_length + response_length]
|
# Reshape to [minibatch_size x num_of_generation, prompt_length + response_length]
|
||||||
data = {k: v.view(-1, v.size(-1)) for k, v in kwargs.items()}
|
data = {k: v.view(-1, v.size(-1)) for k, v in kwargs.items() if "raw_train_mini_batch_" not in k}
|
||||||
|
self.raw_train_batch_reward.extend(kwargs["raw_train_mini_batch_reward"])
|
||||||
|
self.raw_train_batch_format_acc.extend(kwargs["raw_train_mini_batch_format_acc"])
|
||||||
|
self.raw_train_batch_ans_acc.extend(kwargs["raw_train_mini_batch_ans_acc"])
|
||||||
|
self.raw_train_batch_response_len.extend(kwargs["raw_train_mini_batch_response_len"])
|
||||||
action_mask = data["action_mask"]
|
action_mask = data["action_mask"]
|
||||||
num_action = action_mask.shape[1]
|
num_action = action_mask.shape[1]
|
||||||
old_action_log_probs = data["action_log_probs"]
|
old_action_log_probs = data["action_log_probs"]
|
||||||
response_length = torch.sum(action_mask, dim=1).to(torch.float32)
|
response_length = torch.sum(action_mask, dim=1).to(torch.float32)
|
||||||
train_microbatch_size = self.grpo_config.get("train_microbatch_size", data["input_ids"].size(0))
|
train_microbatch_size = self.grpo_config.get("train_microbatch_size", data["input_ids"].size(0))
|
||||||
|
|
||||||
reward_group = self.reward_model(
|
reward = data["reward"].view((-1))
|
||||||
data["input_ids"],
|
format_acc = data["format_acc"].view((-1))
|
||||||
gt_answer=data["gt_answer"],
|
ans_acc = data["ans_acc"].view((-1))
|
||||||
response_idx=data["response_idx"],
|
|
||||||
)
|
|
||||||
|
|
||||||
reward = torch.tensor([value[0] for value in reward_group]).to(data["input_ids"].device)
|
|
||||||
format_acc = torch.tensor([value[1] for value in reward_group]).to(data["input_ids"].device)
|
|
||||||
ans_acc = torch.tensor([value[2] for value in reward_group]).to(data["input_ids"].device)
|
|
||||||
|
|
||||||
# [minibatch_size, num_generations]
|
# [minibatch_size, num_generations]
|
||||||
|
|
||||||
@ -216,67 +201,35 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0)
|
reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0)
|
||||||
# [minibatch_size x num_generations]
|
# [minibatch_size x num_generations]
|
||||||
advantages = ((reward - reward_mean) / (reward_std + 1e-4)).unsqueeze(dim=-1)
|
advantages = ((reward - reward_mean) / (reward_std + 1e-4)).unsqueeze(dim=-1)
|
||||||
# filter out the reward that is too high (all sample gets full score) or too low (all sample gets 0 score),
|
|
||||||
group_ans_acc = (
|
|
||||||
ans_acc.view(-1, self.num_generations).mean(dim=1).repeat_interleave(self.num_generations, dim=0)
|
|
||||||
)
|
|
||||||
# [minibatch_size x num_of_generation]
|
# [minibatch_size x num_of_generation]
|
||||||
loss_mask = (
|
loss_mask = torch.ones(action_mask.size(0), device=action_mask.device).bool()
|
||||||
torch.ones(action_mask.size(0), device=action_mask.device).bool()
|
|
||||||
if self.filter_range is None
|
|
||||||
else torch.logical_and(group_ans_acc > self.filter_range[0], group_ans_acc < self.filter_range[1])
|
|
||||||
)
|
|
||||||
|
|
||||||
# filter out overlength samples
|
# filter out overlength samples
|
||||||
if self.filter_truncated_response and action_mask.size(1) == self.max_length:
|
if self.filter_truncated_response and action_mask.size(1) == self.max_length:
|
||||||
old_loss_mask = loss_mask.clone()
|
|
||||||
loss_mask = torch.logical_and(
|
loss_mask = torch.logical_and(
|
||||||
loss_mask,
|
loss_mask,
|
||||||
action_mask[:, -1] == False,
|
action_mask[:, -1] == False,
|
||||||
)
|
)
|
||||||
|
if self.filter_range is not None and self.grpo_config.get("dynamic_batching", False) == False:
|
||||||
self.overlength_samples = (old_loss_mask & ~loss_mask).sum().item()
|
# filter out samples with reward outside the range
|
||||||
self.overlength_samples = all_reduce_sum(
|
# if dynamic batching is enabled, we filter out out of range groups before training
|
||||||
torch.tensor(self.overlength_samples, device=loss_mask.device), self.plugin
|
group_ans_acc_mean = (
|
||||||
|
ans_acc.view(-1, self.num_generations).mean(dim=1).repeat_interleave(self.num_generations, dim=-1)
|
||||||
)
|
)
|
||||||
self.total_overlength_samples += self.overlength_samples.item()
|
loss_mask = torch.logical_and(
|
||||||
|
loss_mask,
|
||||||
prompt_level_mask = loss_mask.view(self.minibatch_size, self.num_generations)
|
torch.logical_and(
|
||||||
|
group_ans_acc_mean > self.filter_range[0],
|
||||||
# [minibatch_size] -> calculate the number of effective prompts
|
group_ans_acc_mean < self.filter_range[1],
|
||||||
effective_prompts_mask = prompt_level_mask.any(dim=1)
|
),
|
||||||
effective_prompts = all_reduce_sum(torch.sum(effective_prompts_mask), self.plugin)
|
)
|
||||||
self.effective_prompt_count += effective_prompts.item()
|
self.effective_prompt_count += group_reward.size(0) * self.dp_size
|
||||||
excessive_prompts_idx = None
|
|
||||||
|
|
||||||
mean_kl, mean_loss = [], []
|
mean_kl, mean_loss = [], []
|
||||||
|
|
||||||
if self.grpo_config.get("dynamic_batching", True):
|
if self.grpo_config.get("dynamic_batching", True):
|
||||||
need_update = self.effective_prompt_count >= self.batch_size * self.dp_size
|
need_update = self.effective_prompt_count >= self.batch_size * self.dp_size
|
||||||
excessive_prompts = self.effective_prompt_count - self.batch_size * self.dp_size
|
|
||||||
|
|
||||||
if excessive_prompts > 0:
|
|
||||||
excessive_prompts_per_rank = excessive_prompts // self.dp_size
|
|
||||||
# Only count excessive prompts if they are greater than 1 per rank.
|
|
||||||
# TODO: customize excessive prompts calculation.
|
|
||||||
if excessive_prompts_per_rank != 0:
|
|
||||||
# Mask excessive prompts to False
|
|
||||||
true_indices = torch.nonzero(effective_prompts_mask)
|
|
||||||
# Make sure the indices are not empty.
|
|
||||||
if true_indices.numel() > 0:
|
|
||||||
true_indices = true_indices.squeeze(-1)
|
|
||||||
if excessive_prompts_per_rank <= len(true_indices):
|
|
||||||
excessive_prompts_idx = true_indices[-excessive_prompts_per_rank:]
|
|
||||||
else:
|
|
||||||
excessive_prompts_idx = true_indices
|
|
||||||
effective_prompts_mask[excessive_prompts_idx] = False
|
|
||||||
|
|
||||||
for mask_idx in range(len(effective_prompts_mask)):
|
|
||||||
if effective_prompts_mask[mask_idx] == False:
|
|
||||||
# Update loss mask.
|
|
||||||
loss_mask[mask_idx] = False
|
|
||||||
else:
|
|
||||||
excessive_prompts_idx = torch.empty([0])
|
|
||||||
else:
|
else:
|
||||||
# If dynamic batching is disabled, we need to use all samples for training.
|
# If dynamic batching is disabled, we need to use all samples for training.
|
||||||
need_update = (step_idx + 1) % self.num_microbatches == 0
|
need_update = (step_idx + 1) % self.num_microbatches == 0
|
||||||
@ -286,12 +239,10 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
total_effective_tokens_count = all_reduce_sum(torch.sum(effective_tokens_count), self.plugin)
|
total_effective_tokens_count = all_reduce_sum(torch.sum(effective_tokens_count), self.plugin)
|
||||||
total_samples = all_reduce_sum(torch.sum(torch.ones_like(loss_mask, device=loss_mask.device)), self.plugin)
|
total_samples = all_reduce_sum(torch.sum(torch.ones_like(loss_mask, device=loss_mask.device)), self.plugin)
|
||||||
self.effective_sample_count += effective_samples.item()
|
self.effective_sample_count += effective_samples.item()
|
||||||
self.total_sample_count += total_samples.item()
|
|
||||||
pbar.set_postfix(
|
pbar.set_postfix(
|
||||||
{
|
{
|
||||||
"Global Step": self.global_step,
|
"Global Step": self.global_step,
|
||||||
"Effective prompts": f"{self.effective_prompt_count}/{self.batch_size * self.dp_size}",
|
"Gradient Accumulation on": f"{self.effective_prompt_count}/{self.batch_size * self.dp_size} effective prompts, {self.effective_sample_count}/{self.batch_size * self.dp_size * self.num_generations} effective samples",
|
||||||
"Effective samples": f"{self.effective_sample_count}/{self.batch_size * self.dp_size * self.num_generations}",
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -483,22 +434,16 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
self.accum_loss.add_(sum(mean_loss) / len(mean_loss))
|
self.accum_loss.add_(sum(mean_loss) / len(mean_loss))
|
||||||
if self.policy_loss_fn.beta > 0:
|
if self.policy_loss_fn.beta > 0:
|
||||||
self.accum_kl.add_(sum(mean_kl) / len(mean_kl))
|
self.accum_kl.add_(sum(mean_kl) / len(mean_kl))
|
||||||
self.accum_reward.add_(reward.data)
|
|
||||||
self.accum_format_acc.add_(format_acc.data)
|
|
||||||
self.accum_ans_acc.add_(ans_acc.data)
|
|
||||||
self.accum_advantages.add_(advantages.data)
|
self.accum_advantages.add_(advantages.data)
|
||||||
self.accum_response_length.add_(response_length.data)
|
|
||||||
self.accum_count += 1
|
self.accum_count += 1
|
||||||
if need_update:
|
if need_update:
|
||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
self.global_step += 1
|
self.global_step += 1
|
||||||
sample_utilization = self.effective_sample_count / self.total_sample_count
|
# no need to run all reduce as raw_train_batch_* are not splited across dp rank
|
||||||
overlength_samples_percentage = self.total_overlength_samples / self.total_sample_count
|
sample_utilization = self.effective_sample_count / len(self.raw_train_batch_reward) / self.num_generations
|
||||||
self.effective_prompt_count = 0
|
self.effective_prompt_count = 0
|
||||||
self.effective_sample_count = 0
|
self.effective_sample_count = 0
|
||||||
self.total_sample_count = 0
|
|
||||||
self.total_overlength_samples = 0
|
|
||||||
loss_scalar = self.accum_loss.item()
|
loss_scalar = self.accum_loss.item()
|
||||||
if not self.plugin.pp_size > 1 or (
|
if not self.plugin.pp_size > 1 or (
|
||||||
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
|
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
|
||||||
@ -506,27 +451,39 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
if (not self.plugin.pp_size > 1 and self.rank == 0) or (
|
if (not self.plugin.pp_size > 1 and self.rank == 0) or (
|
||||||
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
|
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
|
||||||
):
|
):
|
||||||
|
raw_batch_reward_mean = torch.cat(self.raw_train_batch_reward, dim=0).mean().cpu().item()
|
||||||
|
raw_batch_format_acc_mean = torch.cat(self.raw_train_batch_format_acc, dim=0).mean().cpu().item()
|
||||||
|
raw_batch_ans_acc_mean = torch.cat(self.raw_train_batch_ans_acc, dim=0).mean().cpu().item()
|
||||||
|
raw_batch_response_len = torch.cat(self.raw_train_batch_response_len, dim=0)
|
||||||
|
raw_batch_response_len_mean = raw_batch_response_len.mean().cpu().item()
|
||||||
|
overlength_samples_ratio = (
|
||||||
|
(raw_batch_response_len >= action_mask.size(-1)).to(float).mean().cpu().item()
|
||||||
|
) # not an exact figure, but a close estimate
|
||||||
|
self.raw_train_batch_reward = []
|
||||||
|
self.raw_train_batch_format_acc = []
|
||||||
|
self.raw_train_batch_ans_acc = []
|
||||||
|
self.raw_train_batch_response_len = []
|
||||||
to_log_msg = [
|
to_log_msg = [
|
||||||
f"Loss: {self.accum_loss.item() / self.accum_count:.4f}",
|
f"Loss: {self.accum_loss.item() / self.accum_count:.4f}",
|
||||||
f"Reward: {self.accum_reward.item() / self.accum_count:.4f}",
|
f"Reward: {raw_batch_reward_mean:.4f}",
|
||||||
f"format Reward: {self.accum_format_acc.item() / self.accum_count:.4f}",
|
f"format Reward: {raw_batch_format_acc_mean:.4f}",
|
||||||
f"Acc Reward: {self.accum_ans_acc.item() / self.accum_count:.4f}",
|
f"Acc Reward: {raw_batch_ans_acc_mean:.4f}",
|
||||||
f"Advantages: {self.accum_advantages.item() / self.accum_count:.4f}",
|
f"Advantages: {self.accum_advantages.item() / self.accum_count:.4f}",
|
||||||
f"Response Length: {self.accum_response_length.item() / self.accum_count:.4f}",
|
f"Response Length: {raw_batch_response_len_mean:.4f}",
|
||||||
f"Sample_utilization: {sample_utilization:.4f}",
|
f"Sample_utilization: {sample_utilization:.4f}",
|
||||||
f"Percentage of overlength samples: {overlength_samples_percentage:.4f}",
|
f"Overlength samples ratio: {overlength_samples_ratio:.4f}",
|
||||||
] + ([f"KL: {self.accum_kl.item() / self.accum_count:.4f}"] if self.policy_loss_fn.beta > 0 else [])
|
] + ([f"KL: {self.accum_kl.item() / self.accum_count:.4f}"] if self.policy_loss_fn.beta > 0 else [])
|
||||||
print("\n".join(to_log_msg))
|
print("\n".join(to_log_msg))
|
||||||
metrics = {
|
metrics = {
|
||||||
"metrics/reward": self.accum_reward.item() / self.accum_count,
|
"metrics/reward": raw_batch_reward_mean,
|
||||||
"metrics/format_acc": self.accum_format_acc.item() / self.accum_count,
|
"metrics/format_acc": raw_batch_format_acc_mean,
|
||||||
"metrics/ans_acc": self.accum_ans_acc.item() / self.accum_count,
|
"metrics/ans_acc": raw_batch_ans_acc_mean,
|
||||||
"metrics/response_length": self.accum_response_length.item() / self.accum_count,
|
"metrics/response_length": raw_batch_response_len_mean,
|
||||||
"train/loss": self.accum_loss.item() / self.accum_count,
|
"train/loss": self.accum_loss.item() / self.accum_count,
|
||||||
"train/advantages": self.accum_advantages.item() / self.accum_count,
|
"train/advantages": self.accum_advantages.item() / self.accum_count,
|
||||||
"train/learning_rate": self.lr_scheduler.get_last_lr()[0],
|
"train/learning_rate": self.lr_scheduler.get_last_lr()[0],
|
||||||
"train/sample_utilization": sample_utilization,
|
"train/sample_utilization": sample_utilization,
|
||||||
"train/percentage_overlength_samples": overlength_samples_percentage,
|
"train/overlength_samples_ratio": overlength_samples_ratio,
|
||||||
"rollout/temperature": data["temperature"].cpu().numpy()[0][0],
|
"rollout/temperature": data["temperature"].cpu().numpy()[0][0],
|
||||||
}
|
}
|
||||||
if self.policy_loss_fn.beta > 0:
|
if self.policy_loss_fn.beta > 0:
|
||||||
@ -534,21 +491,46 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
if self.wandb_run is not None:
|
if self.wandb_run is not None:
|
||||||
self.wandb_run.log(metrics)
|
self.wandb_run.log(metrics)
|
||||||
self.accum_loss.zero_()
|
self.accum_loss.zero_()
|
||||||
self.accum_reward.zero_()
|
|
||||||
self.accum_ans_acc.zero_()
|
|
||||||
self.accum_format_acc.zero_()
|
|
||||||
self.accum_kl.zero_()
|
self.accum_kl.zero_()
|
||||||
self.accum_advantages.zero_()
|
self.accum_advantages.zero_()
|
||||||
self.accum_response_length.zero_()
|
|
||||||
self.accum_count = 0
|
self.accum_count = 0
|
||||||
|
return loss_scalar
|
||||||
if excessive_prompts_idx is not None:
|
|
||||||
# All gather excessive prompts index across DP ranks.
|
|
||||||
excessive_prompts_idx = [idx + self.dp_rank * self.minibatch_size for idx in excessive_prompts_idx]
|
|
||||||
excessive_prompts_idx = all_gather_tensors(excessive_prompts_idx, self.plugin)
|
|
||||||
return loss_scalar, excessive_prompts_idx
|
|
||||||
else:
|
else:
|
||||||
return None, excessive_prompts_idx
|
return None
|
||||||
|
|
||||||
|
def calculate_reward(self, rollout: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Calculate the group reward for the given rollout group.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
rollout_group (Dict[str, Any]):
|
||||||
|
a group of samples generated by the model from the same prompt
|
||||||
|
contain the following keys:
|
||||||
|
"input_ids": torch.Tensor, [num_of_generation, prompt_length + response_length]
|
||||||
|
"attention_mask": torch.Tensor, [num_of_generation, prompt_length + response_length]
|
||||||
|
"action_mask": torch.Tensor, [num_of_generation, response_length]
|
||||||
|
"action_log_probs": torch.Tensor, [num_of_generation, response_length]
|
||||||
|
"response_idx": int, torch.Tensor, [num_of_generation, 2]
|
||||||
|
"gt_answer": torch.Tensor, [num_of_generation, 128]
|
||||||
|
"temperature": torch.Tensor, [] (scalar)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, Any]: The new group data with calculated reward.
|
||||||
|
"""
|
||||||
|
reward_model_output = self.reward_model(
|
||||||
|
rollout["input_ids"],
|
||||||
|
gt_answer=rollout["gt_answer"],
|
||||||
|
response_idx=rollout["response_idx"],
|
||||||
|
)
|
||||||
|
# [num_of_generation]
|
||||||
|
reward = torch.tensor([value[0] for value in reward_model_output]).to(rollout["input_ids"].device)
|
||||||
|
format_acc = torch.tensor([value[1] for value in reward_model_output]).to(rollout["input_ids"].device)
|
||||||
|
ans_acc = torch.tensor([value[2] for value in reward_model_output]).to(rollout["input_ids"].device)
|
||||||
|
|
||||||
|
rollout["reward"] = reward.view((-1, 1))
|
||||||
|
rollout["format_acc"] = format_acc.view((-1, 1))
|
||||||
|
rollout["ans_acc"] = ans_acc.view((-1, 1))
|
||||||
|
return rollout
|
||||||
|
|
||||||
def state_dict(self):
|
def state_dict(self):
|
||||||
self.policy_model._force_wait_all_gather()
|
self.policy_model._force_wait_all_gather()
|
||||||
|
@ -100,6 +100,7 @@ def launch_distributed(
|
|||||||
eval_dataset_config=eval_dataset_config,
|
eval_dataset_config=eval_dataset_config,
|
||||||
eval_interval=eval_interval,
|
eval_interval=eval_interval,
|
||||||
evaluation_function_type=grpo_config["reward_fn_type"],
|
evaluation_function_type=grpo_config["reward_fn_type"],
|
||||||
|
response_format_tags=grpo_config["response_format_tags"],
|
||||||
eval_save_dir=eval_save_dir,
|
eval_save_dir=eval_save_dir,
|
||||||
eval_generation_config=eval_generation_config,
|
eval_generation_config=eval_generation_config,
|
||||||
project_name=project_name,
|
project_name=project_name,
|
||||||
|
@ -46,6 +46,7 @@ class BaseProducer:
|
|||||||
eval_dataset_config=None,
|
eval_dataset_config=None,
|
||||||
eval_interval=-1, # disable evaluation
|
eval_interval=-1, # disable evaluation
|
||||||
evaluation_function_type="think_answer_tags",
|
evaluation_function_type="think_answer_tags",
|
||||||
|
response_format_tags=None,
|
||||||
eval_save_dir: str = "./eval",
|
eval_save_dir: str = "./eval",
|
||||||
project_name: str = None,
|
project_name: str = None,
|
||||||
run_name: str = None,
|
run_name: str = None,
|
||||||
@ -148,6 +149,7 @@ class BaseProducer:
|
|||||||
self.evaluation_function = boxed_math_reward_fn
|
self.evaluation_function = boxed_math_reward_fn
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown evaluation function type {evaluation_function_type}")
|
raise ValueError(f"Unknown evaluation function type {evaluation_function_type}")
|
||||||
|
self.response_format_tags = response_format_tags
|
||||||
else:
|
else:
|
||||||
print("No eval dataset provided, skip eval")
|
print("No eval dataset provided, skip eval")
|
||||||
self.device = get_current_device()
|
self.device = get_current_device()
|
||||||
@ -217,6 +219,7 @@ class BaseProducer:
|
|||||||
eval_outputs["response_idx"][m][n],
|
eval_outputs["response_idx"][m][n],
|
||||||
tokenizer=self.tokenizer,
|
tokenizer=self.tokenizer,
|
||||||
eval_mode=True,
|
eval_mode=True,
|
||||||
|
tags=self.response_format_tags,
|
||||||
)
|
)
|
||||||
for m in range(eval_outputs["input_ids"].size(0))
|
for m in range(eval_outputs["input_ids"].size(0))
|
||||||
for n in range(eval_outputs["input_ids"].size(1))
|
for n in range(eval_outputs["input_ids"].size(1))
|
||||||
@ -245,11 +248,10 @@ class BaseProducer:
|
|||||||
self.eval_mode = False
|
self.eval_mode = False
|
||||||
self.latest_eval_step = self.consumer_global_step
|
self.latest_eval_step = self.consumer_global_step
|
||||||
outputs = self.rollout(**batch)
|
outputs = self.rollout(**batch)
|
||||||
|
|
||||||
print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}")
|
|
||||||
outputs["temperature"] = torch.tensor(
|
outputs["temperature"] = torch.tensor(
|
||||||
[self.model.generate_config["temperature"]] * outputs["input_ids"].size(0)
|
[self.model.generate_config["temperature"]] * outputs["input_ids"].size(0)
|
||||||
).to(outputs["input_ids"].device)
|
).to(outputs["input_ids"].device)
|
||||||
|
print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}")
|
||||||
outputs = pre_send(outputs)
|
outputs = pre_send(outputs)
|
||||||
ray_broadcast_tensor_dict(
|
ray_broadcast_tensor_dict(
|
||||||
outputs, src=0, device=self.device, group_name=f"sync_data_{self.producer_idx}"
|
outputs, src=0, device=self.device, group_name=f"sync_data_{self.producer_idx}"
|
||||||
@ -324,6 +326,7 @@ class SimpleProducer(BaseProducer):
|
|||||||
eval_dataset_config=None,
|
eval_dataset_config=None,
|
||||||
eval_interval=-1, # disable evaluation
|
eval_interval=-1, # disable evaluation
|
||||||
evaluation_function_type="think_answer_tags",
|
evaluation_function_type="think_answer_tags",
|
||||||
|
response_format_tags=None,
|
||||||
eval_save_dir: str = "./eval",
|
eval_save_dir: str = "./eval",
|
||||||
eval_generation_config={},
|
eval_generation_config={},
|
||||||
project_name: str = None,
|
project_name: str = None,
|
||||||
@ -349,6 +352,7 @@ class SimpleProducer(BaseProducer):
|
|||||||
eval_dataset_config=eval_dataset_config,
|
eval_dataset_config=eval_dataset_config,
|
||||||
eval_interval=eval_interval,
|
eval_interval=eval_interval,
|
||||||
evaluation_function_type=evaluation_function_type,
|
evaluation_function_type=evaluation_function_type,
|
||||||
|
response_format_tags=response_format_tags,
|
||||||
eval_save_dir=eval_save_dir,
|
eval_save_dir=eval_save_dir,
|
||||||
project_name=project_name,
|
project_name=project_name,
|
||||||
run_name=run_name,
|
run_name=run_name,
|
||||||
|
@ -121,6 +121,34 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-rsd", "--rollout-save-dir", type=str, default="./rollouts", help="Directory for saving rollout loggings."
|
"-rsd", "--rollout-save-dir", type=str, default="./rollouts", help="Directory for saving rollout loggings."
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-tp",
|
||||||
|
"--tensor-parallel-size",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Tensor parallel size for the trainer (consumer). Please check the generation arguments documentation for your backend.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-pp",
|
||||||
|
"--pipeline-parallel-size",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Pipeline parallel size for the trainer (consumer). Please check the generation arguments documentation for your backend.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-zero",
|
||||||
|
"--zero-stage",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="Zero stage for the trainer (consumer). Please check the generation arguments documentation for your backend.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-ptp",
|
||||||
|
"--producer-tensor-parallel-size",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Tensor parallel size for the producer. Please check the generation arguments documentation for your backend.",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args.train_minibatch_size is None:
|
if args.train_minibatch_size is None:
|
||||||
@ -134,8 +162,8 @@ if __name__ == "__main__":
|
|||||||
and args.train_microbatch_size > 0
|
and args.train_microbatch_size > 0
|
||||||
), "Train micro batch size must be greater than 0 less than train mini batch size * num generations"
|
), "Train micro batch size must be greater than 0 less than train mini batch size * num generations"
|
||||||
assert (
|
assert (
|
||||||
args.train_minibatch_size <= args.train_batch_size
|
args.train_minibatch_size <= args.train_batch_size and args.train_batch_size % args.train_minibatch_size == 0
|
||||||
), "Train mini batch size must be less than or equals to train batch size"
|
), "Train mini batch size must be less than or equals to train batch size and train batch size must be divisible by train mini batch size"
|
||||||
|
|
||||||
if args.master_address is None:
|
if args.master_address is None:
|
||||||
# Default settings: Using single machine
|
# Default settings: Using single machine
|
||||||
@ -178,7 +206,7 @@ if __name__ == "__main__":
|
|||||||
enforce_eager=True,
|
enforce_eager=True,
|
||||||
enable_chunked_prefill=True,
|
enable_chunked_prefill=True,
|
||||||
max_model_len=args.max_new_tokens + args.max_prompt_tokens,
|
max_model_len=args.max_new_tokens + args.max_prompt_tokens,
|
||||||
tensor_parallel_size=1,
|
tensor_parallel_size=args.producer_tensor_parallel_size,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
generate_config.update(
|
generate_config.update(
|
||||||
@ -203,6 +231,16 @@ if __name__ == "__main__":
|
|||||||
"reward_fn_type": args.reward_type,
|
"reward_fn_type": args.reward_type,
|
||||||
"max_length": args.max_new_tokens + args.max_prompt_tokens,
|
"max_length": args.max_new_tokens + args.max_prompt_tokens,
|
||||||
"max_new_tokens": args.max_new_tokens,
|
"max_new_tokens": args.max_new_tokens,
|
||||||
|
"response_format_tags": (
|
||||||
|
{
|
||||||
|
"think_start": {"text": "<think>", "num_occur": 1},
|
||||||
|
"think_end": {"text": "</think>", "num_occur": 1},
|
||||||
|
"answer_start": {"text": "<answer>", "num_occur": 1},
|
||||||
|
"answer_end": {"text": "</answer>", "num_occur": 1},
|
||||||
|
}
|
||||||
|
if args.reward_type == "think_answer_tags"
|
||||||
|
else None
|
||||||
|
),
|
||||||
}
|
}
|
||||||
elif args.algo == "DAPO":
|
elif args.algo == "DAPO":
|
||||||
# DAPO variant settings
|
# DAPO variant settings
|
||||||
@ -222,13 +260,23 @@ if __name__ == "__main__":
|
|||||||
"cache_length": min(1024, int(args.max_new_tokens / 4)),
|
"cache_length": min(1024, int(args.max_new_tokens / 4)),
|
||||||
"filter_truncated_response": True,
|
"filter_truncated_response": True,
|
||||||
"reward_fn_type": args.reward_type,
|
"reward_fn_type": args.reward_type,
|
||||||
|
"response_format_tags": (
|
||||||
|
{
|
||||||
|
"think_start": {"text": "<think>", "num_occur": 1},
|
||||||
|
"think_end": {"text": "</think>", "num_occur": 1},
|
||||||
|
"answer_start": {"text": "<answer>", "num_occur": 1},
|
||||||
|
"answer_end": {"text": "</answer>", "num_occur": 1},
|
||||||
|
}
|
||||||
|
if args.reward_type == "think_answer_tags"
|
||||||
|
else None
|
||||||
|
),
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported algorithm: {args.algo}")
|
raise ValueError(f"Unsupported algorithm: {args.algo}")
|
||||||
|
|
||||||
launch_distributed(
|
launch_distributed(
|
||||||
num_producers=args.num_inferencer,
|
num_producers=args.num_inferencer,
|
||||||
num_proc_per_producer=inference_model_config.get("tensor_parallel_size", 1),
|
num_proc_per_producer=inference_model_config.get("tensor_parallel_size", args.producer_tensor_parallel_size),
|
||||||
num_consumer_procs=args.num_trainers,
|
num_consumer_procs=args.num_trainers,
|
||||||
num_episodes=args.num_episodes,
|
num_episodes=args.num_episodes,
|
||||||
inference_batch_size=args.inference_batch_size,
|
inference_batch_size=args.inference_batch_size,
|
||||||
@ -247,17 +295,14 @@ if __name__ == "__main__":
|
|||||||
train_model_config=train_model_config,
|
train_model_config=train_model_config,
|
||||||
grpo_config=grpo_config,
|
grpo_config=grpo_config,
|
||||||
plugin_config={
|
plugin_config={
|
||||||
"zero_stage": 2,
|
"tp_size": args.tensor_parallel_size,
|
||||||
}, # for zero
|
"pp_size": args.pipeline_parallel_size,
|
||||||
# plugin_config={
|
"microbatch_size": max(
|
||||||
# "tp_size": 2,
|
1, args.train_microbatch_size // args.pipeline_parallel_size
|
||||||
# "pp_size": 2,
|
), # microbatch size should be set to train_microbatch_size // pp_size
|
||||||
# "microbatch_size": max(
|
"zero_stage": args.zero_stage,
|
||||||
# 1, args.train_microbatch_size // 2
|
"max_norm": 1.0,
|
||||||
# ), # microbatch size should be set to train_microbatch_size // pp_size
|
}, # for pp, tp
|
||||||
# "zero_stage": 0,
|
|
||||||
# "max_norm": 1.0,
|
|
||||||
# }, # for pp, tp
|
|
||||||
inference_backend=args.backend,
|
inference_backend=args.backend,
|
||||||
master_addr="localhost",
|
master_addr="localhost",
|
||||||
master_port=args.master_port,
|
master_port=args.master_port,
|
||||||
|
Loading…
Reference in New Issue
Block a user