mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-04 09:40:11 +00:00
[Auto Parallel]: Speed up intra-op plan generation by 44% (#5446)
* Remove unnecessary calls to deepcopy * Build DimSpec's difference dict only once This change considerably speeds up construction speed of DimSpec objects. The difference_dict is the same for each DimSpec object, so a single copy of it is enough. * Fix documentation of DimSpec's difference method
This commit is contained in:
parent
c068ef0fa0
commit
45c49dde96
@ -1,4 +1,3 @@
|
|||||||
from copy import deepcopy
|
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
|
||||||
from ..utils import merge_same_dim_mesh_list
|
from ..utils import merge_same_dim_mesh_list
|
||||||
@ -23,10 +22,11 @@ class DimSpec:
|
|||||||
Otherwise, the element in shard_list means the data will be sharded in that dimension.
|
Otherwise, the element in shard_list means the data will be sharded in that dimension.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
_DIFFERENCE_DICT = None
|
||||||
|
|
||||||
def __init__(self, shard_list):
|
def __init__(self, shard_list):
|
||||||
self.is_replica = len(shard_list) == 0
|
self.is_replica = len(shard_list) == 0
|
||||||
self.shard_list = shard_list
|
self.shard_list = shard_list
|
||||||
self.build_difference_2d_dict()
|
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
return str(self) == str(other)
|
return str(self) == str(other)
|
||||||
@ -39,24 +39,43 @@ class DimSpec:
|
|||||||
target += str(dim)
|
target += str(dim)
|
||||||
return target
|
return target
|
||||||
|
|
||||||
def _convert_str_to_shard_list(self, str_spec):
|
@property
|
||||||
|
def difference_dict(self):
|
||||||
"""
|
"""
|
||||||
Convert str_spec into shard_list.
|
Returns the difference dict, and lazily initializes it when needed
|
||||||
|
|
||||||
|
Return:
|
||||||
|
difference_dict(Dict[Tuple[int, int], Union[int, float, str]]):
|
||||||
|
difference dict
|
||||||
|
"""
|
||||||
|
if self._DIFFERENCE_DICT is None:
|
||||||
|
self._DIFFERENCE_DICT = self._build_difference_2d_dict()
|
||||||
|
|
||||||
|
return self._DIFFERENCE_DICT
|
||||||
|
|
||||||
|
def dim_diff(self, other):
|
||||||
|
"""
|
||||||
|
The difference between two DimSpec.
|
||||||
|
|
||||||
Argument:
|
Argument:
|
||||||
str_spec(str): dim spec in str type.
|
other(DimSpec): the dim spec to compare with.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
difference(int): the difference between two DimSpec.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
dim_spec = DimSpec([0])
|
||||||
|
other_dim_spec = DimSpec([0, 1])
|
||||||
|
print(dim_spec.dim_diff(other_dim_spec))
|
||||||
|
|
||||||
|
Output:
|
||||||
|
5
|
||||||
"""
|
"""
|
||||||
|
difference = self.difference_dict[(str(self), str(other))]
|
||||||
|
return difference
|
||||||
|
|
||||||
if str_spec == "R":
|
@classmethod
|
||||||
return []
|
def _build_difference_2d_dict(cls):
|
||||||
if str_spec == "S0":
|
|
||||||
return [0]
|
|
||||||
if str_spec == "S1":
|
|
||||||
return [1]
|
|
||||||
if str_spec == "S01":
|
|
||||||
return [0, 1]
|
|
||||||
|
|
||||||
def build_difference_2d_dict(self):
|
|
||||||
"""
|
"""
|
||||||
Build a difference mapping for 2D device mesh case. It will be used to
|
Build a difference mapping for 2D device mesh case. It will be used to
|
||||||
compute the difference between DimSpec pairs.
|
compute the difference between DimSpec pairs.
|
||||||
@ -67,9 +86,8 @@ class DimSpec:
|
|||||||
difference_dict = {}
|
difference_dict = {}
|
||||||
for source_spec in source_spec_list:
|
for source_spec in source_spec_list:
|
||||||
for target_spec in target_spec_list:
|
for target_spec in target_spec_list:
|
||||||
spec_pair = (deepcopy(source_spec), deepcopy(target_spec))
|
source_shard_list = cls._convert_str_to_shard_list(source_spec)
|
||||||
source_shard_list = self._convert_str_to_shard_list(source_spec)
|
target_shard_list = cls._convert_str_to_shard_list(target_spec)
|
||||||
target_shard_list = self._convert_str_to_shard_list(target_spec)
|
|
||||||
|
|
||||||
# source same as target
|
# source same as target
|
||||||
if source_shard_list == target_shard_list:
|
if source_shard_list == target_shard_list:
|
||||||
@ -112,30 +130,27 @@ class DimSpec:
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
difference = NAN
|
difference = NAN
|
||||||
difference_dict[spec_pair] = difference
|
difference_dict[(source_spec, target_spec)] = difference
|
||||||
|
|
||||||
self.difference_dict = difference_dict
|
return difference_dict
|
||||||
|
|
||||||
def dim_diff(self, other):
|
@staticmethod
|
||||||
|
def _convert_str_to_shard_list(str_spec):
|
||||||
"""
|
"""
|
||||||
The difference between two _DimSpec.
|
Convert str_spec into shard_list.
|
||||||
|
|
||||||
Argument:
|
Argument:
|
||||||
other(_DimSpec): the dim spec to compare with.
|
str_spec(str): dim spec in str type.
|
||||||
|
|
||||||
Return:
|
|
||||||
difference(int): the difference between two _DimSpec.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
dim_spec = _DimSpec([0])
|
|
||||||
other_dim_spec = _DimSpec([0, 1])
|
|
||||||
print(dim_spec.difference(other_dim_spec))
|
|
||||||
|
|
||||||
Output:
|
|
||||||
5
|
|
||||||
"""
|
"""
|
||||||
difference = self.difference_dict[(str(self), str(other))]
|
|
||||||
return difference
|
if str_spec == "R":
|
||||||
|
return []
|
||||||
|
if str_spec == "S0":
|
||||||
|
return [0]
|
||||||
|
if str_spec == "S1":
|
||||||
|
return [1]
|
||||||
|
if str_spec == "S01":
|
||||||
|
return [0, 1]
|
||||||
|
|
||||||
|
|
||||||
class ShardingSpec:
|
class ShardingSpec:
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
import operator
|
import operator
|
||||||
from copy import deepcopy
|
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -27,10 +26,11 @@ class _DimSpec:
|
|||||||
Otherwise, the element in shard_list means the data will be sharded in that dimension.
|
Otherwise, the element in shard_list means the data will be sharded in that dimension.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
_DIFFERENCE_DICT = None
|
||||||
|
|
||||||
def __init__(self, shard_list):
|
def __init__(self, shard_list):
|
||||||
self.is_replica = len(shard_list) == 0
|
self.is_replica = len(shard_list) == 0
|
||||||
self.shard_list = shard_list
|
self.shard_list = shard_list
|
||||||
self.build_difference_2d_dict()
|
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
return str(self) == str(other)
|
return str(self) == str(other)
|
||||||
@ -43,27 +43,46 @@ class _DimSpec:
|
|||||||
target += str(dim)
|
target += str(dim)
|
||||||
return target
|
return target
|
||||||
|
|
||||||
def _convert_str_to_shard_list(self, str_spec):
|
@property
|
||||||
|
def difference_dict(self):
|
||||||
"""
|
"""
|
||||||
Convert str_spec into shard_list.
|
Returns the difference dict, and lazily initializes it when needed
|
||||||
|
|
||||||
|
Return:
|
||||||
|
difference_dict(Dict[Tuple[int, int], Union[int, float, str]]):
|
||||||
|
difference dict
|
||||||
|
"""
|
||||||
|
if self._DIFFERENCE_DICT is None:
|
||||||
|
self._DIFFERENCE_DICT = self._build_difference_2d_dict()
|
||||||
|
|
||||||
|
return self._DIFFERENCE_DICT
|
||||||
|
|
||||||
|
def difference(self, other):
|
||||||
|
"""
|
||||||
|
The difference between two _DimSpec.
|
||||||
|
|
||||||
Argument:
|
Argument:
|
||||||
str_spec(str): dim spec in str type.
|
other(_DimSpec): the dim spec to compare with.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
difference(int): the difference between two _DimSpec.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
dim_spec = _DimSpec([0])
|
||||||
|
other_dim_spec = _DimSpec([0, 1])
|
||||||
|
print(dim_spec.difference(other_dim_spec))
|
||||||
|
|
||||||
|
Output:
|
||||||
|
5
|
||||||
"""
|
"""
|
||||||
|
difference = self.difference_dict[(str(self), str(other))]
|
||||||
|
return difference
|
||||||
|
|
||||||
if str_spec == "R":
|
@classmethod
|
||||||
return []
|
def _build_difference_2d_dict(cls):
|
||||||
if str_spec == "S0":
|
|
||||||
return [0]
|
|
||||||
if str_spec == "S1":
|
|
||||||
return [1]
|
|
||||||
if str_spec == "S01":
|
|
||||||
return [0, 1]
|
|
||||||
|
|
||||||
def build_difference_2d_dict(self):
|
|
||||||
"""
|
"""
|
||||||
Build a difference mapping for 2D device mesh case. It will be used to
|
Build a difference mapping for 2D device mesh case. It will be used to
|
||||||
compute the difference between DimSpec pairs.
|
compute the difference between _DimSpec pairs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
source_spec_list = ["R", "S0", "S1", "S01"]
|
source_spec_list = ["R", "S0", "S1", "S01"]
|
||||||
@ -71,9 +90,8 @@ class _DimSpec:
|
|||||||
difference_dict = {}
|
difference_dict = {}
|
||||||
for source_spec in source_spec_list:
|
for source_spec in source_spec_list:
|
||||||
for target_spec in target_spec_list:
|
for target_spec in target_spec_list:
|
||||||
spec_pair = (deepcopy(source_spec), deepcopy(target_spec))
|
source_shard_list = cls._convert_str_to_shard_list(source_spec)
|
||||||
source_shard_list = self._convert_str_to_shard_list(source_spec)
|
target_shard_list = cls._convert_str_to_shard_list(target_spec)
|
||||||
target_shard_list = self._convert_str_to_shard_list(target_spec)
|
|
||||||
|
|
||||||
# source same as target
|
# source same as target
|
||||||
if source_shard_list == target_shard_list:
|
if source_shard_list == target_shard_list:
|
||||||
@ -116,30 +134,27 @@ class _DimSpec:
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
difference = NAN
|
difference = NAN
|
||||||
difference_dict[spec_pair] = difference
|
difference_dict[(source_spec, target_spec)] = difference
|
||||||
|
|
||||||
self.difference_dict = difference_dict
|
return difference_dict
|
||||||
|
|
||||||
def difference(self, other):
|
@staticmethod
|
||||||
|
def _convert_str_to_shard_list(str_spec):
|
||||||
"""
|
"""
|
||||||
The difference between two _DimSpec.
|
Convert str_spec into shard_list.
|
||||||
|
|
||||||
Argument:
|
Argument:
|
||||||
other(_DimSpec): the dim spec to compare with.
|
str_spec(str): dim spec in str type.
|
||||||
|
|
||||||
Return:
|
|
||||||
difference(int): the difference between two _DimSpec.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
dim_spec = _DimSpec([0])
|
|
||||||
other_dim_spec = _DimSpec([0, 1])
|
|
||||||
print(dim_spec.difference(other_dim_spec))
|
|
||||||
|
|
||||||
Output:
|
|
||||||
5
|
|
||||||
"""
|
"""
|
||||||
difference = self.difference_dict[(str(self), str(other))]
|
|
||||||
return difference
|
if str_spec == "R":
|
||||||
|
return []
|
||||||
|
if str_spec == "S0":
|
||||||
|
return [0]
|
||||||
|
if str_spec == "S1":
|
||||||
|
return [1]
|
||||||
|
if str_spec == "S01":
|
||||||
|
return [0, 1]
|
||||||
|
|
||||||
|
|
||||||
class ShardingSpecException(Exception):
|
class ShardingSpecException(Exception):
|
||||||
|
Loading…
Reference in New Issue
Block a user