mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 18:19:58 +00:00
[autoparallel] add bias addtion function class (#2098)
* [autoparallel] add bias addtion function class * polish code * polish
This commit is contained in:
@@ -8,7 +8,7 @@ import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from typing_extensions import Self
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler import ADDMMFunctionHandler
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
OperationData,
|
||||
OperationDataType,
|
||||
@@ -19,7 +19,7 @@ from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.utils import free_port
|
||||
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
|
||||
@@ -31,7 +31,7 @@ class AddmmModel(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, input, m1, m2):
|
||||
x = torch.addmm(input, m1, m2)
|
||||
x = torch.addmm(input, m1, m2, beta=3, alpha=2)
|
||||
return x
|
||||
|
||||
|
||||
@@ -47,9 +47,9 @@ def check_linear_function_handler(rank, input_shape, world_size, port):
|
||||
m1 = torch.rand(4, 8).cuda()
|
||||
m2 = torch.rand(8, 16).cuda()
|
||||
# the index of addmm node in computation graph
|
||||
node_index = 3
|
||||
node_index = 4
|
||||
# strategy number of linear node
|
||||
strategy_number = 10
|
||||
strategy_number = 14
|
||||
# construct input args
|
||||
input_args = [input, m1, m2]
|
||||
# construct meta arg names
|
||||
@@ -59,9 +59,20 @@ def check_linear_function_handler(rank, input_shape, world_size, port):
|
||||
node_index=node_index,
|
||||
strategy_number=strategy_number,
|
||||
input_args=input_args,
|
||||
meta_arg_names=meta_arg_names)
|
||||
meta_arg_names=meta_arg_names,
|
||||
node_type='bias_module')
|
||||
|
||||
tracer = ColoTracer()
|
||||
# graph():
|
||||
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
|
||||
# %m1 : torch.Tensor [#users=1] = placeholder[target=m1]
|
||||
# %m2 : torch.Tensor [#users=1] = placeholder[target=m2]
|
||||
# %transpose : [#users=1] = call_function[target=torch.transpose](args = (%m2, 0, 1), kwargs = {})
|
||||
# %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%m1, %transpose), kwargs = {})
|
||||
# %mul : [#users=1] = call_function[target=operator.mul](args = (%input_1, 3), kwargs = {})
|
||||
# %mul_1 : [#users=1] = call_function[target=operator.mul](args = (2, %linear), kwargs = {})
|
||||
# %add : [#users=1] = call_function[target=operator.add](args = (%mul_1, %mul), kwargs = {})
|
||||
# return add
|
||||
graph = tracer.trace(model,
|
||||
meta_args={
|
||||
"input": torch.rand(input_shape).to('meta'),
|
||||
@@ -71,11 +82,11 @@ def check_linear_function_handler(rank, input_shape, world_size, port):
|
||||
gm = ColoGraphModule(model, graph)
|
||||
# [input_1, m1, m2, addmm, output]
|
||||
node_list = list(graph.nodes)
|
||||
addmm_node = node_list[3]
|
||||
strategies_vector = StrategiesVector(addmm_node)
|
||||
linear_node = node_list[4]
|
||||
strategies_vector = StrategiesVector(linear_node)
|
||||
|
||||
# build handler
|
||||
handler = ADDMMFunctionHandler(node=addmm_node, device_mesh=device_mesh, strategies_vector=strategies_vector)
|
||||
handler = LinearFunctionHandler(node=linear_node, device_mesh=device_mesh, strategies_vector=strategies_vector)
|
||||
|
||||
handler.register_strategy(compute_resharding_cost=False)
|
||||
strategy_name_list = [val.name for val in strategies_vector]
|
||||
@@ -88,30 +99,22 @@ def check_linear_function_handler(rank, input_shape, world_size, port):
|
||||
assert mapping['input'].type == OperationDataType.ARG
|
||||
assert mapping['input'].logical_shape == torch.Size([4, 8])
|
||||
|
||||
assert mapping['other'].name == "m2"
|
||||
assert mapping['other'].data.shape == torch.Size([8, 16])
|
||||
assert mapping['other'].name == "transpose"
|
||||
assert mapping['other'].data.shape == torch.Size([16, 8])
|
||||
assert mapping['other'].type == OperationDataType.ARG
|
||||
assert mapping['other'].logical_shape == torch.Size([8, 16])
|
||||
|
||||
assert mapping['bias'].name == "input_1"
|
||||
assert mapping['bias'].data.shape == torch.Size(input_shape)
|
||||
assert mapping['bias'].type == OperationDataType.ARG
|
||||
assert mapping['bias'].logical_shape == torch.Size([4, 16])
|
||||
|
||||
assert mapping['output'].name == "addmm"
|
||||
assert mapping['output'].name == "linear"
|
||||
assert mapping['output'].data.shape == torch.Size([4, 16])
|
||||
assert mapping['output'].type == OperationDataType.OUTPUT
|
||||
|
||||
# one strategy will be converted to different physical sharding spec
|
||||
assert len(strategy_name_list) > 8
|
||||
|
||||
# SS = SR x RS
|
||||
assert 'S0S1 = S0R x RS1' in strategy_name_list
|
||||
assert 'S1S0 = S1R x RS0' in strategy_name_list
|
||||
assert 'S0S1 = S0R x RS1_0' in strategy_name_list
|
||||
assert 'S1S0 = S1R x RS0_0' in strategy_name_list
|
||||
|
||||
# SR = SS x SR
|
||||
assert 'S0R = S0S1 x S1R' in strategy_name_list
|
||||
assert 'S1R = S1S0 x S0R' in strategy_name_list
|
||||
assert 'S0R = S0S1 x S1R_0' in strategy_name_list
|
||||
assert 'S1R = S1S0 x S0R_0' in strategy_name_list
|
||||
|
||||
# RS = RS x SS
|
||||
assert 'RS0 = RS1 x S1S0' in strategy_name_list
|
||||
@@ -125,23 +128,33 @@ def check_linear_function_handler(rank, input_shape, world_size, port):
|
||||
assert 'RS0 = RR x RS0' in strategy_name_list
|
||||
assert 'RS1 = RR x RS1' in strategy_name_list
|
||||
|
||||
# S01R = S01R x RR
|
||||
assert 'S01R = S01R x RR_0' in strategy_name_list
|
||||
|
||||
# RR = RS01 x S01R
|
||||
assert 'RR = RS01 x S01R' in strategy_name_list
|
||||
|
||||
# RS01 = RR x RS01
|
||||
assert 'RS01 = RR x RS01' in strategy_name_list
|
||||
|
||||
# RR = RR x RR
|
||||
assert 'RR = RR x RR' in strategy_name_list
|
||||
|
||||
for strategy in strategies_vector:
|
||||
strategy: ShardingStrategy
|
||||
input_sharding_spec = strategy.get_sharding_spec_by_name('m1')
|
||||
weight_sharding_spec = strategy.get_sharding_spec_by_name('m2')
|
||||
output_sharding_spec = strategy.get_sharding_spec_by_name('addmm')
|
||||
bias_sharding_spec = strategy.get_sharding_spec_by_name('input_1')
|
||||
weight_sharding_spec = strategy.get_sharding_spec_by_name('transpose')
|
||||
output_sharding_spec = strategy.get_sharding_spec_by_name('linear')
|
||||
|
||||
# make sure the sharding matches across different operation data
|
||||
assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1]
|
||||
assert weight_sharding_spec.sharding_sequence[0] == input_sharding_spec.sharding_sequence[1]
|
||||
assert weight_sharding_spec.sharding_sequence[1] == output_sharding_spec.sharding_sequence[1]
|
||||
assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1]
|
||||
assert weight_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[1]
|
||||
assert weight_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[1]
|
||||
|
||||
|
||||
@parameterize('input_shape', [(16,), (4, 16)])
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
@pytest.mark.dist
|
||||
@parameterize('input_shape', [(16,), (4, 16)])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_addmm_handler(input_shape):
|
||||
world_size = 4
|
||||
|
Reference in New Issue
Block a user