mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 04:24:47 +00:00
[coloattention]modify coloattention (#5627)
* modify coloattention * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * fix fxi * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -4,11 +4,7 @@ from copy import copy
|
||||
import torch
|
||||
from torch.testing import assert_close
|
||||
|
||||
from colossalai.kernel.kernel_loader import (
|
||||
FlashAttentionLoader,
|
||||
FlashAttentionWithCustomMaskLoader,
|
||||
FlashAttentionWithPaddingMaskLoader,
|
||||
)
|
||||
from colossalai.kernel.kernel_loader import FlashAttentionLoader, FlashAttentionWithCustomMaskLoader
|
||||
from colossalai.shardformer.layer import AttnMaskType, ColoAttention
|
||||
from colossalai.shardformer.layer.attn import invert_mask
|
||||
from colossalai.testing import clear_cache_before_run, parameterize
|
||||
@@ -119,11 +115,6 @@ def test_flash_attn_func(dtype: torch.dtype):
|
||||
if ext.is_available():
|
||||
ext.assert_compatible()
|
||||
avail_custom_mask_attn_funcs.append((ext.load(), ext.name, True))
|
||||
for ext_cls in FlashAttentionWithPaddingMaskLoader.REGISTRY:
|
||||
ext = ext_cls()
|
||||
if ext.is_available():
|
||||
ext.assert_compatible()
|
||||
avail_padding_mask_attn_funcs.append((ext.load(), ext.name, True))
|
||||
|
||||
test_sets = {
|
||||
"none": (lambda dtype: ({}, None), avail_attn_funcs),
|
||||
|
Reference in New Issue
Block a user