mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-11-23 13:40:50 +00:00
[tensor] refine linear and add gather for laynorm (#893)
* refine linear and add function to ColoTensor * add gather for layernorm * polish * polish
This commit is contained in:
@@ -4,11 +4,16 @@ from colossalai.context.parallel_mode import ParallelMode
|
||||
|
||||
|
||||
class ComputePattern(Enum):
|
||||
Activation = 0 # TODO(jzy) A tmp place to store Activation info. Find a better place in future.
|
||||
TP1DRow = 1
|
||||
TP1DCol = 2
|
||||
ZeRO = 3
|
||||
DP = 4
|
||||
|
||||
class ShardPattern(Enum):
|
||||
NA = 0
|
||||
Row = 1
|
||||
Col = 2
|
||||
|
||||
class ParallelAction(object):
|
||||
|
||||
@@ -18,6 +23,7 @@ class ParallelAction(object):
|
||||
self.parallel_mode = parallel_mode
|
||||
self.gather_out = gather_out
|
||||
|
||||
|
||||
class TensorSpec(object):
|
||||
"""
|
||||
It contains two aspects of information:
|
||||
@@ -42,8 +48,9 @@ class TensorSpec(object):
|
||||
# We perform Linear Op according to compute pattern of TP1DRow.
|
||||
# After Linear Op, we split the tensors according to ZeRO.
|
||||
|
||||
def __init__(self, parallel_action_list: List[ParallelAction] = []):
|
||||
def __init__(self, parallel_action_list: List[ParallelAction] = [], shard_pattern: ShardPattern = ShardPattern.NA):
|
||||
self._parallel_action_list = parallel_action_list
|
||||
self._shard_pattern = shard_pattern
|
||||
self.sort()
|
||||
|
||||
@property
|
||||
@@ -57,6 +64,10 @@ class TensorSpec(object):
|
||||
@property
|
||||
def compute_patterns(self):
|
||||
return [parallel_action.compute_pattern for parallel_action in self._parallel_action_list]
|
||||
|
||||
@property
|
||||
def shard_pattern(self):
|
||||
return self._shard_pattern
|
||||
|
||||
def sort(self):
|
||||
if len(self._parallel_action_list) > 0:
|
||||
|
||||
Reference in New Issue
Block a user