[autoparallel] handled illegal sharding strategy (#1728)

* [autoparallel] handled illegal sharding strategy

* polish code
This commit is contained in:
Frank Lee
2022-10-19 12:53:06 +08:00
committed by GitHub
parent cbe9a4cb45
commit eee84908d4
36 changed files with 459 additions and 303 deletions

View File

View File

@@ -1,12 +1,15 @@
import torch
from torch.fx import GraphModule
import torch.nn as nn
from cProfile import run
import pytest
import torch
import torch.nn as nn
from torch.fx import GraphModule
from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.testing.pytest_wrapper import run_on_environment_flag
class ConvModel(nn.Module):
@@ -27,6 +30,7 @@ class ConvModel(nn.Module):
return x
@run_on_environment_flag(name='AUTO_PARALLEL')
def test_conv_handler():
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)

View File

@@ -1,12 +1,13 @@
import torch
from torch.fx import GraphModule
import torch.nn as nn
import pytest
import torch
import torch.nn as nn
from torch.fx import GraphModule
from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.testing.pytest_wrapper import run_on_environment_flag
class MatmulModel(nn.Module):
@@ -20,6 +21,7 @@ class MatmulModel(nn.Module):
return x
@run_on_environment_flag(name='AUTO_PARALLEL')
def test_conv_handler():
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)

View File

@@ -0,0 +1,37 @@
import torch
from colossalai.tensor.sharding_spec import ShardingSpec
def is_sharding_spec_valid(sharding_spec: ShardingSpec, tensor: torch.Tensor):
"""
This function checks whether the ShardingSpec is valid for the physical tensor.
This check includes 2 items:
1. the sharding spec covers all dimensions of the physical tensor
2. the sharding spec for each dimension is divisible by the number of devices.
#
"""
# make sure all dims are covered in sharding spec
sharding_len = len(sharding_spec.sharding_sequence)
tensor_num_dim = tensor.dim()
num_devices_in_col = sharding_spec.device_mesh.mesh_shape[0]
num_devices_in_row = sharding_spec.device_mesh.mesh_shape[1]
assert sharding_len == tensor_num_dim, \
f'The ShardingSpec ({sharding_spec.sharding_sequence}) is created for {sharding_len}-dimension tensor, but the given tensor is {tensor_num_dim}-dimension ({tensor.shape}).'
# make sure the sharding is valid for each dim
for i in range(tensor_num_dim):
dim_size = tensor.shape[i]
dim_spec = sharding_spec.sharding_sequence[i]
if str(dim_spec).startswith('S'):
devices_str = str(dim_spec).lstrip('S')
num_devices = 1
if '0' in devices_str:
num_devices *= num_devices_in_col
if '1' in devices_str:
num_devices *= num_devices_in_row
assert dim_size >= num_devices and dim_size % num_devices == 0, \
f'The dimension at index {i} has value {dim_size}, but it is sharded over {num_devices} devices.'

View File

@@ -6,12 +6,13 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationDa
StrategiesVector)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.fx.tracer.meta_patch.patched_module import linear
from colossalai.tensor.sharding_spec import ShardingSpec
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.common import \
is_sharding_spec_valid
def test_linear_module_handler():
model = nn.Sequential(nn.Linear(16, 32).to('meta'))
tracer = ColoTracer()
graph = tracer.trace(model, meta_args={"input": torch.rand(2, 2, 4, 16).to('meta')})
gm = ColoGraphModule(model, graph)
@@ -91,6 +92,12 @@ def test_linear_module_handler():
bias_sharding_spec = strategy.get_sharding_spec_by_name('bias')
output_sharding_spec = strategy.get_sharding_spec_by_name('_0')
# make sure the sharding spec is valid
is_sharding_spec_valid(input_sharding_spec, torch.rand(2, 2, 4, 16))
is_sharding_spec_valid(weight_sharding_spec, model.get_parameter('0.weight'))
is_sharding_spec_valid(bias_sharding_spec, model.get_parameter('0.bias'))
is_sharding_spec_valid(output_sharding_spec, torch.rand([2, 2, 4, 32]))
# make sure the sharding matches across different operation data
assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1]
assert weight_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1]
@@ -101,7 +108,7 @@ def test_linear_module_handler():
def test_linear_function_handler():
model = nn.Linear(16, 32).to('meta')
tracer = ColoTracer()
graph = tracer.trace(model, meta_args={"input": torch.rand(4, 16).to('meta')})
graph = tracer.trace(model, meta_args={"input": torch.rand(2, 2, 4, 16).to('meta')})
gm = ColoGraphModule(model, graph)
physical_mesh_id = torch.arange(0, 4)
@@ -117,11 +124,13 @@ def test_linear_function_handler():
# # check operation data mapping
mapping = handler.get_operation_data_mapping()
print(mapping['input'].logical_shape)
assert mapping['input'].name == "input_1"
assert mapping['input'].data.is_meta
assert mapping['input'].data.shape == torch.Size([4, 16])
assert mapping['input'].data.shape == torch.Size([2, 2, 4, 16])
assert mapping['input'].type == OperationDataType.ARG
assert mapping['input'].logical_shape == torch.Size([4, 16])
assert mapping['input'].logical_shape == torch.Size([16, 16])
assert mapping['other'].name == "weight"
assert mapping['other'].data.is_meta
@@ -137,7 +146,7 @@ def test_linear_function_handler():
assert mapping['output'].name == "linear"
assert mapping['output'].data.is_meta
assert mapping['output'].data.shape == torch.Size([4, 32])
assert mapping['output'].data.shape == torch.Size([2, 2, 4, 32])
assert mapping['output'].type == OperationDataType.OUTPUT
strategies_vector = handler.register_strategy(compute_resharding_cost=False)
@@ -167,11 +176,18 @@ def test_linear_function_handler():
for strategy in strategies_vector:
strategy: ShardingStrategy
print(strategy)
input_sharding_spec = strategy.get_sharding_spec_by_name('input_1')
weight_sharding_spec = strategy.get_sharding_spec_by_name('weight')
bias_sharding_spec = strategy.get_sharding_spec_by_name('bias')
output_sharding_spec = strategy.get_sharding_spec_by_name('linear')
# make sure the sharding spec is valid
is_sharding_spec_valid(input_sharding_spec, torch.rand(2, 2, 4, 16))
is_sharding_spec_valid(weight_sharding_spec, model.get_parameter('weight'))
is_sharding_spec_valid(bias_sharding_spec, model.get_parameter('bias'))
is_sharding_spec_valid(output_sharding_spec, torch.rand([2, 2, 4, 32]))
# make sure the sharding matches across different operation data
assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1]
assert weight_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1]

View File

@@ -1,6 +1,5 @@
import torch
import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import \
ConvFunctionHandler
from colossalai.auto_parallel.tensor_shard.node_handler.unary_elementwise_handler import \

View File

@@ -1,16 +1,18 @@
from functools import partial
from lib2to3 import pgen2
import colossalai
import torch
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn.functional as F
import colossalai
from colossalai.device.device_mesh import DeviceMesh
from colossalai.nn._ops._utils import gather_forward_split_backward
from colossalai.tensor import ColoParameter, ColoTensor, ProcessGroup
from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from functools import partial
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.tensor import ColoTensor, ColoParameter, ProcessGroup
from colossalai.nn._ops._utils import gather_forward_split_backward
def run_dist(rank, world_size, port):
@@ -18,7 +20,7 @@ def run_dist(rank, world_size, port):
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
# create mlp vars
x = ColoTensor.from_torch_tensor(torch.rand(2, 4, 8, requires_grad=True)).cuda()
x = ColoTensor.from_torch_tensor(torch.rand(4, 4, 8, requires_grad=True)).cuda()
w = ColoParameter.from_torch_tensor(torch.rand(16, 8, requires_grad=True)).cuda()
b = ColoParameter.from_torch_tensor(torch.rand(16, requires_grad=True)).cuda()

View File

@@ -1,6 +1,7 @@
import torch
from colossalai.tensor.sharding_spec import _DimSpec, ShardingSpec
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
def test_sharding_spec():
@@ -11,7 +12,7 @@ def test_sharding_spec():
# [8, 9, 10,11],
# [12,13,14,15]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
entire_shape = torch.Size((4, 8, 6))
entire_shape = torch.Size((16, 8, 6))
dim_partition_dict = {0: [0, 1]}
# DistSpec:
# shard_sequence: S01,R,R