mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-23 19:49:30 +00:00
hotfix entropy calculation
This commit is contained in:
parent
4cf5ce20bf
commit
b6db7da815
@ -250,6 +250,9 @@ class GRPOConsumer(BaseConsumer):
|
||||
input_ids_forward_micro_batch = data["input_ids"][
|
||||
forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
|
||||
]
|
||||
old_action_log_probs_micro_batch = old_action_log_probs[
|
||||
forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
|
||||
]
|
||||
attention_mask_forward_micro_batch = data["attention_mask"][
|
||||
forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
|
||||
]
|
||||
@ -306,17 +309,22 @@ class GRPOConsumer(BaseConsumer):
|
||||
"action_mask": action_mask_forward_micro_batch,
|
||||
"advantages": advantages_forward_micro_batch,
|
||||
"loss_mask": loss_mask_forward_micro_batch,
|
||||
"old_action_log_probs": old_action_log_probs_micro_batch,
|
||||
"source": self.rank,
|
||||
}
|
||||
if reference_action_log_probs is not None:
|
||||
data_policy_forward["reference_action_log_probs"] = reference_action_log_probs
|
||||
|
||||
kl = []
|
||||
policy_model_logits = torch.empty_like(input_ids_forward_micro_batch, device=self.device)
|
||||
|
||||
def _criterion(outputs, inputs):
|
||||
action_logits = outputs.logits
|
||||
policy_model_logits.copy_(action_logits)
|
||||
mini_batch_entropies.append(
|
||||
(
|
||||
((entropy_from_logits(action_logits[:, -num_action:]) * inputs["action_mask"]).sum(-1))
|
||||
/ inputs["action_mask"].sum(-1)
|
||||
).detach()
|
||||
)
|
||||
action_log_probs = memory_efficient_logprob(
|
||||
action_logits / self.generate_config["temperature"],
|
||||
inputs["input_ids"],
|
||||
@ -339,7 +347,7 @@ class GRPOConsumer(BaseConsumer):
|
||||
|
||||
loss, _ = self.policy_loss_fn(
|
||||
action_log_probs,
|
||||
action_log_probs,
|
||||
inputs["old_action_log_probs"],
|
||||
inputs["advantages"].repeat_interleave(action_log_probs.size(-1), dim=-1),
|
||||
per_token_kl,
|
||||
inputs["action_mask"],
|
||||
@ -363,20 +371,6 @@ class GRPOConsumer(BaseConsumer):
|
||||
kl = all_reduce_mean(torch.mean(torch.stack(kl)).to(loss.device), self.plugin).data
|
||||
mean_kl.append(kl)
|
||||
mean_loss.append(all_reduce_mean(loss, self.plugin).data)
|
||||
mini_batch_entropies.append(
|
||||
all_reduce_mean(
|
||||
(
|
||||
(
|
||||
(
|
||||
entropy_from_logits(policy_model_logits[:, -num_action:])
|
||||
* action_mask_forward_micro_batch
|
||||
).sum(-1)
|
||||
)
|
||||
/ action_mask_forward_micro_batch.sum(-1)
|
||||
).detach(),
|
||||
self.plugin,
|
||||
)
|
||||
)
|
||||
else:
|
||||
policy_model_logits = self.policy_model(
|
||||
input_ids=input_ids_forward_micro_batch,
|
||||
@ -415,7 +409,7 @@ class GRPOConsumer(BaseConsumer):
|
||||
|
||||
loss, _ = self.policy_loss_fn(
|
||||
action_log_probs,
|
||||
old_action_log_probs,
|
||||
old_action_log_probs_micro_batch,
|
||||
advantages_forward_micro_batch.repeat_interleave(action_log_probs.size(-1), dim=-1),
|
||||
per_token_kl,
|
||||
action_mask_forward_micro_batch,
|
||||
@ -455,7 +449,7 @@ class GRPOConsumer(BaseConsumer):
|
||||
ans_acc = all_reduce_mean(ans_acc.mean(), self.plugin)
|
||||
advantages = all_reduce_mean(advantages.mean(), self.plugin)
|
||||
response_length = all_reduce_mean(response_length.mean(), self.plugin)
|
||||
entropy = torch.cat(mini_batch_entropies, dim=0).mean()
|
||||
entropy = all_reduce_mean(torch.cat(mini_batch_entropies, dim=0).mean(), self.plugin)
|
||||
self.accum_loss.add_(sum(mean_loss) / len(mean_loss))
|
||||
self.accum_entropy.add_(entropy.data)
|
||||
if self.policy_loss_fn.beta > 0:
|
||||
|
Loading…
Reference in New Issue
Block a user