mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-06 14:12:07 +00:00
[autoparallel] added dot handler (#1475)
This commit is contained in:
parent
d08566fb61
commit
628c7e3fc8
@ -1,9 +1,7 @@
|
|||||||
from lib2to3.pytree import Base
|
|
||||||
import operator
|
import operator
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
import torch
|
import torch
|
||||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy
|
||||||
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
|
||||||
from .operator_handler import OperatorHanlder
|
from .operator_handler import OperatorHanlder
|
||||||
|
|
||||||
|
|
||||||
@ -26,25 +24,6 @@ class ConvHandler(OperatorHanlder):
|
|||||||
assert self.input_data.dim() in (3, 4,
|
assert self.input_data.dim() in (3, 4,
|
||||||
5), f'We suppose the dim of input fed into conv op should in range of [3, 5].'
|
5), f'We suppose the dim of input fed into conv op should in range of [3, 5].'
|
||||||
|
|
||||||
def _generate_resharding_costs(self, resharding_costs, sharding_spec_for_input):
|
|
||||||
'''
|
|
||||||
Compute the resharding costs with this specific strategy.
|
|
||||||
|
|
||||||
Note: The resharding_cost of weight is NOT counted.
|
|
||||||
|
|
||||||
Argument:
|
|
||||||
resharding_costs(Dict[int, List[float]]): The resharding cost generated in this method will be appended into this dictionary.
|
|
||||||
Resharding_cost[i][j] means the cost of i-th argument in the output node argument list
|
|
||||||
with j-th strategy in its strategies_vector transforms to sharding spec wanted in this
|
|
||||||
strategy.
|
|
||||||
sharding_spec_for_input(ShardingSpec): ShardingSpec of the input node.
|
|
||||||
'''
|
|
||||||
# The resharding_cost of weight is counted due to sharing weight cases.
|
|
||||||
resharding_costs[self.input_index] = []
|
|
||||||
for stategy in self.input_node.strategies_vector.strategies:
|
|
||||||
_, _, resharding_cost = self.shape_consistency_manager.shape_consistency(stategy, sharding_spec_for_input)
|
|
||||||
resharding_costs[self.input_index].append(resharding_cost)
|
|
||||||
|
|
||||||
def _generate_compute_cost(self, bs, channel_in, channel_out):
|
def _generate_compute_cost(self, bs, channel_in, channel_out):
|
||||||
'''
|
'''
|
||||||
Compute the computation cost per device with this specific strategy.
|
Compute the computation cost per device with this specific strategy.
|
||||||
|
@ -1,4 +1,8 @@
|
|||||||
|
import operator
|
||||||
|
import torch
|
||||||
|
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy
|
||||||
from .operator_handler import OperatorHanlder
|
from .operator_handler import OperatorHanlder
|
||||||
|
from functools import reduce
|
||||||
|
|
||||||
|
|
||||||
class DotHandler(OperatorHanlder):
|
class DotHandler(OperatorHanlder):
|
||||||
@ -6,7 +10,226 @@ class DotHandler(OperatorHanlder):
|
|||||||
A OperatorHandler which deals with the sharding strategies of linear matrix multiplication.
|
A OperatorHandler which deals with the sharding strategies of linear matrix multiplication.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def _generate_compute_cost(self, input_shape, weight_shape):
|
||||||
super().__init__(*args, **kwargs)
|
# TODO: consider bias addition
|
||||||
|
compute_cost = reduce(operator.mul, input_shape) * weight_shape[0] * 2
|
||||||
|
return compute_cost
|
||||||
|
|
||||||
# TODO: refactor the dot handler in my local branch to align with the latest main branch
|
def split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1):
|
||||||
|
# handle case SS = SR x RS
|
||||||
|
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
|
||||||
|
|
||||||
|
dim_partition_dict_for_input = {0: [mesh_dim_0]}
|
||||||
|
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||||
|
|
||||||
|
# linear layer weight is transposed during init
|
||||||
|
dim_partition_dict_for_weight = {0: [mesh_dim_1]}
|
||||||
|
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||||
|
|
||||||
|
dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
||||||
|
sharding_spec_for_ouput = self._generate_sharding_spec(self.output, dim_partition_dict_for_input)
|
||||||
|
|
||||||
|
# generate resharding cost for this strategy
|
||||||
|
resharding_costs = {}
|
||||||
|
self._generate_resharding_costs(resharding_costs, sharding_spec_for_input)
|
||||||
|
|
||||||
|
# compute computation cost
|
||||||
|
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape)
|
||||||
|
|
||||||
|
# compute the memory cost of this strategy
|
||||||
|
dtype = self.input_data.dtype
|
||||||
|
numel = self.output.numel()
|
||||||
|
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||||
|
sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||||
|
memory_cost = numel * size_per_elem_bytes / sharding_size
|
||||||
|
|
||||||
|
# compute the communication cost
|
||||||
|
# no all-reduce required for this case
|
||||||
|
communication_cost = 0
|
||||||
|
|
||||||
|
# create and register strategy
|
||||||
|
sharding_strategies = ShardingStrategy(name,
|
||||||
|
output_sharding_spec=sharding_spec_for_ouput,
|
||||||
|
compute_cost=compute_cost,
|
||||||
|
communication_cost=communication_cost,
|
||||||
|
memory_cost=memory_cost,
|
||||||
|
resharding_costs=resharding_costs,
|
||||||
|
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||||
|
self.strategies_vector.strategies.append(sharding_strategies)
|
||||||
|
|
||||||
|
def split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
|
||||||
|
# handle the case SR = SS x SR
|
||||||
|
name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
|
||||||
|
|
||||||
|
dim_partition_dict_for_input = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
||||||
|
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||||
|
|
||||||
|
# since weight of the linear layer is transposed
|
||||||
|
# the actual dim to be sharded is 1
|
||||||
|
dim_partition_dict_for_weight = {1: [mesh_dim_0]}
|
||||||
|
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||||
|
|
||||||
|
dim_partition_dict_for_output = {0: [mesh_dim_0]}
|
||||||
|
sharding_spec_for_ouput = self._generate_sharding_spec(self.output, dim_partition_dict_for_output)
|
||||||
|
|
||||||
|
# generate resharding cost for this strategy
|
||||||
|
resharding_costs = {}
|
||||||
|
self._generate_resharding_costs(resharding_costs, sharding_spec_for_input)
|
||||||
|
|
||||||
|
# compute the computation cost of this strategy
|
||||||
|
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape)
|
||||||
|
|
||||||
|
# compute the memory cost of this strategy
|
||||||
|
dtype = self.input_data.dtype
|
||||||
|
numel = self.output.numel()
|
||||||
|
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||||
|
sharding_size = self.device_mesh.shape[mesh_dim_0]
|
||||||
|
memory_cost = numel * size_per_elem_bytes / sharding_size
|
||||||
|
|
||||||
|
# compute the communication cost of this strategy
|
||||||
|
communication_cost = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim_1)
|
||||||
|
sharding_strategies = ShardingStrategy(name,
|
||||||
|
output_sharding_spec=sharding_spec_for_ouput,
|
||||||
|
compute_cost=compute_cost,
|
||||||
|
communication_cost=communication_cost,
|
||||||
|
memory_cost=memory_cost,
|
||||||
|
resharding_costs=resharding_costs,
|
||||||
|
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||||
|
self.strategies_vector.strategies.append(sharding_strategies)
|
||||||
|
|
||||||
|
def split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
|
||||||
|
name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
|
||||||
|
|
||||||
|
dim_partition_dict_for_input = {1: [mesh_dim_0]}
|
||||||
|
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||||
|
|
||||||
|
dim_partition_dict_for_weight = {0: [mesh_dim_0], 1: [mesh_dim_1]}
|
||||||
|
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||||
|
|
||||||
|
dim_partition_dict_for_output = {1: [mesh_dim_1]}
|
||||||
|
sharding_spec_for_ouput = self._generate_sharding_spec(self.output, dim_partition_dict_for_input)
|
||||||
|
|
||||||
|
# generate resharding cost for this strategy
|
||||||
|
resharding_costs = {}
|
||||||
|
self._generate_resharding_costs(resharding_costs, sharding_spec_for_input)
|
||||||
|
|
||||||
|
# compute the computation cost of this strategy
|
||||||
|
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape)
|
||||||
|
|
||||||
|
# compute the memory cost of this strategy
|
||||||
|
dtype = self.input_data.dtype
|
||||||
|
numel = self.output.numel()
|
||||||
|
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||||
|
sharding_size = self.device_mesh.shape[mesh_dim_0]
|
||||||
|
memory_cost = numel * size_per_elem_bytes / sharding_size
|
||||||
|
|
||||||
|
# compute the communication cost of this strategy
|
||||||
|
communication_cost = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim_1)
|
||||||
|
sharding_strategies = ShardingStrategy(name,
|
||||||
|
output_sharding_spec=sharding_spec_for_ouput,
|
||||||
|
compute_cost=compute_cost,
|
||||||
|
communication_cost=communication_cost,
|
||||||
|
memory_cost=memory_cost,
|
||||||
|
resharding_costs=resharding_costs,
|
||||||
|
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||||
|
self.strategies_vector.strategies.append(sharding_strategies)
|
||||||
|
|
||||||
|
def recompute_split_both_contract(self, mesh_dim):
|
||||||
|
name = f'RR = RS{mesh_dim} x S{mesh_dim}R'
|
||||||
|
|
||||||
|
dim_partition_dict_for_input = {1: [mesh_dim]}
|
||||||
|
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||||
|
|
||||||
|
dim_partition_dict_for_weight = {1: [mesh_dim]}
|
||||||
|
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||||
|
|
||||||
|
dim_partition_dict_for_output = {}
|
||||||
|
sharding_spec_for_ouput = self._generate_sharding_spec(self.output, dim_partition_dict_for_output)
|
||||||
|
|
||||||
|
# generate resharding cost for this strategy
|
||||||
|
resharding_costs = {}
|
||||||
|
self._generate_resharding_costs(resharding_costs, sharding_spec_for_input)
|
||||||
|
|
||||||
|
# compute the computation cost of this strategy
|
||||||
|
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape)
|
||||||
|
|
||||||
|
# compute the memory cost of this strategy
|
||||||
|
dtype = self.input_data.dtype
|
||||||
|
numel = self.output.numel()
|
||||||
|
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||||
|
memory_cost = numel * size_per_elem_bytes
|
||||||
|
|
||||||
|
# compute the communication cost of this strategy
|
||||||
|
communication_cost = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim)
|
||||||
|
sharding_strategies = ShardingStrategy(name,
|
||||||
|
output_sharding_spec=sharding_spec_for_ouput,
|
||||||
|
compute_cost=compute_cost,
|
||||||
|
communication_cost=communication_cost,
|
||||||
|
memory_cost=memory_cost,
|
||||||
|
resharding_costs=resharding_costs,
|
||||||
|
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||||
|
self.strategies_vector.strategies.append(sharding_strategies)
|
||||||
|
|
||||||
|
def split_rhs_space_only(self, mesh_dim):
|
||||||
|
name = f'RS{mesh_dim} = RR x RS{mesh_dim}'
|
||||||
|
|
||||||
|
dim_partition_dict_for_input = {}
|
||||||
|
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
|
||||||
|
|
||||||
|
dim_partition_dict_for_weight = {0: [mesh_dim]}
|
||||||
|
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
|
||||||
|
|
||||||
|
dim_partition_dict_for_output = {1: [mesh_dim]}
|
||||||
|
sharding_spec_for_ouput = self._generate_sharding_spec(self.output, dim_partition_dict_for_output)
|
||||||
|
|
||||||
|
# generate resharding cost for this strategy
|
||||||
|
resharding_costs = {}
|
||||||
|
self._generate_resharding_costs(resharding_costs, sharding_spec_for_input)
|
||||||
|
|
||||||
|
# compute the computation cost of this strategy
|
||||||
|
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape)
|
||||||
|
|
||||||
|
# compute the memory cost of this strategy
|
||||||
|
dtype = self.input_data.dtype
|
||||||
|
numel = self.output.numel()
|
||||||
|
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||||
|
sharding_size = self.device_mesh.shape[mesh_dim]
|
||||||
|
memory_cost = numel * size_per_elem_bytes / sharding_size
|
||||||
|
|
||||||
|
# compute the communication cost of this strategy
|
||||||
|
communication_cost = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim)
|
||||||
|
sharding_strategies = ShardingStrategy(name,
|
||||||
|
output_sharding_spec=sharding_spec_for_ouput,
|
||||||
|
compute_cost=compute_cost,
|
||||||
|
communication_cost=communication_cost,
|
||||||
|
memory_cost=memory_cost,
|
||||||
|
resharding_costs=resharding_costs,
|
||||||
|
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
|
||||||
|
self.strategies_vector.strategies.append(sharding_strategies)
|
||||||
|
|
||||||
|
def register_strategy_into_strategies_vector(self):
|
||||||
|
'''
|
||||||
|
Generate every possible strategies for a Conv node, and record all strategies into the strategies_vector.
|
||||||
|
|
||||||
|
Output:
|
||||||
|
|
||||||
|
'''
|
||||||
|
# SS = SR x RS
|
||||||
|
self.split_lhs_space_rhs_space(0, 1)
|
||||||
|
self.split_lhs_space_rhs_space(1, 0)
|
||||||
|
|
||||||
|
# SR = SS x SR
|
||||||
|
self.split_lhs_space_both_contract(0, 1)
|
||||||
|
self.split_lhs_space_both_contract(1, 0)
|
||||||
|
|
||||||
|
# RS = RS x SS
|
||||||
|
self.split_rhs_space_both_contract(0, 1)
|
||||||
|
self.split_rhs_space_both_contract(1, 0)
|
||||||
|
|
||||||
|
# RR= RS x SR
|
||||||
|
self.recompute_split_both_contract(0)
|
||||||
|
self.recompute_split_both_contract(1)
|
||||||
|
|
||||||
|
# RS = RR x RS
|
||||||
|
self.split_rhs_space_only(0)
|
||||||
|
self.split_rhs_space_only(1)
|
||||||
|
@ -43,3 +43,23 @@ class OperatorHanlder(ABC):
|
|||||||
entire_shape=tensor.shape,
|
entire_shape=tensor.shape,
|
||||||
dim_partition_dict=dim_partition_dict)
|
dim_partition_dict=dim_partition_dict)
|
||||||
return sharding_spec
|
return sharding_spec
|
||||||
|
|
||||||
|
def _generate_resharding_costs(self, resharding_costs, sharding_spec_for_input):
|
||||||
|
'''
|
||||||
|
Compute the resharding costs with this specific strategy.
|
||||||
|
|
||||||
|
Note: The resharding_cost of weight is NOT counted.
|
||||||
|
|
||||||
|
Argument:
|
||||||
|
resharding_costs(Dict[int, List[float]]): The resharding cost generated in this method will be appended into this dictionary.
|
||||||
|
Resharding_cost[i][j] means the cost of i-th argument in the output node argument list
|
||||||
|
with j-th strategy in its strategies_vector transforms to sharding spec wanted in this
|
||||||
|
strategy.
|
||||||
|
sharding_spec_for_input(ShardingSpec): ShardingSpec of the input node.
|
||||||
|
'''
|
||||||
|
# The resharding_cost of weight is counted due to sharing weight cases.
|
||||||
|
resharding_costs[self.input_index] = []
|
||||||
|
for stategy in self.input_node.strategies_vector.strategies:
|
||||||
|
_, _, resharding_cost = self.shape_consistency_manager.shape_consistency(stategy, sharding_spec_for_input)
|
||||||
|
resharding_costs[self.input_index].append(resharding_cost)
|
||||||
|
return resharding_cost
|
||||||
|
@ -42,10 +42,13 @@ class StrategiesVector:
|
|||||||
strategies(List[ShardingStrategy]): enumerate all the possible sharding strategies of the node.
|
strategies(List[ShardingStrategy]): enumerate all the possible sharding strategies of the node.
|
||||||
'''
|
'''
|
||||||
|
|
||||||
def __init__(self, node, in_nodes, following_nodes=None, strategies=[]):
|
def __init__(self, node, in_nodes, following_nodes=None, strategies=None):
|
||||||
self.node = node
|
self.node = node
|
||||||
self.in_nodes = in_nodes
|
self.in_nodes = in_nodes
|
||||||
self.following_nodes = following_nodes
|
self.following_nodes = following_nodes
|
||||||
|
|
||||||
|
if strategies is None:
|
||||||
|
strategies = []
|
||||||
self.strategies = strategies
|
self.strategies = strategies
|
||||||
|
|
||||||
def check_merge(self):
|
def check_merge(self):
|
||||||
|
113
tests/test_auto_parallel/test_dot_handler.py
Normal file
113
tests/test_auto_parallel/test_dot_handler.py
Normal file
@ -0,0 +1,113 @@
|
|||||||
|
import torch
|
||||||
|
from torch.fx import GraphModule
|
||||||
|
import torch.nn as nn
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from colossalai.fx.proxy import ColoProxy
|
||||||
|
from colossalai.fx.tracer.tracer import ColoTracer
|
||||||
|
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
|
||||||
|
from colossalai.auto_parallel.solver.dot_handler import DotHandler
|
||||||
|
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||||
|
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||||
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
|
|
||||||
|
|
||||||
|
class LinearModel(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, in_features, out_features):
|
||||||
|
super().__init__()
|
||||||
|
self.linear = nn.Linear(in_features, out_features)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = x * 2
|
||||||
|
x = self.linear(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def test_dot_handler():
|
||||||
|
physical_mesh_id = torch.arange(0, 4)
|
||||||
|
mesh_shape = (2, 2)
|
||||||
|
# [[0, 1]
|
||||||
|
# [2, 3]]
|
||||||
|
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||||
|
entire_shape = torch.Size((4, 8))
|
||||||
|
shape_consistency_manager = ShapeConsistencyManager()
|
||||||
|
|
||||||
|
tracer = ColoTracer()
|
||||||
|
model = LinearModel(8, 16)
|
||||||
|
input_sample = {'x': torch.rand(4, 8).to('meta')}
|
||||||
|
# graph():
|
||||||
|
# %x : torch.Tensor [#users=1] = placeholder[target=x]
|
||||||
|
# %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {})
|
||||||
|
# %conv : [#users=1] = call_module[target=conv](args = (%mul,), kwargs = {})
|
||||||
|
# return conv
|
||||||
|
graph = tracer.trace(root=model, meta_args=input_sample)
|
||||||
|
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||||
|
gm.recompile()
|
||||||
|
# [x, mul, linear, output]
|
||||||
|
nodes = [node for node in gm.graph.nodes]
|
||||||
|
|
||||||
|
strategies_for_input = []
|
||||||
|
sharding_option = (None, 0, 1)
|
||||||
|
for first_sharding_index in sharding_option:
|
||||||
|
for second_sharding_index in sharding_option:
|
||||||
|
if first_sharding_index is not None and second_sharding_index == first_sharding_index:
|
||||||
|
continue
|
||||||
|
if first_sharding_index is None:
|
||||||
|
first_dim_spec = _DimSpec([])
|
||||||
|
else:
|
||||||
|
first_dim_spec = _DimSpec([first_sharding_index])
|
||||||
|
|
||||||
|
if second_sharding_index is None:
|
||||||
|
second_dim_spec = _DimSpec([])
|
||||||
|
else:
|
||||||
|
second_dim_spec = _DimSpec([second_sharding_index])
|
||||||
|
|
||||||
|
sharding_sequence = [first_dim_spec, second_dim_spec]
|
||||||
|
sharding_spec = ShardingSpec(device_mesh=device_mesh,
|
||||||
|
entire_shape=entire_shape,
|
||||||
|
sharding_sequence=sharding_sequence)
|
||||||
|
strategies_for_input.append(sharding_spec)
|
||||||
|
|
||||||
|
# strategies_for_input = [[R, R, R, R], [R, S0, R, R], [R, S1, R, R], [S0, R, R, R], [S0, S1, R, R], [S1, R, R, R], [S1, S0, R, R]]
|
||||||
|
strategies_vector_for_input = StrategiesVector(node=nodes[1], in_nodes=nodes[0], strategies=strategies_for_input)
|
||||||
|
setattr(nodes[1], 'strategies_vector', strategies_vector_for_input)
|
||||||
|
|
||||||
|
strategies_vector = StrategiesVector(node=nodes[2], in_nodes=[
|
||||||
|
nodes[1],
|
||||||
|
])
|
||||||
|
dot_handler = DotHandler(input_node=nodes[1],
|
||||||
|
input_index=0,
|
||||||
|
weight=dict(gm.named_modules())[nodes[2].name].weight,
|
||||||
|
output_node=nodes[2],
|
||||||
|
device_mesh=device_mesh,
|
||||||
|
strategies_vector=strategies_vector,
|
||||||
|
shape_consistency_manager=shape_consistency_manager)
|
||||||
|
dot_handler.register_strategy_into_strategies_vector()
|
||||||
|
|
||||||
|
# ['S0S1 = S0R x RS1', 'S1S0 = S1R x RS0', 'S0R = S0S1 x S1R', 'S1R = S1S0 x S0R', 'RS1 = RS0 x S0S1', 'RS0 = RS1 x S1S0', 'RS0 = RR x RS0', 'RS1 = RR x RS1', 'RR = RR x RR']
|
||||||
|
strategy_name_list = [strategy.name for strategy in dot_handler.strategies_vector.strategies]
|
||||||
|
|
||||||
|
# SS = SR x RS
|
||||||
|
assert 'S0S1 = S0R x RS1' in strategy_name_list
|
||||||
|
assert 'S1S0 = S1R x RS0' in strategy_name_list
|
||||||
|
|
||||||
|
# SR = SS x SR
|
||||||
|
assert 'S0R = S0S1 x S1R' in strategy_name_list
|
||||||
|
assert 'S1R = S1S0 x S0R' in strategy_name_list
|
||||||
|
|
||||||
|
# RS = RS x SS
|
||||||
|
assert 'RS0 = RS1 x S1S0' in strategy_name_list
|
||||||
|
assert 'RS1 = RS0 x S0S1' in strategy_name_list
|
||||||
|
|
||||||
|
# RR = RS x SR
|
||||||
|
assert 'RR = RS0 x S0R' in strategy_name_list
|
||||||
|
assert 'RR = RS1 x S1R' in strategy_name_list
|
||||||
|
|
||||||
|
# RS= RR x RS
|
||||||
|
assert 'RS0 = RR x RS0' in strategy_name_list
|
||||||
|
assert 'RS1 = RR x RS1' in strategy_name_list
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_dot_handler()
|
Loading…
Reference in New Issue
Block a user