mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-23 11:44:15 +00:00
add entropy (#6363)
This commit is contained in:
parent
f5c155ab48
commit
4cf5ce20bf
@ -6,7 +6,7 @@ import torch
|
|||||||
import wandb
|
import wandb
|
||||||
from coati.distributed.consumer import BaseConsumer
|
from coati.distributed.consumer import BaseConsumer
|
||||||
from coati.distributed.loss import PolicyLoss
|
from coati.distributed.loss import PolicyLoss
|
||||||
from coati.distributed.utils import memory_efficient_logprob
|
from coati.distributed.utils import entropy_from_logits, memory_efficient_logprob
|
||||||
from coati.trainer.utils import all_reduce_mean, all_reduce_sum
|
from coati.trainer.utils import all_reduce_mean, all_reduce_sum
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
@ -75,6 +75,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
self.optimizer = HybridAdam(self.policy_model.parameters(), lr=grpo_config.get("lr", 1e-6))
|
self.optimizer = HybridAdam(self.policy_model.parameters(), lr=grpo_config.get("lr", 1e-6))
|
||||||
self.accum_loss = torch.zeros(1, device=self.device)
|
self.accum_loss = torch.zeros(1, device=self.device)
|
||||||
self.accum_kl = torch.zeros(1, device=self.device)
|
self.accum_kl = torch.zeros(1, device=self.device)
|
||||||
|
self.accum_entropy = torch.zeros(1, device=self.device)
|
||||||
self.accum_advantages = torch.zeros(1, device=self.device)
|
self.accum_advantages = torch.zeros(1, device=self.device)
|
||||||
self.raw_train_batch_reward = []
|
self.raw_train_batch_reward = []
|
||||||
self.raw_train_batch_format_acc = []
|
self.raw_train_batch_format_acc = []
|
||||||
@ -244,6 +245,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
else self.booster.no_sync(self.policy_model, self.optimizer)
|
else self.booster.no_sync(self.policy_model, self.optimizer)
|
||||||
)
|
)
|
||||||
with ctx:
|
with ctx:
|
||||||
|
mini_batch_entropies = []
|
||||||
for forward_micro_batch_start in range(0, data["input_ids"].size(0), train_microbatch_size):
|
for forward_micro_batch_start in range(0, data["input_ids"].size(0), train_microbatch_size):
|
||||||
input_ids_forward_micro_batch = data["input_ids"][
|
input_ids_forward_micro_batch = data["input_ids"][
|
||||||
forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
|
forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
|
||||||
@ -310,9 +312,11 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
data_policy_forward["reference_action_log_probs"] = reference_action_log_probs
|
data_policy_forward["reference_action_log_probs"] = reference_action_log_probs
|
||||||
|
|
||||||
kl = []
|
kl = []
|
||||||
|
policy_model_logits = torch.empty_like(input_ids_forward_micro_batch, device=self.device)
|
||||||
|
|
||||||
def _criterion(outputs, inputs):
|
def _criterion(outputs, inputs):
|
||||||
action_logits = outputs.logits
|
action_logits = outputs.logits
|
||||||
|
policy_model_logits.copy_(action_logits)
|
||||||
action_log_probs = memory_efficient_logprob(
|
action_log_probs = memory_efficient_logprob(
|
||||||
action_logits / self.generate_config["temperature"],
|
action_logits / self.generate_config["temperature"],
|
||||||
inputs["input_ids"],
|
inputs["input_ids"],
|
||||||
@ -359,6 +363,20 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
kl = all_reduce_mean(torch.mean(torch.stack(kl)).to(loss.device), self.plugin).data
|
kl = all_reduce_mean(torch.mean(torch.stack(kl)).to(loss.device), self.plugin).data
|
||||||
mean_kl.append(kl)
|
mean_kl.append(kl)
|
||||||
mean_loss.append(all_reduce_mean(loss, self.plugin).data)
|
mean_loss.append(all_reduce_mean(loss, self.plugin).data)
|
||||||
|
mini_batch_entropies.append(
|
||||||
|
all_reduce_mean(
|
||||||
|
(
|
||||||
|
(
|
||||||
|
(
|
||||||
|
entropy_from_logits(policy_model_logits[:, -num_action:])
|
||||||
|
* action_mask_forward_micro_batch
|
||||||
|
).sum(-1)
|
||||||
|
)
|
||||||
|
/ action_mask_forward_micro_batch.sum(-1)
|
||||||
|
).detach(),
|
||||||
|
self.plugin,
|
||||||
|
)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
policy_model_logits = self.policy_model(
|
policy_model_logits = self.policy_model(
|
||||||
input_ids=input_ids_forward_micro_batch,
|
input_ids=input_ids_forward_micro_batch,
|
||||||
@ -412,6 +430,20 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
kl = all_reduce_mean(kl.mean(), self.plugin)
|
kl = all_reduce_mean(kl.mean(), self.plugin)
|
||||||
mean_kl.append(kl.data)
|
mean_kl.append(kl.data)
|
||||||
mean_loss.append(loss.data)
|
mean_loss.append(loss.data)
|
||||||
|
mini_batch_entropies.append(
|
||||||
|
all_reduce_mean(
|
||||||
|
(
|
||||||
|
(
|
||||||
|
(
|
||||||
|
entropy_from_logits(policy_model_logits[:, -num_action:])
|
||||||
|
* action_mask_forward_micro_batch
|
||||||
|
).sum(-1)
|
||||||
|
)
|
||||||
|
/ action_mask_forward_micro_batch.sum(-1)
|
||||||
|
).detach(),
|
||||||
|
self.plugin,
|
||||||
|
)
|
||||||
|
)
|
||||||
if not self.plugin.pp_size > 1 or (
|
if not self.plugin.pp_size > 1 or (
|
||||||
self.plugin.pp_size > 1
|
self.plugin.pp_size > 1
|
||||||
and self.booster.plugin.stage_manager.is_last_stage()
|
and self.booster.plugin.stage_manager.is_last_stage()
|
||||||
@ -423,7 +455,9 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
ans_acc = all_reduce_mean(ans_acc.mean(), self.plugin)
|
ans_acc = all_reduce_mean(ans_acc.mean(), self.plugin)
|
||||||
advantages = all_reduce_mean(advantages.mean(), self.plugin)
|
advantages = all_reduce_mean(advantages.mean(), self.plugin)
|
||||||
response_length = all_reduce_mean(response_length.mean(), self.plugin)
|
response_length = all_reduce_mean(response_length.mean(), self.plugin)
|
||||||
|
entropy = torch.cat(mini_batch_entropies, dim=0).mean()
|
||||||
self.accum_loss.add_(sum(mean_loss) / len(mean_loss))
|
self.accum_loss.add_(sum(mean_loss) / len(mean_loss))
|
||||||
|
self.accum_entropy.add_(entropy.data)
|
||||||
if self.policy_loss_fn.beta > 0:
|
if self.policy_loss_fn.beta > 0:
|
||||||
self.accum_kl.add_(sum(mean_kl) / len(mean_kl))
|
self.accum_kl.add_(sum(mean_kl) / len(mean_kl))
|
||||||
self.accum_advantages.add_(advantages.data)
|
self.accum_advantages.add_(advantages.data)
|
||||||
@ -464,6 +498,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
f"Response Length: {raw_batch_response_len_mean:.4f}",
|
f"Response Length: {raw_batch_response_len_mean:.4f}",
|
||||||
f"Sample_utilization: {sample_utilization:.4f}",
|
f"Sample_utilization: {sample_utilization:.4f}",
|
||||||
f"Overlength samples ratio: {overlength_samples_ratio:.4f}",
|
f"Overlength samples ratio: {overlength_samples_ratio:.4f}",
|
||||||
|
f"Entropy: {self.accum_entropy.item() / self.accum_count:.4f}",
|
||||||
] + ([f"KL: {self.accum_kl.item() / self.accum_count:.4f}"] if self.policy_loss_fn.beta > 0 else [])
|
] + ([f"KL: {self.accum_kl.item() / self.accum_count:.4f}"] if self.policy_loss_fn.beta > 0 else [])
|
||||||
print("\n".join(to_log_msg))
|
print("\n".join(to_log_msg))
|
||||||
metrics = {
|
metrics = {
|
||||||
@ -475,6 +510,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
"train/advantages": self.accum_advantages.item() / self.accum_count,
|
"train/advantages": self.accum_advantages.item() / self.accum_count,
|
||||||
"train/learning_rate": self.lr_scheduler.get_last_lr()[0],
|
"train/learning_rate": self.lr_scheduler.get_last_lr()[0],
|
||||||
"train/sample_utilization": sample_utilization,
|
"train/sample_utilization": sample_utilization,
|
||||||
|
"train/entropy": self.accum_entropy.item() / self.accum_count,
|
||||||
"train/overlength_samples_ratio": overlength_samples_ratio,
|
"train/overlength_samples_ratio": overlength_samples_ratio,
|
||||||
"rollout/temperature": data["temperature"].cpu().numpy()[0][0],
|
"rollout/temperature": data["temperature"].cpu().numpy()[0][0],
|
||||||
}
|
}
|
||||||
@ -484,6 +520,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
self.wandb_run.log(metrics)
|
self.wandb_run.log(metrics)
|
||||||
self.accum_loss.zero_()
|
self.accum_loss.zero_()
|
||||||
self.accum_kl.zero_()
|
self.accum_kl.zero_()
|
||||||
|
self.accum_entropy.zero_()
|
||||||
self.accum_advantages.zero_()
|
self.accum_advantages.zero_()
|
||||||
self.accum_count = 0
|
self.accum_count = 0
|
||||||
return loss_scalar
|
return loss_scalar
|
||||||
|
@ -110,6 +110,16 @@ def memory_efficient_logprob(
|
|||||||
return action_log_probs
|
return action_log_probs
|
||||||
|
|
||||||
|
|
||||||
|
def entropy_from_logits(logits: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Calculate entropy
|
||||||
|
Reference: https://github.com/volcengine/verl/blob/96b730bbed80292a439f0c0057d3920ab8b28d52/verl/utils/torch_functional.py#L145
|
||||||
|
"""
|
||||||
|
p = torch.nn.functional.softmax(logits, dim=-1)
|
||||||
|
entropy = torch.logsumexp(logits, dim=-1) - torch.sum(p * logits, dim=-1)
|
||||||
|
return entropy
|
||||||
|
|
||||||
|
|
||||||
def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch.Tensor:
|
def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Compute the masked mean of a tensor along a specified dimension.
|
Compute the masked mean of a tensor along a specified dimension.
|
||||||
|
Loading…
Reference in New Issue
Block a user