mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-24 14:33:20 +00:00
[autoparallel] where_handler_v2 (#1688)
* where generator * [autoparallel] where_handler_v2
This commit is contained in:
parent
31d2f03d27
commit
319d654f79
@ -0,0 +1,87 @@
|
|||||||
|
import torch
|
||||||
|
from .node_handler import NodeHandler
|
||||||
|
from ..sharding_strategy import ShardingStrategy_V2, OperationDataType, OperationData, StrategiesVector
|
||||||
|
from ..strategy import WhereGenerator, StrategyGenerator_V2
|
||||||
|
from .broadcast import recover_sharding_spec_for_broadcast_shape
|
||||||
|
from typing import List, Dict
|
||||||
|
from .registry import operator_registry
|
||||||
|
import operator
|
||||||
|
import copy
|
||||||
|
|
||||||
|
__all__ = ['WhereHandler']
|
||||||
|
|
||||||
|
|
||||||
|
@operator_registry.register(torch.where)
|
||||||
|
class WhereHandler(NodeHandler):
|
||||||
|
"""
|
||||||
|
A WhereHandler which deals with the sharding strategies for torch.where.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_strategy_generator(self) -> List[StrategyGenerator_V2]:
|
||||||
|
logical_op_data_mapping, _ = self.get_operation_data_mapping()
|
||||||
|
generators = []
|
||||||
|
generators.append(WhereGenerator(logical_op_data_mapping, self.device_mesh))
|
||||||
|
return generators
|
||||||
|
|
||||||
|
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
|
||||||
|
# use transposed shape for strategies
|
||||||
|
# the strategies will be transformed back to its original shape in self.post_process
|
||||||
|
physical_condition_operand = OperationData(name=str(self.node.args[0]),
|
||||||
|
type=OperationDataType.ARG,
|
||||||
|
data=self.node.args[0]._meta_data)
|
||||||
|
physical_x_operand = OperationData(name=str(self.node.args[1]),
|
||||||
|
type=OperationDataType.ARG,
|
||||||
|
data=self.node.args[1]._meta_data)
|
||||||
|
physical_y_operand = OperationData(name=str(self.node.args[2]),
|
||||||
|
type=OperationDataType.ARG,
|
||||||
|
data=self.node.args[2]._meta_data)
|
||||||
|
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
|
||||||
|
physical_mapping = {
|
||||||
|
"condition": physical_condition_operand,
|
||||||
|
"x": physical_x_operand,
|
||||||
|
"y": physical_y_operand,
|
||||||
|
"output": physical_output
|
||||||
|
}
|
||||||
|
logical_shape_for_all = self.node._meta_data.shape
|
||||||
|
logical_mapping = {}
|
||||||
|
for key, physical_operand in physical_mapping.items():
|
||||||
|
logical_mapping[key] = self.convert_physical_operand_to_logical_operand(physical_operand,
|
||||||
|
logical_shape_for_all)
|
||||||
|
|
||||||
|
return logical_mapping, physical_mapping
|
||||||
|
|
||||||
|
def convert_physical_operand_to_logical_operand(self, physical_operand, target_shape):
|
||||||
|
logical_operand = copy.deepcopy(physical_operand)
|
||||||
|
logical_operand.logical_shape = target_shape
|
||||||
|
return logical_operand
|
||||||
|
|
||||||
|
def register_strategy(self, compute_resharding_cost: bool = False) -> StrategiesVector:
|
||||||
|
"""
|
||||||
|
Register different sharding strategies for the current node.
|
||||||
|
"""
|
||||||
|
strategy_generators = self.get_strategy_generator()
|
||||||
|
|
||||||
|
for generator in strategy_generators:
|
||||||
|
strategies = generator.generate()
|
||||||
|
strategies_vector = map(self.post_process, strategies)
|
||||||
|
# compute the resharding costs based on the previous node
|
||||||
|
# strategies if specified
|
||||||
|
if compute_resharding_cost:
|
||||||
|
strategies = list(map(self.update_resharding_cost, strategies))
|
||||||
|
self.strategies_vector.extend(strategies)
|
||||||
|
|
||||||
|
self.strategies_vector = list(strategies_vector)
|
||||||
|
return self.strategies_vector
|
||||||
|
|
||||||
|
def post_process(self, strategy: ShardingStrategy_V2):
|
||||||
|
logical_op_data_mapping, physical_op_data_mapping = self.get_operation_data_mapping()
|
||||||
|
for key in logical_op_data_mapping.keys():
|
||||||
|
logical_sharding_spec = strategy.sharding_specs[logical_op_data_mapping[key]]
|
||||||
|
logical_shape = logical_op_data_mapping[key].logical_shape
|
||||||
|
physical_shape = physical_op_data_mapping[key].logical_shape
|
||||||
|
physical_sharding_spec = recover_sharding_spec_for_broadcast_shape(logical_sharding_spec, logical_shape,
|
||||||
|
physical_shape)
|
||||||
|
strategy.sharding_specs.pop(logical_op_data_mapping[key])
|
||||||
|
strategy.sharding_specs[physical_op_data_mapping[key]] = physical_sharding_spec
|
||||||
|
strategy.name = f"{strategy.sharding_specs[physical_op_data_mapping['output']].sharding_sequence} = {strategy.sharding_specs[physical_op_data_mapping['condition']].sharding_sequence} x {strategy.sharding_specs[physical_op_data_mapping['x']].sharding_sequence} x {strategy.sharding_specs[physical_op_data_mapping['y']].sharding_sequence}"
|
||||||
|
return strategy
|
@ -5,11 +5,12 @@ from .batch_norm_generator import BatchNormStrategyGenerator
|
|||||||
from .unary_elementwise_generator import UnaryElementwiseGenerator
|
from .unary_elementwise_generator import UnaryElementwiseGenerator
|
||||||
from .getitem_generator import GetItemStrategyGenerator, TensorStrategyGenerator, TensorTupleStrategyGenerator
|
from .getitem_generator import GetItemStrategyGenerator, TensorStrategyGenerator, TensorTupleStrategyGenerator
|
||||||
from .layer_norm_generator import LayerNormGenerator
|
from .layer_norm_generator import LayerNormGenerator
|
||||||
|
from .where_generator import WhereGenerator
|
||||||
from .reshape_generator import ReshapeGenerator
|
from .reshape_generator import ReshapeGenerator
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'StrategyGenerator_V2', 'DotProductStrategyGenerator', 'MatVecStrategyGenerator',
|
'StrategyGenerator_V2', 'DotProductStrategyGenerator', 'MatVecStrategyGenerator',
|
||||||
'LinearProjectionStrategyGenerator', 'BatchedMatMulStrategyGenerator', 'ConvStrategyGenerator',
|
'LinearProjectionStrategyGenerator', 'BatchedMatMulStrategyGenerator', 'ConvStrategyGenerator',
|
||||||
'UnaryElementwiseGenerator', 'BatchNormStrategyGenerator', 'GetItemStrategyGenerator', 'TensorStrategyGenerator',
|
'UnaryElementwiseGenerator', 'BatchNormStrategyGenerator', 'GetItemStrategyGenerator', 'TensorStrategyGenerator',
|
||||||
'TensorTupleStrategyGenerator', 'LayerNormGenerator', 'ReshapeGenerator'
|
'TensorTupleStrategyGenerator', 'LayerNormGenerator', "WhereGenerator", 'ReshapeGenerator'
|
||||||
]
|
]
|
||||||
|
@ -163,7 +163,7 @@ class LayerNormGenerator(StrategyGenerator_V2):
|
|||||||
|
|
||||||
def generate(self):
|
def generate(self):
|
||||||
'''
|
'''
|
||||||
Generate every possible strategies for a BatchNorm node, and record all strategies into the strategies_vector.
|
Generate every possible strategies for a LayerNorm node, and record all strategies into the strategies_vector.
|
||||||
'''
|
'''
|
||||||
strategy_list = []
|
strategy_list = []
|
||||||
input_data_dim = len(self.op_data["input"].logical_shape)
|
input_data_dim = len(self.op_data["input"].logical_shape)
|
||||||
|
99
colossalai/auto_parallel/solver/strategy/where_generator.py
Normal file
99
colossalai/auto_parallel/solver/strategy/where_generator.py
Normal file
@ -0,0 +1,99 @@
|
|||||||
|
import operator
|
||||||
|
from functools import reduce
|
||||||
|
from ..sharding_strategy import ShardingStrategy_V2, TrainCycleItem, MemoryCost
|
||||||
|
from colossalai.tensor.shape_consistency import CollectiveCommPattern
|
||||||
|
from .strategy_generator import StrategyGenerator_V2, FollowingStrategyGenerator
|
||||||
|
from typing import List
|
||||||
|
from .._utils import exception_handler, enumerate_all_possible_1d_sharding, enumerate_all_possible_2d_sharding
|
||||||
|
import copy
|
||||||
|
|
||||||
|
__all__ = ['WhereGenerator']
|
||||||
|
|
||||||
|
|
||||||
|
class WhereGenerator(StrategyGenerator_V2):
|
||||||
|
"""
|
||||||
|
WhereGenerator is a generic class to generate strategies for Where operation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def validate(self) -> bool:
|
||||||
|
return super().validate()
|
||||||
|
|
||||||
|
def update_compute_cost(self, strategy: ShardingStrategy_V2):
|
||||||
|
compute_cost = TrainCycleItem(fwd=10, bwd=10, total=20)
|
||||||
|
strategy.compute_cost = compute_cost
|
||||||
|
|
||||||
|
def update_memory_cost(self, strategy: ShardingStrategy_V2):
|
||||||
|
'''
|
||||||
|
Compute the memory cost per device with this specific strategy.
|
||||||
|
'''
|
||||||
|
forward_size_mapping = {
|
||||||
|
'condition': self._compute_size_in_bytes(strategy, "condition"),
|
||||||
|
'x': self._compute_size_in_bytes(strategy, "x"),
|
||||||
|
'y': self._compute_size_in_bytes(strategy, "y"),
|
||||||
|
'output': self._compute_size_in_bytes(strategy, "output")
|
||||||
|
}
|
||||||
|
|
||||||
|
backward_size_mapping = copy.deepcopy(forward_size_mapping)
|
||||||
|
backward_size_mapping.pop("output")
|
||||||
|
# compute fwd cost incurred
|
||||||
|
# fwd_cost = condition + x + y + output
|
||||||
|
fwd_activation_cost = sum([v for k, v in forward_size_mapping.items()])
|
||||||
|
fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=0)
|
||||||
|
|
||||||
|
# compute bwd cost incurred
|
||||||
|
# bwd_cost = condition_grad + x_grad + y_grad
|
||||||
|
bwd_activation_cost = sum([v for k, v in backward_size_mapping.items()])
|
||||||
|
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=0)
|
||||||
|
|
||||||
|
# compute total cost
|
||||||
|
total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost, parameter=0)
|
||||||
|
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
|
||||||
|
strategy.memory_cost = memory_cost
|
||||||
|
|
||||||
|
def _generate_strategy_with_dim_partition(self, dim_partition):
|
||||||
|
dim_partition_dict_mapping = {
|
||||||
|
"condition": dim_partition,
|
||||||
|
"x": dim_partition,
|
||||||
|
"y": dim_partition,
|
||||||
|
"output": dim_partition
|
||||||
|
}
|
||||||
|
|
||||||
|
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
|
||||||
|
|
||||||
|
name = f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["condition"].sharding_sequence} x {sharding_spec_mapping["x"].sharding_sequence} x {sharding_spec_mapping["y"].sharding_sequence}'
|
||||||
|
communication_action_mapping = {}
|
||||||
|
|
||||||
|
strategy = self.get_sharding_strategy(name=name,
|
||||||
|
sharding_spec_mapping=sharding_spec_mapping,
|
||||||
|
communication_action_mapping=communication_action_mapping)
|
||||||
|
|
||||||
|
return strategy
|
||||||
|
|
||||||
|
def enumerate_all_possible_output_spec(self, mesh_dim_0, mesh_dim_1, dimension_length):
|
||||||
|
dim_partition_list = []
|
||||||
|
dim_partition_list.extend(enumerate_all_possible_1d_sharding(mesh_dim_0, dimension_length))
|
||||||
|
dim_partition_list.extend(enumerate_all_possible_1d_sharding(mesh_dim_1, dimension_length))
|
||||||
|
dim_partition_list.extend(enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, dimension_length))
|
||||||
|
# append {} for non_split case
|
||||||
|
dim_partition_list.append({})
|
||||||
|
|
||||||
|
return dim_partition_list
|
||||||
|
|
||||||
|
def generate(self):
|
||||||
|
'''
|
||||||
|
Generate every possible strategies for a where node, and record all strategies into the strategies_vector.
|
||||||
|
'''
|
||||||
|
strategy_list = []
|
||||||
|
|
||||||
|
dimension_length = len(self.op_data["output"].logical_shape)
|
||||||
|
dim_partition_list = self.enumerate_all_possible_output_spec(0, 1, dimension_length)
|
||||||
|
for dim_partition in dim_partition_list:
|
||||||
|
strategy = self._generate_strategy_with_dim_partition(dim_partition)
|
||||||
|
strategy_list.append(strategy)
|
||||||
|
|
||||||
|
for strategy in strategy_list:
|
||||||
|
self.update_communication_cost(strategy)
|
||||||
|
self.update_compute_cost(strategy)
|
||||||
|
self.update_memory_cost(strategy)
|
||||||
|
|
||||||
|
return strategy_list
|
@ -0,0 +1,85 @@
|
|||||||
|
from colossalai.fx.tracer.meta_patch.patched_module import linear
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from colossalai.fx import ColoTracer, ColoGraphModule
|
||||||
|
from colossalai.auto_parallel.solver.op_handler.where_handler_v2 import WhereHandler
|
||||||
|
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||||
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
|
|
||||||
|
|
||||||
|
class ConvModel(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, condition, x, y):
|
||||||
|
output = torch.where(condition, x, y)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def test_where_handler():
|
||||||
|
model = ConvModel()
|
||||||
|
tracer = ColoTracer()
|
||||||
|
# graph():
|
||||||
|
# %condition : torch.Tensor [#users=1] = placeholder[target=condition]
|
||||||
|
# %x : torch.Tensor [#users=1] = placeholder[target=x]
|
||||||
|
# %y : torch.Tensor [#users=1] = placeholder[target=y]
|
||||||
|
# %where : [#users=1] = call_function[target=torch.where](args = (%condition, %x, %y), kwargs = {})
|
||||||
|
# return where
|
||||||
|
graph = tracer.trace(model,
|
||||||
|
meta_args={
|
||||||
|
"condition": torch.rand(4, 4, 64, 64).to('meta'),
|
||||||
|
"x": torch.rand(4, 1, 64, 64).to('meta'),
|
||||||
|
"y": torch.rand(1, 4, 64, 64).to('meta')
|
||||||
|
})
|
||||||
|
gm = ColoGraphModule(model, graph)
|
||||||
|
physical_mesh_id = torch.arange(0, 4)
|
||||||
|
|
||||||
|
mesh_shape = (2, 2)
|
||||||
|
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||||
|
where_node = list(graph.nodes)[3]
|
||||||
|
strategies_vector = StrategiesVector(where_node)
|
||||||
|
|
||||||
|
# build handler
|
||||||
|
handler = WhereHandler(node=where_node, device_mesh=device_mesh, strategies_vector=strategies_vector)
|
||||||
|
|
||||||
|
# check operation data mapping
|
||||||
|
mapping, _ = handler.get_operation_data_mapping()
|
||||||
|
|
||||||
|
for name, op_data in mapping.items():
|
||||||
|
op_data: OperationData
|
||||||
|
# make sure they have valid values
|
||||||
|
assert op_data.logical_shape is not None
|
||||||
|
assert op_data.data is not None
|
||||||
|
|
||||||
|
assert mapping['condition'].name == "condition"
|
||||||
|
assert mapping['condition'].data.is_meta
|
||||||
|
assert mapping['condition'].data.shape == torch.Size([4, 4, 64, 64])
|
||||||
|
assert mapping['condition'].type == OperationDataType.ARG
|
||||||
|
assert mapping['condition'].logical_shape == torch.Size([4, 4, 64, 64])
|
||||||
|
|
||||||
|
assert mapping['x'].name == "x"
|
||||||
|
assert mapping['x'].data.is_meta
|
||||||
|
assert mapping['x'].data.shape == torch.Size([4, 1, 64, 64])
|
||||||
|
assert mapping['x'].type == OperationDataType.ARG
|
||||||
|
assert mapping['x'].logical_shape == torch.Size([4, 4, 64, 64])
|
||||||
|
|
||||||
|
assert mapping['y'].name == "y"
|
||||||
|
assert mapping['y'].data.is_meta
|
||||||
|
assert mapping['y'].data.shape == torch.Size([1, 4, 64, 64])
|
||||||
|
assert mapping['y'].type == OperationDataType.ARG
|
||||||
|
assert mapping['y'].logical_shape == torch.Size([4, 4, 64, 64])
|
||||||
|
|
||||||
|
assert mapping['output'].name == "where"
|
||||||
|
assert mapping['output'].data.is_meta
|
||||||
|
assert mapping['output'].data.shape == torch.Size([4, 4, 64, 64])
|
||||||
|
assert mapping['output'].type == OperationDataType.OUTPUT
|
||||||
|
|
||||||
|
handler.register_strategy()
|
||||||
|
strategy_name_list = [val.name for val in strategies_vector]
|
||||||
|
# 4*3 + 4*3/2*2 + 1
|
||||||
|
assert len(strategy_name_list) == 25
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_where_handler()
|
Loading…
Reference in New Issue
Block a user