mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-06 02:24:31 +00:00
[Feature] Support Distributed LogProb for GRPO Training (#6247)
* [fix] fix qwen VocabParallelLMHead1D and gather output * fix tp bug * fix consumer * [feat] Support Distributed LogProb for GRPO Training * [fix] fix loss func * [fix] fix log prob plugin * [fix] fix qwen modeling param * [fix] rm comments * [fix] rm hard-code;fix non-dist version * [fix] fix test file param name and benchmark tp gather output=True/False * [fix] rm non-dist version in dist log prob * [fix] fix comments * [fix] fix dis log prob plugin * [fix] fix test case * [fix] fix qwen VocabParallelLMHead1D and gather output * [fix] fix DistLogProb comments * [fix] restore tp size * [fix] fix comments * [fix] fix comment; fix LogSoftmax usage --------- Co-authored-by: Tong Li <tong.li35271158@gmail.com>
This commit is contained in:
parent
bc0171d392
commit
7795d4c50d
@ -73,8 +73,6 @@ class BaseConsumer:
|
|||||||
)
|
)
|
||||||
if self.plugin_config.get("pp_size", 1) > 1 and "num_microbatches" not in self.plugin_config:
|
if self.plugin_config.get("pp_size", 1) > 1 and "num_microbatches" not in self.plugin_config:
|
||||||
plugin_config["microbatch_size"] = self.microbatch_size
|
plugin_config["microbatch_size"] = self.microbatch_size
|
||||||
if self.plugin_config.get("tp_size", 1) > 1:
|
|
||||||
plugin_config["parallel_output"] = False
|
|
||||||
plugin_config.update(self.plugin_config)
|
plugin_config.update(self.plugin_config)
|
||||||
self.plugin = HybridParallelPlugin(**plugin_config)
|
self.plugin = HybridParallelPlugin(**plugin_config)
|
||||||
self.booster = Booster(plugin=self.plugin)
|
self.booster = Booster(plugin=self.plugin)
|
||||||
|
@ -120,14 +120,18 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
input_ids=data["input_ids"],
|
input_ids=data["input_ids"],
|
||||||
attention_mask=data["attention_mask"],
|
attention_mask=data["attention_mask"],
|
||||||
)["logits"]
|
)["logits"]
|
||||||
action_log_probs = calc_action_log_probs(policy_model_logits, data["input_ids"], num_action)
|
action_log_probs = calc_action_log_probs(
|
||||||
|
policy_model_logits, data["input_ids"], num_action, self.plugin.shard_config
|
||||||
|
)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
reference_model_logits = self.reference_model(
|
reference_model_logits = self.reference_model(
|
||||||
input_ids=data["input_ids"],
|
input_ids=data["input_ids"],
|
||||||
attention_mask=data["attention_mask"],
|
attention_mask=data["attention_mask"],
|
||||||
)["logits"]
|
)["logits"]
|
||||||
reference_action_log_probs = calc_action_log_probs(reference_model_logits, data["input_ids"], num_action)
|
reference_action_log_probs = calc_action_log_probs(
|
||||||
|
reference_model_logits, data["input_ids"], num_action, self.plugin.shard_config
|
||||||
|
)
|
||||||
|
|
||||||
per_token_kl = (
|
per_token_kl = (
|
||||||
torch.exp(reference_action_log_probs - action_log_probs)
|
torch.exp(reference_action_log_probs - action_log_probs)
|
||||||
|
@ -2,6 +2,8 @@ from typing import Any, Dict, List
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from colossalai.shardformer.layer.loss import dist_log_prob
|
||||||
|
|
||||||
|
|
||||||
def unbind_batch(batch: Dict[str, torch.Tensor]) -> List[Dict[str, torch.Tensor]]:
|
def unbind_batch(batch: Dict[str, torch.Tensor]) -> List[Dict[str, torch.Tensor]]:
|
||||||
batches = []
|
batches = []
|
||||||
@ -66,18 +68,30 @@ def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.T
|
|||||||
return per_label_logps.squeeze(-1)
|
return per_label_logps.squeeze(-1)
|
||||||
|
|
||||||
|
|
||||||
def calc_action_log_probs(logits: torch.Tensor, sequences: torch.LongTensor, num_actions: int) -> torch.Tensor:
|
def calc_action_log_probs(
|
||||||
|
logits: torch.Tensor,
|
||||||
|
sequences: torch.LongTensor,
|
||||||
|
num_actions: int,
|
||||||
|
shard_config,
|
||||||
|
vocab_size: int = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
"""Calculate action log probs.
|
"""Calculate action log probs.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
output (torch.Tensor): Output tensor of Actor.forward.logits.
|
logits (torch.Tensor): Output tensor of Actor.forward.logits.
|
||||||
sequences (torch.LongTensor): Input sequences.
|
sequences (torch.LongTensor): Input sequences.
|
||||||
num_actions (int): Number of actions.
|
num_actions (int): Number of actions.
|
||||||
|
shard_config
|
||||||
|
vocab_size
|
||||||
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
torch.Tensor: Action log probs.
|
torch.Tensor: Action log probs.
|
||||||
"""
|
"""
|
||||||
log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
|
# labels: torch.Tensor, # [B, S] or [B, S, Vocab_size]
|
||||||
|
# logits: torch.Tensor, # [B, S, Vocab_size]
|
||||||
|
log_probs = dist_log_prob(sequences, logits, shard_config, vocab_size, logits.dtype)
|
||||||
|
log_probs = log_probs.squeeze(-1)
|
||||||
return log_probs[:, -num_actions:]
|
return log_probs[:, -num_actions:]
|
||||||
|
|
||||||
|
|
||||||
|
@ -3,7 +3,7 @@ from .attn import AttnMaskType, ColoAttention, RingAttention, get_pad_info
|
|||||||
from .dropout import DropoutForParallelInput, DropoutForReplicatedInput
|
from .dropout import DropoutForParallelInput, DropoutForReplicatedInput
|
||||||
from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D
|
from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D
|
||||||
from .linear import Linear1D_Col, Linear1D_Row, LinearWithGradAccum, PaddingLMHead, VocabParallelLMHead1D
|
from .linear import Linear1D_Col, Linear1D_Row, LinearWithGradAccum, PaddingLMHead, VocabParallelLMHead1D
|
||||||
from .loss import cross_entropy_1d, dist_cross_entropy
|
from .loss import cross_entropy_1d, dist_cross_entropy, dist_log_prob, dist_log_prob_1d
|
||||||
from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm
|
from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm
|
||||||
from .parallel_module import ParallelModule
|
from .parallel_module import ParallelModule
|
||||||
from .qkv_fused_linear import (
|
from .qkv_fused_linear import (
|
||||||
@ -28,6 +28,8 @@ __all__ = [
|
|||||||
"DropoutForReplicatedInput",
|
"DropoutForReplicatedInput",
|
||||||
"cross_entropy_1d",
|
"cross_entropy_1d",
|
||||||
"dist_cross_entropy",
|
"dist_cross_entropy",
|
||||||
|
"dist_log_prob_1d",
|
||||||
|
"dist_log_prob",
|
||||||
"BaseLayerNorm",
|
"BaseLayerNorm",
|
||||||
"LayerNorm",
|
"LayerNorm",
|
||||||
"RMSNorm",
|
"RMSNorm",
|
||||||
|
@ -3,13 +3,21 @@ import torch.distributed as dist
|
|||||||
from torch.autograd import Function
|
from torch.autograd import Function
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
from torch.nn import CrossEntropyLoss
|
from torch.nn import CrossEntropyLoss
|
||||||
|
from torch.nn.functional import log_softmax
|
||||||
|
|
||||||
from colossalai.shardformer.layer._operation import reduce_forward
|
from colossalai.shardformer.layer._operation import reduce_forward
|
||||||
from colossalai.shardformer.shard import ShardConfig
|
from colossalai.shardformer.shard import ShardConfig
|
||||||
|
|
||||||
from .utils import is_share_sp_tp
|
from .utils import is_share_sp_tp
|
||||||
|
|
||||||
__all__ = ["DistCrossEntropy", "cross_entropy_1d", "dist_cross_entropy"]
|
__all__ = [
|
||||||
|
"DistCrossEntropy",
|
||||||
|
"cross_entropy_1d",
|
||||||
|
"dist_cross_entropy",
|
||||||
|
"DistLogProb",
|
||||||
|
"dist_log_prob_1d",
|
||||||
|
"dist_log_prob",
|
||||||
|
]
|
||||||
|
|
||||||
_IGNORE_IDX = -100
|
_IGNORE_IDX = -100
|
||||||
|
|
||||||
@ -137,6 +145,98 @@ class DistCrossEntropy(Function):
|
|||||||
return grad_logits, None, None, None, None, None, None
|
return grad_logits, None, None, None, None, None, None
|
||||||
|
|
||||||
|
|
||||||
|
class DistLogProb(Function):
|
||||||
|
r"""
|
||||||
|
Overwrite the forward and backward function to calculate the log prob before gather
|
||||||
|
|
||||||
|
Args:
|
||||||
|
Function (:class:`torch.autograd.Function`): default
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(
|
||||||
|
ctx,
|
||||||
|
vocab_logits: torch.Tensor,
|
||||||
|
target: torch.Tensor,
|
||||||
|
process_group: ProcessGroup,
|
||||||
|
vocab_size: int,
|
||||||
|
dtype=torch.float32,
|
||||||
|
):
|
||||||
|
|
||||||
|
##################
|
||||||
|
# Step1:Find the global maximum value of logits
|
||||||
|
##################
|
||||||
|
logits_max = torch.max(vocab_logits, dim=-1)[0]
|
||||||
|
handle = dist.all_reduce(logits_max, op=dist.ReduceOp.MAX, group=process_group, async_op=True)
|
||||||
|
|
||||||
|
##################
|
||||||
|
# Step2:Find the local mask. local mask will be use to select log_probs value in Step 4.
|
||||||
|
# For accleration, we overlap Step 2 and Step 3
|
||||||
|
##################
|
||||||
|
rank = dist.get_rank(group=process_group)
|
||||||
|
world_size = dist.get_world_size(group=process_group)
|
||||||
|
if vocab_size is None:
|
||||||
|
partition_vocab_size = vocab_logits.size()[-1]
|
||||||
|
global_vocab_size = partition_vocab_size * world_size
|
||||||
|
else:
|
||||||
|
global_vocab_size = vocab_size
|
||||||
|
partition_vocab_size = global_vocab_size // world_size
|
||||||
|
# down and up threshold for local logits
|
||||||
|
delta = (global_vocab_size + world_size - 1) // world_size
|
||||||
|
down_threshold = rank * delta
|
||||||
|
up_threshold = down_threshold + delta
|
||||||
|
if up_threshold > global_vocab_size:
|
||||||
|
up_threshold = global_vocab_size
|
||||||
|
# mask
|
||||||
|
mask = (target < down_threshold) | (target >= up_threshold)
|
||||||
|
masked_target = target.clone() - down_threshold
|
||||||
|
masked_target[mask] = 0
|
||||||
|
masked_target_1d = masked_target.view(-1).contiguous()
|
||||||
|
handle.wait()
|
||||||
|
|
||||||
|
##################
|
||||||
|
# Step3:Calculate global summation exp logits
|
||||||
|
##################
|
||||||
|
vocab_logits = vocab_logits - logits_max.unsqueeze(dim=-1)
|
||||||
|
exp_logits = torch.exp(vocab_logits)
|
||||||
|
sum_exp_logits = torch.sum(exp_logits, dim=-1, dtype=torch.float32) # local summation exp logits
|
||||||
|
dist.all_reduce(sum_exp_logits, op=dist.ReduceOp.SUM, group=process_group)
|
||||||
|
|
||||||
|
##################
|
||||||
|
# Step4:Calculate local prob. We first cal log_softmax, then select log probs via local mask
|
||||||
|
##################
|
||||||
|
log_probs = vocab_logits - torch.log(sum_exp_logits.unsqueeze(dim=-1)) # cal log_softmax
|
||||||
|
log_probs = log_probs.gather(dim=-1, index=masked_target.unsqueeze(-1))
|
||||||
|
log_probs[mask.unsqueeze(-1)] = 0 # set masked val to zero
|
||||||
|
dist.all_reduce(log_probs, op=dist.ReduceOp.SUM, group=process_group)
|
||||||
|
|
||||||
|
ctx.save_for_backward(exp_logits, mask, masked_target_1d, sum_exp_logits)
|
||||||
|
ctx.dtype = dtype
|
||||||
|
return log_probs
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_output):
|
||||||
|
exp_logits, mask, masked_target_1d, sum_exp_logits = ctx.saved_tensors
|
||||||
|
##################
|
||||||
|
# Step1:Find the global sofmax value
|
||||||
|
##################
|
||||||
|
softmax_logits = exp_logits / sum_exp_logits.unsqueeze(dim=-1)
|
||||||
|
|
||||||
|
##################
|
||||||
|
# Step2:Update softmax value based on local target index
|
||||||
|
##################
|
||||||
|
partion_vocab_size = softmax_logits.shape[-1]
|
||||||
|
softmax_logits_2d = softmax_logits.view(-1, partion_vocab_size)
|
||||||
|
update = 1.0 - mask.view(-1).float().to(ctx.dtype)
|
||||||
|
softmax_logits_2d[torch.arange(0, softmax_logits_2d.shape[0]), masked_target_1d] -= update
|
||||||
|
|
||||||
|
##################
|
||||||
|
# Step3:Calculate grad_output, which is the gradient of the loss function with respect to the output of logsoftmax
|
||||||
|
##################
|
||||||
|
grad_logits = -softmax_logits.mul_(grad_output)
|
||||||
|
return grad_logits, None, None, None, None, None, None
|
||||||
|
|
||||||
|
|
||||||
def cross_entropy_1d(
|
def cross_entropy_1d(
|
||||||
vocab_logits: torch.Tensor,
|
vocab_logits: torch.Tensor,
|
||||||
labels: torch.Tensor,
|
labels: torch.Tensor,
|
||||||
@ -149,6 +249,16 @@ def cross_entropy_1d(
|
|||||||
return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group, vocab_size, dtype, mode)
|
return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group, vocab_size, dtype, mode)
|
||||||
|
|
||||||
|
|
||||||
|
def dist_log_prob_1d(
|
||||||
|
vocab_logits: torch.Tensor,
|
||||||
|
labels: torch.Tensor,
|
||||||
|
process_group: ProcessGroup = None,
|
||||||
|
vocab_size: int = None,
|
||||||
|
dtype: torch.dtype = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return DistLogProb.apply(vocab_logits, labels, process_group, vocab_size, dtype)
|
||||||
|
|
||||||
|
|
||||||
def dist_cross_entropy(
|
def dist_cross_entropy(
|
||||||
labels: torch.Tensor, # [B, S] or [B, S, Vocab_size]
|
labels: torch.Tensor, # [B, S] or [B, S, Vocab_size]
|
||||||
logits: torch.Tensor, # [B, S, Vocab_size]
|
logits: torch.Tensor, # [B, S, Vocab_size]
|
||||||
@ -243,3 +353,41 @@ def dist_cross_entropy(
|
|||||||
loss, num_nonzero = loss[0], loss[1].detach()
|
loss, num_nonzero = loss[0], loss[1].detach()
|
||||||
loss = (loss / num_nonzero).squeeze()
|
loss = (loss / num_nonzero).squeeze()
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
def dist_log_prob(
|
||||||
|
labels: torch.Tensor, # [B, S] or [B, S, Vocab_size]
|
||||||
|
logits: torch.Tensor, # [B, S, Vocab_size]
|
||||||
|
shard_config: ShardConfig,
|
||||||
|
vocab_size: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
seq_dim: int = 1,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Helper to compute log prob for most shardformer models supporting PP, TP.
|
||||||
|
"""
|
||||||
|
# Split labels if not gather output
|
||||||
|
parallel_output = shard_config.parallel_output
|
||||||
|
is_tp = shard_config.enable_tensor_parallelism
|
||||||
|
|
||||||
|
# TODO:support sp
|
||||||
|
labels = labels[..., 1:]
|
||||||
|
logits = logits[..., :-1, :]
|
||||||
|
labels = labels.contiguous()
|
||||||
|
logits = logits.contiguous()
|
||||||
|
assert labels.shape == logits.shape[:-1], f"label shape {labels.shape} does not match logit shape {logits.shape}"
|
||||||
|
|
||||||
|
# Flatten the tokens
|
||||||
|
if is_tp and parallel_output:
|
||||||
|
log_prob = dist_log_prob_1d(
|
||||||
|
logits,
|
||||||
|
labels,
|
||||||
|
process_group=shard_config.tensor_parallel_process_group,
|
||||||
|
vocab_size=vocab_size,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
log_prob = log_softmax(logits)
|
||||||
|
log_prob = log_prob.gather(dim=-1, index=labels.unsqueeze(-1))
|
||||||
|
|
||||||
|
return log_prob
|
||||||
|
@ -832,7 +832,6 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
|||||||
loss = None
|
loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, logits.dtype)
|
loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, logits.dtype)
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (logits,) + outputs[1:]
|
output = (logits,) + outputs[1:]
|
||||||
return (loss,) + output if loss is not None else output
|
return (loss,) + output if loss is not None else output
|
||||||
|
@ -430,8 +430,12 @@ class Qwen2ForCausalLMPolicy(Qwen2Policy):
|
|||||||
sub_module_replacement=[
|
sub_module_replacement=[
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="lm_head",
|
suffix="lm_head",
|
||||||
target_module=Linear1D_Col,
|
target_module=VocabParallelLMHead1D,
|
||||||
kwargs=dict(fp8_communication=self.shard_config.fp8_communication, use_zbv=use_zbv),
|
kwargs=dict(
|
||||||
|
gather_output=not self.shard_config.parallel_output,
|
||||||
|
fp8_communication=self.shard_config.fp8_communication,
|
||||||
|
use_zbv=use_zbv,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)},
|
method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)},
|
||||||
|
52
tests/test_shardformer/test_layer/test_dist_log_prob.py
Normal file
52
tests/test_shardformer/test_layer/test_dist_log_prob.py
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from coati.distributed.utils import log_probs_from_logits
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.logging import disable_existing_loggers
|
||||||
|
from colossalai.shardformer.layer import dist_log_prob_1d
|
||||||
|
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||||
|
|
||||||
|
CONFIG = dict(
|
||||||
|
parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode="1d")),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def check_dist_log_prob(rank, world_size, port):
|
||||||
|
disable_existing_loggers()
|
||||||
|
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost", backend="nccl")
|
||||||
|
|
||||||
|
# prepare data
|
||||||
|
pred = torch.randn(2, 4, 8, requires_grad=True).cuda()
|
||||||
|
labels = torch.randint(8, (2, 4)).cuda()
|
||||||
|
|
||||||
|
logprob = log_probs_from_logits(pred, labels)
|
||||||
|
|
||||||
|
pred.retain_grad()
|
||||||
|
logprob.mean().backward()
|
||||||
|
|
||||||
|
dist_pred = pred.clone().chunk(world_size, -1)[rank].detach()
|
||||||
|
dist_pred.requires_grad = True
|
||||||
|
dist_logprob = dist_log_prob_1d(dist_pred, labels)
|
||||||
|
|
||||||
|
dist_pred.retain_grad()
|
||||||
|
dist_logprob.squeeze(-1).mean().backward()
|
||||||
|
|
||||||
|
assert torch.allclose(
|
||||||
|
logprob, dist_logprob.squeeze(-1), atol=1e-5
|
||||||
|
), f"dist cross entropy logprob is not equal to orgin logprob\n{logprob}\n{dist_logprob.squeeze(-1)}"
|
||||||
|
|
||||||
|
pred_grad_partial = pred.grad.clone().chunk(world_size, -1)[rank].detach()
|
||||||
|
assert torch.allclose(
|
||||||
|
pred_grad_partial, dist_pred.grad
|
||||||
|
), f"dist grad is not equal to orgin grad\n{pred.grad}\n{dist_pred.grad}"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.dist
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
def test_dist_log_prob():
|
||||||
|
spawn(check_dist_log_prob, 2)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_dist_log_prob()
|
Loading…
Reference in New Issue
Block a user