mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-04-27 19:36:13 +00:00
[shardformer] optimize seq parallelism (#6086)
* [shardformer] optimize seq parallelism * [shardformer] fix gpt2 fused linear col * [plugin] update gemini plugin * [plugin] update moe hybrid plugin * [test] update gpt2 fused linear test * [shardformer] fix gpt2 fused linear reduce
This commit is contained in:
parent
6b2c506fc5
commit
dc2cdaf3e8
@ -322,7 +322,6 @@ class GeminiPlugin(DPPluginBase):
|
||||
enable_flash_attention (bool, optional): Whether to switch on flash attention in Shardformer. Defaults to False.
|
||||
enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False.
|
||||
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False.
|
||||
enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False.
|
||||
use_fp8 (bool, optional): Whether to enable fp8 mixed precision training. Defaults to False.
|
||||
verbose (bool, optional): verbose mode. Debug info including chunk search result will be printed. Defaults to False.
|
||||
fp8_communication (bool, optional): Whether to enable fp8 communication. Defaults to False.
|
||||
@ -366,7 +365,6 @@ class GeminiPlugin(DPPluginBase):
|
||||
enable_flash_attention: bool = False,
|
||||
enable_sequence_parallelism: bool = False,
|
||||
enable_jit_fused: bool = False,
|
||||
enable_sequence_overlap: bool = False,
|
||||
enable_async_reduce: bool = True,
|
||||
use_fp8: bool = False,
|
||||
verbose: bool = False,
|
||||
@ -428,7 +426,6 @@ class GeminiPlugin(DPPluginBase):
|
||||
self.enable_flash_attention = enable_flash_attention
|
||||
self.enable_sequence_parallelism = enable_sequence_parallelism if self.enable_tensor_parallelism else False
|
||||
self.enable_jit_fused = enable_jit_fused
|
||||
self.enable_sequence_overlap = enable_sequence_overlap
|
||||
self.verbose = verbose
|
||||
|
||||
self.tp_size = tp_size
|
||||
@ -455,7 +452,6 @@ class GeminiPlugin(DPPluginBase):
|
||||
enable_flash_attention=self.enable_flash_attention,
|
||||
enable_jit_fused=self.enable_jit_fused,
|
||||
enable_sequence_parallelism=self.enable_sequence_parallelism,
|
||||
enable_sequence_overlap=self.enable_sequence_overlap,
|
||||
)
|
||||
|
||||
def __del__(self):
|
||||
|
@ -951,7 +951,6 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False.
|
||||
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False.
|
||||
sequence_parallelism_mode (str): The Sequence parallelism mode. Can only be choosed from ["split_gather", "ring", "all_to_all"]. Defaults to "split_gather".
|
||||
enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False.
|
||||
parallel_output (bool): Whether to keep the output parallel when enabling tensor parallelism. Default to True.
|
||||
num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None.
|
||||
microbatch_size (int, optional): Microbatch size when using pipeline parallelism.
|
||||
@ -1002,7 +1001,6 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
enable_jit_fused: bool = False,
|
||||
enable_sequence_parallelism: bool = False,
|
||||
sequence_parallelism_mode: str = None,
|
||||
enable_sequence_overlap: bool = False,
|
||||
parallel_output: bool = True,
|
||||
num_microbatches: Optional[int] = None,
|
||||
microbatch_size: Optional[int] = None,
|
||||
@ -1174,7 +1172,6 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
enable_jit_fused=self.enable_jit_fused,
|
||||
enable_sequence_parallelism=enable_sequence_parallelism,
|
||||
sequence_parallelism_mode=sequence_parallelism_mode,
|
||||
enable_sequence_overlap=enable_sequence_overlap,
|
||||
parallel_output=parallel_output,
|
||||
make_vocab_size_divisible_by=make_vocab_size_divisible_by,
|
||||
gradient_checkpoint_config=gradient_checkpoint_config,
|
||||
|
@ -140,7 +140,6 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||
enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False.
|
||||
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False.
|
||||
sequence_parallelism_mode (str): The Sequence parallelism mode. Can only be choosed from ["split_gather", "ring", "all_to_all"]. Defaults to "split_gather".
|
||||
enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False.
|
||||
parallel_output (bool): Whether to keep the output parallel when enabling tensor parallelism. Default to True.
|
||||
num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None.
|
||||
microbatch_size (int, optional): Microbatch size when using pipeline parallelism.
|
||||
@ -189,7 +188,6 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||
enable_jit_fused: bool = False,
|
||||
enable_sequence_parallelism: bool = False,
|
||||
sequence_parallelism_mode: str = None,
|
||||
enable_sequence_overlap: bool = False,
|
||||
parallel_output: bool = True,
|
||||
num_microbatches: Optional[int] = None,
|
||||
microbatch_size: Optional[int] = None,
|
||||
@ -351,7 +349,6 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||
enable_jit_fused=self.enable_jit_fused,
|
||||
enable_sequence_parallelism=enable_sequence_parallelism,
|
||||
sequence_parallelism_mode=sequence_parallelism_mode,
|
||||
enable_sequence_overlap=enable_sequence_overlap,
|
||||
parallel_output=parallel_output,
|
||||
make_vocab_size_divisible_by=make_vocab_size_divisible_by,
|
||||
gradient_checkpoint_config=gradient_checkpoint_config,
|
||||
|
@ -102,7 +102,7 @@ class MatmulWithAsyncCommunication(torch.autograd.Function):
|
||||
grad_output = grad_output.view(-1, grad_output.shape[-1])
|
||||
total_input = total_input.view(-1, total_input.shape[-1])
|
||||
|
||||
if ctx.async_grad_allreduce and fp8_communication:
|
||||
if fp8_communication or not ctx.async_grad_allreduce:
|
||||
_reduce(grad_input, group=ctx.process_group, fp8_communication=fp8_communication, fp8_format="e5m2")
|
||||
elif ctx.async_grad_allreduce:
|
||||
# Asynchronous all-reduce
|
||||
@ -216,10 +216,12 @@ def _ring_as_gather(func, input_to_gather=None, input_local=None, process_group=
|
||||
for k in recv_tensors:
|
||||
send_tensors[k], recv_tensors[k] = recv_tensors[k], send_tensors[k]
|
||||
|
||||
input_tensors = []
|
||||
output_tensors = []
|
||||
|
||||
handles = communicate_step()
|
||||
# first round: special case, retrive from local tensor
|
||||
input_tensors.append(input_to_gather)
|
||||
output_tensors.append(func(**input_to_gather, **input_local))
|
||||
for i in range(group_size - 2):
|
||||
for handle in handles:
|
||||
@ -230,14 +232,25 @@ def _ring_as_gather(func, input_to_gather=None, input_local=None, process_group=
|
||||
handles = communicate_step()
|
||||
|
||||
# actual computation
|
||||
input_tensors.append(send_tensors)
|
||||
output_tensors.append(func(**send_tensors, **input_local))
|
||||
|
||||
# final round: special case, no need to send/recv again
|
||||
for handle in handles:
|
||||
handle.wait()
|
||||
input_tensors.append(send_tensors)
|
||||
output_tensors.append(func(**recv_tensors, **input_local))
|
||||
|
||||
return torch.cat(output_tensors[group_size - cur_rank :] + output_tensors[: group_size - cur_rank], dim=gather_dim)
|
||||
gathered_input = {}
|
||||
for k in input_to_gather:
|
||||
input_shards = [d[k] for d in input_tensors[group_size - cur_rank :] + input_tensors[: group_size - cur_rank]]
|
||||
gathered_input[k] = torch.cat(input_shards, dim=gather_dim)
|
||||
|
||||
gathered_output = torch.cat(
|
||||
output_tensors[group_size - cur_rank :] + output_tensors[: group_size - cur_rank], dim=gather_dim
|
||||
)
|
||||
|
||||
return gathered_output, gathered_input
|
||||
|
||||
|
||||
class _GatherForwardReduceScatterBackward(torch.autograd.Function):
|
||||
@ -293,29 +306,30 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap=True, ring=False):
|
||||
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring=False):
|
||||
ctx.save_for_backward(input_, weight, bias)
|
||||
ctx.use_bias = bias is not None
|
||||
ctx.process_group = process_group
|
||||
ctx.async_grad_reduce_scatter = async_grad_reduce_scatter
|
||||
ctx.dim = dim
|
||||
ctx.overlap = overlap
|
||||
|
||||
if ring is True:
|
||||
input_to_gather = {"input": input_}
|
||||
input_local = {"weight": weight}
|
||||
|
||||
output = _ring_as_gather(
|
||||
output, input_dict = _ring_as_gather(
|
||||
F.linear,
|
||||
input_to_gather=input_to_gather,
|
||||
input_local=input_local,
|
||||
process_group=process_group,
|
||||
)
|
||||
ctx.gathered_input = input_dict["input"]
|
||||
|
||||
if bias is not None:
|
||||
output += bias
|
||||
else:
|
||||
input_parallel = _gather(input_, dim, process_group)
|
||||
ctx.gathered_input = input_parallel
|
||||
if bias is not None:
|
||||
output = F.linear(input_parallel, weight, bias)
|
||||
else:
|
||||
@ -329,100 +343,50 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
||||
use_bias = ctx.use_bias
|
||||
dim = ctx.dim
|
||||
process_group = ctx.process_group
|
||||
overlap = ctx.overlap
|
||||
|
||||
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm
|
||||
if use_bias:
|
||||
bias = bias.view(bias.shape)
|
||||
|
||||
if not overlap:
|
||||
input_parallel = _gather(input_, dim, process_group)
|
||||
input_parallel = ctx.gathered_input
|
||||
|
||||
total_input = input_parallel
|
||||
grad_input = grad_output.matmul(weight)
|
||||
grad_output = grad_output.contiguous()
|
||||
# Convert the tensor shapes to 2D for execution compatibility
|
||||
if len(grad_output.shape) > 2:
|
||||
grad_output = grad_output.view(-1, grad_output.shape[-1])
|
||||
total_input = total_input.view(-1, total_input.shape[-1])
|
||||
total_input = input_parallel
|
||||
grad_input = grad_output.matmul(weight)
|
||||
grad_output = grad_output.contiguous()
|
||||
# Convert the tensor shapes to 2D for execution compatibility
|
||||
if len(grad_output.shape) > 2:
|
||||
grad_output = grad_output.view(-1, grad_output.shape[-1])
|
||||
total_input = total_input.view(-1, total_input.shape[-1])
|
||||
|
||||
if ctx.async_grad_reduce_scatter:
|
||||
# Asynchronous reduce-scatter
|
||||
input_list = [
|
||||
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
|
||||
]
|
||||
output = torch.empty(
|
||||
input_.shape, dtype=input_parallel.dtype, device=input_parallel.device
|
||||
).contiguous()
|
||||
handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
|
||||
# Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
|
||||
# all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py
|
||||
|
||||
if _grad_accum_fusion_available and weight.grad is not None:
|
||||
grad = weight.grad
|
||||
if grad.dtype == torch.float32:
|
||||
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad)
|
||||
grad_weight = None
|
||||
elif grad.dtype == torch.float16:
|
||||
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, grad)
|
||||
grad_weight = None
|
||||
else:
|
||||
grad_weight = grad_output.t().matmul(total_input)
|
||||
else:
|
||||
grad_weight = grad_output.t().matmul(total_input)
|
||||
|
||||
grad_bias = grad_output.sum(dim=0) if use_bias else None
|
||||
|
||||
if ctx.async_grad_reduce_scatter:
|
||||
handle.wait()
|
||||
|
||||
else:
|
||||
input_ = input_.contiguous()
|
||||
world_size = dist.get_world_size(process_group)
|
||||
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
|
||||
|
||||
# do all gather in is async way
|
||||
gather_handle = dist.all_gather(tensor_list, input_, group=process_group, async_op=True)
|
||||
# calculate gradient and prepare data asynchronously with all-gather
|
||||
# calculate
|
||||
grad_input = grad_output.matmul(weight)
|
||||
grad_output = grad_output.contiguous()
|
||||
# Convert the tensor shapes to 2D for execution compatibility
|
||||
if len(grad_output.shape) > 2:
|
||||
grad_output = grad_output.view(-1, grad_output.shape[-1])
|
||||
grad_bias = grad_output.sum(dim=0) if use_bias else None
|
||||
# prepare data
|
||||
if ctx.async_grad_reduce_scatter:
|
||||
# Asynchronous reduce-scatter
|
||||
input_list = [
|
||||
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
|
||||
]
|
||||
output = torch.empty(input_.shape, dtype=input_.dtype, device=input_.device).contiguous()
|
||||
# wait until all-gather finished
|
||||
gather_handle.wait()
|
||||
output = torch.empty(input_.shape, dtype=input_parallel.dtype, device=input_parallel.device).contiguous()
|
||||
handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
|
||||
# Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
|
||||
# all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py
|
||||
|
||||
# do reduce-scatter in async way
|
||||
reducescatter_handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
|
||||
input_parallel = torch.cat(tensor_list, dim=dim).contiguous()
|
||||
# calculate gradient
|
||||
if len(input_parallel.shape) > 2:
|
||||
input_parallel = input_parallel.view(-1, input_parallel.shape[-1])
|
||||
|
||||
if _grad_accum_fusion_available and weight.grad is not None:
|
||||
grad = weight.grad
|
||||
if grad.dtype == torch.float32:
|
||||
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(input_parallel, grad_output, grad)
|
||||
grad_weight = None
|
||||
elif grad.dtype == torch.float16:
|
||||
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(input_parallel, grad_output, grad)
|
||||
grad_weight = None
|
||||
else:
|
||||
grad_weight = grad_output.t().matmul(input_parallel)
|
||||
if _grad_accum_fusion_available and weight.grad is not None:
|
||||
grad = weight.grad
|
||||
if grad.dtype == torch.float32:
|
||||
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad)
|
||||
grad_weight = None
|
||||
elif grad.dtype == torch.float16:
|
||||
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, grad)
|
||||
grad_weight = None
|
||||
else:
|
||||
grad_weight = grad_output.t().matmul(input_parallel)
|
||||
# grad_weight = grad_output.t().matmul(input_parallel)
|
||||
# wait until reduce-scatter finished
|
||||
reducescatter_handle.wait()
|
||||
grad_weight = grad_output.t().matmul(total_input)
|
||||
else:
|
||||
grad_weight = grad_output.t().matmul(total_input)
|
||||
|
||||
return output, grad_weight, grad_bias, None, None, None, None, None
|
||||
grad_bias = grad_output.sum(dim=0) if use_bias else None
|
||||
|
||||
if ctx.async_grad_reduce_scatter:
|
||||
handle.wait()
|
||||
|
||||
return output, grad_weight, grad_bias, None, None, None, None
|
||||
|
||||
|
||||
def _ring_as_reducescatter(
|
||||
@ -553,7 +517,7 @@ class _LinearWithReduceScatterForwardGatherBackward(torch.autograd.Function):
|
||||
# Convert the tensor shapes to 2D for execution compatibility
|
||||
if len(grad_output.shape) > 2:
|
||||
grad_output = grad_output.view(-1, grad_output.shape[-1])
|
||||
total_input = total_input.view(-1, total_input.shape[-1])
|
||||
total_input = total_input.reshape(-1, total_input.shape[-1])
|
||||
grad_weight = grad_output.t().matmul(total_input)
|
||||
grad_bias = grad_output.sum(dim=0) if use_bias else None
|
||||
|
||||
@ -611,34 +575,30 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring, fp8_communication
|
||||
):
|
||||
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring, fp8_communication):
|
||||
ctx.save_for_backward(input_, weight, bias)
|
||||
ctx.use_bias = bias is not None
|
||||
ctx.process_group = process_group
|
||||
ctx.async_grad_reduce_scatter = async_grad_reduce_scatter
|
||||
ctx.dim = dim
|
||||
ctx.overlap = overlap
|
||||
ctx.fp8_communication = fp8_communication
|
||||
|
||||
if ring is True:
|
||||
input_to_gather = {}
|
||||
input_local = {}
|
||||
input_to_gather["input"] = input_
|
||||
input_local["other"] = weight
|
||||
input_to_gather = {"input": input_}
|
||||
input_local = {"other": weight}
|
||||
|
||||
output = _ring_as_gather(
|
||||
output, input_dict = _ring_as_gather(
|
||||
torch.matmul,
|
||||
input_to_gather=input_to_gather,
|
||||
input_local=input_local,
|
||||
process_group=process_group,
|
||||
gather_dim=dim,
|
||||
)
|
||||
ctx.gathered_input = input_dict["input"]
|
||||
|
||||
else:
|
||||
input_parallel = _gather(input_, dim, process_group, fp8_communication, fp8_format="e4m3")
|
||||
|
||||
ctx.gathered_input = input_parallel
|
||||
output = torch.matmul(input_parallel, weight)
|
||||
|
||||
if bias is not None:
|
||||
@ -651,76 +611,39 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
|
||||
use_bias = ctx.use_bias
|
||||
dim = ctx.dim
|
||||
process_group = ctx.process_group
|
||||
overlap = ctx.overlap
|
||||
fp8_communication = ctx.fp8_communication
|
||||
|
||||
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm
|
||||
weight = weight.view(weight.shape)
|
||||
if use_bias:
|
||||
bias = bias.view(bias.shape)
|
||||
|
||||
if not overlap:
|
||||
input_parallel = _gather(input_, dim, process_group, fp8_communication, fp8_format="e5m2")
|
||||
input_parallel = ctx.gathered_input
|
||||
|
||||
total_input = input_parallel
|
||||
grad_input = grad_output.matmul(weight.T)
|
||||
grad_output = grad_output.contiguous()
|
||||
# Convert the tensor shapes to 2D for execution compatibility
|
||||
if len(grad_output.shape) > 2:
|
||||
grad_output = grad_output.view(-1, grad_output.shape[-1])
|
||||
total_input = total_input.view(-1, total_input.shape[-1])
|
||||
total_input = input_parallel
|
||||
grad_input = grad_output.matmul(weight.T)
|
||||
grad_output = grad_output.contiguous()
|
||||
# Convert the tensor shapes to 2D for execution compatibility
|
||||
if len(grad_output.shape) > 2:
|
||||
grad_output = grad_output.view(-1, grad_output.shape[-1])
|
||||
total_input = total_input.view(-1, total_input.shape[-1])
|
||||
|
||||
if ctx.async_grad_reduce_scatter:
|
||||
# Asynchronous reduce-scatter
|
||||
input_list = [
|
||||
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
|
||||
]
|
||||
output = torch.empty(
|
||||
input_.shape, dtype=input_parallel.dtype, device=input_parallel.device
|
||||
).contiguous()
|
||||
handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
|
||||
# Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
|
||||
# all-reduce scheduled first and have GPU resources allocated
|
||||
|
||||
grad_weight = total_input.t().matmul(grad_output)
|
||||
grad_bias = grad_output.sum(dim=0) if use_bias else None
|
||||
|
||||
if ctx.async_grad_reduce_scatter:
|
||||
handle.wait()
|
||||
|
||||
else:
|
||||
world_size = dist.get_world_size(process_group)
|
||||
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
|
||||
|
||||
# do all gather in is async way
|
||||
gather_handle = dist.all_gather(tensor_list, input_, group=process_group, async_op=True)
|
||||
# calculate gradient and prepare data asynchronously with all-gather
|
||||
# calculate
|
||||
grad_input = grad_output.matmul(weight.T)
|
||||
grad_output = grad_output.contiguous()
|
||||
# Convert the tensor shapes to 2D for execution compatibility
|
||||
if len(grad_output.shape) > 2:
|
||||
grad_output = grad_output.view(-1, grad_output.shape[-1])
|
||||
grad_bias = grad_output.sum(dim=0) if use_bias else None
|
||||
# prepare data
|
||||
if ctx.async_grad_reduce_scatter:
|
||||
# Asynchronous reduce-scatter
|
||||
input_list = [
|
||||
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
|
||||
]
|
||||
output = torch.empty(input_.shape, dtype=input_.dtype, device=input_.device).contiguous()
|
||||
# wait until all-gather finished
|
||||
gather_handle.wait()
|
||||
output = torch.empty(input_.shape, dtype=input_parallel.dtype, device=input_parallel.device).contiguous()
|
||||
handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
|
||||
# Rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
|
||||
# all-reduce scheduled first and have GPU resources allocated
|
||||
|
||||
# do reduce-scatter in async way
|
||||
reducescatter_handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
|
||||
input_parallel = torch.cat(tensor_list, dim=dim).contiguous()
|
||||
# calculate gradient
|
||||
if len(input_parallel.shape) > 2:
|
||||
input_parallel = input_parallel.view(-1, input_parallel.shape[-1])
|
||||
grad_weight = input_parallel.t().matmul(grad_output)
|
||||
# wait until reduce-scatter finished
|
||||
reducescatter_handle.wait()
|
||||
grad_weight = total_input.t().matmul(grad_output)
|
||||
grad_bias = grad_output.sum(dim=0) if use_bias else None
|
||||
|
||||
return output, grad_weight, grad_bias, None, None, None, None, None, None
|
||||
if ctx.async_grad_reduce_scatter:
|
||||
handle.wait()
|
||||
|
||||
return output, grad_weight, grad_bias, None, None, None, None, None
|
||||
|
||||
|
||||
class _SplitForwardGatherBackward(torch.autograd.Function):
|
||||
@ -1050,10 +973,10 @@ def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allre
|
||||
|
||||
|
||||
def linear_gather_forward_reducescatter_backward(
|
||||
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring=False
|
||||
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring=False
|
||||
):
|
||||
return _LinearWithGatherForwardReduceScatterBackward.apply(
|
||||
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring
|
||||
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring
|
||||
)
|
||||
|
||||
|
||||
@ -1070,10 +993,10 @@ def linear_reducescatter_forward_gather_backward(input_, weight, bias=None, proc
|
||||
|
||||
|
||||
def matmul_gather_forward_reducescatter_backward(
|
||||
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring=False, fp8_communication=False
|
||||
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring=False, fp8_communication=False
|
||||
):
|
||||
return _MatmulWithGatherForwardReduceScatterBackward.apply(
|
||||
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring, fp8_communication
|
||||
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, ring, fp8_communication
|
||||
)
|
||||
|
||||
|
||||
|
@ -23,17 +23,15 @@ from colossalai.tensor.d_tensor.api import (
|
||||
)
|
||||
|
||||
from ._operation import (
|
||||
gather_forward_reducescatter_backward,
|
||||
gather_forward_split_backward,
|
||||
linear_gather_forward_reducescatter_backward,
|
||||
linear_reducescatter_forward_gather_backward,
|
||||
linear_with_async_comm,
|
||||
reduce_forward,
|
||||
reducescatter_forward_gather_backward,
|
||||
split_forward_gather_backward,
|
||||
)
|
||||
from .parallel_module import PaddingParallelModule, ParallelModule
|
||||
from .utils import create_randomizer_with_offset
|
||||
from .utils import create_randomizer_with_offset, is_share_sp_tp
|
||||
|
||||
__all__ = ["Linear1D_Col", "Linear1D_Row"]
|
||||
|
||||
@ -55,7 +53,6 @@ class Linear1D_Col(ParallelModule):
|
||||
to all GPUs, otherwise, every GPU will have its output
|
||||
which is :math:`Y_i = XA_i`, defaults to False
|
||||
seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False.
|
||||
overlap (`bool`): If set to ``True``, it will overlap input all-gather with gradient computation during backward, defaults to False.
|
||||
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
|
||||
which is preserved for kernel fusion, defaults to False
|
||||
weight_initializer (`typing.Callable`):
|
||||
@ -78,7 +75,6 @@ class Linear1D_Col(ParallelModule):
|
||||
gather_output: bool = False,
|
||||
seq_parallel_mode: str = None,
|
||||
seq_parallel_dim: int = 1,
|
||||
overlap: torch.cuda.Stream = None,
|
||||
skip_bias_add: bool = False,
|
||||
weight: Optional[Parameter] = None,
|
||||
bias_: Optional[Parameter] = None,
|
||||
@ -95,7 +91,6 @@ class Linear1D_Col(ParallelModule):
|
||||
self.gather_output = gather_output
|
||||
self.seq_parallel_mode = seq_parallel_mode
|
||||
self.seq_parallel_dim = seq_parallel_dim
|
||||
self.overlap = overlap
|
||||
self.skip_bias_add = skip_bias_add
|
||||
self.device = device
|
||||
self.process_group = process_group
|
||||
@ -202,16 +197,15 @@ class Linear1D_Col(ParallelModule):
|
||||
# Matrix multiply.
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
|
||||
if self.seq_parallel_mode == "split_gather":
|
||||
input_parallel = gather_forward_reducescatter_backward(
|
||||
input_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication
|
||||
)
|
||||
output_parallel = linear_with_async_comm(
|
||||
input_parallel, self.weight, bias, self.process_group, False, fp8_communication=self.fp8_communication
|
||||
)
|
||||
elif self.seq_parallel_mode == "ring":
|
||||
if is_share_sp_tp(self.seq_parallel_mode):
|
||||
output_parallel = linear_gather_forward_reducescatter_backward(
|
||||
input_parallel, self.weight, bias, self.process_group, True, self.seq_parallel_dim, self.overlap, True
|
||||
input_parallel,
|
||||
self.weight,
|
||||
bias,
|
||||
self.process_group,
|
||||
True,
|
||||
self.seq_parallel_dim,
|
||||
ring=self.seq_parallel_mode == "ring",
|
||||
)
|
||||
else:
|
||||
output_parallel = linear_with_async_comm(
|
||||
@ -428,18 +422,13 @@ class Linear1D_Row(ParallelModule):
|
||||
handle.wait()
|
||||
output = torch.cat(output_parallel_list, dim=-1)
|
||||
else:
|
||||
if self.seq_parallel_mode == "split_gather":
|
||||
output_parallel = F.linear(input_, self.weight)
|
||||
output = reducescatter_forward_gather_backward(
|
||||
output_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication
|
||||
)
|
||||
elif self.seq_parallel_mode == "ring":
|
||||
if is_share_sp_tp(self.seq_parallel_mode):
|
||||
output = linear_reducescatter_forward_gather_backward(
|
||||
input_,
|
||||
self.weight,
|
||||
process_group=self.process_group,
|
||||
dim=self.seq_parallel_dim,
|
||||
ring=True,
|
||||
ring=self.seq_parallel_mode == "ring",
|
||||
)
|
||||
else:
|
||||
output_parallel = F.linear(input_, self.weight)
|
||||
@ -551,7 +540,6 @@ class VocabParallelLMHead1D(Linear1D_Col, PaddingParallelModule):
|
||||
to all GPUs, otherwise, every GPU will have its output
|
||||
which is :math:`Y_i = XA_i`, defaults to False
|
||||
seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False.
|
||||
overlap (`bool`): If set to ``True``, it will overlap input all-gather with gradient computation during backward, defaults to False.
|
||||
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
|
||||
which is preserved for kernel fusion, defaults to False
|
||||
weight_initializer (`typing.Callable`):
|
||||
|
@ -25,19 +25,17 @@ from colossalai.tensor.d_tensor.api import (
|
||||
)
|
||||
|
||||
from ._operation import (
|
||||
gather_forward_reducescatter_backward,
|
||||
linear_gather_forward_reducescatter_backward,
|
||||
linear_reducescatter_forward_gather_backward,
|
||||
linear_with_async_comm,
|
||||
matmul_gather_forward_reducescatter_backward,
|
||||
matmul_with_async_comm,
|
||||
reduce_backward,
|
||||
reduce_forward,
|
||||
reducescatter_forward_gather_backward,
|
||||
split_forward_gather_backward,
|
||||
)
|
||||
from .parallel_module import ParallelModule
|
||||
from .utils import create_randomizer_with_offset
|
||||
from .utils import create_randomizer_with_offset, is_share_sp_tp
|
||||
|
||||
__all__ = ["FusedLinear1D_Col", "FusedLinear1D_Row", "GPT2FusedLinearConv1D_Col", "GPT2FusedLinearConv1D_Row"]
|
||||
|
||||
@ -222,10 +220,8 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
|
||||
dtype: torch.dtype = None,
|
||||
device: torch.device = None,
|
||||
process_group: ProcessGroup = None,
|
||||
async_communication: bool = False,
|
||||
gather_output: bool = False,
|
||||
seq_parallel_mode: str = None,
|
||||
overlap: bool = False,
|
||||
skip_bias_add: bool = False,
|
||||
weight: Optional[Parameter] = None,
|
||||
bias_: Optional[Parameter] = None,
|
||||
@ -240,12 +236,10 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
|
||||
self.out_features = out_features
|
||||
self.gather_output = gather_output
|
||||
self.seq_parallel_mode = seq_parallel_mode
|
||||
self.overlap = overlap
|
||||
self.skip_bias_add = skip_bias_add
|
||||
self.device = device
|
||||
self.split_sizes = split_sizes
|
||||
self.process_group = process_group
|
||||
self.async_communication = async_communication
|
||||
self.fp8_communication = fp8_communication
|
||||
|
||||
assert (
|
||||
@ -370,7 +364,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
|
||||
|
||||
# Matrix multiply.
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
if self.seq_parallel_mode == "split_gather":
|
||||
if is_share_sp_tp(self.seq_parallel_mode):
|
||||
input_parallel = input_
|
||||
output_parallel = matmul_gather_forward_reducescatter_backward(
|
||||
input_parallel,
|
||||
@ -379,31 +373,18 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
|
||||
self.process_group,
|
||||
True,
|
||||
1,
|
||||
self.overlap,
|
||||
fp8_communication=self.fp8_communication,
|
||||
)
|
||||
elif self.seq_parallel_mode == "ring":
|
||||
input_parallel = input_
|
||||
output_parallel = matmul_gather_forward_reducescatter_backward(
|
||||
input_parallel,
|
||||
self.weight,
|
||||
bias,
|
||||
self.process_group,
|
||||
True,
|
||||
1,
|
||||
self.overlap,
|
||||
True,
|
||||
ring=self.seq_parallel_mode == "ring",
|
||||
fp8_communication=self.fp8_communication,
|
||||
)
|
||||
elif self.seq_parallel_mode is None or self.seq_parallel_mode == "ring_attn":
|
||||
# Set up backprop all-reduce.
|
||||
input_parallel = reduce_backward(input_, self.process_group)
|
||||
input_parallel = input_
|
||||
output_parallel = matmul_with_async_comm(
|
||||
input_parallel,
|
||||
self.weight,
|
||||
bias,
|
||||
self.process_group,
|
||||
self.async_communication,
|
||||
True,
|
||||
fp8_communication=self.fp8_communication,
|
||||
)
|
||||
else:
|
||||
@ -620,7 +601,7 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
|
||||
if self.seq_parallel_mode is None or self.seq_parallel_mode == "ring_attn":
|
||||
output_parallel = torch.matmul(input_, self.weight)
|
||||
output = reduce_forward(output_parallel, self.process_group, fp8_communication=self.fp8_communication)
|
||||
elif self.seq_parallel_mode == "split_gather":
|
||||
elif is_share_sp_tp(self.seq_parallel_mode):
|
||||
output_parallel = torch.matmul(input_, self.weight)
|
||||
output = reducescatter_forward_gather_backward(
|
||||
output_parallel,
|
||||
@ -628,13 +609,6 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
|
||||
1,
|
||||
self.fp8_communication,
|
||||
)
|
||||
elif self.seq_parallel_mode == "ring":
|
||||
output_parallel = torch.matmul(input_, self.weight)
|
||||
output = reducescatter_forward_gather_backward(
|
||||
output_parallel,
|
||||
self.process_group,
|
||||
1,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"seq_parallel_mode={self.seq_parallel_mode} is not supported!")
|
||||
|
||||
@ -691,7 +665,6 @@ class FusedLinear1D_Col(ParallelModule):
|
||||
gather_output: bool = False,
|
||||
seq_parallel_mode: str = None,
|
||||
seq_parallel_dim: int = 1,
|
||||
overlap: torch.cuda.Stream = None,
|
||||
skip_bias_add: bool = False,
|
||||
weight: Optional[Parameter] = None,
|
||||
bias_: Optional[Parameter] = None,
|
||||
@ -706,7 +679,6 @@ class FusedLinear1D_Col(ParallelModule):
|
||||
self.gather_output = gather_output
|
||||
self.seq_parallel_mode = seq_parallel_mode
|
||||
self.seq_parallel_dim = seq_parallel_dim
|
||||
self.overlap = overlap
|
||||
self.skip_bias_add = skip_bias_add
|
||||
self.device = device
|
||||
self.split_sizes = split_sizes
|
||||
@ -830,16 +802,15 @@ class FusedLinear1D_Col(ParallelModule):
|
||||
# Matrix multiply.
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
|
||||
if self.seq_parallel_mode == "split_gather":
|
||||
input_parallel = gather_forward_reducescatter_backward(
|
||||
input_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication
|
||||
)
|
||||
output_parallel = linear_with_async_comm(
|
||||
input_parallel, self.weight, bias, self.process_group, False, fp8_communication=self.fp8_communication
|
||||
)
|
||||
elif self.seq_parallel_mode == "ring":
|
||||
if is_share_sp_tp(self.seq_parallel_mode):
|
||||
output_parallel = linear_gather_forward_reducescatter_backward(
|
||||
input_parallel, self.weight, bias, self.process_group, True, self.seq_parallel_dim, self.overlap, True
|
||||
input_parallel,
|
||||
self.weight,
|
||||
bias,
|
||||
self.process_group,
|
||||
True,
|
||||
self.seq_parallel_dim,
|
||||
ring=self.seq_parallel_mode == "ring",
|
||||
)
|
||||
else:
|
||||
output_parallel = linear_with_async_comm(
|
||||
@ -1031,18 +1002,13 @@ class FusedLinear1D_Row(ParallelModule):
|
||||
)
|
||||
input_ = split_forward_gather_backward_fused_qkv(input_, self.split_sizes, self.process_group)
|
||||
|
||||
if self.seq_parallel_mode == "split_gather":
|
||||
output_parallel = F.linear(input_, self.weight)
|
||||
output = reducescatter_forward_gather_backward(
|
||||
output_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication
|
||||
)
|
||||
elif self.seq_parallel_mode == "ring":
|
||||
if is_share_sp_tp(self.seq_parallel_mode):
|
||||
output = linear_reducescatter_forward_gather_backward(
|
||||
input_,
|
||||
self.weight,
|
||||
process_group=self.process_group,
|
||||
dim=self.seq_parallel_dim,
|
||||
ring=True,
|
||||
ring=self.seq_parallel_mode == "ring",
|
||||
)
|
||||
else:
|
||||
output_parallel = F.linear(input_, self.weight)
|
||||
|
@ -73,7 +73,6 @@ class BertPolicy(Policy):
|
||||
)
|
||||
sp_mode = "split_gather"
|
||||
|
||||
overlap = self.shard_config.enable_sequence_overlap
|
||||
sp_partial_derived = sp_mode == "split_gather"
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
@ -97,7 +96,6 @@ class BertPolicy(Policy):
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"overlap": overlap,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
@ -106,7 +104,6 @@ class BertPolicy(Policy):
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"overlap": overlap,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
@ -115,7 +112,6 @@ class BertPolicy(Policy):
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"overlap": overlap,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
@ -140,7 +136,6 @@ class BertPolicy(Policy):
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"overlap": overlap,
|
||||
"skip_bias_add": self.enable_bias_gelu_fused,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
|
@ -57,7 +57,6 @@ class BloomPolicy(Policy):
|
||||
)
|
||||
sp_mode = "split_gather"
|
||||
|
||||
overlap = self.shard_config.enable_sequence_overlap
|
||||
sp_partial_derived = sp_mode == "split_gather"
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
@ -78,7 +77,6 @@ class BloomPolicy(Policy):
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"overlap": overlap,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
@ -99,7 +97,6 @@ class BloomPolicy(Policy):
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"overlap": overlap,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
|
@ -67,7 +67,6 @@ class ChatGLMPolicy(Policy):
|
||||
f"For ChatGLM2, sequence parallelism doesn't support mode {sp_mode} yet, will set to be split_gather"
|
||||
)
|
||||
sp_mode = "split_gather"
|
||||
overlap = self.shard_config.enable_sequence_overlap
|
||||
sp_partial_derived = sp_mode in ["split_gather"]
|
||||
|
||||
if sp_mode == "all_to_all":
|
||||
@ -127,7 +126,6 @@ class ChatGLMPolicy(Policy):
|
||||
kwargs={
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"seq_parallel_dim": 0,
|
||||
"overlap": overlap,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
|
@ -65,7 +65,6 @@ class GPT2Policy(Policy):
|
||||
f"For GPT2, sequence parallelism is currently not support mode {sp_mode}, will set to be split_gather"
|
||||
)
|
||||
self.shard_config.sequence_parallelism_mode = sp_mode = "split_gather"
|
||||
overlap = self.shard_config.enable_sequence_overlap
|
||||
sp_partial_derived = sp_mode in ["split_gather", "ring"]
|
||||
use_flash_attention = self.shard_config.enable_flash_attention
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
@ -94,7 +93,6 @@ class GPT2Policy(Policy):
|
||||
kwargs={
|
||||
"split_sizes": [self.model.config.hidden_size] * 3,
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"overlap": overlap,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
@ -109,7 +107,6 @@ class GPT2Policy(Policy):
|
||||
kwargs={
|
||||
"split_sizes": [self.model.config.n_inner or 4 * self.model.config.hidden_size],
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"overlap": overlap,
|
||||
"skip_bias_add": self.enable_bias_gelu_fused,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
|
@ -51,7 +51,6 @@ class GPTJPolicy(Policy):
|
||||
self.shard_config.enable_sequence_parallelism = False
|
||||
warnings.warn("GPTJ doesn't support sequence parallelism now, will ignore the sequence parallelism flag.")
|
||||
|
||||
overlap = self.shard_config.enable_sequence_overlap
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
assert (
|
||||
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
|
||||
@ -76,7 +75,6 @@ class GPTJPolicy(Policy):
|
||||
suffix="attn.k_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"overlap": overlap,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
@ -84,7 +82,6 @@ class GPTJPolicy(Policy):
|
||||
suffix="attn.q_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"overlap": overlap,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
@ -92,7 +89,6 @@ class GPTJPolicy(Policy):
|
||||
suffix="attn.v_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"overlap": overlap,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
},
|
||||
),
|
||||
|
@ -26,7 +26,6 @@ class ShardConfig:
|
||||
enable_flash_attention (bool, optional): Whether to switch on flash attention. Defaults to False.
|
||||
enable_jit_fused (bool, optional): Whether to switch on JIT fused operators. Defaults to False.
|
||||
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism, which partitions non-tensor-parallel regions along the sequence dimension. Defaults to False.
|
||||
enable_sequence_overlap (bool): Whether to turn on sequence overlap, which overlap the computation and communication in sequence parallelism. It can only be used when enable_sequence_parallelism is True. Defaults to False.
|
||||
gradient_checkpoint_config (Optional[GradientCheckpointConfig]): The gradient checkpoint config. Defaults to None.
|
||||
enable_all_optimization (bool): Whether to turn on all optimization tools including 'fused normalization', 'flash attention', 'JIT fused operators', 'sequence parallelism' and 'sequence overlap'. Defaults to False.
|
||||
fp8_communication (bool, optional): Whether to enable fp8 communication in model parallelism. Defaults to False.
|
||||
@ -44,7 +43,6 @@ class ShardConfig:
|
||||
enable_jit_fused: bool = False
|
||||
enable_sequence_parallelism: bool = False
|
||||
sequence_parallelism_mode: str = None
|
||||
enable_sequence_overlap: bool = False
|
||||
parallel_output: bool = True
|
||||
make_vocab_size_divisible_by: int = 64
|
||||
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None
|
||||
@ -84,24 +82,12 @@ class ShardConfig:
|
||||
assert (
|
||||
self.enable_tensor_parallelism
|
||||
), f"sequence parallelism mode {self.sequence_parallelism_mode} can only be used when enable_tensor_parallelism is True"
|
||||
elif self.sequence_parallelism_mode in ["all_to_all"]:
|
||||
# assert (
|
||||
# not self.enable_tensor_parallelism
|
||||
# ), f"sequence parallelism mode {self.sequence_parallelism_mode} can only be used when enable_tensor_parallelism is False"
|
||||
if self.enable_sequence_overlap:
|
||||
self.enable_sequence_overlap = False
|
||||
warnings.warn(
|
||||
f"The enable_sequence_overlap flag will be ignored in sequence parallelism mode {self.sequence_parallelism_mode}"
|
||||
)
|
||||
else:
|
||||
if self.sequence_parallelism_mode:
|
||||
self.sequence_parallelism_mode = None
|
||||
warnings.warn(
|
||||
f"The sequence_parallelism_mode will be ignored when enable_sequence_parallelism is False"
|
||||
)
|
||||
assert (
|
||||
not self.enable_sequence_overlap
|
||||
), f"enable_sequence_overlap can only be set to True when enable_sequence_parallelism is True"
|
||||
|
||||
# get the tensor parallel size
|
||||
if not self.enable_tensor_parallelism:
|
||||
@ -134,4 +120,3 @@ class ShardConfig:
|
||||
# This can cause non-in-place param sharding when used without ZeRO.
|
||||
# It may also slow down training when seq len is small. Plz enable manually.
|
||||
# self.enable_sequence_parallelism = True
|
||||
# self.enable_sequence_overlap = True
|
||||
|
@ -41,7 +41,7 @@ class Conv1D(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
def check_linear_conv_1d_col(lazy_init: bool, seq_parallel_mode: str, overlap: bool):
|
||||
def check_linear_conv_1d_col(lazy_init: bool, seq_parallel_mode: str):
|
||||
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||
linear = Conv1D(192, 48).cuda()
|
||||
with ctx:
|
||||
@ -52,7 +52,6 @@ def check_linear_conv_1d_col(lazy_init: bool, seq_parallel_mode: str, overlap: b
|
||||
gather_output=True,
|
||||
seq_parallel_mode=seq_parallel_mode,
|
||||
split_sizes=[64] * 3,
|
||||
overlap=overlap,
|
||||
)
|
||||
|
||||
assert linear.weight.shape == torch.Size([48, 192])
|
||||
@ -121,9 +120,8 @@ def check_linear_conv_1d_row(lazy_init: bool, seq_parallel_mode: bool):
|
||||
|
||||
@parameterize("lazy_init", [False, True])
|
||||
@parameterize("seq_parallel_mode", ["split_gather", None])
|
||||
@parameterize("overlap", [True])
|
||||
def check_gpt2_qkv_fused_linear_1d(lazy_init: bool, seq_parallel_mode: bool, overlap: bool):
|
||||
check_linear_conv_1d_col(lazy_init, seq_parallel_mode, overlap)
|
||||
def check_gpt2_qkv_fused_linear_1d(lazy_init: bool, seq_parallel_mode: bool):
|
||||
check_linear_conv_1d_col(lazy_init, seq_parallel_mode)
|
||||
check_linear_conv_1d_row(lazy_init, seq_parallel_mode)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user