diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index 380a2ee1b..1e85cccb3 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -73,8 +73,6 @@ class BaseConsumer: ) if self.plugin_config.get("pp_size", 1) > 1 and "num_microbatches" not in self.plugin_config: plugin_config["microbatch_size"] = self.microbatch_size - if self.plugin_config.get("tp_size", 1) > 1: - plugin_config["parallel_output"] = False plugin_config.update(self.plugin_config) self.plugin = HybridParallelPlugin(**plugin_config) self.booster = Booster(plugin=self.plugin) diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 55dfd09ab..b1edb89bb 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -120,14 +120,18 @@ class GRPOConsumer(BaseConsumer): input_ids=data["input_ids"], attention_mask=data["attention_mask"], )["logits"] - action_log_probs = calc_action_log_probs(policy_model_logits, data["input_ids"], num_action) + action_log_probs = calc_action_log_probs( + policy_model_logits, data["input_ids"], num_action, self.plugin.shard_config + ) with torch.no_grad(): reference_model_logits = self.reference_model( input_ids=data["input_ids"], attention_mask=data["attention_mask"], )["logits"] - reference_action_log_probs = calc_action_log_probs(reference_model_logits, data["input_ids"], num_action) + reference_action_log_probs = calc_action_log_probs( + reference_model_logits, data["input_ids"], num_action, self.plugin.shard_config + ) per_token_kl = ( torch.exp(reference_action_log_probs - action_log_probs) diff --git a/applications/ColossalChat/coati/distributed/utils.py b/applications/ColossalChat/coati/distributed/utils.py index 98b54815b..919e4434f 100644 --- a/applications/ColossalChat/coati/distributed/utils.py +++ b/applications/ColossalChat/coati/distributed/utils.py @@ -2,6 +2,8 @@ from typing import Any, Dict, List import torch +from colossalai.shardformer.layer.loss import dist_log_prob + def unbind_batch(batch: Dict[str, torch.Tensor]) -> List[Dict[str, torch.Tensor]]: batches = [] @@ -66,18 +68,30 @@ def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.T return per_label_logps.squeeze(-1) -def calc_action_log_probs(logits: torch.Tensor, sequences: torch.LongTensor, num_actions: int) -> torch.Tensor: +def calc_action_log_probs( + logits: torch.Tensor, + sequences: torch.LongTensor, + num_actions: int, + shard_config, + vocab_size: int = None, +) -> torch.Tensor: """Calculate action log probs. Args: - output (torch.Tensor): Output tensor of Actor.forward.logits. + logits (torch.Tensor): Output tensor of Actor.forward.logits. sequences (torch.LongTensor): Input sequences. num_actions (int): Number of actions. + shard_config + vocab_size + Returns: torch.Tensor: Action log probs. """ - log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:]) + # labels: torch.Tensor, # [B, S] or [B, S, Vocab_size] + # logits: torch.Tensor, # [B, S, Vocab_size] + log_probs = dist_log_prob(sequences, logits, shard_config, vocab_size, logits.dtype) + log_probs = log_probs.squeeze(-1) return log_probs[:, -num_actions:] diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py index 0bd1b6092..a1b80bf56 100644 --- a/colossalai/shardformer/layer/__init__.py +++ b/colossalai/shardformer/layer/__init__.py @@ -3,7 +3,7 @@ from .attn import AttnMaskType, ColoAttention, RingAttention, get_pad_info from .dropout import DropoutForParallelInput, DropoutForReplicatedInput from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D from .linear import Linear1D_Col, Linear1D_Row, LinearWithGradAccum, PaddingLMHead, VocabParallelLMHead1D -from .loss import cross_entropy_1d, dist_cross_entropy +from .loss import cross_entropy_1d, dist_cross_entropy, dist_log_prob, dist_log_prob_1d from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm from .parallel_module import ParallelModule from .qkv_fused_linear import ( @@ -28,6 +28,8 @@ __all__ = [ "DropoutForReplicatedInput", "cross_entropy_1d", "dist_cross_entropy", + "dist_log_prob_1d", + "dist_log_prob", "BaseLayerNorm", "LayerNorm", "RMSNorm", diff --git a/colossalai/shardformer/layer/loss.py b/colossalai/shardformer/layer/loss.py index 0e2241af9..51419a38a 100644 --- a/colossalai/shardformer/layer/loss.py +++ b/colossalai/shardformer/layer/loss.py @@ -3,13 +3,21 @@ import torch.distributed as dist from torch.autograd import Function from torch.distributed import ProcessGroup from torch.nn import CrossEntropyLoss +from torch.nn.functional import log_softmax from colossalai.shardformer.layer._operation import reduce_forward from colossalai.shardformer.shard import ShardConfig from .utils import is_share_sp_tp -__all__ = ["DistCrossEntropy", "cross_entropy_1d", "dist_cross_entropy"] +__all__ = [ + "DistCrossEntropy", + "cross_entropy_1d", + "dist_cross_entropy", + "DistLogProb", + "dist_log_prob_1d", + "dist_log_prob", +] _IGNORE_IDX = -100 @@ -137,6 +145,98 @@ class DistCrossEntropy(Function): return grad_logits, None, None, None, None, None, None +class DistLogProb(Function): + r""" + Overwrite the forward and backward function to calculate the log prob before gather + + Args: + Function (:class:`torch.autograd.Function`): default + """ + + @staticmethod + def forward( + ctx, + vocab_logits: torch.Tensor, + target: torch.Tensor, + process_group: ProcessGroup, + vocab_size: int, + dtype=torch.float32, + ): + + ################## + # Step1:Find the global maximum value of logits + ################## + logits_max = torch.max(vocab_logits, dim=-1)[0] + handle = dist.all_reduce(logits_max, op=dist.ReduceOp.MAX, group=process_group, async_op=True) + + ################## + # Step2:Find the local mask. local mask will be use to select log_probs value in Step 4. + # For accleration, we overlap Step 2 and Step 3 + ################## + rank = dist.get_rank(group=process_group) + world_size = dist.get_world_size(group=process_group) + if vocab_size is None: + partition_vocab_size = vocab_logits.size()[-1] + global_vocab_size = partition_vocab_size * world_size + else: + global_vocab_size = vocab_size + partition_vocab_size = global_vocab_size // world_size + # down and up threshold for local logits + delta = (global_vocab_size + world_size - 1) // world_size + down_threshold = rank * delta + up_threshold = down_threshold + delta + if up_threshold > global_vocab_size: + up_threshold = global_vocab_size + # mask + mask = (target < down_threshold) | (target >= up_threshold) + masked_target = target.clone() - down_threshold + masked_target[mask] = 0 + masked_target_1d = masked_target.view(-1).contiguous() + handle.wait() + + ################## + # Step3:Calculate global summation exp logits + ################## + vocab_logits = vocab_logits - logits_max.unsqueeze(dim=-1) + exp_logits = torch.exp(vocab_logits) + sum_exp_logits = torch.sum(exp_logits, dim=-1, dtype=torch.float32) # local summation exp logits + dist.all_reduce(sum_exp_logits, op=dist.ReduceOp.SUM, group=process_group) + + ################## + # Step4:Calculate local prob. We first cal log_softmax, then select log probs via local mask + ################## + log_probs = vocab_logits - torch.log(sum_exp_logits.unsqueeze(dim=-1)) # cal log_softmax + log_probs = log_probs.gather(dim=-1, index=masked_target.unsqueeze(-1)) + log_probs[mask.unsqueeze(-1)] = 0 # set masked val to zero + dist.all_reduce(log_probs, op=dist.ReduceOp.SUM, group=process_group) + + ctx.save_for_backward(exp_logits, mask, masked_target_1d, sum_exp_logits) + ctx.dtype = dtype + return log_probs + + @staticmethod + def backward(ctx, grad_output): + exp_logits, mask, masked_target_1d, sum_exp_logits = ctx.saved_tensors + ################## + # Step1:Find the global sofmax value + ################## + softmax_logits = exp_logits / sum_exp_logits.unsqueeze(dim=-1) + + ################## + # Step2:Update softmax value based on local target index + ################## + partion_vocab_size = softmax_logits.shape[-1] + softmax_logits_2d = softmax_logits.view(-1, partion_vocab_size) + update = 1.0 - mask.view(-1).float().to(ctx.dtype) + softmax_logits_2d[torch.arange(0, softmax_logits_2d.shape[0]), masked_target_1d] -= update + + ################## + # Step3:Calculate grad_output, which is the gradient of the loss function with respect to the output of logsoftmax + ################## + grad_logits = -softmax_logits.mul_(grad_output) + return grad_logits, None, None, None, None, None, None + + def cross_entropy_1d( vocab_logits: torch.Tensor, labels: torch.Tensor, @@ -149,6 +249,16 @@ def cross_entropy_1d( return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group, vocab_size, dtype, mode) +def dist_log_prob_1d( + vocab_logits: torch.Tensor, + labels: torch.Tensor, + process_group: ProcessGroup = None, + vocab_size: int = None, + dtype: torch.dtype = None, +) -> torch.Tensor: + return DistLogProb.apply(vocab_logits, labels, process_group, vocab_size, dtype) + + def dist_cross_entropy( labels: torch.Tensor, # [B, S] or [B, S, Vocab_size] logits: torch.Tensor, # [B, S, Vocab_size] @@ -243,3 +353,41 @@ def dist_cross_entropy( loss, num_nonzero = loss[0], loss[1].detach() loss = (loss / num_nonzero).squeeze() return loss + + +def dist_log_prob( + labels: torch.Tensor, # [B, S] or [B, S, Vocab_size] + logits: torch.Tensor, # [B, S, Vocab_size] + shard_config: ShardConfig, + vocab_size: int, + dtype: torch.dtype, + seq_dim: int = 1, +) -> torch.Tensor: + """ + Helper to compute log prob for most shardformer models supporting PP, TP. + """ + # Split labels if not gather output + parallel_output = shard_config.parallel_output + is_tp = shard_config.enable_tensor_parallelism + + # TODO:support sp + labels = labels[..., 1:] + logits = logits[..., :-1, :] + labels = labels.contiguous() + logits = logits.contiguous() + assert labels.shape == logits.shape[:-1], f"label shape {labels.shape} does not match logit shape {logits.shape}" + + # Flatten the tokens + if is_tp and parallel_output: + log_prob = dist_log_prob_1d( + logits, + labels, + process_group=shard_config.tensor_parallel_process_group, + vocab_size=vocab_size, + dtype=dtype, + ) + else: + log_prob = log_softmax(logits) + log_prob = log_prob.gather(dim=-1, index=labels.unsqueeze(-1)) + + return log_prob diff --git a/colossalai/shardformer/modeling/qwen2.py b/colossalai/shardformer/modeling/qwen2.py index 5417bf4eb..3371771e0 100644 --- a/colossalai/shardformer/modeling/qwen2.py +++ b/colossalai/shardformer/modeling/qwen2.py @@ -823,7 +823,6 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): loss = None if labels is not None: loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, logits.dtype) - if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output diff --git a/colossalai/shardformer/policies/qwen2.py b/colossalai/shardformer/policies/qwen2.py index a883b7d32..32f7d37b1 100644 --- a/colossalai/shardformer/policies/qwen2.py +++ b/colossalai/shardformer/policies/qwen2.py @@ -412,8 +412,12 @@ class Qwen2ForCausalLMPolicy(Qwen2Policy): sub_module_replacement=[ SubModuleReplacementDescription( suffix="lm_head", - target_module=Linear1D_Col, - kwargs=dict(fp8_communication=self.shard_config.fp8_communication, use_zbv=use_zbv), + target_module=VocabParallelLMHead1D, + kwargs=dict( + gather_output=not self.shard_config.parallel_output, + fp8_communication=self.shard_config.fp8_communication, + use_zbv=use_zbv, + ), ) ], method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)}, diff --git a/tests/test_shardformer/test_layer/test_dist_log_prob.py b/tests/test_shardformer/test_layer/test_dist_log_prob.py new file mode 100644 index 000000000..05a6a5d47 --- /dev/null +++ b/tests/test_shardformer/test_layer/test_dist_log_prob.py @@ -0,0 +1,52 @@ +import pytest +import torch +from coati.distributed.utils import log_probs_from_logits + +import colossalai +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer.layer import dist_log_prob_1d +from colossalai.testing import rerun_if_address_is_in_use, spawn + +CONFIG = dict( + parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode="1d")), +) + + +def check_dist_log_prob(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost", backend="nccl") + + # prepare data + pred = torch.randn(2, 4, 8, requires_grad=True).cuda() + labels = torch.randint(8, (2, 4)).cuda() + + logprob = log_probs_from_logits(pred, labels) + + pred.retain_grad() + logprob.mean().backward() + + dist_pred = pred.clone().chunk(world_size, -1)[rank].detach() + dist_pred.requires_grad = True + dist_logprob = dist_log_prob_1d(dist_pred, labels) + + dist_pred.retain_grad() + dist_logprob.squeeze(-1).mean().backward() + + assert torch.allclose( + logprob, dist_logprob.squeeze(-1), atol=1e-5 + ), f"dist cross entropy logprob is not equal to orgin logprob\n{logprob}\n{dist_logprob.squeeze(-1)}" + + pred_grad_partial = pred.grad.clone().chunk(world_size, -1)[rank].detach() + assert torch.allclose( + pred_grad_partial, dist_pred.grad + ), f"dist grad is not equal to orgin grad\n{pred.grad}\n{dist_pred.grad}" + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_dist_log_prob(): + spawn(check_dist_log_prob, 2) + + +if __name__ == "__main__": + test_dist_log_prob()