From 3a54e1c9b707978fe2775707c1d195e65d9b5da9 Mon Sep 17 00:00:00 2001
From: Frank Lee <somerlee.9@gmail.com>
Date: Fri, 19 Aug 2022 15:51:54 +0800
Subject: [PATCH] [autoparallel] standardize the code structure (#1469)

---
 colossalai/auto_parallel/__init__.py          |  0
 colossalai/auto_parallel/solver/__init__.py   |  0
 .../auto_parallel/solver/sharding_strategy.py | 28 +++++++++----------
 3 files changed, 13 insertions(+), 15 deletions(-)
 create mode 100644 colossalai/auto_parallel/__init__.py
 create mode 100644 colossalai/auto_parallel/solver/__init__.py

diff --git a/colossalai/auto_parallel/__init__.py b/colossalai/auto_parallel/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/colossalai/auto_parallel/solver/__init__.py b/colossalai/auto_parallel/solver/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/colossalai/auto_parallel/solver/sharding_strategy.py b/colossalai/auto_parallel/solver/sharding_strategy.py
index 8f227076d..025a13fc6 100644
--- a/colossalai/auto_parallel/solver/sharding_strategy.py
+++ b/colossalai/auto_parallel/solver/sharding_strategy.py
@@ -1,3 +1,9 @@
+from dataclasses import dataclass
+from colossalai.tensor.sharding_spec import ShardingSpec
+from typing import Dict, List
+
+
+@dataclass
 class ShardingStrategy:
     '''
     ShardingStrategy is a structure containing sharding strategies of inputs and output of this node
@@ -15,21 +21,13 @@ class ShardingStrategy:
         input_shardings(List(ShardingSpec)): The ShardingSpecs of the input nodes.
     '''
 
-    def __init__(self,
-                 name,
-                 output_sharding_spec,
-                 compute_cost=0,
-                 communication_cost=0,
-                 memory_cost=0,
-                 resharding_costs=None,
-                 input_shardings=None):
-        self.name = name
-        self.output_sharding_spec = output_sharding_spec
-        self.compute_cost = compute_cost
-        self.communication_cost = communication_cost
-        self.memory_cost = memory_cost
-        self.resharding_costs = resharding_costs
-        self.input_shardings = input_shardings
+    name: str
+    output_sharding_spec: ShardingSpec
+    compute_cost: float = 0.
+    communication_cost: float = 0.
+    memory_cost: float = 0.
+    resharding_costs: Dict[int, List[float]] = None
+    input_shardings: ShardingSpec = None
 
 
 class StrategiesVector: