From 258b43317c4a5cafb8d3da0ff63c8843443bc448 Mon Sep 17 00:00:00 2001
From: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com>
Date: Tue, 21 Mar 2023 13:24:18 +0800
Subject: [PATCH] [hotfix] layout converting issue (#3188)

---
 colossalai/tensor/d_tensor/layout_converter.py | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/colossalai/tensor/d_tensor/layout_converter.py b/colossalai/tensor/d_tensor/layout_converter.py
index a4f4c9c2d..cf02aac30 100644
--- a/colossalai/tensor/d_tensor/layout_converter.py
+++ b/colossalai/tensor/d_tensor/layout_converter.py
@@ -10,7 +10,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost,
 from colossalai.context.singleton_meta import SingletonMeta
 from colossalai.tensor.d_tensor.comm_spec import *
 from colossalai.tensor.d_tensor.layout import Layout
-from colossalai.tensor.sharding_spec import ShardingSpecException
+from colossalai.tensor.d_tensor.misc import LayoutException
 from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, shard_simulator
 
 from .sharding_spec import ShardingSpec
@@ -145,7 +145,7 @@ class LayoutConverter(metaclass=SingletonMeta):
                                     entire_shape=source_layout.entire_shape)
 
                 valid_spec_dict[new_layout] = comm_spec
-            except ShardingSpecException:
+            except LayoutException:
                 pass
         return valid_spec_dict
 
@@ -255,7 +255,7 @@ class LayoutConverter(metaclass=SingletonMeta):
                                         device_type=source_layout.device_type,
                                         entire_shape=source_layout.entire_shape)
                     valid_spec_dict[new_layout] = comm_spec
-                except ShardingSpecException:
+                except LayoutException:
                     pass
 
         return valid_spec_dict
@@ -343,7 +343,7 @@ class LayoutConverter(metaclass=SingletonMeta):
                                         device_type=source_layout.device_type,
                                         entire_shape=source_layout.entire_shape)
                     valid_spec_dict[new_layout] = comm_spec
-                except ShardingSpecException:
+                except LayoutException:
                     pass
         return valid_spec_dict