mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-23 03:33:05 +00:00
add entropy (#6363)
This commit is contained in:
parent
f5c155ab48
commit
4cf5ce20bf
@ -6,7 +6,7 @@ import torch
|
||||
import wandb
|
||||
from coati.distributed.consumer import BaseConsumer
|
||||
from coati.distributed.loss import PolicyLoss
|
||||
from coati.distributed.utils import memory_efficient_logprob
|
||||
from coati.distributed.utils import entropy_from_logits, memory_efficient_logprob
|
||||
from coati.trainer.utils import all_reduce_mean, all_reduce_sum
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
@ -75,6 +75,7 @@ class GRPOConsumer(BaseConsumer):
|
||||
self.optimizer = HybridAdam(self.policy_model.parameters(), lr=grpo_config.get("lr", 1e-6))
|
||||
self.accum_loss = torch.zeros(1, device=self.device)
|
||||
self.accum_kl = torch.zeros(1, device=self.device)
|
||||
self.accum_entropy = torch.zeros(1, device=self.device)
|
||||
self.accum_advantages = torch.zeros(1, device=self.device)
|
||||
self.raw_train_batch_reward = []
|
||||
self.raw_train_batch_format_acc = []
|
||||
@ -244,6 +245,7 @@ class GRPOConsumer(BaseConsumer):
|
||||
else self.booster.no_sync(self.policy_model, self.optimizer)
|
||||
)
|
||||
with ctx:
|
||||
mini_batch_entropies = []
|
||||
for forward_micro_batch_start in range(0, data["input_ids"].size(0), train_microbatch_size):
|
||||
input_ids_forward_micro_batch = data["input_ids"][
|
||||
forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
|
||||
@ -310,9 +312,11 @@ class GRPOConsumer(BaseConsumer):
|
||||
data_policy_forward["reference_action_log_probs"] = reference_action_log_probs
|
||||
|
||||
kl = []
|
||||
policy_model_logits = torch.empty_like(input_ids_forward_micro_batch, device=self.device)
|
||||
|
||||
def _criterion(outputs, inputs):
|
||||
action_logits = outputs.logits
|
||||
policy_model_logits.copy_(action_logits)
|
||||
action_log_probs = memory_efficient_logprob(
|
||||
action_logits / self.generate_config["temperature"],
|
||||
inputs["input_ids"],
|
||||
@ -359,6 +363,20 @@ class GRPOConsumer(BaseConsumer):
|
||||
kl = all_reduce_mean(torch.mean(torch.stack(kl)).to(loss.device), self.plugin).data
|
||||
mean_kl.append(kl)
|
||||
mean_loss.append(all_reduce_mean(loss, self.plugin).data)
|
||||
mini_batch_entropies.append(
|
||||
all_reduce_mean(
|
||||
(
|
||||
(
|
||||
(
|
||||
entropy_from_logits(policy_model_logits[:, -num_action:])
|
||||
* action_mask_forward_micro_batch
|
||||
).sum(-1)
|
||||
)
|
||||
/ action_mask_forward_micro_batch.sum(-1)
|
||||
).detach(),
|
||||
self.plugin,
|
||||
)
|
||||
)
|
||||
else:
|
||||
policy_model_logits = self.policy_model(
|
||||
input_ids=input_ids_forward_micro_batch,
|
||||
@ -412,6 +430,20 @@ class GRPOConsumer(BaseConsumer):
|
||||
kl = all_reduce_mean(kl.mean(), self.plugin)
|
||||
mean_kl.append(kl.data)
|
||||
mean_loss.append(loss.data)
|
||||
mini_batch_entropies.append(
|
||||
all_reduce_mean(
|
||||
(
|
||||
(
|
||||
(
|
||||
entropy_from_logits(policy_model_logits[:, -num_action:])
|
||||
* action_mask_forward_micro_batch
|
||||
).sum(-1)
|
||||
)
|
||||
/ action_mask_forward_micro_batch.sum(-1)
|
||||
).detach(),
|
||||
self.plugin,
|
||||
)
|
||||
)
|
||||
if not self.plugin.pp_size > 1 or (
|
||||
self.plugin.pp_size > 1
|
||||
and self.booster.plugin.stage_manager.is_last_stage()
|
||||
@ -423,7 +455,9 @@ class GRPOConsumer(BaseConsumer):
|
||||
ans_acc = all_reduce_mean(ans_acc.mean(), self.plugin)
|
||||
advantages = all_reduce_mean(advantages.mean(), self.plugin)
|
||||
response_length = all_reduce_mean(response_length.mean(), self.plugin)
|
||||
entropy = torch.cat(mini_batch_entropies, dim=0).mean()
|
||||
self.accum_loss.add_(sum(mean_loss) / len(mean_loss))
|
||||
self.accum_entropy.add_(entropy.data)
|
||||
if self.policy_loss_fn.beta > 0:
|
||||
self.accum_kl.add_(sum(mean_kl) / len(mean_kl))
|
||||
self.accum_advantages.add_(advantages.data)
|
||||
@ -464,6 +498,7 @@ class GRPOConsumer(BaseConsumer):
|
||||
f"Response Length: {raw_batch_response_len_mean:.4f}",
|
||||
f"Sample_utilization: {sample_utilization:.4f}",
|
||||
f"Overlength samples ratio: {overlength_samples_ratio:.4f}",
|
||||
f"Entropy: {self.accum_entropy.item() / self.accum_count:.4f}",
|
||||
] + ([f"KL: {self.accum_kl.item() / self.accum_count:.4f}"] if self.policy_loss_fn.beta > 0 else [])
|
||||
print("\n".join(to_log_msg))
|
||||
metrics = {
|
||||
@ -475,6 +510,7 @@ class GRPOConsumer(BaseConsumer):
|
||||
"train/advantages": self.accum_advantages.item() / self.accum_count,
|
||||
"train/learning_rate": self.lr_scheduler.get_last_lr()[0],
|
||||
"train/sample_utilization": sample_utilization,
|
||||
"train/entropy": self.accum_entropy.item() / self.accum_count,
|
||||
"train/overlength_samples_ratio": overlength_samples_ratio,
|
||||
"rollout/temperature": data["temperature"].cpu().numpy()[0][0],
|
||||
}
|
||||
@ -484,6 +520,7 @@ class GRPOConsumer(BaseConsumer):
|
||||
self.wandb_run.log(metrics)
|
||||
self.accum_loss.zero_()
|
||||
self.accum_kl.zero_()
|
||||
self.accum_entropy.zero_()
|
||||
self.accum_advantages.zero_()
|
||||
self.accum_count = 0
|
||||
return loss_scalar
|
||||
|
@ -110,6 +110,16 @@ def memory_efficient_logprob(
|
||||
return action_log_probs
|
||||
|
||||
|
||||
def entropy_from_logits(logits: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Calculate entropy
|
||||
Reference: https://github.com/volcengine/verl/blob/96b730bbed80292a439f0c0057d3920ab8b28d52/verl/utils/torch_functional.py#L145
|
||||
"""
|
||||
p = torch.nn.functional.softmax(logits, dim=-1)
|
||||
entropy = torch.logsumexp(logits, dim=-1) - torch.sum(p * logits, dim=-1)
|
||||
return entropy
|
||||
|
||||
|
||||
def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch.Tensor:
|
||||
"""
|
||||
Compute the masked mean of a tensor along a specified dimension.
|
||||
|
Loading…
Reference in New Issue
Block a user