[autoparallel] refactor handlers which reshape input tensors (#2615)

* [autoparallel] refactor handlers which reshape input tensors

* polish
This commit is contained in:
YuliangLiu0306
2023-02-08 15:02:49 +08:00
committed by GitHub
parent 28398f1c70
commit 37df666f38
15 changed files with 307 additions and 365 deletions

View File

@@ -1,8 +1,8 @@
import torch
import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.node_handler import DefaultReshapeHandler
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler
from colossalai.auto_parallel.tensor_shard.node_handler.reshape_handler import ReshapeHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
@@ -51,9 +51,9 @@ def test_reshape_handler():
strategies_vector=conv_strategies_vector)
conv_handler.register_strategy(compute_resharding_cost=False)
setattr(conv_mod_node, 'strategies_vector', conv_strategies_vector)
reshape_handler = ReshapeHandler(node=reshape_node,
device_mesh=device_mesh,
strategies_vector=reshape_strategies_vector)
reshape_handler = DefaultReshapeHandler(node=reshape_node,
device_mesh=device_mesh,
strategies_vector=reshape_strategies_vector)
reshape_handler.register_strategy(compute_resharding_cost=False)

View File

@@ -5,10 +5,10 @@ import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.node_handler.default_reshape_handler import DefaultReshapeHandler
from colossalai.auto_parallel.tensor_shard.node_handler.getitem_handler import GetItemHandler
from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler
from colossalai.auto_parallel.tensor_shard.node_handler.placeholder_handler import PlaceholderHandler
from colossalai.auto_parallel.tensor_shard.node_handler.reshape_handler import ReshapeHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
@@ -153,7 +153,9 @@ def test_getitem_from_tuple_handler():
)
input_handler.register_strategy(compute_resharding_cost=False)
setattr(input_node, 'strategies_vector', input_strategies_vector)
split_handler = ReshapeHandler(node=split_node, device_mesh=device_mesh, strategies_vector=split_strategies_vector)
split_handler = DefaultReshapeHandler(node=split_node,
device_mesh=device_mesh,
strategies_vector=split_strategies_vector)
split_handler.register_strategy(compute_resharding_cost=False)
setattr(split_node, 'strategies_vector', split_strategies_vector)
getitem_handler = GetItemHandler(node=getitem_node,

View File

@@ -5,8 +5,8 @@ import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.node_handler import PermuteHandler, TransposeHandler
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler
from colossalai.auto_parallel.tensor_shard.node_handler.experimental import PermuteHandler, TransposeHandler
from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh

View File

@@ -5,8 +5,8 @@ import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.node_handler import SplitHandler
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler
from colossalai.auto_parallel.tensor_shard.node_handler.experimental import SplitHandler
from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
@@ -156,8 +156,7 @@ def check_split_handler(rank, split_size, split_dim, model_cls, world_size, port
# reshape handler is a following strategy handler, so the number of strategies is equal to the predecessor node.
assert len(split_strategies_vector) == len(previous_strategies_vector)
strategy_name_list = [strategy.name for strategy in split_strategies_vector]
for name in strategy_name_list:
print(name)
if model_cls.__name__ == 'ConvSplitModel':
if split_dim == 0:

View File

@@ -5,8 +5,8 @@ import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.node_handler import ViewHandler
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler
from colossalai.auto_parallel.tensor_shard.node_handler.experimental import ViewHandler
from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh