[autoparallel] add sequential order to communication actions (#1735)

This commit is contained in:
YuliangLiu0306
2022-10-20 18:48:18 +08:00
committed by GitHub
parent b893342f95
commit a4ce180e85
7 changed files with 293 additions and 90 deletions

View File

@@ -4,11 +4,12 @@ from enum import Enum
from typing import Any, Dict, List, Tuple, Union
import torch
from colossalai.tensor.shape_consistency import CommSpec
from colossalai.tensor.sharding_spec import ShardingSpec
from torch.fx.node import Node
from .constants import (BCAST_FUNC_OP, ELEMENTWISE_FUNC_OP, ELEMENTWISE_MODULE_OP, RESHAPE_FUNC_OP)
from colossalai.tensor.shape_consistency import CommSpec
from colossalai.tensor.sharding_spec import ShardingSpec
from .constants import BCAST_FUNC_OP, ELEMENTWISE_FUNC_OP, ELEMENTWISE_MODULE_OP, RESHAPE_FUNC_OP
__all__ = ['OperationDataType', 'OperationData', 'TrainCycleItem', 'MemoryCost', 'ShardingStrategy', 'StrategiesVector']
@@ -84,6 +85,38 @@ class MemoryCost:
buffer: int = 0
class CommType(Enum):
"""
CommType describes the sequential order of a communication action and a computation action.
Meaning:
BEFORE: the communication action happens just before the computation operation.
AFTER: the communication action happens after the computation operation.
HOOK: the communication action is used to do the grad all reduce.
IMPLICIT: the communication action happens during the kernel execution, such as SyncBatchNorm
"""
BEFORE = 0
AFTER = 1
HOOK = 2
IMPLICIT = 3
@dataclass
class CommAction:
"""
CommAction is used to record the communication action.
Args:
comm_spec: express the communication pattern and the process groups to execute the communication action.
comm_type: describes the sequential order of a communication action and a computation action.
arg_index: record the location of tensor which join the communication, we cannot use name of node or op_data at runtime,
because the args of node may be changed by graph transform passes.
"""
comm_spec: CommSpec = None
comm_type: CommType = None
arg_index: int = -1
@dataclass
class ShardingStrategy:
"""
@@ -102,7 +135,7 @@ class ShardingStrategy:
compute_cost: TrainCycleItem = None
communication_cost: TrainCycleItem = None
memory_cost: TrainCycleItem = None
communication_actions: Dict[OperationData, CommSpec] = None
communication_actions: Dict[OperationData, CommAction] = None
resharding_costs: Dict[Node, List[TrainCycleItem]] = None
@property