mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-04-27 19:36:13 +00:00
[feat] Support DAPO (#6263)
* update help information * update style * fix * minor fix * support PP training * add pp support * remove unused code * address conversation * fix memory leakage support tp+pp * move empty cache * move empty cache * add DAPO support * remove format reward * fix filtering, still buggy * small fix * add DAPO support * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * tested multi-node training; fix bind_batch bug * fix conversation; support sleep mode * support reusing excessive samples * add dynamic batching control flag * add dynamic batching control flag * refactored * fix logging --------- Co-authored-by: Tong Li <tong.li35271158@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
b823c6eec7
commit
26d859f68e
@ -16,7 +16,7 @@ from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from .comm import ray_broadcast_tensor_dict
|
||||
from .utils import bind_batch, post_recv, unbind_batch
|
||||
from .utils import bind_batch, pad_batch, post_recv, unbind_batch
|
||||
|
||||
|
||||
class BaseConsumer:
|
||||
@ -33,7 +33,7 @@ class BaseConsumer:
|
||||
batch_size: int,
|
||||
model_config: Dict[str, Any],
|
||||
plugin_config: Dict[str, Any],
|
||||
microbatch_size: int = 1,
|
||||
minibatch_size: int = 1,
|
||||
save_interval: int = 100,
|
||||
save_dir: str = "./model",
|
||||
):
|
||||
@ -46,11 +46,11 @@ class BaseConsumer:
|
||||
self.num_update_per_episode = num_update_per_episode
|
||||
self.num_recv_per_update = num_recv_per_update
|
||||
self.batch_size = batch_size
|
||||
self.microbatch_size = microbatch_size
|
||||
self.minibatch_size = minibatch_size
|
||||
self.save_interval = save_interval
|
||||
self.save_dir = save_dir
|
||||
assert batch_size % microbatch_size == 0, "batch_size should be divisible by microbatch_size"
|
||||
self.num_microbatches = batch_size // microbatch_size
|
||||
assert batch_size % minibatch_size == 0, "batch_size should be divisible by microbatch_size"
|
||||
self.num_microbatches = batch_size // minibatch_size
|
||||
|
||||
self.model_config = model_config
|
||||
self.plugin_config = plugin_config
|
||||
@ -67,7 +67,7 @@ class BaseConsumer:
|
||||
|
||||
plugin_config = dict(tp_size=1, pp_size=1, precision="bf16", zero_stage=2)
|
||||
if self.plugin_config.get("pp_size", 1) > 1 and "num_microbatches" not in self.plugin_config:
|
||||
plugin_config["microbatch_size"] = self.microbatch_size
|
||||
plugin_config["microbatch_size"] = self.minibatch_size
|
||||
plugin_config.update(self.plugin_config)
|
||||
self.plugin = HybridParallelPlugin(**plugin_config)
|
||||
self.booster = Booster(plugin=self.plugin)
|
||||
@ -105,18 +105,26 @@ class BaseConsumer:
|
||||
)
|
||||
)
|
||||
)
|
||||
while len(self.buffer) >= self.dp_size * self.microbatch_size:
|
||||
while len(self.buffer) >= self.dp_size * self.minibatch_size:
|
||||
batches = self.buffer[
|
||||
self.dp_rank * self.microbatch_size : (self.dp_rank + 1) * self.microbatch_size
|
||||
self.dp_rank * self.minibatch_size : (self.dp_rank + 1) * self.minibatch_size
|
||||
]
|
||||
self.buffer = self.buffer[self.dp_size * self.microbatch_size :]
|
||||
batch = pad_batch(
|
||||
batches
|
||||
) # when `imbs` is smaller than `tMbs`, samples may have differ in size, need to pad before stacking
|
||||
batch = bind_batch(batches)
|
||||
batch = post_recv(batch)
|
||||
loss = self.step(i, **batch)
|
||||
loss, num_excessive_prompts = self.step(i, pbar, **batch)
|
||||
self.buffer = (
|
||||
self.buffer[
|
||||
(self.dp_rank + 1) * self.minibatch_size
|
||||
- num_excessive_prompts : (self.dp_rank + 1) * self.minibatch_size
|
||||
]
|
||||
+ self.buffer[self.dp_size * self.minibatch_size :]
|
||||
)
|
||||
if loss is not None:
|
||||
pbar.set_postfix({"loss": loss})
|
||||
i += 1
|
||||
assert len(self.buffer) == 0
|
||||
if self.lr_scheduler is not None:
|
||||
self.lr_scheduler.step()
|
||||
if (step + 1) % self.save_interval == 0 or (step + 1) == self.num_update_per_episode:
|
||||
@ -154,7 +162,9 @@ class SimpleConsumer(BaseConsumer):
|
||||
batch_size,
|
||||
model_config,
|
||||
plugin_config,
|
||||
microbatch_size=1,
|
||||
minibatch_size=1,
|
||||
save_interval: int = 100,
|
||||
save_dir="./model",
|
||||
):
|
||||
super().__init__(
|
||||
num_producers,
|
||||
@ -168,7 +178,7 @@ class SimpleConsumer(BaseConsumer):
|
||||
batch_size,
|
||||
model_config,
|
||||
plugin_config,
|
||||
microbatch_size,
|
||||
minibatch_size,
|
||||
)
|
||||
path = model_config.pop("path")
|
||||
self.model = AutoModelForCausalLM.from_pretrained(path, **model_config)
|
||||
@ -181,7 +191,7 @@ class SimpleConsumer(BaseConsumer):
|
||||
super().setup()
|
||||
self.model, self.optimizer, *_ = self.booster.boost(self.model, self.optimizer)
|
||||
|
||||
def step(self, step_idx: int, **kwargs) -> Optional[float]:
|
||||
def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
|
||||
labels = kwargs["input_ids"].clone()
|
||||
labels[kwargs["attention_mask"] == 0] = -100
|
||||
kwargs["labels"] = labels
|
||||
|
@ -1,18 +1,16 @@
|
||||
import json
|
||||
import os
|
||||
import warnings
|
||||
from contextlib import nullcontext
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import ray
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import wandb
|
||||
from coati.distributed.consumer import BaseConsumer
|
||||
from coati.distributed.loss import PolicyLoss
|
||||
from coati.distributed.reward.reward_fn import math_reward_fn
|
||||
from coati.distributed.reward.verifiable_reward import VerifiableReward
|
||||
from coati.distributed.utils import calc_action_log_probs
|
||||
from coati.trainer.utils import all_reduce_mean
|
||||
from coati.trainer.utils import all_reduce_mean, all_reduce_sum
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
@ -34,13 +32,23 @@ class GRPOConsumer(BaseConsumer):
|
||||
batch_size,
|
||||
model_config,
|
||||
plugin_config,
|
||||
microbatch_size=1,
|
||||
minibatch_size=1,
|
||||
num_generations=8,
|
||||
use_wandb=True,
|
||||
generate_config=None,
|
||||
training_config={},
|
||||
grpo_config={},
|
||||
project_name=None,
|
||||
save_interval: int = 100,
|
||||
save_dir="./model",
|
||||
):
|
||||
print(f"Using GRPO config: {grpo_config}")
|
||||
if grpo_config.get("loss_variation", "sample_level") == "token_level":
|
||||
if batch_size != minibatch_size:
|
||||
warnings.warn(
|
||||
f"Applied token_level loss, force overwrite mini-batch-size with batch-size: mini-batch-size: {minibatch_size}->{batch_size}",
|
||||
UserWarning,
|
||||
)
|
||||
minibatch_size = batch_size
|
||||
super().__init__(
|
||||
num_producers,
|
||||
num_episodes,
|
||||
@ -53,36 +61,58 @@ class GRPOConsumer(BaseConsumer):
|
||||
batch_size,
|
||||
model_config,
|
||||
plugin_config,
|
||||
microbatch_size,
|
||||
minibatch_size,
|
||||
save_dir=save_dir,
|
||||
)
|
||||
path = model_config.pop("path")
|
||||
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
|
||||
self.policy_model.train()
|
||||
self.policy_model.gradient_checkpointing_enable()
|
||||
self.optimizer = HybridAdam(self.policy_model.parameters(), lr=training_config.get("lr", 1e-6))
|
||||
self.optimizer = HybridAdam(self.policy_model.parameters(), lr=grpo_config.get("lr", 1e-6))
|
||||
self.accum_loss = torch.zeros(1, device=self.device)
|
||||
self.accum_reward = torch.zeros(1, device=self.device)
|
||||
self.accum_kl = torch.zeros(1, device=self.device)
|
||||
self.accum_format_reward = torch.zeros(1, device=self.device)
|
||||
self.accum_acc_reward = torch.zeros(1, device=self.device)
|
||||
self.accum_format_acc = torch.zeros(1, device=self.device)
|
||||
self.accum_ans_acc = torch.zeros(1, device=self.device)
|
||||
self.accum_advantages = torch.zeros(1, device=self.device)
|
||||
self.accum_response_length = torch.zeros(1, device=self.device)
|
||||
self.accum_count = 0
|
||||
self.generate_config = generate_config
|
||||
self.training_config = training_config
|
||||
self.grpo_config = grpo_config
|
||||
self.project_name = project_name
|
||||
self.effective_sample_count = 0
|
||||
self.total_sample_count = 0
|
||||
|
||||
self.policy_loss_fn = PolicyLoss(
|
||||
clip_eps_low=grpo_config.get("clip_eps_low", 0.2),
|
||||
clip_eps_high=grpo_config.get("clip_eps_high", 0.2),
|
||||
beta=grpo_config.get("beta", 0.01),
|
||||
loss_variation=grpo_config.get("loss_variation", "sample_level"),
|
||||
)
|
||||
|
||||
# Reference model is initialized from policy model.
|
||||
self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
|
||||
self.reference_model.eval()
|
||||
if self.policy_loss_fn.beta > 0:
|
||||
self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
|
||||
self.reference_model.eval()
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(path)
|
||||
self.pad_token_id = self.tokenizer.pad_token_id
|
||||
self.num_generations = num_generations
|
||||
self.filter_range = training_config.get("filter_range", None)
|
||||
self.filter_range = grpo_config.get("filter_range", None)
|
||||
if self.filter_range is not None:
|
||||
assert len(self.filter_range) == 2, "Filter range should have 2 values."
|
||||
|
||||
self.filter_truncated_response = grpo_config.get("filter_truncated_response", False)
|
||||
if self.filter_truncated_response:
|
||||
self.max_length = 0
|
||||
if "max_tokens" in self.generate_config:
|
||||
self.max_length = self.generate_config["max_tokens"]
|
||||
elif "max_new_tokens" in self.generate_config:
|
||||
self.max_length = self.generate_config["max_new_tokens"]
|
||||
else:
|
||||
raise ValueError(
|
||||
"either max_tokens (vllm) or max_new_tokens (transformers) must be set in generate_config."
|
||||
)
|
||||
# Initialize verifiable reward.
|
||||
response_format_tags = {
|
||||
"think_start": {"text": "<think>", "num_occur": 1},
|
||||
@ -90,11 +120,12 @@ class GRPOConsumer(BaseConsumer):
|
||||
"answer_start": {"text": "<answer>", "num_occur": 1},
|
||||
"answer_end": {"text": "</answer>", "num_occur": 1},
|
||||
}
|
||||
reward_model_kwargs = {
|
||||
k: v for k, v in grpo_config.items() if k in ["soft_over_length_punishment", "max_length", "cache_length"]
|
||||
}
|
||||
self.reward_model = VerifiableReward(
|
||||
reward_fns=[math_reward_fn], tokenizer=self.tokenizer, tags=response_format_tags
|
||||
reward_fns=[math_reward_fn], tokenizer=self.tokenizer, tags=response_format_tags, **reward_model_kwargs
|
||||
)
|
||||
|
||||
self.policy_loss_fn = PolicyLoss()
|
||||
self.global_step = 0
|
||||
self.use_wandb = use_wandb
|
||||
|
||||
@ -102,7 +133,7 @@ class GRPOConsumer(BaseConsumer):
|
||||
optimizer=self.optimizer,
|
||||
total_steps=min(self.num_episodes, 4) * self.num_update_per_episode,
|
||||
warmup_steps=0,
|
||||
eta_min=0.1 * training_config.get("lr", 1e-6),
|
||||
eta_min=0.1 * grpo_config.get("lr", 1e-6),
|
||||
)
|
||||
|
||||
def setup(self):
|
||||
@ -118,10 +149,11 @@ class GRPOConsumer(BaseConsumer):
|
||||
self.policy_model, self.optimizer, _, _, self.lr_scheduler = self.booster.boost(
|
||||
self.policy_model, self.optimizer, lr_scheduler=self.lr_scheduler
|
||||
)
|
||||
self.reference_model, *_ = self.booster.boost(self.reference_model)
|
||||
if self.policy_loss_fn.beta > 0:
|
||||
self.reference_model, *_ = self.booster.boost(self.reference_model)
|
||||
self.plugin.logger.set_level("ERROR")
|
||||
|
||||
def step(self, step_idx: int, **kwargs) -> Optional[float]:
|
||||
def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
|
||||
"""
|
||||
Step data from policy model:
|
||||
[{
|
||||
@ -132,18 +164,108 @@ class GRPOConsumer(BaseConsumer):
|
||||
},
|
||||
...]
|
||||
Format:
|
||||
[batch_size, num_of_generation, prompt_length + response_length] --- <PAD>...<PAD><PROMPT>...<PROMPT><RESPONSE>...<RESPONSE><PAD>...<PAD>.
|
||||
[minibatch_size, num_of_generation, prompt_length + response_length] --- <PAD>...<PAD><PROMPT>...<PROMPT><RESPONSE>...<RESPONSE><PAD>...<PAD>.
|
||||
"""
|
||||
|
||||
# Reshape to [batch_size x num_of_generation, prompt_length + response_length]
|
||||
# Reshape to [minibatch_size x num_of_generation, prompt_length + response_length]
|
||||
data = {k: v.view(-1, v.size(-1)) for k, v in kwargs.items()}
|
||||
action_mask = data["action_mask"]
|
||||
num_action = action_mask.shape[1]
|
||||
old_action_log_probs = data["action_log_probs"]
|
||||
response_length = torch.sum(action_mask, dim=1).to(torch.float32)
|
||||
forward_batch_size = self.training_config.get("train_microbatch_size", data["input_ids"].size(0))
|
||||
train_microbatch_size = self.grpo_config.get("train_microbatch_size", data["input_ids"].size(0))
|
||||
|
||||
reward_group = self.reward_model(
|
||||
data["input_ids"],
|
||||
gt_answer=data["gt_answer"],
|
||||
response_idx=data["response_idx"],
|
||||
)
|
||||
|
||||
reward = torch.tensor([value[0] for value in reward_group]).to(data["input_ids"].device)
|
||||
format_acc = torch.tensor([value[1] for value in reward_group]).to(data["input_ids"].device)
|
||||
ans_acc = torch.tensor([value[2] for value in reward_group]).to(data["input_ids"].device)
|
||||
|
||||
# [minibatch_size, num_generations]
|
||||
|
||||
group_reward = reward.view(-1, self.num_generations)
|
||||
reward_mean = group_reward.mean(dim=1)
|
||||
# [minibatch_size x num_generations]
|
||||
reward_mean = reward_mean.repeat_interleave(self.num_generations, dim=0)
|
||||
|
||||
reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0)
|
||||
# [minibatch_size x num_generations]
|
||||
advantages = ((reward - reward_mean) / (reward_std + 1e-4)).unsqueeze(dim=-1)
|
||||
# filter out the reward that is too high (all sample gets full score) or too low (all sample gets 0 score),
|
||||
group_ans_acc = (
|
||||
ans_acc.view(-1, self.num_generations).mean(dim=1).repeat_interleave(self.num_generations, dim=0)
|
||||
)
|
||||
loss_mask = (
|
||||
torch.ones(action_mask.size(0), device=action_mask.device).bool()
|
||||
if self.filter_range is None
|
||||
else torch.logical_and(group_ans_acc > self.filter_range[0], group_ans_acc < self.filter_range[1])
|
||||
)
|
||||
# filter out overlength samples
|
||||
if self.filter_truncated_response and action_mask.size(1) == self.max_length:
|
||||
loss_mask = torch.logical_and(
|
||||
loss_mask,
|
||||
action_mask[:, -1] == False,
|
||||
)
|
||||
effective_tokens_count = torch.sum(action_mask, dim=-1) * loss_mask
|
||||
effective_samples = all_reduce_sum(torch.sum(loss_mask), self.plugin)
|
||||
total_effective_tokens_count = all_reduce_sum(torch.sum(effective_tokens_count), self.plugin)
|
||||
total_samples = all_reduce_sum(torch.sum(torch.ones_like(loss_mask, device=loss_mask.device)), self.plugin)
|
||||
self.effective_sample_count += effective_samples.item()
|
||||
self.total_sample_count += total_samples.item()
|
||||
|
||||
mean_kl, mean_loss = [], []
|
||||
|
||||
if self.grpo_config.get("dynamic_batching", True):
|
||||
need_update = self.effective_sample_count >= self.batch_size * self.dp_size * self.num_generations
|
||||
# to exclude the excessive samples from the last batch, the last num_excessive_samples samples are not used for training and will be kept in buffer for the next iteration.
|
||||
num_excessive_samples = (
|
||||
int(
|
||||
(self.effective_sample_count - self.batch_size * self.dp_size * self.num_generations)
|
||||
/ self.num_generations
|
||||
/ self.dp_size
|
||||
)
|
||||
* self.num_generations
|
||||
)
|
||||
if num_excessive_samples > 0:
|
||||
data = {
|
||||
k: (
|
||||
v[: -num_excessive_samples if num_excessive_samples != 0 else v.size(0)]
|
||||
if k
|
||||
in [
|
||||
"input_ids",
|
||||
"attention_mask",
|
||||
"action_log_probs",
|
||||
"action_mask",
|
||||
"response_idx",
|
||||
"gt_answer",
|
||||
]
|
||||
else v
|
||||
)
|
||||
for k, v in data.items()
|
||||
}
|
||||
action_mask = action_mask[
|
||||
: -num_excessive_samples if num_excessive_samples != 0 else action_mask.size(0)
|
||||
]
|
||||
loss_mask = loss_mask[: -num_excessive_samples if num_excessive_samples != 0 else loss_mask.size(0)]
|
||||
advantages = advantages[: -num_excessive_samples if num_excessive_samples != 0 else advantages.size(0)]
|
||||
else:
|
||||
num_excessive_samples = 0
|
||||
else:
|
||||
# If dynamic batching is disabled, we need to use all samples for training.
|
||||
need_update = (step_idx + 1) % self.num_microbatches == 0
|
||||
num_excessive_samples = 0
|
||||
|
||||
pbar.set_postfix(
|
||||
{
|
||||
"Step": self.global_step + 1,
|
||||
"Status": f"Collecting: {self.effective_sample_count}/{self.batch_size * self.dp_size * self.num_generations}",
|
||||
}
|
||||
)
|
||||
|
||||
need_update = (step_idx + 1) % self.num_microbatches == 0
|
||||
# Gradient must be synchronized if zero2 is enabled. https://github.com/hpcaitech/ColossalAI/blob/44d4053fec005fe0b06b6bc755fdc962463145df/colossalai/booster/plugin/hybrid_parallel_plugin.py#L1500
|
||||
ctx = (
|
||||
nullcontext()
|
||||
@ -151,95 +273,71 @@ class GRPOConsumer(BaseConsumer):
|
||||
else self.booster.no_sync(self.policy_model, self.optimizer)
|
||||
)
|
||||
with ctx:
|
||||
reward_group = self.reward_model(
|
||||
data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"]
|
||||
)
|
||||
|
||||
reward = torch.tensor([value[0] for value in reward_group]).to(data["input_ids"].device)
|
||||
format_reward = torch.tensor([value[1] for value in reward_group]).to(data["input_ids"].device)
|
||||
acc_reward = torch.tensor([value[2] for value in reward_group]).to(data["input_ids"].device)
|
||||
|
||||
# [batch_size, num_generations]
|
||||
|
||||
group_reward = reward.view(-1, self.num_generations)
|
||||
reward_mean = group_reward.mean(dim=1)
|
||||
# [batch_size x num_generations]
|
||||
reward_mean = reward_mean.repeat_interleave(self.num_generations, dim=0)
|
||||
reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0)
|
||||
# [batch_size x num_generations]
|
||||
advantages = ((reward - reward_mean) / (reward_std + 1e-4)).unsqueeze(dim=-1)
|
||||
# filter out the reward that is too high (all sample gets full score) or too low (all sample gets 0 score),
|
||||
loss_mask = (
|
||||
None
|
||||
if self.filter_range is None
|
||||
else torch.logical_and(
|
||||
reward_mean > self.filter_range[0], reward_mean < self.filter_range[1]
|
||||
).repeat_interleave(self.num_generations, dim=0)
|
||||
)
|
||||
mean_kl, mean_loss = [], []
|
||||
|
||||
for forward_micro_batch_start in range(0, data["input_ids"].size(0), forward_batch_size):
|
||||
for forward_micro_batch_start in range(0, data["input_ids"].size(0), train_microbatch_size):
|
||||
input_ids_forward_micro_batch = data["input_ids"][
|
||||
forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
|
||||
forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
|
||||
]
|
||||
attention_mask_forward_micro_batch = data["attention_mask"][
|
||||
forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
|
||||
forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
|
||||
]
|
||||
action_mask_forward_micro_batch = action_mask[
|
||||
forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
|
||||
forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
|
||||
]
|
||||
loss_mask_forward_micro_batch = (
|
||||
loss_mask[forward_micro_batch_start : forward_micro_batch_start + forward_batch_size]
|
||||
loss_mask[forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size]
|
||||
if loss_mask is not None
|
||||
else None
|
||||
)
|
||||
advantages_forward_micro_batch = advantages[
|
||||
forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
|
||||
forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
|
||||
]
|
||||
|
||||
if self.plugin.pp_size > 1:
|
||||
# Support training with PP.
|
||||
if self.policy_loss_fn.beta > 0:
|
||||
with torch.no_grad():
|
||||
reference_model_outputs = self.booster.execute_pipeline(
|
||||
iter(
|
||||
[
|
||||
{
|
||||
"input_ids": input_ids_forward_micro_batch,
|
||||
"attention_mask": attention_mask_forward_micro_batch,
|
||||
}
|
||||
]
|
||||
),
|
||||
self.reference_model,
|
||||
criterion=lambda outputs, inputs: torch.tensor(
|
||||
[0.0], device=action_mask.device
|
||||
), # dummy criterion
|
||||
optimizer=None,
|
||||
return_loss=False,
|
||||
return_outputs=True,
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
reference_model_outputs = self.booster.execute_pipeline(
|
||||
iter(
|
||||
[
|
||||
{
|
||||
"input_ids": input_ids_forward_micro_batch,
|
||||
"attention_mask": attention_mask_forward_micro_batch,
|
||||
}
|
||||
]
|
||||
),
|
||||
self.reference_model,
|
||||
criterion=lambda outputs, inputs: torch.tensor(
|
||||
[0.0], device=action_mask.device
|
||||
), # dummy criterion
|
||||
optimizer=None,
|
||||
return_loss=False,
|
||||
return_outputs=True,
|
||||
)
|
||||
|
||||
if self.booster.plugin.stage_manager.is_last_stage():
|
||||
reference_model_logits = reference_model_outputs["outputs"]["logits"]
|
||||
reference_action_log_probs = calc_action_log_probs(
|
||||
reference_model_logits / self.generate_config["temperature"],
|
||||
input_ids_forward_micro_batch,
|
||||
num_action,
|
||||
self.plugin.shard_config,
|
||||
)
|
||||
if self.booster.plugin.stage_manager.is_last_stage():
|
||||
reference_model_logits = reference_model_outputs["outputs"]["logits"]
|
||||
reference_action_log_probs = calc_action_log_probs(
|
||||
reference_model_logits / self.generate_config["temperature"],
|
||||
input_ids_forward_micro_batch,
|
||||
num_action,
|
||||
self.plugin.shard_config,
|
||||
)
|
||||
else:
|
||||
# Dummy reference logprobs for data iterator.
|
||||
reference_action_log_probs = None
|
||||
else:
|
||||
# Dummy reference logprobs for data iterator.
|
||||
reference_action_log_probs = None
|
||||
|
||||
data_policy_forward = {
|
||||
"input_ids": input_ids_forward_micro_batch,
|
||||
"attention_mask": attention_mask_forward_micro_batch,
|
||||
"action_mask": action_mask_forward_micro_batch,
|
||||
"reference_action_log_probs": reference_action_log_probs,
|
||||
"advantages": advantages_forward_micro_batch,
|
||||
"loss_mask": loss_mask_forward_micro_batch,
|
||||
"source": self.rank,
|
||||
}
|
||||
if reference_action_log_probs is not None:
|
||||
data_policy_forward["reference_action_log_probs"] = reference_action_log_probs
|
||||
|
||||
kl = []
|
||||
|
||||
@ -251,24 +349,30 @@ class GRPOConsumer(BaseConsumer):
|
||||
num_action,
|
||||
self.plugin.shard_config,
|
||||
)
|
||||
per_token_kl = (
|
||||
torch.exp(inputs["reference_action_log_probs"] - action_log_probs)
|
||||
- (inputs["reference_action_log_probs"] - action_log_probs)
|
||||
- 1
|
||||
)
|
||||
appox_kl = torch.sum(per_token_kl * inputs["action_mask"], dim=-1) / torch.sum(
|
||||
inputs["action_mask"], dim=-1
|
||||
)
|
||||
kl.append(appox_kl.mean())
|
||||
loss, skip_update, _ = self.policy_loss_fn(
|
||||
if "reference_action_log_probs" in inputs:
|
||||
per_token_kl = (
|
||||
torch.exp(inputs["reference_action_log_probs"] - action_log_probs)
|
||||
- (inputs["reference_action_log_probs"] - action_log_probs)
|
||||
- 1
|
||||
)
|
||||
appox_kl = torch.sum(per_token_kl * inputs["action_mask"], dim=-1) / torch.sum(
|
||||
inputs["action_mask"], dim=-1
|
||||
)
|
||||
kl.append(appox_kl.mean())
|
||||
else:
|
||||
per_token_kl = 0.0
|
||||
kl.append(0.0)
|
||||
|
||||
loss, _ = self.policy_loss_fn(
|
||||
action_log_probs,
|
||||
action_log_probs,
|
||||
inputs["advantages"].repeat_interleave(action_log_probs.size(-1), dim=-1),
|
||||
per_token_kl,
|
||||
inputs["action_mask"],
|
||||
loss_mask=inputs["loss_mask"],
|
||||
total_effective_tokens_in_batch=total_effective_tokens_count,
|
||||
)
|
||||
return loss
|
||||
return loss, num_excessive_samples // self.num_generations
|
||||
|
||||
policy_model_outputs = self.booster.execute_pipeline(
|
||||
iter([data_policy_forward]),
|
||||
@ -298,61 +402,71 @@ class GRPOConsumer(BaseConsumer):
|
||||
self.plugin.shard_config,
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
reference_model_logits = self.reference_model(
|
||||
input_ids=input_ids_forward_micro_batch,
|
||||
attention_mask=attention_mask_forward_micro_batch,
|
||||
).logits
|
||||
reference_action_log_probs = calc_action_log_probs(
|
||||
reference_model_logits / self.generate_config["temperature"],
|
||||
input_ids_forward_micro_batch,
|
||||
num_action,
|
||||
self.plugin.shard_config,
|
||||
)
|
||||
per_token_kl = (
|
||||
torch.exp(reference_action_log_probs - action_log_probs)
|
||||
- (reference_action_log_probs - action_log_probs)
|
||||
- 1
|
||||
)
|
||||
kl = torch.sum(per_token_kl * action_mask_forward_micro_batch, dim=-1) / torch.sum(
|
||||
action_mask_forward_micro_batch, dim=-1
|
||||
)
|
||||
if self.policy_loss_fn.beta > 0:
|
||||
with torch.no_grad():
|
||||
reference_model_logits = self.reference_model(
|
||||
input_ids=input_ids_forward_micro_batch,
|
||||
attention_mask=attention_mask_forward_micro_batch,
|
||||
).logits
|
||||
reference_action_log_probs = calc_action_log_probs(
|
||||
reference_model_logits / self.generate_config["temperature"],
|
||||
input_ids_forward_micro_batch,
|
||||
num_action,
|
||||
self.plugin.shard_config,
|
||||
)
|
||||
per_token_kl = (
|
||||
torch.exp(reference_action_log_probs - action_log_probs)
|
||||
- (reference_action_log_probs - action_log_probs)
|
||||
- 1
|
||||
)
|
||||
kl = torch.sum(per_token_kl * action_mask_forward_micro_batch, dim=-1) / torch.sum(
|
||||
action_mask_forward_micro_batch, dim=-1
|
||||
)
|
||||
else:
|
||||
per_token_kl = 0.0
|
||||
kl = None
|
||||
|
||||
loss, skip_update, _ = self.policy_loss_fn(
|
||||
loss, _ = self.policy_loss_fn(
|
||||
action_log_probs,
|
||||
old_action_log_probs,
|
||||
advantages_forward_micro_batch.repeat_interleave(action_log_probs.size(-1), dim=-1),
|
||||
per_token_kl,
|
||||
action_mask_forward_micro_batch,
|
||||
loss_mask=loss_mask_forward_micro_batch,
|
||||
total_effective_tokens_in_batch=total_effective_tokens_count,
|
||||
)
|
||||
|
||||
if not skip_update:
|
||||
self.booster.backward(loss, self.optimizer)
|
||||
self.booster.backward(loss, self.optimizer)
|
||||
loss = all_reduce_mean(loss, self.plugin)
|
||||
kl = all_reduce_mean(kl.mean(), self.plugin)
|
||||
# Calculate accumulate value.
|
||||
mean_kl.append(kl.data)
|
||||
if kl is not None:
|
||||
kl = all_reduce_mean(kl.mean(), self.plugin)
|
||||
mean_kl.append(kl.data)
|
||||
mean_loss.append(loss.data)
|
||||
if not self.plugin.pp_size > 1 or (
|
||||
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
|
||||
):
|
||||
reward = all_reduce_mean(reward.mean(), self.plugin)
|
||||
format_reward = all_reduce_mean(format_reward.mean(), self.plugin)
|
||||
acc_reward = all_reduce_mean(acc_reward.mean(), self.plugin)
|
||||
format_acc = all_reduce_mean(format_acc.mean(), self.plugin)
|
||||
ans_acc = all_reduce_mean(ans_acc.mean(), self.plugin)
|
||||
advantages = all_reduce_mean(advantages.mean(), self.plugin)
|
||||
response_length = all_reduce_mean(response_length.mean(), self.plugin)
|
||||
self.accum_loss.add_(sum(mean_loss) / len(mean_loss))
|
||||
self.accum_kl.add_(sum(mean_kl) / len(mean_kl))
|
||||
if self.policy_loss_fn.beta > 0:
|
||||
self.accum_kl.add_(sum(mean_kl) / len(mean_kl))
|
||||
self.accum_reward.add_(reward.data)
|
||||
self.accum_format_reward.add_(format_reward.data)
|
||||
self.accum_acc_reward.add_(acc_reward.data)
|
||||
self.accum_format_acc.add_(format_acc.data)
|
||||
self.accum_ans_acc.add_(ans_acc.data)
|
||||
self.accum_advantages.add_(advantages.data)
|
||||
self.accum_response_length.add_(response_length.data)
|
||||
self.accum_count += 1
|
||||
if need_update:
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
self.global_step += 1
|
||||
sample_utilization = self.effective_sample_count / self.total_sample_count
|
||||
self.effective_sample_count = 0
|
||||
self.total_sample_count = 0
|
||||
if not self.plugin.pp_size > 1 or (
|
||||
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
|
||||
):
|
||||
@ -360,167 +474,41 @@ class GRPOConsumer(BaseConsumer):
|
||||
if (not self.plugin.pp_size > 1 and self.rank == 0) or (
|
||||
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
|
||||
):
|
||||
print(
|
||||
"Loss:",
|
||||
self.accum_loss.item() / self.accum_count,
|
||||
"\nReward:",
|
||||
self.accum_reward.item() / self.accum_count,
|
||||
"\nFormat Reward:",
|
||||
self.accum_format_reward.item() / self.accum_count,
|
||||
"\nAcc Reward:",
|
||||
self.accum_acc_reward.item() / self.accum_count,
|
||||
"\nKL:",
|
||||
self.accum_kl.item() / self.accum_count,
|
||||
"\nAdvantages:",
|
||||
self.accum_advantages.item() / self.accum_count,
|
||||
"\nResponse Length:",
|
||||
self.accum_response_length.item() / self.accum_count,
|
||||
)
|
||||
self.wandb_run.log(
|
||||
{
|
||||
"metrics/reward": self.accum_reward.item() / self.accum_count,
|
||||
"metrics/format_reward": self.accum_format_reward.item() / self.accum_count,
|
||||
"metrics/acc_reward": self.accum_acc_reward.item() / self.accum_count,
|
||||
"metrics/response_length": self.accum_response_length.item() / self.accum_count,
|
||||
"train/loss": self.accum_loss.item() / self.accum_count,
|
||||
"train/kl": self.accum_kl.item() / self.accum_count,
|
||||
"train/advantages": self.accum_advantages.item() / self.accum_count,
|
||||
"train/learning_rate": self.lr_scheduler.get_last_lr()[0],
|
||||
"rollout/temperature": data["temperature"].cpu().numpy()[0][0],
|
||||
}
|
||||
)
|
||||
to_log_msg = [
|
||||
f"Loss: {self.accum_loss.item() / self.accum_count:.4f}",
|
||||
f"Reward: {self.accum_reward.item() / self.accum_count:.4f}",
|
||||
f"format Reward: {self.accum_format_acc.item() / self.accum_count:.4f}",
|
||||
f"Acc Reward: {self.accum_ans_acc.item() / self.accum_count:.4f}",
|
||||
f"Advantages: {self.accum_advantages.item() / self.accum_count:.4f}",
|
||||
f"Response Length: {self.accum_response_length.item() / self.accum_count:.4f}",
|
||||
] + ([f"KL: {self.accum_kl.item() / self.accum_count:.4f}"] if self.policy_loss_fn.beta > 0 else [])
|
||||
print("\n".join(to_log_msg))
|
||||
metrics = {
|
||||
"metrics/reward": self.accum_reward.item() / self.accum_count,
|
||||
"metrics/format_acc": self.accum_format_acc.item() / self.accum_count,
|
||||
"metrics/ans_acc": self.accum_ans_acc.item() / self.accum_count,
|
||||
"metrics/response_length": self.accum_response_length.item() / self.accum_count,
|
||||
"train/loss": self.accum_loss.item() / self.accum_count,
|
||||
"train/advantages": self.accum_advantages.item() / self.accum_count,
|
||||
"train/learning_rate": self.lr_scheduler.get_last_lr()[0],
|
||||
"train/sample_utilization": sample_utilization,
|
||||
"rollout/temperature": data["temperature"].cpu().numpy()[0][0],
|
||||
}
|
||||
if self.policy_loss_fn.beta > 0:
|
||||
metrics["train/kl"] = self.accum_kl.item() / self.accum_count
|
||||
|
||||
self.wandb_run.log(metrics)
|
||||
self.accum_loss.zero_()
|
||||
self.accum_reward.zero_()
|
||||
self.accum_acc_reward.zero_()
|
||||
self.accum_format_reward.zero_()
|
||||
self.accum_ans_acc.zero_()
|
||||
self.accum_format_acc.zero_()
|
||||
self.accum_kl.zero_()
|
||||
self.accum_advantages.zero_()
|
||||
self.accum_response_length.zero_()
|
||||
|
||||
self.accum_count = 0
|
||||
return loss_scalar
|
||||
|
||||
def state_dict(self):
|
||||
self.policy_model._force_wait_all_gather()
|
||||
model = self.policy_model.unwrap()
|
||||
state_dict = model.state_dict()
|
||||
return state_dict
|
||||
|
||||
|
||||
@ray.remote
|
||||
class GRPOEvalConsumer(BaseConsumer):
|
||||
def __init__(
|
||||
self,
|
||||
num_producers,
|
||||
num_episodes,
|
||||
rank,
|
||||
world_size,
|
||||
master_addr,
|
||||
master_port,
|
||||
num_update_per_episode,
|
||||
num_recv_per_update,
|
||||
batch_size,
|
||||
model_config,
|
||||
plugin_config,
|
||||
microbatch_size=1,
|
||||
num_generations=4,
|
||||
use_wandb=True,
|
||||
log_dir="./results",
|
||||
):
|
||||
super().__init__(
|
||||
num_producers,
|
||||
num_episodes,
|
||||
rank,
|
||||
world_size,
|
||||
master_addr,
|
||||
master_port,
|
||||
num_update_per_episode,
|
||||
num_recv_per_update,
|
||||
batch_size,
|
||||
model_config,
|
||||
plugin_config,
|
||||
microbatch_size,
|
||||
)
|
||||
path = model_config.pop("path")
|
||||
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
|
||||
self.policy_model.train()
|
||||
self.accum_reward = torch.zeros(1, device=self.device)
|
||||
self.accum_format_reward = torch.zeros(1, device=self.device)
|
||||
self.accum_acc_reward = torch.zeros(1, device=self.device)
|
||||
self.accum_response_length = torch.zeros(1, device=self.device)
|
||||
self.accum_count = torch.zeros(1, device=self.device)
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(path)
|
||||
self.pad_token_id = self.tokenizer.pad_token_id
|
||||
self.num_generations = num_generations
|
||||
|
||||
# Initialize verifiable reward.
|
||||
response_format_tags = {
|
||||
"think_start": {"text": "<think>", "num_occur": 1},
|
||||
"think_end": {"text": "</think>", "num_occur": 1},
|
||||
"answer_start": {"text": "<answer>", "num_occur": 1},
|
||||
"answer_end": {"text": "</answer>", "num_occur": 1},
|
||||
}
|
||||
self.reward_model = VerifiableReward(
|
||||
reward_fns=[math_reward_fn], tokenizer=self.tokenizer, tags=response_format_tags
|
||||
)
|
||||
|
||||
self.log_dir = log_dir
|
||||
if not os.path.exists(self.log_dir):
|
||||
os.makedirs(self.log_dir)
|
||||
return loss_scalar, num_excessive_samples // self.num_generations
|
||||
else:
|
||||
os.system(f"rm -rf {self.log_dir}/*")
|
||||
|
||||
def setup(self):
|
||||
super().setup()
|
||||
self.policy_model, _, *_ = self.booster.boost(self.policy_model)
|
||||
|
||||
def step(self, step_idx: int, **kwargs) -> Optional[float]:
|
||||
rank = dist.get_rank()
|
||||
data = {k: v.view(-1, v.size(-1)).cpu() for k, v in kwargs.items()}
|
||||
kwargs["input_ids"].size(0)
|
||||
reward_group = self.reward_model(
|
||||
data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"]
|
||||
)
|
||||
reward = [value[0].item() for value in reward_group]
|
||||
format_reward = [value[1].item() for value in reward_group]
|
||||
acc_reward = [value[2].item() for value in reward_group]
|
||||
response_length = [(data["response_idx"][i][1] - data["response_idx"][i][0]).item() for i in range(len(reward))]
|
||||
|
||||
response = self.tokenizer.batch_decode(data["input_ids"], skip_special_tokens=True)
|
||||
with open(f"{self.log_dir}/eval_results_rank_{rank}.jsonl", "a", encoding="utf8") as f:
|
||||
for i in range(len(response)):
|
||||
f.write(
|
||||
json.dumps(
|
||||
{
|
||||
"response": response[i],
|
||||
"reward": reward[i],
|
||||
"format_reward": format_reward[i],
|
||||
"acc_reward": acc_reward[i],
|
||||
"response_length": response_length[i],
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
self.accum_reward += sum(reward)
|
||||
self.accum_format_reward += sum(format_reward)
|
||||
self.accum_acc_reward += sum(acc_reward)
|
||||
self.accum_response_length += sum(response_length)
|
||||
self.accum_count += len(reward)
|
||||
|
||||
# print results
|
||||
total_count = all_reduce_mean(self.accum_count, self.plugin)
|
||||
mean_reward = all_reduce_mean(self.accum_reward, self.plugin) / total_count
|
||||
mean_format_reward = all_reduce_mean(self.accum_format_reward, self.plugin) / total_count
|
||||
mean_acc_reward = all_reduce_mean(self.accum_acc_reward, self.plugin) / total_count
|
||||
mean_response_length = all_reduce_mean(self.accum_response_length, self.plugin) / total_count
|
||||
if rank == 0:
|
||||
print(
|
||||
f"Step {step_idx}: Mean Reward: {mean_reward}, Mean Format Reward: {mean_format_reward}, Mean Acc Reward: {mean_acc_reward}, Mean Response Length: {mean_response_length}"
|
||||
)
|
||||
return None
|
||||
return None, num_excessive_samples // self.num_generations
|
||||
|
||||
def state_dict(self):
|
||||
self.policy_model._force_wait_all_gather()
|
||||
|
@ -184,6 +184,7 @@ class SGLangInferenceBackend(BaseInferenceBackend):
|
||||
class VLLMInferenceBackend(BaseInferenceBackend):
|
||||
DEFAULT_MODEL_CONFIG = dict(
|
||||
trust_remote_code=True,
|
||||
enable_sleep_mode=False,
|
||||
)
|
||||
FORCE_GENERATE_CONFIG = dict(
|
||||
logprobs=0,
|
||||
@ -205,6 +206,7 @@ class VLLMInferenceBackend(BaseInferenceBackend):
|
||||
generate_config.update(self.FORCE_GENERATE_CONFIG)
|
||||
generate_config.update({"n": num_generations})
|
||||
self.generate_config = SamplingParams(**generate_config)
|
||||
self.model_config = model_config
|
||||
self.tokenizer = tokenizer
|
||||
self.num_generations = num_generations
|
||||
|
||||
|
@ -4,10 +4,10 @@ from typing import Any, Dict, Optional
|
||||
import ray
|
||||
|
||||
from .consumer import SimpleConsumer
|
||||
from .grpo_consumer import GRPOConsumer, GRPOEvalConsumer
|
||||
from .grpo_consumer import GRPOConsumer
|
||||
from .producer import SimpleProducer
|
||||
|
||||
ALGO_MAP = {"Simple": SimpleConsumer, "GRPO": GRPOConsumer, "EvalGRPO": GRPOEvalConsumer}
|
||||
ALGO_MAP = {"Simple": SimpleConsumer, "GRPO": GRPOConsumer, "DAPO": GRPOConsumer}
|
||||
|
||||
|
||||
def get_jsonl_size_fast(path: str) -> int:
|
||||
@ -40,6 +40,7 @@ def launch_distributed(
|
||||
inference_model_config: Dict[str, Any],
|
||||
generate_config: Dict[str, Any],
|
||||
train_model_config: Dict[str, Any],
|
||||
grpo_config: Dict[str, Any],
|
||||
plugin_config: Dict[str, Any],
|
||||
tokenizer_config: Optional[Dict[str, Any]] = None,
|
||||
inference_backend: str = "transformers",
|
||||
@ -48,6 +49,8 @@ def launch_distributed(
|
||||
master_port: int = 29500,
|
||||
core_algo: str = "GRPO",
|
||||
project_name: Optional[str] = None,
|
||||
save_interval: int = 100,
|
||||
save_dir: str = "./model",
|
||||
):
|
||||
|
||||
if core_algo not in ALGO_MAP:
|
||||
@ -101,15 +104,13 @@ def launch_distributed(
|
||||
batch_size=train_batch_size,
|
||||
model_config=train_model_config,
|
||||
plugin_config=plugin_config,
|
||||
microbatch_size=train_minibatch_size,
|
||||
minibatch_size=train_minibatch_size,
|
||||
generate_config=generate_config_consumer,
|
||||
training_config={
|
||||
"filter_range": [0.05, 9.0],
|
||||
"lr": 1e-6,
|
||||
"train_microbatch_size": train_microbatch_size,
|
||||
},
|
||||
grpo_config=grpo_config,
|
||||
num_generations=num_generations,
|
||||
project_name=project_name,
|
||||
save_interval=save_interval,
|
||||
save_dir=save_dir,
|
||||
)
|
||||
procs.append(consumer)
|
||||
ray.get([p.setup.remote() for p in procs])
|
||||
|
@ -2,7 +2,7 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from coati.distributed.utils import masked_mean
|
||||
from coati.distributed.utils import masked_mean, masked_sum
|
||||
|
||||
|
||||
class PolicyLoss(nn.Module):
|
||||
@ -10,11 +10,19 @@ class PolicyLoss(nn.Module):
|
||||
Policy Loss for PPO
|
||||
"""
|
||||
|
||||
def __init__(self, clip_eps: float = 0.2, skip_threshold: float = 20.0, beta: float = 0.01) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
clip_eps_low: float = 0.2,
|
||||
clip_eps_high: float = 0.2,
|
||||
beta: float = 0.01,
|
||||
loss_variation: str = "sample_level",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.clip_eps = clip_eps
|
||||
self.skip_threshold = skip_threshold
|
||||
self.clip_eps_low = clip_eps_low
|
||||
self.clip_eps_high = clip_eps_high
|
||||
self.beta = beta
|
||||
self.loss_variation = loss_variation
|
||||
assert loss_variation in ["sample_level", "token_level"], f"Unsupported loss variation: {loss_variation}"
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -24,22 +32,37 @@ class PolicyLoss(nn.Module):
|
||||
per_token_kl: torch.Tensor,
|
||||
action_mask: Optional[torch.Tensor] = None,
|
||||
loss_mask: Optional[torch.Tensor] = None,
|
||||
total_effective_tokens_in_batch: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
skip = False
|
||||
if action_mask is None:
|
||||
ratio = (log_probs - log_probs.detach()).exp()
|
||||
else:
|
||||
ratio = ((log_probs - log_probs.detach()) * action_mask).exp()
|
||||
|
||||
surr1 = ratio * advantages
|
||||
surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages
|
||||
surr2 = ratio.clamp(1 - self.clip_eps_low, 1 + self.clip_eps_high) * advantages
|
||||
if self.beta == 0:
|
||||
# skip kl term if kl coefficient is zero
|
||||
per_token_kl = 0.0
|
||||
loss = -torch.min(surr1, surr2) + self.beta * per_token_kl
|
||||
|
||||
if action_mask is not None:
|
||||
loss = masked_mean(loss, action_mask)
|
||||
if self.loss_variation == "sample_level":
|
||||
if action_mask is not None:
|
||||
loss = masked_mean(loss, action_mask)
|
||||
else:
|
||||
loss = loss.mean(dim=1)
|
||||
if loss_mask is not None:
|
||||
loss = loss * loss_mask
|
||||
loss = loss.mean()
|
||||
elif self.loss_variation == "token_level":
|
||||
if action_mask is not None:
|
||||
loss = masked_sum(loss, action_mask)
|
||||
else:
|
||||
loss = loss.sum(dim=1)
|
||||
if loss_mask is not None:
|
||||
loss = loss * loss_mask
|
||||
loss = loss.sum() / (total_effective_tokens_in_batch + 1e-8)
|
||||
else:
|
||||
loss = loss.mean(dim=1)
|
||||
if loss_mask is not None:
|
||||
loss = loss * loss_mask
|
||||
loss = loss.mean()
|
||||
return loss, skip, ratio.max()
|
||||
raise ValueError(f"Unsupported loss variation: {self.loss_variation}")
|
||||
|
||||
return loss, ratio.max()
|
||||
|
@ -103,7 +103,14 @@ class BaseProducer:
|
||||
|
||||
print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}")
|
||||
outputs["temperature"] = torch.tensor(
|
||||
[self.model.generate_config.temperature] * outputs["input_ids"].size(0)
|
||||
[
|
||||
(
|
||||
self.model.generate_config["temperature"]
|
||||
if isinstance(self.model.generate_config.temperature, dict)
|
||||
else self.model.generate_config.temperature
|
||||
)
|
||||
]
|
||||
* outputs["input_ids"].size(0)
|
||||
).to(outputs["input_ids"].device)
|
||||
outputs = pre_send(outputs)
|
||||
ray_broadcast_tensor_dict(
|
||||
@ -113,10 +120,15 @@ class BaseProducer:
|
||||
if (i + 1) % self.num_microbatches == 0 and (
|
||||
episode != self.num_episodes - 1 or i != num_valid_microbatches - 1
|
||||
):
|
||||
if isinstance(self.model, BACKEND_MAP["vllm"]) and self.model.model_config.get(
|
||||
"enable_sleep_mode", False
|
||||
):
|
||||
self.model.llm.sleep() # revict KV_cache to avoid OOM
|
||||
# don't sync model for last iteration
|
||||
print(
|
||||
f"[P{self.producer_idx}] Sync model episode {episode} step {(i + 1) // self.num_microbatches - 1}"
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
state_dict = ray_broadcast_tensor_dict(
|
||||
None, self.num_producers, device=self.device, group_name="sync_model"
|
||||
@ -124,12 +136,21 @@ class BaseProducer:
|
||||
self.load_state_dict(state_dict)
|
||||
del state_dict
|
||||
torch.cuda.empty_cache()
|
||||
# linear annealing for 1 episode, temperature from initial to 0.7
|
||||
if isinstance(self.model, BACKEND_MAP["vllm"]) and self.model.model_config.get(
|
||||
"enable_sleep_mode", False
|
||||
):
|
||||
self.model.llm.wake_up()
|
||||
# linear annealing for 1 episode, temperature from initial to 0.9
|
||||
if episode <= 0:
|
||||
ratio = 1 - (len(self.dataloader) - i) / len(self.dataloader)
|
||||
self.model.generate_config.temperature = (1 - ratio) * self.generate_config[
|
||||
"temperature"
|
||||
] + ratio * 0.7
|
||||
if isinstance(self.model.generate_config.temperature, dict):
|
||||
self.model.generate_config["temperature"] = (1 - ratio) * self.generate_config[
|
||||
"temperature"
|
||||
] + ratio * 0.9
|
||||
else:
|
||||
self.model.generate_config.temperature = (1 - ratio) * self.generate_config[
|
||||
"temperature"
|
||||
] + ratio * 0.9
|
||||
|
||||
|
||||
@ray.remote
|
||||
|
@ -4,13 +4,22 @@ from .reward_utils import extract_solution, validate_response_structure
|
||||
|
||||
|
||||
def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
||||
format_score = 1.0
|
||||
acc_score = 9.0
|
||||
tokenizer = kwargs["tokenizer"]
|
||||
soft_over_length_punishment = kwargs.get("soft_over_length_punishment", False)
|
||||
acc_score = 10.0
|
||||
reward = torch.tensor(0.0)
|
||||
format_reward = torch.tensor(0.0)
|
||||
acc_reward = torch.tensor(0.0)
|
||||
format_acc = torch.tensor(0.0)
|
||||
ans_acc = torch.tensor(0.0)
|
||||
s, e = response_idx[0], response_idx[1]
|
||||
|
||||
length_reward = 0.0
|
||||
if soft_over_length_punishment:
|
||||
max_length = kwargs.get("max_length", 1024 * 4)
|
||||
cache_length = kwargs.get("cache_length", 512)
|
||||
res_length = e.item() - s.item() + 1
|
||||
if max_length - cache_length < res_length < max_length:
|
||||
length_reward = ((max_length - cache_length) - res_length) / cache_length * acc_score
|
||||
|
||||
if gt_answer is None:
|
||||
return reward
|
||||
|
||||
@ -22,18 +31,20 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
||||
|
||||
# Check format accuracy
|
||||
if format_valid:
|
||||
format_reward += format_score
|
||||
reward += format_score
|
||||
format_acc += 1
|
||||
|
||||
# Check answer accuracy
|
||||
# Check answer accuracy, answer is considered correct if the answer is correct and the format is valid
|
||||
if (
|
||||
final_answer is not None
|
||||
format_valid
|
||||
and final_answer is not None
|
||||
and gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower()
|
||||
):
|
||||
acc_reward += acc_score
|
||||
ans_acc += 1
|
||||
reward += acc_score
|
||||
|
||||
return torch.tensor([reward, format_reward, acc_reward]).to(input_ids.device)
|
||||
reward = reward + length_reward
|
||||
|
||||
return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device)
|
||||
|
||||
|
||||
def gsm8k_reward_fn(input_ids, **kwargs):
|
||||
|
@ -1,3 +1,4 @@
|
||||
from collections import defaultdict
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import torch
|
||||
@ -26,6 +27,27 @@ def bind_batch(batches: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor
|
||||
return batch
|
||||
|
||||
|
||||
def pad_batch(batches: List[Dict[str, torch.Tensor]], tokenizer: Any = None) -> List[Dict[str, torch.Tensor]]:
|
||||
max_len = defaultdict(int)
|
||||
for sample in batches:
|
||||
for k in sample:
|
||||
if k in ["input_ids", "attention_mask", "action_log_probs", "action_mask"]:
|
||||
max_len[k] = max(max_len[k], sample[k].size(-1))
|
||||
for idx, sample in enumerate(batches):
|
||||
for k in sample:
|
||||
if k in ["input_ids", "attention_mask", "action_log_probs", "action_mask"]:
|
||||
# right pad with 0s
|
||||
if k in ["attention_mask", "action_mask"]:
|
||||
batches[idx][k] = torch.nn.functional.pad(
|
||||
batches[idx][k], (0, max_len[k] - batches[idx][k].size(-1)), "constant", False
|
||||
)
|
||||
else:
|
||||
batches[idx][k] = torch.nn.functional.pad(
|
||||
batches[idx][k], (0, max_len[k] - batches[idx][k].size(-1)), "constant", 0
|
||||
)
|
||||
return batches
|
||||
|
||||
|
||||
def pre_send(batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
# compress mask to save bandwidth
|
||||
if "attention_mask" in batch:
|
||||
@ -113,3 +135,20 @@ def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch
|
||||
mask_sum = mask.sum(dim=dim)
|
||||
mean = tensor / (mask_sum + 1e-8)
|
||||
return mean
|
||||
|
||||
|
||||
def masked_sum(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch.Tensor:
|
||||
"""
|
||||
Compute the masked sum of a tensor along a specified dimension.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): The input tensor.
|
||||
mask (torch.Tensor): The mask tensor with the same shape as the input tensor.
|
||||
dim (int, optional): The dimension along which to compute the sum. Default is 1.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The masked sum tensor.
|
||||
|
||||
"""
|
||||
tensor = tensor * mask
|
||||
return tensor.sum(dim=dim)
|
||||
|
@ -128,7 +128,7 @@ def all_reduce_mean(tensor: torch.Tensor, plugin: Plugin = None) -> torch.Tensor
|
||||
return tensor
|
||||
|
||||
|
||||
def all_reduce_sum(tensor: torch.Tensor) -> torch.Tensor:
|
||||
def all_reduce_sum(tensor: torch.Tensor, plugin: Plugin = None) -> torch.Tensor:
|
||||
"""
|
||||
Performs an all-reduce operation to sum the values of the given tensor across all processes.
|
||||
|
||||
@ -138,5 +138,9 @@ def all_reduce_sum(tensor: torch.Tensor) -> torch.Tensor:
|
||||
Returns:
|
||||
torch.Tensor: The reduced tensor with the sum of values across all processes.
|
||||
"""
|
||||
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM)
|
||||
# All reduce sum across DP group
|
||||
if plugin is not None:
|
||||
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM, group=plugin.dp_group)
|
||||
else:
|
||||
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM)
|
||||
return tensor
|
||||
|
@ -1,4 +1,5 @@
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import ray
|
||||
import torch
|
||||
@ -8,15 +9,17 @@ if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B")
|
||||
parser.add_argument("-d", "--dataset", type=str, default="data.jsonl")
|
||||
parser.add_argument("-p", "--project", type=str, default="GRPO", help="Project name.")
|
||||
|
||||
# Distributed training parameters
|
||||
parser.add_argument("-t", "--num-trainers", type=int, default=2)
|
||||
parser.add_argument("-i", "--num-inferencer", type=int, default=2)
|
||||
parser.add_argument("-g", "--num-generations", type=int, default=8, help="Number of generations per prompt.")
|
||||
parser.add_argument("-p", "--project", type=str, default="GRPO", help="Project name.")
|
||||
parser.add_argument(
|
||||
"-ibs",
|
||||
"--inference-batch-size",
|
||||
type=int,
|
||||
default=64,
|
||||
default=None,
|
||||
help="Number of prompts to generate per inference step. It should be divisible by tbs, and the weights on the inference backend will be synced every ibs/tbs training steps of the policy model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
@ -37,7 +40,7 @@ if __name__ == "__main__":
|
||||
"-tMbs",
|
||||
"--train-minibatch-size",
|
||||
type=int,
|
||||
default=1,
|
||||
default=None,
|
||||
help="Number of unique prompts in each training batch per dp group. The inference backend must generate tMbs * g * dp_size samples before forwarding. Satisfy tMbs * g >= tmbs",
|
||||
)
|
||||
parser.add_argument(
|
||||
@ -47,22 +50,78 @@ if __name__ == "__main__":
|
||||
default=2,
|
||||
help="Effective batch size per dp group for forwarding and backwarding. Please select based on the availiable memory.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ray_dir", type=str, default=None, help="Custom temperary directory for storing ray cluster data, Optional"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--master_address", type=str, default=None, help="Master address for multi-node distributed training, Optional"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--master_port", type=int, default=29506, help="Master port for multi-node distributed training, Optional"
|
||||
)
|
||||
|
||||
# Sampling parameters
|
||||
parser.add_argument("-b", "--backend", type=str, default="transformers", choices=["transformers", "vllm"])
|
||||
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple", "GRPO", "EvalGRPO"])
|
||||
parser.add_argument("-temp", "--temperature", type=float, default=1.0, help="Temperature for sampling.")
|
||||
parser.add_argument(
|
||||
"-topk",
|
||||
"--top-k",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Top k for sampling. Please check the generation arguments documentation for your backend.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-topp",
|
||||
"--top-p",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Top p for sampling. Please check the generation arguments documentation for your backend.",
|
||||
)
|
||||
parser.add_argument("-s", "--system-prompt", type=str, default=None, help="System prompt for data construction.")
|
||||
parser.add_argument("-mnt", "--max-new-tokens", type=int, default=1024 * 4 - 512, help="Max length for generation.")
|
||||
parser.add_argument("-mpt", "--max-prompt-tokens", type=int, default=512, help="Max length for prompt.")
|
||||
|
||||
# GRPO parameters
|
||||
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["DAPO", "GRPO"])
|
||||
parser.add_argument("-lr", "--learning-rate", type=float, default=1e-6, help="Learning rate for GRPO.")
|
||||
parser.add_argument("-kl", "--kl-coeff", type=float, default=0.01, help="KL penalty coefficient for GRPO.")
|
||||
|
||||
# Logging/Checkpointing parameters
|
||||
parser.add_argument("-si", "--save-interval", type=int, default=100, help="Interval for saving checkpoints.")
|
||||
parser.add_argument("-sd", "--save-dir", type=str, default="./model", help="Directory for saving checkpoints.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
assert args.train_minibatch_size > 0, "Train mini batch size must be greater than 0"
|
||||
if args.train_minibatch_size is None:
|
||||
# Default settings: Using train batch size as mini batch size
|
||||
args.train_minibatch_size = args.train_batch_size
|
||||
if args.inference_batch_size is None:
|
||||
# Default settings: Using train batch size as inference batch size, sync every inference model every train step
|
||||
args.inference_batch_size = args.train_batch_size
|
||||
assert (
|
||||
args.train_minibatch_size * args.num_generations >= args.train_microbatch_size
|
||||
and args.train_microbatch_size > 0
|
||||
), "Train micro batch size must be greater than 0 less than train mini batch size * num generations"
|
||||
assert (
|
||||
args.train_minibatch_size <= args.train_batch_size
|
||||
), "Train mini batch size must be less than or equals to train batch size"
|
||||
|
||||
ray.init(address="local", namespace="ray-example")
|
||||
if args.master_address is None:
|
||||
# Default settings: Using single machine
|
||||
ray.init(address="local", namespace="ray-example")
|
||||
else:
|
||||
# For ray distributed multi-machine training, Please change _node_ip_address to your IP address of your master node
|
||||
ray.init(_node_ip_address=args.master_address, namespace="ray-example", _temp_dir=args.ray_dir)
|
||||
|
||||
if args.top_k is None:
|
||||
if args.backend == "transformers":
|
||||
args.top_k = 50
|
||||
elif args.backend == "vllm":
|
||||
args.top_k = -1
|
||||
|
||||
inference_model_config = dict(path=args.model)
|
||||
train_model_config = dict(path=args.model, use_flash_attention_2=True, use_cache=False)
|
||||
generate_config = dict(top_k=50, top_p=0.75, temperature=0.9)
|
||||
generate_config = dict(top_k=args.top_k, top_p=args.top_p, temperature=args.temperature)
|
||||
|
||||
if args.backend == "transformers":
|
||||
inference_model_config.update(
|
||||
@ -73,7 +132,7 @@ if __name__ == "__main__":
|
||||
)
|
||||
generate_config.update(
|
||||
dict(
|
||||
max_length=1024 + 512,
|
||||
max_length=args.max_new_tokens + args.max_prompt_tokens,
|
||||
do_sample=True,
|
||||
max_new_tokens=None,
|
||||
early_stopping=False,
|
||||
@ -81,31 +140,57 @@ if __name__ == "__main__":
|
||||
)
|
||||
)
|
||||
elif args.backend == "vllm":
|
||||
inference_model_config.update(dict(gpu_memory_utilization=0.7, enforce_eager=True, enable_chunked_prefill=True))
|
||||
inference_model_config.update(
|
||||
dict(
|
||||
gpu_memory_utilization=0.7,
|
||||
enforce_eager=True,
|
||||
enable_chunked_prefill=True,
|
||||
max_model_len=args.max_new_tokens + args.max_prompt_tokens,
|
||||
tensor_parallel_size=1,
|
||||
)
|
||||
)
|
||||
generate_config.update(
|
||||
dict(
|
||||
max_tokens=2048,
|
||||
max_tokens=args.max_new_tokens, # max new tokens
|
||||
ignore_eos=True,
|
||||
include_stop_str_in_output=True,
|
||||
stop=["</answer>"],
|
||||
)
|
||||
)
|
||||
else:
|
||||
inference_model_config.update(
|
||||
dict(
|
||||
mem_fraction_static=0.6,
|
||||
)
|
||||
)
|
||||
generate_config.update(
|
||||
dict(
|
||||
max_new_tokens=256,
|
||||
ignore_eos=True,
|
||||
)
|
||||
)
|
||||
raise ValueError(f"Unsupported backend: {args.backend}")
|
||||
|
||||
if args.algo == "GRPO":
|
||||
# Default Settings
|
||||
grpo_config = {
|
||||
"lr": args.learning_rate,
|
||||
"train_microbatch_size": args.train_microbatch_size,
|
||||
"beta": args.kl_coeff, # KL penalty coefficient
|
||||
"loss_variation": "sample_level",
|
||||
}
|
||||
elif args.algo == "DAPO":
|
||||
# DAPO variant settings
|
||||
grpo_config = {
|
||||
"filter_range": [0.01, 0.99], # only filter out all zero batch and all one batch
|
||||
"lr": args.learning_rate,
|
||||
"train_microbatch_size": args.train_microbatch_size,
|
||||
"dynamic_batching": True,
|
||||
"clip_eps_low": 0.2,
|
||||
"clip_eps_high": 0.28,
|
||||
"skip_threshold": 20.0,
|
||||
"beta": 0, # no KL penalty for DAPO
|
||||
"loss_variation": "token_level",
|
||||
"soft_over_length_punishment": True,
|
||||
"max_length": args.max_new_tokens + args.max_prompt_tokens,
|
||||
"cache_length": min(1024, int(args.max_new_tokens / 4)),
|
||||
"filter_truncated_response": True,
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unsupported algorithm: {args.algo}")
|
||||
|
||||
launch_distributed(
|
||||
num_producers=args.num_inferencer,
|
||||
num_proc_per_producer=1,
|
||||
num_proc_per_producer=inference_model_config.get("tensor_parallel_size", 1),
|
||||
num_consumer_procs=args.num_trainers,
|
||||
num_episodes=1,
|
||||
inference_batch_size=args.inference_batch_size,
|
||||
@ -113,15 +198,22 @@ if __name__ == "__main__":
|
||||
train_batch_size=args.train_batch_size,
|
||||
train_minibatch_size=args.train_minibatch_size,
|
||||
train_microbatch_size=args.train_microbatch_size,
|
||||
dataset_config={"path": args.dataset, "max_length": 300, "system_prompt": args.system_prompt},
|
||||
dataset_config={
|
||||
"path": args.dataset,
|
||||
"max_length": args.max_prompt_tokens,
|
||||
"system_prompt": args.system_prompt,
|
||||
},
|
||||
dataloaders_config={},
|
||||
inference_model_config=inference_model_config,
|
||||
generate_config=generate_config,
|
||||
num_generations=args.num_generations,
|
||||
train_model_config=train_model_config,
|
||||
plugin_config={}, # Default setting: zero.
|
||||
grpo_config=grpo_config,
|
||||
plugin_config={
|
||||
"zero_stage": 2,
|
||||
}, # for zero
|
||||
# currently not support tp/pp
|
||||
# plugin_config={
|
||||
# "pp_size": 2,
|
||||
# "tp_size": 2,
|
||||
# "microbatch_size": args.train_microbatch_size // 2,
|
||||
# "zero_stage": 0,
|
||||
@ -129,7 +221,9 @@ if __name__ == "__main__":
|
||||
# }, # for pp
|
||||
inference_backend=args.backend,
|
||||
master_addr="localhost",
|
||||
master_port=29506,
|
||||
master_port=args.master_port,
|
||||
core_algo=args.algo,
|
||||
project_name=args.project,
|
||||
save_interval=args.save_interval,
|
||||
save_dir=os.path.join(args.save_dir, args.project.replace(" ", "_")),
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user