[autoparallel] integrate auto parallel with torch fx (#1479)

This commit is contained in:
Frank Lee 2022-08-23 14:23:08 +08:00 committed by GitHub
parent 8fb09a950a
commit ede326298b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 132 additions and 120 deletions

View File

@ -0,0 +1,6 @@
from .operator_handler import OperatorHandler
from .dot_handler import DotHandler
from .conv_handler import ConvHandler
from .sharding_strategy import ShardingStrategy, StrategiesVector
__all__ = ['OperatorHandler', 'DotHandler', 'ConvHandler', 'StrategiesVector', 'ShardingStrategy']

View File

@ -1,17 +1,20 @@
import operator import operator
from functools import reduce from functools import reduce
import torch import torch
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
from .operator_handler import OperatorHanlder from .operator_handler import OperatorHandler
class ConvHandler(OperatorHanlder): class ConvHandler(OperatorHandler):
""" """
A OperatorHandler which deals with the sharding strategies of linear matrix multiplication. A OperatorHandler which deals with the sharding strategies of linear matrix multiplication.
""" """
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.input_data = self.predecessor_node[0]._meta_data
self.weight = self.module_named_parameters['weight']
self.output_data = self.node._meta_data
self._sanity_check() self._sanity_check()
def _sanity_check(self): def _sanity_check(self):
@ -42,7 +45,7 @@ class ConvHandler(OperatorHanlder):
# 1D: (L) * N * Cout * Cin * kernel # 1D: (L) * N * Cout * Cin * kernel
# 2D: (H * W) * N * Cout * Cin * kernel # 2D: (H * W) * N * Cout * Cin * kernel
# 3D: (H * W * D) * N * Cout * Cin * kernel # 3D: (H * W * D) * N * Cout * Cin * kernel
output_size = self.output.shape[2:] output_size = self.output_data.shape[2:]
output_size_product = reduce(operator.mul, output_size, 1) output_size_product = reduce(operator.mul, output_size, 1)
kernel_size = self.weight.shape[2:] kernel_size = self.weight.shape[2:]
kernel_size_product = reduce(operator.mul, kernel_size, 1) kernel_size_product = reduce(operator.mul, kernel_size, 1)
@ -59,11 +62,10 @@ class ConvHandler(OperatorHanlder):
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]} dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]}
sharding_spec_for_ouput = self._generate_sharding_spec(self.output, dim_partition_dict_for_input) sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy # generate resharding cost for this strategy
resharding_costs = {} resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
self._generate_resharding_costs(resharding_costs, sharding_spec_for_input)
# compute the computation cost of this strategy # compute the computation cost of this strategy
bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0] bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0]
@ -73,7 +75,7 @@ class ConvHandler(OperatorHanlder):
# compute the memory cost of this strategy # compute the memory cost of this strategy
dtype = self.input_data.dtype dtype = self.input_data.dtype
numel = self.output.numel() numel = self.output_data.numel()
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1] sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
memory_cost = numel * size_per_elem_bytes / sharding_size memory_cost = numel * size_per_elem_bytes / sharding_size
@ -87,7 +89,7 @@ class ConvHandler(OperatorHanlder):
memory_cost=memory_cost, memory_cost=memory_cost,
resharding_costs=resharding_costs, resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.strategies.append(sharding_strategies) self.strategies_vector.append(sharding_strategies)
def split_input_both_dim_weight_in_channel(self, mesh_dim_0, mesh_dim_1): def split_input_both_dim_weight_in_channel(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R' name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
@ -99,11 +101,10 @@ class ConvHandler(OperatorHanlder):
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {0: [mesh_dim_0]} dim_partition_dict_for_output = {0: [mesh_dim_0]}
sharding_spec_for_ouput = self._generate_sharding_spec(self.output, dim_partition_dict_for_input) sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
# generate resharding cost for this strategy # generate resharding cost for this strategy
resharding_costs = {} resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
self._generate_resharding_costs(resharding_costs, sharding_spec_for_input)
# compute the computation cost of this strategy # compute the computation cost of this strategy
bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0] bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0]
@ -113,7 +114,7 @@ class ConvHandler(OperatorHanlder):
# compute the memory cost of this strategy # compute the memory cost of this strategy
dtype = self.input_data.dtype dtype = self.input_data.dtype
numel = self.output.numel() numel = self.output_data.numel()
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
sharding_size = self.device_mesh.shape[mesh_dim_0] sharding_size = self.device_mesh.shape[mesh_dim_0]
memory_cost = numel * size_per_elem_bytes / sharding_size memory_cost = numel * size_per_elem_bytes / sharding_size
@ -127,7 +128,7 @@ class ConvHandler(OperatorHanlder):
memory_cost=memory_cost, memory_cost=memory_cost,
resharding_costs=resharding_costs, resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.strategies.append(sharding_strategies) self.strategies_vector.append(sharding_strategies)
def split_input_in_channel_weight_both_channel(self, mesh_dim_0, mesh_dim_1): def split_input_in_channel_weight_both_channel(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}' name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
@ -139,11 +140,10 @@ class ConvHandler(OperatorHanlder):
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {1: [mesh_dim_1]} dim_partition_dict_for_output = {1: [mesh_dim_1]}
sharding_spec_for_ouput = self._generate_sharding_spec(self.output, dim_partition_dict_for_input) sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
# generate resharding cost for this strategy # generate resharding cost for this strategy
resharding_costs = {} resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
self._generate_resharding_costs(resharding_costs, sharding_spec_for_input)
# compute the computation cost of this strategy # compute the computation cost of this strategy
bs = self.input_data.shape[0] bs = self.input_data.shape[0]
@ -153,7 +153,7 @@ class ConvHandler(OperatorHanlder):
# compute the memory cost of this strategy # compute the memory cost of this strategy
dtype = self.input_data.dtype dtype = self.input_data.dtype
numel = self.output.numel() numel = self.output_data.numel()
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
sharding_size = self.device_mesh.shape[mesh_dim_0] sharding_size = self.device_mesh.shape[mesh_dim_0]
memory_cost = numel * size_per_elem_bytes / sharding_size memory_cost = numel * size_per_elem_bytes / sharding_size
@ -167,7 +167,7 @@ class ConvHandler(OperatorHanlder):
memory_cost=memory_cost, memory_cost=memory_cost,
resharding_costs=resharding_costs, resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.strategies.append(sharding_strategies) self.strategies_vector.append(sharding_strategies)
def split_weight_out_channel(self, mesh_dim_0): def split_weight_out_channel(self, mesh_dim_0):
name = f'RS{mesh_dim_0} = RR x RS{mesh_dim_0}' name = f'RS{mesh_dim_0} = RR x RS{mesh_dim_0}'
@ -179,11 +179,10 @@ class ConvHandler(OperatorHanlder):
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {1: [mesh_dim_0]} dim_partition_dict_for_output = {1: [mesh_dim_0]}
sharding_spec_for_ouput = self._generate_sharding_spec(self.output, dim_partition_dict_for_input) sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
# generate resharding cost for this strategy # generate resharding cost for this strategy
resharding_costs = {} resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
self._generate_resharding_costs(resharding_costs, sharding_spec_for_input)
# compute the computation cost of this strategy # compute the computation cost of this strategy
bs = self.input_data.shape[0] bs = self.input_data.shape[0]
@ -193,7 +192,7 @@ class ConvHandler(OperatorHanlder):
# compute the memory cost of this strategy # compute the memory cost of this strategy
dtype = self.input_data.dtype dtype = self.input_data.dtype
numel = self.output.numel() numel = self.output_data.numel()
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
sharding_size = self.device_mesh.shape[mesh_dim_0] sharding_size = self.device_mesh.shape[mesh_dim_0]
memory_cost = numel * size_per_elem_bytes / sharding_size memory_cost = numel * size_per_elem_bytes / sharding_size
@ -208,7 +207,7 @@ class ConvHandler(OperatorHanlder):
memory_cost=memory_cost, memory_cost=memory_cost,
resharding_costs=resharding_costs, resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.strategies.append(sharding_strategies) self.strategies_vector.append(sharding_strategies)
def non_split(self): def non_split(self):
name = f'RR = RR x RR' name = f'RR = RR x RR'
@ -220,11 +219,10 @@ class ConvHandler(OperatorHanlder):
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {} dim_partition_dict_for_output = {}
sharding_spec_for_ouput = self._generate_sharding_spec(self.output, dim_partition_dict_for_input) sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
# generate resharding cost for this strategy # generate resharding cost for this strategy
resharding_costs = {} resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
self._generate_resharding_costs(resharding_costs, sharding_spec_for_input)
# compute the computation cost of this strategy # compute the computation cost of this strategy
bs = self.input_data.shape[0] bs = self.input_data.shape[0]
@ -234,7 +232,7 @@ class ConvHandler(OperatorHanlder):
# compute the memory cost of this strategy # compute the memory cost of this strategy
dtype = self.input_data.dtype dtype = self.input_data.dtype
numel = self.output.numel() numel = self.output_data.numel()
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
memory_cost = numel * size_per_elem_bytes memory_cost = numel * size_per_elem_bytes
@ -248,9 +246,9 @@ class ConvHandler(OperatorHanlder):
memory_cost=memory_cost, memory_cost=memory_cost,
resharding_costs=resharding_costs, resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.strategies.append(sharding_strategies) self.strategies_vector.append(sharding_strategies)
def register_strategy_into_strategies_vector(self): def register_strategy(self) -> StrategiesVector:
''' '''
Generate every possible strategies for a Conv node, and record all strategies into the strategies_vector. Generate every possible strategies for a Conv node, and record all strategies into the strategies_vector.
@ -315,3 +313,5 @@ class ConvHandler(OperatorHanlder):
# RR= RR x RR # RR= RR x RR
self.non_split() self.non_split()
return self.strategies_vector

View File

@ -1,15 +1,21 @@
import operator import operator
import torch import torch
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
from .operator_handler import OperatorHanlder from .operator_handler import OperatorHandler
from functools import reduce from functools import reduce
class DotHandler(OperatorHanlder): class DotHandler(OperatorHandler):
""" """
A OperatorHandler which deals with the sharding strategies of linear matrix multiplication. A OperatorHandler which deals with the sharding strategies of linear matrix multiplication.
""" """
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.input_data = self.predecessor_node[0]._meta_data
self.weight = self.module_named_parameters['weight']
self.output_data = self.node._meta_data
def _generate_compute_cost(self, input_shape, weight_shape): def _generate_compute_cost(self, input_shape, weight_shape):
# TODO: consider bias addition # TODO: consider bias addition
compute_cost = reduce(operator.mul, input_shape) * weight_shape[0] * 2 compute_cost = reduce(operator.mul, input_shape) * weight_shape[0] * 2
@ -27,18 +33,17 @@ class DotHandler(OperatorHanlder):
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]} dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]}
sharding_spec_for_ouput = self._generate_sharding_spec(self.output, dim_partition_dict_for_input) sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
# generate resharding cost for this strategy # generate resharding cost for this strategy
resharding_costs = {} resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
self._generate_resharding_costs(resharding_costs, sharding_spec_for_input)
# compute computation cost # compute computation cost
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape) compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape)
# compute the memory cost of this strategy # compute the memory cost of this strategy
dtype = self.input_data.dtype dtype = self.input_data.dtype
numel = self.output.numel() numel = self.output_data.numel()
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1] sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
memory_cost = numel * size_per_elem_bytes / sharding_size memory_cost = numel * size_per_elem_bytes / sharding_size
@ -55,7 +60,7 @@ class DotHandler(OperatorHanlder):
memory_cost=memory_cost, memory_cost=memory_cost,
resharding_costs=resharding_costs, resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.strategies.append(sharding_strategies) self.strategies_vector.append(sharding_strategies)
def split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1): def split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
# handle the case SR = SS x SR # handle the case SR = SS x SR
@ -70,18 +75,17 @@ class DotHandler(OperatorHanlder):
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {0: [mesh_dim_0]} dim_partition_dict_for_output = {0: [mesh_dim_0]}
sharding_spec_for_ouput = self._generate_sharding_spec(self.output, dim_partition_dict_for_output) sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy # generate resharding cost for this strategy
resharding_costs = {} resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
self._generate_resharding_costs(resharding_costs, sharding_spec_for_input)
# compute the computation cost of this strategy # compute the computation cost of this strategy
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape) compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape)
# compute the memory cost of this strategy # compute the memory cost of this strategy
dtype = self.input_data.dtype dtype = self.input_data.dtype
numel = self.output.numel() numel = self.output_data.numel()
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
sharding_size = self.device_mesh.shape[mesh_dim_0] sharding_size = self.device_mesh.shape[mesh_dim_0]
memory_cost = numel * size_per_elem_bytes / sharding_size memory_cost = numel * size_per_elem_bytes / sharding_size
@ -95,7 +99,7 @@ class DotHandler(OperatorHanlder):
memory_cost=memory_cost, memory_cost=memory_cost,
resharding_costs=resharding_costs, resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.strategies.append(sharding_strategies) self.strategies_vector.append(sharding_strategies)
def split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1): def split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}' name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
@ -107,18 +111,17 @@ class DotHandler(OperatorHanlder):
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {1: [mesh_dim_1]} dim_partition_dict_for_output = {1: [mesh_dim_1]}
sharding_spec_for_ouput = self._generate_sharding_spec(self.output, dim_partition_dict_for_input) sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
# generate resharding cost for this strategy # generate resharding cost for this strategy
resharding_costs = {} resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
self._generate_resharding_costs(resharding_costs, sharding_spec_for_input)
# compute the computation cost of this strategy # compute the computation cost of this strategy
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape) compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape)
# compute the memory cost of this strategy # compute the memory cost of this strategy
dtype = self.input_data.dtype dtype = self.input_data.dtype
numel = self.output.numel() numel = self.output_data.numel()
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
sharding_size = self.device_mesh.shape[mesh_dim_0] sharding_size = self.device_mesh.shape[mesh_dim_0]
memory_cost = numel * size_per_elem_bytes / sharding_size memory_cost = numel * size_per_elem_bytes / sharding_size
@ -132,7 +135,7 @@ class DotHandler(OperatorHanlder):
memory_cost=memory_cost, memory_cost=memory_cost,
resharding_costs=resharding_costs, resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.strategies.append(sharding_strategies) self.strategies_vector.append(sharding_strategies)
def recompute_split_both_contract(self, mesh_dim): def recompute_split_both_contract(self, mesh_dim):
name = f'RR = RS{mesh_dim} x S{mesh_dim}R' name = f'RR = RS{mesh_dim} x S{mesh_dim}R'
@ -144,18 +147,17 @@ class DotHandler(OperatorHanlder):
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {} dim_partition_dict_for_output = {}
sharding_spec_for_ouput = self._generate_sharding_spec(self.output, dim_partition_dict_for_output) sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy # generate resharding cost for this strategy
resharding_costs = {} resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
self._generate_resharding_costs(resharding_costs, sharding_spec_for_input)
# compute the computation cost of this strategy # compute the computation cost of this strategy
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape) compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape)
# compute the memory cost of this strategy # compute the memory cost of this strategy
dtype = self.input_data.dtype dtype = self.input_data.dtype
numel = self.output.numel() numel = self.output_data.numel()
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
memory_cost = numel * size_per_elem_bytes memory_cost = numel * size_per_elem_bytes
@ -168,7 +170,7 @@ class DotHandler(OperatorHanlder):
memory_cost=memory_cost, memory_cost=memory_cost,
resharding_costs=resharding_costs, resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.strategies.append(sharding_strategies) self.strategies_vector.append(sharding_strategies)
def split_rhs_space_only(self, mesh_dim): def split_rhs_space_only(self, mesh_dim):
name = f'RS{mesh_dim} = RR x RS{mesh_dim}' name = f'RS{mesh_dim} = RR x RS{mesh_dim}'
@ -180,18 +182,17 @@ class DotHandler(OperatorHanlder):
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {1: [mesh_dim]} dim_partition_dict_for_output = {1: [mesh_dim]}
sharding_spec_for_ouput = self._generate_sharding_spec(self.output, dim_partition_dict_for_output) sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy # generate resharding cost for this strategy
resharding_costs = {} resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
self._generate_resharding_costs(resharding_costs, sharding_spec_for_input)
# compute the computation cost of this strategy # compute the computation cost of this strategy
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape) compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape)
# compute the memory cost of this strategy # compute the memory cost of this strategy
dtype = self.input_data.dtype dtype = self.input_data.dtype
numel = self.output.numel() numel = self.output_data.numel()
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
sharding_size = self.device_mesh.shape[mesh_dim] sharding_size = self.device_mesh.shape[mesh_dim]
memory_cost = numel * size_per_elem_bytes / sharding_size memory_cost = numel * size_per_elem_bytes / sharding_size
@ -205,9 +206,9 @@ class DotHandler(OperatorHanlder):
memory_cost=memory_cost, memory_cost=memory_cost,
resharding_costs=resharding_costs, resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.strategies.append(sharding_strategies) self.strategies_vector.append(sharding_strategies)
def register_strategy_into_strategies_vector(self): def register_strategy(self) -> StrategiesVector:
''' '''
Generate every possible strategies for a Conv node, and record all strategies into the strategies_vector. Generate every possible strategies for a Conv node, and record all strategies into the strategies_vector.
@ -233,3 +234,4 @@ class DotHandler(OperatorHanlder):
# RS = RR x RS # RS = RR x RS
self.split_rhs_space_only(0) self.split_rhs_space_only(0)
self.split_rhs_space_only(1) self.split_rhs_space_only(1)
return self.strategies_vector

View File

@ -1,15 +1,18 @@
import torch
import torch.nn as nn
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from torch.fx.node import Node from torch.fx.node import Node
import torch.nn as nn from typing import Dict
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from .sharding_strategy import StrategiesVector
from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec from colossalai.tensor.sharding_spec import ShardingSpec
from .sharding_strategy import StrategiesVector
class OperatorHanlder(ABC):
class OperatorHandler(ABC):
''' '''
The OperatorHanlder is an abstract class used to generate every possible strategies for a operator node. The OperatorHandler is an abstract class used to generate every possible strategies for a operator node.
Argument: Argument:
input_node(Node): the input node in node argument list. input_node(Node): the input node in node argument list.
@ -21,30 +24,43 @@ class OperatorHanlder(ABC):
shape_consistency_manager(ShapeConsistencyManager): ShapeConsistencyManager will give the resharding costs of the different sharding specs. shape_consistency_manager(ShapeConsistencyManager): ShapeConsistencyManager will give the resharding costs of the different sharding specs.
''' '''
def __init__(self, input_node: Node, input_index: int, weight: nn.Parameter, output_node: Node, def __init__(self, node: Node, device_mesh: DeviceMesh, strategies_vector: StrategiesVector,
device_mesh: DeviceMesh, strategies_vector: StrategiesVector,
shape_consistency_manager: ShapeConsistencyManager): shape_consistency_manager: ShapeConsistencyManager):
self.input_node = input_node self.node = node
self.input_data = self.input_node._meta_data self.predecessor_node = list(node._input_nodes.keys())
self.weight = weight self.successor_node = list(node.users.keys())
self.input_index = input_index
self.output_node = output_node
self.output = self.output_node._meta_data
self.device_mesh = device_mesh self.device_mesh = device_mesh
self.strategies_vector = strategies_vector self.strategies_vector = strategies_vector
self.shape_consistency_manager = shape_consistency_manager self.shape_consistency_manager = shape_consistency_manager
# find the module and its parameters associated with this node
# this can be used to compute the compute/communication/sharding cost
if self.node.op == 'call_module':
module = node.graph.owning_module.get_submodule(node.target)
named_parameters = list(module.named_parameters(recurse=False))
# convert named parameters from list to dict
named_parameters = {k: v for k, v in named_parameters}
else:
module = None
named_parameters = None
self.module = module
self.module_named_parameters = named_parameters
@abstractmethod @abstractmethod
def register_strategy_into_strategies_vector(self): def register_strategy(self) -> StrategiesVector:
pass pass
def _generate_sharding_spec(self, tensor, dim_partition_dict): def _generate_sharding_spec(self, tensor: torch.Tensor, dim_partition_dict: Dict[int, int]) -> ShardingSpec:
"""
Generate the sharding spec of the tensor based on the given dim_partition_dict
where the key is the tensor dimension and the value is the mesh dimension for sharding.
"""
sharding_spec = ShardingSpec(device_mesh=self.device_mesh, sharding_spec = ShardingSpec(device_mesh=self.device_mesh,
entire_shape=tensor.shape, entire_shape=tensor.shape,
dim_partition_dict=dim_partition_dict) dim_partition_dict=dim_partition_dict)
return sharding_spec return sharding_spec
def _generate_resharding_costs(self, resharding_costs, sharding_spec_for_input): def _generate_resharding_costs(self, sharding_spec_for_input):
''' '''
Compute the resharding costs with this specific strategy. Compute the resharding costs with this specific strategy.
@ -58,8 +74,10 @@ class OperatorHanlder(ABC):
sharding_spec_for_input(ShardingSpec): ShardingSpec of the input node. sharding_spec_for_input(ShardingSpec): ShardingSpec of the input node.
''' '''
# The resharding_cost of weight is counted due to sharing weight cases. # The resharding_cost of weight is counted due to sharing weight cases.
resharding_costs[self.input_index] = [] resharding_costs = {}
for stategy in self.input_node.strategies_vector.strategies: for input_node, input_spec in zip(self.predecessor_node, sharding_spec_for_input):
_, _, resharding_cost = self.shape_consistency_manager.shape_consistency(stategy, sharding_spec_for_input) resharding_costs[input_node] = []
resharding_costs[self.input_index].append(resharding_cost) for strategy in input_node.strategies_vector:
_, _, resharding_cost = self.shape_consistency_manager.shape_consistency(strategy, input_spec)
resharding_costs[input_node].append(resharding_cost)
return resharding_cost return resharding_cost

View File

@ -1,6 +1,9 @@
from dataclasses import dataclass from dataclasses import dataclass
from colossalai.tensor.sharding_spec import ShardingSpec from colossalai.tensor.sharding_spec import ShardingSpec
from typing import Dict, List from typing import Dict, List
from torch.fx.node import Node
__all__ = ['ShardingStrategy', 'StrategiesVector']
@dataclass @dataclass
@ -30,26 +33,21 @@ class ShardingStrategy:
input_shardings: ShardingSpec = None input_shardings: ShardingSpec = None
class StrategiesVector: class StrategiesVector(list):
''' '''
Each node in fx graph will have a corresponding StrategiesVector, to store all the possible Each node in fx graph will have a corresponding StrategiesVector, to store all the possible
strategies of the node. strategies of the node.
Argument: Argument:
node(Node): node to build corresponding strategies_vector. node (Node): node for which the list of sharding strategies are generated.
in_nodes(List[Node]): input nodes in the argument list of the node.
following_nodes(List[Node]): the nodes take the target node as their argument.
strategies(List[ShardingStrategy]): enumerate all the possible sharding strategies of the node.
''' '''
def __init__(self, node, in_nodes, following_nodes=None, strategies=None): def __init__(self, node: Node):
super().__init__()
self.node = node self.node = node
self.in_nodes = in_nodes # fetch its input and output nodes
self.following_nodes = following_nodes self.predecessor_nodes = list(node._input_nodes.keys())
self.successor_ndoes = list(node.users.keys())
if strategies is None:
strategies = []
self.strategies = strategies
def check_merge(self): def check_merge(self):
pass pass

View File

@ -47,7 +47,9 @@ def test_conv_handler():
# [x, mul, conv, output] # [x, mul, conv, output]
nodes = [node for node in gm.graph.nodes] nodes = [node for node in gm.graph.nodes]
strategies_for_input = [] # find the sharding strategies for the input node of the conv node
# strategies_for_input = [[R, R, R, R], [R, S0, R, R], [R, S1, R, R], [S0, R, R, R], [S0, S1, R, R], [S1, R, R, R], [S1, S0, R, R]]
strategies_vector_for_input = StrategiesVector(nodes[1])
sharding_option = (None, 0, 1) sharding_option = (None, 0, 1)
for first_sharding_index in sharding_option: for first_sharding_index in sharding_option:
for second_sharding_index in sharding_option: for second_sharding_index in sharding_option:
@ -68,28 +70,19 @@ def test_conv_handler():
sharding_spec = ShardingSpec(device_mesh=device_mesh, sharding_spec = ShardingSpec(device_mesh=device_mesh,
entire_shape=entire_shape, entire_shape=entire_shape,
sharding_sequence=sharding_sequence) sharding_sequence=sharding_sequence)
strategies_for_input.append(sharding_spec) strategies_vector_for_input.append(sharding_spec)
# strategies_for_input = [[R, R, R, R], [R, S0, R, R], [R, S1, R, R], [S0, R, R, R], [S0, S1, R, R], [S1, R, R, R], [S1, S0, R, R]]
strategies_vector_for_input = StrategiesVector(node=nodes[0],
in_nodes=[nodes[1], 2],
strategies=strategies_for_input)
setattr(nodes[1], 'strategies_vector', strategies_vector_for_input) setattr(nodes[1], 'strategies_vector', strategies_vector_for_input)
strategies_vector = StrategiesVector(node=nodes[2], in_nodes=[ # generate conv strategy
nodes[1], strategies_vector = StrategiesVector(node=nodes[2])
]) conv_handler = ConvHandler(node=nodes[2],
conv_handler = ConvHandler(input_node=nodes[1],
input_index=0,
weight=dict(gm.named_modules())[nodes[2].name].weight,
output_node=nodes[2],
device_mesh=device_mesh, device_mesh=device_mesh,
strategies_vector=strategies_vector, strategies_vector=strategies_vector,
shape_consistency_manager=shape_consistency_manager) shape_consistency_manager=shape_consistency_manager)
conv_handler.register_strategy_into_strategies_vector() conv_handler.register_strategy()
# ['S0S1 = S0R x RS1', 'S1S0 = S1R x RS0', 'S0R = S0S1 x S1R', 'S1R = S1S0 x S0R', 'RS1 = RS0 x S0S1', 'RS0 = RS1 x S1S0', 'RS0 = RR x RS0', 'RS1 = RR x RS1', 'RR = RR x RR'] # ['S0S1 = S0R x RS1', 'S1S0 = S1R x RS0', 'S0R = S0S1 x S1R', 'S1R = S1S0 x S0R', 'RS1 = RS0 x S0S1', 'RS0 = RS1 x S1S0', 'RS0 = RR x RS0', 'RS1 = RR x RS1', 'RR = RR x RR']
strategy_name_list = [strategy.name for strategy in conv_handler.strategies_vector.strategies] strategy_name_list = [strategy.name for strategy in conv_handler.strategies_vector]
# SS = SR x RS # SS = SR x RS
assert 'S0S1 = S0R x RS1' in strategy_name_list assert 'S0S1 = S0R x RS1' in strategy_name_list

View File

@ -47,7 +47,9 @@ def test_dot_handler():
# [x, mul, linear, output] # [x, mul, linear, output]
nodes = [node for node in gm.graph.nodes] nodes = [node for node in gm.graph.nodes]
strategies_for_input = [] # find the sharding strategies for the input node of the conv node
# strategies_for_input = [[R, R, R, R], [R, S0, R, R], [R, S1, R, R], [S0, R, R, R], [S0, S1, R, R], [S1, R, R, R], [S1, S0, R, R]]
strategies_vector_for_input = StrategiesVector(node=nodes[1])
sharding_option = (None, 0, 1) sharding_option = (None, 0, 1)
for first_sharding_index in sharding_option: for first_sharding_index in sharding_option:
for second_sharding_index in sharding_option: for second_sharding_index in sharding_option:
@ -67,26 +69,19 @@ def test_dot_handler():
sharding_spec = ShardingSpec(device_mesh=device_mesh, sharding_spec = ShardingSpec(device_mesh=device_mesh,
entire_shape=entire_shape, entire_shape=entire_shape,
sharding_sequence=sharding_sequence) sharding_sequence=sharding_sequence)
strategies_for_input.append(sharding_spec) strategies_vector_for_input.append(sharding_spec)
# strategies_for_input = [[R, R, R, R], [R, S0, R, R], [R, S1, R, R], [S0, R, R, R], [S0, S1, R, R], [S1, R, R, R], [S1, S0, R, R]]
strategies_vector_for_input = StrategiesVector(node=nodes[1], in_nodes=nodes[0], strategies=strategies_for_input)
setattr(nodes[1], 'strategies_vector', strategies_vector_for_input) setattr(nodes[1], 'strategies_vector', strategies_vector_for_input)
strategies_vector = StrategiesVector(node=nodes[2], in_nodes=[ # generate dot strategy
nodes[1], strategies_vector = StrategiesVector(node=nodes[2])
]) dot_handler = DotHandler(node=nodes[2],
dot_handler = DotHandler(input_node=nodes[1],
input_index=0,
weight=dict(gm.named_modules())[nodes[2].name].weight,
output_node=nodes[2],
device_mesh=device_mesh, device_mesh=device_mesh,
strategies_vector=strategies_vector, strategies_vector=strategies_vector,
shape_consistency_manager=shape_consistency_manager) shape_consistency_manager=shape_consistency_manager)
dot_handler.register_strategy_into_strategies_vector() strategies_vector = dot_handler.register_strategy()
# ['S0S1 = S0R x RS1', 'S1S0 = S1R x RS0', 'S0R = S0S1 x S1R', 'S1R = S1S0 x S0R', 'RS1 = RS0 x S0S1', 'RS0 = RS1 x S1S0', 'RS0 = RR x RS0', 'RS1 = RR x RS1', 'RR = RR x RR'] # ['S0S1 = S0R x RS1', 'S1S0 = S1R x RS0', 'S0R = S0S1 x S1R', 'S1R = S1S0 x S0R', 'RS1 = RS0 x S0S1', 'RS0 = RS1 x S1S0', 'RS0 = RR x RS0', 'RS1 = RR x RS1', 'RR = RR x RR']
strategy_name_list = [strategy.name for strategy in dot_handler.strategies_vector.strategies] strategy_name_list = [strategy.name for strategy in strategies_vector]
# SS = SR x RS # SS = SR x RS
assert 'S0S1 = S0R x RS1' in strategy_name_list assert 'S0S1 = S0R x RS1' in strategy_name_list