fix filtering, still buggy

This commit is contained in:
YeAnbang 2025-04-16 14:10:43 +08:00
parent 24c01864df
commit cc4faa7300
6 changed files with 67 additions and 85 deletions

View File

@ -63,8 +63,8 @@ class GRPOConsumer(BaseConsumer):
self.accum_loss = torch.zeros(1, device=self.device)
self.accum_reward = torch.zeros(1, device=self.device)
self.accum_kl = torch.zeros(1, device=self.device)
self.accum_format_reward = torch.zeros(1, device=self.device)
self.accum_acc_reward = torch.zeros(1, device=self.device)
self.accum_format_acc = torch.zeros(1, device=self.device)
self.accum_ans_acc = torch.zeros(1, device=self.device)
self.accum_advantages = torch.zeros(1, device=self.device)
self.accum_response_length = torch.zeros(1, device=self.device)
self.accum_count = 0
@ -72,10 +72,19 @@ class GRPOConsumer(BaseConsumer):
self.grpo_config = grpo_config
self.project_name = project_name
self.effective_sample_count = 0
self.total_sample_count = 0
self.policy_loss_fn = PolicyLoss(
clip_eps_low=grpo_config.get("clip_eps_low", 0.2),
clip_eps_high=grpo_config.get("clip_eps_high", 0.2),
beta=grpo_config.get("beta", 0.01),
loss_variation=grpo_config.get("loss_variation", "sample_level"),
)
# Reference model is initialized from policy model.
self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
self.reference_model.eval()
if self.policy_loss_fn.beta > 0:
self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
self.reference_model.eval()
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.pad_token_id = self.tokenizer.pad_token_id
@ -108,14 +117,6 @@ class GRPOConsumer(BaseConsumer):
self.reward_model = VerifiableReward(
reward_fns=[math_reward_fn], tokenizer=self.tokenizer, tags=response_format_tags, **reward_model_kwargs
)
self.policy_loss_fn = PolicyLoss(
clip_eps_low=grpo_config.get("clip_eps_low", 0.2),
clip_eps_high=grpo_config.get("clip_eps_high", 0.2),
skip_threshold=grpo_config.get("skip_threshold", 20.0),
beta=grpo_config.get("beta", 0.01),
loss_variation=grpo_config.get("loss_variation", "sample_level"),
)
self.global_step = 0
self.use_wandb = use_wandb
@ -139,7 +140,8 @@ class GRPOConsumer(BaseConsumer):
self.policy_model, self.optimizer, _, _, self.lr_scheduler = self.booster.boost(
self.policy_model, self.optimizer, lr_scheduler=self.lr_scheduler
)
self.reference_model, *_ = self.booster.boost(self.reference_model)
if self.policy_loss_fn.beta > 0:
self.reference_model, *_ = self.booster.boost(self.reference_model)
self.plugin.logger.set_level("ERROR")
def step(self, step_idx: int, **kwargs) -> Optional[float]:
@ -165,15 +167,14 @@ class GRPOConsumer(BaseConsumer):
forward_batch_size = self.grpo_config.get("train_microbatch_size", data["input_ids"].size(0))
reward_group = self.reward_model(
int(step_idx / self.num_microbatches),
data["input_ids"],
gt_answer=data["gt_answer"],
response_idx=data["response_idx"],
)
reward = torch.tensor([value[0] for value in reward_group]).to(data["input_ids"].device)
format_reward = torch.tensor([value[1] for value in reward_group]).to(data["input_ids"].device)
acc_reward = torch.tensor([value[2] for value in reward_group]).to(data["input_ids"].device)
format_acc = torch.tensor([value[1] for value in reward_group]).to(data["input_ids"].device)
ans_acc = torch.tensor([value[2] for value in reward_group]).to(data["input_ids"].device)
# [batch_size, num_generations]
@ -186,18 +187,13 @@ class GRPOConsumer(BaseConsumer):
# [batch_size x num_generations]
advantages = ((reward - reward_mean) / (reward_std + 1e-4)).unsqueeze(dim=-1)
# filter out the reward that is too high (all sample gets full score) or too low (all sample gets 0 score),
reward_mean_no_length_penalty = (
(format_reward + acc_reward)
.view(-1, self.num_generations)
.mean(dim=1)
.repeat_interleave(self.num_generations, dim=0)
group_ans_acc = (
ans_acc.view(-1, self.num_generations).mean(dim=1).repeat_interleave(self.num_generations, dim=0)
)
loss_mask = (
torch.ones(action_mask.size(0), device=action_mask.device).bool()
if self.filter_range is None
else torch.logical_and(
reward_mean_no_length_penalty > self.filter_range[0], reward_mean < self.filter_range[1]
)
else torch.logical_and(group_ans_acc > self.filter_range[0], group_ans_acc < self.filter_range[1])
)
# filter out overlength samples
if self.filter_truncated_response and action_mask.size(1) == self.max_length:
@ -205,21 +201,23 @@ class GRPOConsumer(BaseConsumer):
loss_mask,
action_mask[:, -1] == False,
)
# for i in range(loss_mask.size(0)):
# if loss_mask[i] == False:
# print(data["input_ids"].size(), data["input_ids"][i], action_mask[i], "mean reward", reward_mean_no_length_penalty.size(), reward_mean_no_length_penalty[i])
effective_samples = all_reduce_sum(torch.sum(loss_mask), self.plugin)
total_samples = all_reduce_sum(torch.sum(torch.ones_like(loss_mask, device=loss_mask.device)), self.plugin)
self.effective_sample_count += effective_samples.item()
self.total_sample_count += total_samples.item()
print(
loss_mask,
self.effective_sample_count,
self.total_sample_count,
self.batch_size * self.dp_size * self.num_generations * 0.75,
)
mean_kl, mean_loss = [], []
# update gradient only if at least 0.7*batch_size*num_generation valid samples are collected in case a lot of samples are invalid and got filtered out.
# balance between efficiency and accuracy
need_update = self.effective_sample_count >= self.batch_size * self.dp_size * self.num_generations * 0.75
if need_update:
print(f"***** Update gradient based on {self.effective_sample_count} valid samples *****")
self.effective_sample_count = 0
# Gradient must be synchronized if zero2 is enabled. https://github.com/hpcaitech/ColossalAI/blob/44d4053fec005fe0b06b6bc755fdc962463145df/colossalai/booster/plugin/hybrid_parallel_plugin.py#L1500
ctx = (
@ -319,7 +317,7 @@ class GRPOConsumer(BaseConsumer):
per_token_kl = 0.0
kl.append(0.0)
loss, skip_update, _ = self.policy_loss_fn(
loss, _ = self.policy_loss_fn(
action_log_probs,
action_log_probs,
inputs["advantages"].repeat_interleave(action_log_probs.size(-1), dim=-1),
@ -381,7 +379,7 @@ class GRPOConsumer(BaseConsumer):
per_token_kl = 0.0
kl = None
loss, skip_update, _ = self.policy_loss_fn(
loss, _ = self.policy_loss_fn(
action_log_probs,
old_action_log_probs,
advantages_forward_micro_batch.repeat_interleave(action_log_probs.size(-1), dim=-1),
@ -390,8 +388,7 @@ class GRPOConsumer(BaseConsumer):
loss_mask=loss_mask_forward_micro_batch,
)
if not skip_update:
self.booster.backward(loss, self.optimizer)
self.booster.backward(loss, self.optimizer)
loss = all_reduce_mean(loss, self.plugin)
# Calculate accumulate value.
if kl is not None:
@ -402,22 +399,25 @@ class GRPOConsumer(BaseConsumer):
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
):
reward = all_reduce_mean(reward.mean(), self.plugin)
format_reward = all_reduce_mean(format_reward.mean(), self.plugin)
acc_reward = all_reduce_mean(acc_reward.mean(), self.plugin)
format_acc = all_reduce_mean(format_acc.mean(), self.plugin)
ans_acc = all_reduce_mean(ans_acc.mean(), self.plugin)
advantages = all_reduce_mean(advantages.mean(), self.plugin)
response_length = all_reduce_mean(response_length.mean(), self.plugin)
self.accum_loss.add_(sum(mean_loss) / len(mean_loss))
if self.policy_loss_fn.beta > 0:
self.accum_kl.add_(sum(mean_kl) / len(mean_kl))
self.accum_reward.add_(reward.data)
self.accum_format_reward.add_(format_reward.data)
self.accum_acc_reward.add_(acc_reward.data)
self.accum_format_acc.add_(format_acc.data)
self.accum_ans_acc.add_(ans_acc.data)
self.accum_advantages.add_(advantages.data)
self.accum_response_length.add_(response_length.data)
self.accum_count += 1
if need_update:
self.optimizer.step()
self.optimizer.zero_grad()
sample_utilization = self.effective_sample_count / self.total_sample_count
self.effective_sample_count = 0
self.total_sample_count = 0
if not self.plugin.pp_size > 1 or (
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
):
@ -428,8 +428,8 @@ class GRPOConsumer(BaseConsumer):
to_log_msg = (
f"Loss: {self.accum_loss.item() / self.accum_count:.4f} \
Reward: {self.accum_reward.item() / self.accum_count:.4f} \
Format Reward: {self.accum_format_reward.item() / self.accum_count:.4f} \
Acc Reward: {self.accum_acc_reward.item() / self.accum_count:.4f} \
Format Reward: {self.accum_format_acc.item() / self.accum_count:.4f} \
Acc Reward: {self.accum_ans_acc.item() / self.accum_count:.4f} \
Advantages: {self.accum_advantages.item() / self.accum_count:.4f} \
Response Length: {self.accum_response_length.item() / self.accum_count:.4f}"
+ f" KL: {self.accum_kl.item() / self.accum_count:.4f}"
@ -439,12 +439,13 @@ class GRPOConsumer(BaseConsumer):
print(to_log_msg)
metrics = {
"metrics/reward": self.accum_reward.item() / self.accum_count,
"metrics/format_reward": self.accum_format_reward.item() / self.accum_count,
"metrics/acc_reward": self.accum_acc_reward.item() / self.accum_count,
"metrics/format_acc": self.accum_format_acc.item() / self.accum_count,
"metrics/ans_acc": self.accum_ans_acc.item() / self.accum_count,
"metrics/response_length": self.accum_response_length.item() / self.accum_count,
"train/loss": self.accum_loss.item() / self.accum_count,
"train/advantages": self.accum_advantages.item() / self.accum_count,
"train/learning_rate": self.lr_scheduler.get_last_lr()[0],
"train/sample_utilization": sample_utilization,
"rollout/temperature": data["temperature"].cpu().numpy()[0][0],
}
if self.policy_loss_fn.beta > 0:
@ -453,12 +454,11 @@ class GRPOConsumer(BaseConsumer):
self.wandb_run.log(metrics)
self.accum_loss.zero_()
self.accum_reward.zero_()
self.accum_acc_reward.zero_()
self.accum_format_reward.zero_()
self.accum_ans_acc.zero_()
self.accum_format_acc.zero_()
self.accum_kl.zero_()
self.accum_advantages.zero_()
self.accum_response_length.zero_()
self.accum_count = 0
return loss_scalar
@ -507,8 +507,8 @@ class GRPOEvalConsumer(BaseConsumer):
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
self.policy_model.train()
self.accum_reward = torch.zeros(1, device=self.device)
self.accum_format_reward = torch.zeros(1, device=self.device)
self.accum_acc_reward = torch.zeros(1, device=self.device)
self.accum_format_acc = torch.zeros(1, device=self.device)
self.accum_ans_acc = torch.zeros(1, device=self.device)
self.accum_response_length = torch.zeros(1, device=self.device)
self.accum_count = torch.zeros(1, device=self.device)
@ -545,8 +545,8 @@ class GRPOEvalConsumer(BaseConsumer):
data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"]
)
reward = [value[0].item() for value in reward_group]
format_reward = [value[1].item() for value in reward_group]
acc_reward = [value[2].item() for value in reward_group]
format_acc = [value[1].item() for value in reward_group]
ans_acc = [value[2].item() for value in reward_group]
response_length = [(data["response_idx"][i][1] - data["response_idx"][i][0]).item() for i in range(len(reward))]
response = self.tokenizer.batch_decode(data["input_ids"], skip_special_tokens=True)
@ -557,8 +557,8 @@ class GRPOEvalConsumer(BaseConsumer):
{
"response": response[i],
"reward": reward[i],
"format_reward": format_reward[i],
"acc_reward": acc_reward[i],
"format_acc": format_acc[i],
"ans_acc": ans_acc[i],
"response_length": response_length[i],
},
ensure_ascii=False,
@ -567,20 +567,20 @@ class GRPOEvalConsumer(BaseConsumer):
)
self.accum_reward += sum(reward)
self.accum_format_reward += sum(format_reward)
self.accum_acc_reward += sum(acc_reward)
self.accum_format_acc += sum(format_acc)
self.accum_ans_acc += sum(ans_acc)
self.accum_response_length += sum(response_length)
self.accum_count += len(reward)
# print results
total_count = all_reduce_mean(self.accum_count, self.plugin)
mean_reward = all_reduce_mean(self.accum_reward, self.plugin) / total_count
mean_format_reward = all_reduce_mean(self.accum_format_reward, self.plugin) / total_count
mean_acc_reward = all_reduce_mean(self.accum_acc_reward, self.plugin) / total_count
mean_format_acc = all_reduce_mean(self.accum_format_acc, self.plugin) / total_count
mean_ans_acc = all_reduce_mean(self.accum_ans_acc, self.plugin) / total_count
mean_response_length = all_reduce_mean(self.accum_response_length, self.plugin) / total_count
if rank == 0:
print(
f"Step {step_idx}: Mean Reward: {mean_reward}, Mean Format Reward: {mean_format_reward}, Mean Acc Reward: {mean_acc_reward}, Mean Response Length: {mean_response_length}"
f"Step {step_idx}: Mean Reward: {mean_reward}, Mean Format Reward: {mean_format_acc}, Mean Acc Reward: {mean_ans_acc}, Mean Response Length: {mean_response_length}"
)
return None

View File

@ -14,14 +14,12 @@ class PolicyLoss(nn.Module):
self,
clip_eps_low: float = 0.2,
clip_eps_high: float = 0.2,
skip_threshold: float = 20.0,
beta: float = 0.01,
loss_variation: str = "sample_level",
) -> None:
super().__init__()
self.clip_eps_low = clip_eps_low
self.clip_eps_high = clip_eps_high
self.skip_threshold = skip_threshold
self.beta = beta
self.loss_variation = loss_variation
assert loss_variation in ["sample_level", "token_level"], f"Unsupported loss variation: {loss_variation}"
@ -35,7 +33,6 @@ class PolicyLoss(nn.Module):
action_mask: Optional[torch.Tensor] = None,
loss_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
skip = False
if action_mask is None:
ratio = (log_probs - log_probs.detach()).exp()
else:
@ -43,7 +40,7 @@ class PolicyLoss(nn.Module):
surr1 = ratio * advantages
surr2 = ratio.clamp(1 - self.clip_eps_low, 1 + self.clip_eps_high) * advantages
if self.beta <= 0:
if self.beta == 0:
# skip kl term if kl coefficient is zero
per_token_kl = 0.0
loss = -torch.min(surr1, surr2) + self.beta * per_token_kl
@ -68,5 +65,7 @@ class PolicyLoss(nn.Module):
loss = loss * loss_mask
total_tokens = total_tokens * loss_mask
loss = loss.sum() / (total_tokens.sum() + 1e-8)
else:
raise ValueError(f"Unsupported loss variation: {self.loss_variation}")
return loss, skip, ratio.max()
return loss, ratio.max()

View File

@ -9,8 +9,8 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
format_score = 0.0
acc_score = 10.0
reward = torch.tensor(0.0)
format_reward = torch.tensor(0.0)
acc_reward = torch.tensor(0.0)
format_acc = torch.tensor(0.0)
ans_acc = torch.tensor(0.0)
s, e = response_idx[0], response_idx[1]
length_reward = 0.0
@ -32,7 +32,7 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
# Check format accuracy
if format_valid:
format_reward += format_score
format_acc += 1
reward += format_score
# Check answer accuracy
@ -40,12 +40,12 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
final_answer is not None
and gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower()
):
acc_reward += acc_score
ans_acc += 1
reward += acc_score
reward = reward + length_reward
return torch.tensor([reward, format_reward, acc_reward]).to(input_ids.device)
return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device)
def gsm8k_reward_fn(input_ids, **kwargs):

View File

@ -14,7 +14,6 @@ class VerifiableReward:
def __call__(
self,
step: int,
input_ids: torch.LongTensor,
gt_answer: List[torch.Tensor] = None,
response_idx: List[torch.Tensor] = None,
@ -30,7 +29,6 @@ class VerifiableReward:
reward_batch = torch.stack(
[
reward_fn(
step,
input_ids[i],
gt_answer=gt_answer[i],
response_idx=response_idx[i],

View File

@ -128,20 +128,6 @@ def all_reduce_mean(tensor: torch.Tensor, plugin: Plugin = None) -> torch.Tensor
return tensor
# def all_reduce_sum(tensor: torch.Tensor, ) -> torch.Tensor:
# """
# Performs an all-reduce operation to sum the values of the given tensor across all processes.
# Args:
# tensor (torch.Tensor): The input tensor to be reduced.
# Returns:
# torch.Tensor: The reduced tensor with the sum of values across all processes.
# """
# dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM)
# return tensor
def all_reduce_sum(tensor: torch.Tensor, plugin: Plugin = None) -> torch.Tensor:
"""
Performs an all-reduce operation to sum the values of the given tensor across all processes.

View File

@ -111,7 +111,7 @@ if __name__ == "__main__":
# DAPO variant settings
grpo_config = {
"filter_range": [0.05, 9.0],
"filter_range": [0.01, 0.99], # only filter out all zero batch and all one batch
"lr": 1e-6,
"train_microbatch_size": args.train_microbatch_size,
"clip_eps_low": 0.2,
@ -144,8 +144,7 @@ if __name__ == "__main__":
grpo_config=grpo_config,
plugin_config={
"zero_stage": 2,
},
# for zero
}, # for zero
# plugin_config={
# "pp_size": 2,
# "tp_size": 2,