mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-04-27 11:31:58 +00:00
* add SimPO
* fix dataloader
* remove debug code
* add orpo
* fix style
* fix colossalai, transformers version
* fix colossalai, transformers version
* fix colossalai, transformers version
* fix torch colossalai version
* update transformers version
* [shardformer] DeepseekMoE support (#5871)
* [Feature] deepseek moe expert parallel implement
* [misc] fix typo, remove redundant file (#5867)
* [misc] fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Feature] deepseek support & unit test
* [misc] remove debug code & useless print
* [misc] fix typos (#5872)
* [Feature] remove modeling file, use auto config. (#5884)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [Deepseek] remove redundant code (#5888)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [misc] remove redundant code
* [Feature/deepseek] resolve comment. (#5889)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [misc] remove redundant code
* [misc] mv module replacement into if branch
* [misc] add some warning message and modify some code in unit test
* [misc] fix typos
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Hoxfix] Fix CUDA_DEVICE_MAX_CONNECTIONS for comm overlap
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [Feat] Diffusion Model(PixArtAlpha/StableDiffusion3) Support (#5838)
* Diffusion Model Inference support
* Stable Diffusion 3 Support
* pixartalpha support
* [HotFix] CI,import,requirements-test for #5838 (#5892)
* [Hot Fix] CI,import,requirements-test
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Feature] Enable PP + SP for llama (#5868)
* fix cross-PP-stage position id length diff bug
* fix typo
* fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* use a one cross entropy func for all shardformer models
---------
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [ShardFormer] Add Ulysses Sequence Parallelism support for Command-R, Qwen2 and ChatGLM (#5897)
* add benchmark for sft, dpo, simpo, orpo. Add benchmarking result. Support lora with gradient checkpoint
* fix style
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix eval
* hotfix citation
* [zero] support all-gather overlap (#5898)
* [zero] support all-gather overlap
* [zero] add overlap all-gather flag
* [misc] fix typo
* [zero] update api
* fix orpo cross entropy loss
* [Auto Parallel]: Speed up intra-op plan generation by 44% (#5446)
* Remove unnecessary calls to deepcopy
* Build DimSpec's difference dict only once
This change considerably speeds up construction speed of DimSpec objects. The difference_dict is the same for each DimSpec object, so a single copy of it is enough.
* Fix documentation of DimSpec's difference method
* [ShardFormer] fix qwen2 sp (#5903)
* [compatibility] support torch 2.2 (#5875)
* Support Pytorch 2.2.2
* keep build_on_pr file and update .compatibility
* fix object_to_tensor usage when torch>=2.3.0 (#5820)
* [misc] support torch2.3 (#5893)
* [misc] support torch2.3
* [devops] update compatibility ci
* [devops] update compatibility ci
* [devops] add debug
* [devops] add debug
* [devops] add debug
* [devops] add debug
* [devops] remove debug
* [devops] remove debug
* [release] update version (#5912)
* [plugin] support all-gather overlap for hybrid parallel (#5919)
* [plugin] fixed all-gather overlap support for hybrid parallel
* add kto
* fix style, add kto data sample
* [Examples] Add lazy init to OPT and GPT examples (#5924)
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [ColossalChat] Hotfix for ColossalChat (#5910)
* add ignore and tiny llama
* fix path issue
* run style
* fix issue
* update bash
* add ignore and tiny llama
* fix path issue
* run style
* fix issue
* update bash
* fix ddp issue
* add Qwen 1.5 32B
* refactor tokenization
* [FIX BUG] UnboundLocalError: cannot access local variable 'default_conversation' where it is not associated with a value (#5931)
* cannot access local variable 'default_conversation' where it is not associated with a value
set default value for 'default_conversation'
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* fix test data
* refactor evaluation
* remove real data path
* remove real data path
* Add n_fused as an input from native_module (#5894)
* [FIX BUG] convert env param to int in (#5934)
* [Hotfix] Fix ZeRO typo #5936
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [Feature] Add a switch to control whether the model checkpoint needs to be saved after each epoch ends (#5941)
* Add a switch to control whether the model checkpoint needs to be saved after each epoch ends
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* fix style
* fix style
* fix style
* [shardformer] hotfix attn mask (#5945)
* [shardformer] hotfix attn mask (#5947)
* [Feat] Distrifusion Acceleration Support for Diffusion Inference (#5895)
* Distrifusion Support source
* comp comm overlap optimization
* sd3 benchmark
* pixart distrifusion bug fix
* sd3 bug fix and benchmark
* generation bug fix
* naming fix
* add docstring, fix counter and shape error
* add reference
* readme and requirement
* [zero] hotfix update master params (#5951)
* [release] update version (#5952)
* [Chat] Fix lora (#5946)
* fix merging
* remove filepath
* fix style
* Update README.md (#5958)
* [hotfix] Remove unused plan section (#5957)
* remove readme
* fix readme
* update
* [test] add mixtral for sequence classification
* [test] add mixtral transformer test
* [moe] fix plugin
* [test] mixtra pp shard test
* [chore] handle non member group
* [zero] solve hang
* [test] pass mixtral shardformer test
* [moe] implement transit between non moe tp and ep
* [zero] solve hang
* [misc] solve booster hang by rename the variable
* solve hang when parallel mode = pp + dp
* [moe] implement submesh initialization
* [moe] add mixtral dp grad scaling when not all experts are activated
* [chore] manually revert unintended commit
* [chore] trivial fix
* [chore] arg pass & remove drop token
* [test] add mixtral modelling test
* [moe] implement tp
* [moe] test deepseek
* [moe] clean legacy code
* [Feature] MoE Ulysses Support (#5918)
* moe sp support
* moe sp bug solve
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [chore] minor fix
* [moe] init moe plugin comm setting with sp
* moe sp + ep bug fix
* [moe] finalize test (no pp)
* [moe] full test for deepseek and mixtral (pp + sp to fix)
* [chore] minor fix after rebase
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [chore] solve moe ckpt test failure and some other arg pass failure
* [moe] remove ops
* [test] fix test: test_zero1_2
* [bug] fix: somehow logger hangs the program
* [moe] deepseek moe sp support
* [test] add check
* [deepseek] replace attn (a workaround for bug in transformers)
* [misc] skip redunant test
* [misc] remove debug/print code
* [moe] refactor mesh assignment
* Revert "[moe] implement submesh initialization"
This reverts commit 2f9bce6686
.
* [chore] change moe_pg_mesh to private
* [misc] remove incompatible test config
* [misc] fix ci failure: change default value to false in moe plugin
* [misc] remove useless condition
* [chore] docstring
* [moe] remove force_overlap_comm flag and add warning instead
* [doc] add MoeHybridParallelPlugin docstring
* [moe] solve dp axis issue
* [chore] remove redundant test case, print string & reduce test tokens
* [feat] Dist Loader for Eval (#5950)
* support auto distributed data loader
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* support auto distributed data loader
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix tp error
* remove unused parameters
* remove unused
* update inference
* update docs
* update inference
---------
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [lora] lora support hybrid parallel plugin (#5956)
* lora support hybrid plugin
* fix
* fix
* fix
* fix
* Support overall loss, update KTO logging
* [Docs] clarify launch port
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [Hotfix] README link (#5966)
* update ignore
* update readme
* run style
* update readme
* [Hotfix] Avoid fused RMSnorm import error without apex (#5985)
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [Chat] fix readme (#5989)
* fix readme
* fix readme, tokenization fully tested
* fix readme, tokenization fully tested
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: root <root@notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9-0.notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9.colossal-ai.svc.cluster.local>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* fix sync condition (#6000)
* [plugin] add cast inputs option for zero (#6003)
* [pre-commit.ci] pre-commit autoupdate (#5995)
updates:
- [github.com/psf/black-pre-commit-mirror: 24.4.2 → 24.8.0](https://github.com/psf/black-pre-commit-mirror/compare/24.4.2...24.8.0)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [misc] Bypass the huggingface bug to solve the mask mismatch problem (#5991)
* [Feature] Zigzag Ring attention (#5905)
* halfway
* fix cross-PP-stage position id length diff bug
* fix typo
* fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* unified cross entropy func for all shardformer models
* remove redundant lines
* add basic ring attn; debug cross entropy
* fwd bwd logic complete
* fwd bwd logic complete; add experimental triton rescale
* precision tests passed
* precision tests passed
* fix typos and remove misc files
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* add sp_mode to benchmark; fix varlen interface
* update softmax_lse shape by new interface
* change tester name
* remove buffer clone; support packed seq layout
* add varlen tests
* fix typo
* all tests passed
* add dkv_group; fix mask
* remove debug statements
---------
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [misc] update compatibility (#6008)
* [misc] update compatibility
* [misc] update requirements
* [devops] disable requirements cache
* [test] fix torch ddp test
* [test] fix rerun on address in use
* [test] fix lazy init
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix the merge
* fix the merge
* overlap kv comm with output rescale (#6017)
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* fix the merge
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix the merge
* fix
* fix
* fix the merge
* fix
* [misc] Use dist logger in plugins (#6011)
* use dist logger in plugins
* remove trash
* print on rank 0
---------
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* fix
* fix
* fix
* fix
* fix the merge
* fix
* fix
* fix
* fix
---------
Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: Haze188 <haze188@qq.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com>
Co-authored-by: Guangyao Zhang <xjtu521@qq.com>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: Stephan Kö <stephankoe@users.noreply.github.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: zhurunhua <1281592874@qq.com>
Co-authored-by: Insu Jang <insujang@umich.edu>
Co-authored-by: Gao, Ruiyuan <905370712@qq.com>
Co-authored-by: hxwang <wang1570@e.ntu.edu.sg>
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: root <root@notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9-0.notebook-8f919155-6035-47b4-9c6f-1be133b9e2c9.colossal-ai.svc.cluster.local>
142 lines
5.6 KiB
Python
142 lines
5.6 KiB
Python
import math
|
|
from copy import copy
|
|
|
|
import torch
|
|
from torch.testing import assert_close
|
|
|
|
from colossalai.kernel.kernel_loader import FlashAttentionLoader, FlashAttentionWithCustomMaskLoader
|
|
from colossalai.shardformer.layer import AttnMaskType, ColoAttention
|
|
from colossalai.shardformer.layer.attn import invert_mask
|
|
from colossalai.testing import clear_cache_before_run, parameterize
|
|
from colossalai.utils import get_current_device, set_seed
|
|
|
|
DTYPE = [torch.float16, torch.bfloat16]
|
|
B, N, S, D = 2, 8, 256, 32
|
|
|
|
TOL_MAP = {
|
|
torch.float16: {"atol": 5e-4, "rtol": 2e-3},
|
|
torch.bfloat16: {},
|
|
}
|
|
|
|
|
|
def attention_ref(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask=None, dropout_p=0.0):
|
|
head_dim = q.size(-1)
|
|
attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(head_dim)
|
|
if attn_mask is not None:
|
|
attn_weights = attn_weights + attn_mask
|
|
attn_weights = torch.softmax(attn_weights, dim=-1, dtype=torch.float).to(q.dtype)
|
|
attn_weights = torch.dropout(attn_weights, p=dropout_p, train=True)
|
|
attn_output = torch.matmul(attn_weights, v)
|
|
return attn_output
|
|
|
|
|
|
def gen_padded_kwargs(dtype: torch.dtype):
|
|
padding_mask = torch.ones((B, S), dtype=torch.int, device=get_current_device())
|
|
padding_mask[0, : S // 4] = 0
|
|
return (
|
|
ColoAttention.prepare_attn_kwargs((B, 1, S, S), dtype, padding_mask.device, q_padding_mask=padding_mask),
|
|
padding_mask,
|
|
)
|
|
|
|
|
|
def gen_padded_causal_kwargs(dtype: torch.dtype):
|
|
padding_mask = torch.ones((B, S), dtype=torch.int, device=get_current_device())
|
|
padding_mask[0, S // 2 :] = 0
|
|
return (
|
|
ColoAttention.prepare_attn_kwargs(
|
|
(B, 1, S, S), dtype, padding_mask.device, q_padding_mask=padding_mask, is_causal=True
|
|
),
|
|
padding_mask,
|
|
)
|
|
|
|
|
|
def gen_causal_kwargs(dtype: torch.dtype):
|
|
return ColoAttention.prepare_attn_kwargs((B, 1, S, S), dtype, get_current_device(), is_causal=True), None
|
|
|
|
|
|
def gen_custom_kwargs(dtype: torch.dtype):
|
|
attn_mask = torch.ones((B, S, S), dtype=dtype, device=get_current_device())
|
|
attn_mask[0, : S // 2, S // 2 :] = 0
|
|
attn_mask[0, S // 2 :, : S // 2] = 0
|
|
attn_mask[1, :, S // 4 :] = 0
|
|
attn_mask = invert_mask(attn_mask).unsqueeze(1)
|
|
assert not torch.all(attn_mask != 0, dim=-1).any()
|
|
return {"attention_mask": attn_mask}, None
|
|
|
|
|
|
def post_process_kwargs_for_raw_attn(attn_kwargs: dict):
|
|
if "attention_mask_type" in attn_kwargs:
|
|
attn_kwargs = copy(attn_kwargs)
|
|
mask_type = attn_kwargs.pop("attention_mask_type")
|
|
attn_kwargs["is_causal"] = mask_type in (AttnMaskType.CAUSAL, AttnMaskType.PADDED_CAUSAL)
|
|
return attn_kwargs
|
|
|
|
|
|
def check_attn_func(dtype: torch.dtype, attn_func, attn_kwargs: dict, padding_mask=None):
|
|
tols = TOL_MAP[dtype]
|
|
q = torch.rand((B, N, S, D), dtype=dtype, device=get_current_device(), requires_grad=True)
|
|
k = torch.rand((B, N, S, D), dtype=dtype, device=get_current_device(), requires_grad=True)
|
|
v = torch.rand((B, N, S, D), dtype=dtype, device=get_current_device(), requires_grad=True)
|
|
q_flash = q.clone().detach().requires_grad_(True)
|
|
k_flash = k.clone().detach().requires_grad_(True)
|
|
v_flash = v.clone().detach().requires_grad_(True)
|
|
attn_mask = attn_kwargs.get("attention_mask", None)
|
|
ref_output = attention_ref(q, k, v, attn_mask)
|
|
output = attn_func(q_flash, k_flash, v_flash, **attn_kwargs)
|
|
if padding_mask is not None:
|
|
# [B, Sq] -> [B, 1, Sq, 1]
|
|
padding_mask = padding_mask[:, None, :, None].logical_not()
|
|
ref_output = ref_output.masked_fill(padding_mask, 0)
|
|
output = output.masked_fill(padding_mask, 0)
|
|
|
|
assert_close(output, ref_output, **tols)
|
|
output.mean().backward()
|
|
ref_output.mean().backward()
|
|
assert_close(q.grad, q_flash.grad, **tols)
|
|
assert_close(k.grad, k_flash.grad, **tols)
|
|
assert_close(v.grad, v_flash.grad, **tols)
|
|
|
|
|
|
@clear_cache_before_run()
|
|
@parameterize("dtype", DTYPE)
|
|
def test_flash_attn_func(dtype: torch.dtype):
|
|
torch.backends.cudnn.deterministic = True
|
|
set_seed(0)
|
|
# (func, name, need_postprocess)
|
|
avail_attn_funcs = [(ColoAttention.attention, "coloattn", False)]
|
|
avail_custom_mask_attn_funcs = [(ColoAttention.attention, "coloattn", False)]
|
|
avail_padding_mask_attn_funcs = [(ColoAttention.attention, "coloattn", False)]
|
|
for ext_cls in FlashAttentionLoader.REGISTRY:
|
|
ext = ext_cls()
|
|
if ext.is_available():
|
|
ext.assert_compatible()
|
|
avail_attn_funcs.append((ext.load(), ext.name, True))
|
|
for ext_cls in FlashAttentionWithCustomMaskLoader.REGISTRY:
|
|
ext = ext_cls()
|
|
if ext.is_available():
|
|
ext.assert_compatible()
|
|
avail_custom_mask_attn_funcs.append((ext.load(), ext.name, True))
|
|
|
|
test_sets = {
|
|
"none": (lambda dtype: ({}, None), avail_attn_funcs),
|
|
"padded": (gen_padded_kwargs, avail_padding_mask_attn_funcs),
|
|
"padded_causal": (gen_padded_causal_kwargs, avail_padding_mask_attn_funcs),
|
|
"causal": (gen_causal_kwargs, avail_attn_funcs),
|
|
"custom": (gen_custom_kwargs, avail_custom_mask_attn_funcs),
|
|
}
|
|
|
|
for mask_type, (gen_kwargs_func, attn_funcs) in test_sets.items():
|
|
attn_kwargs, padding_mask = gen_kwargs_func(dtype)
|
|
for attn_func, name, need_postprocess in attn_funcs:
|
|
print(f"{dtype}, {name}, {mask_type}")
|
|
if mask_type == "padded":
|
|
pass
|
|
if need_postprocess:
|
|
check_attn_func(dtype, attn_func, post_process_kwargs_for_raw_attn(attn_kwargs), padding_mask)
|
|
else:
|
|
check_attn_func(dtype, attn_func, attn_kwargs, padding_mask)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_flash_attn_func()
|