[Feature] Support Distributed LogProb for GRPO Training (#6247)

* [fix] fix qwen VocabParallelLMHead1D and gather output

* fix tp bug

* fix consumer

* [feat] Support Distributed LogProb for GRPO Training

* [fix] fix loss func

* [fix] fix log prob plugin

* [fix] fix qwen modeling param

* [fix] rm comments

* [fix] rm hard-code;fix non-dist version

* [fix] fix test file param name and benchmark tp gather output=True/False

* [fix] rm non-dist version in dist log prob

* [fix] fix comments

* [fix] fix dis log prob plugin

* [fix] fix test case

* [fix] fix qwen VocabParallelLMHead1D and gather output

* [fix] fix DistLogProb comments

* [fix] restore tp size

* [fix] fix comments

* [fix] fix comment; fix LogSoftmax usage

---------

Co-authored-by: Tong Li <tong.li35271158@gmail.com>
This commit is contained in:
duanjunwen 2025-03-18 17:47:55 +08:00 committed by GitHub
parent bc0171d392
commit 7795d4c50d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 233 additions and 12 deletions

View File

@ -73,8 +73,6 @@ class BaseConsumer:
) )
if self.plugin_config.get("pp_size", 1) > 1 and "num_microbatches" not in self.plugin_config: if self.plugin_config.get("pp_size", 1) > 1 and "num_microbatches" not in self.plugin_config:
plugin_config["microbatch_size"] = self.microbatch_size plugin_config["microbatch_size"] = self.microbatch_size
if self.plugin_config.get("tp_size", 1) > 1:
plugin_config["parallel_output"] = False
plugin_config.update(self.plugin_config) plugin_config.update(self.plugin_config)
self.plugin = HybridParallelPlugin(**plugin_config) self.plugin = HybridParallelPlugin(**plugin_config)
self.booster = Booster(plugin=self.plugin) self.booster = Booster(plugin=self.plugin)

View File

@ -120,14 +120,18 @@ class GRPOConsumer(BaseConsumer):
input_ids=data["input_ids"], input_ids=data["input_ids"],
attention_mask=data["attention_mask"], attention_mask=data["attention_mask"],
)["logits"] )["logits"]
action_log_probs = calc_action_log_probs(policy_model_logits, data["input_ids"], num_action) action_log_probs = calc_action_log_probs(
policy_model_logits, data["input_ids"], num_action, self.plugin.shard_config
)
with torch.no_grad(): with torch.no_grad():
reference_model_logits = self.reference_model( reference_model_logits = self.reference_model(
input_ids=data["input_ids"], input_ids=data["input_ids"],
attention_mask=data["attention_mask"], attention_mask=data["attention_mask"],
)["logits"] )["logits"]
reference_action_log_probs = calc_action_log_probs(reference_model_logits, data["input_ids"], num_action) reference_action_log_probs = calc_action_log_probs(
reference_model_logits, data["input_ids"], num_action, self.plugin.shard_config
)
per_token_kl = ( per_token_kl = (
torch.exp(reference_action_log_probs - action_log_probs) torch.exp(reference_action_log_probs - action_log_probs)

View File

@ -2,6 +2,8 @@ from typing import Any, Dict, List
import torch import torch
from colossalai.shardformer.layer.loss import dist_log_prob
def unbind_batch(batch: Dict[str, torch.Tensor]) -> List[Dict[str, torch.Tensor]]: def unbind_batch(batch: Dict[str, torch.Tensor]) -> List[Dict[str, torch.Tensor]]:
batches = [] batches = []
@ -66,18 +68,30 @@ def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.T
return per_label_logps.squeeze(-1) return per_label_logps.squeeze(-1)
def calc_action_log_probs(logits: torch.Tensor, sequences: torch.LongTensor, num_actions: int) -> torch.Tensor: def calc_action_log_probs(
logits: torch.Tensor,
sequences: torch.LongTensor,
num_actions: int,
shard_config,
vocab_size: int = None,
) -> torch.Tensor:
"""Calculate action log probs. """Calculate action log probs.
Args: Args:
output (torch.Tensor): Output tensor of Actor.forward.logits. logits (torch.Tensor): Output tensor of Actor.forward.logits.
sequences (torch.LongTensor): Input sequences. sequences (torch.LongTensor): Input sequences.
num_actions (int): Number of actions. num_actions (int): Number of actions.
shard_config
vocab_size
Returns: Returns:
torch.Tensor: Action log probs. torch.Tensor: Action log probs.
""" """
log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:]) # labels: torch.Tensor, # [B, S] or [B, S, Vocab_size]
# logits: torch.Tensor, # [B, S, Vocab_size]
log_probs = dist_log_prob(sequences, logits, shard_config, vocab_size, logits.dtype)
log_probs = log_probs.squeeze(-1)
return log_probs[:, -num_actions:] return log_probs[:, -num_actions:]

View File

@ -3,7 +3,7 @@ from .attn import AttnMaskType, ColoAttention, RingAttention, get_pad_info
from .dropout import DropoutForParallelInput, DropoutForReplicatedInput from .dropout import DropoutForParallelInput, DropoutForReplicatedInput
from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D
from .linear import Linear1D_Col, Linear1D_Row, LinearWithGradAccum, PaddingLMHead, VocabParallelLMHead1D from .linear import Linear1D_Col, Linear1D_Row, LinearWithGradAccum, PaddingLMHead, VocabParallelLMHead1D
from .loss import cross_entropy_1d, dist_cross_entropy from .loss import cross_entropy_1d, dist_cross_entropy, dist_log_prob, dist_log_prob_1d
from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm
from .parallel_module import ParallelModule from .parallel_module import ParallelModule
from .qkv_fused_linear import ( from .qkv_fused_linear import (
@ -28,6 +28,8 @@ __all__ = [
"DropoutForReplicatedInput", "DropoutForReplicatedInput",
"cross_entropy_1d", "cross_entropy_1d",
"dist_cross_entropy", "dist_cross_entropy",
"dist_log_prob_1d",
"dist_log_prob",
"BaseLayerNorm", "BaseLayerNorm",
"LayerNorm", "LayerNorm",
"RMSNorm", "RMSNorm",

View File

@ -3,13 +3,21 @@ import torch.distributed as dist
from torch.autograd import Function from torch.autograd import Function
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from torch.nn.functional import log_softmax
from colossalai.shardformer.layer._operation import reduce_forward from colossalai.shardformer.layer._operation import reduce_forward
from colossalai.shardformer.shard import ShardConfig from colossalai.shardformer.shard import ShardConfig
from .utils import is_share_sp_tp from .utils import is_share_sp_tp
__all__ = ["DistCrossEntropy", "cross_entropy_1d", "dist_cross_entropy"] __all__ = [
"DistCrossEntropy",
"cross_entropy_1d",
"dist_cross_entropy",
"DistLogProb",
"dist_log_prob_1d",
"dist_log_prob",
]
_IGNORE_IDX = -100 _IGNORE_IDX = -100
@ -137,6 +145,98 @@ class DistCrossEntropy(Function):
return grad_logits, None, None, None, None, None, None return grad_logits, None, None, None, None, None, None
class DistLogProb(Function):
r"""
Overwrite the forward and backward function to calculate the log prob before gather
Args:
Function (:class:`torch.autograd.Function`): default
"""
@staticmethod
def forward(
ctx,
vocab_logits: torch.Tensor,
target: torch.Tensor,
process_group: ProcessGroup,
vocab_size: int,
dtype=torch.float32,
):
##################
# Step1:Find the global maximum value of logits
##################
logits_max = torch.max(vocab_logits, dim=-1)[0]
handle = dist.all_reduce(logits_max, op=dist.ReduceOp.MAX, group=process_group, async_op=True)
##################
# Step2:Find the local mask. local mask will be use to select log_probs value in Step 4.
# For accleration, we overlap Step 2 and Step 3
##################
rank = dist.get_rank(group=process_group)
world_size = dist.get_world_size(group=process_group)
if vocab_size is None:
partition_vocab_size = vocab_logits.size()[-1]
global_vocab_size = partition_vocab_size * world_size
else:
global_vocab_size = vocab_size
partition_vocab_size = global_vocab_size // world_size
# down and up threshold for local logits
delta = (global_vocab_size + world_size - 1) // world_size
down_threshold = rank * delta
up_threshold = down_threshold + delta
if up_threshold > global_vocab_size:
up_threshold = global_vocab_size
# mask
mask = (target < down_threshold) | (target >= up_threshold)
masked_target = target.clone() - down_threshold
masked_target[mask] = 0
masked_target_1d = masked_target.view(-1).contiguous()
handle.wait()
##################
# Step3:Calculate global summation exp logits
##################
vocab_logits = vocab_logits - logits_max.unsqueeze(dim=-1)
exp_logits = torch.exp(vocab_logits)
sum_exp_logits = torch.sum(exp_logits, dim=-1, dtype=torch.float32) # local summation exp logits
dist.all_reduce(sum_exp_logits, op=dist.ReduceOp.SUM, group=process_group)
##################
# Step4:Calculate local prob. We first cal log_softmax, then select log probs via local mask
##################
log_probs = vocab_logits - torch.log(sum_exp_logits.unsqueeze(dim=-1)) # cal log_softmax
log_probs = log_probs.gather(dim=-1, index=masked_target.unsqueeze(-1))
log_probs[mask.unsqueeze(-1)] = 0 # set masked val to zero
dist.all_reduce(log_probs, op=dist.ReduceOp.SUM, group=process_group)
ctx.save_for_backward(exp_logits, mask, masked_target_1d, sum_exp_logits)
ctx.dtype = dtype
return log_probs
@staticmethod
def backward(ctx, grad_output):
exp_logits, mask, masked_target_1d, sum_exp_logits = ctx.saved_tensors
##################
# Step1:Find the global sofmax value
##################
softmax_logits = exp_logits / sum_exp_logits.unsqueeze(dim=-1)
##################
# Step2:Update softmax value based on local target index
##################
partion_vocab_size = softmax_logits.shape[-1]
softmax_logits_2d = softmax_logits.view(-1, partion_vocab_size)
update = 1.0 - mask.view(-1).float().to(ctx.dtype)
softmax_logits_2d[torch.arange(0, softmax_logits_2d.shape[0]), masked_target_1d] -= update
##################
# Step3:Calculate grad_output, which is the gradient of the loss function with respect to the output of logsoftmax
##################
grad_logits = -softmax_logits.mul_(grad_output)
return grad_logits, None, None, None, None, None, None
def cross_entropy_1d( def cross_entropy_1d(
vocab_logits: torch.Tensor, vocab_logits: torch.Tensor,
labels: torch.Tensor, labels: torch.Tensor,
@ -149,6 +249,16 @@ def cross_entropy_1d(
return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group, vocab_size, dtype, mode) return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group, vocab_size, dtype, mode)
def dist_log_prob_1d(
vocab_logits: torch.Tensor,
labels: torch.Tensor,
process_group: ProcessGroup = None,
vocab_size: int = None,
dtype: torch.dtype = None,
) -> torch.Tensor:
return DistLogProb.apply(vocab_logits, labels, process_group, vocab_size, dtype)
def dist_cross_entropy( def dist_cross_entropy(
labels: torch.Tensor, # [B, S] or [B, S, Vocab_size] labels: torch.Tensor, # [B, S] or [B, S, Vocab_size]
logits: torch.Tensor, # [B, S, Vocab_size] logits: torch.Tensor, # [B, S, Vocab_size]
@ -243,3 +353,41 @@ def dist_cross_entropy(
loss, num_nonzero = loss[0], loss[1].detach() loss, num_nonzero = loss[0], loss[1].detach()
loss = (loss / num_nonzero).squeeze() loss = (loss / num_nonzero).squeeze()
return loss return loss
def dist_log_prob(
labels: torch.Tensor, # [B, S] or [B, S, Vocab_size]
logits: torch.Tensor, # [B, S, Vocab_size]
shard_config: ShardConfig,
vocab_size: int,
dtype: torch.dtype,
seq_dim: int = 1,
) -> torch.Tensor:
"""
Helper to compute log prob for most shardformer models supporting PP, TP.
"""
# Split labels if not gather output
parallel_output = shard_config.parallel_output
is_tp = shard_config.enable_tensor_parallelism
# TODO:support sp
labels = labels[..., 1:]
logits = logits[..., :-1, :]
labels = labels.contiguous()
logits = logits.contiguous()
assert labels.shape == logits.shape[:-1], f"label shape {labels.shape} does not match logit shape {logits.shape}"
# Flatten the tokens
if is_tp and parallel_output:
log_prob = dist_log_prob_1d(
logits,
labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=vocab_size,
dtype=dtype,
)
else:
log_prob = log_softmax(logits)
log_prob = log_prob.gather(dim=-1, index=labels.unsqueeze(-1))
return log_prob

View File

@ -832,7 +832,6 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
loss = None loss = None
if labels is not None: if labels is not None:
loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, logits.dtype) loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, logits.dtype)
if not return_dict: if not return_dict:
output = (logits,) + outputs[1:] output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output return (loss,) + output if loss is not None else output

View File

@ -430,8 +430,12 @@ class Qwen2ForCausalLMPolicy(Qwen2Policy):
sub_module_replacement=[ sub_module_replacement=[
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="lm_head", suffix="lm_head",
target_module=Linear1D_Col, target_module=VocabParallelLMHead1D,
kwargs=dict(fp8_communication=self.shard_config.fp8_communication, use_zbv=use_zbv), kwargs=dict(
gather_output=not self.shard_config.parallel_output,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
) )
], ],
method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)}, method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)},

View File

@ -0,0 +1,52 @@
import pytest
import torch
from coati.distributed.utils import log_probs_from_logits
import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer.layer import dist_log_prob_1d
from colossalai.testing import rerun_if_address_is_in_use, spawn
CONFIG = dict(
parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode="1d")),
)
def check_dist_log_prob(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost", backend="nccl")
# prepare data
pred = torch.randn(2, 4, 8, requires_grad=True).cuda()
labels = torch.randint(8, (2, 4)).cuda()
logprob = log_probs_from_logits(pred, labels)
pred.retain_grad()
logprob.mean().backward()
dist_pred = pred.clone().chunk(world_size, -1)[rank].detach()
dist_pred.requires_grad = True
dist_logprob = dist_log_prob_1d(dist_pred, labels)
dist_pred.retain_grad()
dist_logprob.squeeze(-1).mean().backward()
assert torch.allclose(
logprob, dist_logprob.squeeze(-1), atol=1e-5
), f"dist cross entropy logprob is not equal to orgin logprob\n{logprob}\n{dist_logprob.squeeze(-1)}"
pred_grad_partial = pred.grad.clone().chunk(world_size, -1)[rank].detach()
assert torch.allclose(
pred_grad_partial, dist_pred.grad
), f"dist grad is not equal to orgin grad\n{pred.grad}\n{dist_pred.grad}"
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_dist_log_prob():
spawn(check_dist_log_prob, 2)
if __name__ == "__main__":
test_dist_log_prob()