[autoparallel] add experimental permute handler (#2029)

This commit is contained in:
YuliangLiu0306 2022-11-27 20:26:52 +08:00 committed by GitHub
parent 95c4532fff
commit 81330b0352
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 657 additions and 37 deletions

View File

@ -37,30 +37,6 @@ def _solution_annotatation(gm: torch.fx.GraphModule, solution: List[int]):
origin_node_sharding_spec_dict[node_index] = strategies_vector[strategy_index].get_sharding_spec_by_name( origin_node_sharding_spec_dict[node_index] = strategies_vector[strategy_index].get_sharding_spec_by_name(
str(node)) str(node))
# experimental pass for torch.Tensor.view
# Arguments of view op will be divided in the sharded dimensions.
for node in nodes:
if node.op == 'call_method' and getattr(node.args[0]._meta_data.__class__, node.target) in (torch.Tensor.view,):
output_dim_partition_dict = node.sharding_spec.dim_partition_dict
device_mesh = node.sharding_spec.device_mesh
new_args = []
for arg in node.args:
if isinstance(arg, Node):
if isinstance(arg._meta_data, int):
new_args.append(arg._meta_data)
else:
new_args.append(arg)
else:
assert isinstance(arg, int), 'The argument in view node should be either type of Node or int.'
new_args.append(arg)
for dim, shard_dims in output_dim_partition_dict.items():
total_shard_size = 1
for shard_dim in shard_dims:
total_shard_size *= device_mesh.shape[shard_dim]
new_args[dim + 1] //= total_shard_size
node.args = tuple(new_args)
# the dict to get input sharding specs of user node # the dict to get input sharding specs of user node
sharding_spec_convert_dict = {} sharding_spec_convert_dict = {}
# the dict to record comm actions of nodes # the dict to record comm actions of nodes
@ -113,7 +89,74 @@ def _solution_annotatation(gm: torch.fx.GraphModule, solution: List[int]):
return gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict return gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict
def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh): def _node_args_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
"""
This pass will process node args to adapt the distributed tensor layout.
"""
mod_graph = gm.graph
nodes = tuple(mod_graph.nodes)
for node in nodes:
# skip the placeholder node added in _solution_annotation pass
if not hasattr(node, 'sharding_spec'):
continue
output_dim_partition_dict = node.sharding_spec.dim_partition_dict
device_mesh = node.sharding_spec.device_mesh
new_args = []
if node.op == 'call_method':
method = getattr(node.args[0]._meta_data.__class__, node.target)
# process the node with (input, *shape) style args
if method in (torch.Tensor.view, torch.Tensor.reshape):
for arg in node.args:
if isinstance(arg, Node):
if isinstance(arg._meta_data, int):
new_args.append(arg._meta_data)
else:
new_args.append(arg)
else:
assert isinstance(arg, int), 'The argument in view node should be either type of Node or int.'
new_args.append(arg)
for dim, shard_dims in output_dim_partition_dict.items():
# we will skip the dim with -1 value
if new_args[dim + 1] == -1:
continue
total_shard_size = 1
for shard_dim in shard_dims:
total_shard_size *= device_mesh.shape[shard_dim]
new_args[dim + 1] //= total_shard_size
node.args = tuple(new_args)
elif node.op == 'call_function':
target = node.target
# process the node with (input, torch.Size) style args
if target in (torch.reshape,):
for arg in node.args:
if isinstance(arg, Node):
if isinstance(arg._meta_data, (tuple, list)):
new_args.append(list(arg._meta_data))
else:
new_args.append(arg)
else:
assert isinstance(
arg, (tuple, list)), 'The argument in reshape node should be either type of Node or tuple.'
new_args.append(list(arg))
for dim, shard_dims in output_dim_partition_dict.items():
# we will skip the dim with -1 value
if new_args[1][dim] == -1:
continue
total_shard_size = 1
for shard_dim in shard_dims:
total_shard_size *= device_mesh.shape[shard_dim]
new_args[1][dim] //= total_shard_size
node.args = tuple(new_args)
return gm
def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
""" """
Apply the sharding action to the module parameters and buffers following the Apply the sharding action to the module parameters and buffers following the
instructions of solver solution. instructions of solver solution.
@ -216,6 +259,7 @@ def implicit_comm_action_apply(gm: torch.fx.GraphModule):
def runtime_preparation_pass(gm: torch.fx.GraphModule, solution: List[int], device_mesh: DeviceMesh): def runtime_preparation_pass(gm: torch.fx.GraphModule, solution: List[int], device_mesh: DeviceMesh):
gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict = _solution_annotatation( gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict = _solution_annotatation(
gm, solution) gm, solution)
gm = _node_args_converting(gm, device_mesh)
# TODO: the pass below should be uncommented after the implementation of implicit_comm_action_apply_pass completed. # TODO: the pass below should be uncommented after the implementation of implicit_comm_action_apply_pass completed.
# gm = implicit_comm_action_apply(gm) # gm = implicit_comm_action_apply(gm)
gm = _module_params_sharding(gm, device_mesh) gm = _module_params_sharding(gm, device_mesh)

View File

@ -3,6 +3,7 @@ from .batch_norm_handler import BatchNormModuleHandler
from .binary_elementwise_handler import BinaryElementwiseHandler from .binary_elementwise_handler import BinaryElementwiseHandler
from .bmm_handler import AddBMMFunctionHandler, BMMFunctionHandler from .bmm_handler import AddBMMFunctionHandler, BMMFunctionHandler
from .conv_handler import ConvFunctionHandler, ConvModuleHandler from .conv_handler import ConvFunctionHandler, ConvModuleHandler
from .experimental import PermuteHandler, ViewHandler
from .getatrr_handler import GetattrHandler from .getatrr_handler import GetattrHandler
from .getitem_handler import GetItemHandler from .getitem_handler import GetItemHandler
from .layer_norm_handler import LayerNormModuleHandler from .layer_norm_handler import LayerNormModuleHandler
@ -21,5 +22,5 @@ __all__ = [
'LayerNormModuleHandler', 'BatchNormModuleHandler', 'ConvModuleHandler', 'ConvFunctionHandler', 'LayerNormModuleHandler', 'BatchNormModuleHandler', 'ConvModuleHandler', 'ConvFunctionHandler',
'UnaryElementwiseHandler', 'ReshapeHandler', 'PlacehodlerHandler', 'OuputHandler', 'WhereHandler', 'UnaryElementwiseHandler', 'ReshapeHandler', 'PlacehodlerHandler', 'OuputHandler', 'WhereHandler',
'NormPoolingHandler', 'BinaryElementwiseHandler', 'MatMulHandler', 'operator_registry', 'ADDMMFunctionHandler', 'NormPoolingHandler', 'BinaryElementwiseHandler', 'MatMulHandler', 'operator_registry', 'ADDMMFunctionHandler',
'GetItemHandler', 'GetattrHandler' 'GetItemHandler', 'GetattrHandler', 'ViewHandler', 'PermuteHandler'
] ]

View File

@ -1,4 +1,8 @@
from .view_generator import ViewGenerator from .permute_handler import PermuteHandler
from .reshape_generator import PermuteGenerator, TransposeGenerator, ViewGenerator
from .transpose_handler import TransposeHandler
from .view_handler import ViewHandler from .view_handler import ViewHandler
__all__ = ['ViewGenerator', 'ViewHandler'] __all__ = [
'ViewGenerator', 'ViewHandler', 'PermuteGenerator', 'PermuteHandler', 'TransposeGenerator', 'TransposeGenerator'
]

View File

@ -0,0 +1,76 @@
from typing import Dict, List
import torch
from ...sharding_strategy import OperationData, OperationDataType
from ..node_handler import NodeHandler
from ..registry import operator_registry
from ..strategy import StrategyGenerator
from .reshape_generator import PermuteGenerator
__all__ = ['PermuteHandler']
@operator_registry.register(torch.Tensor.permute)
@operator_registry.register(torch.permute)
class PermuteHandler(NodeHandler):
"""
A PermuteHandler which deals with the sharding strategies for torch.permute or torch.transpose.
"""
def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(PermuteGenerator(op_data_mapping, self.device_mesh, self.node.args[0]))
return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# check if the input operand is a parameter
if isinstance(self.node.args[0]._meta_data, torch.nn.parameter.Parameter):
data_type = OperationDataType.PARAM
else:
data_type = OperationDataType.ARG
input_data = self.node.args[0]._meta_data
physical_input_operand = OperationData(name=str(self.node.args[0]), type=data_type, data=input_data)
permute_dims = []
if self.node.op == 'call_method':
# torch.Tensor.permute (input, *dims)
for arg in self.node.args:
if isinstance(arg, torch.fx.Node):
if isinstance(arg._meta_data, int):
permute_dims.append(arg._meta_data)
else:
assert isinstance(arg, int), 'The argument in permute node should be either type of Node or int.'
permute_dims.append(arg)
else:
# torch.permute (input, dims)
for arg in self.node.args:
if isinstance(arg, torch.fx.Node):
if isinstance(arg._meta_data, (tuple, list)):
permute_dims.extend(arg._meta_data)
else:
assert isinstance(
arg,
(tuple, list)), 'The argument in permute node should be type of Node, Tuple[int] or List[int].'
permute_dims.extend(arg)
num_dims = self.node._meta_data.dim()
for i in range(num_dims):
# recover negative value to positive
if permute_dims[i] < 0:
permute_dims[i] += num_dims
physical_shape_operand = OperationData(name='permute_dims', type=OperationDataType.ARG, data=list(permute_dims))
output_data = self.node._meta_data
physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data)
mapping = {
"input": physical_input_operand,
"permute_dims": physical_shape_operand,
"output": physical_output_operand
}
return mapping

View File

@ -17,12 +17,12 @@ from colossalai.auto_parallel.tensor_shard.utils import (
from colossalai.tensor.shape_consistency import CollectiveCommPattern from colossalai.tensor.shape_consistency import CollectiveCommPattern
from colossalai.tensor.sharding_spec import ShardingSpec from colossalai.tensor.sharding_spec import ShardingSpec
__all__ = ['ViewGenerator'] __all__ = ['ReshapeGenerator', 'ViewGenerator', 'PermuteGenerator', 'TransposeGenerator']
class ViewGenerator(FollowingStrategyGenerator): class ReshapeGenerator(FollowingStrategyGenerator):
""" """
ViewGenerator which deals with the sharding strategies of view op. ReshapeGenerator is the base class for all the reshape operation.
""" """
def validate(self) -> bool: def validate(self) -> bool:
@ -61,6 +61,15 @@ class ViewGenerator(FollowingStrategyGenerator):
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost strategy.memory_cost = memory_cost
def collate_strategies(self) -> List[ShardingStrategy]:
return super().collate_strategies()
class ViewGenerator(ReshapeGenerator):
"""
ViewGenerator deals with the sharding strategies of view op.
"""
def collate_strategies(self) -> List[ShardingStrategy]: def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = [] strategy_list = []
for index, strategy in enumerate(self.predecessor_node.strategies_vector): for index, strategy in enumerate(self.predecessor_node.strategies_vector):
@ -136,3 +145,85 @@ class ViewGenerator(FollowingStrategyGenerator):
strategy_list.append(strategy) strategy_list.append(strategy)
return strategy_list return strategy_list
class PermuteGenerator(ReshapeGenerator):
"""
PermuteGenerator deals with the sharding strategies of permute op.
"""
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
for index, strategy in enumerate(self.predecessor_node.strategies_vector):
dim_partition_dict_mapping = {}
communication_action_mapping = {}
input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]]
permute_dims = self.op_data['permute_dims'].data
dim_partition_dict_for_input = input_sharding_spec.dim_partition_dict
dim_partition_dict_for_output = {}
for dim_index, permute_dim in enumerate(permute_dims):
if permute_dim in dim_partition_dict_for_input:
dim_partition_dict_for_output[dim_index] = dim_partition_dict_for_input[permute_dim]
dim_partition_dict_mapping = {
"input": dim_partition_dict_for_input,
"output": dim_partition_dict_for_output,
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# add index into name to pass the duplicated check
# we keep same strategies with different name for node merging, and it will not increase the searching space,
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}'
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
strategy_list.append(strategy)
return strategy_list
class TransposeGenerator(ReshapeGenerator):
"""
TransposeGenerator deals with the sharding strategies of permute op.
"""
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
for index, strategy in enumerate(self.predecessor_node.strategies_vector):
dim_partition_dict_mapping = {}
communication_action_mapping = {}
input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]]
dim_partition_dict_for_input = input_sharding_spec.dim_partition_dict
dim_partition_dict_for_output = {}
transpose_dims = self.op_data['transpose_dims'].data
dim_0 = transpose_dims[0]
dim_1 = transpose_dims[1]
for dim, sharded_dims in dim_partition_dict_for_input.items():
if dim == dim_0:
dim_partition_dict_for_output[dim_1] = dim_partition_dict_for_input[dim_0]
elif dim == dim_1:
dim_partition_dict_for_output[dim_0] = dim_partition_dict_for_input[dim_1]
else:
dim_partition_dict_for_output[dim] = sharded_dims
dim_partition_dict_mapping = {
"input": dim_partition_dict_for_input,
"output": dim_partition_dict_for_output,
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# add index into name to pass the duplicated check
# we keep same strategies with different name for node merging, and it will not increase the searching space,
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}'
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
strategy_list.append(strategy)
return strategy_list

View File

@ -0,0 +1,65 @@
from typing import Dict, List
import torch
from ...sharding_strategy import OperationData, OperationDataType
from ..node_handler import NodeHandler
from ..registry import operator_registry
from ..strategy import StrategyGenerator
from .reshape_generator import TransposeGenerator
__all__ = ['TransposeHandler']
@operator_registry.register(torch.Tensor.transpose)
@operator_registry.register(torch.transpose)
class TransposeHandler(NodeHandler):
"""
A TransposeHandler which deals with the sharding strategies for torch.permute or torch.transpose.
"""
def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(TransposeGenerator(op_data_mapping, self.device_mesh, self.node.args[0]))
return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# check if the input operand is a parameter
if isinstance(self.node.args[0]._meta_data, torch.nn.parameter.Parameter):
data_type = OperationDataType.PARAM
else:
data_type = OperationDataType.ARG
input_data = self.node.args[0]._meta_data
physical_input_operand = OperationData(name=str(self.node.args[0]), type=data_type, data=input_data)
transpose_dims = []
# torch.transpose (input, dim0, dim1)
for arg in self.node.args:
if isinstance(arg, torch.fx.Node):
if isinstance(arg._meta_data, int):
transpose_dims.append(arg._meta_data)
else:
transpose_dims.append(arg)
num_dims = self.node._meta_data.dim()
for i in range(2):
# recover negative value to positive
if transpose_dims[i] < 0:
transpose_dims[i] += num_dims
physical_shape_operand = OperationData(name='transpose_dims',
type=OperationDataType.ARG,
data=list(transpose_dims))
output_data = self.node._meta_data
physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data)
mapping = {
"input": physical_input_operand,
"transpose_dims": physical_shape_operand,
"output": physical_output_operand
}
return mapping

View File

@ -6,11 +6,13 @@ from ...sharding_strategy import OperationData, OperationDataType
from ..node_handler import NodeHandler from ..node_handler import NodeHandler
from ..registry import operator_registry from ..registry import operator_registry
from ..strategy import StrategyGenerator from ..strategy import StrategyGenerator
from .view_generator import ViewGenerator from .reshape_generator import ViewGenerator
__all__ = ['ViewHandler'] __all__ = ['ViewHandler']
@operator_registry.register(torch.Tensor.reshape)
@operator_registry.register(torch.reshape)
@operator_registry.register(torch.Tensor.view) @operator_registry.register(torch.Tensor.view)
class ViewHandler(NodeHandler): class ViewHandler(NodeHandler):
""" """

View File

@ -10,13 +10,9 @@ from .strategy import ReshapeGenerator, StrategyGenerator
__all__ = ['ReshapeHandler'] __all__ = ['ReshapeHandler']
@operator_registry.register(torch.reshape)
@operator_registry.register(torch.Tensor.split) @operator_registry.register(torch.Tensor.split)
@operator_registry.register(torch.split) @operator_registry.register(torch.split)
@operator_registry.register(torch.flatten) @operator_registry.register(torch.flatten)
@operator_registry.register(torch.Tensor.transpose)
@operator_registry.register(torch.Tensor.permute)
@operator_registry.register(torch.Tensor.view)
@operator_registry.register(torch.nn.AdaptiveAvgPool2d) @operator_registry.register(torch.nn.AdaptiveAvgPool2d)
class ReshapeHandler(NodeHandler): class ReshapeHandler(NodeHandler):
""" """

View File

@ -23,6 +23,8 @@ def _all_gather(tensor, comm_spec):
torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device)
for _ in range(comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axis]) for _ in range(comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axis])
] ]
# without this contiguous operation, the all gather may get some unexpected results.
tensor = tensor.contiguous()
dist.all_gather(tensor_list, tensor, group=process_group) dist.all_gather(tensor_list, tensor, group=process_group)
output = torch.cat(tuple(tensor_list), comm_spec.gather_dim).contiguous() output = torch.cat(tuple(tensor_list), comm_spec.gather_dim).contiguous()
return output return output

View File

@ -0,0 +1,339 @@
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler
from colossalai.auto_parallel.tensor_shard.node_handler.experimental import PermuteHandler, TransposeHandler
from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
class ConvReshapeModel(nn.Module):
def __init__(self, reshape_dims, call_function):
super().__init__()
self.reshape_dims = reshape_dims
self.call_function = call_function
def forward(self, input, other):
conv_node = nn.functional.conv2d(input, other, bias=None)
# permute_node = torch.permute(conv_node, self.permute_dims)
if self.call_function == torch.permute:
permute_node = self.call_function(conv_node, self.reshape_dims)
else:
permute_node = self.call_function(conv_node, *self.reshape_dims)
return permute_node
class LinearReshapeModel(nn.Module):
def __init__(self, reshape_dims, call_function):
super().__init__()
self.reshape_dims = reshape_dims
self.call_function = call_function
def forward(self, input, other):
linear_node = nn.functional.linear(input, other, bias=None)
# permute_node = torch.permute(linear_node, self.tgt_shape)
if self.call_function == torch.permute:
permute_node = self.call_function(linear_node, self.reshape_dims)
else:
permute_node = self.call_function(linear_node, *self.reshape_dims)
return permute_node
def check_view_handler(rank, call_function, reshape_dims, model_cls, world_size, port):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
if call_function == torch.permute:
reshape_dims = reshape_dims[0]
elif call_function == torch.transpose:
reshape_dims = reshape_dims[1]
model = model_cls(reshape_dims, call_function).cuda()
if model_cls.__name__ == 'ConvReshapeModel':
input = torch.rand(8, 8, 66, 66).to('cuda')
other = torch.rand(16, 8, 3, 3).to('cuda')
# index of conv node in computation graph
node_index = 2
# total number of conv strategies
strategy_number = 16
if model_cls.__name__ == 'LinearReshapeModel':
input = torch.rand(8, 16, 64, 32).to('cuda')
other = torch.rand(64, 32).to('cuda')
# index of linear node in computation graph
node_index = 2
# total number of linear strategies
strategy_number = 23
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
numerical_test_for_node_strategy(model=model,
device_mesh=device_mesh,
node_index=node_index,
strategy_number=strategy_number,
input_args=[input, other],
meta_arg_names=['input', 'other'],
node_type='following')
tracer = ColoTracer()
if model_cls.__name__ == 'ConvReshapeModel':
# graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
# %other : torch.Tensor [#users=1] = placeholder[target=other]
# %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%input_1, %other), kwargs = {bias: None})
# %permute : [#users=1] = call_function[target=torch.permute](args = (%conv2d, (0, 2, 1, 3)), kwargs = {})
# return permute
graph = tracer.trace(model,
meta_args={
"input": torch.rand(8, 8, 66, 66).to('meta'),
"other": torch.rand(16, 8, 3, 3).to('meta'),
})
if model_cls.__name__ == 'LinearReshapeModel':
# graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
# %other : torch.Tensor [#users=1] = placeholder[target=other]
# %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%input_1, %other), kwargs = {bias: None})
# %permute : [#users=1] = call_method[target=view](args = (%linear, 32, 4, 32, 32, 4), kwargs = {})
# return permute
graph = tracer.trace(model,
meta_args={
"input": torch.rand(8, 16, 64, 32).to('meta'),
"other": torch.rand(64, 32).to('meta'),
})
gm = ColoGraphModule(model, graph)
previous_mod_node = list(graph.nodes)[2]
reshape_node = list(graph.nodes)[3]
view_strategies_vector = StrategiesVector(reshape_node)
previous_strategies_vector = StrategiesVector(previous_mod_node)
# build handler
if model_cls.__name__ == 'ConvReshapeModel':
conv_handler = ConvFunctionHandler(node=previous_mod_node,
device_mesh=device_mesh,
strategies_vector=previous_strategies_vector)
conv_handler.register_strategy(compute_resharding_cost=False)
setattr(previous_mod_node, 'strategies_vector', previous_strategies_vector)
if model_cls.__name__ == 'LinearReshapeModel':
assert len(previous_strategies_vector) == 0
linear_handler = LinearFunctionHandler(node=previous_mod_node,
device_mesh=device_mesh,
strategies_vector=previous_strategies_vector)
linear_handler.register_strategy(compute_resharding_cost=False)
setattr(previous_mod_node, 'strategies_vector', previous_strategies_vector)
if call_function == torch.permute:
reshape_handler = PermuteHandler(node=reshape_node,
device_mesh=device_mesh,
strategies_vector=view_strategies_vector)
else:
reshape_handler = TransposeHandler(node=reshape_node,
device_mesh=device_mesh,
strategies_vector=view_strategies_vector)
reshape_handler.register_strategy(compute_resharding_cost=False)
# check operation data mapping
mapping = reshape_handler.get_operation_data_mapping()
for name, op_data in mapping.items():
op_data: OperationData
# make sure they have valid values
assert op_data.data is not None
if model_cls.__name__ == 'ConvReshapeModel':
assert mapping['input'].name == "conv2d"
else:
assert mapping['input'].name == "linear"
assert mapping['input'].data.is_meta
assert mapping['input'].data.shape == torch.Size([8, 16, 64, 64])
assert mapping['input'].type == OperationDataType.ARG
assert mapping['input'].logical_shape == torch.Size([8, 16, 64, 64])
if call_function == torch.permute:
assert mapping['output'].name == "permute"
assert mapping['output'].data.is_meta
assert mapping['output'].data.shape == torch.permute(torch.rand(8, 16, 64, 64), reshape_dims).shape
assert mapping['output'].type == OperationDataType.OUTPUT
else:
assert mapping['output'].name == "transpose"
assert mapping['output'].data.is_meta
assert mapping['output'].data.shape == torch.transpose(torch.rand(8, 16, 64, 64), *reshape_dims).shape
assert mapping['output'].type == OperationDataType.OUTPUT
# reshape handler is a following strategy handler, so the number of strategies is equal to the predecessor node.
assert len(view_strategies_vector) == len(previous_strategies_vector)
strategy_name_list = [strategy.name for strategy in view_strategies_vector]
if rank == 0:
for name in strategy_name_list:
print(name)
if model_cls.__name__ == 'ConvReshapeModel':
if reshape_dims in ((0, 2, 1, 3), (1, 2)):
assert '[S0, S1, R, R] -> [S0, R, S1, R]_0' in strategy_name_list
assert '[S1, S0, R, R] -> [S1, R, S0, R]_1' in strategy_name_list
assert '[S0, R, R, R] -> [S0, R, R, R]_2' in strategy_name_list
assert '[S1, R, R, R] -> [S1, R, R, R]_3' in strategy_name_list
assert '[S0, R, R, R] -> [S0, R, R, R]_4' in strategy_name_list
assert '[S1, R, R, R] -> [S1, R, R, R]_5' in strategy_name_list
assert '[R, S1, R, R] -> [R, R, S1, R]_6' in strategy_name_list
assert '[R, S0, R, R] -> [R, R, S0, R]_7' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_9' in strategy_name_list
assert '[R, S0, R, R] -> [R, R, S0, R]_10' in strategy_name_list
assert '[R, S1, R, R] -> [R, R, S1, R]_11' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_12' in strategy_name_list
assert '[S01, R, R, R] -> [S01, R, R, R]_13' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list
assert '[R, S01, R, R] -> [R, R, S01, R]_15' in strategy_name_list
if reshape_dims == (2, 0, 1, 3):
assert '[S0, S1, R, R] -> [R, S0, S1, R]_0' in strategy_name_list
assert '[S1, S0, R, R] -> [R, S1, S0, R]_1' in strategy_name_list
assert '[S0, R, R, R] -> [R, S0, R, R]_2' in strategy_name_list
assert '[S1, R, R, R] -> [R, S1, R, R]_3' in strategy_name_list
assert '[S0, R, R, R] -> [R, S0, R, R]_4' in strategy_name_list
assert '[S1, R, R, R] -> [R, S1, R, R]_5' in strategy_name_list
assert '[R, S1, R, R] -> [R, R, S1, R]_6' in strategy_name_list
assert '[R, S0, R, R] -> [R, R, S0, R]_7' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_9' in strategy_name_list
assert '[R, S0, R, R] -> [R, R, S0, R]_10' in strategy_name_list
assert '[R, S1, R, R] -> [R, R, S1, R]_11' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_12' in strategy_name_list
assert '[S01, R, R, R] -> [R, S01, R, R]_13' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list
assert '[R, S01, R, R] -> [R, R, S01, R]_15' in strategy_name_list
if reshape_dims == (1, 3):
assert '[S0, S1, R, R] -> [S0, R, R, S1]_0' in strategy_name_list
assert '[S1, S0, R, R] -> [S1, R, R, S0]_1' in strategy_name_list
assert '[S0, R, R, R] -> [S0, R, R, R]_2' in strategy_name_list
assert '[S1, R, R, R] -> [S1, R, R, R]_3' in strategy_name_list
assert '[S0, R, R, R] -> [S0, R, R, R]_4' in strategy_name_list
assert '[S1, R, R, R] -> [S1, R, R, R]_5' in strategy_name_list
assert '[R, S1, R, R] -> [R, R, R, S1]_6' in strategy_name_list
assert '[R, S0, R, R] -> [R, R, R, S0]_7' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_9' in strategy_name_list
assert '[R, S0, R, R] -> [R, R, R, S0]_10' in strategy_name_list
assert '[R, S1, R, R] -> [R, R, R, S1]_11' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_12' in strategy_name_list
assert '[S01, R, R, R] -> [S01, R, R, R]_13' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list
assert '[R, S01, R, R] -> [R, R, R, S01]_15' in strategy_name_list
if model_cls.__name__ == 'LinearReshapeModel':
if reshape_dims == ((0, 2, 1, 3), (1, 2)):
assert '[S0, R, R, S1] -> [S0, R, R, S1]_0' in strategy_name_list
assert '[R, S0, R, S1] -> [R, R, S0, S1]_1' in strategy_name_list
assert '[R, R, S0, S1] -> [R, S0, R, S1]_2' in strategy_name_list
assert '[S1, R, R, S0] -> [S1, R, R, S0]_3' in strategy_name_list
assert '[R, S1, R, S0] -> [R, R, S1, S0]_4' in strategy_name_list
assert '[R, R, S1, S0] -> [R, S1, R, S0]_5' in strategy_name_list
assert '[S0, R, R, R] -> [S0, R, R, R]_6' in strategy_name_list
assert '[R, S0, R, R] -> [R, R, S0, R]_7' in strategy_name_list
assert '[R, R, S0, R] -> [R, S0, R, R]_8' in strategy_name_list
assert '[S1, R, R, R] -> [S1, R, R, R]_9' in strategy_name_list
assert '[R, S1, R, R] -> [R, R, S1, R]_10' in strategy_name_list
assert '[R, R, S1, R] -> [R, S1, R, R]_11' in strategy_name_list
assert '[R, R, R, S1] -> [R, R, R, S1]_12' in strategy_name_list
assert '[R, R, R, S0] -> [R, R, R, S0]_13' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_15' in strategy_name_list
assert '[R, R, R, S0] -> [R, R, R, S0]_16' in strategy_name_list
assert '[R, R, R, S1] -> [R, R, R, S1]_17' in strategy_name_list
assert '[S01, R, R, R] -> [S01, R, R, R]_18' in strategy_name_list
assert '[R, S01, R, R] -> [R, R, S01, R]_19' in strategy_name_list
assert '[R, R, S01, R] -> [R, S01, R, R]_20' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_21' in strategy_name_list
assert '[R, R, R, S01] -> [R, R, R, S01]_22' in strategy_name_list
if reshape_dims == (2, 0, 1, 3):
assert '[S0, R, R, S1] -> [R, S0, R, S1]_0' in strategy_name_list
assert '[R, S0, R, S1] -> [R, R, S0, S1]_1' in strategy_name_list
assert '[R, R, S0, S1] -> [S0, R, R, S1]_2' in strategy_name_list
assert '[S1, R, R, S0] -> [R, S1, R, S0]_3' in strategy_name_list
assert '[R, S1, R, S0] -> [R, R, S1, S0]_4' in strategy_name_list
assert '[R, R, S1, S0] -> [S1, R, R, S0]_5' in strategy_name_list
assert '[S0, R, R, R] -> [R, S0, R, R]_6' in strategy_name_list
assert '[R, S0, R, R] -> [R, R, S0, R]_7' in strategy_name_list
assert '[R, R, S0, R] -> [S0, R, R, R]_8' in strategy_name_list
assert '[S1, R, R, R] -> [R, S1, R, R]_9' in strategy_name_list
assert '[R, S1, R, R] -> [R, R, S1, R]_10' in strategy_name_list
assert '[R, R, S1, R] -> [S1, R, R, R]_11' in strategy_name_list
assert '[R, R, R, S1] -> [R, R, R, S1]_12' in strategy_name_list
assert '[R, R, R, S0] -> [R, R, R, S0]_13' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_15' in strategy_name_list
assert '[R, R, R, S0] -> [R, R, R, S0]_16' in strategy_name_list
assert '[R, R, R, S1] -> [R, R, R, S1]_17' in strategy_name_list
assert '[S01, R, R, R] -> [R, S01, R, R]_18' in strategy_name_list
assert '[R, S01, R, R] -> [R, R, S01, R]_19' in strategy_name_list
assert '[R, R, S01, R] -> [S01, R, R, R]_20' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_21' in strategy_name_list
assert '[R, R, R, S01] -> [R, R, R, S01]_22' in strategy_name_list
if reshape_dims == (1, 3):
assert '[S0, R, R, S1] -> [S0, S1, R, R]_0' in strategy_name_list
assert '[R, S0, R, S1] -> [R, S1, R, S0]_1' in strategy_name_list
assert '[R, R, S0, S1] -> [R, S1, S0, R]_2' in strategy_name_list
assert '[S1, R, R, S0] -> [S1, S0, R, R]_3' in strategy_name_list
assert '[R, S1, R, S0] -> [R, S0, R, S1]_4' in strategy_name_list
assert '[R, R, S1, S0] -> [R, S0, S1, R]_5' in strategy_name_list
assert '[S0, R, R, R] -> [S0, R, R, R]_6' in strategy_name_list
assert '[R, S0, R, R] -> [R, R, R, S0]_7' in strategy_name_list
assert '[R, R, S0, R] -> [R, R, S0, R]_8' in strategy_name_list
assert '[S1, R, R, R] -> [S1, R, R, R]_9' in strategy_name_list
assert '[R, S1, R, R] -> [R, R, R, S1]_10' in strategy_name_list
assert '[R, R, S1, R] -> [R, R, S1, R]_11' in strategy_name_list
assert '[R, R, R, S1] -> [R, S1, R, R]_12' in strategy_name_list
assert '[R, R, R, S0] -> [R, S0, R, R]_13' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_15' in strategy_name_list
assert '[R, R, R, S0] -> [R, S0, R, R]_16' in strategy_name_list
assert '[R, R, R, S1] -> [R, S1, R, R]_17' in strategy_name_list
assert '[S01, R, R, R] -> [S01, R, R, R]_18' in strategy_name_list
assert '[R, S01, R, R] -> [R, R, R, S01]_19' in strategy_name_list
assert '[R, R, S01, R] -> [R, R, S01, R]_20' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_21' in strategy_name_list
assert '[R, R, R, S01] -> [R, S01, R, R]_22' in strategy_name_list
@run_on_environment_flag(name='AUTO_PARALLEL')
@pytest.mark.dist
@rerun_if_address_is_in_use()
@parameterize('call_function', [torch.permute, torch.transpose])
@parameterize('reshape_dims', [((0, 2, 1, 3), (1, 2)), ((2, 0, 1, 3), (1, 3))])
@parameterize('model_cls', [ConvReshapeModel, LinearReshapeModel])
def test_view_handler(call_function, reshape_dims, model_cls):
world_size = 4
run_func = partial(check_view_handler,
call_function=call_function,
reshape_dims=reshape_dims,
model_cls=model_cls,
world_size=world_size,
port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_view_handler()

View File

@ -84,7 +84,7 @@ def check_view_handler(rank, tgt_shape, model_cls, world_size, port):
# return view # return view
graph = tracer.trace(model, graph = tracer.trace(model,
meta_args={ meta_args={
"input": torch.rand(8, 16, 66, 66).to('meta'), "input": torch.rand(8, 8, 66, 66).to('meta'),
"other": torch.rand(16, 8, 3, 3).to('meta'), "other": torch.rand(16, 8, 3, 3).to('meta'),
}) })