mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 19:40:28 +00:00
[Feature] Enable PP + SP for llama (#5868)
* fix cross-PP-stage position id length diff bug * fix typo * fix typo * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * use a one cross entropy func for all shardformer models --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -2,8 +2,11 @@ import torch
|
||||
import torch.distributed as dist
|
||||
from torch.autograd import Function
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
__all__ = ["DistCrossEntropy", "cross_entropy_1d"]
|
||||
from colossalai.shardformer.shard import ShardConfig
|
||||
|
||||
__all__ = ["DistCrossEntropy", "cross_entropy_1d", "dist_cross_entropy"]
|
||||
|
||||
|
||||
class DistCrossEntropy(Function):
|
||||
@@ -132,3 +135,43 @@ def cross_entropy_1d(
|
||||
dtype: torch.dtype = None,
|
||||
) -> torch.Tensor:
|
||||
return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group, vocab_size, dtype)
|
||||
|
||||
|
||||
def dist_cross_entropy(
|
||||
labels: torch.Tensor,
|
||||
logits: torch.Tensor,
|
||||
shard_config: ShardConfig,
|
||||
out_features: int,
|
||||
vocab_size: int,
|
||||
dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Helper to compute cross entropy loss for most shardformer models,
|
||||
compatible with PP, TP and SP.
|
||||
"""
|
||||
if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_labels = shift_labels.view(-1)
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
if shard_config.enable_tensor_parallelism and shard_config.parallel_output:
|
||||
# Cross entropy with all-reduce for TP
|
||||
new_vocab_size = logits.shape[-1]
|
||||
shift_logits = shift_logits.view(-1, new_vocab_size)
|
||||
loss = cross_entropy_1d(
|
||||
shift_logits,
|
||||
shift_labels,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
vocab_size=out_features,
|
||||
dtype=dtype,
|
||||
)
|
||||
else:
|
||||
# NOTE if use TP and not parallel_output, the output is gathered.
|
||||
# see VocabParallelLMHead1D
|
||||
shift_logits = shift_logits.view(-1, vocab_size)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
|
||||
return loss
|
||||
|
Reference in New Issue
Block a user