add DAPO support

This commit is contained in:
YeAnbang 2025-04-15 18:28:35 +08:00
parent 1723a02860
commit 6e71e2a3ce
9 changed files with 321 additions and 155 deletions

View File

@ -12,7 +12,7 @@ from coati.distributed.loss import PolicyLoss
from coati.distributed.reward.reward_fn import math_reward_fn
from coati.distributed.reward.verifiable_reward import VerifiableReward
from coati.distributed.utils import calc_action_log_probs
from coati.trainer.utils import all_reduce_mean
from coati.trainer.utils import all_reduce_mean, all_reduce_sum
from transformers import AutoModelForCausalLM, AutoTokenizer
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
@ -38,7 +38,7 @@ class GRPOConsumer(BaseConsumer):
num_generations=8,
use_wandb=True,
generate_config=None,
training_config={},
grpo_config={},
project_name=None,
):
super().__init__(
@ -59,7 +59,7 @@ class GRPOConsumer(BaseConsumer):
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
self.policy_model.train()
self.policy_model.gradient_checkpointing_enable()
self.optimizer = HybridAdam(self.policy_model.parameters(), lr=training_config.get("lr", 1e-6))
self.optimizer = HybridAdam(self.policy_model.parameters(), lr=grpo_config.get("lr", 1e-6))
self.accum_loss = torch.zeros(1, device=self.device)
self.accum_reward = torch.zeros(1, device=self.device)
self.accum_kl = torch.zeros(1, device=self.device)
@ -69,8 +69,9 @@ class GRPOConsumer(BaseConsumer):
self.accum_response_length = torch.zeros(1, device=self.device)
self.accum_count = 0
self.generate_config = generate_config
self.training_config = training_config
self.grpo_config = grpo_config
self.project_name = project_name
self.effective_sample_count = 0
# Reference model is initialized from policy model.
self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
@ -79,10 +80,21 @@ class GRPOConsumer(BaseConsumer):
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.pad_token_id = self.tokenizer.pad_token_id
self.num_generations = num_generations
self.filter_range = training_config.get("filter_range", None)
self.filter_range = grpo_config.get("filter_range", None)
if self.filter_range is not None:
assert len(self.filter_range) == 2, "Filter range should have 2 values."
self.filter_truncated_response = grpo_config.get("filter_truncated_response", False)
if self.filter_truncated_response:
self.max_length = 0
if "max_tokens" in self.generate_config:
self.max_length = self.generate_config["max_tokens"]
elif "max_new_tokens" in self.generate_config:
self.max_length = self.generate_config["max_new_tokens"]
else:
raise ValueError(
"either max_tokens (vllm) or max_new_tokens (transformers) must be set in generate_config."
)
# Initialize verifiable reward.
response_format_tags = {
"think_start": {"text": "<think>", "num_occur": 1},
@ -90,11 +102,20 @@ class GRPOConsumer(BaseConsumer):
"answer_start": {"text": "<answer>", "num_occur": 1},
"answer_end": {"text": "</answer>", "num_occur": 1},
}
reward_model_kwargs = {
k: v for k, v in grpo_config.items() if k in ["soft_over_length_punishment", "max_length", "cache_length"]
}
self.reward_model = VerifiableReward(
reward_fns=[math_reward_fn], tokenizer=self.tokenizer, tags=response_format_tags
reward_fns=[math_reward_fn], tokenizer=self.tokenizer, tags=response_format_tags, **reward_model_kwargs
)
self.policy_loss_fn = PolicyLoss()
self.policy_loss_fn = PolicyLoss(
clip_eps_low=grpo_config.get("clip_eps_low", 0.2),
clip_eps_high=grpo_config.get("clip_eps_high", 0.2),
skip_threshold=grpo_config.get("skip_threshold", 20.0),
beta=grpo_config.get("beta", 0.01),
loss_variation=grpo_config.get("loss_variation", "sample_level"),
)
self.global_step = 0
self.use_wandb = use_wandb
@ -102,7 +123,7 @@ class GRPOConsumer(BaseConsumer):
optimizer=self.optimizer,
total_steps=min(self.num_episodes, 4) * self.num_update_per_episode,
warmup_steps=0,
eta_min=0.1 * training_config.get("lr", 1e-6),
eta_min=0.1 * grpo_config.get("lr", 1e-6),
)
def setup(self):
@ -141,9 +162,65 @@ class GRPOConsumer(BaseConsumer):
num_action = action_mask.shape[1]
old_action_log_probs = data["action_log_probs"]
response_length = torch.sum(action_mask, dim=1).to(torch.float32)
forward_batch_size = self.training_config.get("train_microbatch_size", data["input_ids"].size(0))
forward_batch_size = self.grpo_config.get("train_microbatch_size", data["input_ids"].size(0))
reward_group = self.reward_model(
int(step_idx / self.num_microbatches),
data["input_ids"],
gt_answer=data["gt_answer"],
response_idx=data["response_idx"],
)
reward = torch.tensor([value[0] for value in reward_group]).to(data["input_ids"].device)
format_reward = torch.tensor([value[1] for value in reward_group]).to(data["input_ids"].device)
acc_reward = torch.tensor([value[2] for value in reward_group]).to(data["input_ids"].device)
# [batch_size, num_generations]
group_reward = reward.view(-1, self.num_generations)
reward_mean = group_reward.mean(dim=1)
# [batch_size x num_generations]
reward_mean = reward_mean.repeat_interleave(self.num_generations, dim=0)
reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0)
# [batch_size x num_generations]
advantages = ((reward - reward_mean) / (reward_std + 1e-4)).unsqueeze(dim=-1)
# filter out the reward that is too high (all sample gets full score) or too low (all sample gets 0 score),
reward_mean_no_length_penalty = (
(format_reward + acc_reward)
.view(-1, self.num_generations)
.mean(dim=1)
.repeat_interleave(self.num_generations, dim=0)
)
loss_mask = (
torch.ones(action_mask.size(0), device=action_mask.device).bool()
if self.filter_range is None
else torch.logical_and(
reward_mean_no_length_penalty > self.filter_range[0], reward_mean < self.filter_range[1]
)
)
# filter out overlength samples
if self.filter_truncated_response and action_mask.size(1) == self.max_length:
loss_mask = torch.logical_and(
loss_mask,
action_mask[:, -1] == False,
)
# for i in range(loss_mask.size(0)):
# if loss_mask[i] == False:
# print(data["input_ids"].size(), data["input_ids"][i], action_mask[i], "mean reward", reward_mean_no_length_penalty.size(), reward_mean_no_length_penalty[i])
effective_samples = all_reduce_sum(torch.sum(loss_mask), self.plugin)
self.effective_sample_count += effective_samples.item()
mean_kl, mean_loss = [], []
# update gradient only if at least 0.7*batch_size*num_generation valid samples are collected in case a lot of samples are invalid and got filtered out.
# balance between efficiency and accuracy
need_update = self.effective_sample_count >= self.batch_size * self.dp_size * self.num_generations * 0.75
if need_update:
print(f"***** Update gradient based on {self.effective_sample_count} valid samples *****")
self.effective_sample_count = 0
need_update = (step_idx + 1) % self.num_microbatches == 0
# Gradient must be synchronized if zero2 is enabled. https://github.com/hpcaitech/ColossalAI/blob/44d4053fec005fe0b06b6bc755fdc962463145df/colossalai/booster/plugin/hybrid_parallel_plugin.py#L1500
ctx = (
nullcontext()
@ -151,32 +228,6 @@ class GRPOConsumer(BaseConsumer):
else self.booster.no_sync(self.policy_model, self.optimizer)
)
with ctx:
reward_group = self.reward_model(
data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"]
)
reward = torch.tensor([value[0] for value in reward_group]).to(data["input_ids"].device)
format_reward = torch.tensor([value[1] for value in reward_group]).to(data["input_ids"].device)
acc_reward = torch.tensor([value[2] for value in reward_group]).to(data["input_ids"].device)
# [batch_size, num_generations]
group_reward = reward.view(-1, self.num_generations)
reward_mean = group_reward.mean(dim=1)
# [batch_size x num_generations]
reward_mean = reward_mean.repeat_interleave(self.num_generations, dim=0)
reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0)
# [batch_size x num_generations]
advantages = ((reward - reward_mean) / (reward_std + 1e-4)).unsqueeze(dim=-1)
# filter out the reward that is too high (all sample gets full score) or too low (all sample gets 0 score),
loss_mask = (
None
if self.filter_range is None
else torch.logical_and(
reward_mean > self.filter_range[0], reward_mean < self.filter_range[1]
).repeat_interleave(self.num_generations, dim=0)
)
mean_kl, mean_loss = [], []
for forward_micro_batch_start in range(0, data["input_ids"].size(0), forward_batch_size):
input_ids_forward_micro_batch = data["input_ids"][
@ -199,47 +250,50 @@ class GRPOConsumer(BaseConsumer):
if self.plugin.pp_size > 1:
# Support training with PP.
if self.policy_loss_fn.beta > 0:
with torch.no_grad():
reference_model_outputs = self.booster.execute_pipeline(
iter(
[
{
"input_ids": input_ids_forward_micro_batch,
"attention_mask": attention_mask_forward_micro_batch,
}
]
),
self.reference_model,
criterion=lambda outputs, inputs: torch.tensor(
[0.0], device=action_mask.device
), # dummy criterion
optimizer=None,
return_loss=False,
return_outputs=True,
)
with torch.no_grad():
reference_model_outputs = self.booster.execute_pipeline(
iter(
[
{
"input_ids": input_ids_forward_micro_batch,
"attention_mask": attention_mask_forward_micro_batch,
}
]
),
self.reference_model,
criterion=lambda outputs, inputs: torch.tensor(
[0.0], device=action_mask.device
), # dummy criterion
optimizer=None,
return_loss=False,
return_outputs=True,
)
if self.booster.plugin.stage_manager.is_last_stage():
reference_model_logits = reference_model_outputs["outputs"]["logits"]
reference_action_log_probs = calc_action_log_probs(
reference_model_logits / self.generate_config["temperature"],
input_ids_forward_micro_batch,
num_action,
self.plugin.shard_config,
)
if self.booster.plugin.stage_manager.is_last_stage():
reference_model_logits = reference_model_outputs["outputs"]["logits"]
reference_action_log_probs = calc_action_log_probs(
reference_model_logits / self.generate_config["temperature"],
input_ids_forward_micro_batch,
num_action,
self.plugin.shard_config,
)
else:
# Dummy reference logprobs for data iterator.
reference_action_log_probs = None
else:
# Dummy reference logprobs for data iterator.
reference_action_log_probs = None
data_policy_forward = {
"input_ids": input_ids_forward_micro_batch,
"attention_mask": attention_mask_forward_micro_batch,
"action_mask": action_mask_forward_micro_batch,
"reference_action_log_probs": reference_action_log_probs,
"advantages": advantages_forward_micro_batch,
"loss_mask": loss_mask_forward_micro_batch,
"source": self.rank,
}
if reference_action_log_probs is not None:
data_policy_forward["reference_action_log_probs"] = reference_action_log_probs
kl = []
@ -251,15 +305,20 @@ class GRPOConsumer(BaseConsumer):
num_action,
self.plugin.shard_config,
)
per_token_kl = (
torch.exp(inputs["reference_action_log_probs"] - action_log_probs)
- (inputs["reference_action_log_probs"] - action_log_probs)
- 1
)
appox_kl = torch.sum(per_token_kl * inputs["action_mask"], dim=-1) / torch.sum(
inputs["action_mask"], dim=-1
)
kl.append(appox_kl.mean())
if "reference_action_log_probs" in inputs:
per_token_kl = (
torch.exp(inputs["reference_action_log_probs"] - action_log_probs)
- (inputs["reference_action_log_probs"] - action_log_probs)
- 1
)
appox_kl = torch.sum(per_token_kl * inputs["action_mask"], dim=-1) / torch.sum(
inputs["action_mask"], dim=-1
)
kl.append(appox_kl.mean())
else:
per_token_kl = 0.0
kl.append(0.0)
loss, skip_update, _ = self.policy_loss_fn(
action_log_probs,
action_log_probs,
@ -298,25 +357,29 @@ class GRPOConsumer(BaseConsumer):
self.plugin.shard_config,
)
with torch.no_grad():
reference_model_logits = self.reference_model(
input_ids=input_ids_forward_micro_batch,
attention_mask=attention_mask_forward_micro_batch,
).logits
reference_action_log_probs = calc_action_log_probs(
reference_model_logits / self.generate_config["temperature"],
input_ids_forward_micro_batch,
num_action,
self.plugin.shard_config,
)
per_token_kl = (
torch.exp(reference_action_log_probs - action_log_probs)
- (reference_action_log_probs - action_log_probs)
- 1
)
kl = torch.sum(per_token_kl * action_mask_forward_micro_batch, dim=-1) / torch.sum(
action_mask_forward_micro_batch, dim=-1
)
if self.policy_loss_fn.beta > 0:
with torch.no_grad():
reference_model_logits = self.reference_model(
input_ids=input_ids_forward_micro_batch,
attention_mask=attention_mask_forward_micro_batch,
).logits
reference_action_log_probs = calc_action_log_probs(
reference_model_logits / self.generate_config["temperature"],
input_ids_forward_micro_batch,
num_action,
self.plugin.shard_config,
)
per_token_kl = (
torch.exp(reference_action_log_probs - action_log_probs)
- (reference_action_log_probs - action_log_probs)
- 1
)
kl = torch.sum(per_token_kl * action_mask_forward_micro_batch, dim=-1) / torch.sum(
action_mask_forward_micro_batch, dim=-1
)
else:
per_token_kl = 0.0
kl = None
loss, skip_update, _ = self.policy_loss_fn(
action_log_probs,
@ -330,9 +393,10 @@ class GRPOConsumer(BaseConsumer):
if not skip_update:
self.booster.backward(loss, self.optimizer)
loss = all_reduce_mean(loss, self.plugin)
kl = all_reduce_mean(kl.mean(), self.plugin)
# Calculate accumulate value.
mean_kl.append(kl.data)
if kl is not None:
kl = all_reduce_mean(kl.mean(), self.plugin)
mean_kl.append(kl.data)
mean_loss.append(loss.data)
if not self.plugin.pp_size > 1 or (
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
@ -343,7 +407,8 @@ class GRPOConsumer(BaseConsumer):
advantages = all_reduce_mean(advantages.mean(), self.plugin)
response_length = all_reduce_mean(response_length.mean(), self.plugin)
self.accum_loss.add_(sum(mean_loss) / len(mean_loss))
self.accum_kl.add_(sum(mean_kl) / len(mean_kl))
if self.policy_loss_fn.beta > 0:
self.accum_kl.add_(sum(mean_kl) / len(mean_kl))
self.accum_reward.add_(reward.data)
self.accum_format_reward.add_(format_reward.data)
self.accum_acc_reward.add_(acc_reward.data)
@ -360,35 +425,32 @@ class GRPOConsumer(BaseConsumer):
if (not self.plugin.pp_size > 1 and self.rank == 0) or (
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
):
print(
"Loss:",
self.accum_loss.item() / self.accum_count,
"\nReward:",
self.accum_reward.item() / self.accum_count,
"\nFormat Reward:",
self.accum_format_reward.item() / self.accum_count,
"\nAcc Reward:",
self.accum_acc_reward.item() / self.accum_count,
"\nKL:",
self.accum_kl.item() / self.accum_count,
"\nAdvantages:",
self.accum_advantages.item() / self.accum_count,
"\nResponse Length:",
self.accum_response_length.item() / self.accum_count,
)
self.wandb_run.log(
{
"metrics/reward": self.accum_reward.item() / self.accum_count,
"metrics/format_reward": self.accum_format_reward.item() / self.accum_count,
"metrics/acc_reward": self.accum_acc_reward.item() / self.accum_count,
"metrics/response_length": self.accum_response_length.item() / self.accum_count,
"train/loss": self.accum_loss.item() / self.accum_count,
"train/kl": self.accum_kl.item() / self.accum_count,
"train/advantages": self.accum_advantages.item() / self.accum_count,
"train/learning_rate": self.lr_scheduler.get_last_lr()[0],
"rollout/temperature": data["temperature"].cpu().numpy()[0][0],
}
to_log_msg = (
f"Loss: {self.accum_loss.item() / self.accum_count:.4f} \
Reward: {self.accum_reward.item() / self.accum_count:.4f} \
Format Reward: {self.accum_format_reward.item() / self.accum_count:.4f} \
Acc Reward: {self.accum_acc_reward.item() / self.accum_count:.4f} \
Advantages: {self.accum_advantages.item() / self.accum_count:.4f} \
Response Length: {self.accum_response_length.item() / self.accum_count:.4f}"
+ f" KL: {self.accum_kl.item() / self.accum_count:.4f}"
if self.policy_loss_fn.beta > 0
else ""
)
print(to_log_msg)
metrics = {
"metrics/reward": self.accum_reward.item() / self.accum_count,
"metrics/format_reward": self.accum_format_reward.item() / self.accum_count,
"metrics/acc_reward": self.accum_acc_reward.item() / self.accum_count,
"metrics/response_length": self.accum_response_length.item() / self.accum_count,
"train/loss": self.accum_loss.item() / self.accum_count,
"train/advantages": self.accum_advantages.item() / self.accum_count,
"train/learning_rate": self.lr_scheduler.get_last_lr()[0],
"rollout/temperature": data["temperature"].cpu().numpy()[0][0],
}
if self.policy_loss_fn.beta > 0:
metrics["train/kl"] = self.accum_kl.item() / self.accum_count
self.wandb_run.log(metrics)
self.accum_loss.zero_()
self.accum_reward.zero_()
self.accum_acc_reward.zero_()

View File

@ -40,6 +40,7 @@ def launch_distributed(
inference_model_config: Dict[str, Any],
generate_config: Dict[str, Any],
train_model_config: Dict[str, Any],
grpo_config: Dict[str, Any],
plugin_config: Dict[str, Any],
tokenizer_config: Optional[Dict[str, Any]] = None,
inference_backend: str = "transformers",
@ -103,11 +104,7 @@ def launch_distributed(
plugin_config=plugin_config,
microbatch_size=train_minibatch_size,
generate_config=generate_config_consumer,
training_config={
"filter_range": [0.05, 9.0],
"lr": 1e-6,
"train_microbatch_size": train_microbatch_size,
},
grpo_config=grpo_config,
num_generations=num_generations,
project_name=project_name,
)

View File

@ -2,7 +2,7 @@ from typing import Optional
import torch
import torch.nn as nn
from coati.distributed.utils import masked_mean
from coati.distributed.utils import masked_mean, masked_sum
class PolicyLoss(nn.Module):
@ -10,11 +10,21 @@ class PolicyLoss(nn.Module):
Policy Loss for PPO
"""
def __init__(self, clip_eps: float = 0.2, skip_threshold: float = 20.0, beta: float = 0.01) -> None:
def __init__(
self,
clip_eps_low: float = 0.2,
clip_eps_high: float = 0.2,
skip_threshold: float = 20.0,
beta: float = 0.01,
loss_variation: str = "sample_level",
) -> None:
super().__init__()
self.clip_eps = clip_eps
self.clip_eps_low = clip_eps_low
self.clip_eps_high = clip_eps_high
self.skip_threshold = skip_threshold
self.beta = beta
self.loss_variation = loss_variation
assert loss_variation in ["sample_level", "token_level"], f"Unsupported loss variation: {loss_variation}"
def forward(
self,
@ -32,14 +42,31 @@ class PolicyLoss(nn.Module):
ratio = ((log_probs - log_probs.detach()) * action_mask).exp()
surr1 = ratio * advantages
surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages
surr2 = ratio.clamp(1 - self.clip_eps_low, 1 + self.clip_eps_high) * advantages
if self.beta <= 0:
# skip kl term if kl coefficient is zero
per_token_kl = 0.0
loss = -torch.min(surr1, surr2) + self.beta * per_token_kl
if action_mask is not None:
loss = masked_mean(loss, action_mask)
else:
loss = loss.mean(dim=1)
if loss_mask is not None:
loss = loss * loss_mask
loss = loss.mean()
if self.loss_variation == "sample_level":
if action_mask is not None:
loss = masked_mean(loss, action_mask)
else:
loss = loss.mean(dim=1)
if loss_mask is not None:
loss = loss * loss_mask
loss = loss.mean()
elif self.loss_variation == "token_level":
total_tokens = 0
if action_mask is not None:
loss = masked_sum(loss, action_mask)
total_tokens = action_mask.sum(dim=1)
else:
loss = loss.sum(dim=1)
total_tokens = torch.ones_like(loss, device=loss.device) * log_probs.size(1)
if loss_mask is not None:
loss = loss * loss_mask
total_tokens = total_tokens * loss_mask
loss = loss.sum() / (total_tokens.sum() + 1e-8)
return loss, skip, ratio.max()

View File

@ -124,12 +124,12 @@ class BaseProducer:
self.load_state_dict(state_dict)
del state_dict
torch.cuda.empty_cache()
# linear annealing for 1 episode, temperature from initial to 0.7
# linear annealing for 1 episode, temperature from initial to 0.9
if episode <= 0:
ratio = 1 - (len(self.dataloader) - i) / len(self.dataloader)
self.model.generate_config.temperature = (1 - ratio) * self.generate_config[
"temperature"
] + ratio * 0.7
] + ratio * 0.9
@ray.remote

View File

@ -3,14 +3,29 @@ import torch
from .reward_utils import extract_solution, validate_response_structure
def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
def math_reward_fn(step, input_ids, gt_answer, response_idx, **kwargs):
tokenizer = kwargs["tokenizer"]
soft_over_length_punishment = kwargs["soft_over_length_punishment"]
format_score = 1.0
acc_score = 9.0
tokenizer = kwargs["tokenizer"]
if step > 30:
format_score = 0.0
acc_score = 10.0
reward = torch.tensor(0.0)
format_reward = torch.tensor(0.0)
acc_reward = torch.tensor(0.0)
s, e = response_idx[0], response_idx[1]
length_reward = 0.0
if soft_over_length_punishment:
max_length = kwargs.get("max_length", 1024 * 4)
cache_length = kwargs.get("cache_length", 512)
res_length = e.item() - s.item() + 1
if res_length >= max_length:
length_reward = -1.0 * 2
elif res_length > max_length - cache_length:
length_reward = ((max_length - cache_length) - res_length) / cache_length * 2
if gt_answer is None:
return reward
@ -33,6 +48,8 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
acc_reward += acc_score
reward += acc_score
reward = reward + length_reward
return torch.tensor([reward, format_reward, acc_reward]).to(input_ids.device)

View File

@ -14,6 +14,7 @@ class VerifiableReward:
def __call__(
self,
step: int,
input_ids: torch.LongTensor,
gt_answer: List[torch.Tensor] = None,
response_idx: List[torch.Tensor] = None,
@ -29,6 +30,7 @@ class VerifiableReward:
reward_batch = torch.stack(
[
reward_fn(
step,
input_ids[i],
gt_answer=gt_answer[i],
response_idx=response_idx[i],

View File

@ -113,3 +113,20 @@ def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch
mask_sum = mask.sum(dim=dim)
mean = tensor / (mask_sum + 1e-8)
return mean
def masked_sum(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch.Tensor:
"""
Compute the masked sum of a tensor along a specified dimension.
Args:
tensor (torch.Tensor): The input tensor.
mask (torch.Tensor): The mask tensor with the same shape as the input tensor.
dim (int, optional): The dimension along which to compute the sum. Default is 1.
Returns:
torch.Tensor: The masked sum tensor.
"""
tensor = tensor * mask
return tensor.sum(dim=dim)

View File

@ -128,7 +128,21 @@ def all_reduce_mean(tensor: torch.Tensor, plugin: Plugin = None) -> torch.Tensor
return tensor
def all_reduce_sum(tensor: torch.Tensor) -> torch.Tensor:
# def all_reduce_sum(tensor: torch.Tensor, ) -> torch.Tensor:
# """
# Performs an all-reduce operation to sum the values of the given tensor across all processes.
# Args:
# tensor (torch.Tensor): The input tensor to be reduced.
# Returns:
# torch.Tensor: The reduced tensor with the sum of values across all processes.
# """
# dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM)
# return tensor
def all_reduce_sum(tensor: torch.Tensor, plugin: Plugin = None) -> torch.Tensor:
"""
Performs an all-reduce operation to sum the values of the given tensor across all processes.
@ -138,5 +152,9 @@ def all_reduce_sum(tensor: torch.Tensor) -> torch.Tensor:
Returns:
torch.Tensor: The reduced tensor with the sum of values across all processes.
"""
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM)
# All reduce sum across DP group
if plugin is not None:
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM, group=plugin.dp_group)
else:
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM)
return tensor

View File

@ -60,8 +60,8 @@ if __name__ == "__main__":
ray.init(address="local", namespace="ray-example")
inference_model_config = dict(path=args.model)
train_model_config = dict(path=args.model, use_flash_attention_2=True, use_cache=False)
generate_config = dict(top_k=50, top_p=0.75, temperature=0.9)
train_model_config = dict(path=args.model, use_flash_attention_2=False, use_cache=False)
generate_config = dict(top_k=-1, top_p=1.0, temperature=1.0)
if args.backend == "transformers":
inference_model_config.update(
@ -102,6 +102,29 @@ if __name__ == "__main__":
)
)
# Default Settings
# grpo_config = {
# "filter_range": [0.05, 9.0],
# "lr": 1e-6,
# "train_microbatch_size": train_microbatch_size,
# }
# DAPO variant settings
grpo_config = {
"filter_range": [0.05, 9.0],
"lr": 1e-6,
"train_microbatch_size": args.train_microbatch_size,
"clip_eps_low": 0.2,
"clip_eps_high": 0.28,
"skip_threshold": 20.0,
"beta": 0.0, # no KL penalty
"loss_variation": "token_level",
"soft_over_length_punishment": True,
"max_length": 1024 * 2,
"cache_length": 256,
"filter_truncated_response": True,
}
launch_distributed(
num_producers=args.num_inferencer,
num_proc_per_producer=1,
@ -118,14 +141,17 @@ if __name__ == "__main__":
generate_config=generate_config,
num_generations=args.num_generations,
train_model_config=train_model_config,
# plugin_config={}, # for zero
grpo_config=grpo_config,
plugin_config={
"pp_size": 2,
"tp_size": 2,
"microbatch_size": args.train_microbatch_size // 2,
"zero_stage": 0,
"max_norm": 1.0,
}, # for pp
"zero_stage": 2,
}, # for zero
# plugin_config={
# "pp_size": 2,
# "tp_size": 2,
# "microbatch_size": args.train_microbatch_size // 2,
# "zero_stage": 0,
# "max_norm": 1.0,
# }, # for pp
inference_backend=args.backend,
master_addr="localhost",
master_port=29506,