From 2629f9717dbb1f0dc0ecad3f3bda18a4b44e3785 Mon Sep 17 00:00:00 2001
From: YH <100389977+yhna940@users.noreply.github.com>
Date: Sat, 6 May 2023 18:55:37 +0900
Subject: [PATCH] [tensor] Refactor handle_trans_spec in DistSpecManager

---
 colossalai/tensor/dist_spec_mgr.py | 20 ++++++++++++++------
 1 file changed, 14 insertions(+), 6 deletions(-)

diff --git a/colossalai/tensor/dist_spec_mgr.py b/colossalai/tensor/dist_spec_mgr.py
index 865798923..c968050de 100644
--- a/colossalai/tensor/dist_spec_mgr.py
+++ b/colossalai/tensor/dist_spec_mgr.py
@@ -4,10 +4,8 @@ import torch
 import torch.distributed as dist
 # from colossalai.nn.layer.utils import divide
 from numpy import prod
-from packaging import version
 
-from colossalai.logging import get_dist_logger
-from colossalai.tensor.distspec import _DistSpec
+from colossalai.tensor.distspec import DistPlacementPattern, _DistSpec
 from colossalai.tensor.process_group import ProcessGroup
 
 
@@ -171,11 +169,21 @@ class DistSpecManager:
                           pg: ProcessGroup) -> torch.Tensor:
         assert isinstance(old_dist_spec, _DistSpec), f"{type(old_dist_spec)} should be _DistSpec"
         assert isinstance(dist_spec, _DistSpec), f"{type(dist_spec)} should be _DistSpec"
-        forward_trans_handle = getattr(DistSpecManager, f'_{old_dist_spec.placement.value}2{dist_spec.placement.value}')
+
+        trans_func_key = (old_dist_spec.placement, dist_spec.placement)
+        trans_funcs = {
+            (DistPlacementPattern.REPLICATE, DistPlacementPattern.REPLICATE): DistSpecManager._r2r,
+            (DistPlacementPattern.REPLICATE, DistPlacementPattern.SHARD): DistSpecManager._r2s,
+            (DistPlacementPattern.SHARD, DistPlacementPattern.REPLICATE): DistSpecManager._s2r,
+            (DistPlacementPattern.SHARD, DistPlacementPattern.SHARD): DistSpecManager._s2s
+        }
+
+        forward_trans_handle = trans_funcs[trans_func_key]
         if not DistSpecManager._use_autograd_function:
             return forward_trans_handle(tensor, old_dist_spec, dist_spec, pg)
-        backward_trans_handle = getattr(DistSpecManager,
-                                        f'_{dist_spec.placement.value}2{old_dist_spec.placement.value}')
+
+        backward_trans_handle = trans_funcs[(dist_spec.placement, old_dist_spec.placement)]
+
         return TransformDistSpec.apply(tensor, old_dist_spec, dist_spec, pg, forward_trans_handle,
                                        backward_trans_handle)