mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-01-29 21:49:54 +00:00
[autoparallel] shard param and buffer as expected (#1753)
* [autoparallel] shard param and buffer as expected * fix unit test issue
This commit is contained in:
@@ -1,28 +1,32 @@
|
||||
import copy
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from torch.fx import GraphModule
|
||||
import torch.nn as nn
|
||||
import pytest
|
||||
from colossalai import device
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.auto_parallel.tensor_shard.solver.graph_analysis import GraphAnalyser
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.fx.passes.experimental.adding_shape_consistency_pass_v2 import shape_consistency_pass, solution_annotatation_pass
|
||||
from colossalai.auto_parallel.tensor_shard.solver.options import SolverOptions
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.auto_parallel.tensor_shard.solver.strategies_constructor import StrategiesConstructor
|
||||
from colossalai.auto_parallel.tensor_shard.solver.cost_graph import CostGraph
|
||||
from copy import deepcopy
|
||||
from colossalai.auto_parallel.tensor_shard.solver.solver import Solver
|
||||
from torch.fx import GraphModule
|
||||
from torchvision.models import resnet34, resnet50
|
||||
|
||||
from colossalai import device
|
||||
from colossalai.auto_parallel.tensor_shard.constants import *
|
||||
from colossalai.testing import assert_close_loose, assert_close
|
||||
from colossalai.auto_parallel.tensor_shard.solver.cost_graph import CostGraph
|
||||
from colossalai.auto_parallel.tensor_shard.solver.graph_analysis import GraphAnalyser
|
||||
from colossalai.auto_parallel.tensor_shard.solver.options import SolverOptions
|
||||
from colossalai.auto_parallel.tensor_shard.solver.solver import Solver
|
||||
from colossalai.auto_parallel.tensor_shard.solver.strategies_constructor import StrategiesConstructor
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx.passes.experimental.adding_shape_consistency_pass_v2 import (
|
||||
shape_consistency_pass,
|
||||
solution_annotatation_pass,
|
||||
)
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import assert_close, assert_close_loose, rerun_if_address_is_in_use
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.utils import free_port
|
||||
|
||||
seed = 128
|
||||
cudnn_benchmark = False
|
||||
@@ -108,16 +112,17 @@ class Bottleneck(nn.Module):
|
||||
def check_apply_bottleneck(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
input = torch.rand(256, 64, 64, 64).cuda()
|
||||
input = torch.rand(4, 4, 4, 4).cuda()
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
mesh_shape = (2, 2)
|
||||
# [[0, 1]
|
||||
# [2, 3]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=False)
|
||||
entire_shape = torch.Size((4, 4, 8, 8))
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||
|
||||
tracer = ColoTracer()
|
||||
model = Bottleneck(64, 64, 1, norm_layer=torch.nn.modules.batchnorm.BatchNorm2d).cuda()
|
||||
model = Bottleneck(4, 4, 1, norm_layer=torch.nn.modules.batchnorm.BatchNorm2d).cuda()
|
||||
test_model = copy.deepcopy(model)
|
||||
test_input = copy.deepcopy(input)
|
||||
# graph():
|
||||
# %x : torch.Tensor [#users=1] = placeholder[target=x]
|
||||
# %conv1 : [#users=1] = call_module[target=conv1](args = (%x,), kwargs = {})
|
||||
@@ -130,9 +135,8 @@ def check_apply_bottleneck(rank, world_size, port):
|
||||
# %bn3 : [#users=1] = call_module[target=bn3](args = (%conv3,), kwargs = {})
|
||||
# %relu_2 : [#users=1] = call_module[target=relu](args = (%bn3,), kwargs = {})
|
||||
# return relu_2
|
||||
input_sample = {'x': torch.rand(256, 64, 224, 224).to('meta')}
|
||||
cuda_rng_state = torch.cuda.get_rng_state()
|
||||
origin_output = model(input)
|
||||
input_sample = {'x': torch.rand(4, 4, 4, 4).to('meta')}
|
||||
|
||||
graph = tracer.trace(root=model, meta_args=input_sample)
|
||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
gm.recompile()
|
||||
@@ -147,16 +151,42 @@ def check_apply_bottleneck(rank, world_size, port):
|
||||
ret = solver.call_solver_serialized_args()
|
||||
solution = list(ret[0])
|
||||
print(solution)
|
||||
device_mesh.process_groups_dict = device_mesh.create_process_groups_for_logical_mesh()
|
||||
sharding_spec_dict, origin_spec_dict = solution_annotatation_pass(gm, solution, device_mesh)
|
||||
for index, node in enumerate(graph.nodes):
|
||||
print(node.name, node.strategies_vector[solution[index]].name)
|
||||
sharding_spec_dict, origin_spec_dict, comm_actions_dict = solution_annotatation_pass(gm, solution, device_mesh)
|
||||
shape_consistency_pass(gm)
|
||||
gm.recompile()
|
||||
nodes = [node for node in gm.graph.nodes]
|
||||
# TODO: wrap the gm to avoid the influence of the user training code
|
||||
cuda_rng_state = torch.cuda.get_rng_state()
|
||||
origin_output = test_model(test_input)
|
||||
torch.cuda.set_rng_state(cuda_rng_state)
|
||||
output = gm(input, sharding_spec_dict, origin_spec_dict)
|
||||
output = gm(input, sharding_spec_dict, origin_spec_dict, comm_actions_dict)
|
||||
|
||||
assert output.shape == origin_output.shape
|
||||
assert output.equal(origin_output)
|
||||
assert_close(output, origin_output)
|
||||
print("*******************backward starting*******************")
|
||||
cuda_rng_state = torch.cuda.get_rng_state()
|
||||
output.sum().backward()
|
||||
torch.cuda.set_rng_state(cuda_rng_state)
|
||||
origin_output.sum().backward()
|
||||
if rank == 0:
|
||||
print((gm.bn3.weight.grad - test_model.bn3.weight.grad.narrow(0, 0, 4)).abs().sum())
|
||||
print((gm.conv3.weight.grad - test_model.conv3.weight.grad.narrow(0, 0, 8)).abs().sum())
|
||||
assert_close_loose(gm.conv3.weight.grad.sum(), test_model.conv3.weight.grad.narrow(0, 0, 8).sum())
|
||||
if rank == 1:
|
||||
print((gm.bn3.weight.grad - test_model.bn3.weight.grad.narrow(0, 4, 4)).abs().sum())
|
||||
print((gm.conv3.weight.grad - test_model.conv3.weight.grad.narrow(0, 0, 8)).abs().sum())
|
||||
assert_close_loose(gm.conv3.weight.grad.sum(), test_model.conv3.weight.grad.narrow(0, 0, 8).sum())
|
||||
if rank == 2:
|
||||
print((gm.bn3.weight.grad - test_model.bn3.weight.grad.narrow(0, 8, 4)).abs().sum())
|
||||
print((gm.conv3.weight.grad - test_model.conv3.weight.grad.narrow(0, 8, 8)).abs().sum())
|
||||
assert_close_loose(gm.conv3.weight.grad.sum(), test_model.conv3.weight.grad.narrow(0, 8, 8).sum())
|
||||
|
||||
if rank == 3:
|
||||
print((gm.bn3.weight.grad - test_model.bn3.weight.grad.narrow(0, 12, 4)).abs().sum())
|
||||
print((gm.conv3.weight.grad - test_model.conv3.weight.grad.narrow(0, 8, 8)).abs().sum())
|
||||
assert_close_loose(gm.conv3.weight.grad.sum(), test_model.conv3.weight.grad.narrow(0, 8, 8).sum())
|
||||
|
||||
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
|
||||
Reference in New Issue
Block a user