From f54ae56f12282a129a7ae16466af3dcd9d9fee4f Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Wed, 16 Jul 2025 16:44:23 +0800 Subject: [PATCH] add entropy --- .../coati/distributed/grpo_consumer.py | 39 ++++++++++++++++++- .../ColossalChat/coati/distributed/utils.py | 10 +++++ 2 files changed, 48 insertions(+), 1 deletion(-) diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index f8ce1afde..754f78097 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -6,7 +6,7 @@ import torch import wandb from coati.distributed.consumer import BaseConsumer from coati.distributed.loss import PolicyLoss -from coati.distributed.utils import memory_efficient_logprob +from coati.distributed.utils import entropy_from_logits, memory_efficient_logprob from coati.trainer.utils import all_reduce_mean, all_reduce_sum from transformers import AutoModelForCausalLM, AutoTokenizer @@ -75,6 +75,7 @@ class GRPOConsumer(BaseConsumer): self.optimizer = HybridAdam(self.policy_model.parameters(), lr=grpo_config.get("lr", 1e-6)) self.accum_loss = torch.zeros(1, device=self.device) self.accum_kl = torch.zeros(1, device=self.device) + self.accum_entropy = torch.zeros(1, device=self.device) self.accum_advantages = torch.zeros(1, device=self.device) self.raw_train_batch_reward = [] self.raw_train_batch_format_acc = [] @@ -244,6 +245,7 @@ class GRPOConsumer(BaseConsumer): else self.booster.no_sync(self.policy_model, self.optimizer) ) with ctx: + mini_batch_entropies = [] for forward_micro_batch_start in range(0, data["input_ids"].size(0), train_microbatch_size): input_ids_forward_micro_batch = data["input_ids"][ forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size @@ -310,9 +312,11 @@ class GRPOConsumer(BaseConsumer): data_policy_forward["reference_action_log_probs"] = reference_action_log_probs kl = [] + policy_model_logits = torch.empty_like(input_ids_forward_micro_batch, device=self.device) def _criterion(outputs, inputs): action_logits = outputs.logits + policy_model_logits.copy_(action_logits) action_log_probs = memory_efficient_logprob( action_logits / self.generate_config["temperature"], inputs["input_ids"], @@ -359,6 +363,20 @@ class GRPOConsumer(BaseConsumer): kl = all_reduce_mean(torch.mean(torch.stack(kl)).to(loss.device), self.plugin).data mean_kl.append(kl) mean_loss.append(all_reduce_mean(loss, self.plugin).data) + mini_batch_entropies.append( + all_reduce_mean( + ( + ( + ( + entropy_from_logits(policy_model_logits[:, -num_action:]) + * action_mask_forward_micro_batch + ).sum(-1) + ) + / action_mask_forward_micro_batch.sum(-1) + ).detach(), + self.plugin, + ) + ) else: policy_model_logits = self.policy_model( input_ids=input_ids_forward_micro_batch, @@ -412,6 +430,20 @@ class GRPOConsumer(BaseConsumer): kl = all_reduce_mean(kl.mean(), self.plugin) mean_kl.append(kl.data) mean_loss.append(loss.data) + mini_batch_entropies.append( + all_reduce_mean( + ( + ( + ( + entropy_from_logits(policy_model_logits[:, -num_action:]) + * action_mask_forward_micro_batch + ).sum(-1) + ) + / action_mask_forward_micro_batch.sum(-1) + ).detach(), + self.plugin, + ) + ) if not self.plugin.pp_size > 1 or ( self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() @@ -423,7 +455,9 @@ class GRPOConsumer(BaseConsumer): ans_acc = all_reduce_mean(ans_acc.mean(), self.plugin) advantages = all_reduce_mean(advantages.mean(), self.plugin) response_length = all_reduce_mean(response_length.mean(), self.plugin) + entropy = torch.cat(mini_batch_entropies, dim=0).mean() self.accum_loss.add_(sum(mean_loss) / len(mean_loss)) + self.accum_entropy.add_(entropy.data) if self.policy_loss_fn.beta > 0: self.accum_kl.add_(sum(mean_kl) / len(mean_kl)) self.accum_advantages.add_(advantages.data) @@ -464,6 +498,7 @@ class GRPOConsumer(BaseConsumer): f"Response Length: {raw_batch_response_len_mean:.4f}", f"Sample_utilization: {sample_utilization:.4f}", f"Overlength samples ratio: {overlength_samples_ratio:.4f}", + f"Entropy: {self.accum_entropy.item() / self.accum_count:.4f}", ] + ([f"KL: {self.accum_kl.item() / self.accum_count:.4f}"] if self.policy_loss_fn.beta > 0 else []) print("\n".join(to_log_msg)) metrics = { @@ -475,6 +510,7 @@ class GRPOConsumer(BaseConsumer): "train/advantages": self.accum_advantages.item() / self.accum_count, "train/learning_rate": self.lr_scheduler.get_last_lr()[0], "train/sample_utilization": sample_utilization, + "train/entropy": self.accum_entropy.item() / self.accum_count, "train/overlength_samples_ratio": overlength_samples_ratio, "rollout/temperature": data["temperature"].cpu().numpy()[0][0], } @@ -484,6 +520,7 @@ class GRPOConsumer(BaseConsumer): self.wandb_run.log(metrics) self.accum_loss.zero_() self.accum_kl.zero_() + self.accum_entropy.zero_() self.accum_advantages.zero_() self.accum_count = 0 return loss_scalar diff --git a/applications/ColossalChat/coati/distributed/utils.py b/applications/ColossalChat/coati/distributed/utils.py index d46243114..466914cc0 100644 --- a/applications/ColossalChat/coati/distributed/utils.py +++ b/applications/ColossalChat/coati/distributed/utils.py @@ -110,6 +110,16 @@ def memory_efficient_logprob( return action_log_probs +def entropy_from_logits(logits: torch.Tensor) -> torch.Tensor: + """ + Calculate entropy + Reference: https://github.com/volcengine/verl/blob/96b730bbed80292a439f0c0057d3920ab8b28d52/verl/utils/torch_functional.py#L145 + """ + p = torch.nn.functional.softmax(logits, dim=-1) + entropy = torch.logsumexp(logits, dim=-1) - torch.sum(p * logits, dim=-1) + return entropy + + def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch.Tensor: """ Compute the masked mean of a tensor along a specified dimension.