[autoparallel]add embedding handler (#2089)

* [autoparallel] add embedding handler

* fix bugs
This commit is contained in:
YuliangLiu0306 2022-12-07 09:41:46 +08:00 committed by GitHub
parent 1fca5d79ea
commit 7f72eb0510
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 844 additions and 8 deletions

View File

@ -3,6 +3,7 @@ from .batch_norm_handler import BatchNormModuleHandler
from .binary_elementwise_handler import BinaryElementwiseHandler from .binary_elementwise_handler import BinaryElementwiseHandler
from .bmm_handler import AddBMMFunctionHandler, BMMFunctionHandler from .bmm_handler import AddBMMFunctionHandler, BMMFunctionHandler
from .conv_handler import ConvFunctionHandler, ConvModuleHandler from .conv_handler import ConvFunctionHandler, ConvModuleHandler
from .embedding_handler import EmbeddingFunctionHandler, EmbeddingModuleHandler
from .experimental import PermuteHandler, ViewHandler from .experimental import PermuteHandler, ViewHandler
from .getatrr_handler import GetattrHandler from .getatrr_handler import GetattrHandler
from .getitem_handler import GetItemHandler from .getitem_handler import GetItemHandler
@ -23,5 +24,6 @@ __all__ = [
'LayerNormModuleHandler', 'BatchNormModuleHandler', 'ConvModuleHandler', 'ConvFunctionHandler', 'LayerNormModuleHandler', 'BatchNormModuleHandler', 'ConvModuleHandler', 'ConvFunctionHandler',
'UnaryElementwiseHandler', 'ReshapeHandler', 'PlacehodlerHandler', 'OuputHandler', 'WhereHandler', 'UnaryElementwiseHandler', 'ReshapeHandler', 'PlacehodlerHandler', 'OuputHandler', 'WhereHandler',
'NormPoolingHandler', 'BinaryElementwiseHandler', 'MatMulHandler', 'operator_registry', 'ADDMMFunctionHandler', 'NormPoolingHandler', 'BinaryElementwiseHandler', 'MatMulHandler', 'operator_registry', 'ADDMMFunctionHandler',
'GetItemHandler', 'GetattrHandler', 'ViewHandler', 'PermuteHandler', 'TensorConstructorHandler' 'GetItemHandler', 'GetattrHandler', 'ViewHandler', 'PermuteHandler', 'TensorConstructorHandler',
'EmbeddingModuleHandler', 'EmbeddingFunctionHandler'
] ]

View File

@ -0,0 +1,230 @@
from typing import Dict, List, Union
import torch
import torch.nn.functional as F
from colossalai.auto_parallel.tensor_shard.utils import update_partition_dim
from colossalai.logging import get_dist_logger
from colossalai.tensor.sharding_spec import ShardingNotDivisibleError
from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy
from .node_handler import ModuleHandler, NodeHandler
from .registry import operator_registry
from .strategy import EmbeddingStrategyGenerator, StrategyGenerator
__all__ = ['EmbeddingModuleHandler', 'EmbeddingFunctionHandler']
def _convert_logical_sharding_to_physical_sharding_spec_for_embedding(strategy: ShardingStrategy, input_name: str,
output_name: str) -> List[ShardingStrategy]:
"""
This function converts the logical sharding spec to the physical sharding spec for both the input and output
of the embedding operation.
Args:
strategy (ShardingStrategy): the logical strategy generated by the strategy generator.
input_name (str): the name of the OperationData object for the input.
output_name (str): the name of the OperationData object for the output.
"""
# the result will be a list of strategies
sharding_strategies = []
# get operation data
input_op_data = strategy.get_op_data_by_name(input_name)
output_op_data = strategy.get_op_data_by_name(output_name)
input_sharding_spec = strategy.get_sharding_spec_by_name(input_op_data.name)
output_sharding_spec = strategy.get_sharding_spec_by_name(output_op_data.name)
# recover the last logical dimension to physical dimension
last_logical_output_dims = len(output_op_data.logical_shape) - 1
last_physical_output_dims = output_op_data.data.dim() - 1
# get logger for debug message
logger = get_dist_logger()
# For the input of the embedding operation, it can be multi-dimensional. The sharding spec is only generated for
# logical 1D non-matrix dimension, the logical non-matrix dimension can belong to the 0th to Nth dimension of the
# physical input shape. Thus, we enumerate to get all possible cases.
if input_sharding_spec.dim_partition_dict:
# if bool(input_sharding_spec.dim_partition_dict), it means that the
# the generated sharding strategy does shard the non-matrix dimension,
# in this case, we need to do enumeration
num_input_dims = input_op_data.data.dim()
for i in range(num_input_dims):
strategy_copy = strategy.clone()
input_sharding_spec = strategy_copy.get_sharding_spec_by_name(input_op_data.name)
output_sharding_spec = strategy_copy.get_sharding_spec_by_name(output_op_data.name)
try:
# replace the 0th dimension in the logical sharding with ith dimension in the physical sharding
update_partition_dim(sharding_spec=input_sharding_spec,
dim_mapping={0: i},
physical_shape=input_op_data.data.shape,
inplace=True)
if last_logical_output_dims in output_sharding_spec.dim_partition_dict:
dim_mapping = {0: i, last_logical_output_dims: last_physical_output_dims}
else:
dim_mapping = {0: i}
update_partition_dim(sharding_spec=output_sharding_spec,
dim_mapping=dim_mapping,
physical_shape=output_op_data.data.shape,
inplace=True)
strategy_copy.name = f'{strategy.name}_{i}'
sharding_strategies.append(strategy_copy)
except ShardingNotDivisibleError as e:
logger.debug(
f'Errored occurred when converting the logical sharding spec to the physical one. Error details: {e}'
)
else:
# the generated sharding strategy does not shard the non-matrix dimension,
# in this case, we don't need to do enumeration
# but instead, we still need to convert the logical shape to physical shape
strategy_copy = strategy.clone()
input_sharding_spec = strategy_copy.get_sharding_spec_by_name(input_op_data.name)
output_sharding_spec = strategy_copy.get_sharding_spec_by_name(output_op_data.name)
# after updating, the logical shape will be replaced by the physical shape
update_partition_dim(sharding_spec=input_sharding_spec,
dim_mapping={},
physical_shape=input_op_data.data.shape,
inplace=True)
if last_logical_output_dims in output_sharding_spec.dim_partition_dict:
dim_mapping = {last_logical_output_dims: last_physical_output_dims}
else:
dim_mapping = {}
update_partition_dim(sharding_spec=output_sharding_spec,
dim_mapping=dim_mapping,
physical_shape=output_op_data.data.shape,
inplace=True)
sharding_strategies.append(strategy_copy)
return sharding_strategies
@operator_registry.register(torch.nn.Embedding)
class EmbeddingModuleHandler(ModuleHandler):
"""
A EmbeddingModuleHandler which deals with the sharding strategies for nn.Embedding module.
"""
def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(EmbeddingStrategyGenerator(op_data_mapping, self.device_mesh))
return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# In nn.Embedding operation, all the dimensions of input will be treated as the batch dimension,
# and then the sharding spec will be generated based on the logical 1D tensor.
# After that, the logical sharding info will be enumerated among all the physical dimensions.
# Finally, the input will be transformed back to its original shape in self.post_process
input_meta_data = self.node.args[0]._meta_data
input_logical_shape = input_meta_data.view(-1).shape
physical_input_operand = OperationData(name=str(self.node.args[0]),
type=OperationDataType.ARG,
data=input_meta_data,
logical_shape=input_logical_shape)
physical_other_operand = OperationData(name="weight",
type=OperationDataType.PARAM,
data=self.named_parameters['weight'])
# Same as input, in nn.Embedding operation, all the dimensions of output will be treated as
# (batch dimension, embedding dimension), and then the sharding spec will be generated based
# on the logical 2D tensor.
# After that, the logical sharding info of batch dimension will be enumerated among all the physical dimensions.
# Finally, the output will be transformed back to its original shape in self.post_process
output_meta_data = self.node._meta_data
output_logical_shape = output_meta_data.view(-1, output_meta_data.shape[-1]).shape
physical_output = OperationData(name=str(self.node),
type=OperationDataType.OUTPUT,
data=output_meta_data,
logical_shape=output_logical_shape)
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
return mapping
def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]:
"""
Convert the sharding spec from the logical shape to the physical shape.
"""
# create multiple sharding strategies for the inputs
# as input can be multi-dimensinal and the partition dim is only 2D,
# we need to map the partition at logical dim 0 to one of the first few dimensions of the input and output
strategies = _convert_logical_sharding_to_physical_sharding_spec_for_embedding(strategy=strategy,
input_name=str(
self.node.args[0]),
output_name=str(self.node))
return strategies
@operator_registry.register(F.embedding)
class EmbeddingFunctionHandler(NodeHandler):
"""
A EmbeddingFunctionHandler which deals with the sharding strategies for F.embedding.
"""
def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(EmbeddingStrategyGenerator(op_data_mapping, self.device_mesh))
return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# In F.embedding operation, all the dimensions of input will be treated as the batch dimension,
# and then the sharding spec will be generated based on the logical 1D tensor.
# After that, the logical sharding info will be enumerated among all the physical dimensions.
# Finally, the input will be transformed back to its original shape in self.post_process
input_meta_data = self.node.args[0]._meta_data
input_logical_shape = input_meta_data.view(-1).shape
physical_input_operand = OperationData(name=str(self.node.args[0]),
type=OperationDataType.ARG,
data=self.node.args[0]._meta_data,
logical_shape=input_logical_shape)
# check if the other operand is a parameter
if isinstance(self.node.args[1]._meta_data, torch.nn.parameter.Parameter):
data_type = OperationDataType.PARAM
else:
data_type = OperationDataType.ARG
physical_other_operand = OperationData(name=str(self.node.args[1]),
type=data_type,
data=self.node.args[1]._meta_data)
# Same as input, in F.embedding operation, all the dimensions of output will be treated as
# (batch dimension, embedding dimension), and then the sharding spec will be generated based
# on the logical 2D tensor.
# After that, the logical sharding info of batch dimension will be enumerated among all the physical dimensions.
# Finally, the output will be transformed back to its original shape in self.post_process
output_meta_data = self.node._meta_data
output_logical_shape = output_meta_data.view(-1, output_meta_data.shape[-1]).shape
physical_output = OperationData(
name=str(self.node),
type=OperationDataType.OUTPUT,
data=self.node._meta_data,
logical_shape=output_logical_shape,
)
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
return mapping
def post_process(self, strategy: ShardingStrategy):
"""
Convert the sharding spec from the logical shape to the physical shape.
"""
# create multiple sharding strategies for the inputs
# as input can be multi-dimensinal and the partition dim is only 2D,
# we need to map the partition at logical dim 0 to one of the first few dimensions of the input and output
strategies = _convert_logical_sharding_to_physical_sharding_spec_for_embedding(strategy=strategy,
input_name=str(
self.node.args[0]),
output_name=str(self.node))
return strategies

View File

@ -1,6 +1,7 @@
from .batch_norm_generator import BatchNormStrategyGenerator from .batch_norm_generator import BatchNormStrategyGenerator
from .binary_elementwise_generator import BinaryElementwiseStrategyGenerator from .binary_elementwise_generator import BinaryElementwiseStrategyGenerator
from .conv_strategy_generator import ConvStrategyGenerator from .conv_strategy_generator import ConvStrategyGenerator
from .embedding_generator import EmbeddingStrategyGenerator
from .getattr_generator import GetattrGenerator from .getattr_generator import GetattrGenerator
from .getitem_generator import GetItemStrategyGenerator, TensorStrategyGenerator, TensorTupleStrategyGenerator from .getitem_generator import GetItemStrategyGenerator, TensorStrategyGenerator, TensorTupleStrategyGenerator
from .layer_norm_generator import LayerNormGenerator from .layer_norm_generator import LayerNormGenerator
@ -25,5 +26,5 @@ __all__ = [
'BatchNormStrategyGenerator', 'GetItemStrategyGenerator', 'TensorStrategyGenerator', 'TensorTupleStrategyGenerator', 'BatchNormStrategyGenerator', 'GetItemStrategyGenerator', 'TensorStrategyGenerator', 'TensorTupleStrategyGenerator',
'LayerNormGenerator', 'ReshapeGenerator', 'PlaceholderGenerator', 'OutputGenerator', 'WhereGenerator', 'LayerNormGenerator', 'ReshapeGenerator', 'PlaceholderGenerator', 'OutputGenerator', 'WhereGenerator',
'ReshapeGenerator', 'NormalPoolStrategyGenerator', 'BinaryElementwiseStrategyGenerator', 'GetattrGenerator', 'ReshapeGenerator', 'NormalPoolStrategyGenerator', 'BinaryElementwiseStrategyGenerator', 'GetattrGenerator',
'TensorConstructorGenerator' 'TensorConstructorGenerator', 'EmbeddingStrategyGenerator'
] ]

View File

@ -0,0 +1,310 @@
import copy
import operator
import warnings
from functools import reduce
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommAction,
CommType,
MemoryCost,
ShardingStrategy,
TrainCycleItem,
)
from colossalai.auto_parallel.tensor_shard.utils import ignore_sharding_exception
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import StrategyGenerator
class EmbeddingStrategyGenerator(StrategyGenerator):
"""
EmbeddingStrategyGenerator is a generic class to generate strategies for nn.Embedding or F.embedding.
The operation data is defined as `output = input x other`.
"""
def validate(self) -> bool:
return super().validate()
def update_compute_cost(self, strategy: ShardingStrategy):
'''
Compute the computation cost per device with this specific strategy.
Note: The computation cost for the embedding handler is estimated as dense computing now.
It may not be accurate.
'''
# TODO: estimate the embedding computation cost as sparse operation
sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
sharded_other_shape = strategy.sharding_specs[self.op_data['other']].get_sharded_shape_per_device()
sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device()
input_size_product = reduce(operator.mul, sharded_input_shape)
other_size_product = reduce(operator.mul, sharded_other_shape)
output_size_product = reduce(operator.mul, sharded_output_shape)
forward_compute_cost = input_size_product * other_size_product
backward_activation_cost = other_size_product * output_size_product / sharded_output_shape[-1]
backward_weight_cost = input_size_product * other_size_product
backward_compute_cost = backward_weight_cost + backward_activation_cost
total_compute_cost = forward_compute_cost + backward_compute_cost
compute_cost = TrainCycleItem(fwd=forward_compute_cost, bwd=backward_compute_cost, total=total_compute_cost)
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
forward_size_mapping = {
'input': self._compute_size_in_bytes(strategy, "input"),
'other': self._compute_size_in_bytes(strategy, "other"),
'output': self._compute_size_in_bytes(strategy, "output")
}
backward_size_mapping = copy.deepcopy(forward_size_mapping)
backward_size_mapping.pop("output")
# compute fwd cost incurred
# fwd_cost = input + other + output
fwd_activation_cost = sum([v for k, v in forward_size_mapping.items() if not self.is_param(k)])
fwd_parameter_cost = sum([v for k, v in forward_size_mapping.items() if self.is_param(k)])
fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost)
# compute bwd cost incurred
# bwd_cost = input_grad + other_grad
bwd_activation_cost = sum([v for k, v in backward_size_mapping.items() if not self.is_param(k)])
bwd_parameter_cost = sum([v for k, v in backward_size_mapping.items() if self.is_param(k)])
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost)
# compute total cost
total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
parameter=fwd_parameter_cost + bwd_parameter_cost)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
@ignore_sharding_exception
def non_split(self):
name = f'RR = R x RR'
dim_partition_dict_mapping = {
"input": {},
"other": {},
"output": {},
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping={})
@ignore_sharding_exception
def split_input(self, mesh_dim_0):
name = f'S{mesh_dim_0}R = S{mesh_dim_0} x RR'
dim_partition_dict_mapping = {
"input": {
0: [mesh_dim_0]
},
"other": {},
"output": {
0: [mesh_dim_0],
},
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
communication_action_mapping = {}
if self.is_param("other"):
other_comm_action = self.get_communication_action(
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.HOOK)
else:
other_comm_action = self.get_communication_action(
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
arg_index=1)
communication_action_mapping["other"] = other_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def split_input_and_embedding_dim(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0} x RS{mesh_dim_1}'
dim_partition_dict_mapping = {
"input": {
0: [mesh_dim_0],
},
"other": {
1: [mesh_dim_1],
},
"output": {
0: [mesh_dim_0],
1: [mesh_dim_1],
},
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# set communication action
input_comm_action = self.get_communication_action(
sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_1,
comm_type=CommType.BEFORE,
arg_index=0)
communication_action_mapping = {"input": input_comm_action}
if self.is_param("other"):
other_comm_action = self.get_communication_action(
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.HOOK)
else:
other_comm_action = self.get_communication_action(
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
arg_index=1)
communication_action_mapping["other"] = other_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def split_1d_parallel_on_input(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1} x RR'
dim_partition_dict_mapping = {
"input": {
0: [mesh_dim_0, mesh_dim_1]
},
"other": {},
"output": {
0: [mesh_dim_0, mesh_dim_1],
},
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# set communication action
communication_action_mapping = {}
if self.is_param("other"):
other_comm_action = self.get_communication_action(
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.HOOK)
else:
other_comm_action = self.get_communication_action(
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.BEFORE,
arg_index=1)
communication_action_mapping["other"] = other_comm_action
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def split_embedding_dim(self, mesh_dim_0):
name = f'RS{mesh_dim_0} = R x RS{mesh_dim_0}'
dim_partition_dict_mapping = {
"input": {},
"other": {
1: [mesh_dim_0],
},
"output": {
1: [mesh_dim_0],
},
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# set communication action
input_comm_action = self.get_communication_action(
sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
arg_index=0)
communication_action_mapping = {"input": input_comm_action}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def split_1d_parallel_on_embedding_dim(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_0}{mesh_dim_1} = R x RS{mesh_dim_0}{mesh_dim_1}'
dim_partition_dict_mapping = {
"input": {},
"other": {
1: [mesh_dim_0, mesh_dim_1],
},
"output": {
1: [mesh_dim_0, mesh_dim_1],
},
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# set communication action
input_comm_action = self.get_communication_action(
sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.BEFORE,
arg_index=0)
communication_action_mapping = {"input": input_comm_action}
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
def collate_strategies(self) -> List[ShardingStrategy]:
strategies = []
# RR= R x RR
strategies.append(self.non_split())
# SR = S x RR
strategies.append(self.split_input(0))
strategies.append(self.split_input(1))
# SS = S x RS
strategies.append(self.split_input_and_embedding_dim(0, 1))
strategies.append(self.split_input_and_embedding_dim(1, 0))
# S01R = S01 x RR
strategies.append(self.split_1d_parallel_on_input(0, 1))
# RS = R x RS
strategies.append(self.split_embedding_dim(0))
strategies.append(self.split_embedding_dim(1))
# RS01 = R x RS01
strategies.append(self.split_1d_parallel_on_embedding_dim(0, 1))
return strategies

View File

@ -0,0 +1,286 @@
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.node_handler.embedding_handler import (
EmbeddingFunctionHandler,
EmbeddingModuleHandler,
)
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
NUM_EMBEDDINGS = 16
EMBEDDING_DIMS = 32
class EmbeddingModule(nn.Module):
def __init__(self, num_embeddings, embedding_dims):
super().__init__()
self.embedding = nn.Embedding(num_embeddings, embedding_dims)
def forward(self, input):
x = self.embedding(input)
return x
def check_embedding_module_handler(rank, world_size, port):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = EmbeddingModule(num_embeddings=NUM_EMBEDDINGS, embedding_dims=EMBEDDING_DIMS).cuda()
# graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
# %embedding : [#users=1] = call_module[target=embedding](args = (%input_1,), kwargs = {})
# return embedding
input = torch.rand(4, 16, 16) * NUM_EMBEDDINGS
input = input.to(torch.int64).cuda()
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
# index of embedding node in computation graph
node_index = 1
# total number of embedding strategies
strategy_number = 19
numerical_test_for_node_strategy(model=model,
device_mesh=device_mesh,
node_index=node_index,
strategy_number=strategy_number,
input_args=[input],
meta_arg_names=['input'])
tracer = ColoTracer()
graph = tracer.trace(model, meta_args={"input": torch.rand(4, 16, 16).to('meta')})
gm = ColoGraphModule(model, graph)
embedding_node = list(graph.nodes)[1]
strategies_vector = StrategiesVector(embedding_node)
# build handler
handler = EmbeddingModuleHandler(node=embedding_node, device_mesh=device_mesh, strategies_vector=strategies_vector)
# check operation data mapping
mapping = handler.get_operation_data_mapping()
for name, op_data in mapping.items():
op_data: OperationData
# make sure they have valid values
assert op_data.logical_shape is not None
assert op_data.data is not None
assert mapping['input'].name == "input_1"
# assert mapping['input'].data.is_meta
assert mapping['input'].data.shape == torch.Size([4, 16, 16])
assert mapping['input'].type == OperationDataType.ARG
assert mapping['input'].logical_shape == torch.Size([1024])
assert mapping['other'].name == "weight"
assert mapping['other'].data.shape == torch.Size([NUM_EMBEDDINGS, EMBEDDING_DIMS])
assert mapping['other'].type == OperationDataType.PARAM
assert mapping['other'].logical_shape == torch.Size([NUM_EMBEDDINGS, EMBEDDING_DIMS])
assert mapping['output'].name == "embedding"
assert mapping['output'].data.shape == torch.Size([4, 16, 16, EMBEDDING_DIMS])
assert mapping['output'].type == OperationDataType.OUTPUT
assert mapping['output'].logical_shape == torch.Size([1024, EMBEDDING_DIMS])
strategies_vector = handler.register_strategy(compute_resharding_cost=False)
strategy_name_list = [val.name for val in strategies_vector]
# RR = RR x RR
assert 'RR = R x RR' in strategy_name_list
# SR = SR x RR
assert 'S0R = S0 x RR_0' in strategy_name_list
assert 'S0R = S0 x RR_1' in strategy_name_list
assert 'S0R = S0 x RR_2' in strategy_name_list
assert 'S1R = S1 x RR_0' in strategy_name_list
assert 'S1R = S1 x RR_1' in strategy_name_list
assert 'S1R = S1 x RR_2' in strategy_name_list
# SS = SR x RS
assert 'S0S1 = S0 x RS1_0' in strategy_name_list
assert 'S0S1 = S0 x RS1_1' in strategy_name_list
assert 'S0S1 = S0 x RS1_2' in strategy_name_list
assert 'S1S0 = S1 x RS0_0' in strategy_name_list
assert 'S1S0 = S1 x RS0_1' in strategy_name_list
assert 'S1S0 = S1 x RS0_2' in strategy_name_list
# RS= RR x RS
assert 'RS0 = R x RS0' in strategy_name_list
assert 'RS1 = R x RS1' in strategy_name_list
# S01R = S01R x RR
assert 'S01R = S01 x RR_0' in strategy_name_list
assert 'S01R = S01 x RR_1' in strategy_name_list
assert 'S01R = S01 x RR_2' in strategy_name_list
# RS01 = RR x RS01
assert 'RS01 = R x RS01' in strategy_name_list
for strategy in strategies_vector:
input_sharding_spec = strategy.get_sharding_spec_by_name('input_1')
weight_sharding_spec = strategy.get_sharding_spec_by_name('weight')
output_sharding_spec = strategy.get_sharding_spec_by_name('embedding')
# make sure the sharding matches across different operation data
assert output_sharding_spec.sharding_sequence[-1] == weight_sharding_spec.sharding_sequence[-1]
assert input_sharding_spec.sharding_sequence == output_sharding_spec.sharding_sequence[:-1]
class EmbeddingFunction(nn.Module):
def __init__(self):
super().__init__()
def forward(self, input, others):
x = nn.functional.embedding(input, others)
return x
def check_embedding_function_handler(rank, world_size, port):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = EmbeddingFunction().cuda()
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
input = torch.rand(4, 16, 16) * NUM_EMBEDDINGS
input = input.to(torch.int64).cuda()
others = torch.rand(NUM_EMBEDDINGS, EMBEDDING_DIMS).cuda()
input_args = [input, others]
meta_arg_names = ['input', 'others']
input_kwargs = {}
# total number of embedding strategies
strategy_number = 19
node_index = 2
numerical_test_for_node_strategy(model=model,
device_mesh=device_mesh,
node_index=node_index,
strategy_number=strategy_number,
input_args=input_args,
meta_arg_names=meta_arg_names,
input_kwargs=input_kwargs)
tracer = ColoTracer()
# graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
# %others : torch.Tensor [#users=1] = placeholder[target=others]
# %embedding : [#users=1] = call_function[target=torch.nn.functional.embedding](args = (%input_1, %others), kwargs = {padding_idx: None, max_norm: None, norm_type: 2.0, scale_grad_by_freq: False, sparse: False})
# return embedding
meta_args = {
"input": torch.rand(4, 16, 16).to('meta'),
"others": torch.rand(NUM_EMBEDDINGS, EMBEDDING_DIMS).to('meta')
}
graph = tracer.trace(model, meta_args=meta_args)
gm = ColoGraphModule(model, graph)
embedding_node = list(graph.nodes)[2]
strategies_vector = StrategiesVector(embedding_node)
# build handler
handler = EmbeddingFunctionHandler(node=embedding_node,
device_mesh=device_mesh,
strategies_vector=strategies_vector)
# check operation data mapping
mapping = handler.get_operation_data_mapping()
for name, op_data in mapping.items():
op_data: OperationData
# make sure they have valid values
assert op_data.logical_shape is not None
assert op_data.data is not None
assert mapping['input'].name == "input_1"
assert mapping['input'].data.is_meta
assert mapping['input'].data.shape == torch.Size([4, 16, 16])
assert mapping['input'].type == OperationDataType.ARG
assert mapping['input'].logical_shape == torch.Size([1024])
assert mapping['other'].name == "others"
assert mapping['other'].data.is_meta
assert mapping['other'].data.shape == torch.Size([NUM_EMBEDDINGS, EMBEDDING_DIMS])
assert mapping['other'].type == OperationDataType.ARG
assert mapping['other'].logical_shape == torch.Size([NUM_EMBEDDINGS, EMBEDDING_DIMS])
assert mapping['output'].name == "embedding"
assert mapping['output'].data.is_meta
assert mapping['output'].data.shape == torch.Size([4, 16, 16, EMBEDDING_DIMS])
assert mapping['output'].type == OperationDataType.OUTPUT
assert mapping['output'].logical_shape == torch.Size([1024, EMBEDDING_DIMS])
handler.register_strategy(compute_resharding_cost=False)
strategy_name_list = [val.name for val in strategies_vector]
# RR = RR x RR
assert 'RR = R x RR' in strategy_name_list
# SR = SR x RR
assert 'S0R = S0 x RR_0' in strategy_name_list
assert 'S0R = S0 x RR_1' in strategy_name_list
assert 'S0R = S0 x RR_2' in strategy_name_list
assert 'S1R = S1 x RR_0' in strategy_name_list
assert 'S1R = S1 x RR_1' in strategy_name_list
assert 'S1R = S1 x RR_2' in strategy_name_list
# SS = SR x RS
assert 'S0S1 = S0 x RS1_0' in strategy_name_list
assert 'S0S1 = S0 x RS1_1' in strategy_name_list
assert 'S0S1 = S0 x RS1_2' in strategy_name_list
assert 'S1S0 = S1 x RS0_0' in strategy_name_list
assert 'S1S0 = S1 x RS0_1' in strategy_name_list
assert 'S1S0 = S1 x RS0_2' in strategy_name_list
# RS= RR x RS
assert 'RS0 = R x RS0' in strategy_name_list
assert 'RS1 = R x RS1' in strategy_name_list
# S01R = S01R x RR
assert 'S01R = S01 x RR_0' in strategy_name_list
assert 'S01R = S01 x RR_1' in strategy_name_list
assert 'S01R = S01 x RR_2' in strategy_name_list
# RS01 = RR x RS01
assert 'RS01 = R x RS01' in strategy_name_list
for strategy in strategies_vector:
input_sharding_spec = strategy.get_sharding_spec_by_name('input_1')
weight_sharding_spec = strategy.get_sharding_spec_by_name('others')
output_sharding_spec = strategy.get_sharding_spec_by_name('embedding')
# make sure the sharding matches across different operation data
assert output_sharding_spec.sharding_sequence[-1] == weight_sharding_spec.sharding_sequence[-1]
assert input_sharding_spec.sharding_sequence == output_sharding_spec.sharding_sequence[:-1]
@run_on_environment_flag(name='AUTO_PARALLEL')
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_embedding_module_handler():
world_size = 4
run_func = partial(check_embedding_module_handler, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
@run_on_environment_flag(name='AUTO_PARALLEL')
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_embedding_function_handler():
world_size = 4
run_func = partial(check_embedding_function_handler, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_embedding_module_handler()
test_embedding_function_handler()

View File

@ -13,7 +13,7 @@ from colossalai.auto_parallel.tensor_shard.solver.solver import Solver
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.tracer.tracer import ColoTracer from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.tensor.shape_consistency import to_global from colossalai.tensor.shape_consistency import to_global
from colossalai.testing.comparison import assert_close, assert_close_loose from colossalai.testing.comparison import assert_close
def _build_model_to_compare(model: torch.nn.Module, input_args: List[torch.Tensor], def _build_model_to_compare(model: torch.nn.Module, input_args: List[torch.Tensor],
@ -32,8 +32,12 @@ def _build_model_to_compare(model: torch.nn.Module, input_args: List[torch.Tenso
param.register_hook(hook_fn) param.register_hook(hook_fn)
arg_to_compare = copy.deepcopy(input_tensor) arg_to_compare = copy.deepcopy(input_tensor)
# only Tensors of floating point and complex dtype can require gradients
if arg_to_compare.dtype != torch.int64:
arg_to_compare.requires_grad = True arg_to_compare.requires_grad = True
wrapper(arg_to_compare, arg_index) wrapper(arg_to_compare, arg_index)
args_to_compare.append(arg_to_compare) args_to_compare.append(arg_to_compare)
for name, input_kwarg in input_kwargs.items(): for name, input_kwarg in input_kwargs.items():
@ -46,8 +50,12 @@ def _build_model_to_compare(model: torch.nn.Module, input_args: List[torch.Tenso
param.register_hook(hook_fn) param.register_hook(hook_fn)
kwarg_to_compare = copy.deepcopy(input_kwarg) kwarg_to_compare = copy.deepcopy(input_kwarg)
# only Tensors of floating point and complex dtype can require gradients
if kwarg_to_compare.dtype != torch.int64:
kwarg_to_compare.requires_grad = True kwarg_to_compare.requires_grad = True
wrapper(kwarg_to_compare, name) wrapper(kwarg_to_compare, name)
kwargs_to_compare[name] = kwarg_to_compare kwargs_to_compare[name] = kwarg_to_compare
return model_to_compare, args_to_compare, kwargs_to_compare return model_to_compare, args_to_compare, kwargs_to_compare
@ -160,7 +168,6 @@ def assert_close_helper(first: torch.Tensor,
""" """
This method is used to check whether the average difference between two tensors is as close as expected. This method is used to check whether the average difference between two tensors is as close as expected.
""" """
# average_diff_tensor = ((first - second)/(second+0.1)).sum()/second.numel()
try: try:
if isinstance(first, (tuple, list)): if isinstance(first, (tuple, list)):
for first_element, second_element in zip(first, second): for first_element, second_element in zip(first, second):