diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/__init__.py b/colossalai/auto_parallel/meta_profiler/meta_registry/__init__.py index df9eb6498..d005ac813 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/__init__.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/__init__.py @@ -6,3 +6,4 @@ from .linear import * from .norm import * from .pooling import * from .tensor import * +from .where import * diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/where.py b/colossalai/auto_parallel/meta_profiler/meta_registry/where.py new file mode 100644 index 000000000..c67eb40bc --- /dev/null +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/where.py @@ -0,0 +1,60 @@ +from typing import List, Tuple + +import torch + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem +from colossalai.fx.profiler.memory_utils import activation_size +from colossalai.fx.profiler.opcount import flop_mapping + +from ..registry import meta_register + +__all__ = ["where_meta_info"] + + +@meta_register.register(torch.where) +def where_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: + """torch.where meta information generator + + Returns: + Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs + """ + + condition_tensor, x_tensor, y_tensor, output_tensor = [arg.data for arg in args] + + # compute cost + fwd_compute_cost = 0 + + # if we need to broadcast the condition tensor, during backward we need to do a reduce_sum + bwd_compute_cost = 0 + if x_tensor.shape != output_tensor.shape: + bwd_compute_cost += flop_mapping[torch.ops.aten.sum.dim_IntList]([output_tensor], [x_tensor]) + if y_tensor.shape != output_tensor.shape: + bwd_compute_cost += flop_mapping[torch.ops.aten.sum.dim_IntList]([output_tensor], [y_tensor]) + + compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost) + + # memory cost + # during the forward phase, torch.where will allocate memory for output tensor and condition tensor + # during the backward phase, torch.where will allocate temp memory which is 3 times as output tensor, then generate + # gradient matrix for input x and input y, remove the temp memory and condition tensor generated in forward phase + # NOTE: currently in SPMD solver we always believe that there will be a new input tensor created in forward + fwd_mem_cost = MemoryCost(activation=activation_size([condition_tensor, x_tensor, y_tensor, output_tensor])) + bwd_mem_cost = MemoryCost(activation=activation_size([x_tensor, y_tensor]) - activation_size([condition_tensor]), + parameter=0, + temp=activation_size([output_tensor]) * 3 + activation_size([condition_tensor]) - + activation_size([x_tensor, y_tensor]), + buffer=0) + + total_mem_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation, + parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter, + temp=fwd_mem_cost.temp + bwd_mem_cost.temp, + buffer=fwd_mem_cost.buffer + bwd_mem_cost.buffer) + + memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) + + # store fwd_in, fwd_buffer, fwd_out + fwd_in = [condition_tensor] + fwd_buffer = [] + fwd_out = [output_tensor] + + return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_where_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_where_metainfo.py new file mode 100644 index 000000000..20156f9ab --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_where_metainfo.py @@ -0,0 +1,104 @@ +import pytest +import torch +import torch.multiprocessing as mp +import torch.nn as nn + +from colossalai.auto_parallel.tensor_shard.node_handler import LinearModuleHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + MemoryCost, + OperationData, + OperationDataType, + ShardingStrategy, + StrategiesVector, + TrainCycleItem, +) +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use +from colossalai.utils import free_port +from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results + +if torch.__version__ >= '1.12.0': + from colossalai.auto_parallel.meta_profiler import MetaInfo, meta_register + + +@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations") +def test_where_meta_info(): + meta_func = meta_register.get(torch.where) + + # construct meta tensors + condition_tensor = torch.rand(1, 1, 1024, 1024) > 0.5 + condition_tensor = condition_tensor.to(device="meta") + x_tensor = torch.rand(8, 16, 1024, 1024, device="meta") + y_tensor = torch.tensor(0, device="meta") + output_tensor = torch.rand(8, 16, 1024, 1024) + + # construct operation data + condition_data = OperationData( + name="condition", + data=condition_tensor, + type=OperationDataType.ARG, + logical_shape=condition_tensor.shape, + ) + x_data = OperationData( + name="x", + data=x_tensor, + type=OperationDataType.ARG, + logical_shape=x_tensor.shape, + ) + y_data = OperationData( + name="y", + data=y_tensor, + type=OperationDataType.ARG, + logical_shape=y_tensor.shape, + ) + output_data = OperationData( + name="output", + data=output_tensor, + type=OperationDataType.OUTPUT, + logical_shape=output_tensor.shape, + ) + + # construct args and kwargs + args = [condition_data, x_data, y_data, output_data] + kwargs = {'inplace': False} + + # estimated results + compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out = meta_func(*args, **kwargs) + + # actual results + condition_real_tensor = torch.rand(1, 1, 1024, 1024) > 0.5 + condition_real_tensor = condition_real_tensor.to(device="cuda") + x_real_tensor = torch.rand(8, 16, 1024, 1024, device="cuda") + y_real_tensor = torch.tensor(0.0, device="cuda") + + x_real_tensor.requires_grad = True + y_real_tensor.requires_grad = True + + # fwd + torch.cuda.reset_peak_memory_stats() + mem_stamp0 = torch.cuda.memory_allocated() + output_real_tensor = torch.where(condition_real_tensor, x_real_tensor, y_real_tensor) + fwd_allocated = torch.cuda.memory_allocated() - mem_stamp0 + fwd_peak = torch.cuda.max_memory_allocated() - mem_stamp0 + + # bwd + upstream_grad = torch.rand_like(output_real_tensor) + torch.cuda.reset_peak_memory_stats() + mem_stamp0 = torch.cuda.memory_allocated() + torch.autograd.backward(output_real_tensor, upstream_grad) + bwd_allocated = torch.cuda.memory_allocated() - mem_stamp0 + bwd_peak = torch.cuda.max_memory_allocated() - mem_stamp0 + + compute_cost: TrainCycleItem + memory_cost: TrainCycleItem + + print_results([condition_real_tensor, x_real_tensor, y_real_tensor], [output_real_tensor], compute_cost, + memory_cost, fwd_allocated, fwd_peak, bwd_allocated, bwd_peak) + + +if __name__ == '__main__': + test_where_meta_info()