mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 01:55:12 +00:00
[autoparallel] refactor handlers which reshape input tensors (#2615)
* [autoparallel] refactor handlers which reshape input tensors * polish
This commit is contained in:
@@ -1,8 +1,8 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler import DefaultReshapeHandler
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.reshape_handler import ReshapeHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
@@ -51,9 +51,9 @@ def test_reshape_handler():
|
||||
strategies_vector=conv_strategies_vector)
|
||||
conv_handler.register_strategy(compute_resharding_cost=False)
|
||||
setattr(conv_mod_node, 'strategies_vector', conv_strategies_vector)
|
||||
reshape_handler = ReshapeHandler(node=reshape_node,
|
||||
device_mesh=device_mesh,
|
||||
strategies_vector=reshape_strategies_vector)
|
||||
reshape_handler = DefaultReshapeHandler(node=reshape_node,
|
||||
device_mesh=device_mesh,
|
||||
strategies_vector=reshape_strategies_vector)
|
||||
|
||||
reshape_handler.register_strategy(compute_resharding_cost=False)
|
||||
|
@@ -5,10 +5,10 @@ import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.default_reshape_handler import DefaultReshapeHandler
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.getitem_handler import GetItemHandler
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.placeholder_handler import PlaceholderHandler
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.reshape_handler import ReshapeHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
@@ -153,7 +153,9 @@ def test_getitem_from_tuple_handler():
|
||||
)
|
||||
input_handler.register_strategy(compute_resharding_cost=False)
|
||||
setattr(input_node, 'strategies_vector', input_strategies_vector)
|
||||
split_handler = ReshapeHandler(node=split_node, device_mesh=device_mesh, strategies_vector=split_strategies_vector)
|
||||
split_handler = DefaultReshapeHandler(node=split_node,
|
||||
device_mesh=device_mesh,
|
||||
strategies_vector=split_strategies_vector)
|
||||
split_handler.register_strategy(compute_resharding_cost=False)
|
||||
setattr(split_node, 'strategies_vector', split_strategies_vector)
|
||||
getitem_handler = GetItemHandler(node=getitem_node,
|
||||
|
@@ -5,8 +5,8 @@ import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler import PermuteHandler, TransposeHandler
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.experimental import PermuteHandler, TransposeHandler
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
|
@@ -5,8 +5,8 @@ import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler import SplitHandler
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.experimental import SplitHandler
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
@@ -156,8 +156,7 @@ def check_split_handler(rank, split_size, split_dim, model_cls, world_size, port
|
||||
# reshape handler is a following strategy handler, so the number of strategies is equal to the predecessor node.
|
||||
assert len(split_strategies_vector) == len(previous_strategies_vector)
|
||||
strategy_name_list = [strategy.name for strategy in split_strategies_vector]
|
||||
for name in strategy_name_list:
|
||||
print(name)
|
||||
|
||||
if model_cls.__name__ == 'ConvSplitModel':
|
||||
|
||||
if split_dim == 0:
|
||||
|
@@ -5,8 +5,8 @@ import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler import ViewHandler
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.experimental import ViewHandler
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
|
Reference in New Issue
Block a user