diff --git a/colossalai/nn/layer/parallel_3d/_operation.py b/colossalai/nn/layer/parallel_3d/_operation.py old mode 100644 new mode 100755 index 885d06e6d..07869e5ad --- a/colossalai/nn/layer/parallel_3d/_operation.py +++ b/colossalai/nn/layer/parallel_3d/_operation.py @@ -281,7 +281,7 @@ def vocab_parallel_classifier_3d( @torch.jit.script -def norm_forward(x, mean, sqr_mean, weight, bias, eps): +def norm_forward(x: Tensor, mean: Tensor, sqr_mean: Tensor, weight: Tensor, bias: Tensor, eps: float): mu = x - mean var = sqr_mean - mean**2 sigma = torch.sqrt(var + eps) @@ -292,7 +292,7 @@ def norm_forward(x, mean, sqr_mean, weight, bias, eps): @torch.jit.script -def norm_backward(grad, mu, sigma, weight): +def norm_backward(grad: Tensor, mu: Tensor, sigma: Tensor, weight: Tensor): # dbias, dweight = grad, grad * mu / sigma dz = grad * weight dmu = dz / sigma