mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 10:06:44 +00:00
[autoparallel] refactored shape consistency to remove redundancy (#1591)
* [autoparallel] refactored shape consistency to remove redundancy * polish code * polish code * polish code
This commit is contained in:
@@ -1,15 +1,22 @@
|
||||
import torch
|
||||
from dataclasses import dataclass
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
|
||||
from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, shard_simulator
|
||||
from enum import Enum
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from colossalai.context.singleton_meta import SingletonMeta
|
||||
import torch.distributed as dist
|
||||
import math
|
||||
from functools import reduce
|
||||
import operator
|
||||
from torch.distributed import ReduceOp
|
||||
|
||||
__all__ = [
|
||||
'CollectiveCommPattern', 'CommSpec', 'ShapeConsistencyManager', 'ShapeConsistencyOptions',
|
||||
'set_shape_consistency_options'
|
||||
]
|
||||
|
||||
|
||||
class CollectiveCommPattern(Enum):
|
||||
ALLGATHER = 'all_gather'
|
||||
@@ -152,14 +159,40 @@ class CommSpec:
|
||||
tensor.data = tensor
|
||||
|
||||
|
||||
class ShapeConsistencyManager:
|
||||
@dataclass
|
||||
class ShapeConsistencyOptions:
|
||||
"""
|
||||
ShapeConsistencyOptions is a dataclass which specifies the preferences for shape consistency.
|
||||
"""
|
||||
# TODO: shape consistency option is not implemented yet
|
||||
pass
|
||||
|
||||
def __init__(self, consistency_option=None):
|
||||
self.consistency_option = consistency_option
|
||||
|
||||
def set_shape_consistency_options(options: ShapeConsistencyOptions):
|
||||
"""
|
||||
Configure the shape consistency manager via function call.
|
||||
"""
|
||||
manager = ShapeConsistencyManager()
|
||||
manager.options = options
|
||||
|
||||
|
||||
class ShapeConsistencyManager(metaclass=SingletonMeta):
|
||||
|
||||
def __init__(self):
|
||||
self._options = None
|
||||
self.total_communication_cost = 0
|
||||
self.total_transform_steps = 0
|
||||
self.cached_spec_pairs_transform_path = {}
|
||||
|
||||
@property
|
||||
def options(self):
|
||||
return self._options
|
||||
|
||||
@options.setter
|
||||
def options(self, options_: ShapeConsistencyOptions):
|
||||
assert isinstance(options_, ShapeConsistencyOptions)
|
||||
self._options = options_
|
||||
|
||||
def get_all_all_gather_spec(self, source_spec, orig_cost):
|
||||
'''
|
||||
Get all valid sharding specs from source_spec with single all-gather operation, and
|
||||
|
Reference in New Issue
Block a user