mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-15 06:00:07 +00:00
[shardformer] fix linear 1d row and support uneven splits for fused qkv linear (#6084)
* [tp] hotfix linear row * [tp] support uneven split for fused linear * [tp] support sp for fused linear * [tp] fix gpt2 mlp policy * [tp] fix gather fused and add fused linear row
This commit is contained in:
@@ -6,7 +6,7 @@ from .linear import Linear1D_Col, Linear1D_Row, PaddingLMHead, VocabParallelLMHe
|
||||
from .loss import cross_entropy_1d, dist_cross_entropy
|
||||
from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm
|
||||
from .parallel_module import ParallelModule
|
||||
from .qkv_fused_linear import FusedLinear1D_Col, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row
|
||||
from .qkv_fused_linear import FusedLinear1D_Col, FusedLinear1D_Row, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row
|
||||
|
||||
__all__ = [
|
||||
"Embedding1D",
|
||||
@@ -34,4 +34,5 @@ __all__ = [
|
||||
"RingAttention",
|
||||
"get_pad_info",
|
||||
"all_to_all_comm",
|
||||
"FusedLinear1D_Row",
|
||||
]
|
||||
|
Reference in New Issue
Block a user