[autoparallel] shard param and buffer as expected (#1753)

* [autoparallel] shard param and buffer as expected

* fix unit test issue
This commit is contained in:
YuliangLiu0306
2022-10-21 15:45:13 +08:00
committed by GitHub
parent cdb7d5e7d2
commit 980ed21723
6 changed files with 129 additions and 106 deletions

View File

@@ -1,28 +1,32 @@
import copy
from copy import deepcopy
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
from torch.fx import GraphModule
import torch.nn as nn
import pytest
from colossalai import device
from colossalai.initialize import launch
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.logging import disable_existing_loggers
from colossalai.auto_parallel.tensor_shard.solver.graph_analysis import GraphAnalyser
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.fx.passes.experimental.adding_shape_consistency_pass_v2 import shape_consistency_pass, solution_annotatation_pass
from colossalai.auto_parallel.tensor_shard.solver.options import SolverOptions
from colossalai.device.device_mesh import DeviceMesh
from colossalai.auto_parallel.tensor_shard.solver.strategies_constructor import StrategiesConstructor
from colossalai.auto_parallel.tensor_shard.solver.cost_graph import CostGraph
from copy import deepcopy
from colossalai.auto_parallel.tensor_shard.solver.solver import Solver
from torch.fx import GraphModule
from torchvision.models import resnet34, resnet50
from colossalai import device
from colossalai.auto_parallel.tensor_shard.constants import *
from colossalai.testing import assert_close_loose, assert_close
from colossalai.auto_parallel.tensor_shard.solver.cost_graph import CostGraph
from colossalai.auto_parallel.tensor_shard.solver.graph_analysis import GraphAnalyser
from colossalai.auto_parallel.tensor_shard.solver.options import SolverOptions
from colossalai.auto_parallel.tensor_shard.solver.solver import Solver
from colossalai.auto_parallel.tensor_shard.solver.strategies_constructor import StrategiesConstructor
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.passes.experimental.adding_shape_consistency_pass_v2 import (
shape_consistency_pass,
solution_annotatation_pass,
)
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, assert_close_loose, rerun_if_address_is_in_use
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
seed = 128
cudnn_benchmark = False
@@ -108,16 +112,17 @@ class Bottleneck(nn.Module):
def check_apply_bottleneck(rank, world_size, port):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
input = torch.rand(256, 64, 64, 64).cuda()
input = torch.rand(4, 4, 4, 4).cuda()
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
# [[0, 1]
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=False)
entire_shape = torch.Size((4, 4, 8, 8))
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
tracer = ColoTracer()
model = Bottleneck(64, 64, 1, norm_layer=torch.nn.modules.batchnorm.BatchNorm2d).cuda()
model = Bottleneck(4, 4, 1, norm_layer=torch.nn.modules.batchnorm.BatchNorm2d).cuda()
test_model = copy.deepcopy(model)
test_input = copy.deepcopy(input)
# graph():
# %x : torch.Tensor [#users=1] = placeholder[target=x]
# %conv1 : [#users=1] = call_module[target=conv1](args = (%x,), kwargs = {})
@@ -130,9 +135,8 @@ def check_apply_bottleneck(rank, world_size, port):
# %bn3 : [#users=1] = call_module[target=bn3](args = (%conv3,), kwargs = {})
# %relu_2 : [#users=1] = call_module[target=relu](args = (%bn3,), kwargs = {})
# return relu_2
input_sample = {'x': torch.rand(256, 64, 224, 224).to('meta')}
cuda_rng_state = torch.cuda.get_rng_state()
origin_output = model(input)
input_sample = {'x': torch.rand(4, 4, 4, 4).to('meta')}
graph = tracer.trace(root=model, meta_args=input_sample)
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
@@ -147,16 +151,42 @@ def check_apply_bottleneck(rank, world_size, port):
ret = solver.call_solver_serialized_args()
solution = list(ret[0])
print(solution)
device_mesh.process_groups_dict = device_mesh.create_process_groups_for_logical_mesh()
sharding_spec_dict, origin_spec_dict = solution_annotatation_pass(gm, solution, device_mesh)
for index, node in enumerate(graph.nodes):
print(node.name, node.strategies_vector[solution[index]].name)
sharding_spec_dict, origin_spec_dict, comm_actions_dict = solution_annotatation_pass(gm, solution, device_mesh)
shape_consistency_pass(gm)
gm.recompile()
nodes = [node for node in gm.graph.nodes]
# TODO: wrap the gm to avoid the influence of the user training code
cuda_rng_state = torch.cuda.get_rng_state()
origin_output = test_model(test_input)
torch.cuda.set_rng_state(cuda_rng_state)
output = gm(input, sharding_spec_dict, origin_spec_dict)
output = gm(input, sharding_spec_dict, origin_spec_dict, comm_actions_dict)
assert output.shape == origin_output.shape
assert output.equal(origin_output)
assert_close(output, origin_output)
print("*******************backward starting*******************")
cuda_rng_state = torch.cuda.get_rng_state()
output.sum().backward()
torch.cuda.set_rng_state(cuda_rng_state)
origin_output.sum().backward()
if rank == 0:
print((gm.bn3.weight.grad - test_model.bn3.weight.grad.narrow(0, 0, 4)).abs().sum())
print((gm.conv3.weight.grad - test_model.conv3.weight.grad.narrow(0, 0, 8)).abs().sum())
assert_close_loose(gm.conv3.weight.grad.sum(), test_model.conv3.weight.grad.narrow(0, 0, 8).sum())
if rank == 1:
print((gm.bn3.weight.grad - test_model.bn3.weight.grad.narrow(0, 4, 4)).abs().sum())
print((gm.conv3.weight.grad - test_model.conv3.weight.grad.narrow(0, 0, 8)).abs().sum())
assert_close_loose(gm.conv3.weight.grad.sum(), test_model.conv3.weight.grad.narrow(0, 0, 8).sum())
if rank == 2:
print((gm.bn3.weight.grad - test_model.bn3.weight.grad.narrow(0, 8, 4)).abs().sum())
print((gm.conv3.weight.grad - test_model.conv3.weight.grad.narrow(0, 8, 8)).abs().sum())
assert_close_loose(gm.conv3.weight.grad.sum(), test_model.conv3.weight.grad.narrow(0, 8, 8).sum())
if rank == 3:
print((gm.bn3.weight.grad - test_model.bn3.weight.grad.narrow(0, 12, 4)).abs().sum())
print((gm.conv3.weight.grad - test_model.conv3.weight.grad.narrow(0, 8, 8)).abs().sum())
assert_close_loose(gm.conv3.weight.grad.sum(), test_model.conv3.weight.grad.narrow(0, 8, 8).sum())
@run_on_environment_flag(name='AUTO_PARALLEL')