[kernel] added kernel loader to softmax autograd function (#3093)

* [kernel] added kernel loader to softmax autograd function

* [release] v0.2.6
This commit is contained in:
Frank Lee
2023-03-10 14:27:09 +08:00
committed by GitHub
parent fff98f06ed
commit 95a36eae63

View File

@@ -180,4 +180,9 @@ class FusedScaleMaskSoftmax(nn.Module):
return probs
def get_batch_per_block(self, sq, sk, b, np):
# build and load kernel if not pre-built
global scaled_masked_softmax
if scaled_masked_softmax is None:
scaled_masked_softmax = ScaledMaskedSoftmaxBuilder().load()
return scaled_masked_softmax.get_batch_per_block(sq, sk, b, np)