[shardformer] supported fused qkv checkpoint (#4073)

This commit is contained in:
Frank Lee
2023-06-23 16:07:09 +08:00
parent 0803a61412
commit 70c58cfd4f
10 changed files with 420 additions and 88 deletions

View File

@@ -15,12 +15,11 @@ from torch.nn.parameter import Parameter
from colossalai.nn import init as init
from colossalai.nn.layer.utils import divide
from colossalai.tensor.d_tensor import shard_colwise, shard_rowwise, sharded_tensor_to_param
from colossalai.utils.cuda import get_current_device
from ._operation import (
gather_forward_split_backward,
linear_with_async_comm,
reduce_input,
reduce_forward,
split_forward_gather_backward,
)
from .parallel_module import ParallelModule
@@ -148,9 +147,10 @@ class Linear1D_Col(ParallelModule):
assert input_.shape[-1] == self.weight.shape[-1], \
'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format(
input_.shape, self.weight.shape, self.weight.shape[-1])
# Set up backprop all-reduce.
# input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D)
input_parallel = input_
# Matrix multiply.
bias = self.bias if not self.skip_bias_add else None
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True)
@@ -209,17 +209,14 @@ class Linear1D_Row(ParallelModule):
self.parallel_input = parallel_input
self.skip_bias_add = skip_bias_add
self.process_group = process_group
self.num_partitions = dist.get_world_size(self.process_group)
if skip_bias_add and not bias:
raise ValueError('cannot skip bias addition if bias is None')
# Parameters.
# Initialize weight.
if device is None:
device = get_current_device()
factory_kwargs = {'device': device, 'dtype': dtype}
weight = torch.empty(self.out_features, self.in_features, **factory_kwargs)
sharded_weight = shard_colwise(weight, self.process_group)
self.weight = sharded_tensor_to_param(sharded_weight)
@@ -327,8 +324,7 @@ class Linear1D_Row(ParallelModule):
output = torch.cat(output_parallel_list, dim=-1)
else:
output_parallel = F.linear(input_, self.weight)
# output_parallel = linear_with_async_comm(input_, self.weight, None, ParallelMode.PARALLEL_1D, False)
output = reduce_input(output_parallel, self.process_group)
output = reduce_forward(output_parallel, self.process_group)
if not self.skip_bias_add:
if self.bias is not None:
@@ -336,5 +332,3 @@ class Linear1D_Row(ParallelModule):
return output
else:
return output, self.bias
return output, self.bias
return output, self.bias