mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-24 19:17:30 +00:00
[autoparallel] update binary elementwise handler (#2451)
* [autoparallel] update binary elementwise handler * polish
This commit is contained in:
@@ -122,25 +122,41 @@ def check_binary_elementwise_handler_with_tensor(rank, op, other_dim, world_size
|
||||
assert input_sharding_spec.sharding_sequence[-1] == other_sharding_spec.sharding_sequence[-1]
|
||||
|
||||
|
||||
def check_binary_elementwise_handler_with_int(rank, op, other_dim, world_size, port):
|
||||
class BEOpModelWithNodeConst(nn.Module):
|
||||
|
||||
def __init__(self, op):
|
||||
super().__init__()
|
||||
self.op = op
|
||||
|
||||
def forward(self, x1):
|
||||
const = x1.dim()
|
||||
out = self.op(x1, const)
|
||||
return out
|
||||
|
||||
|
||||
class BEOpModelWithIntConst(nn.Module):
|
||||
|
||||
def __init__(self, op, const):
|
||||
super().__init__()
|
||||
self.op = op
|
||||
self.const = const
|
||||
|
||||
def forward(self, x1):
|
||||
out = self.op(x1, self.const)
|
||||
return out
|
||||
|
||||
|
||||
def check_binary_elementwise_handler_with_int(rank, op, other_dim, model_cls, world_size, port):
|
||||
disable_existing_loggers()
|
||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
class BinaryElementwiseOpModel(nn.Module):
|
||||
|
||||
def __init__(self, op, const):
|
||||
super().__init__()
|
||||
self.op = op
|
||||
self.const = const
|
||||
|
||||
def forward(self, x1):
|
||||
out = self.op(x1, self.const)
|
||||
return out
|
||||
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
mesh_shape = (2, 2)
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||
model = BinaryElementwiseOpModel(op, other_dim).cuda()
|
||||
if model_cls == BEOpModelWithNodeConst:
|
||||
model = model_cls(op).cuda()
|
||||
else:
|
||||
model = model_cls(op, other_dim).cuda()
|
||||
x1 = torch.rand(4, 4).cuda()
|
||||
# the index of binary-elementwise node in computation graph
|
||||
node_index = 1
|
||||
@@ -159,9 +175,14 @@ def check_binary_elementwise_handler_with_int(rank, op, other_dim, world_size, p
|
||||
tracer = ColoTracer()
|
||||
meta_args = {'x1': torch.rand(4, 4).to('meta')}
|
||||
graph = tracer.trace(model, meta_args=meta_args)
|
||||
print(graph)
|
||||
# assert False
|
||||
gm = ColoGraphModule(model, graph)
|
||||
|
||||
op_node = list(graph.nodes)[1]
|
||||
if model_cls == BEOpModelWithNodeConst:
|
||||
op_node = list(graph.nodes)[2]
|
||||
else:
|
||||
op_node = list(graph.nodes)[1]
|
||||
strategies_vector = StrategiesVector(op_node)
|
||||
|
||||
# build handler
|
||||
@@ -212,7 +233,7 @@ def check_binary_elementwise_handler_with_int(rank, op, other_dim, world_size, p
|
||||
@parameterize('other_dim', [1, 2])
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_binary_elementwise_handler(op, other_dim):
|
||||
def test_binary_elementwise_handler_with_tensor(op, other_dim):
|
||||
world_size = 4
|
||||
run_func_tensor = partial(check_binary_elementwise_handler_with_tensor,
|
||||
op=op,
|
||||
@@ -220,8 +241,19 @@ def test_binary_elementwise_handler(op, other_dim):
|
||||
world_size=world_size,
|
||||
port=free_port())
|
||||
mp.spawn(run_func_tensor, nprocs=world_size)
|
||||
|
||||
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
@parameterize('op', [torch.add])
|
||||
@parameterize('other_dim', [1, 2])
|
||||
@parameterize('model_cls', [BEOpModelWithNodeConst, BEOpModelWithIntConst])
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_binary_elementwise_handler_with_int(op, model_cls, other_dim):
|
||||
world_size = 4
|
||||
run_func_int = partial(check_binary_elementwise_handler_with_int,
|
||||
op=op,
|
||||
model_cls=model_cls,
|
||||
other_dim=other_dim,
|
||||
world_size=world_size,
|
||||
port=free_port())
|
||||
@@ -229,4 +261,5 @@ def test_binary_elementwise_handler(op, other_dim):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_binary_elementwise_handler()
|
||||
test_binary_elementwise_handler_with_tensor()
|
||||
test_binary_elementwise_handler_with_int()
|
||||
|
@@ -90,7 +90,8 @@ def numerical_test_for_node_strategy(model: torch.nn.Module,
|
||||
solver_options = SolverOptions()
|
||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
||||
strategies_constructor.build_strategies_and_cost()
|
||||
target_node = list(graph.nodes)[node_index]
|
||||
target_node = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies
|
||||
][node_index]
|
||||
if node_type == 'normal':
|
||||
solution_len = len(strategies_constructor.leaf_strategies)
|
||||
solution = [0] * solution_len
|
||||
@@ -112,7 +113,7 @@ def numerical_test_for_node_strategy(model: torch.nn.Module,
|
||||
ret = solver.call_solver_serialized_args()
|
||||
solution = list(ret[0])
|
||||
gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(
|
||||
gm, solution, device_mesh)
|
||||
gm, solution, device_mesh, strategies_constructor)
|
||||
gm = runtime_apply_pass(gm)
|
||||
gm.recompile()
|
||||
|
||||
|
Reference in New Issue
Block a user