mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-11 22:10:37 +00:00
[autoparallel] support addbmm computation (#2102)
This commit is contained in:
@@ -5,7 +5,7 @@ import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler import AddBMMFunctionHandler
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler import BMMFunctionHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
@@ -19,20 +19,36 @@ from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import n
|
||||
|
||||
class AddBMMTensorMethodModule(nn.Module):
|
||||
|
||||
def __init__(self, using_kwargs):
|
||||
super().__init__()
|
||||
self.using_kwargs = using_kwargs
|
||||
|
||||
def forward(self, bias, x1, x2):
|
||||
return bias.addbmm(x1, x2)
|
||||
if self.using_kwargs:
|
||||
output = bias.addbmm(x1, x2, alpha=2, beta=3)
|
||||
else:
|
||||
output = bias.addbmm(x1, x2)
|
||||
return output
|
||||
|
||||
|
||||
class AddBMMTorchFunctionModule(nn.Module):
|
||||
|
||||
def __init__(self, using_kwargs):
|
||||
super().__init__()
|
||||
self.using_kwargs = using_kwargs
|
||||
|
||||
def forward(self, bias, x1, x2):
|
||||
return torch.addbmm(bias, x1, x2)
|
||||
if self.using_kwargs:
|
||||
output = torch.addbmm(bias, x1, x2, alpha=2, beta=3)
|
||||
else:
|
||||
output = torch.addbmm(bias, x1, x2)
|
||||
return output
|
||||
|
||||
|
||||
def check_2d_device_mesh(rank, module, bias_shape, world_size, port):
|
||||
def check_2d_device_mesh(rank, module, bias_shape, using_kwargs, world_size, port):
|
||||
disable_existing_loggers()
|
||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
model = module().cuda()
|
||||
model = module(using_kwargs).cuda()
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
mesh_shape = (2, 2)
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||
@@ -54,6 +70,14 @@ def check_2d_device_mesh(rank, module, bias_shape, world_size, port):
|
||||
input_args=input_args,
|
||||
meta_arg_names=meta_arg_names)
|
||||
tracer = ColoTracer()
|
||||
# graph():
|
||||
# %bias : torch.Tensor [#users=1] = placeholder[target=bias]
|
||||
# %x1 : torch.Tensor [#users=1] = placeholder[target=x1]
|
||||
# %x2 : torch.Tensor [#users=1] = placeholder[target=x2]
|
||||
# %bmm : [#users=1] = call_function[target=torch.bmm](args = (%x1, %x2), kwargs = {})
|
||||
# %sum_1 : [#users=1] = call_function[target=torch.sum](args = (%bmm, 0), kwargs = {})
|
||||
# %add : [#users=1] = call_function[target=operator.add](args = (%sum_1, %bias), kwargs = {})
|
||||
# return add
|
||||
graph = tracer.trace(model,
|
||||
meta_args={
|
||||
'bias': torch.rand(*bias_shape).to('meta'),
|
||||
@@ -62,11 +86,11 @@ def check_2d_device_mesh(rank, module, bias_shape, world_size, port):
|
||||
})
|
||||
gm = ColoGraphModule(model, graph)
|
||||
|
||||
linear_mod_node = list(graph.nodes)[3]
|
||||
strategies_vector = StrategiesVector(linear_mod_node)
|
||||
bmm_mod_node = list(graph.nodes)[3]
|
||||
strategies_vector = StrategiesVector(bmm_mod_node)
|
||||
|
||||
# build handler
|
||||
handler = AddBMMFunctionHandler(node=linear_mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector)
|
||||
handler = BMMFunctionHandler(node=bmm_mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector)
|
||||
|
||||
# check operation data mapping
|
||||
mapping = handler.get_operation_data_mapping()
|
||||
@@ -89,19 +113,15 @@ def check_2d_device_mesh(rank, module, bias_shape, world_size, port):
|
||||
assert mapping['other'].type == OperationDataType.ARG
|
||||
assert mapping['other'].logical_shape == torch.Size([4, 16, 8])
|
||||
|
||||
assert mapping['bias'].name == "bias"
|
||||
assert mapping['bias'].data.is_meta
|
||||
assert mapping['bias'].data.shape == torch.Size(bias_shape)
|
||||
assert mapping['bias'].type == OperationDataType.ARG
|
||||
assert mapping['bias'].logical_shape == torch.Size([8, 8])
|
||||
|
||||
assert mapping['output'].name == "addbmm"
|
||||
assert mapping['output'].name == "bmm"
|
||||
assert mapping['output'].data.is_meta
|
||||
assert mapping['output'].data.shape == torch.Size([8, 8])
|
||||
assert mapping['output'].data.shape == torch.Size([4, 8, 8])
|
||||
assert mapping['output'].type == OperationDataType.OUTPUT
|
||||
|
||||
strategies_vector = handler.register_strategy(compute_resharding_cost=False)
|
||||
strategy_name_list = [val.name for val in strategies_vector]
|
||||
for name in strategy_name_list:
|
||||
print(name)
|
||||
# one batch dim
|
||||
assert 'Sb0 = Sb0 x Sb0' not in strategy_name_list
|
||||
|
||||
@@ -123,23 +143,21 @@ def check_2d_device_mesh(rank, module, bias_shape, world_size, port):
|
||||
for strategy in strategies_vector:
|
||||
input_sharding_spec = strategy.get_sharding_spec_by_name('x1')
|
||||
other_sharding_spec = strategy.get_sharding_spec_by_name('x2')
|
||||
bias_sharding_spec = strategy.get_sharding_spec_by_name('bias')
|
||||
output_sharding_spec = strategy.get_sharding_spec_by_name('addbmm')
|
||||
output_sharding_spec = strategy.get_sharding_spec_by_name('bmm')
|
||||
|
||||
# make sure the sharding matches across different operation data
|
||||
assert input_sharding_spec.sharding_sequence[1] == output_sharding_spec.sharding_sequence[0]
|
||||
assert input_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[0]
|
||||
assert other_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1]
|
||||
assert other_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1]
|
||||
assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1]
|
||||
|
||||
|
||||
def check_1d_device_mesh(rank, module, bias_shape, world_size, port):
|
||||
def check_1d_device_mesh(rank, module, bias_shape, using_kwargs, world_size, port):
|
||||
disable_existing_loggers()
|
||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
mesh_shape = (1, 4)
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||
model = module().cuda()
|
||||
model = module(using_kwargs).cuda()
|
||||
x1 = torch.rand(4, 8, 16).cuda()
|
||||
x2 = torch.rand(4, 16, 8).cuda()
|
||||
bias = torch.rand(bias_shape).cuda()
|
||||
@@ -159,6 +177,14 @@ def check_1d_device_mesh(rank, module, bias_shape, world_size, port):
|
||||
meta_arg_names=meta_arg_names)
|
||||
|
||||
tracer = ColoTracer()
|
||||
# graph():
|
||||
# %bias : torch.Tensor [#users=1] = placeholder[target=bias]
|
||||
# %x1 : torch.Tensor [#users=1] = placeholder[target=x1]
|
||||
# %x2 : torch.Tensor [#users=1] = placeholder[target=x2]
|
||||
# %bmm : [#users=1] = call_function[target=torch.bmm](args = (%x1, %x2), kwargs = {})
|
||||
# %sum_1 : [#users=1] = call_function[target=torch.sum](args = (%bmm, 0), kwargs = {})
|
||||
# %add : [#users=1] = call_function[target=operator.add](args = (%sum_1, %bias), kwargs = {})
|
||||
# return add
|
||||
graph = tracer.trace(model,
|
||||
meta_args={
|
||||
'bias': torch.rand(*bias_shape).to('meta'),
|
||||
@@ -166,11 +192,11 @@ def check_1d_device_mesh(rank, module, bias_shape, world_size, port):
|
||||
'x2': torch.rand(4, 16, 8).to('meta')
|
||||
})
|
||||
gm = ColoGraphModule(model, graph)
|
||||
linear_mod_node = list(graph.nodes)[3]
|
||||
strategies_vector = StrategiesVector(linear_mod_node)
|
||||
bmm_mod_node = list(graph.nodes)[3]
|
||||
strategies_vector = StrategiesVector(bmm_mod_node)
|
||||
|
||||
# build handler
|
||||
handler = AddBMMFunctionHandler(node=linear_mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector)
|
||||
handler = BMMFunctionHandler(node=bmm_mod_node, device_mesh=device_mesh, strategies_vector=strategies_vector)
|
||||
|
||||
# check operation data mapping
|
||||
mapping = handler.get_operation_data_mapping()
|
||||
@@ -193,15 +219,9 @@ def check_1d_device_mesh(rank, module, bias_shape, world_size, port):
|
||||
assert mapping['other'].type == OperationDataType.ARG
|
||||
assert mapping['other'].logical_shape == torch.Size([4, 16, 8])
|
||||
|
||||
assert mapping['bias'].name == "bias"
|
||||
assert mapping['bias'].data.is_meta
|
||||
assert mapping['bias'].data.shape == torch.Size(bias_shape)
|
||||
assert mapping['bias'].type == OperationDataType.ARG
|
||||
assert mapping['bias'].logical_shape == torch.Size([8, 8])
|
||||
|
||||
assert mapping['output'].name == "addbmm"
|
||||
assert mapping['output'].name == "bmm"
|
||||
assert mapping['output'].data.is_meta
|
||||
assert mapping['output'].data.shape == torch.Size([8, 8])
|
||||
assert mapping['output'].data.shape == torch.Size([4, 8, 8])
|
||||
assert mapping['output'].type == OperationDataType.OUTPUT
|
||||
|
||||
strategies_vector = handler.register_strategy(compute_resharding_cost=False)
|
||||
@@ -213,14 +233,12 @@ def check_1d_device_mesh(rank, module, bias_shape, world_size, port):
|
||||
for strategy in strategies_vector:
|
||||
input_sharding_spec = strategy.get_sharding_spec_by_name('x1')
|
||||
other_sharding_spec = strategy.get_sharding_spec_by_name('x2')
|
||||
bias_sharding_spec = strategy.get_sharding_spec_by_name('bias')
|
||||
output_sharding_spec = strategy.get_sharding_spec_by_name('addbmm')
|
||||
output_sharding_spec = strategy.get_sharding_spec_by_name('bmm')
|
||||
|
||||
# make sure the sharding matches across different operation data
|
||||
assert input_sharding_spec.sharding_sequence[1] == output_sharding_spec.sharding_sequence[0]
|
||||
assert input_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[0]
|
||||
assert other_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1]
|
||||
assert other_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1]
|
||||
assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1]
|
||||
|
||||
|
||||
@pytest.mark.skip("skip due to bias cases not ready")
|
||||
@@ -228,13 +246,15 @@ def check_1d_device_mesh(rank, module, bias_shape, world_size, port):
|
||||
@pytest.mark.dist
|
||||
@parameterize('module', [AddBMMTorchFunctionModule, AddBMMTensorMethodModule])
|
||||
@parameterize('bias_shape', [[8], [1, 8], [8, 8]])
|
||||
@parameterize('using_kwargs', [True, False])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_2d_device_mesh(module, bias_shape):
|
||||
def test_2d_device_mesh(module, bias_shape, using_kwargs):
|
||||
world_size = 4
|
||||
run_func = partial(check_2d_device_mesh,
|
||||
module=module,
|
||||
bias_shape=bias_shape,
|
||||
world_size=world_size,
|
||||
using_kwargs=using_kwargs,
|
||||
port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
@@ -244,12 +264,14 @@ def test_2d_device_mesh(module, bias_shape):
|
||||
@pytest.mark.dist
|
||||
@parameterize('module', [AddBMMTorchFunctionModule, AddBMMTensorMethodModule])
|
||||
@parameterize('bias_shape', [[8], [1, 8], [8, 8]])
|
||||
@parameterize('using_kwargs', [True, False])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_1d_device_mesh(module, bias_shape):
|
||||
def test_1d_device_mesh(module, bias_shape, using_kwargs):
|
||||
world_size = 4
|
||||
run_func = partial(check_1d_device_mesh,
|
||||
module=module,
|
||||
bias_shape=bias_shape,
|
||||
using_kwargs=using_kwargs,
|
||||
world_size=world_size,
|
||||
port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
Reference in New Issue
Block a user