[example] integrate seq-parallel tutorial with CI (#2463)

This commit is contained in:
Frank Lee
2023-01-13 14:40:05 +08:00
committed by GitHub
parent 8e85d2440a
commit 8b7495dd54
7 changed files with 72 additions and 170 deletions

View File

@@ -114,6 +114,13 @@ class FusedScaleMaskSoftmax(nn.Module):
self.softmax_in_fp32 = softmax_in_fp32
self.scale = scale
try:
from colossalai._C import scaled_masked_softmax
except ImportError:
from colossalai.kernel.op_builder.scaled_masked_softmax import ScaledMaskedSoftmaxBuilder
scaled_masked_softmax = ScaledMaskedSoftmaxBuilder().load()
self.scaled_masked_softmax = scaled_masked_softmax
assert (self.scale is None or softmax_in_fp32), "softmax should be in fp32 when scaled"
def forward(self, input, mask):
@@ -178,11 +185,5 @@ class FusedScaleMaskSoftmax(nn.Module):
return probs
@staticmethod
def get_batch_per_block(sq, sk, b, np):
try:
import colossalai._C.scaled_masked_softmax
except ImportError:
raise RuntimeError('ScaledMaskedSoftmax requires cuda extensions')
return colossalai._C.scaled_masked_softmax.get_batch_per_block(sq, sk, b, np)
def get_batch_per_block(self, sq, sk, b, np):
return self.scaled_masked_softmax.get_batch_per_block(sq, sk, b, np)