mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 19:40:28 +00:00
[autoparallel] add sequential order to communication actions (#1735)
This commit is contained in:
@@ -4,11 +4,12 @@ from enum import Enum
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
|
||||
import torch
|
||||
from colossalai.tensor.shape_consistency import CommSpec
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
from torch.fx.node import Node
|
||||
|
||||
from .constants import (BCAST_FUNC_OP, ELEMENTWISE_FUNC_OP, ELEMENTWISE_MODULE_OP, RESHAPE_FUNC_OP)
|
||||
from colossalai.tensor.shape_consistency import CommSpec
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
from .constants import BCAST_FUNC_OP, ELEMENTWISE_FUNC_OP, ELEMENTWISE_MODULE_OP, RESHAPE_FUNC_OP
|
||||
|
||||
__all__ = ['OperationDataType', 'OperationData', 'TrainCycleItem', 'MemoryCost', 'ShardingStrategy', 'StrategiesVector']
|
||||
|
||||
@@ -84,6 +85,38 @@ class MemoryCost:
|
||||
buffer: int = 0
|
||||
|
||||
|
||||
class CommType(Enum):
|
||||
"""
|
||||
CommType describes the sequential order of a communication action and a computation action.
|
||||
|
||||
Meaning:
|
||||
BEFORE: the communication action happens just before the computation operation.
|
||||
AFTER: the communication action happens after the computation operation.
|
||||
HOOK: the communication action is used to do the grad all reduce.
|
||||
IMPLICIT: the communication action happens during the kernel execution, such as SyncBatchNorm
|
||||
"""
|
||||
BEFORE = 0
|
||||
AFTER = 1
|
||||
HOOK = 2
|
||||
IMPLICIT = 3
|
||||
|
||||
|
||||
@dataclass
|
||||
class CommAction:
|
||||
"""
|
||||
CommAction is used to record the communication action.
|
||||
|
||||
Args:
|
||||
comm_spec: express the communication pattern and the process groups to execute the communication action.
|
||||
comm_type: describes the sequential order of a communication action and a computation action.
|
||||
arg_index: record the location of tensor which join the communication, we cannot use name of node or op_data at runtime,
|
||||
because the args of node may be changed by graph transform passes.
|
||||
"""
|
||||
comm_spec: CommSpec = None
|
||||
comm_type: CommType = None
|
||||
arg_index: int = -1
|
||||
|
||||
|
||||
@dataclass
|
||||
class ShardingStrategy:
|
||||
"""
|
||||
@@ -102,7 +135,7 @@ class ShardingStrategy:
|
||||
compute_cost: TrainCycleItem = None
|
||||
communication_cost: TrainCycleItem = None
|
||||
memory_cost: TrainCycleItem = None
|
||||
communication_actions: Dict[OperationData, CommSpec] = None
|
||||
communication_actions: Dict[OperationData, CommAction] = None
|
||||
resharding_costs: Dict[Node, List[TrainCycleItem]] = None
|
||||
|
||||
@property
|
||||
|
Reference in New Issue
Block a user