mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 13:00:52 +00:00
[shardformer] supported fused qkv checkpoint (#4073)
This commit is contained in:
@@ -15,12 +15,11 @@ from torch.nn.parameter import Parameter
|
||||
from colossalai.nn import init as init
|
||||
from colossalai.nn.layer.utils import divide
|
||||
from colossalai.tensor.d_tensor import shard_colwise, shard_rowwise, sharded_tensor_to_param
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
|
||||
from ._operation import (
|
||||
gather_forward_split_backward,
|
||||
linear_with_async_comm,
|
||||
reduce_input,
|
||||
reduce_forward,
|
||||
split_forward_gather_backward,
|
||||
)
|
||||
from .parallel_module import ParallelModule
|
||||
@@ -148,9 +147,10 @@ class Linear1D_Col(ParallelModule):
|
||||
assert input_.shape[-1] == self.weight.shape[-1], \
|
||||
'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format(
|
||||
input_.shape, self.weight.shape, self.weight.shape[-1])
|
||||
|
||||
# Set up backprop all-reduce.
|
||||
# input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D)
|
||||
input_parallel = input_
|
||||
|
||||
# Matrix multiply.
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True)
|
||||
@@ -209,17 +209,14 @@ class Linear1D_Row(ParallelModule):
|
||||
self.parallel_input = parallel_input
|
||||
self.skip_bias_add = skip_bias_add
|
||||
self.process_group = process_group
|
||||
self.num_partitions = dist.get_world_size(self.process_group)
|
||||
|
||||
if skip_bias_add and not bias:
|
||||
raise ValueError('cannot skip bias addition if bias is None')
|
||||
|
||||
# Parameters.
|
||||
# Initialize weight.
|
||||
if device is None:
|
||||
device = get_current_device()
|
||||
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
|
||||
weight = torch.empty(self.out_features, self.in_features, **factory_kwargs)
|
||||
sharded_weight = shard_colwise(weight, self.process_group)
|
||||
self.weight = sharded_tensor_to_param(sharded_weight)
|
||||
@@ -327,8 +324,7 @@ class Linear1D_Row(ParallelModule):
|
||||
output = torch.cat(output_parallel_list, dim=-1)
|
||||
else:
|
||||
output_parallel = F.linear(input_, self.weight)
|
||||
# output_parallel = linear_with_async_comm(input_, self.weight, None, ParallelMode.PARALLEL_1D, False)
|
||||
output = reduce_input(output_parallel, self.process_group)
|
||||
output = reduce_forward(output_parallel, self.process_group)
|
||||
|
||||
if not self.skip_bias_add:
|
||||
if self.bias is not None:
|
||||
@@ -336,5 +332,3 @@ class Linear1D_Row(ParallelModule):
|
||||
return output
|
||||
else:
|
||||
return output, self.bias
|
||||
return output, self.bias
|
||||
return output, self.bias
|
||||
|
Reference in New Issue
Block a user