hotfix entropy calculation

This commit is contained in:
YeAnbang 2025-07-21 17:58:48 +08:00
parent 4cf5ce20bf
commit b6db7da815

View File

@ -250,6 +250,9 @@ class GRPOConsumer(BaseConsumer):
input_ids_forward_micro_batch = data["input_ids"][
forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
]
old_action_log_probs_micro_batch = old_action_log_probs[
forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
]
attention_mask_forward_micro_batch = data["attention_mask"][
forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
]
@ -306,17 +309,22 @@ class GRPOConsumer(BaseConsumer):
"action_mask": action_mask_forward_micro_batch,
"advantages": advantages_forward_micro_batch,
"loss_mask": loss_mask_forward_micro_batch,
"old_action_log_probs": old_action_log_probs_micro_batch,
"source": self.rank,
}
if reference_action_log_probs is not None:
data_policy_forward["reference_action_log_probs"] = reference_action_log_probs
kl = []
policy_model_logits = torch.empty_like(input_ids_forward_micro_batch, device=self.device)
def _criterion(outputs, inputs):
action_logits = outputs.logits
policy_model_logits.copy_(action_logits)
mini_batch_entropies.append(
(
((entropy_from_logits(action_logits[:, -num_action:]) * inputs["action_mask"]).sum(-1))
/ inputs["action_mask"].sum(-1)
).detach()
)
action_log_probs = memory_efficient_logprob(
action_logits / self.generate_config["temperature"],
inputs["input_ids"],
@ -339,7 +347,7 @@ class GRPOConsumer(BaseConsumer):
loss, _ = self.policy_loss_fn(
action_log_probs,
action_log_probs,
inputs["old_action_log_probs"],
inputs["advantages"].repeat_interleave(action_log_probs.size(-1), dim=-1),
per_token_kl,
inputs["action_mask"],
@ -363,20 +371,6 @@ class GRPOConsumer(BaseConsumer):
kl = all_reduce_mean(torch.mean(torch.stack(kl)).to(loss.device), self.plugin).data
mean_kl.append(kl)
mean_loss.append(all_reduce_mean(loss, self.plugin).data)
mini_batch_entropies.append(
all_reduce_mean(
(
(
(
entropy_from_logits(policy_model_logits[:, -num_action:])
* action_mask_forward_micro_batch
).sum(-1)
)
/ action_mask_forward_micro_batch.sum(-1)
).detach(),
self.plugin,
)
)
else:
policy_model_logits = self.policy_model(
input_ids=input_ids_forward_micro_batch,
@ -415,7 +409,7 @@ class GRPOConsumer(BaseConsumer):
loss, _ = self.policy_loss_fn(
action_log_probs,
old_action_log_probs,
old_action_log_probs_micro_batch,
advantages_forward_micro_batch.repeat_interleave(action_log_probs.size(-1), dim=-1),
per_token_kl,
action_mask_forward_micro_batch,
@ -455,7 +449,7 @@ class GRPOConsumer(BaseConsumer):
ans_acc = all_reduce_mean(ans_acc.mean(), self.plugin)
advantages = all_reduce_mean(advantages.mean(), self.plugin)
response_length = all_reduce_mean(response_length.mean(), self.plugin)
entropy = torch.cat(mini_batch_entropies, dim=0).mean()
entropy = all_reduce_mean(torch.cat(mini_batch_entropies, dim=0).mean(), self.plugin)
self.accum_loss.add_(sum(mean_loss) / len(mean_loss))
self.accum_entropy.add_(entropy.data)
if self.policy_loss_fn.beta > 0: