mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-12 12:47:21 +00:00
[builder] builder for scaled_upper_triang_masked_softmax (#2234)
This commit is contained in:
@@ -23,27 +23,20 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, inputs, scale):
|
||||
try:
|
||||
import colossalai._C.scaled_upper_triang_masked_softmax
|
||||
except ImportError:
|
||||
raise RuntimeError('ScaledUpperTriangMaskedSoftmax requires cuda extensions')
|
||||
from colossalai.kernel import scaled_upper_triang_masked_softmax
|
||||
|
||||
scale_t = torch.tensor([scale])
|
||||
softmax_results = colossalai._C.scaled_upper_triang_masked_softmax.forward(inputs, scale_t[0])
|
||||
softmax_results = scaled_upper_triang_masked_softmax.forward(inputs, scale_t[0])
|
||||
|
||||
ctx.save_for_backward(softmax_results, scale_t)
|
||||
return softmax_results
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, output_grads):
|
||||
try:
|
||||
import colossalai._C.scaled_upper_triang_masked_softmax
|
||||
except ImportError:
|
||||
raise RuntimeError('ScaledUpperTriangMaskedSoftmax requires cuda extensions')
|
||||
from colossalai.kernel import scaled_upper_triang_masked_softmax
|
||||
|
||||
softmax_results, scale_t = ctx.saved_tensors
|
||||
input_grads = colossalai._C.scaled_upper_triang_masked_softmax.backward(output_grads, softmax_results,
|
||||
scale_t[0])
|
||||
input_grads = scaled_upper_triang_masked_softmax.backward(output_grads, softmax_results, scale_t[0])
|
||||
|
||||
return input_grads, None
|
||||
|
||||
|
Reference in New Issue
Block a user