[autoparallel] added dot handler (#1475)

This commit is contained in:
Frank Lee
2022-08-22 10:32:17 +08:00
committed by GitHub
parent d08566fb61
commit 628c7e3fc8
5 changed files with 364 additions and 26 deletions

View File

@@ -1,9 +1,7 @@
from lib2to3.pytree import Base
import operator
from functools import reduce
import torch
from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy
from .operator_handler import OperatorHanlder
@@ -26,25 +24,6 @@ class ConvHandler(OperatorHanlder):
assert self.input_data.dim() in (3, 4,
5), f'We suppose the dim of input fed into conv op should in range of [3, 5].'
def _generate_resharding_costs(self, resharding_costs, sharding_spec_for_input):
'''
Compute the resharding costs with this specific strategy.
Note: The resharding_cost of weight is NOT counted.
Argument:
resharding_costs(Dict[int, List[float]]): The resharding cost generated in this method will be appended into this dictionary.
Resharding_cost[i][j] means the cost of i-th argument in the output node argument list
with j-th strategy in its strategies_vector transforms to sharding spec wanted in this
strategy.
sharding_spec_for_input(ShardingSpec): ShardingSpec of the input node.
'''
# The resharding_cost of weight is counted due to sharing weight cases.
resharding_costs[self.input_index] = []
for stategy in self.input_node.strategies_vector.strategies:
_, _, resharding_cost = self.shape_consistency_manager.shape_consistency(stategy, sharding_spec_for_input)
resharding_costs[self.input_index].append(resharding_cost)
def _generate_compute_cost(self, bs, channel_in, channel_out):
'''
Compute the computation cost per device with this specific strategy.

View File

@@ -1,4 +1,8 @@
import operator
import torch
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy
from .operator_handler import OperatorHanlder
from functools import reduce
class DotHandler(OperatorHanlder):
@@ -6,7 +10,226 @@ class DotHandler(OperatorHanlder):
A OperatorHandler which deals with the sharding strategies of linear matrix multiplication.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def _generate_compute_cost(self, input_shape, weight_shape):
# TODO: consider bias addition
compute_cost = reduce(operator.mul, input_shape) * weight_shape[0] * 2
return compute_cost
# TODO: refactor the dot handler in my local branch to align with the latest main branch
def split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1):
# handle case SS = SR x RS
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
dim_partition_dict_for_input = {0: [mesh_dim_0]}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
# linear layer weight is transposed during init
dim_partition_dict_for_weight = {0: [mesh_dim_1]}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]}
sharding_spec_for_ouput = self._generate_sharding_spec(self.output, dim_partition_dict_for_input)
# generate resharding cost for this strategy
resharding_costs = {}
self._generate_resharding_costs(resharding_costs, sharding_spec_for_input)
# compute computation cost
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape)
# compute the memory cost of this strategy
dtype = self.input_data.dtype
numel = self.output.numel()
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
memory_cost = numel * size_per_elem_bytes / sharding_size
# compute the communication cost
# no all-reduce required for this case
communication_cost = 0
# create and register strategy
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_ouput,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.strategies.append(sharding_strategies)
def split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
# handle the case SR = SS x SR
name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
dim_partition_dict_for_input = {0: [mesh_dim_0], 1: [mesh_dim_1]}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
# since weight of the linear layer is transposed
# the actual dim to be sharded is 1
dim_partition_dict_for_weight = {1: [mesh_dim_0]}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {0: [mesh_dim_0]}
sharding_spec_for_ouput = self._generate_sharding_spec(self.output, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = {}
self._generate_resharding_costs(resharding_costs, sharding_spec_for_input)
# compute the computation cost of this strategy
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape)
# compute the memory cost of this strategy
dtype = self.input_data.dtype
numel = self.output.numel()
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
sharding_size = self.device_mesh.shape[mesh_dim_0]
memory_cost = numel * size_per_elem_bytes / sharding_size
# compute the communication cost of this strategy
communication_cost = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim_1)
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_ouput,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.strategies.append(sharding_strategies)
def split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
dim_partition_dict_for_input = {1: [mesh_dim_0]}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_weight = {0: [mesh_dim_0], 1: [mesh_dim_1]}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {1: [mesh_dim_1]}
sharding_spec_for_ouput = self._generate_sharding_spec(self.output, dim_partition_dict_for_input)
# generate resharding cost for this strategy
resharding_costs = {}
self._generate_resharding_costs(resharding_costs, sharding_spec_for_input)
# compute the computation cost of this strategy
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape)
# compute the memory cost of this strategy
dtype = self.input_data.dtype
numel = self.output.numel()
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
sharding_size = self.device_mesh.shape[mesh_dim_0]
memory_cost = numel * size_per_elem_bytes / sharding_size
# compute the communication cost of this strategy
communication_cost = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim_1)
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_ouput,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.strategies.append(sharding_strategies)
def recompute_split_both_contract(self, mesh_dim):
name = f'RR = RS{mesh_dim} x S{mesh_dim}R'
dim_partition_dict_for_input = {1: [mesh_dim]}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_weight = {1: [mesh_dim]}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {}
sharding_spec_for_ouput = self._generate_sharding_spec(self.output, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = {}
self._generate_resharding_costs(resharding_costs, sharding_spec_for_input)
# compute the computation cost of this strategy
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape)
# compute the memory cost of this strategy
dtype = self.input_data.dtype
numel = self.output.numel()
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
memory_cost = numel * size_per_elem_bytes
# compute the communication cost of this strategy
communication_cost = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim)
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_ouput,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.strategies.append(sharding_strategies)
def split_rhs_space_only(self, mesh_dim):
name = f'RS{mesh_dim} = RR x RS{mesh_dim}'
dim_partition_dict_for_input = {}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_weight = {0: [mesh_dim]}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {1: [mesh_dim]}
sharding_spec_for_ouput = self._generate_sharding_spec(self.output, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = {}
self._generate_resharding_costs(resharding_costs, sharding_spec_for_input)
# compute the computation cost of this strategy
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape)
# compute the memory cost of this strategy
dtype = self.input_data.dtype
numel = self.output.numel()
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
sharding_size = self.device_mesh.shape[mesh_dim]
memory_cost = numel * size_per_elem_bytes / sharding_size
# compute the communication cost of this strategy
communication_cost = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim)
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_ouput,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.strategies.append(sharding_strategies)
def register_strategy_into_strategies_vector(self):
'''
Generate every possible strategies for a Conv node, and record all strategies into the strategies_vector.
Output:
'''
# SS = SR x RS
self.split_lhs_space_rhs_space(0, 1)
self.split_lhs_space_rhs_space(1, 0)
# SR = SS x SR
self.split_lhs_space_both_contract(0, 1)
self.split_lhs_space_both_contract(1, 0)
# RS = RS x SS
self.split_rhs_space_both_contract(0, 1)
self.split_rhs_space_both_contract(1, 0)
# RR= RS x SR
self.recompute_split_both_contract(0)
self.recompute_split_both_contract(1)
# RS = RR x RS
self.split_rhs_space_only(0)
self.split_rhs_space_only(1)

View File

@@ -43,3 +43,23 @@ class OperatorHanlder(ABC):
entire_shape=tensor.shape,
dim_partition_dict=dim_partition_dict)
return sharding_spec
def _generate_resharding_costs(self, resharding_costs, sharding_spec_for_input):
'''
Compute the resharding costs with this specific strategy.
Note: The resharding_cost of weight is NOT counted.
Argument:
resharding_costs(Dict[int, List[float]]): The resharding cost generated in this method will be appended into this dictionary.
Resharding_cost[i][j] means the cost of i-th argument in the output node argument list
with j-th strategy in its strategies_vector transforms to sharding spec wanted in this
strategy.
sharding_spec_for_input(ShardingSpec): ShardingSpec of the input node.
'''
# The resharding_cost of weight is counted due to sharing weight cases.
resharding_costs[self.input_index] = []
for stategy in self.input_node.strategies_vector.strategies:
_, _, resharding_cost = self.shape_consistency_manager.shape_consistency(stategy, sharding_spec_for_input)
resharding_costs[self.input_index].append(resharding_cost)
return resharding_cost

View File

@@ -42,10 +42,13 @@ class StrategiesVector:
strategies(List[ShardingStrategy]): enumerate all the possible sharding strategies of the node.
'''
def __init__(self, node, in_nodes, following_nodes=None, strategies=[]):
def __init__(self, node, in_nodes, following_nodes=None, strategies=None):
self.node = node
self.in_nodes = in_nodes
self.following_nodes = following_nodes
if strategies is None:
strategies = []
self.strategies = strategies
def check_merge(self):