mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-17 08:51:59 +00:00
fix format (#376)
This commit is contained in:
parent
ce886a9062
commit
5a4a3b77d9
@ -7,7 +7,7 @@ except:
|
|||||||
|
|
||||||
|
|
||||||
class FusedLayerNormAffineFunction1D(torch.autograd.Function):
|
class FusedLayerNormAffineFunction1D(torch.autograd.Function):
|
||||||
r"""
|
r"""
|
||||||
Layernorm
|
Layernorm
|
||||||
|
|
||||||
:param input: input maxtrix
|
:param input: input maxtrix
|
||||||
@ -20,27 +20,26 @@ class FusedLayerNormAffineFunction1D(torch.autograd.Function):
|
|||||||
:param eps: a value added to the denominator for numerical stability
|
:param eps: a value added to the denominator for numerical stability
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, input, weight, bias, normalized_shape, eps):
|
def forward(ctx, input, weight, bias, normalized_shape, eps):
|
||||||
ctx.normalized_shape = normalized_shape
|
ctx.normalized_shape = normalized_shape
|
||||||
ctx.eps = eps
|
ctx.eps = eps
|
||||||
input_ = input.contiguous()
|
input_ = input.contiguous()
|
||||||
weight_ = weight.contiguous()
|
weight_ = weight.contiguous()
|
||||||
bias_ = bias.contiguous()
|
bias_ = bias.contiguous()
|
||||||
output, mean, invvar = fused_mix_prec_layer_norm_cuda.forward_affine(
|
output, mean, invvar = fused_mix_prec_layer_norm_cuda.forward_affine(input_, ctx.normalized_shape, weight_,
|
||||||
input_, ctx.normalized_shape, weight_, bias_, ctx.eps)
|
bias_, ctx.eps)
|
||||||
ctx.save_for_backward(input_, weight_, bias_, mean, invvar)
|
ctx.save_for_backward(input_, weight_, bias_, mean, invvar)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_output):
|
||||||
|
input_, weight_, bias_, mean, invvar = ctx.saved_tensors
|
||||||
|
grad_input = grad_weight = grad_bias = None
|
||||||
|
grad_input, grad_weight, grad_bias \
|
||||||
|
= fused_mix_prec_layer_norm_cuda.backward_affine(
|
||||||
|
grad_output.contiguous(), mean, invvar,
|
||||||
|
input_, ctx.normalized_shape,
|
||||||
|
weight_, bias_, ctx.eps)
|
||||||
|
|
||||||
@staticmethod
|
return grad_input, grad_weight, grad_bias, None, None
|
||||||
def backward(ctx, grad_output):
|
|
||||||
input_, weight_, bias_, mean, invvar = ctx.saved_tensors
|
|
||||||
grad_input = grad_weight = grad_bias = None
|
|
||||||
grad_input, grad_weight, grad_bias \
|
|
||||||
= fused_mix_prec_layer_norm_cuda.backward_affine(
|
|
||||||
grad_output.contiguous(), mean, invvar,
|
|
||||||
input_, ctx.normalized_shape,
|
|
||||||
weight_, bias_, ctx.eps)
|
|
||||||
|
|
||||||
return grad_input, grad_weight, grad_bias, None, None
|
|
||||||
|
@ -81,6 +81,7 @@ class _ReduceGrad(torch.autograd.Function):
|
|||||||
:param input_: input matrix
|
:param input_: input matrix
|
||||||
:param parallel_mode: parallel mode
|
:param parallel_mode: parallel mode
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def symbolic(graph, input_):
|
def symbolic(graph, input_):
|
||||||
return input_
|
return input_
|
||||||
@ -102,6 +103,7 @@ class _ReduceInput(torch.autograd.Function):
|
|||||||
:param input_: input matrix
|
:param input_: input matrix
|
||||||
:param parallel_mode: parallel mode
|
:param parallel_mode: parallel mode
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def symbolic(graph, input_):
|
def symbolic(graph, input_):
|
||||||
return _reduce(input_)
|
return _reduce(input_)
|
||||||
@ -123,6 +125,7 @@ class _SplitForwardGatherBackward(torch.autograd.Function):
|
|||||||
:param parallel_mode: parallel mode
|
:param parallel_mode: parallel mode
|
||||||
:param dim: dimension
|
:param dim: dimension
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def symbolic(graph, input_):
|
def symbolic(graph, input_):
|
||||||
return _split(input_)
|
return _split(input_)
|
||||||
@ -146,6 +149,7 @@ class _GatherForwardSplitBackward(torch.autograd.Function):
|
|||||||
:param parallel_mode: parallel mode
|
:param parallel_mode: parallel mode
|
||||||
:param dim: dimension
|
:param dim: dimension
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def symbolic(graph, input_):
|
def symbolic(graph, input_):
|
||||||
return _gather(input_)
|
return _gather(input_)
|
||||||
|
Loading…
Reference in New Issue
Block a user