mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-12-03 04:13:20 +00:00
[Tensor] Add function to spec and update linear 1Drow and unit tests (#869)
This commit is contained in:
@@ -1,8 +1,6 @@
|
||||
from enum import Enum
|
||||
from typing import Tuple, List
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
|
||||
|
||||
class ComputePattern(Enum):
|
||||
TP1DRow = 1
|
||||
@@ -12,17 +10,13 @@ class ComputePattern(Enum):
|
||||
|
||||
|
||||
class ParallelAction(object):
|
||||
priority = 0
|
||||
compute_pattern = ComputePattern.DP
|
||||
process_group = gpc.get_group(ParallelMode.DATA)
|
||||
|
||||
def __init__(self, priority, compute_pattern, process_group) -> None:
|
||||
def __init__(self, priority=0, compute_pattern=ComputePattern.DP, parallel_mode=ParallelMode.DATA) -> None:
|
||||
self.priority = priority
|
||||
self.compute_pattern = compute_pattern
|
||||
self.process_group = process_group
|
||||
self.parallel_mode = parallel_mode
|
||||
|
||||
|
||||
class TensorSpec(Enum):
|
||||
class TensorSpec(object):
|
||||
"""
|
||||
It contains two aspects of information:
|
||||
First, How are tensors distributed in Heterougenous memory space.
|
||||
@@ -44,4 +38,28 @@ class TensorSpec(Enum):
|
||||
# Before Linear Op, we gather the tensors according to ZeRO.
|
||||
# We perform Linear Op according to compute pattern of TP1DRow.
|
||||
# After Linear Op, we split the tensors according to ZeRO.
|
||||
parallel_action_list: List[ParallelAction] = []
|
||||
def __init__(self, parallel_action_list: List[ParallelAction] = []):
|
||||
self._parallel_action_list = parallel_action_list
|
||||
self.sort()
|
||||
|
||||
@property
|
||||
def parallel_action_list(self):
|
||||
return self._parallel_action_list
|
||||
|
||||
@property
|
||||
def num_action(self):
|
||||
return len(self._parallel_action_list)
|
||||
|
||||
@property
|
||||
def compute_patterns(self):
|
||||
return [parallel_action.compute_pattern for parallel_action in self._parallel_action_list]
|
||||
|
||||
def sort(self):
|
||||
if len(self._parallel_action_list) > 0:
|
||||
self._parallel_action_list.sort(key=lambda parallel_action : parallel_action.priority)
|
||||
|
||||
def get_action_by_compute_pattern(self, compute_pattern: ComputePattern):
|
||||
for parallel_action in self._parallel_action_list:
|
||||
if parallel_action.compute_pattern == compute_pattern:
|
||||
return parallel_action
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user