[autoparallel]integrate auto parallel feature with new tracer (#3408)

* [autoparallel] integrate new analyzer in module level

* unify the profiling method

* polish

* fix no codegen bug

* fix pass bug

* fix liveness test

* polish
This commit is contained in:
YuliangLiu0306
2023-04-04 17:40:45 +08:00
committed by GitHub
parent 573af84184
commit ffcdbf0f65
46 changed files with 396 additions and 470 deletions

View File

@@ -3,6 +3,8 @@ from typing import Callable, Dict, List, Tuple, Union
import torch
from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
from colossalai._analyzer.fx.node_util import compute_size_in_bytes
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
MemoryCost,
OperationData,
@@ -11,8 +13,6 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
StrategiesVector,
TrainCycleItem,
)
from colossalai.fx.profiler.memory_utils import activation_size
from colossalai.fx.profiler.opcount import flop_mapping
from colossalai.tensor.sharding_spec import ShardingSpec
from ..registry import meta_register
@@ -112,14 +112,14 @@ def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
# NOTE: Linear don't have buffer and temp in forward and backward phase
# the forward activation cost is the size of output_tensor, parameter cost is the size of weight_tensor and bias_tensor
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
fwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, output_tensor]),
parameter=activation_size([weight_tensor, bias_tensor]),
fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor]),
parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
temp=0,
buffer=0)
# the backward activation cost is the size of input_tensor, weight_tensor and bias_tensor, parameter cost is 0
bwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, weight_tensor, bias_tensor]),
parameter=activation_size([weight_tensor, bias_tensor]),
bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor]),
parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
temp=0,
buffer=0)
@@ -148,14 +148,14 @@ def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
# NOTE: Linear don't have buffer and temp in forward and backward phase
# the forward activation cost is the size of output_tensor, parameter cost is the size of weight_tensor
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
fwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, output_tensor]),
parameter=activation_size(weight_tensor),
fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor]),
parameter=compute_size_in_bytes(weight_tensor),
temp=0,
buffer=0)
# the backward activation cost is the size of input_tensor and weight_tensor, parameter cost is 0
bwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, weight_tensor]),
parameter=activation_size(weight_tensor),
bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, weight_tensor]),
parameter=compute_size_in_bytes(weight_tensor),
temp=0,
buffer=0)
@@ -210,48 +210,48 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
# Check dimension
if all(len(tensor.shape) == 1 for tensor in input_tensors):
# Dot
fwd_compute_cost = flop_mapping[torch.ops.aten.dot.default](input_tensors, output_tensors)
fwd_compute_cost = flop_mapping[torch.ops.aten.matmul.default](input_tensors, output_tensors)
bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor](input_tensors[0], output_tensors) * 2
fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors), parameter=0, temp=0, buffer=0)
bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors), parameter=0, temp=0, buffer=0)
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors), parameter=0, temp=0, buffer=0)
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors), parameter=0, temp=0, buffer=0)
elif len(input_tensors[0].shape) >= 2 and len(input_tensors[1].shape) == 1:
# gemv case 1: matrix-vector multiplication
# &
# batched gemv case 1: batched matrix-vector multiplication
fwd_compute_cost = flop_mapping[torch.ops.aten.mv.default](
fwd_compute_cost = flop_mapping[torch.ops.aten.matmul.default](
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1]), input_tensors[1]], output_tensors)
# combine the dimensions of output
bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor](
[output_tensors[0].reshape(-1), input_tensors[1]],
output_tensors) + \
flop_mapping[torch.ops.aten.mv.default](
flop_mapping[torch.ops.aten.matmul.default](
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1]).transpose(0, 1), output_tensors[0].reshape(-1)],
output_tensors)
fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors), parameter=0, temp=0, buffer=0)
bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors), parameter=0, temp=0, buffer=0)
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors), parameter=0, temp=0, buffer=0)
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors), parameter=0, temp=0, buffer=0)
elif len(input_tensors[0].shape) == 1 and len(input_tensors[1].shape) == 2:
# gemv case 2: vector-matrix multiplication
fwd_compute_cost = flop_mapping[torch.ops.aten.mv.default](input_tensors, output_tensors)
fwd_compute_cost = flop_mapping[torch.ops.aten.matmul.default](input_tensors, output_tensors)
bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor]([output_tensors[0], input_tensors[0]], output_tensors) + \
flop_mapping[torch.ops.aten.mv.default]([input_tensors[1], output_tensors[0]], output_tensors)
flop_mapping[torch.ops.aten.matmul.default]([input_tensors[1], output_tensors[0]], output_tensors)
fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors), parameter=0, temp=0, buffer=0)
bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors),
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors), parameter=0, temp=0, buffer=0)
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors),
parameter=0,
temp=activation_size(input_tensors[1]),
temp=compute_size_in_bytes(input_tensors[1]),
buffer=0)
elif len(input_tensors[0].shape) == 1 and len(input_tensors[1].shape) >= 3:
# batched gemv case 2: vector-batched matrix multiplication
fwd_compute_cost = flop_mapping[torch.ops.aten.mv.default](
fwd_compute_cost = flop_mapping[torch.ops.aten.matmul.default](
[input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]), input_tensors[0]],
[output_tensors[0].reshape(-1)])
@@ -260,15 +260,15 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
[output_tensors[0].reshape(-1), input_tensors[0]],
output_tensors
) + \
flop_mapping[torch.ops.aten.mv.default](
flop_mapping[torch.ops.aten.matmul.default](
[input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]).transpose(0, 1), output_tensors[0].reshape(-1)],
output_tensors
)
fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors + [input_tensors[1]]))
bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors[0]),
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors + [input_tensors[1]]))
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors[0]),
parameter=0,
temp=activation_size(input_tensors[1]),
temp=compute_size_in_bytes(input_tensors[1]),
buffer=0)
elif len(input_tensors[0].shape) >= 2 and len(input_tensors[1].shape) == 2:
@@ -287,8 +287,8 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1])]
)
fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors), parameter=0, temp=0, buffer=0)
bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors), parameter=0, temp=0, buffer=0)
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors), parameter=0, temp=0, buffer=0)
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors), parameter=0, temp=0, buffer=0)
elif len(input_tensors[0].shape) == 2 and len(input_tensors[1].shape) >= 3:
# batched gemm case 2: matrix-batched matrix multiplication
@@ -306,11 +306,12 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
[input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2])]
)
fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors) + activation_size(input_tensors[1]),
temp=activation_size(output_tensors))
bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors[0]),
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors) +
compute_size_in_bytes(input_tensors[1]),
temp=compute_size_in_bytes(output_tensors))
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors[0]),
parameter=0,
temp=activation_size(input_tensors[1]) + activation_size(output_tensors))
temp=compute_size_in_bytes(input_tensors[1]) + compute_size_in_bytes(output_tensors))
elif all(len(tensor.shape) >= 3 for tensor in input_tensors):
# Batched matrix-batched matrix multiplication
@@ -351,8 +352,8 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
[input_tensors[0].reshape(-1, input_dim_00, input_dim_01)]
)
fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors))
bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors))
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors))
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors))
else:
# Case 2: batch dimensions are different
@@ -381,10 +382,10 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
)
fwd_mem_cost = MemoryCost(
activation=activation_size([output_tensors[0], extended_input_0, extended_input_1]))
bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors) -
activation_size([extended_input_0, extended_input_1]),
temp=activation_size([extended_input_0, extended_input_1]))
activation=compute_size_in_bytes([output_tensors[0], extended_input_0, extended_input_1]))
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors) -
compute_size_in_bytes([extended_input_0, extended_input_1]),
temp=compute_size_in_bytes([extended_input_0, extended_input_1]))
# compute cost
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)