mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-22 13:41:43 +00:00
[autoparallel] fixed wrong sharding strategy in conv handler (#1747)
* [autoparallel] fixed wrong sharding strategy in conv handler * polish code
This commit is contained in:
parent
8b8937d901
commit
474111ecb5
@ -4,6 +4,7 @@ import torch
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy
|
from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy
|
||||||
|
from ..utils import transpose_partition_dim
|
||||||
from .node_handler import ModuleHandler, NodeHandler
|
from .node_handler import ModuleHandler, NodeHandler
|
||||||
from .registry import operator_registry
|
from .registry import operator_registry
|
||||||
from .strategy import ConvStrategyGenerator, StrategyGenerator
|
from .strategy import ConvStrategyGenerator, StrategyGenerator
|
||||||
@ -55,20 +56,7 @@ class ConvModuleHandler(ModuleHandler):
|
|||||||
"""
|
"""
|
||||||
for op_data, sharding_spec in strategy.input_sharding_specs.items():
|
for op_data, sharding_spec in strategy.input_sharding_specs.items():
|
||||||
if op_data.name == "weight":
|
if op_data.name == "weight":
|
||||||
dim_partition_dict = sharding_spec.dim_partition_dict
|
transpose_partition_dim(sharding_spec, 0, 1)
|
||||||
|
|
||||||
# switch first and second dim of the conv module weight
|
|
||||||
first_dim_partition = dim_partition_dict.pop(1, None)
|
|
||||||
second_dim_partition = dim_partition_dict.pop(0, None)
|
|
||||||
|
|
||||||
if first_dim_partition:
|
|
||||||
dim_partition_dict[0] = first_dim_partition
|
|
||||||
|
|
||||||
if second_dim_partition:
|
|
||||||
dim_partition_dict[1] = second_dim_partition
|
|
||||||
|
|
||||||
# re-init the sharding spec
|
|
||||||
sharding_spec.__init__(sharding_spec.device_mesh, op_data.data.shape, dim_partition_dict)
|
|
||||||
return strategy
|
return strategy
|
||||||
|
|
||||||
|
|
||||||
@ -110,7 +98,7 @@ class ConvFunctionHandler(NodeHandler):
|
|||||||
|
|
||||||
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
|
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
|
||||||
|
|
||||||
if "bias" in self.node.kwargs:
|
if "bias" in self.node.kwargs and self.node.kwargs['bias'] is not None:
|
||||||
# check if the other operand is a parameter
|
# check if the other operand is a parameter
|
||||||
if isinstance(self.node.kwargs["bias"]._meta_data, torch.nn.parameter.Parameter):
|
if isinstance(self.node.kwargs["bias"]._meta_data, torch.nn.parameter.Parameter):
|
||||||
data_type = OperationDataType.PARAM
|
data_type = OperationDataType.PARAM
|
||||||
@ -128,19 +116,5 @@ class ConvFunctionHandler(NodeHandler):
|
|||||||
"""
|
"""
|
||||||
for op_data, sharding_spec in strategy.input_sharding_specs.items():
|
for op_data, sharding_spec in strategy.input_sharding_specs.items():
|
||||||
if op_data.name == str(self.node.args[1]):
|
if op_data.name == str(self.node.args[1]):
|
||||||
assert op_data.logical_shape != op_data.data.shape
|
transpose_partition_dim(sharding_spec, 0, 1)
|
||||||
dim_partition_dict = sharding_spec.dim_partition_dict
|
|
||||||
|
|
||||||
# switch first and second dim of the conv function weight
|
|
||||||
first_dim_partition = dim_partition_dict.pop(1, None)
|
|
||||||
second_dim_partition = dim_partition_dict.pop(0, None)
|
|
||||||
|
|
||||||
if first_dim_partition:
|
|
||||||
dim_partition_dict[0] = first_dim_partition
|
|
||||||
|
|
||||||
if second_dim_partition:
|
|
||||||
dim_partition_dict[1] = second_dim_partition
|
|
||||||
|
|
||||||
# re-init the sharding spec
|
|
||||||
sharding_spec.__init__(sharding_spec.device_mesh, sharding_spec.entire_shape, dim_partition_dict)
|
|
||||||
return strategy
|
return strategy
|
||||||
|
@ -3,7 +3,7 @@ from typing import Dict, List, Union
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from colossalai.auto_parallel.tensor_shard.utils import tranpose_partition_dim, update_partition_dim
|
from colossalai.auto_parallel.tensor_shard.utils import transpose_partition_dim, update_partition_dim
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.tensor.sharding_spec import ShardingNotDivisibleError
|
from colossalai.tensor.sharding_spec import ShardingNotDivisibleError
|
||||||
|
|
||||||
@ -30,7 +30,7 @@ def _update_sharding_spec_for_transposed_weight_for_linear(strategy: ShardingStr
|
|||||||
op_data = strategy.get_op_data_by_name(weight_name)
|
op_data = strategy.get_op_data_by_name(weight_name)
|
||||||
assert op_data.logical_shape != op_data.data.shape, \
|
assert op_data.logical_shape != op_data.data.shape, \
|
||||||
"Expected the logical and physical shape of the linear operator's weight to be different, but found them to be the same"
|
"Expected the logical and physical shape of the linear operator's weight to be different, but found them to be the same"
|
||||||
tranpose_partition_dim(sharding_spec, 0, -1)
|
transpose_partition_dim(sharding_spec, 0, -1)
|
||||||
return strategy
|
return strategy
|
||||||
|
|
||||||
|
|
||||||
|
@ -5,13 +5,13 @@ from .sharding import (
|
|||||||
enumerate_all_possible_1d_sharding,
|
enumerate_all_possible_1d_sharding,
|
||||||
enumerate_all_possible_2d_sharding,
|
enumerate_all_possible_2d_sharding,
|
||||||
generate_sharding_size,
|
generate_sharding_size,
|
||||||
tranpose_partition_dim,
|
transpose_partition_dim,
|
||||||
update_partition_dim,
|
update_partition_dim,
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'BroadcastType', 'get_broadcast_shape', 'is_broadcastable', 'recover_sharding_spec_for_broadcast_shape',
|
'BroadcastType', 'get_broadcast_shape', 'is_broadcastable', 'recover_sharding_spec_for_broadcast_shape',
|
||||||
'generate_resharding_costs', 'generate_sharding_spec', 'ignore_sharding_exception', 'check_sharding_spec_validity'
|
'generate_resharding_costs', 'generate_sharding_spec', 'ignore_sharding_exception', 'check_sharding_spec_validity'
|
||||||
'tranpose_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding',
|
'transpose_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding',
|
||||||
'enumerate_all_possible_2d_sharding', 'generate_sharding_size'
|
'enumerate_all_possible_2d_sharding', 'generate_sharding_size'
|
||||||
]
|
]
|
||||||
|
@ -68,4 +68,5 @@ def check_sharding_spec_validity(sharding_spec: ShardingSpec, tensor: torch.Tens
|
|||||||
f'The dimension at index {i} has value {dim_size}, but it is sharded over {num_devices} devices.'
|
f'The dimension at index {i} has value {dim_size}, but it is sharded over {num_devices} devices.'
|
||||||
|
|
||||||
# make sure the entire shape matches the physical tensor shape
|
# make sure the entire shape matches the physical tensor shape
|
||||||
assert sharding_spec.entire_shape == tensor.shape
|
assert sharding_spec.entire_shape == tensor.shape, \
|
||||||
|
f'The entire_shape of the sharding spec {sharding_spec.entire_shape} does not match the tensor shape {tensor.shape}'
|
||||||
|
@ -8,12 +8,12 @@ import torch
|
|||||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'tranpose_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding',
|
'transpose_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding',
|
||||||
'enumerate_all_possible_2d_sharding', 'generate_sharding_size'
|
'enumerate_all_possible_2d_sharding', 'generate_sharding_size'
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def tranpose_partition_dim(sharding_spec: ShardingSpec, dim1: int, dim2: int) -> ShardingSpec:
|
def transpose_partition_dim(sharding_spec: ShardingSpec, dim1: int, dim2: int) -> ShardingSpec:
|
||||||
"""
|
"""
|
||||||
Switch the sharding mesh dimensions for two tensor dimensions. This operation is in-place.
|
Switch the sharding mesh dimensions for two tensor dimensions. This operation is in-place.
|
||||||
|
|
||||||
|
@ -5,12 +5,12 @@ from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import Conv
|
|||||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||||
from colossalai.device.device_mesh import DeviceMesh
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
from colossalai.testing import parameterize
|
||||||
|
|
||||||
|
|
||||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
@parameterize('bias', [True, False])
|
||||||
def test_conv_module_handler():
|
def test_conv_module_handler(bias):
|
||||||
model = nn.Sequential(nn.Conv2d(4, 16, 3, padding=1).to('meta'))
|
model = nn.Sequential(nn.Conv2d(4, 16, 3, padding=1, bias=bias).to('meta'))
|
||||||
tracer = ColoTracer()
|
tracer = ColoTracer()
|
||||||
# graph():
|
# graph():
|
||||||
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
|
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
|
||||||
@ -49,6 +49,7 @@ def test_conv_module_handler():
|
|||||||
assert mapping['other'].type == OperationDataType.PARAM
|
assert mapping['other'].type == OperationDataType.PARAM
|
||||||
assert mapping['other'].logical_shape == torch.Size([4, 16, 3, 3])
|
assert mapping['other'].logical_shape == torch.Size([4, 16, 3, 3])
|
||||||
|
|
||||||
|
if bias:
|
||||||
assert mapping['bias'].name == "bias"
|
assert mapping['bias'].name == "bias"
|
||||||
assert mapping['bias'].data.is_meta
|
assert mapping['bias'].data.is_meta
|
||||||
assert mapping['bias'].data.shape == torch.Size([16])
|
assert mapping['bias'].data.shape == torch.Size([16])
|
||||||
@ -99,6 +100,24 @@ def test_conv_module_handler():
|
|||||||
# RS01 = RR x RS01
|
# RS01 = RR x RS01
|
||||||
assert 'RS01 = RR x RS01' in strategy_name_list
|
assert 'RS01 = RR x RS01' in strategy_name_list
|
||||||
|
|
||||||
|
for strategy in strategies_vector:
|
||||||
|
input_sharding_spec = strategy.get_sharding_spec_by_name('input_1')
|
||||||
|
weight_sharding_spec = strategy.get_sharding_spec_by_name('weight')
|
||||||
|
output_sharding_spec = strategy.get_sharding_spec_by_name('_0')
|
||||||
|
|
||||||
|
if bias:
|
||||||
|
bias_sharding_spec = strategy.get_sharding_spec_by_name('bias')
|
||||||
|
|
||||||
|
# make sure the sharding matches across different operation data
|
||||||
|
assert output_sharding_spec.sharding_sequence[1] == weight_sharding_spec.sharding_sequence[0]
|
||||||
|
assert input_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[0]
|
||||||
|
assert input_sharding_spec.sharding_sequence[2:] == output_sharding_spec.sharding_sequence[2:]
|
||||||
|
assert input_sharding_spec.sharding_sequence[1] == weight_sharding_spec.sharding_sequence[1]
|
||||||
|
|
||||||
|
if bias:
|
||||||
|
assert bias_sharding_spec.sharding_sequence[-1] == weight_sharding_spec.sharding_sequence[0]
|
||||||
|
assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[1]
|
||||||
|
|
||||||
|
|
||||||
class ConvModel(nn.Module):
|
class ConvModel(nn.Module):
|
||||||
|
|
||||||
@ -110,8 +129,8 @@ class ConvModel(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
@parameterize('bias', [True, False])
|
||||||
def test_conv_function_handler():
|
def test_conv_function_handler(bias):
|
||||||
model = ConvModel()
|
model = ConvModel()
|
||||||
tracer = ColoTracer()
|
tracer = ColoTracer()
|
||||||
# graph():
|
# graph():
|
||||||
@ -119,18 +138,20 @@ def test_conv_function_handler():
|
|||||||
# %others : torch.Tensor [#users=1] = placeholder[target=others]
|
# %others : torch.Tensor [#users=1] = placeholder[target=others]
|
||||||
# %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%input_1, %others), kwargs = {})
|
# %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%input_1, %others), kwargs = {})
|
||||||
# return conv2d
|
# return conv2d
|
||||||
graph = tracer.trace(model,
|
meta_args = {"input": torch.rand(4, 4, 64, 64).to('meta'), "others": torch.rand(16, 4, 3, 3).to('meta')}
|
||||||
meta_args={
|
if bias:
|
||||||
"input": torch.rand(4, 4, 64, 64).to('meta'),
|
meta_args['bias'] = torch.rand(16).to('meta')
|
||||||
"others": torch.rand(16, 4, 3, 3).to('meta'),
|
graph = tracer.trace(model, meta_args=meta_args)
|
||||||
"bias": torch.rand(16).to('meta')
|
|
||||||
})
|
|
||||||
gm = ColoGraphModule(model, graph)
|
gm = ColoGraphModule(model, graph)
|
||||||
physical_mesh_id = torch.arange(0, 4)
|
physical_mesh_id = torch.arange(0, 4)
|
||||||
|
|
||||||
mesh_shape = (2, 2)
|
mesh_shape = (2, 2)
|
||||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||||
|
|
||||||
|
if bias:
|
||||||
conv_mod_node = list(graph.nodes)[3]
|
conv_mod_node = list(graph.nodes)[3]
|
||||||
|
else:
|
||||||
|
conv_mod_node = list(graph.nodes)[2]
|
||||||
strategies_vector = StrategiesVector(conv_mod_node)
|
strategies_vector = StrategiesVector(conv_mod_node)
|
||||||
|
|
||||||
# build handler
|
# build handler
|
||||||
@ -157,6 +178,7 @@ def test_conv_function_handler():
|
|||||||
assert mapping['other'].type == OperationDataType.ARG
|
assert mapping['other'].type == OperationDataType.ARG
|
||||||
assert mapping['other'].logical_shape == torch.Size([4, 16, 3, 3])
|
assert mapping['other'].logical_shape == torch.Size([4, 16, 3, 3])
|
||||||
|
|
||||||
|
if bias:
|
||||||
assert mapping['bias'].name == "bias"
|
assert mapping['bias'].name == "bias"
|
||||||
assert mapping['bias'].data.is_meta
|
assert mapping['bias'].data.is_meta
|
||||||
assert mapping['bias'].data.shape == torch.Size([16])
|
assert mapping['bias'].data.shape == torch.Size([16])
|
||||||
@ -207,6 +229,24 @@ def test_conv_function_handler():
|
|||||||
# RS01 = RR x RS01
|
# RS01 = RR x RS01
|
||||||
assert 'RS01 = RR x RS01' in strategy_name_list
|
assert 'RS01 = RR x RS01' in strategy_name_list
|
||||||
|
|
||||||
|
for strategy in strategies_vector:
|
||||||
|
input_sharding_spec = strategy.get_sharding_spec_by_name('input_1')
|
||||||
|
weight_sharding_spec = strategy.get_sharding_spec_by_name('others')
|
||||||
|
output_sharding_spec = strategy.get_sharding_spec_by_name('conv2d')
|
||||||
|
|
||||||
|
if bias:
|
||||||
|
bias_sharding_spec = strategy.get_sharding_spec_by_name('bias')
|
||||||
|
|
||||||
|
# make sure the sharding matches across different operation data
|
||||||
|
assert output_sharding_spec.sharding_sequence[1] == weight_sharding_spec.sharding_sequence[0]
|
||||||
|
assert input_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[0]
|
||||||
|
assert input_sharding_spec.sharding_sequence[2:] == output_sharding_spec.sharding_sequence[2:]
|
||||||
|
assert input_sharding_spec.sharding_sequence[1] == weight_sharding_spec.sharding_sequence[1]
|
||||||
|
|
||||||
|
if bias:
|
||||||
|
assert bias_sharding_spec.sharding_sequence[-1] == weight_sharding_spec.sharding_sequence[0]
|
||||||
|
assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[1]
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_conv_module_handler()
|
test_conv_module_handler()
|
||||||
|
Loading…
Reference in New Issue
Block a user