mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-15 22:19:38 +00:00
[autoparallel] handled illegal sharding strategy (#1728)
* [autoparallel] handled illegal sharding strategy * polish code
This commit is contained in:
0
tests/test_auto_parallel/__init__.py
Normal file
0
tests/test_auto_parallel/__init__.py
Normal file
@@ -1,12 +1,15 @@
|
||||
import torch
|
||||
from torch.fx import GraphModule
|
||||
import torch.nn as nn
|
||||
from cProfile import run
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.fx import GraphModule
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
|
||||
|
||||
class ConvModel(nn.Module):
|
||||
@@ -27,6 +30,7 @@ class ConvModel(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
def test_conv_handler():
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
mesh_shape = (2, 2)
|
||||
|
@@ -1,12 +1,13 @@
|
||||
import torch
|
||||
from torch.fx import GraphModule
|
||||
import torch.nn as nn
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.fx import GraphModule
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
|
||||
|
||||
class MatmulModel(nn.Module):
|
||||
@@ -20,6 +21,7 @@ class MatmulModel(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
def test_conv_handler():
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
mesh_shape = (2, 2)
|
||||
|
@@ -0,0 +1,37 @@
|
||||
import torch
|
||||
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
|
||||
def is_sharding_spec_valid(sharding_spec: ShardingSpec, tensor: torch.Tensor):
|
||||
"""
|
||||
This function checks whether the ShardingSpec is valid for the physical tensor.
|
||||
This check includes 2 items:
|
||||
1. the sharding spec covers all dimensions of the physical tensor
|
||||
2. the sharding spec for each dimension is divisible by the number of devices.
|
||||
#
|
||||
"""
|
||||
# make sure all dims are covered in sharding spec
|
||||
sharding_len = len(sharding_spec.sharding_sequence)
|
||||
tensor_num_dim = tensor.dim()
|
||||
num_devices_in_col = sharding_spec.device_mesh.mesh_shape[0]
|
||||
num_devices_in_row = sharding_spec.device_mesh.mesh_shape[1]
|
||||
assert sharding_len == tensor_num_dim, \
|
||||
f'The ShardingSpec ({sharding_spec.sharding_sequence}) is created for {sharding_len}-dimension tensor, but the given tensor is {tensor_num_dim}-dimension ({tensor.shape}).'
|
||||
|
||||
# make sure the sharding is valid for each dim
|
||||
for i in range(tensor_num_dim):
|
||||
dim_size = tensor.shape[i]
|
||||
dim_spec = sharding_spec.sharding_sequence[i]
|
||||
|
||||
if str(dim_spec).startswith('S'):
|
||||
devices_str = str(dim_spec).lstrip('S')
|
||||
num_devices = 1
|
||||
|
||||
if '0' in devices_str:
|
||||
num_devices *= num_devices_in_col
|
||||
if '1' in devices_str:
|
||||
num_devices *= num_devices_in_row
|
||||
|
||||
assert dim_size >= num_devices and dim_size % num_devices == 0, \
|
||||
f'The dimension at index {i} has value {dim_size}, but it is sharded over {num_devices} devices.'
|
@@ -6,12 +6,13 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationDa
|
||||
StrategiesVector)
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.fx.tracer.meta_patch.patched_module import linear
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.common import \
|
||||
is_sharding_spec_valid
|
||||
|
||||
|
||||
def test_linear_module_handler():
|
||||
model = nn.Sequential(nn.Linear(16, 32).to('meta'))
|
||||
|
||||
tracer = ColoTracer()
|
||||
graph = tracer.trace(model, meta_args={"input": torch.rand(2, 2, 4, 16).to('meta')})
|
||||
gm = ColoGraphModule(model, graph)
|
||||
@@ -91,6 +92,12 @@ def test_linear_module_handler():
|
||||
bias_sharding_spec = strategy.get_sharding_spec_by_name('bias')
|
||||
output_sharding_spec = strategy.get_sharding_spec_by_name('_0')
|
||||
|
||||
# make sure the sharding spec is valid
|
||||
is_sharding_spec_valid(input_sharding_spec, torch.rand(2, 2, 4, 16))
|
||||
is_sharding_spec_valid(weight_sharding_spec, model.get_parameter('0.weight'))
|
||||
is_sharding_spec_valid(bias_sharding_spec, model.get_parameter('0.bias'))
|
||||
is_sharding_spec_valid(output_sharding_spec, torch.rand([2, 2, 4, 32]))
|
||||
|
||||
# make sure the sharding matches across different operation data
|
||||
assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1]
|
||||
assert weight_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1]
|
||||
@@ -101,7 +108,7 @@ def test_linear_module_handler():
|
||||
def test_linear_function_handler():
|
||||
model = nn.Linear(16, 32).to('meta')
|
||||
tracer = ColoTracer()
|
||||
graph = tracer.trace(model, meta_args={"input": torch.rand(4, 16).to('meta')})
|
||||
graph = tracer.trace(model, meta_args={"input": torch.rand(2, 2, 4, 16).to('meta')})
|
||||
gm = ColoGraphModule(model, graph)
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
|
||||
@@ -117,11 +124,13 @@ def test_linear_function_handler():
|
||||
# # check operation data mapping
|
||||
mapping = handler.get_operation_data_mapping()
|
||||
|
||||
print(mapping['input'].logical_shape)
|
||||
|
||||
assert mapping['input'].name == "input_1"
|
||||
assert mapping['input'].data.is_meta
|
||||
assert mapping['input'].data.shape == torch.Size([4, 16])
|
||||
assert mapping['input'].data.shape == torch.Size([2, 2, 4, 16])
|
||||
assert mapping['input'].type == OperationDataType.ARG
|
||||
assert mapping['input'].logical_shape == torch.Size([4, 16])
|
||||
assert mapping['input'].logical_shape == torch.Size([16, 16])
|
||||
|
||||
assert mapping['other'].name == "weight"
|
||||
assert mapping['other'].data.is_meta
|
||||
@@ -137,7 +146,7 @@ def test_linear_function_handler():
|
||||
|
||||
assert mapping['output'].name == "linear"
|
||||
assert mapping['output'].data.is_meta
|
||||
assert mapping['output'].data.shape == torch.Size([4, 32])
|
||||
assert mapping['output'].data.shape == torch.Size([2, 2, 4, 32])
|
||||
assert mapping['output'].type == OperationDataType.OUTPUT
|
||||
|
||||
strategies_vector = handler.register_strategy(compute_resharding_cost=False)
|
||||
@@ -167,11 +176,18 @@ def test_linear_function_handler():
|
||||
|
||||
for strategy in strategies_vector:
|
||||
strategy: ShardingStrategy
|
||||
print(strategy)
|
||||
input_sharding_spec = strategy.get_sharding_spec_by_name('input_1')
|
||||
weight_sharding_spec = strategy.get_sharding_spec_by_name('weight')
|
||||
bias_sharding_spec = strategy.get_sharding_spec_by_name('bias')
|
||||
output_sharding_spec = strategy.get_sharding_spec_by_name('linear')
|
||||
|
||||
# make sure the sharding spec is valid
|
||||
is_sharding_spec_valid(input_sharding_spec, torch.rand(2, 2, 4, 16))
|
||||
is_sharding_spec_valid(weight_sharding_spec, model.get_parameter('weight'))
|
||||
is_sharding_spec_valid(bias_sharding_spec, model.get_parameter('bias'))
|
||||
is_sharding_spec_valid(output_sharding_spec, torch.rand([2, 2, 4, 32]))
|
||||
|
||||
# make sure the sharding matches across different operation data
|
||||
assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1]
|
||||
assert weight_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1]
|
||||
|
@@ -1,6 +1,5 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import \
|
||||
ConvFunctionHandler
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.unary_elementwise_handler import \
|
||||
|
@@ -1,16 +1,18 @@
|
||||
from functools import partial
|
||||
from lib2to3 import pgen2
|
||||
import colossalai
|
||||
import torch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn.functional as F
|
||||
|
||||
import colossalai
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.nn._ops._utils import gather_forward_split_backward
|
||||
from colossalai.tensor import ColoParameter, ColoTensor, ProcessGroup
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from functools import partial
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
from colossalai.tensor import ColoTensor, ColoParameter, ProcessGroup
|
||||
from colossalai.nn._ops._utils import gather_forward_split_backward
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
@@ -18,7 +20,7 @@ def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
# create mlp vars
|
||||
x = ColoTensor.from_torch_tensor(torch.rand(2, 4, 8, requires_grad=True)).cuda()
|
||||
x = ColoTensor.from_torch_tensor(torch.rand(4, 4, 8, requires_grad=True)).cuda()
|
||||
w = ColoParameter.from_torch_tensor(torch.rand(16, 8, requires_grad=True)).cuda()
|
||||
b = ColoParameter.from_torch_tensor(torch.rand(16, requires_grad=True)).cuda()
|
||||
|
||||
|
@@ -1,6 +1,7 @@
|
||||
import torch
|
||||
from colossalai.tensor.sharding_spec import _DimSpec, ShardingSpec
|
||||
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
|
||||
|
||||
|
||||
def test_sharding_spec():
|
||||
@@ -11,7 +12,7 @@ def test_sharding_spec():
|
||||
# [8, 9, 10,11],
|
||||
# [12,13,14,15]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
entire_shape = torch.Size((4, 8, 6))
|
||||
entire_shape = torch.Size((16, 8, 6))
|
||||
dim_partition_dict = {0: [0, 1]}
|
||||
# DistSpec:
|
||||
# shard_sequence: S01,R,R
|
||||
|
Reference in New Issue
Block a user