mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 12:30:42 +00:00
[autoparallel] add sequential order to communication actions (#1735)
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import copy
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
@@ -6,15 +7,22 @@ import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from torch.fx import GraphModule
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.solver import (CostGraph, GraphAnalyser, Solver, SolverOptions,
|
||||
StrategiesConstructor)
|
||||
from colossalai.auto_parallel.tensor_shard.solver import (
|
||||
CostGraph,
|
||||
GraphAnalyser,
|
||||
Solver,
|
||||
SolverOptions,
|
||||
StrategiesConstructor,
|
||||
)
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx.passes.experimental.adding_shape_consistency_pass_v2 import (shape_consistency_pass,
|
||||
solution_annotatation_pass)
|
||||
from colossalai.fx.passes.experimental.adding_shape_consistency_pass_v2 import (
|
||||
shape_consistency_pass,
|
||||
solution_annotatation_pass,
|
||||
)
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.testing import assert_close, rerun_if_address_is_in_use
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.utils import free_port
|
||||
|
||||
@@ -27,6 +35,7 @@ class ConvModel(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = torch.flatten(x)
|
||||
return x
|
||||
|
||||
|
||||
@@ -38,12 +47,13 @@ def check_apply(rank, world_size, port):
|
||||
mesh_shape = (2, 2)
|
||||
# [[0, 1]
|
||||
# [2, 3]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=False)
|
||||
entire_shape = torch.Size((4, 4, 8, 8))
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||
|
||||
tracer = ColoTracer()
|
||||
model = ConvModel(4, 4).cuda()
|
||||
origin_output = model(input)
|
||||
test_model = copy.deepcopy(model)
|
||||
test_input = copy.deepcopy(input)
|
||||
|
||||
input_sample = {'x': torch.rand(4, 4, 4, 4).to('meta')}
|
||||
# graph():
|
||||
# %x : torch.Tensor [#users=1] = placeholder[target=x]
|
||||
@@ -62,16 +72,30 @@ def check_apply(rank, world_size, port):
|
||||
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser)
|
||||
ret = solver.call_solver_serialized_args()
|
||||
solution = list(ret[0])
|
||||
device_mesh.process_groups_dict = device_mesh.create_process_groups_for_logical_mesh()
|
||||
sharding_spec_dict, origin_spec_dict = solution_annotatation_pass(gm, solution, device_mesh)
|
||||
sharding_spec_dict, origin_spec_dict, comm_actions_dict = solution_annotatation_pass(gm, solution, device_mesh)
|
||||
shape_consistency_pass(gm)
|
||||
gm.recompile()
|
||||
nodes = [node for node in gm.graph.nodes]
|
||||
# TODO: wrap the gm to avoid the influence of the user training code
|
||||
output = gm(input, sharding_spec_dict, origin_spec_dict)
|
||||
output = gm(input, sharding_spec_dict, origin_spec_dict, comm_actions_dict)
|
||||
origin_output = test_model(test_input)
|
||||
assert output.equal(origin_output)
|
||||
origin_loss = origin_output.sum()
|
||||
loss = output.sum()
|
||||
|
||||
origin_loss.backward()
|
||||
loss.backward()
|
||||
|
||||
grad_0 = test_model.conv.weight.grad.narrow(0, 0, 2)
|
||||
grad_1 = test_model.conv.weight.grad.narrow(0, 2, 2)
|
||||
|
||||
if rank in (0, 1):
|
||||
assert_close(gm.conv.weight.grad.data, grad_0.data)
|
||||
elif rank in (2, 3):
|
||||
assert_close(gm.conv.weight.grad.data, grad_1.data)
|
||||
|
||||
|
||||
# skip this test due to pulp not installed in CI environment
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
|
Reference in New Issue
Block a user