[coloattention]modify coloattention (#5627)

* modify coloattention

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

fxi

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
flybird11111
2024-04-25 10:47:14 +08:00
committed by GitHub
parent 7ee569b05f
commit 148506c828
3 changed files with 16 additions and 22 deletions

View File

@@ -4,11 +4,7 @@ from copy import copy
import torch
from torch.testing import assert_close
from colossalai.kernel.kernel_loader import (
FlashAttentionLoader,
FlashAttentionWithCustomMaskLoader,
FlashAttentionWithPaddingMaskLoader,
)
from colossalai.kernel.kernel_loader import FlashAttentionLoader, FlashAttentionWithCustomMaskLoader
from colossalai.shardformer.layer import AttnMaskType, ColoAttention
from colossalai.shardformer.layer.attn import invert_mask
from colossalai.testing import clear_cache_before_run, parameterize
@@ -119,11 +115,6 @@ def test_flash_attn_func(dtype: torch.dtype):
if ext.is_available():
ext.assert_compatible()
avail_custom_mask_attn_funcs.append((ext.load(), ext.name, True))
for ext_cls in FlashAttentionWithPaddingMaskLoader.REGISTRY:
ext = ext_cls()
if ext.is_available():
ext.assert_compatible()
avail_padding_mask_attn_funcs.append((ext.load(), ext.name, True))
test_sets = {
"none": (lambda dtype: ({}, None), avail_attn_funcs),