mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-19 04:02:17 +00:00
[autoparallel] add experimental permute handler (#2029)
This commit is contained in:
parent
95c4532fff
commit
81330b0352
@ -37,30 +37,6 @@ def _solution_annotatation(gm: torch.fx.GraphModule, solution: List[int]):
|
|||||||
origin_node_sharding_spec_dict[node_index] = strategies_vector[strategy_index].get_sharding_spec_by_name(
|
origin_node_sharding_spec_dict[node_index] = strategies_vector[strategy_index].get_sharding_spec_by_name(
|
||||||
str(node))
|
str(node))
|
||||||
|
|
||||||
# experimental pass for torch.Tensor.view
|
|
||||||
# Arguments of view op will be divided in the sharded dimensions.
|
|
||||||
for node in nodes:
|
|
||||||
if node.op == 'call_method' and getattr(node.args[0]._meta_data.__class__, node.target) in (torch.Tensor.view,):
|
|
||||||
output_dim_partition_dict = node.sharding_spec.dim_partition_dict
|
|
||||||
device_mesh = node.sharding_spec.device_mesh
|
|
||||||
new_args = []
|
|
||||||
for arg in node.args:
|
|
||||||
if isinstance(arg, Node):
|
|
||||||
if isinstance(arg._meta_data, int):
|
|
||||||
new_args.append(arg._meta_data)
|
|
||||||
else:
|
|
||||||
new_args.append(arg)
|
|
||||||
else:
|
|
||||||
assert isinstance(arg, int), 'The argument in view node should be either type of Node or int.'
|
|
||||||
new_args.append(arg)
|
|
||||||
|
|
||||||
for dim, shard_dims in output_dim_partition_dict.items():
|
|
||||||
total_shard_size = 1
|
|
||||||
for shard_dim in shard_dims:
|
|
||||||
total_shard_size *= device_mesh.shape[shard_dim]
|
|
||||||
new_args[dim + 1] //= total_shard_size
|
|
||||||
node.args = tuple(new_args)
|
|
||||||
|
|
||||||
# the dict to get input sharding specs of user node
|
# the dict to get input sharding specs of user node
|
||||||
sharding_spec_convert_dict = {}
|
sharding_spec_convert_dict = {}
|
||||||
# the dict to record comm actions of nodes
|
# the dict to record comm actions of nodes
|
||||||
@ -113,7 +89,74 @@ def _solution_annotatation(gm: torch.fx.GraphModule, solution: List[int]):
|
|||||||
return gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict
|
return gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict
|
||||||
|
|
||||||
|
|
||||||
def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh):
|
def _node_args_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
|
||||||
|
"""
|
||||||
|
This pass will process node args to adapt the distributed tensor layout.
|
||||||
|
"""
|
||||||
|
mod_graph = gm.graph
|
||||||
|
nodes = tuple(mod_graph.nodes)
|
||||||
|
|
||||||
|
for node in nodes:
|
||||||
|
# skip the placeholder node added in _solution_annotation pass
|
||||||
|
if not hasattr(node, 'sharding_spec'):
|
||||||
|
continue
|
||||||
|
output_dim_partition_dict = node.sharding_spec.dim_partition_dict
|
||||||
|
device_mesh = node.sharding_spec.device_mesh
|
||||||
|
new_args = []
|
||||||
|
|
||||||
|
if node.op == 'call_method':
|
||||||
|
method = getattr(node.args[0]._meta_data.__class__, node.target)
|
||||||
|
# process the node with (input, *shape) style args
|
||||||
|
if method in (torch.Tensor.view, torch.Tensor.reshape):
|
||||||
|
for arg in node.args:
|
||||||
|
if isinstance(arg, Node):
|
||||||
|
if isinstance(arg._meta_data, int):
|
||||||
|
new_args.append(arg._meta_data)
|
||||||
|
else:
|
||||||
|
new_args.append(arg)
|
||||||
|
else:
|
||||||
|
assert isinstance(arg, int), 'The argument in view node should be either type of Node or int.'
|
||||||
|
new_args.append(arg)
|
||||||
|
|
||||||
|
for dim, shard_dims in output_dim_partition_dict.items():
|
||||||
|
# we will skip the dim with -1 value
|
||||||
|
if new_args[dim + 1] == -1:
|
||||||
|
continue
|
||||||
|
total_shard_size = 1
|
||||||
|
for shard_dim in shard_dims:
|
||||||
|
total_shard_size *= device_mesh.shape[shard_dim]
|
||||||
|
new_args[dim + 1] //= total_shard_size
|
||||||
|
node.args = tuple(new_args)
|
||||||
|
|
||||||
|
elif node.op == 'call_function':
|
||||||
|
target = node.target
|
||||||
|
# process the node with (input, torch.Size) style args
|
||||||
|
if target in (torch.reshape,):
|
||||||
|
for arg in node.args:
|
||||||
|
if isinstance(arg, Node):
|
||||||
|
if isinstance(arg._meta_data, (tuple, list)):
|
||||||
|
new_args.append(list(arg._meta_data))
|
||||||
|
else:
|
||||||
|
new_args.append(arg)
|
||||||
|
else:
|
||||||
|
assert isinstance(
|
||||||
|
arg, (tuple, list)), 'The argument in reshape node should be either type of Node or tuple.'
|
||||||
|
new_args.append(list(arg))
|
||||||
|
|
||||||
|
for dim, shard_dims in output_dim_partition_dict.items():
|
||||||
|
# we will skip the dim with -1 value
|
||||||
|
if new_args[1][dim] == -1:
|
||||||
|
continue
|
||||||
|
total_shard_size = 1
|
||||||
|
for shard_dim in shard_dims:
|
||||||
|
total_shard_size *= device_mesh.shape[shard_dim]
|
||||||
|
new_args[1][dim] //= total_shard_size
|
||||||
|
node.args = tuple(new_args)
|
||||||
|
|
||||||
|
return gm
|
||||||
|
|
||||||
|
|
||||||
|
def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
|
||||||
"""
|
"""
|
||||||
Apply the sharding action to the module parameters and buffers following the
|
Apply the sharding action to the module parameters and buffers following the
|
||||||
instructions of solver solution.
|
instructions of solver solution.
|
||||||
@ -216,6 +259,7 @@ def implicit_comm_action_apply(gm: torch.fx.GraphModule):
|
|||||||
def runtime_preparation_pass(gm: torch.fx.GraphModule, solution: List[int], device_mesh: DeviceMesh):
|
def runtime_preparation_pass(gm: torch.fx.GraphModule, solution: List[int], device_mesh: DeviceMesh):
|
||||||
gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict = _solution_annotatation(
|
gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict = _solution_annotatation(
|
||||||
gm, solution)
|
gm, solution)
|
||||||
|
gm = _node_args_converting(gm, device_mesh)
|
||||||
# TODO: the pass below should be uncommented after the implementation of implicit_comm_action_apply_pass completed.
|
# TODO: the pass below should be uncommented after the implementation of implicit_comm_action_apply_pass completed.
|
||||||
# gm = implicit_comm_action_apply(gm)
|
# gm = implicit_comm_action_apply(gm)
|
||||||
gm = _module_params_sharding(gm, device_mesh)
|
gm = _module_params_sharding(gm, device_mesh)
|
||||||
|
@ -3,6 +3,7 @@ from .batch_norm_handler import BatchNormModuleHandler
|
|||||||
from .binary_elementwise_handler import BinaryElementwiseHandler
|
from .binary_elementwise_handler import BinaryElementwiseHandler
|
||||||
from .bmm_handler import AddBMMFunctionHandler, BMMFunctionHandler
|
from .bmm_handler import AddBMMFunctionHandler, BMMFunctionHandler
|
||||||
from .conv_handler import ConvFunctionHandler, ConvModuleHandler
|
from .conv_handler import ConvFunctionHandler, ConvModuleHandler
|
||||||
|
from .experimental import PermuteHandler, ViewHandler
|
||||||
from .getatrr_handler import GetattrHandler
|
from .getatrr_handler import GetattrHandler
|
||||||
from .getitem_handler import GetItemHandler
|
from .getitem_handler import GetItemHandler
|
||||||
from .layer_norm_handler import LayerNormModuleHandler
|
from .layer_norm_handler import LayerNormModuleHandler
|
||||||
@ -21,5 +22,5 @@ __all__ = [
|
|||||||
'LayerNormModuleHandler', 'BatchNormModuleHandler', 'ConvModuleHandler', 'ConvFunctionHandler',
|
'LayerNormModuleHandler', 'BatchNormModuleHandler', 'ConvModuleHandler', 'ConvFunctionHandler',
|
||||||
'UnaryElementwiseHandler', 'ReshapeHandler', 'PlacehodlerHandler', 'OuputHandler', 'WhereHandler',
|
'UnaryElementwiseHandler', 'ReshapeHandler', 'PlacehodlerHandler', 'OuputHandler', 'WhereHandler',
|
||||||
'NormPoolingHandler', 'BinaryElementwiseHandler', 'MatMulHandler', 'operator_registry', 'ADDMMFunctionHandler',
|
'NormPoolingHandler', 'BinaryElementwiseHandler', 'MatMulHandler', 'operator_registry', 'ADDMMFunctionHandler',
|
||||||
'GetItemHandler', 'GetattrHandler'
|
'GetItemHandler', 'GetattrHandler', 'ViewHandler', 'PermuteHandler'
|
||||||
]
|
]
|
||||||
|
@ -1,4 +1,8 @@
|
|||||||
from .view_generator import ViewGenerator
|
from .permute_handler import PermuteHandler
|
||||||
|
from .reshape_generator import PermuteGenerator, TransposeGenerator, ViewGenerator
|
||||||
|
from .transpose_handler import TransposeHandler
|
||||||
from .view_handler import ViewHandler
|
from .view_handler import ViewHandler
|
||||||
|
|
||||||
__all__ = ['ViewGenerator', 'ViewHandler']
|
__all__ = [
|
||||||
|
'ViewGenerator', 'ViewHandler', 'PermuteGenerator', 'PermuteHandler', 'TransposeGenerator', 'TransposeGenerator'
|
||||||
|
]
|
||||||
|
@ -0,0 +1,76 @@
|
|||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from ...sharding_strategy import OperationData, OperationDataType
|
||||||
|
from ..node_handler import NodeHandler
|
||||||
|
from ..registry import operator_registry
|
||||||
|
from ..strategy import StrategyGenerator
|
||||||
|
from .reshape_generator import PermuteGenerator
|
||||||
|
|
||||||
|
__all__ = ['PermuteHandler']
|
||||||
|
|
||||||
|
|
||||||
|
@operator_registry.register(torch.Tensor.permute)
|
||||||
|
@operator_registry.register(torch.permute)
|
||||||
|
class PermuteHandler(NodeHandler):
|
||||||
|
"""
|
||||||
|
A PermuteHandler which deals with the sharding strategies for torch.permute or torch.transpose.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_strategy_generator(self) -> List[StrategyGenerator]:
|
||||||
|
op_data_mapping = self.get_operation_data_mapping()
|
||||||
|
generators = []
|
||||||
|
generators.append(PermuteGenerator(op_data_mapping, self.device_mesh, self.node.args[0]))
|
||||||
|
return generators
|
||||||
|
|
||||||
|
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
|
||||||
|
# check if the input operand is a parameter
|
||||||
|
if isinstance(self.node.args[0]._meta_data, torch.nn.parameter.Parameter):
|
||||||
|
data_type = OperationDataType.PARAM
|
||||||
|
else:
|
||||||
|
data_type = OperationDataType.ARG
|
||||||
|
|
||||||
|
input_data = self.node.args[0]._meta_data
|
||||||
|
physical_input_operand = OperationData(name=str(self.node.args[0]), type=data_type, data=input_data)
|
||||||
|
|
||||||
|
permute_dims = []
|
||||||
|
if self.node.op == 'call_method':
|
||||||
|
# torch.Tensor.permute (input, *dims)
|
||||||
|
for arg in self.node.args:
|
||||||
|
if isinstance(arg, torch.fx.Node):
|
||||||
|
if isinstance(arg._meta_data, int):
|
||||||
|
permute_dims.append(arg._meta_data)
|
||||||
|
else:
|
||||||
|
assert isinstance(arg, int), 'The argument in permute node should be either type of Node or int.'
|
||||||
|
permute_dims.append(arg)
|
||||||
|
else:
|
||||||
|
# torch.permute (input, dims)
|
||||||
|
for arg in self.node.args:
|
||||||
|
if isinstance(arg, torch.fx.Node):
|
||||||
|
if isinstance(arg._meta_data, (tuple, list)):
|
||||||
|
permute_dims.extend(arg._meta_data)
|
||||||
|
else:
|
||||||
|
assert isinstance(
|
||||||
|
arg,
|
||||||
|
(tuple, list)), 'The argument in permute node should be type of Node, Tuple[int] or List[int].'
|
||||||
|
permute_dims.extend(arg)
|
||||||
|
|
||||||
|
num_dims = self.node._meta_data.dim()
|
||||||
|
for i in range(num_dims):
|
||||||
|
# recover negative value to positive
|
||||||
|
if permute_dims[i] < 0:
|
||||||
|
permute_dims[i] += num_dims
|
||||||
|
|
||||||
|
physical_shape_operand = OperationData(name='permute_dims', type=OperationDataType.ARG, data=list(permute_dims))
|
||||||
|
|
||||||
|
output_data = self.node._meta_data
|
||||||
|
physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data)
|
||||||
|
|
||||||
|
mapping = {
|
||||||
|
"input": physical_input_operand,
|
||||||
|
"permute_dims": physical_shape_operand,
|
||||||
|
"output": physical_output_operand
|
||||||
|
}
|
||||||
|
|
||||||
|
return mapping
|
@ -17,12 +17,12 @@ from colossalai.auto_parallel.tensor_shard.utils import (
|
|||||||
from colossalai.tensor.shape_consistency import CollectiveCommPattern
|
from colossalai.tensor.shape_consistency import CollectiveCommPattern
|
||||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||||
|
|
||||||
__all__ = ['ViewGenerator']
|
__all__ = ['ReshapeGenerator', 'ViewGenerator', 'PermuteGenerator', 'TransposeGenerator']
|
||||||
|
|
||||||
|
|
||||||
class ViewGenerator(FollowingStrategyGenerator):
|
class ReshapeGenerator(FollowingStrategyGenerator):
|
||||||
"""
|
"""
|
||||||
ViewGenerator which deals with the sharding strategies of view op.
|
ReshapeGenerator is the base class for all the reshape operation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def validate(self) -> bool:
|
def validate(self) -> bool:
|
||||||
@ -61,6 +61,15 @@ class ViewGenerator(FollowingStrategyGenerator):
|
|||||||
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
|
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
|
||||||
strategy.memory_cost = memory_cost
|
strategy.memory_cost = memory_cost
|
||||||
|
|
||||||
|
def collate_strategies(self) -> List[ShardingStrategy]:
|
||||||
|
return super().collate_strategies()
|
||||||
|
|
||||||
|
|
||||||
|
class ViewGenerator(ReshapeGenerator):
|
||||||
|
"""
|
||||||
|
ViewGenerator deals with the sharding strategies of view op.
|
||||||
|
"""
|
||||||
|
|
||||||
def collate_strategies(self) -> List[ShardingStrategy]:
|
def collate_strategies(self) -> List[ShardingStrategy]:
|
||||||
strategy_list = []
|
strategy_list = []
|
||||||
for index, strategy in enumerate(self.predecessor_node.strategies_vector):
|
for index, strategy in enumerate(self.predecessor_node.strategies_vector):
|
||||||
@ -136,3 +145,85 @@ class ViewGenerator(FollowingStrategyGenerator):
|
|||||||
strategy_list.append(strategy)
|
strategy_list.append(strategy)
|
||||||
|
|
||||||
return strategy_list
|
return strategy_list
|
||||||
|
|
||||||
|
|
||||||
|
class PermuteGenerator(ReshapeGenerator):
|
||||||
|
"""
|
||||||
|
PermuteGenerator deals with the sharding strategies of permute op.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def collate_strategies(self) -> List[ShardingStrategy]:
|
||||||
|
strategy_list = []
|
||||||
|
for index, strategy in enumerate(self.predecessor_node.strategies_vector):
|
||||||
|
dim_partition_dict_mapping = {}
|
||||||
|
communication_action_mapping = {}
|
||||||
|
input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]]
|
||||||
|
|
||||||
|
permute_dims = self.op_data['permute_dims'].data
|
||||||
|
dim_partition_dict_for_input = input_sharding_spec.dim_partition_dict
|
||||||
|
dim_partition_dict_for_output = {}
|
||||||
|
for dim_index, permute_dim in enumerate(permute_dims):
|
||||||
|
if permute_dim in dim_partition_dict_for_input:
|
||||||
|
dim_partition_dict_for_output[dim_index] = dim_partition_dict_for_input[permute_dim]
|
||||||
|
|
||||||
|
dim_partition_dict_mapping = {
|
||||||
|
"input": dim_partition_dict_for_input,
|
||||||
|
"output": dim_partition_dict_for_output,
|
||||||
|
}
|
||||||
|
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
|
||||||
|
|
||||||
|
# add index into name to pass the duplicated check
|
||||||
|
# we keep same strategies with different name for node merging, and it will not increase the searching space,
|
||||||
|
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
|
||||||
|
name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}'
|
||||||
|
|
||||||
|
strategy = self.get_sharding_strategy(name=name,
|
||||||
|
sharding_spec_mapping=sharding_spec_mapping,
|
||||||
|
communication_action_mapping=communication_action_mapping)
|
||||||
|
strategy_list.append(strategy)
|
||||||
|
|
||||||
|
return strategy_list
|
||||||
|
|
||||||
|
|
||||||
|
class TransposeGenerator(ReshapeGenerator):
|
||||||
|
"""
|
||||||
|
TransposeGenerator deals with the sharding strategies of permute op.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def collate_strategies(self) -> List[ShardingStrategy]:
|
||||||
|
strategy_list = []
|
||||||
|
for index, strategy in enumerate(self.predecessor_node.strategies_vector):
|
||||||
|
dim_partition_dict_mapping = {}
|
||||||
|
communication_action_mapping = {}
|
||||||
|
input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]]
|
||||||
|
dim_partition_dict_for_input = input_sharding_spec.dim_partition_dict
|
||||||
|
dim_partition_dict_for_output = {}
|
||||||
|
|
||||||
|
transpose_dims = self.op_data['transpose_dims'].data
|
||||||
|
dim_0 = transpose_dims[0]
|
||||||
|
dim_1 = transpose_dims[1]
|
||||||
|
for dim, sharded_dims in dim_partition_dict_for_input.items():
|
||||||
|
if dim == dim_0:
|
||||||
|
dim_partition_dict_for_output[dim_1] = dim_partition_dict_for_input[dim_0]
|
||||||
|
elif dim == dim_1:
|
||||||
|
dim_partition_dict_for_output[dim_0] = dim_partition_dict_for_input[dim_1]
|
||||||
|
else:
|
||||||
|
dim_partition_dict_for_output[dim] = sharded_dims
|
||||||
|
|
||||||
|
dim_partition_dict_mapping = {
|
||||||
|
"input": dim_partition_dict_for_input,
|
||||||
|
"output": dim_partition_dict_for_output,
|
||||||
|
}
|
||||||
|
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
|
||||||
|
|
||||||
|
# add index into name to pass the duplicated check
|
||||||
|
# we keep same strategies with different name for node merging, and it will not increase the searching space,
|
||||||
|
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
|
||||||
|
name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}'
|
||||||
|
|
||||||
|
strategy = self.get_sharding_strategy(name=name,
|
||||||
|
sharding_spec_mapping=sharding_spec_mapping,
|
||||||
|
communication_action_mapping=communication_action_mapping)
|
||||||
|
strategy_list.append(strategy)
|
||||||
|
|
||||||
|
return strategy_list
|
@ -0,0 +1,65 @@
|
|||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from ...sharding_strategy import OperationData, OperationDataType
|
||||||
|
from ..node_handler import NodeHandler
|
||||||
|
from ..registry import operator_registry
|
||||||
|
from ..strategy import StrategyGenerator
|
||||||
|
from .reshape_generator import TransposeGenerator
|
||||||
|
|
||||||
|
__all__ = ['TransposeHandler']
|
||||||
|
|
||||||
|
|
||||||
|
@operator_registry.register(torch.Tensor.transpose)
|
||||||
|
@operator_registry.register(torch.transpose)
|
||||||
|
class TransposeHandler(NodeHandler):
|
||||||
|
"""
|
||||||
|
A TransposeHandler which deals with the sharding strategies for torch.permute or torch.transpose.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_strategy_generator(self) -> List[StrategyGenerator]:
|
||||||
|
op_data_mapping = self.get_operation_data_mapping()
|
||||||
|
generators = []
|
||||||
|
generators.append(TransposeGenerator(op_data_mapping, self.device_mesh, self.node.args[0]))
|
||||||
|
return generators
|
||||||
|
|
||||||
|
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
|
||||||
|
# check if the input operand is a parameter
|
||||||
|
if isinstance(self.node.args[0]._meta_data, torch.nn.parameter.Parameter):
|
||||||
|
data_type = OperationDataType.PARAM
|
||||||
|
else:
|
||||||
|
data_type = OperationDataType.ARG
|
||||||
|
|
||||||
|
input_data = self.node.args[0]._meta_data
|
||||||
|
physical_input_operand = OperationData(name=str(self.node.args[0]), type=data_type, data=input_data)
|
||||||
|
|
||||||
|
transpose_dims = []
|
||||||
|
# torch.transpose (input, dim0, dim1)
|
||||||
|
for arg in self.node.args:
|
||||||
|
if isinstance(arg, torch.fx.Node):
|
||||||
|
if isinstance(arg._meta_data, int):
|
||||||
|
transpose_dims.append(arg._meta_data)
|
||||||
|
else:
|
||||||
|
transpose_dims.append(arg)
|
||||||
|
|
||||||
|
num_dims = self.node._meta_data.dim()
|
||||||
|
for i in range(2):
|
||||||
|
# recover negative value to positive
|
||||||
|
if transpose_dims[i] < 0:
|
||||||
|
transpose_dims[i] += num_dims
|
||||||
|
|
||||||
|
physical_shape_operand = OperationData(name='transpose_dims',
|
||||||
|
type=OperationDataType.ARG,
|
||||||
|
data=list(transpose_dims))
|
||||||
|
|
||||||
|
output_data = self.node._meta_data
|
||||||
|
physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data)
|
||||||
|
|
||||||
|
mapping = {
|
||||||
|
"input": physical_input_operand,
|
||||||
|
"transpose_dims": physical_shape_operand,
|
||||||
|
"output": physical_output_operand
|
||||||
|
}
|
||||||
|
|
||||||
|
return mapping
|
@ -6,11 +6,13 @@ from ...sharding_strategy import OperationData, OperationDataType
|
|||||||
from ..node_handler import NodeHandler
|
from ..node_handler import NodeHandler
|
||||||
from ..registry import operator_registry
|
from ..registry import operator_registry
|
||||||
from ..strategy import StrategyGenerator
|
from ..strategy import StrategyGenerator
|
||||||
from .view_generator import ViewGenerator
|
from .reshape_generator import ViewGenerator
|
||||||
|
|
||||||
__all__ = ['ViewHandler']
|
__all__ = ['ViewHandler']
|
||||||
|
|
||||||
|
|
||||||
|
@operator_registry.register(torch.Tensor.reshape)
|
||||||
|
@operator_registry.register(torch.reshape)
|
||||||
@operator_registry.register(torch.Tensor.view)
|
@operator_registry.register(torch.Tensor.view)
|
||||||
class ViewHandler(NodeHandler):
|
class ViewHandler(NodeHandler):
|
||||||
"""
|
"""
|
||||||
|
@ -10,13 +10,9 @@ from .strategy import ReshapeGenerator, StrategyGenerator
|
|||||||
__all__ = ['ReshapeHandler']
|
__all__ = ['ReshapeHandler']
|
||||||
|
|
||||||
|
|
||||||
@operator_registry.register(torch.reshape)
|
|
||||||
@operator_registry.register(torch.Tensor.split)
|
@operator_registry.register(torch.Tensor.split)
|
||||||
@operator_registry.register(torch.split)
|
@operator_registry.register(torch.split)
|
||||||
@operator_registry.register(torch.flatten)
|
@operator_registry.register(torch.flatten)
|
||||||
@operator_registry.register(torch.Tensor.transpose)
|
|
||||||
@operator_registry.register(torch.Tensor.permute)
|
|
||||||
@operator_registry.register(torch.Tensor.view)
|
|
||||||
@operator_registry.register(torch.nn.AdaptiveAvgPool2d)
|
@operator_registry.register(torch.nn.AdaptiveAvgPool2d)
|
||||||
class ReshapeHandler(NodeHandler):
|
class ReshapeHandler(NodeHandler):
|
||||||
"""
|
"""
|
||||||
|
@ -23,6 +23,8 @@ def _all_gather(tensor, comm_spec):
|
|||||||
torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device)
|
torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device)
|
||||||
for _ in range(comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axis])
|
for _ in range(comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axis])
|
||||||
]
|
]
|
||||||
|
# without this contiguous operation, the all gather may get some unexpected results.
|
||||||
|
tensor = tensor.contiguous()
|
||||||
dist.all_gather(tensor_list, tensor, group=process_group)
|
dist.all_gather(tensor_list, tensor, group=process_group)
|
||||||
output = torch.cat(tuple(tensor_list), comm_spec.gather_dim).contiguous()
|
output = torch.cat(tuple(tensor_list), comm_spec.gather_dim).contiguous()
|
||||||
return output
|
return output
|
||||||
|
@ -0,0 +1,339 @@
|
|||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler
|
||||||
|
from colossalai.auto_parallel.tensor_shard.node_handler.experimental import PermuteHandler, TransposeHandler
|
||||||
|
from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler
|
||||||
|
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||||
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
|
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||||
|
from colossalai.initialize import launch
|
||||||
|
from colossalai.logging import disable_existing_loggers
|
||||||
|
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
|
||||||
|
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||||
|
from colossalai.utils import free_port
|
||||||
|
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
|
||||||
|
|
||||||
|
|
||||||
|
class ConvReshapeModel(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, reshape_dims, call_function):
|
||||||
|
super().__init__()
|
||||||
|
self.reshape_dims = reshape_dims
|
||||||
|
self.call_function = call_function
|
||||||
|
|
||||||
|
def forward(self, input, other):
|
||||||
|
conv_node = nn.functional.conv2d(input, other, bias=None)
|
||||||
|
# permute_node = torch.permute(conv_node, self.permute_dims)
|
||||||
|
if self.call_function == torch.permute:
|
||||||
|
permute_node = self.call_function(conv_node, self.reshape_dims)
|
||||||
|
else:
|
||||||
|
permute_node = self.call_function(conv_node, *self.reshape_dims)
|
||||||
|
return permute_node
|
||||||
|
|
||||||
|
|
||||||
|
class LinearReshapeModel(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, reshape_dims, call_function):
|
||||||
|
super().__init__()
|
||||||
|
self.reshape_dims = reshape_dims
|
||||||
|
self.call_function = call_function
|
||||||
|
|
||||||
|
def forward(self, input, other):
|
||||||
|
linear_node = nn.functional.linear(input, other, bias=None)
|
||||||
|
# permute_node = torch.permute(linear_node, self.tgt_shape)
|
||||||
|
if self.call_function == torch.permute:
|
||||||
|
permute_node = self.call_function(linear_node, self.reshape_dims)
|
||||||
|
else:
|
||||||
|
permute_node = self.call_function(linear_node, *self.reshape_dims)
|
||||||
|
return permute_node
|
||||||
|
|
||||||
|
|
||||||
|
def check_view_handler(rank, call_function, reshape_dims, model_cls, world_size, port):
|
||||||
|
disable_existing_loggers()
|
||||||
|
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
|
if call_function == torch.permute:
|
||||||
|
reshape_dims = reshape_dims[0]
|
||||||
|
elif call_function == torch.transpose:
|
||||||
|
reshape_dims = reshape_dims[1]
|
||||||
|
model = model_cls(reshape_dims, call_function).cuda()
|
||||||
|
|
||||||
|
if model_cls.__name__ == 'ConvReshapeModel':
|
||||||
|
input = torch.rand(8, 8, 66, 66).to('cuda')
|
||||||
|
other = torch.rand(16, 8, 3, 3).to('cuda')
|
||||||
|
# index of conv node in computation graph
|
||||||
|
node_index = 2
|
||||||
|
# total number of conv strategies
|
||||||
|
strategy_number = 16
|
||||||
|
if model_cls.__name__ == 'LinearReshapeModel':
|
||||||
|
input = torch.rand(8, 16, 64, 32).to('cuda')
|
||||||
|
other = torch.rand(64, 32).to('cuda')
|
||||||
|
# index of linear node in computation graph
|
||||||
|
node_index = 2
|
||||||
|
# total number of linear strategies
|
||||||
|
strategy_number = 23
|
||||||
|
|
||||||
|
physical_mesh_id = torch.arange(0, 4)
|
||||||
|
mesh_shape = (2, 2)
|
||||||
|
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||||
|
|
||||||
|
numerical_test_for_node_strategy(model=model,
|
||||||
|
device_mesh=device_mesh,
|
||||||
|
node_index=node_index,
|
||||||
|
strategy_number=strategy_number,
|
||||||
|
input_args=[input, other],
|
||||||
|
meta_arg_names=['input', 'other'],
|
||||||
|
node_type='following')
|
||||||
|
tracer = ColoTracer()
|
||||||
|
if model_cls.__name__ == 'ConvReshapeModel':
|
||||||
|
# graph():
|
||||||
|
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
|
||||||
|
# %other : torch.Tensor [#users=1] = placeholder[target=other]
|
||||||
|
# %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%input_1, %other), kwargs = {bias: None})
|
||||||
|
# %permute : [#users=1] = call_function[target=torch.permute](args = (%conv2d, (0, 2, 1, 3)), kwargs = {})
|
||||||
|
# return permute
|
||||||
|
graph = tracer.trace(model,
|
||||||
|
meta_args={
|
||||||
|
"input": torch.rand(8, 8, 66, 66).to('meta'),
|
||||||
|
"other": torch.rand(16, 8, 3, 3).to('meta'),
|
||||||
|
})
|
||||||
|
|
||||||
|
if model_cls.__name__ == 'LinearReshapeModel':
|
||||||
|
# graph():
|
||||||
|
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
|
||||||
|
# %other : torch.Tensor [#users=1] = placeholder[target=other]
|
||||||
|
# %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%input_1, %other), kwargs = {bias: None})
|
||||||
|
# %permute : [#users=1] = call_method[target=view](args = (%linear, 32, 4, 32, 32, 4), kwargs = {})
|
||||||
|
# return permute
|
||||||
|
graph = tracer.trace(model,
|
||||||
|
meta_args={
|
||||||
|
"input": torch.rand(8, 16, 64, 32).to('meta'),
|
||||||
|
"other": torch.rand(64, 32).to('meta'),
|
||||||
|
})
|
||||||
|
|
||||||
|
gm = ColoGraphModule(model, graph)
|
||||||
|
|
||||||
|
previous_mod_node = list(graph.nodes)[2]
|
||||||
|
reshape_node = list(graph.nodes)[3]
|
||||||
|
view_strategies_vector = StrategiesVector(reshape_node)
|
||||||
|
previous_strategies_vector = StrategiesVector(previous_mod_node)
|
||||||
|
|
||||||
|
# build handler
|
||||||
|
if model_cls.__name__ == 'ConvReshapeModel':
|
||||||
|
|
||||||
|
conv_handler = ConvFunctionHandler(node=previous_mod_node,
|
||||||
|
device_mesh=device_mesh,
|
||||||
|
strategies_vector=previous_strategies_vector)
|
||||||
|
conv_handler.register_strategy(compute_resharding_cost=False)
|
||||||
|
setattr(previous_mod_node, 'strategies_vector', previous_strategies_vector)
|
||||||
|
|
||||||
|
if model_cls.__name__ == 'LinearReshapeModel':
|
||||||
|
assert len(previous_strategies_vector) == 0
|
||||||
|
linear_handler = LinearFunctionHandler(node=previous_mod_node,
|
||||||
|
device_mesh=device_mesh,
|
||||||
|
strategies_vector=previous_strategies_vector)
|
||||||
|
linear_handler.register_strategy(compute_resharding_cost=False)
|
||||||
|
setattr(previous_mod_node, 'strategies_vector', previous_strategies_vector)
|
||||||
|
|
||||||
|
if call_function == torch.permute:
|
||||||
|
reshape_handler = PermuteHandler(node=reshape_node,
|
||||||
|
device_mesh=device_mesh,
|
||||||
|
strategies_vector=view_strategies_vector)
|
||||||
|
else:
|
||||||
|
reshape_handler = TransposeHandler(node=reshape_node,
|
||||||
|
device_mesh=device_mesh,
|
||||||
|
strategies_vector=view_strategies_vector)
|
||||||
|
|
||||||
|
reshape_handler.register_strategy(compute_resharding_cost=False)
|
||||||
|
|
||||||
|
# check operation data mapping
|
||||||
|
mapping = reshape_handler.get_operation_data_mapping()
|
||||||
|
|
||||||
|
for name, op_data in mapping.items():
|
||||||
|
op_data: OperationData
|
||||||
|
# make sure they have valid values
|
||||||
|
assert op_data.data is not None
|
||||||
|
|
||||||
|
if model_cls.__name__ == 'ConvReshapeModel':
|
||||||
|
assert mapping['input'].name == "conv2d"
|
||||||
|
else:
|
||||||
|
assert mapping['input'].name == "linear"
|
||||||
|
assert mapping['input'].data.is_meta
|
||||||
|
assert mapping['input'].data.shape == torch.Size([8, 16, 64, 64])
|
||||||
|
assert mapping['input'].type == OperationDataType.ARG
|
||||||
|
assert mapping['input'].logical_shape == torch.Size([8, 16, 64, 64])
|
||||||
|
|
||||||
|
if call_function == torch.permute:
|
||||||
|
assert mapping['output'].name == "permute"
|
||||||
|
assert mapping['output'].data.is_meta
|
||||||
|
assert mapping['output'].data.shape == torch.permute(torch.rand(8, 16, 64, 64), reshape_dims).shape
|
||||||
|
assert mapping['output'].type == OperationDataType.OUTPUT
|
||||||
|
else:
|
||||||
|
assert mapping['output'].name == "transpose"
|
||||||
|
assert mapping['output'].data.is_meta
|
||||||
|
assert mapping['output'].data.shape == torch.transpose(torch.rand(8, 16, 64, 64), *reshape_dims).shape
|
||||||
|
assert mapping['output'].type == OperationDataType.OUTPUT
|
||||||
|
|
||||||
|
# reshape handler is a following strategy handler, so the number of strategies is equal to the predecessor node.
|
||||||
|
assert len(view_strategies_vector) == len(previous_strategies_vector)
|
||||||
|
strategy_name_list = [strategy.name for strategy in view_strategies_vector]
|
||||||
|
if rank == 0:
|
||||||
|
for name in strategy_name_list:
|
||||||
|
print(name)
|
||||||
|
if model_cls.__name__ == 'ConvReshapeModel':
|
||||||
|
|
||||||
|
if reshape_dims in ((0, 2, 1, 3), (1, 2)):
|
||||||
|
assert '[S0, S1, R, R] -> [S0, R, S1, R]_0' in strategy_name_list
|
||||||
|
assert '[S1, S0, R, R] -> [S1, R, S0, R]_1' in strategy_name_list
|
||||||
|
assert '[S0, R, R, R] -> [S0, R, R, R]_2' in strategy_name_list
|
||||||
|
assert '[S1, R, R, R] -> [S1, R, R, R]_3' in strategy_name_list
|
||||||
|
assert '[S0, R, R, R] -> [S0, R, R, R]_4' in strategy_name_list
|
||||||
|
assert '[S1, R, R, R] -> [S1, R, R, R]_5' in strategy_name_list
|
||||||
|
assert '[R, S1, R, R] -> [R, R, S1, R]_6' in strategy_name_list
|
||||||
|
assert '[R, S0, R, R] -> [R, R, S0, R]_7' in strategy_name_list
|
||||||
|
assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list
|
||||||
|
assert '[R, R, R, R] -> [R, R, R, R]_9' in strategy_name_list
|
||||||
|
assert '[R, S0, R, R] -> [R, R, S0, R]_10' in strategy_name_list
|
||||||
|
assert '[R, S1, R, R] -> [R, R, S1, R]_11' in strategy_name_list
|
||||||
|
assert '[R, R, R, R] -> [R, R, R, R]_12' in strategy_name_list
|
||||||
|
assert '[S01, R, R, R] -> [S01, R, R, R]_13' in strategy_name_list
|
||||||
|
assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list
|
||||||
|
assert '[R, S01, R, R] -> [R, R, S01, R]_15' in strategy_name_list
|
||||||
|
|
||||||
|
if reshape_dims == (2, 0, 1, 3):
|
||||||
|
assert '[S0, S1, R, R] -> [R, S0, S1, R]_0' in strategy_name_list
|
||||||
|
assert '[S1, S0, R, R] -> [R, S1, S0, R]_1' in strategy_name_list
|
||||||
|
assert '[S0, R, R, R] -> [R, S0, R, R]_2' in strategy_name_list
|
||||||
|
assert '[S1, R, R, R] -> [R, S1, R, R]_3' in strategy_name_list
|
||||||
|
assert '[S0, R, R, R] -> [R, S0, R, R]_4' in strategy_name_list
|
||||||
|
assert '[S1, R, R, R] -> [R, S1, R, R]_5' in strategy_name_list
|
||||||
|
assert '[R, S1, R, R] -> [R, R, S1, R]_6' in strategy_name_list
|
||||||
|
assert '[R, S0, R, R] -> [R, R, S0, R]_7' in strategy_name_list
|
||||||
|
assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list
|
||||||
|
assert '[R, R, R, R] -> [R, R, R, R]_9' in strategy_name_list
|
||||||
|
assert '[R, S0, R, R] -> [R, R, S0, R]_10' in strategy_name_list
|
||||||
|
assert '[R, S1, R, R] -> [R, R, S1, R]_11' in strategy_name_list
|
||||||
|
assert '[R, R, R, R] -> [R, R, R, R]_12' in strategy_name_list
|
||||||
|
assert '[S01, R, R, R] -> [R, S01, R, R]_13' in strategy_name_list
|
||||||
|
assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list
|
||||||
|
assert '[R, S01, R, R] -> [R, R, S01, R]_15' in strategy_name_list
|
||||||
|
|
||||||
|
if reshape_dims == (1, 3):
|
||||||
|
assert '[S0, S1, R, R] -> [S0, R, R, S1]_0' in strategy_name_list
|
||||||
|
assert '[S1, S0, R, R] -> [S1, R, R, S0]_1' in strategy_name_list
|
||||||
|
assert '[S0, R, R, R] -> [S0, R, R, R]_2' in strategy_name_list
|
||||||
|
assert '[S1, R, R, R] -> [S1, R, R, R]_3' in strategy_name_list
|
||||||
|
assert '[S0, R, R, R] -> [S0, R, R, R]_4' in strategy_name_list
|
||||||
|
assert '[S1, R, R, R] -> [S1, R, R, R]_5' in strategy_name_list
|
||||||
|
assert '[R, S1, R, R] -> [R, R, R, S1]_6' in strategy_name_list
|
||||||
|
assert '[R, S0, R, R] -> [R, R, R, S0]_7' in strategy_name_list
|
||||||
|
assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list
|
||||||
|
assert '[R, R, R, R] -> [R, R, R, R]_9' in strategy_name_list
|
||||||
|
assert '[R, S0, R, R] -> [R, R, R, S0]_10' in strategy_name_list
|
||||||
|
assert '[R, S1, R, R] -> [R, R, R, S1]_11' in strategy_name_list
|
||||||
|
assert '[R, R, R, R] -> [R, R, R, R]_12' in strategy_name_list
|
||||||
|
assert '[S01, R, R, R] -> [S01, R, R, R]_13' in strategy_name_list
|
||||||
|
assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list
|
||||||
|
assert '[R, S01, R, R] -> [R, R, R, S01]_15' in strategy_name_list
|
||||||
|
|
||||||
|
if model_cls.__name__ == 'LinearReshapeModel':
|
||||||
|
|
||||||
|
if reshape_dims == ((0, 2, 1, 3), (1, 2)):
|
||||||
|
assert '[S0, R, R, S1] -> [S0, R, R, S1]_0' in strategy_name_list
|
||||||
|
assert '[R, S0, R, S1] -> [R, R, S0, S1]_1' in strategy_name_list
|
||||||
|
assert '[R, R, S0, S1] -> [R, S0, R, S1]_2' in strategy_name_list
|
||||||
|
assert '[S1, R, R, S0] -> [S1, R, R, S0]_3' in strategy_name_list
|
||||||
|
assert '[R, S1, R, S0] -> [R, R, S1, S0]_4' in strategy_name_list
|
||||||
|
assert '[R, R, S1, S0] -> [R, S1, R, S0]_5' in strategy_name_list
|
||||||
|
assert '[S0, R, R, R] -> [S0, R, R, R]_6' in strategy_name_list
|
||||||
|
assert '[R, S0, R, R] -> [R, R, S0, R]_7' in strategy_name_list
|
||||||
|
assert '[R, R, S0, R] -> [R, S0, R, R]_8' in strategy_name_list
|
||||||
|
assert '[S1, R, R, R] -> [S1, R, R, R]_9' in strategy_name_list
|
||||||
|
assert '[R, S1, R, R] -> [R, R, S1, R]_10' in strategy_name_list
|
||||||
|
assert '[R, R, S1, R] -> [R, S1, R, R]_11' in strategy_name_list
|
||||||
|
assert '[R, R, R, S1] -> [R, R, R, S1]_12' in strategy_name_list
|
||||||
|
assert '[R, R, R, S0] -> [R, R, R, S0]_13' in strategy_name_list
|
||||||
|
assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list
|
||||||
|
assert '[R, R, R, R] -> [R, R, R, R]_15' in strategy_name_list
|
||||||
|
assert '[R, R, R, S0] -> [R, R, R, S0]_16' in strategy_name_list
|
||||||
|
assert '[R, R, R, S1] -> [R, R, R, S1]_17' in strategy_name_list
|
||||||
|
assert '[S01, R, R, R] -> [S01, R, R, R]_18' in strategy_name_list
|
||||||
|
assert '[R, S01, R, R] -> [R, R, S01, R]_19' in strategy_name_list
|
||||||
|
assert '[R, R, S01, R] -> [R, S01, R, R]_20' in strategy_name_list
|
||||||
|
assert '[R, R, R, R] -> [R, R, R, R]_21' in strategy_name_list
|
||||||
|
assert '[R, R, R, S01] -> [R, R, R, S01]_22' in strategy_name_list
|
||||||
|
|
||||||
|
if reshape_dims == (2, 0, 1, 3):
|
||||||
|
assert '[S0, R, R, S1] -> [R, S0, R, S1]_0' in strategy_name_list
|
||||||
|
assert '[R, S0, R, S1] -> [R, R, S0, S1]_1' in strategy_name_list
|
||||||
|
assert '[R, R, S0, S1] -> [S0, R, R, S1]_2' in strategy_name_list
|
||||||
|
assert '[S1, R, R, S0] -> [R, S1, R, S0]_3' in strategy_name_list
|
||||||
|
assert '[R, S1, R, S0] -> [R, R, S1, S0]_4' in strategy_name_list
|
||||||
|
assert '[R, R, S1, S0] -> [S1, R, R, S0]_5' in strategy_name_list
|
||||||
|
assert '[S0, R, R, R] -> [R, S0, R, R]_6' in strategy_name_list
|
||||||
|
assert '[R, S0, R, R] -> [R, R, S0, R]_7' in strategy_name_list
|
||||||
|
assert '[R, R, S0, R] -> [S0, R, R, R]_8' in strategy_name_list
|
||||||
|
assert '[S1, R, R, R] -> [R, S1, R, R]_9' in strategy_name_list
|
||||||
|
assert '[R, S1, R, R] -> [R, R, S1, R]_10' in strategy_name_list
|
||||||
|
assert '[R, R, S1, R] -> [S1, R, R, R]_11' in strategy_name_list
|
||||||
|
assert '[R, R, R, S1] -> [R, R, R, S1]_12' in strategy_name_list
|
||||||
|
assert '[R, R, R, S0] -> [R, R, R, S0]_13' in strategy_name_list
|
||||||
|
assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list
|
||||||
|
assert '[R, R, R, R] -> [R, R, R, R]_15' in strategy_name_list
|
||||||
|
assert '[R, R, R, S0] -> [R, R, R, S0]_16' in strategy_name_list
|
||||||
|
assert '[R, R, R, S1] -> [R, R, R, S1]_17' in strategy_name_list
|
||||||
|
assert '[S01, R, R, R] -> [R, S01, R, R]_18' in strategy_name_list
|
||||||
|
assert '[R, S01, R, R] -> [R, R, S01, R]_19' in strategy_name_list
|
||||||
|
assert '[R, R, S01, R] -> [S01, R, R, R]_20' in strategy_name_list
|
||||||
|
assert '[R, R, R, R] -> [R, R, R, R]_21' in strategy_name_list
|
||||||
|
assert '[R, R, R, S01] -> [R, R, R, S01]_22' in strategy_name_list
|
||||||
|
|
||||||
|
if reshape_dims == (1, 3):
|
||||||
|
assert '[S0, R, R, S1] -> [S0, S1, R, R]_0' in strategy_name_list
|
||||||
|
assert '[R, S0, R, S1] -> [R, S1, R, S0]_1' in strategy_name_list
|
||||||
|
assert '[R, R, S0, S1] -> [R, S1, S0, R]_2' in strategy_name_list
|
||||||
|
assert '[S1, R, R, S0] -> [S1, S0, R, R]_3' in strategy_name_list
|
||||||
|
assert '[R, S1, R, S0] -> [R, S0, R, S1]_4' in strategy_name_list
|
||||||
|
assert '[R, R, S1, S0] -> [R, S0, S1, R]_5' in strategy_name_list
|
||||||
|
assert '[S0, R, R, R] -> [S0, R, R, R]_6' in strategy_name_list
|
||||||
|
assert '[R, S0, R, R] -> [R, R, R, S0]_7' in strategy_name_list
|
||||||
|
assert '[R, R, S0, R] -> [R, R, S0, R]_8' in strategy_name_list
|
||||||
|
assert '[S1, R, R, R] -> [S1, R, R, R]_9' in strategy_name_list
|
||||||
|
assert '[R, S1, R, R] -> [R, R, R, S1]_10' in strategy_name_list
|
||||||
|
assert '[R, R, S1, R] -> [R, R, S1, R]_11' in strategy_name_list
|
||||||
|
assert '[R, R, R, S1] -> [R, S1, R, R]_12' in strategy_name_list
|
||||||
|
assert '[R, R, R, S0] -> [R, S0, R, R]_13' in strategy_name_list
|
||||||
|
assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list
|
||||||
|
assert '[R, R, R, R] -> [R, R, R, R]_15' in strategy_name_list
|
||||||
|
assert '[R, R, R, S0] -> [R, S0, R, R]_16' in strategy_name_list
|
||||||
|
assert '[R, R, R, S1] -> [R, S1, R, R]_17' in strategy_name_list
|
||||||
|
assert '[S01, R, R, R] -> [S01, R, R, R]_18' in strategy_name_list
|
||||||
|
assert '[R, S01, R, R] -> [R, R, R, S01]_19' in strategy_name_list
|
||||||
|
assert '[R, R, S01, R] -> [R, R, S01, R]_20' in strategy_name_list
|
||||||
|
assert '[R, R, R, R] -> [R, R, R, R]_21' in strategy_name_list
|
||||||
|
assert '[R, R, R, S01] -> [R, S01, R, R]_22' in strategy_name_list
|
||||||
|
|
||||||
|
|
||||||
|
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||||
|
@pytest.mark.dist
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
@parameterize('call_function', [torch.permute, torch.transpose])
|
||||||
|
@parameterize('reshape_dims', [((0, 2, 1, 3), (1, 2)), ((2, 0, 1, 3), (1, 3))])
|
||||||
|
@parameterize('model_cls', [ConvReshapeModel, LinearReshapeModel])
|
||||||
|
def test_view_handler(call_function, reshape_dims, model_cls):
|
||||||
|
world_size = 4
|
||||||
|
run_func = partial(check_view_handler,
|
||||||
|
call_function=call_function,
|
||||||
|
reshape_dims=reshape_dims,
|
||||||
|
model_cls=model_cls,
|
||||||
|
world_size=world_size,
|
||||||
|
port=free_port())
|
||||||
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_view_handler()
|
@ -84,7 +84,7 @@ def check_view_handler(rank, tgt_shape, model_cls, world_size, port):
|
|||||||
# return view
|
# return view
|
||||||
graph = tracer.trace(model,
|
graph = tracer.trace(model,
|
||||||
meta_args={
|
meta_args={
|
||||||
"input": torch.rand(8, 16, 66, 66).to('meta'),
|
"input": torch.rand(8, 8, 66, 66).to('meta'),
|
||||||
"other": torch.rand(16, 8, 3, 3).to('meta'),
|
"other": torch.rand(16, 8, 3, 3).to('meta'),
|
||||||
})
|
})
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user