mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-22 21:49:08 +00:00
add DAPO support
This commit is contained in:
parent
1723a02860
commit
6e71e2a3ce
@ -12,7 +12,7 @@ from coati.distributed.loss import PolicyLoss
|
|||||||
from coati.distributed.reward.reward_fn import math_reward_fn
|
from coati.distributed.reward.reward_fn import math_reward_fn
|
||||||
from coati.distributed.reward.verifiable_reward import VerifiableReward
|
from coati.distributed.reward.verifiable_reward import VerifiableReward
|
||||||
from coati.distributed.utils import calc_action_log_probs
|
from coati.distributed.utils import calc_action_log_probs
|
||||||
from coati.trainer.utils import all_reduce_mean
|
from coati.trainer.utils import all_reduce_mean, all_reduce_sum
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||||
@ -38,7 +38,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
num_generations=8,
|
num_generations=8,
|
||||||
use_wandb=True,
|
use_wandb=True,
|
||||||
generate_config=None,
|
generate_config=None,
|
||||||
training_config={},
|
grpo_config={},
|
||||||
project_name=None,
|
project_name=None,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
@ -59,7 +59,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
|
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
|
||||||
self.policy_model.train()
|
self.policy_model.train()
|
||||||
self.policy_model.gradient_checkpointing_enable()
|
self.policy_model.gradient_checkpointing_enable()
|
||||||
self.optimizer = HybridAdam(self.policy_model.parameters(), lr=training_config.get("lr", 1e-6))
|
self.optimizer = HybridAdam(self.policy_model.parameters(), lr=grpo_config.get("lr", 1e-6))
|
||||||
self.accum_loss = torch.zeros(1, device=self.device)
|
self.accum_loss = torch.zeros(1, device=self.device)
|
||||||
self.accum_reward = torch.zeros(1, device=self.device)
|
self.accum_reward = torch.zeros(1, device=self.device)
|
||||||
self.accum_kl = torch.zeros(1, device=self.device)
|
self.accum_kl = torch.zeros(1, device=self.device)
|
||||||
@ -69,8 +69,9 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
self.accum_response_length = torch.zeros(1, device=self.device)
|
self.accum_response_length = torch.zeros(1, device=self.device)
|
||||||
self.accum_count = 0
|
self.accum_count = 0
|
||||||
self.generate_config = generate_config
|
self.generate_config = generate_config
|
||||||
self.training_config = training_config
|
self.grpo_config = grpo_config
|
||||||
self.project_name = project_name
|
self.project_name = project_name
|
||||||
|
self.effective_sample_count = 0
|
||||||
|
|
||||||
# Reference model is initialized from policy model.
|
# Reference model is initialized from policy model.
|
||||||
self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
|
self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
|
||||||
@ -79,10 +80,21 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
self.tokenizer = AutoTokenizer.from_pretrained(path)
|
self.tokenizer = AutoTokenizer.from_pretrained(path)
|
||||||
self.pad_token_id = self.tokenizer.pad_token_id
|
self.pad_token_id = self.tokenizer.pad_token_id
|
||||||
self.num_generations = num_generations
|
self.num_generations = num_generations
|
||||||
self.filter_range = training_config.get("filter_range", None)
|
self.filter_range = grpo_config.get("filter_range", None)
|
||||||
if self.filter_range is not None:
|
if self.filter_range is not None:
|
||||||
assert len(self.filter_range) == 2, "Filter range should have 2 values."
|
assert len(self.filter_range) == 2, "Filter range should have 2 values."
|
||||||
|
|
||||||
|
self.filter_truncated_response = grpo_config.get("filter_truncated_response", False)
|
||||||
|
if self.filter_truncated_response:
|
||||||
|
self.max_length = 0
|
||||||
|
if "max_tokens" in self.generate_config:
|
||||||
|
self.max_length = self.generate_config["max_tokens"]
|
||||||
|
elif "max_new_tokens" in self.generate_config:
|
||||||
|
self.max_length = self.generate_config["max_new_tokens"]
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"either max_tokens (vllm) or max_new_tokens (transformers) must be set in generate_config."
|
||||||
|
)
|
||||||
# Initialize verifiable reward.
|
# Initialize verifiable reward.
|
||||||
response_format_tags = {
|
response_format_tags = {
|
||||||
"think_start": {"text": "<think>", "num_occur": 1},
|
"think_start": {"text": "<think>", "num_occur": 1},
|
||||||
@ -90,11 +102,20 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
"answer_start": {"text": "<answer>", "num_occur": 1},
|
"answer_start": {"text": "<answer>", "num_occur": 1},
|
||||||
"answer_end": {"text": "</answer>", "num_occur": 1},
|
"answer_end": {"text": "</answer>", "num_occur": 1},
|
||||||
}
|
}
|
||||||
|
reward_model_kwargs = {
|
||||||
|
k: v for k, v in grpo_config.items() if k in ["soft_over_length_punishment", "max_length", "cache_length"]
|
||||||
|
}
|
||||||
self.reward_model = VerifiableReward(
|
self.reward_model = VerifiableReward(
|
||||||
reward_fns=[math_reward_fn], tokenizer=self.tokenizer, tags=response_format_tags
|
reward_fns=[math_reward_fn], tokenizer=self.tokenizer, tags=response_format_tags, **reward_model_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
self.policy_loss_fn = PolicyLoss()
|
self.policy_loss_fn = PolicyLoss(
|
||||||
|
clip_eps_low=grpo_config.get("clip_eps_low", 0.2),
|
||||||
|
clip_eps_high=grpo_config.get("clip_eps_high", 0.2),
|
||||||
|
skip_threshold=grpo_config.get("skip_threshold", 20.0),
|
||||||
|
beta=grpo_config.get("beta", 0.01),
|
||||||
|
loss_variation=grpo_config.get("loss_variation", "sample_level"),
|
||||||
|
)
|
||||||
self.global_step = 0
|
self.global_step = 0
|
||||||
self.use_wandb = use_wandb
|
self.use_wandb = use_wandb
|
||||||
|
|
||||||
@ -102,7 +123,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
optimizer=self.optimizer,
|
optimizer=self.optimizer,
|
||||||
total_steps=min(self.num_episodes, 4) * self.num_update_per_episode,
|
total_steps=min(self.num_episodes, 4) * self.num_update_per_episode,
|
||||||
warmup_steps=0,
|
warmup_steps=0,
|
||||||
eta_min=0.1 * training_config.get("lr", 1e-6),
|
eta_min=0.1 * grpo_config.get("lr", 1e-6),
|
||||||
)
|
)
|
||||||
|
|
||||||
def setup(self):
|
def setup(self):
|
||||||
@ -141,18 +162,13 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
num_action = action_mask.shape[1]
|
num_action = action_mask.shape[1]
|
||||||
old_action_log_probs = data["action_log_probs"]
|
old_action_log_probs = data["action_log_probs"]
|
||||||
response_length = torch.sum(action_mask, dim=1).to(torch.float32)
|
response_length = torch.sum(action_mask, dim=1).to(torch.float32)
|
||||||
forward_batch_size = self.training_config.get("train_microbatch_size", data["input_ids"].size(0))
|
forward_batch_size = self.grpo_config.get("train_microbatch_size", data["input_ids"].size(0))
|
||||||
|
|
||||||
need_update = (step_idx + 1) % self.num_microbatches == 0
|
|
||||||
# Gradient must be synchronized if zero2 is enabled. https://github.com/hpcaitech/ColossalAI/blob/44d4053fec005fe0b06b6bc755fdc962463145df/colossalai/booster/plugin/hybrid_parallel_plugin.py#L1500
|
|
||||||
ctx = (
|
|
||||||
nullcontext()
|
|
||||||
if need_update or self.booster.plugin.zero_stage == 2
|
|
||||||
else self.booster.no_sync(self.policy_model, self.optimizer)
|
|
||||||
)
|
|
||||||
with ctx:
|
|
||||||
reward_group = self.reward_model(
|
reward_group = self.reward_model(
|
||||||
data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"]
|
int(step_idx / self.num_microbatches),
|
||||||
|
data["input_ids"],
|
||||||
|
gt_answer=data["gt_answer"],
|
||||||
|
response_idx=data["response_idx"],
|
||||||
)
|
)
|
||||||
|
|
||||||
reward = torch.tensor([value[0] for value in reward_group]).to(data["input_ids"].device)
|
reward = torch.tensor([value[0] for value in reward_group]).to(data["input_ids"].device)
|
||||||
@ -165,19 +181,54 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
reward_mean = group_reward.mean(dim=1)
|
reward_mean = group_reward.mean(dim=1)
|
||||||
# [batch_size x num_generations]
|
# [batch_size x num_generations]
|
||||||
reward_mean = reward_mean.repeat_interleave(self.num_generations, dim=0)
|
reward_mean = reward_mean.repeat_interleave(self.num_generations, dim=0)
|
||||||
|
|
||||||
reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0)
|
reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0)
|
||||||
# [batch_size x num_generations]
|
# [batch_size x num_generations]
|
||||||
advantages = ((reward - reward_mean) / (reward_std + 1e-4)).unsqueeze(dim=-1)
|
advantages = ((reward - reward_mean) / (reward_std + 1e-4)).unsqueeze(dim=-1)
|
||||||
# filter out the reward that is too high (all sample gets full score) or too low (all sample gets 0 score),
|
# filter out the reward that is too high (all sample gets full score) or too low (all sample gets 0 score),
|
||||||
|
reward_mean_no_length_penalty = (
|
||||||
|
(format_reward + acc_reward)
|
||||||
|
.view(-1, self.num_generations)
|
||||||
|
.mean(dim=1)
|
||||||
|
.repeat_interleave(self.num_generations, dim=0)
|
||||||
|
)
|
||||||
loss_mask = (
|
loss_mask = (
|
||||||
None
|
torch.ones(action_mask.size(0), device=action_mask.device).bool()
|
||||||
if self.filter_range is None
|
if self.filter_range is None
|
||||||
else torch.logical_and(
|
else torch.logical_and(
|
||||||
reward_mean > self.filter_range[0], reward_mean < self.filter_range[1]
|
reward_mean_no_length_penalty > self.filter_range[0], reward_mean < self.filter_range[1]
|
||||||
).repeat_interleave(self.num_generations, dim=0)
|
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
# filter out overlength samples
|
||||||
|
if self.filter_truncated_response and action_mask.size(1) == self.max_length:
|
||||||
|
loss_mask = torch.logical_and(
|
||||||
|
loss_mask,
|
||||||
|
action_mask[:, -1] == False,
|
||||||
|
)
|
||||||
|
# for i in range(loss_mask.size(0)):
|
||||||
|
# if loss_mask[i] == False:
|
||||||
|
# print(data["input_ids"].size(), data["input_ids"][i], action_mask[i], "mean reward", reward_mean_no_length_penalty.size(), reward_mean_no_length_penalty[i])
|
||||||
|
|
||||||
|
effective_samples = all_reduce_sum(torch.sum(loss_mask), self.plugin)
|
||||||
|
self.effective_sample_count += effective_samples.item()
|
||||||
|
|
||||||
mean_kl, mean_loss = [], []
|
mean_kl, mean_loss = [], []
|
||||||
|
|
||||||
|
# update gradient only if at least 0.7*batch_size*num_generation valid samples are collected in case a lot of samples are invalid and got filtered out.
|
||||||
|
# balance between efficiency and accuracy
|
||||||
|
need_update = self.effective_sample_count >= self.batch_size * self.dp_size * self.num_generations * 0.75
|
||||||
|
if need_update:
|
||||||
|
print(f"***** Update gradient based on {self.effective_sample_count} valid samples *****")
|
||||||
|
self.effective_sample_count = 0
|
||||||
|
|
||||||
|
# Gradient must be synchronized if zero2 is enabled. https://github.com/hpcaitech/ColossalAI/blob/44d4053fec005fe0b06b6bc755fdc962463145df/colossalai/booster/plugin/hybrid_parallel_plugin.py#L1500
|
||||||
|
ctx = (
|
||||||
|
nullcontext()
|
||||||
|
if need_update or self.booster.plugin.zero_stage == 2
|
||||||
|
else self.booster.no_sync(self.policy_model, self.optimizer)
|
||||||
|
)
|
||||||
|
with ctx:
|
||||||
|
|
||||||
for forward_micro_batch_start in range(0, data["input_ids"].size(0), forward_batch_size):
|
for forward_micro_batch_start in range(0, data["input_ids"].size(0), forward_batch_size):
|
||||||
input_ids_forward_micro_batch = data["input_ids"][
|
input_ids_forward_micro_batch = data["input_ids"][
|
||||||
forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
|
forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
|
||||||
@ -199,7 +250,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
|
|
||||||
if self.plugin.pp_size > 1:
|
if self.plugin.pp_size > 1:
|
||||||
# Support training with PP.
|
# Support training with PP.
|
||||||
|
if self.policy_loss_fn.beta > 0:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
reference_model_outputs = self.booster.execute_pipeline(
|
reference_model_outputs = self.booster.execute_pipeline(
|
||||||
iter(
|
iter(
|
||||||
@ -230,16 +281,19 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
else:
|
else:
|
||||||
# Dummy reference logprobs for data iterator.
|
# Dummy reference logprobs for data iterator.
|
||||||
reference_action_log_probs = None
|
reference_action_log_probs = None
|
||||||
|
else:
|
||||||
|
reference_action_log_probs = None
|
||||||
|
|
||||||
data_policy_forward = {
|
data_policy_forward = {
|
||||||
"input_ids": input_ids_forward_micro_batch,
|
"input_ids": input_ids_forward_micro_batch,
|
||||||
"attention_mask": attention_mask_forward_micro_batch,
|
"attention_mask": attention_mask_forward_micro_batch,
|
||||||
"action_mask": action_mask_forward_micro_batch,
|
"action_mask": action_mask_forward_micro_batch,
|
||||||
"reference_action_log_probs": reference_action_log_probs,
|
|
||||||
"advantages": advantages_forward_micro_batch,
|
"advantages": advantages_forward_micro_batch,
|
||||||
"loss_mask": loss_mask_forward_micro_batch,
|
"loss_mask": loss_mask_forward_micro_batch,
|
||||||
"source": self.rank,
|
"source": self.rank,
|
||||||
}
|
}
|
||||||
|
if reference_action_log_probs is not None:
|
||||||
|
data_policy_forward["reference_action_log_probs"] = reference_action_log_probs
|
||||||
|
|
||||||
kl = []
|
kl = []
|
||||||
|
|
||||||
@ -251,6 +305,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
num_action,
|
num_action,
|
||||||
self.plugin.shard_config,
|
self.plugin.shard_config,
|
||||||
)
|
)
|
||||||
|
if "reference_action_log_probs" in inputs:
|
||||||
per_token_kl = (
|
per_token_kl = (
|
||||||
torch.exp(inputs["reference_action_log_probs"] - action_log_probs)
|
torch.exp(inputs["reference_action_log_probs"] - action_log_probs)
|
||||||
- (inputs["reference_action_log_probs"] - action_log_probs)
|
- (inputs["reference_action_log_probs"] - action_log_probs)
|
||||||
@ -260,6 +315,10 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
inputs["action_mask"], dim=-1
|
inputs["action_mask"], dim=-1
|
||||||
)
|
)
|
||||||
kl.append(appox_kl.mean())
|
kl.append(appox_kl.mean())
|
||||||
|
else:
|
||||||
|
per_token_kl = 0.0
|
||||||
|
kl.append(0.0)
|
||||||
|
|
||||||
loss, skip_update, _ = self.policy_loss_fn(
|
loss, skip_update, _ = self.policy_loss_fn(
|
||||||
action_log_probs,
|
action_log_probs,
|
||||||
action_log_probs,
|
action_log_probs,
|
||||||
@ -298,6 +357,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
self.plugin.shard_config,
|
self.plugin.shard_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.policy_loss_fn.beta > 0:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
reference_model_logits = self.reference_model(
|
reference_model_logits = self.reference_model(
|
||||||
input_ids=input_ids_forward_micro_batch,
|
input_ids=input_ids_forward_micro_batch,
|
||||||
@ -317,6 +377,9 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
kl = torch.sum(per_token_kl * action_mask_forward_micro_batch, dim=-1) / torch.sum(
|
kl = torch.sum(per_token_kl * action_mask_forward_micro_batch, dim=-1) / torch.sum(
|
||||||
action_mask_forward_micro_batch, dim=-1
|
action_mask_forward_micro_batch, dim=-1
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
per_token_kl = 0.0
|
||||||
|
kl = None
|
||||||
|
|
||||||
loss, skip_update, _ = self.policy_loss_fn(
|
loss, skip_update, _ = self.policy_loss_fn(
|
||||||
action_log_probs,
|
action_log_probs,
|
||||||
@ -330,8 +393,9 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
if not skip_update:
|
if not skip_update:
|
||||||
self.booster.backward(loss, self.optimizer)
|
self.booster.backward(loss, self.optimizer)
|
||||||
loss = all_reduce_mean(loss, self.plugin)
|
loss = all_reduce_mean(loss, self.plugin)
|
||||||
kl = all_reduce_mean(kl.mean(), self.plugin)
|
|
||||||
# Calculate accumulate value.
|
# Calculate accumulate value.
|
||||||
|
if kl is not None:
|
||||||
|
kl = all_reduce_mean(kl.mean(), self.plugin)
|
||||||
mean_kl.append(kl.data)
|
mean_kl.append(kl.data)
|
||||||
mean_loss.append(loss.data)
|
mean_loss.append(loss.data)
|
||||||
if not self.plugin.pp_size > 1 or (
|
if not self.plugin.pp_size > 1 or (
|
||||||
@ -343,6 +407,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
advantages = all_reduce_mean(advantages.mean(), self.plugin)
|
advantages = all_reduce_mean(advantages.mean(), self.plugin)
|
||||||
response_length = all_reduce_mean(response_length.mean(), self.plugin)
|
response_length = all_reduce_mean(response_length.mean(), self.plugin)
|
||||||
self.accum_loss.add_(sum(mean_loss) / len(mean_loss))
|
self.accum_loss.add_(sum(mean_loss) / len(mean_loss))
|
||||||
|
if self.policy_loss_fn.beta > 0:
|
||||||
self.accum_kl.add_(sum(mean_kl) / len(mean_kl))
|
self.accum_kl.add_(sum(mean_kl) / len(mean_kl))
|
||||||
self.accum_reward.add_(reward.data)
|
self.accum_reward.add_(reward.data)
|
||||||
self.accum_format_reward.add_(format_reward.data)
|
self.accum_format_reward.add_(format_reward.data)
|
||||||
@ -360,35 +425,32 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
if (not self.plugin.pp_size > 1 and self.rank == 0) or (
|
if (not self.plugin.pp_size > 1 and self.rank == 0) or (
|
||||||
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
|
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
|
||||||
):
|
):
|
||||||
print(
|
to_log_msg = (
|
||||||
"Loss:",
|
f"Loss: {self.accum_loss.item() / self.accum_count:.4f} \
|
||||||
self.accum_loss.item() / self.accum_count,
|
Reward: {self.accum_reward.item() / self.accum_count:.4f} \
|
||||||
"\nReward:",
|
Format Reward: {self.accum_format_reward.item() / self.accum_count:.4f} \
|
||||||
self.accum_reward.item() / self.accum_count,
|
Acc Reward: {self.accum_acc_reward.item() / self.accum_count:.4f} \
|
||||||
"\nFormat Reward:",
|
Advantages: {self.accum_advantages.item() / self.accum_count:.4f} \
|
||||||
self.accum_format_reward.item() / self.accum_count,
|
Response Length: {self.accum_response_length.item() / self.accum_count:.4f}"
|
||||||
"\nAcc Reward:",
|
+ f" KL: {self.accum_kl.item() / self.accum_count:.4f}"
|
||||||
self.accum_acc_reward.item() / self.accum_count,
|
if self.policy_loss_fn.beta > 0
|
||||||
"\nKL:",
|
else ""
|
||||||
self.accum_kl.item() / self.accum_count,
|
|
||||||
"\nAdvantages:",
|
|
||||||
self.accum_advantages.item() / self.accum_count,
|
|
||||||
"\nResponse Length:",
|
|
||||||
self.accum_response_length.item() / self.accum_count,
|
|
||||||
)
|
)
|
||||||
self.wandb_run.log(
|
print(to_log_msg)
|
||||||
{
|
metrics = {
|
||||||
"metrics/reward": self.accum_reward.item() / self.accum_count,
|
"metrics/reward": self.accum_reward.item() / self.accum_count,
|
||||||
"metrics/format_reward": self.accum_format_reward.item() / self.accum_count,
|
"metrics/format_reward": self.accum_format_reward.item() / self.accum_count,
|
||||||
"metrics/acc_reward": self.accum_acc_reward.item() / self.accum_count,
|
"metrics/acc_reward": self.accum_acc_reward.item() / self.accum_count,
|
||||||
"metrics/response_length": self.accum_response_length.item() / self.accum_count,
|
"metrics/response_length": self.accum_response_length.item() / self.accum_count,
|
||||||
"train/loss": self.accum_loss.item() / self.accum_count,
|
"train/loss": self.accum_loss.item() / self.accum_count,
|
||||||
"train/kl": self.accum_kl.item() / self.accum_count,
|
|
||||||
"train/advantages": self.accum_advantages.item() / self.accum_count,
|
"train/advantages": self.accum_advantages.item() / self.accum_count,
|
||||||
"train/learning_rate": self.lr_scheduler.get_last_lr()[0],
|
"train/learning_rate": self.lr_scheduler.get_last_lr()[0],
|
||||||
"rollout/temperature": data["temperature"].cpu().numpy()[0][0],
|
"rollout/temperature": data["temperature"].cpu().numpy()[0][0],
|
||||||
}
|
}
|
||||||
)
|
if self.policy_loss_fn.beta > 0:
|
||||||
|
metrics["train/kl"] = self.accum_kl.item() / self.accum_count
|
||||||
|
|
||||||
|
self.wandb_run.log(metrics)
|
||||||
self.accum_loss.zero_()
|
self.accum_loss.zero_()
|
||||||
self.accum_reward.zero_()
|
self.accum_reward.zero_()
|
||||||
self.accum_acc_reward.zero_()
|
self.accum_acc_reward.zero_()
|
||||||
|
@ -40,6 +40,7 @@ def launch_distributed(
|
|||||||
inference_model_config: Dict[str, Any],
|
inference_model_config: Dict[str, Any],
|
||||||
generate_config: Dict[str, Any],
|
generate_config: Dict[str, Any],
|
||||||
train_model_config: Dict[str, Any],
|
train_model_config: Dict[str, Any],
|
||||||
|
grpo_config: Dict[str, Any],
|
||||||
plugin_config: Dict[str, Any],
|
plugin_config: Dict[str, Any],
|
||||||
tokenizer_config: Optional[Dict[str, Any]] = None,
|
tokenizer_config: Optional[Dict[str, Any]] = None,
|
||||||
inference_backend: str = "transformers",
|
inference_backend: str = "transformers",
|
||||||
@ -103,11 +104,7 @@ def launch_distributed(
|
|||||||
plugin_config=plugin_config,
|
plugin_config=plugin_config,
|
||||||
microbatch_size=train_minibatch_size,
|
microbatch_size=train_minibatch_size,
|
||||||
generate_config=generate_config_consumer,
|
generate_config=generate_config_consumer,
|
||||||
training_config={
|
grpo_config=grpo_config,
|
||||||
"filter_range": [0.05, 9.0],
|
|
||||||
"lr": 1e-6,
|
|
||||||
"train_microbatch_size": train_microbatch_size,
|
|
||||||
},
|
|
||||||
num_generations=num_generations,
|
num_generations=num_generations,
|
||||||
project_name=project_name,
|
project_name=project_name,
|
||||||
)
|
)
|
||||||
|
@ -2,7 +2,7 @@ from typing import Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from coati.distributed.utils import masked_mean
|
from coati.distributed.utils import masked_mean, masked_sum
|
||||||
|
|
||||||
|
|
||||||
class PolicyLoss(nn.Module):
|
class PolicyLoss(nn.Module):
|
||||||
@ -10,11 +10,21 @@ class PolicyLoss(nn.Module):
|
|||||||
Policy Loss for PPO
|
Policy Loss for PPO
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, clip_eps: float = 0.2, skip_threshold: float = 20.0, beta: float = 0.01) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
clip_eps_low: float = 0.2,
|
||||||
|
clip_eps_high: float = 0.2,
|
||||||
|
skip_threshold: float = 20.0,
|
||||||
|
beta: float = 0.01,
|
||||||
|
loss_variation: str = "sample_level",
|
||||||
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.clip_eps = clip_eps
|
self.clip_eps_low = clip_eps_low
|
||||||
|
self.clip_eps_high = clip_eps_high
|
||||||
self.skip_threshold = skip_threshold
|
self.skip_threshold = skip_threshold
|
||||||
self.beta = beta
|
self.beta = beta
|
||||||
|
self.loss_variation = loss_variation
|
||||||
|
assert loss_variation in ["sample_level", "token_level"], f"Unsupported loss variation: {loss_variation}"
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -32,9 +42,13 @@ class PolicyLoss(nn.Module):
|
|||||||
ratio = ((log_probs - log_probs.detach()) * action_mask).exp()
|
ratio = ((log_probs - log_probs.detach()) * action_mask).exp()
|
||||||
|
|
||||||
surr1 = ratio * advantages
|
surr1 = ratio * advantages
|
||||||
surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages
|
surr2 = ratio.clamp(1 - self.clip_eps_low, 1 + self.clip_eps_high) * advantages
|
||||||
|
if self.beta <= 0:
|
||||||
|
# skip kl term if kl coefficient is zero
|
||||||
|
per_token_kl = 0.0
|
||||||
loss = -torch.min(surr1, surr2) + self.beta * per_token_kl
|
loss = -torch.min(surr1, surr2) + self.beta * per_token_kl
|
||||||
|
|
||||||
|
if self.loss_variation == "sample_level":
|
||||||
if action_mask is not None:
|
if action_mask is not None:
|
||||||
loss = masked_mean(loss, action_mask)
|
loss = masked_mean(loss, action_mask)
|
||||||
else:
|
else:
|
||||||
@ -42,4 +56,17 @@ class PolicyLoss(nn.Module):
|
|||||||
if loss_mask is not None:
|
if loss_mask is not None:
|
||||||
loss = loss * loss_mask
|
loss = loss * loss_mask
|
||||||
loss = loss.mean()
|
loss = loss.mean()
|
||||||
|
elif self.loss_variation == "token_level":
|
||||||
|
total_tokens = 0
|
||||||
|
if action_mask is not None:
|
||||||
|
loss = masked_sum(loss, action_mask)
|
||||||
|
total_tokens = action_mask.sum(dim=1)
|
||||||
|
else:
|
||||||
|
loss = loss.sum(dim=1)
|
||||||
|
total_tokens = torch.ones_like(loss, device=loss.device) * log_probs.size(1)
|
||||||
|
if loss_mask is not None:
|
||||||
|
loss = loss * loss_mask
|
||||||
|
total_tokens = total_tokens * loss_mask
|
||||||
|
loss = loss.sum() / (total_tokens.sum() + 1e-8)
|
||||||
|
|
||||||
return loss, skip, ratio.max()
|
return loss, skip, ratio.max()
|
||||||
|
@ -124,12 +124,12 @@ class BaseProducer:
|
|||||||
self.load_state_dict(state_dict)
|
self.load_state_dict(state_dict)
|
||||||
del state_dict
|
del state_dict
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
# linear annealing for 1 episode, temperature from initial to 0.7
|
# linear annealing for 1 episode, temperature from initial to 0.9
|
||||||
if episode <= 0:
|
if episode <= 0:
|
||||||
ratio = 1 - (len(self.dataloader) - i) / len(self.dataloader)
|
ratio = 1 - (len(self.dataloader) - i) / len(self.dataloader)
|
||||||
self.model.generate_config.temperature = (1 - ratio) * self.generate_config[
|
self.model.generate_config.temperature = (1 - ratio) * self.generate_config[
|
||||||
"temperature"
|
"temperature"
|
||||||
] + ratio * 0.7
|
] + ratio * 0.9
|
||||||
|
|
||||||
|
|
||||||
@ray.remote
|
@ray.remote
|
||||||
|
@ -3,14 +3,29 @@ import torch
|
|||||||
from .reward_utils import extract_solution, validate_response_structure
|
from .reward_utils import extract_solution, validate_response_structure
|
||||||
|
|
||||||
|
|
||||||
def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
def math_reward_fn(step, input_ids, gt_answer, response_idx, **kwargs):
|
||||||
|
tokenizer = kwargs["tokenizer"]
|
||||||
|
soft_over_length_punishment = kwargs["soft_over_length_punishment"]
|
||||||
format_score = 1.0
|
format_score = 1.0
|
||||||
acc_score = 9.0
|
acc_score = 9.0
|
||||||
tokenizer = kwargs["tokenizer"]
|
if step > 30:
|
||||||
|
format_score = 0.0
|
||||||
|
acc_score = 10.0
|
||||||
reward = torch.tensor(0.0)
|
reward = torch.tensor(0.0)
|
||||||
format_reward = torch.tensor(0.0)
|
format_reward = torch.tensor(0.0)
|
||||||
acc_reward = torch.tensor(0.0)
|
acc_reward = torch.tensor(0.0)
|
||||||
s, e = response_idx[0], response_idx[1]
|
s, e = response_idx[0], response_idx[1]
|
||||||
|
|
||||||
|
length_reward = 0.0
|
||||||
|
if soft_over_length_punishment:
|
||||||
|
max_length = kwargs.get("max_length", 1024 * 4)
|
||||||
|
cache_length = kwargs.get("cache_length", 512)
|
||||||
|
res_length = e.item() - s.item() + 1
|
||||||
|
if res_length >= max_length:
|
||||||
|
length_reward = -1.0 * 2
|
||||||
|
elif res_length > max_length - cache_length:
|
||||||
|
length_reward = ((max_length - cache_length) - res_length) / cache_length * 2
|
||||||
|
|
||||||
if gt_answer is None:
|
if gt_answer is None:
|
||||||
return reward
|
return reward
|
||||||
|
|
||||||
@ -33,6 +48,8 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
|||||||
acc_reward += acc_score
|
acc_reward += acc_score
|
||||||
reward += acc_score
|
reward += acc_score
|
||||||
|
|
||||||
|
reward = reward + length_reward
|
||||||
|
|
||||||
return torch.tensor([reward, format_reward, acc_reward]).to(input_ids.device)
|
return torch.tensor([reward, format_reward, acc_reward]).to(input_ids.device)
|
||||||
|
|
||||||
|
|
||||||
|
@ -14,6 +14,7 @@ class VerifiableReward:
|
|||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
|
step: int,
|
||||||
input_ids: torch.LongTensor,
|
input_ids: torch.LongTensor,
|
||||||
gt_answer: List[torch.Tensor] = None,
|
gt_answer: List[torch.Tensor] = None,
|
||||||
response_idx: List[torch.Tensor] = None,
|
response_idx: List[torch.Tensor] = None,
|
||||||
@ -29,6 +30,7 @@ class VerifiableReward:
|
|||||||
reward_batch = torch.stack(
|
reward_batch = torch.stack(
|
||||||
[
|
[
|
||||||
reward_fn(
|
reward_fn(
|
||||||
|
step,
|
||||||
input_ids[i],
|
input_ids[i],
|
||||||
gt_answer=gt_answer[i],
|
gt_answer=gt_answer[i],
|
||||||
response_idx=response_idx[i],
|
response_idx=response_idx[i],
|
||||||
|
@ -113,3 +113,20 @@ def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch
|
|||||||
mask_sum = mask.sum(dim=dim)
|
mask_sum = mask.sum(dim=dim)
|
||||||
mean = tensor / (mask_sum + 1e-8)
|
mean = tensor / (mask_sum + 1e-8)
|
||||||
return mean
|
return mean
|
||||||
|
|
||||||
|
|
||||||
|
def masked_sum(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Compute the masked sum of a tensor along a specified dimension.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor (torch.Tensor): The input tensor.
|
||||||
|
mask (torch.Tensor): The mask tensor with the same shape as the input tensor.
|
||||||
|
dim (int, optional): The dimension along which to compute the sum. Default is 1.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The masked sum tensor.
|
||||||
|
|
||||||
|
"""
|
||||||
|
tensor = tensor * mask
|
||||||
|
return tensor.sum(dim=dim)
|
||||||
|
@ -128,7 +128,21 @@ def all_reduce_mean(tensor: torch.Tensor, plugin: Plugin = None) -> torch.Tensor
|
|||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
def all_reduce_sum(tensor: torch.Tensor) -> torch.Tensor:
|
# def all_reduce_sum(tensor: torch.Tensor, ) -> torch.Tensor:
|
||||||
|
# """
|
||||||
|
# Performs an all-reduce operation to sum the values of the given tensor across all processes.
|
||||||
|
|
||||||
|
# Args:
|
||||||
|
# tensor (torch.Tensor): The input tensor to be reduced.
|
||||||
|
|
||||||
|
# Returns:
|
||||||
|
# torch.Tensor: The reduced tensor with the sum of values across all processes.
|
||||||
|
# """
|
||||||
|
# dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM)
|
||||||
|
# return tensor
|
||||||
|
|
||||||
|
|
||||||
|
def all_reduce_sum(tensor: torch.Tensor, plugin: Plugin = None) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Performs an all-reduce operation to sum the values of the given tensor across all processes.
|
Performs an all-reduce operation to sum the values of the given tensor across all processes.
|
||||||
|
|
||||||
@ -138,5 +152,9 @@ def all_reduce_sum(tensor: torch.Tensor) -> torch.Tensor:
|
|||||||
Returns:
|
Returns:
|
||||||
torch.Tensor: The reduced tensor with the sum of values across all processes.
|
torch.Tensor: The reduced tensor with the sum of values across all processes.
|
||||||
"""
|
"""
|
||||||
|
# All reduce sum across DP group
|
||||||
|
if plugin is not None:
|
||||||
|
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM, group=plugin.dp_group)
|
||||||
|
else:
|
||||||
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM)
|
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM)
|
||||||
return tensor
|
return tensor
|
||||||
|
@ -60,8 +60,8 @@ if __name__ == "__main__":
|
|||||||
ray.init(address="local", namespace="ray-example")
|
ray.init(address="local", namespace="ray-example")
|
||||||
|
|
||||||
inference_model_config = dict(path=args.model)
|
inference_model_config = dict(path=args.model)
|
||||||
train_model_config = dict(path=args.model, use_flash_attention_2=True, use_cache=False)
|
train_model_config = dict(path=args.model, use_flash_attention_2=False, use_cache=False)
|
||||||
generate_config = dict(top_k=50, top_p=0.75, temperature=0.9)
|
generate_config = dict(top_k=-1, top_p=1.0, temperature=1.0)
|
||||||
|
|
||||||
if args.backend == "transformers":
|
if args.backend == "transformers":
|
||||||
inference_model_config.update(
|
inference_model_config.update(
|
||||||
@ -102,6 +102,29 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Default Settings
|
||||||
|
# grpo_config = {
|
||||||
|
# "filter_range": [0.05, 9.0],
|
||||||
|
# "lr": 1e-6,
|
||||||
|
# "train_microbatch_size": train_microbatch_size,
|
||||||
|
# }
|
||||||
|
|
||||||
|
# DAPO variant settings
|
||||||
|
grpo_config = {
|
||||||
|
"filter_range": [0.05, 9.0],
|
||||||
|
"lr": 1e-6,
|
||||||
|
"train_microbatch_size": args.train_microbatch_size,
|
||||||
|
"clip_eps_low": 0.2,
|
||||||
|
"clip_eps_high": 0.28,
|
||||||
|
"skip_threshold": 20.0,
|
||||||
|
"beta": 0.0, # no KL penalty
|
||||||
|
"loss_variation": "token_level",
|
||||||
|
"soft_over_length_punishment": True,
|
||||||
|
"max_length": 1024 * 2,
|
||||||
|
"cache_length": 256,
|
||||||
|
"filter_truncated_response": True,
|
||||||
|
}
|
||||||
|
|
||||||
launch_distributed(
|
launch_distributed(
|
||||||
num_producers=args.num_inferencer,
|
num_producers=args.num_inferencer,
|
||||||
num_proc_per_producer=1,
|
num_proc_per_producer=1,
|
||||||
@ -118,14 +141,17 @@ if __name__ == "__main__":
|
|||||||
generate_config=generate_config,
|
generate_config=generate_config,
|
||||||
num_generations=args.num_generations,
|
num_generations=args.num_generations,
|
||||||
train_model_config=train_model_config,
|
train_model_config=train_model_config,
|
||||||
# plugin_config={}, # for zero
|
grpo_config=grpo_config,
|
||||||
plugin_config={
|
plugin_config={
|
||||||
"pp_size": 2,
|
"zero_stage": 2,
|
||||||
"tp_size": 2,
|
}, # for zero
|
||||||
"microbatch_size": args.train_microbatch_size // 2,
|
# plugin_config={
|
||||||
"zero_stage": 0,
|
# "pp_size": 2,
|
||||||
"max_norm": 1.0,
|
# "tp_size": 2,
|
||||||
}, # for pp
|
# "microbatch_size": args.train_microbatch_size // 2,
|
||||||
|
# "zero_stage": 0,
|
||||||
|
# "max_norm": 1.0,
|
||||||
|
# }, # for pp
|
||||||
inference_backend=args.backend,
|
inference_backend=args.backend,
|
||||||
master_addr="localhost",
|
master_addr="localhost",
|
||||||
master_port=29506,
|
master_port=29506,
|
||||||
|
Loading…
Reference in New Issue
Block a user