[tensor] hijack addmm for colo tensor (#923)

* hijack addmm for colo tensor

* fix bugs

* polish unit test

* polish comments
This commit is contained in:
ver217
2022-05-09 18:55:49 +08:00
committed by GitHub
parent 534afb018a
commit 45b9124df4
5 changed files with 210 additions and 8 deletions

View File

@@ -142,12 +142,15 @@ class ColoTensor(object):
# Model Parameters
if self._shard_spec.num_action == 1:
parallel_action = self._shard_spec.get_action_by_compute_pattern(self._shard_spec.compute_patterns[0])
if parallel_action.compute_pattern in [ComputePattern.TP1DRow_Linear, \
ComputePattern.TP1DCol_Embedding]:
if parallel_action.compute_pattern in [
ComputePattern.TP1DRow_Linear, ComputePattern.TP1DCol_Embedding, ComputePattern.TP1DCol_mm
]:
self._shard_1d(parallel_action=parallel_action, dim=-1)
self._shard_pattern = ShardPattern.Col # We bind our ComputePattern on weight, which has to be transposed when linear().
elif parallel_action.compute_pattern in [ComputePattern.TP1DCol_Linear, \
ComputePattern.TP1DRow_Embedding]:
# We bind our ComputePattern on weight, which has to be transposed when linear().
self._shard_pattern = ShardPattern.Col
elif parallel_action.compute_pattern in [
ComputePattern.TP1DCol_Linear, ComputePattern.TP1DRow_Embedding, ComputePattern.TP1DRow_mm
]:
self._shard_1d(parallel_action=parallel_action, dim=0)
self._shard_pattern = ShardPattern.Row
else: