mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-08 11:27:24 +00:00
[autoparallel] implemented linear projection strategy generator (#1639)
This commit is contained in:
parent
154d3ef432
commit
45b39a692a
@ -1,46 +1,12 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from .node_handler import ModuleHandler, NodeHandler
|
from .node_handler import ModuleHandler, NodeHandler
|
||||||
from ..sharding_strategy import ShardingStrategy_V2, StrategyGenerator_V2, OperationDataType, OperationData
|
from ..sharding_strategy import ShardingStrategy_V2, OperationDataType, OperationData
|
||||||
|
from ..strategy import LinearProjectionStrategyGenerator, StrategyGenerator_V2
|
||||||
from typing import List, Dict
|
from typing import List, Dict
|
||||||
from .registry import operator_registry
|
from .registry import operator_registry
|
||||||
|
|
||||||
__all__ = ['LinearModuleHandler']
|
__all__ = ['LinearModuleHandler', 'LinearFunctionHandler']
|
||||||
|
|
||||||
|
|
||||||
class DotProductStrategyGenerator(StrategyGenerator_V2):
|
|
||||||
"""TODO: to be implemented"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class MatVecStrategyGenerator(StrategyGenerator_V2):
|
|
||||||
"""TODO: to be implemented"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class LinearProjectionStrategyGenerator(StrategyGenerator_V2):
|
|
||||||
|
|
||||||
def update_compute_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2:
|
|
||||||
"""TODO: to be implemented"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def update_memory_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2:
|
|
||||||
"""TODO: to be implemented"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def generate(self, operand_mapping: Dict[str, OperationData]) -> List[ShardingStrategy_V2]:
|
|
||||||
"""TODO: to be implemented"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def validate(self, *args, **kwargs) -> bool:
|
|
||||||
"""TODO: to be implemented"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class BatchedMatMulStrategyGenerator(StrategyGenerator_V2):
|
|
||||||
"""TODO: to be implemented"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@operator_registry.register(torch.nn.Linear)
|
@operator_registry.register(torch.nn.Linear)
|
||||||
@ -49,9 +15,10 @@ class LinearModuleHandler(ModuleHandler):
|
|||||||
A LinearModuleHandler which deals with the sharding strategies for nn.Linear module.
|
A LinearModuleHandler which deals with the sharding strategies for nn.Linear module.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def register_strategy_generator(self) -> List[StrategyGenerator_V2]:
|
def get_strategy_generator(self) -> List[StrategyGenerator_V2]:
|
||||||
|
op_data_mapping = self.get_operation_data_mapping()
|
||||||
generators = []
|
generators = []
|
||||||
generators.append(LinearProjectionStrategyGenerator(self.device_mesh))
|
generators.append(LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh))
|
||||||
return generators
|
return generators
|
||||||
|
|
||||||
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
|
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
|
||||||
@ -97,9 +64,10 @@ class LinearFunctionHandler(NodeHandler):
|
|||||||
A LinearModuleHandler which deals with the sharding strategies for nn.Linear module.
|
A LinearModuleHandler which deals with the sharding strategies for nn.Linear module.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def register_strategy_generator(self) -> List[StrategyGenerator_V2]:
|
def get_strategy_generator(self) -> List[StrategyGenerator_V2]:
|
||||||
|
op_data_mapping = self.get_operation_data_mapping()
|
||||||
generators = []
|
generators = []
|
||||||
generators.append(LinearProjectionStrategyGenerator(self.device_mesh))
|
generators.append(LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh))
|
||||||
return generators
|
return generators
|
||||||
|
|
||||||
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
|
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
|
||||||
@ -108,8 +76,15 @@ class LinearFunctionHandler(NodeHandler):
|
|||||||
physical_input_operand = OperationData(name=str(self.node.args[0]),
|
physical_input_operand = OperationData(name=str(self.node.args[0]),
|
||||||
type=OperationDataType.ARG,
|
type=OperationDataType.ARG,
|
||||||
data=self.node.args[0]._meta_data)
|
data=self.node.args[0]._meta_data)
|
||||||
|
|
||||||
|
# check if the other operand is a parameter
|
||||||
|
if isinstance(self.node.args[1]._meta_data, torch.nn.parameter.Parameter):
|
||||||
|
data_type = OperationDataType.PARAM
|
||||||
|
else:
|
||||||
|
data_type = OperationDataType.ARG
|
||||||
|
|
||||||
physical_other_operand = OperationData(name=str(self.node.args[1]),
|
physical_other_operand = OperationData(name=str(self.node.args[1]),
|
||||||
type=OperationDataType.ARG,
|
type=data_type,
|
||||||
data=self.node.args[1]._meta_data,
|
data=self.node.args[1]._meta_data,
|
||||||
logical_shape=self.node.args[1]._meta_data.shape[::-1])
|
logical_shape=self.node.args[1]._meta_data.shape[::-1])
|
||||||
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
|
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
|
||||||
@ -117,8 +92,13 @@ class LinearFunctionHandler(NodeHandler):
|
|||||||
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
|
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
|
||||||
|
|
||||||
if self.node.args[2] is not None:
|
if self.node.args[2] is not None:
|
||||||
|
# check if the other operand is a parameter
|
||||||
|
if isinstance(self.node.args[2]._meta_data, torch.nn.parameter.Parameter):
|
||||||
|
data_type = OperationDataType.PARAM
|
||||||
|
else:
|
||||||
|
data_type = OperationDataType.ARG
|
||||||
physical_bias_operand = OperationData(name=str(self.node.args[2]),
|
physical_bias_operand = OperationData(name=str(self.node.args[2]),
|
||||||
type=OperationDataType.ARG,
|
type=data_type,
|
||||||
data=self.node.args[2]._meta_data)
|
data=self.node.args[2]._meta_data)
|
||||||
mapping['bias'] = physical_bias_operand
|
mapping['bias'] = physical_bias_operand
|
||||||
return mapping
|
return mapping
|
||||||
|
@ -2,7 +2,8 @@ from abc import ABC, abstractmethod
|
|||||||
from torch.fx.node import Node
|
from torch.fx.node import Node
|
||||||
from colossalai.device.device_mesh import DeviceMesh
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
from ..sharding_strategy import ShardingStrategy, ShardingStrategy_V2, StrategiesVector, OperationData, StrategyGenerator_V2
|
from ..sharding_strategy import ShardingStrategy_V2, StrategiesVector, OperationData
|
||||||
|
from ..strategy import StrategyGenerator_V2
|
||||||
|
|
||||||
|
|
||||||
class NodeHandler(ABC):
|
class NodeHandler(ABC):
|
||||||
@ -26,14 +27,14 @@ class NodeHandler(ABC):
|
|||||||
self.successor_node = list(node.users.keys())
|
self.successor_node = list(node.users.keys())
|
||||||
self.device_mesh = device_mesh
|
self.device_mesh = device_mesh
|
||||||
self.strategies_vector = strategies_vector
|
self.strategies_vector = strategies_vector
|
||||||
self.strategy_generator = self.register_strategy_generator()
|
|
||||||
|
|
||||||
def register_strategy(self) -> StrategiesVector:
|
def register_strategy(self) -> StrategiesVector:
|
||||||
"""
|
"""
|
||||||
Register different sharding strategies for the current node.
|
Register different sharding strategies for the current node.
|
||||||
"""
|
"""
|
||||||
operand_mapping = self.get_operand_mapping()
|
strategy_generators = self.get_strategy_generator()
|
||||||
for generator in self.strategy_generator:
|
operand_mapping = self.get_operation_data_mapping()
|
||||||
|
for generator in strategy_generators:
|
||||||
strategies = generator.generate(operand_mapping)
|
strategies = generator.generate(operand_mapping)
|
||||||
self.strategies_vector.extend(strategies)
|
self.strategies_vector.extend(strategies)
|
||||||
|
|
||||||
@ -46,7 +47,7 @@ class NodeHandler(ABC):
|
|||||||
return strategy
|
return strategy
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def register_strategy_generator(self) -> List[StrategyGenerator_V2]:
|
def get_strategy_generator(self) -> List[StrategyGenerator_V2]:
|
||||||
"""
|
"""
|
||||||
Define which generators should be used by this NodeHandler object.
|
Define which generators should be used by this NodeHandler object.
|
||||||
"""
|
"""
|
||||||
@ -81,6 +82,8 @@ class ModuleHandler(NodeHandler):
|
|||||||
def __init__(self, *args, **kwargs) -> None:
|
def __init__(self, *args, **kwargs) -> None:
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
print("created")
|
||||||
|
|
||||||
# set attributes to access module parameters for convenience
|
# set attributes to access module parameters for convenience
|
||||||
assert self.node.graph.owning_module is not None, \
|
assert self.node.graph.owning_module is not None, \
|
||||||
f'The graph is not associated with a module, please make sure it can be used to instantiate a GraphModule object.'
|
f'The graph is not associated with a module, please make sure it can be used to instantiate a GraphModule object.'
|
||||||
|
@ -7,6 +7,7 @@ from functools import reduce
|
|||||||
|
|
||||||
from colossalai.device.device_mesh import DeviceMesh
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||||
|
from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec
|
||||||
from typing import Dict, List, Union, Tuple, Any
|
from typing import Dict, List, Union, Tuple, Any
|
||||||
from torch.fx.node import Node
|
from torch.fx.node import Node
|
||||||
from .constants import *
|
from .constants import *
|
||||||
@ -90,18 +91,12 @@ class TrainCycleItem:
|
|||||||
total: Any
|
total: Any
|
||||||
|
|
||||||
|
|
||||||
class CommunicationType(Enum):
|
|
||||||
FWD_ALL_REDUCE = 0
|
|
||||||
BWD_ALL_REDUCE = 1
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CommunicationAction:
|
class MemoryCost:
|
||||||
"""
|
"""
|
||||||
The actions
|
|
||||||
"""
|
"""
|
||||||
type: CommunicationType
|
activation: int = 0
|
||||||
mesh_dim: int
|
parameter: int = 0
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -126,7 +121,7 @@ class ShardingStrategy_V2:
|
|||||||
communication_cost: TrainCycleItem = None
|
communication_cost: TrainCycleItem = None
|
||||||
memory_cost: TrainCycleItem = None
|
memory_cost: TrainCycleItem = None
|
||||||
input_resharding_costs: Dict[OperationData, List[float]] = None
|
input_resharding_costs: Dict[OperationData, List[float]] = None
|
||||||
communication_actions: Dict[OperationData, List[CommunicationAction]] = None
|
communication_actions: Dict[OperationData, CommSpec] = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input_sharding_specs(self) -> Dict[OperationData, ShardingSpec]:
|
def input_sharding_specs(self) -> Dict[OperationData, ShardingSpec]:
|
||||||
@ -152,79 +147,6 @@ class ShardingStrategy_V2:
|
|||||||
return specs
|
return specs
|
||||||
|
|
||||||
|
|
||||||
class StrategyGenerator_V2(ABC):
|
|
||||||
"""
|
|
||||||
StrategyGenerator is used to generate the same group of sharding strategies.
|
|
||||||
|
|
||||||
TODO: remove the original strategy_generator.py after refactoring
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, device_mesh: DeviceMesh):
|
|
||||||
self.device_mesh = device_mesh
|
|
||||||
|
|
||||||
def update_communication_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2:
|
|
||||||
"""
|
|
||||||
Compute the communication cost involved in the forward and backward iteration.
|
|
||||||
"""
|
|
||||||
|
|
||||||
comm_cost = TrainCycleItem(fwd=0, bwd=0)
|
|
||||||
|
|
||||||
def _compute_and_add(data: OperationData, action: CommunicationAction):
|
|
||||||
sharded_shape = strategy.sharding_specs[data].get_sharded_shape_per_device()
|
|
||||||
dtype = operand.data.dtype
|
|
||||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
|
||||||
num_bytes = size_per_elem_bytes * reduce(operator.mul, sharded_shape)
|
|
||||||
cost = self.device_mesh.all_reduce_cost(num_bytes=num_bytes, mesh_dim=action.mesh_dim)
|
|
||||||
|
|
||||||
# compute the fwd
|
|
||||||
if action.type == CommunicationType.FWD_ALL_REDUCE:
|
|
||||||
comm_cost.fwd += cost
|
|
||||||
elif action.type == CommunicationType.BWD_ALL_REDUCE:
|
|
||||||
comm_cost.fwd += cost
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Found unknown CommunicationType {action.type}")
|
|
||||||
|
|
||||||
# check if communication action exists
|
|
||||||
# if so, loop over each action and compute the cost of each action
|
|
||||||
if strategy.communication_actions is not None:
|
|
||||||
for operand, actions in strategy.communication_actions:
|
|
||||||
for action in actions:
|
|
||||||
_compute_and_add(operand, action)
|
|
||||||
|
|
||||||
# update the communication cost attribute in-place
|
|
||||||
strategy.communication_cost = comm_cost
|
|
||||||
return strategy
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def update_compute_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2:
|
|
||||||
"""
|
|
||||||
Customize this method to compute the computation flops.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def update_memory_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2:
|
|
||||||
"""
|
|
||||||
Customize this method to compute the memory cost in bytes.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def generate(self, operand_mapping: Dict[str, OperationData]) -> List[ShardingStrategy_V2]:
|
|
||||||
"""
|
|
||||||
Generate all possible sharding strategies for this operation.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def validate(self, *args, **kwargs) -> bool:
|
|
||||||
"""
|
|
||||||
Validate if the operands are of desired shape.
|
|
||||||
If True, means this generator can be used for the current operation.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class StrategiesVector(list):
|
class StrategiesVector(list):
|
||||||
'''
|
'''
|
||||||
Each node in fx graph will have a corresponding StrategiesVector, to store all the possible
|
Each node in fx graph will have a corresponding StrategiesVector, to store all the possible
|
||||||
|
7
colossalai/auto_parallel/solver/strategy/__init__.py
Normal file
7
colossalai/auto_parallel/solver/strategy/__init__.py
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
from .strategy_generator import StrategyGenerator_V2
|
||||||
|
from .matmul_strategy_generator import DotProductStrategyGenerator, MatVecStrategyGenerator, LinearProjectionStrategyGenerator, BatchedMatMulStrategyGenerator
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'StrategyGenerator_V2', 'DotProductStrategyGenerator', 'MatVecStrategyGenerator',
|
||||||
|
'LinearProjectionStrategyGenerator', 'BatchedMatMulStrategyGenerator'
|
||||||
|
]
|
@ -0,0 +1,364 @@
|
|||||||
|
from cmath import log
|
||||||
|
from distutils.log import Log
|
||||||
|
import operator
|
||||||
|
import torch
|
||||||
|
from functools import reduce
|
||||||
|
from ..sharding_strategy import ShardingStrategy_V2, TrainCycleItem, MemoryCost
|
||||||
|
from colossalai.tensor.shape_consistency import CollectiveCommPattern
|
||||||
|
from .strategy_generator import StrategyGenerator_V2
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
|
class DotProductStrategyGenerator(StrategyGenerator_V2):
|
||||||
|
"""TODO: to be implemented"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class MatVecStrategyGenerator(StrategyGenerator_V2):
|
||||||
|
"""TODO: to be implemented"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class LinearProjectionStrategyGenerator(StrategyGenerator_V2):
|
||||||
|
|
||||||
|
def update_compute_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2:
|
||||||
|
# C = AB
|
||||||
|
# C: [M, N], A: [M, P], B: [P, N]
|
||||||
|
# fwd cost = MNP (only count mul)
|
||||||
|
# bwd: 2 x fwd_cost
|
||||||
|
sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
|
||||||
|
sharded_other_shape = strategy.sharding_specs[self.op_data['other']].get_sharded_shape_per_device()
|
||||||
|
dim_m_val = reduce(operator.mul, sharded_input_shape[:-1])
|
||||||
|
dim_n_val = sharded_other_shape[-1]
|
||||||
|
dim_p_val = sharded_other_shape[0]
|
||||||
|
|
||||||
|
fwd_compute_cost = dim_m_val * dim_n_val * dim_p_val
|
||||||
|
bwd_compute_cost = fwd_compute_cost * 2
|
||||||
|
compute_cost = TrainCycleItem(fwd=bwd_compute_cost,
|
||||||
|
bwd=bwd_compute_cost,
|
||||||
|
total=fwd_compute_cost + bwd_compute_cost)
|
||||||
|
strategy.compute_cost = compute_cost
|
||||||
|
|
||||||
|
def update_memory_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2:
|
||||||
|
input_size = self._compute_size_in_bytes(strategy, "input")
|
||||||
|
other_size = self._compute_size_in_bytes(strategy, "input")
|
||||||
|
|
||||||
|
if "bias" in self.op_data:
|
||||||
|
bias_size = self._compute_size_in_bytes(strategy, "bias")
|
||||||
|
else:
|
||||||
|
bias_size = 0
|
||||||
|
output_size = self._compute_size_in_bytes(strategy, "output")
|
||||||
|
|
||||||
|
fwd_mem_cost = MemoryCost(activation=output_size, parameter=other_size + bias_size)
|
||||||
|
bwd_mem_cost = MemoryCost(activation=input_size + other_size + bias_size, parameter=other_size)
|
||||||
|
total_mem_cost = MemoryCost(activation=input_size + 2 * output_size + bias_size,
|
||||||
|
parameter=other_size + bias_size)
|
||||||
|
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
|
||||||
|
strategy.memory_cost = memory_cost
|
||||||
|
|
||||||
|
def generate(self) -> List[ShardingStrategy_V2]:
|
||||||
|
strategies = []
|
||||||
|
|
||||||
|
# SS = SR x RS
|
||||||
|
strategies.append(self.split_lhs_space_rhs_space(0, 1))
|
||||||
|
strategies.append(self.split_lhs_space_rhs_space(1, 0))
|
||||||
|
|
||||||
|
# SR = SS x SR
|
||||||
|
strategies.append(self.split_lhs_space_both_contract(0, 1))
|
||||||
|
strategies.append(self.split_lhs_space_both_contract(1, 0))
|
||||||
|
|
||||||
|
# RS = RS x SS
|
||||||
|
strategies.append(self.split_rhs_space_both_contract(0, 1))
|
||||||
|
strategies.append(self.split_rhs_space_both_contract(1, 0))
|
||||||
|
|
||||||
|
# RR= RS x SR
|
||||||
|
strategies.append(self.recompute_split_both_contract(0))
|
||||||
|
strategies.append(self.recompute_split_both_contract(1))
|
||||||
|
|
||||||
|
# RS = RR x RS
|
||||||
|
strategies.append(self.split_rhs_space_only(0))
|
||||||
|
strategies.append(self.split_rhs_space_only(1))
|
||||||
|
|
||||||
|
# S01R = S01R x RR
|
||||||
|
strategies.append(self.split_lhs_1st_dim_1d(0, 1))
|
||||||
|
|
||||||
|
# RR = RS01 x S01R
|
||||||
|
strategies.append(self.split_lhs_2nd_dim_1d(0, 1))
|
||||||
|
|
||||||
|
# RS01 = RR x RS01
|
||||||
|
strategies.append(self.split_rhs_2nd_dim_1d(0, 1))
|
||||||
|
|
||||||
|
# update mete info on cost
|
||||||
|
for strategy in strategies:
|
||||||
|
self.update_communication_cost(strategy)
|
||||||
|
self.update_compute_cost(strategy)
|
||||||
|
self.update_memory_cost(strategy)
|
||||||
|
|
||||||
|
return strategies
|
||||||
|
|
||||||
|
def split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1):
|
||||||
|
# handle case SS = SR x RS
|
||||||
|
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
|
||||||
|
dim_partition_dict_mapping = {
|
||||||
|
"input": {
|
||||||
|
0: [mesh_dim_0]
|
||||||
|
},
|
||||||
|
"other": {
|
||||||
|
self.dim_q: [mesh_dim_1]
|
||||||
|
},
|
||||||
|
"bias": {
|
||||||
|
-1: [mesh_dim_1]
|
||||||
|
},
|
||||||
|
"output": {
|
||||||
|
0: [mesh_dim_0],
|
||||||
|
-1: [mesh_dim_1]
|
||||||
|
},
|
||||||
|
}
|
||||||
|
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
|
||||||
|
|
||||||
|
# set communication action
|
||||||
|
input_comm_spec = self.get_communication_spec(
|
||||||
|
sharding_spec=sharding_spec_mapping["input"],
|
||||||
|
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||||
|
logical_process_axis=mesh_dim_1)
|
||||||
|
other_comm_spec = self.get_communication_spec(
|
||||||
|
sharding_spec_mapping["output"],
|
||||||
|
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||||
|
logical_process_axis=mesh_dim_0)
|
||||||
|
|
||||||
|
communication_action_mapping = {"input": input_comm_spec, "other": other_comm_spec}
|
||||||
|
|
||||||
|
return self.get_sharding_strategy(name=name,
|
||||||
|
sharding_spec_mapping=sharding_spec_mapping,
|
||||||
|
communication_action_mapping=communication_action_mapping)
|
||||||
|
|
||||||
|
def split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
|
||||||
|
# handle the case SR = SS x SR
|
||||||
|
name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
|
||||||
|
|
||||||
|
# get sharding spec mapping
|
||||||
|
dim_partition_dict_mapping = {
|
||||||
|
"input": {
|
||||||
|
0: [mesh_dim_0],
|
||||||
|
-1: [mesh_dim_1]
|
||||||
|
},
|
||||||
|
"other": {
|
||||||
|
self.dim_p: [mesh_dim_1]
|
||||||
|
},
|
||||||
|
"bias": {},
|
||||||
|
"output": {
|
||||||
|
0: [mesh_dim_0]
|
||||||
|
},
|
||||||
|
}
|
||||||
|
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
|
||||||
|
|
||||||
|
# get communication action mapping
|
||||||
|
input_comm_spec = self.get_communication_spec(
|
||||||
|
sharding_spec=sharding_spec_mapping["input"],
|
||||||
|
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||||
|
logical_process_axis=mesh_dim_0)
|
||||||
|
output_comm_spec = self.get_communication_spec(
|
||||||
|
sharding_spec=sharding_spec_mapping["output"],
|
||||||
|
communication_pattern=CollectiveCommPattern.REDUCE_FWD_IDENTITY_BWD,
|
||||||
|
logical_process_axis=mesh_dim_1)
|
||||||
|
|
||||||
|
communication_action_mapping = {"input": input_comm_spec, 'output': output_comm_spec}
|
||||||
|
|
||||||
|
return self.get_sharding_strategy(name=name,
|
||||||
|
sharding_spec_mapping=sharding_spec_mapping,
|
||||||
|
communication_action_mapping=communication_action_mapping)
|
||||||
|
|
||||||
|
def split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
|
||||||
|
name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
|
||||||
|
|
||||||
|
# get sharding specs
|
||||||
|
dim_partition_dict_mapping = {
|
||||||
|
"input": {
|
||||||
|
-1: [mesh_dim_0]
|
||||||
|
},
|
||||||
|
"other": {
|
||||||
|
self.dim_p: [mesh_dim_0],
|
||||||
|
self.dim_q: [mesh_dim_1]
|
||||||
|
},
|
||||||
|
"bias": {
|
||||||
|
-1: [mesh_dim_1]
|
||||||
|
},
|
||||||
|
"output": {
|
||||||
|
-1: [mesh_dim_1]
|
||||||
|
},
|
||||||
|
}
|
||||||
|
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
|
||||||
|
|
||||||
|
# get communication actions
|
||||||
|
output_comm_spec = self.get_communication_spec(
|
||||||
|
sharding_spec=sharding_spec_mapping['output'],
|
||||||
|
communication_pattern=CollectiveCommPattern.REDUCE_FWD_IDENTITY_BWD,
|
||||||
|
logical_process_axis=mesh_dim_0)
|
||||||
|
input_comm_spec = self.get_communication_spec(
|
||||||
|
sharding_spec=sharding_spec_mapping['input'],
|
||||||
|
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||||
|
logical_process_axis=mesh_dim_1)
|
||||||
|
communication_action_mapping = {"output": output_comm_spec, "input": input_comm_spec}
|
||||||
|
return self.get_sharding_strategy(name=name,
|
||||||
|
sharding_spec_mapping=sharding_spec_mapping,
|
||||||
|
communication_action_mapping=communication_action_mapping)
|
||||||
|
|
||||||
|
def recompute_split_both_contract(self, mesh_dim):
|
||||||
|
name = f'RR = RS{mesh_dim} x S{mesh_dim}R'
|
||||||
|
|
||||||
|
# get sharding spec
|
||||||
|
dim_partition_dict_mapping = {
|
||||||
|
"input": {
|
||||||
|
-1: [mesh_dim]
|
||||||
|
},
|
||||||
|
"other": {
|
||||||
|
self.dim_p: [mesh_dim]
|
||||||
|
},
|
||||||
|
"bias": {},
|
||||||
|
"output": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
|
||||||
|
|
||||||
|
# get communication action
|
||||||
|
output_comm_spec = self.get_communication_spec(
|
||||||
|
sharding_spec=sharding_spec_mapping['output'],
|
||||||
|
communication_pattern=CollectiveCommPattern.REDUCE_FWD_IDENTITY_BWD,
|
||||||
|
logical_process_axis=mesh_dim)
|
||||||
|
communication_action_mapping = {'output': output_comm_spec}
|
||||||
|
return self.get_sharding_strategy(name=name,
|
||||||
|
sharding_spec_mapping=sharding_spec_mapping,
|
||||||
|
communication_action_mapping=communication_action_mapping)
|
||||||
|
|
||||||
|
def split_rhs_space_only(self, mesh_dim):
|
||||||
|
name = f'RS{mesh_dim} = RR x RS{mesh_dim}'
|
||||||
|
|
||||||
|
# get sharding spec
|
||||||
|
dim_partition_dict_mapping = {
|
||||||
|
"input": {},
|
||||||
|
"other": {
|
||||||
|
self.dim_q: [mesh_dim]
|
||||||
|
},
|
||||||
|
"bias": {
|
||||||
|
-1: [mesh_dim]
|
||||||
|
},
|
||||||
|
"output": {
|
||||||
|
-1: [mesh_dim]
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
|
||||||
|
|
||||||
|
# get communication actions
|
||||||
|
input_comm_spec = self.get_communication_spec(
|
||||||
|
sharding_spec=sharding_spec_mapping['input'],
|
||||||
|
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||||
|
logical_process_axis=mesh_dim)
|
||||||
|
communication_action_mapping = {'input': input_comm_spec}
|
||||||
|
return self.get_sharding_strategy(name=name,
|
||||||
|
sharding_spec_mapping=sharding_spec_mapping,
|
||||||
|
communication_action_mapping=communication_action_mapping)
|
||||||
|
|
||||||
|
def split_lhs_1st_dim_1d(self, mesh_dim_0, mesh_dim_1):
|
||||||
|
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR'
|
||||||
|
# get sharding spec
|
||||||
|
dim_partition_dict_mapping = {
|
||||||
|
"input": {
|
||||||
|
0: [mesh_dim_0, mesh_dim_1]
|
||||||
|
},
|
||||||
|
"other": {},
|
||||||
|
"bias": {},
|
||||||
|
"output": {
|
||||||
|
0: [mesh_dim_0, mesh_dim_1]
|
||||||
|
},
|
||||||
|
}
|
||||||
|
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
|
||||||
|
|
||||||
|
# get communication action
|
||||||
|
other_comm_spec = self.get_communication_spec(
|
||||||
|
sharding_spec=sharding_spec_mapping['other'],
|
||||||
|
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||||
|
logical_process_axis=[mesh_dim_0, mesh_dim_1])
|
||||||
|
|
||||||
|
communcation_action_mapping = {"other": other_comm_spec}
|
||||||
|
return self.get_sharding_strategy(name=name,
|
||||||
|
sharding_spec_mapping=sharding_spec_mapping,
|
||||||
|
communication_action_mapping=communcation_action_mapping)
|
||||||
|
|
||||||
|
def split_lhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1):
|
||||||
|
name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R'
|
||||||
|
|
||||||
|
# get sharding spec
|
||||||
|
dim_partition_dict_mapping = {
|
||||||
|
"input": {
|
||||||
|
-1: [mesh_dim_0, mesh_dim_1]
|
||||||
|
},
|
||||||
|
"other": {
|
||||||
|
self.dim_p: [mesh_dim_0, mesh_dim_1]
|
||||||
|
},
|
||||||
|
"bias": {},
|
||||||
|
"output": {},
|
||||||
|
}
|
||||||
|
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
|
||||||
|
|
||||||
|
# get communication action
|
||||||
|
output_comm_spec = self.get_communication_spec(
|
||||||
|
sharding_spec=sharding_spec_mapping['output'],
|
||||||
|
communication_pattern=CollectiveCommPattern.REDUCE_FWD_IDENTITY_BWD,
|
||||||
|
logical_process_axis=[mesh_dim_0, mesh_dim_1])
|
||||||
|
communication_action_mapping = {'output': output_comm_spec}
|
||||||
|
|
||||||
|
return self.get_sharding_strategy(name=name,
|
||||||
|
sharding_spec_mapping=sharding_spec_mapping,
|
||||||
|
communication_action_mapping=communication_action_mapping)
|
||||||
|
|
||||||
|
def split_rhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1):
|
||||||
|
name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}'
|
||||||
|
|
||||||
|
# get sharding spec
|
||||||
|
dim_partition_dict_mapping = {
|
||||||
|
"input": {},
|
||||||
|
"other": {
|
||||||
|
self.dim_q: [mesh_dim_0, mesh_dim_1]
|
||||||
|
},
|
||||||
|
"bias": {
|
||||||
|
-1: [mesh_dim_0, mesh_dim_1]
|
||||||
|
},
|
||||||
|
"output": {
|
||||||
|
-1: [mesh_dim_0, mesh_dim_1]
|
||||||
|
},
|
||||||
|
}
|
||||||
|
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
|
||||||
|
|
||||||
|
# get communication action
|
||||||
|
input_comm_spec = self.get_communication_spec(
|
||||||
|
sharding_spec=sharding_spec_mapping['input'],
|
||||||
|
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||||
|
logical_process_axis=[mesh_dim_0, mesh_dim_1])
|
||||||
|
communication_action_mapping = {'input': input_comm_spec}
|
||||||
|
|
||||||
|
return self.get_sharding_strategy(name=name,
|
||||||
|
sharding_spec_mapping=sharding_spec_mapping,
|
||||||
|
communication_action_mapping=communication_action_mapping)
|
||||||
|
|
||||||
|
def validate(self) -> bool:
|
||||||
|
assert "input" in self.op_data
|
||||||
|
assert "other" in self.op_data
|
||||||
|
|
||||||
|
# make sure the other has 2 dim
|
||||||
|
input_data = self.op_data['input']
|
||||||
|
other_data = self.op_data['other']
|
||||||
|
assert input_data.data.dim() > 0 and other_data.data.dim() == 2
|
||||||
|
assert other_data.logical_shape[0] == input_data.logical_shape[-1]
|
||||||
|
|
||||||
|
# check if bias has the same a valid dim
|
||||||
|
has_bias = "bias" in self.op_data
|
||||||
|
|
||||||
|
if has_bias:
|
||||||
|
bias_data = self.op_data['bias']
|
||||||
|
assert bias_data.logical_shape[-1] == other_data.logical_shape[-1]
|
||||||
|
|
||||||
|
|
||||||
|
class BatchedMatMulStrategyGenerator(StrategyGenerator_V2):
|
||||||
|
"""TODO: to be implemented"""
|
||||||
|
pass
|
154
colossalai/auto_parallel/solver/strategy/strategy_generator.py
Normal file
154
colossalai/auto_parallel/solver/strategy/strategy_generator.py
Normal file
@ -0,0 +1,154 @@
|
|||||||
|
import operator
|
||||||
|
import torch
|
||||||
|
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||||
|
from functools import reduce
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec
|
||||||
|
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||||
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
|
from typing import Dict, List, Union, Any
|
||||||
|
from ..sharding_strategy import OperationData, ShardingStrategy_V2, TrainCycleItem
|
||||||
|
|
||||||
|
|
||||||
|
class StrategyGenerator_V2(ABC):
|
||||||
|
"""
|
||||||
|
StrategyGenerator is used to generate the same group of sharding strategies.
|
||||||
|
|
||||||
|
TODO: remove the original strategy_generator.py after refactoring
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh):
|
||||||
|
self.op_data = operation_data_mapping
|
||||||
|
self.device_mesh = device_mesh
|
||||||
|
|
||||||
|
def get_sharding_strategy(self, name: str, sharding_spec_mapping: Dict[str, ShardingSpec],
|
||||||
|
communication_action_mapping: Dict[str, CommSpec]):
|
||||||
|
"""
|
||||||
|
A factory method to produce a ShardingStrategy object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sharding_spec_mapping (Dict[str, ShardingSpec]): the mapping between the operation data name and the ShardingSpec object.
|
||||||
|
communication_action_mapping (Dict[str, CommSpec]): the mapping between the operation data name and the CommSpec object.
|
||||||
|
"""
|
||||||
|
sharding_specs = self.replace_op_name_with_op_data(sharding_spec_mapping)
|
||||||
|
communication_actions = self.replace_op_name_with_op_data(communication_action_mapping)
|
||||||
|
return ShardingStrategy_V2(name=name,
|
||||||
|
sharding_specs=sharding_specs,
|
||||||
|
communication_actions=communication_actions)
|
||||||
|
|
||||||
|
def to_sharding_spec_mapping(self, mapping: Dict[str, Dict[int, List[int]]]):
|
||||||
|
"""
|
||||||
|
A utility method to convert the the dim partition dict to a ShardingSpec object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mapping (Dict[str, Dict[int, List[int]]]): the key of the mapping is the operation data name and the value is a dim partition dictionary.
|
||||||
|
"""
|
||||||
|
results = {}
|
||||||
|
for op_data_name, dim_partition_dict in mapping.items():
|
||||||
|
op_data = self.op_data[op_data_name]
|
||||||
|
sharding_spec = ShardingSpec(device_mesh=self.device_mesh,
|
||||||
|
entire_shape=op_data.logical_shape,
|
||||||
|
dim_partition_dict=dim_partition_dict)
|
||||||
|
results[op_data_name] = sharding_spec
|
||||||
|
return results
|
||||||
|
|
||||||
|
def replace_op_name_with_op_data(self, mapping: Dict[str, Any]):
|
||||||
|
"""
|
||||||
|
Convert the key of the dictionary from the operation data name to an OperationData object.
|
||||||
|
"""
|
||||||
|
results = {}
|
||||||
|
for k, v in mapping.items():
|
||||||
|
op_data = self.op_data[k]
|
||||||
|
results[op_data] = v
|
||||||
|
return results
|
||||||
|
|
||||||
|
def get_communication_spec(self, sharding_spec: ShardingSpec, communication_pattern: CollectiveCommPattern,
|
||||||
|
logical_process_axis: Union[int, List[int]]):
|
||||||
|
"""
|
||||||
|
A factory method to produce a CommSpec object.
|
||||||
|
"""
|
||||||
|
# use flatten device mesh the same action is applied to two axes
|
||||||
|
if isinstance(logical_process_axis, list) and len(logical_process_axis) == 2:
|
||||||
|
sharding_spec.device_mesh = sharding_spec.device_mesh.flatten()
|
||||||
|
logical_process_axis = 0
|
||||||
|
return CommSpec(comm_pattern=communication_pattern,
|
||||||
|
sharding_spec=sharding_spec,
|
||||||
|
logical_process_axis=logical_process_axis)
|
||||||
|
|
||||||
|
def update_communication_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2:
|
||||||
|
"""
|
||||||
|
Compute the communication cost involved in the forward and backward iteration.
|
||||||
|
"""
|
||||||
|
|
||||||
|
comm_cost = TrainCycleItem(fwd=0, bwd=0)
|
||||||
|
|
||||||
|
def _compute_and_add(data: OperationData, comm_spec: CommSpec):
|
||||||
|
num_ele_in_comm = comm_spec.get_comm_cost()
|
||||||
|
dtype = operand.data.dtype
|
||||||
|
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||||
|
cost = size_per_elem_bytes * num_ele_in_comm
|
||||||
|
|
||||||
|
# compute the fwd
|
||||||
|
# TODO: comm_spec.get_comm_cost should return a TrainCycleItem instead of the total cost.
|
||||||
|
# it works fine here because only REDUCE_FWD_IDENTITY_BWD and IDENTITY_FWD_ALLREDUCE_BWD are used,
|
||||||
|
# so total cost is either for fwd or bwd.
|
||||||
|
if comm_spec.comm_pattern == CollectiveCommPattern.REDUCE_FWD_IDENTITY_BWD:
|
||||||
|
comm_cost.fwd += cost
|
||||||
|
elif comm_spec.comm_pattern == CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD:
|
||||||
|
comm_cost.fwd += cost
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Found unknown CommunicationType {comm_spec.comm_pattern}")
|
||||||
|
|
||||||
|
# check if communication action exists
|
||||||
|
# if so, loop over each action and compute the cost of each action
|
||||||
|
if strategy.communication_actions is not None:
|
||||||
|
for operand, comm_spec in strategy.communication_actions:
|
||||||
|
_compute_and_add(operand, comm_spec)
|
||||||
|
|
||||||
|
# update the communication cost attribute in-place
|
||||||
|
strategy.communication_cost = comm_cost
|
||||||
|
return strategy
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def update_compute_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2:
|
||||||
|
"""
|
||||||
|
Customize this method to compute the computation flops.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def update_memory_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2:
|
||||||
|
"""
|
||||||
|
Customize this method to compute the memory cost in bytes.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _compute_size_in_bytes(self, strategy: ShardingStrategy_V2, key: str):
|
||||||
|
"""
|
||||||
|
Compute the size of a tensor in bytes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
strategy (ShardingStrategy): the ShardingStrategy generated.
|
||||||
|
key (str): the name of the operation data defined by the generator.
|
||||||
|
|
||||||
|
"""
|
||||||
|
op_data = self.op_data[key]
|
||||||
|
sharded_shape = strategy.sharding_specs[op_data].get_sharded_shape_per_device()
|
||||||
|
dtype = self.op_data[key].data.dtype
|
||||||
|
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||||
|
return reduce(operator.mul, sharded_shape) * size_per_elem_bytes
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def generate(self) -> List[ShardingStrategy_V2]:
|
||||||
|
"""
|
||||||
|
Generate all possible sharding strategies for this operation.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def validate(self, *args, **kwargs) -> bool:
|
||||||
|
"""
|
||||||
|
Validate if the operands are of desired shape.
|
||||||
|
If True, means this generator can be used for the current operation.
|
||||||
|
"""
|
||||||
|
pass
|
@ -84,13 +84,13 @@ def test_linear_function_handler():
|
|||||||
assert mapping['other'].name == "weight"
|
assert mapping['other'].name == "weight"
|
||||||
assert mapping['other'].data.is_meta
|
assert mapping['other'].data.is_meta
|
||||||
assert mapping['other'].data.shape == torch.Size([20, 10])
|
assert mapping['other'].data.shape == torch.Size([20, 10])
|
||||||
assert mapping['other'].type == OperationDataType.ARG
|
assert mapping['other'].type == OperationDataType.PARAM
|
||||||
assert mapping['other'].logical_shape == torch.Size([10, 20])
|
assert mapping['other'].logical_shape == torch.Size([10, 20])
|
||||||
|
|
||||||
assert mapping['bias'].name == "bias"
|
assert mapping['bias'].name == "bias"
|
||||||
assert mapping['bias'].data.is_meta
|
assert mapping['bias'].data.is_meta
|
||||||
assert mapping['bias'].data.shape == torch.Size([20])
|
assert mapping['bias'].data.shape == torch.Size([20])
|
||||||
assert mapping['bias'].type == OperationDataType.ARG
|
assert mapping['bias'].type == OperationDataType.PARAM
|
||||||
assert mapping['other'].logical_shape == torch.Size([10, 20])
|
assert mapping['other'].logical_shape == torch.Size([10, 20])
|
||||||
|
|
||||||
assert mapping['output'].name == "linear"
|
assert mapping['output'].name == "linear"
|
||||||
@ -100,5 +100,5 @@ def test_linear_function_handler():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# test_linear_module_handler()
|
test_linear_module_handler()
|
||||||
test_linear_function_handler()
|
test_linear_function_handler()
|
||||||
|
Loading…
Reference in New Issue
Block a user