[autoparallel] Patch tensor related operations meta information (#2789)

* [autoparallel] tensor related meta information prototype

* [autoparallel] tensor related meta information

* [autoparallel] tensor related meta information

* [autoparallel] tensor related meta information

* [autoparallel] tensor related meta information
This commit is contained in:
Boyuan Yao 2023-02-20 17:38:55 +08:00 committed by GitHub
parent a5721229d9
commit 7ea6bc7f69
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 183 additions and 1 deletions

View File

@ -5,3 +5,4 @@ from .embedding import *
from .linear import *
from .norm import *
from .pooling import *
from .tensor import *

View File

@ -14,7 +14,6 @@ __all__ = ["avgpool_meta_info", "maxpool_meta_info"]
@meta_register.register(torch.nn.AdaptiveAvgPool1d)
@meta_register.register(torch.nn.AdaptiveAvgPool2d)
@meta_register.register(torch.nn.AdaptiveAvgPool3d)
@meta_register.register(torch.flatten)
def avgpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]:
"""Meta info for AdaptiveAvgPool
The aten graph of AdaptiveAvgPool is

View File

@ -0,0 +1,79 @@
from typing import Callable, List, Tuple
import torch
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
from colossalai.fx.profiler.memory_utils import activation_size
from colossalai.fx.profiler.opcount import flop_mapping
from ..registry import meta_register
__all__ = ["tensor_related_metainfo"]
def tensor_related_metainfo(bwd_mem_out_factor: float = 1, bwd_mem_tmp_factor: float = 0) -> Callable:
"""torch.Tensor related metainfo generator template
Args:
bwd_mem_out_factor (float, optional): backward activation memory cost factor. Defaults to 1.
bwd_mem_tmp_factor (float, optional): backward temp memory cost factor. Defaults to 0.
Returns:
Callable: torch.Tensor related metainfo generator
"""
def meta_func(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]:
"""torch.Tensor related metainfo generator
Returns:
Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs
"""
outputs = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data
# compute costs are all zero
compute_cost = TrainCycleItem(fwd=0, bwd=0, total=0)
# memory costs
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
fwd_mem_cost = MemoryCost(activation=activation_size(outputs) * 2, parameter=0, temp=0, buffer=0)
bwd_mem_cost = MemoryCost(activation=activation_size(outputs) * bwd_mem_out_factor,
parameter=0,
temp=activation_size(outputs) * bwd_mem_tmp_factor,
buffer=0)
total_mem_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation,
parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter,
temp=fwd_mem_cost.temp + bwd_mem_cost.temp,
buffer=fwd_mem_cost.buffer + bwd_mem_cost.buffer)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
# store fwd_in, fwd_buffer, fwd_out
fwd_in = []
fwd_buffer = []
if isinstance(outputs, tuple) or isinstance(outputs, list) or isinstance(outputs, dict):
# tuple of tensors
fwd_out = [torch.zeros_like(tensor) for tensor in outputs]
else:
# enaged_tensors is a single tensor
fwd_out = [torch.zeros_like(outputs)]
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
return meta_func
# register torch.Tensor related metainfo
# (0, 0)
meta_register.register([torch.tensor, torch.Tensor.to, torch.Tensor.unsqueeze, torch.unsqueeze,
torch.arange])(tensor_related_metainfo(0, 0))
# (1, 0)
meta_register.register([
torch.Tensor.flatten, torch.flatten, torch.Tensor.transpose, torch.transpose, torch.Tensor.permute, torch.permute,
torch.Tensor.split, torch.split, torch.Tensor.view
])(tensor_related_metainfo(1, 0))
# (1, 1)
meta_register.register([torch.Tensor.type, torch.Tensor.contiguous])(tensor_related_metainfo(1, 1))

View File

@ -0,0 +1,103 @@
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.node_handler import LinearModuleHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
MemoryCost,
OperationData,
OperationDataType,
ShardingStrategy,
StrategiesVector,
TrainCycleItem,
)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results
if torch.__version__ >= '1.12.0':
from colossalai.auto_parallel.meta_profiler import MetaInfo, meta_register
class SplitModule(nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x):
return x.split(512, dim=0)
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations")
def test_tensor_meta_info():
"""test tensor related meta information
We will just use torch.Tensor.split for the test
"""
meta_func = meta_register.get(torch.Tensor.split)
# construct meta tensors
input_tensor = torch.rand(1024, 1024, device="meta")
output_tensor = input_tensor.split(512, dim=0)
# construct operation data
input_data = OperationData(
name="input",
data=input_tensor,
type=OperationDataType.ARG,
logical_shape=input_tensor.shape,
)
output_data = OperationData(
name="output",
data=output_tensor,
type=OperationDataType.OUTPUT,
logical_shape=input_tensor.shape,
)
split_info_data = OperationData(
name='split_info',
type=OperationDataType.ARG,
data=0,
logical_shape=None,
)
# construct args
args = [input_data, output_data, split_info_data]
kwargs = {'inplace': False}
# estimated results
compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out = meta_func(*args, **kwargs)
# actual results
model = SplitModule()
input_real_tensor = torch.rand(1024, 1024).cuda()
input_real_tensor.requires_grad = True
# fwd
torch.cuda.reset_peak_memory_stats()
mem_stamp0 = torch.cuda.memory_allocated()
output_real_tensor = model(input_real_tensor)
fwd_allocated = torch.cuda.memory_allocated() - mem_stamp0
fwd_peak = torch.cuda.max_memory_allocated() - mem_stamp0
# bwd
upstream_grad = [torch.rand_like(tensor) for tensor in output_real_tensor]
torch.cuda.reset_peak_memory_stats()
mem_stamp0 = torch.cuda.memory_allocated()
torch.autograd.backward(output_real_tensor, upstream_grad)
bwd_allocated = torch.cuda.memory_allocated() - mem_stamp0
bwd_peak = torch.cuda.max_memory_allocated() - mem_stamp0
print_results([input_real_tensor], output_real_tensor, compute_cost, memory_cost, fwd_allocated, fwd_peak,
bwd_allocated, bwd_peak)
if __name__ == "__main__":
test_tensor_meta_info()