mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-17 15:11:20 +00:00
[autoparallel] adapt handlers with attention block (#1990)
* [autoparallel] adapt handlers with attention block * polish
This commit is contained in:
@@ -12,6 +12,7 @@ __all__ = ['ReshapeHandler']
|
||||
|
||||
@operator_registry.register(torch.reshape)
|
||||
@operator_registry.register(torch.Tensor.split)
|
||||
@operator_registry.register(torch.split)
|
||||
@operator_registry.register(torch.flatten)
|
||||
@operator_registry.register(torch.Tensor.transpose)
|
||||
@operator_registry.register(torch.Tensor.permute)
|
||||
|
@@ -220,7 +220,9 @@ class BatchNormStrategyGenerator(StrategyGenerator):
|
||||
logical_process_axis=mesh_dim_0,
|
||||
comm_type=CommType.IMPLICIT)
|
||||
|
||||
communication_action_mapping = {"output": output_comm_action}
|
||||
# TODO: Temporary solution has no communication cost,
|
||||
# above action should be added after the SyncBN replace pass completed.
|
||||
communication_action_mapping = {}
|
||||
|
||||
return self.get_sharding_strategy(name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
@@ -256,7 +258,9 @@ class BatchNormStrategyGenerator(StrategyGenerator):
|
||||
logical_process_axis=[mesh_dim_0, mesh_dim_1],
|
||||
comm_type=CommType.IMPLICIT)
|
||||
|
||||
communication_action_mapping = {"output": output_comm_action}
|
||||
# TODO: Temporary solution has no communication cost,
|
||||
# above action should be added after the SyncBN replace pass completed.
|
||||
communication_action_mapping = {}
|
||||
|
||||
return self.get_sharding_strategy(name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
@@ -302,7 +306,9 @@ class BatchNormStrategyGenerator(StrategyGenerator):
|
||||
logical_process_axis=[mesh_dim_0],
|
||||
comm_type=CommType.IMPLICIT)
|
||||
|
||||
communication_action_mapping = {"output": output_comm_action}
|
||||
# TODO: Temporary solution has no communication cost,
|
||||
# above action should be added after the SyncBN replace pass completed.
|
||||
communication_action_mapping = {}
|
||||
|
||||
return self.get_sharding_strategy(name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
|
@@ -69,7 +69,7 @@ class TensorStrategyGenerator(GetItemStrategyGenerator):
|
||||
|
||||
def collate_strategies(self) -> List[ShardingStrategy]:
|
||||
strategy_list = []
|
||||
for strategy in self.predecessor_node.strategies_vector:
|
||||
for index, strategy in enumerate(self.predecessor_node.strategies_vector):
|
||||
dim_partition_dict_mapping = {}
|
||||
communication_action_mapping = {}
|
||||
dim_partition_dict_for_input = strategy.output_sharding_specs[self.op_data["input"]].dim_partition_dict
|
||||
@@ -96,7 +96,7 @@ class TensorStrategyGenerator(GetItemStrategyGenerator):
|
||||
arg_index=0)
|
||||
communication_action_mapping["input"] = input_communication_action
|
||||
|
||||
name = f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["input"].sharding_sequence}'
|
||||
name = f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["input"].sharding_sequence}_{index}'
|
||||
|
||||
strategy = self.get_sharding_strategy(name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
@@ -121,7 +121,7 @@ class TensorTupleStrategyGenerator(GetItemStrategyGenerator):
|
||||
strategy_list = []
|
||||
index = self.op_data["index"].data
|
||||
|
||||
for strategy in self.predecessor_node.strategies_vector:
|
||||
for strategy_index, strategy in enumerate(self.predecessor_node.strategies_vector):
|
||||
# the sharding spec for input in this case is a tuple of ShardingSpec.
|
||||
sharding_spec_for_input = strategy.output_sharding_specs[self.op_data["input"]]
|
||||
dim_partition_dict_for_output = sharding_spec_for_input[index].dim_partition_dict
|
||||
@@ -132,8 +132,11 @@ class TensorTupleStrategyGenerator(GetItemStrategyGenerator):
|
||||
}
|
||||
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
|
||||
sharding_spec_mapping["input"] = sharding_spec_for_input
|
||||
|
||||
name = f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["input"].sharding_sequence}'
|
||||
input_sharding_info = f"get the {index} element from ("
|
||||
for sharding_spec in sharding_spec_for_input:
|
||||
input_sharding_info += f'{sharding_spec.sharding_sequence}, '
|
||||
input_sharding_info += ")"
|
||||
name = f'{sharding_spec_mapping["output"].sharding_sequence} = {input_sharding_info}_{strategy_index}'
|
||||
|
||||
strategy = self.get_sharding_strategy(name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
|
@@ -1,9 +1,12 @@
|
||||
import copy
|
||||
from typing import List
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
|
||||
from colossalai.auto_parallel.tensor_shard.utils import (enumerate_all_possible_1d_sharding,
|
||||
enumerate_all_possible_2d_sharding)
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem
|
||||
from colossalai.auto_parallel.tensor_shard.utils import (
|
||||
enumerate_all_possible_1d_sharding,
|
||||
enumerate_all_possible_2d_sharding,
|
||||
ignore_sharding_exception,
|
||||
)
|
||||
|
||||
from .strategy_generator import StrategyGenerator
|
||||
|
||||
@@ -50,6 +53,7 @@ class WhereGenerator(StrategyGenerator):
|
||||
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
|
||||
strategy.memory_cost = memory_cost
|
||||
|
||||
@ignore_sharding_exception
|
||||
def _generate_strategy_with_dim_partition(self, dim_partition):
|
||||
dim_partition_dict_mapping = {
|
||||
"condition": dim_partition,
|
||||
|
@@ -14,6 +14,11 @@ __all__ = ['UnaryElementwiseHandler']
|
||||
@operator_registry.register(torch.Tensor.type)
|
||||
@operator_registry.register(torch.abs)
|
||||
@operator_registry.register(torch.nn.ReLU)
|
||||
# TODO: softmax need to be relocated
|
||||
@operator_registry.register(torch.nn.functional.softmax)
|
||||
@operator_registry.register(torch.nn.modules.dropout.Dropout)
|
||||
@operator_registry.register(torch.Tensor.contiguous)
|
||||
@operator_registry.register(torch.nn.functional.dropout)
|
||||
class UnaryElementwiseHandler(NodeHandler):
|
||||
"""
|
||||
A UnaryElementwiseHandler which deals with the sharding strategies for UnaryElementwise Op.
|
||||
|
@@ -57,24 +57,6 @@ class WhereHandler(NodeHandler):
|
||||
logical_operand.logical_shape = target_shape
|
||||
return logical_operand
|
||||
|
||||
def register_strategy(self, compute_resharding_cost: bool = False) -> StrategiesVector:
|
||||
"""
|
||||
Register different sharding strategies for the current node.
|
||||
"""
|
||||
strategy_generators = self.get_strategy_generator()
|
||||
|
||||
for generator in strategy_generators:
|
||||
strategies = generator.generate()
|
||||
strategies_vector = map(self.post_process, strategies)
|
||||
# compute the resharding costs based on the previous node
|
||||
# strategies if specified
|
||||
if compute_resharding_cost:
|
||||
strategies = list(map(self.update_resharding_cost, strategies))
|
||||
self.strategies_vector.extend(strategies)
|
||||
|
||||
self.strategies_vector = list(strategies_vector)
|
||||
return self.strategies_vector
|
||||
|
||||
def post_process(self, strategy: ShardingStrategy):
|
||||
logical_op_data_mapping, physical_op_data_mapping = self.get_operation_data_mapping()
|
||||
for key in logical_op_data_mapping.keys():
|
||||
|
Reference in New Issue
Block a user