[shardformer] supported fused qkv checkpoint (#4073)

This commit is contained in:
Frank Lee
2023-06-23 16:07:09 +08:00
parent 0803a61412
commit 70c58cfd4f
10 changed files with 420 additions and 88 deletions

View File

@@ -15,7 +15,7 @@ from colossalai.nn import init as init
from colossalai.nn.layer.utils import divide
from colossalai.tensor.d_tensor.api import shard_colwise, shard_rowwise, sharded_tensor_to_param
from ._operation import gather_forward_split_backward, reduce_input
from ._operation import gather_forward_split_backward, reduce_forward
from .parallel_module import ParallelModule
from .utils import create_randomizer_with_offset
@@ -276,5 +276,5 @@ class VocabParallelEmbedding1D(ParallelModule):
# Mask the output embedding.
output_parallel[input_mask, :] = 0.
# Reduce across all the model parallel GPUs.
output = reduce_input(output_parallel, self.process_group)
output = reduce_forward(output_parallel, self.process_group)
return output