diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/dot_handler.py b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/dot_handler.py index 4feeacd98..1f2281cc4 100644 --- a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/dot_handler.py +++ b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/dot_handler.py @@ -6,9 +6,9 @@ from typing import List import torch import torch.nn as nn import torch.nn.functional as F -from colossalai.auto_parallel.tensor_shard.deprecated._utils import \ - ignore_sharding_exception -from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector) + +from colossalai.auto_parallel.tensor_shard.deprecated._utils import ignore_sharding_exception +from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector from ..constants import LINEAR_FUNC_OP, LINEAR_MODULE_OP from .operator_handler import OperatorHandler @@ -82,13 +82,13 @@ class MatVecStrategyGenerator(StrategyGenerator): class MatMulStrategyGenerator(StrategyGenerator): """ - MatMulStrategyGenerator is used to generate the sharding strategies when the second tensor is + MatMulStrategyGenerator is used to generate the sharding strategies when the second tensor is a 2D tensor. This is used for nn.Linear, F.linear, torch.matmul and torch.addmm. A matmul can be formulated as [n, p] x [p, q] = [n, q] Args: - is_linear (bool): whether this generator is used for nn.Linear and F.linear. + is_linear (bool): whether this generator is used for nn.Linear and F.linear. This will incur extra transformation of the dim partitioning as the weight is transposed. """ @@ -255,7 +255,7 @@ class BatchedMatMulStrategyGenerator(StrategyGenerator): """ Generate sharding strategies for the batched matrix multiplication. - A batched matrix multiplication can be viewed as + A batched matrix multiplication can be viewed as [b, i, k] x [b, k, j] -> [b, i, j] """ @@ -431,7 +431,7 @@ class DotHandler(OperatorHandler): sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]} - sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input) + sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input) # generate resharding cost for this strategy resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) @@ -451,7 +451,7 @@ class DotHandler(OperatorHandler): # create and register strategy sharding_strategies = ShardingStrategy(name, - output_sharding_spec=sharding_spec_for_ouput, + output_sharding_spec=sharding_spec_for_output, compute_cost=compute_cost, communication_cost=communication_cost, memory_cost=toatl_memory_cost, @@ -473,7 +473,7 @@ class DotHandler(OperatorHandler): sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) dim_partition_dict_for_output = {0: [mesh_dim_0]} - sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) # generate resharding cost for this strategy resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) @@ -491,7 +491,7 @@ class DotHandler(OperatorHandler): communication_cost_grad_backward = self.device_mesh.all_reduce_cost(weight_memory_cost, mesh_dim_0) communication_cost = communication_cost_activation_forward + communication_cost_grad_backward sharding_strategies = ShardingStrategy(name, - output_sharding_spec=sharding_spec_for_ouput, + output_sharding_spec=sharding_spec_for_output, compute_cost=compute_cost, communication_cost=communication_cost, memory_cost=toatl_memory_cost, @@ -510,7 +510,7 @@ class DotHandler(OperatorHandler): sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) dim_partition_dict_for_output = {1: [mesh_dim_1]} - sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input) + sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input) # generate resharding cost for this strategy resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) @@ -529,7 +529,7 @@ class DotHandler(OperatorHandler): communication_cost = communication_cost_activation_backward + communication_cost_activation_forward sharding_strategies = ShardingStrategy(name, - output_sharding_spec=sharding_spec_for_ouput, + output_sharding_spec=sharding_spec_for_output, compute_cost=compute_cost, communication_cost=communication_cost, memory_cost=toatl_memory_cost, @@ -548,7 +548,7 @@ class DotHandler(OperatorHandler): sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) dim_partition_dict_for_output = {} - sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) # generate resharding cost for this strategy resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) @@ -564,7 +564,7 @@ class DotHandler(OperatorHandler): # compute the communication cost of this strategy communication_cost = self.device_mesh.all_reduce_cost(activation_memory_cost, mesh_dim) sharding_strategies = ShardingStrategy(name, - output_sharding_spec=sharding_spec_for_ouput, + output_sharding_spec=sharding_spec_for_output, compute_cost=compute_cost, communication_cost=communication_cost, memory_cost=toatl_memory_cost, @@ -583,7 +583,7 @@ class DotHandler(OperatorHandler): sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) dim_partition_dict_for_output = {1: [mesh_dim]} - sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) # generate resharding cost for this strategy resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) @@ -600,7 +600,7 @@ class DotHandler(OperatorHandler): communication_cost_activation_backward = self.device_mesh.all_reduce_cost(input_grad_memory_cost, mesh_dim) communication_cost = communication_cost_activation_backward sharding_strategies = ShardingStrategy(name, - output_sharding_spec=sharding_spec_for_ouput, + output_sharding_spec=sharding_spec_for_output, compute_cost=compute_cost, communication_cost=communication_cost, memory_cost=toatl_memory_cost, @@ -619,7 +619,7 @@ class DotHandler(OperatorHandler): sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) dim_partition_dict_for_output = {0: [mesh_dim_0, mesh_dim_1]} - sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) # generate resharding cost for this strategy resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) @@ -636,7 +636,7 @@ class DotHandler(OperatorHandler): communication_cost_weight_backward = self.device_mesh.flatten_device_mesh.all_reduce_cost(weight_memory_cost, 0) communication_cost = communication_cost_weight_backward sharding_strategies = ShardingStrategy(name, - output_sharding_spec=sharding_spec_for_ouput, + output_sharding_spec=sharding_spec_for_output, compute_cost=compute_cost, communication_cost=communication_cost, memory_cost=toatl_memory_cost, @@ -655,7 +655,7 @@ class DotHandler(OperatorHandler): sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) dim_partition_dict_for_output = {} - sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) # generate resharding cost for this strategy resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) @@ -673,7 +673,7 @@ class DotHandler(OperatorHandler): activation_memory_cost, 0) communication_cost = communication_cost_forward_activation sharding_strategies = ShardingStrategy(name, - output_sharding_spec=sharding_spec_for_ouput, + output_sharding_spec=sharding_spec_for_output, compute_cost=compute_cost, communication_cost=communication_cost, memory_cost=toatl_memory_cost, @@ -692,7 +692,7 @@ class DotHandler(OperatorHandler): sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) dim_partition_dict_for_output = {1: [mesh_dim_0, mesh_dim_1]} - sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) + sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) # generate resharding cost for this strategy resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) @@ -709,7 +709,7 @@ class DotHandler(OperatorHandler): input_grad_memory_cost, 0) communication_cost = communication_cost_activation_backward sharding_strategies = ShardingStrategy(name, - output_sharding_spec=sharding_spec_for_ouput, + output_sharding_spec=sharding_spec_for_output, compute_cost=compute_cost, communication_cost=communication_cost, memory_cost=toatl_memory_cost, diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py b/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py index b4ba3b7cd..a5e3f649a 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py @@ -5,14 +5,14 @@ from .bmm_handler import AddBMMFunctionHandler, BMMFunctionHandler from .conv_handler import ConvFunctionHandler, ConvModuleHandler from .embedding_handler import EmbeddingFunctionHandler, EmbeddingModuleHandler from .experimental import PermuteHandler, ViewHandler -from .getatrr_handler import GetattrHandler +from .getattr_handler import GetattrHandler from .getitem_handler import GetItemHandler from .layer_norm_handler import LayerNormModuleHandler from .linear_handler import LinearFunctionHandler, LinearModuleHandler from .matmul_handler import MatMulHandler from .normal_pooling_handler import NormPoolingHandler -from .output_handler import OuputHandler -from .placeholder_handler import PlacehodlerHandler +from .output_handler import OutputHandler +from .placeholder_handler import PlaceholderHandler from .registry import operator_registry from .reshape_handler import ReshapeHandler from .softmax_handler import SoftmaxHandler @@ -24,7 +24,7 @@ from .where_handler import WhereHandler __all__ = [ 'LinearFunctionHandler', 'LinearModuleHandler', 'BMMFunctionHandler', 'AddBMMFunctionHandler', 'LayerNormModuleHandler', 'BatchNormModuleHandler', 'ConvModuleHandler', 'ConvFunctionHandler', - 'UnaryElementwiseHandler', 'ReshapeHandler', 'PlacehodlerHandler', 'OuputHandler', 'WhereHandler', + 'UnaryElementwiseHandler', 'ReshapeHandler', 'PlaceholderHandler', 'OutputHandler', 'WhereHandler', 'NormPoolingHandler', 'BinaryElementwiseHandler', 'MatMulHandler', 'operator_registry', 'ADDMMFunctionHandler', 'GetItemHandler', 'GetattrHandler', 'ViewHandler', 'PermuteHandler', 'TensorConstructorHandler', 'EmbeddingModuleHandler', 'EmbeddingFunctionHandler', 'SumHandler', 'SoftmaxHandler' diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/getatrr_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/getattr_handler.py similarity index 100% rename from colossalai/auto_parallel/tensor_shard/node_handler/getatrr_handler.py rename to colossalai/auto_parallel/tensor_shard/node_handler/getattr_handler.py diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/output_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/output_handler.py index d2edfa83c..ed120a8c3 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/output_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/output_handler.py @@ -8,12 +8,12 @@ from ..sharding_strategy import OperationData, OperationDataType, StrategiesVect from .node_handler import NodeHandler from .strategy import OutputGenerator, StrategyGenerator -__all__ = ['OuputHandler'] +__all__ = ['OutputHandler'] -class OuputHandler(NodeHandler): +class OutputHandler(NodeHandler): """ - A OuputHandler which deals with the sharding strategies for Output Node. + A OutputHandler which deals with the sharding strategies for Output Node. """ def __init__(self, node: torch.fx.node.Node, device_mesh: DeviceMesh, strategies_vector: StrategiesVector, diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/placeholder_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/placeholder_handler.py index c72a5d3bf..e4f40fc93 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/placeholder_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/placeholder_handler.py @@ -8,12 +8,12 @@ from ..sharding_strategy import OperationData, OperationDataType, StrategiesVect from .node_handler import NodeHandler from .strategy import PlaceholderGenerator, StrategyGenerator -__all__ = ['PlacehodlerHandler'] +__all__ = ['PlaceholderHandler'] -class PlacehodlerHandler(NodeHandler): +class PlaceholderHandler(NodeHandler): """ - A PlacehodlerHandler which deals with the sharding strategies for Placeholder Node. + A PlaceholderHandler which deals with the sharding strategies for Placeholder Node. """ def __init__(self, node: Node, device_mesh: DeviceMesh, strategies_vector: StrategiesVector, diff --git a/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py b/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py index 5c40b83f9..042b9bb4b 100644 --- a/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py +++ b/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py @@ -9,8 +9,8 @@ from torch.fx import Graph, Node from colossalai.auto_parallel.tensor_shard.node_handler import ( GetattrHandler, - OuputHandler, - PlacehodlerHandler, + OutputHandler, + PlaceholderHandler, operator_registry, ) from colossalai.auto_parallel.tensor_shard.sharding_strategy import StrategiesVector @@ -93,7 +93,7 @@ class StrategiesConstructor: else: assert self.solver_options.dataloader_option == DataloaderOption.REPLICATED, f'placeholder_option {self.solver_options.dataloader_option} is not supported' placeholder_option = 'replicated' - placeholder_handler = PlacehodlerHandler(node, + placeholder_handler = PlaceholderHandler(node, self.device_mesh, strategies_vector, placeholder_option=placeholder_option) @@ -140,7 +140,7 @@ class StrategiesConstructor: else: assert self.solver_options.dataloader_option == DataloaderOption.REPLICATED, f'placeholder_option {self.solver_options.dataloader_option} is not supported' output_option = 'replicated' - output_handler = OuputHandler(node, self.device_mesh, strategies_vector, output_option=output_option) + output_handler = OutputHandler(node, self.device_mesh, strategies_vector, output_option=output_option) output_handler.register_strategy() self.remove_duplicated_strategy(strategies_vector) diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getattr_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getattr_handler.py index d3af5ac6f..681e93a5f 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getattr_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getattr_handler.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn -from colossalai.auto_parallel.tensor_shard.node_handler.getatrr_handler import GetattrHandler +from colossalai.auto_parallel.tensor_shard.node_handler.getattr_handler import GetattrHandler from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.device.device_mesh import DeviceMesh from colossalai.fx import ColoGraphModule, ColoTracer diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py index 3547767dc..3c35da61b 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py @@ -7,7 +7,7 @@ import torch.nn as nn from colossalai.auto_parallel.tensor_shard.node_handler.getitem_handler import GetItemHandler from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler -from colossalai.auto_parallel.tensor_shard.node_handler.placeholder_handler import PlacehodlerHandler +from colossalai.auto_parallel.tensor_shard.node_handler.placeholder_handler import PlaceholderHandler from colossalai.auto_parallel.tensor_shard.node_handler.reshape_handler import ReshapeHandler from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.device.device_mesh import DeviceMesh @@ -145,7 +145,7 @@ def test_getitem_from_tuple_handler(): split_strategies_vector = StrategiesVector(split_node) # build handler - input_handler = PlacehodlerHandler( + input_handler = PlaceholderHandler( node=input_node, device_mesh=device_mesh, strategies_vector=input_strategies_vector, diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_output_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_output_handler.py index 16eb98300..26376c429 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_output_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_output_handler.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn -from colossalai.auto_parallel.tensor_shard.node_handler.output_handler import OuputHandler +from colossalai.auto_parallel.tensor_shard.node_handler.output_handler import OutputHandler from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.device.device_mesh import DeviceMesh from colossalai.fx import ColoGraphModule, ColoTracer @@ -39,10 +39,10 @@ def test_output_handler(output_option): output_strategies_vector = StrategiesVector(output_node) # build handler - otuput_handler = OuputHandler(node=output_node, - device_mesh=device_mesh, - strategies_vector=output_strategies_vector, - output_option=output_option) + otuput_handler = OutputHandler(node=output_node, + device_mesh=device_mesh, + strategies_vector=output_strategies_vector, + output_option=output_option) otuput_handler.register_strategy(compute_resharding_cost=False) # check operation data mapping diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_placeholder_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_placeholder_handler.py index 0aafb9e0b..9bc453a27 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_placeholder_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_placeholder_handler.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn -from colossalai.auto_parallel.tensor_shard.node_handler.placeholder_handler import PlacehodlerHandler +from colossalai.auto_parallel.tensor_shard.node_handler.placeholder_handler import PlaceholderHandler from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.device.device_mesh import DeviceMesh from colossalai.fx import ColoGraphModule, ColoTracer @@ -36,7 +36,7 @@ def test_placeholder_handler(placeholder_option): placeholder_node = list(graph.nodes)[0] placeholder_strategies_vector = StrategiesVector(placeholder_node) # build handler - placeholder_handler = PlacehodlerHandler(node=placeholder_node, + placeholder_handler = PlaceholderHandler(node=placeholder_node, device_mesh=device_mesh, strategies_vector=placeholder_strategies_vector, placeholder_option=placeholder_option)