mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-17 23:18:36 +00:00
[autoparallel] add numerical test for node strategies (#1760)
* [autoparallel] add numerical test for node strategies * polish code * polish code
This commit is contained in:
@@ -24,7 +24,6 @@ def runtime_apply(node: Node, origin_dict: Dict, input_dict: Dict, node_index: i
|
||||
"""
|
||||
origin_sharding_spec = origin_dict[node_index]
|
||||
target_sharding_spec = input_dict[node_index][user_node_index]
|
||||
|
||||
return shape_consistency_manager.apply_for_autoparallel_runtime(node, origin_sharding_spec, target_sharding_spec)
|
||||
|
||||
|
||||
@@ -81,18 +80,24 @@ def _shape_consistency_apply(gm: torch.fx.GraphModule):
|
||||
if not hasattr(node, 'best_strategy') or node.op == 'output':
|
||||
continue
|
||||
|
||||
for user_node in node.strategies_vector.successor_nodes:
|
||||
user_node_index = user_node.strategies_vector.predecessor_nodes.index(node)
|
||||
for user_node_index, user_node in enumerate(node.strategies_vector.successor_nodes):
|
||||
with mod_graph.inserting_before(user_node):
|
||||
shape_consistency_node = mod_graph.create_node('call_function',
|
||||
runtime_apply,
|
||||
args=(node, origin_dict_node, input_dict_node,
|
||||
node_to_index_dict[node], user_node_index))
|
||||
|
||||
origin_index_args = user_node.args.index(node)
|
||||
new_args = list(user_node.args)
|
||||
new_args[origin_index_args] = shape_consistency_node
|
||||
user_node.args = new_args
|
||||
new_kwargs = dict(user_node.kwargs)
|
||||
# the origin node may be a positional argument or key word argument of user node
|
||||
if node in new_args:
|
||||
# substitute the origin node with shape_consistency_node
|
||||
origin_index_args = new_args.index(node)
|
||||
new_args[origin_index_args] = shape_consistency_node
|
||||
user_node.args = new_args
|
||||
elif str(node) in new_kwargs:
|
||||
# substitute the origin node with shape_consistency_node
|
||||
new_kwargs[str(node)] = shape_consistency_node
|
||||
user_node.kwargs = new_kwargs
|
||||
|
||||
return gm
|
||||
|
||||
@@ -112,18 +117,31 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
|
||||
|
||||
comm_actions = node.best_strategy.communication_actions
|
||||
for op_data, comm_action in comm_actions.items():
|
||||
comm_object = node.args[comm_action.arg_index]
|
||||
|
||||
if op_data.type == OperationDataType.PARAM:
|
||||
continue
|
||||
if comm_action.comm_type == CommType.BEFORE:
|
||||
if comm_action.key_for_kwarg is not None:
|
||||
comm_object = node.kwargs[comm_action.key_for_kwarg]
|
||||
else:
|
||||
comm_object = node.args[comm_action.arg_index]
|
||||
with mod_graph.inserting_before(node):
|
||||
comm_spec_apply_node = mod_graph.create_node('call_function',
|
||||
runtime_comm_spec_apply,
|
||||
args=(comm_object, comm_actions_dict_node,
|
||||
node_to_index_dict[node], op_data.name))
|
||||
new_args = list(node.args)
|
||||
new_args[comm_action.arg_index] = comm_spec_apply_node
|
||||
node.args = new_args
|
||||
# the origin node may be a positional argument or key word argument of user node
|
||||
if comm_action.key_for_kwarg is not None:
|
||||
# substitute the origin node with comm_spec_apply_node
|
||||
new_kwargs = dict(node.kwargs)
|
||||
new_kwargs[comm_action.key_for_kwarg] = comm_spec_apply_node
|
||||
node.kwargs = new_kwargs
|
||||
else:
|
||||
# substitute the origin node with comm_spec_apply_node
|
||||
new_args = list(node.args)
|
||||
new_args[comm_action.arg_index] = comm_spec_apply_node
|
||||
node.args = new_args
|
||||
|
||||
elif comm_action.comm_type == CommType.AFTER:
|
||||
with mod_graph.inserting_after(node):
|
||||
comm_spec_apply_node = mod_graph.create_node('call_function',
|
||||
@@ -135,8 +153,16 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
|
||||
if user == comm_spec_apply_node:
|
||||
continue
|
||||
new_args = list(user.args)
|
||||
new_args[new_args.index(node)] = comm_spec_apply_node
|
||||
user.args = tuple(new_args)
|
||||
new_kwargs = dict(user.kwargs)
|
||||
# the origin node may be a positional argument or key word argument of user node
|
||||
if node in new_args:
|
||||
# substitute the origin node with comm_spec_apply_node
|
||||
new_args[new_args.index(node)] = comm_spec_apply_node
|
||||
user.args = tuple(new_args)
|
||||
elif str(node) in new_kwargs:
|
||||
# substitute the origin node with comm_spec_apply_node
|
||||
new_kwargs[str(node)] = comm_spec_apply_node
|
||||
user.kwargs = new_kwargs
|
||||
|
||||
return gm
|
||||
|
||||
|
@@ -77,6 +77,7 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh):
|
||||
if target_sharding_spec.dim_partition_dict != {}:
|
||||
origin_sharding_spec = ShardingSpec(device_mesh, param.shape, {})
|
||||
setattr(param, 'sharding_spec', origin_sharding_spec)
|
||||
# TODO: build a ColoParamter class to manager the distributed parameters
|
||||
param_sharded = torch.nn.Parameter(
|
||||
shape_consistency_manager.apply_for_autoparallel_runtime(param.data, param.sharding_spec,
|
||||
target_sharding_spec).detach().clone())
|
||||
|
@@ -4,7 +4,6 @@ import warnings
|
||||
from functools import reduce
|
||||
from typing import List
|
||||
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
CommAction,
|
||||
CommType,
|
||||
@@ -12,10 +11,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
ShardingStrategy,
|
||||
TrainCycleItem,
|
||||
)
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.utils import \
|
||||
ignore_sharding_exception
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.utils import ignore_sharding_exception
|
||||
from colossalai.tensor.shape_consistency import CollectiveCommPattern
|
||||
|
||||
from .strategy_generator import StrategyGenerator
|
||||
@@ -135,7 +131,8 @@ class ConvStrategyGenerator(StrategyGenerator):
|
||||
sharding_spec=sharding_spec_mapping["input"],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=mesh_dim_1,
|
||||
comm_type=CommType.BEFORE)
|
||||
comm_type=CommType.BEFORE,
|
||||
arg_index=0)
|
||||
communication_action_mapping = {"input": input_comm_action}
|
||||
|
||||
if self.is_param("other"):
|
||||
@@ -223,8 +220,7 @@ class ConvStrategyGenerator(StrategyGenerator):
|
||||
sharding_spec_mapping["output"],
|
||||
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
|
||||
logical_process_axis=mesh_dim_1,
|
||||
comm_type=CommType.AFTER,
|
||||
arg_index=0)
|
||||
comm_type=CommType.AFTER)
|
||||
|
||||
communication_action_mapping = {"output": output_comm_action}
|
||||
|
||||
@@ -277,8 +273,7 @@ class ConvStrategyGenerator(StrategyGenerator):
|
||||
sharding_spec_mapping["output"],
|
||||
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
|
||||
logical_process_axis=mesh_dim_0,
|
||||
comm_type=CommType.AFTER,
|
||||
arg_index=0)
|
||||
comm_type=CommType.AFTER)
|
||||
input_comm_action = self.get_communication_action(
|
||||
sharding_spec_mapping["input"],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
@@ -316,8 +311,7 @@ class ConvStrategyGenerator(StrategyGenerator):
|
||||
sharding_spec_mapping["output"],
|
||||
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
|
||||
logical_process_axis=mesh_dim_0,
|
||||
comm_type=CommType.AFTER,
|
||||
arg_index=0)
|
||||
comm_type=CommType.AFTER)
|
||||
|
||||
communication_action_mapping = {"output": output_comm_action}
|
||||
|
||||
@@ -351,7 +345,8 @@ class ConvStrategyGenerator(StrategyGenerator):
|
||||
sharding_spec_mapping["input"],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=mesh_dim_0,
|
||||
comm_type=CommType.BEFORE)
|
||||
comm_type=CommType.BEFORE,
|
||||
arg_index=0)
|
||||
|
||||
communication_action_mapping = {"input": input_comm_action}
|
||||
|
||||
@@ -441,8 +436,7 @@ class ConvStrategyGenerator(StrategyGenerator):
|
||||
sharding_spec_mapping["output"],
|
||||
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
|
||||
logical_process_axis=[mesh_dim_0, mesh_dim_1],
|
||||
comm_type=CommType.AFTER,
|
||||
arg_index=0)
|
||||
comm_type=CommType.AFTER)
|
||||
|
||||
communication_action_mapping = {"output": output_comm_action}
|
||||
|
||||
|
@@ -109,7 +109,8 @@ class StrategyGenerator(ABC):
|
||||
communication_pattern: CollectiveCommPattern,
|
||||
logical_process_axis: Union[int, List[int]],
|
||||
comm_type: CommType,
|
||||
arg_index: int = -1) -> CommAction:
|
||||
arg_index: int = -1,
|
||||
key_for_kwarg: any = None) -> CommAction:
|
||||
"""
|
||||
A factory method to produce a CommAction object.
|
||||
"""
|
||||
@@ -117,7 +118,8 @@ class StrategyGenerator(ABC):
|
||||
communication_pattern=communication_pattern,
|
||||
logical_process_axis=logical_process_axis),
|
||||
comm_type=comm_type,
|
||||
arg_index=arg_index)
|
||||
arg_index=arg_index,
|
||||
key_for_kwarg=key_for_kwarg)
|
||||
|
||||
def update_communication_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
|
||||
"""
|
||||
|
@@ -115,6 +115,7 @@ class CommAction:
|
||||
comm_spec: CommSpec = None
|
||||
comm_type: CommType = None
|
||||
arg_index: int = -1
|
||||
key_for_kwarg: any = None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@@ -1,5 +1,6 @@
|
||||
from functools import reduce
|
||||
import operator
|
||||
from functools import reduce
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
@@ -11,7 +12,7 @@ class DeviceMesh:
|
||||
can be viewed as a 1x16 or a 4x4 logical mesh). Each mesh dimension has its
|
||||
own latency and bandwidth. We use alpha-beta model to model the
|
||||
communication cost.
|
||||
|
||||
|
||||
Arguments:
|
||||
physical_mesh_id (torch.Tensor): physical view of the devices in global rank.
|
||||
mesh_shape (torch.Size): shape of logical view.
|
||||
@@ -64,6 +65,18 @@ class DeviceMesh:
|
||||
def logical_mesh_id(self):
|
||||
return self._logical_mesh_id
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
cls = self.__class__
|
||||
result = cls.__new__(cls)
|
||||
memo[id(self)] = result
|
||||
for k, v in self.__dict__.items():
|
||||
if k != 'process_groups_dict':
|
||||
setattr(result, k, __import__("copy").deepcopy(v, memo))
|
||||
else:
|
||||
setattr(result, k, v)
|
||||
|
||||
return result
|
||||
|
||||
def flatten(self):
|
||||
"""
|
||||
Flatten the logical mesh into an effective 1d logical mesh,
|
||||
@@ -90,7 +103,7 @@ class DeviceMesh:
|
||||
def create_process_groups_for_logical_mesh(self):
|
||||
'''
|
||||
This method is used to initialize the logical process groups which will be used in communications
|
||||
among logical device mesh.
|
||||
among logical device mesh.
|
||||
Note: if init_process_group set to False, you have to call this method manually. Otherwise,
|
||||
the communication related function, such as ShapeConsistencyManager.apply will raise errors.
|
||||
'''
|
||||
|
@@ -28,6 +28,15 @@ class ShapeConsistencyOptions:
|
||||
pass
|
||||
|
||||
|
||||
def to_global(distributed_tensor: torch.Tensor, sharding_spec: ShardingSpec):
|
||||
shape_consistency_manager = ShapeConsistencyManager()
|
||||
global_sharding_spec = ShardingSpec(sharding_spec.device_mesh, sharding_spec.entire_shape, {})
|
||||
with torch.no_grad():
|
||||
global_tensor = shape_consistency_manager.apply_for_autoparallel_runtime(distributed_tensor, sharding_spec,
|
||||
global_sharding_spec)
|
||||
return global_tensor
|
||||
|
||||
|
||||
def set_shape_consistency_options(options: ShapeConsistencyOptions):
|
||||
"""
|
||||
Configure the shape consistency manager via function call.
|
||||
|
@@ -6,7 +6,6 @@ from functools import reduce
|
||||
import torch
|
||||
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.tensor.utils import (all_gather_simulator, all_to_all_simulator, shard_simulator)
|
||||
|
||||
__all__ = ['_DimSpec', 'ShardingException', 'ShardingSpec']
|
||||
|
||||
@@ -23,7 +22,7 @@ class _DimSpec:
|
||||
This class is used internally in ShardingSpec.
|
||||
|
||||
Argument:
|
||||
shard_list(List[int]): if shard_list is None, the dim spec will be 'R' type.
|
||||
shard_list(List[int]): if shard_list is None, the dim spec will be 'R' type.
|
||||
Otherwise, the element in shard_list means the data will be sharded in that dimension.
|
||||
'''
|
||||
|
||||
@@ -62,7 +61,7 @@ class _DimSpec:
|
||||
|
||||
def build_difference_2d_dict(self):
|
||||
'''
|
||||
Build a difference maping for 2D device mesh case. It will be used to
|
||||
Build a difference maping for 2D device mesh case. It will be used to
|
||||
compute the difference between DimSpec pairs.
|
||||
'''
|
||||
|
||||
@@ -159,9 +158,9 @@ class ShardingNotDivisibleError(ShardingSpecException):
|
||||
class ShardingSpec:
|
||||
'''
|
||||
Sharding spec for a tensor, it contains info of the logical device mesh this tensor belong
|
||||
to, the entire shape of the tensor before sharded, and the sharding sequence looks like
|
||||
to, the entire shape of the tensor before sharded, and the sharding sequence looks like
|
||||
[R, R, S0, S1].
|
||||
|
||||
|
||||
Argument:
|
||||
device_mesh(DeviceMesh): A logical view of a physical mesh.
|
||||
entire_shape(torch.Size): The entire shape of tensor before sharded.
|
||||
@@ -260,10 +259,10 @@ class ShardingSpec:
|
||||
# device_mesh_shape: (4, 4)
|
||||
sharding_spec_to_compare = ShardingSpec(device_mesh, entire_shape, dim_partition_dict_to_compare)
|
||||
print(sharding_spec.sharding_sequence_difference(sharding_spec_to_compare))
|
||||
|
||||
|
||||
Output:
|
||||
25
|
||||
|
||||
|
||||
Argument:
|
||||
other(ShardingSpec): The ShardingSpec to compared with.
|
||||
|
||||
|
Reference in New Issue
Block a user