[Tensor] add ColoTensor TP1Dcol Embedding (#899)

This commit is contained in:
Ziyue Jiang
2022-04-28 17:45:06 +08:00
committed by GitHub
parent e46e423c00
commit 2c0d19d755
9 changed files with 173 additions and 27 deletions

View File

@@ -142,14 +142,19 @@ class ColoTensor(object):
if self._shard_pattern is not ShardPattern.NA: # reshard
self.gather()
# Model Parameters
if ComputePattern.TP1DRow in self._shard_spec.compute_patterns:
parallel_action = self._shard_spec.get_action_by_compute_pattern(ComputePattern.TP1DRow)
self._shard_1d(parallel_action=parallel_action, dim=-1)
self._shard_pattern = ShardPattern.Col # We bind our ComputePattern on weight, which has to be transposed when linear().
elif ComputePattern.TP1DCol in self._shard_spec.compute_patterns:
parallel_action = self._shard_spec.get_action_by_compute_pattern(ComputePattern.TP1DCol)
self._shard_1d(parallel_action=parallel_action, dim=0)
self._shard_pattern = ShardPattern.Row
if self._shard_spec.num_action == 1:
parallel_action = self._shard_spec.get_action_by_compute_pattern(
self._shard_spec.compute_patterns[0])
if parallel_action.compute_pattern in [ComputePattern.TP1DRow_Linear, \
ComputePattern.TP1DCol_Embedding]:
self._shard_1d(parallel_action=parallel_action, dim=-1)
self._shard_pattern = ShardPattern.Col # We bind our ComputePattern on weight, which has to be transposed when linear().
elif parallel_action.compute_pattern in [ComputePattern.TP1DCol_Linear, \
ComputePattern.TP1DRow_Embedding]:
self._shard_1d(parallel_action=parallel_action, dim=0)
self._shard_pattern = ShardPattern.Row
else:
raise NotImplementedError
def gather(self):
assert self.is_activation(), 'Currently we only support gather Activation ColoTensor.'