[Auto Parallel]: Speed up intra-op plan generation by 44% (#5446)

* Remove unnecessary calls to deepcopy

* Build DimSpec's difference dict only once

This change considerably speeds up construction speed of DimSpec objects. The difference_dict is the same for each DimSpec object, so a single copy of it is enough.

* Fix documentation of DimSpec's difference method
This commit is contained in:
Stephan Kö 2024-07-15 12:05:06 +08:00 committed by GitHub
parent c068ef0fa0
commit 45c49dde96
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 103 additions and 73 deletions

View File

@ -1,4 +1,3 @@
from copy import deepcopy
from typing import Dict, List from typing import Dict, List
from ..utils import merge_same_dim_mesh_list from ..utils import merge_same_dim_mesh_list
@ -23,10 +22,11 @@ class DimSpec:
Otherwise, the element in shard_list means the data will be sharded in that dimension. Otherwise, the element in shard_list means the data will be sharded in that dimension.
""" """
_DIFFERENCE_DICT = None
def __init__(self, shard_list): def __init__(self, shard_list):
self.is_replica = len(shard_list) == 0 self.is_replica = len(shard_list) == 0
self.shard_list = shard_list self.shard_list = shard_list
self.build_difference_2d_dict()
def __eq__(self, other): def __eq__(self, other):
return str(self) == str(other) return str(self) == str(other)
@ -39,24 +39,43 @@ class DimSpec:
target += str(dim) target += str(dim)
return target return target
def _convert_str_to_shard_list(self, str_spec): @property
def difference_dict(self):
""" """
Convert str_spec into shard_list. Returns the difference dict, and lazily initializes it when needed
Return:
difference_dict(Dict[Tuple[int, int], Union[int, float, str]]):
difference dict
"""
if self._DIFFERENCE_DICT is None:
self._DIFFERENCE_DICT = self._build_difference_2d_dict()
return self._DIFFERENCE_DICT
def dim_diff(self, other):
"""
The difference between two DimSpec.
Argument: Argument:
str_spec(str): dim spec in str type. other(DimSpec): the dim spec to compare with.
Return:
difference(int): the difference between two DimSpec.
Example:
dim_spec = DimSpec([0])
other_dim_spec = DimSpec([0, 1])
print(dim_spec.dim_diff(other_dim_spec))
Output:
5
""" """
difference = self.difference_dict[(str(self), str(other))]
return difference
if str_spec == "R": @classmethod
return [] def _build_difference_2d_dict(cls):
if str_spec == "S0":
return [0]
if str_spec == "S1":
return [1]
if str_spec == "S01":
return [0, 1]
def build_difference_2d_dict(self):
""" """
Build a difference mapping for 2D device mesh case. It will be used to Build a difference mapping for 2D device mesh case. It will be used to
compute the difference between DimSpec pairs. compute the difference between DimSpec pairs.
@ -67,9 +86,8 @@ class DimSpec:
difference_dict = {} difference_dict = {}
for source_spec in source_spec_list: for source_spec in source_spec_list:
for target_spec in target_spec_list: for target_spec in target_spec_list:
spec_pair = (deepcopy(source_spec), deepcopy(target_spec)) source_shard_list = cls._convert_str_to_shard_list(source_spec)
source_shard_list = self._convert_str_to_shard_list(source_spec) target_shard_list = cls._convert_str_to_shard_list(target_spec)
target_shard_list = self._convert_str_to_shard_list(target_spec)
# source same as target # source same as target
if source_shard_list == target_shard_list: if source_shard_list == target_shard_list:
@ -112,30 +130,27 @@ class DimSpec:
else: else:
difference = NAN difference = NAN
difference_dict[spec_pair] = difference difference_dict[(source_spec, target_spec)] = difference
self.difference_dict = difference_dict return difference_dict
def dim_diff(self, other): @staticmethod
def _convert_str_to_shard_list(str_spec):
""" """
The difference between two _DimSpec. Convert str_spec into shard_list.
Argument: Argument:
other(_DimSpec): the dim spec to compare with. str_spec(str): dim spec in str type.
Return:
difference(int): the difference between two _DimSpec.
Example:
dim_spec = _DimSpec([0])
other_dim_spec = _DimSpec([0, 1])
print(dim_spec.difference(other_dim_spec))
Output:
5
""" """
difference = self.difference_dict[(str(self), str(other))]
return difference if str_spec == "R":
return []
if str_spec == "S0":
return [0]
if str_spec == "S1":
return [1]
if str_spec == "S01":
return [0, 1]
class ShardingSpec: class ShardingSpec:

View File

@ -1,5 +1,4 @@
import operator import operator
from copy import deepcopy
from functools import reduce from functools import reduce
import torch import torch
@ -27,10 +26,11 @@ class _DimSpec:
Otherwise, the element in shard_list means the data will be sharded in that dimension. Otherwise, the element in shard_list means the data will be sharded in that dimension.
""" """
_DIFFERENCE_DICT = None
def __init__(self, shard_list): def __init__(self, shard_list):
self.is_replica = len(shard_list) == 0 self.is_replica = len(shard_list) == 0
self.shard_list = shard_list self.shard_list = shard_list
self.build_difference_2d_dict()
def __eq__(self, other): def __eq__(self, other):
return str(self) == str(other) return str(self) == str(other)
@ -43,27 +43,46 @@ class _DimSpec:
target += str(dim) target += str(dim)
return target return target
def _convert_str_to_shard_list(self, str_spec): @property
def difference_dict(self):
""" """
Convert str_spec into shard_list. Returns the difference dict, and lazily initializes it when needed
Return:
difference_dict(Dict[Tuple[int, int], Union[int, float, str]]):
difference dict
"""
if self._DIFFERENCE_DICT is None:
self._DIFFERENCE_DICT = self._build_difference_2d_dict()
return self._DIFFERENCE_DICT
def difference(self, other):
"""
The difference between two _DimSpec.
Argument: Argument:
str_spec(str): dim spec in str type. other(_DimSpec): the dim spec to compare with.
Return:
difference(int): the difference between two _DimSpec.
Example:
dim_spec = _DimSpec([0])
other_dim_spec = _DimSpec([0, 1])
print(dim_spec.difference(other_dim_spec))
Output:
5
""" """
difference = self.difference_dict[(str(self), str(other))]
return difference
if str_spec == "R": @classmethod
return [] def _build_difference_2d_dict(cls):
if str_spec == "S0":
return [0]
if str_spec == "S1":
return [1]
if str_spec == "S01":
return [0, 1]
def build_difference_2d_dict(self):
""" """
Build a difference mapping for 2D device mesh case. It will be used to Build a difference mapping for 2D device mesh case. It will be used to
compute the difference between DimSpec pairs. compute the difference between _DimSpec pairs.
""" """
source_spec_list = ["R", "S0", "S1", "S01"] source_spec_list = ["R", "S0", "S1", "S01"]
@ -71,9 +90,8 @@ class _DimSpec:
difference_dict = {} difference_dict = {}
for source_spec in source_spec_list: for source_spec in source_spec_list:
for target_spec in target_spec_list: for target_spec in target_spec_list:
spec_pair = (deepcopy(source_spec), deepcopy(target_spec)) source_shard_list = cls._convert_str_to_shard_list(source_spec)
source_shard_list = self._convert_str_to_shard_list(source_spec) target_shard_list = cls._convert_str_to_shard_list(target_spec)
target_shard_list = self._convert_str_to_shard_list(target_spec)
# source same as target # source same as target
if source_shard_list == target_shard_list: if source_shard_list == target_shard_list:
@ -116,30 +134,27 @@ class _DimSpec:
else: else:
difference = NAN difference = NAN
difference_dict[spec_pair] = difference difference_dict[(source_spec, target_spec)] = difference
self.difference_dict = difference_dict return difference_dict
def difference(self, other): @staticmethod
def _convert_str_to_shard_list(str_spec):
""" """
The difference between two _DimSpec. Convert str_spec into shard_list.
Argument: Argument:
other(_DimSpec): the dim spec to compare with. str_spec(str): dim spec in str type.
Return:
difference(int): the difference between two _DimSpec.
Example:
dim_spec = _DimSpec([0])
other_dim_spec = _DimSpec([0, 1])
print(dim_spec.difference(other_dim_spec))
Output:
5
""" """
difference = self.difference_dict[(str(self), str(other))]
return difference if str_spec == "R":
return []
if str_spec == "S0":
return [0]
if str_spec == "S1":
return [1]
if str_spec == "S01":
return [0, 1]
class ShardingSpecException(Exception): class ShardingSpecException(Exception):