mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-29 08:47:17 +00:00
[autoparallel] support more flexible data type (#1967)
This commit is contained in:
parent
5bec3b2168
commit
05020e50d0
@ -4,6 +4,7 @@ from .binary_elementwise_handler import BinaryElementwiseHandler
|
|||||||
from .bmm_handler import AddBMMFunctionHandler, BMMFunctionHandler
|
from .bmm_handler import AddBMMFunctionHandler, BMMFunctionHandler
|
||||||
from .conv_handler import ConvFunctionHandler, ConvModuleHandler
|
from .conv_handler import ConvFunctionHandler, ConvModuleHandler
|
||||||
from .getatrr_handler import GetattrHandler
|
from .getatrr_handler import GetattrHandler
|
||||||
|
from .getitem_handler import GetItemHandler
|
||||||
from .layer_norm_handler import LayerNormModuleHandler
|
from .layer_norm_handler import LayerNormModuleHandler
|
||||||
from .linear_handler import LinearFunctionHandler, LinearModuleHandler
|
from .linear_handler import LinearFunctionHandler, LinearModuleHandler
|
||||||
from .matmul_handler import MatMulHandler
|
from .matmul_handler import MatMulHandler
|
||||||
@ -19,5 +20,6 @@ __all__ = [
|
|||||||
'LinearFunctionHandler', 'LinearModuleHandler', 'BMMFunctionHandler', 'AddBMMFunctionHandler',
|
'LinearFunctionHandler', 'LinearModuleHandler', 'BMMFunctionHandler', 'AddBMMFunctionHandler',
|
||||||
'LayerNormModuleHandler', 'BatchNormModuleHandler', 'ConvModuleHandler', 'ConvFunctionHandler',
|
'LayerNormModuleHandler', 'BatchNormModuleHandler', 'ConvModuleHandler', 'ConvFunctionHandler',
|
||||||
'UnaryElementwiseHandler', 'ReshapeHandler', 'PlacehodlerHandler', 'OuputHandler', 'WhereHandler',
|
'UnaryElementwiseHandler', 'ReshapeHandler', 'PlacehodlerHandler', 'OuputHandler', 'WhereHandler',
|
||||||
'NormPoolingHandler', 'BinaryElementwiseHandler', 'MatMulHandler', 'operator_registry', 'ADDMMFunctionHandler', 'GetattrHandler'
|
'NormPoolingHandler', 'BinaryElementwiseHandler', 'MatMulHandler', 'operator_registry', 'ADDMMFunctionHandler',
|
||||||
|
'GetItemHandler', 'GetattrHandler'
|
||||||
]
|
]
|
||||||
|
@ -51,6 +51,10 @@ class NodeHandler(ABC):
|
|||||||
for node in self.predecessor_node:
|
for node in self.predecessor_node:
|
||||||
node_name = str(node)
|
node_name = str(node)
|
||||||
# get the current sharding spec generated by this node handler
|
# get the current sharding spec generated by this node handler
|
||||||
|
|
||||||
|
# TODO: we need to check this in future
|
||||||
|
if not isinstance(node._meta_data, torch.Tensor):
|
||||||
|
continue
|
||||||
op_data = strategy.get_op_data_by_name(node_name)
|
op_data = strategy.get_op_data_by_name(node_name)
|
||||||
current_sharding_spec = strategy.sharding_specs[op_data]
|
current_sharding_spec = strategy.sharding_specs[op_data]
|
||||||
|
|
||||||
|
@ -11,7 +11,9 @@ __all__ = ['ReshapeHandler']
|
|||||||
|
|
||||||
|
|
||||||
@operator_registry.register(torch.reshape)
|
@operator_registry.register(torch.reshape)
|
||||||
|
@operator_registry.register(torch.Tensor.split)
|
||||||
@operator_registry.register(torch.flatten)
|
@operator_registry.register(torch.flatten)
|
||||||
|
@operator_registry.register(torch.Tensor.transpose)
|
||||||
@operator_registry.register(torch.Tensor.permute)
|
@operator_registry.register(torch.Tensor.permute)
|
||||||
@operator_registry.register(torch.Tensor.view)
|
@operator_registry.register(torch.Tensor.view)
|
||||||
@operator_registry.register(torch.nn.AdaptiveAvgPool2d)
|
@operator_registry.register(torch.nn.AdaptiveAvgPool2d)
|
||||||
@ -26,6 +28,24 @@ class ReshapeHandler(NodeHandler):
|
|||||||
generators.append(ReshapeGenerator(op_data_mapping, self.device_mesh, self.node.args[0]))
|
generators.append(ReshapeGenerator(op_data_mapping, self.device_mesh, self.node.args[0]))
|
||||||
return generators
|
return generators
|
||||||
|
|
||||||
|
def infer_logical_shape(self, data):
|
||||||
|
"""
|
||||||
|
This function is used to infer logical shape for operands.
|
||||||
|
|
||||||
|
Notes: This function is only used for the operands whose data are not only in type of tensor,
|
||||||
|
such as tuple of tensor.
|
||||||
|
"""
|
||||||
|
if isinstance(data, torch.Tensor):
|
||||||
|
return data.shape
|
||||||
|
else:
|
||||||
|
assert isinstance(data, tuple), "input_data should be a tuple of tensor or a tensor."
|
||||||
|
logical_shape = []
|
||||||
|
for tensor in data:
|
||||||
|
assert isinstance(tensor, torch.Tensor), "input_data should be a tuple of tensor or a tensor."
|
||||||
|
logical_shape.append(tensor.shape)
|
||||||
|
logical_shape = tuple(logical_shape)
|
||||||
|
return logical_shape
|
||||||
|
|
||||||
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
|
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
|
||||||
# use transposed shape for strategies
|
# use transposed shape for strategies
|
||||||
# the strategies will be transformed back to its original shape in self.post_process
|
# the strategies will be transformed back to its original shape in self.post_process
|
||||||
@ -36,10 +56,19 @@ class ReshapeHandler(NodeHandler):
|
|||||||
else:
|
else:
|
||||||
data_type = OperationDataType.ARG
|
data_type = OperationDataType.ARG
|
||||||
|
|
||||||
|
input_data = self.node.args[0]._meta_data
|
||||||
|
input_logical_shape = self.infer_logical_shape(input_data)
|
||||||
physical_input_operand = OperationData(name=str(self.node.args[0]),
|
physical_input_operand = OperationData(name=str(self.node.args[0]),
|
||||||
type=data_type,
|
type=data_type,
|
||||||
data=self.node.args[0]._meta_data)
|
data=input_data,
|
||||||
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
|
logical_shape=input_logical_shape)
|
||||||
|
|
||||||
|
output_data = self.node._meta_data
|
||||||
|
output_logical_shape = self.infer_logical_shape(output_data)
|
||||||
|
physical_output = OperationData(name=str(self.node),
|
||||||
|
type=OperationDataType.OUTPUT,
|
||||||
|
data=output_data,
|
||||||
|
logical_shape=output_logical_shape)
|
||||||
|
|
||||||
mapping = {"input": physical_input_operand, "output": physical_output}
|
mapping = {"input": physical_input_operand, "output": physical_output}
|
||||||
|
|
||||||
|
@ -81,9 +81,10 @@ class StrategyGenerator(ABC):
|
|||||||
for logical_shape, dim_partition_dict_element in zip(op_data.logical_shape, dim_partition_dict):
|
for logical_shape, dim_partition_dict_element in zip(op_data.logical_shape, dim_partition_dict):
|
||||||
dim_size = len(logical_shape)
|
dim_size = len(logical_shape)
|
||||||
dim_partition_dict_element = convert_dim_partition_dict(dim_size, dim_partition_dict_element)
|
dim_partition_dict_element = convert_dim_partition_dict(dim_size, dim_partition_dict_element)
|
||||||
sharding_spec = ShardingSpec(device_mesh=self.device_mesh,
|
sharding_spec_element = ShardingSpec(device_mesh=self.device_mesh,
|
||||||
entire_shape=logical_shape,
|
entire_shape=logical_shape,
|
||||||
dim_partition_dict=dim_partition_dict_element)
|
dim_partition_dict=dim_partition_dict_element)
|
||||||
|
sharding_spec.append(sharding_spec_element)
|
||||||
else:
|
else:
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
op_data.data, torch.Tensor
|
op_data.data, torch.Tensor
|
||||||
@ -193,18 +194,40 @@ class StrategyGenerator(ABC):
|
|||||||
Args:
|
Args:
|
||||||
strategy (ShardingStrategy): the ShardingStrategy generated.
|
strategy (ShardingStrategy): the ShardingStrategy generated.
|
||||||
key (str): the name of the operation data defined by the generator.
|
key (str): the name of the operation data defined by the generator.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
op_data = self.op_data[key]
|
op_data = self.op_data[key]
|
||||||
sharded_shape = strategy.sharding_specs[op_data].get_sharded_shape_per_device()
|
|
||||||
|
|
||||||
if len(sharded_shape) == 0:
|
def _compute_size_in_bytes_helper(sharding_spec, meta_data):
|
||||||
num_elements = 1
|
sharded_shape = sharding_spec.get_sharded_shape_per_device()
|
||||||
|
if len(sharded_shape) == 0:
|
||||||
|
num_elements = 1
|
||||||
|
else:
|
||||||
|
num_elements = reduce(operator.mul, sharded_shape)
|
||||||
|
dtype = getattr(meta_data, 'dtype')
|
||||||
|
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||||
|
return num_elements * size_per_elem_bytes
|
||||||
|
|
||||||
|
if isinstance(op_data.data, tuple):
|
||||||
|
assert isinstance(strategy.sharding_specs[op_data], list), \
|
||||||
|
'sharding_spec of op_data should be a list of sharding specs if op_data.data is a tuple.'
|
||||||
|
total_bytes = 0
|
||||||
|
for index, sharding_spec in enumerate(strategy.sharding_specs[op_data]):
|
||||||
|
meta_data = op_data.data[index]
|
||||||
|
if isinstance(meta_data, torch.Tensor):
|
||||||
|
element_bytes = _compute_size_in_bytes_helper(sharding_spec, meta_data)
|
||||||
|
else:
|
||||||
|
# if meta_data is not a tensor, we count the memroy as 0
|
||||||
|
element_bytes = 0
|
||||||
|
total_bytes += element_bytes
|
||||||
|
|
||||||
else:
|
else:
|
||||||
num_elements = reduce(operator.mul, sharded_shape)
|
if isinstance(op_data.data, torch.Tensor):
|
||||||
dtype = self.op_data[key].data.dtype
|
total_bytes = _compute_size_in_bytes_helper(strategy.sharding_specs[op_data], op_data.data)
|
||||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
else:
|
||||||
return num_elements * size_per_elem_bytes
|
# if op_data.data is not a tensor, we count the memroy as 0
|
||||||
|
total_bytes = 0
|
||||||
|
|
||||||
|
return total_bytes
|
||||||
|
|
||||||
def generate(self) -> List[ShardingStrategy]:
|
def generate(self) -> List[ShardingStrategy]:
|
||||||
"""
|
"""
|
||||||
|
@ -10,6 +10,8 @@ from .strategy import StrategyGenerator, UnaryElementwiseGenerator
|
|||||||
__all__ = ['UnaryElementwiseHandler']
|
__all__ = ['UnaryElementwiseHandler']
|
||||||
|
|
||||||
|
|
||||||
|
@operator_registry.register(torch.Tensor.to)
|
||||||
|
@operator_registry.register(torch.Tensor.type)
|
||||||
@operator_registry.register(torch.abs)
|
@operator_registry.register(torch.abs)
|
||||||
@operator_registry.register(torch.nn.ReLU)
|
@operator_registry.register(torch.nn.ReLU)
|
||||||
class UnaryElementwiseHandler(NodeHandler):
|
class UnaryElementwiseHandler(NodeHandler):
|
||||||
|
Loading…
Reference in New Issue
Block a user