mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-25 19:21:17 +00:00
* [fix] fix qwen VocabParallelLMHead1D and gather output * fix tp bug * fix consumer * [feat] Support Distributed LogProb for GRPO Training * [fix] fix loss func * [fix] fix log prob plugin * [fix] fix qwen modeling param * [fix] rm comments * [fix] rm hard-code;fix non-dist version * [fix] fix test file param name and benchmark tp gather output=True/False * [fix] rm non-dist version in dist log prob * [fix] fix comments * [fix] fix dis log prob plugin * [fix] fix test case * [fix] fix qwen VocabParallelLMHead1D and gather output * [fix] fix DistLogProb comments * [fix] restore tp size * [fix] fix comments * [fix] fix comment; fix LogSoftmax usage --------- Co-authored-by: Tong Li <tong.li35271158@gmail.com>
51 lines
1.5 KiB
Python
51 lines
1.5 KiB
Python
from ._operation import all_to_all_comm
|
|
from .attn import AttnMaskType, ColoAttention, RingAttention, get_pad_info
|
|
from .dropout import DropoutForParallelInput, DropoutForReplicatedInput
|
|
from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D
|
|
from .linear import Linear1D_Col, Linear1D_Row, LinearWithGradAccum, PaddingLMHead, VocabParallelLMHead1D
|
|
from .loss import cross_entropy_1d, dist_cross_entropy, dist_log_prob, dist_log_prob_1d
|
|
from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm
|
|
from .parallel_module import ParallelModule
|
|
from .qkv_fused_linear import (
|
|
FusedLinear,
|
|
FusedLinear1D_Col,
|
|
FusedLinear1D_Row,
|
|
GPT2FusedLinearConv,
|
|
GPT2FusedLinearConv1D_Col,
|
|
GPT2FusedLinearConv1D_Row,
|
|
)
|
|
|
|
__all__ = [
|
|
"Embedding1D",
|
|
"VocabParallelEmbedding1D",
|
|
"LinearWithGradAccum",
|
|
"Linear1D_Col",
|
|
"Linear1D_Row",
|
|
"GPT2FusedLinearConv",
|
|
"GPT2FusedLinearConv1D_Row",
|
|
"GPT2FusedLinearConv1D_Col",
|
|
"DropoutForParallelInput",
|
|
"DropoutForReplicatedInput",
|
|
"cross_entropy_1d",
|
|
"dist_cross_entropy",
|
|
"dist_log_prob_1d",
|
|
"dist_log_prob",
|
|
"BaseLayerNorm",
|
|
"LayerNorm",
|
|
"RMSNorm",
|
|
"FusedLayerNorm",
|
|
"FusedRMSNorm",
|
|
"FusedLinear1D_Col",
|
|
"FusedLinear",
|
|
"ParallelModule",
|
|
"PaddingEmbedding",
|
|
"PaddingLMHead",
|
|
"VocabParallelLMHead1D",
|
|
"AttnMaskType",
|
|
"ColoAttention",
|
|
"RingAttention",
|
|
"get_pad_info",
|
|
"all_to_all_comm",
|
|
"FusedLinear1D_Row",
|
|
]
|