mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-30 15:00:39 +00:00
[Feature] Support Distributed LogProb for GRPO Training (#6247)
* [fix] fix qwen VocabParallelLMHead1D and gather output * fix tp bug * fix consumer * [feat] Support Distributed LogProb for GRPO Training * [fix] fix loss func * [fix] fix log prob plugin * [fix] fix qwen modeling param * [fix] rm comments * [fix] rm hard-code;fix non-dist version * [fix] fix test file param name and benchmark tp gather output=True/False * [fix] rm non-dist version in dist log prob * [fix] fix comments * [fix] fix dis log prob plugin * [fix] fix test case * [fix] fix qwen VocabParallelLMHead1D and gather output * [fix] fix DistLogProb comments * [fix] restore tp size * [fix] fix comments * [fix] fix comment; fix LogSoftmax usage --------- Co-authored-by: Tong Li <tong.li35271158@gmail.com>
This commit is contained in:
parent
bc0171d392
commit
7795d4c50d
@ -73,8 +73,6 @@ class BaseConsumer:
|
||||
)
|
||||
if self.plugin_config.get("pp_size", 1) > 1 and "num_microbatches" not in self.plugin_config:
|
||||
plugin_config["microbatch_size"] = self.microbatch_size
|
||||
if self.plugin_config.get("tp_size", 1) > 1:
|
||||
plugin_config["parallel_output"] = False
|
||||
plugin_config.update(self.plugin_config)
|
||||
self.plugin = HybridParallelPlugin(**plugin_config)
|
||||
self.booster = Booster(plugin=self.plugin)
|
||||
|
@ -120,14 +120,18 @@ class GRPOConsumer(BaseConsumer):
|
||||
input_ids=data["input_ids"],
|
||||
attention_mask=data["attention_mask"],
|
||||
)["logits"]
|
||||
action_log_probs = calc_action_log_probs(policy_model_logits, data["input_ids"], num_action)
|
||||
action_log_probs = calc_action_log_probs(
|
||||
policy_model_logits, data["input_ids"], num_action, self.plugin.shard_config
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
reference_model_logits = self.reference_model(
|
||||
input_ids=data["input_ids"],
|
||||
attention_mask=data["attention_mask"],
|
||||
)["logits"]
|
||||
reference_action_log_probs = calc_action_log_probs(reference_model_logits, data["input_ids"], num_action)
|
||||
reference_action_log_probs = calc_action_log_probs(
|
||||
reference_model_logits, data["input_ids"], num_action, self.plugin.shard_config
|
||||
)
|
||||
|
||||
per_token_kl = (
|
||||
torch.exp(reference_action_log_probs - action_log_probs)
|
||||
|
@ -2,6 +2,8 @@ from typing import Any, Dict, List
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.shardformer.layer.loss import dist_log_prob
|
||||
|
||||
|
||||
def unbind_batch(batch: Dict[str, torch.Tensor]) -> List[Dict[str, torch.Tensor]]:
|
||||
batches = []
|
||||
@ -66,18 +68,30 @@ def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.T
|
||||
return per_label_logps.squeeze(-1)
|
||||
|
||||
|
||||
def calc_action_log_probs(logits: torch.Tensor, sequences: torch.LongTensor, num_actions: int) -> torch.Tensor:
|
||||
def calc_action_log_probs(
|
||||
logits: torch.Tensor,
|
||||
sequences: torch.LongTensor,
|
||||
num_actions: int,
|
||||
shard_config,
|
||||
vocab_size: int = None,
|
||||
) -> torch.Tensor:
|
||||
"""Calculate action log probs.
|
||||
|
||||
Args:
|
||||
output (torch.Tensor): Output tensor of Actor.forward.logits.
|
||||
logits (torch.Tensor): Output tensor of Actor.forward.logits.
|
||||
sequences (torch.LongTensor): Input sequences.
|
||||
num_actions (int): Number of actions.
|
||||
shard_config
|
||||
vocab_size
|
||||
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Action log probs.
|
||||
"""
|
||||
log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
|
||||
# labels: torch.Tensor, # [B, S] or [B, S, Vocab_size]
|
||||
# logits: torch.Tensor, # [B, S, Vocab_size]
|
||||
log_probs = dist_log_prob(sequences, logits, shard_config, vocab_size, logits.dtype)
|
||||
log_probs = log_probs.squeeze(-1)
|
||||
return log_probs[:, -num_actions:]
|
||||
|
||||
|
||||
|
@ -3,7 +3,7 @@ from .attn import AttnMaskType, ColoAttention, RingAttention, get_pad_info
|
||||
from .dropout import DropoutForParallelInput, DropoutForReplicatedInput
|
||||
from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D
|
||||
from .linear import Linear1D_Col, Linear1D_Row, LinearWithGradAccum, PaddingLMHead, VocabParallelLMHead1D
|
||||
from .loss import cross_entropy_1d, dist_cross_entropy
|
||||
from .loss import cross_entropy_1d, dist_cross_entropy, dist_log_prob, dist_log_prob_1d
|
||||
from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm
|
||||
from .parallel_module import ParallelModule
|
||||
from .qkv_fused_linear import (
|
||||
@ -28,6 +28,8 @@ __all__ = [
|
||||
"DropoutForReplicatedInput",
|
||||
"cross_entropy_1d",
|
||||
"dist_cross_entropy",
|
||||
"dist_log_prob_1d",
|
||||
"dist_log_prob",
|
||||
"BaseLayerNorm",
|
||||
"LayerNorm",
|
||||
"RMSNorm",
|
||||
|
@ -3,13 +3,21 @@ import torch.distributed as dist
|
||||
from torch.autograd import Function
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from torch.nn.functional import log_softmax
|
||||
|
||||
from colossalai.shardformer.layer._operation import reduce_forward
|
||||
from colossalai.shardformer.shard import ShardConfig
|
||||
|
||||
from .utils import is_share_sp_tp
|
||||
|
||||
__all__ = ["DistCrossEntropy", "cross_entropy_1d", "dist_cross_entropy"]
|
||||
__all__ = [
|
||||
"DistCrossEntropy",
|
||||
"cross_entropy_1d",
|
||||
"dist_cross_entropy",
|
||||
"DistLogProb",
|
||||
"dist_log_prob_1d",
|
||||
"dist_log_prob",
|
||||
]
|
||||
|
||||
_IGNORE_IDX = -100
|
||||
|
||||
@ -137,6 +145,98 @@ class DistCrossEntropy(Function):
|
||||
return grad_logits, None, None, None, None, None, None
|
||||
|
||||
|
||||
class DistLogProb(Function):
|
||||
r"""
|
||||
Overwrite the forward and backward function to calculate the log prob before gather
|
||||
|
||||
Args:
|
||||
Function (:class:`torch.autograd.Function`): default
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
vocab_logits: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
process_group: ProcessGroup,
|
||||
vocab_size: int,
|
||||
dtype=torch.float32,
|
||||
):
|
||||
|
||||
##################
|
||||
# Step1:Find the global maximum value of logits
|
||||
##################
|
||||
logits_max = torch.max(vocab_logits, dim=-1)[0]
|
||||
handle = dist.all_reduce(logits_max, op=dist.ReduceOp.MAX, group=process_group, async_op=True)
|
||||
|
||||
##################
|
||||
# Step2:Find the local mask. local mask will be use to select log_probs value in Step 4.
|
||||
# For accleration, we overlap Step 2 and Step 3
|
||||
##################
|
||||
rank = dist.get_rank(group=process_group)
|
||||
world_size = dist.get_world_size(group=process_group)
|
||||
if vocab_size is None:
|
||||
partition_vocab_size = vocab_logits.size()[-1]
|
||||
global_vocab_size = partition_vocab_size * world_size
|
||||
else:
|
||||
global_vocab_size = vocab_size
|
||||
partition_vocab_size = global_vocab_size // world_size
|
||||
# down and up threshold for local logits
|
||||
delta = (global_vocab_size + world_size - 1) // world_size
|
||||
down_threshold = rank * delta
|
||||
up_threshold = down_threshold + delta
|
||||
if up_threshold > global_vocab_size:
|
||||
up_threshold = global_vocab_size
|
||||
# mask
|
||||
mask = (target < down_threshold) | (target >= up_threshold)
|
||||
masked_target = target.clone() - down_threshold
|
||||
masked_target[mask] = 0
|
||||
masked_target_1d = masked_target.view(-1).contiguous()
|
||||
handle.wait()
|
||||
|
||||
##################
|
||||
# Step3:Calculate global summation exp logits
|
||||
##################
|
||||
vocab_logits = vocab_logits - logits_max.unsqueeze(dim=-1)
|
||||
exp_logits = torch.exp(vocab_logits)
|
||||
sum_exp_logits = torch.sum(exp_logits, dim=-1, dtype=torch.float32) # local summation exp logits
|
||||
dist.all_reduce(sum_exp_logits, op=dist.ReduceOp.SUM, group=process_group)
|
||||
|
||||
##################
|
||||
# Step4:Calculate local prob. We first cal log_softmax, then select log probs via local mask
|
||||
##################
|
||||
log_probs = vocab_logits - torch.log(sum_exp_logits.unsqueeze(dim=-1)) # cal log_softmax
|
||||
log_probs = log_probs.gather(dim=-1, index=masked_target.unsqueeze(-1))
|
||||
log_probs[mask.unsqueeze(-1)] = 0 # set masked val to zero
|
||||
dist.all_reduce(log_probs, op=dist.ReduceOp.SUM, group=process_group)
|
||||
|
||||
ctx.save_for_backward(exp_logits, mask, masked_target_1d, sum_exp_logits)
|
||||
ctx.dtype = dtype
|
||||
return log_probs
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
exp_logits, mask, masked_target_1d, sum_exp_logits = ctx.saved_tensors
|
||||
##################
|
||||
# Step1:Find the global sofmax value
|
||||
##################
|
||||
softmax_logits = exp_logits / sum_exp_logits.unsqueeze(dim=-1)
|
||||
|
||||
##################
|
||||
# Step2:Update softmax value based on local target index
|
||||
##################
|
||||
partion_vocab_size = softmax_logits.shape[-1]
|
||||
softmax_logits_2d = softmax_logits.view(-1, partion_vocab_size)
|
||||
update = 1.0 - mask.view(-1).float().to(ctx.dtype)
|
||||
softmax_logits_2d[torch.arange(0, softmax_logits_2d.shape[0]), masked_target_1d] -= update
|
||||
|
||||
##################
|
||||
# Step3:Calculate grad_output, which is the gradient of the loss function with respect to the output of logsoftmax
|
||||
##################
|
||||
grad_logits = -softmax_logits.mul_(grad_output)
|
||||
return grad_logits, None, None, None, None, None, None
|
||||
|
||||
|
||||
def cross_entropy_1d(
|
||||
vocab_logits: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
@ -149,6 +249,16 @@ def cross_entropy_1d(
|
||||
return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group, vocab_size, dtype, mode)
|
||||
|
||||
|
||||
def dist_log_prob_1d(
|
||||
vocab_logits: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
process_group: ProcessGroup = None,
|
||||
vocab_size: int = None,
|
||||
dtype: torch.dtype = None,
|
||||
) -> torch.Tensor:
|
||||
return DistLogProb.apply(vocab_logits, labels, process_group, vocab_size, dtype)
|
||||
|
||||
|
||||
def dist_cross_entropy(
|
||||
labels: torch.Tensor, # [B, S] or [B, S, Vocab_size]
|
||||
logits: torch.Tensor, # [B, S, Vocab_size]
|
||||
@ -243,3 +353,41 @@ def dist_cross_entropy(
|
||||
loss, num_nonzero = loss[0], loss[1].detach()
|
||||
loss = (loss / num_nonzero).squeeze()
|
||||
return loss
|
||||
|
||||
|
||||
def dist_log_prob(
|
||||
labels: torch.Tensor, # [B, S] or [B, S, Vocab_size]
|
||||
logits: torch.Tensor, # [B, S, Vocab_size]
|
||||
shard_config: ShardConfig,
|
||||
vocab_size: int,
|
||||
dtype: torch.dtype,
|
||||
seq_dim: int = 1,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Helper to compute log prob for most shardformer models supporting PP, TP.
|
||||
"""
|
||||
# Split labels if not gather output
|
||||
parallel_output = shard_config.parallel_output
|
||||
is_tp = shard_config.enable_tensor_parallelism
|
||||
|
||||
# TODO:support sp
|
||||
labels = labels[..., 1:]
|
||||
logits = logits[..., :-1, :]
|
||||
labels = labels.contiguous()
|
||||
logits = logits.contiguous()
|
||||
assert labels.shape == logits.shape[:-1], f"label shape {labels.shape} does not match logit shape {logits.shape}"
|
||||
|
||||
# Flatten the tokens
|
||||
if is_tp and parallel_output:
|
||||
log_prob = dist_log_prob_1d(
|
||||
logits,
|
||||
labels,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
vocab_size=vocab_size,
|
||||
dtype=dtype,
|
||||
)
|
||||
else:
|
||||
log_prob = log_softmax(logits)
|
||||
log_prob = log_prob.gather(dim=-1, index=labels.unsqueeze(-1))
|
||||
|
||||
return log_prob
|
||||
|
@ -832,7 +832,6 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, logits.dtype)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
@ -430,8 +430,12 @@ class Qwen2ForCausalLMPolicy(Qwen2Policy):
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="lm_head",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(fp8_communication=self.shard_config.fp8_communication, use_zbv=use_zbv),
|
||||
target_module=VocabParallelLMHead1D,
|
||||
kwargs=dict(
|
||||
gather_output=not self.shard_config.parallel_output,
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
)
|
||||
],
|
||||
method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)},
|
||||
|
52
tests/test_shardformer/test_layer/test_dist_log_prob.py
Normal file
52
tests/test_shardformer/test_layer/test_dist_log_prob.py
Normal file
@ -0,0 +1,52 @@
|
||||
import pytest
|
||||
import torch
|
||||
from coati.distributed.utils import log_probs_from_logits
|
||||
|
||||
import colossalai
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.shardformer.layer import dist_log_prob_1d
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
CONFIG = dict(
|
||||
parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode="1d")),
|
||||
)
|
||||
|
||||
|
||||
def check_dist_log_prob(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost", backend="nccl")
|
||||
|
||||
# prepare data
|
||||
pred = torch.randn(2, 4, 8, requires_grad=True).cuda()
|
||||
labels = torch.randint(8, (2, 4)).cuda()
|
||||
|
||||
logprob = log_probs_from_logits(pred, labels)
|
||||
|
||||
pred.retain_grad()
|
||||
logprob.mean().backward()
|
||||
|
||||
dist_pred = pred.clone().chunk(world_size, -1)[rank].detach()
|
||||
dist_pred.requires_grad = True
|
||||
dist_logprob = dist_log_prob_1d(dist_pred, labels)
|
||||
|
||||
dist_pred.retain_grad()
|
||||
dist_logprob.squeeze(-1).mean().backward()
|
||||
|
||||
assert torch.allclose(
|
||||
logprob, dist_logprob.squeeze(-1), atol=1e-5
|
||||
), f"dist cross entropy logprob is not equal to orgin logprob\n{logprob}\n{dist_logprob.squeeze(-1)}"
|
||||
|
||||
pred_grad_partial = pred.grad.clone().chunk(world_size, -1)[rank].detach()
|
||||
assert torch.allclose(
|
||||
pred_grad_partial, dist_pred.grad
|
||||
), f"dist grad is not equal to orgin grad\n{pred.grad}\n{dist_pred.grad}"
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_dist_log_prob():
|
||||
spawn(check_dist_log_prob, 2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_dist_log_prob()
|
Loading…
Reference in New Issue
Block a user