[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot]
2025-05-29 10:16:55 +00:00
parent 7b921acc8a
commit 0d008110e7
2 changed files with 34 additions and 16 deletions

View File

@@ -211,10 +211,12 @@ class GRPOConsumer(BaseConsumer):
loss_mask,
action_mask[:, -1] == False,
)
if self.filter_range is not None and self.grpo_config.get("dynamic_batching", False)==False:
if self.filter_range is not None and self.grpo_config.get("dynamic_batching", False) == False:
# filter out samples with reward outside the range
# if dynamic batching is enabled, we filter out out of range groups before training
group_ans_acc_mean = ans_acc.view(-1, self.num_generations).mean(dim=1).repeat_interleave(self.num_generations, dim=-1)
group_ans_acc_mean = (
ans_acc.view(-1, self.num_generations).mean(dim=1).repeat_interleave(self.num_generations, dim=-1)
)
loss_mask = torch.logical_and(
loss_mask,
torch.logical_and(
@@ -454,7 +456,9 @@ class GRPOConsumer(BaseConsumer):
raw_batch_ans_acc_mean = torch.cat(self.raw_train_batch_ans_acc, dim=0).mean().cpu().item()
raw_batch_response_len = torch.cat(self.raw_train_batch_response_len, dim=0)
raw_batch_response_len_mean = raw_batch_response_len.mean().cpu().item()
overlength_samples_ratio = (raw_batch_response_len >= action_mask.size(-1)).to(float).mean().cpu().item() # not an exact figure, but a close estimate
overlength_samples_ratio = (
(raw_batch_response_len >= action_mask.size(-1)).to(float).mean().cpu().item()
) # not an exact figure, but a close estimate
self.raw_train_batch_reward = []
self.raw_train_batch_format_acc = []
self.raw_train_batch_ans_acc = []