ColossalAI/colossalai/shardformer/layer/loss.py
Wang Binluo eea37da6fa
[fp8] Merge feature/fp8_comm to main branch of Colossalai (#6016)
* add SimPO

* fix dataloader

* remove debug code

* add orpo

* fix style

* fix colossalai, transformers version

* fix colossalai, transformers version

* fix colossalai, transformers version

* fix torch colossalai version

* update transformers version

* [shardformer] DeepseekMoE support (#5871)

* [Feature] deepseek moe expert parallel implement

* [misc] fix typo, remove redundant file (#5867)

* [misc] fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Feature] deepseek support & unit test

* [misc] remove debug code & useless print

* [misc] fix typos (#5872)

* [Feature] remove modeling file, use auto config. (#5884)

* [misc] fix typos

* [Feature] deepseek support via auto model, remove modeling file

* [misc] delete useless file

* [misc] fix typos

* [Deepseek] remove redundant code (#5888)

* [misc] fix typos

* [Feature] deepseek support via auto model, remove modeling file

* [misc] delete useless file

* [misc] fix typos

* [misc] remove redundant code

* [Feature/deepseek] resolve comment. (#5889)

* [misc] fix typos

* [Feature] deepseek support via auto model, remove modeling file

* [misc] delete useless file

* [misc] fix typos

* [misc] remove redundant code

* [misc] mv module replacement into if branch

* [misc] add some warning message and modify some code in unit test

* [misc] fix typos

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Hoxfix] Fix CUDA_DEVICE_MAX_CONNECTIONS for comm overlap

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [Feat] Diffusion Model(PixArtAlpha/StableDiffusion3) Support (#5838)

* Diffusion Model Inference support

* Stable Diffusion 3 Support

* pixartalpha support

* [HotFix] CI,import,requirements-test for #5838 (#5892)

* [Hot Fix] CI,import,requirements-test

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Feature] Enable PP + SP for llama (#5868)

* fix cross-PP-stage position id length diff bug

* fix typo

* fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* use a one cross entropy func for all shardformer models

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [ShardFormer] Add Ulysses Sequence Parallelism support for Command-R, Qwen2 and ChatGLM (#5897)

* add benchmark for sft, dpo, simpo, orpo. Add benchmarking result. Support lora with gradient checkpoint

* fix style

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix eval

* hotfix citation

* [zero] support all-gather overlap (#5898)

* [zero] support all-gather overlap

* [zero] add overlap all-gather flag

* [misc] fix typo

* [zero] update api

* fix orpo cross entropy loss

* [Auto Parallel]: Speed up intra-op plan generation by 44% (#5446)

* Remove unnecessary calls to deepcopy

* Build DimSpec's difference dict only once

This change considerably speeds up construction speed of DimSpec objects. The difference_dict is the same for each DimSpec object, so a single copy of it is enough.

* Fix documentation of DimSpec's difference method

* [ShardFormer] fix qwen2 sp (#5903)

* [compatibility] support torch 2.2 (#5875)

* Support Pytorch 2.2.2

* keep build_on_pr file and update .compatibility

* fix object_to_tensor usage when torch>=2.3.0 (#5820)

* [misc] support torch2.3 (#5893)

* [misc] support torch2.3

* [devops] update compatibility ci

* [devops] update compatibility ci

* [devops] add debug

* [devops] add debug

* [devops] add debug

* [devops] add debug

* [devops] remove debug

* [devops] remove debug

* [release] update version (#5912)

* [plugin] support all-gather overlap for hybrid parallel (#5919)

* [plugin] fixed all-gather overlap support for hybrid parallel

* add kto

* fix style, add kto data sample

* [Examples] Add lazy init to OPT and GPT examples (#5924)

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [ColossalChat] Hotfix for ColossalChat (#5910)

* add ignore and tiny llama

* fix path issue

* run style

* fix issue

* update bash

* add ignore and tiny llama

* fix path issue

* run style

* fix issue

* update bash

* fix ddp issue

* add Qwen 1.5 32B

* refactor tokenization

* [FIX BUG] UnboundLocalError: cannot access local variable 'default_conversation' where it is not associated with a value (#5931)

* cannot access local variable 'default_conversation' where it is not associated with a value

set default value for 'default_conversation'

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* fix test data

* refactor evaluation

* remove real data path

* remove real data path

* Add n_fused as an input from native_module (#5894)

* [FIX BUG] convert env param to int in (#5934)

* [Hotfix] Fix ZeRO typo #5936

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [Feature] Add a switch to control whether the model checkpoint needs to be saved after each epoch ends (#5941)

* Add a switch to control whether the model checkpoint needs to be saved after each epoch ends

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* fix style

* fix style

* fix style

* [shardformer] hotfix attn mask (#5945)

* [shardformer] hotfix attn mask (#5947)

* [Feat] Distrifusion Acceleration Support for Diffusion Inference (#5895)

* Distrifusion Support source

* comp comm overlap optimization

* sd3 benchmark

* pixart distrifusion bug fix

* sd3 bug fix and benchmark

* generation bug fix

* naming fix

* add docstring, fix counter and shape error

* add reference

* readme and requirement

* [zero] hotfix update master params (#5951)

* [release] update version (#5952)

* [Chat] Fix lora (#5946)

* fix merging

* remove filepath

* fix style

* Update README.md (#5958)

* [hotfix] Remove unused plan section (#5957)

* remove readme

* fix readme

* update

* [test] add mixtral for sequence classification

* [test] add mixtral transformer test

* [moe] fix plugin

* [test] mixtra pp shard test

* [chore] handle non member group

* [zero] solve hang

* [test] pass mixtral shardformer test

* [moe] implement transit between non moe tp and ep

* [zero] solve hang

* [misc] solve booster hang by rename the variable

* solve hang when parallel mode = pp + dp

* [moe] implement submesh initialization

* [moe] add mixtral dp grad scaling when not all experts are activated

* [chore] manually revert unintended commit

* [chore] trivial fix

* [chore] arg pass & remove drop token

* [test] add mixtral modelling test

* [moe] implement tp

* [moe] test deepseek

* [moe] clean legacy code

* [Feature] MoE Ulysses Support (#5918)

* moe sp support

* moe sp bug solve

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [chore] minor fix

* [moe] init moe plugin comm setting with sp

* moe sp + ep bug fix

* [moe] finalize test (no pp)

* [moe] full test for deepseek and mixtral (pp + sp to fix)

* [chore] minor fix after rebase

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [chore] solve moe ckpt test failure and some other arg pass failure

* [moe] remove ops

* [test] fix test: test_zero1_2

* [bug] fix: somehow logger hangs the program

* [moe] deepseek moe sp support

* [test] add check

* [deepseek] replace attn (a workaround for bug in transformers)

* [misc] skip redunant test

* [misc] remove debug/print code

* [moe] refactor mesh assignment

* Revert "[moe] implement submesh initialization"

This reverts commit 2f9bce6686.

* [chore] change moe_pg_mesh to private

* [misc] remove incompatible test config

* [misc] fix ci failure: change default value to false in moe plugin

* [misc] remove useless condition

* [chore] docstring

* [moe] remove force_overlap_comm flag and add warning instead

* [doc] add MoeHybridParallelPlugin docstring

* [moe] solve dp axis issue

* [chore] remove redundant test case, print string & reduce test tokens

* [feat] Dist Loader for Eval (#5950)

* support auto distributed data loader

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* support auto distributed data loader

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix tp error

* remove unused parameters

* remove unused

* update inference

* update docs

* update inference

---------

Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [lora] lora support hybrid parallel plugin (#5956)

* lora support hybrid plugin

* fix

* fix

* fix

* fix

* Support overall loss, update KTO logging

* [Docs] clarify launch port

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [Hotfix] README link (#5966)

* update ignore

* update readme

* run style

* update readme

* [Hotfix] Avoid fused RMSnorm import error without apex (#5985)

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [Chat] fix readme (#5989)

* fix readme

* fix readme, tokenization fully tested

* fix readme, tokenization fully tested

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: root <root@notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9-0.notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9.colossal-ai.svc.cluster.local>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* fix sync condition (#6000)

* [plugin] add cast inputs option for zero (#6003)

* [pre-commit.ci] pre-commit autoupdate (#5995)

updates:
- [github.com/psf/black-pre-commit-mirror: 24.4.2 → 24.8.0](https://github.com/psf/black-pre-commit-mirror/compare/24.4.2...24.8.0)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [misc] Bypass the huggingface bug to solve the mask mismatch problem (#5991)

* [Feature] Zigzag Ring attention (#5905)

* halfway

* fix cross-PP-stage position id length diff bug

* fix typo

* fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* unified cross entropy func for all shardformer models

* remove redundant lines

* add basic ring attn; debug cross entropy

* fwd bwd logic complete

* fwd bwd logic complete; add experimental triton rescale

* precision tests passed

* precision tests passed

* fix typos and remove misc files

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add sp_mode to benchmark; fix varlen interface

* update softmax_lse shape by new interface

* change tester name

* remove buffer clone; support packed seq layout

* add varlen tests

* fix typo

* all tests passed

* add dkv_group; fix mask

* remove debug statements

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [misc] update compatibility (#6008)

* [misc] update compatibility

* [misc] update requirements

* [devops] disable requirements cache

* [test] fix torch ddp test

* [test] fix rerun on address in use

* [test] fix lazy init

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix the merge

* fix the merge

* overlap kv comm with output rescale (#6017)

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* fix the merge

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix the merge

* fix

* fix

* fix the merge

* fix

* [misc] Use dist logger in plugins (#6011)

* use dist logger in plugins

* remove trash

* print on rank 0

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* fix

* fix

* fix

* fix

* fix the merge

* fix

* fix

* fix

* fix

---------

Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: Haze188 <haze188@qq.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com>
Co-authored-by: Guangyao Zhang <xjtu521@qq.com>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: Stephan Kö <stephankoe@users.noreply.github.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: zhurunhua <1281592874@qq.com>
Co-authored-by: Insu Jang <insujang@umich.edu>
Co-authored-by: Gao, Ruiyuan <905370712@qq.com>
Co-authored-by: hxwang <wang1570@e.ntu.edu.sg>
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: root <root@notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9-0.notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9.colossal-ai.svc.cluster.local>
2024-08-22 09:21:34 +08:00

247 lines
9.7 KiB
Python

import torch
import torch.distributed as dist
from torch.autograd import Function
from torch.distributed import ProcessGroup
from torch.nn import CrossEntropyLoss
from colossalai.shardformer.layer._operation import reduce_forward
from colossalai.shardformer.shard import ShardConfig
from .utils import is_share_sp_tp
__all__ = ["DistCrossEntropy", "cross_entropy_1d", "dist_cross_entropy"]
_IGNORE_IDX = -100
class DistCrossEntropy(Function):
r"""
Overwrite the forward and backward function to calculate the cross entropy loss before gather
Args:
Function (:class:`torch.autograd.Function`): default
"""
@staticmethod
def forward(
ctx,
vocab_logits: torch.Tensor,
target: torch.Tensor,
ignore_index: int,
process_group: ProcessGroup,
vocab_size: int,
dtype=torch.float32,
mode="mean",
):
r"""
Calculate the cross entropy loss before gather, the origin loss function is as follows:
loss = -log(exp(x[class])/sum(exp(x[i]))
and can be rewriten as:
loss = log(sum(exp(x[i])) - x[class]
To avoid the `nan` of log(sum(exp(x[i]))), we minus the max of x[i]
Args:
vocab_logits (:class:`torch.Tensor`): The logits of the vocabulary, shape is
[batch_size, seq_len, vocab_size]
target (:class:`torch.Tensor`): The labels of the vocabulary, shape is
[batch_size, seq_len]
Returns:
:class:`torch.Tensor`: The cross entropy loss
"""
assert mode in ["mean", "sum"]
# get the max
logits_max = torch.max(vocab_logits, dim=-1)[0]
handle = dist.all_reduce(logits_max, op=dist.ReduceOp.MAX, group=process_group, async_op=True)
# mask the target in the local device
rank = dist.get_rank(group=process_group)
world_size = dist.get_world_size(group=process_group)
if vocab_size == None:
partition_vocab_size = vocab_logits.size()[-1]
global_vocab_size = partition_vocab_size * world_size
else:
global_vocab_size = vocab_size
partition_vocab_size = global_vocab_size // world_size
# [down, up) => false, other device and -100 => true
delta = (global_vocab_size + world_size - 1) // world_size
down_threshold = rank * delta
up_threshold = down_threshold + delta
if up_threshold > global_vocab_size:
up_threshold = global_vocab_size
mask = (target < down_threshold) | (target >= up_threshold)
masked_target = target.clone() - down_threshold
masked_target[mask] = 0
masked_target_1d = masked_target.view(-1).contiguous()
# minus the max to avoid the result of sum of exp is too large and the log is nan
handle.wait()
vocab_logits = vocab_logits - logits_max.unsqueeze(dim=-1)
# reshape the logits and target
# reshape the vocab_logits to [bath_size * seq_len, vocab_size]
# reshape the labels to [bath_size * seq_len]
self_vocab_size = vocab_logits.size()[-1]
logits_2d = vocab_logits.view(-1, self_vocab_size)
# extract the x[class] and set the x[other device] to zero
idx = torch.arange(start=0, end=logits_2d.shape[0], device=logits_2d.device)
pred_logits_1d = logits_2d[idx, masked_target_1d].contiguous()
pred_logits = pred_logits_1d.view_as(target)
pred_logits[mask] = 0.0
# all-reduce to get full x[i, y]
handle = dist.all_reduce(pred_logits, op=dist.ReduceOp.SUM, group=process_group, async_op=True)
exp_logits = vocab_logits
torch.exp(vocab_logits, out=exp_logits)
sum_exp_logits = torch.sum(exp_logits, dim=-1, dtype=torch.float32)
dist.all_reduce(sum_exp_logits, op=dist.ReduceOp.SUM, group=process_group)
# calculate the loss
# loss = log(sum(exp(x[i]))) - x[class]
handle.wait()
loss = torch.where(target == ignore_index, 0.0, torch.log(sum_exp_logits) - pred_logits)
if mode == "mean":
num_non_zero = torch.sum(loss != 0.0)
ctx.inv_num_non_zero = 1.0 / num_non_zero
loss = torch.sum(loss).div_(num_non_zero)
else:
loss = torch.sum(loss)
# calculate the softmax
exp_logits = exp_logits.div(sum_exp_logits.unsqueeze(dim=-1)).to(dtype)
exp_logits[target == ignore_index] = 0.0
ctx.save_for_backward(exp_logits, mask, masked_target_1d)
ctx.dtype = dtype
ctx.mode = mode
return loss
@staticmethod
def backward(ctx, grad_output):
# retrieve the saved tensors
if ctx.mode == "mean":
grad_output = grad_output * ctx.inv_num_non_zero
exp_logits, mask, masked_target_1d = ctx.saved_tensors
# use exp logits as the input grad
grad_logits = exp_logits
partion_vocab_size = grad_logits.shape[-1]
grad_logits_2d = grad_logits.view(-1, partion_vocab_size)
update = 1.0 - mask.view(-1).float().to(ctx.dtype)
grad_logits_2d[torch.arange(0, grad_logits_2d.shape[0]), masked_target_1d] -= update
grad_logits.mul_(grad_output.unsqueeze(dim=-1))
return grad_logits, None, None, None, None, None, None
def cross_entropy_1d(
vocab_logits: torch.Tensor,
labels: torch.Tensor,
ignore_index: int = _IGNORE_IDX,
process_group: ProcessGroup = None,
vocab_size: int = None,
dtype: torch.dtype = None,
mode: str = "mean",
) -> torch.Tensor:
return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group, vocab_size, dtype, mode)
def dist_cross_entropy(
labels: torch.Tensor, # [B, S] or [B, S, Vocab_size]
logits: torch.Tensor, # [B, S, Vocab_size]
shard_config: ShardConfig,
out_features: int,
vocab_size: int,
dtype: torch.dtype,
seq_dim: int = 1,
) -> torch.Tensor:
"""
Helper to compute cross entropy loss for most shardformer models supporting PP, TP and SP.
"""
# Split labels if not gather output
sp_group = shard_config.sequence_parallel_process_group
sp_rank = dist.get_rank(sp_group)
sp_size = shard_config.sequence_parallel_size
sp_mode = shard_config.sequence_parallelism_mode
parallel_output = shard_config.parallel_output
is_tp = shard_config.enable_tensor_parallelism
is_packed = labels.dim() == 2
if is_packed:
bs, seq_len = labels.shape
else:
# padded sequence
seq_len = labels.shape[-1]
logits = logits.reshape(-1, *logits.shape[2:])
seq_dim = 0
# Shift labels to predict the next token, and remove the tail logit predicting <EOS>
is_sp = sp_size > 1 and (not is_share_sp_tp(sp_mode))
split_labels_here = seq_len // sp_size == logits.size(seq_dim) # ring attn splits labels before forward
if sp_mode == "ring_attn":
# For Zigzag Ring Attention, labels should've been split and
# shifted by RingAttention.prepare_varlen_batch()
if sp_rank == 0:
logits = logits[..., :-1, :]
logits = torch.cat([logits, torch.full_like(logits[:, :1, :], _IGNORE_IDX)], dim=seq_dim)
elif is_sp:
# Shift only once: either before splitting or in the last rank without splitting
if split_labels_here or (sp_rank == sp_size - 1):
labels = labels[..., 1:]
if split_labels_here:
labels = labels.split(seq_len // sp_size, dim=-1)[sp_rank]
if sp_rank == sp_size - 1:
logits = logits[..., :-1, :]
# Pad logits and labels to the same shape across all ranks for TP all_reduce
if is_tp and parallel_output:
# If is packed sequence (label dim is 1), then each seq already has the end label token padded.
# torch.cat is faster than F.pad...
pad_shape = (logits.shape[0], 1, *logits.shape[2:]) if is_packed else (1, *logits.shape[1:])
padding = torch.full(pad_shape, _IGNORE_IDX, dtype=logits.dtype, device=logits.device)
logits = torch.cat([logits, padding], dim=seq_dim)
pad_shape = (labels.shape[0], 1) if is_packed else (1,)
padding = torch.full(pad_shape, _IGNORE_IDX, dtype=labels.dtype, device=labels.device)
labels = torch.cat([labels, padding], dim=seq_dim)
else:
labels = labels[..., 1:]
logits = logits[..., :-1, :]
labels = labels.contiguous()
logits = logits.contiguous()
num_nonzero = (labels != _IGNORE_IDX).sum()
assert labels.shape == logits.shape[:-1], f"label shape {labels.shape} does not match logit shape {logits.shape}"
# Flatten the tokens
loss_fct = CrossEntropyLoss(ignore_index=_IGNORE_IDX, reduction="sum")
labels = labels.view(-1)
if is_tp and parallel_output:
# Cross entropy with all-reduce for TP
new_vocab_size = logits.shape[-1]
logits = logits.view(-1, new_vocab_size)
loss = cross_entropy_1d(
logits,
labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=out_features,
dtype=dtype,
mode="sum",
)
else:
# NOTE if use TP and not parallel_output, the output is gathered in VocabParallelLMHead1D
logits = logits.view(-1, vocab_size)
loss = loss_fct(logits, labels)
# Reduce loss instead of gathering logits over seq dim for savings
if split_labels_here or sp_mode == "ring_attn":
# Get the global non-zero count
loss = torch.stack((loss, num_nonzero))
# Rescale to offset the grad / (DP * SP) in HybridParallelPlugin
loss = reduce_forward(loss, sp_group, grad_scale=sp_size)
loss, num_nonzero = loss[0], loss[1].detach()
loss = (loss / num_nonzero).squeeze()
return loss