mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 12:30:42 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
@@ -211,10 +211,12 @@ class GRPOConsumer(BaseConsumer):
|
||||
loss_mask,
|
||||
action_mask[:, -1] == False,
|
||||
)
|
||||
if self.filter_range is not None and self.grpo_config.get("dynamic_batching", False)==False:
|
||||
if self.filter_range is not None and self.grpo_config.get("dynamic_batching", False) == False:
|
||||
# filter out samples with reward outside the range
|
||||
# if dynamic batching is enabled, we filter out out of range groups before training
|
||||
group_ans_acc_mean = ans_acc.view(-1, self.num_generations).mean(dim=1).repeat_interleave(self.num_generations, dim=-1)
|
||||
group_ans_acc_mean = (
|
||||
ans_acc.view(-1, self.num_generations).mean(dim=1).repeat_interleave(self.num_generations, dim=-1)
|
||||
)
|
||||
loss_mask = torch.logical_and(
|
||||
loss_mask,
|
||||
torch.logical_and(
|
||||
@@ -454,7 +456,9 @@ class GRPOConsumer(BaseConsumer):
|
||||
raw_batch_ans_acc_mean = torch.cat(self.raw_train_batch_ans_acc, dim=0).mean().cpu().item()
|
||||
raw_batch_response_len = torch.cat(self.raw_train_batch_response_len, dim=0)
|
||||
raw_batch_response_len_mean = raw_batch_response_len.mean().cpu().item()
|
||||
overlength_samples_ratio = (raw_batch_response_len >= action_mask.size(-1)).to(float).mean().cpu().item() # not an exact figure, but a close estimate
|
||||
overlength_samples_ratio = (
|
||||
(raw_batch_response_len >= action_mask.size(-1)).to(float).mean().cpu().item()
|
||||
) # not an exact figure, but a close estimate
|
||||
self.raw_train_batch_reward = []
|
||||
self.raw_train_batch_format_acc = []
|
||||
self.raw_train_batch_ans_acc = []
|
||||
|
Reference in New Issue
Block a user