mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 19:13:01 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -1,5 +1,3 @@
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
@@ -22,7 +20,7 @@ class FusedLayerNormAffineFunction1D(torch.autograd.Function):
|
||||
If a single integer is used, it is treated as a singleton list, and this module will
|
||||
normalize over the last dimension which is expected to be of that specific size.
|
||||
eps: a value added to the denominator for numerical stability
|
||||
"""
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input, weight, bias, normalized_shape, eps):
|
||||
@@ -31,8 +29,9 @@ class FusedLayerNormAffineFunction1D(torch.autograd.Function):
|
||||
input_ = input.contiguous()
|
||||
weight_ = weight.contiguous()
|
||||
bias_ = bias.contiguous()
|
||||
output, mean, invvar = fused_mix_prec_layer_norm_cuda.forward_affine(input_, ctx.normalized_shape, weight_,
|
||||
bias_, ctx.eps)
|
||||
output, mean, invvar = fused_mix_prec_layer_norm_cuda.forward_affine(
|
||||
input_, ctx.normalized_shape, weight_, bias_, ctx.eps
|
||||
)
|
||||
ctx.save_for_backward(input_, weight_, bias_, mean, invvar)
|
||||
return output
|
||||
|
||||
@@ -40,11 +39,9 @@ class FusedLayerNormAffineFunction1D(torch.autograd.Function):
|
||||
def backward(ctx, grad_output):
|
||||
input_, weight_, bias_, mean, invvar = ctx.saved_tensors
|
||||
grad_input = grad_weight = grad_bias = None
|
||||
grad_input, grad_weight, grad_bias \
|
||||
= fused_mix_prec_layer_norm_cuda.backward_affine(
|
||||
grad_output.contiguous(), mean, invvar,
|
||||
input_, ctx.normalized_shape,
|
||||
weight_, bias_, ctx.eps)
|
||||
grad_input, grad_weight, grad_bias = fused_mix_prec_layer_norm_cuda.backward_affine(
|
||||
grad_output.contiguous(), mean, invvar, input_, ctx.normalized_shape, weight_, bias_, ctx.eps
|
||||
)
|
||||
|
||||
return grad_input, grad_weight, grad_bias, None, None
|
||||
|
||||
@@ -195,8 +192,9 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
||||
input_list = [
|
||||
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
|
||||
]
|
||||
output = torch.empty(input_.shape, dtype=input_parallel.dtype,
|
||||
device=input_parallel.device).contiguous()
|
||||
output = torch.empty(
|
||||
input_.shape, dtype=input_parallel.dtype, device=input_parallel.device
|
||||
).contiguous()
|
||||
handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
|
||||
# Delay the start of weight gradient computation shortly (3us) to have
|
||||
# reduce-scatter scheduled first and have GPU resources allocated
|
||||
@@ -260,8 +258,9 @@ class _LinearWithReduceScatterForwardGatherBackward(torch.autograd.Function):
|
||||
|
||||
# do reduce-scatter
|
||||
new_shape = list(input_.shape)
|
||||
assert new_shape[dim] % dist.get_world_size(process_group) == 0, \
|
||||
f'The dimension to split ({new_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). '
|
||||
assert (
|
||||
new_shape[dim] % dist.get_world_size(process_group) == 0
|
||||
), f"The dimension to split ({new_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). "
|
||||
new_shape[dim] = new_shape[dim] // dist.get_world_size(process_group)
|
||||
input_list = [item.contiguous() for item in torch.chunk(input_, dist.get_world_size(process_group), dim=dim)]
|
||||
output = torch.empty(new_shape, dtype=input_.dtype, device=input_.device)
|
||||
@@ -329,8 +328,9 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
||||
input_list = [
|
||||
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
|
||||
]
|
||||
output = torch.empty(input_.shape, dtype=input_parallel.dtype,
|
||||
device=input_parallel.device).contiguous()
|
||||
output = torch.empty(
|
||||
input_.shape, dtype=input_parallel.dtype, device=input_parallel.device
|
||||
).contiguous()
|
||||
handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
|
||||
# Delay the start of weight gradient computation shortly (3us) to have
|
||||
# reduce-scatter scheduled first and have GPU resources allocated
|
||||
@@ -473,9 +473,10 @@ def _split(input_, dim=-1, process_group=None):
|
||||
|
||||
# Split along last dimension.
|
||||
dim_size = input_.size(dim)
|
||||
assert dim_size % world_size == 0, \
|
||||
f'The dimension to split ({dim_size}) is not a multiple of world size ({world_size}), ' \
|
||||
f'cannot split tensor evenly'
|
||||
assert dim_size % world_size == 0, (
|
||||
f"The dimension to split ({dim_size}) is not a multiple of world size ({world_size}), "
|
||||
f"cannot split tensor evenly"
|
||||
)
|
||||
|
||||
tensor_list = torch.split(input_, dim_size // world_size, dim=dim)
|
||||
rank = dist.get_rank(process_group)
|
||||
@@ -502,7 +503,7 @@ def _gather(input_, dim=-1, process_group=None):
|
||||
|
||||
|
||||
def _reduce_scatter(input_, dim=1, process_group=None):
|
||||
""" Do reduce-scatter operation.
|
||||
"""Do reduce-scatter operation.
|
||||
|
||||
Args:
|
||||
input_ (`torch.Tensor`): The input tensor from sequence parallel region.
|
||||
@@ -515,8 +516,9 @@ def _reduce_scatter(input_, dim=1, process_group=None):
|
||||
|
||||
# reduce-scatter
|
||||
new_shape = list(input_.shape)
|
||||
assert new_shape[dim] % dist.get_world_size(process_group) == 0, \
|
||||
f'The dimension to split ({new_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). '
|
||||
assert (
|
||||
new_shape[dim] % dist.get_world_size(process_group) == 0
|
||||
), f"The dimension to split ({new_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). "
|
||||
new_shape[dim] = new_shape[dim] // world_size
|
||||
output = torch.empty(new_shape, dtype=input_.dtype, device=input_.device)
|
||||
dist.reduce_scatter(output, input_, group=process_group)
|
||||
@@ -532,20 +534,24 @@ def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allre
|
||||
return LinearWithAsyncCommunication.apply(input_, weight, bias, process_group, async_grad_allreduce)
|
||||
|
||||
|
||||
def linear_gather_forward_reducescatter_backward(input_, weight, bias, process_group, async_grad_reduce_scatter, dim,
|
||||
overlap):
|
||||
return _LinearWithGatherForwardReduceScatterBackward.apply(input_, weight, bias, process_group,
|
||||
async_grad_reduce_scatter, dim, overlap)
|
||||
def linear_gather_forward_reducescatter_backward(
|
||||
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap
|
||||
):
|
||||
return _LinearWithGatherForwardReduceScatterBackward.apply(
|
||||
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap
|
||||
)
|
||||
|
||||
|
||||
def linear_reducescatter_forward_gather_backward(input_, process_group, dim):
|
||||
return _LinearWithReduceScatterForwardGatherBackward.apply(input_, process_group, dim)
|
||||
|
||||
|
||||
def matmul_gather_forward_reducescatter_backward(input_, weight, bias, process_group, async_grad_reduce_scatter, dim,
|
||||
overlap):
|
||||
return _MatmulWithGatherForwardReduceScatterBackward.apply(input_, weight, bias, process_group,
|
||||
async_grad_reduce_scatter, dim, overlap)
|
||||
def matmul_gather_forward_reducescatter_backward(
|
||||
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap
|
||||
):
|
||||
return _MatmulWithGatherForwardReduceScatterBackward.apply(
|
||||
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap
|
||||
)
|
||||
|
||||
|
||||
def gather_forward_split_backward(input_, dim, process_group):
|
||||
|
Reference in New Issue
Block a user