diff --git a/colossalai/auto_parallel/__init__.py b/colossalai/auto_parallel/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/colossalai/auto_parallel/solver/__init__.py b/colossalai/auto_parallel/solver/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/colossalai/auto_parallel/solver/sharding_strategy.py b/colossalai/auto_parallel/solver/sharding_strategy.py index 8f227076d..025a13fc6 100644 --- a/colossalai/auto_parallel/solver/sharding_strategy.py +++ b/colossalai/auto_parallel/solver/sharding_strategy.py @@ -1,3 +1,9 @@ +from dataclasses import dataclass +from colossalai.tensor.sharding_spec import ShardingSpec +from typing import Dict, List + + +@dataclass class ShardingStrategy: ''' ShardingStrategy is a structure containing sharding strategies of inputs and output of this node @@ -15,21 +21,13 @@ class ShardingStrategy: input_shardings(List(ShardingSpec)): The ShardingSpecs of the input nodes. ''' - def __init__(self, - name, - output_sharding_spec, - compute_cost=0, - communication_cost=0, - memory_cost=0, - resharding_costs=None, - input_shardings=None): - self.name = name - self.output_sharding_spec = output_sharding_spec - self.compute_cost = compute_cost - self.communication_cost = communication_cost - self.memory_cost = memory_cost - self.resharding_costs = resharding_costs - self.input_shardings = input_shardings + name: str + output_sharding_spec: ShardingSpec + compute_cost: float = 0. + communication_cost: float = 0. + memory_cost: float = 0. + resharding_costs: Dict[int, List[float]] = None + input_shardings: ShardingSpec = None class StrategiesVector: