[autoparallel] support addbmm computation (#2102)

This commit is contained in:
YuliangLiu0306
2022-12-08 21:15:11 +08:00
committed by GitHub
parent d3d4630495
commit 0fecbb9e20
7 changed files with 179 additions and 65 deletions

View File

@@ -5,7 +5,7 @@ import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.node_handler import AddBMMFunctionHandler
from colossalai.auto_parallel.tensor_shard.node_handler import BMMFunctionHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
@@ -19,20 +19,36 @@ from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import n
class AddBMMTensorMethodModule(nn.Module):
def __init__(self, using_kwargs):
super().__init__()
self.using_kwargs = using_kwargs
def forward(self, bias, x1, x2):
return bias.addbmm(x1, x2)
if self.using_kwargs:
output = bias.addbmm(x1, x2, alpha=2, beta=3)
else:
output = bias.addbmm(x1, x2)
return output
class AddBMMTorchFunctionModule(nn.Module):
def __init__(self, using_kwargs):
super().__init__()
self.using_kwargs = using_kwargs
def forward(self, bias, x1, x2):
return torch.addbmm(bias, x1, x2)
if self.using_kwargs:
output = torch.addbmm(bias, x1, x2, alpha=2, beta=3)
else:
output = torch.addbmm(bias, x1, x2)
return output
def check_2d_device_mesh(rank, module, bias_shape, world_size, port):
def check_2d_device_mesh(rank, module, bias_shape, using_kwargs, world_size, port):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = module().cuda()
model = module(using_kwargs).cuda()
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
@@ -54,6 +70,14 @@ def check_2d_device_mesh(rank, module, bias_shape, world_size, port):
input_args=input_args,
meta_arg_names=meta_arg_names)
tracer = ColoTracer()
# graph():
# %bias : torch.Tensor [#users=1] = placeholder[target=bias]
# %x1 : torch.Tensor [#users=1] = placeholder[target=x1]
# %x2 : torch.Tensor [#users=1] = placeholder[target=x2]
# %bmm : [#users=1] = call_function[target=torch.bmm](args = (%x1, %x2), kwargs = {})
# %sum_1 : [#users=1] = call_function[target=torch.sum](args = (%bmm, 0), kwargs = {})
# %add : [#users=1] = call_function[target=operator.add](args = (%sum_1, %bias), kwargs = {})
# return add
graph = tracer.trace(model,
meta_args={
'bias': torch.rand(*bias_shape).to('meta'),
@@ -62,11 +86,11 @@ def check_2d_device_mesh(rank, module, bias_shape, world_size, port):
})
gm = ColoGraphModule(model, graph)
linear_mod_node = list(graph.nodes)[3]
strategies_vector = StrategiesVector(linear_mod_node)
bmm_mod_node = list(graph.nodes)[3]
strategies_vector = StrategiesVector(bmm_mod_node)
# build handler
handler = AddBMMFunctionHandler(node=linear_mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector)
handler = BMMFunctionHandler(node=bmm_mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector)
# check operation data mapping
mapping = handler.get_operation_data_mapping()
@@ -89,19 +113,15 @@ def check_2d_device_mesh(rank, module, bias_shape, world_size, port):
assert mapping['other'].type == OperationDataType.ARG
assert mapping['other'].logical_shape == torch.Size([4, 16, 8])
assert mapping['bias'].name == "bias"
assert mapping['bias'].data.is_meta
assert mapping['bias'].data.shape == torch.Size(bias_shape)
assert mapping['bias'].type == OperationDataType.ARG
assert mapping['bias'].logical_shape == torch.Size([8, 8])
assert mapping['output'].name == "addbmm"
assert mapping['output'].name == "bmm"
assert mapping['output'].data.is_meta
assert mapping['output'].data.shape == torch.Size([8, 8])
assert mapping['output'].data.shape == torch.Size([4, 8, 8])
assert mapping['output'].type == OperationDataType.OUTPUT
strategies_vector = handler.register_strategy(compute_resharding_cost=False)
strategy_name_list = [val.name for val in strategies_vector]
for name in strategy_name_list:
print(name)
# one batch dim
assert 'Sb0 = Sb0 x Sb0' not in strategy_name_list
@@ -123,23 +143,21 @@ def check_2d_device_mesh(rank, module, bias_shape, world_size, port):
for strategy in strategies_vector:
input_sharding_spec = strategy.get_sharding_spec_by_name('x1')
other_sharding_spec = strategy.get_sharding_spec_by_name('x2')
bias_sharding_spec = strategy.get_sharding_spec_by_name('bias')
output_sharding_spec = strategy.get_sharding_spec_by_name('addbmm')
output_sharding_spec = strategy.get_sharding_spec_by_name('bmm')
# make sure the sharding matches across different operation data
assert input_sharding_spec.sharding_sequence[1] == output_sharding_spec.sharding_sequence[0]
assert input_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[0]
assert other_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1]
assert other_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1]
assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1]
def check_1d_device_mesh(rank, module, bias_shape, world_size, port):
def check_1d_device_mesh(rank, module, bias_shape, using_kwargs, world_size, port):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (1, 4)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
model = module().cuda()
model = module(using_kwargs).cuda()
x1 = torch.rand(4, 8, 16).cuda()
x2 = torch.rand(4, 16, 8).cuda()
bias = torch.rand(bias_shape).cuda()
@@ -159,6 +177,14 @@ def check_1d_device_mesh(rank, module, bias_shape, world_size, port):
meta_arg_names=meta_arg_names)
tracer = ColoTracer()
# graph():
# %bias : torch.Tensor [#users=1] = placeholder[target=bias]
# %x1 : torch.Tensor [#users=1] = placeholder[target=x1]
# %x2 : torch.Tensor [#users=1] = placeholder[target=x2]
# %bmm : [#users=1] = call_function[target=torch.bmm](args = (%x1, %x2), kwargs = {})
# %sum_1 : [#users=1] = call_function[target=torch.sum](args = (%bmm, 0), kwargs = {})
# %add : [#users=1] = call_function[target=operator.add](args = (%sum_1, %bias), kwargs = {})
# return add
graph = tracer.trace(model,
meta_args={
'bias': torch.rand(*bias_shape).to('meta'),
@@ -166,11 +192,11 @@ def check_1d_device_mesh(rank, module, bias_shape, world_size, port):
'x2': torch.rand(4, 16, 8).to('meta')
})
gm = ColoGraphModule(model, graph)
linear_mod_node = list(graph.nodes)[3]
strategies_vector = StrategiesVector(linear_mod_node)
bmm_mod_node = list(graph.nodes)[3]
strategies_vector = StrategiesVector(bmm_mod_node)
# build handler
handler = AddBMMFunctionHandler(node=linear_mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector)
handler = BMMFunctionHandler(node=bmm_mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector)
# check operation data mapping
mapping = handler.get_operation_data_mapping()
@@ -193,15 +219,9 @@ def check_1d_device_mesh(rank, module, bias_shape, world_size, port):
assert mapping['other'].type == OperationDataType.ARG
assert mapping['other'].logical_shape == torch.Size([4, 16, 8])
assert mapping['bias'].name == "bias"
assert mapping['bias'].data.is_meta
assert mapping['bias'].data.shape == torch.Size(bias_shape)
assert mapping['bias'].type == OperationDataType.ARG
assert mapping['bias'].logical_shape == torch.Size([8, 8])
assert mapping['output'].name == "addbmm"
assert mapping['output'].name == "bmm"
assert mapping['output'].data.is_meta
assert mapping['output'].data.shape == torch.Size([8, 8])
assert mapping['output'].data.shape == torch.Size([4, 8, 8])
assert mapping['output'].type == OperationDataType.OUTPUT
strategies_vector = handler.register_strategy(compute_resharding_cost=False)
@@ -213,14 +233,12 @@ def check_1d_device_mesh(rank, module, bias_shape, world_size, port):
for strategy in strategies_vector:
input_sharding_spec = strategy.get_sharding_spec_by_name('x1')
other_sharding_spec = strategy.get_sharding_spec_by_name('x2')
bias_sharding_spec = strategy.get_sharding_spec_by_name('bias')
output_sharding_spec = strategy.get_sharding_spec_by_name('addbmm')
output_sharding_spec = strategy.get_sharding_spec_by_name('bmm')
# make sure the sharding matches across different operation data
assert input_sharding_spec.sharding_sequence[1] == output_sharding_spec.sharding_sequence[0]
assert input_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[0]
assert other_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1]
assert other_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1]
assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1]
@pytest.mark.skip("skip due to bias cases not ready")
@@ -228,13 +246,15 @@ def check_1d_device_mesh(rank, module, bias_shape, world_size, port):
@pytest.mark.dist
@parameterize('module', [AddBMMTorchFunctionModule, AddBMMTensorMethodModule])
@parameterize('bias_shape', [[8], [1, 8], [8, 8]])
@parameterize('using_kwargs', [True, False])
@rerun_if_address_is_in_use()
def test_2d_device_mesh(module, bias_shape):
def test_2d_device_mesh(module, bias_shape, using_kwargs):
world_size = 4
run_func = partial(check_2d_device_mesh,
module=module,
bias_shape=bias_shape,
world_size=world_size,
using_kwargs=using_kwargs,
port=free_port())
mp.spawn(run_func, nprocs=world_size)
@@ -244,12 +264,14 @@ def test_2d_device_mesh(module, bias_shape):
@pytest.mark.dist
@parameterize('module', [AddBMMTorchFunctionModule, AddBMMTensorMethodModule])
@parameterize('bias_shape', [[8], [1, 8], [8, 8]])
@parameterize('using_kwargs', [True, False])
@rerun_if_address_is_in_use()
def test_1d_device_mesh(module, bias_shape):
def test_1d_device_mesh(module, bias_shape, using_kwargs):
world_size = 4
run_func = partial(check_1d_device_mesh,
module=module,
bias_shape=bias_shape,
using_kwargs=using_kwargs,
world_size=world_size,
port=free_port())
mp.spawn(run_func, nprocs=world_size)