mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 21:09:18 +00:00
[autoparallel] added dot handler (#1475)
This commit is contained in:
@@ -1,9 +1,7 @@
|
||||
from lib2to3.pytree import Base
|
||||
import operator
|
||||
from functools import reduce
|
||||
import torch
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy
|
||||
from .operator_handler import OperatorHanlder
|
||||
|
||||
|
||||
@@ -26,25 +24,6 @@ class ConvHandler(OperatorHanlder):
|
||||
assert self.input_data.dim() in (3, 4,
|
||||
5), f'We suppose the dim of input fed into conv op should in range of [3, 5].'
|
||||
|
||||
def _generate_resharding_costs(self, resharding_costs, sharding_spec_for_input):
|
||||
'''
|
||||
Compute the resharding costs with this specific strategy.
|
||||
|
||||
Note: The resharding_cost of weight is NOT counted.
|
||||
|
||||
Argument:
|
||||
resharding_costs(Dict[int, List[float]]): The resharding cost generated in this method will be appended into this dictionary.
|
||||
Resharding_cost[i][j] means the cost of i-th argument in the output node argument list
|
||||
with j-th strategy in its strategies_vector transforms to sharding spec wanted in this
|
||||
strategy.
|
||||
sharding_spec_for_input(ShardingSpec): ShardingSpec of the input node.
|
||||
'''
|
||||
# The resharding_cost of weight is counted due to sharing weight cases.
|
||||
resharding_costs[self.input_index] = []
|
||||
for stategy in self.input_node.strategies_vector.strategies:
|
||||
_, _, resharding_cost = self.shape_consistency_manager.shape_consistency(stategy, sharding_spec_for_input)
|
||||
resharding_costs[self.input_index].append(resharding_cost)
|
||||
|
||||
def _generate_compute_cost(self, bs, channel_in, channel_out):
|
||||
'''
|
||||
Compute the computation cost per device with this specific strategy.
|
||||
|
@@ -1,4 +1,8 @@
|
||||
import operator
|
||||
import torch
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy
|
||||
from .operator_handler import OperatorHanlder
|
||||
from functools import reduce
|
||||
|
||||
|
||||
class DotHandler(OperatorHanlder):
|
||||
@@ -6,7 +10,226 @@ class DotHandler(OperatorHanlder):
|
||||
A OperatorHandler which deals with the sharding strategies of linear matrix multiplication.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
def _generate_compute_cost(self, input_shape, weight_shape):
|
||||
# TODO: consider bias addition
|
||||
compute_cost = reduce(operator.mul, input_shape) * weight_shape[0] * 2
|
||||
return compute_cost
|
||||
|
||||
# TODO: refactor the dot handler in my local branch to align with the latest main branch
|
||||
def split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1):
|
||||
# handle case SS = SR x RS
|
||||
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
|
||||
|
||||
dim_partition_dict_for_input = {0: [mesh_dim_0]}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
# linear layer weight is transposed during init
|
||||
dim_partition_dict_for_weight = {0: [mesh_dim_1]}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output, dim_partition_dict_for_input)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = {}
|
||||
self._generate_resharding_costs(resharding_costs, sharding_spec_for_input)
|
||||
|
||||
# compute computation cost
|
||||
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
dtype = self.input_data.dtype
|
||||
numel = self.output.numel()
|
||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||
sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||
memory_cost = numel * size_per_elem_bytes / sharding_size
|
||||
|
||||
# compute the communication cost
|
||||
# no all-reduce required for this case
|
||||
communication_cost = 0
|
||||
|
||||
# create and register strategy
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_ouput,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
self.strategies_vector.strategies.append(sharding_strategies)
|
||||
|
||||
def split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
|
||||
# handle the case SR = SS x SR
|
||||
name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
|
||||
|
||||
dim_partition_dict_for_input = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
# since weight of the linear layer is transposed
|
||||
# the actual dim to be sharded is 1
|
||||
dim_partition_dict_for_weight = {1: [mesh_dim_0]}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {0: [mesh_dim_0]}
|
||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = {}
|
||||
self._generate_resharding_costs(resharding_costs, sharding_spec_for_input)
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
dtype = self.input_data.dtype
|
||||
numel = self.output.numel()
|
||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||
sharding_size = self.device_mesh.shape[mesh_dim_0]
|
||||
memory_cost = numel * size_per_elem_bytes / sharding_size
|
||||
|
||||
# compute the communication cost of this strategy
|
||||
communication_cost = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim_1)
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_ouput,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
self.strategies_vector.strategies.append(sharding_strategies)
|
||||
|
||||
def split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
|
||||
|
||||
dim_partition_dict_for_input = {1: [mesh_dim_0]}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {1: [mesh_dim_1]}
|
||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output, dim_partition_dict_for_input)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = {}
|
||||
self._generate_resharding_costs(resharding_costs, sharding_spec_for_input)
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
dtype = self.input_data.dtype
|
||||
numel = self.output.numel()
|
||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||
sharding_size = self.device_mesh.shape[mesh_dim_0]
|
||||
memory_cost = numel * size_per_elem_bytes / sharding_size
|
||||
|
||||
# compute the communication cost of this strategy
|
||||
communication_cost = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim_1)
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_ouput,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
self.strategies_vector.strategies.append(sharding_strategies)
|
||||
|
||||
def recompute_split_both_contract(self, mesh_dim):
|
||||
name = f'RR = RS{mesh_dim} x S{mesh_dim}R'
|
||||
|
||||
dim_partition_dict_for_input = {1: [mesh_dim]}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {1: [mesh_dim]}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {}
|
||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = {}
|
||||
self._generate_resharding_costs(resharding_costs, sharding_spec_for_input)
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
dtype = self.input_data.dtype
|
||||
numel = self.output.numel()
|
||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||
memory_cost = numel * size_per_elem_bytes
|
||||
|
||||
# compute the communication cost of this strategy
|
||||
communication_cost = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim)
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_ouput,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
self.strategies_vector.strategies.append(sharding_strategies)
|
||||
|
||||
def split_rhs_space_only(self, mesh_dim):
|
||||
name = f'RS{mesh_dim} = RR x RS{mesh_dim}'
|
||||
|
||||
dim_partition_dict_for_input = {}
|
||||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||
|
||||
dim_partition_dict_for_weight = {0: [mesh_dim]}
|
||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||
|
||||
dim_partition_dict_for_output = {1: [mesh_dim]}
|
||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output, dim_partition_dict_for_output)
|
||||
|
||||
# generate resharding cost for this strategy
|
||||
resharding_costs = {}
|
||||
self._generate_resharding_costs(resharding_costs, sharding_spec_for_input)
|
||||
|
||||
# compute the computation cost of this strategy
|
||||
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape)
|
||||
|
||||
# compute the memory cost of this strategy
|
||||
dtype = self.input_data.dtype
|
||||
numel = self.output.numel()
|
||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||
sharding_size = self.device_mesh.shape[mesh_dim]
|
||||
memory_cost = numel * size_per_elem_bytes / sharding_size
|
||||
|
||||
# compute the communication cost of this strategy
|
||||
communication_cost = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim)
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_ouput,
|
||||
compute_cost=compute_cost,
|
||||
communication_cost=communication_cost,
|
||||
memory_cost=memory_cost,
|
||||
resharding_costs=resharding_costs,
|
||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||
self.strategies_vector.strategies.append(sharding_strategies)
|
||||
|
||||
def register_strategy_into_strategies_vector(self):
|
||||
'''
|
||||
Generate every possible strategies for a Conv node, and record all strategies into the strategies_vector.
|
||||
|
||||
Output:
|
||||
|
||||
'''
|
||||
# SS = SR x RS
|
||||
self.split_lhs_space_rhs_space(0, 1)
|
||||
self.split_lhs_space_rhs_space(1, 0)
|
||||
|
||||
# SR = SS x SR
|
||||
self.split_lhs_space_both_contract(0, 1)
|
||||
self.split_lhs_space_both_contract(1, 0)
|
||||
|
||||
# RS = RS x SS
|
||||
self.split_rhs_space_both_contract(0, 1)
|
||||
self.split_rhs_space_both_contract(1, 0)
|
||||
|
||||
# RR= RS x SR
|
||||
self.recompute_split_both_contract(0)
|
||||
self.recompute_split_both_contract(1)
|
||||
|
||||
# RS = RR x RS
|
||||
self.split_rhs_space_only(0)
|
||||
self.split_rhs_space_only(1)
|
||||
|
@@ -43,3 +43,23 @@ class OperatorHanlder(ABC):
|
||||
entire_shape=tensor.shape,
|
||||
dim_partition_dict=dim_partition_dict)
|
||||
return sharding_spec
|
||||
|
||||
def _generate_resharding_costs(self, resharding_costs, sharding_spec_for_input):
|
||||
'''
|
||||
Compute the resharding costs with this specific strategy.
|
||||
|
||||
Note: The resharding_cost of weight is NOT counted.
|
||||
|
||||
Argument:
|
||||
resharding_costs(Dict[int, List[float]]): The resharding cost generated in this method will be appended into this dictionary.
|
||||
Resharding_cost[i][j] means the cost of i-th argument in the output node argument list
|
||||
with j-th strategy in its strategies_vector transforms to sharding spec wanted in this
|
||||
strategy.
|
||||
sharding_spec_for_input(ShardingSpec): ShardingSpec of the input node.
|
||||
'''
|
||||
# The resharding_cost of weight is counted due to sharing weight cases.
|
||||
resharding_costs[self.input_index] = []
|
||||
for stategy in self.input_node.strategies_vector.strategies:
|
||||
_, _, resharding_cost = self.shape_consistency_manager.shape_consistency(stategy, sharding_spec_for_input)
|
||||
resharding_costs[self.input_index].append(resharding_cost)
|
||||
return resharding_cost
|
||||
|
@@ -42,10 +42,13 @@ class StrategiesVector:
|
||||
strategies(List[ShardingStrategy]): enumerate all the possible sharding strategies of the node.
|
||||
'''
|
||||
|
||||
def __init__(self, node, in_nodes, following_nodes=None, strategies=[]):
|
||||
def __init__(self, node, in_nodes, following_nodes=None, strategies=None):
|
||||
self.node = node
|
||||
self.in_nodes = in_nodes
|
||||
self.following_nodes = following_nodes
|
||||
|
||||
if strategies is None:
|
||||
strategies = []
|
||||
self.strategies = strategies
|
||||
|
||||
def check_merge(self):
|
||||
|
Reference in New Issue
Block a user