mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 20:40:34 +00:00
[autoparallel] adapt autoparallel with new analyzer (#3261)
* [autoparallel] adapt autoparallel with new analyzer * fix all node handler tests * polish * polish
This commit is contained in:
@@ -70,14 +70,28 @@ class MetaInfo:
|
||||
if self._strategy is not None and self._target is not None:
|
||||
self.compute_metainfo()
|
||||
|
||||
def compute_sharded_opdata(self, operation_data: OperationData, sharding_spec: ShardingSpec) -> torch.Tensor:
|
||||
def compute_sharded_opdata(self, operation_data: OperationData, sharding_spec: ShardingSpec):
|
||||
"""
|
||||
Compute sharded opdata based on the given data and sharding spec.
|
||||
"""
|
||||
return OperationData(name=operation_data.name,
|
||||
data=torch.zeros(sharding_spec.get_sharded_shape_per_device(), device="meta"),
|
||||
type=operation_data.type,
|
||||
logical_shape=operation_data.logical_shape)
|
||||
|
||||
if isinstance(sharding_spec, ShardingSpec):
|
||||
op_data = OperationData(name=operation_data.name,
|
||||
data=torch.zeros(sharding_spec.get_sharded_shape_per_device(), device="meta"),
|
||||
type=operation_data.type,
|
||||
logical_shape=operation_data.logical_shape)
|
||||
elif isinstance(sharding_spec, (list, tuple)):
|
||||
data = operation_data.data
|
||||
assert isinstance(data, (list, tuple)), f"Data Should be list or tuple, but got {type(data)}."
|
||||
assert len(data) == len(sharding_spec), f"Length of data and sharding spec should be the same."
|
||||
sharded_data = []
|
||||
for d, s in zip(data, sharding_spec):
|
||||
sharded_data.append(torch.zeros(s.get_sharded_shape_per_device(), device="meta"))
|
||||
op_data = OperationData(name=operation_data.name, data=sharded_data, type=operation_data.type)
|
||||
else:
|
||||
raise ValueError(f"Sharding spec should be ShardingSpec or list, but got {type(sharding_spec)}.")
|
||||
|
||||
return op_data
|
||||
|
||||
def compute_metainfo(self):
|
||||
"""
|
||||
|
@@ -387,12 +387,13 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes
|
||||
# This stream is created for overlaping the communication and computation.
|
||||
reduction_stream = torch.cuda.Stream()
|
||||
|
||||
def _add_hook_for_grad_communication(node, param):
|
||||
def _add_hook_for_grad_communication(node, param, name=None):
|
||||
|
||||
comm_actions = node.best_strategy.communication_actions
|
||||
|
||||
def _filter_param_to_hook(node, op_data, comm_action):
|
||||
if node.op == 'call_module' and op_data.type == OperationDataType.PARAM and op_data.name == param.name and comm_action.comm_type == CommType.HOOK:
|
||||
def _filter_param_to_hook(node, op_data, comm_action, name):
|
||||
|
||||
if node.op == 'call_module' and op_data.type == OperationDataType.PARAM and op_data.name == name and comm_action.comm_type == CommType.HOOK:
|
||||
return True
|
||||
if node.op == 'get_attr' and isinstance(
|
||||
node._meta_data, torch.nn.parameter.Parameter) and comm_action.comm_type == CommType.HOOK:
|
||||
@@ -402,7 +403,7 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes
|
||||
for operation_data, comm_action in comm_actions.items():
|
||||
comm_spec_to_use = comm_action.comm_spec
|
||||
# register hook to the parameters
|
||||
if _filter_param_to_hook(node, operation_data, comm_action):
|
||||
if _filter_param_to_hook(node, operation_data, comm_action, name=name):
|
||||
|
||||
def wrapper(param, comm_spec, stream, overlap):
|
||||
|
||||
@@ -442,7 +443,7 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes
|
||||
param = _shard_param(param, target_sharding_spec)
|
||||
|
||||
setattr(target_module, name, param)
|
||||
_add_hook_for_grad_communication(node, param)
|
||||
_add_hook_for_grad_communication(node, param, name)
|
||||
|
||||
sharded_buffer_dict = {}
|
||||
# apply the sharding spec of buffers
|
||||
|
@@ -81,7 +81,10 @@ class AddBMMFunctionHandler(NodeHandler):
|
||||
def get_strategy_generator(self) -> List[StrategyGenerator]:
|
||||
op_data_mapping = self.get_operation_data_mapping()
|
||||
generators = []
|
||||
generators.append(BatchedMatMulStrategyGenerator(op_data_mapping, self.device_mesh))
|
||||
generator = BatchedMatMulStrategyGenerator(op_data_mapping, self.device_mesh)
|
||||
# addbmm will shrink the first batch dim
|
||||
generator.squeeze_batch_dim = True
|
||||
generators.append(generator)
|
||||
return generators
|
||||
|
||||
def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]:
|
||||
|
@@ -776,10 +776,6 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
|
||||
bias_op_data = self.op_data['bias']
|
||||
assert bias_op_data.data.dim() < 3 and len(bias_op_data.logical_shape) == 2
|
||||
|
||||
if self.op_data['output'].data.dim() == 2:
|
||||
# addbmm will shrink the first batch dim
|
||||
self.squeeze_batch_dim = True
|
||||
|
||||
def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
|
||||
fwd_compute_cost = self.op_data['input'].data.shape[-1] * reduce(operator.mul,
|
||||
self.op_data['output'].data.shape)
|
||||
|
Reference in New Issue
Block a user