[feat] GRPO with distributed implementation (#6230)

* add reward related function

* add simple grpo

* update grpo

* polish

* modify data loader

* grpo consumer

* update loss

* update reward fn

* update example

* update loader

* add algo selection

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add save

* update select algo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update grpo

* update reward fn

* update reward

* fix reward score

* add response length

* detach

* fix tp bug

* fix consumer

* convert to 8 generation

* print results

* setup update

* fix transformers backend

* [Feature] Support Distributed LogProb for GRPO Training (#6247)

* [fix] fix qwen VocabParallelLMHead1D and gather output

* fix tp bug

* fix consumer

* [feat] Support Distributed LogProb for GRPO Training

* [fix] fix loss func

* [fix] fix log prob plugin

* [fix] fix qwen modeling param

* [fix] rm comments

* [fix] rm hard-code;fix non-dist version

* [fix] fix test file param name and benchmark tp gather output=True/False

* [fix] rm non-dist version in dist log prob

* [fix] fix comments

* [fix] fix dis log prob plugin

* [fix] fix test case

* [fix] fix qwen VocabParallelLMHead1D and gather output

* [fix] fix DistLogProb comments

* [fix] restore tp size

* [fix] fix comments

* [fix] fix comment; fix LogSoftmax usage

---------

Co-authored-by: Tong Li <tong.li35271158@gmail.com>

* fix vllm

* fix logprob, add filtering, temperature annealing, lr descent

* simplify vllm preprocessing input ids

* update logging

* [feat] add microbatch forwarding (#6251)

* add microbatch forwarding

* fix forward microbatch

* fix producer OOM

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* change project name

* fix temperature annealing

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* address conversation

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Distributed RLHF] Integration of PP (#6257)

* update help information

* update style

* fix

* minor fix

* support PP training

* add pp support

* remove unused code

* address conversation

---------

Co-authored-by: Tong Li <tong.li35271158@gmail.com>

* [hot-fix] Fix memory leakage bug, support TP+PP (#6258)

* update help information

* update style

* fix

* minor fix

* support PP training

* add pp support

* remove unused code

* address conversation

* fix memory leakage support tp+pp

* move empty cache

* move empty cache

---------

Co-authored-by: Tong Li <tong.li35271158@gmail.com>

---------

Co-authored-by: Tong Li <tong.li35271158@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: duanjunwen <935724073@qq.com>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
This commit is contained in:
Tong Li
2025-04-21 10:43:49 +08:00
committed by GitHub
parent 2bb71c6248
commit 7bb7e80476
19 changed files with 1264 additions and 64 deletions

View File

@@ -356,10 +356,24 @@ def apply_chat_template_and_mask(
truncation: bool = True,
ignore_idx: int = -100,
) -> Dict[str, torch.Tensor]:
system_prompt = "You are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and<answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>. Now the user asks you to solve a math problem that involves reasoning. After thinking, when you finally reach a conclusion, clearly output the final answer without explanation within the <answer> </answer> tags, i.e., <answer> 123 </answer>.\n\n"
system_element = {
"role": "system",
"content": system_prompt,
}
# Format for RL.
gt_answer = None
if "messages" in chat and "gt_answer" in chat:
gt_answer = chat["gt_answer"]
chat = [chat["messages"]]
tokens = []
assistant_mask = []
for i, msg in enumerate(chat):
msg_tokens = tokenizer.apply_chat_template([msg], tokenize=True)
msg_tokens = tokenizer.apply_chat_template([system_element, msg], tokenize=True, add_generation_prompt=True)
# remove unexpected bos token
if i > 0 and msg_tokens[0] == tokenizer.bos_token_id:
msg_tokens = msg_tokens[1:]
@@ -372,14 +386,10 @@ def apply_chat_template_and_mask(
if max_length is not None:
if padding and len(tokens) < max_length:
to_pad = max_length - len(tokens)
if tokenizer.padding_side == "right":
tokens.extend([tokenizer.pad_token_id] * to_pad)
assistant_mask.extend([False] * to_pad)
attention_mask.extend([0] * to_pad)
else:
tokens = [tokenizer.pad_token_id] * to_pad + tokens
assistant_mask = [False] * to_pad + assistant_mask
attention_mask = [0] * to_pad + attention_mask
# Left padding for generation.
tokens = [tokenizer.pad_token_id] * to_pad + tokens
assistant_mask = [False] * to_pad + assistant_mask
attention_mask = [0] * to_pad + attention_mask
if truncation and len(tokens) > max_length:
tokens = tokens[:max_length]
assistant_mask = assistant_mask[:max_length]
@@ -389,6 +399,13 @@ def apply_chat_template_and_mask(
labels = input_ids.clone()
labels[~torch.tensor(assistant_mask, dtype=torch.bool)] = ignore_idx
if gt_answer is not None:
gt_answer = tokenizer.encode(
gt_answer, padding="max_length", truncation=True, max_length=128, return_tensors="pt"
)
gt_answer = gt_answer.squeeze(1)
return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, "gt_answer": gt_answer}
return {
"input_ids": input_ids,
"attention_mask": attention_mask,

View File

@@ -1,3 +1,4 @@
import os
from contextlib import nullcontext
from typing import Any, Dict, Optional
@@ -33,6 +34,8 @@ class BaseConsumer:
model_config: Dict[str, Any],
plugin_config: Dict[str, Any],
microbatch_size: int = 1,
save_interval: int = 100,
save_dir: str = "./model",
):
self.num_producers = num_producers
self.num_episodes = num_episodes
@@ -44,14 +47,16 @@ class BaseConsumer:
self.num_recv_per_update = num_recv_per_update
self.batch_size = batch_size
self.microbatch_size = microbatch_size
self.save_interval = save_interval
self.save_dir = save_dir
assert batch_size % microbatch_size == 0, "batch_size should be divisible by microbatch_size"
self.num_microbatches = batch_size // microbatch_size
self.model_config = model_config
self.plugin_config = plugin_config
assert self.plugin_config.get("pp_size", 1) == 1, "pp_size > 1 is not supported now"
self.device = get_current_device()
self.lr_scheduler = None
def setup(self) -> None:
for i in range(self.num_producers):
@@ -60,18 +65,15 @@ class BaseConsumer:
cc.init_collective_group(self.num_producers + 1, self.num_producers, group_name="sync_model")
launch(self.rank, self.world_size, self.master_addr, self.master_port, local_rank=0)
plugin_config = dict(
tp_size=1,
pp_size=1,
precision="bf16",
zero_stage=1,
)
plugin_config = dict(tp_size=1, pp_size=1, precision="bf16", zero_stage=2)
if self.plugin_config.get("pp_size", 1) > 1 and "num_microbatches" not in self.plugin_config:
plugin_config["microbatch_size"] = self.microbatch_size
plugin_config.update(self.plugin_config)
self.plugin = HybridParallelPlugin(**plugin_config)
self.booster = Booster(plugin=self.plugin)
self.dp_rank = dist.get_rank(self.plugin.dp_group)
self.tp_rank = dist.get_rank(self.plugin.tp_group)
self.dp_size = dist.get_world_size(self.plugin.dp_group)
self.buffer = []
@@ -94,7 +96,6 @@ class BaseConsumer:
i = 0
for _ in range(self.num_recv_per_update):
# receive data from producers
for r in range(self.num_producers):
print(f"[T{dist.get_rank()}] Recv data episode {episode} step {step} from {r}")
self.buffer.extend(
@@ -116,13 +117,26 @@ class BaseConsumer:
pbar.set_postfix({"loss": loss})
i += 1
assert len(self.buffer) == 0
if self.lr_scheduler is not None:
self.lr_scheduler.step()
if (step + 1) % self.save_interval == 0:
if self.rank == 0:
print(f"Start saving policy model at step {step + 1}.")
save_path = os.path.join(self.save_dir, f"modeling-step-{step + 1}")
self.booster.save_model(self.policy_model, save_path, shard=True)
if self.rank == 0:
print(f"Saved model checkpoint at step {step + 1} in folder {save_path}")
if episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1:
print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}")
torch.cuda.empty_cache()
state_dict = self.state_dict()
if self.rank == 0:
ray_broadcast_tensor_dict(
state_dict, src=self.num_producers, device=self.device, group_name="sync_model"
)
del state_dict
torch.cuda.empty_cache()
@ray.remote

View File

@@ -0,0 +1,529 @@
import json
import os
from contextlib import nullcontext
from typing import Optional
import ray
import torch
import torch.distributed as dist
import wandb
from coati.distributed.consumer import BaseConsumer
from coati.distributed.loss import PolicyLoss
from coati.distributed.reward.reward_fn import math_reward_fn
from coati.distributed.reward.verifiable_reward import VerifiableReward
from coati.distributed.utils import calc_action_log_probs
from coati.trainer.utils import all_reduce_mean
from transformers import AutoModelForCausalLM, AutoTokenizer
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.nn.optimizer import HybridAdam
@ray.remote
class GRPOConsumer(BaseConsumer):
def __init__(
self,
num_producers,
num_episodes,
rank,
world_size,
master_addr,
master_port,
num_update_per_episode,
num_recv_per_update,
batch_size,
model_config,
plugin_config,
microbatch_size=1,
num_generations=8,
use_wandb=True,
generate_config=None,
training_config={},
project_name=None,
):
super().__init__(
num_producers,
num_episodes,
rank,
world_size,
master_addr,
master_port,
num_update_per_episode,
num_recv_per_update,
batch_size,
model_config,
plugin_config,
microbatch_size,
)
path = model_config.pop("path")
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
self.policy_model.train()
self.policy_model.gradient_checkpointing_enable()
self.optimizer = HybridAdam(self.policy_model.parameters(), lr=training_config.get("lr", 1e-6))
self.accum_loss = torch.zeros(1, device=self.device)
self.accum_reward = torch.zeros(1, device=self.device)
self.accum_kl = torch.zeros(1, device=self.device)
self.accum_format_reward = torch.zeros(1, device=self.device)
self.accum_acc_reward = torch.zeros(1, device=self.device)
self.accum_advantages = torch.zeros(1, device=self.device)
self.accum_response_length = torch.zeros(1, device=self.device)
self.accum_count = 0
self.generate_config = generate_config
self.training_config = training_config
self.project_name = project_name
# Reference model is initialized from policy model.
self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
self.reference_model.eval()
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.pad_token_id = self.tokenizer.pad_token_id
self.num_generations = num_generations
self.filter_range = training_config.get("filter_range", None)
if self.filter_range is not None:
assert len(self.filter_range) == 2, "Filter range should have 2 values."
# Initialize verifiable reward.
response_format_tags = {
"think_start": {"text": "<think>", "num_occur": 1},
"think_end": {"text": "</think>", "num_occur": 1},
"answer_start": {"text": "<answer>", "num_occur": 1},
"answer_end": {"text": "</answer>", "num_occur": 1},
}
self.reward_model = VerifiableReward(
reward_fns=[math_reward_fn], tokenizer=self.tokenizer, tags=response_format_tags
)
self.policy_loss_fn = PolicyLoss()
self.global_step = 0
self.use_wandb = use_wandb
self.lr_scheduler = CosineAnnealingWarmupLR(
optimizer=self.optimizer,
total_steps=min(self.num_episodes, 4) * self.num_update_per_episode,
warmup_steps=0,
eta_min=0.1 * training_config.get("lr", 1e-6),
)
def setup(self):
super().setup()
if self.use_wandb and (
(not self.plugin.pp_size > 1 and self.rank == 0)
or (self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0)
):
# Initialize wandb.
name = f"{self.generate_config['backend']}_bs_{self.batch_size*self.dp_size}_temp_{self.generate_config['temperature']:.01f}_top_p_{self.generate_config['top_p']:.02f}"
self.wandb_run = wandb.init(project=self.project_name, sync_tensorboard=True, dir="./wandb", name=name)
self.policy_model, self.optimizer, _, _, self.lr_scheduler = self.booster.boost(
self.policy_model, self.optimizer, lr_scheduler=self.lr_scheduler
)
self.reference_model, *_ = self.booster.boost(self.reference_model)
self.plugin.logger.set_level("ERROR")
def step(self, step_idx: int, **kwargs) -> Optional[float]:
"""
Step data from policy model:
[{
"input_ids": torch.Tensor,
"attention_mask": torch.Tensor,
"action_mask": torch.Tensor,
"action_log_probs": torch.Tensor,
},
...]
Format:
[batch_size, num_of_generation, prompt_length + response_length] --- <PAD>...<PAD><PROMPT>...<PROMPT><RESPONSE>...<RESPONSE><PAD>...<PAD>.
"""
# Reshape to [batch_size x num_of_generation, prompt_length + response_length]
data = {k: v.view(-1, v.size(-1)) for k, v in kwargs.items()}
action_mask = data["action_mask"]
num_action = action_mask.shape[1]
old_action_log_probs = data["action_log_probs"]
response_length = torch.sum(action_mask, dim=1).to(torch.float32)
forward_batch_size = self.training_config.get("train_microbatch_size", data["input_ids"].size(0))
need_update = (step_idx + 1) % self.num_microbatches == 0
# Gradient must be synchronized if zero2 is enabled. https://github.com/hpcaitech/ColossalAI/blob/44d4053fec005fe0b06b6bc755fdc962463145df/colossalai/booster/plugin/hybrid_parallel_plugin.py#L1500
ctx = (
nullcontext()
if need_update or self.booster.plugin.zero_stage == 2
else self.booster.no_sync(self.policy_model, self.optimizer)
)
with ctx:
reward_group = self.reward_model(
data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"]
)
reward = torch.tensor([value[0] for value in reward_group]).to(data["input_ids"].device)
format_reward = torch.tensor([value[1] for value in reward_group]).to(data["input_ids"].device)
acc_reward = torch.tensor([value[2] for value in reward_group]).to(data["input_ids"].device)
# [batch_size, num_generations]
group_reward = reward.view(-1, self.num_generations)
reward_mean = group_reward.mean(dim=1)
# [batch_size x num_generations]
reward_mean = reward_mean.repeat_interleave(self.num_generations, dim=0)
reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0)
# [batch_size x num_generations]
advantages = ((reward - reward_mean) / (reward_std + 1e-4)).unsqueeze(dim=-1)
# filter out the reward that is too high (all sample gets full score) or too low (all sample gets 0 score),
loss_mask = (
None
if self.filter_range is None
else torch.logical_and(
reward_mean > self.filter_range[0], reward_mean < self.filter_range[1]
).repeat_interleave(self.num_generations, dim=0)
)
mean_kl, mean_loss = [], []
for forward_micro_batch_start in range(0, data["input_ids"].size(0), forward_batch_size):
input_ids_forward_micro_batch = data["input_ids"][
forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
]
attention_mask_forward_micro_batch = data["attention_mask"][
forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
]
action_mask_forward_micro_batch = action_mask[
forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
]
loss_mask_forward_micro_batch = (
loss_mask[forward_micro_batch_start : forward_micro_batch_start + forward_batch_size]
if loss_mask is not None
else None
)
advantages_forward_micro_batch = advantages[
forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
]
if self.plugin.pp_size > 1:
# Support training with PP.
with torch.no_grad():
reference_model_outputs = self.booster.execute_pipeline(
iter(
[
{
"input_ids": input_ids_forward_micro_batch,
"attention_mask": attention_mask_forward_micro_batch,
}
]
),
self.reference_model,
criterion=lambda outputs, inputs: torch.tensor(
[0.0], device=action_mask.device
), # dummy criterion
optimizer=None,
return_loss=False,
return_outputs=True,
)
if self.booster.plugin.stage_manager.is_last_stage():
reference_model_logits = reference_model_outputs["outputs"]["logits"]
reference_action_log_probs = calc_action_log_probs(
reference_model_logits / self.generate_config["temperature"],
input_ids_forward_micro_batch,
num_action,
self.plugin.shard_config,
)
else:
# Dummy reference logprobs for data iterator.
reference_action_log_probs = None
data_policy_forward = {
"input_ids": input_ids_forward_micro_batch,
"attention_mask": attention_mask_forward_micro_batch,
"action_mask": action_mask_forward_micro_batch,
"reference_action_log_probs": reference_action_log_probs,
"advantages": advantages_forward_micro_batch,
"loss_mask": loss_mask_forward_micro_batch,
"source": self.rank,
}
kl = []
def _criterion(outputs, inputs):
action_logits = outputs.logits
action_log_probs = calc_action_log_probs(
action_logits / self.generate_config["temperature"],
inputs["input_ids"],
num_action,
self.plugin.shard_config,
)
per_token_kl = (
torch.exp(inputs["reference_action_log_probs"] - action_log_probs)
- (inputs["reference_action_log_probs"] - action_log_probs)
- 1
)
appox_kl = torch.sum(per_token_kl * inputs["action_mask"], dim=-1) / torch.sum(
inputs["action_mask"], dim=-1
)
kl.append(appox_kl.mean())
loss, skip_update, _ = self.policy_loss_fn(
action_log_probs,
action_log_probs,
inputs["advantages"].repeat_interleave(action_log_probs.size(-1), dim=-1),
per_token_kl,
inputs["action_mask"],
loss_mask=inputs["loss_mask"],
)
return loss
policy_model_outputs = self.booster.execute_pipeline(
iter([data_policy_forward]),
self.policy_model,
criterion=_criterion,
optimizer=self.optimizer,
return_loss=True,
return_outputs=True,
)
loss = policy_model_outputs["loss"]
if self.booster.plugin.stage_manager.is_last_stage():
if len(kl) > 0:
kl = all_reduce_mean(torch.mean(torch.stack(kl)).to(loss.device), self.plugin).data
mean_kl.append(kl)
mean_loss.append(all_reduce_mean(loss, self.plugin).data)
else:
policy_model_logits = self.policy_model(
input_ids=input_ids_forward_micro_batch,
attention_mask=attention_mask_forward_micro_batch,
).logits
action_log_probs = calc_action_log_probs(
policy_model_logits / self.generate_config["temperature"],
input_ids_forward_micro_batch,
num_action,
self.plugin.shard_config,
)
with torch.no_grad():
reference_model_logits = self.reference_model(
input_ids=input_ids_forward_micro_batch,
attention_mask=attention_mask_forward_micro_batch,
).logits
reference_action_log_probs = calc_action_log_probs(
reference_model_logits / self.generate_config["temperature"],
input_ids_forward_micro_batch,
num_action,
self.plugin.shard_config,
)
per_token_kl = (
torch.exp(reference_action_log_probs - action_log_probs)
- (reference_action_log_probs - action_log_probs)
- 1
)
kl = torch.sum(per_token_kl * action_mask_forward_micro_batch, dim=-1) / torch.sum(
action_mask_forward_micro_batch, dim=-1
)
loss, skip_update, _ = self.policy_loss_fn(
action_log_probs,
old_action_log_probs,
advantages_forward_micro_batch.repeat_interleave(action_log_probs.size(-1), dim=-1),
per_token_kl,
action_mask_forward_micro_batch,
loss_mask=loss_mask_forward_micro_batch,
)
if not skip_update:
self.booster.backward(loss, self.optimizer)
loss = all_reduce_mean(loss, self.plugin)
kl = all_reduce_mean(kl.mean(), self.plugin)
# Calculate accumulate value.
mean_kl.append(kl.data)
mean_loss.append(loss.data)
if not self.plugin.pp_size > 1 or (
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
):
reward = all_reduce_mean(reward.mean(), self.plugin)
format_reward = all_reduce_mean(format_reward.mean(), self.plugin)
acc_reward = all_reduce_mean(acc_reward.mean(), self.plugin)
advantages = all_reduce_mean(advantages.mean(), self.plugin)
response_length = all_reduce_mean(response_length.mean(), self.plugin)
self.accum_loss.add_(sum(mean_loss) / len(mean_loss))
self.accum_kl.add_(sum(mean_kl) / len(mean_kl))
self.accum_reward.add_(reward.data)
self.accum_format_reward.add_(format_reward.data)
self.accum_acc_reward.add_(acc_reward.data)
self.accum_advantages.add_(advantages.data)
self.accum_response_length.add_(response_length.data)
self.accum_count += 1
if need_update:
self.optimizer.step()
self.optimizer.zero_grad()
if not self.plugin.pp_size > 1 or (
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
):
loss_scalar = self.accum_loss.item()
if (not self.plugin.pp_size > 1 and self.rank == 0) or (
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
):
print(
"Loss:",
self.accum_loss.item() / self.accum_count,
"\nReward:",
self.accum_reward.item() / self.accum_count,
"\nFormat Reward:",
self.accum_format_reward.item() / self.accum_count,
"\nAcc Reward:",
self.accum_acc_reward.item() / self.accum_count,
"\nKL:",
self.accum_kl.item() / self.accum_count,
"\nAdvantages:",
self.accum_advantages.item() / self.accum_count,
"\nResponse Length:",
self.accum_response_length.item() / self.accum_count,
)
self.wandb_run.log(
{
"metrics/reward": self.accum_reward.item() / self.accum_count,
"metrics/format_reward": self.accum_format_reward.item() / self.accum_count,
"metrics/acc_reward": self.accum_acc_reward.item() / self.accum_count,
"metrics/response_length": self.accum_response_length.item() / self.accum_count,
"train/loss": self.accum_loss.item() / self.accum_count,
"train/kl": self.accum_kl.item() / self.accum_count,
"train/advantages": self.accum_advantages.item() / self.accum_count,
"train/learning_rate": self.lr_scheduler.get_last_lr()[0],
"rollout/temperature": data["temperature"].cpu().numpy()[0][0],
}
)
self.accum_loss.zero_()
self.accum_reward.zero_()
self.accum_acc_reward.zero_()
self.accum_format_reward.zero_()
self.accum_kl.zero_()
self.accum_advantages.zero_()
self.accum_response_length.zero_()
self.accum_count = 0
return loss_scalar
def state_dict(self):
self.policy_model._force_wait_all_gather()
model = self.policy_model.unwrap()
state_dict = model.state_dict()
return state_dict
@ray.remote
class GRPOEvalConsumer(BaseConsumer):
def __init__(
self,
num_producers,
num_episodes,
rank,
world_size,
master_addr,
master_port,
num_update_per_episode,
num_recv_per_update,
batch_size,
model_config,
plugin_config,
microbatch_size=1,
num_generations=4,
use_wandb=True,
log_dir="./results",
):
super().__init__(
num_producers,
num_episodes,
rank,
world_size,
master_addr,
master_port,
num_update_per_episode,
num_recv_per_update,
batch_size,
model_config,
plugin_config,
microbatch_size,
)
path = model_config.pop("path")
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
self.policy_model.train()
self.accum_reward = torch.zeros(1, device=self.device)
self.accum_format_reward = torch.zeros(1, device=self.device)
self.accum_acc_reward = torch.zeros(1, device=self.device)
self.accum_response_length = torch.zeros(1, device=self.device)
self.accum_count = torch.zeros(1, device=self.device)
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.pad_token_id = self.tokenizer.pad_token_id
self.num_generations = num_generations
# Initialize verifiable reward.
response_format_tags = {
"think_start": {"text": "<think>", "num_occur": 1},
"think_end": {"text": "</think>", "num_occur": 1},
"answer_start": {"text": "<answer>", "num_occur": 1},
"answer_end": {"text": "</answer>", "num_occur": 1},
}
self.reward_model = VerifiableReward(
reward_fns=[math_reward_fn], tokenizer=self.tokenizer, tags=response_format_tags
)
self.log_dir = log_dir
if not os.path.exists(self.log_dir):
os.makedirs(self.log_dir)
else:
os.system(f"rm -rf {self.log_dir}/*")
def setup(self):
super().setup()
self.policy_model, _, *_ = self.booster.boost(self.policy_model)
def step(self, step_idx: int, **kwargs) -> Optional[float]:
rank = dist.get_rank()
data = {k: v.view(-1, v.size(-1)).cpu() for k, v in kwargs.items()}
kwargs["input_ids"].size(0)
reward_group = self.reward_model(
data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"]
)
reward = [value[0].item() for value in reward_group]
format_reward = [value[1].item() for value in reward_group]
acc_reward = [value[2].item() for value in reward_group]
response_length = [(data["response_idx"][i][1] - data["response_idx"][i][0]).item() for i in range(len(reward))]
response = self.tokenizer.batch_decode(data["input_ids"], skip_special_tokens=True)
with open(f"{self.log_dir}/eval_results_rank_{rank}.jsonl", "a", encoding="utf8") as f:
for i in range(len(response)):
f.write(
json.dumps(
{
"response": response[i],
"reward": reward[i],
"format_reward": format_reward[i],
"acc_reward": acc_reward[i],
"response_length": response_length[i],
},
ensure_ascii=False,
)
+ "\n"
)
self.accum_reward += sum(reward)
self.accum_format_reward += sum(format_reward)
self.accum_acc_reward += sum(acc_reward)
self.accum_response_length += sum(response_length)
self.accum_count += len(reward)
# print results
total_count = all_reduce_mean(self.accum_count, self.plugin)
mean_reward = all_reduce_mean(self.accum_reward, self.plugin) / total_count
mean_format_reward = all_reduce_mean(self.accum_format_reward, self.plugin) / total_count
mean_acc_reward = all_reduce_mean(self.accum_acc_reward, self.plugin) / total_count
mean_response_length = all_reduce_mean(self.accum_response_length, self.plugin) / total_count
if rank == 0:
print(
f"Step {step_idx}: Mean Reward: {mean_reward}, Mean Format Reward: {mean_format_reward}, Mean Acc Reward: {mean_acc_reward}, Mean Response Length: {mean_response_length}"
)
return None
def state_dict(self):
self.policy_model._force_wait_all_gather()
model = self.policy_model.unwrap()
state_dict = model.state_dict()
return state_dict

View File

@@ -53,7 +53,13 @@ class TransformersInferenceBackend(BaseInferenceBackend):
)
FORCE_GENERATE_CONFIG = dict(output_logits=True, return_dict_in_generate=True)
def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer):
def __init__(
self,
model_config: Dict[str, Any],
generate_config: Dict[str, Any],
tokenizer: PreTrainedTokenizer,
num_generations: int = 8,
):
model_config = update_by_default(model_config, self.DEFAULT_MODEL_CONFIG)
model_config.update(self.FORCE_MODEL_CONFIG)
path = model_config.pop("path")
@@ -61,12 +67,22 @@ class TransformersInferenceBackend(BaseInferenceBackend):
self.generate_config = generate_config.copy()
self.generate_config.update(self.FORCE_GENERATE_CONFIG)
self.tokenizer = tokenizer
self.num_generations = num_generations
@torch.no_grad()
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
micro_batch_size = input_ids.size(0)
input_ids = input_ids.to(get_current_device())
attention_mask = attention_mask.to(get_current_device())
out = self.model.generate(input_ids, attention_mask=attention_mask, **kwargs, **self.generate_config)
gt_answer = None
if "gt_answer" in kwargs:
gt_answer = kwargs.pop("gt_answer")
if self.num_generations > 1:
input_ids = input_ids.repeat_interleave(self.num_generations, dim=0)
attention_mask = attention_mask.repeat_interleave(self.num_generations, dim=0)
out = self.model.generate(
input_ids, attention_mask=attention_mask, **kwargs, **self.generate_config, tokenizer=self.tokenizer
)
input_len = input_ids.shape[-1]
new_token_ids = out.sequences[:, input_len:]
# get log probs
@@ -76,10 +92,13 @@ class TransformersInferenceBackend(BaseInferenceBackend):
action_log_probs.append(log_probs_from_logits(logits[:, None, :], new_token_ids[:, i : i + 1]))
action_log_probs = torch.cat(action_log_probs, dim=1)
# get action mask
response_idx = torch.zeros((new_token_ids.size(0), 2), dtype=torch.int).to(get_current_device())
action_mask = torch.ones_like(new_token_ids, dtype=attention_mask.dtype)
if self.tokenizer.eos_token_id is not None:
for indices in torch.nonzero(new_token_ids == self.tokenizer.eos_token_id):
action_mask[indices[0], indices[1] + 1 :] = 0
response_idx[:, 0] = input_len
response_idx[:, 1] = input_len + action_mask.sum(dim=1) - 1
if attention_mask.size(0) != action_mask.size(0):
assert action_mask.size(0) % attention_mask.size(0) == 0
@@ -91,7 +110,15 @@ class TransformersInferenceBackend(BaseInferenceBackend):
"attention_mask": attention_mask,
"action_log_probs": action_log_probs,
"action_mask": action_mask,
"response_idx": response_idx,
}
data = {k: v.view(micro_batch_size, self.num_generations, v.size(-1)) for k, v in data.items()}
if gt_answer is not None:
# repeat gt_answer for each prompt.
data["gt_answer"] = gt_answer.repeat_interleave(self.num_generations, dim=1)
data = {k: v.to(get_current_device()) for k, v in data.items()}
return data
def load_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None:
@@ -99,7 +126,13 @@ class TransformersInferenceBackend(BaseInferenceBackend):
class SGLangInferenceBackend(BaseInferenceBackend):
def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer):
def __init__(
self,
model_config: Dict[str, Any],
generate_config: Dict[str, Any],
tokenizer: PreTrainedTokenizer,
num_generations: int = 8,
):
if sgl is None:
raise ImportError("sglang is not installed")
path = model_config.pop("path")
@@ -156,29 +189,46 @@ class VLLMInferenceBackend(BaseInferenceBackend):
logprobs=0,
)
def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer):
def __init__(
self,
model_config: Dict[str, Any],
generate_config: Dict[str, Any],
tokenizer: PreTrainedTokenizer,
num_generations: int = 8,
):
if LLM is None:
raise ImportError("vllm is not installed")
model_config = update_by_default(model_config, self.DEFAULT_MODEL_CONFIG)
path = model_config.pop("path")
self.llm = LLM(path, **model_config)
self.llm = LLM(model=path, **model_config)
generate_config = generate_config.copy()
generate_config.update(self.FORCE_GENERATE_CONFIG)
generate_config.update({"n": num_generations})
self.generate_config = SamplingParams(**generate_config)
self.tokenizer = tokenizer
self.num_generations = num_generations
@torch.no_grad()
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
micro_batch_size = input_ids.size(0)
response_start_idx = input_ids.size(1)
first_non_padding_token_idx = (input_ids != self.tokenizer.pad_token_id).int().argmax(dim=1)
micro_batch_input_ids = input_ids.tolist()
micro_batch_input_ids_no_padding = [
micro_batch_input_ids[i][first_non_padding_token_idx[i] :] for i in range(micro_batch_size)
]
outputs = self.llm.generate(
prompt_token_ids=input_ids.tolist(), sampling_params=self.generate_config, use_tqdm=False
prompt_token_ids=micro_batch_input_ids_no_padding, sampling_params=self.generate_config, use_tqdm=False
)
out_tokens = []
out_len = []
log_probs = []
response_idx = []
for out in outputs:
for output_i in out.outputs:
out_len.append(len(output_i.token_ids))
out_tokens.append(list(output_i.token_ids))
response_idx.append((response_start_idx, response_start_idx + len(output_i.token_ids) - 1))
assert len(output_i.logprobs) == len(output_i.token_ids)
p = [m[t].logprob for m, t in zip(output_i.logprobs, output_i.token_ids)]
log_probs.append(p)
@@ -195,6 +245,8 @@ class VLLMInferenceBackend(BaseInferenceBackend):
out_tokens = torch.tensor(out_tokens)
log_probs = torch.tensor(log_probs)
response_idx = torch.tensor(response_idx)
if attention_mask.size(0) != action_mask.size(0):
assert action_mask.size(0) % attention_mask.size(0) == 0
num_returns = action_mask.size(0) // attention_mask.size(0)
@@ -209,7 +261,14 @@ class VLLMInferenceBackend(BaseInferenceBackend):
"attention_mask": attention_mask,
"action_log_probs": log_probs,
"action_mask": action_mask,
"response_idx": response_idx,
}
data = {k: v.view(micro_batch_size, self.num_generations, v.size(-1)) for k, v in data.items()}
if "gt_answer" in kwargs:
# repeat gt_answer for each prompt.
data["gt_answer"] = kwargs["gt_answer"].repeat_interleave(self.num_generations, dim=1)
data = {k: v.to(get_current_device()) for k, v in data.items()}
return data

View File

@@ -1,10 +1,14 @@
import copy
from typing import Any, Dict, Optional
import ray
from .consumer import SimpleConsumer
from .grpo_consumer import GRPOConsumer, GRPOEvalConsumer
from .producer import SimpleProducer
ALGO_MAP = {"Simple": SimpleConsumer, "GRPO": GRPOConsumer, "EvalGRPO": GRPOEvalConsumer}
def get_jsonl_size_fast(path: str) -> int:
with open(path) as f:
@@ -30,6 +34,7 @@ def launch_distributed(
inference_microbatch_size: int,
train_batch_size: int,
train_microbatch_size: int,
train_minibatch_size: int,
dataset_config: Dict[str, Any],
dataloaders_config: Dict[str, Any],
inference_model_config: Dict[str, Any],
@@ -38,9 +43,18 @@ def launch_distributed(
plugin_config: Dict[str, Any],
tokenizer_config: Optional[Dict[str, Any]] = None,
inference_backend: str = "transformers",
num_generations: int = 8,
master_addr: str = "localhost",
master_port: int = 29500,
core_algo: str = "GRPO",
project_name: Optional[str] = None,
):
if core_algo not in ALGO_MAP:
raise NotImplementedError(f"{core_algo} is not supported yet.")
else:
core_consumer = ALGO_MAP.get(core_algo, SimpleConsumer)
train_dp_size = get_dp_size_fast(num_producers, plugin_config)
assert (inference_batch_size * num_producers) % (train_batch_size * train_dp_size) == 0
@@ -65,10 +79,17 @@ def launch_distributed(
tokenizer_config=tokenizer_config,
microbatch_size=inference_microbatch_size,
backend=inference_backend,
num_generations=num_generations,
)
procs.append(producer)
generate_config_consumer = copy.deepcopy(generate_config)
generate_config_consumer.update(
dict(
backend=inference_backend,
)
)
for i in range(num_consumer_procs):
consumer = SimpleConsumer.options(num_gpus=1).remote(
consumer = core_consumer.options(num_gpus=1).remote(
num_producers=num_producers,
num_episodes=num_episodes,
rank=i,
@@ -80,7 +101,15 @@ def launch_distributed(
batch_size=train_batch_size,
model_config=train_model_config,
plugin_config=plugin_config,
microbatch_size=train_microbatch_size,
microbatch_size=train_minibatch_size,
generate_config=generate_config_consumer,
training_config={
"filter_range": [0.05, 9.0],
"lr": 1e-6,
"train_microbatch_size": train_microbatch_size,
},
num_generations=num_generations,
project_name=project_name,
)
procs.append(consumer)
ray.get([p.setup.remote() for p in procs])

View File

@@ -0,0 +1,45 @@
from typing import Optional
import torch
import torch.nn as nn
from coati.distributed.utils import masked_mean
class PolicyLoss(nn.Module):
"""
Policy Loss for PPO
"""
def __init__(self, clip_eps: float = 0.2, skip_threshold: float = 20.0, beta: float = 0.01) -> None:
super().__init__()
self.clip_eps = clip_eps
self.skip_threshold = skip_threshold
self.beta = beta
def forward(
self,
log_probs: torch.Tensor,
old_log_probs: torch.Tensor,
advantages: torch.Tensor,
per_token_kl: torch.Tensor,
action_mask: Optional[torch.Tensor] = None,
loss_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
skip = False
if action_mask is None:
ratio = (log_probs - log_probs.detach()).exp()
else:
ratio = ((log_probs - log_probs.detach()) * action_mask).exp()
surr1 = ratio * advantages
surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages
loss = -torch.min(surr1, surr2) + self.beta * per_token_kl
if action_mask is not None:
loss = masked_mean(loss, action_mask)
else:
loss = loss.mean(dim=1)
if loss_mask is not None:
loss = loss * loss_mask
loss = loss.mean()
return loss, skip, ratio.max()

View File

@@ -100,7 +100,11 @@ class BaseProducer:
if i >= num_valid_microbatches:
break
outputs = self.rollout(**batch)
print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}")
outputs["temperature"] = torch.tensor(
[self.model.generate_config.temperature] * outputs["input_ids"].size(0)
).to(outputs["input_ids"].device)
outputs = pre_send(outputs)
ray_broadcast_tensor_dict(
outputs, src=0, device=self.device, group_name=f"sync_data_{self.producer_idx}"
@@ -113,10 +117,19 @@ class BaseProducer:
print(
f"[P{self.producer_idx}] Sync model episode {episode} step {(i + 1) // self.num_microbatches - 1}"
)
state_dict = ray_broadcast_tensor_dict(
None, self.num_producers, device=self.device, group_name="sync_model"
)
self.load_state_dict(state_dict)
del state_dict
torch.cuda.empty_cache()
# linear annealing for 1 episode, temperature from initial to 0.7
if episode <= 0:
ratio = 1 - (len(self.dataloader) - i) / len(self.dataloader)
self.model.generate_config.temperature = (1 - ratio) * self.generate_config[
"temperature"
] + ratio * 0.7
@ray.remote
@@ -135,6 +148,7 @@ class SimpleProducer(BaseProducer):
tokenizer_config=None,
microbatch_size=1,
backend="transformers",
num_generations: int = 8,
):
super().__init__(
producer_idx,
@@ -150,11 +164,15 @@ class SimpleProducer(BaseProducer):
microbatch_size,
backend,
)
self.model = self.backend_cls(model_config, generate_config, self.tokenizer)
self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations)
@torch.no_grad()
def rollout(self, input_ids, attention_mask, **kwargs):
return self.model.generate(input_ids, attention_mask, **kwargs)
rollouts = self.model.generate(input_ids, attention_mask, **kwargs)
if self.producer_idx == 1:
print("Rollout example:\n", self.tokenizer.decode(rollouts["input_ids"][0][0], skip_special_tokens=True))
return rollouts
def load_state_dict(self, state_dict):
self.model.load_state_dict(state_dict)

View File

@@ -0,0 +1,61 @@
import torch
from .reward_utils import extract_solution, validate_response_structure
def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
format_score = 1.0
acc_score = 9.0
tokenizer = kwargs["tokenizer"]
reward = torch.tensor(0.0)
format_reward = torch.tensor(0.0)
acc_reward = torch.tensor(0.0)
s, e = response_idx[0], response_idx[1]
if gt_answer is None:
return reward
decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True)
gt_answer = tokenizer.decode(gt_answer.squeeze(0), skip_special_tokens=True)
final_answer, processed_str = extract_solution(decoded_final_answer)
format_valid = validate_response_structure(processed_str, kwargs["tags"])
# Check format accuracy
if format_valid:
format_reward += format_score
reward += format_score
# Check answer accuracy
if (
final_answer is not None
and gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower()
):
acc_reward += acc_score
reward += acc_score
return torch.tensor([reward, format_reward, acc_reward]).to(input_ids.device)
def gsm8k_reward_fn(input_ids, **kwargs):
gt_answer = kwargs["gt_answer"]
tokenizer = kwargs["tokenizer"]
s, e = kwargs["response_start"], kwargs["response_end"]
reward = torch.tensor(0.0).to(input_ids.device)
if gt_answer is None:
return reward
decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True)
final_answer, processed_str = extract_solution(decoded_final_answer)
is_valid = True
try:
int(final_answer.strip())
except Exception:
is_valid = False
format_valid = validate_response_structure(processed_str, kwargs["tags"])
if not is_valid or not format_valid:
return reward
else:
reward += 1.0
if gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower():
reward = reward + 9.0
return reward

View File

@@ -0,0 +1,76 @@
# Copyright Unakar
# Modified from https://github.com/Unakar/Logic-RL/blob/086373176ac198c97277ff50f4b6e7e1bfe669d3/verl/utils/reward_score/kk.py#L99
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from typing import Dict, Optional, Tuple
def validate_response_structure(processed_str: str, tags: Dict = None) -> bool:
"""Performs comprehensive validation of response structure.
Args:
processed_str: Processed response string from the model
Returns:
Boolean indicating whether all formatting requirements are met
"""
validation_passed = True
# Check required tags
if tags is None:
tags = {
"think_start": {"text": "<think>", "num_occur": 1},
"think_end": {"text": "</think>", "num_occur": 1},
"answer_start": {"text": "<answer>", "num_occur": 1},
"answer_end": {"text": "</answer>", "num_occur": 1},
}
positions = {}
for tag_name, tag_info in tags.items():
tag_str = tag_info["text"]
expected_count = tag_info["num_occur"]
count = processed_str.count(tag_str)
positions[tag_name] = pos = processed_str.find(tag_str)
if count != expected_count:
validation_passed = False
# Verify tag order
if (
positions["think_start"] > positions["think_end"]
or positions["think_end"] > positions["answer_start"]
or positions["answer_start"] > positions["answer_end"]
):
validation_passed = False
if len(processed_str) - positions["answer_end"] != len(tags["answer_end"]["text"]):
validation_passed = False
return validation_passed
def extract_solution(solution_str: str) -> Tuple[Optional[str], str]:
"""Extracts the final answer from the model's response string.
Args:
solution_str: Raw response string from the language model
Returns:
Tuple containing (extracted_answer, processed_string)
"""
# Extract final answer using XML-style tags
answer_pattern = r"<answer>(.*?)</answer>"
matches = list(re.finditer(answer_pattern, solution_str, re.DOTALL))
if not matches:
return None, solution_str
final_answer = matches[-1].group(1).strip()
return final_answer, solution_str

View File

@@ -0,0 +1,43 @@
"""
Function-based reward verification module.
"""
from typing import Any, Dict, List
import torch
class VerifiableReward:
def __init__(self, reward_fns: List[callable], **kwargs: List[Dict[str, Any]]):
self.reward_fns = reward_fns
self.kwargs = kwargs
def __call__(
self,
input_ids: torch.LongTensor,
gt_answer: List[torch.Tensor] = None,
response_idx: List[torch.Tensor] = None,
) -> torch.Tensor:
# Get batch size
bs = input_ids.size(0)
# Initialize reward
rewards = torch.zeros((bs, 3), device=input_ids.device)
# Loop through reward functions
for reward_fn in self.reward_fns:
# Apply the reward function to the entire batch at once
reward_batch = torch.stack(
[
reward_fn(
input_ids[i],
gt_answer=gt_answer[i],
response_idx=response_idx[i],
**self.kwargs,
)
for i in range(bs)
],
dim=0,
)
rewards += reward_batch
return rewards

View File

@@ -2,6 +2,8 @@ from typing import Any, Dict, List
import torch
from colossalai.shardformer.layer.loss import dist_log_prob
def unbind_batch(batch: Dict[str, torch.Tensor]) -> List[Dict[str, torch.Tensor]]:
batches = []
@@ -64,3 +66,50 @@ def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.T
log_probs = torch.log_softmax(logits, dim=-1)
per_label_logps = log_probs.gather(dim=-1, index=labels.unsqueeze(-1))
return per_label_logps.squeeze(-1)
def calc_action_log_probs(
logits: torch.Tensor,
sequences: torch.LongTensor,
num_actions: int,
shard_config,
vocab_size: int = None,
) -> torch.Tensor:
"""Calculate action log probs.
Args:
logits (torch.Tensor): Output tensor of Actor.forward.logits.
sequences (torch.LongTensor): Input sequences.
num_actions (int): Number of actions.
shard_config
vocab_size
Returns:
torch.Tensor: Action log probs.
"""
# labels: torch.Tensor, # [B, S] or [B, S, Vocab_size]
# logits: torch.Tensor, # [B, S, Vocab_size]
log_probs = dist_log_prob(sequences, logits, shard_config, vocab_size, logits.dtype)
log_probs = log_probs.squeeze(-1)
return log_probs[:, -num_actions:]
def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch.Tensor:
"""
Compute the masked mean of a tensor along a specified dimension.
Args:
tensor (torch.Tensor): The input tensor.
mask (torch.Tensor): The mask tensor with the same shape as the input tensor.
dim (int, optional): The dimension along which to compute the mean. Default is 1.
Returns:
torch.Tensor: The masked mean tensor.
"""
tensor = tensor * mask
tensor = tensor.sum(dim=dim)
mask_sum = mask.sum(dim=dim)
mean = tensor / (mask_sum + 1e-8)
return mean

View File

@@ -10,54 +10,83 @@ if __name__ == "__main__":
parser.add_argument("-d", "--dataset", type=str, default="data.jsonl")
parser.add_argument("-t", "--num-trainers", type=int, default=2)
parser.add_argument("-i", "--num-inferencer", type=int, default=2)
parser.add_argument("-ibs", "--inference-batch-size", type=int, default=64)
parser.add_argument("-imbs", "--inference-microbatch-size", type=int, default=16)
parser.add_argument("-tbs", "--train-batch-size", type=int, default=16)
parser.add_argument("-tmbs", "--train-microbatch-size", type=int, default=2)
parser.add_argument("-b", "--backend", type=str, default="transformers")
parser.add_argument("-g", "--num-generations", type=int, default=8, help="Number of generations per prompt.")
parser.add_argument("-p", "--project", type=str, default="GRPO", help="Project name.")
parser.add_argument(
"-ibs",
"--inference-batch-size",
type=int,
default=64,
help="Number of prompts to generate per inference step. It should be divisible by tbs, and the weights on the inference backend will be synced every ibs/tbs training steps of the policy model.",
)
parser.add_argument(
"-imbs",
"--inference-microbatch-size",
type=int,
default=8,
help="Effective batch size for the inference backend to run generation. Please select based on memory constraint.",
)
parser.add_argument(
"-tbs",
"--train-batch-size",
type=int,
default=32,
help="Number of unique prompts to update policy model per step per dp group. Gradient is accumulated across tbs * dp_size unique prompts, equivalently tbs * g * dp_size samples",
)
parser.add_argument(
"-tMbs",
"--train-minibatch-size",
type=int,
default=1,
help="Number of unique prompts in each training batch per dp group. The inference backend must generate tMbs * g * dp_size samples before forwarding. Satisfy tMbs * g >= tmbs",
)
parser.add_argument(
"-tmbs",
"--train-microbatch-size",
type=int,
default=2,
help="Effective batch size per dp group for forwarding and backwarding. Please select based on the availiable memory.",
)
parser.add_argument("-b", "--backend", type=str, default="transformers", choices=["transformers", "vllm"])
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple", "GRPO", "EvalGRPO"])
args = parser.parse_args()
assert args.train_minibatch_size > 0, "Train mini batch size must be greater than 0"
assert (
args.train_minibatch_size * args.num_generations >= args.train_microbatch_size
and args.train_microbatch_size > 0
), "Train micro batch size must be greater than 0 less than train mini batch size * num generations"
ray.init(address="local", namespace="ray-example")
inference_model_config = dict(path=args.model)
train_model_config = dict(path=args.model)
generate_config = dict(
top_k=50,
top_p=0.8,
)
train_model_config = dict(path=args.model, use_flash_attention_2=True, use_cache=False)
generate_config = dict(top_k=50, top_p=0.75, temperature=0.9)
if args.backend == "transformers":
inference_model_config.update(
dict(
attn_implementation="flash_attention_2",
use_flash_attention_2=True,
torch_dtype=torch.bfloat16,
)
)
train_model_config.update(
dict(
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16,
use_cache=False,
)
)
generate_config.update(
dict(
max_length=512,
max_length=1024 + 512,
do_sample=True,
max_new_tokens=None,
early_stopping=False,
stop_strings=["</answer>"],
)
)
elif args.backend == "vllm":
inference_model_config.update(
dict(
gpu_memory_utilization=0.6,
)
)
inference_model_config.update(dict(gpu_memory_utilization=0.7, enforce_eager=True, enable_chunked_prefill=True))
generate_config.update(
dict(
max_tokens=256,
max_tokens=2048,
ignore_eos=True,
include_stop_str_in_output=True,
stop=["</answer>"],
)
)
else:
@@ -77,18 +106,29 @@ if __name__ == "__main__":
num_producers=args.num_inferencer,
num_proc_per_producer=1,
num_consumer_procs=args.num_trainers,
num_episodes=1,
num_episodes=10,
inference_batch_size=args.inference_batch_size,
inference_microbatch_size=args.inference_microbatch_size,
train_batch_size=args.train_batch_size,
train_minibatch_size=args.train_minibatch_size,
train_microbatch_size=args.train_microbatch_size,
dataset_config={"path": args.dataset, "max_length": 256},
dataset_config={"path": args.dataset, "max_length": 300},
dataloaders_config={},
inference_model_config=inference_model_config,
generate_config=generate_config,
num_generations=args.num_generations,
train_model_config=train_model_config,
plugin_config={},
# plugin_config={}, # for zero
plugin_config={
"pp_size": 2,
"tp_size": 2,
"microbatch_size": args.train_microbatch_size // 2,
"zero_stage": 0,
"max_norm": 1.0,
}, # for pp
inference_backend=args.backend,
master_addr="localhost",
master_port=29504,
master_port=29506,
core_algo=args.algo,
project_name=args.project,
)