[autoparallel] update_getattr_handler (#2193)

This commit is contained in:
YuliangLiu0306
2022-12-26 21:57:39 +08:00
committed by GitHub
parent f10ce01e31
commit 4851f2d607
6 changed files with 136 additions and 58 deletions

View File

@@ -35,25 +35,59 @@ class AddmmModel(nn.Module):
return x
def check_linear_function_handler(rank, input_shape, world_size, port):
class AddmmModel_with_param(nn.Module):
def __init__(self, weight_shape, bias_shape):
super().__init__()
self.weight = torch.nn.Parameter(torch.rand(weight_shape))
self.bias = torch.nn.Parameter(torch.rand(bias_shape))
def forward(self, m1):
x = torch.addmm(self.bias, m1, self.weight, beta=3, alpha=2)
return x
def check_addmm_function_handler(rank, input_shape, model_cls, world_size, port):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = AddmmModel().cuda()
if model_cls == AddmmModel:
model = AddmmModel().cuda()
else:
model = AddmmModel_with_param(weight_shape=(8, 16), bias_shape=input_shape).cuda()
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
input = torch.rand(input_shape).cuda()
m1 = torch.rand(4, 8).cuda()
m2 = torch.rand(8, 16).cuda()
# the index of addmm node in computation graph
node_index = 4
# strategy number of linear node
strategy_number = 14
# construct input args
input_args = [input, m1, m2]
# construct meta arg names
meta_arg_names = ['input', 'm1', 'm2']
if model_cls == AddmmModel:
input = torch.rand(input_shape).cuda()
m1 = torch.rand(4, 8).cuda()
m2 = torch.rand(8, 16).cuda()
# construct input args
input_args = [input, m1, m2]
# construct meta arg names
meta_arg_names = ['input', 'm1', 'm2']
meta_args_for_tracer = {}
for meta_arg, input_arg in zip(meta_arg_names, input_args):
meta_args_for_tracer[meta_arg] = input_arg.to('meta')
# the index of addmm node in computation graph
node_index = 4
# strategy number of linear node
strategy_number = 14
else:
m1 = torch.rand(4, 8).cuda()
# construct input args
input_args = [m1]
# construct meta arg names
meta_arg_names = ['m1']
# the index of addmm node in computation graph
meta_args_for_tracer = {}
for meta_arg, input_arg in zip(meta_arg_names, input_args):
meta_args_for_tracer[meta_arg] = input_arg.to('meta')
node_index = 4
# strategy number of linear node
strategy_number = 14
numerical_test_for_node_strategy(model=model,
device_mesh=device_mesh,
node_index=node_index,
@@ -73,12 +107,7 @@ def check_linear_function_handler(rank, input_shape, world_size, port):
# %mul_1 : [#users=1] = call_function[target=operator.mul](args = (2, %linear), kwargs = {})
# %add : [#users=1] = call_function[target=operator.add](args = (%mul_1, %mul), kwargs = {})
# return add
graph = tracer.trace(model,
meta_args={
"input": torch.rand(input_shape).to('meta'),
'm1': torch.rand(4, 8).to('meta'),
'm2': torch.rand(8, 16).to('meta'),
})
graph = tracer.trace(model, meta_args=meta_args_for_tracer)
gm = ColoGraphModule(model, graph)
# [input_1, m1, m2, addmm, output]
node_list = list(graph.nodes)
@@ -155,11 +184,13 @@ def check_linear_function_handler(rank, input_shape, world_size, port):
@run_on_environment_flag(name='AUTO_PARALLEL')
@pytest.mark.dist
@parameterize('input_shape', [(16,), (4, 16)])
@parameterize('model_cls', [AddmmModel, AddmmModel_with_param])
@rerun_if_address_is_in_use()
def test_addmm_handler(input_shape):
def test_addmm_handler(input_shape, model_cls):
world_size = 4
run_func_function = partial(check_linear_function_handler,
run_func_function = partial(check_addmm_function_handler,
input_shape=input_shape,
model_cls=model_cls,
world_size=world_size,
port=free_port())
mp.spawn(run_func_function, nprocs=world_size)

View File

@@ -39,6 +39,7 @@ def test_getattr_handler():
strategies_vector=getattr_strategies_vector)
getattr_handler.register_strategy(compute_resharding_cost=False)
# check operation data mapping
mapping = getattr_handler.get_operation_data_mapping()
@@ -51,7 +52,15 @@ def test_getattr_handler():
assert mapping['output'].data.shape == torch.Size((16, 4, 3, 3))
assert mapping['output'].type == OperationDataType.OUTPUT
strategy_name_list = [val.name for val in getattr_handler.strategies_vector]
assert "Replica Attribute" in strategy_name_list
assert 'get_attr [S0, S1, R, R]' in strategy_name_list
assert 'get_attr [S1, S0, R, R]' in strategy_name_list
assert 'get_attr [S01, R, R, R]' in strategy_name_list
assert 'get_attr [R, S01, R, R]' in strategy_name_list
assert 'get_attr [S0, R, R, R]' in strategy_name_list
assert 'get_attr [R, S0, R, R]' in strategy_name_list
assert 'get_attr [S1, R, R, R]' in strategy_name_list
assert 'get_attr [R, S1, R, R]' in strategy_name_list
assert 'get_attr [R, R, R, R]' in strategy_name_list
if __name__ == '__main__':

View File

@@ -149,10 +149,20 @@ def numerical_test_for_node_strategy(model: torch.nn.Module,
param_sharding_spec = strategy_in_use.get_sharding_spec_by_name(param_name)
else:
if 'weight' in name:
param_sharding_spec = list(graph.nodes)[4].sharding_spec
elif 'bias' in name:
param_sharding_spec = list(graph.nodes)[5].sharding_spec
param_sharding_spec = None
for node in list(graph.nodes):
if 'weight' in node.name:
param_sharding_spec = node.sharding_spec
elif 'bias' in name:
param_sharding_spec = None
for node in list(graph.nodes):
if 'bias' in node.name:
param_sharding_spec = node.sharding_spec
assert param_sharding_spec is not None
grad_sharded = param_to_shard_dict[name].grad
grad_to_compare = param_to_compare_dict[name].grad
global_grad = to_global(grad_sharded, param_sharding_spec)