[autoparallel] refactored shape consistency to remove redundancy (#1591)

* [autoparallel] refactored shape consistency to remove redundancy

* polish code

* polish code

* polish code
This commit is contained in:
Frank Lee
2022-09-13 18:30:18 +08:00
committed by GitHub
parent d164449d00
commit 27fe8af60c
13 changed files with 220 additions and 234 deletions

View File

@@ -1,15 +1,22 @@
import torch
from dataclasses import dataclass
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, shard_simulator
from enum import Enum
from copy import deepcopy
from typing import Dict, List, Optional, Tuple, Union
from colossalai.context.singleton_meta import SingletonMeta
import torch.distributed as dist
import math
from functools import reduce
import operator
from torch.distributed import ReduceOp
__all__ = [
'CollectiveCommPattern', 'CommSpec', 'ShapeConsistencyManager', 'ShapeConsistencyOptions',
'set_shape_consistency_options'
]
class CollectiveCommPattern(Enum):
ALLGATHER = 'all_gather'
@@ -152,14 +159,40 @@ class CommSpec:
tensor.data = tensor
class ShapeConsistencyManager:
@dataclass
class ShapeConsistencyOptions:
"""
ShapeConsistencyOptions is a dataclass which specifies the preferences for shape consistency.
"""
# TODO: shape consistency option is not implemented yet
pass
def __init__(self, consistency_option=None):
self.consistency_option = consistency_option
def set_shape_consistency_options(options: ShapeConsistencyOptions):
"""
Configure the shape consistency manager via function call.
"""
manager = ShapeConsistencyManager()
manager.options = options
class ShapeConsistencyManager(metaclass=SingletonMeta):
def __init__(self):
self._options = None
self.total_communication_cost = 0
self.total_transform_steps = 0
self.cached_spec_pairs_transform_path = {}
@property
def options(self):
return self._options
@options.setter
def options(self, options_: ShapeConsistencyOptions):
assert isinstance(options_, ShapeConsistencyOptions)
self._options = options_
def get_all_all_gather_spec(self, source_spec, orig_cost):
'''
Get all valid sharding specs from source_spec with single all-gather operation, and