mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 17:17:05 +00:00
[autoparallel] gpt2lp runtimee test (#2113)
This commit is contained in:
@@ -11,6 +11,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
OperationDataType,
|
||||
ShardingStrategy,
|
||||
)
|
||||
from colossalai.auto_parallel.tensor_shard.solver.strategies_constructor import StrategiesConstructor
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.tensor.comm_spec import _all_reduce
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
@@ -19,13 +20,23 @@ from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
shape_consistency_manager = ShapeConsistencyManager()
|
||||
|
||||
|
||||
def _solution_annotatation(gm: torch.fx.GraphModule, solution: List[int]):
|
||||
def _solution_annotatation(gm: torch.fx.GraphModule,
|
||||
solution: List[int],
|
||||
strategies_constructor: StrategiesConstructor = None):
|
||||
"""
|
||||
This method is used to stick the solution strategy to the nodes and add the information
|
||||
required in runtime into graph as placeholder nodes.
|
||||
"""
|
||||
mod_graph = gm.graph
|
||||
nodes = tuple(mod_graph.nodes)
|
||||
# TODO: In future PR, strategies_constructor should be a required argument,
|
||||
# instead of optional argument. This is because we don't need to consider nodes with
|
||||
# no strategy in runtime preparation pass.
|
||||
if strategies_constructor is not None:
|
||||
nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies]
|
||||
no_strategy_nodes = strategies_constructor.no_strategy_nodes
|
||||
else:
|
||||
nodes = tuple(mod_graph.nodes)
|
||||
no_strategy_nodes = []
|
||||
|
||||
# the dict to get origin sharding spec of node
|
||||
origin_node_sharding_spec_dict = {}
|
||||
@@ -44,7 +55,10 @@ def _solution_annotatation(gm: torch.fx.GraphModule, solution: List[int]):
|
||||
for index, node in enumerate(nodes):
|
||||
target_sharding_specs = []
|
||||
for user_node in node.strategies_vector.successor_nodes:
|
||||
target_sharding_spec = user_node.best_strategy.get_sharding_spec_by_name(str(node.name))
|
||||
if user_node in no_strategy_nodes:
|
||||
target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(str(node.name))
|
||||
else:
|
||||
target_sharding_spec = user_node.best_strategy.get_sharding_spec_by_name(str(node.name))
|
||||
target_sharding_specs.append(target_sharding_spec)
|
||||
sharding_spec_convert_dict[index] = target_sharding_specs
|
||||
setattr(node, 'target_sharding_specs', target_sharding_specs)
|
||||
@@ -136,13 +150,17 @@ def _node_args_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
|
||||
new_args.append(arg)
|
||||
|
||||
for dim, shard_dims in output_dim_partition_dict.items():
|
||||
# we will skip the dim with -1 value
|
||||
if new_args[dim + 1] == -1:
|
||||
continue
|
||||
total_shard_size = 1
|
||||
for shard_dim in shard_dims:
|
||||
total_shard_size *= device_mesh.shape[shard_dim]
|
||||
new_args[dim + 1] //= total_shard_size
|
||||
# There are two ways to use torch.view:
|
||||
# 1. torch.view(input, *shape)
|
||||
# 2. torch.view(input, shape)
|
||||
if isinstance(new_args[1], int):
|
||||
new_args[dim + 1] //= total_shard_size
|
||||
else:
|
||||
new_args[1] = list(new_args[1])
|
||||
new_args[1][dim] //= total_shard_size
|
||||
node.args = tuple(new_args)
|
||||
|
||||
elif node.op == 'call_function':
|
||||
@@ -193,12 +211,12 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
|
||||
origin_sharding_spec = ShardingSpec(device_mesh, param.shape, {})
|
||||
setattr(param, 'sharding_spec', origin_sharding_spec)
|
||||
# TODO: build a ColoParamter class to manager the distributed parameters
|
||||
param_sharded = torch.nn.Parameter(
|
||||
shape_consistency_manager.apply_for_autoparallel_runtime(param.data, param.sharding_spec,
|
||||
target_sharding_spec).detach().clone())
|
||||
else:
|
||||
param_sharded = param
|
||||
setattr(target_module, name, param_sharded)
|
||||
# we could use .data here, because all the operations just happen before the real training
|
||||
# loop, so we don't need to track these operations in the autograd graph.
|
||||
param.data = shape_consistency_manager.apply_for_autoparallel_runtime(
|
||||
param.data, param.sharding_spec, target_sharding_spec).detach().clone()
|
||||
|
||||
setattr(target_module, name, param)
|
||||
comm_actions = node.best_strategy.communication_actions
|
||||
for operation_data, comm_action in comm_actions.items():
|
||||
comm_spec_to_use = comm_action.comm_spec
|
||||
@@ -212,7 +230,7 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
|
||||
|
||||
param.register_hook(hook_fn)
|
||||
|
||||
wrapper(param_sharded, comm_spec_to_use)
|
||||
wrapper(param, comm_spec_to_use)
|
||||
|
||||
sharded_buffer_dict = {}
|
||||
# apply the sharding spec of buffers
|
||||
@@ -242,12 +260,13 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
|
||||
origin_sharding_spec = ShardingSpec(device_mesh, target.shape, {})
|
||||
setattr(target, 'sharding_spec', origin_sharding_spec)
|
||||
# TODO: build a ColoParamter class to manager the distributed parameters
|
||||
target_sharded = torch.nn.Parameter(
|
||||
shape_consistency_manager.apply_for_autoparallel_runtime(target.data, target.sharding_spec,
|
||||
target_sharding_spec).detach().clone())
|
||||
else:
|
||||
target_sharded = target
|
||||
setattr(target_module, atoms[-1], target_sharded)
|
||||
# we could use .data here, because all the operations just happen before the real training
|
||||
# loop, so we don't need to track these operations in the autograd graph.
|
||||
target.data = shape_consistency_manager.apply_for_autoparallel_runtime(
|
||||
target.data, target.sharding_spec, target_sharding_spec).detach().clone()
|
||||
|
||||
assert hasattr(target_module, atoms[-1])
|
||||
setattr(target_module, atoms[-1], target)
|
||||
|
||||
comm_actions = node.best_strategy.communication_actions
|
||||
for operation_data, comm_action in comm_actions.items():
|
||||
@@ -262,7 +281,7 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
|
||||
|
||||
param.register_hook(hook_fn)
|
||||
|
||||
wrapper(target_sharded, comm_spec_to_use)
|
||||
wrapper(target, comm_spec_to_use)
|
||||
return gm
|
||||
|
||||
|
||||
@@ -273,9 +292,12 @@ def implicit_comm_action_apply(gm: torch.fx.GraphModule):
|
||||
pass
|
||||
|
||||
|
||||
def runtime_preparation_pass(gm: torch.fx.GraphModule, solution: List[int], device_mesh: DeviceMesh):
|
||||
def runtime_preparation_pass(gm: torch.fx.GraphModule,
|
||||
solution: List[int],
|
||||
device_mesh: DeviceMesh,
|
||||
strategies_constructor: StrategiesConstructor = None):
|
||||
gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict = _solution_annotatation(
|
||||
gm, solution)
|
||||
gm, solution, strategies_constructor)
|
||||
gm = _node_args_converting(gm, device_mesh)
|
||||
# TODO: the pass below should be uncommented after the implementation of implicit_comm_action_apply_pass completed.
|
||||
# gm = implicit_comm_action_apply(gm)
|
||||
|
@@ -41,6 +41,7 @@ class StrategiesConstructor:
|
||||
self.leaf_strategies = []
|
||||
self.strategy_map = {}
|
||||
self.solver_options = solver_options
|
||||
self.no_strategy_nodes = []
|
||||
|
||||
def remove_duplicated_strategy(self, strategies_vector):
|
||||
'''
|
||||
@@ -78,12 +79,11 @@ class StrategiesConstructor:
|
||||
|
||||
return _check_no_strategy_for_data(node._meta_data)
|
||||
|
||||
no_strategy_node = []
|
||||
for node in self.nodes:
|
||||
strategies_vector = StrategiesVector(node)
|
||||
|
||||
if _check_no_strategy_for_node(node):
|
||||
no_strategy_node.append(node)
|
||||
self.no_strategy_nodes.append(node)
|
||||
pass
|
||||
|
||||
# placeholder node
|
||||
|
Reference in New Issue
Block a user