[autoparallel]integrate auto parallel feature with new tracer (#3408)

* [autoparallel] integrate new analyzer in module level

* unify the profiling method

* polish

* fix no codegen bug

* fix pass bug

* fix liveness test

* polish
This commit is contained in:
YuliangLiu0306
2023-04-04 17:40:45 +08:00
committed by GitHub
parent 573af84184
commit ffcdbf0f65
46 changed files with 396 additions and 470 deletions

View File

@@ -2,9 +2,9 @@ from typing import Callable, List, Tuple
import torch
from colossalai._analyzer._subclasses.flop_tensor import ewise_flop_counter as elementwise_flop_counter
from colossalai._analyzer.fx.node_util import compute_size_in_bytes as activation_size
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
from colossalai.fx.profiler.memory_utils import activation_size
from colossalai.fx.profiler.opcount import elementwise_flop_counter
from ..registry import meta_register

View File

@@ -2,9 +2,9 @@ from typing import List, Tuple
import torch
from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
from colossalai._analyzer.fx.node_util import compute_size_in_bytes as activation_size
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
from colossalai.fx.profiler.memory_utils import activation_size
from colossalai.fx.profiler.opcount import flop_mapping
from ..constants import BCAST_FUNC_OP, NO_SAVE_ACTIVATION
from ..registry import meta_register
@@ -17,7 +17,7 @@ def binary_elementwise_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, Train
"""Meta information generator for binary elementwise operations
NOTE: Some of the binary elementwise operations will discard the input activation after computation, as they
don't need those tensors for back propagation, for example, if there are two tensors being sent for `torch.add`,
they will be discarded right after add operation is done. We create a simple API in `MetaInfo` class to identify
they will be discarded right after add operation is done. We create a simple API in `ShardMetaInfo` class to identify
this behavior, it is critical for better memory estimation.
Returns:

View File

@@ -2,6 +2,8 @@ from typing import Callable, Dict, List, Tuple, Union
import torch
from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
from colossalai._analyzer.fx.node_util import compute_size_in_bytes
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
MemoryCost,
OperationData,
@@ -10,8 +12,6 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
StrategiesVector,
TrainCycleItem,
)
from colossalai.fx.profiler.memory_utils import activation_size
from colossalai.fx.profiler.opcount import flop_mapping
from colossalai.tensor.sharding_spec import ShardingSpec
from ..registry import meta_register
@@ -110,18 +110,18 @@ def convnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
# calculate memory cost
# TODO: use profiler to check conv temp memory
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
fwd_memory_cost = MemoryCost(
activation=activation_size([input_tensor, output_tensor]),
parameter=activation_size([weight_tensor, bias_tensor]) if has_bias else activation_size(weight_tensor),
temp=0,
buffer=0)
fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor]),
parameter=compute_size_in_bytes([weight_tensor, bias_tensor])
if has_bias else compute_size_in_bytes(weight_tensor),
temp=0,
buffer=0)
bwd_memory_cost = MemoryCost(
activation=activation_size([input_tensor, weight_tensor, bias_tensor])
if has_bias else activation_size([input_tensor, weight_tensor]),
parameter=activation_size([weight_tensor, bias_tensor]) if has_bias else activation_size(weight_tensor),
temp=0,
buffer=0)
bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor])
if has_bias else compute_size_in_bytes([input_tensor, weight_tensor]),
parameter=compute_size_in_bytes([weight_tensor, bias_tensor])
if has_bias else compute_size_in_bytes(weight_tensor),
temp=0,
buffer=0)
# total cost is the sum of forward and backward cost
total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,

View File

@@ -2,9 +2,9 @@ from typing import List, Tuple
import torch
from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
from colossalai._analyzer.fx.node_util import compute_size_in_bytes
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
from colossalai.fx.profiler.memory_utils import activation_size
from colossalai.fx.profiler.opcount import flop_mapping
from ..registry import meta_register
@@ -34,11 +34,11 @@ def embedding_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem
# NOTE: during the backward phase of torch.nn.Embedding, it seems when the input is large enough, it will
# have a temp memory which is kind of weird and we don't know the reason yet, so currently we just assume
# that there will be no temp memory, as the temp memory is significantly smaller than the gradient memory
fwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, output_tensor]),
fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor]),
parameter=0,
temp=0,
buffer=0)
bwd_memory_cost = MemoryCost(activation=activation_size([weight_tensor]), parameter=0, temp=0, buffer=0)
bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([weight_tensor]), parameter=0, temp=0, buffer=0)
total_memory_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation)

View File

@@ -3,6 +3,8 @@ from typing import Callable, Dict, List, Tuple, Union
import torch
from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
from colossalai._analyzer.fx.node_util import compute_size_in_bytes
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
MemoryCost,
OperationData,
@@ -11,8 +13,6 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
StrategiesVector,
TrainCycleItem,
)
from colossalai.fx.profiler.memory_utils import activation_size
from colossalai.fx.profiler.opcount import flop_mapping
from colossalai.tensor.sharding_spec import ShardingSpec
from ..registry import meta_register
@@ -112,14 +112,14 @@ def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
# NOTE: Linear don't have buffer and temp in forward and backward phase
# the forward activation cost is the size of output_tensor, parameter cost is the size of weight_tensor and bias_tensor
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
fwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, output_tensor]),
parameter=activation_size([weight_tensor, bias_tensor]),
fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor]),
parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
temp=0,
buffer=0)
# the backward activation cost is the size of input_tensor, weight_tensor and bias_tensor, parameter cost is 0
bwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, weight_tensor, bias_tensor]),
parameter=activation_size([weight_tensor, bias_tensor]),
bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor]),
parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
temp=0,
buffer=0)
@@ -148,14 +148,14 @@ def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
# NOTE: Linear don't have buffer and temp in forward and backward phase
# the forward activation cost is the size of output_tensor, parameter cost is the size of weight_tensor
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
fwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, output_tensor]),
parameter=activation_size(weight_tensor),
fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor]),
parameter=compute_size_in_bytes(weight_tensor),
temp=0,
buffer=0)
# the backward activation cost is the size of input_tensor and weight_tensor, parameter cost is 0
bwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, weight_tensor]),
parameter=activation_size(weight_tensor),
bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, weight_tensor]),
parameter=compute_size_in_bytes(weight_tensor),
temp=0,
buffer=0)
@@ -210,48 +210,48 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
# Check dimension
if all(len(tensor.shape) == 1 for tensor in input_tensors):
# Dot
fwd_compute_cost = flop_mapping[torch.ops.aten.dot.default](input_tensors, output_tensors)
fwd_compute_cost = flop_mapping[torch.ops.aten.matmul.default](input_tensors, output_tensors)
bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor](input_tensors[0], output_tensors) * 2
fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors), parameter=0, temp=0, buffer=0)
bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors), parameter=0, temp=0, buffer=0)
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors), parameter=0, temp=0, buffer=0)
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors), parameter=0, temp=0, buffer=0)
elif len(input_tensors[0].shape) >= 2 and len(input_tensors[1].shape) == 1:
# gemv case 1: matrix-vector multiplication
# &
# batched gemv case 1: batched matrix-vector multiplication
fwd_compute_cost = flop_mapping[torch.ops.aten.mv.default](
fwd_compute_cost = flop_mapping[torch.ops.aten.matmul.default](
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1]), input_tensors[1]], output_tensors)
# combine the dimensions of output
bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor](
[output_tensors[0].reshape(-1), input_tensors[1]],
output_tensors) + \
flop_mapping[torch.ops.aten.mv.default](
flop_mapping[torch.ops.aten.matmul.default](
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1]).transpose(0, 1), output_tensors[0].reshape(-1)],
output_tensors)
fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors), parameter=0, temp=0, buffer=0)
bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors), parameter=0, temp=0, buffer=0)
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors), parameter=0, temp=0, buffer=0)
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors), parameter=0, temp=0, buffer=0)
elif len(input_tensors[0].shape) == 1 and len(input_tensors[1].shape) == 2:
# gemv case 2: vector-matrix multiplication
fwd_compute_cost = flop_mapping[torch.ops.aten.mv.default](input_tensors, output_tensors)
fwd_compute_cost = flop_mapping[torch.ops.aten.matmul.default](input_tensors, output_tensors)
bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor]([output_tensors[0], input_tensors[0]], output_tensors) + \
flop_mapping[torch.ops.aten.mv.default]([input_tensors[1], output_tensors[0]], output_tensors)
flop_mapping[torch.ops.aten.matmul.default]([input_tensors[1], output_tensors[0]], output_tensors)
fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors), parameter=0, temp=0, buffer=0)
bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors),
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors), parameter=0, temp=0, buffer=0)
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors),
parameter=0,
temp=activation_size(input_tensors[1]),
temp=compute_size_in_bytes(input_tensors[1]),
buffer=0)
elif len(input_tensors[0].shape) == 1 and len(input_tensors[1].shape) >= 3:
# batched gemv case 2: vector-batched matrix multiplication
fwd_compute_cost = flop_mapping[torch.ops.aten.mv.default](
fwd_compute_cost = flop_mapping[torch.ops.aten.matmul.default](
[input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]), input_tensors[0]],
[output_tensors[0].reshape(-1)])
@@ -260,15 +260,15 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
[output_tensors[0].reshape(-1), input_tensors[0]],
output_tensors
) + \
flop_mapping[torch.ops.aten.mv.default](
flop_mapping[torch.ops.aten.matmul.default](
[input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]).transpose(0, 1), output_tensors[0].reshape(-1)],
output_tensors
)
fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors + [input_tensors[1]]))
bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors[0]),
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors + [input_tensors[1]]))
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors[0]),
parameter=0,
temp=activation_size(input_tensors[1]),
temp=compute_size_in_bytes(input_tensors[1]),
buffer=0)
elif len(input_tensors[0].shape) >= 2 and len(input_tensors[1].shape) == 2:
@@ -287,8 +287,8 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1])]
)
fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors), parameter=0, temp=0, buffer=0)
bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors), parameter=0, temp=0, buffer=0)
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors), parameter=0, temp=0, buffer=0)
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors), parameter=0, temp=0, buffer=0)
elif len(input_tensors[0].shape) == 2 and len(input_tensors[1].shape) >= 3:
# batched gemm case 2: matrix-batched matrix multiplication
@@ -306,11 +306,12 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
[input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2])]
)
fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors) + activation_size(input_tensors[1]),
temp=activation_size(output_tensors))
bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors[0]),
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors) +
compute_size_in_bytes(input_tensors[1]),
temp=compute_size_in_bytes(output_tensors))
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors[0]),
parameter=0,
temp=activation_size(input_tensors[1]) + activation_size(output_tensors))
temp=compute_size_in_bytes(input_tensors[1]) + compute_size_in_bytes(output_tensors))
elif all(len(tensor.shape) >= 3 for tensor in input_tensors):
# Batched matrix-batched matrix multiplication
@@ -351,8 +352,8 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
[input_tensors[0].reshape(-1, input_dim_00, input_dim_01)]
)
fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors))
bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors))
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors))
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors))
else:
# Case 2: batch dimensions are different
@@ -381,10 +382,10 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
)
fwd_mem_cost = MemoryCost(
activation=activation_size([output_tensors[0], extended_input_0, extended_input_1]))
bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors) -
activation_size([extended_input_0, extended_input_1]),
temp=activation_size([extended_input_0, extended_input_1]))
activation=compute_size_in_bytes([output_tensors[0], extended_input_0, extended_input_1]))
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors) -
compute_size_in_bytes([extended_input_0, extended_input_1]),
temp=compute_size_in_bytes([extended_input_0, extended_input_1]))
# compute cost
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)

View File

@@ -4,8 +4,6 @@ from typing import List, Tuple
import torch
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
from colossalai.fx.profiler.memory_utils import activation_size
from colossalai.fx.profiler.opcount import flop_mapping
from ..registry import meta_register

View File

@@ -2,6 +2,8 @@ from typing import Callable, Dict, List, Tuple, Union
import torch
from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
from colossalai._analyzer.fx.node_util import compute_size_in_bytes
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
MemoryCost,
OperationData,
@@ -10,8 +12,6 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
StrategiesVector,
TrainCycleItem,
)
from colossalai.fx.profiler.memory_utils import activation_size
from colossalai.fx.profiler.opcount import flop_mapping
from colossalai.tensor.sharding_spec import ShardingSpec
from ..registry import meta_register
@@ -77,17 +77,18 @@ def batchnormnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleIt
# calculate memory cost
# the fwd activation cost is output plus saved mean and saved inv std
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
fwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, output_tensor, mean_tensor, var_tensor]),
parameter=activation_size([weight_tensor, bias_tensor]),
fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes(
[input_tensor, output_tensor, mean_tensor, var_tensor]),
parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
temp=0,
buffer=activation_size([mean_tensor, var_tensor]))
buffer=compute_size_in_bytes([mean_tensor, var_tensor]))
# the bwd memory cost is quite tricky here, BatchNorm will remove saved mean
# and saved inv std during backward phase
bwd_memory_cost = MemoryCost(activation=activation_size([input_tensor]),
parameter=activation_size([weight_tensor, bias_tensor]),
temp=activation_size([mean_tensor, var_tensor]),
buffer=activation_size([mean_tensor, var_tensor]))
bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor]),
parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
temp=compute_size_in_bytes([mean_tensor, var_tensor]),
buffer=compute_size_in_bytes([mean_tensor, var_tensor]))
# total cost is the sum of forward and backward cost
total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
@@ -131,15 +132,16 @@ def layernorm_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem
# memory cost
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
fwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, output_tensor, weight_tensor, bias_tensor]),
parameter=activation_size([weight_tensor, bias_tensor]),
fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes(
[input_tensor, output_tensor, weight_tensor, bias_tensor]),
parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
temp=0,
buffer=activation_size([running_mean, running_var]))
buffer=compute_size_in_bytes([running_mean, running_var]))
bwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, weight_tensor, bias_tensor]),
parameter=activation_size([weight_tensor, bias_tensor]),
temp=activation_size([running_mean, running_var]),
buffer=activation_size([running_mean, running_var]))
bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor]),
parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
temp=compute_size_in_bytes([running_mean, running_var]),
buffer=compute_size_in_bytes([running_mean, running_var]))
total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter,

View File

@@ -2,9 +2,9 @@ from typing import List, Tuple
import torch
from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
from colossalai._analyzer.fx.node_util import compute_size_in_bytes
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
from colossalai.fx.profiler.memory_utils import activation_size
from colossalai.fx.profiler.opcount import flop_mapping
from ..registry import meta_register
@@ -52,8 +52,8 @@ def avgpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem,
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)
# calculate memory cost
fwd_mem_cost = MemoryCost() if is_inplace else MemoryCost(activation=activation_size(output_tensor))
bwd_mem_cost = MemoryCost() if is_inplace else MemoryCost(activation=activation_size(input_tensor))
fwd_mem_cost = MemoryCost() if is_inplace else MemoryCost(activation=compute_size_in_bytes(output_tensor))
bwd_mem_cost = MemoryCost() if is_inplace else MemoryCost(activation=compute_size_in_bytes(input_tensor))
# total cost
total_mem_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation)
@@ -114,11 +114,11 @@ def maxpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem,
# calculate memory cost
# NOTE: the index matrix will be discarded in backward phase
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
fwd_mem_cost = MemoryCost(activation=activation_size([input_tensor, output_tensor, index_matrix]))
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor, index_matrix]))
# temp memory for backward is the index matrix to be discarded
bwd_mem_cost = MemoryCost(activation=activation_size(input_tensor) - activation_size(index_matrix),
temp=activation_size(index_matrix))
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensor) - compute_size_in_bytes(index_matrix),
temp=compute_size_in_bytes(index_matrix))
# total cost
total_mem_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation, temp=bwd_mem_cost.temp)

View File

@@ -2,9 +2,9 @@ from typing import Callable, List, Tuple
import torch
from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
from colossalai._analyzer.fx.node_util import compute_size_in_bytes
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
from colossalai.fx.profiler.memory_utils import activation_size
from colossalai.fx.profiler.opcount import flop_mapping
from ..registry import meta_register
@@ -35,11 +35,11 @@ def tensor_related_metainfo(bwd_mem_out_factor: float = 1, bwd_mem_tmp_factor: f
# memory costs
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
fwd_mem_cost = MemoryCost(activation=activation_size(outputs) * 2, parameter=0, temp=0, buffer=0)
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(outputs) * 2, parameter=0, temp=0, buffer=0)
bwd_mem_cost = MemoryCost(activation=activation_size(outputs) * bwd_mem_out_factor,
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(outputs) * bwd_mem_out_factor,
parameter=0,
temp=activation_size(outputs) * bwd_mem_tmp_factor,
temp=compute_size_in_bytes(outputs) * bwd_mem_tmp_factor,
buffer=0)
total_mem_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation,

View File

@@ -2,9 +2,9 @@ from typing import List, Tuple
import torch
from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
from colossalai._analyzer.fx.node_util import compute_size_in_bytes as activation_size
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
from colossalai.fx.profiler.memory_utils import activation_size
from colossalai.fx.profiler.opcount import flop_mapping
from ..registry import meta_register