diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/__init__.py b/colossalai/auto_parallel/meta_profiler/meta_registry/__init__.py index aa5f77f65..359590c1f 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/__init__.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/__init__.py @@ -1,6 +1,7 @@ from .activation import * from .binary_elementwise_ops import * from .conv import * +from .embedding import * from .linear import * from .norm import * from .pooling import * diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/embedding.py b/colossalai/auto_parallel/meta_profiler/meta_registry/embedding.py new file mode 100644 index 000000000..2997f31ad --- /dev/null +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/embedding.py @@ -0,0 +1,52 @@ +from typing import List, Tuple + +import torch + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem +from colossalai.fx.profiler.memory_utils import activation_size +from colossalai.fx.profiler.opcount import flop_mapping + +from ..registry import meta_register + +__all__ = ["embedding_meta_info"] + + +@meta_register.register(torch.nn.Embedding) +def embedding_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: + """torch.nn.Embedding metainfo generator + + Returns: + Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs + """ + input_tensor = next(filter(lambda x: x.type == OperationDataType.ARG, args)).data + weight_tensor = next(filter(lambda x: x.type == OperationDataType.PARAM, args)).data + output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data + + # compute cost + fwd_compute_cost = flop_mapping[torch.ops.aten.embedding.default]([weight_tensor, input_tensor], [output_tensor]) + bwd_compute_cost = flop_mapping[torch.ops.aten.embedding_dense_backward.default]([output_tensor, weight_tensor], + [weight_tensor]) + + compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost) + + # memory cost + # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward + # NOTE: during the backward phase of torch.nn.Embedding, it seems when the input is large enough, it will + # have a temp memory which is kind of weird and we don't know the reason yet, so currently we just assume + # that there will be no temp memory, as the temp memory is significantly smaller than the gradient memory + fwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, output_tensor]), + parameter=0, + temp=0, + buffer=0) + bwd_memory_cost = MemoryCost(activation=activation_size([weight_tensor]), parameter=0, temp=0, buffer=0) + + total_memory_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation) + + memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_memory_cost) + + # store fwd_in, fwd_buffer, fwd_out + fwd_in = [torch.zeros_like(input_tensor)] + fwd_buffer = [] + fwd_out = [torch.zeros_like(output_tensor)] + + return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_embedding_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_embedding_metainfo.py new file mode 100644 index 000000000..2fb130654 --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_embedding_metainfo.py @@ -0,0 +1,77 @@ +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp +import torch.nn as nn + +from colossalai.auto_parallel.tensor_shard.node_handler import LinearModuleHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + MemoryCost, + OperationData, + OperationDataType, + ShardingStrategy, + StrategiesVector, + TrainCycleItem, +) +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use +from colossalai.utils import free_port +from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results + +if torch.__version__ >= '1.12.0': + from colossalai.auto_parallel.meta_profiler import MetaInfo, meta_register + + +@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations") +def test_embedding_meta_info(): + meta_func = meta_register.get(torch.nn.Embedding) + + # construct meta tensors + input_tensor = torch.randint(0, 50256, (8, 1024), device="meta") + weight_tensor = torch.rand(50257, 1024, device="meta") + output_tensor = torch.rand(8, 1024, 1024, device="meta") + + # construct operation data + input_data = OperationData(name="input", type=OperationDataType.ARG, data=input_tensor) + + weight_data = OperationData(name="weight", type=OperationDataType.PARAM, data=weight_tensor) + + output_data = OperationData(name="output", type=OperationDataType.OUTPUT, data=output_tensor) + + # construct args and kwargs + args = [input_data, weight_data, output_data] + kwargs = {'inplace': False} + + # estimated results + compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out = meta_func(*args, **kwargs) + + # actual results + input_real_tensor = torch.randint(0, 50256, (8, 1024), device="cuda") + embedding_module = torch.nn.Embedding(50257, 1024).cuda() + + # fwd + torch.cuda.reset_peak_memory_stats() + mem_stamp0 = torch.cuda.memory_allocated() + output_real_tensor = embedding_module(input_real_tensor) + fwd_allocated = torch.cuda.memory_allocated() - mem_stamp0 + fwd_peak = torch.cuda.max_memory_allocated() - mem_stamp0 + + # bwd + upstream_grad = torch.rand_like(output_real_tensor) + torch.cuda.reset_peak_memory_stats() + mem_stamp0 = torch.cuda.memory_allocated() + torch.autograd.backward(output_real_tensor, upstream_grad) + bwd_allocated = torch.cuda.memory_allocated() - mem_stamp0 + bwd_peak = torch.cuda.max_memory_allocated() - mem_stamp0 + + print_results([input_real_tensor], [output_real_tensor], compute_cost, memory_cost, fwd_allocated, fwd_peak, + bwd_allocated, bwd_peak) + + +if __name__ == '__main__': + test_embedding_meta_info()