mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 20:10:17 +00:00
[Feature] llama shardformer fp8 support (#5938)
* add llama shardformer fp8 * Llama Shardformer Parity * fix typo * fix all reduce * fix pytest failure * fix reduce op and move function to fp8.py * fix typo
This commit is contained in:
@@ -84,6 +84,7 @@ class Linear1D_Col(ParallelModule):
|
||||
bias_: Optional[Parameter] = None,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
fp8_communication: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(weight=weight, bias_=bias_, **kwargs)
|
||||
@@ -98,6 +99,7 @@ class Linear1D_Col(ParallelModule):
|
||||
self.skip_bias_add = skip_bias_add
|
||||
self.device = device
|
||||
self.process_group = process_group
|
||||
self.fp8_communication = fp8_communication
|
||||
|
||||
if skip_bias_add and not bias:
|
||||
raise ValueError("cannot skip bias addition if bias is None")
|
||||
@@ -201,10 +203,12 @@ class Linear1D_Col(ParallelModule):
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
|
||||
if self.seq_parallel_mode is None:
|
||||
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True)
|
||||
output_parallel = linear_with_async_comm(
|
||||
input_parallel, self.weight, bias, self.process_group, True, fp8_communication=self.fp8_communication
|
||||
)
|
||||
elif self.seq_parallel_mode == "split_gather":
|
||||
input_parallel = gather_forward_reducescatter_backward(
|
||||
input_parallel, self.process_group, self.seq_parallel_dim
|
||||
input_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication
|
||||
)
|
||||
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, False)
|
||||
elif self.seq_parallel_mode == "ring":
|
||||
@@ -214,7 +218,9 @@ class Linear1D_Col(ParallelModule):
|
||||
|
||||
if self.gather_output:
|
||||
# All-gather across the partitions.
|
||||
output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group)
|
||||
output = gather_forward_split_backward(
|
||||
output_parallel, dim=-1, process_group=self.process_group, fp8_communication=self.fp8_communication
|
||||
)
|
||||
else:
|
||||
output = output_parallel
|
||||
|
||||
@@ -264,6 +270,7 @@ class Linear1D_Row(ParallelModule):
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
stream_chunk_num: int = 1,
|
||||
fp8_communication: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -278,6 +285,7 @@ class Linear1D_Row(ParallelModule):
|
||||
self.seq_parallel_mode = seq_parallel_mode
|
||||
self.seq_parallel_dim = seq_parallel_dim
|
||||
self.num_partitions = dist.get_world_size(self.process_group)
|
||||
self.fp8_communication = fp8_communication
|
||||
|
||||
if skip_bias_add and not bias:
|
||||
raise ValueError("cannot skip bias addition if bias is None")
|
||||
@@ -398,7 +406,9 @@ class Linear1D_Row(ParallelModule):
|
||||
), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.".format(
|
||||
input_.shape, self.weight.shape, self.weight.shape[-1] * self.num_partitions
|
||||
)
|
||||
input_ = split_forward_gather_backward(input_, dim=-1, process_group=self.process_group)
|
||||
input_ = split_forward_gather_backward(
|
||||
input_, dim=-1, process_group=self.process_group, fp8_communication=self.fp8_communication
|
||||
)
|
||||
|
||||
if self.stream_chunk_num > 1:
|
||||
if self.training:
|
||||
@@ -418,11 +428,11 @@ class Linear1D_Row(ParallelModule):
|
||||
else:
|
||||
if self.seq_parallel_mode is None:
|
||||
output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False)
|
||||
output = reduce_forward(output_parallel, self.process_group)
|
||||
output = reduce_forward(output_parallel, self.process_group, fp8_communication=self.fp8_communication)
|
||||
elif self.seq_parallel_mode == "split_gather":
|
||||
output_parallel = linear_with_async_comm(input_, self.weight, None, self.process_group, False)
|
||||
output = reducescatter_forward_gather_backward(
|
||||
output_parallel, self.process_group, self.seq_parallel_dim
|
||||
output_parallel, self.process_group, self.seq_parallel_dim, fp8_communication=self.fp8_communication
|
||||
)
|
||||
elif self.seq_parallel_mode == "ring":
|
||||
output = linear_reducescatter_forward_gather_backward(
|
||||
|
Reference in New Issue
Block a user