[autoparallel] update binary elementwise handler (#2451)

* [autoparallel] update binary elementwise handler

* polish
This commit is contained in:
YuliangLiu0306
2023-01-12 09:35:10 +08:00
committed by GitHub
parent c9ec5190a0
commit 8221fd7485
3 changed files with 74 additions and 23 deletions

View File

@@ -122,25 +122,41 @@ def check_binary_elementwise_handler_with_tensor(rank, op, other_dim, world_size
assert input_sharding_spec.sharding_sequence[-1] == other_sharding_spec.sharding_sequence[-1]
def check_binary_elementwise_handler_with_int(rank, op, other_dim, world_size, port):
class BEOpModelWithNodeConst(nn.Module):
def __init__(self, op):
super().__init__()
self.op = op
def forward(self, x1):
const = x1.dim()
out = self.op(x1, const)
return out
class BEOpModelWithIntConst(nn.Module):
def __init__(self, op, const):
super().__init__()
self.op = op
self.const = const
def forward(self, x1):
out = self.op(x1, self.const)
return out
def check_binary_elementwise_handler_with_int(rank, op, other_dim, model_cls, world_size, port):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
class BinaryElementwiseOpModel(nn.Module):
def __init__(self, op, const):
super().__init__()
self.op = op
self.const = const
def forward(self, x1):
out = self.op(x1, self.const)
return out
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
model = BinaryElementwiseOpModel(op, other_dim).cuda()
if model_cls == BEOpModelWithNodeConst:
model = model_cls(op).cuda()
else:
model = model_cls(op, other_dim).cuda()
x1 = torch.rand(4, 4).cuda()
# the index of binary-elementwise node in computation graph
node_index = 1
@@ -159,9 +175,14 @@ def check_binary_elementwise_handler_with_int(rank, op, other_dim, world_size, p
tracer = ColoTracer()
meta_args = {'x1': torch.rand(4, 4).to('meta')}
graph = tracer.trace(model, meta_args=meta_args)
print(graph)
# assert False
gm = ColoGraphModule(model, graph)
op_node = list(graph.nodes)[1]
if model_cls == BEOpModelWithNodeConst:
op_node = list(graph.nodes)[2]
else:
op_node = list(graph.nodes)[1]
strategies_vector = StrategiesVector(op_node)
# build handler
@@ -212,7 +233,7 @@ def check_binary_elementwise_handler_with_int(rank, op, other_dim, world_size, p
@parameterize('other_dim', [1, 2])
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_binary_elementwise_handler(op, other_dim):
def test_binary_elementwise_handler_with_tensor(op, other_dim):
world_size = 4
run_func_tensor = partial(check_binary_elementwise_handler_with_tensor,
op=op,
@@ -220,8 +241,19 @@ def test_binary_elementwise_handler(op, other_dim):
world_size=world_size,
port=free_port())
mp.spawn(run_func_tensor, nprocs=world_size)
@run_on_environment_flag(name='AUTO_PARALLEL')
@parameterize('op', [torch.add])
@parameterize('other_dim', [1, 2])
@parameterize('model_cls', [BEOpModelWithNodeConst, BEOpModelWithIntConst])
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_binary_elementwise_handler_with_int(op, model_cls, other_dim):
world_size = 4
run_func_int = partial(check_binary_elementwise_handler_with_int,
op=op,
model_cls=model_cls,
other_dim=other_dim,
world_size=world_size,
port=free_port())
@@ -229,4 +261,5 @@ def test_binary_elementwise_handler(op, other_dim):
if __name__ == '__main__':
test_binary_elementwise_handler()
test_binary_elementwise_handler_with_tensor()
test_binary_elementwise_handler_with_int()

View File

@@ -90,7 +90,8 @@ def numerical_test_for_node_strategy(model: torch.nn.Module,
solver_options = SolverOptions()
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
strategies_constructor.build_strategies_and_cost()
target_node = list(graph.nodes)[node_index]
target_node = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies
][node_index]
if node_type == 'normal':
solution_len = len(strategies_constructor.leaf_strategies)
solution = [0] * solution_len
@@ -112,7 +113,7 @@ def numerical_test_for_node_strategy(model: torch.nn.Module,
ret = solver.call_solver_serialized_args()
solution = list(ret[0])
gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(
gm, solution, device_mesh)
gm, solution, device_mesh, strategies_constructor)
gm = runtime_apply_pass(gm)
gm.recompile()