mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 01:06:00 +00:00
[example] integrate seq-parallel tutorial with CI (#2463)
This commit is contained in:
@@ -114,6 +114,13 @@ class FusedScaleMaskSoftmax(nn.Module):
|
||||
self.softmax_in_fp32 = softmax_in_fp32
|
||||
self.scale = scale
|
||||
|
||||
try:
|
||||
from colossalai._C import scaled_masked_softmax
|
||||
except ImportError:
|
||||
from colossalai.kernel.op_builder.scaled_masked_softmax import ScaledMaskedSoftmaxBuilder
|
||||
scaled_masked_softmax = ScaledMaskedSoftmaxBuilder().load()
|
||||
self.scaled_masked_softmax = scaled_masked_softmax
|
||||
|
||||
assert (self.scale is None or softmax_in_fp32), "softmax should be in fp32 when scaled"
|
||||
|
||||
def forward(self, input, mask):
|
||||
@@ -178,11 +185,5 @@ class FusedScaleMaskSoftmax(nn.Module):
|
||||
|
||||
return probs
|
||||
|
||||
@staticmethod
|
||||
def get_batch_per_block(sq, sk, b, np):
|
||||
try:
|
||||
import colossalai._C.scaled_masked_softmax
|
||||
except ImportError:
|
||||
raise RuntimeError('ScaledMaskedSoftmax requires cuda extensions')
|
||||
|
||||
return colossalai._C.scaled_masked_softmax.get_batch_per_block(sq, sk, b, np)
|
||||
def get_batch_per_block(self, sq, sk, b, np):
|
||||
return self.scaled_masked_softmax.get_batch_per_block(sq, sk, b, np)
|
||||
|
Reference in New Issue
Block a user