mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 06:30:41 +00:00
shardformer fp8
This commit is contained in:
@@ -183,6 +183,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
|
||||
bias_: Optional[Parameter] = None,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
fp8_communication: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -197,6 +198,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
|
||||
self.n_fused = n_fused
|
||||
self.process_group = process_group
|
||||
self.async_communication = async_communication
|
||||
self.fp8_communication = fp8_communication
|
||||
|
||||
if skip_bias_add and not bias:
|
||||
raise ValueError("cannot skip bias addition if bias is None")
|
||||
@@ -314,14 +316,26 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
|
||||
|
||||
if self.seq_parallel_mode is None:
|
||||
# Set up backprop all-reduce.
|
||||
input_parallel = reduce_backward(input_, self.process_group)
|
||||
input_parallel = reduce_backward(input_, self.process_group, fp8_communication=self.fp8_communication)
|
||||
output_parallel = matmul_with_async_comm(
|
||||
input_parallel, self.weight, bias, self.process_group, self.async_communication
|
||||
input_parallel,
|
||||
self.weight,
|
||||
bias,
|
||||
self.process_group,
|
||||
self.async_communication,
|
||||
fp8_communication=self.fp8_communication,
|
||||
)
|
||||
elif self.seq_parallel_mode == "split_gather":
|
||||
input_parallel = input_
|
||||
output_parallel = matmul_gather_forward_reducescatter_backward(
|
||||
input_parallel, self.weight, bias, self.process_group, True, 1, self.overlap
|
||||
input_parallel,
|
||||
self.weight,
|
||||
bias,
|
||||
self.process_group,
|
||||
True,
|
||||
1,
|
||||
self.overlap,
|
||||
fp8_communication=self.fp8_communication,
|
||||
)
|
||||
elif self.seq_parallel_mode == "ring":
|
||||
input_parallel = input_
|
||||
@@ -331,7 +345,9 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
|
||||
|
||||
if self.gather_output:
|
||||
# All-gather across the partitions.
|
||||
output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group)
|
||||
output = gather_forward_split_backward(
|
||||
output_parallel, dim=-1, process_group=self.process_group, fp8_communication=self.fp8_communication
|
||||
)
|
||||
else:
|
||||
output = output_parallel
|
||||
|
||||
@@ -379,6 +395,7 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
stream_chunk_num: int = 1,
|
||||
fp8_communication: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -392,6 +409,7 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
|
||||
self.process_group = process_group
|
||||
self.seq_parallel_mode = seq_parallel_mode
|
||||
self.num_partitions = dist.get_world_size(self.process_group)
|
||||
self.fp8_communication = fp8_communication
|
||||
|
||||
if skip_bias_add and not bias:
|
||||
raise ValueError("cannot skip bias addition if bias is None")
|
||||
@@ -514,7 +532,9 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
|
||||
), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.".format(
|
||||
input_.shape, self.weight.shape, self.weight.shape[0] * self.num_partitions
|
||||
)
|
||||
input_ = split_forward_gather_backward(input_, dim=-1, process_group=self.process_group)
|
||||
input_ = split_forward_gather_backward(
|
||||
input_, dim=-1, process_group=self.process_group, fp8_communication=self.fp8_communication
|
||||
)
|
||||
|
||||
if self.stream_chunk_num > 1:
|
||||
if self.training:
|
||||
@@ -535,13 +555,20 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
|
||||
else:
|
||||
if self.seq_parallel_mode is None:
|
||||
output_parallel = torch.matmul(input_, self.weight)
|
||||
output = reduce_forward(output_parallel, self.process_group)
|
||||
output = reduce_forward(output_parallel, self.process_group, self.fp8_communication)
|
||||
elif self.seq_parallel_mode == "split_gather":
|
||||
output_parallel = torch.matmul(input_, self.weight)
|
||||
output = reducescatter_forward_gather_backward(output_parallel, self.process_group, 1)
|
||||
output = reducescatter_forward_gather_backward(
|
||||
output_parallel,
|
||||
self.process_group,
|
||||
1,
|
||||
self.fp8_communication,
|
||||
)
|
||||
elif self.seq_parallel_mode == "ring":
|
||||
output_parallel = torch.matmul(input_, self.weight)
|
||||
output = reducescatter_forward_gather_backward(output_parallel, self.process_group, 1)
|
||||
output = reducescatter_forward_gather_backward(
|
||||
output_parallel, self.process_group, 1, self.fp8_communication
|
||||
)
|
||||
|
||||
if not self.skip_bias_add:
|
||||
if self.bias is not None:
|
||||
|
Reference in New Issue
Block a user