[autoparallel] adapt solver with self attention (#2037)

* [autoparallel] adapt solver with self attention

* polish code
This commit is contained in:
YuliangLiu0306
2022-12-01 17:53:15 +08:00
committed by GitHub
parent d3499c98d4
commit 1c1fe44305
6 changed files with 320 additions and 13 deletions

View File

@@ -26,7 +26,14 @@ ELEMENTWISE_METHOD_OP = [
# TODO: contiguous maybe need some extra processes.
torch.Tensor.contiguous
]
RESHAPE_FUNC_OP = [torch.flatten, torch.reshape]
RESHAPE_FUNC_OP = [
torch.flatten,
torch.reshape,
torch.transpose,
torch.split,
torch.permute,
operator.getitem,
]
RESHAPE_METHOD_OP = [
torch.Tensor.view,
torch.Tensor.unsqueeze,

View File

@@ -9,7 +9,14 @@ from torch.fx.node import Node
from colossalai.tensor.shape_consistency import CommSpec
from colossalai.tensor.sharding_spec import ShardingSpec
from .constants import BCAST_FUNC_OP, ELEMENTWISE_FUNC_OP, ELEMENTWISE_MODULE_OP, RESHAPE_FUNC_OP
from .constants import (
BCAST_FUNC_OP,
ELEMENTWISE_FUNC_OP,
ELEMENTWISE_METHOD_OP,
ELEMENTWISE_MODULE_OP,
RESHAPE_FUNC_OP,
RESHAPE_METHOD_OP,
)
__all__ = ['OperationDataType', 'OperationData', 'TrainCycleItem', 'MemoryCost', 'ShardingStrategy', 'StrategiesVector']
@@ -249,8 +256,15 @@ class StrategiesVector(list):
# we could merge bcast op if the rhs is a scalar, because it will fall back to the element-wise case.
if self.node.target in BCAST_FUNC_OP and len(self.predecessor_nodes) == 1:
merge_label = True
# we could merge reshape op, because the output sharding spec of reshape op is always fully replicated.
# we could merge reshape op, because their computation costs are negligible.
if self.node.target in RESHAPE_FUNC_OP:
merge_label = True
if self.node.op == 'call_method':
# we could merge reshape op, because their computation costs are negligible.
method = getattr(self.node.args[0]._meta_data.__class__, self.node.target)
if method in RESHAPE_METHOD_OP:
merge_label = True
if method in ELEMENTWISE_METHOD_OP:
merge_label = True
return merge_label

View File

@@ -63,14 +63,40 @@ class CostGraph:
edge_cost[(j, i)] = resharding_cost_item.total
self.edge_costs[node_pair] = edge_cost
# add parents and children attribute to node
parent_nodes = [node for node in strategies_vector.predecessor_nodes]
children_nodes = [node for node in strategies_vector.successor_nodes]
# parent_nodes = [node for node in strategies_vector.predecessor_nodes]
# children_nodes = [node for node in strategies_vector.successor_nodes]
parent_nodes = []
children_nodes = []
def _check_tensor_in_node(data):
"""
This method is used to check whether the data has a tensor inside or not.
"""
has_tensor_flag = False
if isinstance(data, torch.Tensor):
return True
elif isinstance(data, (tuple, list)):
for d in data:
has_tensor_flag = has_tensor_flag or _check_tensor_in_node(d)
return has_tensor_flag
for node in strategies_vector.predecessor_nodes:
if _check_tensor_in_node(node._meta_data):
parent_nodes.append(node)
for node in strategies_vector.successor_nodes:
if _check_tensor_in_node(node._meta_data):
children_nodes.append(node)
setattr(dst_node, 'parents', parent_nodes)
setattr(dst_node, 'children', children_nodes)
if self.simplify and strategies_vector.check_merge():
for followed_node in strategies_vector.predecessor_nodes:
self.merge_pair.append((followed_node, dst_node))
# we only merge node pairs which src node has a tensor element inside.
# This is necessay because the node without a tensor element inside will not
# be assigned any strategy.
if _check_tensor_in_node(followed_node._meta_data):
self.merge_pair.append((followed_node, dst_node))
def get_edge_cost(self, src_node, dst_node):
return self.edge_costs[(src_node, dst_node)]

View File

@@ -154,12 +154,16 @@ class Solver:
if self.forward_only:
origin_communication_cost = communication_cost_item.fwd
compute_cost = compute_cost_item.fwd
# extract MemoryCost item from the memory TrainCycleItem
memory_cost = memory_cost_item.fwd
else:
origin_communication_cost = communication_cost_item.total
compute_cost = compute_cost_item.total
# extract MemoryCost item from the memory TrainCycleItem
memory_cost = memory_cost_item.total
# extract the memory cost in float from MemoryCost item and sum them up
memory_cost = memory_cost.parameter + memory_cost.activation + memory_cost.buffer
compute_costs.append(compute_cost)
# node in extra_node_costs means it has some extra communication
# cost from node merging, so we need to add those extra communication
@@ -366,6 +370,8 @@ class Solver:
for liveness_stage in liveness_set:
mem = 0
for live_variable in liveness_stage.unique_live_vars:
if live_variable.node not in self.node_index_dict:
continue
node_index = self.node_index_dict[live_variable.node]
mem += lpSum(s[node_index][j] * m[node_index][j] for j in range(len(s[node_index])))
prob += mem <= memory_budget

View File

@@ -53,17 +53,38 @@ def detect_reshape_mapping(origin_shape: torch.Size, tgt_shape: torch.Size) -> D
while origin_index != len(origin_shape) or tgt_index != len(tgt_shape):
if original_dimension_size == tgt_dimension_size:
reshape_mapping_dict[tuple(origin_dims)] = tuple(tgt_dims)
origin_index += 1
tgt_index += 1
# if the origin_dims has no element, it means the original tensor has been fully matched.
# Therefore, we do not have to increase the origin_index for that case.
if len(origin_dims) > 0:
origin_index += 1
# if the tgt_dims has no element, it means the original tensor has been fully matched.
# Therefore, we do not have to increase the tgt_index for that case.
if len(tgt_dims) > 0:
tgt_index += 1
# the last step of loop should always end with condition
# so we need to manually skip the preparation for next step
# in the last step.
if origin_index == len(origin_shape):
if origin_index == len(origin_shape) and tgt_index == len(tgt_shape):
continue
original_dimension_size = origin_shape[origin_index]
tgt_dimension_size = tgt_shape[tgt_index]
origin_dims = [origin_len - origin_index - 1]
tgt_dims = [tgt_len - tgt_index - 1]
# If origin_index equals to origin_len, we just need to set the original_dimension_size
# to 1 to match the remaining '1's in the target tensor shape.
if origin_index == len(origin_shape):
original_dimension_size = 1
origin_dims = []
else:
original_dimension_size = origin_shape[origin_index]
origin_dims = [origin_len - origin_index - 1]
# If tgt_index equals to tgt_len, we just need to set the tgt_dimension_size
# to 1 to match the remaining '1's in the original tensor shape.
if tgt_index == len(tgt_shape):
tgt_dimension_size = 1
tgt_dims = []
else:
tgt_dimension_size = tgt_shape[tgt_index]
tgt_dims = [tgt_len - tgt_index - 1]
previous_label = PreviousStatus.RESET
elif original_dimension_size > tgt_dimension_size:
@@ -141,6 +162,9 @@ def check_keep_sharding_status(input_dim_partition_dict: Dict[int, List[int]],
"""
sharded_dims = list(input_dim_partition_dict.keys())
for input_dims in reshape_mapping_dict.keys():
# if input_dims has no element, we could just skip this iteration.
if len(input_dims) == 0:
continue
min_element = min(input_dims)
for dim in input_dims:
if dim in sharded_dims and dim is not min_element: