mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-09 11:58:06 +00:00
[autoparallel] integrate auto parallel with torch fx (#1479)
This commit is contained in:
parent
8fb09a950a
commit
ede326298b
@ -0,0 +1,6 @@
|
|||||||
|
from .operator_handler import OperatorHandler
|
||||||
|
from .dot_handler import DotHandler
|
||||||
|
from .conv_handler import ConvHandler
|
||||||
|
from .sharding_strategy import ShardingStrategy, StrategiesVector
|
||||||
|
|
||||||
|
__all__ = ['OperatorHandler', 'DotHandler', 'ConvHandler', 'StrategiesVector', 'ShardingStrategy']
|
@ -1,17 +1,20 @@
|
|||||||
import operator
|
import operator
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
import torch
|
import torch
|
||||||
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy
|
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||||
from .operator_handler import OperatorHanlder
|
from .operator_handler import OperatorHandler
|
||||||
|
|
||||||
|
|
||||||
class ConvHandler(OperatorHanlder):
|
class ConvHandler(OperatorHandler):
|
||||||
"""
|
"""
|
||||||
A OperatorHandler which deals with the sharding strategies of linear matrix multiplication.
|
A OperatorHandler which deals with the sharding strategies of linear matrix multiplication.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
self.input_data = self.predecessor_node[0]._meta_data
|
||||||
|
self.weight = self.module_named_parameters['weight']
|
||||||
|
self.output_data = self.node._meta_data
|
||||||
self._sanity_check()
|
self._sanity_check()
|
||||||
|
|
||||||
def _sanity_check(self):
|
def _sanity_check(self):
|
||||||
@ -42,7 +45,7 @@ class ConvHandler(OperatorHanlder):
|
|||||||
# 1D: (L) * N * Cout * Cin * kernel
|
# 1D: (L) * N * Cout * Cin * kernel
|
||||||
# 2D: (H * W) * N * Cout * Cin * kernel
|
# 2D: (H * W) * N * Cout * Cin * kernel
|
||||||
# 3D: (H * W * D) * N * Cout * Cin * kernel
|
# 3D: (H * W * D) * N * Cout * Cin * kernel
|
||||||
output_size = self.output.shape[2:]
|
output_size = self.output_data.shape[2:]
|
||||||
output_size_product = reduce(operator.mul, output_size, 1)
|
output_size_product = reduce(operator.mul, output_size, 1)
|
||||||
kernel_size = self.weight.shape[2:]
|
kernel_size = self.weight.shape[2:]
|
||||||
kernel_size_product = reduce(operator.mul, kernel_size, 1)
|
kernel_size_product = reduce(operator.mul, kernel_size, 1)
|
||||||
@ -59,11 +62,10 @@ class ConvHandler(OperatorHanlder):
|
|||||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||||
|
|
||||||
dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
||||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output, dim_partition_dict_for_input)
|
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||||
|
|
||||||
# generate resharding cost for this strategy
|
# generate resharding cost for this strategy
|
||||||
resharding_costs = {}
|
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||||
self._generate_resharding_costs(resharding_costs, sharding_spec_for_input)
|
|
||||||
|
|
||||||
# compute the computation cost of this strategy
|
# compute the computation cost of this strategy
|
||||||
bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0]
|
bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0]
|
||||||
@ -73,7 +75,7 @@ class ConvHandler(OperatorHanlder):
|
|||||||
|
|
||||||
# compute the memory cost of this strategy
|
# compute the memory cost of this strategy
|
||||||
dtype = self.input_data.dtype
|
dtype = self.input_data.dtype
|
||||||
numel = self.output.numel()
|
numel = self.output_data.numel()
|
||||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||||
sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||||
memory_cost = numel * size_per_elem_bytes / sharding_size
|
memory_cost = numel * size_per_elem_bytes / sharding_size
|
||||||
@ -87,7 +89,7 @@ class ConvHandler(OperatorHanlder):
|
|||||||
memory_cost=memory_cost,
|
memory_cost=memory_cost,
|
||||||
resharding_costs=resharding_costs,
|
resharding_costs=resharding_costs,
|
||||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||||
self.strategies_vector.strategies.append(sharding_strategies)
|
self.strategies_vector.append(sharding_strategies)
|
||||||
|
|
||||||
def split_input_both_dim_weight_in_channel(self, mesh_dim_0, mesh_dim_1):
|
def split_input_both_dim_weight_in_channel(self, mesh_dim_0, mesh_dim_1):
|
||||||
name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
|
name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
|
||||||
@ -99,11 +101,10 @@ class ConvHandler(OperatorHanlder):
|
|||||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||||
|
|
||||||
dim_partition_dict_for_output = {0: [mesh_dim_0]}
|
dim_partition_dict_for_output = {0: [mesh_dim_0]}
|
||||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output, dim_partition_dict_for_input)
|
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
|
||||||
|
|
||||||
# generate resharding cost for this strategy
|
# generate resharding cost for this strategy
|
||||||
resharding_costs = {}
|
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||||
self._generate_resharding_costs(resharding_costs, sharding_spec_for_input)
|
|
||||||
|
|
||||||
# compute the computation cost of this strategy
|
# compute the computation cost of this strategy
|
||||||
bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0]
|
bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0]
|
||||||
@ -113,7 +114,7 @@ class ConvHandler(OperatorHanlder):
|
|||||||
|
|
||||||
# compute the memory cost of this strategy
|
# compute the memory cost of this strategy
|
||||||
dtype = self.input_data.dtype
|
dtype = self.input_data.dtype
|
||||||
numel = self.output.numel()
|
numel = self.output_data.numel()
|
||||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||||
sharding_size = self.device_mesh.shape[mesh_dim_0]
|
sharding_size = self.device_mesh.shape[mesh_dim_0]
|
||||||
memory_cost = numel * size_per_elem_bytes / sharding_size
|
memory_cost = numel * size_per_elem_bytes / sharding_size
|
||||||
@ -127,7 +128,7 @@ class ConvHandler(OperatorHanlder):
|
|||||||
memory_cost=memory_cost,
|
memory_cost=memory_cost,
|
||||||
resharding_costs=resharding_costs,
|
resharding_costs=resharding_costs,
|
||||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||||
self.strategies_vector.strategies.append(sharding_strategies)
|
self.strategies_vector.append(sharding_strategies)
|
||||||
|
|
||||||
def split_input_in_channel_weight_both_channel(self, mesh_dim_0, mesh_dim_1):
|
def split_input_in_channel_weight_both_channel(self, mesh_dim_0, mesh_dim_1):
|
||||||
name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
|
name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
|
||||||
@ -139,11 +140,10 @@ class ConvHandler(OperatorHanlder):
|
|||||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||||
|
|
||||||
dim_partition_dict_for_output = {1: [mesh_dim_1]}
|
dim_partition_dict_for_output = {1: [mesh_dim_1]}
|
||||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output, dim_partition_dict_for_input)
|
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
|
||||||
|
|
||||||
# generate resharding cost for this strategy
|
# generate resharding cost for this strategy
|
||||||
resharding_costs = {}
|
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||||
self._generate_resharding_costs(resharding_costs, sharding_spec_for_input)
|
|
||||||
|
|
||||||
# compute the computation cost of this strategy
|
# compute the computation cost of this strategy
|
||||||
bs = self.input_data.shape[0]
|
bs = self.input_data.shape[0]
|
||||||
@ -153,7 +153,7 @@ class ConvHandler(OperatorHanlder):
|
|||||||
|
|
||||||
# compute the memory cost of this strategy
|
# compute the memory cost of this strategy
|
||||||
dtype = self.input_data.dtype
|
dtype = self.input_data.dtype
|
||||||
numel = self.output.numel()
|
numel = self.output_data.numel()
|
||||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||||
sharding_size = self.device_mesh.shape[mesh_dim_0]
|
sharding_size = self.device_mesh.shape[mesh_dim_0]
|
||||||
memory_cost = numel * size_per_elem_bytes / sharding_size
|
memory_cost = numel * size_per_elem_bytes / sharding_size
|
||||||
@ -167,7 +167,7 @@ class ConvHandler(OperatorHanlder):
|
|||||||
memory_cost=memory_cost,
|
memory_cost=memory_cost,
|
||||||
resharding_costs=resharding_costs,
|
resharding_costs=resharding_costs,
|
||||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||||
self.strategies_vector.strategies.append(sharding_strategies)
|
self.strategies_vector.append(sharding_strategies)
|
||||||
|
|
||||||
def split_weight_out_channel(self, mesh_dim_0):
|
def split_weight_out_channel(self, mesh_dim_0):
|
||||||
name = f'RS{mesh_dim_0} = RR x RS{mesh_dim_0}'
|
name = f'RS{mesh_dim_0} = RR x RS{mesh_dim_0}'
|
||||||
@ -179,11 +179,10 @@ class ConvHandler(OperatorHanlder):
|
|||||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||||
|
|
||||||
dim_partition_dict_for_output = {1: [mesh_dim_0]}
|
dim_partition_dict_for_output = {1: [mesh_dim_0]}
|
||||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output, dim_partition_dict_for_input)
|
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
|
||||||
|
|
||||||
# generate resharding cost for this strategy
|
# generate resharding cost for this strategy
|
||||||
resharding_costs = {}
|
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||||
self._generate_resharding_costs(resharding_costs, sharding_spec_for_input)
|
|
||||||
|
|
||||||
# compute the computation cost of this strategy
|
# compute the computation cost of this strategy
|
||||||
bs = self.input_data.shape[0]
|
bs = self.input_data.shape[0]
|
||||||
@ -193,7 +192,7 @@ class ConvHandler(OperatorHanlder):
|
|||||||
|
|
||||||
# compute the memory cost of this strategy
|
# compute the memory cost of this strategy
|
||||||
dtype = self.input_data.dtype
|
dtype = self.input_data.dtype
|
||||||
numel = self.output.numel()
|
numel = self.output_data.numel()
|
||||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||||
sharding_size = self.device_mesh.shape[mesh_dim_0]
|
sharding_size = self.device_mesh.shape[mesh_dim_0]
|
||||||
memory_cost = numel * size_per_elem_bytes / sharding_size
|
memory_cost = numel * size_per_elem_bytes / sharding_size
|
||||||
@ -208,7 +207,7 @@ class ConvHandler(OperatorHanlder):
|
|||||||
memory_cost=memory_cost,
|
memory_cost=memory_cost,
|
||||||
resharding_costs=resharding_costs,
|
resharding_costs=resharding_costs,
|
||||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||||
self.strategies_vector.strategies.append(sharding_strategies)
|
self.strategies_vector.append(sharding_strategies)
|
||||||
|
|
||||||
def non_split(self):
|
def non_split(self):
|
||||||
name = f'RR = RR x RR'
|
name = f'RR = RR x RR'
|
||||||
@ -220,11 +219,10 @@ class ConvHandler(OperatorHanlder):
|
|||||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||||
|
|
||||||
dim_partition_dict_for_output = {}
|
dim_partition_dict_for_output = {}
|
||||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output, dim_partition_dict_for_input)
|
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
|
||||||
|
|
||||||
# generate resharding cost for this strategy
|
# generate resharding cost for this strategy
|
||||||
resharding_costs = {}
|
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||||
self._generate_resharding_costs(resharding_costs, sharding_spec_for_input)
|
|
||||||
|
|
||||||
# compute the computation cost of this strategy
|
# compute the computation cost of this strategy
|
||||||
bs = self.input_data.shape[0]
|
bs = self.input_data.shape[0]
|
||||||
@ -234,7 +232,7 @@ class ConvHandler(OperatorHanlder):
|
|||||||
|
|
||||||
# compute the memory cost of this strategy
|
# compute the memory cost of this strategy
|
||||||
dtype = self.input_data.dtype
|
dtype = self.input_data.dtype
|
||||||
numel = self.output.numel()
|
numel = self.output_data.numel()
|
||||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||||
memory_cost = numel * size_per_elem_bytes
|
memory_cost = numel * size_per_elem_bytes
|
||||||
|
|
||||||
@ -248,9 +246,9 @@ class ConvHandler(OperatorHanlder):
|
|||||||
memory_cost=memory_cost,
|
memory_cost=memory_cost,
|
||||||
resharding_costs=resharding_costs,
|
resharding_costs=resharding_costs,
|
||||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||||
self.strategies_vector.strategies.append(sharding_strategies)
|
self.strategies_vector.append(sharding_strategies)
|
||||||
|
|
||||||
def register_strategy_into_strategies_vector(self):
|
def register_strategy(self) -> StrategiesVector:
|
||||||
'''
|
'''
|
||||||
Generate every possible strategies for a Conv node, and record all strategies into the strategies_vector.
|
Generate every possible strategies for a Conv node, and record all strategies into the strategies_vector.
|
||||||
|
|
||||||
@ -315,3 +313,5 @@ class ConvHandler(OperatorHanlder):
|
|||||||
|
|
||||||
# RR= RR x RR
|
# RR= RR x RR
|
||||||
self.non_split()
|
self.non_split()
|
||||||
|
|
||||||
|
return self.strategies_vector
|
||||||
|
@ -1,15 +1,21 @@
|
|||||||
import operator
|
import operator
|
||||||
import torch
|
import torch
|
||||||
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy
|
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||||
from .operator_handler import OperatorHanlder
|
from .operator_handler import OperatorHandler
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
|
|
||||||
|
|
||||||
class DotHandler(OperatorHanlder):
|
class DotHandler(OperatorHandler):
|
||||||
"""
|
"""
|
||||||
A OperatorHandler which deals with the sharding strategies of linear matrix multiplication.
|
A OperatorHandler which deals with the sharding strategies of linear matrix multiplication.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.input_data = self.predecessor_node[0]._meta_data
|
||||||
|
self.weight = self.module_named_parameters['weight']
|
||||||
|
self.output_data = self.node._meta_data
|
||||||
|
|
||||||
def _generate_compute_cost(self, input_shape, weight_shape):
|
def _generate_compute_cost(self, input_shape, weight_shape):
|
||||||
# TODO: consider bias addition
|
# TODO: consider bias addition
|
||||||
compute_cost = reduce(operator.mul, input_shape) * weight_shape[0] * 2
|
compute_cost = reduce(operator.mul, input_shape) * weight_shape[0] * 2
|
||||||
@ -27,18 +33,17 @@ class DotHandler(OperatorHanlder):
|
|||||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||||
|
|
||||||
dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
||||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output, dim_partition_dict_for_input)
|
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
|
||||||
|
|
||||||
# generate resharding cost for this strategy
|
# generate resharding cost for this strategy
|
||||||
resharding_costs = {}
|
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||||
self._generate_resharding_costs(resharding_costs, sharding_spec_for_input)
|
|
||||||
|
|
||||||
# compute computation cost
|
# compute computation cost
|
||||||
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape)
|
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape)
|
||||||
|
|
||||||
# compute the memory cost of this strategy
|
# compute the memory cost of this strategy
|
||||||
dtype = self.input_data.dtype
|
dtype = self.input_data.dtype
|
||||||
numel = self.output.numel()
|
numel = self.output_data.numel()
|
||||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||||
sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||||
memory_cost = numel * size_per_elem_bytes / sharding_size
|
memory_cost = numel * size_per_elem_bytes / sharding_size
|
||||||
@ -55,7 +60,7 @@ class DotHandler(OperatorHanlder):
|
|||||||
memory_cost=memory_cost,
|
memory_cost=memory_cost,
|
||||||
resharding_costs=resharding_costs,
|
resharding_costs=resharding_costs,
|
||||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||||
self.strategies_vector.strategies.append(sharding_strategies)
|
self.strategies_vector.append(sharding_strategies)
|
||||||
|
|
||||||
def split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
|
def split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
|
||||||
# handle the case SR = SS x SR
|
# handle the case SR = SS x SR
|
||||||
@ -70,18 +75,17 @@ class DotHandler(OperatorHanlder):
|
|||||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||||
|
|
||||||
dim_partition_dict_for_output = {0: [mesh_dim_0]}
|
dim_partition_dict_for_output = {0: [mesh_dim_0]}
|
||||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output, dim_partition_dict_for_output)
|
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||||
|
|
||||||
# generate resharding cost for this strategy
|
# generate resharding cost for this strategy
|
||||||
resharding_costs = {}
|
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||||
self._generate_resharding_costs(resharding_costs, sharding_spec_for_input)
|
|
||||||
|
|
||||||
# compute the computation cost of this strategy
|
# compute the computation cost of this strategy
|
||||||
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape)
|
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape)
|
||||||
|
|
||||||
# compute the memory cost of this strategy
|
# compute the memory cost of this strategy
|
||||||
dtype = self.input_data.dtype
|
dtype = self.input_data.dtype
|
||||||
numel = self.output.numel()
|
numel = self.output_data.numel()
|
||||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||||
sharding_size = self.device_mesh.shape[mesh_dim_0]
|
sharding_size = self.device_mesh.shape[mesh_dim_0]
|
||||||
memory_cost = numel * size_per_elem_bytes / sharding_size
|
memory_cost = numel * size_per_elem_bytes / sharding_size
|
||||||
@ -95,7 +99,7 @@ class DotHandler(OperatorHanlder):
|
|||||||
memory_cost=memory_cost,
|
memory_cost=memory_cost,
|
||||||
resharding_costs=resharding_costs,
|
resharding_costs=resharding_costs,
|
||||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||||
self.strategies_vector.strategies.append(sharding_strategies)
|
self.strategies_vector.append(sharding_strategies)
|
||||||
|
|
||||||
def split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
|
def split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
|
||||||
name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
|
name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
|
||||||
@ -107,18 +111,17 @@ class DotHandler(OperatorHanlder):
|
|||||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||||
|
|
||||||
dim_partition_dict_for_output = {1: [mesh_dim_1]}
|
dim_partition_dict_for_output = {1: [mesh_dim_1]}
|
||||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output, dim_partition_dict_for_input)
|
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
|
||||||
|
|
||||||
# generate resharding cost for this strategy
|
# generate resharding cost for this strategy
|
||||||
resharding_costs = {}
|
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||||
self._generate_resharding_costs(resharding_costs, sharding_spec_for_input)
|
|
||||||
|
|
||||||
# compute the computation cost of this strategy
|
# compute the computation cost of this strategy
|
||||||
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape)
|
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape)
|
||||||
|
|
||||||
# compute the memory cost of this strategy
|
# compute the memory cost of this strategy
|
||||||
dtype = self.input_data.dtype
|
dtype = self.input_data.dtype
|
||||||
numel = self.output.numel()
|
numel = self.output_data.numel()
|
||||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||||
sharding_size = self.device_mesh.shape[mesh_dim_0]
|
sharding_size = self.device_mesh.shape[mesh_dim_0]
|
||||||
memory_cost = numel * size_per_elem_bytes / sharding_size
|
memory_cost = numel * size_per_elem_bytes / sharding_size
|
||||||
@ -132,7 +135,7 @@ class DotHandler(OperatorHanlder):
|
|||||||
memory_cost=memory_cost,
|
memory_cost=memory_cost,
|
||||||
resharding_costs=resharding_costs,
|
resharding_costs=resharding_costs,
|
||||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||||
self.strategies_vector.strategies.append(sharding_strategies)
|
self.strategies_vector.append(sharding_strategies)
|
||||||
|
|
||||||
def recompute_split_both_contract(self, mesh_dim):
|
def recompute_split_both_contract(self, mesh_dim):
|
||||||
name = f'RR = RS{mesh_dim} x S{mesh_dim}R'
|
name = f'RR = RS{mesh_dim} x S{mesh_dim}R'
|
||||||
@ -144,18 +147,17 @@ class DotHandler(OperatorHanlder):
|
|||||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||||
|
|
||||||
dim_partition_dict_for_output = {}
|
dim_partition_dict_for_output = {}
|
||||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output, dim_partition_dict_for_output)
|
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||||
|
|
||||||
# generate resharding cost for this strategy
|
# generate resharding cost for this strategy
|
||||||
resharding_costs = {}
|
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||||
self._generate_resharding_costs(resharding_costs, sharding_spec_for_input)
|
|
||||||
|
|
||||||
# compute the computation cost of this strategy
|
# compute the computation cost of this strategy
|
||||||
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape)
|
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape)
|
||||||
|
|
||||||
# compute the memory cost of this strategy
|
# compute the memory cost of this strategy
|
||||||
dtype = self.input_data.dtype
|
dtype = self.input_data.dtype
|
||||||
numel = self.output.numel()
|
numel = self.output_data.numel()
|
||||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||||
memory_cost = numel * size_per_elem_bytes
|
memory_cost = numel * size_per_elem_bytes
|
||||||
|
|
||||||
@ -168,7 +170,7 @@ class DotHandler(OperatorHanlder):
|
|||||||
memory_cost=memory_cost,
|
memory_cost=memory_cost,
|
||||||
resharding_costs=resharding_costs,
|
resharding_costs=resharding_costs,
|
||||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||||
self.strategies_vector.strategies.append(sharding_strategies)
|
self.strategies_vector.append(sharding_strategies)
|
||||||
|
|
||||||
def split_rhs_space_only(self, mesh_dim):
|
def split_rhs_space_only(self, mesh_dim):
|
||||||
name = f'RS{mesh_dim} = RR x RS{mesh_dim}'
|
name = f'RS{mesh_dim} = RR x RS{mesh_dim}'
|
||||||
@ -180,18 +182,17 @@ class DotHandler(OperatorHanlder):
|
|||||||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||||
|
|
||||||
dim_partition_dict_for_output = {1: [mesh_dim]}
|
dim_partition_dict_for_output = {1: [mesh_dim]}
|
||||||
sharding_spec_for_ouput = self._generate_sharding_spec(self.output, dim_partition_dict_for_output)
|
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
|
||||||
|
|
||||||
# generate resharding cost for this strategy
|
# generate resharding cost for this strategy
|
||||||
resharding_costs = {}
|
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
|
||||||
self._generate_resharding_costs(resharding_costs, sharding_spec_for_input)
|
|
||||||
|
|
||||||
# compute the computation cost of this strategy
|
# compute the computation cost of this strategy
|
||||||
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape)
|
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape)
|
||||||
|
|
||||||
# compute the memory cost of this strategy
|
# compute the memory cost of this strategy
|
||||||
dtype = self.input_data.dtype
|
dtype = self.input_data.dtype
|
||||||
numel = self.output.numel()
|
numel = self.output_data.numel()
|
||||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||||
sharding_size = self.device_mesh.shape[mesh_dim]
|
sharding_size = self.device_mesh.shape[mesh_dim]
|
||||||
memory_cost = numel * size_per_elem_bytes / sharding_size
|
memory_cost = numel * size_per_elem_bytes / sharding_size
|
||||||
@ -205,9 +206,9 @@ class DotHandler(OperatorHanlder):
|
|||||||
memory_cost=memory_cost,
|
memory_cost=memory_cost,
|
||||||
resharding_costs=resharding_costs,
|
resharding_costs=resharding_costs,
|
||||||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||||
self.strategies_vector.strategies.append(sharding_strategies)
|
self.strategies_vector.append(sharding_strategies)
|
||||||
|
|
||||||
def register_strategy_into_strategies_vector(self):
|
def register_strategy(self) -> StrategiesVector:
|
||||||
'''
|
'''
|
||||||
Generate every possible strategies for a Conv node, and record all strategies into the strategies_vector.
|
Generate every possible strategies for a Conv node, and record all strategies into the strategies_vector.
|
||||||
|
|
||||||
@ -233,3 +234,4 @@ class DotHandler(OperatorHanlder):
|
|||||||
# RS = RR x RS
|
# RS = RR x RS
|
||||||
self.split_rhs_space_only(0)
|
self.split_rhs_space_only(0)
|
||||||
self.split_rhs_space_only(1)
|
self.split_rhs_space_only(1)
|
||||||
|
return self.strategies_vector
|
||||||
|
@ -1,15 +1,18 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from torch.fx.node import Node
|
from torch.fx.node import Node
|
||||||
import torch.nn as nn
|
from typing import Dict
|
||||||
from colossalai.device.device_mesh import DeviceMesh
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
from .sharding_strategy import StrategiesVector
|
|
||||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||||
|
|
||||||
|
from .sharding_strategy import StrategiesVector
|
||||||
|
|
||||||
class OperatorHanlder(ABC):
|
|
||||||
|
class OperatorHandler(ABC):
|
||||||
'''
|
'''
|
||||||
The OperatorHanlder is an abstract class used to generate every possible strategies for a operator node.
|
The OperatorHandler is an abstract class used to generate every possible strategies for a operator node.
|
||||||
|
|
||||||
Argument:
|
Argument:
|
||||||
input_node(Node): the input node in node argument list.
|
input_node(Node): the input node in node argument list.
|
||||||
@ -21,30 +24,43 @@ class OperatorHanlder(ABC):
|
|||||||
shape_consistency_manager(ShapeConsistencyManager): ShapeConsistencyManager will give the resharding costs of the different sharding specs.
|
shape_consistency_manager(ShapeConsistencyManager): ShapeConsistencyManager will give the resharding costs of the different sharding specs.
|
||||||
'''
|
'''
|
||||||
|
|
||||||
def __init__(self, input_node: Node, input_index: int, weight: nn.Parameter, output_node: Node,
|
def __init__(self, node: Node, device_mesh: DeviceMesh, strategies_vector: StrategiesVector,
|
||||||
device_mesh: DeviceMesh, strategies_vector: StrategiesVector,
|
|
||||||
shape_consistency_manager: ShapeConsistencyManager):
|
shape_consistency_manager: ShapeConsistencyManager):
|
||||||
self.input_node = input_node
|
self.node = node
|
||||||
self.input_data = self.input_node._meta_data
|
self.predecessor_node = list(node._input_nodes.keys())
|
||||||
self.weight = weight
|
self.successor_node = list(node.users.keys())
|
||||||
self.input_index = input_index
|
|
||||||
self.output_node = output_node
|
|
||||||
self.output = self.output_node._meta_data
|
|
||||||
self.device_mesh = device_mesh
|
self.device_mesh = device_mesh
|
||||||
self.strategies_vector = strategies_vector
|
self.strategies_vector = strategies_vector
|
||||||
self.shape_consistency_manager = shape_consistency_manager
|
self.shape_consistency_manager = shape_consistency_manager
|
||||||
|
|
||||||
|
# find the module and its parameters associated with this node
|
||||||
|
# this can be used to compute the compute/communication/sharding cost
|
||||||
|
if self.node.op == 'call_module':
|
||||||
|
module = node.graph.owning_module.get_submodule(node.target)
|
||||||
|
named_parameters = list(module.named_parameters(recurse=False))
|
||||||
|
# convert named parameters from list to dict
|
||||||
|
named_parameters = {k: v for k, v in named_parameters}
|
||||||
|
else:
|
||||||
|
module = None
|
||||||
|
named_parameters = None
|
||||||
|
self.module = module
|
||||||
|
self.module_named_parameters = named_parameters
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def register_strategy_into_strategies_vector(self):
|
def register_strategy(self) -> StrategiesVector:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _generate_sharding_spec(self, tensor, dim_partition_dict):
|
def _generate_sharding_spec(self, tensor: torch.Tensor, dim_partition_dict: Dict[int, int]) -> ShardingSpec:
|
||||||
|
"""
|
||||||
|
Generate the sharding spec of the tensor based on the given dim_partition_dict
|
||||||
|
where the key is the tensor dimension and the value is the mesh dimension for sharding.
|
||||||
|
"""
|
||||||
sharding_spec = ShardingSpec(device_mesh=self.device_mesh,
|
sharding_spec = ShardingSpec(device_mesh=self.device_mesh,
|
||||||
entire_shape=tensor.shape,
|
entire_shape=tensor.shape,
|
||||||
dim_partition_dict=dim_partition_dict)
|
dim_partition_dict=dim_partition_dict)
|
||||||
return sharding_spec
|
return sharding_spec
|
||||||
|
|
||||||
def _generate_resharding_costs(self, resharding_costs, sharding_spec_for_input):
|
def _generate_resharding_costs(self, sharding_spec_for_input):
|
||||||
'''
|
'''
|
||||||
Compute the resharding costs with this specific strategy.
|
Compute the resharding costs with this specific strategy.
|
||||||
|
|
||||||
@ -58,8 +74,10 @@ class OperatorHanlder(ABC):
|
|||||||
sharding_spec_for_input(ShardingSpec): ShardingSpec of the input node.
|
sharding_spec_for_input(ShardingSpec): ShardingSpec of the input node.
|
||||||
'''
|
'''
|
||||||
# The resharding_cost of weight is counted due to sharing weight cases.
|
# The resharding_cost of weight is counted due to sharing weight cases.
|
||||||
resharding_costs[self.input_index] = []
|
resharding_costs = {}
|
||||||
for stategy in self.input_node.strategies_vector.strategies:
|
for input_node, input_spec in zip(self.predecessor_node, sharding_spec_for_input):
|
||||||
_, _, resharding_cost = self.shape_consistency_manager.shape_consistency(stategy, sharding_spec_for_input)
|
resharding_costs[input_node] = []
|
||||||
resharding_costs[self.input_index].append(resharding_cost)
|
for strategy in input_node.strategies_vector:
|
||||||
|
_, _, resharding_cost = self.shape_consistency_manager.shape_consistency(strategy, input_spec)
|
||||||
|
resharding_costs[input_node].append(resharding_cost)
|
||||||
return resharding_cost
|
return resharding_cost
|
||||||
|
@ -1,6 +1,9 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
from torch.fx.node import Node
|
||||||
|
|
||||||
|
__all__ = ['ShardingStrategy', 'StrategiesVector']
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -30,26 +33,21 @@ class ShardingStrategy:
|
|||||||
input_shardings: ShardingSpec = None
|
input_shardings: ShardingSpec = None
|
||||||
|
|
||||||
|
|
||||||
class StrategiesVector:
|
class StrategiesVector(list):
|
||||||
'''
|
'''
|
||||||
Each node in fx graph will have a corresponding StrategiesVector, to store all the possible
|
Each node in fx graph will have a corresponding StrategiesVector, to store all the possible
|
||||||
strategies of the node.
|
strategies of the node.
|
||||||
|
|
||||||
Argument:
|
Argument:
|
||||||
node(Node): node to build corresponding strategies_vector.
|
node (Node): node for which the list of sharding strategies are generated.
|
||||||
in_nodes(List[Node]): input nodes in the argument list of the node.
|
|
||||||
following_nodes(List[Node]): the nodes take the target node as their argument.
|
|
||||||
strategies(List[ShardingStrategy]): enumerate all the possible sharding strategies of the node.
|
|
||||||
'''
|
'''
|
||||||
|
|
||||||
def __init__(self, node, in_nodes, following_nodes=None, strategies=None):
|
def __init__(self, node: Node):
|
||||||
|
super().__init__()
|
||||||
self.node = node
|
self.node = node
|
||||||
self.in_nodes = in_nodes
|
# fetch its input and output nodes
|
||||||
self.following_nodes = following_nodes
|
self.predecessor_nodes = list(node._input_nodes.keys())
|
||||||
|
self.successor_ndoes = list(node.users.keys())
|
||||||
if strategies is None:
|
|
||||||
strategies = []
|
|
||||||
self.strategies = strategies
|
|
||||||
|
|
||||||
def check_merge(self):
|
def check_merge(self):
|
||||||
pass
|
pass
|
||||||
|
@ -47,7 +47,9 @@ def test_conv_handler():
|
|||||||
# [x, mul, conv, output]
|
# [x, mul, conv, output]
|
||||||
nodes = [node for node in gm.graph.nodes]
|
nodes = [node for node in gm.graph.nodes]
|
||||||
|
|
||||||
strategies_for_input = []
|
# find the sharding strategies for the input node of the conv node
|
||||||
|
# strategies_for_input = [[R, R, R, R], [R, S0, R, R], [R, S1, R, R], [S0, R, R, R], [S0, S1, R, R], [S1, R, R, R], [S1, S0, R, R]]
|
||||||
|
strategies_vector_for_input = StrategiesVector(nodes[1])
|
||||||
sharding_option = (None, 0, 1)
|
sharding_option = (None, 0, 1)
|
||||||
for first_sharding_index in sharding_option:
|
for first_sharding_index in sharding_option:
|
||||||
for second_sharding_index in sharding_option:
|
for second_sharding_index in sharding_option:
|
||||||
@ -68,28 +70,19 @@ def test_conv_handler():
|
|||||||
sharding_spec = ShardingSpec(device_mesh=device_mesh,
|
sharding_spec = ShardingSpec(device_mesh=device_mesh,
|
||||||
entire_shape=entire_shape,
|
entire_shape=entire_shape,
|
||||||
sharding_sequence=sharding_sequence)
|
sharding_sequence=sharding_sequence)
|
||||||
strategies_for_input.append(sharding_spec)
|
strategies_vector_for_input.append(sharding_spec)
|
||||||
|
|
||||||
# strategies_for_input = [[R, R, R, R], [R, S0, R, R], [R, S1, R, R], [S0, R, R, R], [S0, S1, R, R], [S1, R, R, R], [S1, S0, R, R]]
|
|
||||||
strategies_vector_for_input = StrategiesVector(node=nodes[0],
|
|
||||||
in_nodes=[nodes[1], 2],
|
|
||||||
strategies=strategies_for_input)
|
|
||||||
setattr(nodes[1], 'strategies_vector', strategies_vector_for_input)
|
setattr(nodes[1], 'strategies_vector', strategies_vector_for_input)
|
||||||
|
|
||||||
strategies_vector = StrategiesVector(node=nodes[2], in_nodes=[
|
# generate conv strategy
|
||||||
nodes[1],
|
strategies_vector = StrategiesVector(node=nodes[2])
|
||||||
])
|
conv_handler = ConvHandler(node=nodes[2],
|
||||||
conv_handler = ConvHandler(input_node=nodes[1],
|
|
||||||
input_index=0,
|
|
||||||
weight=dict(gm.named_modules())[nodes[2].name].weight,
|
|
||||||
output_node=nodes[2],
|
|
||||||
device_mesh=device_mesh,
|
device_mesh=device_mesh,
|
||||||
strategies_vector=strategies_vector,
|
strategies_vector=strategies_vector,
|
||||||
shape_consistency_manager=shape_consistency_manager)
|
shape_consistency_manager=shape_consistency_manager)
|
||||||
conv_handler.register_strategy_into_strategies_vector()
|
conv_handler.register_strategy()
|
||||||
|
|
||||||
# ['S0S1 = S0R x RS1', 'S1S0 = S1R x RS0', 'S0R = S0S1 x S1R', 'S1R = S1S0 x S0R', 'RS1 = RS0 x S0S1', 'RS0 = RS1 x S1S0', 'RS0 = RR x RS0', 'RS1 = RR x RS1', 'RR = RR x RR']
|
# ['S0S1 = S0R x RS1', 'S1S0 = S1R x RS0', 'S0R = S0S1 x S1R', 'S1R = S1S0 x S0R', 'RS1 = RS0 x S0S1', 'RS0 = RS1 x S1S0', 'RS0 = RR x RS0', 'RS1 = RR x RS1', 'RR = RR x RR']
|
||||||
strategy_name_list = [strategy.name for strategy in conv_handler.strategies_vector.strategies]
|
strategy_name_list = [strategy.name for strategy in conv_handler.strategies_vector]
|
||||||
|
|
||||||
# SS = SR x RS
|
# SS = SR x RS
|
||||||
assert 'S0S1 = S0R x RS1' in strategy_name_list
|
assert 'S0S1 = S0R x RS1' in strategy_name_list
|
||||||
|
@ -47,7 +47,9 @@ def test_dot_handler():
|
|||||||
# [x, mul, linear, output]
|
# [x, mul, linear, output]
|
||||||
nodes = [node for node in gm.graph.nodes]
|
nodes = [node for node in gm.graph.nodes]
|
||||||
|
|
||||||
strategies_for_input = []
|
# find the sharding strategies for the input node of the conv node
|
||||||
|
# strategies_for_input = [[R, R, R, R], [R, S0, R, R], [R, S1, R, R], [S0, R, R, R], [S0, S1, R, R], [S1, R, R, R], [S1, S0, R, R]]
|
||||||
|
strategies_vector_for_input = StrategiesVector(node=nodes[1])
|
||||||
sharding_option = (None, 0, 1)
|
sharding_option = (None, 0, 1)
|
||||||
for first_sharding_index in sharding_option:
|
for first_sharding_index in sharding_option:
|
||||||
for second_sharding_index in sharding_option:
|
for second_sharding_index in sharding_option:
|
||||||
@ -67,26 +69,19 @@ def test_dot_handler():
|
|||||||
sharding_spec = ShardingSpec(device_mesh=device_mesh,
|
sharding_spec = ShardingSpec(device_mesh=device_mesh,
|
||||||
entire_shape=entire_shape,
|
entire_shape=entire_shape,
|
||||||
sharding_sequence=sharding_sequence)
|
sharding_sequence=sharding_sequence)
|
||||||
strategies_for_input.append(sharding_spec)
|
strategies_vector_for_input.append(sharding_spec)
|
||||||
|
|
||||||
# strategies_for_input = [[R, R, R, R], [R, S0, R, R], [R, S1, R, R], [S0, R, R, R], [S0, S1, R, R], [S1, R, R, R], [S1, S0, R, R]]
|
|
||||||
strategies_vector_for_input = StrategiesVector(node=nodes[1], in_nodes=nodes[0], strategies=strategies_for_input)
|
|
||||||
setattr(nodes[1], 'strategies_vector', strategies_vector_for_input)
|
setattr(nodes[1], 'strategies_vector', strategies_vector_for_input)
|
||||||
|
|
||||||
strategies_vector = StrategiesVector(node=nodes[2], in_nodes=[
|
# generate dot strategy
|
||||||
nodes[1],
|
strategies_vector = StrategiesVector(node=nodes[2])
|
||||||
])
|
dot_handler = DotHandler(node=nodes[2],
|
||||||
dot_handler = DotHandler(input_node=nodes[1],
|
|
||||||
input_index=0,
|
|
||||||
weight=dict(gm.named_modules())[nodes[2].name].weight,
|
|
||||||
output_node=nodes[2],
|
|
||||||
device_mesh=device_mesh,
|
device_mesh=device_mesh,
|
||||||
strategies_vector=strategies_vector,
|
strategies_vector=strategies_vector,
|
||||||
shape_consistency_manager=shape_consistency_manager)
|
shape_consistency_manager=shape_consistency_manager)
|
||||||
dot_handler.register_strategy_into_strategies_vector()
|
strategies_vector = dot_handler.register_strategy()
|
||||||
|
|
||||||
# ['S0S1 = S0R x RS1', 'S1S0 = S1R x RS0', 'S0R = S0S1 x S1R', 'S1R = S1S0 x S0R', 'RS1 = RS0 x S0S1', 'RS0 = RS1 x S1S0', 'RS0 = RR x RS0', 'RS1 = RR x RS1', 'RR = RR x RR']
|
# ['S0S1 = S0R x RS1', 'S1S0 = S1R x RS0', 'S0R = S0S1 x S1R', 'S1R = S1S0 x S0R', 'RS1 = RS0 x S0S1', 'RS0 = RS1 x S1S0', 'RS0 = RR x RS0', 'RS1 = RR x RS1', 'RR = RR x RR']
|
||||||
strategy_name_list = [strategy.name for strategy in dot_handler.strategies_vector.strategies]
|
strategy_name_list = [strategy.name for strategy in strategies_vector]
|
||||||
|
|
||||||
# SS = SR x RS
|
# SS = SR x RS
|
||||||
assert 'S0S1 = S0R x RS1' in strategy_name_list
|
assert 'S0S1 = S0R x RS1' in strategy_name_list
|
||||||
|
Loading…
Reference in New Issue
Block a user