[autoparallel] implemented linear projection strategy generator (#1639)

This commit is contained in:
Frank Lee 2022-09-26 16:58:14 +08:00 committed by GitHub
parent 154d3ef432
commit 45b39a692a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 564 additions and 134 deletions

View File

@ -1,46 +1,12 @@
import torch import torch
import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from .node_handler import ModuleHandler, NodeHandler from .node_handler import ModuleHandler, NodeHandler
from ..sharding_strategy import ShardingStrategy_V2, StrategyGenerator_V2, OperationDataType, OperationData from ..sharding_strategy import ShardingStrategy_V2, OperationDataType, OperationData
from ..strategy import LinearProjectionStrategyGenerator, StrategyGenerator_V2
from typing import List, Dict from typing import List, Dict
from .registry import operator_registry from .registry import operator_registry
__all__ = ['LinearModuleHandler'] __all__ = ['LinearModuleHandler', 'LinearFunctionHandler']
class DotProductStrategyGenerator(StrategyGenerator_V2):
"""TODO: to be implemented"""
pass
class MatVecStrategyGenerator(StrategyGenerator_V2):
"""TODO: to be implemented"""
pass
class LinearProjectionStrategyGenerator(StrategyGenerator_V2):
def update_compute_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2:
"""TODO: to be implemented"""
pass
def update_memory_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2:
"""TODO: to be implemented"""
pass
def generate(self, operand_mapping: Dict[str, OperationData]) -> List[ShardingStrategy_V2]:
"""TODO: to be implemented"""
pass
def validate(self, *args, **kwargs) -> bool:
"""TODO: to be implemented"""
pass
class BatchedMatMulStrategyGenerator(StrategyGenerator_V2):
"""TODO: to be implemented"""
pass
@operator_registry.register(torch.nn.Linear) @operator_registry.register(torch.nn.Linear)
@ -49,9 +15,10 @@ class LinearModuleHandler(ModuleHandler):
A LinearModuleHandler which deals with the sharding strategies for nn.Linear module. A LinearModuleHandler which deals with the sharding strategies for nn.Linear module.
""" """
def register_strategy_generator(self) -> List[StrategyGenerator_V2]: def get_strategy_generator(self) -> List[StrategyGenerator_V2]:
op_data_mapping = self.get_operation_data_mapping()
generators = [] generators = []
generators.append(LinearProjectionStrategyGenerator(self.device_mesh)) generators.append(LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh))
return generators return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]: def get_operation_data_mapping(self) -> Dict[str, OperationData]:
@ -97,9 +64,10 @@ class LinearFunctionHandler(NodeHandler):
A LinearModuleHandler which deals with the sharding strategies for nn.Linear module. A LinearModuleHandler which deals with the sharding strategies for nn.Linear module.
""" """
def register_strategy_generator(self) -> List[StrategyGenerator_V2]: def get_strategy_generator(self) -> List[StrategyGenerator_V2]:
op_data_mapping = self.get_operation_data_mapping()
generators = [] generators = []
generators.append(LinearProjectionStrategyGenerator(self.device_mesh)) generators.append(LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh))
return generators return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]: def get_operation_data_mapping(self) -> Dict[str, OperationData]:
@ -108,8 +76,15 @@ class LinearFunctionHandler(NodeHandler):
physical_input_operand = OperationData(name=str(self.node.args[0]), physical_input_operand = OperationData(name=str(self.node.args[0]),
type=OperationDataType.ARG, type=OperationDataType.ARG,
data=self.node.args[0]._meta_data) data=self.node.args[0]._meta_data)
# check if the other operand is a parameter
if isinstance(self.node.args[1]._meta_data, torch.nn.parameter.Parameter):
data_type = OperationDataType.PARAM
else:
data_type = OperationDataType.ARG
physical_other_operand = OperationData(name=str(self.node.args[1]), physical_other_operand = OperationData(name=str(self.node.args[1]),
type=OperationDataType.ARG, type=data_type,
data=self.node.args[1]._meta_data, data=self.node.args[1]._meta_data,
logical_shape=self.node.args[1]._meta_data.shape[::-1]) logical_shape=self.node.args[1]._meta_data.shape[::-1])
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
@ -117,8 +92,13 @@ class LinearFunctionHandler(NodeHandler):
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output} mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
if self.node.args[2] is not None: if self.node.args[2] is not None:
# check if the other operand is a parameter
if isinstance(self.node.args[2]._meta_data, torch.nn.parameter.Parameter):
data_type = OperationDataType.PARAM
else:
data_type = OperationDataType.ARG
physical_bias_operand = OperationData(name=str(self.node.args[2]), physical_bias_operand = OperationData(name=str(self.node.args[2]),
type=OperationDataType.ARG, type=data_type,
data=self.node.args[2]._meta_data) data=self.node.args[2]._meta_data)
mapping['bias'] = physical_bias_operand mapping['bias'] = physical_bias_operand
return mapping return mapping

View File

@ -2,7 +2,8 @@ from abc import ABC, abstractmethod
from torch.fx.node import Node from torch.fx.node import Node
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from typing import Dict, List from typing import Dict, List
from ..sharding_strategy import ShardingStrategy, ShardingStrategy_V2, StrategiesVector, OperationData, StrategyGenerator_V2 from ..sharding_strategy import ShardingStrategy_V2, StrategiesVector, OperationData
from ..strategy import StrategyGenerator_V2
class NodeHandler(ABC): class NodeHandler(ABC):
@ -26,14 +27,14 @@ class NodeHandler(ABC):
self.successor_node = list(node.users.keys()) self.successor_node = list(node.users.keys())
self.device_mesh = device_mesh self.device_mesh = device_mesh
self.strategies_vector = strategies_vector self.strategies_vector = strategies_vector
self.strategy_generator = self.register_strategy_generator()
def register_strategy(self) -> StrategiesVector: def register_strategy(self) -> StrategiesVector:
""" """
Register different sharding strategies for the current node. Register different sharding strategies for the current node.
""" """
operand_mapping = self.get_operand_mapping() strategy_generators = self.get_strategy_generator()
for generator in self.strategy_generator: operand_mapping = self.get_operation_data_mapping()
for generator in strategy_generators:
strategies = generator.generate(operand_mapping) strategies = generator.generate(operand_mapping)
self.strategies_vector.extend(strategies) self.strategies_vector.extend(strategies)
@ -46,7 +47,7 @@ class NodeHandler(ABC):
return strategy return strategy
@abstractmethod @abstractmethod
def register_strategy_generator(self) -> List[StrategyGenerator_V2]: def get_strategy_generator(self) -> List[StrategyGenerator_V2]:
""" """
Define which generators should be used by this NodeHandler object. Define which generators should be used by this NodeHandler object.
""" """
@ -81,6 +82,8 @@ class ModuleHandler(NodeHandler):
def __init__(self, *args, **kwargs) -> None: def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
print("created")
# set attributes to access module parameters for convenience # set attributes to access module parameters for convenience
assert self.node.graph.owning_module is not None, \ assert self.node.graph.owning_module is not None, \
f'The graph is not associated with a module, please make sure it can be used to instantiate a GraphModule object.' f'The graph is not associated with a module, please make sure it can be used to instantiate a GraphModule object.'

View File

@ -7,6 +7,7 @@ from functools import reduce
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.sharding_spec import ShardingSpec from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec
from typing import Dict, List, Union, Tuple, Any from typing import Dict, List, Union, Tuple, Any
from torch.fx.node import Node from torch.fx.node import Node
from .constants import * from .constants import *
@ -90,18 +91,12 @@ class TrainCycleItem:
total: Any total: Any
class CommunicationType(Enum):
FWD_ALL_REDUCE = 0
BWD_ALL_REDUCE = 1
@dataclass @dataclass
class CommunicationAction: class MemoryCost:
""" """
The actions
""" """
type: CommunicationType activation: int = 0
mesh_dim: int parameter: int = 0
@dataclass @dataclass
@ -126,7 +121,7 @@ class ShardingStrategy_V2:
communication_cost: TrainCycleItem = None communication_cost: TrainCycleItem = None
memory_cost: TrainCycleItem = None memory_cost: TrainCycleItem = None
input_resharding_costs: Dict[OperationData, List[float]] = None input_resharding_costs: Dict[OperationData, List[float]] = None
communication_actions: Dict[OperationData, List[CommunicationAction]] = None communication_actions: Dict[OperationData, CommSpec] = None
@property @property
def input_sharding_specs(self) -> Dict[OperationData, ShardingSpec]: def input_sharding_specs(self) -> Dict[OperationData, ShardingSpec]:
@ -152,79 +147,6 @@ class ShardingStrategy_V2:
return specs return specs
class StrategyGenerator_V2(ABC):
"""
StrategyGenerator is used to generate the same group of sharding strategies.
TODO: remove the original strategy_generator.py after refactoring
"""
def __init__(self, device_mesh: DeviceMesh):
self.device_mesh = device_mesh
def update_communication_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2:
"""
Compute the communication cost involved in the forward and backward iteration.
"""
comm_cost = TrainCycleItem(fwd=0, bwd=0)
def _compute_and_add(data: OperationData, action: CommunicationAction):
sharded_shape = strategy.sharding_specs[data].get_sharded_shape_per_device()
dtype = operand.data.dtype
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
num_bytes = size_per_elem_bytes * reduce(operator.mul, sharded_shape)
cost = self.device_mesh.all_reduce_cost(num_bytes=num_bytes, mesh_dim=action.mesh_dim)
# compute the fwd
if action.type == CommunicationType.FWD_ALL_REDUCE:
comm_cost.fwd += cost
elif action.type == CommunicationType.BWD_ALL_REDUCE:
comm_cost.fwd += cost
else:
raise ValueError(f"Found unknown CommunicationType {action.type}")
# check if communication action exists
# if so, loop over each action and compute the cost of each action
if strategy.communication_actions is not None:
for operand, actions in strategy.communication_actions:
for action in actions:
_compute_and_add(operand, action)
# update the communication cost attribute in-place
strategy.communication_cost = comm_cost
return strategy
@abstractmethod
def update_compute_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2:
"""
Customize this method to compute the computation flops.
"""
pass
@abstractmethod
def update_memory_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2:
"""
Customize this method to compute the memory cost in bytes.
"""
pass
@abstractmethod
def generate(self, operand_mapping: Dict[str, OperationData]) -> List[ShardingStrategy_V2]:
"""
Generate all possible sharding strategies for this operation.
"""
pass
@abstractmethod
def validate(self, *args, **kwargs) -> bool:
"""
Validate if the operands are of desired shape.
If True, means this generator can be used for the current operation.
"""
pass
class StrategiesVector(list): class StrategiesVector(list):
''' '''
Each node in fx graph will have a corresponding StrategiesVector, to store all the possible Each node in fx graph will have a corresponding StrategiesVector, to store all the possible

View File

@ -0,0 +1,7 @@
from .strategy_generator import StrategyGenerator_V2
from .matmul_strategy_generator import DotProductStrategyGenerator, MatVecStrategyGenerator, LinearProjectionStrategyGenerator, BatchedMatMulStrategyGenerator
__all__ = [
'StrategyGenerator_V2', 'DotProductStrategyGenerator', 'MatVecStrategyGenerator',
'LinearProjectionStrategyGenerator', 'BatchedMatMulStrategyGenerator'
]

View File

@ -0,0 +1,364 @@
from cmath import log
from distutils.log import Log
import operator
import torch
from functools import reduce
from ..sharding_strategy import ShardingStrategy_V2, TrainCycleItem, MemoryCost
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import StrategyGenerator_V2
from typing import List
class DotProductStrategyGenerator(StrategyGenerator_V2):
"""TODO: to be implemented"""
pass
class MatVecStrategyGenerator(StrategyGenerator_V2):
"""TODO: to be implemented"""
pass
class LinearProjectionStrategyGenerator(StrategyGenerator_V2):
def update_compute_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2:
# C = AB
# C: [M, N], A: [M, P], B: [P, N]
# fwd cost = MNP (only count mul)
# bwd: 2 x fwd_cost
sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
sharded_other_shape = strategy.sharding_specs[self.op_data['other']].get_sharded_shape_per_device()
dim_m_val = reduce(operator.mul, sharded_input_shape[:-1])
dim_n_val = sharded_other_shape[-1]
dim_p_val = sharded_other_shape[0]
fwd_compute_cost = dim_m_val * dim_n_val * dim_p_val
bwd_compute_cost = fwd_compute_cost * 2
compute_cost = TrainCycleItem(fwd=bwd_compute_cost,
bwd=bwd_compute_cost,
total=fwd_compute_cost + bwd_compute_cost)
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2:
input_size = self._compute_size_in_bytes(strategy, "input")
other_size = self._compute_size_in_bytes(strategy, "input")
if "bias" in self.op_data:
bias_size = self._compute_size_in_bytes(strategy, "bias")
else:
bias_size = 0
output_size = self._compute_size_in_bytes(strategy, "output")
fwd_mem_cost = MemoryCost(activation=output_size, parameter=other_size + bias_size)
bwd_mem_cost = MemoryCost(activation=input_size + other_size + bias_size, parameter=other_size)
total_mem_cost = MemoryCost(activation=input_size + 2 * output_size + bias_size,
parameter=other_size + bias_size)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
def generate(self) -> List[ShardingStrategy_V2]:
strategies = []
# SS = SR x RS
strategies.append(self.split_lhs_space_rhs_space(0, 1))
strategies.append(self.split_lhs_space_rhs_space(1, 0))
# SR = SS x SR
strategies.append(self.split_lhs_space_both_contract(0, 1))
strategies.append(self.split_lhs_space_both_contract(1, 0))
# RS = RS x SS
strategies.append(self.split_rhs_space_both_contract(0, 1))
strategies.append(self.split_rhs_space_both_contract(1, 0))
# RR= RS x SR
strategies.append(self.recompute_split_both_contract(0))
strategies.append(self.recompute_split_both_contract(1))
# RS = RR x RS
strategies.append(self.split_rhs_space_only(0))
strategies.append(self.split_rhs_space_only(1))
# S01R = S01R x RR
strategies.append(self.split_lhs_1st_dim_1d(0, 1))
# RR = RS01 x S01R
strategies.append(self.split_lhs_2nd_dim_1d(0, 1))
# RS01 = RR x RS01
strategies.append(self.split_rhs_2nd_dim_1d(0, 1))
# update mete info on cost
for strategy in strategies:
self.update_communication_cost(strategy)
self.update_compute_cost(strategy)
self.update_memory_cost(strategy)
return strategies
def split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1):
# handle case SS = SR x RS
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
dim_partition_dict_mapping = {
"input": {
0: [mesh_dim_0]
},
"other": {
self.dim_q: [mesh_dim_1]
},
"bias": {
-1: [mesh_dim_1]
},
"output": {
0: [mesh_dim_0],
-1: [mesh_dim_1]
},
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# set communication action
input_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_1)
other_comm_spec = self.get_communication_spec(
sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0)
communication_action_mapping = {"input": input_comm_spec, "other": other_comm_spec}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
def split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
# handle the case SR = SS x SR
name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
# get sharding spec mapping
dim_partition_dict_mapping = {
"input": {
0: [mesh_dim_0],
-1: [mesh_dim_1]
},
"other": {
self.dim_p: [mesh_dim_1]
},
"bias": {},
"output": {
0: [mesh_dim_0]
},
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# get communication action mapping
input_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0)
output_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.REDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_1)
communication_action_mapping = {"input": input_comm_spec, 'output': output_comm_spec}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
def split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
# get sharding specs
dim_partition_dict_mapping = {
"input": {
-1: [mesh_dim_0]
},
"other": {
self.dim_p: [mesh_dim_0],
self.dim_q: [mesh_dim_1]
},
"bias": {
-1: [mesh_dim_1]
},
"output": {
-1: [mesh_dim_1]
},
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# get communication actions
output_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping['output'],
communication_pattern=CollectiveCommPattern.REDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_0)
input_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping['input'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_1)
communication_action_mapping = {"output": output_comm_spec, "input": input_comm_spec}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
def recompute_split_both_contract(self, mesh_dim):
name = f'RR = RS{mesh_dim} x S{mesh_dim}R'
# get sharding spec
dim_partition_dict_mapping = {
"input": {
-1: [mesh_dim]
},
"other": {
self.dim_p: [mesh_dim]
},
"bias": {},
"output": {},
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# get communication action
output_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping['output'],
communication_pattern=CollectiveCommPattern.REDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim)
communication_action_mapping = {'output': output_comm_spec}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
def split_rhs_space_only(self, mesh_dim):
name = f'RS{mesh_dim} = RR x RS{mesh_dim}'
# get sharding spec
dim_partition_dict_mapping = {
"input": {},
"other": {
self.dim_q: [mesh_dim]
},
"bias": {
-1: [mesh_dim]
},
"output": {
-1: [mesh_dim]
},
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# get communication actions
input_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping['input'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim)
communication_action_mapping = {'input': input_comm_spec}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
def split_lhs_1st_dim_1d(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR'
# get sharding spec
dim_partition_dict_mapping = {
"input": {
0: [mesh_dim_0, mesh_dim_1]
},
"other": {},
"bias": {},
"output": {
0: [mesh_dim_0, mesh_dim_1]
},
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# get communication action
other_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping['other'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1])
communcation_action_mapping = {"other": other_comm_spec}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communcation_action_mapping)
def split_lhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1):
name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R'
# get sharding spec
dim_partition_dict_mapping = {
"input": {
-1: [mesh_dim_0, mesh_dim_1]
},
"other": {
self.dim_p: [mesh_dim_0, mesh_dim_1]
},
"bias": {},
"output": {},
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# get communication action
output_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping['output'],
communication_pattern=CollectiveCommPattern.REDUCE_FWD_IDENTITY_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1])
communication_action_mapping = {'output': output_comm_spec}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
def split_rhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}'
# get sharding spec
dim_partition_dict_mapping = {
"input": {},
"other": {
self.dim_q: [mesh_dim_0, mesh_dim_1]
},
"bias": {
-1: [mesh_dim_0, mesh_dim_1]
},
"output": {
-1: [mesh_dim_0, mesh_dim_1]
},
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# get communication action
input_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping['input'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1])
communication_action_mapping = {'input': input_comm_spec}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
def validate(self) -> bool:
assert "input" in self.op_data
assert "other" in self.op_data
# make sure the other has 2 dim
input_data = self.op_data['input']
other_data = self.op_data['other']
assert input_data.data.dim() > 0 and other_data.data.dim() == 2
assert other_data.logical_shape[0] == input_data.logical_shape[-1]
# check if bias has the same a valid dim
has_bias = "bias" in self.op_data
if has_bias:
bias_data = self.op_data['bias']
assert bias_data.logical_shape[-1] == other_data.logical_shape[-1]
class BatchedMatMulStrategyGenerator(StrategyGenerator_V2):
"""TODO: to be implemented"""
pass

View File

@ -0,0 +1,154 @@
import operator
import torch
from colossalai.tensor.sharding_spec import ShardingSpec
from functools import reduce
from abc import ABC, abstractmethod
from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec
from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.device.device_mesh import DeviceMesh
from typing import Dict, List, Union, Any
from ..sharding_strategy import OperationData, ShardingStrategy_V2, TrainCycleItem
class StrategyGenerator_V2(ABC):
"""
StrategyGenerator is used to generate the same group of sharding strategies.
TODO: remove the original strategy_generator.py after refactoring
"""
def __init__(self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh):
self.op_data = operation_data_mapping
self.device_mesh = device_mesh
def get_sharding_strategy(self, name: str, sharding_spec_mapping: Dict[str, ShardingSpec],
communication_action_mapping: Dict[str, CommSpec]):
"""
A factory method to produce a ShardingStrategy object.
Args:
sharding_spec_mapping (Dict[str, ShardingSpec]): the mapping between the operation data name and the ShardingSpec object.
communication_action_mapping (Dict[str, CommSpec]): the mapping between the operation data name and the CommSpec object.
"""
sharding_specs = self.replace_op_name_with_op_data(sharding_spec_mapping)
communication_actions = self.replace_op_name_with_op_data(communication_action_mapping)
return ShardingStrategy_V2(name=name,
sharding_specs=sharding_specs,
communication_actions=communication_actions)
def to_sharding_spec_mapping(self, mapping: Dict[str, Dict[int, List[int]]]):
"""
A utility method to convert the the dim partition dict to a ShardingSpec object.
Args:
mapping (Dict[str, Dict[int, List[int]]]): the key of the mapping is the operation data name and the value is a dim partition dictionary.
"""
results = {}
for op_data_name, dim_partition_dict in mapping.items():
op_data = self.op_data[op_data_name]
sharding_spec = ShardingSpec(device_mesh=self.device_mesh,
entire_shape=op_data.logical_shape,
dim_partition_dict=dim_partition_dict)
results[op_data_name] = sharding_spec
return results
def replace_op_name_with_op_data(self, mapping: Dict[str, Any]):
"""
Convert the key of the dictionary from the operation data name to an OperationData object.
"""
results = {}
for k, v in mapping.items():
op_data = self.op_data[k]
results[op_data] = v
return results
def get_communication_spec(self, sharding_spec: ShardingSpec, communication_pattern: CollectiveCommPattern,
logical_process_axis: Union[int, List[int]]):
"""
A factory method to produce a CommSpec object.
"""
# use flatten device mesh the same action is applied to two axes
if isinstance(logical_process_axis, list) and len(logical_process_axis) == 2:
sharding_spec.device_mesh = sharding_spec.device_mesh.flatten()
logical_process_axis = 0
return CommSpec(comm_pattern=communication_pattern,
sharding_spec=sharding_spec,
logical_process_axis=logical_process_axis)
def update_communication_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2:
"""
Compute the communication cost involved in the forward and backward iteration.
"""
comm_cost = TrainCycleItem(fwd=0, bwd=0)
def _compute_and_add(data: OperationData, comm_spec: CommSpec):
num_ele_in_comm = comm_spec.get_comm_cost()
dtype = operand.data.dtype
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
cost = size_per_elem_bytes * num_ele_in_comm
# compute the fwd
# TODO: comm_spec.get_comm_cost should return a TrainCycleItem instead of the total cost.
# it works fine here because only REDUCE_FWD_IDENTITY_BWD and IDENTITY_FWD_ALLREDUCE_BWD are used,
# so total cost is either for fwd or bwd.
if comm_spec.comm_pattern == CollectiveCommPattern.REDUCE_FWD_IDENTITY_BWD:
comm_cost.fwd += cost
elif comm_spec.comm_pattern == CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD:
comm_cost.fwd += cost
else:
raise ValueError(f"Found unknown CommunicationType {comm_spec.comm_pattern}")
# check if communication action exists
# if so, loop over each action and compute the cost of each action
if strategy.communication_actions is not None:
for operand, comm_spec in strategy.communication_actions:
_compute_and_add(operand, comm_spec)
# update the communication cost attribute in-place
strategy.communication_cost = comm_cost
return strategy
@abstractmethod
def update_compute_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2:
"""
Customize this method to compute the computation flops.
"""
pass
@abstractmethod
def update_memory_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2:
"""
Customize this method to compute the memory cost in bytes.
"""
pass
def _compute_size_in_bytes(self, strategy: ShardingStrategy_V2, key: str):
"""
Compute the size of a tensor in bytes.
Args:
strategy (ShardingStrategy): the ShardingStrategy generated.
key (str): the name of the operation data defined by the generator.
"""
op_data = self.op_data[key]
sharded_shape = strategy.sharding_specs[op_data].get_sharded_shape_per_device()
dtype = self.op_data[key].data.dtype
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
return reduce(operator.mul, sharded_shape) * size_per_elem_bytes
@abstractmethod
def generate(self) -> List[ShardingStrategy_V2]:
"""
Generate all possible sharding strategies for this operation.
"""
pass
@abstractmethod
def validate(self, *args, **kwargs) -> bool:
"""
Validate if the operands are of desired shape.
If True, means this generator can be used for the current operation.
"""
pass

View File

@ -84,13 +84,13 @@ def test_linear_function_handler():
assert mapping['other'].name == "weight" assert mapping['other'].name == "weight"
assert mapping['other'].data.is_meta assert mapping['other'].data.is_meta
assert mapping['other'].data.shape == torch.Size([20, 10]) assert mapping['other'].data.shape == torch.Size([20, 10])
assert mapping['other'].type == OperationDataType.ARG assert mapping['other'].type == OperationDataType.PARAM
assert mapping['other'].logical_shape == torch.Size([10, 20]) assert mapping['other'].logical_shape == torch.Size([10, 20])
assert mapping['bias'].name == "bias" assert mapping['bias'].name == "bias"
assert mapping['bias'].data.is_meta assert mapping['bias'].data.is_meta
assert mapping['bias'].data.shape == torch.Size([20]) assert mapping['bias'].data.shape == torch.Size([20])
assert mapping['bias'].type == OperationDataType.ARG assert mapping['bias'].type == OperationDataType.PARAM
assert mapping['other'].logical_shape == torch.Size([10, 20]) assert mapping['other'].logical_shape == torch.Size([10, 20])
assert mapping['output'].name == "linear" assert mapping['output'].name == "linear"
@ -100,5 +100,5 @@ def test_linear_function_handler():
if __name__ == '__main__': if __name__ == '__main__':
# test_linear_module_handler() test_linear_module_handler()
test_linear_function_handler() test_linear_function_handler()