[autoparallel] add sequential order to communication actions (#1735)

This commit is contained in:
YuliangLiu0306
2022-10-20 18:48:18 +08:00
committed by GitHub
parent b893342f95
commit a4ce180e85
7 changed files with 293 additions and 90 deletions

View File

@@ -1,3 +1,4 @@
import copy
from functools import partial
import pytest
@@ -6,15 +7,22 @@ import torch.multiprocessing as mp
import torch.nn as nn
from torch.fx import GraphModule
from colossalai.auto_parallel.tensor_shard.solver import (CostGraph, GraphAnalyser, Solver, SolverOptions,
StrategiesConstructor)
from colossalai.auto_parallel.tensor_shard.solver import (
CostGraph,
GraphAnalyser,
Solver,
SolverOptions,
StrategiesConstructor,
)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.passes.experimental.adding_shape_consistency_pass_v2 import (shape_consistency_pass,
solution_annotatation_pass)
from colossalai.fx.passes.experimental.adding_shape_consistency_pass_v2 import (
shape_consistency_pass,
solution_annotatation_pass,
)
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.testing import assert_close, rerun_if_address_is_in_use
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
@@ -27,6 +35,7 @@ class ConvModel(nn.Module):
def forward(self, x):
x = self.conv(x)
x = torch.flatten(x)
return x
@@ -38,12 +47,13 @@ def check_apply(rank, world_size, port):
mesh_shape = (2, 2)
# [[0, 1]
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=False)
entire_shape = torch.Size((4, 4, 8, 8))
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
tracer = ColoTracer()
model = ConvModel(4, 4).cuda()
origin_output = model(input)
test_model = copy.deepcopy(model)
test_input = copy.deepcopy(input)
input_sample = {'x': torch.rand(4, 4, 4, 4).to('meta')}
# graph():
# %x : torch.Tensor [#users=1] = placeholder[target=x]
@@ -62,16 +72,30 @@ def check_apply(rank, world_size, port):
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser)
ret = solver.call_solver_serialized_args()
solution = list(ret[0])
device_mesh.process_groups_dict = device_mesh.create_process_groups_for_logical_mesh()
sharding_spec_dict, origin_spec_dict = solution_annotatation_pass(gm, solution, device_mesh)
sharding_spec_dict, origin_spec_dict, comm_actions_dict = solution_annotatation_pass(gm, solution, device_mesh)
shape_consistency_pass(gm)
gm.recompile()
nodes = [node for node in gm.graph.nodes]
# TODO: wrap the gm to avoid the influence of the user training code
output = gm(input, sharding_spec_dict, origin_spec_dict)
output = gm(input, sharding_spec_dict, origin_spec_dict, comm_actions_dict)
origin_output = test_model(test_input)
assert output.equal(origin_output)
origin_loss = origin_output.sum()
loss = output.sum()
origin_loss.backward()
loss.backward()
grad_0 = test_model.conv.weight.grad.narrow(0, 0, 2)
grad_1 = test_model.conv.weight.grad.narrow(0, 2, 2)
if rank in (0, 1):
assert_close(gm.conv.weight.grad.data, grad_0.data)
elif rank in (2, 3):
assert_close(gm.conv.weight.grad.data, grad_1.data)
# skip this test due to pulp not installed in CI environment
@run_on_environment_flag(name='AUTO_PARALLEL')
@pytest.mark.dist
@rerun_if_address_is_in_use()