From 3a54e1c9b707978fe2775707c1d195e65d9b5da9 Mon Sep 17 00:00:00 2001 From: Frank Lee <somerlee.9@gmail.com> Date: Fri, 19 Aug 2022 15:51:54 +0800 Subject: [PATCH] [autoparallel] standardize the code structure (#1469) --- colossalai/auto_parallel/__init__.py | 0 colossalai/auto_parallel/solver/__init__.py | 0 .../auto_parallel/solver/sharding_strategy.py | 28 +++++++++---------- 3 files changed, 13 insertions(+), 15 deletions(-) create mode 100644 colossalai/auto_parallel/__init__.py create mode 100644 colossalai/auto_parallel/solver/__init__.py diff --git a/colossalai/auto_parallel/__init__.py b/colossalai/auto_parallel/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/colossalai/auto_parallel/solver/__init__.py b/colossalai/auto_parallel/solver/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/colossalai/auto_parallel/solver/sharding_strategy.py b/colossalai/auto_parallel/solver/sharding_strategy.py index 8f227076d..025a13fc6 100644 --- a/colossalai/auto_parallel/solver/sharding_strategy.py +++ b/colossalai/auto_parallel/solver/sharding_strategy.py @@ -1,3 +1,9 @@ +from dataclasses import dataclass +from colossalai.tensor.sharding_spec import ShardingSpec +from typing import Dict, List + + +@dataclass class ShardingStrategy: ''' ShardingStrategy is a structure containing sharding strategies of inputs and output of this node @@ -15,21 +21,13 @@ class ShardingStrategy: input_shardings(List(ShardingSpec)): The ShardingSpecs of the input nodes. ''' - def __init__(self, - name, - output_sharding_spec, - compute_cost=0, - communication_cost=0, - memory_cost=0, - resharding_costs=None, - input_shardings=None): - self.name = name - self.output_sharding_spec = output_sharding_spec - self.compute_cost = compute_cost - self.communication_cost = communication_cost - self.memory_cost = memory_cost - self.resharding_costs = resharding_costs - self.input_shardings = input_shardings + name: str + output_sharding_spec: ShardingSpec + compute_cost: float = 0. + communication_cost: float = 0. + memory_cost: float = 0. + resharding_costs: Dict[int, List[float]] = None + input_shardings: ShardingSpec = None class StrategiesVector: