diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/solver.py b/colossalai/auto_parallel/tensor_shard/deprecated/solver.py index 2167e6ac2..4c1d2f3be 100644 --- a/colossalai/auto_parallel/tensor_shard/deprecated/solver.py +++ b/colossalai/auto_parallel/tensor_shard/deprecated/solver.py @@ -1,18 +1,20 @@ -import warnings - -import time -import numpy as np import multiprocessing -from torch.fx.node import Node -from torch.fx.graph import Graph -from .graph_analysis import GraphAnalyser -from .cost_graph import CostGraph -from .strategies_constructor import StrategiesConstructor +import time +import warnings from typing import Dict + +import numpy as np +from torch.fx.graph import Graph +from torch.fx.node import Node + from .constants import INFINITY_COST +from .cost_graph import CostGraph +from .graph_analysis import GraphAnalyser +from .strategies_constructor import StrategiesConstructor + try: import pulp - from pulp import LpVariable, LpProblem, LpMinimize, lpSum, lpDot, LpStatus + from pulp import LpMinimize, LpProblem, LpStatus, LpVariable, lpDot, lpSum except: warnings.warn(f'please install the pulp') diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py index 2184c3f47..299184b29 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py @@ -1,10 +1,16 @@ from abc import ABC, abstractmethod from typing import Dict, List, Union +import torch from torch.fx.node import Node -from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, ShardingStrategy, StrategiesVector, - TrainCycleItem) +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + OperationData, + ShardingStrategy, + StrategiesVector, + TrainCycleItem, +) +from colossalai.auto_parallel.tensor_shard.utils import check_sharding_spec_validity from colossalai.device.device_mesh import DeviceMesh from colossalai.tensor.shape_consistency import ShapeConsistencyManager @@ -98,6 +104,12 @@ class NodeHandler(ABC): self.strategies_vector.extend(post_processed_strategies) + # validating the correctness of the sharding strategy + for strategy in self.strategies_vector: + for op_data, sharding_spec in strategy.sharding_specs.items(): + if op_data.data is not None and isinstance(op_data.data, torch.Tensor): + check_sharding_spec_validity(sharding_spec, op_data.data) + return self.strategies_vector def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]: @@ -116,8 +128,8 @@ class NodeHandler(ABC): def get_operation_data_mapping(self) -> Dict[str, OperationData]: """ Returns the mapping between the logical operation data to its physical data. - A logical operation data is a data associated with an operation, which can be input and output. It is - defined by the strategy generator, for example, a matrix multiplication operation has two operands "input" + A logical operation data is a data associated with an operation, which can be input and output. It is + defined by the strategy generator, for example, a matrix multiplication operation has two operands "input" and "other" and one result "output". For a nn.Linear module, the physical operand for "input" is the module input, the physical operand for "other" is the module weight, and the physical result for "output" is the module output. diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/batch_norm_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/batch_norm_generator.py index 716ffc917..e648fff39 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/batch_norm_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/batch_norm_generator.py @@ -3,7 +3,7 @@ import operator from functools import reduce from typing import List -from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem) +from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem from colossalai.tensor.shape_consistency import CollectiveCommPattern from .strategy_generator import StrategyGenerator @@ -31,8 +31,8 @@ class BatchNormStrategyGenerator(StrategyGenerator): For BatchNorm3d, the dim of input data should be 5([N, C, H, W, D]). ''' input_op_data = self.op_data['input'] - assert input_op_data.dim() in (3, 4, - 5), f'We suppose the dim of input fed into conv op should in range of [3, 5].' + assert input_op_data.data.dim() in ( + 3, 4, 5), f'We suppose the dim of input fed into conv op should in range of [3, 5].' def update_compute_cost(self, strategy: ShardingStrategy): ''' diff --git a/colossalai/auto_parallel/tensor_shard/utils/__init__.py b/colossalai/auto_parallel/tensor_shard/utils/__init__.py index d56b80a09..c570ac871 100644 --- a/colossalai/auto_parallel/tensor_shard/utils/__init__.py +++ b/colossalai/auto_parallel/tensor_shard/utils/__init__.py @@ -1,12 +1,17 @@ -from .broadcast import (BroadcastType, get_broadcast_shape, is_broadcastable, recover_sharding_spec_for_broadcast_shape) +from .broadcast import BroadcastType, get_broadcast_shape, is_broadcastable, recover_sharding_spec_for_broadcast_shape from .factory import generate_resharding_costs, generate_sharding_spec -from .misc import ignore_sharding_exception -from .sharding import (enumerate_all_possible_1d_sharding, enumerate_all_possible_2d_sharding, generate_sharding_size, - switch_partition_dim, update_partition_dim) +from .misc import check_sharding_spec_validity, ignore_sharding_exception +from .sharding import ( + enumerate_all_possible_1d_sharding, + enumerate_all_possible_2d_sharding, + generate_sharding_size, + switch_partition_dim, + update_partition_dim, +) __all__ = [ 'BroadcastType', 'get_broadcast_shape', 'is_broadcastable', 'recover_sharding_spec_for_broadcast_shape', - 'generate_resharding_costs', 'generate_sharding_spec', 'ignore_sharding_exception', 'switch_partition_dim', - 'update_partition_dim', 'enumerate_all_possible_1d_sharding', 'enumerate_all_possible_2d_sharding', - 'generate_sharding_size' + 'generate_resharding_costs', 'generate_sharding_spec', 'ignore_sharding_exception', 'check_sharding_spec_validity' + 'switch_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding', + 'enumerate_all_possible_2d_sharding', 'generate_sharding_size' ] diff --git a/colossalai/auto_parallel/tensor_shard/utils/misc.py b/colossalai/auto_parallel/tensor_shard/utils/misc.py index d174988b8..9a445869f 100644 --- a/colossalai/auto_parallel/tensor_shard/utils/misc.py +++ b/colossalai/auto_parallel/tensor_shard/utils/misc.py @@ -1,7 +1,9 @@ import functools +import torch + from colossalai.logging import get_dist_logger -from colossalai.tensor.sharding_spec import ShardingSpecException +from colossalai.tensor.sharding_spec import ShardingSpec, ShardingSpecException __all__ = ['ignore_sharding_exception'] @@ -29,3 +31,37 @@ def ignore_sharding_exception(func): return None return wrapper + + +def check_sharding_spec_validity(sharding_spec: ShardingSpec, tensor: torch.Tensor): + """ + This function checks whether the ShardingSpec is valid for the physical tensor. + This check includes 2 items: + 1. the sharding spec covers all dimensions of the physical tensor + 2. the sharding spec for each dimension is divisible by the number of devices. + # + """ + # make sure all dims are covered in sharding spec + sharding_len = len(sharding_spec.sharding_sequence) + tensor_num_dim = tensor.dim() + num_devices_in_col = sharding_spec.device_mesh.mesh_shape[0] + num_devices_in_row = sharding_spec.device_mesh.mesh_shape[1] + assert sharding_len == tensor_num_dim, \ + f'The ShardingSpec ({sharding_spec.sharding_sequence}) is created for {sharding_len}-dimension tensor, but the given tensor is {tensor_num_dim}-dimension ({tensor.shape}).' + + # make sure the sharding is valid for each dim + for i in range(tensor_num_dim): + dim_size = tensor.shape[i] + dim_spec = sharding_spec.sharding_sequence[i] + + if str(dim_spec).startswith('S'): + devices_str = str(dim_spec).lstrip('S') + num_devices = 1 + + if '0' in devices_str: + num_devices *= num_devices_in_col + if '1' in devices_str: + num_devices *= num_devices_in_row + + assert dim_size >= num_devices and dim_size % num_devices == 0, \ + f'The dimension at index {i} has value {dim_size}, but it is sharded over {num_devices} devices.' diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/common.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/common.py deleted file mode 100644 index 695f79722..000000000 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/common.py +++ /dev/null @@ -1,37 +0,0 @@ -import torch - -from colossalai.tensor.sharding_spec import ShardingSpec - - -def is_sharding_spec_valid(sharding_spec: ShardingSpec, tensor: torch.Tensor): - """ - This function checks whether the ShardingSpec is valid for the physical tensor. - This check includes 2 items: - 1. the sharding spec covers all dimensions of the physical tensor - 2. the sharding spec for each dimension is divisible by the number of devices. - # - """ - # make sure all dims are covered in sharding spec - sharding_len = len(sharding_spec.sharding_sequence) - tensor_num_dim = tensor.dim() - num_devices_in_col = sharding_spec.device_mesh.mesh_shape[0] - num_devices_in_row = sharding_spec.device_mesh.mesh_shape[1] - assert sharding_len == tensor_num_dim, \ - f'The ShardingSpec ({sharding_spec.sharding_sequence}) is created for {sharding_len}-dimension tensor, but the given tensor is {tensor_num_dim}-dimension ({tensor.shape}).' - - # make sure the sharding is valid for each dim - for i in range(tensor_num_dim): - dim_size = tensor.shape[i] - dim_spec = sharding_spec.sharding_sequence[i] - - if str(dim_spec).startswith('S'): - devices_str = str(dim_spec).lstrip('S') - num_devices = 1 - - if '0' in devices_str: - num_devices *= num_devices_in_col - if '1' in devices_str: - num_devices *= num_devices_in_row - - assert dim_size >= num_devices and dim_size % num_devices == 0, \ - f'The dimension at index {i} has value {dim_size}, but it is sharded over {num_devices} devices.' diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_conv_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_conv_handler.py index ddce9f5eb..b2d6754a5 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_conv_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_conv_handler.py @@ -1,11 +1,10 @@ import torch import torch.nn as nn -from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import (ConvFunctionHandler, ConvModuleHandler) -from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector) +from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler, ConvModuleHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.device.device_mesh import DeviceMesh from colossalai.fx import ColoGraphModule, ColoTracer -from colossalai.fx.tracer.meta_patch.patched_module import linear def test_conv_module_handler(): diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py index 8934571f9..d2f26e704 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py @@ -1,13 +1,15 @@ import torch import torch.nn as nn -from colossalai.auto_parallel.tensor_shard.node_handler.dot_handler import (LinearFunctionHandler, LinearModuleHandler) -from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, ShardingStrategy, - StrategiesVector) +from colossalai.auto_parallel.tensor_shard.node_handler.dot_handler import LinearFunctionHandler, LinearModuleHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + OperationData, + OperationDataType, + ShardingStrategy, + StrategiesVector, +) from colossalai.device.device_mesh import DeviceMesh from colossalai.fx import ColoGraphModule, ColoTracer -from tests.test_auto_parallel.test_tensor_shard.test_node_handler.common import \ - is_sharding_spec_valid def test_linear_module_handler(): @@ -92,12 +94,6 @@ def test_linear_module_handler(): bias_sharding_spec = strategy.get_sharding_spec_by_name('bias') output_sharding_spec = strategy.get_sharding_spec_by_name('_0') - # make sure the sharding spec is valid - is_sharding_spec_valid(input_sharding_spec, torch.rand(2, 2, 4, 16)) - is_sharding_spec_valid(weight_sharding_spec, model.get_parameter('0.weight')) - is_sharding_spec_valid(bias_sharding_spec, model.get_parameter('0.bias')) - is_sharding_spec_valid(output_sharding_spec, torch.rand([2, 2, 4, 32])) - # make sure the sharding matches across different operation data assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1] assert weight_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1] @@ -182,12 +178,6 @@ def test_linear_function_handler(): bias_sharding_spec = strategy.get_sharding_spec_by_name('bias') output_sharding_spec = strategy.get_sharding_spec_by_name('linear') - # make sure the sharding spec is valid - is_sharding_spec_valid(input_sharding_spec, torch.rand(2, 2, 4, 16)) - is_sharding_spec_valid(weight_sharding_spec, model.get_parameter('weight')) - is_sharding_spec_valid(bias_sharding_spec, model.get_parameter('bias')) - is_sharding_spec_valid(output_sharding_spec, torch.rand([2, 2, 4, 32])) - # make sure the sharding matches across different operation data assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1] assert weight_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1]