[tensor] refine linear and add gather for laynorm (#893)

* refine linear and add function to ColoTensor

* add gather for layernorm

* polish

* polish
This commit is contained in:
Ziyue Jiang
2022-04-28 10:55:40 +08:00
committed by GitHub
parent 26c49639d8
commit cb182da7c5
7 changed files with 225 additions and 123 deletions

View File

@@ -4,11 +4,16 @@ from colossalai.context.parallel_mode import ParallelMode
class ComputePattern(Enum):
Activation = 0 # TODO(jzy) A tmp place to store Activation info. Find a better place in future.
TP1DRow = 1
TP1DCol = 2
ZeRO = 3
DP = 4
class ShardPattern(Enum):
NA = 0
Row = 1
Col = 2
class ParallelAction(object):
@@ -18,6 +23,7 @@ class ParallelAction(object):
self.parallel_mode = parallel_mode
self.gather_out = gather_out
class TensorSpec(object):
"""
It contains two aspects of information:
@@ -42,8 +48,9 @@ class TensorSpec(object):
# We perform Linear Op according to compute pattern of TP1DRow.
# After Linear Op, we split the tensors according to ZeRO.
def __init__(self, parallel_action_list: List[ParallelAction] = []):
def __init__(self, parallel_action_list: List[ParallelAction] = [], shard_pattern: ShardPattern = ShardPattern.NA):
self._parallel_action_list = parallel_action_list
self._shard_pattern = shard_pattern
self.sort()
@property
@@ -57,6 +64,10 @@ class TensorSpec(object):
@property
def compute_patterns(self):
return [parallel_action.compute_pattern for parallel_action in self._parallel_action_list]
@property
def shard_pattern(self):
return self._shard_pattern
def sort(self):
if len(self._parallel_action_list) > 0: