mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-25 03:31:56 +00:00
[autoparallel] adapt solver with self attention (#2037)
* [autoparallel] adapt solver with self attention * polish code
This commit is contained in:
@@ -26,7 +26,14 @@ ELEMENTWISE_METHOD_OP = [
|
||||
# TODO: contiguous maybe need some extra processes.
|
||||
torch.Tensor.contiguous
|
||||
]
|
||||
RESHAPE_FUNC_OP = [torch.flatten, torch.reshape]
|
||||
RESHAPE_FUNC_OP = [
|
||||
torch.flatten,
|
||||
torch.reshape,
|
||||
torch.transpose,
|
||||
torch.split,
|
||||
torch.permute,
|
||||
operator.getitem,
|
||||
]
|
||||
RESHAPE_METHOD_OP = [
|
||||
torch.Tensor.view,
|
||||
torch.Tensor.unsqueeze,
|
||||
|
@@ -9,7 +9,14 @@ from torch.fx.node import Node
|
||||
from colossalai.tensor.shape_consistency import CommSpec
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
from .constants import BCAST_FUNC_OP, ELEMENTWISE_FUNC_OP, ELEMENTWISE_MODULE_OP, RESHAPE_FUNC_OP
|
||||
from .constants import (
|
||||
BCAST_FUNC_OP,
|
||||
ELEMENTWISE_FUNC_OP,
|
||||
ELEMENTWISE_METHOD_OP,
|
||||
ELEMENTWISE_MODULE_OP,
|
||||
RESHAPE_FUNC_OP,
|
||||
RESHAPE_METHOD_OP,
|
||||
)
|
||||
|
||||
__all__ = ['OperationDataType', 'OperationData', 'TrainCycleItem', 'MemoryCost', 'ShardingStrategy', 'StrategiesVector']
|
||||
|
||||
@@ -249,8 +256,15 @@ class StrategiesVector(list):
|
||||
# we could merge bcast op if the rhs is a scalar, because it will fall back to the element-wise case.
|
||||
if self.node.target in BCAST_FUNC_OP and len(self.predecessor_nodes) == 1:
|
||||
merge_label = True
|
||||
# we could merge reshape op, because the output sharding spec of reshape op is always fully replicated.
|
||||
# we could merge reshape op, because their computation costs are negligible.
|
||||
if self.node.target in RESHAPE_FUNC_OP:
|
||||
merge_label = True
|
||||
|
||||
if self.node.op == 'call_method':
|
||||
# we could merge reshape op, because their computation costs are negligible.
|
||||
method = getattr(self.node.args[0]._meta_data.__class__, self.node.target)
|
||||
if method in RESHAPE_METHOD_OP:
|
||||
merge_label = True
|
||||
if method in ELEMENTWISE_METHOD_OP:
|
||||
merge_label = True
|
||||
return merge_label
|
||||
|
@@ -63,14 +63,40 @@ class CostGraph:
|
||||
edge_cost[(j, i)] = resharding_cost_item.total
|
||||
self.edge_costs[node_pair] = edge_cost
|
||||
# add parents and children attribute to node
|
||||
parent_nodes = [node for node in strategies_vector.predecessor_nodes]
|
||||
children_nodes = [node for node in strategies_vector.successor_nodes]
|
||||
# parent_nodes = [node for node in strategies_vector.predecessor_nodes]
|
||||
# children_nodes = [node for node in strategies_vector.successor_nodes]
|
||||
parent_nodes = []
|
||||
children_nodes = []
|
||||
|
||||
def _check_tensor_in_node(data):
|
||||
"""
|
||||
This method is used to check whether the data has a tensor inside or not.
|
||||
"""
|
||||
has_tensor_flag = False
|
||||
if isinstance(data, torch.Tensor):
|
||||
return True
|
||||
elif isinstance(data, (tuple, list)):
|
||||
for d in data:
|
||||
has_tensor_flag = has_tensor_flag or _check_tensor_in_node(d)
|
||||
return has_tensor_flag
|
||||
|
||||
for node in strategies_vector.predecessor_nodes:
|
||||
if _check_tensor_in_node(node._meta_data):
|
||||
parent_nodes.append(node)
|
||||
for node in strategies_vector.successor_nodes:
|
||||
if _check_tensor_in_node(node._meta_data):
|
||||
children_nodes.append(node)
|
||||
|
||||
setattr(dst_node, 'parents', parent_nodes)
|
||||
setattr(dst_node, 'children', children_nodes)
|
||||
|
||||
if self.simplify and strategies_vector.check_merge():
|
||||
for followed_node in strategies_vector.predecessor_nodes:
|
||||
self.merge_pair.append((followed_node, dst_node))
|
||||
# we only merge node pairs which src node has a tensor element inside.
|
||||
# This is necessay because the node without a tensor element inside will not
|
||||
# be assigned any strategy.
|
||||
if _check_tensor_in_node(followed_node._meta_data):
|
||||
self.merge_pair.append((followed_node, dst_node))
|
||||
|
||||
def get_edge_cost(self, src_node, dst_node):
|
||||
return self.edge_costs[(src_node, dst_node)]
|
||||
|
@@ -154,12 +154,16 @@ class Solver:
|
||||
if self.forward_only:
|
||||
origin_communication_cost = communication_cost_item.fwd
|
||||
compute_cost = compute_cost_item.fwd
|
||||
# extract MemoryCost item from the memory TrainCycleItem
|
||||
memory_cost = memory_cost_item.fwd
|
||||
else:
|
||||
origin_communication_cost = communication_cost_item.total
|
||||
compute_cost = compute_cost_item.total
|
||||
# extract MemoryCost item from the memory TrainCycleItem
|
||||
memory_cost = memory_cost_item.total
|
||||
|
||||
# extract the memory cost in float from MemoryCost item and sum them up
|
||||
memory_cost = memory_cost.parameter + memory_cost.activation + memory_cost.buffer
|
||||
compute_costs.append(compute_cost)
|
||||
# node in extra_node_costs means it has some extra communication
|
||||
# cost from node merging, so we need to add those extra communication
|
||||
@@ -366,6 +370,8 @@ class Solver:
|
||||
for liveness_stage in liveness_set:
|
||||
mem = 0
|
||||
for live_variable in liveness_stage.unique_live_vars:
|
||||
if live_variable.node not in self.node_index_dict:
|
||||
continue
|
||||
node_index = self.node_index_dict[live_variable.node]
|
||||
mem += lpSum(s[node_index][j] * m[node_index][j] for j in range(len(s[node_index])))
|
||||
prob += mem <= memory_budget
|
||||
|
@@ -53,17 +53,38 @@ def detect_reshape_mapping(origin_shape: torch.Size, tgt_shape: torch.Size) -> D
|
||||
while origin_index != len(origin_shape) or tgt_index != len(tgt_shape):
|
||||
if original_dimension_size == tgt_dimension_size:
|
||||
reshape_mapping_dict[tuple(origin_dims)] = tuple(tgt_dims)
|
||||
origin_index += 1
|
||||
tgt_index += 1
|
||||
# if the origin_dims has no element, it means the original tensor has been fully matched.
|
||||
# Therefore, we do not have to increase the origin_index for that case.
|
||||
if len(origin_dims) > 0:
|
||||
origin_index += 1
|
||||
# if the tgt_dims has no element, it means the original tensor has been fully matched.
|
||||
# Therefore, we do not have to increase the tgt_index for that case.
|
||||
if len(tgt_dims) > 0:
|
||||
tgt_index += 1
|
||||
# the last step of loop should always end with condition
|
||||
# so we need to manually skip the preparation for next step
|
||||
# in the last step.
|
||||
if origin_index == len(origin_shape):
|
||||
if origin_index == len(origin_shape) and tgt_index == len(tgt_shape):
|
||||
continue
|
||||
original_dimension_size = origin_shape[origin_index]
|
||||
tgt_dimension_size = tgt_shape[tgt_index]
|
||||
origin_dims = [origin_len - origin_index - 1]
|
||||
tgt_dims = [tgt_len - tgt_index - 1]
|
||||
|
||||
# If origin_index equals to origin_len, we just need to set the original_dimension_size
|
||||
# to 1 to match the remaining '1's in the target tensor shape.
|
||||
if origin_index == len(origin_shape):
|
||||
original_dimension_size = 1
|
||||
origin_dims = []
|
||||
else:
|
||||
original_dimension_size = origin_shape[origin_index]
|
||||
origin_dims = [origin_len - origin_index - 1]
|
||||
|
||||
# If tgt_index equals to tgt_len, we just need to set the tgt_dimension_size
|
||||
# to 1 to match the remaining '1's in the original tensor shape.
|
||||
if tgt_index == len(tgt_shape):
|
||||
tgt_dimension_size = 1
|
||||
tgt_dims = []
|
||||
else:
|
||||
tgt_dimension_size = tgt_shape[tgt_index]
|
||||
tgt_dims = [tgt_len - tgt_index - 1]
|
||||
|
||||
previous_label = PreviousStatus.RESET
|
||||
|
||||
elif original_dimension_size > tgt_dimension_size:
|
||||
@@ -141,6 +162,9 @@ def check_keep_sharding_status(input_dim_partition_dict: Dict[int, List[int]],
|
||||
"""
|
||||
sharded_dims = list(input_dim_partition_dict.keys())
|
||||
for input_dims in reshape_mapping_dict.keys():
|
||||
# if input_dims has no element, we could just skip this iteration.
|
||||
if len(input_dims) == 0:
|
||||
continue
|
||||
min_element = min(input_dims)
|
||||
for dim in input_dims:
|
||||
if dim in sharded_dims and dim is not min_element:
|
||||
|
Reference in New Issue
Block a user