This commit is contained in:
YeAnbang 2025-07-21 18:04:20 +08:00
parent e774edeb80
commit 5c5cb1863b

View File

@ -371,20 +371,6 @@ class GRPOConsumer(BaseConsumer):
kl = all_reduce_mean(torch.mean(torch.stack(kl)).to(loss.device), self.plugin).data
mean_kl.append(kl)
mean_loss.append(all_reduce_mean(loss, self.plugin).data)
mini_batch_entropies.append(
all_reduce_mean(
(
(
(
entropy_from_logits(policy_model_logits[:, -num_action:])
* action_mask_forward_micro_batch
).sum(-1)
)
/ action_mask_forward_micro_batch.sum(-1)
).detach(),
self.plugin,
)
)
else:
policy_model_logits = self.policy_model(
input_ids=input_ids_forward_micro_batch,