fix format (#376)

This commit is contained in:
Jiang Zhuo 2022-03-10 17:15:59 +08:00 committed by Frank Lee
parent ce886a9062
commit 5a4a3b77d9
2 changed files with 26 additions and 23 deletions

View File

@ -27,12 +27,11 @@ class FusedLayerNormAffineFunction1D(torch.autograd.Function):
input_ = input.contiguous()
weight_ = weight.contiguous()
bias_ = bias.contiguous()
output, mean, invvar = fused_mix_prec_layer_norm_cuda.forward_affine(
input_, ctx.normalized_shape, weight_, bias_, ctx.eps)
output, mean, invvar = fused_mix_prec_layer_norm_cuda.forward_affine(input_, ctx.normalized_shape, weight_,
bias_, ctx.eps)
ctx.save_for_backward(input_, weight_, bias_, mean, invvar)
return output
@staticmethod
def backward(ctx, grad_output):
input_, weight_, bias_, mean, invvar = ctx.saved_tensors

View File

@ -81,6 +81,7 @@ class _ReduceGrad(torch.autograd.Function):
:param input_: input matrix
:param parallel_mode: parallel mode
"""
@staticmethod
def symbolic(graph, input_):
return input_
@ -102,6 +103,7 @@ class _ReduceInput(torch.autograd.Function):
:param input_: input matrix
:param parallel_mode: parallel mode
"""
@staticmethod
def symbolic(graph, input_):
return _reduce(input_)
@ -123,6 +125,7 @@ class _SplitForwardGatherBackward(torch.autograd.Function):
:param parallel_mode: parallel mode
:param dim: dimension
"""
@staticmethod
def symbolic(graph, input_):
return _split(input_)
@ -146,6 +149,7 @@ class _GatherForwardSplitBackward(torch.autograd.Function):
:param parallel_mode: parallel mode
:param dim: dimension
"""
@staticmethod
def symbolic(graph, input_):
return _gather(input_)