[shardformer] update shardformer to use flash attention 2 (#4392)

* cherry-pick flash attention 2

cherry-pick flash attention 2

* [shardformer] update shardformer to use flash attention 2

[shardformer] update shardformer to use flash attention 2, fix

[shardformer] update shardformer to use flash attention 2, fix

[shardformer] update shardformer to use flash attention 2, fix
This commit is contained in:
flybird1111
2023-08-09 14:32:19 +08:00
committed by Hongxin Liu
parent ed4c448488
commit 7a3dfd0c64
9 changed files with 10 additions and 11 deletions

View File

@@ -1,8 +1,9 @@
from .layer_norm import MixedFusedLayerNorm as LayerNorm
from .mha.mha import ColoAttention
from .multihead_attention import MultiHeadAttention
from .scaled_softmax import FusedScaleMaskSoftmax, ScaledUpperTriangMaskedSoftmax
from .scaled_softmax import AttnMaskType, FusedScaleMaskSoftmax, ScaledUpperTriangMaskedSoftmax
__all__ = [
'LayerNorm', 'MultiHeadAttention', 'FusedScaleMaskSoftmax', 'ScaledUpperTriangMaskedSoftmax', 'ColoAttention'
'LayerNorm', 'MultiHeadAttention', 'FusedScaleMaskSoftmax', 'ScaledUpperTriangMaskedSoftmax', 'ColoAttention',
'AttnMaskType'
]