[Tensor] Add function to spec and update linear 1Drow and unit tests (#869)

This commit is contained in:
Ziyue Jiang
2022-04-26 10:15:26 +08:00
committed by GitHub
parent 11f54c7b6b
commit 26d4ab8b03
6 changed files with 85 additions and 58 deletions

View File

@@ -1,8 +1,6 @@
from enum import Enum
from typing import Tuple, List
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
class ComputePattern(Enum):
TP1DRow = 1
@@ -12,17 +10,13 @@ class ComputePattern(Enum):
class ParallelAction(object):
priority = 0
compute_pattern = ComputePattern.DP
process_group = gpc.get_group(ParallelMode.DATA)
def __init__(self, priority, compute_pattern, process_group) -> None:
def __init__(self, priority=0, compute_pattern=ComputePattern.DP, parallel_mode=ParallelMode.DATA) -> None:
self.priority = priority
self.compute_pattern = compute_pattern
self.process_group = process_group
self.parallel_mode = parallel_mode
class TensorSpec(Enum):
class TensorSpec(object):
"""
It contains two aspects of information:
First, How are tensors distributed in Heterougenous memory space.
@@ -44,4 +38,28 @@ class TensorSpec(Enum):
# Before Linear Op, we gather the tensors according to ZeRO.
# We perform Linear Op according to compute pattern of TP1DRow.
# After Linear Op, we split the tensors according to ZeRO.
parallel_action_list: List[ParallelAction] = []
def __init__(self, parallel_action_list: List[ParallelAction] = []):
self._parallel_action_list = parallel_action_list
self.sort()
@property
def parallel_action_list(self):
return self._parallel_action_list
@property
def num_action(self):
return len(self._parallel_action_list)
@property
def compute_patterns(self):
return [parallel_action.compute_pattern for parallel_action in self._parallel_action_list]
def sort(self):
if len(self._parallel_action_list) > 0:
self._parallel_action_list.sort(key=lambda parallel_action : parallel_action.priority)
def get_action_by_compute_pattern(self, compute_pattern: ComputePattern):
for parallel_action in self._parallel_action_list:
if parallel_action.compute_pattern == compute_pattern:
return parallel_action
return None