mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 13:00:52 +00:00
[feat] refactored extension module (#5298)
* [feat] refactored extension module * polish * polish * polish * polish * polish * polish * polish * polish * polish * polish
This commit is contained in:
@@ -8,13 +8,12 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import Parameter
|
||||
|
||||
from colossalai.kernel import FusedScaleMaskSoftmax
|
||||
from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType
|
||||
from colossalai.legacy.context import seed
|
||||
from colossalai.legacy.context.parallel_mode import ParallelMode
|
||||
from colossalai.legacy.core import global_context as gpc
|
||||
from colossalai.legacy.nn.layer.parallel_sequence._operation import RingAV, RingQK
|
||||
from colossalai.legacy.registry import LAYERS
|
||||
from colossalai.nn.layer.scaled_softmax import AttnMaskType, FusedScaleMaskSoftmax
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
|
Reference in New Issue
Block a user